From c4d0ac0e2f7a12cf44f4711b47bbc5737c14ce9c Mon Sep 17 00:00:00 2001 From: Dave Cheney Date: Thu, 1 Dec 2011 08:30:16 -0200 Subject: [PATCH] exp/ssh: add Std{in,out,err}Pipe methods to Session R=gustav.paul, cw, agl, rsc, n13m3y3r CC=golang-dev https://golang.org/cl/5433080 --- src/pkg/exp/ssh/session.go | 87 ++++++++++++++++--- src/pkg/exp/ssh/session_test.go | 149 ++++++++++++++++++++++++++++++++ 2 files changed, 224 insertions(+), 12 deletions(-) create mode 100644 src/pkg/exp/ssh/session_test.go diff --git a/src/pkg/exp/ssh/session.go b/src/pkg/exp/ssh/session.go index dab0113f4b..8eea8b287b 100644 --- a/src/pkg/exp/ssh/session.go +++ b/src/pkg/exp/ssh/session.go @@ -54,9 +54,10 @@ type Session struct { *clientChan // the channel backing this session - started bool // true once a Shell or Run is invoked. - copyFuncs []func() error - errch chan error // one send per copyFunc + started bool // true once Start, Run or Shell is invoked. + closeAfterWait []io.Closer + copyFuncs []func() error + errch chan error // one send per copyFunc } // RFC 4254 Section 6.4. @@ -231,7 +232,7 @@ func (s *Session) start() error { return nil } -// Wait waits for the remote command to exit. +// Wait waits for the remote command to exit. func (s *Session) Wait() error { if !s.started { return errors.New("ssh: session not started") @@ -244,11 +245,12 @@ func (s *Session) Wait() error { copyError = err } } - + for _, fd := range s.closeAfterWait { + fd.Close() + } if waitErr != nil { return waitErr } - return copyError } @@ -283,11 +285,15 @@ func (s *Session) stdin() error { s.Stdin = new(bytes.Buffer) } s.copyFuncs = append(s.copyFuncs, func() error { - _, err := io.Copy(&chanWriter{ + w := &chanWriter{ packetWriter: s, peersId: s.peersId, win: s.win, - }, s.Stdin) + } + _, err := io.Copy(w, s.Stdin) + if err1 := w.Close(); err == nil { + err = err1 + } return err }) return nil @@ -298,11 +304,12 @@ func (s *Session) stdout() error { s.Stdout = ioutil.Discard } s.copyFuncs = append(s.copyFuncs, func() error { - _, err := io.Copy(s.Stdout, &chanReader{ + r := &chanReader{ packetWriter: s, peersId: s.peersId, data: s.data, - }) + } + _, err := io.Copy(s.Stdout, r) return err }) return nil @@ -313,16 +320,72 @@ func (s *Session) stderr() error { s.Stderr = ioutil.Discard } s.copyFuncs = append(s.copyFuncs, func() error { - _, err := io.Copy(s.Stderr, &chanReader{ + r := &chanReader{ packetWriter: s, peersId: s.peersId, data: s.dataExt, - }) + } + _, err := io.Copy(s.Stderr, r) return err }) return nil } +// StdinPipe returns a pipe that will be connected to the +// remote command's standard input when the command starts. +func (s *Session) StdinPipe() (io.WriteCloser, error) { + if s.Stdin != nil { + return nil, errors.New("ssh: Stdin already set") + } + if s.started { + return nil, errors.New("ssh: StdinPipe after process started") + } + pr, pw := io.Pipe() + s.Stdin = pr + s.closeAfterWait = append(s.closeAfterWait, pr) + return pw, nil +} + +// StdoutPipe returns a pipe that will be connected to the +// remote command's standard output when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StdoutPipe reader is +// not serviced fast enought it may eventually cause the +// remote command to block. +func (s *Session) StdoutPipe() (io.ReadCloser, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.started { + return nil, errors.New("ssh: StdoutPipe after process started") + } + pr, pw := io.Pipe() + s.Stdout = pw + s.closeAfterWait = append(s.closeAfterWait, pw) + return pr, nil +} + +// StderrPipe returns a pipe that will be connected to the +// remote command's standard error when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StderrPipe reader is +// not serviced fast enought it may eventually cause the +// remote command to block. +func (s *Session) StderrPipe() (io.ReadCloser, error) { + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + if s.started { + return nil, errors.New("ssh: StderrPipe after process started") + } + pr, pw := io.Pipe() + s.Stderr = pw + s.closeAfterWait = append(s.closeAfterWait, pw) + return pr, nil +} + +// TODO(dfc) add Output and CombinedOutput helpers + // NewSession returns a new interactive session on the remote host. func (c *ClientConn) NewSession() (*Session, error) { ch := c.newChan(c.transport) diff --git a/src/pkg/exp/ssh/session_test.go b/src/pkg/exp/ssh/session_test.go new file mode 100644 index 0000000000..4be7746d17 --- /dev/null +++ b/src/pkg/exp/ssh/session_test.go @@ -0,0 +1,149 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session tests. + +import ( + "bytes" + "io" + "testing" +) + +// dial constructs a new test server and returns a *ClientConn. +func dial(t *testing.T) *ClientConn { + pw := password("tiger") + serverConfig.PasswordCallback = func(user, pass string) bool { + return user == "testuser" && pass == string(pw) + } + serverConfig.PubKeyCallback = nil + + l, err := Listen("tcp", "127.0.0.1:0", serverConfig) + if err != nil { + t.Fatalf("unable to listen: %s", err) + } + go func() { + defer l.Close() + conn, err := l.Accept() + if err != nil { + t.Errorf("Unable to accept: %v", err) + return + } + defer conn.Close() + if err := conn.Handshake(); err != nil { + t.Errorf("Unable to handshake: %v", err) + return + } + for { + ch, err := conn.Accept() + if err == io.EOF { + return + } + if err != nil { + t.Errorf("Unable to accept incoming channel request: %v", err) + return + } + if ch.ChannelType() != "session" { + ch.Reject(UnknownChannelType, "unknown channel type") + continue + } + ch.Accept() + go func() { + defer ch.Close() + // this string is returned to stdout + shell := NewServerShell(ch, "golang") + shell.ReadLine() + type exitMsg struct { + PeersId uint32 + Request string + WantReply bool + Status uint32 + } + // TODO(dfc) casting to the concrete type should not be + // necessary to send a packet. + msg := exitMsg{ + PeersId: ch.(*channel).theirId, + Request: "exit-status", + WantReply: false, + Status: 0, + } + ch.(*channel).serverConn.writePacket(marshal(msgChannelRequest, msg)) + }() + } + t.Log("done") + }() + + config := &ClientConfig{ + User: "testuser", + Auth: []ClientAuth{ + ClientAuthPassword(pw), + }, + } + + c, err := Dial("tcp", l.Addr().String(), config) + if err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + return c +} + +// Test a simple string is returned to session.Stdout. +func TestSessionShell(t *testing.T) { + conn := dial(t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %s", err) + } + defer session.Close() + stdout := new(bytes.Buffer) + session.Stdout = stdout + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %s", err) + } + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %s", err) + } + actual := stdout.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it. + +// Test a simple string is returned via StdoutPipe. +func TestSessionStdoutPipe(t *testing.T) { + conn := dial(t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %s", err) + } + defer session.Close() + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("Unable to request StdoutPipe(): %v", err) + } + var buf bytes.Buffer + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %s", err) + } + done := make(chan bool, 1) + go func() { + if _, err := io.Copy(&buf, stdout); err != nil { + t.Errorf("Copy of stdout failed: %v", err) + } + done <- true + }() + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %s", err) + } + <-done + actual := buf.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} -- 2.48.1