From e1e534344eca3d8fc468d958d0ff3742f7cf5358 Mon Sep 17 00:00:00 2001 From: Akshay Shekher Date: Sun, 2 May 2021 09:18:31 -0700 Subject: [PATCH] Fix SSHCHAT_TIMESTAMP env variables (#392) * Fixes Env Vars to pass config to ssh-chat. The env vars were beign parsed and set to the host before the user was even added to the host and hence ignored. This change moves the env var parsing to after initializing the user. TODO: tests, completeness+reliability * cleaned up the test * reduced test flakyness by adding wait instead of being optimistic Co-authored-by: Akshay --- host.go | 57 +++++++++++++++++++---------------- host_test.go | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 115 insertions(+), 27 deletions(-) diff --git a/host.go b/host.go index dedbb3c..62e6e1a 100644 --- a/host.go +++ b/host.go @@ -49,6 +49,8 @@ type Host struct { // GetMOTD is used to reload the motd from an external source GetMOTD func() (string, error) + // OnUserJoined is used to notify when a user joins a host + OnUserJoined func(*message.User) } // NewHost creates a Host on top of an existing listener. @@ -114,32 +116,6 @@ func (h *Host) Connect(term *sshd.Terminal) { } user.SetConfig(cfg) - - // Load user config overrides from ENV - // TODO: Would be nice to skip the command parsing pipeline just to load - // config values. Would need to factor out some command handler logic into - // accessible helpers. - env := term.Env() - for _, e := range env { - switch e.Key { - case "SSHCHAT_TIMESTAMP": - if e.Value != "" && e.Value != "0" { - cmd := "/timestamp" - if e.Value != "1" { - cmd += " " + e.Value - } - if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok { - h.Room.HandleMsg(msg) - } - } - case "SSHCHAT_THEME": - cmd := "/theme " + e.Value - if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok { - h.Room.HandleMsg(msg) - } - } - } - go user.Consume() // Close term once user is closed. @@ -168,6 +144,31 @@ func (h *Host) Connect(term *sshd.Terminal) { return } + // Load user config overrides from ENV + // TODO: Would be nice to skip the command parsing pipeline just to load + // config values. Would need to factor out some command handler logic into + // accessible helpers. + env := term.Env() + for _, e := range env { + switch e.Key { + case "SSHCHAT_TIMESTAMP": + if e.Value != "" && e.Value != "0" { + cmd := "/timestamp" + if e.Value != "1" { + cmd += " " + e.Value + } + if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok { + h.Room.HandleMsg(msg) + } + } + case "SSHCHAT_THEME": + cmd := "/theme " + e.Value + if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok { + h.Room.HandleMsg(msg) + } + } + } + // Successfully joined. if !apiMode { term.SetPrompt(GetPrompt(user)) @@ -183,6 +184,10 @@ func (h *Host) Connect(term *sshd.Terminal) { logger.Debugf("[%s] Joined: %s", term.Conn.RemoteAddr(), user.Name()) + if h.OnUserJoined != nil { + h.OnUserJoined(user) + } + for { line, err := term.ReadLine() if err == io.EOF { diff --git a/host_test.go b/host_test.go index 76b6d0a..8c190e6 100644 --- a/host_test.go +++ b/host_test.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/rsa" "errors" + "fmt" "io" mathRand "math/rand" "strings" @@ -78,7 +79,7 @@ func TestHostGetPrompt(t *testing.T) { func TestHostNameCollision(t *testing.T) { t.Skip("Test is flakey on CI. :(") - + key, err := sshd.NewRandomSigner(512) if err != nil { t.Fatal(err) @@ -284,3 +285,85 @@ func TestHostKick(t *testing.T) { t.Error(err) } } + +func TestTimestampEnvConfig(t *testing.T) { + cases := []struct { + input string + timeformat *string + }{ + {"", strptr("15:04")}, + {"1", strptr("15:04")}, + {"0", nil}, + {"time +8h", strptr("15:04")}, + {"datetime +8h", strptr("2006-01-02 15:04:05")}, + } + for _, tc := range cases { + u, err := connectUserWithConfig("dingus", map[string]string{ + "SSHCHAT_TIMESTAMP": tc.input, + }) + if err != nil { + t.Fatal(err) + } + userConfig := u.Config() + if userConfig.Timeformat != nil && tc.timeformat != nil { + if *userConfig.Timeformat != *tc.timeformat { + t.Fatal("unexpected timeformat:", *userConfig.Timeformat, "expected:", *tc.timeformat) + } + } + } +} + +func strptr(s string) *string { + return &s +} + +func connectUserWithConfig(name string, envConfig map[string]string) (*message.User, error) { + key, err := sshd.NewRandomSigner(512) + if err != nil { + return nil, fmt.Errorf("unable to create signer: %w", err) + } + config := sshd.MakeNoAuth() + config.AddHostKey(key) + + s, err := sshd.ListenSSH("localhost:0", config) + if err != nil { + return nil, fmt.Errorf("unable to create a test server: %w", err) + } + defer s.Close() + host := NewHost(s, nil) + + newUsers := make(chan *message.User) + host.OnUserJoined = func(u *message.User) { + newUsers <- u + } + go host.Serve() + + clientConfig := sshd.NewClientConfig(name) + conn, err := ssh.Dial("tcp", s.Addr().String(), clientConfig) + if err != nil { + return nil, fmt.Errorf("unable to connect to test ssh-chat server: %w", err) + } + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + return nil, fmt.Errorf("unable to open session: %w", err) + } + defer session.Close() + + for key := range envConfig { + session.Setenv(key, envConfig[key]) + } + + err = session.Shell() + if err != nil { + return nil, fmt.Errorf("unable to open shell: %w", err) + } + + for u := range newUsers { + if u.Name() == name { + return u, nil + } + } + return nil, fmt.Errorf("user %s not found in the host", name) +}