mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-15 08:30:36 +03:00
Merge branch 'refactor'
This commit is contained in:
commit
5c72b1a121
4
Makefile
4
Makefile
@ -28,5 +28,5 @@ debug: $(BINARY) $(KEY)
|
|||||||
./$(BINARY) --pprof 6060 -i $(KEY) --bind ":$(PORT)" -vv
|
./$(BINARY) --pprof 6060 -i $(KEY) --bind ":$(PORT)" -vv
|
||||||
|
|
||||||
test:
|
test:
|
||||||
go test .
|
go test ./...
|
||||||
golint
|
golint ./...
|
||||||
|
@ -19,8 +19,6 @@ The server's RSA key fingerprint is `e5:d5:d1:75:90:38:42:f6:c7:03:d7:d0:56:7d:6
|
|||||||
|
|
||||||
## Compiling / Developing
|
## Compiling / Developing
|
||||||
|
|
||||||
**If you're going to be diving into the code, please use the [refactor branch](https://github.com/shazow/ssh-chat/tree/refactor) or see [issue #87](https://github.com/shazow/ssh-chat/pull/87).** It's not quite at feature parity yet, but the code is way nicer. The master branch is what's running on chat.shazow.net, but that will change soon.
|
|
||||||
|
|
||||||
You can compile ssh-chat by using `make build`. The resulting binary is portable and
|
You can compile ssh-chat by using `make build`. The resulting binary is portable and
|
||||||
can be run on any system with a similar OS and CPU arch. Go 1.3 or higher is required to compile.
|
can be run on any system with a similar OS and CPU arch. Go 1.3 or higher is required to compile.
|
||||||
|
|
||||||
@ -40,7 +38,7 @@ Usage:
|
|||||||
Application Options:
|
Application Options:
|
||||||
-v, --verbose Show verbose logging.
|
-v, --verbose Show verbose logging.
|
||||||
-i, --identity= Private key to identify server with. (~/.ssh/id_rsa)
|
-i, --identity= Private key to identify server with. (~/.ssh/id_rsa)
|
||||||
--bind= Host and port to listen on. (0.0.0.0:22)
|
--bind= Host and port to listen on. (0.0.0.0:2022)
|
||||||
--admin= Fingerprint of pubkey to mark as admin.
|
--admin= Fingerprint of pubkey to mark as admin.
|
||||||
--whitelist= Optional file of pubkey fingerprints that are allowed to connect
|
--whitelist= Optional file of pubkey fingerprints that are allowed to connect
|
||||||
--motd= Message of the Day file (optional)
|
--motd= Message of the Day file (optional)
|
||||||
@ -54,7 +52,7 @@ After doing `go get github.com/shazow/ssh-chat` on this repo, you should be able
|
|||||||
to run a command like:
|
to run a command like:
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ssh-chat --verbose --bind ":2022" --identity ~/.ssh/id_dsa
|
$ ssh-chat --verbose --bind ":22" --identity ~/.ssh/id_dsa
|
||||||
```
|
```
|
||||||
|
|
||||||
To bind on port 22, you'll need to make sure it's free (move any other ssh
|
To bind on port 22, you'll need to make sure it's free (move any other ssh
|
||||||
|
150
auth.go
Normal file
150
auth.go
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/shazow/ssh-chat/sshd"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The error returned a key is checked that is not whitelisted, with whitelisting required.
|
||||||
|
var ErrNotWhitelisted = errors.New("not whitelisted")
|
||||||
|
|
||||||
|
// The error returned a key is checked that is banned.
|
||||||
|
var ErrBanned = errors.New("banned")
|
||||||
|
|
||||||
|
// NewAuthKey returns string from an ssh.PublicKey.
|
||||||
|
func NewAuthKey(key ssh.PublicKey) string {
|
||||||
|
if key == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// FIXME: Is there a way to index pubkeys without marshal'ing them into strings?
|
||||||
|
return sshd.Fingerprint(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuthAddr returns a string from a net.Addr
|
||||||
|
func NewAuthAddr(addr net.Addr) string {
|
||||||
|
if addr == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
host, _, _ := net.SplitHostPort(addr.String())
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth stores fingerprint lookups
|
||||||
|
// TODO: Add timed auth by using a time.Time instead of struct{} for values.
|
||||||
|
type Auth struct {
|
||||||
|
sync.RWMutex
|
||||||
|
bannedAddr *Set
|
||||||
|
banned *Set
|
||||||
|
whitelist *Set
|
||||||
|
ops *Set
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuth creates a new default Auth.
|
||||||
|
func NewAuth() *Auth {
|
||||||
|
return &Auth{
|
||||||
|
bannedAddr: NewSet(),
|
||||||
|
banned: NewSet(),
|
||||||
|
whitelist: NewSet(),
|
||||||
|
ops: NewSet(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowAnonymous determines if anonymous users are permitted.
|
||||||
|
func (a Auth) AllowAnonymous() bool {
|
||||||
|
return a.whitelist.Len() == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check determines if a pubkey fingerprint is permitted.
|
||||||
|
func (a *Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) {
|
||||||
|
authkey := NewAuthKey(key)
|
||||||
|
|
||||||
|
if a.whitelist.Len() != 0 {
|
||||||
|
// Only check whitelist if there is something in it, otherwise it's disabled.
|
||||||
|
whitelisted := a.whitelist.In(authkey)
|
||||||
|
if !whitelisted {
|
||||||
|
return false, ErrNotWhitelisted
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
banned := a.banned.In(authkey)
|
||||||
|
if !banned {
|
||||||
|
banned = a.bannedAddr.In(NewAuthAddr(addr))
|
||||||
|
}
|
||||||
|
if banned {
|
||||||
|
return false, ErrBanned
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Op sets a public key as a known operator.
|
||||||
|
func (a *Auth) Op(key ssh.PublicKey, d time.Duration) {
|
||||||
|
if key == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authkey := NewAuthKey(key)
|
||||||
|
if d != 0 {
|
||||||
|
a.ops.AddExpiring(authkey, d)
|
||||||
|
} else {
|
||||||
|
a.ops.Add(authkey)
|
||||||
|
}
|
||||||
|
logger.Debugf("Added to ops: %s (for %s)", authkey, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsOp checks if a public key is an op.
|
||||||
|
func (a *Auth) IsOp(key ssh.PublicKey) bool {
|
||||||
|
if key == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
authkey := NewAuthKey(key)
|
||||||
|
return a.ops.In(authkey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Whitelist will set a public key as a whitelisted user.
|
||||||
|
func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) {
|
||||||
|
if key == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authkey := NewAuthKey(key)
|
||||||
|
if d != 0 {
|
||||||
|
a.whitelist.AddExpiring(authkey, d)
|
||||||
|
} else {
|
||||||
|
a.whitelist.Add(authkey)
|
||||||
|
}
|
||||||
|
logger.Debugf("Added to whitelist: %s (for %s)", authkey, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ban will set a public key as banned.
|
||||||
|
func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) {
|
||||||
|
if key == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
a.BanFingerprint(NewAuthKey(key), d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BanFingerprint will set a public key fingerprint as banned.
|
||||||
|
func (a *Auth) BanFingerprint(authkey string, d time.Duration) {
|
||||||
|
if d != 0 {
|
||||||
|
a.banned.AddExpiring(authkey, d)
|
||||||
|
} else {
|
||||||
|
a.banned.Add(authkey)
|
||||||
|
}
|
||||||
|
logger.Debugf("Added to banned: %s (for %s)", authkey, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ban will set an IP address as banned.
|
||||||
|
func (a *Auth) BanAddr(addr net.Addr, d time.Duration) {
|
||||||
|
key := NewAuthAddr(addr)
|
||||||
|
if d != 0 {
|
||||||
|
a.bannedAddr.AddExpiring(key, d)
|
||||||
|
} else {
|
||||||
|
a.bannedAddr.Add(key)
|
||||||
|
}
|
||||||
|
logger.Debugf("Added to bannedAddr: %s (for %s)", key, d)
|
||||||
|
}
|
62
auth_test.go
Normal file
62
auth_test.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewRandomPublicKey(bits int) (ssh.PublicKey, error) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, bits)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssh.NewPublicKey(key.Public())
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClonePublicKey(key ssh.PublicKey) (ssh.PublicKey, error) {
|
||||||
|
return ssh.ParsePublicKey(key.Marshal())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthWhitelist(t *testing.T) {
|
||||||
|
key, err := NewRandomPublicKey(512)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth := NewAuth()
|
||||||
|
ok, err := auth.Check(nil, key)
|
||||||
|
if !ok || err != nil {
|
||||||
|
t.Error("Failed to permit in default state:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.Whitelist(key, 0)
|
||||||
|
|
||||||
|
keyClone, err := ClonePublicKey(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(keyClone.Marshal()) != string(key.Marshal()) {
|
||||||
|
t.Error("Clone key does not match.")
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err = auth.Check(nil, keyClone)
|
||||||
|
if !ok || err != nil {
|
||||||
|
t.Error("Failed to permit whitelisted:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key2, err := NewRandomPublicKey(512)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err = auth.Check(nil, key2)
|
||||||
|
if ok || err == nil {
|
||||||
|
t.Error("Failed to restrict not whitelisted:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
241
chat/command.go
Normal file
241
chat/command.go
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
// FIXME: Would be sweet if we could piggyback on a cli parser or something.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The error returned when an invalid command is issued.
|
||||||
|
var ErrInvalidCommand = errors.New("invalid command")
|
||||||
|
|
||||||
|
// The error returned when a command is given without an owner.
|
||||||
|
var ErrNoOwner = errors.New("command without owner")
|
||||||
|
|
||||||
|
// The error returned when a command is performed without the necessary number
|
||||||
|
// of arguments.
|
||||||
|
var ErrMissingArg = errors.New("missing argument")
|
||||||
|
|
||||||
|
// The error returned when a command is added without a prefix.
|
||||||
|
var ErrMissingPrefix = errors.New("command missing prefix")
|
||||||
|
|
||||||
|
// Command is a definition of a handler for a command.
|
||||||
|
type Command struct {
|
||||||
|
// The command's key, such as /foo
|
||||||
|
Prefix string
|
||||||
|
// Extra help regarding arguments
|
||||||
|
PrefixHelp string
|
||||||
|
// If omitted, command is hidden from /help
|
||||||
|
Help string
|
||||||
|
Handler func(*Room, CommandMsg) error
|
||||||
|
// Command requires Op permissions
|
||||||
|
Op bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commands is a registry of available commands.
|
||||||
|
type Commands map[string]*Command
|
||||||
|
|
||||||
|
// Add will register a command. If help string is empty, it will be hidden from
|
||||||
|
// Help().
|
||||||
|
func (c Commands) Add(cmd Command) error {
|
||||||
|
if cmd.Prefix == "" {
|
||||||
|
return ErrMissingPrefix
|
||||||
|
}
|
||||||
|
|
||||||
|
c[cmd.Prefix] = &cmd
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Alias will add another command for the same handler, won't get added to help.
|
||||||
|
func (c Commands) Alias(command string, alias string) error {
|
||||||
|
cmd, ok := c[command]
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidCommand
|
||||||
|
}
|
||||||
|
c[alias] = cmd
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run executes a command message.
|
||||||
|
func (c Commands) Run(room *Room, msg CommandMsg) error {
|
||||||
|
if msg.From == nil {
|
||||||
|
return ErrNoOwner
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd, ok := c[msg.Command()]
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidCommand
|
||||||
|
}
|
||||||
|
|
||||||
|
return cmd.Handler(room, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Help will return collated help text as one string.
|
||||||
|
func (c Commands) Help(showOp bool) string {
|
||||||
|
// Filter by op
|
||||||
|
op := []*Command{}
|
||||||
|
normal := []*Command{}
|
||||||
|
for _, cmd := range c {
|
||||||
|
if cmd.Op {
|
||||||
|
op = append(op, cmd)
|
||||||
|
} else {
|
||||||
|
normal = append(normal, cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
help := "Available commands:" + Newline + NewCommandsHelp(normal).String()
|
||||||
|
if showOp {
|
||||||
|
help += Newline + "-> Operator commands:" + Newline + NewCommandsHelp(op).String()
|
||||||
|
}
|
||||||
|
return help
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultCommands *Commands
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
defaultCommands = &Commands{}
|
||||||
|
InitCommands(defaultCommands)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitCommands injects default commands into a Commands registry.
|
||||||
|
func InitCommands(c *Commands) {
|
||||||
|
c.Add(Command{
|
||||||
|
Prefix: "/help",
|
||||||
|
Handler: func(room *Room, msg CommandMsg) error {
|
||||||
|
op := room.IsOp(msg.From())
|
||||||
|
room.Send(NewSystemMsg(room.commands.Help(op), msg.From()))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
c.Add(Command{
|
||||||
|
Prefix: "/me",
|
||||||
|
Handler: func(room *Room, msg CommandMsg) error {
|
||||||
|
me := strings.TrimLeft(msg.body, "/me")
|
||||||
|
if me == "" {
|
||||||
|
me = "is at a loss for words."
|
||||||
|
} else {
|
||||||
|
me = me[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
room.Send(NewEmoteMsg(me, msg.From()))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
c.Add(Command{
|
||||||
|
Prefix: "/exit",
|
||||||
|
Help: "Exit the chat.",
|
||||||
|
Handler: func(room *Room, msg CommandMsg) error {
|
||||||
|
msg.From().Close()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Alias("/exit", "/quit")
|
||||||
|
|
||||||
|
c.Add(Command{
|
||||||
|
Prefix: "/nick",
|
||||||
|
PrefixHelp: "NAME",
|
||||||
|
Help: "Rename yourself.",
|
||||||
|
Handler: func(room *Room, msg CommandMsg) error {
|
||||||
|
args := msg.Args()
|
||||||
|
if len(args) != 1 {
|
||||||
|
return ErrMissingArg
|
||||||
|
}
|
||||||
|
u := msg.From()
|
||||||
|
|
||||||
|
member, ok := room.MemberById(u.Id())
|
||||||
|
if !ok {
|
||||||
|
return errors.New("failed to find member")
|
||||||
|
}
|
||||||
|
|
||||||
|
oldId := member.Id()
|
||||||
|
member.SetId(SanitizeName(args[0]))
|
||||||
|
err := room.Rename(oldId, member)
|
||||||
|
if err != nil {
|
||||||
|
member.SetId(oldId)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
c.Add(Command{
|
||||||
|
Prefix: "/names",
|
||||||
|
Help: "List users who are connected.",
|
||||||
|
Handler: func(room *Room, msg CommandMsg) error {
|
||||||
|
// TODO: colorize
|
||||||
|
names := room.NamesPrefix("")
|
||||||
|
body := fmt.Sprintf("%d connected: %s", len(names), strings.Join(names, ", "))
|
||||||
|
room.Send(NewSystemMsg(body, msg.From()))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Alias("/names", "/list")
|
||||||
|
|
||||||
|
c.Add(Command{
|
||||||
|
Prefix: "/theme",
|
||||||
|
PrefixHelp: "[mono|colors]",
|
||||||
|
Help: "Set your color theme.",
|
||||||
|
Handler: func(room *Room, msg CommandMsg) error {
|
||||||
|
user := msg.From()
|
||||||
|
args := msg.Args()
|
||||||
|
if len(args) == 0 {
|
||||||
|
theme := "plain"
|
||||||
|
if user.Config.Theme != nil {
|
||||||
|
theme = user.Config.Theme.Id()
|
||||||
|
}
|
||||||
|
body := fmt.Sprintf("Current theme: %s", theme)
|
||||||
|
room.Send(NewSystemMsg(body, user))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
id := args[0]
|
||||||
|
for _, t := range Themes {
|
||||||
|
if t.Id() == id {
|
||||||
|
user.Config.Theme = &t
|
||||||
|
body := fmt.Sprintf("Set theme: %s", id)
|
||||||
|
room.Send(NewSystemMsg(body, user))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors.New("theme not found")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
c.Add(Command{
|
||||||
|
Prefix: "/quiet",
|
||||||
|
Help: "Silence room announcements.",
|
||||||
|
Handler: func(room *Room, msg CommandMsg) error {
|
||||||
|
u := msg.From()
|
||||||
|
u.ToggleQuietMode()
|
||||||
|
|
||||||
|
var body string
|
||||||
|
if u.Config.Quiet {
|
||||||
|
body = "Quiet mode is toggled ON"
|
||||||
|
} else {
|
||||||
|
body = "Quiet mode is toggled OFF"
|
||||||
|
}
|
||||||
|
room.Send(NewSystemMsg(body, u))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
c.Add(Command{
|
||||||
|
Prefix: "/slap",
|
||||||
|
PrefixHelp: "NAME",
|
||||||
|
Handler: func(room *Room, msg CommandMsg) error {
|
||||||
|
var me string
|
||||||
|
args := msg.Args()
|
||||||
|
if len(args) == 0 {
|
||||||
|
me = "slaps themselves around a bit with a large trout."
|
||||||
|
} else {
|
||||||
|
me = fmt.Sprintf("slaps %s around a bit with a large trout.", strings.Join(args, " "))
|
||||||
|
}
|
||||||
|
|
||||||
|
room.Send(NewEmoteMsg(me, msg.From()))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
13
chat/doc.go
Normal file
13
chat/doc.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
/*
|
||||||
|
`chat` package is a server-agnostic implementation of a chat interface, built
|
||||||
|
with the intention of using with the intention of using as the backend for
|
||||||
|
ssh-chat.
|
||||||
|
|
||||||
|
This package should not know anything about sockets. It should expose io-style
|
||||||
|
interfaces and rooms for communicating with any method of transnport.
|
||||||
|
|
||||||
|
TODO: Add usage examples here.
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
package chat
|
58
chat/help.go
Normal file
58
chat/help.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type helpItem struct {
|
||||||
|
Prefix string
|
||||||
|
Text string
|
||||||
|
}
|
||||||
|
|
||||||
|
type help struct {
|
||||||
|
items []helpItem
|
||||||
|
prefixWidth int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCommandsHelp creates a help container from a commands container.
|
||||||
|
func NewCommandsHelp(c []*Command) fmt.Stringer {
|
||||||
|
lookup := map[string]struct{}{}
|
||||||
|
h := help{
|
||||||
|
items: []helpItem{},
|
||||||
|
}
|
||||||
|
for _, cmd := range c {
|
||||||
|
if cmd.Help == "" {
|
||||||
|
// Skip hidden commands.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, exists := lookup[cmd.Prefix]
|
||||||
|
if exists {
|
||||||
|
// Duplicate (alias)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lookup[cmd.Prefix] = struct{}{}
|
||||||
|
prefix := fmt.Sprintf("%s %s", cmd.Prefix, cmd.PrefixHelp)
|
||||||
|
h.add(helpItem{prefix, cmd.Help})
|
||||||
|
}
|
||||||
|
return &h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *help) add(item helpItem) {
|
||||||
|
h.items = append(h.items, item)
|
||||||
|
if len(item.Prefix) > h.prefixWidth {
|
||||||
|
h.prefixWidth = len(item.Prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h help) String() string {
|
||||||
|
r := []string{}
|
||||||
|
format := fmt.Sprintf("%%-%ds - %%s", h.prefixWidth)
|
||||||
|
for _, item := range h.items {
|
||||||
|
r = append(r, fmt.Sprintf(format, item.Prefix, item.Text))
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(r)
|
||||||
|
return strings.Join(r, Newline)
|
||||||
|
}
|
@ -1,27 +1,33 @@
|
|||||||
// TODO: Split this out into its own module, it's kinda neat.
|
package chat
|
||||||
package main
|
|
||||||
|
|
||||||
import "sync"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const timestampFmt = "2006-01-02 15:04:05"
|
||||||
|
|
||||||
// History contains the history entries
|
// History contains the history entries
|
||||||
type History struct {
|
type History struct {
|
||||||
entries []string
|
sync.RWMutex
|
||||||
|
entries []Message
|
||||||
head int
|
head int
|
||||||
size int
|
size int
|
||||||
lock sync.Mutex
|
out io.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHistory constructs a new history of the given size
|
// NewHistory constructs a new history of the given size
|
||||||
func NewHistory(size int) *History {
|
func NewHistory(size int) *History {
|
||||||
return &History{
|
return &History{
|
||||||
entries: make([]string, size),
|
entries: make([]Message, size),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add adds the given entry to the entries in the history
|
// Add adds the given entry to the entries in the history
|
||||||
func (h *History) Add(entry string) {
|
func (h *History) Add(entry Message) {
|
||||||
h.lock.Lock()
|
h.Lock()
|
||||||
defer h.lock.Unlock()
|
defer h.Unlock()
|
||||||
|
|
||||||
max := cap(h.entries)
|
max := cap(h.entries)
|
||||||
h.head = (h.head + 1) % max
|
h.head = (h.head + 1) % max
|
||||||
@ -29,6 +35,10 @@ func (h *History) Add(entry string) {
|
|||||||
if h.size < max {
|
if h.size < max {
|
||||||
h.size++
|
h.size++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.out != nil {
|
||||||
|
fmt.Fprintf(h.out, "[%s] %s\n", entry.Timestamp().UTC().Format(timestampFmt), entry.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Len returns the number of entries in the history
|
// Len returns the number of entries in the history
|
||||||
@ -37,16 +47,16 @@ func (h *History) Len() int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the entry with the given number
|
// Get the entry with the given number
|
||||||
func (h *History) Get(num int) []string {
|
func (h *History) Get(num int) []Message {
|
||||||
h.lock.Lock()
|
h.RLock()
|
||||||
defer h.lock.Unlock()
|
defer h.RUnlock()
|
||||||
|
|
||||||
max := cap(h.entries)
|
max := cap(h.entries)
|
||||||
if num > h.size {
|
if num > h.size {
|
||||||
num = h.size
|
num = h.size
|
||||||
}
|
}
|
||||||
|
|
||||||
r := make([]string, num)
|
r := make([]Message, num)
|
||||||
for i := 0; i < num; i++ {
|
for i := 0; i < num; i++ {
|
||||||
idx := (h.head - i) % max
|
idx := (h.head - i) % max
|
||||||
if idx < 0 {
|
if idx < 0 {
|
||||||
@ -57,3 +67,10 @@ func (h *History) Get(num int) []string {
|
|||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetOutput sets the output for logging added messages
|
||||||
|
func (h *History) SetOutput(w io.Writer) {
|
||||||
|
h.Lock()
|
||||||
|
h.out = w
|
||||||
|
h.Unlock()
|
||||||
|
}
|
62
chat/history_test.go
Normal file
62
chat/history_test.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func msgEqual(a []Message, b []Message) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i].String() != b[i].String() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHistory(t *testing.T) {
|
||||||
|
var r, expected []Message
|
||||||
|
var size int
|
||||||
|
|
||||||
|
h := NewHistory(5)
|
||||||
|
|
||||||
|
r = h.Get(10)
|
||||||
|
expected = []Message{}
|
||||||
|
if !msgEqual(r, expected) {
|
||||||
|
t.Errorf("Got: %v, Expected: %v", r, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.Add(NewMsg("1"))
|
||||||
|
|
||||||
|
if size = h.Len(); size != 1 {
|
||||||
|
t.Errorf("Wrong size: %v", size)
|
||||||
|
}
|
||||||
|
|
||||||
|
r = h.Get(1)
|
||||||
|
expected = []Message{NewMsg("1")}
|
||||||
|
if !msgEqual(r, expected) {
|
||||||
|
t.Errorf("Got: %v, Expected: %v", r, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.Add(NewMsg("2"))
|
||||||
|
h.Add(NewMsg("3"))
|
||||||
|
h.Add(NewMsg("4"))
|
||||||
|
h.Add(NewMsg("5"))
|
||||||
|
h.Add(NewMsg("6"))
|
||||||
|
|
||||||
|
if size = h.Len(); size != 5 {
|
||||||
|
t.Errorf("Wrong size: %v", size)
|
||||||
|
}
|
||||||
|
|
||||||
|
r = h.Get(2)
|
||||||
|
expected = []Message{NewMsg("5"), NewMsg("6")}
|
||||||
|
if !msgEqual(r, expected) {
|
||||||
|
t.Errorf("Got: %v, Expected: %v", r, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
r = h.Get(10)
|
||||||
|
expected = []Message{NewMsg("2"), NewMsg("3"), NewMsg("4"), NewMsg("5"), NewMsg("6")}
|
||||||
|
if !msgEqual(r, expected) {
|
||||||
|
t.Errorf("Got: %v, Expected: %v", r, expected)
|
||||||
|
}
|
||||||
|
}
|
22
chat/logger.go
Normal file
22
chat/logger.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
import stdlog "log"
|
||||||
|
|
||||||
|
var logger *stdlog.Logger
|
||||||
|
|
||||||
|
func SetLogger(w io.Writer) {
|
||||||
|
flags := stdlog.Flags()
|
||||||
|
prefix := "[chat] "
|
||||||
|
logger = stdlog.New(w, prefix, flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
type nullWriter struct{}
|
||||||
|
|
||||||
|
func (nullWriter) Write(data []byte) (int, error) {
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
SetLogger(nullWriter{})
|
||||||
|
}
|
257
chat/message.go
Normal file
257
chat/message.go
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message is an interface for messages.
|
||||||
|
type Message interface {
|
||||||
|
Render(*Theme) string
|
||||||
|
String() string
|
||||||
|
Command() string
|
||||||
|
Timestamp() time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageTo interface {
|
||||||
|
Message
|
||||||
|
To() *User
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageFrom interface {
|
||||||
|
Message
|
||||||
|
From() *User
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseInput(body string, from *User) Message {
|
||||||
|
m := NewPublicMsg(body, from)
|
||||||
|
cmd, isCmd := m.ParseCommand()
|
||||||
|
if isCmd {
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Msg is a base type for other message types.
|
||||||
|
type Msg struct {
|
||||||
|
body string
|
||||||
|
timestamp time.Time
|
||||||
|
// TODO: themeCache *map[*Theme]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMsg(body string) *Msg {
|
||||||
|
return &Msg{
|
||||||
|
body: body,
|
||||||
|
timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render message based on a theme.
|
||||||
|
func (m *Msg) Render(t *Theme) string {
|
||||||
|
// TODO: Render based on theme
|
||||||
|
// TODO: Cache based on theme
|
||||||
|
return m.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Msg) String() string {
|
||||||
|
return m.body
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Msg) Command() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Msg) Timestamp() time.Time {
|
||||||
|
return m.timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublicMsg is any message from a user sent to the room.
|
||||||
|
type PublicMsg struct {
|
||||||
|
Msg
|
||||||
|
from *User
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPublicMsg(body string, from *User) *PublicMsg {
|
||||||
|
return &PublicMsg{
|
||||||
|
Msg: Msg{
|
||||||
|
body: body,
|
||||||
|
timestamp: time.Now(),
|
||||||
|
},
|
||||||
|
from: from,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *PublicMsg) From() *User {
|
||||||
|
return m.from
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *PublicMsg) ParseCommand() (*CommandMsg, bool) {
|
||||||
|
// Check if the message is a command
|
||||||
|
if !strings.HasPrefix(m.body, "/") {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse
|
||||||
|
// TODO: Handle quoted fields properly
|
||||||
|
fields := strings.Fields(m.body)
|
||||||
|
command, args := fields[0], fields[1:]
|
||||||
|
msg := CommandMsg{
|
||||||
|
PublicMsg: m,
|
||||||
|
command: command,
|
||||||
|
args: args,
|
||||||
|
}
|
||||||
|
return &msg, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *PublicMsg) Render(t *Theme) string {
|
||||||
|
if t == nil {
|
||||||
|
return m.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s: %s", t.ColorName(m.from), m.body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *PublicMsg) RenderFor(cfg UserConfig) string {
|
||||||
|
if cfg.Highlight == nil || cfg.Theme == nil {
|
||||||
|
return m.Render(cfg.Theme)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.Highlight.MatchString(m.body) {
|
||||||
|
return m.Render(cfg.Theme)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := cfg.Highlight.ReplaceAllString(m.body, cfg.Theme.Highlight("${1}"))
|
||||||
|
if cfg.Bell {
|
||||||
|
body += Bel
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s: %s", cfg.Theme.ColorName(m.from), body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *PublicMsg) String() string {
|
||||||
|
return fmt.Sprintf("%s: %s", m.from.Name(), m.body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmoteMsg is a /me message sent to the room. It specifically does not
|
||||||
|
// extend PublicMsg because it doesn't implement MessageFrom to allow the
|
||||||
|
// sender to see the emote.
|
||||||
|
type EmoteMsg struct {
|
||||||
|
Msg
|
||||||
|
from *User
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEmoteMsg(body string, from *User) *EmoteMsg {
|
||||||
|
return &EmoteMsg{
|
||||||
|
Msg: Msg{
|
||||||
|
body: body,
|
||||||
|
timestamp: time.Now(),
|
||||||
|
},
|
||||||
|
from: from,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *EmoteMsg) Render(t *Theme) string {
|
||||||
|
return fmt.Sprintf("** %s %s", m.from.Name(), m.body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *EmoteMsg) String() string {
|
||||||
|
return m.Render(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrivateMsg is a message sent to another user, not shown to anyone else.
|
||||||
|
type PrivateMsg struct {
|
||||||
|
PublicMsg
|
||||||
|
to *User
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPrivateMsg(body string, from *User, to *User) *PrivateMsg {
|
||||||
|
return &PrivateMsg{
|
||||||
|
PublicMsg: *NewPublicMsg(body, from),
|
||||||
|
to: to,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *PrivateMsg) To() *User {
|
||||||
|
return m.to
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *PrivateMsg) Render(t *Theme) string {
|
||||||
|
return fmt.Sprintf("[PM from %s] %s", m.from.Name(), m.body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *PrivateMsg) String() string {
|
||||||
|
return m.Render(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemMsg is a response sent from the server directly to a user, not shown
|
||||||
|
// to anyone else. Usually in response to something, like /help.
|
||||||
|
type SystemMsg struct {
|
||||||
|
Msg
|
||||||
|
to *User
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSystemMsg(body string, to *User) *SystemMsg {
|
||||||
|
return &SystemMsg{
|
||||||
|
Msg: Msg{
|
||||||
|
body: body,
|
||||||
|
timestamp: time.Now(),
|
||||||
|
},
|
||||||
|
to: to,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SystemMsg) Render(t *Theme) string {
|
||||||
|
if t == nil {
|
||||||
|
return m.String()
|
||||||
|
}
|
||||||
|
return t.ColorSys(m.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SystemMsg) String() string {
|
||||||
|
return fmt.Sprintf("-> %s", m.body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SystemMsg) To() *User {
|
||||||
|
return m.to
|
||||||
|
}
|
||||||
|
|
||||||
|
// AnnounceMsg is a message sent from the server to everyone, like a join or
|
||||||
|
// leave event.
|
||||||
|
type AnnounceMsg struct {
|
||||||
|
Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAnnounceMsg(body string) *AnnounceMsg {
|
||||||
|
return &AnnounceMsg{
|
||||||
|
Msg: Msg{
|
||||||
|
body: body,
|
||||||
|
timestamp: time.Now(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AnnounceMsg) Render(t *Theme) string {
|
||||||
|
if t == nil {
|
||||||
|
return m.String()
|
||||||
|
}
|
||||||
|
return t.ColorSys(m.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AnnounceMsg) String() string {
|
||||||
|
return fmt.Sprintf(" * %s", m.body)
|
||||||
|
}
|
||||||
|
|
||||||
|
type CommandMsg struct {
|
||||||
|
*PublicMsg
|
||||||
|
command string
|
||||||
|
args []string
|
||||||
|
room *Room
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CommandMsg) Command() string {
|
||||||
|
return m.command
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CommandMsg) Args() []string {
|
||||||
|
return m.args
|
||||||
|
}
|
52
chat/message_test.go
Normal file
52
chat/message_test.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
type testId string
|
||||||
|
|
||||||
|
func (i testId) Id() string {
|
||||||
|
return string(i)
|
||||||
|
}
|
||||||
|
func (i testId) SetId(s string) {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
|
func (i testId) Name() string {
|
||||||
|
return i.Id()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessage(t *testing.T) {
|
||||||
|
var expected, actual string
|
||||||
|
|
||||||
|
expected = " * foo"
|
||||||
|
actual = NewAnnounceMsg("foo").String()
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
u := NewUser(testId("foo"))
|
||||||
|
expected = "foo: hello"
|
||||||
|
actual = NewPublicMsg("hello", u).String()
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected = "** foo sighs."
|
||||||
|
actual = NewEmoteMsg("sighs.", u).String()
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected = "-> hello"
|
||||||
|
actual = NewSystemMsg("hello", u).String()
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected = "[PM from foo] hello"
|
||||||
|
actual = NewPrivateMsg("hello", u, u).String()
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Add theme rendering tests
|
222
chat/room.go
Normal file
222
chat/room.go
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const historyLen = 20
|
||||||
|
const roomBuffer = 10
|
||||||
|
|
||||||
|
// The error returned when a message is sent to a room that is already
|
||||||
|
// closed.
|
||||||
|
var ErrRoomClosed = errors.New("room closed")
|
||||||
|
|
||||||
|
// The error returned when a user attempts to join with an invalid name, such
|
||||||
|
// as empty string.
|
||||||
|
var ErrInvalidName = errors.New("invalid name")
|
||||||
|
|
||||||
|
// Member is a User with per-Room metadata attached to it.
|
||||||
|
type Member struct {
|
||||||
|
*User
|
||||||
|
Op bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Room definition, also a Set of User Items
|
||||||
|
type Room struct {
|
||||||
|
topic string
|
||||||
|
history *History
|
||||||
|
members *Set
|
||||||
|
broadcast chan Message
|
||||||
|
commands Commands
|
||||||
|
closed bool
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRoom creates a new room.
|
||||||
|
func NewRoom() *Room {
|
||||||
|
broadcast := make(chan Message, roomBuffer)
|
||||||
|
|
||||||
|
return &Room{
|
||||||
|
broadcast: broadcast,
|
||||||
|
history: NewHistory(historyLen),
|
||||||
|
members: NewSet(),
|
||||||
|
commands: *defaultCommands,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCommands sets the room's command handlers.
|
||||||
|
func (r *Room) SetCommands(commands Commands) {
|
||||||
|
r.commands = commands
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the room and all the users it contains.
|
||||||
|
func (r *Room) Close() {
|
||||||
|
r.closeOnce.Do(func() {
|
||||||
|
r.closed = true
|
||||||
|
r.members.Each(func(m Identifier) {
|
||||||
|
m.(*Member).Close()
|
||||||
|
})
|
||||||
|
r.members.Clear()
|
||||||
|
close(r.broadcast)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLogging sets logging output for the room's history
|
||||||
|
func (r *Room) SetLogging(out io.Writer) {
|
||||||
|
r.history.SetOutput(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleMsg reacts to a message, will block until done.
|
||||||
|
func (r *Room) HandleMsg(m Message) {
|
||||||
|
switch m := m.(type) {
|
||||||
|
case *CommandMsg:
|
||||||
|
cmd := *m
|
||||||
|
err := r.commands.Run(r, cmd)
|
||||||
|
if err != nil {
|
||||||
|
m := NewSystemMsg(fmt.Sprintf("Err: %s", err), cmd.from)
|
||||||
|
go r.HandleMsg(m)
|
||||||
|
}
|
||||||
|
case MessageTo:
|
||||||
|
user := m.To()
|
||||||
|
user.Send(m)
|
||||||
|
default:
|
||||||
|
fromMsg, skip := m.(MessageFrom)
|
||||||
|
var skipUser *User
|
||||||
|
if skip {
|
||||||
|
skipUser = fromMsg.From()
|
||||||
|
}
|
||||||
|
|
||||||
|
r.history.Add(m)
|
||||||
|
r.members.Each(func(u Identifier) {
|
||||||
|
user := u.(*Member).User
|
||||||
|
if skip && skipUser == user {
|
||||||
|
// Skip
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := m.(*AnnounceMsg); ok {
|
||||||
|
if user.Config.Quiet {
|
||||||
|
// Skip
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
user.Send(m)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve will consume the broadcast room and handle the messages, should be
|
||||||
|
// run in a goroutine.
|
||||||
|
func (r *Room) Serve() {
|
||||||
|
for m := range r.broadcast {
|
||||||
|
go r.HandleMsg(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send message, buffered by a chan.
|
||||||
|
func (r *Room) Send(m Message) {
|
||||||
|
r.broadcast <- m
|
||||||
|
}
|
||||||
|
|
||||||
|
// History feeds the room's recent message history to the user's handler.
|
||||||
|
func (r *Room) History(u *User) {
|
||||||
|
for _, m := range r.history.Get(historyLen) {
|
||||||
|
u.Send(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Join the room as a user, will announce.
|
||||||
|
func (r *Room) Join(u *User) (*Member, error) {
|
||||||
|
if r.closed {
|
||||||
|
return nil, ErrRoomClosed
|
||||||
|
}
|
||||||
|
if u.Id() == "" {
|
||||||
|
return nil, ErrInvalidName
|
||||||
|
}
|
||||||
|
member := Member{u, false}
|
||||||
|
err := r.members.Add(&member)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
r.History(u)
|
||||||
|
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.members.Len())
|
||||||
|
r.Send(NewAnnounceMsg(s))
|
||||||
|
return &member, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Leave the room as a user, will announce. Mostly used during setup.
|
||||||
|
func (r *Room) Leave(u *User) error {
|
||||||
|
err := r.members.Remove(u)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s := fmt.Sprintf("%s left.", u.Name())
|
||||||
|
r.Send(NewAnnounceMsg(s))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rename member with a new identity. This will not call rename on the member.
|
||||||
|
func (r *Room) Rename(oldId string, identity Identifier) error {
|
||||||
|
if identity.Id() == "" {
|
||||||
|
return ErrInvalidName
|
||||||
|
}
|
||||||
|
err := r.members.Replace(oldId, identity)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s := fmt.Sprintf("%s is now known as %s.", oldId, identity.Id())
|
||||||
|
r.Send(NewAnnounceMsg(s))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Member returns a corresponding Member object to a User if the Member is
|
||||||
|
// present in this room.
|
||||||
|
func (r *Room) Member(u *User) (*Member, bool) {
|
||||||
|
m, ok := r.MemberById(u.Id())
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// Check that it's the same user
|
||||||
|
if m.User != u {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return m, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Room) MemberById(id string) (*Member, bool) {
|
||||||
|
m, err := r.members.Get(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return m.(*Member), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsOp returns whether a user is an operator in this room.
|
||||||
|
func (r *Room) IsOp(u *User) bool {
|
||||||
|
m, ok := r.Member(u)
|
||||||
|
return ok && m.Op
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topic of the room.
|
||||||
|
func (r *Room) Topic() string {
|
||||||
|
return r.topic
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTopic will set the topic of the room.
|
||||||
|
func (r *Room) SetTopic(s string) {
|
||||||
|
r.topic = s
|
||||||
|
}
|
||||||
|
|
||||||
|
// NamesPrefix lists all members' names with a given prefix, used to query
|
||||||
|
// for autocompletion purposes.
|
||||||
|
func (r *Room) NamesPrefix(prefix string) []string {
|
||||||
|
members := r.members.ListPrefix(prefix)
|
||||||
|
names := make([]string, len(members))
|
||||||
|
for i, u := range members {
|
||||||
|
names[i] = u.(*Member).User.Name()
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
192
chat/room_test.go
Normal file
192
chat/room_test.go
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRoomServe(t *testing.T) {
|
||||||
|
ch := NewRoom()
|
||||||
|
ch.Send(NewAnnounceMsg("hello"))
|
||||||
|
|
||||||
|
received := <-ch.broadcast
|
||||||
|
actual := received.String()
|
||||||
|
expected := " * hello"
|
||||||
|
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomJoin(t *testing.T) {
|
||||||
|
var expected, actual []byte
|
||||||
|
|
||||||
|
s := &MockScreen{}
|
||||||
|
u := NewUser(testId("foo"))
|
||||||
|
|
||||||
|
ch := NewRoom()
|
||||||
|
go ch.Serve()
|
||||||
|
defer ch.Close()
|
||||||
|
|
||||||
|
_, err := ch.Join(u)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
u.ConsumeOne(s)
|
||||||
|
expected = []byte(" * foo joined. (Connected: 1)" + Newline)
|
||||||
|
s.Read(&actual)
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ch.Send(NewSystemMsg("hello", u))
|
||||||
|
u.ConsumeOne(s)
|
||||||
|
expected = []byte("-> hello" + Newline)
|
||||||
|
s.Read(&actual)
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ch.Send(ParseInput("/me says hello.", u))
|
||||||
|
u.ConsumeOne(s)
|
||||||
|
expected = []byte("** foo says hello." + Newline)
|
||||||
|
s.Read(&actual)
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomDoesntBroadcastAnnounceMessagesWhenQuiet(t *testing.T) {
|
||||||
|
u := NewUser(testId("foo"))
|
||||||
|
u.Config = UserConfig{
|
||||||
|
Quiet: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := NewRoom()
|
||||||
|
defer ch.Close()
|
||||||
|
|
||||||
|
_, err := ch.Join(u)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain the initial Join message
|
||||||
|
<-ch.broadcast
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for msg := range u.msg {
|
||||||
|
if _, ok := msg.(*AnnounceMsg); ok {
|
||||||
|
t.Errorf("Got unexpected `%T`", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Call with an AnnounceMsg and all the other types
|
||||||
|
// and assert we received only non-announce messages
|
||||||
|
ch.HandleMsg(NewAnnounceMsg("Ignored"))
|
||||||
|
// Assert we still get all other types of messages
|
||||||
|
ch.HandleMsg(NewEmoteMsg("hello", u))
|
||||||
|
ch.HandleMsg(NewSystemMsg("hello", u))
|
||||||
|
ch.HandleMsg(NewPrivateMsg("hello", u, u))
|
||||||
|
ch.HandleMsg(NewPublicMsg("hello", u))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomQuietToggleBroadcasts(t *testing.T) {
|
||||||
|
u := NewUser(testId("foo"))
|
||||||
|
u.Config = UserConfig{
|
||||||
|
Quiet: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := NewRoom()
|
||||||
|
defer ch.Close()
|
||||||
|
|
||||||
|
_, err := ch.Join(u)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain the initial Join message
|
||||||
|
<-ch.broadcast
|
||||||
|
|
||||||
|
u.ToggleQuietMode()
|
||||||
|
|
||||||
|
expectedMsg := NewAnnounceMsg("Ignored")
|
||||||
|
ch.HandleMsg(expectedMsg)
|
||||||
|
msg := <-u.msg
|
||||||
|
if _, ok := msg.(*AnnounceMsg); !ok {
|
||||||
|
t.Errorf("Got: `%T`; Expected: `%T`", msg, expectedMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
u.ToggleQuietMode()
|
||||||
|
|
||||||
|
ch.HandleMsg(NewAnnounceMsg("Ignored"))
|
||||||
|
ch.HandleMsg(NewSystemMsg("hello", u))
|
||||||
|
msg = <-u.msg
|
||||||
|
if _, ok := msg.(*AnnounceMsg); ok {
|
||||||
|
t.Errorf("Got unexpected `%T`", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuietToggleDisplayState(t *testing.T) {
|
||||||
|
var expected, actual []byte
|
||||||
|
|
||||||
|
s := &MockScreen{}
|
||||||
|
u := NewUser(testId("foo"))
|
||||||
|
|
||||||
|
ch := NewRoom()
|
||||||
|
go ch.Serve()
|
||||||
|
defer ch.Close()
|
||||||
|
|
||||||
|
_, err := ch.Join(u)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain the initial Join message
|
||||||
|
<-ch.broadcast
|
||||||
|
|
||||||
|
ch.Send(ParseInput("/quiet", u))
|
||||||
|
u.ConsumeOne(s)
|
||||||
|
expected = []byte("-> Quiet mode is toggled ON" + Newline)
|
||||||
|
s.Read(&actual)
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ch.Send(ParseInput("/quiet", u))
|
||||||
|
u.ConsumeOne(s)
|
||||||
|
expected = []byte("-> Quiet mode is toggled OFF" + Newline)
|
||||||
|
|
||||||
|
s.Read(&actual)
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomNames(t *testing.T) {
|
||||||
|
var expected, actual []byte
|
||||||
|
|
||||||
|
s := &MockScreen{}
|
||||||
|
u := NewUser(testId("foo"))
|
||||||
|
|
||||||
|
ch := NewRoom()
|
||||||
|
go ch.Serve()
|
||||||
|
defer ch.Close()
|
||||||
|
|
||||||
|
_, err := ch.Join(u)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain the initial Join message
|
||||||
|
<-ch.broadcast
|
||||||
|
|
||||||
|
ch.Send(ParseInput("/names", u))
|
||||||
|
u.ConsumeOne(s)
|
||||||
|
expected = []byte("-> 1 connected: foo" + Newline)
|
||||||
|
s.Read(&actual)
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
17
chat/sanitize.go
Normal file
17
chat/sanitize.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import "regexp"
|
||||||
|
|
||||||
|
var reStripName = regexp.MustCompile("[^\\w.-]")
|
||||||
|
|
||||||
|
// SanitizeName returns a name with only allowed characters.
|
||||||
|
func SanitizeName(s string) string {
|
||||||
|
return reStripName.ReplaceAllString(s, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
var reStripData = regexp.MustCompile("[^[:ascii:]]")
|
||||||
|
|
||||||
|
// SanitizeData returns a string with only allowed characters for client-provided metadata inputs.
|
||||||
|
func SanitizeData(s string) string {
|
||||||
|
return reStripData.ReplaceAllString(s, "")
|
||||||
|
}
|
51
chat/screen_test.go
Normal file
51
chat/screen_test.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Used for testing
|
||||||
|
type MockScreen struct {
|
||||||
|
buffer []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MockScreen) Write(data []byte) (n int, err error) {
|
||||||
|
s.buffer = append(s.buffer, data...)
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MockScreen) Read(p *[]byte) (n int, err error) {
|
||||||
|
*p = s.buffer
|
||||||
|
s.buffer = []byte{}
|
||||||
|
return len(*p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScreen(t *testing.T) {
|
||||||
|
var actual, expected []byte
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: %v; Expected: %v", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = []byte("foo")
|
||||||
|
expected = []byte("foo")
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: %v; Expected: %v", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &MockScreen{}
|
||||||
|
|
||||||
|
expected = nil
|
||||||
|
s.Read(&actual)
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: %v; Expected: %v", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected = []byte("hello, world")
|
||||||
|
s.Write(expected)
|
||||||
|
s.Read(&actual)
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: %v; Expected: %v", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
142
chat/set.go
Normal file
142
chat/set.go
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The error returned when an added id already exists in the set.
|
||||||
|
var ErrIdTaken = errors.New("id already taken")
|
||||||
|
|
||||||
|
// The error returned when a requested item does not exist in the set.
|
||||||
|
var ErrItemMissing = errors.New("item does not exist")
|
||||||
|
|
||||||
|
// Set with string lookup.
|
||||||
|
// TODO: Add trie for efficient prefix lookup?
|
||||||
|
type Set struct {
|
||||||
|
lookup map[string]Identifier
|
||||||
|
sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSet creates a new set.
|
||||||
|
func NewSet() *Set {
|
||||||
|
return &Set{
|
||||||
|
lookup: map[string]Identifier{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all items and returns the number removed.
|
||||||
|
func (s *Set) Clear() int {
|
||||||
|
s.Lock()
|
||||||
|
n := len(s.lookup)
|
||||||
|
s.lookup = map[string]Identifier{}
|
||||||
|
s.Unlock()
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the size of the set right now.
|
||||||
|
func (s *Set) Len() int {
|
||||||
|
return len(s.lookup)
|
||||||
|
}
|
||||||
|
|
||||||
|
// In checks if an item exists in this set.
|
||||||
|
func (s *Set) In(item Identifier) bool {
|
||||||
|
s.RLock()
|
||||||
|
_, ok := s.lookup[item.Id()]
|
||||||
|
s.RUnlock()
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns an item with the given Id.
|
||||||
|
func (s *Set) Get(id string) (Identifier, error) {
|
||||||
|
s.RLock()
|
||||||
|
item, ok := s.lookup[id]
|
||||||
|
s.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrItemMissing
|
||||||
|
}
|
||||||
|
|
||||||
|
return item, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add item to this set if it does not exist already.
|
||||||
|
func (s *Set) Add(item Identifier) error {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
|
||||||
|
_, found := s.lookup[item.Id()]
|
||||||
|
if found {
|
||||||
|
return ErrIdTaken
|
||||||
|
}
|
||||||
|
|
||||||
|
s.lookup[item.Id()] = item
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove item from this set.
|
||||||
|
func (s *Set) Remove(item Identifier) error {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
id := item.Id()
|
||||||
|
_, found := s.lookup[id]
|
||||||
|
if !found {
|
||||||
|
return ErrItemMissing
|
||||||
|
}
|
||||||
|
delete(s.lookup, id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace item from old id with new Identifier.
|
||||||
|
// Used for moving the same identifier to a new Id, such as a rename.
|
||||||
|
func (s *Set) Replace(oldId string, item Identifier) error {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
|
||||||
|
// Check if it already exists
|
||||||
|
_, found := s.lookup[item.Id()]
|
||||||
|
if found {
|
||||||
|
return ErrIdTaken
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove oldId
|
||||||
|
_, found = s.lookup[oldId]
|
||||||
|
if !found {
|
||||||
|
return ErrItemMissing
|
||||||
|
}
|
||||||
|
delete(s.lookup, oldId)
|
||||||
|
|
||||||
|
// Add new identifier
|
||||||
|
s.lookup[item.Id()] = item
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each loops over every item while holding a read lock and applies fn to each
|
||||||
|
// element.
|
||||||
|
func (s *Set) Each(fn func(item Identifier)) {
|
||||||
|
s.RLock()
|
||||||
|
for _, item := range s.lookup {
|
||||||
|
fn(item)
|
||||||
|
}
|
||||||
|
s.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListPrefix returns a list of items with a prefix, case insensitive.
|
||||||
|
func (s *Set) ListPrefix(prefix string) []Identifier {
|
||||||
|
r := []Identifier{}
|
||||||
|
prefix = strings.ToLower(prefix)
|
||||||
|
|
||||||
|
s.RLock()
|
||||||
|
defer s.RUnlock()
|
||||||
|
|
||||||
|
for id, item := range s.lookup {
|
||||||
|
if !strings.HasPrefix(string(id), prefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
r = append(r, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
38
chat/set_test.go
Normal file
38
chat/set_test.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSet(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
s := NewSet()
|
||||||
|
u := NewUser(testId("foo"))
|
||||||
|
|
||||||
|
if s.In(u) {
|
||||||
|
t.Errorf("Set should be empty.")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.Add(u)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.In(u) {
|
||||||
|
t.Errorf("Set should contain user.")
|
||||||
|
}
|
||||||
|
|
||||||
|
u2 := NewUser(testId("bar"))
|
||||||
|
err = s.Add(u2)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.Add(u2)
|
||||||
|
if err != ErrIdTaken {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := s.Len()
|
||||||
|
if size != 2 {
|
||||||
|
t.Errorf("Set wrong size: %d (expected %d)", size, 2)
|
||||||
|
}
|
||||||
|
}
|
195
chat/theme.go
Normal file
195
chat/theme.go
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Reset resets the color
|
||||||
|
Reset = "\033[0m"
|
||||||
|
|
||||||
|
// Bold makes the following text bold
|
||||||
|
Bold = "\033[1m"
|
||||||
|
|
||||||
|
// Dim dims the following text
|
||||||
|
Dim = "\033[2m"
|
||||||
|
|
||||||
|
// Italic makes the following text italic
|
||||||
|
Italic = "\033[3m"
|
||||||
|
|
||||||
|
// Underline underlines the following text
|
||||||
|
Underline = "\033[4m"
|
||||||
|
|
||||||
|
// Blink blinks the following text
|
||||||
|
Blink = "\033[5m"
|
||||||
|
|
||||||
|
// Invert inverts the following text
|
||||||
|
Invert = "\033[7m"
|
||||||
|
|
||||||
|
// Newline
|
||||||
|
Newline = "\r\n"
|
||||||
|
|
||||||
|
// BEL
|
||||||
|
Bel = "\007"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Interface for Styles
|
||||||
|
type Style interface {
|
||||||
|
String() string
|
||||||
|
Format(string) string
|
||||||
|
}
|
||||||
|
|
||||||
|
// General hardcoded style, mostly used as a crutch until we flesh out the
|
||||||
|
// framework to support backgrounds etc.
|
||||||
|
type style string
|
||||||
|
|
||||||
|
func (c style) String() string {
|
||||||
|
return string(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c style) Format(s string) string {
|
||||||
|
return c.String() + s + Reset
|
||||||
|
}
|
||||||
|
|
||||||
|
// 256 color type, for terminals who support it
|
||||||
|
type Color256 uint8
|
||||||
|
|
||||||
|
// String version of this color
|
||||||
|
func (c Color256) String() string {
|
||||||
|
return fmt.Sprintf("38;05;%d", c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return formatted string with this color
|
||||||
|
func (c Color256) Format(s string) string {
|
||||||
|
return "\033[" + c.String() + "m" + s + Reset
|
||||||
|
}
|
||||||
|
|
||||||
|
// No color, used for mono theme
|
||||||
|
type Color0 struct{}
|
||||||
|
|
||||||
|
// No-op for Color0
|
||||||
|
func (c Color0) String() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// No-op for Color0
|
||||||
|
func (c Color0) Format(s string) string {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Container for a collection of colors
|
||||||
|
type Palette struct {
|
||||||
|
colors []Style
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get a color by index, overflows are looped around.
|
||||||
|
func (p Palette) Get(i int) Style {
|
||||||
|
return p.colors[i%(p.size-1)]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Palette) Len() int {
|
||||||
|
return p.size
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Palette) String() string {
|
||||||
|
r := ""
|
||||||
|
for _, c := range p.colors {
|
||||||
|
r += c.Format("X")
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collection of settings for chat
|
||||||
|
type Theme struct {
|
||||||
|
id string
|
||||||
|
sys Style
|
||||||
|
pm Style
|
||||||
|
highlight Style
|
||||||
|
names *Palette
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Theme) Id() string {
|
||||||
|
return t.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// Colorize name string given some index
|
||||||
|
func (t Theme) ColorName(u *User) string {
|
||||||
|
if t.names == nil {
|
||||||
|
return u.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.names.Get(u.colorIdx).Format(u.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Colorize the PM string
|
||||||
|
func (t Theme) ColorPM(s string) string {
|
||||||
|
if t.pm == nil {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.pm.Format(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Colorize the Sys message
|
||||||
|
func (t Theme) ColorSys(s string) string {
|
||||||
|
if t.sys == nil {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.sys.Format(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Highlight a matched string, usually name
|
||||||
|
func (t Theme) Highlight(s string) string {
|
||||||
|
if t.highlight == nil {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return t.highlight.Format(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List of initialzied themes
|
||||||
|
var Themes []Theme
|
||||||
|
|
||||||
|
// Default theme to use
|
||||||
|
var DefaultTheme *Theme
|
||||||
|
|
||||||
|
func readableColors256() *Palette {
|
||||||
|
size := 247
|
||||||
|
p := Palette{
|
||||||
|
colors: make([]Style, size),
|
||||||
|
size: size,
|
||||||
|
}
|
||||||
|
j := 0
|
||||||
|
for i := 0; i < 256; i++ {
|
||||||
|
if (16 <= i && i <= 18) || (232 <= i && i <= 237) {
|
||||||
|
// Remove the ones near black, this is kinda sadpanda.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p.colors[j] = Color256(i)
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
return &p
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
palette := readableColors256()
|
||||||
|
|
||||||
|
Themes = []Theme{
|
||||||
|
Theme{
|
||||||
|
id: "colors",
|
||||||
|
names: palette,
|
||||||
|
sys: palette.Get(8), // Grey
|
||||||
|
pm: palette.Get(7), // White
|
||||||
|
highlight: style(Bold + "\033[48;5;11m\033[38;5;16m"), // Yellow highlight
|
||||||
|
},
|
||||||
|
Theme{
|
||||||
|
id: "mono",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug for printing colors:
|
||||||
|
//for _, color := range palette.colors {
|
||||||
|
// fmt.Print(color.Format(color.String() + " "))
|
||||||
|
//}
|
||||||
|
|
||||||
|
DefaultTheme = &Themes[0]
|
||||||
|
}
|
71
chat/theme_test.go
Normal file
71
chat/theme_test.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestThemePalette(t *testing.T) {
|
||||||
|
var expected, actual string
|
||||||
|
|
||||||
|
palette := readableColors256()
|
||||||
|
color := palette.Get(5)
|
||||||
|
if color == nil {
|
||||||
|
t.Fatal("Failed to return a color from palette.")
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = color.String()
|
||||||
|
expected = "38;05;5"
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = color.Format("foo")
|
||||||
|
expected = "\033[38;05;5mfoo\033[0m"
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = palette.Get(palette.Len() + 1).String()
|
||||||
|
expected = fmt.Sprintf("38;05;%d", 2)
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTheme(t *testing.T) {
|
||||||
|
var expected, actual string
|
||||||
|
|
||||||
|
colorTheme := Themes[0]
|
||||||
|
color := colorTheme.sys
|
||||||
|
if color == nil {
|
||||||
|
t.Fatal("Sys color should not be empty for first theme.")
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = color.Format("foo")
|
||||||
|
expected = "\033[38;05;8mfoo\033[0m"
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = colorTheme.ColorSys("foo")
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
u := NewUser(testId("foo"))
|
||||||
|
u.colorIdx = 4
|
||||||
|
actual = colorTheme.ColorName(u)
|
||||||
|
expected = "\033[38;05;4mfoo\033[0m"
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := NewPublicMsg("hello", u)
|
||||||
|
actual = msg.Render(&colorTheme)
|
||||||
|
expected = "\033[38;05;4mfoo\033[0m: hello"
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
178
chat/user.go
Normal file
178
chat/user.go
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"regexp"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const messageBuffer = 20
|
||||||
|
const reHighlight = `\b(%s)\b`
|
||||||
|
|
||||||
|
var ErrUserClosed = errors.New("user closed")
|
||||||
|
|
||||||
|
// Identifier is an interface that can uniquely identify itself.
|
||||||
|
type Identifier interface {
|
||||||
|
Id() string
|
||||||
|
SetId(string)
|
||||||
|
Name() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// User definition, implemented set Item interface and io.Writer
|
||||||
|
type User struct {
|
||||||
|
Identifier
|
||||||
|
Config UserConfig
|
||||||
|
colorIdx int
|
||||||
|
joined time.Time
|
||||||
|
msg chan Message
|
||||||
|
done chan struct{}
|
||||||
|
replyTo *User // Set when user gets a /msg, for replying.
|
||||||
|
closed bool
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUser(identity Identifier) *User {
|
||||||
|
u := User{
|
||||||
|
Identifier: identity,
|
||||||
|
Config: *DefaultUserConfig,
|
||||||
|
joined: time.Now(),
|
||||||
|
msg: make(chan Message, messageBuffer),
|
||||||
|
done: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
u.SetColorIdx(rand.Int())
|
||||||
|
|
||||||
|
return &u
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserScreen(identity Identifier, screen io.Writer) *User {
|
||||||
|
u := NewUser(identity)
|
||||||
|
go u.Consume(screen)
|
||||||
|
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rename the user with a new Identifier.
|
||||||
|
func (u *User) SetId(id string) {
|
||||||
|
u.Identifier.SetId(id)
|
||||||
|
u.SetColorIdx(rand.Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplyTo returns the last user that messaged this user.
|
||||||
|
func (u *User) ReplyTo() *User {
|
||||||
|
return u.replyTo
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReplyTo sets the last user to message this user.
|
||||||
|
func (u *User) SetReplyTo(user *User) {
|
||||||
|
u.replyTo = user
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToggleQuietMode will toggle whether or not quiet mode is enabled
|
||||||
|
func (u *User) ToggleQuietMode() {
|
||||||
|
u.Config.Quiet = !u.Config.Quiet
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetColorIdx will set the colorIdx to a specific value, primarily used for
|
||||||
|
// testing.
|
||||||
|
func (u *User) SetColorIdx(idx int) {
|
||||||
|
u.colorIdx = idx
|
||||||
|
}
|
||||||
|
|
||||||
|
// Block until user is closed
|
||||||
|
func (u *User) Wait() {
|
||||||
|
<-u.done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disconnect user, stop accepting messages
|
||||||
|
func (u *User) Close() {
|
||||||
|
u.closeOnce.Do(func() {
|
||||||
|
u.closed = true
|
||||||
|
close(u.done)
|
||||||
|
close(u.msg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume message buffer into an io.Writer. Will block, should be called in a
|
||||||
|
// goroutine.
|
||||||
|
// TODO: Not sure if this is a great API.
|
||||||
|
func (u *User) Consume(out io.Writer) {
|
||||||
|
for m := range u.msg {
|
||||||
|
u.HandleMsg(m, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume one message and stop, mostly for testing
|
||||||
|
func (u *User) ConsumeOne(out io.Writer) {
|
||||||
|
u.HandleMsg(<-u.msg, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHighlight sets the highlighting regular expression to match string.
|
||||||
|
func (u *User) SetHighlight(s string) error {
|
||||||
|
re, err := regexp.Compile(fmt.Sprintf(reHighlight, s))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
u.Config.Highlight = re
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) render(m Message) string {
|
||||||
|
switch m := m.(type) {
|
||||||
|
case *PublicMsg:
|
||||||
|
return m.RenderFor(u.Config) + Newline
|
||||||
|
case *PrivateMsg:
|
||||||
|
u.SetReplyTo(m.From())
|
||||||
|
return m.Render(u.Config.Theme) + Newline
|
||||||
|
default:
|
||||||
|
return m.Render(u.Config.Theme) + Newline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) HandleMsg(m Message, out io.Writer) {
|
||||||
|
r := u.render(m)
|
||||||
|
_, err := out.Write([]byte(r))
|
||||||
|
if err != nil {
|
||||||
|
logger.Printf("Write failed to %s, closing: %s", u.Name(), err)
|
||||||
|
u.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add message to consume by user
|
||||||
|
func (u *User) Send(m Message) error {
|
||||||
|
if u.closed {
|
||||||
|
return ErrUserClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case u.msg <- m:
|
||||||
|
default:
|
||||||
|
logger.Printf("Msg buffer full, closing: %s", u.Name())
|
||||||
|
u.Close()
|
||||||
|
return ErrUserClosed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Container for per-user configurations.
|
||||||
|
type UserConfig struct {
|
||||||
|
Highlight *regexp.Regexp
|
||||||
|
Bell bool
|
||||||
|
Quiet bool
|
||||||
|
Theme *Theme
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default user configuration to use
|
||||||
|
var DefaultUserConfig *UserConfig
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
DefaultUserConfig = &UserConfig{
|
||||||
|
Bell: true,
|
||||||
|
Quiet: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Seed random?
|
||||||
|
}
|
24
chat/user_test.go
Normal file
24
chat/user_test.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMakeUser(t *testing.T) {
|
||||||
|
var actual, expected []byte
|
||||||
|
|
||||||
|
s := &MockScreen{}
|
||||||
|
u := NewUser(testId("foo"))
|
||||||
|
m := NewAnnounceMsg("hello")
|
||||||
|
|
||||||
|
defer u.Close()
|
||||||
|
u.Send(m)
|
||||||
|
u.ConsumeOne(s)
|
||||||
|
|
||||||
|
s.Read(&actual)
|
||||||
|
expected = []byte(m.String() + Newline)
|
||||||
|
if !reflect.DeepEqual(actual, expected) {
|
||||||
|
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
547
client.go
547
client.go
@ -1,547 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
"golang.org/x/crypto/ssh/terminal"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// MsgBuffer is the length of the message buffer
|
|
||||||
MsgBuffer int = 20
|
|
||||||
|
|
||||||
// MaxMsgLength is the maximum length of a message
|
|
||||||
MaxMsgLength int = 1024
|
|
||||||
|
|
||||||
// MaxNamesList is the max number of items to return in a /names command
|
|
||||||
MaxNamesList int = 20
|
|
||||||
|
|
||||||
// HelpText is the text returned by /help
|
|
||||||
HelpText string = `Available commands:
|
|
||||||
/about - About this chat.
|
|
||||||
/exit - Exit the chat.
|
|
||||||
/help - Show this help text.
|
|
||||||
/list - List the users that are currently connected.
|
|
||||||
/beep - Enable BEL notifications on mention.
|
|
||||||
/me $ACTION - Show yourself doing an action.
|
|
||||||
/nick $NAME - Rename yourself to a new name.
|
|
||||||
/whois $NAME - Display information about another connected user.
|
|
||||||
/msg $NAME $MESSAGE - Sends a private message to a user.
|
|
||||||
/motd - Prints the Message of the Day.
|
|
||||||
/theme [color|mono] - Set client theme.`
|
|
||||||
|
|
||||||
// OpHelpText is the additional text returned by /help if the client is an Op
|
|
||||||
OpHelpText string = `Available operator commands:
|
|
||||||
/ban $NAME - Banish a user from the chat
|
|
||||||
/unban $FINGERPRINT - Unban a fingerprint
|
|
||||||
/banned - List all banned fingerprints
|
|
||||||
/kick $NAME - Kick em' out.
|
|
||||||
/op $NAME - Promote a user to server operator.
|
|
||||||
/silence $NAME - Revoke a user's ability to speak.
|
|
||||||
/shutdown $MESSAGE - Broadcast message and shutdown server.
|
|
||||||
/motd $MESSAGE - Set message shown whenever somebody joins.
|
|
||||||
/whitelist $FINGERPRINT - Add fingerprint to whitelist, prevent anyone else from joining.
|
|
||||||
/whitelist github.com/$USER - Add github user's pubkeys to whitelist.`
|
|
||||||
|
|
||||||
// AboutText is the text returned by /about
|
|
||||||
AboutText string = `ssh-chat is made by @shazow.
|
|
||||||
|
|
||||||
It is a custom ssh server built in Go to serve a chat experience
|
|
||||||
instead of a shell.
|
|
||||||
|
|
||||||
Source: https://github.com/shazow/ssh-chat
|
|
||||||
|
|
||||||
For more, visit shazow.net or follow at twitter.com/shazow`
|
|
||||||
|
|
||||||
// RequiredWait is the time a client is required to wait between messages
|
|
||||||
RequiredWait time.Duration = time.Second / 2
|
|
||||||
)
|
|
||||||
|
|
||||||
// Client holds all the fields used by the client
|
|
||||||
type Client struct {
|
|
||||||
Server *Server
|
|
||||||
Conn *ssh.ServerConn
|
|
||||||
Msg chan string
|
|
||||||
Name string
|
|
||||||
Color string
|
|
||||||
Op bool
|
|
||||||
ready chan struct{}
|
|
||||||
term *terminal.Terminal
|
|
||||||
termWidth int
|
|
||||||
termHeight int
|
|
||||||
silencedUntil time.Time
|
|
||||||
lastTX time.Time
|
|
||||||
beepMe bool
|
|
||||||
colorMe bool
|
|
||||||
closed bool
|
|
||||||
sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient constructs a new client
|
|
||||||
func NewClient(server *Server, conn *ssh.ServerConn) *Client {
|
|
||||||
return &Client{
|
|
||||||
Server: server,
|
|
||||||
Conn: conn,
|
|
||||||
Name: conn.User(),
|
|
||||||
Color: RandomColor256(),
|
|
||||||
Msg: make(chan string, MsgBuffer),
|
|
||||||
ready: make(chan struct{}, 1),
|
|
||||||
lastTX: time.Now(),
|
|
||||||
colorMe: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ColoredName returns the client name in its color
|
|
||||||
func (c *Client) ColoredName() string {
|
|
||||||
return ColorString(c.Color, c.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SysMsg sends a message in continuous format over the message channel
|
|
||||||
func (c *Client) SysMsg(msg string, args ...interface{}) {
|
|
||||||
c.Send(ContinuousFormat(systemMessageFormat, "-> "+fmt.Sprintf(msg, args...)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes the given message
|
|
||||||
func (c *Client) Write(msg string) {
|
|
||||||
if !c.colorMe {
|
|
||||||
msg = DeColorString(msg)
|
|
||||||
}
|
|
||||||
c.term.Write([]byte(msg + "\r\n"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteLines writes multiple messages
|
|
||||||
func (c *Client) WriteLines(msg []string) {
|
|
||||||
for _, line := range msg {
|
|
||||||
c.Write(line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send sends the given message
|
|
||||||
func (c *Client) Send(msg string) {
|
|
||||||
if len(msg) > MaxMsgLength || c.closed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case c.Msg <- msg:
|
|
||||||
default:
|
|
||||||
logger.Errorf("Msg buffer full, dropping: %s (%s)", c.Name, c.Conn.RemoteAddr())
|
|
||||||
c.Conn.Conn.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendLines sends multiple messages
|
|
||||||
func (c *Client) SendLines(msg []string) {
|
|
||||||
for _, line := range msg {
|
|
||||||
c.Send(line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsSilenced checks if the client is silenced
|
|
||||||
func (c *Client) IsSilenced() bool {
|
|
||||||
return c.silencedUntil.After(time.Now())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Silence silences a client for the given duration
|
|
||||||
func (c *Client) Silence(d time.Duration) {
|
|
||||||
c.silencedUntil = time.Now().Add(d)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resize resizes the client to the given width and height
|
|
||||||
func (c *Client) Resize(width, height int) error {
|
|
||||||
width = 1000000 // TODO: Remove this dirty workaround for text overflow once ssh/terminal is fixed
|
|
||||||
err := c.term.SetSize(width, height)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Resize failed: %dx%d", width, height)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.termWidth, c.termHeight = width, height
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rename renames the client to the given name
|
|
||||||
func (c *Client) Rename(name string) {
|
|
||||||
c.Name = name
|
|
||||||
var prompt string
|
|
||||||
|
|
||||||
if c.colorMe {
|
|
||||||
prompt = c.ColoredName()
|
|
||||||
} else {
|
|
||||||
prompt = c.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
c.term.SetPrompt(fmt.Sprintf("[%s] ", prompt))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fingerprint returns the fingerprint
|
|
||||||
func (c *Client) Fingerprint() string {
|
|
||||||
if c.Conn.Permissions == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return c.Conn.Permissions.Extensions["fingerprint"]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emote formats and sends an emote
|
|
||||||
func (c *Client) Emote(message string) {
|
|
||||||
formatted := fmt.Sprintf("** %s%s", c.ColoredName(), message)
|
|
||||||
if c.IsSilenced() || len(message) > 1000 {
|
|
||||||
c.SysMsg("Message rejected")
|
|
||||||
}
|
|
||||||
c.Server.Broadcast(formatted, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) handleShell(channel ssh.Channel) {
|
|
||||||
defer channel.Close()
|
|
||||||
|
|
||||||
// FIXME: This shouldn't live here, need to restructure the call chaining.
|
|
||||||
c.Server.Add(c)
|
|
||||||
go func() {
|
|
||||||
// Block until done, then remove.
|
|
||||||
c.Conn.Wait()
|
|
||||||
c.closed = true
|
|
||||||
c.Server.Remove(c)
|
|
||||||
close(c.Msg)
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for msg := range c.Msg {
|
|
||||||
c.Write(msg)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
line, err := c.term.ReadLine()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.SplitN(line, " ", 3)
|
|
||||||
isCmd := strings.HasPrefix(parts[0], "/")
|
|
||||||
|
|
||||||
if isCmd {
|
|
||||||
// TODO: Factor this out.
|
|
||||||
switch parts[0] {
|
|
||||||
case "/test-colors": // Shh, this command is a secret!
|
|
||||||
c.Write(ColorString("32", "Lorem ipsum dolor sit amet,"))
|
|
||||||
c.Write("consectetur " + ColorString("31;1", "adipiscing") + " elit.")
|
|
||||||
case "/exit":
|
|
||||||
channel.Close()
|
|
||||||
case "/help":
|
|
||||||
c.SysMsg(strings.Replace(HelpText, "\n", "\r\n", -1))
|
|
||||||
if c.Server.IsOp(c) {
|
|
||||||
c.SysMsg(strings.Replace(OpHelpText, "\n", "\r\n", -1))
|
|
||||||
}
|
|
||||||
case "/about":
|
|
||||||
c.SysMsg(strings.Replace(AboutText, "\n", "\r\n", -1))
|
|
||||||
case "/uptime":
|
|
||||||
c.SysMsg(c.Server.Uptime())
|
|
||||||
case "/beep":
|
|
||||||
c.beepMe = !c.beepMe
|
|
||||||
if c.beepMe {
|
|
||||||
c.SysMsg("I'll beep you good.")
|
|
||||||
} else {
|
|
||||||
c.SysMsg("No more beeps. :(")
|
|
||||||
}
|
|
||||||
case "/me":
|
|
||||||
me := strings.TrimLeft(line, "/me")
|
|
||||||
if me == "" {
|
|
||||||
me = " is at a loss for words."
|
|
||||||
}
|
|
||||||
c.Emote(me)
|
|
||||||
case "/slap":
|
|
||||||
slappee := "themself"
|
|
||||||
if len(parts) > 1 {
|
|
||||||
slappee = parts[1]
|
|
||||||
if len(parts[1]) > 100 {
|
|
||||||
slappee = "some long-named jerk"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Emote(fmt.Sprintf(" slaps %s around a bit with a large trout.", slappee))
|
|
||||||
case "/nick":
|
|
||||||
if len(parts) == 2 {
|
|
||||||
c.Server.Rename(c, parts[1])
|
|
||||||
} else {
|
|
||||||
c.SysMsg("Missing $NAME from: /nick $NAME")
|
|
||||||
}
|
|
||||||
case "/whois":
|
|
||||||
if len(parts) >= 2 {
|
|
||||||
client := c.Server.Who(parts[1])
|
|
||||||
if client != nil {
|
|
||||||
version := reStripText.ReplaceAllString(string(client.Conn.ClientVersion()), "")
|
|
||||||
if len(version) > 100 {
|
|
||||||
version = "Evil Jerk with a superlong string"
|
|
||||||
}
|
|
||||||
c.SysMsg("%s is %s via %s", client.ColoredName(), client.Fingerprint(), version)
|
|
||||||
} else {
|
|
||||||
c.SysMsg("No such name: %s", parts[1])
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
c.SysMsg("Missing $NAME from: /whois $NAME")
|
|
||||||
}
|
|
||||||
case "/names", "/list":
|
|
||||||
coloredNames := []string{}
|
|
||||||
for _, name := range c.Server.List(nil) {
|
|
||||||
coloredNames = append(coloredNames, c.Server.Who(name).ColoredName())
|
|
||||||
}
|
|
||||||
num := len(coloredNames)
|
|
||||||
if len(coloredNames) > MaxNamesList {
|
|
||||||
others := fmt.Sprintf("and %d others.", len(coloredNames)-MaxNamesList)
|
|
||||||
coloredNames = coloredNames[:MaxNamesList]
|
|
||||||
coloredNames = append(coloredNames, others)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.SysMsg("%d connected: %s", num, strings.Join(coloredNames, systemMessageFormat+", "))
|
|
||||||
case "/ban":
|
|
||||||
if !c.Server.IsOp(c) {
|
|
||||||
c.SysMsg("You're not an admin.")
|
|
||||||
} else if len(parts) != 2 {
|
|
||||||
c.SysMsg("Missing $NAME from: /ban $NAME")
|
|
||||||
} else {
|
|
||||||
client := c.Server.Who(parts[1])
|
|
||||||
if client == nil {
|
|
||||||
c.SysMsg("No such name: %s", parts[1])
|
|
||||||
} else {
|
|
||||||
fingerprint := client.Fingerprint()
|
|
||||||
client.SysMsg("Banned by %s.", c.ColoredName())
|
|
||||||
c.Server.Ban(fingerprint, nil)
|
|
||||||
client.Conn.Close()
|
|
||||||
c.Server.Broadcast(fmt.Sprintf("* %s was banned by %s", parts[1], c.ColoredName()), nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "/unban":
|
|
||||||
if !c.Server.IsOp(c) {
|
|
||||||
c.SysMsg("You're not an admin.")
|
|
||||||
} else if len(parts) != 2 {
|
|
||||||
c.SysMsg("Missing $FINGERPRINT from: /unban $FINGERPRINT")
|
|
||||||
} else {
|
|
||||||
fingerprint := parts[1]
|
|
||||||
isBanned := c.Server.IsBanned(fingerprint)
|
|
||||||
if !isBanned {
|
|
||||||
c.SysMsg("No such banned fingerprint: %s", fingerprint)
|
|
||||||
} else {
|
|
||||||
c.Server.Unban(fingerprint)
|
|
||||||
c.Server.Broadcast(fmt.Sprintf("* %s was unbanned by %s", fingerprint, c.ColoredName()), nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "/banned":
|
|
||||||
if !c.Server.IsOp(c) {
|
|
||||||
c.SysMsg("You're not an admin.")
|
|
||||||
} else if len(parts) != 1 {
|
|
||||||
c.SysMsg("Too many arguments for /banned")
|
|
||||||
} else {
|
|
||||||
for fingerprint := range c.Server.bannedPK {
|
|
||||||
c.SysMsg("Banned fingerprint: %s", fingerprint)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "/op":
|
|
||||||
if !c.Server.IsOp(c) {
|
|
||||||
c.SysMsg("You're not an admin.")
|
|
||||||
} else if len(parts) != 2 {
|
|
||||||
c.SysMsg("Missing $NAME from: /op $NAME")
|
|
||||||
} else {
|
|
||||||
client := c.Server.Who(parts[1])
|
|
||||||
if client == nil {
|
|
||||||
c.SysMsg("No such name: %s", parts[1])
|
|
||||||
} else {
|
|
||||||
fingerprint := client.Fingerprint()
|
|
||||||
if fingerprint == "" {
|
|
||||||
c.SysMsg("Cannot op user without fingerprint.")
|
|
||||||
} else {
|
|
||||||
client.SysMsg("Made op by %s.", c.ColoredName())
|
|
||||||
c.Server.Op(fingerprint)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "/kick":
|
|
||||||
if !c.Server.IsOp(c) {
|
|
||||||
c.SysMsg("You're not an admin.")
|
|
||||||
} else if len(parts) != 2 {
|
|
||||||
c.SysMsg("Missing $NAME from: /kick $NAME")
|
|
||||||
} else {
|
|
||||||
client := c.Server.Who(parts[1])
|
|
||||||
if client == nil {
|
|
||||||
c.SysMsg("No such name: %s", parts[1])
|
|
||||||
} else {
|
|
||||||
client.SysMsg("Kicked by %s.", c.ColoredName())
|
|
||||||
client.Conn.Close()
|
|
||||||
c.Server.Broadcast(fmt.Sprintf("* %s was kicked by %s", parts[1], c.ColoredName()), nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "/silence":
|
|
||||||
if !c.Server.IsOp(c) {
|
|
||||||
c.SysMsg("You're not an admin.")
|
|
||||||
} else if len(parts) < 2 {
|
|
||||||
c.SysMsg("Missing $NAME from: /silence $NAME")
|
|
||||||
} else {
|
|
||||||
duration := time.Duration(5) * time.Minute
|
|
||||||
if len(parts) >= 3 {
|
|
||||||
parsedDuration, err := time.ParseDuration(parts[2])
|
|
||||||
if err == nil {
|
|
||||||
duration = parsedDuration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
client := c.Server.Who(parts[1])
|
|
||||||
if client == nil {
|
|
||||||
c.SysMsg("No such name: %s", parts[1])
|
|
||||||
} else {
|
|
||||||
client.Silence(duration)
|
|
||||||
client.SysMsg("Silenced for %s by %s.", duration, c.ColoredName())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "/shutdown":
|
|
||||||
if !c.Server.IsOp(c) {
|
|
||||||
c.SysMsg("You're not an admin.")
|
|
||||||
} else {
|
|
||||||
var split = strings.SplitN(line, " ", 2)
|
|
||||||
var msg string
|
|
||||||
if len(split) > 1 {
|
|
||||||
msg = split[1]
|
|
||||||
} else {
|
|
||||||
msg = ""
|
|
||||||
}
|
|
||||||
// Shutdown after 5 seconds
|
|
||||||
go func() {
|
|
||||||
c.Server.Broadcast(ColorString("31", msg), nil)
|
|
||||||
time.Sleep(time.Second * 5)
|
|
||||||
c.Server.Stop()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
case "/msg": /* Send a PM */
|
|
||||||
/* Make sure we have a recipient and a message */
|
|
||||||
if len(parts) < 2 {
|
|
||||||
c.SysMsg("Missing $NAME from: /msg $NAME $MESSAGE")
|
|
||||||
break
|
|
||||||
} else if len(parts) < 3 {
|
|
||||||
c.SysMsg("Missing $MESSAGE from: /msg $NAME $MESSAGE")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
/* Ask the server to send the message */
|
|
||||||
if err := c.Server.Privmsg(parts[1], parts[2], c); nil != err {
|
|
||||||
c.SysMsg("Unable to send message to %v: %v", parts[1], err)
|
|
||||||
}
|
|
||||||
case "/motd": /* print motd */
|
|
||||||
if !c.Server.IsOp(c) {
|
|
||||||
c.Server.MotdUnicast(c)
|
|
||||||
} else if len(parts) < 2 {
|
|
||||||
c.Server.MotdUnicast(c)
|
|
||||||
} else {
|
|
||||||
var newmotd string
|
|
||||||
if len(parts) == 2 {
|
|
||||||
newmotd = parts[1]
|
|
||||||
} else {
|
|
||||||
newmotd = parts[1] + " " + parts[2]
|
|
||||||
}
|
|
||||||
c.Server.SetMotd(newmotd)
|
|
||||||
c.Server.MotdBroadcast(c)
|
|
||||||
}
|
|
||||||
case "/theme":
|
|
||||||
if len(parts) < 2 {
|
|
||||||
c.SysMsg("Missing $THEME from: /theme $THEME")
|
|
||||||
c.SysMsg("Choose either color or mono")
|
|
||||||
} else {
|
|
||||||
// Sets colorMe attribute of client
|
|
||||||
if parts[1] == "mono" {
|
|
||||||
c.colorMe = false
|
|
||||||
} else if parts[1] == "color" {
|
|
||||||
c.colorMe = true
|
|
||||||
}
|
|
||||||
// Rename to reset prompt
|
|
||||||
c.Rename(c.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
case "/whitelist": /* whitelist a fingerprint */
|
|
||||||
if !c.Server.IsOp(c) {
|
|
||||||
c.SysMsg("You're not an admin.")
|
|
||||||
} else if len(parts) != 2 {
|
|
||||||
c.SysMsg("Missing $FINGERPRINT from: /whitelist $FINGERPRINT")
|
|
||||||
} else {
|
|
||||||
fingerprint := parts[1]
|
|
||||||
go func() {
|
|
||||||
err = c.Server.Whitelist(fingerprint)
|
|
||||||
if err != nil {
|
|
||||||
c.SysMsg("Error adding to whitelist: %s", err)
|
|
||||||
} else {
|
|
||||||
c.SysMsg("Added %s to the whitelist", fingerprint)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
case "/version":
|
|
||||||
c.SysMsg("Version " + buildCommit)
|
|
||||||
|
|
||||||
default:
|
|
||||||
c.SysMsg("Invalid command: %s", line)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := fmt.Sprintf("%s: %s", c.ColoredName(), line)
|
|
||||||
/* Rate limit */
|
|
||||||
if time.Now().Sub(c.lastTX) < RequiredWait {
|
|
||||||
c.SysMsg("Rate limiting in effect.")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c.IsSilenced() || len(msg) > 1000 || len(line) < 1 {
|
|
||||||
c.SysMsg("Message rejected.")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
c.Server.Broadcast(msg, c)
|
|
||||||
c.lastTX = time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) handleChannels(channels <-chan ssh.NewChannel) {
|
|
||||||
prompt := fmt.Sprintf("[%s] ", c.ColoredName())
|
|
||||||
|
|
||||||
hasShell := false
|
|
||||||
|
|
||||||
for ch := range channels {
|
|
||||||
if t := ch.ChannelType(); t != "session" {
|
|
||||||
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, requests, err := ch.Accept()
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Could not accept channel: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
defer channel.Close()
|
|
||||||
|
|
||||||
c.term = terminal.NewTerminal(channel, prompt)
|
|
||||||
c.term.AutoCompleteCallback = c.Server.AutoCompleteFunction
|
|
||||||
|
|
||||||
for req := range requests {
|
|
||||||
var width, height int
|
|
||||||
var ok bool
|
|
||||||
|
|
||||||
switch req.Type {
|
|
||||||
case "shell":
|
|
||||||
if c.term != nil && !hasShell {
|
|
||||||
go c.handleShell(channel)
|
|
||||||
ok = true
|
|
||||||
hasShell = true
|
|
||||||
}
|
|
||||||
case "pty-req":
|
|
||||||
width, height, ok = parsePtyRequest(req.Payload)
|
|
||||||
if ok {
|
|
||||||
err := c.Resize(width, height)
|
|
||||||
ok = err == nil
|
|
||||||
}
|
|
||||||
case "window-change":
|
|
||||||
width, height, ok = parseWinchRequest(req.Payload)
|
|
||||||
if ok {
|
|
||||||
err := c.Resize(width, height)
|
|
||||||
ok = err == nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.WantReply {
|
|
||||||
req.Reply(ok, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
136
cmd.go
136
cmd.go
@ -13,6 +13,10 @@ import (
|
|||||||
"github.com/alexcesaro/log"
|
"github.com/alexcesaro/log"
|
||||||
"github.com/alexcesaro/log/golog"
|
"github.com/alexcesaro/log/golog"
|
||||||
"github.com/jessevdk/go-flags"
|
"github.com/jessevdk/go-flags"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/shazow/ssh-chat/chat"
|
||||||
|
"github.com/shazow/ssh-chat/sshd"
|
||||||
)
|
)
|
||||||
import _ "net/http/pprof"
|
import _ "net/http/pprof"
|
||||||
|
|
||||||
@ -20,10 +24,11 @@ import _ "net/http/pprof"
|
|||||||
type Options struct {
|
type Options struct {
|
||||||
Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."`
|
Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."`
|
||||||
Identity string `short:"i" long:"identity" description:"Private key to identify server with." default:"~/.ssh/id_rsa"`
|
Identity string `short:"i" long:"identity" description:"Private key to identify server with." default:"~/.ssh/id_rsa"`
|
||||||
Bind string `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:22"`
|
Bind string `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:2022"`
|
||||||
Admin []string `long:"admin" description:"Fingerprint of pubkey to mark as admin."`
|
Admin string `long:"admin" description:"File of public keys who are admins."`
|
||||||
Whitelist string `long:"whitelist" description:"Optional file of pubkey fingerprints who are allowed to connect."`
|
Whitelist string `long:"whitelist" description:"Optional file of public keys who are allowed to connect."`
|
||||||
Motd string `long:"motd" description:"Optional Message of the Day file."`
|
Motd string `long:"motd" description:"Optional Message of the Day file."`
|
||||||
|
Log string `long:"log" description:"Write chat log to this file."`
|
||||||
Pprof int `long:"pprof" description:"Enable pprof http server for profiling."`
|
Pprof int `long:"pprof" description:"Enable pprof http server for profiling."`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,6 +39,7 @@ var logLevels = []log.Level{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var buildCommit string
|
var buildCommit string
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
options := Options{}
|
options := Options{}
|
||||||
parser := flags.NewParser(&options, flags.Default)
|
parser := flags.NewParser(&options, flags.Default)
|
||||||
@ -42,6 +48,7 @@ func main() {
|
|||||||
if p == nil {
|
if p == nil {
|
||||||
fmt.Print(err)
|
fmt.Print(err)
|
||||||
}
|
}
|
||||||
|
os.Exit(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,54 +58,84 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize seed for random colors
|
|
||||||
RandomColorInit()
|
|
||||||
|
|
||||||
// Figure out the log level
|
// Figure out the log level
|
||||||
numVerbose := len(options.Verbose)
|
numVerbose := len(options.Verbose)
|
||||||
if numVerbose > len(logLevels) {
|
if numVerbose > len(logLevels) {
|
||||||
numVerbose = len(logLevels)
|
numVerbose = len(logLevels) - 1
|
||||||
}
|
}
|
||||||
|
|
||||||
logLevel := logLevels[numVerbose]
|
logLevel := logLevels[numVerbose]
|
||||||
logger = golog.New(os.Stderr, logLevel)
|
logger = golog.New(os.Stderr, logLevel)
|
||||||
|
|
||||||
|
if logLevel == log.Debug {
|
||||||
|
// Enable logging from submodules
|
||||||
|
chat.SetLogger(os.Stderr)
|
||||||
|
sshd.SetLogger(os.Stderr)
|
||||||
|
}
|
||||||
|
|
||||||
privateKeyPath := options.Identity
|
privateKeyPath := options.Identity
|
||||||
if strings.HasPrefix(privateKeyPath, "~") {
|
if strings.HasPrefix(privateKeyPath, "~/") {
|
||||||
user, err := user.Current()
|
user, err := user.Current()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
privateKeyPath = strings.Replace(privateKeyPath, "~", user.HomeDir, 1)
|
privateKeyPath = strings.Replace(privateKeyPath, "~", user.HomeDir, 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKey, err := ioutil.ReadFile(privateKeyPath)
|
privateKey, err := ReadPrivateKey(privateKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Failed to load identity: %v", err)
|
logger.Errorf("Couldn't read private key: %v", err)
|
||||||
return
|
os.Exit(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
server, err := NewServer(privateKey)
|
signer, err := ssh.ParsePrivateKey(privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Failed to create server: %v", err)
|
logger.Errorf("Failed to parse key: %v", err)
|
||||||
return
|
os.Exit(3)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, fingerprint := range options.Admin {
|
auth := NewAuth()
|
||||||
server.Op(fingerprint)
|
config := sshd.MakeAuth(auth)
|
||||||
}
|
config.AddHostKey(signer)
|
||||||
|
|
||||||
if options.Whitelist != "" {
|
s, err := sshd.ListenSSH(options.Bind, config)
|
||||||
file, err := os.Open(options.Whitelist)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Could not open whitelist file")
|
logger.Errorf("Failed to listen on socket: %v", err)
|
||||||
return
|
os.Exit(4)
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer s.Close()
|
||||||
|
s.RateLimit = true
|
||||||
|
|
||||||
scanner := bufio.NewScanner(file)
|
fmt.Printf("Listening for connections on %v\n", s.Addr().String())
|
||||||
for scanner.Scan() {
|
|
||||||
server.Whitelist(scanner.Text())
|
host := NewHost(s)
|
||||||
|
host.auth = auth
|
||||||
|
host.theme = &chat.Themes[0]
|
||||||
|
|
||||||
|
err = fromFile(options.Admin, func(line []byte) error {
|
||||||
|
key, _, _, _, err := ssh.ParseAuthorizedKey(line)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
auth.Op(key, 0)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Failed to load admins: %v", err)
|
||||||
|
os.Exit(5)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = fromFile(options.Whitelist, func(line []byte) error {
|
||||||
|
key, _, _, _, err := ssh.ParseAuthorizedKey(line)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
auth.Whitelist(key, 0)
|
||||||
|
logger.Debugf("Whitelisted: %s", line)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Failed to load whitelist: %v", err)
|
||||||
|
os.Exit(5)
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.Motd != "" {
|
if options.Motd != "" {
|
||||||
@ -107,24 +144,53 @@ func main() {
|
|||||||
logger.Errorf("Failed to load MOTD file: %v", err)
|
logger.Errorf("Failed to load MOTD file: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
motdString := string(motd[:])
|
motdString := strings.TrimSpace(string(motd))
|
||||||
/* hack to normalize line endings into \r\n */
|
// hack to normalize line endings into \r\n
|
||||||
motdString = strings.Replace(motdString, "\r\n", "\n", -1)
|
motdString = strings.Replace(motdString, "\r\n", "\n", -1)
|
||||||
motdString = strings.Replace(motdString, "\n", "\r\n", -1)
|
motdString = strings.Replace(motdString, "\n", "\r\n", -1)
|
||||||
server.SetMotd(motdString)
|
host.SetMotd(motdString)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if options.Log == "-" {
|
||||||
|
host.SetLogging(os.Stdout)
|
||||||
|
} else if options.Log != "" {
|
||||||
|
fp, err := os.OpenFile(options.Log, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Failed to open log file for writing: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
host.SetLogging(fp)
|
||||||
|
}
|
||||||
|
|
||||||
|
go host.Serve()
|
||||||
|
|
||||||
// Construct interrupt handler
|
// Construct interrupt handler
|
||||||
sig := make(chan os.Signal, 1)
|
sig := make(chan os.Signal, 1)
|
||||||
signal.Notify(sig, os.Interrupt)
|
signal.Notify(sig, os.Interrupt)
|
||||||
|
|
||||||
err = server.Start(options.Bind)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Failed to start server: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
<-sig // Wait for ^C signal
|
<-sig // Wait for ^C signal
|
||||||
logger.Warningf("Interrupt signal detected, shutting down.")
|
logger.Warningf("Interrupt signal detected, shutting down.")
|
||||||
server.Stop()
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromFile(path string, handler func(line []byte) error) error {
|
||||||
|
if path == "" {
|
||||||
|
// Skip
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
file, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(file)
|
||||||
|
for scanner.Scan() {
|
||||||
|
err := handler(scanner.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
82
colors.go
82
colors.go
@ -1,82 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Reset resets the color
|
|
||||||
Reset = "\033[0m"
|
|
||||||
|
|
||||||
// Bold makes the following text bold
|
|
||||||
Bold = "\033[1m"
|
|
||||||
|
|
||||||
// Dim dims the following text
|
|
||||||
Dim = "\033[2m"
|
|
||||||
|
|
||||||
// Italic makes the following text italic
|
|
||||||
Italic = "\033[3m"
|
|
||||||
|
|
||||||
// Underline underlines the following text
|
|
||||||
Underline = "\033[4m"
|
|
||||||
|
|
||||||
// Blink blinks the following text
|
|
||||||
Blink = "\033[5m"
|
|
||||||
|
|
||||||
// Invert inverts the following text
|
|
||||||
Invert = "\033[7m"
|
|
||||||
)
|
|
||||||
|
|
||||||
var colors = []string{"31", "32", "33", "34", "35", "36", "37", "91", "92", "93", "94", "95", "96", "97"}
|
|
||||||
|
|
||||||
// deColor is used for removing ANSI Escapes
|
|
||||||
var deColor = regexp.MustCompile("\033\\[[\\d;]+m")
|
|
||||||
|
|
||||||
// DeColorString removes all color from the given string
|
|
||||||
func DeColorString(s string) string {
|
|
||||||
s = deColor.ReplaceAllString(s, "")
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func randomReadableColor() int {
|
|
||||||
for {
|
|
||||||
i := rand.Intn(256)
|
|
||||||
if (16 <= i && i <= 18) || (232 <= i && i <= 237) {
|
|
||||||
// Remove the ones near black, this is kinda sadpanda.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RandomColor256 returns a random (of 256) color
|
|
||||||
func RandomColor256() string {
|
|
||||||
return fmt.Sprintf("38;05;%d", randomReadableColor())
|
|
||||||
}
|
|
||||||
|
|
||||||
// RandomColor returns a random color
|
|
||||||
func RandomColor() string {
|
|
||||||
return colors[rand.Intn(len(colors))]
|
|
||||||
}
|
|
||||||
|
|
||||||
// ColorString returns a message in the given color
|
|
||||||
func ColorString(color string, msg string) string {
|
|
||||||
return Bold + "\033[" + color + "m" + msg + Reset
|
|
||||||
}
|
|
||||||
|
|
||||||
// RandomColorInit initializes the random seed
|
|
||||||
func RandomColorInit() {
|
|
||||||
rand.Seed(time.Now().UTC().UnixNano())
|
|
||||||
}
|
|
||||||
|
|
||||||
// ContinuousFormat is a horrible hack to "continue" the previous string color
|
|
||||||
// and format after a RESET has been encountered.
|
|
||||||
//
|
|
||||||
// This is not HTML where you can just do a </style> to resume your previous formatting!
|
|
||||||
func ContinuousFormat(format string, str string) string {
|
|
||||||
return systemMessageFormat + strings.Replace(str, Reset, format, -1) + Reset
|
|
||||||
}
|
|
@ -1,53 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestHistory(t *testing.T) {
|
|
||||||
var r, expected []string
|
|
||||||
var size int
|
|
||||||
|
|
||||||
h := NewHistory(5)
|
|
||||||
|
|
||||||
r = h.Get(10)
|
|
||||||
expected = []string{}
|
|
||||||
if !reflect.DeepEqual(r, expected) {
|
|
||||||
t.Errorf("Got: %v, Expected: %v", r, expected)
|
|
||||||
}
|
|
||||||
|
|
||||||
h.Add("1")
|
|
||||||
|
|
||||||
if size = h.Len(); size != 1 {
|
|
||||||
t.Errorf("Wrong size: %v", size)
|
|
||||||
}
|
|
||||||
|
|
||||||
r = h.Get(1)
|
|
||||||
expected = []string{"1"}
|
|
||||||
if !reflect.DeepEqual(r, expected) {
|
|
||||||
t.Errorf("Got: %v, Expected: %v", r, expected)
|
|
||||||
}
|
|
||||||
|
|
||||||
h.Add("2")
|
|
||||||
h.Add("3")
|
|
||||||
h.Add("4")
|
|
||||||
h.Add("5")
|
|
||||||
h.Add("6")
|
|
||||||
|
|
||||||
if size = h.Len(); size != 5 {
|
|
||||||
t.Errorf("Wrong size: %v", size)
|
|
||||||
}
|
|
||||||
|
|
||||||
r = h.Get(2)
|
|
||||||
expected = []string{"5", "6"}
|
|
||||||
if !reflect.DeepEqual(r, expected) {
|
|
||||||
t.Errorf("Got: %v, Expected: %v", r, expected)
|
|
||||||
}
|
|
||||||
|
|
||||||
r = h.Get(10)
|
|
||||||
expected = []string{"2", "3", "4", "5", "6"}
|
|
||||||
if !reflect.DeepEqual(r, expected) {
|
|
||||||
t.Errorf("Got: %v, Expected: %v", r, expected)
|
|
||||||
}
|
|
||||||
}
|
|
461
host.go
Normal file
461
host.go
Normal file
@ -0,0 +1,461 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/shazow/rateio"
|
||||||
|
"github.com/shazow/ssh-chat/chat"
|
||||||
|
"github.com/shazow/ssh-chat/sshd"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxInputLength int = 1024
|
||||||
|
|
||||||
|
// GetPrompt will render the terminal prompt string based on the user.
|
||||||
|
func GetPrompt(user *chat.User) string {
|
||||||
|
name := user.Name()
|
||||||
|
if user.Config.Theme != nil {
|
||||||
|
name = user.Config.Theme.ColorName(user)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("[%s] ", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Host is the bridge between sshd and chat modules
|
||||||
|
// TODO: Should be easy to add support for multiple rooms, if we want.
|
||||||
|
type Host struct {
|
||||||
|
*chat.Room
|
||||||
|
listener *sshd.SSHListener
|
||||||
|
commands chat.Commands
|
||||||
|
|
||||||
|
motd string
|
||||||
|
auth *Auth
|
||||||
|
count int
|
||||||
|
|
||||||
|
// Default theme
|
||||||
|
theme *chat.Theme
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHost creates a Host on top of an existing listener.
|
||||||
|
func NewHost(listener *sshd.SSHListener) *Host {
|
||||||
|
room := chat.NewRoom()
|
||||||
|
h := Host{
|
||||||
|
Room: room,
|
||||||
|
listener: listener,
|
||||||
|
commands: chat.Commands{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make our own commands registry instance.
|
||||||
|
chat.InitCommands(&h.commands)
|
||||||
|
h.InitCommands(&h.commands)
|
||||||
|
room.SetCommands(h.commands)
|
||||||
|
|
||||||
|
go room.Serve()
|
||||||
|
return &h
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMotd sets the host's message of the day.
|
||||||
|
func (h *Host) SetMotd(motd string) {
|
||||||
|
h.motd = motd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h Host) isOp(conn sshd.Connection) bool {
|
||||||
|
key := conn.PublicKey()
|
||||||
|
if key == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return h.auth.IsOp(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect a specific Terminal to this host and its room.
|
||||||
|
func (h *Host) Connect(term *sshd.Terminal) {
|
||||||
|
id := NewIdentity(term.Conn)
|
||||||
|
user := chat.NewUserScreen(id, term)
|
||||||
|
user.Config.Theme = h.theme
|
||||||
|
go func() {
|
||||||
|
// Close term once user is closed.
|
||||||
|
user.Wait()
|
||||||
|
term.Close()
|
||||||
|
}()
|
||||||
|
defer user.Close()
|
||||||
|
|
||||||
|
// Send MOTD
|
||||||
|
if h.motd != "" {
|
||||||
|
user.Send(chat.NewAnnounceMsg(h.motd))
|
||||||
|
}
|
||||||
|
|
||||||
|
member, err := h.Join(user)
|
||||||
|
if err != nil {
|
||||||
|
// Try again...
|
||||||
|
id.SetName(fmt.Sprintf("Guest%d", h.count))
|
||||||
|
member, err = h.Join(user)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Failed to join: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Successfully joined.
|
||||||
|
term.SetPrompt(GetPrompt(user))
|
||||||
|
term.AutoCompleteCallback = h.AutoCompleteFunction(user)
|
||||||
|
user.SetHighlight(user.Name())
|
||||||
|
h.count++
|
||||||
|
|
||||||
|
// Should the user be op'd on join?
|
||||||
|
member.Op = h.isOp(term.Conn)
|
||||||
|
ratelimit := rateio.NewSimpleLimiter(3, time.Second*3)
|
||||||
|
|
||||||
|
for {
|
||||||
|
line, err := term.ReadLine()
|
||||||
|
if err == io.EOF {
|
||||||
|
// Closed
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
logger.Errorf("Terminal reading error: %s", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ratelimit.Count(1)
|
||||||
|
if err != nil {
|
||||||
|
user.Send(chat.NewSystemMsg("Message rejected: Rate limiting is in effect.", user))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(line) > maxInputLength {
|
||||||
|
user.Send(chat.NewSystemMsg("Message rejected: Input too long.", user))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if line == "" {
|
||||||
|
// Silently ignore empty lines.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
m := chat.ParseInput(line, user)
|
||||||
|
|
||||||
|
// FIXME: Any reason to use h.room.Send(m) instead?
|
||||||
|
h.HandleMsg(m)
|
||||||
|
|
||||||
|
cmd := m.Command()
|
||||||
|
if cmd == "/nick" || cmd == "/theme" {
|
||||||
|
// Hijack /nick command to update terminal synchronously. Wouldn't
|
||||||
|
// work if we use h.room.Send(m) above.
|
||||||
|
//
|
||||||
|
// FIXME: This is hacky, how do we improve the API to allow for
|
||||||
|
// this? Chat module shouldn't know about terminals.
|
||||||
|
term.SetPrompt(GetPrompt(user))
|
||||||
|
user.SetHighlight(user.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.Leave(user)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("Failed to leave: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve our chat room onto the listener
|
||||||
|
func (h *Host) Serve() {
|
||||||
|
terminals := h.listener.ServeTerminal()
|
||||||
|
|
||||||
|
for term := range terminals {
|
||||||
|
go h.Connect(term)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h Host) completeName(partial string) string {
|
||||||
|
names := h.NamesPrefix(partial)
|
||||||
|
if len(names) == 0 {
|
||||||
|
// Didn't find anything
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return names[len(names)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h Host) completeCommand(partial string) string {
|
||||||
|
for cmd, _ := range h.commands {
|
||||||
|
if strings.HasPrefix(cmd, partial) {
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// AutoCompleteFunction returns a callback for terminal autocompletion
|
||||||
|
func (h *Host) AutoCompleteFunction(u *chat.User) func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
|
||||||
|
return func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
|
||||||
|
if key != 9 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(line[:pos], " ") {
|
||||||
|
// Don't autocomplete spaces.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Fields(line[:pos])
|
||||||
|
isFirst := len(fields) < 2
|
||||||
|
partial := fields[len(fields)-1]
|
||||||
|
posPartial := pos - len(partial)
|
||||||
|
|
||||||
|
var completed string
|
||||||
|
if isFirst && strings.HasPrefix(partial, "/") {
|
||||||
|
// Command
|
||||||
|
completed = h.completeCommand(partial)
|
||||||
|
if completed == "/reply" {
|
||||||
|
replyTo := u.ReplyTo()
|
||||||
|
if replyTo != nil {
|
||||||
|
completed = "/msg " + replyTo.Name()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Name
|
||||||
|
completed = h.completeName(partial)
|
||||||
|
if completed == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if isFirst {
|
||||||
|
completed += ":"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
completed += " "
|
||||||
|
|
||||||
|
// Reposition the cursor
|
||||||
|
newLine = strings.Replace(line[posPartial:], partial, completed, 1)
|
||||||
|
newLine = line[:posPartial] + newLine
|
||||||
|
newPos = pos + (len(completed) - len(partial))
|
||||||
|
ok = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUser returns a chat.User based on a name.
|
||||||
|
func (h *Host) GetUser(name string) (*chat.User, bool) {
|
||||||
|
m, ok := h.MemberById(name)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return m.User, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitCommands adds host-specific commands to a Commands container. These will
|
||||||
|
// override any existing commands.
|
||||||
|
func (h *Host) InitCommands(c *chat.Commands) {
|
||||||
|
c.Add(chat.Command{
|
||||||
|
Prefix: "/msg",
|
||||||
|
PrefixHelp: "USER MESSAGE",
|
||||||
|
Help: "Send MESSAGE to USER.",
|
||||||
|
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
|
||||||
|
args := msg.Args()
|
||||||
|
switch len(args) {
|
||||||
|
case 0:
|
||||||
|
return errors.New("must specify user")
|
||||||
|
case 1:
|
||||||
|
return errors.New("must specify message")
|
||||||
|
}
|
||||||
|
|
||||||
|
target, ok := h.GetUser(args[0])
|
||||||
|
if !ok {
|
||||||
|
return errors.New("user not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
m := chat.NewPrivateMsg(strings.Join(args[1:], " "), msg.From(), target)
|
||||||
|
room.Send(m)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
c.Add(chat.Command{
|
||||||
|
Prefix: "/reply",
|
||||||
|
PrefixHelp: "MESSAGE",
|
||||||
|
Help: "Reply with MESSAGE to the previous private message.",
|
||||||
|
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
|
||||||
|
args := msg.Args()
|
||||||
|
switch len(args) {
|
||||||
|
case 0:
|
||||||
|
return errors.New("must specify message")
|
||||||
|
}
|
||||||
|
|
||||||
|
target := msg.From().ReplyTo()
|
||||||
|
if target == nil {
|
||||||
|
return errors.New("no message to reply to")
|
||||||
|
}
|
||||||
|
|
||||||
|
m := chat.NewPrivateMsg(strings.Join(args, " "), msg.From(), target)
|
||||||
|
room.Send(m)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
c.Add(chat.Command{
|
||||||
|
Prefix: "/whois",
|
||||||
|
PrefixHelp: "USER",
|
||||||
|
Help: "Information about USER.",
|
||||||
|
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
|
||||||
|
args := msg.Args()
|
||||||
|
if len(args) == 0 {
|
||||||
|
return errors.New("must specify user")
|
||||||
|
}
|
||||||
|
|
||||||
|
target, ok := h.GetUser(args[0])
|
||||||
|
if !ok {
|
||||||
|
return errors.New("user not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
id := target.Identifier.(*Identity)
|
||||||
|
room.Send(chat.NewSystemMsg(id.Whois(), msg.From()))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hidden commands
|
||||||
|
c.Add(chat.Command{
|
||||||
|
Prefix: "/version",
|
||||||
|
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
|
||||||
|
room.Send(chat.NewSystemMsg(buildCommit, msg.From()))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
timeStarted := time.Now()
|
||||||
|
c.Add(chat.Command{
|
||||||
|
Prefix: "/uptime",
|
||||||
|
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
|
||||||
|
room.Send(chat.NewSystemMsg(time.Now().Sub(timeStarted).String(), msg.From()))
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Op commands
|
||||||
|
c.Add(chat.Command{
|
||||||
|
Op: true,
|
||||||
|
Prefix: "/kick",
|
||||||
|
PrefixHelp: "USER",
|
||||||
|
Help: "Kick USER from the server.",
|
||||||
|
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
|
||||||
|
if !room.IsOp(msg.From()) {
|
||||||
|
return errors.New("must be op")
|
||||||
|
}
|
||||||
|
|
||||||
|
args := msg.Args()
|
||||||
|
if len(args) == 0 {
|
||||||
|
return errors.New("must specify user")
|
||||||
|
}
|
||||||
|
|
||||||
|
target, ok := h.GetUser(args[0])
|
||||||
|
if !ok {
|
||||||
|
return errors.New("user not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
body := fmt.Sprintf("%s was kicked by %s.", target.Name(), msg.From().Name())
|
||||||
|
room.Send(chat.NewAnnounceMsg(body))
|
||||||
|
target.Close()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
c.Add(chat.Command{
|
||||||
|
Op: true,
|
||||||
|
Prefix: "/ban",
|
||||||
|
PrefixHelp: "USER [DURATION]",
|
||||||
|
Help: "Ban USER from the server.",
|
||||||
|
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
|
||||||
|
// TODO: Would be nice to specify what to ban. Key? Ip? etc.
|
||||||
|
if !room.IsOp(msg.From()) {
|
||||||
|
return errors.New("must be op")
|
||||||
|
}
|
||||||
|
|
||||||
|
args := msg.Args()
|
||||||
|
if len(args) == 0 {
|
||||||
|
return errors.New("must specify user")
|
||||||
|
}
|
||||||
|
|
||||||
|
target, ok := h.GetUser(args[0])
|
||||||
|
if !ok {
|
||||||
|
return errors.New("user not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
var until time.Duration = 0
|
||||||
|
if len(args) > 1 {
|
||||||
|
until, _ = time.ParseDuration(args[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
id := target.Identifier.(*Identity)
|
||||||
|
h.auth.Ban(id.PublicKey(), until)
|
||||||
|
h.auth.BanAddr(id.RemoteAddr(), until)
|
||||||
|
|
||||||
|
body := fmt.Sprintf("%s was banned by %s.", target.Name(), msg.From().Name())
|
||||||
|
room.Send(chat.NewAnnounceMsg(body))
|
||||||
|
target.Close()
|
||||||
|
|
||||||
|
logger.Debugf("Banned: \n-> %s", id.Whois())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
c.Add(chat.Command{
|
||||||
|
Op: true,
|
||||||
|
Prefix: "/motd",
|
||||||
|
PrefixHelp: "MESSAGE",
|
||||||
|
Help: "Set the MESSAGE of the day.",
|
||||||
|
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
|
||||||
|
if !room.IsOp(msg.From()) {
|
||||||
|
return errors.New("must be op")
|
||||||
|
}
|
||||||
|
|
||||||
|
motd := ""
|
||||||
|
args := msg.Args()
|
||||||
|
if len(args) > 0 {
|
||||||
|
motd = strings.Join(args, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.motd = motd
|
||||||
|
body := fmt.Sprintf("New message of the day set by %s:", msg.From().Name())
|
||||||
|
room.Send(chat.NewAnnounceMsg(body))
|
||||||
|
if motd != "" {
|
||||||
|
room.Send(chat.NewAnnounceMsg(motd))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
c.Add(chat.Command{
|
||||||
|
Op: true,
|
||||||
|
Prefix: "/op",
|
||||||
|
PrefixHelp: "USER [DURATION]",
|
||||||
|
Help: "Set USER as admin.",
|
||||||
|
Handler: func(room *chat.Room, msg chat.CommandMsg) error {
|
||||||
|
if !room.IsOp(msg.From()) {
|
||||||
|
return errors.New("must be op")
|
||||||
|
}
|
||||||
|
|
||||||
|
args := msg.Args()
|
||||||
|
if len(args) == 0 {
|
||||||
|
return errors.New("must specify user")
|
||||||
|
}
|
||||||
|
|
||||||
|
var until time.Duration = 0
|
||||||
|
if len(args) > 1 {
|
||||||
|
until, _ = time.ParseDuration(args[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
member, ok := room.MemberById(args[0])
|
||||||
|
if !ok {
|
||||||
|
return errors.New("user not found")
|
||||||
|
}
|
||||||
|
member.Op = true
|
||||||
|
id := member.Identifier.(*Identity)
|
||||||
|
h.auth.Op(id.PublicKey(), until)
|
||||||
|
|
||||||
|
body := fmt.Sprintf("Made op by %s.", msg.From().Name())
|
||||||
|
room.Send(chat.NewSystemMsg(body, member.User))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
218
host_test.go
Normal file
218
host_test.go
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/shazow/ssh-chat/chat"
|
||||||
|
"github.com/shazow/ssh-chat/sshd"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func stripPrompt(s string) string {
|
||||||
|
pos := strings.LastIndex(s, "\033[K")
|
||||||
|
if pos < 0 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[pos+3:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostGetPrompt(t *testing.T) {
|
||||||
|
var expected, actual string
|
||||||
|
|
||||||
|
u := chat.NewUser(&Identity{nil, "foo"})
|
||||||
|
u.SetColorIdx(2)
|
||||||
|
|
||||||
|
actual = GetPrompt(u)
|
||||||
|
expected = "[foo] "
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Config.Theme = &chat.Themes[0]
|
||||||
|
actual = GetPrompt(u)
|
||||||
|
expected = "[\033[38;05;2mfoo\033[0m] "
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got: %q; Expected: %q", actual, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostNameCollision(t *testing.T) {
|
||||||
|
key, err := sshd.NewRandomSigner(512)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
config := sshd.MakeNoAuth()
|
||||||
|
config.AddHostKey(key)
|
||||||
|
|
||||||
|
s, err := sshd.ListenSSH(":0", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
host := NewHost(s)
|
||||||
|
go host.Serve()
|
||||||
|
|
||||||
|
done := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
// First client
|
||||||
|
go func() {
|
||||||
|
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
|
||||||
|
// Consume the initial buffer
|
||||||
|
scanner.Scan()
|
||||||
|
actual := scanner.Text()
|
||||||
|
if !strings.HasPrefix(actual, "[foo] ") {
|
||||||
|
t.Errorf("First client failed to get 'foo' name.")
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = stripPrompt(actual)
|
||||||
|
expected := " * foo joined. (Connected: 1)"
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got %q; expected %q", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ready for second client
|
||||||
|
done <- struct{}{}
|
||||||
|
|
||||||
|
scanner.Scan()
|
||||||
|
actual = stripPrompt(scanner.Text())
|
||||||
|
expected = " * Guest1 joined. (Connected: 2)"
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got %q; expected %q", actual, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap it up.
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for first client
|
||||||
|
<-done
|
||||||
|
|
||||||
|
// Second client
|
||||||
|
err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
|
||||||
|
// Consume the initial buffer
|
||||||
|
scanner.Scan()
|
||||||
|
actual := scanner.Text()
|
||||||
|
if !strings.HasPrefix(actual, "[Guest1] ") {
|
||||||
|
t.Errorf("Second client did not get Guest1 name.")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostWhitelist(t *testing.T) {
|
||||||
|
key, err := sshd.NewRandomSigner(512)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth := NewAuth()
|
||||||
|
config := sshd.MakeAuth(auth)
|
||||||
|
config.AddHostKey(key)
|
||||||
|
|
||||||
|
s, err := sshd.ListenSSH(":0", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
host := NewHost(s)
|
||||||
|
host.auth = auth
|
||||||
|
go host.Serve()
|
||||||
|
|
||||||
|
target := s.Addr().String()
|
||||||
|
|
||||||
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientkey, err := rsa.GenerateKey(rand.Reader, 512)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
|
||||||
|
auth.Whitelist(clientpubkey, 0)
|
||||||
|
|
||||||
|
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Failed to block unwhitelisted connection.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostKick(t *testing.T) {
|
||||||
|
key, err := sshd.NewRandomSigner(512)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth := NewAuth()
|
||||||
|
config := sshd.MakeAuth(auth)
|
||||||
|
config.AddHostKey(key)
|
||||||
|
|
||||||
|
s, err := sshd.ListenSSH(":0", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
addr := s.Addr().String()
|
||||||
|
host := NewHost(s)
|
||||||
|
go host.Serve()
|
||||||
|
|
||||||
|
connected := make(chan struct{})
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// First client
|
||||||
|
err = sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) {
|
||||||
|
// Make op
|
||||||
|
member, _ := host.Room.MemberById("foo")
|
||||||
|
member.Op = true
|
||||||
|
|
||||||
|
// Block until second client is here
|
||||||
|
connected <- struct{}{}
|
||||||
|
w.Write([]byte("/kick bar\r\n"))
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// Second client
|
||||||
|
err = sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) {
|
||||||
|
<-connected
|
||||||
|
|
||||||
|
// Consume while we're connected. Should break when kicked.
|
||||||
|
ioutil.ReadAll(r)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second * 1):
|
||||||
|
t.Fatal("Timeout.")
|
||||||
|
}
|
||||||
|
}
|
50
identity.go
Normal file
50
identity.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/shazow/ssh-chat/chat"
|
||||||
|
"github.com/shazow/ssh-chat/sshd"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Identity is a container for everything that identifies a client.
|
||||||
|
type Identity struct {
|
||||||
|
sshd.Connection
|
||||||
|
id string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIdentity returns a new identity object from an sshd.Connection.
|
||||||
|
func NewIdentity(conn sshd.Connection) *Identity {
|
||||||
|
return &Identity{
|
||||||
|
Connection: conn,
|
||||||
|
id: chat.SanitizeName(conn.Name()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i Identity) Id() string {
|
||||||
|
return i.id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Identity) SetId(id string) {
|
||||||
|
i.id = id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Identity) SetName(name string) {
|
||||||
|
i.SetId(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i Identity) Name() string {
|
||||||
|
return i.id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i Identity) Whois() string {
|
||||||
|
ip, _, _ := net.SplitHostPort(i.RemoteAddr().String())
|
||||||
|
fingerprint := "(no public key)"
|
||||||
|
if i.PublicKey() != nil {
|
||||||
|
fingerprint = sshd.Fingerprint(i.PublicKey())
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("name: %s"+chat.Newline+
|
||||||
|
" > ip: %s"+chat.Newline+
|
||||||
|
" > fingerprint: %s", i.Name(), ip, fingerprint)
|
||||||
|
}
|
49
key.go
Normal file
49
key.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"code.google.com/p/gopass"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReadPrivateKey attempts to read your private key and possibly decrypt it if it
|
||||||
|
// requires a passphrase.
|
||||||
|
// This function will prompt for a passphrase on STDIN if the environment variable (`IDENTITY_PASSPHRASE`),
|
||||||
|
// is not set.
|
||||||
|
func ReadPrivateKey(path string) ([]byte, error) {
|
||||||
|
privateKey, err := ioutil.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load identity: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, rest := pem.Decode(privateKey)
|
||||||
|
if len(rest) > 0 {
|
||||||
|
return nil, fmt.Errorf("extra data when decoding private key")
|
||||||
|
}
|
||||||
|
if !x509.IsEncryptedPEMBlock(block) {
|
||||||
|
return privateKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
passphrase := os.Getenv("IDENTITY_PASSPHRASE")
|
||||||
|
if passphrase == "" {
|
||||||
|
passphrase, err = gopass.GetPass("Enter passphrase: ")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("couldn't read passphrase: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
der, err := x509.DecryptPEMBlock(block, []byte(passphrase))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decrypt failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey = pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: block.Type,
|
||||||
|
Bytes: der,
|
||||||
|
})
|
||||||
|
|
||||||
|
return privateKey, nil
|
||||||
|
}
|
@ -1,7 +1,16 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
|
||||||
|
"github.com/alexcesaro/log"
|
||||||
"github.com/alexcesaro/log/golog"
|
"github.com/alexcesaro/log/golog"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger *golog.Logger
|
var logger *golog.Logger
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Set a default null logger
|
||||||
|
var b bytes.Buffer
|
||||||
|
logger = golog.New(&b, log.Debug)
|
||||||
|
}
|
||||||
|
519
server.go
519
server.go
@ -1,519 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"crypto/md5"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
maxNameLength = 32
|
|
||||||
historyLength = 20
|
|
||||||
systemMessageFormat = "\033[1;90m"
|
|
||||||
privateMessageFormat = "\033[1m"
|
|
||||||
highlightFormat = Bold + "\033[48;5;11m\033[38;5;16m"
|
|
||||||
beep = "\007"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
reStripText = regexp.MustCompile("[^0-9A-Za-z_.-]")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Clients is a map of clients
|
|
||||||
type Clients map[string]*Client
|
|
||||||
|
|
||||||
// Server holds all the fields used by a server
|
|
||||||
type Server struct {
|
|
||||||
sshConfig *ssh.ServerConfig
|
|
||||||
done chan struct{}
|
|
||||||
clients Clients
|
|
||||||
count int
|
|
||||||
history *History
|
|
||||||
motd string
|
|
||||||
whitelist map[string]struct{} // fingerprint lookup
|
|
||||||
admins map[string]struct{} // fingerprint lookup
|
|
||||||
bannedPK map[string]*time.Time // fingerprint lookup
|
|
||||||
started time.Time
|
|
||||||
sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewServer constructs a new server
|
|
||||||
func NewServer(privateKey []byte) (*Server, error) {
|
|
||||||
signer, err := ssh.ParsePrivateKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
server := Server{
|
|
||||||
done: make(chan struct{}),
|
|
||||||
clients: Clients{},
|
|
||||||
count: 0,
|
|
||||||
history: NewHistory(historyLength),
|
|
||||||
motd: "",
|
|
||||||
whitelist: map[string]struct{}{},
|
|
||||||
admins: map[string]struct{}{},
|
|
||||||
bannedPK: map[string]*time.Time{},
|
|
||||||
started: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
config := ssh.ServerConfig{
|
|
||||||
NoClientAuth: false,
|
|
||||||
// Auth-related things should be constant-time to avoid timing attacks.
|
|
||||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
|
||||||
fingerprint := Fingerprint(key)
|
|
||||||
if server.IsBanned(fingerprint) {
|
|
||||||
return nil, fmt.Errorf("Banned.")
|
|
||||||
}
|
|
||||||
if !server.IsWhitelisted(fingerprint) {
|
|
||||||
return nil, fmt.Errorf("Not Whitelisted.")
|
|
||||||
}
|
|
||||||
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}}
|
|
||||||
return perm, nil
|
|
||||||
},
|
|
||||||
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
|
||||||
if server.IsBanned("") {
|
|
||||||
return nil, fmt.Errorf("Interactive login disabled.")
|
|
||||||
}
|
|
||||||
if !server.IsWhitelisted("") {
|
|
||||||
return nil, fmt.Errorf("Not Whitelisted.")
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
config.AddHostKey(signer)
|
|
||||||
|
|
||||||
server.sshConfig = &config
|
|
||||||
|
|
||||||
return &server, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Len returns the number of clients
|
|
||||||
func (s *Server) Len() int {
|
|
||||||
return len(s.clients)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SysMsg broadcasts the given message to everyone
|
|
||||||
func (s *Server) SysMsg(msg string, args ...interface{}) {
|
|
||||||
s.Broadcast(ContinuousFormat(systemMessageFormat, " * "+fmt.Sprintf(msg, args...)), nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Broadcast broadcasts the given message to everyone except for the given client
|
|
||||||
func (s *Server) Broadcast(msg string, except *Client) {
|
|
||||||
logger.Debugf("Broadcast to %d: %s", s.Len(), msg)
|
|
||||||
s.history.Add(msg)
|
|
||||||
|
|
||||||
s.RLock()
|
|
||||||
defer s.RUnlock()
|
|
||||||
|
|
||||||
for _, client := range s.clients {
|
|
||||||
if except != nil && client == except {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Contains(msg, client.Name) {
|
|
||||||
// Turn message red if client's name is mentioned, and send BEL if they have enabled beeping
|
|
||||||
personalMsg := strings.Replace(msg, client.Name, highlightFormat+client.Name+Reset, -1)
|
|
||||||
if client.beepMe {
|
|
||||||
personalMsg += beep
|
|
||||||
}
|
|
||||||
client.Send(personalMsg)
|
|
||||||
} else {
|
|
||||||
client.Send(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Privmsg sends a message to a particular nick, if it exists
|
|
||||||
func (s *Server) Privmsg(nick, message string, sender *Client) error {
|
|
||||||
// Get the recipient
|
|
||||||
target, ok := s.clients[strings.ToLower(nick)]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("no client with that nick")
|
|
||||||
}
|
|
||||||
// Send the message
|
|
||||||
target.Msg <- fmt.Sprintf(beep+"[PM from %v] %s%v%s", sender.ColoredName(), privateMessageFormat, message, Reset)
|
|
||||||
logger.Debugf("PM from %v to %v: %v", sender.Name, nick, message)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMotd sets the Message of the Day (MOTD)
|
|
||||||
func (s *Server) SetMotd(motd string) {
|
|
||||||
s.motd = motd
|
|
||||||
}
|
|
||||||
|
|
||||||
// MotdUnicast sends the MOTD as a SysMsg
|
|
||||||
func (s *Server) MotdUnicast(client *Client) {
|
|
||||||
if s.motd == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
client.SysMsg(s.motd)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MotdBroadcast broadcasts the MOTD
|
|
||||||
func (s *Server) MotdBroadcast(client *Client) {
|
|
||||||
if s.motd == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.Broadcast(ContinuousFormat(systemMessageFormat, fmt.Sprintf(" * New MOTD set by %s.", client.ColoredName())), client)
|
|
||||||
s.Broadcast(s.motd, client)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add adds the client to the list of clients
|
|
||||||
func (s *Server) Add(client *Client) {
|
|
||||||
go func() {
|
|
||||||
s.MotdUnicast(client)
|
|
||||||
client.SendLines(s.history.Get(10))
|
|
||||||
}()
|
|
||||||
|
|
||||||
s.Lock()
|
|
||||||
s.count++
|
|
||||||
|
|
||||||
newName, err := s.proposeName(client.Name)
|
|
||||||
if err != nil {
|
|
||||||
client.SysMsg("Your name '%s' is not available, renamed to '%s'. Use /nick <name> to change it.", client.Name, ColorString(client.Color, newName))
|
|
||||||
}
|
|
||||||
|
|
||||||
client.Rename(newName)
|
|
||||||
s.clients[strings.ToLower(client.Name)] = client
|
|
||||||
num := len(s.clients)
|
|
||||||
s.Unlock()
|
|
||||||
|
|
||||||
s.Broadcast(ContinuousFormat(systemMessageFormat, fmt.Sprintf(" * %s joined. (Total connected: %d)", client.Name, num)), client)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove removes the given client from the list of clients
|
|
||||||
func (s *Server) Remove(client *Client) {
|
|
||||||
s.Lock()
|
|
||||||
delete(s.clients, strings.ToLower(client.Name))
|
|
||||||
s.Unlock()
|
|
||||||
|
|
||||||
s.SysMsg("%s left.", client.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) proposeName(name string) (string, error) {
|
|
||||||
// Assumes caller holds lock.
|
|
||||||
var err error
|
|
||||||
name = reStripText.ReplaceAllString(name, "")
|
|
||||||
|
|
||||||
if len(name) > maxNameLength {
|
|
||||||
name = name[:maxNameLength]
|
|
||||||
} else if len(name) == 0 {
|
|
||||||
name = fmt.Sprintf("Guest%d", s.count)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, collision := s.clients[strings.ToLower(name)]
|
|
||||||
if collision {
|
|
||||||
err = fmt.Errorf("%s is not available", name)
|
|
||||||
name = fmt.Sprintf("Guest%d", s.count)
|
|
||||||
}
|
|
||||||
|
|
||||||
return name, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rename renames the given client (user)
|
|
||||||
func (s *Server) Rename(client *Client, newName string) {
|
|
||||||
var oldName string
|
|
||||||
if strings.ToLower(newName) == strings.ToLower(client.Name) {
|
|
||||||
oldName = client.Name
|
|
||||||
client.Rename(newName)
|
|
||||||
} else {
|
|
||||||
s.Lock()
|
|
||||||
newName, err := s.proposeName(newName)
|
|
||||||
if err != nil {
|
|
||||||
client.SysMsg("%s", err)
|
|
||||||
s.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Use a channel/goroutine for adding clients, rather than locks?
|
|
||||||
delete(s.clients, strings.ToLower(client.Name))
|
|
||||||
oldName = client.Name
|
|
||||||
client.Rename(newName)
|
|
||||||
s.clients[strings.ToLower(client.Name)] = client
|
|
||||||
s.Unlock()
|
|
||||||
}
|
|
||||||
s.SysMsg("%s is now known as %s.", ColorString(client.Color, oldName), ColorString(client.Color, client.Name))
|
|
||||||
}
|
|
||||||
|
|
||||||
// List lists the clients with the given prefix
|
|
||||||
func (s *Server) List(prefix *string) []string {
|
|
||||||
r := []string{}
|
|
||||||
|
|
||||||
s.RLock()
|
|
||||||
defer s.RUnlock()
|
|
||||||
|
|
||||||
for name, client := range s.clients {
|
|
||||||
if prefix != nil && !strings.HasPrefix(name, strings.ToLower(*prefix)) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
r = append(r, client.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// Who returns the client with a given name
|
|
||||||
func (s *Server) Who(name string) *Client {
|
|
||||||
return s.clients[strings.ToLower(name)]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Op adds the given fingerprint to the list of admins
|
|
||||||
func (s *Server) Op(fingerprint string) {
|
|
||||||
logger.Infof("Adding admin: %s", fingerprint)
|
|
||||||
s.Lock()
|
|
||||||
s.admins[fingerprint] = struct{}{}
|
|
||||||
s.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Whitelist adds the given fingerprint to the whitelist
|
|
||||||
func (s *Server) Whitelist(fingerprint string) error {
|
|
||||||
if fingerprint == "" {
|
|
||||||
return fmt.Errorf("Invalid fingerprint.")
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(fingerprint, "github.com/") {
|
|
||||||
return s.whitelistIdentityURL(fingerprint)
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.whitelistFingerprint(fingerprint)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) whitelistIdentityURL(user string) error {
|
|
||||||
logger.Infof("Adding github account %s to whitelist", user)
|
|
||||||
|
|
||||||
user = strings.Replace(user, "github.com/", "", -1)
|
|
||||||
keys, err := getGithubPubKeys(user)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if len(keys) == 0 {
|
|
||||||
return fmt.Errorf("No keys for github user %s", user)
|
|
||||||
}
|
|
||||||
for _, key := range keys {
|
|
||||||
fingerprint := Fingerprint(key)
|
|
||||||
s.whitelistFingerprint(fingerprint)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) whitelistFingerprint(fingerprint string) error {
|
|
||||||
logger.Infof("Adding whitelist: %s", fingerprint)
|
|
||||||
s.Lock()
|
|
||||||
s.whitelist[fingerprint] = struct{}{}
|
|
||||||
s.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client for getting github pub keys
|
|
||||||
var client = http.Client{
|
|
||||||
Timeout: time.Duration(10 * time.Second),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns an array of public keys for the given github user URL
|
|
||||||
func getGithubPubKeys(user string) ([]ssh.PublicKey, error) {
|
|
||||||
resp, err := client.Get("http://github.com/" + user + ".keys")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
pubs := []ssh.PublicKey{}
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
for scanner.Scan() {
|
|
||||||
text := scanner.Text()
|
|
||||||
if text == "Not Found" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
splitKey := strings.SplitN(text, " ", -1)
|
|
||||||
|
|
||||||
// In case of malformated key
|
|
||||||
if len(splitKey) < 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyDecoded, err := base64.StdEncoding.DecodeString(splitKey[1])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pub, err := ssh.ParsePublicKey(bodyDecoded)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pubs = append(pubs, pub)
|
|
||||||
}
|
|
||||||
return pubs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uptime returns the time since the server was started
|
|
||||||
func (s *Server) Uptime() string {
|
|
||||||
return time.Now().Sub(s.started).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsOp checks if the given client is Op
|
|
||||||
func (s *Server) IsOp(client *Client) bool {
|
|
||||||
fingerprint := client.Fingerprint()
|
|
||||||
if fingerprint == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
_, r := s.admins[client.Fingerprint()]
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsWhitelisted checks if the given fingerprint is whitelisted
|
|
||||||
func (s *Server) IsWhitelisted(fingerprint string) bool {
|
|
||||||
/* if no whitelist, anyone is welcome */
|
|
||||||
if len(s.whitelist) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
/* otherwise, check for whitelist presence */
|
|
||||||
_, r := s.whitelist[fingerprint]
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsBanned checks if the given fingerprint is banned
|
|
||||||
func (s *Server) IsBanned(fingerprint string) bool {
|
|
||||||
ban, hasBan := s.bannedPK[fingerprint]
|
|
||||||
if !hasBan {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if ban == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if ban.Before(time.Now()) {
|
|
||||||
s.Unban(fingerprint)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ban bans a fingerprint for the given duration
|
|
||||||
func (s *Server) Ban(fingerprint string, duration *time.Duration) {
|
|
||||||
var until *time.Time
|
|
||||||
s.Lock()
|
|
||||||
if duration != nil {
|
|
||||||
when := time.Now().Add(*duration)
|
|
||||||
until = &when
|
|
||||||
}
|
|
||||||
s.bannedPK[fingerprint] = until
|
|
||||||
s.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unban unbans a banned fingerprint
|
|
||||||
func (s *Server) Unban(fingerprint string) {
|
|
||||||
s.Lock()
|
|
||||||
delete(s.bannedPK, fingerprint)
|
|
||||||
s.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start starts the server
|
|
||||||
func (s *Server) Start(laddr string) error {
|
|
||||||
// Once a ServerConfig has been configured, connections can be
|
|
||||||
// accepted.
|
|
||||||
socket, err := net.Listen("tcp", laddr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Infof("Listening on %s", laddr)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer socket.Close()
|
|
||||||
for {
|
|
||||||
conn, err := socket.Accept()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Failed to accept connection: %v", err)
|
|
||||||
if err == syscall.EINVAL {
|
|
||||||
// TODO: Handle shutdown more gracefully?
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Goroutineify to resume accepting sockets early.
|
|
||||||
go func() {
|
|
||||||
// From a standard TCP connection to an encrypted SSH connection
|
|
||||||
sshConn, channels, requests, err := ssh.NewServerConn(conn, s.sshConfig)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Failed to handshake: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
version := reStripText.ReplaceAllString(string(sshConn.ClientVersion()), "")
|
|
||||||
if len(version) > 100 {
|
|
||||||
version = "Evil Jerk with a superlong string"
|
|
||||||
}
|
|
||||||
logger.Infof("Connection #%d from: %s, %s, %s", s.count+1, sshConn.RemoteAddr(), sshConn.User(), version)
|
|
||||||
|
|
||||||
go ssh.DiscardRequests(requests)
|
|
||||||
|
|
||||||
client := NewClient(s, sshConn)
|
|
||||||
go client.handleChannels(channels)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-s.done
|
|
||||||
socket.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AutoCompleteFunction handles auto completion of nicks
|
|
||||||
func (s *Server) AutoCompleteFunction(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
|
|
||||||
if key == 9 {
|
|
||||||
shortLine := strings.Split(line[:pos], " ")
|
|
||||||
partialNick := shortLine[len(shortLine)-1]
|
|
||||||
|
|
||||||
nicks := s.List(&partialNick)
|
|
||||||
if len(nicks) > 0 {
|
|
||||||
nick := nicks[len(nicks)-1]
|
|
||||||
posPartialNick := pos - len(partialNick)
|
|
||||||
if len(shortLine) < 2 {
|
|
||||||
nick += ": "
|
|
||||||
} else {
|
|
||||||
nick += " "
|
|
||||||
}
|
|
||||||
newLine = strings.Replace(line[posPartialNick:],
|
|
||||||
partialNick, nick, 1)
|
|
||||||
newLine = line[:posPartialNick] + newLine
|
|
||||||
newPos = pos + (len(nick) - len(partialNick))
|
|
||||||
ok = true
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ok = false
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the server
|
|
||||||
func (s *Server) Stop() {
|
|
||||||
s.Lock()
|
|
||||||
for _, client := range s.clients {
|
|
||||||
client.Conn.Close()
|
|
||||||
}
|
|
||||||
s.Unlock()
|
|
||||||
|
|
||||||
close(s.done)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fingerprint returns the fingerprint based on a public key
|
|
||||||
func Fingerprint(k ssh.PublicKey) string {
|
|
||||||
hash := md5.Sum(k.Marshal())
|
|
||||||
r := fmt.Sprintf("% x", hash)
|
|
||||||
return strings.Replace(r, " ", ":", -1)
|
|
||||||
}
|
|
70
set.go
Normal file
70
set.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type expiringValue struct {
|
||||||
|
time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v expiringValue) Bool() bool {
|
||||||
|
return time.Now().Before(v.Time)
|
||||||
|
}
|
||||||
|
|
||||||
|
type value struct{}
|
||||||
|
|
||||||
|
func (v value) Bool() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type SetValue interface {
|
||||||
|
Bool() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set with expire-able keys
|
||||||
|
type Set struct {
|
||||||
|
lookup map[string]SetValue
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSet creates a new set.
|
||||||
|
func NewSet() *Set {
|
||||||
|
return &Set{
|
||||||
|
lookup: map[string]SetValue{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the size of the set right now.
|
||||||
|
func (s *Set) Len() int {
|
||||||
|
return len(s.lookup)
|
||||||
|
}
|
||||||
|
|
||||||
|
// In checks if an item exists in this set.
|
||||||
|
func (s *Set) In(key string) bool {
|
||||||
|
s.Lock()
|
||||||
|
v, ok := s.lookup[key]
|
||||||
|
if ok && !v.Bool() {
|
||||||
|
ok = false
|
||||||
|
delete(s.lookup, key)
|
||||||
|
}
|
||||||
|
s.Unlock()
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add item to this set, replace if it exists.
|
||||||
|
func (s *Set) Add(key string) {
|
||||||
|
s.Lock()
|
||||||
|
s.lookup[key] = value{}
|
||||||
|
s.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add item to this set, replace if it exists.
|
||||||
|
func (s *Set) AddExpiring(key string, d time.Duration) time.Time {
|
||||||
|
until := time.Now().Add(d)
|
||||||
|
s.Lock()
|
||||||
|
s.lookup[key] = expiringValue{until}
|
||||||
|
s.Unlock()
|
||||||
|
return until
|
||||||
|
}
|
58
set_test.go
Normal file
58
set_test.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSetExpiring(t *testing.T) {
|
||||||
|
s := NewSet()
|
||||||
|
if s.In("foo") {
|
||||||
|
t.Error("Matched before set.")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Add("foo")
|
||||||
|
if !s.In("foo") {
|
||||||
|
t.Errorf("Not matched after set")
|
||||||
|
}
|
||||||
|
if s.Len() != 1 {
|
||||||
|
t.Error("Not len 1 after set")
|
||||||
|
}
|
||||||
|
|
||||||
|
v := expiringValue{time.Now().Add(-time.Nanosecond * 1)}
|
||||||
|
if v.Bool() {
|
||||||
|
t.Errorf("expiringValue now is not expiring")
|
||||||
|
}
|
||||||
|
|
||||||
|
v = expiringValue{time.Now().Add(time.Minute * 2)}
|
||||||
|
if !v.Bool() {
|
||||||
|
t.Errorf("expiringValue in 2 minutes is expiring now")
|
||||||
|
}
|
||||||
|
|
||||||
|
until := s.AddExpiring("bar", time.Minute*2)
|
||||||
|
if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
|
||||||
|
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
|
||||||
|
}
|
||||||
|
val, ok := s.lookup["bar"]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("bar not in lookup")
|
||||||
|
}
|
||||||
|
if !until.Equal(val.(expiringValue).Time) {
|
||||||
|
t.Errorf("bar's until is not equal to the expected value")
|
||||||
|
}
|
||||||
|
if !val.Bool() {
|
||||||
|
t.Errorf("bar expired immediately")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.In("bar") {
|
||||||
|
t.Errorf("Not matched after timed set")
|
||||||
|
}
|
||||||
|
if s.Len() != 2 {
|
||||||
|
t.Error("Not len 2 after set")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.AddExpiring("bar", time.Nanosecond*1)
|
||||||
|
if s.In("bar") {
|
||||||
|
t.Error("Matched after expired timer")
|
||||||
|
}
|
||||||
|
}
|
72
sshd/auth.go
Normal file
72
sshd/auth.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package sshd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Auth is used to authenticate connections based on public keys.
|
||||||
|
type Auth interface {
|
||||||
|
// Whether to allow connections without a public key.
|
||||||
|
AllowAnonymous() bool
|
||||||
|
// Given address and public key, return if the connection should be permitted.
|
||||||
|
Check(net.Addr, ssh.PublicKey) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeAuth makes an ssh.ServerConfig which performs authentication against an Auth implementation.
|
||||||
|
func MakeAuth(auth Auth) *ssh.ServerConfig {
|
||||||
|
config := ssh.ServerConfig{
|
||||||
|
NoClientAuth: false,
|
||||||
|
// Auth-related things should be constant-time to avoid timing attacks.
|
||||||
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
|
ok, err := auth.Check(conn.RemoteAddr(), key)
|
||||||
|
if !ok {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
perm := &ssh.Permissions{Extensions: map[string]string{
|
||||||
|
"pubkey": string(key.Marshal()),
|
||||||
|
}}
|
||||||
|
return perm, nil
|
||||||
|
},
|
||||||
|
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||||
|
if !auth.AllowAnonymous() {
|
||||||
|
return nil, errors.New("public key authentication required")
|
||||||
|
}
|
||||||
|
_, err := auth.Check(conn.RemoteAddr(), nil)
|
||||||
|
return nil, err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &config
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeNoAuth makes a simple ssh.ServerConfig which allows all connections.
|
||||||
|
// Primarily used for testing.
|
||||||
|
func MakeNoAuth() *ssh.ServerConfig {
|
||||||
|
config := ssh.ServerConfig{
|
||||||
|
NoClientAuth: false,
|
||||||
|
// Auth-related things should be constant-time to avoid timing attacks.
|
||||||
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
|
perm := &ssh.Permissions{Extensions: map[string]string{
|
||||||
|
"pubkey": string(key.Marshal()),
|
||||||
|
}}
|
||||||
|
return perm, nil
|
||||||
|
},
|
||||||
|
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fingerprint performs a SHA256 BASE64 fingerprint of the PublicKey, similar to OpenSSH.
|
||||||
|
// See: https://anongit.mindrot.org/openssh.git/commit/?id=56d1c83cdd1ac
|
||||||
|
func Fingerprint(k ssh.PublicKey) string {
|
||||||
|
hash := sha256.Sum256(k.Marshal())
|
||||||
|
return base64.StdEncoding.EncodeToString(hash[:])
|
||||||
|
}
|
72
sshd/client.go
Normal file
72
sshd/client.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package sshd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewRandomSigner generates a random key of a desired bit length.
|
||||||
|
func NewRandomSigner(bits int) (ssh.Signer, error) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, bits)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ssh.NewSignerFromKey(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientConfig creates a barebones ssh.ClientConfig to be used with ssh.Dial.
|
||||||
|
func NewClientConfig(name string) *ssh.ClientConfig {
|
||||||
|
return &ssh.ClientConfig{
|
||||||
|
User: name,
|
||||||
|
Auth: []ssh.AuthMethod{
|
||||||
|
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
|
||||||
|
return
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectShell makes a barebones SSH client session, used for testing.
|
||||||
|
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
|
||||||
|
config := NewClientConfig(name)
|
||||||
|
conn, err := ssh.Dial("tcp", host, config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
session, err := conn.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
in, err := session.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := session.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
/* FIXME: Do we want to request a PTY?
|
||||||
|
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
err = session.Shell()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
handler(out, in)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
46
sshd/client_test.go
Normal file
46
sshd/client_test.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
package sshd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errRejectAuth = errors.New("not welcome here")
|
||||||
|
|
||||||
|
type RejectAuth struct{}
|
||||||
|
|
||||||
|
func (a RejectAuth) AllowAnonymous() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
func (a RejectAuth) Check(net.Addr, ssh.PublicKey) (bool, error) {
|
||||||
|
return false, errRejectAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
func consume(ch <-chan *Terminal) {
|
||||||
|
for _ = range ch {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientReject(t *testing.T) {
|
||||||
|
signer, err := NewRandomSigner(512)
|
||||||
|
config := MakeAuth(RejectAuth{})
|
||||||
|
config.AddHostKey(signer)
|
||||||
|
|
||||||
|
s, err := ListenSSH(":0", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
go consume(s.ServeTerminal())
|
||||||
|
|
||||||
|
conn, err := ssh.Dial("tcp", s.Addr().String(), NewClientConfig("foo"))
|
||||||
|
if err == nil {
|
||||||
|
defer conn.Close()
|
||||||
|
t.Error("Failed to reject conncetion")
|
||||||
|
}
|
||||||
|
t.Log(err)
|
||||||
|
}
|
34
sshd/doc.go
Normal file
34
sshd/doc.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package sshd
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
signer, err := ssh.ParsePrivateKey(privateKey)
|
||||||
|
|
||||||
|
config := MakeNoAuth()
|
||||||
|
config.AddHostKey(signer)
|
||||||
|
|
||||||
|
s, err := ListenSSH("0.0.0.0:2022", config)
|
||||||
|
if err != nil {
|
||||||
|
// Handle opening socket error
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
terminals := s.ServeTerminal()
|
||||||
|
|
||||||
|
for term := range terminals {
|
||||||
|
go func() {
|
||||||
|
defer term.Close()
|
||||||
|
term.SetPrompt("...")
|
||||||
|
term.AutoCompleteCallback = nil // ...
|
||||||
|
|
||||||
|
for {
|
||||||
|
line, err := term.ReadLine()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
term.Write(...)
|
||||||
|
}
|
||||||
|
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
*/
|
22
sshd/logger.go
Normal file
22
sshd/logger.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package sshd
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
import stdlog "log"
|
||||||
|
|
||||||
|
var logger *stdlog.Logger
|
||||||
|
|
||||||
|
func SetLogger(w io.Writer) {
|
||||||
|
flags := stdlog.Flags()
|
||||||
|
prefix := "[sshd] "
|
||||||
|
logger = stdlog.New(w, prefix, flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
type nullWriter struct{}
|
||||||
|
|
||||||
|
func (nullWriter) Write(data []byte) (int, error) {
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
SetLogger(nullWriter{})
|
||||||
|
}
|
74
sshd/net.go
Normal file
74
sshd/net.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package sshd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/shazow/rateio"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Container for the connection and ssh-related configuration
|
||||||
|
type SSHListener struct {
|
||||||
|
net.Listener
|
||||||
|
config *ssh.ServerConfig
|
||||||
|
RateLimit bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make an SSH listener socket
|
||||||
|
func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) {
|
||||||
|
socket, err := net.Listen("tcp", laddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l := SSHListener{Listener: socket, config: config}
|
||||||
|
return &l, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) {
|
||||||
|
if l.RateLimit {
|
||||||
|
// TODO: Configurable Limiter?
|
||||||
|
conn = ReadLimitConn(conn, rateio.NewGracefulLimiter(1024*10, time.Minute*2, time.Second*3))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upgrade TCP connection to SSH connection
|
||||||
|
sshConn, channels, requests, err := ssh.NewServerConn(conn, l.config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME: Disconnect if too many faulty requests? (Avoid DoS.)
|
||||||
|
go ssh.DiscardRequests(requests)
|
||||||
|
return NewSession(sshConn, channels)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept incoming connections as terminal requests and yield them
|
||||||
|
func (l *SSHListener) ServeTerminal() <-chan *Terminal {
|
||||||
|
ch := make(chan *Terminal)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer l.Close()
|
||||||
|
defer close(ch)
|
||||||
|
|
||||||
|
for {
|
||||||
|
conn, err := l.Accept()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Printf("Failed to accept connection: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Goroutineify to resume accepting sockets early
|
||||||
|
go func() {
|
||||||
|
term, err := l.handleConn(conn)
|
||||||
|
if err != nil {
|
||||||
|
logger.Printf("Failed to handshake: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ch <- term
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch
|
||||||
|
}
|
81
sshd/net_test.go
Normal file
81
sshd/net_test.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
package sshd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServerInit(t *testing.T) {
|
||||||
|
config := MakeNoAuth()
|
||||||
|
s, err := ListenSSH(":badport", config)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("should fail on bad port")
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err = ListenSSH(":0", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeTerminals(t *testing.T) {
|
||||||
|
signer, err := NewRandomSigner(512)
|
||||||
|
config := MakeNoAuth()
|
||||||
|
config.AddHostKey(signer)
|
||||||
|
|
||||||
|
s, err := ListenSSH(":0", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
terminals := s.ServeTerminal()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// Accept one terminal, read from it, echo back, close.
|
||||||
|
term := <-terminals
|
||||||
|
term.SetPrompt("> ")
|
||||||
|
|
||||||
|
line, err := term.ReadLine()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
_, err = term.Write([]byte("echo: " + line + "\r\n"))
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
term.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
host := s.Addr().String()
|
||||||
|
name := "foo"
|
||||||
|
|
||||||
|
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) {
|
||||||
|
// Consume if there is anything
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
w.Write([]byte("hello\r\n"))
|
||||||
|
|
||||||
|
buf.Reset()
|
||||||
|
_, err := io.Copy(buf, r)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := "> hello\r\necho: hello\r\n"
|
||||||
|
actual := buf.String()
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got %q; expected %q", actual, expected)
|
||||||
|
}
|
||||||
|
s.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
@ -1,8 +1,9 @@
|
|||||||
// Borrowed from go.crypto circa 2011
|
package sshd
|
||||||
package main
|
|
||||||
|
|
||||||
import "encoding/binary"
|
import "encoding/binary"
|
||||||
|
|
||||||
|
// Helpers below are borrowed from go.crypto circa 2011:
|
||||||
|
|
||||||
// parsePtyRequest parses the payload of the pty-req message and extracts the
|
// parsePtyRequest parses the payload of the pty-req message and extracts the
|
||||||
// dimensions of the terminal. See RFC 4254, section 6.2.
|
// dimensions of the terminal. See RFC 4254, section 6.2.
|
||||||
func parsePtyRequest(s []byte) (width, height int, ok bool) {
|
func parsePtyRequest(s []byte) (width, height int, ok bool) {
|
25
sshd/ratelimit.go
Normal file
25
sshd/ratelimit.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package sshd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/shazow/rateio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type limitedConn struct {
|
||||||
|
net.Conn
|
||||||
|
io.Reader // Our rate-limited io.Reader for net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *limitedConn) Read(p []byte) (n int, err error) {
|
||||||
|
return r.Reader.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadLimitConn returns a net.Conn whose io.Reader interface is rate-limited by limiter.
|
||||||
|
func ReadLimitConn(conn net.Conn, limiter rateio.Limiter) net.Conn {
|
||||||
|
return &limitedConn{
|
||||||
|
Conn: conn,
|
||||||
|
Reader: rateio.NewReader(conn, limiter),
|
||||||
|
}
|
||||||
|
}
|
144
sshd/terminal.go
Normal file
144
sshd/terminal.go
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
package sshd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
"golang.org/x/crypto/ssh/terminal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Connection is an interface with fields necessary to operate an sshd host.
|
||||||
|
type Connection interface {
|
||||||
|
PublicKey() ssh.PublicKey
|
||||||
|
RemoteAddr() net.Addr
|
||||||
|
Name() string
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshConn struct {
|
||||||
|
*ssh.ServerConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c sshConn) PublicKey() ssh.PublicKey {
|
||||||
|
if c.Permissions == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s, ok := c.Permissions.Extensions["pubkey"]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := ssh.ParsePublicKey([]byte(s))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c sshConn) Name() string {
|
||||||
|
return c.User()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extending ssh/terminal to include a closer interface
|
||||||
|
type Terminal struct {
|
||||||
|
terminal.Terminal
|
||||||
|
Conn Connection
|
||||||
|
Channel ssh.Channel
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make new terminal from a session channel
|
||||||
|
func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
|
||||||
|
if ch.ChannelType() != "session" {
|
||||||
|
return nil, errors.New("terminal requires session channel")
|
||||||
|
}
|
||||||
|
channel, requests, err := ch.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
term := Terminal{
|
||||||
|
*terminal.NewTerminal(channel, "Connecting..."),
|
||||||
|
sshConn{conn},
|
||||||
|
channel,
|
||||||
|
}
|
||||||
|
|
||||||
|
go term.listen(requests)
|
||||||
|
go func() {
|
||||||
|
// FIXME: Is this necessary?
|
||||||
|
conn.Wait()
|
||||||
|
channel.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &term, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find session channel and make a Terminal from it
|
||||||
|
func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
|
||||||
|
for ch := range channels {
|
||||||
|
if t := ch.ChannelType(); t != "session" {
|
||||||
|
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
term, err = NewTerminal(conn, ch)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if term != nil {
|
||||||
|
// Reject the rest.
|
||||||
|
// FIXME: Do we need this?
|
||||||
|
go func() {
|
||||||
|
for ch := range channels {
|
||||||
|
ch.Reject(ssh.Prohibited, "only one session allowed")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
return term, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close terminal and ssh connection
|
||||||
|
func (t *Terminal) Close() error {
|
||||||
|
return t.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Negotiate terminal type and settings
|
||||||
|
func (t *Terminal) listen(requests <-chan *ssh.Request) {
|
||||||
|
hasShell := false
|
||||||
|
|
||||||
|
for req := range requests {
|
||||||
|
var width, height int
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
switch req.Type {
|
||||||
|
case "shell":
|
||||||
|
if !hasShell {
|
||||||
|
ok = true
|
||||||
|
hasShell = true
|
||||||
|
}
|
||||||
|
case "pty-req":
|
||||||
|
width, height, ok = parsePtyRequest(req.Payload)
|
||||||
|
if ok {
|
||||||
|
// TODO: Hardcode width to 100000?
|
||||||
|
err := t.SetSize(width, height)
|
||||||
|
ok = err == nil
|
||||||
|
}
|
||||||
|
case "window-change":
|
||||||
|
width, height, ok = parseWinchRequest(req.Payload)
|
||||||
|
if ok {
|
||||||
|
// TODO: Hardcode width to 100000?
|
||||||
|
err := t.SetSize(width, height)
|
||||||
|
ok = err == nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.WantReply {
|
||||||
|
req.Reply(ok, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user