From 7aab82c246cd5709e562bad8108c0dc087b9647c Mon Sep 17 00:00:00 2001 From: Deluan Date: Wed, 5 Feb 2020 14:12:13 -0500 Subject: [PATCH] feat: enable overriding sql sorting --- persistence/mediafile_repository.go | 4 ++++ persistence/sql_base_repository.go | 37 ++++++++++++++++++++++------- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/persistence/mediafile_repository.go b/persistence/mediafile_repository.go index e0117401d..0c7113bbe 100644 --- a/persistence/mediafile_repository.go +++ b/persistence/mediafile_repository.go @@ -21,6 +21,10 @@ func NewMediaFileRepository(ctx context.Context, o orm.Ormer) *mediaFileReposito r.ctx = ctx r.ormer = o r.tableName = "media_file" + r.sortMappings = map[string]string{ + "artist": "artist asc, album asc, disc_number asc, track_number asc", + "album": "album asc, disc_number asc, track_number asc", + } return r } diff --git a/persistence/sql_base_repository.go b/persistence/sql_base_repository.go index e14c7a269..15dcb25f1 100644 --- a/persistence/sql_base_repository.go +++ b/persistence/sql_base_repository.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "text/scanner" "time" . "github.com/Masterminds/squirrel" @@ -15,9 +16,10 @@ import ( ) type sqlRepository struct { - ctx context.Context - tableName string - ormer orm.Ormer + ctx context.Context + tableName string + ormer orm.Ormer + sortMappings map[string]string } const invalidUserId = "-1" @@ -55,11 +57,30 @@ func (r sqlRepository) applyOptions(sq SelectBuilder, options ...model.QueryOpti sq = sq.Offset(uint64(options[0].Offset)) } if options[0].Sort != "" { - if options[0].Order == "desc" { - sq = sq.OrderBy(toSnakeCase(options[0].Sort + " desc")) - } else { - sq = sq.OrderBy(toSnakeCase(options[0].Sort)) + sort := toSnakeCase(options[0].Sort) + if mapping, ok := r.sortMappings[sort]; ok { + sort = mapping } + if !strings.Contains(sort, "asc") && !strings.Contains(sort, "desc") { + sort = sort + " asc" + } + if options[0].Order == "desc" { + var s scanner.Scanner + s.Init(strings.NewReader(sort)) + var newSort string + for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() { + switch s.TokenText() { + case "asc": + newSort += " " + "desc" + case "desc": + newSort += " " + "asc" + default: + newSort += " " + s.TokenText() + } + } + sort = newSort + } + sq = sq.OrderBy(sort) } } return sq @@ -190,7 +211,7 @@ func (r sqlRepository) parseRestOptions(options ...rest.QueryOptions) model.Quer qo := model.QueryOptions{} if len(options) > 0 { qo.Sort = options[0].Sort - qo.Order = options[0].Order + qo.Order = strings.ToLower(options[0].Order) qo.Max = options[0].Max qo.Offset = options[0].Offset if len(options[0].Filters) > 0 {