Merge ca04f2a0edebbac583baf0d8fff68e72eb68678e into 67691e410db7a50b07a64858820b14de9aa91314

This commit is contained in:
Jesse Gross 2024-11-14 12:43:32 +01:00 committed by GitHub
commit 16c12a2278
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,6 +20,8 @@ import (
"time" "time"
"unicode/utf8" "unicode/utf8"
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
) )
@ -203,38 +205,51 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
} }
type Server struct { type Server struct {
model *llama.Model // is the server ready to process requests?
lc *llama.Context // protects access to model and image
ready sync.WaitGroup
// required for image embeddings // loaded model
model *llama.Model
// image model context for multi-modal models
image *ImageContext image *ImageContext
// status for external health reporting - loading, ready to serve, etc.
status ServerStatus
// current progress on loading the model
progress float32
// number of simultaneous requests to handle
parallel int
// maximum number of elements in a batch (per sequence)
// TODO (jmorganca): make this n_batch // TODO (jmorganca): make this n_batch
batchSize int batchSize int
// parallel is the number of parallel requests to handle // protects access to everything below this line
parallel int // this is context state needed for decoding
mu sync.Mutex
// seqs is the list of parallel sequences being evaluated // indicates that data is ready for processing
// TODO (jmorganca): this can probably be moved into run() cond *sync.Cond
// decoding state
lc *llama.Context
// the list of simultaneous sequences being evaluated
seqs []*Sequence seqs []*Sequence
// seqs can have a maximum of parallel entries, which
// is enfoced by seqSem
seqsSem *semaphore.Weighted
// KV cache // KV cache
cache *InputCache cache *InputCache
// next sequence for prompt processing to avoid starvation // next sequence for prompt processing to avoid starvation
nextSeq int nextSeq int
// is the server ready to process requests?
ready sync.WaitGroup
mu sync.Mutex
cond *sync.Cond
progress float32
status ServerStatus
} }
func (s *Server) allNil() bool { func (s *Server) allNil() bool {
@ -609,8 +624,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
// TODO (jmorganca): add to sequence queue instead of // Ensure that a place to put the sequence is available
// failing if a slot isn't available if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return
}
defer s.seqsSem.Release(1)
s.mu.Lock() s.mu.Lock()
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
@ -693,7 +713,13 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
return return
} }
// TODO (jessegross): Wait for a free slot instead of failing and blocking forever // Ensure that a place to put the sequence is available
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return
}
defer s.seqsSem.Release(1)
s.mu.Lock() s.mu.Lock()
for i, sq := range s.seqs { for i, sq := range s.seqs {
if sq == nil { if sq == nil {
@ -848,6 +874,7 @@ func main() {
batchSize: *batchSize, batchSize: *batchSize,
parallel: *parallel, parallel: *parallel,
seqs: make([]*Sequence, *parallel), seqs: make([]*Sequence, *parallel),
seqsSem: semaphore.NewWeighted(int64(*parallel)),
status: ServerStatusLoadingModel, status: ServerStatusLoadingModel,
} }