diff --git a/.dockerignore b/.dockerignore index a1f8beae..43f2e07d 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,8 +1,9 @@ .vscode ollama app +macapp dist llm/llama.cpp .env .cache -test_data \ No newline at end of file +test_data diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f5174c33..d2534302 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -2,17 +2,22 @@ name: test on: pull_request: + paths: + - '**/*' + - '!docs/**' + - '!examples/**' + - '!README.md' jobs: generate: strategy: matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest, windows-2019] arch: [amd64, arm64] exclude: - os: ubuntu-latest arch: arm64 - - os: windows-latest + - os: windows-2019 arch: arm64 runs-on: ${{ matrix.os }} env: @@ -21,10 +26,21 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: - go-version: '1.21' + go-version: '1.22' cache: true - run: go get ./... + - run: | + $gopath=(get-command go).source | split-path -parent + & "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1" + cd $env:GITHUB_WORKSPACE + $env:CMAKE_SYSTEM_VERSION="10.0.22621.0" + $env:PATH="$gopath;$env:PATH" + go generate -x ./... + if: ${{ startsWith(matrix.os, 'windows-') }} + name: "Windows Go Generate" - run: go generate -x ./... + if: ${{ ! startsWith(matrix.os, 'windows-') }} + name: "Unix Go Generate" - uses: actions/upload-artifact@v4 with: name: ${{ matrix.os }}-${{ matrix.arch }}-libraries @@ -34,7 +50,7 @@ jobs: matrix: cuda-version: - '11.8.0' - runs-on: ubuntu-latest + runs-on: linux container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04 steps: - run: | @@ -46,7 +62,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-go@v4 with: - go-version: '1.21' + go-version: '1.22' cache: true - run: go get ./... - run: | @@ -62,9 +78,8 @@ jobs: strategy: matrix: rocm-version: - - '5.7.1' - '6.0' - runs-on: ubuntu-latest + runs-on: linux container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }} steps: - run: | @@ -76,7 +91,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-go@v4 with: - go-version: '1.21' + go-version: '1.22' cache: true - run: go get ./... - run: | @@ -91,26 +106,26 @@ jobs: lint: strategy: matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest, windows-2019] arch: [amd64, arm64] exclude: - os: ubuntu-latest arch: arm64 - - os: windows-latest + - os: windows-2019 arch: arm64 - os: macos-latest arch: amd64 runs-on: ${{ matrix.os }} env: GOARCH: ${{ matrix.arch }} - CGO_ENABLED: "1" + CGO_ENABLED: '1' steps: - uses: actions/checkout@v4 with: submodules: recursive - uses: actions/setup-go@v5 with: - go-version: '1.21' + go-version: '1.22' cache: false - run: | mkdir -p llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/ @@ -130,24 +145,24 @@ jobs: needs: generate strategy: matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest, windows-2019] arch: [amd64] exclude: - os: ubuntu-latest arch: arm64 - - os: windows-latest + - os: windows-2019 arch: arm64 runs-on: ${{ matrix.os }} env: GOARCH: ${{ matrix.arch }} - CGO_ENABLED: "1" + CGO_ENABLED: '1' steps: - uses: actions/checkout@v4 with: submodules: recursive - uses: actions/setup-go@v5 with: - go-version: '1.21' + go-version: '1.22' cache: true - run: go get - uses: actions/download-artifact@v4 diff --git a/.gitignore b/.gitignore index 97f73481..388175f7 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ ggml-metal.metal .cache *.exe .idea -test_data \ No newline at end of file +test_data +*.crt \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 7c921df8..0cbc7e34 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,7 @@ -ARG GOLANG_VERSION=1.21.3 +ARG GOLANG_VERSION=1.22.1 ARG CMAKE_VERSION=3.22.1 ARG CUDA_VERSION=11.3.1 +ARG ROCM_VERSION=6.0 # Copy the minimal context we need to run the generate scripts FROM scratch AS llm-code @@ -28,7 +29,7 @@ WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate ARG CGO_CFLAGS RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh -FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete AS rocm-5-build-amd64 +FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64 ARG CMAKE_VERSION COPY ./scripts/rh_linux_deps.sh / RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh @@ -39,18 +40,14 @@ WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate ARG CGO_CFLAGS ARG AMDGPU_TARGETS RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh +RUN mkdir /tmp/scratch && \ + for dep in $(cat /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/x86_64/rocm*/lib/deps.txt) ; do \ + cp ${dep} /tmp/scratch/ || exit 1 ; \ + done && \ + (cd /opt/rocm/lib && tar cf - rocblas/library) | (cd /tmp/scratch/ && tar xf - ) && \ + mkdir -p /go/src/github.com/jmorganca/ollama/dist/deps/ && \ + (cd /tmp/scratch/ && tar czvf /go/src/github.com/jmorganca/ollama/dist/deps/ollama-linux-amd64-rocm.tgz . ) -FROM --platform=linux/amd64 rocm/dev-centos-7:6.0-complete AS rocm-6-build-amd64 -ARG CMAKE_VERSION -COPY ./scripts/rh_linux_deps.sh / -RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh -ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH -ENV LIBRARY_PATH /opt/amdgpu/lib64 -COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/ -WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate -ARG CGO_CFLAGS -ARG AMDGPU_TARGETS -RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh FROM --platform=linux/amd64 centos:7 AS cpu-builder-amd64 ARG CMAKE_VERSION @@ -91,11 +88,11 @@ COPY . . COPY --from=cpu_avx-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/ COPY --from=cpu_avx2-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/ COPY --from=cuda-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/ -COPY --from=rocm-5-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/ -COPY --from=rocm-6-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/ +COPY --from=rocm-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/ +COPY --from=rocm-build-amd64 /go/src/github.com/jmorganca/ollama/dist/deps/ ./dist/deps/ ARG GOFLAGS ARG CGO_CFLAGS -RUN go build . +RUN go build -trimpath . # Intermediate stage used for ./scripts/build_linux.sh FROM --platform=linux/arm64 cpu-build-arm64 AS build-arm64 @@ -106,7 +103,7 @@ COPY . . COPY --from=cuda-build-arm64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/ ARG GOFLAGS ARG CGO_CFLAGS -RUN go build . +RUN go build -trimpath . # Runtime stages FROM --platform=linux/amd64 ubuntu:22.04 as runtime-amd64 @@ -117,7 +114,7 @@ RUN apt-get update && apt-get install -y ca-certificates COPY --from=build-arm64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama # Radeon images are much larger so we keep it distinct from the CPU/CUDA image -FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete as runtime-rocm +FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete as runtime-rocm RUN update-pciids COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama EXPOSE 11434 @@ -132,6 +129,7 @@ ENV OLLAMA_HOST 0.0.0.0 ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility +ENV NVIDIA_VISIBLE_DEVICES=all ENTRYPOINT ["/bin/ollama"] CMD ["serve"] diff --git a/README.md b/README.md index 98f95703..f30d8769 100644 --- a/README.md +++ b/README.md @@ -10,16 +10,16 @@ Get up and running with large language models locally. ### macOS -[Download](https://ollama.ai/download/Ollama-darwin.zip) +[Download](https://ollama.com/download/Ollama-darwin.zip) -### Windows +### Windows preview -Coming soon! For now, you can install Ollama on Windows via WSL2. +[Download](https://ollama.com/download/OllamaSetup.exe) -### Linux & WSL2 +### Linux ``` -curl https://ollama.ai/install.sh | sh +curl -fsSL https://ollama.com/install.sh | sh ``` [Manual install instructions](https://github.com/jmorganca/ollama/blob/main/docs/linux.md) @@ -35,7 +35,7 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla ## Quickstart -To run and chat with [Llama 2](https://ollama.ai/library/llama2): +To run and chat with [Llama 2](https://ollama.com/library/llama2): ``` ollama run llama2 @@ -43,9 +43,9 @@ ollama run llama2 ## Model library -Ollama supports a list of open-source models available on [ollama.ai/library](https://ollama.ai/library 'ollama model library') +Ollama supports a list of models available on [ollama.com/library](https://ollama.com/library 'ollama model library') -Here are some example open-source models that can be downloaded: +Here are some example models that can be downloaded: | Model | Parameters | Size | Download | | ------------------ | ---------- | ----- | ------------------------------ | @@ -62,6 +62,8 @@ Here are some example open-source models that can be downloaded: | Orca Mini | 3B | 1.9GB | `ollama run orca-mini` | | Vicuna | 7B | 3.8GB | `ollama run vicuna` | | LLaVA | 7B | 4.5GB | `ollama run llava` | +| Gemma | 2B | 1.4GB | `ollama run gemma:2b` | +| Gemma | 7B | 4.8GB | `ollama run gemma:7b` | > Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models. @@ -200,18 +202,21 @@ brew install cmake go ``` Then generate dependencies: + ``` go generate ./... ``` + Then build the binary: + ``` go build . ``` More detailed instructions can be found in the [developer guide](https://github.com/jmorganca/ollama/blob/main/docs/development.md) - ### Running local builds + Next, start the server: ``` @@ -253,20 +258,28 @@ See the [API documentation](./docs/api.md) for all endpoints. ## Community Integrations ### Web & Desktop + - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt) +- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted) - [HTML UI](https://github.com/rtcfirefly/ollama-ui) - [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama) - [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file) - [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui) -- [Web UI](https://github.com/ollama-webui/ollama-webui) +- [Open WebUI](https://github.com/open-webui/open-webui) - [Ollamac](https://github.com/kevinhermawan/Ollamac) -- [big-AGI](https://github.com/enricoros/big-agi/blob/main/docs/config-ollama.md) +- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md) - [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core) - [Amica](https://github.com/semperai/amica) - [chatd](https://github.com/BruceMacD/chatd) - [Ollama-SwiftUI](https://github.com/kghandour/Ollama-SwiftUI) - [MindMac](https://mindmac.app) - +- [NextJS Web Interface for Ollama](https://github.com/jakobhoeg/nextjs-ollama-llm-ui) +- [Msty](https://msty.app) +- [Chatbox](https://github.com/Bin-Huang/Chatbox) +- [WinForm Ollama Copilot](https://github.com/tgraupmann/WinForm_Ollama_Copilot) +- [NextChat](https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web) with [Get Started Doc](https://docs.nextchat.dev/models/ollama) +- [Odin Runes](https://github.com/leonid20000/OdinRunes) +- [LLM-X: Progressive Web App](https://github.com/mrdjohnson/llm-x) ### Terminal @@ -275,10 +288,14 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Emacs client](https://github.com/zweifisch/ollama) - [gen.nvim](https://github.com/David-Kunz/gen.nvim) - [ollama.nvim](https://github.com/nomnivore/ollama.nvim) +- [ollama-chat.nvim](https://github.com/gerazov/ollama-chat.nvim) - [ogpt.nvim](https://github.com/huynle/ogpt.nvim) - [gptel Emacs client](https://github.com/karthink/gptel) - [Oatmeal](https://github.com/dustinblackman/oatmeal) - [cmdh](https://github.com/pgibler/cmdh) +- [tenere](https://github.com/pythops/tenere) +- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/). +- [ShellOracle](https://github.com/djcopley/ShellOracle) ### Database @@ -287,12 +304,15 @@ See the [API documentation](./docs/api.md) for all endpoints. ### Package managers - [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/) +- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama) ### Libraries - [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa) - [LangChainGo](https://github.com/tmc/langchaingo/) with [example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example) +- [LangChain4j](https://github.com/langchain4j/langchain4j) with [example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java) - [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html) +- [LangChain4j](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-ollama) - [LiteLLM](https://github.com/BerriAI/litellm) - [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp) - [Ollama for Ruby](https://github.com/gbaptista/ollama-ai) @@ -305,8 +325,10 @@ See the [API documentation](./docs/api.md) for all endpoints. - [LangChainDart](https://github.com/davidmigloz/langchain_dart) - [Semantic Kernel - Python](https://github.com/microsoft/semantic-kernel/tree/main/python/semantic_kernel/connectors/ai/ollama) - [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md) +- [Elixir LangChain](https://github.com/brainlid/langchain) - [Ollama for R - rollama](https://github.com/JBGruber/rollama) - +- [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex) +- [Ollama Connector for SAP ABAP](https://github.com/b-tocs/abap_btocs_ollama) ### Mobile @@ -320,6 +342,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Continue](https://github.com/continuedev/continue) - [Obsidian Ollama plugin](https://github.com/hinterdupfinger/obsidian-ollama) - [Logseq Ollama plugin](https://github.com/omagdy7/ollama-logseq) +- [NotesOllama](https://github.com/andersrex/notesollama) (Apple Notes Ollama plugin) - [Dagger Chatbot](https://github.com/samalba/dagger-chatbot) - [Discord AI Bot](https://github.com/mekb-turtle/discord-ai-bot) - [Ollama Telegram Bot](https://github.com/ruecat/ollama-telegram) @@ -327,6 +350,9 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama) - [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama) - [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot) +- [Copilot for Obsidian plugin](https://github.com/logancyang/obsidian-copilot) +- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt) - [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama) - [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama) - +- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace) +- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension) diff --git a/api/client.go b/api/client.go index 4fe8668a..36019206 100644 --- a/api/client.go +++ b/api/client.go @@ -21,7 +21,7 @@ import ( type Client struct { base *url.URL - http http.Client + http *http.Client } func checkError(resp *http.Response, body []byte) error { @@ -66,30 +66,13 @@ func ClientFromEnvironment() (*Client, error) { } } - client := Client{ + return &Client{ base: &url.URL{ Scheme: scheme, Host: net.JoinHostPort(host, port), }, - } - - mockRequest, err := http.NewRequest(http.MethodHead, client.base.String(), nil) - if err != nil { - return nil, err - } - - proxyURL, err := http.ProxyFromEnvironment(mockRequest) - if err != nil { - return nil, err - } - - client.http = http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - }, - } - - return &client, nil + http: http.DefaultClient, + }, nil } func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error { diff --git a/api/types.go b/api/types.go index 609c4a8a..3169e11f 100644 --- a/api/types.go +++ b/api/types.go @@ -83,7 +83,7 @@ type Metrics struct { EvalDuration time.Duration `json:"eval_duration,omitempty"` } -// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also +// Options specified in GenerateRequest, if you add a new option here add it to the API docs also type Options struct { Runner @@ -121,7 +121,6 @@ type Runner struct { VocabOnly bool `json:"vocab_only,omitempty"` UseMMap bool `json:"use_mmap,omitempty"` UseMLock bool `json:"use_mlock,omitempty"` - EmbeddingOnly bool `json:"embedding_only,omitempty"` RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"` RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"` NumThread int `json:"num_thread,omitempty"` @@ -395,7 +394,6 @@ func DefaultOptions() Options { UseMLock: false, UseMMap: true, UseNUMA: false, - EmbeddingOnly: true, }, } } @@ -415,8 +413,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) { switch t := v.(type) { case float64: if t < 0 { - t = math.MaxFloat64 - d.Duration = time.Duration(t) + d.Duration = time.Duration(math.MaxInt64) } else { d.Duration = time.Duration(t * float64(time.Second)) } @@ -426,8 +423,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) { return err } if d.Duration < 0 { - mf := math.MaxFloat64 - d.Duration = time.Duration(mf) + d.Duration = time.Duration(math.MaxInt64) } } diff --git a/app/.gitignore b/app/.gitignore index 8296128d..0aa24794 100644 --- a/app/.gitignore +++ b/app/.gitignore @@ -1,92 +1 @@ -# Logs -logs -*.log -npm-debug.log* -yarn-debug.log* -yarn-error.log* -lerna-debug.log* - -# Diagnostic reports (https://nodejs.org/api/report.html) -report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json - -# Runtime data -pids -*.pid -*.seed -*.pid.lock -.DS_Store - -# Directory for instrumented libs generated by jscoverage/JSCover -lib-cov - -# Coverage directory used by tools like istanbul -coverage -*.lcov - -# nyc test coverage -.nyc_output - -# node-waf configuration -.lock-wscript - -# Compiled binary addons (https://nodejs.org/api/addons.html) -build/Release - -# Dependency directories -node_modules/ -jspm_packages/ - -# TypeScript v1 declaration files -typings/ - -# TypeScript cache -*.tsbuildinfo - -# Optional npm cache directory -.npm - -# Optional eslint cache -.eslintcache - -# Optional REPL history -.node_repl_history - -# Output of 'npm pack' -*.tgz - -# Yarn Integrity file -.yarn-integrity - -# dotenv environment variables file -.env -.env.test - -# parcel-bundler cache (https://parceljs.org/) -.cache - -# next.js build output -.next - -# nuxt.js build output -.nuxt - -# vuepress build output -.vuepress/dist - -# Serverless directories -.serverless/ - -# FuseBox cache -.fusebox/ - -# DynamoDB Local files -.dynamodb/ - -# Webpack -.webpack/ - -# Vite -.vite/ - -# Electron-Forge -out/ +ollama.syso diff --git a/app/README.md b/app/README.md index cc34d745..883d7ab7 100644 --- a/app/README.md +++ b/app/README.md @@ -1,21 +1,22 @@ -# Desktop +# Ollama App -This app builds upon Ollama to provide a desktop experience for running models. +## Linux -## Developing +TODO -First, build the `ollama` binary: +## MacOS + +TODO + +## Windows + +If you want to build the installer, youll need to install +- https://jrsoftware.org/isinfo.php + + +In the top directory of this repo, run the following powershell script +to build the ollama CLI, ollama app, and ollama installer. ``` -cd .. -go build . +powershell -ExecutionPolicy Bypass -File .\scripts\build_windows.ps1 ``` - -Then run the desktop app with `npm start`: - -``` -cd app -npm install -npm start -``` - diff --git a/app/assets/app.ico b/app/assets/app.ico new file mode 100644 index 00000000..875924f2 Binary files /dev/null and b/app/assets/app.ico differ diff --git a/app/assets/assets.go b/app/assets/assets.go new file mode 100644 index 00000000..6fed2d0e --- /dev/null +++ b/app/assets/assets.go @@ -0,0 +1,17 @@ +package assets + +import ( + "embed" + "io/fs" +) + +//go:embed *.ico +var icons embed.FS + +func ListIcons() ([]string, error) { + return fs.Glob(icons, "*") +} + +func GetIcon(filename string) ([]byte, error) { + return icons.ReadFile(filename) +} diff --git a/app/assets/setup.bmp b/app/assets/setup.bmp new file mode 100644 index 00000000..ff58b909 Binary files /dev/null and b/app/assets/setup.bmp differ diff --git a/app/assets/tray.ico b/app/assets/tray.ico new file mode 100644 index 00000000..e63616c5 Binary files /dev/null and b/app/assets/tray.ico differ diff --git a/app/assets/tray_upgrade.ico b/app/assets/tray_upgrade.ico new file mode 100644 index 00000000..d2083051 Binary files /dev/null and b/app/assets/tray_upgrade.ico differ diff --git a/app/lifecycle/getstarted_nonwindows.go b/app/lifecycle/getstarted_nonwindows.go new file mode 100644 index 00000000..c36d14c0 --- /dev/null +++ b/app/lifecycle/getstarted_nonwindows.go @@ -0,0 +1,9 @@ +//go:build !windows + +package lifecycle + +import "fmt" + +func GetStarted() error { + return fmt.Errorf("GetStarted not implemented") +} diff --git a/app/lifecycle/getstarted_windows.go b/app/lifecycle/getstarted_windows.go new file mode 100644 index 00000000..092c3c17 --- /dev/null +++ b/app/lifecycle/getstarted_windows.go @@ -0,0 +1,44 @@ +package lifecycle + +import ( + "fmt" + "log/slog" + "os" + "os/exec" + "path/filepath" + "syscall" +) + +func GetStarted() error { + const CREATE_NEW_CONSOLE = 0x00000010 + var err error + bannerScript := filepath.Join(AppDir, "ollama_welcome.ps1") + args := []string{ + // TODO once we're signed, the execution policy bypass should be removed + "powershell", "-noexit", "-ExecutionPolicy", "Bypass", "-nologo", "-file", bannerScript, + } + args[0], err = exec.LookPath(args[0]) + if err != nil { + return err + } + + // Make sure the script actually exists + _, err = os.Stat(bannerScript) + if err != nil { + return fmt.Errorf("getting started banner script error %s", err) + } + + slog.Info(fmt.Sprintf("opening getting started terminal with %v", args)) + attrs := &os.ProcAttr{ + Files: []*os.File{os.Stdin, os.Stdout, os.Stderr}, + Sys: &syscall.SysProcAttr{CreationFlags: CREATE_NEW_CONSOLE, HideWindow: false}, + } + proc, err := os.StartProcess(args[0], args, attrs) + + if err != nil { + return fmt.Errorf("unable to start getting started shell %w", err) + } + + slog.Debug(fmt.Sprintf("getting started terminal PID: %d", proc.Pid)) + return proc.Release() +} diff --git a/app/lifecycle/lifecycle.go b/app/lifecycle/lifecycle.go new file mode 100644 index 00000000..14a85b11 --- /dev/null +++ b/app/lifecycle/lifecycle.go @@ -0,0 +1,92 @@ +package lifecycle + +import ( + "context" + "fmt" + "log" + "log/slog" + "os" + "os/signal" + "syscall" + + "github.com/jmorganca/ollama/app/store" + "github.com/jmorganca/ollama/app/tray" +) + +func Run() { + InitLogging() + + ctx, cancel := context.WithCancel(context.Background()) + var done chan int + + t, err := tray.NewTray() + if err != nil { + log.Fatalf("Failed to start: %s", err) + } + callbacks := t.GetCallbacks() + + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + + go func() { + slog.Debug("starting callback loop") + for { + select { + case <-callbacks.Quit: + slog.Debug("quit called") + t.Quit() + case <-signals: + slog.Debug("shutting down due to signal") + t.Quit() + case <-callbacks.Update: + err := DoUpgrade(cancel, done) + if err != nil { + slog.Warn(fmt.Sprintf("upgrade attempt failed: %s", err)) + } + case <-callbacks.ShowLogs: + ShowLogs() + case <-callbacks.DoFirstUse: + err := GetStarted() + if err != nil { + slog.Warn(fmt.Sprintf("Failed to launch getting started shell: %s", err)) + } + } + } + }() + + // Are we first use? + if !store.GetFirstTimeRun() { + slog.Debug("First time run") + err = t.DisplayFirstUseNotification() + if err != nil { + slog.Debug(fmt.Sprintf("XXX failed to display first use notification %v", err)) + } + store.SetFirstTimeRun(true) + } else { + slog.Debug("Not first time, skipping first run notification") + } + + if IsServerRunning(ctx) { + slog.Info("Detected another instance of ollama running, exiting") + os.Exit(1) + } else { + done, err = SpawnServer(ctx, CLIName) + if err != nil { + // TODO - should we retry in a backoff loop? + // TODO - should we pop up a warning and maybe add a menu item to view application logs? + slog.Error(fmt.Sprintf("Failed to spawn ollama server %s", err)) + done = make(chan int, 1) + done <- 1 + } + } + + StartBackgroundUpdaterChecker(ctx, t.UpdateAvailable) + + t.Run() + cancel() + slog.Info("Waiting for ollama server to shutdown...") + if done != nil { + <-done + } + slog.Info("Ollama app exiting") +} diff --git a/app/lifecycle/logging.go b/app/lifecycle/logging.go new file mode 100644 index 00000000..98df9b41 --- /dev/null +++ b/app/lifecycle/logging.go @@ -0,0 +1,46 @@ +package lifecycle + +import ( + "fmt" + "log/slog" + "os" + "path/filepath" +) + +func InitLogging() { + level := slog.LevelInfo + + if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { + level = slog.LevelDebug + } + + var logFile *os.File + var err error + // Detect if we're a GUI app on windows, and if not, send logs to console + if os.Stderr.Fd() != 0 { + // Console app detected + logFile = os.Stderr + // TODO - write one-line to the app.log file saying we're running in console mode to help avoid confusion + } else { + logFile, err = os.OpenFile(AppLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755) + if err != nil { + slog.Error(fmt.Sprintf("failed to create server log %v", err)) + return + } + } + handler := slog.NewTextHandler(logFile, &slog.HandlerOptions{ + Level: level, + AddSource: true, + ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr { + if attr.Key == slog.SourceKey { + source := attr.Value.Any().(*slog.Source) + source.File = filepath.Base(source.File) + } + return attr + }, + }) + + slog.SetDefault(slog.New(handler)) + + slog.Info("ollama app started") +} diff --git a/app/lifecycle/logging_nonwindows.go b/app/lifecycle/logging_nonwindows.go new file mode 100644 index 00000000..50b3a638 --- /dev/null +++ b/app/lifecycle/logging_nonwindows.go @@ -0,0 +1,9 @@ +//go:build !windows + +package lifecycle + +import "log/slog" + +func ShowLogs() { + slog.Warn("ShowLogs not yet implemented") +} diff --git a/app/lifecycle/logging_windows.go b/app/lifecycle/logging_windows.go new file mode 100644 index 00000000..8f20337f --- /dev/null +++ b/app/lifecycle/logging_windows.go @@ -0,0 +1,19 @@ +package lifecycle + +import ( + "fmt" + "log/slog" + "os/exec" + "syscall" +) + +func ShowLogs() { + cmd_path := "c:\\Windows\\system32\\cmd.exe" + slog.Debug(fmt.Sprintf("viewing logs with start %s", AppDataDir)) + cmd := exec.Command(cmd_path, "/c", "start", AppDataDir) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: false, CreationFlags: 0x08000000} + err := cmd.Start() + if err != nil { + slog.Error(fmt.Sprintf("Failed to open log dir: %s", err)) + } +} diff --git a/app/lifecycle/paths.go b/app/lifecycle/paths.go new file mode 100644 index 00000000..e4f2dbd9 --- /dev/null +++ b/app/lifecycle/paths.go @@ -0,0 +1,79 @@ +package lifecycle + +import ( + "errors" + "fmt" + "log/slog" + "os" + "path/filepath" + "runtime" + "strings" +) + +var ( + AppName = "ollama app" + CLIName = "ollama" + AppDir = "/opt/Ollama" + AppDataDir = "/opt/Ollama" + // TODO - should there be a distinct log dir? + UpdateStageDir = "/tmp" + AppLogFile = "/tmp/ollama_app.log" + ServerLogFile = "/tmp/ollama.log" + UpgradeLogFile = "/tmp/ollama_update.log" + Installer = "OllamaSetup.exe" +) + +func init() { + if runtime.GOOS == "windows" { + AppName += ".exe" + CLIName += ".exe" + // Logs, configs, downloads go to LOCALAPPDATA + localAppData := os.Getenv("LOCALAPPDATA") + AppDataDir = filepath.Join(localAppData, "Ollama") + UpdateStageDir = filepath.Join(AppDataDir, "updates") + AppLogFile = filepath.Join(AppDataDir, "app.log") + ServerLogFile = filepath.Join(AppDataDir, "server.log") + UpgradeLogFile = filepath.Join(AppDataDir, "upgrade.log") + + // Executables are stored in APPDATA + AppDir = filepath.Join(localAppData, "Programs", "Ollama") + + // Make sure we have PATH set correctly for any spawned children + paths := strings.Split(os.Getenv("PATH"), ";") + // Start with whatever we find in the PATH/LD_LIBRARY_PATH + found := false + for _, path := range paths { + d, err := filepath.Abs(path) + if err != nil { + continue + } + if strings.EqualFold(AppDir, d) { + found = true + } + } + if !found { + paths = append(paths, AppDir) + + pathVal := strings.Join(paths, ";") + slog.Debug("setting PATH=" + pathVal) + err := os.Setenv("PATH", pathVal) + if err != nil { + slog.Error(fmt.Sprintf("failed to update PATH: %s", err)) + } + } + + // Make sure our logging dir exists + _, err := os.Stat(AppDataDir) + if errors.Is(err, os.ErrNotExist) { + if err := os.MkdirAll(AppDataDir, 0o755); err != nil { + slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err)) + } + } + + } else if runtime.GOOS == "darwin" { + // TODO + AppName += ".app" + // } else if runtime.GOOS == "linux" { + // TODO + } +} diff --git a/app/lifecycle/server.go b/app/lifecycle/server.go new file mode 100644 index 00000000..1cb689a2 --- /dev/null +++ b/app/lifecycle/server.go @@ -0,0 +1,139 @@ +package lifecycle + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/jmorganca/ollama/api" +) + +func getCLIFullPath(command string) string { + cmdPath := "" + appExe, err := os.Executable() + if err == nil { + cmdPath = filepath.Join(filepath.Dir(appExe), command) + _, err := os.Stat(cmdPath) + if err == nil { + return cmdPath + } + } + cmdPath, err = exec.LookPath(command) + if err == nil { + _, err := os.Stat(cmdPath) + if err == nil { + return cmdPath + } + } + pwd, err := os.Getwd() + if err == nil { + cmdPath = filepath.Join(pwd, command) + _, err = os.Stat(cmdPath) + if err == nil { + return cmdPath + } + } + + return command +} + +func SpawnServer(ctx context.Context, command string) (chan int, error) { + done := make(chan int) + + logDir := filepath.Dir(ServerLogFile) + _, err := os.Stat(logDir) + if errors.Is(err, os.ErrNotExist) { + if err := os.MkdirAll(logDir, 0o755); err != nil { + return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err) + } + } + + cmd := getCmd(ctx, getCLIFullPath(command)) + // send stdout and stderr to a file + stdout, err := cmd.StdoutPipe() + if err != nil { + return done, fmt.Errorf("failed to spawn server stdout pipe %s", err) + } + stderr, err := cmd.StderrPipe() + if err != nil { + return done, fmt.Errorf("failed to spawn server stderr pipe %s", err) + } + stdin, err := cmd.StdinPipe() + if err != nil { + return done, fmt.Errorf("failed to spawn server stdin pipe %s", err) + } + + // TODO - rotation + logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755) + if err != nil { + return done, fmt.Errorf("failed to create server log %w", err) + } + go func() { + defer logFile.Close() + io.Copy(logFile, stdout) //nolint:errcheck + }() + go func() { + defer logFile.Close() + io.Copy(logFile, stderr) //nolint:errcheck + }() + + // run the command and wait for it to finish + if err := cmd.Start(); err != nil { + return done, fmt.Errorf("failed to start server %w", err) + } + if cmd.Process != nil { + slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid)) + } + slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile)) + + go func() { + // Keep the server running unless we're shuttind down the app + crashCount := 0 + for { + cmd.Wait() //nolint:errcheck + stdin.Close() + var code int + if cmd.ProcessState != nil { + code = cmd.ProcessState.ExitCode() + } + + select { + case <-ctx.Done(): + slog.Debug(fmt.Sprintf("server shutdown with exit code %d", code)) + done <- code + return + default: + crashCount++ + slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code)) + time.Sleep(500 * time.Millisecond) + if err := cmd.Start(); err != nil { + slog.Error(fmt.Sprintf("failed to restart server %s", err)) + // Keep trying, but back off if we keep failing + time.Sleep(time.Duration(crashCount) * time.Second) + } + } + } + }() + return done, nil +} + +func IsServerRunning(ctx context.Context) bool { + client, err := api.ClientFromEnvironment() + if err != nil { + slog.Info("unable to connect to server") + return false + } + err = client.Heartbeat(ctx) + if err != nil { + slog.Debug(fmt.Sprintf("heartbeat from server: %s", err)) + slog.Info("unable to connect to server") + return false + } + return true +} diff --git a/app/lifecycle/server_unix.go b/app/lifecycle/server_unix.go new file mode 100644 index 00000000..c35f8b5b --- /dev/null +++ b/app/lifecycle/server_unix.go @@ -0,0 +1,12 @@ +//go:build !windows + +package lifecycle + +import ( + "context" + "os/exec" +) + +func getCmd(ctx context.Context, cmd string) *exec.Cmd { + return exec.CommandContext(ctx, cmd, "serve") +} diff --git a/app/lifecycle/server_windows.go b/app/lifecycle/server_windows.go new file mode 100644 index 00000000..3044e526 --- /dev/null +++ b/app/lifecycle/server_windows.go @@ -0,0 +1,13 @@ +package lifecycle + +import ( + "context" + "os/exec" + "syscall" +) + +func getCmd(ctx context.Context, exePath string) *exec.Cmd { + cmd := exec.CommandContext(ctx, exePath, "serve") + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true, CreationFlags: 0x08000000} + return cmd +} diff --git a/app/lifecycle/updater.go b/app/lifecycle/updater.go new file mode 100644 index 00000000..f26e32af --- /dev/null +++ b/app/lifecycle/updater.go @@ -0,0 +1,228 @@ +package lifecycle + +import ( + "context" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "mime" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/jmorganca/ollama/auth" + "github.com/jmorganca/ollama/version" +) + +var ( + UpdateCheckURLBase = "https://ollama.com/api/update" + UpdateDownloaded = false + UpdateCheckInterval = 60 * 60 * time.Second +) + +// TODO - maybe move up to the API package? +type UpdateResponse struct { + UpdateURL string `json:"url"` + UpdateVersion string `json:"version"` +} + +func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) { + var updateResp UpdateResponse + + requestURL, err := url.Parse(UpdateCheckURLBase) + if err != nil { + return false, updateResp + } + + query := requestURL.Query() + query.Add("os", runtime.GOOS) + query.Add("arch", runtime.GOARCH) + query.Add("version", version.Version) + query.Add("ts", fmt.Sprintf("%d", time.Now().Unix())) + + nonce, err := auth.NewNonce(rand.Reader, 16) + if err != nil { + return false, updateResp + } + + query.Add("nonce", nonce) + requestURL.RawQuery = query.Encode() + + data := []byte(fmt.Sprintf("%s,%s", http.MethodGet, requestURL.RequestURI())) + signature, err := auth.Sign(ctx, data) + if err != nil { + return false, updateResp + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil) + if err != nil { + slog.Warn(fmt.Sprintf("failed to check for update: %s", err)) + return false, updateResp + } + req.Header.Set("Authorization", signature) + req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) + + slog.Debug("checking for available update", "requestURL", requestURL) + resp, err := http.DefaultClient.Do(req) + if err != nil { + slog.Warn(fmt.Sprintf("failed to check for update: %s", err)) + return false, updateResp + } + defer resp.Body.Close() + + if resp.StatusCode == 204 { + slog.Debug("check update response 204 (current version is up to date)") + return false, updateResp + } + body, err := io.ReadAll(resp.Body) + if err != nil { + slog.Warn(fmt.Sprintf("failed to read body response: %s", err)) + } + + if resp.StatusCode != 200 { + slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body))) + return false, updateResp + } + err = json.Unmarshal(body, &updateResp) + if err != nil { + slog.Warn(fmt.Sprintf("malformed response checking for update: %s", err)) + return false, updateResp + } + // Extract the version string from the URL in the github release artifact path + updateResp.UpdateVersion = path.Base(path.Dir(updateResp.UpdateURL)) + + slog.Info("New update available at " + updateResp.UpdateURL) + return true, updateResp +} + +func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error { + // Do a head first to check etag info + req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil) + if err != nil { + return err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("error checking update: %w", err) + } + if resp.StatusCode != 200 { + return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode) + } + resp.Body.Close() + etag := strings.Trim(resp.Header.Get("etag"), "\"") + if etag == "" { + slog.Debug("no etag detected, falling back to filename based dedup") + etag = "_" + } + filename := Installer + _, params, err := mime.ParseMediaType(resp.Header.Get("content-disposition")) + if err == nil { + filename = params["filename"] + } + + stageFilename := filepath.Join(UpdateStageDir, etag, filename) + + // Check to see if we already have it downloaded + _, err = os.Stat(stageFilename) + if err == nil { + slog.Info("update already downloaded") + return nil + } + + cleanupOldDownloads() + + req.Method = http.MethodGet + resp, err = http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("error checking update: %w", err) + } + defer resp.Body.Close() + etag = strings.Trim(resp.Header.Get("etag"), "\"") + if etag == "" { + slog.Debug("no etag detected, falling back to filename based dedup") // TODO probably can get rid of this redundant log + etag = "_" + } + + stageFilename = filepath.Join(UpdateStageDir, etag, filename) + + _, err = os.Stat(filepath.Dir(stageFilename)) + if errors.Is(err, os.ErrNotExist) { + if err := os.MkdirAll(filepath.Dir(stageFilename), 0o755); err != nil { + return fmt.Errorf("create ollama dir %s: %v", filepath.Dir(stageFilename), err) + } + } + + payload, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read body response: %w", err) + } + fp, err := os.OpenFile(stageFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) + if err != nil { + return fmt.Errorf("write payload %s: %w", stageFilename, err) + } + defer fp.Close() + if n, err := fp.Write(payload); err != nil || n != len(payload) { + return fmt.Errorf("write payload %s: %d vs %d -- %w", stageFilename, n, len(payload), err) + } + slog.Info("new update downloaded " + stageFilename) + + UpdateDownloaded = true + return nil +} + +func cleanupOldDownloads() { + files, err := os.ReadDir(UpdateStageDir) + if err != nil && errors.Is(err, os.ErrNotExist) { + // Expected behavior on first run + return + } else if err != nil { + slog.Warn(fmt.Sprintf("failed to list stage dir: %s", err)) + return + } + for _, file := range files { + fullname := filepath.Join(UpdateStageDir, file.Name()) + slog.Debug("cleaning up old download: " + fullname) + err = os.RemoveAll(fullname) + if err != nil { + slog.Warn(fmt.Sprintf("failed to cleanup stale update download %s", err)) + } + } +} + +func StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) { + go func() { + // Don't blast an update message immediately after startup + // time.Sleep(30 * time.Second) + time.Sleep(3 * time.Second) + + for { + available, resp := IsNewReleaseAvailable(ctx) + if available { + err := DownloadNewRelease(ctx, resp) + if err != nil { + slog.Error(fmt.Sprintf("failed to download new release: %s", err)) + } + err = cb(resp.UpdateVersion) + if err != nil { + slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err)) + } + } + select { + case <-ctx.Done(): + slog.Debug("stopping background update checker") + return + default: + time.Sleep(UpdateCheckInterval) + } + } + }() +} diff --git a/app/lifecycle/updater_nonwindows.go b/app/lifecycle/updater_nonwindows.go new file mode 100644 index 00000000..0f213b34 --- /dev/null +++ b/app/lifecycle/updater_nonwindows.go @@ -0,0 +1,12 @@ +//go:build !windows + +package lifecycle + +import ( + "context" + "fmt" +) + +func DoUpgrade(cancel context.CancelFunc, done chan int) error { + return fmt.Errorf("DoUpgrade not yet implemented") +} diff --git a/app/lifecycle/updater_windows.go b/app/lifecycle/updater_windows.go new file mode 100644 index 00000000..f26c43c9 --- /dev/null +++ b/app/lifecycle/updater_windows.go @@ -0,0 +1,80 @@ +package lifecycle + +import ( + "context" + "fmt" + "log/slog" + "os" + "os/exec" + "path/filepath" +) + +func DoUpgrade(cancel context.CancelFunc, done chan int) error { + files, err := filepath.Glob(filepath.Join(UpdateStageDir, "*", "*.exe")) // TODO generalize for multiplatform + if err != nil { + return fmt.Errorf("failed to lookup downloads: %s", err) + } + if len(files) == 0 { + return fmt.Errorf("no update downloads found") + } else if len(files) > 1 { + // Shouldn't happen + slog.Warn(fmt.Sprintf("multiple downloads found, using first one %v", files)) + } + installerExe := files[0] + + slog.Info("starting upgrade with " + installerExe) + slog.Info("upgrade log file " + UpgradeLogFile) + + // When running in debug mode, we'll be "verbose" and let the installer pop up and prompt + installArgs := []string{ + "/CLOSEAPPLICATIONS", // Quit the tray app if it's still running + "/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd + "/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed + } + // When we're not in debug mode, make the upgrade as quiet as possible (no GUI, no prompts) + // TODO - temporarily disable since we're pinning in debug mode for the preview + // if debug := os.Getenv("OLLAMA_DEBUG"); debug == "" { + installArgs = append(installArgs, + "/SP", // Skip the "This will install... Do you wish to continue" prompt + "/SUPPRESSMSGBOXES", + "/SILENT", + "/VERYSILENT", + ) + // } + + // Safeguard in case we have requests in flight that need to drain... + slog.Info("Waiting for server to shutdown") + cancel() + if done != nil { + <-done + } else { + // Shouldn't happen + slog.Warn("done chan was nil, not actually waiting") + } + + slog.Debug(fmt.Sprintf("starting installer: %s %v", installerExe, installArgs)) + os.Chdir(filepath.Dir(UpgradeLogFile)) //nolint:errcheck + cmd := exec.Command(installerExe, installArgs...) + + if err := cmd.Start(); err != nil { + return fmt.Errorf("unable to start ollama app %w", err) + } + + if cmd.Process != nil { + err = cmd.Process.Release() + if err != nil { + slog.Error(fmt.Sprintf("failed to release server process: %s", err)) + } + } else { + // TODO - some details about why it didn't start, or is this a pedantic error case? + return fmt.Errorf("installer process did not start") + } + + // TODO should we linger for a moment and check to make sure it's actually running by checking the pid? + + slog.Info("Installer started in background, exiting") + + os.Exit(0) + // Not reached + return nil +} diff --git a/app/main.go b/app/main.go new file mode 100644 index 00000000..57d8b1c1 --- /dev/null +++ b/app/main.go @@ -0,0 +1,12 @@ +package main + +// Compile with the following to get rid of the cmd pop up on windows +// go build -ldflags="-H windowsgui" . + +import ( + "github.com/jmorganca/ollama/app/lifecycle" +) + +func main() { + lifecycle.Run() +} diff --git a/app/ollama.iss b/app/ollama.iss new file mode 100644 index 00000000..df61ac4c --- /dev/null +++ b/app/ollama.iss @@ -0,0 +1,159 @@ +; Inno Setup Installer for Ollama +; +; To build the installer use the build script invoked from the top of the source tree +; +; powershell -ExecutionPolicy Bypass -File .\scripts\build_windows.ps + + +#define MyAppName "Ollama" +#if GetEnv("PKG_VERSION") != "" + #define MyAppVersion GetEnv("PKG_VERSION") +#else + #define MyAppVersion "0.0.0" +#endif +#define MyAppPublisher "Ollama" +#define MyAppURL "https://ollama.com/" +#define MyAppExeName "ollama app.exe" +#define MyIcon ".\assets\app.ico" + +[Setup] +; NOTE: The value of AppId uniquely identifies this application. Do not use the same AppId value in installers for other applications. +; (To generate a new GUID, click Tools | Generate GUID inside the IDE.) +AppId={{44E83376-CE68-45EB-8FC1-393500EB558C} +AppName={#MyAppName} +AppVersion={#MyAppVersion} +VersionInfoVersion={#MyAppVersion} +;AppVerName={#MyAppName} {#MyAppVersion} +AppPublisher={#MyAppPublisher} +AppPublisherURL={#MyAppURL} +AppSupportURL={#MyAppURL} +AppUpdatesURL={#MyAppURL} +ArchitecturesAllowed=x64 +ArchitecturesInstallIn64BitMode=x64 +DefaultDirName={localappdata}\Programs\{#MyAppName} +DefaultGroupName={#MyAppName} +DisableProgramGroupPage=yes +PrivilegesRequired=lowest +OutputBaseFilename="OllamaSetup" +SetupIconFile={#MyIcon} +UninstallDisplayIcon={uninstallexe} +Compression=lzma2 +SolidCompression=no +WizardStyle=modern +ChangesEnvironment=yes +OutputDir=..\dist\ + +; Disable logging once everything's battle tested +; Filename will be %TEMP%\Setup Log*.txt +SetupLogging=yes +CloseApplications=yes +RestartApplications=no + +; https://jrsoftware.org/ishelp/index.php?topic=setup_wizardimagefile +WizardSmallImageFile=.\assets\setup.bmp + +; TODO verifty actual min windows version... +; OG Win 10 +MinVersion=10.0.10240 + +; First release that supports WinRT UI Composition for win32 apps +; MinVersion=10.0.17134 +; First release with XAML Islands - possible UI path forward +; MinVersion=10.0.18362 + +; quiet... +DisableDirPage=yes +DisableFinishedPage=yes +DisableReadyMemo=yes +DisableReadyPage=yes +DisableStartupPrompt=yes +DisableWelcomePage=yes + +; TODO - percentage can't be set less than 100, so how to make it shorter? +; WizardSizePercent=100,80 + +#if GetEnv("KEY_CONTAINER") +SignTool=MySignTool +SignedUninstaller=yes +#endif + +SetupMutex=OllamaSetupMutex + +[Languages] +Name: "english"; MessagesFile: "compiler:Default.isl" + +[LangOptions] +DialogFontSize=12 + +[Files] +Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit +Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit +Source: "..\dist\windeps\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit +Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion +Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion +; Assumes v5.7, may need adjustments for v6 +#if GetEnv("HIP_PATH") != "" + Source: "{#GetEnv('HIP_PATH')}\bin\hipblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion + Source: "{#GetEnv('HIP_PATH')}\bin\rocblas.dll"; DestDir: "{app}\rocm\"; Flags: ignoreversion + ; amdhip64.dll dependency comes from the driver and must be installed already + Source: "{#GetEnv('HIP_PATH')}\bin\rocblas\library\*"; DestDir: "{app}\rocm\rocblas\library\"; Flags: ignoreversion +#endif + + +[Icons] +Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"; IconFilename: "{app}\app.ico" +Name: "{userstartup}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"; IconFilename: "{app}\app.ico" +Name: "{userprograms}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"; IconFilename: "{app}\app.ico" + +[Run] +Filename: "{cmd}"; Parameters: "/C set PATH={app};%PATH% & ""{app}\{#MyAppExeName}"""; Flags: postinstall nowait runhidden + +[UninstallRun] +; Filename: "{cmd}"; Parameters: "/C ""taskkill /im ''{#MyAppExeName}'' /f /t"; Flags: runhidden +; Filename: "{cmd}"; Parameters: "/C ""taskkill /im ollama.exe /f /t"; Flags: runhidden +Filename: "taskkill"; Parameters: "/im ""{#MyAppExeName}"" /f /t"; Flags: runhidden +Filename: "taskkill"; Parameters: "/im ""ollama.exe"" /f /t"; Flags: runhidden +; HACK! need to give the server and app enough time to exit +; TODO - convert this to a Pascal code script so it waits until they're no longer running, then completes +Filename: "{cmd}"; Parameters: "/c timeout 5"; Flags: runhidden + +[UninstallDelete] +Type: filesandordirs; Name: "{%TEMP}\ollama*" +Type: filesandordirs; Name: "{%LOCALAPPDATA}\Ollama" +Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama" +Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models" +Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history" +; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved + +[Messages] +WizardReady=Ollama Windows Preview +ReadyLabel1=%nLet's get you up and running with your own large language models. +SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or finish the other installer, then click OK to continue with this install, or Cancel to exit. + + +;FinishedHeadingLabel=Run your first model +;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama2 +;ClickFinish=%n + +[Registry] +Root: HKCU; Subkey: "Environment"; \ + ValueType: expandsz; ValueName: "Path"; ValueData: "{olddata};{app}"; \ + Check: NeedsAddPath('{app}') + +[Code] + +function NeedsAddPath(Param: string): boolean; +var + OrigPath: string; +begin + if not RegQueryStringValue(HKEY_CURRENT_USER, + 'Environment', + 'Path', OrigPath) + then begin + Result := True; + exit; + end; + { look for the path with leading and trailing semicolon } + { Pos() returns 0 if not found } + Result := Pos(';' + ExpandConstant(Param) + ';', ';' + OrigPath + ';') = 0; +end; diff --git a/app/ollama.rc b/app/ollama.rc new file mode 100644 index 00000000..acd84493 --- /dev/null +++ b/app/ollama.rc @@ -0,0 +1,29 @@ +#include + +VS_VERSION_INFO VERSIONINFO + FILEFLAGSMASK 0x3fL +#ifdef _DEBUG + FILEFLAGS 0x1L +#else + FILEFLAGS 0x0L +#endif + FILEOS 0x40004L + FILETYPE 0x1L + FILESUBTYPE 0x0L +BEGIN + BLOCK "StringFileInfo" + BEGIN + BLOCK "040904b0" + BEGIN + VALUE "FileDescription", "Ollama" + VALUE "InternalName", "Ollama" + VALUE "OriginalFilename", "ollama app.exe" + VALUE "ProductName", "Ollama" + END + END + + BLOCK "VarFileInfo" + BEGIN + VALUE "Translation", 0x409, 1200 + END +END diff --git a/app/ollama_welcome.ps1 b/app/ollama_welcome.ps1 new file mode 100644 index 00000000..e7056952 --- /dev/null +++ b/app/ollama_welcome.ps1 @@ -0,0 +1,8 @@ +# TODO - consider ANSI colors and maybe ASCII art... +write-host "" +write-host "Welcome to Ollama!" +write-host "" +write-host "Run your first model:" +write-host "" +write-host "`tollama run llama2" +write-host "" \ No newline at end of file diff --git a/app/store/store.go b/app/store/store.go new file mode 100644 index 00000000..13a75a60 --- /dev/null +++ b/app/store/store.go @@ -0,0 +1,98 @@ +package store + +import ( + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + "path/filepath" + "sync" + + "github.com/google/uuid" +) + +type Store struct { + ID string `json:"id"` + FirstTimeRun bool `json:"first-time-run"` +} + +var ( + lock sync.Mutex + store Store +) + +func GetID() string { + lock.Lock() + defer lock.Unlock() + if store.ID == "" { + initStore() + } + return store.ID + +} + +func GetFirstTimeRun() bool { + lock.Lock() + defer lock.Unlock() + if store.ID == "" { + initStore() + } + return store.FirstTimeRun +} + +func SetFirstTimeRun(val bool) { + lock.Lock() + defer lock.Unlock() + if store.FirstTimeRun == val { + return + } + store.FirstTimeRun = val + writeStore(getStorePath()) +} + +// lock must be held +func initStore() { + storeFile, err := os.Open(getStorePath()) + if err == nil { + defer storeFile.Close() + err = json.NewDecoder(storeFile).Decode(&store) + if err == nil { + slog.Debug(fmt.Sprintf("loaded existing store %s - ID: %s", getStorePath(), store.ID)) + return + } + } else if !errors.Is(err, os.ErrNotExist) { + slog.Debug(fmt.Sprintf("unexpected error searching for store: %s", err)) + } + slog.Debug("initializing new store") + store.ID = uuid.New().String() + writeStore(getStorePath()) +} + +func writeStore(storeFilename string) { + ollamaDir := filepath.Dir(storeFilename) + _, err := os.Stat(ollamaDir) + if errors.Is(err, os.ErrNotExist) { + if err := os.MkdirAll(ollamaDir, 0o755); err != nil { + slog.Error(fmt.Sprintf("create ollama dir %s: %v", ollamaDir, err)) + return + } + } + payload, err := json.Marshal(store) + if err != nil { + slog.Error(fmt.Sprintf("failed to marshal store: %s", err)) + return + } + fp, err := os.OpenFile(storeFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) + if err != nil { + slog.Error(fmt.Sprintf("write store payload %s: %v", storeFilename, err)) + return + } + defer fp.Close() + if n, err := fp.Write(payload); err != nil || n != len(payload) { + slog.Error(fmt.Sprintf("write store payload %s: %d vs %d -- %v", storeFilename, n, len(payload), err)) + return + } + slog.Debug("Store contents: " + string(payload)) + slog.Info(fmt.Sprintf("wrote store: %s", storeFilename)) +} diff --git a/app/store/store_darwin.go b/app/store/store_darwin.go new file mode 100644 index 00000000..e53d8525 --- /dev/null +++ b/app/store/store_darwin.go @@ -0,0 +1,13 @@ +package store + +import ( + "os" + "path/filepath" +) + +func getStorePath() string { + // TODO - system wide location? + + home := os.Getenv("HOME") + return filepath.Join(home, "Library", "Application Support", "Ollama", "config.json") +} diff --git a/app/store/store_linux.go b/app/store/store_linux.go new file mode 100644 index 00000000..3aac9b01 --- /dev/null +++ b/app/store/store_linux.go @@ -0,0 +1,16 @@ +package store + +import ( + "os" + "path/filepath" +) + +func getStorePath() string { + if os.Geteuid() == 0 { + // TODO where should we store this on linux for system-wide operation? + return "/etc/ollama/config.json" + } + + home := os.Getenv("HOME") + return filepath.Join(home, ".ollama", "config.json") +} diff --git a/app/store/store_windows.go b/app/store/store_windows.go new file mode 100644 index 00000000..ba06b82c --- /dev/null +++ b/app/store/store_windows.go @@ -0,0 +1,11 @@ +package store + +import ( + "os" + "path/filepath" +) + +func getStorePath() string { + localAppData := os.Getenv("LOCALAPPDATA") + return filepath.Join(localAppData, "Ollama", "config.json") +} diff --git a/app/tray/commontray/types.go b/app/tray/commontray/types.go new file mode 100644 index 00000000..ed633dc9 --- /dev/null +++ b/app/tray/commontray/types.go @@ -0,0 +1,24 @@ +package commontray + +var ( + Title = "Ollama" + ToolTip = "Ollama" + + UpdateIconName = "tray_upgrade" + IconName = "tray" +) + +type Callbacks struct { + Quit chan struct{} + Update chan struct{} + DoFirstUse chan struct{} + ShowLogs chan struct{} +} + +type OllamaTray interface { + GetCallbacks() Callbacks + Run() + UpdateAvailable(ver string) error + DisplayFirstUseNotification() error + Quit() +} diff --git a/app/tray/tray.go b/app/tray/tray.go new file mode 100644 index 00000000..47b204d6 --- /dev/null +++ b/app/tray/tray.go @@ -0,0 +1,33 @@ +package tray + +import ( + "fmt" + "runtime" + + "github.com/jmorganca/ollama/app/assets" + "github.com/jmorganca/ollama/app/tray/commontray" +) + +func NewTray() (commontray.OllamaTray, error) { + extension := ".png" + if runtime.GOOS == "windows" { + extension = ".ico" + } + iconName := commontray.UpdateIconName + extension + updateIcon, err := assets.GetIcon(iconName) + if err != nil { + return nil, fmt.Errorf("failed to load icon %s: %w", iconName, err) + } + iconName = commontray.IconName + extension + icon, err := assets.GetIcon(iconName) + if err != nil { + return nil, fmt.Errorf("failed to load icon %s: %w", iconName, err) + } + + tray, err := InitPlatformTray(icon, updateIcon) + if err != nil { + return nil, err + } + + return tray, nil +} diff --git a/app/tray/tray_nonwindows.go b/app/tray/tray_nonwindows.go new file mode 100644 index 00000000..6c30c3c2 --- /dev/null +++ b/app/tray/tray_nonwindows.go @@ -0,0 +1,13 @@ +//go:build !windows + +package tray + +import ( + "fmt" + + "github.com/jmorganca/ollama/app/tray/commontray" +) + +func InitPlatformTray(icon, updateIcon []byte) (commontray.OllamaTray, error) { + return nil, fmt.Errorf("NOT IMPLEMENTED YET") +} diff --git a/app/tray/tray_windows.go b/app/tray/tray_windows.go new file mode 100644 index 00000000..8ac4e478 --- /dev/null +++ b/app/tray/tray_windows.go @@ -0,0 +1,10 @@ +package tray + +import ( + "github.com/jmorganca/ollama/app/tray/commontray" + "github.com/jmorganca/ollama/app/tray/wintray" +) + +func InitPlatformTray(icon, updateIcon []byte) (commontray.OllamaTray, error) { + return wintray.InitTray(icon, updateIcon) +} diff --git a/app/tray/wintray/eventloop.go b/app/tray/wintray/eventloop.go new file mode 100644 index 00000000..a0af9787 --- /dev/null +++ b/app/tray/wintray/eventloop.go @@ -0,0 +1,184 @@ +//go:build windows + +package wintray + +import ( + "fmt" + "log/slog" + "sync" + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + quitOnce sync.Once +) + +func (t *winTray) Run() { + nativeLoop() +} + +func nativeLoop() { + // Main message pump. + slog.Debug("starting event handling loop") + m := &struct { + WindowHandle windows.Handle + Message uint32 + Wparam uintptr + Lparam uintptr + Time uint32 + Pt point + LPrivate uint32 + }{} + for { + ret, _, err := pGetMessage.Call(uintptr(unsafe.Pointer(m)), 0, 0, 0) + + // If the function retrieves a message other than WM_QUIT, the return value is nonzero. + // If the function retrieves the WM_QUIT message, the return value is zero. + // If there is an error, the return value is -1 + // https://msdn.microsoft.com/en-us/library/windows/desktop/ms644936(v=vs.85).aspx + switch int32(ret) { + case -1: + slog.Error(fmt.Sprintf("get message failure: %v", err)) + return + case 0: + return + default: + pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck + pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck + + } + } +} + +// WindowProc callback function that processes messages sent to a window. +// https://msdn.microsoft.com/en-us/library/windows/desktop/ms633573(v=vs.85).aspx +func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam uintptr) (lResult uintptr) { + const ( + WM_RBUTTONUP = 0x0205 + WM_LBUTTONUP = 0x0202 + WM_COMMAND = 0x0111 + WM_ENDSESSION = 0x0016 + WM_CLOSE = 0x0010 + WM_DESTROY = 0x0002 + WM_MOUSEMOVE = 0x0200 + WM_LBUTTONDOWN = 0x0201 + ) + switch message { + case WM_COMMAND: + menuItemId := int32(wParam) + // https://docs.microsoft.com/en-us/windows/win32/menurc/wm-command#menus + switch menuItemId { + case quitMenuID: + select { + case t.callbacks.Quit <- struct{}{}: + // should not happen but in case not listening + default: + slog.Error("no listener on Quit") + } + case updateMenuID: + select { + case t.callbacks.Update <- struct{}{}: + // should not happen but in case not listening + default: + slog.Error("no listener on Update") + } + case diagLogsMenuID: + select { + case t.callbacks.ShowLogs <- struct{}{}: + // should not happen but in case not listening + default: + slog.Error("no listener on ShowLogs") + } + default: + slog.Debug(fmt.Sprintf("Unexpected menu item id: %d", menuItemId)) + } + case WM_CLOSE: + boolRet, _, err := pDestroyWindow.Call(uintptr(t.window)) + if boolRet == 0 { + slog.Error(fmt.Sprintf("failed to destroy window: %s", err)) + } + err = t.wcex.unregister() + if err != nil { + slog.Error(fmt.Sprintf("failed to uregister windo %s", err)) + } + case WM_DESTROY: + // same as WM_ENDSESSION, but throws 0 exit code after all + defer pPostQuitMessage.Call(uintptr(int32(0))) //nolint:errcheck + fallthrough + case WM_ENDSESSION: + t.muNID.Lock() + if t.nid != nil { + err := t.nid.delete() + if err != nil { + slog.Error(fmt.Sprintf("failed to delete nid: %s", err)) + } + } + t.muNID.Unlock() + case t.wmSystrayMessage: + switch lParam { + case WM_MOUSEMOVE, WM_LBUTTONDOWN: + // Ignore these... + case WM_RBUTTONUP, WM_LBUTTONUP: + err := t.showMenu() + if err != nil { + slog.Error(fmt.Sprintf("failed to show menu: %s", err)) + } + case 0x405: // TODO - how is this magic value derived for the notification left click + if t.pendingUpdate { + select { + case t.callbacks.Update <- struct{}{}: + // should not happen but in case not listening + default: + slog.Error("no listener on Update") + } + } else { + select { + case t.callbacks.DoFirstUse <- struct{}{}: + // should not happen but in case not listening + default: + slog.Error("no listener on DoFirstUse") + } + } + case 0x404: // Middle click or close notification + // slog.Debug("doing nothing on close of first time notification") + default: + // 0x402 also seems common - what is it? + slog.Debug(fmt.Sprintf("unmanaged app message, lParm: 0x%x", lParam)) + } + case t.wmTaskbarCreated: // on explorer.exe restarts + t.muNID.Lock() + err := t.nid.add() + if err != nil { + slog.Error(fmt.Sprintf("failed to refresh the taskbar on explorer restart: %s", err)) + } + t.muNID.Unlock() + default: + // Calls the default window procedure to provide default processing for any window messages that an application does not process. + // https://msdn.microsoft.com/en-us/library/windows/desktop/ms633572(v=vs.85).aspx + lResult, _, _ = pDefWindowProc.Call( + uintptr(hWnd), + uintptr(message), + uintptr(wParam), + uintptr(lParam), + ) + } + return +} + +func (t *winTray) Quit() { + quitOnce.Do(quit) +} + +func quit() { + boolRet, _, err := pPostMessage.Call( + uintptr(wt.window), + WM_CLOSE, + 0, + 0, + ) + if boolRet == 0 { + slog.Error(fmt.Sprintf("failed to post close message on shutdown %s", err)) + } +} diff --git a/app/tray/wintray/menus.go b/app/tray/wintray/menus.go new file mode 100644 index 00000000..74defa67 --- /dev/null +++ b/app/tray/wintray/menus.go @@ -0,0 +1,71 @@ +//go:build windows + +package wintray + +import ( + "fmt" + "log/slog" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + updatAvailableMenuID = 1 + updateMenuID = updatAvailableMenuID + 1 + separatorMenuID = updateMenuID + 1 + diagLogsMenuID = separatorMenuID + 1 + diagSeparatorMenuID = diagLogsMenuID + 1 + quitMenuID = diagSeparatorMenuID + 1 +) + +func (t *winTray) initMenus() error { + if err := t.addOrUpdateMenuItem(diagLogsMenuID, 0, diagLogsMenuTitle, false); err != nil { + return fmt.Errorf("unable to create menu entries %w\n", err) + } + if err := t.addSeparatorMenuItem(diagSeparatorMenuID, 0); err != nil { + return fmt.Errorf("unable to create menu entries %w", err) + } + if err := t.addOrUpdateMenuItem(quitMenuID, 0, quitMenuTitle, false); err != nil { + return fmt.Errorf("unable to create menu entries %w\n", err) + } + return nil +} + +func (t *winTray) UpdateAvailable(ver string) error { + if !t.updateNotified { + slog.Debug("updating menu and sending notification for new update") + if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil { + return fmt.Errorf("unable to create menu entries %w", err) + } + if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil { + return fmt.Errorf("unable to create menu entries %w", err) + } + if err := t.addSeparatorMenuItem(separatorMenuID, 0); err != nil { + return fmt.Errorf("unable to create menu entries %w", err) + } + iconFilePath, err := iconBytesToFilePath(wt.updateIcon) + if err != nil { + return fmt.Errorf("unable to write icon data to temp file: %w", err) + } + if err := wt.setIcon(iconFilePath); err != nil { + return fmt.Errorf("unable to set icon: %w", err) + } + t.updateNotified = true + + t.pendingUpdate = true + // Now pop up the notification + t.muNID.Lock() + defer t.muNID.Unlock() + copy(t.nid.InfoTitle[:], windows.StringToUTF16(updateTitle)) + copy(t.nid.Info[:], windows.StringToUTF16(fmt.Sprintf(updateMessage, ver))) + t.nid.Flags |= NIF_INFO + t.nid.Timeout = 10 + t.nid.Size = uint32(unsafe.Sizeof(*wt.nid)) + err = t.nid.modify() + if err != nil { + return err + } + } + return nil +} diff --git a/app/tray/wintray/messages.go b/app/tray/wintray/messages.go new file mode 100644 index 00000000..d364c716 --- /dev/null +++ b/app/tray/wintray/messages.go @@ -0,0 +1,15 @@ +//go:build windows + +package wintray + +const ( + firstTimeTitle = "Ollama is running" + firstTimeMessage = "Click here to get started" + updateTitle = "Update available" + updateMessage = "Ollama version %s is ready to install" + + quitMenuTitle = "Quit Ollama" + updateAvailableMenuTitle = "An update is available" + updateMenutTitle = "Restart to update" + diagLogsMenuTitle = "View logs" +) diff --git a/app/tray/wintray/notifyicon.go b/app/tray/wintray/notifyicon.go new file mode 100644 index 00000000..47071669 --- /dev/null +++ b/app/tray/wintray/notifyicon.go @@ -0,0 +1,66 @@ +//go:build windows + +package wintray + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +// Contains information that the system needs to display notifications in the notification area. +// Used by Shell_NotifyIcon. +// https://msdn.microsoft.com/en-us/library/windows/desktop/bb773352(v=vs.85).aspx +// https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159 +type notifyIconData struct { + Size uint32 + Wnd windows.Handle + ID, Flags, CallbackMessage uint32 + Icon windows.Handle + Tip [128]uint16 + State, StateMask uint32 + Info [256]uint16 + // Timeout, Version uint32 + Timeout uint32 + + InfoTitle [64]uint16 + InfoFlags uint32 + GuidItem windows.GUID + BalloonIcon windows.Handle +} + +func (nid *notifyIconData) add() error { + const NIM_ADD = 0x00000000 + res, _, err := pShellNotifyIcon.Call( + uintptr(NIM_ADD), + uintptr(unsafe.Pointer(nid)), + ) + if res == 0 { + return err + } + return nil +} + +func (nid *notifyIconData) modify() error { + const NIM_MODIFY = 0x00000001 + res, _, err := pShellNotifyIcon.Call( + uintptr(NIM_MODIFY), + uintptr(unsafe.Pointer(nid)), + ) + if res == 0 { + return err + } + return nil +} + +func (nid *notifyIconData) delete() error { + const NIM_DELETE = 0x00000002 + res, _, err := pShellNotifyIcon.Call( + uintptr(NIM_DELETE), + uintptr(unsafe.Pointer(nid)), + ) + if res == 0 { + return err + } + return nil +} diff --git a/app/tray/wintray/tray.go b/app/tray/wintray/tray.go new file mode 100644 index 00000000..365cfb82 --- /dev/null +++ b/app/tray/wintray/tray.go @@ -0,0 +1,485 @@ +//go:build windows + +package wintray + +import ( + "crypto/md5" + "encoding/hex" + "fmt" + "log/slog" + "os" + "path/filepath" + "sort" + "sync" + "unsafe" + + "github.com/jmorganca/ollama/app/tray/commontray" + "golang.org/x/sys/windows" +) + +// Helpful sources: https://github.com/golang/exp/blob/master/shiny/driver/internal/win32 + +// Contains information about loaded resources +type winTray struct { + instance, + icon, + cursor, + window windows.Handle + + loadedImages map[string]windows.Handle + muLoadedImages sync.RWMutex + + // menus keeps track of the submenus keyed by the menu item ID, plus 0 + // which corresponds to the main popup menu. + menus map[uint32]windows.Handle + muMenus sync.RWMutex + menuOf map[uint32]windows.Handle + muMenuOf sync.RWMutex + // menuItemIcons maintains the bitmap of each menu item (if applies). It's + // needed to show the icon correctly when showing a previously hidden menu + // item again. + // menuItemIcons map[uint32]windows.Handle + // muMenuItemIcons sync.RWMutex + visibleItems map[uint32][]uint32 + muVisibleItems sync.RWMutex + + nid *notifyIconData + muNID sync.RWMutex + wcex *wndClassEx + + wmSystrayMessage, + wmTaskbarCreated uint32 + + pendingUpdate bool + updateNotified bool // Only pop up the notification once - TODO consider daily nag? + // Callbacks + callbacks commontray.Callbacks + normalIcon []byte + updateIcon []byte +} + +var wt winTray + +func (t *winTray) GetCallbacks() commontray.Callbacks { + return t.callbacks +} + +func InitTray(icon, updateIcon []byte) (*winTray, error) { + wt.callbacks.Quit = make(chan struct{}) + wt.callbacks.Update = make(chan struct{}) + wt.callbacks.ShowLogs = make(chan struct{}) + wt.callbacks.DoFirstUse = make(chan struct{}) + wt.normalIcon = icon + wt.updateIcon = updateIcon + if err := wt.initInstance(); err != nil { + return nil, fmt.Errorf("Unable to init instance: %w\n", err) + } + + if err := wt.createMenu(); err != nil { + return nil, fmt.Errorf("Unable to create menu: %w\n", err) + } + + iconFilePath, err := iconBytesToFilePath(wt.normalIcon) + if err != nil { + return nil, fmt.Errorf("Unable to write icon data to temp file: %w", err) + } + if err := wt.setIcon(iconFilePath); err != nil { + return nil, fmt.Errorf("Unable to set icon: %w", err) + } + + return &wt, wt.initMenus() +} + +func (t *winTray) initInstance() error { + const ( + className = "OllamaClass" + windowName = "" + ) + + t.wmSystrayMessage = WM_USER + 1 + t.visibleItems = make(map[uint32][]uint32) + t.menus = make(map[uint32]windows.Handle) + t.menuOf = make(map[uint32]windows.Handle) + + t.loadedImages = make(map[string]windows.Handle) + + taskbarEventNamePtr, _ := windows.UTF16PtrFromString("TaskbarCreated") + // https://msdn.microsoft.com/en-us/library/windows/desktop/ms644947 + res, _, err := pRegisterWindowMessage.Call( + uintptr(unsafe.Pointer(taskbarEventNamePtr)), + ) + if res == 0 { // success 0xc000-0xfff + return fmt.Errorf("failed to register window: %w", err) + } + t.wmTaskbarCreated = uint32(res) + + instanceHandle, _, err := pGetModuleHandle.Call(0) + if instanceHandle == 0 { + return err + } + t.instance = windows.Handle(instanceHandle) + + // https://msdn.microsoft.com/en-us/library/windows/desktop/ms648072(v=vs.85).aspx + iconHandle, _, err := pLoadIcon.Call(0, uintptr(IDI_APPLICATION)) + if iconHandle == 0 { + return err + } + t.icon = windows.Handle(iconHandle) + + // https://msdn.microsoft.com/en-us/library/windows/desktop/ms648391(v=vs.85).aspx + cursorHandle, _, err := pLoadCursor.Call(0, uintptr(IDC_ARROW)) + if cursorHandle == 0 { + return err + } + t.cursor = windows.Handle(cursorHandle) + + classNamePtr, err := windows.UTF16PtrFromString(className) + if err != nil { + return err + } + + windowNamePtr, err := windows.UTF16PtrFromString(windowName) + if err != nil { + return err + } + + t.wcex = &wndClassEx{ + Style: CS_HREDRAW | CS_VREDRAW, + WndProc: windows.NewCallback(t.wndProc), + Instance: t.instance, + Icon: t.icon, + Cursor: t.cursor, + Background: windows.Handle(6), // (COLOR_WINDOW + 1) + ClassName: classNamePtr, + IconSm: t.icon, + } + if err := t.wcex.register(); err != nil { + return err + } + + windowHandle, _, err := pCreateWindowEx.Call( + uintptr(0), + uintptr(unsafe.Pointer(classNamePtr)), + uintptr(unsafe.Pointer(windowNamePtr)), + uintptr(WS_OVERLAPPEDWINDOW), + uintptr(CW_USEDEFAULT), + uintptr(CW_USEDEFAULT), + uintptr(CW_USEDEFAULT), + uintptr(CW_USEDEFAULT), + uintptr(0), + uintptr(0), + uintptr(t.instance), + uintptr(0), + ) + if windowHandle == 0 { + return err + } + t.window = windows.Handle(windowHandle) + + pShowWindow.Call(uintptr(t.window), uintptr(SW_HIDE)) //nolint:errcheck + + boolRet, _, err := pUpdateWindow.Call(uintptr(t.window)) + if boolRet == 0 { + slog.Error(fmt.Sprintf("failed to update window: %s", err)) + } + + t.muNID.Lock() + defer t.muNID.Unlock() + t.nid = ¬ifyIconData{ + Wnd: windows.Handle(t.window), + ID: 100, + Flags: NIF_MESSAGE, + CallbackMessage: t.wmSystrayMessage, + } + t.nid.Size = uint32(unsafe.Sizeof(*t.nid)) + + return t.nid.add() +} + +func (t *winTray) createMenu() error { + + menuHandle, _, err := pCreatePopupMenu.Call() + if menuHandle == 0 { + return err + } + t.menus[0] = windows.Handle(menuHandle) + + // https://msdn.microsoft.com/en-us/library/windows/desktop/ms647575(v=vs.85).aspx + mi := struct { + Size, Mask, Style, Max uint32 + Background windows.Handle + ContextHelpID uint32 + MenuData uintptr + }{ + Mask: MIM_APPLYTOSUBMENUS, + } + mi.Size = uint32(unsafe.Sizeof(mi)) + + res, _, err := pSetMenuInfo.Call( + uintptr(t.menus[0]), + uintptr(unsafe.Pointer(&mi)), + ) + if res == 0 { + return err + } + return nil +} + +// Contains information about a menu item. +// https://msdn.microsoft.com/en-us/library/windows/desktop/ms647578(v=vs.85).aspx +type menuItemInfo struct { + Size, Mask, Type, State uint32 + ID uint32 + SubMenu, Checked, Unchecked windows.Handle + ItemData uintptr + TypeData *uint16 + Cch uint32 + BMPItem windows.Handle +} + +func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title string, disabled bool) error { + titlePtr, err := windows.UTF16PtrFromString(title) + if err != nil { + return err + } + + mi := menuItemInfo{ + Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE, + Type: MFT_STRING, + ID: uint32(menuItemId), + TypeData: titlePtr, + Cch: uint32(len(title)), + } + mi.Size = uint32(unsafe.Sizeof(mi)) + if disabled { + mi.State |= MFS_DISABLED + } + + var res uintptr + t.muMenus.RLock() + menu := t.menus[parentId] + t.muMenus.RUnlock() + if t.getVisibleItemIndex(parentId, menuItemId) != -1 { + // We set the menu item info based on the menuID + boolRet, _, err := pSetMenuItemInfo.Call( + uintptr(menu), + uintptr(menuItemId), + 0, + uintptr(unsafe.Pointer(&mi)), + ) + if boolRet == 0 { + return fmt.Errorf("failed to set menu item: %w", err) + } + } + + if res == 0 { + // Menu item does not already exist, create it + t.muMenus.RLock() + submenu, exists := t.menus[menuItemId] + t.muMenus.RUnlock() + if exists { + mi.Mask |= MIIM_SUBMENU + mi.SubMenu = submenu + } + t.addToVisibleItems(parentId, menuItemId) + position := t.getVisibleItemIndex(parentId, menuItemId) + res, _, err = pInsertMenuItem.Call( + uintptr(menu), + uintptr(position), + 1, + uintptr(unsafe.Pointer(&mi)), + ) + if res == 0 { + t.delFromVisibleItems(parentId, menuItemId) + return err + } + t.muMenuOf.Lock() + t.menuOf[menuItemId] = menu + t.muMenuOf.Unlock() + } + + return nil +} + +func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error { + + mi := menuItemInfo{ + Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE, + Type: MFT_SEPARATOR, + ID: uint32(menuItemId), + } + + mi.Size = uint32(unsafe.Sizeof(mi)) + + t.addToVisibleItems(parentId, menuItemId) + position := t.getVisibleItemIndex(parentId, menuItemId) + t.muMenus.RLock() + menu := uintptr(t.menus[parentId]) + t.muMenus.RUnlock() + res, _, err := pInsertMenuItem.Call( + menu, + uintptr(position), + 1, + uintptr(unsafe.Pointer(&mi)), + ) + if res == 0 { + return err + } + + return nil +} + +// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error { +// const ERROR_SUCCESS syscall.Errno = 0 + +// t.muMenus.RLock() +// menu := uintptr(t.menus[parentId]) +// t.muMenus.RUnlock() +// res, _, err := pRemoveMenu.Call( +// menu, +// uintptr(menuItemId), +// MF_BYCOMMAND, +// ) +// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS { +// return err +// } +// t.delFromVisibleItems(parentId, menuItemId) + +// return nil +// } + +func (t *winTray) showMenu() error { + p := point{} + boolRet, _, err := pGetCursorPos.Call(uintptr(unsafe.Pointer(&p))) + if boolRet == 0 { + return err + } + boolRet, _, err = pSetForegroundWindow.Call(uintptr(t.window)) + if boolRet == 0 { + slog.Warn(fmt.Sprintf("failed to bring menu to foreground: %s", err)) + } + + boolRet, _, err = pTrackPopupMenu.Call( + uintptr(t.menus[0]), + TPM_BOTTOMALIGN|TPM_LEFTALIGN, + uintptr(p.X), + uintptr(p.Y), + 0, + uintptr(t.window), + 0, + ) + if boolRet == 0 { + return err + } + + return nil +} + +func (t *winTray) delFromVisibleItems(parent, val uint32) { + t.muVisibleItems.Lock() + defer t.muVisibleItems.Unlock() + visibleItems := t.visibleItems[parent] + for i, itemval := range visibleItems { + if val == itemval { + t.visibleItems[parent] = append(visibleItems[:i], visibleItems[i+1:]...) + break + } + } +} + +func (t *winTray) addToVisibleItems(parent, val uint32) { + t.muVisibleItems.Lock() + defer t.muVisibleItems.Unlock() + if visibleItems, exists := t.visibleItems[parent]; !exists { + t.visibleItems[parent] = []uint32{val} + } else { + newvisible := append(visibleItems, val) + sort.Slice(newvisible, func(i, j int) bool { return newvisible[i] < newvisible[j] }) + t.visibleItems[parent] = newvisible + } +} + +func (t *winTray) getVisibleItemIndex(parent, val uint32) int { + t.muVisibleItems.RLock() + defer t.muVisibleItems.RUnlock() + for i, itemval := range t.visibleItems[parent] { + if val == itemval { + return i + } + } + return -1 +} + +func iconBytesToFilePath(iconBytes []byte) (string, error) { + bh := md5.Sum(iconBytes) + dataHash := hex.EncodeToString(bh[:]) + iconFilePath := filepath.Join(os.TempDir(), "ollama_temp_icon_"+dataHash) + + if _, err := os.Stat(iconFilePath); os.IsNotExist(err) { + if err := os.WriteFile(iconFilePath, iconBytes, 0644); err != nil { + return "", err + } + } + return iconFilePath, nil +} + +// Loads an image from file and shows it in tray. +// Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx +func (t *winTray) setIcon(src string) error { + + h, err := t.loadIconFrom(src) + if err != nil { + return err + } + + t.muNID.Lock() + defer t.muNID.Unlock() + t.nid.Icon = h + t.nid.Flags |= NIF_ICON + t.nid.Size = uint32(unsafe.Sizeof(*t.nid)) + + return t.nid.modify() +} + +// Loads an image from file to be shown in tray or menu item. +// LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx +func (t *winTray) loadIconFrom(src string) (windows.Handle, error) { + + // Save and reuse handles of loaded images + t.muLoadedImages.RLock() + h, ok := t.loadedImages[src] + t.muLoadedImages.RUnlock() + if !ok { + srcPtr, err := windows.UTF16PtrFromString(src) + if err != nil { + return 0, err + } + res, _, err := pLoadImage.Call( + 0, + uintptr(unsafe.Pointer(srcPtr)), + IMAGE_ICON, + 0, + 0, + LR_LOADFROMFILE|LR_DEFAULTSIZE, + ) + if res == 0 { + return 0, err + } + h = windows.Handle(res) + t.muLoadedImages.Lock() + t.loadedImages[src] = h + t.muLoadedImages.Unlock() + } + return h, nil +} + +func (t *winTray) DisplayFirstUseNotification() error { + t.muNID.Lock() + defer t.muNID.Unlock() + copy(t.nid.InfoTitle[:], windows.StringToUTF16(firstTimeTitle)) + copy(t.nid.Info[:], windows.StringToUTF16(firstTimeMessage)) + t.nid.Flags |= NIF_INFO + t.nid.Size = uint32(unsafe.Sizeof(*wt.nid)) + + return t.nid.modify() +} diff --git a/app/tray/wintray/w32api.go b/app/tray/wintray/w32api.go new file mode 100644 index 00000000..a1e0381d --- /dev/null +++ b/app/tray/wintray/w32api.go @@ -0,0 +1,89 @@ +//go:build windows + +package wintray + +import ( + "runtime" + + "golang.org/x/sys/windows" +) + +var ( + k32 = windows.NewLazySystemDLL("Kernel32.dll") + u32 = windows.NewLazySystemDLL("User32.dll") + s32 = windows.NewLazySystemDLL("Shell32.dll") + + pCreatePopupMenu = u32.NewProc("CreatePopupMenu") + pCreateWindowEx = u32.NewProc("CreateWindowExW") + pDefWindowProc = u32.NewProc("DefWindowProcW") + pDestroyWindow = u32.NewProc("DestroyWindow") + pDispatchMessage = u32.NewProc("DispatchMessageW") + pGetCursorPos = u32.NewProc("GetCursorPos") + pGetMessage = u32.NewProc("GetMessageW") + pGetModuleHandle = k32.NewProc("GetModuleHandleW") + pInsertMenuItem = u32.NewProc("InsertMenuItemW") + pLoadCursor = u32.NewProc("LoadCursorW") + pLoadIcon = u32.NewProc("LoadIconW") + pLoadImage = u32.NewProc("LoadImageW") + pPostMessage = u32.NewProc("PostMessageW") + pPostQuitMessage = u32.NewProc("PostQuitMessage") + pRegisterClass = u32.NewProc("RegisterClassExW") + pRegisterWindowMessage = u32.NewProc("RegisterWindowMessageW") + pSetForegroundWindow = u32.NewProc("SetForegroundWindow") + pSetMenuInfo = u32.NewProc("SetMenuInfo") + pSetMenuItemInfo = u32.NewProc("SetMenuItemInfoW") + pShellNotifyIcon = s32.NewProc("Shell_NotifyIconW") + pShowWindow = u32.NewProc("ShowWindow") + pTrackPopupMenu = u32.NewProc("TrackPopupMenu") + pTranslateMessage = u32.NewProc("TranslateMessage") + pUnregisterClass = u32.NewProc("UnregisterClassW") + pUpdateWindow = u32.NewProc("UpdateWindow") +) + +const ( + CS_HREDRAW = 0x0002 + CS_VREDRAW = 0x0001 + CW_USEDEFAULT = 0x80000000 + IDC_ARROW = 32512 // Standard arrow + IDI_APPLICATION = 32512 + IMAGE_ICON = 1 // Loads an icon + LR_DEFAULTSIZE = 0x00000040 // Loads default-size icon for windows(SM_CXICON x SM_CYICON) if cx, cy are set to zero + LR_LOADFROMFILE = 0x00000010 // Loads the stand-alone image from the file + MF_BYCOMMAND = 0x00000000 + MFS_DISABLED = 0x00000003 + MFT_SEPARATOR = 0x00000800 + MFT_STRING = 0x00000000 + MIIM_BITMAP = 0x00000080 + MIIM_FTYPE = 0x00000100 + MIIM_ID = 0x00000002 + MIIM_STATE = 0x00000001 + MIIM_STRING = 0x00000040 + MIIM_SUBMENU = 0x00000004 + MIM_APPLYTOSUBMENUS = 0x80000000 + NIF_ICON = 0x00000002 + NIF_INFO = 0x00000010 + NIF_MESSAGE = 0x00000001 + SW_HIDE = 0 + TPM_BOTTOMALIGN = 0x0020 + TPM_LEFTALIGN = 0x0000 + WM_CLOSE = 0x0010 + WM_USER = 0x0400 + WS_CAPTION = 0x00C00000 + WS_MAXIMIZEBOX = 0x00010000 + WS_MINIMIZEBOX = 0x00020000 + WS_OVERLAPPED = 0x00000000 + WS_OVERLAPPEDWINDOW = WS_OVERLAPPED | WS_CAPTION | WS_SYSMENU | WS_THICKFRAME | WS_MINIMIZEBOX | WS_MAXIMIZEBOX + WS_SYSMENU = 0x00080000 + WS_THICKFRAME = 0x00040000 +) + +// Not sure if this is actually needed on windows +func init() { + runtime.LockOSThread() +} + +// The POINT structure defines the x- and y- coordinates of a point. +// https://msdn.microsoft.com/en-us/library/windows/desktop/dd162805(v=vs.85).aspx +type point struct { + X, Y int32 +} diff --git a/app/tray/wintray/winclass.go b/app/tray/wintray/winclass.go new file mode 100644 index 00000000..9ce71d00 --- /dev/null +++ b/app/tray/wintray/winclass.go @@ -0,0 +1,45 @@ +//go:build windows + +package wintray + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +// Contains window class information. +// It is used with the RegisterClassEx and GetClassInfoEx functions. +// https://msdn.microsoft.com/en-us/library/ms633577.aspx +type wndClassEx struct { + Size, Style uint32 + WndProc uintptr + ClsExtra, WndExtra int32 + Instance, Icon, Cursor, Background windows.Handle + MenuName, ClassName *uint16 + IconSm windows.Handle +} + +// Registers a window class for subsequent use in calls to the CreateWindow or CreateWindowEx function. +// https://msdn.microsoft.com/en-us/library/ms633587.aspx +func (w *wndClassEx) register() error { + w.Size = uint32(unsafe.Sizeof(*w)) + res, _, err := pRegisterClass.Call(uintptr(unsafe.Pointer(w))) + if res == 0 { + return err + } + return nil +} + +// Unregisters a window class, freeing the memory required for the class. +// https://msdn.microsoft.com/en-us/library/ms644899.aspx +func (w *wndClassEx) unregister() error { + res, _, err := pUnregisterClass.Call( + uintptr(unsafe.Pointer(w.ClassName)), + uintptr(w.Instance), + ) + if res == 0 { + return err + } + return nil +} diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 00000000..ca64670d --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,61 @@ +package auth + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + + "golang.org/x/crypto/ssh" +) + +const defaultPrivateKey = "id_ed25519" + +func NewNonce(r io.Reader, length int) (string, error) { + nonce := make([]byte, length) + if _, err := io.ReadFull(r, nonce); err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(nonce), nil +} + +func Sign(ctx context.Context, bts []byte) (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) + + privateKeyFile, err := os.ReadFile(keyPath) + if err != nil { + slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) + return "", err + } + + privateKey, err := ssh.ParsePrivateKey(privateKeyFile) + if err != nil { + return "", err + } + + // get the pubkey, but remove the type + publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) + parts := bytes.Split(publicKey, []byte(" ")) + if len(parts) < 2 { + return "", fmt.Errorf("malformed public key") + } + + signedData, err := privateKey.Sign(rand.Reader, bts) + if err != nil { + return "", err + } + + // signature is : + return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil +} diff --git a/cmd/cmd.go b/cmd/cmd.go index 915fa993..b9afb2e2 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,6 +1,7 @@ package cmd import ( + "archive/zip" "bytes" "context" "crypto/ed25519" @@ -14,7 +15,6 @@ import ( "net" "net/http" "os" - "os/exec" "os/signal" "path/filepath" "runtime" @@ -22,9 +22,12 @@ import ( "syscall" "time" + "github.com/containerd/console" + "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" "golang.org/x/crypto/ssh" + "golang.org/x/exp/slices" "golang.org/x/term" "github.com/jmorganca/ollama/api" @@ -85,22 +88,82 @@ func CreateHandler(cmd *cobra.Command, args []string) error { path = filepath.Join(filepath.Dir(filename), path) } - bin, err := os.Open(path) + fi, err := os.Stat(path) if errors.Is(err, os.ErrNotExist) && c.Name == "model" { continue } else if err != nil { return err } - defer bin.Close() - hash := sha256.New() - if _, err := io.Copy(hash, bin); err != nil { - return err + // TODO make this work w/ adapters + if fi.IsDir() { + tf, err := os.CreateTemp("", "ollama-tf") + if err != nil { + return err + } + defer os.RemoveAll(tf.Name()) + + zf := zip.NewWriter(tf) + + files, err := filepath.Glob(filepath.Join(path, "model-*.safetensors")) + if err != nil { + return err + } + + if len(files) == 0 { + return fmt.Errorf("no safetensors files were found in '%s'", path) + } + + // add the safetensor config file + tokenizer + files = append(files, filepath.Join(path, "config.json")) + files = append(files, filepath.Join(path, "added_tokens.json")) + files = append(files, filepath.Join(path, "tokenizer.model")) + + for _, fn := range files { + f, err := os.Open(fn) + if os.IsNotExist(err) && strings.HasSuffix(fn, "added_tokens.json") { + continue + } else if err != nil { + return err + } + + fi, err := f.Stat() + if err != nil { + return err + } + + h, err := zip.FileInfoHeader(fi) + if err != nil { + return err + } + + h.Name = filepath.Base(fn) + h.Method = zip.Store + + w, err := zf.CreateHeader(h) + if err != nil { + return err + } + + _, err = io.Copy(w, f) + if err != nil { + return err + } + + } + + if err := zf.Close(); err != nil { + return err + } + + if err := tf.Close(); err != nil { + return err + } + path = tf.Name() } - bin.Seek(0, io.SeekStart) - digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) - if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { + digest, err := createBlob(cmd, client, path) + if err != nil { return err } @@ -139,6 +202,26 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return nil } +func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { + bin, err := os.Open(path) + if err != nil { + return "", err + } + defer bin.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, bin); err != nil { + return "", err + } + bin.Seek(0, io.SeekStart) + + digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) + if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { + return "", err + } + return digest, nil +} + func RunHandler(cmd *cobra.Command, args []string) error { client, err := api.ClientFromEnvironment() if err != nil { @@ -146,19 +229,68 @@ func RunHandler(cmd *cobra.Command, args []string) error { } name := args[0] + // check if the model exists on the server - _, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name}) + show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name}) var statusError api.StatusError switch { case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound: if err := PullHandler(cmd, []string{name}); err != nil { return err } + + show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name}) + if err != nil { + return err + } case err != nil: return err } - return RunGenerate(cmd, args) + interactive := true + + opts := runOptions{ + Model: args[0], + WordWrap: os.Getenv("TERM") == "xterm-256color", + Options: map[string]interface{}{}, + MultiModal: slices.Contains(show.Details.Families, "clip"), + ParentModel: show.Details.ParentModel, + } + + format, err := cmd.Flags().GetString("format") + if err != nil { + return err + } + opts.Format = format + + prompts := args[1:] + // prepend stdin to the prompt if provided + if !term.IsTerminal(int(os.Stdin.Fd())) { + in, err := io.ReadAll(os.Stdin) + if err != nil { + return err + } + + prompts = append([]string{string(in)}, prompts...) + opts.WordWrap = false + interactive = false + } + opts.Prompt = strings.Join(prompts, " ") + if len(prompts) > 0 { + interactive = false + } + + nowrap, err := cmd.Flags().GetBool("nowordwrap") + if err != nil { + return err + } + opts.WordWrap = !nowrap + + if !interactive { + return generate(cmd, opts) + } + + return generateInteractive(cmd, opts) } func PushHandler(cmd *cobra.Command, args []string) error { @@ -410,51 +542,6 @@ func PullHandler(cmd *cobra.Command, args []string) error { return nil } -func RunGenerate(cmd *cobra.Command, args []string) error { - interactive := true - - opts := runOptions{ - Model: args[0], - WordWrap: os.Getenv("TERM") == "xterm-256color", - Options: map[string]interface{}{}, - } - - format, err := cmd.Flags().GetString("format") - if err != nil { - return err - } - opts.Format = format - - prompts := args[1:] - // prepend stdin to the prompt if provided - if !term.IsTerminal(int(os.Stdin.Fd())) { - in, err := io.ReadAll(os.Stdin) - if err != nil { - return err - } - - prompts = append([]string{string(in)}, prompts...) - opts.WordWrap = false - interactive = false - } - opts.Prompt = strings.Join(prompts, " ") - if len(prompts) > 0 { - interactive = false - } - - nowrap, err := cmd.Flags().GetBool("nowordwrap") - if err != nil { - return err - } - opts.WordWrap = !nowrap - - if !interactive { - return generate(cmd, opts) - } - - return generateInteractive(cmd, opts) -} - type generateContextKey string type runOptions struct { @@ -630,10 +717,18 @@ func generate(cmd *cobra.Command, opts runOptions) error { return nil } + if opts.MultiModal { + opts.Prompt, opts.Images, err = extractFileData(opts.Prompt) + if err != nil { + return err + } + } + request := api.GenerateRequest{ Model: opts.Model, Prompt: opts.Prompt, Context: generateContext, + Images: opts.Images, Format: opts.Format, System: opts.System, Template: opts.Template, @@ -672,7 +767,7 @@ func generate(cmd *cobra.Command, opts runOptions) error { } func RunServer(cmd *cobra.Command, _ []string) error { - host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST")) + host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'")) if err != nil { host, port = "127.0.0.1", "11434" if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil { @@ -704,59 +799,42 @@ func initializeKeypair() error { _, err = os.Stat(privKeyPath) if os.IsNotExist(err) { fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath) - _, privKey, err := ed25519.GenerateKey(rand.Reader) + cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader) if err != nil { return err } - privKeyBytes, err := format.OpenSSHPrivateKey(privKey, "") + privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "") if err != nil { return err } - err = os.MkdirAll(filepath.Dir(privKeyPath), 0o755) - if err != nil { + if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil { return fmt.Errorf("could not create directory %w", err) } - err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0o600) + if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil { + return err + } + + sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey) if err != nil { return err } - sshPrivateKey, err := ssh.NewSignerFromKey(privKey) - if err != nil { + publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey) + + if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil { return err } - pubKeyData := ssh.MarshalAuthorizedKey(sshPrivateKey.PublicKey()) - - err = os.WriteFile(pubKeyPath, pubKeyData, 0o644) - if err != nil { - return err - } - - fmt.Printf("Your new public key is: \n\n%s\n", string(pubKeyData)) + fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes) } return nil } -func startMacApp(ctx context.Context, client *api.Client) error { - exe, err := os.Executable() - if err != nil { - return err - } - link, err := os.Readlink(exe) - if err != nil { - return err - } - if !strings.Contains(link, "Ollama.app") { - return fmt.Errorf("could not find ollama app") - } - path := strings.Split(link, "Ollama.app") - if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil { - return err - } +//nolint:unused +func waitForServer(ctx context.Context, client *api.Client) error { // wait for the server to start timeout := time.After(5 * time.Second) tick := time.Tick(500 * time.Millisecond) @@ -770,6 +848,7 @@ func startMacApp(ctx context.Context, client *api.Client) error { } } } + } func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { @@ -778,15 +857,11 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { return err } if err := client.Heartbeat(cmd.Context()); err != nil { - if !strings.Contains(err.Error(), "connection refused") { + if !strings.Contains(err.Error(), " refused") { return err } - if runtime.GOOS == "darwin" { - if err := startMacApp(cmd.Context(), client); err != nil { - return fmt.Errorf("could not connect to ollama app, is it running?") - } - } else { - return fmt.Errorf("could not connect to ollama server, run 'ollama serve' to start it") + if err := startApp(cmd.Context(), client); err != nil { + return fmt.Errorf("could not connect to ollama app, is it running?") } } return nil @@ -812,10 +887,23 @@ func versionHandler(cmd *cobra.Command, _ []string) { } } +func appendHostEnvDocs(cmd *cobra.Command) { + const hostEnvDocs = ` +Environment Variables: + OLLAMA_HOST The host:port or base URL of the Ollama server (e.g. http://localhost:11434) +` + cmd.SetUsageTemplate(cmd.UsageTemplate() + hostEnvDocs) +} + func NewCLI() *cobra.Command { log.SetFlags(log.LstdFlags | log.Lshortfile) cobra.EnableCommandSorting = false + if runtime.GOOS == "windows" { + // Enable colorful ANSI escape code in Windows terminal (disabled by default) + console.ConsoleFromFile(os.Stdout) //nolint:errcheck + } + rootCmd := &cobra.Command{ Use: "ollama", Short: "Large language model runner", @@ -872,7 +960,6 @@ func NewCLI() *cobra.Command { runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().String("format", "", "Response format (e.g. json)") - serveCmd := &cobra.Command{ Use: "serve", Aliases: []string{"start"}, @@ -880,6 +967,13 @@ func NewCLI() *cobra.Command { Args: cobra.ExactArgs(0), RunE: RunServer, } + serveCmd.SetUsageTemplate(serveCmd.UsageTemplate() + ` +Environment Variables: + + OLLAMA_HOST The host:port to bind to (default "127.0.0.1:11434") + OLLAMA_ORIGINS A comma separated list of allowed origins. + OLLAMA_MODELS The path to the models directory (default is "~/.ollama/models") +`) pullCmd := &cobra.Command{ Use: "pull MODEL", @@ -908,7 +1002,6 @@ func NewCLI() *cobra.Command { PreRunE: checkServerHeartbeat, RunE: ListHandler, } - copyCmd := &cobra.Command{ Use: "cp SOURCE TARGET", Short: "Copy a model", @@ -925,6 +1018,19 @@ func NewCLI() *cobra.Command { RunE: DeleteHandler, } + for _, cmd := range []*cobra.Command{ + createCmd, + showCmd, + runCmd, + pullCmd, + pushCmd, + listCmd, + copyCmd, + deleteCmd, + } { + appendHostEnvDocs(cmd) + } + rootCmd.AddCommand( serveCmd, createCmd, diff --git a/cmd/interactive.go b/cmd/interactive.go index d337e555..82e3642a 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "os" + "path/filepath" "regexp" "sort" "strings" @@ -98,6 +99,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.") + + if opts.MultiModal { + fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file")) + } + fmt.Fprintln(os.Stderr, "") } @@ -207,6 +213,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { switch multiline { case MultilineSystem: opts.System = sb.String() + opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System}) fmt.Println("Set system message.") sb.Reset() case MultilineTemplate: @@ -226,7 +233,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Fprintln(&sb) multiline = MultilinePrompt scanner.Prompt.UseAlt = true - break } case scanner.Pasting: fmt.Fprintln(&sb, line) @@ -348,11 +354,21 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } if args[1] == "system" { - opts.System = sb.String() + opts.System = sb.String() // for display in modelfile + newMessage := api.Message{Role: "system", Content: sb.String()} + // Check if the slice is not empty and the last message is from 'system' + if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" { + // Replace the last message + opts.Messages[len(opts.Messages)-1] = newMessage + } else { + opts.Messages = append(opts.Messages, newMessage) + } fmt.Println("Set system message.") + sb.Reset() } else if args[1] == "template" { opts.Template = sb.String() fmt.Println("Set prompt template.") + sb.Reset() } sb.Reset() @@ -454,7 +470,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } else { usage() } - case line == "/exit", line == "/bye": + case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"): return nil case strings.HasPrefix(line, "/"): args := strings.Fields(line) @@ -487,29 +503,18 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { if err != nil { return err } - newMessage.Content = msg - // reset the context if we find another image + // clear all previous images for better responses if len(images) > 0 { - newMessage.Images = append(newMessage.Images, images...) - // reset the context for the new image - opts.Messages = []api.Message{} - } else { - if len(opts.Messages) > 1 { - newMessage.Images = append(newMessage.Images, opts.Messages[len(opts.Messages)-2].Images...) + for i := range opts.Messages { + opts.Messages[i].Images = nil } } - if len(newMessage.Images) == 0 { - fmt.Println("This model requires you to add a jpeg, png, or svg image.") - fmt.Println() - sb.Reset() - continue - } + + newMessage.Content = msg + newMessage.Images = images } - if opts.System != "" { - opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System}) - } opts.Messages = append(opts.Messages, newMessage) assistant, err := chat(cmd, opts) @@ -603,10 +608,10 @@ func extractFileData(input string) (string, []api.ImageData, error) { if os.IsNotExist(err) { continue } - fmt.Printf("Couldn't process image: %q\n", err) + fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err) return "", imgs, err } - fmt.Printf("Added image '%s'\n", nfp) + fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp) input = strings.ReplaceAll(input, fp, "") imgs = append(imgs, data) } @@ -627,7 +632,7 @@ func getImageData(filePath string) ([]byte, error) { } contentType := http.DetectContentType(buf) - allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"} + allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"} if !slices.Contains(allowedTypes, contentType) { return nil, fmt.Errorf("invalid image type: %s", contentType) } diff --git a/cmd/start_darwin.go b/cmd/start_darwin.go new file mode 100644 index 00000000..7e3000f0 --- /dev/null +++ b/cmd/start_darwin.go @@ -0,0 +1,30 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "os/exec" + "strings" + + "github.com/jmorganca/ollama/api" +) + +func startApp(ctx context.Context, client *api.Client) error { + exe, err := os.Executable() + if err != nil { + return err + } + link, err := os.Readlink(exe) + if err != nil { + return err + } + if !strings.Contains(link, "Ollama.app") { + return fmt.Errorf("could not find ollama app") + } + path := strings.Split(link, "Ollama.app") + if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil { + return err + } + return waitForServer(ctx, client) +} diff --git a/cmd/start_default.go b/cmd/start_default.go new file mode 100644 index 00000000..664c2d1f --- /dev/null +++ b/cmd/start_default.go @@ -0,0 +1,14 @@ +//go:build !windows && !darwin + +package cmd + +import ( + "context" + "fmt" + + "github.com/jmorganca/ollama/api" +) + +func startApp(ctx context.Context, client *api.Client) error { + return fmt.Errorf("could not connect to ollama server, run 'ollama serve' to start it") +} diff --git a/cmd/start_windows.go b/cmd/start_windows.go new file mode 100644 index 00000000..b9a423cf --- /dev/null +++ b/cmd/start_windows.go @@ -0,0 +1,58 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" + + "github.com/jmorganca/ollama/api" +) + +func startApp(ctx context.Context, client *api.Client) error { + // log.Printf("XXX Attempting to find and start ollama app") + AppName := "ollama app.exe" + exe, err := os.Executable() + if err != nil { + return err + } + appExe := filepath.Join(filepath.Dir(exe), AppName) + _, err = os.Stat(appExe) + if errors.Is(err, os.ErrNotExist) { + // Try the standard install location + localAppData := os.Getenv("LOCALAPPDATA") + appExe = filepath.Join(localAppData, "Ollama", AppName) + _, err := os.Stat(appExe) + if errors.Is(err, os.ErrNotExist) { + // Finally look in the path + appExe, err = exec.LookPath(AppName) + if err != nil { + return fmt.Errorf("could not locate ollama app") + } + } + } + // log.Printf("XXX attempting to start app %s", appExe) + + cmd_path := "c:\\Windows\\system32\\cmd.exe" + cmd := exec.Command(cmd_path, "/c", appExe) + // TODO - these hide flags aren't working - still pops up a command window for some reason + cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true} + + // TODO this didn't help either... + cmd.Stdin = strings.NewReader("") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + return fmt.Errorf("unable to start ollama app %w", err) + } + + if cmd.Process != nil { + defer cmd.Process.Release() //nolint:errcheck + } + return waitForServer(ctx, client) +} diff --git a/convert/convert.go b/convert/convert.go new file mode 100644 index 00000000..ba23080c --- /dev/null +++ b/convert/convert.go @@ -0,0 +1,331 @@ +package convert + +import ( + "bytes" + "cmp" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "regexp" + "slices" + + "github.com/mitchellh/mapstructure" + "google.golang.org/protobuf/proto" + + "github.com/jmorganca/ollama/convert/sentencepiece" + "github.com/jmorganca/ollama/llm" +) + +type Params struct { + Architectures []string `json:"architectures"` + VocabSize int `json:"vocab_size"` + HiddenSize int `json:"hidden_size"` // n_embd + HiddenLayers int `json:"num_hidden_layers"` // n_layer + ContextSize int `json:"max_position_embeddings"` + IntermediateSize int `json:"intermediate_size"` + AttentionHeads int `json:"num_attention_heads"` // n_head + KeyValHeads int `json:"num_key_value_heads"` + NormEPS float64 `json:"rms_norm_eps"` + RopeFreqBase float64 `json:"rope_theta"` + BoSTokenID int `json:"bos_token_id"` + EoSTokenID int `json:"eos_token_id"` +} + +type MetaData struct { + Type string `mapstructure:"dtype"` + Shape []int `mapstructure:"shape"` + Offsets []int `mapstructure:"data_offsets"` +} + +func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) { + f, err := os.Open(fn) + if err != nil { + return []llm.Tensor{}, 0, err + } + defer f.Close() + + var jsonSize uint64 + binary.Read(f, binary.LittleEndian, &jsonSize) + + buf := make([]byte, jsonSize) + _, err = io.ReadFull(f, buf) + if err != nil { + return []llm.Tensor{}, 0, err + } + + d := json.NewDecoder(bytes.NewBuffer(buf)) + d.UseNumber() + var parsed map[string]interface{} + if err = d.Decode(&parsed); err != nil { + return []llm.Tensor{}, 0, err + } + + var keys []string + for k := range parsed { + keys = append(keys, k) + } + + slices.Sort(keys) + + slog.Info("converting layers") + + var tensors []llm.Tensor + for _, k := range keys { + vals := parsed[k].(map[string]interface{}) + var data MetaData + if err = mapstructure.Decode(vals, &data); err != nil { + return []llm.Tensor{}, 0, err + } + + var size uint64 + var kind uint32 + switch len(data.Shape) { + case 0: + // metadata + continue + case 1: + // convert to float32 + kind = 0 + size = uint64(data.Shape[0] * 4) + case 2: + // convert to float16 + kind = 1 + size = uint64(data.Shape[0] * data.Shape[1] * 2) + } + + ggufName, err := GetTensorName(k) + if err != nil { + slog.Error("%v", err) + return []llm.Tensor{}, 0, err + } + + shape := []uint64{0, 0, 0, 0} + for i := range data.Shape { + shape[i] = uint64(data.Shape[i]) + } + + t := llm.Tensor{ + Name: ggufName, + Kind: kind, + Offset: offset, + Shape: shape[:], + FileName: fn, + OffsetPadding: 8 + jsonSize, + FileOffsets: []uint64{uint64(data.Offsets[0]), uint64(data.Offsets[1])}, + } + slog.Debug(fmt.Sprintf("%v", t)) + tensors = append(tensors, t) + offset += size + } + return tensors, offset, nil +} + +func GetSafeTensors(dirpath string) ([]llm.Tensor, error) { + var tensors []llm.Tensor + files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors")) + if err != nil { + return []llm.Tensor{}, err + } + + var offset uint64 + for _, f := range files { + var t []llm.Tensor + var err error + t, offset, err = ReadSafeTensors(f, offset) + if err != nil { + slog.Error("%v", err) + return []llm.Tensor{}, err + } + tensors = append(tensors, t...) + } + return tensors, nil +} + +func GetParams(dirpath string) (*Params, error) { + f, err := os.Open(filepath.Join(dirpath, "config.json")) + if err != nil { + return nil, err + } + defer f.Close() + + var params Params + + d := json.NewDecoder(f) + err = d.Decode(¶ms) + if err != nil { + return nil, err + } + + return ¶ms, nil +} + +// Details on gguf's tokenizer can be found at: +// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#tokenizer +type Vocab struct { + Tokens []string + Scores []float32 + Types []int32 +} + +func LoadTokens(dirpath string) (*Vocab, error) { + slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model"))) + in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model")) + if err != nil { + return nil, err + } + + // To regenerate sentencepiece from the protobufs use: + // protoc -I=./ --go_out=./ sentencepiece_model.proto + modelProto := &sentencepiece.ModelProto{} + if err := proto.Unmarshal(in, modelProto); err != nil { + return nil, err + } + + v := &Vocab{ + Tokens: make([]string, 0), + Scores: make([]float32, 0), + Types: make([]int32, 0), + } + + pieces := modelProto.GetPieces() + for _, p := range pieces { + v.Tokens = append(v.Tokens, p.GetPiece()) + v.Scores = append(v.Scores, p.GetScore()) + t := p.GetType() + v.Types = append(v.Types, int32(t)) + } + + slog.Info(fmt.Sprintf("vocab size: %d", len(v.Tokens))) + + // add any additional tokens + addIn, err := os.ReadFile(filepath.Join(dirpath, "added_tokens.json")) + if os.IsNotExist(err) { + return v, nil + } else if err != nil { + return nil, err + } + + slog.Info("reading user defined tokens") + + var extraTokenData map[string]int + if err := json.Unmarshal(addIn, &extraTokenData); err != nil { + return nil, err + } + + type token struct { + key string + pos int + } + + extraTokens := make([]token, 0) + for k, id := range extraTokenData { + extraTokens = append(extraTokens, token{k, id}) + } + + slices.SortFunc(extraTokens, func(a, b token) int { + return cmp.Compare(a.pos, b.pos) + }) + + numToks := len(v.Tokens) + + for cnt, t := range extraTokens { + // the token id should match the specific index for the total number of tokens + if t.pos != cnt+numToks { + return nil, fmt.Errorf("token ID '%d' for '%s' doesn't match total token size", t.pos, t.key) + } + v.Tokens = append(v.Tokens, t.key) + v.Scores = append(v.Scores, -1000.0) + v.Types = append(v.Types, int32(llm.GGUFTokenUserDefined)) + } + slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens))) + + return v, nil +} + +func GetTensorName(n string) (string, error) { + tMap := map[string]string{ + "model.embed_tokens.weight": "token_embd.weight", + "model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight", + "model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight", + "model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight", + "model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight", + "model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight", + "model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight", + "model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight", + "model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight", + "model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight", + "lm_head.weight": "output.weight", + "model.norm.weight": "output_norm.weight", + } + + v, ok := tMap[n] + if ok { + return v, nil + } + + // quick hack to rename the layers to gguf format + for k, v := range tMap { + re := regexp.MustCompile(k) + newName := re.ReplaceAllString(n, v) + if newName != n { + return newName, nil + } + } + + return "", fmt.Errorf("couldn't find a layer name for '%s'", n) +} + +func WriteGGUF(name string, tensors []llm.Tensor, params *Params, vocab *Vocab) (string, error) { + c := llm.ContainerGGUF{ + ByteOrder: binary.LittleEndian, + } + + m := llm.NewGGUFModel(&c) + m.Tensors = tensors + m.KV["general.architecture"] = "llama" + m.KV["general.name"] = name + m.KV["llama.context_length"] = uint32(params.ContextSize) + m.KV["llama.embedding_length"] = uint32(params.HiddenSize) + m.KV["llama.block_count"] = uint32(params.HiddenLayers) + m.KV["llama.feed_forward_length"] = uint32(params.IntermediateSize) + m.KV["llama.rope.dimension_count"] = uint32(128) + m.KV["llama.attention.head_count"] = uint32(params.AttentionHeads) + m.KV["llama.attention.head_count_kv"] = uint32(params.KeyValHeads) + m.KV["llama.attention.layer_norm_rms_epsilon"] = float32(params.NormEPS) + m.KV["llama.rope.freq_base"] = float32(params.RopeFreqBase) + m.KV["general.file_type"] = uint32(1) + m.KV["tokenizer.ggml.model"] = "llama" + + m.KV["tokenizer.ggml.tokens"] = vocab.Tokens + m.KV["tokenizer.ggml.scores"] = vocab.Scores + m.KV["tokenizer.ggml.token_type"] = vocab.Types + + m.KV["tokenizer.ggml.bos_token_id"] = uint32(params.BoSTokenID) + m.KV["tokenizer.ggml.eos_token_id"] = uint32(params.EoSTokenID) + m.KV["tokenizer.ggml.unknown_token_id"] = uint32(0) + m.KV["tokenizer.ggml.add_bos_token"] = true + m.KV["tokenizer.ggml.add_eos_token"] = false + + // llamacpp sets the chat template, however we don't need to set it since we pass it in through a layer + // m.KV["tokenizer.chat_template"] = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" // XXX removeme + + c.V3.NumTensor = uint64(len(tensors)) + c.V3.NumKV = uint64(len(m.KV)) + + f, err := os.CreateTemp("", "ollama-gguf") + if err != nil { + return "", err + } + defer f.Close() + + err = m.Encode(f) + if err != nil { + return "", err + } + + return f.Name(), nil +} diff --git a/convert/sentencepiece/sentencepiece_model.pb.go b/convert/sentencepiece/sentencepiece_model.pb.go new file mode 100644 index 00000000..5c8db9bc --- /dev/null +++ b/convert/sentencepiece/sentencepiece_model.pb.go @@ -0,0 +1,1497 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.! + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.32.0 +// protoc v4.25.2 +// source: sentencepiece_model.proto + +package sentencepiece + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Model type. only have UNIGRAM now. +type TrainerSpec_ModelType int32 + +const ( + TrainerSpec_UNIGRAM TrainerSpec_ModelType = 1 // Unigram language model with dynamic algorithm + TrainerSpec_BPE TrainerSpec_ModelType = 2 // Byte Pair Encoding + TrainerSpec_WORD TrainerSpec_ModelType = 3 // Delimitered by whitespace. + TrainerSpec_CHAR TrainerSpec_ModelType = 4 // tokenizes into character sequence +) + +// Enum value maps for TrainerSpec_ModelType. +var ( + TrainerSpec_ModelType_name = map[int32]string{ + 1: "UNIGRAM", + 2: "BPE", + 3: "WORD", + 4: "CHAR", + } + TrainerSpec_ModelType_value = map[string]int32{ + "UNIGRAM": 1, + "BPE": 2, + "WORD": 3, + "CHAR": 4, + } +) + +func (x TrainerSpec_ModelType) Enum() *TrainerSpec_ModelType { + p := new(TrainerSpec_ModelType) + *p = x + return p +} + +func (x TrainerSpec_ModelType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (TrainerSpec_ModelType) Descriptor() protoreflect.EnumDescriptor { + return file_sentencepiece_model_proto_enumTypes[0].Descriptor() +} + +func (TrainerSpec_ModelType) Type() protoreflect.EnumType { + return &file_sentencepiece_model_proto_enumTypes[0] +} + +func (x TrainerSpec_ModelType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Do not use. +func (x *TrainerSpec_ModelType) UnmarshalJSON(b []byte) error { + num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b) + if err != nil { + return err + } + *x = TrainerSpec_ModelType(num) + return nil +} + +// Deprecated: Use TrainerSpec_ModelType.Descriptor instead. +func (TrainerSpec_ModelType) EnumDescriptor() ([]byte, []int) { + return file_sentencepiece_model_proto_rawDescGZIP(), []int{0, 0} +} + +type ModelProto_SentencePiece_Type int32 + +const ( + ModelProto_SentencePiece_NORMAL ModelProto_SentencePiece_Type = 1 // normal symbol + ModelProto_SentencePiece_UNKNOWN ModelProto_SentencePiece_Type = 2 // unknown symbol. only for now. + ModelProto_SentencePiece_CONTROL ModelProto_SentencePiece_Type = 3 // control symbols. , , <2ja> etc. + ModelProto_SentencePiece_USER_DEFINED ModelProto_SentencePiece_Type = 4 // user defined symbols. + // Typical usage of USER_DEFINED symbol + // is placeholder. + ModelProto_SentencePiece_BYTE ModelProto_SentencePiece_Type = 6 // byte symbols. Used when `byte_fallback` is true. + ModelProto_SentencePiece_UNUSED ModelProto_SentencePiece_Type = 5 // this piece is not used. +) + +// Enum value maps for ModelProto_SentencePiece_Type. +var ( + ModelProto_SentencePiece_Type_name = map[int32]string{ + 1: "NORMAL", + 2: "UNKNOWN", + 3: "CONTROL", + 4: "USER_DEFINED", + 6: "BYTE", + 5: "UNUSED", + } + ModelProto_SentencePiece_Type_value = map[string]int32{ + "NORMAL": 1, + "UNKNOWN": 2, + "CONTROL": 3, + "USER_DEFINED": 4, + "BYTE": 6, + "UNUSED": 5, + } +) + +func (x ModelProto_SentencePiece_Type) Enum() *ModelProto_SentencePiece_Type { + p := new(ModelProto_SentencePiece_Type) + *p = x + return p +} + +func (x ModelProto_SentencePiece_Type) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (ModelProto_SentencePiece_Type) Descriptor() protoreflect.EnumDescriptor { + return file_sentencepiece_model_proto_enumTypes[1].Descriptor() +} + +func (ModelProto_SentencePiece_Type) Type() protoreflect.EnumType { + return &file_sentencepiece_model_proto_enumTypes[1] +} + +func (x ModelProto_SentencePiece_Type) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Do not use. +func (x *ModelProto_SentencePiece_Type) UnmarshalJSON(b []byte) error { + num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b) + if err != nil { + return err + } + *x = ModelProto_SentencePiece_Type(num) + return nil +} + +// Deprecated: Use ModelProto_SentencePiece_Type.Descriptor instead. +func (ModelProto_SentencePiece_Type) EnumDescriptor() ([]byte, []int) { + return file_sentencepiece_model_proto_rawDescGZIP(), []int{3, 0, 0} +} + +// TrainerSpec encodes a various parameters for SentencePiece training. +// Next id: 55 +type TrainerSpec struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + extensionFields protoimpl.ExtensionFields + + // ///////////////////////////////////////////////////////////////// + // General parameters + // + // Input corpus files. + // + // Trainer accepts the following two formats: + // A) Monolingual: plain text, one sentence per line. + // B) Bilingual: TSV, source sentence target sentence + // When bilingual data is passed, shared vocabulary model is built. + // Note that the input file must be raw corpus, not a preprocessed corpus. + // Trainer only loads the first `input_sentence_size` sentences specified + // with this parameter. + Input []string `protobuf:"bytes,1,rep,name=input" json:"input,omitempty"` + // Input corpus format: + // "text": one-sentence-per-line text format (default) + // "tsv": sentence freq + InputFormat *string `protobuf:"bytes,7,opt,name=input_format,json=inputFormat" json:"input_format,omitempty"` + // Output model file prefix. + // .model and .vocab are generated. + ModelPrefix *string `protobuf:"bytes,2,opt,name=model_prefix,json=modelPrefix" json:"model_prefix,omitempty"` + ModelType *TrainerSpec_ModelType `protobuf:"varint,3,opt,name=model_type,json=modelType,enum=sentencepiece.TrainerSpec_ModelType,def=1" json:"model_type,omitempty"` + // Vocabulary size. 8k is the default size. + VocabSize *int32 `protobuf:"varint,4,opt,name=vocab_size,json=vocabSize,def=8000" json:"vocab_size,omitempty"` + // List of the languages this model can accept. + // Since the model is language-agnostic, this field is used as a reference. + AcceptLanguage []string `protobuf:"bytes,5,rep,name=accept_language,json=acceptLanguage" json:"accept_language,omitempty"` + // Size of self-test samples, which are encoded in the model file. + SelfTestSampleSize *int32 `protobuf:"varint,6,opt,name=self_test_sample_size,json=selfTestSampleSize,def=0" json:"self_test_sample_size,omitempty"` + // Whether to use DP version of sentencepiece. Use it with TSV input format + // (requires precomputed word tab counts to work). + EnableDifferentialPrivacy *bool `protobuf:"varint,50,opt,name=enable_differential_privacy,json=enableDifferentialPrivacy,def=0" json:"enable_differential_privacy,omitempty"` + // Set these parameters if you need DP version of sentencepiece. + // std of noise to add. + DifferentialPrivacyNoiseLevel *float32 `protobuf:"fixed32,51,opt,name=differential_privacy_noise_level,json=differentialPrivacyNoiseLevel,def=0" json:"differential_privacy_noise_level,omitempty"` + // Clipping threshold to apply after adding noise. All the words with + // frequency less than this value are dropped. + DifferentialPrivacyClippingThreshold *uint64 `protobuf:"varint,52,opt,name=differential_privacy_clipping_threshold,json=differentialPrivacyClippingThreshold,def=0" json:"differential_privacy_clipping_threshold,omitempty"` + // ///////////////////////////////////////////////////////////////// + // Training parameters. + // + // Uses characters which cover the corpus with the ratio of `chars_coverage`. + // This parameter determines the set of basic Alphabet of sentence piece. + // 1.0 - `chars_coverage` characters are treated as UNK. + // See also required_chars field. + CharacterCoverage *float32 `protobuf:"fixed32,10,opt,name=character_coverage,json=characterCoverage,def=0.9995" json:"character_coverage,omitempty"` + // Maximum size of sentences the trainer loads from `input` parameter. + // Trainer simply loads the `input` files in sequence. + // It is better to shuffle the input corpus randomly. + InputSentenceSize *uint64 `protobuf:"varint,11,opt,name=input_sentence_size,json=inputSentenceSize,def=0" json:"input_sentence_size,omitempty"` + ShuffleInputSentence *bool `protobuf:"varint,19,opt,name=shuffle_input_sentence,json=shuffleInputSentence,def=1" json:"shuffle_input_sentence,omitempty"` + // Maximum size of sentences to make seed sentence pieces. + // Extended suffix array is constructed to extract frequent + // sub-strings from the corpus. This uses 20N working space, + // where N is the size of corpus. + // + // Deprecated: Marked as deprecated in sentencepiece_model.proto. + MiningSentenceSize *int32 `protobuf:"varint,12,opt,name=mining_sentence_size,json=miningSentenceSize" json:"mining_sentence_size,omitempty"` + // Maximum size of sentences to train sentence pieces. + // + // Deprecated: Marked as deprecated in sentencepiece_model.proto. + TrainingSentenceSize *int32 `protobuf:"varint,13,opt,name=training_sentence_size,json=trainingSentenceSize" json:"training_sentence_size,omitempty"` + // The size of seed sentencepieces. + // `seed_sentencepiece_size` must be larger than `vocab_size`. + SeedSentencepieceSize *int32 `protobuf:"varint,14,opt,name=seed_sentencepiece_size,json=seedSentencepieceSize,def=1000000" json:"seed_sentencepiece_size,omitempty"` + // In every EM sub-iterations, keeps top + // `shrinking_factor` * `current sentencepieces size` with respect to + // the loss of the sentence piece. This value should be smaller than 1.0. + ShrinkingFactor *float32 `protobuf:"fixed32,15,opt,name=shrinking_factor,json=shrinkingFactor,def=0.75" json:"shrinking_factor,omitempty"` + // The maximum sentence length in byte. The sentences with the length + // larger than `max_sentence_length` is simply ignored. + // Longer input tends to bring the following risks: + // - Overflow during EM training (unigram language model only) + // - Performance drop because of O(n log n) cost in BPE. + MaxSentenceLength *int32 `protobuf:"varint,18,opt,name=max_sentence_length,json=maxSentenceLength,def=4192" json:"max_sentence_length,omitempty"` + // Number of threads in the training. + NumThreads *int32 `protobuf:"varint,16,opt,name=num_threads,json=numThreads,def=16" json:"num_threads,omitempty"` + // Number of EM sub iterations. + NumSubIterations *int32 `protobuf:"varint,17,opt,name=num_sub_iterations,json=numSubIterations,def=2" json:"num_sub_iterations,omitempty"` + // ///////////////////////////////////////////////////////////////// + // SentencePiece parameters which control the shapes of sentence piece. + // + // Maximum length of sentencepiece. + MaxSentencepieceLength *int32 `protobuf:"varint,20,opt,name=max_sentencepiece_length,json=maxSentencepieceLength,def=16" json:"max_sentencepiece_length,omitempty"` + // Uses Unicode script to split sentence pieces. + // When `split_by_unicode_script` is true, we do not allow sentence piece to + // include multiple Unicode scripts, e.g. "F1" is not a valid piece. + // Exception: CJ characters (Hiragana/Katakana/Han) are all handled + // as one script type, since Japanese word can consist of multiple scripts. + // This exception is always applied regardless of the accept-language + // parameter. + SplitByUnicodeScript *bool `protobuf:"varint,21,opt,name=split_by_unicode_script,json=splitByUnicodeScript,def=1" json:"split_by_unicode_script,omitempty"` + // When `split_by_number` is true, put a boundary between number and + // non-number transition. If we want to treat "F1" is one token, set this flag + // to be false. + SplitByNumber *bool `protobuf:"varint,23,opt,name=split_by_number,json=splitByNumber,def=1" json:"split_by_number,omitempty"` + // Use a white space to split sentence pieces. + // When `split_by_whitespace` is false, we may have the piece containing + // a white space in the middle. e.g., "in_the". + SplitByWhitespace *bool `protobuf:"varint,22,opt,name=split_by_whitespace,json=splitByWhitespace,def=1" json:"split_by_whitespace,omitempty"` + // Adds whitespace symbol (_) as a suffix instead of prefix. e.g., _hello => + // hello_. When `treat_whitespace_as_suffix` is true, + // NormalizerSpec::add_dummy_prefix will add the dummy whitespace to the end + // of sentence. + TreatWhitespaceAsSuffix *bool `protobuf:"varint,24,opt,name=treat_whitespace_as_suffix,json=treatWhitespaceAsSuffix,def=0" json:"treat_whitespace_as_suffix,omitempty"` + // Allows pieces that only contain whitespaces instead of appearing only as + // prefix or suffix of other pieces. + AllowWhitespaceOnlyPieces *bool `protobuf:"varint,26,opt,name=allow_whitespace_only_pieces,json=allowWhitespaceOnlyPieces,def=0" json:"allow_whitespace_only_pieces,omitempty"` + // Split all digits (0-9) into separate pieces. + SplitDigits *bool `protobuf:"varint,25,opt,name=split_digits,json=splitDigits,def=0" json:"split_digits,omitempty"` + // Defines the pre-tokenization delimiter. + // When specified, no pieces crossing this delimiter is not included + // in the vocab. Then the delimiter string is virtually ignored + // during the training. This field can allows constraints on the vocabulary + // selection. Note that this field is available on unigram mode. + PretokenizationDelimiter *string `protobuf:"bytes,53,opt,name=pretokenization_delimiter,json=pretokenizationDelimiter,def=" json:"pretokenization_delimiter,omitempty"` + // ///////////////////////////////////////////////////////////////// + // Vocabulary management + // + // Defines control symbols used as an indicator to + // change the behavior of the decoder. and are pre-defined. + // We can use this field to encode various meta information, + // including language indicator in multilingual model. + // These symbols are not visible to users, but visible to + // the decoder. Note that when the input sentence contains control symbols, + // they are not treated as one token, but segmented into normal pieces. + // Control symbols must be inserted independently from the segmentation. + ControlSymbols []string `protobuf:"bytes,30,rep,name=control_symbols,json=controlSymbols" json:"control_symbols,omitempty"` + // Defines user defined symbols. + // These symbols are added with extremely high score + // so they are always treated as one unique symbol in any context. + // Typical usage of user_defined_symbols is placeholder for named entities. + UserDefinedSymbols []string `protobuf:"bytes,31,rep,name=user_defined_symbols,json=userDefinedSymbols" json:"user_defined_symbols,omitempty"` + // Defines required characters. Each UTF8 character in this string is included + // in the character set regardless of character_coverage value. Unlike + // user_defined_symbols, these characters have scores based on the frequency + // on input sentences, and the model can form subwords using characters + // in this field. + RequiredChars *string `protobuf:"bytes,36,opt,name=required_chars,json=requiredChars" json:"required_chars,omitempty"` + // Decomposes unknown pieces into UTF-8 bytes. + ByteFallback *bool `protobuf:"varint,35,opt,name=byte_fallback,json=byteFallback,def=0" json:"byte_fallback,omitempty"` + // When creating the vocabulary file, defines whether or not to additionally + // output the score for each piece. + VocabularyOutputPieceScore *bool `protobuf:"varint,32,opt,name=vocabulary_output_piece_score,json=vocabularyOutputPieceScore,def=1" json:"vocabulary_output_piece_score,omitempty"` + // `vocab_size` is treated as hard limit. Crash if + // the model can not produce the vocab of size `vocab_size`, + // When `hard_vocab_limit` is false, vocab_size is treated + // as soft limit. Note that when model_type=char, + // always assumes hard_vocab_limit = false. + HardVocabLimit *bool `protobuf:"varint,33,opt,name=hard_vocab_limit,json=hardVocabLimit,def=1" json:"hard_vocab_limit,omitempty"` + // use all symbols for vocab extraction. This flag is valid + // if model type is either CHAR or WORD + UseAllVocab *bool `protobuf:"varint,34,opt,name=use_all_vocab,json=useAllVocab,def=0" json:"use_all_vocab,omitempty"` + // ///////////////////////////////////////////////////////////////// + // Reserved special meta tokens. + // * -1 is not used. + // * unk_id must not be -1. + // Id must starts with 0 and be contigous. + UnkId *int32 `protobuf:"varint,40,opt,name=unk_id,json=unkId,def=0" json:"unk_id,omitempty"` // + BosId *int32 `protobuf:"varint,41,opt,name=bos_id,json=bosId,def=1" json:"bos_id,omitempty"` // + EosId *int32 `protobuf:"varint,42,opt,name=eos_id,json=eosId,def=2" json:"eos_id,omitempty"` // + PadId *int32 `protobuf:"varint,43,opt,name=pad_id,json=padId,def=-1" json:"pad_id,omitempty"` // (padding) + UnkPiece *string `protobuf:"bytes,45,opt,name=unk_piece,json=unkPiece,def=" json:"unk_piece,omitempty"` + BosPiece *string `protobuf:"bytes,46,opt,name=bos_piece,json=bosPiece,def=" json:"bos_piece,omitempty"` + EosPiece *string `protobuf:"bytes,47,opt,name=eos_piece,json=eosPiece,def=" json:"eos_piece,omitempty"` + PadPiece *string `protobuf:"bytes,48,opt,name=pad_piece,json=padPiece,def=" json:"pad_piece,omitempty"` + // Encodes into U+2047 (DOUBLE QUESTION MARK), + // since this character can be useful both for user and + // developer. We can easily figure out that is emitted. + UnkSurface *string `protobuf:"bytes,44,opt,name=unk_surface,json=unkSurface,def= ⁇ " json:"unk_surface,omitempty"` + // Increase bit depth to allow unigram model training on large + // (>10M sentences) corpora. A Side-effect of enabling this flag + // is increased memory usage. + TrainExtremelyLargeCorpus *bool `protobuf:"varint,49,opt,name=train_extremely_large_corpus,json=trainExtremelyLargeCorpus,def=0" json:"train_extremely_large_corpus,omitempty"` + // Path to a seed sentencepieces file, with one tab-separated + // seed sentencepiece frequency per line. + SeedSentencepiecesFile *string `protobuf:"bytes,54,opt,name=seed_sentencepieces_file,json=seedSentencepiecesFile,def=" json:"seed_sentencepieces_file,omitempty"` +} + +// Default values for TrainerSpec fields. +const ( + Default_TrainerSpec_ModelType = TrainerSpec_UNIGRAM + Default_TrainerSpec_VocabSize = int32(8000) + Default_TrainerSpec_SelfTestSampleSize = int32(0) + Default_TrainerSpec_EnableDifferentialPrivacy = bool(false) + Default_TrainerSpec_DifferentialPrivacyNoiseLevel = float32(0) + Default_TrainerSpec_DifferentialPrivacyClippingThreshold = uint64(0) + Default_TrainerSpec_CharacterCoverage = float32(0.9994999766349792) + Default_TrainerSpec_InputSentenceSize = uint64(0) + Default_TrainerSpec_ShuffleInputSentence = bool(true) + Default_TrainerSpec_SeedSentencepieceSize = int32(1000000) + Default_TrainerSpec_ShrinkingFactor = float32(0.75) + Default_TrainerSpec_MaxSentenceLength = int32(4192) + Default_TrainerSpec_NumThreads = int32(16) + Default_TrainerSpec_NumSubIterations = int32(2) + Default_TrainerSpec_MaxSentencepieceLength = int32(16) + Default_TrainerSpec_SplitByUnicodeScript = bool(true) + Default_TrainerSpec_SplitByNumber = bool(true) + Default_TrainerSpec_SplitByWhitespace = bool(true) + Default_TrainerSpec_TreatWhitespaceAsSuffix = bool(false) + Default_TrainerSpec_AllowWhitespaceOnlyPieces = bool(false) + Default_TrainerSpec_SplitDigits = bool(false) + Default_TrainerSpec_PretokenizationDelimiter = string("") + Default_TrainerSpec_ByteFallback = bool(false) + Default_TrainerSpec_VocabularyOutputPieceScore = bool(true) + Default_TrainerSpec_HardVocabLimit = bool(true) + Default_TrainerSpec_UseAllVocab = bool(false) + Default_TrainerSpec_UnkId = int32(0) + Default_TrainerSpec_BosId = int32(1) + Default_TrainerSpec_EosId = int32(2) + Default_TrainerSpec_PadId = int32(-1) + Default_TrainerSpec_UnkPiece = string("") + Default_TrainerSpec_BosPiece = string("") + Default_TrainerSpec_EosPiece = string("") + Default_TrainerSpec_PadPiece = string("") + Default_TrainerSpec_UnkSurface = string(" ⁇ ") + Default_TrainerSpec_TrainExtremelyLargeCorpus = bool(false) + Default_TrainerSpec_SeedSentencepiecesFile = string("") +) + +func (x *TrainerSpec) Reset() { + *x = TrainerSpec{} + if protoimpl.UnsafeEnabled { + mi := &file_sentencepiece_model_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TrainerSpec) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TrainerSpec) ProtoMessage() {} + +func (x *TrainerSpec) ProtoReflect() protoreflect.Message { + mi := &file_sentencepiece_model_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TrainerSpec.ProtoReflect.Descriptor instead. +func (*TrainerSpec) Descriptor() ([]byte, []int) { + return file_sentencepiece_model_proto_rawDescGZIP(), []int{0} +} + +func (x *TrainerSpec) GetInput() []string { + if x != nil { + return x.Input + } + return nil +} + +func (x *TrainerSpec) GetInputFormat() string { + if x != nil && x.InputFormat != nil { + return *x.InputFormat + } + return "" +} + +func (x *TrainerSpec) GetModelPrefix() string { + if x != nil && x.ModelPrefix != nil { + return *x.ModelPrefix + } + return "" +} + +func (x *TrainerSpec) GetModelType() TrainerSpec_ModelType { + if x != nil && x.ModelType != nil { + return *x.ModelType + } + return Default_TrainerSpec_ModelType +} + +func (x *TrainerSpec) GetVocabSize() int32 { + if x != nil && x.VocabSize != nil { + return *x.VocabSize + } + return Default_TrainerSpec_VocabSize +} + +func (x *TrainerSpec) GetAcceptLanguage() []string { + if x != nil { + return x.AcceptLanguage + } + return nil +} + +func (x *TrainerSpec) GetSelfTestSampleSize() int32 { + if x != nil && x.SelfTestSampleSize != nil { + return *x.SelfTestSampleSize + } + return Default_TrainerSpec_SelfTestSampleSize +} + +func (x *TrainerSpec) GetEnableDifferentialPrivacy() bool { + if x != nil && x.EnableDifferentialPrivacy != nil { + return *x.EnableDifferentialPrivacy + } + return Default_TrainerSpec_EnableDifferentialPrivacy +} + +func (x *TrainerSpec) GetDifferentialPrivacyNoiseLevel() float32 { + if x != nil && x.DifferentialPrivacyNoiseLevel != nil { + return *x.DifferentialPrivacyNoiseLevel + } + return Default_TrainerSpec_DifferentialPrivacyNoiseLevel +} + +func (x *TrainerSpec) GetDifferentialPrivacyClippingThreshold() uint64 { + if x != nil && x.DifferentialPrivacyClippingThreshold != nil { + return *x.DifferentialPrivacyClippingThreshold + } + return Default_TrainerSpec_DifferentialPrivacyClippingThreshold +} + +func (x *TrainerSpec) GetCharacterCoverage() float32 { + if x != nil && x.CharacterCoverage != nil { + return *x.CharacterCoverage + } + return Default_TrainerSpec_CharacterCoverage +} + +func (x *TrainerSpec) GetInputSentenceSize() uint64 { + if x != nil && x.InputSentenceSize != nil { + return *x.InputSentenceSize + } + return Default_TrainerSpec_InputSentenceSize +} + +func (x *TrainerSpec) GetShuffleInputSentence() bool { + if x != nil && x.ShuffleInputSentence != nil { + return *x.ShuffleInputSentence + } + return Default_TrainerSpec_ShuffleInputSentence +} + +// Deprecated: Marked as deprecated in sentencepiece_model.proto. +func (x *TrainerSpec) GetMiningSentenceSize() int32 { + if x != nil && x.MiningSentenceSize != nil { + return *x.MiningSentenceSize + } + return 0 +} + +// Deprecated: Marked as deprecated in sentencepiece_model.proto. +func (x *TrainerSpec) GetTrainingSentenceSize() int32 { + if x != nil && x.TrainingSentenceSize != nil { + return *x.TrainingSentenceSize + } + return 0 +} + +func (x *TrainerSpec) GetSeedSentencepieceSize() int32 { + if x != nil && x.SeedSentencepieceSize != nil { + return *x.SeedSentencepieceSize + } + return Default_TrainerSpec_SeedSentencepieceSize +} + +func (x *TrainerSpec) GetShrinkingFactor() float32 { + if x != nil && x.ShrinkingFactor != nil { + return *x.ShrinkingFactor + } + return Default_TrainerSpec_ShrinkingFactor +} + +func (x *TrainerSpec) GetMaxSentenceLength() int32 { + if x != nil && x.MaxSentenceLength != nil { + return *x.MaxSentenceLength + } + return Default_TrainerSpec_MaxSentenceLength +} + +func (x *TrainerSpec) GetNumThreads() int32 { + if x != nil && x.NumThreads != nil { + return *x.NumThreads + } + return Default_TrainerSpec_NumThreads +} + +func (x *TrainerSpec) GetNumSubIterations() int32 { + if x != nil && x.NumSubIterations != nil { + return *x.NumSubIterations + } + return Default_TrainerSpec_NumSubIterations +} + +func (x *TrainerSpec) GetMaxSentencepieceLength() int32 { + if x != nil && x.MaxSentencepieceLength != nil { + return *x.MaxSentencepieceLength + } + return Default_TrainerSpec_MaxSentencepieceLength +} + +func (x *TrainerSpec) GetSplitByUnicodeScript() bool { + if x != nil && x.SplitByUnicodeScript != nil { + return *x.SplitByUnicodeScript + } + return Default_TrainerSpec_SplitByUnicodeScript +} + +func (x *TrainerSpec) GetSplitByNumber() bool { + if x != nil && x.SplitByNumber != nil { + return *x.SplitByNumber + } + return Default_TrainerSpec_SplitByNumber +} + +func (x *TrainerSpec) GetSplitByWhitespace() bool { + if x != nil && x.SplitByWhitespace != nil { + return *x.SplitByWhitespace + } + return Default_TrainerSpec_SplitByWhitespace +} + +func (x *TrainerSpec) GetTreatWhitespaceAsSuffix() bool { + if x != nil && x.TreatWhitespaceAsSuffix != nil { + return *x.TreatWhitespaceAsSuffix + } + return Default_TrainerSpec_TreatWhitespaceAsSuffix +} + +func (x *TrainerSpec) GetAllowWhitespaceOnlyPieces() bool { + if x != nil && x.AllowWhitespaceOnlyPieces != nil { + return *x.AllowWhitespaceOnlyPieces + } + return Default_TrainerSpec_AllowWhitespaceOnlyPieces +} + +func (x *TrainerSpec) GetSplitDigits() bool { + if x != nil && x.SplitDigits != nil { + return *x.SplitDigits + } + return Default_TrainerSpec_SplitDigits +} + +func (x *TrainerSpec) GetPretokenizationDelimiter() string { + if x != nil && x.PretokenizationDelimiter != nil { + return *x.PretokenizationDelimiter + } + return Default_TrainerSpec_PretokenizationDelimiter +} + +func (x *TrainerSpec) GetControlSymbols() []string { + if x != nil { + return x.ControlSymbols + } + return nil +} + +func (x *TrainerSpec) GetUserDefinedSymbols() []string { + if x != nil { + return x.UserDefinedSymbols + } + return nil +} + +func (x *TrainerSpec) GetRequiredChars() string { + if x != nil && x.RequiredChars != nil { + return *x.RequiredChars + } + return "" +} + +func (x *TrainerSpec) GetByteFallback() bool { + if x != nil && x.ByteFallback != nil { + return *x.ByteFallback + } + return Default_TrainerSpec_ByteFallback +} + +func (x *TrainerSpec) GetVocabularyOutputPieceScore() bool { + if x != nil && x.VocabularyOutputPieceScore != nil { + return *x.VocabularyOutputPieceScore + } + return Default_TrainerSpec_VocabularyOutputPieceScore +} + +func (x *TrainerSpec) GetHardVocabLimit() bool { + if x != nil && x.HardVocabLimit != nil { + return *x.HardVocabLimit + } + return Default_TrainerSpec_HardVocabLimit +} + +func (x *TrainerSpec) GetUseAllVocab() bool { + if x != nil && x.UseAllVocab != nil { + return *x.UseAllVocab + } + return Default_TrainerSpec_UseAllVocab +} + +func (x *TrainerSpec) GetUnkId() int32 { + if x != nil && x.UnkId != nil { + return *x.UnkId + } + return Default_TrainerSpec_UnkId +} + +func (x *TrainerSpec) GetBosId() int32 { + if x != nil && x.BosId != nil { + return *x.BosId + } + return Default_TrainerSpec_BosId +} + +func (x *TrainerSpec) GetEosId() int32 { + if x != nil && x.EosId != nil { + return *x.EosId + } + return Default_TrainerSpec_EosId +} + +func (x *TrainerSpec) GetPadId() int32 { + if x != nil && x.PadId != nil { + return *x.PadId + } + return Default_TrainerSpec_PadId +} + +func (x *TrainerSpec) GetUnkPiece() string { + if x != nil && x.UnkPiece != nil { + return *x.UnkPiece + } + return Default_TrainerSpec_UnkPiece +} + +func (x *TrainerSpec) GetBosPiece() string { + if x != nil && x.BosPiece != nil { + return *x.BosPiece + } + return Default_TrainerSpec_BosPiece +} + +func (x *TrainerSpec) GetEosPiece() string { + if x != nil && x.EosPiece != nil { + return *x.EosPiece + } + return Default_TrainerSpec_EosPiece +} + +func (x *TrainerSpec) GetPadPiece() string { + if x != nil && x.PadPiece != nil { + return *x.PadPiece + } + return Default_TrainerSpec_PadPiece +} + +func (x *TrainerSpec) GetUnkSurface() string { + if x != nil && x.UnkSurface != nil { + return *x.UnkSurface + } + return Default_TrainerSpec_UnkSurface +} + +func (x *TrainerSpec) GetTrainExtremelyLargeCorpus() bool { + if x != nil && x.TrainExtremelyLargeCorpus != nil { + return *x.TrainExtremelyLargeCorpus + } + return Default_TrainerSpec_TrainExtremelyLargeCorpus +} + +func (x *TrainerSpec) GetSeedSentencepiecesFile() string { + if x != nil && x.SeedSentencepiecesFile != nil { + return *x.SeedSentencepiecesFile + } + return Default_TrainerSpec_SeedSentencepiecesFile +} + +// NormalizerSpec encodes a various parameters for string normalizaiton +type NormalizerSpec struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + extensionFields protoimpl.ExtensionFields + + // name of normalization rule. + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + // Pre-compiled normalization rule created by + // Builder::GetPrecompiledCharsMap() or Builder::CompileCharsMap() method. + // Usually this field is set by Builder::GetNormalizerSpec() method. + PrecompiledCharsmap []byte `protobuf:"bytes,2,opt,name=precompiled_charsmap,json=precompiledCharsmap" json:"precompiled_charsmap,omitempty"` + // Adds dummy whitespace at the beginning of text in order to + // treat "world" in "world" and "hello world" in the same way. + AddDummyPrefix *bool `protobuf:"varint,3,opt,name=add_dummy_prefix,json=addDummyPrefix,def=1" json:"add_dummy_prefix,omitempty"` + // Removes leading, trailing, and duplicate internal whitespace. + RemoveExtraWhitespaces *bool `protobuf:"varint,4,opt,name=remove_extra_whitespaces,json=removeExtraWhitespaces,def=1" json:"remove_extra_whitespaces,omitempty"` + // Replaces whitespace with meta symbol. + // This field must be true to train sentence piece model. + EscapeWhitespaces *bool `protobuf:"varint,5,opt,name=escape_whitespaces,json=escapeWhitespaces,def=1" json:"escape_whitespaces,omitempty"` + // Custom normalization rule file in TSV format. + // https://github.com/google/sentencepiece/blob/master/doc/normalization.md + // This field is only used in SentencePieceTrainer::Train() method, which + // compiles the rule into the binary rule stored in `precompiled_charsmap`. + NormalizationRuleTsv *string `protobuf:"bytes,6,opt,name=normalization_rule_tsv,json=normalizationRuleTsv" json:"normalization_rule_tsv,omitempty"` +} + +// Default values for NormalizerSpec fields. +const ( + Default_NormalizerSpec_AddDummyPrefix = bool(true) + Default_NormalizerSpec_RemoveExtraWhitespaces = bool(true) + Default_NormalizerSpec_EscapeWhitespaces = bool(true) +) + +func (x *NormalizerSpec) Reset() { + *x = NormalizerSpec{} + if protoimpl.UnsafeEnabled { + mi := &file_sentencepiece_model_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *NormalizerSpec) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*NormalizerSpec) ProtoMessage() {} + +func (x *NormalizerSpec) ProtoReflect() protoreflect.Message { + mi := &file_sentencepiece_model_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use NormalizerSpec.ProtoReflect.Descriptor instead. +func (*NormalizerSpec) Descriptor() ([]byte, []int) { + return file_sentencepiece_model_proto_rawDescGZIP(), []int{1} +} + +func (x *NormalizerSpec) GetName() string { + if x != nil && x.Name != nil { + return *x.Name + } + return "" +} + +func (x *NormalizerSpec) GetPrecompiledCharsmap() []byte { + if x != nil { + return x.PrecompiledCharsmap + } + return nil +} + +func (x *NormalizerSpec) GetAddDummyPrefix() bool { + if x != nil && x.AddDummyPrefix != nil { + return *x.AddDummyPrefix + } + return Default_NormalizerSpec_AddDummyPrefix +} + +func (x *NormalizerSpec) GetRemoveExtraWhitespaces() bool { + if x != nil && x.RemoveExtraWhitespaces != nil { + return *x.RemoveExtraWhitespaces + } + return Default_NormalizerSpec_RemoveExtraWhitespaces +} + +func (x *NormalizerSpec) GetEscapeWhitespaces() bool { + if x != nil && x.EscapeWhitespaces != nil { + return *x.EscapeWhitespaces + } + return Default_NormalizerSpec_EscapeWhitespaces +} + +func (x *NormalizerSpec) GetNormalizationRuleTsv() string { + if x != nil && x.NormalizationRuleTsv != nil { + return *x.NormalizationRuleTsv + } + return "" +} + +// Proto to store samples for self-testing. +type SelfTestData struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + extensionFields protoimpl.ExtensionFields + + Samples []*SelfTestData_Sample `protobuf:"bytes,1,rep,name=samples" json:"samples,omitempty"` +} + +func (x *SelfTestData) Reset() { + *x = SelfTestData{} + if protoimpl.UnsafeEnabled { + mi := &file_sentencepiece_model_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SelfTestData) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SelfTestData) ProtoMessage() {} + +func (x *SelfTestData) ProtoReflect() protoreflect.Message { + mi := &file_sentencepiece_model_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SelfTestData.ProtoReflect.Descriptor instead. +func (*SelfTestData) Descriptor() ([]byte, []int) { + return file_sentencepiece_model_proto_rawDescGZIP(), []int{2} +} + +func (x *SelfTestData) GetSamples() []*SelfTestData_Sample { + if x != nil { + return x.Samples + } + return nil +} + +// ModelProto stores model parameters. +// SentencePieceProcessor is supposed to be self-contained. +// All settings/parameters which may change the behavior must be encoded +// in ModelProto. +type ModelProto struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + extensionFields protoimpl.ExtensionFields + + // Sentence pieces with scores. + Pieces []*ModelProto_SentencePiece `protobuf:"bytes,1,rep,name=pieces" json:"pieces,omitempty"` + // Spec used to generate this model file. + TrainerSpec *TrainerSpec `protobuf:"bytes,2,opt,name=trainer_spec,json=trainerSpec" json:"trainer_spec,omitempty"` + // Spec for text normalization. + NormalizerSpec *NormalizerSpec `protobuf:"bytes,3,opt,name=normalizer_spec,json=normalizerSpec" json:"normalizer_spec,omitempty"` + // Stores sample input and its expected segmentation to verify the model. + SelfTestData *SelfTestData `protobuf:"bytes,4,opt,name=self_test_data,json=selfTestData" json:"self_test_data,omitempty"` + // Spec for text de-normalization. + DenormalizerSpec *NormalizerSpec `protobuf:"bytes,5,opt,name=denormalizer_spec,json=denormalizerSpec" json:"denormalizer_spec,omitempty"` +} + +func (x *ModelProto) Reset() { + *x = ModelProto{} + if protoimpl.UnsafeEnabled { + mi := &file_sentencepiece_model_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ModelProto) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ModelProto) ProtoMessage() {} + +func (x *ModelProto) ProtoReflect() protoreflect.Message { + mi := &file_sentencepiece_model_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ModelProto.ProtoReflect.Descriptor instead. +func (*ModelProto) Descriptor() ([]byte, []int) { + return file_sentencepiece_model_proto_rawDescGZIP(), []int{3} +} + +func (x *ModelProto) GetPieces() []*ModelProto_SentencePiece { + if x != nil { + return x.Pieces + } + return nil +} + +func (x *ModelProto) GetTrainerSpec() *TrainerSpec { + if x != nil { + return x.TrainerSpec + } + return nil +} + +func (x *ModelProto) GetNormalizerSpec() *NormalizerSpec { + if x != nil { + return x.NormalizerSpec + } + return nil +} + +func (x *ModelProto) GetSelfTestData() *SelfTestData { + if x != nil { + return x.SelfTestData + } + return nil +} + +func (x *ModelProto) GetDenormalizerSpec() *NormalizerSpec { + if x != nil { + return x.DenormalizerSpec + } + return nil +} + +type SelfTestData_Sample struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Input *string `protobuf:"bytes,1,opt,name=input" json:"input,omitempty"` + Expected *string `protobuf:"bytes,2,opt,name=expected" json:"expected,omitempty"` +} + +func (x *SelfTestData_Sample) Reset() { + *x = SelfTestData_Sample{} + if protoimpl.UnsafeEnabled { + mi := &file_sentencepiece_model_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SelfTestData_Sample) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SelfTestData_Sample) ProtoMessage() {} + +func (x *SelfTestData_Sample) ProtoReflect() protoreflect.Message { + mi := &file_sentencepiece_model_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SelfTestData_Sample.ProtoReflect.Descriptor instead. +func (*SelfTestData_Sample) Descriptor() ([]byte, []int) { + return file_sentencepiece_model_proto_rawDescGZIP(), []int{2, 0} +} + +func (x *SelfTestData_Sample) GetInput() string { + if x != nil && x.Input != nil { + return *x.Input + } + return "" +} + +func (x *SelfTestData_Sample) GetExpected() string { + if x != nil && x.Expected != nil { + return *x.Expected + } + return "" +} + +type ModelProto_SentencePiece struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + extensionFields protoimpl.ExtensionFields + + Piece *string `protobuf:"bytes,1,opt,name=piece" json:"piece,omitempty"` // piece must not be empty. + Score *float32 `protobuf:"fixed32,2,opt,name=score" json:"score,omitempty"` + Type *ModelProto_SentencePiece_Type `protobuf:"varint,3,opt,name=type,enum=sentencepiece.ModelProto_SentencePiece_Type,def=1" json:"type,omitempty"` +} + +// Default values for ModelProto_SentencePiece fields. +const ( + Default_ModelProto_SentencePiece_Type = ModelProto_SentencePiece_NORMAL +) + +func (x *ModelProto_SentencePiece) Reset() { + *x = ModelProto_SentencePiece{} + if protoimpl.UnsafeEnabled { + mi := &file_sentencepiece_model_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ModelProto_SentencePiece) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ModelProto_SentencePiece) ProtoMessage() {} + +func (x *ModelProto_SentencePiece) ProtoReflect() protoreflect.Message { + mi := &file_sentencepiece_model_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ModelProto_SentencePiece.ProtoReflect.Descriptor instead. +func (*ModelProto_SentencePiece) Descriptor() ([]byte, []int) { + return file_sentencepiece_model_proto_rawDescGZIP(), []int{3, 0} +} + +func (x *ModelProto_SentencePiece) GetPiece() string { + if x != nil && x.Piece != nil { + return *x.Piece + } + return "" +} + +func (x *ModelProto_SentencePiece) GetScore() float32 { + if x != nil && x.Score != nil { + return *x.Score + } + return 0 +} + +func (x *ModelProto_SentencePiece) GetType() ModelProto_SentencePiece_Type { + if x != nil && x.Type != nil { + return *x.Type + } + return Default_ModelProto_SentencePiece_Type +} + +var File_sentencepiece_model_proto protoreflect.FileDescriptor + +var file_sentencepiece_model_proto_rawDesc = []byte{ + 0x0a, 0x19, 0x73, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, 0x5f, + 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0d, 0x73, 0x65, 0x6e, + 0x74, 0x65, 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, 0x22, 0xc6, 0x12, 0x0a, 0x0b, 0x54, + 0x72, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x53, 0x70, 0x65, 0x63, 0x12, 0x14, 0x0a, 0x05, 0x69, 0x6e, + 0x70, 0x75, 0x74, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x69, 0x6e, 0x70, 0x75, 0x74, + 0x12, 0x21, 0x0a, 0x0c, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, 0x66, 0x6f, 0x72, 0x6d, 0x61, 0x74, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x46, 0x6f, 0x72, + 0x6d, 0x61, 0x74, 0x12, 0x21, 0x0a, 0x0c, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x5f, 0x70, 0x72, 0x65, + 0x66, 0x69, 0x78, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x6d, 0x6f, 0x64, 0x65, 0x6c, + 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x4c, 0x0a, 0x0a, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x5f, + 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x24, 0x2e, 0x73, 0x65, 0x6e, + 0x74, 0x65, 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, 0x2e, 0x54, 0x72, 0x61, 0x69, 0x6e, + 0x65, 0x72, 0x53, 0x70, 0x65, 0x63, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x54, 0x79, 0x70, 0x65, + 0x3a, 0x07, 0x55, 0x4e, 0x49, 0x47, 0x52, 0x41, 0x4d, 0x52, 0x09, 0x6d, 0x6f, 0x64, 0x65, 0x6c, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x23, 0x0a, 0x0a, 0x76, 0x6f, 0x63, 0x61, 0x62, 0x5f, 0x73, 0x69, + 0x7a, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x3a, 0x04, 0x38, 0x30, 0x30, 0x30, 0x52, 0x09, + 0x76, 0x6f, 0x63, 0x61, 0x62, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x27, 0x0a, 0x0f, 0x61, 0x63, 0x63, + 0x65, 0x70, 0x74, 0x5f, 0x6c, 0x61, 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x0e, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x4c, 0x61, 0x6e, 0x67, 0x75, 0x61, + 0x67, 0x65, 0x12, 0x34, 0x0a, 0x15, 0x73, 0x65, 0x6c, 0x66, 0x5f, 0x74, 0x65, 0x73, 0x74, 0x5f, + 0x73, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x5f, 0x73, 0x69, 0x7a, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x05, 0x3a, 0x01, 0x30, 0x52, 0x12, 0x73, 0x65, 0x6c, 0x66, 0x54, 0x65, 0x73, 0x74, 0x53, 0x61, + 0x6d, 0x70, 0x6c, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x45, 0x0a, 0x1b, 0x65, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x5f, 0x64, 0x69, 0x66, 0x66, 0x65, 0x72, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, + 0x70, 0x72, 0x69, 0x76, 0x61, 0x63, 0x79, 0x18, 0x32, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, + 0x61, 0x6c, 0x73, 0x65, 0x52, 0x19, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x44, 0x69, 0x66, 0x66, + 0x65, 0x72, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x50, 0x72, 0x69, 0x76, 0x61, 0x63, 0x79, 0x12, + 0x4a, 0x0a, 0x20, 0x64, 0x69, 0x66, 0x66, 0x65, 0x72, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, + 0x70, 0x72, 0x69, 0x76, 0x61, 0x63, 0x79, 0x5f, 0x6e, 0x6f, 0x69, 0x73, 0x65, 0x5f, 0x6c, 0x65, + 0x76, 0x65, 0x6c, 0x18, 0x33, 0x20, 0x01, 0x28, 0x02, 0x3a, 0x01, 0x30, 0x52, 0x1d, 0x64, 0x69, + 0x66, 0x66, 0x65, 0x72, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x50, 0x72, 0x69, 0x76, 0x61, 0x63, + 0x79, 0x4e, 0x6f, 0x69, 0x73, 0x65, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x58, 0x0a, 0x27, 0x64, + 0x69, 0x66, 0x66, 0x65, 0x72, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x70, 0x72, 0x69, 0x76, + 0x61, 0x63, 0x79, 0x5f, 0x63, 0x6c, 0x69, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x5f, 0x74, 0x68, 0x72, + 0x65, 0x73, 0x68, 0x6f, 0x6c, 0x64, 0x18, 0x34, 0x20, 0x01, 0x28, 0x04, 0x3a, 0x01, 0x30, 0x52, + 0x24, 0x64, 0x69, 0x66, 0x66, 0x65, 0x72, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x50, 0x72, 0x69, + 0x76, 0x61, 0x63, 0x79, 0x43, 0x6c, 0x69, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x54, 0x68, 0x72, 0x65, + 0x73, 0x68, 0x6f, 0x6c, 0x64, 0x12, 0x35, 0x0a, 0x12, 0x63, 0x68, 0x61, 0x72, 0x61, 0x63, 0x74, + 0x65, 0x72, 0x5f, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x61, 0x67, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, + 0x02, 0x3a, 0x06, 0x30, 0x2e, 0x39, 0x39, 0x39, 0x35, 0x52, 0x11, 0x63, 0x68, 0x61, 0x72, 0x61, + 0x63, 0x74, 0x65, 0x72, 0x43, 0x6f, 0x76, 0x65, 0x72, 0x61, 0x67, 0x65, 0x12, 0x31, 0x0a, 0x13, + 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, 0x73, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x5f, 0x73, + 0x69, 0x7a, 0x65, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x04, 0x3a, 0x01, 0x30, 0x52, 0x11, 0x69, 0x6e, + 0x70, 0x75, 0x74, 0x53, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x12, + 0x3a, 0x0a, 0x16, 0x73, 0x68, 0x75, 0x66, 0x66, 0x6c, 0x65, 0x5f, 0x69, 0x6e, 0x70, 0x75, 0x74, + 0x5f, 0x73, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x13, 0x20, 0x01, 0x28, 0x08, 0x3a, + 0x04, 0x74, 0x72, 0x75, 0x65, 0x52, 0x14, 0x73, 0x68, 0x75, 0x66, 0x66, 0x6c, 0x65, 0x49, 0x6e, + 0x70, 0x75, 0x74, 0x53, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x34, 0x0a, 0x14, 0x6d, + 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x5f, 0x73, + 0x69, 0x7a, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x05, 0x42, 0x02, 0x18, 0x01, 0x52, 0x12, 0x6d, + 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x53, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x53, 0x69, 0x7a, + 0x65, 0x12, 0x38, 0x0a, 0x16, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x5f, 0x73, 0x65, + 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x5f, 0x73, 0x69, 0x7a, 0x65, 0x18, 0x0d, 0x20, 0x01, 0x28, + 0x05, 0x42, 0x02, 0x18, 0x01, 0x52, 0x14, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x53, + 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x3f, 0x0a, 0x17, 0x73, + 0x65, 0x65, 0x64, 0x5f, 0x73, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, + 0x65, 0x5f, 0x73, 0x69, 0x7a, 0x65, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x05, 0x3a, 0x07, 0x31, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x52, 0x15, 0x73, 0x65, 0x65, 0x64, 0x53, 0x65, 0x6e, 0x74, 0x65, + 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x2f, 0x0a, 0x10, + 0x73, 0x68, 0x72, 0x69, 0x6e, 0x6b, 0x69, 0x6e, 0x67, 0x5f, 0x66, 0x61, 0x63, 0x74, 0x6f, 0x72, + 0x18, 0x0f, 0x20, 0x01, 0x28, 0x02, 0x3a, 0x04, 0x30, 0x2e, 0x37, 0x35, 0x52, 0x0f, 0x73, 0x68, + 0x72, 0x69, 0x6e, 0x6b, 0x69, 0x6e, 0x67, 0x46, 0x61, 0x63, 0x74, 0x6f, 0x72, 0x12, 0x34, 0x0a, + 0x13, 0x6d, 0x61, 0x78, 0x5f, 0x73, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x5f, 0x6c, 0x65, + 0x6e, 0x67, 0x74, 0x68, 0x18, 0x12, 0x20, 0x01, 0x28, 0x05, 0x3a, 0x04, 0x34, 0x31, 0x39, 0x32, + 0x52, 0x11, 0x6d, 0x61, 0x78, 0x53, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x4c, 0x65, 0x6e, + 0x67, 0x74, 0x68, 0x12, 0x23, 0x0a, 0x0b, 0x6e, 0x75, 0x6d, 0x5f, 0x74, 0x68, 0x72, 0x65, 0x61, + 0x64, 0x73, 0x18, 0x10, 0x20, 0x01, 0x28, 0x05, 0x3a, 0x02, 0x31, 0x36, 0x52, 0x0a, 0x6e, 0x75, + 0x6d, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x12, 0x2f, 0x0a, 0x12, 0x6e, 0x75, 0x6d, 0x5f, + 0x73, 0x75, 0x62, 0x5f, 0x69, 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x11, + 0x20, 0x01, 0x28, 0x05, 0x3a, 0x01, 0x32, 0x52, 0x10, 0x6e, 0x75, 0x6d, 0x53, 0x75, 0x62, 0x49, + 0x74, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x3c, 0x0a, 0x18, 0x6d, 0x61, 0x78, + 0x5f, 0x73, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, 0x5f, 0x6c, + 0x65, 0x6e, 0x67, 0x74, 0x68, 0x18, 0x14, 0x20, 0x01, 0x28, 0x05, 0x3a, 0x02, 0x31, 0x36, 0x52, + 0x16, 0x6d, 0x61, 0x78, 0x53, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, + 0x65, 0x4c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x12, 0x3b, 0x0a, 0x17, 0x73, 0x70, 0x6c, 0x69, 0x74, + 0x5f, 0x62, 0x79, 0x5f, 0x75, 0x6e, 0x69, 0x63, 0x6f, 0x64, 0x65, 0x5f, 0x73, 0x63, 0x72, 0x69, + 0x70, 0x74, 0x18, 0x15, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x04, 0x74, 0x72, 0x75, 0x65, 0x52, 0x14, + 0x73, 0x70, 0x6c, 0x69, 0x74, 0x42, 0x79, 0x55, 0x6e, 0x69, 0x63, 0x6f, 0x64, 0x65, 0x53, 0x63, + 0x72, 0x69, 0x70, 0x74, 0x12, 0x2c, 0x0a, 0x0f, 0x73, 0x70, 0x6c, 0x69, 0x74, 0x5f, 0x62, 0x79, + 0x5f, 0x6e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x17, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x04, 0x74, + 0x72, 0x75, 0x65, 0x52, 0x0d, 0x73, 0x70, 0x6c, 0x69, 0x74, 0x42, 0x79, 0x4e, 0x75, 0x6d, 0x62, + 0x65, 0x72, 0x12, 0x34, 0x0a, 0x13, 0x73, 0x70, 0x6c, 0x69, 0x74, 0x5f, 0x62, 0x79, 0x5f, 0x77, + 0x68, 0x69, 0x74, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x16, 0x20, 0x01, 0x28, 0x08, 0x3a, + 0x04, 0x74, 0x72, 0x75, 0x65, 0x52, 0x11, 0x73, 0x70, 0x6c, 0x69, 0x74, 0x42, 0x79, 0x57, 0x68, + 0x69, 0x74, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x42, 0x0a, 0x1a, 0x74, 0x72, 0x65, 0x61, + 0x74, 0x5f, 0x77, 0x68, 0x69, 0x74, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x5f, 0x61, 0x73, 0x5f, + 0x73, 0x75, 0x66, 0x66, 0x69, 0x78, 0x18, 0x18, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, + 0x6c, 0x73, 0x65, 0x52, 0x17, 0x74, 0x72, 0x65, 0x61, 0x74, 0x57, 0x68, 0x69, 0x74, 0x65, 0x73, + 0x70, 0x61, 0x63, 0x65, 0x41, 0x73, 0x53, 0x75, 0x66, 0x66, 0x69, 0x78, 0x12, 0x46, 0x0a, 0x1c, + 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x5f, 0x77, 0x68, 0x69, 0x74, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, + 0x5f, 0x6f, 0x6e, 0x6c, 0x79, 0x5f, 0x70, 0x69, 0x65, 0x63, 0x65, 0x73, 0x18, 0x1a, 0x20, 0x01, + 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x19, 0x61, 0x6c, 0x6c, 0x6f, 0x77, + 0x57, 0x68, 0x69, 0x74, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x4f, 0x6e, 0x6c, 0x79, 0x50, 0x69, + 0x65, 0x63, 0x65, 0x73, 0x12, 0x28, 0x0a, 0x0c, 0x73, 0x70, 0x6c, 0x69, 0x74, 0x5f, 0x64, 0x69, + 0x67, 0x69, 0x74, 0x73, 0x18, 0x19, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, + 0x65, 0x52, 0x0b, 0x73, 0x70, 0x6c, 0x69, 0x74, 0x44, 0x69, 0x67, 0x69, 0x74, 0x73, 0x12, 0x3d, + 0x0a, 0x19, 0x70, 0x72, 0x65, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x5f, 0x64, 0x65, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x72, 0x18, 0x35, 0x20, 0x01, 0x28, + 0x09, 0x3a, 0x00, 0x52, 0x18, 0x70, 0x72, 0x65, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x69, 0x7a, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x44, 0x65, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x72, 0x12, 0x27, 0x0a, + 0x0f, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x5f, 0x73, 0x79, 0x6d, 0x62, 0x6f, 0x6c, 0x73, + 0x18, 0x1e, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x53, + 0x79, 0x6d, 0x62, 0x6f, 0x6c, 0x73, 0x12, 0x30, 0x0a, 0x14, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x64, + 0x65, 0x66, 0x69, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x79, 0x6d, 0x62, 0x6f, 0x6c, 0x73, 0x18, 0x1f, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x12, 0x75, 0x73, 0x65, 0x72, 0x44, 0x65, 0x66, 0x69, 0x6e, 0x65, + 0x64, 0x53, 0x79, 0x6d, 0x62, 0x6f, 0x6c, 0x73, 0x12, 0x25, 0x0a, 0x0e, 0x72, 0x65, 0x71, 0x75, + 0x69, 0x72, 0x65, 0x64, 0x5f, 0x63, 0x68, 0x61, 0x72, 0x73, 0x18, 0x24, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x43, 0x68, 0x61, 0x72, 0x73, 0x12, + 0x2a, 0x0a, 0x0d, 0x62, 0x79, 0x74, 0x65, 0x5f, 0x66, 0x61, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, + 0x18, 0x23, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0c, 0x62, + 0x79, 0x74, 0x65, 0x46, 0x61, 0x6c, 0x6c, 0x62, 0x61, 0x63, 0x6b, 0x12, 0x47, 0x0a, 0x1d, 0x76, + 0x6f, 0x63, 0x61, 0x62, 0x75, 0x6c, 0x61, 0x72, 0x79, 0x5f, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, + 0x5f, 0x70, 0x69, 0x65, 0x63, 0x65, 0x5f, 0x73, 0x63, 0x6f, 0x72, 0x65, 0x18, 0x20, 0x20, 0x01, + 0x28, 0x08, 0x3a, 0x04, 0x74, 0x72, 0x75, 0x65, 0x52, 0x1a, 0x76, 0x6f, 0x63, 0x61, 0x62, 0x75, + 0x6c, 0x61, 0x72, 0x79, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x50, 0x69, 0x65, 0x63, 0x65, 0x53, + 0x63, 0x6f, 0x72, 0x65, 0x12, 0x2e, 0x0a, 0x10, 0x68, 0x61, 0x72, 0x64, 0x5f, 0x76, 0x6f, 0x63, + 0x61, 0x62, 0x5f, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x21, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x04, + 0x74, 0x72, 0x75, 0x65, 0x52, 0x0e, 0x68, 0x61, 0x72, 0x64, 0x56, 0x6f, 0x63, 0x61, 0x62, 0x4c, + 0x69, 0x6d, 0x69, 0x74, 0x12, 0x29, 0x0a, 0x0d, 0x75, 0x73, 0x65, 0x5f, 0x61, 0x6c, 0x6c, 0x5f, + 0x76, 0x6f, 0x63, 0x61, 0x62, 0x18, 0x22, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, + 0x73, 0x65, 0x52, 0x0b, 0x75, 0x73, 0x65, 0x41, 0x6c, 0x6c, 0x56, 0x6f, 0x63, 0x61, 0x62, 0x12, + 0x18, 0x0a, 0x06, 0x75, 0x6e, 0x6b, 0x5f, 0x69, 0x64, 0x18, 0x28, 0x20, 0x01, 0x28, 0x05, 0x3a, + 0x01, 0x30, 0x52, 0x05, 0x75, 0x6e, 0x6b, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x06, 0x62, 0x6f, 0x73, + 0x5f, 0x69, 0x64, 0x18, 0x29, 0x20, 0x01, 0x28, 0x05, 0x3a, 0x01, 0x31, 0x52, 0x05, 0x62, 0x6f, + 0x73, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x06, 0x65, 0x6f, 0x73, 0x5f, 0x69, 0x64, 0x18, 0x2a, 0x20, + 0x01, 0x28, 0x05, 0x3a, 0x01, 0x32, 0x52, 0x05, 0x65, 0x6f, 0x73, 0x49, 0x64, 0x12, 0x19, 0x0a, + 0x06, 0x70, 0x61, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x2b, 0x20, 0x01, 0x28, 0x05, 0x3a, 0x02, 0x2d, + 0x31, 0x52, 0x05, 0x70, 0x61, 0x64, 0x49, 0x64, 0x12, 0x22, 0x0a, 0x09, 0x75, 0x6e, 0x6b, 0x5f, + 0x70, 0x69, 0x65, 0x63, 0x65, 0x18, 0x2d, 0x20, 0x01, 0x28, 0x09, 0x3a, 0x05, 0x3c, 0x75, 0x6e, + 0x6b, 0x3e, 0x52, 0x08, 0x75, 0x6e, 0x6b, 0x50, 0x69, 0x65, 0x63, 0x65, 0x12, 0x20, 0x0a, 0x09, + 0x62, 0x6f, 0x73, 0x5f, 0x70, 0x69, 0x65, 0x63, 0x65, 0x18, 0x2e, 0x20, 0x01, 0x28, 0x09, 0x3a, + 0x03, 0x3c, 0x73, 0x3e, 0x52, 0x08, 0x62, 0x6f, 0x73, 0x50, 0x69, 0x65, 0x63, 0x65, 0x12, 0x21, + 0x0a, 0x09, 0x65, 0x6f, 0x73, 0x5f, 0x70, 0x69, 0x65, 0x63, 0x65, 0x18, 0x2f, 0x20, 0x01, 0x28, + 0x09, 0x3a, 0x04, 0x3c, 0x2f, 0x73, 0x3e, 0x52, 0x08, 0x65, 0x6f, 0x73, 0x50, 0x69, 0x65, 0x63, + 0x65, 0x12, 0x22, 0x0a, 0x09, 0x70, 0x61, 0x64, 0x5f, 0x70, 0x69, 0x65, 0x63, 0x65, 0x18, 0x30, + 0x20, 0x01, 0x28, 0x09, 0x3a, 0x05, 0x3c, 0x70, 0x61, 0x64, 0x3e, 0x52, 0x08, 0x70, 0x61, 0x64, + 0x50, 0x69, 0x65, 0x63, 0x65, 0x12, 0x26, 0x0a, 0x0b, 0x75, 0x6e, 0x6b, 0x5f, 0x73, 0x75, 0x72, + 0x66, 0x61, 0x63, 0x65, 0x18, 0x2c, 0x20, 0x01, 0x28, 0x09, 0x3a, 0x05, 0x20, 0xe2, 0x81, 0x87, + 0x20, 0x52, 0x0a, 0x75, 0x6e, 0x6b, 0x53, 0x75, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x46, 0x0a, + 0x1c, 0x74, 0x72, 0x61, 0x69, 0x6e, 0x5f, 0x65, 0x78, 0x74, 0x72, 0x65, 0x6d, 0x65, 0x6c, 0x79, + 0x5f, 0x6c, 0x61, 0x72, 0x67, 0x65, 0x5f, 0x63, 0x6f, 0x72, 0x70, 0x75, 0x73, 0x18, 0x31, 0x20, + 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x19, 0x74, 0x72, 0x61, 0x69, + 0x6e, 0x45, 0x78, 0x74, 0x72, 0x65, 0x6d, 0x65, 0x6c, 0x79, 0x4c, 0x61, 0x72, 0x67, 0x65, 0x43, + 0x6f, 0x72, 0x70, 0x75, 0x73, 0x12, 0x3a, 0x0a, 0x18, 0x73, 0x65, 0x65, 0x64, 0x5f, 0x73, 0x65, + 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, 0x73, 0x5f, 0x66, 0x69, 0x6c, + 0x65, 0x18, 0x36, 0x20, 0x01, 0x28, 0x09, 0x3a, 0x00, 0x52, 0x16, 0x73, 0x65, 0x65, 0x64, 0x53, + 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, 0x73, 0x46, 0x69, 0x6c, + 0x65, 0x22, 0x35, 0x0a, 0x09, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0b, + 0x0a, 0x07, 0x55, 0x4e, 0x49, 0x47, 0x52, 0x41, 0x4d, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x42, + 0x50, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x4f, 0x52, 0x44, 0x10, 0x03, 0x12, 0x08, + 0x0a, 0x04, 0x43, 0x48, 0x41, 0x52, 0x10, 0x04, 0x2a, 0x09, 0x08, 0xc8, 0x01, 0x10, 0x80, 0x80, + 0x80, 0x80, 0x02, 0x22, 0xbd, 0x02, 0x0a, 0x0e, 0x4e, 0x6f, 0x72, 0x6d, 0x61, 0x6c, 0x69, 0x7a, + 0x65, 0x72, 0x53, 0x70, 0x65, 0x63, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x31, 0x0a, 0x14, 0x70, 0x72, + 0x65, 0x63, 0x6f, 0x6d, 0x70, 0x69, 0x6c, 0x65, 0x64, 0x5f, 0x63, 0x68, 0x61, 0x72, 0x73, 0x6d, + 0x61, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x13, 0x70, 0x72, 0x65, 0x63, 0x6f, 0x6d, + 0x70, 0x69, 0x6c, 0x65, 0x64, 0x43, 0x68, 0x61, 0x72, 0x73, 0x6d, 0x61, 0x70, 0x12, 0x2e, 0x0a, + 0x10, 0x61, 0x64, 0x64, 0x5f, 0x64, 0x75, 0x6d, 0x6d, 0x79, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, + 0x78, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x04, 0x74, 0x72, 0x75, 0x65, 0x52, 0x0e, 0x61, + 0x64, 0x64, 0x44, 0x75, 0x6d, 0x6d, 0x79, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x3e, 0x0a, + 0x18, 0x72, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x5f, 0x65, 0x78, 0x74, 0x72, 0x61, 0x5f, 0x77, 0x68, + 0x69, 0x74, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x3a, + 0x04, 0x74, 0x72, 0x75, 0x65, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x45, 0x78, 0x74, + 0x72, 0x61, 0x57, 0x68, 0x69, 0x74, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x73, 0x12, 0x33, 0x0a, + 0x12, 0x65, 0x73, 0x63, 0x61, 0x70, 0x65, 0x5f, 0x77, 0x68, 0x69, 0x74, 0x65, 0x73, 0x70, 0x61, + 0x63, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x04, 0x74, 0x72, 0x75, 0x65, 0x52, + 0x11, 0x65, 0x73, 0x63, 0x61, 0x70, 0x65, 0x57, 0x68, 0x69, 0x74, 0x65, 0x73, 0x70, 0x61, 0x63, + 0x65, 0x73, 0x12, 0x34, 0x0a, 0x16, 0x6e, 0x6f, 0x72, 0x6d, 0x61, 0x6c, 0x69, 0x7a, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x5f, 0x72, 0x75, 0x6c, 0x65, 0x5f, 0x74, 0x73, 0x76, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x14, 0x6e, 0x6f, 0x72, 0x6d, 0x61, 0x6c, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x52, 0x75, 0x6c, 0x65, 0x54, 0x73, 0x76, 0x2a, 0x09, 0x08, 0xc8, 0x01, 0x10, 0x80, 0x80, + 0x80, 0x80, 0x02, 0x22, 0x93, 0x01, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x66, 0x54, 0x65, 0x73, 0x74, + 0x44, 0x61, 0x74, 0x61, 0x12, 0x3c, 0x0a, 0x07, 0x73, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x22, 0x2e, 0x73, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, + 0x70, 0x69, 0x65, 0x63, 0x65, 0x2e, 0x53, 0x65, 0x6c, 0x66, 0x54, 0x65, 0x73, 0x74, 0x44, 0x61, + 0x74, 0x61, 0x2e, 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x07, 0x73, 0x61, 0x6d, 0x70, 0x6c, + 0x65, 0x73, 0x1a, 0x3a, 0x0a, 0x06, 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x69, 0x6e, 0x70, 0x75, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x69, 0x6e, 0x70, + 0x75, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x78, 0x70, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x78, 0x70, 0x65, 0x63, 0x74, 0x65, 0x64, 0x2a, 0x09, + 0x08, 0xc8, 0x01, 0x10, 0x80, 0x80, 0x80, 0x80, 0x02, 0x22, 0xd7, 0x04, 0x0a, 0x0a, 0x4d, 0x6f, + 0x64, 0x65, 0x6c, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x3f, 0x0a, 0x06, 0x70, 0x69, 0x65, 0x63, + 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x27, 0x2e, 0x73, 0x65, 0x6e, 0x74, 0x65, + 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x50, 0x72, + 0x6f, 0x74, 0x6f, 0x2e, 0x53, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x50, 0x69, 0x65, 0x63, + 0x65, 0x52, 0x06, 0x70, 0x69, 0x65, 0x63, 0x65, 0x73, 0x12, 0x3d, 0x0a, 0x0c, 0x74, 0x72, 0x61, + 0x69, 0x6e, 0x65, 0x72, 0x5f, 0x73, 0x70, 0x65, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x1a, 0x2e, 0x73, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, 0x2e, + 0x54, 0x72, 0x61, 0x69, 0x6e, 0x65, 0x72, 0x53, 0x70, 0x65, 0x63, 0x52, 0x0b, 0x74, 0x72, 0x61, + 0x69, 0x6e, 0x65, 0x72, 0x53, 0x70, 0x65, 0x63, 0x12, 0x46, 0x0a, 0x0f, 0x6e, 0x6f, 0x72, 0x6d, + 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x72, 0x5f, 0x73, 0x70, 0x65, 0x63, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1d, 0x2e, 0x73, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, + 0x65, 0x2e, 0x4e, 0x6f, 0x72, 0x6d, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x72, 0x53, 0x70, 0x65, 0x63, + 0x52, 0x0e, 0x6e, 0x6f, 0x72, 0x6d, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x72, 0x53, 0x70, 0x65, 0x63, + 0x12, 0x41, 0x0a, 0x0e, 0x73, 0x65, 0x6c, 0x66, 0x5f, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x64, 0x61, + 0x74, 0x61, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x73, 0x65, 0x6e, 0x74, 0x65, + 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, 0x2e, 0x53, 0x65, 0x6c, 0x66, 0x54, 0x65, 0x73, + 0x74, 0x44, 0x61, 0x74, 0x61, 0x52, 0x0c, 0x73, 0x65, 0x6c, 0x66, 0x54, 0x65, 0x73, 0x74, 0x44, + 0x61, 0x74, 0x61, 0x12, 0x4a, 0x0a, 0x11, 0x64, 0x65, 0x6e, 0x6f, 0x72, 0x6d, 0x61, 0x6c, 0x69, + 0x7a, 0x65, 0x72, 0x5f, 0x73, 0x70, 0x65, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, + 0x2e, 0x73, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, 0x2e, 0x4e, + 0x6f, 0x72, 0x6d, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x72, 0x53, 0x70, 0x65, 0x63, 0x52, 0x10, 0x64, + 0x65, 0x6e, 0x6f, 0x72, 0x6d, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x72, 0x53, 0x70, 0x65, 0x63, 0x1a, + 0xe6, 0x01, 0x0a, 0x0d, 0x53, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x50, 0x69, 0x65, 0x63, + 0x65, 0x12, 0x14, 0x0a, 0x05, 0x70, 0x69, 0x65, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x05, 0x70, 0x69, 0x65, 0x63, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x63, 0x6f, 0x72, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x02, 0x52, 0x05, 0x73, 0x63, 0x6f, 0x72, 0x65, 0x12, 0x48, 0x0a, + 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x73, 0x65, + 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, 0x2e, 0x4d, 0x6f, 0x64, 0x65, + 0x6c, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x53, 0x65, 0x6e, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x50, + 0x69, 0x65, 0x63, 0x65, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x3a, 0x06, 0x4e, 0x4f, 0x52, 0x4d, 0x41, + 0x4c, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x22, 0x54, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, + 0x0a, 0x0a, 0x06, 0x4e, 0x4f, 0x52, 0x4d, 0x41, 0x4c, 0x10, 0x01, 0x12, 0x0b, 0x0a, 0x07, 0x55, + 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x43, 0x4f, 0x4e, 0x54, + 0x52, 0x4f, 0x4c, 0x10, 0x03, 0x12, 0x10, 0x0a, 0x0c, 0x55, 0x53, 0x45, 0x52, 0x5f, 0x44, 0x45, + 0x46, 0x49, 0x4e, 0x45, 0x44, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x42, 0x59, 0x54, 0x45, 0x10, + 0x06, 0x12, 0x0a, 0x0a, 0x06, 0x55, 0x4e, 0x55, 0x53, 0x45, 0x44, 0x10, 0x05, 0x2a, 0x09, 0x08, + 0xc8, 0x01, 0x10, 0x80, 0x80, 0x80, 0x80, 0x02, 0x2a, 0x09, 0x08, 0xc8, 0x01, 0x10, 0x80, 0x80, + 0x80, 0x80, 0x02, 0x42, 0x13, 0x48, 0x03, 0x5a, 0x0f, 0x2e, 0x2f, 0x73, 0x65, 0x6e, 0x74, 0x65, + 0x6e, 0x63, 0x65, 0x70, 0x69, 0x65, 0x63, 0x65, +} + +var ( + file_sentencepiece_model_proto_rawDescOnce sync.Once + file_sentencepiece_model_proto_rawDescData = file_sentencepiece_model_proto_rawDesc +) + +func file_sentencepiece_model_proto_rawDescGZIP() []byte { + file_sentencepiece_model_proto_rawDescOnce.Do(func() { + file_sentencepiece_model_proto_rawDescData = protoimpl.X.CompressGZIP(file_sentencepiece_model_proto_rawDescData) + }) + return file_sentencepiece_model_proto_rawDescData +} + +var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2) +var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_sentencepiece_model_proto_goTypes = []interface{}{ + (TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType + (ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type + (*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec + (*NormalizerSpec)(nil), // 3: sentencepiece.NormalizerSpec + (*SelfTestData)(nil), // 4: sentencepiece.SelfTestData + (*ModelProto)(nil), // 5: sentencepiece.ModelProto + (*SelfTestData_Sample)(nil), // 6: sentencepiece.SelfTestData.Sample + (*ModelProto_SentencePiece)(nil), // 7: sentencepiece.ModelProto.SentencePiece +} +var file_sentencepiece_model_proto_depIdxs = []int32{ + 0, // 0: sentencepiece.TrainerSpec.model_type:type_name -> sentencepiece.TrainerSpec.ModelType + 6, // 1: sentencepiece.SelfTestData.samples:type_name -> sentencepiece.SelfTestData.Sample + 7, // 2: sentencepiece.ModelProto.pieces:type_name -> sentencepiece.ModelProto.SentencePiece + 2, // 3: sentencepiece.ModelProto.trainer_spec:type_name -> sentencepiece.TrainerSpec + 3, // 4: sentencepiece.ModelProto.normalizer_spec:type_name -> sentencepiece.NormalizerSpec + 4, // 5: sentencepiece.ModelProto.self_test_data:type_name -> sentencepiece.SelfTestData + 3, // 6: sentencepiece.ModelProto.denormalizer_spec:type_name -> sentencepiece.NormalizerSpec + 1, // 7: sentencepiece.ModelProto.SentencePiece.type:type_name -> sentencepiece.ModelProto.SentencePiece.Type + 8, // [8:8] is the sub-list for method output_type + 8, // [8:8] is the sub-list for method input_type + 8, // [8:8] is the sub-list for extension type_name + 8, // [8:8] is the sub-list for extension extendee + 0, // [0:8] is the sub-list for field type_name +} + +func init() { file_sentencepiece_model_proto_init() } +func file_sentencepiece_model_proto_init() { + if File_sentencepiece_model_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TrainerSpec); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + case 3: + return &v.extensionFields + default: + return nil + } + } + file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*NormalizerSpec); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + case 3: + return &v.extensionFields + default: + return nil + } + } + file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SelfTestData); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + case 3: + return &v.extensionFields + default: + return nil + } + } + file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ModelProto); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + case 3: + return &v.extensionFields + default: + return nil + } + } + file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SelfTestData_Sample); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ModelProto_SentencePiece); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + case 3: + return &v.extensionFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_sentencepiece_model_proto_rawDesc, + NumEnums: 2, + NumMessages: 6, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_sentencepiece_model_proto_goTypes, + DependencyIndexes: file_sentencepiece_model_proto_depIdxs, + EnumInfos: file_sentencepiece_model_proto_enumTypes, + MessageInfos: file_sentencepiece_model_proto_msgTypes, + }.Build() + File_sentencepiece_model_proto = out.File + file_sentencepiece_model_proto_rawDesc = nil + file_sentencepiece_model_proto_goTypes = nil + file_sentencepiece_model_proto_depIdxs = nil +} diff --git a/convert/sentencepiece_model.proto b/convert/sentencepiece_model.proto new file mode 100644 index 00000000..5dc02d6c --- /dev/null +++ b/convert/sentencepiece_model.proto @@ -0,0 +1,333 @@ +// Copyright 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.! + +syntax = "proto2"; + +// TODO(taku): Needs to use LITE RUNTIME in OSS release. +option optimize_for = LITE_RUNTIME; +option go_package = "./sentencepiece"; + +package sentencepiece; + +// TrainerSpec encodes a various parameters for SentencePiece training. +// Next id: 55 +message TrainerSpec { + /////////////////////////////////////////////////////////////////// + // General parameters + // + // Input corpus files. + // Trainer accepts the following two formats: + // A) Monolingual: plain text, one sentence per line. + // B) Bilingual: TSV, source sentence target sentence + // When bilingual data is passed, shared vocabulary model is built. + // Note that the input file must be raw corpus, not a preprocessed corpus. + // Trainer only loads the first `input_sentence_size` sentences specified + // with this parameter. + repeated string input = 1; + + // Input corpus format: + // "text": one-sentence-per-line text format (default) + // "tsv": sentence freq + optional string input_format = 7; + + // Output model file prefix. + // .model and .vocab are generated. + optional string model_prefix = 2; + + // Model type. only have UNIGRAM now. + enum ModelType { + UNIGRAM = 1; // Unigram language model with dynamic algorithm + BPE = 2; // Byte Pair Encoding + WORD = 3; // Delimitered by whitespace. + CHAR = 4; // tokenizes into character sequence + } + optional ModelType model_type = 3 [default = UNIGRAM]; + + // Vocabulary size. 8k is the default size. + optional int32 vocab_size = 4 [default = 8000]; + + // List of the languages this model can accept. + // Since the model is language-agnostic, this field is used as a reference. + repeated string accept_language = 5; + + // Size of self-test samples, which are encoded in the model file. + optional int32 self_test_sample_size = 6 [default = 0]; + + // Whether to use DP version of sentencepiece. Use it with TSV input format + // (requires precomputed word tab counts to work). + optional bool enable_differential_privacy = 50 [default = false]; + // Set these parameters if you need DP version of sentencepiece. + // std of noise to add. + optional float differential_privacy_noise_level = 51 [default = 0.0]; + // Clipping threshold to apply after adding noise. All the words with + // frequency less than this value are dropped. + optional uint64 differential_privacy_clipping_threshold = 52 [default = 0]; + + /////////////////////////////////////////////////////////////////// + // Training parameters. + // + // Uses characters which cover the corpus with the ratio of `chars_coverage`. + // This parameter determines the set of basic Alphabet of sentence piece. + // 1.0 - `chars_coverage` characters are treated as UNK. + // See also required_chars field. + optional float character_coverage = 10 [default = 0.9995]; + + // Maximum size of sentences the trainer loads from `input` parameter. + // Trainer simply loads the `input` files in sequence. + // It is better to shuffle the input corpus randomly. + optional uint64 input_sentence_size = 11 [default = 0]; + optional bool shuffle_input_sentence = 19 [default = true]; + + // Maximum size of sentences to make seed sentence pieces. + // Extended suffix array is constructed to extract frequent + // sub-strings from the corpus. This uses 20N working space, + // where N is the size of corpus. + optional int32 mining_sentence_size = 12 [deprecated = true]; + + // Maximum size of sentences to train sentence pieces. + optional int32 training_sentence_size = 13 [deprecated = true]; + + // The size of seed sentencepieces. + // `seed_sentencepiece_size` must be larger than `vocab_size`. + optional int32 seed_sentencepiece_size = 14 [default = 1000000]; + + // In every EM sub-iterations, keeps top + // `shrinking_factor` * `current sentencepieces size` with respect to + // the loss of the sentence piece. This value should be smaller than 1.0. + optional float shrinking_factor = 15 [default = 0.75]; + + // The maximum sentence length in byte. The sentences with the length + // larger than `max_sentence_length` is simply ignored. + // Longer input tends to bring the following risks: + // * Overflow during EM training (unigram language model only) + // * Performance drop because of O(n log n) cost in BPE. + optional int32 max_sentence_length = 18 [default = 4192]; + + // Number of threads in the training. + optional int32 num_threads = 16 [default = 16]; + + // Number of EM sub iterations. + optional int32 num_sub_iterations = 17 [default = 2]; + + /////////////////////////////////////////////////////////////////// + // SentencePiece parameters which control the shapes of sentence piece. + // + // Maximum length of sentencepiece. + optional int32 max_sentencepiece_length = 20 [default = 16]; + + // Uses Unicode script to split sentence pieces. + // When `split_by_unicode_script` is true, we do not allow sentence piece to + // include multiple Unicode scripts, e.g. "F1" is not a valid piece. + // Exception: CJ characters (Hiragana/Katakana/Han) are all handled + // as one script type, since Japanese word can consist of multiple scripts. + // This exception is always applied regardless of the accept-language + // parameter. + optional bool split_by_unicode_script = 21 [default = true]; + + // When `split_by_number` is true, put a boundary between number and + // non-number transition. If we want to treat "F1" is one token, set this flag + // to be false. + optional bool split_by_number = 23 [default = true]; + + // Use a white space to split sentence pieces. + // When `split_by_whitespace` is false, we may have the piece containing + // a white space in the middle. e.g., "in_the". + optional bool split_by_whitespace = 22 [default = true]; + + // Adds whitespace symbol (_) as a suffix instead of prefix. e.g., _hello => + // hello_. When `treat_whitespace_as_suffix` is true, + // NormalizerSpec::add_dummy_prefix will add the dummy whitespace to the end + // of sentence. + optional bool treat_whitespace_as_suffix = 24 [default = false]; + + // Allows pieces that only contain whitespaces instead of appearing only as + // prefix or suffix of other pieces. + optional bool allow_whitespace_only_pieces = 26 [default = false]; + + // Split all digits (0-9) into separate pieces. + optional bool split_digits = 25 [default = false]; + + // Defines the pre-tokenization delimiter. + // When specified, no pieces crossing this delimiter is not included + // in the vocab. Then the delimiter string is virtually ignored + // during the training. This field can allows constraints on the vocabulary + // selection. Note that this field is available on unigram mode. + optional string pretokenization_delimiter = 53 [ default = ""]; + + /////////////////////////////////////////////////////////////////// + // Vocabulary management + // + // Defines control symbols used as an indicator to + // change the behavior of the decoder. and are pre-defined. + // We can use this field to encode various meta information, + // including language indicator in multilingual model. + // These symbols are not visible to users, but visible to + // the decoder. Note that when the input sentence contains control symbols, + // they are not treated as one token, but segmented into normal pieces. + // Control symbols must be inserted independently from the segmentation. + repeated string control_symbols = 30; + + // Defines user defined symbols. + // These symbols are added with extremely high score + // so they are always treated as one unique symbol in any context. + // Typical usage of user_defined_symbols is placeholder for named entities. + repeated string user_defined_symbols = 31; + + // Defines required characters. Each UTF8 character in this string is included + // in the character set regardless of character_coverage value. Unlike + // user_defined_symbols, these characters have scores based on the frequency + // on input sentences, and the model can form subwords using characters + // in this field. + optional string required_chars = 36; + + // Decomposes unknown pieces into UTF-8 bytes. + optional bool byte_fallback = 35 [default = false]; + + // When creating the vocabulary file, defines whether or not to additionally + // output the score for each piece. + optional bool vocabulary_output_piece_score = 32 [default = true]; + + // `vocab_size` is treated as hard limit. Crash if + // the model can not produce the vocab of size `vocab_size`, + // When `hard_vocab_limit` is false, vocab_size is treated + // as soft limit. Note that when model_type=char, + // always assumes hard_vocab_limit = false. + optional bool hard_vocab_limit = 33 [default = true]; + + // use all symbols for vocab extraction. This flag is valid + // if model type is either CHAR or WORD + optional bool use_all_vocab = 34 [default = false]; + + /////////////////////////////////////////////////////////////////// + // Reserved special meta tokens. + // * -1 is not used. + // * unk_id must not be -1. + // Id must starts with 0 and be contigous. + optional int32 unk_id = 40 [default = 0]; // + optional int32 bos_id = 41 [default = 1]; // + optional int32 eos_id = 42 [default = 2]; // + optional int32 pad_id = 43 [default = -1]; // (padding) + optional string unk_piece = 45 [default = ""]; + optional string bos_piece = 46 [default = ""]; + optional string eos_piece = 47 [default = ""]; + optional string pad_piece = 48 [default = ""]; + + // Encodes into U+2047 (DOUBLE QUESTION MARK), + // since this character can be useful both for user and + // developer. We can easily figure out that is emitted. + optional string unk_surface = 44 [default = " \xE2\x81\x87 "]; + + // Increase bit depth to allow unigram model training on large + // (>10M sentences) corpora. A Side-effect of enabling this flag + // is increased memory usage. + optional bool train_extremely_large_corpus = 49 [default = false]; + + // Path to a seed sentencepieces file, with one tab-separated + // seed sentencepiece frequency per line. + optional string seed_sentencepieces_file = 54 [default = ""]; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + +// NormalizerSpec encodes a various parameters for string normalizaiton +message NormalizerSpec { + // name of normalization rule. + optional string name = 1; + + // Pre-compiled normalization rule created by + // Builder::GetPrecompiledCharsMap() or Builder::CompileCharsMap() method. + // Usually this field is set by Builder::GetNormalizerSpec() method. + optional bytes precompiled_charsmap = 2; + + // Adds dummy whitespace at the beginning of text in order to + // treat "world" in "world" and "hello world" in the same way. + optional bool add_dummy_prefix = 3 [default = true]; + + // Removes leading, trailing, and duplicate internal whitespace. + optional bool remove_extra_whitespaces = 4 [default = true]; + + // Replaces whitespace with meta symbol. + // This field must be true to train sentence piece model. + optional bool escape_whitespaces = 5 [default = true]; + + // Custom normalization rule file in TSV format. + // https://github.com/google/sentencepiece/blob/master/doc/normalization.md + // This field is only used in SentencePieceTrainer::Train() method, which + // compiles the rule into the binary rule stored in `precompiled_charsmap`. + optional string normalization_rule_tsv = 6; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + +// Proto to store samples for self-testing. +message SelfTestData { + message Sample { + optional string input = 1; + optional string expected = 2; + } + repeated Sample samples = 1; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} + +// ModelProto stores model parameters. +// SentencePieceProcessor is supposed to be self-contained. +// All settings/parameters which may change the behavior must be encoded +// in ModelProto. +message ModelProto { + message SentencePiece { + enum Type { + NORMAL = 1; // normal symbol + UNKNOWN = 2; // unknown symbol. only for now. + CONTROL = 3; // control symbols. , , <2ja> etc. + USER_DEFINED = 4; // user defined symbols. + // Typical usage of USER_DEFINED symbol + // is placeholder. + BYTE = 6; // byte symbols. Used when `byte_fallback` is true. + UNUSED = 5; // this piece is not used. + } + optional string piece = 1; // piece must not be empty. + optional float score = 2; + optional Type type = 3 [default = NORMAL]; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; + } + + // Sentence pieces with scores. + repeated SentencePiece pieces = 1; + + // Spec used to generate this model file. + optional TrainerSpec trainer_spec = 2; + + // Spec for text normalization. + optional NormalizerSpec normalizer_spec = 3; + + // Stores sample input and its expected segmentation to verify the model. + optional SelfTestData self_test_data = 4; + + // Spec for text de-normalization. + optional NormalizerSpec denormalizer_spec = 5; + + // Customized extensions: the range of field numbers + // are open to third-party extensions. + extensions 200 to max; +} diff --git a/docs/README.md b/docs/README.md index fd5b902f..ea058e60 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,25 +1,21 @@ # Documentation -To get started, see the project's **[quickstart](../README.md#quickstart)**. +### Getting Started +* [Quickstart](../README.md#quickstart) +* [Examples](../examples) +* [Importing models](./import.md) from GGUF, Pytorch and Safetensors +* [Linux Documentation](./linux.md) +* [Windows Documentation](./windows.md) +* [Docker Documentation](https://hub.docker.com/r/ollama/ollama) -Ollama is a tool for running AI models on your hardware. Many users will choose to use the Command Line Interface (CLI) to work with Ollama. Learn more about all the commands in the CLI in the **[Main Readme](../README.md)**. +### Reference -Use the RESTful API using any language, including Python, JavaScript, Typescript, Go, Rust, and many more. Learn more about using the API in the **[API Documentation](./api.md)**. +* [API Reference](./api.md) +* [Modelfile Reference](./modelfile.md) +* [OpenAI Compatibility](./openai.md) -Create new models or modify models already in the library using the Modelfile. Learn more about the Modelfile syntax in the **[Modelfile Documentation](./modelfile.md)**. +### Resources -Import models using source model weights found on Hugging Face and similar sites by referring to the **[Import Documentation](./import.md)**. - -Installing on Linux in most cases is easy using the script on Ollama.ai. To get more detail about the install, including CUDA drivers, see the **[Linux Documentation](./linux.md)**. - -Many of our users like the flexibility of using our official Docker Image. Learn more about using Docker with Ollama using the **[Docker Documentation](https://hub.docker.com/r/ollama/ollama)**. - -It is easy to install on Linux and Mac, but many users will choose to build Ollama on their own. To do this, refer to the **[Development Documentation](./development.md)**. - -If encountering a problem with Ollama, the best place to start is the logs. Find more information about them here in the **[Troubleshooting Guide](./troubleshooting.md)**. - -Finally for all the questions that don't fit anywhere else, there is the **[FAQ](./faq.md)** - -[Tutorials](./tutorials.md) apply the documentation to tasks. - -For working code examples of using Ollama, see [Examples](../examples). +* [Troubleshooting Guide](./troubleshooting.md) +* [FAQ](./faq.md) +* [Development guide](./development.md) diff --git a/docs/api.md b/docs/api.md index 0202b7e8..5ec92a82 100644 --- a/docs/api.md +++ b/docs/api.md @@ -49,11 +49,12 @@ Advanced parameters (optional): - `template`: the prompt template to use (overrides what is defined in the `Modelfile`) - `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects -- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API. +- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API +- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) #### JSON mode -Enable JSON mode by setting the `format` parameter to `json`. This will structure the response as a valid JSON object. See the JSON mode [example](#generate-request-json-mode) below. +Enable JSON mode by setting the `format` parameter to `json`. This will structure the response as a valid JSON object. See the JSON mode [example](#request-json-mode) below. > Note: it's important to instruct the model to use JSON in the `prompt`. Otherwise, the model may generate large amounts whitespace. @@ -246,6 +247,23 @@ curl http://localhost:11434/api/generate -d '{ }' ``` +#### Request (Reproducible outputs) + +For reproducible outputs, set `temperature` to 0 and `seed` to a number: + +##### Request + +```shell +curl http://localhost:11434/api/generate -d '{ + "model": "mistral", + "prompt": "Why is the sky blue?", + "options": { + "seed": 123, + "temperature": 0 + } +}' +``` + ##### Response ```json @@ -303,7 +321,6 @@ curl http://localhost:11434/api/generate -d '{ "vocab_only": false, "use_mmap": true, "use_mlock": false, - "embedding_only": false, "rope_frequency_base": 1.1, "rope_frequency_scale": 0.8, "num_thread": 8 @@ -379,6 +396,7 @@ Advanced parameters (optional): - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `template`: the prompt template to use (overrides what is defined in the `Modelfile`) - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects +- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) ### Examples @@ -542,7 +560,7 @@ curl http://localhost:11434/api/chat -d '{ "role": "user", "content": "what is in this image?", "images": ["iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC"] - }, + } ] }' ``` @@ -568,6 +586,46 @@ curl http://localhost:11434/api/chat -d '{ } ``` +#### Chat request (Reproducible outputs) + +##### Request + +```shell +curl http://localhost:11434/api/chat -d '{ + "model": "llama2", + "messages": [ + { + "role": "user", + "content": "Hello!" + } + ], + "options": { + "seed": 101, + "temperature": 0 + } +}' +``` + +##### Response + +```json +{ + "model": "registry.ollama.ai/library/llama2:latest", + "created_at": "2023-12-12T14:13:43.416799Z", + "message": { + "role": "assistant", + "content": "Hello! How are you today?" + }, + "done": true, + "total_duration": 5191566416, + "load_duration": 2154458, + "prompt_eval_count": 26, + "prompt_eval_duration": 383809000, + "eval_count": 298, + "eval_duration": 4799921000 +} +``` + ## Create a Model ```shell @@ -958,6 +1016,7 @@ Generate embeddings from a model Advanced parameters: - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` +- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) ### Examples @@ -965,7 +1024,7 @@ Advanced parameters: ```shell curl http://localhost:11434/api/embeddings -d '{ - "model": "llama2", + "model": "all-minilm", "prompt": "Here is an article about llamas..." }' ``` diff --git a/docs/development.md b/docs/development.md index 59651b1f..3973123a 100644 --- a/docs/development.md +++ b/docs/development.md @@ -3,7 +3,7 @@ Install required tools: - cmake version 3.24 or higher -- go version 1.21 or higher +- go version 1.22 or higher - gcc version 11.4.0 or higher ```bash @@ -42,15 +42,15 @@ Now you can run `ollama`: #### Linux CUDA (NVIDIA) -*Your operating system distribution may already have packages for NVIDIA CUDA. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!* +_Your operating system distribution may already have packages for NVIDIA CUDA. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!_ Install `cmake` and `golang` as well as [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) -development and runtime packages. +development and runtime packages. Typically the build scripts will auto-detect CUDA, however, if your Linux distro or installation approach uses unusual paths, you can specify the location by specifying an environment variable `CUDA_LIB_DIR` to the location of the shared -libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize +libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70") Then generate dependencies: @@ -67,15 +67,15 @@ go build . #### Linux ROCm (AMD) -*Your operating system distribution may already have packages for AMD ROCm and CLBlast. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!* +_Your operating system distribution may already have packages for AMD ROCm and CLBlast. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!_ -Install [CLBlast](https://github.com/CNugteren/CLBlast/blob/master/doc/installation.md) and [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html) developement packages first, as well as `cmake` and `golang`. +Install [CLBlast](https://github.com/CNugteren/CLBlast/blob/master/doc/installation.md) and [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html) development packages first, as well as `cmake` and `golang`. Typically the build scripts will auto-detect ROCm, however, if your Linux distro or installation approach uses unusual paths, you can specify the location by specifying an environment variable `ROCM_PATH` to the location of the ROCm install (typically `/opt/rocm`), and `CLBlast_DIR` to the location of the -CLBlast install (typically `/usr/lib/cmake/CLBlast`). You can also customize +CLBlast install (typically `/usr/lib/cmake/CLBlast`). You can also customize the AMD GPU targets by setting AMDGPU_TARGETS (e.g. `AMDGPU_TARGETS="gfx1101;gfx1102"`) ``` @@ -88,17 +88,17 @@ Then build the binary: go build . ``` -ROCm requires elevated privileges to access the GPU at runtime. On most distros you can add your user account to the `render` group, or run as root. +ROCm requires elevated privileges to access the GPU at runtime. On most distros you can add your user account to the `render` group, or run as root. #### Advanced CPU Settings By default, running `go generate ./...` will compile a few different variations of the LLM library based on common CPU families and vector math capabilities, including a lowest-common-denominator which should run on almost any 64 bit CPU -somewhat slowly. At runtime, Ollama will auto-detect the optimal variation to -load. If you would like to build a CPU-based build customized for your +somewhat slowly. At runtime, Ollama will auto-detect the optimal variation to +load. If you would like to build a CPU-based build customized for your processor, you can set `OLLAMA_CUSTOM_CPU_DEFS` to the llama.cpp flags you would -like to use. For example, to compile an optimized binary for an Intel i9-9880H, +like to use. For example, to compile an optimized binary for an Intel i9-9880H, you might use: ``` @@ -108,8 +108,7 @@ go build . #### Containerized Linux Build -If you have Docker available, you can build linux binaries with `./scripts/build_linux.sh` which has the CUDA and ROCm dependencies included. The resulting binary is placed in `./dist` - +If you have Docker available, you can build linux binaries with `./scripts/build_linux.sh` which has the CUDA and ROCm dependencies included. The resulting binary is placed in `./dist` ### Windows @@ -117,8 +116,8 @@ Note: The windows build for Ollama is still under development. Install required tools: -- MSVC toolchain - C/C++ and cmake as minimal requirements -- go version 1.21 or higher +- MSVC toolchain - C/C++ and cmake as minimal requirements - You must build from a "Developer Shell" with the environment variables set +- go version 1.22 or higher - MinGW (pick one variant) with GCC. - - @@ -133,6 +132,13 @@ go build . #### Windows CUDA (NVIDIA) -In addition to the common Windows development tools described above, install: +In addition to the common Windows development tools described above, install CUDA **AFTER** you install MSVC. - [NVIDIA CUDA](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) + + +#### Windows ROCm (AMD Radeon) + +In addition to the common Windows development tools described above, install AMDs HIP package **AFTER** you install MSVC + +- [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html) \ No newline at end of file diff --git a/docs/faq.md b/docs/faq.md index fd94b39b..805f3fa4 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -2,12 +2,40 @@ ## How can I upgrade Ollama? -To upgrade Ollama, run the installation process again. On the Mac, click the Ollama icon in the menubar and choose the restart option if an update is available. +Ollama on macOS and Windows will automatically download updates. Click on the taskbar or menubar item and then click "Restart to update" to apply the update. Updates can also be installed by downloading the latest version [manually](https://ollama.com/download/). + +On Linux, re-run the install script: + +``` +curl -fsSL https://ollama.com/install.sh | sh +``` ## How can I view the logs? Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs. +## How can I specify the context window size? + +By default, Ollama uses a context window size of 2048 tokens. + +To change this when using `ollama run`, use `/set parameter`: + +``` +/set parameter num_ctx 4096 +``` + +When using the API, specify the `num_ctx` parameter: + +``` +curl http://localhost:11434/api/generate -d '{ + "model": "llama2", + "prompt": "Why is the sky blue?", + "options": { + "num_ctx": 4096 + } +}' +``` + ## How do I configure Ollama server? Ollama server can be configured with environment variables. @@ -46,6 +74,21 @@ If Ollama is run as a systemd service, environment variables should be set using systemctl restart ollama ``` +### Setting environment variables on Windows + +On windows, Ollama inherits your user and system environment variables. + +1. First Quit Ollama by clicking on it in the task bar + +2. Edit system environment variables from the control panel + +3. Edit or create New variable(s) for your user account for `OLLAMA_HOST`, `OLLAMA_MODELS`, etc. + +4. Click OK/Apply to save + +5. Run `ollama` from a new terminal window + + ## How can I expose Ollama on my network? Ollama binds 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable. @@ -60,8 +103,9 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e ## Where are models stored? -- macOS: `~/.ollama/models`. +- macOS: `~/.ollama/models` - Linux: `/usr/share/ollama/.ollama/models` +- Windows: `C:\Users\\.ollama\models` ### How do I set them to a different location? @@ -69,9 +113,9 @@ If a different directory needs to be used, set the environment variable `OLLAMA_ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform. -## Does Ollama send my prompts and answers back to Ollama.ai to use in any way? +## Does Ollama send my prompts and answers back to ollama.com? -No, Ollama runs entirely locally, and conversation data will never leave your machine. +No. Ollama runs locally, and conversation data does not leave your machine. ## How can I use Ollama in Visual Studio Code? @@ -115,3 +159,37 @@ This can impact both installing Ollama, as well as downloading models. Open `Control Panel > Networking and Internet > View network status and tasks` and click on `Change adapter settings` on the left panel. Find the `vEthernel (WSL)` adapter, right click and select `Properties`. Click on `Configure` and open the `Advanced` tab. Search through each of the properties until you find `Large Send Offload Version 2 (IPv4)` and `Large Send Offload Version 2 (IPv6)`. *Disable* both of these properties. + +## How can I pre-load a model to get faster response times? + +If you are using the API you can preload a model by sending the Ollama server an empty request. This works with both the `/api/generate` and `/api/chat` API endpoints. + +To preload the mistral model using the generate endpoint, use: +```shell +curl http://localhost:11434/api/generate -d '{"model": "mistral"}' +``` + +To use the chat completions endpoint, use: +```shell +curl http://localhost:11434/api/chat -d '{"model": "mistral"}' +``` + +## How do I keep a model loaded in memory or make it unload immediately? + +By default models are kept in memory for 5 minutes before being unloaded. This allows for quicker response times if you are making numerous requests to the LLM. You may, however, want to free up the memory before the 5 minutes have elapsed or keep the model loaded indefinitely. Use the `keep_alive` parameter with either the `/api/generate` and `/api/chat` API endpoints to control how long the model is left in memory. + +The `keep_alive` parameter can be set to: +* a duration string (such as "10m" or "24h") +* a number in seconds (such as 3600) +* any negative number which will keep the model loaded in memory (e.g. -1 or "-1m") +* '0' which will unload the model immediately after generating a response + +For example, to preload a model and leave it in memory use: +```shell +curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": -1}' +``` + +To unload the model and free up memory use: +```shell +curl http://localhost:11434/api/generate -d '{"model": "llama2", "keep_alive": 0}' +``` diff --git a/docs/import.md b/docs/import.md index e940f6d1..672916b5 100644 --- a/docs/import.md +++ b/docs/import.md @@ -15,7 +15,7 @@ FROM ./mistral-7b-v0.1.Q4_0.gguf (Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`: ``` -FROM ./q4_0.bin +FROM ./mistral-7b-v0.1.Q4_0.gguf TEMPLATE "[INST] {{ .Prompt }} [/INST]" ``` @@ -37,55 +37,69 @@ ollama run example "What is your favourite condiment?" ## Importing (PyTorch & Safetensors) -### Supported models +> Importing from PyTorch and Safetensors is a longer process than importing from GGUF. Improvements that make it easier are a work in progress. -Ollama supports a set of model architectures, with support for more coming soon: +### Setup -- Llama & Mistral -- Falcon & RW -- BigCode +First, clone the `ollama/ollama` repo: -To view a model's architecture, check the `config.json` file in its HuggingFace repo. You should see an entry under `architectures` (e.g. `LlamaForCausalLM`). +``` +git clone git@github.com:ollama/ollama.git ollama +cd ollama +``` -### Step 1: Clone the HuggingFace repository (optional) +and then fetch its `llama.cpp` submodule: + +```shell +git submodule init +git submodule update llm/llama.cpp +``` + +Next, install the Python dependencies: + +``` +python3 -m venv llm/llama.cpp/.venv +source llm/llama.cpp/.venv/bin/activate +pip install -r llm/llama.cpp/requirements.txt +``` + +Then build the `quantize` tool: + +``` +make -C llm/llama.cpp quantize +``` + +### Clone the HuggingFace repository (optional) If the model is currently hosted in a HuggingFace repository, first clone that repository to download the raw model. +Install [Git LFS](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage), verify it's installed, and then clone the model's repository: + ``` git lfs install -git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 -cd Mistral-7B-Instruct-v0.1 +git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 model ``` -### Step 2: Convert and quantize to a `.bin` file (optional, for PyTorch and Safetensors) +### Convert the model -If the model is in PyTorch or Safetensors format, a [Docker image](https://hub.docker.com/r/ollama/quantize) with the tooling required to convert and quantize models is available. - -First, Install [Docker](https://www.docker.com/get-started/). - -Next, to convert and quantize your model, run: +> Note: some model architectures require using specific convert scripts. For example, Qwen models require running `convert-hf-to-gguf.py` instead of `convert.py` ``` -docker run --rm -v .:/model ollama/quantize -q q4_0 /model +python llm/llama.cpp/convert.py ./model --outtype f16 --outfile converted.bin ``` -This will output two files into the directory: +### Quantize the model -- `f16.bin`: the model converted to GGUF -- `q4_0.bin` the model quantized to a 4-bit quantization (Ollama will use this file to create the Ollama model) +``` +llm/llama.cpp/quantize converted.bin quantized.bin q4_0 +``` ### Step 3: Write a `Modelfile` Next, create a `Modelfile` for your model: ``` -FROM ./q4_0.bin -``` - -(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`: - -``` -FROM ./q4_0.bin +FROM quantized.bin TEMPLATE "[INST] {{ .Prompt }} [/INST]" ``` @@ -109,9 +123,12 @@ ollama run example "What is your favourite condiment?" Publishing models is in early alpha. If you'd like to publish your model to share with others, follow these steps: -1. Create [an account](https://ollama.ai/signup) -2. Run `cat ~/.ollama/id_ed25519.pub` to view your Ollama public key. Copy this to the clipboard. -3. Add your public key to your [Ollama account](https://ollama.ai/settings/keys) +1. Create [an account](https://ollama.com/signup) +2. Copy your Ollama public key: + - macOS: `cat ~/.ollama/id_ed25519.pub` + - Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub` + - Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub` +3. Add your public key to your [Ollama account](https://ollama.com/settings/keys) Next, copy your model to your username's namespace: @@ -125,7 +142,7 @@ Then push the model: ollama push /example ``` -After publishing, your model will be available at `https://ollama.ai//example`. +After publishing, your model will be available at `https://ollama.com//example`. ## Quantization reference @@ -149,47 +166,3 @@ The quantization options are as follow (from highest highest to lowest levels of - `q6_K` - `q8_0` - `f16` - -## Manually converting & quantizing models - -### Prerequisites - -Start by cloning the `llama.cpp` repo to your machine in another directory: - -``` -git clone https://github.com/ggerganov/llama.cpp.git -cd llama.cpp -``` - -Next, install the Python dependencies: - -``` -pip install -r requirements.txt -``` - -Finally, build the `quantize` tool: - -``` -make quantize -``` - -### Convert the model - -Run the correct conversion script for your model architecture: - -```shell -# LlamaForCausalLM or MistralForCausalLM -python convert.py - -# FalconForCausalLM -python convert-falcon-hf-to-gguf.py - -# GPTBigCodeForCausalLM -python convert-starcoder-hf-to-gguf.py -``` - -### Quantize the model - -``` -quantize /ggml-model-f32.bin /q4_0.bin q4_0 -``` diff --git a/docs/linux.md b/docs/linux.md index abd63320..0ef4a30f 100644 --- a/docs/linux.md +++ b/docs/linux.md @@ -3,11 +3,21 @@ ## Install Install Ollama running this one-liner: + > + ```bash -curl https://ollama.ai/install.sh | sh +curl -fsSL https://ollama.com/install.sh | sh ``` +## AMD Radeon GPU support + +While AMD has contributed the `amdgpu` driver upstream to the official linux +kernel source, the version is older and may not support all ROCm features. We +recommend you install the latest driver from +https://www.amd.com/en/support/linux-drivers for best support of your Radeon +GPU. + ## Manual install ### Download the `ollama` binary @@ -15,7 +25,7 @@ curl https://ollama.ai/install.sh | sh Ollama is distributed as a self-contained binary. Download it to a directory in your PATH: ```bash -sudo curl -L https://ollama.ai/download/ollama-linux-amd64 -o /usr/bin/ollama +sudo curl -L https://ollama.com/download/ollama-linux-amd64 -o /usr/bin/ollama sudo chmod +x /usr/bin/ollama ``` @@ -62,6 +72,11 @@ Verify that the drivers are installed by running the following command, which sh nvidia-smi ``` +### Install ROCm (optional - for Radeon GPUs) +[Download and Install](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html) + +Make sure to install ROCm v6 + ### Start Ollama Start Ollama using `systemd`: @@ -75,13 +90,13 @@ sudo systemctl start ollama Update ollama by running the install script again: ```bash -curl https://ollama.ai/install.sh | sh +curl -fsSL https://ollama.com/install.sh | sh ``` Or by downloading the ollama binary: ```bash -sudo curl -L https://ollama.ai/download/ollama-linux-amd64 -o /usr/bin/ollama +sudo curl -L https://ollama.com/download/ollama-linux-amd64 -o /usr/bin/ollama sudo chmod +x /usr/bin/ollama ``` @@ -110,6 +125,7 @@ sudo rm $(which ollama) ``` Remove the downloaded models and Ollama service user and group: + ```bash sudo rm -r /usr/share/ollama sudo userdel ollama diff --git a/docs/modelfile.md b/docs/modelfile.md index 6d6ac152..1d0030f4 100644 --- a/docs/modelfile.md +++ b/docs/modelfile.md @@ -67,13 +67,13 @@ To use this: More examples are available in the [examples directory](../examples). -### `Modelfile`s in [ollama.ai/library][1] +### `Modelfile`s in [ollama.com/library][1] -There are two ways to view `Modelfile`s underlying the models in [ollama.ai/library][1]: +There are two ways to view `Modelfile`s underlying the models in [ollama.com/library][1]: - Option 1: view a details page from a model's tags page: - 1. Go to a particular model's tags (e.g. https://ollama.ai/library/llama2/tags) - 2. Click on a tag (e.g. https://ollama.ai/library/llama2:13b) + 1. Go to a particular model's tags (e.g. https://ollama.com/library/llama2/tags) + 2. Click on a tag (e.g. https://ollama.com/library/llama2:13b) 3. Scroll down to "Layers" - Note: if the [`FROM` instruction](#from-required) is not present, it means the model was created from a local file @@ -86,7 +86,7 @@ There are two ways to view `Modelfile`s underlying the models in [ollama.ai/libr # FROM llama2:13b FROM /root/.ollama/models/blobs/sha256:123abc - TEMPLATE """[INST] {{ if and .First .System }}<>{{ .System }}<> + TEMPLATE """[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] """ SYSTEM """""" @@ -154,31 +154,23 @@ PARAMETER ### TEMPLATE -`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model. +`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message, a user's message and the response from the model. Note: syntax may be model specific. Templates use Go [template syntax](https://pkg.go.dev/text/template). #### Template Variables -| Variable | Description | -| ----------------- | ------------------------------------------------------------------------------------------------------------- | -| `{{ .System }}` | The system message used to specify custom behavior, this must also be set in the Modelfile as an instruction. | -| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. | -| `{{ .Response }}` | The response from the LLM, if not specified response is appended to the end of the template. | -| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. | +| Variable | Description | +| ----------------- | --------------------------------------------------------------------------------------------- | +| `{{ .System }}` | The system message used to specify custom behavior. | +| `{{ .Prompt }}` | The user prompt message. | +| `{{ .Response }}` | The response from the model. When generating a response, text after this variable is omitted. | -```modelfile -TEMPLATE """ -{{- if .First }} -### System: -{{ .System }} -{{- end }} - -### User: -{{ .Prompt }} - -### Response: +``` +TEMPLATE """{{ if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant """ - -SYSTEM """""" ``` ### SYSTEM @@ -225,4 +217,4 @@ MESSAGE assistant yes - the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments. - Instructions can be in any order. In the examples, the `FROM` instruction is first to keep it easily readable. -[1]: https://ollama.ai/library +[1]: https://ollama.com/library diff --git a/docs/openai.md b/docs/openai.md new file mode 100644 index 00000000..5808ae75 --- /dev/null +++ b/docs/openai.md @@ -0,0 +1,141 @@ +# OpenAI compatibility + +> **Note:** OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/jmorganca/ollama/blob/main/docs/api.md). + +Ollama provides experimental compatibility with parts of the [OpenAI API](https://platform.openai.com/docs/api-reference) to help connect existing applications to Ollama. + +## Usage + +### OpenAI Python library + +```python +from openai import OpenAI + +client = OpenAI( + base_url='http://localhost:11434/v1/', + + # required but ignored + api_key='ollama', +) + +chat_completion = client.chat.completions.create( + messages=[ + { + 'role': 'user', + 'content': 'Say this is a test', + } + ], + model='llama2', +) +``` + +### OpenAI JavaScript library + +```javascript +import OpenAI from 'openai' + +const openai = new OpenAI({ + baseURL: 'http://localhost:11434/v1/', + + // required but ignored + apiKey: 'ollama', +}) + +const chatCompletion = await openai.chat.completions.create({ + messages: [{ role: 'user', content: 'Say this is a test' }], + model: 'llama2', +}) +``` + +### `curl` + +``` +curl http://localhost:11434/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llama2", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello!" + } + ] + }' +``` + +## Endpoints + +### `/v1/chat/completions` + +#### Supported features + +- [x] Chat completions +- [x] Streaming +- [x] JSON mode +- [x] Reproducible outputs +- [ ] Vision +- [ ] Function calling +- [ ] Logprobs + +#### Supported request fields + +- [x] `model` +- [x] `messages` + - [x] Text `content` + - [ ] Array of `content` parts +- [x] `frequency_penalty` +- [x] `presence_penalty` +- [x] `response_format` +- [x] `seed` +- [x] `stop` +- [x] `stream` +- [x] `temperature` +- [x] `top_p` +- [x] `max_tokens` +- [ ] `logit_bias` +- [ ] `tools` +- [ ] `tool_choice` +- [ ] `user` +- [ ] `n` + +#### Notes + +- Setting `seed` will always set `temperature` to `0` +- `finish_reason` will always be `stop` +- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached + +## Models + +Before using a model, pull it locally `ollama pull`: + +```shell +ollama pull llama2 +``` + +### Default model names + +For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name: + +``` +ollama cp llama2 gpt-3.5-turbo +``` + +Afterwards, this new model name can be specified the `model` field: + +```shell +curl http://localhost:11434/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "Hello!" + } + ] + }' +``` diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 1367194e..46f5bf51 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -1,60 +1,111 @@ -# How to troubleshoot issues - -Sometimes Ollama may not perform as expected. One of the best ways to figure out what happened is to take a look at the logs. Find the logs on Mac by running the command: - -```shell -cat ~/.ollama/logs/server.log -``` - -On Linux systems with systemd, the logs can be found with this command: - -```shell -journalctl -u ollama -``` - -When you run Ollama in a container, the logs go to stdout/stderr in the container: - -```shell -docker logs -``` -(Use `docker ps` to find the container name) - -If manually running `ollama serve` in a terminal, the logs will be on that terminal. - -Join the [Discord](https://discord.gg/ollama) for help interpreting the logs. - -## LLM libraries - -Ollama includes multiple LLM libraries compiled for different GPUs and CPU -vector features. Ollama tries to pick the best one based on the capabilities of -your system. If this autodetection has problems, or you run into other problems -(e.g. crashes in your GPU) you can workaround this by forcing a specific LLM -library. `cpu_avx2` will perform the best, followed by `cpu_avx` an the slowest -but most compatible is `cpu`. Rosetta emulation under MacOS will work with the -`cpu` library. - -In the server log, you will see a message that looks something like this (varies -from release to release): - -``` -Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5] -``` - -**Experimental LLM Library Override** - -You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass -autodetection, so for example, if you have a CUDA card, but want to force the -CPU LLM library with AVX2 vector support, use: - -``` -OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve -``` - -You can see what features your CPU has with the following. -``` -cat /proc/cpuinfo| grep flags | head -1 -``` - -## Known issues - -* N/A \ No newline at end of file +# How to troubleshoot issues + +Sometimes Ollama may not perform as expected. One of the best ways to figure out what happened is to take a look at the logs. Find the logs on **Mac** by running the command: + +```shell +cat ~/.ollama/logs/server.log +``` + +On **Linux** systems with systemd, the logs can be found with this command: + +```shell +journalctl -u ollama +``` + +When you run Ollama in a **container**, the logs go to stdout/stderr in the container: + +```shell +docker logs +``` +(Use `docker ps` to find the container name) + +If manually running `ollama serve` in a terminal, the logs will be on that terminal. + +When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `+R` and type in: +- `explorer %LOCALAPPDATA%\Ollama` to view logs +- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH) +- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored +- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories + +To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal +```powershell +$env:OLLAMA_DEBUG="1" +& "ollama app.exe" +``` + +Join the [Discord](https://discord.gg/ollama) for help interpreting the logs. + +## LLM libraries + +Ollama includes multiple LLM libraries compiled for different GPUs and CPU +vector features. Ollama tries to pick the best one based on the capabilities of +your system. If this autodetection has problems, or you run into other problems +(e.g. crashes in your GPU) you can workaround this by forcing a specific LLM +library. `cpu_avx2` will perform the best, followed by `cpu_avx` an the slowest +but most compatible is `cpu`. Rosetta emulation under MacOS will work with the +`cpu` library. + +In the server log, you will see a message that looks something like this (varies +from release to release): + +``` +Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5] +``` + +**Experimental LLM Library Override** + +You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass +autodetection, so for example, if you have a CUDA card, but want to force the +CPU LLM library with AVX2 vector support, use: + +``` +OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve +``` + +You can see what features your CPU has with the following. +``` +cat /proc/cpuinfo| grep flags | head -1 +``` + +## AMD Radeon GPU Support + +Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In +some cases you can force the system to try to use a similar LLVM target that is +close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4) +however, ROCm does not currently support this target. The closest support is +`gfx1030`. You can use the environment variable `HSA_OVERRIDE_GFX_VERSION` with +`x.y.z` syntax. So for example, to force the system to run on the RX 5400, you +would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the +server. If you have an unsupported AMD GPU you can experiment using the list of +supported types below. + +At this time, the known supported GPU types are the following LLVM Targets. +This table shows some example GPUs that map to these LLVM targets: +| **LLVM Target** | **An Example GPU** | +|-----------------|---------------------| +| gfx900 | Radeon RX Vega 56 | +| gfx906 | Radeon Instinct MI50 | +| gfx908 | Radeon Instinct MI100 | +| gfx90a | Radeon Instinct MI210 | +| gfx940 | Radeon Instinct MI300 | +| gfx941 | | +| gfx942 | | +| gfx1030 | Radeon PRO V620 | +| gfx1100 | Radeon PRO W7900 | +| gfx1101 | Radeon PRO W7700 | +| gfx1102 | Radeon RX 7600 | + +AMD is working on enhancing ROCm v6 to broaden support for families of GPUs in a +future release which should increase support for more GPUs. + +Reach out on [Discord](https://discord.gg/ollama) or file an +[issue](https://github.com/ollama/ollama/issues) for additional help. + +## Installing older versions on Linux + +If you run into problems on Linux and want to install an older version you can tell the install script +which version to install. + +```sh +curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION="0.1.27" sh +``` diff --git a/docs/tutorials/langchainpy.md b/docs/tutorials/langchainpy.md index ac712923..f6ee4fa3 100644 --- a/docs/tutorials/langchainpy.md +++ b/docs/tutorials/langchainpy.md @@ -42,12 +42,12 @@ text_splitter=RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0) all_splits = text_splitter.split_documents(data) ``` -It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install GPT4All chromadb` +It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install chromadb` ```python from langchain.embeddings import OllamaEmbeddings from langchain.vectorstores import Chroma -oembed = OllamaEmbeddings(base_url="http://localhost:11434", model="llama2") +oembed = OllamaEmbeddings(base_url="http://localhost:11434", model="nomic-embed-text") vectorstore = Chroma.from_documents(documents=all_splits, embedding=oembed) ``` @@ -66,7 +66,7 @@ The next thing is to send the question and the relevant parts of the docs to the ```python from langchain.chains import RetrievalQA qachain=RetrievalQA.from_chain_type(ollama, retriever=vectorstore.as_retriever()) -qachain({"query": question}) +qachain.invoke({"query": question}) ``` The answer received from this chain was: diff --git a/docs/tutorials/nvidia-jetson.md b/docs/tutorials/nvidia-jetson.md index 85cf741c..2d3adb98 100644 --- a/docs/tutorials/nvidia-jetson.md +++ b/docs/tutorials/nvidia-jetson.md @@ -17,7 +17,7 @@ Prerequisites: Here are the steps: -- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.ai/install.sh | sh` +- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh` - Stop the Ollama service: `sudo systemctl stop ollama` - Start Ollama serve in a tmux session called ollama_jetson and reference the CUDA libraries path: `tmux has-session -t ollama_jetson 2>/dev/null || tmux new-session -d -s ollama_jetson 'LD_LIBRARY_PATH=/usr/local/cuda/lib64 ollama serve'` diff --git a/docs/windows.md b/docs/windows.md new file mode 100644 index 00000000..49d579c9 --- /dev/null +++ b/docs/windows.md @@ -0,0 +1,47 @@ +# Ollama Windows Preview + +Welcome to the Ollama Windows preview. + +No more WSL required! + +Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support. +After installing Ollama Windows Preview, Ollama will run in the background and +the `ollama` command line is available in `cmd`, `powershell` or your favorite +terminal application. As usual the Ollama [api](./api.md) will be served on +`http://localhost:11434`. + +As this is a preview release, you should expect a few bugs here and there. If +you run into a problem you can reach out on +[Discord](https://discord.gg/ollama), or file an +[issue](https://github.com/ollama/ollama/issues). +Logs will often be helpful in dianosing the problem (see +[Troubleshooting](#troubleshooting) below) + +## System Requirements + +* Windows 10 or newer, Home or Pro +* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card +* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card + +## API Access + +Here's a quick example showing API access from `powershell` +```powershell +(Invoke-WebRequest -method POST -Body '{"model":"llama2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json +``` + +## Troubleshooting + +While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds +a "view logs" menu item to the app, and increses logging for the GUI app and +server. + +Ollama on Windows stores files in a few different locations. You can view them in +the explorer window by hitting `+R` and type in: +- `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates + - *app.log* contains logs from the GUI application + - *server.log* contains the server logs + - *upgrade.log* contains log output for upgrades +- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH) +- `explorer %HOMEPATH%\.ollama` contains models and configuration +- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories diff --git a/examples/jupyter-notebook/ollama.ipynb b/examples/jupyter-notebook/ollama.ipynb index d57e2057..bee353cb 100644 --- a/examples/jupyter-notebook/ollama.ipynb +++ b/examples/jupyter-notebook/ollama.ipynb @@ -8,7 +8,7 @@ "outputs": [], "source": [ "# Download and run the Ollama Linux install script\n", - "!curl https://ollama.ai/install.sh | sh\n", + "!curl -fsSL https://ollama.com/install.sh | sh\n", "!command -v systemctl >/dev/null && sudo systemctl stop ollama" ] }, diff --git a/examples/kubernetes/README.md b/examples/kubernetes/README.md index cb5f39f9..c522ba76 100644 --- a/examples/kubernetes/README.md +++ b/examples/kubernetes/README.md @@ -2,28 +2,28 @@ ## Prerequisites -- Ollama: https://ollama.ai/download +- Ollama: https://ollama.com/download - Kubernetes cluster. This example will use Google Kubernetes Engine. ## Steps 1. Create the Ollama namespace, daemon set, and service - ```bash - kubectl apply -f cpu.yaml - ``` + ```bash + kubectl apply -f cpu.yaml + ``` 1. Port forward the Ollama service to connect and use it locally - ```bash - kubectl -n ollama port-forward service/ollama 11434:80 - ``` + ```bash + kubectl -n ollama port-forward service/ollama 11434:80 + ``` 1. Pull and run a model, for example `orca-mini:3b` - ```bash - ollama run orca-mini:3b - ``` + ```bash + ollama run orca-mini:3b + ``` ## (Optional) Hardware Acceleration diff --git a/examples/langchain-python-rag-websummary/README.md b/examples/langchain-python-rag-websummary/README.md index 9ccc54cc..3f3b9873 100644 --- a/examples/langchain-python-rag-websummary/README.md +++ b/examples/langchain-python-rag-websummary/README.md @@ -1,6 +1,6 @@ # LangChain Web Summarization -This example summarizes the website, [https://ollama.ai/blog/run-llama2-uncensored-locally](https://ollama.ai/blog/run-llama2-uncensored-locally) +This example summarizes the website, [https://ollama.com/blog/run-llama2-uncensored-locally](https://ollama.com/blog/run-llama2-uncensored-locally) ## Running the Example diff --git a/examples/langchain-python-rag-websummary/main.py b/examples/langchain-python-rag-websummary/main.py index 2bb25d75..cd2ef47f 100644 --- a/examples/langchain-python-rag-websummary/main.py +++ b/examples/langchain-python-rag-websummary/main.py @@ -2,7 +2,7 @@ from langchain.llms import Ollama from langchain.document_loaders import WebBaseLoader from langchain.chains.summarize import load_summarize_chain -loader = WebBaseLoader("https://ollama.ai/blog/run-llama2-uncensored-locally") +loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally") docs = loader.load() llm = Ollama(model="llama2") diff --git a/examples/modelfile-tweetwriter/readme.md b/examples/modelfile-tweetwriter/readme.md deleted file mode 100644 index 51111259..00000000 --- a/examples/modelfile-tweetwriter/readme.md +++ /dev/null @@ -1,23 +0,0 @@ -# Example Modelfile - Tweetwriter - -This simple examples shows what you can do without any code, simply relying on a Modelfile. The file has two instructions: - -1. FROM - The From instructions defines the parent model to use for this one. If you choose a model from the library, you can enter just the model name. For all other models, you need to specify the namespace as well. You could also use a local file. Just include the relative path to the converted, quantized model weights file. To learn more about creating that file, see the `import.md` file in the docs folder of this repository. -2. SYSTEM - This defines the system prompt for the model and overrides the system prompt from the parent model. - -## Running the Example - -1. Create the model: - - ```bash - ollama create tweetwriter - ``` - -2. Enter a topic to generate a tweet about. -3. Show the Modelfile in the REPL. - - ```bash - /show modelfile - ``` - - Notice that the FROM and SYSTEM match what was in the file. But there is also a TEMPLATE and PARAMETER. These are inherited from the parent model. \ No newline at end of file diff --git a/examples/python-loganalysis/readme.md b/examples/python-loganalysis/readme.md index 828e8de2..60c57217 100644 --- a/examples/python-loganalysis/readme.md +++ b/examples/python-loganalysis/readme.md @@ -40,13 +40,13 @@ You are a log file analyzer. You will receive a set of lines from a log file for """ ``` -This model is available at https://ollama.ai/mattw/loganalyzer. You can customize it and add to your own namespace using the command `ollama create -f ` then `ollama push `. +This model is available at https://ollama.com/mattw/loganalyzer. You can customize it and add to your own namespace using the command `ollama create -f ` then `ollama push `. Then loganalysis.py scans all the lines in the given log file and searches for the word 'error'. When the word is found, the 10 lines before and after are set as the prompt for a call to the Generate API. ```python data = { - "prompt": "\n".join(error_logs), + "prompt": "\n".join(error_logs), "model": "mattw/loganalyzer" } ``` diff --git a/examples/typescript-mentors/README.md b/examples/typescript-mentors/README.md index 5ab1cc55..c3ce9c82 100644 --- a/examples/typescript-mentors/README.md +++ b/examples/typescript-mentors/README.md @@ -29,9 +29,9 @@ You can also add your own character to be chosen at random when you ask a questi ```bash ollama pull stablebeluga2:70b-q4_K_M ``` - + 2. Create a new character: - + ```bash npm run charactergen "Lorne Greene" ``` @@ -41,15 +41,15 @@ You can also add your own character to be chosen at random when you ask a questi 3. Now you can create a model with this command: ```bash - ollama create /lornegreene -f lornegreene/Modelfile + ollama create /lornegreene -f lornegreene/Modelfile ``` - `YourNamespace` is whatever name you set up when you signed up at [https://ollama.ai/signup](https://ollama.ai/signup). + `username` is whatever name you set up when you signed up at [https://ollama.com/signup](https://ollama.com/signup). -4. To add this to your mentors, you will have to update the code as follows. On line 8 of `mentors.ts`, add an object to the array, replacing `` with the namespace you used above. +4. To add this to your mentors, you will have to update the code as follows. On line 8 of `mentors.ts`, add an object to the array, replacing `` with the username you used above. ```bash - {ns: "", char: "Lorne Greene"} + {ns: "", char: "Lorne Greene"} ``` ## Review the Code diff --git a/format/openssh.go b/format/openssh.go deleted file mode 100644 index e642e358..00000000 --- a/format/openssh.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Code originally from https://go-review.googlesource.com/c/crypto/+/218620 - -// TODO: replace with upstream once the above change is merged and released. - -package format - -import ( - "crypto" - "crypto/ed25519" - "crypto/rand" - "encoding/binary" - "encoding/pem" - "fmt" - - "golang.org/x/crypto/ssh" -) - -const privateKeyAuthMagic = "openssh-key-v1\x00" - -type openSSHEncryptedPrivateKey struct { - CipherName string - KDFName string - KDFOptions string - KeysCount uint32 - PubKey []byte - KeyBlocks []byte -} - -type openSSHPrivateKey struct { - Check1 uint32 - Check2 uint32 - Keytype string - Rest []byte `ssh:"rest"` -} - -type openSSHEd25519PrivateKey struct { - Pub []byte - Priv []byte - Comment string - Pad []byte `ssh:"rest"` -} - -func OpenSSHPrivateKey(key crypto.PrivateKey, comment string) (*pem.Block, error) { - var check uint32 - if err := binary.Read(rand.Reader, binary.BigEndian, &check); err != nil { - return nil, err - } - - var pk1 openSSHPrivateKey - pk1.Check1 = check - pk1.Check2 = check - - var w openSSHEncryptedPrivateKey - w.KeysCount = 1 - - if k, ok := key.(*ed25519.PrivateKey); ok { - key = *k - } - - switch k := key.(type) { - case ed25519.PrivateKey: - pub, priv := k[32:], k - key := openSSHEd25519PrivateKey{ - Pub: pub, - Priv: priv, - Comment: comment, - } - - pk1.Keytype = ssh.KeyAlgoED25519 - pk1.Rest = ssh.Marshal(key) - - w.PubKey = ssh.Marshal(struct { - KeyType string - Pub []byte - }{ - ssh.KeyAlgoED25519, pub, - }) - default: - return nil, fmt.Errorf("ssh: unknown key type %T", k) - } - - w.KeyBlocks = openSSHPadding(ssh.Marshal(pk1), 8) - - w.CipherName, w.KDFName, w.KDFOptions = "none", "none", "" - - return &pem.Block{ - Type: "OPENSSH PRIVATE KEY", - Bytes: append([]byte(privateKeyAuthMagic), ssh.Marshal(w)...), - }, nil -} - -func openSSHPadding(block []byte, blocksize int) []byte { - for i, j := 0, len(block); (j+i)%blocksize != 0; i++ { - block = append(block, byte(i+1)) - } - - return block -} diff --git a/go.mod b/go.mod index 57ec2495..74f75b47 100644 --- a/go.mod +++ b/go.mod @@ -1,21 +1,43 @@ module github.com/jmorganca/ollama -go 1.21 +go 1.22 + +toolchain go1.22.0 require ( + github.com/containerd/console v1.0.3 + github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/emirpasic/gods v1.18.1 github.com/gin-gonic/gin v1.9.1 + github.com/golang/protobuf v1.5.0 + github.com/google/uuid v1.0.0 + github.com/mitchellh/mapstructure v1.5.0 github.com/olekukonko/tablewriter v0.0.5 github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.8.4 + github.com/x448/float16 v0.8.4 golang.org/x/sync v0.3.0 ) +require github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9 + require ( + github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc // indirect + github.com/chewxy/hm v1.0.0 // indirect + github.com/chewxy/math32 v1.0.8 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/google/flatbuffers v1.12.0 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/xtgo/set v1.0.0 // indirect + go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gonum.org/v1/gonum v0.8.2 // indirect + gorgonia.org/vecf32 v0.9.0 // indirect + gorgonia.org/vecf64 v0.9.0 // indirect ) require ( @@ -36,7 +58,6 @@ require ( github.com/mattn/go-isatty v0.0.19 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect @@ -48,6 +69,6 @@ require ( golang.org/x/sys v0.13.0 golang.org/x/term v0.13.0 golang.org/x/text v0.13.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect + google.golang.org/protobuf v1.30.0 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ff6bcbd9..d1a75b56 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,38 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= +github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc h1:zvQ6w7KwtQWgMQiewOF9tFtundRMVZFSAksNV6ogzuY= +github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc/go.mod h1:c9sxoIT3YgLxH4UhLOCKaBlEojuMhVYpk4Ntv3opUTQ= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= +github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= +github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= +github.com/chewxy/math32 v1.0.8 h1:fU5E4Ec4Z+5RtRAi3TovSxUjQPkgRh+HbP7tKB2OFbM= +github.com/chewxy/math32 v1.0.8/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARubLw= +github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 h1:cBzrdJPAFBsgCrDPnZxlp1dF2+k4r1kVpD7+1S1PVjY= +github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLcxEuYUlAd/EXyjc/v55nd3+47YAgWbSXVxPrNI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/gin-contrib/cors v1.4.0 h1:oJ6gwtUl3lqV0WEIwM/LxPF1QZ5qe2lGWdY2+bz7y0g= @@ -22,6 +44,7 @@ github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= @@ -34,15 +57,44 @@ github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QX github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/flatbuffers v1.12.0 h1:/PtAHvnBY4Kqnx/xCQ3OIV9uYcSFGScBsWI3Oogeh6w= +github.com/google/flatbuffers v1.12.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= +github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= @@ -63,6 +115,8 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -70,14 +124,17 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= -github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= -github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= +github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9 h1:DV4iXjNn6fGeDl1AkZ1I0QB/0DBjrc7kPpxHrmuDzW4= +github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9/go.mod h1:nR7l3gM6ubiOm+mCkmmUyIBUcBAyiUmW6dQrDZhugFE= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= @@ -91,6 +148,8 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.2.0/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -107,20 +166,63 @@ github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6 github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= +github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 h1:lGdhQUN/cnWdSH3291CUuxSEqc+AsGTiDxPP3r2J0l4= +go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= +golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -131,12 +233,56 @@ golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= +gonum.org/v1/gonum v0.8.2 h1:CCXrcPKiGGotvnN6jfUsKk4rRqm7q09/YbKb5xCEvtM= +gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= +gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f h1:Yv4xsIx7HZOoyUGSJ2ksDyWE2qIBXROsZKt2ny3hCGM= +google.golang.org/genproto v0.0.0-20200911024640-645f7a48b24f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.32.0 h1:zWTV+LMdc3kaiJMSTOFz2UgSBgx8RNQoTGiZu3fR9S0= +google.golang.org/grpc v1.32.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc/cmd/protoc-gen-go-grpc v0.0.0-20200910201057-6591123024b3/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= @@ -151,4 +297,10 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorgonia.org/vecf32 v0.9.0 h1:PClazic1r+JVJ1dEzRXgeiVl4g1/Hf/w+wUSqnco1Xg= +gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA= +gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A= +gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/gpu/amd_common.go b/gpu/amd_common.go new file mode 100644 index 00000000..deb931ff --- /dev/null +++ b/gpu/amd_common.go @@ -0,0 +1,58 @@ +//go:build linux || windows + +package gpu + +import ( + "fmt" + "log/slog" + "os" + "path/filepath" + "strconv" + "strings" +) + +// Determine if the given ROCm lib directory is usable by checking for existence of some glob patterns +func rocmLibUsable(libDir string) bool { + slog.Debug("evaluating potential rocm lib dir " + libDir) + for _, g := range ROCmLibGlobs { + res, _ := filepath.Glob(filepath.Join(libDir, g)) + if len(res) == 0 { + return false + } + } + return true +} + +func GetSupportedGFX(libDir string) ([]string, error) { + var ret []string + files, err := filepath.Glob(filepath.Join(libDir, "rocblas", "library", "TensileLibrary_lazy_gfx*.dat")) + if err != nil { + return nil, err + } + for _, file := range files { + ret = append(ret, strings.TrimSuffix(strings.TrimPrefix(filepath.Base(file), "TensileLibrary_lazy_"), ".dat")) + } + return ret, nil +} + +func amdSetVisibleDevices(ids []int, skip map[int]interface{}) { + // Set the visible devices if not already set + // TODO - does sort order matter? + devices := []string{} + for i := range ids { + slog.Debug(fmt.Sprintf("i=%d", i)) + if _, skipped := skip[i]; skipped { + slog.Debug("skipped") + continue + } + devices = append(devices, strconv.Itoa(i)) + } + slog.Debug(fmt.Sprintf("devices=%v", devices)) + + val := strings.Join(devices, ",") + err := os.Setenv("HIP_VISIBLE_DEVICES", val) + if err != nil { + slog.Warn(fmt.Sprintf("failed to set env: %s", err)) + } + slog.Debug("HIP_VISIBLE_DEVICES=" + val) +} diff --git a/gpu/amd_hip_windows.go b/gpu/amd_hip_windows.go new file mode 100644 index 00000000..14a6c7d6 --- /dev/null +++ b/gpu/amd_hip_windows.go @@ -0,0 +1,141 @@ +package gpu + +import ( + "fmt" + "log/slog" + "strconv" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + hipSuccess = 0 + hipErrorNoDevice = 100 +) + +type hipDevicePropMinimal struct { + Name [256]byte + unused1 [140]byte + GcnArchName [256]byte // gfx#### + iGPU int // Doesn't seem to actually report correctly + unused2 [128]byte +} + +// Wrap the amdhip64.dll library for GPU discovery +type HipLib struct { + dll windows.Handle + hipGetDeviceCount uintptr + hipGetDeviceProperties uintptr + hipMemGetInfo uintptr + hipSetDevice uintptr + hipDriverGetVersion uintptr +} + +func NewHipLib() (*HipLib, error) { + h, err := windows.LoadLibrary("amdhip64.dll") + if err != nil { + return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err) + } + hl := &HipLib{} + hl.dll = h + hl.hipGetDeviceCount, err = windows.GetProcAddress(hl.dll, "hipGetDeviceCount") + if err != nil { + return nil, err + } + hl.hipGetDeviceProperties, err = windows.GetProcAddress(hl.dll, "hipGetDeviceProperties") + if err != nil { + return nil, err + } + hl.hipMemGetInfo, err = windows.GetProcAddress(hl.dll, "hipMemGetInfo") + if err != nil { + return nil, err + } + hl.hipSetDevice, err = windows.GetProcAddress(hl.dll, "hipSetDevice") + if err != nil { + return nil, err + } + hl.hipDriverGetVersion, err = windows.GetProcAddress(hl.dll, "hipDriverGetVersion") + if err != nil { + return nil, err + } + return hl, nil +} + +// The hip library only evaluates the HIP_VISIBLE_DEVICES variable at startup +// so we have to unload/reset the library after we do our initial discovery +// to make sure our updates to that variable are processed by llama.cpp +func (hl *HipLib) Release() { + err := windows.FreeLibrary(hl.dll) + if err != nil { + slog.Warn(fmt.Sprintf("failed to unload amdhip64.dll: %s", err)) + } + hl.dll = 0 +} + +func (hl *HipLib) AMDDriverVersion() (string, error) { + if hl.dll == 0 { + return "", fmt.Errorf("dll has been unloaded") + } + var version int + status, _, err := syscall.SyscallN(hl.hipDriverGetVersion, uintptr(unsafe.Pointer(&version))) + if status != hipSuccess { + return "", fmt.Errorf("failed call to hipDriverGetVersion: %d %s", status, err) + } + return strconv.Itoa(version), nil +} + +func (hl *HipLib) HipGetDeviceCount() int { + if hl.dll == 0 { + slog.Error("dll has been unloaded") + return 0 + } + var count int + status, _, err := syscall.SyscallN(hl.hipGetDeviceCount, uintptr(unsafe.Pointer(&count))) + if status == hipErrorNoDevice { + slog.Info("AMD ROCm reports no devices found") + return 0 + } + if status != hipSuccess { + slog.Warn(fmt.Sprintf("failed call to hipGetDeviceCount: %d %s", status, err)) + } + return count +} + +func (hl *HipLib) HipSetDevice(device int) error { + if hl.dll == 0 { + return fmt.Errorf("dll has been unloaded") + } + status, _, err := syscall.SyscallN(hl.hipSetDevice, uintptr(device)) + if status != hipSuccess { + return fmt.Errorf("failed call to hipSetDevice: %d %s", status, err) + } + return nil +} + +func (hl *HipLib) HipGetDeviceProperties(device int) (*hipDevicePropMinimal, error) { + if hl.dll == 0 { + return nil, fmt.Errorf("dll has been unloaded") + } + var props hipDevicePropMinimal + status, _, err := syscall.SyscallN(hl.hipGetDeviceProperties, uintptr(unsafe.Pointer(&props)), uintptr(device)) + if status != hipSuccess { + return nil, fmt.Errorf("failed call to hipGetDeviceProperties: %d %s", status, err) + } + return &props, nil +} + +// free, total, err +func (hl *HipLib) HipMemGetInfo() (uint64, uint64, error) { + if hl.dll == 0 { + return 0, 0, fmt.Errorf("dll has been unloaded") + } + var totalMemory uint64 + var freeMemory uint64 + status, _, err := syscall.SyscallN(hl.hipMemGetInfo, uintptr(unsafe.Pointer(&freeMemory)), uintptr(unsafe.Pointer(&totalMemory))) + if status != hipSuccess { + return 0, 0, fmt.Errorf("failed call to hipMemGetInfo: %d %s", status, err) + } + return freeMemory, totalMemory, nil +} diff --git a/gpu/amd_linux.go b/gpu/amd_linux.go new file mode 100644 index 00000000..2a2d22b6 --- /dev/null +++ b/gpu/amd_linux.go @@ -0,0 +1,418 @@ +package gpu + +import ( + "bufio" + "errors" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "slices" + "strconv" + "strings" +) + +// Discovery logic for AMD/ROCm GPUs + +const ( + DriverVersionFile = "/sys/module/amdgpu/version" + AMDNodesSysfsDir = "/sys/class/kfd/kfd/topology/nodes/" + GPUPropertiesFileGlob = AMDNodesSysfsDir + "*/properties" + + // Prefix with the node dir + GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line + GPUUsedMemoryFileGlob = "mem_banks/*/used_memory" + RocmStandardLocation = "/opt/rocm/lib" +) + +var ( + // Used to validate if the given ROCm lib is usable + ROCmLibGlobs = []string{"libhipblas.so.2*", "rocblas"} // TODO - probably include more coverage of files here... +) + +// Gather GPU information from the amdgpu driver if any supported GPUs are detected +// HIP_VISIBLE_DEVICES will be set if we detect a mix of unsupported and supported devices +// and the user hasn't already set this variable +func AMDGetGPUInfo(resp *GpuInfo) { + // TODO - DRY this out with windows + if !AMDDetected() { + return + } + skip := map[int]interface{}{} + + // Opportunistic logging of driver version to aid in troubleshooting + ver, err := AMDDriverVersion() + if err == nil { + slog.Info("AMD Driver: " + ver) + } else { + // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU + slog.Warn(fmt.Sprintf("ollama recommends running the https://www.amd.com/en/support/linux-drivers: %s", err)) + } + + // If the user has specified exactly which GPUs to use, look up their memory + visibleDevices := os.Getenv("HIP_VISIBLE_DEVICES") + if visibleDevices != "" { + ids := []int{} + for _, idStr := range strings.Split(visibleDevices, ",") { + id, err := strconv.Atoi(idStr) + if err != nil { + slog.Warn(fmt.Sprintf("malformed HIP_VISIBLE_DEVICES=%s %s", visibleDevices, err)) + } else { + ids = append(ids, id) + } + } + amdProcMemLookup(resp, nil, ids) + return + } + + // Gather GFX version information from all detected cards + gfx := AMDGFXVersions() + verStrings := []string{} + for i, v := range gfx { + verStrings = append(verStrings, v.ToGFXString()) + if v.Major == 0 { + // Silently skip CPUs + skip[i] = struct{}{} + continue + } + if v.Major < 9 { + // TODO consider this a build-time setting if we can support 8xx family GPUs + slog.Warn(fmt.Sprintf("amdgpu [%d] too old %s", i, v.ToGFXString())) + skip[i] = struct{}{} + } + } + slog.Info(fmt.Sprintf("detected amdgpu versions %v", verStrings)) + + // Abort if all GPUs are skipped + if len(skip) >= len(gfx) { + slog.Info("all detected amdgpus are skipped, falling back to CPU") + return + } + + // If we got this far, then we have at least 1 GPU that's a ROCm candidate, so make sure we have a lib + libDir, err := AMDValidateLibDir() + if err != nil { + slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err)) + return + } + + gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION") + if gfxOverride == "" { + supported, err := GetSupportedGFX(libDir) + if err != nil { + slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err)) + return + } + slog.Debug(fmt.Sprintf("rocm supported GPU types %v", supported)) + + for i, v := range gfx { + if !slices.Contains[[]string, string](supported, v.ToGFXString()) { + slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported)) + // TODO - consider discrete markdown just for ROCM troubleshooting? + slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage") + skip[i] = struct{}{} + } else { + slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString())) + } + } + } else { + slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride) + } + + if len(skip) >= len(gfx) { + slog.Info("all detected amdgpus are skipped, falling back to CPU") + return + } + + ids := make([]int, len(gfx)) + i := 0 + for k := range gfx { + ids[i] = k + i++ + } + amdProcMemLookup(resp, skip, ids) + if resp.memInfo.DeviceCount == 0 { + return + } + if len(skip) > 0 { + amdSetVisibleDevices(ids, skip) + } +} + +// Walk the sysfs nodes for the available GPUs and gather information from them +// skipping over any devices in the skip map +func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) { + resp.memInfo.DeviceCount = 0 + resp.memInfo.TotalMemory = 0 + resp.memInfo.FreeMemory = 0 + if len(ids) == 0 { + slog.Debug("discovering all amdgpu devices") + entries, err := os.ReadDir(AMDNodesSysfsDir) + if err != nil { + slog.Warn(fmt.Sprintf("failed to read amdgpu sysfs %s - %s", AMDNodesSysfsDir, err)) + return + } + for _, node := range entries { + if !node.IsDir() { + continue + } + id, err := strconv.Atoi(node.Name()) + if err != nil { + slog.Warn("malformed amdgpu sysfs node id " + node.Name()) + continue + } + ids = append(ids, id) + } + } + slog.Debug(fmt.Sprintf("discovering amdgpu devices %v", ids)) + + for _, id := range ids { + if _, skipped := skip[id]; skipped { + continue + } + totalMemory := uint64(0) + usedMemory := uint64(0) + propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUTotalMemoryFileGlob) + propFiles, err := filepath.Glob(propGlob) + if err != nil { + slog.Warn(fmt.Sprintf("error looking up total GPU memory: %s %s", propGlob, err)) + } + // 1 or more memory banks - sum the values of all of them + for _, propFile := range propFiles { + fp, err := os.Open(propFile) + if err != nil { + slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", propFile, err)) + continue + } + defer fp.Close() + scanner := bufio.NewScanner(fp) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "size_in_bytes") { + ver := strings.Fields(line) + if len(ver) != 2 { + slog.Warn("malformed " + line) + continue + } + bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64) + if err != nil { + slog.Warn("malformed int " + line) + continue + } + totalMemory += bankSizeInBytes + } + } + } + if totalMemory == 0 { + continue + } + usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(id), GPUUsedMemoryFileGlob) + usedFiles, err := filepath.Glob(usedGlob) + if err != nil { + slog.Warn(fmt.Sprintf("error looking up used GPU memory: %s %s", usedGlob, err)) + continue + } + for _, usedFile := range usedFiles { + fp, err := os.Open(usedFile) + if err != nil { + slog.Warn(fmt.Sprintf("failed to open sysfs node file %s: %s", usedFile, err)) + continue + } + defer fp.Close() + data, err := io.ReadAll(fp) + if err != nil { + slog.Warn(fmt.Sprintf("failed to read sysfs node file %s: %s", usedFile, err)) + continue + } + used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64) + if err != nil { + slog.Warn(fmt.Sprintf("malformed used memory %s: %s", string(data), err)) + continue + } + usedMemory += used + } + slog.Info(fmt.Sprintf("[%d] amdgpu totalMemory %d", id, totalMemory)) + slog.Info(fmt.Sprintf("[%d] amdgpu freeMemory %d", id, (totalMemory - usedMemory))) + resp.memInfo.DeviceCount++ + resp.memInfo.TotalMemory += totalMemory + resp.memInfo.FreeMemory += (totalMemory - usedMemory) + } + if resp.memInfo.DeviceCount > 0 { + resp.Library = "rocm" + } +} + +// Quick check for AMD driver so we can skip amdgpu discovery if not present +func AMDDetected() bool { + // Some driver versions (older?) don't have a version file, so just lookup the parent dir + sysfsDir := filepath.Dir(DriverVersionFile) + _, err := os.Stat(sysfsDir) + if errors.Is(err, os.ErrNotExist) { + slog.Debug("amdgpu driver not detected " + sysfsDir) + return false + } else if err != nil { + slog.Debug(fmt.Sprintf("error looking up amd driver %s %s", sysfsDir, err)) + return false + } + return true +} + +func setupLink(source, target string) error { + if err := os.RemoveAll(target); err != nil { + return fmt.Errorf("failed to remove old rocm directory %s %w", target, err) + } + if err := os.Symlink(source, target); err != nil { + return fmt.Errorf("failed to create link %s => %s %w", source, target, err) + } + slog.Debug(fmt.Sprintf("host rocm linked %s => %s", source, target)) + return nil +} + +// Ensure the AMD rocm lib dir is wired up +// Prefer to use host installed ROCm, as long as it meets our minimum requirements +// failing that, tell the user how to download it on their own +func AMDValidateLibDir() (string, error) { + // We rely on the rpath compiled into our library to find rocm + // so we establish a symlink to wherever we find it on the system + // to /rocm + payloadsDir, err := PayloadsDir() + if err != nil { + return "", err + } + + // If we already have a rocm dependency wired, nothing more to do + rocmTargetDir := filepath.Clean(filepath.Join(payloadsDir, "..", "rocm")) + if rocmLibUsable(rocmTargetDir) { + return rocmTargetDir, nil + } + + // next to the running binary + exe, err := os.Executable() + if err == nil { + peerDir := filepath.Dir(exe) + if rocmLibUsable(peerDir) { + slog.Debug("detected ROCM next to ollama executable " + peerDir) + return rocmTargetDir, setupLink(peerDir, rocmTargetDir) + } + peerDir = filepath.Join(filepath.Dir(exe), "rocm") + if rocmLibUsable(peerDir) { + slog.Debug("detected ROCM next to ollama executable " + peerDir) + return rocmTargetDir, setupLink(peerDir, rocmTargetDir) + } + } + + // Well known ollama installer path + installedRocmDir := "/usr/share/ollama/lib/rocm" + if rocmLibUsable(installedRocmDir) { + return rocmTargetDir, setupLink(installedRocmDir, rocmTargetDir) + } + + // Prefer explicit HIP env var + hipPath := os.Getenv("HIP_PATH") + if hipPath != "" { + hipLibDir := filepath.Join(hipPath, "lib") + if rocmLibUsable(hipLibDir) { + slog.Debug("detected ROCM via HIP_PATH=" + hipPath) + return rocmTargetDir, setupLink(hipLibDir, rocmTargetDir) + } + } + + // Scan the library path for potential matches + ldPaths := strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":") + for _, ldPath := range ldPaths { + d, err := filepath.Abs(ldPath) + if err != nil { + continue + } + if rocmLibUsable(d) { + return rocmTargetDir, setupLink(d, rocmTargetDir) + } + } + + // Well known location(s) + if rocmLibUsable("/opt/rocm/lib") { + return rocmTargetDir, setupLink("/opt/rocm/lib", rocmTargetDir) + } + + // If we still haven't found a usable rocm, the user will have to install it on their own + slog.Warn("amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install") + return "", fmt.Errorf("no suitable rocm found, falling back to CPU") +} + +func AMDDriverVersion() (string, error) { + _, err := os.Stat(DriverVersionFile) + if err != nil { + return "", fmt.Errorf("amdgpu version file missing: %s %w", DriverVersionFile, err) + } + fp, err := os.Open(DriverVersionFile) + if err != nil { + return "", err + } + defer fp.Close() + verString, err := io.ReadAll(fp) + if err != nil { + return "", err + } + return strings.TrimSpace(string(verString)), nil +} + +func AMDGFXVersions() map[int]Version { + res := map[int]Version{} + matches, _ := filepath.Glob(GPUPropertiesFileGlob) + for _, match := range matches { + fp, err := os.Open(match) + if err != nil { + slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err)) + continue + } + defer fp.Close() + i, err := strconv.Atoi(filepath.Base(filepath.Dir(match))) + if err != nil { + slog.Debug(fmt.Sprintf("failed to parse node ID %s", err)) + continue + } + + scanner := bufio.NewScanner(fp) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "gfx_target_version") { + ver := strings.Fields(line) + if len(ver) != 2 || len(ver[1]) < 5 { + + if ver[1] == "0" { + // Silently skip the CPU + continue + } else { + slog.Debug("malformed " + line) + } + res[i] = Version{ + Major: 0, + Minor: 0, + Patch: 0, + } + continue + } + l := len(ver[1]) + patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32) + minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32) + major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32) + if err1 != nil || err2 != nil || err3 != nil { + slog.Debug("malformed int " + line) + continue + } + + res[i] = Version{ + Major: uint(major), + Minor: uint(minor), + Patch: uint(patch), + } + } + } + } + return res +} + +func (v Version) ToGFXString() string { + return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch) +} diff --git a/gpu/amd_windows.go b/gpu/amd_windows.go new file mode 100644 index 00000000..be1be567 --- /dev/null +++ b/gpu/amd_windows.go @@ -0,0 +1,180 @@ +package gpu + +import ( + "bytes" + "fmt" + "log/slog" + "os" + "path/filepath" + "slices" + "strings" +) + +const ( + RocmStandardLocation = "C:\\Program Files\\AMD\\ROCm\\5.7\\bin" // TODO glob? + + // TODO We're lookinng for this exact name to detect iGPUs since hipGetDeviceProperties never reports integrated==true + iGPUName = "AMD Radeon(TM) Graphics" +) + +var ( + // Used to validate if the given ROCm lib is usable + ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here... +) + +func AMDGetGPUInfo(resp *GpuInfo) { + hl, err := NewHipLib() + if err != nil { + slog.Debug(err.Error()) + return + } + defer hl.Release() + skip := map[int]interface{}{} + ids := []int{} + resp.memInfo.DeviceCount = 0 + resp.memInfo.TotalMemory = 0 + resp.memInfo.FreeMemory = 0 + + ver, err := hl.AMDDriverVersion() + if err == nil { + slog.Info("AMD Driver: " + ver) + } else { + // For now this is benign, but we may eventually need to fail compatibility checks + slog.Debug(fmt.Sprintf("error looking up amd driver version: %s", err)) + } + + // Note: the HIP library automatically handles HIP_VISIBLE_DEVICES + count := hl.HipGetDeviceCount() + if count == 0 { + return + } + libDir, err := AMDValidateLibDir() + if err != nil { + slog.Warn(fmt.Sprintf("unable to verify rocm library, will use cpu: %s", err)) + return + } + + var supported []string + gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION") + if gfxOverride == "" { + supported, err = GetSupportedGFX(libDir) + if err != nil { + slog.Warn(fmt.Sprintf("failed to lookup supported GFX types, falling back to CPU mode: %s", err)) + return + } + } else { + slog.Debug("skipping rocm gfx compatibility check with HSA_OVERRIDE_GFX_VERSION=" + gfxOverride) + } + + slog.Info(fmt.Sprintf("detected %d hip devices", count)) + for i := 0; i < count; i++ { + ids = append(ids, i) + err = hl.HipSetDevice(i) + if err != nil { + slog.Warn(fmt.Sprintf("[%d] %s", i, err)) + skip[i] = struct{}{} + continue + } + + props, err := hl.HipGetDeviceProperties(i) + if err != nil { + slog.Warn(fmt.Sprintf("[%d] %s", i, err)) + skip[i] = struct{}{} + continue + } + n := bytes.IndexByte(props.Name[:], 0) + name := string(props.Name[:n]) + slog.Info(fmt.Sprintf("[%d] Name: %s", i, name)) + n = bytes.IndexByte(props.GcnArchName[:], 0) + gfx := string(props.GcnArchName[:n]) + slog.Info(fmt.Sprintf("[%d] GcnArchName: %s", i, gfx)) + //slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0 + // TODO Why isn't props.iGPU accurate!? + if strings.EqualFold(name, iGPUName) { + slog.Info(fmt.Sprintf("iGPU detected [%d] skipping", i)) + skip[i] = struct{}{} + continue + } + if gfxOverride == "" { + if !slices.Contains[[]string, string](supported, gfx) { + slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, gfx, libDir, supported)) + // TODO - consider discrete markdown just for ROCM troubleshooting? + slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage") + skip[i] = struct{}{} + continue + } else { + slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, gfx)) + } + } + + totalMemory, freeMemory, err := hl.HipMemGetInfo() + if err != nil { + slog.Warn(fmt.Sprintf("[%d] %s", i, err)) + continue + } + + // TODO according to docs, freeMem may lie on windows! + slog.Info(fmt.Sprintf("[%d] Total Mem: %d", i, totalMemory)) + slog.Info(fmt.Sprintf("[%d] Free Mem: %d", i, freeMemory)) + resp.memInfo.DeviceCount++ + resp.memInfo.TotalMemory += totalMemory + resp.memInfo.FreeMemory += freeMemory + } + if resp.memInfo.DeviceCount > 0 { + resp.Library = "rocm" + } + // Abort if all GPUs are skipped + if len(skip) >= count { + slog.Info("all detected amdgpus are skipped, falling back to CPU") + return + } + if len(skip) > 0 { + amdSetVisibleDevices(ids, skip) + } + UpdatePath(libDir) +} + +func AMDValidateLibDir() (string, error) { + // On windows non-admins typically can't create links + // so instead of trying to rely on rpath and a link in + // $LibDir/rocm, we instead rely on setting PATH to point + // to the location of the ROCm library + + // Installer payload location if we're running the installed binary + exe, err := os.Executable() + if err == nil { + rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm") + if rocmLibUsable(rocmTargetDir) { + slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir) + return rocmTargetDir, nil + } + } + + // Installer payload (if we're running from some other location) + localAppData := os.Getenv("LOCALAPPDATA") + appDir := filepath.Join(localAppData, "Programs", "Ollama") + rocmTargetDir := filepath.Join(appDir, "rocm") + if rocmLibUsable(rocmTargetDir) { + slog.Debug("detected ollama installed ROCm at " + rocmTargetDir) + return rocmTargetDir, nil + } + + // Prefer explicit HIP env var + hipPath := os.Getenv("HIP_PATH") + if hipPath != "" { + hipLibDir := filepath.Join(hipPath, "bin") + if rocmLibUsable(hipLibDir) { + slog.Debug("detected ROCM via HIP_PATH=" + hipPath) + return hipLibDir, nil + } + } + + // Well known location(s) + if rocmLibUsable(RocmStandardLocation) { + return RocmStandardLocation, nil + } + + // Should not happen on windows since we include it in the installer, but stand-alone binary might hit this + slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm") + return "", fmt.Errorf("no suitable rocm found, falling back to CPU") +} diff --git a/gpu/assets.go b/gpu/assets.go new file mode 100644 index 00000000..dacfa5ee --- /dev/null +++ b/gpu/assets.go @@ -0,0 +1,67 @@ +package gpu + +import ( + "fmt" + "log/slog" + "os" + "path/filepath" + "runtime" + "strings" + "sync" +) + +var ( + lock sync.Mutex + payloadsDir = "" +) + +func PayloadsDir() (string, error) { + lock.Lock() + defer lock.Unlock() + if payloadsDir == "" { + tmpDir, err := os.MkdirTemp("", "ollama") + if err != nil { + return "", fmt.Errorf("failed to generate tmp dir: %w", err) + } + // We create a distinct subdirectory for payloads within the tmpdir + // This will typically look like /tmp/ollama3208993108/runners on linux + payloadsDir = filepath.Join(tmpDir, "runners") + } + return payloadsDir, nil +} + +func Cleanup() { + lock.Lock() + defer lock.Unlock() + if payloadsDir != "" { + // We want to fully clean up the tmpdir parent of the payloads dir + tmpDir := filepath.Clean(filepath.Join(payloadsDir, "..")) + slog.Debug("cleaning up", "dir", tmpDir) + err := os.RemoveAll(tmpDir) + if err != nil { + slog.Warn("failed to clean up", "dir", tmpDir, "err", err) + } + } +} + +func UpdatePath(dir string) { + if runtime.GOOS == "windows" { + tmpDir := filepath.Dir(dir) + pathComponents := strings.Split(os.Getenv("PATH"), ";") + i := 0 + for _, comp := range pathComponents { + if strings.EqualFold(comp, dir) { + return + } + // Remove any other prior paths to our temp dir + if !strings.HasPrefix(strings.ToLower(comp), strings.ToLower(tmpDir)) { + pathComponents[i] = comp + i++ + } + } + newPath := strings.Join(append([]string{dir}, pathComponents...), ";") + slog.Info(fmt.Sprintf("Updating PATH to %s", newPath)) + os.Setenv("PATH", newPath) + } + // linux and darwin rely on rpath +} diff --git a/gpu/gpu.go b/gpu/gpu.go index 6e67e653..e0c18e26 100644 --- a/gpu/gpu.go +++ b/gpu/gpu.go @@ -24,7 +24,6 @@ import ( type handles struct { cuda *C.cuda_handle_t - rocm *C.rocm_handle_t } var gpuMutex sync.Mutex @@ -54,39 +53,23 @@ var CudaWindowsGlobs = []string{ "c:\\Windows\\System32\\nvml.dll", } -var RocmLinuxGlobs = []string{ - "/opt/rocm*/lib*/librocm_smi64.so*", -} - -var RocmWindowsGlobs = []string{ - "c:\\Windows\\System32\\rocm_smi64.dll", -} - // Note: gpuMutex must already be held func initGPUHandles() { // TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing - gpuHandles = &handles{nil, nil} + gpuHandles = &handles{nil} var cudaMgmtName string var cudaMgmtPatterns []string - var rocmMgmtName string - var rocmMgmtPatterns []string switch runtime.GOOS { case "windows": cudaMgmtName = "nvml.dll" cudaMgmtPatterns = make([]string, len(CudaWindowsGlobs)) copy(cudaMgmtPatterns, CudaWindowsGlobs) - rocmMgmtName = "rocm_smi64.dll" - rocmMgmtPatterns = make([]string, len(RocmWindowsGlobs)) - copy(rocmMgmtPatterns, RocmWindowsGlobs) case "linux": cudaMgmtName = "libnvidia-ml.so" cudaMgmtPatterns = make([]string, len(CudaLinuxGlobs)) copy(cudaMgmtPatterns, CudaLinuxGlobs) - rocmMgmtName = "librocm_smi64.so" - rocmMgmtPatterns = make([]string, len(RocmLinuxGlobs)) - copy(rocmMgmtPatterns, RocmLinuxGlobs) default: return } @@ -101,16 +84,6 @@ func initGPUHandles() { return } } - - rocmLibPaths := FindGPULibs(rocmMgmtName, rocmMgmtPatterns) - if len(rocmLibPaths) > 0 { - rocm := LoadROCMMgmt(rocmLibPaths) - if rocm != nil { - slog.Info("Radeon GPU detected") - gpuHandles.rocm = rocm - return - } - } } func GetGPUInfo() GpuInfo { @@ -149,43 +122,10 @@ func GetGPUInfo() GpuInfo { slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)) } } - } else if gpuHandles.rocm != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") { - C.rocm_check_vram(*gpuHandles.rocm, &memInfo) - if memInfo.err != nil { - slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))) - C.free(unsafe.Pointer(memInfo.err)) - } else if memInfo.igpu_index >= 0 && memInfo.count == 1 { - // Only one GPU detected and it appears to be an integrated GPU - skip it - slog.Info("ROCm unsupported integrated GPU detected") - } else if memInfo.count > 0 { - if memInfo.igpu_index >= 0 { - // We have multiple GPUs reported, and one of them is an integrated GPU - // so we have to set the env var to bypass it - // If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it - val := os.Getenv("ROCR_VISIBLE_DEVICES") - if val == "" { - devices := []string{} - for i := 0; i < int(memInfo.count); i++ { - if i == int(memInfo.igpu_index) { - continue - } - devices = append(devices, strconv.Itoa(i)) - } - val = strings.Join(devices, ",") - os.Setenv("ROCR_VISIBLE_DEVICES", val) - } - slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val)) - } - resp.Library = "rocm" - var version C.rocm_version_resp_t - C.rocm_get_version(*gpuHandles.rocm, &version) - verString := C.GoString(version.str) - if version.status == 0 { - resp.Variant = "v" + verString - } else { - slog.Info(fmt.Sprintf("failed to look up ROCm version: %s", verString)) - } - C.free(unsafe.Pointer(version.str)) + } else { + AMDGetGPUInfo(&resp) + if resp.Library != "" { + return resp } } if resp.Library == "" { @@ -219,6 +159,15 @@ func getCPUMem() (memInfo, error) { } func CheckVRAM() (int64, error) { + userLimit := os.Getenv("OLLAMA_MAX_VRAM") + if userLimit != "" { + avail, err := strconv.ParseInt(userLimit, 10, 64) + if err != nil { + return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err) + } + slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail)) + return avail, nil + } gpuInfo := GetGPUInfo() if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") { // leave 10% or 1024MiB of VRAM free per GPU to handle unaccounted for overhead @@ -306,23 +255,6 @@ func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t { return nil } -func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t { - var resp C.rocm_init_resp_t - resp.rh.verbose = getVerboseState() - for _, libPath := range rocmLibPaths { - lib := C.CString(libPath) - defer C.free(unsafe.Pointer(lib)) - C.rocm_init(lib, &resp) - if resp.err != nil { - slog.Info(fmt.Sprintf("Unable to load ROCm management library %s: %s", libPath, C.GoString(resp.err))) - C.free(unsafe.Pointer(resp.err)) - } else { - return &resp.rh - } - } - return nil -} - func getVerboseState() C.uint16_t { if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { return C.uint16_t(1) diff --git a/gpu/gpu_darwin.go b/gpu/gpu_darwin.go index 97907511..f2856e29 100644 --- a/gpu/gpu_darwin.go +++ b/gpu/gpu_darwin.go @@ -2,32 +2,38 @@ package gpu +/* +#cgo CFLAGS: -x objective-c +#cgo LDFLAGS: -framework Foundation -framework CoreGraphics -framework Metal +#include "gpu_info_darwin.h" +*/ import "C" import ( + "fmt" + "log/slog" + "os" "runtime" - - "github.com/pbnjay/memory" + "strconv" ) // CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs func CheckVRAM() (int64, error) { + userLimit := os.Getenv("OLLAMA_MAX_VRAM") + if userLimit != "" { + avail, err := strconv.ParseInt(userLimit, 10, 64) + if err != nil { + return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err) + } + slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail)) + return avail, nil + } + if runtime.GOARCH == "amd64" { // gpu not supported, this may not be metal return 0, nil } - - // on macOS, there's already buffer for available vram (see below) so just return the total - systemMemory := int64(memory.TotalMemory()) - - // macOS limits how much memory is available to the GPU based on the amount of system memory - // TODO: handle case where iogpu.wired_limit_mb is set to a higher value - if systemMemory <= 36*1024*1024*1024 { - systemMemory = systemMemory * 2 / 3 - } else { - systemMemory = systemMemory * 3 / 4 - } - - return systemMemory, nil + recommendedMaxVRAM := int64(C.getRecommendedMaxVRAM()) + return recommendedMaxVRAM, nil } func GetGPUInfo() GpuInfo { diff --git a/gpu/gpu_info.h b/gpu/gpu_info.h index e52d2066..8186a3f0 100644 --- a/gpu/gpu_info.h +++ b/gpu/gpu_info.h @@ -53,7 +53,6 @@ void cpu_check_ram(mem_info_t *resp); #endif #include "gpu_info_cuda.h" -#include "gpu_info_rocm.h" #endif // __GPU_INFO_H__ #endif // __APPLE__ \ No newline at end of file diff --git a/gpu/gpu_info_cuda.c b/gpu/gpu_info_cuda.c index d877ff0c..509bf5c6 100644 --- a/gpu/gpu_info_cuda.c +++ b/gpu/gpu_info_cuda.c @@ -124,31 +124,31 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) { // When in verbose mode, report more information about // the card we discover, but don't fail on error ret = (*h.nvmlDeviceGetName)(device, buf, buflen); - if (ret != RSMI_STATUS_SUCCESS) { + if (ret != NVML_SUCCESS) { LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret); } else { LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf); } ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen); - if (ret != RSMI_STATUS_SUCCESS) { + if (ret != NVML_SUCCESS) { LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret); } else { LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf); } ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen); - if (ret != RSMI_STATUS_SUCCESS) { + if (ret != NVML_SUCCESS) { LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret); } else { LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf); } ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen); - if (ret != RSMI_STATUS_SUCCESS) { + if (ret != NVML_SUCCESS) { LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret); } else { LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf); } ret = (*h.nvmlDeviceGetBrand)(device, &brand); - if (ret != RSMI_STATUS_SUCCESS) { + if (ret != NVML_SUCCESS) { LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret); } else { LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand); @@ -156,7 +156,7 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) { } LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total); - LOG(h.verbose, "[%d] CUDA usedMem %ld\n", i, memInfo.free); + LOG(h.verbose, "[%d] CUDA usedMem %ld\n", i, memInfo.used); resp->total += memInfo.total; resp->free += memInfo.free; diff --git a/gpu/gpu_info_darwin.h b/gpu/gpu_info_darwin.h new file mode 100644 index 00000000..6ba30c0a --- /dev/null +++ b/gpu/gpu_info_darwin.h @@ -0,0 +1,3 @@ +#import +#include +uint64_t getRecommendedMaxVRAM(); diff --git a/gpu/gpu_info_darwin.m b/gpu/gpu_info_darwin.m new file mode 100644 index 00000000..06d7b69b --- /dev/null +++ b/gpu/gpu_info_darwin.m @@ -0,0 +1,11 @@ +//go:build darwin +#include "gpu_info_darwin.h" + +uint64_t getRecommendedMaxVRAM() +{ + id device = MTLCreateSystemDefaultDevice(); + uint64_t result = device.recommendedMaxWorkingSetSize; + CFRelease(device); + return result; +} + diff --git a/gpu/gpu_info_rocm.c b/gpu/gpu_info_rocm.c deleted file mode 100644 index 7ac88611..00000000 --- a/gpu/gpu_info_rocm.c +++ /dev/null @@ -1,198 +0,0 @@ -#ifndef __APPLE__ - -#include "gpu_info_rocm.h" - -#include - -void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) { - rsmi_status_t ret; - resp->err = NULL; - const int buflen = 256; - char buf[buflen + 1]; - int i; - struct lookup { - char *s; - void **p; - } l[] = { - {"rsmi_init", (void *)&resp->rh.rsmi_init}, - {"rsmi_shut_down", (void *)&resp->rh.rsmi_shut_down}, - {"rsmi_dev_memory_total_get", (void *)&resp->rh.rsmi_dev_memory_total_get}, - {"rsmi_dev_memory_usage_get", (void *)&resp->rh.rsmi_dev_memory_usage_get}, - {"rsmi_version_get", (void *)&resp->rh.rsmi_version_get}, - {"rsmi_num_monitor_devices", (void*)&resp->rh.rsmi_num_monitor_devices}, - {"rsmi_dev_id_get", (void*)&resp->rh.rsmi_dev_id_get}, - {"rsmi_dev_name_get", (void *)&resp->rh.rsmi_dev_name_get}, - {"rsmi_dev_brand_get", (void *)&resp->rh.rsmi_dev_brand_get}, - {"rsmi_dev_vendor_name_get", (void *)&resp->rh.rsmi_dev_vendor_name_get}, - {"rsmi_dev_vram_vendor_get", (void *)&resp->rh.rsmi_dev_vram_vendor_get}, - {"rsmi_dev_serial_number_get", (void *)&resp->rh.rsmi_dev_serial_number_get}, - {"rsmi_dev_subsystem_name_get", (void *)&resp->rh.rsmi_dev_subsystem_name_get}, - {"rsmi_dev_vbios_version_get", (void *)&resp->rh.rsmi_dev_vbios_version_get}, - {NULL, NULL}, - }; - - resp->rh.handle = LOAD_LIBRARY(rocm_lib_path, RTLD_LAZY); - if (!resp->rh.handle) { - char *msg = LOAD_ERR(); - snprintf(buf, buflen, - "Unable to load %s library to query for Radeon GPUs: %s\n", - rocm_lib_path, msg); - free(msg); - resp->err = strdup(buf); - return; - } - - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->rh.verbose, "wiring rocm management library functions in %s\n", rocm_lib_path); - - for (i = 0; l[i].s != NULL; i++) { - // TODO once we've squashed the remaining corner cases remove this log - LOG(resp->rh.verbose, "dlsym: %s\n", l[i].s); - - *l[i].p = LOAD_SYMBOL(resp->rh.handle, l[i].s); - if (!l[i].p) { - resp->rh.handle = NULL; - char *msg = LOAD_ERR(); - LOG(resp->rh.verbose, "dlerr: %s\n", msg); - UNLOAD_LIBRARY(resp->rh.handle); - snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, - msg); - free(msg); - resp->err = strdup(buf); - return; - } - } - - ret = (*resp->rh.rsmi_init)(0); - if (ret != RSMI_STATUS_SUCCESS) { - LOG(resp->rh.verbose, "rsmi_init err: %d\n", ret); - UNLOAD_LIBRARY(resp->rh.handle); - resp->rh.handle = NULL; - snprintf(buf, buflen, "rocm vram init failure: %d", ret); - resp->err = strdup(buf); - } - - return; -} - -void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) { - resp->err = NULL; - resp->igpu_index = -1; - uint64_t totalMem = 0; - uint64_t usedMem = 0; - rsmi_status_t ret; - const int buflen = 256; - char buf[buflen + 1]; - int i; - - if (h.handle == NULL) { - resp->err = strdup("rocm handle not initialized"); - return; - } - - ret = (*h.rsmi_num_monitor_devices)(&resp->count); - if (ret != RSMI_STATUS_SUCCESS) { - snprintf(buf, buflen, "unable to get device count: %d", ret); - resp->err = strdup(buf); - return; - } - LOG(h.verbose, "discovered %d ROCm GPU Devices\n", resp->count); - - resp->total = 0; - resp->free = 0; - for (i = 0; i < resp->count; i++) { - if (h.verbose) { - // When in verbose mode, report more information about - // the card we discover, but don't fail on error - ret = (*h.rsmi_dev_name_get)(i, buf, buflen); - if (ret != RSMI_STATUS_SUCCESS) { - LOG(h.verbose, "rsmi_dev_name_get failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] ROCm device name: %s\n", i, buf); - } - ret = (*h.rsmi_dev_brand_get)(i, buf, buflen); - if (ret != RSMI_STATUS_SUCCESS) { - LOG(h.verbose, "rsmi_dev_brand_get failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] ROCm brand: %s\n", i, buf); - } - ret = (*h.rsmi_dev_vendor_name_get)(i, buf, buflen); - if (ret != RSMI_STATUS_SUCCESS) { - LOG(h.verbose, "rsmi_dev_vendor_name_get failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] ROCm vendor: %s\n", i, buf); - } - ret = (*h.rsmi_dev_vram_vendor_get)(i, buf, buflen); - if (ret != RSMI_STATUS_SUCCESS) { - LOG(h.verbose, "rsmi_dev_vram_vendor_get failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] ROCm VRAM vendor: %s\n", i, buf); - } - ret = (*h.rsmi_dev_serial_number_get)(i, buf, buflen); - if (ret != RSMI_STATUS_SUCCESS) { - LOG(h.verbose, "rsmi_dev_serial_number_get failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] ROCm S/N: %s\n", i, buf); - } - ret = (*h.rsmi_dev_subsystem_name_get)(i, buf, buflen); - if (ret != RSMI_STATUS_SUCCESS) { - LOG(h.verbose, "rsmi_dev_subsystem_name_get failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] ROCm subsystem name: %s\n", i, buf); - } - ret = (*h.rsmi_dev_vbios_version_get)(i, buf, buflen); - if (ret != RSMI_STATUS_SUCCESS) { - LOG(h.verbose, "rsmi_dev_vbios_version_get failed: %d\n", ret); - } else { - LOG(h.verbose, "[%d] ROCm vbios version: %s\n", i, buf); - } - } - - // Get total memory - used memory for available memory - ret = (*h.rsmi_dev_memory_total_get)(i, RSMI_MEM_TYPE_VRAM, &totalMem); - if (ret != RSMI_STATUS_SUCCESS) { - snprintf(buf, buflen, "rocm total mem lookup failure: %d", ret); - resp->err = strdup(buf); - return; - } - ret = (*h.rsmi_dev_memory_usage_get)(i, RSMI_MEM_TYPE_VRAM, &usedMem); - if (ret != RSMI_STATUS_SUCCESS) { - snprintf(buf, buflen, "rocm usage mem lookup failure: %d", ret); - resp->err = strdup(buf); - return; - } - LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem); - LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem); - if (totalMem < 1024 * 1024 * 1024) { - // Do not add up integrated GPU memory capacity, it's a bogus 512M, and actually uses system memory - LOG(h.verbose, "[%d] ROCm integrated GPU\n", i); - resp->igpu_index = i; - } else { - resp->total += totalMem; - resp->free += totalMem - usedMem; - } - } -} - -void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) { - const int buflen = 256; - char buf[buflen + 1]; - if (h.handle == NULL) { - resp->str = strdup("rocm handle not initialized"); - resp->status = 1; - return; - } - rsmi_version_t ver; - rsmi_status_t ret; - ret = h.rsmi_version_get(&ver); - if (ret != RSMI_STATUS_SUCCESS) { - snprintf(buf, buflen, "unexpected response on version lookup %d", ret); - resp->status = 1; - } else { - snprintf(buf, buflen, "%d", ver.major); - resp->status = 0; - } - resp->str = strdup(buf); -} - -#endif // __APPLE__ diff --git a/gpu/gpu_info_rocm.h b/gpu/gpu_info_rocm.h deleted file mode 100644 index 0a8d50c0..00000000 --- a/gpu/gpu_info_rocm.h +++ /dev/null @@ -1,59 +0,0 @@ -#ifndef __APPLE__ -#ifndef __GPU_INFO_ROCM_H__ -#define __GPU_INFO_ROCM_H__ -#include "gpu_info.h" - -// Just enough typedef's to dlopen/dlsym for memory information -typedef enum rsmi_status_return { - RSMI_STATUS_SUCCESS = 0, - // Other values omitted for now... -} rsmi_status_t; - -typedef enum rsmi_memory_type { - RSMI_MEM_TYPE_VRAM = 0, - RSMI_MEM_TYPE_VIS_VRAM, - RSMI_MEM_TYPE_GTT, -} rsmi_memory_type_t; - - typedef struct { - uint32_t major; - uint32_t minor; - uint32_t patch; - const char *build; - } rsmi_version_t; - -typedef struct rocm_handle { - void *handle; - uint16_t verbose; - rsmi_status_t (*rsmi_init)(uint64_t); - rsmi_status_t (*rsmi_shut_down)(void); - rsmi_status_t (*rsmi_dev_memory_total_get)(uint32_t, rsmi_memory_type_t, uint64_t *); - rsmi_status_t (*rsmi_dev_memory_usage_get)(uint32_t, rsmi_memory_type_t, uint64_t *); - rsmi_status_t (*rsmi_version_get) (rsmi_version_t *version); - rsmi_status_t (*rsmi_num_monitor_devices) (uint32_t *); - rsmi_status_t (*rsmi_dev_id_get)(uint32_t, uint16_t *); - rsmi_status_t (*rsmi_dev_name_get) (uint32_t,char *,size_t); - rsmi_status_t (*rsmi_dev_brand_get) (uint32_t, char *, uint32_t); - rsmi_status_t (*rsmi_dev_vendor_name_get) (uint32_t, char *, uint32_t); - rsmi_status_t (*rsmi_dev_vram_vendor_get) (uint32_t, char *, uint32_t); - rsmi_status_t (*rsmi_dev_serial_number_get) (uint32_t, char *, uint32_t); - rsmi_status_t (*rsmi_dev_subsystem_name_get) (uint32_t, char *, uint32_t); - rsmi_status_t (*rsmi_dev_vbios_version_get) (uint32_t, char *, uint32_t); -} rocm_handle_t; - -typedef struct rocm_init_resp { - char *err; // If err is non-null handle is invalid - rocm_handle_t rh; -} rocm_init_resp_t; - -typedef struct rocm_version_resp { - rsmi_status_t status; - char *str; // Contains version or error string if status != 0 -} rocm_version_resp_t; - -void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp); -void rocm_check_vram(rocm_handle_t rh, mem_info_t *resp); -void rocm_get_version(rocm_handle_t rh, rocm_version_resp_t *resp); - -#endif // __GPU_INFO_ROCM_H__ -#endif // __APPLE__ \ No newline at end of file diff --git a/gpu/types.go b/gpu/types.go index 24fa4a24..67727180 100644 --- a/gpu/types.go +++ b/gpu/types.go @@ -16,3 +16,9 @@ type GpuInfo struct { // TODO add other useful attributes about the card here for discovery information } + +type Version struct { + Major uint + Minor uint + Patch uint +} diff --git a/llm/dyn_ext_server.c b/llm/dyn_ext_server.c index 47dc4e99..dab49f85 100644 --- a/llm/dyn_ext_server.c +++ b/llm/dyn_ext_server.c @@ -14,17 +14,14 @@ #define LOAD_LIBRARY(lib, flags) LoadLibrary(lib) #define LOAD_SYMBOL(handle, sym) GetProcAddress(handle, sym) #define UNLOAD_LIBRARY(handle) FreeLibrary(handle) -inline char *LOAD_ERR() { - LPSTR messageBuffer = NULL; - size_t size = FormatMessageA( - FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | - FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPSTR)&messageBuffer, 0, NULL); - char *resp = strdup(messageBuffer); - LocalFree(messageBuffer); - return resp; -} +#define LOAD_ERR() ({\ + LPSTR messageBuffer = NULL; \ + size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, \ + NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); \ + char *resp = strdup(messageBuffer); \ + LocalFree(messageBuffer); \ + resp; \ +}) #else #include #define LOAD_LIBRARY(lib, flags) dlopen(lib, flags) diff --git a/llm/dyn_ext_server.go b/llm/dyn_ext_server.go index 782fd382..e981be94 100644 --- a/llm/dyn_ext_server.go +++ b/llm/dyn_ext_server.go @@ -28,13 +28,13 @@ import ( "log/slog" "os" "path/filepath" - "runtime" "strings" "sync" "time" "unsafe" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/gpu" ) type dynExtServer struct { @@ -72,7 +72,7 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts slog.Info("concurrent llm servers not yet supported, waiting for prior server to complete") mutex.Lock() } - updatePath(filepath.Dir(library)) + gpu.UpdatePath(filepath.Dir(library)) libPath := C.CString(library) defer C.free(unsafe.Pointer(libPath)) resp := newExtServerResp(512) @@ -106,7 +106,12 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts sparams.memory_f16 = C.bool(opts.F16KV) sparams.use_mlock = C.bool(opts.UseMLock) sparams.use_mmap = C.bool(opts.UseMMap) - sparams.numa = C.bool(opts.UseNUMA) + + if opts.UseNUMA { + sparams.numa = C.int(1) + } else { + sparams.numa = C.int(0) + } sparams.lora_adapters = nil for i := 0; i < len(adapters); i++ { @@ -143,7 +148,8 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts } slog.Info("Initializing llama server") - initResp := newExtServerResp(128) + slog.Debug(fmt.Sprintf("server params: %+v", sparams)) + initResp := newExtServerResp(512) defer freeExtServerResp(initResp) C.dyn_llama_server_init(llm.s, &sparams, &initResp) if initResp.id < 0 { @@ -161,13 +167,10 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { resp := newExtServerResp(128) defer freeExtServerResp(resp) - var imageData []ImageData + if len(predict.Images) > 0 { - for cnt, i := range predict.Images { - imageData = append(imageData, ImageData{Data: i, ID: cnt}) - } + slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images))) } - slog.Info(fmt.Sprintf("loaded %d images", len(imageData))) request := map[string]any{ "prompt": predict.Prompt, @@ -189,7 +192,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu "penalize_nl": predict.Options.PenalizeNewline, "seed": predict.Options.Seed, "stop": predict.Options.Stop, - "image_data": imageData, + "image_data": predict.Images, "cache_prompt": true, } @@ -261,7 +264,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu }) } - if p.Stop { + if p.Stop || bool(result.stop) { fn(PredictResult{ Done: true, PromptEvalCount: p.Timings.PromptN, @@ -363,25 +366,3 @@ func (llm *dynExtServer) Close() { C.dyn_llama_server_stop(llm.s) mutex.Unlock() } - -func updatePath(dir string) { - if runtime.GOOS == "windows" { - tmpDir := filepath.Dir(dir) - pathComponents := strings.Split(os.Getenv("PATH"), ";") - i := 0 - for _, comp := range pathComponents { - if strings.EqualFold(comp, dir) { - return - } - // Remove any other prior paths to our temp dir - if !strings.HasPrefix(strings.ToLower(comp), strings.ToLower(tmpDir)) { - pathComponents[i] = comp - i++ - } - } - newPath := strings.Join(append([]string{dir}, pathComponents...), ";") - slog.Info(fmt.Sprintf("Updating PATH to %s", newPath)) - os.Setenv("PATH", newPath) - } - // linux and darwin rely on rpath -} diff --git a/llm/ext_server/ext_server.cpp b/llm/ext_server/ext_server.cpp index b59b46d2..4a9d120d 100644 --- a/llm/ext_server/ext_server.cpp +++ b/llm/ext_server/ext_server.cpp @@ -1,4 +1,5 @@ #include "ext_server.h" +#include // Necessary evil since the server types are not defined in a header #include "server.cpp" @@ -27,8 +28,24 @@ // Expose the llama server as a callable extern "C" API llama_server_context *llama = NULL; std::thread ext_server_thread; +bool shutting_down = false; +std::atomic_int recv_counter; +// RAII wrapper for tracking in-flight recv calls +class atomicRecv { + public: + atomicRecv(std::atomic &atomic) : atomic(atomic) { + ++this->atomic; + } + ~atomicRecv() { + --this->atomic; + } + private: + std::atomic &atomic; +}; + void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) { + recv_counter = 0; assert(err != NULL && sparams != NULL); log_set_target(stderr); if (!sparams->verbose_logging) { @@ -63,7 +80,7 @@ void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) { params.main_gpu = sparams->main_gpu; params.use_mlock = sparams->use_mlock; params.use_mmap = sparams->use_mmap; - params.numa = sparams->numa; + params.numa = (ggml_numa_strategy)sparams->numa; params.embedding = sparams->embedding; if (sparams->model != NULL) { params.model = sparams->model; @@ -94,18 +111,15 @@ void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) { } #endif - llama_backend_init(params.numa); + llama_backend_init(); + llama_numa_init(params.numa); - // load the model - if (!llama->load_model(params)) { - // TODO - consider modifying the logging logic or patching load_model so - // we can capture more detailed error messages and pass them back to the - // caller for better UX - err->id = -1; - snprintf(err->msg, err->msg_len, "error loading model %s", - params.model.c_str()); - return; - } + if (!llama->load_model(params)) { + // an error occured that was not thrown + err->id = -1; + snprintf(err->msg, err->msg_len, "error loading model %s", params.model.c_str()); + return; + } llama->initialize(); } catch (std::exception &e) { @@ -128,9 +142,9 @@ void llama_server_start() { llama->queue_tasks.on_new_task(std::bind( &llama_server_context::process_single_task, llama, std::placeholders::_1)); llama->queue_tasks.on_finish_multitask(std::bind( - &llama_server_context::on_finish_multitask, llama, std::placeholders::_1)); - llama->queue_tasks.on_all_tasks_finished(std::bind( - &llama_server_context::run_on_all_tasks_finished, llama)); + &llama_server_context::on_finish_multitask, llama, std::placeholders::_1)); + llama->queue_tasks.on_run_slots(std::bind( + &llama_server_context::update_slots, llama)); llama->queue_results.on_multitask_update(std::bind( &llama_server_queue::update_multitask, &llama->queue_tasks, @@ -151,7 +165,14 @@ void llama_server_start() { void llama_server_stop() { assert(llama != NULL); + // Shutdown any in-flight requests and block incoming requests. LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n"); + shutting_down = true; + + while (recv_counter.load() > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + // This may take a while for any pending tasks to drain // TODO - consider a timeout to cancel tasks if it's taking too long llama->queue_tasks.terminate(); @@ -159,6 +180,7 @@ void llama_server_stop() { delete llama; llama = NULL; LOG_TEE("llama server shutdown complete\n"); + shutting_down = false; } void llama_server_completion(const char *json_req, ext_server_resp_t *resp) { @@ -166,6 +188,9 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) { resp->id = -1; resp->msg[0] = '\0'; try { + if (shutting_down) { + throw std::runtime_error("server shutting down"); + } json data = json::parse(json_req); resp->id = llama->queue_tasks.get_new_id(); llama->queue_results.add_waiting_task_id(resp->id); @@ -180,13 +205,13 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) { void llama_server_completion_next_result(const int task_id, ext_server_task_result_t *resp) { assert(llama != NULL && resp != NULL); - std::string msg; resp->id = -1; resp->stop = false; resp->error = false; resp->json_resp = NULL; std::string result_json; try { + atomicRecv ar(recv_counter); task_result result = llama->queue_results.recv(task_id); result_json = result.result_json.dump(-1, ' ', false, json::error_handler_t::replace); @@ -203,6 +228,11 @@ void llama_server_completion_next_result(const int task_id, llama->request_cancel(task_id); LOG_TEE("next result removing waiting task ID: %d\n", task_id); llama->queue_results.remove_waiting_task_id(task_id); + } else if (shutting_down) { + LOG_TEE("aborting completion due to shutdown %d\n", task_id); + llama->request_cancel(task_id); + llama->queue_results.remove_waiting_task_id(task_id); + resp->stop = true; } } catch (std::exception &e) { resp->error = true; @@ -251,6 +281,9 @@ void llama_server_tokenize(const char *json_req, char **json_resp, err->id = 0; err->msg[0] = '\0'; try { + if (shutting_down) { + throw std::runtime_error("server shutting down"); + } const json body = json::parse(json_req); std::vector tokens; if (body.count("content") != 0) { @@ -284,6 +317,9 @@ void llama_server_detokenize(const char *json_req, char **json_resp, err->id = 0; err->msg[0] = '\0'; try { + if (shutting_down) { + throw std::runtime_error("server shutting down"); + } const json body = json::parse(json_req); std::string content; if (body.count("tokens") != 0) { @@ -311,6 +347,9 @@ void llama_server_embedding(const char *json_req, char **json_resp, err->id = 0; err->msg[0] = '\0'; try { + if (shutting_down) { + throw std::runtime_error("server shutting down"); + } const json body = json::parse(json_req); json prompt; if (body.count("content") != 0) { @@ -321,6 +360,7 @@ void llama_server_embedding(const char *json_req, char **json_resp, const int task_id = llama->queue_tasks.get_new_id(); llama->queue_results.add_waiting_task_id(task_id); llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1); + atomicRecv ar(recv_counter); task_result result = llama->queue_results.recv(task_id); std::string result_json = result.result_json.dump(); const std::string::size_type size = result_json.size() + 1; diff --git a/llm/ext_server/ext_server.h b/llm/ext_server/ext_server.h index 8eefb3cc..9b9ce2ec 100644 --- a/llm/ext_server/ext_server.h +++ b/llm/ext_server/ext_server.h @@ -41,7 +41,7 @@ typedef struct ext_server_params { int32_t main_gpu; // the GPU that is used for scratch and small tensors bool use_mlock; // force system to keep model in RAM bool use_mmap; // use mmap if possible - bool numa; // attempt optimizations that help on some NUMA systems + int numa; // attempt optimizations that help on some NUMA systems bool embedding; // get only sentence embedding ext_server_lora_adapter_t *lora_adapters; char *mmproj; diff --git a/llm/generate/gen_common.sh b/llm/generate/gen_common.sh index 43d3dce5..d8d906b4 100644 --- a/llm/generate/gen_common.sh +++ b/llm/generate/gen_common.sh @@ -1,4 +1,4 @@ -# common logic accross linux and darwin +# common logic across linux and darwin init_vars() { case "${GOARCH}" in @@ -65,15 +65,17 @@ apply_patches() { echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt fi - # apply temporary patches until fix is upstream - for patch in ../patches/*.diff; do - for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do - (cd ${LLAMACPP_DIR}; git checkout ${file}) + if [ -n "$(ls -A ../patches/*.diff)" ]; then + # apply temporary patches until fix is upstream + for patch in ../patches/*.diff; do + for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do + (cd ${LLAMACPP_DIR}; git checkout ${file}) + done done - done - for patch in ../patches/*.diff; do - (cd ${LLAMACPP_DIR} && git apply ${patch}) - done + for patch in ../patches/*.diff; do + (cd ${LLAMACPP_DIR} && git apply ${patch}) + done + fi # Avoid duplicate main symbols when we link into the cgo binary sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp && @@ -99,7 +101,7 @@ compress_libs() { pids="" rm -rf ${BUILD_DIR}/lib/*.${LIB_EXT}*.gz for lib in ${BUILD_DIR}/lib/*.${LIB_EXT}* ; do - gzip --best -f ${lib} & + gzip -n --best -f ${lib} & pids+=" $!" done echo @@ -112,4 +114,12 @@ compress_libs() { # Keep the local tree clean after we're done with the build cleanup() { (cd ${LLAMACPP_DIR}/examples/server/ && git checkout CMakeLists.txt server.cpp) + + if [ -n "$(ls -A ../patches/*.diff)" ]; then + for patch in ../patches/*.diff; do + for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do + (cd ${LLAMACPP_DIR}; git checkout ${file}) + done + done + fi } diff --git a/llm/generate/gen_darwin.sh b/llm/generate/gen_darwin.sh index 4b806b02..59bdc801 100755 --- a/llm/generate/gen_darwin.sh +++ b/llm/generate/gen_darwin.sh @@ -60,7 +60,7 @@ case "${GOARCH}" in compress_libs ;; "arm64") - CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}" + CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_METAL_EMBED_LIBRARY=on -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}" BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/metal" EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders" build diff --git a/llm/generate/gen_linux.sh b/llm/generate/gen_linux.sh index 82c8c75b..c2437c99 100755 --- a/llm/generate/gen_linux.sh +++ b/llm/generate/gen_linux.sh @@ -21,7 +21,6 @@ amdGPUs() { return fi GPU_LIST=( - "gfx803" "gfx900" "gfx906:xnack-" "gfx908:xnack-" @@ -128,6 +127,11 @@ if [ -z "${CUDA_LIB_DIR}" ] && [ -d /opt/cuda/targets/x86_64-linux/lib ]; then CUDA_LIB_DIR=/opt/cuda/targets/x86_64-linux/lib fi +# Allow override in case libcudart is in the wrong place +if [ -z "${CUDART_LIB_DIR}" ]; then + CUDART_LIB_DIR="${CUDA_LIB_DIR}" +fi + if [ -d "${CUDA_LIB_DIR}" ]; then echo "CUDA libraries detected - building dynamic CUDA library" init_vars @@ -151,6 +155,8 @@ if [ -d "${CUDA_LIB_DIR}" ]; then cp "${CUDA_LIB_DIR}/${DEP}" "${BUILD_DIR}/lib/" elif [ -e "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" ]; then cp "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" "${BUILD_DIR}/lib/" + elif [ -e "${CUDART_LIB_DIR}/${lib}" ]; then + cp -d ${CUDART_LIB_DIR}/${lib}* "${BUILD_DIR}/lib/" else cp -d "${CUDA_LIB_DIR}/${lib}*" "${BUILD_DIR}/lib/" fi @@ -173,17 +179,27 @@ fi if [ -d "${ROCM_PATH}" ]; then echo "ROCm libraries detected - building dynamic ROCm library" - if [ -f ${ROCM_PATH}/lib/librocm_smi64.so.? ]; then - ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocm_smi64.so.? | cut -f3 -d. || true) + if [ -f ${ROCM_PATH}/lib/librocblas.so.*.*.????? ]; then + ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocblas.so.*.*.????? | cut -f5 -d. || true) fi init_vars CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)" BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/rocm${ROCM_VARIANT}" - EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,${ROCM_PATH}/lib,-rpath,/opt/amdgpu/lib/x86_64-linux-gnu/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu" + EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu" build - # Note: the ROCM libs and runtime library files are too large to embed, so we depend on - # them being present at runtime on the host + # Record the ROCM dependencies + rm -f "${BUILD_DIR}/lib/deps.txt" + touch "${BUILD_DIR}/lib/deps.txt" + for dep in $(ldd "${BUILD_DIR}/lib/libext_server.so" | grep "=>" | cut -f2 -d= | cut -f2 -d' ' | grep -e rocm -e amdgpu -e libtinfo ); do + echo "${dep}" >> "${BUILD_DIR}/lib/deps.txt" + done + # bomb out if for some reason we didn't get a few deps + if [ $(cat "${BUILD_DIR}/lib/deps.txt" | wc -l ) -lt 8 ] ; then + cat "${BUILD_DIR}/lib/deps.txt" + echo "ERROR: deps file short" + exit 1 + fi compress_libs fi diff --git a/llm/generate/gen_windows.ps1 b/llm/generate/gen_windows.ps1 index f7a241cc..579b2bca 100644 --- a/llm/generate/gen_windows.ps1 +++ b/llm/generate/gen_windows.ps1 @@ -2,23 +2,58 @@ $ErrorActionPreference = "Stop" +function amdGPUs { + if ($env:AMDGPU_TARGETS) { + return $env:AMDGPU_TARGETS + } + # TODO - load from some common data file for linux + windows build consistency + $GPU_LIST = @( + "gfx900" + "gfx906:xnack-" + "gfx908:xnack-" + "gfx90a:xnack+" + "gfx90a:xnack-" + "gfx1010" + "gfx1012" + "gfx1030" + "gfx1100" + "gfx1101" + "gfx1102" + ) + $GPU_LIST -join ';' +} + function init_vars { + # Verify the environment is a Developer Shell for MSVC 2019 + write-host $env:VSINSTALLDIR + if (($env:VSINSTALLDIR -eq $null)) { + Write-Error "`r`nBUILD ERROR - YOUR DEVELOPMENT ENVIRONMENT IS NOT SET UP CORRECTLY`r`nTo build Ollama you must run from an MSVC Developer Shell`r`nSee .\docs\development.md for instructions to set up your dev environment" + exit 1 + } + $script:SRC_DIR = $(resolve-path "..\..\") $script:llamacppDir = "../llama.cpp" - $script:cmakeDefs = @("-DBUILD_SHARED_LIBS=on", "-DLLAMA_NATIVE=off", "-A","x64") + $script:cmakeDefs = @( + "-DBUILD_SHARED_LIBS=on", + "-DLLAMA_NATIVE=off" + ) $script:cmakeTargets = @("ext_server") $script:ARCH = "amd64" # arm not yet supported. if ($env:CGO_CFLAGS -contains "-g") { - $script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on") + $script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo") $script:config = "RelWithDebInfo" } else { - $script:cmakeDefs += @("-DLLAMA_SERVER_VERBOSE=off") + $script:cmakeDefs += @("-DLLAMA_SERVER_VERBOSE=off", "-DCMAKE_BUILD_TYPE=Release") $script:config = "Release" } + if ($null -ne $env:CMAKE_SYSTEM_VERSION) { + $script:cmakeDefs += @("-DCMAKE_SYSTEM_VERSION=${env:CMAKE_SYSTEM_VERSION}") + } # Try to find the CUDA dir if ($env:CUDA_LIB_DIR -eq $null) { $d=(get-command -ea 'silentlycontinue' nvcc).path if ($d -ne $null) { $script:CUDA_LIB_DIR=($d| split-path -parent) + $script:CUDA_INCLUDE_DIR=($script:CUDA_LIB_DIR|split-path -parent)+"\include" } } else { $script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR @@ -30,6 +65,11 @@ function init_vars { } else { $script:CMAKE_CUDA_ARCHITECTURES=$env:CMAKE_CUDA_ARCHITECTURES } + # Note: 10 Windows Kit signtool crashes with GCP's plugin + ${script:SignTool}="C:\Program Files (x86)\Windows Kits\8.1\bin\x64\signtool.exe" + if ("${env:KEY_CONTAINER}") { + ${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt") + } } function git_module_setup { @@ -56,8 +96,8 @@ function apply_patches { } # Checkout each file + Set-Location -Path ${script:llamacppDir} foreach ($file in $filePaths) { - Set-Location -Path ${script:llamacppDir} git checkout $file } } @@ -89,13 +129,23 @@ function install { md "${script:buildDir}/lib" -ea 0 > $null cp "${script:buildDir}/bin/${script:config}/ext_server.dll" "${script:buildDir}/lib" cp "${script:buildDir}/bin/${script:config}/llama.dll" "${script:buildDir}/lib" - # Display the dll dependencies in the build log if ($script:DUMPBIN -ne $null) { & "$script:DUMPBIN" /dependents "${script:buildDir}/bin/${script:config}/ext_server.dll" | select-string ".dll" } } +function sign { + if ("${env:KEY_CONTAINER}") { + write-host "Signing ${script:buildDir}/lib/*.dll" + foreach ($file in (get-childitem "${script:buildDir}/lib/*.dll")){ + & "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" ` + /csp "Google Cloud KMS Provider" /kc "${env:KEY_CONTAINER}" $file + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } + } +} + function compress_libs { if ($script:GZIP -eq $null) { write-host "gzip not installed, not compressing files" @@ -109,8 +159,23 @@ function compress_libs { } function cleanup { + $patches = Get-ChildItem "../patches/*.diff" + foreach ($patch in $patches) { + # Extract file paths from the patch file + $filePaths = Get-Content $patch.FullName | Where-Object { $_ -match '^\+\+\+ ' } | ForEach-Object { + $parts = $_ -split ' ' + ($parts[1] -split '/', 2)[1] + } + + # Checkout each file + Set-Location -Path ${script:llamacppDir} + foreach ($file in $filePaths) { + git checkout $file + } + } Set-Location "${script:llamacppDir}/examples/server" git checkout CMakeLists.txt server.cpp + } init_vars @@ -122,54 +187,89 @@ apply_patches # -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen # -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver -$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on", "-DLLAMA_NATIVE=off") +$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on") -$script:cmakeDefs = $script:commonCpuDefs + @("-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs +init_vars +$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs $script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu" write-host "Building LCD CPU" build install +sign compress_libs -$script:cmakeDefs = $script:commonCpuDefs + @("-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs +init_vars +$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs $script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu_avx" write-host "Building AVX CPU" build install +sign compress_libs -$script:cmakeDefs = $script:commonCpuDefs + @("-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs +init_vars +$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs $script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu_avx2" write-host "Building AVX2 CPU" build install +sign compress_libs if ($null -ne $script:CUDA_LIB_DIR) { # Then build cuda as a dynamically loaded library - $nvcc = (get-command -ea 'silentlycontinue' nvcc) - if ($null -ne $nvcc) { - $script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename - } + $nvcc = "$script:CUDA_LIB_DIR\nvcc.exe" + $script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename if ($null -ne $script:CUDA_VERSION) { $script:CUDA_VARIANT="_"+$script:CUDA_VERSION } init_vars $script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT" - $script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}") + $script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}") + write-host "Building CUDA" build install - cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib" - cp "${script:CUDA_LIB_DIR}/cublas64_*.dll" "${script:buildDir}/lib" - cp "${script:CUDA_LIB_DIR}/cublasLt64_*.dll" "${script:buildDir}/lib" + sign compress_libs } -# TODO - actually implement ROCm support on windows -$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/rocm" -rm -ea 0 -recurse -force -path "${script:buildDir}/lib" -md "${script:buildDir}/lib" -ea 0 > $null -echo $null >> "${script:buildDir}/lib/.generated" +if ($null -ne $env:HIP_PATH) { + $script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename + if ($null -ne $script:ROCM_VERSION) { + $script:ROCM_VARIANT="_v"+$script:ROCM_VERSION + } + + init_vars + $script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT" + $script:cmakeDefs += @( + "-G", "Ninja", + "-DCMAKE_C_COMPILER=clang.exe", + "-DCMAKE_CXX_COMPILER=clang++.exe", + "-DLLAMA_HIPBLAS=on", + "-DLLAMA_AVX=on", + "-DLLAMA_AVX2=off", + "-DCMAKE_POSITION_INDEPENDENT_CODE=on", + "-DAMDGPU_TARGETS=$(amdGPUs)", + "-DGPU_TARGETS=$(amdGPUs)" + ) + + # Make sure the ROCm binary dir is first in the path + $env:PATH="$env:HIP_PATH\bin;$env:VSINSTALLDIR\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja;$env:PATH" + + # We have to clobber the LIB var from the developer shell for clang to work properly + $env:LIB="" + + write-host "Building ROCm" + build + # Ninja doesn't prefix with config name + ${script:config}="" + install + if ($null -ne $script:DUMPBIN) { + & "$script:DUMPBIN" /dependents "${script:buildDir}/bin/${script:config}/ext_server.dll" | select-string ".dll" + } + sign + compress_libs +} cleanup -write-host "`ngo generate completed" \ No newline at end of file +write-host "`ngo generate completed" diff --git a/llm/ggla.go b/llm/ggla.go new file mode 100644 index 00000000..e22dd59f --- /dev/null +++ b/llm/ggla.go @@ -0,0 +1,152 @@ +package llm + +import ( + "encoding/binary" + "errors" + "io" + "slices" +) + +type ContainerGGLA struct { + version uint32 +} + +func (c *ContainerGGLA) Name() string { + return "ggla" +} + +func (c *ContainerGGLA) Decode(rso *readSeekOffset) (model, error) { + binary.Read(rso, binary.LittleEndian, &c.version) + + switch c.version { + case 1: + default: + return nil, errors.New("invalid version") + } + + model := newModelGGLA(c) + err := model.decode(rso) + return model, err +} + +type ModelGGLA struct { + *ContainerGGLA + + kv KV + tensors []Tensor +} + +func newModelGGLA(container *ContainerGGLA) *ModelGGLA { + return &ModelGGLA{ + ContainerGGLA: container, + kv: make(KV), + } +} + +func (m *ModelGGLA) decode(rso *readSeekOffset) error { + var r uint32 + if err := binary.Read(rso, binary.LittleEndian, &r); err != nil { + return err + } + m.kv["r"] = r + + var alpha uint32 + if err := binary.Read(rso, binary.LittleEndian, &alpha); err != nil { + return err + } + m.kv["alpha"] = alpha + + for { + var dims uint32 + if err := binary.Read(rso, binary.LittleEndian, &dims); err != nil { + return err + } + + var namesize uint32 + if err := binary.Read(rso, binary.LittleEndian, &namesize); err != nil { + return err + } + + var t Tensor + if err := binary.Read(rso, binary.LittleEndian, &t.Kind); err != nil { + return err + } + + t.Shape = make([]uint64, dims) + for i := 0; uint32(i) < dims; i++ { + var shape32 uint32 + if err := binary.Read(rso, binary.LittleEndian, &shape32); err != nil { + return err + } + + t.Shape[i] = uint64(shape32) + } + + // ggla tensor shape is reversed + // ref: https://github.com/ggerganov/llama.cpp/blob/29ae62d2ae163e2b68aa0ad3bf2ab4636de0c957/convert-lora-to-ggml.py#L44 + slices.Reverse(t.Shape) + + name := make([]byte, namesize) + if err := binary.Read(rso, binary.LittleEndian, &name); err != nil { + return err + } + + t.Name = string(name) + + if _, err := rso.Seek((rso.offset+31)&-32, io.SeekStart); err != nil { + return err + } + + t.Offset = uint64(rso.offset) + + if _, err := rso.Seek(int64(t.Size()), io.SeekCurrent); err != nil { + return err + } + + m.tensors = append(m.tensors, t) + } +} + +func (m *ModelGGLA) KV() KV { + return m.kv +} + +func (m *ModelGGLA) Tensor() []Tensor { + return m.tensors +} + +func (*ModelGGLA) ModelFamily() string { + return "ggla" +} + +func (*ModelGGLA) ModelType() string { + panic("not implemented") +} + +func (*ModelGGLA) FileType() string { + panic("not implemented") +} + +func (*ModelGGLA) NumLayers() uint32 { + panic("not implemented") +} + +func (*ModelGGLA) NumGQA() uint32 { + panic("not implemented") +} + +func (*ModelGGLA) NumEmbed() uint32 { + panic("not implemented") +} + +func (*ModelGGLA) NumHead() uint32 { + panic("not implemented") +} + +func (*ModelGGLA) NumHeadKv() uint32 { + panic("not implemented") +} + +func (*ModelGGLA) NumCtx() uint32 { + panic("not implemented") +} diff --git a/llm/ggml.go b/llm/ggml.go index 3fb0539c..88cd9e13 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -31,6 +31,11 @@ const ( fileTypeQ5_K_S fileTypeQ5_K_M fileTypeQ6_K + fileTypeIQ2_XXS + fileTypeIQ2_XS + fileTypeQ2_K_S + fileTypeQ3_K_XS + fileTypeIQ3_XXS ) func fileType(fileType uint32) string { @@ -69,6 +74,16 @@ func fileType(fileType uint32) string { return "Q5_K_M" case fileTypeQ6_K: return "Q6_K" + case fileTypeIQ2_XXS: + return "IQ2_XXS" + case fileTypeIQ2_XS: + return "IQ2_XS" + case fileTypeQ2_K_S: + return "Q2_K_S" + case fileTypeQ3_K_XS: + return "Q3_K_XS" + case fileTypeIQ3_XXS: + return "IQ3_XXS" default: return "unknown" } @@ -91,32 +106,6 @@ type container interface { Decode(*readSeekOffset) (model, error) } -type containerLORA struct { - version uint32 -} - -func (c *containerLORA) Name() string { - return "ggla" -} - -func (c *containerLORA) Decode(rso *readSeekOffset) (model, error) { - var version uint32 - binary.Read(rso, binary.LittleEndian, &version) - - switch version { - case 1: - default: - return nil, errors.New("invalid version") - } - - c.version = version - - // remaining file contents aren't decoded - rso.Seek(0, io.SeekEnd) - - return nil, nil -} - const ( // Magic constant for `ggml` files (unversioned). FILE_MAGIC_GGML = 0x67676d6c @@ -146,17 +135,19 @@ func DecodeGGML(r io.ReadSeeker) (*GGML, error) { case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT: return nil, ErrUnsupportedFormat case FILE_MAGIC_GGLA: - c = &containerLORA{} + c = &ContainerGGLA{} case FILE_MAGIC_GGUF_LE: - c = &containerGGUF{bo: binary.LittleEndian} + c = &ContainerGGUF{ByteOrder: binary.LittleEndian} case FILE_MAGIC_GGUF_BE: - c = &containerGGUF{bo: binary.BigEndian} + c = &ContainerGGUF{ByteOrder: binary.BigEndian} default: return nil, errors.New("invalid file magic") } model, err := c.Decode(&ro) - if err != nil { + if errors.Is(err, io.EOF) { + // noop + } else if err != nil { return nil, err } diff --git a/llm/gguf.go b/llm/gguf.go index 436be42c..61c55148 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -5,12 +5,20 @@ import ( "encoding/binary" "fmt" "io" + "log/slog" + "os" + "regexp" + + "github.com/d4l3k/go-bfloat16" + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" + "github.com/x448/float16" "github.com/jmorganca/ollama/format" ) -type containerGGUF struct { - bo binary.ByteOrder +type ContainerGGUF struct { + ByteOrder binary.ByteOrder Version uint32 @@ -23,23 +31,28 @@ type containerGGUF struct { NumTensor uint64 NumKV uint64 } + + V3 struct { + NumTensor uint64 + NumKV uint64 + } } -func (c *containerGGUF) Name() string { +func (c *ContainerGGUF) Name() string { return "gguf" } -func (c *containerGGUF) Decode(rso *readSeekOffset) (model, error) { - binary.Read(rso, c.bo, &c.Version) +func (c *ContainerGGUF) Decode(rso *readSeekOffset) (model, error) { + binary.Read(rso, c.ByteOrder, &c.Version) switch c.Version { case 1: - binary.Read(rso, c.bo, &c.V1) + binary.Read(rso, c.ByteOrder, &c.V1) default: - binary.Read(rso, c.bo, &c.V2) + binary.Read(rso, c.ByteOrder, &c.V2) } - model := newGGUFModel(c) + model := NewGGUFModel(c) if err := model.Decode(rso); err != nil { return nil, err } @@ -48,47 +61,61 @@ func (c *containerGGUF) Decode(rso *readSeekOffset) (model, error) { } const ( - ggufTypeUint8 uint32 = iota - ggufTypeInt8 - ggufTypeUint16 - ggufTypeInt16 - ggufTypeUint32 - ggufTypeInt32 - ggufTypeFloat32 - ggufTypeBool - ggufTypeString - ggufTypeArray - ggufTypeUint64 - ggufTypeInt64 - ggufTypeFloat64 + _ uint32 = iota + GGUFTokenNormal + GGUFTokenUnknown + GGUFTokenControl + GGUFTokenUserDefined + GGUFTokenUnused + GGUFTokenByte ) -type kv map[string]any +const ( + GGUFTypeUint8 uint32 = iota + GGUFTypeInt8 + GGUFTypeUint16 + GGUFTypeInt16 + GGUFTypeUint32 + GGUFTypeInt32 + GGUFTypeFloat32 + GGUFTypeBool + GGUFTypeString + GGUFTypeArray + GGUFTypeUint64 + GGUFTypeInt64 + GGUFTypeFloat64 +) -type tensor struct { - name string - kind uint32 - offset uint64 +type KV map[string]any + +type Tensor struct { + Name string + Kind uint32 + Offset uint64 // shape is the number of elements in each dimension - shape [4]uint64 + Shape []uint64 + + FileName string + OffsetPadding uint64 + FileOffsets []uint64 } -func (t tensor) blockSize() uint64 { +func (t Tensor) BlockSize() uint64 { switch { - case t.kind < 2: + case t.Kind < 2: return 1 - case t.kind < 10: + case t.Kind < 10: return 32 default: return 256 } } -func (t tensor) typeSize() uint64 { - blockSize := t.blockSize() +func (t Tensor) TypeSize() uint64 { + blockSize := t.BlockSize() - switch t.kind { + switch t.Kind { case 0: // FP32 return 4 case 1: // FP16 @@ -115,36 +142,80 @@ func (t tensor) typeSize() uint64 { return 2 + 2 + 12 + blockSize/8 + blockSize/2 case 14: // Q6_K return blockSize/2 + blockSize/4 + blockSize/16 + 2 + case 15: // Q8_K + return 2 + blockSize + 2*blockSize/16 + case 16: // IQ2_XXS + return 2 + 2*blockSize/8 + case 17: // IQ2_XS + return 2 + 2*blockSize/8 + blockSize/32 + case 18: // IQ3_XXS + return 2 + 3*blockSize/8 default: return 0 } } -func (t tensor) parameters() uint64 { - return t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3] +func (t Tensor) Parameters() uint64 { + var count uint64 = 1 + for _, n := range t.Shape { + count *= n + } + return count } -func (t tensor) size() uint64 { - return t.parameters() * t.typeSize() / t.blockSize() +func (t Tensor) Size() uint64 { + return t.Parameters() * t.TypeSize() / t.BlockSize() } -type ggufModel struct { - *containerGGUF +func (t Tensor) Repack(data []uint16, heads int) ([]uint16, error) { + n := tensor.New(tensor.WithShape(int(t.Shape[0]), int(t.Shape[1])), tensor.WithBacking(data)) + origShape := n.Shape().Clone() - kv - tensors []tensor + // reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf + if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil { + return []uint16{}, err + } + + if err := n.T(0, 2, 1, 3); err != nil { + return []uint16{}, err + } + + if err := n.Reshape(origShape...); err != nil { + return []uint16{}, err + } + + if err := n.Transpose(); err != nil { + return []uint16{}, err + } + newN, err := native.SelectU16(n, 1) + if err != nil { + return []uint16{}, err + } + + var fullTensor []uint16 + for _, v := range newN { + fullTensor = append(fullTensor, v...) + } + return fullTensor, nil +} + +type GGUFModel struct { + *ContainerGGUF + + KV + Tensors []Tensor parameters uint64 } -func newGGUFModel(container *containerGGUF) *ggufModel { - return &ggufModel{ - containerGGUF: container, - kv: make(kv), +func NewGGUFModel(container *ContainerGGUF) *GGUFModel { + return &GGUFModel{ + ContainerGGUF: container, + KV: make(KV), } } -func (llm *ggufModel) NumTensor() uint64 { +func (llm *GGUFModel) NumTensor() uint64 { if llm.Version == 1 { return uint64(llm.V1.NumTensor) } @@ -152,7 +223,7 @@ func (llm *ggufModel) NumTensor() uint64 { return llm.V2.NumTensor } -func (llm *ggufModel) NumKV() uint64 { +func (llm *GGUFModel) NumKV() uint64 { if llm.Version == 1 { return uint64(llm.V1.NumKV) } @@ -160,15 +231,15 @@ func (llm *ggufModel) NumKV() uint64 { return llm.V2.NumKV } -func (llm *ggufModel) ModelFamily() string { - if t, ok := llm.kv["general.architecture"].(string); ok { +func (llm *GGUFModel) ModelFamily() string { + if t, ok := llm.KV["general.architecture"].(string); ok { return t } return "unknown" } -func (llm *ggufModel) ModelType() string { +func (llm *GGUFModel) ModelType() string { if llm.parameters > 0 { return format.HumanNumber(llm.parameters) } @@ -176,15 +247,393 @@ func (llm *ggufModel) ModelType() string { return "unknown" } -func (llm *ggufModel) FileType() string { - if t, ok := llm.kv["general.file_type"].(uint32); ok { +func (llm *GGUFModel) FileType() string { + if t, ok := llm.KV["general.file_type"].(uint32); ok { return fileType(t) } return "unknown" } -func (llm *ggufModel) Decode(rso *readSeekOffset) error { +func (llm *GGUFModel) Encode(f *os.File) error { + // this mimics the order of the llama.cpp convert script + kOrder := []string{ + "general.architecture", + "general.name", + "llama.context_length", + "llama.embedding_length", + "llama.block_count", + "llama.feed_forward_length", + "llama.rope.dimension_count", + "llama.attention.head_count", + "llama.attention.head_count_kv", + "llama.attention.layer_norm_rms_epsilon", + "llama.rope.freq_base", + "general.file_type", + "tokenizer.ggml.model", + "tokenizer.ggml.tokens", + "tokenizer.ggml.scores", + "tokenizer.ggml.token_type", + "tokenizer.ggml.bos_token_id", + "tokenizer.ggml.eos_token_id", + "tokenizer.ggml.unknown_token_id", + "tokenizer.ggml.add_bos_token", + "tokenizer.ggml.add_eos_token", + "tokenizer.chat_template", + } + + if err := binary.Write(f, llm.ByteOrder, []byte("GGUF")); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, uint32(3)); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, uint64(llm.V3.NumTensor)); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, uint64(llm.V3.NumKV)); err != nil { + return err + } + + for _, k := range kOrder { + val, ok := llm.KV[k] + if !ok { + continue + } + + if err := binary.Write(f, llm.ByteOrder, uint64(len(k))); err != nil { + return err + } + if err := binary.Write(f, llm.ByteOrder, []byte(k)); err != nil { + return err + } + + switch v := val.(type) { + case uint32: + if err := binary.Write(f, llm.ByteOrder, GGUFTypeUint32); err != nil { + return err + } + + if err := llm.writeUint32(f, v); err != nil { + return err + } + case float32: + if err := binary.Write(f, llm.ByteOrder, GGUFTypeFloat32); err != nil { + return err + } + + if err := llm.writeF32(f, v); err != nil { + return err + } + case bool: + if err := binary.Write(f, llm.ByteOrder, GGUFTypeBool); err != nil { + return err + } + + if err := llm.writeBool(f, v); err != nil { + return err + } + case string: + if err := binary.Write(f, llm.ByteOrder, GGUFTypeString); err != nil { + return err + } + + if err := llm.writeString(f, v); err != nil { + return err + } + case []int32: + if err := binary.Write(f, llm.ByteOrder, GGUFTypeArray); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, GGUFTypeInt32); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, uint64(len(v))); err != nil { + return err + } + for _, i := range v { + if err := llm.writeInt32(f, i); err != nil { + return err + } + } + case []uint32: + if err := binary.Write(f, llm.ByteOrder, GGUFTypeArray); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, GGUFTypeUint32); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, uint64(len(v))); err != nil { + return err + } + for _, i := range v { + if err := llm.writeUint32(f, i); err != nil { + return err + } + } + case []float32: + if err := binary.Write(f, llm.ByteOrder, GGUFTypeArray); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, GGUFTypeFloat32); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, uint64(len(v))); err != nil { + return err + } + for _, fl := range v { + if err := llm.writeF32(f, fl); err != nil { + return err + } + } + case []string: + if err := binary.Write(f, llm.ByteOrder, GGUFTypeArray); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, GGUFTypeString); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, uint64(len(v))); err != nil { + return err + } + + for _, s := range v { + if err := llm.writeString(f, s); err != nil { + return err + } + } + } + } + + // write layer metadata + for _, t := range llm.Tensors { + if err := llm.writeString(f, t.Name); err != nil { + return err + } + + // the dimensions of the tensor + dims := 1 + if t.Shape[1] > 0 { + dims = 2 + } + + if err := binary.Write(f, llm.ByteOrder, uint32(dims)); err != nil { + return err + } + + for i := 0; i < dims; i++ { + if err := binary.Write(f, llm.ByteOrder, uint64(t.Shape[dims-1-i])); err != nil { + return err + } + } + + if err := binary.Write(f, llm.ByteOrder, uint32(t.Kind)); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, uint64(t.Offset)); err != nil { + return err + } + } + + offset, terr := f.Seek(0, io.SeekCurrent) + if terr != nil { + return terr + } + slog.Debug(fmt.Sprintf("tensors offset = %x", offset)) + + if err := llm.writePadding(f, 32); err != nil { + return err + } + + var dataFile *os.File + var currentFile string + var err error + for _, t := range llm.Tensors { + if currentFile != t.FileName { + if f != nil { + dataFile.Close() + } + currentFile = t.FileName + dataFile, err = os.Open(t.FileName) + if err != nil { + fmt.Println(err) + return err + } + } + + dataFile.Seek(int64(t.OffsetPadding+t.FileOffsets[0]), 0) + + pattern := `^blk\.[0-9]+\.attn_(?Pq|k)\.weight$` + re, err := regexp.Compile(pattern) + if err != nil { + return err + } + + matches := re.FindAllStringSubmatch(t.Name, -1) + if len(matches) > 0 { + layerSize := t.FileOffsets[1] - t.FileOffsets[0] + + var err error + tData := make([]uint16, layerSize/2) + if err = binary.Read(dataFile, llm.ByteOrder, tData); err != nil { + return err + } + + layerType := matches[0][re.SubexpIndex("layer")] + var heads uint32 + switch layerType { + case "q": + heads = llm.KV["llama.attention.head_count"].(uint32) + case "k": + heads = llm.KV["llama.attention.head_count_kv"].(uint32) + if heads == 0 { + heads = llm.KV["llama.attention.head_count"].(uint32) + } + } + + tData, err = t.Repack(tData, int(heads)) + if err != nil { + return err + } + + var buf []byte + for _, n := range tData { + buf = binary.LittleEndian.AppendUint16(buf, n) + } + + tempBuf := make([]uint16, len(tData)) + tDataF32 := bfloat16.DecodeFloat32(buf) + for cnt, v := range tDataF32 { + tDataF16 := float16.Fromfloat32(v) + tempBuf[cnt] = uint16(tDataF16) + } + + if err = binary.Write(f, llm.ByteOrder, tempBuf); err != nil { + return err + } + + if err := llm.writePadding(f, 32); err != nil { + return err + } + continue + } + + remaining := t.FileOffsets[1] - t.FileOffsets[0] + + bufSize := uint64(10240) + var finished bool + for { + data := make([]byte, min(bufSize, remaining)) + + b, err := io.ReadFull(dataFile, data) + remaining -= uint64(b) + + if err == io.EOF || remaining <= 0 { + finished = true + } else if err != nil { + return err + } + + // convert bfloat16 -> ieee float32 + tDataF32 := bfloat16.DecodeFloat32(data) + + switch t.Kind { + case 0: + if err := binary.Write(f, llm.ByteOrder, tDataF32); err != nil { + return err + } + case 1: + // convert float32 -> float16 + tempBuf := make([]uint16, len(data)/2) + for cnt, v := range tDataF32 { + tDataF16 := float16.Fromfloat32(v) + tempBuf[cnt] = uint16(tDataF16) + } + if err := binary.Write(f, llm.ByteOrder, tempBuf); err != nil { + return err + } + } + if finished { + break + } + } + + if err := llm.writePadding(f, 32); err != nil { + return err + } + } + f.Close() + + return nil +} + +func (llm *GGUFModel) writePadding(f *os.File, align int64) error { + // gguf file padding is defined in https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#file-structure + offset, err := f.Seek(0, io.SeekCurrent) + if err != nil { + return err + } + padding := ((offset + align - 1) / align) * align + buf := make([]byte, padding-offset) + if err := binary.Write(f, llm.ByteOrder, buf); err != nil { + return err + } + + return nil +} + +func (llm *GGUFModel) writeInt32(f *os.File, v int32) error { + if err := binary.Write(f, llm.ByteOrder, v); err != nil { + return err + } + return nil +} + +func (llm *GGUFModel) writeUint32(f *os.File, v uint32) error { + if err := binary.Write(f, llm.ByteOrder, v); err != nil { + return err + } + return nil +} + +func (llm *GGUFModel) writeF32(f *os.File, v float32) error { + if err := binary.Write(f, llm.ByteOrder, v); err != nil { + return err + } + return nil +} + +func (llm *GGUFModel) writeBool(f *os.File, b bool) error { + if err := binary.Write(f, llm.ByteOrder, b); err != nil { + return err + } + return nil +} + +func (llm *GGUFModel) writeString(f *os.File, s string) error { + if err := binary.Write(f, llm.ByteOrder, uint64(len(s))); err != nil { + return err + } + + if err := binary.Write(f, llm.ByteOrder, []byte(s)); err != nil { + return err + } + return nil +} + +func (llm *GGUFModel) Decode(rso *readSeekOffset) error { // decode key-values for i := 0; uint64(i) < llm.NumKV(); i++ { k, err := llm.readString(rso) @@ -196,36 +645,36 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error { var v any switch vtype { - case ggufTypeUint8: + case GGUFTypeUint8: v = llm.readU8(rso) - case ggufTypeInt8: + case GGUFTypeInt8: v = llm.readI8(rso) - case ggufTypeUint16: + case GGUFTypeUint16: v = llm.readU16(rso) - case ggufTypeInt16: + case GGUFTypeInt16: v = llm.readI16(rso) - case ggufTypeUint32: + case GGUFTypeUint32: v = llm.readU32(rso) - case ggufTypeInt32: + case GGUFTypeInt32: v = llm.readI32(rso) - case ggufTypeUint64: + case GGUFTypeUint64: v = llm.readU64(rso) - case ggufTypeInt64: + case GGUFTypeInt64: v = llm.readI64(rso) - case ggufTypeFloat32: + case GGUFTypeFloat32: v = llm.readF32(rso) - case ggufTypeFloat64: + case GGUFTypeFloat64: v = llm.readF64(rso) - case ggufTypeBool: + case GGUFTypeBool: v = llm.readBool(rso) - case ggufTypeString: + case GGUFTypeString: s, err := llm.readString(rso) if err != nil { return err } v = s - case ggufTypeArray: + case GGUFTypeArray: a, err := llm.readArray(rso) if err != nil { return err @@ -236,7 +685,7 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error { return fmt.Errorf("invalid type: %d", vtype) } - llm.kv[k] = v + llm.KV[k] = v } // decode tensors @@ -254,33 +703,33 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error { shape[i] = llm.readU64(rso) } - tensor := tensor{ - name: name, - kind: llm.readU32(rso), - offset: llm.readU64(rso), - shape: shape, + tensor := Tensor{ + Name: name, + Kind: llm.readU32(rso), + Offset: llm.readU64(rso), + Shape: shape[:], } - llm.tensors = append(llm.tensors, tensor) - llm.parameters += tensor.parameters() + llm.Tensors = append(llm.Tensors, tensor) + llm.parameters += tensor.Parameters() } - alignment, ok := llm.kv["general.alignment"].(uint32) + alignment, ok := llm.KV["general.alignment"].(uint32) if !ok { alignment = 32 } rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent) - for _, tensor := range llm.tensors { - padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1) + for _, tensor := range llm.Tensors { + padded := (int64(tensor.Size()) + int64(alignment) - 1) & ^(int64(alignment) - 1) rso.Seek(padded, io.SeekCurrent) } return nil } -func (llm *ggufModel) NumLayers() uint32 { - value, exists := llm.kv[fmt.Sprintf("%s.block_count", llm.ModelFamily())] +func (llm *GGUFModel) NumLayers() uint32 { + value, exists := llm.KV[fmt.Sprintf("%s.block_count", llm.ModelFamily())] if !exists { return 0 } @@ -288,8 +737,8 @@ func (llm *ggufModel) NumLayers() uint32 { return value.(uint32) } -func (llm *ggufModel) NumHead() uint32 { - value, exists := llm.kv[fmt.Sprintf("%s.attention.head_count", llm.ModelFamily())] +func (llm *GGUFModel) NumHead() uint32 { + value, exists := llm.KV[fmt.Sprintf("%s.attention.head_count", llm.ModelFamily())] if !exists { return 0 } @@ -297,8 +746,8 @@ func (llm *ggufModel) NumHead() uint32 { return value.(uint32) } -func (llm *ggufModel) NumEmbed() uint32 { - value, exists := llm.kv[fmt.Sprintf("%s.embedding_length", llm.ModelFamily())] +func (llm *GGUFModel) NumEmbed() uint32 { + value, exists := llm.KV[fmt.Sprintf("%s.embedding_length", llm.ModelFamily())] if !exists { return 0 } @@ -306,8 +755,8 @@ func (llm *ggufModel) NumEmbed() uint32 { return value.(uint32) } -func (llm *ggufModel) NumHeadKv() uint32 { - value, exists := llm.kv[fmt.Sprintf("%s.attention.head_count_kv", llm.ModelFamily())] +func (llm *GGUFModel) NumHeadKv() uint32 { + value, exists := llm.KV[fmt.Sprintf("%s.attention.head_count_kv", llm.ModelFamily())] if !exists { return 0 } @@ -315,8 +764,8 @@ func (llm *ggufModel) NumHeadKv() uint32 { return value.(uint32) } -func (llm *ggufModel) NumCtx() uint32 { - value, exists := llm.kv[fmt.Sprintf("%s.context_length", llm.ModelFamily())] +func (llm *GGUFModel) NumCtx() uint32 { + value, exists := llm.KV[fmt.Sprintf("%s.context_length", llm.ModelFamily())] if !exists { return 0 } @@ -324,7 +773,7 @@ func (llm *ggufModel) NumCtx() uint32 { return value.(uint32) } -func (llm *ggufModel) NumGQA() uint32 { +func (llm *GGUFModel) NumGQA() uint32 { numHeadKv := llm.NumHeadKv() if numHeadKv == 0 { return 0 @@ -333,75 +782,75 @@ func (llm *ggufModel) NumGQA() uint32 { return llm.NumHead() / numHeadKv } -func (llm ggufModel) readU8(r io.Reader) uint8 { +func (llm GGUFModel) readU8(r io.Reader) uint8 { var u8 uint8 - binary.Read(r, llm.bo, &u8) + binary.Read(r, llm.ByteOrder, &u8) return u8 } -func (llm ggufModel) readI8(r io.Reader) int8 { +func (llm GGUFModel) readI8(r io.Reader) int8 { var i8 int8 - binary.Read(r, llm.bo, &i8) + binary.Read(r, llm.ByteOrder, &i8) return i8 } -func (llm ggufModel) readU16(r io.Reader) uint16 { +func (llm GGUFModel) readU16(r io.Reader) uint16 { var u16 uint16 - binary.Read(r, llm.bo, &u16) + binary.Read(r, llm.ByteOrder, &u16) return u16 } -func (llm ggufModel) readI16(r io.Reader) int16 { +func (llm GGUFModel) readI16(r io.Reader) int16 { var i16 int16 - binary.Read(r, llm.bo, &i16) + binary.Read(r, llm.ByteOrder, &i16) return i16 } -func (llm ggufModel) readU32(r io.Reader) uint32 { +func (llm GGUFModel) readU32(r io.Reader) uint32 { var u32 uint32 - binary.Read(r, llm.bo, &u32) + binary.Read(r, llm.ByteOrder, &u32) return u32 } -func (llm ggufModel) readI32(r io.Reader) int32 { +func (llm GGUFModel) readI32(r io.Reader) int32 { var i32 int32 - binary.Read(r, llm.bo, &i32) + binary.Read(r, llm.ByteOrder, &i32) return i32 } -func (llm ggufModel) readU64(r io.Reader) uint64 { +func (llm GGUFModel) readU64(r io.Reader) uint64 { var u64 uint64 - binary.Read(r, llm.bo, &u64) + binary.Read(r, llm.ByteOrder, &u64) return u64 } -func (llm ggufModel) readI64(r io.Reader) int64 { +func (llm GGUFModel) readI64(r io.Reader) int64 { var i64 int64 - binary.Read(r, llm.bo, &i64) + binary.Read(r, llm.ByteOrder, &i64) return i64 } -func (llm ggufModel) readF32(r io.Reader) float32 { +func (llm GGUFModel) readF32(r io.Reader) float32 { var f32 float32 - binary.Read(r, llm.bo, &f32) + binary.Read(r, llm.ByteOrder, &f32) return f32 } -func (llm ggufModel) readF64(r io.Reader) float64 { +func (llm GGUFModel) readF64(r io.Reader) float64 { var f64 float64 - binary.Read(r, llm.bo, &f64) + binary.Read(r, llm.ByteOrder, &f64) return f64 } -func (llm ggufModel) readBool(r io.Reader) bool { +func (llm GGUFModel) readBool(r io.Reader) bool { var b bool - binary.Read(r, llm.bo, &b) + binary.Read(r, llm.ByteOrder, &b) return b } -func (llm ggufModel) readStringV1(r io.Reader) (string, error) { +func (llm GGUFModel) readStringV1(r io.Reader) (string, error) { var nameLength uint32 - binary.Read(r, llm.bo, &nameLength) + binary.Read(r, llm.ByteOrder, &nameLength) var b bytes.Buffer if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil { @@ -414,13 +863,13 @@ func (llm ggufModel) readStringV1(r io.Reader) (string, error) { return b.String(), nil } -func (llm ggufModel) readString(r io.Reader) (string, error) { +func (llm GGUFModel) readString(r io.Reader) (string, error) { if llm.Version == 1 { return llm.readStringV1(r) } var nameLength uint64 - binary.Read(r, llm.bo, &nameLength) + binary.Read(r, llm.ByteOrder, &nameLength) var b bytes.Buffer if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil { @@ -430,29 +879,29 @@ func (llm ggufModel) readString(r io.Reader) (string, error) { return b.String(), nil } -func (llm *ggufModel) readArrayV1(r io.Reader) (arr []any, err error) { +func (llm *GGUFModel) readArrayV1(r io.Reader) (arr []any, err error) { atype := llm.readU32(r) n := llm.readU32(r) for i := 0; uint32(i) < n; i++ { switch atype { - case ggufTypeUint8: + case GGUFTypeUint8: arr = append(arr, llm.readU8(r)) - case ggufTypeInt8: + case GGUFTypeInt8: arr = append(arr, llm.readI8(r)) - case ggufTypeUint16: + case GGUFTypeUint16: arr = append(arr, llm.readU16(r)) - case ggufTypeInt16: + case GGUFTypeInt16: arr = append(arr, llm.readI16(r)) - case ggufTypeUint32: + case GGUFTypeUint32: arr = append(arr, llm.readU32(r)) - case ggufTypeInt32: + case GGUFTypeInt32: arr = append(arr, llm.readI32(r)) - case ggufTypeFloat32: + case GGUFTypeFloat32: arr = append(arr, llm.readF32(r)) - case ggufTypeBool: + case GGUFTypeBool: arr = append(arr, llm.readBool(r)) - case ggufTypeString: + case GGUFTypeString: s, err := llm.readStringV1(r) if err != nil { return nil, err @@ -467,7 +916,7 @@ func (llm *ggufModel) readArrayV1(r io.Reader) (arr []any, err error) { return } -func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) { +func (llm *GGUFModel) readArray(r io.Reader) (arr []any, err error) { if llm.Version == 1 { return llm.readArrayV1(r) } @@ -477,29 +926,29 @@ func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) { for i := 0; uint64(i) < n; i++ { switch atype { - case ggufTypeUint8: + case GGUFTypeUint8: arr = append(arr, llm.readU8(r)) - case ggufTypeInt8: + case GGUFTypeInt8: arr = append(arr, llm.readI8(r)) - case ggufTypeUint16: + case GGUFTypeUint16: arr = append(arr, llm.readU16(r)) - case ggufTypeInt16: + case GGUFTypeInt16: arr = append(arr, llm.readI16(r)) - case ggufTypeUint32: + case GGUFTypeUint32: arr = append(arr, llm.readU32(r)) - case ggufTypeInt32: + case GGUFTypeInt32: arr = append(arr, llm.readI32(r)) - case ggufTypeUint64: + case GGUFTypeUint64: arr = append(arr, llm.readU64(r)) - case ggufTypeInt64: + case GGUFTypeInt64: arr = append(arr, llm.readI64(r)) - case ggufTypeFloat32: + case GGUFTypeFloat32: arr = append(arr, llm.readF32(r)) - case ggufTypeFloat64: + case GGUFTypeFloat64: arr = append(arr, llm.readF64(r)) - case ggufTypeBool: + case GGUFTypeBool: arr = append(arr, llm.readBool(r)) - case ggufTypeString: + case GGUFTypeString: s, err := llm.readString(r) if err != nil { return nil, err diff --git a/llm/llama.cpp b/llm/llama.cpp index d2f650cb..ceca1aef 160000 --- a/llm/llama.cpp +++ b/llm/llama.cpp @@ -1 +1 @@ -Subproject commit d2f650cb5b04ee2726663e79b47da5efe196ce00 +Subproject commit ceca1aef0738b57951cd12c603c3477e75312dec diff --git a/llm/llama.go b/llm/llama.go index 80b4f75b..a5d2036a 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -62,7 +62,7 @@ const maxRetries = 3 type PredictOpts struct { Prompt string Format string - Images []api.ImageData + Images []ImageData Options api.Options } diff --git a/llm/llm.go b/llm/llm.go index 8e2f0714..a3f59d2e 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -6,6 +6,7 @@ import ( "log/slog" "os" "runtime" + "slices" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/gpu" @@ -19,7 +20,11 @@ type LLM interface { Close() } -func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) { +var cpuOnlyFamilies = []string{ + "mamba", +} + +func New(model string, adapters, projectors []string, opts api.Options) (LLM, error) { if _, err := os.Stat(model); err != nil { return nil, err } @@ -48,13 +53,18 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options) size := ggml.Size // fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value - kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(ggml.NumHead()) + kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(max(ggml.NumHead(), 1)) // this amount is the overhead + tensors in memory // TODO: get this from the llama.cpp's graph calculations instead of // estimating it's 1/6 * kv_cache_size * num_gqa graph := int64(ggml.NumGQA()) * kv / 6 + // certain model architectures don't support gpu inference yet + if slices.Contains(cpuOnlyFamilies, ggml.ModelFamily()) { + opts.NumGPU = 0 + } + info := gpu.GetGPUInfo() switch runtime.GOOS { case "darwin": @@ -63,9 +73,7 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options) } if size+kv+graph > vram { - slog.Info("not enough vram available, falling back to CPU only") - info.Library = "cpu" - info.Variant = gpu.GetCPUVariant() + slog.Info("not enough vram available, setting num_gpu=0") opts.NumGPU = 0 break } @@ -124,8 +132,8 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options) } // Give any native cgo implementations an opportunity to initialize -func Init(workdir string) error { - return nativeInit(workdir) +func Init() error { + return nativeInit() } func newLlmServer(gpuInfo gpu.GpuInfo, model string, adapters, projectors []string, opts api.Options) (LLM, error) { @@ -143,6 +151,16 @@ func newLlmServer(gpuInfo gpu.GpuInfo, model string, adapters, projectors []stri } } + // We stage into a temp directory, and if we've been idle for a while, it may have been reaped + _, err := os.Stat(dynLibs[0]) + if err != nil { + slog.Info(fmt.Sprintf("%s has disappeared, reloading libraries", dynLibs[0])) + err = nativeInit() + if err != nil { + return nil, err + } + } + err2 := fmt.Errorf("unable to locate suitable llm library") for _, dynLib := range dynLibs { srv, err := newDynExtServer(dynLib, model, adapters, projectors, opts) diff --git a/llm/patches/01-cache.diff b/llm/patches/01-cache.diff index 79f8d002..e2cd30a2 100644 --- a/llm/patches/01-cache.diff +++ b/llm/patches/01-cache.diff @@ -1,30 +1,21 @@ diff --git a/examples/server/server.cpp b/examples/server/server.cpp -index a48582ad..9fffffd8 100644 +index 8fe5e0b1..3e82acb9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp -@@ -1564,12 +1564,6 @@ struct llama_server_context - LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); - } +@@ -997,13 +997,15 @@ struct llama_server_context + slot.n_sent_text += result.text_to_send.size(); + // add the token to slot queue and cache + } +- slot.add_token_string(result); ++ + if (slot.params.stream) + { + send_partial_response(slot, result); + } + } -- LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past); -- -- llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1); -- -- slot.cache_tokens = prompt_tokens; -- - if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0) - { - // we have to evaluate at least 1 token to generate logits. -@@ -1581,6 +1575,12 @@ struct llama_server_context - } - } - -+ LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past); ++ slot.add_token_string(result); + -+ llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1); -+ -+ slot.cache_tokens = prompt_tokens; -+ - LOG_VERBOSE("prompt ingested", { - {"n_past", slot.n_past}, - {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, + if (incomplete) + { + slot.has_next_token = true; diff --git a/llm/patches/02-cudaleaks.diff b/llm/patches/02-cudaleaks.diff new file mode 100644 index 00000000..e7bbb745 --- /dev/null +++ b/llm/patches/02-cudaleaks.diff @@ -0,0 +1,117 @@ +diff --git a/examples/server/server.cpp b/examples/server/server.cpp +index 8fe5e0b1..53bf39c1 100644 +--- a/examples/server/server.cpp ++++ b/examples/server/server.cpp +@@ -31,6 +31,10 @@ + #include + #include + ++#ifdef GGML_USE_CUBLAS ++extern "C" GGML_CALL void ggml_free_cublas(void); ++#endif ++ + using json = nlohmann::json; + + struct server_params { +@@ -363,6 +367,10 @@ struct llama_server_context + llama_free_model(model); + model = nullptr; + } ++ ++#ifdef GGML_USE_CUBLAS ++ ggml_free_cublas(); ++#endif + } + + bool load_model(const gpt_params ¶ms_) +@@ -3543,6 +3551,7 @@ int main(int argc, char **argv) + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); ++ sigaction(SIGUSR1, &sigint_action, NULL); + #elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; +diff --git a/ggml-cuda.cu b/ggml-cuda.cu +index 72bcec8c..6c934e8c 100644 +--- a/ggml-cuda.cu ++++ b/ggml-cuda.cu +@@ -43,6 +43,7 @@ + #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) + #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 + #define cublasCreate hipblasCreate ++#define cublasDestroy hipblasDestroy + #define cublasGemmEx hipblasGemmEx + #define cublasGemmBatchedEx hipblasGemmBatchedEx + #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx +@@ -8751,10 +8752,10 @@ GGML_CALL bool ggml_cublas_loaded(void) { + return g_cublas_loaded; + } + +-GGML_CALL void ggml_init_cublas() { +- static bool initialized = false; ++static bool g_cublas_initialized = false; + +- if (!initialized) { ++GGML_CALL void ggml_init_cublas() { ++ if (!g_cublas_initialized) { + + #ifdef __HIP_PLATFORM_AMD__ + // Workaround for a rocBLAS bug when using multiple graphics cards: +@@ -8764,7 +8765,7 @@ GGML_CALL void ggml_init_cublas() { + #endif + + if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) { +- initialized = true; ++ g_cublas_initialized = true; + g_cublas_loaded = false; + fprintf(stderr, "%s: no " GGML_CUDA_NAME " devices found, " GGML_CUDA_NAME " will be disabled\n", __func__); + return; +@@ -8835,7 +8836,7 @@ GGML_CALL void ggml_init_cublas() { + // configure logging to stdout + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + +- initialized = true; ++ g_cublas_initialized = true; + g_cublas_loaded = true; + } + } +@@ -12490,3 +12491,23 @@ GGML_CALL int ggml_backend_cuda_reg_devices() { + } + return device_count; + } ++ ++ ++extern "C" GGML_CALL void ggml_free_cublas(void); ++GGML_CALL void ggml_free_cublas(void) { ++ for (int id = 0; id < g_device_count; ++id) { ++#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) ++ if (g_device_caps[id].vmm) { ++ CU_CHECK(cuMemUnmap(g_cuda_pool_addr[id], g_cuda_pool_size[id])); ++ g_cuda_pool_size[id] = 0; ++ g_cuda_pool_addr[id] = 0; ++ } ++#endif ++ // TODO: free legacy non-vmm memory ++ // destroy cublas handle ++ CUBLAS_CHECK(cublasDestroy(g_cublas_handles[id])); ++ g_cublas_handles[id] = nullptr; ++ } ++ ++ g_cublas_initialized = false; ++} +\ No newline at end of file +diff --git a/ggml-cuda.h b/ggml-cuda.h +index b1ebd61d..6dd58ddf 100644 +--- a/ggml-cuda.h ++++ b/ggml-cuda.h +@@ -23,6 +23,9 @@ GGML_API GGML_CALL void ggml_init_cublas(void); + // Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`. + GGML_API GGML_CALL bool ggml_cublas_loaded(void); + ++// Release CUDA resources ++GGML_API GGML_CALL void ggml_free_cublas(void); ++ + GGML_API GGML_CALL void * ggml_cuda_host_malloc(size_t size); + GGML_API GGML_CALL void ggml_cuda_host_free(void * ptr); + diff --git a/llm/patches/02-shutdown.diff b/llm/patches/02-shutdown.diff deleted file mode 100644 index 4c247cc0..00000000 --- a/llm/patches/02-shutdown.diff +++ /dev/null @@ -1,90 +0,0 @@ -diff --git a/examples/server/server.cpp b/examples/server/server.cpp -index 11dd82c3..311495a8 100644 ---- a/examples/server/server.cpp -+++ b/examples/server/server.cpp -@@ -28,6 +28,7 @@ - #include - #include - #include -+#include - - using json = nlohmann::json; - -@@ -2394,6 +2395,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con - } - } - -+std::function shutdown_handler; -+inline void signal_handler(int signal) { shutdown_handler(signal); } -+ - int main(int argc, char **argv) - { - #if SERVER_VERBOSE != 1 -@@ -3014,8 +3018,14 @@ int main(int argc, char **argv) - std::placeholders::_2, - std::placeholders::_3 - )); -- llama.queue_tasks.start_loop(); - -+ shutdown_handler = [&](int) { -+ llama.queue_tasks.terminate(); -+ }; -+ signal(SIGTERM, signal_handler); -+ signal(SIGINT, signal_handler); -+ llama.queue_tasks.start_loop(); -+ svr.stop(); - t.join(); - - llama_backend_free(); -diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp -index 70cce072..2acb1eab 100644 ---- a/examples/server/utils.hpp -+++ b/examples/server/utils.hpp -@@ -6,6 +6,7 @@ - #include - #include - #include -+#include - - #include "json.hpp" - -@@ -190,6 +191,7 @@ inline std::string format_chatml(std::vector messages) - struct llama_server_queue { - int id = 0; - std::mutex mutex_tasks; -+ std::atomic running; - // queues - std::vector queue_tasks; - std::vector queue_tasks_deferred; -@@ -248,9 +250,15 @@ struct llama_server_queue { - queue_tasks_deferred.clear(); - } - -- // Start the main loop. This call is blocking -- [[noreturn]] -+ // end the start_loop routine -+ void terminate() { -+ running = false; -+ condition_tasks.notify_all(); -+ } -+ -+ // Start the main loop. - void start_loop() { -+ running = true; - while (true) { - // new task arrived - LOG_VERBOSE("have new task", {}); -@@ -294,8 +302,12 @@ struct llama_server_queue { - { - std::unique_lock lock(mutex_tasks); - if (queue_tasks.empty()) { -+ if (!running.load()) { -+ LOG_VERBOSE("ending start_loop", {}); -+ return; -+ } - condition_tasks.wait(lock, [&]{ -- return !queue_tasks.empty(); -+ return (!queue_tasks.empty() || !running.load()); - }); - } - } diff --git a/llm/patches/03-load_exception.diff b/llm/patches/03-load_exception.diff new file mode 100644 index 00000000..9e838fa9 --- /dev/null +++ b/llm/patches/03-load_exception.diff @@ -0,0 +1,44 @@ +diff --git a/llama.cpp b/llama.cpp +index 4225f955..7b762f86 100644 +--- a/llama.cpp ++++ b/llama.cpp +@@ -4756,7 +4756,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam + } + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); +- return -1; ++ throw; + } + + return 0; +@@ -12102,16 +12102,22 @@ struct llama_model * llama_load_model_from_file( + }; + } + +- int status = llama_model_load(path_model, *model, params); +- GGML_ASSERT(status <= 0); +- if (status < 0) { +- if (status == -1) { +- LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); +- } else if (status == -2) { +- LLAMA_LOG_INFO("%s: cancelled model load\n", __func__); ++ try { ++ int status = llama_model_load(path_model, *model, params); ++ GGML_ASSERT(status <= 0); ++ if (status < 0) { ++ if (status == -1) { ++ LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); ++ } else if (status == -2) { ++ LLAMA_LOG_INFO("%s: cancelled model load\n", __func__); ++ } ++ delete model; ++ return nullptr; + } ++ } catch (...) { ++ LLAMA_LOG_ERROR("%s: exception loading model\n", __func__); + delete model; +- return nullptr; ++ throw; + } + + return model; diff --git a/llm/payload_common.go b/llm/payload_common.go index d0218979..010ba73d 100644 --- a/llm/payload_common.go +++ b/llm/payload_common.go @@ -90,6 +90,7 @@ func getDynLibs(gpuInfo gpu.GpuInfo) []string { if len(dynLibs) == 0 { dynLibs = []string{availableDynLibs["cpu"]} } + slog.Debug(fmt.Sprintf("ordered list of LLM libraries to try %v", dynLibs)) return dynLibs } @@ -102,22 +103,15 @@ func rocmDynLibPresent() bool { return false } -func nativeInit(workdir string) error { - slog.Info("Extracting dynamic libraries...") - if runtime.GOOS == "darwin" { - err := extractPayloadFiles(workdir, "llama.cpp/ggml-metal.metal") - if err != nil { - if err == payloadMissing { - // TODO perhaps consider this a hard failure on arm macs? - slog.Info("ggml-meta.metal payload missing") - return nil - } - return err - } - os.Setenv("GGML_METAL_PATH_RESOURCES", workdir) +func nativeInit() error { + payloadsDir, err := gpu.PayloadsDir() + if err != nil { + return err } - libs, err := extractDynamicLibs(workdir, "llama.cpp/build/*/*/*/lib/*") + slog.Info(fmt.Sprintf("Extracting dynamic libraries to %s ...", payloadsDir)) + + libs, err := extractDynamicLibs(payloadsDir, "llama.cpp/build/*/*/*/lib/*") if err != nil { if err == payloadMissing { slog.Info(fmt.Sprintf("%s", payloadMissing)) @@ -148,17 +142,13 @@ func nativeInit(workdir string) error { return nil } -func extractDynamicLibs(workDir, glob string) ([]string, error) { +func extractDynamicLibs(payloadsDir, glob string) ([]string, error) { files, err := fs.Glob(libEmbed, glob) if err != nil || len(files) == 0 { return nil, payloadMissing } libs := []string{} - // TODO consider making this idempotent with some sort of persistent directory (where we store models probably) - // and tracking by version so we don't reexpand the files every time - // Also maybe consider lazy loading only what is needed - g := new(errgroup.Group) for _, file := range files { pathComps := strings.Split(file, "/") @@ -171,14 +161,14 @@ func extractDynamicLibs(workDir, glob string) ([]string, error) { g.Go(func() error { // llama.cpp/build/$OS/$GOARCH/$VARIANT/lib/$LIBRARY // Include the variant in the path to avoid conflicts between multiple server libs - targetDir := filepath.Join(workDir, pathComps[pathComponentCount-3]) + targetDir := filepath.Join(payloadsDir, pathComps[pathComponentCount-3]) srcFile, err := libEmbed.Open(file) if err != nil { return fmt.Errorf("read payload %s: %v", file, err) } defer srcFile.Close() if err := os.MkdirAll(targetDir, 0o755); err != nil { - return fmt.Errorf("create payload temp dir %s: %v", workDir, err) + return fmt.Errorf("create payload lib dir %s: %v", payloadsDir, err) } src := io.Reader(srcFile) filename := file @@ -195,19 +185,13 @@ func extractDynamicLibs(workDir, glob string) ([]string, error) { libs = append(libs, destFile) } - _, err = os.Stat(destFile) - switch { - case errors.Is(err, os.ErrNotExist): - destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) - if err != nil { - return fmt.Errorf("write payload %s: %v", file, err) - } - defer destFile.Close() - if _, err := io.Copy(destFile, src); err != nil { - return fmt.Errorf("copy payload %s: %v", file, err) - } - case err != nil: - return fmt.Errorf("stat payload %s: %v", file, err) + destFp, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) + if err != nil { + return fmt.Errorf("write payload %s: %v", file, err) + } + defer destFp.Close() + if _, err := io.Copy(destFp, src); err != nil { + return fmt.Errorf("copy payload %s: %v", file, err) } return nil }) @@ -215,50 +199,6 @@ func extractDynamicLibs(workDir, glob string) ([]string, error) { return libs, g.Wait() } -func extractPayloadFiles(workDir, glob string) error { - files, err := fs.Glob(libEmbed, glob) - if err != nil || len(files) == 0 { - return payloadMissing - } - - for _, file := range files { - srcFile, err := libEmbed.Open(file) - if err != nil { - return fmt.Errorf("read payload %s: %v", file, err) - } - defer srcFile.Close() - if err := os.MkdirAll(workDir, 0o755); err != nil { - return fmt.Errorf("create payload temp dir %s: %v", workDir, err) - } - src := io.Reader(srcFile) - filename := file - if strings.HasSuffix(file, ".gz") { - src, err = gzip.NewReader(src) - if err != nil { - return fmt.Errorf("decompress payload %s: %v", file, err) - } - filename = strings.TrimSuffix(filename, ".gz") - } - - destFile := filepath.Join(workDir, filepath.Base(filename)) - _, err = os.Stat(destFile) - switch { - case errors.Is(err, os.ErrNotExist): - destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) - if err != nil { - return fmt.Errorf("write payload %s: %v", file, err) - } - defer destFile.Close() - if _, err := io.Copy(destFile, src); err != nil { - return fmt.Errorf("copy payload %s: %v", file, err) - } - case err != nil: - return fmt.Errorf("stat payload %s: %v", file, err) - } - } - return nil -} - func verifyDriverAccess() error { if runtime.GOOS != "linux" { return nil diff --git a/llm/payload_darwin_amd64.go b/llm/payload_darwin_amd64.go index a1c70ba9..dfeeb9cf 100644 --- a/llm/payload_darwin_amd64.go +++ b/llm/payload_darwin_amd64.go @@ -4,5 +4,5 @@ import ( "embed" ) -//go:embed llama.cpp/ggml-metal.metal llama.cpp/build/darwin/x86_64/*/lib/*.dylib* +//go:embed llama.cpp/build/darwin/x86_64/*/lib/*.dylib* var libEmbed embed.FS diff --git a/llm/payload_linux.go b/llm/payload_linux.go index fc366209..276705c7 100644 --- a/llm/payload_linux.go +++ b/llm/payload_linux.go @@ -4,5 +4,5 @@ import ( "embed" ) -//go:embed llama.cpp/build/linux/*/*/lib/*.so* +//go:embed llama.cpp/build/linux/*/*/lib/* var libEmbed embed.FS diff --git a/app/.eslintrc.json b/macapp/.eslintrc.json similarity index 100% rename from app/.eslintrc.json rename to macapp/.eslintrc.json diff --git a/macapp/.gitignore b/macapp/.gitignore new file mode 100644 index 00000000..8296128d --- /dev/null +++ b/macapp/.gitignore @@ -0,0 +1,92 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +lerna-debug.log* + +# Diagnostic reports (https://nodejs.org/api/report.html) +report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json + +# Runtime data +pids +*.pid +*.seed +*.pid.lock +.DS_Store + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage +*.lcov + +# nyc test coverage +.nyc_output + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ +jspm_packages/ + +# TypeScript v1 declaration files +typings/ + +# TypeScript cache +*.tsbuildinfo + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variables file +.env +.env.test + +# parcel-bundler cache (https://parceljs.org/) +.cache + +# next.js build output +.next + +# nuxt.js build output +.nuxt + +# vuepress build output +.vuepress/dist + +# Serverless directories +.serverless/ + +# FuseBox cache +.fusebox/ + +# DynamoDB Local files +.dynamodb/ + +# Webpack +.webpack/ + +# Vite +.vite/ + +# Electron-Forge +out/ diff --git a/macapp/README.md b/macapp/README.md new file mode 100644 index 00000000..cc34d745 --- /dev/null +++ b/macapp/README.md @@ -0,0 +1,21 @@ +# Desktop + +This app builds upon Ollama to provide a desktop experience for running models. + +## Developing + +First, build the `ollama` binary: + +``` +cd .. +go build . +``` + +Then run the desktop app with `npm start`: + +``` +cd app +npm install +npm start +``` + diff --git a/app/assets/icon.icns b/macapp/assets/icon.icns similarity index 100% rename from app/assets/icon.icns rename to macapp/assets/icon.icns diff --git a/app/assets/iconDarkTemplate.png b/macapp/assets/iconDarkTemplate.png similarity index 100% rename from app/assets/iconDarkTemplate.png rename to macapp/assets/iconDarkTemplate.png diff --git a/app/assets/iconDarkTemplate@2x.png b/macapp/assets/iconDarkTemplate@2x.png similarity index 100% rename from app/assets/iconDarkTemplate@2x.png rename to macapp/assets/iconDarkTemplate@2x.png diff --git a/app/assets/iconDarkUpdateTemplate.png b/macapp/assets/iconDarkUpdateTemplate.png similarity index 100% rename from app/assets/iconDarkUpdateTemplate.png rename to macapp/assets/iconDarkUpdateTemplate.png diff --git a/app/assets/iconDarkUpdateTemplate@2x.png b/macapp/assets/iconDarkUpdateTemplate@2x.png similarity index 100% rename from app/assets/iconDarkUpdateTemplate@2x.png rename to macapp/assets/iconDarkUpdateTemplate@2x.png diff --git a/app/assets/iconTemplate.png b/macapp/assets/iconTemplate.png similarity index 100% rename from app/assets/iconTemplate.png rename to macapp/assets/iconTemplate.png diff --git a/app/assets/iconTemplate@2x.png b/macapp/assets/iconTemplate@2x.png similarity index 100% rename from app/assets/iconTemplate@2x.png rename to macapp/assets/iconTemplate@2x.png diff --git a/app/assets/iconUpdateTemplate.png b/macapp/assets/iconUpdateTemplate.png similarity index 100% rename from app/assets/iconUpdateTemplate.png rename to macapp/assets/iconUpdateTemplate.png diff --git a/app/assets/iconUpdateTemplate@2x.png b/macapp/assets/iconUpdateTemplate@2x.png similarity index 100% rename from app/assets/iconUpdateTemplate@2x.png rename to macapp/assets/iconUpdateTemplate@2x.png diff --git a/app/forge.config.ts b/macapp/forge.config.ts similarity index 100% rename from app/forge.config.ts rename to macapp/forge.config.ts diff --git a/app/package-lock.json b/macapp/package-lock.json similarity index 100% rename from app/package-lock.json rename to macapp/package-lock.json diff --git a/app/package.json b/macapp/package.json similarity index 100% rename from app/package.json rename to macapp/package.json diff --git a/app/postcss.config.js b/macapp/postcss.config.js similarity index 100% rename from app/postcss.config.js rename to macapp/postcss.config.js diff --git a/app/src/app.css b/macapp/src/app.css similarity index 100% rename from app/src/app.css rename to macapp/src/app.css diff --git a/app/src/app.tsx b/macapp/src/app.tsx similarity index 100% rename from app/src/app.tsx rename to macapp/src/app.tsx diff --git a/app/src/declarations.d.ts b/macapp/src/declarations.d.ts similarity index 100% rename from app/src/declarations.d.ts rename to macapp/src/declarations.d.ts diff --git a/app/src/index.html b/macapp/src/index.html similarity index 100% rename from app/src/index.html rename to macapp/src/index.html diff --git a/app/src/index.ts b/macapp/src/index.ts similarity index 100% rename from app/src/index.ts rename to macapp/src/index.ts diff --git a/app/src/install.ts b/macapp/src/install.ts similarity index 100% rename from app/src/install.ts rename to macapp/src/install.ts diff --git a/app/src/ollama.svg b/macapp/src/ollama.svg similarity index 100% rename from app/src/ollama.svg rename to macapp/src/ollama.svg diff --git a/app/src/preload.ts b/macapp/src/preload.ts similarity index 100% rename from app/src/preload.ts rename to macapp/src/preload.ts diff --git a/app/src/renderer.tsx b/macapp/src/renderer.tsx similarity index 100% rename from app/src/renderer.tsx rename to macapp/src/renderer.tsx diff --git a/app/tailwind.config.js b/macapp/tailwind.config.js similarity index 100% rename from app/tailwind.config.js rename to macapp/tailwind.config.js diff --git a/app/tsconfig.json b/macapp/tsconfig.json similarity index 100% rename from app/tsconfig.json rename to macapp/tsconfig.json diff --git a/app/webpack.main.config.ts b/macapp/webpack.main.config.ts similarity index 100% rename from app/webpack.main.config.ts rename to macapp/webpack.main.config.ts diff --git a/app/webpack.plugins.ts b/macapp/webpack.plugins.ts similarity index 100% rename from app/webpack.plugins.ts rename to macapp/webpack.plugins.ts diff --git a/app/webpack.renderer.config.ts b/macapp/webpack.renderer.config.ts similarity index 100% rename from app/webpack.renderer.config.ts rename to macapp/webpack.renderer.config.ts diff --git a/app/webpack.rules.ts b/macapp/webpack.rules.ts similarity index 100% rename from app/webpack.rules.ts rename to macapp/webpack.rules.ts diff --git a/openai/openai.go b/openai/openai.go new file mode 100644 index 00000000..4f495569 --- /dev/null +++ b/openai/openai.go @@ -0,0 +1,322 @@ +// openai package provides middleware for partial compatibility with the OpenAI REST API +package openai + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/jmorganca/ollama/api" +) + +type Error struct { + Message string `json:"message"` + Type string `json:"type"` + Param interface{} `json:"param"` + Code *string `json:"code"` +} + +type ErrorResponse struct { + Error Error `json:"error"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type Choice struct { + Index int `json:"index"` + Message Message `json:"message"` + FinishReason *string `json:"finish_reason"` +} + +type ChunkChoice struct { + Index int `json:"index"` + Delta Message `json:"delta"` + FinishReason *string `json:"finish_reason"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type ResponseFormat struct { + Type string `json:"type"` +} + +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream bool `json:"stream"` + MaxTokens *int `json:"max_tokens"` + Seed *int `json:"seed"` + Stop any `json:"stop"` + Temperature *float64 `json:"temperature"` + FrequencyPenalty *float64 `json:"frequency_penalty"` + PresencePenalty *float64 `json:"presence_penalty_penalty"` + TopP *float64 `json:"top_p"` + ResponseFormat *ResponseFormat `json:"response_format"` +} + +type ChatCompletion struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage,omitempty"` +} + +type ChatCompletionChunk struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []ChunkChoice `json:"choices"` +} + +func NewError(code int, message string) ErrorResponse { + var etype string + switch code { + case http.StatusBadRequest: + etype = "invalid_request_error" + case http.StatusNotFound: + etype = "not_found_error" + default: + etype = "api_error" + } + + return ErrorResponse{Error{Type: etype, Message: message}} +} + +func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { + return ChatCompletion{ + Id: id, + Object: "chat.completion", + Created: r.CreatedAt.Unix(), + Model: r.Model, + SystemFingerprint: "fp_ollama", + Choices: []Choice{{ + Index: 0, + Message: Message{Role: r.Message.Role, Content: r.Message.Content}, + FinishReason: func(done bool) *string { + if done { + reason := "stop" + return &reason + } + return nil + }(r.Done), + }}, + Usage: Usage{ + // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count + PromptTokens: r.PromptEvalCount, + CompletionTokens: r.EvalCount, + TotalTokens: r.PromptEvalCount + r.EvalCount, + }, + } +} + +func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { + return ChatCompletionChunk{ + Id: id, + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Model: r.Model, + SystemFingerprint: "fp_ollama", + Choices: []ChunkChoice{ + { + Index: 0, + Delta: Message{Role: "assistant", Content: r.Message.Content}, + FinishReason: func(done bool) *string { + if done { + reason := "stop" + return &reason + } + return nil + }(r.Done), + }, + }, + } +} + +func fromRequest(r ChatCompletionRequest) api.ChatRequest { + var messages []api.Message + for _, msg := range r.Messages { + messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content}) + } + + options := make(map[string]interface{}) + + switch stop := r.Stop.(type) { + case string: + options["stop"] = []string{stop} + case []interface{}: + var stops []string + for _, s := range stop { + if str, ok := s.(string); ok { + stops = append(stops, str) + } + } + options["stop"] = stops + } + + if r.MaxTokens != nil { + options["num_predict"] = *r.MaxTokens + } + + if r.Temperature != nil { + options["temperature"] = *r.Temperature * 2.0 + } else { + options["temperature"] = 1.0 + } + + if r.Seed != nil { + options["seed"] = *r.Seed + + // temperature=0 is required for reproducible outputs + options["temperature"] = 0.0 + } + + if r.FrequencyPenalty != nil { + options["frequency_penalty"] = *r.FrequencyPenalty * 2.0 + } + + if r.PresencePenalty != nil { + options["presence_penalty"] = *r.PresencePenalty * 2.0 + } + + if r.TopP != nil { + options["top_p"] = *r.TopP + } else { + options["top_p"] = 1.0 + } + + var format string + if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" { + format = "json" + } + + return api.ChatRequest{ + Model: r.Model, + Messages: messages, + Format: format, + Options: options, + Stream: &r.Stream, + } +} + +type writer struct { + stream bool + id string + gin.ResponseWriter +} + +func (w *writer) writeError(code int, data []byte) (int, error) { + var serr api.StatusError + err := json.Unmarshal(data, &serr) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error())) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *writer) writeResponse(data []byte) (int, error) { + var chatResponse api.ChatResponse + err := json.Unmarshal(data, &chatResponse) + if err != nil { + return 0, err + } + + // chat chunk + if w.stream { + d, err := json.Marshal(toChunk(w.id, chatResponse)) + if err != nil { + return 0, err + + } + + w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + + if chatResponse.Done { + _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) + if err != nil { + return 0, err + } + } + + return len(data), nil + } + + // chat completion + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *writer) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(code, data) + } + + return w.writeResponse(data) +} + +func Middleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req ChatCompletionRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + return + } + + if len(req.Messages) == 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'")) + return + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &writer{ + ResponseWriter: c.Writer, + stream: req.Stream, + id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), + } + + c.Writer = w + + c.Next() + } +} diff --git a/readline/buffer.go b/readline/buffer.go index 7ccb1916..52e8a56c 100644 --- a/readline/buffer.go +++ b/readline/buffer.go @@ -19,10 +19,9 @@ type Buffer struct { func NewBuffer(prompt *Prompt) (*Buffer, error) { fd := int(os.Stdout.Fd()) - width, height, err := term.GetSize(fd) - if err != nil { - fmt.Println("Error getting size:", err) - return nil, err + width, height := 80, 24 + if termWidth, termHeight, err := term.GetSize(fd); err == nil { + width, height = termWidth, termHeight } lwidth := width - len(prompt.prompt()) diff --git a/readline/readline.go b/readline/readline.go index 202d9fa7..8ba7d89c 100644 --- a/readline/readline.go +++ b/readline/readline.go @@ -32,6 +32,8 @@ func (p *Prompt) placeholder() string { type Terminal struct { outchan chan rune + rawmode bool + termios any } type Instance struct { @@ -60,6 +62,16 @@ func New(prompt Prompt) (*Instance, error) { } func (i *Instance) Readline() (string, error) { + if !i.Terminal.rawmode { + fd := int(syscall.Stdin) + termios, err := SetRawMode(fd) + if err != nil { + return "", err + } + i.Terminal.rawmode = true + i.Terminal.termios = termios + } + prompt := i.Prompt.prompt() if i.Pasting { // force alt prompt when pasting @@ -67,13 +79,12 @@ func (i *Instance) Readline() (string, error) { } fmt.Print(prompt) - fd := int(syscall.Stdin) - termios, err := SetRawMode(fd) - if err != nil { - return "", err - } - // nolint: errcheck - defer UnsetRawMode(fd, termios) + defer func() { + fd := int(syscall.Stdin) + // nolint: errcheck + UnsetRawMode(fd, i.Terminal.termios) + i.Terminal.rawmode = false + }() buf, _ := NewBuffer(i.Prompt) @@ -205,7 +216,8 @@ func (i *Instance) Readline() (string, error) { case CharCtrlW: buf.DeleteWord() case CharCtrlZ: - return handleCharCtrlZ(fd, termios) + fd := int(syscall.Stdin) + return handleCharCtrlZ(fd, i.Terminal.termios) case CharEnter: output := buf.String() if output != "" { @@ -236,8 +248,16 @@ func (i *Instance) HistoryDisable() { } func NewTerminal() (*Terminal, error) { + fd := int(syscall.Stdin) + termios, err := SetRawMode(fd) + if err != nil { + return nil, err + } + t := &Terminal{ outchan: make(chan rune), + rawmode: true, + termios: termios, } go t.ioloop() diff --git a/readline/readline_unix.go b/readline/readline_unix.go index 73930c3d..76cff8c8 100644 --- a/readline/readline_unix.go +++ b/readline/readline_unix.go @@ -6,8 +6,9 @@ import ( "syscall" ) -func handleCharCtrlZ(fd int, termios *Termios) (string, error) { - if err := UnsetRawMode(fd, termios); err != nil { +func handleCharCtrlZ(fd int, termios any) (string, error) { + t := termios.(*Termios) + if err := UnsetRawMode(fd, t); err != nil { return "", err } diff --git a/readline/readline_windows.go b/readline/readline_windows.go index c8178903..b4e96b25 100644 --- a/readline/readline_windows.go +++ b/readline/readline_windows.go @@ -1,6 +1,6 @@ package readline -func handleCharCtrlZ(fd int, state *State) (string, error) { +func handleCharCtrlZ(fd int, state any) (string, error) { // not supported return "", nil } diff --git a/readline/term.go b/readline/term.go index 45757e6a..9d747162 100644 --- a/readline/term.go +++ b/readline/term.go @@ -25,8 +25,9 @@ func SetRawMode(fd int) (*Termios, error) { return termios, setTermios(fd, &newTermios) } -func UnsetRawMode(fd int, termios *Termios) error { - return setTermios(fd, termios) +func UnsetRawMode(fd int, termios any) error { + t := termios.(*Termios) + return setTermios(fd, t) } // IsTerminal returns true if the given file descriptor is a terminal. diff --git a/readline/term_windows.go b/readline/term_windows.go index 3d1c80e1..a40fbaa3 100644 --- a/readline/term_windows.go +++ b/readline/term_windows.go @@ -56,7 +56,8 @@ func SetRawMode(fd int) (*State, error) { return &State{st}, nil } -func UnsetRawMode(fd int, state *State) error { - _, _, err := syscall.SyscallN(procSetConsoleMode.Addr(), uintptr(fd), uintptr(state.mode), 0) +func UnsetRawMode(fd int, state any) error { + s := state.(*State) + _, _, err := syscall.SyscallN(procSetConsoleMode.Addr(), uintptr(fd), uintptr(s.mode), 0) return err } diff --git a/scripts/build_darwin.sh b/scripts/build_darwin.sh index 381bcba5..cb561123 100755 --- a/scripts/build_darwin.sh +++ b/scripts/build_darwin.sh @@ -10,8 +10,8 @@ mkdir -p dist for TARGETARCH in arm64 amd64; do rm -rf llm/llama.cpp/build GOOS=darwin GOARCH=$TARGETARCH go generate ./... - CGO_ENABLED=1 GOOS=darwin GOARCH=$TARGETARCH go build -o dist/ollama-darwin-$TARGETARCH - CGO_ENABLED=1 GOOS=darwin GOARCH=$TARGETARCH go build -cover -o dist/ollama-darwin-$TARGETARCH-cov + CGO_ENABLED=1 GOOS=darwin GOARCH=$TARGETARCH go build -trimpath -o dist/ollama-darwin-$TARGETARCH + CGO_ENABLED=1 GOOS=darwin GOARCH=$TARGETARCH go build -trimpath -cover -o dist/ollama-darwin-$TARGETARCH-cov done lipo -create -output dist/ollama dist/ollama-darwin-arm64 dist/ollama-darwin-amd64 @@ -24,13 +24,13 @@ fi chmod +x dist/ollama # build and optionally sign the mac app -npm install --prefix app +npm install --prefix macapp if [ -n "$APPLE_IDENTITY" ]; then - npm run --prefix app make:sign + npm run --prefix macapp make:sign else - npm run --prefix app make + npm run --prefix macapp make fi -cp app/out/make/zip/darwin/universal/Ollama-darwin-universal-$VERSION.zip dist/Ollama-darwin.zip +cp macapp/out/make/zip/darwin/universal/Ollama-darwin-universal-$VERSION.zip dist/Ollama-darwin.zip # sign the binary and rename it if [ -n "$APPLE_IDENTITY" ]; then diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh index 40054ca6..a3aa4264 100755 --- a/scripts/build_docker.sh +++ b/scripts/build_docker.sh @@ -5,13 +5,15 @@ set -eu export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")} export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'" +IMAGE_NAME=${IMAGE_NAME:-"ollama/ollama"} +BUILD_PLATFORM=${BUILD_PLATFORM:-"linux/arm64,linux/amd64"} docker build \ --load \ - --platform=linux/arm64,linux/amd64 \ + --platform=${BUILD_PLATFORM} \ --build-arg=VERSION \ --build-arg=GOFLAGS \ -f Dockerfile \ - -t ollama/ollama:$VERSION \ + -t ${IMAGE_NAME}:$VERSION \ . docker build \ @@ -21,5 +23,12 @@ docker build \ --build-arg=GOFLAGS \ --target runtime-rocm \ -f Dockerfile \ - -t ollama/ollama:$VERSION-rocm \ + -t ${IMAGE_NAME}:$VERSION-rocm \ . + +docker tag ${IMAGE_NAME}:$VERSION ${IMAGE_NAME}:latest +docker tag ${IMAGE_NAME}:$VERSION-rocm ${IMAGE_NAME}:rocm + +echo "To release, run:" +echo " docker push ${IMAGE_NAME}:$VERSION && docker push ${IMAGE_NAME}:latest" +echo " docker push ${IMAGE_NAME}:$VERSION-rocm && docker push ${IMAGE_NAME}:rocm" \ No newline at end of file diff --git a/scripts/build_linux.sh b/scripts/build_linux.sh index 338dbcd5..e6db485a 100755 --- a/scripts/build_linux.sh +++ b/scripts/build_linux.sh @@ -22,5 +22,10 @@ for TARGETARCH in ${BUILD_ARCH}; do . docker create --platform linux/$TARGETARCH --name builder-$TARGETARCH builder:$TARGETARCH docker cp builder-$TARGETARCH:/go/src/github.com/jmorganca/ollama/ollama ./dist/ollama-linux-$TARGETARCH + + if [ "$TARGETARCH" = "amd64" ]; then + docker cp builder-$TARGETARCH:/go/src/github.com/jmorganca/ollama/dist/deps/ ./dist/ + fi + docker rm builder-$TARGETARCH done diff --git a/scripts/build_remote.py b/scripts/build_remote.py index 314232ac..2ab58ad7 100755 --- a/scripts/build_remote.py +++ b/scripts/build_remote.py @@ -60,13 +60,17 @@ subprocess.check_call(['ssh', netloc, 'cd', path, ';', 'git', 'checkout', branch # subprocess.check_call(['ssh', netloc, 'cd', path, ';', 'env']) # TODO - or consider paramiko maybe -print("Performing generate") -subprocess.check_call(['ssh', netloc, 'cd', path, ';', GoCmd, 'generate', './...']) +print("Running Windows Build Script") +subprocess.check_call(['ssh', netloc, 'cd', path, ';', "powershell", "-ExecutionPolicy", "Bypass", "-File", "./scripts/build_windows.ps1"]) -print("Building") -subprocess.check_call(['ssh', netloc, 'cd', path, ';', GoCmd, 'build', '.']) +# print("Building") +# subprocess.check_call(['ssh', netloc, 'cd', path, ';', GoCmd, 'build', '.']) print("Copying built result") subprocess.check_call(['scp', netloc +":"+ path + "/ollama.exe", './dist/']) +print("Copying installer") +subprocess.check_call(['scp', netloc +":"+ path + "/dist/Ollama Setup.exe", './dist/']) + + diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 new file mode 100644 index 00000000..ee184799 --- /dev/null +++ b/scripts/build_windows.ps1 @@ -0,0 +1,133 @@ +#!powershell +# +# powershell -ExecutionPolicy Bypass -File .\scripts\build_windows.ps1 +# +# gcloud auth application-default login + +$ErrorActionPreference = "Stop" + +function checkEnv() { + write-host "Locating required tools and paths" + $script:SRC_DIR=$PWD + if (!$env:VCToolsRedistDir) { + $MSVC_INSTALL=(Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation + $env:VCToolsRedistDir=(get-item "${MSVC_INSTALL}\VC\Redist\MSVC\*")[0] + } + $script:NVIDIA_DIR=(get-item "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin\")[0] + $script:INNO_SETUP_DIR=(get-item "C:\Program Files*\Inno Setup*\")[0] + + $script:DEPS_DIR="${script:SRC_DIR}\dist\windeps" + $env:CGO_ENABLED="1" + echo "Checking version" + if (!$env:VERSION) { + $data=(git describe --tags --first-parent --abbrev=7 --long --dirty --always) + $pattern="v(.+)" + if ($data -match $pattern) { + $script:VERSION=$matches[1] + } + } else { + $script:VERSION=$env:VERSION + } + $pattern = "(\d+[.]\d+[.]\d+)-(\d+)-" + if ($script:VERSION -match $pattern) { + $script:PKG_VERSION=$matches[1] + "." + $matches[2] + } else { + $script:PKG_VERSION=$script:VERSION + } + write-host "Building Ollama $script:VERSION with package version $script:PKG_VERSION" + + # Check for signing key + if ("${env:KEY_CONTAINER}") { + ${script:OLLAMA_CERT}=$(resolve-path "${script:SRC_DIR}\ollama_inc.crt") + Write-host "Code signing enabled" + # Note: 10 Windows Kit signtool crashes with GCP's plugin + ${script:SignTool}="C:\Program Files (x86)\Windows Kits\8.1\bin\x64\signtool.exe" + } else { + write-host "Code signing disabled - please set KEY_CONTAINERS to sign and copy ollama_inc.crt to the top of the source tree" + } + +} + + +function buildOllama() { + write-host "Building ollama CLI" + & go generate ./... + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & go build -trimpath -ldflags "-s -w -X=github.com/jmorganca/ollama/version.Version=$script:VERSION -X=github.com/jmorganca/ollama/server.mode=release" . + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + if ("${env:KEY_CONTAINER}") { + & "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" ` + /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ollama.exe + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } + New-Item -ItemType Directory -Path .\dist -Force + cp .\ollama.exe .\dist\ollama-windows-amd64.exe +} + +function buildApp() { + write-host "Building Ollama App" + cd "${script:SRC_DIR}\app" + & windres -l 0 -o ollama.syso ollama.rc + & go build -trimpath -ldflags "-s -w -H windowsgui -X=github.com/jmorganca/ollama/version.Version=$script:VERSION -X=github.com/jmorganca/ollama/server.mode=release" . + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + if ("${env:KEY_CONTAINER}") { + & "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" ` + /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} app.exe + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } +} + +function gatherDependencies() { + write-host "Gathering runtime dependencies" + cd "${script:SRC_DIR}" + rm -ea 0 -recurse -force -path "${script:DEPS_DIR}" + md "${script:DEPS_DIR}" -ea 0 > $null + + # TODO - this varies based on host build system and MSVC version - drive from dumpbin output + # currently works for Win11 + MSVC 2019 + Cuda V11 + cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\" + cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\" + cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\" + + cp "${script:NVIDIA_DIR}\cudart64_*.dll" "${script:DEPS_DIR}\" + cp "${script:NVIDIA_DIR}\cublas64_*.dll" "${script:DEPS_DIR}\" + cp "${script:NVIDIA_DIR}\cublasLt64_*.dll" "${script:DEPS_DIR}\" + + cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\" + if ("${env:KEY_CONTAINER}") { + write-host "about to sign" + foreach ($file in (get-childitem "${script:DEPS_DIR}/cu*.dll") + @("${script:SRC_DIR}\dist\ollama_welcome.ps1")){ + write-host "signing $file" + & "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" ` + /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} $file + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } + } + +} + +function buildInstaller() { + write-host "Building Ollama Installer" + cd "${script:SRC_DIR}\app" + $env:PKG_VERSION=$script:PKG_VERSION + if ("${env:KEY_CONTAINER}") { + & "${script:INNO_SETUP_DIR}\ISCC.exe" /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss + } else { + & "${script:INNO_SETUP_DIR}\ISCC.exe" .\ollama.iss + } + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} +} + +try { + checkEnv + buildOllama + buildApp + gatherDependencies + buildInstaller +} catch { + write-host "Build Failed" + write-host $_ +} finally { + set-location $script:SRC_DIR + $env:PKG_VERSION="" +} diff --git a/scripts/install.sh b/scripts/install.sh index e9e2ebf2..f8f1d6f4 100644 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -61,7 +61,7 @@ if [ -n "$NEEDS" ]; then fi status "Downloading ollama..." -curl --fail --show-error --location --progress-bar -o $TEMP_DIR/ollama "https://ollama.ai/download/ollama-linux-$ARCH" +curl --fail --show-error --location --progress-bar -o $TEMP_DIR/ollama "https://ollama.com/download/ollama-linux-$ARCH" for BINDIR in /usr/local/bin /usr/bin /bin; do echo $PATH | grep -q $BINDIR && break || continue @@ -72,7 +72,7 @@ $SUDO install -o0 -g0 -m755 -d $BINDIR $SUDO install -o0 -g0 -m755 $TEMP_DIR/ollama $BINDIR/ollama install_success() { - status 'The Ollama API is now available at 0.0.0.0:11434.' + status 'The Ollama API is now available at 127.0.0.1:11434.' status 'Install complete. Run "ollama" from the command line.' } trap install_success EXIT @@ -88,6 +88,10 @@ configure_systemd() { status "Adding ollama user to render group..." $SUDO usermod -a -G render ollama fi + if getent group video >/dev/null 2>&1; then + status "Adding ollama user to video group..." + $SUDO usermod -a -G video ollama + fi status "Adding current user to ollama group..." $SUDO usermod -a -G ollama $(whoami) diff --git a/server/auth.go b/server/auth.go index 8faf1e4b..5af85ff6 100644 --- a/server/auth.go +++ b/server/auth.go @@ -1,7 +1,6 @@ package server import ( - "bytes" "context" "crypto/rand" "crypto/sha256" @@ -10,161 +9,87 @@ import ( "encoding/json" "fmt" "io" - "log/slog" "net/http" "net/url" - "os" - "path/filepath" "strconv" "strings" "time" - "golang.org/x/crypto/ssh" - "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/auth" ) -type AuthRedirect struct { +type registryChallenge struct { Realm string Service string Scope string } -type SignatureData struct { - Method string - Path string - Data []byte -} - -func generateNonce(length int) (string, error) { - nonce := make([]byte, length) - _, err := rand.Read(nonce) - if err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(nonce), nil -} - -func (r AuthRedirect) URL() (*url.URL, error) { +func (r registryChallenge) URL() (*url.URL, error) { redirectURL, err := url.Parse(r.Realm) if err != nil { return nil, err } values := redirectURL.Query() - values.Add("service", r.Service) - for _, s := range strings.Split(r.Scope, " ") { values.Add("scope", s) } values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10)) - nonce, err := generateNonce(16) + nonce, err := auth.NewNonce(rand.Reader, 16) if err != nil { return nil, err } + values.Add("nonce", nonce) redirectURL.RawQuery = values.Encode() return redirectURL, nil } -func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) { - redirectURL, err := redirData.URL() +func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) { + redirectURL, err := challenge.URL() if err != nil { return "", err } - home, err := os.UserHomeDir() - if err != nil { - return "", err - } - - keyPath := filepath.Join(home, ".ollama", "id_ed25519") - - rawKey, err := os.ReadFile(keyPath) - if err != nil { - slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) - return "", err - } - - s := SignatureData{ - Method: http.MethodGet, - Path: redirectURL.String(), - Data: nil, - } - - sig, err := s.Sign(rawKey) - if err != nil { - return "", err - } + sha256sum := sha256.Sum256(nil) + data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:]))))) headers := make(http.Header) - headers.Set("Authorization", sig) - resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil) - if err != nil { - slog.Info(fmt.Sprintf("couldn't get token: %q", err)) - return "", err - } - defer resp.Body.Close() - - if resp.StatusCode >= http.StatusBadRequest { - body, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body) - } - - respBody, err := io.ReadAll(resp.Body) + signature, err := auth.Sign(ctx, data) if err != nil { return "", err } - var tok api.TokenResponse - if err := json.Unmarshal(respBody, &tok); err != nil { + headers.Add("Authorization", signature) + + response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil) + if err != nil { + return "", err + } + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return "", fmt.Errorf("%d: %v", response.StatusCode, err) + } + + if response.StatusCode >= http.StatusBadRequest { + if len(body) > 0 { + return "", fmt.Errorf("%d: %s", response.StatusCode, body) + } else { + return "", fmt.Errorf("%d", response.StatusCode) + } + } + + var token api.TokenResponse + if err := json.Unmarshal(body, &token); err != nil { return "", err } - return tok.Token, nil -} - -// Bytes returns a byte slice of the data to sign for the request -func (s SignatureData) Bytes() []byte { - // We first derive the content hash of the request body using: - // base64(hex(sha256(request body))) - - hash := sha256.Sum256(s.Data) - hashHex := make([]byte, hex.EncodedLen(len(hash))) - hex.Encode(hashHex, hash[:]) - contentHash := base64.StdEncoding.EncodeToString(hashHex) - - // We then put the entire request together in a serialize string using: - // ",," - // e.g. "GET,http://localhost,OTdkZjM1O..." - - return []byte(strings.Join([]string{s.Method, s.Path, contentHash}, ",")) -} - -// SignData takes a SignatureData object and signs it with a raw private key -func (s SignatureData) Sign(rawKey []byte) (string, error) { - signer, err := ssh.ParsePrivateKey(rawKey) - if err != nil { - return "", err - } - - // get the pubkey, but remove the type - pubKey := ssh.MarshalAuthorizedKey(signer.PublicKey()) - parts := bytes.Split(pubKey, []byte(" ")) - if len(parts) < 2 { - return "", fmt.Errorf("malformed public key") - } - - signedData, err := signer.Sign(nil, s.Bytes()) - if err != nil { - return "", err - } - - // signature is : - sig := fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)) - return sig, nil + return token.Token, nil } diff --git a/server/download.go b/server/download.go index f089bd41..f6d199b9 100644 --- a/server/download.go +++ b/server/download.go @@ -85,7 +85,7 @@ func (p *blobDownloadPart) Write(b []byte) (n int, err error) { return n, nil } -func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { +func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { partFilePaths, err := filepath.Glob(b.Name + "-partial-*") if err != nil { return err @@ -137,11 +137,11 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R return nil } -func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) { +func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) { b.err = b.run(ctx, requestURL, opts) } -func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { +func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { defer blobDownloadManager.Delete(b.Digest) ctx, b.CancelFunc = context.WithCancel(ctx) @@ -210,7 +210,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis return nil } -func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error { +func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { headers := make(http.Header) @@ -334,7 +334,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) type downloadOpts struct { mp ModelPath digest string - regOpts *RegistryOptions + regOpts *registryOptions fn func(api.ProgressResponse) } diff --git a/server/images.go b/server/images.go index 503dd8e2..06f8ffd9 100644 --- a/server/images.go +++ b/server/images.go @@ -1,6 +1,7 @@ package server import ( + "archive/zip" "bytes" "context" "crypto/sha256" @@ -9,6 +10,7 @@ import ( "errors" "fmt" "io" + "io/fs" "log" "log/slog" "net/http" @@ -19,17 +21,17 @@ import ( "strconv" "strings" "text/template" - "text/template/parse" "golang.org/x/exp/slices" "github.com/jmorganca/ollama/api" + "github.com/jmorganca/ollama/convert" "github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/version" ) -type RegistryOptions struct { +type registryOptions struct { Insecure bool Username string Password string @@ -53,155 +55,15 @@ type Model struct { Messages []Message } +func (m *Model) IsEmbedding() bool { + return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") +} + type Message struct { Role string `json:"role"` Content string `json:"content"` } -type PromptVars struct { - System string - Prompt string - Response string - First bool -} - -// extractParts extracts the parts of the template before and after the {{.Response}} node. -func extractParts(tmplStr string) (pre string, post string, err error) { - tmpl, err := template.New("").Parse(tmplStr) - if err != nil { - return "", "", err - } - - var foundResponse bool - - for _, node := range tmpl.Tree.Root.Nodes { - if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" { - foundResponse = true - } - if !foundResponse { - pre += node.String() - } else { - post += node.String() - } - } - - return pre, post, nil -} - -func Prompt(promptTemplate string, p PromptVars) (string, error) { - var prompt strings.Builder - // Use the "missingkey=zero" option to handle missing variables without panicking - tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate) - if err != nil { - return "", err - } - - vars := map[string]any{ - "System": p.System, - "Prompt": p.Prompt, - "Response": p.Response, - "First": p.First, - } - - var sb strings.Builder - if err := tmpl.Execute(&sb, vars); err != nil { - return "", err - } - prompt.WriteString(sb.String()) - - if !strings.Contains(prompt.String(), p.Response) { - // if the response is not in the prompt template, append it to the end - prompt.WriteString(p.Response) - } - - return prompt.String(), nil -} - -// PreResponsePrompt returns the prompt before the response tag -func (m *Model) PreResponsePrompt(p PromptVars) (string, error) { - pre, _, err := extractParts(m.Template) - if err != nil { - return "", err - } - - return Prompt(pre, p) -} - -// PostResponseTemplate returns the template after the response tag -func (m *Model) PostResponseTemplate(p PromptVars) (string, error) { - if p.System == "" { - // use the default system prompt for this model if one is not specified - p.System = m.System - } - _, post, err := extractParts(m.Template) - if err != nil { - return "", err - } - - if post == "" { - // if there is no post-response template, return the provided response - return p.Response, nil - } - - return Prompt(post, p) -} - -type ChatHistory struct { - Prompts []PromptVars - CurrentImages []api.ImageData - LastSystem string -} - -// ChatPrompts returns a list of formatted chat prompts from a list of messages -func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) { - // build the prompt from the list of messages - var currentImages []api.ImageData - lastSystem := m.System - currentVars := PromptVars{ - First: true, - System: m.System, - } - - prompts := []PromptVars{} - - for _, msg := range msgs { - switch strings.ToLower(msg.Role) { - case "system": - // if this is the first message it overrides the system prompt in the modelfile - if !currentVars.First && currentVars.System != "" { - prompts = append(prompts, currentVars) - currentVars = PromptVars{} - } - currentVars.System = msg.Content - lastSystem = msg.Content - case "user": - if currentVars.Prompt != "" { - prompts = append(prompts, currentVars) - currentVars = PromptVars{} - } - currentVars.Prompt = msg.Content - currentImages = msg.Images - case "assistant": - currentVars.Response = msg.Content - prompts = append(prompts, currentVars) - currentVars = PromptVars{} - default: - return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) - } - } - - // Append the last set of vars if they are non-empty - if currentVars.Prompt != "" || currentVars.System != "" { - prompts = append(prompts, currentVars) - } - - return &ChatHistory{ - Prompts: prompts, - CurrentImages: currentImages, - LastSystem: lastSystem, - }, nil -} - type ManifestV2 struct { SchemaVersion int `json:"schemaVersion"` MediaType string `json:"mediaType"` @@ -457,7 +319,27 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars c.Args = blobPath } - bin, err := os.Open(realpath(modelFileDir, c.Args)) + pathName := realpath(modelFileDir, c.Args) + + ggufName, err := convertSafetensors(name, pathName) + if err != nil { + var pathErr *fs.PathError + switch { + case errors.Is(err, zip.ErrFormat): + // it's not a safetensor archive + case errors.As(err, &pathErr): + // it's not a file on disk, could be a model reference + default: + return err + } + } + + if ggufName != "" { + pathName = ggufName + defer os.RemoveAll(ggufName) + } + + bin, err := os.Open(pathName) if err != nil { // not a file on disk so must be a model reference modelpath := ParseModelPath(c.Args) @@ -465,7 +347,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars switch { case errors.Is(err, os.ErrNotExist): fn(api.ProgressResponse{Status: "pulling model"}) - if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { + if err := PullModel(ctx, c.Args, ®istryOptions{}, fn); err != nil { return err } @@ -591,7 +473,13 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars } defer bin.Close() - layer, err := NewLayer(bin, mediatype) + ggml, err := llm.DecodeGGML(bin) + if err != nil { + return err + } + + sr := io.NewSectionReader(bin, 0, ggml.Size) + layer, err := NewLayer(sr, mediatype) if err != nil { return err } @@ -733,6 +621,73 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars return nil } +func convertSafetensors(name, fn string) (string, error) { + r, err := zip.OpenReader(fn) + if err != nil { + return "", err + } + defer r.Close() + + tempDir, err := os.MkdirTemp("", "ollama-convert") + if err != nil { + return "", err + } + defer os.RemoveAll(tempDir) + + for _, f := range r.File { + fpath := filepath.Join(tempDir, f.Name) + outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) + if err != nil { + return "", err + } + + rc, err := f.Open() + if err != nil { + return "", err + } + + _, err = io.Copy(outFile, rc) + if err != nil { + return "", err + } + + outFile.Close() + rc.Close() + } + + params, err := convert.GetParams(tempDir) + if err != nil { + return "", err + } + + SupportedArchs := []string{ + "MistralForCausalLM", + } + + for _, arch := range params.Architectures { + if !slices.Contains(SupportedArchs, arch) { + return "", fmt.Errorf("this safetensors model is not yet supported") + } + } + + t, err := convert.GetSafeTensors(tempDir) + if err != nil { + return "", err + } + + vocab, err := convert.LoadTokens(tempDir) + if err != nil { + return "", err + } + + fn, err = convert.WriteGGUF(name, t, params, vocab) + if err != nil { + return "", err + } + + return fn, nil +} + func CopyModel(src, dest string) error { srcModelPath := ParseModelPath(src) srcPath, err := srcModelPath.GetManifestPath() @@ -985,7 +940,7 @@ PARAMETER {{ $k }} {{ printf "%#v" $parameter }} return buf.String(), nil } -func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { +func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "retrieving manifest"}) @@ -1035,7 +990,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu return nil } -func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error { +func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) var manifest *ManifestV2 @@ -1141,7 +1096,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu return nil } -func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) { +func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) { requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) headers := make(http.Header) @@ -1173,7 +1128,7 @@ func GetSHA256Digest(r io.Reader) (string, int64) { var errUnauthorized = fmt.Errorf("unauthorized") -func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) { +func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) { for i := 0; i < 2; i++ { resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) if err != nil { @@ -1187,9 +1142,8 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR switch { case resp.StatusCode == http.StatusUnauthorized: // Handle authentication error with one retry - auth := resp.Header.Get("www-authenticate") - authRedir := ParseAuthRedirectString(auth) - token, err := getAuthToken(ctx, authRedir) + challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate")) + token, err := getAuthorizationToken(ctx, challenge) if err != nil { return nil, err } @@ -1216,7 +1170,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR return nil, errUnauthorized } -func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) { +func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) { if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure { requestURL.Scheme = "http" } @@ -1249,18 +1203,7 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header req.ContentLength = contentLength } - proxyURL, err := http.ProxyFromEnvironment(req) - if err != nil { - return nil, err - } - - client := http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - }, - } - - resp, err := client.Do(req) + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } @@ -1291,10 +1234,10 @@ func getValue(header, key string) string { return header[startIdx:endIdx] } -func ParseAuthRedirectString(authStr string) AuthRedirect { +func parseRegistryChallenge(authStr string) registryChallenge { authStr = strings.TrimPrefix(authStr, "Bearer ") - return AuthRedirect{ + return registryChallenge{ Realm: getValue(authStr, "realm"), Service: getValue(authStr, "service"), Scope: getValue(authStr, "scope"), diff --git a/server/images_test.go b/server/images_test.go deleted file mode 100644 index 0f63a19b..00000000 --- a/server/images_test.go +++ /dev/null @@ -1,423 +0,0 @@ -package server - -import ( - "bytes" - "strings" - "testing" - - "github.com/jmorganca/ollama/api" -) - -func TestPrompt(t *testing.T) { - tests := []struct { - name string - template string - vars PromptVars - want string - wantErr bool - }{ - { - name: "System Prompt", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - vars: PromptVars{ - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", - }, - { - name: "System Prompt with Response", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", - vars: PromptVars{ - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - Response: "I don't know.", - }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.", - }, - { - name: "Conditional Logic Nodes", - template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", - vars: PromptVars{ - First: true, - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - Response: "I don't know.", - }, - want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := Prompt(tt.template, tt.vars) - if (err != nil) != tt.wantErr { - t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("Prompt() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestModel_PreResponsePrompt(t *testing.T) { - tests := []struct { - name string - template string - vars PromptVars - want string - wantErr bool - }{ - { - name: "No Response in Template", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - vars: PromptVars{ - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", - }, - { - name: "Response in Template", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", - vars: PromptVars{ - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ", - }, - { - name: "Response in Template with Trailing Formatting", - template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>", - vars: PromptVars{ - Prompt: "What are the potion ingredients?", - }, - want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n", - }, - { - name: "Response in Template with Alternative Formatting", - template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>", - vars: PromptVars{ - Prompt: "What are the potion ingredients?", - }, - want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n", - }, - } - - for _, tt := range tests { - m := Model{Template: tt.template} - t.Run(tt.name, func(t *testing.T) { - got, err := m.PreResponsePrompt(tt.vars) - if (err != nil) != tt.wantErr { - t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestModel_PostResponsePrompt(t *testing.T) { - tests := []struct { - name string - template string - vars PromptVars - want string - wantErr bool - }{ - { - name: "No Response in Template", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - vars: PromptVars{ - Response: "I don't know.", - }, - want: "I don't know.", - }, - { - name: "Response in Template", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", - vars: PromptVars{ - Response: "I don't know.", - }, - want: "I don't know.", - }, - { - name: "Response in Template with Trailing Formatting", - template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>", - vars: PromptVars{ - Response: "I don't know.", - }, - want: "I don't know.<|im_end|>", - }, - { - name: "Response in Template with Alternative Formatting", - template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>", - vars: PromptVars{ - Response: "I don't know.", - }, - want: "I don't know.<|im_end|>", - }, - } - - for _, tt := range tests { - m := Model{Template: tt.template} - t.Run(tt.name, func(t *testing.T) { - got, err := m.PostResponseTemplate(tt.vars) - if (err != nil) != tt.wantErr { - t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) { - tests := []struct { - name string - template string - preVars PromptVars - postVars PromptVars - want string - wantErr bool - }{ - { - name: "Response in Template", - template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>", - preVars: PromptVars{ - Prompt: "What are the potion ingredients?", - }, - postVars: PromptVars{ - Prompt: "What are the potion ingredients?", - Response: "Sugar.", - }, - want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>", - }, - { - name: "No Response in Template", - template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n", - preVars: PromptVars{ - Prompt: "What are the potion ingredients?", - }, - postVars: PromptVars{ - Prompt: "What are the potion ingredients?", - Response: "Spice.", - }, - want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.", - }, - } - - for _, tt := range tests { - m := Model{Template: tt.template} - t.Run(tt.name, func(t *testing.T) { - pre, err := m.PreResponsePrompt(tt.preVars) - if (err != nil) != tt.wantErr { - t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr) - return - } - post, err := m.PostResponseTemplate(tt.postVars) - if err != nil { - t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr) - return - } - result := pre + post - if result != tt.want { - t.Errorf("Prompt() got = %v, want %v", result, tt.want) - } - }) - } -} - -func chatHistoryEqual(a, b ChatHistory) bool { - if len(a.Prompts) != len(b.Prompts) { - return false - } - if len(a.CurrentImages) != len(b.CurrentImages) { - return false - } - for i, v := range a.Prompts { - if v != b.Prompts[i] { - return false - } - } - for i, v := range a.CurrentImages { - if !bytes.Equal(v, b.CurrentImages[i]) { - return false - } - } - return a.LastSystem == b.LastSystem -} - -func TestChat(t *testing.T) { - tests := []struct { - name string - model Model - msgs []api.Message - want ChatHistory - wantErr string - }{ - { - name: "Single Message", - model: Model{ - Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - }, - msgs: []api.Message{ - { - Role: "system", - Content: "You are a Wizard.", - }, - { - Role: "user", - Content: "What are the potion ingredients?", - }, - }, - want: ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - First: true, - }, - }, - LastSystem: "You are a Wizard.", - }, - }, - { - name: "Message History", - model: Model{ - Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - }, - msgs: []api.Message{ - { - Role: "system", - Content: "You are a Wizard.", - }, - { - Role: "user", - Content: "What are the potion ingredients?", - }, - { - Role: "assistant", - Content: "sugar", - }, - { - Role: "user", - Content: "Anything else?", - }, - }, - want: ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - Response: "sugar", - First: true, - }, - { - Prompt: "Anything else?", - }, - }, - LastSystem: "You are a Wizard.", - }, - }, - { - name: "Assistant Only", - model: Model{ - Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - }, - msgs: []api.Message{ - { - Role: "assistant", - Content: "everything nice", - }, - }, - want: ChatHistory{ - Prompts: []PromptVars{ - { - Response: "everything nice", - First: true, - }, - }, - }, - }, - { - name: "Last system message is preserved from modelfile", - model: Model{ - Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - System: "You are Mojo Jojo.", - }, - msgs: []api.Message{ - { - Role: "user", - Content: "hi", - }, - }, - want: ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are Mojo Jojo.", - Prompt: "hi", - First: true, - }, - }, - LastSystem: "You are Mojo Jojo.", - }, - }, - { - name: "Last system message is preserved from messages", - model: Model{ - Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - System: "You are Mojo Jojo.", - }, - msgs: []api.Message{ - { - Role: "system", - Content: "You are Professor Utonium.", - }, - }, - want: ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are Professor Utonium.", - First: true, - }, - }, - LastSystem: "You are Professor Utonium.", - }, - }, - { - name: "Invalid Role", - msgs: []api.Message{ - { - Role: "not-a-role", - Content: "howdy", - }, - }, - wantErr: "invalid role: not-a-role", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.model.ChatPrompts(tt.msgs) - if tt.wantErr != "" { - if err == nil { - t.Errorf("ChatPrompt() expected error, got nil") - } - if !strings.Contains(err.Error(), tt.wantErr) { - t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr) - } - return - } - if !chatHistoryEqual(*got, tt.want) { - t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want) - } - }) - } -} diff --git a/server/prompt.go b/server/prompt.go new file mode 100644 index 00000000..88da5b6b --- /dev/null +++ b/server/prompt.go @@ -0,0 +1,221 @@ +package server + +import ( + "fmt" + "log/slog" + "strings" + "text/template" + "text/template/parse" + + "github.com/jmorganca/ollama/api" +) + +// isResponseNode checks if the node contains .Response +func isResponseNode(node *parse.ActionNode) bool { + for _, cmd := range node.Pipe.Cmds { + for _, arg := range cmd.Args { + if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 { + if fieldNode.Ident[0] == "Response" { + return true + } + } + } + } + return false +} + +// formatTemplateForResponse formats the template AST to: +// 1. remove all nodes after the first .Response (if generate=true) +// 2. add a .Response node to the end if it doesn't exist +// TODO(jmorganca): this should recursively cut the template before the first .Response +func formatTemplateForResponse(tmpl *template.Template, generate bool) { + var found bool + for i, node := range tmpl.Tree.Root.Nodes { + if actionNode, ok := node.(*parse.ActionNode); ok { + if isResponseNode(actionNode) { + found = true + if generate { + tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1] + break + } + } + } + } + + if !found { + // add the response node if it doesn't exist + responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}} + responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}} + responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode} + tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode) + } +} + +// Prompt renders a prompt from a template. If generate is set to true, +// the response and parts of the template following it are not rendered +func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) { + parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl) + if err != nil { + return "", err + } + + formatTemplateForResponse(parsed, generate) + + vars := map[string]any{ + "System": system, + "Prompt": prompt, + "Response": response, + } + + var sb strings.Builder + if err := parsed.Execute(&sb, vars); err != nil { + return "", err + } + + return sb.String(), nil +} + +func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) { + rendered, err := Prompt(tmpl, system, prompt, response, false) + if err != nil { + return 0, err + } + + tokens, err := encode(rendered) + if err != nil { + slog.Error("failed to encode prompt", "err", err) + return 0, err + } + + return len(tokens), err +} + +// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size +func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) { + type prompt struct { + System string + Prompt string + Response string + + images []int + tokens int + } + + var p prompt + + // iterate through messages to build up {system,user,response} prompts + var imgId int + var prompts []prompt + for _, msg := range messages { + switch strings.ToLower(msg.Role) { + case "system": + if p.System != "" || p.Prompt != "" || p.Response != "" { + prompts = append(prompts, p) + p = prompt{} + } + + p.System = msg.Content + case "user": + if p.Prompt != "" || p.Response != "" { + prompts = append(prompts, p) + p = prompt{} + } + + var sb strings.Builder + for range msg.Images { + fmt.Fprintf(&sb, "[img-%d] ", imgId) + p.images = append(p.images, imgId) + imgId += 1 + } + + sb.WriteString(msg.Content) + p.Prompt = sb.String() + case "assistant": + if p.Response != "" { + prompts = append(prompts, p) + p = prompt{} + } + + p.Response = msg.Content + default: + return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) + } + } + + // add final prompt + if p.System != "" || p.Prompt != "" || p.Response != "" { + prompts = append(prompts, p) + } + + // calculate token lengths for each prompt, estimating 768 tokens per images + for i, p := range prompts { + tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode) + if err != nil { + return "", err + } + + prompts[i].tokens = tokens + len(prompts[i].images)*768 + } + + // truncate images and prompts starting from the beginning of the list + // until either one prompt remains or the total tokens fits the context window + // TODO (jmorganca): this doesn't account for the context window room required for the response + for { + var required int + for _, p := range prompts { + required += p.tokens + } + + required += 1 // for bos token + + if required <= window { + slog.Debug("prompt now fits in context window", "required", required, "window", window) + break + } + + prompt := &prompts[0] + + if len(prompt.images) > 1 { + img := prompt.images[0] + slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window) + prompt.images = prompt.images[1:] + prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1) + prompt.tokens -= 768 + continue + } + + if len(prompts) > 1 { + slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window) + system := prompt.System + prompts = prompts[1:] + + if system != "" && prompts[0].System == "" { + prompts[0].System = system + + tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode) + if err != nil { + return "", err + } + + prompts[0].tokens = tokens + len(prompts[0].images)*768 + } + + continue + } + + // stop truncating if there's only one prompt left + break + } + + var sb strings.Builder + for i, p := range prompts { + // last prompt should leave the response unrendered (for completion) + rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1) + if err != nil { + return "", err + } + sb.WriteString(rendered) + } + + return sb.String(), nil +} diff --git a/server/prompt_test.go b/server/prompt_test.go new file mode 100644 index 00000000..500ee522 --- /dev/null +++ b/server/prompt_test.go @@ -0,0 +1,205 @@ +package server + +import ( + "strings" + "testing" + + "github.com/jmorganca/ollama/api" +) + +func TestPrompt(t *testing.T) { + tests := []struct { + name string + template string + system string + prompt string + response string + generate bool + want string + }{ + { + name: "simple prompt", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + system: "You are a Wizard.", + prompt: "What are the potion ingredients?", + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", + }, + { + name: "implicit response", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", + system: "You are a Wizard.", + prompt: "What are the potion ingredients?", + response: "I don't know.", + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.", + }, + { + name: "response", + template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}", + system: "You are a Wizard.", + prompt: "What are the potion ingredients?", + response: "I don't know.", + want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.", + }, + { + name: "cut", + template: "{{ .System }}{{ .Prompt }}{{ .Response }}", + system: "You are a Wizard.", + prompt: "What are the potion ingredients?", + response: "I don't know.", + generate: true, + want: "You are a Wizard.What are the potion ingredients?I don't know.", + }, + { + name: "nocut", + template: "{{ .System }}{{ .Prompt }}{{ .Response }}", + system: "You are a Wizard.", + prompt: "What are the potion ingredients?", + response: "I don't know.", + want: "You are a Wizard.What are the potion ingredients?I don't know.", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate) + if err != nil { + t.Errorf("error = %v", err) + } + + if got != tc.want { + t.Errorf("got = %v, want %v", got, tc.want) + } + }) + } +} + +func TestChatPrompt(t *testing.T) { + tests := []struct { + name string + template string + messages []api.Message + window int + want string + }{ + { + name: "simple prompt", + template: "[INST] {{ .Prompt }} [/INST]", + messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + window: 1024, + want: "[INST] Hello [/INST]", + }, + { + name: "with system message", + template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "Hello"}, + }, + window: 1024, + want: "[INST] <>You are a Wizard.<> Hello [/INST]", + }, + { + name: "with response", + template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "I am?"}, + }, + window: 1024, + want: "[INST] <>You are a Wizard.<> Hello [/INST] I am?", + }, + { + name: "with implicit response", + template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST]", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "I am?"}, + }, + window: 1024, + want: "[INST] <>You are a Wizard.<> Hello [/INST]I am?", + }, + { + name: "with conversation", + template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "What are the potion ingredients?"}, + {Role: "assistant", Content: "sugar"}, + {Role: "user", Content: "Anything else?"}, + }, + window: 1024, + want: "[INST] <>You are a Wizard.<> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ", + }, + { + name: "with truncation", + template: "{{ .System }} {{ .Prompt }} {{ .Response }} ", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "I am?"}, + {Role: "user", Content: "Why is the sky blue?"}, + {Role: "assistant", Content: "The sky is blue from rayleigh scattering"}, + }, + window: 10, + want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering", + }, + { + name: "images", + template: "{{ .System }} {{ .Prompt }}", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}}, + }, + window: 1024, + want: "You are a Wizard. [img-0] Hello", + }, + { + name: "images truncated", + template: "{{ .System }} {{ .Prompt }}", + messages: []api.Message{ + {Role: "system", Content: "You are a Wizard."}, + {Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}}, + }, + window: 1024, + want: "You are a Wizard. [img-0] [img-1] Hello", + }, + { + name: "empty list", + template: "{{ .System }} {{ .Prompt }}", + messages: []api.Message{}, + window: 1024, + want: "", + }, + { + name: "empty prompt", + template: "[INST] {{ if .System }}<>{{ .System }}<> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ", + messages: []api.Message{ + {Role: "user", Content: ""}, + }, + window: 1024, + want: "", + }, + } + + encode := func(s string) ([]int, error) { + words := strings.Fields(s) + return make([]int, len(words)), nil + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode) + if err != nil { + t.Errorf("error = %v", err) + } + + if got != tc.want { + t.Errorf("got: %q, want: %q", got, tc.want) + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index 01a898a8..d99c858c 100644 --- a/server/routes.go +++ b/server/routes.go @@ -10,6 +10,7 @@ import ( "log/slog" "net" "net/http" + "net/netip" "os" "os/signal" "path/filepath" @@ -22,10 +23,12 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "golang.org/x/exp/slices" "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/gpu" "github.com/jmorganca/ollama/llm" + "github.com/jmorganca/ollama/openai" "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/version" ) @@ -33,7 +36,7 @@ import ( var mode string = gin.DebugMode type Server struct { - WorkDir string + addr net.Addr } func init() { @@ -64,8 +67,6 @@ var defaultSessionDuration = 5 * time.Minute // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error { - workDir := c.GetString("workDir") - needLoad := loaded.runner == nil || // is there a model loaded? loaded.ModelPath != model.ModelPath || // has the base model changed? !reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed? @@ -80,7 +81,7 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D loaded.Options = nil } - llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts) + llmRunner, err := llm.New(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts) if err != nil { // some older models are not compatible with newer versions of llama.cpp // show a generalized compatibility error until there is a better way to @@ -135,6 +136,12 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options return opts, nil } +func isSupportedImageType(image []byte) bool { + contentType := http.DetectContentType(image) + allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"} + return slices.Contains(allowedTypes, contentType) +} + func GenerateHandler(c *gin.Context) { loaded.mu.Lock() defer loaded.mu.Unlock() @@ -165,6 +172,13 @@ func GenerateHandler(c *gin.Context) { return } + for _, img := range req.Images { + if !isSupportedImageType(img) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"}) + return + } + } + model, err := GetModel(req.Model) if err != nil { var pErr *fs.PathError @@ -176,6 +190,11 @@ func GenerateHandler(c *gin.Context) { return } + if model.IsEmbedding() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"}) + return + } + opts, err := modelOptions(model, req.Options) if err != nil { if errors.Is(err, api.ErrInvalidOpts) { @@ -199,6 +218,8 @@ func GenerateHandler(c *gin.Context) { } // an empty request loads the model + // note: for a short while template was used in lieu + // of `raw` mode so we need to check for it too if req.Prompt == "" && req.Template == "" && req.System == "" { c.JSON(http.StatusOK, api.GenerateResponse{ CreatedAt: time.Now().UTC(), @@ -211,49 +232,52 @@ func GenerateHandler(c *gin.Context) { checkpointLoaded := time.Now() var prompt string - var promptVars PromptVars switch { case req.Raw: prompt = req.Prompt case req.Prompt != "": - if req.Template != "" { - // override the default model template - model.Template = req.Template + if req.Template == "" { + req.Template = model.Template } - var rebuild strings.Builder + if req.System == "" { + req.System = model.System + } + + slog.Debug("generate handler", "prompt", req.Prompt) + slog.Debug("generate handler", "template", req.Template) + slog.Debug("generate handler", "system", req.System) + + var sb strings.Builder + for i := range req.Images { + fmt.Fprintf(&sb, "[img-%d] ", i) + } + + sb.WriteString(req.Prompt) + + p, err := Prompt(req.Template, req.System, sb.String(), "", true) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + sb.Reset() if req.Context != nil { - // TODO: context is deprecated, at some point the context logic within this conditional should be removed - prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context) + prev, err := loaded.runner.Decode(c.Request.Context(), req.Context) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - // Remove leading spaces from prevCtx if present - prevCtx = strings.TrimPrefix(prevCtx, " ") - rebuild.WriteString(prevCtx) - } - promptVars = PromptVars{ - System: req.System, - Prompt: req.Prompt, - First: len(req.Context) == 0, + sb.WriteString(prev) } - if promptVars.System == "" { - promptVars.System = model.System - } + sb.WriteString(p) - p, err := model.PreResponsePrompt(promptVars) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - rebuild.WriteString(p) - prompt = rebuild.String() + prompt = sb.String() } - slog.Debug(fmt.Sprintf("prompt: %s", prompt)) + slog.Debug("generate handler", "prompt", prompt) ch := make(chan any) var generated strings.Builder @@ -289,30 +313,39 @@ func GenerateHandler(c *gin.Context) { resp.LoadDuration = checkpointLoaded.Sub(checkpointStart) if !req.Raw { - // append the generated text to the history and template it if needed - promptVars.Response = generated.String() - result, err := model.PostResponseTemplate(promptVars) + p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // TODO (jmorganca): encode() should not strip special tokens + tokens, err := loaded.runner.Encode(c.Request.Context(), p) if err != nil { ch <- gin.H{"error": err.Error()} return } - embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result) - if err != nil { - ch <- gin.H{"error": err.Error()} - return - } - resp.Context = embd + + resp.Context = append(req.Context, tokens...) } } ch <- resp } + var images []llm.ImageData + for i := range req.Images { + images = append(images, llm.ImageData{ + ID: i, + Data: req.Images[i], + }) + } + // Start prediction predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: req.Images, + Images: images, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { @@ -351,7 +384,7 @@ func GenerateHandler(c *gin.Context) { streamResponse(c, ch) } -func EmbeddingHandler(c *gin.Context) { +func EmbeddingsHandler(c *gin.Context) { loaded.mu.Lock() defer loaded.mu.Unlock() @@ -404,8 +437,9 @@ func EmbeddingHandler(c *gin.Context) { return } - if !loaded.Options.EmbeddingOnly { - c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"}) + // an empty request loads the model + if req.Prompt == "" { + c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}}) return } @@ -451,7 +485,7 @@ func PullModelHandler(c *gin.Context) { ch <- r } - regOpts := &RegistryOptions{ + regOpts := ®istryOptions{ Insecure: req.Insecure, } @@ -500,7 +534,7 @@ func PushModelHandler(c *gin.Context) { ch <- r } - regOpts := &RegistryOptions{ + regOpts := ®istryOptions{ Insecure: req.Insecure, } @@ -871,15 +905,83 @@ var defaultAllowOrigins = []string{ "0.0.0.0", } -func NewServer() (*Server, error) { - workDir, err := os.MkdirTemp("", "ollama") - if err != nil { - return nil, err +func isLocalIP(ip netip.Addr) bool { + if interfaces, err := net.Interfaces(); err == nil { + for _, iface := range interfaces { + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, a := range addrs { + if parsed, _, err := net.ParseCIDR(a.String()); err == nil { + if parsed.String() == ip.String() { + return true + } + } + } + } } - return &Server{ - WorkDir: workDir, - }, nil + return false +} + +func allowedHost(host string) bool { + if host == "" || host == "localhost" { + return true + } + + if hostname, err := os.Hostname(); err == nil && host == hostname { + return true + } + + var tlds = []string{ + "localhost", + "local", + "internal", + } + + // check if the host is a local TLD + for _, tld := range tlds { + if strings.HasSuffix(host, "."+tld) { + return true + } + } + + return false +} + +func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc { + return func(c *gin.Context) { + if addr == nil { + c.Next() + return + } + + if addr, err := netip.ParseAddrPort(addr.String()); err == nil && !addr.Addr().IsLoopback() { + c.Next() + return + } + + host, _, err := net.SplitHostPort(c.Request.Host) + if err != nil { + host = c.Request.Host + } + + if addr, err := netip.ParseAddr(host); err == nil { + if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) { + c.Next() + return + } + } + + if allowedHost(host) { + c.Next() + return + } + + c.AbortWithStatus(http.StatusForbidden) + } } func (s *Server) GenerateRoutes() http.Handler { @@ -905,16 +1007,13 @@ func (s *Server) GenerateRoutes() http.Handler { r := gin.Default() r.Use( cors.New(config), - func(c *gin.Context) { - c.Set("workDir", s.WorkDir) - c.Next() - }, + allowedHostsMiddleware(s.addr), ) r.POST("/api/pull", PullModelHandler) r.POST("/api/generate", GenerateHandler) r.POST("/api/chat", ChatHandler) - r.POST("/api/embeddings", EmbeddingHandler) + r.POST("/api/embeddings", EmbeddingsHandler) r.POST("/api/create", CreateModelHandler) r.POST("/api/push", PushModelHandler) r.POST("/api/copy", CopyModelHandler) @@ -923,6 +1022,9 @@ func (s *Server) GenerateRoutes() http.Handler { r.POST("/api/blobs/:digest", CreateBlobHandler) r.HEAD("/api/blobs/:digest", HeadBlobHandler) + // Compatibility endpoints + r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler) + for _, method := range []string{http.MethodGet, http.MethodHead} { r.Handle(method, "/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") @@ -974,10 +1076,7 @@ func Serve(ln net.Listener) error { } } - s, err := NewServer() - if err != nil { - return err - } + s := &Server{addr: ln.Addr()} r := s.GenerateRoutes() slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version)) @@ -993,11 +1092,11 @@ func Serve(ln net.Listener) error { if loaded.runner != nil { loaded.runner.Close() } - os.RemoveAll(s.WorkDir) + gpu.Cleanup() os.Exit(0) }() - if err := llm.Init(s.WorkDir); err != nil { + if err := llm.Init(); err != nil { return fmt.Errorf("unable to initialize llm library %w", err) } if runtime.GOOS == "linux" { // TODO - windows too @@ -1060,6 +1159,20 @@ func streamResponse(c *gin.Context, ch chan any) { }) } +// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model +func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) { + encode := func(s string) ([]int, error) { + return loaded.runner.Encode(ctx, s) + } + + prompt, err := ChatPrompt(template, messages, numCtx, encode) + if err != nil { + return "", err + } + + return prompt, nil +} + func ChatHandler(c *gin.Context) { loaded.mu.Lock() defer loaded.mu.Unlock() @@ -1098,6 +1211,11 @@ func ChatHandler(c *gin.Context) { return } + if model.IsEmbedding() { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"}) + return + } + opts, err := modelOptions(model, req.Options) if err != nil { if errors.Is(err, api.ErrInvalidOpts) { @@ -1120,8 +1238,26 @@ func ChatHandler(c *gin.Context) { return } + checkpointLoaded := time.Now() + + // if the first message is not a system message, then add the model's default system message + if len(req.Messages) > 0 && req.Messages[0].Role != "system" { + req.Messages = append([]api.Message{ + { + Role: "system", + Content: model.System, + }, + }, req.Messages...) + } + + prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + // an empty request loads the model - if len(req.Messages) == 0 { + if len(req.Messages) == 0 || prompt == "" { resp := api.ChatResponse{ CreatedAt: time.Now().UTC(), Model: req.Model, @@ -1132,20 +1268,24 @@ func ChatHandler(c *gin.Context) { return } - checkpointLoaded := time.Now() + // only send images that are in the prompt + var i int + var images []llm.ImageData + for _, m := range req.Messages { + for _, img := range m.Images { + if !isSupportedImageType(img) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"}) + return + } - chat, err := model.ChatPrompts(req.Messages) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - prompt, err := trimmedPrompt(c.Request.Context(), chat, model) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) { + images = append(images, llm.ImageData{Data: img, ID: i}) + } + i += 1 + } } - slog.Debug(fmt.Sprintf("prompt: %s", prompt)) + slog.Debug("chat handler", "prompt", prompt, "images", len(images)) ch := make(chan any) @@ -1182,7 +1322,7 @@ func ChatHandler(c *gin.Context) { predictReq := llm.PredictOpts{ Prompt: prompt, Format: req.Format, - Images: chat.CurrentImages, + Images: images, Options: opts, } if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { @@ -1220,101 +1360,3 @@ func ChatHandler(c *gin.Context) { streamResponse(c, ch) } - -// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model -type promptInfo struct { - vars PromptVars - tokenLen int -} - -// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length, -// while preserving the most recent system message. -func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) { - if len(chat.Prompts) == 0 { - return "", nil - } - - var promptsToAdd []promptInfo - var totalTokenLength int - var systemPromptIncluded bool - - // reverse iterate through the prompts to build the prompt string in a way that fits the max context length - for i := len(chat.Prompts) - 1; i >= 0; i-- { - promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1) - if err != nil { - return "", err - } - - encodedTokens, err := loaded.runner.Encode(ctx, promptText) - if err != nil { - return "", err - } - - if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 { - break // reached max context length, stop adding more prompts - } - - totalTokenLength += len(encodedTokens) - systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != "" - promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)}) - } - - // ensure the system prompt is included, if not already - if chat.LastSystem != "" && !systemPromptIncluded { - var err error - promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd) - if err != nil { - return "", err - } - } - - promptsToAdd[len(promptsToAdd)-1].vars.First = true - - // construct the final prompt string from the prompts which fit within the context window - var result string - for i, prompt := range promptsToAdd { - promptText, err := promptString(model, prompt.vars, i == 0) - if err != nil { - return "", err - } - result = promptText + result - } - return result, nil -} - -// promptString applies the model template to the prompt -func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) { - if isMostRecent { - p, err := model.PreResponsePrompt(vars) - if err != nil { - return "", fmt.Errorf("pre-response template: %w", err) - } - return p, nil - } - p, err := Prompt(model.Template, vars) - if err != nil { - return "", err - } - return p, nil -} - -// includeSystemPrompt adjusts the prompts to include the system prompt. -func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) { - systemTokens, err := loaded.runner.Encode(ctx, systemPrompt) - if err != nil { - return nil, err - } - - for i := len(promptsToAdd) - 1; i >= 0; i-- { - if totalTokenLength+len(systemTokens) <= loaded.NumCtx { - promptsToAdd[i].vars.System = systemPrompt - return promptsToAdd[:i+1], nil - } - totalTokenLength -= promptsToAdd[i].tokenLen - } - - // if got here, system did not fit anywhere, so return the most recent prompt with the system message set - recent := promptsToAdd[len(promptsToAdd)-1] - recent.vars.System = systemPrompt - return []promptInfo{recent}, nil -} diff --git a/server/routes_test.go b/server/routes_test.go index 9c53dc20..bbed02ed 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -21,12 +21,6 @@ import ( "github.com/jmorganca/ollama/version" ) -func setupServer(t *testing.T) (*Server, error) { - t.Helper() - - return NewServer() -} - func Test_Routes(t *testing.T) { type testCase struct { Name string @@ -207,9 +201,7 @@ func Test_Routes(t *testing.T) { }, } - s, err := setupServer(t) - assert.Nil(t, err) - + s := Server{} router := s.GenerateRoutes() httpSrv := httptest.NewServer(router) @@ -241,236 +233,6 @@ func Test_Routes(t *testing.T) { } } -func Test_ChatPrompt(t *testing.T) { - tests := []struct { - name string - template string - chat *ChatHistory - numCtx int - runner MockLLM - want string - wantErr string - }{ - { - name: "Single Message", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - First: true, - }, - }, - LastSystem: "You are a Wizard.", - }, - numCtx: 1, - runner: MockLLM{ - encoding: []int{1}, // fit the ctxLen - }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", - }, - { - name: "First Message", - template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - Response: "eye of newt", - First: true, - }, - { - Prompt: "Anything else?", - }, - }, - LastSystem: "You are a Wizard.", - }, - numCtx: 2, - runner: MockLLM{ - encoding: []int{1}, // fit the ctxLen - }, - want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]", - }, - { - name: "Message History", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - System: "You are a Wizard.", - Prompt: "What are the potion ingredients?", - Response: "sugar", - First: true, - }, - { - Prompt: "Anything else?", - }, - }, - LastSystem: "You are a Wizard.", - }, - numCtx: 4, - runner: MockLLM{ - encoding: []int{1}, // fit the ctxLen, 1 for each message - }, - want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]", - }, - { - name: "Assistant Only", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - Response: "everything nice", - First: true, - }, - }, - }, - numCtx: 1, - runner: MockLLM{ - encoding: []int{1}, - }, - want: "[INST] [/INST]everything nice", - }, - { - name: "Message History Truncated, No System", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - Prompt: "What are the potion ingredients?", - Response: "sugar", - First: true, - }, - { - Prompt: "Anything else?", - Response: "spice", - }, - { - Prompt: "... and?", - }, - }, - }, - numCtx: 2, // only 1 message from history and most recent message - runner: MockLLM{ - encoding: []int{1}, - }, - want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]", - }, - { - name: "System is Preserved when Truncated", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - Prompt: "What are the magic words?", - Response: "abracadabra", - }, - { - Prompt: "What is the spell for invisibility?", - }, - }, - LastSystem: "You are a wizard.", - }, - numCtx: 2, - runner: MockLLM{ - encoding: []int{1}, - }, - want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]", - }, - { - name: "System is Preserved when Length Exceeded", - template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - Prompt: "What are the magic words?", - Response: "abracadabra", - }, - { - Prompt: "What is the spell for invisibility?", - }, - }, - LastSystem: "You are a wizard.", - }, - numCtx: 1, - runner: MockLLM{ - encoding: []int{1}, - }, - want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]", - }, - { - name: "First is Preserved when Truncated", - template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]", - - chat: &ChatHistory{ - Prompts: []PromptVars{ - // first message omitted for test - { - Prompt: "Do you have a magic hat?", - Response: "Of course.", - }, - { - Prompt: "What is the spell for invisibility?", - }, - }, - LastSystem: "You are a wizard.", - }, - numCtx: 3, // two most recent messages and room for system message - runner: MockLLM{ - encoding: []int{1}, - }, - want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]", - }, - { - name: "Most recent message is returned when longer than ctxLen", - template: "[INST] {{ .Prompt }} [/INST]", - - chat: &ChatHistory{ - Prompts: []PromptVars{ - { - Prompt: "What is the spell for invisibility?", - First: true, - }, - }, - }, - numCtx: 1, // two most recent messages - runner: MockLLM{ - encoding: []int{1, 2}, - }, - want: "[INST] What is the spell for invisibility? [/INST]", - }, - } - - for _, testCase := range tests { - tt := testCase - m := &Model{ - Template: tt.template, - } - t.Run(tt.name, func(t *testing.T) { - loaded.runner = &tt.runner - loaded.Options = &api.Options{ - Runner: api.Runner{ - NumCtx: tt.numCtx, - }, - } - got, err := trimmedPrompt(context.Background(), tt.chat, m) - if tt.wantErr != "" { - if err == nil { - t.Errorf("ChatPrompt() expected error, got nil") - } - if !strings.Contains(err.Error(), tt.wantErr) { - t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr) - } - } - if got != tt.want { - t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want) - } - }) - } -} - type MockLLM struct { encoding []int } diff --git a/server/upload.go b/server/upload.go index 3609b308..4da34052 100644 --- a/server/upload.go +++ b/server/upload.go @@ -12,7 +12,6 @@ import ( "net/http" "net/url" "os" - "strings" "sync" "sync/atomic" "time" @@ -49,7 +48,7 @@ const ( maxUploadPartSize int64 = 1000 * format.MegaByte ) -func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { +func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { p, err := GetBlobsPath(b.Digest) if err != nil { return err @@ -121,7 +120,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg // Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded // in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error. -func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { +func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) { defer blobUploadManager.Delete(b.Digest) ctx, b.CancelFunc = context.WithCancel(ctx) @@ -177,16 +176,14 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { requestURL := <-b.nextURL // calculate md5 checksum and add it to the commit request - var sb strings.Builder + md5sum := md5.New() for _, part := range b.Parts { - sb.Write(part.Sum(nil)) + md5sum.Write(part.Sum(nil)) } - md5sum := md5.Sum([]byte(sb.String())) - values := requestURL.Query() values.Add("digest", b.Digest) - values.Add("etag", fmt.Sprintf("%x-%d", md5sum, len(b.Parts))) + values.Add("etag", fmt.Sprintf("%x-%d", md5sum.Sum(nil), len(b.Parts))) requestURL.RawQuery = values.Encode() headers := make(http.Header) @@ -212,7 +209,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) { b.done = true } -func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error { +func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *registryOptions) error { headers := make(http.Header) headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Length", fmt.Sprintf("%d", part.Size)) @@ -277,9 +274,8 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL * case resp.StatusCode == http.StatusUnauthorized: w.Rollback() - auth := resp.Header.Get("www-authenticate") - authRedir := ParseAuthRedirectString(auth) - token, err := getAuthToken(ctx, authRedir) + challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate")) + token, err := getAuthorizationToken(ctx, challenge) if err != nil { return err } @@ -364,7 +360,7 @@ func (p *progressWriter) Rollback() { p.written = 0 } -func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error { +func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error { requestURL := mp.BaseURL() requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)