pure python init
This commit is contained in:
commit
d3bb2e1b6b
19
anaylyse_results.py
Normal file
19
anaylyse_results.py
Normal file
@ -0,0 +1,19 @@
|
||||
from pprint import pprint
|
||||
import json
|
||||
from pathlib import Path
|
||||
from main import RunResults, Iteration
|
||||
|
||||
|
||||
def main():
|
||||
for json_file in Path('.').glob('*.json'):
|
||||
with open(json_file, mode='r') as file:
|
||||
content = json.load(file)
|
||||
|
||||
k = content['k']
|
||||
wss = content['iterations'][-1]['wss']
|
||||
delta = content['iterations'][-1]['cent_move_delta']
|
||||
print(f"{k}, {wss:.2f}, {delta:.2f}, {len(content['iterations'])}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
278
main.py
Normal file
278
main.py
Normal file
@ -0,0 +1,278 @@
|
||||
import datetime
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, asdict
|
||||
from math import sqrt
|
||||
from pprint import pprint
|
||||
|
||||
import psycopg
|
||||
from psycopg.rows import dict_row, class_row
|
||||
from tqdm import tqdm
|
||||
|
||||
IMPORTANT_FIELDS: tuple[str, ...] = (
|
||||
'Tempo',
|
||||
'Zcr',
|
||||
'MeanSpectralCentroid',
|
||||
'StdDevSpectralCentroid',
|
||||
'MeanSpectralRolloff',
|
||||
'StdDevSpectralRolloff',
|
||||
'MeanSpectralFlatness',
|
||||
'StdDevSpectralFlatness',
|
||||
'MeanLoudness',
|
||||
'StdDevLoudness',
|
||||
'Chroma1',
|
||||
'Chroma2',
|
||||
'Chroma3',
|
||||
'Chroma4',
|
||||
'Chroma5',
|
||||
'Chroma6',
|
||||
'Chroma7',
|
||||
'Chroma8',
|
||||
'Chroma9',
|
||||
'Chroma10',
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UpdateCentroidsRes:
|
||||
cent_move_delta: float
|
||||
wss: float # wss - within-cluster-sum of squared errors
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityBase:
|
||||
Tempo: float
|
||||
Zcr: float
|
||||
MeanSpectralCentroid: float
|
||||
StdDevSpectralCentroid: float
|
||||
MeanSpectralRolloff: float
|
||||
StdDevSpectralRolloff: float
|
||||
MeanSpectralFlatness: float
|
||||
StdDevSpectralFlatness: float
|
||||
MeanLoudness: float
|
||||
StdDevLoudness: float
|
||||
Chroma1: float
|
||||
Chroma2: float
|
||||
Chroma3: float
|
||||
Chroma4: float
|
||||
Chroma5: float
|
||||
Chroma6: float
|
||||
Chroma7: float
|
||||
Chroma8: float
|
||||
Chroma9: float
|
||||
Chroma10: float
|
||||
|
||||
def __matmul__(self, other: "EntityBase") -> float:
|
||||
"""Euclidian distance"""
|
||||
return sqrt(
|
||||
(self.Tempo - other.Tempo) ** 2
|
||||
+ (self.Zcr - other.Zcr) ** 2
|
||||
+ (self.MeanSpectralCentroid - other.MeanSpectralCentroid) ** 2
|
||||
+ (self.StdDevSpectralCentroid - other.StdDevSpectralCentroid) ** 2
|
||||
+ (self.MeanSpectralRolloff - other.MeanSpectralRolloff) ** 2
|
||||
+ (self.StdDevSpectralRolloff - other.StdDevSpectralRolloff) ** 2
|
||||
+ (self.MeanSpectralFlatness - other.MeanSpectralFlatness) ** 2
|
||||
+ (self.StdDevSpectralFlatness - other.StdDevSpectralFlatness) ** 2
|
||||
+ (self.MeanLoudness - other.MeanLoudness) ** 2
|
||||
+ (self.StdDevLoudness - other.StdDevLoudness) ** 2
|
||||
+ (self.Chroma1 - other.Chroma1) ** 2
|
||||
+ (self.Chroma2 - other.Chroma2) ** 2
|
||||
+ (self.Chroma3 - other.Chroma3) ** 2
|
||||
+ (self.Chroma4 - other.Chroma4) ** 2
|
||||
+ (self.Chroma5 - other.Chroma5) ** 2
|
||||
+ (self.Chroma6 - other.Chroma6) ** 2
|
||||
+ (self.Chroma7 - other.Chroma7) ** 2
|
||||
+ (self.Chroma8 - other.Chroma8) ** 2
|
||||
+ (self.Chroma9 - other.Chroma9) ** 2
|
||||
+ (self.Chroma10 - other.Chroma10) ** 2
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Centroid(EntityBase):
|
||||
id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Song(EntityBase):
|
||||
File: str
|
||||
|
||||
def __str__(self):
|
||||
return self.File
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Iteration:
|
||||
n: int
|
||||
wss: float
|
||||
cent_move_delta: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunResults:
|
||||
centroids: list[Centroid]
|
||||
k: int
|
||||
iterations: list[Iteration]
|
||||
|
||||
|
||||
def choose_centroids(conn: psycopg.Connection, k: int = 1000) -> dict[int, Centroid]:
|
||||
final_dict: dict[int, Centroid] = dict()
|
||||
QUERY = """
|
||||
select
|
||||
row_number() over () as id, *
|
||||
from (
|
||||
select
|
||||
"Tempo",
|
||||
"Zcr",
|
||||
"MeanSpectralCentroid",
|
||||
"StdDevSpectralCentroid",
|
||||
"MeanSpectralRolloff",
|
||||
"StdDevSpectralRolloff",
|
||||
"MeanSpectralFlatness",
|
||||
"StdDevSpectralFlatness",
|
||||
"MeanLoudness",
|
||||
"StdDevLoudness",
|
||||
"Chroma1",
|
||||
"Chroma2",
|
||||
"Chroma3",
|
||||
"Chroma4",
|
||||
"Chroma5",
|
||||
"Chroma6",
|
||||
"Chroma7",
|
||||
"Chroma8",
|
||||
"Chroma9",
|
||||
"Chroma10"
|
||||
FROM bliss_tracks bt
|
||||
ORDER BY RANDOM()
|
||||
LIMIT %(k)s
|
||||
) with_ids;"""
|
||||
with conn.cursor(row_factory=class_row(Centroid)) as cur:
|
||||
cur.execute(QUERY, {"k": k})
|
||||
for song in cur.fetchall():
|
||||
final_dict[song.id] = song
|
||||
|
||||
return final_dict
|
||||
|
||||
|
||||
def get_all_songs(conn: psycopg.Connection) -> list[Song]:
|
||||
QUERY = """
|
||||
select
|
||||
"File",
|
||||
"Tempo",
|
||||
"Zcr",
|
||||
"MeanSpectralCentroid",
|
||||
"StdDevSpectralCentroid",
|
||||
"MeanSpectralRolloff",
|
||||
"StdDevSpectralRolloff",
|
||||
"MeanSpectralFlatness",
|
||||
"StdDevSpectralFlatness",
|
||||
"MeanLoudness",
|
||||
"StdDevLoudness",
|
||||
"Chroma1",
|
||||
"Chroma2",
|
||||
"Chroma3",
|
||||
"Chroma4",
|
||||
"Chroma5",
|
||||
"Chroma6",
|
||||
"Chroma7",
|
||||
"Chroma8",
|
||||
"Chroma9",
|
||||
"Chroma10"
|
||||
FROM bliss_tracks;"""
|
||||
with conn.cursor(row_factory=class_row(Song)) as cur:
|
||||
cur.execute(QUERY)
|
||||
return cur.fetchall()
|
||||
|
||||
|
||||
def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]:
|
||||
centroids_with_songs: dict[int, list[Song]] = defaultdict(list)
|
||||
|
||||
# Now we need to assign songs to centroids
|
||||
for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1):
|
||||
# print(f"Song: {song.File}")
|
||||
|
||||
# Need to find the closest centroid
|
||||
closest_centroid = min(centroids.values(), key=lambda centroid: song @ centroid)
|
||||
# print(f"We've selected group of centroid: {closest_centroid.id} with distance {closest_centroid @ song}")
|
||||
centroids_with_songs[closest_centroid.id].append(song)
|
||||
|
||||
return centroids_with_songs
|
||||
|
||||
|
||||
def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dict[int, Centroid]) -> UpdateCentroidsRes:
|
||||
"""Returns total delta"""
|
||||
total_delta = 0.0
|
||||
wss = 0.0
|
||||
|
||||
for cent_id, cent_songs in tqdm(centroids_with_songs.items(), desc="Updating centroids", disable=False, delay=1):
|
||||
centroid: Centroid = centroids[cent_id]
|
||||
|
||||
for field_name in IMPORTANT_FIELDS:
|
||||
old_field_value = getattr(centroid, field_name)
|
||||
avg_value_of_field = sum(getattr(song, field_name) for song in cent_songs) / len(cent_songs)
|
||||
setattr(centroid, field_name, avg_value_of_field)
|
||||
|
||||
total_delta += abs(old_field_value - avg_value_of_field)
|
||||
|
||||
wss += sum((centroid @ song) ** 2 for song in cent_songs)
|
||||
|
||||
return UpdateCentroidsRes(cent_move_delta=total_delta, wss=wss)
|
||||
|
||||
|
||||
def run(conn: psycopg.Connection, k: int):
|
||||
print(f'Running with k {k}')
|
||||
centroids = choose_centroids(conn, k=k)
|
||||
all_songs = get_all_songs(conn)
|
||||
|
||||
iterations: list[Iteration] = list()
|
||||
|
||||
for iteration_num in range(10):
|
||||
centroids_with_songs = centroids_to_songs(centroids, all_songs)
|
||||
upd_cent_res = update_centroids(centroids_with_songs, centroids) # Centroids are passed by reference
|
||||
|
||||
iterations.append(Iteration(
|
||||
n=iteration_num,
|
||||
cent_move_delta=upd_cent_res.cent_move_delta,
|
||||
wss=upd_cent_res.wss
|
||||
))
|
||||
|
||||
pprint(iterations[-1])
|
||||
|
||||
if upd_cent_res.cent_move_delta < 0.5:
|
||||
print("Leaving loop by delta")
|
||||
break
|
||||
|
||||
else:
|
||||
print("Leaving loop by attempts exceeding")
|
||||
|
||||
# print("Clusters:")
|
||||
# for songs in centroids_to_songs(centroids, all_songs).values():
|
||||
# if len(songs) > 1:
|
||||
# pprint(songs)
|
||||
|
||||
with open(f'{datetime.datetime.now().isoformat().replace(":", "-")}.json', mode='w', encoding='utf-8') as res_file:
|
||||
results = asdict(RunResults(centroids=list(centroids.values()), k=k, iterations=iterations))
|
||||
json.dump(results, res_file, indent=4)
|
||||
|
||||
|
||||
def main(k: int):
|
||||
conn = psycopg.connect(
|
||||
"dbname=postgres user=postgres port=5555 host=192.168.1.68 password=music_recommendation_service_postgres",
|
||||
row_factory=dict_row,
|
||||
)
|
||||
|
||||
try:
|
||||
# for k in range(10000, 16000, 1000):
|
||||
run(conn, k)
|
||||
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
conn.close()
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from sys import argv
|
||||
|
||||
main(int(argv[1]))
|
Loading…
x
Reference in New Issue
Block a user