ollama/llama/server/main.go
2024-09-03 21:15:12 -04:00

127 lines
2.4 KiB
Go

package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"net"
"net/http"
"sync"
"github.com/ollama/ollama/llama"
)
type Request struct {
Prompt string `json:"prompt"`
}
type Response struct {
Token string `json:"token"`
}
type Server struct {
model *llama.Model
lc *llama.Context
batch *llama.Batch
}
var mu sync.Mutex
func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
var request Request
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
// main loop
tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
if err != nil {
panic(err)
}
fmt.Println("tokens", tokens)
batch := llama.NewBatch(512, 0, 1)
// prompt eval
for i, t := range tokens {
batch.Add(t, llama.Pos(i), []llama.SeqId{0}, true)
}
// main loop
for n := batch.NumTokens(); n < 2048; n++ {
mu.Lock()
err = s.lc.Decode(batch)
if err != nil {
panic("Failed to decode")
}
// sample a token
token := s.lc.SampleTokenGreedy(batch)
mu.Unlock()
// if it's an end of sequence token, break
if s.model.TokenIsEog(token) {
break
}
// print the token
str := s.model.TokenToPiece(token)
if err := enc.Encode(&Response{Token: str}); err != nil {
log.Println("Failed to encode result:", err)
return
}
w.(http.Flusher).Flush()
batch.Clear()
batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true)
}
}
func main() {
mp := flag.String("model", "", "Path to model binary file")
flag.Parse()
// load the model
llama.BackendInit()
params := llama.NewModelParams()
model := llama.LoadModelFromFile(*mp, params)
ctxParams := llama.NewContextParams()
lc := llama.NewContextWithModel(model, ctxParams)
if lc == nil {
panic("Failed to create context")
}
server := &Server{
model: model,
lc: lc,
}
addr := "127.0.0.1:8080"
listener, err := net.Listen("tcp", addr)
if err != nil {
fmt.Println("Listen error:", err)
return
}
defer listener.Close()
httpServer := http.Server{
Handler: http.HandlerFunc(server.stream),
}
log.Println("Server listening on", addr)
if err := httpServer.Serve(listener); err != nil {
log.Fatal("server error:", err)
}
}