From 19ce10e49e29a405e7ddf7e6853a7df104c94006 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Wed, 1 May 2024 20:14:32 -0700 Subject: [PATCH] add a /tokenize endpoint --- api/types.go | 12 ++++++++ server/routes.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/api/types.go b/api/types.go index c210d419..d3668b89 100644 --- a/api/types.go +++ b/api/types.go @@ -195,6 +195,18 @@ type EmbeddingResponse struct { Embedding []float64 `json:"embedding"` } +type TokenizeRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + KeepAlive *Duration `json:"keep_alive,omitempty"` + + Options map[string]interface{} `json:"options"` +} + +type TokenizeResponse struct { + Tokens []int `json:"tokens"` +} + // CreateRequest is the request passed to [Client.Create]. type CreateRequest struct { Model string `json:"model"` diff --git a/server/routes.go b/server/routes.go index ec9f0e76..938330c8 100644 --- a/server/routes.go +++ b/server/routes.go @@ -407,6 +407,84 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { c.JSON(http.StatusOK, resp) } +func (s *Server) TokenizeHandler(c *gin.Context) { + var req api.TokenizeRequest + err := c.ShouldBindJSON(&req) + switch { + case errors.Is(err, io.EOF): + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + case err != nil: + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if req.Model == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"}) + return + } + + model, err := GetModel(req.Model) + if err != nil { + var pErr *fs.PathError + if errors.As(err, &pErr) { + c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + opts, err := modelOptions(model, req.Options) + if err != nil { + if errors.Is(err, api.ErrInvalidOpts) { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + var sessionDuration time.Duration + if req.KeepAlive == nil { + sessionDuration = getDefaultSessionDuration() + } else { + sessionDuration = req.KeepAlive.Duration + } + + rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration) + var runner *runnerRef + select { + case runner = <-rCh: + case err = <-eCh: + if errors.Is(err, context.Canceled) { + c.JSON(499, gin.H{"error": "request canceled"}) + return + } + + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // an empty request loads the model + if req.Prompt == "" { + c.JSON(http.StatusOK, api.TokenizeResponse{Tokens: []int{}}) + return + } + + tokens, err := runner.llama.Tokenize(c.Request.Context(), req.Prompt) + if err != nil { + slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) + return + } + + resp := api.TokenizeResponse{ + Tokens: tokens, + } + c.JSON(http.StatusOK, resp) +} + func (s *Server) PullModelHandler(c *gin.Context) { var req api.PullRequest err := c.ShouldBindJSON(&req) @@ -967,6 +1045,7 @@ func (s *Server) GenerateRoutes() http.Handler { r.POST("/api/generate", s.GenerateHandler) r.POST("/api/chat", s.ChatHandler) r.POST("/api/embeddings", s.EmbeddingsHandler) + r.POST("/api/tokenize", s.TokenizeHandler) r.POST("/api/create", s.CreateModelHandler) r.POST("/api/push", s.PushModelHandler) r.POST("/api/copy", s.CopyModelHandler)