Split to files, create inference.py, analyse_results.py
This commit is contained in:
parent
abc4251c7c
commit
805eab808b
@ -1,19 +1,43 @@
|
||||
from pprint import pprint
|
||||
import json
|
||||
from pathlib import Path
|
||||
from main import RunResults, Iteration
|
||||
from main import RunResults, Iteration, Centroid
|
||||
import matplotlib.pyplot as plt
|
||||
from persistance import run_result_from_dict
|
||||
|
||||
|
||||
def process(runs: list[RunResults]) -> None:
|
||||
runs.sort(key=lambda a: a.k)
|
||||
x = tuple(run.k for run in runs)
|
||||
y = tuple(run.iterations[-1].wss for run in runs)
|
||||
plt.scatter(x, y)
|
||||
plt.ylabel("wss")
|
||||
plt.xlabel("K")
|
||||
# for i_x, i_y in zip(x, y): # For k labels on every point
|
||||
# plt.text(i_x, i_y, '{}'.format(i_x))
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
# def process(runs: list[RunResults]) -> None:
|
||||
# for run in runs:
|
||||
# if run.k == 1000:
|
||||
# print(run.k, run.iterations[-1].wss)
|
||||
|
||||
|
||||
def main():
|
||||
for json_file in Path('.').glob('*.json'):
|
||||
with open(json_file, mode='r') as file:
|
||||
content = json.load(file)
|
||||
runs = list()
|
||||
for json_file in Path(".").glob("*.json"):
|
||||
with open(json_file, mode="r") as file:
|
||||
try:
|
||||
runs.append(run_result_from_dict(json.load(file), str(json_file)))
|
||||
|
||||
k = content['k']
|
||||
wss = content['iterations'][-1]['wss']
|
||||
delta = content['iterations'][-1]['cent_move_delta']
|
||||
print(f"{k}, {wss:.2f}, {delta:.2f}, {len(content['iterations'])}")
|
||||
except Exception as e:
|
||||
print(json_file)
|
||||
raise e
|
||||
|
||||
process(runs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
80
db.py
Normal file
80
db.py
Normal file
@ -0,0 +1,80 @@
|
||||
import psycopg
|
||||
from psycopg.rows import dict_row, class_row
|
||||
from domain import Centroid, Song
|
||||
|
||||
|
||||
def choose_centroids(conn: psycopg.Connection, k: int = 1000) -> dict[int, Centroid]:
|
||||
final_dict: dict[int, Centroid] = dict()
|
||||
QUERY = """
|
||||
select
|
||||
row_number() over () as id, *
|
||||
from (
|
||||
select
|
||||
"Tempo",
|
||||
"Zcr",
|
||||
"MeanSpectralCentroid",
|
||||
"StdDevSpectralCentroid",
|
||||
"MeanSpectralRolloff",
|
||||
"StdDevSpectralRolloff",
|
||||
"MeanSpectralFlatness",
|
||||
"StdDevSpectralFlatness",
|
||||
"MeanLoudness",
|
||||
"StdDevLoudness",
|
||||
"Chroma1",
|
||||
"Chroma2",
|
||||
"Chroma3",
|
||||
"Chroma4",
|
||||
"Chroma5",
|
||||
"Chroma6",
|
||||
"Chroma7",
|
||||
"Chroma8",
|
||||
"Chroma9",
|
||||
"Chroma10"
|
||||
FROM bliss_tracks bt
|
||||
ORDER BY RANDOM()
|
||||
LIMIT %(k)s
|
||||
) with_ids;"""
|
||||
with conn.cursor(row_factory=class_row(Centroid)) as cur:
|
||||
cur.execute(QUERY, {"k": k})
|
||||
for centroid in cur.fetchall():
|
||||
final_dict[centroid.id] = centroid
|
||||
|
||||
return final_dict
|
||||
|
||||
|
||||
def get_all_songs(conn: psycopg.Connection) -> list[Song]:
|
||||
QUERY = """
|
||||
select
|
||||
"File",
|
||||
"Tempo",
|
||||
"Zcr",
|
||||
"MeanSpectralCentroid",
|
||||
"StdDevSpectralCentroid",
|
||||
"MeanSpectralRolloff",
|
||||
"StdDevSpectralRolloff",
|
||||
"MeanSpectralFlatness",
|
||||
"StdDevSpectralFlatness",
|
||||
"MeanLoudness",
|
||||
"StdDevLoudness",
|
||||
"Chroma1",
|
||||
"Chroma2",
|
||||
"Chroma3",
|
||||
"Chroma4",
|
||||
"Chroma5",
|
||||
"Chroma6",
|
||||
"Chroma7",
|
||||
"Chroma8",
|
||||
"Chroma9",
|
||||
"Chroma10"
|
||||
FROM bliss_tracks;"""
|
||||
with conn.cursor(row_factory=class_row(Song)) as cur:
|
||||
cur.execute(QUERY)
|
||||
return cur.fetchall()
|
||||
|
||||
|
||||
def get_conn() -> psycopg.Connection:
|
||||
from os import environ
|
||||
|
||||
return psycopg.connect(
|
||||
f"dbname=postgres user=postgres port=5555 host=192.168.1.68 password={environ['PGPASSWORD']}"
|
||||
)
|
80
domain.py
Normal file
80
domain.py
Normal file
@ -0,0 +1,80 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
|
||||
IMPORTANT_FIELDS: tuple[str, ...] = (
|
||||
"Tempo",
|
||||
"Zcr",
|
||||
"MeanSpectralCentroid",
|
||||
"StdDevSpectralCentroid",
|
||||
"MeanSpectralRolloff",
|
||||
"StdDevSpectralRolloff",
|
||||
"MeanSpectralFlatness",
|
||||
"StdDevSpectralFlatness",
|
||||
"MeanLoudness",
|
||||
"StdDevLoudness",
|
||||
"Chroma1",
|
||||
"Chroma2",
|
||||
"Chroma3",
|
||||
"Chroma4",
|
||||
"Chroma5",
|
||||
"Chroma6",
|
||||
"Chroma7",
|
||||
"Chroma8",
|
||||
"Chroma9",
|
||||
"Chroma10",
|
||||
)
|
||||
|
||||
|
||||
class DimensionDescriptor:
|
||||
def __set_name__(self, owner, name):
|
||||
self.name: str = name
|
||||
self.idx: int = IMPORTANT_FIELDS.index(name)
|
||||
|
||||
def __get__(self, instance: "EntityBase", owner) -> float:
|
||||
return instance.np_arr[self.idx]
|
||||
|
||||
def __set__(self, instance: "EntityBase", value: float):
|
||||
instance.np_arr[self.idx] = value
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class EntityBase:
|
||||
np_arr: np.ndarray = field(
|
||||
init=False, repr=False, default_factory=lambda: np.zeros(len(IMPORTANT_FIELDS))
|
||||
)
|
||||
Tempo: float = DimensionDescriptor()
|
||||
Zcr: float = DimensionDescriptor()
|
||||
MeanSpectralCentroid: float = DimensionDescriptor()
|
||||
StdDevSpectralCentroid: float = DimensionDescriptor()
|
||||
MeanSpectralRolloff: float = DimensionDescriptor()
|
||||
StdDevSpectralRolloff: float = DimensionDescriptor()
|
||||
MeanSpectralFlatness: float = DimensionDescriptor()
|
||||
StdDevSpectralFlatness: float = DimensionDescriptor()
|
||||
MeanLoudness: float = DimensionDescriptor()
|
||||
StdDevLoudness: float = DimensionDescriptor()
|
||||
Chroma1: float = DimensionDescriptor()
|
||||
Chroma2: float = DimensionDescriptor()
|
||||
Chroma3: float = DimensionDescriptor()
|
||||
Chroma4: float = DimensionDescriptor()
|
||||
Chroma5: float = DimensionDescriptor()
|
||||
Chroma6: float = DimensionDescriptor()
|
||||
Chroma7: float = DimensionDescriptor()
|
||||
Chroma8: float = DimensionDescriptor()
|
||||
Chroma9: float = DimensionDescriptor()
|
||||
Chroma10: float = DimensionDescriptor()
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Centroid(EntityBase):
|
||||
id: int
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Song(EntityBase):
|
||||
File: str
|
||||
|
||||
def __str__(self):
|
||||
return self.File
|
||||
|
||||
__repr__ = __str__
|
36
inference.py
Normal file
36
inference.py
Normal file
@ -0,0 +1,36 @@
|
||||
import persistance
|
||||
from main import RunResults, centroids_to_songs
|
||||
from db import get_all_songs, get_conn
|
||||
from domain import Song
|
||||
from contextlib import closing
|
||||
|
||||
# Load model and show songs assigned to clusters
|
||||
|
||||
|
||||
def inference(model: RunResults):
|
||||
with closing(get_conn()) as conn:
|
||||
songs: list[Song] = get_all_songs(conn)
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
centroids = {cent.id: cent for cent in model.centroids}
|
||||
clusters = centroids_to_songs(centroids, songs)
|
||||
for cluster in clusters.values():
|
||||
for song in cluster:
|
||||
if song.File == "/music/music_mirror/PoH/Green Day - 21st Century Breakdown.mp3":
|
||||
pprint(cluster)
|
||||
break
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run model")
|
||||
parser.add_argument("filename")
|
||||
args = parser.parse_args()
|
||||
model = persistance.load_model(args.filename)
|
||||
inference(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
196
main.py
196
main.py
@ -1,50 +1,17 @@
|
||||
import datetime
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from numba import jit, njit
|
||||
from math import sqrt
|
||||
from dataclasses import dataclass, asdict
|
||||
from pprint import pprint
|
||||
|
||||
import numpy as np
|
||||
import psycopg
|
||||
from psycopg.rows import dict_row, class_row
|
||||
from numba import njit
|
||||
from tqdm import tqdm
|
||||
|
||||
IMPORTANT_FIELDS: tuple[str, ...] = (
|
||||
'Tempo',
|
||||
'Zcr',
|
||||
'MeanSpectralCentroid',
|
||||
'StdDevSpectralCentroid',
|
||||
'MeanSpectralRolloff',
|
||||
'StdDevSpectralRolloff',
|
||||
'MeanSpectralFlatness',
|
||||
'StdDevSpectralFlatness',
|
||||
'MeanLoudness',
|
||||
'StdDevLoudness',
|
||||
'Chroma1',
|
||||
'Chroma2',
|
||||
'Chroma3',
|
||||
'Chroma4',
|
||||
'Chroma5',
|
||||
'Chroma6',
|
||||
'Chroma7',
|
||||
'Chroma8',
|
||||
'Chroma9',
|
||||
'Chroma10',
|
||||
)
|
||||
|
||||
|
||||
class DimensionDescriptor:
|
||||
def __set_name__(self, owner, name):
|
||||
self.name: str = name
|
||||
self.idx: int = IMPORTANT_FIELDS.index(name)
|
||||
|
||||
def __get__(self, instance: 'EntityBase', owner) -> float:
|
||||
return instance.np_arr[self.idx]
|
||||
|
||||
def __set__(self, instance: 'EntityBase', value: float):
|
||||
instance.np_arr[self.idx] = value
|
||||
import db
|
||||
from db import choose_centroids, get_all_songs
|
||||
from domain import Centroid, Song, IMPORTANT_FIELDS
|
||||
|
||||
|
||||
@njit(cache=True)
|
||||
@ -58,46 +25,6 @@ class UpdateCentroidsRes:
|
||||
wss: float # wss - within-cluster-sum of squared errors
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class EntityBase:
|
||||
np_arr: np.ndarray = field(init=False, repr=False, default_factory=lambda: np.zeros(len(IMPORTANT_FIELDS)))
|
||||
Tempo: float = DimensionDescriptor()
|
||||
Zcr: float = DimensionDescriptor()
|
||||
MeanSpectralCentroid: float = DimensionDescriptor()
|
||||
StdDevSpectralCentroid: float = DimensionDescriptor()
|
||||
MeanSpectralRolloff: float = DimensionDescriptor()
|
||||
StdDevSpectralRolloff: float = DimensionDescriptor()
|
||||
MeanSpectralFlatness: float = DimensionDescriptor()
|
||||
StdDevSpectralFlatness: float = DimensionDescriptor()
|
||||
MeanLoudness: float = DimensionDescriptor()
|
||||
StdDevLoudness: float = DimensionDescriptor()
|
||||
Chroma1: float = DimensionDescriptor()
|
||||
Chroma2: float = DimensionDescriptor()
|
||||
Chroma3: float = DimensionDescriptor()
|
||||
Chroma4: float = DimensionDescriptor()
|
||||
Chroma5: float = DimensionDescriptor()
|
||||
Chroma6: float = DimensionDescriptor()
|
||||
Chroma7: float = DimensionDescriptor()
|
||||
Chroma8: float = DimensionDescriptor()
|
||||
Chroma9: float = DimensionDescriptor()
|
||||
Chroma10: float = DimensionDescriptor()
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Centroid(EntityBase):
|
||||
id: int
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Song(EntityBase):
|
||||
File: str
|
||||
|
||||
def __str__(self):
|
||||
return self.File
|
||||
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Iteration:
|
||||
n: int
|
||||
@ -106,79 +33,14 @@ class Iteration:
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunResults:
|
||||
class RunResults: # Aka the model
|
||||
centroids: list[Centroid]
|
||||
k: int
|
||||
iterations: list[Iteration]
|
||||
filename: str = ""
|
||||
|
||||
|
||||
def choose_centroids(conn: psycopg.Connection, k: int = 1000) -> dict[int, Centroid]:
|
||||
final_dict: dict[int, Centroid] = dict()
|
||||
QUERY = """
|
||||
select
|
||||
row_number() over () as id, *
|
||||
from (
|
||||
select
|
||||
"Tempo",
|
||||
"Zcr",
|
||||
"MeanSpectralCentroid",
|
||||
"StdDevSpectralCentroid",
|
||||
"MeanSpectralRolloff",
|
||||
"StdDevSpectralRolloff",
|
||||
"MeanSpectralFlatness",
|
||||
"StdDevSpectralFlatness",
|
||||
"MeanLoudness",
|
||||
"StdDevLoudness",
|
||||
"Chroma1",
|
||||
"Chroma2",
|
||||
"Chroma3",
|
||||
"Chroma4",
|
||||
"Chroma5",
|
||||
"Chroma6",
|
||||
"Chroma7",
|
||||
"Chroma8",
|
||||
"Chroma9",
|
||||
"Chroma10"
|
||||
FROM bliss_tracks bt
|
||||
ORDER BY RANDOM()
|
||||
LIMIT %(k)s
|
||||
) with_ids;"""
|
||||
with conn.cursor(row_factory=class_row(Centroid)) as cur:
|
||||
cur.execute(QUERY, {"k": k})
|
||||
for song in cur.fetchall():
|
||||
final_dict[song.id] = song
|
||||
|
||||
return final_dict
|
||||
|
||||
|
||||
def get_all_songs(conn: psycopg.Connection) -> list[Song]:
|
||||
QUERY = """
|
||||
select
|
||||
"File",
|
||||
"Tempo",
|
||||
"Zcr",
|
||||
"MeanSpectralCentroid",
|
||||
"StdDevSpectralCentroid",
|
||||
"MeanSpectralRolloff",
|
||||
"StdDevSpectralRolloff",
|
||||
"MeanSpectralFlatness",
|
||||
"StdDevSpectralFlatness",
|
||||
"MeanLoudness",
|
||||
"StdDevLoudness",
|
||||
"Chroma1",
|
||||
"Chroma2",
|
||||
"Chroma3",
|
||||
"Chroma4",
|
||||
"Chroma5",
|
||||
"Chroma6",
|
||||
"Chroma7",
|
||||
"Chroma8",
|
||||
"Chroma9",
|
||||
"Chroma10"
|
||||
FROM bliss_tracks;"""
|
||||
with conn.cursor(row_factory=class_row(Song)) as cur:
|
||||
cur.execute(QUERY)
|
||||
return cur.fetchall()
|
||||
def __str__(self):
|
||||
return f"wss: {self.iterations[-1].wss:.f2} k: {self.k:.f2} n: {len(self.iterations)} {self.filename}"
|
||||
|
||||
|
||||
def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]:
|
||||
@ -187,12 +49,17 @@ def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dic
|
||||
|
||||
# Now we need to assign songs to centroids
|
||||
for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1):
|
||||
# print(f"Song: {song.File}")
|
||||
# Performance for 10000 centroids for this progres bar
|
||||
# pure python 15 it/s
|
||||
# np.linalg.norm + jit 20 it/s
|
||||
# np.linalg.norm 21 it/s
|
||||
# np.dot 40 it/s
|
||||
# np.sqrt(np.sum(np.square(x - y))) + jit 150 it/s <- Current implementation
|
||||
|
||||
song_np_arr = song.np_arr
|
||||
|
||||
# Need to find the closest centroid
|
||||
closest_centroid = min(centroids.values(), key=lambda centroid: ldist(song_np_arr, centroid.np_arr))
|
||||
# print(f"We've selected group of centroid: {closest_centroid.id} with distance {closest_centroid @ song}")
|
||||
centroids_with_songs[closest_centroid.id].append(song)
|
||||
|
||||
return centroids_with_songs
|
||||
@ -219,7 +86,7 @@ def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dic
|
||||
|
||||
|
||||
def run(conn: psycopg.Connection, k: int):
|
||||
print(f'Running with k {k}')
|
||||
print(f"Running with k {k}")
|
||||
centroids = choose_centroids(conn, k=k)
|
||||
all_songs = get_all_songs(conn)
|
||||
|
||||
@ -229,11 +96,13 @@ def run(conn: psycopg.Connection, k: int):
|
||||
centroids_with_songs = centroids_to_songs(centroids, all_songs)
|
||||
upd_cent_res = update_centroids(centroids_with_songs, centroids) # Centroids are passed by reference
|
||||
|
||||
iterations.append(Iteration(
|
||||
n=iteration_num,
|
||||
cent_move_delta=upd_cent_res.cent_move_delta,
|
||||
wss=upd_cent_res.wss
|
||||
))
|
||||
iterations.append(
|
||||
Iteration(
|
||||
n=iteration_num,
|
||||
cent_move_delta=upd_cent_res.cent_move_delta,
|
||||
wss=upd_cent_res.wss,
|
||||
)
|
||||
)
|
||||
|
||||
pprint(iterations[-1])
|
||||
|
||||
@ -244,21 +113,16 @@ def run(conn: psycopg.Connection, k: int):
|
||||
else:
|
||||
print("Leaving loop by attempts exceeding")
|
||||
|
||||
# print("Clusters:")
|
||||
# for songs in centroids_to_songs(centroids, all_songs).values():
|
||||
# if len(songs) > 1:
|
||||
# pprint(songs)
|
||||
|
||||
with open(f'{datetime.datetime.now().isoformat().replace(":", "-")}.json', mode='w', encoding='utf-8') as res_file:
|
||||
with open(f'{datetime.datetime.now().isoformat().replace(":", "-")}.json', mode="w", encoding="utf-8") as res_file:
|
||||
results = asdict(RunResults(centroids=list(centroids.values()), k=k, iterations=iterations))
|
||||
for cent_d in results["centroids"]:
|
||||
del cent_d["np_arr"]
|
||||
|
||||
json.dump(results, res_file, indent=4)
|
||||
|
||||
|
||||
def main(k: int):
|
||||
conn = psycopg.connect(
|
||||
"dbname=postgres user=postgres port=5555 host=192.168.1.68 password=music_recommendation_service_postgres",
|
||||
row_factory=dict_row,
|
||||
)
|
||||
conn = db.get_conn()
|
||||
|
||||
try:
|
||||
# for k in range(10000, 16000, 1000):
|
||||
@ -266,7 +130,7 @@ def main(k: int):
|
||||
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
conn.close()
|
||||
print('Connection closed')
|
||||
print("Connection closed")
|
||||
raise e
|
||||
|
||||
|
||||
|
18
persistance.py
Normal file
18
persistance.py
Normal file
@ -0,0 +1,18 @@
|
||||
import json
|
||||
from main import RunResults, Iteration, Centroid
|
||||
|
||||
|
||||
def run_result_from_dict(d: dict, filename: str) -> RunResults:
|
||||
centroids_d = d["centroids"]
|
||||
iterations_d = d["iterations"]
|
||||
return RunResults(
|
||||
k=d["k"],
|
||||
iterations=[Iteration(**iter_d) for iter_d in iterations_d],
|
||||
centroids=[Centroid(**cent_d) for cent_d in centroids_d],
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
def load_model(filename: str) -> RunResults:
|
||||
with open(filename, mode="r") as f:
|
||||
return run_result_from_dict(json.load(f), filename)
|
Loading…
x
Reference in New Issue
Block a user