mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-06-06 18:33:05 +03:00
Channel Member now wrapping User with metadata, new Auth struct.
This commit is contained in:
parent
6874601c0b
commit
6a662bf358
@ -13,12 +13,17 @@ const channelBuffer = 10
|
|||||||
// closed.
|
// closed.
|
||||||
var ErrChannelClosed = errors.New("channel closed")
|
var ErrChannelClosed = errors.New("channel closed")
|
||||||
|
|
||||||
|
// Member is a User with per-Channel metadata attached to it.
|
||||||
|
type Member struct {
|
||||||
|
*User
|
||||||
|
Op bool
|
||||||
|
}
|
||||||
|
|
||||||
// Channel definition, also a Set of User Items
|
// Channel definition, also a Set of User Items
|
||||||
type Channel struct {
|
type Channel struct {
|
||||||
topic string
|
topic string
|
||||||
history *History
|
history *History
|
||||||
users *Set
|
members *Set
|
||||||
ops *Set
|
|
||||||
broadcast chan Message
|
broadcast chan Message
|
||||||
commands Commands
|
commands Commands
|
||||||
closed bool
|
closed bool
|
||||||
@ -32,8 +37,7 @@ func NewChannel() *Channel {
|
|||||||
return &Channel{
|
return &Channel{
|
||||||
broadcast: broadcast,
|
broadcast: broadcast,
|
||||||
history: NewHistory(historyLen),
|
history: NewHistory(historyLen),
|
||||||
users: NewSet(),
|
members: NewSet(),
|
||||||
ops: NewSet(),
|
|
||||||
commands: *defaultCommands,
|
commands: *defaultCommands,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -47,10 +51,10 @@ func (ch *Channel) SetCommands(commands Commands) {
|
|||||||
func (ch *Channel) Close() {
|
func (ch *Channel) Close() {
|
||||||
ch.closeOnce.Do(func() {
|
ch.closeOnce.Do(func() {
|
||||||
ch.closed = true
|
ch.closed = true
|
||||||
ch.users.Each(func(u Item) {
|
ch.members.Each(func(u Item) {
|
||||||
u.(*User).Close()
|
u.(*User).Close()
|
||||||
})
|
})
|
||||||
ch.users.Clear()
|
ch.members.Clear()
|
||||||
close(ch.broadcast)
|
close(ch.broadcast)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -75,8 +79,8 @@ func (ch *Channel) HandleMsg(m Message) {
|
|||||||
skipUser = fromMsg.From()
|
skipUser = fromMsg.From()
|
||||||
}
|
}
|
||||||
|
|
||||||
ch.users.Each(func(u Item) {
|
ch.members.Each(func(u Item) {
|
||||||
user := u.(*User)
|
user := u.(*Member).User
|
||||||
if skip && skipUser == user {
|
if skip && skipUser == user {
|
||||||
// Skip
|
// Skip
|
||||||
return
|
return
|
||||||
@ -108,18 +112,18 @@ func (ch *Channel) Join(u *User) error {
|
|||||||
if ch.closed {
|
if ch.closed {
|
||||||
return ErrChannelClosed
|
return ErrChannelClosed
|
||||||
}
|
}
|
||||||
err := ch.users.Add(u)
|
err := ch.members.Add(&Member{u, false})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.users.Len())
|
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.members.Len())
|
||||||
ch.Send(NewAnnounceMsg(s))
|
ch.Send(NewAnnounceMsg(s))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Leave the channel as a user, will announce.
|
// Leave the channel as a user, will announce. Mostly used during setup.
|
||||||
func (ch *Channel) Leave(u *User) error {
|
func (ch *Channel) Leave(u *User) error {
|
||||||
err := ch.users.Remove(u)
|
err := ch.members.Remove(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -128,6 +132,26 @@ func (ch *Channel) Leave(u *User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Member returns a corresponding Member object to a User if the Member is
|
||||||
|
// present in this channel.
|
||||||
|
func (ch *Channel) Member(u *User) (*Member, bool) {
|
||||||
|
m, err := ch.members.Get(u.Id())
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// Check that it's the same user
|
||||||
|
if m.(*Member).User != u {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return m.(*Member), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsOp returns whether a user is an operator in this channel.
|
||||||
|
func (ch *Channel) IsOp(u *User) bool {
|
||||||
|
m, ok := ch.Member(u)
|
||||||
|
return ok && m.Op
|
||||||
|
}
|
||||||
|
|
||||||
// Topic of the channel.
|
// Topic of the channel.
|
||||||
func (ch *Channel) Topic() string {
|
func (ch *Channel) Topic() string {
|
||||||
return ch.topic
|
return ch.topic
|
||||||
@ -141,9 +165,9 @@ func (ch *Channel) SetTopic(s string) {
|
|||||||
// NamesPrefix lists all members' names with a given prefix, used to query
|
// NamesPrefix lists all members' names with a given prefix, used to query
|
||||||
// for autocompletion purposes.
|
// for autocompletion purposes.
|
||||||
func (ch *Channel) NamesPrefix(prefix string) []string {
|
func (ch *Channel) NamesPrefix(prefix string) []string {
|
||||||
users := ch.users.ListPrefix(prefix)
|
members := ch.members.ListPrefix(prefix)
|
||||||
names := make([]string, len(users))
|
names := make([]string, len(members))
|
||||||
for i, u := range users {
|
for i, u := range members {
|
||||||
names[i] = u.(*User).Name()
|
names[i] = u.(*User).Name()
|
||||||
}
|
}
|
||||||
return names
|
return names
|
||||||
|
@ -98,9 +98,8 @@ func init() {
|
|||||||
c.Add(Command{
|
c.Add(Command{
|
||||||
Prefix: "/help",
|
Prefix: "/help",
|
||||||
Handler: func(channel *Channel, msg CommandMsg) error {
|
Handler: func(channel *Channel, msg CommandMsg) error {
|
||||||
user := msg.From()
|
op := channel.IsOp(msg.From())
|
||||||
op := channel.ops.In(user)
|
channel.Send(NewSystemMsg(channel.commands.Help(op), msg.From()))
|
||||||
channel.Send(NewSystemMsg(channel.commands.Help(op), user))
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@ -193,11 +192,12 @@ func init() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
c.Add(Command{
|
c.Add(Command{
|
||||||
|
Op: true,
|
||||||
Prefix: "/op",
|
Prefix: "/op",
|
||||||
PrefixHelp: "USER",
|
PrefixHelp: "USER",
|
||||||
Help: "Mark user as admin.",
|
Help: "Mark user as admin.",
|
||||||
Handler: func(channel *Channel, msg CommandMsg) error {
|
Handler: func(channel *Channel, msg CommandMsg) error {
|
||||||
if !channel.ops.In(msg.From()) {
|
if !channel.IsOp(msg.From()) {
|
||||||
return errors.New("must be op")
|
return errors.New("must be op")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -206,13 +206,14 @@ func init() {
|
|||||||
return errors.New("must specify user")
|
return errors.New("must specify user")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Add support for fingerprint-based op'ing.
|
// TODO: Add support for fingerprint-based op'ing. This will
|
||||||
user, err := channel.users.Get(Id(args[0]))
|
// probably need to live in host land.
|
||||||
|
member, err := channel.members.Get(Id(args[0]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("user not found")
|
return errors.New("user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
channel.ops.Add(user)
|
member.(*Member).Op = true
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
17
cmd.go
17
cmd.go
@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -93,8 +94,8 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: MakeAuth
|
auth := Auth{}
|
||||||
config := sshd.MakeNoAuth()
|
config := sshd.MakeAuth(auth)
|
||||||
config.AddHostKey(signer)
|
config.AddHostKey(signer)
|
||||||
|
|
||||||
s, err := sshd.ListenSSH(options.Bind, config)
|
s, err := sshd.ListenSSH(options.Bind, config)
|
||||||
@ -106,11 +107,10 @@ func main() {
|
|||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
host := NewHost(s)
|
host := NewHost(s)
|
||||||
go host.Serve()
|
host.auth = &auth
|
||||||
|
|
||||||
/* TODO:
|
|
||||||
for _, fingerprint := range options.Admin {
|
for _, fingerprint := range options.Admin {
|
||||||
server.Op(fingerprint)
|
auth.Op(fingerprint)
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.Whitelist != "" {
|
if options.Whitelist != "" {
|
||||||
@ -123,7 +123,7 @@ func main() {
|
|||||||
|
|
||||||
scanner := bufio.NewScanner(file)
|
scanner := bufio.NewScanner(file)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
server.Whitelist(scanner.Text())
|
auth.Whitelist(scanner.Text())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,9 +137,10 @@ func main() {
|
|||||||
// hack to normalize line endings into \r\n
|
// hack to normalize line endings into \r\n
|
||||||
motdString = strings.Replace(motdString, "\r\n", "\n", -1)
|
motdString = strings.Replace(motdString, "\r\n", "\n", -1)
|
||||||
motdString = strings.Replace(motdString, "\n", "\r\n", -1)
|
motdString = strings.Replace(motdString, "\n", "\r\n", -1)
|
||||||
server.SetMotd(motdString)
|
host.SetMotd(motdString)
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
go host.Serve()
|
||||||
|
|
||||||
// Construct interrupt handler
|
// Construct interrupt handler
|
||||||
sig := make(chan os.Signal, 1)
|
sig := make(chan os.Signal, 1)
|
||||||
|
12
host.go
12
host.go
@ -14,9 +14,12 @@ import (
|
|||||||
type Host struct {
|
type Host struct {
|
||||||
listener *sshd.SSHListener
|
listener *sshd.SSHListener
|
||||||
channel *chat.Channel
|
channel *chat.Channel
|
||||||
|
|
||||||
|
motd string
|
||||||
|
auth *Auth
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHost creates a Host on top of an existing listener
|
// NewHost creates a Host on top of an existing listener.
|
||||||
func NewHost(listener *sshd.SSHListener) *Host {
|
func NewHost(listener *sshd.SSHListener) *Host {
|
||||||
ch := chat.NewChannel()
|
ch := chat.NewChannel()
|
||||||
h := Host{
|
h := Host{
|
||||||
@ -27,7 +30,12 @@ func NewHost(listener *sshd.SSHListener) *Host {
|
|||||||
return &h
|
return &h
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect a specific Terminal to this host and its channel
|
// SetMotd sets the host's message of the day.
|
||||||
|
func (h *Host) SetMotd(motd string) {
|
||||||
|
h.motd = motd
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect a specific Terminal to this host and its channel.
|
||||||
func (h *Host) Connect(term *sshd.Terminal) {
|
func (h *Host) Connect(term *sshd.Terminal) {
|
||||||
name := term.Conn.User()
|
name := term.Conn.User()
|
||||||
term.AutoCompleteCallback = h.AutoCompleteFunction
|
term.AutoCompleteCallback = h.AutoCompleteFunction
|
||||||
|
29
sshd/auth.go
29
sshd/auth.go
@ -9,13 +9,9 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errBanned = errors.New("banned")
|
|
||||||
var errNotWhitelisted = errors.New("not whitelisted")
|
|
||||||
var errNoInteractive = errors.New("public key authentication required")
|
|
||||||
|
|
||||||
type Auth interface {
|
type Auth interface {
|
||||||
IsBanned(ssh.PublicKey) bool
|
AllowAnonymous() bool
|
||||||
IsWhitelisted(ssh.PublicKey) bool
|
Check(string) (bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeAuth(auth Auth) *ssh.ServerConfig {
|
func MakeAuth(auth Auth) *ssh.ServerConfig {
|
||||||
@ -23,21 +19,17 @@ func MakeAuth(auth Auth) *ssh.ServerConfig {
|
|||||||
NoClientAuth: false,
|
NoClientAuth: false,
|
||||||
// Auth-related things should be constant-time to avoid timing attacks.
|
// Auth-related things should be constant-time to avoid timing attacks.
|
||||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
if auth.IsBanned(key) {
|
fingerprint := Fingerprint(key)
|
||||||
return nil, errBanned
|
ok, err := auth.Check(fingerprint)
|
||||||
|
if !ok {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
if !auth.IsWhitelisted(key) {
|
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": fingerprint}}
|
||||||
return nil, errNotWhitelisted
|
|
||||||
}
|
|
||||||
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}}
|
|
||||||
return perm, nil
|
return perm, nil
|
||||||
},
|
},
|
||||||
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||||
if auth.IsBanned(nil) {
|
if !auth.AllowAnonymous() {
|
||||||
return nil, errNoInteractive
|
return nil, errors.New("public key authentication required")
|
||||||
}
|
|
||||||
if !auth.IsWhitelisted(nil) {
|
|
||||||
return nil, errNotWhitelisted
|
|
||||||
}
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
},
|
},
|
||||||
@ -51,7 +43,8 @@ func MakeNoAuth() *ssh.ServerConfig {
|
|||||||
NoClientAuth: false,
|
NoClientAuth: false,
|
||||||
// Auth-related things should be constant-time to avoid timing attacks.
|
// Auth-related things should be constant-time to avoid timing attacks.
|
||||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
return nil, nil
|
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}}
|
||||||
|
return perm, nil
|
||||||
},
|
},
|
||||||
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -11,12 +11,12 @@ import (
|
|||||||
// Extending ssh/terminal to include a closer interface
|
// Extending ssh/terminal to include a closer interface
|
||||||
type Terminal struct {
|
type Terminal struct {
|
||||||
terminal.Terminal
|
terminal.Terminal
|
||||||
Conn ssh.Conn
|
Conn *ssh.ServerConn
|
||||||
Channel ssh.Channel
|
Channel ssh.Channel
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make new terminal from a session channel
|
// Make new terminal from a session channel
|
||||||
func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) {
|
func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
|
||||||
if ch.ChannelType() != "session" {
|
if ch.ChannelType() != "session" {
|
||||||
return nil, errors.New("terminal requires session channel")
|
return nil, errors.New("terminal requires session channel")
|
||||||
}
|
}
|
||||||
@ -41,7 +41,7 @@ func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Find session channel and make a Terminal from it
|
// Find session channel and make a Terminal from it
|
||||||
func NewSession(conn ssh.Conn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
|
func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
|
||||||
for ch := range channels {
|
for ch := range channels {
|
||||||
if t := ch.ChannelType(); t != "session" {
|
if t := ch.ChannelType(); t != "session" {
|
||||||
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
|
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user