From d8d5deac1c6b04f169ba6ddc31738142c925ec9e Mon Sep 17 00:00:00 2001
From: Andrey Petrov <andrey.petrov@shazow.net>
Date: Sat, 10 Jan 2015 12:44:06 -0800
Subject: [PATCH] Use authorized_keys-style public keys rather than
 fingerprints.

Tests for whitelisting.
---
 auth.go             |  61 +++++++++++++++++---------
 auth_test.go        |  62 +++++++++++++++++++++++++++
 cmd.go              | 101 ++++++++++++++++++++------------------------
 host_test.go        |  47 ++++++++++++++++++++-
 key.go              |  49 +++++++++++++++++++++
 sshd/auth.go        |  31 +++++++++-----
 sshd/client.go      |   4 +-
 sshd/client_test.go |  44 +++++++++++++++++++
 sshd/net.go         |   1 +
 sshd/net_test.go    |   4 +-
 10 files changed, 310 insertions(+), 94 deletions(-)
 create mode 100644 auth_test.go
 create mode 100644 key.go
 create mode 100644 sshd/client_test.go

diff --git a/auth.go b/auth.go
index 6357564..4d6de86 100644
--- a/auth.go
+++ b/auth.go
@@ -5,24 +5,39 @@ import (
 	"sync"
 
 	"github.com/shazow/ssh-chat/sshd"
+	"golang.org/x/crypto/ssh"
 )
 
+// The error returned a key is checked that is not whitelisted, with whitelisting required.
+var ErrNotWhitelisted = errors.New("not whitelisted")
+
+// The error returned a key is checked that is banned.
+var ErrBanned = errors.New("banned")
+
+// AuthKey is the type that our lookups are keyed against.
+type AuthKey string
+
+// NewAuthKey returns an AuthKey from an ssh.PublicKey.
+func NewAuthKey(key ssh.PublicKey) AuthKey {
+	// FIXME: Is there a way to index pubkeys without marshal'ing them into strings?
+	return AuthKey(string(key.Marshal()))
+}
+
 // Auth stores fingerprint lookups
 type Auth struct {
-	whitelist map[string]struct{}
-	banned    map[string]struct{}
-	ops       map[string]struct{}
-
 	sshd.Auth
 	sync.RWMutex
+	whitelist map[AuthKey]struct{}
+	banned    map[AuthKey]struct{}
+	ops       map[AuthKey]struct{}
 }
 
 // NewAuth creates a new default Auth.
-func NewAuth() Auth {
-	return Auth{
-		whitelist: make(map[string]struct{}),
-		banned:    make(map[string]struct{}),
-		ops:       make(map[string]struct{}),
+func NewAuth() *Auth {
+	return &Auth{
+		whitelist: make(map[AuthKey]struct{}),
+		banned:    make(map[AuthKey]struct{}),
+		ops:       make(map[AuthKey]struct{}),
 	}
 }
 
@@ -35,43 +50,49 @@ func (a Auth) AllowAnonymous() bool {
 }
 
 // Check determines if a pubkey fingerprint is permitted.
-func (a Auth) Check(fingerprint string) (bool, error) {
+func (a Auth) Check(key ssh.PublicKey) (bool, error) {
+	authkey := NewAuthKey(key)
+
 	a.RLock()
 	defer a.RUnlock()
 
 	if len(a.whitelist) > 0 {
 		// Only check whitelist if there is something in it, otherwise it's disabled.
-		_, whitelisted := a.whitelist[fingerprint]
+
+		_, whitelisted := a.whitelist[authkey]
 		if !whitelisted {
-			return false, errors.New("not whitelisted")
+			return false, ErrNotWhitelisted
 		}
 	}
 
-	_, banned := a.banned[fingerprint]
+	_, banned := a.banned[authkey]
 	if banned {
-		return false, errors.New("banned")
+		return false, ErrBanned
 	}
 
 	return true, nil
 }
 
 // Op will set a fingerprint as a known operator.
-func (a *Auth) Op(fingerprint string) {
+func (a *Auth) Op(key ssh.PublicKey) {
+	authkey := NewAuthKey(key)
 	a.Lock()
-	a.ops[fingerprint] = struct{}{}
+	a.ops[authkey] = struct{}{}
 	a.Unlock()
 }
 
 // Whitelist will set a fingerprint as a whitelisted user.
-func (a *Auth) Whitelist(fingerprint string) {
+func (a *Auth) Whitelist(key ssh.PublicKey) {
+	authkey := NewAuthKey(key)
 	a.Lock()
-	a.whitelist[fingerprint] = struct{}{}
+	a.whitelist[authkey] = struct{}{}
 	a.Unlock()
 }
 
 // Ban will set a fingerprint as banned.
-func (a *Auth) Ban(fingerprint string) {
+func (a *Auth) Ban(key ssh.PublicKey) {
+	authkey := NewAuthKey(key)
 	a.Lock()
-	a.banned[fingerprint] = struct{}{}
+	a.banned[authkey] = struct{}{}
 	a.Unlock()
 }
diff --git a/auth_test.go b/auth_test.go
new file mode 100644
index 0000000..cb7e521
--- /dev/null
+++ b/auth_test.go
@@ -0,0 +1,62 @@
+package main
+
+import (
+	"crypto/rand"
+	"crypto/rsa"
+	"testing"
+
+	"golang.org/x/crypto/ssh"
+)
+
+func NewRandomPublicKey(bits int) (ssh.PublicKey, error) {
+	key, err := rsa.GenerateKey(rand.Reader, bits)
+	if err != nil {
+		return nil, err
+	}
+
+	return ssh.NewPublicKey(key.Public())
+}
+
+func ClonePublicKey(key ssh.PublicKey) (ssh.PublicKey, error) {
+	return ssh.ParsePublicKey(key.Marshal())
+}
+
+func TestAuthWhitelist(t *testing.T) {
+	key, err := NewRandomPublicKey(512)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	auth := NewAuth()
+	ok, err := auth.Check(key)
+	if !ok || err != nil {
+		t.Error("Failed to permit in default state:", err)
+	}
+
+	auth.Whitelist(key)
+
+	key_clone, err := ClonePublicKey(key)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if string(key_clone.Marshal()) != string(key.Marshal()) {
+		t.Error("Clone key does not match.")
+	}
+
+	ok, err = auth.Check(key_clone)
+	if !ok || err != nil {
+		t.Error("Failed to permit whitelisted:", err)
+	}
+
+	key2, err := NewRandomPublicKey(512)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	ok, err = auth.Check(key2)
+	if ok || err == nil {
+		t.Error("Failed to restrict not whitelisted:", err)
+	}
+
+}
diff --git a/cmd.go b/cmd.go
index 9f5104b..e244361 100644
--- a/cmd.go
+++ b/cmd.go
@@ -2,8 +2,6 @@ package main
 
 import (
 	"bufio"
-	"crypto/x509"
-	"encoding/pem"
 	"fmt"
 	"io/ioutil"
 	"net/http"
@@ -16,7 +14,6 @@ import (
 	"github.com/alexcesaro/log/golog"
 	"github.com/jessevdk/go-flags"
 	"golang.org/x/crypto/ssh"
-	"golang.org/x/crypto/ssh/terminal"
 
 	"github.com/shazow/ssh-chat/chat"
 	"github.com/shazow/ssh-chat/sshd"
@@ -25,13 +22,13 @@ import _ "net/http/pprof"
 
 // Options contains the flag options
 type Options struct {
-	Verbose   []bool   `short:"v" long:"verbose" description:"Show verbose logging."`
-	Identity  string   `short:"i" long:"identity" description:"Private key to identify server with." default:"~/.ssh/id_rsa"`
-	Bind      string   `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:2022"`
-	Admin     []string `long:"admin" description:"Fingerprint of pubkey to mark as admin."`
-	Whitelist string   `long:"whitelist" description:"Optional file of pubkey fingerprints who are allowed to connect."`
-	Motd      string   `long:"motd" description:"Optional Message of the Day file."`
-	Pprof     int      `long:"pprof" description:"Enable pprof http server for profiling."`
+	Verbose   []bool `short:"v" long:"verbose" description:"Show verbose logging."`
+	Identity  string `short:"i" long:"identity" description:"Private key to identify server with." default:"~/.ssh/id_rsa"`
+	Bind      string `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:2022"`
+	Admin     string `long:"admin" description:"File of public keys who are admins."`
+	Whitelist string `long:"whitelist" description:"Optional file of public keys who are allowed to connect."`
+	Motd      string `long:"motd" description:"Optional Message of the Day file."`
+	Pprof     int    `long:"pprof" description:"Enable pprof http server for profiling."`
 }
 
 var logLevels = []log.Level{
@@ -83,7 +80,7 @@ func main() {
 		}
 	}
 
-	privateKey, err := readPrivateKey(privateKeyPath)
+	privateKey, err := ReadPrivateKey(privateKeyPath)
 	if err != nil {
 		logger.Errorf("Couldn't read private key: %v", err)
 		os.Exit(2)
@@ -109,25 +106,35 @@ func main() {
 	fmt.Printf("Listening for connections on %v\n", s.Addr().String())
 
 	host := NewHost(s)
-	host.auth = &auth
+	host.auth = auth
 	host.theme = &chat.Themes[0]
 
-	for _, fingerprint := range options.Admin {
-		auth.Op(fingerprint)
+	err = fromFile(options.Admin, func(line []byte) error {
+		key, _, _, _, err := ssh.ParseAuthorizedKey(line)
+		if err != nil {
+			return err
+		}
+		auth.Op(key)
+		logger.Debugf("Added admin: %s", line)
+		return nil
+	})
+	if err != nil {
+		logger.Errorf("Failed to load admins: %v", err)
+		os.Exit(5)
 	}
 
-	if options.Whitelist != "" {
-		file, err := os.Open(options.Whitelist)
+	err = fromFile(options.Whitelist, func(line []byte) error {
+		key, _, _, _, err := ssh.ParseAuthorizedKey(line)
 		if err != nil {
-			logger.Errorf("Could not open whitelist file")
-			return
-		}
-		defer file.Close()
-
-		scanner := bufio.NewScanner(file)
-		for scanner.Scan() {
-			auth.Whitelist(scanner.Text())
+			return err
 		}
+		auth.Whitelist(key)
+		logger.Debugf("Whitelisted: %s", line)
+		return nil
+	})
+	if err != nil {
+		logger.Errorf("Failed to load whitelist: %v", err)
+		os.Exit(5)
 	}
 
 	if options.Motd != "" {
@@ -154,42 +161,24 @@ func main() {
 	os.Exit(0)
 }
 
-// readPrivateKey attempts to read your private key and possibly decrypt it if it
-// requires a passphrase.
-// This function will prompt for a passphrase on STDIN if the environment variable (`IDENTITY_PASSPHRASE`),
-// is not set.
-func readPrivateKey(privateKeyPath string) ([]byte, error) {
-	privateKey, err := ioutil.ReadFile(privateKeyPath)
+func fromFile(path string, handler func(line []byte) error) error {
+	if path == "" {
+		// Skip
+		return nil
+	}
+
+	file, err := os.Open(path)
 	if err != nil {
-		return nil, fmt.Errorf("failed to load identity: %v", err)
+		return err
 	}
+	defer file.Close()
 
-	block, rest := pem.Decode(privateKey)
-	if len(rest) > 0 {
-		return nil, fmt.Errorf("extra data when decoding private key")
-	}
-	if !x509.IsEncryptedPEMBlock(block) {
-		return privateKey, nil
-	}
-
-	passphrase := []byte(os.Getenv("IDENTITY_PASSPHRASE"))
-	if len(passphrase) == 0 {
-		fmt.Printf("Enter passphrase: ")
-		passphrase, err = terminal.ReadPassword(int(os.Stdin.Fd()))
+	scanner := bufio.NewScanner(file)
+	for scanner.Scan() {
+		err := handler(scanner.Bytes())
 		if err != nil {
-			return nil, fmt.Errorf("couldn't read passphrase: %v", err)
+			return err
 		}
-		fmt.Println()
 	}
-	der, err := x509.DecryptPEMBlock(block, passphrase)
-	if err != nil {
-		return nil, fmt.Errorf("decrypt failed: %v", err)
-	}
-
-	privateKey = pem.EncodeToMemory(&pem.Block{
-		Type:  block.Type,
-		Bytes: der,
-	})
-
-	return privateKey, nil
+	return nil
 }
diff --git a/host_test.go b/host_test.go
index d86c353..abee9f3 100644
--- a/host_test.go
+++ b/host_test.go
@@ -2,12 +2,15 @@ package main
 
 import (
 	"bufio"
+	"crypto/rand"
+	"crypto/rsa"
 	"io"
 	"strings"
 	"testing"
 
 	"github.com/shazow/ssh-chat/chat"
 	"github.com/shazow/ssh-chat/sshd"
+	"golang.org/x/crypto/ssh"
 )
 
 func stripPrompt(s string) string {
@@ -39,7 +42,7 @@ func TestHostGetPrompt(t *testing.T) {
 }
 
 func TestHostNameCollision(t *testing.T) {
-	key, err := sshd.NewRandomKey(512)
+	key, err := sshd.NewRandomSigner(512)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -50,6 +53,7 @@ func TestHostNameCollision(t *testing.T) {
 	if err != nil {
 		t.Fatal(err)
 	}
+	defer s.Close()
 	host := NewHost(s)
 	go host.Serve()
 
@@ -110,5 +114,44 @@ func TestHostNameCollision(t *testing.T) {
 	}
 
 	<-done
-	s.Close()
+}
+
+func TestHostWhitelist(t *testing.T) {
+	key, err := sshd.NewRandomSigner(512)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	auth := NewAuth()
+	config := sshd.MakeAuth(auth)
+	config.AddHostKey(key)
+
+	s, err := sshd.ListenSSH(":0", config)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer s.Close()
+	host := NewHost(s)
+	host.auth = auth
+	go host.Serve()
+
+	target := s.Addr().String()
+
+	err = sshd.NewClientSession(target, "foo", func(r io.Reader, w io.WriteCloser) {})
+	if err != nil {
+		t.Error(err)
+	}
+
+	clientkey, err := rsa.GenerateKey(rand.Reader, 512)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
+	auth.Whitelist(clientpubkey)
+
+	err = sshd.NewClientSession(target, "foo", func(r io.Reader, w io.WriteCloser) {})
+	if err == nil {
+		t.Error("Failed to block unwhitelisted connection.")
+	}
 }
diff --git a/key.go b/key.go
new file mode 100644
index 0000000..0135e1b
--- /dev/null
+++ b/key.go
@@ -0,0 +1,49 @@
+package main
+
+import (
+	"crypto/x509"
+	"encoding/pem"
+	"fmt"
+	"io/ioutil"
+	"os"
+
+	"code.google.com/p/gopass"
+)
+
+// ReadPrivateKey attempts to read your private key and possibly decrypt it if it
+// requires a passphrase.
+// This function will prompt for a passphrase on STDIN if the environment variable (`IDENTITY_PASSPHRASE`),
+// is not set.
+func ReadPrivateKey(path string) ([]byte, error) {
+	privateKey, err := ioutil.ReadFile(path)
+	if err != nil {
+		return nil, fmt.Errorf("failed to load identity: %v", err)
+	}
+
+	block, rest := pem.Decode(privateKey)
+	if len(rest) > 0 {
+		return nil, fmt.Errorf("extra data when decoding private key")
+	}
+	if !x509.IsEncryptedPEMBlock(block) {
+		return privateKey, nil
+	}
+
+	passphrase := os.Getenv("IDENTITY_PASSPHRASE")
+	if passphrase == "" {
+		passphrase, err = gopass.GetPass("Enter passphrase: ")
+		if err != nil {
+			return nil, fmt.Errorf("couldn't read passphrase: %v", err)
+		}
+	}
+	der, err := x509.DecryptPEMBlock(block, []byte(passphrase))
+	if err != nil {
+		return nil, fmt.Errorf("decrypt failed: %v", err)
+	}
+
+	privateKey = pem.EncodeToMemory(&pem.Block{
+		Type:  block.Type,
+		Bytes: der,
+	})
+
+	return privateKey, nil
+}
diff --git a/sshd/auth.go b/sshd/auth.go
index 90134e5..339d158 100644
--- a/sshd/auth.go
+++ b/sshd/auth.go
@@ -1,30 +1,34 @@
 package sshd
 
 import (
-	"crypto/sha1"
+	"crypto/sha256"
+	"encoding/base64"
 	"errors"
-	"fmt"
-	"strings"
 
 	"golang.org/x/crypto/ssh"
 )
 
+// Auth is used to authenticate connections based on public keys.
 type Auth interface {
+	// Whether to allow connections without a public key.
 	AllowAnonymous() bool
-	Check(string) (bool, error)
+	// Given public key, return if the connection should be permitted.
+	Check(ssh.PublicKey) (bool, error)
 }
 
+// MakeAuth makes an ssh.ServerConfig which performs authentication against an Auth implementation.
 func MakeAuth(auth Auth) *ssh.ServerConfig {
 	config := ssh.ServerConfig{
 		NoClientAuth: false,
 		// Auth-related things should be constant-time to avoid timing attacks.
 		PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
-			fingerprint := Fingerprint(key)
-			ok, err := auth.Check(fingerprint)
+			ok, err := auth.Check(key)
 			if !ok {
 				return nil, err
 			}
-			perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}}
+			perm := &ssh.Permissions{Extensions: map[string]string{
+				"pubkey": string(ssh.MarshalAuthorizedKey(key)),
+			}}
 			return perm, nil
 		},
 		KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
@@ -38,12 +42,16 @@ func MakeAuth(auth Auth) *ssh.ServerConfig {
 	return &config
 }
 
+// MakeNoAuth makes a simple ssh.ServerConfig which allows all connections.
+// Primarily used for testing.
 func MakeNoAuth() *ssh.ServerConfig {
 	config := ssh.ServerConfig{
 		NoClientAuth: false,
 		// Auth-related things should be constant-time to avoid timing attacks.
 		PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
-			perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}}
+			perm := &ssh.Permissions{Extensions: map[string]string{
+				"pubkey": string(ssh.MarshalAuthorizedKey(key)),
+			}}
 			return perm, nil
 		},
 		KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
@@ -54,8 +62,9 @@ func MakeNoAuth() *ssh.ServerConfig {
 	return &config
 }
 
+// Fingerprint performs a SHA256 BASE64 fingerprint of the PublicKey, similar to OpenSSH.
+// See: https://anongit.mindrot.org/openssh.git/commit/?id=56d1c83cdd1ac
 func Fingerprint(k ssh.PublicKey) string {
-	hash := sha1.Sum(k.Marshal())
-	r := fmt.Sprintf("% x", hash)
-	return strings.Replace(r, " ", ":", -1)
+	hash := sha256.Sum256(k.Marshal())
+	return base64.StdEncoding.EncodeToString(hash[:])
 }
diff --git a/sshd/client.go b/sshd/client.go
index 60dab6e..9a01065 100644
--- a/sshd/client.go
+++ b/sshd/client.go
@@ -8,8 +8,8 @@ import (
 	"golang.org/x/crypto/ssh"
 )
 
-// NewRandomKey generates a random key of a desired bit length.
-func NewRandomKey(bits int) (ssh.Signer, error) {
+// NewRandomSigner generates a random key of a desired bit length.
+func NewRandomSigner(bits int) (ssh.Signer, error) {
 	key, err := rsa.GenerateKey(rand.Reader, bits)
 	if err != nil {
 		return nil, err
diff --git a/sshd/client_test.go b/sshd/client_test.go
new file mode 100644
index 0000000..2fd109f
--- /dev/null
+++ b/sshd/client_test.go
@@ -0,0 +1,44 @@
+package sshd
+
+import (
+	"errors"
+	"testing"
+
+	"golang.org/x/crypto/ssh"
+)
+
+var errRejectAuth = errors.New("not welcome here")
+
+type RejectAuth struct{}
+
+func (a RejectAuth) AllowAnonymous() bool {
+	return false
+}
+func (a RejectAuth) Check(ssh.PublicKey) (bool, error) {
+	return false, errRejectAuth
+}
+
+func consume(ch <-chan *Terminal) {
+	for range ch {}
+}
+
+func TestClientReject(t *testing.T) {
+	signer, err := NewRandomSigner(512)
+	config := MakeAuth(RejectAuth{})
+	config.AddHostKey(signer)
+
+	s, err := ListenSSH(":0", config)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer s.Close()
+
+	go consume(s.ServeTerminal())
+
+	conn, err := ssh.Dial("tcp", s.Addr().String(), NewClientConfig("foo"))
+	if err == nil {
+		defer conn.Close()
+		t.Error("Failed to reject conncetion")
+	}
+	t.Log(err)
+}
diff --git a/sshd/net.go b/sshd/net.go
index 6a30976..bb94432 100644
--- a/sshd/net.go
+++ b/sshd/net.go
@@ -29,6 +29,7 @@ func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) {
 		return nil, err
 	}
 
+	// FIXME: Disconnect if too many faulty requests? (Avoid DoS.)
 	go ssh.DiscardRequests(requests)
 	return NewSession(sshConn, channels)
 }
diff --git a/sshd/net_test.go b/sshd/net_test.go
index 8321b30..724ce77 100644
--- a/sshd/net_test.go
+++ b/sshd/net_test.go
@@ -6,8 +6,6 @@ import (
 	"testing"
 )
 
-// TODO: Move some of these into their own package?
-
 func TestServerInit(t *testing.T) {
 	config := MakeNoAuth()
 	s, err := ListenSSH(":badport", config)
@@ -27,7 +25,7 @@ func TestServerInit(t *testing.T) {
 }
 
 func TestServeTerminals(t *testing.T) {
-	signer, err := NewRandomKey(512)
+	signer, err := NewRandomSigner(512)
 	config := MakeNoAuth()
 	config.AddHostKey(signer)