From a0b8a32eb4a933f7ad3bf687892f05f85dd75fee Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan <jmorganca@gmail.com>
Date: Mon, 15 Apr 2024 12:09:32 -0400
Subject: [PATCH] Terminate subprocess if receiving `SIGINT` or `SIGTERM`
 signals while model is loading (#3653)

* terminate subprocess if receiving `SIGINT` or `SIGTERM` signals while model is loading

* use `unload` in signal handler
---
 llm/server.go    | 15 ++-------------
 server/routes.go | 40 +++++++++++++++++++++-------------------
 2 files changed, 23 insertions(+), 32 deletions(-)

diff --git a/llm/server.go b/llm/server.go
index 707f0b8b..4c1f9634 100644
--- a/llm/server.go
+++ b/llm/server.go
@@ -17,7 +17,6 @@ import (
 	"os/exec"
 	"path/filepath"
 	"runtime"
-	"slices"
 	"strconv"
 	"strings"
 	"time"
@@ -36,10 +35,6 @@ type LlamaServer struct {
 	options api.Options
 }
 
-var cpuOnlyFamilies = []string{
-	"mamba",
-}
-
 func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
 	f, err := os.Open(model)
 	if err != nil {
@@ -91,7 +86,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 	memoryRequiredPartial := memoryMinimum + graphPartialOffload
 
 	if info.Library != "metal" {
-		if memoryRequiredPartial > memoryAvailable || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture()) {
+		if memoryRequiredPartial > memoryAvailable {
 			info.Library = "cpu"
 		}
 	}
@@ -277,12 +272,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
 			_ = s.cmd.Wait()
 		}()
 
-		if err = s.waitUntilRunning(); err != nil {
-			slog.Error("error starting llama server", "server", servers[i], "error", err)
-			s.Close()
-			finalErr = err
-			continue
-		}
 		return s, nil
 	}
 
@@ -383,7 +372,7 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
 	return nil
 }
 
-func (s *LlamaServer) waitUntilRunning() error {
+func (s *LlamaServer) WaitUntilRunning() error {
 	start := time.Now()
 	// TODO we need to wire up a better way to detect hangs during model load and startup of the server
 	expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
diff --git a/server/routes.go b/server/routes.go
index d1e7f4cd..b0d36b14 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -68,6 +68,18 @@ var loaded struct {
 
 var defaultSessionDuration = 5 * time.Minute
 
+func unload() {
+	if loaded.llama != nil {
+		loaded.llama.Close()
+	}
+
+	loaded.llama = nil
+	loaded.model = ""
+	loaded.adapters = nil
+	loaded.projectors = nil
+	loaded.Options = nil
+}
+
 // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
 func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
 	ctx, cancel := context.WithTimeout(c, 10*time.Second)
@@ -83,12 +95,7 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
 	if needLoad {
 		if loaded.llama != nil {
 			slog.Info("changing loaded model")
-			loaded.llama.Close()
-			loaded.llama = nil
-			loaded.model = ""
-			loaded.adapters = nil
-			loaded.projectors = nil
-			loaded.Options = nil
+			unload()
 		}
 
 		llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
@@ -108,22 +115,19 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
 		loaded.projectors = model.ProjectorPaths
 		loaded.llama = llama
 		loaded.Options = &opts
+
+		if err = llama.WaitUntilRunning(); err != nil {
+			slog.Error("error loading llama server", "error", err)
+			unload()
+			return err
+		}
 	}
 
 	if loaded.expireTimer == nil {
 		loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
 			loaded.mu.Lock()
 			defer loaded.mu.Unlock()
-
-			if loaded.llama != nil {
-				loaded.llama.Close()
-			}
-
-			loaded.llama = nil
-			loaded.model = ""
-			loaded.adapters = nil
-			loaded.projectors = nil
-			loaded.Options = nil
+			unload()
 		})
 	}
 
@@ -1146,9 +1150,7 @@ func Serve(ln net.Listener) error {
 	signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
 	go func() {
 		<-signals
-		if loaded.llama != nil {
-			loaded.llama.Close()
-		}
+		unload()
 		gpu.Cleanup()
 		os.Exit(0)
 	}()