set: Improve coverage and cleanup. Switch sshchat package to use it.

This commit is contained in:
Andrey Petrov 2016-08-14 21:03:16 -04:00
parent b0a90315d8
commit 6e02b05f99
6 changed files with 79 additions and 175 deletions

54
auth.go
View File

@ -5,6 +5,7 @@ import (
"net"
"time"
"github.com/shazow/ssh-chat/set"
"github.com/shazow/ssh-chat/sshd"
"golang.org/x/crypto/ssh"
)
@ -20,10 +21,14 @@ func newAuthKey(key ssh.PublicKey) string {
if key == nil {
return ""
}
// FIXME: Is there a way to index pubkeys without marshal'ing them into strings?
// FIXME: Is there a better way to index pubkeys without marshal'ing them into strings?
return sshd.Fingerprint(key)
}
func newAuthItem(key ssh.PublicKey) set.Item {
return set.StringItem(newAuthKey(key))
}
// newAuthAddr returns a string from a net.Addr used to index the address the key in our lookup.
func newAuthAddr(addr net.Addr) string {
if addr == nil {
@ -35,19 +40,19 @@ func newAuthAddr(addr net.Addr) string {
// Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface.
type Auth struct {
bannedAddr *Set
banned *Set
whitelist *Set
ops *Set
bannedAddr *set.Set
banned *set.Set
whitelist *set.Set
ops *set.Set
}
// NewAuth creates a new empty Auth.
func NewAuth() *Auth {
return &Auth{
bannedAddr: NewSet(),
banned: NewSet(),
whitelist: NewSet(),
ops: NewSet(),
bannedAddr: set.New(),
banned: set.New(),
whitelist: set.New(),
ops: set.New(),
}
}
@ -85,13 +90,13 @@ func (a *Auth) Op(key ssh.PublicKey, d time.Duration) {
if key == nil {
return
}
authkey := newAuthKey(key)
authItem := newAuthItem(key)
if d != 0 {
a.ops.AddExpiring(authkey, d)
a.ops.Add(set.Expire(authItem, d))
} else {
a.ops.Add(authkey)
a.ops.Add(authItem)
}
logger.Debugf("Added to ops: %s (for %s)", authkey, d)
logger.Debugf("Added to ops: %s (for %s)", authItem.Key(), d)
}
// IsOp checks if a public key is an op.
@ -108,13 +113,13 @@ func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) {
if key == nil {
return
}
authkey := newAuthKey(key)
authItem := newAuthItem(key)
if d != 0 {
a.whitelist.AddExpiring(authkey, d)
a.whitelist.Add(set.Expire(authItem, d))
} else {
a.whitelist.Add(authkey)
a.whitelist.Add(authItem)
}
logger.Debugf("Added to whitelist: %s (for %s)", authkey, d)
logger.Debugf("Added to whitelist: %s (for %s)", authItem.Key(), d)
}
// Ban will set a public key as banned.
@ -127,21 +132,22 @@ func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) {
// BanFingerprint will set a public key fingerprint as banned.
func (a *Auth) BanFingerprint(authkey string, d time.Duration) {
authItem := set.StringItem(authkey)
if d != 0 {
a.banned.AddExpiring(authkey, d)
a.banned.Add(set.Expire(authItem, d))
} else {
a.banned.Add(authkey)
a.banned.Add(authItem)
}
logger.Debugf("Added to banned: %s (for %s)", authkey, d)
logger.Debugf("Added to banned: %s (for %s)", authItem.Key(), d)
}
// Ban will set an IP address as banned.
func (a *Auth) BanAddr(addr net.Addr, d time.Duration) {
key := newAuthAddr(addr)
authItem := set.StringItem(addr.String())
if d != 0 {
a.bannedAddr.AddExpiring(key, d)
a.bannedAddr.Add(set.Expire(authItem, d))
} else {
a.bannedAddr.Add(key)
a.bannedAddr.Add(authItem)
}
logger.Debugf("Added to bannedAddr: %s (for %s)", key, d)
logger.Debugf("Added to bannedAddr: %s (for %s)", authItem.Key(), d)
}

72
set.go
View File

@ -1,72 +0,0 @@
package sshchat
import (
"sync"
"time"
)
type expiringValue struct {
time.Time
}
func (v expiringValue) Bool() bool {
return time.Now().Before(v.Time)
}
type value struct{}
func (v value) Bool() bool {
return true
}
type setValue interface {
Bool() bool
}
// Set with expire-able keys
type Set struct {
sync.Mutex
lookup map[string]setValue
}
// NewSet creates a new set.
func NewSet() *Set {
return &Set{
lookup: map[string]setValue{},
}
}
// Len returns the size of the set right now.
func (s *Set) Len() int {
s.Lock()
defer s.Unlock()
return len(s.lookup)
}
// In checks if an item exists in this set.
func (s *Set) In(key string) bool {
s.Lock()
v, ok := s.lookup[key]
if ok && !v.Bool() {
ok = false
delete(s.lookup, key)
}
s.Unlock()
return ok
}
// Add item to this set, replace if it exists.
func (s *Set) Add(key string) {
s.Lock()
s.lookup[key] = value{}
s.Unlock()
}
// Add item to this set, replace if it exists.
func (s *Set) AddExpiring(key string, d time.Duration) time.Time {
until := time.Now().Add(d)
s.Lock()
s.lookup[key] = expiringValue{until}
s.Unlock()
return until
}

View File

@ -15,7 +15,7 @@ func (item StringItem) Key() string {
}
func (item StringItem) Value() interface{} {
return string(item)
return true
}
func Expire(item Item, d time.Duration) Item {

View File

@ -15,6 +15,8 @@ var ErrMissing = errors.New("item does not exist")
// Returned when a nil item is added. Nil values are considered expired and invalid.
var ErrNil = errors.New("item value must not be nil")
type IterFunc func(key string, item Item) error
type Set struct {
sync.RWMutex
lookup map[string]Item
@ -153,24 +155,20 @@ func (s *Set) Replace(oldKey string, item Item) error {
// Each loops over every item while holding a read lock and applies fn to each
// element.
func (s *Set) Each(fn func(item Item)) {
cleanup := []string{}
func (s *Set) Each(fn IterFunc) error {
s.RLock()
defer s.RUnlock()
for key, item := range s.lookup {
if item.Value() == nil {
cleanup = append(cleanup, key)
defer s.cleanup(key)
continue
}
fn(item)
}
s.RUnlock()
if len(cleanup) == 0 {
return
}
for _, key := range cleanup {
s.cleanup(key)
if err := fn(key, item); err != nil {
// Abort early
return err
}
}
return nil
}
// ListPrefix returns a list of items with a prefix, normalized.
@ -179,8 +177,11 @@ func (s *Set) ListPrefix(prefix string) []Item {
r := []Item{}
prefix = s.normalize(prefix)
s.Each(func(item Item) {
r = append(r, item)
s.Each(func(key string, item Item) error {
if strings.HasPrefix(key, prefix) {
r = append(r, item)
}
return nil
})
return r

View File

@ -26,14 +26,14 @@ func TestSetExpiring(t *testing.T) {
t.Errorf("ExpiringItem a nanosec ago is not expiring")
}
item = &ExpiringItem{nil, time.Now().Add(time.Minute * 2)}
item = &ExpiringItem{nil, time.Now().Add(time.Minute * 5)}
if item.Expired() {
t.Errorf("ExpiringItem in 2 minutes is expiring now")
}
item = Expire(StringItem("bar"), time.Minute*2).(*ExpiringItem)
item = Expire(StringItem("bar"), time.Minute*5).(*ExpiringItem)
until := item.Time
if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
if !until.After(time.Now().Add(time.Minute*4)) || !until.Before(time.Now().Add(time.Minute*6)) {
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
}
if item.Value() == nil {
@ -54,11 +54,38 @@ func TestSetExpiring(t *testing.T) {
if s.Len() != 2 {
t.Error("not len 2 after set")
}
if err := s.Replace("bar", Expire(StringItem("quux"), time.Minute*5)); err != nil {
t.Fatalf("failed to add quux: %s", err)
}
if err := s.Replace("bar", Expire(StringItem("bar"), time.Minute*5)); err != nil {
if err := s.Replace("quux", Expire(StringItem("bar"), time.Minute*5)); err != nil {
t.Fatalf("failed to add bar: %s", err)
}
if !s.In("bar") {
t.Error("failed to match before expiry")
if s.In("quux") {
t.Error("quux in set after replace")
}
if _, err := s.Get("bar"); err != nil {
t.Errorf("failed to get before expiry: %s", err)
}
if err := s.Add(StringItem("barbar")); err != nil {
t.Fatalf("failed to add barbar")
}
if _, err := s.Get("barbar"); err != nil {
t.Errorf("failed to get barbar: %s", err)
}
b := s.ListPrefix("b")
if len(b) != 2 || b[0].Key() != "bar" || b[1].Key() != "barbar" {
t.Errorf("b-prefix incorrect: %q", b)
}
if err := s.Remove("bar"); err != nil {
t.Fatalf("failed to remove: %s", err)
}
if s.Len() != 2 {
t.Error("not len 2 after remove")
}
s.Clear()
if s.Len() != 0 {
t.Error("not len 0 after clear")
}
}

View File

@ -1,58 +0,0 @@
package sshchat
import (
"testing"
"time"
)
func TestSetExpiring(t *testing.T) {
s := NewSet()
if s.In("foo") {
t.Error("Matched before set.")
}
s.Add("foo")
if !s.In("foo") {
t.Errorf("Not matched after set")
}
if s.Len() != 1 {
t.Error("Not len 1 after set")
}
v := expiringValue{time.Now().Add(-time.Nanosecond * 1)}
if v.Bool() {
t.Errorf("expiringValue now is not expiring")
}
v = expiringValue{time.Now().Add(time.Minute * 2)}
if !v.Bool() {
t.Errorf("expiringValue in 2 minutes is expiring now")
}
until := s.AddExpiring("bar", time.Minute*2)
if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
}
val, ok := s.lookup["bar"]
if !ok {
t.Errorf("bar not in lookup")
}
if !until.Equal(val.(expiringValue).Time) {
t.Errorf("bar's until is not equal to the expected value")
}
if !val.Bool() {
t.Errorf("bar expired immediately")
}
if !s.In("bar") {
t.Errorf("Not matched after timed set")
}
if s.Len() != 2 {
t.Error("Not len 2 after set")
}
s.AddExpiring("bar", time.Nanosecond*1)
if s.In("bar") {
t.Error("Matched after expired timer")
}
}