api: expose tokenize and detokenize endpoints

This commit is contained in:
Yurzs 2024-09-01 23:35:58 +07:00
parent 5f7b4a5e30
commit 19a388bfb8
No known key found for this signature in database
GPG Key ID: 7F998CD5EA377078
3 changed files with 108 additions and 0 deletions

View File

@ -360,6 +360,24 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
return &resp, nil
}
// Tokenize tokenizes a string.
func (c *Client) Tokenize(ctx context.Context, req *TokenizeRequest) (*TokenizeResponse, error) {
var resp TokenizeResponse
if err := c.do(ctx, http.MethodPost, "/api/tokenize", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Detokenize detokenizes a string.
func (c *Client) Detokenize(ctx context.Context, req *DetokenizeRequest) (*DetokenizeResponse, error) {
var resp DetokenizeResponse
if err := c.do(ctx, http.MethodPost, "/api/detokenize", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// CreateBlob creates a blob from a file on the server. digest is the
// expected SHA256 digest of the file, and r represents the file.
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {

View File

@ -293,6 +293,44 @@ type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
// TokenizeRequest is the request passed to [Client.Tokenize].
type TokenizeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
// KeepAlive controls how long the model will stay loaded in memory following
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
// TokenizeResponse is the response from [Client.Tokenize].
type TokenizeResponse struct {
Model string `json:"model"`
Tokens []int `json:"tokens"`
}
// DetokenizeRequest is the request passed to [Client.Detokenize].
type DetokenizeRequest struct {
Model string `json:"model"`
Tokens []int `json:"tokens"`
// KeepAlive controls how long the model will stay loaded in memory following
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
// DetokenizeResponse is the response from [Client.Detokenize].
type DetokenizeResponse struct {
Model string `json:"model"`
Text string `json:"text"`
}
// CreateRequest is the request passed to [Client.Create].
type CreateRequest struct {
Model string `json:"model"`

View File

@ -463,6 +463,56 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
func (s *Server) TokenizeHandler(c *gin.Context) {
var req api.TokenizeRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
tokens, err := r.Tokenize(c.Request.Context(), req.Prompt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, api.TokenizeResponse{Model: req.Model, Tokens: tokens})
}
func (s *Server) DetokenizeHandler(c *gin.Context) {
var req api.DetokenizeRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
text, err := r.Detokenize(c.Request.Context(), req.Tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, api.DetokenizeResponse{Model: req.Model, Text: text})
}
func (s *Server) PullHandler(c *gin.Context) {
var req api.PullRequest
err := c.ShouldBindJSON(&req)
@ -1086,6 +1136,8 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/tokenize", s.TokenizeHandler)
r.POST("/api/detokenize", s.DetokenizeHandler)
r.POST("/api/create", s.CreateHandler)
r.POST("/api/push", s.PushHandler)
r.POST("/api/copy", s.CopyHandler)