141 lines
4.3 KiB
Python
141 lines
4.3 KiB
Python
import datetime
|
|
import json
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, asdict
|
|
from pprint import pprint
|
|
|
|
import numpy as np
|
|
import psycopg
|
|
from numba import njit
|
|
from tqdm import tqdm
|
|
|
|
import db
|
|
from db import choose_centroids, get_all_songs
|
|
from domain import Centroid, Song, IMPORTANT_FIELDS
|
|
|
|
|
|
@njit(cache=True)
|
|
def dist(x: np.ndarray, y: np.ndarray) -> float:
|
|
return np.sqrt(np.sum(np.square(x - y)))
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class UpdateCentroidsRes:
|
|
cent_move_delta: float
|
|
wss: float # wss - within-cluster-sum of squared errors
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Iteration:
|
|
n: int
|
|
wss: float
|
|
cent_move_delta: float
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RunResults: # Aka the model
|
|
centroids: list[Centroid]
|
|
k: int
|
|
iterations: list[Iteration]
|
|
filename: str = ""
|
|
|
|
def __str__(self):
|
|
return f"wss: {self.iterations[-1].wss:.f2} k: {self.k:.f2} n: {len(self.iterations)} {self.filename}"
|
|
|
|
|
|
def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]:
|
|
ldist = dist # Local reference for faster lookup
|
|
centroids_with_songs: dict[int, list[Song]] = defaultdict(list)
|
|
|
|
# Now we need to assign songs to centroids
|
|
for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1):
|
|
# Performance for 10000 centroids for this progres bar
|
|
# pure python 15 it/s
|
|
# np.linalg.norm + jit 20 it/s
|
|
# np.linalg.norm 21 it/s
|
|
# np.dot 40 it/s
|
|
# np.sqrt(np.sum(np.square(x - y))) + jit 150 it/s <- Current implementation
|
|
|
|
song_np_arr = song.np_arr
|
|
|
|
# Need to find the closest centroid
|
|
closest_centroid = min(centroids.values(), key=lambda centroid: ldist(song_np_arr, centroid.np_arr))
|
|
centroids_with_songs[closest_centroid.id].append(song)
|
|
|
|
return centroids_with_songs
|
|
|
|
|
|
def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dict[int, Centroid]) -> UpdateCentroidsRes:
|
|
"""Returns total delta"""
|
|
total_delta = 0.0
|
|
wss = 0.0
|
|
|
|
for cent_id, cent_songs in tqdm(centroids_with_songs.items(), desc="Updating centroids", disable=False, delay=1):
|
|
centroid: Centroid = centroids[cent_id]
|
|
|
|
for field_name in IMPORTANT_FIELDS:
|
|
old_field_value = getattr(centroid, field_name)
|
|
avg_value_of_field = sum(getattr(song, field_name) for song in cent_songs) / len(cent_songs)
|
|
setattr(centroid, field_name, avg_value_of_field)
|
|
|
|
total_delta += abs(old_field_value - avg_value_of_field)
|
|
|
|
wss += sum((dist(centroid.np_arr, song.np_arr)) ** 2 for song in cent_songs)
|
|
|
|
return UpdateCentroidsRes(cent_move_delta=total_delta, wss=wss)
|
|
|
|
|
|
def run(conn: psycopg.Connection, k: int):
|
|
print(f"Running with k {k}")
|
|
centroids = choose_centroids(conn, k=k)
|
|
all_songs = get_all_songs(conn)
|
|
|
|
iterations: list[Iteration] = list()
|
|
|
|
for iteration_num in range(10):
|
|
centroids_with_songs = centroids_to_songs(centroids, all_songs)
|
|
upd_cent_res = update_centroids(centroids_with_songs, centroids) # Centroids are passed by reference
|
|
|
|
iterations.append(
|
|
Iteration(
|
|
n=iteration_num,
|
|
cent_move_delta=upd_cent_res.cent_move_delta,
|
|
wss=upd_cent_res.wss,
|
|
)
|
|
)
|
|
|
|
pprint(iterations[-1])
|
|
|
|
if upd_cent_res.cent_move_delta < 0.5:
|
|
print("Leaving loop by delta")
|
|
break
|
|
|
|
else:
|
|
print("Leaving loop by attempts exceeding")
|
|
|
|
with open(f'{datetime.datetime.now().isoformat().replace(":", "-")}.json', mode="w", encoding="utf-8") as res_file:
|
|
results = asdict(RunResults(centroids=list(centroids.values()), k=k, iterations=iterations))
|
|
for cent_d in results["centroids"]:
|
|
del cent_d["np_arr"]
|
|
|
|
json.dump(results, res_file, indent=4)
|
|
|
|
|
|
def main(k: int):
|
|
conn = db.get_conn()
|
|
|
|
try:
|
|
# for k in range(10000, 16000, 1000):
|
|
run(conn, k)
|
|
|
|
except (Exception, KeyboardInterrupt) as e:
|
|
conn.close()
|
|
print("Connection closed")
|
|
raise e
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from sys import argv
|
|
|
|
main(int(argv[1]))
|