Set term width properly.

This commit is contained in:
Andrey Petrov 2014-12-09 19:26:55 -08:00
parent da934daf4a
commit adbe812a41
4 changed files with 148 additions and 33 deletions

View File

@ -5,11 +5,12 @@ Coming real soon.
## TODO: ## TODO:
* Welcome message. * [x] Welcome message.
* set term width properly * [x] set term width properly
* client map rather than list * [x] client map rather than list
* backfill chat history * [ ] backfill chat history
* tab completion * [ ] tab completion
* /help * [ ] /help
* /about * [ ] /about
* /list * [ ] /list
* [ ] pubkey fingerprint

View File

@ -10,37 +10,55 @@ import (
const MSG_BUFFER = 10 const MSG_BUFFER = 10
type Client struct { type Client struct {
Server *Server Server *Server
Msg chan string Conn *ssh.ServerConn
Name string Msg chan string
Name string
term *terminal.Terminal
termWidth int
termHeight int
} }
func NewClient(server *Server, name string) *Client { func NewClient(server *Server, conn *ssh.ServerConn, name string) *Client {
if name == "" { if name == "" {
name = "Anonymoose" name = "Anonymoose"
} }
return &Client{ return &Client{
Server: server, Server: server,
Conn: conn,
Name: name, Name: name,
Msg: make(chan string, MSG_BUFFER), Msg: make(chan string, MSG_BUFFER),
} }
} }
func (c *Client) Resize(width int, height int) error {
err := c.term.SetSize(width, height)
if err != nil {
logger.Errorf("Resize failed: %dx%d", width, height)
return err
}
c.termWidth, c.termHeight = width, height
return nil
}
func (c *Client) sendWelcome() {
msg := fmt.Sprintf("Welcome to ssh-chat. Enter /help for more.\r\n")
c.Msg <- msg
}
func (c *Client) handleShell(channel ssh.Channel) { func (c *Client) handleShell(channel ssh.Channel) {
defer channel.Close() defer channel.Close()
prompt := fmt.Sprintf("%s> ", c.Name)
term := terminal.NewTerminal(channel, prompt)
go func() { go func() {
for msg := range c.Msg { for msg := range c.Msg {
term.Write([]byte(msg)) c.term.Write([]byte(msg))
} }
}() }()
for { for {
line, err := term.ReadLine() line, err := c.term.ReadLine()
if err != nil { if err != nil {
break break
} }
@ -50,13 +68,16 @@ func (c *Client) handleShell(channel ssh.Channel) {
channel.Close() channel.Close()
} }
term.Write(term.Escape.Reset) //c.term.Write(c.term.Escape.Reset)
msg := fmt.Sprintf("%s: %s\r\n", c.Name, line) msg := fmt.Sprintf("%s: %s\r\n", c.Name, line)
c.Server.Broadcast(msg, c) c.Server.Broadcast(msg, c)
} }
} }
func (c *Client) handleChannels(channels <-chan ssh.NewChannel) { func (c *Client) handleChannels(channels <-chan ssh.NewChannel) {
prompt := fmt.Sprintf("[%s] ", c.Name)
for ch := range channels { for ch := range channels {
if t := ch.ChannelType(); t != "session" { if t := ch.ChannelType(); t != "session" {
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
@ -69,25 +90,42 @@ func (c *Client) handleChannels(channels <-chan ssh.NewChannel) {
continue continue
} }
c.term = terminal.NewTerminal(channel, prompt)
go func(in <-chan *ssh.Request) { go func(in <-chan *ssh.Request) {
defer channel.Close() defer channel.Close()
hasShell := false
for req := range in { for req := range in {
ok := false var width, height int
var ok bool
switch req.Type { switch req.Type {
case "shell": case "shell":
if len(req.Payload) == 0 { if c.term != nil && !hasShell {
go c.handleShell(channel)
ok = true ok = true
hasShell = true
} }
case "pty-req": case "pty-req":
// Setup PTY width, height, ok = parsePtyRequest(req.Payload)
ok = true if ok {
err := c.Resize(width, height)
ok = err == nil
}
case "window-change":
width, height, ok = parseWinchRequest(req.Payload)
if ok {
err := c.Resize(width, height)
ok = err == nil
}
}
if req.WantReply {
req.Reply(ok, nil)
} }
req.Reply(ok, nil)
} }
}(requests) }(requests)
go c.handleShell(channel)
// We don't care about other channels? // We don't care about other channels?
return return
} }

69
pty.go Normal file
View File

@ -0,0 +1,69 @@
// Borrowed from go.crypto circa 2011
package main
import "encoding/binary"
// parsePtyRequest parses the payload of the pty-req message and extracts the
// dimensions of the terminal. See RFC 4254, section 6.2.
func parsePtyRequest(s []byte) (width, height int, ok bool) {
_, s, ok = parseString(s)
if !ok {
return
}
width32, s, ok := parseUint32(s)
if !ok {
return
}
height32, _, ok := parseUint32(s)
width = int(width32)
height = int(height32)
if width < 1 {
ok = false
}
if height < 1 {
ok = false
}
return
}
func parseWinchRequest(s []byte) (width, height int, ok bool) {
width32, s, ok := parseUint32(s)
if !ok {
return
}
height32, s, ok := parseUint32(s)
if !ok {
return
}
width = int(width32)
height = int(height32)
if width < 1 {
ok = false
}
if height < 1 {
ok = false
}
return
}
func parseString(in []byte) (out string, rest []byte, ok bool) {
if len(in) < 4 {
return
}
length := binary.BigEndian.Uint32(in)
if uint32(len(in)) < 4+length {
return
}
out = string(in[4 : 4+length])
rest = in[4+length:]
ok = true
return
}
func parseUint32(in []byte) (uint32, []byte, bool) {
if len(in) < 4 {
return 0, nil, false
}
return binary.BigEndian.Uint32(in), in[4:], true
}

View File

@ -13,7 +13,7 @@ type Server struct {
sshConfig *ssh.ServerConfig sshConfig *ssh.ServerConfig
sshSigner *ssh.Signer sshSigner *ssh.Signer
done chan struct{} done chan struct{}
clients map[Client]struct{} clients map[*Client]struct{}
lock sync.Mutex lock sync.Mutex
} }
@ -29,6 +29,7 @@ func NewServer(privateKey []byte) (*Server, error) {
return nil, nil return nil, nil
}, },
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
// fingerprint := md5.Sum(key.Marshal()
return nil, nil return nil, nil
}, },
} }
@ -38,7 +39,7 @@ func NewServer(privateKey []byte) (*Server, error) {
sshConfig: &config, sshConfig: &config,
sshSigner: &signer, sshSigner: &signer,
done: make(chan struct{}), done: make(chan struct{}),
clients: map[Client]struct{}{}, clients: map[*Client]struct{}{},
} }
return &server, nil return &server, nil
@ -47,7 +48,7 @@ func NewServer(privateKey []byte) (*Server, error) {
func (s *Server) Broadcast(msg string, except *Client) { func (s *Server) Broadcast(msg string, except *Client) {
logger.Debugf("Broadcast to %d: %s", len(s.clients), strings.TrimRight(msg, "\r\n")) logger.Debugf("Broadcast to %d: %s", len(s.clients), strings.TrimRight(msg, "\r\n"))
for client := range s.clients { for client := range s.clients {
if except != nil && client == *except { if except != nil && client == except {
continue continue
} }
client.Msg <- msg client.Msg <- msg
@ -87,23 +88,25 @@ func (s *Server) Start(laddr string) error {
go ssh.DiscardRequests(requests) go ssh.DiscardRequests(requests)
client := NewClient(s, sshConn.User()) client := NewClient(s, sshConn, sshConn.User())
// TODO: mutex this // TODO: mutex this
s.lock.Lock() s.lock.Lock()
s.clients[*client] = struct{}{} s.clients[client] = struct{}{}
num := len(s.clients) num := len(s.clients)
s.lock.Unlock() s.lock.Unlock()
s.Broadcast(fmt.Sprintf("* Joined: %s (%d present)\r\n", client.Name, num), nil) client.sendWelcome()
s.Broadcast(fmt.Sprintf("* %s joined. (Total connected: %d)\r\n", client.Name, num), nil)
go func() { go func() {
sshConn.Wait() sshConn.Wait()
s.lock.Lock() s.lock.Lock()
delete(s.clients, *client) delete(s.clients, client)
s.lock.Unlock() s.lock.Unlock()
s.Broadcast(fmt.Sprintf("* Left: %s\r\n", client.Name), nil) s.Broadcast(fmt.Sprintf("* %s left.\r\n", client.Name), nil)
}() }()
go client.handleChannels(channels) go client.handleChannels(channels)
@ -120,5 +123,9 @@ func (s *Server) Start(laddr string) error {
} }
func (s *Server) Stop() { func (s *Server) Stop() {
for client := range s.clients {
client.Conn.Close()
}
close(s.done) close(s.done)
} }