diff --git a/server/images.go b/server/images.go index 73efc1c5..e8111b72 100644 --- a/server/images.go +++ b/server/images.go @@ -1113,7 +1113,11 @@ func GetSHA256Digest(r io.Reader) (string, int) { return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n) } +type requestContextKey string + func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *RegistryOptions) (string, error) { + retry, _ := ctx.Value(requestContextKey("retry")).(int) + url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository()) if layer.From != "" { url = fmt.Sprintf("%s/v2/%s/blobs/uploads/?mount=%s&from=%s", mp.Registry, mp.GetNamespaceRepository(), layer.Digest, layer.From) @@ -1126,8 +1130,25 @@ func startUpload(ctx context.Context, mp ModelPath, layer *Layer, regOpts *Regis } defer resp.Body.Close() - // Check for success - if resp.StatusCode != http.StatusAccepted && resp.StatusCode != http.StatusCreated { + switch resp.StatusCode { + case http.StatusAccepted, http.StatusCreated: + // noop + case http.StatusUnauthorized: + if retry > MaxRetries { + return "", fmt.Errorf("max retries exceeded: %s", resp.Status) + } + + auth := resp.Header.Get("www-authenticate") + authRedir := ParseAuthRedirectString(auth) + token, err := getAuthToken(ctx, authRedir, regOpts) + if err != nil { + return "", err + } + + regOpts.Token = token + ctx = context.WithValue(ctx, requestContextKey("retry"), retry+1) + return startUpload(ctx, mp, layer, regOpts) + default: body, _ := io.ReadAll(resp.Body) return "", fmt.Errorf("on upload registry responded with code %d: %s", resp.StatusCode, body) } @@ -1229,15 +1250,6 @@ func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Lay } func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { - retryCtx := ctx.Value("retries") - var retries int - var ok bool - if retries, ok = retryCtx.(int); ok { - if retries > MaxRetries { - return nil, fmt.Errorf("maximum retries hit; are you sure you have access to this resource?") - } - } - if !strings.HasPrefix(url, "http") { if regOpts.Insecure { url = "http://" + url @@ -1246,18 +1258,7 @@ func makeRequest(ctx context.Context, method, url string, headers map[string]str } } - // make a copy of the body in case we need to try the call to makeRequest again - var buf bytes.Buffer - if body != nil { - _, err := io.Copy(&buf, body) - if err != nil { - return nil, err - } - } - - bodyCopy := bytes.NewReader(buf.Bytes()) - - req, err := http.NewRequest(method, url, bodyCopy) + req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { return nil, err } @@ -1281,25 +1282,12 @@ func makeRequest(ctx context.Context, method, url string, headers map[string]str return nil }, } + resp, err := client.Do(req) if err != nil { return nil, err } - // if the request is unauthenticated, try to authenticate and make the request again - if resp.StatusCode == http.StatusUnauthorized { - auth := resp.Header.Get("Www-Authenticate") - authRedir := ParseAuthRedirectString(string(auth)) - token, err := getAuthToken(ctx, authRedir, regOpts) - if err != nil { - return nil, err - } - regOpts.Token = token - bodyCopy = bytes.NewReader(buf.Bytes()) - ctx = context.WithValue(ctx, "retries", retries+1) - return makeRequest(ctx, method, url, headers, bodyCopy, regOpts) - } - return resp, nil }