diff --git a/client.go b/client.go index f67288d..6f4acca 100644 --- a/client.go +++ b/client.go @@ -30,7 +30,14 @@ func NewClient(server *Server, name string) *Client { func (c *Client) handleShell(channel ssh.Channel) { defer channel.Close() - term := terminal.NewTerminal(channel, "") + prompt := fmt.Sprintf("%s> ", c.Name) + term := terminal.NewTerminal(channel, prompt) + + go func() { + for msg := range c.Msg { + term.Write([]byte(msg)) + } + }() for { line, err := term.ReadLine() @@ -43,8 +50,9 @@ func (c *Client) handleShell(channel ssh.Channel) { channel.Close() } + term.Write(term.Escape.Reset) msg := fmt.Sprintf("%s: %s\r\n", c.Name, line) - c.Server.Broadcast(msg) + c.Server.Broadcast(msg, c) } } @@ -80,12 +88,6 @@ func (c *Client) handleChannels(channels <-chan ssh.NewChannel) { go c.handleShell(channel) - go func() { - for msg := range c.Msg { - channel.Write([]byte(msg)) - } - }() - // We don't care about other channels? return } diff --git a/server.go b/server.go index da5a45d..ae4644e 100644 --- a/server.go +++ b/server.go @@ -43,9 +43,12 @@ func NewServer(privateKey []byte) (*Server, error) { return &server, nil } -func (s *Server) Broadcast(msg string) { +func (s *Server) Broadcast(msg string, except *Client) { logger.Debugf("Broadcast to %d: %s", len(s.clients), strings.TrimRight(msg, "\r\n")) for _, client := range s.clients { + if except != nil && client == *except { + continue + } client.Msg <- msg } } @@ -88,9 +91,10 @@ func (s *Server) Start(laddr string) error { s.lock.Lock() s.clients = append(s.clients, *client) + num := len(s.clients) s.lock.Unlock() - s.Broadcast(fmt.Sprintf("* Joined: %s", client.Name)) + s.Broadcast(fmt.Sprintf("* Joined: %s (%d present)\r\n", client.Name, num), nil) go client.handleChannels(channels) }()