Files
sjy01-image-proc/image_registration.go
2024-05-27 14:55:59 +08:00

510 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package imageproc
import (
"bufio"
"fmt"
"image"
"image/color"
"math"
"os"
"sync"
"github.com/airbusgeo/godal"
log "github.com/sirupsen/logrus"
"gocv.io/x/gocv"
)
type Registrate interface{}
const (
MssBands = 4
PixelBytes = 2
PanWidth = 9344 // 像素宽度
MssWidth = 2336
BlockNH = 8
BlockNW = 4
OverlappedBlockLines = 2000 // 重叠块的行数
DownSampled ResampleMethod = "down_sample_pan"
UpSampled ResampleMethod = "up_sample_mss"
)
type ResampleMethod string
type Registrator struct {
PanImage gocv.Mat
PanHeight int
PanWidth int
MssImages [4]gocv.Mat
MssHeight int
MssWidth int
shiftMutex sync.Mutex
phaseShifts [4][]PhaseShiftM
deltaXCoeffs [4][]float64 // Polynomial fitting coefficients: 图像内畸变(非线性变换),捕捉图像在水平方向上引起的垂直方向的变形
deltaYCoeffs [4][]float64 // Polynomial fitting coefficients: 图像内畸变(非线性变换),捕捉图像在垂直方向上引起的水平方向的变形
registeredMssImages [4]gocv.Mat // 配准后的MSS图像
rgbirImage gocv.Mat
resampleMethod ResampleMethod
}
func NewRegistrator() *Registrator {
var r Registrator
return &r
}
func (r *Registrator) LoadPanRaw(raw string) error {
data, err := os.ReadFile(raw)
if err != nil {
log.Error("Error reading raw file: ", err)
return err
}
height := len(data) / (PanWidth * PixelBytes)
r.PanImage, err = gocv.NewMatFromBytes(height, PanWidth, gocv.MatTypeCV16U, data)
if err != nil {
log.Error("Error creating Mat from bytes: ", err)
return err
}
r.PanHeight = height
r.PanWidth = PanWidth
godal.RegisterAll()
hDriver, ok := godal.RasterDriver("Gtiff")
if !ok {
panic("Gtiff not found")
}
md := hDriver.Metadatas()
if md["DCAP_CREATE"] == "YES" {
fmt.Printf("Driver GTiff supports Create() method.\n")
}
if md["DCAP_CREATECOPY"] == "YES" {
fmt.Printf("Driver Gtiff supports CreateCopy() method.\n")
}
fmt.Println("Gtiff driver name:", hDriver.LongName(), hDriver.ShortName())
return nil
}
func (r *Registrator) LoadMssRaw(raw string) error {
data, err := os.ReadFile(raw)
if err != nil {
log.Error("Error reading raw file: ", err)
return err
}
height := len(data) / (MssWidth * PixelBytes * MssBands)
mssData := make([][]byte, MssBands)
for h := 0; h < height; h++ {
row := data[h*MssWidth*MssBands*PixelBytes : (h+1)*MssWidth*MssBands*PixelBytes]
for i := 0; i < MssBands; i++ {
mssData[i] = append(mssData[i], row[i*MssWidth*PixelBytes:(i+1)*MssWidth*PixelBytes]...)
}
}
for i := 0; i < MssBands; i++ {
r.MssImages[i], err = gocv.NewMatFromBytes(height, MssWidth, gocv.MatTypeCV16U, mssData[i])
if err != nil {
log.Error("Error creating Mat from bytes: ", err)
return err
}
}
r.MssHeight = height
r.MssWidth = MssWidth
return nil
}
// 将PAN降采样后计算相位相关的偏移量
func (r *Registrator) CalcDownPhaseCorrelation() error {
// 确保 MSS 高度是 PAN 高度的 1/4
if r.MssHeight*4 != r.PanHeight {
err := fmt.Errorf("MSS height is not 1/4 of PAN height, invalid raw file")
log.Error(err)
return err
}
// 将PAN将采样作为轮廓匹配基准图像
downsampledPanImage := gocv.NewMat()
gocv.Resize(r.PanImage, &downsampledPanImage,
image.Point{X: r.MssWidth, Y: r.MssHeight}, 0, 0, gocv.InterpolationCubic)
fmt.Println("down sampled PAN images size:", downsampledPanImage.Size())
// 分块高度
blockHeight := r.MssHeight / BlockNH
for band := 0; band < MssBands; band++ {
for bh := 0; bh < BlockNH; bh++ {
// 读取 PAN 和 MSS 分块数据
y1 := (bh+1)*blockHeight + 800
if y1 > r.MssHeight {
y1 = r.MssHeight
}
var shiftM PhaseShiftM
shiftM.Block.width = r.MssWidth // 块宽度
shiftM.Block.height = y1 - bh*blockHeight // 块高度
shiftM.Block.coord.X = 0 // 块左上角x坐标
shiftM.Block.coord.Y = bh * blockHeight // 块左上角y坐标
rect := image.Rect(
shiftM.Block.coord.X, shiftM.Block.coord.Y,
shiftM.Block.coord.X+shiftM.Block.width, shiftM.Block.coord.Y+shiftM.Block.height,
)
log.Println("Band", band+1, ", processing block", bh, rect)
panBlock := downsampledPanImage.Region(rect)
mssBlock := r.MssImages[band].Region(rect)
// 处理每个分块
phaseShift, response := r.calculateBlockPhaseShift(panBlock, mssBlock)
shiftM.dx = phaseShift.X
shiftM.dy = phaseShift.Y
shiftM.response = response
r.phaseShifts[band] = append(r.phaseShifts[band], shiftM)
panBlock.Close()
mssBlock.Close()
}
}
// if err := r.DoMssPhaseShift(); err != nil {
// log.Error("Error calculating MSS phase shift: ", err)
// return err
// }
for i := 0; i < MssBands; i++ {
for j, shift := range r.phaseShifts[i] {
if shift.response > 0.4 || shift.dy > 8 {
fmt.Printf("Band %d, block %d, dx=%f, dy=%f, response=%f\n",
i, j, shift.dx, shift.dy, shift.response)
}
}
}
r.calcDeltaCoeffs()
return nil
}
// 将MSS升采样采样后计算相位相关的偏移量
func (r *Registrator) CalcUpPhaseCorrelation() error {
// 确保 MSS 高度是 PAN 高度的 1/4
if r.MssHeight*4 != r.PanHeight {
err := fmt.Errorf("MSS height is not 1/4 of PAN height, invalid raw file")
log.Error(err)
return err
}
// 将PAN将采样作为轮廓匹配基准图像
var upsampledMssImages [MssBands]gocv.Mat
for i := 0; i < MssBands; i++ {
upsampledMssImages[i] = gocv.NewMat()
gocv.Resize(r.MssImages[i], &upsampledMssImages[i],
image.Point{X: r.PanWidth, Y: r.PanHeight}, 0, 0, gocv.InterpolationCubic)
}
fmt.Println("up sampled MSS images size:", upsampledMssImages[0].Size())
// 分块高度 - BlockNH, BlockNW % 4 == 0
blockHeight := r.PanHeight / BlockNH
blockWidth := r.PanWidth / BlockNW
log.Infof("blockHeight=%d, blockWidth=%d", blockHeight, blockWidth)
var wg sync.WaitGroup
for bh := 0; bh < BlockNH; bh++ {
for bw := 0; bw < BlockNW; bw++ {
wg.Add(1)
go func(bh, bw int) {
defer wg.Done()
x0 := bw * blockWidth
y0 := bh * blockHeight
x1 := (bw + 1) * blockWidth
y1 := (bh + 1) * blockHeight
y1 += OverlappedBlockLines // Y偏移量过大需要将重叠块的行数加上以避免边界影响
if x1 > r.PanWidth || y1 > r.PanHeight {
log.Warnf("Block out of range. x0=%d, y0=%d, x1=%d, y1=%d", x0, y0, x1, y1)
}
if y1 > r.PanHeight {
y1 = r.PanHeight
}
var shiftM PhaseShiftM
shiftM.Block.width = x1 - x0
shiftM.Block.height = y1 - y0
shiftM.Block.coord.X = x0 // 块左上角x坐标
shiftM.Block.coord.Y = y0 // 块左上角y坐标
rect := image.Rect(
x0, y0,
x1, y1,
)
panBlock := r.PanImage.Region(rect)
for band := 0; band < MssBands; band++ {
log.Println("Band", band+1, ", processing block", bh, rect)
mssBlock := upsampledMssImages[band].Region(rect)
// 处理每个分块
phaseShift, response := r.calculateBlockPhaseShift(panBlock, mssBlock)
shiftM.dx = phaseShift.X
shiftM.dy = phaseShift.Y
shiftM.response = response
r.shiftMutex.Lock()
r.phaseShifts[band] = append(r.phaseShifts[band], shiftM)
r.shiftMutex.Unlock()
mssBlock.Close()
}
panBlock.Close()
}(bh, bw)
}
}
wg.Wait()
for i := 0; i < MssBands; i++ {
for _, shift := range r.phaseShifts[i] {
if shift.response > 0.4 || shift.dx > 8 || shift.dy > 8 {
fmt.Printf("Band %d, block %d, dx=%f, dy=%f, response=%f\n",
i, shift.Block.coord.X, shift.dx, shift.dy, shift.response)
}
}
}
return nil
}
func (r *Registrator) SaveRegisteredMssToRaw(raw string) error {
f, err := os.OpenFile(raw, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
if err != nil {
return err
}
var mssData [4][]byte
for i := 0; i < MssBands; i++ {
mssData[i] = r.registeredMssImages[i].ToBytes()
}
width := r.MssWidth * PixelBytes
height := r.MssHeight
log.Println("Writing registered MSS to RAW file...", len(mssData[0])*4)
log.Println("width:", width)
log.Println("height:", height)
w := bufio.NewWriter(f)
for row := 0; row < height; row++ {
w.Write(mssData[0][row*width : (row+1)*width])
w.Write(mssData[1][row*width : (row+1)*width])
w.Write(mssData[2][row*width : (row+1)*width])
w.Write(mssData[3][row*width : (row+1)*width])
}
return nil
}
func (r *Registrator) bytesToRaw(mssData []byte, filePath string) error {
f, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
if err != nil {
return err
}
w := bufio.NewWriter(f)
w.Write(mssData)
return nil
}
func (r *Registrator) SaveRegisteredMssToGDALGTiff(tiffFile string) error {
log.Println("Saving registered MSS to TIFF file:", tiffFile)
width := r.MssWidth
height := r.MssHeight
// 创建合并后的图像RGBIR
r.rgbirImage = gocv.NewMatWithSize(height, width, gocv.MatTypeCV16UC4) // 4通道16位
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
red := r.registeredMssImages[0].GetShortAt(y, x)
green := r.registeredMssImages[1].GetShortAt(y, x)
blue := r.registeredMssImages[2].GetShortAt(y, x)
ir := r.registeredMssImages[3].GetShortAt(y, x)
r.rgbirImage.SetShortAt(y, x*4+0, red)
r.rgbirImage.SetShortAt(y, x*4+1, green)
r.rgbirImage.SetShortAt(y, x*4+2, blue)
r.rgbirImage.SetShortAt(y, x*4+3, ir)
}
}
// 创建一个二维切片来存储图像数据
data := make([][]uint16, MssBands)
for i := range data {
data[i] = make([]uint16, width*height)
}
// 从gocv.Mat中提取数据
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
for b := 0; b < MssBands; b++ {
data[b][y*width+x] = uint16(r.rgbirImage.GetShortAt(y, x*4+b))
}
}
}
ds, err := godal.Create(godal.GTiff,
tiffFile,
MssBands,
godal.UInt16,
width, height)
if err != nil {
log.Error("Error creating TIFF file: ", err)
return err
}
defer ds.Close()
setGeoTransform(ds, 0, 0, float64(width), float64(height), 1.2*4)
for b := 0; b < MssBands; b++ {
band := ds.Bands()[b]
err := band.IO(godal.IOWrite,
0, 0,
data[b],
width, height,
godal.PixelSpacing(2),
godal.LineSpacing(width*2))
if err != nil {
log.Error("Failed to write data to band:", err)
return err
}
}
log.Info("Saved registered mss to ", tiffFile)
return nil
}
func (r *Registrator) SavePanToGDALGTiff(tiffFile string) error {
log.Println("Saving PAN image to TIFF file:", tiffFile)
width := r.PanWidth
height := r.PanHeight
ds, err := godal.Create(godal.GTiff, tiffFile, 1, godal.UInt16, width, height)
if err != nil {
log.Error("Error creating TIFF file: ", err)
return err
}
defer ds.Close()
setGeoTransform(ds, 0, 0, float64(width), float64(height), 1.2)
ds.SetMetadata("NBITS", "16")
// 将通道的数据转换为uint16数组
data := make([]uint16, width*height)
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
data[y*width+x] = uint16(r.PanImage.GetShortAt(y, x))
}
}
band := ds.Bands()[0]
err = band.IO(godal.IOWrite,
0, 0,
data,
width, height,
godal.PixelSpacing(2),
godal.LineSpacing(width*2))
if err != nil {
log.Error("Failed to write data to band:", err)
return err
}
log.Info("Saved pan image to ", tiffFile)
return nil
}
func (r *Registrator) Clean() {
r.PanImage.Close()
for i := 0; i < MssBands; i++ {
r.MssImages[i].Close()
}
for i := 0; i < MssBands; i++ {
r.registeredMssImages[i].Close()
}
r.rgbirImage.Close()
}
func (r *Registrator) calcDeltaCoeffs() error {
// 计算每个通道的delta多项式拟合系数
for i := 0; i < MssBands; i++ {
var cx []float64
var dx []float64
var dy []float64
effectShift := 0
for j, shift := range r.phaseShifts[i] {
if math.IsNaN(float64(shift.dx)) || math.IsNaN(float64(shift.dy)) {
continue
}
effectShift++
cx = append(cx, float64(shift.Block.coord.X+j)) // MSS 块在X方向没有分块
fmt.Println("effectShift:", effectShift, "cx:", shift.Block.coord.X, "dy:", shift.dy)
dx = append(dx, float64(shift.dx))
dy = append(dy, float64(shift.dy))
}
r.deltaXCoeffs[i] = PolynomialFit(cx, dx, 1)
r.deltaYCoeffs[i] = PolynomialFit(cx, dy, 2)
}
for i := 0; i < MssBands; i++ {
log.Printf("Band %d:\n delta_x = %.6f*x + %.6f, \n delta_y = %.6f*x^2 + %.6f*x + %.6f\n",
i+1,
r.deltaXCoeffs[i][1], r.deltaXCoeffs[i][0],
r.deltaYCoeffs[i][2], r.deltaYCoeffs[i][1], r.deltaYCoeffs[i][0])
}
return nil
}
func (r *Registrator) DoCoRegestration() error {
for band := 0; band < MssBands; band++ {
mapX := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV32FC1)
mapY := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV32FC1)
for y := 0; y < r.MssHeight; y++ {
for x := 0; x < r.MssWidth; x++ {
// dx := r.deltaXCoeffs[band][1]*float64(x) + r.deltaXCoeffs[band][0] + float64(x)
// dy := r.deltaYCoeffs[band][2]*float64(x*x) + r.deltaYCoeffs[band][1]*float64(x) + r.deltaYCoeffs[band][0] + float64(y)
// fmt.Println("x:", x, "dx:", dx, "y:", y, "dy:", dy)
mapX.SetFloatAt(y, x, float32(x)+float32(r.deltaXCoeffs[band][0]))
mapY.SetFloatAt(y, x, float32(y)+float32(r.deltaYCoeffs[band][0]))
}
}
log.Println("co-registration for band", band+1)
r.registeredMssImages[band] = gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV16UC1)
gocv.Remap(r.MssImages[band],
&r.registeredMssImages[band],
&mapX, &mapY,
gocv.InterpolationCubic,
gocv.BorderConstant,
color.RGBA{0, 0, 0, 0})
}
return nil
}