}
}
+func TestHuffmanDecodeExcessPadding(t *testing.T) {
+ tests := [][]byte{
+ {0xff}, // Padding Exceeds 7 bits
+ {0x1f, 0xff}, // {"a", 1 byte excess padding}
+ {0x1f, 0xff, 0xff}, // {"a", 2 byte excess padding}
+ {0x1f, 0xff, 0xff, 0xff}, // {"a", 3 byte excess padding}
+ {0xff, 0x9f, 0xff, 0xff, 0xff}, // {"a", 29 bit excess padding}
+ {'R', 0xbc, '0', 0xff, 0xff, 0xff, 0xff}, // Padding ends on partial symbol.
+ }
+ for i, in := range tests {
+ var buf bytes.Buffer
+ if _, err := HuffmanDecode(&buf, in); err != ErrInvalidHuffman {
+ t.Errorf("test-%d: decode(%q) = %v; want ErrInvalidHuffman", i, in, err)
+ }
+ }
+}
+
+func TestHuffmanDecodeEOS(t *testing.T) {
+ in := []byte{0xff, 0xff, 0xff, 0xff, 0xfc} // {EOS, "?"}
+ var buf bytes.Buffer
+ if _, err := HuffmanDecode(&buf, in); err != ErrInvalidHuffman {
+ t.Errorf("error = %v; want ErrInvalidHuffman", err)
+ }
+}
+
+func TestHuffmanDecodeMaxLengthOnTrailingByte(t *testing.T) {
+ in := []byte{0x00, 0x01} // {"0", "0", "0"}
+ var buf bytes.Buffer
+ if err := huffmanDecode(&buf, 2, in); err != ErrStringLength {
+ t.Errorf("error = %v; want ErrStringLength", err)
+ }
+}
+
+func TestHuffmanDecodeCorruptPadding(t *testing.T) {
+ in := []byte{0x00}
+ var buf bytes.Buffer
+ if _, err := HuffmanDecode(&buf, in); err != ErrInvalidHuffman {
+ t.Errorf("error = %v; want ErrInvalidHuffman", err)
+ }
+}
+
func TestHuffmanDecode(t *testing.T) {
tests := []struct {
inHex, want string
// maxLen bytes will return ErrStringLength.
func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error {
n := rootHuffmanNode
- cur, nbits := uint(0), uint8(0)
+ // cur is the bit buffer that has not been fed into n.
+ // cbits is the number of low order bits in cur that are valid.
+ // sbits is the number of bits of the symbol prefix being decoded.
+ cur, cbits, sbits := uint(0), uint8(0), uint8(0)
for _, b := range v {
cur = cur<<8 | uint(b)
- nbits += 8
- for nbits >= 8 {
- idx := byte(cur >> (nbits - 8))
+ cbits += 8
+ sbits += 8
+ for cbits >= 8 {
+ idx := byte(cur >> (cbits - 8))
n = n.children[idx]
if n == nil {
return ErrInvalidHuffman
return ErrStringLength
}
buf.WriteByte(n.sym)
- nbits -= n.codeLen
+ cbits -= n.codeLen
n = rootHuffmanNode
+ sbits = cbits
} else {
- nbits -= 8
+ cbits -= 8
}
}
}
- for nbits > 0 {
- n = n.children[byte(cur<<(8-nbits))]
- if n.children != nil || n.codeLen > nbits {
+ for cbits > 0 {
+ n = n.children[byte(cur<<(8-cbits))]
+ if n == nil {
+ return ErrInvalidHuffman
+ }
+ if n.children != nil || n.codeLen > cbits {
break
}
+ if maxLen != 0 && buf.Len() == maxLen {
+ return ErrStringLength
+ }
buf.WriteByte(n.sym)
- nbits -= n.codeLen
+ cbits -= n.codeLen
n = rootHuffmanNode
+ sbits = cbits
+ }
+ if sbits > 7 {
+ // Either there was an incomplete symbol, or overlong padding.
+ // Both are decoding errors per RFC 7541 section 5.2.
+ return ErrInvalidHuffman
}
+ if mask := uint(1<<cbits - 1); cur&mask != mask {
+ // Trailing bits must be a prefix of EOS per RFC 7541 section 5.2.
+ return ErrInvalidHuffman
+ }
+
return nil
}