277 lines
7.9 KiB
Python

import datetime
import json
from collections import defaultdict
from dataclasses import dataclass, asdict, field
from numba import jit, njit
from math import sqrt
from pprint import pprint
import numpy as np
import psycopg
from psycopg.rows import dict_row, class_row
from tqdm import tqdm
IMPORTANT_FIELDS: tuple[str, ...] = (
'Tempo',
'Zcr',
'MeanSpectralCentroid',
'StdDevSpectralCentroid',
'MeanSpectralRolloff',
'StdDevSpectralRolloff',
'MeanSpectralFlatness',
'StdDevSpectralFlatness',
'MeanLoudness',
'StdDevLoudness',
'Chroma1',
'Chroma2',
'Chroma3',
'Chroma4',
'Chroma5',
'Chroma6',
'Chroma7',
'Chroma8',
'Chroma9',
'Chroma10',
)
class DimensionDescriptor:
def __set_name__(self, owner, name):
self.name: str = name
self.idx: int = IMPORTANT_FIELDS.index(name)
def __get__(self, instance: 'EntityBase', owner) -> float:
return instance.np_arr[self.idx]
def __set__(self, instance: 'EntityBase', value: float):
instance.np_arr[self.idx] = value
@njit(cache=True)
def dist(x: np.ndarray, y: np.ndarray) -> float:
return np.sqrt(np.sum(np.square(x - y)))
@dataclass(frozen=True)
class UpdateCentroidsRes:
cent_move_delta: float
wss: float # wss - within-cluster-sum of squared errors
@dataclass(kw_only=True)
class EntityBase:
np_arr: np.ndarray = field(init=False, repr=False, default_factory=lambda: np.zeros(len(IMPORTANT_FIELDS)))
Tempo: float = DimensionDescriptor()
Zcr: float = DimensionDescriptor()
MeanSpectralCentroid: float = DimensionDescriptor()
StdDevSpectralCentroid: float = DimensionDescriptor()
MeanSpectralRolloff: float = DimensionDescriptor()
StdDevSpectralRolloff: float = DimensionDescriptor()
MeanSpectralFlatness: float = DimensionDescriptor()
StdDevSpectralFlatness: float = DimensionDescriptor()
MeanLoudness: float = DimensionDescriptor()
StdDevLoudness: float = DimensionDescriptor()
Chroma1: float = DimensionDescriptor()
Chroma2: float = DimensionDescriptor()
Chroma3: float = DimensionDescriptor()
Chroma4: float = DimensionDescriptor()
Chroma5: float = DimensionDescriptor()
Chroma6: float = DimensionDescriptor()
Chroma7: float = DimensionDescriptor()
Chroma8: float = DimensionDescriptor()
Chroma9: float = DimensionDescriptor()
Chroma10: float = DimensionDescriptor()
@dataclass(kw_only=True)
class Centroid(EntityBase):
id: int
@dataclass(kw_only=True)
class Song(EntityBase):
File: str
def __str__(self):
return self.File
__repr__ = __str__
@dataclass(frozen=True)
class Iteration:
n: int
wss: float
cent_move_delta: float
@dataclass(frozen=True)
class RunResults:
centroids: list[Centroid]
k: int
iterations: list[Iteration]
def choose_centroids(conn: psycopg.Connection, k: int = 1000) -> dict[int, Centroid]:
final_dict: dict[int, Centroid] = dict()
QUERY = """
select
row_number() over () as id, *
from (
select
"Tempo",
"Zcr",
"MeanSpectralCentroid",
"StdDevSpectralCentroid",
"MeanSpectralRolloff",
"StdDevSpectralRolloff",
"MeanSpectralFlatness",
"StdDevSpectralFlatness",
"MeanLoudness",
"StdDevLoudness",
"Chroma1",
"Chroma2",
"Chroma3",
"Chroma4",
"Chroma5",
"Chroma6",
"Chroma7",
"Chroma8",
"Chroma9",
"Chroma10"
FROM bliss_tracks bt
ORDER BY RANDOM()
LIMIT %(k)s
) with_ids;"""
with conn.cursor(row_factory=class_row(Centroid)) as cur:
cur.execute(QUERY, {"k": k})
for song in cur.fetchall():
final_dict[song.id] = song
return final_dict
def get_all_songs(conn: psycopg.Connection) -> list[Song]:
QUERY = """
select
"File",
"Tempo",
"Zcr",
"MeanSpectralCentroid",
"StdDevSpectralCentroid",
"MeanSpectralRolloff",
"StdDevSpectralRolloff",
"MeanSpectralFlatness",
"StdDevSpectralFlatness",
"MeanLoudness",
"StdDevLoudness",
"Chroma1",
"Chroma2",
"Chroma3",
"Chroma4",
"Chroma5",
"Chroma6",
"Chroma7",
"Chroma8",
"Chroma9",
"Chroma10"
FROM bliss_tracks;"""
with conn.cursor(row_factory=class_row(Song)) as cur:
cur.execute(QUERY)
return cur.fetchall()
def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]:
ldist = dist # Local reference for faster lookup
centroids_with_songs: dict[int, list[Song]] = defaultdict(list)
# Now we need to assign songs to centroids
for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1):
# print(f"Song: {song.File}")
song_np_arr = song.np_arr
# Need to find the closest centroid
closest_centroid = min(centroids.values(), key=lambda centroid: ldist(song_np_arr, centroid.np_arr))
# print(f"We've selected group of centroid: {closest_centroid.id} with distance {closest_centroid @ song}")
centroids_with_songs[closest_centroid.id].append(song)
return centroids_with_songs
def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dict[int, Centroid]) -> UpdateCentroidsRes:
"""Returns total delta"""
total_delta = 0.0
wss = 0.0
for cent_id, cent_songs in tqdm(centroids_with_songs.items(), desc="Updating centroids", disable=False, delay=1):
centroid: Centroid = centroids[cent_id]
for field_name in IMPORTANT_FIELDS:
old_field_value = getattr(centroid, field_name)
avg_value_of_field = sum(getattr(song, field_name) for song in cent_songs) / len(cent_songs)
setattr(centroid, field_name, avg_value_of_field)
total_delta += abs(old_field_value - avg_value_of_field)
wss += sum((dist(centroid.np_arr, song.np_arr)) ** 2 for song in cent_songs)
return UpdateCentroidsRes(cent_move_delta=total_delta, wss=wss)
def run(conn: psycopg.Connection, k: int):
print(f'Running with k {k}')
centroids = choose_centroids(conn, k=k)
all_songs = get_all_songs(conn)
iterations: list[Iteration] = list()
for iteration_num in range(10):
centroids_with_songs = centroids_to_songs(centroids, all_songs)
upd_cent_res = update_centroids(centroids_with_songs, centroids) # Centroids are passed by reference
iterations.append(Iteration(
n=iteration_num,
cent_move_delta=upd_cent_res.cent_move_delta,
wss=upd_cent_res.wss
))
pprint(iterations[-1])
if upd_cent_res.cent_move_delta < 0.5:
print("Leaving loop by delta")
break
else:
print("Leaving loop by attempts exceeding")
# print("Clusters:")
# for songs in centroids_to_songs(centroids, all_songs).values():
# if len(songs) > 1:
# pprint(songs)
with open(f'{datetime.datetime.now().isoformat().replace(":", "-")}.json', mode='w', encoding='utf-8') as res_file:
results = asdict(RunResults(centroids=list(centroids.values()), k=k, iterations=iterations))
json.dump(results, res_file, indent=4)
def main(k: int):
conn = psycopg.connect(
"dbname=postgres user=postgres port=5555 host=192.168.1.68 password=music_recommendation_service_postgres",
row_factory=dict_row,
)
try:
# for k in range(10000, 16000, 1000):
run(conn, k)
except (Exception, KeyboardInterrupt) as e:
conn.close()
print('Connection closed')
raise e
if __name__ == "__main__":
from sys import argv
main(int(argv[1]))