package main import ( "bufio" "fmt" "image" "os" "github.com/airbusgeo/godal" log "github.com/sirupsen/logrus" "gocv.io/x/gocv" ) type Registrate interface{} const ( MssBands = 4 PixelBytes = 2 PanWidth = 9344 // 像素宽度 MssWidth = 2336 BlockNH = 5 BlockNW = 10 DownSampled ResampleMethod = "down_sample_pan" UpSampled ResampleMethod = "up_sample_mss" ) type ResampleMethod string type Registrator struct { PanImage gocv.Mat PanHeight int PanWidth int MssImages [4]gocv.Mat MssHeight int MssWidth int phaseShifts [4][]*PhaseShiftM registeredMssImages [4]gocv.Mat // 平移处理后升采样到PAN分辨率的MSS图像 rgbirImage gocv.Mat resampleMethod ResampleMethod } func NewRegistrator() *Registrator { var r Registrator return &r } func (r *Registrator) LoadPanRaw(raw string) error { data, err := os.ReadFile(raw) if err != nil { log.Error("Error reading raw file: ", err) return err } height := len(data) / (PanWidth * PixelBytes) r.PanImage, err = gocv.NewMatFromBytes(height, PanWidth, gocv.MatTypeCV16U, data) if err != nil { log.Error("Error creating Mat from bytes: ", err) return err } r.PanHeight = height r.PanWidth = PanWidth godal.RegisterAll() hDriver, ok := godal.RasterDriver("Gtiff") if !ok { panic("Gtiff not found") } md := hDriver.Metadatas() if md["DCAP_CREATE"] == "YES" { fmt.Printf("Driver GTiff supports Create() method.\n") } if md["DCAP_CREATECOPY"] == "YES" { fmt.Printf("Driver Gtiff supports CreateCopy() method.\n") } fmt.Println("Gtiff driver name:", hDriver.LongName(), hDriver.ShortName()) return nil } func (r *Registrator) LoadMssRaw(raw string) error { data, err := os.ReadFile(raw) if err != nil { log.Error("Error reading raw file: ", err) return err } height := len(data) / (MssWidth * PixelBytes * MssBands) mssData := make([][]byte, MssBands) for h := 0; h < height; h++ { row := data[h*MssWidth*MssBands*PixelBytes : (h+1)*MssWidth*MssBands*PixelBytes] for i := 0; i < MssBands; i++ { mssData[i] = append(mssData[i], row[i*MssWidth*PixelBytes:(i+1)*MssWidth*PixelBytes]...) } } for i := 0; i < MssBands; i++ { r.MssImages[i], err = gocv.NewMatFromBytes(height, MssWidth, gocv.MatTypeCV16U, mssData[i]) if err != nil { log.Error("Error creating Mat from bytes: ", err) return err } } r.MssHeight = height r.MssWidth = MssWidth return nil } // 将PAN降采样后计算相位相关的偏移量 func (r *Registrator) DoDownPhaseCorrelation() error { // 确保 MSS 高度是 PAN 高度的 1/4 if r.MssHeight*4 != r.PanHeight { err := fmt.Errorf("MSS height is not 1/4 of PAN height, invalid raw file") log.Error(err) return err } // 将PAN将采样作为轮廓匹配基准图像 downsampledPanImage := gocv.NewMat() gocv.Resize(r.PanImage, &downsampledPanImage, image.Point{X: r.MssWidth, Y: r.MssHeight}, 0, 0, gocv.InterpolationLinear) // 对每个 MSS 波段图像进行上采样 // upsampledMSSImages := make([]gocv.Mat, MssBands) for i := 0; i < MssBands; i++ { // upsampledMSSImages[i] = gocv.NewMat() // gocv.Resize(r.MssImages[i], &upsampledMSSImages[i], // image.Point{X: r.PanWidth, Y: r.PanHeight}, 0, 0, gocv.InterpolationLinear) // r.msToRaw(upsampledMSSImages[i].ToBytes(), fmt.Sprintf("MSS%d.RAW", i+1)) } fmt.Println("down sampled PAN images size:", downsampledPanImage.Size()) // 分块高度 blockHeight := r.MssHeight / BlockNH for band := 0; band < MssBands; band++ { for bh := 0; bh < BlockNH; bh++ { // 读取 PAN 和 MSS 分块数据 y1 := (bh + 1) * blockHeight if bh == BlockNH-1 { y1 = r.MssHeight } var shiftM PhaseShiftM shiftM.Block.width = r.MssWidth // 块宽度 shiftM.Block.height = y1 - bh*blockHeight // 块高度 shiftM.Block.coord.X = 0 // 块左上角x坐标 shiftM.Block.coord.Y = bh * blockHeight // 块左上角y坐标 rect := image.Rect( shiftM.Block.coord.X, shiftM.Block.coord.Y, shiftM.Block.coord.X+shiftM.Block.width, shiftM.Block.coord.Y+shiftM.Block.height, ) log.Println("Band", band+1, ", processing block", bh, rect) panBlock := downsampledPanImage.Region(rect) mssBlock := r.MssImages[band].Region(rect) // 处理每个分块 phaseShift := r.calculateBlockPhaseShift(panBlock, mssBlock) shiftM.dx = phaseShift.X shiftM.dy = phaseShift.Y r.phaseShifts[band] = append(r.phaseShifts[band], &shiftM) panBlock.Close() mssBlock.Close() } } alignedMssData, err := r.DoMssPhaseShift() if err != nil { log.Error("Error calculating MSS phase shift: ", err) return err } r.mssToRaw(alignedMssData) // r.bytesToRaw(r.PanImage.ToBytes(), "data/PAN.RAW") // r.SavePanToGDALGTiff("data/pan.tiff") r.SaveRegisteredMssToGDALGTiff("data/registered_mss.tiff") return nil } func (r *Registrator) mssToRaw(mssData [][]byte) error { f, err := os.OpenFile("data/downsampled_registered_mss.RAW", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777) if err != nil { return err } width := r.MssWidth * PixelBytes height := r.MssHeight log.Println("Writing downsampled registered MSS to RAW file...", len(mssData[0])*4) log.Println("width:", width) log.Println("height:", height) w := bufio.NewWriter(f) for row := 0; row < height; row++ { w.Write(mssData[0][row*width : (row+1)*width]) w.Write(mssData[1][row*width : (row+1)*width]) w.Write(mssData[2][row*width : (row+1)*width]) w.Write(mssData[3][row*width : (row+1)*width]) } return nil } func (r *Registrator) bytesToRaw(mssData []byte, filePath string) error { f, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777) if err != nil { return err } w := bufio.NewWriter(f) w.Write(mssData) return nil } func (r *Registrator) SaveRegisteredMssToGDALGTiff(tiffFile string) error { log.Println("Saving registered MSS to TIFF file...") width := r.MssWidth height := r.MssHeight // 创建合并后的图像(RGBIR) r.rgbirImage = gocv.NewMatWithSize(height, width, gocv.MatTypeCV16UC4) // 4通道,16位 for y := 0; y < height; y++ { for x := 0; x < width; x++ { red := r.registeredMssImages[0].GetShortAt(y, x) green := r.registeredMssImages[1].GetShortAt(y, x) blue := r.registeredMssImages[2].GetShortAt(y, x) ir := r.registeredMssImages[3].GetShortAt(y, x) r.rgbirImage.SetShortAt(y, x*4+0, red) r.rgbirImage.SetShortAt(y, x*4+1, green) r.rgbirImage.SetShortAt(y, x*4+2, blue) r.rgbirImage.SetShortAt(y, x*4+3, ir) } } // 创建一个二维切片来存储图像数据 data := make([][]uint16, MssBands) for i := range data { data[i] = make([]uint16, width*height) } // 从gocv.Mat中提取数据 for y := 0; y < height; y++ { for x := 0; x < width; x++ { for b := 0; b < MssBands; b++ { data[b][y*width+x] = uint16(r.rgbirImage.GetShortAt(y, x*4+b)) } } } ds, err := godal.Create(godal.GTiff, tiffFile, MssBands, godal.UInt16, width, height) if err != nil { log.Error("Error creating TIFF file: ", err) return err } defer ds.Close() for b := 0; b < MssBands; b++ { band := ds.Bands()[b] err := band.IO(godal.IOWrite, 0, 0, data[b], width, height, godal.PixelSpacing(2), godal.LineSpacing(width*2)) if err != nil { log.Error("Failed to write data to band:", err) return err } } log.Info("saved registered mss to tiff file: ", tiffFile) return nil } func (r *Registrator) SavePanToGDALGTiff(tiffFile string) error { log.Println("Saving PAN image to TIFF file...") width := r.PanWidth height := r.PanHeight ds, err := godal.Create(godal.GTiff, tiffFile, 1, godal.UInt16, width, height) if err != nil { log.Error("Error creating TIFF file: ", err) return err } defer ds.Close() // 将通道的数据转换为字节数组 data := make([]uint16, width*height) for y := 0; y < height; y++ { for x := 0; x < width; x++ { data[y*width+x] = uint16(r.PanImage.GetShortAt(y, x)) } } band := ds.Bands()[0] err = band.IO(godal.IOWrite, 0, 0, data, width, height, godal.PixelSpacing(2), godal.LineSpacing(width*2)) if err != nil { log.Error("Failed to write data to band:", err) return err } log.Info("saved pan image to tiff file: ", tiffFile) return nil }