Channel Member now wrapping User with metadata, new Auth struct.

This commit is contained in:
Andrey Petrov 2015-01-01 18:40:10 -08:00
parent 6874601c0b
commit 6a662bf358
6 changed files with 80 additions and 53 deletions

View File

@ -13,12 +13,17 @@ const channelBuffer = 10
// closed. // closed.
var ErrChannelClosed = errors.New("channel closed") var ErrChannelClosed = errors.New("channel closed")
// Member is a User with per-Channel metadata attached to it.
type Member struct {
*User
Op bool
}
// Channel definition, also a Set of User Items // Channel definition, also a Set of User Items
type Channel struct { type Channel struct {
topic string topic string
history *History history *History
users *Set members *Set
ops *Set
broadcast chan Message broadcast chan Message
commands Commands commands Commands
closed bool closed bool
@ -32,8 +37,7 @@ func NewChannel() *Channel {
return &Channel{ return &Channel{
broadcast: broadcast, broadcast: broadcast,
history: NewHistory(historyLen), history: NewHistory(historyLen),
users: NewSet(), members: NewSet(),
ops: NewSet(),
commands: *defaultCommands, commands: *defaultCommands,
} }
} }
@ -47,10 +51,10 @@ func (ch *Channel) SetCommands(commands Commands) {
func (ch *Channel) Close() { func (ch *Channel) Close() {
ch.closeOnce.Do(func() { ch.closeOnce.Do(func() {
ch.closed = true ch.closed = true
ch.users.Each(func(u Item) { ch.members.Each(func(u Item) {
u.(*User).Close() u.(*User).Close()
}) })
ch.users.Clear() ch.members.Clear()
close(ch.broadcast) close(ch.broadcast)
}) })
} }
@ -75,8 +79,8 @@ func (ch *Channel) HandleMsg(m Message) {
skipUser = fromMsg.From() skipUser = fromMsg.From()
} }
ch.users.Each(func(u Item) { ch.members.Each(func(u Item) {
user := u.(*User) user := u.(*Member).User
if skip && skipUser == user { if skip && skipUser == user {
// Skip // Skip
return return
@ -108,18 +112,18 @@ func (ch *Channel) Join(u *User) error {
if ch.closed { if ch.closed {
return ErrChannelClosed return ErrChannelClosed
} }
err := ch.users.Add(u) err := ch.members.Add(&Member{u, false})
if err != nil { if err != nil {
return err return err
} }
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.users.Len()) s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.members.Len())
ch.Send(NewAnnounceMsg(s)) ch.Send(NewAnnounceMsg(s))
return nil return nil
} }
// Leave the channel as a user, will announce. // Leave the channel as a user, will announce. Mostly used during setup.
func (ch *Channel) Leave(u *User) error { func (ch *Channel) Leave(u *User) error {
err := ch.users.Remove(u) err := ch.members.Remove(u)
if err != nil { if err != nil {
return err return err
} }
@ -128,6 +132,26 @@ func (ch *Channel) Leave(u *User) error {
return nil return nil
} }
// Member returns a corresponding Member object to a User if the Member is
// present in this channel.
func (ch *Channel) Member(u *User) (*Member, bool) {
m, err := ch.members.Get(u.Id())
if err != nil {
return nil, false
}
// Check that it's the same user
if m.(*Member).User != u {
return nil, false
}
return m.(*Member), true
}
// IsOp returns whether a user is an operator in this channel.
func (ch *Channel) IsOp(u *User) bool {
m, ok := ch.Member(u)
return ok && m.Op
}
// Topic of the channel. // Topic of the channel.
func (ch *Channel) Topic() string { func (ch *Channel) Topic() string {
return ch.topic return ch.topic
@ -141,9 +165,9 @@ func (ch *Channel) SetTopic(s string) {
// NamesPrefix lists all members' names with a given prefix, used to query // NamesPrefix lists all members' names with a given prefix, used to query
// for autocompletion purposes. // for autocompletion purposes.
func (ch *Channel) NamesPrefix(prefix string) []string { func (ch *Channel) NamesPrefix(prefix string) []string {
users := ch.users.ListPrefix(prefix) members := ch.members.ListPrefix(prefix)
names := make([]string, len(users)) names := make([]string, len(members))
for i, u := range users { for i, u := range members {
names[i] = u.(*User).Name() names[i] = u.(*User).Name()
} }
return names return names

View File

@ -98,9 +98,8 @@ func init() {
c.Add(Command{ c.Add(Command{
Prefix: "/help", Prefix: "/help",
Handler: func(channel *Channel, msg CommandMsg) error { Handler: func(channel *Channel, msg CommandMsg) error {
user := msg.From() op := channel.IsOp(msg.From())
op := channel.ops.In(user) channel.Send(NewSystemMsg(channel.commands.Help(op), msg.From()))
channel.Send(NewSystemMsg(channel.commands.Help(op), user))
return nil return nil
}, },
}) })
@ -193,11 +192,12 @@ func init() {
}) })
c.Add(Command{ c.Add(Command{
Op: true,
Prefix: "/op", Prefix: "/op",
PrefixHelp: "USER", PrefixHelp: "USER",
Help: "Mark user as admin.", Help: "Mark user as admin.",
Handler: func(channel *Channel, msg CommandMsg) error { Handler: func(channel *Channel, msg CommandMsg) error {
if !channel.ops.In(msg.From()) { if !channel.IsOp(msg.From()) {
return errors.New("must be op") return errors.New("must be op")
} }
@ -206,13 +206,14 @@ func init() {
return errors.New("must specify user") return errors.New("must specify user")
} }
// TODO: Add support for fingerprint-based op'ing. // TODO: Add support for fingerprint-based op'ing. This will
user, err := channel.users.Get(Id(args[0])) // probably need to live in host land.
member, err := channel.members.Get(Id(args[0]))
if err != nil { if err != nil {
return errors.New("user not found") return errors.New("user not found")
} }
channel.ops.Add(user) member.(*Member).Op = true
return nil return nil
}, },
}) })

17
cmd.go
View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"bufio"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -93,8 +94,8 @@ func main() {
return return
} }
// TODO: MakeAuth auth := Auth{}
config := sshd.MakeNoAuth() config := sshd.MakeAuth(auth)
config.AddHostKey(signer) config.AddHostKey(signer)
s, err := sshd.ListenSSH(options.Bind, config) s, err := sshd.ListenSSH(options.Bind, config)
@ -106,11 +107,10 @@ func main() {
defer s.Close() defer s.Close()
host := NewHost(s) host := NewHost(s)
go host.Serve() host.auth = &auth
/* TODO:
for _, fingerprint := range options.Admin { for _, fingerprint := range options.Admin {
server.Op(fingerprint) auth.Op(fingerprint)
} }
if options.Whitelist != "" { if options.Whitelist != "" {
@ -123,7 +123,7 @@ func main() {
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
for scanner.Scan() { for scanner.Scan() {
server.Whitelist(scanner.Text()) auth.Whitelist(scanner.Text())
} }
} }
@ -137,9 +137,10 @@ func main() {
// hack to normalize line endings into \r\n // hack to normalize line endings into \r\n
motdString = strings.Replace(motdString, "\r\n", "\n", -1) motdString = strings.Replace(motdString, "\r\n", "\n", -1)
motdString = strings.Replace(motdString, "\n", "\r\n", -1) motdString = strings.Replace(motdString, "\n", "\r\n", -1)
server.SetMotd(motdString) host.SetMotd(motdString)
} }
*/
go host.Serve()
// Construct interrupt handler // Construct interrupt handler
sig := make(chan os.Signal, 1) sig := make(chan os.Signal, 1)

12
host.go
View File

@ -14,9 +14,12 @@ import (
type Host struct { type Host struct {
listener *sshd.SSHListener listener *sshd.SSHListener
channel *chat.Channel channel *chat.Channel
motd string
auth *Auth
} }
// NewHost creates a Host on top of an existing listener // NewHost creates a Host on top of an existing listener.
func NewHost(listener *sshd.SSHListener) *Host { func NewHost(listener *sshd.SSHListener) *Host {
ch := chat.NewChannel() ch := chat.NewChannel()
h := Host{ h := Host{
@ -27,7 +30,12 @@ func NewHost(listener *sshd.SSHListener) *Host {
return &h return &h
} }
// Connect a specific Terminal to this host and its channel // SetMotd sets the host's message of the day.
func (h *Host) SetMotd(motd string) {
h.motd = motd
}
// Connect a specific Terminal to this host and its channel.
func (h *Host) Connect(term *sshd.Terminal) { func (h *Host) Connect(term *sshd.Terminal) {
name := term.Conn.User() name := term.Conn.User()
term.AutoCompleteCallback = h.AutoCompleteFunction term.AutoCompleteCallback = h.AutoCompleteFunction

View File

@ -9,13 +9,9 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
var errBanned = errors.New("banned")
var errNotWhitelisted = errors.New("not whitelisted")
var errNoInteractive = errors.New("public key authentication required")
type Auth interface { type Auth interface {
IsBanned(ssh.PublicKey) bool AllowAnonymous() bool
IsWhitelisted(ssh.PublicKey) bool Check(string) (bool, error)
} }
func MakeAuth(auth Auth) *ssh.ServerConfig { func MakeAuth(auth Auth) *ssh.ServerConfig {
@ -23,21 +19,17 @@ func MakeAuth(auth Auth) *ssh.ServerConfig {
NoClientAuth: false, NoClientAuth: false,
// Auth-related things should be constant-time to avoid timing attacks. // Auth-related things should be constant-time to avoid timing attacks.
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if auth.IsBanned(key) { fingerprint := Fingerprint(key)
return nil, errBanned ok, err := auth.Check(fingerprint)
if !ok {
return nil, err
} }
if !auth.IsWhitelisted(key) { perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}}
return nil, errNotWhitelisted
}
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}}
return perm, nil return perm, nil
}, },
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
if auth.IsBanned(nil) { if !auth.AllowAnonymous() {
return nil, errNoInteractive return nil, errors.New("public key authentication required")
}
if !auth.IsWhitelisted(nil) {
return nil, errNotWhitelisted
} }
return nil, nil return nil, nil
}, },
@ -51,7 +43,8 @@ func MakeNoAuth() *ssh.ServerConfig {
NoClientAuth: false, NoClientAuth: false,
// Auth-related things should be constant-time to avoid timing attacks. // Auth-related things should be constant-time to avoid timing attacks.
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
return nil, nil perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}}
return perm, nil
}, },
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
return nil, nil return nil, nil

View File

@ -11,12 +11,12 @@ import (
// Extending ssh/terminal to include a closer interface // Extending ssh/terminal to include a closer interface
type Terminal struct { type Terminal struct {
terminal.Terminal terminal.Terminal
Conn ssh.Conn Conn *ssh.ServerConn
Channel ssh.Channel Channel ssh.Channel
} }
// Make new terminal from a session channel // Make new terminal from a session channel
func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) { func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
if ch.ChannelType() != "session" { if ch.ChannelType() != "session" {
return nil, errors.New("terminal requires session channel") return nil, errors.New("terminal requires session channel")
} }
@ -41,7 +41,7 @@ func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) {
} }
// Find session channel and make a Terminal from it // Find session channel and make a Terminal from it
func NewSession(conn ssh.Conn, channels <-chan ssh.NewChannel) (term *Terminal, err error) { func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
for ch := range channels { for ch := range channels {
if t := ch.ChannelType(); t != "session" { if t := ch.ChannelType(); t != "session" {
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))