Compare commits

..

No commits in common. "master" and "v1.0" have entirely different histories.
master ... v1.0

71 changed files with 1135 additions and 5949 deletions

1
.github/FUNDING.yml vendored
View File

@ -1 +0,0 @@
github: shazow

View File

@ -1,63 +0,0 @@
name: Bug Report
description: Create a report to fix something that is broken
title: "[Bug]: "
labels: ["bug", "needs-triage"]
body:
- type: textarea
id: summary
validations:
required: true
attributes:
label: Summary
description: A clear and concise description of what the bug is.
- type: input
id: client-version
attributes:
label: Client version
description: Paste output of `ssh -V`
validations:
required: true
- type: input
id: server-version
attributes:
label: Server version
description: Paste output of `ssh-chat --version`
placeholder: e.g., ssh-chat v0.1.0
validations:
required: true
- type: input
id: latest-server-version
attributes:
label: Latest server version available (at time of report)
description: Check https://github.com/shazow/ssh-chat/releases and paste the latest version
placeholder: e.g., v0.2.0
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: To Reproduce
description: Steps to reproduce the behavior
placeholder: |
1. Full command to run...
2. Resulting output...
render: markdown
validations:
required: true
- type: textarea
id: expected
attributes:
label: Expected behavior
description: A clear and concise description of what you expected to happen.
placeholder: Describe the expected behavior
validations:
required: true
- type: textarea
id: context
attributes:
label: Additional context
description: Add any other context about the problem here.

View File

@ -1,32 +0,0 @@
name: Go
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
build:
name: Build
runs-on: ubuntu-latest
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ^1
id: go
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: Get dependencies
run: go get -v -t -d ./...
- name: Build
run: go build -v .
- name: Test
run: go test -race -vet "all" -v ./...

3
.gitignore vendored
View File

@ -1,7 +1,4 @@
/build
host_key host_key
host_key.pub host_key.pub
ssh-chat ssh-chat
*.log *.log
.*
vendor/

View File

@ -1,18 +1,17 @@
language: go
notifications: notifications:
email: false email: false
language: go
go:
- 1.x
env:
- CGO_ENABLED=0 GO111MODULE=on
install: install:
- go get github.com/gordonklaus/ineffassign - export PATH=$PATH:$HOME/gopath/bin
- go get github.com/GeertJohan/fgt
- go get github.com/golang/lint/golint
- make deps
script: script:
- diff -u <(echo -n) <(gofmt -d .) # TODO: Bring this back: - fgt golint
- ineffassign . - make test
- go test -vet "all" -v ./...
go:
- 1.4

View File

@ -1,78 +0,0 @@
# Code of Conduct
This code of conduct applies to both: The `ssh.chat` code participants and the `ssh-chat` code contributors.
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces (such as inside the chat) and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project e-mail
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at andrey.petrov@shazow.net. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

View File

@ -1,44 +0,0 @@
# How to Contribute
This is a brief guide on how you can contribute to `ssh-chat`
## Getting Started
Contributions come in the form of bug reports, feature requests, documentation and wiki edits, and pull requests. If you have an issue with a certain feature or encountered a bug, you will refer to the Issues section.
### Submitting an Issue
`ssh-chat` has a lot of issues, and we try to help every one of them as best we can. The best way to submit an issue is simple: check if it already exists using the search bar. If you encounter a bug or want a certain feature, make sure no one else has submitted it before so we can avoid duplicate issues.
When submitting a bug report, make sure you submit very specific details surrounding the bug:
* What did you do to create the bug?
* Was there any error code given or exceptions thrown?
* What operating system are you and which version of OpenSSH are you using?
* If you built from source, what version of Golang did you use to build `ssh-chat`?
These details should help us to come to a solution.
For feature requests, use the search bar to look up if a feature you want has already been requested. If there was an issue already create, you can vote on it using the "thumbs up" emoji.
### Submitting Code
Submitting code is another way to contribute. The best way to start contributing code would be to look at all the open Issues and see if you can find an interesting bug to tackle. Or if there's a feature you want to implement, check if an Issue was opened for it, or even submit the feature request yourself to open up a discussion.
When submitting code, you should, in your commit message, refer to which issue you are working on. That way when the issue is resolved, or if future bugs are introduced because of it, we can refer to the pull request made and try to fix any bugs.
Once submitted, the code must meet the following conditions in order to be accepted:
* Code must be formatted using `gofmt`
* Code must pass code review
* Code must pass the Travis CI testing stage
If the code meets these conditions, then it will be merged into the `master` branch.
### Discussion Channels
Development discussion of `ssh-chat` can be found on Shazow's public `ssh-chat` server. Connect using any `ssh` client with the following:
```bash
$ ssh username@chat.shazow.net
```

View File

@ -1,21 +1,18 @@
FROM golang:alpine AS builder #
# Usage example:
# $ docker build -t ssh-chat .
# $ docker run -d -p 0.0.0.0:(your host machine port):2022 --name ssh-chat ssh-chat
#
FROM golang:1.4
MAINTAINER Alvin Lai <al@alvinlai.com>
WORKDIR /usr/src/app RUN apt-get update
RUN apt-get install -y openssh-client
COPY . . RUN go get github.com/shazow/ssh-chat
RUN apk add make openssh RUN ssh-keygen -f ~/.ssh/id_rsa -t rsa -N ''
RUN make build
EXPOSE 2022
FROM alpine CMD ["-i", "/root/.ssh/id_rsa", "-vv", "--bind", "\":2022\""]
ENTRYPOINT ["ssh-chat"]
RUN apk add openssh
RUN mkdir /root/.ssh
WORKDIR /root/.ssh
RUN ssh-keygen -t rsa -C "chatkey" -f id_rsa
WORKDIR /usr/local/bin
COPY --from=builder /usr/src/app/ssh-chat .
RUN chmod +x ssh-chat
CMD ["/usr/local/bin/ssh-chat"]

45
Gopkg.lock generated
View File

@ -1,45 +0,0 @@
# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
[[projects]]
branch = "master"
name = "github.com/alexcesaro/log"
packages = [".","golog"]
revision = "61e686294e58a8698a9e1091268bb4ac1116bd5e"
[[projects]]
branch = "master"
name = "github.com/howeyc/gopass"
packages = ["."]
revision = "bf9dde6d0d2c004a008c27aaee91170c786f6db8"
[[projects]]
name = "github.com/jessevdk/go-flags"
packages = ["."]
revision = "96dc06278ce32a0e9d957d590bb987c81ee66407"
version = "v1.3.0"
[[projects]]
branch = "master"
name = "github.com/shazow/rateio"
packages = ["."]
revision = "e8e00881e5c12090412414be41c04ca9c8a71106"
[[projects]]
branch = "master"
name = "golang.org/x/crypto"
packages = ["curve25519","ed25519","ed25519/internal/edwards25519","internal/chacha20","poly1305","ssh","ssh/terminal"]
revision = "ee41a25c63fb5b74abf2213abb6dee3751e6ac4a"
[[projects]]
branch = "master"
name = "golang.org/x/sys"
packages = ["unix","windows"]
revision = "2c42eef0765b9837fbdab12011af7830f55f88f0"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "48a7f7477a28e61efdd4256fe7f426bfaf93df53b5731e905088c0e9c2f10d3b"
solver-name = "gps-cdcl"
solver-version = 1

View File

@ -1,46 +0,0 @@
# Gopkg.toml example
#
# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md
# for detailed Gopkg.toml documentation.
#
# required = ["github.com/user/thing/cmd/thing"]
# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"]
#
# [[constraint]]
# name = "github.com/user/project"
# version = "1.0.0"
#
# [[constraint]]
# name = "github.com/user/project2"
# branch = "dev"
# source = "github.com/myfork/project2"
#
# [[override]]
# name = "github.com/x/y"
# version = "2.4.0"
[[constraint]]
branch = "master"
name = "github.com/alexcesaro/log"
[[constraint]]
branch = "master"
name = "github.com/dustin/go-humanize"
[[constraint]]
branch = "master"
name = "github.com/howeyc/gopass"
[[constraint]]
name = "github.com/jessevdk/go-flags"
version = "1.3.0"
[[constraint]]
branch = "master"
name = "github.com/shazow/rateio"
[[constraint]]
branch = "master"
name = "golang.org/x/crypto"

View File

@ -3,13 +3,14 @@ KEY = host_key
PORT = 2022 PORT = 2022
SRCS = %.go SRCS = %.go
VERSION := $(shell git describe --tags --dirty --always 2> /dev/null || echo "dev")
LDFLAGS = -X main.Version=$(VERSION) -extldflags "-static"
all: $(BINARY) all: $(BINARY)
$(BINARY): **/**/*.go **/*.go *.go $(BINARY): deps **/**/*.go **/*.go *.go
go build -ldflags "$(LDFLAGS)" ./cmd/ssh-chat go build -ldflags "-X main.buildCommit `git describe --long --tags --dirty --always`" ./cmd/ssh-chat
deps:
go get ./...
build: $(BINARY) build: $(BINARY)
@ -26,22 +27,5 @@ debug: $(BINARY) $(KEY)
./$(BINARY) --pprof 6060 -i $(KEY) --bind ":$(PORT)" -vv ./$(BINARY) --pprof 6060 -i $(KEY) --bind ":$(PORT)" -vv
test: test:
go test -race -test.timeout 5s ./... go test ./...
golint ./...
release:
# We use static linking for release build. LDFLAGS via
# https://github.com/golang/go/issues/26492
# Can replace LDFLAGS with -static once the issue has been resolved.
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 LDFLAGS='$(LDFLAGS)' ./build_release "github.com/shazow/ssh-chat/cmd/ssh-chat" README.md LICENSE
CGO_ENABLED=0 GOOS=linux GOARCH=386 LDFLAGS='$(LDFLAGS)' ./build_release "github.com/shazow/ssh-chat/cmd/ssh-chat" README.md LICENSE
CGO_ENABLED=0 GOOS=linux GOARCH=arm GOARM=6 LDFLAGS='$(LDFLAGS)' ./build_release "github.com/shazow/ssh-chat/cmd/ssh-chat" README.md LICENSE
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 LDFLAGS='$(LDFLAGS)' ./build_release "github.com/shazow/ssh-chat/cmd/ssh-chat" README.md LICENSE
CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 LDFLAGS='$(LDFLAGS)' ./build_release "github.com/shazow/ssh-chat/cmd/ssh-chat" README.md LICENSE
CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 LDFLAGS='$(LDFLAGS)' ./build_release "github.com/shazow/ssh-chat/cmd/ssh-chat" README.md LICENSE
CGO_ENABLED=0 GOOS=freebsd GOARCH=amd64 LDFLAGS='$(LDFLAGS)' ./build_release "github.com/shazow/ssh-chat/cmd/ssh-chat" README.md LICENSE
CGO_ENABLED=0 GOOS=windows GOARCH=386 LDFLAGS='$(LDFLAGS)' ./build_release "github.com/shazow/ssh-chat/cmd/ssh-chat" README.md LICENSE
deploy: build/ssh-chat-linux_amd64.tgz
ssh -p 2022 ssh.chat tar xvz < build/ssh-chat-linux_amd64.tgz
@echo " --- Ready to deploy ---"
@echo "Run: ssh -t -p 2022 ssh.chat sudo systemctl restart ssh-chat"

33
NOTICE
View File

@ -1,33 +0,0 @@
## x/crypto/ssh/terminal
This project contains a fork of https://github.com/golang/crypto/tree/master/ssh/terminal
under the sshd/terminal directory. The project's original license applies:
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,8 +1,6 @@
[![Build Status](https://travis-ci.org/shazow/ssh-chat.svg?branch=master)](https://travis-ci.org/shazow/ssh-chat) [![Build Status](https://travis-ci.org/shazow/ssh-chat.svg?branch=master)](https://travis-ci.org/shazow/ssh-chat)
[![GoDoc](https://godoc.org/github.com/shazow/ssh-chat?status.svg)](https://godoc.org/github.com/shazow/ssh-chat)
[![Downloads](https://img.shields.io/github/downloads/shazow/ssh-chat/total.svg?color=orange)](https://github.com/shazow/ssh-chat/releases)
[![Bountysource](https://www.bountysource.com/badge/team?team_id=52292&style=bounties_received)](https://www.bountysource.com/teams/ssh-chat/issues?utm_source=ssh-chat&utm_medium=shield&utm_campaign=bounties_received) [![Bountysource](https://www.bountysource.com/badge/team?team_id=52292&style=bounties_received)](https://www.bountysource.com/teams/ssh-chat/issues?utm_source=ssh-chat&utm_medium=shield&utm_campaign=bounties_received)
[![GoDoc](https://godoc.org/github.com/shazow/ssh-chat?status.svg)](https://godoc.org/github.com/shazow/ssh-chat)
# ssh-chat # ssh-chat
@ -12,35 +10,18 @@ Custom SSH server written in Go. Instead of a shell, you get a chat prompt.
Join the party: Join the party:
``` console ```
$ ssh ssh.chat $ ssh chat.shazow.net
``` ```
Please abide by our [project's Code of Conduct](https://github.com/shazow/ssh-chat/blob/master/CODE_OF_CONDUCT.md) while participating in chat. The server's RSA key fingerprint is `e5:d5:d1:75:90:38:42:f6:c7:03:d7:d0:56:7d:6a:db`. If you see something different, you might be [MITM](https://en.wikipedia.org/wiki/Man-in-the-middle_attack)'d.
The host's public key is `ssh.chat ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKPrQofxXqoz2y9A7NFkkENt6iW8/mvpfes3RY/41Oyt` and the fingerprint is `SHA256:yoqMXkCysMTBsvhu2yRoMUl+EmZKlvkN+ZKmL3115xU` (as of 2021-10-13).
If you see something different, you might be [MITM](https://en.wikipedia.org/wiki/Man-in-the-middle_attack)'d.
(Apologies if the server is down, try again shortly.) (Apologies if the server is down, try again shortly.)
## Downloading a release
Recent releases include builds for MacOS (darwin/amd64) and Linux (386,
amd64, and ARM6 for your RaspberryPi).
**[Grab the latest binary release here](https://github.com/shazow/ssh-chat/releases/)**.
Play around with it. Additional [deploy examples are here](https://github.com/shazow/ssh-chat/wiki/Deployment).
## Compiling / Developing ## Compiling / Developing
Most people just want the [latest binary release](https://github.com/shazow/ssh-chat/releases/). If you're sure you want to compile it from source, read on:
You can compile ssh-chat by using `make build`. The resulting binary is portable and You can compile ssh-chat by using `make build`. The resulting binary is portable and
can be run on any system with a similar OS and CPU arch. Go 1.8 or higher is required to compile. can be run on any system with a similar OS and CPU arch. Go 1.3 or higher is required to compile.
If you're developing on this repo, there is a handy Makefile that should set If you're developing on this repo, there is a handy Makefile that should set
things up with `make run`. things up with `make run`.
@ -49,23 +30,20 @@ Additionally, `make debug` runs the server with an http `pprof` server. This all
[http://localhost:6060/debug/pprof/]() and view profiling data. See [http://localhost:6060/debug/pprof/]() and view profiling data. See
[net/http/pprof](http://golang.org/pkg/net/http/pprof/) for more information about `pprof`. [net/http/pprof](http://golang.org/pkg/net/http/pprof/) for more information about `pprof`.
## Quick Start ## Quick Start
``` console ```
Usage: Usage:
ssh-chat [OPTIONS] ssh-chat [OPTIONS]
Application Options: Application Options:
-v, --verbose Show verbose logging. -v, --verbose Show verbose logging.
--version Print version and exit. -i, --identity= Private key to identify server with. (~/.ssh/id_rsa)
-i, --identity= Private key to identify server with. (default: ~/.ssh/id_rsa) --bind= Host and port to listen on. (0.0.0.0:2022)
--bind= Host and port to listen on. (default: 0.0.0.0:2022) --admin= Fingerprint of pubkey to mark as admin.
--admin= File of public keys who are admins. --whitelist= Optional file of pubkey fingerprints that are allowed to connect
--allowlist= Optional file of public keys who are allowed to connect. --motd= Message of the Day file (optional)
--motd= Optional Message of the Day file. --pprof= enable http server for pprof
--log= Write chat log to this file.
--pprof= Enable pprof http server for profiling.
Help Options: Help Options:
-h, --help Show this help message -h, --help Show this help message
@ -74,18 +52,26 @@ Help Options:
After doing `go get github.com/shazow/ssh-chat/...` on this repo, you should be able After doing `go get github.com/shazow/ssh-chat/...` on this repo, you should be able
to run a command like: to run a command like:
``` console ```
$ ssh-chat --verbose --bind ":22" --identity ~/.ssh/id_dsa $ ssh-chat --verbose --bind ":22" --identity ~/.ssh/id_dsa
``` ```
To bind on port 22, you'll need to make sure it's free (move any other ssh To bind on port 22, you'll need to make sure it's free (move any other ssh
daemons to another port) and run ssh-chat as root (or with sudo). daemons to another port) and run ssh-chat as root (or with sudo).
## Frequently Asked Questions ## Deploying with Docker
The FAQs can be found on the project's [Wiki page](https://github.com/shazow/ssh-chat/wiki/FAQ). You can run ssh-chat using a Docker image without manually installing go-lang:
Feel free to submit more questions to be answered and added to the page.
**Note: alvin/ssh-chat has v0 which is not the latest master branch as of this writing (Jan 23, 2015)**
```
$ docker pull alvin/ssh-chat
$ docker run -d -p 0.0.0.0:(your host machine port):2022 --name ssh-chat alvin/ssh-chat
```
See notes in the header of our Dockerfile for details on building your own image.
## License ## License
MIT This project is licensed under the MIT open source license.

294
auth.go
View File

@ -1,48 +1,30 @@
package sshchat package sshchat
import ( import (
"crypto/sha256"
"crypto/subtle"
"encoding/csv"
"errors" "errors"
"fmt"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
"github.com/shazow/ssh-chat/set"
"github.com/shazow/ssh-chat/sshd" "github.com/shazow/ssh-chat/sshd"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// KeyLoader loads public keys, e.g. from an authorized_keys file. // The error returned a key is checked that is not whitelisted, with whitelisting required.
// It must return a nil slice on error. var ErrNotWhitelisted = errors.New("not whitelisted")
type KeyLoader func() ([]ssh.PublicKey, error)
// ErrNotAllowed Is the error returned when a key is checked that is not allowlisted, // The error returned a key is checked that is banned.
// when allowlisting is enabled.
var ErrNotAllowed = errors.New("not allowed")
// ErrBanned is the error returned when a client is banned.
var ErrBanned = errors.New("banned") var ErrBanned = errors.New("banned")
// ErrIncorrectPassphrase is the error returned when a provided passphrase is incorrect.
var ErrIncorrectPassphrase = errors.New("incorrect passphrase")
// newAuthKey returns string from an ssh.PublicKey used to index the key in our lookup. // newAuthKey returns string from an ssh.PublicKey used to index the key in our lookup.
func newAuthKey(key ssh.PublicKey) string { func newAuthKey(key ssh.PublicKey) string {
if key == nil { if key == nil {
return "" return ""
} }
// FIXME: Is there a better way to index pubkeys without marshal'ing them into strings? // FIXME: Is there a way to index pubkeys without marshal'ing them into strings?
return sshd.Fingerprint(key) return sshd.Fingerprint(key)
} }
func newAuthItem(key ssh.PublicKey) set.Item {
return set.StringItem(newAuthKey(key))
}
// newAuthAddr returns a string from a net.Addr used to index the address the key in our lookup. // newAuthAddr returns a string from a net.Addr used to index the address the key in our lookup.
func newAuthAddr(addr net.Addr) string { func newAuthAddr(addr net.Addr) string {
if addr == nil { if addr == nil {
@ -52,109 +34,52 @@ func newAuthAddr(addr net.Addr) string {
return host return host
} }
// Auth stores lookups for bans, allowlists, and ops. It implements the sshd.Auth interface. // Auth stores lookups for bans, whitelists, and ops. It implements the sshd.Auth interface.
// If the contained passphrase is not empty, it complements a allowlist.
type Auth struct { type Auth struct {
passphraseHash []byte sync.RWMutex
bannedAddr *set.Set bannedAddr *Set
bannedClient *set.Set banned *Set
banned *set.Set whitelist *Set
allowlist *set.Set ops *Set
ops *set.Set
settingsMu sync.RWMutex
allowlistMode bool
opLoader KeyLoader
allowlistLoader KeyLoader
} }
// NewAuth creates a new empty Auth. // NewAuth creates a new empty Auth.
func NewAuth() *Auth { func NewAuth() *Auth {
return &Auth{ return &Auth{
bannedAddr: set.New(), bannedAddr: NewSet(),
bannedClient: set.New(), banned: NewSet(),
banned: set.New(), whitelist: NewSet(),
allowlist: set.New(), ops: NewSet(),
ops: set.New(),
}
}
func (a *Auth) AllowlistMode() bool {
a.settingsMu.RLock()
defer a.settingsMu.RUnlock()
return a.allowlistMode
}
func (a *Auth) SetAllowlistMode(value bool) {
a.settingsMu.Lock()
defer a.settingsMu.Unlock()
a.allowlistMode = value
}
// SetPassphrase enables passphrase authentication with the given passphrase.
// If an empty passphrase is given, disable passphrase authentication.
func (a *Auth) SetPassphrase(passphrase string) {
if passphrase == "" {
a.passphraseHash = nil
} else {
hashArray := sha256.Sum256([]byte(passphrase))
a.passphraseHash = hashArray[:]
} }
} }
// AllowAnonymous determines if anonymous users are permitted. // AllowAnonymous determines if anonymous users are permitted.
func (a *Auth) AllowAnonymous() bool { func (a *Auth) AllowAnonymous() bool {
return !a.AllowlistMode() && a.passphraseHash == nil return a.whitelist.Len() == 0
} }
// AcceptPassphrase determines if passphrase authentication is accepted. // Check determines if a pubkey fingerprint is permitted.
func (a *Auth) AcceptPassphrase() bool { func (a *Auth) Check(addr net.Addr, key ssh.PublicKey) (bool, error) {
return a.passphraseHash != nil
}
// CheckBans checks IP, key and client bans.
func (a *Auth) CheckBans(addr net.Addr, key ssh.PublicKey, clientVersion string) error {
authkey := newAuthKey(key) authkey := newAuthKey(key)
var banned bool if a.whitelist.Len() != 0 {
if authkey != "" { // Only check whitelist if there is something in it, otherwise it's disabled.
banned = a.banned.In(authkey) whitelisted := a.whitelist.In(authkey)
if !whitelisted {
return false, ErrNotWhitelisted
}
return true, nil
} }
banned := a.banned.In(authkey)
if !banned { if !banned {
banned = a.bannedAddr.In(newAuthAddr(addr)) banned = a.bannedAddr.In(newAuthAddr(addr))
} }
if !banned { if banned {
banned = a.bannedClient.In(clientVersion) return false, ErrBanned
}
// Ops can bypass bans, just in case we ban ourselves.
if banned && !a.IsOp(key) {
return ErrBanned
} }
return nil return true, nil
}
// CheckPubkey determines if a pubkey fingerprint is permitted.
func (a *Auth) CheckPublicKey(key ssh.PublicKey) error {
authkey := newAuthKey(key)
allowlisted := a.allowlist.In(authkey)
if a.AllowAnonymous() || allowlisted || a.IsOp(key) {
return nil
} else {
return ErrNotAllowed
}
}
// CheckPassphrase determines if a passphrase is permitted.
func (a *Auth) CheckPassphrase(passphrase string) error {
if !a.AcceptPassphrase() {
return errors.New("passphrases not accepted") // this should never happen
}
passedPassphraseHash := sha256.Sum256([]byte(passphrase))
if subtle.ConstantTimeCompare(passedPassphraseHash[:], a.passphraseHash) == 0 {
return ErrIncorrectPassphrase
}
return nil
} }
// Op sets a public key as a known operator. // Op sets a public key as a known operator.
@ -162,79 +87,36 @@ func (a *Auth) Op(key ssh.PublicKey, d time.Duration) {
if key == nil { if key == nil {
return return
} }
authItem := newAuthItem(key) authkey := newAuthKey(key)
if d != 0 { if d != 0 {
a.ops.Set(set.Expire(authItem, d)) a.ops.AddExpiring(authkey, d)
} else { } else {
a.ops.Set(authItem) a.ops.Add(authkey)
} }
logger.Debugf("Added to ops: %q (for %s)", authItem.Key(), d) logger.Debugf("Added to ops: %s (for %s)", authkey, d)
} }
// IsOp checks if a public key is an op. // IsOp checks if a public key is an op.
func (a *Auth) IsOp(key ssh.PublicKey) bool { func (a *Auth) IsOp(key ssh.PublicKey) bool {
if key == nil {
return false
}
authkey := newAuthKey(key) authkey := newAuthKey(key)
return a.ops.In(authkey) return a.ops.In(authkey)
} }
// LoadOps sets the public keys form loader to operators and saves the loader for later use // Whitelist will set a public key as a whitelisted user.
func (a *Auth) LoadOps(loader KeyLoader) error { func (a *Auth) Whitelist(key ssh.PublicKey, d time.Duration) {
a.settingsMu.Lock()
a.opLoader = loader
a.settingsMu.Unlock()
return a.ReloadOps()
}
// ReloadOps sets the public keys from a loader saved in the last call to operators
func (a *Auth) ReloadOps() error {
a.settingsMu.RLock()
defer a.settingsMu.RUnlock()
return addFromLoader(a.opLoader, a.Op)
}
// Allowlist will set a public key as a allowlisted user.
func (a *Auth) Allowlist(key ssh.PublicKey, d time.Duration) {
if key == nil { if key == nil {
return return
} }
var err error authkey := newAuthKey(key)
authItem := newAuthItem(key)
if d != 0 { if d != 0 {
err = a.allowlist.Set(set.Expire(authItem, d)) a.whitelist.AddExpiring(authkey, d)
} else { } else {
err = a.allowlist.Set(authItem) a.whitelist.Add(authkey)
} }
if err == nil { logger.Debugf("Added to whitelist: %s (for %s)", authkey, d)
logger.Debugf("Added to allowlist: %q (for %s)", authItem.Key(), d)
} else {
logger.Errorf("Error adding %q to allowlist for %s: %s", authItem.Key(), d, err)
}
}
// LoadAllowlist adds the public keys from the loader to the allowlist and saves the loader for later use
func (a *Auth) LoadAllowlist(loader KeyLoader) error {
a.settingsMu.Lock()
a.allowlistLoader = loader
a.settingsMu.Unlock()
return a.ReloadAllowlist()
}
// LoadAllowlist adds the public keys from a loader saved in a previous call to the allowlist
func (a *Auth) ReloadAllowlist() error {
a.settingsMu.RLock()
defer a.settingsMu.RUnlock()
return addFromLoader(a.allowlistLoader, a.Allowlist)
}
func addFromLoader(loader KeyLoader, adder func(ssh.PublicKey, time.Duration)) error {
if loader == nil {
return nil
}
keys, err := loader()
for _, key := range keys {
adder(key, 0)
}
return err
} }
// Ban will set a public key as banned. // Ban will set a public key as banned.
@ -247,97 +129,21 @@ func (a *Auth) Ban(key ssh.PublicKey, d time.Duration) {
// BanFingerprint will set a public key fingerprint as banned. // BanFingerprint will set a public key fingerprint as banned.
func (a *Auth) BanFingerprint(authkey string, d time.Duration) { func (a *Auth) BanFingerprint(authkey string, d time.Duration) {
// FIXME: This is a case insensitive key, which isn't great...
authItem := set.StringItem(authkey)
if d != 0 { if d != 0 {
a.banned.Set(set.Expire(authItem, d)) a.banned.AddExpiring(authkey, d)
} else { } else {
a.banned.Set(authItem) a.banned.Add(authkey)
} }
logger.Debugf("Added to banned: %q (for %s)", authItem.Key(), d) logger.Debugf("Added to banned: %s (for %s)", authkey, d)
} }
// BanClient will set client version as banned. Useful for misbehaving bots. // Ban will set an IP address as banned.
func (a *Auth) BanClient(client string, d time.Duration) {
item := set.StringItem(client)
if d != 0 {
a.bannedClient.Set(set.Expire(item, d))
} else {
a.bannedClient.Set(item)
}
logger.Debugf("Added to banned: %q (for %s)", item.Key(), d)
}
// Banned returns the list of banned keys.
func (a *Auth) Banned() (ip []string, fingerprint []string, client []string) {
a.banned.Each(func(key string, _ set.Item) error {
fingerprint = append(fingerprint, key)
return nil
})
a.bannedAddr.Each(func(key string, _ set.Item) error {
ip = append(ip, key)
return nil
})
a.bannedClient.Each(func(key string, _ set.Item) error {
client = append(client, key)
return nil
})
return
}
// BanAddr will set an IP address as banned.
func (a *Auth) BanAddr(addr net.Addr, d time.Duration) { func (a *Auth) BanAddr(addr net.Addr, d time.Duration) {
authItem := set.StringItem(newAuthAddr(addr)) key := newAuthAddr(addr)
if d != 0 { if d != 0 {
a.bannedAddr.Set(set.Expire(authItem, d)) a.bannedAddr.AddExpiring(key, d)
} else { } else {
a.bannedAddr.Set(authItem) a.bannedAddr.Add(key)
} }
logger.Debugf("Added to bannedAddr: %q (for %s)", authItem.Key(), d) logger.Debugf("Added to bannedAddr: %s (for %s)", key, d)
}
// BanQuery takes space-separated key="value" pairs to ban, including ip, fingerprint, client.
// Fields without an = will be treated as a duration, applied to the next field.
// For example: 5s client=foo 10min ip=1.1.1.1
// Will ban client foo for 5 seconds, and ip 1.1.1.1 for 10min.
func (a *Auth) BanQuery(q string) error {
r := csv.NewReader(strings.NewReader(q))
r.Comma = ' '
fields, err := r.Read()
if err != nil {
return err
}
var d time.Duration
if last := fields[len(fields)-1]; !strings.Contains(last, "=") {
d, err = time.ParseDuration(last)
if err != nil {
return err
}
fields = fields[:len(fields)-1]
}
for _, field := range fields {
parts := strings.SplitN(field, "=", 2)
if len(parts) != 2 {
return fmt.Errorf("invalid query: %q", q)
}
key, value := parts[0], parts[1]
switch key {
case "client":
a.BanClient(value, d)
case "fingerprint":
// TODO: Add a validity check?
a.BanFingerprint(value, d)
case "ip":
ip := net.ParseIP(value)
if ip.String() == "" {
return fmt.Errorf("invalid ip value: %q", ip)
}
a.BanAddr(&net.TCPAddr{IP: ip}, d)
default:
return fmt.Errorf("unknown query field: %q", field)
}
}
return nil
} }

View File

@ -21,20 +21,19 @@ func ClonePublicKey(key ssh.PublicKey) (ssh.PublicKey, error) {
return ssh.ParsePublicKey(key.Marshal()) return ssh.ParsePublicKey(key.Marshal())
} }
func TestAuthAllowlist(t *testing.T) { func TestAuthWhitelist(t *testing.T) {
key, err := NewRandomPublicKey(512) key, err := NewRandomPublicKey(512)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
auth := NewAuth() auth := NewAuth()
err = auth.CheckPublicKey(key) ok, err := auth.Check(nil, key)
if err != nil { if !ok || err != nil {
t.Error("Failed to permit in default state:", err) t.Error("Failed to permit in default state:", err)
} }
auth.Allowlist(key, 0) auth.Whitelist(key, 0)
auth.SetAllowlistMode(true)
keyClone, err := ClonePublicKey(key) keyClone, err := ClonePublicKey(key)
if err != nil { if err != nil {
@ -45,9 +44,9 @@ func TestAuthAllowlist(t *testing.T) {
t.Error("Clone key does not match.") t.Error("Clone key does not match.")
} }
err = auth.CheckPublicKey(keyClone) ok, err = auth.Check(nil, keyClone)
if err != nil { if !ok || err != nil {
t.Error("Failed to permit allowlisted:", err) t.Error("Failed to permit whitelisted:", err)
} }
key2, err := NewRandomPublicKey(512) key2, err := NewRandomPublicKey(512)
@ -55,42 +54,9 @@ func TestAuthAllowlist(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = auth.CheckPublicKey(key2) ok, err = auth.Check(nil, key2)
if err == nil { if ok || err == nil {
t.Error("Failed to restrict not allowlisted:", err) t.Error("Failed to restrict not whitelisted:", err)
}
}
func TestAuthPassphrases(t *testing.T) {
auth := NewAuth()
if auth.AcceptPassphrase() {
t.Error("Doesn't known it won't accept passphrases.")
}
auth.SetPassphrase("")
if auth.AcceptPassphrase() {
t.Error("Doesn't known it won't accept passphrases.")
}
err := auth.CheckPassphrase("Pa$$w0rd")
if err == nil {
t.Error("Failed to deny without passphrase:", err)
}
auth.SetPassphrase("Pa$$w0rd")
err = auth.CheckPassphrase("Pa$$w0rd")
if err != nil {
t.Error("Failed to allow vaild passphrase:", err)
}
err = auth.CheckPassphrase("something else")
if err == nil {
t.Error("Failed to restrict wrong passphrase:", err)
}
auth.SetPassphrase("")
if auth.AcceptPassphrase() {
t.Error("Didn't clear passphrase.")
} }
} }

View File

@ -1,67 +0,0 @@
#!/usr/bin/env bash
usage() {
echo "Build and bundle Go releases with the current dir as the build dir."
echo "Usage: $0 PACKAGE [ASSETS...]"
}
main() {
set -eo pipefail
[[ "$TRACE" ]] && set -x
if [[ ! "$1" ]]; then
usage
exit 1
fi
if [[ ! "$GOOS" ]]; then
export GOOS="linux"
echo "Defaulting to GOOS=$GOOS"
fi
if [[ ! "$GOARCH" ]]; then
export GOARCH="amd64"
echo "Defaulting to GOARCH=$GOARCH"
fi
if [[ ! "$BUILDDIR" ]]; then
export BUILDDIR="build"
echo "Defaulting to BUILDDIR=$BUILDDIR"
fi
build "$@"
}
build() {
local package="$1"; shift
local assets="$@"
local bin="$(basename $package)"
local tarball="${bin}-${GOOS}_${GOARCH}.tgz"
local outdir="$BUILDDIR/$bin"
local tardir="$bin"
if [ "$GOOS" == "windows" ]; then
bin="$bin.exe"
fi
if [[ -d "$outdir" ]]; then
echo "err: outdir already exists: $PWD/$outdir"
fi
mkdir -p "$outdir"
go build -ldflags "$LDFLAGS" -o "$outdir/$bin" "$package"
# Stage asset bundle
if [[ "$assets" ]]; then
ln -f $assets "$outdir"
fi
# Create tarball
tar -C "$BUILDDIR" -czvf "$BUILDDIR/$tarball" "$tardir"
# Cleanup
rm -rf "$outdir"
echo "Packaged: $tarball"
}
main "$@"

View File

@ -5,37 +5,35 @@ package chat
import ( import (
"errors" "errors"
"fmt" "fmt"
"sort"
"strings" "strings"
"time"
"github.com/shazow/ssh-chat/chat/message" "github.com/shazow/ssh-chat/chat/message"
"github.com/shazow/ssh-chat/internal/sanitize"
"github.com/shazow/ssh-chat/set"
) )
// ErrInvalidCommand is the error returned when an invalid command is issued. // The error returned when an invalid command is issued.
var ErrInvalidCommand = errors.New("invalid command") var ErrInvalidCommand = errors.New("invalid command")
// ErrNoOwner is the error returned when a command is given without an owner. // The error returned when a command is given without an owner.
var ErrNoOwner = errors.New("command without owner") var ErrNoOwner = errors.New("command without owner")
// ErrMissingArg is the error returned when a command is performed without the necessary // The error returned when a command is performed without the necessary number
// number of arguments. // of arguments.
var ErrMissingArg = errors.New("missing argument") var ErrMissingArg = errors.New("missing argument")
// ErrMissingPrefix is the error returned when a command is added without a prefix. // The error returned when a command is added without a prefix.
var ErrMissingPrefix = errors.New("command missing prefix") var ErrMissingPrefix = errors.New("command missing prefix")
// Command is a definition of a handler for a command. // Command is a definition of a handler for a command.
type Command struct { type Command struct {
Prefix string // The command's key, such as /foo // The command's key, such as /foo
PrefixHelp string // Extra help regarding arguments Prefix string
Help string // help text, if omitted, command is hidden from /help // Extra help regarding arguments
Op bool // does the command require Op permissions? PrefixHelp string
// If omitted, command is hidden from /help
// Handler for the command Help string
Handler func(*Room, message.CommandMsg) error Handler func(*Room, message.CommandMsg) error
// Command requires Op permissions
Op bool
} }
// Commands is a registry of available commands. // Commands is a registry of available commands.
@ -95,10 +93,6 @@ func (c Commands) Help(showOp bool) string {
return help return help
} }
var timeformatDatetime = "2006-01-02 15:04:05"
var timeformatTime = "15:04"
var defaultCommands *Commands var defaultCommands *Commands
func init() { func init() {
@ -153,20 +147,16 @@ func InitCommands(c *Commands) {
} }
u := msg.From() u := msg.From()
member, ok := room.MemberByID(u.ID()) member, ok := room.MemberById(u.Id())
if !ok { if !ok {
return errors.New("failed to find member") return errors.New("failed to find member")
} }
oldID := member.ID() oldId := member.Id()
newID := sanitize.Name(args[0]) member.SetId(SanitizeName(args[0]))
if newID == oldID { err := room.Rename(oldId, member)
return errors.New("new name is the same as the original")
}
member.SetID(newID)
err := room.Rename(oldID, member)
if err != nil { if err != nil {
member.SetID(oldID) member.SetId(oldId)
return err return err
} }
return nil return nil
@ -177,26 +167,9 @@ func InitCommands(c *Commands) {
Prefix: "/names", Prefix: "/names",
Help: "List users who are connected.", Help: "List users who are connected.",
Handler: func(room *Room, msg message.CommandMsg) error { Handler: func(room *Room, msg message.CommandMsg) error {
theme := msg.From().Config().Theme // TODO: colorize
names := room.NamesPrefix("")
colorize := func(u *message.User) string { body := fmt.Sprintf("%d connected: %s", len(names), strings.Join(names, ", "))
return theme.ColorName(u)
}
if theme == nil {
colorize = func(u *message.User) string {
return u.Name()
}
}
names := room.Members.ListPrefix("")
sort.Slice(names, func(i, j int) bool { return names[i].Key() < names[j].Key() })
colNames := make([]string, len(names))
for i, uname := range names {
colNames[i] = colorize(uname.Value().(*Member).User)
}
body := fmt.Sprintf("%d connected: %s", len(colNames), strings.Join(colNames, ", "))
room.Send(message.NewSystemMsg(body, msg.From())) room.Send(message.NewSystemMsg(body, msg.From()))
return nil return nil
}, },
@ -205,36 +178,25 @@ func InitCommands(c *Commands) {
c.Add(Command{ c.Add(Command{
Prefix: "/theme", Prefix: "/theme",
PrefixHelp: "[colors|...]", PrefixHelp: "[mono|colors]",
Help: "Set your color theme.", Help: "Set your color theme.",
Handler: func(room *Room, msg message.CommandMsg) error { Handler: func(room *Room, msg message.CommandMsg) error {
user := msg.From() user := msg.From()
args := msg.Args() args := msg.Args()
cfg := user.Config()
if len(args) == 0 { if len(args) == 0 {
theme := "plain" theme := "plain"
if cfg.Theme != nil { if user.Config.Theme != nil {
theme = cfg.Theme.ID() theme = user.Config.Theme.Id()
} }
var output strings.Builder body := fmt.Sprintf("Current theme: %s", theme)
fmt.Fprintf(&output, "Current theme: %s%s", theme, message.Newline) room.Send(message.NewSystemMsg(body, user))
fmt.Fprintf(&output, " Themes available: ")
for i, t := range message.Themes {
output.WriteString(t.ID())
if i < len(message.Themes)-1 {
output.WriteString(", ")
}
}
room.Send(message.NewSystemMsg(output.String(), user))
return nil return nil
} }
id := args[0] id := args[0]
for _, t := range message.Themes { for _, t := range message.Themes {
if t.ID() == id { if t.Id() == id {
cfg.Theme = &t user.Config.Theme = &t
user.SetConfig(cfg)
body := fmt.Sprintf("Set theme: %s", id) body := fmt.Sprintf("Set theme: %s", id)
room.Send(message.NewSystemMsg(body, user)) room.Send(message.NewSystemMsg(body, user))
return nil return nil
@ -249,12 +211,10 @@ func InitCommands(c *Commands) {
Help: "Silence room announcements.", Help: "Silence room announcements.",
Handler: func(room *Room, msg message.CommandMsg) error { Handler: func(room *Room, msg message.CommandMsg) error {
u := msg.From() u := msg.From()
cfg := u.Config() u.ToggleQuietMode()
cfg.Quiet = !cfg.Quiet
u.SetConfig(cfg)
var body string var body string
if cfg.Quiet { if u.Config.Quiet {
body = "Quiet mode is toggled ON" body = "Quiet mode is toggled ON"
} else { } else {
body = "Quiet mode is toggled OFF" body = "Quiet mode is toggled OFF"
@ -280,250 +240,4 @@ func InitCommands(c *Commands) {
return nil return nil
}, },
}) })
c.Add(Command{
Prefix: "/shrug",
Handler: func(room *Room, msg message.CommandMsg) error {
room.Send(message.NewEmoteMsg(`¯\_(ツ)_/¯`, msg.From()))
return nil
},
})
c.Add(Command{
Prefix: "/timestamp",
PrefixHelp: "[time|datetime]",
Help: "Prefix messages with a timestamp. You can also provide the UTC offset: /timestamp time +5h45m",
Handler: func(room *Room, msg message.CommandMsg) error {
u := msg.From()
cfg := u.Config()
args := msg.Args()
mode := ""
if len(args) >= 1 {
mode = args[0]
}
if len(args) >= 2 {
// FIXME: This is an annoying format to demand from users, but
// hopefully we can make it a non-primary flow if we add GeoIP
// someday.
offset, err := time.ParseDuration(args[1])
if err != nil {
return err
}
cfg.Timezone = time.FixedZone("", int(offset.Seconds()))
}
switch mode {
case "time":
cfg.Timeformat = &timeformatTime
case "datetime":
cfg.Timeformat = &timeformatDatetime
case "":
// Toggle
if cfg.Timeformat != nil {
cfg.Timeformat = nil
} else {
cfg.Timeformat = &timeformatTime
}
case "off":
cfg.Timeformat = nil
default:
return errors.New("timestamp value must be one of: time, datetime, off")
}
u.SetConfig(cfg)
var body string
if cfg.Timeformat != nil {
if cfg.Timezone != nil {
tzname := time.Now().In(cfg.Timezone).Format("MST")
body = fmt.Sprintf("Timestamp is toggled ON, timezone is %q", tzname)
} else {
body = "Timestamp is toggled ON, timezone is UTC"
}
} else {
body = "Timestamp is toggled OFF"
}
room.Send(message.NewSystemMsg(body, u))
return nil
},
})
c.Add(Command{
Prefix: "/ignore",
PrefixHelp: "[USER]",
Help: "Hide messages from USER, /unignore USER to stop hiding.",
Handler: func(room *Room, msg message.CommandMsg) error {
id := strings.TrimSpace(strings.TrimLeft(msg.Body(), "/ignore"))
if id == "" {
// Print ignored names, if any.
var names []string
msg.From().Ignored.Each(func(_ string, item set.Item) error {
names = append(names, item.Key())
return nil
})
var systemMsg string
if len(names) == 0 {
systemMsg = "0 users ignored."
} else {
systemMsg = fmt.Sprintf("%d ignored: %s", len(names), strings.Join(names, ", "))
}
room.Send(message.NewSystemMsg(systemMsg, msg.From()))
return nil
}
if id == msg.From().ID() {
return errors.New("cannot ignore self")
}
target, ok := room.MemberByID(id)
if !ok {
return fmt.Errorf("user not found: %s", id)
}
err := msg.From().Ignored.Add(set.Itemize(id, target))
if err == set.ErrCollision {
return fmt.Errorf("user already ignored: %s", id)
} else if err != nil {
return err
}
room.Send(message.NewSystemMsg(fmt.Sprintf("Ignoring: %s", target.Name()), msg.From()))
return nil
},
})
c.Add(Command{
Prefix: "/unignore",
PrefixHelp: "USER",
Handler: func(room *Room, msg message.CommandMsg) error {
id := strings.TrimSpace(strings.TrimLeft(msg.Body(), "/unignore"))
if id == "" {
return errors.New("must specify user")
}
if err := msg.From().Ignored.Remove(id); err != nil {
return err
}
room.Send(message.NewSystemMsg(fmt.Sprintf("No longer ignoring: %s", id), msg.From()))
return nil
},
})
c.Add(Command{
Prefix: "/focus",
PrefixHelp: "[USER ...]",
Help: "Only show messages from focused users, or $ to reset.",
Handler: func(room *Room, msg message.CommandMsg) error {
ids := strings.TrimSpace(strings.TrimLeft(msg.Body(), "/focus"))
if ids == "" {
// Print focused names, if any.
var names []string
msg.From().Focused.Each(func(_ string, item set.Item) error {
names = append(names, item.Key())
return nil
})
var systemMsg string
if len(names) == 0 {
systemMsg = "Unfocused."
} else {
systemMsg = fmt.Sprintf("Focusing on %d users: %s", len(names), strings.Join(names, ", "))
}
room.Send(message.NewSystemMsg(systemMsg, msg.From()))
return nil
}
n := msg.From().Focused.Clear()
if ids == "$" {
room.Send(message.NewSystemMsg(fmt.Sprintf("Removed focus from %d users.", n), msg.From()))
return nil
}
var focused []string
for _, name := range strings.Split(ids, " ") {
id := sanitize.Name(name)
if id == "" {
continue // Skip
}
focused = append(focused, id)
if err := msg.From().Focused.Set(set.Itemize(id, set.ZeroValue)); err != nil {
return err
}
}
room.Send(message.NewSystemMsg(fmt.Sprintf("Focusing: %s", strings.Join(focused, ", ")), msg.From()))
return nil
},
})
c.Add(Command{
Prefix: "/away",
PrefixHelp: "[REASON]",
Help: "Set away reason, or empty to unset.",
Handler: func(room *Room, msg message.CommandMsg) error {
awayMsg := strings.TrimSpace(strings.TrimLeft(msg.Body(), "/away"))
isAway, _, _ := msg.From().GetAway()
msg.From().SetAway(awayMsg)
if awayMsg != "" {
room.Send(message.NewEmoteMsg("has gone away: "+awayMsg, msg.From()))
return nil
}
if isAway {
room.Send(message.NewEmoteMsg("is back.", msg.From()))
return nil
}
return errors.New("not away. Append a reason message to set away")
},
})
c.Add(Command{
Prefix: "/back",
Help: "Clear away status.",
Handler: func(room *Room, msg message.CommandMsg) error {
isAway, _, _ := msg.From().GetAway()
if isAway {
msg.From().SetAway("")
room.Send(message.NewEmoteMsg("is back.", msg.From()))
return nil
}
return errors.New("must be away to be back")
},
})
c.Add(Command{
Op: true,
Prefix: "/mute",
PrefixHelp: "USER",
Help: "Toggle muting USER, preventing messages from broadcasting.",
Handler: func(room *Room, msg message.CommandMsg) error {
if !room.IsOp(msg.From()) {
return errors.New("must be op")
}
args := msg.Args()
if len(args) == 0 {
return errors.New("must specify user")
}
member, ok := room.MemberByID(args[0])
if !ok {
return errors.New("user not found")
}
setMute := !member.IsMuted()
member.SetMute(setMute)
id := member.ID()
if setMute {
room.Send(message.NewSystemMsg("Muted: "+id, msg.From()))
} else {
room.Send(message.NewSystemMsg("Unmuted: "+id, msg.From()))
}
return nil
},
})
} }

View File

@ -1,70 +0,0 @@
package chat
import (
"fmt"
"testing"
"github.com/shazow/ssh-chat/chat/message"
)
func TestAwayCommands(t *testing.T) {
cmds := &Commands{}
InitCommands(cmds)
room := NewRoom()
go room.Serve()
defer room.Close()
// steps are order dependent
// User can be "away" or "not away" using 3 commands "/away [msg]", "/away", "/back"
// 2^3 possible cases, run all and verify state at the end
type step struct {
// input
Msg string
// expected output
IsUserAway bool
AwayMessage string
// expected state change
ExpectsError func(awayBefore bool) bool
}
neverError := func(_ bool) bool { return false }
// if the user was away before, then the error is expected
errorIfAwayBefore := func(awayBefore bool) bool { return awayBefore }
awayStep := step{"/away snorkling", true, "snorkling", neverError}
notAwayStep := step{"/away", false, "", errorIfAwayBefore}
backStep := step{"/back", false, "", errorIfAwayBefore}
steps := []step{awayStep, notAwayStep, backStep}
cases := [][]int{
{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 0, 1}, {2, 1, 0},
}
for _, c := range cases {
t.Run(fmt.Sprintf("Case: %d, %d, %d", c[0], c[1], c[2]), func(t *testing.T) {
u := message.NewUser(message.SimpleID("shark"))
for _, s := range []step{steps[c[0]], steps[c[1]], steps[c[2]]} {
msg, _ := message.NewPublicMsg(s.Msg, u).ParseCommand()
awayBeforeCommand, _, _ := u.GetAway()
err := cmds.Run(room, *msg)
if err != nil && s.ExpectsError(awayBeforeCommand) {
t.Fatalf("unexpected error running the command: %+v", err)
}
isAway, _, awayMsg := u.GetAway()
if isAway != s.IsUserAway {
t.Fatalf("expected user away state '%t' not equals to actual '%t' after message '%s'", s.IsUserAway, isAway, s.Msg)
}
if awayMsg != s.AwayMessage {
t.Fatalf("expected user away message '%s' not equal to actual '%s' after message '%s'", s.AwayMessage, awayMsg, s.Msg)
}
}
})
}
}

View File

@ -1,6 +1,7 @@
/* /*
`chat` package is a server-agnostic implementation of a chat interface, built `chat` package is a server-agnostic implementation of a chat interface, built
to be used as the backend for ssh-chat. with the intention of using with the intention of using as the backend for
ssh-chat.
This package should not know anything about sockets. It should expose io-style This package should not know anything about sockets. It should expose io-style
interfaces and rooms for communicating with any method of transnport. interfaces and rooms for communicating with any method of transnport.

View File

@ -1,13 +1,10 @@
package chat package chat
import ( import "io"
"io" import stdlog "log"
stdlog "log"
)
var logger *stdlog.Logger var logger *stdlog.Logger
// SetLogger changes the logger used for logging inside the package
func SetLogger(w io.Writer) { func SetLogger(w io.Writer) {
flags := stdlog.Flags() flags := stdlog.Flags()
prefix := "[chat] " prefix := "[chat] "

View File

@ -2,25 +2,25 @@ package message
// Identifier is an interface that can uniquely identify itself. // Identifier is an interface that can uniquely identify itself.
type Identifier interface { type Identifier interface {
ID() string Id() string
SetID(string) SetId(string)
Name() string Name() string
} }
// SimpleID is a simple Identifier implementation used for testing. // SimpleId is a simple Identifier implementation used for testing.
type SimpleID string type SimpleId string
// ID returns the ID as a string. // Id returns the Id as a string.
func (i SimpleID) ID() string { func (i SimpleId) Id() string {
return string(i) return string(i)
} }
// SetID is a no-op // SetId is a no-op
func (i SimpleID) SetID(s string) { func (i SimpleId) SetId(s string) {
// no-op // no-op
} }
// Name returns the ID // Name returns the Id
func (i SimpleID) Name() string { func (i SimpleId) Name() string {
return i.ID() return i.Id()
} }

View File

@ -48,21 +48,21 @@ func NewMsg(body string) *Msg {
} }
// Render message based on a theme. // Render message based on a theme.
func (m Msg) Render(t *Theme) string { func (m *Msg) Render(t *Theme) string {
// TODO: Render based on theme // TODO: Render based on theme
// TODO: Cache based on theme // TODO: Cache based on theme
return m.String() return m.String()
} }
func (m Msg) String() string { func (m *Msg) String() string {
return m.body return m.body
} }
func (m Msg) Command() string { func (m *Msg) Command() string {
return "" return ""
} }
func (m Msg) Timestamp() time.Time { func (m *Msg) Timestamp() time.Time {
return m.timestamp return m.timestamp
} }
@ -72,8 +72,8 @@ type PublicMsg struct {
from *User from *User
} }
func NewPublicMsg(body string, from *User) PublicMsg { func NewPublicMsg(body string, from *User) *PublicMsg {
return PublicMsg{ return &PublicMsg{
Msg: Msg{ Msg: Msg{
body: body, body: body,
timestamp: time.Now(), timestamp: time.Now(),
@ -82,11 +82,11 @@ func NewPublicMsg(body string, from *User) PublicMsg {
} }
} }
func (m PublicMsg) From() *User { func (m *PublicMsg) From() *User {
return m.from return m.from
} }
func (m PublicMsg) ParseCommand() (*CommandMsg, bool) { func (m *PublicMsg) ParseCommand() (*CommandMsg, bool) {
// Check if the message is a command // Check if the message is a command
if !strings.HasPrefix(m.body, "/") { if !strings.HasPrefix(m.body, "/") {
return nil, false return nil, false
@ -104,7 +104,7 @@ func (m PublicMsg) ParseCommand() (*CommandMsg, bool) {
return &msg, true return &msg, true
} }
func (m PublicMsg) Render(t *Theme) string { func (m *PublicMsg) Render(t *Theme) string {
if t == nil { if t == nil {
return m.String() return m.String()
} }
@ -112,8 +112,7 @@ func (m PublicMsg) Render(t *Theme) string {
return fmt.Sprintf("%s: %s", t.ColorName(m.from), m.body) return fmt.Sprintf("%s: %s", t.ColorName(m.from), m.body)
} }
// RenderFor renders the message for other users to see. func (m *PublicMsg) RenderFor(cfg UserConfig) string {
func (m PublicMsg) RenderFor(cfg UserConfig) string {
if cfg.Highlight == nil || cfg.Theme == nil { if cfg.Highlight == nil || cfg.Theme == nil {
return m.Render(cfg.Theme) return m.Render(cfg.Theme)
} }
@ -129,19 +128,13 @@ func (m PublicMsg) RenderFor(cfg UserConfig) string {
return fmt.Sprintf("%s: %s", cfg.Theme.ColorName(m.from), body) return fmt.Sprintf("%s: %s", cfg.Theme.ColorName(m.from), body)
} }
// RenderSelf renders the message for when it's echoing your own message. func (m *PublicMsg) String() string {
func (m PublicMsg) RenderSelf(cfg UserConfig) string {
if cfg.Theme == nil {
return fmt.Sprintf("[%s] %s", m.from.Name(), m.body)
}
return fmt.Sprintf("[%s] %s", cfg.Theme.ColorName(m.from), m.body)
}
func (m PublicMsg) String() string {
return fmt.Sprintf("%s: %s", m.from.Name(), m.body) return fmt.Sprintf("%s: %s", m.from.Name(), m.body)
} }
// EmoteMsg is a /me message sent to the room. // EmoteMsg is a /me message sent to the room. It specifically does not
// extend PublicMsg because it doesn't implement MessageFrom to allow the
// sender to see the emote.
type EmoteMsg struct { type EmoteMsg struct {
Msg Msg
from *User from *User
@ -157,15 +150,11 @@ func NewEmoteMsg(body string, from *User) *EmoteMsg {
} }
} }
func (m EmoteMsg) From() *User { func (m *EmoteMsg) Render(t *Theme) string {
return m.from
}
func (m EmoteMsg) Render(t *Theme) string {
return fmt.Sprintf("** %s %s", m.from.Name(), m.body) return fmt.Sprintf("** %s %s", m.from.Name(), m.body)
} }
func (m EmoteMsg) String() string { func (m *EmoteMsg) String() string {
return m.Render(nil) return m.Render(nil)
} }
@ -175,31 +164,22 @@ type PrivateMsg struct {
to *User to *User
} }
func NewPrivateMsg(body string, from *User, to *User) PrivateMsg { func NewPrivateMsg(body string, from *User, to *User) *PrivateMsg {
return PrivateMsg{ return &PrivateMsg{
PublicMsg: NewPublicMsg(body, from), PublicMsg: *NewPublicMsg(body, from),
to: to, to: to,
} }
} }
func (m PrivateMsg) To() *User { func (m *PrivateMsg) To() *User {
return m.to return m.to
} }
func (m PrivateMsg) From() *User { func (m *PrivateMsg) Render(t *Theme) string {
return m.from return fmt.Sprintf("[PM from %s] %s", m.from.Name(), m.body)
} }
func (m PrivateMsg) Render(t *Theme) string { func (m *PrivateMsg) String() string {
format := "[PM from %s] %s"
if t == nil {
return fmt.Sprintf(format, m.from.ID(), m.body)
}
s := fmt.Sprintf(format, m.from.Name(), m.body)
return t.ColorPM(s)
}
func (m PrivateMsg) String() string {
return m.Render(nil) return m.Render(nil)
} }
@ -250,31 +230,31 @@ func NewAnnounceMsg(body string) *AnnounceMsg {
} }
} }
func (m AnnounceMsg) Render(t *Theme) string { func (m *AnnounceMsg) Render(t *Theme) string {
if t == nil { if t == nil {
return m.String() return m.String()
} }
return t.ColorSys(m.String()) return t.ColorSys(m.String())
} }
func (m AnnounceMsg) String() string { func (m *AnnounceMsg) String() string {
return fmt.Sprintf(" * %s", m.body) return fmt.Sprintf(" * %s", m.body)
} }
type CommandMsg struct { type CommandMsg struct {
PublicMsg *PublicMsg
command string command string
args []string args []string
} }
func (m CommandMsg) Command() string { func (m *CommandMsg) Command() string {
return m.command return m.command
} }
func (m CommandMsg) Args() []string { func (m *CommandMsg) Args() []string {
return m.args return m.args
} }
func (m CommandMsg) Body() string { func (m *CommandMsg) Body() string {
return m.body return m.body
} }

View File

@ -11,7 +11,7 @@ func TestMessage(t *testing.T) {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
u := NewUser(SimpleID("foo")) u := NewUser(SimpleId("foo"))
expected = "foo: hello" expected = "foo: hello"
actual = NewPublicMsg("hello", u).String() actual = NewPublicMsg("hello", u).String()
if actual != expected { if actual != expected {

View File

@ -21,10 +21,6 @@ func (s *MockScreen) Read(p *[]byte) (n int, err error) {
return len(*p), nil return len(*p), nil
} }
func (s *MockScreen) Close() error {
return nil
}
func TestScreen(t *testing.T) { func TestScreen(t *testing.T) {
var actual, expected []byte var actual, expected []byte

View File

@ -1,8 +1,6 @@
package message package message
import ( import "fmt"
"fmt"
)
const ( const (
// Reset resets the color // Reset resets the color
@ -64,18 +62,6 @@ func (c Color256) Format(s string) string {
return "\033[" + c.String() + "m" + s + Reset return "\033[" + c.String() + "m" + s + Reset
} }
func Color256Palette(colors ...uint8) *Palette {
size := len(colors)
p := make([]Style, 0, size)
for _, color := range colors {
p = append(p, Color256(color))
}
return &Palette{
colors: p,
size: size,
}
}
// No color, used for mono theme // No color, used for mono theme
type Color0 struct{} type Color0 struct{}
@ -97,9 +83,6 @@ type Palette struct {
// Get a color by index, overflows are looped around. // Get a color by index, overflows are looped around.
func (p Palette) Get(i int) Style { func (p Palette) Get(i int) Style {
if p.size == 1 {
return p.colors[0]
}
return p.colors[i%(p.size-1)] return p.colors[i%(p.size-1)]
} }
@ -122,60 +105,45 @@ type Theme struct {
pm Style pm Style
highlight Style highlight Style
names *Palette names *Palette
useID bool
} }
func (theme Theme) ID() string { func (t Theme) Id() string {
return theme.id return t.id
} }
// Colorize name string given some index // Colorize name string given some index
func (theme Theme) ColorName(u *User) string { func (t Theme) ColorName(u *User) string {
var name string if t.names == nil {
if theme.useID { return u.Name()
name = u.ID()
} else {
name = u.Name()
}
if theme.names == nil {
return name
} }
return theme.names.Get(u.colorIdx).Format(name) return t.names.Get(u.colorIdx).Format(u.Name())
} }
// Colorize the PM string // Colorize the PM string
func (theme Theme) ColorPM(s string) string { func (t Theme) ColorPM(s string) string {
if theme.pm == nil { if t.pm == nil {
return s return s
} }
return theme.pm.Format(s) return t.pm.Format(s)
} }
// Colorize the Sys message // Colorize the Sys message
func (theme Theme) ColorSys(s string) string { func (t Theme) ColorSys(s string) string {
if theme.sys == nil { if t.sys == nil {
return s return s
} }
return theme.sys.Format(s) return t.sys.Format(s)
} }
// Highlight a matched string, usually name // Highlight a matched string, usually name
func (theme Theme) Highlight(s string) string { func (t Theme) Highlight(s string) string {
if theme.highlight == nil { if t.highlight == nil {
return s return s
} }
return theme.highlight.Format(s) return t.highlight.Format(s)
}
// Timestamp colorizes the timestamp.
func (theme Theme) Timestamp(s string) string {
if theme.sys == nil {
return s
}
return theme.sys.Format(s)
} }
// List of initialzied themes // List of initialzied themes
@ -184,92 +152,44 @@ var Themes []Theme
// Default theme to use // Default theme to use
var DefaultTheme *Theme var DefaultTheme *Theme
// MonoTheme is a simple theme without colors, useful for testing and bots.
var MonoTheme *Theme
func allColors256() *Palette {
colors := []uint8{}
var i uint8
for i = 0; i < 255; i++ {
colors = append(colors, i)
}
return Color256Palette(colors...)
}
func readableColors256() *Palette { func readableColors256() *Palette {
colors := []uint8{} size := 247
var i uint8 p := Palette{
for i = 0; i < 255; i++ { colors: make([]Style, size),
if i == 0 || i == 7 || i == 8 || i == 15 || i == 16 || i == 17 || i > 230 { size: size,
// Skip 31 Shades of Grey, and one hyperintelligent shade of blue. }
j := 0
for i := 0; i < 256; i++ {
if (16 <= i && i <= 18) || (232 <= i && i <= 237) {
// Remove the ones near black, this is kinda sadpanda.
continue continue
} }
colors = append(colors, i) p.colors[j] = Color256(i)
j++
} }
return Color256Palette(colors...) return &p
} }
func init() { func init() {
palette := readableColors256()
Themes = []Theme{ Themes = []Theme{
{ Theme{
id: "colors", id: "colors",
names: readableColors256(), names: palette,
sys: Color256(245), // Grey sys: palette.Get(8), // Grey
pm: Color256(7), // White pm: palette.Get(7), // White
highlight: style(Bold + "\033[48;5;11m\033[38;5;16m"), // Yellow highlight highlight: style(Bold + "\033[48;5;11m\033[38;5;16m"), // Yellow highlight
}, },
{ Theme{
id: "solarized", id: "mono",
names: Color256Palette(1, 2, 3, 4, 5, 6, 7, 9, 13),
sys: Color256(11), // Yellow
pm: Color256(15), // White
highlight: style(Bold + "\033[48;5;3m\033[38;5;94m"), // Orange highlight
},
{
id: "hacker",
names: Color256Palette(82), // Green
sys: Color256(22), // Another green
pm: Color256(28), // More green, slightly lighter
highlight: style(Bold + "\033[48;5;22m\033[38;5;46m"), // Green on dark green
},
{
id: "mono",
useID: true,
}, },
} }
// Debug for printing colors:
//for _, color := range palette.colors {
// fmt.Print(color.Format(color.String() + " "))
//}
DefaultTheme = &Themes[0] DefaultTheme = &Themes[0]
MonoTheme = &Themes[3]
/* Some debug helpers for your convenience:
// Debug for palettes
printPalette(allColors256())
// Debug for themes
for _, t := range Themes {
printTheme(t)
}
*/
}
func printTheme(t Theme) {
fmt.Println("Printing theme:", t.ID())
if t.names != nil {
for i, color := range t.names.colors {
fmt.Printf("%s ", color.Format(fmt.Sprintf("name%d", i)))
}
fmt.Println("")
}
fmt.Println(t.ColorSys("SystemMsg"))
fmt.Println(t.ColorPM("PrivateMsg"))
fmt.Println(t.Highlight("Highlight"))
fmt.Println("")
}
func printPalette(p *Palette) {
for i, color := range p.colors {
fmt.Printf("%d\t%s\n", i, color.Format(color.String()+" "))
}
} }

View File

@ -1,6 +1,9 @@
package message package message
import "testing" import (
"fmt"
"testing"
)
func TestThemePalette(t *testing.T) { func TestThemePalette(t *testing.T) {
var expected, actual string var expected, actual string
@ -12,21 +15,21 @@ func TestThemePalette(t *testing.T) {
} }
actual = color.String() actual = color.String()
expected = "38;05;6" expected = "38;05;5"
if actual != expected { if actual != expected {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
actual = color.Format("foo") actual = color.Format("foo")
expected = "\033[38;05;6mfoo\033[0m" expected = "\033[38;05;5mfoo\033[0m"
if actual != expected { if actual != expected {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
actual = palette.Get(palette.Len() + 1).String() actual = palette.Get(palette.Len() + 1).String()
expected = "38;05;3" expected = fmt.Sprintf("38;05;%d", 2)
if actual != expected { if actual != expected {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
} }
@ -41,28 +44,28 @@ func TestTheme(t *testing.T) {
} }
actual = color.Format("foo") actual = color.Format("foo")
expected = "\033[38;05;245mfoo\033[0m" expected = "\033[38;05;8mfoo\033[0m"
if actual != expected { if actual != expected {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
actual = colorTheme.ColorSys("foo") actual = colorTheme.ColorSys("foo")
if actual != expected { if actual != expected {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
u := NewUser(SimpleID("foo")) u := NewUser(SimpleId("foo"))
u.colorIdx = 4 u.colorIdx = 4
actual = colorTheme.ColorName(u) actual = colorTheme.ColorName(u)
expected = "\033[38;05;5mfoo\033[0m" expected = "\033[38;05;4mfoo\033[0m"
if actual != expected { if actual != expected {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
msg := NewPublicMsg("hello", u) msg := NewPublicMsg("hello", u)
actual = msg.Render(&colorTheme) actual = msg.Render(&colorTheme)
expected = "\033[38;05;5mfoo\033[0m: hello" expected = "\033[38;05;4mfoo\033[0m: hello"
if actual != expected { if actual != expected {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
} }

View File

@ -8,180 +8,99 @@ import (
"regexp" "regexp"
"sync" "sync"
"time" "time"
"github.com/shazow/ssh-chat/set"
) )
const messageBuffer = 5 const messageBuffer = 20
const messageTimeout = 5 * time.Second
const reHighlight = `\b(%s)\b` const reHighlight = `\b(%s)\b`
const timestampTimeout = 30 * time.Minute
var ErrUserClosed = errors.New("user closed") var ErrUserClosed = errors.New("user closed")
// User definition, implemented set Item interface and io.Writer // User definition, implemented set Item interface and io.Writer
type User struct { type User struct {
Identifier Identifier
OnChange func() Config UserConfig
Ignored set.Interface colorIdx int
Focused set.Interface joined time.Time
colorIdx int msg chan Message
joined time.Time done chan struct{}
msg chan Message replyTo *User // Set when user gets a /msg, for replying.
done chan struct{} closed bool
screen io.WriteCloser
closeOnce sync.Once closeOnce sync.Once
mu sync.Mutex
config UserConfig
replyTo *User // Set when user gets a /msg, for replying.
lastMsg time.Time // When the last message was rendered.
awayReason string // Away reason, "" when not away.
awaySince time.Time // When away was set, 0 when not away.
} }
func NewUser(identity Identifier) *User { func NewUser(identity Identifier) *User {
u := User{ u := User{
Identifier: identity, Identifier: identity,
config: DefaultUserConfig, Config: *DefaultUserConfig,
joined: time.Now(), joined: time.Now(),
msg: make(chan Message, messageBuffer), msg: make(chan Message, messageBuffer),
done: make(chan struct{}), done: make(chan struct{}, 1),
Ignored: set.New(),
Focused: set.New(),
} }
u.setColorIdx(rand.Int()) u.SetColorIdx(rand.Int())
return &u return &u
} }
func NewUserScreen(identity Identifier, screen io.WriteCloser) *User { func NewUserScreen(identity Identifier, screen io.Writer) *User {
u := NewUser(identity) u := NewUser(identity)
u.screen = screen go u.Consume(screen)
return u return u
} }
func (u *User) Joined() time.Time {
return u.joined
}
func (u *User) LastMsg() time.Time {
u.mu.Lock()
defer u.mu.Unlock()
return u.lastMsg
}
// SetAway sets the users away reason and state.
func (u *User) SetAway(msg string) {
u.mu.Lock()
defer u.mu.Unlock()
u.awayReason = msg
if msg == "" {
u.awaySince = time.Time{}
} else {
// Reset away timer even if already away
u.awaySince = time.Now()
}
}
// GetAway returns if the user is away, when they went away, and the reason.
func (u *User) GetAway() (bool, time.Time, string) {
u.mu.Lock()
defer u.mu.Unlock()
return u.awayReason != "", u.awaySince, u.awayReason
}
func (u *User) Config() UserConfig {
u.mu.Lock()
defer u.mu.Unlock()
return u.config
}
func (u *User) SetConfig(cfg UserConfig) {
u.mu.Lock()
u.config = cfg
u.mu.Unlock()
if u.OnChange != nil {
u.OnChange()
}
}
// Rename the user with a new Identifier. // Rename the user with a new Identifier.
func (u *User) SetID(id string) { func (u *User) SetId(id string) {
u.Identifier.SetID(id) u.Identifier.SetId(id)
u.setColorIdx(rand.Int()) u.SetColorIdx(rand.Int())
if u.OnChange != nil {
u.OnChange()
}
} }
// ReplyTo returns the last user that messaged this user. // ReplyTo returns the last user that messaged this user.
func (u *User) ReplyTo() *User { func (u *User) ReplyTo() *User {
u.mu.Lock()
defer u.mu.Unlock()
return u.replyTo return u.replyTo
} }
// SetReplyTo sets the last user to message this user. // SetReplyTo sets the last user to message this user.
func (u *User) SetReplyTo(user *User) { func (u *User) SetReplyTo(user *User) {
u.mu.Lock()
defer u.mu.Unlock()
u.replyTo = user u.replyTo = user
} }
// setColorIdx will set the colorIdx to a specific value, primarily used for // ToggleQuietMode will toggle whether or not quiet mode is enabled
func (u *User) ToggleQuietMode() {
u.Config.Quiet = !u.Config.Quiet
}
// SetColorIdx will set the colorIdx to a specific value, primarily used for
// testing. // testing.
func (u *User) setColorIdx(idx int) { func (u *User) SetColorIdx(idx int) {
u.colorIdx = idx u.colorIdx = idx
} }
// Block until user is closed
func (u *User) Wait() {
<-u.done
}
// Disconnect user, stop accepting messages // Disconnect user, stop accepting messages
func (u *User) Close() { func (u *User) Close() {
u.closeOnce.Do(func() { u.closeOnce.Do(func() {
if u.screen != nil { u.closed = true
if err := u.screen.Close(); err != nil {
logger.Printf("Failed to close user %q screen: %s", u.ID(), err)
}
}
// close(u.msg) TODO: Close?
close(u.done) close(u.done)
close(u.msg)
}) })
} }
// Consume message buffer into the handler. Will block, should be called in a // Consume message buffer into an io.Writer. Will block, should be called in a
// goroutine. // goroutine.
func (u *User) Consume() { // TODO: Not sure if this is a great API.
for { func (u *User) Consume(out io.Writer) {
select { for m := range u.msg {
case <-u.done: u.HandleMsg(m, out)
return
case m, ok := <-u.msg:
if !ok {
return
}
u.HandleMsg(m)
}
} }
} }
// Consume one message and stop, mostly for testing // Consume one message and stop, mostly for testing
func (u *User) ConsumeOne() Message { func (u *User) ConsumeChan() <-chan Message {
return <-u.msg return u.msg
}
// Check if there are pending messages, used for testing
func (u *User) HasMessages() bool {
select {
case msg := <-u.msg:
u.msg <- msg
return true
default:
return false
}
} }
// SetHighlight sets the highlighting regular expression to match string. // SetHighlight sets the highlighting regular expression to match string.
@ -190,79 +109,41 @@ func (u *User) SetHighlight(s string) error {
if err != nil { if err != nil {
return err return err
} }
u.mu.Lock() u.Config.Highlight = re
u.config.Highlight = re
u.mu.Unlock()
return nil return nil
} }
func (u *User) render(m Message) string { func (u *User) render(m Message) string {
cfg := u.Config()
var out string
switch m := m.(type) { switch m := m.(type) {
case PublicMsg: case *PublicMsg:
if u == m.From() { return m.RenderFor(u.Config) + Newline
u.mu.Lock()
u.lastMsg = m.Timestamp()
u.mu.Unlock()
if !cfg.Echo {
return ""
}
out += m.RenderSelf(cfg)
} else if u.Focused.Len() > 0 && !u.Focused.In(m.From().ID()) {
// Skip message during focus
return ""
} else {
out += m.RenderFor(cfg)
}
case *PrivateMsg: case *PrivateMsg:
out += m.Render(cfg.Theme) u.SetReplyTo(m.From())
if cfg.Bell { return m.Render(u.Config.Theme) + Newline
out += Bel
}
case *CommandMsg:
out += m.RenderSelf(cfg)
default: default:
out += m.Render(cfg.Theme) return m.Render(u.Config.Theme) + Newline
} }
if cfg.Timeformat != nil {
ts := m.Timestamp()
if cfg.Timezone != nil {
ts = ts.In(cfg.Timezone)
} else {
ts = ts.UTC()
}
return cfg.Theme.Timestamp(ts.Format(*cfg.Timeformat)) + " " + out + Newline
}
return out + Newline
} }
// writeMsg renders the message and attempts to write it, will Close the user func (u *User) HandleMsg(m Message, out io.Writer) {
// if it fails.
func (u *User) writeMsg(m Message) error {
r := u.render(m) r := u.render(m)
_, err := u.screen.Write([]byte(r)) _, err := out.Write([]byte(r))
if err != nil { if err != nil {
logger.Printf("Write failed to %s, closing: %s", u.ID(), err) logger.Printf("Write failed to %s, closing: %s", u.Name(), err)
u.Close() u.Close()
} }
return err
}
// HandleMsg will render the message to the screen, blocking.
func (u *User) HandleMsg(m Message) error {
return u.writeMsg(m)
} }
// Add message to consume by user // Add message to consume by user
func (u *User) Send(m Message) error { func (u *User) Send(m Message) error {
select { if u.closed {
case <-u.done:
return ErrUserClosed return ErrUserClosed
}
select {
case u.msg <- m: case u.msg <- m:
case <-time.After(messageTimeout): default:
logger.Printf("Message buffer full, closing: %s", u.ID()) logger.Printf("Msg buffer full, closing: %s", u.Name())
u.Close() u.Close()
return ErrUserClosed return ErrUserClosed
} }
@ -271,50 +152,20 @@ func (u *User) Send(m Message) error {
// Container for per-user configurations. // Container for per-user configurations.
type UserConfig struct { type UserConfig struct {
Highlight *regexp.Regexp Highlight *regexp.Regexp
Bell bool Bell bool
Quiet bool Quiet bool
Echo bool // Echo shows your own messages after sending, disabled for bots Theme *Theme
Timeformat *string
Timezone *time.Location
Theme *Theme
} }
// Default user configuration to use // Default user configuration to use
var DefaultUserConfig UserConfig var DefaultUserConfig *UserConfig
func init() { func init() {
DefaultUserConfig = UserConfig{ DefaultUserConfig = &UserConfig{
Bell: true, Bell: true,
Echo: true,
Quiet: false, Quiet: false,
} }
// TODO: Seed random? // TODO: Seed random?
} }
// RecentActiveUsers is a slice of *Users that knows how to be sorted by the
// time of the last message. If no message has been sent, then fall back to the
// time joined instead.
type RecentActiveUsers []*User
func (a RecentActiveUsers) Len() int { return len(a) }
func (a RecentActiveUsers) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a RecentActiveUsers) Less(i, j int) bool {
a[i].mu.Lock()
defer a[i].mu.Unlock()
a[j].mu.Lock()
defer a[j].mu.Unlock()
ai := a[i].lastMsg
if ai.IsZero() {
ai = a[i].joined
}
aj := a[j].lastMsg
if aj.IsZero() {
aj = a[j].joined
}
return ai.After(aj)
}

View File

@ -1,7 +1,6 @@
package message package message
import ( import (
"math/rand"
"reflect" "reflect"
"testing" "testing"
) )
@ -10,17 +9,12 @@ func TestMakeUser(t *testing.T) {
var actual, expected []byte var actual, expected []byte
s := &MockScreen{} s := &MockScreen{}
u := NewUserScreen(SimpleID("foo"), s) u := NewUser(SimpleId("foo"))
cfg := u.Config()
cfg.Theme = MonoTheme // Mono
u.SetConfig(cfg)
m := NewAnnounceMsg("hello") m := NewAnnounceMsg("hello")
defer u.Close() defer u.Close()
u.Send(m) u.Send(m)
u.HandleMsg(u.ConsumeOne()) u.HandleMsg(<-u.ConsumeChan(), s)
s.Read(&actual) s.Read(&actual)
expected = []byte(m.String() + Newline) expected = []byte(m.String() + Newline)
@ -28,34 +22,3 @@ func TestMakeUser(t *testing.T) {
t.Errorf("Got: `%s`; Expected: `%s`", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
} }
func TestRenderTimestamp(t *testing.T) {
var actual, expected []byte
// Reset seed for username color
rand.Seed(1)
s := &MockScreen{}
u := NewUserScreen(SimpleID("foo"), s)
cfg := u.Config()
timefmt := "AA:BB"
cfg.Theme = DefaultTheme
cfg.Timeformat = &timefmt
u.SetConfig(cfg)
if got, want := cfg.Theme.Timestamp("foo"), `foo`+Reset; got != want {
t.Errorf("Wrong timestamp formatting:\n got: %q\nwant: %q", got, want)
}
m := NewPublicMsg("hello", u)
defer u.Close()
u.Send(m)
u.HandleMsg(u.ConsumeOne())
s.Read(&actual)
expected = []byte(`AA:BB` + Reset + ` [foo] hello` + Newline)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Wrong screen output:\n Got: `%q`;\nWant: `%q`", actual, expected)
}
}

View File

@ -4,58 +4,37 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"sort"
"sync" "sync"
"github.com/shazow/ssh-chat/chat/message" "github.com/shazow/ssh-chat/chat/message"
"github.com/shazow/ssh-chat/internal/humantime"
"github.com/shazow/ssh-chat/set"
) )
const historyLen = 20 const historyLen = 20
const roomBuffer = 10 const roomBuffer = 10
// ErrRoomClosed is the error returned when a message is sent to a room that is already // The error returned when a message is sent to a room that is already
// closed. // closed.
var ErrRoomClosed = errors.New("room closed") var ErrRoomClosed = errors.New("room closed")
// ErrInvalidName is the error returned when a user attempts to join with an invalid name, // The error returned when a user attempts to join with an invalid name, such
// such as empty string. // as empty string.
var ErrInvalidName = errors.New("invalid name") var ErrInvalidName = errors.New("invalid name")
// Member is a User with per-Room metadata attached to it. // Member is a User with per-Room metadata attached to it.
type Member struct { type Member struct {
*message.User *message.User
IsOp bool Op bool
// TODO: Move IsOp under mu?
mu sync.Mutex
isMuted bool // When true, messages should not be broadcasted.
}
func (m *Member) IsMuted() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.isMuted
}
func (m *Member) SetMute(muted bool) {
m.mu.Lock()
defer m.mu.Unlock()
m.isMuted = muted
} }
// Room definition, also a Set of User Items // Room definition, also a Set of User Items
type Room struct { type Room struct {
topic string topic string
history *message.History history *message.History
members *idSet
broadcast chan message.Message broadcast chan message.Message
commands Commands commands Commands
closed bool closed bool
closeOnce sync.Once closeOnce sync.Once
Members *set.Set
} }
// NewRoom creates a new room. // NewRoom creates a new room.
@ -65,9 +44,8 @@ func NewRoom() *Room {
return &Room{ return &Room{
broadcast: broadcast, broadcast: broadcast,
history: message.NewHistory(historyLen), history: message.NewHistory(historyLen),
members: newIdSet(),
commands: *defaultCommands, commands: *defaultCommands,
Members: set.New(),
} }
} }
@ -80,11 +58,10 @@ func (r *Room) SetCommands(commands Commands) {
func (r *Room) Close() { func (r *Room) Close() {
r.closeOnce.Do(func() { r.closeOnce.Do(func() {
r.closed = true r.closed = true
r.Members.Each(func(_ string, item set.Item) error { r.members.Each(func(m identified) {
item.Value().(*Member).Close() m.(*Member).Close()
return nil
}) })
r.Members.Clear() r.members.Clear()
close(r.broadcast) close(r.broadcast)
}) })
} }
@ -96,25 +73,6 @@ func (r *Room) SetLogging(out io.Writer) {
// HandleMsg reacts to a message, will block until done. // HandleMsg reacts to a message, will block until done.
func (r *Room) HandleMsg(m message.Message) { func (r *Room) HandleMsg(m message.Message) {
var fromID string
if fromMsg, ok := m.(message.MessageFrom); ok {
fromID = fromMsg.From().ID()
}
if fromID != "" {
if item, err := r.Members.Get(fromID); err != nil {
// Message from a member who is not in the room, this should not happen.
logger.Printf("Room received unexpected message from a non-member: %v", m)
return
} else if member, ok := item.Value().(*Member); ok && member.IsMuted() {
// Short circuit message handling for muted users
if _, ok = m.(*message.CommandMsg); !ok {
member.User.Send(m)
}
return
}
}
switch m := m.(type) { switch m := m.(type) {
case *message.CommandMsg: case *message.CommandMsg:
cmd := *m cmd := *m
@ -125,27 +83,28 @@ func (r *Room) HandleMsg(m message.Message) {
} }
case message.MessageTo: case message.MessageTo:
user := m.To() user := m.To()
if user.Ignored.In(fromID) {
return // Skip ignored
}
user.Send(m) user.Send(m)
default: default:
fromMsg, skip := m.(message.MessageFrom)
var skipUser *message.User
if skip {
skipUser = fromMsg.From()
}
r.history.Add(m) r.history.Add(m)
r.Members.Each(func(_ string, item set.Item) (err error) { r.members.Each(func(u identified) {
user := item.Value().(*Member).User user := u.(*Member).User
if skip && skipUser == user {
if user.Ignored.In(fromID) { // Skip
return // Skip ignored return
} }
if _, ok := m.(*message.AnnounceMsg); ok { if _, ok := m.(*message.AnnounceMsg); ok {
if user.Config().Quiet { if user.Config.Quiet {
return // Skip announcements // Skip
return
} }
} }
user.Send(m) user.Send(m)
return
}) })
} }
} }
@ -172,44 +131,45 @@ func (r *Room) History(u *message.User) {
// Join the room as a user, will announce. // Join the room as a user, will announce.
func (r *Room) Join(u *message.User) (*Member, error) { func (r *Room) Join(u *message.User) (*Member, error) {
// TODO: Check if closed if r.closed {
if u.ID() == "" { return nil, ErrRoomClosed
}
if u.Id() == "" {
return nil, ErrInvalidName return nil, ErrInvalidName
} }
member := &Member{User: u} member := Member{u, false}
err := r.Members.Add(set.Itemize(u.ID(), member)) err := r.members.Add(&member)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: Remove user ID from sets, probably referring to a prior user.
r.History(u) r.History(u)
s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.Members.Len()) s := fmt.Sprintf("%s joined. (Connected: %d)", u.Name(), r.members.Len())
r.Send(message.NewAnnounceMsg(s)) r.Send(message.NewAnnounceMsg(s))
return member, nil return &member, nil
} }
// Leave the room as a user, will announce. Mostly used during setup. // Leave the room as a user, will announce. Mostly used during setup.
func (r *Room) Leave(u *message.User) error { func (r *Room) Leave(u message.Identifier) error {
err := r.Members.Remove(u.ID()) err := r.members.Remove(u)
if err != nil { if err != nil {
return err return err
} }
s := fmt.Sprintf("%s left. (After %s)", u.Name(), humantime.Since(u.Joined())) s := fmt.Sprintf("%s left.", u.Name())
r.Send(message.NewAnnounceMsg(s)) r.Send(message.NewAnnounceMsg(s))
return nil return nil
} }
// Rename member with a new identity. This will not call rename on the member. // Rename member with a new identity. This will not call rename on the member.
func (r *Room) Rename(oldID string, u message.Identifier) error { func (r *Room) Rename(oldId string, identity message.Identifier) error {
if u.ID() == "" { if identity.Id() == "" {
return ErrInvalidName return ErrInvalidName
} }
err := r.Members.Replace(oldID, set.Itemize(u.ID(), u)) err := r.members.Replace(oldId, identity)
if err != nil { if err != nil {
return err return err
} }
s := fmt.Sprintf("%s is now known as %s.", oldID, u.ID()) s := fmt.Sprintf("%s is now known as %s.", oldId, identity.Id())
r.Send(message.NewAnnounceMsg(s)) r.Send(message.NewAnnounceMsg(s))
return nil return nil
} }
@ -217,7 +177,7 @@ func (r *Room) Rename(oldID string, u message.Identifier) error {
// Member returns a corresponding Member object to a User if the Member is // Member returns a corresponding Member object to a User if the Member is
// present in this room. // present in this room.
func (r *Room) Member(u *message.User) (*Member, bool) { func (r *Room) Member(u *message.User) (*Member, bool) {
m, ok := r.MemberByID(u.ID()) m, ok := r.MemberById(u.Id())
if !ok { if !ok {
return nil, false return nil, false
} }
@ -228,22 +188,18 @@ func (r *Room) Member(u *message.User) (*Member, bool) {
return m, true return m, true
} }
// MemberByID Gets a member by an id / name func (r *Room) MemberById(id string) (*Member, bool) {
func (r *Room) MemberByID(id string) (*Member, bool) { m, err := r.members.Get(id)
m, err := r.Members.Get(id)
if err != nil { if err != nil {
return nil, false return nil, false
} }
return m.Value().(*Member), true return m.(*Member), true
} }
// IsOp returns whether a user is an operator in this room. // IsOp returns whether a user is an operator in this room.
func (r *Room) IsOp(u *message.User) bool { func (r *Room) IsOp(u *message.User) bool {
m, ok := r.Member(u) m, ok := r.Member(u)
if !ok { return ok && m.Op
return false
}
return m.IsOp
} }
// Topic of the room. // Topic of the room.
@ -257,21 +213,12 @@ func (r *Room) SetTopic(s string) {
} }
// NamesPrefix lists all members' names with a given prefix, used to query // NamesPrefix lists all members' names with a given prefix, used to query
// for autocompletion purposes. Sorted by which user was last active. // for autocompletion purposes.
func (r *Room) NamesPrefix(prefix string) []string { func (r *Room) NamesPrefix(prefix string) []string {
items := r.Members.ListPrefix(prefix) members := r.members.ListPrefix(prefix)
names := make([]string, len(members))
// Sort results by recently active for i, u := range members {
users := make([]*message.User, 0, len(items)) names[i] = u.(*Member).User.Name()
for _, item := range items {
users = append(users, item.Value().(*Member).User)
}
sort.Sort(message.RecentActiveUsers(users))
// Pull out names
names := make([]string, 0, len(items))
for _, user := range users {
names = append(names, user.ID())
} }
return names return names
} }

View File

@ -1,13 +1,10 @@
package chat package chat
import ( import (
"errors"
"fmt"
"reflect" "reflect"
"testing" "testing"
"github.com/shazow/ssh-chat/chat/message" "github.com/shazow/ssh-chat/chat/message"
"github.com/shazow/ssh-chat/set"
) )
// Used for testing // Used for testing
@ -26,10 +23,6 @@ func (s *MockScreen) Read(p *[]byte) (n int, err error) {
return len(*p), nil return len(*p), nil
} }
func (s *MockScreen) Close() error {
return nil
}
func TestRoomServe(t *testing.T) { func TestRoomServe(t *testing.T) {
ch := NewRoom() ch := NewRoom()
ch.Send(message.NewAnnounceMsg("hello")) ch.Send(message.NewAnnounceMsg("hello"))
@ -39,240 +32,15 @@ func TestRoomServe(t *testing.T) {
expected := " * hello" expected := " * hello"
if actual != expected { if actual != expected {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
} }
type ScreenedUser struct {
user *message.User
screen *MockScreen
}
func TestIgnore(t *testing.T) {
var buffer []byte
ch := NewRoom()
go ch.Serve()
defer ch.Close()
// Create 3 users, join the room and clear their screen buffers
users := make([]ScreenedUser, 3)
for i := 0; i < 3; i++ {
screen := &MockScreen{}
user := message.NewUserScreen(message.SimpleID(fmt.Sprintf("user%d", i)), screen)
users[i] = ScreenedUser{
user: user,
screen: screen,
}
_, err := ch.Join(user)
if err != nil {
t.Fatal(err)
}
}
for _, u := range users {
for i := 0; i < 3; i++ {
u.user.HandleMsg(u.user.ConsumeOne())
u.screen.Read(&buffer)
}
}
// Use some handy variable names for distinguish between roles
ignorer := users[0]
ignored := users[1]
other := users[2]
// test ignoring unexisting user
if err := sendCommand("/ignore test", ignorer, ch, &buffer); err != nil {
t.Fatal(err)
}
expectOutput(t, buffer, "-> Err: user not found: test"+message.Newline)
// test ignoring existing user
if err := sendCommand("/ignore "+ignored.user.Name(), ignorer, ch, &buffer); err != nil {
t.Fatal(err)
}
expectOutput(t, buffer, "-> Ignoring: "+ignored.user.Name()+message.Newline)
// ignoring the same user twice returns an error message and doesn't add the user twice
if err := sendCommand("/ignore "+ignored.user.Name(), ignorer, ch, &buffer); err != nil {
t.Fatal(err)
}
expectOutput(t, buffer, "-> Err: user already ignored: user1"+message.Newline)
if ignoredList := ignorer.user.Ignored.ListPrefix(""); len(ignoredList) != 1 {
t.Fatalf("should have %d ignored users, has %d", 1, len(ignoredList))
}
// when an emote is sent by an ignored user, it should not be displayed for ignorer
ch.HandleMsg(message.NewEmoteMsg("is crying", ignored.user))
if ignorer.user.HasMessages() {
t.Fatal("should not have emote messages")
}
other.user.HandleMsg(other.user.ConsumeOne())
other.screen.Read(&buffer)
expectOutput(t, buffer, "** "+ignored.user.Name()+" is crying"+message.Newline)
// when a message is sent from the ignored user, it is delivered to non-ignoring users
ch.HandleMsg(message.NewPublicMsg("hello", ignored.user))
other.user.HandleMsg(other.user.ConsumeOne())
other.screen.Read(&buffer)
expectOutput(t, buffer, ignored.user.Name()+": hello"+message.Newline)
// ensure ignorer doesn't have received any message
if ignorer.user.HasMessages() {
t.Fatal("should not have messages")
}
// `/ignore` returns a list of ignored users
if err := sendCommand("/ignore", ignorer, ch, &buffer); err != nil {
t.Fatal(err)
}
expectOutput(t, buffer, "-> 1 ignored: "+ignored.user.Name()+message.Newline)
// `/unignore [USER]` removes the user from ignored ones
if err := sendCommand("/unignore "+ignored.user.Name(), users[0], ch, &buffer); err != nil {
t.Fatal(err)
}
expectOutput(t, buffer, "-> No longer ignoring: user1"+message.Newline)
if err := sendCommand("/ignore", users[0], ch, &buffer); err != nil {
t.Fatal(err)
}
expectOutput(t, buffer, "-> 0 users ignored."+message.Newline)
if ignoredList := ignorer.user.Ignored.ListPrefix(""); len(ignoredList) != 0 {
t.Fatalf("should have %d ignored users, has %d", 0, len(ignoredList))
}
// after unignoring a user, its messages can be received again
ch.HandleMsg(message.NewPublicMsg("hello again!", ignored.user))
// ensure ignorer has received the message
if !ignorer.user.HasMessages() {
t.Fatal("should have messages")
}
ignorer.user.HandleMsg(ignorer.user.ConsumeOne())
ignorer.screen.Read(&buffer)
expectOutput(t, buffer, ignored.user.Name()+": hello again!"+message.Newline)
}
func TestMute(t *testing.T) {
var buffer []byte
ch := NewRoom()
go ch.Serve()
defer ch.Close()
// Create 3 users, join the room and clear their screen buffers
users := make([]ScreenedUser, 3)
members := make([]*Member, 3)
for i := 0; i < 3; i++ {
screen := &MockScreen{}
user := message.NewUserScreen(message.SimpleID(fmt.Sprintf("user%d", i)), screen)
users[i] = ScreenedUser{
user: user,
screen: screen,
}
member, err := ch.Join(user)
if err != nil {
t.Fatal(err)
}
members[i] = member
}
for _, u := range users {
for i := 0; i < 3; i++ {
u.user.HandleMsg(u.user.ConsumeOne())
u.screen.Read(&buffer)
}
}
// Use some handy variable names for distinguish between roles
muter := users[0]
muted := users[1]
other := users[2]
members[0].IsOp = true
// test muting unexisting user
if err := sendCommand("/mute test", muter, ch, &buffer); err != nil {
t.Fatal(err)
}
expectOutput(t, buffer, "-> Err: user not found"+message.Newline)
// test muting by non-op
if err := sendCommand("/mute "+muted.user.Name(), other, ch, &buffer); err != nil {
t.Fatal(err)
}
expectOutput(t, buffer, "-> Err: must be op"+message.Newline)
// test muting existing user
if err := sendCommand("/mute "+muted.user.Name(), muter, ch, &buffer); err != nil {
t.Fatal(err)
}
expectOutput(t, buffer, "-> Muted: "+muted.user.Name()+message.Newline)
if got, want := members[1].IsMuted(), true; got != want {
t.Error("muted user failed to set mute flag")
}
// when an emote is sent by a muted user, it should not be displayed for anyone
ch.HandleMsg(message.NewPublicMsg("hello!", muted.user))
ch.HandleMsg(message.NewEmoteMsg("is crying", muted.user))
if muter.user.HasMessages() {
muter.user.HandleMsg(muter.user.ConsumeOne())
muter.screen.Read(&buffer)
t.Errorf("muter should not have messages: %s", buffer)
}
if other.user.HasMessages() {
other.user.HandleMsg(other.user.ConsumeOne())
other.screen.Read(&buffer)
t.Errorf("other should not have messages: %s", buffer)
}
// test unmuting
if err := sendCommand("/mute "+muted.user.Name(), muter, ch, &buffer); err != nil {
t.Fatal(err)
}
expectOutput(t, buffer, "-> Unmuted: "+muted.user.Name()+message.Newline)
ch.HandleMsg(message.NewPublicMsg("hello again!", muted.user))
other.user.HandleMsg(other.user.ConsumeOne())
other.screen.Read(&buffer)
expectOutput(t, buffer, muted.user.Name()+": hello again!"+message.Newline)
}
func expectOutput(t *testing.T, buffer []byte, expected string) {
t.Helper()
bytes := []byte(expected)
if !reflect.DeepEqual(buffer, bytes) {
t.Errorf("Got: %q; Expected: %q", buffer, expected)
}
}
func sendCommand(cmd string, mock ScreenedUser, room *Room, buffer *[]byte) error {
msg, ok := message.NewPublicMsg(cmd, mock.user).ParseCommand()
if !ok {
return errors.New("cannot parse command message")
}
room.Send(msg)
mock.user.HandleMsg(mock.user.ConsumeOne())
mock.screen.Read(buffer)
return nil
}
func TestRoomJoin(t *testing.T) { func TestRoomJoin(t *testing.T) {
var expected, actual []byte var expected, actual []byte
s := &MockScreen{} s := &MockScreen{}
u := message.NewUserScreen(message.SimpleID("foo"), s) u := message.NewUser(message.SimpleId("foo"))
ch := NewRoom() ch := NewRoom()
go ch.Serve() go ch.Serve()
@ -283,35 +51,35 @@ func TestRoomJoin(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
u.HandleMsg(u.ConsumeOne()) u.HandleMsg(<-u.ConsumeChan(), s)
expected = []byte(" * foo joined. (Connected: 1)" + message.Newline) expected = []byte(" * foo joined. (Connected: 1)" + message.Newline)
s.Read(&actual) s.Read(&actual)
if !reflect.DeepEqual(actual, expected) { if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
ch.Send(message.NewSystemMsg("hello", u)) ch.Send(message.NewSystemMsg("hello", u))
u.HandleMsg(u.ConsumeOne()) u.HandleMsg(<-u.ConsumeChan(), s)
expected = []byte("-> hello" + message.Newline) expected = []byte("-> hello" + message.Newline)
s.Read(&actual) s.Read(&actual)
if !reflect.DeepEqual(actual, expected) { if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
ch.Send(message.ParseInput("/me says hello.", u)) ch.Send(message.ParseInput("/me says hello.", u))
u.HandleMsg(u.ConsumeOne()) u.HandleMsg(<-u.ConsumeChan(), s)
expected = []byte("** foo says hello." + message.Newline) expected = []byte("** foo says hello." + message.Newline)
s.Read(&actual) s.Read(&actual)
if !reflect.DeepEqual(actual, expected) { if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
} }
func TestRoomDoesntBroadcastAnnounceMessagesWhenQuiet(t *testing.T) { func TestRoomDoesntBroadcastAnnounceMessagesWhenQuiet(t *testing.T) {
u := message.NewUser(message.SimpleID("foo")) u := message.NewUser(message.SimpleId("foo"))
u.SetConfig(message.UserConfig{ u.Config = message.UserConfig{
Quiet: true, Quiet: true,
}) }
ch := NewRoom() ch := NewRoom()
defer ch.Close() defer ch.Close()
@ -325,15 +93,11 @@ func TestRoomDoesntBroadcastAnnounceMessagesWhenQuiet(t *testing.T) {
<-ch.broadcast <-ch.broadcast
go func() { go func() {
/* for msg := range u.ConsumeChan() {
for { if _, ok := msg.(*message.AnnounceMsg); ok {
msg := u.ConsumeChan() t.Errorf("Got unexpected `%T`", msg)
if _, ok := msg.(*message.AnnounceMsg); ok {
t.Errorf("Got unexpected `%T`", msg)
}
} }
*/ }
// XXX: Fix this
}() }()
// Call with an AnnounceMsg and all the other types // Call with an AnnounceMsg and all the other types
@ -347,10 +111,10 @@ func TestRoomDoesntBroadcastAnnounceMessagesWhenQuiet(t *testing.T) {
} }
func TestRoomQuietToggleBroadcasts(t *testing.T) { func TestRoomQuietToggleBroadcasts(t *testing.T) {
u := message.NewUser(message.SimpleID("foo")) u := message.NewUser(message.SimpleId("foo"))
u.SetConfig(message.UserConfig{ u.Config = message.UserConfig{
Quiet: true, Quiet: true,
}) }
ch := NewRoom() ch := NewRoom()
defer ch.Close() defer ch.Close()
@ -363,24 +127,20 @@ func TestRoomQuietToggleBroadcasts(t *testing.T) {
// Drain the initial Join message // Drain the initial Join message
<-ch.broadcast <-ch.broadcast
u.SetConfig(message.UserConfig{ u.ToggleQuietMode()
Quiet: false,
})
expectedMsg := message.NewAnnounceMsg("Ignored") expectedMsg := message.NewAnnounceMsg("Ignored")
ch.HandleMsg(expectedMsg) ch.HandleMsg(expectedMsg)
msg := u.ConsumeOne() msg := <-u.ConsumeChan()
if _, ok := msg.(*message.AnnounceMsg); !ok { if _, ok := msg.(*message.AnnounceMsg); !ok {
t.Errorf("Got: `%T`; Expected: `%T`", msg, expectedMsg) t.Errorf("Got: `%T`; Expected: `%T`", msg, expectedMsg)
} }
u.SetConfig(message.UserConfig{ u.ToggleQuietMode()
Quiet: true,
})
ch.HandleMsg(message.NewAnnounceMsg("Ignored")) ch.HandleMsg(message.NewAnnounceMsg("Ignored"))
ch.HandleMsg(message.NewSystemMsg("hello", u)) ch.HandleMsg(message.NewSystemMsg("hello", u))
msg = u.ConsumeOne() msg = <-u.ConsumeChan()
if _, ok := msg.(*message.AnnounceMsg); ok { if _, ok := msg.(*message.AnnounceMsg); ok {
t.Errorf("Got unexpected `%T`", msg) t.Errorf("Got unexpected `%T`", msg)
} }
@ -390,7 +150,7 @@ func TestQuietToggleDisplayState(t *testing.T) {
var expected, actual []byte var expected, actual []byte
s := &MockScreen{} s := &MockScreen{}
u := message.NewUserScreen(message.SimpleID("foo"), s) u := message.NewUser(message.SimpleId("foo"))
ch := NewRoom() ch := NewRoom()
go ch.Serve() go ch.Serve()
@ -401,29 +161,24 @@ func TestQuietToggleDisplayState(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
u.HandleMsg(u.ConsumeOne()) // Drain the initial Join message
expected = []byte(" * foo joined. (Connected: 1)" + message.Newline) <-ch.broadcast
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: %q; Expected: %q", actual, expected)
}
ch.Send(message.ParseInput("/quiet", u)) ch.Send(message.ParseInput("/quiet", u))
u.HandleMsg(<-u.ConsumeChan(), s)
u.HandleMsg(u.ConsumeOne())
expected = []byte("-> Quiet mode is toggled ON" + message.Newline) expected = []byte("-> Quiet mode is toggled ON" + message.Newline)
s.Read(&actual) s.Read(&actual)
if !reflect.DeepEqual(actual, expected) { if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
ch.Send(message.ParseInput("/quiet", u)) ch.Send(message.ParseInput("/quiet", u))
u.HandleMsg(<-u.ConsumeChan(), s)
u.HandleMsg(u.ConsumeOne())
expected = []byte("-> Quiet mode is toggled OFF" + message.Newline) expected = []byte("-> Quiet mode is toggled OFF" + message.Newline)
s.Read(&actual) s.Read(&actual)
if !reflect.DeepEqual(actual, expected) { if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
} }
} }
@ -431,7 +186,7 @@ func TestRoomNames(t *testing.T) {
var expected, actual []byte var expected, actual []byte
s := &MockScreen{} s := &MockScreen{}
u := message.NewUserScreen(message.SimpleID("foo"), s) u := message.NewUser(message.SimpleId("foo"))
ch := NewRoom() ch := NewRoom()
go ch.Serve() go ch.Serve()
@ -442,65 +197,14 @@ func TestRoomNames(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
u.HandleMsg(u.ConsumeOne()) // Drain the initial Join message
expected = []byte(" * foo joined. (Connected: 1)" + message.Newline) <-ch.broadcast
s.Read(&actual)
if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: %q; Expected: %q", actual, expected)
}
ch.Send(message.ParseInput("/names", u)) ch.Send(message.ParseInput("/names", u))
u.HandleMsg(<-u.ConsumeChan(), s)
u.HandleMsg(u.ConsumeOne())
expected = []byte("-> 1 connected: foo" + message.Newline) expected = []byte("-> 1 connected: foo" + message.Newline)
s.Read(&actual) s.Read(&actual)
if !reflect.DeepEqual(actual, expected) { if !reflect.DeepEqual(actual, expected) {
t.Errorf("Got: %q; Expected: %q", actual, expected) t.Errorf("Got: `%s`; Expected: `%s`", actual, expected)
}
}
func TestRoomNamesPrefix(t *testing.T) {
r := NewRoom()
s := &MockScreen{}
members := []*Member{
&Member{User: message.NewUserScreen(message.SimpleID("aaa"), s)},
&Member{User: message.NewUserScreen(message.SimpleID("aab"), s)},
&Member{User: message.NewUserScreen(message.SimpleID("aac"), s)},
&Member{User: message.NewUserScreen(message.SimpleID("foo"), s)},
}
for _, m := range members {
if err := r.Members.Add(set.Itemize(m.ID(), m)); err != nil {
t.Fatal(err)
}
}
sendMsg := func(from *Member, body string) {
// lastMsg is set during render of self messags, so we can't use NewMsg
from.HandleMsg(message.NewPublicMsg(body, from.User))
}
// Inject some activity
sendMsg(members[2], "hi") // aac
sendMsg(members[0], "hi") // aaa
sendMsg(members[3], "hi") // foo
sendMsg(members[1], "hi") // aab
if got, want := r.NamesPrefix("a"), []string{"aab", "aaa", "aac"}; !reflect.DeepEqual(got, want) {
t.Errorf("got: %q; want: %q", got, want)
}
sendMsg(members[2], "hi") // aac
if got, want := r.NamesPrefix("a"), []string{"aac", "aab", "aaa"}; !reflect.DeepEqual(got, want) {
t.Errorf("got: %q; want: %q", got, want)
}
if got, want := r.NamesPrefix("f"), []string{"foo"}; !reflect.DeepEqual(got, want) {
t.Errorf("got: %q; want: %q", got, want)
}
if got, want := r.NamesPrefix("bar"), []string{}; !reflect.DeepEqual(got, want) {
t.Errorf("got: %q; want: %q", got, want)
} }
} }

17
chat/sanitize.go Normal file
View File

@ -0,0 +1,17 @@
package chat
import "regexp"
var reStripName = regexp.MustCompile("[^\\w.-]")
// SanitizeName returns a name with only allowed characters.
func SanitizeName(s string) string {
return reStripName.ReplaceAllString(s, "")
}
var reStripData = regexp.MustCompile("[^[:ascii:]]")
// SanitizeData returns a string with only allowed characters for client-provided metadata inputs.
func SanitizeData(s string) string {
return reStripData.ReplaceAllString(s, "")
}

147
chat/set.go Normal file
View File

@ -0,0 +1,147 @@
package chat
import (
"errors"
"strings"
"sync"
)
// The error returned when an added id already exists in the set.
var ErrIdTaken = errors.New("id already taken")
// The error returned when a requested item does not exist in the set.
var ErridentifiedMissing = errors.New("item does not exist")
// Interface for an item storeable in the set
type identified interface {
Id() string
}
// Set with string lookup.
// TODO: Add trie for efficient prefix lookup?
type idSet struct {
lookup map[string]identified
sync.RWMutex
}
// newIdSet creates a new set.
func newIdSet() *idSet {
return &idSet{
lookup: map[string]identified{},
}
}
// Clear removes all items and returns the number removed.
func (s *idSet) Clear() int {
s.Lock()
n := len(s.lookup)
s.lookup = map[string]identified{}
s.Unlock()
return n
}
// Len returns the size of the set right now.
func (s *idSet) Len() int {
return len(s.lookup)
}
// In checks if an item exists in this set.
func (s *idSet) In(item identified) bool {
s.RLock()
_, ok := s.lookup[item.Id()]
s.RUnlock()
return ok
}
// Get returns an item with the given Id.
func (s *idSet) Get(id string) (identified, error) {
s.RLock()
item, ok := s.lookup[id]
s.RUnlock()
if !ok {
return nil, ErridentifiedMissing
}
return item, nil
}
// Add item to this set if it does not exist already.
func (s *idSet) Add(item identified) error {
s.Lock()
defer s.Unlock()
_, found := s.lookup[item.Id()]
if found {
return ErrIdTaken
}
s.lookup[item.Id()] = item
return nil
}
// Remove item from this set.
func (s *idSet) Remove(item identified) error {
s.Lock()
defer s.Unlock()
id := item.Id()
_, found := s.lookup[id]
if !found {
return ErridentifiedMissing
}
delete(s.lookup, id)
return nil
}
// Replace item from old id with new identified.
// Used for moving the same identified to a new Id, such as a rename.
func (s *idSet) Replace(oldId string, item identified) error {
s.Lock()
defer s.Unlock()
// Check if it already exists
_, found := s.lookup[item.Id()]
if found {
return ErrIdTaken
}
// Remove oldId
_, found = s.lookup[oldId]
if !found {
return ErridentifiedMissing
}
delete(s.lookup, oldId)
// Add new identified
s.lookup[item.Id()] = item
return nil
}
// Each loops over every item while holding a read lock and applies fn to each
// element.
func (s *idSet) Each(fn func(item identified)) {
s.RLock()
for _, item := range s.lookup {
fn(item)
}
s.RUnlock()
}
// ListPrefix returns a list of items with a prefix, case insensitive.
func (s *idSet) ListPrefix(prefix string) []identified {
r := []identified{}
prefix = strings.ToLower(prefix)
s.RLock()
defer s.RUnlock()
for id, item := range s.lookup {
if !strings.HasPrefix(string(id), prefix) {
continue
}
r = append(r, item)
}
return r
}

View File

@ -4,35 +4,34 @@ import (
"testing" "testing"
"github.com/shazow/ssh-chat/chat/message" "github.com/shazow/ssh-chat/chat/message"
"github.com/shazow/ssh-chat/set"
) )
func TestSet(t *testing.T) { func TestSet(t *testing.T) {
var err error var err error
s := set.New() s := newIdSet()
u := message.NewUser(message.SimpleID("foo")) u := message.NewUser(message.SimpleId("foo"))
if s.In(u.ID()) { if s.In(u) {
t.Errorf("Set should be empty.") t.Errorf("Set should be empty.")
} }
err = s.Add(set.Itemize(u.ID(), u)) err = s.Add(u)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if !s.In(u.ID()) { if !s.In(u) {
t.Errorf("Set should contain user.") t.Errorf("Set should contain user.")
} }
u2 := message.NewUser(message.SimpleID("bar")) u2 := message.NewUser(message.SimpleId("bar"))
err = s.Add(set.Itemize(u2.ID(), u2)) err = s.Add(u2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
err = s.Add(set.Itemize(u2.ID(), u2)) err = s.Add(u2)
if err != set.ErrCollision { if err != ErrIdTaken {
t.Error(err) t.Error(err)
} }

View File

@ -12,44 +12,28 @@ import (
"github.com/alexcesaro/log" "github.com/alexcesaro/log"
"github.com/alexcesaro/log/golog" "github.com/alexcesaro/log/golog"
flags "github.com/jessevdk/go-flags" "github.com/jessevdk/go-flags"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
sshchat "github.com/shazow/ssh-chat" "github.com/shazow/ssh-chat"
"github.com/shazow/ssh-chat/chat" "github.com/shazow/ssh-chat/chat"
"github.com/shazow/ssh-chat/chat/message" "github.com/shazow/ssh-chat/chat/message"
"github.com/shazow/ssh-chat/sshd" "github.com/shazow/ssh-chat/sshd"
_ "net/http/pprof"
) )
import _ "net/http/pprof"
// Version of the binary, assigned during build.
var Version string = "dev"
// Options contains the flag options // Options contains the flag options
type Options struct { type Options struct {
Admin string `long:"admin" description:"File of public keys who are admins."` Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."`
Bind string `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:2022"` Identity string `short:"i" long:"identity" description:"Private key to identify server with." default:"~/.ssh/id_rsa"`
Identity []string `short:"i" long:"identity" description:"Private key to identify server with." default:"~/.ssh/id_rsa"` Bind string `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:2022"`
Log string `long:"log" description:"Write chat log to this file."` Admin string `long:"admin" description:"File of public keys who are admins."`
Motd string `long:"motd" description:"Optional Message of the Day file."` Whitelist string `long:"whitelist" description:"Optional file of public keys who are allowed to connect."`
Pprof int `long:"pprof" description:"Enable pprof http server for profiling."` Motd string `long:"motd" description:"Optional Message of the Day file."`
Verbose []bool `short:"v" long:"verbose" description:"Show verbose logging."` Log string `long:"log" description:"Write chat log to this file."`
Version bool `long:"version" description:"Print version and exit."` Pprof int `long:"pprof" description:"Enable pprof http server for profiling."`
Allowlist string `long:"allowlist" description:"Optional file of public keys who are allowed to connect."`
Whitelist string `long:"whitelist" dexcription:"Old name for allowlist option"`
Passphrase string `long:"unsafe-passphrase" description:"Require an interactive passphrase to connect. Allowlist feature is more secure."`
} }
const extraHelp = `There are hidden options and easter eggs in ssh-chat. The source code is a good
place to start looking. Some useful links:
* Project Repository:
https://github.com/shazow/ssh-chat
* Project Wiki FAQ:
https://github.com/shazow/ssh-chat/wiki/FAQ
`
var logLevels = []log.Level{ var logLevels = []log.Level{
log.Warning, log.Warning,
log.Info, log.Info,
@ -69,9 +53,7 @@ func main() {
if p == nil { if p == nil {
fmt.Print(err) fmt.Print(err)
} }
if flagErr, ok := err.(*flags.Error); ok && flagErr.Type == flags.ErrHelp { os.Exit(1)
fmt.Print(extraHelp)
}
return return
} }
@ -81,49 +63,42 @@ func main() {
}() }()
} }
if options.Version {
fmt.Println(Version)
return
}
// Figure out the log level // Figure out the log level
numVerbose := len(options.Verbose) numVerbose := len(options.Verbose)
if numVerbose >= len(logLevels) { if numVerbose > len(logLevels) {
numVerbose = len(logLevels) - 1 numVerbose = len(logLevels) - 1
} }
logLevel := logLevels[numVerbose] logLevel := logLevels[numVerbose]
logger := golog.New(os.Stderr, logLevel) sshchat.SetLogger(golog.New(os.Stderr, logLevel))
sshchat.SetLogger(logger)
if logLevel == log.Debug { if logLevel == log.Debug {
// Enable logging from submodules // Enable logging from submodules
chat.SetLogger(os.Stderr) chat.SetLogger(os.Stderr)
sshd.SetLogger(os.Stderr) sshd.SetLogger(os.Stderr)
message.SetLogger(os.Stderr) }
privateKeyPath := options.Identity
if strings.HasPrefix(privateKeyPath, "~/") {
user, err := user.Current()
if err == nil {
privateKeyPath = strings.Replace(privateKeyPath, "~", user.HomeDir, 1)
}
}
privateKey, err := ReadPrivateKey(privateKeyPath)
if err != nil {
fail(2, "Couldn't read private key: %v\n", err)
}
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
fail(3, "Failed to parse key: %v\n", err)
} }
auth := sshchat.NewAuth() auth := sshchat.NewAuth()
config := sshd.MakeAuth(auth) config := sshd.MakeAuth(auth)
config.ServerVersion = "SSH-2.0-Go ssh-chat" config.AddHostKey(signer)
// FIXME: Should we be using config.NoClientAuth = true by default?
for _, privateKeyPath := range options.Identity {
if strings.HasPrefix(privateKeyPath, "~/") {
user, err := user.Current()
if err == nil {
privateKeyPath = strings.Replace(privateKeyPath, "~", user.HomeDir, 1)
}
}
signer, err := ReadPrivateKey(privateKeyPath)
if err != nil {
fail(3, "Failed to read identity private key: %v\n", err)
}
config.AddHostKey(signer)
fmt.Printf("Added server identity: %s\n", sshd.Fingerprint(signer.PublicKey()))
}
s, err := sshd.ListenSSH(options.Bind, config) s, err := sshd.ListenSSH(options.Bind, config)
if err != nil { if err != nil {
@ -136,44 +111,41 @@ func main() {
host := sshchat.NewHost(s, auth) host := sshchat.NewHost(s, auth)
host.SetTheme(message.Themes[0]) host.SetTheme(message.Themes[0])
host.Version = Version
if options.Passphrase != "" { err = fromFile(options.Admin, func(line []byte) error {
auth.SetPassphrase(options.Passphrase) key, _, _, _, err := ssh.ParseAuthorizedKey(line)
} if err != nil {
return err
err = auth.LoadOps(loaderFromFile(options.Admin, logger)) }
auth.Op(key, 0)
return nil
})
if err != nil { if err != nil {
fail(5, "Failed to load admins: %v\n", err) fail(5, "Failed to load admins: %v\n", err)
} }
if options.Allowlist == "" && options.Whitelist != "" { err = fromFile(options.Whitelist, func(line []byte) error {
fmt.Println("--whitelist was renamed to --allowlist.") key, _, _, _, err := ssh.ParseAuthorizedKey(line)
options.Allowlist = options.Whitelist if err != nil {
} return err
err = auth.LoadAllowlist(loaderFromFile(options.Allowlist, logger)) }
auth.Whitelist(key, 0)
return nil
})
if err != nil { if err != nil {
fail(6, "Failed to load allowlist: %v\n", err) fail(6, "Failed to load whitelist: %v\n", err)
} }
auth.SetAllowlistMode(options.Allowlist != "")
if options.Motd != "" { if options.Motd != "" {
host.GetMOTD = func() (string, error) { motd, err := ioutil.ReadFile(options.Motd)
motd, err := ioutil.ReadFile(options.Motd) if err != nil {
if err != nil {
return "", err
}
motdString := string(motd)
// hack to normalize line endings into \r\n
motdString = strings.Replace(motdString, "\r\n", "\n", -1)
motdString = strings.Replace(motdString, "\n", "\r\n", -1)
return motdString, nil
}
if motdString, err := host.GetMOTD(); err != nil {
fail(7, "Failed to load MOTD file: %v\n", err) fail(7, "Failed to load MOTD file: %v\n", err)
} else {
host.SetMotd(motdString)
} }
motdString := strings.TrimSpace(string(motd))
// hack to normalize line endings into \r\n
motdString = strings.Replace(motdString, "\r\n", "\n", -1)
motdString = strings.Replace(motdString, "\n", "\r\n", -1)
host.SetMotd(motdString)
} }
if options.Log == "-" { if options.Log == "-" {
@ -194,34 +166,27 @@ func main() {
<-sig // Wait for ^C signal <-sig // Wait for ^C signal
fmt.Fprintln(os.Stderr, "Interrupt signal detected, shutting down.") fmt.Fprintln(os.Stderr, "Interrupt signal detected, shutting down.")
os.Exit(0)
} }
func loaderFromFile(path string, logger *golog.Logger) sshchat.KeyLoader { func fromFile(path string, handler func(line []byte) error) error {
if path == "" { if path == "" {
// Skip
return nil return nil
} }
return func() ([]ssh.PublicKey, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
var keys []ssh.PublicKey file, err := os.Open(path)
scanner := bufio.NewScanner(file) if err != nil {
for scanner.Scan() { return err
key, _, _, _, err := ssh.ParseAuthorizedKey(scanner.Bytes())
if err != nil {
if err.Error() == "ssh: no key found" {
continue // Skip line
}
return nil, err
}
keys = append(keys, key)
}
if keys == nil {
logger.Warning("file", path, "contained no keys")
}
return keys, nil
} }
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
err := handler(scanner.Bytes())
if err != nil {
return err
}
}
return nil
} }

View File

@ -1,38 +1,49 @@
package main package main
import ( import (
"crypto/x509"
"encoding/pem"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"syscall"
"golang.org/x/crypto/ssh" "code.google.com/p/gopass"
"golang.org/x/term"
) )
// ReadPrivateKey attempts to read your private key and possibly decrypt it if it // ReadPrivateKey attempts to read your private key and possibly decrypt it if it
// requires a passphrase. // requires a passphrase.
// This function will prompt for a passphrase on STDIN if the environment variable (`IDENTITY_PASSPHRASE`), // This function will prompt for a passphrase on STDIN if the environment variable (`IDENTITY_PASSPHRASE`),
// is not set. // is not set.
func ReadPrivateKey(path string) (ssh.Signer, error) { func ReadPrivateKey(path string) ([]byte, error) {
privateKey, err := ioutil.ReadFile(path) privateKey, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load identity: %v", err) return nil, fmt.Errorf("failed to load identity: %v", err)
} }
pk, err := ssh.ParsePrivateKey(privateKey) block, rest := pem.Decode(privateKey)
if err == nil { if len(rest) > 0 {
} else if _, ok := err.(*ssh.PassphraseMissingError); ok { return nil, fmt.Errorf("extra data when decoding private key")
passphrase := []byte(os.Getenv("IDENTITY_PASSPHRASE")) }
if len(passphrase) == 0 { if !x509.IsEncryptedPEMBlock(block) {
fmt.Println("Enter passphrase to unlock identity private key:", path) return privateKey, nil
passphrase, err = term.ReadPassword(int(syscall.Stdin))
if err != nil {
return nil, fmt.Errorf("couldn't read passphrase: %v", err)
}
}
return ssh.ParsePrivateKeyWithPassphrase(privateKey, passphrase)
} }
return pk, err passphrase := os.Getenv("IDENTITY_PASSPHRASE")
if passphrase == "" {
passphrase, err = gopass.GetPass("Enter passphrase: ")
if err != nil {
return nil, fmt.Errorf("couldn't read passphrase: %v", err)
}
}
der, err := x509.DecryptPEMBlock(block, []byte(passphrase))
if err != nil {
return nil, fmt.Errorf("decrypt failed: %v", err)
}
privateKey = pem.EncodeToMemory(&pem.Block{
Type: block.Type,
Bytes: der,
})
return privateKey, nil
} }

View File

@ -1,13 +0,0 @@
version: '3.2'
services:
app:
container_name: ssh-chat
build: .
ports:
- 2022:2022
restart: unless-stopped
volumes:
- type: bind
source: ~/.ssh/
target: /root/.ssh/
read_only: true

14
go.mod
View File

@ -1,14 +0,0 @@
module github.com/shazow/ssh-chat
require (
github.com/alexcesaro/log v0.0.0-20150915221235-61e686294e58
github.com/jessevdk/go-flags v1.5.0
github.com/shazow/rateio v0.0.0-20200113175441-4461efc8bdc4
golang.org/x/crypto v0.17.0
golang.org/x/sync v0.1.0
golang.org/x/sys v0.15.0
golang.org/x/term v0.15.0
golang.org/x/text v0.14.0
)
go 1.13

50
go.sum
View File

@ -1,50 +0,0 @@
github.com/alexcesaro/log v0.0.0-20150915221235-61e686294e58 h1:MkpmYfld/S8kXqTYI68DfL8/hHXjHogL120Dy00TIxc=
github.com/alexcesaro/log v0.0.0-20150915221235-61e686294e58/go.mod h1:YNfsMyWSs+h+PaYkxGeMVmVCX75Zj/pqdjbu12ciCYE=
github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc=
github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4=
github.com/shazow/rateio v0.0.0-20200113175441-4461efc8bdc4 h1:zwQ1HBo5FYwn1ksMd19qBCKO8JAWE9wmHivEpkw/DvE=
github.com/shazow/rateio v0.0.0-20200113175441-4461efc8bdc4/go.mod h1:vt2jWY/3Qw1bIzle5thrJWucsLuuX9iUNnp20CqCciI=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4=
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -1,5 +1,5 @@
/* /*
Package sshchat is an implementation of an ssh server which serves a chat room sshchat package is an implementation of an ssh server which serves a chat room
instead of a shell. instead of a shell.
sshd subdirectory contains the ssh-related pieces which know nothing about chat. sshd subdirectory contains the ssh-related pieces which know nothing about chat.

599
host.go
View File

@ -1,33 +1,27 @@
package sshchat package sshchat
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"strings" "strings"
"sync"
"time" "time"
"golang.org/x/crypto/ssh"
"github.com/shazow/rateio" "github.com/shazow/rateio"
"github.com/shazow/ssh-chat/chat" "github.com/shazow/ssh-chat/chat"
"github.com/shazow/ssh-chat/chat/message" "github.com/shazow/ssh-chat/chat/message"
"github.com/shazow/ssh-chat/internal/humantime"
"github.com/shazow/ssh-chat/internal/sanitize"
"github.com/shazow/ssh-chat/set"
"github.com/shazow/ssh-chat/sshd" "github.com/shazow/ssh-chat/sshd"
) )
var buildCommit string
const maxInputLength int = 1024 const maxInputLength int = 1024
// GetPrompt will render the terminal prompt string based on the user. // GetPrompt will render the terminal prompt string based on the user.
func GetPrompt(user *message.User) string { func GetPrompt(user *message.User) string {
name := user.Name() name := user.Name()
cfg := user.Config() if user.Config.Theme != nil {
if cfg.Theme != nil { name = user.Config.Theme.ColorName(user)
name = cfg.Theme.ColorName(user)
} }
return fmt.Sprintf("[%s] ", name) return fmt.Sprintf("[%s] ", name)
} }
@ -38,22 +32,13 @@ type Host struct {
*chat.Room *chat.Room
listener *sshd.SSHListener listener *sshd.SSHListener
commands chat.Commands commands chat.Commands
auth *Auth
// Version string to print on /version motd string
Version string auth *Auth
count int
// Default theme // Default theme
theme message.Theme theme message.Theme
mu sync.Mutex
motd string
count int
// GetMOTD is used to reload the motd from an external source
GetMOTD func() (string, error)
// OnUserJoined is used to notify when a user joins a host
OnUserJoined func(*message.User)
} }
// NewHost creates a Host on top of an existing listener. // NewHost creates a Host on top of an existing listener.
@ -77,20 +62,15 @@ func NewHost(listener *sshd.SSHListener, auth *Auth) *Host {
// SetTheme sets the default theme for the host. // SetTheme sets the default theme for the host.
func (h *Host) SetTheme(theme message.Theme) { func (h *Host) SetTheme(theme message.Theme) {
h.mu.Lock()
h.theme = theme h.theme = theme
h.mu.Unlock()
} }
// SetMotd sets the host's message of the day. // SetMotd sets the host's message of the day.
// TODO: Change to SetMOTD
func (h *Host) SetMotd(motd string) { func (h *Host) SetMotd(motd string) {
h.mu.Lock()
h.motd = motd h.motd = motd
h.mu.Unlock()
} }
func (h *Host) isOp(conn sshd.Connection) bool { func (h Host) isOp(conn sshd.Connection) bool {
key := conn.PublicKey() key := conn.PublicKey()
if key == nil { if key == nil {
return false return false
@ -102,102 +82,47 @@ func (h *Host) isOp(conn sshd.Connection) bool {
func (h *Host) Connect(term *sshd.Terminal) { func (h *Host) Connect(term *sshd.Terminal) {
id := NewIdentity(term.Conn) id := NewIdentity(term.Conn)
user := message.NewUserScreen(id, term) user := message.NewUserScreen(id, term)
user.OnChange = func() { user.Config.Theme = &h.theme
term.SetPrompt(GetPrompt(user)) go func() {
user.SetHighlight(user.ID()) // Close term once user is closed.
} user.Wait()
cfg := user.Config() term.Close()
}()
apiMode := strings.ToLower(term.Term()) == "bot"
if apiMode {
cfg.Theme = message.MonoTheme
cfg.Echo = false
} else {
term.SetEnterClear(true) // We provide our own echo rendering
cfg.Theme = &h.theme
}
user.SetConfig(cfg)
go user.Consume()
// Close term once user is closed.
defer user.Close() defer user.Close()
defer term.Close()
h.mu.Lock()
motd := h.motd
count := h.count
h.count++
h.mu.Unlock()
// Send MOTD // Send MOTD
if motd != "" && !apiMode { if h.motd != "" {
user.Send(message.NewAnnounceMsg(motd)) user.Send(message.NewAnnounceMsg(h.motd))
} }
member, err := h.Join(user) member, err := h.Join(user)
if err != nil { if err != nil {
// Try again... // Try again...
id.SetName(fmt.Sprintf("Guest%d", count)) id.SetName(fmt.Sprintf("Guest%d", h.count))
member, err = h.Join(user) member, err = h.Join(user)
} }
if err != nil { if err != nil {
logger.Errorf("[%s] Failed to join: %s", term.Conn.RemoteAddr(), err) logger.Errorf("Failed to join: %s", err)
return return
} }
// Load user config overrides from ENV
// TODO: Would be nice to skip the command parsing pipeline just to load
// config values. Would need to factor out some command handler logic into
// accessible helpers.
env := term.Env()
for _, e := range env {
switch e.Key {
case "SSHCHAT_TIMESTAMP":
if e.Value != "" && e.Value != "0" {
cmd := "/timestamp"
if e.Value != "1" {
cmd += " " + e.Value
}
if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok {
h.Room.HandleMsg(msg)
}
}
case "SSHCHAT_THEME":
cmd := "/theme " + e.Value
if msg, ok := message.NewPublicMsg(cmd, user).ParseCommand(); ok {
h.Room.HandleMsg(msg)
}
}
}
// Successfully joined. // Successfully joined.
if !apiMode { term.SetPrompt(GetPrompt(user))
term.SetPrompt(GetPrompt(user)) term.AutoCompleteCallback = h.AutoCompleteFunction(user)
term.AutoCompleteCallback = h.AutoCompleteFunction(user) user.SetHighlight(user.Name())
user.SetHighlight(user.Name()) h.count++
}
// Should the user be op'd on join? // Should the user be op'd on join?
if h.isOp(term.Conn) { member.Op = h.isOp(term.Conn)
member.IsOp = true
}
ratelimit := rateio.NewSimpleLimiter(3, time.Second*3) ratelimit := rateio.NewSimpleLimiter(3, time.Second*3)
logger.Debugf("[%s] Joined: %s", term.Conn.RemoteAddr(), user.Name())
if h.OnUserJoined != nil {
h.OnUserJoined(user)
}
for { for {
line, err := term.ReadLine() line, err := term.ReadLine()
if err == io.EOF { if err == io.EOF {
// Closed // Closed
break break
} else if err != nil { } else if err != nil {
logger.Errorf("[%s] Terminal reading error: %s", term.Conn.RemoteAddr(), err) logger.Errorf("Terminal reading error: %s", err)
break break
} }
@ -212,60 +137,54 @@ func (h *Host) Connect(term *sshd.Terminal) {
} }
if line == "" { if line == "" {
// Silently ignore empty lines. // Silently ignore empty lines.
term.Write([]byte{})
continue continue
} }
m := message.ParseInput(line, user) m := message.ParseInput(line, user)
if !apiMode {
if m, ok := m.(*message.CommandMsg); ok {
// Other messages render themselves by the room, commands we'll
// have to re-echo ourselves manually.
user.HandleMsg(m)
}
}
// FIXME: Any reason to use h.room.Send(m) instead? // FIXME: Any reason to use h.room.Send(m) instead?
h.HandleMsg(m) h.HandleMsg(m)
if apiMode { cmd := m.Command()
// Skip the remaining rendering workarounds if cmd == "/nick" || cmd == "/theme" {
continue // Hijack /nick command to update terminal synchronously. Wouldn't
// work if we use h.room.Send(m) above.
//
// FIXME: This is hacky, how do we improve the API to allow for
// this? Chat module shouldn't know about terminals.
term.SetPrompt(GetPrompt(user))
user.SetHighlight(user.Name())
} }
} }
err = h.Leave(user) err = h.Leave(user)
if err != nil { if err != nil {
logger.Errorf("[%s] Failed to leave: %s", term.Conn.RemoteAddr(), err) logger.Errorf("Failed to leave: %s", err)
return return
} }
logger.Debugf("[%s] Leaving: %s", term.Conn.RemoteAddr(), user.Name())
} }
// Serve our chat room onto the listener // Serve our chat room onto the listener
func (h *Host) Serve() { func (h *Host) Serve() {
h.listener.HandlerFunc = h.Connect terminals := h.listener.ServeTerminal()
h.listener.Serve()
for term := range terminals {
go h.Connect(term)
}
} }
func (h *Host) completeName(partial string, skipName string) string { func (h Host) completeName(partial string) string {
names := h.NamesPrefix(partial) names := h.NamesPrefix(partial)
if len(names) == 0 { if len(names) == 0 {
// Didn't find anything // Didn't find anything
return "" return ""
} else if name := names[0]; name != skipName {
// First name is not the skipName, great
return name
} else if len(names) > 1 {
// Next candidate
return names[1]
} }
return ""
return names[len(names)-1]
} }
func (h *Host) completeCommand(partial string) string { func (h Host) completeCommand(partial string) string {
for cmd := range h.commands { for cmd, _ := range h.commands {
if strings.HasPrefix(cmd, partial) { if strings.HasPrefix(cmd, partial) {
return cmd return cmd
} }
@ -287,31 +206,22 @@ func (h *Host) AutoCompleteFunction(u *message.User) func(line string, pos int,
fields := strings.Fields(line[:pos]) fields := strings.Fields(line[:pos])
isFirst := len(fields) < 2 isFirst := len(fields) < 2
partial := "" partial := fields[len(fields)-1]
if len(fields) > 0 {
partial = fields[len(fields)-1]
}
posPartial := pos - len(partial) posPartial := pos - len(partial)
var completed string var completed string
if isFirst && strings.HasPrefix(line, "/") { if isFirst && strings.HasPrefix(partial, "/") {
// Command // Command
completed = h.completeCommand(partial) completed = h.completeCommand(partial)
if completed == "/reply" { if completed == "/reply" {
replyTo := u.ReplyTo() replyTo := u.ReplyTo()
if replyTo != nil { if replyTo != nil {
name := replyTo.ID() completed = "/msg " + replyTo.Name()
_, found := h.GetUser(name)
if found {
completed = "/msg " + name
} else {
u.SetReplyTo(nil)
}
} }
} }
} else { } else {
// Name // Name
completed = h.completeName(partial, u.Name()) completed = h.completeName(partial)
if completed == "" { if completed == "" {
return return
} }
@ -332,7 +242,7 @@ func (h *Host) AutoCompleteFunction(u *message.User) func(line string, pos int,
// GetUser returns a message.User based on a name. // GetUser returns a message.User based on a name.
func (h *Host) GetUser(name string) (*message.User, bool) { func (h *Host) GetUser(name string) (*message.User, bool) {
m, ok := h.MemberByID(name) m, ok := h.MemberById(name)
if !ok { if !ok {
return nil, false return nil, false
} }
@ -342,20 +252,6 @@ func (h *Host) GetUser(name string) (*message.User, bool) {
// InitCommands adds host-specific commands to a Commands container. These will // InitCommands adds host-specific commands to a Commands container. These will
// override any existing commands. // override any existing commands.
func (h *Host) InitCommands(c *chat.Commands) { func (h *Host) InitCommands(c *chat.Commands) {
sendPM := func(room *chat.Room, msg string, from *message.User, target *message.User) error {
m := message.NewPrivateMsg(msg, from, target)
room.Send(&m)
txt := fmt.Sprintf("[Sent PM to %s]", target.Name())
if isAway, _, awayReason := target.GetAway(); isAway {
txt += " Away: " + awayReason
}
sysMsg := message.NewSystemMsg(txt, from)
room.Send(sysMsg)
target.SetReplyTo(from)
return nil
}
c.Add(chat.Command{ c.Add(chat.Command{
Prefix: "/msg", Prefix: "/msg",
PrefixHelp: "USER MESSAGE", PrefixHelp: "USER MESSAGE",
@ -374,7 +270,9 @@ func (h *Host) InitCommands(c *chat.Commands) {
return errors.New("user not found") return errors.New("user not found")
} }
return sendPM(room, strings.Join(args[1:], " "), msg.From(), target) m := message.NewPrivateMsg(strings.Join(args[1:], " "), msg.From(), target)
room.Send(m)
return nil
}, },
}) })
@ -394,12 +292,9 @@ func (h *Host) InitCommands(c *chat.Commands) {
return errors.New("no message to reply to") return errors.New("no message to reply to")
} }
_, found := h.GetUser(target.ID()) m := message.NewPrivateMsg(strings.Join(args, " "), msg.From(), target)
if !found { room.Send(m)
return errors.New("user not found") return nil
}
return sendPM(room, strings.Join(args, " "), msg.From(), target)
}, },
}) })
@ -417,15 +312,9 @@ func (h *Host) InitCommands(c *chat.Commands) {
if !ok { if !ok {
return errors.New("user not found") return errors.New("user not found")
} }
id := target.Identifier.(*Identity) id := target.Identifier.(*Identity)
var whois string room.Send(message.NewSystemMsg(id.Whois(), msg.From()))
switch room.IsOp(msg.From()) {
case true:
whois = id.WhoisAdmin(room)
case false:
whois = id.Whois(room)
}
room.Send(message.NewSystemMsg(whois, msg.From()))
return nil return nil
}, },
@ -435,7 +324,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
c.Add(chat.Command{ c.Add(chat.Command{
Prefix: "/version", Prefix: "/version",
Handler: func(room *chat.Room, msg message.CommandMsg) error { Handler: func(room *chat.Room, msg message.CommandMsg) error {
room.Send(message.NewSystemMsg(h.Version, msg.From())) room.Send(message.NewSystemMsg(buildCommit, msg.From()))
return nil return nil
}, },
}) })
@ -444,7 +333,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
c.Add(chat.Command{ c.Add(chat.Command{
Prefix: "/uptime", Prefix: "/uptime",
Handler: func(room *chat.Room, msg message.CommandMsg) error { Handler: func(room *chat.Room, msg message.CommandMsg) error {
room.Send(message.NewSystemMsg(humantime.Since(timeStarted), msg.From())) room.Send(message.NewSystemMsg(time.Now().Sub(timeStarted).String(), msg.From()))
return nil return nil
}, },
}) })
@ -480,8 +369,8 @@ func (h *Host) InitCommands(c *chat.Commands) {
c.Add(chat.Command{ c.Add(chat.Command{
Op: true, Op: true,
Prefix: "/ban", Prefix: "/ban",
PrefixHelp: "QUERY [DURATION]", PrefixHelp: "USER [DURATION]",
Help: "Ban from the server. QUERY can be a username to ban the fingerprint and ip, or quoted \"key=value\" pairs with keys like ip, fingerprint, client.", Help: "Ban USER from the server.",
Handler: func(room *chat.Room, msg message.CommandMsg) error { Handler: func(room *chat.Room, msg message.CommandMsg) error {
// TODO: Would be nice to specify what to ban. Key? Ip? etc. // TODO: Would be nice to specify what to ban. Key? Ip? etc.
if !room.IsOp(msg.From()) { if !room.IsOp(msg.From()) {
@ -493,17 +382,12 @@ func (h *Host) InitCommands(c *chat.Commands) {
return errors.New("must specify user") return errors.New("must specify user")
} }
query := args[0] target, ok := h.GetUser(args[0])
target, ok := h.GetUser(query)
if !ok { if !ok {
query = strings.Join(args, " ")
if strings.Contains(query, "=") {
return h.auth.BanQuery(query)
}
return errors.New("user not found") return errors.New("user not found")
} }
var until time.Duration var until time.Duration = 0
if len(args) > 1 { if len(args) > 1 {
until, _ = time.ParseDuration(args[1]) until, _ = time.ParseDuration(args[1])
} }
@ -516,36 +400,7 @@ func (h *Host) InitCommands(c *chat.Commands) {
room.Send(message.NewAnnounceMsg(body)) room.Send(message.NewAnnounceMsg(body))
target.Close() target.Close()
logger.Debugf("Banned: \n-> %s", id.Whois(room)) logger.Debugf("Banned: \n-> %s", id.Whois())
return nil
},
})
c.Add(chat.Command{
Op: true,
Prefix: "/banned",
Help: "List the current ban conditions.",
Handler: func(room *chat.Room, msg message.CommandMsg) error {
if !room.IsOp(msg.From()) {
return errors.New("must be op")
}
bannedIPs, bannedFingerprints, bannedClients := h.auth.Banned()
buf := bytes.Buffer{}
fmt.Fprintf(&buf, "Banned:")
for _, key := range bannedIPs {
fmt.Fprintf(&buf, "\n \"ip=%s\"", key)
}
for _, key := range bannedFingerprints {
fmt.Fprintf(&buf, "\n \"fingerprint=%s\"", key)
}
for _, key := range bannedClients {
fmt.Fprintf(&buf, "\n \"client=%s\"", key)
}
room.Send(message.NewSystemMsg(buf.String(), msg.From()))
return nil return nil
}, },
@ -554,40 +409,26 @@ func (h *Host) InitCommands(c *chat.Commands) {
c.Add(chat.Command{ c.Add(chat.Command{
Op: true, Op: true,
Prefix: "/motd", Prefix: "/motd",
PrefixHelp: "[MESSAGE]", PrefixHelp: "MESSAGE",
Help: "Set a new MESSAGE of the day, or print the motd if no MESSAGE.", Help: "Set the MESSAGE of the day.",
Handler: func(room *chat.Room, msg message.CommandMsg) error { Handler: func(room *chat.Room, msg message.CommandMsg) error {
if !room.IsOp(msg.From()) {
return errors.New("must be op")
}
motd := ""
args := msg.Args() args := msg.Args()
user := msg.From() if len(args) > 0 {
motd = strings.Join(args, " ")
h.mu.Lock()
motd := h.motd
h.mu.Unlock()
if len(args) == 0 {
room.Send(message.NewSystemMsg(motd, user))
return nil
}
if !room.IsOp(user) {
return errors.New("must be OP to modify the MOTD")
} }
var err error h.motd = motd
var s string = strings.Join(args, " ") body := fmt.Sprintf("New message of the day set by %s:", msg.From().Name())
room.Send(message.NewAnnounceMsg(body))
if s == "@" { if motd != "" {
if h.GetMOTD == nil { room.Send(message.NewAnnounceMsg(motd))
return errors.New("motd reload not set")
}
if s, err = h.GetMOTD(); err != nil {
return err
}
} }
h.SetMotd(s)
fromMsg := fmt.Sprintf("New message of the day set by %s:", msg.From().Name())
room.Send(message.NewAnnounceMsg(fromMsg + message.Newline + "-> " + s))
return nil return nil
}, },
}) })
@ -595,8 +436,8 @@ func (h *Host) InitCommands(c *chat.Commands) {
c.Add(chat.Command{ c.Add(chat.Command{
Op: true, Op: true,
Prefix: "/op", Prefix: "/op",
PrefixHelp: "USER [DURATION|remove]", PrefixHelp: "USER [DURATION]",
Help: "Set USER as admin. Duration only applies to pubkey reconnects.", Help: "Set USER as admin.",
Handler: func(room *chat.Room, msg message.CommandMsg) error { Handler: func(room *chat.Room, msg message.CommandMsg) error {
if !room.IsOp(msg.From()) { if !room.IsOp(msg.From()) {
return errors.New("must be op") return errors.New("must be op")
@ -607,293 +448,23 @@ func (h *Host) InitCommands(c *chat.Commands) {
return errors.New("must specify user") return errors.New("must specify user")
} }
opValue := true var until time.Duration = 0
var until time.Duration
if len(args) > 1 { if len(args) > 1 {
if args[1] == "remove" { until, _ = time.ParseDuration(args[1])
// Expire instantly
until = time.Duration(1)
opValue = false
} else {
until, _ = time.ParseDuration(args[1])
}
} }
member, ok := room.MemberByID(args[0]) member, ok := room.MemberById(args[0])
if !ok { if !ok {
return errors.New("user not found") return errors.New("user not found")
} }
member.IsOp = opValue member.Op = true
id := member.Identifier.(*Identity) id := member.Identifier.(*Identity)
h.auth.Op(id.PublicKey(), until) h.auth.Op(id.PublicKey(), until)
var body string body := fmt.Sprintf("Made op by %s.", msg.From().Name())
if opValue {
body = fmt.Sprintf("Made op by %s.", msg.From().Name())
} else {
body = fmt.Sprintf("Removed op by %s.", msg.From().Name())
}
room.Send(message.NewSystemMsg(body, member.User)) room.Send(message.NewSystemMsg(body, member.User))
return nil return nil
}, },
}) })
c.Add(chat.Command{
Op: true,
Prefix: "/rename",
PrefixHelp: "USER NEW_NAME [SYMBOL]",
Help: "Rename USER to NEW_NAME, add optional SYMBOL prefix",
Handler: func(room *chat.Room, msg message.CommandMsg) error {
if !room.IsOp(msg.From()) {
return errors.New("must be op")
}
args := msg.Args()
if len(args) < 2 {
return errors.New("must specify user and new name")
}
member, ok := room.MemberByID(args[0])
if !ok {
return errors.New("user not found")
}
symbolSet := false
if len(args) == 3 {
s := args[2]
if id, ok := member.Identifier.(*Identity); ok {
id.SetSymbol(s)
} else {
return errors.New("user does not support setting symbol")
}
body := fmt.Sprintf("Assigned symbol %q by %s.", s, msg.From().Name())
room.Send(message.NewSystemMsg(body, member.User))
symbolSet = true
}
oldID := member.ID()
newID := sanitize.Name(args[1])
if newID == oldID && !symbolSet {
return errors.New("new name is the same as the original")
} else if (newID == "" || newID == oldID) && symbolSet {
if member.User.OnChange != nil {
member.User.OnChange()
}
return nil
}
member.SetID(newID)
err := room.Rename(oldID, member)
if err != nil {
member.SetID(oldID)
return err
}
body := fmt.Sprintf("%s was renamed by %s.", oldID, msg.From().Name())
room.Send(message.NewAnnounceMsg(body))
return nil
},
})
forConnectedUsers := func(cmd func(*chat.Member, ssh.PublicKey) error) error {
return h.Members.Each(func(key string, item set.Item) error {
v := item.Value()
if v == nil { // expired between Each and here
return nil
}
user := v.(*chat.Member)
pk := user.Identifier.(*Identity).PublicKey()
return cmd(user, pk)
})
}
forPubkeyUser := func(args []string, cmd func(ssh.PublicKey)) (errors []string) {
invalidUsers := []string{}
invalidKeys := []string{}
noKeyUsers := []string{}
var keyType string
for _, v := range args {
switch {
case keyType != "":
pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyType + " " + v))
if err == nil {
cmd(pk)
} else {
invalidKeys = append(invalidKeys, keyType+" "+v)
}
keyType = ""
case strings.HasPrefix(v, "ssh-"):
keyType = v
default:
user, ok := h.GetUser(v)
if ok {
pk := user.Identifier.(*Identity).PublicKey()
if pk == nil {
noKeyUsers = append(noKeyUsers, user.Identifier.Name())
} else {
cmd(pk)
}
} else {
invalidUsers = append(invalidUsers, v)
}
}
}
if len(noKeyUsers) != 0 {
errors = append(errors, fmt.Sprintf("users without a public key: %v", noKeyUsers))
}
if len(invalidUsers) != 0 {
errors = append(errors, fmt.Sprintf("invalid users: %v", invalidUsers))
}
if len(invalidKeys) != 0 {
errors = append(errors, fmt.Sprintf("invalid keys: %v", invalidKeys))
}
return
}
allowlistHelptext := []string{
"Usage: /allowlist help | on | off | add {PUBKEY|USER}... | remove {PUBKEY|USER}... | import [AGE] | reload {keep|flush} | reverify | status",
"help: this help message",
"on, off: set allowlist mode (applies to new connections)",
"add, remove: add or remove keys from the allowlist",
"import: add all keys of users connected since AGE (default 0) ago to the allowlist",
"reload: re-read the allowlist file and keep or discard entries in the current allowlist but not in the file",
"reverify: kick all users not in the allowlist if allowlisting is enabled",
"status: show status information",
}
allowlistImport := func(args []string) (msgs []string, err error) {
var since time.Duration
if len(args) > 0 {
since, err = time.ParseDuration(args[0])
if err != nil {
return
}
}
cutoff := time.Now().Add(-since)
noKeyUsers := []string{}
forConnectedUsers(func(user *chat.Member, pk ssh.PublicKey) error {
if user.Joined().Before(cutoff) {
if pk == nil {
noKeyUsers = append(noKeyUsers, user.Identifier.Name())
} else {
h.auth.Allowlist(pk, 0)
}
}
return nil
})
if len(noKeyUsers) != 0 {
msgs = []string{fmt.Sprintf("users without a public key: %v", noKeyUsers)}
}
return
}
allowlistReload := func(args []string) error {
if !(len(args) > 0 && (args[0] == "keep" || args[0] == "flush")) {
return errors.New("must specify whether to keep or flush current entries")
}
if args[0] == "flush" {
h.auth.allowlist.Clear()
}
return h.auth.ReloadAllowlist()
}
allowlistReverify := func(room *chat.Room) []string {
if !h.auth.AllowlistMode() {
return []string{"allowlist is disabled, so nobody will be kicked"}
}
var kicked []string
forConnectedUsers(func(user *chat.Member, pk ssh.PublicKey) error {
if h.auth.CheckPublicKey(pk) != nil && !user.IsOp { // we do this check here as well for ops without keys
kicked = append(kicked, user.Name())
user.Close()
}
return nil
})
if kicked != nil {
room.Send(message.NewAnnounceMsg("Kicked during pubkey reverification: " + strings.Join(kicked, ", ")))
}
return nil
}
allowlistStatus := func() (msgs []string) {
if h.auth.AllowlistMode() {
msgs = []string{"allowlist enabled"}
} else {
msgs = []string{"allowlist disabled"}
}
allowlistedUsers := []string{}
allowlistedKeys := []string{}
h.auth.allowlist.Each(func(key string, item set.Item) error {
keyFP := item.Key()
if forConnectedUsers(func(user *chat.Member, pk ssh.PublicKey) error {
if pk != nil && sshd.Fingerprint(pk) == keyFP {
allowlistedUsers = append(allowlistedUsers, user.Name())
return io.EOF
}
return nil
}) == nil {
// if we land here, the key matches no users
allowlistedKeys = append(allowlistedKeys, keyFP)
}
return nil
})
if len(allowlistedUsers) != 0 {
msgs = append(msgs, "Connected users on the allowlist: "+strings.Join(allowlistedUsers, ", "))
}
if len(allowlistedKeys) != 0 {
msgs = append(msgs, "Keys on the allowlist without connected user: "+strings.Join(allowlistedKeys, ", "))
}
return
}
c.Add(chat.Command{
Op: true,
Prefix: "/allowlist",
PrefixHelp: "COMMAND [ARGS...]",
Help: "Modify the allowlist or allowlist state. See /allowlist help for subcommands",
Handler: func(room *chat.Room, msg message.CommandMsg) (err error) {
if !room.IsOp(msg.From()) {
return errors.New("must be op")
}
args := msg.Args()
if len(args) == 0 {
args = []string{"help"}
}
// send exactly one message to preserve order
var replyLines []string
switch args[0] {
case "help":
replyLines = allowlistHelptext
case "on":
h.auth.SetAllowlistMode(true)
case "off":
h.auth.SetAllowlistMode(false)
case "add":
replyLines = forPubkeyUser(args[1:], func(pk ssh.PublicKey) { h.auth.Allowlist(pk, 0) })
case "remove":
replyLines = forPubkeyUser(args[1:], func(pk ssh.PublicKey) { h.auth.Allowlist(pk, 1) })
case "import":
replyLines, err = allowlistImport(args[1:])
case "reload":
err = allowlistReload(args[1:])
case "reverify":
replyLines = allowlistReverify(room)
case "status":
replyLines = allowlistStatus()
default:
err = errors.New("invalid subcommand: " + args[0])
}
if err == nil && replyLines != nil {
room.Send(message.NewSystemMsg(strings.Join(replyLines, "\r\n"), msg.From()))
}
return
},
})
} }

View File

@ -2,477 +2,216 @@ package sshchat
import ( import (
"bufio" "bufio"
"errors" "crypto/rand"
"fmt" "crypto/rsa"
"io" "io"
mathRand "math/rand" "io/ioutil"
"strings" "strings"
"testing" "testing"
"time"
"github.com/shazow/ssh-chat/chat/message" "github.com/shazow/ssh-chat/chat/message"
"github.com/shazow/ssh-chat/sshd" "github.com/shazow/ssh-chat/sshd"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
) )
func stripPrompt(s string) string { func stripPrompt(s string) string {
// FIXME: Is there a better way to do this? pos := strings.LastIndex(s, "\033[K")
if endPos := strings.Index(s, "\x1b[K "); endPos > 0 { if pos < 0 {
return s[endPos+3:] return s
}
if endPos := strings.Index(s, "\x1b[2K "); endPos > 0 {
return s[endPos+4:]
}
if endPos := strings.Index(s, "\x1b[K-> "); endPos > 0 {
return s[endPos+6:]
}
if endPos := strings.Index(s, "] "); endPos > 0 {
return s[endPos+2:]
}
if strings.HasPrefix(s, "-> ") {
return s[3:]
}
return s
}
func TestStripPrompt(t *testing.T) {
tests := []struct {
Input string
Want string
}{
{
Input: "\x1b[A\x1b[2K[quux] hello",
Want: "hello",
},
{
Input: "[foo] \x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[D\x1b[K * Guest1 joined. (Connected: 2)\r",
Want: " * Guest1 joined. (Connected: 2)\r",
},
{
Input: "[foo] \x1b[6D\x1b[K-> From your friendly system.\r",
Want: "From your friendly system.\r",
},
{
Input: "-> Err: must be op.\r",
Want: "Err: must be op.\r",
},
}
for i, tc := range tests {
if got, want := stripPrompt(tc.Input), tc.Want; got != want {
t.Errorf("case #%d:\n got: %q\nwant: %q", i, got, want)
}
} }
return s[pos+3:]
} }
func TestHostGetPrompt(t *testing.T) { func TestHostGetPrompt(t *testing.T) {
var expected, actual string var expected, actual string
// Make the random colors consistent across tests u := message.NewUser(&Identity{nil, "foo"})
mathRand.Seed(1) u.SetColorIdx(2)
u := message.NewUser(&Identity{id: "foo"})
actual = GetPrompt(u) actual = GetPrompt(u)
expected = "[foo] " expected = "[foo] "
if actual != expected { if actual != expected {
t.Errorf("Invalid host prompt:\n Got: %q;\nWant: %q", actual, expected) t.Errorf("Got: %q; Expected: %q", actual, expected)
} }
u.SetConfig(message.UserConfig{ u.Config.Theme = &message.Themes[0]
Theme: &message.Themes[0],
})
actual = GetPrompt(u) actual = GetPrompt(u)
expected = "[\033[38;05;88mfoo\033[0m] " expected = "[\033[38;05;2mfoo\033[0m] "
if actual != expected { if actual != expected {
t.Errorf("Invalid host prompt:\n Got: %q;\nWant: %q", actual, expected) t.Errorf("Got: %q; Expected: %q", actual, expected)
} }
} }
func getHost(t *testing.T, auth *Auth) (*sshd.SSHListener, *Host) {
key, err := sshd.NewRandomSigner(1024)
if err != nil {
t.Fatal(err)
}
var config *ssh.ServerConfig
if auth == nil {
config = sshd.MakeNoAuth()
} else {
config = sshd.MakeAuth(auth)
}
config.AddHostKey(key)
s, err := sshd.ListenSSH("localhost:0", config)
if err != nil {
t.Fatal(err)
}
return s, NewHost(s, auth)
}
func TestHostNameCollision(t *testing.T) { func TestHostNameCollision(t *testing.T) {
s, host := getHost(t, nil) key, err := sshd.NewRandomSigner(512)
defer s.Close() if err != nil {
t.Fatal(err)
newUsers := make(chan *message.User)
host.OnUserJoined = func(u *message.User) {
newUsers <- u
} }
config := sshd.MakeNoAuth()
config.AddHostKey(key)
s, err := sshd.ListenSSH(":0", config)
if err != nil {
t.Fatal(err)
}
defer s.Close()
host := NewHost(s, nil)
go host.Serve() go host.Serve()
g := errgroup.Group{} done := make(chan struct{}, 1)
// First client // First client
g.Go(func() error { go func() {
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error { err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
// second client scanner := bufio.NewScanner(r)
name := (<-newUsers).Name()
if name != "Guest1" { // Consume the initial buffer
t.Errorf("Second client did not get Guest1 name: %q", name) scanner.Scan()
actual := scanner.Text()
if !strings.HasPrefix(actual, "[foo] ") {
t.Errorf("First client failed to get 'foo' name: %q", actual)
} }
return nil
actual = stripPrompt(actual)
expected := " * foo joined. (Connected: 1)"
if actual != expected {
t.Errorf("Got %q; expected %q", actual, expected)
}
// Ready for second client
done <- struct{}{}
scanner.Scan()
actual = stripPrompt(scanner.Text())
expected = " * Guest1 joined. (Connected: 2)"
if actual != expected {
t.Errorf("Got %q; expected %q", actual, expected)
}
// Wrap it up.
close(done)
}) })
}) if err != nil {
t.Fatal(err)
}
}()
// Wait for first client
<-done
// Second client // Second client
g.Go(func() error { err = sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) {
// first client scanner := bufio.NewScanner(r)
name := (<-newUsers).Name()
if name != "foo" {
t.Errorf("First client did not get foo name: %q", name)
}
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
return nil
})
})
if err := g.Wait(); err != nil { // Consume the initial buffer
t.Error(err) scanner.Scan()
actual := scanner.Text()
if !strings.HasPrefix(actual, "[Guest1] ") {
t.Errorf("Second client did not get Guest1 name.")
}
})
if err != nil {
t.Fatal(err)
} }
<-done
} }
func TestHostAllowlist(t *testing.T) { func TestHostWhitelist(t *testing.T) {
key, err := sshd.NewRandomSigner(512)
if err != nil {
t.Fatal(err)
}
auth := NewAuth() auth := NewAuth()
s, host := getHost(t, auth) config := sshd.MakeAuth(auth)
config.AddHostKey(key)
s, err := sshd.ListenSSH(":0", config)
if err != nil {
t.Fatal(err)
}
defer s.Close() defer s.Close()
host := NewHost(s, auth)
go host.Serve() go host.Serve()
target := s.Addr().String() target := s.Addr().String()
clientPrivateKey, err := sshd.NewRandomSigner(512) err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
if err != nil {
t.Error(err)
}
clientkey, err := rsa.GenerateKey(rand.Reader, 512)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
clientKey := clientPrivateKey.PublicKey()
loadCount := -1
loader := func() ([]ssh.PublicKey, error) {
loadCount++
return [][]ssh.PublicKey{
{},
{clientKey},
}[loadCount], nil
}
auth.LoadAllowlist(loader)
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil }) clientpubkey, _ := ssh.NewPublicKey(clientkey.Public())
if err != nil { auth.Whitelist(clientpubkey, 0)
t.Error(err)
}
auth.SetAllowlistMode(true) err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) {})
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
if err == nil { if err == nil {
t.Error(err) t.Error("Failed to block unwhitelisted connection.")
} }
err = sshd.ConnectShellWithKey(target, "foo", clientPrivateKey, func(r io.Reader, w io.WriteCloser) error { return nil })
if err == nil {
t.Error(err)
}
auth.ReloadAllowlist()
err = sshd.ConnectShell(target, "foo", func(r io.Reader, w io.WriteCloser) error { return nil })
if err == nil {
t.Error("Failed to block unallowlisted connection.")
}
}
func TestHostAllowlistCommand(t *testing.T) {
s, host := getHost(t, NewAuth())
defer s.Close()
go host.Serve()
users := make(chan *message.User)
host.OnUserJoined = func(u *message.User) {
users <- u
}
kickSignal := make(chan struct{})
clientKey, err := sshd.NewRandomSigner(1024)
if err != nil {
t.Fatal(err)
}
clientKeyFP := sshd.Fingerprint(clientKey.PublicKey())
go sshd.ConnectShellWithKey(s.Addr().String(), "bar", clientKey, func(r io.Reader, w io.WriteCloser) error {
<-kickSignal
n, err := w.Write([]byte("alive and well"))
if n != 0 || err == nil {
t.Error("could write after being kicked")
}
return nil
})
sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error {
<-users
<-users
m, ok := host.MemberByID("foo")
if !ok {
t.Fatal("can't get member foo")
}
scanner := bufio.NewScanner(r)
scanner.Scan() // Joined
scanner.Scan()
assertLineEq := func(expected ...string) {
if !scanner.Scan() {
t.Error("no line available")
}
actual := stripPrompt(scanner.Text())
for _, exp := range expected {
if exp == actual {
return
}
}
t.Errorf("expected %#v, got %q", expected, actual)
}
sendCmd := func(cmd string, formatting ...interface{}) {
host.HandleMsg(message.ParseInput(fmt.Sprintf(cmd, formatting...), m.User))
}
sendCmd("/allowlist")
assertLineEq("Err: must be op\r")
m.IsOp = true
sendCmd("/allowlist")
for _, expected := range [...]string{"Usage", "help", "on, off", "add, remove", "import", "reload", "reverify", "status"} {
if !scanner.Scan() {
t.Error("no line available")
}
if actual := stripPrompt(scanner.Text()); !strings.HasPrefix(actual, expected) {
t.Errorf("Unexpected help message order: have %q, want prefix %q", actual, expected)
}
}
sendCmd("/allowlist on")
if !host.auth.AllowlistMode() {
t.Error("allowlist not enabled after /allowlist on")
}
sendCmd("/allowlist off")
if host.auth.AllowlistMode() {
t.Error("allowlist not disabled after /allowlist off")
}
testKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIPUiNw0nQku4pcUCbZcJlIEAIf5bXJYTy/DKI1vh5b+P"
testKeyFP := "SHA256:GJNSl9NUcOS2pZYALn0C5Qgfh5deT+R+FfqNIUvpM9s="
if host.auth.allowlist.Len() != 0 {
t.Error("allowlist not empty before adding anyone")
}
sendCmd("/allowlist add ssh-invalid blah ssh-rsa wrongAsWell invalid foo bar %s", testKey)
assertLineEq("users without a public key: [foo]\r")
assertLineEq("invalid users: [invalid]\r")
assertLineEq("invalid keys: [ssh-invalid blah ssh-rsa wrongAsWell]\r")
if !host.auth.allowlist.In(testKeyFP) || !host.auth.allowlist.In(clientKeyFP) {
t.Error("failed to add keys to allowlist")
}
sendCmd("/allowlist remove invalid bar")
assertLineEq("invalid users: [invalid]\r")
if host.auth.allowlist.In(clientKeyFP) {
t.Error("failed to remove key from allowlist")
}
if !host.auth.allowlist.In(testKeyFP) {
t.Error("removed wrong key")
}
sendCmd("/allowlist import 5h")
if host.auth.allowlist.In(clientKeyFP) {
t.Error("imporrted key not seen long enough")
}
sendCmd("/allowlist import")
assertLineEq("users without a public key: [foo]\r")
if !host.auth.allowlist.In(clientKeyFP) {
t.Error("failed to import key")
}
sendCmd("/allowlist reload keep")
if !host.auth.allowlist.In(testKeyFP) {
t.Error("cleared allowlist to be kept")
}
sendCmd("/allowlist reload flush")
if host.auth.allowlist.In(testKeyFP) {
t.Error("kept allowlist to be cleared")
}
sendCmd("/allowlist reload thisIsWrong")
assertLineEq("Err: must specify whether to keep or flush current entries\r")
sendCmd("/allowlist reload")
assertLineEq("Err: must specify whether to keep or flush current entries\r")
sendCmd("/allowlist reverify")
assertLineEq("allowlist is disabled, so nobody will be kicked\r")
sendCmd("/allowlist on")
sendCmd("/allowlist reverify")
assertLineEq(" * Kicked during pubkey reverification: bar\r", " * bar left. (After 0 seconds)\r")
assertLineEq(" * Kicked during pubkey reverification: bar\r", " * bar left. (After 0 seconds)\r")
kickSignal <- struct{}{}
sendCmd("/allowlist add " + testKey)
sendCmd("/allowlist status")
assertLineEq("allowlist enabled\r")
assertLineEq(fmt.Sprintf("Keys on the allowlist without connected user: %s\r", testKeyFP))
sendCmd("/allowlist invalidSubcommand")
assertLineEq("Err: invalid subcommand: invalidSubcommand\r")
return nil
})
} }
func TestHostKick(t *testing.T) { func TestHostKick(t *testing.T) {
s, host := getHost(t, NewAuth()) key, err := sshd.NewRandomSigner(512)
if err != nil {
t.Fatal(err)
}
auth := NewAuth()
config := sshd.MakeAuth(auth)
config.AddHostKey(key)
s, err := sshd.ListenSSH(":0", config)
if err != nil {
t.Fatal(err)
}
defer s.Close() defer s.Close()
addr := s.Addr().String()
host := NewHost(s, nil)
go host.Serve() go host.Serve()
g := errgroup.Group{}
connected := make(chan struct{}) connected := make(chan struct{})
kicked := make(chan struct{}) done := make(chan struct{})
g.Go(func() error { go func() {
// First client // First client
return sshd.ConnectShell(s.Addr().String(), "foo", func(r io.Reader, w io.WriteCloser) error { err = sshd.ConnectShell(addr, "foo", func(r io.Reader, w io.WriteCloser) {
scanner := bufio.NewScanner(r)
// Consume the initial buffer
scanner.Scan() // Joined
// Make op // Make op
member, _ := host.Room.MemberByID("foo") member, _ := host.Room.MemberById("foo")
if member == nil { member.Op = true
return errors.New("failed to load MemberByID")
}
member.IsOp = true
// Change nicks, make sure op sticks
w.Write([]byte("/nick quux\r\n"))
scanner.Scan() // Prompt
scanner.Scan() // Nick change response
// Block until second client is here // Block until second client is here
connected <- struct{}{} connected <- struct{}{}
scanner.Scan() // Connected message
w.Write([]byte("/kick bar\r\n")) w.Write([]byte("/kick bar\r\n"))
scanner.Scan() // Prompt
scanner.Scan() // Kick result
if actual, expected := stripPrompt(scanner.Text()), " * bar was kicked by quux.\r"; actual != expected {
t.Errorf("Failed to detect kick:\n Got: %q;\nWant: %q", actual, expected)
}
kicked <- struct{}{}
return nil
}) })
}) if err != nil {
t.Fatal(err)
}
}()
g.Go(func() error { go func() {
// Second client // Second client
return sshd.ConnectShell(s.Addr().String(), "bar", func(r io.Reader, w io.WriteCloser) error { err = sshd.ConnectShell(addr, "bar", func(r io.Reader, w io.WriteCloser) {
scanner := bufio.NewScanner(r)
<-connected <-connected
scanner.Scan()
<-kicked // Consume while we're connected. Should break when kicked.
ioutil.ReadAll(r)
if _, err := w.Write([]byte("am I still here?\r\n")); err != io.EOF {
return errors.New("expected to be kicked")
}
scanner.Scan()
if err := scanner.Err(); err == io.EOF {
// All good, we got kicked.
return nil
} else {
return err
}
}) })
}) if err != nil {
t.Fatal(err)
if err := g.Wait(); err != nil {
t.Error(err)
}
}
func TestTimestampEnvConfig(t *testing.T) {
cases := []struct {
input string
timeformat *string
}{
{"", strptr("15:04")},
{"1", strptr("15:04")},
{"0", nil},
{"time +8h", strptr("15:04")},
{"datetime +8h", strptr("2006-01-02 15:04:05")},
}
for _, tc := range cases {
u := connectUserWithConfig(t, "dingus", map[string]string{
"SSHCHAT_TIMESTAMP": tc.input,
})
userConfig := u.Config()
if userConfig.Timeformat != nil && tc.timeformat != nil {
if *userConfig.Timeformat != *tc.timeformat {
t.Fatal("unexpected timeformat:", *userConfig.Timeformat, "expected:", *tc.timeformat)
}
} }
close(done)
}()
select {
case <-done:
case <-time.After(time.Second * 1):
t.Fatal("Timeout.")
} }
} }
func strptr(s string) *string {
return &s
}
func connectUserWithConfig(t *testing.T, name string, envConfig map[string]string) *message.User {
s, host := getHost(t, nil)
defer s.Close()
newUsers := make(chan *message.User)
host.OnUserJoined = func(u *message.User) {
newUsers <- u
}
go host.Serve()
clientConfig := sshd.NewClientConfig(name)
conn, err := ssh.Dial("tcp", s.Addr().String(), clientConfig)
if err != nil {
t.Fatal("unable to connect to test ssh-chat server:", err)
}
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatal("unable to open session:", err)
}
defer session.Close()
for key := range envConfig {
session.Setenv(key, envConfig[key])
}
err = session.Shell()
if err != nil {
t.Fatal("unable to open shell:", err)
}
for u := range newUsers {
if u.Name() == name {
return u
}
}
t.Fatalf("user %s not found in the host", name)
return nil
}

View File

@ -3,111 +3,49 @@ package sshchat
import ( import (
"fmt" "fmt"
"net" "net"
"strings"
"time"
"github.com/shazow/ssh-chat/chat" "github.com/shazow/ssh-chat/chat"
"github.com/shazow/ssh-chat/chat/message" "github.com/shazow/ssh-chat/chat/message"
"github.com/shazow/ssh-chat/internal/humantime"
"github.com/shazow/ssh-chat/internal/sanitize"
"github.com/shazow/ssh-chat/sshd" "github.com/shazow/ssh-chat/sshd"
) )
// Identity is a container for everything that identifies a client. // Identity is a container for everything that identifies a client.
type Identity struct { type Identity struct {
sshd.Connection sshd.Connection
id string id string
symbol string // symbol is displayed as a prefix to the name
created time.Time
} }
// NewIdentity returns a new identity object from an sshd.Connection. // NewIdentity returns a new identity object from an sshd.Connection.
func NewIdentity(conn sshd.Connection) *Identity { func NewIdentity(conn sshd.Connection) *Identity {
return &Identity{ return &Identity{
Connection: conn, Connection: conn,
id: sanitize.Name(conn.Name()), id: chat.SanitizeName(conn.Name()),
created: time.Now(),
} }
} }
// ID returns the name for the Identity func (i Identity) Id() string {
func (i Identity) ID() string {
return i.id return i.id
} }
// SetID Changes the Identity's name func (i *Identity) SetId(id string) {
func (i *Identity) SetID(id string) {
i.id = id i.id = id
} }
// SetName Changes the Identity's name
func (i *Identity) SetName(name string) { func (i *Identity) SetName(name string) {
i.SetID(name) i.SetId(name)
} }
func (i *Identity) SetSymbol(symbol string) {
i.symbol = symbol
}
// Name returns the name for the Identity
func (i Identity) Name() string { func (i Identity) Name() string {
if i.symbol != "" {
return i.symbol + " " + i.id
}
return i.id return i.id
} }
// Whois returns a whois description for non-admin users. func (i Identity) Whois() string {
func (i Identity) Whois(room *chat.Room) string {
fingerprint := "(no public key)"
if i.PublicKey() != nil {
fingerprint = sshd.Fingerprint(i.PublicKey())
}
// TODO: Rewrite this using strings.Builder like WhoisAdmin
awayMsg := ""
if m, ok := room.MemberByID(i.ID()); ok {
isAway, awaySince, awayMessage := m.GetAway()
if isAway {
awayMsg = fmt.Sprintf("%s > away: (%s ago) %s", message.Newline, humantime.Since(awaySince), awayMessage)
}
}
return "name: " + i.Name() + message.Newline +
" > fingerprint: " + fingerprint + message.Newline +
" > client: " + sanitize.Data(string(i.ClientVersion()), 64) + message.Newline +
" > joined: " + humantime.Since(i.created) + " ago" +
awayMsg
}
// WhoisAdmin returns a whois description for admin users.
func (i Identity) WhoisAdmin(room *chat.Room) string {
ip, _, _ := net.SplitHostPort(i.RemoteAddr().String()) ip, _, _ := net.SplitHostPort(i.RemoteAddr().String())
fingerprint := "(no public key)" fingerprint := "(no public key)"
if i.PublicKey() != nil { if i.PublicKey() != nil {
fingerprint = sshd.Fingerprint(i.PublicKey()) fingerprint = sshd.Fingerprint(i.PublicKey())
} }
return fmt.Sprintf("name: %s"+message.Newline+
out := strings.Builder{} " > ip: %s"+message.Newline+
out.WriteString("name: " + i.Name() + message.Newline + " > fingerprint: %s", i.Name(), ip, fingerprint)
" > ip: " + ip + message.Newline +
" > fingerprint: " + fingerprint + message.Newline +
" > client: " + sanitize.Data(string(i.ClientVersion()), 64) + message.Newline +
" > joined: " + humantime.Since(i.created) + " ago")
if member, ok := room.MemberByID(i.ID()); ok {
// Add room-specific whois
if isAway, awaySince, awayMessage := member.GetAway(); isAway {
fmt.Fprintf(&out, message.Newline+" > away: (%s ago) %s", humantime.Since(awaySince), awayMessage)
}
// FIXME: Should these always be present, even if they're false? Maybe
// change that once we add room context to Whois() above.
if !member.LastMsg().IsZero() {
out.WriteString(message.Newline + " > room/messaged: " + humantime.Since(member.LastMsg()) + " ago")
}
if room.IsOp(member.User) {
out.WriteString(message.Newline + " > room/op: true")
}
}
return out.String()
} }

View File

@ -1,21 +0,0 @@
package humantime
import (
"fmt"
"time"
)
// since returns a human-friendly relative time string
func Since(t time.Time) string {
d := time.Since(t)
switch {
case d < time.Minute*2:
return fmt.Sprintf("%0.f seconds", d.Seconds())
case d < time.Hour*2:
return fmt.Sprintf("%0.f minutes", d.Minutes())
case d < time.Hour*48:
return fmt.Sprintf("%0.1f hours", d.Minutes()/60)
}
days := d.Minutes() / (24 * 60)
return fmt.Sprintf("%0.1f days", days)
}

View File

@ -1,41 +0,0 @@
package humantime
import (
"testing"
"time"
)
func TestHumanSince(t *testing.T) {
tests := []struct {
input time.Duration
expected string
}{
{
time.Second * 42,
"42 seconds",
},
{
time.Second * 60 * 5,
"5 minutes",
},
{
time.Minute * 185,
"3.1 hours",
},
{
time.Hour * 49,
"2.0 days",
},
{
time.Hour * (24*900 + 12),
"900.5 days",
},
}
for _, test := range tests {
absolute := time.Now().Add(test.input * -1)
if actual, expected := Since(absolute), test.expected; actual != expected {
t.Errorf("Got: %q; Expected: %q", actual, expected)
}
}
}

View File

@ -1,29 +0,0 @@
package sanitize
import "regexp"
var (
reStripName = regexp.MustCompile("[^\\w.-]")
reStripData = regexp.MustCompile("[^[:ascii:]]|[[:cntrl:]]")
)
const maxLength = 16
// Name returns a name with only allowed characters and a reasonable length
func Name(s string) string {
s = reStripName.ReplaceAllString(s, "")
nameLength := maxLength
if len(s) <= maxLength {
nameLength = len(s)
}
s = s[:nameLength]
return s
}
// Data returns a string with only allowed characters for client-provided metadata inputs.
func Data(s string, maxlen int) string {
if len(s) > maxlen {
s = s[:maxlen]
}
return reStripData.ReplaceAllString(s, "")
}

View File

@ -1,7 +1,7 @@
package sshchat package sshchat
import ( import (
"io/ioutil" "bytes"
"github.com/alexcesaro/log" "github.com/alexcesaro/log"
"github.com/alexcesaro/log/golog" "github.com/alexcesaro/log/golog"
@ -9,12 +9,12 @@ import (
var logger *golog.Logger var logger *golog.Logger
// SetLogger sets the package logging to use l.
func SetLogger(l *golog.Logger) { func SetLogger(l *golog.Logger) {
logger = l logger = l
} }
func init() { func init() {
// Set a default null logger // Set a default null logger
SetLogger(golog.New(ioutil.Discard, log.Debug)) var b bytes.Buffer
SetLogger(golog.New(&b, log.Debug))
} }

View File

@ -1,4 +1 @@
Welcome to ssh-chat, enter /help for more. Welcome to chat.shazow.net, enter /help for more. 
🐛 Please enjoy our selection of bugs, but run your own server if you want to crash it: https://ssh.chat/issues
🍮 Sponsors get an emoji prefix: https://ssh.chat/sponsor
😌 Be nice and follow our Code of Conduct: https://ssh.chat/conduct

70
set.go Normal file
View File

@ -0,0 +1,70 @@
package sshchat
import (
"sync"
"time"
)
type expiringValue struct {
time.Time
}
func (v expiringValue) Bool() bool {
return time.Now().Before(v.Time)
}
type value struct{}
func (v value) Bool() bool {
return true
}
type setValue interface {
Bool() bool
}
// Set with expire-able keys
type Set struct {
lookup map[string]setValue
sync.Mutex
}
// NewSet creates a new set.
func NewSet() *Set {
return &Set{
lookup: map[string]setValue{},
}
}
// Len returns the size of the set right now.
func (s *Set) Len() int {
return len(s.lookup)
}
// In checks if an item exists in this set.
func (s *Set) In(key string) bool {
s.Lock()
v, ok := s.lookup[key]
if ok && !v.Bool() {
ok = false
delete(s.lookup, key)
}
s.Unlock()
return ok
}
// Add item to this set, replace if it exists.
func (s *Set) Add(key string) {
s.Lock()
s.lookup[key] = value{}
s.Unlock()
}
// Add item to this set, replace if it exists.
func (s *Set) AddExpiring(key string, d time.Duration) time.Time {
until := time.Now().Add(d)
s.Lock()
s.lookup[key] = expiringValue{until}
s.Unlock()
return until
}

View File

@ -1,59 +0,0 @@
package set
import "time"
// Interface for an item storeable in the set
type Item interface {
Key() string
Value() interface{}
}
type item struct {
key string
value interface{}
}
func (item *item) Key() string {
return item.key
}
func (item *item) Value() interface{} {
return item.value
}
func Itemize(key string, value interface{}) Item {
return &item{key, value}
}
type StringItem string
func (item StringItem) Key() string {
return string(item)
}
func (item StringItem) Value() interface{} {
return true
}
func Expire(item Item, d time.Duration) Item {
return &ExpiringItem{
Item: item,
Time: time.Now().Add(d),
}
}
type ExpiringItem struct {
Item
time.Time
}
func (item *ExpiringItem) Expired() bool {
return time.Now().After(item.Time)
}
func (item *ExpiringItem) Value() interface{} {
if item.Expired() {
return nil
}
return item.Item.Value()
}

View File

@ -1,215 +0,0 @@
package set
import (
"errors"
"strings"
"sync"
)
// Returned when an added key already exists in the set.
var ErrCollision = errors.New("key already exists")
// Returned when a requested item does not exist in the set.
var ErrMissing = errors.New("item does not exist")
// ZeroValue can be used when we only care about the key, not about the value.
var ZeroValue = struct{}{}
// Interface is the Set interface
type Interface interface {
Clear() int
Each(fn IterFunc) error
// Add only if the item does not already exist
Add(item Item) error
// Set item, override if it already exists
Set(item Item) error
Get(key string) (Item, error)
In(key string) bool
Len() int
ListPrefix(prefix string) []Item
Remove(key string) error
Replace(oldKey string, item Item) error
}
type IterFunc func(key string, item Item) error
type Set struct {
sync.RWMutex
lookup map[string]Item
normalize func(string) string
}
// New creates a new set with case-insensitive keys
func New() *Set {
return &Set{
lookup: map[string]Item{},
normalize: normalize,
}
}
// Clear removes all items and returns the number removed.
func (s *Set) Clear() int {
s.Lock()
n := len(s.lookup)
s.lookup = map[string]Item{}
s.Unlock()
return n
}
// Len returns the size of the set right now.
func (s *Set) Len() int {
s.RLock()
defer s.RUnlock()
return len(s.lookup)
}
// In checks if an item exists in this set.
func (s *Set) In(key string) bool {
key = s.normalize(key)
s.RLock()
item, ok := s.lookup[key]
s.RUnlock()
if ok && item.Value() == nil {
s.cleanup(key)
ok = false
}
return ok
}
// Get returns an item with the given key.
func (s *Set) Get(key string) (Item, error) {
key = s.normalize(key)
s.RLock()
item, ok := s.lookup[key]
s.RUnlock()
if !ok {
return nil, ErrMissing
}
if item.Value() == nil {
s.cleanup(key)
}
return item, nil
}
// Remove potentially expired key (normalized).
func (s *Set) cleanup(key string) {
s.Lock()
item, ok := s.lookup[key]
if ok && item.Value() == nil {
delete(s.lookup, key)
}
s.Unlock()
}
// Add item to this set if it does not exist already.
func (s *Set) Add(item Item) error {
key := s.normalize(item.Key())
s.Lock()
defer s.Unlock()
oldItem, found := s.lookup[key]
if found && oldItem.Value() != nil {
return ErrCollision
}
s.lookup[key] = item
return nil
}
// Set item to this set, even if it already exists.
func (s *Set) Set(item Item) error {
key := s.normalize(item.Key())
s.Lock()
defer s.Unlock()
s.lookup[key] = item
return nil
}
// Remove item from this set.
func (s *Set) Remove(key string) error {
key = s.normalize(key)
s.Lock()
defer s.Unlock()
_, found := s.lookup[key]
if !found {
return ErrMissing
}
delete(s.lookup, key)
return nil
}
// Replace oldKey with a new item, which might be a new key.
// Can be used to rename items.
func (s *Set) Replace(oldKey string, item Item) error {
newKey := s.normalize(item.Key())
oldKey = s.normalize(oldKey)
s.Lock()
defer s.Unlock()
if newKey != oldKey {
// Check if it already exists
_, found := s.lookup[newKey]
if found {
return ErrCollision
}
// Remove oldKey
_, found = s.lookup[oldKey]
if !found {
return ErrMissing
}
delete(s.lookup, oldKey)
}
// Add new item
s.lookup[newKey] = item
return nil
}
// Each loops over every item while holding a read lock and applies fn to each
// element.
func (s *Set) Each(fn IterFunc) error {
var err error
s.RLock()
for key, item := range s.lookup {
if item.Value() == nil {
// Expired
defer s.cleanup(key)
continue
}
if err = fn(key, item); err != nil {
// Abort early
break
}
}
s.RUnlock()
return err
}
// ListPrefix returns a list of items with a prefix, normalized.
// TODO: Add trie for efficient prefix lookup
func (s *Set) ListPrefix(prefix string) []Item {
r := []Item{}
prefix = s.normalize(prefix)
s.Each(func(key string, item Item) error {
if strings.HasPrefix(key, prefix) {
r = append(r, item)
}
return nil
})
return r
}
func normalize(key string) string {
return strings.ToLower(key)
}

View File

@ -1,112 +0,0 @@
package set
import (
"testing"
"time"
)
func TestSetExpiring(t *testing.T) {
s := New()
if s.In("foo") {
t.Error("matched before set.")
}
if err := s.Add(StringItem("foo")); err != nil {
t.Fatalf("failed to add foo: %s", err)
}
if !s.In("foo") {
t.Errorf("not matched after set")
}
if s.Len() != 1 {
t.Error("not len 1 after set")
}
item := Expire(StringItem("asdf"), -time.Nanosecond).(*ExpiringItem)
if !item.Expired() {
t.Errorf("ExpiringItem a nanosec ago is not expiring")
}
if err := s.Add(item); err != nil {
t.Error("Error adding expired item to set: ", err)
}
if s.In("asdf") {
t.Error("Expired item in set")
}
if s.Len() != 1 {
t.Error("not len 1 after expired item")
}
item = &ExpiringItem{nil, time.Now().Add(time.Minute * 5)}
if item.Expired() {
t.Errorf("ExpiringItem in 5 minutes is expiring now")
}
item = Expire(StringItem("bar"), time.Minute*5).(*ExpiringItem)
until := item.Time
if !until.After(time.Now().Add(time.Minute*4)) || !until.Before(time.Now().Add(time.Minute*6)) {
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
}
if item.Value() == nil {
t.Errorf("bar expired immediately")
}
if err := s.Add(item); err != nil {
t.Fatalf("failed to add item: %s", err)
}
itemInLookup, ok := s.lookup["bar"]
if !ok {
t.Fatalf("bar not present in lookup even though it's not expired")
}
if itemInLookup != item {
t.Fatalf("present item %#v != %#v original item", itemInLookup, item)
}
if !s.In("bar") {
t.Errorf("not matched after timed set")
}
if s.Len() != 2 {
t.Error("not len 2 after set")
}
if err := s.Replace("bar", Expire(StringItem("quux"), time.Minute*5)); err != nil {
t.Fatalf("failed to add quux: %s", err)
}
if err := s.Replace("quux", Expire(StringItem("bar"), time.Minute*5)); err != nil {
t.Fatalf("failed to add bar: %s", err)
}
if s.In("quux") {
t.Error("quux in set after replace")
}
if _, err := s.Get("bar"); err != nil {
t.Errorf("failed to get before expiry: %s", err)
}
if err := s.Add(StringItem("barbar")); err != nil {
t.Fatalf("failed to add barbar")
}
if _, err := s.Get("barbar"); err != nil {
t.Errorf("failed to get barbar: %s", err)
}
b := s.ListPrefix("b")
if len(b) != 2 || !anyItemPresentWithKey(b, "bar") || !anyItemPresentWithKey(b, "barbar") {
t.Errorf("b-prefix incorrect: %q", b)
}
if err := s.Remove("bar"); err != nil {
t.Fatalf("failed to remove: %s", err)
}
if s.Len() != 2 {
t.Error("not len 2 after remove")
}
s.Clear()
if s.Len() != 0 {
t.Error("not len 0 after clear")
}
}
func anyItemPresentWithKey(items []Item, key string) bool {
for _, item := range items {
if item.Key() == key {
return true
}
}
return false
}

58
set_test.go Normal file
View File

@ -0,0 +1,58 @@
package sshchat
import (
"testing"
"time"
)
func TestSetExpiring(t *testing.T) {
s := NewSet()
if s.In("foo") {
t.Error("Matched before set.")
}
s.Add("foo")
if !s.In("foo") {
t.Errorf("Not matched after set")
}
if s.Len() != 1 {
t.Error("Not len 1 after set")
}
v := expiringValue{time.Now().Add(-time.Nanosecond * 1)}
if v.Bool() {
t.Errorf("expiringValue now is not expiring")
}
v = expiringValue{time.Now().Add(time.Minute * 2)}
if !v.Bool() {
t.Errorf("expiringValue in 2 minutes is expiring now")
}
until := s.AddExpiring("bar", time.Minute*2)
if !until.After(time.Now().Add(time.Minute*1)) || !until.Before(time.Now().Add(time.Minute*3)) {
t.Errorf("until is not a minute after %s: %s", time.Now(), until)
}
val, ok := s.lookup["bar"]
if !ok {
t.Errorf("bar not in lookup")
}
if !until.Equal(val.(expiringValue).Time) {
t.Errorf("bar's until is not equal to the expected value")
}
if !val.Bool() {
t.Errorf("bar expired immediately")
}
if !s.In("bar") {
t.Errorf("Not matched after timed set")
}
if s.Len() != 2 {
t.Error("Not len 2 after set")
}
s.AddExpiring("bar", time.Nanosecond*1)
if s.In("bar") {
t.Error("Matched after expired timer")
}
}

View File

@ -5,41 +5,26 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"net" "net"
"time"
"github.com/shazow/ssh-chat/internal/sanitize"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// Auth is used to authenticate connections. // Auth is used to authenticate connections based on public keys.
type Auth interface { type Auth interface {
// Whether to allow connections without a public key. // Whether to allow connections without a public key.
AllowAnonymous() bool AllowAnonymous() bool
// If passphrase authentication is accepted // Given address and public key, return if the connection should be permitted.
AcceptPassphrase() bool Check(net.Addr, ssh.PublicKey) (bool, error)
// Given address and public key and client agent string, returns nil if the connection is not banned.
CheckBans(net.Addr, ssh.PublicKey, string) error
// Given a public key, returns nil if the connection should be allowed.
CheckPublicKey(ssh.PublicKey) error
// Given a passphrase, returns nil if the connection should be allowed.
CheckPassphrase(string) error
// BanAddr bans an IP address for the specified amount of time.
BanAddr(net.Addr, time.Duration)
} }
// MakeAuth makes an ssh.ServerConfig which performs authentication against an Auth implementation. // MakeAuth makes an ssh.ServerConfig which performs authentication against an Auth implementation.
// TODO: Switch to using ssh.AuthMethod instead?
func MakeAuth(auth Auth) *ssh.ServerConfig { func MakeAuth(auth Auth) *ssh.ServerConfig {
config := ssh.ServerConfig{ config := ssh.ServerConfig{
NoClientAuth: false, NoClientAuth: false,
// Auth-related things should be constant-time to avoid timing attacks. // Auth-related things should be constant-time to avoid timing attacks.
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
err := auth.CheckBans(conn.RemoteAddr(), key, sanitize.Data(string(conn.ClientVersion()), 64)) ok, err := auth.Check(conn.RemoteAddr(), key)
if err != nil { if !ok {
return nil, err
}
err = auth.CheckPublicKey(key)
if err != nil {
return nil, err return nil, err
} }
perm := &ssh.Permissions{Extensions: map[string]string{ perm := &ssh.Permissions{Extensions: map[string]string{
@ -47,31 +32,11 @@ func MakeAuth(auth Auth) *ssh.ServerConfig {
}} }}
return perm, nil return perm, nil
}, },
// We use KeyboardInteractiveCallback instead of PasswordCallback to
// avoid preventing the client from including a pubkey in the user
// identification.
KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) { KeyboardInteractiveCallback: func(conn ssh.ConnMetadata, challenge ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
err := auth.CheckBans(conn.RemoteAddr(), nil, sanitize.Data(string(conn.ClientVersion()), 64)) if !auth.AllowAnonymous() {
if err != nil { return nil, errors.New("public key authentication required")
return nil, err
}
if auth.AcceptPassphrase() {
var answers []string
answers, err = challenge("", "", []string{"Passphrase required to connect: "}, []bool{true})
if err == nil {
if len(answers) != 1 {
err = errors.New("didn't get passphrase")
} else {
err = auth.CheckPassphrase(answers[0])
if err != nil {
auth.BanAddr(conn.RemoteAddr(), time.Second*2)
}
}
}
} else if !auth.AllowAnonymous() {
err = errors.New("public key authentication required")
} }
_, err := auth.Check(conn.RemoteAddr(), nil)
return nil, err return nil, err
}, },
} }
@ -103,5 +68,5 @@ func MakeNoAuth() *ssh.ServerConfig {
// See: https://anongit.mindrot.org/openssh.git/commit/?id=56d1c83cdd1ac // See: https://anongit.mindrot.org/openssh.git/commit/?id=56d1c83cdd1ac
func Fingerprint(k ssh.PublicKey) string { func Fingerprint(k ssh.PublicKey) string {
hash := sha256.Sum256(k.Marshal()) hash := sha256.Sum256(k.Marshal())
return "SHA256:" + base64.StdEncoding.EncodeToString(hash[:]) return base64.StdEncoding.EncodeToString(hash[:])
} }

View File

@ -26,28 +26,12 @@ func NewClientConfig(name string) *ssh.ClientConfig {
return return
}), }),
}, },
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
}
func NewClientConfigWithKey(name string, key ssh.Signer) *ssh.ClientConfig {
return &ssh.ClientConfig{
User: name,
Auth: []ssh.AuthMethod{ssh.PublicKeys(key)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
} }
} }
// ConnectShell makes a barebones SSH client session, used for testing. // ConnectShell makes a barebones SSH client session, used for testing.
func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser) error) error { func ConnectShell(host string, name string, handler func(r io.Reader, w io.WriteCloser)) error {
return connectShell(host, NewClientConfig(name), handler) config := NewClientConfig(name)
}
func ConnectShellWithKey(host string, name string, key ssh.Signer, handler func(r io.Reader, w io.WriteCloser) error) error {
return connectShell(host, NewClientConfigWithKey(name, key), handler)
}
func connectShell(host string, config *ssh.ClientConfig, handler func(r io.Reader, w io.WriteCloser) error) error {
conn, err := ssh.Dial("tcp", host, config) conn, err := ssh.Dial("tcp", host, config)
if err != nil { if err != nil {
return err return err
@ -70,11 +54,11 @@ func connectShell(host string, config *ssh.ClientConfig, handler func(r io.Reade
return err return err
} }
/* /* FIXME: Do we want to request a PTY?
err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}) err = session.RequestPty("xterm", 80, 40, ssh.TerminalModes{})
if err != nil { if err != nil {
return err return err
} }
*/ */
err = session.Shell() err = session.Shell()
@ -82,10 +66,7 @@ func connectShell(host string, config *ssh.ClientConfig, handler func(r io.Reade
return err return err
} }
_, err = session.SendRequest("ping", true, nil) handler(out, in)
if err != nil {
return err
}
return handler(out, in) return nil
} }

View File

@ -4,7 +4,6 @@ import (
"errors" "errors"
"net" "net"
"testing" "testing"
"time"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -16,39 +15,32 @@ type RejectAuth struct{}
func (a RejectAuth) AllowAnonymous() bool { func (a RejectAuth) AllowAnonymous() bool {
return false return false
} }
func (a RejectAuth) AcceptPassphrase() bool { func (a RejectAuth) Check(net.Addr, ssh.PublicKey) (bool, error) {
return false return false, errRejectAuth
} }
func (a RejectAuth) CheckBans(addr net.Addr, key ssh.PublicKey, clientVersion string) error {
return errRejectAuth func consume(ch <-chan *Terminal) {
for _ = range ch {
}
} }
func (a RejectAuth) CheckPublicKey(ssh.PublicKey) error {
return errRejectAuth
}
func (a RejectAuth) CheckPassphrase(string) error {
return errRejectAuth
}
func (a RejectAuth) BanAddr(net.Addr, time.Duration) {}
func TestClientReject(t *testing.T) { func TestClientReject(t *testing.T) {
signer, err := NewRandomSigner(512) signer, err := NewRandomSigner(512)
if err != nil {
t.Fatal(err)
}
config := MakeAuth(RejectAuth{}) config := MakeAuth(RejectAuth{})
config.AddHostKey(signer) config.AddHostKey(signer)
s, err := ListenSSH("localhost:0", config) s, err := ListenSSH(":0", config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer s.Close() defer s.Close()
go s.Serve() go consume(s.ServeTerminal())
conn, err := ssh.Dial("tcp", s.Addr().String(), NewClientConfig("foo")) conn, err := ssh.Dial("tcp", s.Addr().String(), NewClientConfig("foo"))
if err == nil { if err == nil {
defer conn.Close() defer conn.Close()
t.Error("Failed to reject conncetion") t.Error("Failed to reject conncetion")
} }
t.Log(err)
} }

View File

@ -5,7 +5,6 @@ import stdlog "log"
var logger *stdlog.Logger var logger *stdlog.Logger
// SetLogger sets the package logging output to use w.
func SetLogger(w io.Writer) { func SetLogger(w io.Writer) {
flags := stdlog.Flags() flags := stdlog.Flags()
prefix := "[sshd] " prefix := "[sshd] "

View File

@ -2,22 +2,19 @@ package sshd
import ( import (
"net" "net"
"time"
"github.com/shazow/rateio" "github.com/shazow/rateio"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// SSHListener is the container for the connection and ssh-related configuration // Container for the connection and ssh-related configuration
type SSHListener struct { type SSHListener struct {
net.Listener net.Listener
config *ssh.ServerConfig config *ssh.ServerConfig
RateLimit func() rateio.Limiter
RateLimit func() rateio.Limiter
HandlerFunc func(term *Terminal)
} }
// ListenSSH makes an SSH listener socket // Make an SSH listener socket
func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) { func ListenSSH(laddr string, config *ssh.ServerConfig) (*SSHListener, error) {
socket, err := net.Listen("tcp", laddr) socket, err := net.Listen("tcp", laddr)
if err != nil { if err != nil {
@ -33,12 +30,6 @@ func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) {
conn = ReadLimitConn(conn, l.RateLimit()) conn = ReadLimitConn(conn, l.RateLimit())
} }
// If the connection doesn't write anything back for too long before we get
// a valid session, it should be dropped.
var handleTimeout = 20 * time.Second
conn.SetReadDeadline(time.Now().Add(handleTimeout))
defer conn.SetReadDeadline(time.Time{})
// Upgrade TCP connection to SSH connection // Upgrade TCP connection to SSH connection
sshConn, channels, requests, err := ssh.NewServerConn(conn, l.config) sshConn, channels, requests, err := ssh.NewServerConn(conn, l.config)
if err != nil { if err != nil {
@ -50,26 +41,33 @@ func (l *SSHListener) handleConn(conn net.Conn) (*Terminal, error) {
return NewSession(sshConn, channels) return NewSession(sshConn, channels)
} }
// Serve Accepts incoming connections as terminal requests and yield them // Accept incoming connections as terminal requests and yield them
func (l *SSHListener) Serve() { func (l *SSHListener) ServeTerminal() <-chan *Terminal {
defer l.Close() ch := make(chan *Terminal)
for {
conn, err := l.Accept()
if err != nil { go func() {
logger.Printf("Failed to accept connection: %s", err) defer l.Close()
break defer close(ch)
}
for {
conn, err := l.Accept()
// Goroutineify to resume accepting sockets early
go func() {
term, err := l.handleConn(conn)
if err != nil { if err != nil {
logger.Printf("[%s] Failed to handshake: %s", conn.RemoteAddr(), err) logger.Printf("Failed to accept connection: %v", err)
conn.Close() // Must be closed to avoid a leak
return return
} }
l.HandlerFunc(term)
}() // Goroutineify to resume accepting sockets early
} go func() {
term, err := l.handleConn(conn)
if err != nil {
logger.Printf("Failed to handshake: %v", err)
return
}
ch <- term
}()
}
}()
return ch
} }

View File

@ -8,12 +8,12 @@ import (
func TestServerInit(t *testing.T) { func TestServerInit(t *testing.T) {
config := MakeNoAuth() config := MakeNoAuth()
s, err := ListenSSH("localhost:badport", config) s, err := ListenSSH(":badport", config)
if err == nil { if err == nil {
t.Fatal("should fail on bad port") t.Fatal("should fail on bad port")
} }
s, err = ListenSSH("localhost:0", config) s, err = ListenSSH(":0", config)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -25,23 +25,16 @@ func TestServerInit(t *testing.T) {
} }
func TestServeTerminals(t *testing.T) { func TestServeTerminals(t *testing.T) {
signer, err := NewRandomSigner(1024) signer, err := NewRandomSigner(512)
if err != nil {
t.Fatal(err)
}
config := MakeNoAuth() config := MakeNoAuth()
config.AddHostKey(signer) config.AddHostKey(signer)
s, err := ListenSSH("localhost:0", config) s, err := ListenSSH(":0", config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
terminals := make(chan *Terminal) terminals := s.ServeTerminal()
s.HandlerFunc = func(term *Terminal) {
terminals <- term
}
go s.Serve()
go func() { go func() {
// Accept one terminal, read from it, echo back, close. // Accept one terminal, read from it, echo back, close.
@ -52,7 +45,7 @@ func TestServeTerminals(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
_, err = term.Write([]byte("echo: " + line + "\n")) _, err = term.Write([]byte("echo: " + line + "\r\n"))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -63,24 +56,23 @@ func TestServeTerminals(t *testing.T) {
host := s.Addr().String() host := s.Addr().String()
name := "foo" name := "foo"
err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) error { err = ConnectShell(host, name, func(r io.Reader, w io.WriteCloser) {
// Consume if there is anything // Consume if there is anything
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
w.Write([]byte("hello\r\n")) w.Write([]byte("hello\r\n"))
buf.Reset() buf.Reset()
_, err := io.Copy(buf, r) _, err := io.Copy(buf, r)
if err != nil {
t.Error(err)
}
expected := "> hello\r\necho: hello\r\n" expected := "> hello\r\necho: hello\r\n"
actual := buf.String() actual := buf.String()
if actual != expected { if actual != expected {
if err != nil {
t.Error(err)
}
t.Errorf("Got %q; expected %q", actual, expected) t.Errorf("Got %q; expected %q", actual, expected)
} }
s.Close() s.Close()
return nil
}) })
if err != nil { if err != nil {

View File

@ -6,8 +6,8 @@ import "encoding/binary"
// parsePtyRequest parses the payload of the pty-req message and extracts the // parsePtyRequest parses the payload of the pty-req message and extracts the
// dimensions of the terminal. See RFC 4254, section 6.2. // dimensions of the terminal. See RFC 4254, section 6.2.
func parsePtyRequest(s []byte) (term string, width, height int, ok bool) { func parsePtyRequest(s []byte) (width, height int, ok bool) {
term, s, ok = parseString(s) _, s, ok = parseString(s)
if !ok { if !ok {
return return
} }
@ -28,11 +28,11 @@ func parsePtyRequest(s []byte) (term string, width, height int, ok bool) {
} }
func parseWinchRequest(s []byte) (width, height int, ok bool) { func parseWinchRequest(s []byte) (width, height int, ok bool) {
width32, _, ok := parseUint32(s) width32, s, ok := parseUint32(s)
if !ok { if !ok {
return return
} }
height32, _, ok := parseUint32(s) height32, s, ok := parseUint32(s)
if !ok { if !ok {
return return
} }

View File

@ -44,8 +44,8 @@ type inputLimiter struct {
func NewInputLimiter() rateio.Limiter { func NewInputLimiter() rateio.Limiter {
grace := time.Second * 3 grace := time.Second * 3
return &inputLimiter{ return &inputLimiter{
Amount: 2 << 14, // ~16kb, should be plenty for a high typing rate/copypasta/large key handshakes. Amount: 200 * 4 * 2, // Assume fairly high typing rate + margin for copypasta of links.
Frequency: time.Minute * 1, Frequency: time.Minute * 2,
readCap: 128, // Allow up to 128 bytes per read (anecdotally, 1 character = 52 bytes over ssh) readCap: 128, // Allow up to 128 bytes per read (anecdotally, 1 character = 52 bytes over ssh)
numRead: -1024 * 1024, // Start with a 1mb grace numRead: -1024 * 1024, // Start with a 1mb grace
timeRead: time.Now().Add(grace), timeRead: time.Now().Add(grace),

View File

@ -4,28 +4,16 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"sync"
"time"
"github.com/shazow/ssh-chat/sshd/terminal"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
) )
var keepaliveInterval = time.Second * 30
var keepaliveRequest = "keepalive@ssh-chat"
// ErrNoSessionChannel is returned when there is no session channel.
var ErrNoSessionChannel = errors.New("no session channel")
// ErrNotSessionChannel is returned when a channel is not a session channel.
var ErrNotSessionChannel = errors.New("terminal requires session channel")
// Connection is an interface with fields necessary to operate an sshd host. // Connection is an interface with fields necessary to operate an sshd host.
type Connection interface { type Connection interface {
PublicKey() ssh.PublicKey PublicKey() ssh.PublicKey
RemoteAddr() net.Addr RemoteAddr() net.Addr
Name() string Name() string
ClientVersion() []byte
Close() error Close() error
} }
@ -55,129 +43,72 @@ func (c sshConn) Name() string {
return c.User() return c.User()
} }
// EnvVar is an environment variable key-value pair // Extending ssh/terminal to include a closer interface
type EnvVar struct {
Key string
Value string
}
func (v EnvVar) String() string {
return v.Key + "=" + v.Value
}
// Env is a wrapper type around []EnvVar with some helper methods
type Env []EnvVar
// Get returns the latest value for a given key, or empty string if not found
func (e Env) Get(key string) string {
for i := len(e) - 1; i >= 0; i-- {
if e[i].Key == key {
return e[i].Value
}
}
return ""
}
// Terminal extends ssh/terminal to include a close method
type Terminal struct { type Terminal struct {
terminal.Terminal terminal.Terminal
Conn Connection Conn Connection
Channel ssh.Channel Channel ssh.Channel
done chan struct{}
closeOnce sync.Once
mu sync.Mutex
env []EnvVar
term string
} }
// Make new terminal from a session channel // Make new terminal from a session channel
// TODO: For v2, make a separate `Serve(ctx context.Context) error` method to activate the Terminal
func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) { func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
if ch.ChannelType() != "session" { if ch.ChannelType() != "session" {
return nil, ErrNotSessionChannel return nil, errors.New("terminal requires session channel")
} }
channel, requests, err := ch.Accept() channel, requests, err := ch.Accept()
if err != nil { if err != nil {
return nil, err return nil, err
} }
term := Terminal{ term := Terminal{
Terminal: *terminal.NewTerminal(channel, ""), *terminal.NewTerminal(channel, "Connecting..."),
Conn: sshConn{conn}, sshConn{conn},
Channel: channel, channel,
done: make(chan struct{}),
} }
ready := make(chan struct{}) go term.listen(requests)
go term.listen(requests, ready)
go func() { go func() {
// Keep-Alive Ticker // FIXME: Is this necessary?
ticker := time.Tick(keepaliveInterval) conn.Wait()
for { channel.Close()
select {
case <-ticker:
_, err := channel.SendRequest(keepaliveRequest, true, nil)
if err != nil {
// Connection is gone
logger.Printf("[%s] Keepalive failed, closing terminal: %s", term.Conn.RemoteAddr(), err)
term.Close()
return
}
case <-term.done:
return
}
}
}() }()
// We need to wait for term.ready to acquire a shell before we return, this return &term, nil
// gives the SSH session a chance to populate the env vars and other state.
// TODO: Make the timeout configurable
// TODO: Use context.Context for abort/timeout in the future, will need to change the API.
select {
case <-ready: // shell acquired
return &term, nil
case <-term.done:
return nil, errors.New("terminal aborted")
case <-time.NewTimer(time.Minute).C:
return nil, errors.New("timed out starting terminal")
}
} }
// NewSession Finds a session channel and make a Terminal from it // Find session channel and make a Terminal from it
func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (*Terminal, error) { func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (term *Terminal, err error) {
// Make a terminal from the first session found
for ch := range channels { for ch := range channels {
if t := ch.ChannelType(); t != "session" { if t := ch.ChannelType(); t != "session" {
logger.Printf("[%s] Ignored channel type: %s", conn.RemoteAddr(), t)
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
continue continue
} }
return NewTerminal(conn, ch) term, err = NewTerminal(conn, ch)
if err == nil {
break
}
} }
return nil, ErrNoSessionChannel if term != nil {
// Reject the rest.
// FIXME: Do we need this?
go func() {
for ch := range channels {
ch.Reject(ssh.Prohibited, "only one session allowed")
}
}()
}
return term, err
} }
// Close terminal and ssh connection // Close terminal and ssh connection
func (t *Terminal) Close() error { func (t *Terminal) Close() error {
var err error return t.Conn.Close()
t.closeOnce.Do(func() {
close(t.done)
if err := t.Channel.Close(); err != nil {
logger.Printf("[%s] Failed to close terminal channel: %s", t.Conn.RemoteAddr(), err)
}
err = t.Conn.Close()
})
return err
} }
// listen negotiates the terminal type and state // Negotiate terminal type and settings
// ready is closed when the terminal is ready. func (t *Terminal) listen(requests <-chan *ssh.Request) {
func (t *Terminal) listen(requests <-chan *ssh.Request, ready chan<- struct{}) {
hasShell := false hasShell := false
for req := range requests { for req := range requests {
@ -189,19 +120,13 @@ func (t *Terminal) listen(requests <-chan *ssh.Request, ready chan<- struct{}) {
if !hasShell { if !hasShell {
ok = true ok = true
hasShell = true hasShell = true
close(ready)
} }
case "pty-req": case "pty-req":
var term string width, height, ok = parsePtyRequest(req.Payload)
term, width, height, ok = parsePtyRequest(req.Payload)
if ok { if ok {
// TODO: Hardcode width to 100000? // TODO: Hardcode width to 100000?
err := t.SetSize(width, height) err := t.SetSize(width, height)
ok = err == nil ok = err == nil
// Save the term:
t.mu.Lock()
t.term = term
t.mu.Unlock()
} }
case "window-change": case "window-change":
width, height, ok = parseWinchRequest(req.Payload) width, height, ok = parseWinchRequest(req.Payload)
@ -210,14 +135,6 @@ func (t *Terminal) listen(requests <-chan *ssh.Request, ready chan<- struct{}) {
err := t.SetSize(width, height) err := t.SetSize(width, height)
ok = err == nil ok = err == nil
} }
case "env":
var v EnvVar
if err := ssh.Unmarshal(req.Payload, &v); err == nil {
t.mu.Lock()
t.env = append(t.env, v)
t.mu.Unlock()
ok = true
}
} }
if req.WantReply { if req.WantReply {
@ -225,24 +142,3 @@ func (t *Terminal) listen(requests <-chan *ssh.Request, ready chan<- struct{}) {
} }
} }
} }
// Env returns a list of environment key-values that have been set. They are
// returned in the order that they have been set, there is no deduplication or
// other pre-processing applied.
func (t *Terminal) Env() Env {
t.mu.Lock()
defer t.mu.Unlock()
return Env(t.env)
}
// Term returns the terminal string value as set by the pty.
// If there was no pty request, it falls back to the TERM value passed in as an
// Env variable.
func (t *Terminal) Term() string {
t.mu.Lock()
defer t.mu.Unlock()
if t.term != "" {
return t.term
}
return Env(t.env).Get("TERM")
}

View File

@ -1,27 +0,0 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

File diff suppressed because it is too large Load Diff

View File

@ -1,435 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build aix darwin dragonfly freebsd linux,!appengine netbsd openbsd windows plan9 solaris
package terminal
import (
"bytes"
"io"
"os"
"runtime"
"testing"
"unicode/utf8"
)
type MockTerminal struct {
toSend []byte
bytesPerRead int
received []byte
}
func (c *MockTerminal) Read(data []byte) (n int, err error) {
n = len(data)
if n == 0 {
return
}
if n > len(c.toSend) {
n = len(c.toSend)
}
if n == 0 {
return 0, io.EOF
}
if c.bytesPerRead > 0 && n > c.bytesPerRead {
n = c.bytesPerRead
}
copy(data, c.toSend[:n])
c.toSend = c.toSend[n:]
return
}
func (c *MockTerminal) Write(data []byte) (n int, err error) {
c.received = append(c.received, data...)
return len(data), nil
}
func TestClose(t *testing.T) {
c := &MockTerminal{}
ss := NewTerminal(c, "> ")
line, err := ss.ReadLine()
if line != "" {
t.Errorf("Expected empty line but got: %s", line)
}
if err != io.EOF {
t.Errorf("Error should have been EOF but got: %s", err)
}
}
var keyPressTests = []struct {
in string
line string
received string
err error
throwAwayLines int
}{
{
err: io.EOF,
},
{
in: "\r",
line: "",
},
{
in: "foo\r",
line: "foo",
},
{
in: "a\x1b[Cb\r", // right
line: "ab",
},
{
in: "a\x1b[Db\r", // left
line: "ba",
},
{
in: "a\177b\r", // backspace
line: "b",
},
{
in: "\x1b[A\r", // up
},
{
in: "\x1b[B\r", // down
},
{
in: "\016\r", // ^P
},
{
in: "\014\r", // ^N
},
{
in: "line\x1b[A\x1b[B\r", // up then down
line: "line",
},
{
in: "line1\rline2\x1b[A\r", // recall previous line.
line: "line1",
throwAwayLines: 1,
},
{
// recall two previous lines and append.
in: "line1\rline2\rline3\x1b[A\x1b[Axxx\r",
line: "line1xxx",
throwAwayLines: 2,
},
{
// Ctrl-A to move to beginning of line followed by ^K to kill
// line.
in: "a b \001\013\r",
line: "",
},
{
// Ctrl-A to move to beginning of line, Ctrl-E to move to end,
// finally ^K to kill nothing.
in: "a b \001\005\013\r",
line: "a b ",
},
{
in: "\027\r",
line: "",
},
{
in: "a\027\r",
line: "",
},
{
in: "a \027\r",
line: "",
},
{
in: "a b\027\r",
line: "a ",
},
{
in: "a b \027\r",
line: "a ",
},
{
in: "one two thr\x1b[D\027\r",
line: "one two r",
},
{
in: "\013\r",
line: "",
},
{
in: "a\013\r",
line: "a",
},
{
in: "ab\x1b[D\013\r",
line: "a",
},
{
in: "Ξεσκεπάζω\r",
line: "Ξεσκεπάζω",
},
{
in: "£\r\x1b[A\177\r", // non-ASCII char, enter, up, backspace.
line: "",
throwAwayLines: 1,
},
{
in: "£\r££\x1b[A\x1b[B\177\r", // non-ASCII char, enter, 2x non-ASCII, up, down, backspace, enter.
line: "£",
throwAwayLines: 1,
},
{
// Ctrl-D at the end of the line should be ignored.
in: "a\004\r",
line: "a",
},
{
// a, b, left, Ctrl-D should erase the b.
in: "ab\x1b[D\004\r",
line: "a",
},
{
// a, b, c, d, left, left, ^U should erase to the beginning of
// the line.
in: "abcd\x1b[D\x1b[D\025\r",
line: "cd",
},
{
// Bracketed paste mode: control sequences should be returned
// verbatim in paste mode.
in: "abc\x1b[200~de\177f\x1b[201~\177\r",
line: "abcde\177",
},
{
// Enter in bracketed paste mode should still work.
in: "abc\x1b[200~d\refg\x1b[201~h\r",
line: "efgh",
throwAwayLines: 1,
},
{
// Lines consisting entirely of pasted data should be indicated as such.
in: "\x1b[200~a\r",
line: "a",
err: ErrPasteIndicator,
},
}
func TestKeyPresses(t *testing.T) {
for i, test := range keyPressTests {
for j := 1; j < len(test.in); j++ {
c := &MockTerminal{
toSend: []byte(test.in),
bytesPerRead: j,
}
ss := NewTerminal(c, "> ")
for k := 0; k < test.throwAwayLines; k++ {
_, err := ss.ReadLine()
if err != nil {
t.Errorf("Throwaway line %d from test %d resulted in error: %s", k, i, err)
}
}
line, err := ss.ReadLine()
if line != test.line {
t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
break
}
if err != test.err {
t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
break
}
}
}
}
var renderTests = []struct {
in string
received string
err error
}{
{
// Cursor move after keyHome (left 4) then enter (right 4, newline)
in: "abcd\x1b[H\r",
received: "> abcd\x1b[4D\x1b[4C\r\n",
},
{
// Write, home, prepend, enter. Prepends rewrites the line.
in: "cdef\x1b[Hab\r",
received: "> cdef" + // Initial input
"\x1b[4Da" + // Move cursor back, insert first char
"cdef" + // Copy over original string
"\x1b[4Dbcdef" + // Repeat for second char with copy
"\x1b[4D" + // Put cursor back in position to insert again
"\x1b[4C\r\n", // Put cursor at the end of the line and newline.
},
}
func TestRender(t *testing.T) {
for i, test := range renderTests {
for j := 1; j < len(test.in); j++ {
c := &MockTerminal{
toSend: []byte(test.in),
bytesPerRead: j,
}
ss := NewTerminal(c, "> ")
_, err := ss.ReadLine()
if err != test.err {
t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
break
}
if test.received != string(c.received) {
t.Errorf("Results rendered from test %d (%d bytes per read) was '%s', expected '%s'", i, j, c.received, test.received)
break
}
}
}
}
func TestPasswordNotSaved(t *testing.T) {
c := &MockTerminal{
toSend: []byte("password\r\x1b[A\r"),
bytesPerRead: 1,
}
ss := NewTerminal(c, "> ")
pw, _ := ss.ReadPassword("> ")
if pw != "password" {
t.Fatalf("failed to read password, got %s", pw)
}
line, _ := ss.ReadLine()
if len(line) > 0 {
t.Fatalf("password was saved in history")
}
}
var setSizeTests = []struct {
width, height int
}{
{40, 13},
{80, 24},
{132, 43},
}
func TestTerminalSetSize(t *testing.T) {
for _, setSize := range setSizeTests {
c := &MockTerminal{
toSend: []byte("password\r\x1b[A\r"),
bytesPerRead: 1,
}
ss := NewTerminal(c, "> ")
ss.SetSize(setSize.width, setSize.height)
pw, _ := ss.ReadPassword("Password: ")
if pw != "password" {
t.Fatalf("failed to read password, got %s", pw)
}
if string(c.received) != "Password: \r\n" {
t.Errorf("failed to set the temporary prompt expected %q, got %q", "Password: ", c.received)
}
}
}
func TestReadPasswordLineEnd(t *testing.T) {
var tests = []struct {
input string
want string
}{
{"\n", ""},
{"\r\n", ""},
{"test\r\n", "test"},
{"testtesttesttes\n", "testtesttesttes"},
{"testtesttesttes\r\n", "testtesttesttes"},
{"testtesttesttesttest\n", "testtesttesttesttest"},
{"testtesttesttesttest\r\n", "testtesttesttesttest"},
}
for _, test := range tests {
buf := new(bytes.Buffer)
if _, err := buf.WriteString(test.input); err != nil {
t.Fatal(err)
}
have, err := readPasswordLine(buf)
if err != nil {
t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
continue
}
if string(have) != test.want {
t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
continue
}
if _, err = buf.WriteString(test.input); err != nil {
t.Fatal(err)
}
have, err = readPasswordLine(buf)
if err != nil {
t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
continue
}
if string(have) != test.want {
t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
continue
}
}
}
func TestMakeRawState(t *testing.T) {
fd := int(os.Stdout.Fd())
if !IsTerminal(fd) {
t.Skip("stdout is not a terminal; skipping test")
}
st, err := GetState(fd)
if err != nil {
t.Fatalf("failed to get terminal state from GetState: %s", err)
}
if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
t.Skip("MakeRaw not allowed on iOS; skipping test")
}
defer Restore(fd, st)
raw, err := MakeRaw(fd)
if err != nil {
t.Fatalf("failed to get terminal state from MakeRaw: %s", err)
}
if *st != *raw {
t.Errorf("states do not match; was %v, expected %v", raw, st)
}
}
func TestOutputNewlines(t *testing.T) {
// \n should be changed to \r\n in terminal output.
buf := new(bytes.Buffer)
term := NewTerminal(buf, ">")
term.Write([]byte("1\n2\n"))
output := string(buf.Bytes())
const expected = "1\r\n2\r\n"
if output != expected {
t.Errorf("incorrect output: was %q, expected %q", output, expected)
}
}
func TestTerminalvisualLength(t *testing.T) {
var tests = []struct {
input string
want int
}{
{"hello world", 11},
{"babalala", 8},
{"端子", 4},
{"を搭載", 6},
{"baba端子lalaを搭載", 18},
}
for _, test := range tests {
var runes []rune
for i, w := 0, 0; i < len(test.input); i += w {
runeValue, width := utf8.DecodeRuneInString(test.input[i:])
runes = append(runes, runeValue)
w = width
}
output := visualLength(runes)
if output != test.want {
t.Errorf("incorrect [%s] output: was %d, expected %d",
test.input, output, test.want)
}
}
}

View File

@ -1,114 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build aix darwin dragonfly freebsd linux,!appengine netbsd openbsd
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Putting a terminal into raw mode is the most common requirement:
//
// oldState, err := terminal.MakeRaw(0)
// if err != nil {
// panic(err)
// }
// defer terminal.Restore(0, oldState)
package terminal
import (
"golang.org/x/sys/unix"
)
// State contains the state of a terminal.
type State struct {
termios unix.Termios
}
// IsTerminal returns whether the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
_, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
return err == nil
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
if err != nil {
return nil, err
}
oldState := State{termios: *termios}
// This attempts to replicate the behaviour documented for cfmakeraw in
// the termios(3) manpage.
termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON
termios.Oflag &^= unix.OPOST
termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN
termios.Cflag &^= unix.CSIZE | unix.PARENB
termios.Cflag |= unix.CS8
termios.Cc[unix.VMIN] = 1
termios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(fd, ioctlWriteTermios, termios); err != nil {
return nil, err
}
return &oldState, nil
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
if err != nil {
return nil, err
}
return &State{termios: *termios}, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
return unix.IoctlSetTermios(fd, ioctlWriteTermios, &state.termios)
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ)
if err != nil {
return -1, -1, err
}
return int(ws.Col), int(ws.Row), nil
}
// passwordReader is an io.Reader that reads from a specific file descriptor.
type passwordReader int
func (r passwordReader) Read(buf []byte) (int, error) {
return unix.Read(int(r), buf)
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
if err != nil {
return nil, err
}
newState := *termios
newState.Lflag &^= unix.ECHO
newState.Lflag |= unix.ICANON | unix.ISIG
newState.Iflag |= unix.ICRNL
if err := unix.IoctlSetTermios(fd, ioctlWriteTermios, &newState); err != nil {
return nil, err
}
defer unix.IoctlSetTermios(fd, ioctlWriteTermios, termios)
return readPasswordLine(passwordReader(fd))
}

View File

@ -1,12 +0,0 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build aix
package terminal
import "golang.org/x/sys/unix"
const ioctlReadTermios = unix.TCGETS
const ioctlWriteTermios = unix.TCSETS

View File

@ -1,12 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd netbsd openbsd
package terminal
import "golang.org/x/sys/unix"
const ioctlReadTermios = unix.TIOCGETA
const ioctlWriteTermios = unix.TIOCSETA

View File

@ -1,10 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package terminal
import "golang.org/x/sys/unix"
const ioctlReadTermios = unix.TCGETS
const ioctlWriteTermios = unix.TCSETS

View File

@ -1,58 +0,0 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Putting a terminal into raw mode is the most common requirement:
//
// oldState, err := terminal.MakeRaw(0)
// if err != nil {
// panic(err)
// }
// defer terminal.Restore(0, oldState)
package terminal
import (
"fmt"
"runtime"
)
type State struct{}
// IsTerminal returns whether the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
return false
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
return nil, fmt.Errorf("terminal: MakeRaw not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
return nil, fmt.Errorf("terminal: GetState not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
return fmt.Errorf("terminal: Restore not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
return 0, 0, fmt.Errorf("terminal: GetSize not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
return nil, fmt.Errorf("terminal: ReadPassword not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

View File

@ -1,125 +0,0 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build solaris
package terminal
import (
"io"
"syscall"
"golang.org/x/sys/unix"
)
// State contains the state of a terminal.
type State struct {
termios unix.Termios
}
// IsTerminal returns whether the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
_, err := unix.IoctlGetTermio(fd, unix.TCGETA)
return err == nil
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
// see also: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libast/common/uwin/getpass.c
val, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
oldState := *val
newState := oldState
newState.Lflag &^= syscall.ECHO
newState.Lflag |= syscall.ICANON | syscall.ISIG
newState.Iflag |= syscall.ICRNL
err = unix.IoctlSetTermios(fd, unix.TCSETS, &newState)
if err != nil {
return nil, err
}
defer unix.IoctlSetTermios(fd, unix.TCSETS, &oldState)
var buf [16]byte
var ret []byte
for {
n, err := syscall.Read(fd, buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
}
// MakeRaw puts the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
// see http://cr.illumos.org/~webrev/andy_js/1060/
func MakeRaw(fd int) (*State, error) {
termios, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
oldState := State{termios: *termios}
termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON
termios.Oflag &^= unix.OPOST
termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN
termios.Cflag &^= unix.CSIZE | unix.PARENB
termios.Cflag |= unix.CS8
termios.Cc[unix.VMIN] = 1
termios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(fd, unix.TCSETS, termios); err != nil {
return nil, err
}
return &oldState, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, oldState *State) error {
return unix.IoctlSetTermios(fd, unix.TCSETS, &oldState.termios)
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
termios, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
return &State{termios: *termios}, nil
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ)
if err != nil {
return 0, 0, err
}
return int(ws.Col), int(ws.Row), nil
}

View File

@ -1,103 +0,0 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build windows
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Putting a terminal into raw mode is the most common requirement:
//
// oldState, err := terminal.MakeRaw(0)
// if err != nil {
// panic(err)
// }
// defer terminal.Restore(0, oldState)
package terminal
import (
"os"
"golang.org/x/sys/windows"
)
type State struct {
mode uint32
}
// IsTerminal returns whether the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
var st uint32
err := windows.GetConsoleMode(windows.Handle(fd), &st)
return err == nil
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
var st uint32
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
raw := st &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT)
if err := windows.SetConsoleMode(windows.Handle(fd), raw); err != nil {
return nil, err
}
return &State{st}, nil
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
var st uint32
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
return &State{st}, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
return windows.SetConsoleMode(windows.Handle(fd), state.mode)
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
var info windows.ConsoleScreenBufferInfo
if err := windows.GetConsoleScreenBufferInfo(windows.Handle(fd), &info); err != nil {
return 0, 0, err
}
return int(info.Size.X), int(info.Size.Y), nil
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
var st uint32
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
old := st
st &^= (windows.ENABLE_ECHO_INPUT)
st |= (windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT)
if err := windows.SetConsoleMode(windows.Handle(fd), st); err != nil {
return nil, err
}
defer windows.SetConsoleMode(windows.Handle(fd), old)
var h windows.Handle
p, _ := windows.GetCurrentProcess()
if err := windows.DuplicateHandle(p, windows.Handle(fd), p, &h, 0, false, windows.DUPLICATE_SAME_ACCESS); err != nil {
return nil, err
}
f := os.NewFile(uintptr(h), "stdin")
defer f.Close()
return readPasswordLine(f)
}