diff --git a/envconfig/config.go b/envconfig/config.go index b82b773d..74d4102d 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -57,6 +57,24 @@ func Host() *url.URL { } } +func BasicAuth() (basicAuth map[string]string) { + basicAuthString := Var("OLLAMA_BASIC_AUTH") + parts := strings.Split(basicAuthString, ":") + + if len(parts) < 2 || parts[0] == "" || parts[1] == "" { + // Return a default value if the split result is not as expected + basicAuth = map[string]string{ + "username": "password", + } + } else { + basicAuth = map[string]string{ + parts[0]: parts[1], + } + } + + return basicAuth +} + // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable. func Origins() (origins []string) { if s := Var("OLLAMA_ORIGINS"); s != "" { @@ -258,6 +276,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"}, + "OLLAMA_BASIC_AUTH": {"OLLAMA_BASIC_AUTH", BasicAuth(), "Basic auth (default username:password, will allow all request if empty)"}, } if runtime.GOOS != "darwin" { ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"} diff --git a/envconfig/config_test.go b/envconfig/config_test.go index 92a500f1..2a610e78 100644 --- a/envconfig/config_test.go +++ b/envconfig/config_test.go @@ -233,3 +233,29 @@ func TestVar(t *testing.T) { }) } } + +func TestBasicAuth(t *testing.T) { + cases := map[string]struct { + value string + expectUser string + expectPassword string + }{ + "empty": {"", "username", "password"}, + "valid": {"user1:password1", "user1", "password1"}, + "missingPassword": {"user2:", "username", "password"}, + "missingUser": {":password2", "username", "password"}, + "noColon": {"user3password3", "username", "password"}, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + t.Setenv("OLLAMA_BASIC_AUTH", tt.value) + basicAuth := BasicAuth() + + password, exists := basicAuth[tt.expectUser] + if !exists || password != tt.expectPassword { + t.Errorf("%s: expected map[%s:%s], got %v", name, tt.expectUser, tt.expectPassword, basicAuth) + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index b9c66b65..9fe5c30b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1064,6 +1064,14 @@ func (s *Server) GenerateRoutes() http.Handler { allowedHostsMiddleware(s.addr), ) + auth := envconfig.BasicAuth() + // Check if the auth map contains the default credentials + isDefaultAuth := len(auth) == 1 && auth["username"] == "password" + + if !isDefaultAuth { + r.Use(gin.BasicAuth(auth)) + } + r.POST("/api/pull", s.PullModelHandler) r.POST("/api/generate", s.GenerateHandler) r.POST("/api/chat", s.ChatHandler) diff --git a/server/routes_test.go b/server/routes_test.go index ef7248ef..e33002a5 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -342,6 +342,31 @@ func Test_Routes(t *testing.T) { } }, }, + { + Name: "Authenticated Route Success", + Method: http.MethodGet, + Path: "/api/protected", // replace with your protected route + Setup: func(t *testing.T, req *http.Request) { + req.SetBasicAuth("username", "password") + }, + Expected: func(t *testing.T, resp *http.Response) { + assert.Equal(t, http.StatusOK, resp.StatusCode) // replace with your expected status code + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) + // Add more assertions based on the expected response body + }, + }, + { + Name: "Authenticated Route Failure", + Method: http.MethodGet, + Path: "/api/protected", // replace with your protected route + Setup: func(t *testing.T, req *http.Request) { + req.SetBasicAuth("wronguser", "wrongpass") + }, + Expected: func(t *testing.T, resp *http.Response) { + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }, + }, } t.Setenv("OLLAMA_MODELS", t.TempDir())