diff --git a/envconfig/config.go b/envconfig/config.go index e80c67ba..5b31ef5f 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -180,6 +180,8 @@ var ( RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES") GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL") HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION") + + BasicAuthKey = String("OLLAMA_BASIC_AUTH_KEY") ) func Uint(key string, defaultValue uint) func() uint { @@ -249,6 +251,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, + "OLLAMA_BASIC_AUTH_KEY": {"OLLAMA_BASIC_AUTH_KEY", BasicAuthKey(), "Basic auth key for user ollama, will allow all request if empty"}, // Informational "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, diff --git a/server/routes.go b/server/routes.go index c5fd3293..1416a5f6 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1145,6 +1145,15 @@ func (s *Server) GenerateRoutes() http.Handler { allowedHostsMiddleware(s.addr), ) + ollamaAuthKey := envconfig.BasicAuthKey() + if ollamaAuthKey != "" { + r.Use( + gin.BasicAuth(gin.Accounts{ + "ollama": ollamaAuthKey, + }), + ) + } + r.POST("/api/pull", s.PullHandler) r.POST("/api/generate", s.GenerateHandler) r.POST("/api/chat", s.ChatHandler) diff --git a/server/routes_test.go b/server/routes_test.go index bd5b56af..309bbeba 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -438,6 +438,46 @@ func Test_Routes(t *testing.T) { } }, }, + { + Name: "Authenticated Route Success", + Method: http.MethodGet, + Path: "/api/version", + Setup: func(t *testing.T, req *http.Request) { + t.Setenv("OLLAMA_BASIC_AUTH_KEY", "password") + req.SetBasicAuth("ollama", "password") + }, + Expected: func(t *testing.T, resp *http.Response) { + assert.Equal(t, http.StatusOK, resp.StatusCode) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) + }, + }, + { + Name: "Authenticated Route Failure", + Method: http.MethodGet, + Path: "/api/version", + Setup: func(t *testing.T, req *http.Request) { + t.Setenv("OLLAMA_BASIC_AUTH_KEY", "password") + req.SetBasicAuth("ollama", "wrongpassword") + }, + Expected: func(t *testing.T, resp *http.Response) { + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }, + }, + { + Name: "BasicAuthKey not set", + Method: http.MethodGet, + Path: "/api/version", + Setup: func(t *testing.T, req *http.Request) { + t.Setenv("OLLAMA_BASIC_AUTH_KEY", "") + req.SetBasicAuth("ollama", "wrongpassword") + }, + Expected: func(t *testing.T, resp *http.Response) { + assert.Equal(t, http.StatusOK, resp.StatusCode) + _, err := io.ReadAll(resp.Body) + require.NoError(t, err) + }, + }, } t.Setenv("OLLAMA_MODELS", t.TempDir())