Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
949fc4eafa |
16
api/types.go
16
api/types.go
@ -31,6 +31,22 @@ func (e StatusError) Error() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// /api/chat
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatResponse struct {
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Message Message `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
type GenerateRequest struct {
|
type GenerateRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
@ -54,6 +54,54 @@ type Model struct {
|
|||||||
Embeddings []vector.Embedding
|
Embeddings []vector.Embedding
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) ChatPrompt(messages []api.Message) (string, error) {
|
||||||
|
tmpl, err := template.New("").Parse(m.Template)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var vars struct {
|
||||||
|
System string
|
||||||
|
Prompt string
|
||||||
|
First bool
|
||||||
|
}
|
||||||
|
|
||||||
|
vars.First = true
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
flush := func() {
|
||||||
|
tmpl.Execute(&sb, vars)
|
||||||
|
vars.System = ""
|
||||||
|
vars.Prompt = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// build the chat history from messages
|
||||||
|
for _, m := range messages {
|
||||||
|
if m.Role == "system" {
|
||||||
|
if vars.System != "" {
|
||||||
|
flush()
|
||||||
|
}
|
||||||
|
vars.System = m.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Role == "user" {
|
||||||
|
if vars.Prompt != "" {
|
||||||
|
flush()
|
||||||
|
}
|
||||||
|
vars.Prompt = m.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Role == "assistant" {
|
||||||
|
flush()
|
||||||
|
sb.Write([]byte(m.Content))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flush()
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
|
func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
|
||||||
t := m.Template
|
t := m.Template
|
||||||
if request.Template != "" {
|
if request.Template != "" {
|
||||||
|
@ -156,6 +156,54 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ChatModelHandler(c *gin.Context) {
|
||||||
|
loaded.mu.Lock()
|
||||||
|
defer loaded.mu.Unlock()
|
||||||
|
|
||||||
|
var req api.ChatRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := GetModel(req.Model)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt, err := model.ChatPrompt(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var response string
|
||||||
|
fn := func(r api.GenerateResponse) {
|
||||||
|
response += r.Response
|
||||||
|
}
|
||||||
|
|
||||||
|
workDir := c.GetString("workDir")
|
||||||
|
if err := load(c.Request.Context(), workDir, model, nil, defaultSessionDuration); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(prompt)
|
||||||
|
|
||||||
|
if err := loaded.llm.Predict(c.Request.Context(), []int{}, prompt, fn); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, api.ChatResponse{
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: response,
|
||||||
|
},
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateHandler(c *gin.Context) {
|
func GenerateHandler(c *gin.Context) {
|
||||||
loaded.mu.Lock()
|
loaded.mu.Lock()
|
||||||
defer loaded.mu.Unlock()
|
defer loaded.mu.Unlock()
|
||||||
@ -552,6 +600,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
r.POST("/api/chat", ChatModelHandler)
|
||||||
r.POST("/api/pull", PullModelHandler)
|
r.POST("/api/pull", PullModelHandler)
|
||||||
r.POST("/api/generate", GenerateHandler)
|
r.POST("/api/generate", GenerateHandler)
|
||||||
r.POST("/api/embeddings", EmbeddingHandler)
|
r.POST("/api/embeddings", EmbeddingHandler)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user