From a56e588c8ebcd46d5b4f30f831621e15d85ad75b Mon Sep 17 00:00:00 2001
From: Deluan <deluan@deluan.com>
Date: Fri, 8 May 2020 13:57:32 -0400
Subject: [PATCH] Create relation table for playlist tracks

---
 ...0200516140647_add_playlist_tracks_table.go | 100 ++++++++++
 engine/playlists.go                           | 121 ++++++------
 model/playlist.go                             |   3 +-
 persistence/helpers.go                        |   2 +-
 persistence/persistence_suite_test.go         |  13 +-
 persistence/playlist_repository.go            | 183 +++++++-----------
 server/subsonic/playlists.go                  |   2 +-
 7 files changed, 250 insertions(+), 174 deletions(-)
 create mode 100644 db/migration/20200516140647_add_playlist_tracks_table.go

diff --git a/db/migration/20200516140647_add_playlist_tracks_table.go b/db/migration/20200516140647_add_playlist_tracks_table.go
new file mode 100644
index 000000000..f90b1d383
--- /dev/null
+++ b/db/migration/20200516140647_add_playlist_tracks_table.go
@@ -0,0 +1,100 @@
+package migration
+
+import (
+	"database/sql"
+	"strings"
+
+	"github.com/deluan/navidrome/log"
+	"github.com/pressly/goose"
+)
+
+func init() {
+	goose.AddMigration(Up20200516140647, Down20200516140647)
+}
+
+func Up20200516140647(tx *sql.Tx) error {
+	_, err := tx.Exec(`
+create table if not exists playlist_tracks
+(
+	id integer default 0 not null, 
+    playlist_id varchar(255) not null, 
+	media_file_id varchar(255) not null
+);
+
+create unique index if not exists playlist_tracks_pos
+	on playlist_tracks (playlist_id, id);
+`)
+	if err != nil {
+		return err
+	}
+	rows, err := tx.Query("select id, tracks from playlist")
+	if err != nil {
+		return err
+	}
+	defer rows.Close()
+	var id, tracks string
+	for rows.Next() {
+		err := rows.Scan(&id, &tracks)
+		if err != nil {
+			return err
+		}
+		err = Up20200516140647UpdatePlaylistTracks(tx, id, tracks)
+		if err != nil {
+			return err
+		}
+	}
+	err = rows.Err()
+	if err != nil {
+		return err
+	}
+
+	_, err = tx.Exec(`
+create table playlist_dg_tmp
+(
+	id varchar(255) not null
+		primary key,
+	name varchar(255) default '' not null,
+	comment varchar(255) default '' not null,
+	duration real default 0 not null,
+	song_count integer default 0 not null,
+	owner varchar(255) default '' not null,
+	public bool default FALSE not null,
+	created_at datetime,
+	updated_at datetime
+);
+
+insert into playlist_dg_tmp(id, name, comment, duration, owner, public, created_at, updated_at) 
+	select id, name, comment, duration, owner, public, created_at, updated_at from playlist;
+
+drop table playlist;
+
+alter table playlist_dg_tmp rename to playlist;
+
+create index playlist_name
+	on playlist (name);
+
+update playlist set song_count = (select count(*) from playlist_tracks where playlist_id = playlist.id)
+where id <> ''
+
+`)
+	return err
+}
+
+func Up20200516140647UpdatePlaylistTracks(tx *sql.Tx, id string, tracks string) error {
+	trackList := strings.Split(tracks, ",")
+	stmt, err := tx.Prepare("insert into playlist_tracks (playlist_id, media_file_id, id) values (?, ?, ?)")
+	if err != nil {
+		return err
+	}
+	for i, trackId := range trackList {
+		_, err := stmt.Exec(id, trackId, i)
+		if err != nil {
+			log.Error("Error adding track to playlist", "playlistId", id, "trackId", trackId, err)
+		}
+	}
+	return nil
+}
+
+func Down20200516140647(tx *sql.Tx) error {
+	return nil
+}
diff --git a/engine/playlists.go b/engine/playlists.go
index ef0860ed7..a0e075aed 100644
--- a/engine/playlists.go
+++ b/engine/playlists.go
@@ -26,30 +26,33 @@ type playlists struct {
 }
 
 func (p *playlists) Create(ctx context.Context, playlistId, name string, ids []string) error {
-	owner := p.getUser(ctx)
-	var pls *model.Playlist
-	var err error
-	// If playlistID is present, override tracks
-	if playlistId != "" {
-		pls, err = p.ds.Playlist(ctx).Get(playlistId)
-		if err != nil {
-			return err
-		}
-		if owner != pls.Owner {
-			return model.ErrNotAuthorized
-		}
-		pls.Tracks = nil
-	} else {
-		pls = &model.Playlist{
-			Name:  name,
-			Owner: owner,
-		}
-	}
-	for _, id := range ids {
-		pls.Tracks = append(pls.Tracks, model.MediaFile{ID: id})
-	}
+	return p.ds.WithTx(func(tx model.DataStore) error {
+		owner := p.getUser(ctx)
+		var pls *model.Playlist
+		var err error
 
-	return p.ds.Playlist(ctx).Put(pls)
+		// If playlistID is present, override tracks
+		if playlistId != "" {
+			pls, err = tx.Playlist(ctx).Get(playlistId)
+			if err != nil {
+				return err
+			}
+			if owner != pls.Owner {
+				return model.ErrNotAuthorized
+			}
+			pls.Tracks = nil
+		} else {
+			pls = &model.Playlist{
+				Name:  name,
+				Owner: owner,
+			}
+		}
+		for _, id := range ids {
+			pls.Tracks = append(pls.Tracks, model.MediaFile{ID: id})
+		}
+
+		return tx.Playlist(ctx).Put(pls)
+	})
 }
 
 func (p *playlists) getUser(ctx context.Context) string {
@@ -61,46 +64,50 @@ func (p *playlists) getUser(ctx context.Context) string {
 }
 
 func (p *playlists) Delete(ctx context.Context, playlistId string) error {
-	pls, err := p.ds.Playlist(ctx).Get(playlistId)
-	if err != nil {
-		return err
-	}
+	return p.ds.WithTx(func(tx model.DataStore) error {
+		pls, err := tx.Playlist(ctx).Get(playlistId)
+		if err != nil {
+			return err
+		}
 
-	owner := p.getUser(ctx)
-	if owner != pls.Owner {
-		return model.ErrNotAuthorized
-	}
-	return p.ds.Playlist(ctx).Delete(playlistId)
+		owner := p.getUser(ctx)
+		if owner != pls.Owner {
+			return model.ErrNotAuthorized
+		}
+		return tx.Playlist(ctx).Delete(playlistId)
+	})
 }
 
 func (p *playlists) Update(ctx context.Context, playlistId string, name *string, idsToAdd []string, idxToRemove []int) error {
-	pls, err := p.ds.Playlist(ctx).Get(playlistId)
-	if err != nil {
-		return err
-	}
-
-	owner := p.getUser(ctx)
-	if owner != pls.Owner {
-		return model.ErrNotAuthorized
-	}
-
-	if name != nil {
-		pls.Name = *name
-	}
-	newTracks := model.MediaFiles{}
-	for i, t := range pls.Tracks {
-		if utils.IntInSlice(i, idxToRemove) {
-			continue
+	return p.ds.WithTx(func(tx model.DataStore) error {
+		pls, err := tx.Playlist(ctx).Get(playlistId)
+		if err != nil {
+			return err
 		}
-		newTracks = append(newTracks, t)
-	}
 
-	for _, id := range idsToAdd {
-		newTracks = append(newTracks, model.MediaFile{ID: id})
-	}
-	pls.Tracks = newTracks
+		owner := p.getUser(ctx)
+		if owner != pls.Owner {
+			return model.ErrNotAuthorized
+		}
 
-	return p.ds.Playlist(ctx).Put(pls)
+		if name != nil {
+			pls.Name = *name
+		}
+		newTracks := model.MediaFiles{}
+		for i, t := range pls.Tracks {
+			if utils.IntInSlice(i, idxToRemove) {
+				continue
+			}
+			newTracks = append(newTracks, t)
+		}
+
+		for _, id := range idsToAdd {
+			newTracks = append(newTracks, model.MediaFile{ID: id})
+		}
+		pls.Tracks = newTracks
+
+		return tx.Playlist(ctx).Put(pls)
+	})
 }
 
 func (p *playlists) GetAll(ctx context.Context) (model.Playlists, error) {
@@ -134,7 +141,7 @@ func (p *playlists) Get(ctx context.Context, id string) (*PlaylistInfo, error) {
 	plsInfo := &PlaylistInfo{
 		Id:        pl.ID,
 		Name:      pl.Name,
-		SongCount: len(pl.Tracks),
+		SongCount: pl.SongCount,
 		Duration:  int(pl.Duration),
 		Public:    pl.Public,
 		Owner:     pl.Owner,
diff --git a/model/playlist.go b/model/playlist.go
index c82923bfc..8e4f552f7 100644
--- a/model/playlist.go
+++ b/model/playlist.go
@@ -3,10 +3,11 @@ package model
 import "time"
 
 type Playlist struct {
-	ID        string     `json:"id"`
+	ID        string     `json:"id"          orm:"column(id)"`
 	Name      string     `json:"name"`
 	Comment   string     `json:"comment"`
 	Duration  float32    `json:"duration"`
+	SongCount int        `json:"songCount"`
 	Owner     string     `json:"owner"`
 	Public    bool       `json:"public"`
 	Tracks    MediaFiles `json:"tracks"`
diff --git a/persistence/helpers.go b/persistence/helpers.go
index 70a611dd7..597841f6d 100644
--- a/persistence/helpers.go
+++ b/persistence/helpers.go
@@ -23,7 +23,7 @@ func toSqlArgs(rec interface{}) (map[string]interface{}, error) {
 	err = json.Unmarshal(b, &m)
 	r := make(map[string]interface{}, len(m))
 	for f, v := range m {
-		if !utils.StringInSlice(f, model.AnnotationFields) {
+		if !utils.StringInSlice(f, model.AnnotationFields) && v != nil {
 			r[toSnakeCase(f)] = v
 		}
 	}
diff --git a/persistence/persistence_suite_test.go b/persistence/persistence_suite_test.go
index 1263a518a..dbaabc4ae 100644
--- a/persistence/persistence_suite_test.go
+++ b/persistence/persistence_suite_test.go
@@ -65,12 +65,13 @@ var (
 
 var (
 	plsBest = model.Playlist{
-		ID:      "10",
-		Name:    "Best",
-		Comment: "No Comments",
-		Owner:   "userid",
-		Public:  true,
-		Tracks:  model.MediaFiles{{ID: "1001"}, {ID: "1003"}},
+		ID:        "10",
+		Name:      "Best",
+		Comment:   "No Comments",
+		Owner:     "userid",
+		Public:    true,
+		SongCount: 2,
+		Tracks:    model.MediaFiles{{ID: "1001"}, {ID: "1003"}},
 	}
 	plsCool       = model.Playlist{ID: "11", Name: "Cool", Tracks: model.MediaFiles{{ID: "1004"}}}
 	testPlaylists = model.Playlists{plsBest, plsCool}
diff --git a/persistence/playlist_repository.go b/persistence/playlist_repository.go
index 8f226930c..b586eb3c8 100644
--- a/persistence/playlist_repository.go
+++ b/persistence/playlist_repository.go
@@ -2,7 +2,6 @@ package persistence
 
 import (
 	"context"
-	"strings"
 	"time"
 
 	. "github.com/Masterminds/squirrel"
@@ -12,18 +11,6 @@ import (
 	"github.com/deluan/rest"
 )
 
-type playlist struct {
-	ID        string `orm:"column(id)"`
-	Name      string
-	Comment   string
-	Duration  float32
-	Owner     string
-	Public    bool
-	Tracks    string
-	CreatedAt time.Time
-	UpdatedAt time.Time
-}
-
 type playlistRepository struct {
 	sqlRepository
 	sqlRestful
@@ -46,6 +33,11 @@ func (r *playlistRepository) Exists(id string) (bool, error) {
 }
 
 func (r *playlistRepository) Delete(id string) error {
+	del := Delete("playlist_tracks").Where(Eq{"playlist_id": id})
+	_, err := r.executeSQL(del)
+	if err != nil {
+		return err
+	}
 	return r.delete(Eq{"id": id})
 }
 
@@ -54,121 +46,96 @@ func (r *playlistRepository) Put(p *model.Playlist) error {
 		p.CreatedAt = time.Now()
 	}
 	p.UpdatedAt = time.Now()
-	pls := r.fromModel(p)
-	_, err := r.put(pls.ID, pls)
+
+	// Save tracks for later and set it to nil, to avoid trying to save it to the DB
+	tracks := p.Tracks
+	p.Tracks = nil
+
+	id, err := r.put(p.ID, p)
+	if err != nil {
+		return err
+	}
+	err = r.updateTracks(id, tracks)
 	return err
 }
 
 func (r *playlistRepository) Get(id string) (*model.Playlist, error) {
 	sel := r.newSelect().Columns("*").Where(Eq{"id": id})
-	var res playlist
-	err := r.queryOne(sel, &res)
-	pls := r.toModel(&res)
+	var pls model.Playlist
+	err := r.queryOne(sel, &pls)
+	if err != nil {
+		return nil, err
+	}
+	err = r.loadTracks(&pls)
 	return &pls, err
 }
 
 func (r *playlistRepository) GetAll(options ...model.QueryOptions) (model.Playlists, error) {
 	sel := r.newSelect(options...).Columns("*")
-	var res []playlist
+	var res model.Playlists
 	err := r.queryAll(sel, &res)
-	return r.toModels(res), err
+	if err != nil {
+		return nil, err
+	}
+	err = r.loadAllTracks(res)
+	return res, err
 }
 
-func (r *playlistRepository) toModels(all []playlist) model.Playlists {
-	result := make(model.Playlists, len(all))
-	for i := range all {
-		p := all[i]
-		result[i] = r.toModel(&p)
-	}
-	return result
-}
-
-func (r *playlistRepository) toModel(p *playlist) model.Playlist {
-	pls := model.Playlist{
-		ID:        p.ID,
-		Name:      p.Name,
-		Comment:   p.Comment,
-		Duration:  p.Duration,
-		Owner:     p.Owner,
-		Public:    p.Public,
-		CreatedAt: p.CreatedAt,
-		UpdatedAt: p.UpdatedAt,
-	}
-	if strings.TrimSpace(p.Tracks) != "" {
-		tracks := strings.Split(p.Tracks, ",")
-		for _, t := range tracks {
-			pls.Tracks = append(pls.Tracks, model.MediaFile{ID: t})
-		}
-	}
-	pls.Tracks = r.loadTracks(&pls)
-	return pls
-}
-
-func (r *playlistRepository) fromModel(p *model.Playlist) playlist {
-	pls := playlist{
-		ID:        p.ID,
-		Name:      p.Name,
-		Comment:   p.Comment,
-		Owner:     p.Owner,
-		Public:    p.Public,
-		CreatedAt: p.CreatedAt,
-		UpdatedAt: p.UpdatedAt,
-	}
-	// TODO Update duration with a SQL query, instead of loading all tracks
-	p.Tracks = r.loadTracks(p)
-	var newTracks []string
-	for _, t := range p.Tracks {
-		newTracks = append(newTracks, t.ID)
-		pls.Duration += t.Duration
-	}
-	pls.Tracks = strings.Join(newTracks, ",")
-	return pls
-}
-
-// TODO: Introduce a relation table for Playlist <-> MediaFiles, and rewrite this method in pure SQL
-func (r *playlistRepository) loadTracks(p *model.Playlist) model.MediaFiles {
-	if len(p.Tracks) == 0 {
-		return nil
+func (r *playlistRepository) updateTracks(id string, tracks model.MediaFiles) error {
+	// Remove old tracks
+	del := Delete("playlist_tracks").Where(Eq{"playlist_id": id})
+	_, err := r.executeSQL(del)
+	if err != nil {
+		return err
 	}
 
-	// Collect all ids
-	ids := make([]string, len(p.Tracks))
-	for i, t := range p.Tracks {
-		ids[i] = t.ID
-	}
-
-	// Break the list in chunks, up to 50 items, to avoid hitting SQLITE_MAX_FUNCTION_ARG limit
-	const chunkSize = 50
-	var chunks [][]string
-	for i := 0; i < len(ids); i += chunkSize {
-		end := i + chunkSize
-		if end > len(ids) {
-			end = len(ids)
-		}
-
-		chunks = append(chunks, ids[i:end])
-	}
-
-	// Query each chunk of media_file ids and store results in a map
-	mfRepo := NewMediaFileRepository(r.ctx, r.ormer)
-	trackMap := map[string]model.MediaFile{}
-	for i := range chunks {
-		idsFilter := Eq{"id": chunks[i]}
-		tracks, err := mfRepo.GetAll(model.QueryOptions{Filters: idsFilter})
+	// Add new tracks
+	for i, t := range tracks {
+		ins := Insert("playlist_tracks").Columns("playlist_id", "media_file_id", "id").
+			Values(id, t.ID, i)
+		_, err = r.executeSQL(ins)
 		if err != nil {
-			log.Error(r.ctx, "Could not load playlist's tracks", "playlistName", p.Name, "playlistId", p.ID, err)
-		}
-		for _, t := range tracks {
-			trackMap[t.ID] = t
+			return err
 		}
 	}
 
-	// Create a new list of tracks with the same order as the original
-	newTracks := make(model.MediaFiles, len(p.Tracks))
-	for i, t := range p.Tracks {
-		newTracks[i] = trackMap[t.ID]
+	// Get total playlist duration and count
+	statsSql := Select("sum(duration) as duration", "count(*) as count").From("media_file").
+		Join("playlist_tracks f on f.media_file_id = media_file.id").
+		Where(Eq{"playlist_id": id})
+	var res struct{ Duration, Count float32 }
+	err = r.queryOne(statsSql, &res)
+	if err != nil {
+		return err
 	}
-	return newTracks
+
+	// Update total playlist duration and count
+	upd := Update(r.tableName).
+		Set("duration", res.Duration).
+		Set("song_count", res.Count).
+		Where(Eq{"id": id})
+	_, err = r.executeSQL(upd)
+	return err
+}
+
+func (r *playlistRepository) loadAllTracks(all model.Playlists) error {
+	for i := range all {
+		if err := r.loadTracks(&all[i]); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (r *playlistRepository) loadTracks(pls *model.Playlist) (err error) {
+	tracksQuery := Select("media_file.*").From("media_file").
+		Join("playlist_tracks f on f.media_file_id = media_file.id").
+		Where(Eq{"f.playlist_id": pls.ID}).OrderBy("f.id")
+	err = r.queryAll(tracksQuery, &pls.Tracks)
+	if err != nil {
+		log.Error("Error loading playlist tracks", "playlist", pls.Name, "id", pls.ID)
+	}
+	return
 }
 
 func (r *playlistRepository) Count(options ...rest.QueryOptions) (int64, error) {
diff --git a/server/subsonic/playlists.go b/server/subsonic/playlists.go
index 039ede726..bca3e7396 100644
--- a/server/subsonic/playlists.go
+++ b/server/subsonic/playlists.go
@@ -32,7 +32,7 @@ func (c *PlaylistsController) GetPlaylists(w http.ResponseWriter, r *http.Reques
 		playlists[i].Id = p.ID
 		playlists[i].Name = p.Name
 		playlists[i].Comment = p.Comment
-		playlists[i].SongCount = len(p.Tracks)
+		playlists[i].SongCount = p.SongCount
 		playlists[i].Duration = int(p.Duration)
 		playlists[i].Owner = p.Owner
 		playlists[i].Public = p.Public