diff --git a/DB.py b/DB.py index 7139bd7..0b968be 100644 --- a/DB.py +++ b/DB.py @@ -1,11 +1,10 @@ -import os from datetime import datetime -from pathlib import Path import peewee -DB_PATH = Path(os.getenv('DATA_DIR', '.')) / 'voice_cache.sqlite' -database = peewee.SqliteDatabase(str(DB_PATH)) +import config + +database = peewee.SqliteDatabase(str(config.DB_PATH)) class BaseModel(peewee.Model): diff --git a/Dockerfile b/Dockerfile index 6e3073d..9de4b60 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,8 +8,7 @@ ENV PYTHONUNBUFFERED 1 RUN apt update && apt install git gcc libc6-dev -y --no-install-recommends && apt clean && rm -rf /var/lib/apt/lists/* COPY requirements.txt . RUN --mount=type=cache,target=/root/.cache/pip pip install Cython && \ - pip wheel --no-deps --wheel-dir /app/wheels -r requirements.txt && \ - pip wheel torch numpy --wheel-dir /app/wheels --index-url https://download.pytorch.org/whl/cpu + pip wheel --no-deps --wheel-dir /app/wheels -r requirements.txt @@ -20,7 +19,6 @@ ENV PYTHONUNBUFFERED 1 RUN useradd -ms /bin/bash silero_user && \ apt update && apt install ffmpeg -y --no-install-recommends && apt clean && rm -rf /var/lib/apt/lists/* -USER silero_user WORKDIR /app @@ -30,5 +28,6 @@ COPY --from=builder /app/requirements.txt . RUN pip install --no-cache /wheels/* COPY . . +USER silero_user CMD ["python3", "/app/main.py"] diff --git a/DynamicCommandPrefix.py b/DynamicCommandPrefix.py index ea0e14d..548b55c 100644 --- a/DynamicCommandPrefix.py +++ b/DynamicCommandPrefix.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- import discord -from discord.ext import commands import peewee +from discord.ext import commands + import DB + + # from loguru import logger diff --git a/FFmpegPCMAudioModified.py b/FFmpegPCMAudioModified.py index 1c00475..6cf1943 100644 --- a/FFmpegPCMAudioModified.py +++ b/FFmpegPCMAudioModified.py @@ -1,8 +1,10 @@ -import subprocess -import shlex import io -from discord.opus import Encoder +import shlex +import subprocess + import discord +from discord.opus import Encoder + # Huge thanks to https://github.com/Rapptz/discord.py/issues/5192#issuecomment-669515571 diff --git a/Observ/Observer.py b/Observ/Observer.py index d17d793..affdd31 100644 --- a/Observ/Observer.py +++ b/Observ/Observer.py @@ -1,8 +1,9 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod + import discord class Observer: @abstractmethod - def update(self, message: discord.Message) -> None: + async def update(self, message: discord.Message) -> None: raise NotImplemented diff --git a/Observ/Subject.py b/Observ/Subject.py index 65c0b57..ce42785 100644 --- a/Observ/Subject.py +++ b/Observ/Subject.py @@ -1,4 +1,5 @@ import discord + from .Observer import Observer diff --git a/SpeakersSettingsAdapterDiscord.py b/SpeakersSettingsAdapterDiscord.py index 51a46ec..1927904 100644 --- a/SpeakersSettingsAdapterDiscord.py +++ b/SpeakersSettingsAdapterDiscord.py @@ -1,12 +1,9 @@ # -*- coding: utf-8 -*- import DB -from TTSSilero import Speakers class SpeakersSettingsAdapterDiscord: - DEFAULT_SPEAKER = Speakers.kseniya - - def get_speaker(self, guild_id: int, user_id: int) -> Speakers: + def get_speaker(self, guild_id: int, user_id: int) -> str: user_defined_speaker = self.get_speaker_user(guild_id, user_id) if user_defined_speaker is None: return self.get_speaker_global(guild_id) @@ -14,36 +11,29 @@ class SpeakersSettingsAdapterDiscord: else: return user_defined_speaker - def get_speaker_global(self, guild_id: int) -> Speakers: + def get_speaker_global(self, guild_id: int) -> str | None: server_speaker_query = DB.ServerSpeaker.select()\ .where(DB.ServerSpeaker.server_id == guild_id) if server_speaker_query.count() == 1: - return Speakers(server_speaker_query.get().speaker) + return server_speaker_query.get().speaker else: - return self.DEFAULT_SPEAKER + return None - def get_speaker_user(self, guild_id: int, user_id: int) -> Speakers | None: + def get_speaker_user(self, guild_id: int, user_id: int) -> str | None: user_speaker_query = DB.UserServerSpeaker.select()\ .where(DB.UserServerSpeaker.server_id == guild_id)\ .where(DB.UserServerSpeaker.user_id == user_id) if user_speaker_query.count() == 1: - return Speakers(user_speaker_query.get().speaker) + return user_speaker_query.get().speaker else: return None - @property - def available_speakers(self) -> set[str]: - return {speaker.name for speaker in Speakers} + def set_speaker_user(self, guild_id: int, user_id: int, speaker: str) -> None: + DB.UserServerSpeaker.replace(server_id=guild_id, user_id=user_id, speaker=speaker).execute() - def set_speaker_user(self, guild_id: int, user_id: int, speaker: Speakers) -> None: - DB.UserServerSpeaker.replace(server_id=guild_id, user_id=user_id, speaker=speaker.value).execute() - - def set_speaker_global(self, guild_id: int, speaker: Speakers) -> None: - DB.ServerSpeaker.replace(server_id=guild_id, speaker=speaker.value).execute() - - -speakers_settings_adapter = SpeakersSettingsAdapterDiscord() + def set_speaker_global(self, guild_id: int, user_id: int, speaker: str) -> None: + DB.ServerSpeaker.replace(server_id=guild_id, speaker=speaker).execute() diff --git a/TTSServer/Speakers.py b/TTSServer/Speakers.py new file mode 100644 index 0000000..a5d30f0 --- /dev/null +++ b/TTSServer/Speakers.py @@ -0,0 +1,15 @@ +from typing import Literal + +from pydantic import BaseModel + + +class Argument(BaseModel): + type: Literal['str', 'int', 'float'] + description: str | None = None + + +class ModelDescription(BaseModel): + engine: str + name: str + arguments: dict[str, Argument] + description: None | str = None diff --git a/TTSServer/TTSServer.py b/TTSServer/TTSServer.py new file mode 100644 index 0000000..089fa4f --- /dev/null +++ b/TTSServer/TTSServer.py @@ -0,0 +1,19 @@ +import aiohttp + +import config +from .Speakers import ModelDescription + + +class TTSServer: + def __init__(self): + self.session = aiohttp.ClientSession(base_url=config.BASE_URL) + + async def synthesize_text(self, text: str, speaker: ModelDescription) -> bytes: + async with self.session.post(url=f'/synth/{speaker.engine}/model/{speaker.name}', json={'text': text}) as req: + req.raise_for_status() + return await req.content.read() + + async def discovery(self) -> list[ModelDescription]: + async with self.session.get('/discovery') as req: + req.raise_for_status() + return [ModelDescription(**item) for item in await req.json()] diff --git a/TTSServer/TTSServerCached.py b/TTSServer/TTSServerCached.py new file mode 100644 index 0000000..33dae3d --- /dev/null +++ b/TTSServer/TTSServerCached.py @@ -0,0 +1,53 @@ +from loguru import logger + +import DB +from .Speakers import ModelDescription +from .TTSServer import TTSServer + + +class TTSServerCached(TTSServer): + def __init__(self): + super().__init__() + self.models_lookup_idx: dict[str, ModelDescription] = dict() + + async def synthesize_text(self, text: str, speaker: ModelDescription) -> bytes: + cache_query = DB.SoundCache.select() \ + .where(DB.SoundCache.text == text) \ + .where(DB.SoundCache.speaker == speaker) + + if cache_query.count() == 1: + with DB.database.atomic(): + DB.SoundCache.update({DB.SoundCache.usages: DB.SoundCache.usages + 1}) \ + .where(DB.SoundCache.text == text) \ + .where(DB.SoundCache.speaker == speaker).execute() + + cached = cache_query.get().audio + + return cached + + else: + synthesized = await super().synthesize_text(text, speaker) + DB.SoundCache.create(text=text, speaker=speaker, audio=synthesized) + return synthesized + + async def discovery(self) -> list[ModelDescription]: + res = await super().discovery() + logger.debug(f'Discovered {len(res)} models') + self.models_lookup_idx: dict[str, ModelDescription] = {f'{desc.engine}_{desc.name}': desc for desc in res} + return res + + async def speaker_to_description(self, speaker_str: str) -> ModelDescription | None: + for attempt in range(2): + if len(self.models_lookup_idx) > 0: + break + + await self.discovery() + + else: + raise RuntimeError('Models discovery seems to return zero models') + + # Default to first model + return self.models_lookup_idx.get(speaker_str, None) + + def default_speaker(self) -> ModelDescription: + return list(self.models_lookup_idx.values())[0] diff --git a/TTSServer/__init__.py b/TTSServer/__init__.py new file mode 100644 index 0000000..97eb52f --- /dev/null +++ b/TTSServer/__init__.py @@ -0,0 +1,3 @@ +from .Speakers import ModelDescription +from .TTSServer import TTSServer +from .TTSServerCached import TTSServerCached diff --git a/TTSSilero/Speakers.py b/TTSSilero/Speakers.py deleted file mode 100644 index 1bfb4a2..0000000 --- a/TTSSilero/Speakers.py +++ /dev/null @@ -1,15 +0,0 @@ -import enum - - -class Speakers(enum.Enum): - aidar = 'aidar' - baya = 'baya' - kseniya = 'kseniya' - irina = 'irina' - ruslan = 'ruslan' - natasha = 'natasha' - thorsten = 'thorsten' - tux = 'tux' - gilles = 'gilles' - lj = 'lj' - dilyara = 'dilyara' diff --git a/TTSSilero/TTSSilero.py b/TTSSilero/TTSSilero.py deleted file mode 100644 index bfdbd7c..0000000 --- a/TTSSilero/TTSSilero.py +++ /dev/null @@ -1,63 +0,0 @@ -import contextlib -import io -import os -import wave -from pathlib import Path - -import torch.package - -from .Speakers import Speakers -from .multi_v2_package import TTSModelMulti_v2 - - -class TTSSilero: - def __init__(self, threads: int = 12): - device = torch.device('cpu') - torch.set_num_threads(threads) - local_file = Path(os.getenv('DATA_DIR', '.')) / 'model_multi.pt' - - if not os.path.isfile(local_file): - torch.hub.download_url_to_file( - 'https://models.silero.ai/models/tts/multi/v2_multi.pt', - str(local_file) - ) - - self.model: TTSModelMulti_v2 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model") - self.model.to(device) - - self.sample_rate = 16000 - - def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes: - return self.to_wav(self._synthesize_text(text, speaker)) - - def _synthesize_text(self, text: str, speaker: Speakers) -> list[torch.Tensor]: - """ - Performs splitting text and synthesizing it - - :param text: - :return: - """ - - results_list: list[torch.Tensor] = self.model.apply_tts( - texts=[text], - speakers=speaker.value, - sample_rate=self.sample_rate - ) - - return results_list - - def to_wav(self, synthesized_text: list[torch.Tensor]) -> bytes: - res_io_stream = io.BytesIO() - - with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - wf.setframerate(self.sample_rate) - for result in synthesized_text: - wf.writeframes((result * 32767).numpy().astype('int16')) - - res_io_stream.seek(0) - - return res_io_stream.read() - - diff --git a/TTSSilero/TTSSileroCached.py b/TTSSilero/TTSSileroCached.py deleted file mode 100644 index 75ba3e6..0000000 --- a/TTSSilero/TTSSileroCached.py +++ /dev/null @@ -1,36 +0,0 @@ -import time -from typing import Union - -from .TTSSilero import TTSSilero -from .Speakers import Speakers -import DB -import sqlite3 -from loguru import logger - - -class TTSSileroCached(TTSSilero): - def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes: - # start = time.time() - cache_query = DB.SoundCache.select()\ - .where(DB.SoundCache.text == text)\ - .where(DB.SoundCache.speaker == speaker.value) - - if cache_query.count() == 1: - with DB.database.atomic(): - DB.SoundCache.update({DB.SoundCache.usages: DB.SoundCache.usages + 1})\ - .where(DB.SoundCache.text == text)\ - .where(DB.SoundCache.speaker == speaker.value).execute() - - cached = cache_query.get().audio - - return cached - - else: - # logger.debug(f'Starting synthesis') - # start2 = time.time() - synthesized = super().synthesize_text(text, speaker) - # logger.debug(f'Synthesis done in {time.time() - start2} s in {time.time() - start} s after start') - DB.SoundCache.create(text=text, speaker=speaker.value, audio=synthesized) - - # logger.debug(f'Cache set in {time.time() - start2} synth end and {time.time() - start2} s after start') - return synthesized diff --git a/TTSSilero/__init__.py b/TTSSilero/__init__.py deleted file mode 100644 index 4adca7f..0000000 --- a/TTSSilero/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .TTSSilero import TTSSilero -from .Speakers import Speakers -from .TTSSileroCached import TTSSileroCached diff --git a/TTSSilero/multi_v2_package.py b/TTSSilero/multi_v2_package.py deleted file mode 100644 index 5cbe4f7..0000000 --- a/TTSSilero/multi_v2_package.py +++ /dev/null @@ -1,148 +0,0 @@ -import re -import wave -import torch -import warnings -import contextlib - -# for type hints only - - -class TTSModelMulti_v2(): - def __init__(self, model_path, symbols): - self.model = self.init_jit_model(model_path) - self.symbols = symbols - self.device = torch.device('cpu') - speakers = ['aidar', 'baya', 'kseniya', 'irina', 'ruslan', 'natasha', - 'thorsten', 'tux', 'gilles', 'lj', 'dilyara'] - self.speaker_to_id = {sp: i for i, sp in enumerate(speakers)} - - def init_jit_model(self, model_path: str): - torch.set_grad_enabled(False) - model = torch.jit.load(model_path, map_location='cpu') - model.eval() - return model - - def prepare_text_input(self, text, symbols, symbol_to_id=None): - if len(text) > 140: - warnings.warn('Text string is longer than 140 symbols.') - - if symbol_to_id is None: - symbol_to_id = {s: i for i, s in enumerate(symbols)} - - text = text.lower() - text = re.sub(r'[^{}]'.format(symbols[2:]), '', text) - text = re.sub(r'\s+', ' ', text).strip() - if text[-1] not in ['.', '!', '?']: - text = text + '.' - text = text + symbols[1] - - text_ohe = [symbol_to_id[s] for s in text if s in symbols] - text_tensor = torch.LongTensor(text_ohe) - return text_tensor - - def prepare_tts_model_input(self, text: str or list, symbols: str, speakers: list): - assert len(speakers) == len(text) or len(speakers) == 1 - if type(text) == str: - text = [text] - symbol_to_id = {s: i for i, s in enumerate(symbols)} - if len(text) == 1: - return self.prepare_text_input(text[0], symbols, symbol_to_id).unsqueeze(0), torch.LongTensor(speakers), torch.LongTensor([0]) - - text_tensors = [] - for string in text: - string_tensor = self.prepare_text_input(string, symbols, symbol_to_id) - text_tensors.append(string_tensor) - input_lengths, ids_sorted_decreasing = torch.sort( - torch.LongTensor([len(t) for t in text_tensors]), - dim=0, descending=True) - max_input_len = input_lengths[0] - batch_size = len(text_tensors) - - text_padded = torch.ones(batch_size, max_input_len, dtype=torch.int32) - if len(speakers) == 1: - speakers = speakers*batch_size - speaker_ids = torch.LongTensor(batch_size).zero_() - - for i, idx in enumerate(ids_sorted_decreasing): - text_tensor = text_tensors[idx] - in_len = text_tensor.size(0) - text_padded[i, :in_len] = text_tensor - speaker_ids[i] = speakers[idx] - - return text_padded, speaker_ids, ids_sorted_decreasing - - def process_tts_model_output(self, out, out_lens, ids): - out = out.to('cpu') - out_lens = out_lens.to('cpu') - _, orig_ids = ids.sort() - - proc_outs = [] - orig_out = out.index_select(0, orig_ids) - orig_out_lens = out_lens.index_select(0, orig_ids) - - for i, out_len in enumerate(orig_out_lens): - proc_outs.append(orig_out[i][:out_len]) - return proc_outs - - def to(self, device): - self.model = self.model.to(device) - self.device = device - - def get_speakers(self, speakers: str or list): - if type(speakers) == str: - speakers = [speakers] - speaker_ids = [] - for speaker in speakers: - try: - speaker_id = self.speaker_to_id[speaker] - speaker_ids.append(speaker_id) - except Exception: - raise ValueError(f'No such speaker: {speaker}') - return speaker_ids - - def apply_tts(self, texts: str or list, - speakers: str or list, - sample_rate: int = 16000): - speaker_ids = self.get_speakers(speakers) - text_padded, speaker_ids, orig_ids = self.prepare_tts_model_input(texts, - symbols=self.symbols, - speakers=speaker_ids) - with torch.inference_mode(): - out, out_lens = self.model(text_padded.to(self.device), - speaker_ids.to(self.device), - sr=sample_rate) - audios = self.process_tts_model_output(out, out_lens, orig_ids) - return audios - - @staticmethod - def write_wave(path, audio, sample_rate): - """Writes a .wav file. - Takes path, PCM audio data, and sample rate. - """ - with contextlib.closing(wave.open(path, 'wb')) as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - wf.setframerate(sample_rate) - wf.writeframes(audio) - - def save_wav(self, texts: str or list, - speakers: str or list, - audio_pathes: str or list = '', - sample_rate: int = 16000): - if type(texts) == str: - texts = [texts] - - if not audio_pathes: - audio_pathes = [f'test_{str(i).zfill(3)}.wav' for i in range(len(texts))] - if type(audio_pathes) == str: - audio_pathes = [audio_pathes] - assert len(audio_pathes) == len(texts) - - audio = self.apply_tts(texts=texts, - speakers=speakers, - sample_rate=sample_rate) - for i, _audio in enumerate(audio): - self.write_wave(path=audio_pathes[i], - audio=(_audio * 32767).numpy().astype('int16'), - sample_rate=sample_rate) - return audio_pathes diff --git a/cogErrorHandlers.py b/cogErrorHandlers.py index d9410e7..f8e4c43 100644 --- a/cogErrorHandlers.py +++ b/cogErrorHandlers.py @@ -1,6 +1,6 @@ -from loguru import logger from discord.ext import commands from discord.ext.commands import Context +from loguru import logger class cogErrorHandlers: diff --git a/cogs/BotInformation.py b/cogs/BotInformation.py index fb49d84..febe88b 100644 --- a/cogs/BotInformation.py +++ b/cogs/BotInformation.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- +import datetime +import time + import discord from discord.ext import commands, tasks from discord.ext.commands import Context -import datetime -import time class BotInformation(commands.Cog): @@ -15,9 +16,9 @@ class BotInformation(commands.Cog): @commands.command('info') async def get_info(self, ctx: Context): info = f""" -Text-To-Speech bot, based on Silero TTS model () +Text-To-Speech bot License: `GNU GENERAL PUBLIC LICENSE Version 3` -Author discord: `a31#6403` +Author discord: `@a31` Author email: `a31@demb.uk` Source code on github: Source code on gitea a31's instance: diff --git a/cogs/BotManagement.py b/cogs/BotManagement.py index 8d5cf34..94900a5 100644 --- a/cogs/BotManagement.py +++ b/cogs/BotManagement.py @@ -1,15 +1,22 @@ # -*- coding: utf-8 -*- import asyncio import time -from formatting import format_table -from discord.ext import commands -from discord.ext.commands import Context -import rfoo -from rfoo.utils import rconsole +try: + import rfoo + from rfoo.utils import rconsole + +except ImportError as e: + print("not importing rfoo", e) + from threading import Thread from typing import Optional, Coroutine + +from discord.ext import commands +from discord.ext.commands import Context from loguru import logger +from formatting import format_table + class SingletonBase: def __new__(cls, *args, **kwargs): @@ -24,7 +31,7 @@ class BotManagement(commands.Cog, SingletonBase): instance: Optional['BotManagement'] = None rfoo_server_thread: Optional[Thread] = None - rfoo_server: Optional[rfoo.InetServer] = None + rfoo_server: Optional['rfoo.InetServer'] = None # def __new__(cls, *args, **kwargs): # if cls.instance is None: @@ -105,7 +112,7 @@ class BotManagement(commands.Cog, SingletonBase): @commands.command('shutdown') async def shutdown(self, ctx: Context) -> None: """Shutdown bot with hope docker daemon will restart it, `@a31` and `@furiz__` has rights for it""" - if ctx.author.id in (420130693696323585, # @a31 + if ctx.author.id in (420130693696323585, # @a31 444819880781545472): # @furiz__ log_msg = f"Got shutdown command by {ctx.author}" @@ -116,7 +123,7 @@ class BotManagement(commands.Cog, SingletonBase): await (await self.bot.fetch_user(420130693696323585)).send(log_msg) except Exception as e: - logger.opt(exception=e).warning(f'Failed to send shutdown message',) + logger.opt(exception=e).warning(f'Failed to send shutdown message', ) self.bot.loop.create_task(self.bot.close()) @@ -127,7 +134,7 @@ class BotManagement(commands.Cog, SingletonBase): async def setup(bot: commands.Bot): await bot.add_cog(BotManagement(bot)) - async def teardown(bot): stop_res = BotManagement(bot).stop_rfoo() logger.info(f'Unloaded rfoo with result {stop_res} during BotManagement unload') + diff --git a/cogs/TTSCore.py b/cogs/TTSCore.py index bd68a8c..4be15ff 100644 --- a/cogs/TTSCore.py +++ b/cogs/TTSCore.py @@ -2,18 +2,21 @@ import subprocess import time from collections import defaultdict +from typing import Union, Optional + +import discord from discord.ext import commands from discord.ext.commands import Context -import discord -import DB -from typing import Union, Optional from loguru import logger -from TTSSilero import TTSSileroCached -from TTSSilero import Speakers -from FFmpegPCMAudioModified import FFmpegPCMAudio + +import DB import Observ +import formatting +from FFmpegPCMAudioModified import FFmpegPCMAudio +from SpeakersSettingsAdapterDiscord import SpeakersSettingsAdapterDiscord +from TTSServer import TTSServerCached +from TTSServer.Speakers import ModelDescription from cogErrorHandlers import cogErrorHandlers -from SpeakersSettingsAdapterDiscord import speakers_settings_adapter, SpeakersSettingsAdapterDiscord class TTSCore(commands.Cog, Observ.Observer): @@ -21,9 +24,9 @@ class TTSCore(commands.Cog, Observ.Observer): self.bot = bot self.cog_command_error = cogErrorHandlers.missing_argument_handler self.bot.subscribe(self) # subscribe for messages that aren't commands - self.tts = TTSSileroCached() + self.tts = TTSServerCached() self.tts_queues: dict[int, list[discord.AudioSource]] = defaultdict(list) - self.speakers_adapter: SpeakersSettingsAdapterDiscord = speakers_settings_adapter + self.speakers_adapter: SpeakersSettingsAdapterDiscord = SpeakersSettingsAdapterDiscord() @commands.command('drop') async def drop_queue(self, ctx: Context): @@ -89,25 +92,28 @@ class TTSCore(commands.Cog, Observ.Observer): await message.channel.send('We are in different voice channels') return - speaker: Speakers = self.speakers_adapter.get_speaker(message.guild.id, message.author.id) + speaker_str: str = self.speakers_adapter.get_speaker(message.guild.id, message.author.id) + speaker: ModelDescription | None = await self.tts.speaker_to_description(speaker_str) + if speaker is None: + speaker = self.tts.default_speaker() # check if message will fail on synthesis - if DB.SynthesisErrors.select()\ - .where(DB.SynthesisErrors.speaker == speaker.value)\ - .where(DB.SynthesisErrors.text == message.content)\ + if DB.SynthesisErrors.select() \ + .where(DB.SynthesisErrors.speaker == speaker_str) \ + .where(DB.SynthesisErrors.text == message.content) \ .count() == 1: # Then we will not try to synthesis it await message.channel.send(f"I will not synthesis this message due to TTS engine limitations") return try: - wav_file_like_object = self.tts.synthesize_text(message.content, speaker=speaker) + wav_file_like_object: bytes = await self.tts.synthesize_text(message.content, speaker=speaker) sound_source = FFmpegPCMAudio(wav_file_like_object, pipe=True, stderr=subprocess.PIPE) except Exception as synth_exception: logger.opt(exception=True).warning(f'Exception on synthesize {message.content!r}: {synth_exception}') await message.channel.send(f'Synthesize error') - DB.SynthesisErrors.create(speaker=speaker.value, text=message.content) + DB.SynthesisErrors.create(speaker=speaker, text=message.content) return else: @@ -115,13 +121,15 @@ class TTSCore(commands.Cog, Observ.Observer): if voice_client.is_playing(): # Then we need to enqueue prepared sound for playing through self.tts_queues mechanism self.tts_queues[message.guild.id].append(sound_source) - await message.channel.send(f"Enqueued for play, queue size: {len(self.tts_queues[message.guild.id])}") + await message.channel.send( + f"Enqueued for play, queue size: {len(self.tts_queues[message.guild.id])}") return voice_client.play(sound_source, after=lambda e: self.queue_player(message)) except Exception as play_exception: - logger.opt(exception=True).warning(f'Exception on playing for: {message.guild.name}[#{message.channel.name}]: {message.author.display_name} / {play_exception}') + logger.opt(exception=True).warning( + f'Exception on playing for: {message.guild.name}[#{message.channel.name}]: {message.author.display_name} / {play_exception}') await message.channel.send(f'Playing error') return @@ -151,7 +159,8 @@ class TTSCore(commands.Cog, Observ.Observer): pass @commands.Cog.listener() - async def on_voice_state_update(self, member: discord.Member, before: discord.VoiceState, after: discord.VoiceState): + async def on_voice_state_update(self, member: discord.Member, before: discord.VoiceState, + after: discord.VoiceState): if after.channel is None: members = before.channel.members if len(members) == 1: @@ -160,6 +169,95 @@ class TTSCore(commands.Cog, Observ.Observer): # TODO: leave voice channel after being moved there alone + # Ex TTSSettings.py + @commands.command('getAllSpeakers') + async def get_speakers(self, ctx: Context): + """ + Enumerate all available to set speakers + + :param ctx: + :return: + """ + + await ctx.send( + formatting.format_table( + tuple((model.engine + '_' + model.name, model.description) + for model in await self.tts.discovery()), + ('model', 'description') + )) + + @commands.command('setPersonalSpeaker') + async def set_user_speaker(self, ctx: Context, speaker: str): + """ + Set personal speaker on this server + + :param ctx: + :param speaker: + :return: + """ + + if await self.tts.speaker_to_description(speaker) is None: + return await ctx.send( + f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command") + + self.speakers_adapter.set_speaker_user(ctx.guild.id, ctx.author.id, speaker) + await ctx.reply(f'Successfully set **your personal** speaker to `{speaker}`') + + @commands.command('setServerSpeaker') + async def set_server_speaker(self, ctx: Context, speaker: str): + """ + Set global server speaker + + :param ctx: + :param speaker: + :return: + """ + if await self.tts.speaker_to_description(speaker) is None: + return await ctx.send( + f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command") + + self.speakers_adapter.set_speaker_global(ctx.guild.id, ctx.author.id, speaker) + await ctx.reply(f'Successfully set **your personal** speaker to `{speaker}`') + + @commands.command('getSpeaker') + async def get_speaker(self, ctx: Context): + """ + Tell first appropriate speaker for a user, it can be user specified, server specified or server default + + :param ctx: + :return: + """ + speaker = self.speakers_adapter.get_speaker(ctx.guild.id, ctx.author.id) + + await ctx.reply(f'Your current speaker is `{speaker}`') + + @commands.command('getPersonalSpeaker') + async def get_personal_speaker(self, ctx: Context): + """ + Tell user his personal speaker on this server, if user don't have personal speaker, tells server default one + + :param ctx: + :return: + """ + speaker = self.speakers_adapter.get_speaker_user(ctx.guild.id, ctx.author.id) + if speaker is None: + server_speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id) + await ctx.send(f"You currently don't have a personal speaker, current server speaker is `{server_speaker}`") + + else: + await ctx.reply(f"Your personal speaker is `{speaker}`") + + @commands.command('getServerSpeaker') + async def get_server_speaker(self, ctx: Context): + """ + Tell server global speaker + + :param ctx: + :return: + """ + speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id) + await ctx.send(f"Current server speaker is `{speaker}`") + async def setup(bot): await bot.add_cog(TTSCore(bot)) diff --git a/cogs/TTSSettings.py b/cogs/TTSSettings.py deleted file mode 100644 index 4e9ca4b..0000000 --- a/cogs/TTSSettings.py +++ /dev/null @@ -1,102 +0,0 @@ -# -*- coding: utf-8 -*- -from discord.ext import commands -from discord.ext.commands import Context -from TTSSilero import Speakers -from cogErrorHandlers import cogErrorHandlers -from SpeakersSettingsAdapterDiscord import speakers_settings_adapter, SpeakersSettingsAdapterDiscord - - -class TTSSettings(commands.Cog): - def __init__(self, bot: commands.Bot): - self.bot = bot - self.cog_command_error = cogErrorHandlers.missing_argument_handler - self.speakers_adapter: SpeakersSettingsAdapterDiscord = speakers_settings_adapter - - @commands.command('getAllSpeakers') - async def get_speakers(self, ctx: Context): - """ - Enumerate all available to set speakers - - :param ctx: - :return: - """ - speakers = '\n'.join(self.speakers_adapter.available_speakers) - - await ctx.send(f"```\n{speakers}```") - - @commands.command('setPersonalSpeaker') - async def set_user_speaker(self, ctx: Context, speaker: str): - """ - Set personal speaker on this server - - :param ctx: - :param speaker: - :return: - """ - try: - checked_speaker: Speakers = Speakers(speaker) - self.speakers_adapter.set_speaker_user(ctx.guild.id, ctx.author.id, checked_speaker) - await ctx.reply(f'Successfully set **your personal** speaker to `{checked_speaker.value}`') - - except (KeyError, ValueError): - await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command") - - @commands.command('setServerSpeaker') - async def set_server_speaker(self, ctx: Context, speaker: str): - """ - Set global server speaker - - :param ctx: - :param speaker: - :return: - """ - try: - checked_speaker: Speakers = Speakers(speaker) - self.speakers_adapter.set_speaker_global(ctx.guild.id, checked_speaker) - await ctx.send(f'Successfully set **server** speaker to `{checked_speaker.value}`') - - except (KeyError, ValueError): - await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command") - - @commands.command('getSpeaker') - async def get_speaker(self, ctx: Context): - """ - Tell first appropriate speaker for a user, it can be user specified, server specified or server default - - :param ctx: - :return: - """ - speaker = self.speakers_adapter.get_speaker(ctx.guild.id, ctx.author.id) - - await ctx.reply(f'Your current speaker is `{speaker.value}`') - - @commands.command('getPersonalSpeaker') - async def get_personal_speaker(self, ctx: Context): - """ - Tell user his personal speaker on this server, if user don't have personal speaker, tells server default one - - :param ctx: - :return: - """ - speaker = self.speakers_adapter.get_speaker_user(ctx.guild.id, ctx.author.id) - if speaker is None: - server_speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id).value - await ctx.send(f"You currently don't have a personal speaker, current server speaker is `{server_speaker}`") - - else: - await ctx.reply(f"Your personal speaker is `{speaker.value}`") - - @commands.command('getServerSpeaker') - async def get_server_speaker(self, ctx: Context): - """ - Tell server global speaker - - :param ctx: - :return: - """ - speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id) - await ctx.send(f"Current server speaker is `{speaker.value}`") - - -async def setup(bot): - await bot.add_cog(TTSSettings(bot)) diff --git a/cogs/prefixConfiguration.py b/cogs/prefixConfiguration.py index da75d51..597048d 100644 --- a/cogs/prefixConfiguration.py +++ b/cogs/prefixConfiguration.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- from discord.ext import commands from discord.ext.commands import Context -import DB from loguru import logger + +import DB import DynamicCommandPrefix from cogErrorHandlers import cogErrorHandlers diff --git a/config.py b/config.py new file mode 100644 index 0000000..df86bb5 --- /dev/null +++ b/config.py @@ -0,0 +1,5 @@ +import os +from pathlib import Path + +DB_PATH = Path(os.getenv('DATA_DIR', '.')) / 'voice_cache.sqlite' +BASE_URL = os.environ['BASE_URL'] diff --git a/formatting.py b/formatting.py index 6a0e3c3..18c5569 100644 --- a/formatting.py +++ b/formatting.py @@ -11,7 +11,7 @@ def format_table(data: Iterable[Iterable[str]], header: Iterable[str] = MISSING) result = '```\n' for row in data: - row = [item.replace('`', '\\`') for item in row] + row = [str(item).replace('`', '\\`') for item in row] result += '\t'.join(row) + '\n' result += '```' diff --git a/main.py b/main.py index c780019..1eec9dd 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- -import os -from discord.ext import commands -import discord -import signal import asyncio +import os +import signal + +import discord +from discord.ext import commands from loguru import logger -from DynamicCommandPrefix import dynamic_command_prefix + import Observ +from DynamicCommandPrefix import dynamic_command_prefix LOG_FILE_ENABLED = os.getenv('LOG_FILE_ENABLED', 'true').lower() == 'true' @@ -68,6 +70,7 @@ async def main(): await discord_client.start(os.environ['DISCORD_TOKEN']) + if __name__ == '__main__': loop = asyncio.new_event_loop() loop.run_until_complete(main()) diff --git a/requirements.txt b/requirements.txt index f20eab3..6c2c252 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,4 @@ peewee==3.17.5 PyNaCl==1.5.0 git+https://github.com/aaiyer/rfoo.git@1555bd4eed204bb6a33a5e313146a6c2813cfe91 # Cython is setup dependency for rfoo ---index-url https://download.pytorch.org/whl/cpu -torch==2.3.0+cpu -numpy==1.26.4 \ No newline at end of file +pydantic \ No newline at end of file diff --git a/utils.py b/utils.py index 201c492..d9f6e0a 100644 --- a/utils.py +++ b/utils.py @@ -1,10 +1,4 @@ import winsound -import io - - -def save_bytes(filename: str, bytes_audio: bytes) -> None: - with open(file=filename, mode='wb') as res_file: - res_file.write(bytes_audio) def play_bytes(bytes_sound: bytes) -> None: