add a /tokenize endpoint
This commit is contained in:
parent
7e2bceceee
commit
19ce10e49e
12
api/types.go
12
api/types.go
@ -195,6 +195,18 @@ type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
type TokenizeResponse struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
// CreateRequest is the request passed to [Client.Create].
|
||||
type CreateRequest struct {
|
||||
Model string `json:"model"`
|
||||
|
@ -407,6 +407,84 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) TokenizeHandler(c *gin.Context) {
|
||||
var req api.TokenizeRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
var pErr *fs.PathError
|
||||
if errors.As(err, &pErr) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
if errors.Is(err, api.ErrInvalidOpts) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var sessionDuration time.Duration
|
||||
if req.KeepAlive == nil {
|
||||
sessionDuration = getDefaultSessionDuration()
|
||||
} else {
|
||||
sessionDuration = req.KeepAlive.Duration
|
||||
}
|
||||
|
||||
rCh, eCh := s.sched.GetRunner(c.Request.Context(), model, opts, sessionDuration)
|
||||
var runner *runnerRef
|
||||
select {
|
||||
case runner = <-rCh:
|
||||
case err = <-eCh:
|
||||
if errors.Is(err, context.Canceled) {
|
||||
c.JSON(499, gin.H{"error": "request canceled"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// an empty request loads the model
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusOK, api.TokenizeResponse{Tokens: []int{}})
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := runner.llama.Tokenize(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
}
|
||||
|
||||
resp := api.TokenizeResponse{
|
||||
Tokens: tokens,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) PullModelHandler(c *gin.Context) {
|
||||
var req api.PullRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
@ -967,6 +1045,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
r.POST("/api/tokenize", s.TokenizeHandler)
|
||||
r.POST("/api/create", s.CreateModelHandler)
|
||||
r.POST("/api/push", s.PushModelHandler)
|
||||
r.POST("/api/copy", s.CopyModelHandler)
|
||||
|
Loading…
x
Reference in New Issue
Block a user