pure python init

This commit is contained in:
norohind 2023-12-06 14:25:22 +03:00
commit d3bb2e1b6b
Signed by: norohind
GPG Key ID: 01C3BECC26FB59E1
2 changed files with 297 additions and 0 deletions

19
anaylyse_results.py Normal file
View File

@ -0,0 +1,19 @@
from pprint import pprint
import json
from pathlib import Path
from main import RunResults, Iteration
def main():
for json_file in Path('.').glob('*.json'):
with open(json_file, mode='r') as file:
content = json.load(file)
k = content['k']
wss = content['iterations'][-1]['wss']
delta = content['iterations'][-1]['cent_move_delta']
print(f"{k}, {wss:.2f}, {delta:.2f}, {len(content['iterations'])}")
if __name__ == '__main__':
main()

278
main.py Normal file
View File

@ -0,0 +1,278 @@
import datetime
import json
from collections import defaultdict
from dataclasses import dataclass, asdict
from math import sqrt
from pprint import pprint
import psycopg
from psycopg.rows import dict_row, class_row
from tqdm import tqdm
IMPORTANT_FIELDS: tuple[str, ...] = (
'Tempo',
'Zcr',
'MeanSpectralCentroid',
'StdDevSpectralCentroid',
'MeanSpectralRolloff',
'StdDevSpectralRolloff',
'MeanSpectralFlatness',
'StdDevSpectralFlatness',
'MeanLoudness',
'StdDevLoudness',
'Chroma1',
'Chroma2',
'Chroma3',
'Chroma4',
'Chroma5',
'Chroma6',
'Chroma7',
'Chroma8',
'Chroma9',
'Chroma10',
)
@dataclass(frozen=True)
class UpdateCentroidsRes:
cent_move_delta: float
wss: float # wss - within-cluster-sum of squared errors
@dataclass
class EntityBase:
Tempo: float
Zcr: float
MeanSpectralCentroid: float
StdDevSpectralCentroid: float
MeanSpectralRolloff: float
StdDevSpectralRolloff: float
MeanSpectralFlatness: float
StdDevSpectralFlatness: float
MeanLoudness: float
StdDevLoudness: float
Chroma1: float
Chroma2: float
Chroma3: float
Chroma4: float
Chroma5: float
Chroma6: float
Chroma7: float
Chroma8: float
Chroma9: float
Chroma10: float
def __matmul__(self, other: "EntityBase") -> float:
"""Euclidian distance"""
return sqrt(
(self.Tempo - other.Tempo) ** 2
+ (self.Zcr - other.Zcr) ** 2
+ (self.MeanSpectralCentroid - other.MeanSpectralCentroid) ** 2
+ (self.StdDevSpectralCentroid - other.StdDevSpectralCentroid) ** 2
+ (self.MeanSpectralRolloff - other.MeanSpectralRolloff) ** 2
+ (self.StdDevSpectralRolloff - other.StdDevSpectralRolloff) ** 2
+ (self.MeanSpectralFlatness - other.MeanSpectralFlatness) ** 2
+ (self.StdDevSpectralFlatness - other.StdDevSpectralFlatness) ** 2
+ (self.MeanLoudness - other.MeanLoudness) ** 2
+ (self.StdDevLoudness - other.StdDevLoudness) ** 2
+ (self.Chroma1 - other.Chroma1) ** 2
+ (self.Chroma2 - other.Chroma2) ** 2
+ (self.Chroma3 - other.Chroma3) ** 2
+ (self.Chroma4 - other.Chroma4) ** 2
+ (self.Chroma5 - other.Chroma5) ** 2
+ (self.Chroma6 - other.Chroma6) ** 2
+ (self.Chroma7 - other.Chroma7) ** 2
+ (self.Chroma8 - other.Chroma8) ** 2
+ (self.Chroma9 - other.Chroma9) ** 2
+ (self.Chroma10 - other.Chroma10) ** 2
)
@dataclass
class Centroid(EntityBase):
id: int
@dataclass
class Song(EntityBase):
File: str
def __str__(self):
return self.File
__repr__ = __str__
@dataclass(frozen=True)
class Iteration:
n: int
wss: float
cent_move_delta: float
@dataclass(frozen=True)
class RunResults:
centroids: list[Centroid]
k: int
iterations: list[Iteration]
def choose_centroids(conn: psycopg.Connection, k: int = 1000) -> dict[int, Centroid]:
final_dict: dict[int, Centroid] = dict()
QUERY = """
select
row_number() over () as id, *
from (
select
"Tempo",
"Zcr",
"MeanSpectralCentroid",
"StdDevSpectralCentroid",
"MeanSpectralRolloff",
"StdDevSpectralRolloff",
"MeanSpectralFlatness",
"StdDevSpectralFlatness",
"MeanLoudness",
"StdDevLoudness",
"Chroma1",
"Chroma2",
"Chroma3",
"Chroma4",
"Chroma5",
"Chroma6",
"Chroma7",
"Chroma8",
"Chroma9",
"Chroma10"
FROM bliss_tracks bt
ORDER BY RANDOM()
LIMIT %(k)s
) with_ids;"""
with conn.cursor(row_factory=class_row(Centroid)) as cur:
cur.execute(QUERY, {"k": k})
for song in cur.fetchall():
final_dict[song.id] = song
return final_dict
def get_all_songs(conn: psycopg.Connection) -> list[Song]:
QUERY = """
select
"File",
"Tempo",
"Zcr",
"MeanSpectralCentroid",
"StdDevSpectralCentroid",
"MeanSpectralRolloff",
"StdDevSpectralRolloff",
"MeanSpectralFlatness",
"StdDevSpectralFlatness",
"MeanLoudness",
"StdDevLoudness",
"Chroma1",
"Chroma2",
"Chroma3",
"Chroma4",
"Chroma5",
"Chroma6",
"Chroma7",
"Chroma8",
"Chroma9",
"Chroma10"
FROM bliss_tracks;"""
with conn.cursor(row_factory=class_row(Song)) as cur:
cur.execute(QUERY)
return cur.fetchall()
def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]:
centroids_with_songs: dict[int, list[Song]] = defaultdict(list)
# Now we need to assign songs to centroids
for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1):
# print(f"Song: {song.File}")
# Need to find the closest centroid
closest_centroid = min(centroids.values(), key=lambda centroid: song @ centroid)
# print(f"We've selected group of centroid: {closest_centroid.id} with distance {closest_centroid @ song}")
centroids_with_songs[closest_centroid.id].append(song)
return centroids_with_songs
def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dict[int, Centroid]) -> UpdateCentroidsRes:
"""Returns total delta"""
total_delta = 0.0
wss = 0.0
for cent_id, cent_songs in tqdm(centroids_with_songs.items(), desc="Updating centroids", disable=False, delay=1):
centroid: Centroid = centroids[cent_id]
for field_name in IMPORTANT_FIELDS:
old_field_value = getattr(centroid, field_name)
avg_value_of_field = sum(getattr(song, field_name) for song in cent_songs) / len(cent_songs)
setattr(centroid, field_name, avg_value_of_field)
total_delta += abs(old_field_value - avg_value_of_field)
wss += sum((centroid @ song) ** 2 for song in cent_songs)
return UpdateCentroidsRes(cent_move_delta=total_delta, wss=wss)
def run(conn: psycopg.Connection, k: int):
print(f'Running with k {k}')
centroids = choose_centroids(conn, k=k)
all_songs = get_all_songs(conn)
iterations: list[Iteration] = list()
for iteration_num in range(10):
centroids_with_songs = centroids_to_songs(centroids, all_songs)
upd_cent_res = update_centroids(centroids_with_songs, centroids) # Centroids are passed by reference
iterations.append(Iteration(
n=iteration_num,
cent_move_delta=upd_cent_res.cent_move_delta,
wss=upd_cent_res.wss
))
pprint(iterations[-1])
if upd_cent_res.cent_move_delta < 0.5:
print("Leaving loop by delta")
break
else:
print("Leaving loop by attempts exceeding")
# print("Clusters:")
# for songs in centroids_to_songs(centroids, all_songs).values():
# if len(songs) > 1:
# pprint(songs)
with open(f'{datetime.datetime.now().isoformat().replace(":", "-")}.json', mode='w', encoding='utf-8') as res_file:
results = asdict(RunResults(centroids=list(centroids.values()), k=k, iterations=iterations))
json.dump(results, res_file, indent=4)
def main(k: int):
conn = psycopg.connect(
"dbname=postgres user=postgres port=5555 host=192.168.1.68 password=music_recommendation_service_postgres",
row_factory=dict_row,
)
try:
# for k in range(10000, 16000, 1000):
run(conn, k)
except (Exception, KeyboardInterrupt) as e:
conn.close()
raise e
if __name__ == "__main__":
from sys import argv
main(int(argv[1]))