This commit is contained in:
norohind 2022-03-14 21:55:46 +03:00
commit d3eaa8bb70
Signed by: norohind
GPG Key ID: 01C3BECC26FB59E1
8 changed files with 411 additions and 0 deletions

27
Cache.py Normal file
View File

@ -0,0 +1,27 @@
import io
import sqlite3
class Cache:
SCHEMA = "create table if not exists cache (key text primary key, value blob);"
SET = "insert into cache (key, value) values (:key, :value);"
GET = "select value from cache where key = :key;"
def __init__(self):
self.connection = sqlite3.connect('voice_cache.sqlite')
self.connection.execute(self.SCHEMA)
def set(self, key: str, value: bytes) -> None:
with self.connection:
self.connection.execute(self.SET, {'key': key, 'value': value})
def get(self, key: str):
res = self.connection.execute(self.GET, {'key': key}).fetchone()
if res is None:
return None
else:
return io.BytesIO(res[0])
cache = Cache()

47
FFmpegPCMAudioModified.py Normal file
View File

@ -0,0 +1,47 @@
import subprocess
import shlex
import io
from discord.opus import Encoder
import discord
# Huge thanks to https://github.com/Rapptz/discord.py/issues/5192#issuecomment-669515571
class FFmpegPCMAudio(discord.AudioSource):
def __init__(self, source, *, executable='ffmpeg', pipe=False, stderr=None, before_options=None, options=None):
stdin = None if not pipe else source
args = [executable]
if isinstance(before_options, str):
args.extend(shlex.split(before_options))
args.append('-i')
args.append('-' if pipe else source)
args.extend(('-f', 's16le', '-ar', '48000', '-ac', '2', '-loglevel', 'warning'))
if isinstance(options, str):
args.extend(shlex.split(options))
args.append('pipe:1')
self._process = None
try:
self._process = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=stderr)
self._stdout = io.BytesIO(
self._process.communicate(input=stdin)[0]
)
except FileNotFoundError:
raise discord.ClientException(executable + ' was not found.') from None
except subprocess.SubprocessError as exc:
raise discord.ClientException('Popen failed: {0.__class__.__name__}: {0}'.format(exc)) from exc
def read(self):
ret = self._stdout.read(Encoder.FRAME_SIZE)
if len(ret) != Encoder.FRAME_SIZE:
return b''
return ret
def cleanup(self):
proc = self._process
if proc is None:
return
proc.kill()
if proc.poll() is None:
proc.communicate()
self._process = None

15
TTSSilero/Speakers.py Normal file
View File

@ -0,0 +1,15 @@
import enum
class Speakers(enum.Enum):
aidar = 'aidar'
baya = 'baya'
kseniya = 'kseniya'
irina = 'irina'
ruslan = 'ruslan'
natasha = 'natasha'
thorsten = 'thorsten'
tux = 'tux'
gilles = 'gilles'
lj = 'lj'
dilyara = 'dilyara'

62
TTSSilero/TTSSilero.py Normal file
View File

@ -0,0 +1,62 @@
import contextlib
import os
import io
import wave
import torch.package
from .Speakers import Speakers
from .multi_v2_package import TTSModelMulti_v2
class TTS:
def __init__(self, threads: int = 12):
device = torch.device('cpu')
torch.set_num_threads(threads)
local_file = 'model_multi.pt'
if not os.path.isfile(local_file):
torch.hub.download_url_to_file(
'https://models.silero.ai/models/tts/multi/v2_multi.pt',
local_file
)
self.model: TTSModelMulti_v2 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
self.model.to(device)
self.sample_rate = 16000
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya, seek: int = None) -> io.BytesIO:
return self.to_wav(self._synthesize_text(text, speaker), seek)
def _synthesize_text(self, text: str, speaker: Speakers) -> list[torch.Tensor]:
"""
Performs splitting text and synthesizing it
:param text:
:return:
"""
results_list: list[torch.Tensor] = self.model.apply_tts(
texts=[text],
speakers=speaker.value,
sample_rate=self.sample_rate
)
return results_list
def to_wav(self, synthesized_text: list[torch.Tensor], seek: int = None) -> io.BytesIO:
res_io_stream = io.BytesIO()
with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(self.sample_rate)
for result in synthesized_text:
wf.writeframes((result * 32767).numpy().astype('int16'))
if type(seek) is int:
res_io_stream.seek(seek)
return res_io_stream

2
TTSSilero/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .TTSSilero import TTS
from .Speakers import Speakers

View File

@ -0,0 +1,148 @@
import re
import wave
import torch
import warnings
import contextlib
# for type hints only
class TTSModelMulti_v2():
def __init__(self, model_path, symbols):
self.model = self.init_jit_model(model_path)
self.symbols = symbols
self.device = torch.device('cpu')
speakers = ['aidar', 'baya', 'kseniya', 'irina', 'ruslan', 'natasha',
'thorsten', 'tux', 'gilles', 'lj', 'dilyara']
self.speaker_to_id = {sp: i for i, sp in enumerate(speakers)}
def init_jit_model(self, model_path: str):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location='cpu')
model.eval()
return model
def prepare_text_input(self, text, symbols, symbol_to_id=None):
if len(text) > 140:
warnings.warn('Text string is longer than 140 symbols.')
if symbol_to_id is None:
symbol_to_id = {s: i for i, s in enumerate(symbols)}
text = text.lower()
text = re.sub(r'[^{}]'.format(symbols[2:]), '', text)
text = re.sub(r'\s+', ' ', text).strip()
if text[-1] not in ['.', '!', '?']:
text = text + '.'
text = text + symbols[1]
text_ohe = [symbol_to_id[s] for s in text if s in symbols]
text_tensor = torch.LongTensor(text_ohe)
return text_tensor
def prepare_tts_model_input(self, text: str or list, symbols: str, speakers: list):
assert len(speakers) == len(text) or len(speakers) == 1
if type(text) == str:
text = [text]
symbol_to_id = {s: i for i, s in enumerate(symbols)}
if len(text) == 1:
return self.prepare_text_input(text[0], symbols, symbol_to_id).unsqueeze(0), torch.LongTensor(speakers), torch.LongTensor([0])
text_tensors = []
for string in text:
string_tensor = self.prepare_text_input(string, symbols, symbol_to_id)
text_tensors.append(string_tensor)
input_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(t) for t in text_tensors]),
dim=0, descending=True)
max_input_len = input_lengths[0]
batch_size = len(text_tensors)
text_padded = torch.ones(batch_size, max_input_len, dtype=torch.int32)
if len(speakers) == 1:
speakers = speakers*batch_size
speaker_ids = torch.LongTensor(batch_size).zero_()
for i, idx in enumerate(ids_sorted_decreasing):
text_tensor = text_tensors[idx]
in_len = text_tensor.size(0)
text_padded[i, :in_len] = text_tensor
speaker_ids[i] = speakers[idx]
return text_padded, speaker_ids, ids_sorted_decreasing
def process_tts_model_output(self, out, out_lens, ids):
out = out.to('cpu')
out_lens = out_lens.to('cpu')
_, orig_ids = ids.sort()
proc_outs = []
orig_out = out.index_select(0, orig_ids)
orig_out_lens = out_lens.index_select(0, orig_ids)
for i, out_len in enumerate(orig_out_lens):
proc_outs.append(orig_out[i][:out_len])
return proc_outs
def to(self, device):
self.model = self.model.to(device)
self.device = device
def get_speakers(self, speakers: str or list):
if type(speakers) == str:
speakers = [speakers]
speaker_ids = []
for speaker in speakers:
try:
speaker_id = self.speaker_to_id[speaker]
speaker_ids.append(speaker_id)
except Exception:
raise ValueError(f'No such speaker: {speaker}')
return speaker_ids
def apply_tts(self, texts: str or list,
speakers: str or list,
sample_rate: int = 16000):
speaker_ids = self.get_speakers(speakers)
text_padded, speaker_ids, orig_ids = self.prepare_tts_model_input(texts,
symbols=self.symbols,
speakers=speaker_ids)
with torch.inference_mode():
out, out_lens = self.model(text_padded.to(self.device),
speaker_ids.to(self.device),
sr=sample_rate)
audios = self.process_tts_model_output(out, out_lens, orig_ids)
return audios
@staticmethod
def write_wave(path, audio, sample_rate):
"""Writes a .wav file.
Takes path, PCM audio data, and sample rate.
"""
with contextlib.closing(wave.open(path, 'wb')) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio)
def save_wav(self, texts: str or list,
speakers: str or list,
audio_pathes: str or list = '',
sample_rate: int = 16000):
if type(texts) == str:
texts = [texts]
if not audio_pathes:
audio_pathes = [f'test_{str(i).zfill(3)}.wav' for i in range(len(texts))]
if type(audio_pathes) == str:
audio_pathes = [audio_pathes]
assert len(audio_pathes) == len(texts)
audio = self.apply_tts(texts=texts,
speakers=speakers,
sample_rate=sample_rate)
for i, _audio in enumerate(audio):
self.write_wave(path=audio_pathes[i],
audio=(_audio * 32767).numpy().astype('int16'),
sample_rate=sample_rate)
return audio_pathes

97
main.py Normal file
View File

@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
import os
import time
import discord
import signal
import asyncio
from loguru import logger
import utils
from TTSSilero import TTS
from TTSSilero import Speakers
from FFmpegPCMAudioModified import FFmpegPCMAudio
from Cache import cache
tts = TTS()
"""
while msg := input('$ '):
start = time.time()
audio = tts.synthesize_text(msg, speaker=Speakers.kseniya)
print('synthesize took ', str(time.time() - start))
utils.play_bytes_io(audio)
"""
class DiscordTTSBot(discord.Client):
def __init__(self, **kwargs):
super().__init__(**kwargs)
signal.signal(signal.SIGTERM, self.shutdown)
signal.signal(signal.SIGINT, self.shutdown)
logger.info('Shutdown callbacks registered')
def shutdown(self, sig, frame):
logger.info(f'Shutting down by signal: {sig}')
asyncio.create_task(self.close())
async def on_ready(self):
logger.debug('Bot is ready')
async def on_message(self, message: discord.Message):
if message.author == self.user:
return
if not message.content.startswith('-'):
return
if isinstance(message.channel, discord.TextChannel):
logger.info(f'Message: {message.content}')
user_voice_state = message.author.voice
if message.content.startswith('/exit'):
if message.guild.voice_client is not None:
logger.debug(f'Disconnecting from voice channel')
await message.guild.voice_client.disconnect(force=False)
await message.channel.send(f"Left voice channel")
return
else:
await message.channel.send("I'm not in any voice channel")
return
if user_voice_state is None:
await message.channel.send(f"You're not in a voice channel")
return
# noinspection PyTypeChecker
voice_client: discord.VoiceClient = message.guild.voice_client
if voice_client is None:
voice_client: discord.VoiceClient = await user_voice_state.channel.connect()
cached = cache.get(message.content)
if cached is not None:
wav_file_like_object = cached
logger.debug(f'Cache lookup for {message.content!r} successful')
else:
synthesis_start = time.time()
wav_file_like_object = tts.synthesize_text(message.content, seek=0)
logger.debug(f'Synthesis took {time.time() - synthesis_start} s')
cache.set(message.content, wav_file_like_object.read())
logger.debug(f'Set cache for {message.content!r}')
wav_file_like_object.seek(0)
sound_source = FFmpegPCMAudio(
wav_file_like_object.read(),
pipe=True
)
voice_client.play(sound_source, after=lambda e: logger.debug(f"Player done, {e=}"))
intents = discord.Intents.default()
intents.message_content = True
discord_client = DiscordTTSBot(intents=intents)
loop = asyncio.new_event_loop()
loop.run_until_complete(discord_client.start(os.environ['DISCORD_TOKEN']))
logger.debug('Shutdown completed')

13
utils.py Normal file
View File

@ -0,0 +1,13 @@
import winsound
import io
def save_bytes_io(filename: str, bytes_stream: io.BytesIO) -> None:
with open(file=filename, mode='wb') as res_file:
bytes_stream.seek(0)
res_file.write(bytes_stream.read())
def play_bytes_io(bytes_stream: io.BytesIO) -> None:
bytes_stream.seek(0)
winsound.PlaySound(bytes_stream.read(), winsound.SND_MEMORY)