import (
"fmt"
"math"
+ "math/big"
"math/rand"
"strings"
"testing"
benchmarkDCT(b, idct)
}
+const testSlowVsBig = true
+
func TestDCT(t *testing.T) {
blocks := make([]block, len(testBlocks))
copy(blocks, testBlocks[:])
- // Append some randomly generated blocks of varying sparseness.
+ // All zeros
+ blocks = append(blocks, block{})
+
+ // Every possible unit impulse.
+ for i := range blockSize {
+ var b block
+ b[i] = 255
+ blocks = append(blocks, b)
+ }
+
+ // All ones.
+ var ones block
+ for i := range ones {
+ ones[i] = 255
+ }
+ blocks = append(blocks, ones)
+
+ // Every possible inverted unit impulse.
+ for i := range blockSize {
+ ones[i] = 0
+ blocks = append(blocks, ones)
+ ones[i] = 255
+ }
+
+ // Some randomly generated blocks of varying sparseness.
r := rand.New(rand.NewSource(123))
for i := 0; i < 100; i++ {
b := block{}
blocks = append(blocks, b)
}
- // Check that the FDCT and IDCT functions are inverses, after a scale and
- // level shift. Scaling reduces the rounding errors in the conversion from
- // floats to ints.
- for i, b := range blocks {
- got, want := b, b
- slowFDCT(&got)
- slowIDCT(&got)
- for j := range got {
- got[j] = got[j]/8 + 128
- }
- if d := differ(&got, &want, 2); d >= 0 {
- t.Errorf("i=%d: IDCT(FDCT) (diff at %d,%d)\nsrc\n%s\ngot\n%s\nwant\n%s\n", i, d/8, d%8, &b, &got, &want)
+ // Check that the slow FDCT and IDCT functions are inverses,
+ // after a scale and level shift.
+ // Scaling reduces the rounding errors in the conversion.
+ // The “fast” ones are not inverses because the fast IDCT
+ // is optimized for 8-bit inputs, not full 16-bit ones.
+ slowRoundTrip := func(b *block) {
+ slowFDCT(b)
+ slowIDCT(b)
+ for j := range b {
+ b[j] = b[j]/8 + 128
}
}
+ nop := func(*block) {}
+ testDCT(t, "IDCT(FDCT)", blocks, slowRoundTrip, nop, 1, 8)
- // Check that the optimized and slow FDCT implementations agree.
- // The fdct function already does a scale and level shift.
- for i, b := range blocks {
- got, want := b, b
- fdct(&got)
- slowFDCT(&want)
- if d := differ(&got, &want, 2); d >= 0 {
- t.Errorf("i=%d: FDCT (diff at %d,%d)\nsrc\n%s\ngot\n%s\nwant\n%s\n", i, d/8, d%8, &b, &got, &want)
- }
+ if testSlowVsBig {
+ testDCT(t, "slowFDCT", blocks, slowFDCT, slowerFDCT, 0, 64)
+ testDCT(t, "slowIDCT", blocks, slowIDCT, slowerIDCT, 0, 64)
}
- // Check that the optimized and slow IDCT implementations agree.
- for i, b := range blocks {
- got, want := b, b
- idct(&got)
- slowIDCT(&want)
- if d := differ(&got, &want, 2); d >= 0 {
- t.Errorf("i=%d: IDCT (diff at %d,%d)\nsrc\n%s\ngot\n%s\nwant\n%s\n", i, d/8, d%8, &b, &got, &want)
+ // Check that the optimized and slow FDCT implementations agree.
+ testDCT(t, "FDCT", blocks, fdct, slowFDCT, 1, 16)
+ testDCT(t, "IDCT", blocks, idct, slowIDCT, 1, 8)
+}
+
+func testDCT(t *testing.T, name string, blocks []block, fhave, fwant func(*block), tolerance int32, maxCloseCalls int) {
+ t.Run(name, func(t *testing.T) {
+ totalClose := 0
+ for i, b := range blocks {
+ have, want := b, b
+ fhave(&have)
+ fwant(&want)
+ d, n := differ(&have, &want, tolerance)
+ if d >= 0 || n > maxCloseCalls {
+ fail := ""
+ if d >= 0 {
+ fail = fmt.Sprintf("diff at %d,%d", d/8, d%8)
+ }
+ if n > maxCloseCalls {
+ if fail != "" {
+ fail += "; "
+ }
+ fail += fmt.Sprintf("%d close calls", n)
+ }
+ t.Errorf("i=%d: %s (%s)\nsrc\n%s\nhave\n%s\nwant\n%s\n",
+ i, name, fail, &b, &have, &want)
+ }
+ totalClose += n
}
- }
+ if tolerance > 0 {
+ t.Logf("%d/%d total close calls", totalClose, len(blocks)*blockSize)
+ }
+ })
}
-// differ reports whether any pair-wise elements in b0 and b1 differ by more than 'ok'.
-// That tolerance is because there isn't a single definitive decoding of
-// a given JPEG image, even before the YCbCr to RGB conversion; implementations
+// differ returns the index of the first pair-wise elements in b0 and b1
+// that differ by more than 'ok', along with the total number of elements
+// that differ by at least ok ("close calls").
+//
+// There isn't a single definitive decoding of a given JPEG image,
+// even before the YCbCr to RGB conversion; implementations
// can have different IDCT rounding errors.
-// If there is a difference, differ returns the index of the first difference.
-// Otherwise it returns -1.
-func differ(b0, b1 *block, ok int32) int {
+//
+// If there are no differences, differ returns -1, 0.
+func differ(b0, b1 *block, ok int32) (index, closeCalls int) {
+ index = -1
for i := range b0 {
delta := b0[i] - b1[i]
if delta < -ok || ok < delta {
- return i
+ if index < 0 {
+ index = i
+ }
+ }
+ if delta <= -ok || ok <= delta {
+ closeCalls++
}
}
- return -1
+ return
}
// alpha returns 1 if i is 0 and returns √2 otherwise.
return math.Sqrt2
}
+// bigAlpha returns 1 if i is 0 and returns √2 otherwise.
+func bigAlpha(i int) *big.Float {
+ if i == 0 {
+ return bigFloat1
+ }
+ return bigFloatSqrt2
+}
+
var cosines = [32]float64{
+1.0000000000000000000000000000000000000000000000000000000000000000, // cos(π/16 * 0)
+0.9807852804032304491261822361342390369739337308933360950029160885, // cos(π/16 * 1)
+0.9807852804032304491261822361342390369739337308933360950029160885, // cos(π/16 * 31)
}
+func bigFloat(s string) *big.Float {
+ f, ok := new(big.Float).SetString(s)
+ if !ok {
+ panic("bad float")
+ }
+ return f
+}
+
+var (
+ bigFloat1 = big.NewFloat(1)
+ bigFloatSqrt2 = bigFloat("1.41421356237309504880168872420969807856967187537694807317667974")
+)
+
+var bigCosines = [32]*big.Float{
+ bigFloat("+1.0000000000000000000000000000000000000000000000000000000000000000"), // cos(π/16 * 0)
+ bigFloat("+0.9807852804032304491261822361342390369739337308933360950029160885"), // cos(π/16 * 1)
+ bigFloat("+0.9238795325112867561281831893967882868224166258636424861150977312"), // cos(π/16 * 2)
+ bigFloat("+0.8314696123025452370787883776179057567385608119872499634461245902"), // cos(π/16 * 3)
+ bigFloat("+0.7071067811865475244008443621048490392848359376884740365883398689"), // cos(π/16 * 4)
+ bigFloat("+0.5555702330196022247428308139485328743749371907548040459241535282"), // cos(π/16 * 5)
+ bigFloat("+0.3826834323650897717284599840303988667613445624856270414338006356"), // cos(π/16 * 6)
+ bigFloat("+0.1950903220161282678482848684770222409276916177519548077545020894"), // cos(π/16 * 7)
+
+ bigFloat("-0.0000000000000000000000000000000000000000000000000000000000000000"), // cos(π/16 * 8)
+ bigFloat("-0.1950903220161282678482848684770222409276916177519548077545020894"), // cos(π/16 * 9)
+ bigFloat("-0.3826834323650897717284599840303988667613445624856270414338006356"), // cos(π/16 * 10)
+ bigFloat("-0.5555702330196022247428308139485328743749371907548040459241535282"), // cos(π/16 * 11)
+ bigFloat("-0.7071067811865475244008443621048490392848359376884740365883398689"), // cos(π/16 * 12)
+ bigFloat("-0.8314696123025452370787883776179057567385608119872499634461245902"), // cos(π/16 * 13)
+ bigFloat("-0.9238795325112867561281831893967882868224166258636424861150977312"), // cos(π/16 * 14)
+ bigFloat("-0.9807852804032304491261822361342390369739337308933360950029160885"), // cos(π/16 * 15)
+
+ bigFloat("-1.0000000000000000000000000000000000000000000000000000000000000000"), // cos(π/16 * 16)
+ bigFloat("-0.9807852804032304491261822361342390369739337308933360950029160885"), // cos(π/16 * 17)
+ bigFloat("-0.9238795325112867561281831893967882868224166258636424861150977312"), // cos(π/16 * 18)
+ bigFloat("-0.8314696123025452370787883776179057567385608119872499634461245902"), // cos(π/16 * 19)
+ bigFloat("-0.7071067811865475244008443621048490392848359376884740365883398689"), // cos(π/16 * 20)
+ bigFloat("-0.5555702330196022247428308139485328743749371907548040459241535282"), // cos(π/16 * 21)
+ bigFloat("-0.3826834323650897717284599840303988667613445624856270414338006356"), // cos(π/16 * 22)
+ bigFloat("-0.1950903220161282678482848684770222409276916177519548077545020894"), // cos(π/16 * 23)
+
+ bigFloat("+0.0000000000000000000000000000000000000000000000000000000000000000"), // cos(π/16 * 24)
+ bigFloat("+0.1950903220161282678482848684770222409276916177519548077545020894"), // cos(π/16 * 25)
+ bigFloat("+0.3826834323650897717284599840303988667613445624856270414338006356"), // cos(π/16 * 26)
+ bigFloat("+0.5555702330196022247428308139485328743749371907548040459241535282"), // cos(π/16 * 27)
+ bigFloat("+0.7071067811865475244008443621048490392848359376884740365883398689"), // cos(π/16 * 28)
+ bigFloat("+0.8314696123025452370787883776179057567385608119872499634461245902"), // cos(π/16 * 29)
+ bigFloat("+0.9238795325112867561281831893967882868224166258636424861150977312"), // cos(π/16 * 30)
+ bigFloat("+0.9807852804032304491261822361342390369739337308933360950029160885"), // cos(π/16 * 31)
+}
+
// slowFDCT performs the 8*8 2-dimensional forward discrete cosine transform:
//
// dst[u,v] = (1/8) * Σ_x Σ_y alpha(u) * alpha(v) * src[x,y] *
//
// b acts as both dst and src.
func slowFDCT(b *block) {
- var dst [blockSize]float64
+ var dst block
for v := 0; v < 8; v++ {
for u := 0; u < 8; u++ {
sum := 0.0
cosines[((2*y+1)*v)%32]
}
}
- dst[8*v+u] = sum
+ dst[8*v+u] = int32(math.Round(sum))
}
}
- // Convert from float64 to int32.
- for i := range dst {
- b[i] = int32(dst[i] + 0.5)
+ *b = dst
+}
+
+// slowerFDCT is slowFDCT but using big.Floats to validate slowFDCT.
+func slowerFDCT(b *block) {
+ var dst block
+ for v := 0; v < 8; v++ {
+ for u := 0; u < 8; u++ {
+ sum := big.NewFloat(0)
+ for y := 0; y < 8; y++ {
+ for x := 0; x < 8; x++ {
+ f := big.NewFloat(float64(b[8*y+x] - 128))
+ f = new(big.Float).Mul(f, bigAlpha(u))
+ f = new(big.Float).Mul(f, bigAlpha(v))
+ f = new(big.Float).Mul(f, bigCosines[((2*x+1)*u)%32])
+ f = new(big.Float).Mul(f, bigCosines[((2*y+1)*v)%32])
+ sum = new(big.Float).Add(sum, f)
+ }
+ }
+ // Int64 truncates toward zero, so add ±0.5
+ // as needed to round
+ if sum.Sign() > 0 {
+ sum = new(big.Float).Add(sum, big.NewFloat(+0.5))
+ } else {
+ sum = new(big.Float).Add(sum, big.NewFloat(-0.5))
+ }
+ i, _ := sum.Int64()
+ dst[8*v+u] = int32(i)
+ }
}
+ *b = dst
}
// slowIDCT performs the 8*8 2-dimensional inverse discrete cosine transform:
//
// b acts as both dst and src.
func slowIDCT(b *block) {
- var dst [blockSize]float64
+ var dst block
for y := 0; y < 8; y++ {
for x := 0; x < 8; x++ {
sum := 0.0
cosines[((2*y+1)*v)%32]
}
}
- dst[8*y+x] = sum / 8
+ dst[8*y+x] = int32(math.Round(sum / 8))
}
}
- // Convert from float64 to int32.
- for i := range dst {
- b[i] = int32(dst[i] + 0.5)
+ *b = dst
+}
+
+// slowerIDCT is slowIDCT but using big.Floats to validate slowIDCT.
+func slowerIDCT(b *block) {
+ var dst block
+ for y := 0; y < 8; y++ {
+ for x := 0; x < 8; x++ {
+ sum := big.NewFloat(0)
+ for v := 0; v < 8; v++ {
+ for u := 0; u < 8; u++ {
+ f := big.NewFloat(float64(b[8*v+u]))
+ f = new(big.Float).Mul(f, bigAlpha(u))
+ f = new(big.Float).Mul(f, bigAlpha(v))
+ f = new(big.Float).Mul(f, bigCosines[((2*x+1)*u)%32])
+ f = new(big.Float).Mul(f, bigCosines[((2*y+1)*v)%32])
+ f = new(big.Float).Quo(f, big.NewFloat(8))
+ sum = new(big.Float).Add(sum, f)
+ }
+ }
+ // Int64 truncates toward zero, so add ±0.5
+ // as needed to round
+ if sum.Sign() > 0 {
+ sum = new(big.Float).Add(sum, big.NewFloat(+0.5))
+ } else {
+ sum = new(big.Float).Add(sum, big.NewFloat(-0.5))
+ }
+ i, _ := sum.Int64()
+ dst[8*y+x] = int32(i)
+ }
}
+ *b = dst
}
func (b *block) String() string {