diff --git a/auth.go b/auth.go index 26a217c..8c86b26 100644 --- a/auth.go +++ b/auth.go @@ -4,7 +4,9 @@ import ( "errors" "net" "sync" + "time" + "github.com/shazow/ssh-chat/sshd" "golang.org/x/crypto/ssh" ) @@ -14,16 +16,13 @@ var ErrNotWhitelisted = errors.New("not whitelisted") // The error returned a key is checked that is banned. var ErrBanned = errors.New("banned") -// AuthKey is the type that our lookups are keyed against. -type AuthKey string - // NewAuthKey returns string from an ssh.PublicKey. func NewAuthKey(key ssh.PublicKey) string { if key == nil { return "" } // FIXME: Is there a way to index pubkeys without marshal'ing them into strings? - return string(key.Marshal()) + return sshd.Fingerprint(key) } // NewAuthAddr returns a string from a net.Addr @@ -39,50 +38,43 @@ func NewAuthAddr(addr net.Addr) string { // TODO: Add timed auth by using a time.Time instead of struct{} for values. type Auth struct { sync.RWMutex - bannedAddr map[string]struct{} - banned map[string]struct{} - whitelist map[string]struct{} - ops map[string]struct{} + bannedAddr *Set + banned *Set + whitelist *Set + ops *Set } // NewAuth creates a new default Auth. func NewAuth() *Auth { return &Auth{ - bannedAddr: make(map[string]struct{}), - banned: make(map[string]struct{}), - whitelist: make(map[string]struct{}), - ops: make(map[string]struct{}), + bannedAddr: NewSet(), + banned: NewSet(), + whitelist: NewSet(), + ops: NewSet(), } } // AllowAnonymous determines if anonymous users are permitted. func (a Auth) AllowAnonymous() bool { - a.RLock() - ok := len(a.whitelist) == 0 - a.RUnlock() - return ok + return a.whitelist.Len() == 0 } // Check determines if a pubkey fingerprint is permitted. -func (a Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) { +func (a *Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) { authkey := NewAuthKey(key) - a.RLock() - defer a.RUnlock() - - if len(a.whitelist) > 0 { + if a.whitelist.Len() != 0 { // Only check whitelist if there is something in it, otherwise it's disabled. - - _, whitelisted := a.whitelist[authkey] + whitelisted := a.whitelist.In(authkey) if !whitelisted { return false, ErrNotWhitelisted } return true, nil } - _, banned := a.banned[authkey] + banned := a.banned.In(authkey) if !banned { - _, banned = a.bannedAddr[NewAuthAddr(addr)] + banned = a.bannedAddr.In(NewAuthAddr(addr)) } if banned { return false, ErrBanned @@ -91,60 +83,68 @@ func (a Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) { return true, nil } -// Op will set a fingerprint as a known operator. -func (a *Auth) Op(key ssh.PublicKey) { +// Op sets a public key as a known operator. +func (a *Auth) Op(key ssh.PublicKey, d time.Duration) { if key == nil { - // Don't process empty keys. return } authkey := NewAuthKey(key) - a.Lock() - a.ops[authkey] = struct{}{} - a.Unlock() + if d != 0 { + a.ops.AddExpiring(authkey, d) + } else { + a.ops.Add(authkey) + } + logger.Debugf("Added to ops: %s (for %s)", authkey, d) } // IsOp checks if a public key is an op. -func (a Auth) IsOp(key ssh.PublicKey) bool { +func (a *Auth) IsOp(key ssh.PublicKey) bool { if key == nil { return false } authkey := NewAuthKey(key) - a.RLock() - _, ok := a.ops[authkey] - a.RUnlock() - return ok + return a.ops.In(authkey) } // Whitelist will set a public key as a whitelisted user. -func (a *Auth) Whitelist(key ssh.PublicKey) { +func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) { if key == nil { - // Don't process empty keys. return } authkey := NewAuthKey(key) - a.Lock() - a.whitelist[authkey] = struct{}{} - a.Unlock() + if d != 0 { + a.whitelist.AddExpiring(authkey, d) + } else { + a.whitelist.Add(authkey) + } + logger.Debugf("Added to whitelist: %s (for %s)", authkey, d) } // Ban will set a public key as banned. -func (a *Auth) Ban(key ssh.PublicKey) { +func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) { if key == nil { - // Don't process empty keys. return } - authkey := NewAuthKey(key) + a.BanFingerprint(NewAuthKey(key), d) +} - a.Lock() - a.banned[authkey] = struct{}{} - a.Unlock() +// BanFingerprint will set a public key fingerprint as banned. +func (a *Auth) BanFingerprint(authkey string, d time.Duration) { + if d != 0 { + a.banned.AddExpiring(authkey, d) + } else { + a.banned.Add(authkey) + } + logger.Debugf("Added to banned: %s (for %s)", authkey, d) } // Ban will set an IP address as banned. -func (a *Auth) BanAddr(addr net.Addr) { +func (a *Auth) BanAddr(addr net.Addr, d time.Duration) { key := NewAuthAddr(addr) - - a.Lock() - a.bannedAddr[key] = struct{}{} - a.Unlock() + if d != 0 { + a.bannedAddr.AddExpiring(key, d) + } else { + a.bannedAddr.Add(key) + } + logger.Debugf("Added to bannedAddr: %s (for %s)", key, d) } diff --git a/auth_test.go b/auth_test.go index eb29773..981a1d6 100644 --- a/auth_test.go +++ b/auth_test.go @@ -33,7 +33,7 @@ func TestAuthWhitelist(t *testing.T) { t.Error("Failed to permit in default state:", err) } - auth.Whitelist(key) + auth.Whitelist(key, 0) keyClone, err := ClonePublicKey(key) if err != nil { diff --git a/cmd.go b/cmd.go index a342a02..e60acb9 100644 --- a/cmd.go +++ b/cmd.go @@ -116,8 +116,7 @@ func main() { if err != nil { return err } - auth.Op(key) - logger.Debugf("Added admin: %s", sshd.Fingerprint(key)) + auth.Op(key, 0) return nil }) if err != nil { @@ -130,7 +129,7 @@ func main() { if err != nil { return err } - auth.Whitelist(key) + auth.Whitelist(key, 0) logger.Debugf("Whitelisted: %s", line) return nil }) diff --git a/host.go b/host.go index 2e1a57f..d099772 100644 --- a/host.go +++ b/host.go @@ -342,7 +342,7 @@ func (h *Host) InitCommands(c *chat.Commands) { c.Add(chat.Command{ Op: true, Prefix: "/ban", - PrefixHelp: "USER", + PrefixHelp: "USER [DURATION]", Help: "Ban USER from the server.", Handler: func(room *chat.Room, msg chat.CommandMsg) error { // TODO: Would be nice to specify what to ban. Key? Ip? etc. @@ -360,9 +360,14 @@ func (h *Host) InitCommands(c *chat.Commands) { return errors.New("user not found") } + var until time.Duration = 0 + if len(args) > 1 { + until, _ = time.ParseDuration(args[1]) + } + id := target.Identifier.(*Identity) - h.auth.Ban(id.PublicKey()) - h.auth.BanAddr(id.RemoteAddr()) + h.auth.Ban(id.PublicKey(), until) + h.auth.BanAddr(id.RemoteAddr(), until) body := fmt.Sprintf("%s was banned by %s.", target.Name(), msg.From().Name()) room.Send(chat.NewAnnounceMsg(body)) @@ -404,7 +409,7 @@ func (h *Host) InitCommands(c *chat.Commands) { c.Add(chat.Command{ Op: true, Prefix: "/op", - PrefixHelp: "USER", + PrefixHelp: "USER [DURATION]", Help: "Set USER as admin.", Handler: func(room *chat.Room, msg chat.CommandMsg) error { if !room.IsOp(msg.From()) { @@ -412,17 +417,22 @@ func (h *Host) InitCommands(c *chat.Commands) { } args := msg.Args() - if len(args) != 1 { + if len(args) == 0 { return errors.New("must specify user") } + var until time.Duration = 0 + if len(args) > 1 { + until, _ = time.ParseDuration(args[1]) + } + member, ok := room.MemberById(args[0]) if !ok { return errors.New("user not found") } member.Op = true id := member.Identifier.(*Identity) - h.auth.Op(id.PublicKey()) + h.auth.Op(id.PublicKey(), until) body := fmt.Sprintf("Made op by %s.", msg.From().Name()) room.Send(chat.NewSystemMsg(body, member.User)) diff --git a/host_test.go b/host_test.go index 2e79fb0..76bbe6b 100644 --- a/host_test.go +++ b/host_test.go @@ -150,7 +150,7 @@ func TestHostWhitelist(t *testing.T) { } clientpubkey, _ := ssh.NewPublicKey(clientkey.Public()) - auth.Whitelist(clientpubkey) + auth.Whitelist(clientpubkey, 0) err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {}) if err == nil { diff --git a/logger.go b/logger.go index 8fe9842..4fabd05 100644 --- a/logger.go +++ b/logger.go @@ -1,7 +1,16 @@ package main import ( + "bytes" + + "github.com/alexcesaro/log" "github.com/alexcesaro/log/golog" ) var logger *golog.Logger + +func init() { + // Set a default null logger + var b bytes.Buffer + logger = golog.New(&b, log.Debug) +} diff --git a/set.go b/set.go new file mode 100644 index 0000000..86afe13 --- /dev/null +++ b/set.go @@ -0,0 +1,70 @@ +package main + +import ( + "sync" + "time" +) + +type expiringValue struct { + time.Time +} + +func (v expiringValue) Bool() bool { + return time.Now().Before(v.Time) +} + +type value struct{} + +func (v value) Bool() bool { + return true +} + +type SetValue interface { + Bool() bool +} + +// Set with expire-able keys +type Set struct { + lookup map[string]SetValue + sync.Mutex +} + +// NewSet creates a new set. +func NewSet() *Set { + return &Set{ + lookup: map[string]SetValue{}, + } +} + +// Len returns the size of the set right now. +func (s *Set) Len() int { + return len(s.lookup) +} + +// In checks if an item exists in this set. +func (s *Set) In(key string) bool { + s.Lock() + v, ok := s.lookup[key] + if ok && !v.Bool() { + ok = false + delete(s.lookup, key) + } + s.Unlock() + return ok +} + +// Add item to this set, replace if it exists. +func (s *Set) Add(key string) { + s.Lock() + s.lookup[key] = value{} + s.Unlock() +} + +// Add item to this set, replace if it exists. +func (s *Set) AddExpiring(key string, d time.Duration) time.Time { + until := time.Now().Add(d) + s.Lock() + s.lookup[key] = expiringValue{until} + s.Unlock() + return until +} diff --git a/set_test.go b/set_test.go new file mode 100644 index 0000000..0a4b9ea --- /dev/null +++ b/set_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "testing" + "time" +) + +func TestSetExpiring(t *testing.T) { + s := NewSet() + if s.In("foo") { + t.Error("Matched before set.") + } + + s.Add("foo") + if !s.In("foo") { + t.Errorf("Not matched after set") + } + if s.Len() != 1 { + t.Error("Not len 1 after set") + } + + v := expiringValue{time.Now().Add(-time.Nanosecond * 1)} + if v.Bool() { + t.Errorf("expiringValue now is not expiring") + } + + v = expiringValue{time.Now().Add(time.Minute * 2)} + if !v.Bool() { + t.Errorf("expiringValue in 2 minutes is expiring now") + } + + until := s.AddExpiring("bar", time.Minute*2) + if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) { + t.Errorf("until is not a minute after %s: %s", time.Now(), until) + } + val, ok := s.lookup["bar"] + if !ok { + t.Errorf("bar not in lookup") + } + if !until.Equal(val.(expiringValue).Time) { + t.Errorf("bar's until is not equal to the expected value") + } + if !val.Bool() { + t.Errorf("bar expired immediately") + } + + if !s.In("bar") { + t.Errorf("Not matched after timed set") + } + if s.Len() != 2 { + t.Error("Not len 2 after set") + } + + s.AddExpiring("bar", time.Nanosecond*1) + if s.In("bar") { + t.Error("Matched after expired timer") + } +}