]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssh: add Std{in,out,err}Pipe methods to Session
authorDave Cheney <dave@cheney.net>
Thu, 1 Dec 2011 10:30:16 +0000 (08:30 -0200)
committerGustavo Niemeyer <gustavo@niemeyer.net>
Thu, 1 Dec 2011 10:30:16 +0000 (08:30 -0200)
R=gustav.paul, cw, agl, rsc, n13m3y3r
CC=golang-dev
https://golang.org/cl/5433080

src/pkg/exp/ssh/session.go
src/pkg/exp/ssh/session_test.go [new file with mode: 0644]

index dab0113f4b5bc00007b3e74e716947a25c0dd6bd..8eea8b287b721eb7d0fc8c4164cd691628e47eb2 100644 (file)
@@ -54,9 +54,10 @@ type Session struct {
 
        *clientChan // the channel backing this session
 
-       started   bool // true once a Shell or Run is invoked.
-       copyFuncs []func() error
-       errch     chan error // one send per copyFunc
+       started        bool // true once Start, Run or Shell is invoked.
+       closeAfterWait []io.Closer
+       copyFuncs      []func() error
+       errch          chan error // one send per copyFunc
 }
 
 // RFC 4254 Section 6.4.
@@ -231,7 +232,7 @@ func (s *Session) start() error {
        return nil
 }
 
-// Wait waits for the remote command to exit. 
+// Wait waits for the remote command to exit.
 func (s *Session) Wait() error {
        if !s.started {
                return errors.New("ssh: session not started")
@@ -244,11 +245,12 @@ func (s *Session) Wait() error {
                        copyError = err
                }
        }
-
+       for _, fd := range s.closeAfterWait {
+               fd.Close()
+       }
        if waitErr != nil {
                return waitErr
        }
-
        return copyError
 }
 
@@ -283,11 +285,15 @@ func (s *Session) stdin() error {
                s.Stdin = new(bytes.Buffer)
        }
        s.copyFuncs = append(s.copyFuncs, func() error {
-               _, err := io.Copy(&chanWriter{
+               w := &chanWriter{
                        packetWriter: s,
                        peersId:      s.peersId,
                        win:          s.win,
-               }, s.Stdin)
+               }
+               _, err := io.Copy(w, s.Stdin)
+               if err1 := w.Close(); err == nil {
+                       err = err1
+               }
                return err
        })
        return nil
@@ -298,11 +304,12 @@ func (s *Session) stdout() error {
                s.Stdout = ioutil.Discard
        }
        s.copyFuncs = append(s.copyFuncs, func() error {
-               _, err := io.Copy(s.Stdout, &chanReader{
+               r := &chanReader{
                        packetWriter: s,
                        peersId:      s.peersId,
                        data:         s.data,
-               })
+               }
+               _, err := io.Copy(s.Stdout, r)
                return err
        })
        return nil
@@ -313,16 +320,72 @@ func (s *Session) stderr() error {
                s.Stderr = ioutil.Discard
        }
        s.copyFuncs = append(s.copyFuncs, func() error {
-               _, err := io.Copy(s.Stderr, &chanReader{
+               r := &chanReader{
                        packetWriter: s,
                        peersId:      s.peersId,
                        data:         s.dataExt,
-               })
+               }
+               _, err := io.Copy(s.Stderr, r)
                return err
        })
        return nil
 }
 
+// StdinPipe returns a pipe that will be connected to the 
+// remote command's standard input when the command starts.
+func (s *Session) StdinPipe() (io.WriteCloser, error) {
+       if s.Stdin != nil {
+               return nil, errors.New("ssh: Stdin already set")
+       }
+       if s.started {
+               return nil, errors.New("ssh: StdinPipe after process started")
+       }
+       pr, pw := io.Pipe()
+       s.Stdin = pr
+       s.closeAfterWait = append(s.closeAfterWait, pr)
+       return pw, nil
+}
+
+// StdoutPipe returns a pipe that will be connected to the 
+// remote command's standard output when the command starts.
+// There is a fixed amount of buffering that is shared between
+// stdout and stderr streams. If the StdoutPipe reader is 
+// not serviced fast enought it may eventually cause the 
+// remote command to block.
+func (s *Session) StdoutPipe() (io.ReadCloser, error) {
+       if s.Stdout != nil {
+               return nil, errors.New("ssh: Stdout already set")
+       }
+       if s.started {
+               return nil, errors.New("ssh: StdoutPipe after process started")
+       }
+       pr, pw := io.Pipe()
+       s.Stdout = pw
+       s.closeAfterWait = append(s.closeAfterWait, pw)
+       return pr, nil
+}
+
+// StderrPipe returns a pipe that will be connected to the 
+// remote command's standard error when the command starts.
+// There is a fixed amount of buffering that is shared between
+// stdout and stderr streams. If the StderrPipe reader is 
+// not serviced fast enought it may eventually cause the 
+// remote command to block.
+func (s *Session) StderrPipe() (io.ReadCloser, error) {
+       if s.Stderr != nil {
+               return nil, errors.New("ssh: Stderr already set")
+       }
+       if s.started {
+               return nil, errors.New("ssh: StderrPipe after process started")
+       }
+       pr, pw := io.Pipe()
+       s.Stderr = pw
+       s.closeAfterWait = append(s.closeAfterWait, pw)
+       return pr, nil
+}
+
+// TODO(dfc) add Output and CombinedOutput helpers
+
 // NewSession returns a new interactive session on the remote host.
 func (c *ClientConn) NewSession() (*Session, error) {
        ch := c.newChan(c.transport)
diff --git a/src/pkg/exp/ssh/session_test.go b/src/pkg/exp/ssh/session_test.go
new file mode 100644 (file)
index 0000000..4be7746
--- /dev/null
@@ -0,0 +1,149 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+// Session tests.
+
+import (
+       "bytes"
+       "io"
+       "testing"
+)
+
+// dial constructs a new test server and returns a *ClientConn.
+func dial(t *testing.T) *ClientConn {
+       pw := password("tiger")
+       serverConfig.PasswordCallback = func(user, pass string) bool {
+               return user == "testuser" && pass == string(pw)
+       }
+       serverConfig.PubKeyCallback = nil
+
+       l, err := Listen("tcp", "127.0.0.1:0", serverConfig)
+       if err != nil {
+               t.Fatalf("unable to listen: %s", err)
+       }
+       go func() {
+               defer l.Close()
+               conn, err := l.Accept()
+               if err != nil {
+                       t.Errorf("Unable to accept: %v", err)
+                       return
+               }
+               defer conn.Close()
+               if err := conn.Handshake(); err != nil {
+                       t.Errorf("Unable to handshake: %v", err)
+                       return
+               }
+               for {
+                       ch, err := conn.Accept()
+                       if err == io.EOF {
+                               return
+                       }
+                       if err != nil {
+                               t.Errorf("Unable to accept incoming channel request: %v", err)
+                               return
+                       }
+                       if ch.ChannelType() != "session" {
+                               ch.Reject(UnknownChannelType, "unknown channel type")
+                               continue
+                       }
+                       ch.Accept()
+                       go func() {
+                               defer ch.Close()
+                               // this string is returned to stdout
+                               shell := NewServerShell(ch, "golang")
+                               shell.ReadLine()
+                               type exitMsg struct {
+                                       PeersId   uint32
+                                       Request   string
+                                       WantReply bool
+                                       Status    uint32
+                               }
+                               // TODO(dfc) casting to the concrete type should not be
+                               // necessary to send a packet.
+                               msg := exitMsg{
+                                       PeersId:   ch.(*channel).theirId,
+                                       Request:   "exit-status",
+                                       WantReply: false,
+                                       Status:    0,
+                               }
+                               ch.(*channel).serverConn.writePacket(marshal(msgChannelRequest, msg))
+                       }()
+               }
+               t.Log("done")
+       }()
+
+       config := &ClientConfig{
+               User: "testuser",
+               Auth: []ClientAuth{
+                       ClientAuthPassword(pw),
+               },
+       }
+
+       c, err := Dial("tcp", l.Addr().String(), config)
+       if err != nil {
+               t.Fatalf("unable to dial remote side: %s", err)
+       }
+       return c
+}
+
+// Test a simple string is returned to session.Stdout.
+func TestSessionShell(t *testing.T) {
+       conn := dial(t)
+       defer conn.Close()
+       session, err := conn.NewSession()
+       if err != nil {
+               t.Fatalf("Unable to request new session: %s", err)
+       }
+       defer session.Close()
+       stdout := new(bytes.Buffer)
+       session.Stdout = stdout
+       if err := session.Shell(); err != nil {
+               t.Fatalf("Unable to execute command: %s", err)
+       }
+       if err := session.Wait(); err != nil {
+               t.Fatalf("Remote command did not exit cleanly: %s", err)
+       }
+       actual := stdout.String()
+       if actual != "golang" {
+               t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
+       }
+}
+
+// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it.
+
+// Test a simple string is returned via StdoutPipe.
+func TestSessionStdoutPipe(t *testing.T) {
+       conn := dial(t)
+       defer conn.Close()
+       session, err := conn.NewSession()
+       if err != nil {
+               t.Fatalf("Unable to request new session: %s", err)
+       }
+       defer session.Close()
+       stdout, err := session.StdoutPipe()
+       if err != nil {
+               t.Fatalf("Unable to request StdoutPipe(): %v", err)
+       }
+       var buf bytes.Buffer
+       if err := session.Shell(); err != nil {
+               t.Fatalf("Unable to execute command: %s", err)
+       }
+       done := make(chan bool, 1)
+       go func() {
+               if _, err := io.Copy(&buf, stdout); err != nil {
+                       t.Errorf("Copy of stdout failed: %v", err)
+               }
+               done <- true
+       }()
+       if err := session.Wait(); err != nil {
+               t.Fatalf("Remote command did not exit cleanly: %s", err)
+       }
+       <-done
+       actual := buf.String()
+       if actual != "golang" {
+               t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual)
+       }
+}