319 lines
7.7 KiB
Go
319 lines
7.7 KiB
Go
package imageproc
|
||
|
||
import (
|
||
"bufio"
|
||
"fmt"
|
||
"image"
|
||
"os"
|
||
|
||
"github.com/airbusgeo/godal"
|
||
log "github.com/sirupsen/logrus"
|
||
"gocv.io/x/gocv"
|
||
)
|
||
|
||
type Registrate interface{}
|
||
|
||
const (
|
||
MssBands = 4
|
||
PixelBytes = 2
|
||
PanWidth = 9344 // 像素宽度
|
||
MssWidth = 2336
|
||
BlockNH = 5
|
||
BlockNW = 10
|
||
DownSampled ResampleMethod = "down_sample_pan"
|
||
UpSampled ResampleMethod = "up_sample_mss"
|
||
)
|
||
|
||
type ResampleMethod string
|
||
|
||
type Registrator struct {
|
||
PanImage gocv.Mat
|
||
PanHeight int
|
||
PanWidth int
|
||
|
||
MssImages [4]gocv.Mat
|
||
MssHeight int
|
||
MssWidth int
|
||
|
||
phaseShifts [4][]*PhaseShiftM
|
||
|
||
registeredMssImages [4]gocv.Mat // 平移处理后升采样到PAN分辨率的MSS图像
|
||
rgbirImage gocv.Mat
|
||
|
||
resampleMethod ResampleMethod
|
||
}
|
||
|
||
func NewRegistrator() *Registrator {
|
||
var r Registrator
|
||
|
||
return &r
|
||
}
|
||
|
||
func (r *Registrator) LoadPanRaw(raw string) error {
|
||
data, err := os.ReadFile(raw)
|
||
if err != nil {
|
||
log.Error("Error reading raw file: ", err)
|
||
return err
|
||
}
|
||
|
||
height := len(data) / (PanWidth * PixelBytes)
|
||
r.PanImage, err = gocv.NewMatFromBytes(height, PanWidth, gocv.MatTypeCV16U, data)
|
||
if err != nil {
|
||
log.Error("Error creating Mat from bytes: ", err)
|
||
return err
|
||
}
|
||
|
||
r.PanHeight = height
|
||
r.PanWidth = PanWidth
|
||
|
||
godal.RegisterAll()
|
||
hDriver, ok := godal.RasterDriver("Gtiff")
|
||
if !ok {
|
||
panic("Gtiff not found")
|
||
}
|
||
md := hDriver.Metadatas()
|
||
if md["DCAP_CREATE"] == "YES" {
|
||
fmt.Printf("Driver GTiff supports Create() method.\n")
|
||
}
|
||
if md["DCAP_CREATECOPY"] == "YES" {
|
||
fmt.Printf("Driver Gtiff supports CreateCopy() method.\n")
|
||
}
|
||
|
||
fmt.Println("Gtiff driver name:", hDriver.LongName(), hDriver.ShortName())
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *Registrator) LoadMssRaw(raw string) error {
|
||
data, err := os.ReadFile(raw)
|
||
if err != nil {
|
||
log.Error("Error reading raw file: ", err)
|
||
return err
|
||
}
|
||
|
||
height := len(data) / (MssWidth * PixelBytes * MssBands)
|
||
mssData := make([][]byte, MssBands)
|
||
for h := 0; h < height; h++ {
|
||
row := data[h*MssWidth*MssBands*PixelBytes : (h+1)*MssWidth*MssBands*PixelBytes]
|
||
for i := 0; i < MssBands; i++ {
|
||
mssData[i] = append(mssData[i], row[i*MssWidth*PixelBytes:(i+1)*MssWidth*PixelBytes]...)
|
||
}
|
||
}
|
||
|
||
for i := 0; i < MssBands; i++ {
|
||
r.MssImages[i], err = gocv.NewMatFromBytes(height, MssWidth, gocv.MatTypeCV16U, mssData[i])
|
||
if err != nil {
|
||
log.Error("Error creating Mat from bytes: ", err)
|
||
return err
|
||
}
|
||
}
|
||
|
||
r.MssHeight = height
|
||
r.MssWidth = MssWidth
|
||
|
||
return nil
|
||
}
|
||
|
||
// 将PAN降采样后计算相位相关的偏移量
|
||
func (r *Registrator) DoDownPhaseCorrelation() error {
|
||
// 确保 MSS 高度是 PAN 高度的 1/4
|
||
if r.MssHeight*4 != r.PanHeight {
|
||
err := fmt.Errorf("MSS height is not 1/4 of PAN height, invalid raw file")
|
||
log.Error(err)
|
||
return err
|
||
}
|
||
|
||
// 将PAN将采样作为轮廓匹配基准图像
|
||
downsampledPanImage := gocv.NewMat()
|
||
gocv.Resize(r.PanImage, &downsampledPanImage,
|
||
image.Point{X: r.MssWidth, Y: r.MssHeight}, 0, 0, gocv.InterpolationLinear)
|
||
fmt.Println("down sampled PAN images size:", downsampledPanImage.Size())
|
||
|
||
// 分块高度
|
||
blockHeight := r.MssHeight / BlockNH
|
||
for band := 0; band < MssBands; band++ {
|
||
for bh := 0; bh < BlockNH; bh++ {
|
||
// 读取 PAN 和 MSS 分块数据
|
||
y1 := (bh + 1) * blockHeight
|
||
if bh == BlockNH-1 {
|
||
y1 = r.MssHeight
|
||
}
|
||
|
||
var shiftM PhaseShiftM
|
||
shiftM.Block.width = r.MssWidth // 块宽度
|
||
shiftM.Block.height = y1 - bh*blockHeight // 块高度
|
||
shiftM.Block.coord.X = 0 // 块左上角x坐标
|
||
shiftM.Block.coord.Y = bh * blockHeight // 块左上角y坐标
|
||
|
||
rect := image.Rect(
|
||
shiftM.Block.coord.X, shiftM.Block.coord.Y,
|
||
shiftM.Block.coord.X+shiftM.Block.width, shiftM.Block.coord.Y+shiftM.Block.height,
|
||
)
|
||
log.Println("Band", band+1, ", processing block", bh, rect)
|
||
|
||
panBlock := downsampledPanImage.Region(rect)
|
||
mssBlock := r.MssImages[band].Region(rect)
|
||
|
||
// 处理每个分块
|
||
phaseShift := r.calculateBlockPhaseShift(panBlock, mssBlock)
|
||
shiftM.dx = phaseShift.X
|
||
shiftM.dy = phaseShift.Y
|
||
|
||
r.phaseShifts[band] = append(r.phaseShifts[band], &shiftM)
|
||
|
||
panBlock.Close()
|
||
mssBlock.Close()
|
||
}
|
||
}
|
||
|
||
if err := r.DoMssPhaseShift(); err != nil {
|
||
log.Error("Error calculating MSS phase shift: ", err)
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *Registrator) SaveRegisteredMssToRaw(raw string) error {
|
||
f, err := os.OpenFile(raw, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
var mssData [4][]byte
|
||
for i := 0; i < MssBands; i++ {
|
||
mssData[i] = r.registeredMssImages[i].ToBytes()
|
||
}
|
||
|
||
width := r.MssWidth * PixelBytes
|
||
height := r.MssHeight
|
||
log.Println("Writing registered MSS to RAW file...", len(mssData[0])*4)
|
||
log.Println("width:", width)
|
||
log.Println("height:", height)
|
||
w := bufio.NewWriter(f)
|
||
for row := 0; row < height; row++ {
|
||
w.Write(mssData[0][row*width : (row+1)*width])
|
||
w.Write(mssData[1][row*width : (row+1)*width])
|
||
w.Write(mssData[2][row*width : (row+1)*width])
|
||
w.Write(mssData[3][row*width : (row+1)*width])
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *Registrator) bytesToRaw(mssData []byte, filePath string) error {
|
||
f, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
w := bufio.NewWriter(f)
|
||
w.Write(mssData)
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *Registrator) SaveRegisteredMssToGDALGTiff(tiffFile string) error {
|
||
log.Println("Saving registered MSS to TIFF file:", tiffFile)
|
||
|
||
width := r.MssWidth
|
||
height := r.MssHeight
|
||
|
||
// 创建合并后的图像(RGBIR)
|
||
r.rgbirImage = gocv.NewMatWithSize(height, width, gocv.MatTypeCV16UC4) // 4通道,16位
|
||
|
||
for y := 0; y < height; y++ {
|
||
for x := 0; x < width; x++ {
|
||
red := r.registeredMssImages[0].GetShortAt(y, x)
|
||
green := r.registeredMssImages[1].GetShortAt(y, x)
|
||
blue := r.registeredMssImages[2].GetShortAt(y, x)
|
||
ir := r.registeredMssImages[3].GetShortAt(y, x)
|
||
r.rgbirImage.SetShortAt(y, x*4+0, red)
|
||
r.rgbirImage.SetShortAt(y, x*4+1, green)
|
||
r.rgbirImage.SetShortAt(y, x*4+2, blue)
|
||
r.rgbirImage.SetShortAt(y, x*4+3, ir)
|
||
}
|
||
}
|
||
|
||
// 创建一个二维切片来存储图像数据
|
||
data := make([][]uint16, MssBands)
|
||
for i := range data {
|
||
data[i] = make([]uint16, width*height)
|
||
}
|
||
|
||
// 从gocv.Mat中提取数据
|
||
for y := 0; y < height; y++ {
|
||
for x := 0; x < width; x++ {
|
||
for b := 0; b < MssBands; b++ {
|
||
data[b][y*width+x] = uint16(r.rgbirImage.GetShortAt(y, x*4+b))
|
||
}
|
||
}
|
||
}
|
||
|
||
ds, err := godal.Create(godal.GTiff,
|
||
tiffFile,
|
||
MssBands,
|
||
godal.UInt16,
|
||
width, height)
|
||
if err != nil {
|
||
log.Error("Error creating TIFF file: ", err)
|
||
return err
|
||
}
|
||
defer ds.Close()
|
||
|
||
for b := 0; b < MssBands; b++ {
|
||
band := ds.Bands()[b]
|
||
err := band.IO(godal.IOWrite,
|
||
0, 0,
|
||
data[b],
|
||
width, height,
|
||
godal.PixelSpacing(2),
|
||
godal.LineSpacing(width*2))
|
||
if err != nil {
|
||
log.Error("Failed to write data to band:", err)
|
||
return err
|
||
}
|
||
}
|
||
|
||
log.Info("Saved registered mss to ", tiffFile)
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *Registrator) SavePanToGDALGTiff(tiffFile string) error {
|
||
log.Println("Saving PAN image to TIFF file:", tiffFile)
|
||
|
||
width := r.PanWidth
|
||
height := r.PanHeight
|
||
|
||
ds, err := godal.Create(godal.GTiff, tiffFile, 1, godal.UInt16, width, height)
|
||
if err != nil {
|
||
log.Error("Error creating TIFF file: ", err)
|
||
return err
|
||
}
|
||
defer ds.Close()
|
||
|
||
// 将通道的数据转换为字节数组
|
||
data := make([]uint16, width*height)
|
||
for y := 0; y < height; y++ {
|
||
for x := 0; x < width; x++ {
|
||
data[y*width+x] = uint16(r.PanImage.GetShortAt(y, x))
|
||
}
|
||
}
|
||
|
||
band := ds.Bands()[0]
|
||
err = band.IO(godal.IOWrite,
|
||
0, 0,
|
||
data,
|
||
width, height,
|
||
godal.PixelSpacing(2),
|
||
godal.LineSpacing(width*2))
|
||
if err != nil {
|
||
log.Error("Failed to write data to band:", err)
|
||
return err
|
||
}
|
||
|
||
log.Info("Saved pan image to ", tiffFile)
|
||
|
||
return nil
|
||
}
|