diff --git a/cmd.go b/cmd.go index 34840ad..89ad01a 100644 --- a/cmd.go +++ b/cmd.go @@ -34,6 +34,7 @@ var logLevels = []log.Level{ } var buildCommit string + func main() { options := Options{} parser := flags.NewParser(&options, flags.Default) diff --git a/sshd/auth.go b/sshd/auth.go new file mode 100644 index 0000000..d271a85 --- /dev/null +++ b/sshd/auth.go @@ -0,0 +1,68 @@ +package sshd + +import ( + "crypto/sha1" + "errors" + "fmt" + "strings" + + "golang.org/x/crypto/ssh" +) + +var errBanned = errors.New("banned") +var errNotWhitelisted = errors.New("not whitelisted") +var errNoInteractive = errors.New("public key authentication required") + +type Auth interface { + IsBanned(ssh.PublicKey) bool + IsWhitelisted(ssh.PublicKey) bool +} + +func MakeAuth(auth Auth) *ssh.ServerConfig { + config := ssh.ServerConfig{ + NoClientAuth: false, + // Auth-related things should be constant-time to avoid timing attacks. + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if auth.IsBanned(key) { + return nil, errBanned + } + if !auth.IsWhitelisted(key) { + return nil, errNotWhitelisted + } + perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}} + return perm, nil + }, + KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { + if auth.IsBanned(nil) { + return nil, errNoInteractive + } + if !auth.IsWhitelisted(nil) { + return nil, errNotWhitelisted + } + return nil, nil + }, + } + + return &config +} + +func MakeNoAuth() *ssh.ServerConfig { + config := ssh.ServerConfig{ + NoClientAuth: false, + // Auth-related things should be constant-time to avoid timing attacks. + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + return nil, nil + }, + KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { + return nil, nil + }, + } + + return &config +} + +func Fingerprint(k ssh.PublicKey) string { + hash := sha1.Sum(k.Marshal()) + r := fmt.Sprintf("% x", hash) + return strings.Replace(r, " ", ":", -1) +} diff --git a/sshd/doc.go b/sshd/doc.go new file mode 100644 index 0000000..4d94d8b --- /dev/null +++ b/sshd/doc.go @@ -0,0 +1,33 @@ +package sshd + +/* + + signer, err := ssh.ParsePrivateKey(privateKey) + + config := MakeNoAuth() + config.AddHostKey(signer) + + s, err := ListenSSH("0.0.0.0:22", config) + if err != nil { + // Handle opening socket error + } + + terminals := s.ServeTerminal() + + for term := range terminals { + go func() { + defer term.Close() + term.SetPrompt("...") + term.AutoCompleteCallback = nil // ... + + for { + line, err := term.Readline() + if err != nil { + break + } + term.Write(...) + } + + }() + } +*/ diff --git a/sshd/logger.go b/sshd/logger.go new file mode 100644 index 0000000..49a4456 --- /dev/null +++ b/sshd/logger.go @@ -0,0 +1,22 @@ +package sshd + +import "io" +import stdlog "log" + +var logger *stdlog.Logger + +func SetLogger(w io.Writer) { + flags := stdlog.Flags() + prefix := "[chat] " + logger = stdlog.New(w, prefix, flags) +} + +type nullWriter struct{} + +func (nullWriter) Write(data []byte) (int, error) { + return len(data), nil +} + +func init() { + SetLogger(nullWriter{}) +} diff --git a/sshd/multi.go b/sshd/multi.go new file mode 100644 index 0000000..62447a7 --- /dev/null +++ b/sshd/multi.go @@ -0,0 +1,42 @@ +package sshd + +import ( + "fmt" + "io" + "strings" +) + +// Keep track of multiple errors and coerce them into one error +type MultiError []error + +func (e MultiError) Error() string { + switch len(e) { + case 0: + return "" + case 1: + return e[0].Error() + default: + errs := []string{} + for _, err := range e { + errs = append(errs, err.Error()) + } + return fmt.Sprintf("%d errors: %s", strings.Join(errs, "; ")) + } +} + +// Keep track of multiple closers and close them all as one closer +type MultiCloser []io.Closer + +func (c MultiCloser) Close() error { + errors := MultiError{} + for _, closer := range c { + err := closer.Close() + if err != nil { + errors = append(errors, err) + } + } + if len(errors) == 0 { + return nil + } + return errors +} diff --git a/sshd/net.go b/sshd/net.go new file mode 100644 index 0000000..ba34bc0 --- /dev/null +++ b/sshd/net.go @@ -0,0 +1,68 @@ +package sshd + +import ( + "net" + "syscall" + + "golang.org/x/crypto/ssh" +) + +// Container for the connection and ssh-related configuration +type SSHListener struct { + net.Listener + config *ssh.ServerConfig +} + +// Make an SSH listener socket +func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) { + socket, err := net.Listen("tcp", laddr) + if err != nil { + return nil, err + } + l := socket.(SSHListener) + l.config = config + return &l, nil +} + +func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) { + // Upgrade TCP connection to SSH connection + sshConn, channels, requests, err := ssh.NewServerConn(conn, l.config) + if err != nil { + return nil, err + } + + go ssh.DiscardRequests(requests) + return NewSession(sshConn, channels) +} + +// Accept incoming connections as terminal requests and yield them +func (l *SSHListener) ServeTerminal() <-chan *Terminal { + ch := make(chan *Terminal) + + go func() { + defer l.Close() + + for { + conn, err := l.Accept() + + if err != nil { + logger.Printf("Failed to accept connection: %v", err) + if err == syscall.EINVAL { + return + } + } + + // Goroutineify to resume accepting sockets early + go func() { + term, err := l.handleConn(conn) + if err != nil { + logger.Printf("Failed to handshake: %v", err) + return + } + ch <- term + }() + } + }() + + return ch +} diff --git a/sshd/pty.go b/sshd/pty.go new file mode 100644 index 0000000..5aecc3e --- /dev/null +++ b/sshd/pty.go @@ -0,0 +1,69 @@ +// Borrowed from go.crypto circa 2011 +package sshd + +import "encoding/binary" + +// parsePtyRequest parses the payload of the pty-req message and extracts the +// dimensions of the terminal. See RFC 4254, section 6.2. +func parsePtyRequest(s []byte) (width, height int, ok bool) { + _, s, ok = parseString(s) + if !ok { + return + } + width32, s, ok := parseUint32(s) + if !ok { + return + } + height32, _, ok := parseUint32(s) + width = int(width32) + height = int(height32) + if width < 1 { + ok = false + } + if height < 1 { + ok = false + } + return +} + +func parseWinchRequest(s []byte) (width, height int, ok bool) { + width32, s, ok := parseUint32(s) + if !ok { + return + } + height32, s, ok := parseUint32(s) + if !ok { + return + } + + width = int(width32) + height = int(height32) + if width < 1 { + ok = false + } + if height < 1 { + ok = false + } + return +} + +func parseString(in []byte) (out string, rest []byte, ok bool) { + if len(in) < 4 { + return + } + length := binary.BigEndian.Uint32(in) + if uint32(len(in)) < 4+length { + return + } + out = string(in[4 : 4+length]) + rest = in[4+length:] + ok = true + return +} + +func parseUint32(in []byte) (uint32, []byte, bool) { + if len(in) < 4 { + return 0, nil, false + } + return binary.BigEndian.Uint32(in), in[4:], true +} diff --git a/sshd/server.go b/sshd/server.go new file mode 100644 index 0000000..cd8980c --- /dev/null +++ b/sshd/server.go @@ -0,0 +1,98 @@ +package sshd + +import ( + "net" + "sync" + "syscall" + "time" + + "golang.org/x/crypto/ssh" +) + +// Server holds all the fields used by a server +type Server struct { + sshConfig *ssh.ServerConfig + done chan struct{} + started time.Time + sync.RWMutex +} + +// Initialize a new server +func NewServer(privateKey []byte) (*Server, error) { + signer, err := ssh.ParsePrivateKey(privateKey) + if err != nil { + return nil, err + } + + server := Server{ + done: make(chan struct{}), + started: time.Now(), + } + + config := MakeNoAuth() + config.AddHostKey(signer) + + server.sshConfig = config + + return &server, nil +} + +// Start starts the server +func (s *Server) Start(laddr string) error { + // Once a ServerConfig has been configured, connections can be + // accepted. + socket, err := net.Listen("tcp", laddr) + if err != nil { + return err + } + + logger.Infof("Listening on %s", laddr) + + go func() { + defer socket.Close() + for { + conn, err := socket.Accept() + + if err != nil { + logger.Printf("Failed to accept connection: %v", err) + if err == syscall.EINVAL { + // TODO: Handle shutdown more gracefully? + return + } + } + + // Goroutineify to resume accepting sockets early. + go func() { + // From a standard TCP connection to an encrypted SSH connection + sshConn, channels, requests, err := ssh.NewServerConn(conn, s.sshConfig) + if err != nil { + logger.Printf("Failed to handshake: %v", err) + return + } + + go ssh.DiscardRequests(requests) + + client := NewClient(s, sshConn) + go client.handleChannels(channels) + }() + } + }() + + go func() { + <-s.done + socket.Close() + }() + + return nil +} + +// Stop stops the server +func (s *Server) Stop() { + s.Lock() + for _, client := range s.clients { + client.Conn.Close() + } + s.Unlock() + + close(s.done) +} diff --git a/sshd/terminal.go b/sshd/terminal.go new file mode 100644 index 0000000..e872bf6 --- /dev/null +++ b/sshd/terminal.go @@ -0,0 +1,93 @@ +package sshd + +import ( + "errors" + "fmt" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" +) + +// Extending ssh/terminal to include a closer interface +type Terminal struct { + *terminal.Terminal + Conn ssh.Conn + Channel ssh.Channel +} + +// Make new terminal from a session channel +func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) { + if ch.ChannelType() != "session" { + return nil, errors.New("terminal requires session channel") + } + channel, requests, err := ch.Accept() + if err != nil { + return nil, err + } + term := Terminal{ + terminal.NewTerminal(channel, "Connecting..."), + conn, + channel, + } + + go term.listen(requests) + return &term, nil +} + +// Find session channel and make a Terminal from it +func NewSession(conn ssh.Conn, channels <-chan ssh.NewChannel) (term *Terminal, err error) { + for ch := range channels { + if t := ch.ChannelType(); t != "session" { + ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) + continue + } + + term, err = NewTerminal(conn, ch) + if err == nil { + break + } + } + + return term, err +} + +// Close terminal and ssh connection +func (t *Terminal) Close() error { + return MultiCloser{t.Channel, t.Conn}.Close() +} + +// Negotiate terminal type and settings +func (t *Terminal) listen(requests <-chan *ssh.Request) { + hasShell := false + + for req := range requests { + var width, height int + var ok bool + + switch req.Type { + case "shell": + if !hasShell { + ok = true + hasShell = true + } + case "pty-req": + width, height, ok = parsePtyRequest(req.Payload) + if ok { + // TODO: Hardcode width to 100000? + err := t.SetSize(width, height) + ok = err == nil + } + case "window-change": + width, height, ok = parseWinchRequest(req.Payload) + if ok { + // TODO: Hardcode width to 100000? + err := t.SetSize(width, height) + ok = err == nil + } + } + + if req.WantReply { + req.Reply(ok, nil) + } + } +}