init
This commit is contained in:
commit
d3eaa8bb70
27
Cache.py
Normal file
27
Cache.py
Normal file
@ -0,0 +1,27 @@
|
||||
import io
|
||||
import sqlite3
|
||||
|
||||
|
||||
class Cache:
|
||||
SCHEMA = "create table if not exists cache (key text primary key, value blob);"
|
||||
SET = "insert into cache (key, value) values (:key, :value);"
|
||||
GET = "select value from cache where key = :key;"
|
||||
|
||||
def __init__(self):
|
||||
self.connection = sqlite3.connect('voice_cache.sqlite')
|
||||
self.connection.execute(self.SCHEMA)
|
||||
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
with self.connection:
|
||||
self.connection.execute(self.SET, {'key': key, 'value': value})
|
||||
|
||||
def get(self, key: str):
|
||||
res = self.connection.execute(self.GET, {'key': key}).fetchone()
|
||||
if res is None:
|
||||
return None
|
||||
|
||||
else:
|
||||
return io.BytesIO(res[0])
|
||||
|
||||
|
||||
cache = Cache()
|
47
FFmpegPCMAudioModified.py
Normal file
47
FFmpegPCMAudioModified.py
Normal file
@ -0,0 +1,47 @@
|
||||
import subprocess
|
||||
import shlex
|
||||
import io
|
||||
from discord.opus import Encoder
|
||||
import discord
|
||||
|
||||
# Huge thanks to https://github.com/Rapptz/discord.py/issues/5192#issuecomment-669515571
|
||||
|
||||
|
||||
class FFmpegPCMAudio(discord.AudioSource):
|
||||
def __init__(self, source, *, executable='ffmpeg', pipe=False, stderr=None, before_options=None, options=None):
|
||||
stdin = None if not pipe else source
|
||||
args = [executable]
|
||||
if isinstance(before_options, str):
|
||||
args.extend(shlex.split(before_options))
|
||||
args.append('-i')
|
||||
args.append('-' if pipe else source)
|
||||
args.extend(('-f', 's16le', '-ar', '48000', '-ac', '2', '-loglevel', 'warning'))
|
||||
if isinstance(options, str):
|
||||
args.extend(shlex.split(options))
|
||||
args.append('pipe:1')
|
||||
self._process = None
|
||||
try:
|
||||
self._process = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=stderr)
|
||||
self._stdout = io.BytesIO(
|
||||
self._process.communicate(input=stdin)[0]
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise discord.ClientException(executable + ' was not found.') from None
|
||||
except subprocess.SubprocessError as exc:
|
||||
raise discord.ClientException('Popen failed: {0.__class__.__name__}: {0}'.format(exc)) from exc
|
||||
|
||||
def read(self):
|
||||
ret = self._stdout.read(Encoder.FRAME_SIZE)
|
||||
if len(ret) != Encoder.FRAME_SIZE:
|
||||
return b''
|
||||
return ret
|
||||
|
||||
def cleanup(self):
|
||||
proc = self._process
|
||||
if proc is None:
|
||||
return
|
||||
proc.kill()
|
||||
if proc.poll() is None:
|
||||
proc.communicate()
|
||||
|
||||
self._process = None
|
15
TTSSilero/Speakers.py
Normal file
15
TTSSilero/Speakers.py
Normal file
@ -0,0 +1,15 @@
|
||||
import enum
|
||||
|
||||
|
||||
class Speakers(enum.Enum):
|
||||
aidar = 'aidar'
|
||||
baya = 'baya'
|
||||
kseniya = 'kseniya'
|
||||
irina = 'irina'
|
||||
ruslan = 'ruslan'
|
||||
natasha = 'natasha'
|
||||
thorsten = 'thorsten'
|
||||
tux = 'tux'
|
||||
gilles = 'gilles'
|
||||
lj = 'lj'
|
||||
dilyara = 'dilyara'
|
62
TTSSilero/TTSSilero.py
Normal file
62
TTSSilero/TTSSilero.py
Normal file
@ -0,0 +1,62 @@
|
||||
import contextlib
|
||||
import os
|
||||
import io
|
||||
import wave
|
||||
import torch.package
|
||||
|
||||
from .Speakers import Speakers
|
||||
from .multi_v2_package import TTSModelMulti_v2
|
||||
|
||||
|
||||
class TTS:
|
||||
def __init__(self, threads: int = 12):
|
||||
device = torch.device('cpu')
|
||||
torch.set_num_threads(threads)
|
||||
local_file = 'model_multi.pt'
|
||||
|
||||
if not os.path.isfile(local_file):
|
||||
torch.hub.download_url_to_file(
|
||||
'https://models.silero.ai/models/tts/multi/v2_multi.pt',
|
||||
local_file
|
||||
)
|
||||
|
||||
self.model: TTSModelMulti_v2 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
|
||||
self.model.to(device)
|
||||
|
||||
self.sample_rate = 16000
|
||||
|
||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya, seek: int = None) -> io.BytesIO:
|
||||
return self.to_wav(self._synthesize_text(text, speaker), seek)
|
||||
|
||||
def _synthesize_text(self, text: str, speaker: Speakers) -> list[torch.Tensor]:
|
||||
"""
|
||||
Performs splitting text and synthesizing it
|
||||
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
|
||||
results_list: list[torch.Tensor] = self.model.apply_tts(
|
||||
texts=[text],
|
||||
speakers=speaker.value,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
return results_list
|
||||
|
||||
def to_wav(self, synthesized_text: list[torch.Tensor], seek: int = None) -> io.BytesIO:
|
||||
res_io_stream = io.BytesIO()
|
||||
|
||||
with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(self.sample_rate)
|
||||
for result in synthesized_text:
|
||||
wf.writeframes((result * 32767).numpy().astype('int16'))
|
||||
|
||||
if type(seek) is int:
|
||||
res_io_stream.seek(seek)
|
||||
|
||||
return res_io_stream
|
||||
|
||||
|
2
TTSSilero/__init__.py
Normal file
2
TTSSilero/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .TTSSilero import TTS
|
||||
from .Speakers import Speakers
|
148
TTSSilero/multi_v2_package.py
Normal file
148
TTSSilero/multi_v2_package.py
Normal file
@ -0,0 +1,148 @@
|
||||
import re
|
||||
import wave
|
||||
import torch
|
||||
import warnings
|
||||
import contextlib
|
||||
|
||||
# for type hints only
|
||||
|
||||
|
||||
class TTSModelMulti_v2():
|
||||
def __init__(self, model_path, symbols):
|
||||
self.model = self.init_jit_model(model_path)
|
||||
self.symbols = symbols
|
||||
self.device = torch.device('cpu')
|
||||
speakers = ['aidar', 'baya', 'kseniya', 'irina', 'ruslan', 'natasha',
|
||||
'thorsten', 'tux', 'gilles', 'lj', 'dilyara']
|
||||
self.speaker_to_id = {sp: i for i, sp in enumerate(speakers)}
|
||||
|
||||
def init_jit_model(self, model_path: str):
|
||||
torch.set_grad_enabled(False)
|
||||
model = torch.jit.load(model_path, map_location='cpu')
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def prepare_text_input(self, text, symbols, symbol_to_id=None):
|
||||
if len(text) > 140:
|
||||
warnings.warn('Text string is longer than 140 symbols.')
|
||||
|
||||
if symbol_to_id is None:
|
||||
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
|
||||
text = text.lower()
|
||||
text = re.sub(r'[^{}]'.format(symbols[2:]), '', text)
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
if text[-1] not in ['.', '!', '?']:
|
||||
text = text + '.'
|
||||
text = text + symbols[1]
|
||||
|
||||
text_ohe = [symbol_to_id[s] for s in text if s in symbols]
|
||||
text_tensor = torch.LongTensor(text_ohe)
|
||||
return text_tensor
|
||||
|
||||
def prepare_tts_model_input(self, text: str or list, symbols: str, speakers: list):
|
||||
assert len(speakers) == len(text) or len(speakers) == 1
|
||||
if type(text) == str:
|
||||
text = [text]
|
||||
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
if len(text) == 1:
|
||||
return self.prepare_text_input(text[0], symbols, symbol_to_id).unsqueeze(0), torch.LongTensor(speakers), torch.LongTensor([0])
|
||||
|
||||
text_tensors = []
|
||||
for string in text:
|
||||
string_tensor = self.prepare_text_input(string, symbols, symbol_to_id)
|
||||
text_tensors.append(string_tensor)
|
||||
input_lengths, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([len(t) for t in text_tensors]),
|
||||
dim=0, descending=True)
|
||||
max_input_len = input_lengths[0]
|
||||
batch_size = len(text_tensors)
|
||||
|
||||
text_padded = torch.ones(batch_size, max_input_len, dtype=torch.int32)
|
||||
if len(speakers) == 1:
|
||||
speakers = speakers*batch_size
|
||||
speaker_ids = torch.LongTensor(batch_size).zero_()
|
||||
|
||||
for i, idx in enumerate(ids_sorted_decreasing):
|
||||
text_tensor = text_tensors[idx]
|
||||
in_len = text_tensor.size(0)
|
||||
text_padded[i, :in_len] = text_tensor
|
||||
speaker_ids[i] = speakers[idx]
|
||||
|
||||
return text_padded, speaker_ids, ids_sorted_decreasing
|
||||
|
||||
def process_tts_model_output(self, out, out_lens, ids):
|
||||
out = out.to('cpu')
|
||||
out_lens = out_lens.to('cpu')
|
||||
_, orig_ids = ids.sort()
|
||||
|
||||
proc_outs = []
|
||||
orig_out = out.index_select(0, orig_ids)
|
||||
orig_out_lens = out_lens.index_select(0, orig_ids)
|
||||
|
||||
for i, out_len in enumerate(orig_out_lens):
|
||||
proc_outs.append(orig_out[i][:out_len])
|
||||
return proc_outs
|
||||
|
||||
def to(self, device):
|
||||
self.model = self.model.to(device)
|
||||
self.device = device
|
||||
|
||||
def get_speakers(self, speakers: str or list):
|
||||
if type(speakers) == str:
|
||||
speakers = [speakers]
|
||||
speaker_ids = []
|
||||
for speaker in speakers:
|
||||
try:
|
||||
speaker_id = self.speaker_to_id[speaker]
|
||||
speaker_ids.append(speaker_id)
|
||||
except Exception:
|
||||
raise ValueError(f'No such speaker: {speaker}')
|
||||
return speaker_ids
|
||||
|
||||
def apply_tts(self, texts: str or list,
|
||||
speakers: str or list,
|
||||
sample_rate: int = 16000):
|
||||
speaker_ids = self.get_speakers(speakers)
|
||||
text_padded, speaker_ids, orig_ids = self.prepare_tts_model_input(texts,
|
||||
symbols=self.symbols,
|
||||
speakers=speaker_ids)
|
||||
with torch.inference_mode():
|
||||
out, out_lens = self.model(text_padded.to(self.device),
|
||||
speaker_ids.to(self.device),
|
||||
sr=sample_rate)
|
||||
audios = self.process_tts_model_output(out, out_lens, orig_ids)
|
||||
return audios
|
||||
|
||||
@staticmethod
|
||||
def write_wave(path, audio, sample_rate):
|
||||
"""Writes a .wav file.
|
||||
Takes path, PCM audio data, and sample rate.
|
||||
"""
|
||||
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio)
|
||||
|
||||
def save_wav(self, texts: str or list,
|
||||
speakers: str or list,
|
||||
audio_pathes: str or list = '',
|
||||
sample_rate: int = 16000):
|
||||
if type(texts) == str:
|
||||
texts = [texts]
|
||||
|
||||
if not audio_pathes:
|
||||
audio_pathes = [f'test_{str(i).zfill(3)}.wav' for i in range(len(texts))]
|
||||
if type(audio_pathes) == str:
|
||||
audio_pathes = [audio_pathes]
|
||||
assert len(audio_pathes) == len(texts)
|
||||
|
||||
audio = self.apply_tts(texts=texts,
|
||||
speakers=speakers,
|
||||
sample_rate=sample_rate)
|
||||
for i, _audio in enumerate(audio):
|
||||
self.write_wave(path=audio_pathes[i],
|
||||
audio=(_audio * 32767).numpy().astype('int16'),
|
||||
sample_rate=sample_rate)
|
||||
return audio_pathes
|
97
main.py
Normal file
97
main.py
Normal file
@ -0,0 +1,97 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import time
|
||||
|
||||
import discord
|
||||
import signal
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
|
||||
import utils
|
||||
from TTSSilero import TTS
|
||||
from TTSSilero import Speakers
|
||||
from FFmpegPCMAudioModified import FFmpegPCMAudio
|
||||
from Cache import cache
|
||||
|
||||
tts = TTS()
|
||||
"""
|
||||
while msg := input('$ '):
|
||||
start = time.time()
|
||||
audio = tts.synthesize_text(msg, speaker=Speakers.kseniya)
|
||||
print('synthesize took ', str(time.time() - start))
|
||||
utils.play_bytes_io(audio)
|
||||
"""
|
||||
|
||||
|
||||
class DiscordTTSBot(discord.Client):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
signal.signal(signal.SIGTERM, self.shutdown)
|
||||
signal.signal(signal.SIGINT, self.shutdown)
|
||||
logger.info('Shutdown callbacks registered')
|
||||
|
||||
def shutdown(self, sig, frame):
|
||||
logger.info(f'Shutting down by signal: {sig}')
|
||||
asyncio.create_task(self.close())
|
||||
|
||||
async def on_ready(self):
|
||||
logger.debug('Bot is ready')
|
||||
|
||||
async def on_message(self, message: discord.Message):
|
||||
if message.author == self.user:
|
||||
return
|
||||
|
||||
if not message.content.startswith('-'):
|
||||
return
|
||||
|
||||
if isinstance(message.channel, discord.TextChannel):
|
||||
logger.info(f'Message: {message.content}')
|
||||
user_voice_state = message.author.voice
|
||||
if message.content.startswith('/exit'):
|
||||
if message.guild.voice_client is not None:
|
||||
logger.debug(f'Disconnecting from voice channel')
|
||||
await message.guild.voice_client.disconnect(force=False)
|
||||
await message.channel.send(f"Left voice channel")
|
||||
return
|
||||
|
||||
else:
|
||||
await message.channel.send("I'm not in any voice channel")
|
||||
return
|
||||
|
||||
if user_voice_state is None:
|
||||
await message.channel.send(f"You're not in a voice channel")
|
||||
return
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
voice_client: discord.VoiceClient = message.guild.voice_client
|
||||
if voice_client is None:
|
||||
voice_client: discord.VoiceClient = await user_voice_state.channel.connect()
|
||||
|
||||
cached = cache.get(message.content)
|
||||
if cached is not None:
|
||||
wav_file_like_object = cached
|
||||
logger.debug(f'Cache lookup for {message.content!r} successful')
|
||||
|
||||
else:
|
||||
synthesis_start = time.time()
|
||||
wav_file_like_object = tts.synthesize_text(message.content, seek=0)
|
||||
logger.debug(f'Synthesis took {time.time() - synthesis_start} s')
|
||||
cache.set(message.content, wav_file_like_object.read())
|
||||
logger.debug(f'Set cache for {message.content!r}')
|
||||
wav_file_like_object.seek(0)
|
||||
|
||||
sound_source = FFmpegPCMAudio(
|
||||
wav_file_like_object.read(),
|
||||
pipe=True
|
||||
)
|
||||
voice_client.play(sound_source, after=lambda e: logger.debug(f"Player done, {e=}"))
|
||||
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
discord_client = DiscordTTSBot(intents=intents)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(discord_client.start(os.environ['DISCORD_TOKEN']))
|
||||
logger.debug('Shutdown completed')
|
13
utils.py
Normal file
13
utils.py
Normal file
@ -0,0 +1,13 @@
|
||||
import winsound
|
||||
import io
|
||||
|
||||
|
||||
def save_bytes_io(filename: str, bytes_stream: io.BytesIO) -> None:
|
||||
with open(file=filename, mode='wb') as res_file:
|
||||
bytes_stream.seek(0)
|
||||
res_file.write(bytes_stream.read())
|
||||
|
||||
|
||||
def play_bytes_io(bytes_stream: io.BytesIO) -> None:
|
||||
bytes_stream.seek(0)
|
||||
winsound.PlaySound(bytes_stream.read(), winsound.SND_MEMORY)
|
Loading…
x
Reference in New Issue
Block a user