diff --git a/server/routes.go b/server/routes.go index f9749db2..8c95f158 100644 --- a/server/routes.go +++ b/server/routes.go @@ -216,7 +216,11 @@ func GenerateHandler(c *gin.Context) { case req.Prompt != "": prompt, err = promptFromRequestParams(c, model, req) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + status := http.StatusInternalServerError + if errors.Is(err, errInvalidRole) { + status = http.StatusBadRequest + } + c.JSON(status, gin.H{"error": err.Error()}) return } sendContext = true @@ -325,6 +329,8 @@ func promptFromRequestParams(c *gin.Context, model *Model, req api.GenerateReque return prompt.String(), nil } +var errInvalidRole = errors.New("invalid message role") + func promptFromMessages(model *Model, messages []api.Message) (string, error) { flush := func(vars *PromptVars, model *Model, prompt *strings.Builder) error { p, err := model.Prompt(vars) @@ -360,6 +366,8 @@ func promptFromMessages(model *Model, messages []api.Message) (string, error) { vars.Prompt = m.Content case "assistant": prompt.WriteString(m.Content) + default: + return "", fmt.Errorf("%w %q, role must be one of [system, user, assistant]", errInvalidRole, m.Role) } }