From 1283f413f824a43b33c7c08e85320a6463fab1bb Mon Sep 17 00:00:00 2001 From: kemalelmizan Date: Wed, 7 Aug 2024 12:04:11 +0700 Subject: [PATCH 1/2] feat: add gin BasicAuth for username:password setup in env --- envconfig/config.go | 19 +++++++++++++++++++ envconfig/config_test.go | 26 ++++++++++++++++++++++++++ server/routes.go | 8 ++++++++ server/routes_test.go | 25 +++++++++++++++++++++++++ 4 files changed, 78 insertions(+) diff --git a/envconfig/config.go b/envconfig/config.go index b82b773d..74d4102d 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -57,6 +57,24 @@ func Host() *url.URL { } } +func BasicAuth() (basicAuth map[string]string) { + basicAuthString := Var("OLLAMA_BASIC_AUTH") + parts := strings.Split(basicAuthString, ":") + + if len(parts) < 2 || parts[0] == "" || parts[1] == "" { + // Return a default value if the split result is not as expected + basicAuth = map[string]string{ + "username": "password", + } + } else { + basicAuth = map[string]string{ + parts[0]: parts[1], + } + } + + return basicAuth +} + // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. func Origins() (origins []string) { if s := Var("OLLAMA_ORIGINS"); s != "" { @@ -258,6 +276,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"}, + "OLLAMA_BASIC_AUTH": {"OLLAMA_BASIC_AUTH", BasicAuth(), "Basic auth (default username:password, will allow all request if empty)"}, } if runtime.GOOS != "darwin" { ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"} diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 92a500f1..2a610e78 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -233,3 +233,29 @@ func TestVar(t *testing.T) { }) } } + +func TestBasicAuth(t *testing.T) { + cases := map[string]struct { + value string + expectUser string + expectPassword string + }{ + "empty": {"", "username", "password"}, + "valid": {"user1:password1", "user1", "password1"}, + "missingPassword": {"user2:", "username", "password"}, + "missingUser": {":password2", "username", "password"}, + "noColon": {"user3password3", "username", "password"}, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + t.Setenv("OLLAMA_BASIC_AUTH", tt.value) + basicAuth := BasicAuth() + + password, exists := basicAuth[tt.expectUser] + if !exists || password != tt.expectPassword { + t.Errorf("%s: expected map[%s:%s], got %v", name, tt.expectUser, tt.expectPassword, basicAuth) + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index b9c66b65..9fe5c30b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1064,6 +1064,14 @@ func (s *Server) GenerateRoutes() http.Handler { allowedHostsMiddleware(s.addr), ) + auth := envconfig.BasicAuth() + // Check if the auth map contains the default credentials + isDefaultAuth := len(auth) == 1 && auth["username"] == "password" + + if !isDefaultAuth { + r.Use(gin.BasicAuth(auth)) + } + r.POST("/api/pull", s.PullModelHandler) r.POST("/api/generate", s.GenerateHandler) r.POST("/api/chat", s.ChatHandler) diff --git a/server/routes_test.go b/server/routes_test.go index ef7248ef..e33002a5 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -342,6 +342,31 @@ func Test_Routes(t *testing.T) { } }, }, + { + Name: "Authenticated Route Success", + Method: http.MethodGet, + Path: "/api/protected", // replace with your protected route + Setup: func(t *testing.T, req *http.Request) { + req.SetBasicAuth("username", "password") + }, + Expected: func(t *testing.T, resp *http.Response) { + assert.Equal(t, http.StatusOK, resp.StatusCode) // replace with your expected status code + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) + // Add more assertions based on the expected response body + }, + }, + { + Name: "Authenticated Route Failure", + Method: http.MethodGet, + Path: "/api/protected", // replace with your protected route + Setup: func(t *testing.T, req *http.Request) { + req.SetBasicAuth("wronguser", "wrongpass") + }, + Expected: func(t *testing.T, resp *http.Response) { + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }, + }, } t.Setenv("OLLAMA_MODELS", t.TempDir()) From 32e0dca416ce9ef1ec1067070bcb299bd24dacd3 Mon Sep 17 00:00:00 2001 From: kemalelmizan Date: Wed, 7 Aug 2024 12:42:34 +0700 Subject: [PATCH 2/2] change approach, set BasicAuthKey to only password instead of username:password pair, with ollama as user gin account --- envconfig/config.go | 22 +++------------------- envconfig/config_test.go | 26 -------------------------- server/routes.go | 13 +++++++------ server/routes_test.go | 27 +++++++++++++++++++++------ 4 files changed, 31 insertions(+), 57 deletions(-) diff --git a/envconfig/config.go b/envconfig/config.go index 74d4102d..bcc28bb6 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -57,24 +57,6 @@ func Host() *url.URL { } } -func BasicAuth() (basicAuth map[string]string) { - basicAuthString := Var("OLLAMA_BASIC_AUTH") - parts := strings.Split(basicAuthString, ":") - - if len(parts) < 2 || parts[0] == "" || parts[1] == "" { - // Return a default value if the split result is not as expected - basicAuth = map[string]string{ - "username": "password", - } - } else { - basicAuth = map[string]string{ - parts[0]: parts[1], - } - } - - return basicAuth -} - // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. func Origins() (origins []string) { if s := Var("OLLAMA_ORIGINS"); s != "" { @@ -179,6 +161,8 @@ var ( RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES") GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL") HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION") + + BasicAuthKey = String("OLLAMA_BASIC_AUTH_KEY") ) func RunnersDir() (p string) { @@ -276,7 +260,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"}, - "OLLAMA_BASIC_AUTH": {"OLLAMA_BASIC_AUTH", BasicAuth(), "Basic auth (default username:password, will allow all request if empty)"}, + "OLLAMA_BASIC_AUTH_KEY": {"OLLAMA_BASIC_AUTH_KEY", BasicAuthKey(), "Basic auth key for user ollama, will allow all request if empty"}, } if runtime.GOOS != "darwin" { ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"} diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 2a610e78..92a500f1 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -233,29 +233,3 @@ func TestVar(t *testing.T) { }) } } - -func TestBasicAuth(t *testing.T) { - cases := map[string]struct { - value string - expectUser string - expectPassword string - }{ - "empty": {"", "username", "password"}, - "valid": {"user1:password1", "user1", "password1"}, - "missingPassword": {"user2:", "username", "password"}, - "missingUser": {":password2", "username", "password"}, - "noColon": {"user3password3", "username", "password"}, - } - - for name, tt := range cases { - t.Run(name, func(t *testing.T) { - t.Setenv("OLLAMA_BASIC_AUTH", tt.value) - basicAuth := BasicAuth() - - password, exists := basicAuth[tt.expectUser] - if !exists || password != tt.expectPassword { - t.Errorf("%s: expected map[%s:%s], got %v", name, tt.expectUser, tt.expectPassword, basicAuth) - } - }) - } -} diff --git a/server/routes.go b/server/routes.go index 9fe5c30b..ec5f7202 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1064,12 +1064,13 @@ func (s *Server) GenerateRoutes() http.Handler { allowedHostsMiddleware(s.addr), ) - auth := envconfig.BasicAuth() - // Check if the auth map contains the default credentials - isDefaultAuth := len(auth) == 1 && auth["username"] == "password" - - if !isDefaultAuth { - r.Use(gin.BasicAuth(auth)) + ollamaAuthKey := envconfig.BasicAuthKey() + if ollamaAuthKey != "" { + r.Use( + gin.BasicAuth(gin.Accounts{ + "ollama": ollamaAuthKey, + }), + ) } r.POST("/api/pull", s.PullModelHandler) diff --git a/server/routes_test.go b/server/routes_test.go index e33002a5..87c6fd5d 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -345,28 +345,43 @@ func Test_Routes(t *testing.T) { { Name: "Authenticated Route Success", Method: http.MethodGet, - Path: "/api/protected", // replace with your protected route + Path: "/api/version", Setup: func(t *testing.T, req *http.Request) { - req.SetBasicAuth("username", "password") + t.Setenv("OLLAMA_BASIC_AUTH_KEY", "password") + req.SetBasicAuth("ollama", "password") }, Expected: func(t *testing.T, resp *http.Response) { - assert.Equal(t, http.StatusOK, resp.StatusCode) // replace with your expected status code + assert.Equal(t, http.StatusOK, resp.StatusCode) _, err := io.ReadAll(resp.Body) require.NoError(t, err) - // Add more assertions based on the expected response body }, }, { Name: "Authenticated Route Failure", Method: http.MethodGet, - Path: "/api/protected", // replace with your protected route + Path: "/api/version", Setup: func(t *testing.T, req *http.Request) { - req.SetBasicAuth("wronguser", "wrongpass") + t.Setenv("OLLAMA_BASIC_AUTH_KEY", "password") + req.SetBasicAuth("ollama", "wrongpassword") }, Expected: func(t *testing.T, resp *http.Response) { assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) }, }, + { + Name: "BasicAuthKey not set", + Method: http.MethodGet, + Path: "/api/version", + Setup: func(t *testing.T, req *http.Request) { + t.Setenv("OLLAMA_BASIC_AUTH_KEY", "") + req.SetBasicAuth("ollama", "wrongpassword") + }, + Expected: func(t *testing.T, resp *http.Response) { + assert.Equal(t, http.StatusOK, resp.StatusCode) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) + }, + }, } t.Setenv("OLLAMA_MODELS", t.TempDir())