Abstracted sshd.Connection; Op works now.

This commit is contained in:
Andrey Petrov 2015-01-10 13:46:36 -08:00
parent d8d5deac1c
commit d5626b7514
7 changed files with 93 additions and 20 deletions

29
auth.go
View File

@ -81,7 +81,16 @@ func (a *Auth) Op(key ssh.PublicKey) {
a.Unlock() a.Unlock()
} }
// Whitelist will set a fingerprint as a whitelisted user. // IsOp checks if a public key is an op.
func (a Auth) IsOp(key ssh.PublicKey) bool {
authkey := NewAuthKey(key)
a.RLock()
_, ok := a.ops[authkey]
a.RUnlock()
return ok
}
// Whitelist will set a public key as a whitelisted user.
func (a *Auth) Whitelist(key ssh.PublicKey) { func (a *Auth) Whitelist(key ssh.PublicKey) {
authkey := NewAuthKey(key) authkey := NewAuthKey(key)
a.Lock() a.Lock()
@ -89,6 +98,15 @@ func (a *Auth) Whitelist(key ssh.PublicKey) {
a.Unlock() a.Unlock()
} }
// IsWhitelisted checks if a public key is whitelisted.
func (a Auth) IsWhitelisted(key ssh.PublicKey) bool {
authkey := NewAuthKey(key)
a.RLock()
_, ok := a.whitelist[authkey]
a.RUnlock()
return ok
}
// Ban will set a fingerprint as banned. // Ban will set a fingerprint as banned.
func (a *Auth) Ban(key ssh.PublicKey) { func (a *Auth) Ban(key ssh.PublicKey) {
authkey := NewAuthKey(key) authkey := NewAuthKey(key)
@ -96,3 +114,12 @@ func (a *Auth) Ban(key ssh.PublicKey) {
a.banned[authkey] = struct{}{} a.banned[authkey] = struct{}{}
a.Unlock() a.Unlock()
} }
// IsBanned will set a fingerprint as banned.
func (a Auth) IsBanned(key ssh.PublicKey) bool {
authkey := NewAuthKey(key)
a.RLock()
_, ok := a.whitelist[authkey]
a.RUnlock()
return ok
}

View File

@ -35,16 +35,16 @@ func TestAuthWhitelist(t *testing.T) {
auth.Whitelist(key) auth.Whitelist(key)
key_clone, err := ClonePublicKey(key) keyClone, err := ClonePublicKey(key)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if string(key_clone.Marshal()) != string(key.Marshal()) { if string(keyClone.Marshal()) != string(key.Marshal()) {
t.Error("Clone key does not match.") t.Error("Clone key does not match.")
} }
ok, err = auth.Check(key_clone) ok, err = auth.Check(keyClone)
if !ok || err != nil { if !ok || err != nil {
t.Error("Failed to permit whitelisted:", err) t.Error("Failed to permit whitelisted:", err)
} }

View File

@ -114,17 +114,18 @@ func (ch *Channel) Send(m Message) {
} }
// Join the channel as a user, will announce. // Join the channel as a user, will announce.
func (ch *Channel) Join(u *User) error { func (ch *Channel) Join(u *User) (*Member, error) {
if ch.closed { if ch.closed {
return ErrChannelClosed return nil, ErrChannelClosed
} }
err := ch.members.Add(&Member{u, false}) member := Member{u, false}
err := ch.members.Add(&member)
if err != nil { if err != nil {
return err return nil, err
} }
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.members.Len()) s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), ch.members.Len())
ch.Send(NewAnnounceMsg(s)) ch.Send(NewAnnounceMsg(s))
return nil return &member, nil
} }
// Leave the channel as a user, will announce. Mostly used during setup. // Leave the channel as a user, will announce. Mostly used during setup.

View File

@ -28,7 +28,7 @@ func TestChannelJoin(t *testing.T) {
go ch.Serve() go ch.Serve()
defer ch.Close() defer ch.Close()
err := ch.Join(u) _, err := ch.Join(u)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -66,7 +66,7 @@ func TestChannelDoesntBroadcastAnnounceMessagesWhenQuiet(t *testing.T) {
ch := NewChannel() ch := NewChannel()
defer ch.Close() defer ch.Close()
err := ch.Join(u) _, err := ch.Join(u)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -101,7 +101,7 @@ func TestChannelQuietToggleBroadcasts(t *testing.T) {
ch := NewChannel() ch := NewChannel()
defer ch.Close() defer ch.Close()
err := ch.Join(u) _, err := ch.Join(u)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -138,7 +138,7 @@ func TestQuietToggleDisplayState(t *testing.T) {
go ch.Serve() go ch.Serve()
defer ch.Close() defer ch.Close()
err := ch.Join(u) _, err := ch.Join(u)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -174,7 +174,7 @@ func TestChannelNames(t *testing.T) {
go ch.Serve() go ch.Serve()
defer ch.Close() defer ch.Close()
err := ch.Join(u) _, err := ch.Join(u)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

17
host.go
View File

@ -46,9 +46,17 @@ func (h *Host) SetMotd(motd string) {
h.motd = motd h.motd = motd
} }
func (h Host) isOp(conn sshd.Connection) bool {
key, ok := conn.PublicKey()
if !ok {
return false
}
return h.auth.IsOp(key)
}
// Connect a specific Terminal to this host and its channel. // Connect a specific Terminal to this host and its channel.
func (h *Host) Connect(term *sshd.Terminal) { func (h *Host) Connect(term *sshd.Terminal) {
name := term.Conn.User() name := term.Conn.Name()
term.AutoCompleteCallback = h.AutoCompleteFunction term.AutoCompleteCallback = h.AutoCompleteFunction
user := chat.NewUserScreen(name, term) user := chat.NewUserScreen(name, term)
@ -60,11 +68,11 @@ func (h *Host) Connect(term *sshd.Terminal) {
}() }()
defer user.Close() defer user.Close()
err := h.channel.Join(user) member, err := h.channel.Join(user)
if err == chat.ErrIdTaken { if err == chat.ErrIdTaken {
// Try again... // Try again...
user.SetName(fmt.Sprintf("Guest%d", h.count)) user.SetName(fmt.Sprintf("Guest%d", h.count))
err = h.channel.Join(user) member, err = h.channel.Join(user)
} }
if err != nil { if err != nil {
logger.Errorf("Failed to join: %s", err) logger.Errorf("Failed to join: %s", err)
@ -75,6 +83,9 @@ func (h *Host) Connect(term *sshd.Terminal) {
term.SetPrompt(GetPrompt(user)) term.SetPrompt(GetPrompt(user))
h.count++ h.count++
// Should the user be op'd?
member.Op = h.isOp(term.Conn)
for { for {
line, err := term.ReadLine() line, err := term.ReadLine()
if err == io.EOF { if err == io.EOF {

View File

@ -19,7 +19,8 @@ func (a RejectAuth) Check(ssh.PublicKey) (bool, error) {
} }
func consume(ch <-chan *Terminal) { func consume(ch <-chan *Terminal) {
for range ch {} for _ = range ch {
}
} }
func TestClientReject(t *testing.T) { func TestClientReject(t *testing.T) {

View File

@ -8,10 +8,43 @@ import (
"golang.org/x/crypto/ssh/terminal" "golang.org/x/crypto/ssh/terminal"
) )
// Connection is an interface with fields necessary to operate an sshd host.
type Connection interface {
PublicKey() (ssh.PublicKey, bool)
Name() string
Close() error
}
type sshConn struct {
*ssh.ServerConn
}
func (c sshConn) PublicKey() (ssh.PublicKey, bool) {
if c.Permissions == nil {
return nil, false
}
s, ok := c.Permissions.Extensions["pubkey"]
if !ok {
return nil, false
}
key, err := ssh.ParsePublicKey([]byte(s))
if err != nil {
return nil, false
}
return key, true
}
func (c sshConn) Name() string {
return c.User()
}
// Extending ssh/terminal to include a closer interface // Extending ssh/terminal to include a closer interface
type Terminal struct { type Terminal struct {
terminal.Terminal terminal.Terminal
Conn *ssh.ServerConn Conn Connection
Channel ssh.Channel Channel ssh.Channel
} }
@ -26,7 +59,7 @@ func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
} }
term := Terminal{ term := Terminal{
*terminal.NewTerminal(channel, "Connecting..."), *terminal.NewTerminal(channel, "Connecting..."),
conn, sshConn{conn},
channel, channel,
} }