diff --git a/persistence/sql_base_repository.go b/persistence/sql_base_repository.go index 543cf713f..8b3cf854e 100644 --- a/persistence/sql_base_repository.go +++ b/persistence/sql_base_repository.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "strings" - "text/scanner" "time" . "github.com/Masterminds/squirrel" @@ -56,35 +55,46 @@ func (r sqlRepository) applyOptions(sq SelectBuilder, options ...model.QueryOpti sq = sq.Offset(uint64(options[0].Offset)) } if options[0].Sort != "" { - sort := toSnakeCase(options[0].Sort) - if mapping, ok := r.sortMappings[sort]; ok { - sort = mapping - } - if !strings.Contains(sort, "asc") && !strings.Contains(sort, "desc") { - sort = sort + " asc" - } - if options[0].Order == "desc" { - var s scanner.Scanner - s.Init(strings.NewReader(sort)) - var newSort string - for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() { - switch s.TokenText() { - case "asc": - newSort += " " + "desc" - case "desc": - newSort += " " + "asc" - default: - newSort += " " + s.TokenText() - } - } - sort = newSort - } - sq = sq.OrderBy(sort) + sq = sq.OrderBy(r.buildSortOrder(options[0].Sort, options[0].Order)) } } return sq } +func (r sqlRepository) buildSortOrder(sort, order string) string { + if mapping, ok := r.sortMappings[sort]; ok { + sort = mapping + } + + sort = toSnakeCase(sort) + order = strings.ToLower(strings.TrimSpace(order)) + var reverseOrder string + if order == "desc" { + reverseOrder = "asc" + } else { + order = "asc" + reverseOrder = "desc" + } + + var newSort []string + parts := strings.FieldsFunc(sort, func(c rune) bool { return c == ',' }) + for _, p := range parts { + f := strings.Fields(p) + newField := []string{f[0]} + if len(f) == 1 { + newField = append(newField, order) + } else { + if f[1] == "asc" { + newField = append(newField, order) + } else { + newField = append(newField, reverseOrder) + } + } + newSort = append(newSort, strings.Join(newField, " ")) + } + return strings.Join(newSort, ", ") +} + func (r sqlRepository) applyFilters(sq SelectBuilder, options ...model.QueryOptions) SelectBuilder { if len(options) > 0 && options[0].Filters != nil { sq = sq.Where(options[0].Filters) diff --git a/persistence/sql_base_repository_test.go b/persistence/sql_base_repository_test.go new file mode 100644 index 000000000..7181af633 --- /dev/null +++ b/persistence/sql_base_repository_test.go @@ -0,0 +1,64 @@ +package persistence + +import ( + "github.com/Masterminds/squirrel" + "github.com/deluan/navidrome/model" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("sqlRepository", func() { + r := sqlRepository{} + Describe("applyOptions", func() { + var sq squirrel.SelectBuilder + BeforeEach(func() { + sq = squirrel.Select("*").From("test") + }) + It("does not add any clauses when options is empty", func() { + sq = r.applyOptions(sq, model.QueryOptions{}) + sql, _, _ := sq.ToSql() + Expect(sql).To(Equal("SELECT * FROM test")) + }) + It("adds all option clauses", func() { + sq = r.applyOptions(sq, model.QueryOptions{ + Sort: "name", + Order: "desc", + Max: 1, + Offset: 2, + }) + sql, _, _ := sq.ToSql() + Expect(sql).To(Equal("SELECT * FROM test ORDER BY name desc LIMIT 1 OFFSET 2")) + }) + }) + + Describe("buildSortOrder", func() { + Context("single field", func() { + It("sorts by specified field", func() { + sql := r.buildSortOrder("name", "desc") + Expect(sql).To(Equal("name desc")) + }) + It("defaults to 'asc'", func() { + sql := r.buildSortOrder("name", "") + Expect(sql).To(Equal("name asc")) + }) + It("inverts pre-defined order", func() { + sql := r.buildSortOrder("name desc", "desc") + Expect(sql).To(Equal("name asc")) + }) + It("forces snake case for field names", func() { + sql := r.buildSortOrder("AlbumArtist", "asc") + Expect(sql).To(Equal("album_artist asc")) + }) + }) + Context("multiple fields", func() { + It("handles multiple fields", func() { + sql := r.buildSortOrder("name desc,age asc, status desc ", "asc") + Expect(sql).To(Equal("name desc, age asc, status desc")) + }) + It("inverts multiple fields", func() { + sql := r.buildSortOrder("name desc, age, status asc", "desc") + Expect(sql).To(Equal("name asc, age desc, status desc")) + }) + }) + }) +})