43 lines
1.7 KiB
Python
43 lines
1.7 KiB
Python
import importlib
|
|
from pathlib import Path
|
|
|
|
from loguru import logger
|
|
|
|
import config
|
|
from EngineABC import EngineABC, ModelDescription
|
|
from analytics import measure
|
|
from preprocessing import preprocess
|
|
|
|
|
|
class EnginesController:
|
|
def __init__(self):
|
|
self.engines: dict[str, EngineABC] = dict()
|
|
engines_dir = Path('engines')
|
|
for module_dir in engines_dir.iterdir():
|
|
init_path = module_dir / '__init__.py'
|
|
if init_path.exists():
|
|
module = importlib.machinery.SourceFileLoader(module_dir.name, str(init_path)).load_module()
|
|
logger.info(f'imported {init_path.parent.name!r} engine')
|
|
for obj_name in dir(module):
|
|
obj = getattr(module, obj_name)
|
|
if isinstance(obj, type) and EngineABC in obj.__bases__:
|
|
obj: type[EngineABC]
|
|
cache_path = config.BASE / obj.__name__
|
|
cache_path.mkdir(exist_ok=True)
|
|
instance: EngineABC = obj(save_path=cache_path)
|
|
self.engines[obj.__name__] = instance
|
|
|
|
def discovery(self) -> list[ModelDescription]:
|
|
to_return: list[ModelDescription] = list()
|
|
for name, engine in self.engines.items():
|
|
to_return += engine.discovery()
|
|
|
|
return to_return
|
|
|
|
@measure
|
|
def synth(self, engine: str, model: str, text: str, **kwargs) -> bytes:
|
|
return self.engines[engine].synth(preprocess(text), model, **kwargs)
|
|
|
|
def synth_all(self, text: str = 'Съешь ещё 15 этих пирожков') -> list[bytes]:
|
|
return [self.synth(model.engine, model.name, text=text) for model in self.discovery()]
|