From bb6ab2391ca8965baf312394fe6d6a5af9606a7a Mon Sep 17 00:00:00 2001 From: Javier Martinez Date: Fri, 21 Jun 2024 11:17:34 +0200 Subject: [PATCH] allow to return embeddings using cli --- api/types.go | 2 ++ cmd/cmd.go | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/api/types.go b/api/types.go index 0a1189e7..dd7ad467 100644 --- a/api/types.go +++ b/api/types.go @@ -223,6 +223,8 @@ type EmbeddingRequest struct { // EmbeddingResponse is the response from [Client.Embeddings]. type EmbeddingResponse struct { Embedding []float64 `json:"embedding"` + + Metrics } // CreateRequest is the request passed to [Client.Create]. diff --git a/cmd/cmd.go b/cmd/cmd.go index 68197f72..a54b2f6d 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -369,6 +369,39 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generateInteractive(cmd, opts) } +func EmbeddingsHandler(cmd *cobra.Command, args []string) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + name := args[0] + + show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name}) + var statusError api.StatusError + switch { + case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound: + if err := PullHandler(cmd, []string{name}); err != nil { + return err + } + + show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name}) + if err != nil { + return err + } + case err != nil: + return err + } + + opts := runOptions{ + Model: name, + Prompt: strings.Join(args[1:], " "), + ParentModel: show.Details.ParentModel, + } + + return embeddings(cmd, opts) +} + func errFromUnknownKey(unknownKeyErr error) error { // find SSH public key in the error message sshKeyPattern := `ssh-\w+ [^\s"]+` @@ -1068,6 +1101,49 @@ func generate(cmd *cobra.Command, opts runOptions) error { return nil } +func embeddings(cmd *cobra.Command, opts runOptions) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + p := progress.NewProgress(os.Stderr) + defer p.StopAndClear() + + spinner := progress.NewSpinner("") + p.Add("", spinner) + + req := api.EmbeddingRequest{ + Model: opts.Model, + Prompt: opts.Prompt, + Options: opts.Options, + } + + response, err := client.Embeddings(cmd.Context(), &req) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err + } + + // cast response.Embedding to a string + embedding := fmt.Sprintf("%v", response.Embedding) + p.StopAndClear() + fmt.Println(embedding) + + verbose, err := cmd.Flags().GetBool("verbose") + if err != nil { + return err + } + + if verbose { + response.Summary() + } + + return nil +} + func RunServer(cmd *cobra.Command, _ []string) error { if err := initializeKeypair(); err != nil { return err @@ -1249,6 +1325,19 @@ func NewCLI() *cobra.Command { runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().String("format", "", "Response format (e.g. json)") + + embsCmd := &cobra.Command{ + Use: "embeddings MODEL [PROMPT]", + Short: "Get embeddings from a model", + Args: cobra.MinimumNArgs(1), + PreRunE: checkServerHeartbeat, + RunE: EmbeddingsHandler, + } + + embsCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)") + embsCmd.Flags().Bool("verbose", false, "Show timings for response") + embsCmd.Flags().Bool("insecure", false, "Use an insecure registry") + serveCmd := &cobra.Command{ Use: "serve", Aliases: []string{"start"}, @@ -1316,6 +1405,7 @@ func NewCLI() *cobra.Command { createCmd, showCmd, runCmd, + embsCmd, pullCmd, pushCmd, listCmd, @@ -1353,6 +1443,7 @@ func NewCLI() *cobra.Command { createCmd, showCmd, runCmd, + embsCmd, pullCmd, pushCmd, listCmd,