main: Use x/crypto/ssh helpers for parsing passworded keys

This commit is contained in:
Andrey Petrov 2020-04-20 15:34:42 -04:00
parent 5c71e9b242
commit daad9ba07b
2 changed files with 16 additions and 34 deletions

View File

@ -112,14 +112,9 @@ func main() {
} }
} }
privateKey, err := ReadPrivateKey(privateKeyPath) signer, err := ReadPrivateKey(privateKeyPath)
if err != nil { if err != nil {
fail(2, "Couldn't read private key: %v\n", err) fail(3, "Failed to read identity private key: %v\n", err)
}
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
fail(3, "Failed to parse key: %v\n", err)
} }
auth := sshchat.NewAuth() auth := sshchat.NewAuth()

View File

@ -1,33 +1,27 @@
package main package main
import ( import (
"crypto/x509"
"encoding/pem"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"github.com/howeyc/gopass" "github.com/howeyc/gopass"
"golang.org/x/crypto/ssh"
) )
// ReadPrivateKey attempts to read your private key and possibly decrypt it if it // ReadPrivateKey attempts to read your private key and possibly decrypt it if it
// requires a passphrase. // requires a passphrase.
// This function will prompt for a passphrase on STDIN if the environment variable (`IDENTITY_PASSPHRASE`), // This function will prompt for a passphrase on STDIN if the environment variable (`IDENTITY_PASSPHRASE`),
// is not set. // is not set.
func ReadPrivateKey(path string) ([]byte, error) { func ReadPrivateKey(path string) (ssh.Signer, error) {
privateKey, err := ioutil.ReadFile(path) privateKey, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load identity: %v", err) return nil, fmt.Errorf("failed to load identity: %v", err)
} }
block, rest := pem.Decode(privateKey) pk, err := ssh.ParsePrivateKey(privateKey)
if len(rest) > 0 { if err == nil {
return nil, fmt.Errorf("extra data when decoding private key") } else if _, ok := err.(*ssh.PassphraseMissingError); ok {
}
if !x509.IsEncryptedPEMBlock(block) {
return privateKey, nil
}
passphrase := []byte(os.Getenv("IDENTITY_PASSPHRASE")) passphrase := []byte(os.Getenv("IDENTITY_PASSPHRASE"))
if len(passphrase) == 0 { if len(passphrase) == 0 {
fmt.Print("Enter passphrase: ") fmt.Print("Enter passphrase: ")
@ -36,15 +30,8 @@ func ReadPrivateKey(path string) ([]byte, error) {
return nil, fmt.Errorf("couldn't read passphrase: %v", err) return nil, fmt.Errorf("couldn't read passphrase: %v", err)
} }
} }
der, err := x509.DecryptPEMBlock(block, passphrase) return ssh.ParsePrivateKeyWithPassphrase(privateKey, passphrase)
if err != nil {
return nil, fmt.Errorf("decrypt failed: %v", err)
} }
privateKey = pem.EncodeToMemory(&pem.Block{ return pk, err
Type: block.Type,
Bytes: der,
})
return privateKey, nil
} }