diff --git a/auth.go b/auth.go index 7095260..ced008b 100644 --- a/auth.go +++ b/auth.go @@ -5,6 +5,7 @@ import ( "net" "time" + "github.com/shazow/ssh-chat/set" "github.com/shazow/ssh-chat/sshd" "golang.org/x/crypto/ssh" ) @@ -20,10 +21,14 @@ func newAuthKey(key ssh.PublicKey) string { if key == nil { return "" } - // FIXME: Is there a way to index pubkeys without marshal'ing them into strings? + // FIXME: Is there a better way to index pubkeys without marshal'ing them into strings? return sshd.Fingerprint(key) } +func newAuthItem(key ssh.PublicKey) set.Item { + return set.StringItem(newAuthKey(key)) +} + // newAuthAddr returns a string from a net.Addr used to index the address the key in our lookup. func newAuthAddr(addr net.Addr) string { if addr == nil { @@ -35,19 +40,19 @@ func newAuthAddr(addr net.Addr) string { // Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface. type Auth struct { - bannedAddr *Set - banned *Set - whitelist *Set - ops *Set + bannedAddr *set.Set + banned *set.Set + whitelist *set.Set + ops *set.Set } // NewAuth creates a new empty Auth. func NewAuth() *Auth { return &Auth{ - bannedAddr: NewSet(), - banned: NewSet(), - whitelist: NewSet(), - ops: NewSet(), + bannedAddr: set.New(), + banned: set.New(), + whitelist: set.New(), + ops: set.New(), } } @@ -85,13 +90,13 @@ func (a *Auth) Op(key ssh.PublicKey, d time.Duration) { if key == nil { return } - authkey := newAuthKey(key) + authItem := newAuthItem(key) if d != 0 { - a.ops.AddExpiring(authkey, d) + a.ops.Add(set.Expire(authItem, d)) } else { - a.ops.Add(authkey) + a.ops.Add(authItem) } - logger.Debugf("Added to ops: %s (for %s)", authkey, d) + logger.Debugf("Added to ops: %s (for %s)", authItem.Key(), d) } // IsOp checks if a public key is an op. @@ -108,13 +113,13 @@ func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) { if key == nil { return } - authkey := newAuthKey(key) + authItem := newAuthItem(key) if d != 0 { - a.whitelist.AddExpiring(authkey, d) + a.whitelist.Add(set.Expire(authItem, d)) } else { - a.whitelist.Add(authkey) + a.whitelist.Add(authItem) } - logger.Debugf("Added to whitelist: %s (for %s)", authkey, d) + logger.Debugf("Added to whitelist: %s (for %s)", authItem.Key(), d) } // Ban will set a public key as banned. @@ -127,21 +132,22 @@ func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) { // BanFingerprint will set a public key fingerprint as banned. func (a *Auth) BanFingerprint(authkey string, d time.Duration) { + authItem := set.StringItem(authkey) if d != 0 { - a.banned.AddExpiring(authkey, d) + a.banned.Add(set.Expire(authItem, d)) } else { - a.banned.Add(authkey) + a.banned.Add(authItem) } - logger.Debugf("Added to banned: %s (for %s)", authkey, d) + logger.Debugf("Added to banned: %s (for %s)", authItem.Key(), d) } // Ban will set an IP address as banned. func (a *Auth) BanAddr(addr net.Addr, d time.Duration) { - key := newAuthAddr(addr) + authItem := set.StringItem(addr.String()) if d != 0 { - a.bannedAddr.AddExpiring(key, d) + a.bannedAddr.Add(set.Expire(authItem, d)) } else { - a.bannedAddr.Add(key) + a.bannedAddr.Add(authItem) } - logger.Debugf("Added to bannedAddr: %s (for %s)", key, d) + logger.Debugf("Added to bannedAddr: %s (for %s)", authItem.Key(), d) } diff --git a/set.go b/set.go deleted file mode 100644 index 3e29a57..0000000 --- a/set.go +++ /dev/null @@ -1,72 +0,0 @@ -package sshchat - -import ( - "sync" - "time" -) - -type expiringValue struct { - time.Time -} - -func (v expiringValue) Bool() bool { - return time.Now().Before(v.Time) -} - -type value struct{} - -func (v value) Bool() bool { - return true -} - -type setValue interface { - Bool() bool -} - -// Set with expire-able keys -type Set struct { - sync.Mutex - lookup map[string]setValue -} - -// NewSet creates a new set. -func NewSet() *Set { - return &Set{ - lookup: map[string]setValue{}, - } -} - -// Len returns the size of the set right now. -func (s *Set) Len() int { - s.Lock() - defer s.Unlock() - return len(s.lookup) -} - -// In checks if an item exists in this set. -func (s *Set) In(key string) bool { - s.Lock() - v, ok := s.lookup[key] - if ok && !v.Bool() { - ok = false - delete(s.lookup, key) - } - s.Unlock() - return ok -} - -// Add item to this set, replace if it exists. -func (s *Set) Add(key string) { - s.Lock() - s.lookup[key] = value{} - s.Unlock() -} - -// Add item to this set, replace if it exists. -func (s *Set) AddExpiring(key string, d time.Duration) time.Time { - until := time.Now().Add(d) - s.Lock() - s.lookup[key] = expiringValue{until} - s.Unlock() - return until -} diff --git a/set/item.go b/set/item.go index bcb2a1b..a59fa13 100644 --- a/set/item.go +++ b/set/item.go @@ -15,7 +15,7 @@ func (item StringItem) Key() string { } func (item StringItem) Value() interface{} { - return string(item) + return true } func Expire(item Item, d time.Duration) Item { diff --git a/set/set.go b/set/set.go index 06ccec3..1b489ad 100644 --- a/set/set.go +++ b/set/set.go @@ -15,6 +15,8 @@ var ErrMissing = errors.New("item does not exist") // Returned when a nil item is added. Nil values are considered expired and invalid. var ErrNil = errors.New("item value must not be nil") +type IterFunc func(key string, item Item) error + type Set struct { sync.RWMutex lookup map[string]Item @@ -153,24 +155,20 @@ func (s *Set) Replace(oldKey string, item Item) error { // Each loops over every item while holding a read lock and applies fn to each // element. -func (s *Set) Each(fn func(item Item)) { - cleanup := []string{} +func (s *Set) Each(fn IterFunc) error { s.RLock() + defer s.RUnlock() for key, item := range s.lookup { if item.Value() == nil { - cleanup = append(cleanup, key) + defer s.cleanup(key) continue } - fn(item) - } - s.RUnlock() - - if len(cleanup) == 0 { - return - } - for _, key := range cleanup { - s.cleanup(key) + if err := fn(key, item); err != nil { + // Abort early + return err + } } + return nil } // ListPrefix returns a list of items with a prefix, normalized. @@ -179,8 +177,11 @@ func (s *Set) ListPrefix(prefix string) []Item { r := []Item{} prefix = s.normalize(prefix) - s.Each(func(item Item) { - r = append(r, item) + s.Each(func(key string, item Item) error { + if strings.HasPrefix(key, prefix) { + r = append(r, item) + } + return nil }) return r diff --git a/set/set_test.go b/set/set_test.go index 7b55dc8..f75192d 100644 --- a/set/set_test.go +++ b/set/set_test.go @@ -26,14 +26,14 @@ func TestSetExpiring(t *testing.T) { t.Errorf("ExpiringItem a nanosec ago is not expiring") } - item = &ExpiringItem{nil, time.Now().Add(time.Minute * 2)} + item = &ExpiringItem{nil, time.Now().Add(time.Minute * 5)} if item.Expired() { t.Errorf("ExpiringItem in 2 minutes is expiring now") } - item = Expire(StringItem("bar"), time.Minute*2).(*ExpiringItem) + item = Expire(StringItem("bar"), time.Minute*5).(*ExpiringItem) until := item.Time - if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) { + if !until.After(time.Now().Add(time.Minute*4)) || !until.Before(time.Now().Add(time.Minute*6)) { t.Errorf("until is not a minute after %s: %s", time.Now(), until) } if item.Value() == nil { @@ -54,11 +54,38 @@ func TestSetExpiring(t *testing.T) { if s.Len() != 2 { t.Error("not len 2 after set") } + if err := s.Replace("bar", Expire(StringItem("quux"), time.Minute*5)); err != nil { + t.Fatalf("failed to add quux: %s", err) + } - if err := s.Replace("bar", Expire(StringItem("bar"), time.Minute*5)); err != nil { + if err := s.Replace("quux", Expire(StringItem("bar"), time.Minute*5)); err != nil { t.Fatalf("failed to add bar: %s", err) } - if !s.In("bar") { - t.Error("failed to match before expiry") + if s.In("quux") { + t.Error("quux in set after replace") + } + if _, err := s.Get("bar"); err != nil { + t.Errorf("failed to get before expiry: %s", err) + } + if err := s.Add(StringItem("barbar")); err != nil { + t.Fatalf("failed to add barbar") + } + if _, err := s.Get("barbar"); err != nil { + t.Errorf("failed to get barbar: %s", err) + } + b := s.ListPrefix("b") + if len(b) != 2 || b[0].Key() != "bar" || b[1].Key() != "barbar" { + t.Errorf("b-prefix incorrect: %q", b) + } + + if err := s.Remove("bar"); err != nil { + t.Fatalf("failed to remove: %s", err) + } + if s.Len() != 2 { + t.Error("not len 2 after remove") + } + s.Clear() + if s.Len() != 0 { + t.Error("not len 0 after clear") } } diff --git a/set_test.go b/set_test.go deleted file mode 100644 index 1d7fbef..0000000 --- a/set_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package sshchat - -import ( - "testing" - "time" -) - -func TestSetExpiring(t *testing.T) { - s := NewSet() - if s.In("foo") { - t.Error("Matched before set.") - } - - s.Add("foo") - if !s.In("foo") { - t.Errorf("Not matched after set") - } - if s.Len() != 1 { - t.Error("Not len 1 after set") - } - - v := expiringValue{time.Now().Add(-time.Nanosecond * 1)} - if v.Bool() { - t.Errorf("expiringValue now is not expiring") - } - - v = expiringValue{time.Now().Add(time.Minute * 2)} - if !v.Bool() { - t.Errorf("expiringValue in 2 minutes is expiring now") - } - - until := s.AddExpiring("bar", time.Minute*2) - if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) { - t.Errorf("until is not a minute after %s: %s", time.Now(), until) - } - val, ok := s.lookup["bar"] - if !ok { - t.Errorf("bar not in lookup") - } - if !until.Equal(val.(expiringValue).Time) { - t.Errorf("bar's until is not equal to the expected value") - } - if !val.Bool() { - t.Errorf("bar expired immediately") - } - - if !s.In("bar") { - t.Errorf("Not matched after timed set") - } - if s.Len() != 2 { - t.Error("Not len 2 after set") - } - - s.AddExpiring("bar", time.Nanosecond*1) - if s.In("bar") { - t.Error("Matched after expired timer") - } -}