Compare commits
33 Commits
main
...
jyan/local
Author | SHA1 | Date | |
---|---|---|---|
|
f1b5d939f5 | ||
|
d1b7f8bb07 | ||
|
6d4724a06d | ||
|
c507325288 | ||
|
09431f353d | ||
|
8548d1d596 | ||
|
478b58dd77 | ||
|
24c5e172ca | ||
|
d12717e7dc | ||
|
a80d79536a | ||
|
4c1e188200 | ||
|
689a7cb90d | ||
|
93a8054693 | ||
|
7769602b75 | ||
|
8048ce0816 | ||
|
72314bf4b5 | ||
|
d4ab994ade | ||
|
c44f4825c4 | ||
|
154b59c0b6 | ||
|
8ee1ada22a | ||
|
e9a2ead87a | ||
|
a7721cb1d2 | ||
|
1a6197abb1 | ||
|
9fbd474bf7 | ||
|
7e8d8cc72f | ||
|
cbd98a2e37 | ||
|
ad36d4ff1b | ||
|
461c964941 | ||
|
a993a3a85c | ||
|
f7d64856d5 | ||
|
6b1b85ba3d | ||
|
5603441538 | ||
|
76b4dfcc9e |
@ -17,6 +17,7 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -24,7 +25,10 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@ -383,3 +387,16 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
|||||||
|
|
||||||
return version.Version, nil
|
return version.Version, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Authorization(ctx context.Context, request *http.Request) (string, error) {
|
||||||
|
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
|
||||||
|
|
||||||
|
token, err := auth.Sign(ctx, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// interleave request data into the token
|
||||||
|
key, sig, _ := strings.Cut(token, ":")
|
||||||
|
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
|
||||||
|
}
|
||||||
|
119
auth/auth.go
119
auth/auth.go
@ -3,49 +3,67 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultPrivateKey = "id_ed25519"
|
const defaultPrivateKey = "id_ed25519"
|
||||||
|
|
||||||
func keyPath() (string, error) {
|
func privateKey() (ssh.Signer, error) {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||||
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
err := initializeKeypair()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return privateKey()
|
||||||
|
} else if err != nil {
|
||||||
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssh.ParsePrivateKey(privateKeyFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPublicKey() (string, error) {
|
func GetPublicKey() (ssh.PublicKey, error) {
|
||||||
keyPath, err := keyPath()
|
// try to read pubkey first
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub")
|
||||||
|
pubKeyFile, err := os.ReadFile(pubkeyPath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
// try from privateKey
|
||||||
|
privateKey, err := privateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
return nil, fmt.Errorf("failed to read public key: %w", err)
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
return privateKey.PublicKey(), nil
|
||||||
if err != nil {
|
} else if err != nil {
|
||||||
return "", err
|
return nil, fmt.Errorf("failed to read public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
pubKey, _, _, _, err := ssh.ParseAuthorizedKey(pubKeyFile)
|
||||||
|
return pubKey, err
|
||||||
return strings.TrimSpace(string(publicKey)), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNonce(r io.Reader, length int) (string, error) {
|
func NewNonce(r io.Reader, length int) (string, error) {
|
||||||
@ -58,25 +76,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||||
keyPath, err := keyPath()
|
privateKey, err := privateKey()
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
|
||||||
if err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the pubkey, but remove the type
|
// get the pubkey, but remove the type
|
||||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
publicKey, err := GetPublicKey()
|
||||||
parts := bytes.Split(publicKey, []byte(" "))
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
|
||||||
|
|
||||||
|
parts := bytes.Split(publicKeyBytes, []byte(" "))
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return "", fmt.Errorf("malformed public key")
|
return "", fmt.Errorf("malformed public key")
|
||||||
}
|
}
|
||||||
@ -89,3 +102,49 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
|
|||||||
// signature is <pubkey>:<signature>
|
// signature is <pubkey>:<signature>
|
||||||
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initializeKeypair() error {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
||||||
|
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
|
||||||
|
|
||||||
|
_, err = os.Stat(privKeyPath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
|
||||||
|
cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
|
||||||
|
return fmt.Errorf("could not create directory %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
|
||||||
|
|
||||||
|
if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
166
cmd/cmd.go
166
cmd/cmd.go
@ -4,10 +4,8 @@ import (
|
|||||||
"archive/zip"
|
"archive/zip"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/pem"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -15,6 +13,7 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -263,6 +262,8 @@ func tempZipFiles(path string) (string, error) {
|
|||||||
return tempfile.Name(), nil
|
return tempfile.Name(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrBlobExists = errors.New("blob exists")
|
||||||
|
|
||||||
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
||||||
bin, err := os.Open(path)
|
bin, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -280,12 +281,120 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||||
|
|
||||||
|
// We check if we can find the models directory locally
|
||||||
|
// If we can, we return the path to the directory
|
||||||
|
// If we can't, we return an error
|
||||||
|
// If the blob exists already, we return the digest
|
||||||
|
dest, err := getLocalPath(cmd.Context(), digest)
|
||||||
|
|
||||||
|
if errors.Is(err, ErrBlobExists) {
|
||||||
|
return digest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Successfully found the model directory
|
||||||
|
if err == nil {
|
||||||
|
// Copy blob in via OS specific copy
|
||||||
|
// Linux errors out to use io.copy
|
||||||
|
err = localCopy(path, dest)
|
||||||
|
if err == nil {
|
||||||
|
return digest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default copy using io.copy
|
||||||
|
err = defaultCopy(path, dest)
|
||||||
|
if err == nil {
|
||||||
|
return digest, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If at any point copying the blob over locally fails, we default to the copy through the server
|
||||||
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return digest, nil
|
return digest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getLocalPath(ctx context.Context, digest string) (string, error) {
|
||||||
|
ollamaHost := envconfig.Host
|
||||||
|
|
||||||
|
client := http.DefaultClient
|
||||||
|
base := &url.URL{
|
||||||
|
Scheme: ollamaHost.Scheme,
|
||||||
|
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(digest)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody := bytes.NewReader(data)
|
||||||
|
path := fmt.Sprintf("/api/blobs/%s", digest)
|
||||||
|
requestURL := base.JoinPath(path)
|
||||||
|
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
authz, err := api.Authorization(ctx, request)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
request.Header.Set("Authorization", authz)
|
||||||
|
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||||
|
request.Header.Set("X-Redirect-Create", "1")
|
||||||
|
|
||||||
|
resp, err := client.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusTemporaryRedirect {
|
||||||
|
dest := resp.Header.Get("LocalLocation")
|
||||||
|
|
||||||
|
return dest, nil
|
||||||
|
}
|
||||||
|
return "", ErrBlobExists
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultCopy(path string, dest string) error {
|
||||||
|
// This function should be called if the server is local
|
||||||
|
// It should find the model directory, copy the blob over, and return the digest
|
||||||
|
dirPath := filepath.Dir(dest)
|
||||||
|
|
||||||
|
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy blob over
|
||||||
|
sourceFile, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not open source file: %v", err)
|
||||||
|
}
|
||||||
|
defer sourceFile.Close()
|
||||||
|
|
||||||
|
destFile, err := os.Create(dest)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not create destination file: %v", err)
|
||||||
|
}
|
||||||
|
defer destFile.Close()
|
||||||
|
|
||||||
|
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error copying file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = destFile.Sync()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error flushing file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
@ -379,11 +488,12 @@ func errFromUnknownKey(unknownKeyErr error) error {
|
|||||||
if len(matches) > 0 {
|
if len(matches) > 0 {
|
||||||
serverPubKey := matches[0]
|
serverPubKey := matches[0]
|
||||||
|
|
||||||
localPubKey, err := auth.GetPublicKey()
|
publicKey, err := auth.GetPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return unknownKeyErr
|
return unknownKeyErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
|
||||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||||
// try the ollama service public key
|
// try the ollama service public key
|
||||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||||
@ -1072,7 +1182,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||||
if err := initializeKeypair(); err != nil {
|
if _, err := auth.GetPublicKey(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1089,52 +1199,6 @@ func RunServer(cmd *cobra.Command, _ []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func initializeKeypair() error {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
|
||||||
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
|
|
||||||
|
|
||||||
_, err = os.Stat(privKeyPath)
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
|
|
||||||
cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
|
|
||||||
return fmt.Errorf("could not create directory %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
|
|
||||||
|
|
||||||
if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
23
cmd/copy_darwin.go
Normal file
23
cmd/copy_darwin.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func localCopy(src, target string) error {
|
||||||
|
dirPath := filepath.Dir(target)
|
||||||
|
|
||||||
|
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err := unix.Clonefile(src, target, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
7
cmd/copy_linux.go
Normal file
7
cmd/copy_linux.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
func localCopy(src, target string) error {
|
||||||
|
return errors.New("no local copy implementation for linux")
|
||||||
|
}
|
67
cmd/copy_windows.go
Normal file
67
cmd/copy_windows.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
func localCopy(src, target string) error {
|
||||||
|
// Create target directory if it doesn't exist
|
||||||
|
dirPath := filepath.Dir(target)
|
||||||
|
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open source file
|
||||||
|
sourceFile, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer sourceFile.Close()
|
||||||
|
|
||||||
|
// Create target file
|
||||||
|
targetFile, err := os.Create(target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer targetFile.Close()
|
||||||
|
|
||||||
|
// Use CopyFileExW to copy the file
|
||||||
|
err = copyFileEx(src, target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyFileEx(src, dst string) error {
|
||||||
|
kernel32 := syscall.NewLazyDLL("kernel32.dll")
|
||||||
|
copyFileEx := kernel32.NewProc("CopyFileExW")
|
||||||
|
|
||||||
|
srcPtr, err := syscall.UTF16PtrFromString(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dstPtr, err := syscall.UTF16PtrFromString(dst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
r1, _, err := copyFileEx.Call(
|
||||||
|
uintptr(unsafe.Pointer(srcPtr)),
|
||||||
|
uintptr(unsafe.Pointer(dstPtr)),
|
||||||
|
0, 0, 0, 0)
|
||||||
|
|
||||||
|
if r1 == 0 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -32,6 +32,7 @@ import (
|
|||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -1088,11 +1089,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|||||||
if anonymous {
|
if anonymous {
|
||||||
// no user is associated with the public key, and the request requires non-anonymous access
|
// no user is associated with the public key, and the request requires non-anonymous access
|
||||||
pubKey, nestedErr := auth.GetPublicKey()
|
pubKey, nestedErr := auth.GetPublicKey()
|
||||||
|
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
|
||||||
if nestedErr != nil {
|
if nestedErr != nil {
|
||||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
}
|
}
|
||||||
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
|
||||||
}
|
}
|
||||||
// user is associated with the public key, but is not authorized to make the request
|
// user is associated with the public key, but is not authorized to make the request
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"cmp"
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -23,8 +24,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/gpu"
|
"github.com/ollama/ollama/gpu"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
@ -928,7 +931,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = os.Stat(path)
|
_, err = os.Stat(path)
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, os.ErrNotExist):
|
case errors.Is(err, os.ErrNotExist):
|
||||||
@ -940,6 +942,11 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if c.GetHeader("X-Redirect-Create") == "1" && s.isLocal(c) {
|
||||||
|
c.Header("LocalLocation", path)
|
||||||
|
c.Status(http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
layer, err := NewLayer(c.Request.Body, "")
|
layer, err := NewLayer(c.Request.Body, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -955,6 +962,54 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
|||||||
c.Status(http.StatusCreated)
|
c.Status(http.StatusCreated)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) isLocal(c *gin.Context) bool {
|
||||||
|
if authz := c.GetHeader("Authorization"); authz != "" {
|
||||||
|
parts := strings.Split(authz, ":")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
|
||||||
|
requestData, err := base64.StdEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
partialRequestDataParts := strings.Split(string(requestData), ",")
|
||||||
|
if len(partialRequestDataParts) != 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := base64.StdEncoding.DecodeString(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := clientPublicKey.Verify(requestData, &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
serverPublicKey, err := auth.GetPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error(fmt.Sprintf("failed to get server public key: %v", err))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func isLocalIP(ip netip.Addr) bool {
|
func isLocalIP(ip netip.Addr) bool {
|
||||||
if interfaces, err := net.Interfaces(); err == nil {
|
if interfaces, err := net.Interfaces(); err == nil {
|
||||||
for _, iface := range interfaces {
|
for _, iface := range interfaces {
|
||||||
|
@ -10,15 +10,18 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
@ -527,3 +530,62 @@ func TestNormalize(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsLocalReal(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
clientPubLoc := t.TempDir()
|
||||||
|
t.Setenv("HOME", clientPubLoc)
|
||||||
|
|
||||||
|
_, err := auth.GetPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ctx, _ := gin.CreateTestContext(w)
|
||||||
|
ctx.Request = &http.Request{
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
|
||||||
|
requestURL := url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: "localhost:8080",
|
||||||
|
Path: "/api/blobs",
|
||||||
|
}
|
||||||
|
request := &http.Request{
|
||||||
|
Method: http.MethodPost,
|
||||||
|
URL: &requestURL,
|
||||||
|
}
|
||||||
|
s := &Server{}
|
||||||
|
|
||||||
|
authz, err := api.Authorization(ctx, request)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set client authorization header
|
||||||
|
ctx.Request.Header.Set("Authorization", authz)
|
||||||
|
if !s.isLocal(ctx) {
|
||||||
|
t.Fatal("Expected isLocal to return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("different server pubkey", func(t *testing.T) {
|
||||||
|
serverPubLoc := t.TempDir()
|
||||||
|
t.Setenv("HOME", serverPubLoc)
|
||||||
|
_, err := auth.GetPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.isLocal(ctx) {
|
||||||
|
t.Fatal("Expected isLocal to return false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid pubkey", func(t *testing.T) {
|
||||||
|
ctx.Request.Header.Set("Authorization", "sha-25616:invalid")
|
||||||
|
if s.isLocal(ctx) {
|
||||||
|
t.Fatal("Expected isLocal to return false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user