使用gamma校正提升jpg亮度

This commit is contained in:
nuknal
2024-05-30 11:22:34 +08:00
parent 7d9ec46750
commit 07ee4d88d4
18 changed files with 174 additions and 98 deletions

View File

@@ -0,0 +1,373 @@
package imageproc
import (
"fmt"
"image"
"image/color"
"math"
"os"
"sync"
"github.com/airbusgeo/godal"
log "github.com/sirupsen/logrus"
"gocv.io/x/gocv"
)
type Registrate interface{}
const (
MssBands = 4
PixelBytes = 2
PanWidth = 9344 // 像素宽度
MssWidth = 2336
BlockNH = 4
BlockNW = 8
OverlappedBlockLines = 3000 // 重叠块的行数
DownSampled ResampleMethod = "down_sample_pan"
UpSampled ResampleMethod = "up_sample_mss"
)
type ResampleMethod string
type Registrator struct {
Params Params
PanImage gocv.Mat
PanHeight int
PanWidth int
MssImages [4]gocv.Mat
MssHeight int
MssWidth int
shiftMutex sync.Mutex
phaseShifts [4][]PhaseShiftM
deltaXCoeffs [4][]float64 // 图像内畸变线性变换捕捉图像在水平方向上引起的X方向的变形
deltaYCoeffs [4][]float64 // 图像内畸变非线性变换捕捉图像在水平方向上引起的Y方向的变形
registeredMssImages [4]gocv.Mat // 配准后的MSS图像
rgbirImage gocv.Mat
resampleMethod ResampleMethod
}
func NewRegistrator(rsmethod ResampleMethod) *Registrator {
var r Registrator
r.resampleMethod = rsmethod
return &r
}
func (r *Registrator) LoadPanRaw() error {
data, err := os.ReadFile(r.Params.PanRawFile)
if err != nil {
log.Error("Error reading raw file: ", err)
return err
}
height := len(data) / (PanWidth * PixelBytes)
r.PanImage, err = gocv.NewMatFromBytes(height, PanWidth, gocv.MatTypeCV16U, data)
if err != nil {
log.Error("Error creating Mat from bytes: ", err)
return err
}
r.PanHeight = height
r.PanWidth = PanWidth
godal.RegisterAll()
hDriver, ok := godal.RasterDriver("Gtiff")
if !ok {
panic("Gtiff not found")
}
md := hDriver.Metadatas()
if md["DCAP_CREATE"] == "YES" {
fmt.Printf("Driver GTiff supports Create() method.\n")
}
if md["DCAP_CREATECOPY"] == "YES" {
fmt.Printf("Driver Gtiff supports CreateCopy() method.\n")
}
fmt.Println("Gtiff driver name:", hDriver.LongName(), hDriver.ShortName())
return nil
}
func (r *Registrator) LoadMssRaw() error {
data, err := os.ReadFile(r.Params.MssRawFile)
if err != nil {
log.Error("Error reading raw file: ", err)
return err
}
height := len(data) / (MssWidth * PixelBytes * MssBands)
mssData := make([][]byte, MssBands)
for h := 0; h < height; h++ {
row := data[h*MssWidth*MssBands*PixelBytes : (h+1)*MssWidth*MssBands*PixelBytes]
for i := 0; i < MssBands; i++ {
mssData[i] = append(mssData[i], row[i*MssWidth*PixelBytes:(i+1)*MssWidth*PixelBytes]...)
}
}
for i := 0; i < MssBands; i++ {
r.MssImages[i], err = gocv.NewMatFromBytes(height, MssWidth, gocv.MatTypeCV16U, mssData[i])
if err != nil {
log.Error("Error creating Mat from bytes: ", err)
return err
}
}
r.MssHeight = height
r.MssWidth = MssWidth
return nil
}
func (r *Registrator) DoPhaseCorrelation() error {
switch r.resampleMethod {
case UpSampled:
return r.CalcUpPhaseCorrelation()
default:
return r.CalcDownPhaseCorrelation()
}
}
// 将PAN降采样后计算相位相关的偏移量
func (r *Registrator) CalcDownPhaseCorrelation() error {
// 确保 MSS 高度是 PAN 高度的 1/4
if r.MssHeight*4 != r.PanHeight {
err := fmt.Errorf("MSS height is not 1/4 of PAN height, invalid raw file")
log.Error(err)
return err
}
// 将PAN将采样作为轮廓匹配基准图像
downsampledPanImage := gocv.NewMat()
gocv.Resize(r.PanImage, &downsampledPanImage,
image.Point{X: r.MssWidth, Y: r.MssHeight}, 0, 0, gocv.InterpolationCubic)
log.Println("down sampled PAN images size:", downsampledPanImage.Size())
// 分块高度
blockHeight := r.MssHeight / BlockNH
blockWidth := r.MssWidth / BlockNW
return r.calcPhaseCorrelation(downsampledPanImage, r.MssImages, r.MssHeight, r.MssWidth, blockHeight, blockWidth)
}
// 将MSS升采样采样后计算相位相关的偏移量
func (r *Registrator) CalcUpPhaseCorrelation() error {
// 确保 MSS 高度是 PAN 高度的 1/4
if r.MssHeight*4 != r.PanHeight {
err := fmt.Errorf("MSS height is not 1/4 of PAN height, invalid raw file")
log.Error(err)
return err
}
// 将PAN将采样作为轮廓匹配基准图像
var upsampledMssImages [MssBands]gocv.Mat
for i := 0; i < MssBands; i++ {
upsampledMssImages[i] = gocv.NewMat()
gocv.Resize(r.MssImages[i], &upsampledMssImages[i],
image.Point{X: r.PanWidth, Y: r.PanHeight}, 0, 0, gocv.InterpolationCubic)
}
fmt.Println("up sampled MSS images size:", upsampledMssImages[0].Size())
// 分块高度 - BlockNH, BlockNW % 4 == 0
blockHeight := r.PanHeight / BlockNH
blockWidth := r.PanWidth / BlockNW
log.Infof("blockHeight=%d, blockWidth=%d", blockHeight, blockWidth)
return r.calcPhaseCorrelation(r.PanImage, upsampledMssImages, r.PanHeight, r.PanWidth, blockHeight, blockWidth)
}
func (r *Registrator) calcPhaseCorrelation(panImage gocv.Mat,
mssImages [4]gocv.Mat,
height, width,
blockHeight, blockWidth int) error {
var wg sync.WaitGroup
for bh := 0; bh < BlockNH; bh++ {
for bw := 0; bw < BlockNW; bw++ {
wg.Add(1)
go func(bh, bw int) {
defer wg.Done()
x0 := bw * blockWidth
y0 := bh * blockHeight
x1 := (bw + 1) * blockWidth
y1 := (bh + 1) * blockHeight
y1 += OverlappedBlockLines // Y偏移量过大需要将重叠块的行数加上以避免边界影响
if x1 > width || y1 > height {
log.Debugf("Block out of range. x0=%d, y0=%d, x1=%d, y1=%d", x0, y0, x1, y1)
}
if y1 > height {
y1 = height
}
var shiftM PhaseShiftM
shiftM.Block.width = x1 - x0
shiftM.Block.height = y1 - y0
shiftM.Block.coord.X = x0 // 块左上角x坐标
shiftM.Block.coord.Y = y0 // 块左上角y坐标
rect := image.Rect(
x0, y0,
x1, y1,
)
panBlock := panImage.Region(rect)
for band := 0; band < MssBands; band++ {
log.Debug("processing band:", band+1, ",block:", bh, rect)
mssBlock := mssImages[band].Region(rect)
// 处理每个分块
phaseShift, response := r.calculateBlockPhaseShift(panBlock, mssBlock)
shiftM.dx = phaseShift.X
shiftM.dy = phaseShift.Y
shiftM.response = response
r.shiftMutex.Lock()
r.phaseShifts[band] = append(r.phaseShifts[band], shiftM)
r.shiftMutex.Unlock()
mssBlock.Close()
}
panBlock.Close()
}(bh, bw)
}
}
wg.Wait()
for i := 0; i < MssBands; i++ {
for _, shift := range r.phaseShifts[i] {
if shift.response > 0.4 || shift.dx > 8 || shift.dy > 8 {
log.Debugf("Band %d, block %d, dx=%f, dy=%f, response=%f",
i, shift.Block.coord.X, shift.dx, shift.dy, shift.response)
}
}
}
return r.calcDeltaCoeffs()
}
func (r *Registrator) Clean() {
r.PanImage.Close()
for i := 0; i < MssBands; i++ {
r.MssImages[i].Close()
}
for i := 0; i < MssBands; i++ {
r.registeredMssImages[i].Close()
}
r.rgbirImage.Close()
}
func (r *Registrator) calcDeltaCoeffs() error {
// 计算每个通道的delta多项式拟合系数
for i := 0; i < MssBands; i++ {
var cx []float64
var dx []float64
var dy []float64
effectShift := 0
for _, shift := range r.phaseShifts[i] {
if math.IsNaN(float64(shift.dx)) || math.IsNaN(float64(shift.dy)) {
continue
}
// 经验值过滤
if shift.dy < 64.0 {
continue
}
effectShift++
cx = append(cx, float64(shift.Block.coord.X+shift.Block.width/2)) // MSS 块在X方向没有分块
log.Debugf("effective shift value: %v, cx: %v, dx: %v, dy: %v",
effectShift, shift.Block.coord.X, shift.dx, shift.dy)
dx = append(dx, float64(shift.dx))
dy = append(dy, float64(shift.dy))
}
if len(cx) < 3 {
log.Errorf("No effective shift value found for band %d, skip delta coefficients calculation", i+1)
continue
} else {
var err error
if r.deltaXCoeffs[i], err = PolynomialFit(cx, dx, 1); err != nil {
log.Error("Error fitting deltaX coeffs: ", err)
return err
}
if r.deltaYCoeffs[i], err = PolynomialFit(cx, dy, 2); err != nil {
log.Error("Error fitting deltaY coeffs: ", err)
return err
}
}
}
for i := 0; i < MssBands; i++ {
if len(r.deltaXCoeffs[i]) < 2 || len(r.deltaYCoeffs[i]) < 3 {
continue
}
log.Printf("Band %d:\n delta_x = %.6f*x + %.6f, \n delta_y = %.6f*x^2 + %.6f*x + %.6f\n",
i+1,
r.deltaXCoeffs[i][1], r.deltaXCoeffs[i][0],
r.deltaYCoeffs[i][2], r.deltaYCoeffs[i][1], r.deltaYCoeffs[i][0])
}
return nil
}
func (r *Registrator) DoCoRegestration() error {
for band := 0; band < MssBands; band++ {
if len(r.deltaXCoeffs[band]) < 2 || len(r.deltaYCoeffs[band]) < 3 {
log.Error("delta coefficients not calculated, skip co-registration")
r.registeredMssImages[band] = r.MssImages[band].Clone()
continue
}
mapX := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV32FC1)
mapY := gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV32FC1)
for y := 0; y < r.MssHeight; y++ {
for x := 0; x < r.MssWidth; x++ {
var dx, dy float64
if r.resampleMethod == UpSampled {
xx := float64(x * MssBands)
yy := float64(y * MssBands)
dx = (r.deltaXCoeffs[band][1]*float64(xx) + r.deltaXCoeffs[band][0] + xx) / MssBands
dy = (r.deltaYCoeffs[band][2]*float64(xx*xx) + r.deltaYCoeffs[band][1]*float64(xx) + r.deltaYCoeffs[band][0] + yy) / MssBands
} else {
dx = r.deltaXCoeffs[band][1]*float64(x) + r.deltaXCoeffs[band][0] + float64(x)
dy = r.deltaYCoeffs[band][2]*float64(x*x) + r.deltaYCoeffs[band][1]*float64(x) + r.deltaYCoeffs[band][0] + float64(y)
}
// if band+1 == 4 {
// fmt.Println("band:", band+1, "x:", x, "map_x:", mx, "y:", y, "map_y:", my)
// }
// mapX.SetFloatAt(y, x, float32(x)+float32(r.deltaXCoeffs[band][0]))
// mapY.SetFloatAt(y, x, float32(y)+float32(r.deltaYCoeffs[band][0]))
mapX.SetFloatAt(y, x, float32(dx))
mapY.SetFloatAt(y, x, float32(dy))
}
}
log.Println("co-registration for band", band+1)
r.registeredMssImages[band] = gocv.NewMatWithSize(r.MssHeight, r.MssWidth, gocv.MatTypeCV16UC1)
gocv.Remap(r.MssImages[band],
&r.registeredMssImages[band],
&mapX, &mapY,
gocv.InterpolationCubic,
gocv.BorderConstant,
color.RGBA{0, 0, 0, 0})
}
return nil
}