add loader to allowlist test

This commit is contained in:
mik2k2 2021-12-24 12:28:29 +01:00
parent 27997bcdf6
commit 7677d48704

View File

@ -2,8 +2,6 @@ package sshchat
import (
"bufio"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"io"
@ -181,20 +179,37 @@ func TestHostAllowlist(t *testing.T) {
target := s.Addr().String()
err := sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
clientPrivateKey, err := sshd.NewRandomSigner(512)
if err != nil {
t.Fatal(err)
}
clientKey := clientPrivateKey.PublicKey()
loadCount := -1
loader := func() ([]ssh.PublicKey, error) {
loadCount++
return [][]ssh.PublicKey{
{},
{clientKey},
}[loadCount], nil
}
auth.LoadAllowlist(loader)
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
if err != nil {
t.Error(err)
}
clientkey, err := rsa.GenerateKey(rand.Reader, 512)
if err != nil {
t.Fatal(err)
auth.SetAllowlistMode(true)
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
if err == nil {
t.Error(err)
}
err = sshd.ConnectShellWithKey(target, "foo", clientPrivateKey, func(r io.Reader, w io.WriteCloser) error { return nil })
if err == nil {
t.Error(err)
}
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
auth.Allowlist(clientpubkey, 0)
auth.SetAllowlistMode(true)
auth.ReloadAllowlist()
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
if err == nil {
t.Error("Failed to block unallowlisted connection.")