From 090d08422b361bcbef82a04d0e6e160caaad8f89 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 2 Oct 2023 13:34:07 -0700 Subject: [PATCH] handle unexpected eofs --- server/download.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/server/download.go b/server/download.go index 6023de31..f3a5b378 100644 --- a/server/download.go +++ b/server/download.go @@ -45,8 +45,6 @@ type blobDownloadPart struct { } func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { - b.done = make(chan struct{}, 1) - partFilePaths, err := filepath.Glob(b.Name + "-partial-*") if err != nil { return err @@ -109,6 +107,9 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis b.Truncate(b.Total) + b.done = make(chan struct{}, 1) + defer close(b.done) + g, ctx := errgroup.WithContext(ctx) g.SetLimit(64) for i := range b.Parts { @@ -154,7 +155,6 @@ func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *Regis return err } - close(b.done) return nil } @@ -174,14 +174,19 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, i defer resp.Body.Close() n, err := io.Copy(w, io.TeeReader(resp.Body, b)) - if err != nil && !errors.Is(err, io.EOF) { + if err != nil && !errors.Is(err, context.Canceled) { // rollback progress b.Completed.Add(-n) return err } part.Completed += n - return b.writePart(partName, part) + if err := b.writePart(partName, part); err != nil { + return err + } + + // return nil or context.Canceled + return err } func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) { @@ -221,6 +226,10 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) ticker := time.NewTicker(60 * time.Millisecond) for { select { + case <-b.done: + if b.Completed.Load() != b.Total { + return io.ErrUnexpectedEOF + } case <-ticker.C: case <-ctx.Done(): if b.refCount.Add(-1) == 0 {