277 lines
7.9 KiB
Python
277 lines
7.9 KiB
Python
import datetime
|
|
import json
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, asdict, field
|
|
from numba import jit, njit
|
|
from math import sqrt
|
|
from pprint import pprint
|
|
|
|
import numpy as np
|
|
import psycopg
|
|
from psycopg.rows import dict_row, class_row
|
|
from tqdm import tqdm
|
|
|
|
IMPORTANT_FIELDS: tuple[str, ...] = (
|
|
'Tempo',
|
|
'Zcr',
|
|
'MeanSpectralCentroid',
|
|
'StdDevSpectralCentroid',
|
|
'MeanSpectralRolloff',
|
|
'StdDevSpectralRolloff',
|
|
'MeanSpectralFlatness',
|
|
'StdDevSpectralFlatness',
|
|
'MeanLoudness',
|
|
'StdDevLoudness',
|
|
'Chroma1',
|
|
'Chroma2',
|
|
'Chroma3',
|
|
'Chroma4',
|
|
'Chroma5',
|
|
'Chroma6',
|
|
'Chroma7',
|
|
'Chroma8',
|
|
'Chroma9',
|
|
'Chroma10',
|
|
)
|
|
|
|
|
|
class DimensionDescriptor:
|
|
def __set_name__(self, owner, name):
|
|
self.name: str = name
|
|
self.idx: int = IMPORTANT_FIELDS.index(name)
|
|
|
|
def __get__(self, instance: 'EntityBase', owner) -> float:
|
|
return instance.np_arr[self.idx]
|
|
|
|
def __set__(self, instance: 'EntityBase', value: float):
|
|
instance.np_arr[self.idx] = value
|
|
|
|
|
|
@njit(cache=True)
|
|
def dist(x: np.ndarray, y: np.ndarray) -> float:
|
|
return np.sqrt(np.sum(np.square(x - y)))
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class UpdateCentroidsRes:
|
|
cent_move_delta: float
|
|
wss: float # wss - within-cluster-sum of squared errors
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class EntityBase:
|
|
np_arr: np.ndarray = field(init=False, repr=False, default_factory=lambda: np.zeros(len(IMPORTANT_FIELDS)))
|
|
Tempo: float = DimensionDescriptor()
|
|
Zcr: float = DimensionDescriptor()
|
|
MeanSpectralCentroid: float = DimensionDescriptor()
|
|
StdDevSpectralCentroid: float = DimensionDescriptor()
|
|
MeanSpectralRolloff: float = DimensionDescriptor()
|
|
StdDevSpectralRolloff: float = DimensionDescriptor()
|
|
MeanSpectralFlatness: float = DimensionDescriptor()
|
|
StdDevSpectralFlatness: float = DimensionDescriptor()
|
|
MeanLoudness: float = DimensionDescriptor()
|
|
StdDevLoudness: float = DimensionDescriptor()
|
|
Chroma1: float = DimensionDescriptor()
|
|
Chroma2: float = DimensionDescriptor()
|
|
Chroma3: float = DimensionDescriptor()
|
|
Chroma4: float = DimensionDescriptor()
|
|
Chroma5: float = DimensionDescriptor()
|
|
Chroma6: float = DimensionDescriptor()
|
|
Chroma7: float = DimensionDescriptor()
|
|
Chroma8: float = DimensionDescriptor()
|
|
Chroma9: float = DimensionDescriptor()
|
|
Chroma10: float = DimensionDescriptor()
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class Centroid(EntityBase):
|
|
id: int
|
|
|
|
|
|
@dataclass(kw_only=True)
|
|
class Song(EntityBase):
|
|
File: str
|
|
|
|
def __str__(self):
|
|
return self.File
|
|
|
|
__repr__ = __str__
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Iteration:
|
|
n: int
|
|
wss: float
|
|
cent_move_delta: float
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RunResults:
|
|
centroids: list[Centroid]
|
|
k: int
|
|
iterations: list[Iteration]
|
|
|
|
|
|
def choose_centroids(conn: psycopg.Connection, k: int = 1000) -> dict[int, Centroid]:
|
|
final_dict: dict[int, Centroid] = dict()
|
|
QUERY = """
|
|
select
|
|
row_number() over () as id, *
|
|
from (
|
|
select
|
|
"Tempo",
|
|
"Zcr",
|
|
"MeanSpectralCentroid",
|
|
"StdDevSpectralCentroid",
|
|
"MeanSpectralRolloff",
|
|
"StdDevSpectralRolloff",
|
|
"MeanSpectralFlatness",
|
|
"StdDevSpectralFlatness",
|
|
"MeanLoudness",
|
|
"StdDevLoudness",
|
|
"Chroma1",
|
|
"Chroma2",
|
|
"Chroma3",
|
|
"Chroma4",
|
|
"Chroma5",
|
|
"Chroma6",
|
|
"Chroma7",
|
|
"Chroma8",
|
|
"Chroma9",
|
|
"Chroma10"
|
|
FROM bliss_tracks bt
|
|
ORDER BY RANDOM()
|
|
LIMIT %(k)s
|
|
) with_ids;"""
|
|
with conn.cursor(row_factory=class_row(Centroid)) as cur:
|
|
cur.execute(QUERY, {"k": k})
|
|
for song in cur.fetchall():
|
|
final_dict[song.id] = song
|
|
|
|
return final_dict
|
|
|
|
|
|
def get_all_songs(conn: psycopg.Connection) -> list[Song]:
|
|
QUERY = """
|
|
select
|
|
"File",
|
|
"Tempo",
|
|
"Zcr",
|
|
"MeanSpectralCentroid",
|
|
"StdDevSpectralCentroid",
|
|
"MeanSpectralRolloff",
|
|
"StdDevSpectralRolloff",
|
|
"MeanSpectralFlatness",
|
|
"StdDevSpectralFlatness",
|
|
"MeanLoudness",
|
|
"StdDevLoudness",
|
|
"Chroma1",
|
|
"Chroma2",
|
|
"Chroma3",
|
|
"Chroma4",
|
|
"Chroma5",
|
|
"Chroma6",
|
|
"Chroma7",
|
|
"Chroma8",
|
|
"Chroma9",
|
|
"Chroma10"
|
|
FROM bliss_tracks;"""
|
|
with conn.cursor(row_factory=class_row(Song)) as cur:
|
|
cur.execute(QUERY)
|
|
return cur.fetchall()
|
|
|
|
|
|
def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]:
|
|
ldist = dist # Local reference for faster lookup
|
|
centroids_with_songs: dict[int, list[Song]] = defaultdict(list)
|
|
|
|
# Now we need to assign songs to centroids
|
|
for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1):
|
|
# print(f"Song: {song.File}")
|
|
song_np_arr = song.np_arr
|
|
|
|
# Need to find the closest centroid
|
|
closest_centroid = min(centroids.values(), key=lambda centroid: ldist(song_np_arr, centroid.np_arr))
|
|
# print(f"We've selected group of centroid: {closest_centroid.id} with distance {closest_centroid @ song}")
|
|
centroids_with_songs[closest_centroid.id].append(song)
|
|
|
|
return centroids_with_songs
|
|
|
|
|
|
def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dict[int, Centroid]) -> UpdateCentroidsRes:
|
|
"""Returns total delta"""
|
|
total_delta = 0.0
|
|
wss = 0.0
|
|
|
|
for cent_id, cent_songs in tqdm(centroids_with_songs.items(), desc="Updating centroids", disable=False, delay=1):
|
|
centroid: Centroid = centroids[cent_id]
|
|
|
|
for field_name in IMPORTANT_FIELDS:
|
|
old_field_value = getattr(centroid, field_name)
|
|
avg_value_of_field = sum(getattr(song, field_name) for song in cent_songs) / len(cent_songs)
|
|
setattr(centroid, field_name, avg_value_of_field)
|
|
|
|
total_delta += abs(old_field_value - avg_value_of_field)
|
|
|
|
wss += sum((dist(centroid.np_arr, song.np_arr)) ** 2 for song in cent_songs)
|
|
|
|
return UpdateCentroidsRes(cent_move_delta=total_delta, wss=wss)
|
|
|
|
|
|
def run(conn: psycopg.Connection, k: int):
|
|
print(f'Running with k {k}')
|
|
centroids = choose_centroids(conn, k=k)
|
|
all_songs = get_all_songs(conn)
|
|
|
|
iterations: list[Iteration] = list()
|
|
|
|
for iteration_num in range(10):
|
|
centroids_with_songs = centroids_to_songs(centroids, all_songs)
|
|
upd_cent_res = update_centroids(centroids_with_songs, centroids) # Centroids are passed by reference
|
|
|
|
iterations.append(Iteration(
|
|
n=iteration_num,
|
|
cent_move_delta=upd_cent_res.cent_move_delta,
|
|
wss=upd_cent_res.wss
|
|
))
|
|
|
|
pprint(iterations[-1])
|
|
|
|
if upd_cent_res.cent_move_delta < 0.5:
|
|
print("Leaving loop by delta")
|
|
break
|
|
|
|
else:
|
|
print("Leaving loop by attempts exceeding")
|
|
|
|
# print("Clusters:")
|
|
# for songs in centroids_to_songs(centroids, all_songs).values():
|
|
# if len(songs) > 1:
|
|
# pprint(songs)
|
|
|
|
with open(f'{datetime.datetime.now().isoformat().replace(":", "-")}.json', mode='w', encoding='utf-8') as res_file:
|
|
results = asdict(RunResults(centroids=list(centroids.values()), k=k, iterations=iterations))
|
|
json.dump(results, res_file, indent=4)
|
|
|
|
|
|
def main(k: int):
|
|
conn = psycopg.connect(
|
|
"dbname=postgres user=postgres port=5555 host=192.168.1.68 password=music_recommendation_service_postgres",
|
|
row_factory=dict_row,
|
|
)
|
|
|
|
try:
|
|
# for k in range(10000, 16000, 1000):
|
|
run(conn, k)
|
|
|
|
except (Exception, KeyboardInterrupt) as e:
|
|
conn.close()
|
|
print('Connection closed')
|
|
raise e
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from sys import argv
|
|
|
|
main(int(argv[1]))
|