more fixes for mllama
This commit is contained in:
parent
5da1043680
commit
c48e2cfc0d
@ -675,7 +675,6 @@ const maxBufferSize = 512 * format.KiloByte
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
ImageData []float32 `json:"image_data"`
|
||||
AspectRatioID int `json:"aspect_ratio_id"`
|
||||
}
|
||||
|
||||
|
@ -159,11 +159,7 @@ func PadImage(img image.Image, outputSize, aspectRatio image.Point) image.Image
|
||||
}
|
||||
|
||||
dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y))
|
||||
centerX := (paddedSize.X - img.Bounds().Max.X) / 2
|
||||
centerY := (paddedSize.Y - img.Bounds().Max.Y) / 2
|
||||
pos := image.Rect(centerX, centerY, centerX+img.Bounds().Max.X, centerY+img.Bounds().Max.Y)
|
||||
|
||||
draw.Draw(dst, pos, img, image.Point{0, 0}, draw.Over)
|
||||
draw.Draw(dst, img.Bounds(), img, image.Point{0, 0}, draw.Over)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
@ -3,7 +3,10 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
@ -18,6 +21,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
||||
var system []api.Message
|
||||
|
||||
// always include the last message
|
||||
n := len(msgs) - 1
|
||||
// in reverse, find all messages that fit into context window
|
||||
@ -39,16 +43,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
c := len(s)
|
||||
ctxLen := len(s)
|
||||
if m.ProjectorPaths != nil {
|
||||
for _, m := range msgs[i:] {
|
||||
// images are represented as 768 sized embeddings
|
||||
// TODO: get embedding length from project metadata
|
||||
c += 768 * len(m.Images)
|
||||
ctxLen += 768 * len(m.Images)
|
||||
}
|
||||
}
|
||||
|
||||
if c > opts.NumCtx {
|
||||
if ctxLen > opts.NumCtx {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||
break
|
||||
} else {
|
||||
@ -56,35 +60,58 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
}
|
||||
}
|
||||
|
||||
// truncate any messages that do not fit into the context window
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
|
||||
return "", nil, err
|
||||
currMsgIdx := n
|
||||
|
||||
if checkMllamaModelFamily(m) {
|
||||
lastMsgIdx := len(msgs) - 1
|
||||
if len(msgs[lastMsgIdx].Images) == 1 {
|
||||
data, aspectRatioID, err := imageproc.Preprocess(msgs[lastMsgIdx].Images[0])
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
err = binary.Write(buf, binary.LittleEndian, data)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
imgData := llm.ImageData{
|
||||
Data: buf.Bytes(),
|
||||
AspectRatioID: aspectRatioID,
|
||||
}
|
||||
|
||||
msgs[lastMsgIdx].Content = strings.TrimSpace("<|image|>" + msgs[lastMsgIdx].Content)
|
||||
images = append(images, imgData)
|
||||
}
|
||||
}
|
||||
|
||||
preprocess := checkMllamaModelFamily(m)
|
||||
|
||||
for _, m := range msgs[n:] {
|
||||
for _, i := range m.Images {
|
||||
if preprocess {
|
||||
data, aspectRatioID, err := imageproc.Preprocess(i)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
images = append(images, llm.ImageData{
|
||||
ID: len(images),
|
||||
ImageData: data,
|
||||
AspectRatioID: aspectRatioID,
|
||||
})
|
||||
} else {
|
||||
images = append(images, llm.ImageData{
|
||||
ID: len(images),
|
||||
Data: i,
|
||||
})
|
||||
for cnt, msg := range msgs[currMsgIdx:] {
|
||||
for _, i := range msg.Images {
|
||||
imgData := llm.ImageData{
|
||||
ID: len(images),
|
||||
Data: i,
|
||||
}
|
||||
|
||||
imageTag := fmt.Sprintf("[img-%d]", imgData.ID)
|
||||
prompt := msg.Content
|
||||
|
||||
if !strings.Contains(prompt, "[img]") {
|
||||
prompt = strings.TrimSpace("[img] " + prompt)
|
||||
}
|
||||
prompt = strings.Replace(prompt, "[img]", imageTag, 1)
|
||||
msgs[currMsgIdx+cnt].Content = prompt
|
||||
|
||||
images = append(images, imgData)
|
||||
}
|
||||
}
|
||||
|
||||
// truncate any messages that do not fit into the context window
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return b.String(), images, nil
|
||||
}
|
||||
|
||||
|
@ -119,20 +119,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
case err.Error() == "invalid model name":
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
case err.Error() == "invalid model name":
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
s.sched.expireRunner(model)
|
||||
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
@ -169,6 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
// load the model
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
@ -179,6 +181,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
isMllama := checkMllamaModelFamily(model)
|
||||
if isMllama && len(req.Images) > 1 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
|
||||
return
|
||||
}
|
||||
|
||||
images := make([]llm.ImageData, len(req.Images))
|
||||
for i := range req.Images {
|
||||
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
|
||||
@ -212,7 +220,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
for _, i := range images {
|
||||
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
||||
if isMllama {
|
||||
msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"})
|
||||
} else {
|
||||
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
|
||||
}
|
||||
}
|
||||
|
||||
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||
|
@ -421,22 +421,22 @@ func TestGenerate(t *testing.T) {
|
||||
|
||||
t.Run("missing body", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, nil)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected status 404, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected status 404, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"slices"
|
||||
@ -302,22 +301,10 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
// into a single message. collate also collects and returns all system messages.
|
||||
// collate mutates message content adding image tags ([img-%d]) as needed
|
||||
func collate(msgs []api.Message) (string, []*api.Message) {
|
||||
var n int
|
||||
|
||||
var system []string
|
||||
var collated []*api.Message
|
||||
for i := range msgs {
|
||||
msg := msgs[i]
|
||||
for range msg.Images {
|
||||
imageTag := fmt.Sprintf("[img-%d]", n)
|
||||
if !strings.Contains(msg.Content, "[img]") {
|
||||
msg.Content = strings.TrimSpace("[img] " + msg.Content)
|
||||
}
|
||||
|
||||
msg.Content = strings.Replace(msg.Content, "[img]", imageTag, 1)
|
||||
n++
|
||||
}
|
||||
|
||||
if msg.Role == "system" {
|
||||
system = append(system, msg.Content)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user