mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-13 07:37:17 +03:00
/ban and /op now support durations, also all other auth things in the api.
This commit is contained in:
parent
797d8c92be
commit
69ea63bf88
104
auth.go
104
auth.go
@ -4,7 +4,9 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/shazow/ssh-chat/sshd"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
@ -14,16 +16,13 @@ var ErrNotWhitelisted = errors.New("not whitelisted")
|
||||
// The error returned a key is checked that is banned.
|
||||
var ErrBanned = errors.New("banned")
|
||||
|
||||
// AuthKey is the type that our lookups are keyed against.
|
||||
type AuthKey string
|
||||
|
||||
// NewAuthKey returns string from an ssh.PublicKey.
|
||||
func NewAuthKey(key ssh.PublicKey) string {
|
||||
if key == nil {
|
||||
return ""
|
||||
}
|
||||
// FIXME: Is there a way to index pubkeys without marshal'ing them into strings?
|
||||
return string(key.Marshal())
|
||||
return sshd.Fingerprint(key)
|
||||
}
|
||||
|
||||
// NewAuthAddr returns a string from a net.Addr
|
||||
@ -39,50 +38,43 @@ func NewAuthAddr(addr net.Addr) string {
|
||||
// TODO: Add timed auth by using a time.Time instead of struct{} for values.
|
||||
type Auth struct {
|
||||
sync.RWMutex
|
||||
bannedAddr map[string]struct{}
|
||||
banned map[string]struct{}
|
||||
whitelist map[string]struct{}
|
||||
ops map[string]struct{}
|
||||
bannedAddr *Set
|
||||
banned *Set
|
||||
whitelist *Set
|
||||
ops *Set
|
||||
}
|
||||
|
||||
// NewAuth creates a new default Auth.
|
||||
func NewAuth() *Auth {
|
||||
return &Auth{
|
||||
bannedAddr: make(map[string]struct{}),
|
||||
banned: make(map[string]struct{}),
|
||||
whitelist: make(map[string]struct{}),
|
||||
ops: make(map[string]struct{}),
|
||||
bannedAddr: NewSet(),
|
||||
banned: NewSet(),
|
||||
whitelist: NewSet(),
|
||||
ops: NewSet(),
|
||||
}
|
||||
}
|
||||
|
||||
// AllowAnonymous determines if anonymous users are permitted.
|
||||
func (a Auth) AllowAnonymous() bool {
|
||||
a.RLock()
|
||||
ok := len(a.whitelist) == 0
|
||||
a.RUnlock()
|
||||
return ok
|
||||
return a.whitelist.Len() == 0
|
||||
}
|
||||
|
||||
// Check determines if a pubkey fingerprint is permitted.
|
||||
func (a Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) {
|
||||
func (a *Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) {
|
||||
authkey := NewAuthKey(key)
|
||||
|
||||
a.RLock()
|
||||
defer a.RUnlock()
|
||||
|
||||
if len(a.whitelist) > 0 {
|
||||
if a.whitelist.Len() != 0 {
|
||||
// Only check whitelist if there is something in it, otherwise it's disabled.
|
||||
|
||||
_, whitelisted := a.whitelist[authkey]
|
||||
whitelisted := a.whitelist.In(authkey)
|
||||
if !whitelisted {
|
||||
return false, ErrNotWhitelisted
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
_, banned := a.banned[authkey]
|
||||
banned := a.banned.In(authkey)
|
||||
if !banned {
|
||||
_, banned = a.bannedAddr[NewAuthAddr(addr)]
|
||||
banned = a.bannedAddr.In(NewAuthAddr(addr))
|
||||
}
|
||||
if banned {
|
||||
return false, ErrBanned
|
||||
@ -91,60 +83,68 @@ func (a Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Op will set a fingerprint as a known operator.
|
||||
func (a *Auth) Op(key ssh.PublicKey) {
|
||||
// Op sets a public key as a known operator.
|
||||
func (a *Auth) Op(key ssh.PublicKey, d time.Duration) {
|
||||
if key == nil {
|
||||
// Don't process empty keys.
|
||||
return
|
||||
}
|
||||
authkey := NewAuthKey(key)
|
||||
a.Lock()
|
||||
a.ops[authkey] = struct{}{}
|
||||
a.Unlock()
|
||||
if d != 0 {
|
||||
a.ops.AddExpiring(authkey, d)
|
||||
} else {
|
||||
a.ops.Add(authkey)
|
||||
}
|
||||
logger.Debugf("Added to ops: %s (for %s)", authkey, d)
|
||||
}
|
||||
|
||||
// IsOp checks if a public key is an op.
|
||||
func (a Auth) IsOp(key ssh.PublicKey) bool {
|
||||
func (a *Auth) IsOp(key ssh.PublicKey) bool {
|
||||
if key == nil {
|
||||
return false
|
||||
}
|
||||
authkey := NewAuthKey(key)
|
||||
a.RLock()
|
||||
_, ok := a.ops[authkey]
|
||||
a.RUnlock()
|
||||
return ok
|
||||
return a.ops.In(authkey)
|
||||
}
|
||||
|
||||
// Whitelist will set a public key as a whitelisted user.
|
||||
func (a *Auth) Whitelist(key ssh.PublicKey) {
|
||||
func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) {
|
||||
if key == nil {
|
||||
// Don't process empty keys.
|
||||
return
|
||||
}
|
||||
authkey := NewAuthKey(key)
|
||||
a.Lock()
|
||||
a.whitelist[authkey] = struct{}{}
|
||||
a.Unlock()
|
||||
if d != 0 {
|
||||
a.whitelist.AddExpiring(authkey, d)
|
||||
} else {
|
||||
a.whitelist.Add(authkey)
|
||||
}
|
||||
logger.Debugf("Added to whitelist: %s (for %s)", authkey, d)
|
||||
}
|
||||
|
||||
// Ban will set a public key as banned.
|
||||
func (a *Auth) Ban(key ssh.PublicKey) {
|
||||
func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) {
|
||||
if key == nil {
|
||||
// Don't process empty keys.
|
||||
return
|
||||
}
|
||||
authkey := NewAuthKey(key)
|
||||
a.BanFingerprint(NewAuthKey(key), d)
|
||||
}
|
||||
|
||||
a.Lock()
|
||||
a.banned[authkey] = struct{}{}
|
||||
a.Unlock()
|
||||
// BanFingerprint will set a public key fingerprint as banned.
|
||||
func (a *Auth) BanFingerprint(authkey string, d time.Duration) {
|
||||
if d != 0 {
|
||||
a.banned.AddExpiring(authkey, d)
|
||||
} else {
|
||||
a.banned.Add(authkey)
|
||||
}
|
||||
logger.Debugf("Added to banned: %s (for %s)", authkey, d)
|
||||
}
|
||||
|
||||
// Ban will set an IP address as banned.
|
||||
func (a *Auth) BanAddr(addr net.Addr) {
|
||||
func (a *Auth) BanAddr(addr net.Addr, d time.Duration) {
|
||||
key := NewAuthAddr(addr)
|
||||
|
||||
a.Lock()
|
||||
a.bannedAddr[key] = struct{}{}
|
||||
a.Unlock()
|
||||
if d != 0 {
|
||||
a.bannedAddr.AddExpiring(key, d)
|
||||
} else {
|
||||
a.bannedAddr.Add(key)
|
||||
}
|
||||
logger.Debugf("Added to bannedAddr: %s (for %s)", key, d)
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ func TestAuthWhitelist(t *testing.T) {
|
||||
t.Error("Failed to permit in default state:", err)
|
||||
}
|
||||
|
||||
auth.Whitelist(key)
|
||||
auth.Whitelist(key, 0)
|
||||
|
||||
keyClone, err := ClonePublicKey(key)
|
||||
if err != nil {
|
||||
|
5
cmd.go
5
cmd.go
@ -116,8 +116,7 @@ func main() {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
auth.Op(key)
|
||||
logger.Debugf("Added admin: %s", sshd.Fingerprint(key))
|
||||
auth.Op(key, 0)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@ -130,7 +129,7 @@ func main() {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
auth.Whitelist(key)
|
||||
auth.Whitelist(key, 0)
|
||||
logger.Debugf("Whitelisted: %s", line)
|
||||
return nil
|
||||
})
|
||||
|
22
host.go
22
host.go
@ -342,7 +342,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
|
||||
c.Add(chat.Command{
|
||||
Op: true,
|
||||
Prefix: "/ban",
|
||||
PrefixHelp: "USER",
|
||||
PrefixHelp: "USER [DURATION]",
|
||||
Help: "Ban USER from the server.",
|
||||
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
|
||||
// TODO: Would be nice to specify what to ban. Key? Ip? etc.
|
||||
@ -360,9 +360,14 @@ func (h *Host) InitCommands(c *chat.Commands) {
|
||||
return errors.New("user not found")
|
||||
}
|
||||
|
||||
var until time.Duration = 0
|
||||
if len(args) > 1 {
|
||||
until, _ = time.ParseDuration(args[1])
|
||||
}
|
||||
|
||||
id := target.Identifier.(*Identity)
|
||||
h.auth.Ban(id.PublicKey())
|
||||
h.auth.BanAddr(id.RemoteAddr())
|
||||
h.auth.Ban(id.PublicKey(), until)
|
||||
h.auth.BanAddr(id.RemoteAddr(), until)
|
||||
|
||||
body := fmt.Sprintf("%s was banned by %s.", target.Name(), msg.From().Name())
|
||||
room.Send(chat.NewAnnounceMsg(body))
|
||||
@ -404,7 +409,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
|
||||
c.Add(chat.Command{
|
||||
Op: true,
|
||||
Prefix: "/op",
|
||||
PrefixHelp: "USER",
|
||||
PrefixHelp: "USER [DURATION]",
|
||||
Help: "Set USER as admin.",
|
||||
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
|
||||
if !room.IsOp(msg.From()) {
|
||||
@ -412,17 +417,22 @@ func (h *Host) InitCommands(c *chat.Commands) {
|
||||
}
|
||||
|
||||
args := msg.Args()
|
||||
if len(args) != 1 {
|
||||
if len(args) == 0 {
|
||||
return errors.New("must specify user")
|
||||
}
|
||||
|
||||
var until time.Duration = 0
|
||||
if len(args) > 1 {
|
||||
until, _ = time.ParseDuration(args[1])
|
||||
}
|
||||
|
||||
member, ok := room.MemberById(args[0])
|
||||
if !ok {
|
||||
return errors.New("user not found")
|
||||
}
|
||||
member.Op = true
|
||||
id := member.Identifier.(*Identity)
|
||||
h.auth.Op(id.PublicKey())
|
||||
h.auth.Op(id.PublicKey(), until)
|
||||
|
||||
body := fmt.Sprintf("Made op by %s.", msg.From().Name())
|
||||
room.Send(chat.NewSystemMsg(body, member.User))
|
||||
|
@ -150,7 +150,7 @@ func TestHostWhitelist(t *testing.T) {
|
||||
}
|
||||
|
||||
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
|
||||
auth.Whitelist(clientpubkey)
|
||||
auth.Whitelist(clientpubkey, 0)
|
||||
|
||||
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
||||
if err == nil {
|
||||
|
@ -1,7 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/alexcesaro/log"
|
||||
"github.com/alexcesaro/log/golog"
|
||||
)
|
||||
|
||||
var logger *golog.Logger
|
||||
|
||||
func init() {
|
||||
// Set a default null logger
|
||||
var b bytes.Buffer
|
||||
logger = golog.New(&b, log.Debug)
|
||||
}
|
||||
|
70
set.go
Normal file
70
set.go
Normal file
@ -0,0 +1,70 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type expiringValue struct {
|
||||
time.Time
|
||||
}
|
||||
|
||||
func (v expiringValue) Bool() bool {
|
||||
return time.Now().Before(v.Time)
|
||||
}
|
||||
|
||||
type value struct{}
|
||||
|
||||
func (v value) Bool() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type SetValue interface {
|
||||
Bool() bool
|
||||
}
|
||||
|
||||
// Set with expire-able keys
|
||||
type Set struct {
|
||||
lookup map[string]SetValue
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// NewSet creates a new set.
|
||||
func NewSet() *Set {
|
||||
return &Set{
|
||||
lookup: map[string]SetValue{},
|
||||
}
|
||||
}
|
||||
|
||||
// Len returns the size of the set right now.
|
||||
func (s *Set) Len() int {
|
||||
return len(s.lookup)
|
||||
}
|
||||
|
||||
// In checks if an item exists in this set.
|
||||
func (s *Set) In(key string) bool {
|
||||
s.Lock()
|
||||
v, ok := s.lookup[key]
|
||||
if ok && !v.Bool() {
|
||||
ok = false
|
||||
delete(s.lookup, key)
|
||||
}
|
||||
s.Unlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
// Add item to this set, replace if it exists.
|
||||
func (s *Set) Add(key string) {
|
||||
s.Lock()
|
||||
s.lookup[key] = value{}
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
// Add item to this set, replace if it exists.
|
||||
func (s *Set) AddExpiring(key string, d time.Duration) time.Time {
|
||||
until := time.Now().Add(d)
|
||||
s.Lock()
|
||||
s.lookup[key] = expiringValue{until}
|
||||
s.Unlock()
|
||||
return until
|
||||
}
|
58
set_test.go
Normal file
58
set_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSetExpiring(t *testing.T) {
|
||||
s := NewSet()
|
||||
if s.In("foo") {
|
||||
t.Error("Matched before set.")
|
||||
}
|
||||
|
||||
s.Add("foo")
|
||||
if !s.In("foo") {
|
||||
t.Errorf("Not matched after set")
|
||||
}
|
||||
if s.Len() != 1 {
|
||||
t.Error("Not len 1 after set")
|
||||
}
|
||||
|
||||
v := expiringValue{time.Now().Add(-time.Nanosecond * 1)}
|
||||
if v.Bool() {
|
||||
t.Errorf("expiringValue now is not expiring")
|
||||
}
|
||||
|
||||
v = expiringValue{time.Now().Add(time.Minute * 2)}
|
||||
if !v.Bool() {
|
||||
t.Errorf("expiringValue in 2 minutes is expiring now")
|
||||
}
|
||||
|
||||
until := s.AddExpiring("bar", time.Minute*2)
|
||||
if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
|
||||
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
|
||||
}
|
||||
val, ok := s.lookup["bar"]
|
||||
if !ok {
|
||||
t.Errorf("bar not in lookup")
|
||||
}
|
||||
if !until.Equal(val.(expiringValue).Time) {
|
||||
t.Errorf("bar's until is not equal to the expected value")
|
||||
}
|
||||
if !val.Bool() {
|
||||
t.Errorf("bar expired immediately")
|
||||
}
|
||||
|
||||
if !s.In("bar") {
|
||||
t.Errorf("Not matched after timed set")
|
||||
}
|
||||
if s.Len() != 2 {
|
||||
t.Error("Not len 2 after set")
|
||||
}
|
||||
|
||||
s.AddExpiring("bar", time.Nanosecond*1)
|
||||
if s.In("bar") {
|
||||
t.Error("Matched after expired timer")
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user