Merge 10802044b8dce04a4aa816d169cb1bbf7cf30e5a into d7eb05b9361febead29a74e71ddffc2ebeff5302

This commit is contained in:
Adrian Hesketh 2024-11-14 13:55:04 +08:00 committed by GitHub
commit 125ce5d046
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 78 additions and 30 deletions

1
.gitignore vendored
View File

@ -15,4 +15,5 @@ build/*/*/*
!build/**/placeholder
llama/build
__debug_bin*
.ccls-cache
llama/vendor

View File

@ -907,13 +907,57 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
return nil
}
type pullFn func(ctx context.Context, name string, fn api.PullProgressFunc) error
func getLocalPuller(insecure bool) (p pullFn, err error) {
p = func(ctx context.Context, name string, fn api.PullProgressFunc) error {
if err := initializeKeypair(); err != nil {
return err
}
opts := &server.RegistryOptions{
Insecure: insecure,
}
f := func(r api.ProgressResponse) {
fn(r)
}
return server.PullModel(ctx, name, opts, f)
}
return p, nil
}
func getAPIPuller(insecure bool) (p pullFn, err error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return p, err
}
p = func(ctx context.Context, name string, fn api.PullProgressFunc) error {
if err := preflightCheck(ctx, client); err != nil {
return err
}
req := api.PullRequest{Name: name, Insecure: insecure}
return client.Pull(ctx, &req, fn)
}
return p, nil
}
func PullHandler(cmd *cobra.Command, args []string) error {
insecure, err := cmd.Flags().GetBool("insecure")
if err != nil {
return err
}
local, err := cmd.Flags().GetBool("local")
if err != nil {
return err
}
client, err := api.ClientFromEnvironment()
var pf pullFn
if local {
fmt.Println("pulling to local machine")
pf, err = getLocalPuller(insecure)
} else {
fmt.Println("requesting pull from server")
pf, err = getAPIPuller(insecure)
}
if err != nil {
return err
}
@ -953,8 +997,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
return nil
}
request := api.PullRequest{Name: args[0], Insecure: insecure}
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
if err := pf(cmd.Context(), args[0], fn); err != nil {
return err
}
@ -1268,11 +1311,15 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
if err != nil {
return err
}
if err := client.Heartbeat(cmd.Context()); err != nil {
return preflightCheck(cmd.Context(), client)
}
func preflightCheck(ctx context.Context, client *api.Client) error {
if err := client.Heartbeat(ctx); err != nil {
if !strings.Contains(err.Error(), " refused") {
return err
}
if err := startApp(cmd.Context(), client); err != nil {
if err := startApp(ctx, client); err != nil {
return errors.New("could not connect to ollama app, is it running?")
}
}
@ -1398,14 +1445,14 @@ func NewCLI() *cobra.Command {
}
pullCmd := &cobra.Command{
Use: "pull MODEL",
Short: "Pull a model from a registry",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: PullHandler,
Use: "pull MODEL",
Short: "Pull a model from a registry",
Args: cobra.ExactArgs(1),
RunE: PullHandler,
}
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
pullCmd.Flags().Bool("local", false, "Pull to the local machine instead of making a server request")
pushCmd := &cobra.Command{
Use: "push MODEL",

View File

@ -67,7 +67,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st
headers.Add("Authorization", signature)
response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, &registryOptions{})
response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, &RegistryOptions{})
if err != nil {
return "", err
}

View File

@ -122,7 +122,7 @@ func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
return n, nil
}
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
if err != nil {
return err
@ -176,7 +176,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
return nil
}
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) {
defer close(b.done)
b.err = b.run(ctx, requestURL, opts)
}
@ -207,7 +207,7 @@ func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error {
}
}
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
defer blobDownloadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
@ -228,7 +228,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
for {
// shallow clone opts to be used in the closure
// without affecting the outer opts.
newOpts := new(registryOptions)
newOpts := new(RegistryOptions)
*newOpts = *opts
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
@ -454,7 +454,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
type downloadOpts struct {
mp ModelPath
digest string
regOpts *registryOptions
regOpts *RegistryOptions
fn func(api.ProgressResponse)
}

View File

@ -50,7 +50,7 @@ const (
CapabilityInsert = Capability("insert")
)
type registryOptions struct {
type RegistryOptions struct {
Insecure bool
Username string
Password string
@ -797,7 +797,7 @@ func PruneDirectory(path string) error {
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
@ -846,7 +846,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil
}
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
// build deleteMap to prune unused layers
@ -952,7 +952,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header)
@ -1013,7 +1013,7 @@ func getTokenSubject(token string) string {
return fmt.Sprintf("%s", sub)
}
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
anonymous := true // access will default to anonymous if no user is found associated with the public key
for range 2 {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
@ -1071,7 +1071,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
return nil, errUnauthorized
}
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
requestURL.Scheme = "http"
}

View File

@ -34,7 +34,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
m, err := ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
if err := PullModel(ctx, name.String(), &RegistryOptions{}, fn); err != nil {
return nil, err
}

View File

@ -552,7 +552,7 @@ func (s *Server) PullHandler(c *gin.Context) {
ch <- r
}
regOpts := &registryOptions{
regOpts := &RegistryOptions{
Insecure: req.Insecure,
}
@ -601,7 +601,7 @@ func (s *Server) PushHandler(c *gin.Context) {
ch <- r
}
regOpts := &registryOptions{
regOpts := &RegistryOptions{
Insecure: req.Insecure,
}

View File

@ -50,7 +50,7 @@ const (
maxUploadPartSize int64 = 1000 * format.MegaByte
)
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
p, err := GetBlobsPath(b.Digest)
if err != nil {
return err
@ -122,7 +122,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
@ -211,7 +211,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
b.done = true
}
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *registryOptions) error {
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", strconv.FormatInt(part.Size, 10))
@ -256,7 +256,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
// retry uploading to the redirect URL
for try := range maxRetries {
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, &registryOptions{})
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, &RegistryOptions{})
switch {
case errors.Is(err, context.Canceled):
return err
@ -362,7 +362,7 @@ func (p *progressWriter) Rollback() {
p.written = 0
}
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)