From 7221b49b9882fe202e0aac1af8df437c300f42c2 Mon Sep 17 00:00:00 2001 From: Deluan Date: Thu, 14 Oct 2021 15:23:13 -0400 Subject: [PATCH] More tests --- persistence/sql_smartplaylist.go | 79 +++++++++++++-------------- persistence/sql_smartplaylist_test.go | 79 ++++++++++++++++++++++++++- 2 files changed, 116 insertions(+), 42 deletions(-) diff --git a/persistence/sql_smartplaylist.go b/persistence/sql_smartplaylist.go index a788c414a..d7b44f9fb 100644 --- a/persistence/sql_smartplaylist.go +++ b/persistence/sql_smartplaylist.go @@ -128,56 +128,43 @@ var dateRuleType = reflect.TypeOf(dateRule{}) type dateRule model.Rule func (r dateRule) ToSql() (string, []interface{}, error) { - var dates []time.Time + var date time.Time var err error var sq Sqlizer switch r.Operator { case "is": - if dates, err = r.parseDates(); err != nil { - return "", nil, err - } - sq = Eq{r.Field: dates} + date, err = r.parseDate(r.Value) + sq = Eq{r.Field: date} case "is not": - if dates, err = r.parseDates(); err != nil { - return "", nil, err - } - sq = NotEq{r.Field: dates} + date, err = r.parseDate(r.Value) + sq = NotEq{r.Field: date} case "is before": - if dates, err = r.parseDates(); err != nil { - return "", nil, err - } - sq = Lt{r.Field: dates[0]} + date, err = r.parseDate(r.Value) + sq = Lt{r.Field: date} case "is after": - if dates, err = r.parseDates(); err != nil { - return "", nil, err - } - sq = Gt{r.Field: dates[0]} + date, err = r.parseDate(r.Value) + sq = Gt{r.Field: date} case "is in the range": - if dates, err = r.parseDates(); err != nil { - return "", nil, err + var dates []time.Time + if dates, err = r.parseDates(); err == nil { + sq = And{GtOrEq{r.Field: dates[0]}, LtOrEq{r.Field: dates[1]}} } - if len(dates) != 2 { - return "", nil, fmt.Errorf("not a valid date range: %s", r.Value) - } - sq = And{Gt{r.Field: dates[0]}, Lt{r.Field: dates[1]}} case "in the last": sq, err = r.inTheLast(false) - if err != nil { - return "", nil, err - } case "not in the last": sq, err = r.inTheLast(true) - if err != nil { - return "", nil, err - } default: - return "", nil, errors.New("operator not supported: " + r.Operator) + err = errors.New("operator not supported: " + r.Operator) + } + if err != nil { + return "", nil, err } return sq.ToSql() } func (r dateRule) inTheLast(invert bool) (Sqlizer, error) { - v, err := strconv.ParseInt(r.Value.(string), 10, 64) + str := fmt.Sprintf("%v", r.Value) + v, err := strconv.ParseInt(str, 10, 64) if err != nil { return nil, err } @@ -188,22 +175,34 @@ func (r dateRule) inTheLast(invert bool) (Sqlizer, error) { return Gt{r.Field: period}, nil } +func (r dateRule) parseDate(date interface{}) (time.Time, error) { + input, ok := date.(string) + if !ok { + return time.Time{}, fmt.Errorf("invalid date: %v", date) + } + d, err := time.Parse("2006-01-02", input) + if err != nil { + return time.Time{}, fmt.Errorf("invalid date: %v", date) + } + return d, nil +} + func (r dateRule) parseDates() ([]time.Time, error) { - var input []string - switch v := r.Value.(type) { - case string: - input = append(input, v) - case []string: - input = append(input, v...) + input, ok := r.Value.([]string) + if !ok { + return nil, fmt.Errorf("invalid date range: %s", r.Value) } var dates []time.Time for _, s := range input { - d, err := time.Parse("2006-01-02", s) + d, err := r.parseDate(s) if err != nil { - return nil, errors.New("invalid date: " + s) + return nil, fmt.Errorf("invalid date '%v' in range %v", s, input) } dates = append(dates, d) } + if len(dates) != 2 { + return nil, fmt.Errorf("not a valid date range: %s", r.Value) + } return dates, nil } @@ -254,7 +253,7 @@ func (e errorSqlizer) ToSql() (sql string, args []interface{}, err error) { func (rg RuleGroup) ruleToSqlizer(r model.Rule) Sqlizer { ruleDef := fieldMap[strings.ToLower(r.Field)] if ruleDef == nil { - return errorSqlizer("invalid smart playlist field " + r.Field) + return errorSqlizer(fmt.Sprintf("invalid smart playlist field '%s'", r.Field)) } r.Field = ruleDef.dbField r.Operator = strings.ToLower(r.Operator) diff --git a/persistence/sql_smartplaylist_test.go b/persistence/sql_smartplaylist_test.go index 4454655d7..abbe364b5 100644 --- a/persistence/sql_smartplaylist_test.go +++ b/persistence/sql_smartplaylist_test.go @@ -43,6 +43,14 @@ var _ = Describe("SmartPlaylist", func() { lastMonth := time.Now().Add(-30 * 24 * time.Hour) Expect(args).To(ConsistOf("%love%", 1980, 1989, true, BeTemporally("~", lastMonth, time.Second), "zé", "4")) }) + It("returns an error if field is invalid", func() { + r := pls.Rules[0].(model.Rule) + r.Field = "INVALID" + pls.Rules[0] = r + sel := pls.AddFilters(squirrel.Select("media_file").Columns("*")) + _, _, err := sel.ToSql() + Expect(err).To(MatchError("invalid smart playlist field 'INVALID'")) + }) }) Describe("fieldMap", func() { @@ -78,12 +86,12 @@ var _ = Describe("SmartPlaylist", func() { Describe("numberRule", func() { DescribeTable("operators", - func(operator, expectedSql string, expectedValue ...interface{}) { + func(operator, expectedSql string, expectedValue int) { r := numberRule{Field: "year", Operator: operator, Value: 1985} sql, args, err := r.ToSql() Expect(err).ToNot(HaveOccurred()) Expect(sql).To(Equal(expectedSql)) - Expect(args).To(ConsistOf(expectedValue...)) + Expect(args).To(ConsistOf(expectedValue)) }, Entry("is", "is", "year = ?", 1985), Entry("is not", "is not", "year <> ?", 1985), @@ -99,4 +107,71 @@ var _ = Describe("SmartPlaylist", func() { Expect(args).To(ConsistOf(1981, 1990)) }) }) + + Describe("dateRule", func() { + dateStr := "2021-10-14" + date, _ := time.Parse("2006-01-02", dateStr) + DescribeTable("simple operators", + func(operator, expectedSql string, expectedValue time.Time) { + r := dateRule{Field: "lastPlayed", Operator: operator, Value: dateStr} + sql, args, err := r.ToSql() + Expect(err).ToNot(HaveOccurred()) + Expect(sql).To(Equal(expectedSql)) + Expect(args).To(ConsistOf(expectedValue)) + }, + Entry("is", "is", "lastPlayed = ?", date), + Entry("is not", "is not", "lastPlayed <> ?", date), + Entry("is before", "is before", "lastPlayed < ?", date), + Entry("is after", "is after", "lastPlayed > ?", date), + ) + + DescribeTable("period operators", + func(operator, expectedSql string, expectedValue time.Time) { + r := dateRule{Field: "lastPlayed", Operator: operator, Value: 90} + sql, args, err := r.ToSql() + Expect(err).ToNot(HaveOccurred()) + Expect(sql).To(Equal(expectedSql)) + Expect(args).To(ConsistOf(BeTemporally("~", expectedValue, 25*time.Hour))) + }, + Entry("in the last", "in the last", "lastPlayed > ?", date.Add(-90*24*time.Hour)), + Entry("not in the last", "not in the last", "lastPlayed < ?", date.Add(-90*24*time.Hour)), + ) + + It("accepts string as the 'in the last' operator value", func() { + r := dateRule{Field: "lastPlayed", Operator: "in the last", Value: "90"} + _, args, _ := r.ToSql() + Expect(args).To(ConsistOf(BeTemporally("~", date.Add(-90*24*time.Hour), 25*time.Hour))) + }) + + It("implements the 'is in the range' operator", func() { + date2Str := "2021-09-14" + date2, _ := time.Parse("2006-01-02", date2Str) + + r := dateRule{Field: "lastPlayed", Operator: "is in the range", Value: []string{date2Str, dateStr}} + sql, args, err := r.ToSql() + Expect(err).ToNot(HaveOccurred()) + Expect(sql).To(Equal("(lastPlayed >= ? AND lastPlayed <= ?)")) + Expect(args).To(ConsistOf(BeTemporally("~", date2, 25*time.Hour), BeTemporally("~", date, 25*time.Hour))) + }) + + It("returns error if date is invalid", func() { + r := dateRule{Field: "lastPlayed", Operator: "is", Value: "INVALID"} + _, _, err := r.ToSql() + Expect(err).To(MatchError("invalid date: INVALID")) + }) + }) + + Describe("boolRule", func() { + DescribeTable("operators", + func(operator, expectedSql string, expectedValue ...interface{}) { + r := boolRule{Field: "loved", Operator: operator} + sql, args, err := r.ToSql() + Expect(err).ToNot(HaveOccurred()) + Expect(sql).To(Equal(expectedSql)) + Expect(args).To(ConsistOf(expectedValue...)) + }, + Entry("is true", "is true", "loved = ?", true), + Entry("is false", "is false", "loved = ?", false), + ) + }) })