numpy + numba speedup ~10x times
This commit is contained in:
parent
d3bb2e1b6b
commit
abc4251c7c
100
main.py
100
main.py
@ -1,10 +1,12 @@
|
||||
import datetime
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from numba import jit, njit
|
||||
from math import sqrt
|
||||
from pprint import pprint
|
||||
|
||||
import numpy as np
|
||||
import psycopg
|
||||
from psycopg.rows import dict_row, class_row
|
||||
from tqdm import tqdm
|
||||
@ -33,67 +35,60 @@ IMPORTANT_FIELDS: tuple[str, ...] = (
|
||||
)
|
||||
|
||||
|
||||
class DimensionDescriptor:
|
||||
def __set_name__(self, owner, name):
|
||||
self.name: str = name
|
||||
self.idx: int = IMPORTANT_FIELDS.index(name)
|
||||
|
||||
def __get__(self, instance: 'EntityBase', owner) -> float:
|
||||
return instance.np_arr[self.idx]
|
||||
|
||||
def __set__(self, instance: 'EntityBase', value: float):
|
||||
instance.np_arr[self.idx] = value
|
||||
|
||||
|
||||
@njit(cache=True)
|
||||
def dist(x: np.ndarray, y: np.ndarray) -> float:
|
||||
return np.sqrt(np.sum(np.square(x - y)))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UpdateCentroidsRes:
|
||||
cent_move_delta: float
|
||||
wss: float # wss - within-cluster-sum of squared errors
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(kw_only=True)
|
||||
class EntityBase:
|
||||
Tempo: float
|
||||
Zcr: float
|
||||
MeanSpectralCentroid: float
|
||||
StdDevSpectralCentroid: float
|
||||
MeanSpectralRolloff: float
|
||||
StdDevSpectralRolloff: float
|
||||
MeanSpectralFlatness: float
|
||||
StdDevSpectralFlatness: float
|
||||
MeanLoudness: float
|
||||
StdDevLoudness: float
|
||||
Chroma1: float
|
||||
Chroma2: float
|
||||
Chroma3: float
|
||||
Chroma4: float
|
||||
Chroma5: float
|
||||
Chroma6: float
|
||||
Chroma7: float
|
||||
Chroma8: float
|
||||
Chroma9: float
|
||||
Chroma10: float
|
||||
|
||||
def __matmul__(self, other: "EntityBase") -> float:
|
||||
"""Euclidian distance"""
|
||||
return sqrt(
|
||||
(self.Tempo - other.Tempo) ** 2
|
||||
+ (self.Zcr - other.Zcr) ** 2
|
||||
+ (self.MeanSpectralCentroid - other.MeanSpectralCentroid) ** 2
|
||||
+ (self.StdDevSpectralCentroid - other.StdDevSpectralCentroid) ** 2
|
||||
+ (self.MeanSpectralRolloff - other.MeanSpectralRolloff) ** 2
|
||||
+ (self.StdDevSpectralRolloff - other.StdDevSpectralRolloff) ** 2
|
||||
+ (self.MeanSpectralFlatness - other.MeanSpectralFlatness) ** 2
|
||||
+ (self.StdDevSpectralFlatness - other.StdDevSpectralFlatness) ** 2
|
||||
+ (self.MeanLoudness - other.MeanLoudness) ** 2
|
||||
+ (self.StdDevLoudness - other.StdDevLoudness) ** 2
|
||||
+ (self.Chroma1 - other.Chroma1) ** 2
|
||||
+ (self.Chroma2 - other.Chroma2) ** 2
|
||||
+ (self.Chroma3 - other.Chroma3) ** 2
|
||||
+ (self.Chroma4 - other.Chroma4) ** 2
|
||||
+ (self.Chroma5 - other.Chroma5) ** 2
|
||||
+ (self.Chroma6 - other.Chroma6) ** 2
|
||||
+ (self.Chroma7 - other.Chroma7) ** 2
|
||||
+ (self.Chroma8 - other.Chroma8) ** 2
|
||||
+ (self.Chroma9 - other.Chroma9) ** 2
|
||||
+ (self.Chroma10 - other.Chroma10) ** 2
|
||||
)
|
||||
np_arr: np.ndarray = field(init=False, repr=False, default_factory=lambda: np.zeros(len(IMPORTANT_FIELDS)))
|
||||
Tempo: float = DimensionDescriptor()
|
||||
Zcr: float = DimensionDescriptor()
|
||||
MeanSpectralCentroid: float = DimensionDescriptor()
|
||||
StdDevSpectralCentroid: float = DimensionDescriptor()
|
||||
MeanSpectralRolloff: float = DimensionDescriptor()
|
||||
StdDevSpectralRolloff: float = DimensionDescriptor()
|
||||
MeanSpectralFlatness: float = DimensionDescriptor()
|
||||
StdDevSpectralFlatness: float = DimensionDescriptor()
|
||||
MeanLoudness: float = DimensionDescriptor()
|
||||
StdDevLoudness: float = DimensionDescriptor()
|
||||
Chroma1: float = DimensionDescriptor()
|
||||
Chroma2: float = DimensionDescriptor()
|
||||
Chroma3: float = DimensionDescriptor()
|
||||
Chroma4: float = DimensionDescriptor()
|
||||
Chroma5: float = DimensionDescriptor()
|
||||
Chroma6: float = DimensionDescriptor()
|
||||
Chroma7: float = DimensionDescriptor()
|
||||
Chroma8: float = DimensionDescriptor()
|
||||
Chroma9: float = DimensionDescriptor()
|
||||
Chroma10: float = DimensionDescriptor()
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(kw_only=True)
|
||||
class Centroid(EntityBase):
|
||||
id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(kw_only=True)
|
||||
class Song(EntityBase):
|
||||
File: str
|
||||
|
||||
@ -187,14 +182,16 @@ def get_all_songs(conn: psycopg.Connection) -> list[Song]:
|
||||
|
||||
|
||||
def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]:
|
||||
ldist = dist # Local reference for faster lookup
|
||||
centroids_with_songs: dict[int, list[Song]] = defaultdict(list)
|
||||
|
||||
# Now we need to assign songs to centroids
|
||||
for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1):
|
||||
# print(f"Song: {song.File}")
|
||||
song_np_arr = song.np_arr
|
||||
|
||||
# Need to find the closest centroid
|
||||
closest_centroid = min(centroids.values(), key=lambda centroid: song @ centroid)
|
||||
closest_centroid = min(centroids.values(), key=lambda centroid: ldist(song_np_arr, centroid.np_arr))
|
||||
# print(f"We've selected group of centroid: {closest_centroid.id} with distance {closest_centroid @ song}")
|
||||
centroids_with_songs[closest_centroid.id].append(song)
|
||||
|
||||
@ -216,7 +213,7 @@ def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dic
|
||||
|
||||
total_delta += abs(old_field_value - avg_value_of_field)
|
||||
|
||||
wss += sum((centroid @ song) ** 2 for song in cent_songs)
|
||||
wss += sum((dist(centroid.np_arr, song.np_arr)) ** 2 for song in cent_songs)
|
||||
|
||||
return UpdateCentroidsRes(cent_move_delta=total_delta, wss=wss)
|
||||
|
||||
@ -269,6 +266,7 @@ def main(k: int):
|
||||
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
conn.close()
|
||||
print('Connection closed')
|
||||
raise e
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user