diff --git a/cmd.go b/cmd.go index 29c5563..266eeb2 100644 --- a/cmd.go +++ b/cmd.go @@ -1,6 +1,7 @@ package main import ( + "bufio" "fmt" "io/ioutil" "os" @@ -12,10 +13,11 @@ import ( ) type Options struct { - Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."` - Identity string `short:"i" long:"identity" description:"Private key to identify server with." default:"~/.ssh/id_rsa"` - Bind string `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:22"` - Admin []string `long:"admin" description:"Fingerprint of pubkey to mark as admin."` + Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."` + Identity string `short:"i" long:"identity" description:"Private key to identify server with." default:"~/.ssh/id_rsa"` + Bind string `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:22"` + Admin []string `long:"admin" description:"Fingerprint of pubkey to mark as admin."` + Whitelist string `long:"whitelist" description:"Optional file of pubkey fingerprints that are allowed to connect"` } var logLevels = []log.Level{ @@ -60,6 +62,24 @@ func main() { return } + for _, fingerprint := range options.Admin { + server.Op(fingerprint) + } + + if options.Whitelist != "" { + file, err := os.Open(options.Whitelist) + if err != nil { + logger.Errorf("Could not open whitelist file") + return + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + server.Whitelist(scanner.Text()) + } + } + // Construct interrupt handler sig := make(chan os.Signal, 1) signal.Notify(sig, os.Interrupt) @@ -70,10 +90,6 @@ func main() { return } - for _, fingerprint := range options.Admin { - server.Op(fingerprint) - } - <-sig // Wait for ^C signal logger.Warningf("Interrupt signal detected, shutting down.") server.Stop() diff --git a/server.go b/server.go index bf86515..b6a209b 100644 --- a/server.go +++ b/server.go @@ -32,6 +32,7 @@ type Server struct { count int history *History motd string + whitelist map[string]struct{} // fingerprint lookup admins map[string]struct{} // fingerprint lookup bannedPk map[string]*time.Time // fingerprint lookup bannedIp map[net.Addr]*time.Time @@ -45,15 +46,16 @@ func NewServer(privateKey []byte) (*Server, error) { } server := Server{ - done: make(chan struct{}), - clients: Clients{}, - count: 0, - history: NewHistory(HISTORY_LEN), - motd: "Message of the Day! Modify with /motd", - admins: map[string]struct{}{}, - bannedPk: map[string]*time.Time{}, - bannedIp: map[net.Addr]*time.Time{}, - started: time.Now(), + done: make(chan struct{}), + clients: Clients{}, + count: 0, + history: NewHistory(HISTORY_LEN), + motd: "Message of the Day! Modify with /motd", + whitelist: map[string]struct{}{}, + admins: map[string]struct{}{}, + bannedPk: map[string]*time.Time{}, + bannedIp: map[net.Addr]*time.Time{}, + started: time.Now(), } config := ssh.ServerConfig{ @@ -64,6 +66,9 @@ func NewServer(privateKey []byte) (*Server, error) { if server.IsBanned(fingerprint) { return nil, fmt.Errorf("Banned.") } + if !server.IsWhitelisted(fingerprint) { + return nil, fmt.Errorf("Not Whitelisted.") + } perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}} return perm, nil }, @@ -230,6 +235,13 @@ func (s *Server) Op(fingerprint string) { s.lock.Unlock() } +func (s *Server) Whitelist(fingerprint string) { + logger.Infof("Adding whitelist: %s", fingerprint) + s.lock.Lock() + s.whitelist[fingerprint] = struct{}{} + s.lock.Unlock() +} + func (s *Server) Uptime() string { return time.Now().Sub(s.started).String() } @@ -239,6 +251,17 @@ func (s *Server) IsOp(client *Client) bool { return r } +func (s *Server) IsWhitelisted(fingerprint string) bool { + /* if no whitelist, anyone is welcome */ + if len(s.whitelist) == 0 { + return true + } + + /* otherwise, check for whitelist presence */ + _, r := s.whitelist[fingerprint] + return r +} + func (s *Server) IsBanned(fingerprint string) bool { ban, hasBan := s.bannedPk[fingerprint] if !hasBan {