diff --git a/client.go b/client.go index f3d886b..da6223a 100644 --- a/client.go +++ b/client.go @@ -415,8 +415,12 @@ func (c *Client) handleShell(channel ssh.Channel) { c.SysMsg("Missing $FINGERPRINT from: /whitelist $FINGERPRINT") } else { fingerprint := parts[1] - c.Server.Whitelist(fingerprint) - c.SysMsg("Added %s to the whitelist", fingerprint) + err = c.Server.Whitelist(fingerprint) + if err != nil { + c.SysMsg("Error adding to whitelist: %s", err) + } else { + c.SysMsg("Added %s to the whitelist", fingerprint) + } } default: diff --git a/server.go b/server.go index bd4230e..ea86f09 100644 --- a/server.go +++ b/server.go @@ -9,6 +9,10 @@ import ( "sync" "syscall" "time" + "net/http" + "io/ioutil" + "encoding/base64" + "errors" "golang.org/x/crypto/ssh" ) @@ -260,11 +264,65 @@ func (s *Server) Op(fingerprint string) { } // Whitelist adds the given fingerprint to the whitelist -func (s *Server) Whitelist(fingerprint string) { - logger.Infof("Adding whitelist: %s", fingerprint) - s.Lock() - s.whitelist[fingerprint] = struct{}{} - s.Unlock() +func (s *Server) Whitelist(fingerprint string) error { + if strings.HasPrefix(fingerprint, "github.com/") { + logger.Infof("Adding github account %s to whitelist", fingerprint) + + keys, err := getGithubKey(fingerprint) + if err != nil { + return err + } + if len(keys) == 0 { + return errors.New(fmt.Sprintf("No github user %s", fingerprint)) + } + for _, key := range keys { + fingerprint = Fingerprint(key) + logger.Infof("Adding whitelist: %s", fingerprint) + s.Lock() + s.whitelist[fingerprint] = struct{}{} + s.Unlock() + } + } else { + logger.Infof("Adding whitelist: %s", fingerprint) + s.Lock() + s.whitelist[fingerprint] = struct{}{} + s.Unlock() + } + return nil +} + +var r *regexp.Regexp = regexp.MustCompile(`ssh-rsa ([A-Za-z0-9\+=\/]+)\s*`) +func getGithubKey(url string) ([]ssh.PublicKey, error) { + resp, err := http.Get("http://" + url + ".keys") + if err != nil { + return nil, err + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + body_str := string(body) + keys := r.FindAllStringSubmatch(body_str, -1) + pubs := make([]ssh.PublicKey, 0, 3) + for _, key := range keys { + if(len(key) < 2) { + continue + } + + body_decoded, err := base64.StdEncoding.DecodeString(key[1]) + if err != nil { + return nil, err + } + + pub, err := ssh.ParsePublicKey(body_decoded) + if err != nil { + return nil, err + } + + pubs = append(pubs, pub) + } + return pubs, nil } // Uptime returns the time since the server was started