forked from third-party-mirrors/ollama
127 lines
2.4 KiB
Go
127 lines
2.4 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
|
|
"github.com/ollama/ollama/llama"
|
|
)
|
|
|
|
type Request struct {
|
|
Prompt string `json:"prompt"`
|
|
}
|
|
|
|
type Response struct {
|
|
Token string `json:"token"`
|
|
}
|
|
|
|
type Server struct {
|
|
model *llama.Model
|
|
lc *llama.Context
|
|
batch *llama.Batch
|
|
}
|
|
|
|
var mu sync.Mutex
|
|
|
|
func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
|
var request Request
|
|
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
|
|
http.Error(w, "Bad request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Set the headers to indicate streaming
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Transfer-Encoding", "chunked")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
enc := json.NewEncoder(w)
|
|
|
|
// main loop
|
|
tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
fmt.Println("tokens", tokens)
|
|
|
|
batch := llama.NewBatch(512, 0, 1)
|
|
|
|
// prompt eval
|
|
for i, t := range tokens {
|
|
batch.Add(t, llama.Pos(i), []llama.SeqId{0}, true)
|
|
}
|
|
|
|
// main loop
|
|
for n := batch.NumTokens(); n < 2048; n++ {
|
|
mu.Lock()
|
|
err = s.lc.Decode(batch)
|
|
if err != nil {
|
|
panic("Failed to decode")
|
|
}
|
|
|
|
// sample a token
|
|
token := s.lc.SampleTokenGreedy(batch)
|
|
mu.Unlock()
|
|
|
|
// if it's an end of sequence token, break
|
|
if s.model.TokenIsEog(token) {
|
|
break
|
|
}
|
|
|
|
// print the token
|
|
str := s.model.TokenToPiece(token)
|
|
|
|
if err := enc.Encode(&Response{Token: str}); err != nil {
|
|
log.Println("Failed to encode result:", err)
|
|
return
|
|
}
|
|
w.(http.Flusher).Flush()
|
|
|
|
batch.Clear()
|
|
batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true)
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
mp := flag.String("model", "", "Path to model binary file")
|
|
flag.Parse()
|
|
|
|
// load the model
|
|
llama.BackendInit()
|
|
params := llama.NewModelParams()
|
|
model := llama.LoadModelFromFile(*mp, params)
|
|
ctxParams := llama.NewContextParams()
|
|
lc := llama.NewContextWithModel(model, ctxParams)
|
|
if lc == nil {
|
|
panic("Failed to create context")
|
|
}
|
|
|
|
server := &Server{
|
|
model: model,
|
|
lc: lc,
|
|
}
|
|
|
|
addr := "127.0.0.1:8080"
|
|
listener, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
fmt.Println("Listen error:", err)
|
|
return
|
|
}
|
|
defer listener.Close()
|
|
|
|
httpServer := http.Server{
|
|
Handler: http.HandlerFunc(server.stream),
|
|
}
|
|
|
|
log.Println("Server listening on", addr)
|
|
if err := httpServer.Serve(listener); err != nil {
|
|
log.Fatal("server error:", err)
|
|
}
|
|
}
|