Compare commits

...

13 Commits

Author SHA1 Message Date
Roy Han
30823ec925 update readme 2024-08-09 11:32:27 -07:00
Roy Han
89f3bae306 cli 2024-08-09 11:04:26 -07:00
Roy Han
ad7e822883 audio processing error prop 2024-08-07 14:05:22 -07:00
Roy Han
d503f04b32 expiration 2024-08-07 13:04:57 -07:00
Roy Han
8ccf543c53 chat doc 2024-08-07 13:04:57 -07:00
Roy Han
75ad6309b4 chat support 2024-08-07 13:04:57 -07:00
Roy Han
a5181a8c51 error handling 2024-08-07 13:04:57 -07:00
Roy Han
2a9feb0707 model flexibility 2024-08-07 13:04:57 -07:00
Roy Han
e4d35198a2 transcribe 2024-08-07 13:04:57 -07:00
Roy Han
17f9dc6d08 save whisper port 2024-08-07 13:04:57 -07:00
Roy Han
97d9dffa80 err check 2024-08-07 13:04:57 -07:00
Roy Han
65483180b9 working poc 2024-08-07 13:04:57 -07:00
Roy Han
1ac92eae7c submodule 2024-08-07 13:04:57 -07:00
11 changed files with 596 additions and 10 deletions

5
.gitmodules vendored
View File

@ -1,4 +1,7 @@
[submodule "llama.cpp"]
path = llm/llama.cpp
url = https://github.com/ggerganov/llama.cpp.git
shallow = true
shallow = true
[submodule "llm/whisper.cpp"]
path = llm/whisper.cpp
url = git@github.com:ggerganov/whisper.cpp.git

View File

@ -36,6 +36,13 @@ func (e StatusError) Error() string {
// ImageData represents the raw binary data of an image file.
type ImageData []byte
type WhisperRequest struct {
Model string `json:"model,omitempty"`
Audio string `json:"audio,omitempty"`
Transcribe bool `json:"transcribe,omitempty"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
}
// GenerateRequest describes a request sent by [Client.Generate]. While you
// have to specify the Model and Prompt fields, all the other fields have
// reasonable defaults for basic uses.
@ -80,6 +87,8 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
Speech *WhisperRequest `json:"speech,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@ -105,6 +114,10 @@ type ChatRequest struct {
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
Speech *WhisperRequest `json:"speech,omitempty"`
RunSpeech bool `json:"run_speech,omitempty"`
}
type Tools []Tool
@ -127,6 +140,7 @@ type Message struct {
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Audio string `json:"audio,omitempty"`
}
func (m *Message) UnmarshalJSON(b []byte) error {
@ -450,6 +464,11 @@ type GenerateResponse struct {
Metrics
}
type WhisperCompletion struct {
Text string `json:"text"`
Error string `json:"error,omitempty"`
}
// ModelDetails provides details about a model.
type ModelDetails struct {
ParentModel string `json:"parent_model"`

View File

@ -38,6 +38,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/recorder"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@ -380,6 +381,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
speech, err := cmd.Flags().GetBool("speech")
if err != nil {
return err
}
if speech {
return generateInteractiveAudio(cmd, opts)
}
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
@ -862,6 +871,7 @@ type runOptions struct {
Options map[string]interface{}
MultiModal bool
KeepAlive *api.Duration
Audio bool
}
type displayResponseState struct {
@ -970,6 +980,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
req.KeepAlive = opts.KeepAlive
}
if opts.Audio {
req.RunSpeech = true
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
@ -1055,6 +1069,30 @@ func generate(cmd *cobra.Command, opts runOptions) error {
KeepAlive: opts.KeepAlive,
}
speech, err := cmd.Flags().GetBool("speech")
if err != nil {
return err
}
// create temp wav file with the recorder package
if speech {
tempFile, err := os.CreateTemp("", "recording-*.wav")
if err != nil {
return err
}
defer os.Remove(tempFile.Name())
fmt.Print("Speech Mode\n\n")
err = recorder.RecordAudio(tempFile)
if err != nil {
return err
}
request.Speech = &api.WhisperRequest{
Audio: tempFile.Name(),
}
}
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
@ -1262,6 +1300,7 @@ func NewCLI() *cobra.Command {
RunE: RunHandler,
}
runCmd.Flags().Bool("speech", false, "Speech to text mode")
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
runCmd.Flags().Bool("verbose", false, "Show timings for response")
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")

View File

@ -20,6 +20,7 @@ import (
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/recorder"
"github.com/ollama/ollama/types/errtypes"
)
@ -51,6 +52,40 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
}
func generateInteractiveAudio(cmd *cobra.Command, opts runOptions) error {
for {
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
// create temp wav file with the recorder package
tempFile, err := os.CreateTemp("", "recording-*.wav")
if err != nil {
return err
}
defer os.Remove(tempFile.Name())
err = recorder.RecordAudio(tempFile)
if err != nil {
return err
}
p.StopAndClear()
newMessage := api.Message{Role: "user", Audio: tempFile.Name()}
opts.Audio = true
opts.Messages = append(opts.Messages, newMessage)
assistant, err := chat(cmd, opts)
if err != nil {
return err
}
if assistant != nil {
opts.Messages = append(opts.Messages, *assistant)
}
}
}
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")

83
docs/speech.md Normal file
View File

@ -0,0 +1,83 @@
# Speech to Text Prototype
### To run
`make {/path/to/whisper.cpp/server}`
- replace `whisperServer` in `routes.go` with path to server
## CLI
`./ollama run llama3 [PROMPT] --speech`
- processes voice audio with the provided prompt
`./ollama run llama3 --speech`
- enters interactive mode for continuous voice chat
- TODO: fix exiting interactive mode
Notes: uses default model
## api/generate
### Request fields
- `speech` (required):
- `audio` (required): path to audio file
- `model` (optional): path to whisper model, uses default if null
- `transcribe` (optional): if true, will transcribe and return the audio file
- `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
- `prompt` (optional): if not null, passed in with the transcribed audio
#### Transcription
```
curl http://localhost:11434/api/generate -d '{
"speech": {
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
"transcribe": true,
"keep_alive": "1m"
},
"stream": false
}' | jq
```
#### Response Generation
```
curl http://localhost:11434/api/generate -d '{
"model": "llama3",
"prompt": "What do you think about this quote?",
"speech": {
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
"keep_alive": "1m"
},
"stream": false
}' | jq
```
## api/chat
### Request fields
- `model` (required): language model to chat with
- `speech` (optional):
- `model` (optional): path to whisper model, uses default if null
- `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
- `run_speech` (optional): either this flag must be true or `speech` must be passed in for speech mode to run
- `messages`/`message`/`audio` (required): path to audio file
```
curl http://localhost:11434/api/chat -d '{
"model": "llama3",
"speech": {
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
"keep_alive": "10m"
},
"messages": [
{
"role": "system",
"content": "You are a Canadian Nationalist"
},
{
"role": "user",
"content": "What do you think about this quote?",
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav"
}
],
"stream": false
}' | jq
```

1
go.mod
View File

@ -19,6 +19,7 @@ require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/google/go-cmp v0.6.0
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c

2
go.sum
View File

@ -115,6 +115,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=

1
llm/whisper.cpp Submodule

@ -0,0 +1 @@
Subproject commit 6739eb83c3ca5cf40d24c6fe8442a761a1eb6248

137
recorder/recorder.go Normal file
View File

@ -0,0 +1,137 @@
package recorder
import (
"encoding/binary"
"fmt"
"os"
"os/signal"
"syscall"
"golang.org/x/sys/unix"
"golang.org/x/term"
"github.com/gordonklaus/portaudio"
)
const (
sampleRate = 16000
numChannels = 1
bitsPerSample = 16
)
func RecordAudio(f *os.File) error {
fmt.Print("Recording. Press any key to stop.\n\n")
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
portaudio.Initialize()
defer portaudio.Terminate()
in := make([]int16, 64)
stream, err := portaudio.OpenDefaultStream(numChannels, 0, sampleRate, len(in), in)
if err != nil {
return err
}
defer stream.Close()
err = stream.Start()
if err != nil {
return err
}
// Write WAV header with placeholder sizes
writeWavHeader(f, sampleRate, numChannels, bitsPerSample)
var totalSamples uint32
// Set up terminal input reading
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
return err
}
defer term.Restore(int(os.Stdin.Fd()), oldState)
// Create a channel to handle the stop signal
stop := make(chan struct{})
go func() {
_, err := unix.Read(int(os.Stdin.Fd()), make([]byte, 1))
if err != nil {
fmt.Println("Error reading from stdin:", err)
return
}
// Send signal to stop recording
stop <- struct{}{}
}()
loop:
for {
err = stream.Read()
if err != nil {
return err
}
err = binary.Write(f, binary.LittleEndian, in)
if err != nil {
return err
}
totalSamples += uint32(len(in))
select {
case <-stop:
break loop
case <-sig:
break loop
default:
}
}
err = stream.Stop()
if err != nil {
return err
}
// Update WAV header with actual sizes
updateWavHeader(f, totalSamples, numChannels, bitsPerSample)
return nil
}
func writeWavHeader(f *os.File, sampleRate int, numChannels int, bitsPerSample int) {
subchunk1Size := 16
audioFormat := 1
byteRate := sampleRate * numChannels * (bitsPerSample / 8)
blockAlign := numChannels * (bitsPerSample / 8)
// Write the RIFF header
f.Write([]byte("RIFF"))
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for file size
f.Write([]byte("WAVE"))
// Write the fmt subchunk
f.Write([]byte("fmt "))
binary.Write(f, binary.LittleEndian, uint32(subchunk1Size))
binary.Write(f, binary.LittleEndian, uint16(audioFormat))
binary.Write(f, binary.LittleEndian, uint16(numChannels))
binary.Write(f, binary.LittleEndian, uint32(sampleRate))
binary.Write(f, binary.LittleEndian, uint32(byteRate))
binary.Write(f, binary.LittleEndian, uint16(blockAlign))
binary.Write(f, binary.LittleEndian, uint16(bitsPerSample))
// Write the data subchunk header
f.Write([]byte("data"))
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for data size
}
func updateWavHeader(f *os.File, totalSamples uint32, numChannels int, bitsPerSample int) {
fileSize := 36 + (totalSamples * uint32(numChannels) * uint32(bitsPerSample/8))
dataSize := totalSamples * uint32(numChannels) * uint32(bitsPerSample/8)
// Seek to the start of the file and write updated sizes
f.Seek(4, 0)
binary.Write(f, binary.LittleEndian, uint32(fileSize))
f.Seek(40, 0)
binary.Write(f, binary.LittleEndian, uint32(dataSize))
}

View File

@ -10,13 +10,17 @@ import (
"io"
"log/slog"
"math"
"math/rand"
"mime/multipart"
"net"
"net/http"
"net/netip"
"os"
"os/exec"
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"syscall"
"time"
@ -105,6 +109,186 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
return runner.llama, model, &opts, nil
}
func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) {
var modelPath string
if speech.Model == "" {
modelPath = "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin"
} else {
modelPath = speech.Model
}
// default to 5 minutes
var sessionDuration time.Duration
if speech.KeepAlive != nil {
sessionDuration = speech.KeepAlive.Duration
} else {
sessionDuration = 5 * time.Minute
}
s.sched.whisperMu.Lock()
if s.sched.whisperLoaded[modelPath] != nil {
slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath]))
portCh <- *s.sched.whisperLoaded[modelPath]
// Renew the expiration time
s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
s.sched.whisperMu.Unlock()
return
}
whisperServer := "/Users/royhan-ollama/.ollama/server"
// Find an available port for whisper
port := 0
params := []string{}
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
var l *net.TCPListener
if l, err = net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
slog.Debug("ResolveTCPAddr failed")
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
}
finalParams := append(params, "--port", strconv.Itoa(port), "--model", modelPath)
cmd := exec.Command(whisperServer, finalParams...)
slog.Info("starting whisper server", "cmd", cmd.String())
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
if err != nil {
slog.Error("failed to start whisper server", "error", err)
errCh <- err
return
}
// Wait for server connection
retries := 25
var connErr error
for range retries {
time.Sleep(50 * time.Millisecond)
conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second)
if err == nil {
conn.Close()
connErr = nil
break
}
connErr = err
}
if connErr != nil {
slog.Error("failed to connect to whisper server", "error", connErr)
errCh <- connErr
return
}
portCh <- port
s.sched.whisperLoaded[modelPath] = &port
s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
s.sched.whisperMu.Unlock()
// Wait for the whisper server to exit
defer func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for range ticker.C {
s.sched.whisperMu.Lock()
if time.Now().After(s.sched.whisperExpiresAt[modelPath]) {
slog.Info("exiting whisper server")
delete(s.sched.whisperLoaded, modelPath)
delete(s.sched.whisperExpiresAt, modelPath)
s.sched.whisperMu.Unlock()
if err := cmd.Process.Kill(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
slog.Debug("whisper server stopped")
return
}
s.sched.whisperMu.Unlock()
}
}()
}
func whisperInference(c *gin.Context, filePath string, port int) (*api.WhisperCompletion, error) {
// Open the file
file, err := os.Open(filePath)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to open file"})
return nil, err
}
defer file.Close()
// Create a buffer to hold the multipart form data
buffer := &bytes.Buffer{}
writer := multipart.NewWriter(buffer)
// Add the file to the multipart form
part, err := writer.CreateFormFile("file", filepath.Base(filePath))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create form file"})
return nil, err
}
if _, err := io.Copy(part, file); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to copy file"})
return nil, err
}
// Add other fields to the form
if err := writer.WriteField("temperature", "0.0"); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to write field"})
return nil, err
}
// Close the writer to finalize the multipart form
if err := writer.Close(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to close writer"})
return nil, err
}
endpoint := fmt.Sprintf("http://localhost:%s/inference", strconv.Itoa(port))
serverReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, endpoint, buffer)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"})
return nil, err
}
serverReq.Header.Set("Content-Type", writer.FormDataContentType())
res, err := http.DefaultClient.Do(serverReq)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to send request"})
return nil, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to read response"})
return nil, err
}
var w api.WhisperCompletion
if err := json.Unmarshal(body, &w); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to unmarshal response"})
return nil, err
}
if w.Error != "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": w.Error})
return nil, fmt.Errorf(w.Error)
}
return &w, nil
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
@ -129,6 +313,40 @@ func (s *Server) GenerateHandler(c *gin.Context) {
caps = append(caps, CapabilityInsert)
}
if req.Speech != nil {
portCh := make(chan int, 1)
errCh := make(chan error, 1)
go s.runWhisperServer(c, portCh, errCh, req.Speech)
var port int
select {
case port = <-portCh:
case err := <-errCh:
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
w, err := whisperInference(c, req.Speech.Audio, port)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
return
}
if req.Speech.Transcribe {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: w.Text,
Done: true,
DoneReason: "transcribe",
})
return
}
req.Prompt += "\n" + w.Text
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
@ -1296,6 +1514,37 @@ func (s *Server) ProcessHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) error {
slog.Info("processing audio")
if req == nil {
req = &api.WhisperRequest{}
}
portCh := make(chan int, 1)
errCh := make(chan error, 1)
go s.runWhisperServer(c, portCh, errCh, req)
var port int
select {
case port = <-portCh:
case err := <-errCh:
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return err
}
// could parallelize this
for i, msg := range msgs {
if msg.Audio != "" {
w, err := whisperInference(c, msg.Audio, port)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
return err
}
msgs[i].Content += "\n" + w.Text
}
}
return nil
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
@ -1340,6 +1589,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
}
if req.Speech != nil || req.RunSpeech {
if err := processAudio(c, s, msgs, req.Speech); err != nil {
slog.Error("failed to process audio", "error", err)
return
}
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})

View File

@ -46,6 +46,10 @@ type Scheduler struct {
getGpuFn func() gpu.GpuInfoList
getCpuFn func() gpu.GpuInfoList
reschedDelay time.Duration
whisperLoaded map[string]*int
whisperExpiresAt map[string]time.Time
whisperMu sync.Mutex
}
// Default automatic value for number of models we allow per GPU
@ -63,15 +67,17 @@ var ErrMaxQueue = errors.New("server busy, please try again. maximum pending re
func InitScheduler(ctx context.Context) *Scheduler {
maxQueue := envconfig.MaxQueue()
sched := &Scheduler{
pendingReqCh: make(chan *LlmRequest, maxQueue),
finishedReqCh: make(chan *LlmRequest, maxQueue),
expiredCh: make(chan *runnerRef, maxQueue),
unloadedCh: make(chan interface{}, maxQueue),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
pendingReqCh: make(chan *LlmRequest, maxQueue),
finishedReqCh: make(chan *LlmRequest, maxQueue),
expiredCh: make(chan *runnerRef, maxQueue),
unloadedCh: make(chan interface{}, maxQueue),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
whisperLoaded: make(map[string]*int),
whisperExpiresAt: make(map[string]time.Time),
}
sched.loadFn = sched.load
return sched
@ -110,6 +116,10 @@ func (s *Scheduler) Run(ctx context.Context) {
go func() {
s.processCompleted(ctx)
}()
// go func() {
// could clean up whisper servers in init thread
// }
}
func (s *Scheduler) processPending(ctx context.Context) {