Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
0e01da82d6 | ||
|
6b1b85ba3d | ||
|
5603441538 | ||
|
76b4dfcc9e |
120
auth/auth.go
120
auth/auth.go
@ -3,49 +3,68 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultPrivateKey = "id_ed25519"
|
const defaultPrivateKey = "id_ed25519"
|
||||||
|
|
||||||
func keyPath() (string, error) {
|
func privateKey() (ssh.Signer, error) {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||||
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
err := initializeKeypair()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return privateKey()
|
||||||
|
} else if err != nil {
|
||||||
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssh.ParsePrivateKey(privateKeyFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPublicKey() (string, error) {
|
func GetPublicKey() (ssh.PublicKey, error) {
|
||||||
keyPath, err := keyPath()
|
// try to read pubkey first
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub")
|
||||||
|
pubKeyFile, err := os.ReadFile(pubkeyPath)
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
// try from privateKey
|
||||||
|
privateKey, err := privateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
return nil, fmt.Errorf("failed to read public key: %w", err)
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
return privateKey.PublicKey(), nil
|
||||||
if err != nil {
|
} else if err != nil {
|
||||||
return "", err
|
return nil, fmt.Errorf("failed to read public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
pubKey, _, _, _, err := ssh.ParseAuthorizedKey(pubKeyFile)
|
||||||
|
return pubKey, err
|
||||||
return strings.TrimSpace(string(publicKey)), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNonce(r io.Reader, length int) (string, error) {
|
func NewNonce(r io.Reader, length int) (string, error) {
|
||||||
@ -58,25 +77,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||||
keyPath, err := keyPath()
|
privateKey, err := privateKey()
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKeyFile, err := os.ReadFile(keyPath)
|
|
||||||
if err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the pubkey, but remove the type
|
// get the pubkey, but remove the type
|
||||||
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
publicKey, err := GetPublicKey()
|
||||||
parts := bytes.Split(publicKey, []byte(" "))
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
|
||||||
|
|
||||||
|
parts := bytes.Split(publicKeyBytes, []byte(" "))
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return "", fmt.Errorf("malformed public key")
|
return "", fmt.Errorf("malformed public key")
|
||||||
}
|
}
|
||||||
@ -89,3 +103,49 @@ func Sign(ctx context.Context, bts []byte) (string, error) {
|
|||||||
// signature is <pubkey>:<signature>
|
// signature is <pubkey>:<signature>
|
||||||
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initializeKeypair() error {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
||||||
|
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
|
||||||
|
|
||||||
|
_, err = os.Stat(privKeyPath)
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
|
||||||
|
cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
|
||||||
|
return fmt.Errorf("could not create directory %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
|
||||||
|
|
||||||
|
if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
54
cmd/cmd.go
54
cmd/cmd.go
@ -4,10 +4,7 @@ import (
|
|||||||
"archive/zip"
|
"archive/zip"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -379,11 +376,12 @@ func errFromUnknownKey(unknownKeyErr error) error {
|
|||||||
if len(matches) > 0 {
|
if len(matches) > 0 {
|
||||||
serverPubKey := matches[0]
|
serverPubKey := matches[0]
|
||||||
|
|
||||||
localPubKey, err := auth.GetPublicKey()
|
publicKey, err := auth.GetPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return unknownKeyErr
|
return unknownKeyErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
|
||||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||||
// try the ollama service public key
|
// try the ollama service public key
|
||||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||||
@ -1072,7 +1070,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||||
if err := initializeKeypair(); err != nil {
|
if _, err := auth.GetPublicKey(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1089,52 +1087,6 @@ func RunServer(cmd *cobra.Command, _ []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func initializeKeypair() error {
|
|
||||||
home, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
|
||||||
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
|
|
||||||
|
|
||||||
_, err = os.Stat(privKeyPath)
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
|
|
||||||
cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
privateKeyBytes, err := ssh.MarshalPrivateKey(cryptoPrivateKey, "")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
|
|
||||||
return fmt.Errorf("could not create directory %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privateKeyBytes), 0o600); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
sshPublicKey, err := ssh.NewPublicKey(cryptoPublicKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
publicKeyBytes := ssh.MarshalAuthorizedKey(sshPublicKey)
|
|
||||||
|
|
||||||
if err := os.WriteFile(pubKeyPath, publicKeyBytes, 0o644); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Your new public key is: \n\n%s\n", publicKeyBytes)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -32,6 +32,7 @@ import (
|
|||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -1088,11 +1089,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|||||||
if anonymous {
|
if anonymous {
|
||||||
// no user is associated with the public key, and the request requires non-anonymous access
|
// no user is associated with the public key, and the request requires non-anonymous access
|
||||||
pubKey, nestedErr := auth.GetPublicKey()
|
pubKey, nestedErr := auth.GetPublicKey()
|
||||||
|
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
|
||||||
if nestedErr != nil {
|
if nestedErr != nil {
|
||||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
}
|
}
|
||||||
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
|
||||||
}
|
}
|
||||||
// user is associated with the public key, but is not authorized to make the request
|
// user is associated with the public key, but is not authorized to make the request
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
|
Loading…
x
Reference in New Issue
Block a user