Refactor random.WeightedChooser, unsing generics

This commit is contained in:
Deluan 2024-05-17 15:45:34 -04:00
parent a7a4fb522c
commit 4d28d534cc
3 changed files with 68 additions and 41 deletions

View File

@ -267,8 +267,8 @@ func (e *externalMetadata) SimilarSongs(ctx context.Context, id string, count in
return nil, ctx.Err() return nil, ctx.Err()
} }
weightedSongs := random.NewWeightedRandomChooser() weightedSongs := random.NewWeightedChooser[model.MediaFile]()
addArtist := func(a model.Artist, weightedSongs *random.WeightedChooser, count, artistWeight int) error { addArtist := func(a model.Artist, weightedSongs *random.WeightedChooser[model.MediaFile], count, artistWeight int) error {
if utils.IsCtxDone(ctx) { if utils.IsCtxDone(ctx) {
log.Warn(ctx, "SimilarSongs call canceled", ctx.Err()) log.Warn(ctx, "SimilarSongs call canceled", ctx.Err())
return ctx.Err() return ctx.Err()
@ -302,12 +302,12 @@ func (e *externalMetadata) SimilarSongs(ctx context.Context, id string, count in
var similarSongs model.MediaFiles var similarSongs model.MediaFiles
for len(similarSongs) < count && weightedSongs.Size() > 0 { for len(similarSongs) < count && weightedSongs.Size() > 0 {
s, err := weightedSongs.GetAndRemove() s, err := weightedSongs.Pick()
if err != nil { if err != nil {
log.Warn(ctx, "Error getting weighted song", err) log.Warn(ctx, "Error getting weighted song", err)
continue continue
} }
similarSongs = append(similarSongs, s.(model.MediaFile)) similarSongs = append(similarSongs, s)
} }
return similarSongs, nil return similarSongs, nil

View File

@ -2,42 +2,46 @@ package random
import ( import (
"errors" "errors"
"slices"
) )
type WeightedChooser struct { // WeightedChooser allows to randomly choose an entry based on their weights
entries []interface{} // (higher weight = higher chance of being chosen). Based on the subtraction method described in
// https://eli.thegreenplace.net/2010/01/22/weighted-random-generation-in-python/
type WeightedChooser[T any] struct {
entries []T
weights []int weights []int
totalWeight int totalWeight int
} }
func NewWeightedRandomChooser() *WeightedChooser { func NewWeightedChooser[T any]() *WeightedChooser[T] {
return &WeightedChooser{} return &WeightedChooser[T]{}
} }
func (w *WeightedChooser) Add(value interface{}, weight int) { func (w *WeightedChooser[T]) Add(value T, weight int) {
w.entries = append(w.entries, value) w.entries = append(w.entries, value)
w.weights = append(w.weights, weight) w.weights = append(w.weights, weight)
w.totalWeight += weight w.totalWeight += weight
} }
// GetAndRemove choose a random entry based on their weights, and removes it from the list // Pick choose a random entry based on their weights, and removes it from the list
func (w *WeightedChooser) GetAndRemove() (interface{}, error) { func (w *WeightedChooser[T]) Pick() (T, error) {
var empty T
if w.totalWeight == 0 { if w.totalWeight == 0 {
return nil, errors.New("cannot choose from zero weight") return empty, errors.New("cannot choose from zero weight")
} }
i, err := w.weightedChoice() i, err := w.weightedChoice()
if err != nil { if err != nil {
return nil, err return empty, err
} }
entry := w.entries[i] entry := w.entries[i]
w.Remove(i) _ = w.Remove(i)
return entry, nil return entry, nil
} }
// Based on https://eli.thegreenplace.net/2010/01/22/weighted-random-generation-in-python/ func (w *WeightedChooser[T]) weightedChoice() (int, error) {
func (w *WeightedChooser) weightedChoice() (int, error) { if len(w.entries) == 0 {
if w.totalWeight == 0 { return 0, errors.New("cannot choose from empty list")
return 0, errors.New("no choices available")
} }
rnd := Int64(w.totalWeight) rnd := Int64(w.totalWeight)
for i, weight := range w.weights { for i, weight := range w.weights {
@ -49,17 +53,18 @@ func (w *WeightedChooser) weightedChoice() (int, error) {
return 0, errors.New("internal error - code should not reach this point") return 0, errors.New("internal error - code should not reach this point")
} }
func (w *WeightedChooser) Remove(i int) { func (w *WeightedChooser[T]) Remove(i int) error {
if i < 0 || i >= len(w.entries) {
return errors.New("index out of bounds")
}
w.totalWeight -= w.weights[i] w.totalWeight -= w.weights[i]
w.weights[i] = w.weights[len(w.weights)-1] w.weights = slices.Delete(w.weights, i, i+1)
w.weights = w.weights[:len(w.weights)-1] w.entries = slices.Delete(w.entries, i, i+1)
return nil
w.entries[i] = w.entries[len(w.entries)-1]
w.entries[len(w.entries)-1] = nil
w.entries = w.entries[:len(w.entries)-1]
} }
func (w *WeightedChooser) Size() int { func (w *WeightedChooser[T]) Size() int {
return len(w.entries) return len(w.entries)
} }

View File

@ -5,35 +5,57 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
var _ = Describe("WeightedRandomChooser", func() { var _ = Describe("WeightedChooser", func() {
var w *WeightedChooser var w *WeightedChooser[int]
BeforeEach(func() { BeforeEach(func() {
w = NewWeightedRandomChooser() w = NewWeightedChooser[int]()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
w.Add(i, i) w.Add(i, i+1)
} }
}) })
It("removes a random item", func() { It("selects and removes a random item", func() {
Expect(w.Size()).To(Equal(10)) Expect(w.Size()).To(Equal(10))
_, err := w.GetAndRemove() _, err := w.Pick()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(w.Size()).To(Equal(9)) Expect(w.Size()).To(Equal(9))
}) })
It("removes items", func() {
Expect(w.Size()).To(Equal(10))
for i := 0; i < 10; i++ {
Expect(w.Remove(0)).To(Succeed())
}
Expect(w.Size()).To(Equal(0))
})
It("returns error if trying to remove an invalid index", func() {
Expect(w.Size()).To(Equal(10))
Expect(w.Remove(-1)).ToNot(Succeed())
Expect(w.Remove(10000)).ToNot(Succeed())
Expect(w.Size()).To(Equal(10))
})
It("returns the sole item", func() { It("returns the sole item", func() {
w = NewWeightedRandomChooser() ws := NewWeightedChooser[string]()
w.Add("a", 1) ws.Add("a", 1)
Expect(w.GetAndRemove()).To(Equal("a")) Expect(ws.Pick()).To(Equal("a"))
})
It("returns all items from the list", func() {
for i := 0; i < 10; i++ {
Expect(w.Pick()).To(BeElementOf(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
}
Expect(w.Size()).To(Equal(0))
}) })
It("fails when trying to choose from empty set", func() { It("fails when trying to choose from empty set", func() {
w = NewWeightedRandomChooser() w = NewWeightedChooser[int]()
w.Add("a", 1) w.Add(1, 1)
w.Add("b", 1) w.Add(2, 1)
Expect(w.GetAndRemove()).To(BeElementOf("a", "b")) Expect(w.Pick()).To(BeElementOf(1, 2))
Expect(w.GetAndRemove()).To(BeElementOf("a", "b")) Expect(w.Pick()).To(BeElementOf(1, 2))
_, err := w.GetAndRemove() _, err := w.Pick()
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })