856 lines
20 KiB
Go
856 lines
20 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"cmp"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"golang.org/x/exp/slices"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/auth"
|
|
"github.com/ollama/ollama/format"
|
|
"github.com/ollama/ollama/llm"
|
|
"github.com/ollama/ollama/server/envconfig"
|
|
"github.com/ollama/ollama/types/errtypes"
|
|
"github.com/ollama/ollama/types/model"
|
|
"github.com/ollama/ollama/version"
|
|
)
|
|
|
|
type registryOptions struct {
|
|
Insecure bool
|
|
Username string
|
|
Password string
|
|
Token string
|
|
}
|
|
|
|
type Model struct {
|
|
Name model.Name
|
|
Config ConfigV2
|
|
ModelPath string
|
|
ParentModel string
|
|
AdapterPaths []string
|
|
ProjectorPaths []string
|
|
Template string
|
|
System string
|
|
License []string
|
|
Digest string
|
|
Options map[string]interface{}
|
|
Messages []Message
|
|
}
|
|
|
|
func (m *Model) IsEmbedding() bool {
|
|
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
|
|
}
|
|
|
|
func (m *Model) String() string {
|
|
var modelfile model.File
|
|
|
|
modelfile.Commands = append(modelfile.Commands, model.Command{
|
|
Name: "model",
|
|
Args: m.ModelPath,
|
|
})
|
|
|
|
for _, adapter := range m.AdapterPaths {
|
|
modelfile.Commands = append(modelfile.Commands, model.Command{
|
|
Name: "adapter",
|
|
Args: adapter,
|
|
})
|
|
}
|
|
|
|
for _, projector := range m.ProjectorPaths {
|
|
modelfile.Commands = append(modelfile.Commands, model.Command{
|
|
Name: "model",
|
|
Args: projector,
|
|
})
|
|
}
|
|
|
|
if m.Template != "" {
|
|
modelfile.Commands = append(modelfile.Commands, model.Command{
|
|
Name: "template",
|
|
Args: m.Template,
|
|
})
|
|
}
|
|
|
|
if m.System != "" {
|
|
modelfile.Commands = append(modelfile.Commands, model.Command{
|
|
Name: "system",
|
|
Args: m.System,
|
|
})
|
|
}
|
|
|
|
for k, v := range m.Options {
|
|
switch v := v.(type) {
|
|
case []any:
|
|
for _, s := range v {
|
|
modelfile.Commands = append(modelfile.Commands, model.Command{
|
|
Name: k,
|
|
Args: fmt.Sprintf("%v", s),
|
|
})
|
|
}
|
|
default:
|
|
modelfile.Commands = append(modelfile.Commands, model.Command{
|
|
Name: k,
|
|
Args: fmt.Sprintf("%v", v),
|
|
})
|
|
}
|
|
}
|
|
|
|
for _, license := range m.License {
|
|
modelfile.Commands = append(modelfile.Commands, model.Command{
|
|
Name: "license",
|
|
Args: license,
|
|
})
|
|
}
|
|
|
|
for _, msg := range m.Messages {
|
|
modelfile.Commands = append(modelfile.Commands, model.Command{
|
|
Name: "message",
|
|
Args: fmt.Sprintf("%s %s", msg.Role, msg.Content),
|
|
})
|
|
}
|
|
|
|
return modelfile.String()
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type ManifestV2 struct {
|
|
SchemaVersion int `json:"schemaVersion"`
|
|
MediaType string `json:"mediaType"`
|
|
Config *Layer `json:"config"`
|
|
Layers []*Layer `json:"layers"`
|
|
}
|
|
|
|
type ConfigV2 struct {
|
|
ModelFormat string `json:"model_format"`
|
|
ModelFamily string `json:"model_family"`
|
|
ModelFamilies []string `json:"model_families"`
|
|
ModelType string `json:"model_type"`
|
|
FileType string `json:"file_type"`
|
|
|
|
// required by spec
|
|
Architecture string `json:"architecture"`
|
|
OS string `json:"os"`
|
|
RootFS RootFS `json:"rootfs"`
|
|
}
|
|
|
|
type RootFS struct {
|
|
Type string `json:"type"`
|
|
DiffIDs []string `json:"diff_ids"`
|
|
}
|
|
|
|
func GetModel(name model.Name) (*Model, error) {
|
|
manifest, err := ParseNamedManifest(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
model := &Model{
|
|
Name: name,
|
|
Digest: manifest.digest,
|
|
Template: "{{ .Prompt }}",
|
|
License: []string{},
|
|
}
|
|
|
|
filename, err := GetBlobsPath(manifest.Config.Digest)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
configFile, err := os.Open(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer configFile.Close()
|
|
|
|
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, layer := range manifest.Layers {
|
|
filename, err := GetBlobsPath(layer.Digest)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch layer.MediaType {
|
|
case "application/vnd.ollama.image.model":
|
|
model.ModelPath = filename
|
|
model.ParentModel = layer.From
|
|
case "application/vnd.ollama.image.embed":
|
|
// Deprecated in versions > 0.1.2
|
|
// TODO: remove this warning in a future version
|
|
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
|
case "application/vnd.ollama.image.adapter":
|
|
model.AdapterPaths = append(model.AdapterPaths, filename)
|
|
case "application/vnd.ollama.image.projector":
|
|
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
|
case "application/vnd.ollama.image.template":
|
|
bts, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
model.Template = string(bts)
|
|
case "application/vnd.ollama.image.system":
|
|
bts, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
model.System = string(bts)
|
|
case "application/vnd.ollama.image.prompt":
|
|
bts, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
model.Template = string(bts)
|
|
case "application/vnd.ollama.image.params":
|
|
params, err := os.Open(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer params.Close()
|
|
|
|
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
|
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
|
return nil, err
|
|
}
|
|
case "application/vnd.ollama.image.messages":
|
|
msgs, err := os.Open(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer msgs.Close()
|
|
|
|
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
|
|
return nil, err
|
|
}
|
|
case "application/vnd.ollama.image.license":
|
|
bts, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
model.License = append(model.License, string(bts))
|
|
}
|
|
}
|
|
|
|
return model, nil
|
|
}
|
|
|
|
func realpath(rel, from string) string {
|
|
abspath, err := filepath.Abs(from)
|
|
if err != nil {
|
|
return from
|
|
}
|
|
|
|
home, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return abspath
|
|
}
|
|
|
|
if from == "~" {
|
|
return home
|
|
} else if strings.HasPrefix(from, "~/") {
|
|
return filepath.Join(home, from[2:])
|
|
}
|
|
|
|
if _, err := os.Stat(filepath.Join(rel, from)); err == nil {
|
|
// this is a file relative to the Modelfile
|
|
return filepath.Join(rel, from)
|
|
}
|
|
|
|
return abspath
|
|
}
|
|
|
|
func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *model.File, fn func(resp api.ProgressResponse)) (err error) {
|
|
config := ConfigV2{
|
|
OS: "linux",
|
|
Architecture: "amd64",
|
|
RootFS: RootFS{
|
|
Type: "layers",
|
|
},
|
|
}
|
|
|
|
var messages []*api.Message
|
|
parameters := make(map[string]any)
|
|
|
|
var layers []*Layer
|
|
for _, c := range modelfile.Commands {
|
|
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
|
|
|
switch c.Name {
|
|
case "model", "adapter":
|
|
var baseLayers []*layerWithGGML
|
|
if name := model.ParseName(c.Args); name.IsValid() {
|
|
baseLayers, err = parseFromModel(ctx, name, fn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else if strings.HasPrefix(c.Args, "@") {
|
|
blobpath, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
blob, err := os.Open(blobpath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer blob.Close()
|
|
|
|
baseLayers, err = parseFromFile(ctx, blob, fn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil {
|
|
defer file.Close()
|
|
|
|
baseLayers, err = parseFromFile(ctx, file, fn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
return fmt.Errorf("invalid model reference: %s", c.Args)
|
|
}
|
|
|
|
for _, baseLayer := range baseLayers {
|
|
if quantization != "" &&
|
|
baseLayer.MediaType == "application/vnd.ollama.image.model" &&
|
|
baseLayer.GGML != nil &&
|
|
baseLayer.GGML.Name() == "gguf" {
|
|
want, err := llm.ParseFileType(quantization)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ft := baseLayer.GGML.KV().FileType()
|
|
if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
|
|
return errors.New("quantization is only supported for F16 and F32 models")
|
|
} else if want != ft {
|
|
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)})
|
|
|
|
blob, err := GetBlobsPath(baseLayer.Digest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
temp, err := os.CreateTemp(filepath.Dir(blob), quantization)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer temp.Close()
|
|
defer os.Remove(temp.Name())
|
|
|
|
if err := llm.Quantize(blob, temp.Name(), want); err != nil {
|
|
return err
|
|
}
|
|
|
|
baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
if baseLayer.GGML != nil {
|
|
config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name())
|
|
config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture())
|
|
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount()))
|
|
config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String())
|
|
config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
|
|
}
|
|
|
|
layers = append(layers, baseLayer.Layer)
|
|
}
|
|
case "license", "template", "system":
|
|
blob := strings.NewReader(c.Args)
|
|
layer, err := NewLayer(blob, mediatype)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if c.Name != "license" {
|
|
// replace
|
|
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
|
return layer.MediaType == mediatype
|
|
})
|
|
}
|
|
|
|
layers = append(layers, layer)
|
|
case "message":
|
|
role, content, ok := strings.Cut(c.Args, ": ")
|
|
if !ok {
|
|
return fmt.Errorf("invalid message: %s", c.Args)
|
|
}
|
|
|
|
messages = append(messages, &api.Message{Role: role, Content: content})
|
|
default:
|
|
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for k, v := range ps {
|
|
if ks, ok := parameters[k].([]string); ok {
|
|
parameters[k] = append(ks, v.([]string)...)
|
|
} else if vs, ok := v.([]string); ok {
|
|
parameters[k] = vs
|
|
} else {
|
|
parameters[k] = v
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
var err2 error
|
|
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
|
switch layer.MediaType {
|
|
case "application/vnd.ollama.image.message":
|
|
// if there are new messages, remove the inherited ones
|
|
if len(messages) > 0 {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
case "application/vnd.ollama.image.params":
|
|
// merge inherited parameters with new ones
|
|
r, err := layer.Open()
|
|
if err != nil {
|
|
err2 = err
|
|
return false
|
|
}
|
|
defer r.Close()
|
|
|
|
var ps map[string]any
|
|
if err := json.NewDecoder(r).Decode(&ps); err != nil {
|
|
err2 = err
|
|
return false
|
|
}
|
|
|
|
for k, v := range ps {
|
|
if _, ok := parameters[k]; !ok {
|
|
parameters[k] = v
|
|
}
|
|
}
|
|
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
})
|
|
|
|
if err2 != nil {
|
|
return err2
|
|
}
|
|
|
|
if len(messages) > 0 {
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(messages); err != nil {
|
|
return err
|
|
}
|
|
|
|
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
layers = append(layers, layer)
|
|
}
|
|
|
|
if len(parameters) > 0 {
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(parameters); err != nil {
|
|
return err
|
|
}
|
|
|
|
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
layers = append(layers, layer)
|
|
}
|
|
|
|
digests := make([]string, len(layers))
|
|
for i, layer := range layers {
|
|
digests[i] = layer.Digest
|
|
}
|
|
|
|
config.RootFS.DiffIDs = digests
|
|
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(config); err != nil {
|
|
return err
|
|
}
|
|
|
|
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, layer := range append(layers, layer) {
|
|
if layer.status != "" {
|
|
fn(api.ProgressResponse{Status: layer.status})
|
|
}
|
|
}
|
|
|
|
if !envconfig.NoPrune {
|
|
if old, err := ParseNamedManifest(name); err == nil {
|
|
//nolint:errcheck
|
|
defer old.RemoveLayers()
|
|
}
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
|
if err := WriteManifest(name, layer, layers); err != nil {
|
|
return err
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "success"})
|
|
return nil
|
|
}
|
|
|
|
func CopyModel(src, dst model.Name) error {
|
|
if !dst.IsFullyQualified() {
|
|
return model.Unqualified(dst)
|
|
}
|
|
if !src.IsFullyQualified() {
|
|
return model.Unqualified(src)
|
|
}
|
|
|
|
if src.Filepath() == dst.Filepath() {
|
|
return nil
|
|
}
|
|
|
|
manifests, err := GetManifestPath()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
dstpath := filepath.Join(manifests, dst.Filepath())
|
|
if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil {
|
|
return err
|
|
}
|
|
|
|
srcpath := filepath.Join(manifests, src.Filepath())
|
|
srcfile, err := os.Open(srcpath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer srcfile.Close()
|
|
|
|
dstfile, err := os.Create(dstpath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer dstfile.Close()
|
|
|
|
_, err = io.Copy(dstfile, srcfile)
|
|
return err
|
|
}
|
|
|
|
func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
|
|
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
|
|
|
m, err := ParseNamedManifest(name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
scheme := "https"
|
|
if opts.Insecure {
|
|
scheme = "http"
|
|
}
|
|
|
|
baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, layer := range append(m.Layers, m.Config) {
|
|
if err := uploadBlob(ctx, uploadOptions{name: name, baseURL: baseURL, layer: layer, regOpts: &opts, fn: fn}); err != nil {
|
|
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
|
return err
|
|
}
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "pushing manifest"})
|
|
requestURL := baseURL.JoinPath("v2", name.Namespace, name.Model, "manifests", name.Tag)
|
|
|
|
manifestJSON, err := json.Marshal(m)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
headers := make(http.Header)
|
|
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
|
|
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), &opts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
fn(api.ProgressResponse{Status: "success"})
|
|
|
|
return nil
|
|
}
|
|
|
|
func PullModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
|
|
old, _ := ParseNamedManifest(name)
|
|
|
|
if !name.IsFullyQualified() {
|
|
return model.Unqualified(name)
|
|
}
|
|
|
|
scheme := "https"
|
|
if opts.Insecure {
|
|
scheme = "http"
|
|
}
|
|
|
|
baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "pulling manifest"})
|
|
m, err := pullModelManifest(ctx, name, baseURL, &opts)
|
|
if err != nil {
|
|
return fmt.Errorf("pull model manifest: %s", err)
|
|
}
|
|
|
|
layers := append(m.Layers, m.Config)
|
|
for _, layer := range layers {
|
|
if err := downloadBlob(
|
|
ctx,
|
|
downloadOptions{
|
|
name: name,
|
|
baseURL: baseURL,
|
|
digest: layer.Digest,
|
|
regOpts: &opts,
|
|
fn: fn,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
|
for _, layer := range layers {
|
|
if err := layer.Verify(); err != nil {
|
|
_ = layer.Remove()
|
|
return err
|
|
}
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
|
if err := WriteManifest(name, m.Config, m.Layers); err != nil {
|
|
return err
|
|
}
|
|
|
|
if !envconfig.NoPrune && old != nil {
|
|
fn(api.ProgressResponse{Status: "removing any unused layers"})
|
|
_ = old.RemoveLayers()
|
|
}
|
|
|
|
fn(api.ProgressResponse{Status: "success"})
|
|
return nil
|
|
}
|
|
|
|
func pullModelManifest(ctx context.Context, name model.Name, baseURL *url.URL, opts *registryOptions) (*ManifestV2, error) {
|
|
requestURL := baseURL.JoinPath("manifests", name.Tag)
|
|
|
|
headers := make(http.Header)
|
|
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
|
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var m *ManifestV2
|
|
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return m, err
|
|
}
|
|
|
|
var errUnauthorized = fmt.Errorf("unauthorized: access denied")
|
|
|
|
// getTokenSubject returns the subject of a JWT token, it does not validate the token
|
|
func getTokenSubject(token string) string {
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) != 3 {
|
|
slog.Error("jwt token does not contain 3 parts")
|
|
return ""
|
|
}
|
|
|
|
payload := parts[1]
|
|
payloadBytes, err := base64.RawURLEncoding.DecodeString(payload)
|
|
if err != nil {
|
|
slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err))
|
|
return ""
|
|
}
|
|
|
|
var payloadMap map[string]interface{}
|
|
if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil {
|
|
slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err))
|
|
return ""
|
|
}
|
|
|
|
sub, ok := payloadMap["sub"]
|
|
if !ok {
|
|
slog.Error("jwt does not contain 'sub' field")
|
|
return ""
|
|
}
|
|
|
|
return fmt.Sprintf("%s", sub)
|
|
}
|
|
|
|
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
|
anonymous := true // access will default to anonymous if no user is found associated with the public key
|
|
for i := 0; i < 2; i++ {
|
|
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
|
if err != nil {
|
|
if !errors.Is(err, context.Canceled) {
|
|
slog.Info(fmt.Sprintf("request failed: %v", err))
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
switch {
|
|
case resp.StatusCode == http.StatusUnauthorized:
|
|
// Handle authentication error with one retry
|
|
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
|
token, err := getAuthorizationToken(ctx, challenge)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
anonymous = getTokenSubject(token) == "anonymous"
|
|
regOpts.Token = token
|
|
if body != nil {
|
|
_, err = body.Seek(0, io.SeekStart)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
case resp.StatusCode == http.StatusNotFound:
|
|
return nil, os.ErrNotExist
|
|
case resp.StatusCode >= http.StatusBadRequest:
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
|
}
|
|
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
|
default:
|
|
return resp, nil
|
|
}
|
|
}
|
|
|
|
if anonymous {
|
|
// no user is associated with the public key, and the request requires non-anonymous access
|
|
pubKey, nestedErr := auth.GetPublicKey()
|
|
if nestedErr != nil {
|
|
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
|
return nil, errUnauthorized
|
|
}
|
|
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
|
}
|
|
// user is associated with the public key, but is not authorized to make the request
|
|
return nil, errUnauthorized
|
|
}
|
|
|
|
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
|
|
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
|
requestURL.Scheme = "http"
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if headers != nil {
|
|
req.Header = headers
|
|
}
|
|
|
|
if regOpts != nil {
|
|
if regOpts.Token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
|
|
} else if regOpts.Username != "" && regOpts.Password != "" {
|
|
req.SetBasicAuth(regOpts.Username, regOpts.Password)
|
|
}
|
|
}
|
|
|
|
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
|
|
|
if s := req.Header.Get("Content-Length"); s != "" {
|
|
contentLength, err := strconv.ParseInt(s, 10, 64)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.ContentLength = contentLength
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func getValue(header, key string) string {
|
|
startIdx := strings.Index(header, key+"=")
|
|
if startIdx == -1 {
|
|
return ""
|
|
}
|
|
|
|
// Move the index to the starting quote after the key.
|
|
startIdx += len(key) + 2
|
|
endIdx := startIdx
|
|
|
|
for endIdx < len(header) {
|
|
if header[endIdx] == '"' {
|
|
if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
|
|
endIdx++
|
|
continue
|
|
}
|
|
break
|
|
}
|
|
endIdx++
|
|
}
|
|
return header[startIdx:endIdx]
|
|
}
|
|
|
|
func parseRegistryChallenge(authStr string) registryChallenge {
|
|
authStr = strings.TrimPrefix(authStr, "Bearer ")
|
|
|
|
return registryChallenge{
|
|
Realm: getValue(authStr, "realm"),
|
|
Service: getValue(authStr, "service"),
|
|
Scope: getValue(authStr, "scope"),
|
|
}
|
|
}
|