Files
sjy01-image-proc/pkg/producer/image_registration.go
2024-10-29 16:30:02 +08:00

507 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package producer
import (
"fmt"
"image"
"image/color"
"math"
"os"
"sync"
"github.com/airbusgeo/godal"
log "github.com/sirupsen/logrus"
"gocv.io/x/gocv"
"starwiz.cn/sjy01/image-proc/pkg/auxilary"
"starwiz.cn/sjy01/image-proc/pkg/payload"
)
type Registrate interface{}
const (
MssBands = 4
PixelBytes = 2
PanWidth = payload.PAN_PIXEL_WIDTH // 像素宽度
MssWidth = payload.MSS_PIXEL_WIDTH
BlockNH = 8
BlockNW = 4
OverlappedBlockLines = 1000 // 重叠块的行数
DownSampled ResampleMethod = "down_sample_pan"
UpSampled ResampleMethod = "up_sample_mss"
PanResolution = 1.3 // mm/pixel
MssResolution = 5.2
ReferenceShiftYB1 = 100.0
ReferenceShiftYB2 = 200.0
ReferenceShiftYB3 = 300.0
ReferenceShiftYB4 = 400.0
)
type ResampleMethod string
type Registrator struct {
Params Params
PanImage gocv.Mat
PanHeight int
PanWidth int
MssImages [4]gocv.Mat
MssHeight int
MssWidth int
shiftMutex sync.Mutex
phaseShifts [4][]PhaseShiftM
deltaXCoeffs [4][]float64 // 图像内畸变线性变换捕捉图像在水平方向上引起的X方向的变形
deltaYCoeffs [4][]float64 // 图像内畸变非线性变换捕捉图像在水平方向上引起的Y方向的变形
registeredMssImages [4]gocv.Mat // 配准后的MSS图像
rgbirImage gocv.Mat
resampleMethod ResampleMethod
auxHeads []*auxilary.AuxFrameHead
auxBoxes []*auxilary.AuxFocalBox
AuxPlatforms []*auxilary.AuxPlatform
GPSs *auxilary.GPSs
AttQuaternion *auxilary.Attitudes
ImageTime *auxilary.ImageTime
report Report
}
func NewRegistrator(rsmethod ResampleMethod) *Registrator {
var r Registrator
r.resampleMethod = rsmethod
return &r
}
func (r *Registrator) LoadPanRaw() error {
data, err := os.ReadFile(r.Params.PanRawFile)
if err != nil {
log.Error("Error reading raw file: ", err)
return err
}
height := len(data) / (PanWidth * PixelBytes)
r.PanImage, err = gocv.NewMatFromBytes(height, PanWidth, gocv.MatTypeCV16U, data)
if err != nil {
log.Error("Error creating Mat from bytes: ", err)
return err
}
r.PanHeight = height
r.PanWidth = PanWidth
godal.RegisterAll()
hDriver, ok := godal.RasterDriver("Gtiff")
if !ok {
panic("Gtiff not found")
}
md := hDriver.Metadatas()
if md["DCAP_CREATE"] == "YES" {
fmt.Printf("Driver GTiff supports Create() method.\n")
}
if md["DCAP_CREATECOPY"] == "YES" {
fmt.Printf("Driver Gtiff supports CreateCopy() method.\n")
}
fmt.Println("Gtiff driver name:", hDriver.LongName(), hDriver.ShortName())
return nil
}
func (r *Registrator) LoadMssRaw() error {
data, err := os.ReadFile(r.Params.MssRawFile)
if err != nil {
log.Error("Error reading raw file: ", err)
return err
}
height := len(data) / (MssWidth * PixelBytes * MssBands)
mssData := make([][]byte, MssBands)
for h := 0; h < height; h++ {
row := data[h*MssWidth*MssBands*PixelBytes : (h+1)*MssWidth*MssBands*PixelBytes]
for i := 0; i < MssBands; i++ {
mssData[i] = append(mssData[i], row[i*MssWidth*PixelBytes:(i+1)*MssWidth*PixelBytes]...)
}
}
for i := 0; i < MssBands; i++ {
r.MssImages[i], err = gocv.NewMatFromBytes(height, MssWidth, gocv.MatTypeCV16U, mssData[i])
if err != nil {
log.Error("Error creating Mat from bytes: ", err)
return err
}
}
r.MssHeight = height
r.MssWidth = MssWidth
return nil
}
func (r *Registrator) DoPhaseCorrelation() error {
switch r.resampleMethod {
case UpSampled:
return r.CalcUpPhaseCorrelation()
default:
return r.CalcDownPhaseCorrelation()
}
}
// 将PAN降采样后计算相位相关的偏移量
func (r *Registrator) CalcDownPhaseCorrelation() error {
// 确保 MSS 高度是 PAN 高度的 1/4
if r.MssHeight*4 != r.PanHeight {
err := fmt.Errorf("MSS height is not 1/4 of PAN height, invalid raw file")
log.Error(err)
return err
}
// 将PAN将采样作为轮廓匹配基准图像
downsampledPanImage := gocv.NewMat()
defer downsampledPanImage.Close()
gocv.Resize(r.PanImage, &downsampledPanImage,
image.Point{X: r.MssWidth, Y: r.MssHeight}, 0, 0, gocv.InterpolationCubic)
log.Println("down sampled PAN images size:", downsampledPanImage.Size())
// 分块高度
blockHeight := r.MssHeight / BlockNH
blockWidth := r.MssWidth / BlockNW
// 在 MSS 4 个波段上进行配准
err := r.doPhaseCorrelation(r.MssImages[0],
[]gocv.Mat{r.MssImages[0], r.MssImages[1], r.MssImages[2], r.MssImages[3]},
r.MssHeight, r.MssWidth, blockHeight, blockWidth)
if err != nil {
return err
}
r.fileterPhaseShift([]float64{64, 64, 64, 64}, true)
r.calcMSSDeltaCoeffs(4)
r.DoMSSCoRegistration(false)
// 边缘检测后再做一次配准
// var mssEdges []gocv.Mat
// for band := 0; band < len(r.registeredMssImages); band++ {
// edge := CV_Sobel(r.registeredMssImages[band])
// mssEdges = append(mssEdges, edge)
// }
// r.doPhaseCorrelation(mssEdges[0], mssEdges,
// r.MssHeight, r.MssWidth, blockHeight, blockWidth)
// r.fileterPhaseShift([]float64{5, 5, 5, 5}, false)
// r.calcMSSDeltaCoeffs(4)
// r.DoMSSCoRegistration(true)
// 基于 PAN 图像进行配准
err = r.doPhaseCorrelation(downsampledPanImage,
[]gocv.Mat{r.registeredMssImages[0]},
r.MssHeight, r.MssWidth, blockHeight, blockWidth)
if err != nil {
return err
}
r.fileterPhaseShift([]float64{30.0}, true)
r.calcMSSDeltaCoeffs(1)
r.DoPANCoRegistration()
return nil
}
// 将MSS升采样采样后计算相位相关的偏移量
func (r *Registrator) CalcUpPhaseCorrelation() error {
log.Fatal("unsuppotted up-resample method")
// 确保 MSS 高度是 PAN 高度的 1/4
if r.MssHeight*4 != r.PanHeight {
err := fmt.Errorf("MSS height is not 1/4 of PAN height, invalid raw file")
log.Error(err)
return err
}
// 在 MSS 4 个波段上进行配准
err := r.doPhaseCorrelation(r.MssImages[0],
[]gocv.Mat{r.MssImages[0], r.MssImages[1], r.MssImages[2], r.MssImages[3]},
r.MssHeight, r.MssWidth, r.MssHeight/BlockNH, r.MssWidth/BlockNW)
if err != nil {
return err
}
r.DoMSSCoRegistration(false)
upsampledMssImages := make([]gocv.Mat, MssBands)
for i := 0; i < MssBands; i++ {
upsampledMssImages[i] = gocv.NewMat()
gocv.Resize(r.registeredMssImages[i], &upsampledMssImages[i],
image.Point{X: r.PanWidth, Y: r.PanHeight}, 0, 0, gocv.InterpolationCubic)
}
fmt.Println("up sampled MSS images size:", upsampledMssImages[0].Size())
// 分块高度 - BlockNH, BlockNW % 4 == 0
blockHeight := r.PanHeight / BlockNH
blockWidth := r.PanWidth / BlockNW
log.Infof("blockHeight=%d, blockWidth=%d", blockHeight, blockWidth)
// 基于 PAN 图像进行配准
r.doPhaseCorrelation(r.PanImage,
[]gocv.Mat{upsampledMssImages[0]},
r.MssHeight, r.MssWidth, blockHeight, blockWidth)
return r.DoPANCoRegistration()
}
func (r *Registrator) doPhaseCorrelation(base gocv.Mat,
mssImages []gocv.Mat,
height, width,
blockHeight, blockWidth int) error {
var wg sync.WaitGroup
for band := 0; band < len(mssImages); band++ {
r.phaseShifts[band] = make([]PhaseShiftM, 0)
r.deltaXCoeffs[band] = make([]float64, 0)
r.deltaYCoeffs[band] = make([]float64, 0)
}
for bh := 0; bh < BlockNH; bh++ {
for bw := 0; bw < BlockNW; bw++ {
wg.Add(1)
go func(bh, bw int) {
defer wg.Done()
x0 := bw * blockWidth
y0 := bh * blockHeight
x1 := (bw + 1) * blockWidth
y1 := (bh + 1) * blockHeight
y1 += OverlappedBlockLines // Y偏移量过大需要将重叠块的行数加上以避免边界影响
if x1 > width || y1 > height {
log.Debugf("Block out of range. x0=%d, y0=%d, x1=%d, y1=%d", x0, y0, x1, y1)
}
if y1 > height {
y1 = height
}
var shiftM PhaseShiftM
shiftM.Block.width = x1 - x0
shiftM.Block.height = y1 - y0
shiftM.Block.coord.X = x0 // 块左上角x坐标
shiftM.Block.coord.Y = y0 // 块左上角y坐标
rect := image.Rect(
x0, y0,
x1, y1,
)
panBlock := base.Region(rect)
for band := 0; band < len(mssImages); band++ {
log.Debug("processing band:", band+1, ",block:", bh, rect)
mssBlock := mssImages[band].Region(rect)
// 处理每个分块
phaseShift, response := r.calculateBlockPhaseShift(panBlock, mssBlock)
shiftM.dx = phaseShift.X
shiftM.dy = phaseShift.Y
shiftM.response = response
r.shiftMutex.Lock()
r.phaseShifts[band] = append(r.phaseShifts[band], shiftM)
r.shiftMutex.Unlock()
mssBlock.Close()
}
panBlock.Close()
}(bh, bw)
}
}
wg.Wait()
for i := 0; i < len(mssImages); i++ {
for _, shift := range r.phaseShifts[i] {
if shift.response > 0.4 || shift.dx > 8 || shift.dy > 8 {
log.Debugf("Band %d, block %d, dx=%f, dy=%f, response=%f",
i, shift.Block.coord.X, shift.dx, shift.dy, shift.response)
}
}
}
return nil
}
func (r *Registrator) Clean() {
r.PanImage.Close()
for i := 0; i < MssBands; i++ {
r.MssImages[i].Close()
}
for i := 0; i < MssBands; i++ {
r.registeredMssImages[i].Close()
}
r.rgbirImage.Close()
}
func (r *Registrator) calcMSSDeltaCoeffs(bands int) error {
// 计算每个通道的delta多项式拟合系数
for i := 0; i < bands; i++ {
var cx []float64
var dx []float64
var dy []float64
effectShift := 0
for _, shift := range r.phaseShifts[i] {
if math.IsNaN(float64(shift.dx)) || math.IsNaN(float64(shift.dy)) {
continue
}
// 经验值过滤
// if shift.dy < 64.0 {
// continue
// }
effectShift++
cx = append(cx, float64(shift.Block.coord.X+shift.Block.width/2)) // MSS 块在X方向没有分块
log.Debugf("effective shift value: %v, cx: %v, dx: %v, dy: %v",
effectShift, shift.Block.coord.X, shift.dx, shift.dy)
dx = append(dx, float64(shift.dx))
dy = append(dy, float64(shift.dy))
}
if len(cx) < 3 {
log.Errorf("No effective shift value found for band %d, skip delta coefficients calculation", i+1)
continue
} else {
var err error
if r.deltaXCoeffs[i], err = PolynomialFit(cx, dx, 1); err != nil {
log.Error("Error fitting deltaX coeffs: ", err)
continue
}
if r.deltaYCoeffs[i], err = PolynomialFit(cx, dy, 2); err != nil {
log.Error("Error fitting deltaY coeffs: ", err)
continue
}
}
}
for i := 0; i < bands; i++ {
if len(r.deltaXCoeffs[i]) < 2 || len(r.deltaYCoeffs[i]) < 3 {
continue
}
log.Printf("Band %d:\n delta_x = %.6f*x + %.6f, \n delta_y = %.6f*x^2 + %.6f*x + %.6f\n",
i+1,
r.deltaXCoeffs[i][1], r.deltaXCoeffs[i][0],
r.deltaYCoeffs[i][2], r.deltaYCoeffs[i][1], r.deltaYCoeffs[i][0])
}
return nil
}
func (r *Registrator) DoMSSCoRegistration(byEdge bool) error {
for band := 0; band < MssBands; band++ {
if len(r.deltaXCoeffs[band]) < 2 || len(r.deltaYCoeffs[band]) < 3 {
log.Errorf("delta coefficients not calculated, skip co-registration %d", band+1)
if !byEdge {
r.registeredMssImages[band] = r.MssImages[band].Clone()
}
continue
}
mapX := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV32FC1)
mapY := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV32FC1)
for y := 0; y < r.MssHeight; y++ {
for x := 0; x < r.MssWidth; x++ {
var dx, dy float64
if r.resampleMethod == UpSampled {
xx := float64(x * MssBands)
yy := float64(y * MssBands)
dx = (r.deltaXCoeffs[band][1]*float64(xx) + r.deltaXCoeffs[band][0] + xx) / MssBands
dy = (r.deltaYCoeffs[band][2]*float64(xx*xx) + r.deltaYCoeffs[band][1]*float64(xx) + r.deltaYCoeffs[band][0] + yy) / MssBands
} else {
dx = r.deltaXCoeffs[band][1]*float64(x) + r.deltaXCoeffs[band][0] + float64(x)
dy = r.deltaYCoeffs[band][2]*float64(x*x) + r.deltaYCoeffs[band][1]*float64(x) + r.deltaYCoeffs[band][0] + float64(y)
}
mapX.SetFloatAt(y, x, float32(dx))
mapY.SetFloatAt(y, x, float32(dy))
}
}
log.Println("co-registration for band", band+1)
if !byEdge {
r.registeredMssImages[band] = gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV16UC1)
gocv.Remap(r.MssImages[band],
&r.registeredMssImages[band],
&mapX, &mapY,
gocv.InterpolationCubic,
gocv.BorderConstant,
color.RGBA{0, 0, 0, 0})
} else {
gocv.Remap(r.registeredMssImages[band],
&r.registeredMssImages[band],
&mapX, &mapY,
gocv.InterpolationCubic,
gocv.BorderConstant,
color.RGBA{0, 0, 0, 0})
}
mapX.Close()
mapY.Close()
}
// 裁掉末尾的的 MSS 480 行 和 PAN 的 480*4 行
r.PanHeight -= 360 * 4
r.MssHeight -= 360
return nil
}
func (r *Registrator) DoPANCoRegistration() error {
if len(r.deltaXCoeffs[0]) < 2 || len(r.deltaYCoeffs[0]) < 3 {
log.Error("delta coefficients not calculated, skip co-registration")
return nil
}
mapX := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV32FC1)
defer mapX.Close()
mapY := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV32FC1)
defer mapY.Close()
for y := 0; y < r.MssHeight; y++ {
for x := 0; x < r.MssWidth; x++ {
var dx, dy float64
if r.resampleMethod == UpSampled {
xx := float64(x * MssBands)
yy := float64(y * MssBands)
dx = (r.deltaXCoeffs[0][1]*float64(xx) + r.deltaXCoeffs[0][0] + xx) / MssBands
dy = (r.deltaYCoeffs[0][2]*float64(xx*xx) + r.deltaYCoeffs[0][1]*float64(xx) + r.deltaYCoeffs[0][0] + yy) / MssBands
} else {
dx = r.deltaXCoeffs[0][1]*float64(x) + r.deltaXCoeffs[0][0] + float64(x)
dy = r.deltaYCoeffs[0][2]*float64(x*x) + r.deltaYCoeffs[0][1]*float64(x) + r.deltaYCoeffs[0][0] + float64(y)
}
mapX.SetFloatAt(y, x, float32(dx))
mapY.SetFloatAt(y, x, float32(dy))
}
}
log.Println("co-registration for MSS (Align with PAN)")
for i := 0; i < MssBands; i++ {
registeredMSS := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV16UC1)
gocv.Remap(r.registeredMssImages[i],
&registeredMSS,
&mapX, &mapY,
gocv.InterpolationCubic,
gocv.BorderConstant,
color.RGBA{0, 0, 0, 0})
r.registeredMssImages[i].Close()
r.registeredMssImages[i] = registeredMSS
}
// 裁掉末尾的的 MSS 480 行 和 PAN 的 480*4 行
r.PanHeight -= 120 * 4
r.MssHeight -= 120
return nil
}