Compare commits
2 Commits
c8d32959bb
...
c83fd8a2be
Author | SHA1 | Date | |
---|---|---|---|
c83fd8a2be | |||
14bd299929 |
@ -9,7 +9,6 @@ RUN apt update && apt install git gcc libc6-dev -y --no-install-recommends && ap
|
|||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip pip install Cython && \
|
RUN --mount=type=cache,target=/root/.cache/pip pip install Cython && \
|
||||||
pip wheel --no-deps --wheel-dir /app/wheels -r requirements.txt && \
|
pip wheel --no-deps --wheel-dir /app/wheels -r requirements.txt && \
|
||||||
pip wheel torch numpy --wheel-dir /app/wheels --index-url https://download.pytorch.org/whl/cpu
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import DB
|
import DB
|
||||||
from TTSSilero import Speakers
|
|
||||||
|
|
||||||
|
|
||||||
class SpeakersSettingsAdapterDiscord:
|
class SpeakersSettingsAdapterDiscord:
|
||||||
DEFAULT_SPEAKER = Speakers.kseniya
|
def get_speaker(self, guild_id: int, user_id: int) -> str:
|
||||||
|
|
||||||
def get_speaker(self, guild_id: int, user_id: int) -> Speakers:
|
|
||||||
user_defined_speaker = self.get_speaker_user(guild_id, user_id)
|
user_defined_speaker = self.get_speaker_user(guild_id, user_id)
|
||||||
if user_defined_speaker is None:
|
if user_defined_speaker is None:
|
||||||
return self.get_speaker_global(guild_id)
|
return self.get_speaker_global(guild_id)
|
||||||
@ -14,36 +11,29 @@ class SpeakersSettingsAdapterDiscord:
|
|||||||
else:
|
else:
|
||||||
return user_defined_speaker
|
return user_defined_speaker
|
||||||
|
|
||||||
def get_speaker_global(self, guild_id: int) -> Speakers:
|
def get_speaker_global(self, guild_id: int) -> str | None:
|
||||||
server_speaker_query = DB.ServerSpeaker.select()\
|
server_speaker_query = DB.ServerSpeaker.select()\
|
||||||
.where(DB.ServerSpeaker.server_id == guild_id)
|
.where(DB.ServerSpeaker.server_id == guild_id)
|
||||||
|
|
||||||
if server_speaker_query.count() == 1:
|
if server_speaker_query.count() == 1:
|
||||||
return Speakers(server_speaker_query.get().speaker)
|
return server_speaker_query.get().speaker
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return self.DEFAULT_SPEAKER
|
return None
|
||||||
|
|
||||||
def get_speaker_user(self, guild_id: int, user_id: int) -> Speakers | None:
|
def get_speaker_user(self, guild_id: int, user_id: int) -> str | None:
|
||||||
user_speaker_query = DB.UserServerSpeaker.select()\
|
user_speaker_query = DB.UserServerSpeaker.select()\
|
||||||
.where(DB.UserServerSpeaker.server_id == guild_id)\
|
.where(DB.UserServerSpeaker.server_id == guild_id)\
|
||||||
.where(DB.UserServerSpeaker.user_id == user_id)
|
.where(DB.UserServerSpeaker.user_id == user_id)
|
||||||
|
|
||||||
if user_speaker_query.count() == 1:
|
if user_speaker_query.count() == 1:
|
||||||
return Speakers(user_speaker_query.get().speaker)
|
return user_speaker_query.get().speaker
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
def set_speaker_user(self, guild_id: int, user_id: int, speaker: str) -> None:
|
||||||
def available_speakers(self) -> set[str]:
|
DB.UserServerSpeaker.replace(server_id=guild_id, user_id=user_id, speaker=speaker).execute()
|
||||||
return {speaker.name for speaker in Speakers}
|
|
||||||
|
|
||||||
def set_speaker_user(self, guild_id: int, user_id: int, speaker: Speakers) -> None:
|
def set_speaker_global(self, guild_id: int, user_id: int, speaker: str) -> None:
|
||||||
DB.UserServerSpeaker.replace(server_id=guild_id, user_id=user_id, speaker=speaker.value).execute()
|
DB.ServerSpeaker.replace(server_id=guild_id, speaker=speaker).execute()
|
||||||
|
|
||||||
def set_speaker_global(self, guild_id: int, speaker: Speakers) -> None:
|
|
||||||
DB.ServerSpeaker.replace(server_id=guild_id, speaker=speaker.value).execute()
|
|
||||||
|
|
||||||
|
|
||||||
speakers_settings_adapter = SpeakersSettingsAdapterDiscord()
|
|
||||||
|
14
TTSServer/Speakers.py
Normal file
14
TTSServer/Speakers.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from typing import Literal
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Argument(BaseModel):
|
||||||
|
type: Literal['str', 'int', 'float']
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDescription(BaseModel):
|
||||||
|
engine: str
|
||||||
|
name: str
|
||||||
|
arguments: dict[str, Argument]
|
||||||
|
description: None | str = None
|
17
TTSServer/TTSServer.py
Normal file
17
TTSServer/TTSServer.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import aiohttp
|
||||||
|
from .Speakers import ModelDescription
|
||||||
|
|
||||||
|
|
||||||
|
class TTSServer:
|
||||||
|
def __init__(self):
|
||||||
|
self.session = aiohttp.ClientSession(base_url='http://localhost:8000')
|
||||||
|
|
||||||
|
async def synthesize_text(self, text: str, speaker: ModelDescription) -> bytes:
|
||||||
|
async with self.session.post(url=f'/synth/{speaker.engine}/model/{speaker.name}', json={'text': text}) as req:
|
||||||
|
req.raise_for_status()
|
||||||
|
return await req.content.read()
|
||||||
|
|
||||||
|
async def discovery(self) -> list[ModelDescription]:
|
||||||
|
async with self.session.get('/discovery') as req:
|
||||||
|
req.raise_for_status()
|
||||||
|
return [ModelDescription(**item) for item in await req.json()]
|
52
TTSServer/TTSServerCached.py
Normal file
52
TTSServer/TTSServerCached.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from .TTSServer import TTSServer
|
||||||
|
from .Speakers import ModelDescription
|
||||||
|
import DB
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
class TTSServerCached(TTSServer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.models_lookup_idx: dict[str, ModelDescription] = dict()
|
||||||
|
|
||||||
|
async def synthesize_text(self, text: str, speaker: ModelDescription) -> bytes:
|
||||||
|
cache_query = DB.SoundCache.select()\
|
||||||
|
.where(DB.SoundCache.text == text)\
|
||||||
|
.where(DB.SoundCache.speaker == speaker)
|
||||||
|
|
||||||
|
if cache_query.count() == 1:
|
||||||
|
with DB.database.atomic():
|
||||||
|
DB.SoundCache.update({DB.SoundCache.usages: DB.SoundCache.usages + 1})\
|
||||||
|
.where(DB.SoundCache.text == text)\
|
||||||
|
.where(DB.SoundCache.speaker == speaker).execute()
|
||||||
|
|
||||||
|
cached = cache_query.get().audio
|
||||||
|
|
||||||
|
return cached
|
||||||
|
|
||||||
|
else:
|
||||||
|
synthesized = await super().synthesize_text(text, speaker)
|
||||||
|
DB.SoundCache.create(text=text, speaker=speaker, audio=synthesized)
|
||||||
|
return synthesized
|
||||||
|
|
||||||
|
async def discovery(self) -> list[ModelDescription]:
|
||||||
|
res = await super().discovery()
|
||||||
|
logger.debug(f'Discovered {len(res)} models')
|
||||||
|
self.models_lookup_idx: dict[str, ModelDescription] = {f'{desc.engine}_{desc.name}': desc for desc in res}
|
||||||
|
return res
|
||||||
|
|
||||||
|
async def speaker_to_description(self, speaker_str: str) -> ModelDescription | None:
|
||||||
|
for attempt in range(2):
|
||||||
|
if len(self.models_lookup_idx) > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
await self.discovery()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Models discovery seems to return zero models')
|
||||||
|
|
||||||
|
# Default to first model
|
||||||
|
return self.models_lookup_idx.get(speaker_str, None)
|
||||||
|
|
||||||
|
def default_speaker(self) -> ModelDescription:
|
||||||
|
return list(self.models_lookup_idx.values())[0]
|
3
TTSServer/__init__.py
Normal file
3
TTSServer/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .TTSServer import TTSServer
|
||||||
|
from .Speakers import ModelDescription
|
||||||
|
from .TTSServerCached import TTSServerCached
|
@ -1,15 +0,0 @@
|
|||||||
import enum
|
|
||||||
|
|
||||||
|
|
||||||
class Speakers(enum.Enum):
|
|
||||||
aidar = 'aidar'
|
|
||||||
baya = 'baya'
|
|
||||||
kseniya = 'kseniya'
|
|
||||||
irina = 'irina'
|
|
||||||
ruslan = 'ruslan'
|
|
||||||
natasha = 'natasha'
|
|
||||||
thorsten = 'thorsten'
|
|
||||||
tux = 'tux'
|
|
||||||
gilles = 'gilles'
|
|
||||||
lj = 'lj'
|
|
||||||
dilyara = 'dilyara'
|
|
@ -1,63 +0,0 @@
|
|||||||
import contextlib
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import wave
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch.package
|
|
||||||
|
|
||||||
from .Speakers import Speakers
|
|
||||||
from .multi_v2_package import TTSModelMulti_v2
|
|
||||||
|
|
||||||
|
|
||||||
class TTSSilero:
|
|
||||||
def __init__(self, threads: int = 12):
|
|
||||||
device = torch.device('cpu')
|
|
||||||
torch.set_num_threads(threads)
|
|
||||||
local_file = Path(os.getenv('DATA_DIR', '.')) / 'model_multi.pt'
|
|
||||||
|
|
||||||
if not os.path.isfile(local_file):
|
|
||||||
torch.hub.download_url_to_file(
|
|
||||||
'https://models.silero.ai/models/tts/multi/v2_multi.pt',
|
|
||||||
str(local_file)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model: TTSModelMulti_v2 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
|
|
||||||
self.model.to(device)
|
|
||||||
|
|
||||||
self.sample_rate = 16000
|
|
||||||
|
|
||||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
|
|
||||||
return self.to_wav(self._synthesize_text(text, speaker))
|
|
||||||
|
|
||||||
def _synthesize_text(self, text: str, speaker: Speakers) -> list[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Performs splitting text and synthesizing it
|
|
||||||
|
|
||||||
:param text:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
results_list: list[torch.Tensor] = self.model.apply_tts(
|
|
||||||
texts=[text],
|
|
||||||
speakers=speaker.value,
|
|
||||||
sample_rate=self.sample_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
return results_list
|
|
||||||
|
|
||||||
def to_wav(self, synthesized_text: list[torch.Tensor]) -> bytes:
|
|
||||||
res_io_stream = io.BytesIO()
|
|
||||||
|
|
||||||
with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf:
|
|
||||||
wf.setnchannels(1)
|
|
||||||
wf.setsampwidth(2)
|
|
||||||
wf.setframerate(self.sample_rate)
|
|
||||||
for result in synthesized_text:
|
|
||||||
wf.writeframes((result * 32767).numpy().astype('int16'))
|
|
||||||
|
|
||||||
res_io_stream.seek(0)
|
|
||||||
|
|
||||||
return res_io_stream.read()
|
|
||||||
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from .TTSSilero import TTSSilero
|
|
||||||
from .Speakers import Speakers
|
|
||||||
import DB
|
|
||||||
import sqlite3
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
|
|
||||||
class TTSSileroCached(TTSSilero):
|
|
||||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
|
|
||||||
# start = time.time()
|
|
||||||
cache_query = DB.SoundCache.select()\
|
|
||||||
.where(DB.SoundCache.text == text)\
|
|
||||||
.where(DB.SoundCache.speaker == speaker.value)
|
|
||||||
|
|
||||||
if cache_query.count() == 1:
|
|
||||||
with DB.database.atomic():
|
|
||||||
DB.SoundCache.update({DB.SoundCache.usages: DB.SoundCache.usages + 1})\
|
|
||||||
.where(DB.SoundCache.text == text)\
|
|
||||||
.where(DB.SoundCache.speaker == speaker.value).execute()
|
|
||||||
|
|
||||||
cached = cache_query.get().audio
|
|
||||||
|
|
||||||
return cached
|
|
||||||
|
|
||||||
else:
|
|
||||||
# logger.debug(f'Starting synthesis')
|
|
||||||
# start2 = time.time()
|
|
||||||
synthesized = super().synthesize_text(text, speaker)
|
|
||||||
# logger.debug(f'Synthesis done in {time.time() - start2} s in {time.time() - start} s after start')
|
|
||||||
DB.SoundCache.create(text=text, speaker=speaker.value, audio=synthesized)
|
|
||||||
|
|
||||||
# logger.debug(f'Cache set in {time.time() - start2} synth end and {time.time() - start2} s after start')
|
|
||||||
return synthesized
|
|
@ -1,3 +0,0 @@
|
|||||||
from .TTSSilero import TTSSilero
|
|
||||||
from .Speakers import Speakers
|
|
||||||
from .TTSSileroCached import TTSSileroCached
|
|
@ -1,148 +0,0 @@
|
|||||||
import re
|
|
||||||
import wave
|
|
||||||
import torch
|
|
||||||
import warnings
|
|
||||||
import contextlib
|
|
||||||
|
|
||||||
# for type hints only
|
|
||||||
|
|
||||||
|
|
||||||
class TTSModelMulti_v2():
|
|
||||||
def __init__(self, model_path, symbols):
|
|
||||||
self.model = self.init_jit_model(model_path)
|
|
||||||
self.symbols = symbols
|
|
||||||
self.device = torch.device('cpu')
|
|
||||||
speakers = ['aidar', 'baya', 'kseniya', 'irina', 'ruslan', 'natasha',
|
|
||||||
'thorsten', 'tux', 'gilles', 'lj', 'dilyara']
|
|
||||||
self.speaker_to_id = {sp: i for i, sp in enumerate(speakers)}
|
|
||||||
|
|
||||||
def init_jit_model(self, model_path: str):
|
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
model = torch.jit.load(model_path, map_location='cpu')
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
def prepare_text_input(self, text, symbols, symbol_to_id=None):
|
|
||||||
if len(text) > 140:
|
|
||||||
warnings.warn('Text string is longer than 140 symbols.')
|
|
||||||
|
|
||||||
if symbol_to_id is None:
|
|
||||||
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
|
||||||
|
|
||||||
text = text.lower()
|
|
||||||
text = re.sub(r'[^{}]'.format(symbols[2:]), '', text)
|
|
||||||
text = re.sub(r'\s+', ' ', text).strip()
|
|
||||||
if text[-1] not in ['.', '!', '?']:
|
|
||||||
text = text + '.'
|
|
||||||
text = text + symbols[1]
|
|
||||||
|
|
||||||
text_ohe = [symbol_to_id[s] for s in text if s in symbols]
|
|
||||||
text_tensor = torch.LongTensor(text_ohe)
|
|
||||||
return text_tensor
|
|
||||||
|
|
||||||
def prepare_tts_model_input(self, text: str or list, symbols: str, speakers: list):
|
|
||||||
assert len(speakers) == len(text) or len(speakers) == 1
|
|
||||||
if type(text) == str:
|
|
||||||
text = [text]
|
|
||||||
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
|
||||||
if len(text) == 1:
|
|
||||||
return self.prepare_text_input(text[0], symbols, symbol_to_id).unsqueeze(0), torch.LongTensor(speakers), torch.LongTensor([0])
|
|
||||||
|
|
||||||
text_tensors = []
|
|
||||||
for string in text:
|
|
||||||
string_tensor = self.prepare_text_input(string, symbols, symbol_to_id)
|
|
||||||
text_tensors.append(string_tensor)
|
|
||||||
input_lengths, ids_sorted_decreasing = torch.sort(
|
|
||||||
torch.LongTensor([len(t) for t in text_tensors]),
|
|
||||||
dim=0, descending=True)
|
|
||||||
max_input_len = input_lengths[0]
|
|
||||||
batch_size = len(text_tensors)
|
|
||||||
|
|
||||||
text_padded = torch.ones(batch_size, max_input_len, dtype=torch.int32)
|
|
||||||
if len(speakers) == 1:
|
|
||||||
speakers = speakers*batch_size
|
|
||||||
speaker_ids = torch.LongTensor(batch_size).zero_()
|
|
||||||
|
|
||||||
for i, idx in enumerate(ids_sorted_decreasing):
|
|
||||||
text_tensor = text_tensors[idx]
|
|
||||||
in_len = text_tensor.size(0)
|
|
||||||
text_padded[i, :in_len] = text_tensor
|
|
||||||
speaker_ids[i] = speakers[idx]
|
|
||||||
|
|
||||||
return text_padded, speaker_ids, ids_sorted_decreasing
|
|
||||||
|
|
||||||
def process_tts_model_output(self, out, out_lens, ids):
|
|
||||||
out = out.to('cpu')
|
|
||||||
out_lens = out_lens.to('cpu')
|
|
||||||
_, orig_ids = ids.sort()
|
|
||||||
|
|
||||||
proc_outs = []
|
|
||||||
orig_out = out.index_select(0, orig_ids)
|
|
||||||
orig_out_lens = out_lens.index_select(0, orig_ids)
|
|
||||||
|
|
||||||
for i, out_len in enumerate(orig_out_lens):
|
|
||||||
proc_outs.append(orig_out[i][:out_len])
|
|
||||||
return proc_outs
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
self.model = self.model.to(device)
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def get_speakers(self, speakers: str or list):
|
|
||||||
if type(speakers) == str:
|
|
||||||
speakers = [speakers]
|
|
||||||
speaker_ids = []
|
|
||||||
for speaker in speakers:
|
|
||||||
try:
|
|
||||||
speaker_id = self.speaker_to_id[speaker]
|
|
||||||
speaker_ids.append(speaker_id)
|
|
||||||
except Exception:
|
|
||||||
raise ValueError(f'No such speaker: {speaker}')
|
|
||||||
return speaker_ids
|
|
||||||
|
|
||||||
def apply_tts(self, texts: str or list,
|
|
||||||
speakers: str or list,
|
|
||||||
sample_rate: int = 16000):
|
|
||||||
speaker_ids = self.get_speakers(speakers)
|
|
||||||
text_padded, speaker_ids, orig_ids = self.prepare_tts_model_input(texts,
|
|
||||||
symbols=self.symbols,
|
|
||||||
speakers=speaker_ids)
|
|
||||||
with torch.inference_mode():
|
|
||||||
out, out_lens = self.model(text_padded.to(self.device),
|
|
||||||
speaker_ids.to(self.device),
|
|
||||||
sr=sample_rate)
|
|
||||||
audios = self.process_tts_model_output(out, out_lens, orig_ids)
|
|
||||||
return audios
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def write_wave(path, audio, sample_rate):
|
|
||||||
"""Writes a .wav file.
|
|
||||||
Takes path, PCM audio data, and sample rate.
|
|
||||||
"""
|
|
||||||
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
|
||||||
wf.setnchannels(1)
|
|
||||||
wf.setsampwidth(2)
|
|
||||||
wf.setframerate(sample_rate)
|
|
||||||
wf.writeframes(audio)
|
|
||||||
|
|
||||||
def save_wav(self, texts: str or list,
|
|
||||||
speakers: str or list,
|
|
||||||
audio_pathes: str or list = '',
|
|
||||||
sample_rate: int = 16000):
|
|
||||||
if type(texts) == str:
|
|
||||||
texts = [texts]
|
|
||||||
|
|
||||||
if not audio_pathes:
|
|
||||||
audio_pathes = [f'test_{str(i).zfill(3)}.wav' for i in range(len(texts))]
|
|
||||||
if type(audio_pathes) == str:
|
|
||||||
audio_pathes = [audio_pathes]
|
|
||||||
assert len(audio_pathes) == len(texts)
|
|
||||||
|
|
||||||
audio = self.apply_tts(texts=texts,
|
|
||||||
speakers=speakers,
|
|
||||||
sample_rate=sample_rate)
|
|
||||||
for i, _audio in enumerate(audio):
|
|
||||||
self.write_wave(path=audio_pathes[i],
|
|
||||||
audio=(_audio * 32767).numpy().astype('int16'),
|
|
||||||
sample_rate=sample_rate)
|
|
||||||
return audio_pathes
|
|
@ -4,8 +4,8 @@ import time
|
|||||||
from formatting import format_table
|
from formatting import format_table
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from discord.ext.commands import Context
|
from discord.ext.commands import Context
|
||||||
import rfoo
|
# import rfoo
|
||||||
from rfoo.utils import rconsole
|
# from rfoo.utils import rconsole
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Optional, Coroutine
|
from typing import Optional, Coroutine
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -24,7 +24,7 @@ class BotManagement(commands.Cog, SingletonBase):
|
|||||||
instance: Optional['BotManagement'] = None
|
instance: Optional['BotManagement'] = None
|
||||||
|
|
||||||
rfoo_server_thread: Optional[Thread] = None
|
rfoo_server_thread: Optional[Thread] = None
|
||||||
rfoo_server: Optional[rfoo.InetServer] = None
|
rfoo_server: Optional['rfoo.InetServer'] = None
|
||||||
|
|
||||||
# def __new__(cls, *args, **kwargs):
|
# def __new__(cls, *args, **kwargs):
|
||||||
# if cls.instance is None:
|
# if cls.instance is None:
|
||||||
@ -125,9 +125,10 @@ class BotManagement(commands.Cog, SingletonBase):
|
|||||||
|
|
||||||
|
|
||||||
async def setup(bot: commands.Bot):
|
async def setup(bot: commands.Bot):
|
||||||
await bot.add_cog(BotManagement(bot))
|
...
|
||||||
|
# await bot.add_cog(BotManagement(bot))
|
||||||
|
|
||||||
|
|
||||||
async def teardown(bot):
|
# async def teardown(bot):
|
||||||
stop_res = BotManagement(bot).stop_rfoo()
|
# stop_res = BotManagement(bot).stop_rfoo()
|
||||||
logger.info(f'Unloaded rfoo with result {stop_res} during BotManagement unload')
|
# logger.info(f'Unloaded rfoo with result {stop_res} during BotManagement unload')
|
||||||
|
110
cogs/TTSCore.py
110
cogs/TTSCore.py
@ -8,12 +8,14 @@ import discord
|
|||||||
import DB
|
import DB
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from TTSSilero import TTSSileroCached
|
|
||||||
from TTSSilero import Speakers
|
import formatting
|
||||||
|
from TTSServer import TTSServerCached
|
||||||
from FFmpegPCMAudioModified import FFmpegPCMAudio
|
from FFmpegPCMAudioModified import FFmpegPCMAudio
|
||||||
import Observ
|
import Observ
|
||||||
|
from TTSServer.Speakers import ModelDescription
|
||||||
from cogErrorHandlers import cogErrorHandlers
|
from cogErrorHandlers import cogErrorHandlers
|
||||||
from SpeakersSettingsAdapterDiscord import speakers_settings_adapter, SpeakersSettingsAdapterDiscord
|
from SpeakersSettingsAdapterDiscord import SpeakersSettingsAdapterDiscord
|
||||||
|
|
||||||
|
|
||||||
class TTSCore(commands.Cog, Observ.Observer):
|
class TTSCore(commands.Cog, Observ.Observer):
|
||||||
@ -21,9 +23,9 @@ class TTSCore(commands.Cog, Observ.Observer):
|
|||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.cog_command_error = cogErrorHandlers.missing_argument_handler
|
self.cog_command_error = cogErrorHandlers.missing_argument_handler
|
||||||
self.bot.subscribe(self) # subscribe for messages that aren't commands
|
self.bot.subscribe(self) # subscribe for messages that aren't commands
|
||||||
self.tts = TTSSileroCached()
|
self.tts = TTSServerCached()
|
||||||
self.tts_queues: dict[int, list[discord.AudioSource]] = defaultdict(list)
|
self.tts_queues: dict[int, list[discord.AudioSource]] = defaultdict(list)
|
||||||
self.speakers_adapter: SpeakersSettingsAdapterDiscord = speakers_settings_adapter
|
self.speakers_adapter: SpeakersSettingsAdapterDiscord = SpeakersSettingsAdapterDiscord()
|
||||||
|
|
||||||
@commands.command('drop')
|
@commands.command('drop')
|
||||||
async def drop_queue(self, ctx: Context):
|
async def drop_queue(self, ctx: Context):
|
||||||
@ -89,11 +91,14 @@ class TTSCore(commands.Cog, Observ.Observer):
|
|||||||
await message.channel.send('We are in different voice channels')
|
await message.channel.send('We are in different voice channels')
|
||||||
return
|
return
|
||||||
|
|
||||||
speaker: Speakers = self.speakers_adapter.get_speaker(message.guild.id, message.author.id)
|
speaker_str: str = self.speakers_adapter.get_speaker(message.guild.id, message.author.id)
|
||||||
|
speaker: ModelDescription | None = await self.tts.speaker_to_description(speaker_str)
|
||||||
|
if speaker is None:
|
||||||
|
speaker = self.tts.default_speaker()
|
||||||
|
|
||||||
# check if message will fail on synthesis
|
# check if message will fail on synthesis
|
||||||
if DB.SynthesisErrors.select()\
|
if DB.SynthesisErrors.select()\
|
||||||
.where(DB.SynthesisErrors.speaker == speaker.value)\
|
.where(DB.SynthesisErrors.speaker == speaker_str)\
|
||||||
.where(DB.SynthesisErrors.text == message.content)\
|
.where(DB.SynthesisErrors.text == message.content)\
|
||||||
.count() == 1:
|
.count() == 1:
|
||||||
# Then we will not try to synthesis it
|
# Then we will not try to synthesis it
|
||||||
@ -101,7 +106,7 @@ class TTSCore(commands.Cog, Observ.Observer):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
wav_file_like_object = self.tts.synthesize_text(message.content, speaker=speaker)
|
wav_file_like_object: bytes = await self.tts.synthesize_text(message.content, speaker=speaker)
|
||||||
sound_source = FFmpegPCMAudio(wav_file_like_object, pipe=True, stderr=subprocess.PIPE)
|
sound_source = FFmpegPCMAudio(wav_file_like_object, pipe=True, stderr=subprocess.PIPE)
|
||||||
|
|
||||||
except Exception as synth_exception:
|
except Exception as synth_exception:
|
||||||
@ -160,6 +165,95 @@ class TTSCore(commands.Cog, Observ.Observer):
|
|||||||
|
|
||||||
# TODO: leave voice channel after being moved there alone
|
# TODO: leave voice channel after being moved there alone
|
||||||
|
|
||||||
|
# Ex TTSSettings.py
|
||||||
|
@commands.command('getAllSpeakers')
|
||||||
|
async def get_speakers(self, ctx: Context):
|
||||||
|
"""
|
||||||
|
Enumerate all available to set speakers
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
await ctx.send(
|
||||||
|
formatting.format_table(
|
||||||
|
tuple((model.engine + '_' + model.name, model.description)
|
||||||
|
for model in await self.tts.discovery()),
|
||||||
|
('model', 'description')
|
||||||
|
))
|
||||||
|
|
||||||
|
@commands.command('setPersonalSpeaker')
|
||||||
|
async def set_user_speaker(self, ctx: Context, speaker: str):
|
||||||
|
"""
|
||||||
|
Set personal speaker on this server
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:param speaker:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
if await self.tts.speaker_to_description(speaker) is None:
|
||||||
|
return await ctx.send(
|
||||||
|
f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||||
|
|
||||||
|
self.speakers_adapter.set_speaker_user(ctx.guild.id, ctx.author.id, speaker)
|
||||||
|
await ctx.reply(f'Successfully set **your personal** speaker to `{speaker}`')
|
||||||
|
|
||||||
|
@commands.command('setServerSpeaker')
|
||||||
|
async def set_server_speaker(self, ctx: Context, speaker: str):
|
||||||
|
"""
|
||||||
|
Set global server speaker
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:param speaker:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if await self.tts.speaker_to_description(speaker) is None:
|
||||||
|
return await ctx.send(
|
||||||
|
f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||||
|
|
||||||
|
self.speakers_adapter.set_speaker_global(ctx.guild.id, ctx.author.id, speaker)
|
||||||
|
await ctx.reply(f'Successfully set **your personal** speaker to `{speaker}`')
|
||||||
|
|
||||||
|
@commands.command('getSpeaker')
|
||||||
|
async def get_speaker(self, ctx: Context):
|
||||||
|
"""
|
||||||
|
Tell first appropriate speaker for a user, it can be user specified, server specified or server default
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
speaker = self.speakers_adapter.get_speaker(ctx.guild.id, ctx.author.id)
|
||||||
|
|
||||||
|
await ctx.reply(f'Your current speaker is `{speaker}`')
|
||||||
|
|
||||||
|
@commands.command('getPersonalSpeaker')
|
||||||
|
async def get_personal_speaker(self, ctx: Context):
|
||||||
|
"""
|
||||||
|
Tell user his personal speaker on this server, if user don't have personal speaker, tells server default one
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
speaker = self.speakers_adapter.get_speaker_user(ctx.guild.id, ctx.author.id)
|
||||||
|
if speaker is None:
|
||||||
|
server_speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
|
||||||
|
await ctx.send(f"You currently don't have a personal speaker, current server speaker is `{server_speaker}`")
|
||||||
|
|
||||||
|
else:
|
||||||
|
await ctx.reply(f"Your personal speaker is `{speaker}`")
|
||||||
|
|
||||||
|
@commands.command('getServerSpeaker')
|
||||||
|
async def get_server_speaker(self, ctx: Context):
|
||||||
|
"""
|
||||||
|
Tell server global speaker
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
|
||||||
|
await ctx.send(f"Current server speaker is `{speaker}`")
|
||||||
|
|
||||||
|
|
||||||
async def setup(bot):
|
async def setup(bot):
|
||||||
await bot.add_cog(TTSCore(bot))
|
await bot.add_cog(TTSCore(bot))
|
||||||
|
@ -1,102 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
from discord.ext import commands
|
|
||||||
from discord.ext.commands import Context
|
|
||||||
from TTSSilero import Speakers
|
|
||||||
from cogErrorHandlers import cogErrorHandlers
|
|
||||||
from SpeakersSettingsAdapterDiscord import speakers_settings_adapter, SpeakersSettingsAdapterDiscord
|
|
||||||
|
|
||||||
|
|
||||||
class TTSSettings(commands.Cog):
|
|
||||||
def __init__(self, bot: commands.Bot):
|
|
||||||
self.bot = bot
|
|
||||||
self.cog_command_error = cogErrorHandlers.missing_argument_handler
|
|
||||||
self.speakers_adapter: SpeakersSettingsAdapterDiscord = speakers_settings_adapter
|
|
||||||
|
|
||||||
@commands.command('getAllSpeakers')
|
|
||||||
async def get_speakers(self, ctx: Context):
|
|
||||||
"""
|
|
||||||
Enumerate all available to set speakers
|
|
||||||
|
|
||||||
:param ctx:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
speakers = '\n'.join(self.speakers_adapter.available_speakers)
|
|
||||||
|
|
||||||
await ctx.send(f"```\n{speakers}```")
|
|
||||||
|
|
||||||
@commands.command('setPersonalSpeaker')
|
|
||||||
async def set_user_speaker(self, ctx: Context, speaker: str):
|
|
||||||
"""
|
|
||||||
Set personal speaker on this server
|
|
||||||
|
|
||||||
:param ctx:
|
|
||||||
:param speaker:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
checked_speaker: Speakers = Speakers(speaker)
|
|
||||||
self.speakers_adapter.set_speaker_user(ctx.guild.id, ctx.author.id, checked_speaker)
|
|
||||||
await ctx.reply(f'Successfully set **your personal** speaker to `{checked_speaker.value}`')
|
|
||||||
|
|
||||||
except (KeyError, ValueError):
|
|
||||||
await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
|
||||||
|
|
||||||
@commands.command('setServerSpeaker')
|
|
||||||
async def set_server_speaker(self, ctx: Context, speaker: str):
|
|
||||||
"""
|
|
||||||
Set global server speaker
|
|
||||||
|
|
||||||
:param ctx:
|
|
||||||
:param speaker:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
checked_speaker: Speakers = Speakers(speaker)
|
|
||||||
self.speakers_adapter.set_speaker_global(ctx.guild.id, checked_speaker)
|
|
||||||
await ctx.send(f'Successfully set **server** speaker to `{checked_speaker.value}`')
|
|
||||||
|
|
||||||
except (KeyError, ValueError):
|
|
||||||
await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
|
||||||
|
|
||||||
@commands.command('getSpeaker')
|
|
||||||
async def get_speaker(self, ctx: Context):
|
|
||||||
"""
|
|
||||||
Tell first appropriate speaker for a user, it can be user specified, server specified or server default
|
|
||||||
|
|
||||||
:param ctx:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
speaker = self.speakers_adapter.get_speaker(ctx.guild.id, ctx.author.id)
|
|
||||||
|
|
||||||
await ctx.reply(f'Your current speaker is `{speaker.value}`')
|
|
||||||
|
|
||||||
@commands.command('getPersonalSpeaker')
|
|
||||||
async def get_personal_speaker(self, ctx: Context):
|
|
||||||
"""
|
|
||||||
Tell user his personal speaker on this server, if user don't have personal speaker, tells server default one
|
|
||||||
|
|
||||||
:param ctx:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
speaker = self.speakers_adapter.get_speaker_user(ctx.guild.id, ctx.author.id)
|
|
||||||
if speaker is None:
|
|
||||||
server_speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id).value
|
|
||||||
await ctx.send(f"You currently don't have a personal speaker, current server speaker is `{server_speaker}`")
|
|
||||||
|
|
||||||
else:
|
|
||||||
await ctx.reply(f"Your personal speaker is `{speaker.value}`")
|
|
||||||
|
|
||||||
@commands.command('getServerSpeaker')
|
|
||||||
async def get_server_speaker(self, ctx: Context):
|
|
||||||
"""
|
|
||||||
Tell server global speaker
|
|
||||||
|
|
||||||
:param ctx:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
|
|
||||||
await ctx.send(f"Current server speaker is `{speaker.value}`")
|
|
||||||
|
|
||||||
|
|
||||||
async def setup(bot):
|
|
||||||
await bot.add_cog(TTSSettings(bot))
|
|
@ -11,7 +11,7 @@ def format_table(data: Iterable[Iterable[str]], header: Iterable[str] = MISSING)
|
|||||||
|
|
||||||
result = '```\n'
|
result = '```\n'
|
||||||
for row in data:
|
for row in data:
|
||||||
row = [item.replace('`', '\\`') for item in row]
|
row = [str(item).replace('`', '\\`') for item in row]
|
||||||
result += '\t'.join(row) + '\n'
|
result += '\t'.join(row) + '\n'
|
||||||
|
|
||||||
result += '```'
|
result += '```'
|
||||||
|
@ -4,6 +4,4 @@ peewee==3.17.5
|
|||||||
PyNaCl==1.5.0
|
PyNaCl==1.5.0
|
||||||
git+https://github.com/aaiyer/rfoo.git@1555bd4eed204bb6a33a5e313146a6c2813cfe91
|
git+https://github.com/aaiyer/rfoo.git@1555bd4eed204bb6a33a5e313146a6c2813cfe91
|
||||||
# Cython is setup dependency for rfoo
|
# Cython is setup dependency for rfoo
|
||||||
--index-url https://download.pytorch.org/whl/cpu
|
pydantic
|
||||||
torch==2.3.0+cpu
|
|
||||||
numpy==1.26.4
|
|
6
utils.py
6
utils.py
@ -1,10 +1,4 @@
|
|||||||
import winsound
|
import winsound
|
||||||
import io
|
|
||||||
|
|
||||||
|
|
||||||
def save_bytes(filename: str, bytes_audio: bytes) -> None:
|
|
||||||
with open(file=filename, mode='wb') as res_file:
|
|
||||||
res_file.write(bytes_audio)
|
|
||||||
|
|
||||||
|
|
||||||
def play_bytes(bytes_sound: bytes) -> None:
|
def play_bytes(bytes_sound: bytes) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user