diff --git a/api/client.go b/api/client.go index 4e434fae..f153f32e 100644 --- a/api/client.go +++ b/api/client.go @@ -5,153 +5,108 @@ import ( "bytes" "context" "encoding/json" - "fmt" "io" "net/http" - "sync" + "net/url" ) type Client struct { - URL string - HTTP http.Client + base url.URL } -func checkError(resp *http.Response, body []byte) error { - if resp.StatusCode >= 200 && resp.StatusCode < 400 { - return nil +func NewClient(hosts ...string) *Client { + host := "127.0.0.1:11434" + if len(hosts) > 0 { + host = hosts[0] } - apiError := Error{Code: int32(resp.StatusCode)} - - if err := json.Unmarshal(body, &apiError); err != nil { - // Use the full body as the message if we fail to decode a response. - apiError.Message = string(body) + return &Client{ + base: url.URL{Scheme: "http", Host: host}, } - - return apiError } -func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func(data []byte)) error { - var reqBody io.Reader - var data []byte - var err error - if reqData != nil { - data, err = json.Marshal(reqData) - if err != nil { - return err - } - reqBody = bytes.NewReader(data) +type options struct { + requestBody io.Reader + responseFunc func(bts []byte) error +} + +func OptionRequestBody(data any) func(*options) { + bts, err := json.Marshal(data) + if err != nil { + panic(err) } - url := fmt.Sprintf("%s%s", c.URL, path) + return func(opts *options) { + opts.requestBody = bytes.NewReader(bts) + } +} - req, err := http.NewRequestWithContext(ctx, method, url, reqBody) +func OptionResponseFunc(fn func([]byte) error) func(*options) { + return func(opts *options) { + opts.responseFunc = fn + } +} + +func (c *Client) stream(ctx context.Context, method, path string, fns ...func(*options)) error { + var opts options + for _, fn := range fns { + fn(&opts) + } + + request, err := http.NewRequestWithContext(ctx, method, c.base.JoinPath(path).String(), opts.requestBody) if err != nil { return err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") + request.Header.Set("Content-Type", "application/json") + request.Header.Set("Accept", "application/json") - res, err := c.HTTP.Do(req) + response, err := http.DefaultClient.Do(request) if err != nil { return err } - defer res.Body.Close() + defer response.Body.Close() - reader := bufio.NewReader(res.Body) - - for { - line, err := reader.ReadBytes('\n') - if err != nil { - if err == io.EOF { - break - } else { - return err // Handle other errors + if opts.responseFunc != nil { + scanner := bufio.NewScanner(response.Body) + for scanner.Scan() { + if err := opts.responseFunc(scanner.Bytes()); err != nil { + return err } } - if err := checkError(res, line); err != nil { - return err - } - callback(bytes.TrimSuffix(line, []byte("\n"))) } return nil } -func (c *Client) do(ctx context.Context, method string, path string, reqData any, respData any) error { - var reqBody io.Reader - var data []byte - var err error - if reqData != nil { - data, err = json.Marshal(reqData) - if err != nil { - return err - } - reqBody = bytes.NewReader(data) - } +type GenerateResponseFunc func(GenerateResponse) error - url := fmt.Sprintf("%s%s", c.URL, path) +func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error { + return c.stream(ctx, http.MethodPost, "/api/generate", + OptionRequestBody(req), + OptionResponseFunc(func(bts []byte) error { + var resp GenerateResponse + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } - req, err := http.NewRequestWithContext(ctx, method, url, reqBody) - if err != nil { - return err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - - respObj, err := c.HTTP.Do(req) - if err != nil { - return err - } - defer respObj.Body.Close() - - respBody, err := io.ReadAll(respObj.Body) - if err != nil { - return err - } - - if err := checkError(respObj, respBody); err != nil { - return err - } - - if len(respBody) > 0 && respData != nil { - if err := json.Unmarshal(respBody, respData); err != nil { - return err - } - } - return nil + return fn(resp) + }), + ) } -func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback func(token string)) (*GenerateResponse, error) { - var res GenerateResponse - if err := c.stream(ctx, http.MethodPost, "/api/generate", req, func(token []byte) { - callback(string(token)) - }); err != nil { - return nil, err - } +type PullProgressFunc func(PullProgress) error - return &res, nil -} - -func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) error { - var wg sync.WaitGroup - wg.Add(1) - if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) { - var progress PullProgress - if err := json.Unmarshal(progressBytes, &progress); err != nil { - fmt.Println(err) - return - } - if progress.Completed >= progress.Total { - wg.Done() - } - callback(progress) - }); err != nil { - return err - } - - wg.Wait() - return nil +func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { + return c.stream(ctx, http.MethodPost, "/api/pull", + OptionRequestBody(req), + OptionResponseFunc(func(bts []byte) error { + var resp PullProgress + if err := json.Unmarshal(bts, &resp); err != nil { + return err + } + + return fn(resp) + }), + ) } diff --git a/api/types.go b/api/types.go index d4eb63ad..39167ffd 100644 --- a/api/types.go +++ b/api/types.go @@ -23,8 +23,8 @@ type PullRequest struct { } type PullProgress struct { - Total int `json:"total"` - Completed int `json:"completed"` + Total int64 `json:"total"` + Completed int64 `json:"completed"` Percent float64 `json:"percent"` } diff --git a/cmd/cmd.go b/cmd/cmd.go index 1619be09..e6cfa7ff 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,18 +1,22 @@ package cmd import ( + "bufio" "context" + "errors" "fmt" "log" "net" "os" "path" - "sync" + "time" + + "github.com/schollz/progressbar/v3" + "github.com/spf13/cobra" + "golang.org/x/term" - "github.com/gosuri/uiprogress" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/server" - "github.com/spf13/cobra" ) func cacheDir() string { @@ -24,46 +28,126 @@ func cacheDir() string { return path.Join(home, ".ollama") } -func bytesToGB(bytes int) float64 { - return float64(bytes) / float64(1<<30) +func RunRun(cmd *cobra.Command, args []string) error { + _, err := os.Stat(args[0]) + switch { + case errors.Is(err, os.ErrNotExist): + if err := pull(args[0]); err != nil { + return err + } + + fmt.Println("Up to date.") + case err != nil: + return err + } + + return RunGenerate(cmd, args) } -func run(model string) error { - client, err := NewAPIClient() - if err != nil { - return err - } - pr := api.PullRequest{ - Model: model, - } - var bar *uiprogress.Bar - mutex := &sync.Mutex{} - var progressData api.PullProgress +func pull(model string) error { + client := api.NewClient() - pullCallback := func(progress api.PullProgress) { - mutex.Lock() - progressData = progress - if bar == nil { - uiprogress.Start() - bar = uiprogress.AddBar(int(progress.Total)) - bar.PrependFunc(func(b *uiprogress.Bar) string { - return fmt.Sprintf("Downloading: %.2f GB / %.2f GB", bytesToGB(progressData.Completed), bytesToGB(progressData.Total)) - }) - bar.AppendFunc(func(b *uiprogress.Bar) string { - return fmt.Sprintf(" %d%%", int((float64(progressData.Completed)/float64(progressData.Total))*100)) - }) + var bar *progressbar.ProgressBar + return client.Pull( + context.Background(), + &api.PullRequest{Model: model}, + func(progress api.PullProgress) error { + if bar == nil { + bar = progressbar.DefaultBytes(progress.Total) + } + + return bar.Set64(progress.Completed) + }, + ) +} + +func RunGenerate(_ *cobra.Command, args []string) error { + if len(args) > 1 { + return generateOneshot(args[0], args[1:]...) + } + + if term.IsTerminal(int(os.Stdin.Fd())) { + return generateInteractive(args[0]) + } + + return generateBatch(args[0]) +} + +func generate(model, prompt string) error { + client := api.NewClient() + + spinner := progressbar.NewOptions(-1, + progressbar.OptionSetWriter(os.Stderr), + progressbar.OptionThrottle(60*time.Millisecond), + progressbar.OptionSpinnerType(14), + progressbar.OptionSetRenderBlankState(true), + progressbar.OptionSetElapsedTime(false), + progressbar.OptionClearOnFinish(), + ) + + go func() { + for range time.Tick(60 * time.Millisecond) { + if spinner.IsFinished() { + break + } + + spinner.Add(1) } - bar.Set(int(progress.Completed)) - mutex.Unlock() - } - if err := client.Pull(context.Background(), &pr, pullCallback); err != nil { - return err - } - fmt.Println("Up to date.") + }() + + client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error { + if !spinner.IsFinished() { + spinner.Finish() + } + + fmt.Print(resp.Response) + return nil + }) + + fmt.Println() + fmt.Println() return nil } -func serve() error { +func generateOneshot(model string, prompts ...string) error { + for _, prompt := range prompts { + fmt.Printf(">>> %s\n", prompt) + if err := generate(model, prompt); err != nil { + return err + } + } + + return nil +} + +func generateInteractive(model string) error { + fmt.Print(">>> ") + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + if err := generate(model, scanner.Text()); err != nil { + return err + } + + fmt.Print(">>> ") + } + + return nil +} + +func generateBatch(model string) error { + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + prompt := scanner.Text() + fmt.Printf(">>> %s\n", prompt) + if err := generate(model, prompt); err != nil { + return err + } + } + + return nil +} + +func RunServer(_ *cobra.Command, _ []string) error { ln, err := net.Listen("tcp", "127.0.0.1:11434") if err != nil { return err @@ -72,49 +156,36 @@ func serve() error { return server.Serve(ln) } -func NewAPIClient() (*api.Client, error) { - return &api.Client{ - URL: "http://localhost:11434", - }, nil -} - func NewCLI() *cobra.Command { log.SetFlags(log.LstdFlags | log.Lshortfile) rootCmd := &cobra.Command{ - Use: "ollama", - Short: "Large language model runner", + Use: "ollama", + Short: "Large language model runner", + SilenceUsage: true, CompletionOptions: cobra.CompletionOptions{ DisableDefaultCmd: true, }, - PersistentPreRun: func(cmd *cobra.Command, args []string) { - // Disable usage printing on errors - cmd.SilenceUsage = true + PersistentPreRunE: func(_ *cobra.Command, args []string) error { // create the models directory and it's parent - if err := os.MkdirAll(path.Join(cacheDir(), "models"), 0o700); err != nil { - panic(err) - } + return os.MkdirAll(path.Join(cacheDir(), "models"), 0o700) }, } cobra.EnableCommandSorting = false runCmd := &cobra.Command{ - Use: "run MODEL", + Use: "run MODEL [PROMPT]", Short: "Run a model", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - return run(args[0]) - }, + Args: cobra.MinimumNArgs(1), + RunE: RunRun, } serveCmd := &cobra.Command{ Use: "serve", Aliases: []string{"start"}, Short: "Start ollama", - RunE: func(cmd *cobra.Command, args []string) error { - return serve() - }, + RunE: RunServer, } rootCmd.AddCommand( diff --git a/go.mod b/go.mod index 6ca336d1..c2e15346 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,15 @@ go 1.20 require ( github.com/gin-gonic/gin v1.9.1 - github.com/gosuri/uiprogress v0.0.1 github.com/spf13/cobra v1.7.0 ) +require ( + github.com/mattn/go-runewidth v0.0.14 // indirect + github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect + github.com/rivo/uniseg v0.2.0 // indirect +) + require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect @@ -18,7 +23,6 @@ require ( github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.5.9 // indirect - github.com/gosuri/uilive v0.0.4 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect @@ -28,6 +32,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/schollz/progressbar/v3 v3.13.1 github.com/spf13/pflag v1.0.5 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect @@ -35,6 +40,7 @@ require ( golang.org/x/crypto v0.10.0 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.10.0 // indirect + golang.org/x/term v0.10.0 golang.org/x/text v0.10.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 065bb0db..2adee49d 100644 --- a/go.sum +++ b/go.sum @@ -28,14 +28,11 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/gosuri/uilive v0.0.4 h1:hUEBpQDj8D8jXgtCdBu7sWsy5sbW/5GhuO8KBwJ2jyY= -github.com/gosuri/uilive v0.0.4/go.mod h1:V/epo5LjjlDE5RJUcqx8dbw+zc93y5Ya3yg8tfZ74VI= -github.com/gosuri/uiprogress v0.0.1 h1:0kpv/XY/qTmFWl/SkaJykZXrBBzwwadmW8fRb7RJSxw= -github.com/gosuri/uiprogress v0.0.1/go.mod h1:C1RTYn4Sc7iEyf6j8ft5dyoZ4212h8G1ol9QQluh5+0= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= @@ -43,8 +40,13 @@ github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/lithammer/fuzzysearch v1.1.8 h1:/HIuJnjHuXS8bKaiTMeeDlW2/AyIWk2brx1V8LFgLN4= github.com/lithammer/fuzzysearch v1.1.8/go.mod h1:IdqeyBClc3FFqSzYq/MXESsS4S0FsZ5ajtkr5xPLts4= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= +github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -54,7 +56,11 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/schollz/progressbar/v3 v3.13.1 h1:o8rySDYiQ59Mwzy2FELeHY5ZARXZTVJC7iHD6PEFUiE= +github.com/schollz/progressbar/v3 v3.13.1/go.mod h1:xvrbki8kfT1fzWzBT/UZd9L6GA+jdL7HAgq2RFnO6fQ= github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= @@ -99,6 +105,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= @@ -106,6 +113,9 @@ golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= diff --git a/server/models.go b/server/models.go index b0cccf46..ac41d11d 100644 --- a/server/models.go +++ b/server/models.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "net/http" "os" "path" @@ -30,6 +29,15 @@ type Model struct { License string `json:"license"` } +func (m *Model) FullName() string { + home, err := os.UserHomeDir() + if err != nil { + panic(err) + } + + return path.Join(home, ".ollama", "models", m.Name+".bin") +} + func pull(model string, progressCh chan<- api.PullProgress) error { remote, err := getRemote(model) if err != nil { @@ -45,7 +53,7 @@ func getRemote(model string) (*Model, error) { return nil, fmt.Errorf("failed to get directory: %w", err) } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read directory: %w", err) } @@ -64,13 +72,6 @@ func getRemote(model string) (*Model, error) { func saveModel(model *Model, progressCh chan<- api.PullProgress) error { // this models cache directory is created by the server on startup - home, err := os.UserHomeDir() - if err != nil { - return fmt.Errorf("failed to get home directory: %w", err) - } - modelsCache := path.Join(home, ".ollama", "models") - - fileName := path.Join(modelsCache, model.Name+".bin") client := &http.Client{} req, err := http.NewRequest("GET", model.URL, nil) @@ -78,16 +79,16 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { return fmt.Errorf("failed to download model: %w", err) } // check for resume - alreadyDownloaded := 0 - fileInfo, err := os.Stat(fileName) + alreadyDownloaded := int64(0) + fileInfo, err := os.Stat(model.FullName()) if err != nil { if !os.IsNotExist(err) { return fmt.Errorf("failed to check resume model file: %w", err) } // file doesn't exist, create it now } else { - alreadyDownloaded = int(fileInfo.Size()) - req.Header.Add("Range", "bytes="+strconv.Itoa(alreadyDownloaded)+"-") + alreadyDownloaded = fileInfo.Size() + req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded)) } resp, err := client.Do(req) @@ -111,13 +112,13 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { return fmt.Errorf("failed to download model: %s", resp.Status) } - out, err := os.OpenFile(fileName, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) + out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { panic(err) } defer out.Close() - totalSize, _ := strconv.Atoi(resp.Header.Get("Content-Length")) + totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) buf := make([]byte, 1024) totalBytes := alreadyDownloaded @@ -134,7 +135,8 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error { if _, err := out.Write(buf[:n]); err != nil { return err } - totalBytes += n + + totalBytes += int64(n) // send progress updates progressCh <- api.PullProgress{ diff --git a/server/routes.go b/server/routes.go index 0ca3d10e..4831b7ad 100644 --- a/server/routes.go +++ b/server/routes.go @@ -37,6 +37,10 @@ func generate(c *gin.Context) { return } + if remoteModel, _ := getRemote(req.Model); remoteModel != nil { + req.Model = remoteModel.FullName() + } + model, err := llama.New(req.Model, llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(gpulayers)) if err != nil { fmt.Println("Loading the model failed:", err.Error())