From 35af37a2cb7097dcbac2a0f88eb2636436f82d2a Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 18 Jul 2023 11:59:42 -0700 Subject: [PATCH 1/7] session id --- api/types.go | 8 +++++--- cmd/cmd.go | 24 ++++++++++++++-------- llama/llama.go | 18 ++++++++-------- server/routes.go | 53 +++++++++++++++++++++++++++++++++--------------- 4 files changed, 67 insertions(+), 36 deletions(-) diff --git a/api/types.go b/api/types.go index 07ce8122..42b0c470 100644 --- a/api/types.go +++ b/api/types.go @@ -28,9 +28,10 @@ func (e StatusError) Error() string { } type GenerateRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - Context []int `json:"context,omitempty"` + SessionID int64 `json:"session_id"` + Model string `json:"model"` + Prompt string `json:"prompt"` + Context []int `json:"context,omitempty"` Options `json:"options"` } @@ -81,6 +82,7 @@ type ListResponseModel struct { } type GenerateResponse struct { + SessionID int64 `json:"session_id"` Model string `json:"model"` CreatedAt time.Time `json:"created_at"` Response string `json:"response,omitempty"` diff --git a/cmd/cmd.go b/cmd/cmd.go index 7761b03b..b9c07cff 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -244,7 +244,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error { return generateBatch(cmd, args[0]) } -var generateContextKey struct{} +type generateContextKey string func generate(cmd *cobra.Command, model, prompt string) error { if len(strings.TrimSpace(prompt)) > 0 { @@ -255,22 +255,25 @@ func generate(cmd *cobra.Command, model, prompt string) error { var latest api.GenerateResponse - generateContext, ok := cmd.Context().Value(generateContextKey).([]int) + generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int) if !ok { generateContext = []int{} } - request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext} - fn := func(resp api.GenerateResponse) error { + generateSession, ok := cmd.Context().Value(generateContextKey("session")).(int64) + if !ok { + generateSession = 0 + } + + request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, SessionID: generateSession} + fn := func(response api.GenerateResponse) error { if !spinner.IsFinished() { spinner.Finish() } - latest = resp + latest = response - fmt.Print(resp.Response) - - cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context)) + fmt.Print(response.Response) return nil } @@ -289,6 +292,11 @@ func generate(cmd *cobra.Command, model, prompt string) error { if verbose { latest.Summary() } + + ctx := cmd.Context() + ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context) + ctx = context.WithValue(ctx, generateContextKey("session"), latest.SessionID) + cmd.SetContext(ctx) } return nil diff --git a/llama/llama.go b/llama/llama.go index 37b4c143..a48c5965 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -91,7 +91,7 @@ import ( "github.com/jmorganca/ollama/api" ) -type llama struct { +type LLM struct { params *C.struct_llama_context_params model *C.struct_llama_model ctx *C.struct_llama_context @@ -99,12 +99,12 @@ type llama struct { api.Options } -func New(model string, opts api.Options) (*llama, error) { +func New(model string, opts api.Options) (*LLM, error) { if _, err := os.Stat(model); err != nil { return nil, err } - llm := llama{Options: opts} + llm := LLM{Options: opts} C.llama_backend_init(C.bool(llm.UseNUMA)) @@ -144,14 +144,14 @@ func New(model string, opts api.Options) (*llama, error) { return &llm, nil } -func (llm *llama) Close() { +func (llm *LLM) Close() { defer C.llama_free_model(llm.model) defer C.llama_free(llm.ctx) C.llama_print_timings(llm.ctx) } -func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { +func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { if input := llm.tokenize(prompt); input != nil { embd := make([]C.llama_token, len(ctx)) for i := range ctx { @@ -164,7 +164,7 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse return errors.New("llama: tokenize") } -func (llm *llama) tokenize(prompt string) []C.llama_token { +func (llm *LLM) tokenize(prompt string) []C.llama_token { cPrompt := C.CString(prompt) defer C.free(unsafe.Pointer(cPrompt)) @@ -176,7 +176,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token { return nil } -func (llm *llama) detokenize(tokens ...C.llama_token) string { +func (llm *LLM) detokenize(tokens ...C.llama_token) string { var sb strings.Builder for _, token := range tokens { sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token))) @@ -185,7 +185,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string { return sb.String() } -func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error { +func (llm *LLM) generate(input []C.llama_token, fn func(api.GenerateResponse)) error { var opts C.struct_llama_sample_options opts.repeat_penalty = C.float(llm.RepeatPenalty) opts.frequency_penalty = C.float(llm.FrequencyPenalty) @@ -256,7 +256,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) return nil } -func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) { +func (llm *LLM) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) { numVocab := int(C.llama_n_vocab(llm.ctx)) logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab) diff --git a/server/routes.go b/server/routes.go index aabcb718..93a04cd7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -11,6 +11,7 @@ import ( "os" "path/filepath" "strings" + "sync" "time" "dario.cat/mergo" @@ -21,7 +22,17 @@ import ( "github.com/jmorganca/ollama/llama" ) +var mu sync.Mutex + +var activeSession struct { + ID int64 + *llama.LLM +} + func GenerateHandler(c *gin.Context) { + mu.Lock() + defer mu.Unlock() + start := time.Now() var req api.GenerateRequest @@ -36,15 +47,31 @@ func GenerateHandler(c *gin.Context) { return } - opts := api.DefaultOptions() - if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + if req.SessionID == 0 || req.SessionID != activeSession.ID { + if activeSession.LLM != nil { + activeSession.Close() + activeSession.LLM = nil + } - if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + opts := api.DefaultOptions() + if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + llm, err := llama.New(model.ModelPath, opts) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + activeSession.ID = time.Now().UnixNano() + activeSession.LLM = llm } prompt, err := model.Prompt(req) @@ -53,19 +80,13 @@ func GenerateHandler(c *gin.Context) { return } - llm, err := llama.New(model.ModelPath, opts) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - defer llm.Close() - ch := make(chan any) go func() { defer close(ch) fn := func(r api.GenerateResponse) { r.Model = req.Model r.CreatedAt = time.Now().UTC() + r.SessionID = activeSession.ID if r.Done { r.TotalDuration = time.Since(start) } @@ -73,7 +94,7 @@ func GenerateHandler(c *gin.Context) { ch <- r } - if err := llm.Predict(req.Context, prompt, fn); err != nil { + if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() From 32aec66e6ad909759da45c18d0a4504e0dd73fc1 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 18 Jul 2023 12:02:02 -0700 Subject: [PATCH 2/7] add load duration --- api/types.go | 5 +++++ server/routes.go | 7 +++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/api/types.go b/api/types.go index 42b0c470..dccfbf7a 100644 --- a/api/types.go +++ b/api/types.go @@ -91,6 +91,7 @@ type GenerateResponse struct { Context []int `json:"context,omitempty"` TotalDuration time.Duration `json:"total_duration,omitempty"` + LoadDuration time.Duration `json:"load_duration,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"` PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` EvalCount int `json:"eval_count,omitempty"` @@ -102,6 +103,10 @@ func (r *GenerateResponse) Summary() { fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration) } + if r.LoadDuration > 0 { + fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration) + } + if r.PromptEvalCount > 0 { fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount) } diff --git a/server/routes.go b/server/routes.go index 93a04cd7..c3f27ec8 100644 --- a/server/routes.go +++ b/server/routes.go @@ -33,7 +33,7 @@ func GenerateHandler(c *gin.Context) { mu.Lock() defer mu.Unlock() - start := time.Now() + checkpointStart := time.Now() var req api.GenerateRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -74,6 +74,8 @@ func GenerateHandler(c *gin.Context) { activeSession.LLM = llm } + checkpointLoaded := time.Now() + prompt, err := model.Prompt(req) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -88,7 +90,8 @@ func GenerateHandler(c *gin.Context) { r.CreatedAt = time.Now().UTC() r.SessionID = activeSession.ID if r.Done { - r.TotalDuration = time.Since(start) + r.TotalDuration = time.Since(checkpointStart) + r.LoadDuration = checkpointLoaded.Sub(checkpointStart) } ch <- r From 3003fc03fcd2b12919433506dfc675b30cdca85f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 19 Jul 2023 12:47:15 -0700 Subject: [PATCH 3/7] update predict code --- api/types.go | 5 +- llama/llama.go | 245 ++++++++++++++++++++++++++++++++----------------- llama/utils.go | 107 ++------------------- 3 files changed, 176 insertions(+), 181 deletions(-) diff --git a/api/types.go b/api/types.go index dccfbf7a..5f8c3891 100644 --- a/api/types.go +++ b/api/types.go @@ -134,6 +134,7 @@ type Options struct { // Model options NumCtx int `json:"num_ctx,omitempty"` + NumKeep int `json:"num_keep,omitempty"` NumBatch int `json:"num_batch,omitempty"` NumGPU int `json:"num_gpu,omitempty"` MainGPU int `json:"main_gpu,omitempty"` @@ -158,6 +159,7 @@ type Options struct { Mirostat int `json:"mirostat,omitempty"` MirostatTau float32 `json:"mirostat_tau,omitempty"` MirostatEta float32 `json:"mirostat_eta,omitempty"` + PenalizeNewline bool `json:"penalize_newline,omitempty"` NumThread int `json:"num_thread,omitempty"` } @@ -176,7 +178,7 @@ func DefaultOptions() Options { UseMMap: true, UseMLock: false, - RepeatLastN: 512, + RepeatLastN: 64, RepeatPenalty: 1.1, FrequencyPenalty: 0.0, PresencePenalty: 0.0, @@ -188,6 +190,7 @@ func DefaultOptions() Options { Mirostat: 0, MirostatTau: 5.0, MirostatEta: 0.1, + PenalizeNewline: true, NumThread: runtime.NumCPU(), } diff --git a/llama/llama.go b/llama/llama.go index a48c5965..9f5066f3 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -1,8 +1,8 @@ package llama /* -#cgo CPPFLAGS: -O3 -DNDEBUG=1 -DGGML_USE_K_QUANTS -#cgo CXXFLAGS: -std=c++11 +#cgo CPPFLAGS: -O3 -Wall -Wextra -Werror -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS +#cgo CXXFLAGS: -std=gnu++11 #cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_METAL_NDEBUG #cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders #include @@ -21,6 +21,7 @@ struct llama_sample_options int mirostat; float mirostat_tau; float mirostat_eta; + bool penalize_newline; }; llama_token llama_sample( @@ -37,6 +38,8 @@ llama_token llama_sample( false, }; + struct llama_token_data newline = candidates_p.data[llama_token_nl()]; + llama_sample_repetition_penalty( ctx, &candidates_p, last_tokens, n_last_tokens, @@ -47,6 +50,10 @@ llama_token llama_sample( last_tokens, n_last_tokens, opts->frequency_penalty, opts->presence_penalty); + if (!opts->penalize_newline) { + candidates_p.data[llama_token_nl()] = newline; + } + if (opts->temperature <= 0) { return llama_sample_token_greedy(ctx, &candidates_p); } @@ -82,9 +89,9 @@ import ( "errors" "fmt" "io" + "log" "os" "strings" - "time" "unicode/utf8" "unsafe" @@ -96,6 +103,10 @@ type LLM struct { model *C.struct_llama_model ctx *C.struct_llama_context + last []C.llama_token + embd []C.llama_token + cursor int + api.Options } @@ -152,16 +163,98 @@ func (llm *LLM) Close() { } func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { - if input := llm.tokenize(prompt); input != nil { - embd := make([]C.llama_token, len(ctx)) - for i := range ctx { - embd[i] = C.llama_token(ctx[i]) - } + C.llama_reset_timings(llm.ctx) - return llm.generate(append(embd, input...), fn) + tokens := make([]C.llama_token, len(ctx)) + for i := range tokens { + tokens[i] = C.llama_token(ctx[i]) } - return errors.New("llama: tokenize") + if len(tokens) == 0 { + tokens = llm.tokenize(" ") + } + + llm.marshalPrompt(tokens, prompt) + + C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed)) + + var b bytes.Buffer + for { + token, err := llm.next() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return err + } + + b.WriteString(llm.detokenize(token)) + if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax { + fn(api.GenerateResponse{Response: b.String()}) + b.Reset() + } + } + + last := make([]int, 0, len(llm.last)) + for _, i := range llm.last { + if i != 0 { + last = append(last, int(i)) + } + } + + timings := C.llama_get_timings(llm.ctx) + fn(api.GenerateResponse{ + Done: true, + Context: last, + PromptEvalCount: int(timings.n_p_eval), + PromptEvalDuration: parseDurationMs(float64(timings.t_p_eval_ms)), + EvalCount: int(timings.n_eval), + EvalDuration: parseDurationMs(float64(timings.t_eval_ms)), + }) + + return nil +} + +func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token { + tokens := append(ctx, llm.tokenize(prompt)...) + if llm.NumKeep < 0 { + llm.NumKeep = len(tokens) + } + + // min(llm.NumCtx - 4, llm.NumKeep) + if llm.NumCtx-4 < llm.NumKeep { + llm.NumKeep = llm.NumCtx - 4 + } + + if len(tokens) >= llm.NumCtx { + // truncate input + numLeft := (llm.NumCtx - llm.NumKeep) / 2 + truncated := tokens[:llm.NumKeep] + erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft + truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...) + copy(llm.last, tokens[len(tokens)-llm.NumCtx:]) + + tokens = truncated + log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated)) + } else { + llm.last = make([]C.llama_token, llm.NumCtx-len(tokens)) + llm.last = append(llm.last, tokens...) + } + + var i int + for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ { + // noop + } + + llm.embd = tokens + if i == len(tokens) { + // evaluate at least one token to generate logits + i-- + } + + llm.cursor = i + + log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:])) + return tokens } func (llm *LLM) tokenize(prompt string) []C.llama_token { @@ -185,98 +278,86 @@ func (llm *LLM) detokenize(tokens ...C.llama_token) string { return sb.String() } -func (llm *LLM) generate(input []C.llama_token, fn func(api.GenerateResponse)) error { - var opts C.struct_llama_sample_options - opts.repeat_penalty = C.float(llm.RepeatPenalty) - opts.frequency_penalty = C.float(llm.FrequencyPenalty) - opts.presence_penalty = C.float(llm.PresencePenalty) - opts.temperature = C.float(llm.Temperature) - opts.top_k = C.int(llm.TopK) - opts.top_p = C.float(llm.TopP) - opts.tfs_z = C.float(llm.TFSZ) - opts.typical_p = C.float(llm.TypicalP) - opts.mirostat = C.int(llm.Mirostat) - opts.mirostat_tau = C.float(llm.MirostatTau) - opts.mirostat_eta = C.float(llm.MirostatEta) +func (llm *LLM) next() (C.llama_token, error) { + if len(llm.embd) >= llm.NumCtx { + numLeft := (llm.NumCtx - llm.NumKeep) / 2 + truncated := llm.embd[:llm.NumKeep] + truncated = append(truncated, llm.embd[len(llm.embd)-numLeft:]...) - output := deque[C.llama_token]{capacity: llm.NumCtx} - - context := deque[int]{capacity: llm.NumCtx / 2} - for _, in := range input { - context.PushLeft(int(in)) + llm.embd = truncated + llm.cursor = llm.NumKeep + log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d cursor=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated), llm.cursor) } - var b bytes.Buffer - for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) { - if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 { - return errors.New("llama: eval") - } - - token, err := llm.sample(output, &opts) - if errors.Is(err, io.EOF) { + for { + if llm.cursor >= len(llm.embd) { break - } else if err != nil { - return err } - b.WriteString(llm.detokenize(token)) - if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax { - // call the callback - fn(api.GenerateResponse{ - Response: b.String(), - }) - - output.PushLeft(token) - context.PushLeft(int(token)) - b.Reset() + numEval := len(llm.embd) - llm.cursor + if numEval > llm.NumBatch { + numEval = llm.NumBatch } - input = []C.llama_token{token} + if retval := C.llama_eval(llm.ctx, unsafe.SliceData(llm.embd[llm.cursor:]), C.int(numEval), C.int(llm.cursor), C.int(llm.NumThread)); retval != 0 { + return 0, fmt.Errorf("llama_eval: %d", retval) + } + + llm.cursor += numEval } - dur := func(ms float64) time.Duration { - d, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) - if err != nil { - panic(err) - } + var sampleOpts C.struct_llama_sample_options + sampleOpts.repeat_penalty = C.float(llm.RepeatPenalty) + sampleOpts.frequency_penalty = C.float(llm.FrequencyPenalty) + sampleOpts.presence_penalty = C.float(llm.PresencePenalty) + sampleOpts.temperature = C.float(llm.Temperature) + sampleOpts.top_k = C.int(llm.TopK) + sampleOpts.top_p = C.float(llm.TopP) + sampleOpts.tfs_z = C.float(llm.TFSZ) + sampleOpts.typical_p = C.float(llm.TypicalP) + sampleOpts.mirostat = C.int(llm.Mirostat) + sampleOpts.mirostat_tau = C.float(llm.MirostatTau) + sampleOpts.mirostat_eta = C.float(llm.MirostatEta) + sampleOpts.penalize_newline = C.bool(llm.PenalizeNewline) - return d - } - - timings := C.llama_get_timings(llm.ctx) - fn(api.GenerateResponse{ - Done: true, - Context: context.Data(), - PromptEvalCount: int(timings.n_p_eval), - PromptEvalDuration: dur(float64(timings.t_p_eval_ms)), - EvalCount: int(timings.n_eval), - EvalDuration: dur(float64(timings.t_eval_ms)), - }) - - return nil -} - -func (llm *LLM) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) { - numVocab := int(C.llama_n_vocab(llm.ctx)) + numVocab := C.llama_n_vocab(llm.ctx) logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab) - candidates := deque[C.struct_llama_token_data]{capacity: numVocab} - for i := 0; i < candidates.Cap(); i++ { - candidates.PushLeft(C.struct_llama_token_data{ + // TODO: logit bias + + candidates := make([]C.llama_token_data, numVocab) + for i := range logits { + candidates[i] = C.llama_token_data{ id: C.int(i), logit: logits[i], p: 0, - }) + } } + repeatLastN := llm.RepeatLastN + if len(llm.last) < repeatLastN { + repeatLastN = len(llm.last) + } + + if llm.NumCtx < repeatLastN { + repeatLastN = llm.NumCtx + } + + lastN := llm.last[len(llm.last)-repeatLastN:] + token := C.llama_sample( llm.ctx, - unsafe.SliceData(candidates.Data()), C.size_t(candidates.Len()), - unsafe.SliceData(output.Data()), C.size_t(output.Len()), - opts) - if token != C.llama_token_eos() { - return token, nil + unsafe.SliceData(candidates), C.size_t(len(candidates)), + unsafe.SliceData(lastN), C.size_t(len(lastN)), + &sampleOpts, + ) + + llm.last = append(llm.last, token) + llm.embd = append(llm.embd, token) + + if token == C.llama_token_eos() { + return 0, io.EOF } - return 0, io.EOF + return token, nil } diff --git a/llama/utils.go b/llama/utils.go index b0db27d4..8b52ad5c 100644 --- a/llama/utils.go +++ b/llama/utils.go @@ -1,104 +1,15 @@ package llama -type node[T any] struct { - t T - next *node[T] - prev *node[T] -} +import ( + "fmt" + "time" +) -type deque[T any] struct { - head *node[T] - tail *node[T] - size int - capacity int -} - -func (d *deque[T]) Empty() bool { - return d.size == 0 -} - -func (d *deque[T]) Len() int { - return d.size -} - -func (d *deque[T]) Cap() int { - return d.capacity -} - -func (d *deque[T]) Push(t T) { - if d.capacity > 0 && d.size >= d.capacity { - d.PopLeft() +func parseDurationMs(ms float64) time.Duration { + dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) + if err != nil { + panic(err) } - n := node[T]{t: t} - if d.head != nil { - n.next = d.head - d.head.prev = &n - d.head = &n - } else { - d.head = &n - d.tail = &n - } - - d.size++ -} - -func (d *deque[T]) PushLeft(t T) { - if d.capacity > 0 && d.size >= d.capacity { - d.Pop() - } - - n := node[T]{t: t} - if d.tail != nil { - n.prev = d.tail - d.tail.next = &n - d.tail = &n - } else { - d.head = &n - d.tail = &n - } - - d.size++ -} - -func (d *deque[T]) Pop() *T { - if d.Empty() { - return nil - } - - head := d.head - d.head = head.next - if d.head != nil { - d.head.prev = nil - } else { - d.tail = nil - } - - d.size-- - return &head.t -} - -func (d *deque[T]) PopLeft() *T { - if d.Empty() { - return nil - } - - tail := d.tail - d.tail = tail.prev - if d.tail != nil { - d.tail.next = nil - } else { - d.head = nil - } - - d.size-- - return &tail.t -} - -func (d *deque[T]) Data() (data []T) { - for n := d.head; n != nil; n = n.next { - data = append(data, n.t) - } - - return data + return dur } From f62a882760c78d799d4044b633a8dda37e097ac8 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 19 Jul 2023 15:00:28 -0700 Subject: [PATCH 4/7] add session expiration --- api/types.go | 45 +++++++++++++++++++++++++++++++---- llama/llama.go | 14 +++++++++++ server/routes.go | 61 ++++++++++++++++++++++++++++++++++++------------ 3 files changed, 100 insertions(+), 20 deletions(-) diff --git a/api/types.go b/api/types.go index 5f8c3891..24666462 100644 --- a/api/types.go +++ b/api/types.go @@ -1,7 +1,9 @@ package api import ( + "encoding/json" "fmt" + "math" "os" "runtime" "time" @@ -28,10 +30,12 @@ func (e StatusError) Error() string { } type GenerateRequest struct { - SessionID int64 `json:"session_id"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Context []int `json:"context,omitempty"` + SessionID int64 `json:"session_id"` + SessionDuration Duration `json:"session_duration,omitempty"` + + Model string `json:"model"` + Prompt string `json:"prompt"` + Context []int `json:"context,omitempty"` Options `json:"options"` } @@ -82,7 +86,9 @@ type ListResponseModel struct { } type GenerateResponse struct { - SessionID int64 `json:"session_id"` + SessionID int64 `json:"session_id"` + SessionExpiresAt time.Time `json:"session_expires_at"` + Model string `json:"model"` CreatedAt time.Time `json:"created_at"` Response string `json:"response,omitempty"` @@ -195,3 +201,32 @@ func DefaultOptions() Options { NumThread: runtime.NumCPU(), } } + +type Duration struct { + time.Duration +} + +func (d *Duration) UnmarshalJSON(b []byte) (err error) { + var v any + if err := json.Unmarshal(b, &v); err != nil { + return err + } + + d.Duration = 5 * time.Minute + + switch t := v.(type) { + case float64: + if t < 0 { + t = math.MaxFloat64 + } + + d.Duration = time.Duration(t) + case string: + d.Duration, err = time.ParseDuration(t) + if err != nil { + return err + } + } + + return nil +} diff --git a/llama/llama.go b/llama/llama.go index 9f5066f3..5919b4bd 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -92,6 +92,7 @@ import ( "log" "os" "strings" + "sync" "unicode/utf8" "unsafe" @@ -107,6 +108,9 @@ type LLM struct { embd []C.llama_token cursor int + mu sync.Mutex + gc bool + api.Options } @@ -156,6 +160,11 @@ func New(model string, opts api.Options) (*LLM, error) { } func (llm *LLM) Close() { + llm.gc = true + + llm.mu.Lock() + defer llm.mu.Unlock() + defer C.llama_free_model(llm.model) defer C.llama_free(llm.ctx) @@ -163,6 +172,9 @@ func (llm *LLM) Close() { } func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { + llm.mu.Lock() + defer llm.mu.Unlock() + C.llama_reset_timings(llm.ctx) tokens := make([]C.llama_token, len(ctx)) @@ -185,6 +197,8 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) break } else if err != nil { return err + } else if llm.gc { + return io.EOF } b.WriteString(llm.detokenize(token)) diff --git a/server/routes.go b/server/routes.go index c3f27ec8..cc9df958 100644 --- a/server/routes.go +++ b/server/routes.go @@ -22,16 +22,19 @@ import ( "github.com/jmorganca/ollama/llama" ) -var mu sync.Mutex - var activeSession struct { - ID int64 - *llama.LLM + mu sync.Mutex + + id int64 + llm *llama.LLM + + expireAt time.Time + expireTimer *time.Timer } func GenerateHandler(c *gin.Context) { - mu.Lock() - defer mu.Unlock() + activeSession.mu.Lock() + defer activeSession.mu.Unlock() checkpointStart := time.Now() @@ -47,10 +50,10 @@ func GenerateHandler(c *gin.Context) { return } - if req.SessionID == 0 || req.SessionID != activeSession.ID { - if activeSession.LLM != nil { - activeSession.Close() - activeSession.LLM = nil + if req.SessionID == 0 || req.SessionID != activeSession.id { + if activeSession.llm != nil { + activeSession.llm.Close() + activeSession.llm = nil } opts := api.DefaultOptions() @@ -70,10 +73,34 @@ func GenerateHandler(c *gin.Context) { return } - activeSession.ID = time.Now().UnixNano() - activeSession.LLM = llm + activeSession.id = time.Now().UnixNano() + activeSession.llm = llm } + sessionDuration := req.SessionDuration + sessionID := activeSession.id + + activeSession.expireAt = time.Now().Add(sessionDuration.Duration) + if activeSession.expireTimer == nil { + activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() { + activeSession.mu.Lock() + defer activeSession.mu.Unlock() + + if sessionID != activeSession.id { + return + } + + if time.Now().Before(activeSession.expireAt) { + return + } + + activeSession.llm.Close() + activeSession.llm = nil + activeSession.id = 0 + }) + } + activeSession.expireTimer.Reset(sessionDuration.Duration) + checkpointLoaded := time.Now() prompt, err := model.Prompt(req) @@ -86,9 +113,13 @@ func GenerateHandler(c *gin.Context) { go func() { defer close(ch) fn := func(r api.GenerateResponse) { + activeSession.expireAt = time.Now().Add(sessionDuration.Duration) + activeSession.expireTimer.Reset(sessionDuration.Duration) + r.Model = req.Model r.CreatedAt = time.Now().UTC() - r.SessionID = activeSession.ID + r.SessionID = activeSession.id + r.SessionExpiresAt = activeSession.expireAt.UTC() if r.Done { r.TotalDuration = time.Since(checkpointStart) r.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -97,7 +128,7 @@ func GenerateHandler(c *gin.Context) { ch <- r } - if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil { + if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -247,7 +278,7 @@ func ListModelsHandler(c *gin.Context) { return } - c.JSON(http.StatusOK, api.ListResponse{models}) + c.JSON(http.StatusOK, api.ListResponse{Models: models}) } func CopyModelHandler(c *gin.Context) { From c4904161891cb077834155798099410af2bbfed9 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 20 Jul 2023 09:29:43 -0700 Subject: [PATCH 5/7] lock on llm.lock(); decrease batch size --- api/types.go | 2 +- llama/llama.go | 16 ++++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/api/types.go b/api/types.go index 24666462..fc00adb1 100644 --- a/api/types.go +++ b/api/types.go @@ -177,7 +177,7 @@ func DefaultOptions() Options { UseNUMA: false, NumCtx: 2048, - NumBatch: 512, + NumBatch: 32, NumGPU: 1, LowVRAM: false, F16KV: true, diff --git a/llama/llama.go b/llama/llama.go index 5919b4bd..07dd8a13 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -172,9 +172,6 @@ func (llm *LLM) Close() { } func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { - llm.mu.Lock() - defer llm.mu.Unlock() - C.llama_reset_timings(llm.ctx) tokens := make([]C.llama_token, len(ctx)) @@ -193,12 +190,12 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) var b bytes.Buffer for { token, err := llm.next() - if errors.Is(err, io.EOF) { + if llm.gc { + return nil + } else if errors.Is(err, io.EOF) { break } else if err != nil { return err - } else if llm.gc { - return io.EOF } b.WriteString(llm.detokenize(token)) @@ -293,6 +290,9 @@ func (llm *LLM) detokenize(tokens ...C.llama_token) string { } func (llm *LLM) next() (C.llama_token, error) { + llm.mu.Lock() + defer llm.mu.Unlock() + if len(llm.embd) >= llm.NumCtx { numLeft := (llm.NumCtx - llm.NumKeep) / 2 truncated := llm.embd[:llm.NumKeep] @@ -304,6 +304,10 @@ func (llm *LLM) next() (C.llama_token, error) { } for { + if llm.gc { + return 0, io.EOF + } + if llm.cursor >= len(llm.embd) { break } From cca61181cb08995ffc2ac93439425ac3fa997a5b Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 25 Jul 2023 15:51:32 -0700 Subject: [PATCH 6/7] sample metrics --- api/types.go | 11 +++++++++++ llama/llama.go | 2 ++ 2 files changed, 13 insertions(+) diff --git a/api/types.go b/api/types.go index fc00adb1..6208cb6e 100644 --- a/api/types.go +++ b/api/types.go @@ -98,6 +98,8 @@ type GenerateResponse struct { TotalDuration time.Duration `json:"total_duration,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"` + SampleCount int `json:"sample_count,omitempty"` + SampleDuration time.Duration `json:"sample_duration,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"` PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` EvalCount int `json:"eval_count,omitempty"` @@ -113,6 +115,15 @@ func (r *GenerateResponse) Summary() { fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration) } + if r.SampleCount > 0 { + fmt.Fprintf(os.Stderr, "sample count: %d token(s)\n", r.SampleCount) + } + + if r.SampleDuration > 0 { + fmt.Fprintf(os.Stderr, "sample duration: %s\n", r.SampleDuration) + fmt.Fprintf(os.Stderr, "sample rate: %.2f tokens/s\n", float64(r.SampleCount)/r.SampleDuration.Seconds()) + } + if r.PromptEvalCount > 0 { fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount) } diff --git a/llama/llama.go b/llama/llama.go index 07dd8a13..e2c30f1f 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -216,6 +216,8 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) fn(api.GenerateResponse{ Done: true, Context: last, + SampleCount: int(timings.n_sample), + SampleDuration: parseDurationMs(float64(timings.t_sample_ms)), PromptEvalCount: int(timings.n_p_eval), PromptEvalDuration: parseDurationMs(float64(timings.t_p_eval_ms)), EvalCount: int(timings.n_eval), From 688661ab9b5bb821555efda75f016689dfd1b2da Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Thu, 27 Jul 2023 16:51:01 -0400 Subject: [PATCH 7/7] increase default batch size to 1024 --- api/types.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/types.go b/api/types.go index 6208cb6e..8f12b5f9 100644 --- a/api/types.go +++ b/api/types.go @@ -188,7 +188,7 @@ func DefaultOptions() Options { UseNUMA: false, NumCtx: 2048, - NumBatch: 32, + NumBatch: 1024, NumGPU: 1, LowVRAM: false, F16KV: true,