Compare commits
12 Commits
main
...
brucemacd/
Author | SHA1 | Date | |
---|---|---|---|
|
5690cd5ab2 | ||
|
3befd3dad9 | ||
|
75657e30e0 | ||
|
72ce843921 | ||
|
1eff6853eb | ||
|
6e6b3baa23 | ||
|
50ba1dcc7e | ||
|
2d5d926ce0 | ||
|
8c044c8dd2 | ||
|
e03fbf558d | ||
|
4718ecc62e | ||
|
9c21d23a35 |
55
api/types.go
55
api/types.go
@ -31,18 +31,24 @@ func (e StatusError) Error() string {
|
||||
}
|
||||
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"` // DEPRECATED: context is deprecated, use messages instead
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"` // one of ["system", "user", "assistant"]
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
|
||||
type Options struct {
|
||||
Runner
|
||||
@ -87,6 +93,23 @@ type Runner struct {
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
}
|
||||
|
||||
type GenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response,omitempty"` // the latest response chunk when streaming
|
||||
Message *Message `json:"message,omitempty"` // the latest message chunk when streaming
|
||||
|
||||
Done bool `json:"done"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
@ -164,20 +187,8 @@ type TokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type GenerateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response"`
|
||||
|
||||
Done bool `json:"done"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount int `json:"eval_count,omitempty"`
|
||||
EvalDuration time.Duration `json:"eval_duration,omitempty"`
|
||||
func (r *GenerateRequest) Empty() bool {
|
||||
return r.Prompt == "" && r.Template == "" && r.System == "" && len(r.Messages) == 0
|
||||
}
|
||||
|
||||
func (r *GenerateResponse) Summary() {
|
||||
|
75
cmd/cmd.go
75
cmd/cmd.go
@ -427,7 +427,10 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||
|
||||
// output is being piped
|
||||
if !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
return generate(cmd, args[0], strings.Join(prompts, " "), false, format)
|
||||
if _, err := generate(cmd, false, api.GenerateRequest{Model: args[0], Prompt: strings.Join(prompts, " "), Format: format}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
wordWrap := os.Getenv("TERM") == "xterm-256color"
|
||||
@ -442,18 +445,19 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||
|
||||
// prompts are provided via stdin or args so don't enter interactive mode
|
||||
if len(prompts) > 0 {
|
||||
return generate(cmd, args[0], strings.Join(prompts, " "), wordWrap, format)
|
||||
if _, err := generate(cmd, wordWrap, api.GenerateRequest{Model: args[0], Prompt: strings.Join(prompts, " "), Format: format}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, args[0], wordWrap, format)
|
||||
}
|
||||
|
||||
type generateContextKey string
|
||||
|
||||
func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format string) error {
|
||||
func generate(cmd *cobra.Command, wordWrap bool, request api.GenerateRequest) (*api.Message, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
@ -464,11 +468,6 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
|
||||
|
||||
var latest api.GenerateResponse
|
||||
|
||||
generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
|
||||
if !ok {
|
||||
generateContext = []int{}
|
||||
}
|
||||
|
||||
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil {
|
||||
wordWrap = false
|
||||
@ -490,14 +489,30 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
|
||||
var currentLineLength int
|
||||
var wordBuffer string
|
||||
|
||||
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, Format: format}
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
var role string
|
||||
var fullResponse strings.Builder
|
||||
|
||||
fn := func(generated api.GenerateResponse) error {
|
||||
p.StopAndClear()
|
||||
|
||||
latest = response
|
||||
latest = generated
|
||||
|
||||
if generated.Response == "" && generated.Message == nil {
|
||||
// warm-up response
|
||||
return nil
|
||||
}
|
||||
var content string
|
||||
if generated.Message != nil {
|
||||
role = generated.Message.Role
|
||||
content = generated.Message.Content
|
||||
} else {
|
||||
content = generated.Response
|
||||
}
|
||||
|
||||
fullResponse.WriteString(content)
|
||||
|
||||
if wordWrap {
|
||||
for _, ch := range response.Response {
|
||||
for _, ch := range content {
|
||||
if currentLineLength+1 > termWidth-5 {
|
||||
// backtrack the length of the last word and clear to the end of the line
|
||||
fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer))
|
||||
@ -518,7 +533,7 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Print(response.Response)
|
||||
fmt.Print(content)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -526,41 +541,39 @@ func generate(cmd *cobra.Command, model, prompt string, wordWrap bool, format st
|
||||
|
||||
if err := client.Generate(cancelCtx, &request, fn); err != nil {
|
||||
if strings.Contains(err.Error(), "context canceled") && abort {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if prompt != "" {
|
||||
|
||||
if request.Prompt != "" || request.Messages != nil {
|
||||
// spacing for readability, a message was sent
|
||||
fmt.Println()
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
if !latest.Done {
|
||||
if abort {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
return errors.New("unexpected end of response")
|
||||
return nil, errors.New("unexpected end of response")
|
||||
}
|
||||
|
||||
verbose, err := cmd.Flags().GetBool("verbose")
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if verbose {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
ctx := cmd.Context()
|
||||
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
|
||||
cmd.SetContext(ctx)
|
||||
|
||||
return nil
|
||||
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format string) error {
|
||||
// load the model
|
||||
if err := generate(cmd, model, "", false, ""); err != nil {
|
||||
if _, err := generate(cmd, false, api.GenerateRequest{Model: model}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -614,6 +627,7 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||
defer fmt.Printf(readline.EndBracketedPaste)
|
||||
|
||||
var multiLineBuffer string
|
||||
messages := make([]api.Message, 0)
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@ -756,9 +770,12 @@ func generateInteractive(cmd *cobra.Command, model string, wordWrap bool, format
|
||||
}
|
||||
|
||||
if len(line) > 0 && line[0] != '/' {
|
||||
if err := generate(cmd, model, line, wordWrap, format); err != nil {
|
||||
messages = append(messages, api.Message{Role: "user", Content: line})
|
||||
assistant, err := generate(cmd, wordWrap, api.GenerateRequest{Model: model, Messages: messages, Format: format})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
messages = append(messages, *assistant)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
78
docs/api.md
78
docs/api.md
@ -24,7 +24,7 @@ All durations are returned in nanoseconds.
|
||||
|
||||
### Streaming responses
|
||||
|
||||
Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character.
|
||||
Certain endpoints stream responses as JSON objects.
|
||||
|
||||
## Generate a completion
|
||||
|
||||
@ -32,22 +32,23 @@ Certain endpoints stream responses as JSON objects delineated with the newline (
|
||||
POST /api/generate
|
||||
```
|
||||
|
||||
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request.
|
||||
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
|
||||
|
||||
### Parameters
|
||||
|
||||
`model` and *one* of `prompt` or `messages` is required.
|
||||
|
||||
- `model`: (required) the [model name](#model-names)
|
||||
- `prompt`: the prompt to generate a response for
|
||||
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Currently the only accepted value is `json`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `system`: system prompt to (overrides what is defined in the `Modelfile`)
|
||||
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
|
||||
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself.
|
||||
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing messages yourself.
|
||||
|
||||
### JSON mode
|
||||
|
||||
@ -57,12 +58,17 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur
|
||||
|
||||
### Examples
|
||||
|
||||
#### Request
|
||||
#### Request (Chat Mode)
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
curl -X POST http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"prompt": "Why is the sky blue?"
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "why is the sky blue?"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
@ -74,7 +80,10 @@ A stream of JSON objects is returned:
|
||||
{
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||
"response": "The",
|
||||
"message": {
|
||||
"role": "assisant",
|
||||
"content": "The"
|
||||
},
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
@ -89,8 +98,7 @@ The final response in the stream also includes additional data about the generat
|
||||
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
|
||||
- `eval_count`: number of tokens the response
|
||||
- `eval_duration`: time in nanoseconds spent generating the response
|
||||
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
|
||||
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
|
||||
- `message`: omitted if the response was streamed, if not streamed, this will contain the full response
|
||||
|
||||
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
|
||||
|
||||
@ -98,8 +106,44 @@ To calculate how fast the response is generated in tokens per second (token/s),
|
||||
{
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"response": "",
|
||||
"context": [1, 2, 3],
|
||||
"done": true,
|
||||
"total_duration": 5589157167,
|
||||
"load_duration": 3013701500,
|
||||
"sample_count": 114,
|
||||
"sample_duration": 81442000,
|
||||
"prompt_eval_count": 46,
|
||||
"prompt_eval_duration": 1160282000,
|
||||
"eval_count": 113,
|
||||
"eval_duration": 1325948000
|
||||
}
|
||||
```
|
||||
|
||||
#### Request (Prompt)
|
||||
|
||||
```shell
|
||||
curl -X POST http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"prompt": "Why is the sky blue?"
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||
"response": "The",
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
|
||||
Final response:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"done": true,
|
||||
"total_duration": 5589157167,
|
||||
"load_duration": 3013701500,
|
||||
@ -114,6 +158,8 @@ To calculate how fast the response is generated in tokens per second (token/s),
|
||||
|
||||
#### Request (No streaming)
|
||||
|
||||
A response can be recieved in one reply when streaming is off. This applies to `prompt` and `messages`.
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
@ -131,7 +177,6 @@ If `stream` is set to `false`, the response will be a single JSON object:
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"response": "The sky is blue because it is the color of the sky.",
|
||||
"context": [1, 2, 3],
|
||||
"done": true,
|
||||
"total_duration": 5589157167,
|
||||
"load_duration": 3013701500,
|
||||
@ -144,9 +189,9 @@ If `stream` is set to `false`, the response will be a single JSON object:
|
||||
}
|
||||
```
|
||||
|
||||
#### Request (Raw mode)
|
||||
#### Request (Raw Mode)
|
||||
|
||||
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context.
|
||||
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting.
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
@ -275,7 +320,6 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama2",
|
||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||
"response": "The sky is blue because it is the color of the sky.",
|
||||
"context": [1, 2, 3],
|
||||
"done": true,
|
||||
"total_duration": 5589157167,
|
||||
"load_duration": 3013701500,
|
||||
|
@ -125,11 +125,11 @@ PARAMETER <parameter> <parametervalue>
|
||||
|
||||
#### Template Variables
|
||||
|
||||
| Variable | Description |
|
||||
| --------------- | ------------------------------------------------------------------------------------------------------------ |
|
||||
| `{{ .System }}` | The system prompt used to specify custom behavior, this must also be set in the Modelfile as an instruction. |
|
||||
| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. |
|
||||
| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. |
|
||||
| Variable | Description |
|
||||
| --------------------------------- | ------------------------------------------------------------------------------------------------------------ |
|
||||
| `{{ .System }}` | The system prompt used to specify custom behavior, this must also be set in the Modelfile as an instruction. |
|
||||
| `{{ .Prompt }}` or `{{ .User }}` | The incoming prompt from the user, this is not specified in the model file and will be set based on input. |
|
||||
| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. |
|
||||
|
||||
```modelfile
|
||||
TEMPLATE """
|
||||
|
26
llm/llama.go
26
llm/llama.go
@ -527,21 +527,9 @@ type prediction struct {
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
|
||||
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
|
||||
prevConvo, err := llm.Decode(ctx, prevContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove leading spaces from prevConvo if present
|
||||
prevConvo = strings.TrimPrefix(prevConvo, " ")
|
||||
|
||||
var nextContext strings.Builder
|
||||
nextContext.WriteString(prevConvo)
|
||||
nextContext.WriteString(prompt)
|
||||
|
||||
func (llm *llama) Predict(ctx context.Context, prompt string, format string, fn func(api.GenerateResponse)) error {
|
||||
request := map[string]any{
|
||||
"prompt": nextContext.String(),
|
||||
"prompt": prompt,
|
||||
"stream": true,
|
||||
"n_predict": llm.NumPredict,
|
||||
"n_keep": llm.NumKeep,
|
||||
@ -620,19 +608,15 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||
}
|
||||
|
||||
if p.Content != "" {
|
||||
fn(api.GenerateResponse{Response: p.Content})
|
||||
nextContext.WriteString(p.Content)
|
||||
fn(api.GenerateResponse{
|
||||
Response: p.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if p.Stop {
|
||||
embd, err := llm.Encode(ctx, nextContext.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding context: %v", err)
|
||||
}
|
||||
|
||||
fn(api.GenerateResponse{
|
||||
Done: true,
|
||||
Context: embd,
|
||||
PromptEvalCount: p.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
||||
EvalCount: p.Timings.PredictedN,
|
||||
|
@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
type LLM interface {
|
||||
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
|
||||
Predict(context.Context, string, string, func(api.GenerateResponse)) error
|
||||
Embedding(context.Context, string) ([]float64, error)
|
||||
Encode(context.Context, string) ([]int, error)
|
||||
Decode(context.Context, []int) (string, error)
|
||||
|
@ -48,29 +48,21 @@ type Model struct {
|
||||
Options map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
||||
t := m.Template
|
||||
if request.Template != "" {
|
||||
t = request.Template
|
||||
}
|
||||
type PromptVars struct {
|
||||
System string
|
||||
Prompt string // prompt and user are considered the same thing
|
||||
User string
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").Parse(t)
|
||||
func (m *Model) Prompt(vars *PromptVars) (string, error) {
|
||||
tmpl, err := template.New("").Parse(m.Template)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var vars struct {
|
||||
First bool
|
||||
System string
|
||||
Prompt string
|
||||
}
|
||||
|
||||
vars.First = len(request.Context) == 0
|
||||
vars.System = m.System
|
||||
vars.Prompt = request.Prompt
|
||||
|
||||
if request.System != "" {
|
||||
vars.System = request.System
|
||||
if vars.System == "" {
|
||||
// use the default system prompt for this model if one is not specified
|
||||
vars.System = m.System
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
@ -2,17 +2,15 @@ package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
func TestModelPrompt(t *testing.T) {
|
||||
var m Model
|
||||
req := api.GenerateRequest{
|
||||
m := Model{
|
||||
Template: "a{{ .Prompt }}b",
|
||||
Prompt: "<h1>",
|
||||
}
|
||||
s, err := m.Prompt(req)
|
||||
s, err := m.Prompt(&PromptVars{
|
||||
Prompt: "<h1>",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
158
server/routes.go
158
server/routes.go
@ -168,8 +168,11 @@ func GenerateHandler(c *gin.Context) {
|
||||
case len(req.Format) > 0 && req.Format != "json":
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
||||
return
|
||||
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
||||
case req.Raw && (len(req.Context) > 0 || len(req.Messages) > 0 || req.System != "" || req.Template != ""):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, context, or messages"})
|
||||
return
|
||||
case len(req.Messages) > 0 && (len(req.Context) > 0 || req.Raw || req.System != "" || req.Template != "" || req.Prompt != ""):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "cannot specify context, raw, system, template, or prompt when using messages"})
|
||||
return
|
||||
}
|
||||
|
||||
@ -199,20 +202,36 @@ func GenerateHandler(c *gin.Context) {
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt := req.Prompt
|
||||
if !req.Raw {
|
||||
prompt, err = model.Prompt(req)
|
||||
var prompt string
|
||||
sendContext := false
|
||||
switch {
|
||||
case req.Raw:
|
||||
prompt = req.Prompt
|
||||
case len(req.Messages) > 0:
|
||||
prompt, err = promptFromMessages(model, req.Messages)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
case req.Prompt != "":
|
||||
prompt, err = promptFromRequestParams(c, model, req)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if errors.Is(err, errInvalidRole) {
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
c.JSON(status, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
sendContext = true
|
||||
}
|
||||
|
||||
var generated strings.Builder
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
// an empty request loads the model
|
||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||
if req.Empty() {
|
||||
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
|
||||
return
|
||||
}
|
||||
@ -223,44 +242,147 @@ func GenerateHandler(c *gin.Context) {
|
||||
|
||||
r.Model = req.Model
|
||||
r.CreatedAt = time.Now().UTC()
|
||||
|
||||
// build up the full response to send back the context
|
||||
if _, err := generated.WriteString(r.Response); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
r.TotalDuration = time.Since(checkpointStart)
|
||||
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
// if the response expects a context, encode it and send it back
|
||||
if sendContext {
|
||||
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+generated.String())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
r.Context = embd
|
||||
}
|
||||
}
|
||||
|
||||
if req.Raw {
|
||||
// in raw mode the client must manage history on their own
|
||||
r.Context = nil
|
||||
// determine if the client should get a prompt/response or message
|
||||
if len(req.Messages) > 0 && !r.Done {
|
||||
r.Message = &api.Message{Role: "assistant", Content: r.Response}
|
||||
// do not send back response in the case of messages
|
||||
r.Response = ""
|
||||
}
|
||||
|
||||
ch <- r
|
||||
}
|
||||
|
||||
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil {
|
||||
if err := loaded.runner.Predict(c.Request.Context(), prompt, req.Format, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
var response api.GenerateResponse
|
||||
generated := ""
|
||||
// Wait for the channel to close
|
||||
var r api.GenerateResponse
|
||||
for resp := range ch {
|
||||
if r, ok := resp.(api.GenerateResponse); ok {
|
||||
generated += r.Response
|
||||
response = r
|
||||
} else {
|
||||
var ok bool
|
||||
if r, ok = resp.(api.GenerateResponse); !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
response.Response = generated
|
||||
c.JSON(http.StatusOK, response)
|
||||
if len(req.Messages) > 0 {
|
||||
r.Message = &api.Message{Role: "assistant", Content: generated.String()}
|
||||
} else {
|
||||
r.Response = generated.String()
|
||||
}
|
||||
c.JSON(http.StatusOK, r)
|
||||
return
|
||||
}
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func promptFromRequestParams(c *gin.Context, model *Model, req api.GenerateRequest) (string, error) {
|
||||
if req.Template != "" {
|
||||
// override the default model template
|
||||
model.Template = req.Template
|
||||
}
|
||||
|
||||
var prompt strings.Builder
|
||||
if req.Context != nil {
|
||||
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
|
||||
prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Remove leading spaces from prevCtx if present
|
||||
prevCtx = strings.TrimPrefix(prevCtx, " ")
|
||||
prompt.WriteString(prevCtx)
|
||||
}
|
||||
p, err := model.Prompt(&PromptVars{
|
||||
System: req.System,
|
||||
Prompt: req.Prompt,
|
||||
User: req.Prompt,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
return prompt.String(), nil
|
||||
}
|
||||
|
||||
var errInvalidRole = errors.New("invalid message role")
|
||||
|
||||
func promptFromMessages(model *Model, messages []api.Message) (string, error) {
|
||||
flush := func(vars *PromptVars, model *Model, prompt *strings.Builder) error {
|
||||
p, err := model.Prompt(vars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
|
||||
vars.Prompt = ""
|
||||
vars.User = ""
|
||||
vars.System = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
var prompt strings.Builder
|
||||
vars := &PromptVars{}
|
||||
for _, m := range messages {
|
||||
if (m.Role == "system" || m.Role == "user") && vars.User != "" {
|
||||
if err := flush(vars, model, &prompt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
if m.Role == "assistant" && (vars.User != "" || vars.System != "") {
|
||||
if err := flush(vars, model, &prompt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
switch m.Role {
|
||||
case "system":
|
||||
vars.System = m.Content
|
||||
case "user":
|
||||
vars.Prompt = m.Content
|
||||
vars.User = m.Content
|
||||
case "assistant":
|
||||
prompt.WriteString(m.Content)
|
||||
default:
|
||||
return "", fmt.Errorf("%w %q, role must be one of [system, user, assistant]", errInvalidRole, m.Role)
|
||||
}
|
||||
}
|
||||
|
||||
if vars.Prompt != "" || vars.System != "" {
|
||||
if err := flush(vars, model, &prompt); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return prompt.String(), nil
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context) {
|
||||
loaded.mu.Lock()
|
||||
defer loaded.mu.Unlock()
|
||||
|
Loading…
x
Reference in New Issue
Block a user