From adbe812a41b47a10cfa9feb9406caad3971752db Mon Sep 17 00:00:00 2001 From: Andrey Petrov Date: Tue, 9 Dec 2014 19:26:55 -0800 Subject: [PATCH] Set term width properly. --- README.md | 17 ++++++------- client.go | 72 ++++++++++++++++++++++++++++++++++++++++++------------- pty.go | 69 ++++++++++++++++++++++++++++++++++++++++++++++++++++ server.go | 23 +++++++++++------- 4 files changed, 148 insertions(+), 33 deletions(-) create mode 100644 pty.go diff --git a/README.md b/README.md index a2108a4..3e6aab3 100644 --- a/README.md +++ b/README.md @@ -5,11 +5,12 @@ Coming real soon. ## TODO: -* Welcome message. -* set term width properly -* client map rather than list -* backfill chat history -* tab completion -* /help -* /about -* /list +* [x] Welcome message. +* [x] set term width properly +* [x] client map rather than list +* [ ] backfill chat history +* [ ] tab completion +* [ ] /help +* [ ] /about +* [ ] /list +* [ ] pubkey fingerprint diff --git a/client.go b/client.go index 6f4acca..0f756fa 100644 --- a/client.go +++ b/client.go @@ -10,37 +10,55 @@ import ( const MSG_BUFFER = 10 type Client struct { - Server *Server - Msg chan string - Name string + Server *Server + Conn *ssh.ServerConn + Msg chan string + Name string + term *terminal.Terminal + termWidth int + termHeight int } -func NewClient(server *Server, name string) *Client { +func NewClient(server *Server, conn *ssh.ServerConn, name string) *Client { if name == "" { name = "Anonymoose" } return &Client{ Server: server, + Conn: conn, Name: name, Msg: make(chan string, MSG_BUFFER), } } +func (c *Client) Resize(width int, height int) error { + err := c.term.SetSize(width, height) + if err != nil { + logger.Errorf("Resize failed: %dx%d", width, height) + return err + } + c.termWidth, c.termHeight = width, height + return nil +} + +func (c *Client) sendWelcome() { + msg := fmt.Sprintf("Welcome to ssh-chat. Enter /help for more.\r\n") + c.Msg <- msg + +} + func (c *Client) handleShell(channel ssh.Channel) { defer channel.Close() - prompt := fmt.Sprintf("%s> ", c.Name) - term := terminal.NewTerminal(channel, prompt) - go func() { for msg := range c.Msg { - term.Write([]byte(msg)) + c.term.Write([]byte(msg)) } }() for { - line, err := term.ReadLine() + line, err := c.term.ReadLine() if err != nil { break } @@ -50,13 +68,16 @@ func (c *Client) handleShell(channel ssh.Channel) { channel.Close() } - term.Write(term.Escape.Reset) + //c.term.Write(c.term.Escape.Reset) msg := fmt.Sprintf("%s: %s\r\n", c.Name, line) c.Server.Broadcast(msg, c) } + } func (c *Client) handleChannels(channels <-chan ssh.NewChannel) { + prompt := fmt.Sprintf("[%s] ", c.Name) + for ch := range channels { if t := ch.ChannelType(); t != "session" { ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) @@ -69,25 +90,42 @@ func (c *Client) handleChannels(channels <-chan ssh.NewChannel) { continue } + c.term = terminal.NewTerminal(channel, prompt) + go func(in <-chan *ssh.Request) { defer channel.Close() + hasShell := false for req := range in { - ok := false + var width, height int + var ok bool + switch req.Type { case "shell": - if len(req.Payload) == 0 { + if c.term != nil && !hasShell { + go c.handleShell(channel) ok = true + hasShell = true } case "pty-req": - // Setup PTY - ok = true + width, height, ok = parsePtyRequest(req.Payload) + if ok { + err := c.Resize(width, height) + ok = err == nil + } + case "window-change": + width, height, ok = parseWinchRequest(req.Payload) + if ok { + err := c.Resize(width, height) + ok = err == nil + } + } + + if req.WantReply { + req.Reply(ok, nil) } - req.Reply(ok, nil) } }(requests) - go c.handleShell(channel) - // We don't care about other channels? return } diff --git a/pty.go b/pty.go new file mode 100644 index 0000000..e635fba --- /dev/null +++ b/pty.go @@ -0,0 +1,69 @@ +// Borrowed from go.crypto circa 2011 +package main + +import "encoding/binary" + +// parsePtyRequest parses the payload of the pty-req message and extracts the +// dimensions of the terminal. See RFC 4254, section 6.2. +func parsePtyRequest(s []byte) (width, height int, ok bool) { + _, s, ok = parseString(s) + if !ok { + return + } + width32, s, ok := parseUint32(s) + if !ok { + return + } + height32, _, ok := parseUint32(s) + width = int(width32) + height = int(height32) + if width < 1 { + ok = false + } + if height < 1 { + ok = false + } + return +} + +func parseWinchRequest(s []byte) (width, height int, ok bool) { + width32, s, ok := parseUint32(s) + if !ok { + return + } + height32, s, ok := parseUint32(s) + if !ok { + return + } + + width = int(width32) + height = int(height32) + if width < 1 { + ok = false + } + if height < 1 { + ok = false + } + return +} + +func parseString(in []byte) (out string, rest []byte, ok bool) { + if len(in) < 4 { + return + } + length := binary.BigEndian.Uint32(in) + if uint32(len(in)) < 4+length { + return + } + out = string(in[4 : 4+length]) + rest = in[4+length:] + ok = true + return +} + +func parseUint32(in []byte) (uint32, []byte, bool) { + if len(in) < 4 { + return 0, nil, false + } + return binary.BigEndian.Uint32(in), in[4:], true +} diff --git a/server.go b/server.go index b4436e7..ba702c6 100644 --- a/server.go +++ b/server.go @@ -13,7 +13,7 @@ type Server struct { sshConfig *ssh.ServerConfig sshSigner *ssh.Signer done chan struct{} - clients map[Client]struct{} + clients map[*Client]struct{} lock sync.Mutex } @@ -29,6 +29,7 @@ func NewServer(privateKey []byte) (*Server, error) { return nil, nil }, PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + // fingerprint := md5.Sum(key.Marshal() return nil, nil }, } @@ -38,7 +39,7 @@ func NewServer(privateKey []byte) (*Server, error) { sshConfig: &config, sshSigner: &signer, done: make(chan struct{}), - clients: map[Client]struct{}{}, + clients: map[*Client]struct{}{}, } return &server, nil @@ -47,7 +48,7 @@ func NewServer(privateKey []byte) (*Server, error) { func (s *Server) Broadcast(msg string, except *Client) { logger.Debugf("Broadcast to %d: %s", len(s.clients), strings.TrimRight(msg, "\r\n")) for client := range s.clients { - if except != nil && client == *except { + if except != nil && client == except { continue } client.Msg <- msg @@ -87,23 +88,25 @@ func (s *Server) Start(laddr string) error { go ssh.DiscardRequests(requests) - client := NewClient(s, sshConn.User()) + client := NewClient(s, sshConn, sshConn.User()) // TODO: mutex this s.lock.Lock() - s.clients[*client] = struct{}{} + s.clients[client] = struct{}{} num := len(s.clients) s.lock.Unlock() - s.Broadcast(fmt.Sprintf("* Joined: %s (%d present)\r\n", client.Name, num), nil) + client.sendWelcome() + + s.Broadcast(fmt.Sprintf("* %s joined. (Total connected: %d)\r\n", client.Name, num), nil) go func() { sshConn.Wait() s.lock.Lock() - delete(s.clients, *client) + delete(s.clients, client) s.lock.Unlock() - s.Broadcast(fmt.Sprintf("* Left: %s\r\n", client.Name), nil) + s.Broadcast(fmt.Sprintf("* %s left.\r\n", client.Name), nil) }() go client.handleChannels(channels) @@ -120,5 +123,9 @@ func (s *Server) Start(laddr string) error { } func (s *Server) Stop() { + for client := range s.clients { + client.Conn.Close() + } + close(s.done) }