Merge 9fbfa74219fa74d030c816cc99d3bee0d1c1923d into d7eb05b9361febead29a74e71ddffc2ebeff5302

This commit is contained in:
Javier Martinez 2024-11-14 13:56:55 +08:00 committed by GitHub
commit 74423ce545
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 92 additions and 0 deletions

View File

@ -291,6 +291,8 @@ type EmbeddingRequest struct {
// EmbeddingResponse is the response from [Client.Embeddings].
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
Metrics
}
// CreateRequest is the request passed to [Client.Create].

View File

@ -512,6 +512,39 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generate(cmd, opts)
}
func EmbeddingsHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
name := args[0]
show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name})
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
if err := PullHandler(cmd, []string{name}); err != nil {
return err
}
show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
if err != nil {
return err
}
case err != nil:
return err
}
opts := runOptions{
Model: name,
Prompt: strings.Join(args[1:], " "),
ParentModel: show.Details.ParentModel,
}
return embeddings(cmd, opts)
}
func errFromUnknownKey(unknownKeyErr error) error {
// find SSH public key in the error message
sshKeyPattern := `ssh-\w+ [^\s"]+`
@ -1199,6 +1232,49 @@ func generate(cmd *cobra.Command, opts runOptions) error {
return nil
}
func embeddings(cmd *cobra.Command, opts runOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
req := api.EmbeddingRequest{
Model: opts.Model,
Prompt: opts.Prompt,
Options: opts.Options,
}
response, err := client.Embeddings(cmd.Context(), &req)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
return err
}
// cast response.Embedding to a string
embedding := fmt.Sprintf("%v", response.Embedding)
p.StopAndClear()
fmt.Println(embedding)
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return err
}
if verbose {
response.Summary()
}
return nil
}
func RunServer(_ *cobra.Command, _ []string) error {
if err := initializeKeypair(); err != nil {
return err
@ -1381,6 +1457,18 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)")
embsCmd := &cobra.Command{
Use: "embeddings MODEL [PROMPT]",
Short: "Get embeddings from a model",
Args: cobra.MinimumNArgs(1),
PreRunE: checkServerHeartbeat,
RunE: EmbeddingsHandler,
}
embsCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
embsCmd.Flags().Bool("verbose", false, "Show timings for response")
embsCmd.Flags().Bool("insecure", false, "Use an insecure registry")
stopCmd := &cobra.Command{
Use: "stop MODEL",
Short: "Stop a running model",
@ -1456,6 +1544,7 @@ func NewCLI() *cobra.Command {
createCmd,
showCmd,
runCmd,
embsCmd,
stopCmd,
pullCmd,
pushCmd,
@ -1496,6 +1585,7 @@ func NewCLI() *cobra.Command {
createCmd,
showCmd,
runCmd,
embsCmd,
stopCmd,
pullCmd,
pushCmd,