321 lines
8.8 KiB
Go
321 lines
8.8 KiB
Go
package main
|
||
|
||
import (
|
||
"bufio"
|
||
"fmt"
|
||
"image"
|
||
"math"
|
||
"os"
|
||
|
||
"github.com/chai2010/tiff"
|
||
log "github.com/sirupsen/logrus"
|
||
"gocv.io/x/gocv"
|
||
)
|
||
|
||
type Registrate interface{}
|
||
|
||
const (
|
||
MssBands = 4
|
||
PixelBytes = 2
|
||
PanWidth = 9344 // 像素宽度
|
||
MssWidth = 2336
|
||
BlockNH = 5
|
||
BlockNW = 10
|
||
DownSampled ResampleMethod = "down_sample_pan"
|
||
UpSampled ResampleMethod = "up_sample_mss"
|
||
)
|
||
|
||
type ResampleMethod string
|
||
|
||
type Registrator struct {
|
||
PanImage gocv.Mat
|
||
PanHeight int
|
||
PanWidth int
|
||
|
||
MssImages [4]gocv.Mat
|
||
MssHeight int
|
||
MssWidth int
|
||
|
||
phaseShifts [4][]*PhaseShiftM
|
||
registeredImages [4]gocv.Mat
|
||
rgbirImage gocv.Mat
|
||
|
||
resampleMethod ResampleMethod
|
||
}
|
||
|
||
func NewRegistrator() *Registrator {
|
||
var r Registrator
|
||
|
||
return &r
|
||
}
|
||
|
||
func (r *Registrator) LoadPanRaw(raw string) error {
|
||
data, err := os.ReadFile(raw)
|
||
if err != nil {
|
||
log.Error("Error reading raw file: ", err)
|
||
return err
|
||
}
|
||
|
||
height := len(data) / (PanWidth * PixelBytes)
|
||
r.PanImage, err = gocv.NewMatFromBytes(height, PanWidth, gocv.MatTypeCV16U, data)
|
||
if err != nil {
|
||
log.Error("Error creating Mat from bytes: ", err)
|
||
return err
|
||
}
|
||
|
||
r.PanHeight = height
|
||
r.PanWidth = PanWidth
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *Registrator) LoadMssRaw(raw string) error {
|
||
data, err := os.ReadFile(raw)
|
||
if err != nil {
|
||
log.Error("Error reading raw file: ", err)
|
||
return err
|
||
}
|
||
|
||
height := len(data) / (MssWidth * PixelBytes * MssBands)
|
||
mssData := make([][]byte, MssBands)
|
||
for h := 0; h < height; h++ {
|
||
row := data[h*MssWidth*MssBands*PixelBytes : (h+1)*MssWidth*MssBands*PixelBytes]
|
||
for i := 0; i < MssBands; i++ {
|
||
mssData[i] = append(mssData[i], row[i*MssWidth*PixelBytes:(i+1)*MssWidth*PixelBytes]...)
|
||
}
|
||
}
|
||
|
||
for i := 0; i < MssBands; i++ {
|
||
r.MssImages[i], err = gocv.NewMatFromBytes(height, MssWidth, gocv.MatTypeCV16U, mssData[i])
|
||
if err != nil {
|
||
log.Error("Error creating Mat from bytes: ", err)
|
||
return err
|
||
}
|
||
}
|
||
|
||
r.MssHeight = height
|
||
r.MssWidth = MssWidth
|
||
|
||
return nil
|
||
}
|
||
|
||
// 将PAN降采样后计算相位相关的偏移量
|
||
func (r *Registrator) DoDownPhaseCorrelation() error {
|
||
// 确保 MSS 高度是 PAN 高度的 1/4
|
||
if r.MssHeight*4 != r.PanHeight {
|
||
err := fmt.Errorf("MSS height is not 1/4 of PAN height, invalid raw file")
|
||
log.Error(err)
|
||
return err
|
||
}
|
||
|
||
// 将PAN将采样作为轮廓匹配基准图像
|
||
downsampledPanImage := gocv.NewMat()
|
||
gocv.Resize(r.PanImage, &downsampledPanImage,
|
||
image.Point{X: r.MssWidth, Y: r.MssHeight}, 0, 0, gocv.InterpolationLinear)
|
||
|
||
// 对每个 MSS 波段图像进行上采样
|
||
// upsampledMSSImages := make([]gocv.Mat, MssBands)
|
||
|
||
for i := 0; i < MssBands; i++ {
|
||
// upsampledMSSImages[i] = gocv.NewMat()
|
||
// gocv.Resize(r.MssImages[i], &upsampledMSSImages[i],
|
||
// image.Point{X: r.PanWidth, Y: r.PanHeight}, 0, 0, gocv.InterpolationLinear)
|
||
|
||
// r.msToRaw(upsampledMSSImages[i].ToBytes(), fmt.Sprintf("MSS%d.RAW", i+1))
|
||
}
|
||
|
||
fmt.Println("down sampled PAN images size:", downsampledPanImage.Size())
|
||
|
||
// 分块高度
|
||
blockHeight := r.MssHeight / BlockNH
|
||
|
||
alignedMssData := make([][]byte, MssBands)
|
||
|
||
for band := 0; band < MssBands; band++ {
|
||
alignedMSSImage := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV16U)
|
||
for bh := 0; bh < BlockNH; bh++ {
|
||
// 读取 PAN 和 MSS 分块数据
|
||
y1 := (bh + 1) * blockHeight
|
||
if bh == BlockNH-1 {
|
||
y1 = r.MssHeight
|
||
}
|
||
|
||
var shiftM PhaseShiftM
|
||
shiftM.Block.width = r.MssWidth // 块宽度
|
||
shiftM.Block.height = y1 - bh*blockHeight // 块高度
|
||
shiftM.Block.coord.X = 0 // 块左上角x坐标
|
||
shiftM.Block.coord.Y = bh * blockHeight // 块左上角y坐标
|
||
|
||
rect := image.Rect(
|
||
shiftM.Block.coord.X, shiftM.Block.coord.Y,
|
||
shiftM.Block.coord.X+shiftM.Block.width, shiftM.Block.coord.Y+shiftM.Block.height,
|
||
)
|
||
log.Println("Band", band+1, ", processing block", bh, rect)
|
||
|
||
panBlock := downsampledPanImage.Region(rect)
|
||
mssBlock := r.MssImages[band].Region(rect)
|
||
|
||
// 处理每个分块
|
||
alignedBlock, phaseShift := r.processBlock(panBlock, mssBlock)
|
||
shiftM.dx = phaseShift.X
|
||
shiftM.dy = phaseShift.Y
|
||
|
||
r.phaseShifts[band] = append(r.phaseShifts[band], &shiftM)
|
||
|
||
// alignedBlockData := alignedBlock.ToBytes()
|
||
// alignedMssData[band] = append(alignedMssData[band], alignedBlockData...)
|
||
|
||
// if alignedMSSImage.Empty() {
|
||
// alignedMSSImage = alignedBlock.Clone()
|
||
// } else {
|
||
// gocv.Vconcat(alignedMSSImage, alignedBlock, &alignedMSSImage)
|
||
// alignedBlock.Close()
|
||
// }
|
||
|
||
panBlock.Close()
|
||
mssBlock.Close()
|
||
alignedBlock.Close()
|
||
}
|
||
|
||
r.registeredImages[band] = alignedMSSImage
|
||
}
|
||
|
||
// 使用平均偏移量来做平移变换
|
||
for band := 0; band < MssBands; band++ {
|
||
var efficientShiftM int
|
||
var xTotal, yTotal float32
|
||
for _, shift := range r.phaseShifts[band] {
|
||
if math.IsNaN(float64(shift.dx)) || math.IsNaN(float64(shift.dy)) {
|
||
continue
|
||
}
|
||
efficientShiftM += 1
|
||
xTotal += shift.dx
|
||
yTotal += shift.dy
|
||
}
|
||
dx := xTotal / float32(efficientShiftM)
|
||
dy := yTotal / float32(efficientShiftM)
|
||
log.Println("Band", band+1, "average shift:", dx, dy, "efficientShiftM:", efficientShiftM)
|
||
|
||
translationMat := gocv.NewMatWithSize(2, 3, gocv.MatTypeCV32F)
|
||
translationMat.SetFloatAt(0, 0, 1)
|
||
translationMat.SetFloatAt(0, 1, 0)
|
||
translationMat.SetFloatAt(0, 2, dx)
|
||
translationMat.SetFloatAt(1, 0, 0)
|
||
translationMat.SetFloatAt(1, 1, 1)
|
||
translationMat.SetFloatAt(1, 2, dy)
|
||
|
||
alignedMss := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV32FC1)
|
||
cvtMss := gocv.NewMat()
|
||
r.MssImages[band].ConvertTo(&cvtMss, gocv.MatTypeCV32FC1)
|
||
// 手动平移像素
|
||
outY := math.MaxInt
|
||
for y := 0; y < r.MssHeight; y++ {
|
||
for x := 0; x < r.MssWidth; x++ {
|
||
// 计算新的坐标
|
||
newX := x + int(dx)
|
||
newY := y + int(dy)
|
||
|
||
// 如果新坐标在图像范围内,进行像素值赋值
|
||
if newX >= 0 && newX < r.MssWidth && newY >= 0 && newY < r.MssHeight {
|
||
alignedMss.SetFloatAt(y, x, cvtMss.GetFloatAt(newY, newX))
|
||
} else {
|
||
// 如果新坐标不在图像范围内,设置为黑色
|
||
alignedMss.SetFloatAt(y, x, 0)
|
||
if outY > y {
|
||
outY = y
|
||
log.Println("Warning: pixel out of range", x, y)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// gocv.WarpAffine(cvtMss, &alignedMss, translationMat, image.Pt(cvtMss.Size()[1], cvtMss.Size()[0]))
|
||
alignedMss.ConvertTo(&alignedMss, gocv.MatTypeCV16U)
|
||
alignedMssData[band] = append(alignedMssData[band], alignedMss.ToBytes()...)
|
||
|
||
translationMat.Close()
|
||
cvtMss.Close()
|
||
alignedMss.Close()
|
||
}
|
||
|
||
r.mssToRaw(alignedMssData)
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *Registrator) panToTiff(panImage gocv.Mat, filePath string) error {
|
||
return nil
|
||
}
|
||
|
||
func (r *Registrator) mssToTiff(registeredImages [4]gocv.Mat, filePath string) error {
|
||
// 创建合并后的图像(RGBIR)
|
||
rgbirImage := gocv.NewMatWithSize(r.PanHeight, r.PanWidth, gocv.MatTypeCV16UC4) // 4通道,16位
|
||
|
||
for y := 0; y < r.PanHeight; y++ {
|
||
for x := 0; x < r.PanWidth; x++ {
|
||
r := registeredImages[0].GetShortAt(y, x)
|
||
g := registeredImages[1].GetShortAt(y, x)
|
||
b := registeredImages[2].GetShortAt(y, x)
|
||
ir := registeredImages[3].GetShortAt(y, x)
|
||
rgbirImage.SetShortAt(y, x*4+0, r)
|
||
rgbirImage.SetShortAt(y, x*4+1, g)
|
||
rgbirImage.SetShortAt(y, x*4+2, b)
|
||
rgbirImage.SetShortAt(y, x*4+3, ir)
|
||
}
|
||
}
|
||
|
||
// 将合并后的图像保存为TIFF文件
|
||
fileName := "data/registered_rgbir.tiff"
|
||
tiffFile, err := os.Create(fileName)
|
||
if err != nil {
|
||
fmt.Println("Error creating TIFF file:", err)
|
||
return err
|
||
}
|
||
defer tiffFile.Close()
|
||
|
||
// 使用tiff库保存图像
|
||
img, err := rgbirImage.ToImage()
|
||
if err != nil {
|
||
fmt.Println("Error converting Mat to image:", err)
|
||
return err
|
||
}
|
||
if err := tiff.Encode(tiffFile, img, nil); err != nil {
|
||
fmt.Println("Error encoding TIFF file:", err)
|
||
return err
|
||
}
|
||
|
||
fmt.Println("Saved", fileName)
|
||
return nil
|
||
}
|
||
|
||
func (r *Registrator) mssToRaw(mssData [][]byte) error {
|
||
f, err := os.OpenFile("data/downsampled_registered_mss.RAW", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
log.Println("Writing downsampled registered MSS to RAW file...", len(mssData[0])*4)
|
||
log.Println("width:", r.MssWidth*PixelBytes*4)
|
||
log.Println("height:", r.MssHeight)
|
||
w := bufio.NewWriter(f)
|
||
for row := 0; row < r.MssHeight; row++ {
|
||
w.Write(mssData[0][row*r.MssWidth*PixelBytes : (row+1)*r.MssWidth*PixelBytes])
|
||
w.Write(mssData[1][row*r.MssWidth*PixelBytes : (row+1)*r.MssWidth*PixelBytes])
|
||
w.Write(mssData[2][row*r.MssWidth*PixelBytes : (row+1)*r.MssWidth*PixelBytes])
|
||
w.Write(mssData[3][row*r.MssWidth*PixelBytes : (row+1)*r.MssWidth*PixelBytes])
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r *Registrator) bytesToRaw(mssData []byte, filePath string) error {
|
||
f, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
w := bufio.NewWriter(f)
|
||
w.Write(mssData)
|
||
|
||
return nil
|
||
}
|