141 lines
4.3 KiB
Python

import datetime
import json
from collections import defaultdict
from dataclasses import dataclass, asdict
from pprint import pprint
import numpy as np
import psycopg
from numba import njit
from tqdm import tqdm
import db
from db import choose_centroids, get_all_songs
from domain import Centroid, Song, IMPORTANT_FIELDS
@njit(cache=True)
def dist(x: np.ndarray, y: np.ndarray) -> float:
return np.sqrt(np.sum(np.square(x - y)))
@dataclass(frozen=True)
class UpdateCentroidsRes:
cent_move_delta: float
wss: float # wss - within-cluster-sum of squared errors
@dataclass(frozen=True)
class Iteration:
n: int
wss: float
cent_move_delta: float
@dataclass(frozen=True)
class RunResults: # Aka the model
centroids: list[Centroid]
k: int
iterations: list[Iteration]
filename: str = ""
def __str__(self):
return f"wss: {self.iterations[-1].wss:.f2} k: {self.k:.f2} n: {len(self.iterations)} {self.filename}"
def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]:
ldist = dist # Local reference for faster lookup
centroids_with_songs: dict[int, list[Song]] = defaultdict(list)
# Now we need to assign songs to centroids
for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1):
# Performance for 10000 centroids for this progres bar
# pure python 15 it/s
# np.linalg.norm + jit 20 it/s
# np.linalg.norm 21 it/s
# np.dot 40 it/s
# np.sqrt(np.sum(np.square(x - y))) + jit 150 it/s <- Current implementation
song_np_arr = song.np_arr
# Need to find the closest centroid
closest_centroid = min(centroids.values(), key=lambda centroid: ldist(song_np_arr, centroid.np_arr))
centroids_with_songs[closest_centroid.id].append(song)
return centroids_with_songs
def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dict[int, Centroid]) -> UpdateCentroidsRes:
"""Returns total delta"""
total_delta = 0.0
wss = 0.0
for cent_id, cent_songs in tqdm(centroids_with_songs.items(), desc="Updating centroids", disable=False, delay=1):
centroid: Centroid = centroids[cent_id]
for field_name in IMPORTANT_FIELDS:
old_field_value = getattr(centroid, field_name)
avg_value_of_field = sum(getattr(song, field_name) for song in cent_songs) / len(cent_songs)
setattr(centroid, field_name, avg_value_of_field)
total_delta += abs(old_field_value - avg_value_of_field)
wss += sum((dist(centroid.np_arr, song.np_arr)) ** 2 for song in cent_songs)
return UpdateCentroidsRes(cent_move_delta=total_delta, wss=wss)
def run(conn: psycopg.Connection, k: int):
print(f"Running with k {k}")
centroids = choose_centroids(conn, k=k)
all_songs = get_all_songs(conn)
iterations: list[Iteration] = list()
for iteration_num in range(10):
centroids_with_songs = centroids_to_songs(centroids, all_songs)
upd_cent_res = update_centroids(centroids_with_songs, centroids) # Centroids are passed by reference
iterations.append(
Iteration(
n=iteration_num,
cent_move_delta=upd_cent_res.cent_move_delta,
wss=upd_cent_res.wss,
)
)
pprint(iterations[-1])
if upd_cent_res.cent_move_delta < 0.5:
print("Leaving loop by delta")
break
else:
print("Leaving loop by attempts exceeding")
with open(f'{datetime.datetime.now().isoformat().replace(":", "-")}.json', mode="w", encoding="utf-8") as res_file:
results = asdict(RunResults(centroids=list(centroids.values()), k=k, iterations=iterations))
for cent_d in results["centroids"]:
del cent_d["np_arr"]
json.dump(results, res_file, indent=4)
def main(k: int):
conn = db.get_conn()
try:
# for k in range(10000, 16000, 1000):
run(conn, k)
except (Exception, KeyboardInterrupt) as e:
conn.close()
print("Connection closed")
raise e
if __name__ == "__main__":
from sys import argv
main(int(argv[1]))