forked from third-party-mirrors/ollama
The health endpoint needs a little more work to show progress as Ollama expects but we can at least return the right status and have comments for the future.
724 lines
18 KiB
Go
724 lines
18 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"log/slog"
|
|
"math"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/llama"
|
|
)
|
|
|
|
type Sequence struct {
|
|
// number of tokens evaluated
|
|
nPast int
|
|
|
|
// batch index
|
|
iBatch int
|
|
|
|
// number of tokens predicted so far
|
|
numPredicted int
|
|
|
|
// tokens left to evaluate
|
|
tokens []int
|
|
|
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
|
// TODO (jmorganca): simplify this
|
|
pendingResponses []string
|
|
|
|
// channel to send responses over
|
|
responses chan string
|
|
|
|
// channel to stop decoding (such as if the remote connection is closed)
|
|
quit chan bool
|
|
|
|
// number of tokens to predict
|
|
numPredict int
|
|
|
|
samplingCtx *llama.SamplingContext
|
|
|
|
// channel to send back the embedding if embedding only
|
|
embedding chan []float32
|
|
|
|
// stop sequences
|
|
stop []string
|
|
|
|
// number of tokens to keep at the beginning when shifting context window
|
|
numKeep int
|
|
|
|
// true if an embedding are to be returned instead of text generation
|
|
embeddingOnly bool
|
|
|
|
doneReason string
|
|
|
|
// Metrics
|
|
startProcessingTime time.Time
|
|
startGenerationTime time.Time
|
|
numDecoded int
|
|
numPromptTokens int
|
|
}
|
|
|
|
type NewSequenceParams struct {
|
|
numPredict int
|
|
stop []string
|
|
numKeep int
|
|
samplingParams *llama.SamplingParams
|
|
embedding bool
|
|
}
|
|
|
|
func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence {
|
|
tokens, err := s.lc.Model().Tokenize(prompt, true, true)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if params.numKeep < 0 {
|
|
params.numKeep = len(tokens)
|
|
}
|
|
// Subtracting 4 ensures that at least 1 token can be discarded during shift
|
|
params.numKeep = min(params.numKeep, s.numCtx-4)
|
|
params.numKeep += s.bosToken
|
|
|
|
// truncate to fit in context window
|
|
if len(tokens) > s.numCtx {
|
|
slog.Warn("truncating input prompt", "limit", s.numCtx, "prompt", len(tokens), "numKeep", params.numKeep)
|
|
newTokens := tokens[:params.numKeep]
|
|
newTokens = append(newTokens, tokens[len(tokens)-s.numCtx+params.numKeep:]...)
|
|
tokens = newTokens
|
|
}
|
|
|
|
var sc *llama.SamplingContext
|
|
if params.samplingParams != nil {
|
|
sc = llama.NewSamplingContext(*params.samplingParams)
|
|
for _, t := range tokens {
|
|
sc.Accept(s.lc, t, false)
|
|
}
|
|
}
|
|
|
|
return &Sequence{
|
|
tokens: tokens,
|
|
numPromptTokens: len(tokens),
|
|
numPredict: params.numPredict,
|
|
pendingResponses: make([]string, 0),
|
|
responses: make(chan string, 1),
|
|
quit: make(chan bool, 1),
|
|
embedding: make(chan []float32, 1),
|
|
samplingCtx: sc,
|
|
embeddingOnly: params.embedding,
|
|
stop: params.stop,
|
|
numKeep: params.numKeep,
|
|
}
|
|
}
|
|
|
|
type Server struct {
|
|
model *llama.Model
|
|
lc *llama.Context
|
|
cc *llama.ClipContext
|
|
|
|
batchSize int
|
|
|
|
// parallel is the number of parallel requests to handle
|
|
parallel int
|
|
|
|
// seqs is the list of parallel sequences being evaluated
|
|
// TODO (jmorganca): this can probably be moved into run()
|
|
seqs []*Sequence
|
|
|
|
// context window size
|
|
numCtx int
|
|
|
|
// does this model require a beginning of sequence token?
|
|
bosToken int
|
|
|
|
mu sync.Mutex
|
|
|
|
cond *sync.Cond
|
|
|
|
progress float32
|
|
|
|
status string
|
|
}
|
|
|
|
func (s *Server) allNil() bool {
|
|
for _, item := range s.seqs {
|
|
if item != nil {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (s *Server) shiftContext(seqIndex int) {
|
|
seq := s.seqs[seqIndex]
|
|
|
|
numLeft := seq.nPast - seq.numKeep
|
|
numDiscard := numLeft / 2
|
|
|
|
slog.Debug("context limit hit - shifting", "limit", s.numCtx, "nPast", seq.nPast,
|
|
"numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
|
|
|
|
s.lc.KvCacheSeqRm(seqIndex, seq.numKeep, seq.numKeep+numDiscard)
|
|
s.lc.KvCacheSeqAdd(seqIndex, seq.numKeep+numDiscard, seq.nPast, -numDiscard)
|
|
|
|
seq.nPast -= numDiscard
|
|
}
|
|
|
|
func incompleteUnicode(token string) bool {
|
|
incomplete := false
|
|
|
|
// check if there is incomplete UTF-8 character at the end
|
|
for i := 1; i < 5 && i <= len(token); i++ {
|
|
c := token[len(token)-i]
|
|
|
|
if (c & 0xc0) == 0x80 {
|
|
// continuation byte: 10xxxxxx
|
|
continue
|
|
}
|
|
|
|
if (c & 0xe0) == 0xc0 {
|
|
// 2-byte character: 110xxxxx ...
|
|
incomplete = i < 2
|
|
} else if (c & 0xf0) == 0xe0 {
|
|
// 3-byte character: 1110xxxx ...
|
|
incomplete = i < 3
|
|
} else if (c & 0xf8) == 0xf0 {
|
|
// 4-byte character: 11110xxx ...
|
|
incomplete = i < 4
|
|
}
|
|
|
|
// else 1-byte character or invalid byte
|
|
break
|
|
}
|
|
|
|
return incomplete
|
|
}
|
|
|
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
|
seq := s.seqs[seqIndex]
|
|
|
|
seq.doneReason = reason
|
|
close(seq.responses)
|
|
close(seq.embedding)
|
|
seq.pendingResponses = []string{}
|
|
seq.samplingCtx.Free()
|
|
s.lc.KvCacheSeqRm(seqIndex, 0, -1)
|
|
s.seqs[seqIndex] = nil
|
|
}
|
|
|
|
func (s *Server) run(ctx context.Context) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
s.processBatch()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) processBatch() {
|
|
batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
|
|
defer batch.Free()
|
|
|
|
s.mu.Lock()
|
|
for s.allNil() {
|
|
s.cond.Wait() // Wait until an item is added
|
|
}
|
|
defer s.mu.Unlock()
|
|
|
|
slog.Debug("Processing batch", "seqs", len(s.seqs))
|
|
|
|
for i, seq := range s.seqs {
|
|
if seq == nil {
|
|
continue
|
|
}
|
|
|
|
// if past the num predict limit
|
|
if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
|
|
s.removeSequence(i, "limit")
|
|
continue
|
|
}
|
|
|
|
if seq.nPast+len(seq.tokens) > s.numCtx {
|
|
s.shiftContext(i)
|
|
}
|
|
|
|
if seq.startProcessingTime.IsZero() {
|
|
seq.startProcessingTime = time.Now()
|
|
}
|
|
|
|
var numTokensProcessed int
|
|
for j, t := range seq.tokens {
|
|
// todo: make this n_batch
|
|
if j >= s.batchSize {
|
|
break
|
|
}
|
|
batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens))
|
|
seq.nPast++
|
|
numTokensProcessed++
|
|
}
|
|
seq.tokens = seq.tokens[numTokensProcessed:]
|
|
seq.iBatch = batch.NumTokens() - 1
|
|
}
|
|
|
|
if batch.NumTokens() == 0 {
|
|
return
|
|
}
|
|
|
|
err := s.lc.Decode(batch)
|
|
if err != nil {
|
|
slog.Error("failed to decode batch", "error", err)
|
|
panic("Failed to decode")
|
|
}
|
|
|
|
for i, seq := range s.seqs {
|
|
if seq == nil {
|
|
continue
|
|
}
|
|
|
|
// don't sample prompt processing
|
|
if len(seq.tokens) != 0 {
|
|
continue
|
|
}
|
|
|
|
// if done processing the prompt, generate an embedding and return
|
|
if seq.embeddingOnly {
|
|
embd := s.lc.GetEmbeddingsSeq(i)
|
|
if embd == nil {
|
|
embd = s.lc.GetEmbeddingsIth(seq.iBatch)
|
|
}
|
|
|
|
seq.embedding <- embd
|
|
s.removeSequence(i, "")
|
|
continue
|
|
}
|
|
|
|
// sample a token
|
|
token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
|
|
|
|
seq.samplingCtx.Accept(s.lc, token, true)
|
|
seq.numDecoded += 1
|
|
if seq.numDecoded == 1 {
|
|
seq.startGenerationTime = time.Now()
|
|
}
|
|
piece := s.model.TokenToPiece(token)
|
|
|
|
seq.numPredicted++
|
|
|
|
slog.Debug("sampled", "piece", piece)
|
|
|
|
// if it's an end of sequence token, break
|
|
// TODO: just end this sequence
|
|
if s.model.TokenIsEog(token) {
|
|
// TODO (jmorganca): we should send this back
|
|
// as it's important for the /api/generate context
|
|
// seq.responses <- piece
|
|
|
|
// TODO: end the sequence instead of quitting the pool
|
|
s.removeSequence(i, "stop")
|
|
continue
|
|
}
|
|
|
|
seq.tokens = []int{token}
|
|
|
|
seq.pendingResponses = append(seq.pendingResponses, piece)
|
|
sequence := strings.Join(seq.pendingResponses, "")
|
|
|
|
if incompleteUnicode(sequence) {
|
|
continue
|
|
}
|
|
|
|
if ok, stop := findStop(sequence, seq.stop); ok {
|
|
slog.Info("hit stop token", "stop", seq.stop)
|
|
|
|
truncated := truncateStop(seq.pendingResponses, stop)
|
|
|
|
for _, p := range truncated {
|
|
select {
|
|
case seq.responses <- p:
|
|
case <-seq.quit:
|
|
break
|
|
}
|
|
}
|
|
|
|
s.removeSequence(i, "stop")
|
|
continue
|
|
}
|
|
|
|
if containsStopSuffix(sequence, seq.stop) {
|
|
continue
|
|
}
|
|
|
|
for _, p := range seq.pendingResponses {
|
|
select {
|
|
case seq.responses <- p:
|
|
case <-seq.quit:
|
|
s.removeSequence(i, "connection")
|
|
break
|
|
}
|
|
}
|
|
|
|
seq.pendingResponses = []string{}
|
|
}
|
|
}
|
|
|
|
type Options struct {
|
|
api.Runner
|
|
|
|
NumKeep int `json:"n_keep"`
|
|
Seed int `json:"seed"`
|
|
NumPredict int `json:"n_predict"`
|
|
TopK int `json:"top_k"`
|
|
TopP float32 `json:"top_p"`
|
|
MinP float32 `json:"min_p"`
|
|
TFSZ float32 `json:"tfs_z"`
|
|
TypicalP float32 `json:"typical_p"`
|
|
RepeatLastN int `json:"repeat_last_n"`
|
|
Temperature float32 `json:"temperature"`
|
|
RepeatPenalty float32 `json:"repeat_penalty"`
|
|
PresencePenalty float32 `json:"presence_penalty"`
|
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
Mirostat int `json:"mirostat"`
|
|
MirostatTau float32 `json:"mirostat_tau"`
|
|
MirostatEta float32 `json:"mirostat_eta"`
|
|
PenalizeNewline bool `json:"penalize_nl"`
|
|
Stop []string `json:"stop"`
|
|
}
|
|
|
|
type CompletionRequest struct {
|
|
Prompt string `json:"prompt"`
|
|
Images []string `json:"images"`
|
|
Grammar string `json:"grammar"`
|
|
|
|
Options
|
|
}
|
|
|
|
type Timings struct {
|
|
PredictedN int `json:"predicted_n"`
|
|
PredictedMS float64 `json:"predicted_ms"`
|
|
PromptN int `json:"prompt_n"`
|
|
PromptMS float64 `json:"prompt_ms"`
|
|
}
|
|
|
|
type CompletionResponse struct {
|
|
Content string `json:"content"`
|
|
Stop bool `json:"stop"`
|
|
|
|
Model string `json:"model,omitempty"`
|
|
Prompt string `json:"prompt,omitempty"`
|
|
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
|
PredictedN int `json:"predicted_n,omitempty"`
|
|
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
|
PromptN int `json:"prompt_n,omitempty"`
|
|
PromptMS float64 `json:"prompt_ms,omitempty"`
|
|
|
|
Timings Timings `json:"timings"`
|
|
}
|
|
|
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
var req CompletionRequest
|
|
req.Options = Options(api.DefaultOptions())
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Set the headers to indicate streaming
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Transfer-Encoding", "chunked")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
var samplingParams llama.SamplingParams
|
|
samplingParams.TopK = req.TopK
|
|
samplingParams.TopP = req.TopP
|
|
samplingParams.MinP = req.MinP
|
|
samplingParams.TfsZ = req.TFSZ
|
|
samplingParams.TypicalP = req.TypicalP
|
|
samplingParams.Temp = req.Temperature
|
|
samplingParams.RepeatLastN = req.RepeatLastN
|
|
samplingParams.PenaltyRepeat = req.RepeatPenalty
|
|
samplingParams.PenaltyFreq = req.FrequencyPenalty
|
|
samplingParams.PenaltyPresent = req.PresencePenalty
|
|
samplingParams.Mirostat = req.Mirostat
|
|
samplingParams.MirostatTau = req.MirostatTau
|
|
samplingParams.MirostatEta = req.MirostatEta
|
|
samplingParams.PenalizeNl = req.PenalizeNewline
|
|
samplingParams.Seed = uint32(req.Seed)
|
|
samplingParams.Grammar = req.Grammar
|
|
|
|
seq := s.NewSequence(req.Prompt, NewSequenceParams{
|
|
numPredict: req.NumPredict,
|
|
stop: req.Stop,
|
|
numKeep: req.NumKeep,
|
|
samplingParams: &samplingParams,
|
|
embedding: false,
|
|
})
|
|
|
|
// TODO (jmorganca): add to sequence queue instead of
|
|
// failing if a slot isn't available
|
|
s.mu.Lock()
|
|
for i, sq := range s.seqs {
|
|
if sq == nil {
|
|
s.seqs[i] = seq
|
|
s.cond.Signal()
|
|
break
|
|
}
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
// stream the response
|
|
for content := range seq.responses {
|
|
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
|
Content: content,
|
|
}); err != nil {
|
|
log.Println("Failed to encode result:", err)
|
|
close(seq.quit)
|
|
return
|
|
}
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
|
close(seq.quit)
|
|
return
|
|
}
|
|
|
|
flusher.Flush()
|
|
}
|
|
|
|
// Send the stop
|
|
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
|
Stop: true,
|
|
Timings: Timings{
|
|
PromptN: seq.numPromptTokens,
|
|
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
|
|
PredictedN: seq.numDecoded,
|
|
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
|
|
},
|
|
}); err != nil {
|
|
log.Println("Failed to encode result:", err)
|
|
return
|
|
}
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
flusher.Flush()
|
|
}
|
|
|
|
type EmbeddingRequest struct {
|
|
Content []string `json:"content"`
|
|
}
|
|
|
|
type EmbeddingResponse struct {
|
|
Embedding [][]float32 `json:"embedding"`
|
|
}
|
|
|
|
// TODO (jmorganca): is it safe to do this concurrently with decoding?
|
|
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|
var req EmbeddingRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
slog.Debug("embedding request", "content", req.Content)
|
|
seqs := make([]*Sequence, len(req.Content))
|
|
embeddings := make([][]float32, len(req.Content))
|
|
var processed int
|
|
for i, content := range req.Content {
|
|
seqs[i] = s.NewSequence(content, NewSequenceParams{embedding: true})
|
|
}
|
|
|
|
// TODO - refactor to go routines to add seq's and drain the responses
|
|
// so we don't stall until each set is iterated through
|
|
for processed < len(seqs) {
|
|
s.mu.Lock()
|
|
for i, sq := range s.seqs {
|
|
if processed >= len(seqs) {
|
|
break
|
|
}
|
|
if sq == nil {
|
|
s.seqs[i] = seqs[processed]
|
|
processed += 1
|
|
}
|
|
}
|
|
s.cond.Signal()
|
|
s.mu.Unlock()
|
|
|
|
for i := range processed {
|
|
embeddings[i] = <-seqs[i].embedding
|
|
}
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
|
Embedding: embeddings,
|
|
}); err != nil {
|
|
log.Println("Failed to encode result:", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
type HealthResponse struct {
|
|
Status string `json:"status"`
|
|
Progress float32 `json:"progress"`
|
|
}
|
|
|
|
// TODO (jmorganca): is it safe to do this concurrently with updating status?
|
|
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(&HealthResponse{
|
|
Status: s.status,
|
|
Progress: s.progress,
|
|
}); err != nil {
|
|
log.Println("Failed to encode result:", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
mpath := flag.String("model", "", "Path to model binary file")
|
|
ppath := flag.String("mmproj", "", "Path to projector binary file")
|
|
parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
|
batchSize := flag.Int("batch-size", 512, "Batch size")
|
|
nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
|
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
|
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
|
|
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
|
|
lpath := flag.String("lora", "", "Path to lora layer file")
|
|
port := flag.Int("port", 8080, "Port to expose the server on")
|
|
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
|
|
|
// TODO not yet implemented but wired to keep the parsing aligned
|
|
embedding := flag.Bool("embedding", false, "enable embedding vector output (default: disabled)")
|
|
logDisable := flag.Bool("log-disable", false, "disables logging to a file")
|
|
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
|
|
f32 := flag.Bool("memory-f32", false, "use f32 instead of f16 for memory key+value (default: disabled) not recommended: doubles context memory required and no measurable increase in quality")
|
|
noMmap := flag.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
|
|
mlock := flag.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
|
|
tensorSplit := flag.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
|
|
|
flag.Parse()
|
|
level := slog.LevelInfo
|
|
if *verbose {
|
|
level = slog.LevelDebug
|
|
}
|
|
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
|
Level: level,
|
|
AddSource: true,
|
|
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
|
if attr.Key == slog.SourceKey {
|
|
source := attr.Value.Any().(*slog.Source)
|
|
source.File = filepath.Base(source.File)
|
|
}
|
|
return attr
|
|
},
|
|
})
|
|
slog.SetDefault(slog.New(handler))
|
|
|
|
// TODO actually implement...
|
|
if *embedding {
|
|
slog.Warn("embeddings not yet supported")
|
|
}
|
|
if *logDisable {
|
|
slog.Info("ignoring --log-disable")
|
|
}
|
|
if *f32 {
|
|
slog.Warn("memory-f32 not yet supported")
|
|
}
|
|
if *noMmap {
|
|
slog.Warn("no-mmap not yet supported")
|
|
}
|
|
if *mlock {
|
|
slog.Warn("mlock not yet supported")
|
|
}
|
|
if *tensorSplit != "" {
|
|
slog.Warn("tensor-split not yet implemented")
|
|
}
|
|
|
|
server := &Server{
|
|
numCtx: *kvSize / *parallel,
|
|
batchSize: *batchSize,
|
|
parallel: *parallel,
|
|
seqs: make([]*Sequence, *parallel),
|
|
status: "loading model",
|
|
}
|
|
|
|
// TODO (jessegross): This should be in a separate goroutine so we can report progress,
|
|
// otherwise Ollama can timeout for large model loads
|
|
// load the model
|
|
llama.BackendInit()
|
|
params := llama.NewModelParams(*nGpuLayers, *mainGpu, func(progress float32) {
|
|
slog.Debug("Loading model", "progress %", math.Round(float64(progress*100)))
|
|
server.progress = progress
|
|
})
|
|
server.model = llama.LoadModelFromFile(*mpath, params)
|
|
|
|
if *lpath != "" {
|
|
err := server.model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
ctxParams := llama.NewContextParams(*kvSize, *threads, *flashAttention)
|
|
server.lc = llama.NewContextWithModel(server.model, ctxParams)
|
|
|
|
if server.model.ShouldAddBOSToken() {
|
|
server.bosToken = 1
|
|
}
|
|
|
|
if *ppath != "" {
|
|
server.cc = llama.NewClipContext(*ppath)
|
|
}
|
|
|
|
server.cond = sync.NewCond(&server.mu)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
go server.run(ctx)
|
|
|
|
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
|
listener, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
fmt.Println("Listen error:", err)
|
|
return
|
|
}
|
|
defer listener.Close()
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/embedding", server.embeddings)
|
|
mux.HandleFunc("/completion", server.completion)
|
|
mux.HandleFunc("/health", server.health)
|
|
|
|
httpServer := http.Server{
|
|
Handler: mux,
|
|
}
|
|
|
|
server.status = "ok"
|
|
|
|
log.Println("Server listening on", addr)
|
|
if err := httpServer.Serve(listener); err != nil {
|
|
log.Fatal("server error:", err)
|
|
}
|
|
|
|
cancel()
|
|
}
|