From 7677d487041a477728d0889d8b48a9c114c7aebe Mon Sep 17 00:00:00 2001 From: mik2k2 <44849223+mik2k2@users.noreply.github.com> Date: Fri, 24 Dec 2021 12:28:29 +0100 Subject: [PATCH] add loader to allowlist test --- host_test.go | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/host_test.go b/host_test.go index 181341c..34641e5 100644 --- a/host_test.go +++ b/host_test.go @@ -2,8 +2,6 @@ package sshchat import ( "bufio" - "crypto/rand" - "crypto/rsa" "errors" "fmt" "io" @@ -181,20 +179,37 @@ func TestHostAllowlist(t *testing.T) { target := s.Addr().String() - err := sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil }) + clientPrivateKey, err := sshd.NewRandomSigner(512) + if err != nil { + t.Fatal(err) + } + clientKey := clientPrivateKey.PublicKey() + loadCount := -1 + loader := func() ([]ssh.PublicKey, error) { + loadCount++ + return [][]ssh.PublicKey{ + {}, + {clientKey}, + }[loadCount], nil + } + auth.LoadAllowlist(loader) + + err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil }) if err != nil { t.Error(err) } - clientkey, err := rsa.GenerateKey(rand.Reader, 512) - if err != nil { - t.Fatal(err) + auth.SetAllowlistMode(true) + err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil }) + if err == nil { + t.Error(err) + } + err = sshd.ConnectShellWithKey(target, "foo", clientPrivateKey, func(r io.Reader, w io.WriteCloser) error { return nil }) + if err == nil { + t.Error(err) } - clientpubkey, _ := ssh.NewPublicKey(clientkey.Public()) - auth.Allowlist(clientpubkey, 0) - auth.SetAllowlistMode(true) - + auth.ReloadAllowlist() err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil }) if err == nil { t.Error("Failed to block unallowlisted connection.")