set: Improve coverage and cleanup. Switch sshchat package to use it.

This commit is contained in:
Andrey Petrov 2016-08-14 21:03:16 -04:00
parent b0a90315d8
commit 6e02b05f99
6 changed files with 79 additions and 175 deletions

54
auth.go
View File

@ -5,6 +5,7 @@ import (
"net" "net"
"time" "time"
"github.com/shazow/ssh-chat/set"
"github.com/shazow/ssh-chat/sshd" "github.com/shazow/ssh-chat/sshd"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -20,10 +21,14 @@ func newAuthKey(key ssh.PublicKey) string {
if key == nil { if key == nil {
return "" return ""
} }
// FIXME: Is there a way to index pubkeys without marshal'ing them into strings? // FIXME: Is there a better way to index pubkeys without marshal'ing them into strings?
return sshd.Fingerprint(key) return sshd.Fingerprint(key)
} }
func newAuthItem(key ssh.PublicKey) set.Item {
return set.StringItem(newAuthKey(key))
}
// newAuthAddr returns a string from a net.Addr used to index the address the key in our lookup. // newAuthAddr returns a string from a net.Addr used to index the address the key in our lookup.
func newAuthAddr(addr net.Addr) string { func newAuthAddr(addr net.Addr) string {
if addr == nil { if addr == nil {
@ -35,19 +40,19 @@ func newAuthAddr(addr net.Addr) string {
// Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface. // Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface.
type Auth struct { type Auth struct {
bannedAddr *Set bannedAddr *set.Set
banned *Set banned *set.Set
whitelist *Set whitelist *set.Set
ops *Set ops *set.Set
} }
// NewAuth creates a new empty Auth. // NewAuth creates a new empty Auth.
func NewAuth() *Auth { func NewAuth() *Auth {
return &Auth{ return &Auth{
bannedAddr: NewSet(), bannedAddr: set.New(),
banned: NewSet(), banned: set.New(),
whitelist: NewSet(), whitelist: set.New(),
ops: NewSet(), ops: set.New(),
} }
} }
@ -85,13 +90,13 @@ func (a *Auth) Op(key ssh.PublicKey, d time.Duration) {
if key == nil { if key == nil {
return return
} }
authkey := newAuthKey(key) authItem := newAuthItem(key)
if d != 0 { if d != 0 {
a.ops.AddExpiring(authkey, d) a.ops.Add(set.Expire(authItem, d))
} else { } else {
a.ops.Add(authkey) a.ops.Add(authItem)
} }
logger.Debugf("Added to ops: %s (for %s)", authkey, d) logger.Debugf("Added to ops: %s (for %s)", authItem.Key(), d)
} }
// IsOp checks if a public key is an op. // IsOp checks if a public key is an op.
@ -108,13 +113,13 @@ func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) {
if key == nil { if key == nil {
return return
} }
authkey := newAuthKey(key) authItem := newAuthItem(key)
if d != 0 { if d != 0 {
a.whitelist.AddExpiring(authkey, d) a.whitelist.Add(set.Expire(authItem, d))
} else { } else {
a.whitelist.Add(authkey) a.whitelist.Add(authItem)
} }
logger.Debugf("Added to whitelist: %s (for %s)", authkey, d) logger.Debugf("Added to whitelist: %s (for %s)", authItem.Key(), d)
} }
// Ban will set a public key as banned. // Ban will set a public key as banned.
@ -127,21 +132,22 @@ func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) {
// BanFingerprint will set a public key fingerprint as banned. // BanFingerprint will set a public key fingerprint as banned.
func (a *Auth) BanFingerprint(authkey string, d time.Duration) { func (a *Auth) BanFingerprint(authkey string, d time.Duration) {
authItem := set.StringItem(authkey)
if d != 0 { if d != 0 {
a.banned.AddExpiring(authkey, d) a.banned.Add(set.Expire(authItem, d))
} else { } else {
a.banned.Add(authkey) a.banned.Add(authItem)
} }
logger.Debugf("Added to banned: %s (for %s)", authkey, d) logger.Debugf("Added to banned: %s (for %s)", authItem.Key(), d)
} }
// Ban will set an IP address as banned. // Ban will set an IP address as banned.
func (a *Auth) BanAddr(addr net.Addr, d time.Duration) { func (a *Auth) BanAddr(addr net.Addr, d time.Duration) {
key := newAuthAddr(addr) authItem := set.StringItem(addr.String())
if d != 0 { if d != 0 {
a.bannedAddr.AddExpiring(key, d) a.bannedAddr.Add(set.Expire(authItem, d))
} else { } else {
a.bannedAddr.Add(key) a.bannedAddr.Add(authItem)
} }
logger.Debugf("Added to bannedAddr: %s (for %s)", key, d) logger.Debugf("Added to bannedAddr: %s (for %s)", authItem.Key(), d)
} }

72
set.go
View File

@ -1,72 +0,0 @@
package sshchat
import (
"sync"
"time"
)
type expiringValue struct {
time.Time
}
func (v expiringValue) Bool() bool {
return time.Now().Before(v.Time)
}
type value struct{}
func (v value) Bool() bool {
return true
}
type setValue interface {
Bool() bool
}
// Set with expire-able keys
type Set struct {
sync.Mutex
lookup map[string]setValue
}
// NewSet creates a new set.
func NewSet() *Set {
return &Set{
lookup: map[string]setValue{},
}
}
// Len returns the size of the set right now.
func (s *Set) Len() int {
s.Lock()
defer s.Unlock()
return len(s.lookup)
}
// In checks if an item exists in this set.
func (s *Set) In(key string) bool {
s.Lock()
v, ok := s.lookup[key]
if ok && !v.Bool() {
ok = false
delete(s.lookup, key)
}
s.Unlock()
return ok
}
// Add item to this set, replace if it exists.
func (s *Set) Add(key string) {
s.Lock()
s.lookup[key] = value{}
s.Unlock()
}
// Add item to this set, replace if it exists.
func (s *Set) AddExpiring(key string, d time.Duration) time.Time {
until := time.Now().Add(d)
s.Lock()
s.lookup[key] = expiringValue{until}
s.Unlock()
return until
}

View File

@ -15,7 +15,7 @@ func (item StringItem) Key() string {
} }
func (item StringItem) Value() interface{} { func (item StringItem) Value() interface{} {
return string(item) return true
} }
func Expire(item Item, d time.Duration) Item { func Expire(item Item, d time.Duration) Item {

View File

@ -15,6 +15,8 @@ var ErrMissing = errors.New("item does not exist")
// Returned when a nil item is added. Nil values are considered expired and invalid. // Returned when a nil item is added. Nil values are considered expired and invalid.
var ErrNil = errors.New("item value must not be nil") var ErrNil = errors.New("item value must not be nil")
type IterFunc func(key string, item Item) error
type Set struct { type Set struct {
sync.RWMutex sync.RWMutex
lookup map[string]Item lookup map[string]Item
@ -153,24 +155,20 @@ func (s *Set) Replace(oldKey string, item Item) error {
// Each loops over every item while holding a read lock and applies fn to each // Each loops over every item while holding a read lock and applies fn to each
// element. // element.
func (s *Set) Each(fn func(item Item)) { func (s *Set) Each(fn IterFunc) error {
cleanup := []string{}
s.RLock() s.RLock()
defer s.RUnlock()
for key, item := range s.lookup { for key, item := range s.lookup {
if item.Value() == nil { if item.Value() == nil {
cleanup = append(cleanup, key) defer s.cleanup(key)
continue continue
} }
fn(item) if err := fn(key, item); err != nil {
} // Abort early
s.RUnlock() return err
}
if len(cleanup) == 0 {
return
}
for _, key := range cleanup {
s.cleanup(key)
} }
return nil
} }
// ListPrefix returns a list of items with a prefix, normalized. // ListPrefix returns a list of items with a prefix, normalized.
@ -179,8 +177,11 @@ func (s *Set) ListPrefix(prefix string) []Item {
r := []Item{} r := []Item{}
prefix = s.normalize(prefix) prefix = s.normalize(prefix)
s.Each(func(item Item) { s.Each(func(key string, item Item) error {
r = append(r, item) if strings.HasPrefix(key, prefix) {
r = append(r, item)
}
return nil
}) })
return r return r

View File

@ -26,14 +26,14 @@ func TestSetExpiring(t *testing.T) {
t.Errorf("ExpiringItem a nanosec ago is not expiring") t.Errorf("ExpiringItem a nanosec ago is not expiring")
} }
item = &ExpiringItem{nil, time.Now().Add(time.Minute * 2)} item = &ExpiringItem{nil, time.Now().Add(time.Minute * 5)}
if item.Expired() { if item.Expired() {
t.Errorf("ExpiringItem in 2 minutes is expiring now") t.Errorf("ExpiringItem in 2 minutes is expiring now")
} }
item = Expire(StringItem("bar"), time.Minute*2).(*ExpiringItem) item = Expire(StringItem("bar"), time.Minute*5).(*ExpiringItem)
until := item.Time until := item.Time
if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) { if !until.After(time.Now().Add(time.Minute*4)) || !until.Before(time.Now().Add(time.Minute*6)) {
t.Errorf("until is not a minute after %s: %s", time.Now(), until) t.Errorf("until is not a minute after %s: %s", time.Now(), until)
} }
if item.Value() == nil { if item.Value() == nil {
@ -54,11 +54,38 @@ func TestSetExpiring(t *testing.T) {
if s.Len() != 2 { if s.Len() != 2 {
t.Error("not len 2 after set") t.Error("not len 2 after set")
} }
if err := s.Replace("bar", Expire(StringItem("quux"), time.Minute*5)); err != nil {
t.Fatalf("failed to add quux: %s", err)
}
if err := s.Replace("bar", Expire(StringItem("bar"), time.Minute*5)); err != nil { if err := s.Replace("quux", Expire(StringItem("bar"), time.Minute*5)); err != nil {
t.Fatalf("failed to add bar: %s", err) t.Fatalf("failed to add bar: %s", err)
} }
if !s.In("bar") { if s.In("quux") {
t.Error("failed to match before expiry") t.Error("quux in set after replace")
}
if _, err := s.Get("bar"); err != nil {
t.Errorf("failed to get before expiry: %s", err)
}
if err := s.Add(StringItem("barbar")); err != nil {
t.Fatalf("failed to add barbar")
}
if _, err := s.Get("barbar"); err != nil {
t.Errorf("failed to get barbar: %s", err)
}
b := s.ListPrefix("b")
if len(b) != 2 || b[0].Key() != "bar" || b[1].Key() != "barbar" {
t.Errorf("b-prefix incorrect: %q", b)
}
if err := s.Remove("bar"); err != nil {
t.Fatalf("failed to remove: %s", err)
}
if s.Len() != 2 {
t.Error("not len 2 after remove")
}
s.Clear()
if s.Len() != 0 {
t.Error("not len 0 after clear")
} }
} }

View File

@ -1,58 +0,0 @@
package sshchat
import (
"testing"
"time"
)
func TestSetExpiring(t *testing.T) {
s := NewSet()
if s.In("foo") {
t.Error("Matched before set.")
}
s.Add("foo")
if !s.In("foo") {
t.Errorf("Not matched after set")
}
if s.Len() != 1 {
t.Error("Not len 1 after set")
}
v := expiringValue{time.Now().Add(-time.Nanosecond * 1)}
if v.Bool() {
t.Errorf("expiringValue now is not expiring")
}
v = expiringValue{time.Now().Add(time.Minute * 2)}
if !v.Bool() {
t.Errorf("expiringValue in 2 minutes is expiring now")
}
until := s.AddExpiring("bar", time.Minute*2)
if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
}
val, ok := s.lookup["bar"]
if !ok {
t.Errorf("bar not in lookup")
}
if !until.Equal(val.(expiringValue).Time) {
t.Errorf("bar's until is not equal to the expected value")
}
if !val.Bool() {
t.Errorf("bar expired immediately")
}
if !s.In("bar") {
t.Errorf("Not matched after timed set")
}
if s.Len() != 2 {
t.Error("Not len 2 after set")
}
s.AddExpiring("bar", time.Nanosecond*1)
if s.In("bar") {
t.Error("Matched after expired timer")
}
}