terminal: Disconnect sooner and more reliably

This commit is contained in:
Andrey Petrov 2016-07-24 22:56:38 -04:00
parent 50d2be3a88
commit f0db74c874

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"sync"
"time" "time"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -13,6 +14,9 @@ import (
var keepaliveInterval = time.Second * 30 var keepaliveInterval = time.Second * 30
var keepaliveRequest = "keepalive@ssh-chat" var keepaliveRequest = "keepalive@ssh-chat"
var ErrNoSessionChannel = errors.New("no session channel")
var ErrNotSessionChannel = errors.New("terminal requires session channel")
// Connection is an interface with fields necessary to operate an sshd host. // Connection is an interface with fields necessary to operate an sshd host.
type Connection interface { type Connection interface {
PublicKey() ssh.PublicKey PublicKey() ssh.PublicKey
@ -52,37 +56,43 @@ type Terminal struct {
terminal.Terminal terminal.Terminal
Conn Connection Conn Connection
Channel ssh.Channel Channel ssh.Channel
done chan struct{}
closeOnce sync.Once
} }
// Make new terminal from a session channel // Make new terminal from a session channel
func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) { func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
if ch.ChannelType() != "session" { if ch.ChannelType() != "session" {
return nil, errors.New("terminal requires session channel") return nil, ErrNotSessionChannel
} }
channel, requests, err := ch.Accept() channel, requests, err := ch.Accept()
if err != nil { if err != nil {
return nil, err return nil, err
} }
term := Terminal{ term := Terminal{
*terminal.NewTerminal(channel, "Connecting..."), Terminal: *terminal.NewTerminal(channel, "Connecting..."),
sshConn{conn}, Conn: sshConn{conn},
channel, Channel: channel,
done: make(chan struct{}),
} }
go term.listen(requests) go term.listen(requests)
go func() {
// FIXME: Is this necessary?
conn.Wait()
channel.Close()
}()
go func() { go func() {
for range time.Tick(keepaliveInterval) { // Keep-Alive Ticker
// TODO: Could break out earlier with a select if we want, rather than waiting for an error. ticker := time.Tick(keepaliveInterval)
_, err := channel.SendRequest(keepaliveRequest, true, nil) for {
if err != nil { select {
// Connection is gone case <-ticker:
conn.Close() _, err := channel.SendRequest(keepaliveRequest, true, nil)
if err != nil {
// Connection is gone
term.Close()
return
}
case <-term.done:
return return
} }
} }
@ -92,35 +102,29 @@ func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
} }
// Find session channel and make a Terminal from it // Find session channel and make a Terminal from it
func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) { func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (*Terminal, error) {
// Make a terminal from the first session found
for ch := range channels { for ch := range channels {
if t := ch.ChannelType(); t != "session" { if t := ch.ChannelType(); t != "session" {
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
continue continue
} }
term, err = NewTerminal(conn, ch) return NewTerminal(conn, ch)
if err == nil {
break
}
} }
if term != nil { return nil, ErrNoSessionChannel
// Reject the rest.
// FIXME: Do we need this?
go func() {
for ch := range channels {
ch.Reject(ssh.Prohibited, "only one session allowed")
}
}()
}
return term, err
} }
// Close terminal and ssh connection // Close terminal and ssh connection
func (t *Terminal) Close() error { func (t *Terminal) Close() error {
return t.Conn.Close() var err error
t.closeOnce.Do(func() {
close(t.done)
t.Channel.Close()
err = t.Conn.Close()
})
return err
} }
// Negotiate terminal type and settings // Negotiate terminal type and settings