Compare commits

..

2 Commits

Author SHA1 Message Date
3a9989d0de
Merge branch 'master' into silero_v3 2022-04-21 16:17:37 +03:00
bb28efe437
Try v3 models 2022-04-21 16:17:10 +03:00
31 changed files with 855 additions and 521 deletions

5
.gitignore vendored
View File

@ -1,5 +0,0 @@
.idea
*__pycache__
*.sqlite
*.log
model_multi.pt

12
CLISynth.py Normal file
View File

@ -0,0 +1,12 @@
import TTSSilero
from time import time
import utils
tts = TTSSilero.TTSSilero()
# while msg := input('$ '):
start = time()
audio = tts.synthesize_text(
"""Миша, давай заведем старый жигуль ... ви ви ви ви ви ви ви ви ви ви ви ви ви ви... ви ви ви ви ви ви ви ви ви ви... ви ви пр пр пр ви ви ви пр пр пр пр пр ви ви ви ви... миша, не сиди, помоги дотолкать до гаража, может там придётся перебирать её""")
print('synthesize took ', str(time() - start))
# utils.play_bytes(audio)

4
DB.py
View File

@ -2,9 +2,7 @@ from datetime import datetime
import peewee
import config
database = peewee.SqliteDatabase(str(config.DB_PATH))
database = peewee.SqliteDatabase('voice_cache.sqlite')
class BaseModel(peewee.Model):

View File

@ -1,33 +0,0 @@
FROM python:3.10-slim as builder
WORKDIR /app
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
RUN apt update && apt install git gcc libc6-dev -y --no-install-recommends && apt clean && rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN --mount=type=cache,target=/root/.cache/pip pip install Cython && \
pip wheel --no-deps --wheel-dir /app/wheels -r requirements.txt
FROM python:3.10-slim
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1
RUN useradd -ms /bin/bash silero_user && \
apt update && apt install ffmpeg -y --no-install-recommends && apt clean && rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY --from=builder /app/wheels /wheels
COPY --from=builder /app/requirements.txt .
RUN pip install --no-cache /wheels/*
COPY . .
USER silero_user
CMD ["python3", "/app/main.py"]

View File

@ -1,11 +1,8 @@
# -*- coding: utf-8 -*-
import discord
import peewee
from discord.ext import commands
import peewee
import DB
# from loguru import logger

View File

@ -1,10 +1,8 @@
import io
import shlex
import subprocess
import discord
import shlex
import io
from discord.opus import Encoder
import discord
# Huge thanks to https://github.com/Rapptz/discord.py/issues/5192#issuecomment-669515571

View File

@ -1,9 +1,8 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
import discord
class Observer:
@abstractmethod
async def update(self, message: discord.Message) -> None:
def update(self, message: discord.Message) -> None:
raise NotImplemented

View File

@ -1,5 +1,4 @@
import discord
from .Observer import Observer

View File

@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-
import DB
from TTSSilero import Speakers
class SpeakersSettingsAdapterDiscord:
def get_speaker(self, guild_id: int, user_id: int) -> str:
DEFAULT_SPEAKER = Speakers.kseniya
def get_speaker(self, guild_id: int, user_id: int) -> Speakers:
user_defined_speaker = self.get_speaker_user(guild_id, user_id)
if user_defined_speaker is None:
return self.get_speaker_global(guild_id)
@ -11,29 +14,36 @@ class SpeakersSettingsAdapterDiscord:
else:
return user_defined_speaker
def get_speaker_global(self, guild_id: int) -> str | None:
def get_speaker_global(self, guild_id: int) -> Speakers:
server_speaker_query = DB.ServerSpeaker.select()\
.where(DB.ServerSpeaker.server_id == guild_id)
if server_speaker_query.count() == 1:
return server_speaker_query.get().speaker
return Speakers(server_speaker_query.get().speaker)
else:
return None
return self.DEFAULT_SPEAKER
def get_speaker_user(self, guild_id: int, user_id: int) -> str | None:
def get_speaker_user(self, guild_id: int, user_id: int) -> Speakers | None:
user_speaker_query = DB.UserServerSpeaker.select()\
.where(DB.UserServerSpeaker.server_id == guild_id)\
.where(DB.UserServerSpeaker.user_id == user_id)
if user_speaker_query.count() == 1:
return user_speaker_query.get().speaker
return Speakers(user_speaker_query.get().speaker)
else:
return None
def set_speaker_user(self, guild_id: int, user_id: int, speaker: str) -> None:
DB.UserServerSpeaker.replace(server_id=guild_id, user_id=user_id, speaker=speaker).execute()
@property
def available_speakers(self) -> set[str]:
return {speaker.name for speaker in Speakers}
def set_speaker_global(self, guild_id: int, user_id: int, speaker: str) -> None:
DB.ServerSpeaker.replace(server_id=guild_id, speaker=speaker).execute()
def set_speaker_user(self, guild_id: int, user_id: int, speaker: Speakers) -> None:
DB.UserServerSpeaker.replace(server_id=guild_id, user_id=user_id, speaker=speaker.value).execute()
def set_speaker_global(self, guild_id: int, speaker: Speakers) -> None:
DB.ServerSpeaker.replace(server_id=guild_id, speaker=speaker.value).execute()
speakers_settings_adapter = SpeakersSettingsAdapterDiscord()

View File

@ -1,15 +0,0 @@
from typing import Literal
from pydantic import BaseModel
class Argument(BaseModel):
type: Literal['str', 'int', 'float']
description: str | None = None
class ModelDescription(BaseModel):
engine: str
name: str
arguments: dict[str, Argument]
description: None | str = None

View File

@ -1,19 +0,0 @@
import aiohttp
import config
from .Speakers import ModelDescription
class TTSServer:
def __init__(self):
self.session = aiohttp.ClientSession(base_url=config.BASE_URL)
async def synthesize_text(self, text: str, speaker: ModelDescription) -> bytes:
async with self.session.post(url=f'/synth/{speaker.engine}/model/{speaker.name}', json={'text': text}) as req:
req.raise_for_status()
return await req.content.read()
async def discovery(self) -> list[ModelDescription]:
async with self.session.get('/discovery') as req:
req.raise_for_status()
return [ModelDescription(**item) for item in await req.json()]

View File

@ -1,53 +0,0 @@
from loguru import logger
import DB
from .Speakers import ModelDescription
from .TTSServer import TTSServer
class TTSServerCached(TTSServer):
def __init__(self):
super().__init__()
self.models_lookup_idx: dict[str, ModelDescription] = dict()
async def synthesize_text(self, text: str, speaker: ModelDescription) -> bytes:
cache_query = DB.SoundCache.select() \
.where(DB.SoundCache.text == text) \
.where(DB.SoundCache.speaker == speaker)
if cache_query.count() == 1:
with DB.database.atomic():
DB.SoundCache.update({DB.SoundCache.usages: DB.SoundCache.usages + 1}) \
.where(DB.SoundCache.text == text) \
.where(DB.SoundCache.speaker == speaker).execute()
cached = cache_query.get().audio
return cached
else:
synthesized = await super().synthesize_text(text, speaker)
DB.SoundCache.create(text=text, speaker=speaker, audio=synthesized)
return synthesized
async def discovery(self) -> list[ModelDescription]:
res = await super().discovery()
logger.debug(f'Discovered {len(res)} models')
self.models_lookup_idx: dict[str, ModelDescription] = {f'{desc.engine}_{desc.name}': desc for desc in res}
return res
async def speaker_to_description(self, speaker_str: str) -> ModelDescription | None:
for attempt in range(2):
if len(self.models_lookup_idx) > 0:
break
await self.discovery()
else:
raise RuntimeError('Models discovery seems to return zero models')
# Default to first model
return self.models_lookup_idx.get(speaker_str, None)
def default_speaker(self) -> ModelDescription:
return list(self.models_lookup_idx.values())[0]

View File

@ -1,3 +0,0 @@
from .Speakers import ModelDescription
from .TTSServer import TTSServer
from .TTSServerCached import TTSServerCached

15
TTSSilero/Speakers.py Normal file
View File

@ -0,0 +1,15 @@
import enum
class Speakers(enum.Enum):
aidar = 'aidar'
baya = 'baya'
kseniya = 'kseniya'
irina = 'irina'
ruslan = 'ruslan'
natasha = 'natasha'
thorsten = 'thorsten'
tux = 'tux'
gilles = 'gilles'
lj = 'lj'
dilyara = 'dilyara'

62
TTSSilero/TTSSilero.py Normal file
View File

@ -0,0 +1,62 @@
import os
import io
import torch.package
from .Speakers import Speakers
from multi_acc_v3_package import TTSModelMultiAcc_v3
class TTSSilero:
def __init__(self, threads: int = 24):
device = torch.device('cpu')
torch.set_num_threads(threads)
"""
local_file = 'model_multi.pt'
if not os.path.isfile(local_file):
torch.hub.download_url_to_file(
'https://models.silero.ai/models/tts/multi/v2_multi.pt',
local_file
)
self.model: TTSModelMulti_v2 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
self.model.to(device)
"""
local_file = 'model.pt'
if not os.path.isfile(local_file):
torch.hub.download_url_to_file('https://models.silero.ai/models/tts/ru/ru_v3.pt', local_file)
self.model: TTSModelMultiAcc_v3 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
self.model.to(device)
# print(self.model.speakers)
self.sample_rate = 48000
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
return self.to_wav(self._synthesize_text(text, speaker))
def _synthesize_text(self, text: str, speaker: Speakers) -> torch.Tensor:
"""
Performs splitting text and synthesizing it
:param text:
:return:
"""
result: torch.Tensor = self.model.apply_tts(
text=text,
speaker=speaker.value,
sample_rate=self.sample_rate
)
return result
def to_wav(self, synthesized_text: torch.Tensor) -> bytes:
res_io_stream = io.BytesIO()
self.model.write_wave(res_io_stream, (synthesized_text * 32767).numpy().astype('int16'), self.sample_rate)
res_io_stream.seek(0)
return res_io_stream.read()

View File

@ -0,0 +1,36 @@
import time
from typing import Union
from .TTSSilero import TTSSilero
from .Speakers import Speakers
import DB
import sqlite3
from loguru import logger
class TTSSileroCached(TTSSilero):
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
# start = time.time()
cache_query = DB.SoundCache.select()\
.where(DB.SoundCache.text == text)\
.where(DB.SoundCache.speaker == speaker.value)
if cache_query.count() == 1:
with DB.database.atomic():
DB.SoundCache.update({DB.SoundCache.usages: DB.SoundCache.usages + 1})\
.where(DB.SoundCache.text == text)\
.where(DB.SoundCache.speaker == speaker.value).execute()
cached = cache_query.get().audio
return cached
else:
# logger.debug(f'Starting synthesis')
# start2 = time.time()
synthesized = super().synthesize_text(text, speaker)
# logger.debug(f'Synthesis done in {time.time() - start2} s in {time.time() - start} s after start')
DB.SoundCache.create(text=text, speaker=speaker.value, audio=synthesized)
# logger.debug(f'Cache set in {time.time() - start2} synth end and {time.time() - start2} s after start')
return synthesized

3
TTSSilero/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .TTSSilero import TTSSilero
from .Speakers import Speakers
from .TTSSileroCached import TTSSileroCached

View File

@ -0,0 +1,148 @@
import re
import wave
import torch
import warnings
import contextlib
# for type hints only
class TTSModelMulti_v2():
def __init__(self, model_path, symbols):
self.model = self.init_jit_model(model_path)
self.symbols = symbols
self.device = torch.device('cpu')
speakers = ['aidar', 'baya', 'kseniya', 'irina', 'ruslan', 'natasha',
'thorsten', 'tux', 'gilles', 'lj', 'dilyara']
self.speaker_to_id = {sp: i for i, sp in enumerate(speakers)}
def init_jit_model(self, model_path: str):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location='cpu')
model.eval()
return model
def prepare_text_input(self, text, symbols, symbol_to_id=None):
if len(text) > 140:
warnings.warn('Text string is longer than 140 symbols.')
if symbol_to_id is None:
symbol_to_id = {s: i for i, s in enumerate(symbols)}
text = text.lower()
text = re.sub(r'[^{}]'.format(symbols[2:]), '', text)
text = re.sub(r'\s+', ' ', text).strip()
if text[-1] not in ['.', '!', '?']:
text = text + '.'
text = text + symbols[1]
text_ohe = [symbol_to_id[s] for s in text if s in symbols]
text_tensor = torch.LongTensor(text_ohe)
return text_tensor
def prepare_tts_model_input(self, text: str or list, symbols: str, speakers: list):
assert len(speakers) == len(text) or len(speakers) == 1
if type(text) == str:
text = [text]
symbol_to_id = {s: i for i, s in enumerate(symbols)}
if len(text) == 1:
return self.prepare_text_input(text[0], symbols, symbol_to_id).unsqueeze(0), torch.LongTensor(speakers), torch.LongTensor([0])
text_tensors = []
for string in text:
string_tensor = self.prepare_text_input(string, symbols, symbol_to_id)
text_tensors.append(string_tensor)
input_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(t) for t in text_tensors]),
dim=0, descending=True)
max_input_len = input_lengths[0]
batch_size = len(text_tensors)
text_padded = torch.ones(batch_size, max_input_len, dtype=torch.int32)
if len(speakers) == 1:
speakers = speakers*batch_size
speaker_ids = torch.LongTensor(batch_size).zero_()
for i, idx in enumerate(ids_sorted_decreasing):
text_tensor = text_tensors[idx]
in_len = text_tensor.size(0)
text_padded[i, :in_len] = text_tensor
speaker_ids[i] = speakers[idx]
return text_padded, speaker_ids, ids_sorted_decreasing
def process_tts_model_output(self, out, out_lens, ids):
out = out.to('cpu')
out_lens = out_lens.to('cpu')
_, orig_ids = ids.sort()
proc_outs = []
orig_out = out.index_select(0, orig_ids)
orig_out_lens = out_lens.index_select(0, orig_ids)
for i, out_len in enumerate(orig_out_lens):
proc_outs.append(orig_out[i][:out_len])
return proc_outs
def to(self, device):
self.model = self.model.to(device)
self.device = device
def get_speakers(self, speakers: str or list):
if type(speakers) == str:
speakers = [speakers]
speaker_ids = []
for speaker in speakers:
try:
speaker_id = self.speaker_to_id[speaker]
speaker_ids.append(speaker_id)
except Exception:
raise ValueError(f'No such speaker: {speaker}')
return speaker_ids
def apply_tts(self, texts: str or list,
speakers: str or list,
sample_rate: int = 16000):
speaker_ids = self.get_speakers(speakers)
text_padded, speaker_ids, orig_ids = self.prepare_tts_model_input(texts,
symbols=self.symbols,
speakers=speaker_ids)
with torch.inference_mode():
out, out_lens = self.model(text_padded.to(self.device),
speaker_ids.to(self.device),
sr=sample_rate)
audios = self.process_tts_model_output(out, out_lens, orig_ids)
return audios
@staticmethod
def write_wave(path, audio, sample_rate):
"""Writes a .wav file.
Takes path, PCM audio data, and sample rate.
"""
with contextlib.closing(wave.open(path, 'wb')) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio)
def save_wav(self, texts: str or list,
speakers: str or list,
audio_pathes: str or list = '',
sample_rate: int = 16000):
if type(texts) == str:
texts = [texts]
if not audio_pathes:
audio_pathes = [f'test_{str(i).zfill(3)}.wav' for i in range(len(texts))]
if type(audio_pathes) == str:
audio_pathes = [audio_pathes]
assert len(audio_pathes) == len(texts)
audio = self.apply_tts(texts=texts,
speakers=speakers,
sample_rate=sample_rate)
for i, _audio in enumerate(audio):
self.write_wave(path=audio_pathes[i],
audio=(_audio * 32767).numpy().astype('int16'),
sample_rate=sample_rate)
return audio_pathes

View File

@ -1,6 +1,6 @@
from loguru import logger
from discord.ext import commands
from discord.ext.commands import Context
from loguru import logger
class cogErrorHandlers:

View File

@ -1,10 +1,9 @@
# -*- coding: utf-8 -*-
import datetime
import time
import discord
from discord.ext import commands, tasks
from discord.ext.commands import Context
import datetime
import time
class BotInformation(commands.Cog):
@ -16,12 +15,12 @@ class BotInformation(commands.Cog):
@commands.command('info')
async def get_info(self, ctx: Context):
info = f"""
Text-To-Speech bot
Text-To-Speech bot, based on Silero TTS model (<https://github.com/snakers4/silero-models>)
License: `GNU GENERAL PUBLIC LICENSE Version 3`
Author discord: `@a31`
Author email: `a31@demb.uk`
Author discord: `a31#6403`
Author email: `a31@demb.design`
Source code on github: <https://github.com/norohind/SileroTTSBot>
Source code on gitea a31's instance: <https://gitea.demb.uk/a31/SileroTTSBot>
Source code on gitea's a31 instance: <https://gitea.demb.design/a31/SileroTTSBot>
Invite link: https://discord.com/oauth2/authorize?client_id={self.bot.user.id}&scope=bot%20applications.commands
"""

View File

@ -1,140 +0,0 @@
# -*- coding: utf-8 -*-
import asyncio
import time
try:
import rfoo
from rfoo.utils import rconsole
except ImportError as e:
print("not importing rfoo", e)
from threading import Thread
from typing import Optional, Coroutine
from discord.ext import commands
from discord.ext.commands import Context
from loguru import logger
from formatting import format_table
class SingletonBase:
def __new__(cls, *args, **kwargs):
if getattr(cls, 'instance') is None:
cls.instance = super().__new__(cls, *args, **kwargs)
# logger.debug(id(getattr(cls, 'instance')))
return getattr(cls, 'instance')
class BotManagement(commands.Cog, SingletonBase):
instance: Optional['BotManagement'] = None
rfoo_server_thread: Optional[Thread] = None
rfoo_server: Optional['rfoo.InetServer'] = None
# def __new__(cls, *args, **kwargs):
# if cls.instance is None:
# cls.instance = super().__new__(cls, *args, **kwargs)
#
# return cls.instance
def __init__(self, bot: commands.Bot):
self.bot = bot
@commands.is_owner()
@commands.command('listServers')
async def list_servers(self, ctx: Context):
text_table = format_table(((server.name, str(server.id)) for server in self.bot.guilds), ('name', 'id'))
await ctx.channel.send(text_table)
@commands.is_owner()
@commands.command('listVoice')
async def list_voice_connections(self, ctx: Context):
text_table = format_table(((it.guild.name,) for it in self.bot.voice_clients))
await ctx.channel.send(text_table)
def start_rfoo(self) -> bool:
# True if started, False if already started
if self.rfoo_server_thread is None:
self.rfoo_server = rfoo.InetServer(rconsole.ConsoleHandler, {'bot': self.bot, 'ct': self.ct})
self.rfoo_server_thread = Thread(target=lambda: self.rfoo_server.start(rfoo.LOOPBACK, 54321))
self.rfoo_server_thread.daemon = True
self.rfoo_server_thread.start()
logger.info('Rfoo thread started by msg')
return True
return False
def stop_rfoo(self) -> bool:
if self.rfoo_server_thread is not None:
self.rfoo_server.stop()
del self.rfoo_server_thread
logger.info('Rfoo thread stopped by msg')
return True
return False
@commands.is_owner()
@commands.command('rfooStart')
async def start(self, ctx: Context):
if self.start_rfoo():
await ctx.send('Rfoo thread started')
else:
await ctx.send('Rfoo thread already started')
@commands.is_owner()
@commands.command('rfooStop')
async def stop(self, ctx: Context):
if self.stop_rfoo():
await ctx.send('Rfoo server stopped')
else:
await ctx.send('Rfoo server already stopped')
def ct(self, coro: Coroutine):
"""
ct - short from create_task
execute coroutine and get result
"""
task = self.bot.loop.create_task(coro)
while not task.done():
time.sleep(0.1)
try:
return task.result()
except asyncio.exceptions.InvalidStateError:
return task.exception()
@commands.command('shutdown')
async def shutdown(self, ctx: Context) -> None:
"""Shutdown bot with hope docker daemon will restart it, `@a31` and `@furiz__` has rights for it"""
if ctx.author.id in (420130693696323585, # @a31
444819880781545472): # @furiz__
log_msg = f"Got shutdown command by {ctx.author}"
logger.info(log_msg)
await ctx.reply('Shutting down')
try:
await (await self.bot.fetch_user(420130693696323585)).send(log_msg)
except Exception as e:
logger.opt(exception=e).warning(f'Failed to send shutdown message', )
self.bot.loop.create_task(self.bot.close())
else:
await ctx.reply('No rights for you')
async def setup(bot: commands.Bot):
await bot.add_cog(BotManagement(bot))
async def teardown(bot):
stop_res = BotManagement(bot).stop_rfoo()
logger.info(f'Unloaded rfoo with result {stop_res} during BotManagement unload')

View File

@ -2,21 +2,18 @@
import subprocess
import time
from collections import defaultdict
from typing import Union, Optional
import discord
from discord.ext import commands
from discord.ext.commands import Context
from loguru import logger
import discord
import DB
import Observ
import formatting
from typing import Union
from loguru import logger
from TTSSilero import TTSSileroCached
from TTSSilero import Speakers
from FFmpegPCMAudioModified import FFmpegPCMAudio
from SpeakersSettingsAdapterDiscord import SpeakersSettingsAdapterDiscord
from TTSServer import TTSServerCached
from TTSServer.Speakers import ModelDescription
import Observ
from cogErrorHandlers import cogErrorHandlers
from SpeakersSettingsAdapterDiscord import speakers_settings_adapter, SpeakersSettingsAdapterDiscord
class TTSCore(commands.Cog, Observ.Observer):
@ -24,33 +21,12 @@ class TTSCore(commands.Cog, Observ.Observer):
self.bot = bot
self.cog_command_error = cogErrorHandlers.missing_argument_handler
self.bot.subscribe(self) # subscribe for messages that aren't commands
self.tts = TTSServerCached()
self.tts = TTSSileroCached()
self.tts_queues: dict[int, list[discord.AudioSource]] = defaultdict(list)
self.speakers_adapter: SpeakersSettingsAdapterDiscord = SpeakersSettingsAdapterDiscord()
@commands.command('drop')
async def drop_queue(self, ctx: Context):
"""
Drop tts queue for current server
:param ctx:
:return:
"""
try:
del self.tts_queues[ctx.guild.id]
await ctx.send('Queue dropped')
except KeyError:
await ctx.send('Failed on dropping queue')
self.speakers_adapter: SpeakersSettingsAdapterDiscord = speakers_settings_adapter
@commands.command('exit')
async def leave_voice(self, ctx: Context):
"""
Disconnect bot from current channel, it also drop messages queue
:param ctx:
:return:
"""
if ctx.guild.voice_client is not None:
await ctx.guild.voice_client.disconnect(force=False)
await ctx.channel.send(f"Left voice channel")
@ -88,67 +64,43 @@ class TTSCore(commands.Cog, Observ.Observer):
if voice_client is None:
voice_client: discord.VoiceClient = await user_voice_state.channel.connect()
if user_voice_state.channel.id != voice_client.channel.id:
await message.channel.send('We are in different voice channels')
return
speaker_str: str = self.speakers_adapter.get_speaker(message.guild.id, message.author.id)
speaker: ModelDescription | None = await self.tts.speaker_to_description(speaker_str)
if speaker is None:
speaker = self.tts.default_speaker()
speaker: Speakers = self.speakers_adapter.get_speaker(message.guild.id, message.author.id)
# check if message will fail on synthesis
if DB.SynthesisErrors.select() \
.where(DB.SynthesisErrors.speaker == speaker_str) \
.where(DB.SynthesisErrors.text == message.content) \
if DB.SynthesisErrors.select()\
.where(DB.SynthesisErrors.speaker == speaker.value)\
.where(DB.SynthesisErrors.text == message.content)\
.count() == 1:
# Then we will not try to synthesis it
await message.channel.send(f"I will not synthesis this message due to TTS engine limitations")
return
try:
wav_file_like_object: bytes = await self.tts.synthesize_text(message.content, speaker=speaker)
wav_file_like_object = self.tts.synthesize_text(message.content, speaker=speaker)
sound_source = FFmpegPCMAudio(wav_file_like_object, pipe=True, stderr=subprocess.PIPE)
if voice_client.is_playing():
# Then we need to enqueue prepared sound for playing through self.tts_queues mechanism
self.tts_queues[message.guild.id].append(sound_source)
await message.channel.send(f"Enqueued for play, queue size: {len(self.tts_queues[message.guild.id])}")
return
voice_client.play(sound_source, after=lambda e: self.queue_player(message))
except Exception as synth_exception:
logger.opt(exception=True).warning(f'Exception on synthesize {message.content!r}: {synth_exception}')
await message.channel.send(f'Synthesize error')
DB.SynthesisErrors.create(speaker=speaker, text=message.content)
return
else:
try:
if voice_client.is_playing():
# Then we need to enqueue prepared sound for playing through self.tts_queues mechanism
self.tts_queues[message.guild.id].append(sound_source)
await message.channel.send(
f"Enqueued for play, queue size: {len(self.tts_queues[message.guild.id])}")
return
voice_client.play(sound_source, after=lambda e: self.queue_player(message))
except Exception as play_exception:
logger.opt(exception=True).warning(
f'Exception on playing for: {message.guild.name}[#{message.channel.name}]: {message.author.display_name} / {play_exception}')
await message.channel.send(f'Playing error')
return
await message.channel.send(f'Internal error')
DB.SynthesisErrors.create(speaker=speaker.value, text=message.content)
def queue_player(self, message: discord.Message):
voice_client: Union[discord.VoiceClient, None] = message.guild.voice_client
if voice_client is None:
# don't play anything and clear queue for whole guild
del self.tts_queues[message.guild.id]
return
for sound_source in self.tts_queues[message.guild.id]:
if len(self.tts_queues[message.guild.id]) == 0:
return
voice_client: Optional[discord.VoiceClient] = message.guild.voice_client
if voice_client is None:
# don't play anything and clear queue for whole guild
break
try:
voice_client.play(sound_source)
except discord.errors.ClientException: # Here we expect Not connected to voice
break
voice_client.play(sound_source)
while voice_client.is_playing():
time.sleep(0.1)
@ -159,105 +111,13 @@ class TTSCore(commands.Cog, Observ.Observer):
pass
@commands.Cog.listener()
async def on_voice_state_update(self, member: discord.Member, before: discord.VoiceState,
after: discord.VoiceState):
async def on_voice_state_update(self, member: discord.Member, before: discord.VoiceState, after: discord.VoiceState):
if after.channel is None:
members = before.channel.members
if len(members) == 1:
if members[0].id == self.bot.user.id:
await before.channel.guild.voice_client.disconnect(force=False)
# TODO: leave voice channel after being moved there alone
# Ex TTSSettings.py
@commands.command('getAllSpeakers')
async def get_speakers(self, ctx: Context):
"""
Enumerate all available to set speakers
:param ctx:
:return:
"""
await ctx.send(
formatting.format_table(
tuple((model.engine + '_' + model.name, model.description)
for model in await self.tts.discovery()),
('model', 'description')
))
@commands.command('setPersonalSpeaker')
async def set_user_speaker(self, ctx: Context, speaker: str):
"""
Set personal speaker on this server
:param ctx:
:param speaker:
:return:
"""
if await self.tts.speaker_to_description(speaker) is None:
return await ctx.send(
f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
self.speakers_adapter.set_speaker_user(ctx.guild.id, ctx.author.id, speaker)
await ctx.reply(f'Successfully set **your personal** speaker to `{speaker}`')
@commands.command('setServerSpeaker')
async def set_server_speaker(self, ctx: Context, speaker: str):
"""
Set global server speaker
:param ctx:
:param speaker:
:return:
"""
if await self.tts.speaker_to_description(speaker) is None:
return await ctx.send(
f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
self.speakers_adapter.set_speaker_global(ctx.guild.id, ctx.author.id, speaker)
await ctx.reply(f'Successfully set **your personal** speaker to `{speaker}`')
@commands.command('getSpeaker')
async def get_speaker(self, ctx: Context):
"""
Tell first appropriate speaker for a user, it can be user specified, server specified or server default
:param ctx:
:return:
"""
speaker = self.speakers_adapter.get_speaker(ctx.guild.id, ctx.author.id)
await ctx.reply(f'Your current speaker is `{speaker}`')
@commands.command('getPersonalSpeaker')
async def get_personal_speaker(self, ctx: Context):
"""
Tell user his personal speaker on this server, if user don't have personal speaker, tells server default one
:param ctx:
:return:
"""
speaker = self.speakers_adapter.get_speaker_user(ctx.guild.id, ctx.author.id)
if speaker is None:
server_speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
await ctx.send(f"You currently don't have a personal speaker, current server speaker is `{server_speaker}`")
else:
await ctx.reply(f"Your personal speaker is `{speaker}`")
@commands.command('getServerSpeaker')
async def get_server_speaker(self, ctx: Context):
"""
Tell server global speaker
:param ctx:
:return:
"""
speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
await ctx.send(f"Current server speaker is `{speaker}`")
async def setup(bot):
await bot.add_cog(TTSCore(bot))

102
cogs/TTSSettings.py Normal file
View File

@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
from discord.ext import commands
from discord.ext.commands import Context
from TTSSilero import Speakers
from cogErrorHandlers import cogErrorHandlers
from SpeakersSettingsAdapterDiscord import speakers_settings_adapter, SpeakersSettingsAdapterDiscord
class TTSSettings(commands.Cog):
def __init__(self, bot: commands.Bot):
self.bot = bot
self.cog_command_error = cogErrorHandlers.missing_argument_handler
self.speakers_adapter: SpeakersSettingsAdapterDiscord = speakers_settings_adapter
@commands.command('getAllSpeakers')
async def get_speakers(self, ctx: Context):
"""
Enumerate all available to set speakers
:param ctx:
:return:
"""
speakers = '\n'.join(self.speakers_adapter.available_speakers)
await ctx.send(f"```\n{speakers}```")
@commands.command('setPersonalSpeaker')
async def set_user_speaker(self, ctx: Context, speaker: str):
"""
Set personal speaker on this server
:param ctx:
:param speaker:
:return:
"""
try:
checked_speaker: Speakers = Speakers(speaker)
self.speakers_adapter.set_speaker_user(ctx.guild.id, ctx.author.id, checked_speaker)
await ctx.reply(f'Successfully set **your personal** speaker to `{checked_speaker.value}`')
except (KeyError, ValueError):
await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
@commands.command('setServerSpeaker')
async def set_server_speaker(self, ctx: Context, speaker: str):
"""
Set global server speaker
:param ctx:
:param speaker:
:return:
"""
try:
checked_speaker: Speakers = Speakers(speaker)
self.speakers_adapter.set_speaker_global(ctx.guild.id, checked_speaker)
await ctx.send(f'Successfully set **server** speaker to `{checked_speaker.value}`')
except (KeyError, ValueError):
await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
@commands.command('getSpeaker')
async def get_speaker(self, ctx: Context):
"""
Tell first appropriate speaker for a user, it can be user specified, server specified or server default
:param ctx:
:return:
"""
speaker = self.speakers_adapter.get_speaker(ctx.guild.id, ctx.author.id)
await ctx.reply(f'Your current speaker is `{speaker.value}`')
@commands.command('getPersonalSpeaker')
async def get_personal_speaker(self, ctx: Context):
"""
Tell user his personal speaker on this server, if user don't have personal speaker, tells server default one
:param ctx:
:return:
"""
speaker = self.speakers_adapter.get_speaker_user(ctx.guild.id, ctx.author.id)
if speaker is None:
server_speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id).value
await ctx.send(f"You currently don't have a personal speaker, current server speaker is `{server_speaker}`")
else:
await ctx.reply(f"Your personal speaker is `{speaker.value}`")
@commands.command('getServerSpeaker')
async def get_server_speaker(self, ctx: Context):
"""
Tell server global speaker
:param ctx:
:return:
"""
speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
await ctx.send(f"Current server speaker is `{speaker.value}`")
async def setup(bot):
await bot.add_cog(TTSSettings(bot))

View File

@ -1,9 +1,8 @@
# -*- coding: utf-8 -*-
from discord.ext import commands
from discord.ext.commands import Context
from loguru import logger
import DB
from loguru import logger
import DynamicCommandPrefix
from cogErrorHandlers import cogErrorHandlers

View File

@ -1,5 +0,0 @@
import os
from pathlib import Path
DB_PATH = Path(os.getenv('DATA_DIR', '.')) / 'voice_cache.sqlite'
BASE_URL = os.environ['BASE_URL']

View File

@ -1,18 +0,0 @@
from typing import Iterable
class MISSING:
pass
def format_table(data: Iterable[Iterable[str]], header: Iterable[str] = MISSING) -> str:
if header != MISSING:
data = (header, *data)
result = '```\n'
for row in data:
row = [str(item).replace('`', '\\`') for item in row]
result += '\t'.join(row) + '\n'
result += '```'
return result

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.1 KiB

24
main.py
View File

@ -1,19 +1,15 @@
# -*- coding: utf-8 -*-
import asyncio
import os
import signal
import discord
from discord.ext import commands
import discord
import signal
import asyncio
from loguru import logger
import Observ
from DynamicCommandPrefix import dynamic_command_prefix
import Observ
LOG_FILE_ENABLED = os.getenv('LOG_FILE_ENABLED', 'true').lower() == 'true'
if LOG_FILE_ENABLED:
logger.add('offlineTTSBot.log', backtrace=True, diagnose=False, rotation='5MB')
logger.add('offlineTTSBot.log', backtrace=True, diagnose=False, rotation='5MB')
"""
while msg := input('$ '):
@ -53,7 +49,7 @@ class DiscordTTSBot(commands.Bot, Observ.Subject):
await self.notify(ctx.message)
else:
logger.opt(exception=exception).warning(f'Global error caught:')
raise exception
intents = discord.Intents.default()
@ -70,8 +66,6 @@ async def main():
await discord_client.start(os.environ['DISCORD_TOKEN'])
if __name__ == '__main__':
loop = asyncio.new_event_loop()
loop.run_until_complete(main())
logger.debug('Shutdown completed')
loop = asyncio.new_event_loop()
loop.run_until_complete(main())
logger.debug('Shutdown completed')

381
multi_acc_v3_package.py Normal file
View File

@ -0,0 +1,381 @@
import re
import wave
import torch
import warnings
import contextlib
import xml.etree.ElementTree as ET
class TTSModelMultiAcc_v3():
def __init__(self, model_path, symbols, speaker_to_id, emb_dim=128):
torch.set_grad_enabled(False)
self.model = self.init_jit_model(model_path)
self.symbols = symbols
self.device = torch.device('cpu')
self.speaker_to_id = speaker_to_id
self.speakers = list(speaker_to_id.keys())
assert 'random' in self.speakers
self.ru_ascii_dict = {r: asc for r, asc in zip('абвгдеёжзийклмнопрстуфхцчшщъыьэюя–',
'abvgde1jzi2klmnoprstufhc4w35y6789=')}
# ssml tags
self.strength2time = {'x_weak': 25, 'weak': 75, 'medium': 150, 'strong': 300, 'x-strong': 1000}
self.rate2value = {'x-slow': 0.5, 'slow': 0.8, 'medium': 1., 'fast': 1.2, 'x-fast': 1.5}
self.pitch2value = {'x-low': 0.6, 'low': 0.8, 'medium': 1., 'high': 1.2, 'x-high': 1.4, 'robot': 0.}
self.emb_dim = emb_dim
self.random_emb = None
self.debug = False
self.valid_tags = {'break': {'strength': list(self.strength2time.keys())},
'prosody': {'rate': list(self.rate2value.keys()),
'pitch': list(self.pitch2value.keys())}}
def init_jit_model(self, model_path: str):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location='cpu')
model.eval()
return model
def ru_to_ascii(self, sentence):
ascii_list = [self.ru_ascii_dict.get(s, s) for s in sentence]
ascii_text = ''.join(ascii_list)
return ascii_text
def prepare_text_input(self, text):
text = text.lower()
text = text.replace('', '').replace('', '').replace('', '-')
text = re.sub(r'[^{}]'.format(self.symbols[3:]), '', text)
text = re.sub(r'\s+', ' ', text).strip()
sentence = self.ru_to_ascii(text)
clean_sentence = re.sub(r'[^a-z1-9\- ]', '', sentence)
has_text = len(clean_sentence.replace(' ', '')) > 0
return sentence, clean_sentence, has_text
def prepare_tts_model_input(self, text: str,
ssml: bool,
speaker_ids: list):
if ssml:
clean_text_list = self.process_ssml(text)
else:
clean_text_list = self.process_simple_text(text)
sentences, clean_sentences, break_lens, prosody_rates, prosody_pitches = map(list, zip(*map(dict.values, clean_text_list)))
full_text_len = sum([len(s) for s in sentences])
if full_text_len > 1000:
warnings.warn('Text string is longer than 1000 symbols.')
speaker_ids = torch.LongTensor(speaker_ids)
return sentences, clean_sentences, break_lens, prosody_rates, prosody_pitches, speaker_ids
def to(self, device):
self.model.tts_model = self.model.tts_model.to(device)
self.device = device
def get_speakers(self, speaker: str, voice_path=None):
try:
if speaker == 'random':
self.load_random_voice(voice_path)
speaker_id = self.speaker_to_id.get(speaker, None)
if speaker_id is None:
raise ValueError(f"`speaker` should be in {', '.join(self.speakers)}")
except Exception as e:
raise ValueError(f'Failed to load speaker: {speaker}, error: {e}')
return [speaker_id]
def process_simple_text(self, text):
sentence, clean_sentence, has_text = self.prepare_text_input(text)
if not has_text:
raise ValueError
simple_text_dict = [{'text': sentence,
'clean_text': clean_sentence,
'break_time': None,
'prosody_rate': 1.,
'prosody_pitch': 1.}]
return simple_text_dict
def process_ssml(self, ssml_text):
ssml_text = re.sub(r'\s+', ' ', ssml_text).strip().replace('\n ', '\n')
try:
root = ET.fromstring(ssml_text)
except Exception:
raise ValueError("Invalid XML format")
assert root.tag == 'speak', "Invalid SSML format: <speak> tag is essential"
try:
ssml_parsed = self.process_ssml_element(root)
if self.debug:
print(ssml_parsed)
except AssertionError as ae:
raise ae
except Exception as e:
raise ValueError(f"Failed to parse SSML: {e}")
try:
clean_text_list = self.process_ssml_tag_dict(ssml_parsed)
if self.debug:
print(clean_text_list)
except Exception as e:
raise ValueError(f"Failed to process SSML: {e}")
return clean_text_list
def process_ssml_tag_dict(self, text_break_list):
proc_text_break_list = []
for i, text_break_prosody in enumerate(text_break_list):
tbreak = text_break_prosody['break']
tprosody = text_break_prosody['prosody']
text, clean_text, has_text = self.prepare_text_input(text_break_prosody['text'])
break_time = int(tbreak['time']/12.5) if tbreak['time'] is not None else None
if has_text or i == 0:
text = self.check_text_break(text, tbreak)
proc_text_break_list.append({'text': text,
'clean_text': clean_text,
'break_time': break_time,
'prosody_rate': tprosody['rate'],
'prosody_pitch': tprosody['pitch']})
elif tbreak['strength'] is not None and len(proc_text_break_list) > 0:
text = self.check_text_break(proc_text_break_list[-1]['text'], tbreak)
proc_text_break_list[-1]['text'] = text
if proc_text_break_list[-1]['break_time'] is None:
proc_text_break_list[-1]['break_time'] = break_time
else:
proc_text_break_list[-1]['break_time'] = max(break_time, proc_text_break_list[-1]['break_time'])
return proc_text_break_list
def check_text_break(self, text, tbreak):
if len(text) == 0 or tbreak['strength'] is not None and text[-1] not in '!,-.:;?–…': # TODO fx dash
text = text + '.'
return text
def process_ssml_element(self, element, def_strength='strong', def_rate=1., def_pitch=1.):
parsed = []
last_tag = None
head_text_parsed = self.process_head_tail_text(element.text, def_rate, def_pitch)
parsed.extend(head_text_parsed)
for child in element:
if child.tag == 'break':
break_strength, break_ts = self.process_break_attrib(child.attrib)
if len(parsed) == 0:
parsed.append({'text': '.',
'break': {'strength': None,
'time': None},
'prosody': {'rate': def_rate,
'pitch': def_pitch}})
parsed[-1]['break'] = {'strength': break_strength,
'time': break_ts}
elif child.tag == 'prosody':
prosody_rate, prosody_pitch, change_rate, change_pitch = self.process_prosody(child.attrib)
child_rate = prosody_rate if change_rate else def_rate
child_pitch = prosody_pitch if change_pitch else def_pitch
child_parsed = self.process_ssml_element(child, def_strength, child_rate, child_pitch)
parsed.extend(child_parsed)
elif child.tag in ['p', 's']:
break_strength = 'strong' if child.tag == 's' else 'x-strong'
child_parsed = self.process_ssml_element(child, break_strength, def_rate, def_pitch)
if len(parsed) > 0 and (parsed[-1]['text'] or last_tag is not None or last_tag != child.tag):
if parsed[-1]['break']['strength'] is None:
parsed[-1]['break'] = {'strength': break_strength,
'time': self.strength2time[break_strength]}
else:
last_time = parsed[-1]['break']['time']
parsed[-1]['break'] = {'strength': break_strength,
'time': max(last_time, self.strength2time[break_strength])}
if len(child_parsed) > 0:
if child_parsed[-1]['break']['strength'] is None:
child_parsed[-1]['break'] = {'strength': break_strength,
'time': self.strength2time[break_strength]}
else:
last_time = child_parsed[-1]['break']['time']
child_parsed[-1]['break'] = {'strength': break_strength,
'time': max(last_time, self.strength2time[break_strength])}
parsed.extend(child_parsed)
else:
warnings.warn(f"Current model doesn't support SSML tag: {child.tag}")
last_tag = child.tag
if child.tail:
tail_text = child.tail
if tail_text[0] in '.,!?…–;:' and len(parsed) > 0:
lost_punct = tail_text[0]
parsed[-1]['text'] = parsed[-1]['text'].strip() + lost_punct
if len(tail_text) > 1:
tail_text = tail_text[1:]
tail_text_parsed = self.process_head_tail_text(tail_text, def_rate, def_pitch)
parsed.extend(tail_text_parsed)
return parsed
def process_head_tail_text(self, element_text, def_rate, def_pitch):
text_parsed = []
if element_text is None:
return text_parsed
proc_text = element_text.replace('\n', '')
proc_text = re.sub(r'\s+', ' ', proc_text).strip()
text_parsed.append({'text': proc_text,
'break': {'strength': None,
'time': None},
'prosody': {'rate': def_rate,
'pitch': def_pitch}})
return text_parsed
def process_break_attrib(self, attrib):
for k in attrib.keys():
if k not in ['strength', 'time']:
warnings.warn(f"Current model doesn't support SSML <break> attrib: {k}")
strength = attrib.get('strength', 'medium')
break_time = attrib.get('time', None)
if break_time is not None:
if break_time.endswith('ms'):
break_ts = int(break_time[:-2])
elif break_time.endswith('s'):
break_ts = int(break_time[:-1]) * 1000
else:
raise AssertionError("Invalid <break> tag, time should end with 'ms' or 's'")
if break_ts >= self.strength2time['x-strong']:
strength = 'x-strong'
elif break_ts >= self.strength2time['strong']:
strength = 'strong'
else:
if strength in self.strength2time:
break_ts = self.strength2time[strength]
else:
raise AssertionError(f"Invalid <break> tag, strength should be in {', '.join(self.valid_tags['break']['strength'])}")
if break_ts > 5000:
warnings.warn('Cuurent model supports pauses less than 5 sec')
break_ts = 5000
return strength, break_ts
def process_prosody(self, attrib):
for k in attrib.keys():
if k not in ['rate', 'pitch']:
warnings.warn(f"Current model doesn't support SSML <prosody> attrib: {k}")
rate = attrib.get('rate', None)
pitch = attrib.get('pitch', None)
assert rate is not None or pitch is not None, "Empty <prosody> tag"
if rate is not None:
change_rate = True
if rate.endswith('%'):
rate_val = int(rate.replace('%', '')) / 100
else:
rate_val = self.rate2value.get(rate, None)
if rate_val is None:
raise AssertionError(f"Invalid <prosody> tag, rate should be in {', '.join(self.valid_tags['prosody']['rate'])}")
else:
change_rate = False
rate_val = 1.
if pitch is not None:
change_pitch = True
if pitch.endswith('%'):
pitch_val = int(pitch.replace('%', '')[1:]) / 100
if pitch[0] == '+':
pitch_val = 1. + pitch_val
else:
pitch_val = 1. - pitch_val
else:
pitch_val = self.pitch2value.get(pitch, None)
if pitch_val is None:
raise AssertionError(f"Invalid <prosody> tag, pitch should be in {', '.join(self.valid_tags['prosody']['pitch'])}")
else:
change_pitch = False
pitch_val = 1.
return rate_val, pitch_val, change_rate, change_pitch
def apply_tts(self, text=None,
ssml_text=None,
speaker: str = 'xenia',
sample_rate: int = 48000,
put_accent=True,
put_yo=True,
voice_path=None):
assert sample_rate in [8000, 24000, 48000], f"`sample_rate` should be in [8000, 24000, 48000], current value is {sample_rate}"
assert speaker in self.speakers, f"`speaker` should be in {', '.join(self.speakers)}"
assert text is not None or ssml_text is not None, "Both `text` and `ssml_text` are empty"
ssml = ssml_text is not None
if ssml:
input_text = ssml_text
else:
input_text = text
speaker_ids = self.get_speakers(speaker, voice_path)
sentences, clean_sentences, break_lens, prosody_rates, prosody_pitches, sp_ids = self.prepare_tts_model_input(input_text,
ssml=ssml,
speaker_ids=speaker_ids)
with torch.no_grad():
try:
out, out_lens = self.model(sentences=sentences,
clean_sentences=clean_sentences,
break_lens=break_lens,
prosody_rates=prosody_rates,
prosody_pitches=prosody_pitches,
speaker_ids=sp_ids,
sr=sample_rate,
device=str(self.device),
put_yo=put_yo,
put_accent=put_accent
)
except RuntimeError as e:
raise Exception("Model couldn't generate your text, probably it's too long")
audio = out.to('cpu')[0]
return audio
@staticmethod
def write_wave(path, audio, sample_rate):
"""Writes a .wav file.
Takes path, PCM audio data, and sample rate.
"""
with contextlib.closing(wave.open(path, 'wb')) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio)
def save_wav(self, text=None,
ssml_text=None,
speaker: str = 'xenia',
audio_path: str = '',
sample_rate: int = 48000,
put_accent=True,
put_yo=True):
if not audio_path:
audio_path = 'test.wav'
audio = self.apply_tts(text=text,
ssml_text=ssml_text,
speaker=speaker,
sample_rate=sample_rate,
put_yo=put_yo,
put_accent=put_accent)
self.write_wave(path=audio_path,
audio=(audio * 32767).numpy().astype('int16'),
sample_rate=sample_rate)
return audio_path
def load_random_voice(self, voice_path=None):
if voice_path is None:
random_emb = torch.randn(2, self.emb_dim, requires_grad=False).to(self.device)
self.random_emb = random_emb
print("Generated new voice")
else:
random_emb = torch.load(voice_path, map_location=self.device)
print(f"Loaded voice from {voice_path}")
if self.random_emb is not None and torch.equal(self.random_emb, random_emb):
return
mel_weight = random_emb[0]
dur_weight = random_emb[1, :self.emb_dim//2]
p_weight = random_emb[1, self.emb_dim//2:]
self.model.tts_model.tacotron.speaker_embedding.weight[-1] = mel_weight
self.model.tts_model.dur_predictor.dur_pred.speaker_embedding.weight[-1] = dur_weight
self.model.tts_model.pitch_predictor.pitch_pred.speaker_embedding.weight[-1] = p_weight
def save_random_voice(self, voice_path):
assert self.random_emb is not None, "No generated random voice"
torch.save(self.random_emb, voice_path)
print(f"Saved generated voice to {voice_path}")

View File

@ -1,7 +1,14 @@
discord-py==2.3.2
loguru==0.7.2
peewee==3.17.5
PyNaCl==1.5.0
git+https://github.com/aaiyer/rfoo.git@1555bd4eed204bb6a33a5e313146a6c2813cfe91
# Cython is setup dependency for rfoo
pydantic
aiohttp==3.7.4.post0
async-timeout==3.0.1
attrs==21.4.0
chardet==4.0.0
idna==3.3
multidict==6.0.2
numpy==1.22.3
torch==1.11.0
typing-extensions==4.1.1
yarl==1.7.2
git+https://github.com/Rapptz/discord.py.git#egg=discord-py
loguru~=0.6.0
peewee~=3.14.10
PyNaCl

View File

@ -1,4 +1,10 @@
import winsound
import io
def save_bytes(filename: str, bytes_audio: bytes) -> None:
with open(file=filename, mode='wb') as res_file:
res_file.write(bytes_audio)
def play_bytes(bytes_sound: bytes) -> None: