diff --git a/auth.go b/auth.go index 27a7077..6357564 100644 --- a/auth.go +++ b/auth.go @@ -17,6 +17,15 @@ type Auth struct { sync.RWMutex } +// NewAuth creates a new default Auth. +func NewAuth() Auth { + return Auth{ + whitelist: make(map[string]struct{}), + banned: make(map[string]struct{}), + ops: make(map[string]struct{}), + } +} + // AllowAnonymous determines if anonymous users are permitted. func (a Auth) AllowAnonymous() bool { a.RLock() diff --git a/cmd.go b/cmd.go index e7bfc68..4e78b77 100644 --- a/cmd.go +++ b/cmd.go @@ -2,6 +2,8 @@ package main import ( "bufio" + "crypto/x509" + "encoding/pem" "fmt" "io/ioutil" "net/http" @@ -14,6 +16,7 @@ import ( "github.com/alexcesaro/log/golog" "github.com/jessevdk/go-flags" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" "github.com/shazow/ssh-chat/chat" "github.com/shazow/ssh-chat/sshd" @@ -80,21 +83,19 @@ func main() { } } - privateKey, err := ioutil.ReadFile(privateKeyPath) + privateKey, err := readPrivateKey(privateKeyPath) if err != nil { - logger.Errorf("Failed to load identity: %v", err) + logger.Errorf("Couldn't read private key: %v", err) os.Exit(2) - return } signer, err := ssh.ParsePrivateKey(privateKey) if err != nil { logger.Errorf("Failed to parse key: %v", err) os.Exit(3) - return } - auth := Auth{} + auth := NewAuth() config := sshd.MakeAuth(auth) config.AddHostKey(signer) @@ -102,7 +103,6 @@ func main() { if err != nil { logger.Errorf("Failed to listen on socket: %v", err) os.Exit(4) - return } defer s.Close() @@ -150,3 +150,43 @@ func main() { logger.Warningf("Interrupt signal detected, shutting down.") os.Exit(0) } + +// readPrivateKey attempts to read your private key and possibly decrypt it if it +// requires a passphrase. +// This function will prompt for a passphrase on STDIN if the environment variable (`IDENTITY_PASSPHRASE`), +// is not set. +func readPrivateKey(privateKeyPath string) ([]byte, error) { + privateKey, err := ioutil.ReadFile(privateKeyPath) + if err != nil { + return nil, fmt.Errorf("failed to load identity: %v", err) + } + + block, rest := pem.Decode(privateKey) + if len(rest) > 0 { + return nil, fmt.Errorf("extra data when decoding private key") + } + if !x509.IsEncryptedPEMBlock(block) { + return privateKey, nil + } + + passphrase := []byte(os.Getenv("IDENTITY_PASSPHRASE")) + if len(passphrase) == 0 { + fmt.Printf("Enter passphrase: ") + passphrase, err = terminal.ReadPassword(int(os.Stdin.Fd())) + if err != nil { + return nil, fmt.Errorf("couldn't read passphrase: %v", err) + } + fmt.Println() + } + der, err := x509.DecryptPEMBlock(block, passphrase) + if err != nil { + return nil, fmt.Errorf("decrypt failed: %v", err) + } + + privateKey = pem.EncodeToMemory(&pem.Block{ + Type: block.Type, + Bytes: der, + }) + + return privateKey, nil +}