diff --git a/AudioController.py b/AudioController.py index 1c6891e..4ec0c1a 100644 --- a/AudioController.py +++ b/AudioController.py @@ -1,5 +1,4 @@ import queue -import time import comtypes import psutil @@ -47,6 +46,7 @@ class AudioController: """ def __init__(self): + self.running = True self.per_session_callbacks_class = PerSessionCallbacks self._sessions: dict[int, AudioSession] = dict() # Mapping pid to session @@ -58,6 +58,11 @@ class AudioController: self.view = ServerSideView(self.outbound_q, self.inbound_q) + def shutdown_callback(self, sig, frame): + """Gets called by signal module as handler""" + logger.info(f'Shutting down by signal {sig}') + self.running = False + def get_process(self, pid: int) -> psutil.Process: return self._sessions[pid].Process @@ -86,9 +91,12 @@ class AudioController: new_session.register_notification(self.per_session_callbacks_class(new_session.ProcessId, self)) # Notifying - self.outbound_q.put(Events.NewSession(new_session.ProcessId)) - self.outbound_q.put(Events.SetName(new_session.ProcessId, get_app_name(new_session.Process))) - # TODO: Send also volume, mute, state + pid = new_session.ProcessId + self.outbound_q.put(Events.NewSession(pid)) + self.outbound_q.put(Events.SetName(pid, get_app_name(new_session.Process))) + self.outbound_q.put(Events.VolumeChanged(pid, self.get_volume(pid))) + self.outbound_q.put(Events.MuteStateChanged(pid, self.is_muted(pid))) + # TODO: Send also state else: logger.debug("None's process session", new_session, new_session.ProcessId) @@ -113,7 +121,7 @@ class AudioController: def _state_change_tick(self): try: - msg = self._state_change_q.get(timeout=3) + msg = self._state_change_q.get(timeout=0.1) logger.trace(f'New state message {msg}') except Empty: @@ -180,6 +188,10 @@ class AudioController: except Exception: logger.opt(exception=True).warning(f'Failed to unregister_notification() for pid {pid}') + # Notify ServerSideView to stop + self.view.running = False + self.view.join(1) + def set_mute(self, pid: int, is_muted: bool): logger.trace(f'Set mute for {pid} {is_muted=}') self._sessions[pid].SimpleAudioVolume.SetMute(int(is_muted), None) @@ -224,7 +236,7 @@ class AudioController: # self.perform_discover() logger.debug(f'Starting blocking') self.view.start() - while True: + while self.running: # time.sleep(1) self._state_change_tick() self._inbound_q_tick() diff --git a/Events.py b/Events.py index 7ba809c..b962ef0 100644 --- a/Events.py +++ b/Events.py @@ -1,4 +1,6 @@ +from typing import TypeVar, Generator from dataclasses import dataclass, field +from functools import lru_cache """ Processes unique identifies by their PIDs. @@ -33,17 +35,26 @@ From client to server: 2. Mute toggle PID +3. New client + *Literally nothing* + # Set PID to any value + # On this event `ServerSideView` should send full state to clients + # Note: This event should be sent by Transport, not client itself + Cases: 1. New Session: Send `New Session` event Send `Name Changed` event Send `Volume Changed` event Send `Mute State Changed` event - Send `State changed` event - # Let's call this set of events a full view since it fully describes current information about a session + Send `State changed` event; Still TODO: + # This set of events fully describes state of a session -2. Session closed +2. Session closed: Send `Session closed` event + +3. New client: + Send events as in `New Session` case """ @@ -110,13 +121,27 @@ class MuteToggle(ClientToServerEvent): ... +@dataclass +class NewClient(ClientToServerEvent): + ... + + +T = TypeVar('T') + + +@lru_cache def lookup_event(event_name: str) -> Event: - subclasses = dict() - to_handle = [Event] + for cls in enumerate_subclasses(Event): + if cls.__name__ == event_name: + return cls + + raise ValueError(f'Lookup {event_name} failed') + + +def enumerate_subclasses(base: type[T]) -> Generator[T, None, None]: + to_handle = [base] while len(to_handle) > 0: current_item = to_handle.pop() for subclass in current_item.__subclasses__(): - subclasses[subclass.__name__] = subclass + yield subclass to_handle.append(subclass) - - return subclasses[event_name] diff --git a/NetworkTransport.py b/NetworkTransport.py new file mode 100644 index 0000000..f7681d4 --- /dev/null +++ b/NetworkTransport.py @@ -0,0 +1,80 @@ +from typing import Callable +from loguru import logger +from dataclasses import asdict +import socket +import selectors +import Events +import json +from TransportABC import TransportABC + + +class NetworkTransport(TransportABC): + def __init__(self, rcv_callback: Callable[[Events.ClientToServerEvent], None]): + self._selector = selectors.DefaultSelector() + self.view_rcv_callback = rcv_callback + + self._sock = socket.socket() + self._sock.bind(('localhost', 54683)) + self._sock.listen(100) + self._sock.setblocking(False) + self._selector.register(self._sock, selectors.EVENT_READ, self._accept) + + self._connections: list[socket.socket] = list() + + def send(self, msg: Events.ServerToClientEvent): + """This method gets called by `ServerSideView` when it wants to send a message to the client""" + + # logger.debug(f'Sending {asdict(msg)}') + msg = json.dumps(asdict(msg)).encode() + b'\n' # TODO: Remove new line probably + self._send_to_all(msg) + + def _send_to_all(self, msg: bytes): + for conn in self._connections: + conn.sendall(msg) + + def _accept(self, sock: socket.socket, mask: int): + """Callback which get called when accepting new connection""" + + conn, addr = sock.accept() + logger.debug(f'Net: Accepted {conn.getpeername()}') + conn.setblocking(False) + self._selector.register(conn, selectors.EVENT_READ, self._on_socket_receive) + self._connections.append(conn) + self.view_rcv_callback(Events.NewClient(-1)) + + def _close_conn(self, conn: socket.socket): + logger.debug(f'Net: Closing connection to {conn.getpeername()}') + self._selector.unregister(conn) + self._connections.remove(conn) + conn.close() + + def _on_socket_receive(self, conn: socket.socket, mask: int): + data = conn.recv(1000) + if not data: + self._close_conn(conn) + return + + try: + event_dict = json.loads(data) + event_name = event_dict['event'] + event_cls = Events.lookup_event(event_name) + del event_dict['event'] + logger.trace(f'Passing msg {event_dict} from client {conn.getpeername()}') + event = event_cls(**event_dict) # noqa + self.view_rcv_callback(event) + + except Exception: + logger.opt(exception=True).warning(f"Couldn't parse message from client: {data}") + + def tick(self): + events = self._selector.select(timeout=0) + for key, mask in events: + callback = key.data + callback(key.fileobj, mask) + + def shutdown(self): + logger.debug(f'Net: Shutting down') + while len(self._connections) > 0: + self._close_conn(self._connections[0]) + + logger.trace(f'Net: Shutdown completed, clients disconnected') diff --git a/ProcessAudioController.py b/ProcessAudioController.py deleted file mode 100644 index df185ba..0000000 --- a/ProcessAudioController.py +++ /dev/null @@ -1,78 +0,0 @@ -from pycaw.pycaw import AudioUtilities -import pycaw.utils -from get_app_name import get_app_name - - -def get_process_session(pid: int) -> pycaw.utils.AudioSession | None: - sessions = AudioUtilities.GetAllSessions() - for session in sessions: - if session.Process and session.Process.pid == pid: - return session - - -class ProcessAudioController: - def __init__(self, *, pid: int = None, audio_session: pycaw.utils.AudioSession = None): - if pid is not None: - self._process_session = get_process_session(pid) - - if audio_session is not None: - self._process_session = audio_session - - self.process = self._process_session.Process - self.process_description = get_app_name(self.process) - - def mute(self): - self._process_session.SimpleAudioVolume.SetMute(1, None) - print(self.process.name(), "has been muted.") # debug - - def unmute(self): - self._process_session.SimpleAudioVolume.SetMute(0, None) - print(self.process.name(), "has been unmuted.") # debug - - def get_process_volume(self): - return self._process_session.SimpleAudioVolume.GetMasterVolume() - - @property - def volume(self): - return self.get_process_volume() - - def set_volume(self, decibels: float): - new_volume = min(1.0, max(0.0, decibels)) - self._process_session.SimpleAudioVolume.SetMasterVolume(new_volume, None) - print("Volume set to", new_volume) # debug - - def decrease_volume(self, decibels: float): - volume = max(0.0, self.volume - decibels) - self._process_session.SimpleAudioVolume.SetMasterVolume(volume, None) - print("Volume reduced to", volume) # debug - - def increase_volume(self, decibels: float): - # 1.0 is the max value, raise by decibels - new_volume = min(1.0, self.volume + decibels) - self._process_session.SimpleAudioVolume.SetMasterVolume(new_volume, None) - print("Volume raised to", new_volume) # debug - - -class AudioController: - processes: dict[int, ProcessAudioController] = dict() # PIDs as keys - _selected_process: Optional[ProcessAudioController] = None - - def __init__(self, view: ViewABC): - self.view = view - for session in AudioUtilities.GetAllSessions(): - if session.ProcessId != 0: - audio_process_controller = ProcessAudioController(audio_session=session) - self.processes[audio_process_controller.process_description] = audio_process_controller - - if len(self.processes) > 0: - self.selected_process = next(iter(self.processes)) - - @property - def selected_process(self) -> Optional[ProcessAudioController]: - return self._selected_process - - @selected_process.setter - def selected_process(self, pid_to_select: int): - self._selected_process = self.processes[pid_to_select] - self.view.select_process_callback(self.selected_process) - diff --git a/README.md b/README.md new file mode 100644 index 0000000..d11660c --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +A backend for application for remote control over windows mixer. +You can find events reference in `Events.py`, those events 1:1 map to json (dictionaries) they produce. +For now transport over tcp sockets is implemented. diff --git a/ServerSideView.py b/ServerSideView.py index 32ae96e..faa17bb 100644 --- a/ServerSideView.py +++ b/ServerSideView.py @@ -3,17 +3,19 @@ from queue import Queue from threading import Thread from loguru import logger import Events -from typing import TypedDict +# from typing import TypedDict from dataclasses import asdict -from Transport import Transport + +from TransportABC import TransportABC +from NetworkTransport import NetworkTransport -class SessionState(TypedDict): - pid: int - volume: float - is_muted: bool - is_active: bool - name: str +# class SessionState(TypedDict): +# pid: int +# volume: float +# is_muted: bool +# is_active: bool +# name: str class ServerSideView(Thread): @@ -24,16 +26,17 @@ class ServerSideView(Thread): The common concept: `AudioController` put messages from callbacks to queue which reads `ServerSideView` which keep up with `ClientSideView` (a client). - `ClientSideView` sends `Events` over `Transport` to `ServerSideView` which put messages to another queue - which reads `AudioController` which performs action specified in messages. + `ClientSideView` sends `Events` over `Transport` to `ServerSideView`. `Transport` calls `ServerSideView` callback + which put messages to queue which is reading by `AudioController` which performs action specified in messages. `AudioController`'s work with queues performs in main thread. Callback calls by pycaw performs in pycaw's internal threads. `ServerSideView` executing in its own thread. """ daemon = True + running = True - def __init__(self, inbound_q: Queue, outbound_q: Queue, transport: Transport = Transport()): + def __init__(self, inbound_q: Queue, outbound_q: Queue): """ :param inbound_q: Queue from AudioController to ServerSideView :param outbound_q: Queue from ServerSideView to AudioController @@ -42,32 +45,41 @@ class ServerSideView(Thread): self.inbound_q = inbound_q self.outbound_q = outbound_q - self.transport = transport + self.transport: TransportABC = NetworkTransport(self.rcv_callback) - self._state: dict[int, SessionState] = dict() # Holds current state of sessions received from AudioController + self._state: dict[int, dict[str, int | float | str]] = dict() # Holds current state of sessions received from AudioController # PID : SessionState + def rcv_callback(self, event: Events.ClientToServerEvent): + if isinstance(event, Events.NewClient): + self.inbound_q.put(event) + + else: + self.outbound_q.put(event) + def run(self) -> None: - while True: + while self.running: try: - msg: Events.ServerToClientEvent = self.inbound_q.get_nowait() + msg: Events.Event = self.inbound_q.get(timeout=0.1) except queue.Empty: pass else: - logger.debug(msg) - self._update_state(msg) - self.transport.send(msg) + # logger.debug(msg) + if isinstance(msg, Events.ServerToClientEvent): + self._update_state(msg) + self.transport.send(msg) + + elif isinstance(msg, Events.NewClient): + self._send_full_state() + + else: + logger.warning(f'Unknown event {msg}') self.transport.tick() - try: - new_msg = self.transport.receive() - self.outbound_q.put(new_msg) - - except queue.Empty: - pass + self.transport.shutdown() def _update_state(self, event: Events.ServerToClientEvent) -> None: if isinstance(event, Events.NewSession): @@ -82,4 +94,27 @@ class ServerSideView(Thread): self._state[event.PID].update(dicted) - logger.debug(f'state: {self._state}') + # logger.trace(f'state: {self._state}') + + def _send_full_state(self): + """Send full state of sessions to clients""" + logger.trace(f'Sending full state') + subclasses = tuple(Events.enumerate_subclasses(Events.ServerToClientEvent)) + for session in self._state.values(): + for cls in subclasses: + if cls.__name__ == 'SessionClosed': + continue + + try: + kwargs = dict() + for field in cls.__dict__['__dataclass_fields__'].keys(): + if field != 'event': + # args.append(session[field]) + kwargs[field] = session[field] + + event: Events.ServerToClientEvent = cls(**kwargs) # Noqa + self.transport.send(event) + + except KeyError: # We don't have appropriate field in state for this kind of events + # logger.debug(f'Passing {cls}') + pass diff --git a/Transport.py b/Transport.py deleted file mode 100644 index 749cbfb..0000000 --- a/Transport.py +++ /dev/null @@ -1,67 +0,0 @@ -from loguru import logger -from dataclasses import asdict -import socket -import selectors -from queue import Queue -import Events -import json - - -class Transport: - def __init__(self): - self._selector = selectors.DefaultSelector() - self._from_net_q = Queue() - - self._sock = socket.socket() - self._sock.bind(('localhost', 54683)) - self._sock.listen(100) - self._sock.setblocking(False) - self._selector.register(self._sock, selectors.EVENT_READ, self._accept) - - self._connections: list[socket.socket] = list() - # self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # self.sock.bind(("127.0.0.1", 54683)) - ... - - def send(self, msg: Events.ServerToClientEvent): - logger.debug(f'Sending {asdict(msg)}') - msg = json.dumps(asdict(msg)).encode() - for conn in self._connections: - conn.sendall(msg) - - def receive(self) -> Events.ClientToServerEvent: - return self._from_net_q.get_nowait() - - def _accept(self, sock: socket.socket, mask: int): - conn, addr = sock.accept() # Should be ready - print('accepted', conn, 'from', addr) - conn.setblocking(False) - self._selector.register(conn, selectors.EVENT_READ, self._read) - self._connections.append(conn) - - def _read(self, conn: socket.socket, mask: int): - data = conn.recv(1000) - if not data: - logger.debug(f'Closing connection to {conn.getpeername()}') - self._selector.unregister(conn) - self._connections.remove(conn) - conn.close() - return - - try: - event_dict = json.loads(data) - event_name = event_dict['event'] - event_cls = Events.lookup_event(event_name) - del event_dict['event'] - logger.trace(f'Passing msg {event_dict} from client {conn.getpeername()}') - self._from_net_q.put(event_cls(**event_dict)) - - except Exception: - logger.opt(exception=True).warning(f"Couldn't parse message from client: {data}") - - def tick(self): - events = self._selector.select(timeout=0) - for key, mask in events: - callback = key.data - callback(key.fileobj, mask) diff --git a/TransportABC.py b/TransportABC.py new file mode 100644 index 0000000..5e7db23 --- /dev/null +++ b/TransportABC.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from typing import Callable + +import Events + + +class TransportABC(ABC): + @abstractmethod + def __init__(self, rcv_callback: Callable[[Events.ClientToServerEvent], None]): + """Should call rcv_callback in order to pass received from client event""" + + @abstractmethod + def send(self, msg: Events.ServerToClientEvent): + """This method gets called by `ServerSideView` when it has an event to send to client""" + + @abstractmethod + def tick(self): + """This method get called by `ServerSideView` every little piece of time in order to allow + `Transport` to handle inbound messages (or other stuff `Transport` should do continuously""" + + @abstractmethod + def shutdown(self): + """Gets called by `ServerSideView` on program shutdown, the class should clean up all connections + and such staff""" diff --git a/main.py b/main.py index 8da4ac4..8ebeccd 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ import sys; sys.coinit_flags = 0 # noqa from loguru import logger import logging +import signal class InterceptHandler(logging.Handler): @@ -29,6 +30,10 @@ import AudioController mgr = AudioUtilities.GetAudioSessionManager() audio_controller = AudioController.AudioController() + +signal.signal(signal.SIGTERM, audio_controller.shutdown_callback) +signal.signal(signal.SIGINT, audio_controller.shutdown_callback) + callback = AudioController.SessionCreateCallback(audio_controller) mgr.RegisterSessionNotification(callback) @@ -43,4 +48,3 @@ except KeyboardInterrupt: finally: mgr.UnregisterSessionNotification(callback) audio_controller.pre_shutdown() - logger.debug(audio_controller._sessions)