diff --git a/sshd/terminal.go b/sshd/terminal.go index e71749b..8d4b725 100644 --- a/sshd/terminal.go +++ b/sshd/terminal.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net" + "sync" "time" "golang.org/x/crypto/ssh" @@ -13,6 +14,9 @@ import ( var keepaliveInterval = time.Second * 30 var keepaliveRequest = "keepalive@ssh-chat" +var ErrNoSessionChannel = errors.New("no session channel") +var ErrNotSessionChannel = errors.New("terminal requires session channel") + // Connection is an interface with fields necessary to operate an sshd host. type Connection interface { PublicKey() ssh.PublicKey @@ -52,37 +56,43 @@ type Terminal struct { terminal.Terminal Conn Connection Channel ssh.Channel + + done chan struct{} + closeOnce sync.Once } // Make new terminal from a session channel func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) { if ch.ChannelType() != "session" { - return nil, errors.New("terminal requires session channel") + return nil, ErrNotSessionChannel } channel, requests, err := ch.Accept() if err != nil { return nil, err } term := Terminal{ - *terminal.NewTerminal(channel, "Connecting..."), - sshConn{conn}, - channel, + Terminal: *terminal.NewTerminal(channel, "Connecting..."), + Conn: sshConn{conn}, + Channel: channel, + + done: make(chan struct{}), } go term.listen(requests) - go func() { - // FIXME: Is this necessary? - conn.Wait() - channel.Close() - }() go func() { - for range time.Tick(keepaliveInterval) { - // TODO: Could break out earlier with a select if we want, rather than waiting for an error. - _, err := channel.SendRequest(keepaliveRequest, true, nil) - if err != nil { - // Connection is gone - conn.Close() + // Keep-Alive Ticker + ticker := time.Tick(keepaliveInterval) + for { + select { + case <-ticker: + _, err := channel.SendRequest(keepaliveRequest, true, nil) + if err != nil { + // Connection is gone + term.Close() + return + } + case <-term.done: return } } @@ -92,35 +102,29 @@ func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) { } // Find session channel and make a Terminal from it -func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) { +func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (*Terminal, error) { + // Make a terminal from the first session found for ch := range channels { if t := ch.ChannelType(); t != "session" { ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) continue } - term, err = NewTerminal(conn, ch) - if err == nil { - break - } + return NewTerminal(conn, ch) } - if term != nil { - // Reject the rest. - // FIXME: Do we need this? - go func() { - for ch := range channels { - ch.Reject(ssh.Prohibited, "only one session allowed") - } - }() - } - - return term, err + return nil, ErrNoSessionChannel } // Close terminal and ssh connection func (t *Terminal) Close() error { - return t.Conn.Close() + var err error + t.closeOnce.Do(func() { + close(t.done) + t.Channel.Close() + err = t.Conn.Close() + }) + return err } // Negotiate terminal type and settings