music-dev-clustering/inference.py

37 lines
1003 B
Python

import persistance
from main import RunResults, centroids_to_songs
from db import get_all_songs, get_conn
from domain import Song
from contextlib import closing
# Load model and show songs assigned to clusters
def inference(model: RunResults):
with closing(get_conn()) as conn:
songs: list[Song] = get_all_songs(conn)
from pprint import pprint
centroids = {cent.id: cent for cent in model.centroids}
clusters = centroids_to_songs(centroids, songs)
for cluster in clusters.values():
for song in cluster:
if song.File == "/music/music_mirror/PoH/Green Day - 21st Century Breakdown.mp3":
pprint(cluster)
break
def main():
import argparse
parser = argparse.ArgumentParser(description="Run model")
parser.add_argument("filename")
args = parser.parse_args()
model = persistance.load_model(args.filename)
inference(model)
if __name__ == "__main__":
main()