diff --git a/envconfig/config.go b/envconfig/config.go index 74d4102d..bcc28bb6 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -57,24 +57,6 @@ func Host() *url.URL { } } -func BasicAuth() (basicAuth map[string]string) { - basicAuthString := Var("OLLAMA_BASIC_AUTH") - parts := strings.Split(basicAuthString, ":") - - if len(parts) < 2 || parts[0] == "" || parts[1] == "" { - // Return a default value if the split result is not as expected - basicAuth = map[string]string{ - "username": "password", - } - } else { - basicAuth = map[string]string{ - parts[0]: parts[1], - } - } - - return basicAuth -} - // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. func Origins() (origins []string) { if s := Var("OLLAMA_ORIGINS"); s != "" { @@ -179,6 +161,8 @@ var ( RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES") GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL") HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION") + + BasicAuthKey = String("OLLAMA_BASIC_AUTH_KEY") ) func RunnersDir() (p string) { @@ -276,7 +260,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"}, - "OLLAMA_BASIC_AUTH": {"OLLAMA_BASIC_AUTH", BasicAuth(), "Basic auth (default username:password, will allow all request if empty)"}, + "OLLAMA_BASIC_AUTH_KEY": {"OLLAMA_BASIC_AUTH_KEY", BasicAuthKey(), "Basic auth key for user ollama, will allow all request if empty"}, } if runtime.GOOS != "darwin" { ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"} diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 2a610e78..92a500f1 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -233,29 +233,3 @@ func TestVar(t *testing.T) { }) } } - -func TestBasicAuth(t *testing.T) { - cases := map[string]struct { - value string - expectUser string - expectPassword string - }{ - "empty": {"", "username", "password"}, - "valid": {"user1:password1", "user1", "password1"}, - "missingPassword": {"user2:", "username", "password"}, - "missingUser": {":password2", "username", "password"}, - "noColon": {"user3password3", "username", "password"}, - } - - for name, tt := range cases { - t.Run(name, func(t *testing.T) { - t.Setenv("OLLAMA_BASIC_AUTH", tt.value) - basicAuth := BasicAuth() - - password, exists := basicAuth[tt.expectUser] - if !exists || password != tt.expectPassword { - t.Errorf("%s: expected map[%s:%s], got %v", name, tt.expectUser, tt.expectPassword, basicAuth) - } - }) - } -} diff --git a/server/routes.go b/server/routes.go index 9fe5c30b..ec5f7202 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1064,12 +1064,13 @@ func (s *Server) GenerateRoutes() http.Handler { allowedHostsMiddleware(s.addr), ) - auth := envconfig.BasicAuth() - // Check if the auth map contains the default credentials - isDefaultAuth := len(auth) == 1 && auth["username"] == "password" - - if !isDefaultAuth { - r.Use(gin.BasicAuth(auth)) + ollamaAuthKey := envconfig.BasicAuthKey() + if ollamaAuthKey != "" { + r.Use( + gin.BasicAuth(gin.Accounts{ + "ollama": ollamaAuthKey, + }), + ) } r.POST("/api/pull", s.PullModelHandler) diff --git a/server/routes_test.go b/server/routes_test.go index e33002a5..87c6fd5d 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -345,28 +345,43 @@ func Test_Routes(t *testing.T) { { Name: "Authenticated Route Success", Method: http.MethodGet, - Path: "/api/protected", // replace with your protected route + Path: "/api/version", Setup: func(t *testing.T, req *http.Request) { - req.SetBasicAuth("username", "password") + t.Setenv("OLLAMA_BASIC_AUTH_KEY", "password") + req.SetBasicAuth("ollama", "password") }, Expected: func(t *testing.T, resp *http.Response) { - assert.Equal(t, http.StatusOK, resp.StatusCode) // replace with your expected status code + assert.Equal(t, http.StatusOK, resp.StatusCode) _, err := io.ReadAll(resp.Body) require.NoError(t, err) - // Add more assertions based on the expected response body }, }, { Name: "Authenticated Route Failure", Method: http.MethodGet, - Path: "/api/protected", // replace with your protected route + Path: "/api/version", Setup: func(t *testing.T, req *http.Request) { - req.SetBasicAuth("wronguser", "wrongpass") + t.Setenv("OLLAMA_BASIC_AUTH_KEY", "password") + req.SetBasicAuth("ollama", "wrongpassword") }, Expected: func(t *testing.T, resp *http.Response) { assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) }, }, + { + Name: "BasicAuthKey not set", + Method: http.MethodGet, + Path: "/api/version", + Setup: func(t *testing.T, req *http.Request) { + t.Setenv("OLLAMA_BASIC_AUTH_KEY", "") + req.SetBasicAuth("ollama", "wrongpassword") + }, + Expected: func(t *testing.T, resp *http.Response) { + assert.Equal(t, http.StatusOK, resp.StatusCode) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) + }, + }, } t.Setenv("OLLAMA_MODELS", t.TempDir())