From 53076d1bd13703c9a743d9e212225359e8f55136 Mon Sep 17 00:00:00 2001 From: Adrian Hesketh Date: Fri, 27 Sep 2024 11:12:34 +0100 Subject: [PATCH 1/4] cli: pull models without starting server, fixes #3369 --- .gitignore | 3 ++- cmd/cmd.go | 49 ++++++++++++++++++++++++++++++++++++++++++---- server/auth.go | 2 +- server/download.go | 10 +++++----- server/images.go | 12 ++++++------ server/model.go | 2 +- server/routes.go | 4 ++-- server/upload.go | 10 +++++----- 8 files changed, 67 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index 87f8b007..795719a5 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ llm/build build/*/*/* !build/**/placeholder llama/build -__debug_bin* \ No newline at end of file +__debug_bin* +.ccls-cache diff --git a/cmd/cmd.go b/cmd/cmd.go index dc288e43..03131394 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -873,17 +873,58 @@ func CopyHandler(cmd *cobra.Command, args []string) error { return nil } +type pullFn func(ctx context.Context, name string, fn api.PullProgressFunc) error + +func getLocalPuller(insecure bool) (p pullFn, err error) { + p = func(ctx context.Context, name string, fn api.PullProgressFunc) error { + opts := &server.RegistryOptions{ + Insecure: insecure, + } + f := func(r api.ProgressResponse) { + fn(r) + } + return server.PullModel(ctx, name, opts, f) + } + return p, nil +} + +func getAPIPuller(insecure bool) (p pullFn, err error) { + client, err := api.ClientFromEnvironment() + if err != nil { + return p, err + } + p = func(ctx context.Context, name string, fn api.PullProgressFunc) error { + req := api.PullRequest{Name: name, Insecure: insecure} + return client.Pull(ctx, &req, fn) + } + return p, nil +} + func PullHandler(cmd *cobra.Command, args []string) error { insecure, err := cmd.Flags().GetBool("insecure") if err != nil { return err } - - client, err := api.ClientFromEnvironment() + local, err := cmd.Flags().GetBool("local") if err != nil { return err } + var pf pullFn + if local { + fmt.Println("pulling to local machine") + pf, err = getLocalPuller(insecure) + } else { + fmt.Println("requesting pull from server") + pf, err = getAPIPuller(insecure) + } + if err != nil { + return err + } + if pf == nil { + return fmt.Errorf("could not get puller: %w", err) + } + p := progress.NewProgress(os.Stderr) defer p.Stop() @@ -919,8 +960,7 @@ func PullHandler(cmd *cobra.Command, args []string) error { return nil } - request := api.PullRequest{Name: args[0], Insecure: insecure} - if err := client.Pull(cmd.Context(), &request, fn); err != nil { + if err := pf(cmd.Context(), args[0], fn); err != nil { return err } @@ -1372,6 +1412,7 @@ func NewCLI() *cobra.Command { } pullCmd.Flags().Bool("insecure", false, "Use an insecure registry") + pullCmd.Flags().Bool("local", false, "Pull to the local machine instead of making a server request") pushCmd := &cobra.Command{ Use: "push MODEL", diff --git a/server/auth.go b/server/auth.go index dcef5bf9..f7d27303 100644 --- a/server/auth.go +++ b/server/auth.go @@ -67,7 +67,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st headers.Add("Authorization", signature) - response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, ®istryOptions{}) + response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, &RegistryOptions{}) if err != nil { return "", err } diff --git a/server/download.go b/server/download.go index a3b53189..12ce8886 100644 --- a/server/download.go +++ b/server/download.go @@ -122,7 +122,7 @@ func (p *blobDownloadPart) Write(b []byte) (n int, err error) { return n, nil } -func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { +func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { partFilePaths, err := filepath.Glob(b.Name + "-partial-*") if err != nil { return err @@ -176,7 +176,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r return nil } -func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) { +func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) { defer close(b.done) b.err = b.run(ctx, requestURL, opts) } @@ -207,7 +207,7 @@ func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error { } } -func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { +func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { defer blobDownloadManager.Delete(b.Digest) ctx, b.CancelFunc = context.WithCancel(ctx) @@ -228,7 +228,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis for { // shallow clone opts to be used in the closure // without affecting the outer opts. - newOpts := new(registryOptions) + newOpts := new(RegistryOptions) *newOpts = *opts newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error { @@ -454,7 +454,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) type downloadOpts struct { mp ModelPath digest string - regOpts *registryOptions + regOpts *RegistryOptions fn func(api.ProgressResponse) } diff --git a/server/images.go b/server/images.go index c88edc69..5af4ab01 100644 --- a/server/images.go +++ b/server/images.go @@ -49,7 +49,7 @@ const ( CapabilityInsert = Capability("insert") ) -type registryOptions struct { +type RegistryOptions struct { Insecure bool Username string Password string @@ -795,7 +795,7 @@ func PruneDirectory(path string) error { return nil } -func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { +func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "retrieving manifest"}) @@ -844,7 +844,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu return nil } -func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { +func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) // build deleteMap to prune unused layers @@ -950,7 +950,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu return nil } -func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) { +func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*Manifest, error) { requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) headers := make(http.Header) @@ -1011,7 +1011,7 @@ func getTokenSubject(token string) string { return fmt.Sprintf("%s", sub) } -func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) { +func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) { anonymous := true // access will default to anonymous if no user is found associated with the public key for range 2 { resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) @@ -1069,7 +1069,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR return nil, errUnauthorized } -func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) { +func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure { requestURL.Scheme = "http" } diff --git a/server/model.go b/server/model.go index 124693d3..49487764 100644 --- a/server/model.go +++ b/server/model.go @@ -34,7 +34,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe m, err := ParseNamedManifest(name) switch { case errors.Is(err, os.ErrNotExist): - if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil { + if err := PullModel(ctx, name.String(), &RegistryOptions{}, fn); err != nil { return nil, err } diff --git a/server/routes.go b/server/routes.go index 23f9dbfd..64dca76b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -521,7 +521,7 @@ func (s *Server) PullHandler(c *gin.Context) { ch <- r } - regOpts := ®istryOptions{ + regOpts := &RegistryOptions{ Insecure: req.Insecure, } @@ -570,7 +570,7 @@ func (s *Server) PushHandler(c *gin.Context) { ch <- r } - regOpts := ®istryOptions{ + regOpts := &RegistryOptions{ Insecure: req.Insecure, } diff --git a/server/upload.go b/server/upload.go index 020e8955..d2f68e9e 100644 --- a/server/upload.go +++ b/server/upload.go @@ -50,7 +50,7 @@ const ( maxUploadPartSize int64 = 1000 * format.MegaByte ) -func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { +func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { p, err := GetBlobsPath(b.Digest) if err != nil { return err @@ -122,7 +122,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg // Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded // in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error. -func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) { +func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { defer blobUploadManager.Delete(b.Digest) ctx, b.CancelFunc = context.WithCancel(ctx) @@ -211,7 +211,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) { b.done = true } -func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *registryOptions) error { +func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error { headers := make(http.Header) headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Length", strconv.FormatInt(part.Size, 10)) @@ -256,7 +256,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL * // retry uploading to the redirect URL for try := range maxRetries { - err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, ®istryOptions{}) + err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, &RegistryOptions{}) switch { case errors.Is(err, context.Canceled): return err @@ -362,7 +362,7 @@ func (p *progressWriter) Rollback() { p.written = 0 } -func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error { +func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error { requestURL := mp.BaseURL() requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest) From 10253ce34414f7dcc918f42f092b08adb11ce88b Mon Sep 17 00:00:00 2001 From: Adrian Hesketh Date: Fri, 27 Sep 2024 17:42:11 +0100 Subject: [PATCH 2/4] fix: ensure preflight check is skipped when pulling locally --- cmd/cmd.go | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 03131394..78808194 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -894,6 +894,9 @@ func getAPIPuller(insecure bool) (p pullFn, err error) { return p, err } p = func(ctx context.Context, name string, fn api.PullProgressFunc) error { + if err := preflightCheck(ctx, client); err != nil { + return err + } req := api.PullRequest{Name: name, Insecure: insecure} return client.Pull(ctx, &req, fn) } @@ -1274,11 +1277,15 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { if err != nil { return err } - if err := client.Heartbeat(cmd.Context()); err != nil { + return preflightCheck(cmd.Context(), client) +} + +func preflightCheck(ctx context.Context, client *api.Client) error { + if err := client.Heartbeat(ctx); err != nil { if !strings.Contains(err.Error(), " refused") { return err } - if err := startApp(cmd.Context(), client); err != nil { + if err := startApp(ctx, client); err != nil { return errors.New("could not connect to ollama app, is it running?") } } @@ -1404,11 +1411,10 @@ func NewCLI() *cobra.Command { } pullCmd := &cobra.Command{ - Use: "pull MODEL", - Short: "Pull a model from a registry", - Args: cobra.ExactArgs(1), - PreRunE: checkServerHeartbeat, - RunE: PullHandler, + Use: "pull MODEL", + Short: "Pull a model from a registry", + Args: cobra.ExactArgs(1), + RunE: PullHandler, } pullCmd.Flags().Bool("insecure", false, "Use an insecure registry") From dfcc0caf8c48e3976092848820d1ee78d13e9458 Mon Sep 17 00:00:00 2001 From: Adrian Hesketh Date: Fri, 27 Sep 2024 18:04:20 +0100 Subject: [PATCH 3/4] fix: ensure that keypairs are generated before pulling models --- cmd/cmd.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmd/cmd.go b/cmd/cmd.go index 78808194..9591bf3c 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -877,6 +877,9 @@ type pullFn func(ctx context.Context, name string, fn api.PullProgressFunc) erro func getLocalPuller(insecure bool) (p pullFn, err error) { p = func(ctx context.Context, name string, fn api.PullProgressFunc) error { + if err := initializeKeypair(); err != nil { + return err + } opts := &server.RegistryOptions{ Insecure: insecure, } From 469c1b59d6a062449ca14ceadc361280103b7779 Mon Sep 17 00:00:00 2001 From: Adrian Hesketh Date: Mon, 30 Sep 2024 12:46:30 +0100 Subject: [PATCH 4/4] refactor: remove unnecessary nil check as per code review --- cmd/cmd.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 9591bf3c..73c57851 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -927,9 +927,6 @@ func PullHandler(cmd *cobra.Command, args []string) error { if err != nil { return err } - if pf == nil { - return fmt.Errorf("could not get puller: %w", err) - } p := progress.NewProgress(os.Stderr) defer p.Stop()