ollama/llama/runner/runner.go
Jesse Gross eccd4dd8d2 runner.go: Use correct JSON field names for runners
The fields for inference parameters are very similar between the
Ollama API and Ollama/runners. However, some of the names are
slightly different. For these fields (such as NumKeep and
NumPredict), the values from Ollama were never read properly and
defaults were always used.

In the future, we can share a single interface rather than duplicating
structs. However, this keeps the interface consistent with minimal
changes in Ollama as long as we continue to use server.cpp
2024-09-03 21:15:14 -04:00

673 lines
17 KiB
Go

package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"log/slog"
"math"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama"
)
type Sequence struct {
// number of tokens evaluated
nPast int
// batch index
iBatch int
// number of tokens predicted so far
numPredicted int
// tokens left to evaluate
tokens []int
// channel to send responses over
responses chan string
// number of tokens to predict
numPredict int
samplingCtx *llama.SamplingContext
// channel to send back the embedding if embedding only
embedding chan []float32
// stop sequences
stop []string
// number of tokens to keep at the beginning when shifting context window
numKeep int
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
doneReason string
// Metrics
t_start_process_prompt time.Time
t_start_genereration time.Time
n_decoded int
n_prompt_tokens int
}
type NewSequenceParams struct {
numPredict int
stop []string
numKeep int
samplingParams *llama.SamplingParams
embedding bool
}
func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence {
tokens, err := s.lc.Model().Tokenize(prompt, true, true)
if err != nil {
panic(err)
}
if params.numKeep < 0 {
params.numKeep = len(tokens)
}
// Subtracting 4 ensures that at least 1 token can be discarded during shift
params.numKeep = min(params.numKeep, s.numCtx-4)
params.numKeep += s.bosToken
// truncate to fit in context window
if len(tokens) > s.numCtx {
slog.Warn("truncating input prompt", "limit", s.numCtx, "prompt", len(tokens), "numKeep", params.numKeep)
newTokens := tokens[:params.numKeep]
newTokens = append(newTokens, tokens[len(tokens)-s.numCtx+params.numKeep:]...)
tokens = newTokens
}
var sc *llama.SamplingContext
if params.samplingParams != nil {
sc = llama.NewSamplingContext(*params.samplingParams)
for _, t := range tokens {
sc.Accept(s.lc, t, false)
}
}
return &Sequence{
tokens: tokens,
n_prompt_tokens: len(tokens),
numPredict: params.numPredict,
responses: make(chan string, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
}
}
type Server struct {
model *llama.Model
lc *llama.Context
cc *llama.ClipContext
batchSize int
// parallel is the number of parallel requests to handle
parallel int
// seqs is the list of parallel sequences being evaluated
// TODO (jmorganca): this can probably be moved into run()
seqs []*Sequence
// context window size
numCtx int
// does this model require a beginning of sequence token?
bosToken int
mu sync.Mutex
cond *sync.Cond
progress float32
status string
}
func (s *Server) allNil() bool {
for _, item := range s.seqs {
if item != nil {
return false
}
}
return true
}
func (s *Server) shiftContext(seqIndex int) {
seq := s.seqs[seqIndex]
numLeft := seq.nPast - seq.numKeep
numDiscard := numLeft / 2
slog.Debug("context limit hit - shifting", "limit", s.numCtx, "nPast", seq.nPast,
"numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
s.lc.KvCacheSeqRm(seqIndex, seq.numKeep, seq.numKeep+numDiscard)
s.lc.KvCacheSeqAdd(seqIndex, seq.numKeep+numDiscard, seq.nPast, -numDiscard)
seq.nPast -= numDiscard
}
func (s *Server) run(ctx context.Context) {
// TODO - should this be n_ctx / parallel like the old server.cpp setup?
batch := llama.NewBatch(s.batchSize, 0, s.parallel)
defer batch.Free()
// build up stop sequences as we recognize them
// TODO (jmorganca): simplify this
pieces := make([][]string, s.parallel)
for {
select {
case <-ctx.Done():
return
default:
slog.Debug("Processing batch", "seqs", len(s.seqs))
s.mu.Lock()
for s.allNil() {
s.cond.Wait() // Wait until an item is added
}
s.mu.Unlock()
for i, seq := range s.seqs {
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
seq.doneReason = "limit"
close(seq.responses)
s.lc.KvCacheSeqRm(i, 0, -1)
s.seqs[i] = nil
continue
}
if seq.nPast+len(seq.tokens) > s.numCtx {
s.shiftContext(i)
}
if seq.t_start_process_prompt.IsZero() {
seq.t_start_process_prompt = time.Now()
}
var numTokensProcessed int
for j, t := range seq.tokens {
// todo: make this n_batch
if j >= s.batchSize {
break
}
batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens))
seq.nPast++
numTokensProcessed++
}
seq.tokens = seq.tokens[numTokensProcessed:]
seq.iBatch = batch.NumTokens() - 1
}
if batch.NumTokens() == 0 {
continue
}
err := s.lc.Decode(batch)
if err != nil {
slog.Error("failed to decode batch", "error", err)
panic("Failed to decode")
}
for i, seq := range s.seqs {
if seq == nil {
continue
}
// don't sample prompt processing
if len(seq.tokens) != 0 {
continue
}
// if done processing the prompt, generating an embedding and return
if seq.embeddingOnly {
embd := s.lc.GetEmbeddingsSeq(i)
if embd == nil {
embd = s.lc.GetEmbeddingsIth(seq.iBatch)
}
seq.embedding <- embd
close(seq.embedding)
s.lc.KvCacheSeqRm(i, 0, -1)
s.seqs[i] = nil
continue
}
// sample a token
// logits := s.lc.GetLogitsIth(ibatch[i])
// token := s.lc.SampleTokenGreedy(logits)
token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
seq.samplingCtx.Accept(s.lc, token, true)
seq.n_decoded += 1
if seq.n_decoded == 1 {
seq.t_start_genereration = time.Now()
}
piece := s.model.TokenToPiece(token)
seq.numPredicted++
slog.Debug("sampled", "piece", piece)
// if it's an end of sequence token, break
// TODO: just end this sequence
if s.model.TokenIsEog(token) {
// TODO: end the sequence instead of quitting the pool
s.lc.KvCacheSeqRm(i, 0, -1)
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
seq.doneReason = "stop"
close(seq.responses)
seq.samplingCtx.Free()
pieces[i] = []string{}
s.seqs[i] = nil
continue
}
seq.tokens = []int{token}
pieces[i] = append(pieces[i], piece)
sequence := strings.Join(pieces[i], "")
if ok, stop := findStop(sequence, seq.stop); ok {
slog.Info("hit stop token", "stop", seq.stop)
truncated := truncateStop(pieces[i], stop)
for _, p := range truncated {
seq.responses <- p
}
s.lc.KvCacheSeqRm(i, 0, -1)
seq.doneReason = "stop"
close(seq.responses)
seq.samplingCtx.Free()
pieces[i] = []string{}
s.seqs[i] = nil
continue
}
if containsStopSuffix(sequence, seq.stop) {
continue
}
for _, p := range pieces[i] {
seq.responses <- p
}
pieces[i] = []string{}
}
batch.Clear()
}
}
}
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TFSZ float32 `json:"tfs_z"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
PenalizeNewline bool `json:"penalize_nl"`
Stop []string `json:"stop"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []string `json:"images"`
Grammar string `json:"grammar"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req CompletionRequest
req.Options = Options(api.DefaultOptions())
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
w.WriteHeader(http.StatusOK)
var samplingParams llama.SamplingParams
samplingParams.TopK = req.TopK
samplingParams.TopP = req.TopP
samplingParams.TfsZ = req.TFSZ
samplingParams.TypicalP = req.TypicalP
samplingParams.Temp = req.Temperature
samplingParams.PenaltyRepeat = req.RepeatPenalty
samplingParams.PenaltyFreq = req.FrequencyPenalty
samplingParams.PenaltyPresent = req.PresencePenalty
samplingParams.Mirostat = req.Mirostat
samplingParams.MirostatTau = req.MirostatTau
samplingParams.MirostatEta = req.MirostatEta
samplingParams.PenalizeNl = req.PenalizeNewline
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar
seq := s.NewSequence(req.Prompt, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: req.NumKeep,
samplingParams: &samplingParams,
embedding: false,
})
// TODO (jmorganca): add to sequence queue instead of
// failing if a slot isn't available
s.mu.Lock()
for i, sq := range s.seqs {
if sq == nil {
s.seqs[i] = seq
s.cond.Signal()
break
}
}
s.mu.Unlock()
// stream the response
for content := range seq.responses {
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Content: content,
}); err != nil {
log.Println("Failed to encode result:", err)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
return
}
flusher.Flush()
}
// Send the stop
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Stop: true,
Timings: Timings{
PromptN: seq.n_prompt_tokens,
PromptMS: float64(seq.t_start_genereration.Sub(seq.t_start_process_prompt).Milliseconds()),
PredictedN: seq.n_decoded,
PredictedMS: float64(time.Since(seq.t_start_genereration).Milliseconds()),
},
}); err != nil {
log.Println("Failed to encode result:", err)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
return
}
flusher.Flush()
}
type EmbeddingRequest struct {
Content []string `json:"content"`
}
type EmbeddingResponse struct {
Embedding [][]float32 `json:"embedding"`
}
// TODO (jmorganca): is it safe to do this concurrently with decoding?
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
slog.Debug("embedding request", "content", req.Content)
seqs := make([]*Sequence, len(req.Content))
embeddings := make([][]float32, len(req.Content))
var processed int
for i, content := range req.Content {
seqs[i] = s.NewSequence(content, NewSequenceParams{embedding: true})
}
// TODO - refactor to go routines to add seq's and drain the responses
// so we don't stall until each set is iterated through
for processed < len(seqs) {
s.mu.Lock()
for i, sq := range s.seqs {
if processed >= len(seqs) {
break
}
if sq == nil {
s.seqs[i] = seqs[processed]
processed += 1
}
}
s.cond.Signal()
s.mu.Unlock()
for i := range processed {
embeddings[i] = <-seqs[i].embedding
}
}
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
Embedding: embeddings,
}); err != nil {
log.Println("Failed to encode result:", err)
return
}
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
// TODO (jmorganca): is it safe to do this concurrently with decoding?
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&HealthResponse{
Status: s.status,
Progress: s.progress,
}); err != nil {
log.Println("Failed to encode result:", err)
return
}
}
func main() {
mpath := flag.String("model", "", "Path to model binary file")
ppath := flag.String("mmproj", "", "Path to projector binary file")
parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
batchSize := flag.Int("batch-size", 512, "Batch size")
nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
numCtx := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
lpath := flag.String("lora", "", "Path to lora layer file")
port := flag.Int("port", 8080, "Port to expose the server on")
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
// TODO not yet implemented but wired to keep the parsing aligned
embedding := flag.Bool("embedding", false, "enable embedding vector output (default: disabled)")
logDisable := flag.Bool("log-disable", false, "disables logging to a file")
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
f32 := flag.Bool("memory-f32", false, "use f32 instead of f16 for memory key+value (default: disabled) not recommended: doubles context memory required and no measurable increase in quality")
noMmap := flag.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
mlock := flag.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
tensorSplit := flag.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
flag.Parse()
level := slog.LevelInfo
if *verbose {
level = slog.LevelDebug
}
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
if attr.Key == slog.SourceKey {
source := attr.Value.Any().(*slog.Source)
source.File = filepath.Base(source.File)
}
return attr
},
})
slog.SetDefault(slog.New(handler))
// TODO actually implement...
if *embedding {
slog.Warn("embeddings not yet support")
}
if *logDisable {
slog.Info("ignoring --log-disable")
}
if *f32 {
slog.Warn("memory-f32 not yet supported")
}
if *noMmap {
slog.Warn("no-mmap not yet supported")
}
if *mlock {
slog.Warn("mlock not yet supported")
}
if *tensorSplit != "" {
slog.Warn("tensor-split not yet implemented")
}
server := &Server{
numCtx: *numCtx,
batchSize: *batchSize,
parallel: *parallel,
seqs: make([]*Sequence, *parallel),
status: "loading",
}
// load the model
llama.BackendInit()
params := llama.NewModelParams(*nGpuLayers, *mainGpu, func(progress float32) {
slog.Debug("Loading model", "progress %", math.Round(float64(progress*100)))
server.progress = progress
})
server.model = llama.LoadModelFromFile(*mpath, params)
if *lpath != "" {
err := server.model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
if err != nil {
panic(err)
}
}
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
server.lc = llama.NewContextWithModel(server.model, ctxParams)
if server.model.ShouldAddBOSToken() {
server.bosToken = 1
}
if *ppath != "" {
server.cc = llama.NewClipContext(*ppath)
}
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background())
go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port)
listener, err := net.Listen("tcp", addr)
if err != nil {
fmt.Println("Listen error:", err)
return
}
defer listener.Close()
mux := http.NewServeMux()
mux.HandleFunc("/embedding", server.embeddings)
mux.HandleFunc("/completion", server.completion)
mux.HandleFunc("/health", server.health)
httpServer := http.Server{
Handler: mux,
}
server.status = "ok"
log.Println("Server listening on", addr)
if err := httpServer.Serve(listener); err != nil {
log.Fatal("server error:", err)
}
cancel()
}