Compare commits

...

3 Commits

Author SHA1 Message Date
Michael Yang
b7860f12ad reference license, template, system as files 2024-07-01 09:23:08 -07:00
Michael Yang
e401a23d62 cmd build context 2024-07-01 09:01:28 -07:00
Michael Yang
d46f6ea512 err on insecure path 2024-07-01 08:36:37 -07:00
4 changed files with 136 additions and 79 deletions

View File

@ -3,6 +3,7 @@ package cmd
import (
"archive/zip"
"bytes"
"cmp"
"context"
"crypto/ed25519"
"crypto/rand"
@ -11,6 +12,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"log"
"math"
"net"
@ -70,49 +72,67 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
home, err := os.UserHomeDir()
if err != nil {
return err
}
status := "transferring model data"
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
createCtx, err := cmd.Flags().GetString("context")
if err != nil {
return err
}
createCtx = cmp.Or(createCtx, filepath.Dir(filename))
fsys := os.DirFS(createCtx)
for i := range modelfile.Commands {
switch modelfile.Commands[i].Name {
case "model", "adapter":
path := modelfile.Commands[i].Args
if path == "~" {
path = home
} else if strings.HasPrefix(path, "~/") {
path = filepath.Join(home, path[2:])
}
if slices.Contains([]string{"model", "adapter", "license", "template", "system"}, modelfile.Commands[i].Name) {
p := filepath.Clean(modelfile.Commands[i].Args)
if !filepath.IsAbs(path) {
path = filepath.Join(filepath.Dir(filename), path)
}
fi, err := os.Stat(path)
fi, err := fs.Stat(fsys, p)
if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
continue
} else if err != nil {
return err
}
if fi.IsDir() {
// this is likely a safetensors or pytorch directory
// TODO make this work w/ adapters
tempfile, err := tempZipFiles(path)
if err != nil {
return err
}
defer os.RemoveAll(tempfile)
switch modelfile.Commands[i].Name {
case "model", "adapter":
if fi.IsDir() {
// this is likely a safetensors or pytorch directory
// TODO make this work w/ adapters
sub, err := fs.Sub(fsys, p)
if err != nil {
return err
}
path = tempfile
temp, err := os.CreateTemp(createCtx, "*.zip")
if err != nil {
return err
}
defer temp.Close()
defer os.RemoveAll(temp.Name())
if err := zipFiles(sub, temp); err != nil {
return err
}
p, err = filepath.Rel(createCtx, temp.Name())
if err != nil {
return err
}
}
case "license", "template", "system":
t, err := detectContentType(fsys, p)
if errors.Is(err, os.ErrNotExist) {
continue
} else if err != nil {
return err
} else if t != "text/plain" {
return fmt.Errorf("invalid content type for %s: %s", modelfile.Commands[i].Name, p)
}
}
digest, err := createBlob(cmd, client, path)
digest, err := createBlob(cmd, client, fsys, p)
if err != nil {
return err
}
@ -155,42 +175,34 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil
}
func tempZipFiles(path string) (string, error) {
tempfile, err := os.CreateTemp("", "ollama-tf")
func detectContentType(fsys fs.FS, name string) (string, error) {
f, err := fsys.Open(name)
if err != nil {
return "", err
}
defer tempfile.Close()
defer f.Close()
detectContentType := func(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
var b bytes.Buffer
b.Grow(512)
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
return contentType, nil
bts, err := io.ReadAll(io.LimitReader(f, 512))
if err != nil {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
return contentType, nil
}
func zipFiles(fsys fs.FS, w io.Writer) error {
glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern)
matches, err := fs.Glob(fsys, pattern)
if err != nil {
return nil, err
}
for _, safetensor := range matches {
if ct, err := detectContentType(safetensor); err != nil {
for _, match := range matches {
if ct, err := detectContentType(fsys, match); err != nil {
return nil, err
} else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
}
}
@ -198,73 +210,73 @@ func tempZipFiles(path string) (string, error) {
}
var files []string
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
if st, _ := glob("model*.safetensors", "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
} else if pt, _ := glob("pytorch_model*.bin", "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...)
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
} else if pt, _ := glob("consolidated*.pth", "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth
files = append(files, pt...)
} else {
return "", errors.New("no safetensors or torch files found")
return errors.New("no safetensors or torch files found")
}
// add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
js, err := glob("*.json", "text/plain")
if err != nil {
return "", err
return err
}
files = append(files, js...)
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
if tks, _ := glob("tokenizer.model", "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
} else if tks, _ := glob("**/tokenizer.model", "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}
zipfile := zip.NewWriter(tempfile)
zipfile := zip.NewWriter(w)
defer zipfile.Close()
for _, file := range files {
f, err := os.Open(file)
f, err := fsys.Open(file)
if err != nil {
return "", err
return err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return "", err
return err
}
zfi, err := zip.FileInfoHeader(fi)
if err != nil {
return "", err
return err
}
zf, err := zipfile.CreateHeader(zfi)
if err != nil {
return "", err
return err
}
if _, err := io.Copy(zf, f); err != nil {
return "", err
return err
}
}
return tempfile.Name(), nil
return nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
bin, err := os.Open(path)
func sha256sum(fsys fs.FS, name string) (string, error) {
bin, err := fsys.Open(name)
if err != nil {
return "", err
}
@ -275,14 +287,25 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
return "", err
}
if _, err := bin.Seek(0, io.SeekStart); err != nil {
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
}
func createBlob(cmd *cobra.Command, client *api.Client, fsys fs.FS, name string) (string, error) {
bin, err := fsys.Open(name)
if err != nil {
return "", err
}
defer bin.Close()
digest, err := sha256sum(fsys, name)
if err != nil {
return "", err
}
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
return "", err
}
return digest, nil
}
@ -1226,6 +1249,7 @@ func NewCLI() *cobra.Command {
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
createCmd.Flags().StringP("context", "C", "", "Context for the model")
showCmd := &cobra.Command{
Use: "show MODEL",

View File

@ -459,7 +459,22 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
})
}
blob := strings.NewReader(c.Args)
var blob io.Reader = strings.NewReader(c.Args)
if strings.HasPrefix(c.Args, "@") {
p, err := GetBlobsPath(strings.TrimPrefix(c.Args, "@"))
if err != nil {
return err
}
b, err := os.Open(p)
if err != nil {
return err
}
defer b.Close()
blob = b
}
layer, err := NewLayer(blob, mediatype)
if err != nil {
return err

View File

@ -11,7 +11,6 @@ import (
"net/http"
"os"
"path/filepath"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
@ -91,12 +90,11 @@ func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse))
fn(api.ProgressResponse{Status: "unpacking model metadata"})
for _, f := range r.File {
n := filepath.Join(p, f.Name)
if !strings.HasPrefix(n, p) {
slog.Warn("skipped extracting file outside of context", "name", f.Name)
continue
if !filepath.IsLocal(f.Name) {
return fmt.Errorf("%w: %s", zip.ErrInsecurePath, f.Name)
}
n := filepath.Join(p, f.Name)
if err := os.MkdirAll(filepath.Dir(n), 0o750); err != nil {
return err
}

View File

@ -3,10 +3,12 @@ package server
import (
"archive/zip"
"bytes"
"errors"
"io"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/ollama/ollama/api"
@ -39,13 +41,31 @@ func TestExtractFromZipFile(t *testing.T) {
cases := []struct {
name string
expect []string
err error
}{
{
name: "good",
expect: []string{"good"},
},
{
name: filepath.Join("..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"),
name: strings.Join([]string{"path", "..", "to", "good"}, string(os.PathSeparator)),
expect: []string{filepath.Join("to", "good")},
},
{
name: strings.Join([]string{"path", "..", "to", "..", "good"}, string(os.PathSeparator)),
expect: []string{"good"},
},
{
name: strings.Join([]string{"path", "to", "..", "..", "good"}, string(os.PathSeparator)),
expect: []string{"good"},
},
{
name: strings.Join([]string{"..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "..", "bad"}, string(os.PathSeparator)),
err: zip.ErrInsecurePath,
},
{
name: strings.Join([]string{"path", "..", "..", "to", "bad"}, string(os.PathSeparator)),
err: zip.ErrInsecurePath,
},
}
@ -55,7 +75,7 @@ func TestExtractFromZipFile(t *testing.T) {
defer f.Close()
tempDir := t.TempDir()
if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); err != nil {
if err := extractFromZipFile(tempDir, f, func(api.ProgressResponse) {}); !errors.Is(err, tt.err) {
t.Fatal(err)
}