This commit is contained in:
Josh Yan 2024-08-28 15:44:21 -07:00
parent c41bbb45bd
commit cf8af774ab
2 changed files with 6 additions and 6 deletions

View File

@ -190,7 +190,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
case "BertModel": case "BertModel":
conv = &bertModel{} conv = &bertModel{}
case "CohereForCausalLM": case "CohereForCausalLM":
conv = &commandr{} conv = &commandrModel{}
default: default:
return errors.New("unsupported architecture") return errors.New("unsupported architecture")
} }

View File

@ -6,7 +6,7 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
) )
type commandr struct { type commandrModel struct {
ModelParameters ModelParameters
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"` HiddenSize uint32 `json:"hidden_size"`
@ -22,9 +22,9 @@ type commandr struct {
NCtx uint32 `json:"n_ctx"` NCtx uint32 `json:"n_ctx"`
} }
var _ ModelConverter = (*commandr)(nil) var _ ModelConverter = (*commandrModel)(nil)
func (p *commandr) KV(t *Tokenizer) llm.KV { func (p *commandrModel) KV(t *Tokenizer) llm.KV {
kv := p.ModelParameters.KV(t) kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "command-r" kv["general.architecture"] = "command-r"
kv["general.name"] = "command-r" kv["general.name"] = "command-r"
@ -47,7 +47,7 @@ func (p *commandr) KV(t *Tokenizer) llm.KV {
return kv return kv
} }
func (p *commandr) Tensors(ts []Tensor) []llm.Tensor { func (p *commandrModel) Tensors(ts []Tensor) []llm.Tensor {
var out []llm.Tensor var out []llm.Tensor
for _, t := range ts { for _, t := range ts {
out = append(out, llm.Tensor{ out = append(out, llm.Tensor{
@ -61,7 +61,7 @@ func (p *commandr) Tensors(ts []Tensor) []llm.Tensor {
return out return out
} }
func (p *commandr) Replacements() []string { func (p *commandrModel) Replacements() []string {
return []string{ return []string{
"self_attn.q_norm", "attn_q_norm", "self_attn.q_norm", "attn_q_norm",
"self_attn.k_norm", "attn_k_norm", "self_attn.k_norm", "attn_k_norm",