From 7beb7f99bbc60e38c4f99f23d4cddd72f62190d0 Mon Sep 17 00:00:00 2001 From: Andrey Petrov Date: Mon, 22 Dec 2014 15:53:30 -0800 Subject: [PATCH] Testing for net. --- sshd/logger.go | 2 +- sshd/net.go | 9 ++-- sshd/net_test.go | 137 +++++++++++++++++++++++++++++++++++++++++++++++ sshd/server.go | 98 --------------------------------- sshd/terminal.go | 4 +- 5 files changed, 143 insertions(+), 107 deletions(-) create mode 100644 sshd/net_test.go delete mode 100644 sshd/server.go diff --git a/sshd/logger.go b/sshd/logger.go index 49a4456..9f6998f 100644 --- a/sshd/logger.go +++ b/sshd/logger.go @@ -7,7 +7,7 @@ var logger *stdlog.Logger func SetLogger(w io.Writer) { flags := stdlog.Flags() - prefix := "[chat] " + prefix := "[sshd] " logger = stdlog.New(w, prefix, flags) } diff --git a/sshd/net.go b/sshd/net.go index ba34bc0..6a30976 100644 --- a/sshd/net.go +++ b/sshd/net.go @@ -2,7 +2,6 @@ package sshd import ( "net" - "syscall" "golang.org/x/crypto/ssh" ) @@ -19,8 +18,7 @@ func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) { if err != nil { return nil, err } - l := socket.(SSHListener) - l.config = config + l := SSHListener{socket, config} return &l, nil } @@ -41,15 +39,14 @@ func (l *SSHListener) ServeTerminal() <-chan *Terminal { go func() { defer l.Close() + defer close(ch) for { conn, err := l.Accept() if err != nil { logger.Printf("Failed to accept connection: %v", err) - if err == syscall.EINVAL { - return - } + return } // Goroutineify to resume accepting sockets early diff --git a/sshd/net_test.go b/sshd/net_test.go new file mode 100644 index 0000000..6ec4311 --- /dev/null +++ b/sshd/net_test.go @@ -0,0 +1,137 @@ +package sshd + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "io" + "testing" + + "golang.org/x/crypto/ssh" +) + +// TODO: Move some of these into their own package? + +func MakeKey(bits int) (ssh.Signer, error) { + key, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, err + } + return ssh.NewSignerFromKey(key) +} + +func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error { + config := &ssh.ClientConfig{ + User: name, + Auth: []ssh.AuthMethod{ + ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + return + }), + }, + } + + conn, err := ssh.Dial("tcp", host, config) + if err != nil { + return err + } + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + return err + } + defer session.Close() + + in, err := session.StdinPipe() + if err != nil { + return err + } + + out, err := session.StdoutPipe() + if err != nil { + return err + } + + err = session.Shell() + if err != nil { + return err + } + + handler(out, in) + + return nil +} + +func TestServerInit(t *testing.T) { + config := MakeNoAuth() + s, err := ListenSSH(":badport", config) + if err == nil { + t.Fatal("should fail on bad port") + } + + s, err = ListenSSH(":0", config) + if err != nil { + t.Error(err) + } + + err = s.Close() + if err != nil { + t.Error(err) + } +} + +func TestServeTerminals(t *testing.T) { + signer, err := MakeKey(512) + config := MakeNoAuth() + config.AddHostKey(signer) + + s, err := ListenSSH(":0", config) + if err != nil { + t.Fatal(err) + } + + terminals := s.ServeTerminal() + + go func() { + // Accept one terminal, read from it, echo back, close. + term := <-terminals + term.SetPrompt("> ") + + line, err := term.ReadLine() + if err != nil { + t.Error(err) + } + _, err = term.Write([]byte("echo: " + line + "\r\n")) + if err != nil { + t.Error(err) + } + + term.Close() + }() + + host := s.Addr().String() + name := "foo" + + err = NewClientSession(host, name, func(r io.Reader, w io.WriteCloser) { + // Consume if there is anything + buf := new(bytes.Buffer) + w.Write([]byte("hello\r\n")) + + buf.Reset() + _, err := io.Copy(buf, r) + if err != nil { + t.Error(err) + } + + expected := "> hello\r\necho: hello\r\n" + actual := buf.String() + if actual != expected { + t.Errorf("Got `%s`; expected `%s`", actual, expected) + } + s.Close() + }) + + if err != nil { + t.Fatal(err) + } +} diff --git a/sshd/server.go b/sshd/server.go deleted file mode 100644 index cd8980c..0000000 --- a/sshd/server.go +++ /dev/null @@ -1,98 +0,0 @@ -package sshd - -import ( - "net" - "sync" - "syscall" - "time" - - "golang.org/x/crypto/ssh" -) - -// Server holds all the fields used by a server -type Server struct { - sshConfig *ssh.ServerConfig - done chan struct{} - started time.Time - sync.RWMutex -} - -// Initialize a new server -func NewServer(privateKey []byte) (*Server, error) { - signer, err := ssh.ParsePrivateKey(privateKey) - if err != nil { - return nil, err - } - - server := Server{ - done: make(chan struct{}), - started: time.Now(), - } - - config := MakeNoAuth() - config.AddHostKey(signer) - - server.sshConfig = config - - return &server, nil -} - -// Start starts the server -func (s *Server) Start(laddr string) error { - // Once a ServerConfig has been configured, connections can be - // accepted. - socket, err := net.Listen("tcp", laddr) - if err != nil { - return err - } - - logger.Infof("Listening on %s", laddr) - - go func() { - defer socket.Close() - for { - conn, err := socket.Accept() - - if err != nil { - logger.Printf("Failed to accept connection: %v", err) - if err == syscall.EINVAL { - // TODO: Handle shutdown more gracefully? - return - } - } - - // Goroutineify to resume accepting sockets early. - go func() { - // From a standard TCP connection to an encrypted SSH connection - sshConn, channels, requests, err := ssh.NewServerConn(conn, s.sshConfig) - if err != nil { - logger.Printf("Failed to handshake: %v", err) - return - } - - go ssh.DiscardRequests(requests) - - client := NewClient(s, sshConn) - go client.handleChannels(channels) - }() - } - }() - - go func() { - <-s.done - socket.Close() - }() - - return nil -} - -// Stop stops the server -func (s *Server) Stop() { - s.Lock() - for _, client := range s.clients { - client.Conn.Close() - } - s.Unlock() - - close(s.done) -} diff --git a/sshd/terminal.go b/sshd/terminal.go index e872bf6..51597c6 100644 --- a/sshd/terminal.go +++ b/sshd/terminal.go @@ -10,7 +10,7 @@ import ( // Extending ssh/terminal to include a closer interface type Terminal struct { - *terminal.Terminal + terminal.Terminal Conn ssh.Conn Channel ssh.Channel } @@ -25,7 +25,7 @@ func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) { return nil, err } term := Terminal{ - terminal.NewTerminal(channel, "Connecting..."), + *terminal.NewTerminal(channel, "Connecting..."), conn, channel, }