change approach, set BasicAuthKey to only password instead of username:password pair, with ollama as user gin account

This commit is contained in:
kemalelmizan 2024-08-07 12:42:34 +07:00
parent 1283f413f8
commit 32e0dca416
4 changed files with 31 additions and 57 deletions

View File

@ -57,24 +57,6 @@ func Host() *url.URL {
}
}
func BasicAuth() (basicAuth map[string]string) {
basicAuthString := Var("OLLAMA_BASIC_AUTH")
parts := strings.Split(basicAuthString, ":")
if len(parts) < 2 || parts[0] == "" || parts[1] == "" {
// Return a default value if the split result is not as expected
basicAuth = map[string]string{
"username": "password",
}
} else {
basicAuth = map[string]string{
parts[0]: parts[1],
}
}
return basicAuth
}
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func Origins() (origins []string) {
if s := Var("OLLAMA_ORIGINS"); s != "" {
@ -179,6 +161,8 @@ var (
RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES")
GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL")
HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION")
BasicAuthKey = String("OLLAMA_BASIC_AUTH_KEY")
)
func RunnersDir() (p string) {
@ -276,7 +260,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"},
"OLLAMA_BASIC_AUTH": {"OLLAMA_BASIC_AUTH", BasicAuth(), "Basic auth (default username:password, will allow all request if empty)"},
"OLLAMA_BASIC_AUTH_KEY": {"OLLAMA_BASIC_AUTH_KEY", BasicAuthKey(), "Basic auth key for user ollama, will allow all request if empty"},
}
if runtime.GOOS != "darwin" {
ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"}

View File

@ -233,29 +233,3 @@ func TestVar(t *testing.T) {
})
}
}
func TestBasicAuth(t *testing.T) {
cases := map[string]struct {
value string
expectUser string
expectPassword string
}{
"empty": {"", "username", "password"},
"valid": {"user1:password1", "user1", "password1"},
"missingPassword": {"user2:", "username", "password"},
"missingUser": {":password2", "username", "password"},
"noColon": {"user3password3", "username", "password"},
}
for name, tt := range cases {
t.Run(name, func(t *testing.T) {
t.Setenv("OLLAMA_BASIC_AUTH", tt.value)
basicAuth := BasicAuth()
password, exists := basicAuth[tt.expectUser]
if !exists || password != tt.expectPassword {
t.Errorf("%s: expected map[%s:%s], got %v", name, tt.expectUser, tt.expectPassword, basicAuth)
}
})
}
}

View File

@ -1064,12 +1064,13 @@ func (s *Server) GenerateRoutes() http.Handler {
allowedHostsMiddleware(s.addr),
)
auth := envconfig.BasicAuth()
// Check if the auth map contains the default credentials
isDefaultAuth := len(auth) == 1 && auth["username"] == "password"
if !isDefaultAuth {
r.Use(gin.BasicAuth(auth))
ollamaAuthKey := envconfig.BasicAuthKey()
if ollamaAuthKey != "" {
r.Use(
gin.BasicAuth(gin.Accounts{
"ollama": ollamaAuthKey,
}),
)
}
r.POST("/api/pull", s.PullModelHandler)

View File

@ -345,28 +345,43 @@ func Test_Routes(t *testing.T) {
{
Name: "Authenticated Route Success",
Method: http.MethodGet,
Path: "/api/protected", // replace with your protected route
Path: "/api/version",
Setup: func(t *testing.T, req *http.Request) {
req.SetBasicAuth("username", "password")
t.Setenv("OLLAMA_BASIC_AUTH_KEY", "password")
req.SetBasicAuth("ollama", "password")
},
Expected: func(t *testing.T, resp *http.Response) {
assert.Equal(t, http.StatusOK, resp.StatusCode) // replace with your expected status code
assert.Equal(t, http.StatusOK, resp.StatusCode)
_, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// Add more assertions based on the expected response body
},
},
{
Name: "Authenticated Route Failure",
Method: http.MethodGet,
Path: "/api/protected", // replace with your protected route
Path: "/api/version",
Setup: func(t *testing.T, req *http.Request) {
req.SetBasicAuth("wronguser", "wrongpass")
t.Setenv("OLLAMA_BASIC_AUTH_KEY", "password")
req.SetBasicAuth("ollama", "wrongpassword")
},
Expected: func(t *testing.T, resp *http.Response) {
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
},
},
{
Name: "BasicAuthKey not set",
Method: http.MethodGet,
Path: "/api/version",
Setup: func(t *testing.T, req *http.Request) {
t.Setenv("OLLAMA_BASIC_AUTH_KEY", "")
req.SetBasicAuth("ollama", "wrongpassword")
},
Expected: func(t *testing.T, resp *http.Response) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
_, err := io.ReadAll(resp.Body)
require.NoError(t, err)
},
},
}
t.Setenv("OLLAMA_MODELS", t.TempDir())