diff --git a/host.go b/host.go index a564613..62e6e1a 100644 --- a/host.go +++ b/host.go @@ -49,6 +49,8 @@ type Host struct { // GetMOTD is used to reload the motd from an external source GetMOTD func() (string, error) + // OnUserJoined is used to notify when a user joins a host + OnUserJoined func(*message.User) } // NewHost creates a Host on top of an existing listener. @@ -182,6 +184,10 @@ func (h *Host) Connect(term *sshd.Terminal) { logger.Debugf("[%s] Joined: %s", term.Conn.RemoteAddr(), user.Name()) + if h.OnUserJoined != nil { + h.OnUserJoined(user) + } + for { line, err := term.ReadLine() if err == io.EOF { diff --git a/host_test.go b/host_test.go index f5dd6ab..8c190e6 100644 --- a/host_test.go +++ b/host_test.go @@ -331,6 +331,11 @@ func connectUserWithConfig(name string, envConfig map[string]string) (*message.U } defer s.Close() host := NewHost(s, nil) + + newUsers := make(chan *message.User) + host.OnUserJoined = func(u *message.User) { + newUsers <- u + } go host.Serve() clientConfig := sshd.NewClientConfig(name) @@ -355,9 +360,10 @@ func connectUserWithConfig(name string, envConfig map[string]string) (*message.U return nil, fmt.Errorf("unable to open shell: %w", err) } - u, ok := host.GetUser(name) - if !ok { - return nil, fmt.Errorf("user %s not found in host", name) + for u := range newUsers { + if u.Name() == name { + return u, nil + } } - return u, nil + return nil, fmt.Errorf("user %s not found in the host", name) }