37 lines
1003 B
Python
37 lines
1003 B
Python
import persistance
|
|
from main import RunResults, centroids_to_songs
|
|
from db import get_all_songs, get_conn
|
|
from domain import Song
|
|
from contextlib import closing
|
|
|
|
# Load model and show songs assigned to clusters
|
|
|
|
|
|
def inference(model: RunResults):
|
|
with closing(get_conn()) as conn:
|
|
songs: list[Song] = get_all_songs(conn)
|
|
|
|
from pprint import pprint
|
|
|
|
centroids = {cent.id: cent for cent in model.centroids}
|
|
clusters = centroids_to_songs(centroids, songs)
|
|
for cluster in clusters.values():
|
|
for song in cluster:
|
|
if song.File == "/music/music_mirror/PoH/Green Day - 21st Century Breakdown.mp3":
|
|
pprint(cluster)
|
|
break
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Run model")
|
|
parser.add_argument("filename")
|
|
args = parser.parse_args()
|
|
model = persistance.load_model(args.filename)
|
|
inference(model)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|