forked from third-party-mirrors/ollama
wip...
This commit is contained in:
parent
eb1aa97961
commit
6129f30479
@ -24,9 +24,28 @@ type Server struct {
|
||||
model *llama.Model
|
||||
lc *llama.Context
|
||||
batch *llama.Batch
|
||||
|
||||
queue chan Sequence
|
||||
seqs []*Sequence
|
||||
|
||||
// mu guards seqs
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
type Sequence struct {
|
||||
prompt []llama.Token
|
||||
out chan string
|
||||
}
|
||||
|
||||
func schedule(parallel int, queue <-chan Sequence) {
|
||||
// Fill sequences from the queue
|
||||
|
||||
// once a sequence finishes, remove it from and add a new one from the queue
|
||||
}
|
||||
|
||||
func process() {
|
||||
// loop through the sequences, fill a batch, decode and sample tokens, responding to appropriate requests
|
||||
}
|
||||
|
||||
func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
||||
var request Request
|
||||
@ -40,17 +59,23 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
|
||||
// main loop
|
||||
tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println("tokens", tokens)
|
||||
seq := Sequence{prompt: tokens}
|
||||
s.queue <- seq
|
||||
|
||||
batch := llama.NewBatch(512, 0, 1)
|
||||
// listen for the sequence to finish
|
||||
for {
|
||||
str := <-seq.out
|
||||
if err := json.NewEncoder(w).Encode(&Response{Token: str}); err != nil {
|
||||
log.Println("Failed to encode result:", err)
|
||||
return
|
||||
}
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
|
||||
// prompt eval
|
||||
for i, t := range tokens {
|
||||
@ -90,6 +115,7 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func main() {
|
||||
mp := flag.String("model", "", "Path to model binary file")
|
||||
parallel := flag.Int("parallel", 1, "Number of parallel requests to handle")
|
||||
flag.Parse()
|
||||
|
||||
// load the model
|
||||
@ -105,6 +131,8 @@ func main() {
|
||||
server := &Server{
|
||||
model: model,
|
||||
lc: lc,
|
||||
queue: make(chan Sequence, 256),
|
||||
seqs: make([]*Sequence, *parallel),
|
||||
}
|
||||
|
||||
addr := "127.0.0.1:8080"
|
||||
|
Loading…
x
Reference in New Issue
Block a user