From 98218d045e2ebc848714c8aac46c78f6febc11cf Mon Sep 17 00:00:00 2001
From: Guilherme Souza <32180229+gqgs@users.noreply.github.com>
Date: Sat, 18 May 2024 15:10:53 -0300
Subject: [PATCH] Deterministic pagination in random albums sort (#1841)

* Deterministic pagination in random albums sort

* Reseed on first random page

* Add unit tests

* Use rand in Subsonic API

* Use different seeds per user on SEEDEDRAND() SQLite3 function

* Small refactor

* Fix id mismatch

* Add seeded random to media_file (subsonic endpoint `getRandomSongs`)

* Refactor

* Remove unneeded import

---------

Co-authored-by: Deluan <deluan@navidrome.org>
---
 db/db.go                            | 11 ++++++--
 persistence/album_repository.go     |  5 ++--
 persistence/mediafile_repository.go |  5 ++--
 persistence/sql_base_repository.go  | 13 +++++++++
 server/subsonic/filter/filters.go   |  4 +--
 utils/hasher/hasher.go              | 44 +++++++++++++++++++++++++++++
 utils/hasher/hasher_test.go         | 36 +++++++++++++++++++++++
 7 files changed, 110 insertions(+), 8 deletions(-)
 create mode 100644 utils/hasher/hasher.go
 create mode 100644 utils/hasher/hasher_test.go

diff --git a/db/db.go b/db/db.go
index b8fd3a36f..cf0ce2cfb 100644
--- a/db/db.go
+++ b/db/db.go
@@ -5,10 +5,11 @@ import (
 	"embed"
 	"fmt"
 
-	_ "github.com/mattn/go-sqlite3"
+	"github.com/mattn/go-sqlite3"
 	"github.com/navidrome/navidrome/conf"
 	_ "github.com/navidrome/navidrome/db/migrations"
 	"github.com/navidrome/navidrome/log"
+	"github.com/navidrome/navidrome/utils/hasher"
 	"github.com/navidrome/navidrome/utils/singleton"
 	"github.com/pressly/goose/v3"
 )
@@ -25,13 +26,19 @@ const migrationsFolder = "migrations"
 
 func Db() *sql.DB {
 	return singleton.GetInstance(func() *sql.DB {
+		sql.Register(Driver+"_custom", &sqlite3.SQLiteDriver{
+			ConnectHook: func(conn *sqlite3.SQLiteConn) error {
+				return conn.RegisterFunc("SEEDEDRAND", hasher.HashFunc(), false)
+			},
+		})
+
 		Path = conf.Server.DbPath
 		if Path == ":memory:" {
 			Path = "file::memory:?cache=shared&_foreign_keys=on"
 			conf.Server.DbPath = Path
 		}
 		log.Debug("Opening DataBase", "dbPath", Path, "driver", Driver)
-		instance, err := sql.Open(Driver, Path)
+		instance, err := sql.Open(Driver+"_custom", Path)
 		if err != nil {
 			panic(err)
 		}
diff --git a/persistence/album_repository.go b/persistence/album_repository.go
index 862852d7e..c820fc13a 100644
--- a/persistence/album_repository.go
+++ b/persistence/album_repository.go
@@ -75,7 +75,7 @@ func NewAlbumRepository(ctx context.Context, db dbx.Builder) model.AlbumReposito
 			"artist":         "compilation asc, COALESCE(NULLIF(sort_album_artist_name,''),order_album_artist_name) asc, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc",
 			"albumArtist":    "compilation asc, COALESCE(NULLIF(sort_album_artist_name,''),order_album_artist_name) asc, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc",
 			"max_year":       "coalesce(nullif(original_date,''), cast(max_year as text)), release_date, name, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc",
-			"random":         "RANDOM()",
+			"random":         r.seededRandomSort(),
 			"recently_added": recentlyAddedSort(),
 		}
 	} else {
@@ -84,7 +84,7 @@ func NewAlbumRepository(ctx context.Context, db dbx.Builder) model.AlbumReposito
 			"artist":         "compilation asc, order_album_artist_name asc, order_album_name asc",
 			"albumArtist":    "compilation asc, order_album_artist_name asc, order_album_name asc",
 			"max_year":       "coalesce(nullif(original_date,''), cast(max_year as text)), release_date, name, order_album_name asc",
-			"random":         "RANDOM()",
+			"random":         r.seededRandomSort(),
 			"recently_added": recentlyAddedSort(),
 		}
 	}
@@ -180,6 +180,7 @@ func (r *albumRepository) GetAll(options ...model.QueryOptions) (model.Albums, e
 }
 
 func (r *albumRepository) GetAllWithoutGenres(options ...model.QueryOptions) (model.Albums, error) {
+	r.resetSeededRandom(options)
 	sq := r.selectAlbum(options...)
 	var dba dbAlbums
 	err := r.queryAll(sq, &dba)
diff --git a/persistence/mediafile_repository.go b/persistence/mediafile_repository.go
index 5c018f34a..6c476a4fe 100644
--- a/persistence/mediafile_repository.go
+++ b/persistence/mediafile_repository.go
@@ -36,7 +36,7 @@ func NewMediaFileRepository(ctx context.Context, db dbx.Builder) *mediaFileRepos
 			"title":     "COALESCE(NULLIF(sort_title,''),title)",
 			"artist":    "COALESCE(NULLIF(sort_artist_name,''),order_artist_name) asc, COALESCE(NULLIF(sort_album_name,''),order_album_name) asc, release_date asc, disc_number asc, track_number asc",
 			"album":     "COALESCE(NULLIF(sort_album_name,''),order_album_name) asc, release_date asc, disc_number asc, track_number asc, COALESCE(NULLIF(sort_artist_name,''),order_artist_name) asc, COALESCE(NULLIF(sort_title,''),title) asc",
-			"random":    "RANDOM()",
+			"random":    r.seededRandomSort(),
 			"createdAt": "media_file.created_at",
 		}
 	} else {
@@ -44,7 +44,7 @@ func NewMediaFileRepository(ctx context.Context, db dbx.Builder) *mediaFileRepos
 			"title":     "order_title",
 			"artist":    "order_artist_name asc, order_album_name asc, release_date asc, disc_number asc, track_number asc",
 			"album":     "order_album_name asc, release_date asc, disc_number asc, track_number asc, order_artist_name asc, title asc",
-			"random":    "RANDOM()",
+			"random":    r.seededRandomSort(),
 			"createdAt": "media_file.created_at",
 		}
 	}
@@ -102,6 +102,7 @@ func (r *mediaFileRepository) Get(id string) (*model.MediaFile, error) {
 }
 
 func (r *mediaFileRepository) GetAll(options ...model.QueryOptions) (model.MediaFiles, error) {
+	r.resetSeededRandom(options)
 	sq := r.selectMediaFile(options...)
 	res := model.MediaFiles{}
 	err := r.queryAll(sq, &res, options...)
diff --git a/persistence/sql_base_repository.go b/persistence/sql_base_repository.go
index d5282516d..da5ac6a3d 100644
--- a/persistence/sql_base_repository.go
+++ b/persistence/sql_base_repository.go
@@ -14,6 +14,7 @@ import (
 	"github.com/navidrome/navidrome/log"
 	"github.com/navidrome/navidrome/model"
 	"github.com/navidrome/navidrome/model/request"
+	"github.com/navidrome/navidrome/utils/hasher"
 	"github.com/pocketbase/dbx"
 )
 
@@ -137,6 +138,18 @@ func (r sqlRepository) applyFilters(sq SelectBuilder, options ...model.QueryOpti
 	return sq
 }
 
+func (r sqlRepository) seededRandomSort() string {
+	u, _ := request.UserFrom(r.ctx)
+	return fmt.Sprintf("SEEDEDRAND('%s', id)", r.tableName+u.ID)
+}
+
+func (r sqlRepository) resetSeededRandom(options []model.QueryOptions) {
+	if len(options) > 0 && options[0].Offset == 0 && options[0].Sort == "random" {
+		u, _ := request.UserFrom(r.ctx)
+		hasher.Reseed(r.tableName + u.ID)
+	}
+}
+
 func (r sqlRepository) executeSQL(sq Sqlizer) (int64, error) {
 	query, args, err := r.toSQL(sq)
 	if err != nil {
diff --git a/server/subsonic/filter/filters.go b/server/subsonic/filter/filters.go
index fca482f33..87fb4804e 100644
--- a/server/subsonic/filter/filters.go
+++ b/server/subsonic/filter/filters.go
@@ -24,7 +24,7 @@ func AlbumsByFrequent() Options {
 }
 
 func AlbumsByRandom() Options {
-	return Options{Sort: "random()"}
+	return Options{Sort: "random"}
 }
 
 func AlbumsByName() Options {
@@ -100,7 +100,7 @@ func SongsByAlbum(albumId string) Options {
 
 func SongsByRandom(genre string, fromYear, toYear int) Options {
 	options := Options{
-		Sort: "random()",
+		Sort: "random",
 	}
 	ff := squirrel.And{}
 	if genre != "" {
diff --git a/utils/hasher/hasher.go b/utils/hasher/hasher.go
new file mode 100644
index 000000000..78566913a
--- /dev/null
+++ b/utils/hasher/hasher.go
@@ -0,0 +1,44 @@
+package hasher
+
+import "hash/maphash"
+
+var instance = NewHasher()
+
+func Reseed(id string) {
+	instance.Reseed(id)
+}
+
+func HashFunc() func(id, str string) uint64 {
+	return instance.HashFunc()
+}
+
+type hasher struct {
+	seeds map[string]maphash.Seed
+}
+
+func NewHasher() *hasher {
+	h := new(hasher)
+	h.seeds = make(map[string]maphash.Seed)
+	return h
+}
+
+// Reseed generates a new seed for the given id
+func (h *hasher) Reseed(id string) {
+	h.seeds[id] = maphash.MakeSeed()
+}
+
+// HashFunc returns a function that hashes a string using the seed for the given id
+func (h *hasher) HashFunc() func(id, str string) uint64 {
+	return func(id, str string) uint64 {
+		var hash maphash.Hash
+		var seed maphash.Seed
+		var ok bool
+		if seed, ok = h.seeds[id]; !ok {
+			seed = maphash.MakeSeed()
+			h.seeds[id] = seed
+		}
+		hash.SetSeed(seed)
+		_, _ = hash.WriteString(str)
+		return hash.Sum64()
+	}
+}
diff --git a/utils/hasher/hasher_test.go b/utils/hasher/hasher_test.go
new file mode 100644
index 000000000..3a1f9dfde
--- /dev/null
+++ b/utils/hasher/hasher_test.go
@@ -0,0 +1,36 @@
+package hasher_test
+
+import (
+	"github.com/navidrome/navidrome/utils/hasher"
+	. "github.com/onsi/ginkgo/v2"
+	. "github.com/onsi/gomega"
+)
+
+var _ = Describe("HashFunc", func() {
+	const input = "123e4567e89b12d3a456426614174000"
+
+	It("hashes the input and returns the sum", func() {
+		hashFunc := hasher.HashFunc()
+		sum := hashFunc("1", input)
+		Expect(sum > 0).To(BeTrue())
+	})
+
+	It("hashes the input, reseeds and returns a different sum", func() {
+		hashFunc := hasher.HashFunc()
+		sum := hashFunc("1", input)
+		hasher.Reseed("1")
+		sum2 := hashFunc("1", input)
+		Expect(sum).NotTo(Equal(sum2))
+	})
+
+	It("keeps different hashes for different ids", func() {
+		hashFunc := hasher.HashFunc()
+		sum := hashFunc("1", input)
+		sum2 := hashFunc("2", input)
+
+		Expect(sum).NotTo(Equal(sum2))
+
+		Expect(sum).To(Equal(hashFunc("1", input)))
+		Expect(sum2).To(Equal(hashFunc("2", input)))
+	})
+})