x/model: limit part len, not entire len

Limiting the whole name length comes naturally with part name length
restrictions. This aligns with Docker's registry behavior.
This commit is contained in:
Blake Mizerany 2024-04-05 18:24:50 -07:00
parent bf8e0c09c9
commit 7c7f56a7fb
3 changed files with 123 additions and 71 deletions

View File

@ -1,14 +1,8 @@
// Package model implements the File and Name types for working with and
// representing Modelfiles and model Names.
//
// The Name type is designed for safety and correctness. It is an opaque
// reference to a model, and holds the parts of a model, casing preserved,
// but is not directly comparable with other Names since model names can be
// represented with different caseing depending on the use case.
//
// Names should never be compared manually parsed. Instead, use the
// [Name.EqualFold] method to compare two names in a case-insensitive
// manner, and [ParseName] to create a Name from a string, safely.
// The Name type should be used when working with model names, and the File
// type should be used when working with Modelfiles.
package model
import (

View File

@ -2,6 +2,7 @@ package model
import (
"cmp"
"errors"
"hash/maphash"
"iter"
"slices"
@ -10,21 +11,19 @@ import (
"github.com/ollama/ollama/x/types/structs"
)
const MaxNameLength = 255
type NamePart int
// Levels of concreteness
const (
Invalid NamePart = iota
Host
Namespace
Model
Tag
Build
// Errors
var (
// ErrInvalidName is not used by this package, but is exported so that
// other packages do not need to invent their own error type when they
// need to return an error for an invalid name.
ErrIncompleteName = errors.New("incomplete model name")
)
var kindNames = map[NamePart]string{
const MaxNamePartLen = 128
type NamePartKind int
var kindNames = map[NamePartKind]string{
Invalid: "Invalid",
Host: "Host",
Namespace: "Namespace",
@ -33,12 +32,36 @@ var kindNames = map[NamePart]string{
Build: "Build",
}
// Name is an opaque reference to a model. It holds the parts of a model,
// casing preserved, and provides methods for comparing and manipulating
// them in a case-insensitive manner.
func (k NamePartKind) String() string {
return cmp.Or(kindNames[k], "!(UNKNOWN PART KIND)")
}
// Levels of concreteness
const (
Invalid NamePartKind = iota
Host
Namespace
Model
Tag
Build
)
// Name is an opaque reference to a model. It holds the parts of a model
// with the case preserved, but is not directly comparable with other Names
// since model names can be represented with different caseing depending on
// the use case. For instance, "Mistral" and "mistral" are the same model
// but each version may have come from different sources (e.g. copied from a
// Web page, or from a file path).
//
// To create a Name, use [ParseName]. To compare two names, use
// [Name.EqualFold]. To use a name as a key in a map, use [Name.MapHash].
// Valid Names can ONLY be constructed by calling [ParseName].
//
// A Name is valid if and only if is have a valid Model part. The other parts
// are optional.
//
// A Name is considered "complete" if it has all parts present. To check if a
// Name is complete, use [Name.Complete].
//
// To compare two names in a case-insensitive manner, use [Name.EqualFold].
//
// The parts of a Name are:
//
@ -124,7 +147,7 @@ func ParseName(s string) Name {
// Fill fills in the missing parts of dst with the parts of src.
//
// Use this for merging a fully qualified ref with a partial ref, such as
// Use this for merging a fully qualified Name with a partial Name, such as
// when filling in a missing parts with defaults.
//
// The returned Name will only be valid if dst is valid.
@ -144,6 +167,23 @@ func (r Name) WithBuild(build string) Name {
return r
}
// Has reports whether the Name has the given part kind.
func (r Name) Has(kind NamePartKind) bool {
switch kind {
case Host:
return r.host != ""
case Namespace:
return r.namespace != ""
case Model:
return r.model != ""
case Tag:
return r.tag != ""
case Build:
return r.build != ""
}
return false
}
var mapHashSeed = maphash.MakeSeed()
// MapHash returns a case insensitive hash for use in maps and equality
@ -165,9 +205,10 @@ func (r Name) MapHash() uint64 {
return h.Sum64()
}
// Format returns a string representation of the ref with the given
// concreteness. If a part is missing, it is replaced with a loud
// placeholder.
func (r Name) DisplayModel() string {
return r.model
}
func (r Name) DisplayFull() string {
return (Name{
host: cmp.Or(r.host, "!(MISSING DOMAIN)"),
@ -178,27 +219,7 @@ func (r Name) DisplayFull() string {
}).String()
}
func (r Name) DisplayModel() string {
return r.model
}
func (r Name) Has(kind NamePart) bool {
switch kind {
case Host:
return r.host != ""
case Namespace:
return r.namespace != ""
case Model:
return r.model != ""
case Tag:
return r.tag != ""
case Build:
return r.build != ""
}
return false
}
// DisplayCompact returns a compact display string of the ref with only the
// DisplayCompact returns a compact display string of the Name with only the
// model and tag parts.
func (r Name) DisplayCompact() string {
return (Name{
@ -207,7 +228,7 @@ func (r Name) DisplayCompact() string {
}).String()
}
// DisplayShort returns a short display string of the ref with only the
// DisplayShort returns a short display string of the Name with only the
// model, tag, and build parts.
func (r Name) DisplayShort() string {
return (Name{
@ -217,7 +238,7 @@ func (r Name) DisplayShort() string {
}).String()
}
// DisplayLong returns a long display string of the ref including namespace,
// DisplayLong returns a long display string of the Name including namespace,
// model, tag, and build parts.
func (r Name) DisplayLong() string {
return (Name{
@ -228,7 +249,7 @@ func (r Name) DisplayLong() string {
}).String()
}
// String returns the fully qualified ref string.
// String returns the fully qualified Name string.
func (r Name) String() string {
var b strings.Builder
if r.host != "" {
@ -251,7 +272,7 @@ func (r Name) String() string {
return b.String()
}
// Complete reports whether the ref is fully qualified. That is it has a
// Complete reports whether the Name is fully qualified. That is it has a
// domain, namespace, name, tag, and build.
func (r Name) Complete() bool {
return r.Valid() && !slices.Contains(r.Parts(), "")
@ -262,7 +283,7 @@ func (r Name) Complete() bool {
// TODO(bmizerany): LogValue
// TODO(bmizerany): driver.Value? (MarshalText etc should be enough)
// Parts returns the parts of the ref in order of concreteness.
// Parts returns the parts of the Name in order of concreteness.
//
// The length of the returned slice is always 5.
func (r Name) Parts() []string {
@ -287,7 +308,7 @@ func (r Name) EqualFold(o Name) bool {
return r.MapHash() == o.MapHash()
}
// Parts returns a sequence of the parts of a ref string from most specific
// Parts returns a sequence of the parts of a Name string from most specific
// to least specific.
//
// It normalizes the input string by removing "http://" and "https://" only.
@ -295,8 +316,8 @@ func (r Name) EqualFold(o Name) bool {
//
// As a special case, question marks are ignored so they may be used as
// placeholders for missing parts in string literals.
func NameParts(s string) iter.Seq2[NamePart, string] {
return func(yield func(NamePart, string) bool) {
func NameParts(s string) iter.Seq2[NamePartKind, string] {
return func(yield func(NamePartKind, string) bool) {
if strings.HasPrefix(s, "http://") {
s = s[len("http://"):]
}
@ -304,11 +325,11 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
s = s[len("https://"):]
}
if len(s) > MaxNameLength || len(s) == 0 {
if len(s) > MaxNamePartLen || len(s) == 0 {
return
}
yieldValid := func(kind NamePart, part string) bool {
yieldValid := func(kind NamePartKind, part string) bool {
if !isValidPart(kind, part) {
yield(Invalid, "")
return false
@ -316,8 +337,13 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
return yield(kind, part)
}
partLen := 0
state, j := Build, len(s)
for i := len(s) - 1; i >= 0; i-- {
if partLen++; partLen > MaxNamePartLen {
yield(Invalid, "")
return
}
switch s[i] {
case '+':
switch state {
@ -325,7 +351,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
if !yieldValid(Build, s[i+1:j]) {
return
}
state, j = Tag, i
state, j, partLen = Tag, i, 0
default:
yield(Invalid, "")
return
@ -336,7 +362,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
if !yieldValid(Tag, s[i+1:j]) {
return
}
state, j = Model, i
state, j, partLen = Model, i, 0
default:
yield(Invalid, "")
return
@ -352,7 +378,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
if !yieldValid(Namespace, s[i+1:j]) {
return
}
state, j = Host, i
state, j, partLen = Host, i, 0
default:
yield(Invalid, "")
return
@ -373,7 +399,7 @@ func NameParts(s string) iter.Seq2[NamePart, string] {
}
}
// Valid returns true if the ref has a valid nick. To know if a ref is
// Valid returns true if the Name has a valid nick. To know if a Name is
// "complete", use Complete.
func (r Name) Valid() bool {
// Parts ensures we only have valid parts, so no need to validate
@ -382,7 +408,7 @@ func (r Name) Valid() bool {
}
// isValidPart returns true if given part is valid ascii [a-zA-Z0-9_\.-]
func isValidPart(kind NamePart, s string) bool {
func isValidPart(kind NamePartKind, s string) bool {
if s == "" {
return false
}
@ -394,7 +420,7 @@ func isValidPart(kind NamePart, s string) bool {
return true
}
func isValidByte(kind NamePart, c byte) bool {
func isValidByte(kind NamePartKind, c byte) bool {
if kind == Namespace && c == '.' {
return false
}

View File

@ -52,8 +52,8 @@ var testNames = map[string]Name{
"file:///etc/passwd:latest": {},
"file:///etc/passwd:latest+u": {},
strings.Repeat("a", MaxNameLength): {model: strings.Repeat("a", MaxNameLength)},
strings.Repeat("a", MaxNameLength+1): {},
strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
strings.Repeat("a", MaxNamePartLen+1): {},
}
func TestNameParts(t *testing.T) {
@ -64,6 +64,34 @@ func TestNameParts(t *testing.T) {
}
}
func TestPartTooLong(t *testing.T) {
for i := Host; i <= Build; i++ {
t.Run(i.String(), func(t *testing.T) {
var p Name
switch i {
case Host:
p.host = strings.Repeat("a", MaxNamePartLen+1)
case Namespace:
p.namespace = strings.Repeat("a", MaxNamePartLen+1)
case Model:
p.model = strings.Repeat("a", MaxNamePartLen+1)
case Tag:
p.tag = strings.Repeat("a", MaxNamePartLen+1)
case Build:
p.build = strings.Repeat("a", MaxNamePartLen+1)
}
s := strings.Trim(p.String(), "+:/")
if len(s) != MaxNamePartLen+1 {
t.Errorf("len(String()) = %d; want %d", len(s), MaxNamePartLen+1)
t.Logf("String() = %q", s)
}
if ParseName(p.String()).Valid() {
t.Errorf("Valid(%q) = true; want false", p)
}
})
}
}
func TestParseName(t *testing.T) {
for baseName, want := range testNames {
for _, prefix := range []string{"", "https://", "http://"} {
@ -210,7 +238,7 @@ func FuzzParseName(f *testing.F) {
}
for _, p := range r0.Parts() {
if len(p) > MaxNameLength {
if len(p) > MaxNamePartLen {
t.Errorf("part too long: %q", p)
}
}
@ -261,11 +289,15 @@ func ExampleFill() {
func ExampleName_MapHash() {
m := map[uint64]bool{}
// key 1
m[ParseName("mistral:latest+q4").MapHash()] = true
m[ParseName("miSTRal:latest+Q4").MapHash()] = true
m[ParseName("mistral:LATest+Q4").MapHash()] = true
// key 2
m[ParseName("mistral:LATest").MapHash()] = true
fmt.Println(len(m))
// Output:
// 1
// 2
}