var (
errBadUint = os.ErrorString("gob: encoded unsigned integer out of range");
+ errBadType = os.ErrorString("gob: unknown type id or corrupted data");
errRange = os.ErrorString("gob: internal error: field numbers out of bounds");
errNotStruct = os.ErrorString("gob: TODO: can only handle structs")
)
return os.ErrorString("gob: decode can't handle " + rt.String())
}
typeLock.Lock();
+ if _, ok := idToType[wireId]; !ok {
+ typeLock.Unlock();
+ return errBadType;
+ }
enginePtr, err := getDecEnginePtr(wireId, rt);
typeLock.Unlock();
if err != nil {
dec.state.err = nil;
for {
// Read a count.
- nbytes, err := decodeUintReader(dec.r, dec.oneByte);
- if err != nil {
- return err;
+ var nbytes uint64;
+ nbytes, dec.state.err = decodeUintReader(dec.r, dec.oneByte);
+ if dec.state.err != nil {
+ break;
}
-
// Allocate the buffer.
if nbytes > uint64(len(dec.buf)) {
dec.buf = make([]byte, nbytes + 1000);
// Read the data
var n int;
- n, err = dec.r.Read(dec.buf[0:nbytes]);
- if err != nil {
- return err;
+ n, dec.state.err = io.ReadFull(dec.r, dec.buf[0:nbytes]);
+ if dec.state.err != nil {
+ break;
}
if n < int(nbytes) {
- return os.ErrorString("gob decode: short read");
+ dec.state.err = io.ErrUnexpectedEOF;
+ break;
}
// Receive a type id.
import (
"bytes";
"gob";
+ "io";
"os";
"reflect";
"strings";
badTypeCheck(new(ET3), false, "different name of field", t);
badTypeCheck(new(ET4), true, "different type of field", t);
}
+
+func corruptDataCheck(s string, err os.Error, t *testing.T) {
+ b := bytes.NewBuffer(strings.Bytes(s));
+ dec := NewDecoder(b);
+ dec.Decode(new(ET2));
+ if dec.state.err != err {
+ t.Error("expected error", err, "got", dec.state.err);
+ }
+}
+
+// Check that we survive bad data.
+func TestBadData(t *testing.T) {
+ corruptDataCheck("\x01\x01\x01", os.EOF, t);
+ corruptDataCheck("\x7Fhi", io.ErrUnexpectedEOF, t);
+ corruptDataCheck("\x03now is the time for all good men", errBadType, t);
+}