commit d3eaa8bb70918c3cc78aee4050e2532b2710b3c9 Author: norohind <60548839+norohind@users.noreply.github.com> Date: Mon Mar 14 21:55:46 2022 +0300 init diff --git a/Cache.py b/Cache.py new file mode 100644 index 0000000..4e4571c --- /dev/null +++ b/Cache.py @@ -0,0 +1,27 @@ +import io +import sqlite3 + + +class Cache: + SCHEMA = "create table if not exists cache (key text primary key, value blob);" + SET = "insert into cache (key, value) values (:key, :value);" + GET = "select value from cache where key = :key;" + + def __init__(self): + self.connection = sqlite3.connect('voice_cache.sqlite') + self.connection.execute(self.SCHEMA) + + def set(self, key: str, value: bytes) -> None: + with self.connection: + self.connection.execute(self.SET, {'key': key, 'value': value}) + + def get(self, key: str): + res = self.connection.execute(self.GET, {'key': key}).fetchone() + if res is None: + return None + + else: + return io.BytesIO(res[0]) + + +cache = Cache() diff --git a/FFmpegPCMAudioModified.py b/FFmpegPCMAudioModified.py new file mode 100644 index 0000000..1c00475 --- /dev/null +++ b/FFmpegPCMAudioModified.py @@ -0,0 +1,47 @@ +import subprocess +import shlex +import io +from discord.opus import Encoder +import discord + +# Huge thanks to https://github.com/Rapptz/discord.py/issues/5192#issuecomment-669515571 + + +class FFmpegPCMAudio(discord.AudioSource): + def __init__(self, source, *, executable='ffmpeg', pipe=False, stderr=None, before_options=None, options=None): + stdin = None if not pipe else source + args = [executable] + if isinstance(before_options, str): + args.extend(shlex.split(before_options)) + args.append('-i') + args.append('-' if pipe else source) + args.extend(('-f', 's16le', '-ar', '48000', '-ac', '2', '-loglevel', 'warning')) + if isinstance(options, str): + args.extend(shlex.split(options)) + args.append('pipe:1') + self._process = None + try: + self._process = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=stderr) + self._stdout = io.BytesIO( + self._process.communicate(input=stdin)[0] + ) + except FileNotFoundError: + raise discord.ClientException(executable + ' was not found.') from None + except subprocess.SubprocessError as exc: + raise discord.ClientException('Popen failed: {0.__class__.__name__}: {0}'.format(exc)) from exc + + def read(self): + ret = self._stdout.read(Encoder.FRAME_SIZE) + if len(ret) != Encoder.FRAME_SIZE: + return b'' + return ret + + def cleanup(self): + proc = self._process + if proc is None: + return + proc.kill() + if proc.poll() is None: + proc.communicate() + + self._process = None diff --git a/TTSSilero/Speakers.py b/TTSSilero/Speakers.py new file mode 100644 index 0000000..1bfb4a2 --- /dev/null +++ b/TTSSilero/Speakers.py @@ -0,0 +1,15 @@ +import enum + + +class Speakers(enum.Enum): + aidar = 'aidar' + baya = 'baya' + kseniya = 'kseniya' + irina = 'irina' + ruslan = 'ruslan' + natasha = 'natasha' + thorsten = 'thorsten' + tux = 'tux' + gilles = 'gilles' + lj = 'lj' + dilyara = 'dilyara' diff --git a/TTSSilero/TTSSilero.py b/TTSSilero/TTSSilero.py new file mode 100644 index 0000000..9a35238 --- /dev/null +++ b/TTSSilero/TTSSilero.py @@ -0,0 +1,62 @@ +import contextlib +import os +import io +import wave +import torch.package + +from .Speakers import Speakers +from .multi_v2_package import TTSModelMulti_v2 + + +class TTS: + def __init__(self, threads: int = 12): + device = torch.device('cpu') + torch.set_num_threads(threads) + local_file = 'model_multi.pt' + + if not os.path.isfile(local_file): + torch.hub.download_url_to_file( + 'https://models.silero.ai/models/tts/multi/v2_multi.pt', + local_file + ) + + self.model: TTSModelMulti_v2 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model") + self.model.to(device) + + self.sample_rate = 16000 + + def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya, seek: int = None) -> io.BytesIO: + return self.to_wav(self._synthesize_text(text, speaker), seek) + + def _synthesize_text(self, text: str, speaker: Speakers) -> list[torch.Tensor]: + """ + Performs splitting text and synthesizing it + + :param text: + :return: + """ + + results_list: list[torch.Tensor] = self.model.apply_tts( + texts=[text], + speakers=speaker.value, + sample_rate=self.sample_rate + ) + + return results_list + + def to_wav(self, synthesized_text: list[torch.Tensor], seek: int = None) -> io.BytesIO: + res_io_stream = io.BytesIO() + + with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(self.sample_rate) + for result in synthesized_text: + wf.writeframes((result * 32767).numpy().astype('int16')) + + if type(seek) is int: + res_io_stream.seek(seek) + + return res_io_stream + + diff --git a/TTSSilero/__init__.py b/TTSSilero/__init__.py new file mode 100644 index 0000000..216a167 --- /dev/null +++ b/TTSSilero/__init__.py @@ -0,0 +1,2 @@ +from .TTSSilero import TTS +from .Speakers import Speakers diff --git a/TTSSilero/multi_v2_package.py b/TTSSilero/multi_v2_package.py new file mode 100644 index 0000000..5cbe4f7 --- /dev/null +++ b/TTSSilero/multi_v2_package.py @@ -0,0 +1,148 @@ +import re +import wave +import torch +import warnings +import contextlib + +# for type hints only + + +class TTSModelMulti_v2(): + def __init__(self, model_path, symbols): + self.model = self.init_jit_model(model_path) + self.symbols = symbols + self.device = torch.device('cpu') + speakers = ['aidar', 'baya', 'kseniya', 'irina', 'ruslan', 'natasha', + 'thorsten', 'tux', 'gilles', 'lj', 'dilyara'] + self.speaker_to_id = {sp: i for i, sp in enumerate(speakers)} + + def init_jit_model(self, model_path: str): + torch.set_grad_enabled(False) + model = torch.jit.load(model_path, map_location='cpu') + model.eval() + return model + + def prepare_text_input(self, text, symbols, symbol_to_id=None): + if len(text) > 140: + warnings.warn('Text string is longer than 140 symbols.') + + if symbol_to_id is None: + symbol_to_id = {s: i for i, s in enumerate(symbols)} + + text = text.lower() + text = re.sub(r'[^{}]'.format(symbols[2:]), '', text) + text = re.sub(r'\s+', ' ', text).strip() + if text[-1] not in ['.', '!', '?']: + text = text + '.' + text = text + symbols[1] + + text_ohe = [symbol_to_id[s] for s in text if s in symbols] + text_tensor = torch.LongTensor(text_ohe) + return text_tensor + + def prepare_tts_model_input(self, text: str or list, symbols: str, speakers: list): + assert len(speakers) == len(text) or len(speakers) == 1 + if type(text) == str: + text = [text] + symbol_to_id = {s: i for i, s in enumerate(symbols)} + if len(text) == 1: + return self.prepare_text_input(text[0], symbols, symbol_to_id).unsqueeze(0), torch.LongTensor(speakers), torch.LongTensor([0]) + + text_tensors = [] + for string in text: + string_tensor = self.prepare_text_input(string, symbols, symbol_to_id) + text_tensors.append(string_tensor) + input_lengths, ids_sorted_decreasing = torch.sort( + torch.LongTensor([len(t) for t in text_tensors]), + dim=0, descending=True) + max_input_len = input_lengths[0] + batch_size = len(text_tensors) + + text_padded = torch.ones(batch_size, max_input_len, dtype=torch.int32) + if len(speakers) == 1: + speakers = speakers*batch_size + speaker_ids = torch.LongTensor(batch_size).zero_() + + for i, idx in enumerate(ids_sorted_decreasing): + text_tensor = text_tensors[idx] + in_len = text_tensor.size(0) + text_padded[i, :in_len] = text_tensor + speaker_ids[i] = speakers[idx] + + return text_padded, speaker_ids, ids_sorted_decreasing + + def process_tts_model_output(self, out, out_lens, ids): + out = out.to('cpu') + out_lens = out_lens.to('cpu') + _, orig_ids = ids.sort() + + proc_outs = [] + orig_out = out.index_select(0, orig_ids) + orig_out_lens = out_lens.index_select(0, orig_ids) + + for i, out_len in enumerate(orig_out_lens): + proc_outs.append(orig_out[i][:out_len]) + return proc_outs + + def to(self, device): + self.model = self.model.to(device) + self.device = device + + def get_speakers(self, speakers: str or list): + if type(speakers) == str: + speakers = [speakers] + speaker_ids = [] + for speaker in speakers: + try: + speaker_id = self.speaker_to_id[speaker] + speaker_ids.append(speaker_id) + except Exception: + raise ValueError(f'No such speaker: {speaker}') + return speaker_ids + + def apply_tts(self, texts: str or list, + speakers: str or list, + sample_rate: int = 16000): + speaker_ids = self.get_speakers(speakers) + text_padded, speaker_ids, orig_ids = self.prepare_tts_model_input(texts, + symbols=self.symbols, + speakers=speaker_ids) + with torch.inference_mode(): + out, out_lens = self.model(text_padded.to(self.device), + speaker_ids.to(self.device), + sr=sample_rate) + audios = self.process_tts_model_output(out, out_lens, orig_ids) + return audios + + @staticmethod + def write_wave(path, audio, sample_rate): + """Writes a .wav file. + Takes path, PCM audio data, and sample rate. + """ + with contextlib.closing(wave.open(path, 'wb')) as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(sample_rate) + wf.writeframes(audio) + + def save_wav(self, texts: str or list, + speakers: str or list, + audio_pathes: str or list = '', + sample_rate: int = 16000): + if type(texts) == str: + texts = [texts] + + if not audio_pathes: + audio_pathes = [f'test_{str(i).zfill(3)}.wav' for i in range(len(texts))] + if type(audio_pathes) == str: + audio_pathes = [audio_pathes] + assert len(audio_pathes) == len(texts) + + audio = self.apply_tts(texts=texts, + speakers=speakers, + sample_rate=sample_rate) + for i, _audio in enumerate(audio): + self.write_wave(path=audio_pathes[i], + audio=(_audio * 32767).numpy().astype('int16'), + sample_rate=sample_rate) + return audio_pathes diff --git a/main.py b/main.py new file mode 100644 index 0000000..32baaa7 --- /dev/null +++ b/main.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +import os +import time + +import discord +import signal +import asyncio +from loguru import logger + +import utils +from TTSSilero import TTS +from TTSSilero import Speakers +from FFmpegPCMAudioModified import FFmpegPCMAudio +from Cache import cache + +tts = TTS() +""" +while msg := input('$ '): + start = time.time() + audio = tts.synthesize_text(msg, speaker=Speakers.kseniya) + print('synthesize took ', str(time.time() - start)) + utils.play_bytes_io(audio) +""" + + +class DiscordTTSBot(discord.Client): + def __init__(self, **kwargs): + super().__init__(**kwargs) + signal.signal(signal.SIGTERM, self.shutdown) + signal.signal(signal.SIGINT, self.shutdown) + logger.info('Shutdown callbacks registered') + + def shutdown(self, sig, frame): + logger.info(f'Shutting down by signal: {sig}') + asyncio.create_task(self.close()) + + async def on_ready(self): + logger.debug('Bot is ready') + + async def on_message(self, message: discord.Message): + if message.author == self.user: + return + + if not message.content.startswith('-'): + return + + if isinstance(message.channel, discord.TextChannel): + logger.info(f'Message: {message.content}') + user_voice_state = message.author.voice + if message.content.startswith('/exit'): + if message.guild.voice_client is not None: + logger.debug(f'Disconnecting from voice channel') + await message.guild.voice_client.disconnect(force=False) + await message.channel.send(f"Left voice channel") + return + + else: + await message.channel.send("I'm not in any voice channel") + return + + if user_voice_state is None: + await message.channel.send(f"You're not in a voice channel") + return + + # noinspection PyTypeChecker + voice_client: discord.VoiceClient = message.guild.voice_client + if voice_client is None: + voice_client: discord.VoiceClient = await user_voice_state.channel.connect() + + cached = cache.get(message.content) + if cached is not None: + wav_file_like_object = cached + logger.debug(f'Cache lookup for {message.content!r} successful') + + else: + synthesis_start = time.time() + wav_file_like_object = tts.synthesize_text(message.content, seek=0) + logger.debug(f'Synthesis took {time.time() - synthesis_start} s') + cache.set(message.content, wav_file_like_object.read()) + logger.debug(f'Set cache for {message.content!r}') + wav_file_like_object.seek(0) + + sound_source = FFmpegPCMAudio( + wav_file_like_object.read(), + pipe=True + ) + voice_client.play(sound_source, after=lambda e: logger.debug(f"Player done, {e=}")) + + +intents = discord.Intents.default() +intents.message_content = True + +discord_client = DiscordTTSBot(intents=intents) + +loop = asyncio.new_event_loop() +loop.run_until_complete(discord_client.start(os.environ['DISCORD_TOKEN'])) +logger.debug('Shutdown completed') diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..4925f43 --- /dev/null +++ b/utils.py @@ -0,0 +1,13 @@ +import winsound +import io + + +def save_bytes_io(filename: str, bytes_stream: io.BytesIO) -> None: + with open(file=filename, mode='wb') as res_file: + bytes_stream.seek(0) + res_file.write(bytes_stream.read()) + + +def play_bytes_io(bytes_stream: io.BytesIO) -> None: + bytes_stream.seek(0) + winsound.PlaySound(bytes_stream.read(), winsound.SND_MEMORY)