Files
sjy01-image-proc/image_registration.go
2024-05-25 09:24:19 +08:00

321 lines
8.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"bufio"
"fmt"
"image"
"math"
"os"
"github.com/chai2010/tiff"
log "github.com/sirupsen/logrus"
"gocv.io/x/gocv"
)
type Registrate interface{}
const (
MssBands = 4
PixelBytes = 2
PanWidth = 9344 // 像素宽度
MssWidth = 2336
BlockNH = 5
BlockNW = 10
DownSampled ResampleMethod = "down_sample_pan"
UpSampled ResampleMethod = "up_sample_mss"
)
type ResampleMethod string
type Registrator struct {
PanImage gocv.Mat
PanHeight int
PanWidth int
MssImages [4]gocv.Mat
MssHeight int
MssWidth int
phaseShifts [4][]*PhaseShiftM
registeredImages [4]gocv.Mat
rgbirImage gocv.Mat
resampleMethod ResampleMethod
}
func NewRegistrator() *Registrator {
var r Registrator
return &r
}
func (r *Registrator) LoadPanRaw(raw string) error {
data, err := os.ReadFile(raw)
if err != nil {
log.Error("Error reading raw file: ", err)
return err
}
height := len(data) / (PanWidth * PixelBytes)
r.PanImage, err = gocv.NewMatFromBytes(height, PanWidth, gocv.MatTypeCV16U, data)
if err != nil {
log.Error("Error creating Mat from bytes: ", err)
return err
}
r.PanHeight = height
r.PanWidth = PanWidth
return nil
}
func (r *Registrator) LoadMssRaw(raw string) error {
data, err := os.ReadFile(raw)
if err != nil {
log.Error("Error reading raw file: ", err)
return err
}
height := len(data) / (MssWidth * PixelBytes * MssBands)
mssData := make([][]byte, MssBands)
for h := 0; h < height; h++ {
row := data[h*MssWidth*MssBands*PixelBytes : (h+1)*MssWidth*MssBands*PixelBytes]
for i := 0; i < MssBands; i++ {
mssData[i] = append(mssData[i], row[i*MssWidth*PixelBytes:(i+1)*MssWidth*PixelBytes]...)
}
}
for i := 0; i < MssBands; i++ {
r.MssImages[i], err = gocv.NewMatFromBytes(height, MssWidth, gocv.MatTypeCV16U, mssData[i])
if err != nil {
log.Error("Error creating Mat from bytes: ", err)
return err
}
}
r.MssHeight = height
r.MssWidth = MssWidth
return nil
}
// 将PAN降采样后计算相位相关的偏移量
func (r *Registrator) DoDownPhaseCorrelation() error {
// 确保 MSS 高度是 PAN 高度的 1/4
if r.MssHeight*4 != r.PanHeight {
err := fmt.Errorf("MSS height is not 1/4 of PAN height, invalid raw file")
log.Error(err)
return err
}
// 将PAN将采样作为轮廓匹配基准图像
downsampledPanImage := gocv.NewMat()
gocv.Resize(r.PanImage, &downsampledPanImage,
image.Point{X: r.MssWidth, Y: r.MssHeight}, 0, 0, gocv.InterpolationLinear)
// 对每个 MSS 波段图像进行上采样
// upsampledMSSImages := make([]gocv.Mat, MssBands)
for i := 0; i < MssBands; i++ {
// upsampledMSSImages[i] = gocv.NewMat()
// gocv.Resize(r.MssImages[i], &upsampledMSSImages[i],
// image.Point{X: r.PanWidth, Y: r.PanHeight}, 0, 0, gocv.InterpolationLinear)
// r.msToRaw(upsampledMSSImages[i].ToBytes(), fmt.Sprintf("MSS%d.RAW", i+1))
}
fmt.Println("down sampled PAN images size:", downsampledPanImage.Size())
// 分块高度
blockHeight := r.MssHeight / BlockNH
alignedMssData := make([][]byte, MssBands)
for band := 0; band < MssBands; band++ {
alignedMSSImage := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV16U)
for bh := 0; bh < BlockNH; bh++ {
// 读取 PAN 和 MSS 分块数据
y1 := (bh + 1) * blockHeight
if bh == BlockNH-1 {
y1 = r.MssHeight
}
var shiftM PhaseShiftM
shiftM.Block.width = r.MssWidth // 块宽度
shiftM.Block.height = y1 - bh*blockHeight // 块高度
shiftM.Block.coord.X = 0 // 块左上角x坐标
shiftM.Block.coord.Y = bh * blockHeight // 块左上角y坐标
rect := image.Rect(
shiftM.Block.coord.X, shiftM.Block.coord.Y,
shiftM.Block.coord.X+shiftM.Block.width, shiftM.Block.coord.Y+shiftM.Block.height,
)
log.Println("Band", band+1, ", processing block", bh, rect)
panBlock := downsampledPanImage.Region(rect)
mssBlock := r.MssImages[band].Region(rect)
// 处理每个分块
alignedBlock, phaseShift := r.processBlock(panBlock, mssBlock)
shiftM.dx = phaseShift.X
shiftM.dy = phaseShift.Y
r.phaseShifts[band] = append(r.phaseShifts[band], &shiftM)
// alignedBlockData := alignedBlock.ToBytes()
// alignedMssData[band] = append(alignedMssData[band], alignedBlockData...)
// if alignedMSSImage.Empty() {
// alignedMSSImage = alignedBlock.Clone()
// } else {
// gocv.Vconcat(alignedMSSImage, alignedBlock, &alignedMSSImage)
// alignedBlock.Close()
// }
panBlock.Close()
mssBlock.Close()
alignedBlock.Close()
}
r.registeredImages[band] = alignedMSSImage
}
// 使用平均偏移量来做平移变换
for band := 0; band < MssBands; band++ {
var efficientShiftM int
var xTotal, yTotal float32
for _, shift := range r.phaseShifts[band] {
if math.IsNaN(float64(shift.dx)) || math.IsNaN(float64(shift.dy)) {
continue
}
efficientShiftM += 1
xTotal += shift.dx
yTotal += shift.dy
}
dx := xTotal / float32(efficientShiftM)
dy := yTotal / float32(efficientShiftM)
log.Println("Band", band+1, "average shift:", dx, dy, "efficientShiftM:", efficientShiftM)
translationMat := gocv.NewMatWithSize(2, 3, gocv.MatTypeCV32F)
translationMat.SetFloatAt(0, 0, 1)
translationMat.SetFloatAt(0, 1, 0)
translationMat.SetFloatAt(0, 2, dx)
translationMat.SetFloatAt(1, 0, 0)
translationMat.SetFloatAt(1, 1, 1)
translationMat.SetFloatAt(1, 2, dy)
alignedMss := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV32FC1)
cvtMss := gocv.NewMat()
r.MssImages[band].ConvertTo(&cvtMss, gocv.MatTypeCV32FC1)
// 手动平移像素
outY := math.MaxInt
for y := 0; y < r.MssHeight; y++ {
for x := 0; x < r.MssWidth; x++ {
// 计算新的坐标
newX := x + int(dx)
newY := y + int(dy)
// 如果新坐标在图像范围内,进行像素值赋值
if newX >= 0 && newX < r.MssWidth && newY >= 0 && newY < r.MssHeight {
alignedMss.SetFloatAt(y, x, cvtMss.GetFloatAt(newY, newX))
} else {
// 如果新坐标不在图像范围内,设置为黑色
alignedMss.SetFloatAt(y, x, 0)
if outY > y {
outY = y
log.Println("Warning: pixel out of range", x, y)
}
}
}
}
// gocv.WarpAffine(cvtMss, &alignedMss, translationMat, image.Pt(cvtMss.Size()[1], cvtMss.Size()[0]))
alignedMss.ConvertTo(&alignedMss, gocv.MatTypeCV16U)
alignedMssData[band] = append(alignedMssData[band], alignedMss.ToBytes()...)
translationMat.Close()
cvtMss.Close()
alignedMss.Close()
}
r.mssToRaw(alignedMssData)
return nil
}
func (r *Registrator) panToTiff(panImage gocv.Mat, filePath string) error {
return nil
}
func (r *Registrator) mssToTiff(registeredImages [4]gocv.Mat, filePath string) error {
// 创建合并后的图像RGBIR
rgbirImage := gocv.NewMatWithSize(r.PanHeight, r.PanWidth, gocv.MatTypeCV16UC4) // 4通道16位
for y := 0; y < r.PanHeight; y++ {
for x := 0; x < r.PanWidth; x++ {
r := registeredImages[0].GetShortAt(y, x)
g := registeredImages[1].GetShortAt(y, x)
b := registeredImages[2].GetShortAt(y, x)
ir := registeredImages[3].GetShortAt(y, x)
rgbirImage.SetShortAt(y, x*4+0, r)
rgbirImage.SetShortAt(y, x*4+1, g)
rgbirImage.SetShortAt(y, x*4+2, b)
rgbirImage.SetShortAt(y, x*4+3, ir)
}
}
// 将合并后的图像保存为TIFF文件
fileName := "data/registered_rgbir.tiff"
tiffFile, err := os.Create(fileName)
if err != nil {
fmt.Println("Error creating TIFF file:", err)
return err
}
defer tiffFile.Close()
// 使用tiff库保存图像
img, err := rgbirImage.ToImage()
if err != nil {
fmt.Println("Error converting Mat to image:", err)
return err
}
if err := tiff.Encode(tiffFile, img, nil); err != nil {
fmt.Println("Error encoding TIFF file:", err)
return err
}
fmt.Println("Saved", fileName)
return nil
}
func (r *Registrator) mssToRaw(mssData [][]byte) error {
f, err := os.OpenFile("data/downsampled_registered_mss.RAW", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
if err != nil {
return err
}
log.Println("Writing downsampled registered MSS to RAW file...", len(mssData[0])*4)
log.Println("width:", r.MssWidth*PixelBytes*4)
log.Println("height:", r.MssHeight)
w := bufio.NewWriter(f)
for row := 0; row < r.MssHeight; row++ {
w.Write(mssData[0][row*r.MssWidth*PixelBytes : (row+1)*r.MssWidth*PixelBytes])
w.Write(mssData[1][row*r.MssWidth*PixelBytes : (row+1)*r.MssWidth*PixelBytes])
w.Write(mssData[2][row*r.MssWidth*PixelBytes : (row+1)*r.MssWidth*PixelBytes])
w.Write(mssData[3][row*r.MssWidth*PixelBytes : (row+1)*r.MssWidth*PixelBytes])
}
return nil
}
func (r *Registrator) bytesToRaw(mssData []byte, filePath string) error {
f, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
if err != nil {
return err
}
w := bufio.NewWriter(f)
w.Write(mssData)
return nil
}