diff --git a/api/types.go b/api/types.go index e5291a02..cc26f948 100644 --- a/api/types.go +++ b/api/types.go @@ -291,6 +291,8 @@ type EmbeddingRequest struct { // EmbeddingResponse is the response from [Client.Embeddings]. type EmbeddingResponse struct { Embedding []float64 `json:"embedding"` + + Metrics } // CreateRequest is the request passed to [Client.Create]. diff --git a/cmd/cmd.go b/cmd/cmd.go index b8c9c640..0e67130a 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -512,6 +512,39 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generate(cmd, opts) } +func EmbeddingsHandler(cmd *cobra.Command, args []string) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + name := args[0] + + show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name}) + var statusError api.StatusError + switch { + case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound: + if err := PullHandler(cmd, []string{name}); err != nil { + return err + } + + show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name}) + if err != nil { + return err + } + case err != nil: + return err + } + + opts := runOptions{ + Model: name, + Prompt: strings.Join(args[1:], " "), + ParentModel: show.Details.ParentModel, + } + + return embeddings(cmd, opts) +} + func errFromUnknownKey(unknownKeyErr error) error { // find SSH public key in the error message sshKeyPattern := `ssh-\w+ [^\s"]+` @@ -1199,6 +1232,49 @@ func generate(cmd *cobra.Command, opts runOptions) error { return nil } +func embeddings(cmd *cobra.Command, opts runOptions) error { + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + p := progress.NewProgress(os.Stderr) + defer p.StopAndClear() + + spinner := progress.NewSpinner("") + p.Add("", spinner) + + req := api.EmbeddingRequest{ + Model: opts.Model, + Prompt: opts.Prompt, + Options: opts.Options, + } + + response, err := client.Embeddings(cmd.Context(), &req) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil + } + return err + } + + // cast response.Embedding to a string + embedding := fmt.Sprintf("%v", response.Embedding) + p.StopAndClear() + fmt.Println(embedding) + + verbose, err := cmd.Flags().GetBool("verbose") + if err != nil { + return err + } + + if verbose { + response.Summary() + } + + return nil +} + func RunServer(_ *cobra.Command, _ []string) error { if err := initializeKeypair(); err != nil { return err @@ -1381,6 +1457,18 @@ func NewCLI() *cobra.Command { runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().String("format", "", "Response format (e.g. json)") + embsCmd := &cobra.Command{ + Use: "embeddings MODEL [PROMPT]", + Short: "Get embeddings from a model", + Args: cobra.MinimumNArgs(1), + PreRunE: checkServerHeartbeat, + RunE: EmbeddingsHandler, + } + + embsCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)") + embsCmd.Flags().Bool("verbose", false, "Show timings for response") + embsCmd.Flags().Bool("insecure", false, "Use an insecure registry") + stopCmd := &cobra.Command{ Use: "stop MODEL", Short: "Stop a running model", @@ -1456,6 +1544,7 @@ func NewCLI() *cobra.Command { createCmd, showCmd, runCmd, + embsCmd, stopCmd, pullCmd, pushCmd, @@ -1496,6 +1585,7 @@ func NewCLI() *cobra.Command { createCmd, showCmd, runCmd, + embsCmd, stopCmd, pullCmd, pushCmd,