diff --git a/Makefile b/Makefile index 43a897a..5136b14 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,7 @@ $(KEY): ssh-keygen -f $(KEY) -P '' run: $(BINARY) $(KEY) - ./$(BINARY) -i $(KEY) --bind ":$(PORT)" -vv + ./$(BINARY) -i $(KEY) --bind "127.0.0.1:$(PORT)" -vv debug: $(BINARY) $(KEY) ./$(BINARY) --pprof 6060 -i $(KEY) --bind ":$(PORT)" -vv diff --git a/client.go b/client.go index 996171a..bf09c0d 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package sshchat import ( + "io" "sync" "time" @@ -9,22 +10,17 @@ import ( "github.com/shazow/ssh-chat/sshd" ) +type multiTerm interface { + Connections() []sshd.Connection + Add(*sshd.Terminal) + ReadLine() (string, error) + io.WriteCloser +} + type client struct { Member sync.Mutex - conns []sshd.Connection -} - -func (cl *client) Connections() []sshd.Connection { - return cl.conns -} - -func (cl *client) Close() error { - // TODO: Stack errors? - for _, conn := range cl.conns { - conn.Close() - } - return nil + multiTerm } type Member interface { diff --git a/host.go b/host.go index 9c43599..912b794 100644 --- a/host.go +++ b/host.go @@ -32,10 +32,9 @@ type Host struct { // Default theme theme message.Theme - mu sync.Mutex - motd string - count int - clients map[chat.Member][]client + mu sync.Mutex + motd string + count int } // NewHost creates a Host on top of an existing listener. @@ -46,7 +45,6 @@ func NewHost(listener *sshd.SSHListener, auth *Auth) *Host { listener: listener, commands: chat.Commands{}, auth: auth, - clients: map[chat.Member][]client{}, } // Make our own commands registry instance. @@ -72,15 +70,30 @@ func (h *Host) SetMotd(motd string) { h.mu.Unlock() } +var globalUser *client + // Connect a specific Terminal to this host and its room. -func (h *Host) Connect(term *sshd.Terminal) { - requestedName := term.Conn.Name() - screen := message.BufferedScreen(requestedName, term) - user := &client{ - Member: screen, - conns: []sshd.Connection{term.Conn}, +func (h *Host) Connect(t *sshd.Terminal) { + // XXX: Hack to test multiple users per key + if globalUser != nil { + globalUser.Add(t) + return } + conn := t.Conn + remoteAddr := conn.RemoteAddr() + requestedName := conn.Name() + term := sshd.MultiTerm(t) + screen := message.BufferedScreen(requestedName, term) + + user := &client{ + Member: screen, + multiTerm: term, + } + defer user.Close() + // XXX: Hack to test multiple users per key + globalUser = user + h.mu.Lock() motd := h.motd count := h.count @@ -91,10 +104,6 @@ func (h *Host) Connect(term *sshd.Terminal) { cfg.Theme = &h.theme user.SetConfig(cfg) - // Close term once user is closed. - defer screen.Close() - defer term.Close() - go screen.Consume() // Send MOTD @@ -109,7 +118,7 @@ func (h *Host) Connect(term *sshd.Terminal) { member, err = h.Join(user) } if err != nil { - logger.Errorf("[%s] Failed to join: %s", term.Conn.RemoteAddr(), err) + logger.Errorf("[%s] Failed to join: %s", conn.RemoteAddr(), err) return } @@ -118,27 +127,29 @@ func (h *Host) Connect(term *sshd.Terminal) { term.AutoCompleteCallback = h.AutoCompleteFunction(user) user.SetHighlight(user.Name()) + // XXX: Mark multiterm as ready? + // Should the user be op'd on join? - if key := term.Conn.PublicKey(); key != nil { + if key := conn.PublicKey(); key != nil { authItem, err := h.auth.ops.Get(newAuthKey(key)) if err == nil { err = h.Room.Ops.Add(set.Rename(authItem, member.ID())) } } if err != nil { - logger.Warningf("[%s] Failed to op: %s", term.Conn.RemoteAddr(), err) + logger.Warningf("[%s] Failed to op: %s", remoteAddr, err) } ratelimit := rateio.NewSimpleLimiter(3, time.Second*3) - logger.Debugf("[%s] Joined: %s", term.Conn.RemoteAddr(), user.Name()) + logger.Debugf("[%s] Joined: %s", remoteAddr, user.Name()) for { - line, err := term.ReadLine() + line, err := user.ReadLine() if err == io.EOF { // Closed break } else if err != nil { - logger.Errorf("[%s] Terminal reading error: %s", term.Conn.RemoteAddr(), err) + logger.Errorf("[%s] Terminal reading error: %s", remoteAddr, err) break } @@ -175,10 +186,10 @@ func (h *Host) Connect(term *sshd.Terminal) { err = h.Leave(user) if err != nil { - logger.Errorf("[%s] Failed to leave: %s", term.Conn.RemoteAddr(), err) + logger.Errorf("[%s] Failed to leave: %s", remoteAddr, err) return } - logger.Debugf("[%s] Leaving: %s", term.Conn.RemoteAddr(), user.Name()) + logger.Debugf("[%s] Leaving: %s", remoteAddr, user.Name()) } // Serve our chat room onto the listener diff --git a/sshd/multiterm.go b/sshd/multiterm.go new file mode 100644 index 0000000..7d020f7 --- /dev/null +++ b/sshd/multiterm.go @@ -0,0 +1,131 @@ +package sshd + +import ( + "fmt" + "io" + "sync" +) + +type termLine struct { + Term *Terminal + Line string + Err error +} + +func MultiTerm(terms ...*Terminal) *multiTerm { + mt := &multiTerm{ + lines: make(chan termLine), + } + for _, t := range terms { + mt.Add(t) + } + return mt +} + +type multiTerm struct { + AutoCompleteCallback func(line string, pos int, key rune) (newLine string, newPos int, ok bool) + + mu sync.Mutex + terms []*Terminal + add chan *Terminal + lines chan termLine + prompt string +} + +func (mt *multiTerm) SetPrompt(prompt string) { + mt.mu.Lock() + mt.prompt = prompt + mt.mu.Unlock() + for _, t := range mt.Terminals() { + t.SetPrompt(prompt) + } +} + +func (mt *multiTerm) Connections() []Connection { + terms := mt.Terminals() + conns := make([]Connection, len(terms)) + for _, term := range terms { + conns = append(conns, term.Conn) + } + return conns +} + +func (mt *multiTerm) Terminals() []*Terminal { + mt.mu.Lock() + terms := mt.terms + mt.mu.Unlock() + return terms +} + +func (mt *multiTerm) Add(t *Terminal) { + mt.mu.Lock() + mt.terms = append(mt.terms, t) + prompt := mt.prompt + mt.mu.Unlock() + t.AutoCompleteCallback = mt.AutoCompleteCallback + t.SetPrompt(prompt) + + go func() { + var line termLine + for { + line.Line, line.Err = t.ReadLine() + line.Term = t + mt.lines <- line + if line.Err != nil { + // FIXME: Should we not abort on all errors? + break + } + } + }() +} + +func (mt *multiTerm) ReadLine() (string, error) { + line := <-mt.lines + mt.mu.Lock() + prompt := mt.prompt + mt.mu.Unlock() + if line.Err == nil { + // Write the line to all the other terminals + for _, w := range mt.Terminals() { + if w == line.Term { + continue + } + // XXX: This is super hacky and frankly wrong. + w.Write([]byte(prompt + line.Line + "\n\r")) + // TODO: Remove terminal if it fails to write? + } + } + return line.Line, line.Err +} + +func (mt *multiTerm) Write(p []byte) (n int, err error) { + for _, w := range mt.Terminals() { + n, err = w.Write(p) + if err != nil { + return + } + if n != len(p) { + err = io.ErrShortWrite + return + } + } + return len(p), nil +} + +func (mt *multiTerm) Close() error { + mt.mu.Lock() + var errs []error + for _, t := range mt.terms { + if err := t.Close(); err != nil { + errs = append(errs, err) + } + } + mt.terms = nil + mt.mu.Unlock() + + if len(errs) == 0 { + return nil + } + + return fmt.Errorf("%d errors: %q", len(errs), errs) +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..5bac432 --- /dev/null +++ b/util.go @@ -0,0 +1,12 @@ +package sshchat + +import "fmt" + +type multiError []error + +func (err multiError) Error() string { + if len(err) == 0 { + return "" + } + return fmt.Sprintf("%d errors: %q", len(err), err) +}