This commit is contained in:
nuknal
2024-05-25 19:43:17 +08:00
parent b9f61cd26f
commit e3d98cb959
6 changed files with 2122 additions and 166 deletions

View File

@@ -4,10 +4,9 @@ import (
"bufio"
"fmt"
"image"
"math"
"os"
"github.com/chai2010/tiff"
"github.com/airbusgeo/godal"
log "github.com/sirupsen/logrus"
"gocv.io/x/gocv"
)
@@ -36,9 +35,10 @@ type Registrator struct {
MssHeight int
MssWidth int
phaseShifts [4][]*PhaseShiftM
registeredImages [4]gocv.Mat
rgbirImage gocv.Mat
phaseShifts [4][]*PhaseShiftM
registeredMssImages [4]gocv.Mat // 平移处理后升采样到PAN分辨率的MSS图像
rgbirImage gocv.Mat
resampleMethod ResampleMethod
}
@@ -66,6 +66,21 @@ func (r *Registrator) LoadPanRaw(raw string) error {
r.PanHeight = height
r.PanWidth = PanWidth
godal.RegisterAll()
hDriver, ok := godal.RasterDriver("Gtiff")
if !ok {
panic("Gtiff not found")
}
md := hDriver.Metadatas()
if md["DCAP_CREATE"] == "YES" {
fmt.Printf("Driver GTiff supports Create() method.\n")
}
if md["DCAP_CREATECOPY"] == "YES" {
fmt.Printf("Driver Gtiff supports CreateCopy() method.\n")
}
fmt.Println("Gtiff driver name:", hDriver.LongName(), hDriver.ShortName())
return nil
}
@@ -128,11 +143,7 @@ func (r *Registrator) DoDownPhaseCorrelation() error {
// 分块高度
blockHeight := r.MssHeight / BlockNH
alignedMssData := make([][]byte, MssBands)
for band := 0; band < MssBands; band++ {
alignedMSSImage := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV16U)
for bh := 0; bh < BlockNH; bh++ {
// 读取 PAN 和 MSS 分块数据
y1 := (bh + 1) * blockHeight
@@ -156,153 +167,48 @@ func (r *Registrator) DoDownPhaseCorrelation() error {
mssBlock := r.MssImages[band].Region(rect)
// 处理每个分块
alignedBlock, phaseShift := r.processBlock(panBlock, mssBlock)
phaseShift := r.calculateBlockPhaseShift(panBlock, mssBlock)
shiftM.dx = phaseShift.X
shiftM.dy = phaseShift.Y
r.phaseShifts[band] = append(r.phaseShifts[band], &shiftM)
// alignedBlockData := alignedBlock.ToBytes()
// alignedMssData[band] = append(alignedMssData[band], alignedBlockData...)
// if alignedMSSImage.Empty() {
// alignedMSSImage = alignedBlock.Clone()
// } else {
// gocv.Vconcat(alignedMSSImage, alignedBlock, &alignedMSSImage)
// alignedBlock.Close()
// }
panBlock.Close()
mssBlock.Close()
alignedBlock.Close()
}
r.registeredImages[band] = alignedMSSImage
}
// 使用平均偏移量来做平移变换
for band := 0; band < MssBands; band++ {
var efficientShiftM int
var xTotal, yTotal float32
for _, shift := range r.phaseShifts[band] {
if math.IsNaN(float64(shift.dx)) || math.IsNaN(float64(shift.dy)) {
continue
}
efficientShiftM += 1
xTotal += shift.dx
yTotal += shift.dy
}
dx := xTotal / float32(efficientShiftM)
dy := yTotal / float32(efficientShiftM)
log.Println("Band", band+1, "average shift:", dx, dy, "efficientShiftM:", efficientShiftM)
translationMat := gocv.NewMatWithSize(2, 3, gocv.MatTypeCV32F)
translationMat.SetFloatAt(0, 0, 1)
translationMat.SetFloatAt(0, 1, 0)
translationMat.SetFloatAt(0, 2, dx)
translationMat.SetFloatAt(1, 0, 0)
translationMat.SetFloatAt(1, 1, 1)
translationMat.SetFloatAt(1, 2, dy)
alignedMss := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV32FC1)
cvtMss := gocv.NewMat()
r.MssImages[band].ConvertTo(&cvtMss, gocv.MatTypeCV32FC1)
// 手动平移像素
outY := math.MaxInt
for y := 0; y < r.MssHeight; y++ {
for x := 0; x < r.MssWidth; x++ {
// 计算新的坐标
newX := x + int(dx)
newY := y + int(dy)
// 如果新坐标在图像范围内,进行像素值赋值
if newX >= 0 && newX < r.MssWidth && newY >= 0 && newY < r.MssHeight {
alignedMss.SetFloatAt(y, x, cvtMss.GetFloatAt(newY, newX))
} else {
// 如果新坐标不在图像范围内,设置为黑色
alignedMss.SetFloatAt(y, x, 0)
if outY > y {
outY = y
log.Println("Warning: pixel out of range", x, y)
}
}
}
}
// gocv.WarpAffine(cvtMss, &alignedMss, translationMat, image.Pt(cvtMss.Size()[1], cvtMss.Size()[0]))
alignedMss.ConvertTo(&alignedMss, gocv.MatTypeCV16U)
alignedMssData[band] = append(alignedMssData[band], alignedMss.ToBytes()...)
translationMat.Close()
cvtMss.Close()
alignedMss.Close()
alignedMssData, err := r.DoMssPhaseShift()
if err != nil {
log.Error("Error calculating MSS phase shift: ", err)
return err
}
r.mssToRaw(alignedMssData)
// r.bytesToRaw(r.PanImage.ToBytes(), "data/PAN.RAW")
// r.SavePanToGDALGTiff("data/pan.tiff")
r.SaveRegisteredMssToGDALGTiff("data/registered_mss.tiff")
return nil
}
func (r *Registrator) panToTiff(panImage gocv.Mat, filePath string) error {
return nil
}
func (r *Registrator) mssToTiff(registeredImages [4]gocv.Mat, filePath string) error {
// 创建合并后的图像RGBIR
rgbirImage := gocv.NewMatWithSize(r.PanHeight, r.PanWidth, gocv.MatTypeCV16UC4) // 4通道16位
for y := 0; y < r.PanHeight; y++ {
for x := 0; x < r.PanWidth; x++ {
r := registeredImages[0].GetShortAt(y, x)
g := registeredImages[1].GetShortAt(y, x)
b := registeredImages[2].GetShortAt(y, x)
ir := registeredImages[3].GetShortAt(y, x)
rgbirImage.SetShortAt(y, x*4+0, r)
rgbirImage.SetShortAt(y, x*4+1, g)
rgbirImage.SetShortAt(y, x*4+2, b)
rgbirImage.SetShortAt(y, x*4+3, ir)
}
}
// 将合并后的图像保存为TIFF文件
fileName := "data/registered_rgbir.tiff"
tiffFile, err := os.Create(fileName)
if err != nil {
fmt.Println("Error creating TIFF file:", err)
return err
}
defer tiffFile.Close()
// 使用tiff库保存图像
img, err := rgbirImage.ToImage()
if err != nil {
fmt.Println("Error converting Mat to image:", err)
return err
}
if err := tiff.Encode(tiffFile, img, nil); err != nil {
fmt.Println("Error encoding TIFF file:", err)
return err
}
fmt.Println("Saved", fileName)
return nil
}
func (r *Registrator) mssToRaw(mssData [][]byte) error {
f, err := os.OpenFile("data/downsampled_registered_mss.RAW", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
if err != nil {
return err
}
width := r.MssWidth * PixelBytes
height := r.MssHeight
log.Println("Writing downsampled registered MSS to RAW file...", len(mssData[0])*4)
log.Println("width:", r.MssWidth*PixelBytes*4)
log.Println("height:", r.MssHeight)
log.Println("width:", width)
log.Println("height:", height)
w := bufio.NewWriter(f)
for row := 0; row < r.MssHeight; row++ {
w.Write(mssData[0][row*r.MssWidth*PixelBytes : (row+1)*r.MssWidth*PixelBytes])
w.Write(mssData[1][row*r.MssWidth*PixelBytes : (row+1)*r.MssWidth*PixelBytes])
w.Write(mssData[2][row*r.MssWidth*PixelBytes : (row+1)*r.MssWidth*PixelBytes])
w.Write(mssData[3][row*r.MssWidth*PixelBytes : (row+1)*r.MssWidth*PixelBytes])
for row := 0; row < height; row++ {
w.Write(mssData[0][row*width : (row+1)*width])
w.Write(mssData[1][row*width : (row+1)*width])
w.Write(mssData[2][row*width : (row+1)*width])
w.Write(mssData[3][row*width : (row+1)*width])
}
return nil
@@ -318,3 +224,108 @@ func (r *Registrator) bytesToRaw(mssData []byte, filePath string) error {
return nil
}
func (r *Registrator) SaveRegisteredMssToGDALGTiff(tiffFile string) error {
log.Println("Saving registered MSS to TIFF file...")
width := r.MssWidth
height := r.MssHeight
// 创建合并后的图像RGBIR
r.rgbirImage = gocv.NewMatWithSize(height, width, gocv.MatTypeCV16UC4) // 4通道16位
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
red := r.registeredMssImages[0].GetShortAt(y, x)
green := r.registeredMssImages[1].GetShortAt(y, x)
blue := r.registeredMssImages[2].GetShortAt(y, x)
ir := r.registeredMssImages[3].GetShortAt(y, x)
r.rgbirImage.SetShortAt(y, x*4+0, red)
r.rgbirImage.SetShortAt(y, x*4+1, green)
r.rgbirImage.SetShortAt(y, x*4+2, blue)
r.rgbirImage.SetShortAt(y, x*4+3, ir)
}
}
// 创建一个二维切片来存储图像数据
data := make([][]uint16, MssBands)
for i := range data {
data[i] = make([]uint16, width*height)
}
// 从gocv.Mat中提取数据
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
for b := 0; b < MssBands; b++ {
data[b][y*width+x] = uint16(r.rgbirImage.GetShortAt(y, x*4+b))
}
}
}
ds, err := godal.Create(godal.GTiff,
tiffFile,
MssBands,
godal.UInt16,
width, height)
if err != nil {
log.Error("Error creating TIFF file: ", err)
return err
}
defer ds.Close()
for b := 0; b < MssBands; b++ {
band := ds.Bands()[b]
err := band.IO(godal.IOWrite,
0, 0,
data[b],
width, height,
godal.PixelSpacing(2),
godal.LineSpacing(width*2))
if err != nil {
log.Error("Failed to write data to band:", err)
return err
}
}
log.Info("saved registered mss to tiff file: ", tiffFile)
return nil
}
func (r *Registrator) SavePanToGDALGTiff(tiffFile string) error {
log.Println("Saving PAN image to TIFF file...")
width := r.PanWidth
height := r.PanHeight
ds, err := godal.Create(godal.GTiff, tiffFile, 1, godal.UInt16, width, height)
if err != nil {
log.Error("Error creating TIFF file: ", err)
return err
}
defer ds.Close()
// 将通道的数据转换为字节数组
data := make([]uint16, width*height)
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
data[y*width+x] = uint16(r.PanImage.GetShortAt(y, x))
}
}
band := ds.Bands()[0]
err = band.IO(godal.IOWrite,
0, 0,
data,
width, height,
godal.PixelSpacing(2),
godal.LineSpacing(width*2))
if err != nil {
log.Error("Failed to write data to band:", err)
return err
}
log.Info("saved pan image to tiff file: ", tiffFile)
return nil
}