diff --git a/chat/set.go b/chat/set.go index c5aef93..7ee9281 100644 --- a/chat/set.go +++ b/chat/set.go @@ -89,7 +89,7 @@ func (s *Set) Remove(item Item) error { defer s.Unlock() id := item.Id() _, found := s.lookup[id] - if found { + if !found { return ErrItemMissing } delete(s.lookup, id) diff --git a/client.go b/client.go deleted file mode 100644 index 1d19e76..0000000 --- a/client.go +++ /dev/null @@ -1,511 +0,0 @@ -package main - -import ( - "fmt" - "strings" - "sync" - "time" - - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/terminal" -) - -const ( - // MsgBuffer is the length of the message buffer - MsgBuffer int = 50 - - // MaxMsgLength is the maximum length of a message - MaxMsgLength int = 1024 - - // HelpText is the text returned by /help - HelpText string = `Available commands: - /about - About this chat. - /exit - Exit the chat. - /help - Show this help text. - /list - List the users that are currently connected. - /beep - Enable BEL notifications on mention. - /me $ACTION - Show yourself doing an action. - /nick $NAME - Rename yourself to a new name. - /whois $NAME - Display information about another connected user. - /msg $NAME $MESSAGE - Sends a private message to a user. - /motd - Prints the Message of the Day. - /theme [color|mono] - Set client theme.` - - // OpHelpText is the additional text returned by /help if the client is an Op - OpHelpText string = `Available operator commands: - /ban $NAME - Banish a user from the chat - /kick $NAME - Kick em' out. - /op $NAME - Promote a user to server operator. - /silence $NAME - Revoke a user's ability to speak. - /shutdown $MESSAGE - Broadcast message and shutdown server. - /motd $MESSAGE - Set message shown whenever somebody joins. - /whitelist $FINGERPRINT - Add fingerprint to whitelist, prevent anyone else from joining. - /whitelist github.com/$USER - Add github user's pubkeys to whitelist.` - - // AboutText is the text returned by /about - AboutText string = `ssh-chat is made by @shazow. - - It is a custom ssh server built in Go to serve a chat experience - instead of a shell. - - Source: https://github.com/shazow/ssh-chat - - For more, visit shazow.net or follow at twitter.com/shazow` - - // RequiredWait is the time a client is required to wait between messages - RequiredWait time.Duration = time.Second / 2 -) - -// Client holds all the fields used by the client -type Client struct { - Server *Server - Conn *ssh.ServerConn - Msg chan string - Name string - Color string - Op bool - ready chan struct{} - term *terminal.Terminal - termWidth int - termHeight int - silencedUntil time.Time - lastTX time.Time - beepMe bool - colorMe bool - closed bool - sync.RWMutex -} - -// NewClient constructs a new client -func NewClient(server *Server, conn *ssh.ServerConn) *Client { - return &Client{ - Server: server, - Conn: conn, - Name: conn.User(), - Color: RandomColor256(), - Msg: make(chan string, MsgBuffer), - ready: make(chan struct{}, 1), - lastTX: time.Now(), - colorMe: true, - } -} - -// ColoredName returns the client name in its color -func (c *Client) ColoredName() string { - return ColorString(c.Color, c.Name) -} - -// SysMsg sends a message in continuous format over the message channel -func (c *Client) SysMsg(msg string, args ...interface{}) { - c.Send(ContinuousFormat(systemMessageFormat, "-> "+fmt.Sprintf(msg, args...))) -} - -// Write writes the given message -func (c *Client) Write(msg string) { - if !c.colorMe { - msg = DeColorString(msg) - } - c.term.Write([]byte(msg + "\r\n")) -} - -// WriteLines writes multiple messages -func (c *Client) WriteLines(msg []string) { - for _, line := range msg { - c.Write(line) - } -} - -// Send sends the given message -func (c *Client) Send(msg string) { - if len(msg) > MaxMsgLength || c.closed { - return - } - select { - case c.Msg <- msg: - default: - logger.Errorf("Msg buffer full, dropping: %s (%s)", c.Name, c.Conn.RemoteAddr()) - c.Conn.Close() - } -} - -// SendLines sends multiple messages -func (c *Client) SendLines(msg []string) { - for _, line := range msg { - c.Send(line) - } -} - -// IsSilenced checks if the client is silenced -func (c *Client) IsSilenced() bool { - return c.silencedUntil.After(time.Now()) -} - -// Silence silences a client for the given duration -func (c *Client) Silence(d time.Duration) { - c.silencedUntil = time.Now().Add(d) -} - -// Resize resizes the client to the given width and height -func (c *Client) Resize(width, height int) error { - width = 1000000 // TODO: Remove this dirty workaround for text overflow once ssh/terminal is fixed - err := c.term.SetSize(width, height) - if err != nil { - logger.Errorf("Resize failed: %dx%d", width, height) - return err - } - c.termWidth, c.termHeight = width, height - return nil -} - -// Rename renames the client to the given name -func (c *Client) Rename(name string) { - c.Name = name - var prompt string - - if c.colorMe { - prompt = c.ColoredName() - } else { - prompt = c.Name - } - - c.term.SetPrompt(fmt.Sprintf("[%s] ", prompt)) -} - -// Fingerprint returns the fingerprint -func (c *Client) Fingerprint() string { - if c.Conn.Permissions == nil { - return "" - } - return c.Conn.Permissions.Extensions["fingerprint"] -} - -// Emote formats and sends an emote -func (c *Client) Emote(message string) { - formatted := fmt.Sprintf("** %s%s", c.ColoredName(), message) - if c.IsSilenced() || len(message) > 1000 { - c.SysMsg("Message rejected") - } - c.Server.Broadcast(formatted, nil) -} - -func (c *Client) handleShell(channel ssh.Channel) { - defer channel.Close() - defer c.Conn.Close() - - // FIXME: This shouldn't live here, need to restructure the call chaining. - c.Server.Add(c) - go func() { - // Block until done, then remove. - c.Conn.Wait() - c.closed = true - c.Server.Remove(c) - close(c.Msg) - }() - - go func() { - for msg := range c.Msg { - c.Write(msg) - } - }() - - for { - line, err := c.term.ReadLine() - if err != nil { - break - } - - parts := strings.SplitN(line, " ", 3) - isCmd := strings.HasPrefix(parts[0], "/") - - if isCmd { - // TODO: Factor this out. - switch parts[0] { - case "/test-colors": // Shh, this command is a secret! - c.Write(ColorString("32", "Lorem ipsum dolor sit amet,")) - c.Write("consectetur " + ColorString("31;1", "adipiscing") + " elit.") - case "/exit": - channel.Close() - case "/help": - c.SysMsg(strings.Replace(HelpText, "\n", "\r\n", -1)) - if c.Server.IsOp(c) { - c.SysMsg(strings.Replace(OpHelpText, "\n", "\r\n", -1)) - } - case "/about": - c.SysMsg(strings.Replace(AboutText, "\n", "\r\n", -1)) - case "/uptime": - c.SysMsg(c.Server.Uptime()) - case "/beep": - c.beepMe = !c.beepMe - if c.beepMe { - c.SysMsg("I'll beep you good.") - } else { - c.SysMsg("No more beeps. :(") - } - case "/me": - me := strings.TrimLeft(line, "/me") - if me == "" { - me = " is at a loss for words." - } - c.Emote(me) - case "/slap": - slappee := "themself" - if len(parts) > 1 { - slappee = parts[1] - if len(parts[1]) > 100 { - slappee = "some long-named jerk" - } - } - c.Emote(fmt.Sprintf(" slaps %s around a bit with a large trout.", slappee)) - case "/nick": - if len(parts) == 2 { - c.Server.Rename(c, parts[1]) - } else { - c.SysMsg("Missing $NAME from: /nick $NAME") - } - case "/whois": - if len(parts) >= 2 { - client := c.Server.Who(parts[1]) - if client != nil { - version := reStripText.ReplaceAllString(string(client.Conn.ClientVersion()), "") - if len(version) > 100 { - version = "Evil Jerk with a superlong string" - } - c.SysMsg("%s is %s via %s", client.ColoredName(), client.Fingerprint(), version) - } else { - c.SysMsg("No such name: %s", parts[1]) - } - } else { - c.SysMsg("Missing $NAME from: /whois $NAME") - } - case "/names", "/list": - names := "" - nameList := c.Server.List(nil) - for _, name := range nameList { - names += c.Server.Who(name).ColoredName() + systemMessageFormat + ", " - } - if len(names) > 2 { - names = names[:len(names)-2] - } - c.SysMsg("%d connected: %s", len(nameList), names) - case "/ban": - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else if len(parts) != 2 { - c.SysMsg("Missing $NAME from: /ban $NAME") - } else { - client := c.Server.Who(parts[1]) - if client == nil { - c.SysMsg("No such name: %s", parts[1]) - } else { - fingerprint := client.Fingerprint() - client.SysMsg("Banned by %s.", c.ColoredName()) - c.Server.Ban(fingerprint, nil) - client.Conn.Close() - c.Server.Broadcast(fmt.Sprintf("* %s was banned by %s", parts[1], c.ColoredName()), nil) - } - } - case "/op": - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else if len(parts) != 2 { - c.SysMsg("Missing $NAME from: /op $NAME") - } else { - client := c.Server.Who(parts[1]) - if client == nil { - c.SysMsg("No such name: %s", parts[1]) - } else { - fingerprint := client.Fingerprint() - client.SysMsg("Made op by %s.", c.ColoredName()) - c.Server.Op(fingerprint) - } - } - case "/kick": - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else if len(parts) != 2 { - c.SysMsg("Missing $NAME from: /kick $NAME") - } else { - client := c.Server.Who(parts[1]) - if client == nil { - c.SysMsg("No such name: %s", parts[1]) - } else { - client.SysMsg("Kicked by %s.", c.ColoredName()) - client.Conn.Close() - c.Server.Broadcast(fmt.Sprintf("* %s was kicked by %s", parts[1], c.ColoredName()), nil) - } - } - case "/silence": - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else if len(parts) < 2 { - c.SysMsg("Missing $NAME from: /silence $NAME") - } else { - duration := time.Duration(5) * time.Minute - if len(parts) >= 3 { - parsedDuration, err := time.ParseDuration(parts[2]) - if err == nil { - duration = parsedDuration - } - } - client := c.Server.Who(parts[1]) - if client == nil { - c.SysMsg("No such name: %s", parts[1]) - } else { - client.Silence(duration) - client.SysMsg("Silenced for %s by %s.", duration, c.ColoredName()) - } - } - case "/shutdown": - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else { - var split = strings.SplitN(line, " ", 2) - var msg string - if len(split) > 1 { - msg = split[1] - } else { - msg = "" - } - // Shutdown after 5 seconds - go func() { - c.Server.Broadcast(ColorString("31", msg), nil) - time.Sleep(time.Second * 5) - c.Server.Stop() - }() - } - case "/msg": /* Send a PM */ - /* Make sure we have a recipient and a message */ - if len(parts) < 2 { - c.SysMsg("Missing $NAME from: /msg $NAME $MESSAGE") - break - } else if len(parts) < 3 { - c.SysMsg("Missing $MESSAGE from: /msg $NAME $MESSAGE") - break - } - /* Ask the server to send the message */ - if err := c.Server.Privmsg(parts[1], parts[2], c); nil != err { - c.SysMsg("Unable to send message to %v: %v", parts[1], err) - } - case "/motd": /* print motd */ - if !c.Server.IsOp(c) { - c.Server.MotdUnicast(c) - } else if len(parts) < 2 { - c.Server.MotdUnicast(c) - } else { - var newmotd string - if len(parts) == 2 { - newmotd = parts[1] - } else { - newmotd = parts[1] + " " + parts[2] - } - c.Server.SetMotd(newmotd) - c.Server.MotdBroadcast(c) - } - case "/theme": - if len(parts) < 2 { - c.SysMsg("Missing $THEME from: /theme $THEME") - c.SysMsg("Choose either color or mono") - } else { - // Sets colorMe attribute of client - if parts[1] == "mono" { - c.colorMe = false - } else if parts[1] == "color" { - c.colorMe = true - } - // Rename to reset prompt - c.Rename(c.Name) - } - - case "/whitelist": /* whitelist a fingerprint */ - if !c.Server.IsOp(c) { - c.SysMsg("You're not an admin.") - } else if len(parts) != 2 { - c.SysMsg("Missing $FINGERPRINT from: /whitelist $FINGERPRINT") - } else { - fingerprint := parts[1] - go func() { - err = c.Server.Whitelist(fingerprint) - if err != nil { - c.SysMsg("Error adding to whitelist: %s", err) - } else { - c.SysMsg("Added %s to the whitelist", fingerprint) - } - }() - } - case "/version": - c.SysMsg("Version " + buildCommit) - - default: - c.SysMsg("Invalid command: %s", line) - } - continue - } - - msg := fmt.Sprintf("%s: %s", c.ColoredName(), line) - /* Rate limit */ - if time.Now().Sub(c.lastTX) < RequiredWait { - c.SysMsg("Rate limiting in effect.") - continue - } - if c.IsSilenced() || len(msg) > 1000 || len(line) < 1 { - c.SysMsg("Message rejected.") - continue - } - c.Server.Broadcast(msg, c) - c.lastTX = time.Now() - } - -} - -func (c *Client) handleChannels(channels <-chan ssh.NewChannel) { - prompt := fmt.Sprintf("[%s] ", c.ColoredName()) - - hasShell := false - - for ch := range channels { - if t := ch.ChannelType(); t != "session" { - ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) - continue - } - - channel, requests, err := ch.Accept() - if err != nil { - logger.Errorf("Could not accept channel: %v", err) - continue - } - defer channel.Close() - - c.term = terminal.NewTerminal(channel, prompt) - c.term.AutoCompleteCallback = c.Server.AutoCompleteFunction - - for req := range requests { - var width, height int - var ok bool - - switch req.Type { - case "shell": - if c.term != nil && !hasShell { - go c.handleShell(channel) - ok = true - hasShell = true - } - case "pty-req": - width, height, ok = parsePtyRequest(req.Payload) - if ok { - err := c.Resize(width, height) - ok = err == nil - } - case "window-change": - width, height, ok = parseWinchRequest(req.Payload) - if ok { - err := c.Resize(width, height) - ok = err == nil - } - } - - if req.WantReply { - req.Reply(ok, nil) - } - } - } -} diff --git a/cmd.go b/cmd.go index 9d242f5..212e3f2 100644 --- a/cmd.go +++ b/cmd.go @@ -64,6 +64,12 @@ func main() { logLevel := logLevels[numVerbose] logger = golog.New(os.Stderr, logLevel) + if logLevel == log.Debug { + // Enable logging from submodules + chat.SetLogger(os.Stderr) + sshd.SetLogger(os.Stderr) + } + privateKeyPath := options.Identity if strings.HasPrefix(privateKeyPath, "~") { user, err := user.Current() @@ -95,41 +101,7 @@ func main() { } defer s.Close() - terminals := s.ServeTerminal() - channel := chat.NewChannel() - - // TODO: Move this elsewhere - go func() { - for term := range terminals { - go func() { - defer term.Close() - name := term.Conn.User() - term.SetPrompt(fmt.Sprintf("[%s] ", name)) - // TODO: term.AutoCompleteCallback = ... - user := chat.NewUserScreen(name, term) - defer user.Close() - channel.Join(user) - - go func() { - // FIXME: This isn't working. - user.Wait() - channel.Leave(user) - }() - - for { - // TODO: Handle commands etc? - line, err := term.ReadLine() - if err != nil { - break - } - m := chat.NewMessage(line).From(user) - channel.Send(*m) - } - - // TODO: Handle disconnect sooner (currently closes channel before removing) - }() - } - }() + go Serve(s) /* TODO: for _, fingerprint := range options.Admin { diff --git a/colors.go b/colors.go deleted file mode 100644 index 6bfc5ca..0000000 --- a/colors.go +++ /dev/null @@ -1,82 +0,0 @@ -package main - -import ( - "fmt" - "math/rand" - "regexp" - "strings" - "time" -) - -const ( - // Reset resets the color - Reset = "\033[0m" - - // Bold makes the following text bold - Bold = "\033[1m" - - // Dim dims the following text - Dim = "\033[2m" - - // Italic makes the following text italic - Italic = "\033[3m" - - // Underline underlines the following text - Underline = "\033[4m" - - // Blink blinks the following text - Blink = "\033[5m" - - // Invert inverts the following text - Invert = "\033[7m" -) - -var colors = []string{"31", "32", "33", "34", "35", "36", "37", "91", "92", "93", "94", "95", "96", "97"} - -// deColor is used for removing ANSI Escapes -var deColor = regexp.MustCompile("\033\\[[\\d;]+m") - -// DeColorString removes all color from the given string -func DeColorString(s string) string { - s = deColor.ReplaceAllString(s, "") - return s -} - -func randomReadableColor() int { - for { - i := rand.Intn(256) - if (16 <= i && i <= 18) || (232 <= i && i <= 237) { - // Remove the ones near black, this is kinda sadpanda. - continue - } - return i - } -} - -// RandomColor256 returns a random (of 256) color -func RandomColor256() string { - return fmt.Sprintf("38;05;%d", randomReadableColor()) -} - -// RandomColor returns a random color -func RandomColor() string { - return colors[rand.Intn(len(colors))] -} - -// ColorString returns a message in the given color -func ColorString(color string, msg string) string { - return Bold + "\033[" + color + "m" + msg + Reset -} - -// RandomColorInit initializes the random seed -func RandomColorInit() { - rand.Seed(time.Now().UTC().UnixNano()) -} - -// ContinuousFormat is a horrible hack to "continue" the previous string color -// and format after a RESET has been encountered. -// -// This is not HTML where you can just do a to resume your previous formatting! -func ContinuousFormat(format string, str string) string { - return systemMessageFormat + strings.Replace(str, Reset, format, -1) + Reset -} diff --git a/history.go b/history.go deleted file mode 100644 index 74ef513..0000000 --- a/history.go +++ /dev/null @@ -1,59 +0,0 @@ -// TODO: Split this out into its own module, it's kinda neat. -package main - -import "sync" - -// History contains the history entries -type History struct { - entries []string - head int - size int - lock sync.Mutex -} - -// NewHistory constructs a new history of the given size -func NewHistory(size int) *History { - return &History{ - entries: make([]string, size), - } -} - -// Add adds the given entry to the entries in the history -func (h *History) Add(entry string) { - h.lock.Lock() - defer h.lock.Unlock() - - max := cap(h.entries) - h.head = (h.head + 1) % max - h.entries[h.head] = entry - if h.size < max { - h.size++ - } -} - -// Len returns the number of entries in the history -func (h *History) Len() int { - return h.size -} - -// Get the entry with the given number -func (h *History) Get(num int) []string { - h.lock.Lock() - defer h.lock.Unlock() - - max := cap(h.entries) - if num > h.size { - num = h.size - } - - r := make([]string, num) - for i := 0; i < num; i++ { - idx := (h.head - i) % max - if idx < 0 { - idx += max - } - r[num-i-1] = h.entries[idx] - } - - return r -} diff --git a/history_test.go b/history_test.go deleted file mode 100644 index 0eab1c7..0000000 --- a/history_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package main - -import ( - "reflect" - "testing" -) - -func TestHistory(t *testing.T) { - var r, expected []string - var size int - - h := NewHistory(5) - - r = h.Get(10) - expected = []string{} - if !reflect.DeepEqual(r, expected) { - t.Errorf("Got: %v, Expected: %v", r, expected) - } - - h.Add("1") - - if size = h.Len(); size != 1 { - t.Errorf("Wrong size: %v", size) - } - - r = h.Get(1) - expected = []string{"1"} - if !reflect.DeepEqual(r, expected) { - t.Errorf("Got: %v, Expected: %v", r, expected) - } - - h.Add("2") - h.Add("3") - h.Add("4") - h.Add("5") - h.Add("6") - - if size = h.Len(); size != 5 { - t.Errorf("Wrong size: %v", size) - } - - r = h.Get(2) - expected = []string{"5", "6"} - if !reflect.DeepEqual(r, expected) { - t.Errorf("Got: %v, Expected: %v", r, expected) - } - - r = h.Get(10) - expected = []string{"2", "3", "4", "5", "6"} - if !reflect.DeepEqual(r, expected) { - t.Errorf("Got: %v, Expected: %v", r, expected) - } -} diff --git a/pty.go b/pty.go deleted file mode 100644 index e635fba..0000000 --- a/pty.go +++ /dev/null @@ -1,69 +0,0 @@ -// Borrowed from go.crypto circa 2011 -package main - -import "encoding/binary" - -// parsePtyRequest parses the payload of the pty-req message and extracts the -// dimensions of the terminal. See RFC 4254, section 6.2. -func parsePtyRequest(s []byte) (width, height int, ok bool) { - _, s, ok = parseString(s) - if !ok { - return - } - width32, s, ok := parseUint32(s) - if !ok { - return - } - height32, _, ok := parseUint32(s) - width = int(width32) - height = int(height32) - if width < 1 { - ok = false - } - if height < 1 { - ok = false - } - return -} - -func parseWinchRequest(s []byte) (width, height int, ok bool) { - width32, s, ok := parseUint32(s) - if !ok { - return - } - height32, s, ok := parseUint32(s) - if !ok { - return - } - - width = int(width32) - height = int(height32) - if width < 1 { - ok = false - } - if height < 1 { - ok = false - } - return -} - -func parseString(in []byte) (out string, rest []byte, ok bool) { - if len(in) < 4 { - return - } - length := binary.BigEndian.Uint32(in) - if uint32(len(in)) < 4+length { - return - } - out = string(in[4 : 4+length]) - rest = in[4+length:] - ok = true - return -} - -func parseUint32(in []byte) (uint32, []byte, bool) { - if len(in) < 4 { - return 0, nil, false - } - return binary.BigEndian.Uint32(in), in[4:], true -} diff --git a/serve.go b/serve.go new file mode 100644 index 0000000..0e97ca8 --- /dev/null +++ b/serve.go @@ -0,0 +1,53 @@ +package main + +import ( + "fmt" + + "github.com/shazow/ssh-chat/chat" + "github.com/shazow/ssh-chat/sshd" +) + +func HandleTerminal(term *sshd.Terminal, channel *chat.Channel) { + defer term.Close() + name := term.Conn.User() + term.SetPrompt(fmt.Sprintf("[%s] ", name)) + // TODO: term.AutoCompleteCallback = ... + + user := chat.NewUserScreen(name, term) + defer user.Close() + + err := channel.Join(user) + if err != nil { + logger.Errorf("Failed to join: %s", err) + return + } + defer func() { + err := channel.Leave(user) + if err != nil { + logger.Errorf("Failed to leave: %s", err) + } + }() + + for { + // TODO: Handle commands etc? + line, err := term.ReadLine() + if err != nil { + // TODO: Catch EOF specifically? + logger.Errorf("Terminal reading error: %s", err) + return + } + m := chat.NewMessage(line).From(user) + channel.Send(*m) + } +} + +// Serve a chat service onto the sshd server. +func Serve(listener *sshd.SSHListener) { + terminals := listener.ServeTerminal() + channel := chat.NewChannel() + + for term := range terminals { + go HandleTerminal(term, channel) + } + +} diff --git a/server.go b/server.go deleted file mode 100644 index 181d2f7..0000000 --- a/server.go +++ /dev/null @@ -1,515 +0,0 @@ -package main - -import ( - "bufio" - "crypto/md5" - "encoding/base64" - "fmt" - "net" - "net/http" - "regexp" - "strings" - "sync" - "syscall" - "time" - - "golang.org/x/crypto/ssh" -) - -const ( - maxNameLength = 32 - historyLength = 20 - systemMessageFormat = "\033[1;90m" - privateMessageFormat = "\033[1m" - highlightFormat = Bold + "\033[48;5;11m\033[38;5;16m" - beep = "\007" -) - -var ( - reStripText = regexp.MustCompile("[^0-9A-Za-z_.-]") -) - -// Clients is a map of clients -type Clients map[string]*Client - -// Server holds all the fields used by a server -type Server struct { - sshConfig *ssh.ServerConfig - done chan struct{} - clients Clients - count int - history *History - motd string - whitelist map[string]struct{} // fingerprint lookup - admins map[string]struct{} // fingerprint lookup - bannedPK map[string]*time.Time // fingerprint lookup - started time.Time - sync.RWMutex -} - -// NewServer constructs a new server -func NewServer(privateKey []byte) (*Server, error) { - signer, err := ssh.ParsePrivateKey(privateKey) - if err != nil { - return nil, err - } - - server := Server{ - done: make(chan struct{}), - clients: Clients{}, - count: 0, - history: NewHistory(historyLength), - motd: "", - whitelist: map[string]struct{}{}, - admins: map[string]struct{}{}, - bannedPK: map[string]*time.Time{}, - started: time.Now(), - } - - config := ssh.ServerConfig{ - NoClientAuth: false, - // Auth-related things should be constant-time to avoid timing attacks. - PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - fingerprint := Fingerprint(key) - if server.IsBanned(fingerprint) { - return nil, fmt.Errorf("Banned.") - } - if !server.IsWhitelisted(fingerprint) { - return nil, fmt.Errorf("Not Whitelisted.") - } - perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}} - return perm, nil - }, - KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { - if server.IsBanned("") { - return nil, fmt.Errorf("Interactive login disabled.") - } - if !server.IsWhitelisted("") { - return nil, fmt.Errorf("Not Whitelisted.") - } - return nil, nil - }, - } - config.AddHostKey(signer) - - server.sshConfig = &config - - return &server, nil -} - -// Len returns the number of clients -func (s *Server) Len() int { - return len(s.clients) -} - -// SysMsg broadcasts the given message to everyone -func (s *Server) SysMsg(msg string, args ...interface{}) { - s.Broadcast(ContinuousFormat(systemMessageFormat, " * "+fmt.Sprintf(msg, args...)), nil) -} - -// Broadcast broadcasts the given message to everyone except for the given client -func (s *Server) Broadcast(msg string, except *Client) { - logger.Debugf("Broadcast to %d: %s", s.Len(), msg) - s.history.Add(msg) - - s.RLock() - defer s.RUnlock() - - for _, client := range s.clients { - if except != nil && client == except { - continue - } - - if strings.Contains(msg, client.Name) { - // Turn message red if client's name is mentioned, and send BEL if they have enabled beeping - personalMsg := strings.Replace(msg, client.Name, highlightFormat+client.Name+Reset, -1) - if client.beepMe { - personalMsg += beep - } - client.Send(personalMsg) - } else { - client.Send(msg) - } - } -} - -// Privmsg sends a message to a particular nick, if it exists -func (s *Server) Privmsg(nick, message string, sender *Client) error { - // Get the recipient - target, ok := s.clients[strings.ToLower(nick)] - if !ok { - return fmt.Errorf("no client with that nick") - } - // Send the message - target.Msg <- fmt.Sprintf(beep+"[PM from %v] %s%v%s", sender.ColoredName(), privateMessageFormat, message, Reset) - logger.Debugf("PM from %v to %v: %v", sender.Name, nick, message) - return nil -} - -// SetMotd sets the Message of the Day (MOTD) -func (s *Server) SetMotd(motd string) { - s.motd = motd -} - -// MotdUnicast sends the MOTD as a SysMsg -func (s *Server) MotdUnicast(client *Client) { - if s.motd == "" { - return - } - client.SysMsg(s.motd) -} - -// MotdBroadcast broadcasts the MOTD -func (s *Server) MotdBroadcast(client *Client) { - if s.motd == "" { - return - } - s.Broadcast(ContinuousFormat(systemMessageFormat, fmt.Sprintf(" * New MOTD set by %s.", client.ColoredName())), client) - s.Broadcast(s.motd, client) -} - -// Add adds the client to the list of clients -func (s *Server) Add(client *Client) { - go func() { - s.MotdUnicast(client) - client.SendLines(s.history.Get(10)) - }() - - s.Lock() - s.count++ - - newName, err := s.proposeName(client.Name) - if err != nil { - client.SysMsg("Your name '%s' is not available, renamed to '%s'. Use /nick to change it.", client.Name, ColorString(client.Color, newName)) - } - - client.Rename(newName) - s.clients[strings.ToLower(client.Name)] = client - num := len(s.clients) - s.Unlock() - - s.Broadcast(ContinuousFormat(systemMessageFormat, fmt.Sprintf(" * %s joined. (Total connected: %d)", client.Name, num)), client) -} - -// Remove removes the given client from the list of clients -func (s *Server) Remove(client *Client) { - s.Lock() - delete(s.clients, strings.ToLower(client.Name)) - s.Unlock() - - s.SysMsg("%s left.", client.Name) -} - -func (s *Server) proposeName(name string) (string, error) { - // Assumes caller holds lock. - var err error - name = reStripText.ReplaceAllString(name, "") - - if len(name) > maxNameLength { - name = name[:maxNameLength] - } else if len(name) == 0 { - name = fmt.Sprintf("Guest%d", s.count) - } - - _, collision := s.clients[strings.ToLower(name)] - if collision { - err = fmt.Errorf("%s is not available", name) - name = fmt.Sprintf("Guest%d", s.count) - } - - return name, err -} - -// Rename renames the given client (user) -func (s *Server) Rename(client *Client, newName string) { - var oldName string - if strings.ToLower(newName) == strings.ToLower(client.Name) { - oldName = client.Name - client.Rename(newName) - } else { - s.Lock() - newName, err := s.proposeName(newName) - if err != nil { - client.SysMsg("%s", err) - s.Unlock() - return - } - - // TODO: Use a channel/goroutine for adding clients, rather than locks? - delete(s.clients, strings.ToLower(client.Name)) - oldName = client.Name - client.Rename(newName) - s.clients[strings.ToLower(client.Name)] = client - s.Unlock() - } - s.SysMsg("%s is now known as %s.", ColorString(client.Color, oldName), ColorString(client.Color, newName)) -} - -// List lists the clients with the given prefix -func (s *Server) List(prefix *string) []string { - r := []string{} - - s.RLock() - defer s.RUnlock() - - for name, client := range s.clients { - if prefix != nil && !strings.HasPrefix(name, strings.ToLower(*prefix)) { - continue - } - r = append(r, client.Name) - } - - return r -} - -// Who returns the client with a given name -func (s *Server) Who(name string) *Client { - return s.clients[strings.ToLower(name)] -} - -// Op adds the given fingerprint to the list of admins -func (s *Server) Op(fingerprint string) { - logger.Infof("Adding admin: %s", fingerprint) - s.Lock() - s.admins[fingerprint] = struct{}{} - s.Unlock() -} - -// Whitelist adds the given fingerprint to the whitelist -func (s *Server) Whitelist(fingerprint string) error { - if fingerprint == "" { - return fmt.Errorf("Invalid fingerprint.") - } - if strings.HasPrefix(fingerprint, "github.com/") { - return s.whitelistIdentityURL(fingerprint) - } - - return s.whitelistFingerprint(fingerprint) -} - -func (s *Server) whitelistIdentityURL(user string) error { - logger.Infof("Adding github account %s to whitelist", user) - - user = strings.Replace(user, "github.com/", "", -1) - keys, err := getGithubPubKeys(user) - if err != nil { - return err - } - if len(keys) == 0 { - return fmt.Errorf("No keys for github user %s", user) - } - for _, key := range keys { - fingerprint := Fingerprint(key) - s.whitelistFingerprint(fingerprint) - } - return nil -} - -func (s *Server) whitelistFingerprint(fingerprint string) error { - logger.Infof("Adding whitelist: %s", fingerprint) - s.Lock() - s.whitelist[fingerprint] = struct{}{} - s.Unlock() - return nil -} - -// Client for getting github pub keys -var client = http.Client{ - Timeout: time.Duration(10 * time.Second), -} - -// Returns an array of public keys for the given github user URL -func getGithubPubKeys(user string) ([]ssh.PublicKey, error) { - resp, err := client.Get("http://github.com/" + user + ".keys") - if err != nil { - return nil, err - } - defer resp.Body.Close() - - pubs := []ssh.PublicKey{} - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - text := scanner.Text() - if text == "Not Found" { - continue - } - - splitKey := strings.SplitN(text, " ", -1) - - // In case of malformated key - if len(splitKey) < 2 { - continue - } - - bodyDecoded, err := base64.StdEncoding.DecodeString(splitKey[1]) - if err != nil { - return nil, err - } - - pub, err := ssh.ParsePublicKey(bodyDecoded) - if err != nil { - return nil, err - } - - pubs = append(pubs, pub) - } - return pubs, nil -} - -// Uptime returns the time since the server was started -func (s *Server) Uptime() string { - return time.Now().Sub(s.started).String() -} - -// IsOp checks if the given client is Op -func (s *Server) IsOp(client *Client) bool { - _, r := s.admins[client.Fingerprint()] - return r -} - -// IsWhitelisted checks if the given fingerprint is whitelisted -func (s *Server) IsWhitelisted(fingerprint string) bool { - /* if no whitelist, anyone is welcome */ - if len(s.whitelist) == 0 { - return true - } - - /* otherwise, check for whitelist presence */ - _, r := s.whitelist[fingerprint] - return r -} - -// IsBanned checks if the given fingerprint is banned -func (s *Server) IsBanned(fingerprint string) bool { - ban, hasBan := s.bannedPK[fingerprint] - if !hasBan { - return false - } - if ban == nil { - return true - } - if ban.Before(time.Now()) { - s.Unban(fingerprint) - return false - } - return true -} - -// Ban bans a fingerprint for the given duration -func (s *Server) Ban(fingerprint string, duration *time.Duration) { - var until *time.Time - s.Lock() - if duration != nil { - when := time.Now().Add(*duration) - until = &when - } - s.bannedPK[fingerprint] = until - s.Unlock() -} - -// Unban unbans a banned fingerprint -func (s *Server) Unban(fingerprint string) { - s.Lock() - delete(s.bannedPK, fingerprint) - s.Unlock() -} - -// Start starts the server -func (s *Server) Start(laddr string) error { - // Once a ServerConfig has been configured, connections can be - // accepted. - socket, err := net.Listen("tcp", laddr) - if err != nil { - return err - } - - logger.Infof("Listening on %s", laddr) - - go func() { - defer socket.Close() - for { - conn, err := socket.Accept() - - if err != nil { - logger.Errorf("Failed to accept connection: %v", err) - if err == syscall.EINVAL { - // TODO: Handle shutdown more gracefully? - return - } - } - - // Goroutineify to resume accepting sockets early. - go func() { - // From a standard TCP connection to an encrypted SSH connection - sshConn, channels, requests, err := ssh.NewServerConn(conn, s.sshConfig) - if err != nil { - logger.Errorf("Failed to handshake: %v", err) - return - } - - version := reStripText.ReplaceAllString(string(sshConn.ClientVersion()), "") - if len(version) > 100 { - version = "Evil Jerk with a superlong string" - } - logger.Infof("Connection #%d from: %s, %s, %s", s.count+1, sshConn.RemoteAddr(), sshConn.User(), version) - - go ssh.DiscardRequests(requests) - - client := NewClient(s, sshConn) - go client.handleChannels(channels) - }() - } - }() - - go func() { - <-s.done - socket.Close() - }() - - return nil -} - -// AutoCompleteFunction handles auto completion of nicks -func (s *Server) AutoCompleteFunction(line string, pos int, key rune) (newLine string, newPos int, ok bool) { - if key == 9 { - shortLine := strings.Split(line[:pos], " ") - partialNick := shortLine[len(shortLine)-1] - - nicks := s.List(&partialNick) - if len(nicks) > 0 { - nick := nicks[len(nicks)-1] - posPartialNick := pos - len(partialNick) - if len(shortLine) < 2 { - nick += ": " - } else { - nick += " " - } - newLine = strings.Replace(line[posPartialNick:], - partialNick, nick, 1) - newLine = line[:posPartialNick] + newLine - newPos = pos + (len(nick) - len(partialNick)) - ok = true - } - } else { - ok = false - } - return -} - -// Stop stops the server -func (s *Server) Stop() { - s.Lock() - for _, client := range s.clients { - client.Conn.Close() - } - s.Unlock() - - close(s.done) -} - -// Fingerprint returns the fingerprint based on a public key -func Fingerprint(k ssh.PublicKey) string { - hash := md5.Sum(k.Marshal()) - r := fmt.Sprintf("% x", hash) - return strings.Replace(r, " ", ":", -1) -} diff --git a/sshd/terminal.go b/sshd/terminal.go index 51597c6..14318b3 100644 --- a/sshd/terminal.go +++ b/sshd/terminal.go @@ -31,6 +31,12 @@ func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) { } go term.listen(requests) + go func() { + // FIXME: Is this necessary? + conn.Wait() + channel.Close() + }() + return &term, nil } @@ -48,12 +54,22 @@ func NewSession(conn ssh.Conn, channels <-chan ssh.NewChannel) (term *Terminal, } } + if term != nil { + // Reject the rest. + // FIXME: Do we need this? + go func() { + for ch := range channels { + ch.Reject(ssh.Prohibited, "only one session allowed") + } + }() + } + return term, err } // Close terminal and ssh connection func (t *Terminal) Close() error { - return MultiCloser{t.Channel, t.Conn}.Close() + return t.Conn.Close() } // Negotiate terminal type and settings