Files
sjy01-image-proc/pkg/utils/interp_lagrange.go
2024-07-12 11:34:02 +08:00

121 lines
2.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package utils
import (
"fmt"
"sort"
"github.com/nuknal/goNum"
)
type LagrangeInterpolator struct {
coeffs []float64
n int
}
func (li *LagrangeInterpolator) Fit(x []float64, y []float64) error {
li.n = len(x) - 1
if li.n < 0 || len(y) != li.n+1 {
return fmt.Errorf("invalid input data")
}
if li.n > 9 {
li.n = 9 // 限制最大阶数为9
}
n := li.n + 1
// 初始化系数数组
li.coeffs = make([]float64, n)
for i := range li.coeffs {
li.coeffs[i] = 0
}
// 计算拉格朗日插值多项式的系数
for i := 0; i < n; i++ {
li_coeff := make([]float64, n)
li_coeff[0] = 1
for j := 0; j < n; j++ {
if i != j {
for k := n - 1; k >= 0; k-- {
li_coeff[k] *= -x[j]
if k > 0 {
li_coeff[k] += li_coeff[k-1]
}
}
for k := 0; k < n; k++ {
li_coeff[k] /= (x[i] - x[j])
}
}
}
for k := 0; k < n; k++ {
li.coeffs[k] += y[i] * li_coeff[k]
}
}
return nil
}
func (li LagrangeInterpolator) Predict(x float64) float64 {
n := len(li.coeffs)
y := 0.0
for i := 0; i < n; i++ {
term := li.coeffs[i]
for j := 0; j < i; j++ {
term *= x
}
y += term
}
return y
}
func (li LagrangeInterpolator) N() int {
return li.n
}
// InterpLagrange 利用拉格朗日插值法计算函数值
// 尽量9阶采用内插值
const STEP_N = 7
func InterpLagrange(x []float64, y []float64, xq float64) float64 {
if len(x) != len(y) {
return 0.0
}
// 限制阶数为9
var data []float64
start, end := FindClosestSubset(x, xq, STEP_N)
for i := start; i <= end; i++ {
data = append(data, x[i])
data = append(data, y[i])
}
A := goNum.NewMatrix(len(data)/2, 2, data)
yq, _ := goNum.InterpLagrange(A, xq)
return yq
}
// FindClosestSubset 找到包含xq的最近的n个元素的子数组
func FindClosestSubset(x []float64, xq float64, n int) (int, int) {
if len(x) <= n {
return 0, len(x) - 1 // 如果元素数量少于等于n直接返回整个数组
}
// 找到xq在数组中的插入点
idx := sort.Search(len(x), func(i int) bool { return x[i] >= xq })
// 计算子数组的起始和结束位置
start := idx - n/2 // 尽量让xq在中间4是因为9个元素的中间位置是4
end := idx + n/2
// 调整边界
if start < 0 {
start = 0
end = n
} else if end >= n {
end = n - 1
start = end - n + 1
}
return start, end
}