This commit is contained in:
jmorganca 2024-05-16 13:52:38 -07:00
parent eb1aa97961
commit 6129f30479

View File

@ -24,9 +24,28 @@ type Server struct {
model *llama.Model
lc *llama.Context
batch *llama.Batch
queue chan Sequence
seqs []*Sequence
// mu guards seqs
mu sync.Mutex
}
var mu sync.Mutex
type Sequence struct {
prompt []llama.Token
out chan string
}
func schedule(parallel int, queue <-chan Sequence) {
// Fill sequences from the queue
// once a sequence finishes, remove it from and add a new one from the queue
}
func process() {
// loop through the sequences, fill a batch, decode and sample tokens, responding to appropriate requests
}
func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
var request Request
@ -40,17 +59,23 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Transfer-Encoding", "chunked")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
// main loop
tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
if err != nil {
panic(err)
}
fmt.Println("tokens", tokens)
seq := Sequence{prompt: tokens}
s.queue <- seq
batch := llama.NewBatch(512, 0, 1)
// listen for the sequence to finish
for {
str := <-seq.out
if err := json.NewEncoder(w).Encode(&Response{Token: str}); err != nil {
log.Println("Failed to encode result:", err)
return
}
w.(http.Flusher).Flush()
}
// prompt eval
for i, t := range tokens {
@ -90,6 +115,7 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
func main() {
mp := flag.String("model", "", "Path to model binary file")
parallel := flag.Int("parallel", 1, "Number of parallel requests to handle")
flag.Parse()
// load the model
@ -105,6 +131,8 @@ func main() {
server := &Server{
model: model,
lc: lc,
queue: make(chan Sequence, 256),
seqs: make([]*Sequence, *parallel),
}
addr := "127.0.0.1:8080"