Merge branch 'refactor'

This commit is contained in:
Andrey Petrov 2015-01-20 14:27:24 -08:00
commit 5c72b1a121
44 changed files with 3674 additions and 1261 deletions

View File

@ -28,5 +28,5 @@ debug: $(BINARY) $(KEY)
./$(BINARY) --pprof 6060 -i $(KEY) --bind ":$(PORT)" -vv ./$(BINARY) --pprof 6060 -i $(KEY) --bind ":$(PORT)" -vv
test: test:
go test . go test ./...
golint golint ./...

View File

@ -19,8 +19,6 @@ The server's RSA key fingerprint is `e5:d5:d1:75:90:38:42:f6:c7:03:d7:d0:56:7d:6
## Compiling / Developing ## Compiling / Developing
**If you're going to be diving into the code, please use the [refactor branch](https://github.com/shazow/ssh-chat/tree/refactor) or see [issue #87](https://github.com/shazow/ssh-chat/pull/87).** It's not quite at feature parity yet, but the code is way nicer. The master branch is what's running on chat.shazow.net, but that will change soon.
You can compile ssh-chat by using `make build`. The resulting binary is portable and You can compile ssh-chat by using `make build`. The resulting binary is portable and
can be run on any system with a similar OS and CPU arch. Go 1.3 or higher is required to compile. can be run on any system with a similar OS and CPU arch. Go 1.3 or higher is required to compile.
@ -40,7 +38,7 @@ Usage:
Application Options: Application Options:
-v, --verbose Show verbose logging. -v, --verbose Show verbose logging.
-i, --identity= Private key to identify server with. (~/.ssh/id_rsa) -i, --identity= Private key to identify server with. (~/.ssh/id_rsa)
--bind= Host and port to listen on. (0.0.0.0:22) --bind= Host and port to listen on. (0.0.0.0:2022)
--admin= Fingerprint of pubkey to mark as admin. --admin= Fingerprint of pubkey to mark as admin.
--whitelist= Optional file of pubkey fingerprints that are allowed to connect --whitelist= Optional file of pubkey fingerprints that are allowed to connect
--motd= Message of the Day file (optional) --motd= Message of the Day file (optional)
@ -54,7 +52,7 @@ After doing `go get github.com/shazow/ssh-chat` on this repo, you should be able
to run a command like: to run a command like:
``` ```
$ ssh-chat --verbose --bind ":2022" --identity ~/.ssh/id_dsa $ ssh-chat --verbose --bind ":22" --identity ~/.ssh/id_dsa
``` ```
To bind on port 22, you'll need to make sure it's free (move any other ssh To bind on port 22, you'll need to make sure it's free (move any other ssh

150
auth.go Normal file
View File

@ -0,0 +1,150 @@
package main
import (
"errors"
"net"
"sync"
"time"
"github.com/shazow/ssh-chat/sshd"
"golang.org/x/crypto/ssh"
)
// The error returned a key is checked that is not whitelisted, with whitelisting required.
var ErrNotWhitelisted = errors.New("not whitelisted")
// The error returned a key is checked that is banned.
var ErrBanned = errors.New("banned")
// NewAuthKey returns string from an ssh.PublicKey.
func NewAuthKey(key ssh.PublicKey) string {
if key == nil {
return ""
}
// FIXME: Is there a way to index pubkeys without marshal'ing them into strings?
return sshd.Fingerprint(key)
}
// NewAuthAddr returns a string from a net.Addr
func NewAuthAddr(addr net.Addr) string {
if addr == nil {
return ""
}
host, _, _ := net.SplitHostPort(addr.String())
return host
}
// Auth stores fingerprint lookups
// TODO: Add timed auth by using a time.Time instead of struct{} for values.
type Auth struct {
sync.RWMutex
bannedAddr *Set
banned *Set
whitelist *Set
ops *Set
}
// NewAuth creates a new default Auth.
func NewAuth() *Auth {
return &Auth{
bannedAddr: NewSet(),
banned: NewSet(),
whitelist: NewSet(),
ops: NewSet(),
}
}
// AllowAnonymous determines if anonymous users are permitted.
func (a Auth) AllowAnonymous() bool {
return a.whitelist.Len() == 0
}
// Check determines if a pubkey fingerprint is permitted.
func (a *Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) {
authkey := NewAuthKey(key)
if a.whitelist.Len() != 0 {
// Only check whitelist if there is something in it, otherwise it's disabled.
whitelisted := a.whitelist.In(authkey)
if !whitelisted {
return false, ErrNotWhitelisted
}
return true, nil
}
banned := a.banned.In(authkey)
if !banned {
banned = a.bannedAddr.In(NewAuthAddr(addr))
}
if banned {
return false, ErrBanned
}
return true, nil
}
// Op sets a public key as a known operator.
func (a *Auth) Op(key ssh.PublicKey, d time.Duration) {
if key == nil {
return
}
authkey := NewAuthKey(key)
if d != 0 {
a.ops.AddExpiring(authkey, d)
} else {
a.ops.Add(authkey)
}
logger.Debugf("Added to ops: %s (for %s)", authkey, d)
}
// IsOp checks if a public key is an op.
func (a *Auth) IsOp(key ssh.PublicKey) bool {
if key == nil {
return false
}
authkey := NewAuthKey(key)
return a.ops.In(authkey)
}
// Whitelist will set a public key as a whitelisted user.
func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) {
if key == nil {
return
}
authkey := NewAuthKey(key)
if d != 0 {
a.whitelist.AddExpiring(authkey, d)
} else {
a.whitelist.Add(authkey)
}
logger.Debugf("Added to whitelist: %s (for %s)", authkey, d)
}
// Ban will set a public key as banned.
func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) {
if key == nil {
return
}
a.BanFingerprint(NewAuthKey(key), d)
}
// BanFingerprint will set a public key fingerprint as banned.
func (a *Auth) BanFingerprint(authkey string, d time.Duration) {
if d != 0 {
a.banned.AddExpiring(authkey, d)
} else {
a.banned.Add(authkey)
}
logger.Debugf("Added to banned: %s (for %s)", authkey, d)
}
// Ban will set an IP address as banned.
func (a *Auth) BanAddr(addr net.Addr, d time.Duration) {
key := NewAuthAddr(addr)
if d != 0 {
a.bannedAddr.AddExpiring(key, d)
} else {
a.bannedAddr.Add(key)
}
logger.Debugf("Added to bannedAddr: %s (for %s)", key, d)
}

62
auth_test.go Normal file
View File

@ -0,0 +1,62 @@
package main
import (
"crypto/rand"
"crypto/rsa"
"testing"
"golang.org/x/crypto/ssh"
)
func NewRandomPublicKey(bits int) (ssh.PublicKey, error) {
key, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, err
}
return ssh.NewPublicKey(key.Public())
}
func ClonePublicKey(key ssh.PublicKey) (ssh.PublicKey, error) {
return ssh.ParsePublicKey(key.Marshal())
}
func TestAuthWhitelist(t *testing.T) {
key, err := NewRandomPublicKey(512)
if err != nil {
t.Fatal(err)
}
auth := NewAuth()
ok, err := auth.Check(nil, key)
if !ok || err != nil {
t.Error("Failed to permit in default state:", err)
}
auth.Whitelist(key, 0)
keyClone, err := ClonePublicKey(key)
if err != nil {
t.Fatal(err)
}
if string(keyClone.Marshal()) != string(key.Marshal()) {
t.Error("Clone key does not match.")
}
ok, err = auth.Check(nil, keyClone)
if !ok || err != nil {
t.Error("Failed to permit whitelisted:", err)
}
key2, err := NewRandomPublicKey(512)
if err != nil {
t.Fatal(err)
}
ok, err = auth.Check(nil, key2)
if ok || err == nil {
t.Error("Failed to restrict not whitelisted:", err)
}
}

241
chat/command.go Normal file
View File

@ -0,0 +1,241 @@
package chat
// FIXME: Would be sweet if we could piggyback on a cli parser or something.
import (
"errors"
"fmt"
"strings"
)
// The error returned when an invalid command is issued.
var ErrInvalidCommand = errors.New("invalid command")
// The error returned when a command is given without an owner.
var ErrNoOwner = errors.New("command without owner")
// The error returned when a command is performed without the necessary number
// of arguments.
var ErrMissingArg = errors.New("missing argument")
// The error returned when a command is added without a prefix.
var ErrMissingPrefix = errors.New("command missing prefix")
// Command is a definition of a handler for a command.
type Command struct {
// The command's key, such as /foo
Prefix string
// Extra help regarding arguments
PrefixHelp string
// If omitted, command is hidden from /help
Help string
Handler func(*Room, CommandMsg) error
// Command requires Op permissions
Op bool
}
// Commands is a registry of available commands.
type Commands map[string]*Command
// Add will register a command. If help string is empty, it will be hidden from
// Help().
func (c Commands) Add(cmd Command) error {
if cmd.Prefix == "" {
return ErrMissingPrefix
}
c[cmd.Prefix] = &cmd
return nil
}
// Alias will add another command for the same handler, won't get added to help.
func (c Commands) Alias(command string, alias string) error {
cmd, ok := c[command]
if !ok {
return ErrInvalidCommand
}
c[alias] = cmd
return nil
}
// Run executes a command message.
func (c Commands) Run(room *Room, msg CommandMsg) error {
if msg.From == nil {
return ErrNoOwner
}
cmd, ok := c[msg.Command()]
if !ok {
return ErrInvalidCommand
}
return cmd.Handler(room, msg)
}
// Help will return collated help text as one string.
func (c Commands) Help(showOp bool) string {
// Filter by op
op := []*Command{}
normal := []*Command{}
for _, cmd := range c {
if cmd.Op {
op = append(op, cmd)
} else {
normal = append(normal, cmd)
}
}
help := "Available commands:" + Newline + NewCommandsHelp(normal).String()
if showOp {
help += Newline + "-> Operator commands:" + Newline + NewCommandsHelp(op).String()
}
return help
}
var defaultCommands *Commands
func init() {
defaultCommands = &Commands{}
InitCommands(defaultCommands)
}
// InitCommands injects default commands into a Commands registry.
func InitCommands(c *Commands) {
c.Add(Command{
Prefix: "/help",
Handler: func(room *Room, msg CommandMsg) error {
op := room.IsOp(msg.From())
room.Send(NewSystemMsg(room.commands.Help(op), msg.From()))
return nil
},
})
c.Add(Command{
Prefix: "/me",
Handler: func(room *Room, msg CommandMsg) error {
me := strings.TrimLeft(msg.body, "/me")
if me == "" {
me = "is at a loss for words."
} else {
me = me[1:]
}
room.Send(NewEmoteMsg(me, msg.From()))
return nil
},
})
c.Add(Command{
Prefix: "/exit",
Help: "Exit the chat.",
Handler: func(room *Room, msg CommandMsg) error {
msg.From().Close()
return nil
},
})
c.Alias("/exit", "/quit")
c.Add(Command{
Prefix: "/nick",
PrefixHelp: "NAME",
Help: "Rename yourself.",
Handler: func(room *Room, msg CommandMsg) error {
args := msg.Args()
if len(args) != 1 {
return ErrMissingArg
}
u := msg.From()
member, ok := room.MemberById(u.Id())
if !ok {
return errors.New("failed to find member")
}
oldId := member.Id()
member.SetId(SanitizeName(args[0]))
err := room.Rename(oldId, member)
if err != nil {
member.SetId(oldId)
return err
}
return nil
},
})
c.Add(Command{
Prefix: "/names",
Help: "List users who are connected.",
Handler: func(room *Room, msg CommandMsg) error {
// TODO: colorize
names := room.NamesPrefix("")
body := fmt.Sprintf("%d connected: %s", len(names), strings.Join(names, ", "))
room.Send(NewSystemMsg(body, msg.From()))
return nil
},
})
c.Alias("/names", "/list")
c.Add(Command{
Prefix: "/theme",
PrefixHelp: "[mono|colors]",
Help: "Set your color theme.",
Handler: func(room *Room, msg CommandMsg) error {
user := msg.From()
args := msg.Args()
if len(args) == 0 {
theme := "plain"
if user.Config.Theme != nil {
theme = user.Config.Theme.Id()
}
body := fmt.Sprintf("Current theme: %s", theme)
room.Send(NewSystemMsg(body, user))
return nil
}
id := args[0]
for _, t := range Themes {
if t.Id() == id {
user.Config.Theme = &t
body := fmt.Sprintf("Set theme: %s", id)
room.Send(NewSystemMsg(body, user))
return nil
}
}
return errors.New("theme not found")
},
})
c.Add(Command{
Prefix: "/quiet",
Help: "Silence room announcements.",
Handler: func(room *Room, msg CommandMsg) error {
u := msg.From()
u.ToggleQuietMode()
var body string
if u.Config.Quiet {
body = "Quiet mode is toggled ON"
} else {
body = "Quiet mode is toggled OFF"
}
room.Send(NewSystemMsg(body, u))
return nil
},
})
c.Add(Command{
Prefix: "/slap",
PrefixHelp: "NAME",
Handler: func(room *Room, msg CommandMsg) error {
var me string
args := msg.Args()
if len(args) == 0 {
me = "slaps themselves around a bit with a large trout."
} else {
me = fmt.Sprintf("slaps %s around a bit with a large trout.", strings.Join(args, " "))
}
room.Send(NewEmoteMsg(me, msg.From()))
return nil
},
})
}

13
chat/doc.go Normal file
View File

@ -0,0 +1,13 @@
/*
`chat` package is a server-agnostic implementation of a chat interface, built
with the intention of using with the intention of using as the backend for
ssh-chat.
This package should not know anything about sockets. It should expose io-style
interfaces and rooms for communicating with any method of transnport.
TODO: Add usage examples here.
*/
package chat

58
chat/help.go Normal file
View File

@ -0,0 +1,58 @@
package chat
import (
"fmt"
"sort"
"strings"
)
type helpItem struct {
Prefix string
Text string
}
type help struct {
items []helpItem
prefixWidth int
}
// NewCommandsHelp creates a help container from a commands container.
func NewCommandsHelp(c []*Command) fmt.Stringer {
lookup := map[string]struct{}{}
h := help{
items: []helpItem{},
}
for _, cmd := range c {
if cmd.Help == "" {
// Skip hidden commands.
continue
}
_, exists := lookup[cmd.Prefix]
if exists {
// Duplicate (alias)
continue
}
lookup[cmd.Prefix] = struct{}{}
prefix := fmt.Sprintf("%s %s", cmd.Prefix, cmd.PrefixHelp)
h.add(helpItem{prefix, cmd.Help})
}
return &h
}
func (h *help) add(item helpItem) {
h.items = append(h.items, item)
if len(item.Prefix) > h.prefixWidth {
h.prefixWidth = len(item.Prefix)
}
}
func (h help) String() string {
r := []string{}
format := fmt.Sprintf("%%-%ds - %%s", h.prefixWidth)
for _, item := range h.items {
r = append(r, fmt.Sprintf(format, item.Prefix, item.Text))
}
sort.Strings(r)
return strings.Join(r, Newline)
}

View File

@ -1,27 +1,33 @@
// TODO: Split this out into its own module, it's kinda neat. package chat
package main
import "sync" import (
"fmt"
"io"
"sync"
)
const timestampFmt = "2006-01-02 15:04:05"
// History contains the history entries // History contains the history entries
type History struct { type History struct {
entries []string sync.RWMutex
entries []Message
head int head int
size int size int
lock sync.Mutex out io.Writer
} }
// NewHistory constructs a new history of the given size // NewHistory constructs a new history of the given size
func NewHistory(size int) *History { func NewHistory(size int) *History {
return &History{ return &History{
entries: make([]string, size), entries: make([]Message, size),
} }
} }
// Add adds the given entry to the entries in the history // Add adds the given entry to the entries in the history
func (h *History) Add(entry string) { func (h *History) Add(entry Message) {
h.lock.Lock() h.Lock()
defer h.lock.Unlock() defer h.Unlock()
max := cap(h.entries) max := cap(h.entries)
h.head = (h.head + 1) % max h.head = (h.head + 1) % max
@ -29,6 +35,10 @@ func (h *History) Add(entry string) {
if h.size < max { if h.size < max {
h.size++ h.size++
} }
if h.out != nil {
fmt.Fprintf(h.out, "[%s] %s\n", entry.Timestamp().UTC().Format(timestampFmt), entry.String())
}
} }
// Len returns the number of entries in the history // Len returns the number of entries in the history
@ -37,16 +47,16 @@ func (h *History) Len() int {
} }
// Get the entry with the given number // Get the entry with the given number
func (h *History) Get(num int) []string { func (h *History) Get(num int) []Message {
h.lock.Lock() h.RLock()
defer h.lock.Unlock() defer h.RUnlock()
max := cap(h.entries) max := cap(h.entries)
if num > h.size { if num > h.size {
num = h.size num = h.size
} }
r := make([]string, num) r := make([]Message, num)
for i := 0; i < num; i++ { for i := 0; i < num; i++ {
idx := (h.head - i) % max idx := (h.head - i) % max
if idx < 0 { if idx < 0 {
@ -57,3 +67,10 @@ func (h *History) Get(num int) []string {
return r return r
} }
// SetOutput sets the output for logging added messages
func (h *History) SetOutput(w io.Writer) {
h.Lock()
h.out = w
h.Unlock()
}

62
chat/history_test.go Normal file
View File

@ -0,0 +1,62 @@
package chat
import "testing"
func msgEqual(a []Message, b []Message) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].String() != b[i].String() {
return false
}
}
return true
}
func TestHistory(t *testing.T) {
var r, expected []Message
var size int
h := NewHistory(5)
r = h.Get(10)
expected = []Message{}
if !msgEqual(r, expected) {
t.Errorf("Got: %v, Expected: %v", r, expected)
}
h.Add(NewMsg("1"))
if size = h.Len(); size != 1 {
t.Errorf("Wrong size: %v", size)
}
r = h.Get(1)
expected = []Message{NewMsg("1")}
if !msgEqual(r, expected) {
t.Errorf("Got: %v, Expected: %v", r, expected)
}
h.Add(NewMsg("2"))
h.Add(NewMsg("3"))
h.Add(NewMsg("4"))
h.Add(NewMsg("5"))
h.Add(NewMsg("6"))
if size = h.Len(); size != 5 {
t.Errorf("Wrong size: %v", size)
}
r = h.Get(2)
expected = []Message{NewMsg("5"), NewMsg("6")}
if !msgEqual(r, expected) {
t.Errorf("Got: %v, Expected: %v", r, expected)
}
r = h.Get(10)
expected = []Message{NewMsg("2"), NewMsg("3"), NewMsg("4"), NewMsg("5"), NewMsg("6")}
if !msgEqual(r, expected) {
t.Errorf("Got: %v, Expected: %v", r, expected)
}
}

22
chat/logger.go Normal file
View File

@ -0,0 +1,22 @@
package chat
import "io"
import stdlog "log"
var logger *stdlog.Logger
func SetLogger(w io.Writer) {
flags := stdlog.Flags()
prefix := "[chat] "
logger = stdlog.New(w, prefix, flags)
}
type nullWriter struct{}
func (nullWriter) Write(data []byte) (int, error) {
return len(data), nil
}
func init() {
SetLogger(nullWriter{})
}

257
chat/message.go Normal file
View File

@ -0,0 +1,257 @@
package chat
import (
"fmt"
"strings"
"time"
)
// Message is an interface for messages.
type Message interface {
Render(*Theme) string
String() string
Command() string
Timestamp() time.Time
}
type MessageTo interface {
Message
To() *User
}
type MessageFrom interface {
Message
From() *User
}
func ParseInput(body string, from *User) Message {
m := NewPublicMsg(body, from)
cmd, isCmd := m.ParseCommand()
if isCmd {
return cmd
}
return m
}
// Msg is a base type for other message types.
type Msg struct {
body string
timestamp time.Time
// TODO: themeCache *map[*Theme]string
}
func NewMsg(body string) *Msg {
return &Msg{
body: body,
timestamp: time.Now(),
}
}
// Render message based on a theme.
func (m *Msg) Render(t *Theme) string {
// TODO: Render based on theme
// TODO: Cache based on theme
return m.String()
}
func (m *Msg) String() string {
return m.body
}
func (m *Msg) Command() string {
return ""
}
func (m *Msg) Timestamp() time.Time {
return m.timestamp
}
// PublicMsg is any message from a user sent to the room.
type PublicMsg struct {
Msg
from *User
}
func NewPublicMsg(body string, from *User) *PublicMsg {
return &PublicMsg{
Msg: Msg{
body: body,
timestamp: time.Now(),
},
from: from,
}
}
func (m *PublicMsg) From() *User {
return m.from
}
func (m *PublicMsg) ParseCommand() (*CommandMsg, bool) {
// Check if the message is a command
if !strings.HasPrefix(m.body, "/") {
return nil, false
}
// Parse
// TODO: Handle quoted fields properly
fields := strings.Fields(m.body)
command, args := fields[0], fields[1:]
msg := CommandMsg{
PublicMsg: m,
command: command,
args: args,
}
return &msg, true
}
func (m *PublicMsg) Render(t *Theme) string {
if t == nil {
return m.String()
}
return fmt.Sprintf("%s: %s", t.ColorName(m.from), m.body)
}
func (m *PublicMsg) RenderFor(cfg UserConfig) string {
if cfg.Highlight == nil || cfg.Theme == nil {
return m.Render(cfg.Theme)
}
if !cfg.Highlight.MatchString(m.body) {
return m.Render(cfg.Theme)
}
body := cfg.Highlight.ReplaceAllString(m.body, cfg.Theme.Highlight("${1}"))
if cfg.Bell {
body += Bel
}
return fmt.Sprintf("%s: %s", cfg.Theme.ColorName(m.from), body)
}
func (m *PublicMsg) String() string {
return fmt.Sprintf("%s: %s", m.from.Name(), m.body)
}
// EmoteMsg is a /me message sent to the room. It specifically does not
// extend PublicMsg because it doesn't implement MessageFrom to allow the
// sender to see the emote.
type EmoteMsg struct {
Msg
from *User
}
func NewEmoteMsg(body string, from *User) *EmoteMsg {
return &EmoteMsg{
Msg: Msg{
body: body,
timestamp: time.Now(),
},
from: from,
}
}
func (m *EmoteMsg) Render(t *Theme) string {
return fmt.Sprintf("** %s %s", m.from.Name(), m.body)
}
func (m *EmoteMsg) String() string {
return m.Render(nil)
}
// PrivateMsg is a message sent to another user, not shown to anyone else.
type PrivateMsg struct {
PublicMsg
to *User
}
func NewPrivateMsg(body string, from *User, to *User) *PrivateMsg {
return &PrivateMsg{
PublicMsg: *NewPublicMsg(body, from),
to: to,
}
}
func (m *PrivateMsg) To() *User {
return m.to
}
func (m *PrivateMsg) Render(t *Theme) string {
return fmt.Sprintf("[PM from %s] %s", m.from.Name(), m.body)
}
func (m *PrivateMsg) String() string {
return m.Render(nil)
}
// SystemMsg is a response sent from the server directly to a user, not shown
// to anyone else. Usually in response to something, like /help.
type SystemMsg struct {
Msg
to *User
}
func NewSystemMsg(body string, to *User) *SystemMsg {
return &SystemMsg{
Msg: Msg{
body: body,
timestamp: time.Now(),
},
to: to,
}
}
func (m *SystemMsg) Render(t *Theme) string {
if t == nil {
return m.String()
}
return t.ColorSys(m.String())
}
func (m *SystemMsg) String() string {
return fmt.Sprintf("-> %s", m.body)
}
func (m *SystemMsg) To() *User {
return m.to
}
// AnnounceMsg is a message sent from the server to everyone, like a join or
// leave event.
type AnnounceMsg struct {
Msg
}
func NewAnnounceMsg(body string) *AnnounceMsg {
return &AnnounceMsg{
Msg: Msg{
body: body,
timestamp: time.Now(),
},
}
}
func (m *AnnounceMsg) Render(t *Theme) string {
if t == nil {
return m.String()
}
return t.ColorSys(m.String())
}
func (m *AnnounceMsg) String() string {
return fmt.Sprintf(" * %s", m.body)
}
type CommandMsg struct {
*PublicMsg
command string
args []string
room *Room
}
func (m *CommandMsg) Command() string {
return m.command
}
func (m *CommandMsg) Args() []string {
return m.args
}

52
chat/message_test.go Normal file
View File

@ -0,0 +1,52 @@
package chat
import "testing"
type testId string
func (i testId) Id() string {
return string(i)
}
func (i testId) SetId(s string) {
// no-op
}
func (i testId) Name() string {
return i.Id()
}
func TestMessage(t *testing.T) {
var expected, actual string
expected = " * foo"
actual = NewAnnounceMsg("foo").String()
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
u := NewUser(testId("foo"))
expected = "foo: hello"
actual = NewPublicMsg("hello", u).String()
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
expected = "** foo sighs."
actual = NewEmoteMsg("sighs.", u).String()
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
expected = "-> hello"
actual = NewSystemMsg("hello", u).String()
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
expected = "[PM from foo] hello"
actual = NewPrivateMsg("hello", u, u).String()
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
}
// TODO: Add theme rendering tests

222
chat/room.go Normal file
View File

@ -0,0 +1,222 @@
package chat
import (
"errors"
"fmt"
"io"
"sync"
)
const historyLen = 20
const roomBuffer = 10
// The error returned when a message is sent to a room that is already
// closed.
var ErrRoomClosed = errors.New("room closed")
// The error returned when a user attempts to join with an invalid name, such
// as empty string.
var ErrInvalidName = errors.New("invalid name")
// Member is a User with per-Room metadata attached to it.
type Member struct {
*User
Op bool
}
// Room definition, also a Set of User Items
type Room struct {
topic string
history *History
members *Set
broadcast chan Message
commands Commands
closed bool
closeOnce sync.Once
}
// NewRoom creates a new room.
func NewRoom() *Room {
broadcast := make(chan Message, roomBuffer)
return &Room{
broadcast: broadcast,
history: NewHistory(historyLen),
members: NewSet(),
commands: *defaultCommands,
}
}
// SetCommands sets the room's command handlers.
func (r *Room) SetCommands(commands Commands) {
r.commands = commands
}
// Close the room and all the users it contains.
func (r *Room) Close() {
r.closeOnce.Do(func() {
r.closed = true
r.members.Each(func(m Identifier) {
m.(*Member).Close()
})
r.members.Clear()
close(r.broadcast)
})
}
// SetLogging sets logging output for the room's history
func (r *Room) SetLogging(out io.Writer) {
r.history.SetOutput(out)
}
// HandleMsg reacts to a message, will block until done.
func (r *Room) HandleMsg(m Message) {
switch m := m.(type) {
case *CommandMsg:
cmd := *m
err := r.commands.Run(r, cmd)
if err != nil {
m := NewSystemMsg(fmt.Sprintf("Err: %s", err), cmd.from)
go r.HandleMsg(m)
}
case MessageTo:
user := m.To()
user.Send(m)
default:
fromMsg, skip := m.(MessageFrom)
var skipUser *User
if skip {
skipUser = fromMsg.From()
}
r.history.Add(m)
r.members.Each(func(u Identifier) {
user := u.(*Member).User
if skip && skipUser == user {
// Skip
return
}
if _, ok := m.(*AnnounceMsg); ok {
if user.Config.Quiet {
// Skip
return
}
}
user.Send(m)
})
}
}
// Serve will consume the broadcast room and handle the messages, should be
// run in a goroutine.
func (r *Room) Serve() {
for m := range r.broadcast {
go r.HandleMsg(m)
}
}
// Send message, buffered by a chan.
func (r *Room) Send(m Message) {
r.broadcast <- m
}
// History feeds the room's recent message history to the user's handler.
func (r *Room) History(u *User) {
for _, m := range r.history.Get(historyLen) {
u.Send(m)
}
}
// Join the room as a user, will announce.
func (r *Room) Join(u *User) (*Member, error) {
if r.closed {
return nil, ErrRoomClosed
}
if u.Id() == "" {
return nil, ErrInvalidName
}
member := Member{u, false}
err := r.members.Add(&member)
if err != nil {
return nil, err
}
r.History(u)
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.members.Len())
r.Send(NewAnnounceMsg(s))
return &member, nil
}
// Leave the room as a user, will announce. Mostly used during setup.
func (r *Room) Leave(u *User) error {
err := r.members.Remove(u)
if err != nil {
return err
}
s := fmt.Sprintf("%s left.", u.Name())
r.Send(NewAnnounceMsg(s))
return nil
}
// Rename member with a new identity. This will not call rename on the member.
func (r *Room) Rename(oldId string, identity Identifier) error {
if identity.Id() == "" {
return ErrInvalidName
}
err := r.members.Replace(oldId, identity)
if err != nil {
return err
}
s := fmt.Sprintf("%s is now known as %s.", oldId, identity.Id())
r.Send(NewAnnounceMsg(s))
return nil
}
// Member returns a corresponding Member object to a User if the Member is
// present in this room.
func (r *Room) Member(u *User) (*Member, bool) {
m, ok := r.MemberById(u.Id())
if !ok {
return nil, false
}
// Check that it's the same user
if m.User != u {
return nil, false
}
return m, true
}
func (r *Room) MemberById(id string) (*Member, bool) {
m, err := r.members.Get(id)
if err != nil {
return nil, false
}
return m.(*Member), true
}
// IsOp returns whether a user is an operator in this room.
func (r *Room) IsOp(u *User) bool {
m, ok := r.Member(u)
return ok && m.Op
}
// Topic of the room.
func (r *Room) Topic() string {
return r.topic
}
// SetTopic will set the topic of the room.
func (r *Room) SetTopic(s string) {
r.topic = s
}
// NamesPrefix lists all members' names with a given prefix, used to query
// for autocompletion purposes.
func (r *Room) NamesPrefix(prefix string) []string {
members := r.members.ListPrefix(prefix)
names := make([]string, len(members))
for i, u := range members {
names[i] = u.(*Member).User.Name()
}
return names
}

192
chat/room_test.go Normal file
View File

@ -0,0 +1,192 @@
package chat
import (
"reflect"
"testing"
)
func TestRoomServe(t *testing.T) {
ch := NewRoom()
ch.Send(NewAnnounceMsg("hello"))
received := <-ch.broadcast
actual := received.String()
expected := " * hello"
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
}
func TestRoomJoin(t *testing.T) {
var expected, actual []byte
s := &MockScreen{}
u := NewUser(testId("foo"))
ch := NewRoom()
go ch.Serve()
defer ch.Close()
_, err := ch.Join(u)
if err != nil {
t.Fatal(err)
}
u.ConsumeOne(s)
expected = []byte(" * foo joined. (Connected: 1)" + Newline)
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
ch.Send(NewSystemMsg("hello", u))
u.ConsumeOne(s)
expected = []byte("-> hello" + Newline)
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
ch.Send(ParseInput("/me says hello.", u))
u.ConsumeOne(s)
expected = []byte("** foo says hello." + Newline)
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
}
func TestRoomDoesntBroadcastAnnounceMessagesWhenQuiet(t *testing.T) {
u := NewUser(testId("foo"))
u.Config = UserConfig{
Quiet: true,
}
ch := NewRoom()
defer ch.Close()
_, err := ch.Join(u)
if err != nil {
t.Fatal(err)
}
// Drain the initial Join message
<-ch.broadcast
go func() {
for msg := range u.msg {
if _, ok := msg.(*AnnounceMsg); ok {
t.Errorf("Got unexpected `%T`", msg)
}
}
}()
// Call with an AnnounceMsg and all the other types
// and assert we received only non-announce messages
ch.HandleMsg(NewAnnounceMsg("Ignored"))
// Assert we still get all other types of messages
ch.HandleMsg(NewEmoteMsg("hello", u))
ch.HandleMsg(NewSystemMsg("hello", u))
ch.HandleMsg(NewPrivateMsg("hello", u, u))
ch.HandleMsg(NewPublicMsg("hello", u))
}
func TestRoomQuietToggleBroadcasts(t *testing.T) {
u := NewUser(testId("foo"))
u.Config = UserConfig{
Quiet: true,
}
ch := NewRoom()
defer ch.Close()
_, err := ch.Join(u)
if err != nil {
t.Fatal(err)
}
// Drain the initial Join message
<-ch.broadcast
u.ToggleQuietMode()
expectedMsg := NewAnnounceMsg("Ignored")
ch.HandleMsg(expectedMsg)
msg := <-u.msg
if _, ok := msg.(*AnnounceMsg); !ok {
t.Errorf("Got: `%T`; Expected: `%T`", msg, expectedMsg)
}
u.ToggleQuietMode()
ch.HandleMsg(NewAnnounceMsg("Ignored"))
ch.HandleMsg(NewSystemMsg("hello", u))
msg = <-u.msg
if _, ok := msg.(*AnnounceMsg); ok {
t.Errorf("Got unexpected `%T`", msg)
}
}
func TestQuietToggleDisplayState(t *testing.T) {
var expected, actual []byte
s := &MockScreen{}
u := NewUser(testId("foo"))
ch := NewRoom()
go ch.Serve()
defer ch.Close()
_, err := ch.Join(u)
if err != nil {
t.Fatal(err)
}
// Drain the initial Join message
<-ch.broadcast
ch.Send(ParseInput("/quiet", u))
u.ConsumeOne(s)
expected = []byte("-> Quiet mode is toggled ON" + Newline)
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
ch.Send(ParseInput("/quiet", u))
u.ConsumeOne(s)
expected = []byte("-> Quiet mode is toggled OFF" + Newline)
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
}
func TestRoomNames(t *testing.T) {
var expected, actual []byte
s := &MockScreen{}
u := NewUser(testId("foo"))
ch := NewRoom()
go ch.Serve()
defer ch.Close()
_, err := ch.Join(u)
if err != nil {
t.Fatal(err)
}
// Drain the initial Join message
<-ch.broadcast
ch.Send(ParseInput("/names", u))
u.ConsumeOne(s)
expected = []byte("-> 1 connected: foo" + Newline)
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
}

17
chat/sanitize.go Normal file
View File

@ -0,0 +1,17 @@
package chat
import "regexp"
var reStripName = regexp.MustCompile("[^\\w.-]")
// SanitizeName returns a name with only allowed characters.
func SanitizeName(s string) string {
return reStripName.ReplaceAllString(s, "")
}
var reStripData = regexp.MustCompile("[^[:ascii:]]")
// SanitizeData returns a string with only allowed characters for client-provided metadata inputs.
func SanitizeData(s string) string {
return reStripData.ReplaceAllString(s, "")
}

51
chat/screen_test.go Normal file
View File

@ -0,0 +1,51 @@
package chat
import (
"reflect"
"testing"
)
// Used for testing
type MockScreen struct {
buffer []byte
}
func (s *MockScreen) Write(data []byte) (n int, err error) {
s.buffer = append(s.buffer, data...)
return len(data), nil
}
func (s *MockScreen) Read(p *[]byte) (n int, err error) {
*p = s.buffer
s.buffer = []byte{}
return len(*p), nil
}
func TestScreen(t *testing.T) {
var actual, expected []byte
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: %v; Expected: %v", actual, expected)
}
actual = []byte("foo")
expected = []byte("foo")
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: %v; Expected: %v", actual, expected)
}
s := &MockScreen{}
expected = nil
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: %v; Expected: %v", actual, expected)
}
expected = []byte("hello, world")
s.Write(expected)
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: %v; Expected: %v", actual, expected)
}
}

142
chat/set.go Normal file
View File

@ -0,0 +1,142 @@
package chat
import (
"errors"
"strings"
"sync"
)
// The error returned when an added id already exists in the set.
var ErrIdTaken = errors.New("id already taken")
// The error returned when a requested item does not exist in the set.
var ErrItemMissing = errors.New("item does not exist")
// Set with string lookup.
// TODO: Add trie for efficient prefix lookup?
type Set struct {
lookup map[string]Identifier
sync.RWMutex
}
// NewSet creates a new set.
func NewSet() *Set {
return &Set{
lookup: map[string]Identifier{},
}
}
// Clear removes all items and returns the number removed.
func (s *Set) Clear() int {
s.Lock()
n := len(s.lookup)
s.lookup = map[string]Identifier{}
s.Unlock()
return n
}
// Len returns the size of the set right now.
func (s *Set) Len() int {
return len(s.lookup)
}
// In checks if an item exists in this set.
func (s *Set) In(item Identifier) bool {
s.RLock()
_, ok := s.lookup[item.Id()]
s.RUnlock()
return ok
}
// Get returns an item with the given Id.
func (s *Set) Get(id string) (Identifier, error) {
s.RLock()
item, ok := s.lookup[id]
s.RUnlock()
if !ok {
return nil, ErrItemMissing
}
return item, nil
}
// Add item to this set if it does not exist already.
func (s *Set) Add(item Identifier) error {
s.Lock()
defer s.Unlock()
_, found := s.lookup[item.Id()]
if found {
return ErrIdTaken
}
s.lookup[item.Id()] = item
return nil
}
// Remove item from this set.
func (s *Set) Remove(item Identifier) error {
s.Lock()
defer s.Unlock()
id := item.Id()
_, found := s.lookup[id]
if !found {
return ErrItemMissing
}
delete(s.lookup, id)
return nil
}
// Replace item from old id with new Identifier.
// Used for moving the same identifier to a new Id, such as a rename.
func (s *Set) Replace(oldId string, item Identifier) error {
s.Lock()
defer s.Unlock()
// Check if it already exists
_, found := s.lookup[item.Id()]
if found {
return ErrIdTaken
}
// Remove oldId
_, found = s.lookup[oldId]
if !found {
return ErrItemMissing
}
delete(s.lookup, oldId)
// Add new identifier
s.lookup[item.Id()] = item
return nil
}
// Each loops over every item while holding a read lock and applies fn to each
// element.
func (s *Set) Each(fn func(item Identifier)) {
s.RLock()
for _, item := range s.lookup {
fn(item)
}
s.RUnlock()
}
// ListPrefix returns a list of items with a prefix, case insensitive.
func (s *Set) ListPrefix(prefix string) []Identifier {
r := []Identifier{}
prefix = strings.ToLower(prefix)
s.RLock()
defer s.RUnlock()
for id, item := range s.lookup {
if !strings.HasPrefix(string(id), prefix) {
continue
}
r = append(r, item)
}
return r
}

38
chat/set_test.go Normal file
View File

@ -0,0 +1,38 @@
package chat
import "testing"
func TestSet(t *testing.T) {
var err error
s := NewSet()
u := NewUser(testId("foo"))
if s.In(u) {
t.Errorf("Set should be empty.")
}
err = s.Add(u)
if err != nil {
t.Error(err)
}
if !s.In(u) {
t.Errorf("Set should contain user.")
}
u2 := NewUser(testId("bar"))
err = s.Add(u2)
if err != nil {
t.Error(err)
}
err = s.Add(u2)
if err != ErrIdTaken {
t.Error(err)
}
size := s.Len()
if size != 2 {
t.Errorf("Set wrong size: %d (expected %d)", size, 2)
}
}

195
chat/theme.go Normal file
View File

@ -0,0 +1,195 @@
package chat
import "fmt"
const (
// Reset resets the color
Reset = "\033[0m"
// Bold makes the following text bold
Bold = "\033[1m"
// Dim dims the following text
Dim = "\033[2m"
// Italic makes the following text italic
Italic = "\033[3m"
// Underline underlines the following text
Underline = "\033[4m"
// Blink blinks the following text
Blink = "\033[5m"
// Invert inverts the following text
Invert = "\033[7m"
// Newline
Newline = "\r\n"
// BEL
Bel = "\007"
)
// Interface for Styles
type Style interface {
String() string
Format(string) string
}
// General hardcoded style, mostly used as a crutch until we flesh out the
// framework to support backgrounds etc.
type style string
func (c style) String() string {
return string(c)
}
func (c style) Format(s string) string {
return c.String() + s + Reset
}
// 256 color type, for terminals who support it
type Color256 uint8
// String version of this color
func (c Color256) String() string {
return fmt.Sprintf("38;05;%d", c)
}
// Return formatted string with this color
func (c Color256) Format(s string) string {
return "\033[" + c.String() + "m" + s + Reset
}
// No color, used for mono theme
type Color0 struct{}
// No-op for Color0
func (c Color0) String() string {
return ""
}
// No-op for Color0
func (c Color0) Format(s string) string {
return s
}
// Container for a collection of colors
type Palette struct {
colors []Style
size int
}
// Get a color by index, overflows are looped around.
func (p Palette) Get(i int) Style {
return p.colors[i%(p.size-1)]
}
func (p Palette) Len() int {
return p.size
}
func (p Palette) String() string {
r := ""
for _, c := range p.colors {
r += c.Format("X")
}
return r
}
// Collection of settings for chat
type Theme struct {
id string
sys Style
pm Style
highlight Style
names *Palette
}
func (t Theme) Id() string {
return t.id
}
// Colorize name string given some index
func (t Theme) ColorName(u *User) string {
if t.names == nil {
return u.Name()
}
return t.names.Get(u.colorIdx).Format(u.Name())
}
// Colorize the PM string
func (t Theme) ColorPM(s string) string {
if t.pm == nil {
return s
}
return t.pm.Format(s)
}
// Colorize the Sys message
func (t Theme) ColorSys(s string) string {
if t.sys == nil {
return s
}
return t.sys.Format(s)
}
// Highlight a matched string, usually name
func (t Theme) Highlight(s string) string {
if t.highlight == nil {
return s
}
return t.highlight.Format(s)
}
// List of initialzied themes
var Themes []Theme
// Default theme to use
var DefaultTheme *Theme
func readableColors256() *Palette {
size := 247
p := Palette{
colors: make([]Style, size),
size: size,
}
j := 0
for i := 0; i < 256; i++ {
if (16 <= i && i <= 18) || (232 <= i && i <= 237) {
// Remove the ones near black, this is kinda sadpanda.
continue
}
p.colors[j] = Color256(i)
j++
}
return &p
}
func init() {
palette := readableColors256()
Themes = []Theme{
Theme{
id: "colors",
names: palette,
sys: palette.Get(8), // Grey
pm: palette.Get(7), // White
highlight: style(Bold + "\033[48;5;11m\033[38;5;16m"), // Yellow highlight
},
Theme{
id: "mono",
},
}
// Debug for printing colors:
//for _, color := range palette.colors {
// fmt.Print(color.Format(color.String() + " "))
//}
DefaultTheme = &Themes[0]
}

71
chat/theme_test.go Normal file
View File

@ -0,0 +1,71 @@
package chat
import (
"fmt"
"testing"
)
func TestThemePalette(t *testing.T) {
var expected, actual string
palette := readableColors256()
color := palette.Get(5)
if color == nil {
t.Fatal("Failed to return a color from palette.")
}
actual = color.String()
expected = "38;05;5"
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
actual = color.Format("foo")
expected = "\033[38;05;5mfoo\033[0m"
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
actual = palette.Get(palette.Len() + 1).String()
expected = fmt.Sprintf("38;05;%d", 2)
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
}
func TestTheme(t *testing.T) {
var expected, actual string
colorTheme := Themes[0]
color := colorTheme.sys
if color == nil {
t.Fatal("Sys color should not be empty for first theme.")
}
actual = color.Format("foo")
expected = "\033[38;05;8mfoo\033[0m"
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
actual = colorTheme.ColorSys("foo")
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
u := NewUser(testId("foo"))
u.colorIdx = 4
actual = colorTheme.ColorName(u)
expected = "\033[38;05;4mfoo\033[0m"
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
msg := NewPublicMsg("hello", u)
actual = msg.Render(&colorTheme)
expected = "\033[38;05;4mfoo\033[0m: hello"
if actual != expected {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
}

178
chat/user.go Normal file
View File

@ -0,0 +1,178 @@
package chat
import (
"errors"
"fmt"
"io"
"math/rand"
"regexp"
"sync"
"time"
)
const messageBuffer = 20
const reHighlight = `\b(%s)\b`
var ErrUserClosed = errors.New("user closed")
// Identifier is an interface that can uniquely identify itself.
type Identifier interface {
Id() string
SetId(string)
Name() string
}
// User definition, implemented set Item interface and io.Writer
type User struct {
Identifier
Config UserConfig
colorIdx int
joined time.Time
msg chan Message
done chan struct{}
replyTo *User // Set when user gets a /msg, for replying.
closed bool
closeOnce sync.Once
}
func NewUser(identity Identifier) *User {
u := User{
Identifier: identity,
Config: *DefaultUserConfig,
joined: time.Now(),
msg: make(chan Message, messageBuffer),
done: make(chan struct{}, 1),
}
u.SetColorIdx(rand.Int())
return &u
}
func NewUserScreen(identity Identifier, screen io.Writer) *User {
u := NewUser(identity)
go u.Consume(screen)
return u
}
// Rename the user with a new Identifier.
func (u *User) SetId(id string) {
u.Identifier.SetId(id)
u.SetColorIdx(rand.Int())
}
// ReplyTo returns the last user that messaged this user.
func (u *User) ReplyTo() *User {
return u.replyTo
}
// SetReplyTo sets the last user to message this user.
func (u *User) SetReplyTo(user *User) {
u.replyTo = user
}
// ToggleQuietMode will toggle whether or not quiet mode is enabled
func (u *User) ToggleQuietMode() {
u.Config.Quiet = !u.Config.Quiet
}
// SetColorIdx will set the colorIdx to a specific value, primarily used for
// testing.
func (u *User) SetColorIdx(idx int) {
u.colorIdx = idx
}
// Block until user is closed
func (u *User) Wait() {
<-u.done
}
// Disconnect user, stop accepting messages
func (u *User) Close() {
u.closeOnce.Do(func() {
u.closed = true
close(u.done)
close(u.msg)
})
}
// Consume message buffer into an io.Writer. Will block, should be called in a
// goroutine.
// TODO: Not sure if this is a great API.
func (u *User) Consume(out io.Writer) {
for m := range u.msg {
u.HandleMsg(m, out)
}
}
// Consume one message and stop, mostly for testing
func (u *User) ConsumeOne(out io.Writer) {
u.HandleMsg(<-u.msg, out)
}
// SetHighlight sets the highlighting regular expression to match string.
func (u *User) SetHighlight(s string) error {
re, err := regexp.Compile(fmt.Sprintf(reHighlight, s))
if err != nil {
return err
}
u.Config.Highlight = re
return nil
}
func (u *User) render(m Message) string {
switch m := m.(type) {
case *PublicMsg:
return m.RenderFor(u.Config) + Newline
case *PrivateMsg:
u.SetReplyTo(m.From())
return m.Render(u.Config.Theme) + Newline
default:
return m.Render(u.Config.Theme) + Newline
}
}
func (u *User) HandleMsg(m Message, out io.Writer) {
r := u.render(m)
_, err := out.Write([]byte(r))
if err != nil {
logger.Printf("Write failed to %s, closing: %s", u.Name(), err)
u.Close()
}
}
// Add message to consume by user
func (u *User) Send(m Message) error {
if u.closed {
return ErrUserClosed
}
select {
case u.msg <- m:
default:
logger.Printf("Msg buffer full, closing: %s", u.Name())
u.Close()
return ErrUserClosed
}
return nil
}
// Container for per-user configurations.
type UserConfig struct {
Highlight *regexp.Regexp
Bell bool
Quiet bool
Theme *Theme
}
// Default user configuration to use
var DefaultUserConfig *UserConfig
func init() {
DefaultUserConfig = &UserConfig{
Bell: true,
Quiet: false,
}
// TODO: Seed random?
}

24
chat/user_test.go Normal file
View File

@ -0,0 +1,24 @@
package chat
import (
"reflect"
"testing"
)
func TestMakeUser(t *testing.T) {
var actual, expected []byte
s := &MockScreen{}
u := NewUser(testId("foo"))
m := NewAnnounceMsg("hello")
defer u.Close()
u.Send(m)
u.ConsumeOne(s)
s.Read(&actual)
expected = []byte(m.String() + Newline)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
}

547
client.go
View File

@ -1,547 +0,0 @@
package main
import (
"fmt"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
)
const (
// MsgBuffer is the length of the message buffer
MsgBuffer int = 20
// MaxMsgLength is the maximum length of a message
MaxMsgLength int = 1024
// MaxNamesList is the max number of items to return in a /names command
MaxNamesList int = 20
// HelpText is the text returned by /help
HelpText string = `Available commands:
/about - About this chat.
/exit - Exit the chat.
/help - Show this help text.
/list - List the users that are currently connected.
/beep - Enable BEL notifications on mention.
/me $ACTION - Show yourself doing an action.
/nick $NAME - Rename yourself to a new name.
/whois $NAME - Display information about another connected user.
/msg $NAME $MESSAGE - Sends a private message to a user.
/motd - Prints the Message of the Day.
/theme [color|mono] - Set client theme.`
// OpHelpText is the additional text returned by /help if the client is an Op
OpHelpText string = `Available operator commands:
/ban $NAME - Banish a user from the chat
/unban $FINGERPRINT - Unban a fingerprint
/banned - List all banned fingerprints
/kick $NAME - Kick em' out.
/op $NAME - Promote a user to server operator.
/silence $NAME - Revoke a user's ability to speak.
/shutdown $MESSAGE - Broadcast message and shutdown server.
/motd $MESSAGE - Set message shown whenever somebody joins.
/whitelist $FINGERPRINT - Add fingerprint to whitelist, prevent anyone else from joining.
/whitelist github.com/$USER - Add github user's pubkeys to whitelist.`
// AboutText is the text returned by /about
AboutText string = `ssh-chat is made by @shazow.
It is a custom ssh server built in Go to serve a chat experience
instead of a shell.
Source: https://github.com/shazow/ssh-chat
For more, visit shazow.net or follow at twitter.com/shazow`
// RequiredWait is the time a client is required to wait between messages
RequiredWait time.Duration = time.Second / 2
)
// Client holds all the fields used by the client
type Client struct {
Server *Server
Conn *ssh.ServerConn
Msg chan string
Name string
Color string
Op bool
ready chan struct{}
term *terminal.Terminal
termWidth int
termHeight int
silencedUntil time.Time
lastTX time.Time
beepMe bool
colorMe bool
closed bool
sync.RWMutex
}
// NewClient constructs a new client
func NewClient(server *Server, conn *ssh.ServerConn) *Client {
return &Client{
Server: server,
Conn: conn,
Name: conn.User(),
Color: RandomColor256(),
Msg: make(chan string, MsgBuffer),
ready: make(chan struct{}, 1),
lastTX: time.Now(),
colorMe: true,
}
}
// ColoredName returns the client name in its color
func (c *Client) ColoredName() string {
return ColorString(c.Color, c.Name)
}
// SysMsg sends a message in continuous format over the message channel
func (c *Client) SysMsg(msg string, args ...interface{}) {
c.Send(ContinuousFormat(systemMessageFormat, "-> "+fmt.Sprintf(msg, args...)))
}
// Write writes the given message
func (c *Client) Write(msg string) {
if !c.colorMe {
msg = DeColorString(msg)
}
c.term.Write([]byte(msg + "\r\n"))
}
// WriteLines writes multiple messages
func (c *Client) WriteLines(msg []string) {
for _, line := range msg {
c.Write(line)
}
}
// Send sends the given message
func (c *Client) Send(msg string) {
if len(msg) > MaxMsgLength || c.closed {
return
}
select {
case c.Msg <- msg:
default:
logger.Errorf("Msg buffer full, dropping: %s (%s)", c.Name, c.Conn.RemoteAddr())
c.Conn.Conn.Close()
}
}
// SendLines sends multiple messages
func (c *Client) SendLines(msg []string) {
for _, line := range msg {
c.Send(line)
}
}
// IsSilenced checks if the client is silenced
func (c *Client) IsSilenced() bool {
return c.silencedUntil.After(time.Now())
}
// Silence silences a client for the given duration
func (c *Client) Silence(d time.Duration) {
c.silencedUntil = time.Now().Add(d)
}
// Resize resizes the client to the given width and height
func (c *Client) Resize(width, height int) error {
width = 1000000 // TODO: Remove this dirty workaround for text overflow once ssh/terminal is fixed
err := c.term.SetSize(width, height)
if err != nil {
logger.Errorf("Resize failed: %dx%d", width, height)
return err
}
c.termWidth, c.termHeight = width, height
return nil
}
// Rename renames the client to the given name
func (c *Client) Rename(name string) {
c.Name = name
var prompt string
if c.colorMe {
prompt = c.ColoredName()
} else {
prompt = c.Name
}
c.term.SetPrompt(fmt.Sprintf("[%s] ", prompt))
}
// Fingerprint returns the fingerprint
func (c *Client) Fingerprint() string {
if c.Conn.Permissions == nil {
return ""
}
return c.Conn.Permissions.Extensions["fingerprint"]
}
// Emote formats and sends an emote
func (c *Client) Emote(message string) {
formatted := fmt.Sprintf("** %s%s", c.ColoredName(), message)
if c.IsSilenced() || len(message) > 1000 {
c.SysMsg("Message rejected")
}
c.Server.Broadcast(formatted, nil)
}
func (c *Client) handleShell(channel ssh.Channel) {
defer channel.Close()
// FIXME: This shouldn't live here, need to restructure the call chaining.
c.Server.Add(c)
go func() {
// Block until done, then remove.
c.Conn.Wait()
c.closed = true
c.Server.Remove(c)
close(c.Msg)
}()
go func() {
for msg := range c.Msg {
c.Write(msg)
}
}()
for {
line, err := c.term.ReadLine()
if err != nil {
break
}
parts := strings.SplitN(line, " ", 3)
isCmd := strings.HasPrefix(parts[0], "/")
if isCmd {
// TODO: Factor this out.
switch parts[0] {
case "/test-colors": // Shh, this command is a secret!
c.Write(ColorString("32", "Lorem ipsum dolor sit amet,"))
c.Write("consectetur " + ColorString("31;1", "adipiscing") + " elit.")
case "/exit":
channel.Close()
case "/help":
c.SysMsg(strings.Replace(HelpText, "\n", "\r\n", -1))
if c.Server.IsOp(c) {
c.SysMsg(strings.Replace(OpHelpText, "\n", "\r\n", -1))
}
case "/about":
c.SysMsg(strings.Replace(AboutText, "\n", "\r\n", -1))
case "/uptime":
c.SysMsg(c.Server.Uptime())
case "/beep":
c.beepMe = !c.beepMe
if c.beepMe {
c.SysMsg("I'll beep you good.")
} else {
c.SysMsg("No more beeps. :(")
}
case "/me":
me := strings.TrimLeft(line, "/me")
if me == "" {
me = " is at a loss for words."
}
c.Emote(me)
case "/slap":
slappee := "themself"
if len(parts) > 1 {
slappee = parts[1]
if len(parts[1]) > 100 {
slappee = "some long-named jerk"
}
}
c.Emote(fmt.Sprintf(" slaps %s around a bit with a large trout.", slappee))
case "/nick":
if len(parts) == 2 {
c.Server.Rename(c, parts[1])
} else {
c.SysMsg("Missing $NAME from: /nick $NAME")
}
case "/whois":
if len(parts) >= 2 {
client := c.Server.Who(parts[1])
if client != nil {
version := reStripText.ReplaceAllString(string(client.Conn.ClientVersion()), "")
if len(version) > 100 {
version = "Evil Jerk with a superlong string"
}
c.SysMsg("%s is %s via %s", client.ColoredName(), client.Fingerprint(), version)
} else {
c.SysMsg("No such name: %s", parts[1])
}
} else {
c.SysMsg("Missing $NAME from: /whois $NAME")
}
case "/names", "/list":
coloredNames := []string{}
for _, name := range c.Server.List(nil) {
coloredNames = append(coloredNames, c.Server.Who(name).ColoredName())
}
num := len(coloredNames)
if len(coloredNames) > MaxNamesList {
others := fmt.Sprintf("and %d others.", len(coloredNames)-MaxNamesList)
coloredNames = coloredNames[:MaxNamesList]
coloredNames = append(coloredNames, others)
}
c.SysMsg("%d connected: %s", num, strings.Join(coloredNames, systemMessageFormat+", "))
case "/ban":
if !c.Server.IsOp(c) {
c.SysMsg("You're not an admin.")
} else if len(parts) != 2 {
c.SysMsg("Missing $NAME from: /ban $NAME")
} else {
client := c.Server.Who(parts[1])
if client == nil {
c.SysMsg("No such name: %s", parts[1])
} else {
fingerprint := client.Fingerprint()
client.SysMsg("Banned by %s.", c.ColoredName())
c.Server.Ban(fingerprint, nil)
client.Conn.Close()
c.Server.Broadcast(fmt.Sprintf("* %s was banned by %s", parts[1], c.ColoredName()), nil)
}
}
case "/unban":
if !c.Server.IsOp(c) {
c.SysMsg("You're not an admin.")
} else if len(parts) != 2 {
c.SysMsg("Missing $FINGERPRINT from: /unban $FINGERPRINT")
} else {
fingerprint := parts[1]
isBanned := c.Server.IsBanned(fingerprint)
if !isBanned {
c.SysMsg("No such banned fingerprint: %s", fingerprint)
} else {
c.Server.Unban(fingerprint)
c.Server.Broadcast(fmt.Sprintf("* %s was unbanned by %s", fingerprint, c.ColoredName()), nil)
}
}
case "/banned":
if !c.Server.IsOp(c) {
c.SysMsg("You're not an admin.")
} else if len(parts) != 1 {
c.SysMsg("Too many arguments for /banned")
} else {
for fingerprint := range c.Server.bannedPK {
c.SysMsg("Banned fingerprint: %s", fingerprint)
}
}
case "/op":
if !c.Server.IsOp(c) {
c.SysMsg("You're not an admin.")
} else if len(parts) != 2 {
c.SysMsg("Missing $NAME from: /op $NAME")
} else {
client := c.Server.Who(parts[1])
if client == nil {
c.SysMsg("No such name: %s", parts[1])
} else {
fingerprint := client.Fingerprint()
if fingerprint == "" {
c.SysMsg("Cannot op user without fingerprint.")
} else {
client.SysMsg("Made op by %s.", c.ColoredName())
c.Server.Op(fingerprint)
}
}
}
case "/kick":
if !c.Server.IsOp(c) {
c.SysMsg("You're not an admin.")
} else if len(parts) != 2 {
c.SysMsg("Missing $NAME from: /kick $NAME")
} else {
client := c.Server.Who(parts[1])
if client == nil {
c.SysMsg("No such name: %s", parts[1])
} else {
client.SysMsg("Kicked by %s.", c.ColoredName())
client.Conn.Close()
c.Server.Broadcast(fmt.Sprintf("* %s was kicked by %s", parts[1], c.ColoredName()), nil)
}
}
case "/silence":
if !c.Server.IsOp(c) {
c.SysMsg("You're not an admin.")
} else if len(parts) < 2 {
c.SysMsg("Missing $NAME from: /silence $NAME")
} else {
duration := time.Duration(5) * time.Minute
if len(parts) >= 3 {
parsedDuration, err := time.ParseDuration(parts[2])
if err == nil {
duration = parsedDuration
}
}
client := c.Server.Who(parts[1])
if client == nil {
c.SysMsg("No such name: %s", parts[1])
} else {
client.Silence(duration)
client.SysMsg("Silenced for %s by %s.", duration, c.ColoredName())
}
}
case "/shutdown":
if !c.Server.IsOp(c) {
c.SysMsg("You're not an admin.")
} else {
var split = strings.SplitN(line, " ", 2)
var msg string
if len(split) > 1 {
msg = split[1]
} else {
msg = ""
}
// Shutdown after 5 seconds
go func() {
c.Server.Broadcast(ColorString("31", msg), nil)
time.Sleep(time.Second * 5)
c.Server.Stop()
}()
}
case "/msg": /* Send a PM */
/* Make sure we have a recipient and a message */
if len(parts) < 2 {
c.SysMsg("Missing $NAME from: /msg $NAME $MESSAGE")
break
} else if len(parts) < 3 {
c.SysMsg("Missing $MESSAGE from: /msg $NAME $MESSAGE")
break
}
/* Ask the server to send the message */
if err := c.Server.Privmsg(parts[1], parts[2], c); nil != err {
c.SysMsg("Unable to send message to %v: %v", parts[1], err)
}
case "/motd": /* print motd */
if !c.Server.IsOp(c) {
c.Server.MotdUnicast(c)
} else if len(parts) < 2 {
c.Server.MotdUnicast(c)
} else {
var newmotd string
if len(parts) == 2 {
newmotd = parts[1]
} else {
newmotd = parts[1] + " " + parts[2]
}
c.Server.SetMotd(newmotd)
c.Server.MotdBroadcast(c)
}
case "/theme":
if len(parts) < 2 {
c.SysMsg("Missing $THEME from: /theme $THEME")
c.SysMsg("Choose either color or mono")
} else {
// Sets colorMe attribute of client
if parts[1] == "mono" {
c.colorMe = false
} else if parts[1] == "color" {
c.colorMe = true
}
// Rename to reset prompt
c.Rename(c.Name)
}
case "/whitelist": /* whitelist a fingerprint */
if !c.Server.IsOp(c) {
c.SysMsg("You're not an admin.")
} else if len(parts) != 2 {
c.SysMsg("Missing $FINGERPRINT from: /whitelist $FINGERPRINT")
} else {
fingerprint := parts[1]
go func() {
err = c.Server.Whitelist(fingerprint)
if err != nil {
c.SysMsg("Error adding to whitelist: %s", err)
} else {
c.SysMsg("Added %s to the whitelist", fingerprint)
}
}()
}
case "/version":
c.SysMsg("Version " + buildCommit)
default:
c.SysMsg("Invalid command: %s", line)
}
continue
}
msg := fmt.Sprintf("%s: %s", c.ColoredName(), line)
/* Rate limit */
if time.Now().Sub(c.lastTX) < RequiredWait {
c.SysMsg("Rate limiting in effect.")
continue
}
if c.IsSilenced() || len(msg) > 1000 || len(line) < 1 {
c.SysMsg("Message rejected.")
continue
}
c.Server.Broadcast(msg, c)
c.lastTX = time.Now()
}
}
func (c *Client) handleChannels(channels <-chan ssh.NewChannel) {
prompt := fmt.Sprintf("[%s] ", c.ColoredName())
hasShell := false
for ch := range channels {
if t := ch.ChannelType(); t != "session" {
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
continue
}
channel, requests, err := ch.Accept()
if err != nil {
logger.Errorf("Could not accept channel: %v", err)
continue
}
defer channel.Close()
c.term = terminal.NewTerminal(channel, prompt)
c.term.AutoCompleteCallback = c.Server.AutoCompleteFunction
for req := range requests {
var width, height int
var ok bool
switch req.Type {
case "shell":
if c.term != nil && !hasShell {
go c.handleShell(channel)
ok = true
hasShell = true
}
case "pty-req":
width, height, ok = parsePtyRequest(req.Payload)
if ok {
err := c.Resize(width, height)
ok = err == nil
}
case "window-change":
width, height, ok = parseWinchRequest(req.Payload)
if ok {
err := c.Resize(width, height)
ok = err == nil
}
}
if req.WantReply {
req.Reply(ok, nil)
}
}
}
}

136
cmd.go
View File

@ -13,6 +13,10 @@ import (
"github.com/alexcesaro/log" "github.com/alexcesaro/log"
"github.com/alexcesaro/log/golog" "github.com/alexcesaro/log/golog"
"github.com/jessevdk/go-flags" "github.com/jessevdk/go-flags"
"golang.org/x/crypto/ssh"
"github.com/shazow/ssh-chat/chat"
"github.com/shazow/ssh-chat/sshd"
) )
import _ "net/http/pprof" import _ "net/http/pprof"
@ -20,10 +24,11 @@ import _ "net/http/pprof"
type Options struct { type Options struct {
Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."` Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."`
Identity string `short:"i" long:"identity" description:"Private key to identify server with." default:"~/.ssh/id_rsa"` Identity string `short:"i" long:"identity" description:"Private key to identify server with." default:"~/.ssh/id_rsa"`
Bind string `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:22"` Bind string `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:2022"`
Admin []string `long:"admin" description:"Fingerprint of pubkey to mark as admin."` Admin string `long:"admin" description:"File of public keys who are admins."`
Whitelist string `long:"whitelist" description:"Optional file of pubkey fingerprints who are allowed to connect."` Whitelist string `long:"whitelist" description:"Optional file of public keys who are allowed to connect."`
Motd string `long:"motd" description:"Optional Message of the Day file."` Motd string `long:"motd" description:"Optional Message of the Day file."`
Log string `long:"log" description:"Write chat log to this file."`
Pprof int `long:"pprof" description:"Enable pprof http server for profiling."` Pprof int `long:"pprof" description:"Enable pprof http server for profiling."`
} }
@ -34,6 +39,7 @@ var logLevels = []log.Level{
} }
var buildCommit string var buildCommit string
func main() { func main() {
options := Options{} options := Options{}
parser := flags.NewParser(&options, flags.Default) parser := flags.NewParser(&options, flags.Default)
@ -42,6 +48,7 @@ func main() {
if p == nil { if p == nil {
fmt.Print(err) fmt.Print(err)
} }
os.Exit(1)
return return
} }
@ -51,54 +58,84 @@ func main() {
}() }()
} }
// Initialize seed for random colors
RandomColorInit()
// Figure out the log level // Figure out the log level
numVerbose := len(options.Verbose) numVerbose := len(options.Verbose)
if numVerbose > len(logLevels) { if numVerbose > len(logLevels) {
numVerbose = len(logLevels) numVerbose = len(logLevels) - 1
} }
logLevel := logLevels[numVerbose] logLevel := logLevels[numVerbose]
logger = golog.New(os.Stderr, logLevel) logger = golog.New(os.Stderr, logLevel)
if logLevel == log.Debug {
// Enable logging from submodules
chat.SetLogger(os.Stderr)
sshd.SetLogger(os.Stderr)
}
privateKeyPath := options.Identity privateKeyPath := options.Identity
if strings.HasPrefix(privateKeyPath, "~") { if strings.HasPrefix(privateKeyPath, "~/") {
user, err := user.Current() user, err := user.Current()
if err == nil { if err == nil {
privateKeyPath = strings.Replace(privateKeyPath, "~", user.HomeDir, 1) privateKeyPath = strings.Replace(privateKeyPath, "~", user.HomeDir, 1)
} }
} }
privateKey, err := ioutil.ReadFile(privateKeyPath) privateKey, err := ReadPrivateKey(privateKeyPath)
if err != nil { if err != nil {
logger.Errorf("Failed to load identity: %v", err) logger.Errorf("Couldn't read private key: %v", err)
return os.Exit(2)
} }
server, err := NewServer(privateKey) signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil { if err != nil {
logger.Errorf("Failed to create server: %v", err) logger.Errorf("Failed to parse key: %v", err)
return os.Exit(3)
} }
for _, fingerprint := range options.Admin { auth := NewAuth()
server.Op(fingerprint) config := sshd.MakeAuth(auth)
} config.AddHostKey(signer)
if options.Whitelist != "" { s, err := sshd.ListenSSH(options.Bind, config)
file, err := os.Open(options.Whitelist)
if err != nil { if err != nil {
logger.Errorf("Could not open whitelist file") logger.Errorf("Failed to listen on socket: %v", err)
return os.Exit(4)
} }
defer file.Close() defer s.Close()
s.RateLimit = true
scanner := bufio.NewScanner(file) fmt.Printf("Listening for connections on %v\n", s.Addr().String())
for scanner.Scan() {
server.Whitelist(scanner.Text()) host := NewHost(s)
host.auth = auth
host.theme = &chat.Themes[0]
err = fromFile(options.Admin, func(line []byte) error {
key, _, _, _, err := ssh.ParseAuthorizedKey(line)
if err != nil {
return err
} }
auth.Op(key, 0)
return nil
})
if err != nil {
logger.Errorf("Failed to load admins: %v", err)
os.Exit(5)
}
err = fromFile(options.Whitelist, func(line []byte) error {
key, _, _, _, err := ssh.ParseAuthorizedKey(line)
if err != nil {
return err
}
auth.Whitelist(key, 0)
logger.Debugf("Whitelisted: %s", line)
return nil
})
if err != nil {
logger.Errorf("Failed to load whitelist: %v", err)
os.Exit(5)
} }
if options.Motd != "" { if options.Motd != "" {
@ -107,24 +144,53 @@ func main() {
logger.Errorf("Failed to load MOTD file: %v", err) logger.Errorf("Failed to load MOTD file: %v", err)
return return
} }
motdString := string(motd[:]) motdString := strings.TrimSpace(string(motd))
/* hack to normalize line endings into \r\n */ // hack to normalize line endings into \r\n
motdString = strings.Replace(motdString, "\r\n", "\n", -1) motdString = strings.Replace(motdString, "\r\n", "\n", -1)
motdString = strings.Replace(motdString, "\n", "\r\n", -1) motdString = strings.Replace(motdString, "\n", "\r\n", -1)
server.SetMotd(motdString) host.SetMotd(motdString)
} }
if options.Log == "-" {
host.SetLogging(os.Stdout)
} else if options.Log != "" {
fp, err := os.OpenFile(options.Log, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
logger.Errorf("Failed to open log file for writing: %v", err)
return
}
host.SetLogging(fp)
}
go host.Serve()
// Construct interrupt handler // Construct interrupt handler
sig := make(chan os.Signal, 1) sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt) signal.Notify(sig, os.Interrupt)
err = server.Start(options.Bind)
if err != nil {
logger.Errorf("Failed to start server: %v", err)
return
}
<-sig // Wait for ^C signal <-sig // Wait for ^C signal
logger.Warningf("Interrupt signal detected, shutting down.") logger.Warningf("Interrupt signal detected, shutting down.")
server.Stop() os.Exit(0)
}
func fromFile(path string, handler func(line []byte) error) error {
if path == "" {
// Skip
return nil
}
file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
err := handler(scanner.Bytes())
if err != nil {
return err
}
}
return nil
} }

View File

@ -1,82 +0,0 @@
package main
import (
"fmt"
"math/rand"
"regexp"
"strings"
"time"
)
const (
// Reset resets the color
Reset = "\033[0m"
// Bold makes the following text bold
Bold = "\033[1m"
// Dim dims the following text
Dim = "\033[2m"
// Italic makes the following text italic
Italic = "\033[3m"
// Underline underlines the following text
Underline = "\033[4m"
// Blink blinks the following text
Blink = "\033[5m"
// Invert inverts the following text
Invert = "\033[7m"
)
var colors = []string{"31", "32", "33", "34", "35", "36", "37", "91", "92", "93", "94", "95", "96", "97"}
// deColor is used for removing ANSI Escapes
var deColor = regexp.MustCompile("\033\\[[\\d;]+m")
// DeColorString removes all color from the given string
func DeColorString(s string) string {
s = deColor.ReplaceAllString(s, "")
return s
}
func randomReadableColor() int {
for {
i := rand.Intn(256)
if (16 <= i && i <= 18) || (232 <= i && i <= 237) {
// Remove the ones near black, this is kinda sadpanda.
continue
}
return i
}
}
// RandomColor256 returns a random (of 256) color
func RandomColor256() string {
return fmt.Sprintf("38;05;%d", randomReadableColor())
}
// RandomColor returns a random color
func RandomColor() string {
return colors[rand.Intn(len(colors))]
}
// ColorString returns a message in the given color
func ColorString(color string, msg string) string {
return Bold + "\033[" + color + "m" + msg + Reset
}
// RandomColorInit initializes the random seed
func RandomColorInit() {
rand.Seed(time.Now().UTC().UnixNano())
}
// ContinuousFormat is a horrible hack to "continue" the previous string color
// and format after a RESET has been encountered.
//
// This is not HTML where you can just do a </style> to resume your previous formatting!
func ContinuousFormat(format string, str string) string {
return systemMessageFormat + strings.Replace(str, Reset, format, -1) + Reset
}

View File

@ -1,53 +0,0 @@
package main
import (
"reflect"
"testing"
)
func TestHistory(t *testing.T) {
var r, expected []string
var size int
h := NewHistory(5)
r = h.Get(10)
expected = []string{}
if !reflect.DeepEqual(r, expected) {
t.Errorf("Got: %v, Expected: %v", r, expected)
}
h.Add("1")
if size = h.Len(); size != 1 {
t.Errorf("Wrong size: %v", size)
}
r = h.Get(1)
expected = []string{"1"}
if !reflect.DeepEqual(r, expected) {
t.Errorf("Got: %v, Expected: %v", r, expected)
}
h.Add("2")
h.Add("3")
h.Add("4")
h.Add("5")
h.Add("6")
if size = h.Len(); size != 5 {
t.Errorf("Wrong size: %v", size)
}
r = h.Get(2)
expected = []string{"5", "6"}
if !reflect.DeepEqual(r, expected) {
t.Errorf("Got: %v, Expected: %v", r, expected)
}
r = h.Get(10)
expected = []string{"2", "3", "4", "5", "6"}
if !reflect.DeepEqual(r, expected) {
t.Errorf("Got: %v, Expected: %v", r, expected)
}
}

461
host.go Normal file
View File

@ -0,0 +1,461 @@
package main
import (
"errors"
"fmt"
"io"
"strings"
"time"
"github.com/shazow/rateio"
"github.com/shazow/ssh-chat/chat"
"github.com/shazow/ssh-chat/sshd"
)
const maxInputLength int = 1024
// GetPrompt will render the terminal prompt string based on the user.
func GetPrompt(user *chat.User) string {
name := user.Name()
if user.Config.Theme != nil {
name = user.Config.Theme.ColorName(user)
}
return fmt.Sprintf("[%s] ", name)
}
// Host is the bridge between sshd and chat modules
// TODO: Should be easy to add support for multiple rooms, if we want.
type Host struct {
*chat.Room
listener *sshd.SSHListener
commands chat.Commands
motd string
auth *Auth
count int
// Default theme
theme *chat.Theme
}
// NewHost creates a Host on top of an existing listener.
func NewHost(listener *sshd.SSHListener) *Host {
room := chat.NewRoom()
h := Host{
Room: room,
listener: listener,
commands: chat.Commands{},
}
// Make our own commands registry instance.
chat.InitCommands(&h.commands)
h.InitCommands(&h.commands)
room.SetCommands(h.commands)
go room.Serve()
return &h
}
// SetMotd sets the host's message of the day.
func (h *Host) SetMotd(motd string) {
h.motd = motd
}
func (h Host) isOp(conn sshd.Connection) bool {
key := conn.PublicKey()
if key == nil {
return false
}
return h.auth.IsOp(key)
}
// Connect a specific Terminal to this host and its room.
func (h *Host) Connect(term *sshd.Terminal) {
id := NewIdentity(term.Conn)
user := chat.NewUserScreen(id, term)
user.Config.Theme = h.theme
go func() {
// Close term once user is closed.
user.Wait()
term.Close()
}()
defer user.Close()
// Send MOTD
if h.motd != "" {
user.Send(chat.NewAnnounceMsg(h.motd))
}
member, err := h.Join(user)
if err != nil {
// Try again...
id.SetName(fmt.Sprintf("Guest%d", h.count))
member, err = h.Join(user)
}
if err != nil {
logger.Errorf("Failed to join: %s", err)
return
}
// Successfully joined.
term.SetPrompt(GetPrompt(user))
term.AutoCompleteCallback = h.AutoCompleteFunction(user)
user.SetHighlight(user.Name())
h.count++
// Should the user be op'd on join?
member.Op = h.isOp(term.Conn)
ratelimit := rateio.NewSimpleLimiter(3, time.Second*3)
for {
line, err := term.ReadLine()
if err == io.EOF {
// Closed
break
} else if err != nil {
logger.Errorf("Terminal reading error: %s", err)
break
}
err = ratelimit.Count(1)
if err != nil {
user.Send(chat.NewSystemMsg("Message rejected: Rate limiting is in effect.", user))
continue
}
if len(line) > maxInputLength {
user.Send(chat.NewSystemMsg("Message rejected: Input too long.", user))
continue
}
if line == "" {
// Silently ignore empty lines.
continue
}
m := chat.ParseInput(line, user)
// FIXME: Any reason to use h.room.Send(m) instead?
h.HandleMsg(m)
cmd := m.Command()
if cmd == "/nick" || cmd == "/theme" {
// Hijack /nick command to update terminal synchronously. Wouldn't
// work if we use h.room.Send(m) above.
//
// FIXME: This is hacky, how do we improve the API to allow for
// this? Chat module shouldn't know about terminals.
term.SetPrompt(GetPrompt(user))
user.SetHighlight(user.Name())
}
}
err = h.Leave(user)
if err != nil {
logger.Errorf("Failed to leave: %s", err)
return
}
}
// Serve our chat room onto the listener
func (h *Host) Serve() {
terminals := h.listener.ServeTerminal()
for term := range terminals {
go h.Connect(term)
}
}
func (h Host) completeName(partial string) string {
names := h.NamesPrefix(partial)
if len(names) == 0 {
// Didn't find anything
return ""
}
return names[len(names)-1]
}
func (h Host) completeCommand(partial string) string {
for cmd, _ := range h.commands {
if strings.HasPrefix(cmd, partial) {
return cmd
}
}
return ""
}
// AutoCompleteFunction returns a callback for terminal autocompletion
func (h *Host) AutoCompleteFunction(u *chat.User) func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
return func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
if key != 9 {
return
}
if strings.HasSuffix(line[:pos], " ") {
// Don't autocomplete spaces.
return
}
fields := strings.Fields(line[:pos])
isFirst := len(fields) < 2
partial := fields[len(fields)-1]
posPartial := pos - len(partial)
var completed string
if isFirst && strings.HasPrefix(partial, "/") {
// Command
completed = h.completeCommand(partial)
if completed == "/reply" {
replyTo := u.ReplyTo()
if replyTo != nil {
completed = "/msg " + replyTo.Name()
}
}
} else {
// Name
completed = h.completeName(partial)
if completed == "" {
return
}
if isFirst {
completed += ":"
}
}
completed += " "
// Reposition the cursor
newLine = strings.Replace(line[posPartial:], partial, completed, 1)
newLine = line[:posPartial] + newLine
newPos = pos + (len(completed) - len(partial))
ok = true
return
}
}
// GetUser returns a chat.User based on a name.
func (h *Host) GetUser(name string) (*chat.User, bool) {
m, ok := h.MemberById(name)
if !ok {
return nil, false
}
return m.User, true
}
// InitCommands adds host-specific commands to a Commands container. These will
// override any existing commands.
func (h *Host) InitCommands(c *chat.Commands) {
c.Add(chat.Command{
Prefix: "/msg",
PrefixHelp: "USER MESSAGE",
Help: "Send MESSAGE to USER.",
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
args := msg.Args()
switch len(args) {
case 0:
return errors.New("must specify user")
case 1:
return errors.New("must specify message")
}
target, ok := h.GetUser(args[0])
if !ok {
return errors.New("user not found")
}
m := chat.NewPrivateMsg(strings.Join(args[1:], " "), msg.From(), target)
room.Send(m)
return nil
},
})
c.Add(chat.Command{
Prefix: "/reply",
PrefixHelp: "MESSAGE",
Help: "Reply with MESSAGE to the previous private message.",
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
args := msg.Args()
switch len(args) {
case 0:
return errors.New("must specify message")
}
target := msg.From().ReplyTo()
if target == nil {
return errors.New("no message to reply to")
}
m := chat.NewPrivateMsg(strings.Join(args, " "), msg.From(), target)
room.Send(m)
return nil
},
})
c.Add(chat.Command{
Prefix: "/whois",
PrefixHelp: "USER",
Help: "Information about USER.",
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
args := msg.Args()
if len(args) == 0 {
return errors.New("must specify user")
}
target, ok := h.GetUser(args[0])
if !ok {
return errors.New("user not found")
}
id := target.Identifier.(*Identity)
room.Send(chat.NewSystemMsg(id.Whois(), msg.From()))
return nil
},
})
// Hidden commands
c.Add(chat.Command{
Prefix: "/version",
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
room.Send(chat.NewSystemMsg(buildCommit, msg.From()))
return nil
},
})
timeStarted := time.Now()
c.Add(chat.Command{
Prefix: "/uptime",
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
room.Send(chat.NewSystemMsg(time.Now().Sub(timeStarted).String(), msg.From()))
return nil
},
})
// Op commands
c.Add(chat.Command{
Op: true,
Prefix: "/kick",
PrefixHelp: "USER",
Help: "Kick USER from the server.",
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
if !room.IsOp(msg.From()) {
return errors.New("must be op")
}
args := msg.Args()
if len(args) == 0 {
return errors.New("must specify user")
}
target, ok := h.GetUser(args[0])
if !ok {
return errors.New("user not found")
}
body := fmt.Sprintf("%s was kicked by %s.", target.Name(), msg.From().Name())
room.Send(chat.NewAnnounceMsg(body))
target.Close()
return nil
},
})
c.Add(chat.Command{
Op: true,
Prefix: "/ban",
PrefixHelp: "USER [DURATION]",
Help: "Ban USER from the server.",
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
// TODO: Would be nice to specify what to ban. Key? Ip? etc.
if !room.IsOp(msg.From()) {
return errors.New("must be op")
}
args := msg.Args()
if len(args) == 0 {
return errors.New("must specify user")
}
target, ok := h.GetUser(args[0])
if !ok {
return errors.New("user not found")
}
var until time.Duration = 0
if len(args) > 1 {
until, _ = time.ParseDuration(args[1])
}
id := target.Identifier.(*Identity)
h.auth.Ban(id.PublicKey(), until)
h.auth.BanAddr(id.RemoteAddr(), until)
body := fmt.Sprintf("%s was banned by %s.", target.Name(), msg.From().Name())
room.Send(chat.NewAnnounceMsg(body))
target.Close()
logger.Debugf("Banned: \n-> %s", id.Whois())
return nil
},
})
c.Add(chat.Command{
Op: true,
Prefix: "/motd",
PrefixHelp: "MESSAGE",
Help: "Set the MESSAGE of the day.",
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
if !room.IsOp(msg.From()) {
return errors.New("must be op")
}
motd := ""
args := msg.Args()
if len(args) > 0 {
motd = strings.Join(args, " ")
}
h.motd = motd
body := fmt.Sprintf("New message of the day set by %s:", msg.From().Name())
room.Send(chat.NewAnnounceMsg(body))
if motd != "" {
room.Send(chat.NewAnnounceMsg(motd))
}
return nil
},
})
c.Add(chat.Command{
Op: true,
Prefix: "/op",
PrefixHelp: "USER [DURATION]",
Help: "Set USER as admin.",
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
if !room.IsOp(msg.From()) {
return errors.New("must be op")
}
args := msg.Args()
if len(args) == 0 {
return errors.New("must specify user")
}
var until time.Duration = 0
if len(args) > 1 {
until, _ = time.ParseDuration(args[1])
}
member, ok := room.MemberById(args[0])
if !ok {
return errors.New("user not found")
}
member.Op = true
id := member.Identifier.(*Identity)
h.auth.Op(id.PublicKey(), until)
body := fmt.Sprintf("Made op by %s.", msg.From().Name())
room.Send(chat.NewSystemMsg(body, member.User))
return nil
},
})
}

218
host_test.go Normal file
View File

@ -0,0 +1,218 @@
package main
import (
"bufio"
"crypto/rand"
"crypto/rsa"
"io"
"io/ioutil"
"strings"
"testing"
"time"
"github.com/shazow/ssh-chat/chat"
"github.com/shazow/ssh-chat/sshd"
"golang.org/x/crypto/ssh"
)
func stripPrompt(s string) string {
pos := strings.LastIndex(s, "\033[K")
if pos < 0 {
return s
}
return s[pos+3:]
}
func TestHostGetPrompt(t *testing.T) {
var expected, actual string
u := chat.NewUser(&Identity{nil, "foo"})
u.SetColorIdx(2)
actual = GetPrompt(u)
expected = "[foo] "
if actual != expected {
t.Errorf("Got: %q; Expected: %q", actual, expected)
}
u.Config.Theme = &chat.Themes[0]
actual = GetPrompt(u)
expected = "[\033[38;05;2mfoo\033[0m] "
if actual != expected {
t.Errorf("Got: %q; Expected: %q", actual, expected)
}
}
func TestHostNameCollision(t *testing.T) {
key, err := sshd.NewRandomSigner(512)
if err != nil {
t.Fatal(err)
}
config := sshd.MakeNoAuth()
config.AddHostKey(key)
s, err := sshd.ListenSSH(":0", config)
if err != nil {
t.Fatal(err)
}
defer s.Close()
host := NewHost(s)
go host.Serve()
done := make(chan struct{}, 1)
// First client
go func() {
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
scanner := bufio.NewScanner(r)
// Consume the initial buffer
scanner.Scan()
actual := scanner.Text()
if !strings.HasPrefix(actual, "[foo] ") {
t.Errorf("First client failed to get 'foo' name.")
}
actual = stripPrompt(actual)
expected := " * foo joined. (Connected: 1)"
if actual != expected {
t.Errorf("Got %q; expected %q", actual, expected)
}
// Ready for second client
done <- struct{}{}
scanner.Scan()
actual = stripPrompt(scanner.Text())
expected = " * Guest1 joined. (Connected: 2)"
if actual != expected {
t.Errorf("Got %q; expected %q", actual, expected)
}
// Wrap it up.
close(done)
})
if err != nil {
t.Fatal(err)
}
}()
// Wait for first client
<-done
// Second client
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
scanner := bufio.NewScanner(r)
// Consume the initial buffer
scanner.Scan()
actual := scanner.Text()
if !strings.HasPrefix(actual, "[Guest1] ") {
t.Errorf("Second client did not get Guest1 name.")
}
})
if err != nil {
t.Fatal(err)
}
<-done
}
func TestHostWhitelist(t *testing.T) {
key, err := sshd.NewRandomSigner(512)
if err != nil {
t.Fatal(err)
}
auth := NewAuth()
config := sshd.MakeAuth(auth)
config.AddHostKey(key)
s, err := sshd.ListenSSH(":0", config)
if err != nil {
t.Fatal(err)
}
defer s.Close()
host := NewHost(s)
host.auth = auth
go host.Serve()
target := s.Addr().String()
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
if err != nil {
t.Error(err)
}
clientkey, err := rsa.GenerateKey(rand.Reader, 512)
if err != nil {
t.Fatal(err)
}
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
auth.Whitelist(clientpubkey, 0)
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
if err == nil {
t.Error("Failed to block unwhitelisted connection.")
}
}
func TestHostKick(t *testing.T) {
key, err := sshd.NewRandomSigner(512)
if err != nil {
t.Fatal(err)
}
auth := NewAuth()
config := sshd.MakeAuth(auth)
config.AddHostKey(key)
s, err := sshd.ListenSSH(":0", config)
if err != nil {
t.Fatal(err)
}
defer s.Close()
addr := s.Addr().String()
host := NewHost(s)
go host.Serve()
connected := make(chan struct{})
done := make(chan struct{})
go func() {
// First client
err = sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) {
// Make op
member, _ := host.Room.MemberById("foo")
member.Op = true
// Block until second client is here
connected <- struct{}{}
w.Write([]byte("/kick bar\r\n"))
})
if err != nil {
t.Fatal(err)
}
}()
go func() {
// Second client
err = sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) {
<-connected
// Consume while we're connected. Should break when kicked.
ioutil.ReadAll(r)
})
if err != nil {
t.Fatal(err)
}
close(done)
}()
select {
case <-done:
case <-time.After(time.Second * 1):
t.Fatal("Timeout.")
}
}

50
identity.go Normal file
View File

@ -0,0 +1,50 @@
package main
import (
"fmt"
"net"
"github.com/shazow/ssh-chat/chat"
"github.com/shazow/ssh-chat/sshd"
)
// Identity is a container for everything that identifies a client.
type Identity struct {
sshd.Connection
id string
}
// NewIdentity returns a new identity object from an sshd.Connection.
func NewIdentity(conn sshd.Connection) *Identity {
return &Identity{
Connection: conn,
id: chat.SanitizeName(conn.Name()),
}
}
func (i Identity) Id() string {
return i.id
}
func (i *Identity) SetId(id string) {
i.id = id
}
func (i *Identity) SetName(name string) {
i.SetId(name)
}
func (i Identity) Name() string {
return i.id
}
func (i Identity) Whois() string {
ip, _, _ := net.SplitHostPort(i.RemoteAddr().String())
fingerprint := "(no public key)"
if i.PublicKey() != nil {
fingerprint = sshd.Fingerprint(i.PublicKey())
}
return fmt.Sprintf("name: %s"+chat.Newline+
" > ip: %s"+chat.Newline+
" > fingerprint: %s", i.Name(), ip, fingerprint)
}

49
key.go Normal file
View File

@ -0,0 +1,49 @@
package main
import (
"crypto/x509"
"encoding/pem"
"fmt"
"io/ioutil"
"os"
"code.google.com/p/gopass"
)
// ReadPrivateKey attempts to read your private key and possibly decrypt it if it
// requires a passphrase.
// This function will prompt for a passphrase on STDIN if the environment variable (`IDENTITY_PASSPHRASE`),
// is not set.
func ReadPrivateKey(path string) ([]byte, error) {
privateKey, err := ioutil.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to load identity: %v", err)
}
block, rest := pem.Decode(privateKey)
if len(rest) > 0 {
return nil, fmt.Errorf("extra data when decoding private key")
}
if !x509.IsEncryptedPEMBlock(block) {
return privateKey, nil
}
passphrase := os.Getenv("IDENTITY_PASSPHRASE")
if passphrase == "" {
passphrase, err = gopass.GetPass("Enter passphrase: ")
if err != nil {
return nil, fmt.Errorf("couldn't read passphrase: %v", err)
}
}
der, err := x509.DecryptPEMBlock(block, []byte(passphrase))
if err != nil {
return nil, fmt.Errorf("decrypt failed: %v", err)
}
privateKey = pem.EncodeToMemory(&pem.Block{
Type: block.Type,
Bytes: der,
})
return privateKey, nil
}

View File

@ -1,7 +1,16 @@
package main package main
import ( import (
"bytes"
"github.com/alexcesaro/log"
"github.com/alexcesaro/log/golog" "github.com/alexcesaro/log/golog"
) )
var logger *golog.Logger var logger *golog.Logger
func init() {
// Set a default null logger
var b bytes.Buffer
logger = golog.New(&b, log.Debug)
}

519
server.go
View File

@ -1,519 +0,0 @@
package main
import (
"bufio"
"crypto/md5"
"encoding/base64"
"fmt"
"net"
"net/http"
"regexp"
"strings"
"sync"
"syscall"
"time"
"golang.org/x/crypto/ssh"
)
const (
maxNameLength = 32
historyLength = 20
systemMessageFormat = "\033[1;90m"
privateMessageFormat = "\033[1m"
highlightFormat = Bold + "\033[48;5;11m\033[38;5;16m"
beep = "\007"
)
var (
reStripText = regexp.MustCompile("[^0-9A-Za-z_.-]")
)
// Clients is a map of clients
type Clients map[string]*Client
// Server holds all the fields used by a server
type Server struct {
sshConfig *ssh.ServerConfig
done chan struct{}
clients Clients
count int
history *History
motd string
whitelist map[string]struct{} // fingerprint lookup
admins map[string]struct{} // fingerprint lookup
bannedPK map[string]*time.Time // fingerprint lookup
started time.Time
sync.RWMutex
}
// NewServer constructs a new server
func NewServer(privateKey []byte) (*Server, error) {
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
return nil, err
}
server := Server{
done: make(chan struct{}),
clients: Clients{},
count: 0,
history: NewHistory(historyLength),
motd: "",
whitelist: map[string]struct{}{},
admins: map[string]struct{}{},
bannedPK: map[string]*time.Time{},
started: time.Now(),
}
config := ssh.ServerConfig{
NoClientAuth: false,
// Auth-related things should be constant-time to avoid timing attacks.
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
fingerprint := Fingerprint(key)
if server.IsBanned(fingerprint) {
return nil, fmt.Errorf("Banned.")
}
if !server.IsWhitelisted(fingerprint) {
return nil, fmt.Errorf("Not Whitelisted.")
}
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}}
return perm, nil
},
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
if server.IsBanned("") {
return nil, fmt.Errorf("Interactive login disabled.")
}
if !server.IsWhitelisted("") {
return nil, fmt.Errorf("Not Whitelisted.")
}
return nil, nil
},
}
config.AddHostKey(signer)
server.sshConfig = &config
return &server, nil
}
// Len returns the number of clients
func (s *Server) Len() int {
return len(s.clients)
}
// SysMsg broadcasts the given message to everyone
func (s *Server) SysMsg(msg string, args ...interface{}) {
s.Broadcast(ContinuousFormat(systemMessageFormat, " * "+fmt.Sprintf(msg, args...)), nil)
}
// Broadcast broadcasts the given message to everyone except for the given client
func (s *Server) Broadcast(msg string, except *Client) {
logger.Debugf("Broadcast to %d: %s", s.Len(), msg)
s.history.Add(msg)
s.RLock()
defer s.RUnlock()
for _, client := range s.clients {
if except != nil && client == except {
continue
}
if strings.Contains(msg, client.Name) {
// Turn message red if client's name is mentioned, and send BEL if they have enabled beeping
personalMsg := strings.Replace(msg, client.Name, highlightFormat+client.Name+Reset, -1)
if client.beepMe {
personalMsg += beep
}
client.Send(personalMsg)
} else {
client.Send(msg)
}
}
}
// Privmsg sends a message to a particular nick, if it exists
func (s *Server) Privmsg(nick, message string, sender *Client) error {
// Get the recipient
target, ok := s.clients[strings.ToLower(nick)]
if !ok {
return fmt.Errorf("no client with that nick")
}
// Send the message
target.Msg <- fmt.Sprintf(beep+"[PM from %v] %s%v%s", sender.ColoredName(), privateMessageFormat, message, Reset)
logger.Debugf("PM from %v to %v: %v", sender.Name, nick, message)
return nil
}
// SetMotd sets the Message of the Day (MOTD)
func (s *Server) SetMotd(motd string) {
s.motd = motd
}
// MotdUnicast sends the MOTD as a SysMsg
func (s *Server) MotdUnicast(client *Client) {
if s.motd == "" {
return
}
client.SysMsg(s.motd)
}
// MotdBroadcast broadcasts the MOTD
func (s *Server) MotdBroadcast(client *Client) {
if s.motd == "" {
return
}
s.Broadcast(ContinuousFormat(systemMessageFormat, fmt.Sprintf(" * New MOTD set by %s.", client.ColoredName())), client)
s.Broadcast(s.motd, client)
}
// Add adds the client to the list of clients
func (s *Server) Add(client *Client) {
go func() {
s.MotdUnicast(client)
client.SendLines(s.history.Get(10))
}()
s.Lock()
s.count++
newName, err := s.proposeName(client.Name)
if err != nil {
client.SysMsg("Your name '%s' is not available, renamed to '%s'. Use /nick <name> to change it.", client.Name, ColorString(client.Color, newName))
}
client.Rename(newName)
s.clients[strings.ToLower(client.Name)] = client
num := len(s.clients)
s.Unlock()
s.Broadcast(ContinuousFormat(systemMessageFormat, fmt.Sprintf(" * %s joined. (Total connected: %d)", client.Name, num)), client)
}
// Remove removes the given client from the list of clients
func (s *Server) Remove(client *Client) {
s.Lock()
delete(s.clients, strings.ToLower(client.Name))
s.Unlock()
s.SysMsg("%s left.", client.Name)
}
func (s *Server) proposeName(name string) (string, error) {
// Assumes caller holds lock.
var err error
name = reStripText.ReplaceAllString(name, "")
if len(name) > maxNameLength {
name = name[:maxNameLength]
} else if len(name) == 0 {
name = fmt.Sprintf("Guest%d", s.count)
}
_, collision := s.clients[strings.ToLower(name)]
if collision {
err = fmt.Errorf("%s is not available", name)
name = fmt.Sprintf("Guest%d", s.count)
}
return name, err
}
// Rename renames the given client (user)
func (s *Server) Rename(client *Client, newName string) {
var oldName string
if strings.ToLower(newName) == strings.ToLower(client.Name) {
oldName = client.Name
client.Rename(newName)
} else {
s.Lock()
newName, err := s.proposeName(newName)
if err != nil {
client.SysMsg("%s", err)
s.Unlock()
return
}
// TODO: Use a channel/goroutine for adding clients, rather than locks?
delete(s.clients, strings.ToLower(client.Name))
oldName = client.Name
client.Rename(newName)
s.clients[strings.ToLower(client.Name)] = client
s.Unlock()
}
s.SysMsg("%s is now known as %s.", ColorString(client.Color, oldName), ColorString(client.Color, client.Name))
}
// List lists the clients with the given prefix
func (s *Server) List(prefix *string) []string {
r := []string{}
s.RLock()
defer s.RUnlock()
for name, client := range s.clients {
if prefix != nil && !strings.HasPrefix(name, strings.ToLower(*prefix)) {
continue
}
r = append(r, client.Name)
}
return r
}
// Who returns the client with a given name
func (s *Server) Who(name string) *Client {
return s.clients[strings.ToLower(name)]
}
// Op adds the given fingerprint to the list of admins
func (s *Server) Op(fingerprint string) {
logger.Infof("Adding admin: %s", fingerprint)
s.Lock()
s.admins[fingerprint] = struct{}{}
s.Unlock()
}
// Whitelist adds the given fingerprint to the whitelist
func (s *Server) Whitelist(fingerprint string) error {
if fingerprint == "" {
return fmt.Errorf("Invalid fingerprint.")
}
if strings.HasPrefix(fingerprint, "github.com/") {
return s.whitelistIdentityURL(fingerprint)
}
return s.whitelistFingerprint(fingerprint)
}
func (s *Server) whitelistIdentityURL(user string) error {
logger.Infof("Adding github account %s to whitelist", user)
user = strings.Replace(user, "github.com/", "", -1)
keys, err := getGithubPubKeys(user)
if err != nil {
return err
}
if len(keys) == 0 {
return fmt.Errorf("No keys for github user %s", user)
}
for _, key := range keys {
fingerprint := Fingerprint(key)
s.whitelistFingerprint(fingerprint)
}
return nil
}
func (s *Server) whitelistFingerprint(fingerprint string) error {
logger.Infof("Adding whitelist: %s", fingerprint)
s.Lock()
s.whitelist[fingerprint] = struct{}{}
s.Unlock()
return nil
}
// Client for getting github pub keys
var client = http.Client{
Timeout: time.Duration(10 * time.Second),
}
// Returns an array of public keys for the given github user URL
func getGithubPubKeys(user string) ([]ssh.PublicKey, error) {
resp, err := client.Get("http://github.com/" + user + ".keys")
if err != nil {
return nil, err
}
defer resp.Body.Close()
pubs := []ssh.PublicKey{}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
text := scanner.Text()
if text == "Not Found" {
continue
}
splitKey := strings.SplitN(text, " ", -1)
// In case of malformated key
if len(splitKey) < 2 {
continue
}
bodyDecoded, err := base64.StdEncoding.DecodeString(splitKey[1])
if err != nil {
return nil, err
}
pub, err := ssh.ParsePublicKey(bodyDecoded)
if err != nil {
return nil, err
}
pubs = append(pubs, pub)
}
return pubs, nil
}
// Uptime returns the time since the server was started
func (s *Server) Uptime() string {
return time.Now().Sub(s.started).String()
}
// IsOp checks if the given client is Op
func (s *Server) IsOp(client *Client) bool {
fingerprint := client.Fingerprint()
if fingerprint == "" {
return false
}
_, r := s.admins[client.Fingerprint()]
return r
}
// IsWhitelisted checks if the given fingerprint is whitelisted
func (s *Server) IsWhitelisted(fingerprint string) bool {
/* if no whitelist, anyone is welcome */
if len(s.whitelist) == 0 {
return true
}
/* otherwise, check for whitelist presence */
_, r := s.whitelist[fingerprint]
return r
}
// IsBanned checks if the given fingerprint is banned
func (s *Server) IsBanned(fingerprint string) bool {
ban, hasBan := s.bannedPK[fingerprint]
if !hasBan {
return false
}
if ban == nil {
return true
}
if ban.Before(time.Now()) {
s.Unban(fingerprint)
return false
}
return true
}
// Ban bans a fingerprint for the given duration
func (s *Server) Ban(fingerprint string, duration *time.Duration) {
var until *time.Time
s.Lock()
if duration != nil {
when := time.Now().Add(*duration)
until = &when
}
s.bannedPK[fingerprint] = until
s.Unlock()
}
// Unban unbans a banned fingerprint
func (s *Server) Unban(fingerprint string) {
s.Lock()
delete(s.bannedPK, fingerprint)
s.Unlock()
}
// Start starts the server
func (s *Server) Start(laddr string) error {
// Once a ServerConfig has been configured, connections can be
// accepted.
socket, err := net.Listen("tcp", laddr)
if err != nil {
return err
}
logger.Infof("Listening on %s", laddr)
go func() {
defer socket.Close()
for {
conn, err := socket.Accept()
if err != nil {
logger.Errorf("Failed to accept connection: %v", err)
if err == syscall.EINVAL {
// TODO: Handle shutdown more gracefully?
return
}
}
// Goroutineify to resume accepting sockets early.
go func() {
// From a standard TCP connection to an encrypted SSH connection
sshConn, channels, requests, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil {
logger.Errorf("Failed to handshake: %v", err)
return
}
version := reStripText.ReplaceAllString(string(sshConn.ClientVersion()), "")
if len(version) > 100 {
version = "Evil Jerk with a superlong string"
}
logger.Infof("Connection #%d from: %s, %s, %s", s.count+1, sshConn.RemoteAddr(), sshConn.User(), version)
go ssh.DiscardRequests(requests)
client := NewClient(s, sshConn)
go client.handleChannels(channels)
}()
}
}()
go func() {
<-s.done
socket.Close()
}()
return nil
}
// AutoCompleteFunction handles auto completion of nicks
func (s *Server) AutoCompleteFunction(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
if key == 9 {
shortLine := strings.Split(line[:pos], " ")
partialNick := shortLine[len(shortLine)-1]
nicks := s.List(&partialNick)
if len(nicks) > 0 {
nick := nicks[len(nicks)-1]
posPartialNick := pos - len(partialNick)
if len(shortLine) < 2 {
nick += ": "
} else {
nick += " "
}
newLine = strings.Replace(line[posPartialNick:],
partialNick, nick, 1)
newLine = line[:posPartialNick] + newLine
newPos = pos + (len(nick) - len(partialNick))
ok = true
}
} else {
ok = false
}
return
}
// Stop stops the server
func (s *Server) Stop() {
s.Lock()
for _, client := range s.clients {
client.Conn.Close()
}
s.Unlock()
close(s.done)
}
// Fingerprint returns the fingerprint based on a public key
func Fingerprint(k ssh.PublicKey) string {
hash := md5.Sum(k.Marshal())
r := fmt.Sprintf("% x", hash)
return strings.Replace(r, " ", ":", -1)
}

70
set.go Normal file
View File

@ -0,0 +1,70 @@
package main
import (
"sync"
"time"
)
type expiringValue struct {
time.Time
}
func (v expiringValue) Bool() bool {
return time.Now().Before(v.Time)
}
type value struct{}
func (v value) Bool() bool {
return true
}
type SetValue interface {
Bool() bool
}
// Set with expire-able keys
type Set struct {
lookup map[string]SetValue
sync.Mutex
}
// NewSet creates a new set.
func NewSet() *Set {
return &Set{
lookup: map[string]SetValue{},
}
}
// Len returns the size of the set right now.
func (s *Set) Len() int {
return len(s.lookup)
}
// In checks if an item exists in this set.
func (s *Set) In(key string) bool {
s.Lock()
v, ok := s.lookup[key]
if ok && !v.Bool() {
ok = false
delete(s.lookup, key)
}
s.Unlock()
return ok
}
// Add item to this set, replace if it exists.
func (s *Set) Add(key string) {
s.Lock()
s.lookup[key] = value{}
s.Unlock()
}
// Add item to this set, replace if it exists.
func (s *Set) AddExpiring(key string, d time.Duration) time.Time {
until := time.Now().Add(d)
s.Lock()
s.lookup[key] = expiringValue{until}
s.Unlock()
return until
}

58
set_test.go Normal file
View File

@ -0,0 +1,58 @@
package main
import (
"testing"
"time"
)
func TestSetExpiring(t *testing.T) {
s := NewSet()
if s.In("foo") {
t.Error("Matched before set.")
}
s.Add("foo")
if !s.In("foo") {
t.Errorf("Not matched after set")
}
if s.Len() != 1 {
t.Error("Not len 1 after set")
}
v := expiringValue{time.Now().Add(-time.Nanosecond * 1)}
if v.Bool() {
t.Errorf("expiringValue now is not expiring")
}
v = expiringValue{time.Now().Add(time.Minute * 2)}
if !v.Bool() {
t.Errorf("expiringValue in 2 minutes is expiring now")
}
until := s.AddExpiring("bar", time.Minute*2)
if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
}
val, ok := s.lookup["bar"]
if !ok {
t.Errorf("bar not in lookup")
}
if !until.Equal(val.(expiringValue).Time) {
t.Errorf("bar's until is not equal to the expected value")
}
if !val.Bool() {
t.Errorf("bar expired immediately")
}
if !s.In("bar") {
t.Errorf("Not matched after timed set")
}
if s.Len() != 2 {
t.Error("Not len 2 after set")
}
s.AddExpiring("bar", time.Nanosecond*1)
if s.In("bar") {
t.Error("Matched after expired timer")
}
}

72
sshd/auth.go Normal file
View File

@ -0,0 +1,72 @@
package sshd
import (
"crypto/sha256"
"encoding/base64"
"errors"
"net"
"golang.org/x/crypto/ssh"
)
// Auth is used to authenticate connections based on public keys.
type Auth interface {
// Whether to allow connections without a public key.
AllowAnonymous() bool
// Given address and public key, return if the connection should be permitted.
Check(net.Addr, ssh.PublicKey) (bool, error)
}
// MakeAuth makes an ssh.ServerConfig which performs authentication against an Auth implementation.
func MakeAuth(auth Auth) *ssh.ServerConfig {
config := ssh.ServerConfig{
NoClientAuth: false,
// Auth-related things should be constant-time to avoid timing attacks.
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
ok, err := auth.Check(conn.RemoteAddr(), key)
if !ok {
return nil, err
}
perm := &ssh.Permissions{Extensions: map[string]string{
"pubkey": string(key.Marshal()),
}}
return perm, nil
},
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
if !auth.AllowAnonymous() {
return nil, errors.New("public key authentication required")
}
_, err := auth.Check(conn.RemoteAddr(), nil)
return nil, err
},
}
return &config
}
// MakeNoAuth makes a simple ssh.ServerConfig which allows all connections.
// Primarily used for testing.
func MakeNoAuth() *ssh.ServerConfig {
config := ssh.ServerConfig{
NoClientAuth: false,
// Auth-related things should be constant-time to avoid timing attacks.
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
perm := &ssh.Permissions{Extensions: map[string]string{
"pubkey": string(key.Marshal()),
}}
return perm, nil
},
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
return nil, nil
},
}
return &config
}
// Fingerprint performs a SHA256 BASE64 fingerprint of the PublicKey, similar to OpenSSH.
// See: https://anongit.mindrot.org/openssh.git/commit/?id=56d1c83cdd1ac
func Fingerprint(k ssh.PublicKey) string {
hash := sha256.Sum256(k.Marshal())
return base64.StdEncoding.EncodeToString(hash[:])
}

72
sshd/client.go Normal file
View File

@ -0,0 +1,72 @@
package sshd
import (
"crypto/rand"
"crypto/rsa"
"io"
"golang.org/x/crypto/ssh"
)
// NewRandomSigner generates a random key of a desired bit length.
func NewRandomSigner(bits int) (ssh.Signer, error) {
key, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, err
}
return ssh.NewSignerFromKey(key)
}
// NewClientConfig creates a barebones ssh.ClientConfig to be used with ssh.Dial.
func NewClientConfig(name string) *ssh.ClientConfig {
return &ssh.ClientConfig{
User: name,
Auth: []ssh.AuthMethod{
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
return
}),
},
}
}
// ConnectShell makes a barebones SSH client session, used for testing.
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
config := NewClientConfig(name)
conn, err := ssh.Dial("tcp", host, config)
if err != nil {
return err
}
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
return err
}
defer session.Close()
in, err := session.StdinPipe()
if err != nil {
return err
}
out, err := session.StdoutPipe()
if err != nil {
return err
}
/* FIXME: Do we want to request a PTY?
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
if err != nil {
return err
}
*/
err = session.Shell()
if err != nil {
return err
}
handler(out, in)
return nil
}

46
sshd/client_test.go Normal file
View File

@ -0,0 +1,46 @@
package sshd
import (
"errors"
"net"
"testing"
"golang.org/x/crypto/ssh"
)
var errRejectAuth = errors.New("not welcome here")
type RejectAuth struct{}
func (a RejectAuth) AllowAnonymous() bool {
return false
}
func (a RejectAuth) Check(net.Addr, ssh.PublicKey) (bool, error) {
return false, errRejectAuth
}
func consume(ch <-chan *Terminal) {
for _ = range ch {
}
}
func TestClientReject(t *testing.T) {
signer, err := NewRandomSigner(512)
config := MakeAuth(RejectAuth{})
config.AddHostKey(signer)
s, err := ListenSSH(":0", config)
if err != nil {
t.Fatal(err)
}
defer s.Close()
go consume(s.ServeTerminal())
conn, err := ssh.Dial("tcp", s.Addr().String(), NewClientConfig("foo"))
if err == nil {
defer conn.Close()
t.Error("Failed to reject conncetion")
}
t.Log(err)
}

34
sshd/doc.go Normal file
View File

@ -0,0 +1,34 @@
package sshd
/*
signer, err := ssh.ParsePrivateKey(privateKey)
config := MakeNoAuth()
config.AddHostKey(signer)
s, err := ListenSSH("0.0.0.0:2022", config)
if err != nil {
// Handle opening socket error
}
defer s.Close()
terminals := s.ServeTerminal()
for term := range terminals {
go func() {
defer term.Close()
term.SetPrompt("...")
term.AutoCompleteCallback = nil // ...
for {
line, err := term.ReadLine()
if err != nil {
break
}
term.Write(...)
}
}()
}
*/

22
sshd/logger.go Normal file
View File

@ -0,0 +1,22 @@
package sshd
import "io"
import stdlog "log"
var logger *stdlog.Logger
func SetLogger(w io.Writer) {
flags := stdlog.Flags()
prefix := "[sshd] "
logger = stdlog.New(w, prefix, flags)
}
type nullWriter struct{}
func (nullWriter) Write(data []byte) (int, error) {
return len(data), nil
}
func init() {
SetLogger(nullWriter{})
}

74
sshd/net.go Normal file
View File

@ -0,0 +1,74 @@
package sshd
import (
"net"
"time"
"github.com/shazow/rateio"
"golang.org/x/crypto/ssh"
)
// Container for the connection and ssh-related configuration
type SSHListener struct {
net.Listener
config *ssh.ServerConfig
RateLimit bool
}
// Make an SSH listener socket
func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) {
socket, err := net.Listen("tcp", laddr)
if err != nil {
return nil, err
}
l := SSHListener{Listener: socket, config: config}
return &l, nil
}
func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) {
if l.RateLimit {
// TODO: Configurable Limiter?
conn = ReadLimitConn(conn, rateio.NewGracefulLimiter(1024*10, time.Minute*2, time.Second*3))
}
// Upgrade TCP connection to SSH connection
sshConn, channels, requests, err := ssh.NewServerConn(conn, l.config)
if err != nil {
return nil, err
}
// FIXME: Disconnect if too many faulty requests? (Avoid DoS.)
go ssh.DiscardRequests(requests)
return NewSession(sshConn, channels)
}
// Accept incoming connections as terminal requests and yield them
func (l *SSHListener) ServeTerminal() <-chan *Terminal {
ch := make(chan *Terminal)
go func() {
defer l.Close()
defer close(ch)
for {
conn, err := l.Accept()
if err != nil {
logger.Printf("Failed to accept connection: %v", err)
return
}
// Goroutineify to resume accepting sockets early
go func() {
term, err := l.handleConn(conn)
if err != nil {
logger.Printf("Failed to handshake: %v", err)
return
}
ch <- term
}()
}
}()
return ch
}

81
sshd/net_test.go Normal file
View File

@ -0,0 +1,81 @@
package sshd
import (
"bytes"
"io"
"testing"
)
func TestServerInit(t *testing.T) {
config := MakeNoAuth()
s, err := ListenSSH(":badport", config)
if err == nil {
t.Fatal("should fail on bad port")
}
s, err = ListenSSH(":0", config)
if err != nil {
t.Error(err)
}
err = s.Close()
if err != nil {
t.Error(err)
}
}
func TestServeTerminals(t *testing.T) {
signer, err := NewRandomSigner(512)
config := MakeNoAuth()
config.AddHostKey(signer)
s, err := ListenSSH(":0", config)
if err != nil {
t.Fatal(err)
}
terminals := s.ServeTerminal()
go func() {
// Accept one terminal, read from it, echo back, close.
term := <-terminals
term.SetPrompt("> ")
line, err := term.ReadLine()
if err != nil {
t.Error(err)
}
_, err = term.Write([]byte("echo: " + line + "\r\n"))
if err != nil {
t.Error(err)
}
term.Close()
}()
host := s.Addr().String()
name := "foo"
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) {
// Consume if there is anything
buf := new(bytes.Buffer)
w.Write([]byte("hello\r\n"))
buf.Reset()
_, err := io.Copy(buf, r)
if err != nil {
t.Error(err)
}
expected := "> hello\r\necho: hello\r\n"
actual := buf.String()
if actual != expected {
t.Errorf("Got %q; expected %q", actual, expected)
}
s.Close()
})
if err != nil {
t.Fatal(err)
}
}

View File

@ -1,8 +1,9 @@
// Borrowed from go.crypto circa 2011 package sshd
package main
import "encoding/binary" import "encoding/binary"
// Helpers below are borrowed from go.crypto circa 2011:
// parsePtyRequest parses the payload of the pty-req message and extracts the // parsePtyRequest parses the payload of the pty-req message and extracts the
// dimensions of the terminal. See RFC 4254, section 6.2. // dimensions of the terminal. See RFC 4254, section 6.2.
func parsePtyRequest(s []byte) (width, height int, ok bool) { func parsePtyRequest(s []byte) (width, height int, ok bool) {

25
sshd/ratelimit.go Normal file
View File

@ -0,0 +1,25 @@
package sshd
import (
"io"
"net"
"github.com/shazow/rateio"
)
type limitedConn struct {
net.Conn
io.Reader // Our rate-limited io.Reader for net.Conn
}
func (r *limitedConn) Read(p []byte) (n int, err error) {
return r.Reader.Read(p)
}
// ReadLimitConn returns a net.Conn whose io.Reader interface is rate-limited by limiter.
func ReadLimitConn(conn net.Conn, limiter rateio.Limiter) net.Conn {
return &limitedConn{
Conn: conn,
Reader: rateio.NewReader(conn, limiter),
}
}

144
sshd/terminal.go Normal file
View File

@ -0,0 +1,144 @@
package sshd
import (
"errors"
"fmt"
"net"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
)
// Connection is an interface with fields necessary to operate an sshd host.
type Connection interface {
PublicKey() ssh.PublicKey
RemoteAddr() net.Addr
Name() string
Close() error
}
type sshConn struct {
*ssh.ServerConn
}
func (c sshConn) PublicKey() ssh.PublicKey {
if c.Permissions == nil {
return nil
}
s, ok := c.Permissions.Extensions["pubkey"]
if !ok {
return nil
}
key, err := ssh.ParsePublicKey([]byte(s))
if err != nil {
return nil
}
return key
}
func (c sshConn) Name() string {
return c.User()
}
// Extending ssh/terminal to include a closer interface
type Terminal struct {
terminal.Terminal
Conn Connection
Channel ssh.Channel
}
// Make new terminal from a session channel
func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
if ch.ChannelType() != "session" {
return nil, errors.New("terminal requires session channel")
}
channel, requests, err := ch.Accept()
if err != nil {
return nil, err
}
term := Terminal{
*terminal.NewTerminal(channel, "Connecting..."),
sshConn{conn},
channel,
}
go term.listen(requests)
go func() {
// FIXME: Is this necessary?
conn.Wait()
channel.Close()
}()
return &term, nil
}
// Find session channel and make a Terminal from it
func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
for ch := range channels {
if t := ch.ChannelType(); t != "session" {
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
continue
}
term, err = NewTerminal(conn, ch)
if err == nil {
break
}
}
if term != nil {
// Reject the rest.
// FIXME: Do we need this?
go func() {
for ch := range channels {
ch.Reject(ssh.Prohibited, "only one session allowed")
}
}()
}
return term, err
}
// Close terminal and ssh connection
func (t *Terminal) Close() error {
return t.Conn.Close()
}
// Negotiate terminal type and settings
func (t *Terminal) listen(requests <-chan *ssh.Request) {
hasShell := false
for req := range requests {
var width, height int
var ok bool
switch req.Type {
case "shell":
if !hasShell {
ok = true
hasShell = true
}
case "pty-req":
width, height, ok = parsePtyRequest(req.Payload)
if ok {
// TODO: Hardcode width to 100000?
err := t.SetSize(width, height)
ok = err == nil
}
case "window-change":
width, height, ok = parseWinchRequest(req.Payload)
if ok {
// TODO: Hardcode width to 100000?
err := t.SetSize(width, height)
ok = err == nil
}
}
if req.WantReply {
req.Reply(ok, nil)
}
}
}