mirror of
https://github.com/navidrome/navidrome.git
synced 2025-04-13 02:37:18 +03:00
Refactor random.WeightedChooser, unsing generics
This commit is contained in:
parent
a7a4fb522c
commit
4d28d534cc
@ -267,8 +267,8 @@ func (e *externalMetadata) SimilarSongs(ctx context.Context, id string, count in
|
|||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
weightedSongs := random.NewWeightedRandomChooser()
|
weightedSongs := random.NewWeightedChooser[model.MediaFile]()
|
||||||
addArtist := func(a model.Artist, weightedSongs *random.WeightedChooser, count, artistWeight int) error {
|
addArtist := func(a model.Artist, weightedSongs *random.WeightedChooser[model.MediaFile], count, artistWeight int) error {
|
||||||
if utils.IsCtxDone(ctx) {
|
if utils.IsCtxDone(ctx) {
|
||||||
log.Warn(ctx, "SimilarSongs call canceled", ctx.Err())
|
log.Warn(ctx, "SimilarSongs call canceled", ctx.Err())
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
@ -302,12 +302,12 @@ func (e *externalMetadata) SimilarSongs(ctx context.Context, id string, count in
|
|||||||
|
|
||||||
var similarSongs model.MediaFiles
|
var similarSongs model.MediaFiles
|
||||||
for len(similarSongs) < count && weightedSongs.Size() > 0 {
|
for len(similarSongs) < count && weightedSongs.Size() > 0 {
|
||||||
s, err := weightedSongs.GetAndRemove()
|
s, err := weightedSongs.Pick()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(ctx, "Error getting weighted song", err)
|
log.Warn(ctx, "Error getting weighted song", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
similarSongs = append(similarSongs, s.(model.MediaFile))
|
similarSongs = append(similarSongs, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
return similarSongs, nil
|
return similarSongs, nil
|
||||||
|
@ -2,42 +2,46 @@ package random
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"slices"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WeightedChooser struct {
|
// WeightedChooser allows to randomly choose an entry based on their weights
|
||||||
entries []interface{}
|
// (higher weight = higher chance of being chosen). Based on the subtraction method described in
|
||||||
|
// https://eli.thegreenplace.net/2010/01/22/weighted-random-generation-in-python/
|
||||||
|
type WeightedChooser[T any] struct {
|
||||||
|
entries []T
|
||||||
weights []int
|
weights []int
|
||||||
totalWeight int
|
totalWeight int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWeightedRandomChooser() *WeightedChooser {
|
func NewWeightedChooser[T any]() *WeightedChooser[T] {
|
||||||
return &WeightedChooser{}
|
return &WeightedChooser[T]{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WeightedChooser) Add(value interface{}, weight int) {
|
func (w *WeightedChooser[T]) Add(value T, weight int) {
|
||||||
w.entries = append(w.entries, value)
|
w.entries = append(w.entries, value)
|
||||||
w.weights = append(w.weights, weight)
|
w.weights = append(w.weights, weight)
|
||||||
w.totalWeight += weight
|
w.totalWeight += weight
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAndRemove choose a random entry based on their weights, and removes it from the list
|
// Pick choose a random entry based on their weights, and removes it from the list
|
||||||
func (w *WeightedChooser) GetAndRemove() (interface{}, error) {
|
func (w *WeightedChooser[T]) Pick() (T, error) {
|
||||||
|
var empty T
|
||||||
if w.totalWeight == 0 {
|
if w.totalWeight == 0 {
|
||||||
return nil, errors.New("cannot choose from zero weight")
|
return empty, errors.New("cannot choose from zero weight")
|
||||||
}
|
}
|
||||||
i, err := w.weightedChoice()
|
i, err := w.weightedChoice()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return empty, err
|
||||||
}
|
}
|
||||||
entry := w.entries[i]
|
entry := w.entries[i]
|
||||||
w.Remove(i)
|
_ = w.Remove(i)
|
||||||
return entry, nil
|
return entry, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Based on https://eli.thegreenplace.net/2010/01/22/weighted-random-generation-in-python/
|
func (w *WeightedChooser[T]) weightedChoice() (int, error) {
|
||||||
func (w *WeightedChooser) weightedChoice() (int, error) {
|
if len(w.entries) == 0 {
|
||||||
if w.totalWeight == 0 {
|
return 0, errors.New("cannot choose from empty list")
|
||||||
return 0, errors.New("no choices available")
|
|
||||||
}
|
}
|
||||||
rnd := Int64(w.totalWeight)
|
rnd := Int64(w.totalWeight)
|
||||||
for i, weight := range w.weights {
|
for i, weight := range w.weights {
|
||||||
@ -49,17 +53,18 @@ func (w *WeightedChooser) weightedChoice() (int, error) {
|
|||||||
return 0, errors.New("internal error - code should not reach this point")
|
return 0, errors.New("internal error - code should not reach this point")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WeightedChooser) Remove(i int) {
|
func (w *WeightedChooser[T]) Remove(i int) error {
|
||||||
|
if i < 0 || i >= len(w.entries) {
|
||||||
|
return errors.New("index out of bounds")
|
||||||
|
}
|
||||||
|
|
||||||
w.totalWeight -= w.weights[i]
|
w.totalWeight -= w.weights[i]
|
||||||
|
|
||||||
w.weights[i] = w.weights[len(w.weights)-1]
|
w.weights = slices.Delete(w.weights, i, i+1)
|
||||||
w.weights = w.weights[:len(w.weights)-1]
|
w.entries = slices.Delete(w.entries, i, i+1)
|
||||||
|
return nil
|
||||||
w.entries[i] = w.entries[len(w.entries)-1]
|
|
||||||
w.entries[len(w.entries)-1] = nil
|
|
||||||
w.entries = w.entries[:len(w.entries)-1]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WeightedChooser) Size() int {
|
func (w *WeightedChooser[T]) Size() int {
|
||||||
return len(w.entries)
|
return len(w.entries)
|
||||||
}
|
}
|
||||||
|
@ -5,35 +5,57 @@ import (
|
|||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("WeightedRandomChooser", func() {
|
var _ = Describe("WeightedChooser", func() {
|
||||||
var w *WeightedChooser
|
var w *WeightedChooser[int]
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
w = NewWeightedRandomChooser()
|
w = NewWeightedChooser[int]()
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
w.Add(i, i)
|
w.Add(i, i+1)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
It("removes a random item", func() {
|
It("selects and removes a random item", func() {
|
||||||
Expect(w.Size()).To(Equal(10))
|
Expect(w.Size()).To(Equal(10))
|
||||||
_, err := w.GetAndRemove()
|
_, err := w.Pick()
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(w.Size()).To(Equal(9))
|
Expect(w.Size()).To(Equal(9))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("removes items", func() {
|
||||||
|
Expect(w.Size()).To(Equal(10))
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
Expect(w.Remove(0)).To(Succeed())
|
||||||
|
}
|
||||||
|
Expect(w.Size()).To(Equal(0))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns error if trying to remove an invalid index", func() {
|
||||||
|
Expect(w.Size()).To(Equal(10))
|
||||||
|
Expect(w.Remove(-1)).ToNot(Succeed())
|
||||||
|
Expect(w.Remove(10000)).ToNot(Succeed())
|
||||||
|
Expect(w.Size()).To(Equal(10))
|
||||||
|
})
|
||||||
|
|
||||||
It("returns the sole item", func() {
|
It("returns the sole item", func() {
|
||||||
w = NewWeightedRandomChooser()
|
ws := NewWeightedChooser[string]()
|
||||||
w.Add("a", 1)
|
ws.Add("a", 1)
|
||||||
Expect(w.GetAndRemove()).To(Equal("a"))
|
Expect(ws.Pick()).To(Equal("a"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns all items from the list", func() {
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
Expect(w.Pick()).To(BeElementOf(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
|
||||||
|
}
|
||||||
|
Expect(w.Size()).To(Equal(0))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("fails when trying to choose from empty set", func() {
|
It("fails when trying to choose from empty set", func() {
|
||||||
w = NewWeightedRandomChooser()
|
w = NewWeightedChooser[int]()
|
||||||
w.Add("a", 1)
|
w.Add(1, 1)
|
||||||
w.Add("b", 1)
|
w.Add(2, 1)
|
||||||
Expect(w.GetAndRemove()).To(BeElementOf("a", "b"))
|
Expect(w.Pick()).To(BeElementOf(1, 2))
|
||||||
Expect(w.GetAndRemove()).To(BeElementOf("a", "b"))
|
Expect(w.Pick()).To(BeElementOf(1, 2))
|
||||||
_, err := w.GetAndRemove()
|
_, err := w.Pick()
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user