diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..ef569c7 --- /dev/null +++ b/Makefile @@ -0,0 +1,22 @@ +BINARY = ssh-chat +KEY = host_key +PORT = 2022 + +all: $(BINARY) + +**/*.go: + go build ./... + +$(BINARY): **/*.go *.go + go build . + +build: $(BINARY) + +clean: + rm $(BINARY) + +key: $(KEY) + ssh-keygen -f $(KEY) -P '' + +run: $(BINARY) $(KEY) + ./$(BINARY) -i $(KEY) -b ":$(PORT)" -vv diff --git a/client.go b/client.go new file mode 100644 index 0000000..f67288d --- /dev/null +++ b/client.go @@ -0,0 +1,92 @@ +package main + +import ( + "fmt" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" +) + +const MSG_BUFFER = 10 + +type Client struct { + Server *Server + Msg chan string + Name string +} + +func NewClient(server *Server, name string) *Client { + if name == "" { + name = "Anonymoose" + } + + return &Client{ + Server: server, + Name: name, + Msg: make(chan string, MSG_BUFFER), + } +} + +func (c *Client) handleShell(channel ssh.Channel) { + defer channel.Close() + + term := terminal.NewTerminal(channel, "") + + for { + line, err := term.ReadLine() + if err != nil { + break + } + + switch line { + case "/exit": + channel.Close() + } + + msg := fmt.Sprintf("%s: %s\r\n", c.Name, line) + c.Server.Broadcast(msg) + } +} + +func (c *Client) handleChannels(channels <-chan ssh.NewChannel) { + for ch := range channels { + if t := ch.ChannelType(); t != "session" { + ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) + continue + } + + channel, requests, err := ch.Accept() + if err != nil { + logger.Errorf("Could not accept channel: %v", err) + continue + } + + go func(in <-chan *ssh.Request) { + defer channel.Close() + for req := range in { + ok := false + switch req.Type { + case "shell": + if len(req.Payload) == 0 { + ok = true + } + case "pty-req": + // Setup PTY + ok = true + } + req.Reply(ok, nil) + } + }(requests) + + go c.handleShell(channel) + + go func() { + for msg := range c.Msg { + channel.Write([]byte(msg)) + } + }() + + // We don't care about other channels? + return + } +} diff --git a/cmd.go b/cmd.go index e8c4610..08329ce 100644 --- a/cmd.go +++ b/cmd.go @@ -59,17 +59,14 @@ func main() { // Construct interrupt handler sig := make(chan os.Signal, 1) signal.Notify(sig, os.Interrupt) - go func() { - <-sig // Wait for ^C signal - logger.Warningf("Interrupt signal detected, shutting down.") - server.Stop() - }() - done, err := server.Start(options.Bind) + err = server.Start(options.Bind) if err != nil { logger.Errorf("Failed to start server: %v", err) return } - <-done + <-sig // Wait for ^C signal + logger.Warningf("Interrupt signal detected, shutting down.") + server.Stop() } diff --git a/server.go b/server.go index 7fb02e2..da5a45d 100644 --- a/server.go +++ b/server.go @@ -1,19 +1,20 @@ -// TODO: NoClientAuth - package main import ( "fmt" - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/terminal" "net" + "strings" + "sync" + + "golang.org/x/crypto/ssh" ) type Server struct { sshConfig *ssh.ServerConfig sshSigner *ssh.Signer - socket *net.Listener done chan struct{} + clients []Client + lock sync.Mutex } func NewServer(privateKey []byte) (*Server, error) { @@ -23,120 +24,87 @@ func NewServer(privateKey []byte) (*Server, error) { } config := ssh.ServerConfig{ - NoClientAuth: true, + NoClientAuth: false, + PasswordCallback: func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + return nil, nil + }, + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + return nil, nil + }, } config.AddHostKey(signer) server := Server{ sshConfig: &config, sshSigner: &signer, + done: make(chan struct{}), } return &server, nil } -func (s *Server) handleShell(channel ssh.Channel) { - defer channel.Close() - - term := terminal.NewTerminal(channel, "") - - for { - line, err := term.ReadLine() - if err != nil { - break - } - - switch line { - case "exit": - channel.Close() - } - - term.Write([]byte("you wrote: " + string(line) + "\r\n")) +func (s *Server) Broadcast(msg string) { + logger.Debugf("Broadcast to %d: %s", len(s.clients), strings.TrimRight(msg, "\r\n")) + for _, client := range s.clients { + client.Msg <- msg } } -func (s *Server) handleChannels(channels <-chan ssh.NewChannel) { - for ch := range channels { - if t := ch.ChannelType(); t != "session" { - ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) - continue - } - - channel, requests, err := ch.Accept() - if err != nil { - logger.Errorf("Could not accept channel: %v", err) - continue - } - - go func(in <-chan *ssh.Request) { - defer channel.Close() - for req := range in { - logger.Infof("Request: ", req.Type, string(req.Payload)) - - ok := false - switch req.Type { - case "shell": - // We don't accept any commands (Payload), - // only the default shell. - if len(req.Payload) == 0 { - ok = true - } - case "pty-req": - // Responding 'ok' here will let the client - // know we have a pty ready for input - ok = true - case "window-change": - continue //no response - } - req.Reply(ok, nil) - } - }(requests) - - go s.handleShell(channel) - - channel.Write([]byte("Hello")) - } -} - -func (s *Server) Start(laddr string) (<-chan struct{}, error) { +func (s *Server) Start(laddr string) error { // Once a ServerConfig has been configured, connections can be // accepted. socket, err := net.Listen("tcp", laddr) if err != nil { - return nil, err + return err } - s.socket = &socket logger.Infof("Listening on %s", laddr) go func() { for { conn, err := socket.Accept() + if err != nil { - // TODO: Handle shutdown more gracefully. + // TODO: Handle shutdown more gracefully? logger.Errorf("Failed to accept connection, aborting loop: %v", err) return } - // From a standard TCP connection to an encrypted SSH connection - sshConn, channels, requests, err := ssh.NewServerConn(conn, s.sshConfig) - if err != nil { - logger.Errorf("Failed to handshake: %v", err) - continue - } + // Goroutineify to resume accepting sockets early. + go func() { + // From a standard TCP connection to an encrypted SSH connection + sshConn, channels, requests, err := ssh.NewServerConn(conn, s.sshConfig) + if err != nil { + logger.Errorf("Failed to handshake: %v", err) + return + } - logger.Infof("Connection from: %s, %s, %s", sshConn.RemoteAddr(), sshConn.User(), sshConn.ClientVersion()) + logger.Infof("Connection from: %s, %s, %s", sshConn.RemoteAddr(), sshConn.User(), sshConn.ClientVersion()) - go ssh.DiscardRequests(requests) - go s.handleChannels(channels) + go ssh.DiscardRequests(requests) + + client := NewClient(s, sshConn.User()) + // TODO: mutex this + + s.lock.Lock() + s.clients = append(s.clients, *client) + s.lock.Unlock() + + s.Broadcast(fmt.Sprintf("* Joined: %s", client.Name)) + + go client.handleChannels(channels) + }() } }() - return s.done, nil + go func() { + <-s.done + socket.Close() + }() + + return nil } -func (s *Server) Stop() error { - err := (*s.socket).Close() - s.done <- struct{}{} - return err +func (s *Server) Stop() { + close(s.done) }