mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-12 23:27:17 +03:00
set: Improve coverage and cleanup. Switch sshchat package to use it.
This commit is contained in:
parent
b0a90315d8
commit
6e02b05f99
54
auth.go
54
auth.go
@ -5,6 +5,7 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/shazow/ssh-chat/set"
|
||||
"github.com/shazow/ssh-chat/sshd"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
@ -20,10 +21,14 @@ func newAuthKey(key ssh.PublicKey) string {
|
||||
if key == nil {
|
||||
return ""
|
||||
}
|
||||
// FIXME: Is there a way to index pubkeys without marshal'ing them into strings?
|
||||
// FIXME: Is there a better way to index pubkeys without marshal'ing them into strings?
|
||||
return sshd.Fingerprint(key)
|
||||
}
|
||||
|
||||
func newAuthItem(key ssh.PublicKey) set.Item {
|
||||
return set.StringItem(newAuthKey(key))
|
||||
}
|
||||
|
||||
// newAuthAddr returns a string from a net.Addr used to index the address the key in our lookup.
|
||||
func newAuthAddr(addr net.Addr) string {
|
||||
if addr == nil {
|
||||
@ -35,19 +40,19 @@ func newAuthAddr(addr net.Addr) string {
|
||||
|
||||
// Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface.
|
||||
type Auth struct {
|
||||
bannedAddr *Set
|
||||
banned *Set
|
||||
whitelist *Set
|
||||
ops *Set
|
||||
bannedAddr *set.Set
|
||||
banned *set.Set
|
||||
whitelist *set.Set
|
||||
ops *set.Set
|
||||
}
|
||||
|
||||
// NewAuth creates a new empty Auth.
|
||||
func NewAuth() *Auth {
|
||||
return &Auth{
|
||||
bannedAddr: NewSet(),
|
||||
banned: NewSet(),
|
||||
whitelist: NewSet(),
|
||||
ops: NewSet(),
|
||||
bannedAddr: set.New(),
|
||||
banned: set.New(),
|
||||
whitelist: set.New(),
|
||||
ops: set.New(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,13 +90,13 @@ func (a *Auth) Op(key ssh.PublicKey, d time.Duration) {
|
||||
if key == nil {
|
||||
return
|
||||
}
|
||||
authkey := newAuthKey(key)
|
||||
authItem := newAuthItem(key)
|
||||
if d != 0 {
|
||||
a.ops.AddExpiring(authkey, d)
|
||||
a.ops.Add(set.Expire(authItem, d))
|
||||
} else {
|
||||
a.ops.Add(authkey)
|
||||
a.ops.Add(authItem)
|
||||
}
|
||||
logger.Debugf("Added to ops: %s (for %s)", authkey, d)
|
||||
logger.Debugf("Added to ops: %s (for %s)", authItem.Key(), d)
|
||||
}
|
||||
|
||||
// IsOp checks if a public key is an op.
|
||||
@ -108,13 +113,13 @@ func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) {
|
||||
if key == nil {
|
||||
return
|
||||
}
|
||||
authkey := newAuthKey(key)
|
||||
authItem := newAuthItem(key)
|
||||
if d != 0 {
|
||||
a.whitelist.AddExpiring(authkey, d)
|
||||
a.whitelist.Add(set.Expire(authItem, d))
|
||||
} else {
|
||||
a.whitelist.Add(authkey)
|
||||
a.whitelist.Add(authItem)
|
||||
}
|
||||
logger.Debugf("Added to whitelist: %s (for %s)", authkey, d)
|
||||
logger.Debugf("Added to whitelist: %s (for %s)", authItem.Key(), d)
|
||||
}
|
||||
|
||||
// Ban will set a public key as banned.
|
||||
@ -127,21 +132,22 @@ func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) {
|
||||
|
||||
// BanFingerprint will set a public key fingerprint as banned.
|
||||
func (a *Auth) BanFingerprint(authkey string, d time.Duration) {
|
||||
authItem := set.StringItem(authkey)
|
||||
if d != 0 {
|
||||
a.banned.AddExpiring(authkey, d)
|
||||
a.banned.Add(set.Expire(authItem, d))
|
||||
} else {
|
||||
a.banned.Add(authkey)
|
||||
a.banned.Add(authItem)
|
||||
}
|
||||
logger.Debugf("Added to banned: %s (for %s)", authkey, d)
|
||||
logger.Debugf("Added to banned: %s (for %s)", authItem.Key(), d)
|
||||
}
|
||||
|
||||
// Ban will set an IP address as banned.
|
||||
func (a *Auth) BanAddr(addr net.Addr, d time.Duration) {
|
||||
key := newAuthAddr(addr)
|
||||
authItem := set.StringItem(addr.String())
|
||||
if d != 0 {
|
||||
a.bannedAddr.AddExpiring(key, d)
|
||||
a.bannedAddr.Add(set.Expire(authItem, d))
|
||||
} else {
|
||||
a.bannedAddr.Add(key)
|
||||
a.bannedAddr.Add(authItem)
|
||||
}
|
||||
logger.Debugf("Added to bannedAddr: %s (for %s)", key, d)
|
||||
logger.Debugf("Added to bannedAddr: %s (for %s)", authItem.Key(), d)
|
||||
}
|
||||
|
72
set.go
72
set.go
@ -1,72 +0,0 @@
|
||||
package sshchat
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type expiringValue struct {
|
||||
time.Time
|
||||
}
|
||||
|
||||
func (v expiringValue) Bool() bool {
|
||||
return time.Now().Before(v.Time)
|
||||
}
|
||||
|
||||
type value struct{}
|
||||
|
||||
func (v value) Bool() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type setValue interface {
|
||||
Bool() bool
|
||||
}
|
||||
|
||||
// Set with expire-able keys
|
||||
type Set struct {
|
||||
sync.Mutex
|
||||
lookup map[string]setValue
|
||||
}
|
||||
|
||||
// NewSet creates a new set.
|
||||
func NewSet() *Set {
|
||||
return &Set{
|
||||
lookup: map[string]setValue{},
|
||||
}
|
||||
}
|
||||
|
||||
// Len returns the size of the set right now.
|
||||
func (s *Set) Len() int {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
return len(s.lookup)
|
||||
}
|
||||
|
||||
// In checks if an item exists in this set.
|
||||
func (s *Set) In(key string) bool {
|
||||
s.Lock()
|
||||
v, ok := s.lookup[key]
|
||||
if ok && !v.Bool() {
|
||||
ok = false
|
||||
delete(s.lookup, key)
|
||||
}
|
||||
s.Unlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
// Add item to this set, replace if it exists.
|
||||
func (s *Set) Add(key string) {
|
||||
s.Lock()
|
||||
s.lookup[key] = value{}
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
// Add item to this set, replace if it exists.
|
||||
func (s *Set) AddExpiring(key string, d time.Duration) time.Time {
|
||||
until := time.Now().Add(d)
|
||||
s.Lock()
|
||||
s.lookup[key] = expiringValue{until}
|
||||
s.Unlock()
|
||||
return until
|
||||
}
|
@ -15,7 +15,7 @@ func (item StringItem) Key() string {
|
||||
}
|
||||
|
||||
func (item StringItem) Value() interface{} {
|
||||
return string(item)
|
||||
return true
|
||||
}
|
||||
|
||||
func Expire(item Item, d time.Duration) Item {
|
||||
|
29
set/set.go
29
set/set.go
@ -15,6 +15,8 @@ var ErrMissing = errors.New("item does not exist")
|
||||
// Returned when a nil item is added. Nil values are considered expired and invalid.
|
||||
var ErrNil = errors.New("item value must not be nil")
|
||||
|
||||
type IterFunc func(key string, item Item) error
|
||||
|
||||
type Set struct {
|
||||
sync.RWMutex
|
||||
lookup map[string]Item
|
||||
@ -153,24 +155,20 @@ func (s *Set) Replace(oldKey string, item Item) error {
|
||||
|
||||
// Each loops over every item while holding a read lock and applies fn to each
|
||||
// element.
|
||||
func (s *Set) Each(fn func(item Item)) {
|
||||
cleanup := []string{}
|
||||
func (s *Set) Each(fn IterFunc) error {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
for key, item := range s.lookup {
|
||||
if item.Value() == nil {
|
||||
cleanup = append(cleanup, key)
|
||||
defer s.cleanup(key)
|
||||
continue
|
||||
}
|
||||
fn(item)
|
||||
}
|
||||
s.RUnlock()
|
||||
|
||||
if len(cleanup) == 0 {
|
||||
return
|
||||
}
|
||||
for _, key := range cleanup {
|
||||
s.cleanup(key)
|
||||
if err := fn(key, item); err != nil {
|
||||
// Abort early
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPrefix returns a list of items with a prefix, normalized.
|
||||
@ -179,8 +177,11 @@ func (s *Set) ListPrefix(prefix string) []Item {
|
||||
r := []Item{}
|
||||
prefix = s.normalize(prefix)
|
||||
|
||||
s.Each(func(item Item) {
|
||||
r = append(r, item)
|
||||
s.Each(func(key string, item Item) error {
|
||||
if strings.HasPrefix(key, prefix) {
|
||||
r = append(r, item)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return r
|
||||
|
@ -26,14 +26,14 @@ func TestSetExpiring(t *testing.T) {
|
||||
t.Errorf("ExpiringItem a nanosec ago is not expiring")
|
||||
}
|
||||
|
||||
item = &ExpiringItem{nil, time.Now().Add(time.Minute * 2)}
|
||||
item = &ExpiringItem{nil, time.Now().Add(time.Minute * 5)}
|
||||
if item.Expired() {
|
||||
t.Errorf("ExpiringItem in 2 minutes is expiring now")
|
||||
}
|
||||
|
||||
item = Expire(StringItem("bar"), time.Minute*2).(*ExpiringItem)
|
||||
item = Expire(StringItem("bar"), time.Minute*5).(*ExpiringItem)
|
||||
until := item.Time
|
||||
if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
|
||||
if !until.After(time.Now().Add(time.Minute*4)) || !until.Before(time.Now().Add(time.Minute*6)) {
|
||||
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
|
||||
}
|
||||
if item.Value() == nil {
|
||||
@ -54,11 +54,38 @@ func TestSetExpiring(t *testing.T) {
|
||||
if s.Len() != 2 {
|
||||
t.Error("not len 2 after set")
|
||||
}
|
||||
if err := s.Replace("bar", Expire(StringItem("quux"), time.Minute*5)); err != nil {
|
||||
t.Fatalf("failed to add quux: %s", err)
|
||||
}
|
||||
|
||||
if err := s.Replace("bar", Expire(StringItem("bar"), time.Minute*5)); err != nil {
|
||||
if err := s.Replace("quux", Expire(StringItem("bar"), time.Minute*5)); err != nil {
|
||||
t.Fatalf("failed to add bar: %s", err)
|
||||
}
|
||||
if !s.In("bar") {
|
||||
t.Error("failed to match before expiry")
|
||||
if s.In("quux") {
|
||||
t.Error("quux in set after replace")
|
||||
}
|
||||
if _, err := s.Get("bar"); err != nil {
|
||||
t.Errorf("failed to get before expiry: %s", err)
|
||||
}
|
||||
if err := s.Add(StringItem("barbar")); err != nil {
|
||||
t.Fatalf("failed to add barbar")
|
||||
}
|
||||
if _, err := s.Get("barbar"); err != nil {
|
||||
t.Errorf("failed to get barbar: %s", err)
|
||||
}
|
||||
b := s.ListPrefix("b")
|
||||
if len(b) != 2 || b[0].Key() != "bar" || b[1].Key() != "barbar" {
|
||||
t.Errorf("b-prefix incorrect: %q", b)
|
||||
}
|
||||
|
||||
if err := s.Remove("bar"); err != nil {
|
||||
t.Fatalf("failed to remove: %s", err)
|
||||
}
|
||||
if s.Len() != 2 {
|
||||
t.Error("not len 2 after remove")
|
||||
}
|
||||
s.Clear()
|
||||
if s.Len() != 0 {
|
||||
t.Error("not len 0 after clear")
|
||||
}
|
||||
}
|
||||
|
58
set_test.go
58
set_test.go
@ -1,58 +0,0 @@
|
||||
package sshchat
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSetExpiring(t *testing.T) {
|
||||
s := NewSet()
|
||||
if s.In("foo") {
|
||||
t.Error("Matched before set.")
|
||||
}
|
||||
|
||||
s.Add("foo")
|
||||
if !s.In("foo") {
|
||||
t.Errorf("Not matched after set")
|
||||
}
|
||||
if s.Len() != 1 {
|
||||
t.Error("Not len 1 after set")
|
||||
}
|
||||
|
||||
v := expiringValue{time.Now().Add(-time.Nanosecond * 1)}
|
||||
if v.Bool() {
|
||||
t.Errorf("expiringValue now is not expiring")
|
||||
}
|
||||
|
||||
v = expiringValue{time.Now().Add(time.Minute * 2)}
|
||||
if !v.Bool() {
|
||||
t.Errorf("expiringValue in 2 minutes is expiring now")
|
||||
}
|
||||
|
||||
until := s.AddExpiring("bar", time.Minute*2)
|
||||
if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
|
||||
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
|
||||
}
|
||||
val, ok := s.lookup["bar"]
|
||||
if !ok {
|
||||
t.Errorf("bar not in lookup")
|
||||
}
|
||||
if !until.Equal(val.(expiringValue).Time) {
|
||||
t.Errorf("bar's until is not equal to the expected value")
|
||||
}
|
||||
if !val.Bool() {
|
||||
t.Errorf("bar expired immediately")
|
||||
}
|
||||
|
||||
if !s.In("bar") {
|
||||
t.Errorf("Not matched after timed set")
|
||||
}
|
||||
if s.Len() != 2 {
|
||||
t.Error("Not len 2 after set")
|
||||
}
|
||||
|
||||
s.AddExpiring("bar", time.Nanosecond*1)
|
||||
if s.In("bar") {
|
||||
t.Error("Matched after expired timer")
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user