diff --git a/api/api_test.go b/api/api_test.go index 016aef882..4c49cef61 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -34,3 +34,14 @@ func Get(url string, testCase string) (*http.Request, *httptest.ResponseRecorder return r, w } + +func GetWithHeader(url string, header, value, testCase string) (*http.Request, *httptest.ResponseRecorder) { + r, _ := http.NewRequest("GET", url, nil) + r.Header.Add(header, value) + w := httptest.NewRecorder() + beego.BeeApp.Handlers.ServeHTTP(w, r) + + beego.Debug("testing", testCase, fmt.Sprintf("\nUrl: %s\nStatus Code: [%d]\n%s", r.URL, w.Code, w.Body.String())) + + return r, w +} diff --git a/api/base_api_controller.go b/api/base_api_controller.go index f1fd6db11..03768da5f 100644 --- a/api/base_api_controller.go +++ b/api/base_api_controller.go @@ -1,7 +1,6 @@ package api import ( - "context" "encoding/xml" "fmt" "strconv" @@ -13,10 +12,7 @@ import ( "github.com/cloudsonic/sonic-server/utils" ) -type BaseAPIController struct { - beego.Controller - context context.Context -} +type BaseAPIController struct{ beego.Controller } func (c *BaseAPIController) NewEmpty() responses.Subsonic { return responses.Subsonic{Status: "ok", Version: beego.AppConfig.String("apiVersion")} diff --git a/api/validation.go b/api/validation.go index a66d5b5f0..c5916abfd 100644 --- a/api/validation.go +++ b/api/validation.go @@ -27,18 +27,12 @@ func Validate(controller BaseAPIController) { } } -func getData(c BaseAPIController, name string) string { - if v, ok := c.Ctx.Input.GetData(name).(string); ok { - return v - } - return "" -} - func addNewContext(c BaseAPIController) { ctx := context.Background() - id := getData(c, "requestId") - c.context = context.WithValue(ctx, "requestId", id) + id := c.Ctx.Input.GetData("requestId") + ctx = context.WithValue(ctx, "requestId", id) + c.Ctx.Input.SetData("context", ctx) } func checkParameters(c BaseAPIController) { @@ -54,6 +48,11 @@ func checkParameters(c BaseAPIController) { if c.GetString("p") == "" && (c.GetString("s") == "" || c.GetString("t") == "") { logWarn(c, "Missing authentication information") } + ctx := c.Ctx.Input.GetData("context").(context.Context) + ctx = context.WithValue(ctx, "user", c.GetString("u")) + ctx = context.WithValue(ctx, "client", c.GetString("c")) + ctx = context.WithValue(ctx, "version", c.GetString("v")) + c.Ctx.Input.SetData("context", ctx) } func authenticate(c BaseAPIController) { diff --git a/api/validation_test.go b/api/validation_test.go index cc1413473..7e51a42e1 100644 --- a/api/validation_test.go +++ b/api/validation_test.go @@ -5,6 +5,10 @@ import ( "fmt" "testing" + "context" + + "github.com/astaxie/beego" + "github.com/cloudsonic/sonic-server/api" "github.com/cloudsonic/sonic-server/api/responses" "github.com/cloudsonic/sonic-server/tests" . "github.com/smartystreets/goconvey/convey" @@ -74,3 +78,39 @@ func TestAuthentication(t *testing.T) { }) }) } + +type mockController struct { + api.BaseAPIController +} + +func (c *mockController) Get() { + actualContext = c.Ctx.Input.GetData("context").(context.Context) + c.Ctx.WriteString("OK") +} + +var actualContext context.Context + +func TestContext(t *testing.T) { + tests.Init(t, false) + beego.Router("/rest/mocktest", &mockController{}) + + Convey("Subject: Context", t, func() { + _, w := GetWithHeader("/rest/mocktest?u=deluan&p=wordpass&c=testClient&v=1.0.0", "X-Request-Id", "123123", "TestContext") + Convey("The status should be 'OK'", func() { + resp := string(w.Body.Bytes()) + So(resp, ShouldEqual, "OK") + }) + Convey("user should be set", func() { + So(actualContext.Value("user"), ShouldEqual, "deluan") + }) + Convey("client should be set", func() { + So(actualContext.Value("client"), ShouldEqual, "testClient") + }) + Convey("version should be set", func() { + So(actualContext.Value("version"), ShouldEqual, "1.0.0") + }) + Convey("context should be set", func() { + So(actualContext.Value("requestId"), ShouldEqual, "123123") + }) + }) +} diff --git a/init/router.go b/init/router.go index 0cadeb2ca..fd10afb0a 100644 --- a/init/router.go +++ b/init/router.go @@ -75,7 +75,6 @@ func initFilters() { if id == "" { id = uuid.NewV4().String() } - ctx.Input.SetData("requestId", id) }