Compare commits
135 Commits
main
...
bmizerany/
Author | SHA1 | Date | |
---|---|---|---|
|
a2cf25caf7 | ||
|
aca112a308 | ||
|
1389e6926a | ||
|
a292cde2f3 | ||
|
ab9e476551 | ||
|
d721228a2b | ||
|
d3c6400487 | ||
|
06c21f00eb | ||
|
8b62eaf059 | ||
|
6c1c0f9f1a | ||
|
6a4b3c3823 | ||
|
38e7ddb39d | ||
|
f595dea189 | ||
|
be7fe0d6d8 | ||
|
98dbc1202a | ||
|
bdff89bc4c | ||
|
2100129e83 | ||
|
4eb7acf84b | ||
|
ff68227ca1 | ||
|
d2aef85dda | ||
|
1407fd3d4a | ||
|
07f27312fa | ||
|
6ba495d4a3 | ||
|
bd446a72cc | ||
|
5615f60bb0 | ||
|
348378ef56 | ||
|
81e8c49ac2 | ||
|
0172726a58 | ||
|
c84f9b07b0 | ||
|
712eaa4b09 | ||
|
2f241692bd | ||
|
d35a6a577f | ||
|
14a6f85e9e | ||
|
45d8d22785 | ||
|
e201627c63 | ||
|
5e76860c47 | ||
|
fb0782b7a9 | ||
|
0bee38f6b5 | ||
|
b24f1ad587 | ||
|
0bea2b8916 | ||
|
e4d65d5aef | ||
|
6e464ebef8 | ||
|
f2c17682b0 | ||
|
f0e6c563e2 | ||
|
a187851900 | ||
|
2e600aa398 | ||
|
0c78e6c23d | ||
|
c0d4f55f3e | ||
|
d67ff60643 | ||
|
3cb07b3ac9 | ||
|
fc595d89bf | ||
|
e1de015fbc | ||
|
1c04951cac | ||
|
95559adee3 | ||
|
9821ca28e8 | ||
|
c5768ceffe | ||
|
0b65220936 | ||
|
a5f05236f7 | ||
|
3cf109ec59 | ||
|
7c7f56a7fb | ||
|
bf8e0c09c9 | ||
|
a6b8bdf938 | ||
|
f51197a814 | ||
|
713e2feacf | ||
|
7f1faf6c12 | ||
|
ad6f020bd8 | ||
|
e052bd8c0f | ||
|
aef832b298 | ||
|
805a92e6f2 | ||
|
0c46151700 | ||
|
bfe89d6fa0 | ||
|
92b7e40fde | ||
|
d510a90214 | ||
|
a4fd06d603 | ||
|
cfe0bb6bb6 | ||
|
6aa9795c4f | ||
|
5041000a28 | ||
|
7cd939690a | ||
|
42cda9dd46 | ||
|
6917865bf3 | ||
|
2633fcb149 | ||
|
58de2b8d4a | ||
|
de72688b35 | ||
|
cbb367b1df | ||
|
18160475c4 | ||
|
31e9b3dd15 | ||
|
e28cfdf813 | ||
|
2751c26da7 | ||
|
45ca3c80e8 | ||
|
d85fbd0e99 | ||
|
acf1cb1dc4 | ||
|
c787b8b2dd | ||
|
9f2d8d2117 | ||
|
d42c3f6be1 | ||
|
4ea3e9efa6 | ||
|
2e1ea6ecaa | ||
|
6d2da77ce2 | ||
|
def4d902bf | ||
|
76a202c04e | ||
|
f7cfe946dc | ||
|
005b6373e2 | ||
|
d54e0fb3b2 | ||
|
bdd05e0ae0 | ||
|
1a346640db | ||
|
f5883070f8 | ||
|
adc23d5f96 | ||
|
a10a11b9d3 | ||
|
94befe366a | ||
|
c95f97689b | ||
|
618eb5b909 | ||
|
eb75418be9 | ||
|
9959da05de | ||
|
aff7970628 | ||
|
628f1feb36 | ||
|
ce3125afd5 | ||
|
f488652ba7 | ||
|
2318ed2919 | ||
|
b1b8be33d9 | ||
|
876f7eab81 | ||
|
7cfc8a0838 | ||
|
fd411b3cf6 | ||
|
04f38cf3f4 | ||
|
c0eddb10fd | ||
|
60ef0e6b4a | ||
|
48c60c01e2 | ||
|
eb2c442a01 | ||
|
c87fe7df48 | ||
|
5182a1dfb1 | ||
|
a32e7857b2 | ||
|
6acc205de0 | ||
|
f6e02d4bc7 | ||
|
e1d457c73e | ||
|
cd5df121a5 | ||
|
112ffed189 | ||
|
c49947dcf5 |
28
go.mod
28
go.mod
@ -10,7 +10,7 @@ require (
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang/protobuf v1.5.0 // indirect
|
||||
github.com/google/uuid v1.0.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
@ -19,23 +19,35 @@ require (
|
||||
golang.org/x/sync v0.3.0
|
||||
)
|
||||
|
||||
require github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
|
||||
require (
|
||||
github.com/minio/minio-go/v7 v7.0.69
|
||||
github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
|
||||
kr.dev/diff v0.3.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc // indirect
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.0.8 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v1.12.0 // indirect
|
||||
github.com/klauspost/compress v1.17.6 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.14 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.8.1 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gonum.org/v1/gonum v0.8.2 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||
)
|
||||
@ -53,7 +65,7 @@ require (
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
@ -63,12 +75,12 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/crypto v0.19.0
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0
|
||||
golang.org/x/term v0.13.0
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sys v0.17.0
|
||||
golang.org/x/term v0.17.0
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
51
go.sum
51
go.sum
@ -26,6 +26,8 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
@ -86,8 +88,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=
|
||||
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
@ -95,9 +97,12 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm
|
||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
|
||||
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
@ -115,6 +120,12 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.69 h1:l8AnsQFyY1xiwa/DaQskY4NXSLA2yrGsW5iD9nRPVS0=
|
||||
github.com/minio/minio-go/v7 v7.0.69/go.mod h1:XAvOPJQ5Xlzk5o3o/ArO2NMbhSGkimC+bpW/ngRKDmQ=
|
||||
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
||||
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@ -129,6 +140,7 @@ github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9/go.mod h1:nR7l3gM6u
|
||||
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@ -138,8 +150,11 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
|
||||
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
@ -181,8 +196,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@ -205,8 +220,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@ -226,18 +241,18 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
|
||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@ -292,6 +307,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@ -303,4 +320,6 @@ gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A=
|
||||
gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
kr.dev/diff v0.3.0 h1:o/T8/tkAq9IuRIuFqCupyKPC5iSY3WXpVZ2p6ZK3Emw=
|
||||
kr.dev/diff v0.3.0/go.mod h1:XiTaLOg2/PD0cmXY7WQXUR8RAF3RwWpqIQEj910J2NY=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
113
x/api/api.go
Normal file
113
x/api/api.go
Normal file
@ -0,0 +1,113 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/build"
|
||||
"github.com/ollama/ollama/x/client/ollama/apitype"
|
||||
"github.com/ollama/ollama/x/oweb"
|
||||
"github.com/ollama/ollama/x/registry"
|
||||
regtype "github.com/ollama/ollama/x/registry/apitype"
|
||||
)
|
||||
|
||||
// Common API Errors
|
||||
var (
|
||||
errUnqualifiedRef = oweb.Invalid("invalid", "name", "must be fully qualified")
|
||||
errRefNotFound = oweb.Invalid("not_found", "name", "no such model")
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Build *build.Server
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
oweb.Serve(s.serveHTTP, w, r)
|
||||
}
|
||||
|
||||
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.URL.Path {
|
||||
case "/v1/push":
|
||||
return s.handlePush(w, r)
|
||||
default:
|
||||
return oweb.ErrNotFound
|
||||
}
|
||||
}
|
||||
|
||||
func want(r *http.Request, method, path string) bool {
|
||||
return r.Method == method && r.URL.Path == path
|
||||
}
|
||||
|
||||
func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "POST" {
|
||||
return oweb.ErrMethodNotAllowed
|
||||
}
|
||||
|
||||
params, err := oweb.DecodeJSON[apitype.PushRequest](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if params.Name == "" {
|
||||
return oweb.Missing("name")
|
||||
}
|
||||
|
||||
const registryURLTODO = "http://localhost:8888"
|
||||
|
||||
man, err := s.Build.ManifestData(params.Name)
|
||||
if err != nil {
|
||||
if errors.Is(err, build.ErrNotFound) {
|
||||
return errRefNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c := registry.Client{BaseURL: registryURLTODO}
|
||||
requirements, err := c.Push(r.Context(), params.Name, man, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var uploads []regtype.CompletePart
|
||||
for _, rq := range requirements {
|
||||
l, err := s.Build.LayerFile(rq.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = func() error {
|
||||
f, err := os.Open(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
cp, err := registry.PushLayer(r.Context(), f, rq.URL, rq.Offset, rq.Size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uploads = append(uploads, cp)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// commit the manifest to the registry
|
||||
requirements, err = c.Push(r.Context(), params.Name, man, ®istry.PushParams{
|
||||
CompleteParts: uploads,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, r := range requirements {
|
||||
err = errors.Join(err, fmt.Errorf("push failed for %q", r.Digest))
|
||||
}
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
return oweb.ErrNotFound
|
||||
}
|
209
x/build/build.go
Normal file
209
x/build/build.go
Normal file
@ -0,0 +1,209 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/build/internal/blobstore"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrIncompleteRef = errors.New("unqualified ref")
|
||||
ErrBuildPresentInRef = errors.New("build present in ref")
|
||||
ErrUnsupportedModelFormat = errors.New("unsupported model format")
|
||||
ErrMissingFileType = errors.New("missing 'general.file_type' key")
|
||||
ErrNotFound = errors.New("not found")
|
||||
)
|
||||
|
||||
type mediaType string
|
||||
|
||||
// Known media types
|
||||
const (
|
||||
mediaTypeModel mediaType = "application/vnd.ollama.image.model"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
st *blobstore.Store
|
||||
}
|
||||
|
||||
// Open starts a new build server that uses dir as the base directory for all
|
||||
// build artifacts. If dir is empty, DefaultDir is used.
|
||||
//
|
||||
// It returns an error if the provided or default dir cannot be initialized.
|
||||
func Open(dir string) (*Server, error) {
|
||||
if dir == "" {
|
||||
var err error
|
||||
dir, err = DefaultDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
st, err := blobstore.Open(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Server{st: st}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Build(ref string, f model.File) error {
|
||||
mp := model.ParseName(ref)
|
||||
if !mp.IsCompleteNoBuild() {
|
||||
return fmt.Errorf("%w: %q", ErrIncompleteRef, ref)
|
||||
}
|
||||
|
||||
// 1. Resolve FROM
|
||||
// a. If it's a local file (gguf), hash it and add it to the store.
|
||||
// c. If it's a remote file (http), refuse.
|
||||
// 2. Turn other pragmas into layers, and add them to the store.
|
||||
// 3. Create a manifest from the layers.
|
||||
// 4. Store the manifest in the manifest cache
|
||||
// 5. Done.
|
||||
|
||||
if f.From == "" {
|
||||
return &model.FileError{Pragma: "FROM", Message: "missing"}
|
||||
}
|
||||
|
||||
var layers []layerJSON
|
||||
|
||||
id, info, size, err := s.importModel(f.From)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
layers = append(layers, layerJSON{
|
||||
ID: id,
|
||||
MediaType: mediaTypeModel,
|
||||
Size: size,
|
||||
})
|
||||
|
||||
id, size, err = blobstore.PutString(s.st, f.License)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
layers = append(layers, layerJSON{
|
||||
ID: id,
|
||||
MediaType: "text/plain",
|
||||
Size: size,
|
||||
})
|
||||
|
||||
data, err := json.Marshal(manifestJSON{Layers: layers})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.setManifestData(
|
||||
mp.WithBuild(info.FileType.String()),
|
||||
data,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) LayerFile(digest string) (string, error) {
|
||||
fileName := s.st.OutputFilename(blobstore.ParseID(digest))
|
||||
_, err := os.Stat(fileName)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return "", fmt.Errorf("%w: %q", ErrNotFound, digest)
|
||||
}
|
||||
return fileName, nil
|
||||
}
|
||||
|
||||
func (s *Server) ManifestData(ref string) ([]byte, error) {
|
||||
data, _, err := s.resolve(model.ParseName(ref))
|
||||
return data, err
|
||||
}
|
||||
|
||||
// WeightFile returns the absolute path to the weights file for the given model ref.
|
||||
func (s *Server) WeightsFile(ref string) (string, error) {
|
||||
m, err := s.getManifest(model.ParseName(ref))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
if l.MediaType == mediaTypeModel {
|
||||
return s.st.OutputFilename(l.ID), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("missing weights layer for %q", ref)
|
||||
}
|
||||
|
||||
// resolve returns the data for the given ref, if any.
|
||||
//
|
||||
// TODO: This should ideally return an ID, but the current on
|
||||
// disk layout is that the actual manifest is stored in the "ref" instead of
|
||||
// a pointer to a content-addressed blob. I (bmizerany) think we should
|
||||
// change the on-disk layout to store the manifest in a content-addressed
|
||||
// blob, and then have the ref point to that blob. This would simplify the
|
||||
// code, allow us to have integrity checks on the manifest, and clean up
|
||||
// this interface.
|
||||
func (s *Server) resolve(ref model.Name) (data []byte, fileName string, err error) {
|
||||
fileName, err = s.refFileName(ref)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
data, err = os.ReadFile(fileName)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, "", fmt.Errorf("%w: %q", ErrNotFound, ref)
|
||||
}
|
||||
if err != nil {
|
||||
// do not wrap the error here, as it is likely an I/O error
|
||||
// and we want to preserve the absraction since we may not
|
||||
// be on disk later.
|
||||
return nil, "", fmt.Errorf("manifest read error: %v", err)
|
||||
}
|
||||
return data, fileName, nil
|
||||
}
|
||||
|
||||
func (s *Server) SetManifestData(ref string, data []byte) error {
|
||||
return s.setManifestData(model.ParseName(ref), data)
|
||||
}
|
||||
|
||||
// Set sets the data for the given ref.
|
||||
func (s *Server) setManifestData(mp model.Name, data []byte) error {
|
||||
path, err := s.refFileName(mp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0666); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) refFileName(mp model.Name) (string, error) {
|
||||
if !mp.IsComplete() {
|
||||
return "", fmt.Errorf("ref not fully qualified: %q", mp)
|
||||
}
|
||||
return filepath.Join(s.st.Dir(), "manifests", filepath.Join(mp.Parts()...)), nil
|
||||
}
|
||||
|
||||
type manifestJSON struct {
|
||||
// Layers is the list of layers in the manifest.
|
||||
Layers []layerJSON `json:"layers"`
|
||||
}
|
||||
|
||||
// Layer is a layer in a model manifest.
|
||||
type layerJSON struct {
|
||||
// ID is the ID of the layer.
|
||||
ID blobstore.ID `json:"digest"`
|
||||
MediaType mediaType `json:"mediaType"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
func (s *Server) getManifest(ref model.Name) (manifestJSON, error) {
|
||||
data, path, err := s.resolve(ref)
|
||||
if err != nil {
|
||||
return manifestJSON{}, err
|
||||
}
|
||||
var m manifestJSON
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return manifestJSON{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err}
|
||||
}
|
||||
return m, nil
|
||||
}
|
163
x/build/build_test.go
Normal file
163
x/build/build_test.go
Normal file
@ -0,0 +1,163 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/encoding/gguf"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
)
|
||||
|
||||
const qualifiedRef = "x/y/z:latest+Q4_0"
|
||||
|
||||
func TestServerBuildErrors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
s, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("unqualified ref", func(t *testing.T) {
|
||||
err := s.Build("x", model.File{})
|
||||
if !errors.Is(err, ErrIncompleteRef) {
|
||||
t.Fatalf("Build() err = %v; want unqualified ref", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM pragma missing", func(t *testing.T) {
|
||||
err := s.Build(qualifiedRef, model.File{})
|
||||
var e *model.FileError
|
||||
if !errors.As(err, &e) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if e.Pragma != "FROM" {
|
||||
t.Errorf("e.Pragma = %s; want FROM", e.Pragma)
|
||||
}
|
||||
if e.Message != "missing" {
|
||||
t.Errorf("e.Message = %s; want missing", e.Message)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM file not found", func(t *testing.T) {
|
||||
err := s.Build(qualifiedRef, model.File{From: "bar"})
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("Build() err = %v; want file not found", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM gguf", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
// Write a gguf file without general.file_type metadata.
|
||||
w.write("gguf", ""+
|
||||
"GGUF"+ // magic
|
||||
"\x03\x00\x00\x00"+ // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
|
||||
"",
|
||||
)
|
||||
|
||||
err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")})
|
||||
if !errors.Is(err, ErrMissingFileType) {
|
||||
t.Fatalf("Build() err = %#v; want missing file type", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM obscure dir", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
w.mkdirAll("unknown")
|
||||
if err := s.Build(qualifiedRef, model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat {
|
||||
t.Fatalf("Build() err = %#v; want unsupported model type", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM unsupported model type", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
from := w.write("unknown", "unknown content")
|
||||
err := s.Build(qualifiedRef, model.File{From: from})
|
||||
if !errors.Is(err, ErrUnsupportedModelFormat) {
|
||||
t.Fatalf("Build() err = %#v; want unsupported model type", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildBasicGGUF(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
w.write("gguf", ""+
|
||||
"GGUF"+ // magic
|
||||
"\x03\x00\x00\x00"+ // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
|
||||
|
||||
// general.file_type key
|
||||
"\x11\x00\x00\x00\x00\x00\x00\x00"+ // key length
|
||||
"general.file_type"+ // key
|
||||
"\x04\x00\x00\x00"+ // type (uint32)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00"+ // uint32 value
|
||||
"",
|
||||
)
|
||||
|
||||
dir := t.TempDir()
|
||||
s, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
filepath.Walk(dir, func(p string, info os.FileInfo, err error) error {
|
||||
t.Logf("file: %s", p)
|
||||
return nil
|
||||
})
|
||||
|
||||
_, err = s.WeightsFile("unknown/y/z:latest+Q4_0")
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Fatalf("WeightsFile() err = %v; want not found", err)
|
||||
}
|
||||
|
||||
path, err := s.WeightsFile("x/y/z:latest+Q4_0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
info, err := gguf.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if info.FileType != gguf.TypeQ4_0 {
|
||||
t.Errorf("info.FileType = %d; want 1", info.FileType)
|
||||
}
|
||||
}
|
||||
|
||||
type work struct {
|
||||
t testing.TB
|
||||
dir string
|
||||
}
|
||||
|
||||
func newWorkDir(t *testing.T) work {
|
||||
return work{t: t, dir: t.TempDir()}
|
||||
}
|
||||
|
||||
func (w work) write(name, content string) (path string) {
|
||||
w.t.Helper()
|
||||
path = w.fileName(name)
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
w.t.Fatal(err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (w work) fileName(name string) string {
|
||||
w.t.Helper()
|
||||
return filepath.Join(w.dir, name)
|
||||
}
|
||||
|
||||
func (w work) mkdirAll(path string) {
|
||||
w.t.Helper()
|
||||
if err := os.MkdirAll(filepath.Join(w.dir, path), 0755); err != nil {
|
||||
w.t.Fatal(err)
|
||||
}
|
||||
}
|
12
x/build/convert.go
Normal file
12
x/build/convert.go
Normal file
@ -0,0 +1,12 @@
|
||||
package build
|
||||
|
||||
func convertSafeTensorToGGUF(path string) (ggufPath string, err error) {
|
||||
// TODO: decine on hueristic for converting safetensor to gguf and
|
||||
// the errors that can be returned. For now, we just say
|
||||
// "unsupported", however it may be intended to be a valid safe
|
||||
// tensor but we hit an error in the conversion.
|
||||
//
|
||||
// I (bmizernay) think this will naturally evolve as we implement
|
||||
// the conversion.
|
||||
return "", ErrUnsupportedModelFormat
|
||||
}
|
28
x/build/default.go
Normal file
28
x/build/default.go
Normal file
@ -0,0 +1,28 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDir = sync.OnceValues(func() (string, error) {
|
||||
dir := os.Getenv("OLLAMA_MODELS")
|
||||
if dir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dir = filepath.Join(home, ".ollama", "models")
|
||||
}
|
||||
return dir, nil
|
||||
})
|
||||
)
|
||||
|
||||
// DefaultDir returns the default directory for models. It returns the value
|
||||
// of the OLLAMA_MODELS environment variable if set; otherwise it returns
|
||||
// "$HOME/.ollama/models".
|
||||
func DefaultDir() (string, error) {
|
||||
return defaultDir()
|
||||
}
|
59
x/build/import.go
Normal file
59
x/build/import.go
Normal file
@ -0,0 +1,59 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/build/internal/blobstore"
|
||||
"github.com/ollama/ollama/x/encoding/gguf"
|
||||
)
|
||||
|
||||
func importError(err error) (blobstore.ID, gguf.Info, int64, error) {
|
||||
return blobstore.ID{}, gguf.Info{}, 0, err
|
||||
}
|
||||
|
||||
func (s *Server) importModel(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return s.importSafeTensor(path)
|
||||
} else {
|
||||
return s.importGGUF(path)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) importGGUF(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
info, err := gguf.StatReader(f)
|
||||
if errors.Is(err, gguf.ErrBadMagic) {
|
||||
return importError(ErrUnsupportedModelFormat)
|
||||
}
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
|
||||
if info.FileType == 0 {
|
||||
return importError(fmt.Errorf("%w: %q", ErrMissingFileType, path))
|
||||
}
|
||||
id, size, err := s.st.Put(f)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
return id, info, size, nil
|
||||
}
|
||||
|
||||
func (s *Server) importSafeTensor(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
path, err := convertSafeTensorToGGUF(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
return s.importGGUF(path)
|
||||
}
|
329
x/build/internal/blobstore/blob.go
Normal file
329
x/build/internal/blobstore/blob.go
Normal file
@ -0,0 +1,329 @@
|
||||
// Package blobstore implements a blob store.
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidID = errors.New("invalid ID")
|
||||
)
|
||||
|
||||
const HashSize = 32
|
||||
|
||||
// An ID is a blob output key, the hash of an output of a computation.
|
||||
type ID struct {
|
||||
a [HashSize]byte
|
||||
}
|
||||
|
||||
func (id ID) MarshalText() ([]byte, error) {
|
||||
return []byte(id.String()), nil
|
||||
}
|
||||
|
||||
func (id *ID) UnmarshalText(text []byte) error {
|
||||
*id = ParseID(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseID(s string) ID {
|
||||
const prefix = "sha256-"
|
||||
h, ok := strings.CutPrefix(s, prefix)
|
||||
if !ok {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
if len(h) != HashSize*2 {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
var b []byte
|
||||
_, err := fmt.Sscanf(h, "%x", &b)
|
||||
if err != nil {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
var id ID
|
||||
copy(id.a[:], b)
|
||||
return id
|
||||
}
|
||||
|
||||
func (id ID) String() string {
|
||||
if !id.Valid() {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("sha256-%x", id.a[:])
|
||||
}
|
||||
|
||||
func (id ID) Valid() bool {
|
||||
return id != ID{}
|
||||
}
|
||||
|
||||
func (id ID) Match(h [HashSize]byte) bool {
|
||||
return id.a == h
|
||||
}
|
||||
|
||||
// A Store is a blob store, backed by a file system directory tree.
|
||||
type Store struct {
|
||||
dir string
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// Open opens and returns the store in the given directory.
|
||||
//
|
||||
// It is safe for multiple processes on a single machine to use the
|
||||
// same store directory in a local file system simultaneously.
|
||||
// They will coordinate using operating system file locks and may
|
||||
// duplicate effort but will not corrupt the store.
|
||||
//
|
||||
// However, it is NOT safe for multiple processes on different machines
|
||||
// to share a store directory (for example, if the directory were stored
|
||||
// in a network file system). File locking is notoriously unreliable in
|
||||
// network file systems and may not suffice to protect the store.
|
||||
func Open(dir string) (*Store, error) {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, &fs.PathError{Op: "open", Path: dir, Err: fmt.Errorf("not a directory")}
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(dir, "blobs"), 0777); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &Store{
|
||||
dir: dir,
|
||||
now: time.Now,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *Store) Dir() string {
|
||||
return s.dir
|
||||
}
|
||||
|
||||
// fileName returns the name of the blob file corresponding to the given id.
|
||||
func (s *Store) fileName(id ID) string {
|
||||
return filepath.Join(s.dir, "blobs", fmt.Sprintf("sha256-%x", id.a[:]))
|
||||
}
|
||||
|
||||
// An entryNotFoundError indicates that a store entry was not found, with an
|
||||
// optional underlying reason.
|
||||
type entryNotFoundError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *entryNotFoundError) Error() string {
|
||||
if e.Err == nil {
|
||||
return "store entry not found"
|
||||
}
|
||||
return fmt.Sprintf("store entry not found: %v", e.Err)
|
||||
}
|
||||
|
||||
func (e *entryNotFoundError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
_ structs.Incomparable
|
||||
|
||||
ID ID
|
||||
Size int64
|
||||
Time time.Time // when added to store
|
||||
}
|
||||
|
||||
// GetFile looks up the blob ID in the store and returns
|
||||
// the name of the corresponding data file.
|
||||
func GetFile(s *Store, id ID) (file string, entry Entry, err error) {
|
||||
entry, err = s.Get(id)
|
||||
if err != nil {
|
||||
return "", Entry{}, err
|
||||
}
|
||||
file = s.OutputFilename(entry.ID)
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return "", Entry{}, &entryNotFoundError{Err: err}
|
||||
}
|
||||
if info.Size() != entry.Size {
|
||||
return "", Entry{}, &entryNotFoundError{Err: errors.New("file incomplete")}
|
||||
}
|
||||
return file, entry, nil
|
||||
}
|
||||
|
||||
// GetBytes looks up the blob ID in the store and returns
|
||||
// the corresponding output bytes.
|
||||
// GetBytes should only be used for data that can be expected to fit in memory.
|
||||
func GetBytes(s *Store, id ID) ([]byte, Entry, error) {
|
||||
entry, err := s.Get(id)
|
||||
if err != nil {
|
||||
return nil, entry, err
|
||||
}
|
||||
data, _ := os.ReadFile(s.OutputFilename(entry.ID))
|
||||
if entry.ID.Match(sha256.Sum256(data)) {
|
||||
return nil, entry, &entryNotFoundError{Err: errors.New("bad checksum")}
|
||||
}
|
||||
return data, entry, nil
|
||||
}
|
||||
|
||||
// OutputFilename returns the name of the blob file for the given ID.
|
||||
func (s *Store) OutputFilename(id ID) string {
|
||||
file := s.fileName(id)
|
||||
// TODO(bmizerany): touch as "used" for cache trimming. (see
|
||||
// cache.go in cmd/go/internal/cache for the full reference implementation to go off of.
|
||||
return file
|
||||
}
|
||||
|
||||
// Get looks up the blob ID in the store,
|
||||
// returning the corresponding output ID and file size, if any.
|
||||
// Note that finding an output ID does not guarantee that the
|
||||
// saved file for that output ID is still available.
|
||||
func (s *Store) Get(id ID) (Entry, error) {
|
||||
file := s.fileName(id)
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return Entry{}, &entryNotFoundError{Err: err}
|
||||
}
|
||||
return Entry{
|
||||
ID: id,
|
||||
Size: info.Size(),
|
||||
Time: info.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error {
|
||||
// TODO(bmizerany): return c.Trim()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put stores the data read from the given file into the store as ID.
|
||||
//
|
||||
// It may read file twice. The content of file must not change between the
|
||||
// two passes.
|
||||
func (s *Store) Put(file io.ReadSeeker) (ID, int64, error) {
|
||||
return s.put(file)
|
||||
}
|
||||
|
||||
func PutBytes(s *Store, data []byte) (ID, int64, error) {
|
||||
return s.Put(bytes.NewReader(data))
|
||||
}
|
||||
|
||||
func PutString(s *Store, data string) (ID, int64, error) {
|
||||
return s.Put(strings.NewReader(data))
|
||||
}
|
||||
|
||||
func (s *Store) put(file io.ReadSeeker) (ID, int64, error) {
|
||||
// Compute output ID.
|
||||
h := sha256.New()
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
return ID{}, 0, err
|
||||
}
|
||||
size, err := io.Copy(h, file)
|
||||
if err != nil {
|
||||
return ID{}, 0, err
|
||||
}
|
||||
var out ID
|
||||
h.Sum(out.a[:0])
|
||||
|
||||
// Copy to blob file (if not already present).
|
||||
if err := s.copyFile(file, out, size); err != nil {
|
||||
return out, size, err
|
||||
}
|
||||
|
||||
// TODO: Add to manifest index.
|
||||
return out, size, nil
|
||||
}
|
||||
|
||||
// copyFile copies file into the store, expecting it to have the given
|
||||
// output ID and size, if that file is not present already.
|
||||
func (s *Store) copyFile(file io.ReadSeeker, out ID, size int64) error {
|
||||
name := s.fileName(out)
|
||||
println("name", name)
|
||||
info, err := os.Stat(name)
|
||||
if err == nil && info.Size() == size {
|
||||
// Check hash.
|
||||
if f, err := os.Open(name); err == nil {
|
||||
h := sha256.New()
|
||||
io.Copy(h, f)
|
||||
f.Close()
|
||||
var out2 ID
|
||||
h.Sum(out2.a[:0])
|
||||
if out == out2 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Hash did not match. Fall through and rewrite file.
|
||||
}
|
||||
|
||||
// Copy file to blobs directory.
|
||||
mode := os.O_RDWR | os.O_CREATE
|
||||
if err == nil && info.Size() > size { // shouldn't happen but fix in case
|
||||
mode |= os.O_TRUNC
|
||||
}
|
||||
f, err := os.OpenFile(name, mode, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if size == 0 {
|
||||
// File now exists with correct size.
|
||||
// Only one possible zero-length file, so contents are OK too.
|
||||
// Early return here makes sure there's a "last byte" for code below.
|
||||
return nil
|
||||
}
|
||||
|
||||
// From here on, if any of the I/O writing the file fails,
|
||||
// we make a best-effort attempt to truncate the file f
|
||||
// before returning, to avoid leaving bad bytes in the file.
|
||||
|
||||
// Copy file to f, but also into h to double-check hash.
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
h := sha256.New()
|
||||
w := io.MultiWriter(f, h)
|
||||
if _, err := io.CopyN(w, file, size-1); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
// Check last byte before writing it; writing it will make the size match
|
||||
// what other processes expect to find and might cause them to start
|
||||
// using the file.
|
||||
buf := make([]byte, 1)
|
||||
if _, err := file.Read(buf); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
h.Write(buf)
|
||||
sum := h.Sum(nil)
|
||||
if !bytes.Equal(sum, out.a[:]) {
|
||||
f.Truncate(0)
|
||||
return fmt.Errorf("file content changed underfoot")
|
||||
}
|
||||
|
||||
// Commit manifest entry.
|
||||
if _, err := f.Write(buf); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
// Data might not have been written,
|
||||
// but file may look like it is the right size.
|
||||
// To be extra careful, remove stored file.
|
||||
os.Remove(name)
|
||||
return err
|
||||
}
|
||||
os.Chtimes(name, s.now(), s.now()) // mainly for tests
|
||||
|
||||
return nil
|
||||
}
|
54
x/build/internal/blobstore/blob_test.go
Normal file
54
x/build/internal/blobstore/blob_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseID(t *testing.T) {
|
||||
const valid = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
var invalid = strings.Repeat("\x00", HashSize*2)
|
||||
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", invalid},
|
||||
{"sha256-", invalid},
|
||||
{"sha256-" + valid, valid},
|
||||
|
||||
{"" + valid, invalid}, // no prefix
|
||||
{"sha123-" + valid, invalid}, // invalid prefix
|
||||
{"sha256-" + valid[1:], invalid}, // too short
|
||||
{"sha256-" + valid + "a", invalid}, // too long
|
||||
{"sha256-!" + valid[1:], invalid}, // invalid hex
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
// sanity check
|
||||
if len(tt.want) > HashSize*2 {
|
||||
panic("invalid test")
|
||||
}
|
||||
|
||||
got := ParseID(tt.in)
|
||||
|
||||
wantValid := tt.want != invalid
|
||||
if wantValid {
|
||||
if !got.Valid() {
|
||||
t.Errorf("ParseID(%q).Valid() = false; want true", tt.in)
|
||||
}
|
||||
if got.String() != "sha256-"+tt.want {
|
||||
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "sha256-"+tt.want)
|
||||
}
|
||||
} else {
|
||||
if got.Valid() {
|
||||
t.Errorf("ParseID(%q).Valid() = true; want false", tt.in)
|
||||
}
|
||||
if got.String() != "" {
|
||||
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
128
x/build/internal/blobstore/store_test.go
Normal file
128
x/build/internal/blobstore/store_test.go
Normal file
@ -0,0 +1,128 @@
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"iter"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
const (
|
||||
blobNameHello = "sha256-2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
|
||||
)
|
||||
|
||||
func TestStoreBasicBlob(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
checkDir(t, dir, nil)
|
||||
|
||||
st, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
st.now = func() time.Time { return now }
|
||||
|
||||
checkDir(t, dir, []string{
|
||||
"blobs/",
|
||||
})
|
||||
|
||||
id, size, err := PutBytes(st, []byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if id != ParseID(blobNameHello) {
|
||||
t.Errorf("unexpected ID: %s", id)
|
||||
}
|
||||
if size != 5 {
|
||||
t.Errorf("unexpected size: %d", size)
|
||||
}
|
||||
|
||||
checkDir(t, dir, []string{
|
||||
"blobs/",
|
||||
"blobs/" + blobNameHello,
|
||||
})
|
||||
|
||||
got, err := st.Get(id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, got, Entry{
|
||||
ID: id,
|
||||
Size: 5,
|
||||
Time: now,
|
||||
})
|
||||
|
||||
file := st.OutputFilename(id)
|
||||
wantFile := filepath.Join(dir, "blobs", blobNameHello)
|
||||
if file != wantFile {
|
||||
t.Errorf("unexpected file: %s", file)
|
||||
}
|
||||
|
||||
// Check tags
|
||||
name := model.ParseName("registry.ollama.ai/library/test:latest+KQED")
|
||||
|
||||
t.Logf("RESOLVING: %q", name.Parts())
|
||||
|
||||
}
|
||||
|
||||
// checkDir checks that the directory at dir contains the files in want. The
|
||||
// files in want must be relative to dir.
|
||||
//
|
||||
// direcotories are suffixed with a slash (e.g. "foo/" instead of "foo").
|
||||
//
|
||||
// want must be in lexicographic order.
|
||||
func checkDir(t testing.TB, dir string, want []string) {
|
||||
t.Helper()
|
||||
|
||||
var matches []string
|
||||
for path, err := range walkDir(dir) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("found %s", path)
|
||||
if path == "./" {
|
||||
continue
|
||||
}
|
||||
path = filepath.ToSlash(path)
|
||||
matches = append(matches, path)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, matches, want)
|
||||
}
|
||||
|
||||
var errStop = errors.New("stop")
|
||||
|
||||
func walkDir(dir string) iter.Seq2[string, error] {
|
||||
return func(yield func(string, error) bool) {
|
||||
err := filepath.WalkDir(dir, func(path string, info os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path, err = filepath.Rel(dir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path = filepath.ToSlash(path)
|
||||
if info.IsDir() {
|
||||
path += "/"
|
||||
}
|
||||
if !yield(path, nil) {
|
||||
return errStop
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if !errors.Is(err, errStop) && err != nil {
|
||||
yield("", err)
|
||||
}
|
||||
}
|
||||
}
|
31
x/client/ollama/apitype/apitype.go
Normal file
31
x/client/ollama/apitype/apitype.go
Normal file
@ -0,0 +1,31 @@
|
||||
package apitype
|
||||
|
||||
import "time"
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Ref string `json:"ref"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
ModifiedAt int64 `json:"modified"`
|
||||
}
|
||||
|
||||
func (m Model) Modifed() time.Time {
|
||||
return time.Unix(0, m.ModifiedAt)
|
||||
}
|
||||
|
||||
type PushRequest struct {
|
||||
Name string `json:"name"` // Ref is the official term, "name" is for backward compatibility with exiting clients.
|
||||
Insecure bool `json:"insecure"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type PushStatus struct {
|
||||
Status string `json:"status"`
|
||||
Digest string `json:"digest"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
173
x/client/ollama/ollama.go
Normal file
173
x/client/ollama/ollama.go
Normal file
@ -0,0 +1,173 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/client/ollama/apitype"
|
||||
"github.com/ollama/ollama/x/types/empty"
|
||||
)
|
||||
|
||||
// TODO(bmizerany): PROGRESS INDICATORS!!!!
|
||||
|
||||
const DefaultBaseURL = "http://localhost:11434"
|
||||
|
||||
var envBaseURL = cmp.Or(os.Getenv("OLLAMA_BASE_URL"), DefaultBaseURL)
|
||||
|
||||
// Default returns a new client with the default base URL.
|
||||
func Default() *Client {
|
||||
return &Client{BaseURL: envBaseURL}
|
||||
}
|
||||
|
||||
// I_Acknowledge_This_API_Is_Under_Development is a flag that must be set to
|
||||
// true for any instance of Client to work.
|
||||
var I_Acknowledge_This_API_Is_Under_Development bool
|
||||
|
||||
// Client is a client for the Ollama API.
|
||||
type Client struct {
|
||||
// BaseURL is the base URL of the Ollama API.
|
||||
BaseURL string
|
||||
|
||||
HTTPClient *http.Client // The HTTP client to use. If nil, http.DefaultClient is used.
|
||||
}
|
||||
|
||||
// Build requests the remote Ollama service to build a model. It uploads any
|
||||
// source files the server needs.
|
||||
func (c *Client) Build(ctx context.Context, ref string, modelfile []byte, source fs.FS) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// Push requests the remote Ollama service to push a model to the server.
|
||||
func (c *Client) Push(ctx context.Context, ref string) error {
|
||||
_, err := Do[empty.Message](ctx, c, "POST", "/v1/push", apitype.PushRequest{Name: ref})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) Pull(ctx context.Context, ref string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) List(ctx context.Context) iter.Seq2[apitype.Model, error] {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Show(ctx context.Context, ref string) (*apitype.Model, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Remove(ctx context.Context, ref string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Copy(ctx context.Context, dstRef, srcRef string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Run(ctx context.Context, ref string, messages []apitype.Message) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
// Status is the HTTP status code returned by the server.
|
||||
Status int `json:"status"`
|
||||
|
||||
// Code specifies a machine readable code indicating the class of
|
||||
// error this error is. See http://docs.ollama.com/errors for a full
|
||||
// list of error codes.
|
||||
Code string `json:"code"`
|
||||
|
||||
// Message is a humage readable message that describes the error. It
|
||||
// may change across versions of the API, so it should not be used for
|
||||
// programmatic decisions.
|
||||
Message string `json:"message,omitempty"`
|
||||
|
||||
// Field is the field in the request that caused the error, if any.
|
||||
Field string `json:"field,omitempty"`
|
||||
|
||||
// Value is the value of the field that caused the error, if any.
|
||||
Value string `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("ollama: ")
|
||||
b.WriteString(e.Code)
|
||||
if e.Field != "" {
|
||||
b.WriteString(" ")
|
||||
b.WriteString(e.Field)
|
||||
}
|
||||
if e.Value != "" {
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Value)
|
||||
}
|
||||
if e.Message != "" {
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Message)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Do encodes in and sends it in a request to the Ollama server and decodes
|
||||
// the response into Res, or an error response (non-2xx) into an *Error, or
|
||||
// any error encounted decoding the response.
|
||||
func Do[Res any](ctx context.Context, c *Client, method, path string, in any) (*Res, error) {
|
||||
var body bytes.Buffer
|
||||
// TODO(bmizerany): pool and reuse this buffer AND the encoder
|
||||
if err := encodeJSON(&body, in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
urlStr := c.BaseURL + path
|
||||
req, err := http.NewRequestWithContext(ctx, method, urlStr, &body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hc := cmp.Or(c.HTTPClient, http.DefaultClient)
|
||||
res, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
var buf bytes.Buffer
|
||||
body := io.TeeReader(res.Body, &buf)
|
||||
e, err := decodeJSON[Error](body)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("ollama: invalid error response from server (status %d): %q", res.StatusCode, buf.String())
|
||||
return nil, err
|
||||
}
|
||||
return nil, e
|
||||
}
|
||||
|
||||
return decodeJSON[Res](res.Body)
|
||||
}
|
||||
|
||||
// decodeJSON decodes JSON from r into a new value of type T.
|
||||
//
|
||||
// NOTE: This is (and encodeJSON) are copies and paste from oweb.go, please
|
||||
// do not try and consolidate so we can keep ollama/client free from
|
||||
// dependencies which are moving targets and not pulling enough weight to
|
||||
// justify their inclusion.
|
||||
func decodeJSON[T any](r io.Reader) (*T, error) {
|
||||
var v T
|
||||
if err := json.NewDecoder(r).Decode(&v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// NOTE: see NOT above decodeJSON
|
||||
func encodeJSON(w io.Writer, v any) error {
|
||||
// TODO(bmizerany): pool and reuse encoder
|
||||
return json.NewEncoder(w).Encode(v)
|
||||
}
|
100
x/cmd/bllamo/bllamo.go
Normal file
100
x/cmd/bllamo/bllamo.go
Normal file
@ -0,0 +1,100 @@
|
||||
// Bllamo is a (new) tool for managing Ollama models.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// bllamo <command> [arguments]
|
||||
//
|
||||
// The commands are:
|
||||
//
|
||||
// build build a model from a Modelfile
|
||||
// list list all models
|
||||
// push push a model from an ollama registry
|
||||
// pull pull a model from an ollama registry
|
||||
// delete delete a model from an ollama registry
|
||||
// help display help for a command
|
||||
package main
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/api"
|
||||
"github.com/ollama/ollama/x/build"
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
"github.com/ollama/ollama/x/registry"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
args := flag.Args()
|
||||
if len(args) < 1 {
|
||||
fmt.Fprintln(os.Stderr, "bllamo: no command provided")
|
||||
os.Exit(2)
|
||||
}
|
||||
if err := Main(flag.Args()...); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var TODOUsage = fmt.Errorf("TODO: usage")
|
||||
|
||||
var commands = map[string]func(ctx context.Context, args ...string) error{
|
||||
"build": cmdBuild,
|
||||
"push": cmdPush,
|
||||
"serve": cmdServe,
|
||||
"registry": cmdRegistry,
|
||||
}
|
||||
|
||||
// Main is the entry point for the blammo command.
|
||||
func Main(args ...string) error {
|
||||
cmd := args[0]
|
||||
args = args[1:]
|
||||
if f, ok := commands[cmd]; ok {
|
||||
ctx := context.TODO()
|
||||
return f(ctx, args...)
|
||||
}
|
||||
return fmt.Errorf("blammo: unknown command %q", cmd)
|
||||
}
|
||||
|
||||
func cmdBuild(ctx context.Context, args ...string) error {
|
||||
var v struct {
|
||||
Modelfile string `flag:"f,the Modelfile to use"`
|
||||
}
|
||||
|
||||
fs := readFlags("build", args, &v)
|
||||
if fs.NArg() != 1 {
|
||||
return TODOUsage
|
||||
}
|
||||
|
||||
modelfile, err := os.ReadFile(cmp.Or(v.Modelfile, "Modelfile"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ollama.Default().Build(ctx, args[0], modelfile, os.DirFS("."))
|
||||
}
|
||||
|
||||
func cmdRegistry(_ context.Context, _ ...string) error {
|
||||
var s registry.Server
|
||||
return http.ListenAndServe(":8888", &s)
|
||||
}
|
||||
|
||||
func cmdServe(ctx context.Context, args ...string) error {
|
||||
bs, err := build.Open("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return http.ListenAndServe(":11434", &api.Server{Build: bs})
|
||||
}
|
||||
|
||||
func cmdPush(ctx context.Context, args ...string) error {
|
||||
fs := readFlags("push", args, nil)
|
||||
if fs.NArg() != 1 {
|
||||
return TODOUsage
|
||||
}
|
||||
return ollama.Default().Push(ctx, fs.Arg(0))
|
||||
}
|
59
x/cmd/bllamo/flags.go
Normal file
59
x/cmd/bllamo/flags.go
Normal file
@ -0,0 +1,59 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// parseArgs parses the provided args using a flag.FlagSet that is
|
||||
// dynamically build using reflection for the provided type. The type fields
|
||||
// that have a "flag" tag are used to build the flags. The flag tag should
|
||||
// include either a ('-'). Example usage:
|
||||
//
|
||||
// func main() {
|
||||
// var flags struct {
|
||||
// Modelfile string `flag:"f,path to the Modelfile"`
|
||||
// }
|
||||
//
|
||||
// fs := readFlags(os.Args[1:], &flags)
|
||||
// fs.Parse(os.Args[1:])
|
||||
// }
|
||||
func readFlags(name string, args []string, v any) *flag.FlagSet {
|
||||
fs := flag.NewFlagSet(name, flag.ExitOnError)
|
||||
defer fs.Parse(args)
|
||||
if v == nil {
|
||||
return fs
|
||||
}
|
||||
|
||||
for i := 0; i < reflect.ValueOf(v).NumField(); i++ {
|
||||
f := reflect.ValueOf(v).Field(i)
|
||||
if !f.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := f.Type().Field(i).Tag.Get("flag")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
var name, usage string
|
||||
if i := strings.Index(tag, ","); i != -1 {
|
||||
name = tag[:i]
|
||||
usage = tag[i+1:]
|
||||
} else {
|
||||
name = tag
|
||||
}
|
||||
|
||||
// TODO(bmizerany): add more types as needed
|
||||
switch f.Kind() {
|
||||
case reflect.String:
|
||||
fs.StringVar(f.Addr().Interface().(*string), name, "", usage)
|
||||
case reflect.Bool:
|
||||
fs.BoolVar(f.Addr().Interface().(*bool), name, false, usage)
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type %v", f.Kind()))
|
||||
}
|
||||
}
|
||||
return fs
|
||||
}
|
97
x/cmd/gguf/gguf.go
Normal file
97
x/cmd/gguf/gguf.go
Normal file
@ -0,0 +1,97 @@
|
||||
// Gguf is a tool for learning about GGUF files.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// gguf [flags] <file>
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/ollama/ollama/x/encoding/gguf"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := Main(os.Stdout, os.Args[1:]...); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Main(stdout io.Writer, args ...string) error {
|
||||
fs := flag.NewFlagSet("gguf", flag.ExitOnError)
|
||||
flagGPU := fs.Uint64("gpu", 0, "use N bytes of GPU memory (default is 0)")
|
||||
|
||||
fs.Usage = func() {
|
||||
io.WriteString(stdout, "Gguf is a tool for learning about GGUF files.\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
io.WriteString(stdout, "Usage:\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
io.WriteString(stdout, "\tgguf [flags] <file>\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
var numFlags int
|
||||
fs.VisitAll(func(*flag.Flag) { numFlags++ })
|
||||
if numFlags > 0 {
|
||||
io.WriteString(stdout, "Flags:\n")
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
}
|
||||
fs.Parse(args)
|
||||
|
||||
if fs.NArg() != 1 {
|
||||
fs.Usage()
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
file := fs.Arg(0)
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
g, err := gguf.ReadFile(f)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tw := tabwriter.NewWriter(stdout, 0, 2, 2, ' ', 0)
|
||||
defer tw.Flush()
|
||||
|
||||
fmt.Fprintf(tw, "version:\t%d\n", g.Version())
|
||||
|
||||
for m, err := range g.Metadata {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if len(m.Values) > 5 {
|
||||
fmt.Fprintf(tw, "meta:\t%q: ... (%d values)\n", m.Key, len(m.Values))
|
||||
} else {
|
||||
fmt.Fprintf(tw, "meta:\t%q: %v\n", m.Key, m.Values)
|
||||
}
|
||||
}
|
||||
|
||||
var i int
|
||||
var totalLayerBytes uint64
|
||||
var offGPU bool
|
||||
for t, err := range g.Tensors {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
totalLayerBytes += t.Size
|
||||
if totalLayerBytes > *flagGPU {
|
||||
offGPU = true
|
||||
}
|
||||
|
||||
const msg = "tensor (layer %000d):\t%q\t%s\tdims=%v\toffset=%d\tsize=%d\tonGPU=%v\n"
|
||||
fmt.Fprintf(tw, msg, i, t.Name, t.Type, t.Dimensions, t.Offset, t.Size, !offGPU)
|
||||
|
||||
i++
|
||||
}
|
||||
return nil
|
||||
}
|
376
x/encoding/gguf/gguf.go
Normal file
376
x/encoding/gguf/gguf.go
Normal file
@ -0,0 +1,376 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
|
||||
// TODO(bmizerany): determine a more reasonable value for MaxDimensions.
|
||||
|
||||
// MaxDimensions is the maximum number of dimensions a tensor can have.
|
||||
const MaxDimensions uint32 = 1e6
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrBadMagic is returned when the magic bytes at the start of the
|
||||
// file. This is useful for detecting if the file is not a gguf
|
||||
// file.
|
||||
ErrBadMagic = errors.New("gguf: bad magic")
|
||||
|
||||
ErrUnsupportedVersion = errors.New("gguf: unsupported version")
|
||||
ErrMangled = errors.New("gguf: mangled data")
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeF32 Type = 0
|
||||
TypeF16 Type = 1
|
||||
TypeQ4_0 Type = 2
|
||||
TypeQ4_1 Type = 3
|
||||
TypeQ5_0 Type = 6
|
||||
TypeQ5_1 Type = 7
|
||||
TypeQ8_0 Type = 8
|
||||
TypeQ8_1 Type = 9
|
||||
TypeQ2_K Type = 10
|
||||
TypeQ3_K Type = 11
|
||||
TypeQ4_K Type = 12
|
||||
TypeQ5_K Type = 13
|
||||
TypeQ6_K Type = 14
|
||||
TypeQ8_K Type = 15
|
||||
TypeI8 Type = 16
|
||||
TypeI16 Type = 17
|
||||
TypeI32 Type = 18
|
||||
TypeCount Type = 19
|
||||
)
|
||||
|
||||
var typeNames = map[Type]string{
|
||||
TypeF32: "F32",
|
||||
TypeF16: "F16",
|
||||
TypeQ4_0: "Q4_0",
|
||||
TypeQ4_1: "Q4_1",
|
||||
TypeQ5_0: "Q5_0",
|
||||
TypeQ5_1: "Q5_1",
|
||||
TypeQ8_0: "Q8_0",
|
||||
TypeQ8_1: "Q8_1",
|
||||
TypeQ2_K: "Q2_K",
|
||||
TypeQ3_K: "Q3_K",
|
||||
TypeQ4_K: "Q4_K",
|
||||
TypeQ5_K: "Q5_K",
|
||||
TypeQ6_K: "Q6_K",
|
||||
TypeQ8_K: "Q8_K",
|
||||
TypeI8: "I8",
|
||||
TypeI16: "I16",
|
||||
TypeI32: "I32",
|
||||
TypeCount: "COUNT",
|
||||
}
|
||||
|
||||
func (t Type) String() string {
|
||||
if name := typeNames[t]; name != "" {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("(!unknown_type %d!)", t)
|
||||
}
|
||||
|
||||
// ValueType is the type of a metadata value.
|
||||
type ValueType uint32
|
||||
|
||||
func (t ValueType) String() string {
|
||||
if name := metaTypeNames[t]; name != "" {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("(!unknown_value_type %d!)", t)
|
||||
}
|
||||
|
||||
const (
|
||||
ValueTypeUint8 ValueType = 0
|
||||
ValueTypeInt8 ValueType = 1
|
||||
ValueTypeUint16 ValueType = 2
|
||||
ValueTypeInt16 ValueType = 3
|
||||
ValueTypeUint32 ValueType = 4
|
||||
ValueTypeInt32 ValueType = 5
|
||||
ValueTypeFloat32 ValueType = 6
|
||||
ValueTypeBool ValueType = 7
|
||||
ValueTypeString ValueType = 8
|
||||
ValueTypeArray ValueType = 9
|
||||
ValueTypeUint64 ValueType = 10
|
||||
ValueTypeInt64 ValueType = 11
|
||||
ValueTypeFloat64 ValueType = 12
|
||||
)
|
||||
|
||||
var metaTypeNames = map[ValueType]string{
|
||||
ValueTypeUint8: "uint8",
|
||||
ValueTypeInt8: "int8",
|
||||
ValueTypeUint16: "uint16",
|
||||
ValueTypeInt16: "int16",
|
||||
ValueTypeUint32: "uint32",
|
||||
ValueTypeInt32: "int32",
|
||||
ValueTypeFloat32: "float32",
|
||||
ValueTypeBool: "bool",
|
||||
ValueTypeString: "string",
|
||||
ValueTypeArray: "array",
|
||||
ValueTypeUint64: "uint64",
|
||||
ValueTypeInt64: "int64",
|
||||
ValueTypeFloat64: "float64",
|
||||
}
|
||||
|
||||
type TensorInfo struct {
|
||||
Name string
|
||||
Dimensions []uint64
|
||||
Type Type
|
||||
Offset uint64
|
||||
Size uint64
|
||||
}
|
||||
|
||||
type MetaValue struct {
|
||||
Type ValueType
|
||||
Value []byte
|
||||
}
|
||||
|
||||
func (v MetaValue) String() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(v.Type.String())
|
||||
b.WriteString("(")
|
||||
switch v.Type {
|
||||
case ValueTypeArray:
|
||||
b.WriteString("[...]")
|
||||
case ValueTypeString:
|
||||
b.WriteString(strconv.Quote(string(v.Value)))
|
||||
case ValueTypeBool:
|
||||
if len(v.Value) == 0 {
|
||||
b.WriteString("(!invalid bool)")
|
||||
}
|
||||
switch v.Value[0] {
|
||||
case 0:
|
||||
b.WriteString("false")
|
||||
case 1:
|
||||
b.WriteString("true")
|
||||
default:
|
||||
b.WriteString("!invalid bool")
|
||||
}
|
||||
case ValueTypeUint8, ValueTypeInt8, ValueTypeUint16, ValueTypeInt16, ValueTypeUint32, ValueTypeInt32, ValueTypeUint64, ValueTypeInt64, ValueTypeFloat32, ValueTypeFloat64:
|
||||
var buf [8]byte
|
||||
if len(v.Value) < 8 {
|
||||
copy(buf[:], v.Value)
|
||||
}
|
||||
fmt.Fprintf(&b, "%v", binary.LittleEndian.Uint64(buf[:]))
|
||||
default:
|
||||
fmt.Fprintf(&b, "%v", v.Value)
|
||||
}
|
||||
b.WriteString(")")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type MetaEntry struct {
|
||||
Key string
|
||||
Type ValueType
|
||||
Values []MetaValue
|
||||
}
|
||||
|
||||
func (e MetaEntry) String() string {
|
||||
if len(e.Values) == 0 {
|
||||
return ""
|
||||
}
|
||||
return string(e.Values[0].Value)
|
||||
}
|
||||
|
||||
func (e MetaEntry) Uint32() uint32 {
|
||||
if len(e.Values) == 0 {
|
||||
return 0
|
||||
}
|
||||
return binary.LittleEndian.Uint32(e.Values[0].Value)
|
||||
}
|
||||
|
||||
func (e MetaEntry) FileType() Type {
|
||||
if len(e.Values) == 0 {
|
||||
return TypeCount
|
||||
}
|
||||
return Type(e.Uint32())
|
||||
}
|
||||
|
||||
func (e MetaEntry) GoString() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(e.Key)
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Type.String())
|
||||
b.WriteString("(")
|
||||
for i, v := range e.Values {
|
||||
if i > 0 {
|
||||
b.WriteString(", ")
|
||||
}
|
||||
b.WriteString(v.String())
|
||||
}
|
||||
b.WriteString(")")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type Info struct {
|
||||
_ structs.Incomparable // prevent comparison of Info values so we can change the implementation later
|
||||
|
||||
Version int
|
||||
FileType Type
|
||||
}
|
||||
|
||||
func Stat(path string) (Info, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
defer f.Close()
|
||||
return StatReader(f)
|
||||
}
|
||||
|
||||
// StatReader reads the header information from r and returns an Info
|
||||
// struct with the version and file type.
|
||||
//
|
||||
// It returns an error if any.
|
||||
//
|
||||
// As a special case, it returns ErrBadMagic if the file does not start with
|
||||
// the magic bytes. This can be used to detect if the file is not a GGUF
|
||||
// file.
|
||||
func StatReader(r io.ReadSeeker) (Info, error) {
|
||||
if _, err := r.Seek(0, 0); err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
f, err := ReadFile(r)
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
info := Info{Version: f.Version()}
|
||||
for m, err := range f.Metadata {
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
if m.Key == "general.file_type" {
|
||||
if m.Type != ValueTypeUint32 {
|
||||
return Info{}, fmt.Errorf("unexpected type for metadata key %q: %v, want %v", m.Key, m.Type, ValueTypeUint32)
|
||||
}
|
||||
info.FileType = m.FileType()
|
||||
}
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
type File struct {
|
||||
version uint32
|
||||
numMetaValues uint64
|
||||
numTensors uint64
|
||||
|
||||
gr *ggufReader
|
||||
}
|
||||
|
||||
// ReadFile reads header information from r and returns a File, ready for
|
||||
// iteration over Metadata and Tensors.
|
||||
func ReadFile(r io.Reader) (*File, error) {
|
||||
f, err := readFile(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *File) Version() int {
|
||||
return int(f.version)
|
||||
}
|
||||
|
||||
// Metadata iterates over the metadata in the file. It must be exhausted
|
||||
// before calling Tensors.
|
||||
//
|
||||
// It is not resumable.
|
||||
func (f *File) Metadata(yield func(MetaEntry, error) bool) {
|
||||
var n int
|
||||
for range f.numMetaValues {
|
||||
meta, err := f.gr.readMetaEntry()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading metadata entry %d: %w", n, err)
|
||||
yield(MetaEntry{}, err)
|
||||
return
|
||||
}
|
||||
if !yield(meta, nil) {
|
||||
return
|
||||
}
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
// Tensors iterates over the tensors in the file. It must only be called
|
||||
// after exhausting the metadata iterator.
|
||||
//
|
||||
// It is not resumable.
|
||||
func (f *File) Tensors(yield func(TensorInfo, error) bool) {
|
||||
var last TensorInfo
|
||||
for range f.numTensors {
|
||||
info, err := f.gr.readTensorInfo()
|
||||
|
||||
// If the last tensor had a valid offset, yield it.
|
||||
//
|
||||
// NOTE: No tensor should have an offset of 0 because the
|
||||
// offset is the start of the tensor data which is always
|
||||
// afer the magic bytes, version, numMetaValues, and
|
||||
// numTensors, which MUST all be non-zero bytes as per the
|
||||
// GGUF spec.
|
||||
if last.Offset > 0 {
|
||||
if !yield(last, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
yield(TensorInfo{}, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Tensor data does not include size, so we need to
|
||||
// calculate it based on the offset of the previous tensor
|
||||
// offset to the current.
|
||||
offset0 := last.Offset
|
||||
last = info
|
||||
last.Size = info.Offset - offset0
|
||||
}
|
||||
if last.Offset > 0 {
|
||||
yield(last, nil)
|
||||
}
|
||||
}
|
||||
|
||||
var magicBytes = []byte{0x47, 0x47, 0x55, 0x46}
|
||||
|
||||
func readFile(r io.Reader) (*File, error) {
|
||||
gr := &ggufReader{r: &reader{r: r}}
|
||||
magic, err := gr.next(4)
|
||||
if err != nil {
|
||||
return nil, errors.Join(err, ErrBadMagic)
|
||||
}
|
||||
if !bytes.Equal(magic, magicBytes) {
|
||||
return nil, ErrBadMagic
|
||||
}
|
||||
version, err := gr.readUint32()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version != 3 {
|
||||
return nil, fmt.Errorf("%w: %d", ErrUnsupportedVersion, version)
|
||||
}
|
||||
numTensors, err := gr.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
numMetaValues, err := gr.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info := &File{
|
||||
version: version,
|
||||
|
||||
numMetaValues: numMetaValues,
|
||||
numTensors: numTensors,
|
||||
gr: gr,
|
||||
}
|
||||
return info, nil
|
||||
}
|
345
x/encoding/gguf/gguf_test.go
Normal file
345
x/encoding/gguf/gguf_test.go
Normal file
@ -0,0 +1,345 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
func TestStat(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
wantInfo Info
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
wantErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad magic",
|
||||
data: "\xBB\xAA\xDD\x00",
|
||||
wantErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad version",
|
||||
data: string(magicBytes) +
|
||||
"\x02\x00\x00\x00" + // version
|
||||
"",
|
||||
wantErr: ErrUnsupportedVersion,
|
||||
},
|
||||
{
|
||||
name: "valid general.file_type",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// general.file_type key
|
||||
"\x11\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"general.file_type" + // key
|
||||
"\x04\x00\x00\x00" + // type (uint32)
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // uint32 value
|
||||
"",
|
||||
wantInfo: Info{
|
||||
Version: 3,
|
||||
FileType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
info, err := StatReader(strings.NewReader(tt.data))
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("err = %v; want %q", err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
diff.Test(t, t.Errorf, info, tt.wantInfo)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadInfo(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
|
||||
wantMeta []MetaEntry
|
||||
wantTensor []TensorInfo
|
||||
wantReadErr error
|
||||
wantMetaErr error
|
||||
wantTensorErr error
|
||||
wantInfo Info
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
wantReadErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "bad magic",
|
||||
data: "\xBB\xAA\xDD\x00",
|
||||
wantReadErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad version",
|
||||
data: string(magicBytes) +
|
||||
"\x02\x00\x00\x00" + // version
|
||||
"",
|
||||
wantReadErr: ErrUnsupportedVersion,
|
||||
},
|
||||
{
|
||||
name: "no metadata or tensors",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"",
|
||||
wantReadErr: nil,
|
||||
},
|
||||
{
|
||||
name: "good metadata",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"",
|
||||
wantMeta: []MetaEntry{
|
||||
{Key: "K", Type: ValueTypeString, Values: []MetaValue{{Type: ValueTypeString, Value: []byte("VV")}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "good metadata with multiple values",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// MetaEntry 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"x" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"XX" + // string value
|
||||
|
||||
// MetaEntry 2
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"y" + // key
|
||||
"\x04\x00\x00\x00" + // type (uint32)
|
||||
"\x99\x88\x77\x66" + // uint32 value
|
||||
"",
|
||||
wantMeta: []MetaEntry{
|
||||
{Key: "x", Type: ValueTypeString, Values: []MetaValue{{
|
||||
Type: ValueTypeString,
|
||||
Value: []byte("XX"),
|
||||
}}},
|
||||
{Key: "y", Type: ValueTypeUint32, Values: []MetaValue{{
|
||||
Type: ValueTypeUint32,
|
||||
Value: []byte{0x99, 0x88, 0x77, 0x66},
|
||||
}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "negative string length in meta key",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"",
|
||||
wantMetaErr: ErrMangled,
|
||||
},
|
||||
|
||||
// Tensor tests
|
||||
{
|
||||
name: "good tensor",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
|
||||
// dimensions
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"",
|
||||
wantTensor: []TensorInfo{
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 256,
|
||||
Size: 256,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "too many dimensions",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
|
||||
"\x00\x00\x00\x01" + // dimensions length
|
||||
"",
|
||||
wantTensorErr: ErrMangled,
|
||||
},
|
||||
{
|
||||
name: "size computed",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
|
||||
|
||||
// Tensor 2
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x03\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"",
|
||||
wantTensor: []TensorInfo{
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 256,
|
||||
Size: 256,
|
||||
},
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 768,
|
||||
Size: 512,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f, err := ReadFile(strings.NewReader(tt.data))
|
||||
if err != nil {
|
||||
if !errors.Is(err, tt.wantReadErr) {
|
||||
t.Fatalf("unexpected ReadFile error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var got []MetaEntry
|
||||
for meta, err := range f.Metadata {
|
||||
if !errors.Is(err, tt.wantMetaErr) {
|
||||
t.Fatalf("err = %v; want %v", err, ErrMangled)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
got = append(got, meta)
|
||||
}
|
||||
diff.Test(t, t.Errorf, got, tt.wantMeta)
|
||||
|
||||
var gotT []TensorInfo
|
||||
for tinfo, err := range f.Tensors {
|
||||
if !errors.Is(err, tt.wantTensorErr) {
|
||||
t.Fatalf("err = %v; want %v", err, tt.wantTensorErr)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
gotT = append(gotT, tinfo)
|
||||
}
|
||||
diff.Test(t, t.Errorf, gotT, tt.wantTensor)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzReadInfo(f *testing.F) {
|
||||
f.Add(string(magicBytes))
|
||||
f.Add(string(magicBytes) +
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"")
|
||||
f.Add(string(magicBytes) +
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x05\x00\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"")
|
||||
|
||||
f.Fuzz(func(t *testing.T, data string) {
|
||||
gf, err := ReadFile(strings.NewReader(data))
|
||||
if err != nil {
|
||||
t.Logf("ReadFile error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
for _, err := range gf.Metadata {
|
||||
if err != nil {
|
||||
t.Logf("metadata error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
for tinfo, err := range gf.Tensors {
|
||||
if err != nil {
|
||||
t.Logf("tensor error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
if tinfo.Offset <= 0 {
|
||||
t.Logf("invalid tensor offset: %+v", t)
|
||||
t.Skip()
|
||||
}
|
||||
if tinfo.Size <= 0 {
|
||||
t.Logf("invalid tensor size: %+v", t)
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
195
x/encoding/gguf/ggufio.go
Normal file
195
x/encoding/gguf/ggufio.go
Normal file
@ -0,0 +1,195 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
)
|
||||
|
||||
type ggufReader struct {
|
||||
r *reader
|
||||
n int
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaEntry() (MetaEntry, error) {
|
||||
key, err := r.readString()
|
||||
if err != nil {
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
typ, err := r.readValueType()
|
||||
if err != nil {
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
var values []MetaValue
|
||||
for v, err := range r.readMetaValues(typ) {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("(key=%q type=%s): %w", key, typ, err)
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
values = append(values, v)
|
||||
}
|
||||
return MetaEntry{
|
||||
Key: string(key),
|
||||
Type: typ,
|
||||
Values: values,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaValue(typ ValueType) (MetaValue, error) {
|
||||
var value []byte
|
||||
var err error
|
||||
switch typ {
|
||||
case ValueTypeUint8, ValueTypeInt8:
|
||||
value, err = r.next(1)
|
||||
case ValueTypeUint16, ValueTypeInt16:
|
||||
value, err = r.next(2)
|
||||
case ValueTypeUint32, ValueTypeInt32, ValueTypeFloat32:
|
||||
value, err = r.next(4)
|
||||
case ValueTypeUint64, ValueTypeInt64, ValueTypeFloat64:
|
||||
value, err = r.next(8)
|
||||
case ValueTypeBool:
|
||||
value, err = r.next(1)
|
||||
case ValueTypeString:
|
||||
value, err = r.readString()
|
||||
case ValueTypeArray:
|
||||
err = fmt.Errorf("nested arrays are not supported")
|
||||
default:
|
||||
err = fmt.Errorf("unsupported metadata type: %d", typ)
|
||||
}
|
||||
if err != nil {
|
||||
return MetaValue{}, err
|
||||
}
|
||||
return MetaValue{
|
||||
Type: typ,
|
||||
Value: bytes.Clone(value),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaValues(typ ValueType) iter.Seq2[MetaValue, error] {
|
||||
return func(yield func(MetaValue, error) bool) {
|
||||
if typ == ValueTypeArray {
|
||||
atyp, err := r.readValueType()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid type: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
n, err := r.readUint64()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid length: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
for i := range n {
|
||||
v, err := r.readMetaValue(atyp)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid entry (type=%s) %d: %w", atyp, i, err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
if !yield(v, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
v, err := r.readMetaValue(typ)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading metadata value: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
yield(v, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ggufReader) readValueType() (ValueType, error) {
|
||||
typ, err := r.readUint32()
|
||||
return ValueType(typ), err
|
||||
}
|
||||
|
||||
func (r *ggufReader) readTensorInfo() (TensorInfo, error) {
|
||||
name, err := r.readString()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
numDimensions, err := r.readUint32()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
if numDimensions > MaxDimensions {
|
||||
return TensorInfo{}, fmt.Errorf("%w: dimensions length (%d) exceeds %d", ErrMangled, numDimensions, MaxDimensions)
|
||||
}
|
||||
|
||||
dims := make([]uint64, numDimensions)
|
||||
for i := range dims {
|
||||
d, err := r.readUint64()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
dims[i] = d
|
||||
}
|
||||
typ, err := r.readUint32()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
offset, err := r.readUint64()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
// TODO(bmizerany): check offset is multiple of ALIGNMENT
|
||||
return TensorInfo{
|
||||
Name: string(name),
|
||||
Dimensions: dims,
|
||||
Type: Type(typ),
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) next(n int) ([]byte, error) {
|
||||
if n < 0 {
|
||||
return nil, errors.Join(fmt.Errorf("invalid read length: %d", n), ErrMangled)
|
||||
}
|
||||
w := r.r.window()
|
||||
for len(w) < n {
|
||||
if r.r.extend() == 0 {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
w = r.r.window()
|
||||
}
|
||||
r.r.release(n)
|
||||
r.n += n
|
||||
return w[:n], nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readString() ([]byte, error) {
|
||||
n, err := r.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(bmizerany): limit max string length
|
||||
return r.next(int(n))
|
||||
}
|
||||
|
||||
func (r *ggufReader) readUint32() (uint32, error) {
|
||||
b, err := r.next(4)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := binary.LittleEndian.Uint32(b)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readUint64() (uint64, error) {
|
||||
b, err := r.next(8)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := binary.LittleEndian.Uint64(b)
|
||||
return n, nil
|
||||
}
|
70
x/encoding/gguf/reader.go
Normal file
70
x/encoding/gguf/reader.go
Normal file
@ -0,0 +1,70 @@
|
||||
package gguf
|
||||
|
||||
import "io"
|
||||
|
||||
// A reader implements a sliding window over an io.Reader.
|
||||
type reader struct {
|
||||
data []byte
|
||||
offset int
|
||||
r io.Reader
|
||||
err error
|
||||
}
|
||||
|
||||
// release discards n bytes from the front of the window.
|
||||
func (b *reader) release(n int) {
|
||||
b.offset += n
|
||||
}
|
||||
|
||||
// window returns the current window.
|
||||
// The window is invalidated by calls to release or extend.
|
||||
func (b *reader) window() []byte {
|
||||
return b.data[b.offset:]
|
||||
}
|
||||
|
||||
// tuning constants for byteReader.extend.
|
||||
const (
|
||||
newBufferSize = 8 << 10
|
||||
minReadSize = newBufferSize >> 2
|
||||
)
|
||||
|
||||
// extend extends the window with data from the underlying reader.
|
||||
func (b *reader) extend() int {
|
||||
if b.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
remaining := len(b.data) - b.offset
|
||||
if remaining == 0 {
|
||||
b.data = b.data[:0]
|
||||
b.offset = 0
|
||||
}
|
||||
if cap(b.data)-len(b.data) >= minReadSize {
|
||||
// nothing to do, enough space exists between len and cap.
|
||||
} else if cap(b.data)-remaining >= minReadSize {
|
||||
// buffer has enough space if we move the data to the front.
|
||||
b.compact()
|
||||
} else {
|
||||
// otherwise, we must allocate/extend a new buffer
|
||||
b.grow()
|
||||
}
|
||||
remaining += b.offset
|
||||
n, err := b.r.Read(b.data[remaining:cap(b.data)])
|
||||
// reduce length to the existing plus the data we read.
|
||||
b.data = b.data[:remaining+n]
|
||||
b.err = err
|
||||
return n
|
||||
}
|
||||
|
||||
// grow grows the buffer, moving the active data to the front.
|
||||
func (b *reader) grow() {
|
||||
buf := make([]byte, max(cap(b.data)*2, newBufferSize))
|
||||
copy(buf, b.data[b.offset:])
|
||||
b.data = buf
|
||||
b.offset = 0
|
||||
}
|
||||
|
||||
// compact moves the active data to the front of the buffer.
|
||||
func (b *reader) compact() {
|
||||
copy(b.data, b.data[b.offset:])
|
||||
b.offset = 0
|
||||
}
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x02\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xfd\xff\xff\xff\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00K\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00VV\x01\x00\x00\x00\x00\\x00\\x00\\x00\\x00")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x0000000000000000000000000\xe5")
|
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x0000000000\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x000\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x0000\x01\x00\x00\x00\x00\x00\x00\x000\x01\x00\x001\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\a")
|
134
x/model/digest.go
Normal file
134
x/model/digest.go
Normal file
@ -0,0 +1,134 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Digest represents a digest of a model Manifest. It is a comparable value
|
||||
// type and is immutable.
|
||||
//
|
||||
// The zero Digest is not a valid digest.
|
||||
type Digest struct {
|
||||
s string
|
||||
}
|
||||
|
||||
// Type returns the digest type of the digest.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ParseDigest("sha256-1234").Type() // returns "sha256"
|
||||
func (d Digest) Type() string {
|
||||
typ, _, _ := strings.Cut(d.s, "-")
|
||||
return typ
|
||||
}
|
||||
|
||||
// String returns the digest in the form of "<digest-type>-<digest>", or the
|
||||
// empty string if the digest is invalid.
|
||||
func (d Digest) String() string { return d.s }
|
||||
|
||||
// IsValid returns true if the digest is valid (not zero).
|
||||
//
|
||||
// A valid digest may be created only by ParseDigest, or
|
||||
// ParseName(name).Digest().
|
||||
func (d Digest) IsValid() bool { return d.s != "" }
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (d Digest) MarshalText() ([]byte, error) {
|
||||
return []byte(d.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements encoding.TextUnmarshaler.
|
||||
func (d *Digest) UnmarshalText(text []byte) error {
|
||||
if d.IsValid() {
|
||||
return errors.New("model.Digest: illegal UnmarshalText on valid Digest")
|
||||
}
|
||||
*d = ParseDigest(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogValue implements slog.Value.
|
||||
func (d Digest) LogValue() slog.Value {
|
||||
return slog.StringValue(d.String())
|
||||
}
|
||||
|
||||
var (
|
||||
_ driver.Valuer = Digest{}
|
||||
_ sql.Scanner = (*Digest)(nil)
|
||||
_ slog.LogValuer = Digest{}
|
||||
)
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (d *Digest) Scan(src any) error {
|
||||
if d.IsValid() {
|
||||
return errors.New("model.Digest: illegal Scan on valid Digest")
|
||||
}
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
*d = ParseDigest(v)
|
||||
return nil
|
||||
case []byte:
|
||||
*d = ParseDigest(string(v))
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("model.Digest: invalid Scan source %T", src)
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (d Digest) Value() (driver.Value, error) {
|
||||
return d.String(), nil
|
||||
}
|
||||
|
||||
// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
|
||||
// Digest.
|
||||
func ParseDigest(s string) Digest {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
if ok && isValidDigestType(typ) && isValidHex(digest) {
|
||||
return Digest{s: s}
|
||||
}
|
||||
return Digest{}
|
||||
}
|
||||
|
||||
// isValidDigest returns true if the given string in the form of
|
||||
// "<digest-type>-<digest>", and <digest-type> is in the form of [a-z0-9]+
|
||||
// and <digest> is a valid hex string.
|
||||
//
|
||||
// It does not check if the digest is a valid hash for the given digest
|
||||
// type, or restrict the digest type to a known set of types. This is left
|
||||
// up to ueers of this package.
|
||||
func isValidDigest(s string) bool {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
res := ok && isValidDigestType(typ) && isValidHex(digest)
|
||||
fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res)
|
||||
return res
|
||||
}
|
||||
|
||||
func isValidDigestType(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidHex(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for i := range s {
|
||||
c := s[i]
|
||||
if c < '0' || c > '9' && c < 'a' || c > 'f' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
83
x/model/digest_test.go
Normal file
83
x/model/digest_test.go
Normal file
@ -0,0 +1,83 @@
|
||||
package model
|
||||
|
||||
import "testing"
|
||||
|
||||
// - test scan
|
||||
// - test marshal text
|
||||
// - test unmarshal text
|
||||
// - test log value
|
||||
// - test string
|
||||
// - test type
|
||||
// - test digest
|
||||
// - test valid
|
||||
// - test driver valuer
|
||||
// - test sql scanner
|
||||
// - test parse digest
|
||||
|
||||
var testDigests = map[string]Digest{
|
||||
"": {},
|
||||
"sha256-1234": {s: "sha256-1234"},
|
||||
"sha256-5678": {s: "sha256-5678"},
|
||||
"blake2-9abc": {s: "blake2-9abc"},
|
||||
"-1234": {},
|
||||
"sha256-": {},
|
||||
"sha256-1234-5678": {},
|
||||
"sha256-P": {}, // invalid hex
|
||||
"sha256-1234P": {},
|
||||
"---": {},
|
||||
}
|
||||
|
||||
func TestDigestParse(t *testing.T) {
|
||||
// Test cases.
|
||||
for s, want := range testDigests {
|
||||
got := ParseDigest(s)
|
||||
t.Logf("ParseDigest(%q) = %#v", s, got)
|
||||
if got != want {
|
||||
t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDigestString(t *testing.T) {
|
||||
// Test cases.
|
||||
for s, d := range testDigests {
|
||||
want := s
|
||||
if !d.IsValid() {
|
||||
want = ""
|
||||
}
|
||||
got := d.String()
|
||||
if got != want {
|
||||
t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
got = ParseDigest(s).String()
|
||||
if got != want {
|
||||
t.Errorf("roundtrip ParseDigest(%q).String() = %q; want %q", s, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDigestUnmarshalText(t *testing.T) {
|
||||
const testDigest = "sha256-1234"
|
||||
t.Run("UnmarshalText (into Valid)", func(t *testing.T) {
|
||||
d := ParseDigest(testDigest)
|
||||
if !d.IsValid() {
|
||||
panic("invalid test")
|
||||
}
|
||||
if err := d.UnmarshalText(nil); err == nil {
|
||||
t.Errorf("UnmarshalText on valid Digest did not return error")
|
||||
}
|
||||
if d.String() != testDigest {
|
||||
t.Errorf("UnmarshalText on valid Digest changed Digest: %q", d.String())
|
||||
}
|
||||
})
|
||||
t.Run("UnmarshalText make safe copy", func(t *testing.T) {
|
||||
data := []byte(testDigest)
|
||||
var d Digest
|
||||
d.UnmarshalText(data)
|
||||
data[0] = 'x'
|
||||
if d.String() != testDigest {
|
||||
t.Errorf("UnmarshalText did not make a safe copy")
|
||||
}
|
||||
})
|
||||
}
|
132
x/model/file.go
Normal file
132
x/model/file.go
Normal file
@ -0,0 +1,132 @@
|
||||
// Package model implements the File and Name types for working with and
|
||||
// representing Modelfiles and model Names.
|
||||
//
|
||||
// The Name type should be used when working with model names, and the File
|
||||
// type should be used when working with Modelfiles.
|
||||
package model
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"iter"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ParamPragma struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
type MessagePragma struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
type File struct {
|
||||
// From is a required pragma that specifies the source of the model,
|
||||
// either on disk, or by reference (see model.ParseName).
|
||||
From string
|
||||
|
||||
// Optional
|
||||
Params []ParamPragma
|
||||
Template string
|
||||
System string
|
||||
Adapter string
|
||||
Messages []MessagePragma
|
||||
|
||||
License string
|
||||
}
|
||||
|
||||
type FileError struct {
|
||||
Pragma string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *FileError) Error() string {
|
||||
return e.Pragma + ": " + e.Message
|
||||
}
|
||||
|
||||
// Pragma represents a single pragma in a Modelfile.
|
||||
type Pragma struct {
|
||||
// The pragma name
|
||||
Name string
|
||||
|
||||
// Args contains the user-defined arguments for the pragma. If no
|
||||
// arguments were provided, it is nil.
|
||||
Args []string
|
||||
}
|
||||
|
||||
func (p Pragma) Arg(i int) string {
|
||||
if i >= len(p.Args) {
|
||||
return ""
|
||||
}
|
||||
return p.Args[i]
|
||||
}
|
||||
|
||||
func FilePragmas(r io.Reader) iter.Seq2[Pragma, error] {
|
||||
return func(yield func(Pragma, error) bool) {
|
||||
sc := bufio.NewScanner(r)
|
||||
for sc.Scan() {
|
||||
line := sc.Text()
|
||||
|
||||
// TODO(bmizerany): set a max num fields/args to
|
||||
// prevent mem bloat
|
||||
args := strings.Fields(line)
|
||||
if len(args) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
p := Pragma{
|
||||
Name: strings.ToUpper(args[0]),
|
||||
}
|
||||
if p.Name == "MESSAGE" {
|
||||
// handle special case where message content
|
||||
// is space separated on the _rest_ of the
|
||||
// line like: `MESSAGE user Is Ontario in
|
||||
// Canada?`
|
||||
panic("TODO")
|
||||
}
|
||||
if len(args) > 1 {
|
||||
p.Args = args[1:]
|
||||
}
|
||||
if !yield(p, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if sc.Err() != nil {
|
||||
yield(Pragma{}, sc.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ParseFile(r io.Reader) (File, error) {
|
||||
var f File
|
||||
for p, err := range FilePragmas(r) {
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
switch p.Name {
|
||||
case "FROM":
|
||||
f.From = p.Arg(0)
|
||||
case "PARAMETER":
|
||||
f.Params = append(f.Params, ParamPragma{
|
||||
Key: strings.ToLower(p.Arg(0)),
|
||||
Value: p.Arg(1),
|
||||
})
|
||||
case "TEMPLATE":
|
||||
f.Template = p.Arg(0)
|
||||
case "SYSTEM":
|
||||
f.System = p.Arg(0)
|
||||
case "ADAPTER":
|
||||
f.Adapter = p.Arg(0)
|
||||
case "MESSAGE":
|
||||
f.Messages = append(f.Messages, MessagePragma{
|
||||
Role: p.Arg(0),
|
||||
Content: p.Arg(1),
|
||||
})
|
||||
case "LICENSE":
|
||||
f.License = p.Arg(0)
|
||||
}
|
||||
}
|
||||
return f, nil
|
||||
}
|
593
x/model/name.go
Normal file
593
x/model/name.go
Normal file
@ -0,0 +1,593 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrInvalidName is not used by this package, but is exported so that
|
||||
// other packages do not need to invent their own error type when they
|
||||
// need to return an error for an invalid name.
|
||||
ErrIncompleteName = errors.New("incomplete model name")
|
||||
ErrInvalidDigest = errors.New("invalid digest")
|
||||
)
|
||||
|
||||
const MaxNamePartLen = 128
|
||||
|
||||
type PartKind int
|
||||
|
||||
// Levels of concreteness
|
||||
const (
|
||||
// Each value aligns with its index in the Name.parts array.
|
||||
|
||||
PartHost PartKind = iota
|
||||
PartNamespace
|
||||
PartModel
|
||||
PartTag
|
||||
PartBuild
|
||||
PartDigest
|
||||
|
||||
// Invalid is a special part that is used to indicate that a part is
|
||||
// invalid. It is not a valid part of a Name.
|
||||
//
|
||||
// It should be kept as the last part in the list.
|
||||
PartInvalid
|
||||
)
|
||||
|
||||
var kindNames = map[PartKind]string{
|
||||
PartHost: "Host",
|
||||
PartNamespace: "Namespace",
|
||||
PartModel: "Name",
|
||||
PartTag: "Tag",
|
||||
PartBuild: "Build",
|
||||
PartDigest: "Digest",
|
||||
PartInvalid: "Invalid",
|
||||
}
|
||||
|
||||
func (k PartKind) String() string {
|
||||
return cmp.Or(kindNames[k], "Unknown")
|
||||
}
|
||||
|
||||
// Name is an opaque reference to a model. It holds the parts of a model
|
||||
// with the case preserved, but is not directly comparable with other Names
|
||||
// since model names can be represented with different caseing depending on
|
||||
// the use case. For instance, "Mistral" and "mistral" are the same model
|
||||
// but each version may have come from different sources (e.g. copied from a
|
||||
// Web page, or from a file path).
|
||||
//
|
||||
// Valid Names can ONLY be constructed by calling [ParseName].
|
||||
//
|
||||
// A Name is valid if and only if is have a valid Model part. The other parts
|
||||
// are optional.
|
||||
//
|
||||
// A Name is considered "complete" if it has all parts present. To check if a
|
||||
// Name is complete, use [Name.IsComplete].
|
||||
//
|
||||
// To compare two names in a case-insensitive manner, use [Name.EqualFold].
|
||||
//
|
||||
// The parts of a Name are:
|
||||
//
|
||||
// - Host: the domain of the model (optional)
|
||||
// - Namespace: the namespace of the model (optional)
|
||||
// - Model: the name of the model (required)
|
||||
// - Tag: the tag of the model (optional)
|
||||
// - Build: the build of the model; usually the quantization or "file type" (optional)
|
||||
//
|
||||
// The parts can be obtained in their original form by calling [Name.Parts].
|
||||
//
|
||||
// To check if a Name has at minimum a valid model part, use [Name.IsValid].
|
||||
//
|
||||
// To make a Name by filling in missing parts from another Name, use [Fill].
|
||||
type Name struct {
|
||||
_ structs.Incomparable
|
||||
parts [6]string // host, namespace, model, tag, build
|
||||
|
||||
// TODO(bmizerany): track offsets and hold s (raw string) here? We
|
||||
// could pack the offests all into a single uint64 since the first
|
||||
// parts take less bits since their max offset is less than the max
|
||||
// offset of the next part. This would save a ton of bytes per Name
|
||||
// and mean zero allocations for String.
|
||||
}
|
||||
|
||||
// ParseName parses s into a Name. The input string must be a valid string
|
||||
// representation of a model name in the form:
|
||||
//
|
||||
// [host/][namespace/]<model>[:tag][+build][@<digest-type>-<digest>]
|
||||
//
|
||||
// The name part is required, all others are optional. If a part is missing,
|
||||
// it is left empty in the returned Name. If a part is invalid, the zero Ref
|
||||
// value is returned.
|
||||
//
|
||||
// The build part is normalized to uppercase.
|
||||
//
|
||||
// Examples of valid paths:
|
||||
//
|
||||
// "example.com/library/mistral:7b+x"
|
||||
// "example.com/eva/mistral:7b+Q4_0"
|
||||
// "mistral:7b+x"
|
||||
// "example.com/mike/mistral:latest+Q4_0"
|
||||
// "example.com/bruce/mistral:latest"
|
||||
// "example.com/mistral:7b+Q4_0@sha256-1234567890abcdef"
|
||||
//
|
||||
// Examples of invalid paths:
|
||||
//
|
||||
// "example.com/mistral:7b+"
|
||||
// "example.com/mistral:7b+Q4_0+"
|
||||
// "x/y/z/z:8n+I"
|
||||
// ""
|
||||
//
|
||||
// It returns the zero value if any part is invalid.
|
||||
//
|
||||
// As a rule of thumb, an valid name is one that can be round-tripped with
|
||||
// the [Name.String] method. That means ("x+") is invalid because
|
||||
// [Name.String] will not print a "+" if the build is empty.
|
||||
func ParseName(s string) Name {
|
||||
var r Name
|
||||
for kind, part := range Parts(s) {
|
||||
if kind == PartInvalid {
|
||||
return Name{}
|
||||
}
|
||||
if kind == PartDigest && !ParseDigest(part).IsValid() {
|
||||
return Name{}
|
||||
}
|
||||
r.parts[kind] = part
|
||||
}
|
||||
if r.IsValid() || r.IsResolved() {
|
||||
return r
|
||||
}
|
||||
return Name{}
|
||||
}
|
||||
|
||||
func MustParseName(s string) Name {
|
||||
r := ParseName(s)
|
||||
if !r.IsValid() {
|
||||
panic("model.MustParseName: invalid name: " + s)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Fill fills in the missing parts of dst with the parts of src.
|
||||
//
|
||||
// The returned Name will only be valid if dst is valid.
|
||||
func Fill(dst, src Name) Name {
|
||||
var r Name
|
||||
for i := range r.parts {
|
||||
r.parts[i] = cmp.Or(dst.parts[i], src.parts[i])
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// WithBuild returns a copy of r with the build set to the given string.
|
||||
func (r Name) WithBuild(build string) Name {
|
||||
r.parts[PartBuild] = build
|
||||
return r
|
||||
}
|
||||
|
||||
func (r Name) WithDigest(digest Digest) Name {
|
||||
r.parts[PartDigest] = digest.String()
|
||||
return r
|
||||
}
|
||||
|
||||
var mapHashSeed = maphash.MakeSeed()
|
||||
|
||||
// MapHash returns a case insensitive hash for use in maps and equality
|
||||
// checks. For a convienent way to compare names, use [Name.EqualFold].
|
||||
func (r Name) MapHash() uint64 {
|
||||
// correctly hash the parts with case insensitive comparison
|
||||
var h maphash.Hash
|
||||
h.SetSeed(mapHashSeed)
|
||||
for _, part := range r.Parts() {
|
||||
// downcase the part for hashing
|
||||
for i := range part {
|
||||
c := part[i]
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
c = c - 'A' + 'a'
|
||||
}
|
||||
h.WriteByte(c)
|
||||
}
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
func (r Name) slice(from, to PartKind) Name {
|
||||
var v Name
|
||||
copy(v.parts[from:to+1], r.parts[from:to+1])
|
||||
return v
|
||||
}
|
||||
|
||||
// DisplayModel returns the a display string composed of the model only.
|
||||
func (r Name) DisplayModel() string {
|
||||
return r.parts[PartModel]
|
||||
}
|
||||
|
||||
// DisplayFullest returns the fullest possible display string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
//
|
||||
// It does not include the build part. For the fullest possible display
|
||||
// string with the build, use [Name.String].
|
||||
func (r Name) DisplayFullest() string {
|
||||
return r.slice(PartHost, PartTag).String()
|
||||
}
|
||||
|
||||
// DisplayShort returns the fullest possible display string in form:
|
||||
//
|
||||
// <model>:<tag>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
func (r Name) DisplayShort() string {
|
||||
return r.slice(PartModel, PartTag).String()
|
||||
}
|
||||
|
||||
// DisplayLong returns the fullest possible display string in form:
|
||||
//
|
||||
// <namespace>/<model>:<tag>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
func (r Name) DisplayLong() string {
|
||||
return r.slice(PartNamespace, PartTag).String()
|
||||
}
|
||||
|
||||
var seps = [...]string{
|
||||
PartHost: "/",
|
||||
PartNamespace: "/",
|
||||
PartModel: ":",
|
||||
PartTag: "+",
|
||||
PartBuild: "@",
|
||||
PartDigest: "",
|
||||
}
|
||||
|
||||
// WriteTo implements io.WriterTo. It writes the fullest possible display
|
||||
// string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>+<build>@<digest-type>-<digest>
|
||||
//
|
||||
// Missing parts and their seperators are not written.
|
||||
//
|
||||
// The full digest is always prefixed with "@". That is if [Name.IsValid]
|
||||
// reports false and [Name.IsResolved] reports true, then the string is
|
||||
// returned as "@<digest-type>-<digest>".
|
||||
func (r Name) writeTo(w io.StringWriter) {
|
||||
var partsWritten int
|
||||
for i := range r.parts {
|
||||
if r.parts[i] == "" {
|
||||
continue
|
||||
}
|
||||
if partsWritten > 0 || i == int(PartDigest) {
|
||||
w.WriteString(seps[i-1])
|
||||
}
|
||||
w.WriteString(r.parts[i])
|
||||
partsWritten++
|
||||
}
|
||||
}
|
||||
|
||||
var builderPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &strings.Builder{}
|
||||
},
|
||||
}
|
||||
|
||||
// String returns the fullest possible display string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>+<build>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
//
|
||||
// For the fullest possible display string without the build, use
|
||||
// [Name.DisplayFullest].
|
||||
func (r Name) String() string {
|
||||
b := builderPool.Get().(*strings.Builder)
|
||||
defer builderPool.Put(b)
|
||||
b.Reset()
|
||||
b.Grow(50) // arbitrarily long enough for most names
|
||||
r.writeTo(b)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// GoString implements fmt.GoStringer. It returns a string suitable for
|
||||
// debugging and logging. It is similar to [Name.String] but it always
|
||||
// returns a string that includes all parts of the Name, with missing parts
|
||||
// replaced with a ("?").
|
||||
func (r Name) GoString() string {
|
||||
for i := range r.parts {
|
||||
r.parts[i] = cmp.Or(r.parts[i], "?")
|
||||
}
|
||||
return r.String()
|
||||
}
|
||||
|
||||
// LogValue implements slog.Valuer.
|
||||
func (r Name) LogValue() slog.Value {
|
||||
return slog.StringValue(r.GoString())
|
||||
}
|
||||
|
||||
var bufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(bytes.Buffer)
|
||||
},
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (r Name) MarshalText() ([]byte, error) {
|
||||
b := bufPool.Get().(*bytes.Buffer)
|
||||
b.Reset()
|
||||
b.Grow(50) // arbitrarily long enough for most names
|
||||
defer bufPool.Put(b)
|
||||
r.writeTo(b)
|
||||
// TODO: We can remove this alloc if/when
|
||||
// https://github.com/golang/go/issues/62384 lands.
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler].
|
||||
//
|
||||
// It is an error to call UnmarshalText on a valid Name.
|
||||
func (r *Name) UnmarshalText(text []byte) error {
|
||||
if r.IsValid() {
|
||||
// The invariant of UnmarshalText is that it should only be
|
||||
// called on an invalid/zero Name. If we allow UnmarshalText
|
||||
// on a valid Name, then the Name will be mutated, breaking
|
||||
// the immutability of the Name.
|
||||
return errors.New("model.Name: illegal UnmarshalText on valid Name")
|
||||
}
|
||||
|
||||
// The contract of UnmarshalText is that we copy to keep the text.
|
||||
*r = ParseName(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ driver.Valuer = Name{}
|
||||
_ sql.Scanner = (*Name)(nil)
|
||||
)
|
||||
|
||||
// Scan implements [database/sql.Scanner].
|
||||
func (r *Name) Scan(src any) error {
|
||||
if r.IsValid() {
|
||||
// The invariant of Scan is that it should only be called on an
|
||||
// invalid/zero Name. If we allow Scan on a valid Name, then the
|
||||
// Name will be mutated, breaking the immutability of the Name.
|
||||
return errors.New("model.Name: illegal Scan on valid Name")
|
||||
|
||||
}
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
*r = ParseName(v)
|
||||
return nil
|
||||
case []byte:
|
||||
*r = ParseName(string(v))
|
||||
return nil
|
||||
}
|
||||
return errors.New("model.Name: invalid Scan source")
|
||||
}
|
||||
|
||||
// Value implements [database/sql/driver.Valuer].
|
||||
func (r Name) Value() (driver.Value, error) {
|
||||
return r.String(), nil
|
||||
}
|
||||
|
||||
// IsComplete reports whether the Name is fully qualified. That is it has a
|
||||
// domain, namespace, name, tag, and build.
|
||||
func (r Name) IsComplete() bool {
|
||||
return !slices.Contains(r.parts[:PartDigest], "")
|
||||
}
|
||||
|
||||
// IsCompleteNoBuild is like [Name.IsComplete] but it does not require the
|
||||
// build part to be present.
|
||||
func (r Name) IsCompleteNoBuild() bool {
|
||||
return !slices.Contains(r.parts[:PartBuild], "")
|
||||
}
|
||||
|
||||
// IsResolved reports true if the Name has a valid digest.
|
||||
//
|
||||
// It is possible to have a valid Name, or a complete Name that is not
|
||||
// resolved.
|
||||
func (r Name) IsResolved() bool {
|
||||
return r.Digest().IsValid()
|
||||
}
|
||||
|
||||
// Digest returns the digest part of the Name, if any.
|
||||
//
|
||||
// If Digest returns a non-empty string, then [Name.IsResolved] will return
|
||||
// true, and digest is considered valid.
|
||||
func (r Name) Digest() Digest {
|
||||
// This was already validated by ParseName, so we can just return it.
|
||||
return Digest{r.parts[PartDigest]}
|
||||
}
|
||||
|
||||
// EqualFold reports whether r and o are equivalent model names, ignoring
|
||||
// case.
|
||||
func (r Name) EqualFold(o Name) bool {
|
||||
return r.CompareFold(o) == 0
|
||||
}
|
||||
|
||||
// CompareFold performs a case-insensitive cmp.Compare on r and o.
|
||||
//
|
||||
// This can be used with [slices.SortFunc].
|
||||
//
|
||||
// For simple equality checks, use [Name.EqualFold].
|
||||
func (r Name) CompareFold(o Name) int {
|
||||
return slices.CompareFunc(r.parts[:], o.parts[:], compareFold)
|
||||
}
|
||||
|
||||
func compareFold(a, b string) int {
|
||||
return slices.CompareFunc([]rune(a), []rune(b), func(a, b rune) int {
|
||||
return cmp.Compare(downcase(a), downcase(b))
|
||||
})
|
||||
}
|
||||
|
||||
func downcase(r rune) rune {
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
return r - 'A' + 'a'
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// TODO(bmizerany): driver.Value? (MarshalText etc should be enough)
|
||||
|
||||
// Parts returns the parts of the Name in order of concreteness.
|
||||
//
|
||||
// The length of the returned slice is always 5.
|
||||
func (r Name) Parts() []string {
|
||||
return slices.Clone(r.parts[:])
|
||||
}
|
||||
|
||||
// Parts returns a sequence of the parts of a Name string from most specific
|
||||
// to least specific.
|
||||
//
|
||||
// It normalizes the input string by removing "http://" and "https://" only.
|
||||
// No other normalization is done.
|
||||
func Parts(s string) iter.Seq2[PartKind, string] {
|
||||
return func(yield func(PartKind, string) bool) {
|
||||
if strings.HasPrefix(s, "http://") {
|
||||
s = s[len("http://"):]
|
||||
}
|
||||
if strings.HasPrefix(s, "https://") {
|
||||
s = s[len("https://"):]
|
||||
}
|
||||
|
||||
if len(s) > MaxNamePartLen || len(s) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
yieldValid := func(kind PartKind, part string) bool {
|
||||
if !isValidPart(kind, part) {
|
||||
yield(PartInvalid, "")
|
||||
return false
|
||||
}
|
||||
return yield(kind, part)
|
||||
}
|
||||
|
||||
partLen := 0
|
||||
state, j := PartDigest, len(s)
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
if partLen++; partLen > MaxNamePartLen {
|
||||
// catch a part that is too long early, so
|
||||
// we don't keep spinning on it, waiting for
|
||||
// an isInValidPart check which would scan
|
||||
// over it again.
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
|
||||
switch s[i] {
|
||||
case '@':
|
||||
switch state {
|
||||
case PartDigest:
|
||||
if !yieldValid(PartDigest, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
if i == 0 {
|
||||
// This is the form
|
||||
// "@<digest>" which is valid.
|
||||
//
|
||||
// We're done.
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartBuild, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
case '+':
|
||||
switch state {
|
||||
case PartBuild, PartDigest:
|
||||
if !yieldValid(PartBuild, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartTag, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
case ':':
|
||||
switch state {
|
||||
case PartTag, PartBuild, PartDigest:
|
||||
if !yieldValid(PartTag, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartModel, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
case '/':
|
||||
switch state {
|
||||
case PartModel, PartTag, PartBuild, PartDigest:
|
||||
if !yieldValid(PartModel, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j = PartNamespace, i
|
||||
case PartNamespace:
|
||||
if !yieldValid(PartNamespace, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartHost, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
default:
|
||||
if !isValidByte(state, s[i]) {
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state <= PartNamespace {
|
||||
yieldValid(state, s[:j])
|
||||
} else {
|
||||
yieldValid(PartModel, s[:j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsValid returns true if the Name hPartas a valid nick. To know if a Name is
|
||||
// "complete", use [Name.IsComplete].
|
||||
func (r Name) IsValid() bool {
|
||||
// Parts ensures we only have valid parts, so no need to validate
|
||||
// them here, only check if we have a name or not.
|
||||
return r.parts[PartModel] != ""
|
||||
}
|
||||
|
||||
// isValidPart returns Parttrue if given part is valid ascii [a-zA-Z0-9_\.-]
|
||||
func isValidPart(kind PartKind, s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, c := range []byte(s) {
|
||||
if !isValidByte(kind, c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidByte(kind PartKind, c byte) bool {
|
||||
if kind == PartNamespace && c == '.' {
|
||||
return false
|
||||
}
|
||||
if c == '.' || c == '-' {
|
||||
return true
|
||||
}
|
||||
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
572
x/model/name_test.go
Normal file
572
x/model/name_test.go
Normal file
@ -0,0 +1,572 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fields struct {
|
||||
host, namespace, model, tag, build string
|
||||
digest string
|
||||
}
|
||||
|
||||
func fieldsFromName(p Name) fields {
|
||||
return fields{
|
||||
host: p.parts[PartHost],
|
||||
namespace: p.parts[PartNamespace],
|
||||
model: p.parts[PartModel],
|
||||
tag: p.parts[PartTag],
|
||||
build: p.parts[PartBuild],
|
||||
digest: p.parts[PartDigest],
|
||||
}
|
||||
}
|
||||
|
||||
var testNames = map[string]fields{
|
||||
"mistral:latest": {model: "mistral", tag: "latest"},
|
||||
"mistral": {model: "mistral"},
|
||||
"mistral:30B": {model: "mistral", tag: "30B"},
|
||||
"mistral:7b": {model: "mistral", tag: "7b"},
|
||||
"mistral:7b+Q4_0": {model: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"mistral+KQED": {model: "mistral", build: "KQED"},
|
||||
"mistral.x-3:7b+Q4_0": {model: "mistral.x-3", tag: "7b", build: "Q4_0"},
|
||||
"mistral:7b+q4_0": {model: "mistral", tag: "7b", build: "q4_0"},
|
||||
"llama2": {model: "llama2"},
|
||||
"user/model": {namespace: "user", model: "model"},
|
||||
"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
|
||||
|
||||
// invalid digest
|
||||
"mistral:latest@invalid256-": {},
|
||||
"mistral:latest@-123": {},
|
||||
"mistral:latest@!-123": {},
|
||||
"mistral:latest@1-!": {},
|
||||
"mistral:latest@": {},
|
||||
|
||||
// resolved
|
||||
"x@sha123-1": {model: "x", digest: "sha123-1"},
|
||||
"@sha456-2": {digest: "sha456-2"},
|
||||
|
||||
"@@sha123-1": {},
|
||||
|
||||
// preserves case for build
|
||||
"x+b": {model: "x", build: "b"},
|
||||
|
||||
// invalid (includes fuzzing trophies)
|
||||
" / / : + ": {},
|
||||
" / : + ": {},
|
||||
" : + ": {},
|
||||
" + ": {},
|
||||
" : ": {},
|
||||
" / ": {},
|
||||
" /": {},
|
||||
"/ ": {},
|
||||
"/": {},
|
||||
":": {},
|
||||
"+": {},
|
||||
|
||||
// (".") in namepsace is not allowed
|
||||
"invalid.com/7b+x": {},
|
||||
|
||||
"invalid:7b+Q4_0:latest": {},
|
||||
"in valid": {},
|
||||
"invalid/y/z/foo": {},
|
||||
"/0": {},
|
||||
"0 /0": {},
|
||||
"0 /": {},
|
||||
"0/": {},
|
||||
":/0": {},
|
||||
"+0/00000": {},
|
||||
"0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {},
|
||||
"0//0": {},
|
||||
"m+^^^": {},
|
||||
"file:///etc/passwd": {},
|
||||
"file:///etc/passwd:latest": {},
|
||||
"file:///etc/passwd:latest+u": {},
|
||||
|
||||
strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
|
||||
strings.Repeat("a", MaxNamePartLen+1): {},
|
||||
}
|
||||
|
||||
func TestNameParts(t *testing.T) {
|
||||
var p Name
|
||||
if w, g := int(PartDigest+1), len(p.Parts()); w != g {
|
||||
t.Errorf("Parts() = %d; want %d", g, w)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamePartString(t *testing.T) {
|
||||
if g := PartKind(-2).String(); g != "Unknown" {
|
||||
t.Errorf("Unknown part = %q; want %q", g, "Unknown")
|
||||
}
|
||||
for kind, name := range kindNames {
|
||||
if g := kind.String(); g != name {
|
||||
t.Errorf("%s = %q; want %q", kind, g, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
for baseName, want := range testNames {
|
||||
for _, prefix := range []string{"", "https://", "http://"} {
|
||||
// We should get the same results with or without the
|
||||
// http(s) prefixes
|
||||
s := prefix + baseName
|
||||
|
||||
t.Run(s, func(t *testing.T) {
|
||||
for kind, part := range Parts(s) {
|
||||
t.Logf("Part: %s: %q", kind, part)
|
||||
}
|
||||
|
||||
name := ParseName(s)
|
||||
got := fieldsFromName(name)
|
||||
if got != want {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
// test round-trip
|
||||
if !ParseName(name.String()).EqualFold(name) {
|
||||
t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName)
|
||||
}
|
||||
|
||||
if name.IsValid() && name.DisplayModel() == "" {
|
||||
t.Errorf("Valid() = true; Model() = %q; want non-empty name", got.model)
|
||||
} else if !name.IsValid() && name.DisplayModel() != "" {
|
||||
t.Errorf("Valid() = false; Model() = %q; want empty name", got.model)
|
||||
}
|
||||
|
||||
if name.IsResolved() && !name.Digest().IsValid() {
|
||||
t.Errorf("Resolved() = true; Digest() = %q; want non-empty digest", got.digest)
|
||||
} else if !name.IsResolved() && name.Digest().IsValid() {
|
||||
t.Errorf("Resolved() = false; Digest() = %q; want empty digest", got.digest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteWithAndWithoutBuild(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
complete bool
|
||||
completeNoBuild bool
|
||||
}{
|
||||
{"", false, false},
|
||||
{"incomplete/mistral:7b+x", false, false},
|
||||
{"incomplete/mistral:7b+Q4_0", false, false},
|
||||
{"incomplete:7b+x", false, false},
|
||||
{"complete.com/x/mistral:latest+Q4_0", true, true},
|
||||
{"complete.com/x/mistral:latest", false, true},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.IsComplete(); g != tt.complete {
|
||||
t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
|
||||
}
|
||||
if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild {
|
||||
t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Complete uses Parts which returns a slice, but it should be
|
||||
// inlined when used in Complete, preventing any allocations or
|
||||
// escaping to the heap.
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseName("complete.com/x/mistral:latest+Q4_0").IsComplete())
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("Complete allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameLogValue(t *testing.T) {
|
||||
cases := []string{
|
||||
"example.com/library/mistral:latest+Q4_0",
|
||||
"mistral:latest",
|
||||
"mistral:7b+Q4_0",
|
||||
}
|
||||
for _, s := range cases {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
log := slog.New(slog.NewTextHandler(&b, nil))
|
||||
name := ParseName(s)
|
||||
log.Info("", "name", name)
|
||||
want := fmt.Sprintf("name=%s", name.GoString())
|
||||
got := b.String()
|
||||
if !strings.Contains(got, want) {
|
||||
t.Errorf("expected log output to contain %q; got %q", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameDisplay(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
wantShort string
|
||||
wantLong string
|
||||
wantComplete string
|
||||
wantString string
|
||||
wantModel string
|
||||
wantGoString string // default is tt.in
|
||||
}{
|
||||
{
|
||||
name: "Complete Name",
|
||||
in: "example.com/library/mistral:latest+Q4_0",
|
||||
wantShort: "mistral:latest",
|
||||
wantLong: "library/mistral:latest",
|
||||
wantComplete: "example.com/library/mistral:latest",
|
||||
wantModel: "mistral",
|
||||
wantGoString: "example.com/library/mistral:latest+Q4_0@?",
|
||||
},
|
||||
{
|
||||
name: "Short Name",
|
||||
in: "mistral:latest",
|
||||
wantShort: "mistral:latest",
|
||||
wantLong: "mistral:latest",
|
||||
wantComplete: "mistral:latest",
|
||||
wantModel: "mistral",
|
||||
wantGoString: "?/?/mistral:latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "Long Name",
|
||||
in: "library/mistral:latest",
|
||||
wantShort: "mistral:latest",
|
||||
wantLong: "library/mistral:latest",
|
||||
wantComplete: "library/mistral:latest",
|
||||
wantModel: "mistral",
|
||||
wantGoString: "?/library/mistral:latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "Case Preserved",
|
||||
in: "Library/Mistral:Latest",
|
||||
wantShort: "Mistral:Latest",
|
||||
wantLong: "Library/Mistral:Latest",
|
||||
wantComplete: "Library/Mistral:Latest",
|
||||
wantModel: "Mistral",
|
||||
wantGoString: "?/Library/Mistral:Latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "With digest",
|
||||
in: "Library/Mistral:Latest@sha256-123456",
|
||||
wantShort: "Mistral:Latest",
|
||||
wantLong: "Library/Mistral:Latest",
|
||||
wantComplete: "Library/Mistral:Latest",
|
||||
wantModel: "Mistral",
|
||||
wantGoString: "?/Library/Mistral:Latest+?@sha256-123456",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
if g := p.DisplayShort(); g != tt.wantShort {
|
||||
t.Errorf("DisplayShort = %q; want %q", g, tt.wantShort)
|
||||
}
|
||||
if g := p.DisplayLong(); g != tt.wantLong {
|
||||
t.Errorf("DisplayLong = %q; want %q", g, tt.wantLong)
|
||||
}
|
||||
if g := p.DisplayFullest(); g != tt.wantComplete {
|
||||
t.Errorf("DisplayFullest = %q; want %q", g, tt.wantComplete)
|
||||
}
|
||||
if g := p.String(); g != tt.in {
|
||||
t.Errorf("String(%q) = %q; want %q", tt.in, g, tt.in)
|
||||
}
|
||||
if g := p.DisplayModel(); g != tt.wantModel {
|
||||
t.Errorf("Model = %q; want %q", g, tt.wantModel)
|
||||
}
|
||||
|
||||
tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
|
||||
if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
|
||||
t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameAllocs(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseName("example.com/mistral:7b+Q4_0"))
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("ParseName allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseName(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for range b.N {
|
||||
keep(ParseName("example.com/mistral:7b+Q4_0"))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNameDisplay(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
r := ParseName("example.com/mistral:7b+Q4_0")
|
||||
b.Run("Short", func(b *testing.B) {
|
||||
for range b.N {
|
||||
keep(r.DisplayShort())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzParseName(f *testing.F) {
|
||||
f.Add("example.com/mistral:7b+Q4_0")
|
||||
f.Add("example.com/mistral:7b+q4_0")
|
||||
f.Add("example.com/mistral:7b+x")
|
||||
f.Add("x/y/z:8n+I")
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
r0 := ParseName(s)
|
||||
if !r0.IsValid() {
|
||||
if !r0.EqualFold(Name{}) {
|
||||
t.Errorf("expected invalid path to be zero value; got %#v", r0)
|
||||
}
|
||||
t.Skipf("invalid path: %q", s)
|
||||
}
|
||||
|
||||
for _, p := range r0.Parts() {
|
||||
if len(p) > MaxNamePartLen {
|
||||
t.Errorf("part too long: %q", p)
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.EqualFold(r0.String(), s) {
|
||||
t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.String(), s)
|
||||
}
|
||||
|
||||
r1 := ParseName(r0.String())
|
||||
if !r0.EqualFold(r1) {
|
||||
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFill(t *testing.T) {
|
||||
cases := []struct {
|
||||
dst string
|
||||
src string
|
||||
want string
|
||||
}{
|
||||
{"mistral", "o.com/library/PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
{"o.com/library/mistral", "PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
{"", "o.com/library/mistral:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.dst, func(t *testing.T) {
|
||||
r := Fill(ParseName(tt.dst), ParseName(tt.src))
|
||||
if r.String() != tt.want {
|
||||
t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameTextMarshal(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
wantErr error
|
||||
}{
|
||||
{"example.com/mistral:latest+Q4_0", "", nil},
|
||||
{"mistral:latest+Q4_0", "mistral:latest+Q4_0", nil},
|
||||
{"mistral:latest", "mistral:latest", nil},
|
||||
{"mistral", "mistral", nil},
|
||||
{"mistral:7b", "mistral:7b", nil},
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest+Q4_0", nil},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
got, err := p.MarshalText()
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("MarshalText() error = %v; want %v", err, tt.wantErr)
|
||||
}
|
||||
if string(got) != tt.want {
|
||||
t.Errorf("MarshalText() = %q; want %q", got, tt.want)
|
||||
}
|
||||
|
||||
var r Name
|
||||
if err := r.UnmarshalText(got); err != nil {
|
||||
t.Fatalf("UnmarshalText() error = %v; want nil", err)
|
||||
}
|
||||
if !r.EqualFold(p) {
|
||||
t.Errorf("UnmarshalText() = %q; want %q", r, p)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("UnmarshalText into valid Name", func(t *testing.T) {
|
||||
// UnmarshalText should not be called on a valid Name.
|
||||
p := MustParseName("x")
|
||||
if err := p.UnmarshalText([]byte("mistral:latest+Q4_0")); err == nil {
|
||||
t.Error("UnmarshalText() = nil; want error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TextMarshal allocs", func(t *testing.T) {
|
||||
var data []byte
|
||||
name := ParseName("example.com/ns/mistral:latest+Q4_0")
|
||||
if !name.IsComplete() {
|
||||
// sanity check
|
||||
panic("sanity check failed")
|
||||
}
|
||||
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
var err error
|
||||
data, err = name.MarshalText()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Fatal("MarshalText() = 0; want non-zero")
|
||||
}
|
||||
})
|
||||
if allocs > 0 {
|
||||
// TODO: Update when/if this lands:
|
||||
// https://github.com/golang/go/issues/62384
|
||||
//
|
||||
// Currently, the best we can do is 1 alloc.
|
||||
t.Errorf("MarshalText allocs = %v; want <= 1", allocs)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnmarshalTest makes safe copy", func(t *testing.T) {
|
||||
// UnmarshalText should make a copy of the data.
|
||||
data := []byte("mistral:latest+Q4_0")
|
||||
p := Name{}
|
||||
if err := p.UnmarshalText(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data[0] = 'x'
|
||||
if p.String() != "mistral:latest+Q4_0" {
|
||||
t.Errorf("UnmarshalText() did not make a copy")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSQL(t *testing.T) {
|
||||
t.Run("Scan for already valid Name", func(t *testing.T) {
|
||||
p := MustParseName("x")
|
||||
if err := p.Scan("mistral:latest+Q4_0"); err == nil {
|
||||
t.Error("Scan() = nil; want error")
|
||||
}
|
||||
})
|
||||
t.Run("Scan for invalid Name", func(t *testing.T) {
|
||||
p := Name{}
|
||||
if err := p.Scan("mistral:latest+Q4_0"); err != nil {
|
||||
t.Errorf("Scan() = %v; want nil", err)
|
||||
}
|
||||
if p.String() != "mistral:latest+Q4_0" {
|
||||
t.Errorf("String() = %q; want %q", p, "mistral:latest+Q4_0")
|
||||
}
|
||||
})
|
||||
t.Run("Value", func(t *testing.T) {
|
||||
p := MustParseName("x")
|
||||
if g, err := p.Value(); err != nil {
|
||||
t.Errorf("Value() error = %v; want nil", err)
|
||||
} else if g != "x" {
|
||||
t.Errorf("Value() = %q; want %q", g, "x")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNameStringAllocs(t *testing.T) {
|
||||
name := ParseName("example.com/ns/mistral:latest+Q4_0")
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(name.String())
|
||||
})
|
||||
if allocs > 1 {
|
||||
t.Errorf("String allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleFill() {
|
||||
defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
|
||||
r := Fill(ParseName("mistral"), defaults)
|
||||
fmt.Println(r)
|
||||
|
||||
// Output:
|
||||
// registry.ollama.com/library/mistral:latest+Q4_0
|
||||
}
|
||||
|
||||
func ExampleName_MapHash() {
|
||||
m := map[uint64]bool{}
|
||||
|
||||
// key 1
|
||||
m[ParseName("mistral:latest+q4").MapHash()] = true
|
||||
m[ParseName("miSTRal:latest+Q4").MapHash()] = true
|
||||
m[ParseName("mistral:LATest+Q4").MapHash()] = true
|
||||
|
||||
// key 2
|
||||
m[ParseName("mistral:LATest").MapHash()] = true
|
||||
|
||||
fmt.Println(len(m))
|
||||
// Output:
|
||||
// 2
|
||||
}
|
||||
|
||||
func ExampleName_CompareFold_sort() {
|
||||
names := []Name{
|
||||
ParseName("mistral:latest"),
|
||||
ParseName("mistRal:7b+q4"),
|
||||
ParseName("MIstral:7b"),
|
||||
}
|
||||
|
||||
slices.SortFunc(names, Name.CompareFold)
|
||||
|
||||
for _, n := range names {
|
||||
fmt.Println(n)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// MIstral:7b
|
||||
// mistRal:7b+q4
|
||||
// mistral:latest
|
||||
}
|
||||
|
||||
func ExampleName_completeAndResolved() {
|
||||
for _, s := range []string{
|
||||
"x/y/z:latest+q4_0@sha123-1",
|
||||
"x/y/z:latest+q4_0",
|
||||
"@sha123-1",
|
||||
} {
|
||||
p := ParseName(s)
|
||||
fmt.Printf("complete:%v resolved:%v digest:%s\n", p.IsComplete(), p.IsResolved(), p.Digest())
|
||||
}
|
||||
|
||||
// Output:
|
||||
// complete:true resolved:true digest:sha123-1
|
||||
// complete:true resolved:false digest:
|
||||
// complete:false resolved:true digest:sha123-1
|
||||
}
|
||||
|
||||
func ExampleName_DisplayFullest() {
|
||||
for _, s := range []string{
|
||||
"example.com/jmorganca/mistral:latest+Q4_0",
|
||||
"mistral:latest+Q4_0",
|
||||
"mistral:latest",
|
||||
} {
|
||||
fmt.Println(ParseName(s).DisplayFullest())
|
||||
}
|
||||
|
||||
// Output:
|
||||
// example.com/jmorganca/mistral:latest
|
||||
// mistral:latest
|
||||
// mistral:latest
|
||||
}
|
||||
|
||||
func keep[T any](v T) T { return v }
|
2
x/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa
vendored
Normal file
2
x/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("/0")
|
2
x/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6
vendored
Normal file
2
x/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("0//0")
|
2
x/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d
vendored
Normal file
2
x/model/testdata/fuzz/FuzzParseRef/3e3b70dba384074d
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("0 /0")
|
2
x/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab
vendored
Normal file
2
x/model/testdata/fuzz/FuzzParseRef/71f1fdff711b6dab
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("+0/00000")
|
2
x/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608
vendored
Normal file
2
x/model/testdata/fuzz/FuzzParseRef/82c2975c430ac608
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string(":")
|
2
x/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948
vendored
Normal file
2
x/model/testdata/fuzz/FuzzParseRef/b51b1c875e61a948
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91")
|
89
x/oweb/oweb.go
Normal file
89
x/oweb/oweb.go
Normal file
@ -0,0 +1,89 @@
|
||||
package oweb
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
)
|
||||
|
||||
func Missing(field string) error {
|
||||
return &ollama.Error{
|
||||
Status: 400,
|
||||
Code: "missing",
|
||||
Field: field,
|
||||
Message: fmt.Sprintf("%s is required", field),
|
||||
}
|
||||
}
|
||||
|
||||
func Invalid(field, value, format string, args ...any) error {
|
||||
return &ollama.Error{
|
||||
Status: 400,
|
||||
Code: "invalid",
|
||||
Field: field,
|
||||
Value: value,
|
||||
Message: fmt.Sprintf(format, args...),
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience errors
|
||||
var (
|
||||
ErrNotFound = &ollama.Error{Status: 404, Code: "not_found"}
|
||||
ErrInternal = &ollama.Error{Status: 500, Code: "internal_error"}
|
||||
ErrMethodNotAllowed = &ollama.Error{Status: 405, Code: "method_not_allowed"}
|
||||
)
|
||||
|
||||
type HandlerFunc func(w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
func Serve(h HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||
if err := h(w, r); err != nil {
|
||||
// TODO: take a slog.Logger
|
||||
log.Printf("error: %v", err)
|
||||
var oe *ollama.Error
|
||||
if !errors.As(err, &oe) {
|
||||
oe = ErrInternal
|
||||
}
|
||||
oe.Status = cmp.Or(oe.Status, 400)
|
||||
w.WriteHeader(oe.Status)
|
||||
if err := EncodeJSON(w, oe); err != nil {
|
||||
log.Printf("error encoding error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func DecodeUserJSON[T any](field string, r io.Reader) (*T, error) {
|
||||
v, err := DecodeJSON[T](r)
|
||||
|
||||
// Handle common JSON syntax errors
|
||||
var e *json.SyntaxError
|
||||
if errors.As(err, &e) {
|
||||
return nil, Invalid(field, "", e.Error())
|
||||
}
|
||||
|
||||
// Handle type errors
|
||||
var se *json.UnmarshalTypeError
|
||||
if errors.As(err, &se) {
|
||||
return nil, Invalid(field, se.Value, "expected %s", se.Type)
|
||||
}
|
||||
|
||||
// Return v and err as they were.
|
||||
return v, err
|
||||
}
|
||||
|
||||
func DecodeJSON[T any](r io.Reader) (*T, error) {
|
||||
var v *T
|
||||
if err := json.NewDecoder(r).Decode(&v); err != nil {
|
||||
var zero T
|
||||
return &zero, err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func EncodeJSON(w io.Writer, v any) error {
|
||||
return json.NewEncoder(w).Encode(v)
|
||||
}
|
46
x/registry/apitype/apitype.go
Normal file
46
x/registry/apitype/apitype.go
Normal file
@ -0,0 +1,46 @@
|
||||
package apitype
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type Manifest struct {
|
||||
Layers []Layer `json:"layers"`
|
||||
}
|
||||
|
||||
type CompletePart struct {
|
||||
URL string `json:"url"` // contains partNumber and uploadId from server
|
||||
ETag string `json:"etag"`
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
Digest string `json:"digest"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
type PushRequest struct {
|
||||
Name string `json:"ref"`
|
||||
Manifest json.RawMessage `json:"manifest"`
|
||||
|
||||
// Parts is a list of upload parts that the client upload in the previous
|
||||
// push.
|
||||
CompleteParts []CompletePart `json:"part_uploads"`
|
||||
}
|
||||
|
||||
type Requirement struct {
|
||||
Digest string `json:"digest"`
|
||||
Offset int64 `json:"offset"`
|
||||
Size int64 `json:"Size"`
|
||||
|
||||
// URL is the url to PUT the layer to.
|
||||
//
|
||||
// Clients must include it as the URL, alond with the ETag in the
|
||||
// response headers from the PUT request, in the next push request
|
||||
// in the Uploaded field.
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type PushResponse struct {
|
||||
// Requirements is a list of digests that the client needs to push before
|
||||
// repushing the manifest.
|
||||
Requirements []Requirement `json:"requirements,omitempty"`
|
||||
}
|
102
x/registry/client.go
Normal file
102
x/registry/client.go
Normal file
@ -0,0 +1,102 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
"github.com/ollama/ollama/x/registry/apitype"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
BaseURL string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func (c *Client) oclient() *ollama.Client {
|
||||
return (*ollama.Client)(c)
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
CompleteParts []apitype.CompletePart
|
||||
}
|
||||
|
||||
// Push pushes a manifest to the server.
|
||||
func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushParams) ([]apitype.Requirement, error) {
|
||||
p = cmp.Or(p, &PushParams{})
|
||||
// TODO(bmizerany): backoff
|
||||
v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
|
||||
Name: ref,
|
||||
Manifest: manifest,
|
||||
CompleteParts: p.CompleteParts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v.Requirements, nil
|
||||
}
|
||||
|
||||
func PushLayer(ctx context.Context, body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) {
|
||||
var zero apitype.CompletePart
|
||||
if off < 0 {
|
||||
return zero, errors.New("off must be >0")
|
||||
}
|
||||
|
||||
file := io.NewSectionReader(body, off, n)
|
||||
req, err := http.NewRequest("PUT", url, file)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
req.ContentLength = n
|
||||
|
||||
// TODO(bmizerany): take content type param
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
if n >= 0 {
|
||||
req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1))
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
e := parseS3Error(res)
|
||||
return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
|
||||
}
|
||||
etag := strings.Trim(res.Header.Get("ETag"), `"`)
|
||||
cp := apitype.CompletePart{
|
||||
URL: url,
|
||||
ETag: etag,
|
||||
// TODO(bmizerany): checksum
|
||||
}
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
type s3Error struct {
|
||||
XMLName xml.Name `xml:"Error"`
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
Resource string `xml:"Resource"`
|
||||
RequestId string `xml:"RequestId"`
|
||||
}
|
||||
|
||||
func (e *s3Error) Error() string {
|
||||
return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
|
||||
}
|
||||
|
||||
// parseS3Error parses an XML error response from S3.
|
||||
func parseS3Error(res *http.Response) error {
|
||||
var se *s3Error
|
||||
if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
|
||||
return err
|
||||
}
|
||||
return se
|
||||
}
|
256
x/registry/server.go
Normal file
256
x/registry/server.go
Normal file
@ -0,0 +1,256 @@
|
||||
// Package implements an Ollama registry client and server package registry
|
||||
package registry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"github.com/ollama/ollama/x/oweb"
|
||||
"github.com/ollama/ollama/x/registry/apitype"
|
||||
"github.com/ollama/ollama/x/utils/upload"
|
||||
)
|
||||
|
||||
// Defaults
|
||||
const (
|
||||
DefaultUploadChunkSize = 50 * 1024 * 1024
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
UploadChunkSize int64 // default is DefaultUploadChunkSize
|
||||
S3Client *minio.Client
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if err := s.serveHTTP(w, r); err != nil {
|
||||
log.Printf("error: %v", err) // TODO(bmizerany): take a slog.Logger
|
||||
var e *ollama.Error
|
||||
if !errors.As(err, &e) {
|
||||
e = oweb.ErrInternal
|
||||
}
|
||||
w.WriteHeader(cmp.Or(e.Status, 400))
|
||||
if err := oweb.EncodeJSON(w, e); err != nil {
|
||||
log.Printf("error encoding error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.URL.Path {
|
||||
case "/v1/push":
|
||||
return s.handlePush(w, r)
|
||||
case "/v1/pull":
|
||||
return s.handlePull(w, r)
|
||||
default:
|
||||
return oweb.ErrNotFound
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) uploadChunkSize() int64 {
|
||||
return cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize)
|
||||
}
|
||||
|
||||
func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
|
||||
const bucketTODO = "test"
|
||||
const minimumMultipartSize = 5 * 1024 * 1024 // S3 spec
|
||||
|
||||
pr, err := oweb.DecodeUserJSON[apitype.PushRequest]("", r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mp := model.ParseName(pr.Name)
|
||||
if !mp.IsComplete() {
|
||||
return oweb.Invalid("name", pr.Name, "must be complete")
|
||||
}
|
||||
|
||||
m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mcc := &minio.Core{Client: s.s3()}
|
||||
// TODO(bmizerany): complete uploads before stats for any with ETag
|
||||
|
||||
type completeParts struct {
|
||||
key string
|
||||
parts []minio.CompletePart
|
||||
}
|
||||
|
||||
completePartsByUploadID := make(map[string]completeParts)
|
||||
for _, mcp := range pr.CompleteParts {
|
||||
// parse the URL
|
||||
u, err := url.Parse(mcp.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
|
||||
// Check if this is a part upload, if not, skip
|
||||
uploadID := q.Get("uploadId")
|
||||
if uploadID == "" {
|
||||
// not a part upload
|
||||
continue
|
||||
}
|
||||
|
||||
// PartNumber is required
|
||||
queryPartNumber := q.Get("partNumber")
|
||||
partNumber, err := strconv.Atoi(queryPartNumber)
|
||||
if err != nil {
|
||||
return oweb.Invalid("partNumber", queryPartNumber, "")
|
||||
}
|
||||
if partNumber < 1 {
|
||||
return oweb.Invalid("partNumber", queryPartNumber, "must be >= 1")
|
||||
}
|
||||
|
||||
// ETag is required
|
||||
if mcp.ETag == "" {
|
||||
return oweb.Missing("etag")
|
||||
}
|
||||
|
||||
cp := completePartsByUploadID[uploadID]
|
||||
cp.key = u.Path
|
||||
cp.parts = append(cp.parts, minio.CompletePart{
|
||||
PartNumber: partNumber,
|
||||
ETag: mcp.ETag,
|
||||
})
|
||||
completePartsByUploadID[uploadID] = cp
|
||||
}
|
||||
|
||||
for uploadID, cp := range completePartsByUploadID {
|
||||
var zeroOpts minio.PutObjectOptions
|
||||
|
||||
// TODO: gross fix!!!!!!!!!!!!!!!
|
||||
key := strings.TrimPrefix(cp.key, "/"+bucketTODO+"/")
|
||||
|
||||
fmt.Printf("Completing multipart upload %s %s %v\n", bucketTODO, key, cp.parts)
|
||||
_, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, key, uploadID, cp.parts, zeroOpts)
|
||||
if err != nil {
|
||||
var e minio.ErrorResponse
|
||||
if errors.As(err, &e) && e.Code == "NoSuchUpload" {
|
||||
return oweb.Invalid("uploadId", uploadID, "")
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var requirements []apitype.Requirement
|
||||
for _, l := range m.Layers {
|
||||
// TODO(bmizerany): do in parallel
|
||||
if l.Size == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(bmizerany): "global" throttle of rate of transfer
|
||||
pushed, err := s.statObject(r.Context(), l.Digest)
|
||||
if err != nil {
|
||||
println("ERROR:", "statObject", err)
|
||||
return err
|
||||
}
|
||||
if !pushed {
|
||||
key := path.Join("blobs", l.Digest)
|
||||
if l.Size < minimumMultipartSize {
|
||||
// single part upload
|
||||
fmt.Printf("Presigning single %s %s\n", bucketTODO, key)
|
||||
signedURL, err := s.s3().PresignedPutObject(r.Context(), bucketTODO, key, 15*time.Minute)
|
||||
if err != nil {
|
||||
println("ERROR:", "presign single", err)
|
||||
return err
|
||||
}
|
||||
requirements = append(requirements, apitype.Requirement{
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
URL: signedURL.String(),
|
||||
})
|
||||
} else {
|
||||
uploadID, err := mcc.NewMultipartUpload(r.Context(), bucketTODO, key, minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Presigning multi %s %s %s\n", bucketTODO, key, uploadID)
|
||||
for partNumber, c := range upload.Chunks(l.Size, s.uploadChunkSize()) {
|
||||
const timeToStartUpload = 15 * time.Minute
|
||||
|
||||
signedURL, err := s.s3().Presign(r.Context(), "PUT", bucketTODO, key, timeToStartUpload, url.Values{
|
||||
"partNumber": []string{strconv.Itoa(partNumber)},
|
||||
"uploadId": []string{uploadID},
|
||||
})
|
||||
if err != nil {
|
||||
println("ERROR:", "presign multi", err)
|
||||
return err
|
||||
}
|
||||
|
||||
requirements = append(requirements, apitype.Requirement{
|
||||
Digest: l.Digest,
|
||||
Offset: c.Offset,
|
||||
Size: c.N,
|
||||
URL: signedURL.String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(requirements) == 0 {
|
||||
// Commit the manifest
|
||||
body := bytes.NewReader(pr.Manifest)
|
||||
path := path.Join("manifests", path.Join(mp.Parts()...))
|
||||
_, err := s.s3().PutObject(r.Context(), bucketTODO, path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return oweb.EncodeJSON(w, &apitype.PushResponse{Requirements: requirements})
|
||||
}
|
||||
|
||||
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
// lookup manifest
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
func (s *Server) statObject(ctx context.Context, digest string) (pushed bool, err error) {
|
||||
// HEAD the object
|
||||
path := path.Join("blobs", digest)
|
||||
_, err = s.s3().StatObject(ctx, "test", path, minio.StatObjectOptions{})
|
||||
if err != nil {
|
||||
if isNoSuchKey(err) {
|
||||
err = nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isNoSuchKey(err error) bool {
|
||||
var e minio.ErrorResponse
|
||||
return errors.As(err, &e) && e.Code == "NoSuchKey"
|
||||
}
|
||||
|
||||
func (s *Server) s3() *minio.Client {
|
||||
if s.S3Client != nil {
|
||||
return s.S3Client
|
||||
}
|
||||
s3, err := minio.New("localhost:9000", &minio.Options{
|
||||
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
|
||||
Secure: false,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s3
|
||||
}
|
473
x/registry/server_test.go
Normal file
473
x/registry/server_test.go
Normal file
@ -0,0 +1,473 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"github.com/ollama/ollama/x/registry/apitype"
|
||||
"github.com/ollama/ollama/x/utils/backoff"
|
||||
"github.com/ollama/ollama/x/utils/upload"
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
// const ref = "registry.ollama.ai/x/y:latest+Z"
|
||||
// const manifest = `{
|
||||
// "layers": [
|
||||
// {"digest": "sha256-1", "size": 1},
|
||||
// {"digest": "sha256-2", "size": 2},
|
||||
// {"digest": "sha256-3", "size": 3}
|
||||
// ]
|
||||
// }`
|
||||
|
||||
// ts := newTestServer(t)
|
||||
// ts.pushNotOK(ref, `{}`, &ollama.Error{
|
||||
// Status: 400,
|
||||
// Code: "invalid",
|
||||
// Message: "name must be fully qualified",
|
||||
// })
|
||||
|
||||
// ts.push(ref, `{
|
||||
// "layers": [
|
||||
// {"digest": "sha256-1", "size": 1},
|
||||
// {"digest": "sha256-2", "size": 2},
|
||||
// {"digest": "sha256-3", "size": 3}
|
||||
// ]
|
||||
// }`)
|
||||
|
||||
type tWriter struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (w tWriter) Write(p []byte) (n int, err error) {
|
||||
w.t.Logf("%s", p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestPushBasic(t *testing.T) {
|
||||
const MB = 1024 * 1024
|
||||
|
||||
mc := startMinio(t, true)
|
||||
|
||||
defer func() {
|
||||
mcc := &minio.Core{Client: mc}
|
||||
// fail if there are any incomplete uploads
|
||||
for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
|
||||
t.Errorf("incomplete: %v", x)
|
||||
}
|
||||
}()
|
||||
|
||||
const ref = "registry.ollama.ai/x/y:latest+Z"
|
||||
|
||||
// Upload two small layers and one large layer that will
|
||||
// trigger a multipart upload.
|
||||
manifest := []byte(`{
|
||||
"layers": [
|
||||
{"digest": "sha256-1", "size": 1},
|
||||
{"digest": "sha256-2", "size": 2},
|
||||
{"digest": "sha256-3", "size": 11000000}
|
||||
]
|
||||
}`)
|
||||
|
||||
hs := httptest.NewServer(&Server{
|
||||
S3Client: mc,
|
||||
UploadChunkSize: 5 * MB,
|
||||
})
|
||||
t.Cleanup(hs.Close)
|
||||
c := &Client{BaseURL: hs.URL}
|
||||
|
||||
requirements, err := c.Push(context.Background(), ref, manifest, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(requirements) < 3 {
|
||||
t.Errorf("expected at least 3 requirements; got %d", len(requirements))
|
||||
t.Logf("requirements: %v", requirements)
|
||||
}
|
||||
|
||||
var uploaded []apitype.CompletePart
|
||||
for i, r := range requirements {
|
||||
t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
|
||||
|
||||
cp, err := PushLayer(context.Background(), &abcReader{}, r.URL, r.Offset, r.Size)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
uploaded = append(uploaded, cp)
|
||||
}
|
||||
|
||||
requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
|
||||
CompleteParts: uploaded,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(requirements) != 0 {
|
||||
t.Errorf("unexpected requirements: %v", requirements)
|
||||
}
|
||||
|
||||
var paths []string
|
||||
keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{
|
||||
Recursive: true,
|
||||
})
|
||||
for k := range keys {
|
||||
paths = append(paths, k.Key)
|
||||
}
|
||||
|
||||
t.Logf("paths: %v", paths)
|
||||
|
||||
diff.Test(t, t.Errorf, paths, []string{
|
||||
"blobs/sha256-1",
|
||||
"blobs/sha256-2",
|
||||
"blobs/sha256-3",
|
||||
"manifests/registry.ollama.ai/x/y/latest/Z",
|
||||
})
|
||||
|
||||
obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
var gotM apitype.Manifest
|
||||
if err := json.NewDecoder(obj).Decode(&gotM); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, gotM, apitype.Manifest{
|
||||
Layers: []apitype.Layer{
|
||||
{Digest: "sha256-1", Size: 1},
|
||||
{Digest: "sha256-2", Size: 2},
|
||||
{Digest: "sha256-3", Size: 11000000},
|
||||
},
|
||||
})
|
||||
|
||||
// checksum the blobs
|
||||
for i, l := range gotM.Layers {
|
||||
obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
info, err := obj.Stat()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
|
||||
|
||||
if msg := checkABCs(obj, int(l.Size)); msg != "" {
|
||||
t.Errorf("[%d] %s", i, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of
|
||||
// presigning a multipart upload, uploading the parts, and completing the
|
||||
// upload. It is for future reference and should not be deleted. This flow
|
||||
// is tricky and if we get it wrong in our server, we can refer back to this
|
||||
// as a "back to basics" test/reference.
|
||||
func TestBasicPresignS3MultipartReferenceDoNotDelete(t *testing.T) {
|
||||
t.Skip("skipping reference test; unskip when needed")
|
||||
|
||||
mc := startMinio(t, true)
|
||||
mcc := &minio.Core{Client: mc}
|
||||
|
||||
uploadID, err := mcc.NewMultipartUpload(context.Background(), "test", "theKey", minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var completed []minio.CompletePart
|
||||
const size int64 = 10 * 1024 * 1024
|
||||
const chunkSize = 5 * 1024 * 1024
|
||||
|
||||
for partNumber, c := range upload.Chunks(size, chunkSize) {
|
||||
u, err := mcc.Presign(context.Background(), "PUT", "test", "theKey", 15*time.Minute, url.Values{
|
||||
"partNumber": {strconv.Itoa(partNumber)},
|
||||
"uploadId": {uploadID},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("[partNumber=%d]: %v", partNumber, err)
|
||||
}
|
||||
t.Logf("[partNumber=%d]: %v", partNumber, u)
|
||||
|
||||
var body abcReader
|
||||
cp, err := PushLayer(context.Background(), &body, u.String(), c.Offset, c.N)
|
||||
if err != nil {
|
||||
t.Fatalf("[partNumber=%d]: %v", partNumber, err)
|
||||
}
|
||||
t.Logf("completed part: %v", cp)
|
||||
|
||||
// behave like server here (don't cheat and use partNumber)
|
||||
// instead get partNumber from the URL
|
||||
retPartNumber, err := strconv.Atoi(u.Query().Get("partNumber"))
|
||||
if err != nil {
|
||||
t.Fatalf("[partNumber=%d]: %v", partNumber, err)
|
||||
}
|
||||
|
||||
completed = append(completed, minio.CompletePart{
|
||||
PartNumber: retPartNumber,
|
||||
ETag: cp.ETag,
|
||||
})
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// fail if there are any incomplete uploads
|
||||
for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
|
||||
t.Errorf("incomplete: %v", x)
|
||||
}
|
||||
}()
|
||||
|
||||
info, err := mcc.CompleteMultipartUpload(context.Background(), "test", "theKey", uploadID, completed, minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("completed: %v", info)
|
||||
|
||||
// Check key in bucket
|
||||
obj, err := mc.GetObject(context.Background(), "test", "theKey", minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, obj); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gotSum := h.Sum(nil)
|
||||
|
||||
h.Reset()
|
||||
var body abcReader
|
||||
if _, err := io.CopyN(h, &body, size); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wantSum := h.Sum(nil)
|
||||
|
||||
if !bytes.Equal(gotSum, wantSum) {
|
||||
t.Errorf("got sum = %x; want %x", gotSum, wantSum)
|
||||
}
|
||||
}
|
||||
|
||||
func availableAddr() string {
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer l.Close()
|
||||
return l.Addr().String()
|
||||
}
|
||||
|
||||
// tracing is "experimental" and may be removed in the future, I can't get it to
|
||||
// work consistently, but I'm leaving it in for now.
|
||||
func startMinio(t *testing.T, trace bool) *minio.Client {
|
||||
t.Helper()
|
||||
|
||||
// Trace is enabled by setting the OLLAMA_MINIO_TRACE environment or
|
||||
// explicitly setting trace to true.
|
||||
trace = cmp.Or(trace, os.Getenv("OLLAMA_MINIO_TRACE") != "")
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
t.Cleanup(func() {
|
||||
// TODO(bmizerany): trim temp dir based on dates so that
|
||||
// future runs may be able to inspect results for some time.
|
||||
})
|
||||
|
||||
waitAndMaybeLogError := func(cmd *exec.Cmd) {
|
||||
if err := cmd.Wait(); err != nil {
|
||||
var e *exec.ExitError
|
||||
if errors.As(err, &e) {
|
||||
if e.Exited() {
|
||||
return
|
||||
}
|
||||
t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
|
||||
t.Logf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode())
|
||||
t.Logf("startMinio: %s exited: %v", cmd.Path, e.Exited())
|
||||
t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
|
||||
} else {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
t.Logf("startMinio: %s exit error: %v", cmd.Path, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel must be called first so do wait to add to Cleanup
|
||||
// stack as last cleanup.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
deadline, ok := t.Deadline()
|
||||
if ok {
|
||||
ctx, cancel = context.WithDeadline(ctx, deadline.Add(-100*time.Millisecond))
|
||||
}
|
||||
|
||||
t.Logf(">> minio: minio server %s", dir)
|
||||
|
||||
addr := availableAddr()
|
||||
cmd := exec.CommandContext(ctx, "minio", "server", "--address", addr, dir)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.WaitDelay = 3 * time.Second
|
||||
cmd.Cancel = func() error {
|
||||
return cmd.Process.Signal(syscall.SIGQUIT)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
waitAndMaybeLogError(cmd)
|
||||
})
|
||||
|
||||
mc, err := minio.New(addr, &minio.Options{
|
||||
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
|
||||
Secure: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
|
||||
// wait for server to start with exponential backoff
|
||||
for _, err := range backoff.Upto(ctx, 1*time.Second) {
|
||||
if err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
// try list buckets to see if server is up
|
||||
if _, err := mc.ListBuckets(ctx); err == nil {
|
||||
break
|
||||
}
|
||||
t.Logf("startMinio: server is offline; retrying")
|
||||
}
|
||||
|
||||
if trace {
|
||||
cmd := exec.CommandContext(ctx, "mc", "admin", "trace", "--verbose", "test")
|
||||
cmd.Env = append(os.Environ(),
|
||||
"MC_HOST_test=http://minioadmin:minioadmin@"+addr,
|
||||
)
|
||||
cmd.WaitDelay = 3 * time.Second
|
||||
cmd.Cancel = func() error {
|
||||
return cmd.Process.Signal(syscall.SIGQUIT)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
|
||||
doneLogging := make(chan struct{})
|
||||
sc := bufio.NewScanner(stdout)
|
||||
go func() {
|
||||
defer close(doneLogging)
|
||||
|
||||
// Scan lines until the process exits.
|
||||
for sc.Scan() {
|
||||
t.Logf("startMinio: mc trace: %s", sc.Text())
|
||||
}
|
||||
_ = sc.Err() // ignore (not important)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
waitAndMaybeLogError(cmd)
|
||||
|
||||
// Make sure we do not log after test exists to
|
||||
// avoid panic.
|
||||
<-doneLogging
|
||||
})
|
||||
}
|
||||
|
||||
if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
return mc
|
||||
}
|
||||
|
||||
// contextForTest returns a context that is canceled when the test deadline,
|
||||
// if any, is reached. The returned doneLogging function should be called
|
||||
// after all Log/Error/Fatalf calls are done before the test returns.
|
||||
func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) {
|
||||
done := make(chan struct{})
|
||||
deadline, ok := t.Deadline()
|
||||
if !ok {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithDeadline(context.Background(), deadline.Add(-100*time.Millisecond))
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
<-done
|
||||
})
|
||||
return ctx, func() { close(done) }
|
||||
}
|
||||
|
||||
// abcReader repeats the string s infinitely.
|
||||
type abcReader struct {
|
||||
pos int
|
||||
}
|
||||
|
||||
const theABCs = "abcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
func (r *abcReader) Read(p []byte) (n int, err error) {
|
||||
for i := range p {
|
||||
p[i] = theABCs[r.pos]
|
||||
r.pos++
|
||||
if r.pos == len(theABCs) {
|
||||
r.pos = 0
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (r *abcReader) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
for i := range p {
|
||||
p[i] = theABCs[(off+int64(i))%int64(len(theABCs))]
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func checkABCs(r io.Reader, size int) (reason string) {
|
||||
h := sha256.New()
|
||||
n, err := io.CopyN(h, &abcReader{}, int64(size))
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
if n != int64(size) {
|
||||
panic("short read; should not happen")
|
||||
}
|
||||
want := h.Sum(nil)
|
||||
h = sha256.New()
|
||||
n, err = io.Copy(h, r)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
if n != int64(size) {
|
||||
return fmt.Sprintf("got len(r) = %d; want %d", n, size)
|
||||
}
|
||||
got := h.Sum(nil)
|
||||
if !bytes.Equal(got, want) {
|
||||
return fmt.Sprintf("got sum = %x; want %x", got, want)
|
||||
}
|
||||
return ""
|
||||
}
|
4
x/types/empty/message.go
Normal file
4
x/types/empty/message.go
Normal file
@ -0,0 +1,4 @@
|
||||
package empty
|
||||
|
||||
// Message is a placeholder type used when encoding json messages.
|
||||
type Message struct{}
|
15
x/types/structs/structs.go
Normal file
15
x/types/structs/structs.go
Normal file
@ -0,0 +1,15 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package structs contains the Incomparable type.
|
||||
package structs
|
||||
|
||||
// Incomparable is a zero-width incomparable type. If added as the
|
||||
// first field in a struct, it marks that struct as not comparable
|
||||
// (can't do == or be a map key) and usually doesn't add any width to
|
||||
// the struct (unless the struct has only small fields).
|
||||
//
|
||||
// By making a struct incomparable, you can prevent misuse (prevent
|
||||
// people from using ==), but also you can shrink generated binaries,
|
||||
// as the compiler can omit equality funcs from the binary.
|
||||
type Incomparable [0]func()
|
12
x/types/they/want.go
Normal file
12
x/types/they/want.go
Normal file
@ -0,0 +1,12 @@
|
||||
package they
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Want returns true if the request method is method and the request path
|
||||
// starts with pathPrefix.
|
||||
func Want(r *http.Request, method string, pathPrefix string) bool {
|
||||
return r.Method == method && strings.HasPrefix(r.URL.Path, pathPrefix)
|
||||
}
|
58
x/utils/backoff/backoff.go
Normal file
58
x/utils/backoff/backoff.go
Normal file
@ -0,0 +1,58 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"iter"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrMaxAttempts is not used by backoff but is available for use by
|
||||
// callers that want to signal that a maximum number of retries has
|
||||
// been exceeded. This should eliminate the need for callers to invent
|
||||
// their own error.
|
||||
ErrMaxAttempts = errors.New("max retries exceeded")
|
||||
)
|
||||
|
||||
// Upto implements a backoff strategy that yields nil errors until the
|
||||
// context is canceled, the maxRetries is exceeded, or yield returns false.
|
||||
//
|
||||
// The backoff strategy is a simple exponential backoff with a maximum
|
||||
// backoff of maxBackoff. The backoff is randomized between 0.5-1.5 times
|
||||
// the current backoff, in order to prevent accidental "thundering herd"
|
||||
// problems.
|
||||
func Upto(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
|
||||
var n int
|
||||
return func(yield func(int, error) bool) {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
yield(n, ctx.Err())
|
||||
return
|
||||
}
|
||||
|
||||
n++
|
||||
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := time.Duration(n*n) * 10 * time.Millisecond
|
||||
if d > maxBackoff {
|
||||
d = maxBackoff
|
||||
}
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
t := time.NewTimer(d)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
case <-t.C:
|
||||
if !yield(n, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
29
x/utils/upload/upload.go
Normal file
29
x/utils/upload/upload.go
Normal file
@ -0,0 +1,29 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"iter"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type Chunk[I constraints.Integer] struct {
|
||||
Offset I
|
||||
N I
|
||||
}
|
||||
|
||||
// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
|
||||
// and size of the chunk. The last chunk may be smaller than chunkSize if size is
|
||||
// not a multiple of chunkSize.
|
||||
//
|
||||
// The first part number is 1 and increases monotonically.
|
||||
func Chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, Chunk[I]] {
|
||||
return func(yield func(int, Chunk[I]) bool) {
|
||||
var n int
|
||||
for off := I(0); off < size; off += chunkSize {
|
||||
n++
|
||||
if !yield(n, Chunk[I]{off, min(chunkSize, size-off)}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
44
x/utils/upload/upload_test.go
Normal file
44
x/utils/upload/upload_test.go
Normal file
@ -0,0 +1,44 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
func TestChunks(t *testing.T) {
|
||||
const size = 101
|
||||
const chunkSize = 10
|
||||
var got []Chunk[int]
|
||||
var lastN int
|
||||
for n, c := range Chunks(size, chunkSize) {
|
||||
if n != lastN+1 {
|
||||
t.Errorf("n = %d; want %d", n, lastN+1)
|
||||
}
|
||||
got = append(got, c)
|
||||
lastN = n
|
||||
}
|
||||
|
||||
want := []Chunk[int]{
|
||||
{0, 10},
|
||||
{10, 10},
|
||||
{20, 10},
|
||||
{30, 10},
|
||||
{40, 10},
|
||||
{50, 10},
|
||||
{60, 10},
|
||||
{70, 10},
|
||||
{80, 10},
|
||||
{90, 10},
|
||||
{100, 1},
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, got, want)
|
||||
}
|
||||
|
||||
func TestChunksBreak(t *testing.T) {
|
||||
for _, _ = range Chunks(1, 1) {
|
||||
return
|
||||
}
|
||||
t.Fatal("expected break")
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user