mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-14 16:17:17 +03:00
sshd abstraction might be done, untested.
This commit is contained in:
parent
54b593ed47
commit
59ac8bb037
1
cmd.go
1
cmd.go
@ -34,6 +34,7 @@ var logLevels = []log.Level{
|
||||
}
|
||||
|
||||
var buildCommit string
|
||||
|
||||
func main() {
|
||||
options := Options{}
|
||||
parser := flags.NewParser(&options, flags.Default)
|
||||
|
68
sshd/auth.go
Normal file
68
sshd/auth.go
Normal file
@ -0,0 +1,68 @@
|
||||
package sshd
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var errBanned = errors.New("banned")
|
||||
var errNotWhitelisted = errors.New("not whitelisted")
|
||||
var errNoInteractive = errors.New("public key authentication required")
|
||||
|
||||
type Auth interface {
|
||||
IsBanned(ssh.PublicKey) bool
|
||||
IsWhitelisted(ssh.PublicKey) bool
|
||||
}
|
||||
|
||||
func MakeAuth(auth Auth) *ssh.ServerConfig {
|
||||
config := ssh.ServerConfig{
|
||||
NoClientAuth: false,
|
||||
// Auth-related things should be constant-time to avoid timing attacks.
|
||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
if auth.IsBanned(key) {
|
||||
return nil, errBanned
|
||||
}
|
||||
if !auth.IsWhitelisted(key) {
|
||||
return nil, errNotWhitelisted
|
||||
}
|
||||
perm := &ssh.Permissions{Extensions: map[string]string{"fingerprint": Fingerprint(key)}}
|
||||
return perm, nil
|
||||
},
|
||||
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||
if auth.IsBanned(nil) {
|
||||
return nil, errNoInteractive
|
||||
}
|
||||
if !auth.IsWhitelisted(nil) {
|
||||
return nil, errNotWhitelisted
|
||||
}
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
return &config
|
||||
}
|
||||
|
||||
func MakeNoAuth() *ssh.ServerConfig {
|
||||
config := ssh.ServerConfig{
|
||||
NoClientAuth: false,
|
||||
// Auth-related things should be constant-time to avoid timing attacks.
|
||||
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
return nil, nil
|
||||
},
|
||||
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
return &config
|
||||
}
|
||||
|
||||
func Fingerprint(k ssh.PublicKey) string {
|
||||
hash := sha1.Sum(k.Marshal())
|
||||
r := fmt.Sprintf("% x", hash)
|
||||
return strings.Replace(r, " ", ":", -1)
|
||||
}
|
33
sshd/doc.go
Normal file
33
sshd/doc.go
Normal file
@ -0,0 +1,33 @@
|
||||
package sshd
|
||||
|
||||
/*
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(privateKey)
|
||||
|
||||
config := MakeNoAuth()
|
||||
config.AddHostKey(signer)
|
||||
|
||||
s, err := ListenSSH("0.0.0.0:22", config)
|
||||
if err != nil {
|
||||
// Handle opening socket error
|
||||
}
|
||||
|
||||
terminals := s.ServeTerminal()
|
||||
|
||||
for term := range terminals {
|
||||
go func() {
|
||||
defer term.Close()
|
||||
term.SetPrompt("...")
|
||||
term.AutoCompleteCallback = nil // ...
|
||||
|
||||
for {
|
||||
line, err := term.Readline()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
term.Write(...)
|
||||
}
|
||||
|
||||
}()
|
||||
}
|
||||
*/
|
22
sshd/logger.go
Normal file
22
sshd/logger.go
Normal file
@ -0,0 +1,22 @@
|
||||
package sshd
|
||||
|
||||
import "io"
|
||||
import stdlog "log"
|
||||
|
||||
var logger *stdlog.Logger
|
||||
|
||||
func SetLogger(w io.Writer) {
|
||||
flags := stdlog.Flags()
|
||||
prefix := "[chat] "
|
||||
logger = stdlog.New(w, prefix, flags)
|
||||
}
|
||||
|
||||
type nullWriter struct{}
|
||||
|
||||
func (nullWriter) Write(data []byte) (int, error) {
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
SetLogger(nullWriter{})
|
||||
}
|
42
sshd/multi.go
Normal file
42
sshd/multi.go
Normal file
@ -0,0 +1,42 @@
|
||||
package sshd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Keep track of multiple errors and coerce them into one error
|
||||
type MultiError []error
|
||||
|
||||
func (e MultiError) Error() string {
|
||||
switch len(e) {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
return e[0].Error()
|
||||
default:
|
||||
errs := []string{}
|
||||
for _, err := range e {
|
||||
errs = append(errs, err.Error())
|
||||
}
|
||||
return fmt.Sprintf("%d errors: %s", strings.Join(errs, "; "))
|
||||
}
|
||||
}
|
||||
|
||||
// Keep track of multiple closers and close them all as one closer
|
||||
type MultiCloser []io.Closer
|
||||
|
||||
func (c MultiCloser) Close() error {
|
||||
errors := MultiError{}
|
||||
for _, closer := range c {
|
||||
err := closer.Close()
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
}
|
||||
if len(errors) == 0 {
|
||||
return nil
|
||||
}
|
||||
return errors
|
||||
}
|
68
sshd/net.go
Normal file
68
sshd/net.go
Normal file
@ -0,0 +1,68 @@
|
||||
package sshd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Container for the connection and ssh-related configuration
|
||||
type SSHListener struct {
|
||||
net.Listener
|
||||
config *ssh.ServerConfig
|
||||
}
|
||||
|
||||
// Make an SSH listener socket
|
||||
func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) {
|
||||
socket, err := net.Listen("tcp", laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l := socket.(SSHListener)
|
||||
l.config = config
|
||||
return &l, nil
|
||||
}
|
||||
|
||||
func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) {
|
||||
// Upgrade TCP connection to SSH connection
|
||||
sshConn, channels, requests, err := ssh.NewServerConn(conn, l.config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(requests)
|
||||
return NewSession(sshConn, channels)
|
||||
}
|
||||
|
||||
// Accept incoming connections as terminal requests and yield them
|
||||
func (l *SSHListener) ServeTerminal() <-chan *Terminal {
|
||||
ch := make(chan *Terminal)
|
||||
|
||||
go func() {
|
||||
defer l.Close()
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
|
||||
if err != nil {
|
||||
logger.Printf("Failed to accept connection: %v", err)
|
||||
if err == syscall.EINVAL {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Goroutineify to resume accepting sockets early
|
||||
go func() {
|
||||
term, err := l.handleConn(conn)
|
||||
if err != nil {
|
||||
logger.Printf("Failed to handshake: %v", err)
|
||||
return
|
||||
}
|
||||
ch <- term
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
69
sshd/pty.go
Normal file
69
sshd/pty.go
Normal file
@ -0,0 +1,69 @@
|
||||
// Borrowed from go.crypto circa 2011
|
||||
package sshd
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
// parsePtyRequest parses the payload of the pty-req message and extracts the
|
||||
// dimensions of the terminal. See RFC 4254, section 6.2.
|
||||
func parsePtyRequest(s []byte) (width, height int, ok bool) {
|
||||
_, s, ok = parseString(s)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
width32, s, ok := parseUint32(s)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
height32, _, ok := parseUint32(s)
|
||||
width = int(width32)
|
||||
height = int(height32)
|
||||
if width < 1 {
|
||||
ok = false
|
||||
}
|
||||
if height < 1 {
|
||||
ok = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseWinchRequest(s []byte) (width, height int, ok bool) {
|
||||
width32, s, ok := parseUint32(s)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
height32, s, ok := parseUint32(s)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
width = int(width32)
|
||||
height = int(height32)
|
||||
if width < 1 {
|
||||
ok = false
|
||||
}
|
||||
if height < 1 {
|
||||
ok = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseString(in []byte) (out string, rest []byte, ok bool) {
|
||||
if len(in) < 4 {
|
||||
return
|
||||
}
|
||||
length := binary.BigEndian.Uint32(in)
|
||||
if uint32(len(in)) < 4+length {
|
||||
return
|
||||
}
|
||||
out = string(in[4 : 4+length])
|
||||
rest = in[4+length:]
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
||||
func parseUint32(in []byte) (uint32, []byte, bool) {
|
||||
if len(in) < 4 {
|
||||
return 0, nil, false
|
||||
}
|
||||
return binary.BigEndian.Uint32(in), in[4:], true
|
||||
}
|
98
sshd/server.go
Normal file
98
sshd/server.go
Normal file
@ -0,0 +1,98 @@
|
||||
package sshd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Server holds all the fields used by a server
|
||||
type Server struct {
|
||||
sshConfig *ssh.ServerConfig
|
||||
done chan struct{}
|
||||
started time.Time
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// Initialize a new server
|
||||
func NewServer(privateKey []byte) (*Server, error) {
|
||||
signer, err := ssh.ParsePrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
server := Server{
|
||||
done: make(chan struct{}),
|
||||
started: time.Now(),
|
||||
}
|
||||
|
||||
config := MakeNoAuth()
|
||||
config.AddHostKey(signer)
|
||||
|
||||
server.sshConfig = config
|
||||
|
||||
return &server, nil
|
||||
}
|
||||
|
||||
// Start starts the server
|
||||
func (s *Server) Start(laddr string) error {
|
||||
// Once a ServerConfig has been configured, connections can be
|
||||
// accepted.
|
||||
socket, err := net.Listen("tcp", laddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Infof("Listening on %s", laddr)
|
||||
|
||||
go func() {
|
||||
defer socket.Close()
|
||||
for {
|
||||
conn, err := socket.Accept()
|
||||
|
||||
if err != nil {
|
||||
logger.Printf("Failed to accept connection: %v", err)
|
||||
if err == syscall.EINVAL {
|
||||
// TODO: Handle shutdown more gracefully?
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Goroutineify to resume accepting sockets early.
|
||||
go func() {
|
||||
// From a standard TCP connection to an encrypted SSH connection
|
||||
sshConn, channels, requests, err := ssh.NewServerConn(conn, s.sshConfig)
|
||||
if err != nil {
|
||||
logger.Printf("Failed to handshake: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(requests)
|
||||
|
||||
client := NewClient(s, sshConn)
|
||||
go client.handleChannels(channels)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-s.done
|
||||
socket.Close()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *Server) Stop() {
|
||||
s.Lock()
|
||||
for _, client := range s.clients {
|
||||
client.Conn.Close()
|
||||
}
|
||||
s.Unlock()
|
||||
|
||||
close(s.done)
|
||||
}
|
93
sshd/terminal.go
Normal file
93
sshd/terminal.go
Normal file
@ -0,0 +1,93 @@
|
||||
package sshd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
)
|
||||
|
||||
// Extending ssh/terminal to include a closer interface
|
||||
type Terminal struct {
|
||||
*terminal.Terminal
|
||||
Conn ssh.Conn
|
||||
Channel ssh.Channel
|
||||
}
|
||||
|
||||
// Make new terminal from a session channel
|
||||
func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) {
|
||||
if ch.ChannelType() != "session" {
|
||||
return nil, errors.New("terminal requires session channel")
|
||||
}
|
||||
channel, requests, err := ch.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
term := Terminal{
|
||||
terminal.NewTerminal(channel, "Connecting..."),
|
||||
conn,
|
||||
channel,
|
||||
}
|
||||
|
||||
go term.listen(requests)
|
||||
return &term, nil
|
||||
}
|
||||
|
||||
// Find session channel and make a Terminal from it
|
||||
func NewSession(conn ssh.Conn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
|
||||
for ch := range channels {
|
||||
if t := ch.ChannelType(); t != "session" {
|
||||
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
|
||||
continue
|
||||
}
|
||||
|
||||
term, err = NewTerminal(conn, ch)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return term, err
|
||||
}
|
||||
|
||||
// Close terminal and ssh connection
|
||||
func (t *Terminal) Close() error {
|
||||
return MultiCloser{t.Channel, t.Conn}.Close()
|
||||
}
|
||||
|
||||
// Negotiate terminal type and settings
|
||||
func (t *Terminal) listen(requests <-chan *ssh.Request) {
|
||||
hasShell := false
|
||||
|
||||
for req := range requests {
|
||||
var width, height int
|
||||
var ok bool
|
||||
|
||||
switch req.Type {
|
||||
case "shell":
|
||||
if !hasShell {
|
||||
ok = true
|
||||
hasShell = true
|
||||
}
|
||||
case "pty-req":
|
||||
width, height, ok = parsePtyRequest(req.Payload)
|
||||
if ok {
|
||||
// TODO: Hardcode width to 100000?
|
||||
err := t.SetSize(width, height)
|
||||
ok = err == nil
|
||||
}
|
||||
case "window-change":
|
||||
width, height, ok = parseWinchRequest(req.Payload)
|
||||
if ok {
|
||||
// TODO: Hardcode width to 100000?
|
||||
err := t.SetSize(width, height)
|
||||
ok = err == nil
|
||||
}
|
||||
}
|
||||
|
||||
if req.WantReply {
|
||||
req.Reply(ok, nil)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user