ttsserver/EnginesController.py
2024-05-29 02:39:47 +03:00

43 lines
1.7 KiB
Python

import importlib
from pathlib import Path
from loguru import logger
import config
from EngineABC import EngineABC, ModelDescription
from analytics import measure
from preprocessing import preprocess
class EnginesController:
def __init__(self):
self.engines: dict[str, EngineABC] = dict()
engines_dir = Path('engines')
for module_dir in engines_dir.iterdir():
init_path = module_dir / '__init__.py'
if init_path.exists():
module = importlib.machinery.SourceFileLoader(module_dir.name, str(init_path)).load_module()
logger.info(f'imported {init_path.parent.name!r} engine')
for obj_name in dir(module):
obj = getattr(module, obj_name)
if isinstance(obj, type) and EngineABC in obj.__bases__:
obj: type[EngineABC]
cache_path = config.BASE / obj.__name__
cache_path.mkdir(exist_ok=True)
instance: EngineABC = obj(save_path=cache_path)
self.engines[obj.__name__] = instance
def discovery(self) -> list[ModelDescription]:
to_return: list[ModelDescription] = list()
for name, engine in self.engines.items():
to_return += engine.discovery()
return to_return
@measure
def synth(self, engine: str, model: str, text: str, **kwargs) -> bytes:
return self.engines[engine].synth(preprocess(text), model, **kwargs)
def synth_all(self, text: str = 'Съешь ещё 15 этих пирожков') -> list[bytes]:
return [self.synth(model.engine, model.name, text=text) for model in self.discovery()]