From bb96d455f87719ceabb3e6eeff46f93a4f24478a Mon Sep 17 00:00:00 2001
From: Deluan <deluan@navidrome.org>
Date: Fri, 10 May 2024 15:27:07 -0400
Subject: [PATCH] Replace sync.WaitGroup with more appropriate errgroup.Group

---
 server/subsonic/searching.go | 38 +++++++++++++++---------------------
 1 file changed, 16 insertions(+), 22 deletions(-)

diff --git a/server/subsonic/searching.go b/server/subsonic/searching.go
index 256d2f2c4..7e94b41a8 100644
--- a/server/subsonic/searching.go
+++ b/server/subsonic/searching.go
@@ -6,7 +6,6 @@ import (
 	"net/http"
 	"reflect"
 	"strings"
-	"sync"
 	"time"
 
 	"github.com/deluan/sanitize"
@@ -15,6 +14,7 @@ import (
 	"github.com/navidrome/navidrome/server/public"
 	"github.com/navidrome/navidrome/server/subsonic/responses"
 	"github.com/navidrome/navidrome/utils/req"
+	"golang.org/x/sync/errgroup"
 )
 
 type searchParams struct {
@@ -42,45 +42,39 @@ func (api *Router) getSearchParams(r *http.Request) (*searchParams, error) {
 
 type searchFunc[T any] func(q string, offset int, size int) (T, error)
 
-func callSearch[T any](ctx context.Context, wg *sync.WaitGroup, s searchFunc[T], q string, offset, size int, result *T) {
-	defer wg.Done()
-	if size == 0 {
-		return
-	}
-	done := make(chan struct{})
-	go func() {
+func callSearch[T any](ctx context.Context, s searchFunc[T], q string, offset, size int, result *T) func() error {
+	return func() error {
+		if size == 0 {
+			return nil
+		}
 		typ := strings.TrimPrefix(reflect.TypeOf(*result).String(), "model.")
 		var err error
 		start := time.Now()
 		*result, err = s(q, offset, size)
 		if err != nil {
-			log.Error(ctx, "Error searching "+typ, "query", q, err)
+			log.Error(ctx, "Error searching "+typ, "query", q, "elapsed", time.Since(start), err)
 		} else {
 			log.Trace(ctx, "Search for "+typ+" completed", "query", q, "elapsed", time.Since(start))
 		}
-		done <- struct{}{}
-	}()
-	select {
-	case <-done:
-	case <-ctx.Done():
+		return nil
 	}
 }
 
 func (api *Router) searchAll(ctx context.Context, sp *searchParams) (mediaFiles model.MediaFiles, albums model.Albums, artists model.Artists) {
 	start := time.Now()
 	q := sanitize.Accents(strings.ToLower(strings.TrimSuffix(sp.query, "*")))
-	wg := &sync.WaitGroup{}
-	wg.Add(3)
-	go callSearch(ctx, wg, api.ds.MediaFile(ctx).Search, q, sp.songOffset, sp.songCount, &mediaFiles)
-	go callSearch(ctx, wg, api.ds.Album(ctx).Search, q, sp.albumOffset, sp.albumCount, &albums)
-	go callSearch(ctx, wg, api.ds.Artist(ctx).Search, q, sp.artistOffset, sp.artistCount, &artists)
-	wg.Wait()
 
-	if ctx.Err() == nil {
+	// Run searches in parallel
+	g, ctx := errgroup.WithContext(ctx)
+	g.Go(callSearch(ctx, api.ds.MediaFile(ctx).Search, q, sp.songOffset, sp.songCount, &mediaFiles))
+	g.Go(callSearch(ctx, api.ds.Album(ctx).Search, q, sp.albumOffset, sp.albumCount, &albums))
+	g.Go(callSearch(ctx, api.ds.Artist(ctx).Search, q, sp.artistOffset, sp.artistCount, &artists))
+	err := g.Wait()
+	if err == nil {
 		log.Debug(ctx, fmt.Sprintf("Search resulted in %d songs, %d albums and %d artists",
 			len(mediaFiles), len(albums), len(artists)), "query", sp.query, "elapsedTime", time.Since(start))
 	} else {
-		log.Warn(ctx, "Search was interrupted", ctx.Err(), "query", sp.query, "elapsedTime", time.Since(start))
+		log.Warn(ctx, "Search was interrupted", "query", sp.query, "elapsedTime", time.Since(start), err)
 	}
 	return mediaFiles, albums, artists
 }