Compare commits

...

43 Commits

Author SHA1 Message Date
Josh Yan
f30b54209c revert other pr change 2024-07-24 11:11:59 -07:00
Josh Yan
e39be4f63a short circuit 2024-07-23 17:14:34 -07:00
Josh Yan
b8c3d54f7a set homedir for windows --test 2024-07-23 16:28:42 -07:00
Josh Yan
c8434b0e69 rmv unsued 2024-07-23 16:03:19 -07:00
Josh Yan
65658e4077 default to post 2024-07-23 14:05:48 -07:00
Josh Yan
b29382b86f bin 2024-07-23 13:57:40 -07:00
Josh Yan
2efe2013a1 test 2024-07-23 13:54:22 -07:00
Josh Yan
5c3786f4d5 comments 2024-07-23 13:46:34 -07:00
Josh Yan
33848ad10f serverside copy 2024-07-23 12:26:05 -07:00
Josh Yan
ff06a2916d changes 2024-07-22 15:51:52 -07:00
Josh Yan
d923a59356 testing auth 2024-07-22 15:51:52 -07:00
Josh Yan
2b42ad5754 auth changes' 2024-07-22 15:51:52 -07:00
Josh Yan
e3253e5469 isLocal testing 2024-07-22 15:51:52 -07:00
Josh Yan
35b49739ec timecheck 2024-07-22 15:51:52 -07:00
Josh Yan
bd8596d32b cmt 2024-07-22 15:51:52 -07:00
Josh Yan
b85705162f remove knownhosts 2024-07-22 15:51:52 -07:00
Josh Yan
d62a3a1e2b lint 2024-07-22 15:51:52 -07:00
Josh Yan
de48cd681f clean 2024-07-22 15:51:52 -07:00
Josh Yan
5d0e078057 removed cmt and prints 2024-07-22 15:51:52 -07:00
Josh Yan
8d5739b833 removed client isLocal() 2024-07-22 15:51:52 -07:00
Josh Yan
b5ff0ed4ff lint 2024-07-22 15:51:52 -07:00
Josh Yan
857054f9fa lint 2024-07-22 15:51:52 -07:00
Josh Yan
6dd9be55e2 lint 2024-07-22 15:51:52 -07:00
Josh Yan
d70707a668 syscopy windows 2024-07-22 15:51:52 -07:00
Josh Yan
c88774ffeb os copy 2024-07-22 15:51:52 -07:00
Josh Yan
34d197000d rmv prints 2024-07-22 15:51:52 -07:00
Josh Yan
6c0a8379f6 local copy 2024-07-22 15:51:52 -07:00
Josh Yan
163ee9a8b0 isLocal firstdraft 2024-07-22 15:51:52 -07:00
Josh Yan
de7b2f3948 clean 2024-07-22 15:51:52 -07:00
Josh Yan
f27c66fb0c rm bench 2024-07-22 15:51:52 -07:00
Josh Yan
a238191798 rm config 2024-07-22 15:51:52 -07:00
Josh Yan
6436c7a375 rm config 2024-07-22 15:51:52 -07:00
Josh Yan
896a15874e clean 2024-07-22 15:51:52 -07:00
Josh Yan
56008688a1 local path 2024-07-22 15:51:52 -07:00
Josh Yan
d14d38e940 still works 2024-07-22 15:51:52 -07:00
Josh Yan
03df02883d rebase 2024-07-22 15:51:52 -07:00
Josh Yan
ae49abf80a benchmark 2024-07-22 15:51:52 -07:00
Josh Yan
2c450502db on disk copy 2024-07-22 15:51:52 -07:00
Josh Yan
46b76aeb46 start tests 2024-07-22 15:51:52 -07:00
Josh Yan
0e01da82d6 errorsis 2024-07-22 15:51:31 -07:00
Josh Yan
6b1b85ba3d hide initialize keypair 2024-07-22 15:41:04 -07:00
Josh Yan
5603441538 test 2024-07-22 13:58:50 -07:00
Josh Yan
76b4dfcc9e auth 2024-07-22 13:54:02 -07:00
9 changed files with 449 additions and 89 deletions

View File

@ -17,6 +17,7 @@ import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
@ -24,7 +25,10 @@ import (
"net/http"
"net/url"
"runtime"
"strings"
"time"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/version"
@ -383,3 +387,16 @@ func (c *Client) Version(ctx context.Context) (string, error) {
return version.Version, nil
}
func Authorization(ctx context.Context, request *http.Request) (string, error) {
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
token, err := auth.Sign(ctx, data)
if err != nil {
return "", err
}
// interleave request data into the token
key, sig, _ := strings.Cut(token, ":")
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
}

View File

@ -3,49 +3,68 @@ package auth
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"golang.org/x/crypto/ssh"
)
const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) {
func privateKey() (ssh.Signer, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
return nil, err
}
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath)
if errors.Is(err, os.ErrNotExist) {
err := initializeKeypair()
if err != nil {
return nil, err
}
return privateKey()
} else if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return nil, err
}
return ssh.ParsePrivateKey(privateKeyFile)
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
func GetPublicKey() (ssh.PublicKey, error) {
// try to read pubkey first
home, err := os.UserHomeDir()
if err != nil {
return "", err
return nil, err
}
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub")
pubKeyFile, err := os.ReadFile(pubkeyPath)
if errors.Is(err, os.ErrNotExist) {
// try from privateKey
privateKey, err := privateKey()
if err != nil {
return nil, fmt.Errorf("failed to read public key: %w", err)
}
return privateKey.PublicKey(), nil
} else if err != nil {
return nil, fmt.Errorf("failed to read public key: %w", err)
}
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
if err != nil {
return "", err
}
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
return strings.TrimSpace(string(publicKey)), nil
pubKey, _, _, _, err := ssh.ParseAuthorizedKey(pubKeyFile)
return pubKey, err
}
func NewNonce(r io.Reader, length int) (string, error) {
@ -58,25 +77,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
}
func Sign(ctx context.Context, bts []byte) (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
}
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
privateKey, err := privateKey()
if err != nil {
return "", err
}
// get the pubkey, but remove the type
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
parts := bytes.Split(publicKey, []byte(" "))
publicKey, err := GetPublicKey()
if err != nil {
return "", err
}
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
parts := bytes.Split(publicKeyBytes, []byte(" "))
if len(parts) < 2 {
return "", fmt.Errorf("malformed public key")
}
@ -89,3 +103,49 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
// signature is <pubkey>:<signature>
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
}
func initializeKeypair() error {
home, err := os.UserHomeDir()
if err != nil {
return err
}
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
_, err = os.Stat(privKeyPath)
if errors.Is(err, os.ErrNotExist) {
fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return err
}
privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
return fmt.Errorf("could not create directory %w", err)
}
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
return err
}
sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
if err != nil {
return err
}
publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
return err
}
fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
}
return nil
}

View File

@ -4,10 +4,7 @@ import (
"archive/zip"
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/pem"
"errors"
"fmt"
"io"
@ -15,6 +12,7 @@ import (
"math"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
@ -112,7 +110,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile
}
digest, err := createBlob(cmd, client, path)
digest, err := createBlob(cmd, path)
if err != nil {
return err
}
@ -263,7 +261,9 @@ func tempZipFiles(path string) (string, error) {
return tempfile.Name(), nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
var ErrBlobExists = errors.New("blob exists")
func createBlob(cmd *cobra.Command, path string) (string, error) {
bin, err := os.Open(path)
if err != nil {
return "", err
@ -280,12 +280,65 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
}
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
// Use our new CreateBlob request which will include the file path
// The server checks for that file and if the server is local, it will copy the file over
// If the local copy fails, the server will continue to the default local copy
// If that fails, it will continue with the server POST
err = CreateBlob(cmd.Context(), path, digest, bin)
if errors.Is(err, ErrBlobExists) {
return digest, nil
}
if err != nil {
return "", err
}
return digest, nil
}
func CreateBlob(ctx context.Context, src, digest string, r *os.File) (error) {
ollamaHost := envconfig.Host
client := http.DefaultClient
base := &url.URL{
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
}
path := fmt.Sprintf("/api/blobs/%s", digest)
requestURL := base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), r)
if err != nil {
return err
}
authz, err := api.Authorization(ctx, request)
if err != nil {
return err
}
request.Header.Set("Authorization", authz)
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
request.Header.Set("X-Ollama-File", src)
resp, err := client.Do(request)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusCreated {
return nil
}
if resp.StatusCode == http.StatusOK {
return ErrBlobExists
}
return err
}
func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true
@ -379,11 +432,12 @@ func errFromUnknownKey(unknownKeyErr error) error {
if len(matches) > 0 {
serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey()
publicKey, err := auth.GetPublicKey()
if err != nil {
return unknownKeyErr
}
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
@ -1072,7 +1126,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
}
func RunServer(cmd *cobra.Command, _ []string) error {
if err := initializeKeypair(); err != nil {
if _, err := auth.GetPublicKey(); err != nil {
return err
}
@ -1089,52 +1143,6 @@ func RunServer(cmd *cobra.Command, _ []string) error {
return err
}
func initializeKeypair() error {
home, err := os.UserHomeDir()
if err != nil {
return err
}
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
_, err = os.Stat(privKeyPath)
if os.IsNotExist(err) {
fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return err
}
privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
return fmt.Errorf("could not create directory %w", err)
}
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
return err
}
sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
if err != nil {
return err
}
publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
return err
}
fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
}
return nil
}
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {

23
server/copy_darwin.go Normal file
View File

@ -0,0 +1,23 @@
package server
import (
"os"
"path/filepath"
"golang.org/x/sys/unix"
)
func localCopy(src, target string) error {
dirPath := filepath.Dir(target)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
err := unix.Clonefile(src, target, 0)
if err != nil {
return err
}
return nil
}

7
server/copy_linux.go Normal file
View File

@ -0,0 +1,7 @@
package server
import "errors"
func localCopy(src, target string) error {
return errors.New("no local copy implementation for linux")
}

67
server/copy_windows.go Normal file
View File

@ -0,0 +1,67 @@
//go:build windows
// +build windows
package server
import (
"os"
"path/filepath"
"syscall"
"unsafe"
)
func localCopy(src, target string) error {
// Create target directory if it doesn't exist
dirPath := filepath.Dir(target)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
// Open source file
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
// Create target file
targetFile, err := os.Create(target)
if err != nil {
return err
}
defer targetFile.Close()
// Use CopyFileExW to copy the file
err = copyFileEx(src, target)
if err != nil {
return err
}
return nil
}
func copyFileEx(src, dst string) error {
kernel32 := syscall.NewLazyDLL("kernel32.dll")
copyFileEx := kernel32.NewProc("CopyFileExW")
srcPtr, err := syscall.UTF16PtrFromString(src)
if err != nil {
return err
}
dstPtr, err := syscall.UTF16PtrFromString(dst)
if err != nil {
return err
}
r1, _, err := copyFileEx.Call(
uintptr(unsafe.Pointer(srcPtr)),
uintptr(unsafe.Pointer(dstPtr)),
0, 0, 0, 0)
if r1 == 0 {
return err
}
return nil
}

View File

@ -32,6 +32,7 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
"golang.org/x/crypto/ssh"
)
var (
@ -1088,11 +1089,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if anonymous {
// no user is associated with the public key, and the request requires non-anonymous access
pubKey, nestedErr := auth.GetPublicKey()
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, errUnauthorized
}
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
}
// user is associated with the public key, but is not authorized to make the request
return nil, errUnauthorized

View File

@ -4,6 +4,7 @@ import (
"bytes"
"cmp"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@ -23,8 +24,10 @@ import (
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/ssh"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
@ -928,7 +931,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err = os.Stat(path)
switch {
case errors.Is(err, os.ErrNotExist):
@ -941,6 +943,14 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return
}
if c.GetHeader("X-Ollama-File") != "" && s.isLocal(c) {
err = localBlobCopy(c.GetHeader("X-Ollama-File"), path)
if err == nil {
c.Status(http.StatusCreated)
return
}
}
layer, err := NewLayer(c.Request.Body, "")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -955,6 +965,108 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusCreated)
}
func localBlobCopy (src, dest string) error {
_, err := os.Stat(src)
if err != nil {
return err
}
err = localCopy(src, dest)
if err == nil {
return nil
}
err = defaultCopy(src, dest)
if err == nil {
return nil
}
return fmt.Errorf("failed to copy blob")
}
func (s *Server) isLocal(c *gin.Context) bool {
if authz := c.GetHeader("Authorization"); authz != "" {
parts := strings.Split(authz, ":")
if len(parts) != 3 {
return false
}
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
if err != nil {
return false
}
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
requestData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return false
}
partialRequestDataParts := strings.Split(string(requestData), ",")
if len(partialRequestDataParts) != 3 {
return false
}
signature, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return false
}
if err := clientPublicKey.Verify(requestData, &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
return false
}
serverPublicKey, err := auth.GetPublicKey()
if err != nil {
slog.Error(fmt.Sprintf("failed to get server public key: %v", err))
return false
}
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
return true
}
return false
}
return false
}
func defaultCopy(path string, dest string) error {
// This function should be called if the server is local
// It should find the model directory, copy the blob over, and return the digest
dirPath := filepath.Dir(dest)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
// Copy blob over
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("could not open source file: %v", err)
}
defer sourceFile.Close()
destFile, err := os.Create(dest)
if err != nil {
return fmt.Errorf("could not create destination file: %v", err)
}
defer destFile.Close()
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
if err != nil {
return fmt.Errorf("error copying file: %v", err)
}
err = destFile.Sync()
if err != nil {
return fmt.Errorf("error flushing file: %v", err)
}
return nil
}
func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces {

View File

@ -10,15 +10,18 @@ import (
"math"
"net/http"
"net/http/httptest"
"net/url"
"os"
"sort"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
@ -527,3 +530,64 @@ func TestNormalize(t *testing.T) {
})
}
}
func TestIsLocalReal(t *testing.T) {
gin.SetMode(gin.TestMode)
clientPubLoc := t.TempDir()
t.Setenv("HOME", clientPubLoc)
t.Setenv("USERPROFILE", clientPubLoc)
_, err := auth.GetPublicKey()
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Request = &http.Request{
Header: make(http.Header),
}
requestURL := url.URL{
Scheme: "http",
Host: "localhost:8080",
Path: "/api/blobs",
}
request := &http.Request{
Method: http.MethodPost,
URL: &requestURL,
}
s := &Server{}
authz, err := api.Authorization(ctx, request)
if err != nil {
t.Fatal(err)
}
// Set client authorization header
ctx.Request.Header.Set("Authorization", authz)
if !s.isLocal(ctx) {
t.Fatal("Expected isLocal to return true")
}
t.Run("different server pubkey", func(t *testing.T) {
serverPubLoc := t.TempDir()
t.Setenv("HOME", serverPubLoc)
t.Setenv("USERPROFILE", serverPubLoc)
_, err := auth.GetPublicKey()
if err != nil {
t.Fatal(err)
}
if s.isLocal(ctx) {
t.Fatal("Expected isLocal to return false")
}
})
t.Run("invalid pubkey", func(t *testing.T) {
ctx.Request.Header.Set("Authorization", "sha-25616:invalid")
if s.isLocal(ctx) {
t.Fatal("Expected isLocal to return false")
}
})
}