diff --git a/client/registry/apitype/apitype.go b/client/registry/apitype/apitype.go new file mode 100644 index 00000000..1a246a62 --- /dev/null +++ b/client/registry/apitype/apitype.go @@ -0,0 +1,95 @@ +package apitype + +import ( + "cmp" + "encoding/json" + "log/slog" + "net/url" + "slices" +) + +type Manifest struct { + Layers []*Layer `json:"layers"` +} + +type CompletePart struct { + URL string `json:"url"` // contains partNumber and uploadId from server + ETag string `json:"etag"` +} + +func queryFromString(s string) url.Values { + u, err := url.Parse(s) + if err != nil { + return nil + } + return u.Query() +} + +func (cp *CompletePart) Compare(o *CompletePart) int { + qa := queryFromString(cp.URL) + qb := queryFromString(o.URL) + return cmp.Or( + cmp.Compare(qa.Get("partNumber"), qb.Get("partNumber")), + cmp.Compare(qa.Get("uploadId"), qb.Get("uploadId")), + cmp.Compare(cp.ETag, o.ETag), + ) +} + +func SortCompleteParts(a []*CompletePart) { + slices.SortFunc(a, (*CompletePart).Compare) +} + +type Layer struct { + Digest string `json:"digest"` + MediaType string `json:"mediaType"` + Size int64 `json:"size"` + + // If present, URL is a remote location of the layer for fetching. + URL string `json:"url,omitempty"` +} + +func (l *Layer) LogValue() slog.Value { + return slog.GroupValue( + slog.String("digest", l.Digest), + slog.String("mediaType", l.MediaType), + slog.Int64("size", l.Size), + slog.String("url", l.URL), + ) +} + +type PushRequest struct { + Name string `json:"ref"` + Manifest json.RawMessage `json:"manifest,omitempty"` + + // Parts is a list of upload parts that the client upload in the previous + // push. + CompleteParts []*CompletePart `json:"part_uploads"` +} + +type Need struct { + Digest string `json:"digest"` + + Start int64 `json:"start"` + End int64 `json:"end"` + + // URL is the url to PUT the layer to. + // + // Clients must include it as the URL, along with the ETag in the + // response headers from the PUT request, in the next push request + // in the Uploaded field. + URL string `json:"url"` +} + +type PushResponse struct { + // Needs is a list of digests that the client needs to push before + // repushing the manifest. + Needs []*Need `json:"requirements,omitempty"` +} + +type PullResponse struct { + // Name is the name of the model being pulled. + Name string `json:"name"` + + // Manifest is the manifest of the model being pulled. + Manifest *Manifest `json:"manifest"` +} diff --git a/client/registry/registry.go b/client/registry/registry.go new file mode 100644 index 00000000..04fb7864 --- /dev/null +++ b/client/registry/registry.go @@ -0,0 +1,384 @@ +package registry + +import ( + "cmp" + "context" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "io" + "iter" + "log/slog" + "net/http" + "net/url" + "os" + "sync" + + "github.com/ollama/ollama/client/ollama" + "github.com/ollama/ollama/client/registry/apitype" + "golang.org/x/exp/constraints" + "golang.org/x/sync/errgroup" +) + +// Errors +var ( + ErrLayerNotFound = errors.New("layer not found") +) + +type Client struct { + BaseURL string + + // TODO(bmizerany): remove NameFill (once we remove model dep here) + NameFill string + + Logger *slog.Logger +} + +func (c *Client) logger() *slog.Logger { + return cmp.Or(c.Logger, slog.Default()) +} + +func (c *Client) oclient() *ollama.Client { + return &ollama.Client{ + BaseURL: c.BaseURL, + } +} + +type ReadAtSeekCloser interface { + io.ReaderAt + io.Seeker + io.Closer +} + +type Cache interface { + // LayerFile returns the absolute file path to the layer file for + // the given model digest. + // + // If the digest is invalid, or the layer does not exist, the empty + // string is returned. + LayerFile(digest string) string + + // OpenLayer opens the layer file for the given model digest and + // returns it, or an if any. The caller is responsible for closing + // the returned file. + OpenLayer(digest string) (ReadAtSeekCloser, error) + + // PutLayerFile moves the layer file at fromPath to the cache for + // the given model digest. It is a hack intended to short circuit a + // file copy operation. + // + // TODO(bmizerany): remove this; find a better way. Once we move + // this into a build package, we should be able to get rid of this. + PutLayerFile(digest, fromPath string) error + + // SetManifestData sets the provided manifest data for the given + // model name. If the manifest data is empty, the manifest is + // removed. If the manifeest exists, it is overwritten. + // + // It is an error to call SetManifestData with a name that is not + // complete. + SetManifestData(name string, data []byte) error + + // ManifestData returns the manifest data for the given model name. + // + // If the name incomplete, or the manifest does not exist, the empty + // string is returned. + ManifestData(name string) []byte +} + +// Pull pulls the manifest for name, and downloads any of its required +// layers that are not already in the cache. It returns an error if any part +// of the process fails, specifically: +func (c *Client) Pull(ctx context.Context, cache Cache, name string) error { + log := c.logger().With("name", name) + + pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil) + if err != nil { + return fmt.Errorf("ollama: pull: %w: %s", err, name) + } + + if pr.Manifest == nil || len(pr.Manifest.Layers) == 0 { + return fmt.Errorf("ollama: pull: invalid manifest: %s: no layers found", name) + } + + // download required layers we do not already have + for _, l := range pr.Manifest.Layers { + if cache.LayerFile(l.Digest) != "" { + continue + } + err := func() error { + log := log.With("digest", l.Digest, "mediaType", l.MediaType, "size", l.Size) + log.Debug("starting download") + + // TODO(bmizerany): stop using temp which might not + // be on same device as cache.... instead let cache + // give us a place to store parts... + tmpFile, err := os.CreateTemp("", "ollama-download-") + if err != nil { + return err + } + defer func() { + tmpFile.Close() + os.Remove(tmpFile.Name()) // in case we fail before committing + }() + + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(8) // TODO(bmizerany): make this configurable + + // TODO(bmizerany): make chunk size configurable + const chunkSize = 50 * 1024 * 1024 // 50MB + chunks(l.Size, chunkSize)(func(_ int, rng chunkRange[int64]) bool { + g.Go(func() (err error) { + defer func() { + if err == nil { + return + } + safeURL := redactAmzSignature(l.URL) + err = fmt.Errorf("%w: %s %s bytes=%s: %s", err, pr.Name, l.Digest, rng, safeURL) + }() + + log.Debug("downloading", "range", rng) + + // TODO(bmizerany): retry + // TODO(bmizerany): use real http client + // TODO(bmizerany): resumable + // TODO(bmizerany): multipart download + req, err := http.NewRequestWithContext(ctx, "GET", l.URL, nil) + if err != nil { + return err + } + req.Header.Set("Range", "bytes="+rng.String()) + + res, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode/100 != 2 { + log.Debug("unexpected non-2XX status code", "status", res.StatusCode) + return fmt.Errorf("unexpected status code fetching layer: %d", res.StatusCode) + } + if res.ContentLength != rng.Size() { + return fmt.Errorf("unexpected content length: %d", res.ContentLength) + } + w := io.NewOffsetWriter(tmpFile, rng.Start) + _, err = io.Copy(w, res.Body) + return err + }) + return true + }) + if err := g.Wait(); err != nil { + return err + } + + tmpFile.Close() // release our hold on the file before moving it + return cache.PutLayerFile(l.Digest, tmpFile.Name()) + }() + if err != nil { + return fmt.Errorf("ollama: pull: %w", err) + } + } + + // do not store the presigned URLs in the cache + for i := range pr.Manifest.Layers { + pr.Manifest.Layers[i].URL = "" + } + data, err := json.Marshal(pr.Manifest) + if err != nil { + return err + } + + // TODO(bmizerany): remove dep on model.Name + return cache.SetManifestData(name, data) +} + +type nopSeeker struct { + io.Reader +} + +func (nopSeeker) Seek(int64, int) (int64, error) { + return 0, nil +} + +// Push pushes a manifest to the server and responds to the server's +// requests for layer uploads, if any, and finally commits the manifest for +// name. It returns an error if any part of the process fails, specifically: +// +// If the server requests layers not found in the cache, ErrLayerNotFound is +// returned. +func (c *Client) Push(ctx context.Context, cache Cache, name string) error { + // TODO(bmizerany): remove dep on model.Name + manifest := cache.ManifestData(name) + if len(manifest) == 0 { + return fmt.Errorf("manifest not found: %s", name) + } + + var mu sync.Mutex + var completed []*apitype.CompletePart + push := func() (*apitype.PushResponse, error) { + v, err := ollama.Do[*apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{ + Name: name, + Manifest: manifest, + CompleteParts: completed, + }) + if err != nil { + return nil, fmt.Errorf("Do: %w", err) + } + return v, nil + } + + pr, err := push() + if err != nil { + return err + } + + var g errgroup.Group + for _, need := range pr.Needs { + g.Go(func() error { + f, err := cache.OpenLayer(need.Digest) + if err != nil { + return fmt.Errorf("OpenLayer: %w: %s", err, need.Digest) + } + defer f.Close() + + cp, err := PushLayer(ctx, f, need.URL, need.Start, need.End) + if err != nil { + return fmt.Errorf("PushLayer: %w: %s", err, need.Digest) + } + mu.Lock() + completed = append(completed, cp) + mu.Unlock() + return nil + }) + } + if err := g.Wait(); err != nil { + return fmt.Errorf("Push: Required: %w", err) + } + + if len(completed) > 0 { + pr, err := push() + if err != nil { + return err + } + if len(pr.Needs) > 0 { + var errs []error + for _, r := range pr.Needs { + errs = append(errs, fmt.Errorf("Push: server failed to find part: %q", r.Digest)) + } + return errors.Join(errs...) + } + } + + return cache.SetManifestData(name, manifest) +} + +func PushLayer(ctx context.Context, body io.ReaderAt, url string, start, end int64) (*apitype.CompletePart, error) { + if start < 0 || end < start { + return nil, errors.New("start must satisfy 0 <= start <= end") + } + + file := io.NewSectionReader(body, start, end-start+1) + req, err := http.NewRequest("PUT", url, file) + if err != nil { + return nil, err + } + req.ContentLength = end - start + 1 + + // TODO(bmizerany): take content type param + req.Header.Set("Content-Type", "text/plain") + + if start != 0 || end != 0 { + req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", start, end)) + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != 200 { + e := parseS3Error(res) + return nil, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e) + } + cp := &apitype.CompletePart{ + URL: url, + ETag: res.Header.Get("ETag"), + // TODO(bmizerany): checksum + } + return cp, nil +} + +type s3Error struct { + XMLName xml.Name `xml:"Error"` + Code string `xml:"Code"` + Message string `xml:"Message"` + Resource string `xml:"Resource"` + RequestId string `xml:"RequestId"` +} + +func (e *s3Error) Error() string { + return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message) +} + +// parseS3Error parses an XML error response from S3. +func parseS3Error(res *http.Response) error { + var se *s3Error + if err := xml.NewDecoder(res.Body).Decode(&se); err != nil { + return err + } + return se +} + +// TODO: replace below by using upload pkg after we have rangefunc; until +// then, we need to keep this free of rangefunc for now. +type chunkRange[I constraints.Integer] struct { + // Start is the byte offset of the chunk. + Start I + + // End is the byte offset of the last byte in the chunk. + End I +} + +func (c chunkRange[I]) Size() I { + return c.End - c.Start + 1 +} + +func (c chunkRange[I]) String() string { + return fmt.Sprintf("%d-%d", c.Start, c.End) +} + +func (c chunkRange[I]) LogValue() slog.Value { + return slog.StringValue(c.String()) +} + +// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset +// and size of the chunk. The last chunk may be smaller than chunkSize if size is +// not a multiple of chunkSize. +// +// The first part number is 1 and increases monotonically. +func chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, chunkRange[I]] { + return func(yield func(int, chunkRange[I]) bool) { + var n int + for off := I(0); off < size; off += chunkSize { + n++ + if !yield(n, chunkRange[I]{ + Start: off, + End: off + min(chunkSize, size-off) - 1, + }) { + return + } + } + } +} + +func redactAmzSignature(s string) string { + u, err := url.Parse(s) + if err != nil { + return "" + } + q := u.Query() + q.Set("X-Amz-Signature", "REDACTED") + u.RawQuery = q.Encode() + return u.String() +}