Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
3a9989d0de | |||
bb28efe437 |
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,5 +0,0 @@
|
||||
.idea
|
||||
*__pycache__
|
||||
*.sqlite
|
||||
*.log
|
||||
model_multi.pt
|
12
CLISynth.py
Normal file
12
CLISynth.py
Normal file
@ -0,0 +1,12 @@
|
||||
import TTSSilero
|
||||
from time import time
|
||||
import utils
|
||||
|
||||
tts = TTSSilero.TTSSilero()
|
||||
|
||||
# while msg := input('$ '):
|
||||
start = time()
|
||||
audio = tts.synthesize_text(
|
||||
"""Миша, давай заведем старый жигуль ... ви ви ви ви ви ви ви ви ви ви ви ви ви ви... ви ви ви ви ви ви ви ви ви ви... ви ви пр пр пр ви ви ви пр пр пр пр пр ви ви ви ви... миша, не сиди, помоги дотолкать до гаража, может там придётся перебирать её""")
|
||||
print('synthesize took ', str(time() - start))
|
||||
# utils.play_bytes(audio)
|
4
DB.py
4
DB.py
@ -2,9 +2,7 @@ from datetime import datetime
|
||||
|
||||
import peewee
|
||||
|
||||
import config
|
||||
|
||||
database = peewee.SqliteDatabase(str(config.DB_PATH))
|
||||
database = peewee.SqliteDatabase('voice_cache.sqlite')
|
||||
|
||||
|
||||
class BaseModel(peewee.Model):
|
||||
|
33
Dockerfile
33
Dockerfile
@ -1,33 +0,0 @@
|
||||
FROM python:3.10-slim as builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
|
||||
RUN apt update && apt install git gcc libc6-dev -y --no-install-recommends && apt clean && rm -rf /var/lib/apt/lists/*
|
||||
COPY requirements.txt .
|
||||
RUN --mount=type=cache,target=/root/.cache/pip pip install Cython && \
|
||||
pip wheel --no-deps --wheel-dir /app/wheels -r requirements.txt
|
||||
|
||||
|
||||
|
||||
FROM python:3.10-slim
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE 1
|
||||
ENV PYTHONUNBUFFERED 1
|
||||
|
||||
RUN useradd -ms /bin/bash silero_user && \
|
||||
apt update && apt install ffmpeg -y --no-install-recommends && apt clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /app/wheels /wheels
|
||||
COPY --from=builder /app/requirements.txt .
|
||||
|
||||
RUN pip install --no-cache /wheels/*
|
||||
|
||||
COPY . .
|
||||
USER silero_user
|
||||
|
||||
CMD ["python3", "/app/main.py"]
|
@ -1,11 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import discord
|
||||
import peewee
|
||||
from discord.ext import commands
|
||||
|
||||
import peewee
|
||||
import DB
|
||||
|
||||
|
||||
# from loguru import logger
|
||||
|
||||
|
||||
|
@ -1,10 +1,8 @@
|
||||
import io
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
import discord
|
||||
import shlex
|
||||
import io
|
||||
from discord.opus import Encoder
|
||||
|
||||
import discord
|
||||
|
||||
# Huge thanks to https://github.com/Rapptz/discord.py/issues/5192#issuecomment-669515571
|
||||
|
||||
|
@ -1,9 +1,8 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import discord
|
||||
|
||||
|
||||
class Observer:
|
||||
@abstractmethod
|
||||
async def update(self, message: discord.Message) -> None:
|
||||
def update(self, message: discord.Message) -> None:
|
||||
raise NotImplemented
|
||||
|
@ -1,5 +1,4 @@
|
||||
import discord
|
||||
|
||||
from .Observer import Observer
|
||||
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import DB
|
||||
from TTSSilero import Speakers
|
||||
|
||||
|
||||
class SpeakersSettingsAdapterDiscord:
|
||||
def get_speaker(self, guild_id: int, user_id: int) -> str:
|
||||
DEFAULT_SPEAKER = Speakers.kseniya
|
||||
|
||||
def get_speaker(self, guild_id: int, user_id: int) -> Speakers:
|
||||
user_defined_speaker = self.get_speaker_user(guild_id, user_id)
|
||||
if user_defined_speaker is None:
|
||||
return self.get_speaker_global(guild_id)
|
||||
@ -11,29 +14,36 @@ class SpeakersSettingsAdapterDiscord:
|
||||
else:
|
||||
return user_defined_speaker
|
||||
|
||||
def get_speaker_global(self, guild_id: int) -> str | None:
|
||||
def get_speaker_global(self, guild_id: int) -> Speakers:
|
||||
server_speaker_query = DB.ServerSpeaker.select()\
|
||||
.where(DB.ServerSpeaker.server_id == guild_id)
|
||||
|
||||
if server_speaker_query.count() == 1:
|
||||
return server_speaker_query.get().speaker
|
||||
return Speakers(server_speaker_query.get().speaker)
|
||||
|
||||
else:
|
||||
return None
|
||||
return self.DEFAULT_SPEAKER
|
||||
|
||||
def get_speaker_user(self, guild_id: int, user_id: int) -> str | None:
|
||||
def get_speaker_user(self, guild_id: int, user_id: int) -> Speakers | None:
|
||||
user_speaker_query = DB.UserServerSpeaker.select()\
|
||||
.where(DB.UserServerSpeaker.server_id == guild_id)\
|
||||
.where(DB.UserServerSpeaker.user_id == user_id)
|
||||
|
||||
if user_speaker_query.count() == 1:
|
||||
return user_speaker_query.get().speaker
|
||||
return Speakers(user_speaker_query.get().speaker)
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_speaker_user(self, guild_id: int, user_id: int, speaker: str) -> None:
|
||||
DB.UserServerSpeaker.replace(server_id=guild_id, user_id=user_id, speaker=speaker).execute()
|
||||
@property
|
||||
def available_speakers(self) -> set[str]:
|
||||
return {speaker.name for speaker in Speakers}
|
||||
|
||||
def set_speaker_global(self, guild_id: int, user_id: int, speaker: str) -> None:
|
||||
DB.ServerSpeaker.replace(server_id=guild_id, speaker=speaker).execute()
|
||||
def set_speaker_user(self, guild_id: int, user_id: int, speaker: Speakers) -> None:
|
||||
DB.UserServerSpeaker.replace(server_id=guild_id, user_id=user_id, speaker=speaker.value).execute()
|
||||
|
||||
def set_speaker_global(self, guild_id: int, speaker: Speakers) -> None:
|
||||
DB.ServerSpeaker.replace(server_id=guild_id, speaker=speaker.value).execute()
|
||||
|
||||
|
||||
speakers_settings_adapter = SpeakersSettingsAdapterDiscord()
|
||||
|
@ -1,15 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Argument(BaseModel):
|
||||
type: Literal['str', 'int', 'float']
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class ModelDescription(BaseModel):
|
||||
engine: str
|
||||
name: str
|
||||
arguments: dict[str, Argument]
|
||||
description: None | str = None
|
@ -1,19 +0,0 @@
|
||||
import aiohttp
|
||||
|
||||
import config
|
||||
from .Speakers import ModelDescription
|
||||
|
||||
|
||||
class TTSServer:
|
||||
def __init__(self):
|
||||
self.session = aiohttp.ClientSession(base_url=config.BASE_URL)
|
||||
|
||||
async def synthesize_text(self, text: str, speaker: ModelDescription) -> bytes:
|
||||
async with self.session.post(url=f'/synth/{speaker.engine}/model/{speaker.name}', json={'text': text}) as req:
|
||||
req.raise_for_status()
|
||||
return await req.content.read()
|
||||
|
||||
async def discovery(self) -> list[ModelDescription]:
|
||||
async with self.session.get('/discovery') as req:
|
||||
req.raise_for_status()
|
||||
return [ModelDescription(**item) for item in await req.json()]
|
@ -1,53 +0,0 @@
|
||||
from loguru import logger
|
||||
|
||||
import DB
|
||||
from .Speakers import ModelDescription
|
||||
from .TTSServer import TTSServer
|
||||
|
||||
|
||||
class TTSServerCached(TTSServer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.models_lookup_idx: dict[str, ModelDescription] = dict()
|
||||
|
||||
async def synthesize_text(self, text: str, speaker: ModelDescription) -> bytes:
|
||||
cache_query = DB.SoundCache.select() \
|
||||
.where(DB.SoundCache.text == text) \
|
||||
.where(DB.SoundCache.speaker == speaker)
|
||||
|
||||
if cache_query.count() == 1:
|
||||
with DB.database.atomic():
|
||||
DB.SoundCache.update({DB.SoundCache.usages: DB.SoundCache.usages + 1}) \
|
||||
.where(DB.SoundCache.text == text) \
|
||||
.where(DB.SoundCache.speaker == speaker).execute()
|
||||
|
||||
cached = cache_query.get().audio
|
||||
|
||||
return cached
|
||||
|
||||
else:
|
||||
synthesized = await super().synthesize_text(text, speaker)
|
||||
DB.SoundCache.create(text=text, speaker=speaker, audio=synthesized)
|
||||
return synthesized
|
||||
|
||||
async def discovery(self) -> list[ModelDescription]:
|
||||
res = await super().discovery()
|
||||
logger.debug(f'Discovered {len(res)} models')
|
||||
self.models_lookup_idx: dict[str, ModelDescription] = {f'{desc.engine}_{desc.name}': desc for desc in res}
|
||||
return res
|
||||
|
||||
async def speaker_to_description(self, speaker_str: str) -> ModelDescription | None:
|
||||
for attempt in range(2):
|
||||
if len(self.models_lookup_idx) > 0:
|
||||
break
|
||||
|
||||
await self.discovery()
|
||||
|
||||
else:
|
||||
raise RuntimeError('Models discovery seems to return zero models')
|
||||
|
||||
# Default to first model
|
||||
return self.models_lookup_idx.get(speaker_str, None)
|
||||
|
||||
def default_speaker(self) -> ModelDescription:
|
||||
return list(self.models_lookup_idx.values())[0]
|
@ -1,3 +0,0 @@
|
||||
from .Speakers import ModelDescription
|
||||
from .TTSServer import TTSServer
|
||||
from .TTSServerCached import TTSServerCached
|
15
TTSSilero/Speakers.py
Normal file
15
TTSSilero/Speakers.py
Normal file
@ -0,0 +1,15 @@
|
||||
import enum
|
||||
|
||||
|
||||
class Speakers(enum.Enum):
|
||||
aidar = 'aidar'
|
||||
baya = 'baya'
|
||||
kseniya = 'kseniya'
|
||||
irina = 'irina'
|
||||
ruslan = 'ruslan'
|
||||
natasha = 'natasha'
|
||||
thorsten = 'thorsten'
|
||||
tux = 'tux'
|
||||
gilles = 'gilles'
|
||||
lj = 'lj'
|
||||
dilyara = 'dilyara'
|
62
TTSSilero/TTSSilero.py
Normal file
62
TTSSilero/TTSSilero.py
Normal file
@ -0,0 +1,62 @@
|
||||
import os
|
||||
import io
|
||||
import torch.package
|
||||
|
||||
from .Speakers import Speakers
|
||||
from multi_acc_v3_package import TTSModelMultiAcc_v3
|
||||
|
||||
|
||||
class TTSSilero:
|
||||
def __init__(self, threads: int = 24):
|
||||
device = torch.device('cpu')
|
||||
torch.set_num_threads(threads)
|
||||
"""
|
||||
local_file = 'model_multi.pt'
|
||||
|
||||
if not os.path.isfile(local_file):
|
||||
torch.hub.download_url_to_file(
|
||||
'https://models.silero.ai/models/tts/multi/v2_multi.pt',
|
||||
local_file
|
||||
)
|
||||
|
||||
self.model: TTSModelMulti_v2 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
|
||||
self.model.to(device)
|
||||
"""
|
||||
local_file = 'model.pt'
|
||||
|
||||
if not os.path.isfile(local_file):
|
||||
torch.hub.download_url_to_file('https://models.silero.ai/models/tts/ru/ru_v3.pt', local_file)
|
||||
|
||||
self.model: TTSModelMultiAcc_v3 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
|
||||
self.model.to(device)
|
||||
# print(self.model.speakers)
|
||||
|
||||
self.sample_rate = 48000
|
||||
|
||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
|
||||
return self.to_wav(self._synthesize_text(text, speaker))
|
||||
|
||||
def _synthesize_text(self, text: str, speaker: Speakers) -> torch.Tensor:
|
||||
"""
|
||||
Performs splitting text and synthesizing it
|
||||
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
|
||||
result: torch.Tensor = self.model.apply_tts(
|
||||
text=text,
|
||||
speaker=speaker.value,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def to_wav(self, synthesized_text: torch.Tensor) -> bytes:
|
||||
res_io_stream = io.BytesIO()
|
||||
self.model.write_wave(res_io_stream, (synthesized_text * 32767).numpy().astype('int16'), self.sample_rate)
|
||||
res_io_stream.seek(0)
|
||||
|
||||
return res_io_stream.read()
|
||||
|
||||
|
36
TTSSilero/TTSSileroCached.py
Normal file
36
TTSSilero/TTSSileroCached.py
Normal file
@ -0,0 +1,36 @@
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
from .TTSSilero import TTSSilero
|
||||
from .Speakers import Speakers
|
||||
import DB
|
||||
import sqlite3
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class TTSSileroCached(TTSSilero):
|
||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
|
||||
# start = time.time()
|
||||
cache_query = DB.SoundCache.select()\
|
||||
.where(DB.SoundCache.text == text)\
|
||||
.where(DB.SoundCache.speaker == speaker.value)
|
||||
|
||||
if cache_query.count() == 1:
|
||||
with DB.database.atomic():
|
||||
DB.SoundCache.update({DB.SoundCache.usages: DB.SoundCache.usages + 1})\
|
||||
.where(DB.SoundCache.text == text)\
|
||||
.where(DB.SoundCache.speaker == speaker.value).execute()
|
||||
|
||||
cached = cache_query.get().audio
|
||||
|
||||
return cached
|
||||
|
||||
else:
|
||||
# logger.debug(f'Starting synthesis')
|
||||
# start2 = time.time()
|
||||
synthesized = super().synthesize_text(text, speaker)
|
||||
# logger.debug(f'Synthesis done in {time.time() - start2} s in {time.time() - start} s after start')
|
||||
DB.SoundCache.create(text=text, speaker=speaker.value, audio=synthesized)
|
||||
|
||||
# logger.debug(f'Cache set in {time.time() - start2} synth end and {time.time() - start2} s after start')
|
||||
return synthesized
|
3
TTSSilero/__init__.py
Normal file
3
TTSSilero/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .TTSSilero import TTSSilero
|
||||
from .Speakers import Speakers
|
||||
from .TTSSileroCached import TTSSileroCached
|
148
TTSSilero/multi_v2_package.py
Normal file
148
TTSSilero/multi_v2_package.py
Normal file
@ -0,0 +1,148 @@
|
||||
import re
|
||||
import wave
|
||||
import torch
|
||||
import warnings
|
||||
import contextlib
|
||||
|
||||
# for type hints only
|
||||
|
||||
|
||||
class TTSModelMulti_v2():
|
||||
def __init__(self, model_path, symbols):
|
||||
self.model = self.init_jit_model(model_path)
|
||||
self.symbols = symbols
|
||||
self.device = torch.device('cpu')
|
||||
speakers = ['aidar', 'baya', 'kseniya', 'irina', 'ruslan', 'natasha',
|
||||
'thorsten', 'tux', 'gilles', 'lj', 'dilyara']
|
||||
self.speaker_to_id = {sp: i for i, sp in enumerate(speakers)}
|
||||
|
||||
def init_jit_model(self, model_path: str):
|
||||
torch.set_grad_enabled(False)
|
||||
model = torch.jit.load(model_path, map_location='cpu')
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def prepare_text_input(self, text, symbols, symbol_to_id=None):
|
||||
if len(text) > 140:
|
||||
warnings.warn('Text string is longer than 140 symbols.')
|
||||
|
||||
if symbol_to_id is None:
|
||||
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
|
||||
text = text.lower()
|
||||
text = re.sub(r'[^{}]'.format(symbols[2:]), '', text)
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
if text[-1] not in ['.', '!', '?']:
|
||||
text = text + '.'
|
||||
text = text + symbols[1]
|
||||
|
||||
text_ohe = [symbol_to_id[s] for s in text if s in symbols]
|
||||
text_tensor = torch.LongTensor(text_ohe)
|
||||
return text_tensor
|
||||
|
||||
def prepare_tts_model_input(self, text: str or list, symbols: str, speakers: list):
|
||||
assert len(speakers) == len(text) or len(speakers) == 1
|
||||
if type(text) == str:
|
||||
text = [text]
|
||||
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
if len(text) == 1:
|
||||
return self.prepare_text_input(text[0], symbols, symbol_to_id).unsqueeze(0), torch.LongTensor(speakers), torch.LongTensor([0])
|
||||
|
||||
text_tensors = []
|
||||
for string in text:
|
||||
string_tensor = self.prepare_text_input(string, symbols, symbol_to_id)
|
||||
text_tensors.append(string_tensor)
|
||||
input_lengths, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([len(t) for t in text_tensors]),
|
||||
dim=0, descending=True)
|
||||
max_input_len = input_lengths[0]
|
||||
batch_size = len(text_tensors)
|
||||
|
||||
text_padded = torch.ones(batch_size, max_input_len, dtype=torch.int32)
|
||||
if len(speakers) == 1:
|
||||
speakers = speakers*batch_size
|
||||
speaker_ids = torch.LongTensor(batch_size).zero_()
|
||||
|
||||
for i, idx in enumerate(ids_sorted_decreasing):
|
||||
text_tensor = text_tensors[idx]
|
||||
in_len = text_tensor.size(0)
|
||||
text_padded[i, :in_len] = text_tensor
|
||||
speaker_ids[i] = speakers[idx]
|
||||
|
||||
return text_padded, speaker_ids, ids_sorted_decreasing
|
||||
|
||||
def process_tts_model_output(self, out, out_lens, ids):
|
||||
out = out.to('cpu')
|
||||
out_lens = out_lens.to('cpu')
|
||||
_, orig_ids = ids.sort()
|
||||
|
||||
proc_outs = []
|
||||
orig_out = out.index_select(0, orig_ids)
|
||||
orig_out_lens = out_lens.index_select(0, orig_ids)
|
||||
|
||||
for i, out_len in enumerate(orig_out_lens):
|
||||
proc_outs.append(orig_out[i][:out_len])
|
||||
return proc_outs
|
||||
|
||||
def to(self, device):
|
||||
self.model = self.model.to(device)
|
||||
self.device = device
|
||||
|
||||
def get_speakers(self, speakers: str or list):
|
||||
if type(speakers) == str:
|
||||
speakers = [speakers]
|
||||
speaker_ids = []
|
||||
for speaker in speakers:
|
||||
try:
|
||||
speaker_id = self.speaker_to_id[speaker]
|
||||
speaker_ids.append(speaker_id)
|
||||
except Exception:
|
||||
raise ValueError(f'No such speaker: {speaker}')
|
||||
return speaker_ids
|
||||
|
||||
def apply_tts(self, texts: str or list,
|
||||
speakers: str or list,
|
||||
sample_rate: int = 16000):
|
||||
speaker_ids = self.get_speakers(speakers)
|
||||
text_padded, speaker_ids, orig_ids = self.prepare_tts_model_input(texts,
|
||||
symbols=self.symbols,
|
||||
speakers=speaker_ids)
|
||||
with torch.inference_mode():
|
||||
out, out_lens = self.model(text_padded.to(self.device),
|
||||
speaker_ids.to(self.device),
|
||||
sr=sample_rate)
|
||||
audios = self.process_tts_model_output(out, out_lens, orig_ids)
|
||||
return audios
|
||||
|
||||
@staticmethod
|
||||
def write_wave(path, audio, sample_rate):
|
||||
"""Writes a .wav file.
|
||||
Takes path, PCM audio data, and sample rate.
|
||||
"""
|
||||
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio)
|
||||
|
||||
def save_wav(self, texts: str or list,
|
||||
speakers: str or list,
|
||||
audio_pathes: str or list = '',
|
||||
sample_rate: int = 16000):
|
||||
if type(texts) == str:
|
||||
texts = [texts]
|
||||
|
||||
if not audio_pathes:
|
||||
audio_pathes = [f'test_{str(i).zfill(3)}.wav' for i in range(len(texts))]
|
||||
if type(audio_pathes) == str:
|
||||
audio_pathes = [audio_pathes]
|
||||
assert len(audio_pathes) == len(texts)
|
||||
|
||||
audio = self.apply_tts(texts=texts,
|
||||
speakers=speakers,
|
||||
sample_rate=sample_rate)
|
||||
for i, _audio in enumerate(audio):
|
||||
self.write_wave(path=audio_pathes[i],
|
||||
audio=(_audio * 32767).numpy().astype('int16'),
|
||||
sample_rate=sample_rate)
|
||||
return audio_pathes
|
@ -1,6 +1,6 @@
|
||||
from loguru import logger
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class cogErrorHandlers:
|
||||
|
@ -1,10 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import time
|
||||
|
||||
import discord
|
||||
from discord.ext import commands, tasks
|
||||
from discord.ext.commands import Context
|
||||
import datetime
|
||||
import time
|
||||
|
||||
|
||||
class BotInformation(commands.Cog):
|
||||
@ -16,12 +15,12 @@ class BotInformation(commands.Cog):
|
||||
@commands.command('info')
|
||||
async def get_info(self, ctx: Context):
|
||||
info = f"""
|
||||
Text-To-Speech bot
|
||||
Text-To-Speech bot, based on Silero TTS model (<https://github.com/snakers4/silero-models>)
|
||||
License: `GNU GENERAL PUBLIC LICENSE Version 3`
|
||||
Author discord: `@a31`
|
||||
Author email: `a31@demb.uk`
|
||||
Author discord: `a31#6403`
|
||||
Author email: `a31@demb.design`
|
||||
Source code on github: <https://github.com/norohind/SileroTTSBot>
|
||||
Source code on gitea a31's instance: <https://gitea.demb.uk/a31/SileroTTSBot>
|
||||
Source code on gitea's a31 instance: <https://gitea.demb.design/a31/SileroTTSBot>
|
||||
Invite link: https://discord.com/oauth2/authorize?client_id={self.bot.user.id}&scope=bot%20applications.commands
|
||||
"""
|
||||
|
||||
|
@ -1,140 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import time
|
||||
try:
|
||||
import rfoo
|
||||
from rfoo.utils import rconsole
|
||||
|
||||
except ImportError as e:
|
||||
print("not importing rfoo", e)
|
||||
|
||||
from threading import Thread
|
||||
from typing import Optional, Coroutine
|
||||
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
from loguru import logger
|
||||
|
||||
from formatting import format_table
|
||||
|
||||
|
||||
class SingletonBase:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if getattr(cls, 'instance') is None:
|
||||
cls.instance = super().__new__(cls, *args, **kwargs)
|
||||
|
||||
# logger.debug(id(getattr(cls, 'instance')))
|
||||
return getattr(cls, 'instance')
|
||||
|
||||
|
||||
class BotManagement(commands.Cog, SingletonBase):
|
||||
instance: Optional['BotManagement'] = None
|
||||
|
||||
rfoo_server_thread: Optional[Thread] = None
|
||||
rfoo_server: Optional['rfoo.InetServer'] = None
|
||||
|
||||
# def __new__(cls, *args, **kwargs):
|
||||
# if cls.instance is None:
|
||||
# cls.instance = super().__new__(cls, *args, **kwargs)
|
||||
#
|
||||
# return cls.instance
|
||||
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
|
||||
@commands.is_owner()
|
||||
@commands.command('listServers')
|
||||
async def list_servers(self, ctx: Context):
|
||||
text_table = format_table(((server.name, str(server.id)) for server in self.bot.guilds), ('name', 'id'))
|
||||
await ctx.channel.send(text_table)
|
||||
|
||||
@commands.is_owner()
|
||||
@commands.command('listVoice')
|
||||
async def list_voice_connections(self, ctx: Context):
|
||||
text_table = format_table(((it.guild.name,) for it in self.bot.voice_clients))
|
||||
await ctx.channel.send(text_table)
|
||||
|
||||
def start_rfoo(self) -> bool:
|
||||
# True if started, False if already started
|
||||
if self.rfoo_server_thread is None:
|
||||
self.rfoo_server = rfoo.InetServer(rconsole.ConsoleHandler, {'bot': self.bot, 'ct': self.ct})
|
||||
self.rfoo_server_thread = Thread(target=lambda: self.rfoo_server.start(rfoo.LOOPBACK, 54321))
|
||||
self.rfoo_server_thread.daemon = True
|
||||
self.rfoo_server_thread.start()
|
||||
logger.info('Rfoo thread started by msg')
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def stop_rfoo(self) -> bool:
|
||||
if self.rfoo_server_thread is not None:
|
||||
self.rfoo_server.stop()
|
||||
del self.rfoo_server_thread
|
||||
logger.info('Rfoo thread stopped by msg')
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@commands.is_owner()
|
||||
@commands.command('rfooStart')
|
||||
async def start(self, ctx: Context):
|
||||
if self.start_rfoo():
|
||||
await ctx.send('Rfoo thread started')
|
||||
|
||||
else:
|
||||
await ctx.send('Rfoo thread already started')
|
||||
|
||||
@commands.is_owner()
|
||||
@commands.command('rfooStop')
|
||||
async def stop(self, ctx: Context):
|
||||
if self.stop_rfoo():
|
||||
await ctx.send('Rfoo server stopped')
|
||||
|
||||
else:
|
||||
await ctx.send('Rfoo server already stopped')
|
||||
|
||||
def ct(self, coro: Coroutine):
|
||||
"""
|
||||
ct - short from create_task
|
||||
execute coroutine and get result
|
||||
"""
|
||||
|
||||
task = self.bot.loop.create_task(coro)
|
||||
while not task.done():
|
||||
time.sleep(0.1)
|
||||
|
||||
try:
|
||||
return task.result()
|
||||
|
||||
except asyncio.exceptions.InvalidStateError:
|
||||
return task.exception()
|
||||
|
||||
@commands.command('shutdown')
|
||||
async def shutdown(self, ctx: Context) -> None:
|
||||
"""Shutdown bot with hope docker daemon will restart it, `@a31` and `@furiz__` has rights for it"""
|
||||
if ctx.author.id in (420130693696323585, # @a31
|
||||
444819880781545472): # @furiz__
|
||||
|
||||
log_msg = f"Got shutdown command by {ctx.author}"
|
||||
logger.info(log_msg)
|
||||
await ctx.reply('Shutting down')
|
||||
|
||||
try:
|
||||
await (await self.bot.fetch_user(420130693696323585)).send(log_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(f'Failed to send shutdown message', )
|
||||
|
||||
self.bot.loop.create_task(self.bot.close())
|
||||
|
||||
else:
|
||||
await ctx.reply('No rights for you')
|
||||
|
||||
|
||||
async def setup(bot: commands.Bot):
|
||||
await bot.add_cog(BotManagement(bot))
|
||||
|
||||
async def teardown(bot):
|
||||
stop_res = BotManagement(bot).stop_rfoo()
|
||||
logger.info(f'Unloaded rfoo with result {stop_res} during BotManagement unload')
|
||||
|
204
cogs/TTSCore.py
204
cogs/TTSCore.py
@ -2,21 +2,18 @@
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Union, Optional
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
from loguru import logger
|
||||
|
||||
import discord
|
||||
import DB
|
||||
import Observ
|
||||
import formatting
|
||||
from typing import Union
|
||||
from loguru import logger
|
||||
from TTSSilero import TTSSileroCached
|
||||
from TTSSilero import Speakers
|
||||
from FFmpegPCMAudioModified import FFmpegPCMAudio
|
||||
from SpeakersSettingsAdapterDiscord import SpeakersSettingsAdapterDiscord
|
||||
from TTSServer import TTSServerCached
|
||||
from TTSServer.Speakers import ModelDescription
|
||||
import Observ
|
||||
from cogErrorHandlers import cogErrorHandlers
|
||||
from SpeakersSettingsAdapterDiscord import speakers_settings_adapter, SpeakersSettingsAdapterDiscord
|
||||
|
||||
|
||||
class TTSCore(commands.Cog, Observ.Observer):
|
||||
@ -24,33 +21,12 @@ class TTSCore(commands.Cog, Observ.Observer):
|
||||
self.bot = bot
|
||||
self.cog_command_error = cogErrorHandlers.missing_argument_handler
|
||||
self.bot.subscribe(self) # subscribe for messages that aren't commands
|
||||
self.tts = TTSServerCached()
|
||||
self.tts = TTSSileroCached()
|
||||
self.tts_queues: dict[int, list[discord.AudioSource]] = defaultdict(list)
|
||||
self.speakers_adapter: SpeakersSettingsAdapterDiscord = SpeakersSettingsAdapterDiscord()
|
||||
|
||||
@commands.command('drop')
|
||||
async def drop_queue(self, ctx: Context):
|
||||
"""
|
||||
Drop tts queue for current server
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
del self.tts_queues[ctx.guild.id]
|
||||
await ctx.send('Queue dropped')
|
||||
|
||||
except KeyError:
|
||||
await ctx.send('Failed on dropping queue')
|
||||
self.speakers_adapter: SpeakersSettingsAdapterDiscord = speakers_settings_adapter
|
||||
|
||||
@commands.command('exit')
|
||||
async def leave_voice(self, ctx: Context):
|
||||
"""
|
||||
Disconnect bot from current channel, it also drop messages queue
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
if ctx.guild.voice_client is not None:
|
||||
await ctx.guild.voice_client.disconnect(force=False)
|
||||
await ctx.channel.send(f"Left voice channel")
|
||||
@ -88,67 +64,43 @@ class TTSCore(commands.Cog, Observ.Observer):
|
||||
if voice_client is None:
|
||||
voice_client: discord.VoiceClient = await user_voice_state.channel.connect()
|
||||
|
||||
if user_voice_state.channel.id != voice_client.channel.id:
|
||||
await message.channel.send('We are in different voice channels')
|
||||
return
|
||||
|
||||
speaker_str: str = self.speakers_adapter.get_speaker(message.guild.id, message.author.id)
|
||||
speaker: ModelDescription | None = await self.tts.speaker_to_description(speaker_str)
|
||||
if speaker is None:
|
||||
speaker = self.tts.default_speaker()
|
||||
speaker: Speakers = self.speakers_adapter.get_speaker(message.guild.id, message.author.id)
|
||||
|
||||
# check if message will fail on synthesis
|
||||
if DB.SynthesisErrors.select() \
|
||||
.where(DB.SynthesisErrors.speaker == speaker_str) \
|
||||
.where(DB.SynthesisErrors.text == message.content) \
|
||||
if DB.SynthesisErrors.select()\
|
||||
.where(DB.SynthesisErrors.speaker == speaker.value)\
|
||||
.where(DB.SynthesisErrors.text == message.content)\
|
||||
.count() == 1:
|
||||
# Then we will not try to synthesis it
|
||||
await message.channel.send(f"I will not synthesis this message due to TTS engine limitations")
|
||||
return
|
||||
|
||||
try:
|
||||
wav_file_like_object: bytes = await self.tts.synthesize_text(message.content, speaker=speaker)
|
||||
wav_file_like_object = self.tts.synthesize_text(message.content, speaker=speaker)
|
||||
sound_source = FFmpegPCMAudio(wav_file_like_object, pipe=True, stderr=subprocess.PIPE)
|
||||
if voice_client.is_playing():
|
||||
# Then we need to enqueue prepared sound for playing through self.tts_queues mechanism
|
||||
|
||||
self.tts_queues[message.guild.id].append(sound_source)
|
||||
await message.channel.send(f"Enqueued for play, queue size: {len(self.tts_queues[message.guild.id])}")
|
||||
return
|
||||
|
||||
voice_client.play(sound_source, after=lambda e: self.queue_player(message))
|
||||
|
||||
except Exception as synth_exception:
|
||||
logger.opt(exception=True).warning(f'Exception on synthesize {message.content!r}: {synth_exception}')
|
||||
await message.channel.send(f'Synthesize error')
|
||||
DB.SynthesisErrors.create(speaker=speaker, text=message.content)
|
||||
return
|
||||
|
||||
else:
|
||||
try:
|
||||
if voice_client.is_playing():
|
||||
# Then we need to enqueue prepared sound for playing through self.tts_queues mechanism
|
||||
self.tts_queues[message.guild.id].append(sound_source)
|
||||
await message.channel.send(
|
||||
f"Enqueued for play, queue size: {len(self.tts_queues[message.guild.id])}")
|
||||
return
|
||||
|
||||
voice_client.play(sound_source, after=lambda e: self.queue_player(message))
|
||||
|
||||
except Exception as play_exception:
|
||||
logger.opt(exception=True).warning(
|
||||
f'Exception on playing for: {message.guild.name}[#{message.channel.name}]: {message.author.display_name} / {play_exception}')
|
||||
await message.channel.send(f'Playing error')
|
||||
return
|
||||
await message.channel.send(f'Internal error')
|
||||
DB.SynthesisErrors.create(speaker=speaker.value, text=message.content)
|
||||
|
||||
def queue_player(self, message: discord.Message):
|
||||
voice_client: Union[discord.VoiceClient, None] = message.guild.voice_client
|
||||
if voice_client is None:
|
||||
# don't play anything and clear queue for whole guild
|
||||
del self.tts_queues[message.guild.id]
|
||||
return
|
||||
|
||||
for sound_source in self.tts_queues[message.guild.id]:
|
||||
if len(self.tts_queues[message.guild.id]) == 0:
|
||||
return
|
||||
|
||||
voice_client: Optional[discord.VoiceClient] = message.guild.voice_client
|
||||
if voice_client is None:
|
||||
# don't play anything and clear queue for whole guild
|
||||
break
|
||||
|
||||
try:
|
||||
voice_client.play(sound_source)
|
||||
|
||||
except discord.errors.ClientException: # Here we expect Not connected to voice
|
||||
break
|
||||
|
||||
voice_client.play(sound_source)
|
||||
while voice_client.is_playing():
|
||||
time.sleep(0.1)
|
||||
|
||||
@ -159,105 +111,13 @@ class TTSCore(commands.Cog, Observ.Observer):
|
||||
pass
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_voice_state_update(self, member: discord.Member, before: discord.VoiceState,
|
||||
after: discord.VoiceState):
|
||||
async def on_voice_state_update(self, member: discord.Member, before: discord.VoiceState, after: discord.VoiceState):
|
||||
if after.channel is None:
|
||||
members = before.channel.members
|
||||
if len(members) == 1:
|
||||
if members[0].id == self.bot.user.id:
|
||||
await before.channel.guild.voice_client.disconnect(force=False)
|
||||
|
||||
# TODO: leave voice channel after being moved there alone
|
||||
|
||||
# Ex TTSSettings.py
|
||||
@commands.command('getAllSpeakers')
|
||||
async def get_speakers(self, ctx: Context):
|
||||
"""
|
||||
Enumerate all available to set speakers
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
|
||||
await ctx.send(
|
||||
formatting.format_table(
|
||||
tuple((model.engine + '_' + model.name, model.description)
|
||||
for model in await self.tts.discovery()),
|
||||
('model', 'description')
|
||||
))
|
||||
|
||||
@commands.command('setPersonalSpeaker')
|
||||
async def set_user_speaker(self, ctx: Context, speaker: str):
|
||||
"""
|
||||
Set personal speaker on this server
|
||||
|
||||
:param ctx:
|
||||
:param speaker:
|
||||
:return:
|
||||
"""
|
||||
|
||||
if await self.tts.speaker_to_description(speaker) is None:
|
||||
return await ctx.send(
|
||||
f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||
|
||||
self.speakers_adapter.set_speaker_user(ctx.guild.id, ctx.author.id, speaker)
|
||||
await ctx.reply(f'Successfully set **your personal** speaker to `{speaker}`')
|
||||
|
||||
@commands.command('setServerSpeaker')
|
||||
async def set_server_speaker(self, ctx: Context, speaker: str):
|
||||
"""
|
||||
Set global server speaker
|
||||
|
||||
:param ctx:
|
||||
:param speaker:
|
||||
:return:
|
||||
"""
|
||||
if await self.tts.speaker_to_description(speaker) is None:
|
||||
return await ctx.send(
|
||||
f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||
|
||||
self.speakers_adapter.set_speaker_global(ctx.guild.id, ctx.author.id, speaker)
|
||||
await ctx.reply(f'Successfully set **your personal** speaker to `{speaker}`')
|
||||
|
||||
@commands.command('getSpeaker')
|
||||
async def get_speaker(self, ctx: Context):
|
||||
"""
|
||||
Tell first appropriate speaker for a user, it can be user specified, server specified or server default
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speaker = self.speakers_adapter.get_speaker(ctx.guild.id, ctx.author.id)
|
||||
|
||||
await ctx.reply(f'Your current speaker is `{speaker}`')
|
||||
|
||||
@commands.command('getPersonalSpeaker')
|
||||
async def get_personal_speaker(self, ctx: Context):
|
||||
"""
|
||||
Tell user his personal speaker on this server, if user don't have personal speaker, tells server default one
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speaker = self.speakers_adapter.get_speaker_user(ctx.guild.id, ctx.author.id)
|
||||
if speaker is None:
|
||||
server_speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
|
||||
await ctx.send(f"You currently don't have a personal speaker, current server speaker is `{server_speaker}`")
|
||||
|
||||
else:
|
||||
await ctx.reply(f"Your personal speaker is `{speaker}`")
|
||||
|
||||
@commands.command('getServerSpeaker')
|
||||
async def get_server_speaker(self, ctx: Context):
|
||||
"""
|
||||
Tell server global speaker
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
|
||||
await ctx.send(f"Current server speaker is `{speaker}`")
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(TTSCore(bot))
|
||||
|
102
cogs/TTSSettings.py
Normal file
102
cogs/TTSSettings.py
Normal file
@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
from TTSSilero import Speakers
|
||||
from cogErrorHandlers import cogErrorHandlers
|
||||
from SpeakersSettingsAdapterDiscord import speakers_settings_adapter, SpeakersSettingsAdapterDiscord
|
||||
|
||||
|
||||
class TTSSettings(commands.Cog):
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
self.cog_command_error = cogErrorHandlers.missing_argument_handler
|
||||
self.speakers_adapter: SpeakersSettingsAdapterDiscord = speakers_settings_adapter
|
||||
|
||||
@commands.command('getAllSpeakers')
|
||||
async def get_speakers(self, ctx: Context):
|
||||
"""
|
||||
Enumerate all available to set speakers
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speakers = '\n'.join(self.speakers_adapter.available_speakers)
|
||||
|
||||
await ctx.send(f"```\n{speakers}```")
|
||||
|
||||
@commands.command('setPersonalSpeaker')
|
||||
async def set_user_speaker(self, ctx: Context, speaker: str):
|
||||
"""
|
||||
Set personal speaker on this server
|
||||
|
||||
:param ctx:
|
||||
:param speaker:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
checked_speaker: Speakers = Speakers(speaker)
|
||||
self.speakers_adapter.set_speaker_user(ctx.guild.id, ctx.author.id, checked_speaker)
|
||||
await ctx.reply(f'Successfully set **your personal** speaker to `{checked_speaker.value}`')
|
||||
|
||||
except (KeyError, ValueError):
|
||||
await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||
|
||||
@commands.command('setServerSpeaker')
|
||||
async def set_server_speaker(self, ctx: Context, speaker: str):
|
||||
"""
|
||||
Set global server speaker
|
||||
|
||||
:param ctx:
|
||||
:param speaker:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
checked_speaker: Speakers = Speakers(speaker)
|
||||
self.speakers_adapter.set_speaker_global(ctx.guild.id, checked_speaker)
|
||||
await ctx.send(f'Successfully set **server** speaker to `{checked_speaker.value}`')
|
||||
|
||||
except (KeyError, ValueError):
|
||||
await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||
|
||||
@commands.command('getSpeaker')
|
||||
async def get_speaker(self, ctx: Context):
|
||||
"""
|
||||
Tell first appropriate speaker for a user, it can be user specified, server specified or server default
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speaker = self.speakers_adapter.get_speaker(ctx.guild.id, ctx.author.id)
|
||||
|
||||
await ctx.reply(f'Your current speaker is `{speaker.value}`')
|
||||
|
||||
@commands.command('getPersonalSpeaker')
|
||||
async def get_personal_speaker(self, ctx: Context):
|
||||
"""
|
||||
Tell user his personal speaker on this server, if user don't have personal speaker, tells server default one
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speaker = self.speakers_adapter.get_speaker_user(ctx.guild.id, ctx.author.id)
|
||||
if speaker is None:
|
||||
server_speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id).value
|
||||
await ctx.send(f"You currently don't have a personal speaker, current server speaker is `{server_speaker}`")
|
||||
|
||||
else:
|
||||
await ctx.reply(f"Your personal speaker is `{speaker.value}`")
|
||||
|
||||
@commands.command('getServerSpeaker')
|
||||
async def get_server_speaker(self, ctx: Context):
|
||||
"""
|
||||
Tell server global speaker
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
|
||||
await ctx.send(f"Current server speaker is `{speaker.value}`")
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(TTSSettings(bot))
|
@ -1,9 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
from loguru import logger
|
||||
|
||||
import DB
|
||||
from loguru import logger
|
||||
import DynamicCommandPrefix
|
||||
from cogErrorHandlers import cogErrorHandlers
|
||||
|
||||
|
@ -1,5 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
DB_PATH = Path(os.getenv('DATA_DIR', '.')) / 'voice_cache.sqlite'
|
||||
BASE_URL = os.environ['BASE_URL']
|
@ -1,18 +0,0 @@
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
class MISSING:
|
||||
pass
|
||||
|
||||
|
||||
def format_table(data: Iterable[Iterable[str]], header: Iterable[str] = MISSING) -> str:
|
||||
if header != MISSING:
|
||||
data = (header, *data)
|
||||
|
||||
result = '```\n'
|
||||
for row in data:
|
||||
row = [str(item).replace('`', '\\`') for item in row]
|
||||
result += '\t'.join(row) + '\n'
|
||||
|
||||
result += '```'
|
||||
return result
|
BIN
logo_200x200.png
BIN
logo_200x200.png
Binary file not shown.
Before Width: | Height: | Size: 4.1 KiB |
24
main.py
24
main.py
@ -1,19 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
import discord
|
||||
import signal
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
|
||||
import Observ
|
||||
from DynamicCommandPrefix import dynamic_command_prefix
|
||||
import Observ
|
||||
|
||||
LOG_FILE_ENABLED = os.getenv('LOG_FILE_ENABLED', 'true').lower() == 'true'
|
||||
|
||||
if LOG_FILE_ENABLED:
|
||||
logger.add('offlineTTSBot.log', backtrace=True, diagnose=False, rotation='5MB')
|
||||
logger.add('offlineTTSBot.log', backtrace=True, diagnose=False, rotation='5MB')
|
||||
|
||||
"""
|
||||
while msg := input('$ '):
|
||||
@ -53,7 +49,7 @@ class DiscordTTSBot(commands.Bot, Observ.Subject):
|
||||
await self.notify(ctx.message)
|
||||
|
||||
else:
|
||||
logger.opt(exception=exception).warning(f'Global error caught:')
|
||||
raise exception
|
||||
|
||||
|
||||
intents = discord.Intents.default()
|
||||
@ -70,8 +66,6 @@ async def main():
|
||||
|
||||
await discord_client.start(os.environ['DISCORD_TOKEN'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
logger.debug('Shutdown completed')
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
logger.debug('Shutdown completed')
|
||||
|
381
multi_acc_v3_package.py
Normal file
381
multi_acc_v3_package.py
Normal file
@ -0,0 +1,381 @@
|
||||
import re
|
||||
import wave
|
||||
import torch
|
||||
import warnings
|
||||
import contextlib
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
|
||||
class TTSModelMultiAcc_v3():
|
||||
def __init__(self, model_path, symbols, speaker_to_id, emb_dim=128):
|
||||
torch.set_grad_enabled(False)
|
||||
self.model = self.init_jit_model(model_path)
|
||||
self.symbols = symbols
|
||||
self.device = torch.device('cpu')
|
||||
self.speaker_to_id = speaker_to_id
|
||||
self.speakers = list(speaker_to_id.keys())
|
||||
assert 'random' in self.speakers
|
||||
self.ru_ascii_dict = {r: asc for r, asc in zip('абвгдеёжзийклмнопрстуфхцчшщъыьэюя–',
|
||||
'abvgde1jzi2klmnoprstufhc4w35y6789=')}
|
||||
# ssml tags
|
||||
self.strength2time = {'x_weak': 25, 'weak': 75, 'medium': 150, 'strong': 300, 'x-strong': 1000}
|
||||
self.rate2value = {'x-slow': 0.5, 'slow': 0.8, 'medium': 1., 'fast': 1.2, 'x-fast': 1.5}
|
||||
self.pitch2value = {'x-low': 0.6, 'low': 0.8, 'medium': 1., 'high': 1.2, 'x-high': 1.4, 'robot': 0.}
|
||||
self.emb_dim = emb_dim
|
||||
|
||||
self.random_emb = None
|
||||
self.debug = False
|
||||
|
||||
self.valid_tags = {'break': {'strength': list(self.strength2time.keys())},
|
||||
'prosody': {'rate': list(self.rate2value.keys()),
|
||||
'pitch': list(self.pitch2value.keys())}}
|
||||
|
||||
def init_jit_model(self, model_path: str):
|
||||
torch.set_grad_enabled(False)
|
||||
model = torch.jit.load(model_path, map_location='cpu')
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def ru_to_ascii(self, sentence):
|
||||
ascii_list = [self.ru_ascii_dict.get(s, s) for s in sentence]
|
||||
ascii_text = ''.join(ascii_list)
|
||||
return ascii_text
|
||||
|
||||
def prepare_text_input(self, text):
|
||||
|
||||
text = text.lower()
|
||||
text = text.replace('—', '–').replace('–', '–').replace('‑', '-')
|
||||
text = re.sub(r'[^{}]'.format(self.symbols[3:]), '', text)
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
sentence = self.ru_to_ascii(text)
|
||||
|
||||
clean_sentence = re.sub(r'[^a-z1-9\- ]', '', sentence)
|
||||
has_text = len(clean_sentence.replace(' ', '')) > 0
|
||||
return sentence, clean_sentence, has_text
|
||||
|
||||
def prepare_tts_model_input(self, text: str,
|
||||
ssml: bool,
|
||||
speaker_ids: list):
|
||||
if ssml:
|
||||
clean_text_list = self.process_ssml(text)
|
||||
else:
|
||||
clean_text_list = self.process_simple_text(text)
|
||||
|
||||
sentences, clean_sentences, break_lens, prosody_rates, prosody_pitches = map(list, zip(*map(dict.values, clean_text_list)))
|
||||
full_text_len = sum([len(s) for s in sentences])
|
||||
if full_text_len > 1000:
|
||||
warnings.warn('Text string is longer than 1000 symbols.')
|
||||
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
return sentences, clean_sentences, break_lens, prosody_rates, prosody_pitches, speaker_ids
|
||||
|
||||
def to(self, device):
|
||||
self.model.tts_model = self.model.tts_model.to(device)
|
||||
self.device = device
|
||||
|
||||
def get_speakers(self, speaker: str, voice_path=None):
|
||||
try:
|
||||
if speaker == 'random':
|
||||
self.load_random_voice(voice_path)
|
||||
speaker_id = self.speaker_to_id.get(speaker, None)
|
||||
if speaker_id is None:
|
||||
raise ValueError(f"`speaker` should be in {', '.join(self.speakers)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f'Failed to load speaker: {speaker}, error: {e}')
|
||||
return [speaker_id]
|
||||
|
||||
def process_simple_text(self, text):
|
||||
sentence, clean_sentence, has_text = self.prepare_text_input(text)
|
||||
if not has_text:
|
||||
raise ValueError
|
||||
simple_text_dict = [{'text': sentence,
|
||||
'clean_text': clean_sentence,
|
||||
'break_time': None,
|
||||
'prosody_rate': 1.,
|
||||
'prosody_pitch': 1.}]
|
||||
return simple_text_dict
|
||||
|
||||
def process_ssml(self, ssml_text):
|
||||
ssml_text = re.sub(r'\s+', ' ', ssml_text).strip().replace('\n ', '\n')
|
||||
try:
|
||||
root = ET.fromstring(ssml_text)
|
||||
except Exception:
|
||||
raise ValueError("Invalid XML format")
|
||||
assert root.tag == 'speak', "Invalid SSML format: <speak> tag is essential"
|
||||
try:
|
||||
ssml_parsed = self.process_ssml_element(root)
|
||||
if self.debug:
|
||||
print(ssml_parsed)
|
||||
except AssertionError as ae:
|
||||
raise ae
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse SSML: {e}")
|
||||
try:
|
||||
clean_text_list = self.process_ssml_tag_dict(ssml_parsed)
|
||||
if self.debug:
|
||||
print(clean_text_list)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to process SSML: {e}")
|
||||
return clean_text_list
|
||||
|
||||
def process_ssml_tag_dict(self, text_break_list):
|
||||
proc_text_break_list = []
|
||||
for i, text_break_prosody in enumerate(text_break_list):
|
||||
tbreak = text_break_prosody['break']
|
||||
tprosody = text_break_prosody['prosody']
|
||||
text, clean_text, has_text = self.prepare_text_input(text_break_prosody['text'])
|
||||
break_time = int(tbreak['time']/12.5) if tbreak['time'] is not None else None
|
||||
if has_text or i == 0:
|
||||
text = self.check_text_break(text, tbreak)
|
||||
proc_text_break_list.append({'text': text,
|
||||
'clean_text': clean_text,
|
||||
'break_time': break_time,
|
||||
'prosody_rate': tprosody['rate'],
|
||||
'prosody_pitch': tprosody['pitch']})
|
||||
elif tbreak['strength'] is not None and len(proc_text_break_list) > 0:
|
||||
text = self.check_text_break(proc_text_break_list[-1]['text'], tbreak)
|
||||
proc_text_break_list[-1]['text'] = text
|
||||
if proc_text_break_list[-1]['break_time'] is None:
|
||||
proc_text_break_list[-1]['break_time'] = break_time
|
||||
else:
|
||||
proc_text_break_list[-1]['break_time'] = max(break_time, proc_text_break_list[-1]['break_time'])
|
||||
return proc_text_break_list
|
||||
|
||||
def check_text_break(self, text, tbreak):
|
||||
if len(text) == 0 or tbreak['strength'] is not None and text[-1] not in '!,-.:;?–…': # TODO fx dash
|
||||
text = text + '.'
|
||||
return text
|
||||
|
||||
def process_ssml_element(self, element, def_strength='strong', def_rate=1., def_pitch=1.):
|
||||
parsed = []
|
||||
last_tag = None
|
||||
head_text_parsed = self.process_head_tail_text(element.text, def_rate, def_pitch)
|
||||
parsed.extend(head_text_parsed)
|
||||
|
||||
for child in element:
|
||||
if child.tag == 'break':
|
||||
break_strength, break_ts = self.process_break_attrib(child.attrib)
|
||||
if len(parsed) == 0:
|
||||
parsed.append({'text': '.',
|
||||
'break': {'strength': None,
|
||||
'time': None},
|
||||
'prosody': {'rate': def_rate,
|
||||
'pitch': def_pitch}})
|
||||
parsed[-1]['break'] = {'strength': break_strength,
|
||||
'time': break_ts}
|
||||
|
||||
elif child.tag == 'prosody':
|
||||
prosody_rate, prosody_pitch, change_rate, change_pitch = self.process_prosody(child.attrib)
|
||||
child_rate = prosody_rate if change_rate else def_rate
|
||||
child_pitch = prosody_pitch if change_pitch else def_pitch
|
||||
child_parsed = self.process_ssml_element(child, def_strength, child_rate, child_pitch)
|
||||
parsed.extend(child_parsed)
|
||||
|
||||
elif child.tag in ['p', 's']:
|
||||
break_strength = 'strong' if child.tag == 's' else 'x-strong'
|
||||
child_parsed = self.process_ssml_element(child, break_strength, def_rate, def_pitch)
|
||||
if len(parsed) > 0 and (parsed[-1]['text'] or last_tag is not None or last_tag != child.tag):
|
||||
if parsed[-1]['break']['strength'] is None:
|
||||
parsed[-1]['break'] = {'strength': break_strength,
|
||||
'time': self.strength2time[break_strength]}
|
||||
else:
|
||||
last_time = parsed[-1]['break']['time']
|
||||
parsed[-1]['break'] = {'strength': break_strength,
|
||||
'time': max(last_time, self.strength2time[break_strength])}
|
||||
if len(child_parsed) > 0:
|
||||
if child_parsed[-1]['break']['strength'] is None:
|
||||
child_parsed[-1]['break'] = {'strength': break_strength,
|
||||
'time': self.strength2time[break_strength]}
|
||||
else:
|
||||
last_time = child_parsed[-1]['break']['time']
|
||||
child_parsed[-1]['break'] = {'strength': break_strength,
|
||||
'time': max(last_time, self.strength2time[break_strength])}
|
||||
parsed.extend(child_parsed)
|
||||
else:
|
||||
warnings.warn(f"Current model doesn't support SSML tag: {child.tag}")
|
||||
|
||||
last_tag = child.tag
|
||||
|
||||
if child.tail:
|
||||
tail_text = child.tail
|
||||
if tail_text[0] in '.,!?…–;:' and len(parsed) > 0:
|
||||
lost_punct = tail_text[0]
|
||||
parsed[-1]['text'] = parsed[-1]['text'].strip() + lost_punct
|
||||
if len(tail_text) > 1:
|
||||
tail_text = tail_text[1:]
|
||||
tail_text_parsed = self.process_head_tail_text(tail_text, def_rate, def_pitch)
|
||||
parsed.extend(tail_text_parsed)
|
||||
return parsed
|
||||
|
||||
def process_head_tail_text(self, element_text, def_rate, def_pitch):
|
||||
text_parsed = []
|
||||
if element_text is None:
|
||||
return text_parsed
|
||||
proc_text = element_text.replace('\n', '')
|
||||
proc_text = re.sub(r'\s+', ' ', proc_text).strip()
|
||||
text_parsed.append({'text': proc_text,
|
||||
'break': {'strength': None,
|
||||
'time': None},
|
||||
'prosody': {'rate': def_rate,
|
||||
'pitch': def_pitch}})
|
||||
return text_parsed
|
||||
|
||||
def process_break_attrib(self, attrib):
|
||||
for k in attrib.keys():
|
||||
if k not in ['strength', 'time']:
|
||||
warnings.warn(f"Current model doesn't support SSML <break> attrib: {k}")
|
||||
strength = attrib.get('strength', 'medium')
|
||||
break_time = attrib.get('time', None)
|
||||
if break_time is not None:
|
||||
if break_time.endswith('ms'):
|
||||
break_ts = int(break_time[:-2])
|
||||
elif break_time.endswith('s'):
|
||||
break_ts = int(break_time[:-1]) * 1000
|
||||
else:
|
||||
raise AssertionError("Invalid <break> tag, time should end with 'ms' or 's'")
|
||||
|
||||
if break_ts >= self.strength2time['x-strong']:
|
||||
strength = 'x-strong'
|
||||
elif break_ts >= self.strength2time['strong']:
|
||||
strength = 'strong'
|
||||
else:
|
||||
if strength in self.strength2time:
|
||||
break_ts = self.strength2time[strength]
|
||||
else:
|
||||
raise AssertionError(f"Invalid <break> tag, strength should be in {', '.join(self.valid_tags['break']['strength'])}")
|
||||
if break_ts > 5000:
|
||||
warnings.warn('Cuurent model supports pauses less than 5 sec')
|
||||
break_ts = 5000
|
||||
return strength, break_ts
|
||||
|
||||
def process_prosody(self, attrib):
|
||||
for k in attrib.keys():
|
||||
if k not in ['rate', 'pitch']:
|
||||
warnings.warn(f"Current model doesn't support SSML <prosody> attrib: {k}")
|
||||
rate = attrib.get('rate', None)
|
||||
pitch = attrib.get('pitch', None)
|
||||
assert rate is not None or pitch is not None, "Empty <prosody> tag"
|
||||
if rate is not None:
|
||||
change_rate = True
|
||||
if rate.endswith('%'):
|
||||
rate_val = int(rate.replace('%', '')) / 100
|
||||
else:
|
||||
rate_val = self.rate2value.get(rate, None)
|
||||
if rate_val is None:
|
||||
raise AssertionError(f"Invalid <prosody> tag, rate should be in {', '.join(self.valid_tags['prosody']['rate'])}")
|
||||
else:
|
||||
change_rate = False
|
||||
rate_val = 1.
|
||||
if pitch is not None:
|
||||
change_pitch = True
|
||||
if pitch.endswith('%'):
|
||||
pitch_val = int(pitch.replace('%', '')[1:]) / 100
|
||||
if pitch[0] == '+':
|
||||
pitch_val = 1. + pitch_val
|
||||
else:
|
||||
pitch_val = 1. - pitch_val
|
||||
else:
|
||||
pitch_val = self.pitch2value.get(pitch, None)
|
||||
if pitch_val is None:
|
||||
raise AssertionError(f"Invalid <prosody> tag, pitch should be in {', '.join(self.valid_tags['prosody']['pitch'])}")
|
||||
else:
|
||||
change_pitch = False
|
||||
pitch_val = 1.
|
||||
return rate_val, pitch_val, change_rate, change_pitch
|
||||
|
||||
def apply_tts(self, text=None,
|
||||
ssml_text=None,
|
||||
speaker: str = 'xenia',
|
||||
sample_rate: int = 48000,
|
||||
put_accent=True,
|
||||
put_yo=True,
|
||||
voice_path=None):
|
||||
|
||||
assert sample_rate in [8000, 24000, 48000], f"`sample_rate` should be in [8000, 24000, 48000], current value is {sample_rate}"
|
||||
assert speaker in self.speakers, f"`speaker` should be in {', '.join(self.speakers)}"
|
||||
assert text is not None or ssml_text is not None, "Both `text` and `ssml_text` are empty"
|
||||
|
||||
ssml = ssml_text is not None
|
||||
if ssml:
|
||||
input_text = ssml_text
|
||||
else:
|
||||
input_text = text
|
||||
speaker_ids = self.get_speakers(speaker, voice_path)
|
||||
sentences, clean_sentences, break_lens, prosody_rates, prosody_pitches, sp_ids = self.prepare_tts_model_input(input_text,
|
||||
ssml=ssml,
|
||||
speaker_ids=speaker_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
try:
|
||||
out, out_lens = self.model(sentences=sentences,
|
||||
clean_sentences=clean_sentences,
|
||||
break_lens=break_lens,
|
||||
prosody_rates=prosody_rates,
|
||||
prosody_pitches=prosody_pitches,
|
||||
speaker_ids=sp_ids,
|
||||
sr=sample_rate,
|
||||
device=str(self.device),
|
||||
put_yo=put_yo,
|
||||
put_accent=put_accent
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise Exception("Model couldn't generate your text, probably it's too long")
|
||||
audio = out.to('cpu')[0]
|
||||
return audio
|
||||
|
||||
@staticmethod
|
||||
def write_wave(path, audio, sample_rate):
|
||||
"""Writes a .wav file.
|
||||
Takes path, PCM audio data, and sample rate.
|
||||
"""
|
||||
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio)
|
||||
|
||||
def save_wav(self, text=None,
|
||||
ssml_text=None,
|
||||
speaker: str = 'xenia',
|
||||
audio_path: str = '',
|
||||
sample_rate: int = 48000,
|
||||
put_accent=True,
|
||||
put_yo=True):
|
||||
|
||||
if not audio_path:
|
||||
audio_path = 'test.wav'
|
||||
|
||||
audio = self.apply_tts(text=text,
|
||||
ssml_text=ssml_text,
|
||||
speaker=speaker,
|
||||
sample_rate=sample_rate,
|
||||
put_yo=put_yo,
|
||||
put_accent=put_accent)
|
||||
self.write_wave(path=audio_path,
|
||||
audio=(audio * 32767).numpy().astype('int16'),
|
||||
sample_rate=sample_rate)
|
||||
return audio_path
|
||||
|
||||
def load_random_voice(self, voice_path=None):
|
||||
if voice_path is None:
|
||||
random_emb = torch.randn(2, self.emb_dim, requires_grad=False).to(self.device)
|
||||
self.random_emb = random_emb
|
||||
print("Generated new voice")
|
||||
else:
|
||||
random_emb = torch.load(voice_path, map_location=self.device)
|
||||
print(f"Loaded voice from {voice_path}")
|
||||
if self.random_emb is not None and torch.equal(self.random_emb, random_emb):
|
||||
return
|
||||
mel_weight = random_emb[0]
|
||||
dur_weight = random_emb[1, :self.emb_dim//2]
|
||||
p_weight = random_emb[1, self.emb_dim//2:]
|
||||
|
||||
self.model.tts_model.tacotron.speaker_embedding.weight[-1] = mel_weight
|
||||
self.model.tts_model.dur_predictor.dur_pred.speaker_embedding.weight[-1] = dur_weight
|
||||
self.model.tts_model.pitch_predictor.pitch_pred.speaker_embedding.weight[-1] = p_weight
|
||||
|
||||
def save_random_voice(self, voice_path):
|
||||
assert self.random_emb is not None, "No generated random voice"
|
||||
torch.save(self.random_emb, voice_path)
|
||||
print(f"Saved generated voice to {voice_path}")
|
@ -1,7 +1,14 @@
|
||||
discord-py==2.3.2
|
||||
loguru==0.7.2
|
||||
peewee==3.17.5
|
||||
PyNaCl==1.5.0
|
||||
git+https://github.com/aaiyer/rfoo.git@1555bd4eed204bb6a33a5e313146a6c2813cfe91
|
||||
# Cython is setup dependency for rfoo
|
||||
pydantic
|
||||
aiohttp==3.7.4.post0
|
||||
async-timeout==3.0.1
|
||||
attrs==21.4.0
|
||||
chardet==4.0.0
|
||||
idna==3.3
|
||||
multidict==6.0.2
|
||||
numpy==1.22.3
|
||||
torch==1.11.0
|
||||
typing-extensions==4.1.1
|
||||
yarl==1.7.2
|
||||
git+https://github.com/Rapptz/discord.py.git#egg=discord-py
|
||||
loguru~=0.6.0
|
||||
peewee~=3.14.10
|
||||
PyNaCl
|
Loading…
x
Reference in New Issue
Block a user