mirror of
https://github.com/shazow/ssh-chat.git
synced 2025-04-18 17:57:41 +03:00
Testing for net.
This commit is contained in:
parent
59ac8bb037
commit
7beb7f99bb
@ -7,7 +7,7 @@ var logger *stdlog.Logger
|
|||||||
|
|
||||||
func SetLogger(w io.Writer) {
|
func SetLogger(w io.Writer) {
|
||||||
flags := stdlog.Flags()
|
flags := stdlog.Flags()
|
||||||
prefix := "[chat] "
|
prefix := "[sshd] "
|
||||||
logger = stdlog.New(w, prefix, flags)
|
logger = stdlog.New(w, prefix, flags)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ package sshd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
@ -19,8 +18,7 @@ func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
l := socket.(SSHListener)
|
l := SSHListener{socket, config}
|
||||||
l.config = config
|
|
||||||
return &l, nil
|
return &l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,15 +39,14 @@ func (l *SSHListener) ServeTerminal() <-chan *Terminal {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer l.Close()
|
defer l.Close()
|
||||||
|
defer close(ch)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := l.Accept()
|
conn, err := l.Accept()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("Failed to accept connection: %v", err)
|
logger.Printf("Failed to accept connection: %v", err)
|
||||||
if err == syscall.EINVAL {
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Goroutineify to resume accepting sockets early
|
// Goroutineify to resume accepting sockets early
|
||||||
|
137
sshd/net_test.go
Normal file
137
sshd/net_test.go
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
package sshd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: Move some of these into their own package?
|
||||||
|
|
||||||
|
func MakeKey(bits int) (ssh.Signer, error) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, bits)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ssh.NewSignerFromKey(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClientSession(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
|
||||||
|
config := &ssh.ClientConfig{
|
||||||
|
User: name,
|
||||||
|
Auth: []ssh.AuthMethod{
|
||||||
|
ssh.KeyboardInteractive(func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
|
||||||
|
return
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := ssh.Dial("tcp", host, config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
session, err := conn.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
in, err := session.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := session.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = session.Shell()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
handler(out, in)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerInit(t *testing.T) {
|
||||||
|
config := MakeNoAuth()
|
||||||
|
s, err := ListenSSH(":badport", config)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("should fail on bad port")
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err = ListenSSH(":0", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeTerminals(t *testing.T) {
|
||||||
|
signer, err := MakeKey(512)
|
||||||
|
config := MakeNoAuth()
|
||||||
|
config.AddHostKey(signer)
|
||||||
|
|
||||||
|
s, err := ListenSSH(":0", config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
terminals := s.ServeTerminal()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// Accept one terminal, read from it, echo back, close.
|
||||||
|
term := <-terminals
|
||||||
|
term.SetPrompt("> ")
|
||||||
|
|
||||||
|
line, err := term.ReadLine()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
_, err = term.Write([]byte("echo: " + line + "\r\n"))
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
term.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
host := s.Addr().String()
|
||||||
|
name := "foo"
|
||||||
|
|
||||||
|
err = NewClientSession(host, name, func(r io.Reader, w io.WriteCloser) {
|
||||||
|
// Consume if there is anything
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
w.Write([]byte("hello\r\n"))
|
||||||
|
|
||||||
|
buf.Reset()
|
||||||
|
_, err := io.Copy(buf, r)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := "> hello\r\necho: hello\r\n"
|
||||||
|
actual := buf.String()
|
||||||
|
if actual != expected {
|
||||||
|
t.Errorf("Got `%s`; expected `%s`", actual, expected)
|
||||||
|
}
|
||||||
|
s.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
@ -1,98 +0,0 @@
|
|||||||
package sshd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server holds all the fields used by a server
|
|
||||||
type Server struct {
|
|
||||||
sshConfig *ssh.ServerConfig
|
|
||||||
done chan struct{}
|
|
||||||
started time.Time
|
|
||||||
sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize a new server
|
|
||||||
func NewServer(privateKey []byte) (*Server, error) {
|
|
||||||
signer, err := ssh.ParsePrivateKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
server := Server{
|
|
||||||
done: make(chan struct{}),
|
|
||||||
started: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
config := MakeNoAuth()
|
|
||||||
config.AddHostKey(signer)
|
|
||||||
|
|
||||||
server.sshConfig = config
|
|
||||||
|
|
||||||
return &server, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start starts the server
|
|
||||||
func (s *Server) Start(laddr string) error {
|
|
||||||
// Once a ServerConfig has been configured, connections can be
|
|
||||||
// accepted.
|
|
||||||
socket, err := net.Listen("tcp", laddr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Infof("Listening on %s", laddr)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer socket.Close()
|
|
||||||
for {
|
|
||||||
conn, err := socket.Accept()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("Failed to accept connection: %v", err)
|
|
||||||
if err == syscall.EINVAL {
|
|
||||||
// TODO: Handle shutdown more gracefully?
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Goroutineify to resume accepting sockets early.
|
|
||||||
go func() {
|
|
||||||
// From a standard TCP connection to an encrypted SSH connection
|
|
||||||
sshConn, channels, requests, err := ssh.NewServerConn(conn, s.sshConfig)
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("Failed to handshake: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go ssh.DiscardRequests(requests)
|
|
||||||
|
|
||||||
client := NewClient(s, sshConn)
|
|
||||||
go client.handleChannels(channels)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-s.done
|
|
||||||
socket.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the server
|
|
||||||
func (s *Server) Stop() {
|
|
||||||
s.Lock()
|
|
||||||
for _, client := range s.clients {
|
|
||||||
client.Conn.Close()
|
|
||||||
}
|
|
||||||
s.Unlock()
|
|
||||||
|
|
||||||
close(s.done)
|
|
||||||
}
|
|
@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
// Extending ssh/terminal to include a closer interface
|
// Extending ssh/terminal to include a closer interface
|
||||||
type Terminal struct {
|
type Terminal struct {
|
||||||
*terminal.Terminal
|
terminal.Terminal
|
||||||
Conn ssh.Conn
|
Conn ssh.Conn
|
||||||
Channel ssh.Channel
|
Channel ssh.Channel
|
||||||
}
|
}
|
||||||
@ -25,7 +25,7 @@ func NewTerminal(conn ssh.Conn, ch ssh.NewChannel) (*Terminal, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
term := Terminal{
|
term := Terminal{
|
||||||
terminal.NewTerminal(channel, "Connecting..."),
|
*terminal.NewTerminal(channel, "Connecting..."),
|
||||||
conn,
|
conn,
|
||||||
channel,
|
channel,
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user