From d3bb2e1b6bc27620ea1369a9b969bcf00dd8af08 Mon Sep 17 00:00:00 2001 From: norohind <60548839+norohind@users.noreply.github.com> Date: Wed, 6 Dec 2023 14:25:22 +0300 Subject: [PATCH] pure python init --- anaylyse_results.py | 19 +++ main.py | 278 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 297 insertions(+) create mode 100644 anaylyse_results.py create mode 100644 main.py diff --git a/anaylyse_results.py b/anaylyse_results.py new file mode 100644 index 0000000..84f61b0 --- /dev/null +++ b/anaylyse_results.py @@ -0,0 +1,19 @@ +from pprint import pprint +import json +from pathlib import Path +from main import RunResults, Iteration + + +def main(): + for json_file in Path('.').glob('*.json'): + with open(json_file, mode='r') as file: + content = json.load(file) + + k = content['k'] + wss = content['iterations'][-1]['wss'] + delta = content['iterations'][-1]['cent_move_delta'] + print(f"{k}, {wss:.2f}, {delta:.2f}, {len(content['iterations'])}") + + +if __name__ == '__main__': + main() diff --git a/main.py b/main.py new file mode 100644 index 0000000..1e9cd71 --- /dev/null +++ b/main.py @@ -0,0 +1,278 @@ +import datetime +import json +from collections import defaultdict +from dataclasses import dataclass, asdict +from math import sqrt +from pprint import pprint + +import psycopg +from psycopg.rows import dict_row, class_row +from tqdm import tqdm + +IMPORTANT_FIELDS: tuple[str, ...] = ( + 'Tempo', + 'Zcr', + 'MeanSpectralCentroid', + 'StdDevSpectralCentroid', + 'MeanSpectralRolloff', + 'StdDevSpectralRolloff', + 'MeanSpectralFlatness', + 'StdDevSpectralFlatness', + 'MeanLoudness', + 'StdDevLoudness', + 'Chroma1', + 'Chroma2', + 'Chroma3', + 'Chroma4', + 'Chroma5', + 'Chroma6', + 'Chroma7', + 'Chroma8', + 'Chroma9', + 'Chroma10', +) + + +@dataclass(frozen=True) +class UpdateCentroidsRes: + cent_move_delta: float + wss: float # wss - within-cluster-sum of squared errors + + +@dataclass +class EntityBase: + Tempo: float + Zcr: float + MeanSpectralCentroid: float + StdDevSpectralCentroid: float + MeanSpectralRolloff: float + StdDevSpectralRolloff: float + MeanSpectralFlatness: float + StdDevSpectralFlatness: float + MeanLoudness: float + StdDevLoudness: float + Chroma1: float + Chroma2: float + Chroma3: float + Chroma4: float + Chroma5: float + Chroma6: float + Chroma7: float + Chroma8: float + Chroma9: float + Chroma10: float + + def __matmul__(self, other: "EntityBase") -> float: + """Euclidian distance""" + return sqrt( + (self.Tempo - other.Tempo) ** 2 + + (self.Zcr - other.Zcr) ** 2 + + (self.MeanSpectralCentroid - other.MeanSpectralCentroid) ** 2 + + (self.StdDevSpectralCentroid - other.StdDevSpectralCentroid) ** 2 + + (self.MeanSpectralRolloff - other.MeanSpectralRolloff) ** 2 + + (self.StdDevSpectralRolloff - other.StdDevSpectralRolloff) ** 2 + + (self.MeanSpectralFlatness - other.MeanSpectralFlatness) ** 2 + + (self.StdDevSpectralFlatness - other.StdDevSpectralFlatness) ** 2 + + (self.MeanLoudness - other.MeanLoudness) ** 2 + + (self.StdDevLoudness - other.StdDevLoudness) ** 2 + + (self.Chroma1 - other.Chroma1) ** 2 + + (self.Chroma2 - other.Chroma2) ** 2 + + (self.Chroma3 - other.Chroma3) ** 2 + + (self.Chroma4 - other.Chroma4) ** 2 + + (self.Chroma5 - other.Chroma5) ** 2 + + (self.Chroma6 - other.Chroma6) ** 2 + + (self.Chroma7 - other.Chroma7) ** 2 + + (self.Chroma8 - other.Chroma8) ** 2 + + (self.Chroma9 - other.Chroma9) ** 2 + + (self.Chroma10 - other.Chroma10) ** 2 + ) + + +@dataclass +class Centroid(EntityBase): + id: int + + +@dataclass +class Song(EntityBase): + File: str + + def __str__(self): + return self.File + + __repr__ = __str__ + + +@dataclass(frozen=True) +class Iteration: + n: int + wss: float + cent_move_delta: float + + +@dataclass(frozen=True) +class RunResults: + centroids: list[Centroid] + k: int + iterations: list[Iteration] + + +def choose_centroids(conn: psycopg.Connection, k: int = 1000) -> dict[int, Centroid]: + final_dict: dict[int, Centroid] = dict() + QUERY = """ +select + row_number() over () as id, * + from ( + select + "Tempo", + "Zcr", + "MeanSpectralCentroid", + "StdDevSpectralCentroid", + "MeanSpectralRolloff", + "StdDevSpectralRolloff", + "MeanSpectralFlatness", + "StdDevSpectralFlatness", + "MeanLoudness", + "StdDevLoudness", + "Chroma1", + "Chroma2", + "Chroma3", + "Chroma4", + "Chroma5", + "Chroma6", + "Chroma7", + "Chroma8", + "Chroma9", + "Chroma10" + FROM bliss_tracks bt + ORDER BY RANDOM() + LIMIT %(k)s +) with_ids;""" + with conn.cursor(row_factory=class_row(Centroid)) as cur: + cur.execute(QUERY, {"k": k}) + for song in cur.fetchall(): + final_dict[song.id] = song + + return final_dict + + +def get_all_songs(conn: psycopg.Connection) -> list[Song]: + QUERY = """ + select + "File", + "Tempo", + "Zcr", + "MeanSpectralCentroid", + "StdDevSpectralCentroid", + "MeanSpectralRolloff", + "StdDevSpectralRolloff", + "MeanSpectralFlatness", + "StdDevSpectralFlatness", + "MeanLoudness", + "StdDevLoudness", + "Chroma1", + "Chroma2", + "Chroma3", + "Chroma4", + "Chroma5", + "Chroma6", + "Chroma7", + "Chroma8", + "Chroma9", + "Chroma10" + FROM bliss_tracks;""" + with conn.cursor(row_factory=class_row(Song)) as cur: + cur.execute(QUERY) + return cur.fetchall() + + +def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]: + centroids_with_songs: dict[int, list[Song]] = defaultdict(list) + + # Now we need to assign songs to centroids + for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1): + # print(f"Song: {song.File}") + + # Need to find the closest centroid + closest_centroid = min(centroids.values(), key=lambda centroid: song @ centroid) + # print(f"We've selected group of centroid: {closest_centroid.id} with distance {closest_centroid @ song}") + centroids_with_songs[closest_centroid.id].append(song) + + return centroids_with_songs + + +def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dict[int, Centroid]) -> UpdateCentroidsRes: + """Returns total delta""" + total_delta = 0.0 + wss = 0.0 + + for cent_id, cent_songs in tqdm(centroids_with_songs.items(), desc="Updating centroids", disable=False, delay=1): + centroid: Centroid = centroids[cent_id] + + for field_name in IMPORTANT_FIELDS: + old_field_value = getattr(centroid, field_name) + avg_value_of_field = sum(getattr(song, field_name) for song in cent_songs) / len(cent_songs) + setattr(centroid, field_name, avg_value_of_field) + + total_delta += abs(old_field_value - avg_value_of_field) + + wss += sum((centroid @ song) ** 2 for song in cent_songs) + + return UpdateCentroidsRes(cent_move_delta=total_delta, wss=wss) + + +def run(conn: psycopg.Connection, k: int): + print(f'Running with k {k}') + centroids = choose_centroids(conn, k=k) + all_songs = get_all_songs(conn) + + iterations: list[Iteration] = list() + + for iteration_num in range(10): + centroids_with_songs = centroids_to_songs(centroids, all_songs) + upd_cent_res = update_centroids(centroids_with_songs, centroids) # Centroids are passed by reference + + iterations.append(Iteration( + n=iteration_num, + cent_move_delta=upd_cent_res.cent_move_delta, + wss=upd_cent_res.wss + )) + + pprint(iterations[-1]) + + if upd_cent_res.cent_move_delta < 0.5: + print("Leaving loop by delta") + break + + else: + print("Leaving loop by attempts exceeding") + + # print("Clusters:") + # for songs in centroids_to_songs(centroids, all_songs).values(): + # if len(songs) > 1: + # pprint(songs) + + with open(f'{datetime.datetime.now().isoformat().replace(":", "-")}.json', mode='w', encoding='utf-8') as res_file: + results = asdict(RunResults(centroids=list(centroids.values()), k=k, iterations=iterations)) + json.dump(results, res_file, indent=4) + + +def main(k: int): + conn = psycopg.connect( + "dbname=postgres user=postgres port=5555 host=192.168.1.68 password=music_recommendation_service_postgres", + row_factory=dict_row, + ) + + try: + # for k in range(10000, 16000, 1000): + run(conn, k) + + except (Exception, KeyboardInterrupt) as e: + conn.close() + raise e + + +if __name__ == "__main__": + from sys import argv + + main(int(argv[1]))