Fix SSHCHAT_TIMESTAMP env variables (#392)

* Fixes Env Vars to pass config to ssh-chat.

The env vars were beign parsed and set to the host
before the user was even added to the host and
hence ignored. This change moves the env var parsing
to after initializing the user.

TODO: tests, completeness+reliability

* cleaned up the test

* reduced test flakyness by adding wait instead of being optimistic

Co-authored-by: Akshay <akshay.shekher@gmail.com>
This commit is contained in:
Akshay Shekher 2021-05-02 09:18:31 -07:00 committed by GitHub
parent 46eaf037e3
commit e1e534344e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 115 additions and 27 deletions

57
host.go
View File

@ -49,6 +49,8 @@ type Host struct {
// GetMOTD is used to reload the motd from an external source
GetMOTD func() (string, error)
// OnUserJoined is used to notify when a user joins a host
OnUserJoined func(*message.User)
}
// NewHost creates a Host on top of an existing listener.
@ -114,32 +116,6 @@ func (h *Host) Connect(term *sshd.Terminal) {
}
user.SetConfig(cfg)
// Load user config overrides from ENV
// TODO: Would be nice to skip the command parsing pipeline just to load
// config values. Would need to factor out some command handler logic into
// accessible helpers.
env := term.Env()
for _, e := range env {
switch e.Key {
case "SSHCHAT_TIMESTAMP":
if e.Value != "" && e.Value != "0" {
cmd := "/timestamp"
if e.Value != "1" {
cmd += " " + e.Value
}
if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok {
h.Room.HandleMsg(msg)
}
}
case "SSHCHAT_THEME":
cmd := "/theme " + e.Value
if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok {
h.Room.HandleMsg(msg)
}
}
}
go user.Consume()
// Close term once user is closed.
@ -168,6 +144,31 @@ func (h *Host) Connect(term *sshd.Terminal) {
return
}
// Load user config overrides from ENV
// TODO: Would be nice to skip the command parsing pipeline just to load
// config values. Would need to factor out some command handler logic into
// accessible helpers.
env := term.Env()
for _, e := range env {
switch e.Key {
case "SSHCHAT_TIMESTAMP":
if e.Value != "" && e.Value != "0" {
cmd := "/timestamp"
if e.Value != "1" {
cmd += " " + e.Value
}
if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok {
h.Room.HandleMsg(msg)
}
}
case "SSHCHAT_THEME":
cmd := "/theme " + e.Value
if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok {
h.Room.HandleMsg(msg)
}
}
}
// Successfully joined.
if !apiMode {
term.SetPrompt(GetPrompt(user))
@ -183,6 +184,10 @@ func (h *Host) Connect(term *sshd.Terminal) {
logger.Debugf("[%s] Joined: %s", term.Conn.RemoteAddr(), user.Name())
if h.OnUserJoined != nil {
h.OnUserJoined(user)
}
for {
line, err := term.ReadLine()
if err == io.EOF {

View File

@ -5,6 +5,7 @@ import (
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"io"
mathRand "math/rand"
"strings"
@ -78,7 +79,7 @@ func TestHostGetPrompt(t *testing.T) {
func TestHostNameCollision(t *testing.T) {
t.Skip("Test is flakey on CI. :(")
key, err := sshd.NewRandomSigner(512)
if err != nil {
t.Fatal(err)
@ -284,3 +285,85 @@ func TestHostKick(t *testing.T) {
t.Error(err)
}
}
func TestTimestampEnvConfig(t *testing.T) {
cases := []struct {
input string
timeformat *string
}{
{"", strptr("15:04")},
{"1", strptr("15:04")},
{"0", nil},
{"time +8h", strptr("15:04")},
{"datetime +8h", strptr("2006-01-02 15:04:05")},
}
for _, tc := range cases {
u, err := connectUserWithConfig("dingus", map[string]string{
"SSHCHAT_TIMESTAMP": tc.input,
})
if err != nil {
t.Fatal(err)
}
userConfig := u.Config()
if userConfig.Timeformat != nil && tc.timeformat != nil {
if *userConfig.Timeformat != *tc.timeformat {
t.Fatal("unexpected timeformat:", *userConfig.Timeformat, "expected:", *tc.timeformat)
}
}
}
}
func strptr(s string) *string {
return &s
}
func connectUserWithConfig(name string, envConfig map[string]string) (*message.User, error) {
key, err := sshd.NewRandomSigner(512)
if err != nil {
return nil, fmt.Errorf("unable to create signer: %w", err)
}
config := sshd.MakeNoAuth()
config.AddHostKey(key)
s, err := sshd.ListenSSH("localhost:0", config)
if err != nil {
return nil, fmt.Errorf("unable to create a test server: %w", err)
}
defer s.Close()
host := NewHost(s, nil)
newUsers := make(chan *message.User)
host.OnUserJoined = func(u *message.User) {
newUsers <- u
}
go host.Serve()
clientConfig := sshd.NewClientConfig(name)
conn, err := ssh.Dial("tcp", s.Addr().String(), clientConfig)
if err != nil {
return nil, fmt.Errorf("unable to connect to test ssh-chat server: %w", err)
}
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
return nil, fmt.Errorf("unable to open session: %w", err)
}
defer session.Close()
for key := range envConfig {
session.Setenv(key, envConfig[key])
}
err = session.Shell()
if err != nil {
return nil, fmt.Errorf("unable to open shell: %w", err)
}
for u := range newUsers {
if u.Name() == name {
return u, nil
}
}
return nil, fmt.Errorf("user %s not found in the host", name)
}