From b5636dba9dc78af25eee8ba7ae71430431383335 Mon Sep 17 00:00:00 2001
From: Chad Etzel <jazzychad@gmail.com>
Date: Sat, 13 Dec 2014 20:28:19 -0800
Subject: [PATCH] add ability to load motd from file - closes #25

---
 client.go |  2 +-
 cmd.go    | 15 +++++++++++++++
 server.go | 12 +++++-------
 3 files changed, 21 insertions(+), 8 deletions(-)

diff --git a/client.go b/client.go
index 9b5166e..3fd2237 100644
--- a/client.go
+++ b/client.go
@@ -321,7 +321,7 @@ func (c *Client) handleShell(channel ssh.Channel) {
 					} else {
 						newmotd = parts[1] + " " + parts[2]
 					}
-					c.Server.SetMotd(c, newmotd)
+					c.Server.SetMotd(newmotd)
 					c.Server.MotdBroadcast(c)
 				}
 
diff --git a/cmd.go b/cmd.go
index 266eeb2..cfedf85 100644
--- a/cmd.go
+++ b/cmd.go
@@ -4,6 +4,7 @@ import (
 	"bufio"
 	"fmt"
 	"io/ioutil"
+	"strings"
 	"os"
 	"os/signal"
 
@@ -18,6 +19,7 @@ type Options struct {
 	Bind      string   `long:"bind" description:"Host and port to listen on." default:"0.0.0.0:22"`
 	Admin     []string `long:"admin" description:"Fingerprint of pubkey to mark as admin."`
 	Whitelist string   `long:"whitelist" description:"Optional file of pubkey fingerprints that are allowed to connect"`
+	Motd      string   `long:"motd" description:"Message of the Day file (optional)"`
 }
 
 var logLevels = []log.Level{
@@ -80,6 +82,19 @@ func main() {
 		}
 	}
 
+	if options.Motd != "" {
+		motd, err := ioutil.ReadFile(options.Motd)
+		if err != nil {
+			logger.Errorf("Failed to load MOTD file: %v", err)
+			return
+		}
+		motdString := string(motd[:])
+		/* hack to normalize line endings into \r\n */
+		motdString = strings.Replace(motdString, "\r\n", "\n", -1)
+		motdString = strings.Replace(motdString, "\n", "\r\n", -1)
+		server.SetMotd(motdString)
+	}
+
 	// Construct interrupt handler
 	sig := make(chan os.Signal, 1)
 	signal.Notify(sig, os.Interrupt)
diff --git a/server.go b/server.go
index fe6d220..163050e 100644
--- a/server.go
+++ b/server.go
@@ -123,21 +123,19 @@ func (s *Server) Privmsg(nick, message string, sender *Client) error {
 	return nil
 }
 
-func (s *Server) SetMotd(client *Client, motd string) {
+func (s *Server) SetMotd(motd string) {
 	s.Lock()
 	s.motd = motd
 	s.Unlock()
 }
 
 func (s *Server) MotdUnicast(client *Client) {
-	client.SysMsg("/** MOTD")
-	client.SysMsg(" * " + ColorString("36", s.motd)) /* a nice cyan color */
-	client.SysMsg(" **/")
+	client.SysMsg("MOTD:\r\n" + ColorString("36", s.motd)) /* a nice cyan color */
 }
 
 func (s *Server) MotdBroadcast(client *Client) {
 	s.Broadcast(ContinuousFormat(SYSTEM_MESSAGE_FORMAT, fmt.Sprintf(" * New MOTD set by %s.", client.ColoredName())), client)
-	s.Broadcast(" /**\r\n" + "  * " + ColorString("36", s.motd) + "\r\n  **/", client)
+	s.Broadcast(ColorString("36", s.motd), client)
 }
 
 func (s *Server) Add(client *Client) {
@@ -237,9 +235,9 @@ func (s *Server) Op(fingerprint string) {
 
 func (s *Server) Whitelist(fingerprint string) {
 	logger.Infof("Adding whitelist: %s", fingerprint)
-	s.lock.Lock()
+	s.Lock()
 	s.whitelist[fingerprint] = struct{}{}
-	s.lock.Unlock()
+	s.Unlock()
 }
 
 func (s *Server) Uptime() string {