From abc4251c7c26f94923c6f9699e790303269df5bd Mon Sep 17 00:00:00 2001 From: norohind <60548839+norohind@users.noreply.github.com> Date: Wed, 6 Dec 2023 17:08:49 +0300 Subject: [PATCH] numpy + numba speedup ~10x times --- main.py | 100 +++++++++++++++++++++++++++----------------------------- 1 file changed, 49 insertions(+), 51 deletions(-) diff --git a/main.py b/main.py index 1e9cd71..43a900d 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,12 @@ import datetime import json from collections import defaultdict -from dataclasses import dataclass, asdict +from dataclasses import dataclass, asdict, field +from numba import jit, njit from math import sqrt from pprint import pprint +import numpy as np import psycopg from psycopg.rows import dict_row, class_row from tqdm import tqdm @@ -33,67 +35,60 @@ IMPORTANT_FIELDS: tuple[str, ...] = ( ) +class DimensionDescriptor: + def __set_name__(self, owner, name): + self.name: str = name + self.idx: int = IMPORTANT_FIELDS.index(name) + + def __get__(self, instance: 'EntityBase', owner) -> float: + return instance.np_arr[self.idx] + + def __set__(self, instance: 'EntityBase', value: float): + instance.np_arr[self.idx] = value + + +@njit(cache=True) +def dist(x: np.ndarray, y: np.ndarray) -> float: + return np.sqrt(np.sum(np.square(x - y))) + + @dataclass(frozen=True) class UpdateCentroidsRes: cent_move_delta: float wss: float # wss - within-cluster-sum of squared errors -@dataclass +@dataclass(kw_only=True) class EntityBase: - Tempo: float - Zcr: float - MeanSpectralCentroid: float - StdDevSpectralCentroid: float - MeanSpectralRolloff: float - StdDevSpectralRolloff: float - MeanSpectralFlatness: float - StdDevSpectralFlatness: float - MeanLoudness: float - StdDevLoudness: float - Chroma1: float - Chroma2: float - Chroma3: float - Chroma4: float - Chroma5: float - Chroma6: float - Chroma7: float - Chroma8: float - Chroma9: float - Chroma10: float - - def __matmul__(self, other: "EntityBase") -> float: - """Euclidian distance""" - return sqrt( - (self.Tempo - other.Tempo) ** 2 - + (self.Zcr - other.Zcr) ** 2 - + (self.MeanSpectralCentroid - other.MeanSpectralCentroid) ** 2 - + (self.StdDevSpectralCentroid - other.StdDevSpectralCentroid) ** 2 - + (self.MeanSpectralRolloff - other.MeanSpectralRolloff) ** 2 - + (self.StdDevSpectralRolloff - other.StdDevSpectralRolloff) ** 2 - + (self.MeanSpectralFlatness - other.MeanSpectralFlatness) ** 2 - + (self.StdDevSpectralFlatness - other.StdDevSpectralFlatness) ** 2 - + (self.MeanLoudness - other.MeanLoudness) ** 2 - + (self.StdDevLoudness - other.StdDevLoudness) ** 2 - + (self.Chroma1 - other.Chroma1) ** 2 - + (self.Chroma2 - other.Chroma2) ** 2 - + (self.Chroma3 - other.Chroma3) ** 2 - + (self.Chroma4 - other.Chroma4) ** 2 - + (self.Chroma5 - other.Chroma5) ** 2 - + (self.Chroma6 - other.Chroma6) ** 2 - + (self.Chroma7 - other.Chroma7) ** 2 - + (self.Chroma8 - other.Chroma8) ** 2 - + (self.Chroma9 - other.Chroma9) ** 2 - + (self.Chroma10 - other.Chroma10) ** 2 - ) + np_arr: np.ndarray = field(init=False, repr=False, default_factory=lambda: np.zeros(len(IMPORTANT_FIELDS))) + Tempo: float = DimensionDescriptor() + Zcr: float = DimensionDescriptor() + MeanSpectralCentroid: float = DimensionDescriptor() + StdDevSpectralCentroid: float = DimensionDescriptor() + MeanSpectralRolloff: float = DimensionDescriptor() + StdDevSpectralRolloff: float = DimensionDescriptor() + MeanSpectralFlatness: float = DimensionDescriptor() + StdDevSpectralFlatness: float = DimensionDescriptor() + MeanLoudness: float = DimensionDescriptor() + StdDevLoudness: float = DimensionDescriptor() + Chroma1: float = DimensionDescriptor() + Chroma2: float = DimensionDescriptor() + Chroma3: float = DimensionDescriptor() + Chroma4: float = DimensionDescriptor() + Chroma5: float = DimensionDescriptor() + Chroma6: float = DimensionDescriptor() + Chroma7: float = DimensionDescriptor() + Chroma8: float = DimensionDescriptor() + Chroma9: float = DimensionDescriptor() + Chroma10: float = DimensionDescriptor() -@dataclass +@dataclass(kw_only=True) class Centroid(EntityBase): id: int -@dataclass +@dataclass(kw_only=True) class Song(EntityBase): File: str @@ -187,14 +182,16 @@ def get_all_songs(conn: psycopg.Connection) -> list[Song]: def centroids_to_songs(centroids: dict[int, Centroid], songs: list[Song]) -> dict[int, list[Song]]: + ldist = dist # Local reference for faster lookup centroids_with_songs: dict[int, list[Song]] = defaultdict(list) # Now we need to assign songs to centroids for song in tqdm(songs, desc="Assigning songs to centroids", disable=False, mininterval=1, delay=1): # print(f"Song: {song.File}") + song_np_arr = song.np_arr # Need to find the closest centroid - closest_centroid = min(centroids.values(), key=lambda centroid: song @ centroid) + closest_centroid = min(centroids.values(), key=lambda centroid: ldist(song_np_arr, centroid.np_arr)) # print(f"We've selected group of centroid: {closest_centroid.id} with distance {closest_centroid @ song}") centroids_with_songs[closest_centroid.id].append(song) @@ -216,7 +213,7 @@ def update_centroids(centroids_with_songs: dict[int, list[Song]], centroids: dic total_delta += abs(old_field_value - avg_value_of_field) - wss += sum((centroid @ song) ** 2 for song in cent_songs) + wss += sum((dist(centroid.np_arr, song.np_arr)) ** 2 for song in cent_songs) return UpdateCentroidsRes(cent_move_delta=total_delta, wss=wss) @@ -269,6 +266,7 @@ def main(k: int): except (Exception, KeyboardInterrupt) as e: conn.close() + print('Connection closed') raise e