diff --git a/maloja/database/sqldb.py b/maloja/database/sqldb.py index a51c3e7..0d4345f 100644 --- a/maloja/database/sqldb.py +++ b/maloja/database/sqldb.py @@ -1,3 +1,5 @@ +from typing import TypedDict, Optional, cast + import sqlalchemy as sql from sqlalchemy.dialects.sqlite import insert as sqliteinsert import json @@ -213,6 +215,25 @@ def set_maloja_info(info,dbconn=None): # The last two fields are not returned under normal circumstances +class AlbumDict(TypedDict): + albumtitle: str + artists: list[str] + + +class TrackDict(TypedDict): + artists: list[str] + title: str + album: AlbumDict + length: int | None + + +class ScrobbleDict(TypedDict): + time: int + track: TrackDict + duration: int + origin: str + extra: Optional[dict] + rawscrobble: Optional[dict] ##### Conversions between DB and dicts @@ -222,121 +243,129 @@ def set_maloja_info(info,dbconn=None): ### DB -> DICT -def scrobbles_db_to_dict(rows,include_internal=False,dbconn=None): - tracks = get_tracks_map(set(row.track_id for row in rows),dbconn=dbconn) +def scrobbles_db_to_dict(rows, include_internal=False, dbconn=None) -> list[ScrobbleDict]: + tracks: list[TrackDict] = get_tracks_map(set(row.track_id for row in rows), dbconn=dbconn) return [ - { + cast({ **{ - "time":row.timestamp, - "track":tracks[row.track_id], - "duration":row.duration, - "origin":row.origin, + "time": row.timestamp, + "track": tracks[row.track_id], + "duration": row.duration, + "origin": row.origin }, **({ - "extra":json.loads(row.extra or '{}'), - "rawscrobble":json.loads(row.rawscrobble or '{}') + "extra": json.loads(row.extra or '{}'), + "rawscrobble": json.loads(row.rawscrobble or '{}') } if include_internal else {}) - } + }, ScrobbleDict) for row in rows ] -def scrobble_db_to_dict(row,dbconn=None): - return scrobbles_db_to_dict([row],dbconn=dbconn)[0] -def tracks_db_to_dict(rows,dbconn=None): - artists = get_artists_of_tracks(set(row.id for row in rows),dbconn=dbconn) - albums = get_albums_map(set(row.album_id for row in rows),dbconn=dbconn) +def scrobble_db_to_dict(row, dbconn=None) -> ScrobbleDict: + return scrobbles_db_to_dict([row], dbconn=dbconn)[0] + + +def tracks_db_to_dict(rows, dbconn=None) -> list[TrackDict]: + artists = get_artists_of_tracks(set(row.id for row in rows), dbconn=dbconn) + albums = get_albums_map(set(row.album_id for row in rows), dbconn=dbconn) return [ - { + cast({ "artists":artists[row.id], "title":row.title, "album":albums.get(row.album_id), "length":row.length - } + }, TrackDict) for row in rows ] -def track_db_to_dict(row,dbconn=None): - return tracks_db_to_dict([row],dbconn=dbconn)[0] -def artists_db_to_dict(rows,dbconn=None): +def track_db_to_dict(row, dbconn=None) -> TrackDict: + return tracks_db_to_dict([row], dbconn=dbconn)[0] + + +def artists_db_to_dict(rows, dbconn=None) -> list[str]: return [ row.name for row in rows ] -def artist_db_to_dict(row,dbconn=None): - return artists_db_to_dict([row],dbconn=dbconn)[0] -def albums_db_to_dict(rows,dbconn=None): - artists = get_artists_of_albums(set(row.id for row in rows),dbconn=dbconn) +def artist_db_to_dict(row, dbconn=None) -> str: + return artists_db_to_dict([row], dbconn=dbconn)[0] + + +def albums_db_to_dict(rows, dbconn=None) -> list[AlbumDict]: + artists = get_artists_of_albums(set(row.id for row in rows), dbconn=dbconn) return [ - { - "artists":artists.get(row.id), - "albumtitle":row.albtitle, - } + cast({ + "artists": artists.get(row.id), + "albumtitle": row.albtitle, + }, AlbumDict) for row in rows ] -def album_db_to_dict(row,dbconn=None): - return albums_db_to_dict([row],dbconn=dbconn)[0] - +def album_db_to_dict(row, dbconn=None) -> AlbumDict: + return albums_db_to_dict([row], dbconn=dbconn)[0] ### DICT -> DB # These should return None when no data is in the dict so they can be used for update statements -def scrobble_dict_to_db(info,update_album=False,dbconn=None): +def scrobble_dict_to_db(info: ScrobbleDict, update_album=False, dbconn=None): return { - "timestamp":info.get('time'), - "origin":info.get('origin'), - "duration":info.get('duration'), - "track_id":get_track_id(info.get('track'),update_album=update_album,dbconn=dbconn), - "extra":json.dumps(info.get('extra')) if info.get('extra') else None, - "rawscrobble":json.dumps(info.get('rawscrobble')) if info.get('rawscrobble') else None + "timestamp": info.get('time'), + "origin": info.get('origin'), + "duration": info.get('duration'), + "track_id": get_track_id(info.get('track'), update_album=update_album, dbconn=dbconn), + "extra": json.dumps(info.get('extra')) if info.get('extra') else None, + "rawscrobble": json.dumps(info.get('rawscrobble')) if info.get('rawscrobble') else None } -def track_dict_to_db(info,dbconn=None): + +def track_dict_to_db(info: TrackDict, dbconn=None): return { - "title":info.get('title'), - "title_normalized":normalize_name(info.get('title','')) or None, - "length":info.get('length') + "title": info.get('title'), + "title_normalized": normalize_name(info.get('title', '')) or None, + "length": info.get('length') } -def artist_dict_to_db(info,dbconn=None): + +def artist_dict_to_db(info: str, dbconn=None): return { "name": info, - "name_normalized":normalize_name(info) + "name_normalized": normalize_name(info) } -def album_dict_to_db(info,dbconn=None): + +def album_dict_to_db(info: AlbumDict, dbconn=None): return { - "albtitle":info.get('albumtitle'), - "albtitle_normalized":normalize_name(info.get('albumtitle')) + "albtitle": info.get('albumtitle'), + "albtitle_normalized": normalize_name(info.get('albumtitle')) } - ##### Actual Database interactions # TODO: remove all resolve_id args and do that logic outside the caching to improve hit chances # TODO: maybe also factor out all intitial get entity funcs (some here, some in __init__) and throw exceptions @connection_provider -def add_scrobble(scrobbledict,update_album=False,dbconn=None): - _, ex, er = add_scrobbles([scrobbledict],update_album=update_album,dbconn=dbconn) +def add_scrobble(scrobbledict: ScrobbleDict, update_album=False, dbconn=None): + _, ex, er = add_scrobbles([scrobbledict], update_album=update_album, dbconn=dbconn) if er > 0: - raise exc.DuplicateTimestamp(existing_scrobble=None,rejected_scrobble=scrobbledict) + raise exc.DuplicateTimestamp(existing_scrobble=None, rejected_scrobble=scrobbledict) # TODO: actually pass existing scrobble elif ex > 0: raise exc.DuplicateScrobble(scrobble=scrobbledict) + @connection_provider -def add_scrobbles(scrobbleslist,update_album=False,dbconn=None): +def add_scrobbles(scrobbleslist: list[ScrobbleDict], update_album=False, dbconn=None) -> tuple[int, int, int]: with SCROBBLE_LOCK: @@ -349,7 +378,7 @@ def add_scrobbles(scrobbleslist,update_album=False,dbconn=None): success, exists, errors = 0, 0, 0 for s in scrobbleslist: - scrobble_entry = scrobble_dict_to_db(s,update_album=update_album,dbconn=dbconn) + scrobble_entry = scrobble_dict_to_db(s, update_album=update_album, dbconn=dbconn) try: dbconn.execute(DB['scrobbles'].insert().values( **scrobble_entry @@ -365,13 +394,13 @@ def add_scrobbles(scrobbleslist,update_album=False,dbconn=None): else: errors += 1 - if errors > 0: log(f"{errors} Scrobbles have not been written to database (duplicate timestamps)!", color='red') if exists > 0: log(f"{exists} Scrobbles have not been written to database (already exist)", color='orange') return success, exists, errors + @connection_provider -def delete_scrobble(scrobble_id,dbconn=None): +def delete_scrobble(scrobble_id: int, dbconn=None) -> bool: with SCROBBLE_LOCK: @@ -385,7 +414,7 @@ def delete_scrobble(scrobble_id,dbconn=None): @connection_provider -def add_track_to_album(track_id,album_id,replace=False,dbconn=None): +def add_track_to_album(track_id: int, album_id: int, replace=False, dbconn=None) -> bool: conditions = [ DB['tracks'].c.id == track_id @@ -414,39 +443,39 @@ def add_track_to_album(track_id,album_id,replace=False,dbconn=None): # ALL OF RECORDED HISTORY in order to display top weeks # lmao # TODO: figure out something better - - return True + @connection_provider -def add_tracks_to_albums(track_to_album_id_dict,replace=False,dbconn=None): +def add_tracks_to_albums(track_to_album_id_dict: dict[int, int], replace=False, dbconn=None) -> bool: for track_id in track_to_album_id_dict: - add_track_to_album(track_id,track_to_album_id_dict[track_id],replace=replace,dbconn=dbconn) + add_track_to_album(track_id,track_to_album_id_dict[track_id], replace=replace, dbconn=dbconn) + return True + @connection_provider -def remove_album(*track_ids,dbconn=None): +def remove_album(*track_ids: list[int], dbconn=None) -> bool: DB['tracks'].update().where( DB['tracks'].c.track_id.in_(track_ids) ).values( album_id=None ) + return True + ### these will 'get' the ID of an entity, creating it if necessary @cached_wrapper @connection_provider -def get_track_id(trackdict,create_new=True,update_album=False,dbconn=None): +def get_track_id(trackdict: TrackDict, create_new=True, update_album=False, dbconn=None) -> int | None: ntitle = normalize_name(trackdict['title']) - artist_ids = [get_artist_id(a,create_new=create_new,dbconn=dbconn) for a in trackdict['artists']] + artist_ids = [get_artist_id(a, create_new=create_new, dbconn=dbconn) for a in trackdict['artists']] artist_ids = list(set(artist_ids)) - - - op = DB['tracks'].select().where( - DB['tracks'].c.title_normalized==ntitle + DB['tracks'].c.title_normalized == ntitle ) result = dbconn.execute(op).all() for row in result: @@ -456,7 +485,7 @@ def get_track_id(trackdict,create_new=True,update_album=False,dbconn=None): op = DB['trackartists'].select( # DB['trackartists'].c.artist_id ).where( - DB['trackartists'].c.track_id==row.id + DB['trackartists'].c.track_id == row.id ) result = dbconn.execute(op).all() match_artist_ids = [r.artist_id for r in result] @@ -472,14 +501,14 @@ def get_track_id(trackdict,create_new=True,update_album=False,dbconn=None): album_id = get_album_id(trackdict['album'],create_new=(update_album or not row.album_id),dbconn=dbconn) add_track_to_album(row.id,album_id,replace=update_album,dbconn=dbconn) - return row.id - if not create_new: return None + if not create_new: + return None #print("Creating new track") op = DB['tracks'].insert().values( - **track_dict_to_db(trackdict,dbconn=dbconn) + **track_dict_to_db(trackdict, dbconn=dbconn) ) result = dbconn.execute(op) track_id = result.inserted_primary_key[0] @@ -494,24 +523,26 @@ def get_track_id(trackdict,create_new=True,update_album=False,dbconn=None): #print("Created",trackdict['title'],track_id) if trackdict.get('album'): - add_track_to_album(track_id,get_album_id(trackdict['album'],dbconn=dbconn),dbconn=dbconn) + add_track_to_album(track_id, get_album_id(trackdict['album'], dbconn=dbconn), dbconn=dbconn) return track_id + @cached_wrapper @connection_provider -def get_artist_id(artistname,create_new=True,dbconn=None): +def get_artist_id(artistname: str, create_new=True, dbconn=None) -> int | None: nname = normalize_name(artistname) #print("looking for",nname) op = DB['artists'].select().where( - DB['artists'].c.name_normalized==nname + DB['artists'].c.name_normalized == nname ) result = dbconn.execute(op).all() for row in result: #print("ID for",artistname,"was",row[0]) return row.id - if not create_new: return None + if not create_new: + return None op = DB['artists'].insert().values( name=artistname, @@ -524,15 +555,15 @@ def get_artist_id(artistname,create_new=True,dbconn=None): @cached_wrapper @connection_provider -def get_album_id(albumdict,create_new=True,ignore_albumartists=False,dbconn=None): +def get_album_id(albumdict: AlbumDict, create_new=True, ignore_albumartists=False, dbconn=None) -> int | None: ntitle = normalize_name(albumdict['albumtitle']) - artist_ids = [get_artist_id(a,dbconn=dbconn) for a in (albumdict.get('artists') or [])] + artist_ids = [get_artist_id(a, dbconn=dbconn) for a in (albumdict.get('artists') or [])] artist_ids = list(set(artist_ids)) op = DB['albums'].select( # DB['albums'].c.id ).where( - DB['albums'].c.albtitle_normalized==ntitle + DB['albums'].c.albtitle_normalized == ntitle ) result = dbconn.execute(op).all() for row in result: @@ -545,7 +576,7 @@ def get_album_id(albumdict,create_new=True,ignore_albumartists=False,dbconn=None op = DB['albumartists'].select( # DB['albumartists'].c.artist_id ).where( - DB['albumartists'].c.album_id==row.id + DB['albumartists'].c.album_id == row.id ) result = dbconn.execute(op).all() match_artist_ids = [r.artist_id for r in result] @@ -554,11 +585,11 @@ def get_album_id(albumdict,create_new=True,ignore_albumartists=False,dbconn=None #print("ID for",albumdict['title'],"was",row[0]) return row.id - if not create_new: return None - + if not create_new: + return None op = DB['albums'].insert().values( - **album_dict_to_db(albumdict,dbconn=dbconn) + **album_dict_to_db(albumdict, dbconn=dbconn) ) result = dbconn.execute(op) album_id = result.inserted_primary_key[0] @@ -573,18 +604,15 @@ def get_album_id(albumdict,create_new=True,ignore_albumartists=False,dbconn=None return album_id - - ### Edit existing - @connection_provider -def edit_scrobble(scrobble_id,scrobbleupdatedict,dbconn=None): +def edit_scrobble(scrobble_id: int, scrobbleupdatedict: dict, dbconn=None) -> bool: dbentry = scrobble_dict_to_db(scrobbleupdatedict,dbconn=dbconn) - dbentry = {k:v for k,v in dbentry.items() if v} + dbentry = {k: v for k, v in dbentry.items() if v} - print("Updating scrobble",dbentry) + print("Updating scrobble", dbentry) with SCROBBLE_LOCK: @@ -595,97 +623,97 @@ def edit_scrobble(scrobble_id,scrobbleupdatedict,dbconn=None): ) dbconn.execute(op) + return True + # edit function only for primary db information (not linked fields) @connection_provider -def edit_artist(id,artistupdatedict,dbconn=None): +def edit_artist(artist_id: int, artistupdatedict: str, dbconn=None) -> bool: - artist = get_artist(id) + artist = get_artist(artist_id) changedartist = artistupdatedict # well - dbentry = artist_dict_to_db(artistupdatedict,dbconn=dbconn) - dbentry = {k:v for k,v in dbentry.items() if v} + dbentry = artist_dict_to_db(artistupdatedict, dbconn=dbconn) + dbentry = {k: v for k, v in dbentry.items() if v} - existing_artist_id = get_artist_id(changedartist,create_new=False,dbconn=dbconn) - if existing_artist_id not in (None,id): + existing_artist_id = get_artist_id(changedartist, create_new=False, dbconn=dbconn) + if existing_artist_id not in (None, artist_id): raise exc.ArtistExists(changedartist) op = DB['artists'].update().where( - DB['artists'].c.id==id + DB['artists'].c.id == artist_id ).values( **dbentry ) result = dbconn.execute(op) - return True + # edit function only for primary db information (not linked fields) @connection_provider -def edit_track(id,trackupdatedict,dbconn=None): +def edit_track(track_id: int, trackupdatedict: dict, dbconn=None) -> bool: - track = get_track(id,dbconn=dbconn) - changedtrack = {**track,**trackupdatedict} + track = get_track(track_id, dbconn=dbconn) + changedtrack: TrackDict = {**track, **trackupdatedict} - dbentry = track_dict_to_db(trackupdatedict,dbconn=dbconn) - dbentry = {k:v for k,v in dbentry.items() if v} + dbentry = track_dict_to_db(trackupdatedict, dbconn=dbconn) + dbentry = {k: v for k, v in dbentry.items() if v} - existing_track_id = get_track_id(changedtrack,create_new=False,dbconn=dbconn) - if existing_track_id not in (None,id): + existing_track_id = get_track_id(changedtrack, create_new=False, dbconn=dbconn) + if existing_track_id not in (None, track_id): raise exc.TrackExists(changedtrack) op = DB['tracks'].update().where( - DB['tracks'].c.id==id + DB['tracks'].c.id == track_id ).values( **dbentry ) result = dbconn.execute(op) - return True + # edit function only for primary db information (not linked fields) @connection_provider -def edit_album(id,albumupdatedict,dbconn=None): +def edit_album(album_id: int, albumupdatedict: dict, dbconn=None) -> bool: - album = get_album(id,dbconn=dbconn) - changedalbum = {**album,**albumupdatedict} + album = get_album(album_id, dbconn=dbconn) + changedalbum: AlbumDict = {**album, **albumupdatedict} - dbentry = album_dict_to_db(albumupdatedict,dbconn=dbconn) - dbentry = {k:v for k,v in dbentry.items() if v} + dbentry = album_dict_to_db(albumupdatedict, dbconn=dbconn) + dbentry = {k: v for k, v in dbentry.items() if v} - existing_album_id = get_album_id(changedalbum,create_new=False,dbconn=dbconn) - if existing_album_id not in (None,id): + existing_album_id = get_album_id(changedalbum, create_new=False, dbconn=dbconn) + if existing_album_id not in (None, album_id): raise exc.TrackExists(changedalbum) op = DB['albums'].update().where( - DB['albums'].c.id==id + DB['albums'].c.id == album_id ).values( **dbentry ) result = dbconn.execute(op) - return True ### Edit associations @connection_provider -def add_artists_to_tracks(track_ids,artist_ids,dbconn=None): +def add_artists_to_tracks(track_ids: list[int], artist_ids: list[int], dbconn=None) -> bool: op = DB['trackartists'].insert().values([ - {'track_id':track_id,'artist_id':artist_id} + {'track_id': track_id, 'artist_id': artist_id} for track_id in track_ids for artist_id in artist_ids ]) result = dbconn.execute(op) - # the resulting tracks could now be duplicates of existing ones # this also takes care of clean_db merge_duplicate_tracks(dbconn=dbconn) - return True + @connection_provider -def remove_artists_from_tracks(track_ids,artist_ids,dbconn=None): +def remove_artists_from_tracks(track_ids: list[int], artist_ids: list[int], dbconn=None) -> bool: # only tracks that have at least one other artist subquery = DB['trackartists'].select().where( @@ -703,16 +731,14 @@ def remove_artists_from_tracks(track_ids,artist_ids,dbconn=None): ) result = dbconn.execute(op) - # the resulting tracks could now be duplicates of existing ones # this also takes care of clean_db merge_duplicate_tracks(dbconn=dbconn) - return True @connection_provider -def add_artists_to_albums(album_ids,artist_ids,dbconn=None): +def add_artists_to_albums(album_ids: list[int], artist_ids: list[int], dbconn=None) -> bool: op = DB['albumartists'].insert().values([ {'album_id':album_id,'artist_id':artist_id} @@ -720,16 +746,14 @@ def add_artists_to_albums(album_ids,artist_ids,dbconn=None): ]) result = dbconn.execute(op) - # the resulting albums could now be duplicates of existing ones # this also takes care of clean_db merge_duplicate_albums(dbconn=dbconn) - return True @connection_provider -def remove_artists_from_albums(album_ids,artist_ids,dbconn=None): +def remove_artists_from_albums(album_ids: list[int], artist_ids: list[int], dbconn=None) -> bool: # no check here, albums are allowed to have zero artists @@ -741,17 +765,16 @@ def remove_artists_from_albums(album_ids,artist_ids,dbconn=None): ) result = dbconn.execute(op) - # the resulting albums could now be duplicates of existing ones # this also takes care of clean_db merge_duplicate_albums(dbconn=dbconn) - return True + ### Merge @connection_provider -def merge_tracks(target_id,source_ids,dbconn=None): +def merge_tracks(target_id: int, source_ids: list[int], dbconn=None) -> bool: op = DB['scrobbles'].update().where( DB['scrobbles'].c.track_id.in_(source_ids) @@ -760,11 +783,11 @@ def merge_tracks(target_id,source_ids,dbconn=None): ) result = dbconn.execute(op) clean_db(dbconn=dbconn) - return True + @connection_provider -def merge_artists(target_id,source_ids,dbconn=None): +def merge_artists(target_id: int, source_ids: list[int], dbconn=None) -> bool: # some tracks could already have multiple of the to be merged artists @@ -792,7 +815,6 @@ def merge_artists(target_id,source_ids,dbconn=None): result = dbconn.execute(op) - # same for albums op = DB['albumartists'].select().where( DB['albumartists'].c.artist_id.in_(source_ids + [target_id]) @@ -813,7 +835,6 @@ def merge_artists(target_id,source_ids,dbconn=None): result = dbconn.execute(op) - # tracks_artists = {} # for row in result: # tracks_artists.setdefault(row.track_id,[]).append(row.artist_id) @@ -830,15 +851,14 @@ def merge_artists(target_id,source_ids,dbconn=None): # result = dbconn.execute(op) # this could have created duplicate tracks and albums - merge_duplicate_tracks(artist_id=target_id,dbconn=dbconn) - merge_duplicate_albums(artist_id=target_id,dbconn=dbconn) + merge_duplicate_tracks(artist_id=target_id, dbconn=dbconn) + merge_duplicate_albums(artist_id=target_id, dbconn=dbconn) clean_db(dbconn=dbconn) - return True @connection_provider -def merge_albums(target_id,source_ids,dbconn=None): +def merge_albums(target_id: int, source_ids: list[int], dbconn=None) -> bool: op = DB['tracks'].update().where( DB['tracks'].c.album_id.in_(source_ids) @@ -847,7 +867,6 @@ def merge_albums(target_id,source_ids,dbconn=None): ) result = dbconn.execute(op) clean_db(dbconn=dbconn) - return True @@ -1622,48 +1641,52 @@ def get_credited_artists(*artists,dbconn=None): @cached_wrapper @connection_provider -def get_track(id,dbconn=None): +def get_track(track_id: int, dbconn=None) -> TrackDict: op = DB['tracks'].select().where( - DB['tracks'].c.id==id + DB['tracks'].c.id == track_id ) result = dbconn.execute(op).all() trackinfo = result[0] - return track_db_to_dict(trackinfo,dbconn=dbconn) + return track_db_to_dict(trackinfo, dbconn=dbconn) + @cached_wrapper @connection_provider -def get_artist(id,dbconn=None): +def get_artist(artist_id: int, dbconn=None) -> str: op = DB['artists'].select().where( - DB['artists'].c.id==id + DB['artists'].c.id == artist_id ) result = dbconn.execute(op).all() artistinfo = result[0] - return artist_db_to_dict(artistinfo,dbconn=dbconn) + return artist_db_to_dict(artistinfo, dbconn=dbconn) + @cached_wrapper @connection_provider -def get_album(id,dbconn=None): +def get_album(album_id: int, dbconn=None) -> AlbumDict: op = DB['albums'].select().where( - DB['albums'].c.id==id + DB['albums'].c.id == album_id ) result = dbconn.execute(op).all() albuminfo = result[0] - return album_db_to_dict(albuminfo,dbconn=dbconn) + return album_db_to_dict(albuminfo, dbconn=dbconn) + @cached_wrapper @connection_provider -def get_scrobble(timestamp, include_internal=False, dbconn=None): +def get_scrobble(timestamp: int, include_internal=False, dbconn=None) -> ScrobbleDict: op = DB['scrobbles'].select().where( - DB['scrobbles'].c.timestamp==timestamp + DB['scrobbles'].c.timestamp == timestamp ) result = dbconn.execute(op).all() scrobble = result[0] return scrobbles_db_to_dict(rows=[scrobble], include_internal=include_internal)[0] + @cached_wrapper @connection_provider def search_artist(searchterm,dbconn=None):