commit d3bb2e1b6bc27620ea1369a9b969bcf00dd8af08
Author: norohind <60548839+norohind@users.noreply.github.com>
Date:   Wed Dec 6 14:25:22 2023 +0300

    pure python init

diff --git a/anaylyse_results.py b/anaylyse_results.py
new file mode 100644
index 0000000..84f61b0
--- /dev/null
+++ b/anaylyse_results.py
@@ -0,0 +1,19 @@
+from pprint import pprint
+import json
+from pathlib import Path
+from main import RunResults, Iteration
+
+
+def main():
+    for json_file in Path('.').glob('*.json'):
+        with open(json_file, mode='r') as file:
+            content = json.load(file)
+
+        k = content['k']
+        wss = content['iterations'][-1]['wss']
+        delta = content['iterations'][-1]['cent_move_delta']
+        print(f"{k}, {wss:.2f}, {delta:.2f}, {len(content['iterations'])}")
+
+
+if __name__ == '__main__':
+    main()
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..1e9cd71
--- /dev/null
+++ b/main.py
@@ -0,0 +1,278 @@
+import datetime
+import json
+from collections import defaultdict
+from dataclasses import dataclass, asdict
+from math import sqrt
+from pprint import pprint
+
+import psycopg
+from psycopg.rows import dict_row, class_row
+from tqdm import tqdm
+
+IMPORTANT_FIELDS: tuple[str, ...] = (
+    'Tempo',
+    'Zcr',
+    'MeanSpectralCentroid',
+    'StdDevSpectralCentroid',
+    'MeanSpectralRolloff',
+    'StdDevSpectralRolloff',
+    'MeanSpectralFlatness',
+    'StdDevSpectralFlatness',
+    'MeanLoudness',
+    'StdDevLoudness',
+    'Chroma1',
+    'Chroma2',
+    'Chroma3',
+    'Chroma4',
+    'Chroma5',
+    'Chroma6',
+    'Chroma7',
+    'Chroma8',
+    'Chroma9',
+    'Chroma10',
+)
+
+
+@dataclass(frozen=True)
+class UpdateCentroidsRes:
+    cent_move_delta: float
+    wss: float  # wss - within-cluster-sum of squared errors
+
+
+@dataclass
+class EntityBase:
+    Tempo: float
+    Zcr: float
+    MeanSpectralCentroid: float
+    StdDevSpectralCentroid: float
+    MeanSpectralRolloff: float
+    StdDevSpectralRolloff: float
+    MeanSpectralFlatness: float
+    StdDevSpectralFlatness: float
+    MeanLoudness: float
+    StdDevLoudness: float
+    Chroma1: float
+    Chroma2: float
+    Chroma3: float
+    Chroma4: float
+    Chroma5: float
+    Chroma6: float
+    Chroma7: float
+    Chroma8: float
+    Chroma9: float
+    Chroma10: float
+
+    def __matmul__(self, other: "EntityBase") -> float:
+        """Euclidian distance"""
+        return sqrt(
+            (self.Tempo - other.Tempo) ** 2
+            + (self.Zcr - other.Zcr) ** 2
+            + (self.MeanSpectralCentroid - other.MeanSpectralCentroid) ** 2
+            + (self.StdDevSpectralCentroid - other.StdDevSpectralCentroid) ** 2
+            + (self.MeanSpectralRolloff - other.MeanSpectralRolloff) ** 2
+            + (self.StdDevSpectralRolloff - other.StdDevSpectralRolloff) ** 2
+            + (self.MeanSpectralFlatness - other.MeanSpectralFlatness) ** 2
+            + (self.StdDevSpectralFlatness - other.StdDevSpectralFlatness) ** 2
+            + (self.MeanLoudness - other.MeanLoudness) ** 2
+            + (self.StdDevLoudness - other.StdDevLoudness) ** 2
+            + (self.Chroma1 - other.Chroma1) ** 2
+            + (self.Chroma2 - other.Chroma2) ** 2
+            + (self.Chroma3 - other.Chroma3) ** 2
+            + (self.Chroma4 - other.Chroma4) ** 2
+            + (self.Chroma5 - other.Chroma5) ** 2
+            + (self.Chroma6 - other.Chroma6) ** 2
+            + (self.Chroma7 - other.Chroma7) ** 2
+            + (self.Chroma8 - other.Chroma8) ** 2
+            + (self.Chroma9 - other.Chroma9) ** 2
+            + (self.Chroma10 - other.Chroma10) ** 2
+        )
+
+
+@dataclass
+class Centroid(EntityBase):
+    id: int
+
+
+@dataclass
+class Song(EntityBase):
+    File: str
+
+    def __str__(self):
+        return self.File
+
+    __repr__ = __str__
+
+
+@dataclass(frozen=True)
+class Iteration:
+    n: int
+    wss: float
+    cent_move_delta: float
+
+
+@dataclass(frozen=True)
+class RunResults:
+    centroids: list[Centroid]
+    k: int
+    iterations: list[Iteration]
+
+
+def choose_centroids(conn: psycopg.Connection, k: int = 1000) -> dict[int, Centroid]:
+    final_dict: dict[int, Centroid] = dict()
+    QUERY = """
+select 
+	row_number() over () as id, *
+	from (	
+		select 
+			"Tempo",
+			"Zcr",
+			"MeanSpectralCentroid",
+			"StdDevSpectralCentroid",
+			"MeanSpectralRolloff",
+			"StdDevSpectralRolloff",
+			"MeanSpectralFlatness",
+			"StdDevSpectralFlatness",
+			"MeanLoudness",
+			"StdDevLoudness",
+			"Chroma1",
+			"Chroma2",
+			"Chroma3",
+			"Chroma4",
+			"Chroma5",
+			"Chroma6",
+			"Chroma7",
+			"Chroma8",
+			"Chroma9",
+			"Chroma10"
+		FROM bliss_tracks bt
+		ORDER BY RANDOM() 
+		LIMIT %(k)s
+) with_ids;"""
+    with conn.cursor(row_factory=class_row(Centroid)) as cur:
+        cur.execute(QUERY, {"k": k})
+        for song in cur.fetchall():
+            final_dict[song.id] = song
+
+        return final_dict
+
+
+def get_all_songs(conn: psycopg.Connection) -> list[Song]:
+    QUERY = """
+    select 
+        "File",
+        "Tempo",
+        "Zcr",
+        "MeanSpectralCentroid",
+        "StdDevSpectralCentroid",
+        "MeanSpectralRolloff",
+        "StdDevSpectralRolloff",
+        "MeanSpectralFlatness",
+        "StdDevSpectralFlatness",
+        "MeanLoudness",
+        "StdDevLoudness",
+        "Chroma1",
+        "Chroma2",
+        "Chroma3",
+        "Chroma4",
+        "Chroma5",
+        "Chroma6",
+        "Chroma7",
+        "Chroma8",
+        "Chroma9",
+        "Chroma10"
+    FROM bliss_tracks;"""
+    with conn.cursor(row_factory=class_row(Song)) as cur:
+        cur.execute(QUERY)
+        return cur.fetchall()
+
+
+def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]:
+    centroids_with_songs: dict[int, list[Song]] = defaultdict(list)
+
+    # Now we need to assign songs to centroids
+    for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1):
+        # print(f"Song: {song.File}")
+
+        # Need to find the closest centroid
+        closest_centroid = min(centroids.values(), key=lambda centroid: song @ centroid)
+        # print(f"We've selected group of centroid: {closest_centroid.id} with distance {closest_centroid @ song}")
+        centroids_with_songs[closest_centroid.id].append(song)
+
+    return centroids_with_songs
+
+
+def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dict[int, Centroid]) -> UpdateCentroidsRes:
+    """Returns total delta"""
+    total_delta = 0.0
+    wss = 0.0
+
+    for cent_id, cent_songs in tqdm(centroids_with_songs.items(), desc="Updating centroids", disable=False, delay=1):
+        centroid: Centroid = centroids[cent_id]
+
+        for field_name in IMPORTANT_FIELDS:
+            old_field_value = getattr(centroid, field_name)
+            avg_value_of_field = sum(getattr(song, field_name) for song in cent_songs) / len(cent_songs)
+            setattr(centroid, field_name, avg_value_of_field)
+
+            total_delta += abs(old_field_value - avg_value_of_field)
+
+        wss += sum((centroid @ song) ** 2 for song in cent_songs)
+
+    return UpdateCentroidsRes(cent_move_delta=total_delta, wss=wss)
+
+
+def run(conn: psycopg.Connection, k: int):
+    print(f'Running with k {k}')
+    centroids = choose_centroids(conn, k=k)
+    all_songs = get_all_songs(conn)
+
+    iterations: list[Iteration] = list()
+
+    for iteration_num in range(10):
+        centroids_with_songs = centroids_to_songs(centroids, all_songs)
+        upd_cent_res = update_centroids(centroids_with_songs, centroids)  # Centroids are passed by reference
+
+        iterations.append(Iteration(
+            n=iteration_num,
+            cent_move_delta=upd_cent_res.cent_move_delta,
+            wss=upd_cent_res.wss
+        ))
+
+        pprint(iterations[-1])
+
+        if upd_cent_res.cent_move_delta < 0.5:
+            print("Leaving loop by delta")
+            break
+
+    else:
+        print("Leaving loop by attempts exceeding")
+
+    # print("Clusters:")
+    # for songs in centroids_to_songs(centroids, all_songs).values():
+    #     if len(songs) > 1:
+    #         pprint(songs)
+
+    with open(f'{datetime.datetime.now().isoformat().replace(":", "-")}.json', mode='w', encoding='utf-8') as res_file:
+        results = asdict(RunResults(centroids=list(centroids.values()), k=k, iterations=iterations))
+        json.dump(results, res_file, indent=4)
+
+
+def main(k: int):
+    conn = psycopg.connect(
+        "dbname=postgres user=postgres port=5555 host=192.168.1.68 password=music_recommendation_service_postgres",
+        row_factory=dict_row,
+    )
+
+    try:
+        # for k in range(10000, 16000, 1000):
+        run(conn, k)
+
+    except (Exception, KeyboardInterrupt) as e:
+        conn.close()
+        raise e
+
+
+if __name__ == "__main__":
+    from sys import argv
+
+    main(int(argv[1]))