Files
sjy01-image-proc/pkg/producer/phase_correlation.go
2024-11-01 15:52:53 +08:00

140 lines
3.2 KiB
Go

package producer
import (
"errors"
"image"
"math/cmplx"
"github.com/duke-git/lancet/v2/slice"
"github.com/mjibson/go-dsp/fft"
log "github.com/sirupsen/logrus"
"gocv.io/x/gocv"
)
type PhaseShiftM struct {
dx float32
dy float32
response float64
Block Block
}
type Block struct {
width int
height int
coord image.Point // top-left corner of the block in the original image
}
func CV_PhaseCorrelate(panBlock, mssBlock gocv.Mat) (gocv.Point2f, float64) {
pan := gocv.NewMat()
mss := gocv.NewMat()
panBlock.ConvertTo(&pan, gocv.MatTypeCV32FC1)
defer pan.Close()
mssBlock.ConvertTo(&mss, gocv.MatTypeCV32FC1)
defer mss.Close()
hann := gocv.NewMatWithSize(pan.Rows(), pan.Cols(), pan.Type())
defer hann.Close()
gocv.CreateHanningWindow(&hann, image.Point{X: pan.Cols(), Y: pan.Rows()}, pan.Type())
shift, response := gocv.PhaseCorrelate(pan, mss, hann)
dx := shift.X
dy := shift.Y
log.Debugf("Block shift: dx = %f, dy = %f. response = %f", dx, dy, response)
return shift, response
}
func (r *ImgProc) fileterPhaseShift(thredholds []float64) error {
if len(thredholds) > 4 {
return errors.New("thredholds length should be less than 4")
}
for i := 0; i < len(thredholds); i++ {
th := thredholds[i]
r.phaseShifts[i] = slice.Filter(r.phaseShifts[i], func(i int, value PhaseShiftM) bool {
if value.response > 0.999999 {
return false
}
return value.dy > float32(th-20) && value.dy < float32(th+20)
})
}
return nil
}
func PhaseCorrelate(panBlock, mssBlock gocv.Mat) (gocv.Point2f, float64) {
// 计算傅里叶变换
panFreqDomain, _, _, err := toFreqDomain(panBlock)
if err != nil {
log.Error(err)
}
mssFreqDomain, width, height, err := toFreqDomain(mssBlock)
if err != nil {
log.Error(err)
}
// 计算共轭乘积并积累相位信息
crossPowerSpectrum := make([][]complex128, height)
for i := range crossPowerSpectrum {
crossPowerSpectrum[i] = make([]complex128, width)
}
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
if panFreqDomain[y][x] != 0 && mssFreqDomain[y][x] != 0 {
crossPowerSpectrum[y][x] = panFreqDomain[y][x] * cmplx.Conj(mssFreqDomain[y][x])
}
}
}
// 归一化
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
if crossPowerSpectrum[y][x] != 0 {
magnitude := cmplx.Abs(crossPowerSpectrum[y][x])
crossPowerSpectrum[y][x] = crossPowerSpectrum[y][x] / complex(magnitude, 0)
}
}
}
ifftResult := fft.IFFT2(crossPowerSpectrum)
// 查找最大值及其对应的平移参数
maxVal := 0.0
var dx, dy float64
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
val := real(ifftResult[y][x])
if val > maxVal {
maxVal = val
dx = float64(x)
dy = float64(y)
}
}
}
log.Debugf("Block shift: dx = %f, dy = %f. response = %f", dx, dy, 0.0)
return gocv.Point2f{X: float32(dx), Y: float32(dy)}, 0.0
}
func toFreqDomain(input gocv.Mat) ([][]complex128, int, int, error) {
height := input.Rows()
width := input.Cols()
data := make([][]complex128, height)
for y := 0; y < height; y++ {
data[y] = make([]complex128, width)
for x := 0; x < width; x++ {
grayColor := float64(uint16(input.GetShortAt(y, x)))
data[y][x] = complex(grayColor, 0)
}
}
freqDomain := fft.FFT2(data)
return freqDomain, width, height, nil
}