/ban and /op now support durations, also all other auth things in the api.

This commit is contained in:
Andrey Petrov 2015-01-19 19:16:37 -08:00
parent 797d8c92be
commit 69ea63bf88
8 changed files with 209 additions and 63 deletions

104
auth.go
View File

@ -4,7 +4,9 @@ import (
"errors" "errors"
"net" "net"
"sync" "sync"
"time"
"github.com/shazow/ssh-chat/sshd"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -14,16 +16,13 @@ var ErrNotWhitelisted = errors.New("not whitelisted")
// The error returned a key is checked that is banned. // The error returned a key is checked that is banned.
var ErrBanned = errors.New("banned") var ErrBanned = errors.New("banned")
// AuthKey is the type that our lookups are keyed against.
type AuthKey string
// NewAuthKey returns string from an ssh.PublicKey. // NewAuthKey returns string from an ssh.PublicKey.
func NewAuthKey(key ssh.PublicKey) string { func NewAuthKey(key ssh.PublicKey) string {
if key == nil { if key == nil {
return "" return ""
} }
// FIXME: Is there a way to index pubkeys without marshal'ing them into strings? // FIXME: Is there a way to index pubkeys without marshal'ing them into strings?
return string(key.Marshal()) return sshd.Fingerprint(key)
} }
// NewAuthAddr returns a string from a net.Addr // NewAuthAddr returns a string from a net.Addr
@ -39,50 +38,43 @@ func NewAuthAddr(addr net.Addr) string {
// TODO: Add timed auth by using a time.Time instead of struct{} for values. // TODO: Add timed auth by using a time.Time instead of struct{} for values.
type Auth struct { type Auth struct {
sync.RWMutex sync.RWMutex
bannedAddr map[string]struct{} bannedAddr *Set
banned map[string]struct{} banned *Set
whitelist map[string]struct{} whitelist *Set
ops map[string]struct{} ops *Set
} }
// NewAuth creates a new default Auth. // NewAuth creates a new default Auth.
func NewAuth() *Auth { func NewAuth() *Auth {
return &Auth{ return &Auth{
bannedAddr: make(map[string]struct{}), bannedAddr: NewSet(),
banned: make(map[string]struct{}), banned: NewSet(),
whitelist: make(map[string]struct{}), whitelist: NewSet(),
ops: make(map[string]struct{}), ops: NewSet(),
} }
} }
// AllowAnonymous determines if anonymous users are permitted. // AllowAnonymous determines if anonymous users are permitted.
func (a Auth) AllowAnonymous() bool { func (a Auth) AllowAnonymous() bool {
a.RLock() return a.whitelist.Len() == 0
ok := len(a.whitelist) == 0
a.RUnlock()
return ok
} }
// Check determines if a pubkey fingerprint is permitted. // Check determines if a pubkey fingerprint is permitted.
func (a Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) { func (a *Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) {
authkey := NewAuthKey(key) authkey := NewAuthKey(key)
a.RLock() if a.whitelist.Len() != 0 {
defer a.RUnlock()
if len(a.whitelist) > 0 {
// Only check whitelist if there is something in it, otherwise it's disabled. // Only check whitelist if there is something in it, otherwise it's disabled.
whitelisted := a.whitelist.In(authkey)
_, whitelisted := a.whitelist[authkey]
if !whitelisted { if !whitelisted {
return false, ErrNotWhitelisted return false, ErrNotWhitelisted
} }
return true, nil return true, nil
} }
_, banned := a.banned[authkey] banned := a.banned.In(authkey)
if !banned { if !banned {
_, banned = a.bannedAddr[NewAuthAddr(addr)] banned = a.bannedAddr.In(NewAuthAddr(addr))
} }
if banned { if banned {
return false, ErrBanned return false, ErrBanned
@ -91,60 +83,68 @@ func (a Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) {
return true, nil return true, nil
} }
// Op will set a fingerprint as a known operator. // Op sets a public key as a known operator.
func (a *Auth) Op(key ssh.PublicKey) { func (a *Auth) Op(key ssh.PublicKey, d time.Duration) {
if key == nil { if key == nil {
// Don't process empty keys.
return return
} }
authkey := NewAuthKey(key) authkey := NewAuthKey(key)
a.Lock() if d != 0 {
a.ops[authkey] = struct{}{} a.ops.AddExpiring(authkey, d)
a.Unlock() } else {
a.ops.Add(authkey)
}
logger.Debugf("Added to ops: %s (for %s)", authkey, d)
} }
// IsOp checks if a public key is an op. // IsOp checks if a public key is an op.
func (a Auth) IsOp(key ssh.PublicKey) bool { func (a *Auth) IsOp(key ssh.PublicKey) bool {
if key == nil { if key == nil {
return false return false
} }
authkey := NewAuthKey(key) authkey := NewAuthKey(key)
a.RLock() return a.ops.In(authkey)
_, ok := a.ops[authkey]
a.RUnlock()
return ok
} }
// Whitelist will set a public key as a whitelisted user. // Whitelist will set a public key as a whitelisted user.
func (a *Auth) Whitelist(key ssh.PublicKey) { func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) {
if key == nil { if key == nil {
// Don't process empty keys.
return return
} }
authkey := NewAuthKey(key) authkey := NewAuthKey(key)
a.Lock() if d != 0 {
a.whitelist[authkey] = struct{}{} a.whitelist.AddExpiring(authkey, d)
a.Unlock() } else {
a.whitelist.Add(authkey)
}
logger.Debugf("Added to whitelist: %s (for %s)", authkey, d)
} }
// Ban will set a public key as banned. // Ban will set a public key as banned.
func (a *Auth) Ban(key ssh.PublicKey) { func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) {
if key == nil { if key == nil {
// Don't process empty keys.
return return
} }
authkey := NewAuthKey(key) a.BanFingerprint(NewAuthKey(key), d)
}
a.Lock() // BanFingerprint will set a public key fingerprint as banned.
a.banned[authkey] = struct{}{} func (a *Auth) BanFingerprint(authkey string, d time.Duration) {
a.Unlock() if d != 0 {
a.banned.AddExpiring(authkey, d)
} else {
a.banned.Add(authkey)
}
logger.Debugf("Added to banned: %s (for %s)", authkey, d)
} }
// Ban will set an IP address as banned. // Ban will set an IP address as banned.
func (a *Auth) BanAddr(addr net.Addr) { func (a *Auth) BanAddr(addr net.Addr, d time.Duration) {
key := NewAuthAddr(addr) key := NewAuthAddr(addr)
if d != 0 {
a.Lock() a.bannedAddr.AddExpiring(key, d)
a.bannedAddr[key] = struct{}{} } else {
a.Unlock() a.bannedAddr.Add(key)
}
logger.Debugf("Added to bannedAddr: %s (for %s)", key, d)
} }

View File

@ -33,7 +33,7 @@ func TestAuthWhitelist(t *testing.T) {
t.Error("Failed to permit in default state:", err) t.Error("Failed to permit in default state:", err)
} }
auth.Whitelist(key) auth.Whitelist(key, 0)
keyClone, err := ClonePublicKey(key) keyClone, err := ClonePublicKey(key)
if err != nil { if err != nil {

5
cmd.go
View File

@ -116,8 +116,7 @@ func main() {
if err != nil { if err != nil {
return err return err
} }
auth.Op(key) auth.Op(key, 0)
logger.Debugf("Added admin: %s", sshd.Fingerprint(key))
return nil return nil
}) })
if err != nil { if err != nil {
@ -130,7 +129,7 @@ func main() {
if err != nil { if err != nil {
return err return err
} }
auth.Whitelist(key) auth.Whitelist(key, 0)
logger.Debugf("Whitelisted: %s", line) logger.Debugf("Whitelisted: %s", line)
return nil return nil
}) })

22
host.go
View File

@ -342,7 +342,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
c.Add(chat.Command{ c.Add(chat.Command{
Op: true, Op: true,
Prefix: "/ban", Prefix: "/ban",
PrefixHelp: "USER", PrefixHelp: "USER [DURATION]",
Help: "Ban USER from the server.", Help: "Ban USER from the server.",
Handler: func(room *chat.Room, msg chat.CommandMsg) error { Handler: func(room *chat.Room, msg chat.CommandMsg) error {
// TODO: Would be nice to specify what to ban. Key? Ip? etc. // TODO: Would be nice to specify what to ban. Key? Ip? etc.
@ -360,9 +360,14 @@ func (h *Host) InitCommands(c *chat.Commands) {
return errors.New("user not found") return errors.New("user not found")
} }
var until time.Duration = 0
if len(args) > 1 {
until, _ = time.ParseDuration(args[1])
}
id := target.Identifier.(*Identity) id := target.Identifier.(*Identity)
h.auth.Ban(id.PublicKey()) h.auth.Ban(id.PublicKey(), until)
h.auth.BanAddr(id.RemoteAddr()) h.auth.BanAddr(id.RemoteAddr(), until)
body := fmt.Sprintf("%s was banned by %s.", target.Name(), msg.From().Name()) body := fmt.Sprintf("%s was banned by %s.", target.Name(), msg.From().Name())
room.Send(chat.NewAnnounceMsg(body)) room.Send(chat.NewAnnounceMsg(body))
@ -404,7 +409,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
c.Add(chat.Command{ c.Add(chat.Command{
Op: true, Op: true,
Prefix: "/op", Prefix: "/op",
PrefixHelp: "USER", PrefixHelp: "USER [DURATION]",
Help: "Set USER as admin.", Help: "Set USER as admin.",
Handler: func(room *chat.Room, msg chat.CommandMsg) error { Handler: func(room *chat.Room, msg chat.CommandMsg) error {
if !room.IsOp(msg.From()) { if !room.IsOp(msg.From()) {
@ -412,17 +417,22 @@ func (h *Host) InitCommands(c *chat.Commands) {
} }
args := msg.Args() args := msg.Args()
if len(args) != 1 { if len(args) == 0 {
return errors.New("must specify user") return errors.New("must specify user")
} }
var until time.Duration = 0
if len(args) > 1 {
until, _ = time.ParseDuration(args[1])
}
member, ok := room.MemberById(args[0]) member, ok := room.MemberById(args[0])
if !ok { if !ok {
return errors.New("user not found") return errors.New("user not found")
} }
member.Op = true member.Op = true
id := member.Identifier.(*Identity) id := member.Identifier.(*Identity)
h.auth.Op(id.PublicKey()) h.auth.Op(id.PublicKey(), until)
body := fmt.Sprintf("Made op by %s.", msg.From().Name()) body := fmt.Sprintf("Made op by %s.", msg.From().Name())
room.Send(chat.NewSystemMsg(body, member.User)) room.Send(chat.NewSystemMsg(body, member.User))

View File

@ -150,7 +150,7 @@ func TestHostWhitelist(t *testing.T) {
} }
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public()) clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
auth.Whitelist(clientpubkey) auth.Whitelist(clientpubkey, 0)
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {}) err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
if err == nil { if err == nil {

View File

@ -1,7 +1,16 @@
package main package main
import ( import (
"bytes"
"github.com/alexcesaro/log"
"github.com/alexcesaro/log/golog" "github.com/alexcesaro/log/golog"
) )
var logger *golog.Logger var logger *golog.Logger
func init() {
// Set a default null logger
var b bytes.Buffer
logger = golog.New(&b, log.Debug)
}

70
set.go Normal file
View File

@ -0,0 +1,70 @@
package main
import (
"sync"
"time"
)
type expiringValue struct {
time.Time
}
func (v expiringValue) Bool() bool {
return time.Now().Before(v.Time)
}
type value struct{}
func (v value) Bool() bool {
return true
}
type SetValue interface {
Bool() bool
}
// Set with expire-able keys
type Set struct {
lookup map[string]SetValue
sync.Mutex
}
// NewSet creates a new set.
func NewSet() *Set {
return &Set{
lookup: map[string]SetValue{},
}
}
// Len returns the size of the set right now.
func (s *Set) Len() int {
return len(s.lookup)
}
// In checks if an item exists in this set.
func (s *Set) In(key string) bool {
s.Lock()
v, ok := s.lookup[key]
if ok && !v.Bool() {
ok = false
delete(s.lookup, key)
}
s.Unlock()
return ok
}
// Add item to this set, replace if it exists.
func (s *Set) Add(key string) {
s.Lock()
s.lookup[key] = value{}
s.Unlock()
}
// Add item to this set, replace if it exists.
func (s *Set) AddExpiring(key string, d time.Duration) time.Time {
until := time.Now().Add(d)
s.Lock()
s.lookup[key] = expiringValue{until}
s.Unlock()
return until
}

58
set_test.go Normal file
View File

@ -0,0 +1,58 @@
package main
import (
"testing"
"time"
)
func TestSetExpiring(t *testing.T) {
s := NewSet()
if s.In("foo") {
t.Error("Matched before set.")
}
s.Add("foo")
if !s.In("foo") {
t.Errorf("Not matched after set")
}
if s.Len() != 1 {
t.Error("Not len 1 after set")
}
v := expiringValue{time.Now().Add(-time.Nanosecond * 1)}
if v.Bool() {
t.Errorf("expiringValue now is not expiring")
}
v = expiringValue{time.Now().Add(time.Minute * 2)}
if !v.Bool() {
t.Errorf("expiringValue in 2 minutes is expiring now")
}
until := s.AddExpiring("bar", time.Minute*2)
if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
}
val, ok := s.lookup["bar"]
if !ok {
t.Errorf("bar not in lookup")
}
if !until.Equal(val.(expiringValue).Time) {
t.Errorf("bar's until is not equal to the expected value")
}
if !val.Bool() {
t.Errorf("bar expired immediately")
}
if !s.In("bar") {
t.Errorf("Not matched after timed set")
}
if s.Len() != 2 {
t.Error("Not len 2 after set")
}
s.AddExpiring("bar", time.Nanosecond*1)
if s.In("bar") {
t.Error("Matched after expired timer")
}
}