From 8cdd4e317d8a66a2bfa61fb0959a0965cb677929 Mon Sep 17 00:00:00 2001
From: Deluan <deluan@deluan.com>
Date: Thu, 19 Mar 2020 21:09:57 -0400
Subject: [PATCH] feat: allow restful filter customization per field

---
 persistence/album_repository.go    |  3 +++
 persistence/sql_base_repository.go | 26 +++++++++++++++++++++-----
 2 files changed, 24 insertions(+), 5 deletions(-)

diff --git a/persistence/album_repository.go b/persistence/album_repository.go
index 034159c32..3427543cc 100644
--- a/persistence/album_repository.go
+++ b/persistence/album_repository.go
@@ -23,6 +23,9 @@ func NewAlbumRepository(ctx context.Context, o orm.Ormer) model.AlbumRepository
 	r.sortMappings = map[string]string{
 		"artist": "compilation asc, album_artist asc, name asc",
 	}
+	r.filterMappings = map[string]filterFunc{
+		"compilation": booleanFilter,
+	}
 
 	return r
 }
diff --git a/persistence/sql_base_repository.go b/persistence/sql_base_repository.go
index 9b21db352..5db0cae4f 100644
--- a/persistence/sql_base_repository.go
+++ b/persistence/sql_base_repository.go
@@ -15,11 +15,14 @@ import (
 	"github.com/google/uuid"
 )
 
+type filterFunc = func(field string, value interface{}) Sqlizer
+
 type sqlRepository struct {
-	ctx          context.Context
-	tableName    string
-	ormer        orm.Ormer
-	sortMappings map[string]string
+	ctx            context.Context
+	tableName      string
+	ormer          orm.Ormer
+	sortMappings   map[string]string
+	filterMappings map[string]filterFunc
 }
 
 const invalidUserId = "-1"
@@ -226,10 +229,23 @@ func (r sqlRepository) parseRestOptions(options ...rest.QueryOptions) model.Quer
 		if len(options[0].Filters) > 0 {
 			filters := And{}
 			for f, v := range options[0].Filters {
-				filters = append(filters, Like{f: fmt.Sprintf("%s%%", v)})
+				if ff, ok := r.filterMappings[f]; ok {
+					filters = append(filters, ff(f, v))
+				} else {
+					filters = append(filters, startsWithFilter(f, v))
+				}
 			}
 			qo.Filters = filters
 		}
 	}
 	return qo
 }
+
+func startsWithFilter(field string, value interface{}) Like {
+	return Like{field: fmt.Sprintf("%s%%", value)}
+}
+
+func booleanFilter(field string, value interface{}) Sqlizer {
+	v := strings.ToLower(value.(string))
+	return Eq{field: strings.ToLower(v) == "true"}
+}