package sshchat import ( "crypto/sha256" "crypto/subtle" "encoding/csv" "errors" "fmt" "net" "strings" "sync" "time" "github.com/shazow/ssh-chat/set" "github.com/shazow/ssh-chat/sshd" "golang.org/x/crypto/ssh" ) // KeyLoader loads public keys, e.g. from an authorized_keys file. // It must return a nil slice on error. type KeyLoader func() ([]ssh.PublicKey, error) // ErrNotAllowed Is the error returned when a key is checked that is not allowlisted, // when allowlisting is enabled. var ErrNotAllowed = errors.New("not allowed") // ErrBanned is the error returned when a client is banned. var ErrBanned = errors.New("banned") // ErrIncorrectPassphrase is the error returned when a provided passphrase is incorrect. var ErrIncorrectPassphrase = errors.New("incorrect passphrase") // newAuthKey returns string from an ssh.PublicKey used to index the key in our lookup. func newAuthKey(key ssh.PublicKey) string { if key == nil { return "" } // FIXME: Is there a better way to index pubkeys without marshal'ing them into strings? return sshd.Fingerprint(key) } func newAuthItem(key ssh.PublicKey) set.Item { return set.StringItem(newAuthKey(key)) } // newAuthAddr returns a string from a net.Addr used to index the address the key in our lookup. func newAuthAddr(addr net.Addr) string { if addr == nil { return "" } host, _, _ := net.SplitHostPort(addr.String()) return host } // Auth stores lookups for bans, allowlists, and ops. It implements the sshd.Auth interface. // If the contained passphrase is not empty, it complements a allowlist. type Auth struct { passphraseHash []byte bannedAddr *set.Set bannedClient *set.Set banned *set.Set allowlist *set.Set ops *set.Set settingsMu sync.RWMutex allowlistMode bool opLoader KeyLoader allowlistLoader KeyLoader } // NewAuth creates a new empty Auth. func NewAuth() *Auth { return &Auth{ bannedAddr: set.New(), bannedClient: set.New(), banned: set.New(), allowlist: set.New(), ops: set.New(), } } func (a *Auth) AllowlistMode() bool { a.settingsMu.RLock() defer a.settingsMu.RUnlock() return a.allowlistMode } func (a *Auth) SetAllowlistMode(value bool) { a.settingsMu.Lock() defer a.settingsMu.Unlock() a.allowlistMode = value } // SetPassphrase enables passphrase authentication with the given passphrase. // If an empty passphrase is given, disable passphrase authentication. func (a *Auth) SetPassphrase(passphrase string) { if passphrase == "" { a.passphraseHash = nil } else { hashArray := sha256.Sum256([]byte(passphrase)) a.passphraseHash = hashArray[:] } } // AllowAnonymous determines if anonymous users are permitted. func (a *Auth) AllowAnonymous() bool { return !a.AllowlistMode() && a.passphraseHash == nil } // AcceptPassphrase determines if passphrase authentication is accepted. func (a *Auth) AcceptPassphrase() bool { return a.passphraseHash != nil } // CheckBans checks IP, key and client bans. func (a *Auth) CheckBans(addr net.Addr, key ssh.PublicKey, clientVersion string) error { authkey := newAuthKey(key) var banned bool if authkey != "" { banned = a.banned.In(authkey) } if !banned { banned = a.bannedAddr.In(newAuthAddr(addr)) } if !banned { banned = a.bannedClient.In(clientVersion) } // Ops can bypass bans, just in case we ban ourselves. if banned && !a.IsOp(key) { return ErrBanned } return nil } // CheckPubkey determines if a pubkey fingerprint is permitted. func (a *Auth) CheckPublicKey(key ssh.PublicKey) error { authkey := newAuthKey(key) allowlisted := a.allowlist.In(authkey) if a.AllowAnonymous() || allowlisted || a.IsOp(key) { return nil } else { return ErrNotAllowed } } // CheckPassphrase determines if a passphrase is permitted. func (a *Auth) CheckPassphrase(passphrase string) error { if !a.AcceptPassphrase() { return errors.New("passphrases not accepted") // this should never happen } passedPassphraseHash := sha256.Sum256([]byte(passphrase)) if subtle.ConstantTimeCompare(passedPassphraseHash[:], a.passphraseHash) == 0 { return ErrIncorrectPassphrase } return nil } // Op sets a public key as a known operator. func (a *Auth) Op(key ssh.PublicKey, d time.Duration) { if key == nil { return } authItem := newAuthItem(key) if d != 0 { a.ops.Set(set.Expire(authItem, d)) } else { a.ops.Set(authItem) } logger.Debugf("Added to ops: %q (for %s)", authItem.Key(), d) } // IsOp checks if a public key is an op. func (a *Auth) IsOp(key ssh.PublicKey) bool { authkey := newAuthKey(key) return a.ops.In(authkey) } // LoadOps sets the public keys form loader to operators and saves the loader for later use func (a *Auth) LoadOps(loader KeyLoader) error { a.settingsMu.Lock() a.opLoader = loader a.settingsMu.Unlock() return a.ReloadOps() } // ReloadOps sets the public keys from a loader saved in the last call to operators func (a *Auth) ReloadOps() error { a.settingsMu.RLock() defer a.settingsMu.RUnlock() return addFromLoader(a.opLoader, a.Op) } // Allowlist will set a public key as a allowlisted user. func (a *Auth) Allowlist(key ssh.PublicKey, d time.Duration) { if key == nil { return } var err error authItem := newAuthItem(key) if d != 0 { err = a.allowlist.Set(set.Expire(authItem, d)) } else { err = a.allowlist.Set(authItem) } if err == nil { logger.Debugf("Added to allowlist: %q (for %s)", authItem.Key(), d) } else { logger.Errorf("Error adding %q to allowlist for %s: %s", authItem.Key(), d, err) } } // LoadAllowlist adds the public keys from the loader to the allowlist and saves the loader for later use func (a *Auth) LoadAllowlist(loader KeyLoader) error { a.settingsMu.Lock() a.allowlistLoader = loader a.settingsMu.Unlock() return a.ReloadAllowlist() } // LoadAllowlist adds the public keys from a loader saved in a previous call to the allowlist func (a *Auth) ReloadAllowlist() error { a.settingsMu.RLock() defer a.settingsMu.RUnlock() return addFromLoader(a.allowlistLoader, a.Allowlist) } func addFromLoader(loader KeyLoader, adder func(ssh.PublicKey, time.Duration)) error { if loader == nil { return nil } keys, err := loader() for _, key := range keys { adder(key, 0) } return err } // Ban will set a public key as banned. func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) { if key == nil { return } a.BanFingerprint(newAuthKey(key), d) } // BanFingerprint will set a public key fingerprint as banned. func (a *Auth) BanFingerprint(authkey string, d time.Duration) { // FIXME: This is a case insensitive key, which isn't great... authItem := set.StringItem(authkey) if d != 0 { a.banned.Set(set.Expire(authItem, d)) } else { a.banned.Set(authItem) } logger.Debugf("Added to banned: %q (for %s)", authItem.Key(), d) } // BanClient will set client version as banned. Useful for misbehaving bots. func (a *Auth) BanClient(client string, d time.Duration) { item := set.StringItem(client) if d != 0 { a.bannedClient.Set(set.Expire(item, d)) } else { a.bannedClient.Set(item) } logger.Debugf("Added to banned: %q (for %s)", item.Key(), d) } // Banned returns the list of banned keys. func (a *Auth) Banned() (ip []string, fingerprint []string, client []string) { a.banned.Each(func(key string, _ set.Item) error { fingerprint = append(fingerprint, key) return nil }) a.bannedAddr.Each(func(key string, _ set.Item) error { ip = append(ip, key) return nil }) a.bannedClient.Each(func(key string, _ set.Item) error { client = append(client, key) return nil }) return } // BanAddr will set an IP address as banned. func (a *Auth) BanAddr(addr net.Addr, d time.Duration) { authItem := set.StringItem(newAuthAddr(addr)) if d != 0 { a.bannedAddr.Set(set.Expire(authItem, d)) } else { a.bannedAddr.Set(authItem) } logger.Debugf("Added to bannedAddr: %q (for %s)", authItem.Key(), d) } // BanQuery takes space-separated key="value" pairs to ban, including ip, fingerprint, client. // Fields without an = will be treated as a duration, applied to the next field. // For example: 5s client=foo 10min ip=1.1.1.1 // Will ban client foo for 5 seconds, and ip 1.1.1.1 for 10min. func (a *Auth) BanQuery(q string) error { r := csv.NewReader(strings.NewReader(q)) r.Comma = ' ' fields, err := r.Read() if err != nil { return err } var d time.Duration if last := fields[len(fields)-1]; !strings.Contains(last, "=") { d, err = time.ParseDuration(last) if err != nil { return err } fields = fields[:len(fields)-1] } for _, field := range fields { parts := strings.SplitN(field, "=", 2) if len(parts) != 2 { return fmt.Errorf("invalid query: %q", q) } key, value := parts[0], parts[1] switch key { case "client": a.BanClient(value, d) case "fingerprint": // TODO: Add a validity check? a.BanFingerprint(value, d) case "ip": ip := net.ParseIP(value) if ip.String() == "" { return fmt.Errorf("invalid ip value: %q", ip) } a.BanAddr(&net.TCPAddr{IP: ip}, d) default: return fmt.Errorf("unknown query field: %q", field) } } return nil }