diff --git a/cmd/ssh-chat/cmd.go b/cmd/ssh-chat/cmd.go index 7d76783..bd89b19 100644 --- a/cmd/ssh-chat/cmd.go +++ b/cmd/ssh-chat/cmd.go @@ -112,14 +112,9 @@ func main() { } } - privateKey, err := ReadPrivateKey(privateKeyPath) + signer, err := ReadPrivateKey(privateKeyPath) if err != nil { - fail(2, "Couldn't read private key: %v\n", err) - } - - signer, err := ssh.ParsePrivateKey(privateKey) - if err != nil { - fail(3, "Failed to parse key: %v\n", err) + fail(3, "Failed to read identity private key: %v\n", err) } auth := sshchat.NewAuth() diff --git a/cmd/ssh-chat/key.go b/cmd/ssh-chat/key.go index e7135a9..470efa4 100644 --- a/cmd/ssh-chat/key.go +++ b/cmd/ssh-chat/key.go @@ -1,50 +1,37 @@ package main import ( - "crypto/x509" - "encoding/pem" "fmt" "io/ioutil" "os" "github.com/howeyc/gopass" + "golang.org/x/crypto/ssh" ) // ReadPrivateKey attempts to read your private key and possibly decrypt it if it // requires a passphrase. // This function will prompt for a passphrase on STDIN if the environment variable (`IDENTITY_PASSPHRASE`), // is not set. -func ReadPrivateKey(path string) ([]byte, error) { +func ReadPrivateKey(path string) (ssh.Signer, error) { privateKey, err := ioutil.ReadFile(path) if err != nil { return nil, fmt.Errorf("failed to load identity: %v", err) } - block, rest := pem.Decode(privateKey) - if len(rest) > 0 { - return nil, fmt.Errorf("extra data when decoding private key") - } - if !x509.IsEncryptedPEMBlock(block) { - return privateKey, nil - } - - passphrase := []byte(os.Getenv("IDENTITY_PASSPHRASE")) - if len(passphrase) == 0 { - fmt.Print("Enter passphrase: ") - passphrase, err = gopass.GetPasswd() - if err != nil { - return nil, fmt.Errorf("couldn't read passphrase: %v", err) + pk, err := ssh.ParsePrivateKey(privateKey) + if err == nil { + } else if _, ok := err.(*ssh.PassphraseMissingError); ok { + passphrase := []byte(os.Getenv("IDENTITY_PASSPHRASE")) + if len(passphrase) == 0 { + fmt.Print("Enter passphrase: ") + passphrase, err = gopass.GetPasswd() + if err != nil { + return nil, fmt.Errorf("couldn't read passphrase: %v", err) + } } - } - der, err := x509.DecryptPEMBlock(block, passphrase) - if err != nil { - return nil, fmt.Errorf("decrypt failed: %v", err) + return ssh.ParsePrivateKeyWithPassphrase(privateKey, passphrase) } - privateKey = pem.EncodeToMemory(&pem.Block{ - Type: block.Type, - Bytes: der, - }) - - return privateKey, nil + return pk, err }