feat: add gin BasicAuth for username:password setup in env
This commit is contained in:
parent
de4fc29773
commit
1283f413f8
@ -57,6 +57,24 @@ func Host() *url.URL {
|
||||
}
|
||||
}
|
||||
|
||||
func BasicAuth() (basicAuth map[string]string) {
|
||||
basicAuthString := Var("OLLAMA_BASIC_AUTH")
|
||||
parts := strings.Split(basicAuthString, ":")
|
||||
|
||||
if len(parts) < 2 || parts[0] == "" || parts[1] == "" {
|
||||
// Return a default value if the split result is not as expected
|
||||
basicAuth = map[string]string{
|
||||
"username": "password",
|
||||
}
|
||||
} else {
|
||||
basicAuth = map[string]string{
|
||||
parts[0]: parts[1],
|
||||
}
|
||||
}
|
||||
|
||||
return basicAuth
|
||||
}
|
||||
|
||||
// Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
|
||||
func Origins() (origins []string) {
|
||||
if s := Var("OLLAMA_ORIGINS"); s != "" {
|
||||
@ -258,6 +276,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir(), "Location for runners"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"},
|
||||
"OLLAMA_BASIC_AUTH": {"OLLAMA_BASIC_AUTH", BasicAuth(), "Basic auth (default username:password, will allow all request if empty)"},
|
||||
}
|
||||
if runtime.GOOS != "darwin" {
|
||||
ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"}
|
||||
|
@ -233,3 +233,29 @@ func TestVar(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
cases := map[string]struct {
|
||||
value string
|
||||
expectUser string
|
||||
expectPassword string
|
||||
}{
|
||||
"empty": {"", "username", "password"},
|
||||
"valid": {"user1:password1", "user1", "password1"},
|
||||
"missingPassword": {"user2:", "username", "password"},
|
||||
"missingUser": {":password2", "username", "password"},
|
||||
"noColon": {"user3password3", "username", "password"},
|
||||
}
|
||||
|
||||
for name, tt := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_BASIC_AUTH", tt.value)
|
||||
basicAuth := BasicAuth()
|
||||
|
||||
password, exists := basicAuth[tt.expectUser]
|
||||
if !exists || password != tt.expectPassword {
|
||||
t.Errorf("%s: expected map[%s:%s], got %v", name, tt.expectUser, tt.expectPassword, basicAuth)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1064,6 +1064,14 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
allowedHostsMiddleware(s.addr),
|
||||
)
|
||||
|
||||
auth := envconfig.BasicAuth()
|
||||
// Check if the auth map contains the default credentials
|
||||
isDefaultAuth := len(auth) == 1 && auth["username"] == "password"
|
||||
|
||||
if !isDefaultAuth {
|
||||
r.Use(gin.BasicAuth(auth))
|
||||
}
|
||||
|
||||
r.POST("/api/pull", s.PullModelHandler)
|
||||
r.POST("/api/generate", s.GenerateHandler)
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
|
@ -342,6 +342,31 @@ func Test_Routes(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Authenticated Route Success",
|
||||
Method: http.MethodGet,
|
||||
Path: "/api/protected", // replace with your protected route
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
req.SetBasicAuth("username", "password")
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode) // replace with your expected status code
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
// Add more assertions based on the expected response body
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Authenticated Route Failure",
|
||||
Method: http.MethodGet,
|
||||
Path: "/api/protected", // replace with your protected route
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
req.SetBasicAuth("wronguser", "wrongpass")
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
Loading…
x
Reference in New Issue
Block a user