diff --git a/api/types.go b/api/types.go index 96324edb..61383a36 100644 --- a/api/types.go +++ b/api/types.go @@ -31,6 +31,22 @@ func (e StatusError) Error() string { } } +// /api/chat +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` +} + +type ChatResponse struct { + CreatedAt time.Time `json:"created_at"` + Message Message `json:"message"` +} + type GenerateRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` diff --git a/server/images.go b/server/images.go index cd5224c9..55b84cab 100644 --- a/server/images.go +++ b/server/images.go @@ -54,6 +54,54 @@ type Model struct { Embeddings []vector.Embedding } +func (m *Model) ChatPrompt(messages []api.Message) (string, error) { + tmpl, err := template.New("").Parse(m.Template) + if err != nil { + return "", err + } + + var vars struct { + System string + Prompt string + First bool + } + + vars.First = true + + var sb strings.Builder + flush := func() { + tmpl.Execute(&sb, vars) + vars.System = "" + vars.Prompt = "" + } + + // build the chat history from messages + for _, m := range messages { + if m.Role == "system" { + if vars.System != "" { + flush() + } + vars.System = m.Content + } + + if m.Role == "user" { + if vars.Prompt != "" { + flush() + } + vars.Prompt = m.Content + } + + if m.Role == "assistant" { + flush() + sb.Write([]byte(m.Content)) + } + } + + flush() + + return sb.String(), nil +} + func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) { t := m.Template if request.Template != "" { diff --git a/server/routes.go b/server/routes.go index c68011df..5543574c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -156,6 +156,54 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string] return nil } +func ChatModelHandler(c *gin.Context) { + loaded.mu.Lock() + defer loaded.mu.Unlock() + + var req api.ChatRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + model, err := GetModel(req.Model) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + prompt, err := model.ChatPrompt(req.Messages) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + var response string + fn := func(r api.GenerateResponse) { + response += r.Response + } + + workDir := c.GetString("workDir") + if err := load(c.Request.Context(), workDir, model, nil, defaultSessionDuration); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + fmt.Println(prompt) + + if err := loaded.llm.Predict(c.Request.Context(), []int{}, prompt, fn); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + } + + c.JSON(http.StatusOK, api.ChatResponse{ + Message: api.Message{ + Role: "assistant", + Content: response, + }, + CreatedAt: time.Now().UTC(), + }) +} + func GenerateHandler(c *gin.Context) { loaded.mu.Lock() defer loaded.mu.Unlock() @@ -552,6 +600,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { }, ) + r.POST("/api/chat", ChatModelHandler) r.POST("/api/pull", PullModelHandler) r.POST("/api/generate", GenerateHandler) r.POST("/api/embeddings", EmbeddingHandler)