feat: add gin BasicAuth for username:password setup in env

This commit is contained in:
kemalelmizan 2024-08-07 12:04:11 +07:00
parent de4fc29773
commit 1283f413f8
4 changed files with 78 additions and 0 deletions

View File

@ -57,6 +57,24 @@ func Host() *url.URL {
}
}
func BasicAuth() (basicAuth map[string]string) {
basicAuthString := Var("OLLAMA_BASIC_AUTH")
parts := strings.Split(basicAuthString, ":")
if len(parts) < 2 || parts[0] == "" || parts[1] == "" {
// Return a default value if the split result is not as expected
basicAuth = map[string]string{
"username": "password",
}
} else {
basicAuth = map[string]string{
parts[0]: parts[1],
}
}
return basicAuth
}
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func Origins() (origins []string) {
if s := Var("OLLAMA_ORIGINS"); s != "" {
@ -258,6 +276,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"},
"OLLAMA_BASIC_AUTH": {"OLLAMA_BASIC_AUTH", BasicAuth(), "Basic auth (default username:password, will allow all request if empty)"},
}
if runtime.GOOS != "darwin" {
ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"}

View File

@ -233,3 +233,29 @@ func TestVar(t *testing.T) {
})
}
}
func TestBasicAuth(t *testing.T) {
cases := map[string]struct {
value string
expectUser string
expectPassword string
}{
"empty": {"", "username", "password"},
"valid": {"user1:password1", "user1", "password1"},
"missingPassword": {"user2:", "username", "password"},
"missingUser": {":password2", "username", "password"},
"noColon": {"user3password3", "username", "password"},
}
for name, tt := range cases {
t.Run(name, func(t *testing.T) {
t.Setenv("OLLAMA_BASIC_AUTH", tt.value)
basicAuth := BasicAuth()
password, exists := basicAuth[tt.expectUser]
if !exists || password != tt.expectPassword {
t.Errorf("%s: expected map[%s:%s], got %v", name, tt.expectUser, tt.expectPassword, basicAuth)
}
})
}
}

View File

@ -1064,6 +1064,14 @@ func (s *Server) GenerateRoutes() http.Handler {
allowedHostsMiddleware(s.addr),
)
auth := envconfig.BasicAuth()
// Check if the auth map contains the default credentials
isDefaultAuth := len(auth) == 1 && auth["username"] == "password"
if !isDefaultAuth {
r.Use(gin.BasicAuth(auth))
}
r.POST("/api/pull", s.PullModelHandler)
r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler)

View File

@ -342,6 +342,31 @@ func Test_Routes(t *testing.T) {
}
},
},
{
Name: "Authenticated Route Success",
Method: http.MethodGet,
Path: "/api/protected", // replace with your protected route
Setup: func(t *testing.T, req *http.Request) {
req.SetBasicAuth("username", "password")
},
Expected: func(t *testing.T, resp *http.Response) {
assert.Equal(t, http.StatusOK, resp.StatusCode) // replace with your expected status code
_, err := io.ReadAll(resp.Body)
require.NoError(t, err)
// Add more assertions based on the expected response body
},
},
{
Name: "Authenticated Route Failure",
Method: http.MethodGet,
Path: "/api/protected", // replace with your protected route
Setup: func(t *testing.T, req *http.Request) {
req.SetBasicAuth("wronguser", "wrongpass")
},
Expected: func(t *testing.T, resp *http.Response) {
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
},
},
}
t.Setenv("OLLAMA_MODELS", t.TempDir())