mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-13 15:47:17 +03:00
terminal: Disconnect sooner and more reliably
This commit is contained in:
parent
50d2be3a88
commit
f0db74c874
@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
@ -13,6 +14,9 @@ import (
|
||||
var keepaliveInterval = time.Second * 30
|
||||
var keepaliveRequest = "keepalive@ssh-chat"
|
||||
|
||||
var ErrNoSessionChannel = errors.New("no session channel")
|
||||
var ErrNotSessionChannel = errors.New("terminal requires session channel")
|
||||
|
||||
// Connection is an interface with fields necessary to operate an sshd host.
|
||||
type Connection interface {
|
||||
PublicKey() ssh.PublicKey
|
||||
@ -52,37 +56,43 @@ type Terminal struct {
|
||||
terminal.Terminal
|
||||
Conn Connection
|
||||
Channel ssh.Channel
|
||||
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// Make new terminal from a session channel
|
||||
func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
|
||||
if ch.ChannelType() != "session" {
|
||||
return nil, errors.New("terminal requires session channel")
|
||||
return nil, ErrNotSessionChannel
|
||||
}
|
||||
channel, requests, err := ch.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
term := Terminal{
|
||||
*terminal.NewTerminal(channel, "Connecting..."),
|
||||
sshConn{conn},
|
||||
channel,
|
||||
Terminal: *terminal.NewTerminal(channel, "Connecting..."),
|
||||
Conn: sshConn{conn},
|
||||
Channel: channel,
|
||||
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
go term.listen(requests)
|
||||
go func() {
|
||||
// FIXME: Is this necessary?
|
||||
conn.Wait()
|
||||
channel.Close()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for range time.Tick(keepaliveInterval) {
|
||||
// TODO: Could break out earlier with a select if we want, rather than waiting for an error.
|
||||
_, err := channel.SendRequest(keepaliveRequest, true, nil)
|
||||
if err != nil {
|
||||
// Connection is gone
|
||||
conn.Close()
|
||||
// Keep-Alive Ticker
|
||||
ticker := time.Tick(keepaliveInterval)
|
||||
for {
|
||||
select {
|
||||
case <-ticker:
|
||||
_, err := channel.SendRequest(keepaliveRequest, true, nil)
|
||||
if err != nil {
|
||||
// Connection is gone
|
||||
term.Close()
|
||||
return
|
||||
}
|
||||
case <-term.done:
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -92,35 +102,29 @@ func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
|
||||
}
|
||||
|
||||
// Find session channel and make a Terminal from it
|
||||
func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
|
||||
func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (*Terminal, error) {
|
||||
// Make a terminal from the first session found
|
||||
for ch := range channels {
|
||||
if t := ch.ChannelType(); t != "session" {
|
||||
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
|
||||
continue
|
||||
}
|
||||
|
||||
term, err = NewTerminal(conn, ch)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
return NewTerminal(conn, ch)
|
||||
}
|
||||
|
||||
if term != nil {
|
||||
// Reject the rest.
|
||||
// FIXME: Do we need this?
|
||||
go func() {
|
||||
for ch := range channels {
|
||||
ch.Reject(ssh.Prohibited, "only one session allowed")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return term, err
|
||||
return nil, ErrNoSessionChannel
|
||||
}
|
||||
|
||||
// Close terminal and ssh connection
|
||||
func (t *Terminal) Close() error {
|
||||
return t.Conn.Close()
|
||||
var err error
|
||||
t.closeOnce.Do(func() {
|
||||
close(t.done)
|
||||
t.Channel.Close()
|
||||
err = t.Conn.Close()
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Negotiate terminal type and settings
|
||||
|
Loading…
x
Reference in New Issue
Block a user