diff --git a/core/external_metadata.go b/core/external_metadata.go index 9ae9c6779..1d7b8825a 100644 --- a/core/external_metadata.go +++ b/core/external_metadata.go @@ -163,28 +163,37 @@ func (e *externalMetadata) SimilarSongs(ctx context.Context, id string, count in return nil, ctx.Err() } - artists := model.Artists{artist.Artist} - artists = append(artists, artist.SimilarArtists...) - weightedSongs := utils.NewWeightedRandomChooser() - for _, a := range artists { + addArtist := func(a model.Artist, weightedSongs *utils.WeightedChooser, count, artistWeight int) error { if utils.IsCtxDone(ctx) { log.Warn(ctx, "SimilarSongs call canceled", ctx.Err()) - return nil, ctx.Err() + return ctx.Err() } topCount := utils.MaxInt(count, 20) topSongs, err := e.getMatchingTopSongs(ctx, e.ag, &auxArtist{Name: a.Name, Artist: a}, topCount) if err != nil { log.Warn(ctx, "Error getting artist's top songs", "artist", a.Name, err) - continue + return nil } - weight := topCount * 4 + weight := topCount * (4 + artistWeight) for _, mf := range topSongs { weightedSongs.Put(mf, weight) weight -= 4 } + return nil + } + + err = addArtist(artist.Artist, weightedSongs, count, 10) + if err != nil { + return nil, err + } + for _, a := range artist.SimilarArtists { + err := addArtist(a, weightedSongs, count, 0) + if err != nil { + return nil, err + } } var similarSongs model.MediaFiles diff --git a/utils/weighted_random_chooser.go b/utils/weighted_random_chooser.go index 8692f148a..78f174402 100644 --- a/utils/weighted_random_chooser.go +++ b/utils/weighted_random_chooser.go @@ -6,29 +6,29 @@ import ( "time" ) -type weightedChooser struct { +type WeightedChooser struct { entries []interface{} weights []int totalWeight int rng *rand.Rand } -func NewWeightedRandomChooser() *weightedChooser { +func NewWeightedRandomChooser() *WeightedChooser { src := rand.NewSource(time.Now().UTC().UnixNano()) - return &weightedChooser{ + return &WeightedChooser{ rng: rand.New(src), // nolint:gosec } } -func (w *weightedChooser) Put(value interface{}, weight int) { +func (w *WeightedChooser) Put(value interface{}, weight int) { w.entries = append(w.entries, value) w.weights = append(w.weights, weight) w.totalWeight += weight } // GetAndRemove choose a random entry based on their weights, and removes it from the list -func (w *weightedChooser) GetAndRemove() (interface{}, error) { +func (w *WeightedChooser) GetAndRemove() (interface{}, error) { if w.totalWeight == 0 { return nil, errors.New("cannot choose from zero weight") } @@ -42,7 +42,7 @@ func (w *weightedChooser) GetAndRemove() (interface{}, error) { } // Based on https://eli.thegreenplace.net/2010/01/22/weighted-random-generation-in-python/ -func (w *weightedChooser) weightedChoice() (int, error) { +func (w *WeightedChooser) weightedChoice() (int, error) { rnd := w.rng.Intn(w.totalWeight) for i, weight := range w.weights { rnd -= weight @@ -53,7 +53,7 @@ func (w *weightedChooser) weightedChoice() (int, error) { return 0, errors.New("internal error - code should not reach this point") } -func (w *weightedChooser) Remove(i int) { +func (w *WeightedChooser) Remove(i int) { w.totalWeight -= w.weights[i] w.weights[i] = w.weights[len(w.weights)-1] @@ -64,6 +64,6 @@ func (w *weightedChooser) Remove(i int) { w.entries = w.entries[:len(w.entries)-1] } -func (w *weightedChooser) Size() int { +func (w *WeightedChooser) Size() int { return len(w.entries) } diff --git a/utils/weighted_random_chooser_test.go b/utils/weighted_random_chooser_test.go index d7303d02c..0d96daf42 100644 --- a/utils/weighted_random_chooser_test.go +++ b/utils/weighted_random_chooser_test.go @@ -6,7 +6,7 @@ import ( ) var _ = Describe("WeightedRandomChooser", func() { - var w *weightedChooser + var w *WeightedChooser BeforeEach(func() { w = NewWeightedRandomChooser() for i := 0; i < 10; i++ {