pansharpen

This commit is contained in:
nuknal
2024-05-27 18:14:03 +08:00
parent 8a15159d05
commit 106fc37aad
6 changed files with 351 additions and 202 deletions

View File

@@ -21,9 +21,9 @@ const (
PixelBytes = 2
PanWidth = 9344 // 像素宽度
MssWidth = 2336
BlockNH = 8
BlockNW = 4
OverlappedBlockLines = 2000 // 重叠块的行数
BlockNH = 4
BlockNW = 16
OverlappedBlockLines = 3000 // 重叠块的行数
DownSampled ResampleMethod = "down_sample_pan"
UpSampled ResampleMethod = "up_sample_mss"
)
@@ -41,8 +41,8 @@ type Registrator struct {
shiftMutex sync.Mutex
phaseShifts [4][]PhaseShiftM
deltaXCoeffs [4][]float64 // Polynomial fitting coefficients: 图像内畸变(非线性变换),捕捉图像在水平方向上引起的垂直方向的变形
deltaYCoeffs [4][]float64 // Polynomial fitting coefficients: 图像内畸变(非线性变换),捕捉图像在垂直方向上引起的水平方向的变形
deltaXCoeffs [4][]float64 // 图像内畸变(非线性变换),捕捉图像在水平方向上引起的垂直方向的变形
deltaYCoeffs [4][]float64 // 图像内畸变(非线性变换),捕捉图像在垂直方向上引起的水平方向的变形
registeredMssImages [4]gocv.Mat // 配准后的MSS图像
rgbirImage gocv.Mat
@@ -52,6 +52,7 @@ type Registrator struct {
func NewRegistrator() *Registrator {
var r Registrator
r.resampleMethod = DownSampled
return &r
}
@@ -121,6 +122,15 @@ func (r *Registrator) LoadMssRaw(raw string) error {
return nil
}
func (r *Registrator) DoPhaseCorrelation() error {
switch r.resampleMethod {
case UpSampled:
return r.CalcUpPhaseCorrelation()
default:
return r.CalcDownPhaseCorrelation()
}
}
// 将PAN降采样后计算相位相关的偏移量
func (r *Registrator) CalcDownPhaseCorrelation() error {
// 确保 MSS 高度是 PAN 高度的 1/4
@@ -134,63 +144,14 @@ func (r *Registrator) CalcDownPhaseCorrelation() error {
downsampledPanImage := gocv.NewMat()
gocv.Resize(r.PanImage, &downsampledPanImage,
image.Point{X: r.MssWidth, Y: r.MssHeight}, 0, 0, gocv.InterpolationCubic)
fmt.Println("down sampled PAN images size:", downsampledPanImage.Size())
log.Println("down sampled PAN images size:", downsampledPanImage.Size())
// 分块高度
blockHeight := r.MssHeight / BlockNH
for band := 0; band < MssBands; band++ {
for bh := 0; bh < BlockNH; bh++ {
// 读取 PAN 和 MSS 分块数据
y1 := (bh+1)*blockHeight + 800
if y1 > r.MssHeight {
y1 = r.MssHeight
}
blockWidth := r.MssWidth / BlockNW
var shiftM PhaseShiftM
shiftM.Block.width = r.MssWidth // 块宽度
shiftM.Block.height = y1 - bh*blockHeight // 块高度
shiftM.Block.coord.X = 0 // 块左上角x坐标
shiftM.Block.coord.Y = bh * blockHeight // 块左上角y坐标
return r.calcPhaseCorrelation(downsampledPanImage, r.MssImages, r.MssHeight, r.MssWidth, blockHeight, blockWidth)
rect := image.Rect(
shiftM.Block.coord.X, shiftM.Block.coord.Y,
shiftM.Block.coord.X+shiftM.Block.width, shiftM.Block.coord.Y+shiftM.Block.height,
)
log.Println("Band", band+1, ", processing block", bh, rect)
panBlock := downsampledPanImage.Region(rect)
mssBlock := r.MssImages[band].Region(rect)
// 处理每个分块
phaseShift, response := r.calculateBlockPhaseShift(panBlock, mssBlock)
shiftM.dx = phaseShift.X
shiftM.dy = phaseShift.Y
shiftM.response = response
r.phaseShifts[band] = append(r.phaseShifts[band], shiftM)
panBlock.Close()
mssBlock.Close()
}
}
// if err := r.DoMssPhaseShift(); err != nil {
// log.Error("Error calculating MSS phase shift: ", err)
// return err
// }
for i := 0; i < MssBands; i++ {
for j, shift := range r.phaseShifts[i] {
if shift.response > 0.4 || shift.dy > 8 {
fmt.Printf("Band %d, block %d, dx=%f, dy=%f, response=%f\n",
i, j, shift.dx, shift.dy, shift.response)
}
}
}
r.calcDeltaCoeffs()
return nil
}
// 将MSS升采样采样后计算相位相关的偏移量
@@ -218,6 +179,13 @@ func (r *Registrator) CalcUpPhaseCorrelation() error {
log.Infof("blockHeight=%d, blockWidth=%d", blockHeight, blockWidth)
return r.calcPhaseCorrelation(r.PanImage, upsampledMssImages, r.PanHeight, r.PanWidth, blockHeight, blockWidth)
}
func (r *Registrator) calcPhaseCorrelation(panImage gocv.Mat,
mssImages [4]gocv.Mat,
height, width,
blockHeight, blockWidth int) error {
var wg sync.WaitGroup
for bh := 0; bh < BlockNH; bh++ {
@@ -231,12 +199,12 @@ func (r *Registrator) CalcUpPhaseCorrelation() error {
y1 := (bh + 1) * blockHeight
y1 += OverlappedBlockLines // Y偏移量过大需要将重叠块的行数加上以避免边界影响
if x1 > r.PanWidth || y1 > r.PanHeight {
if x1 > width || y1 > height {
log.Warnf("Block out of range. x0=%d, y0=%d, x1=%d, y1=%d", x0, y0, x1, y1)
}
if y1 > r.PanHeight {
y1 = r.PanHeight
if y1 > height {
y1 = height
}
var shiftM PhaseShiftM
@@ -250,10 +218,10 @@ func (r *Registrator) CalcUpPhaseCorrelation() error {
x1, y1,
)
panBlock := r.PanImage.Region(rect)
panBlock := panImage.Region(rect)
for band := 0; band < MssBands; band++ {
log.Println("Band", band+1, ", processing block", bh, rect)
mssBlock := upsampledMssImages[band].Region(rect)
mssBlock := mssImages[band].Region(rect)
// 处理每个分块
phaseShift, response := r.calculateBlockPhaseShift(panBlock, mssBlock)
@@ -277,12 +245,14 @@ func (r *Registrator) CalcUpPhaseCorrelation() error {
for i := 0; i < MssBands; i++ {
for _, shift := range r.phaseShifts[i] {
if shift.response > 0.4 || shift.dx > 8 || shift.dy > 8 {
fmt.Printf("Band %d, block %d, dx=%f, dy=%f, response=%f\n",
log.Printf("Band %d, block %d, dx=%f, dy=%f, response=%f\n",
i, shift.Block.coord.X, shift.dx, shift.dy, shift.response)
}
}
}
r.calcDeltaCoeffs()
return nil
}
@@ -313,127 +283,6 @@ func (r *Registrator) SaveRegisteredMssToRaw(raw string) error {
return nil
}
func (r *Registrator) bytesToRaw(mssData []byte, filePath string) error {
f, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0777)
if err != nil {
return err
}
w := bufio.NewWriter(f)
w.Write(mssData)
return nil
}
func (r *Registrator) SaveRegisteredMssToGDALGTiff(tiffFile string) error {
log.Println("Saving registered MSS to TIFF file:", tiffFile)
width := r.MssWidth
height := r.MssHeight
// 创建合并后的图像RGBIR
r.rgbirImage = gocv.NewMatWithSize(height, width, gocv.MatTypeCV16UC4) // 4通道16位
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
red := r.registeredMssImages[0].GetShortAt(y, x)
green := r.registeredMssImages[1].GetShortAt(y, x)
blue := r.registeredMssImages[2].GetShortAt(y, x)
ir := r.registeredMssImages[3].GetShortAt(y, x)
r.rgbirImage.SetShortAt(y, x*4+0, red)
r.rgbirImage.SetShortAt(y, x*4+1, green)
r.rgbirImage.SetShortAt(y, x*4+2, blue)
r.rgbirImage.SetShortAt(y, x*4+3, ir)
}
}
// 创建一个二维切片来存储图像数据
data := make([][]uint16, MssBands)
for i := range data {
data[i] = make([]uint16, width*height)
}
// 从gocv.Mat中提取数据
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
for b := 0; b < MssBands; b++ {
data[b][y*width+x] = uint16(r.rgbirImage.GetShortAt(y, x*4+b))
}
}
}
ds, err := godal.Create(godal.GTiff,
tiffFile,
MssBands,
godal.UInt16,
width, height)
if err != nil {
log.Error("Error creating TIFF file: ", err)
return err
}
defer ds.Close()
setGeoTransform(ds, 0, 0, float64(width), float64(height), 1.2*4)
for b := 0; b < MssBands; b++ {
band := ds.Bands()[b]
err := band.IO(godal.IOWrite,
0, 0,
data[b],
width, height,
godal.PixelSpacing(2),
godal.LineSpacing(width*2))
if err != nil {
log.Error("Failed to write data to band:", err)
return err
}
}
log.Info("Saved registered mss to ", tiffFile)
return nil
}
func (r *Registrator) SavePanToGDALGTiff(tiffFile string) error {
log.Println("Saving PAN image to TIFF file:", tiffFile)
width := r.PanWidth
height := r.PanHeight
ds, err := godal.Create(godal.GTiff, tiffFile, 1, godal.UInt16, width, height)
if err != nil {
log.Error("Error creating TIFF file: ", err)
return err
}
defer ds.Close()
setGeoTransform(ds, 0, 0, float64(width), float64(height), 1.2)
ds.SetMetadata("NBITS", "16")
// 将通道的数据转换为uint16数组
data := make([]uint16, width*height)
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
data[y*width+x] = uint16(r.PanImage.GetShortAt(y, x))
}
}
band := ds.Bands()[0]
err = band.IO(godal.IOWrite,
0, 0,
data,
width, height,
godal.PixelSpacing(2),
godal.LineSpacing(width*2))
if err != nil {
log.Error("Failed to write data to band:", err)
return err
}
log.Info("Saved pan image to ", tiffFile)
return nil
}
func (r *Registrator) Clean() {
r.PanImage.Close()
for i := 0; i < MssBands; i++ {
@@ -454,14 +303,19 @@ func (r *Registrator) calcDeltaCoeffs() error {
var dx []float64
var dy []float64
effectShift := 0
for j, shift := range r.phaseShifts[i] {
for _, shift := range r.phaseShifts[i] {
if math.IsNaN(float64(shift.dx)) || math.IsNaN(float64(shift.dy)) {
continue
}
// 经验值过滤
if shift.dy < 64.0 {
continue
}
effectShift++
cx = append(cx, float64(shift.Block.coord.X+j)) // MSS 块在X方向没有分块
fmt.Println("effectShift:", effectShift, "cx:", shift.Block.coord.X, "dy:", shift.dy)
cx = append(cx, float64(shift.Block.coord.X+shift.Block.width/2)) // MSS 块在X方向没有分块
log.Debug("effective shift value:", effectShift, "cx:", shift.Block.coord.X, "dy:", shift.dy)
dx = append(dx, float64(shift.dx))
dy = append(dy, float64(shift.dy))
@@ -487,11 +341,25 @@ func (r *Registrator) DoCoRegestration() error {
mapY := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV32FC1)
for y := 0; y < r.MssHeight; y++ {
for x := 0; x < r.MssWidth; x++ {
// dx := r.deltaXCoeffs[band][1]*float64(x) + r.deltaXCoeffs[band][0] + float64(x)
// dy := r.deltaYCoeffs[band][2]*float64(x*x) + r.deltaYCoeffs[band][1]*float64(x) + r.deltaYCoeffs[band][0] + float64(y)
// fmt.Println("x:", x, "dx:", dx, "y:", y, "dy:", dy)
mapX.SetFloatAt(y, x, float32(x)+float32(r.deltaXCoeffs[band][0]))
mapY.SetFloatAt(y, x, float32(y)+float32(r.deltaYCoeffs[band][0]))
var dx, dy float64
if r.resampleMethod == UpSampled {
xx := float64(x * MssBands)
yy := float64(y * MssBands)
dx = (r.deltaXCoeffs[band][1]*float64(xx) + r.deltaXCoeffs[band][0] + xx) / MssBands
dy = (r.deltaYCoeffs[band][2]*float64(xx*xx) + r.deltaYCoeffs[band][1]*float64(xx) + r.deltaYCoeffs[band][0] + yy) / MssBands
} else {
dx = r.deltaXCoeffs[band][1]*float64(x) + r.deltaXCoeffs[band][0] + float64(x)
dy = r.deltaYCoeffs[band][2]*float64(x*x) + r.deltaYCoeffs[band][1]*float64(x) + r.deltaYCoeffs[band][0] + float64(y)
}
// if band+1 == 4 {
// fmt.Println("band:", band+1, "x:", x, "map_x:", mx, "y:", y, "map_y:", my)
// }
// mapX.SetFloatAt(y, x, float32(x)+float32(r.deltaXCoeffs[band][0]))
// mapY.SetFloatAt(y, x, float32(y)+float32(r.deltaYCoeffs[band][0]))
mapX.SetFloatAt(y, x, float32(dx))
mapY.SetFloatAt(y, x, float32(dy))
}
}