diff --git a/cmd/ssh-chat/cmd.go b/cmd/ssh-chat/cmd.go index 0c0fb99..a4e02c3 100644 --- a/cmd/ssh-chat/cmd.go +++ b/cmd/ssh-chat/cmd.go @@ -105,7 +105,7 @@ func main() { fail(4, "Failed to listen on socket: %v\n", err) } defer s.Close() - s.RateLimit = true + s.RateLimit = sshd.NewInputLimiter fmt.Printf("Listening for connections on %v\n", s.Addr().String()) diff --git a/sshd/net.go b/sshd/net.go index 69a30da..84d6269 100644 --- a/sshd/net.go +++ b/sshd/net.go @@ -2,7 +2,6 @@ package sshd import ( "net" - "time" "github.com/shazow/rateio" "golang.org/x/crypto/ssh" @@ -12,7 +11,7 @@ import ( type SSHListener struct { net.Listener config *ssh.ServerConfig - RateLimit bool + RateLimit func() rateio.Limiter } // Make an SSH listener socket @@ -26,9 +25,9 @@ func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) { } func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) { - if l.RateLimit { + if l.RateLimit != nil { // TODO: Configurable Limiter? - conn = ReadLimitConn(conn, rateio.NewGracefulLimiter(1024*10, time.Minute*2, time.Second*3)) + conn = ReadLimitConn(conn, l.RateLimit()) } // Upgrade TCP connection to SSH connection diff --git a/sshd/ratelimit.go b/sshd/ratelimit.go index c80f0ac..c76dc46 100644 --- a/sshd/ratelimit.go +++ b/sshd/ratelimit.go @@ -3,6 +3,7 @@ package sshd import ( "io" "net" + "time" "github.com/shazow/rateio" ) @@ -23,3 +24,48 @@ func ReadLimitConn(conn net.Conn, limiter rateio.Limiter) net.Conn { Reader: rateio.NewReader(conn, limiter), } } + +// Count each read as 1 unless it exceeds some number of bytes. +type inputLimiter struct { + // TODO: Could do all kinds of fancy things here, like be more forgiving of + // connections that have been around for a while. + + Amount int + Frequency time.Duration + + remaining int + readCap int + numRead int + timeRead time.Time +} + +// NewInputLimiter returns a rateio.Limiter with sensible defaults for +// differentiating between humans typing and bots spamming. +func NewInputLimiter() rateio.Limiter { + grace := time.Second * 3 + return &inputLimiter{ + Amount: 200 * 4 * 2, // Assume fairly high typing rate + margin for copypasta of links. + Frequency: time.Minute * 2, + readCap: 128, // Allow up to 128 bytes per read (anecdotally, 1 character = 52 bytes over ssh) + numRead: -1024 * 1024, // Start with a 1mb grace + timeRead: time.Now().Add(grace), + } +} + +// Count applies 1 if n limit.Amount { + return rateio.ErrRateExceeded + } + return nil +}