// Read version string as specified by RFC 4253, section 4.2.
func readVersion(r io.Reader) ([]byte, error) {
versionString := make([]byte, 0, 64)
- var ok, seenCR bool
+ var ok bool
var buf [1]byte
forEachByte:
for len(versionString) < maxVersionStringBytes {
if err != nil {
return nil, err
}
- b := buf[0]
-
- if !seenCR {
- if b == '\r' {
- seenCR = true
- }
- } else {
- if b == '\n' {
- ok = true
- break forEachByte
- } else {
- seenCR = false
- }
+ // The RFC says that the version should be terminated with \r\n
+ // but several SSH servers actually only send a \n.
+ if buf[0] == '\n' {
+ ok = true
+ break forEachByte
}
- versionString = append(versionString, b)
+ versionString = append(versionString, buf[0])
}
if !ok {
- return nil, errors.New("failed to read version string")
+ return nil, errors.New("ssh: failed to read version string")
}
- // We need to remove the CR from versionString
- return versionString[:len(versionString)-1], nil
+ // There might be a '\r' on the end which we should remove.
+ if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
+ versionString = versionString[:len(versionString)-1]
+ }
+ return versionString, nil
}
)
func TestReadVersion(t *testing.T) {
- buf := []byte(serverVersion)
+ buf := serverVersion
result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
if err != nil {
t.Errorf("readVersion didn't read version correctly: %s", err)
}
}
+func TestReadVersionWithJustLF(t *testing.T) {
+ var buf []byte
+ buf = append(buf, serverVersion...)
+ buf = buf[:len(buf)-1]
+ buf[len(buf)-1] = '\n'
+ result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
+ if err != nil {
+ t.Error("readVersion failed to handle just a \n")
+ }
+ if !bytes.Equal(buf[:len(buf)-1], result) {
+ t.Errorf("version read did not match expected: got %x, want %x", result, buf[:len(buf)-1])
+ }
+}
+
func TestReadVersionTooLong(t *testing.T) {
buf := make([]byte, maxVersionStringBytes+1)
if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
}
func TestReadVersionWithoutCRLF(t *testing.T) {
- buf := []byte(serverVersion)
+ buf := serverVersion
buf = buf[:len(buf)-1]
if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
t.Error("readVersion did not notice \\n was missing")