diff --git a/server/subsonic/album_lists_test.go b/server/subsonic/album_lists_test.go index 6a4d97c4b..8b5fa8342 100644 --- a/server/subsonic/album_lists_test.go +++ b/server/subsonic/album_lists_test.go @@ -40,7 +40,7 @@ var _ = Describe("AlbumListController", func() { Describe("GetAlbumList", func() { It("should return list of the type specified", func() { - r := newTestRequest("type=newest", "offset=10", "size=20") + r := newGetRequest("type=newest", "offset=10", "size=20") listGen.data = engine.Entries{ {Id: "1"}, {Id: "2"}, } @@ -54,7 +54,7 @@ var _ = Describe("AlbumListController", func() { }) It("should fail if missing type parameter", func() { - r := newTestRequest() + r := newGetRequest() _, err := controller.GetAlbumList(w, r) Expect(err).To(MatchError("Required string parameter 'type' is not present")) @@ -62,7 +62,7 @@ var _ = Describe("AlbumListController", func() { It("should return error if call fails", func() { listGen.err = errors.New("some issue") - r := newTestRequest("type=newest") + r := newGetRequest("type=newest") _, err := controller.GetAlbumList(w, r) @@ -72,7 +72,7 @@ var _ = Describe("AlbumListController", func() { Describe("GetAlbumList2", func() { It("should return list of the type specified", func() { - r := newTestRequest("type=newest", "offset=10", "size=20") + r := newGetRequest("type=newest", "offset=10", "size=20") listGen.data = engine.Entries{ {Id: "1"}, {Id: "2"}, } @@ -86,7 +86,7 @@ var _ = Describe("AlbumListController", func() { }) It("should fail if missing type parameter", func() { - r := newTestRequest() + r := newGetRequest() _, err := controller.GetAlbumList2(w, r) Expect(err).To(MatchError("Required string parameter 'type' is not present")) @@ -94,7 +94,7 @@ var _ = Describe("AlbumListController", func() { It("should return error if call fails", func() { listGen.err = errors.New("some issue") - r := newTestRequest("type=newest") + r := newGetRequest("type=newest") _, err := controller.GetAlbumList2(w, r) diff --git a/server/subsonic/api.go b/server/subsonic/api.go index fe1920eb9..b4a8e7db7 100644 --- a/server/subsonic/api.go +++ b/server/subsonic/api.go @@ -45,6 +45,7 @@ func (api *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (api *Router) routes() http.Handler { r := chi.NewRouter() + r.Use(postFormToQueryParams) r.Use(checkRequiredParameters) // Add validation middleware if not disabled diff --git a/server/subsonic/media_retrieval_test.go b/server/subsonic/media_retrieval_test.go index 5f14c542a..9f5d969ae 100644 --- a/server/subsonic/media_retrieval_test.go +++ b/server/subsonic/media_retrieval_test.go @@ -42,7 +42,7 @@ var _ = Describe("MediaRetrievalController", func() { Describe("GetCoverArt", func() { It("should return data for that id", func() { cover.data = "image data" - r := newTestRequest("id=34", "size=128") + r := newGetRequest("id=34", "size=128") _, err := controller.GetCoverArt(w, r) Expect(err).To(BeNil()) @@ -52,7 +52,7 @@ var _ = Describe("MediaRetrievalController", func() { }) It("should fail if missing id parameter", func() { - r := newTestRequest() + r := newGetRequest() _, err := controller.GetCoverArt(w, r) Expect(err).To(MatchError("id parameter required")) @@ -60,7 +60,7 @@ var _ = Describe("MediaRetrievalController", func() { It("should fail when the file is not found", func() { cover.err = model.ErrNotFound - r := newTestRequest("id=34", "size=128") + r := newGetRequest("id=34", "size=128") _, err := controller.GetCoverArt(w, r) Expect(err).To(MatchError("Cover not found")) @@ -68,7 +68,7 @@ var _ = Describe("MediaRetrievalController", func() { It("should fail when there is an unknown error", func() { cover.err = errors.New("weird error") - r := newTestRequest("id=34", "size=128") + r := newGetRequest("id=34", "size=128") _, err := controller.GetCoverArt(w, r) Expect(err).To(MatchError("Internal Error")) diff --git a/server/subsonic/middlewares.go b/server/subsonic/middlewares.go index f977c82b8..57c2d131f 100644 --- a/server/subsonic/middlewares.go +++ b/server/subsonic/middlewares.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "net/http" + "net/url" + "strings" "github.com/deluan/navidrome/engine" "github.com/deluan/navidrome/log" @@ -11,6 +13,24 @@ import ( "github.com/deluan/navidrome/server/subsonic/responses" ) +func postFormToQueryParams(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + if err != nil { + SendError(w, r, NewError(responses.ErrorGeneric, err.Error())) + } + var parts []string + for key, values := range r.Form { + for _, v := range values { + parts = append(parts, url.QueryEscape(key)+"="+url.QueryEscape(v)) + } + } + r.URL.RawQuery = strings.Join(parts, "&") + + next.ServeHTTP(w, r) + }) +} + func checkRequiredParameters(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requiredParameters := []string{"u", "v", "c"} diff --git a/server/subsonic/middlewares_test.go b/server/subsonic/middlewares_test.go index c3cf48c1c..b3c92ae47 100644 --- a/server/subsonic/middlewares_test.go +++ b/server/subsonic/middlewares_test.go @@ -13,8 +13,18 @@ import ( . "github.com/onsi/gomega" ) -func newTestRequest(queryParams ...string) *http.Request { - r := httptest.NewRequest("get", "/ping?"+strings.Join(queryParams, "&"), nil) +func newGetRequest(queryParams ...string) *http.Request { + r := httptest.NewRequest("GET", "/ping?"+strings.Join(queryParams, "&"), nil) + ctx := r.Context() + return r.WithContext(log.NewContext(ctx)) +} + +func newPostRequest(queryParam string, formFields ...string) *http.Request { + r, err := http.NewRequest("POST", "/ping?"+queryParam, strings.NewReader(strings.Join(formFields, "&"))) + if err != nil { + panic(err) + } + r.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") ctx := r.Context() return r.WithContext(log.NewContext(ctx)) } @@ -28,9 +38,37 @@ var _ = Describe("Middlewares", func() { w = httptest.NewRecorder() }) + Describe("ParsePostForm", func() { + It("converts any filed in a x-www-form-urlencoded POST into query params", func() { + r := newPostRequest("a=abc", "u=user", "v=1.15", "c=test") + cp := postFormToQueryParams(next) + cp.ServeHTTP(w, r) + + Expect(next.req.URL.Query().Get("a")).To(Equal("abc")) + Expect(next.req.URL.Query().Get("u")).To(Equal("user")) + Expect(next.req.URL.Query().Get("v")).To(Equal("1.15")) + Expect(next.req.URL.Query().Get("c")).To(Equal("test")) + }) + It("adds repeated params", func() { + r := newPostRequest("a=abc", "id=1", "id=2") + cp := postFormToQueryParams(next) + cp.ServeHTTP(w, r) + + Expect(next.req.URL.Query().Get("a")).To(Equal("abc")) + Expect(next.req.URL.Query()["id"]).To(ConsistOf("1", "2")) + }) + It("overrides query params with same key", func() { + r := newPostRequest("a=query", "a=body") + cp := postFormToQueryParams(next) + cp.ServeHTTP(w, r) + + Expect(next.req.URL.Query().Get("a")).To(Equal("body")) + }) + }) + Describe("CheckParams", func() { It("passes when all required params are available", func() { - r := newTestRequest("u=user", "v=1.15", "c=test") + r := newGetRequest("u=user", "v=1.15", "c=test") cp := checkRequiredParameters(next) cp.ServeHTTP(w, r) @@ -41,7 +79,7 @@ var _ = Describe("Middlewares", func() { }) It("fails when user is missing", func() { - r := newTestRequest("v=1.15", "c=test") + r := newGetRequest("v=1.15", "c=test") cp := checkRequiredParameters(next) cp.ServeHTTP(w, r) @@ -50,7 +88,7 @@ var _ = Describe("Middlewares", func() { }) It("fails when version is missing", func() { - r := newTestRequest("u=user", "c=test") + r := newGetRequest("u=user", "c=test") cp := checkRequiredParameters(next) cp.ServeHTTP(w, r) @@ -59,7 +97,7 @@ var _ = Describe("Middlewares", func() { }) It("fails when client is missing", func() { - r := newTestRequest("u=user", "v=1.15") + r := newGetRequest("u=user", "v=1.15") cp := checkRequiredParameters(next) cp.ServeHTTP(w, r) @@ -75,7 +113,7 @@ var _ = Describe("Middlewares", func() { }) It("passes all parameters to users.Authenticate ", func() { - r := newTestRequest("u=valid", "p=password", "t=token", "s=salt") + r := newGetRequest("u=valid", "p=password", "t=token", "s=salt") cp := authenticate(mockedUser)(next) cp.ServeHTTP(w, r) @@ -89,7 +127,7 @@ var _ = Describe("Middlewares", func() { }) It("fails authentication with wrong password", func() { - r := newTestRequest("u=invalid", "", "", "") + r := newGetRequest("u=invalid", "", "", "") cp := authenticate(mockedUser)(next) cp.ServeHTTP(w, r)