Compare commits
13 Commits
main
...
royh/whisp
Author | SHA1 | Date | |
---|---|---|---|
|
30823ec925 | ||
|
89f3bae306 | ||
|
ad7e822883 | ||
|
d503f04b32 | ||
|
8ccf543c53 | ||
|
75ad6309b4 | ||
|
a5181a8c51 | ||
|
2a9feb0707 | ||
|
e4d35198a2 | ||
|
17f9dc6d08 | ||
|
97d9dffa80 | ||
|
65483180b9 | ||
|
1ac92eae7c |
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -2,3 +2,6 @@
|
||||
path = llm/llama.cpp
|
||||
url = https://github.com/ggerganov/llama.cpp.git
|
||||
shallow = true
|
||||
[submodule "llm/whisper.cpp"]
|
||||
path = llm/whisper.cpp
|
||||
url = git@github.com:ggerganov/whisper.cpp.git
|
||||
|
19
api/types.go
19
api/types.go
@ -36,6 +36,13 @@ func (e StatusError) Error() string {
|
||||
// ImageData represents the raw binary data of an image file.
|
||||
type ImageData []byte
|
||||
|
||||
type WhisperRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
Transcribe bool `json:"transcribe,omitempty"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
}
|
||||
|
||||
// GenerateRequest describes a request sent by [Client.Generate]. While you
|
||||
// have to specify the Model and Prompt fields, all the other fields have
|
||||
// reasonable defaults for basic uses.
|
||||
@ -80,6 +87,8 @@ type GenerateRequest struct {
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
|
||||
Speech *WhisperRequest `json:"speech,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@ -105,6 +114,10 @@ type ChatRequest struct {
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
|
||||
Speech *WhisperRequest `json:"speech,omitempty"`
|
||||
|
||||
RunSpeech bool `json:"run_speech,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@ -127,6 +140,7 @@ type Message struct {
|
||||
Content string `json:"content"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||
@ -450,6 +464,11 @@ type GenerateResponse struct {
|
||||
Metrics
|
||||
}
|
||||
|
||||
type WhisperCompletion struct {
|
||||
Text string `json:"text"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
type ModelDetails struct {
|
||||
ParentModel string `json:"parent_model"`
|
||||
|
39
cmd/cmd.go
39
cmd/cmd.go
@ -38,6 +38,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/recorder"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@ -380,6 +381,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
speech, err := cmd.Flags().GetBool("speech")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if speech {
|
||||
return generateInteractiveAudio(cmd, opts)
|
||||
}
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
return generate(cmd, opts)
|
||||
@ -862,6 +871,7 @@ type runOptions struct {
|
||||
Options map[string]interface{}
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Audio bool
|
||||
}
|
||||
|
||||
type displayResponseState struct {
|
||||
@ -970,6 +980,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if opts.Audio {
|
||||
req.RunSpeech = true
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
@ -1055,6 +1069,30 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
KeepAlive: opts.KeepAlive,
|
||||
}
|
||||
|
||||
speech, err := cmd.Flags().GetBool("speech")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// create temp wav file with the recorder package
|
||||
if speech {
|
||||
tempFile, err := os.CreateTemp("", "recording-*.wav")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
fmt.Print("Speech Mode\n\n")
|
||||
|
||||
err = recorder.RecordAudio(tempFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.Speech = &api.WhisperRequest{
|
||||
Audio: tempFile.Name(),
|
||||
}
|
||||
}
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
@ -1262,6 +1300,7 @@ func NewCLI() *cobra.Command {
|
||||
RunE: RunHandler,
|
||||
}
|
||||
|
||||
runCmd.Flags().Bool("speech", false, "Speech to text mode")
|
||||
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/recorder"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
)
|
||||
|
||||
@ -51,6 +52,40 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
|
||||
}
|
||||
|
||||
func generateInteractiveAudio(cmd *cobra.Command, opts runOptions) error {
|
||||
for {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
// create temp wav file with the recorder package
|
||||
tempFile, err := os.CreateTemp("", "recording-*.wav")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
err = recorder.RecordAudio(tempFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.StopAndClear()
|
||||
|
||||
newMessage := api.Message{Role: "user", Audio: tempFile.Name()}
|
||||
opts.Audio = true
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
opts.Messages = append(opts.Messages, *assistant)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
usage := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
|
83
docs/speech.md
Normal file
83
docs/speech.md
Normal file
@ -0,0 +1,83 @@
|
||||
# Speech to Text Prototype
|
||||
|
||||
### To run
|
||||
`make {/path/to/whisper.cpp/server}`
|
||||
- replace `whisperServer` in `routes.go` with path to server
|
||||
|
||||
## CLI
|
||||
`./ollama run llama3 [PROMPT] --speech`
|
||||
- processes voice audio with the provided prompt
|
||||
|
||||
`./ollama run llama3 --speech`
|
||||
- enters interactive mode for continuous voice chat
|
||||
- TODO: fix exiting interactive mode
|
||||
|
||||
Notes: uses default model
|
||||
|
||||
|
||||
## api/generate
|
||||
### Request fields
|
||||
- `speech` (required):
|
||||
- `audio` (required): path to audio file
|
||||
- `model` (optional): path to whisper model, uses default if null
|
||||
- `transcribe` (optional): if true, will transcribe and return the audio file
|
||||
- `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
|
||||
- `prompt` (optional): if not null, passed in with the transcribed audio
|
||||
|
||||
#### Transcription
|
||||
```
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"speech": {
|
||||
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
|
||||
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
|
||||
"transcribe": true,
|
||||
"keep_alive": "1m"
|
||||
},
|
||||
"stream": false
|
||||
}' | jq
|
||||
```
|
||||
|
||||
#### Response Generation
|
||||
```
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama3",
|
||||
"prompt": "What do you think about this quote?",
|
||||
"speech": {
|
||||
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
|
||||
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
|
||||
"keep_alive": "1m"
|
||||
},
|
||||
"stream": false
|
||||
}' | jq
|
||||
```
|
||||
|
||||
## api/chat
|
||||
### Request fields
|
||||
- `model` (required): language model to chat with
|
||||
- `speech` (optional):
|
||||
- `model` (optional): path to whisper model, uses default if null
|
||||
- `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
|
||||
- `run_speech` (optional): either this flag must be true or `speech` must be passed in for speech mode to run
|
||||
- `messages`/`message`/`audio` (required): path to audio file
|
||||
|
||||
```
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3",
|
||||
"speech": {
|
||||
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
|
||||
"keep_alive": "10m"
|
||||
},
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a Canadian Nationalist"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What do you think about this quote?",
|
||||
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}' | jq
|
||||
```
|
1
go.mod
1
go.mod
@ -19,6 +19,7 @@ require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
|
2
go.sum
2
go.sum
@ -115,6 +115,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
|
||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
|
1
llm/whisper.cpp
Submodule
1
llm/whisper.cpp
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 6739eb83c3ca5cf40d24c6fe8442a761a1eb6248
|
137
recorder/recorder.go
Normal file
137
recorder/recorder.go
Normal file
@ -0,0 +1,137 @@
|
||||
package recorder
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/gordonklaus/portaudio"
|
||||
)
|
||||
|
||||
const (
|
||||
sampleRate = 16000
|
||||
numChannels = 1
|
||||
bitsPerSample = 16
|
||||
)
|
||||
|
||||
func RecordAudio(f *os.File) error {
|
||||
fmt.Print("Recording. Press any key to stop.\n\n")
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
portaudio.Initialize()
|
||||
defer portaudio.Terminate()
|
||||
|
||||
in := make([]int16, 64)
|
||||
stream, err := portaudio.OpenDefaultStream(numChannels, 0, sampleRate, len(in), in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
err = stream.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write WAV header with placeholder sizes
|
||||
writeWavHeader(f, sampleRate, numChannels, bitsPerSample)
|
||||
|
||||
var totalSamples uint32
|
||||
|
||||
// Set up terminal input reading
|
||||
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer term.Restore(int(os.Stdin.Fd()), oldState)
|
||||
|
||||
// Create a channel to handle the stop signal
|
||||
stop := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
_, err := unix.Read(int(os.Stdin.Fd()), make([]byte, 1))
|
||||
if err != nil {
|
||||
fmt.Println("Error reading from stdin:", err)
|
||||
return
|
||||
}
|
||||
// Send signal to stop recording
|
||||
stop <- struct{}{}
|
||||
}()
|
||||
|
||||
loop:
|
||||
for {
|
||||
err = stream.Read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
totalSamples += uint32(len(in))
|
||||
|
||||
select {
|
||||
case <-stop:
|
||||
break loop
|
||||
case <-sig:
|
||||
break loop
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
err = stream.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update WAV header with actual sizes
|
||||
updateWavHeader(f, totalSamples, numChannels, bitsPerSample)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeWavHeader(f *os.File, sampleRate int, numChannels int, bitsPerSample int) {
|
||||
subchunk1Size := 16
|
||||
audioFormat := 1
|
||||
byteRate := sampleRate * numChannels * (bitsPerSample / 8)
|
||||
blockAlign := numChannels * (bitsPerSample / 8)
|
||||
|
||||
// Write the RIFF header
|
||||
f.Write([]byte("RIFF"))
|
||||
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for file size
|
||||
f.Write([]byte("WAVE"))
|
||||
|
||||
// Write the fmt subchunk
|
||||
f.Write([]byte("fmt "))
|
||||
binary.Write(f, binary.LittleEndian, uint32(subchunk1Size))
|
||||
binary.Write(f, binary.LittleEndian, uint16(audioFormat))
|
||||
binary.Write(f, binary.LittleEndian, uint16(numChannels))
|
||||
binary.Write(f, binary.LittleEndian, uint32(sampleRate))
|
||||
binary.Write(f, binary.LittleEndian, uint32(byteRate))
|
||||
binary.Write(f, binary.LittleEndian, uint16(blockAlign))
|
||||
binary.Write(f, binary.LittleEndian, uint16(bitsPerSample))
|
||||
|
||||
// Write the data subchunk header
|
||||
f.Write([]byte("data"))
|
||||
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for data size
|
||||
}
|
||||
|
||||
func updateWavHeader(f *os.File, totalSamples uint32, numChannels int, bitsPerSample int) {
|
||||
fileSize := 36 + (totalSamples * uint32(numChannels) * uint32(bitsPerSample/8))
|
||||
dataSize := totalSamples * uint32(numChannels) * uint32(bitsPerSample/8)
|
||||
|
||||
// Seek to the start of the file and write updated sizes
|
||||
f.Seek(4, 0)
|
||||
binary.Write(f, binary.LittleEndian, uint32(fileSize))
|
||||
|
||||
f.Seek(40, 0)
|
||||
binary.Write(f, binary.LittleEndian, uint32(dataSize))
|
||||
}
|
256
server/routes.go
256
server/routes.go
@ -10,13 +10,17 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@ -105,6 +109,186 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) {
|
||||
var modelPath string
|
||||
if speech.Model == "" {
|
||||
modelPath = "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin"
|
||||
} else {
|
||||
modelPath = speech.Model
|
||||
}
|
||||
|
||||
// default to 5 minutes
|
||||
var sessionDuration time.Duration
|
||||
if speech.KeepAlive != nil {
|
||||
sessionDuration = speech.KeepAlive.Duration
|
||||
} else {
|
||||
sessionDuration = 5 * time.Minute
|
||||
}
|
||||
|
||||
s.sched.whisperMu.Lock()
|
||||
if s.sched.whisperLoaded[modelPath] != nil {
|
||||
slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath]))
|
||||
portCh <- *s.sched.whisperLoaded[modelPath]
|
||||
// Renew the expiration time
|
||||
s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
|
||||
s.sched.whisperMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
whisperServer := "/Users/royhan-ollama/.ollama/server"
|
||||
|
||||
// Find an available port for whisper
|
||||
port := 0
|
||||
params := []string{}
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
var l *net.TCPListener
|
||||
if l, err = net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
slog.Debug("ResolveTCPAddr failed")
|
||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
}
|
||||
finalParams := append(params, "--port", strconv.Itoa(port), "--model", modelPath)
|
||||
|
||||
cmd := exec.Command(whisperServer, finalParams...)
|
||||
slog.Info("starting whisper server", "cmd", cmd.String())
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
slog.Error("failed to start whisper server", "error", err)
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Wait for server connection
|
||||
retries := 25
|
||||
var connErr error
|
||||
for range retries {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
connErr = nil
|
||||
break
|
||||
}
|
||||
connErr = err
|
||||
}
|
||||
|
||||
if connErr != nil {
|
||||
slog.Error("failed to connect to whisper server", "error", connErr)
|
||||
errCh <- connErr
|
||||
return
|
||||
}
|
||||
|
||||
portCh <- port
|
||||
s.sched.whisperLoaded[modelPath] = &port
|
||||
s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
|
||||
|
||||
s.sched.whisperMu.Unlock()
|
||||
|
||||
// Wait for the whisper server to exit
|
||||
defer func() {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.sched.whisperMu.Lock()
|
||||
if time.Now().After(s.sched.whisperExpiresAt[modelPath]) {
|
||||
slog.Info("exiting whisper server")
|
||||
delete(s.sched.whisperLoaded, modelPath)
|
||||
delete(s.sched.whisperExpiresAt, modelPath)
|
||||
s.sched.whisperMu.Unlock()
|
||||
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("whisper server stopped")
|
||||
return
|
||||
}
|
||||
s.sched.whisperMu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func whisperInference(c *gin.Context, filePath string, port int) (*api.WhisperCompletion, error) {
|
||||
// Open the file
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to open file"})
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Create a buffer to hold the multipart form data
|
||||
buffer := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(buffer)
|
||||
|
||||
// Add the file to the multipart form
|
||||
part, err := writer.CreateFormFile("file", filepath.Base(filePath))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create form file"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to copy file"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add other fields to the form
|
||||
if err := writer.WriteField("temperature", "0.0"); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to write field"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Close the writer to finalize the multipart form
|
||||
if err := writer.Close(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to close writer"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("http://localhost:%s/inference", strconv.Itoa(port))
|
||||
|
||||
serverReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, endpoint, buffer)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverReq.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
res, err := http.DefaultClient.Do(serverReq)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to send request"})
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to read response"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var w api.WhisperCompletion
|
||||
if err := json.Unmarshal(body, &w); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to unmarshal response"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if w.Error != "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": w.Error})
|
||||
return nil, fmt.Errorf(w.Error)
|
||||
}
|
||||
|
||||
return &w, nil
|
||||
}
|
||||
|
||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
var req api.GenerateRequest
|
||||
@ -129,6 +313,40 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
caps = append(caps, CapabilityInsert)
|
||||
}
|
||||
|
||||
if req.Speech != nil {
|
||||
portCh := make(chan int, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go s.runWhisperServer(c, portCh, errCh, req.Speech)
|
||||
|
||||
var port int
|
||||
|
||||
select {
|
||||
case port = <-portCh:
|
||||
case err := <-errCh:
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
|
||||
return
|
||||
}
|
||||
|
||||
w, err := whisperInference(c, req.Speech.Audio, port)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Speech.Transcribe {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: w.Text,
|
||||
Done: true,
|
||||
DoneReason: "transcribe",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
req.Prompt += "\n" + w.Text
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||
@ -1296,6 +1514,37 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||
}
|
||||
|
||||
func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) error {
|
||||
slog.Info("processing audio")
|
||||
if req == nil {
|
||||
req = &api.WhisperRequest{}
|
||||
}
|
||||
portCh := make(chan int, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go s.runWhisperServer(c, portCh, errCh, req)
|
||||
|
||||
var port int
|
||||
select {
|
||||
case port = <-portCh:
|
||||
case err := <-errCh:
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return err
|
||||
}
|
||||
|
||||
// could parallelize this
|
||||
for i, msg := range msgs {
|
||||
if msg.Audio != "" {
|
||||
w, err := whisperInference(c, msg.Audio, port)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
|
||||
return err
|
||||
}
|
||||
msgs[i].Content += "\n" + w.Text
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) ChatHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
@ -1340,6 +1589,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
||||
}
|
||||
|
||||
if req.Speech != nil || req.RunSpeech {
|
||||
if err := processAudio(c, s, msgs, req.Speech); err != nil {
|
||||
slog.Error("failed to process audio", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
@ -46,6 +46,10 @@ type Scheduler struct {
|
||||
getGpuFn func() gpu.GpuInfoList
|
||||
getCpuFn func() gpu.GpuInfoList
|
||||
reschedDelay time.Duration
|
||||
|
||||
whisperLoaded map[string]*int
|
||||
whisperExpiresAt map[string]time.Time
|
||||
whisperMu sync.Mutex
|
||||
}
|
||||
|
||||
// Default automatic value for number of models we allow per GPU
|
||||
@ -63,15 +67,17 @@ var ErrMaxQueue = errors.New("server busy, please try again. maximum pending re
|
||||
func InitScheduler(ctx context.Context) *Scheduler {
|
||||
maxQueue := envconfig.MaxQueue()
|
||||
sched := &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, maxQueue),
|
||||
finishedReqCh: make(chan *LlmRequest, maxQueue),
|
||||
expiredCh: make(chan *runnerRef, maxQueue),
|
||||
unloadedCh: make(chan interface{}, maxQueue),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: llm.NewLlamaServer,
|
||||
getGpuFn: gpu.GetGPUInfo,
|
||||
getCpuFn: gpu.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
pendingReqCh: make(chan *LlmRequest, maxQueue),
|
||||
finishedReqCh: make(chan *LlmRequest, maxQueue),
|
||||
expiredCh: make(chan *runnerRef, maxQueue),
|
||||
unloadedCh: make(chan interface{}, maxQueue),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: llm.NewLlamaServer,
|
||||
getGpuFn: gpu.GetGPUInfo,
|
||||
getCpuFn: gpu.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
whisperLoaded: make(map[string]*int),
|
||||
whisperExpiresAt: make(map[string]time.Time),
|
||||
}
|
||||
sched.loadFn = sched.load
|
||||
return sched
|
||||
@ -110,6 +116,10 @@ func (s *Scheduler) Run(ctx context.Context) {
|
||||
go func() {
|
||||
s.processCompleted(ctx)
|
||||
}()
|
||||
|
||||
// go func() {
|
||||
// could clean up whisper servers in init thread
|
||||
// }
|
||||
}
|
||||
|
||||
func (s *Scheduler) processPending(ctx context.Context) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user