numpy + numba speedup ~10x times

This commit is contained in:
norohind 2023-12-06 17:08:49 +03:00
parent d3bb2e1b6b
commit abc4251c7c
Signed by: norohind
GPG Key ID: 01C3BECC26FB59E1

100
main.py
View File

@ -1,10 +1,12 @@
import datetime
import json
from collections import defaultdict
from dataclasses import dataclass, asdict
from dataclasses import dataclass, asdict, field
from numba import jit, njit
from math import sqrt
from pprint import pprint
import numpy as np
import psycopg
from psycopg.rows import dict_row, class_row
from tqdm import tqdm
@ -33,67 +35,60 @@ IMPORTANT_FIELDS: tuple[str, ...] = (
)
class DimensionDescriptor:
def __set_name__(self, owner, name):
self.name: str = name
self.idx: int = IMPORTANT_FIELDS.index(name)
def __get__(self, instance: 'EntityBase', owner) -> float:
return instance.np_arr[self.idx]
def __set__(self, instance: 'EntityBase', value: float):
instance.np_arr[self.idx] = value
@njit(cache=True)
def dist(x: np.ndarray, y: np.ndarray) -> float:
return np.sqrt(np.sum(np.square(x - y)))
@dataclass(frozen=True)
class UpdateCentroidsRes:
cent_move_delta: float
wss: float # wss - within-cluster-sum of squared errors
@dataclass
@dataclass(kw_only=True)
class EntityBase:
Tempo: float
Zcr: float
MeanSpectralCentroid: float
StdDevSpectralCentroid: float
MeanSpectralRolloff: float
StdDevSpectralRolloff: float
MeanSpectralFlatness: float
StdDevSpectralFlatness: float
MeanLoudness: float
StdDevLoudness: float
Chroma1: float
Chroma2: float
Chroma3: float
Chroma4: float
Chroma5: float
Chroma6: float
Chroma7: float
Chroma8: float
Chroma9: float
Chroma10: float
def __matmul__(self, other: "EntityBase") -> float:
"""Euclidian distance"""
return sqrt(
(self.Tempo - other.Tempo) ** 2
+ (self.Zcr - other.Zcr) ** 2
+ (self.MeanSpectralCentroid - other.MeanSpectralCentroid) ** 2
+ (self.StdDevSpectralCentroid - other.StdDevSpectralCentroid) ** 2
+ (self.MeanSpectralRolloff - other.MeanSpectralRolloff) ** 2
+ (self.StdDevSpectralRolloff - other.StdDevSpectralRolloff) ** 2
+ (self.MeanSpectralFlatness - other.MeanSpectralFlatness) ** 2
+ (self.StdDevSpectralFlatness - other.StdDevSpectralFlatness) ** 2
+ (self.MeanLoudness - other.MeanLoudness) ** 2
+ (self.StdDevLoudness - other.StdDevLoudness) ** 2
+ (self.Chroma1 - other.Chroma1) ** 2
+ (self.Chroma2 - other.Chroma2) ** 2
+ (self.Chroma3 - other.Chroma3) ** 2
+ (self.Chroma4 - other.Chroma4) ** 2
+ (self.Chroma5 - other.Chroma5) ** 2
+ (self.Chroma6 - other.Chroma6) ** 2
+ (self.Chroma7 - other.Chroma7) ** 2
+ (self.Chroma8 - other.Chroma8) ** 2
+ (self.Chroma9 - other.Chroma9) ** 2
+ (self.Chroma10 - other.Chroma10) ** 2
)
np_arr: np.ndarray = field(init=False, repr=False, default_factory=lambda: np.zeros(len(IMPORTANT_FIELDS)))
Tempo: float = DimensionDescriptor()
Zcr: float = DimensionDescriptor()
MeanSpectralCentroid: float = DimensionDescriptor()
StdDevSpectralCentroid: float = DimensionDescriptor()
MeanSpectralRolloff: float = DimensionDescriptor()
StdDevSpectralRolloff: float = DimensionDescriptor()
MeanSpectralFlatness: float = DimensionDescriptor()
StdDevSpectralFlatness: float = DimensionDescriptor()
MeanLoudness: float = DimensionDescriptor()
StdDevLoudness: float = DimensionDescriptor()
Chroma1: float = DimensionDescriptor()
Chroma2: float = DimensionDescriptor()
Chroma3: float = DimensionDescriptor()
Chroma4: float = DimensionDescriptor()
Chroma5: float = DimensionDescriptor()
Chroma6: float = DimensionDescriptor()
Chroma7: float = DimensionDescriptor()
Chroma8: float = DimensionDescriptor()
Chroma9: float = DimensionDescriptor()
Chroma10: float = DimensionDescriptor()
@dataclass
@dataclass(kw_only=True)
class Centroid(EntityBase):
id: int
@dataclass
@dataclass(kw_only=True)
class Song(EntityBase):
File: str
@ -187,14 +182,16 @@ def get_all_songs(conn: psycopg.Connection) -> list[Song]:
def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]:
ldist = dist # Local reference for faster lookup
centroids_with_songs: dict[int, list[Song]] = defaultdict(list)
# Now we need to assign songs to centroids
for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1):
# print(f"Song: {song.File}")
song_np_arr = song.np_arr
# Need to find the closest centroid
closest_centroid = min(centroids.values(), key=lambda centroid: song @ centroid)
closest_centroid = min(centroids.values(), key=lambda centroid: ldist(song_np_arr, centroid.np_arr))
# print(f"We've selected group of centroid: {closest_centroid.id} with distance {closest_centroid @ song}")
centroids_with_songs[closest_centroid.id].append(song)
@ -216,7 +213,7 @@ def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dic
total_delta += abs(old_field_value - avg_value_of_field)
wss += sum((centroid @ song) ** 2 for song in cent_songs)
wss += sum((dist(centroid.np_arr, song.np_arr)) ** 2 for song in cent_songs)
return UpdateCentroidsRes(cent_move_delta=total_delta, wss=wss)
@ -269,6 +266,7 @@ def main(k: int):
except (Exception, KeyboardInterrupt) as e:
conn.close()
print('Connection closed')
raise e