cli: pull models without starting server, fixes #3369
This commit is contained in:
parent
f40bb398f6
commit
53076d1bd1
3
.gitignore
vendored
3
.gitignore
vendored
@ -15,4 +15,5 @@ llm/build
|
||||
build/*/*/*
|
||||
!build/**/placeholder
|
||||
llama/build
|
||||
__debug_bin*
|
||||
__debug_bin*
|
||||
.ccls-cache
|
||||
|
49
cmd/cmd.go
49
cmd/cmd.go
@ -873,17 +873,58 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type pullFn func(ctx context.Context, name string, fn api.PullProgressFunc) error
|
||||
|
||||
func getLocalPuller(insecure bool) (p pullFn, err error) {
|
||||
p = func(ctx context.Context, name string, fn api.PullProgressFunc) error {
|
||||
opts := &server.RegistryOptions{
|
||||
Insecure: insecure,
|
||||
}
|
||||
f := func(r api.ProgressResponse) {
|
||||
fn(r)
|
||||
}
|
||||
return server.PullModel(ctx, name, opts, f)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func getAPIPuller(insecure bool) (p pullFn, err error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return p, err
|
||||
}
|
||||
p = func(ctx context.Context, name string, fn api.PullProgressFunc) error {
|
||||
req := api.PullRequest{Name: name, Insecure: insecure}
|
||||
return client.Pull(ctx, &req, fn)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
insecure, err := cmd.Flags().GetBool("insecure")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
local, err := cmd.Flags().GetBool("local")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var pf pullFn
|
||||
if local {
|
||||
fmt.Println("pulling to local machine")
|
||||
pf, err = getLocalPuller(insecure)
|
||||
} else {
|
||||
fmt.Println("requesting pull from server")
|
||||
pf, err = getAPIPuller(insecure)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pf == nil {
|
||||
return fmt.Errorf("could not get puller: %w", err)
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
@ -919,8 +960,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: args[0], Insecure: insecure}
|
||||
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
|
||||
if err := pf(cmd.Context(), args[0], fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1372,6 +1412,7 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
|
||||
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
pullCmd.Flags().Bool("local", false, "Pull to the local machine instead of making a server request")
|
||||
|
||||
pushCmd := &cobra.Command{
|
||||
Use: "push MODEL",
|
||||
|
@ -67,7 +67,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st
|
||||
|
||||
headers.Add("Authorization", signature)
|
||||
|
||||
response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, ®istryOptions{})
|
||||
response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, &RegistryOptions{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -122,7 +122,7 @@ func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
|
||||
if err != nil {
|
||||
return err
|
||||
@ -176,7 +176,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
|
||||
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) {
|
||||
defer close(b.done)
|
||||
b.err = b.run(ctx, requestURL, opts)
|
||||
}
|
||||
@ -207,7 +207,7 @@ func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||
defer blobDownloadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
@ -228,7 +228,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
for {
|
||||
// shallow clone opts to be used in the closure
|
||||
// without affecting the outer opts.
|
||||
newOpts := new(registryOptions)
|
||||
newOpts := new(RegistryOptions)
|
||||
*newOpts = *opts
|
||||
|
||||
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
@ -454,7 +454,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
type downloadOpts struct {
|
||||
mp ModelPath
|
||||
digest string
|
||||
regOpts *registryOptions
|
||||
regOpts *RegistryOptions
|
||||
fn func(api.ProgressResponse)
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,7 @@ const (
|
||||
CapabilityInsert = Capability("insert")
|
||||
)
|
||||
|
||||
type registryOptions struct {
|
||||
type RegistryOptions struct {
|
||||
Insecure bool
|
||||
Username string
|
||||
Password string
|
||||
@ -795,7 +795,7 @@ func PruneDirectory(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
@ -844,7 +844,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
return nil
|
||||
}
|
||||
|
||||
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
// build deleteMap to prune unused layers
|
||||
@ -950,7 +950,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
return nil
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*Manifest, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
|
||||
headers := make(http.Header)
|
||||
@ -1011,7 +1011,7 @@ func getTokenSubject(token string) string {
|
||||
return fmt.Sprintf("%s", sub)
|
||||
}
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
anonymous := true // access will default to anonymous if no user is found associated with the public key
|
||||
for range 2 {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
@ -1069,7 +1069,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
|
||||
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
|
||||
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
requestURL.Scheme = "http"
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
m, err := ParseNamedManifest(name)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
||||
if err := PullModel(ctx, name.String(), &RegistryOptions{}, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -521,7 +521,7 @@ func (s *Server) PullHandler(c *gin.Context) {
|
||||
ch <- r
|
||||
}
|
||||
|
||||
regOpts := ®istryOptions{
|
||||
regOpts := &RegistryOptions{
|
||||
Insecure: req.Insecure,
|
||||
}
|
||||
|
||||
@ -570,7 +570,7 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
ch <- r
|
||||
}
|
||||
|
||||
regOpts := ®istryOptions{
|
||||
regOpts := &RegistryOptions{
|
||||
Insecure: req.Insecure,
|
||||
}
|
||||
|
||||
|
@ -50,7 +50,7 @@ const (
|
||||
maxUploadPartSize int64 = 1000 * format.MegaByte
|
||||
)
|
||||
|
||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -122,7 +122,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
|
||||
|
||||
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
|
||||
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
|
||||
func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
||||
func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
|
||||
defer blobUploadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
@ -211,7 +211,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
||||
b.done = true
|
||||
}
|
||||
|
||||
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *registryOptions) error {
|
||||
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/octet-stream")
|
||||
headers.Set("Content-Length", strconv.FormatInt(part.Size, 10))
|
||||
@ -256,7 +256,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||
|
||||
// retry uploading to the redirect URL
|
||||
for try := range maxRetries {
|
||||
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, ®istryOptions{})
|
||||
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, &RegistryOptions{})
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled):
|
||||
return err
|
||||
@ -362,7 +362,7 @@ func (p *progressWriter) Rollback() {
|
||||
p.written = 0
|
||||
}
|
||||
|
||||
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user