diff --git a/host.go b/host.go index dedbb3c..62e6e1a 100644 --- a/host.go +++ b/host.go @@ -49,6 +49,8 @@ type Host struct { // GetMOTD is used to reload the motd from an external source GetMOTD func() (string, error) + // OnUserJoined is used to notify when a user joins a host + OnUserJoined func(*message.User) } // NewHost creates a Host on top of an existing listener. @@ -114,32 +116,6 @@ func (h *Host) Connect(term *sshd.Terminal) { } user.SetConfig(cfg) - - // Load user config overrides from ENV - // TODO: Would be nice to skip the command parsing pipeline just to load - // config values. Would need to factor out some command handler logic into - // accessible helpers. - env := term.Env() - for _, e := range env { - switch e.Key { - case "SSHCHAT_TIMESTAMP": - if e.Value != "" && e.Value != "0" { - cmd := "/timestamp" - if e.Value != "1" { - cmd += " " + e.Value - } - if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok { - h.Room.HandleMsg(msg) - } - } - case "SSHCHAT_THEME": - cmd := "/theme " + e.Value - if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok { - h.Room.HandleMsg(msg) - } - } - } - go user.Consume() // Close term once user is closed. @@ -168,6 +144,31 @@ func (h *Host) Connect(term *sshd.Terminal) { return } + // Load user config overrides from ENV + // TODO: Would be nice to skip the command parsing pipeline just to load + // config values. Would need to factor out some command handler logic into + // accessible helpers. + env := term.Env() + for _, e := range env { + switch e.Key { + case "SSHCHAT_TIMESTAMP": + if e.Value != "" && e.Value != "0" { + cmd := "/timestamp" + if e.Value != "1" { + cmd += " " + e.Value + } + if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok { + h.Room.HandleMsg(msg) + } + } + case "SSHCHAT_THEME": + cmd := "/theme " + e.Value + if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok { + h.Room.HandleMsg(msg) + } + } + } + // Successfully joined. if !apiMode { term.SetPrompt(GetPrompt(user)) @@ -183,6 +184,10 @@ func (h *Host) Connect(term *sshd.Terminal) { logger.Debugf("[%s] Joined: %s", term.Conn.RemoteAddr(), user.Name()) + if h.OnUserJoined != nil { + h.OnUserJoined(user) + } + for { line, err := term.ReadLine() if err == io.EOF { diff --git a/host_test.go b/host_test.go index 76b6d0a..8c190e6 100644 --- a/host_test.go +++ b/host_test.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/rsa" "errors" + "fmt" "io" mathRand "math/rand" "strings" @@ -78,7 +79,7 @@ func TestHostGetPrompt(t *testing.T) { func TestHostNameCollision(t *testing.T) { t.Skip("Test is flakey on CI. :(") - + key, err := sshd.NewRandomSigner(512) if err != nil { t.Fatal(err) @@ -284,3 +285,85 @@ func TestHostKick(t *testing.T) { t.Error(err) } } + +func TestTimestampEnvConfig(t *testing.T) { + cases := []struct { + input string + timeformat *string + }{ + {"", strptr("15:04")}, + {"1", strptr("15:04")}, + {"0", nil}, + {"time +8h", strptr("15:04")}, + {"datetime +8h", strptr("2006-01-02 15:04:05")}, + } + for _, tc := range cases { + u, err := connectUserWithConfig("dingus", map[string]string{ + "SSHCHAT_TIMESTAMP": tc.input, + }) + if err != nil { + t.Fatal(err) + } + userConfig := u.Config() + if userConfig.Timeformat != nil && tc.timeformat != nil { + if *userConfig.Timeformat != *tc.timeformat { + t.Fatal("unexpected timeformat:", *userConfig.Timeformat, "expected:", *tc.timeformat) + } + } + } +} + +func strptr(s string) *string { + return &s +} + +func connectUserWithConfig(name string, envConfig map[string]string) (*message.User, error) { + key, err := sshd.NewRandomSigner(512) + if err != nil { + return nil, fmt.Errorf("unable to create signer: %w", err) + } + config := sshd.MakeNoAuth() + config.AddHostKey(key) + + s, err := sshd.ListenSSH("localhost:0", config) + if err != nil { + return nil, fmt.Errorf("unable to create a test server: %w", err) + } + defer s.Close() + host := NewHost(s, nil) + + newUsers := make(chan *message.User) + host.OnUserJoined = func(u *message.User) { + newUsers <- u + } + go host.Serve() + + clientConfig := sshd.NewClientConfig(name) + conn, err := ssh.Dial("tcp", s.Addr().String(), clientConfig) + if err != nil { + return nil, fmt.Errorf("unable to connect to test ssh-chat server: %w", err) + } + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + return nil, fmt.Errorf("unable to open session: %w", err) + } + defer session.Close() + + for key := range envConfig { + session.Setenv(key, envConfig[key]) + } + + err = session.Shell() + if err != nil { + return nil, fmt.Errorf("unable to open shell: %w", err) + } + + for u := range newUsers { + if u.Name() == name { + return u, nil + } + } + return nil, fmt.Errorf("user %s not found in the host", name) +}