pure python init
This commit is contained in:
commit
d3bb2e1b6b
19
anaylyse_results.py
Normal file
19
anaylyse_results.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from pprint import pprint
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from main import RunResults, Iteration
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
for json_file in Path('.').glob('*.json'):
|
||||||
|
with open(json_file, mode='r') as file:
|
||||||
|
content = json.load(file)
|
||||||
|
|
||||||
|
k = content['k']
|
||||||
|
wss = content['iterations'][-1]['wss']
|
||||||
|
delta = content['iterations'][-1]['cent_move_delta']
|
||||||
|
print(f"{k}, {wss:.2f}, {delta:.2f}, {len(content['iterations'])}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
278
main.py
Normal file
278
main.py
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from math import sqrt
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
import psycopg
|
||||||
|
from psycopg.rows import dict_row, class_row
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
IMPORTANT_FIELDS: tuple[str, ...] = (
|
||||||
|
'Tempo',
|
||||||
|
'Zcr',
|
||||||
|
'MeanSpectralCentroid',
|
||||||
|
'StdDevSpectralCentroid',
|
||||||
|
'MeanSpectralRolloff',
|
||||||
|
'StdDevSpectralRolloff',
|
||||||
|
'MeanSpectralFlatness',
|
||||||
|
'StdDevSpectralFlatness',
|
||||||
|
'MeanLoudness',
|
||||||
|
'StdDevLoudness',
|
||||||
|
'Chroma1',
|
||||||
|
'Chroma2',
|
||||||
|
'Chroma3',
|
||||||
|
'Chroma4',
|
||||||
|
'Chroma5',
|
||||||
|
'Chroma6',
|
||||||
|
'Chroma7',
|
||||||
|
'Chroma8',
|
||||||
|
'Chroma9',
|
||||||
|
'Chroma10',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class UpdateCentroidsRes:
|
||||||
|
cent_move_delta: float
|
||||||
|
wss: float # wss - within-cluster-sum of squared errors
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EntityBase:
|
||||||
|
Tempo: float
|
||||||
|
Zcr: float
|
||||||
|
MeanSpectralCentroid: float
|
||||||
|
StdDevSpectralCentroid: float
|
||||||
|
MeanSpectralRolloff: float
|
||||||
|
StdDevSpectralRolloff: float
|
||||||
|
MeanSpectralFlatness: float
|
||||||
|
StdDevSpectralFlatness: float
|
||||||
|
MeanLoudness: float
|
||||||
|
StdDevLoudness: float
|
||||||
|
Chroma1: float
|
||||||
|
Chroma2: float
|
||||||
|
Chroma3: float
|
||||||
|
Chroma4: float
|
||||||
|
Chroma5: float
|
||||||
|
Chroma6: float
|
||||||
|
Chroma7: float
|
||||||
|
Chroma8: float
|
||||||
|
Chroma9: float
|
||||||
|
Chroma10: float
|
||||||
|
|
||||||
|
def __matmul__(self, other: "EntityBase") -> float:
|
||||||
|
"""Euclidian distance"""
|
||||||
|
return sqrt(
|
||||||
|
(self.Tempo - other.Tempo) ** 2
|
||||||
|
+ (self.Zcr - other.Zcr) ** 2
|
||||||
|
+ (self.MeanSpectralCentroid - other.MeanSpectralCentroid) ** 2
|
||||||
|
+ (self.StdDevSpectralCentroid - other.StdDevSpectralCentroid) ** 2
|
||||||
|
+ (self.MeanSpectralRolloff - other.MeanSpectralRolloff) ** 2
|
||||||
|
+ (self.StdDevSpectralRolloff - other.StdDevSpectralRolloff) ** 2
|
||||||
|
+ (self.MeanSpectralFlatness - other.MeanSpectralFlatness) ** 2
|
||||||
|
+ (self.StdDevSpectralFlatness - other.StdDevSpectralFlatness) ** 2
|
||||||
|
+ (self.MeanLoudness - other.MeanLoudness) ** 2
|
||||||
|
+ (self.StdDevLoudness - other.StdDevLoudness) ** 2
|
||||||
|
+ (self.Chroma1 - other.Chroma1) ** 2
|
||||||
|
+ (self.Chroma2 - other.Chroma2) ** 2
|
||||||
|
+ (self.Chroma3 - other.Chroma3) ** 2
|
||||||
|
+ (self.Chroma4 - other.Chroma4) ** 2
|
||||||
|
+ (self.Chroma5 - other.Chroma5) ** 2
|
||||||
|
+ (self.Chroma6 - other.Chroma6) ** 2
|
||||||
|
+ (self.Chroma7 - other.Chroma7) ** 2
|
||||||
|
+ (self.Chroma8 - other.Chroma8) ** 2
|
||||||
|
+ (self.Chroma9 - other.Chroma9) ** 2
|
||||||
|
+ (self.Chroma10 - other.Chroma10) ** 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Centroid(EntityBase):
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Song(EntityBase):
|
||||||
|
File: str
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.File
|
||||||
|
|
||||||
|
__repr__ = __str__
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Iteration:
|
||||||
|
n: int
|
||||||
|
wss: float
|
||||||
|
cent_move_delta: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RunResults:
|
||||||
|
centroids: list[Centroid]
|
||||||
|
k: int
|
||||||
|
iterations: list[Iteration]
|
||||||
|
|
||||||
|
|
||||||
|
def choose_centroids(conn: psycopg.Connection, k: int = 1000) -> dict[int, Centroid]:
|
||||||
|
final_dict: dict[int, Centroid] = dict()
|
||||||
|
QUERY = """
|
||||||
|
select
|
||||||
|
row_number() over () as id, *
|
||||||
|
from (
|
||||||
|
select
|
||||||
|
"Tempo",
|
||||||
|
"Zcr",
|
||||||
|
"MeanSpectralCentroid",
|
||||||
|
"StdDevSpectralCentroid",
|
||||||
|
"MeanSpectralRolloff",
|
||||||
|
"StdDevSpectralRolloff",
|
||||||
|
"MeanSpectralFlatness",
|
||||||
|
"StdDevSpectralFlatness",
|
||||||
|
"MeanLoudness",
|
||||||
|
"StdDevLoudness",
|
||||||
|
"Chroma1",
|
||||||
|
"Chroma2",
|
||||||
|
"Chroma3",
|
||||||
|
"Chroma4",
|
||||||
|
"Chroma5",
|
||||||
|
"Chroma6",
|
||||||
|
"Chroma7",
|
||||||
|
"Chroma8",
|
||||||
|
"Chroma9",
|
||||||
|
"Chroma10"
|
||||||
|
FROM bliss_tracks bt
|
||||||
|
ORDER BY RANDOM()
|
||||||
|
LIMIT %(k)s
|
||||||
|
) with_ids;"""
|
||||||
|
with conn.cursor(row_factory=class_row(Centroid)) as cur:
|
||||||
|
cur.execute(QUERY, {"k": k})
|
||||||
|
for song in cur.fetchall():
|
||||||
|
final_dict[song.id] = song
|
||||||
|
|
||||||
|
return final_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_songs(conn: psycopg.Connection) -> list[Song]:
|
||||||
|
QUERY = """
|
||||||
|
select
|
||||||
|
"File",
|
||||||
|
"Tempo",
|
||||||
|
"Zcr",
|
||||||
|
"MeanSpectralCentroid",
|
||||||
|
"StdDevSpectralCentroid",
|
||||||
|
"MeanSpectralRolloff",
|
||||||
|
"StdDevSpectralRolloff",
|
||||||
|
"MeanSpectralFlatness",
|
||||||
|
"StdDevSpectralFlatness",
|
||||||
|
"MeanLoudness",
|
||||||
|
"StdDevLoudness",
|
||||||
|
"Chroma1",
|
||||||
|
"Chroma2",
|
||||||
|
"Chroma3",
|
||||||
|
"Chroma4",
|
||||||
|
"Chroma5",
|
||||||
|
"Chroma6",
|
||||||
|
"Chroma7",
|
||||||
|
"Chroma8",
|
||||||
|
"Chroma9",
|
||||||
|
"Chroma10"
|
||||||
|
FROM bliss_tracks;"""
|
||||||
|
with conn.cursor(row_factory=class_row(Song)) as cur:
|
||||||
|
cur.execute(QUERY)
|
||||||
|
return cur.fetchall()
|
||||||
|
|
||||||
|
|
||||||
|
def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]:
|
||||||
|
centroids_with_songs: dict[int, list[Song]] = defaultdict(list)
|
||||||
|
|
||||||
|
# Now we need to assign songs to centroids
|
||||||
|
for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1):
|
||||||
|
# print(f"Song: {song.File}")
|
||||||
|
|
||||||
|
# Need to find the closest centroid
|
||||||
|
closest_centroid = min(centroids.values(), key=lambda centroid: song @ centroid)
|
||||||
|
# print(f"We've selected group of centroid: {closest_centroid.id} with distance {closest_centroid @ song}")
|
||||||
|
centroids_with_songs[closest_centroid.id].append(song)
|
||||||
|
|
||||||
|
return centroids_with_songs
|
||||||
|
|
||||||
|
|
||||||
|
def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dict[int, Centroid]) -> UpdateCentroidsRes:
|
||||||
|
"""Returns total delta"""
|
||||||
|
total_delta = 0.0
|
||||||
|
wss = 0.0
|
||||||
|
|
||||||
|
for cent_id, cent_songs in tqdm(centroids_with_songs.items(), desc="Updating centroids", disable=False, delay=1):
|
||||||
|
centroid: Centroid = centroids[cent_id]
|
||||||
|
|
||||||
|
for field_name in IMPORTANT_FIELDS:
|
||||||
|
old_field_value = getattr(centroid, field_name)
|
||||||
|
avg_value_of_field = sum(getattr(song, field_name) for song in cent_songs) / len(cent_songs)
|
||||||
|
setattr(centroid, field_name, avg_value_of_field)
|
||||||
|
|
||||||
|
total_delta += abs(old_field_value - avg_value_of_field)
|
||||||
|
|
||||||
|
wss += sum((centroid @ song) ** 2 for song in cent_songs)
|
||||||
|
|
||||||
|
return UpdateCentroidsRes(cent_move_delta=total_delta, wss=wss)
|
||||||
|
|
||||||
|
|
||||||
|
def run(conn: psycopg.Connection, k: int):
|
||||||
|
print(f'Running with k {k}')
|
||||||
|
centroids = choose_centroids(conn, k=k)
|
||||||
|
all_songs = get_all_songs(conn)
|
||||||
|
|
||||||
|
iterations: list[Iteration] = list()
|
||||||
|
|
||||||
|
for iteration_num in range(10):
|
||||||
|
centroids_with_songs = centroids_to_songs(centroids, all_songs)
|
||||||
|
upd_cent_res = update_centroids(centroids_with_songs, centroids) # Centroids are passed by reference
|
||||||
|
|
||||||
|
iterations.append(Iteration(
|
||||||
|
n=iteration_num,
|
||||||
|
cent_move_delta=upd_cent_res.cent_move_delta,
|
||||||
|
wss=upd_cent_res.wss
|
||||||
|
))
|
||||||
|
|
||||||
|
pprint(iterations[-1])
|
||||||
|
|
||||||
|
if upd_cent_res.cent_move_delta < 0.5:
|
||||||
|
print("Leaving loop by delta")
|
||||||
|
break
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Leaving loop by attempts exceeding")
|
||||||
|
|
||||||
|
# print("Clusters:")
|
||||||
|
# for songs in centroids_to_songs(centroids, all_songs).values():
|
||||||
|
# if len(songs) > 1:
|
||||||
|
# pprint(songs)
|
||||||
|
|
||||||
|
with open(f'{datetime.datetime.now().isoformat().replace(":", "-")}.json', mode='w', encoding='utf-8') as res_file:
|
||||||
|
results = asdict(RunResults(centroids=list(centroids.values()), k=k, iterations=iterations))
|
||||||
|
json.dump(results, res_file, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def main(k: int):
|
||||||
|
conn = psycopg.connect(
|
||||||
|
"dbname=postgres user=postgres port=5555 host=192.168.1.68 password=music_recommendation_service_postgres",
|
||||||
|
row_factory=dict_row,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# for k in range(10000, 16000, 1000):
|
||||||
|
run(conn, k)
|
||||||
|
|
||||||
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
|
conn.close()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from sys import argv
|
||||||
|
|
||||||
|
main(int(argv[1]))
|
Loading…
x
Reference in New Issue
Block a user