forked from third-party-mirrors/ollama
Compare commits
1 Commits
main
...
jmorganca/
Author | SHA1 | Date | |
---|---|---|---|
|
e117483ef6 |
@ -98,7 +98,8 @@ type ChatResponse struct {
|
|||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Message Message `json:"message"`
|
Message Message `json:"message"`
|
||||||
|
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
DoneReason string `json:"done_reason,omitempty"`
|
||||||
|
|
||||||
Metrics
|
Metrics
|
||||||
}
|
}
|
||||||
@ -265,8 +266,9 @@ type GenerateResponse struct {
|
|||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Response string `json:"response"`
|
Response string `json:"response"`
|
||||||
|
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
Context []int `json:"context,omitempty"`
|
DoneReason string `json:"done_reason,omitempty"`
|
||||||
|
Context []int `json:"context,omitempty"`
|
||||||
|
|
||||||
Metrics
|
Metrics
|
||||||
}
|
}
|
||||||
|
@ -509,10 +509,13 @@ type ImageData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type completion struct {
|
type completion struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Stop bool `json:"stop"`
|
Stop bool `json:"stop"`
|
||||||
|
StoppedEos bool `json:"stopped_eos"`
|
||||||
|
StoppedWord bool `json:"stopped_word"`
|
||||||
|
StoppedLimit bool `json:"stopped_limit"`
|
||||||
|
|
||||||
Timings struct {
|
Timings struct {
|
||||||
PredictedN int `json:"predicted_n"`
|
PredictedN int `json:"predicted_n"`
|
||||||
@ -532,6 +535,7 @@ type CompletionRequest struct {
|
|||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
Content string
|
Content string
|
||||||
Done bool
|
Done bool
|
||||||
|
DoneReason string
|
||||||
PromptEvalCount int
|
PromptEvalCount int
|
||||||
PromptEvalDuration time.Duration
|
PromptEvalDuration time.Duration
|
||||||
EvalCount int
|
EvalCount int
|
||||||
@ -648,6 +652,8 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
|||||||
return fmt.Errorf("error parsing llm response stream: %s", line)
|
return fmt.Errorf("error parsing llm response stream: %s", line)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("c", string(evt))
|
||||||
|
|
||||||
var c completion
|
var c completion
|
||||||
if err := json.Unmarshal(evt, &c); err != nil {
|
if err := json.Unmarshal(evt, &c); err != nil {
|
||||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||||
@ -674,8 +680,18 @@ func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn
|
|||||||
}
|
}
|
||||||
|
|
||||||
if c.Stop {
|
if c.Stop {
|
||||||
|
var doneReason string
|
||||||
|
switch {
|
||||||
|
case c.StoppedEos:
|
||||||
|
doneReason = "stop"
|
||||||
|
case c.StoppedWord:
|
||||||
|
doneReason = "stop"
|
||||||
|
case c.StoppedLimit:
|
||||||
|
doneReason = "limit"
|
||||||
|
}
|
||||||
fn(CompletionResponse{
|
fn(CompletionResponse{
|
||||||
Done: true,
|
Done: true,
|
||||||
|
DoneReason: doneReason,
|
||||||
PromptEvalCount: c.Timings.PromptN,
|
PromptEvalCount: c.Timings.PromptN,
|
||||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||||
EvalCount: c.Timings.PredictedN,
|
EvalCount: c.Timings.PredictedN,
|
||||||
|
@ -91,7 +91,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
||||||
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, int, error) {
|
||||||
type prompt struct {
|
type prompt struct {
|
||||||
System string
|
System string
|
||||||
Prompt string
|
Prompt string
|
||||||
@ -138,7 +138,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
|||||||
|
|
||||||
p.Response = msg.Content
|
p.Response = msg.Content
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
return "", 0, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
|||||||
for i, p := range prompts {
|
for i, p := range prompts {
|
||||||
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
|
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
prompts[i].tokens = tokens + len(prompts[i].images)*768
|
prompts[i].tokens = tokens + len(prompts[i].images)*768
|
||||||
@ -160,15 +160,17 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
|||||||
// truncate images and prompts starting from the beginning of the list
|
// truncate images and prompts starting from the beginning of the list
|
||||||
// until either one prompt remains or the total tokens fits the context window
|
// until either one prompt remains or the total tokens fits the context window
|
||||||
// TODO (jmorganca): this doesn't account for the context window room required for the response
|
// TODO (jmorganca): this doesn't account for the context window room required for the response
|
||||||
|
var required int
|
||||||
for {
|
for {
|
||||||
var required int
|
required = 0
|
||||||
for _, p := range prompts {
|
for _, p := range prompts {
|
||||||
required += p.tokens
|
required += p.tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
required += 1 // for bos token
|
required += 1 // for bos token
|
||||||
|
|
||||||
if required <= window {
|
// leave ~1024 tokens for generation
|
||||||
|
if required <= max(1024, window/2) {
|
||||||
slog.Debug("prompt now fits in context window", "required", required, "window", window)
|
slog.Debug("prompt now fits in context window", "required", required, "window", window)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -194,7 +196,7 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
|||||||
|
|
||||||
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
|
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
prompts[0].tokens = tokens + len(prompts[0].images)*768
|
prompts[0].tokens = tokens + len(prompts[0].images)*768
|
||||||
@ -212,10 +214,10 @@ func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(str
|
|||||||
// last prompt should leave the response unrendered (for completion)
|
// last prompt should leave the response unrendered (for completion)
|
||||||
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
|
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
sb.WriteString(rendered)
|
sb.WriteString(rendered)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
return sb.String(), required, nil
|
||||||
}
|
}
|
||||||
|
@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
|
got, _, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error = %v", err)
|
t.Errorf("error = %v", err)
|
||||||
}
|
}
|
||||||
|
@ -234,9 +234,10 @@ func GenerateHandler(c *gin.Context) {
|
|||||||
// of `raw` mode so we need to check for it too
|
// of `raw` mode so we need to check for it too
|
||||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
Done: true,
|
Done: true,
|
||||||
|
DoneReason: "load",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -289,6 +290,14 @@ func GenerateHandler(c *gin.Context) {
|
|||||||
prompt = sb.String()
|
prompt = sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokens, err := loaded.llama.Tokenize(c.Request.Context(), prompt)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opts.NumPredict = max(opts.NumCtx-len(tokens), 0)
|
||||||
|
|
||||||
slog.Debug("generate handler", "prompt", prompt)
|
slog.Debug("generate handler", "prompt", prompt)
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
@ -307,10 +316,11 @@ func GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
resp := api.GenerateResponse{
|
resp := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Done: r.Done,
|
Done: r.Done,
|
||||||
Response: r.Content,
|
DoneReason: r.DoneReason,
|
||||||
|
Response: r.Content,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: r.PromptEvalCount,
|
||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
@ -1219,17 +1229,17 @@ func streamResponse(c *gin.Context, ch chan any) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||||
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
|
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, int, error) {
|
||||||
encode := func(s string) ([]int, error) {
|
encode := func(s string) ([]int, error) {
|
||||||
return loaded.llama.Tokenize(ctx, s)
|
return loaded.llama.Tokenize(ctx, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
prompt, tokens, err := ChatPrompt(template, messages, numCtx, encode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return prompt, nil
|
return prompt, tokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ChatHandler(c *gin.Context) {
|
func ChatHandler(c *gin.Context) {
|
||||||
@ -1309,19 +1319,22 @@ func ChatHandler(c *gin.Context) {
|
|||||||
}, req.Messages...)
|
}, req.Messages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
|
prompt, tokens, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
opts.NumPredict = max(opts.NumCtx-tokens, 0)
|
||||||
|
|
||||||
// an empty request loads the model
|
// an empty request loads the model
|
||||||
if len(req.Messages) == 0 || prompt == "" {
|
if len(req.Messages) == 0 || prompt == "" {
|
||||||
resp := api.ChatResponse{
|
resp := api.ChatResponse{
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
Done: true,
|
Done: true,
|
||||||
Message: api.Message{Role: "assistant"},
|
DoneReason: "load",
|
||||||
|
Message: api.Message{Role: "assistant"},
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
return
|
return
|
||||||
@ -1356,10 +1369,11 @@ func ChatHandler(c *gin.Context) {
|
|||||||
loaded.expireTimer.Reset(sessionDuration)
|
loaded.expireTimer.Reset(sessionDuration)
|
||||||
|
|
||||||
resp := api.ChatResponse{
|
resp := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||||
Done: r.Done,
|
Done: r.Done,
|
||||||
|
DoneReason: r.DoneReason,
|
||||||
Metrics: api.Metrics{
|
Metrics: api.Metrics{
|
||||||
PromptEvalCount: r.PromptEvalCount,
|
PromptEvalCount: r.PromptEvalCount,
|
||||||
PromptEvalDuration: r.PromptEvalDuration,
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user