Utilize ttsserver api for synth
This commit is contained in:
parent
c8d32959bb
commit
db62a98f0a
7
DB.py
7
DB.py
@ -1,11 +1,10 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import peewee
|
||||
|
||||
DB_PATH = Path(os.getenv('DATA_DIR', '.')) / 'voice_cache.sqlite'
|
||||
database = peewee.SqliteDatabase(str(DB_PATH))
|
||||
import config
|
||||
|
||||
database = peewee.SqliteDatabase(str(config.DB_PATH))
|
||||
|
||||
|
||||
class BaseModel(peewee.Model):
|
||||
|
@ -8,8 +8,7 @@ ENV PYTHONUNBUFFERED 1
|
||||
RUN apt update && apt install git gcc libc6-dev -y --no-install-recommends && apt clean && rm -rf /var/lib/apt/lists/*
|
||||
COPY requirements.txt .
|
||||
RUN --mount=type=cache,target=/root/.cache/pip pip install Cython && \
|
||||
pip wheel --no-deps --wheel-dir /app/wheels -r requirements.txt && \
|
||||
pip wheel torch numpy --wheel-dir /app/wheels --index-url https://download.pytorch.org/whl/cpu
|
||||
pip wheel --no-deps --wheel-dir /app/wheels -r requirements.txt
|
||||
|
||||
|
||||
|
||||
@ -20,7 +19,6 @@ ENV PYTHONUNBUFFERED 1
|
||||
|
||||
RUN useradd -ms /bin/bash silero_user && \
|
||||
apt update && apt install ffmpeg -y --no-install-recommends && apt clean && rm -rf /var/lib/apt/lists/*
|
||||
USER silero_user
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@ -30,5 +28,6 @@ COPY --from=builder /app/requirements.txt .
|
||||
RUN pip install --no-cache /wheels/*
|
||||
|
||||
COPY . .
|
||||
USER silero_user
|
||||
|
||||
CMD ["python3", "/app/main.py"]
|
||||
|
@ -1,8 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
import peewee
|
||||
from discord.ext import commands
|
||||
|
||||
import DB
|
||||
|
||||
|
||||
# from loguru import logger
|
||||
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
import subprocess
|
||||
import shlex
|
||||
import io
|
||||
from discord.opus import Encoder
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
import discord
|
||||
from discord.opus import Encoder
|
||||
|
||||
|
||||
# Huge thanks to https://github.com/Rapptz/discord.py/issues/5192#issuecomment-669515571
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
|
||||
import discord
|
||||
|
||||
|
||||
class Observer:
|
||||
@abstractmethod
|
||||
def update(self, message: discord.Message) -> None:
|
||||
async def update(self, message: discord.Message) -> None:
|
||||
raise NotImplemented
|
||||
|
@ -1,4 +1,5 @@
|
||||
import discord
|
||||
|
||||
from .Observer import Observer
|
||||
|
||||
|
||||
|
@ -1,12 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import DB
|
||||
from TTSSilero import Speakers
|
||||
|
||||
|
||||
class SpeakersSettingsAdapterDiscord:
|
||||
DEFAULT_SPEAKER = Speakers.kseniya
|
||||
|
||||
def get_speaker(self, guild_id: int, user_id: int) -> Speakers:
|
||||
def get_speaker(self, guild_id: int, user_id: int) -> str:
|
||||
user_defined_speaker = self.get_speaker_user(guild_id, user_id)
|
||||
if user_defined_speaker is None:
|
||||
return self.get_speaker_global(guild_id)
|
||||
@ -14,36 +11,29 @@ class SpeakersSettingsAdapterDiscord:
|
||||
else:
|
||||
return user_defined_speaker
|
||||
|
||||
def get_speaker_global(self, guild_id: int) -> Speakers:
|
||||
def get_speaker_global(self, guild_id: int) -> str | None:
|
||||
server_speaker_query = DB.ServerSpeaker.select()\
|
||||
.where(DB.ServerSpeaker.server_id == guild_id)
|
||||
|
||||
if server_speaker_query.count() == 1:
|
||||
return Speakers(server_speaker_query.get().speaker)
|
||||
return server_speaker_query.get().speaker
|
||||
|
||||
else:
|
||||
return self.DEFAULT_SPEAKER
|
||||
return None
|
||||
|
||||
def get_speaker_user(self, guild_id: int, user_id: int) -> Speakers | None:
|
||||
def get_speaker_user(self, guild_id: int, user_id: int) -> str | None:
|
||||
user_speaker_query = DB.UserServerSpeaker.select()\
|
||||
.where(DB.UserServerSpeaker.server_id == guild_id)\
|
||||
.where(DB.UserServerSpeaker.user_id == user_id)
|
||||
|
||||
if user_speaker_query.count() == 1:
|
||||
return Speakers(user_speaker_query.get().speaker)
|
||||
return user_speaker_query.get().speaker
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def available_speakers(self) -> set[str]:
|
||||
return {speaker.name for speaker in Speakers}
|
||||
def set_speaker_user(self, guild_id: int, user_id: int, speaker: str) -> None:
|
||||
DB.UserServerSpeaker.replace(server_id=guild_id, user_id=user_id, speaker=speaker).execute()
|
||||
|
||||
def set_speaker_user(self, guild_id: int, user_id: int, speaker: Speakers) -> None:
|
||||
DB.UserServerSpeaker.replace(server_id=guild_id, user_id=user_id, speaker=speaker.value).execute()
|
||||
|
||||
def set_speaker_global(self, guild_id: int, speaker: Speakers) -> None:
|
||||
DB.ServerSpeaker.replace(server_id=guild_id, speaker=speaker.value).execute()
|
||||
|
||||
|
||||
speakers_settings_adapter = SpeakersSettingsAdapterDiscord()
|
||||
def set_speaker_global(self, guild_id: int, user_id: int, speaker: str) -> None:
|
||||
DB.ServerSpeaker.replace(server_id=guild_id, speaker=speaker).execute()
|
||||
|
15
TTSServer/Speakers.py
Normal file
15
TTSServer/Speakers.py
Normal file
@ -0,0 +1,15 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Argument(BaseModel):
|
||||
type: Literal['str', 'int', 'float']
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class ModelDescription(BaseModel):
|
||||
engine: str
|
||||
name: str
|
||||
arguments: dict[str, Argument]
|
||||
description: None | str = None
|
19
TTSServer/TTSServer.py
Normal file
19
TTSServer/TTSServer.py
Normal file
@ -0,0 +1,19 @@
|
||||
import aiohttp
|
||||
|
||||
import config
|
||||
from .Speakers import ModelDescription
|
||||
|
||||
|
||||
class TTSServer:
|
||||
def __init__(self):
|
||||
self.session = aiohttp.ClientSession(base_url=config.BASE_URL)
|
||||
|
||||
async def synthesize_text(self, text: str, speaker: ModelDescription) -> bytes:
|
||||
async with self.session.post(url=f'/synth/{speaker.engine}/model/{speaker.name}', json={'text': text}) as req:
|
||||
req.raise_for_status()
|
||||
return await req.content.read()
|
||||
|
||||
async def discovery(self) -> list[ModelDescription]:
|
||||
async with self.session.get('/discovery') as req:
|
||||
req.raise_for_status()
|
||||
return [ModelDescription(**item) for item in await req.json()]
|
53
TTSServer/TTSServerCached.py
Normal file
53
TTSServer/TTSServerCached.py
Normal file
@ -0,0 +1,53 @@
|
||||
from loguru import logger
|
||||
|
||||
import DB
|
||||
from .Speakers import ModelDescription
|
||||
from .TTSServer import TTSServer
|
||||
|
||||
|
||||
class TTSServerCached(TTSServer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.models_lookup_idx: dict[str, ModelDescription] = dict()
|
||||
|
||||
async def synthesize_text(self, text: str, speaker: ModelDescription) -> bytes:
|
||||
cache_query = DB.SoundCache.select() \
|
||||
.where(DB.SoundCache.text == text) \
|
||||
.where(DB.SoundCache.speaker == speaker)
|
||||
|
||||
if cache_query.count() == 1:
|
||||
with DB.database.atomic():
|
||||
DB.SoundCache.update({DB.SoundCache.usages: DB.SoundCache.usages + 1}) \
|
||||
.where(DB.SoundCache.text == text) \
|
||||
.where(DB.SoundCache.speaker == speaker).execute()
|
||||
|
||||
cached = cache_query.get().audio
|
||||
|
||||
return cached
|
||||
|
||||
else:
|
||||
synthesized = await super().synthesize_text(text, speaker)
|
||||
DB.SoundCache.create(text=text, speaker=speaker, audio=synthesized)
|
||||
return synthesized
|
||||
|
||||
async def discovery(self) -> list[ModelDescription]:
|
||||
res = await super().discovery()
|
||||
logger.debug(f'Discovered {len(res)} models')
|
||||
self.models_lookup_idx: dict[str, ModelDescription] = {f'{desc.engine}_{desc.name}': desc for desc in res}
|
||||
return res
|
||||
|
||||
async def speaker_to_description(self, speaker_str: str) -> ModelDescription | None:
|
||||
for attempt in range(2):
|
||||
if len(self.models_lookup_idx) > 0:
|
||||
break
|
||||
|
||||
await self.discovery()
|
||||
|
||||
else:
|
||||
raise RuntimeError('Models discovery seems to return zero models')
|
||||
|
||||
# Default to first model
|
||||
return self.models_lookup_idx.get(speaker_str, None)
|
||||
|
||||
def default_speaker(self) -> ModelDescription:
|
||||
return list(self.models_lookup_idx.values())[0]
|
3
TTSServer/__init__.py
Normal file
3
TTSServer/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .Speakers import ModelDescription
|
||||
from .TTSServer import TTSServer
|
||||
from .TTSServerCached import TTSServerCached
|
@ -1,15 +0,0 @@
|
||||
import enum
|
||||
|
||||
|
||||
class Speakers(enum.Enum):
|
||||
aidar = 'aidar'
|
||||
baya = 'baya'
|
||||
kseniya = 'kseniya'
|
||||
irina = 'irina'
|
||||
ruslan = 'ruslan'
|
||||
natasha = 'natasha'
|
||||
thorsten = 'thorsten'
|
||||
tux = 'tux'
|
||||
gilles = 'gilles'
|
||||
lj = 'lj'
|
||||
dilyara = 'dilyara'
|
@ -1,63 +0,0 @@
|
||||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import wave
|
||||
from pathlib import Path
|
||||
|
||||
import torch.package
|
||||
|
||||
from .Speakers import Speakers
|
||||
from .multi_v2_package import TTSModelMulti_v2
|
||||
|
||||
|
||||
class TTSSilero:
|
||||
def __init__(self, threads: int = 12):
|
||||
device = torch.device('cpu')
|
||||
torch.set_num_threads(threads)
|
||||
local_file = Path(os.getenv('DATA_DIR', '.')) / 'model_multi.pt'
|
||||
|
||||
if not os.path.isfile(local_file):
|
||||
torch.hub.download_url_to_file(
|
||||
'https://models.silero.ai/models/tts/multi/v2_multi.pt',
|
||||
str(local_file)
|
||||
)
|
||||
|
||||
self.model: TTSModelMulti_v2 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
|
||||
self.model.to(device)
|
||||
|
||||
self.sample_rate = 16000
|
||||
|
||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
|
||||
return self.to_wav(self._synthesize_text(text, speaker))
|
||||
|
||||
def _synthesize_text(self, text: str, speaker: Speakers) -> list[torch.Tensor]:
|
||||
"""
|
||||
Performs splitting text and synthesizing it
|
||||
|
||||
:param text:
|
||||
:return:
|
||||
"""
|
||||
|
||||
results_list: list[torch.Tensor] = self.model.apply_tts(
|
||||
texts=[text],
|
||||
speakers=speaker.value,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
return results_list
|
||||
|
||||
def to_wav(self, synthesized_text: list[torch.Tensor]) -> bytes:
|
||||
res_io_stream = io.BytesIO()
|
||||
|
||||
with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(self.sample_rate)
|
||||
for result in synthesized_text:
|
||||
wf.writeframes((result * 32767).numpy().astype('int16'))
|
||||
|
||||
res_io_stream.seek(0)
|
||||
|
||||
return res_io_stream.read()
|
||||
|
||||
|
@ -1,36 +0,0 @@
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
from .TTSSilero import TTSSilero
|
||||
from .Speakers import Speakers
|
||||
import DB
|
||||
import sqlite3
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class TTSSileroCached(TTSSilero):
|
||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
|
||||
# start = time.time()
|
||||
cache_query = DB.SoundCache.select()\
|
||||
.where(DB.SoundCache.text == text)\
|
||||
.where(DB.SoundCache.speaker == speaker.value)
|
||||
|
||||
if cache_query.count() == 1:
|
||||
with DB.database.atomic():
|
||||
DB.SoundCache.update({DB.SoundCache.usages: DB.SoundCache.usages + 1})\
|
||||
.where(DB.SoundCache.text == text)\
|
||||
.where(DB.SoundCache.speaker == speaker.value).execute()
|
||||
|
||||
cached = cache_query.get().audio
|
||||
|
||||
return cached
|
||||
|
||||
else:
|
||||
# logger.debug(f'Starting synthesis')
|
||||
# start2 = time.time()
|
||||
synthesized = super().synthesize_text(text, speaker)
|
||||
# logger.debug(f'Synthesis done in {time.time() - start2} s in {time.time() - start} s after start')
|
||||
DB.SoundCache.create(text=text, speaker=speaker.value, audio=synthesized)
|
||||
|
||||
# logger.debug(f'Cache set in {time.time() - start2} synth end and {time.time() - start2} s after start')
|
||||
return synthesized
|
@ -1,3 +0,0 @@
|
||||
from .TTSSilero import TTSSilero
|
||||
from .Speakers import Speakers
|
||||
from .TTSSileroCached import TTSSileroCached
|
@ -1,148 +0,0 @@
|
||||
import re
|
||||
import wave
|
||||
import torch
|
||||
import warnings
|
||||
import contextlib
|
||||
|
||||
# for type hints only
|
||||
|
||||
|
||||
class TTSModelMulti_v2():
|
||||
def __init__(self, model_path, symbols):
|
||||
self.model = self.init_jit_model(model_path)
|
||||
self.symbols = symbols
|
||||
self.device = torch.device('cpu')
|
||||
speakers = ['aidar', 'baya', 'kseniya', 'irina', 'ruslan', 'natasha',
|
||||
'thorsten', 'tux', 'gilles', 'lj', 'dilyara']
|
||||
self.speaker_to_id = {sp: i for i, sp in enumerate(speakers)}
|
||||
|
||||
def init_jit_model(self, model_path: str):
|
||||
torch.set_grad_enabled(False)
|
||||
model = torch.jit.load(model_path, map_location='cpu')
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def prepare_text_input(self, text, symbols, symbol_to_id=None):
|
||||
if len(text) > 140:
|
||||
warnings.warn('Text string is longer than 140 symbols.')
|
||||
|
||||
if symbol_to_id is None:
|
||||
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
|
||||
text = text.lower()
|
||||
text = re.sub(r'[^{}]'.format(symbols[2:]), '', text)
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
if text[-1] not in ['.', '!', '?']:
|
||||
text = text + '.'
|
||||
text = text + symbols[1]
|
||||
|
||||
text_ohe = [symbol_to_id[s] for s in text if s in symbols]
|
||||
text_tensor = torch.LongTensor(text_ohe)
|
||||
return text_tensor
|
||||
|
||||
def prepare_tts_model_input(self, text: str or list, symbols: str, speakers: list):
|
||||
assert len(speakers) == len(text) or len(speakers) == 1
|
||||
if type(text) == str:
|
||||
text = [text]
|
||||
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
if len(text) == 1:
|
||||
return self.prepare_text_input(text[0], symbols, symbol_to_id).unsqueeze(0), torch.LongTensor(speakers), torch.LongTensor([0])
|
||||
|
||||
text_tensors = []
|
||||
for string in text:
|
||||
string_tensor = self.prepare_text_input(string, symbols, symbol_to_id)
|
||||
text_tensors.append(string_tensor)
|
||||
input_lengths, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([len(t) for t in text_tensors]),
|
||||
dim=0, descending=True)
|
||||
max_input_len = input_lengths[0]
|
||||
batch_size = len(text_tensors)
|
||||
|
||||
text_padded = torch.ones(batch_size, max_input_len, dtype=torch.int32)
|
||||
if len(speakers) == 1:
|
||||
speakers = speakers*batch_size
|
||||
speaker_ids = torch.LongTensor(batch_size).zero_()
|
||||
|
||||
for i, idx in enumerate(ids_sorted_decreasing):
|
||||
text_tensor = text_tensors[idx]
|
||||
in_len = text_tensor.size(0)
|
||||
text_padded[i, :in_len] = text_tensor
|
||||
speaker_ids[i] = speakers[idx]
|
||||
|
||||
return text_padded, speaker_ids, ids_sorted_decreasing
|
||||
|
||||
def process_tts_model_output(self, out, out_lens, ids):
|
||||
out = out.to('cpu')
|
||||
out_lens = out_lens.to('cpu')
|
||||
_, orig_ids = ids.sort()
|
||||
|
||||
proc_outs = []
|
||||
orig_out = out.index_select(0, orig_ids)
|
||||
orig_out_lens = out_lens.index_select(0, orig_ids)
|
||||
|
||||
for i, out_len in enumerate(orig_out_lens):
|
||||
proc_outs.append(orig_out[i][:out_len])
|
||||
return proc_outs
|
||||
|
||||
def to(self, device):
|
||||
self.model = self.model.to(device)
|
||||
self.device = device
|
||||
|
||||
def get_speakers(self, speakers: str or list):
|
||||
if type(speakers) == str:
|
||||
speakers = [speakers]
|
||||
speaker_ids = []
|
||||
for speaker in speakers:
|
||||
try:
|
||||
speaker_id = self.speaker_to_id[speaker]
|
||||
speaker_ids.append(speaker_id)
|
||||
except Exception:
|
||||
raise ValueError(f'No such speaker: {speaker}')
|
||||
return speaker_ids
|
||||
|
||||
def apply_tts(self, texts: str or list,
|
||||
speakers: str or list,
|
||||
sample_rate: int = 16000):
|
||||
speaker_ids = self.get_speakers(speakers)
|
||||
text_padded, speaker_ids, orig_ids = self.prepare_tts_model_input(texts,
|
||||
symbols=self.symbols,
|
||||
speakers=speaker_ids)
|
||||
with torch.inference_mode():
|
||||
out, out_lens = self.model(text_padded.to(self.device),
|
||||
speaker_ids.to(self.device),
|
||||
sr=sample_rate)
|
||||
audios = self.process_tts_model_output(out, out_lens, orig_ids)
|
||||
return audios
|
||||
|
||||
@staticmethod
|
||||
def write_wave(path, audio, sample_rate):
|
||||
"""Writes a .wav file.
|
||||
Takes path, PCM audio data, and sample rate.
|
||||
"""
|
||||
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio)
|
||||
|
||||
def save_wav(self, texts: str or list,
|
||||
speakers: str or list,
|
||||
audio_pathes: str or list = '',
|
||||
sample_rate: int = 16000):
|
||||
if type(texts) == str:
|
||||
texts = [texts]
|
||||
|
||||
if not audio_pathes:
|
||||
audio_pathes = [f'test_{str(i).zfill(3)}.wav' for i in range(len(texts))]
|
||||
if type(audio_pathes) == str:
|
||||
audio_pathes = [audio_pathes]
|
||||
assert len(audio_pathes) == len(texts)
|
||||
|
||||
audio = self.apply_tts(texts=texts,
|
||||
speakers=speakers,
|
||||
sample_rate=sample_rate)
|
||||
for i, _audio in enumerate(audio):
|
||||
self.write_wave(path=audio_pathes[i],
|
||||
audio=(_audio * 32767).numpy().astype('int16'),
|
||||
sample_rate=sample_rate)
|
||||
return audio_pathes
|
@ -1,6 +1,6 @@
|
||||
from loguru import logger
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class cogErrorHandlers:
|
||||
|
@ -1,9 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import time
|
||||
|
||||
import discord
|
||||
from discord.ext import commands, tasks
|
||||
from discord.ext.commands import Context
|
||||
import datetime
|
||||
import time
|
||||
|
||||
|
||||
class BotInformation(commands.Cog):
|
||||
@ -15,9 +16,9 @@ class BotInformation(commands.Cog):
|
||||
@commands.command('info')
|
||||
async def get_info(self, ctx: Context):
|
||||
info = f"""
|
||||
Text-To-Speech bot, based on Silero TTS model (<https://github.com/snakers4/silero-models>)
|
||||
Text-To-Speech bot
|
||||
License: `GNU GENERAL PUBLIC LICENSE Version 3`
|
||||
Author discord: `a31#6403`
|
||||
Author discord: `@a31`
|
||||
Author email: `a31@demb.uk`
|
||||
Source code on github: <https://github.com/norohind/SileroTTSBot>
|
||||
Source code on gitea a31's instance: <https://gitea.demb.uk/a31/SileroTTSBot>
|
||||
|
@ -1,15 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import time
|
||||
from formatting import format_table
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
import rfoo
|
||||
from rfoo.utils import rconsole
|
||||
try:
|
||||
import rfoo
|
||||
from rfoo.utils import rconsole
|
||||
|
||||
except ImportError as e:
|
||||
print("not importing rfoo", e)
|
||||
|
||||
from threading import Thread
|
||||
from typing import Optional, Coroutine
|
||||
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
from loguru import logger
|
||||
|
||||
from formatting import format_table
|
||||
|
||||
|
||||
class SingletonBase:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
@ -24,7 +31,7 @@ class BotManagement(commands.Cog, SingletonBase):
|
||||
instance: Optional['BotManagement'] = None
|
||||
|
||||
rfoo_server_thread: Optional[Thread] = None
|
||||
rfoo_server: Optional[rfoo.InetServer] = None
|
||||
rfoo_server: Optional['rfoo.InetServer'] = None
|
||||
|
||||
# def __new__(cls, *args, **kwargs):
|
||||
# if cls.instance is None:
|
||||
@ -105,7 +112,7 @@ class BotManagement(commands.Cog, SingletonBase):
|
||||
@commands.command('shutdown')
|
||||
async def shutdown(self, ctx: Context) -> None:
|
||||
"""Shutdown bot with hope docker daemon will restart it, `@a31` and `@furiz__` has rights for it"""
|
||||
if ctx.author.id in (420130693696323585, # @a31
|
||||
if ctx.author.id in (420130693696323585, # @a31
|
||||
444819880781545472): # @furiz__
|
||||
|
||||
log_msg = f"Got shutdown command by {ctx.author}"
|
||||
@ -116,7 +123,7 @@ class BotManagement(commands.Cog, SingletonBase):
|
||||
await (await self.bot.fetch_user(420130693696323585)).send(log_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(f'Failed to send shutdown message',)
|
||||
logger.opt(exception=e).warning(f'Failed to send shutdown message', )
|
||||
|
||||
self.bot.loop.create_task(self.bot.close())
|
||||
|
||||
@ -127,7 +134,7 @@ class BotManagement(commands.Cog, SingletonBase):
|
||||
async def setup(bot: commands.Bot):
|
||||
await bot.add_cog(BotManagement(bot))
|
||||
|
||||
|
||||
async def teardown(bot):
|
||||
stop_res = BotManagement(bot).stop_rfoo()
|
||||
logger.info(f'Unloaded rfoo with result {stop_res} during BotManagement unload')
|
||||
|
||||
|
134
cogs/TTSCore.py
134
cogs/TTSCore.py
@ -2,18 +2,21 @@
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Union, Optional
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
import discord
|
||||
import DB
|
||||
from typing import Union, Optional
|
||||
from loguru import logger
|
||||
from TTSSilero import TTSSileroCached
|
||||
from TTSSilero import Speakers
|
||||
from FFmpegPCMAudioModified import FFmpegPCMAudio
|
||||
|
||||
import DB
|
||||
import Observ
|
||||
import formatting
|
||||
from FFmpegPCMAudioModified import FFmpegPCMAudio
|
||||
from SpeakersSettingsAdapterDiscord import SpeakersSettingsAdapterDiscord
|
||||
from TTSServer import TTSServerCached
|
||||
from TTSServer.Speakers import ModelDescription
|
||||
from cogErrorHandlers import cogErrorHandlers
|
||||
from SpeakersSettingsAdapterDiscord import speakers_settings_adapter, SpeakersSettingsAdapterDiscord
|
||||
|
||||
|
||||
class TTSCore(commands.Cog, Observ.Observer):
|
||||
@ -21,9 +24,9 @@ class TTSCore(commands.Cog, Observ.Observer):
|
||||
self.bot = bot
|
||||
self.cog_command_error = cogErrorHandlers.missing_argument_handler
|
||||
self.bot.subscribe(self) # subscribe for messages that aren't commands
|
||||
self.tts = TTSSileroCached()
|
||||
self.tts = TTSServerCached()
|
||||
self.tts_queues: dict[int, list[discord.AudioSource]] = defaultdict(list)
|
||||
self.speakers_adapter: SpeakersSettingsAdapterDiscord = speakers_settings_adapter
|
||||
self.speakers_adapter: SpeakersSettingsAdapterDiscord = SpeakersSettingsAdapterDiscord()
|
||||
|
||||
@commands.command('drop')
|
||||
async def drop_queue(self, ctx: Context):
|
||||
@ -89,25 +92,28 @@ class TTSCore(commands.Cog, Observ.Observer):
|
||||
await message.channel.send('We are in different voice channels')
|
||||
return
|
||||
|
||||
speaker: Speakers = self.speakers_adapter.get_speaker(message.guild.id, message.author.id)
|
||||
speaker_str: str = self.speakers_adapter.get_speaker(message.guild.id, message.author.id)
|
||||
speaker: ModelDescription | None = await self.tts.speaker_to_description(speaker_str)
|
||||
if speaker is None:
|
||||
speaker = self.tts.default_speaker()
|
||||
|
||||
# check if message will fail on synthesis
|
||||
if DB.SynthesisErrors.select()\
|
||||
.where(DB.SynthesisErrors.speaker == speaker.value)\
|
||||
.where(DB.SynthesisErrors.text == message.content)\
|
||||
if DB.SynthesisErrors.select() \
|
||||
.where(DB.SynthesisErrors.speaker == speaker_str) \
|
||||
.where(DB.SynthesisErrors.text == message.content) \
|
||||
.count() == 1:
|
||||
# Then we will not try to synthesis it
|
||||
await message.channel.send(f"I will not synthesis this message due to TTS engine limitations")
|
||||
return
|
||||
|
||||
try:
|
||||
wav_file_like_object = self.tts.synthesize_text(message.content, speaker=speaker)
|
||||
wav_file_like_object: bytes = await self.tts.synthesize_text(message.content, speaker=speaker)
|
||||
sound_source = FFmpegPCMAudio(wav_file_like_object, pipe=True, stderr=subprocess.PIPE)
|
||||
|
||||
except Exception as synth_exception:
|
||||
logger.opt(exception=True).warning(f'Exception on synthesize {message.content!r}: {synth_exception}')
|
||||
await message.channel.send(f'Synthesize error')
|
||||
DB.SynthesisErrors.create(speaker=speaker.value, text=message.content)
|
||||
DB.SynthesisErrors.create(speaker=speaker, text=message.content)
|
||||
return
|
||||
|
||||
else:
|
||||
@ -115,13 +121,15 @@ class TTSCore(commands.Cog, Observ.Observer):
|
||||
if voice_client.is_playing():
|
||||
# Then we need to enqueue prepared sound for playing through self.tts_queues mechanism
|
||||
self.tts_queues[message.guild.id].append(sound_source)
|
||||
await message.channel.send(f"Enqueued for play, queue size: {len(self.tts_queues[message.guild.id])}")
|
||||
await message.channel.send(
|
||||
f"Enqueued for play, queue size: {len(self.tts_queues[message.guild.id])}")
|
||||
return
|
||||
|
||||
voice_client.play(sound_source, after=lambda e: self.queue_player(message))
|
||||
|
||||
except Exception as play_exception:
|
||||
logger.opt(exception=True).warning(f'Exception on playing for: {message.guild.name}[#{message.channel.name}]: {message.author.display_name} / {play_exception}')
|
||||
logger.opt(exception=True).warning(
|
||||
f'Exception on playing for: {message.guild.name}[#{message.channel.name}]: {message.author.display_name} / {play_exception}')
|
||||
await message.channel.send(f'Playing error')
|
||||
return
|
||||
|
||||
@ -151,7 +159,8 @@ class TTSCore(commands.Cog, Observ.Observer):
|
||||
pass
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_voice_state_update(self, member: discord.Member, before: discord.VoiceState, after: discord.VoiceState):
|
||||
async def on_voice_state_update(self, member: discord.Member, before: discord.VoiceState,
|
||||
after: discord.VoiceState):
|
||||
if after.channel is None:
|
||||
members = before.channel.members
|
||||
if len(members) == 1:
|
||||
@ -160,6 +169,95 @@ class TTSCore(commands.Cog, Observ.Observer):
|
||||
|
||||
# TODO: leave voice channel after being moved there alone
|
||||
|
||||
# Ex TTSSettings.py
|
||||
@commands.command('getAllSpeakers')
|
||||
async def get_speakers(self, ctx: Context):
|
||||
"""
|
||||
Enumerate all available to set speakers
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
|
||||
await ctx.send(
|
||||
formatting.format_table(
|
||||
tuple((model.engine + '_' + model.name, model.description)
|
||||
for model in await self.tts.discovery()),
|
||||
('model', 'description')
|
||||
))
|
||||
|
||||
@commands.command('setPersonalSpeaker')
|
||||
async def set_user_speaker(self, ctx: Context, speaker: str):
|
||||
"""
|
||||
Set personal speaker on this server
|
||||
|
||||
:param ctx:
|
||||
:param speaker:
|
||||
:return:
|
||||
"""
|
||||
|
||||
if await self.tts.speaker_to_description(speaker) is None:
|
||||
return await ctx.send(
|
||||
f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||
|
||||
self.speakers_adapter.set_speaker_user(ctx.guild.id, ctx.author.id, speaker)
|
||||
await ctx.reply(f'Successfully set **your personal** speaker to `{speaker}`')
|
||||
|
||||
@commands.command('setServerSpeaker')
|
||||
async def set_server_speaker(self, ctx: Context, speaker: str):
|
||||
"""
|
||||
Set global server speaker
|
||||
|
||||
:param ctx:
|
||||
:param speaker:
|
||||
:return:
|
||||
"""
|
||||
if await self.tts.speaker_to_description(speaker) is None:
|
||||
return await ctx.send(
|
||||
f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||
|
||||
self.speakers_adapter.set_speaker_global(ctx.guild.id, ctx.author.id, speaker)
|
||||
await ctx.reply(f'Successfully set **your personal** speaker to `{speaker}`')
|
||||
|
||||
@commands.command('getSpeaker')
|
||||
async def get_speaker(self, ctx: Context):
|
||||
"""
|
||||
Tell first appropriate speaker for a user, it can be user specified, server specified or server default
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speaker = self.speakers_adapter.get_speaker(ctx.guild.id, ctx.author.id)
|
||||
|
||||
await ctx.reply(f'Your current speaker is `{speaker}`')
|
||||
|
||||
@commands.command('getPersonalSpeaker')
|
||||
async def get_personal_speaker(self, ctx: Context):
|
||||
"""
|
||||
Tell user his personal speaker on this server, if user don't have personal speaker, tells server default one
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speaker = self.speakers_adapter.get_speaker_user(ctx.guild.id, ctx.author.id)
|
||||
if speaker is None:
|
||||
server_speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
|
||||
await ctx.send(f"You currently don't have a personal speaker, current server speaker is `{server_speaker}`")
|
||||
|
||||
else:
|
||||
await ctx.reply(f"Your personal speaker is `{speaker}`")
|
||||
|
||||
@commands.command('getServerSpeaker')
|
||||
async def get_server_speaker(self, ctx: Context):
|
||||
"""
|
||||
Tell server global speaker
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
|
||||
await ctx.send(f"Current server speaker is `{speaker}`")
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(TTSCore(bot))
|
||||
|
@ -1,102 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
from TTSSilero import Speakers
|
||||
from cogErrorHandlers import cogErrorHandlers
|
||||
from SpeakersSettingsAdapterDiscord import speakers_settings_adapter, SpeakersSettingsAdapterDiscord
|
||||
|
||||
|
||||
class TTSSettings(commands.Cog):
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
self.cog_command_error = cogErrorHandlers.missing_argument_handler
|
||||
self.speakers_adapter: SpeakersSettingsAdapterDiscord = speakers_settings_adapter
|
||||
|
||||
@commands.command('getAllSpeakers')
|
||||
async def get_speakers(self, ctx: Context):
|
||||
"""
|
||||
Enumerate all available to set speakers
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speakers = '\n'.join(self.speakers_adapter.available_speakers)
|
||||
|
||||
await ctx.send(f"```\n{speakers}```")
|
||||
|
||||
@commands.command('setPersonalSpeaker')
|
||||
async def set_user_speaker(self, ctx: Context, speaker: str):
|
||||
"""
|
||||
Set personal speaker on this server
|
||||
|
||||
:param ctx:
|
||||
:param speaker:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
checked_speaker: Speakers = Speakers(speaker)
|
||||
self.speakers_adapter.set_speaker_user(ctx.guild.id, ctx.author.id, checked_speaker)
|
||||
await ctx.reply(f'Successfully set **your personal** speaker to `{checked_speaker.value}`')
|
||||
|
||||
except (KeyError, ValueError):
|
||||
await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||
|
||||
@commands.command('setServerSpeaker')
|
||||
async def set_server_speaker(self, ctx: Context, speaker: str):
|
||||
"""
|
||||
Set global server speaker
|
||||
|
||||
:param ctx:
|
||||
:param speaker:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
checked_speaker: Speakers = Speakers(speaker)
|
||||
self.speakers_adapter.set_speaker_global(ctx.guild.id, checked_speaker)
|
||||
await ctx.send(f'Successfully set **server** speaker to `{checked_speaker.value}`')
|
||||
|
||||
except (KeyError, ValueError):
|
||||
await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||
|
||||
@commands.command('getSpeaker')
|
||||
async def get_speaker(self, ctx: Context):
|
||||
"""
|
||||
Tell first appropriate speaker for a user, it can be user specified, server specified or server default
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speaker = self.speakers_adapter.get_speaker(ctx.guild.id, ctx.author.id)
|
||||
|
||||
await ctx.reply(f'Your current speaker is `{speaker.value}`')
|
||||
|
||||
@commands.command('getPersonalSpeaker')
|
||||
async def get_personal_speaker(self, ctx: Context):
|
||||
"""
|
||||
Tell user his personal speaker on this server, if user don't have personal speaker, tells server default one
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speaker = self.speakers_adapter.get_speaker_user(ctx.guild.id, ctx.author.id)
|
||||
if speaker is None:
|
||||
server_speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id).value
|
||||
await ctx.send(f"You currently don't have a personal speaker, current server speaker is `{server_speaker}`")
|
||||
|
||||
else:
|
||||
await ctx.reply(f"Your personal speaker is `{speaker.value}`")
|
||||
|
||||
@commands.command('getServerSpeaker')
|
||||
async def get_server_speaker(self, ctx: Context):
|
||||
"""
|
||||
Tell server global speaker
|
||||
|
||||
:param ctx:
|
||||
:return:
|
||||
"""
|
||||
speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
|
||||
await ctx.send(f"Current server speaker is `{speaker.value}`")
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(TTSSettings(bot))
|
@ -1,8 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
import DB
|
||||
from loguru import logger
|
||||
|
||||
import DB
|
||||
import DynamicCommandPrefix
|
||||
from cogErrorHandlers import cogErrorHandlers
|
||||
|
||||
|
5
config.py
Normal file
5
config.py
Normal file
@ -0,0 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
DB_PATH = Path(os.getenv('DATA_DIR', '.')) / 'voice_cache.sqlite'
|
||||
BASE_URL = os.environ['BASE_URL']
|
@ -11,7 +11,7 @@ def format_table(data: Iterable[Iterable[str]], header: Iterable[str] = MISSING)
|
||||
|
||||
result = '```\n'
|
||||
for row in data:
|
||||
row = [item.replace('`', '\\`') for item in row]
|
||||
row = [str(item).replace('`', '\\`') for item in row]
|
||||
result += '\t'.join(row) + '\n'
|
||||
|
||||
result += '```'
|
||||
|
13
main.py
13
main.py
@ -1,12 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
from discord.ext import commands
|
||||
import discord
|
||||
import signal
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from loguru import logger
|
||||
from DynamicCommandPrefix import dynamic_command_prefix
|
||||
|
||||
import Observ
|
||||
from DynamicCommandPrefix import dynamic_command_prefix
|
||||
|
||||
LOG_FILE_ENABLED = os.getenv('LOG_FILE_ENABLED', 'true').lower() == 'true'
|
||||
|
||||
@ -68,6 +70,7 @@ async def main():
|
||||
|
||||
await discord_client.start(os.environ['DISCORD_TOKEN'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
|
@ -4,6 +4,4 @@ peewee==3.17.5
|
||||
PyNaCl==1.5.0
|
||||
git+https://github.com/aaiyer/rfoo.git@1555bd4eed204bb6a33a5e313146a6c2813cfe91
|
||||
# Cython is setup dependency for rfoo
|
||||
--index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.3.0+cpu
|
||||
numpy==1.26.4
|
||||
pydantic
|
Loading…
x
Reference in New Issue
Block a user