diff --git a/llama/runner/runner.go b/llama/runner/runner.go index 61f1e7f0..2abe2f9e 100644 --- a/llama/runner/runner.go +++ b/llama/runner/runner.go @@ -14,6 +14,7 @@ import ( "strconv" "strings" "sync" + "time" "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" @@ -50,6 +51,12 @@ type Sequence struct { embeddingOnly bool doneReason string + + // Metrics + t_start_process_prompt time.Time + t_start_genereration time.Time + n_decoded int + n_prompt_tokens int } // prompt returns true if the prompt is still being processed @@ -80,12 +87,13 @@ func (s *Server) NewSequence(prompt string, numPredict int, stop []string, param } return &Sequence{ - tokens: tokens, - responses: make(chan string, 1), - embedding: make(chan []float32, 1), - samplingCtx: sc, - embeddingOnly: embedding, - stop: stop, + tokens: tokens, + n_prompt_tokens: len(tokens), + responses: make(chan string, 1), + embedding: make(chan []float32, 1), + samplingCtx: sc, + embeddingOnly: embedding, + stop: stop, } } @@ -161,6 +169,10 @@ func (s *Server) run(ctx context.Context) { continue } + if seq.t_start_process_prompt.IsZero() { + seq.t_start_process_prompt = time.Now() + } + for j, t := range seq.tokens { // todo: make this n_batch if j > s.batchSize { @@ -207,6 +219,10 @@ func (s *Server) run(ctx context.Context) { token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch) seq.samplingCtx.Accept(s.lc, token, true) + seq.n_decoded += 1 + if seq.n_decoded == 1 { + seq.t_start_genereration = time.Now() + } piece := s.model.TokenToPiece(token) seq.numPredicted++ @@ -278,8 +294,26 @@ type CompletionRequest struct { api.Options } +type Timings struct { + PredictedN int `json:"predicted_n"` + PredictedMS float64 `json:"predicted_ms"` + PromptN int `json:"prompt_n"` + PromptMS float64 `json:"prompt_ms"` +} + type CompletionResponse struct { - Token string `json:"token"` + Content string `json:"content"` + Stop bool `json:"stop"` + + Model string `json:"model,omitempty"` + Prompt string `json:"prompt,omitempty"` + StoppedLimit bool `json:"stopped_limit,omitempty"` + PredictedN int `json:"predicted_n,omitempty"` + PredictedMS float64 `json:"predicted_ms,omitempty"` + PromptN int `json:"prompt_n,omitempty"` + PromptMS float64 `json:"prompt_ms,omitempty"` + + Timings Timings `json:"timings"` } func (s *Server) completion(w http.ResponseWriter, r *http.Request) { @@ -326,9 +360,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { s.mu.Unlock() // stream the response - for token := range seq.responses { + for content := range seq.responses { if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Token: token, + Content: content, }); err != nil { log.Println("Failed to encode result:", err) return @@ -342,6 +376,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } + + // Send the stop + if err := json.NewEncoder(w).Encode(&CompletionResponse{ + Stop: true, + Timings: Timings{ + PromptN: seq.n_prompt_tokens, + PromptMS: float64(seq.t_start_genereration.Sub(seq.t_start_process_prompt).Milliseconds()), + PredictedN: seq.n_decoded, + PredictedMS: float64(time.Since(seq.t_start_genereration).Milliseconds()), + }, + }); err != nil { + log.Println("Failed to encode result:", err) + return + } + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + + flusher.Flush() } type EmbeddingRequest struct {