Compare commits
25 Commits
Author | SHA1 | Date | |
---|---|---|---|
db62a98f0a | |||
c8d32959bb | |||
b103494f2a | |||
ed2dfa0133 | |||
67d2d09fa3 | |||
763a6a247d | |||
155df0d9f3 | |||
d6d416cd1a | |||
71e18a595f | |||
d23ef5b2fa | |||
3dc3431801 | |||
236e24af4a | |||
92303b9a3a | |||
e538b6c23b | |||
f431733940 | |||
e9ef1aca3c | |||
4eb68360ed | |||
eba6a289f8 | |||
e041913d5b | |||
bce9abd10a | |||
c84831aac2 | |||
10b280d6dc | |||
84182d0d28 | |||
c6c5d9c9a7 | |||
4b9f1cbad1 |
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
.idea
|
||||||
|
*__pycache__
|
||||||
|
*.sqlite
|
||||||
|
*.log
|
||||||
|
model_multi.pt
|
4
DB.py
4
DB.py
@ -2,7 +2,9 @@ from datetime import datetime
|
|||||||
|
|
||||||
import peewee
|
import peewee
|
||||||
|
|
||||||
database = peewee.SqliteDatabase('voice_cache.sqlite')
|
import config
|
||||||
|
|
||||||
|
database = peewee.SqliteDatabase(str(config.DB_PATH))
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(peewee.Model):
|
class BaseModel(peewee.Model):
|
||||||
|
33
Dockerfile
Normal file
33
Dockerfile
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
FROM python:3.10-slim as builder
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE 1
|
||||||
|
ENV PYTHONUNBUFFERED 1
|
||||||
|
|
||||||
|
RUN apt update && apt install git gcc libc6-dev -y --no-install-recommends && apt clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip pip install Cython && \
|
||||||
|
pip wheel --no-deps --wheel-dir /app/wheels -r requirements.txt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE 1
|
||||||
|
ENV PYTHONUNBUFFERED 1
|
||||||
|
|
||||||
|
RUN useradd -ms /bin/bash silero_user && \
|
||||||
|
apt update && apt install ffmpeg -y --no-install-recommends && apt clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /app/wheels /wheels
|
||||||
|
COPY --from=builder /app/requirements.txt .
|
||||||
|
|
||||||
|
RUN pip install --no-cache /wheels/*
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
USER silero_user
|
||||||
|
|
||||||
|
CMD ["python3", "/app/main.py"]
|
@ -1,8 +1,11 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
|
||||||
import peewee
|
import peewee
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
import DB
|
import DB
|
||||||
|
|
||||||
|
|
||||||
# from loguru import logger
|
# from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import subprocess
|
|
||||||
import shlex
|
|
||||||
import io
|
import io
|
||||||
from discord.opus import Encoder
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
from discord.opus import Encoder
|
||||||
|
|
||||||
|
|
||||||
# Huge thanks to https://github.com/Rapptz/discord.py/issues/5192#issuecomment-669515571
|
# Huge thanks to https://github.com/Rapptz/discord.py/issues/5192#issuecomment-669515571
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
|
|
||||||
class Observer:
|
class Observer:
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(self, message: discord.Message) -> None:
|
async def update(self, message: discord.Message) -> None:
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import discord
|
import discord
|
||||||
|
|
||||||
from .Observer import Observer
|
from .Observer import Observer
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import DB
|
import DB
|
||||||
from TTSSilero import Speakers
|
|
||||||
|
|
||||||
|
|
||||||
class SpeakersSettingsAdapterDiscord:
|
class SpeakersSettingsAdapterDiscord:
|
||||||
DEFAULT_SPEAKER = Speakers.kseniya
|
def get_speaker(self, guild_id: int, user_id: int) -> str:
|
||||||
|
|
||||||
def get_speaker(self, guild_id: int, user_id: int) -> Speakers:
|
|
||||||
user_defined_speaker = self.get_speaker_user(guild_id, user_id)
|
user_defined_speaker = self.get_speaker_user(guild_id, user_id)
|
||||||
if user_defined_speaker is None:
|
if user_defined_speaker is None:
|
||||||
return self.get_speaker_global(guild_id)
|
return self.get_speaker_global(guild_id)
|
||||||
@ -14,36 +11,29 @@ class SpeakersSettingsAdapterDiscord:
|
|||||||
else:
|
else:
|
||||||
return user_defined_speaker
|
return user_defined_speaker
|
||||||
|
|
||||||
def get_speaker_global(self, guild_id: int) -> Speakers:
|
def get_speaker_global(self, guild_id: int) -> str | None:
|
||||||
server_speaker_query = DB.ServerSpeaker.select()\
|
server_speaker_query = DB.ServerSpeaker.select()\
|
||||||
.where(DB.ServerSpeaker.server_id == guild_id)
|
.where(DB.ServerSpeaker.server_id == guild_id)
|
||||||
|
|
||||||
if server_speaker_query.count() == 1:
|
if server_speaker_query.count() == 1:
|
||||||
return Speakers(server_speaker_query.get().speaker)
|
return server_speaker_query.get().speaker
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return self.DEFAULT_SPEAKER
|
return None
|
||||||
|
|
||||||
def get_speaker_user(self, guild_id: int, user_id: int) -> Speakers | None:
|
def get_speaker_user(self, guild_id: int, user_id: int) -> str | None:
|
||||||
user_speaker_query = DB.UserServerSpeaker.select()\
|
user_speaker_query = DB.UserServerSpeaker.select()\
|
||||||
.where(DB.UserServerSpeaker.server_id == guild_id)\
|
.where(DB.UserServerSpeaker.server_id == guild_id)\
|
||||||
.where(DB.UserServerSpeaker.user_id == user_id)
|
.where(DB.UserServerSpeaker.user_id == user_id)
|
||||||
|
|
||||||
if user_speaker_query.count() == 1:
|
if user_speaker_query.count() == 1:
|
||||||
return Speakers(user_speaker_query.get().speaker)
|
return user_speaker_query.get().speaker
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
def set_speaker_user(self, guild_id: int, user_id: int, speaker: str) -> None:
|
||||||
def available_speakers(self) -> set[str]:
|
DB.UserServerSpeaker.replace(server_id=guild_id, user_id=user_id, speaker=speaker).execute()
|
||||||
return {speaker.name for speaker in Speakers}
|
|
||||||
|
|
||||||
def set_speaker_user(self, guild_id: int, user_id: int, speaker: Speakers) -> None:
|
def set_speaker_global(self, guild_id: int, user_id: int, speaker: str) -> None:
|
||||||
DB.UserServerSpeaker.replace(server_id=guild_id, user_id=user_id, speaker=speaker.value).execute()
|
DB.ServerSpeaker.replace(server_id=guild_id, speaker=speaker).execute()
|
||||||
|
|
||||||
def set_speaker_global(self, guild_id: int, speaker: Speakers) -> None:
|
|
||||||
DB.ServerSpeaker.replace(server_id=guild_id, speaker=speaker.value).execute()
|
|
||||||
|
|
||||||
|
|
||||||
speakers_settings_adapter = SpeakersSettingsAdapterDiscord()
|
|
||||||
|
15
TTSServer/Speakers.py
Normal file
15
TTSServer/Speakers.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Argument(BaseModel):
|
||||||
|
type: Literal['str', 'int', 'float']
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDescription(BaseModel):
|
||||||
|
engine: str
|
||||||
|
name: str
|
||||||
|
arguments: dict[str, Argument]
|
||||||
|
description: None | str = None
|
19
TTSServer/TTSServer.py
Normal file
19
TTSServer/TTSServer.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import aiohttp
|
||||||
|
|
||||||
|
import config
|
||||||
|
from .Speakers import ModelDescription
|
||||||
|
|
||||||
|
|
||||||
|
class TTSServer:
|
||||||
|
def __init__(self):
|
||||||
|
self.session = aiohttp.ClientSession(base_url=config.BASE_URL)
|
||||||
|
|
||||||
|
async def synthesize_text(self, text: str, speaker: ModelDescription) -> bytes:
|
||||||
|
async with self.session.post(url=f'/synth/{speaker.engine}/model/{speaker.name}', json={'text': text}) as req:
|
||||||
|
req.raise_for_status()
|
||||||
|
return await req.content.read()
|
||||||
|
|
||||||
|
async def discovery(self) -> list[ModelDescription]:
|
||||||
|
async with self.session.get('/discovery') as req:
|
||||||
|
req.raise_for_status()
|
||||||
|
return [ModelDescription(**item) for item in await req.json()]
|
53
TTSServer/TTSServerCached.py
Normal file
53
TTSServer/TTSServerCached.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
import DB
|
||||||
|
from .Speakers import ModelDescription
|
||||||
|
from .TTSServer import TTSServer
|
||||||
|
|
||||||
|
|
||||||
|
class TTSServerCached(TTSServer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.models_lookup_idx: dict[str, ModelDescription] = dict()
|
||||||
|
|
||||||
|
async def synthesize_text(self, text: str, speaker: ModelDescription) -> bytes:
|
||||||
|
cache_query = DB.SoundCache.select() \
|
||||||
|
.where(DB.SoundCache.text == text) \
|
||||||
|
.where(DB.SoundCache.speaker == speaker)
|
||||||
|
|
||||||
|
if cache_query.count() == 1:
|
||||||
|
with DB.database.atomic():
|
||||||
|
DB.SoundCache.update({DB.SoundCache.usages: DB.SoundCache.usages + 1}) \
|
||||||
|
.where(DB.SoundCache.text == text) \
|
||||||
|
.where(DB.SoundCache.speaker == speaker).execute()
|
||||||
|
|
||||||
|
cached = cache_query.get().audio
|
||||||
|
|
||||||
|
return cached
|
||||||
|
|
||||||
|
else:
|
||||||
|
synthesized = await super().synthesize_text(text, speaker)
|
||||||
|
DB.SoundCache.create(text=text, speaker=speaker, audio=synthesized)
|
||||||
|
return synthesized
|
||||||
|
|
||||||
|
async def discovery(self) -> list[ModelDescription]:
|
||||||
|
res = await super().discovery()
|
||||||
|
logger.debug(f'Discovered {len(res)} models')
|
||||||
|
self.models_lookup_idx: dict[str, ModelDescription] = {f'{desc.engine}_{desc.name}': desc for desc in res}
|
||||||
|
return res
|
||||||
|
|
||||||
|
async def speaker_to_description(self, speaker_str: str) -> ModelDescription | None:
|
||||||
|
for attempt in range(2):
|
||||||
|
if len(self.models_lookup_idx) > 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
await self.discovery()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Models discovery seems to return zero models')
|
||||||
|
|
||||||
|
# Default to first model
|
||||||
|
return self.models_lookup_idx.get(speaker_str, None)
|
||||||
|
|
||||||
|
def default_speaker(self) -> ModelDescription:
|
||||||
|
return list(self.models_lookup_idx.values())[0]
|
3
TTSServer/__init__.py
Normal file
3
TTSServer/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .Speakers import ModelDescription
|
||||||
|
from .TTSServer import TTSServer
|
||||||
|
from .TTSServerCached import TTSServerCached
|
@ -1,15 +0,0 @@
|
|||||||
import enum
|
|
||||||
|
|
||||||
|
|
||||||
class Speakers(enum.Enum):
|
|
||||||
aidar = 'aidar'
|
|
||||||
baya = 'baya'
|
|
||||||
kseniya = 'kseniya'
|
|
||||||
irina = 'irina'
|
|
||||||
ruslan = 'ruslan'
|
|
||||||
natasha = 'natasha'
|
|
||||||
thorsten = 'thorsten'
|
|
||||||
tux = 'tux'
|
|
||||||
gilles = 'gilles'
|
|
||||||
lj = 'lj'
|
|
||||||
dilyara = 'dilyara'
|
|
@ -1,61 +0,0 @@
|
|||||||
import contextlib
|
|
||||||
import os
|
|
||||||
import io
|
|
||||||
import wave
|
|
||||||
import torch.package
|
|
||||||
|
|
||||||
from .Speakers import Speakers
|
|
||||||
from .multi_v2_package import TTSModelMulti_v2
|
|
||||||
|
|
||||||
|
|
||||||
class TTSSilero:
|
|
||||||
def __init__(self, threads: int = 12):
|
|
||||||
device = torch.device('cpu')
|
|
||||||
torch.set_num_threads(threads)
|
|
||||||
local_file = 'model_multi.pt'
|
|
||||||
|
|
||||||
if not os.path.isfile(local_file):
|
|
||||||
torch.hub.download_url_to_file(
|
|
||||||
'https://models.silero.ai/models/tts/multi/v2_multi.pt',
|
|
||||||
local_file
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model: TTSModelMulti_v2 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
|
|
||||||
self.model.to(device)
|
|
||||||
|
|
||||||
self.sample_rate = 16000
|
|
||||||
|
|
||||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
|
|
||||||
return self.to_wav(self._synthesize_text(text, speaker))
|
|
||||||
|
|
||||||
def _synthesize_text(self, text: str, speaker: Speakers) -> list[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Performs splitting text and synthesizing it
|
|
||||||
|
|
||||||
:param text:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
results_list: list[torch.Tensor] = self.model.apply_tts(
|
|
||||||
texts=[text],
|
|
||||||
speakers=speaker.value,
|
|
||||||
sample_rate=self.sample_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
return results_list
|
|
||||||
|
|
||||||
def to_wav(self, synthesized_text: list[torch.Tensor]) -> bytes:
|
|
||||||
res_io_stream = io.BytesIO()
|
|
||||||
|
|
||||||
with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf:
|
|
||||||
wf.setnchannels(1)
|
|
||||||
wf.setsampwidth(2)
|
|
||||||
wf.setframerate(self.sample_rate)
|
|
||||||
for result in synthesized_text:
|
|
||||||
wf.writeframes((result * 32767).numpy().astype('int16'))
|
|
||||||
|
|
||||||
res_io_stream.seek(0)
|
|
||||||
|
|
||||||
return res_io_stream.read()
|
|
||||||
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from .TTSSilero import TTSSilero
|
|
||||||
from .Speakers import Speakers
|
|
||||||
import DB
|
|
||||||
import sqlite3
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
|
|
||||||
class TTSSileroCached(TTSSilero):
|
|
||||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
|
|
||||||
# start = time.time()
|
|
||||||
cache_query = DB.SoundCache.select()\
|
|
||||||
.where(DB.SoundCache.text == text)\
|
|
||||||
.where(DB.SoundCache.speaker == speaker.value)
|
|
||||||
|
|
||||||
if cache_query.count() == 1:
|
|
||||||
with DB.database.atomic():
|
|
||||||
DB.SoundCache.update({DB.SoundCache.usages: DB.SoundCache.usages + 1})\
|
|
||||||
.where(DB.SoundCache.text == text)\
|
|
||||||
.where(DB.SoundCache.speaker == speaker.value).execute()
|
|
||||||
|
|
||||||
cached = cache_query.get().audio
|
|
||||||
|
|
||||||
return cached
|
|
||||||
|
|
||||||
else:
|
|
||||||
# logger.debug(f'Starting synthesis')
|
|
||||||
# start2 = time.time()
|
|
||||||
synthesized = super().synthesize_text(text, speaker)
|
|
||||||
# logger.debug(f'Synthesis done in {time.time() - start2} s in {time.time() - start} s after start')
|
|
||||||
DB.SoundCache.create(text=text, speaker=speaker.value, audio=synthesized)
|
|
||||||
|
|
||||||
# logger.debug(f'Cache set in {time.time() - start2} synth end and {time.time() - start2} s after start')
|
|
||||||
return synthesized
|
|
@ -1,3 +0,0 @@
|
|||||||
from .TTSSilero import TTSSilero
|
|
||||||
from .Speakers import Speakers
|
|
||||||
from .TTSSileroCached import TTSSileroCached
|
|
@ -1,148 +0,0 @@
|
|||||||
import re
|
|
||||||
import wave
|
|
||||||
import torch
|
|
||||||
import warnings
|
|
||||||
import contextlib
|
|
||||||
|
|
||||||
# for type hints only
|
|
||||||
|
|
||||||
|
|
||||||
class TTSModelMulti_v2():
|
|
||||||
def __init__(self, model_path, symbols):
|
|
||||||
self.model = self.init_jit_model(model_path)
|
|
||||||
self.symbols = symbols
|
|
||||||
self.device = torch.device('cpu')
|
|
||||||
speakers = ['aidar', 'baya', 'kseniya', 'irina', 'ruslan', 'natasha',
|
|
||||||
'thorsten', 'tux', 'gilles', 'lj', 'dilyara']
|
|
||||||
self.speaker_to_id = {sp: i for i, sp in enumerate(speakers)}
|
|
||||||
|
|
||||||
def init_jit_model(self, model_path: str):
|
|
||||||
torch.set_grad_enabled(False)
|
|
||||||
model = torch.jit.load(model_path, map_location='cpu')
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
def prepare_text_input(self, text, symbols, symbol_to_id=None):
|
|
||||||
if len(text) > 140:
|
|
||||||
warnings.warn('Text string is longer than 140 symbols.')
|
|
||||||
|
|
||||||
if symbol_to_id is None:
|
|
||||||
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
|
||||||
|
|
||||||
text = text.lower()
|
|
||||||
text = re.sub(r'[^{}]'.format(symbols[2:]), '', text)
|
|
||||||
text = re.sub(r'\s+', ' ', text).strip()
|
|
||||||
if text[-1] not in ['.', '!', '?']:
|
|
||||||
text = text + '.'
|
|
||||||
text = text + symbols[1]
|
|
||||||
|
|
||||||
text_ohe = [symbol_to_id[s] for s in text if s in symbols]
|
|
||||||
text_tensor = torch.LongTensor(text_ohe)
|
|
||||||
return text_tensor
|
|
||||||
|
|
||||||
def prepare_tts_model_input(self, text: str or list, symbols: str, speakers: list):
|
|
||||||
assert len(speakers) == len(text) or len(speakers) == 1
|
|
||||||
if type(text) == str:
|
|
||||||
text = [text]
|
|
||||||
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
|
||||||
if len(text) == 1:
|
|
||||||
return self.prepare_text_input(text[0], symbols, symbol_to_id).unsqueeze(0), torch.LongTensor(speakers), torch.LongTensor([0])
|
|
||||||
|
|
||||||
text_tensors = []
|
|
||||||
for string in text:
|
|
||||||
string_tensor = self.prepare_text_input(string, symbols, symbol_to_id)
|
|
||||||
text_tensors.append(string_tensor)
|
|
||||||
input_lengths, ids_sorted_decreasing = torch.sort(
|
|
||||||
torch.LongTensor([len(t) for t in text_tensors]),
|
|
||||||
dim=0, descending=True)
|
|
||||||
max_input_len = input_lengths[0]
|
|
||||||
batch_size = len(text_tensors)
|
|
||||||
|
|
||||||
text_padded = torch.ones(batch_size, max_input_len, dtype=torch.int32)
|
|
||||||
if len(speakers) == 1:
|
|
||||||
speakers = speakers*batch_size
|
|
||||||
speaker_ids = torch.LongTensor(batch_size).zero_()
|
|
||||||
|
|
||||||
for i, idx in enumerate(ids_sorted_decreasing):
|
|
||||||
text_tensor = text_tensors[idx]
|
|
||||||
in_len = text_tensor.size(0)
|
|
||||||
text_padded[i, :in_len] = text_tensor
|
|
||||||
speaker_ids[i] = speakers[idx]
|
|
||||||
|
|
||||||
return text_padded, speaker_ids, ids_sorted_decreasing
|
|
||||||
|
|
||||||
def process_tts_model_output(self, out, out_lens, ids):
|
|
||||||
out = out.to('cpu')
|
|
||||||
out_lens = out_lens.to('cpu')
|
|
||||||
_, orig_ids = ids.sort()
|
|
||||||
|
|
||||||
proc_outs = []
|
|
||||||
orig_out = out.index_select(0, orig_ids)
|
|
||||||
orig_out_lens = out_lens.index_select(0, orig_ids)
|
|
||||||
|
|
||||||
for i, out_len in enumerate(orig_out_lens):
|
|
||||||
proc_outs.append(orig_out[i][:out_len])
|
|
||||||
return proc_outs
|
|
||||||
|
|
||||||
def to(self, device):
|
|
||||||
self.model = self.model.to(device)
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def get_speakers(self, speakers: str or list):
|
|
||||||
if type(speakers) == str:
|
|
||||||
speakers = [speakers]
|
|
||||||
speaker_ids = []
|
|
||||||
for speaker in speakers:
|
|
||||||
try:
|
|
||||||
speaker_id = self.speaker_to_id[speaker]
|
|
||||||
speaker_ids.append(speaker_id)
|
|
||||||
except Exception:
|
|
||||||
raise ValueError(f'No such speaker: {speaker}')
|
|
||||||
return speaker_ids
|
|
||||||
|
|
||||||
def apply_tts(self, texts: str or list,
|
|
||||||
speakers: str or list,
|
|
||||||
sample_rate: int = 16000):
|
|
||||||
speaker_ids = self.get_speakers(speakers)
|
|
||||||
text_padded, speaker_ids, orig_ids = self.prepare_tts_model_input(texts,
|
|
||||||
symbols=self.symbols,
|
|
||||||
speakers=speaker_ids)
|
|
||||||
with torch.inference_mode():
|
|
||||||
out, out_lens = self.model(text_padded.to(self.device),
|
|
||||||
speaker_ids.to(self.device),
|
|
||||||
sr=sample_rate)
|
|
||||||
audios = self.process_tts_model_output(out, out_lens, orig_ids)
|
|
||||||
return audios
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def write_wave(path, audio, sample_rate):
|
|
||||||
"""Writes a .wav file.
|
|
||||||
Takes path, PCM audio data, and sample rate.
|
|
||||||
"""
|
|
||||||
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
|
||||||
wf.setnchannels(1)
|
|
||||||
wf.setsampwidth(2)
|
|
||||||
wf.setframerate(sample_rate)
|
|
||||||
wf.writeframes(audio)
|
|
||||||
|
|
||||||
def save_wav(self, texts: str or list,
|
|
||||||
speakers: str or list,
|
|
||||||
audio_pathes: str or list = '',
|
|
||||||
sample_rate: int = 16000):
|
|
||||||
if type(texts) == str:
|
|
||||||
texts = [texts]
|
|
||||||
|
|
||||||
if not audio_pathes:
|
|
||||||
audio_pathes = [f'test_{str(i).zfill(3)}.wav' for i in range(len(texts))]
|
|
||||||
if type(audio_pathes) == str:
|
|
||||||
audio_pathes = [audio_pathes]
|
|
||||||
assert len(audio_pathes) == len(texts)
|
|
||||||
|
|
||||||
audio = self.apply_tts(texts=texts,
|
|
||||||
speakers=speakers,
|
|
||||||
sample_rate=sample_rate)
|
|
||||||
for i, _audio in enumerate(audio):
|
|
||||||
self.write_wave(path=audio_pathes[i],
|
|
||||||
audio=(_audio * 32767).numpy().astype('int16'),
|
|
||||||
sample_rate=sample_rate)
|
|
||||||
return audio_pathes
|
|
@ -1,6 +1,6 @@
|
|||||||
from loguru import logger
|
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from discord.ext.commands import Context
|
from discord.ext.commands import Context
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
class cogErrorHandlers:
|
class cogErrorHandlers:
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands, tasks
|
from discord.ext import commands, tasks
|
||||||
from discord.ext.commands import Context
|
from discord.ext.commands import Context
|
||||||
import datetime
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
class BotInformation(commands.Cog):
|
class BotInformation(commands.Cog):
|
||||||
@ -15,12 +16,12 @@ class BotInformation(commands.Cog):
|
|||||||
@commands.command('info')
|
@commands.command('info')
|
||||||
async def get_info(self, ctx: Context):
|
async def get_info(self, ctx: Context):
|
||||||
info = f"""
|
info = f"""
|
||||||
Text-To-Speech bot, based on Silero TTS model (<https://github.com/snakers4/silero-models>)
|
Text-To-Speech bot
|
||||||
License: `GNU GENERAL PUBLIC LICENSE Version 3`
|
License: `GNU GENERAL PUBLIC LICENSE Version 3`
|
||||||
Author discord: `a31#6403`
|
Author discord: `@a31`
|
||||||
Author email: `a31@demb.design`
|
Author email: `a31@demb.uk`
|
||||||
Source code on github: <https://github.com/norohind/SileroTTSBot>
|
Source code on github: <https://github.com/norohind/SileroTTSBot>
|
||||||
Source code on gitea's a31 instance: <https://gitea.demb.design/a31/SileroTTSBot>
|
Source code on gitea a31's instance: <https://gitea.demb.uk/a31/SileroTTSBot>
|
||||||
Invite link: https://discord.com/oauth2/authorize?client_id={self.bot.user.id}&scope=bot%20applications.commands
|
Invite link: https://discord.com/oauth2/authorize?client_id={self.bot.user.id}&scope=bot%20applications.commands
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
140
cogs/BotManagement.py
Normal file
140
cogs/BotManagement.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
try:
|
||||||
|
import rfoo
|
||||||
|
from rfoo.utils import rconsole
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
print("not importing rfoo", e)
|
||||||
|
|
||||||
|
from threading import Thread
|
||||||
|
from typing import Optional, Coroutine
|
||||||
|
|
||||||
|
from discord.ext import commands
|
||||||
|
from discord.ext.commands import Context
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from formatting import format_table
|
||||||
|
|
||||||
|
|
||||||
|
class SingletonBase:
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
if getattr(cls, 'instance') is None:
|
||||||
|
cls.instance = super().__new__(cls, *args, **kwargs)
|
||||||
|
|
||||||
|
# logger.debug(id(getattr(cls, 'instance')))
|
||||||
|
return getattr(cls, 'instance')
|
||||||
|
|
||||||
|
|
||||||
|
class BotManagement(commands.Cog, SingletonBase):
|
||||||
|
instance: Optional['BotManagement'] = None
|
||||||
|
|
||||||
|
rfoo_server_thread: Optional[Thread] = None
|
||||||
|
rfoo_server: Optional['rfoo.InetServer'] = None
|
||||||
|
|
||||||
|
# def __new__(cls, *args, **kwargs):
|
||||||
|
# if cls.instance is None:
|
||||||
|
# cls.instance = super().__new__(cls, *args, **kwargs)
|
||||||
|
#
|
||||||
|
# return cls.instance
|
||||||
|
|
||||||
|
def __init__(self, bot: commands.Bot):
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
@commands.is_owner()
|
||||||
|
@commands.command('listServers')
|
||||||
|
async def list_servers(self, ctx: Context):
|
||||||
|
text_table = format_table(((server.name, str(server.id)) for server in self.bot.guilds), ('name', 'id'))
|
||||||
|
await ctx.channel.send(text_table)
|
||||||
|
|
||||||
|
@commands.is_owner()
|
||||||
|
@commands.command('listVoice')
|
||||||
|
async def list_voice_connections(self, ctx: Context):
|
||||||
|
text_table = format_table(((it.guild.name,) for it in self.bot.voice_clients))
|
||||||
|
await ctx.channel.send(text_table)
|
||||||
|
|
||||||
|
def start_rfoo(self) -> bool:
|
||||||
|
# True if started, False if already started
|
||||||
|
if self.rfoo_server_thread is None:
|
||||||
|
self.rfoo_server = rfoo.InetServer(rconsole.ConsoleHandler, {'bot': self.bot, 'ct': self.ct})
|
||||||
|
self.rfoo_server_thread = Thread(target=lambda: self.rfoo_server.start(rfoo.LOOPBACK, 54321))
|
||||||
|
self.rfoo_server_thread.daemon = True
|
||||||
|
self.rfoo_server_thread.start()
|
||||||
|
logger.info('Rfoo thread started by msg')
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stop_rfoo(self) -> bool:
|
||||||
|
if self.rfoo_server_thread is not None:
|
||||||
|
self.rfoo_server.stop()
|
||||||
|
del self.rfoo_server_thread
|
||||||
|
logger.info('Rfoo thread stopped by msg')
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@commands.is_owner()
|
||||||
|
@commands.command('rfooStart')
|
||||||
|
async def start(self, ctx: Context):
|
||||||
|
if self.start_rfoo():
|
||||||
|
await ctx.send('Rfoo thread started')
|
||||||
|
|
||||||
|
else:
|
||||||
|
await ctx.send('Rfoo thread already started')
|
||||||
|
|
||||||
|
@commands.is_owner()
|
||||||
|
@commands.command('rfooStop')
|
||||||
|
async def stop(self, ctx: Context):
|
||||||
|
if self.stop_rfoo():
|
||||||
|
await ctx.send('Rfoo server stopped')
|
||||||
|
|
||||||
|
else:
|
||||||
|
await ctx.send('Rfoo server already stopped')
|
||||||
|
|
||||||
|
def ct(self, coro: Coroutine):
|
||||||
|
"""
|
||||||
|
ct - short from create_task
|
||||||
|
execute coroutine and get result
|
||||||
|
"""
|
||||||
|
|
||||||
|
task = self.bot.loop.create_task(coro)
|
||||||
|
while not task.done():
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return task.result()
|
||||||
|
|
||||||
|
except asyncio.exceptions.InvalidStateError:
|
||||||
|
return task.exception()
|
||||||
|
|
||||||
|
@commands.command('shutdown')
|
||||||
|
async def shutdown(self, ctx: Context) -> None:
|
||||||
|
"""Shutdown bot with hope docker daemon will restart it, `@a31` and `@furiz__` has rights for it"""
|
||||||
|
if ctx.author.id in (420130693696323585, # @a31
|
||||||
|
444819880781545472): # @furiz__
|
||||||
|
|
||||||
|
log_msg = f"Got shutdown command by {ctx.author}"
|
||||||
|
logger.info(log_msg)
|
||||||
|
await ctx.reply('Shutting down')
|
||||||
|
|
||||||
|
try:
|
||||||
|
await (await self.bot.fetch_user(420130693696323585)).send(log_msg)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.opt(exception=e).warning(f'Failed to send shutdown message', )
|
||||||
|
|
||||||
|
self.bot.loop.create_task(self.bot.close())
|
||||||
|
|
||||||
|
else:
|
||||||
|
await ctx.reply('No rights for you')
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot: commands.Bot):
|
||||||
|
await bot.add_cog(BotManagement(bot))
|
||||||
|
|
||||||
|
async def teardown(bot):
|
||||||
|
stop_res = BotManagement(bot).stop_rfoo()
|
||||||
|
logger.info(f'Unloaded rfoo with result {stop_res} during BotManagement unload')
|
||||||
|
|
204
cogs/TTSCore.py
204
cogs/TTSCore.py
@ -2,18 +2,21 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import Union, Optional
|
||||||
|
|
||||||
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from discord.ext.commands import Context
|
from discord.ext.commands import Context
|
||||||
import discord
|
|
||||||
import DB
|
|
||||||
from typing import Union
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from TTSSilero import TTSSileroCached
|
|
||||||
from TTSSilero import Speakers
|
import DB
|
||||||
from FFmpegPCMAudioModified import FFmpegPCMAudio
|
|
||||||
import Observ
|
import Observ
|
||||||
|
import formatting
|
||||||
|
from FFmpegPCMAudioModified import FFmpegPCMAudio
|
||||||
|
from SpeakersSettingsAdapterDiscord import SpeakersSettingsAdapterDiscord
|
||||||
|
from TTSServer import TTSServerCached
|
||||||
|
from TTSServer.Speakers import ModelDescription
|
||||||
from cogErrorHandlers import cogErrorHandlers
|
from cogErrorHandlers import cogErrorHandlers
|
||||||
from SpeakersSettingsAdapterDiscord import speakers_settings_adapter, SpeakersSettingsAdapterDiscord
|
|
||||||
|
|
||||||
|
|
||||||
class TTSCore(commands.Cog, Observ.Observer):
|
class TTSCore(commands.Cog, Observ.Observer):
|
||||||
@ -21,12 +24,33 @@ class TTSCore(commands.Cog, Observ.Observer):
|
|||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.cog_command_error = cogErrorHandlers.missing_argument_handler
|
self.cog_command_error = cogErrorHandlers.missing_argument_handler
|
||||||
self.bot.subscribe(self) # subscribe for messages that aren't commands
|
self.bot.subscribe(self) # subscribe for messages that aren't commands
|
||||||
self.tts = TTSSileroCached()
|
self.tts = TTSServerCached()
|
||||||
self.tts_queues: dict[int, list[discord.AudioSource]] = defaultdict(list)
|
self.tts_queues: dict[int, list[discord.AudioSource]] = defaultdict(list)
|
||||||
self.speakers_adapter: SpeakersSettingsAdapterDiscord = speakers_settings_adapter
|
self.speakers_adapter: SpeakersSettingsAdapterDiscord = SpeakersSettingsAdapterDiscord()
|
||||||
|
|
||||||
|
@commands.command('drop')
|
||||||
|
async def drop_queue(self, ctx: Context):
|
||||||
|
"""
|
||||||
|
Drop tts queue for current server
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
del self.tts_queues[ctx.guild.id]
|
||||||
|
await ctx.send('Queue dropped')
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
await ctx.send('Failed on dropping queue')
|
||||||
|
|
||||||
@commands.command('exit')
|
@commands.command('exit')
|
||||||
async def leave_voice(self, ctx: Context):
|
async def leave_voice(self, ctx: Context):
|
||||||
|
"""
|
||||||
|
Disconnect bot from current channel, it also drop messages queue
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
if ctx.guild.voice_client is not None:
|
if ctx.guild.voice_client is not None:
|
||||||
await ctx.guild.voice_client.disconnect(force=False)
|
await ctx.guild.voice_client.disconnect(force=False)
|
||||||
await ctx.channel.send(f"Left voice channel")
|
await ctx.channel.send(f"Left voice channel")
|
||||||
@ -64,43 +88,67 @@ class TTSCore(commands.Cog, Observ.Observer):
|
|||||||
if voice_client is None:
|
if voice_client is None:
|
||||||
voice_client: discord.VoiceClient = await user_voice_state.channel.connect()
|
voice_client: discord.VoiceClient = await user_voice_state.channel.connect()
|
||||||
|
|
||||||
speaker: Speakers = self.speakers_adapter.get_speaker(message.guild.id, message.author.id)
|
if user_voice_state.channel.id != voice_client.channel.id:
|
||||||
|
await message.channel.send('We are in different voice channels')
|
||||||
|
return
|
||||||
|
|
||||||
|
speaker_str: str = self.speakers_adapter.get_speaker(message.guild.id, message.author.id)
|
||||||
|
speaker: ModelDescription | None = await self.tts.speaker_to_description(speaker_str)
|
||||||
|
if speaker is None:
|
||||||
|
speaker = self.tts.default_speaker()
|
||||||
|
|
||||||
# check if message will fail on synthesis
|
# check if message will fail on synthesis
|
||||||
if DB.SynthesisErrors.select()\
|
if DB.SynthesisErrors.select() \
|
||||||
.where(DB.SynthesisErrors.speaker == speaker.value)\
|
.where(DB.SynthesisErrors.speaker == speaker_str) \
|
||||||
.where(DB.SynthesisErrors.text == message.content)\
|
.where(DB.SynthesisErrors.text == message.content) \
|
||||||
.count() == 1:
|
.count() == 1:
|
||||||
# Then we will not try to synthesis it
|
# Then we will not try to synthesis it
|
||||||
await message.channel.send(f"I will not synthesis this message due to TTS engine limitations")
|
await message.channel.send(f"I will not synthesis this message due to TTS engine limitations")
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
wav_file_like_object = self.tts.synthesize_text(message.content, speaker=speaker)
|
wav_file_like_object: bytes = await self.tts.synthesize_text(message.content, speaker=speaker)
|
||||||
sound_source = FFmpegPCMAudio(wav_file_like_object, pipe=True, stderr=subprocess.PIPE)
|
sound_source = FFmpegPCMAudio(wav_file_like_object, pipe=True, stderr=subprocess.PIPE)
|
||||||
if voice_client.is_playing():
|
|
||||||
# Then we need to enqueue prepared sound for playing through self.tts_queues mechanism
|
|
||||||
|
|
||||||
self.tts_queues[message.guild.id].append(sound_source)
|
|
||||||
await message.channel.send(f"Enqueued for play, queue size: {len(self.tts_queues[message.guild.id])}")
|
|
||||||
return
|
|
||||||
|
|
||||||
voice_client.play(sound_source, after=lambda e: self.queue_player(message))
|
|
||||||
|
|
||||||
except Exception as synth_exception:
|
except Exception as synth_exception:
|
||||||
logger.opt(exception=True).warning(f'Exception on synthesize {message.content!r}: {synth_exception}')
|
logger.opt(exception=True).warning(f'Exception on synthesize {message.content!r}: {synth_exception}')
|
||||||
await message.channel.send(f'Internal error')
|
await message.channel.send(f'Synthesize error')
|
||||||
DB.SynthesisErrors.create(speaker=speaker.value, text=message.content)
|
DB.SynthesisErrors.create(speaker=speaker, text=message.content)
|
||||||
|
|
||||||
def queue_player(self, message: discord.Message):
|
|
||||||
voice_client: Union[discord.VoiceClient, None] = message.guild.voice_client
|
|
||||||
if voice_client is None:
|
|
||||||
# don't play anything and clear queue for whole guild
|
|
||||||
del self.tts_queues[message.guild.id]
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
if voice_client.is_playing():
|
||||||
|
# Then we need to enqueue prepared sound for playing through self.tts_queues mechanism
|
||||||
|
self.tts_queues[message.guild.id].append(sound_source)
|
||||||
|
await message.channel.send(
|
||||||
|
f"Enqueued for play, queue size: {len(self.tts_queues[message.guild.id])}")
|
||||||
|
return
|
||||||
|
|
||||||
|
voice_client.play(sound_source, after=lambda e: self.queue_player(message))
|
||||||
|
|
||||||
|
except Exception as play_exception:
|
||||||
|
logger.opt(exception=True).warning(
|
||||||
|
f'Exception on playing for: {message.guild.name}[#{message.channel.name}]: {message.author.display_name} / {play_exception}')
|
||||||
|
await message.channel.send(f'Playing error')
|
||||||
|
return
|
||||||
|
|
||||||
|
def queue_player(self, message: discord.Message):
|
||||||
for sound_source in self.tts_queues[message.guild.id]:
|
for sound_source in self.tts_queues[message.guild.id]:
|
||||||
voice_client.play(sound_source)
|
if len(self.tts_queues[message.guild.id]) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
voice_client: Optional[discord.VoiceClient] = message.guild.voice_client
|
||||||
|
if voice_client is None:
|
||||||
|
# don't play anything and clear queue for whole guild
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
voice_client.play(sound_source)
|
||||||
|
|
||||||
|
except discord.errors.ClientException: # Here we expect Not connected to voice
|
||||||
|
break
|
||||||
|
|
||||||
while voice_client.is_playing():
|
while voice_client.is_playing():
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
@ -111,13 +159,105 @@ class TTSCore(commands.Cog, Observ.Observer):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@commands.Cog.listener()
|
@commands.Cog.listener()
|
||||||
async def on_voice_state_update(self, member: discord.Member, before: discord.VoiceState, after: discord.VoiceState):
|
async def on_voice_state_update(self, member: discord.Member, before: discord.VoiceState,
|
||||||
|
after: discord.VoiceState):
|
||||||
if after.channel is None:
|
if after.channel is None:
|
||||||
members = before.channel.members
|
members = before.channel.members
|
||||||
if len(members) == 1:
|
if len(members) == 1:
|
||||||
if members[0].id == self.bot.user.id:
|
if members[0].id == self.bot.user.id:
|
||||||
await before.channel.guild.voice_client.disconnect(force=False)
|
await before.channel.guild.voice_client.disconnect(force=False)
|
||||||
|
|
||||||
|
# TODO: leave voice channel after being moved there alone
|
||||||
|
|
||||||
|
# Ex TTSSettings.py
|
||||||
|
@commands.command('getAllSpeakers')
|
||||||
|
async def get_speakers(self, ctx: Context):
|
||||||
|
"""
|
||||||
|
Enumerate all available to set speakers
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
await ctx.send(
|
||||||
|
formatting.format_table(
|
||||||
|
tuple((model.engine + '_' + model.name, model.description)
|
||||||
|
for model in await self.tts.discovery()),
|
||||||
|
('model', 'description')
|
||||||
|
))
|
||||||
|
|
||||||
|
@commands.command('setPersonalSpeaker')
|
||||||
|
async def set_user_speaker(self, ctx: Context, speaker: str):
|
||||||
|
"""
|
||||||
|
Set personal speaker on this server
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:param speaker:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
if await self.tts.speaker_to_description(speaker) is None:
|
||||||
|
return await ctx.send(
|
||||||
|
f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||||
|
|
||||||
|
self.speakers_adapter.set_speaker_user(ctx.guild.id, ctx.author.id, speaker)
|
||||||
|
await ctx.reply(f'Successfully set **your personal** speaker to `{speaker}`')
|
||||||
|
|
||||||
|
@commands.command('setServerSpeaker')
|
||||||
|
async def set_server_speaker(self, ctx: Context, speaker: str):
|
||||||
|
"""
|
||||||
|
Set global server speaker
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:param speaker:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if await self.tts.speaker_to_description(speaker) is None:
|
||||||
|
return await ctx.send(
|
||||||
|
f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||||
|
|
||||||
|
self.speakers_adapter.set_speaker_global(ctx.guild.id, ctx.author.id, speaker)
|
||||||
|
await ctx.reply(f'Successfully set **your personal** speaker to `{speaker}`')
|
||||||
|
|
||||||
|
@commands.command('getSpeaker')
|
||||||
|
async def get_speaker(self, ctx: Context):
|
||||||
|
"""
|
||||||
|
Tell first appropriate speaker for a user, it can be user specified, server specified or server default
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
speaker = self.speakers_adapter.get_speaker(ctx.guild.id, ctx.author.id)
|
||||||
|
|
||||||
|
await ctx.reply(f'Your current speaker is `{speaker}`')
|
||||||
|
|
||||||
|
@commands.command('getPersonalSpeaker')
|
||||||
|
async def get_personal_speaker(self, ctx: Context):
|
||||||
|
"""
|
||||||
|
Tell user his personal speaker on this server, if user don't have personal speaker, tells server default one
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
speaker = self.speakers_adapter.get_speaker_user(ctx.guild.id, ctx.author.id)
|
||||||
|
if speaker is None:
|
||||||
|
server_speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
|
||||||
|
await ctx.send(f"You currently don't have a personal speaker, current server speaker is `{server_speaker}`")
|
||||||
|
|
||||||
|
else:
|
||||||
|
await ctx.reply(f"Your personal speaker is `{speaker}`")
|
||||||
|
|
||||||
|
@commands.command('getServerSpeaker')
|
||||||
|
async def get_server_speaker(self, ctx: Context):
|
||||||
|
"""
|
||||||
|
Tell server global speaker
|
||||||
|
|
||||||
|
:param ctx:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
|
||||||
|
await ctx.send(f"Current server speaker is `{speaker}`")
|
||||||
|
|
||||||
|
|
||||||
async def setup(bot):
|
async def setup(bot):
|
||||||
await bot.add_cog(TTSCore(bot))
|
await bot.add_cog(TTSCore(bot))
|
||||||
|
@ -1,102 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
from discord.ext import commands
|
|
||||||
from discord.ext.commands import Context
|
|
||||||
from TTSSilero import Speakers
|
|
||||||
from cogErrorHandlers import cogErrorHandlers
|
|
||||||
from SpeakersSettingsAdapterDiscord import speakers_settings_adapter, SpeakersSettingsAdapterDiscord
|
|
||||||
|
|
||||||
|
|
||||||
class TTSSettings(commands.Cog):
|
|
||||||
def __init__(self, bot: commands.Bot):
|
|
||||||
self.bot = bot
|
|
||||||
self.cog_command_error = cogErrorHandlers.missing_argument_handler
|
|
||||||
self.speakers_adapter: SpeakersSettingsAdapterDiscord = speakers_settings_adapter
|
|
||||||
|
|
||||||
@commands.command('getAllSpeakers')
|
|
||||||
async def get_speakers(self, ctx: Context):
|
|
||||||
"""
|
|
||||||
Enumerate all available to set speakers
|
|
||||||
|
|
||||||
:param ctx:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
speakers = '\n'.join(self.speakers_adapter.available_speakers)
|
|
||||||
|
|
||||||
await ctx.send(f"```\n{speakers}```")
|
|
||||||
|
|
||||||
@commands.command('setPersonalSpeaker')
|
|
||||||
async def set_user_speaker(self, ctx: Context, speaker: str):
|
|
||||||
"""
|
|
||||||
Set personal speaker on this server
|
|
||||||
|
|
||||||
:param ctx:
|
|
||||||
:param speaker:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
checked_speaker: Speakers = Speakers(speaker)
|
|
||||||
self.speakers_adapter.set_speaker_user(ctx.guild.id, ctx.author.id, checked_speaker)
|
|
||||||
await ctx.reply(f'Successfully set **your personal** speaker to `{checked_speaker.value}`')
|
|
||||||
|
|
||||||
except (KeyError, ValueError):
|
|
||||||
await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
|
||||||
|
|
||||||
@commands.command('setServerSpeaker')
|
|
||||||
async def set_server_speaker(self, ctx: Context, speaker: str):
|
|
||||||
"""
|
|
||||||
Set global server speaker
|
|
||||||
|
|
||||||
:param ctx:
|
|
||||||
:param speaker:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
checked_speaker: Speakers = Speakers(speaker)
|
|
||||||
self.speakers_adapter.set_speaker_global(ctx.guild.id, checked_speaker)
|
|
||||||
await ctx.send(f'Successfully set **server** speaker to `{checked_speaker.value}`')
|
|
||||||
|
|
||||||
except (KeyError, ValueError):
|
|
||||||
await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
|
||||||
|
|
||||||
@commands.command('getSpeaker')
|
|
||||||
async def get_speaker(self, ctx: Context):
|
|
||||||
"""
|
|
||||||
Tell first appropriate speaker for a user, it can be user specified, server specified or server default
|
|
||||||
|
|
||||||
:param ctx:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
speaker = self.speakers_adapter.get_speaker(ctx.guild.id, ctx.author.id)
|
|
||||||
|
|
||||||
await ctx.reply(f'Your current speaker is `{speaker.value}`')
|
|
||||||
|
|
||||||
@commands.command('getPersonalSpeaker')
|
|
||||||
async def get_personal_speaker(self, ctx: Context):
|
|
||||||
"""
|
|
||||||
Tell user his personal speaker on this server, if user don't have personal speaker, tells server default one
|
|
||||||
|
|
||||||
:param ctx:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
speaker = self.speakers_adapter.get_speaker_user(ctx.guild.id, ctx.author.id)
|
|
||||||
if speaker is None:
|
|
||||||
server_speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id).value
|
|
||||||
await ctx.send(f"You currently don't have a personal speaker, current server speaker is `{server_speaker}`")
|
|
||||||
|
|
||||||
else:
|
|
||||||
await ctx.reply(f"Your personal speaker is `{speaker.value}`")
|
|
||||||
|
|
||||||
@commands.command('getServerSpeaker')
|
|
||||||
async def get_server_speaker(self, ctx: Context):
|
|
||||||
"""
|
|
||||||
Tell server global speaker
|
|
||||||
|
|
||||||
:param ctx:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
speaker = self.speakers_adapter.get_speaker_global(ctx.guild.id)
|
|
||||||
await ctx.send(f"Current server speaker is `{speaker.value}`")
|
|
||||||
|
|
||||||
|
|
||||||
async def setup(bot):
|
|
||||||
await bot.add_cog(TTSSettings(bot))
|
|
@ -1,8 +1,9 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from discord.ext.commands import Context
|
from discord.ext.commands import Context
|
||||||
import DB
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
import DB
|
||||||
import DynamicCommandPrefix
|
import DynamicCommandPrefix
|
||||||
from cogErrorHandlers import cogErrorHandlers
|
from cogErrorHandlers import cogErrorHandlers
|
||||||
|
|
||||||
|
5
config.py
Normal file
5
config.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
DB_PATH = Path(os.getenv('DATA_DIR', '.')) / 'voice_cache.sqlite'
|
||||||
|
BASE_URL = os.environ['BASE_URL']
|
18
formatting.py
Normal file
18
formatting.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
|
||||||
|
class MISSING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def format_table(data: Iterable[Iterable[str]], header: Iterable[str] = MISSING) -> str:
|
||||||
|
if header != MISSING:
|
||||||
|
data = (header, *data)
|
||||||
|
|
||||||
|
result = '```\n'
|
||||||
|
for row in data:
|
||||||
|
row = [str(item).replace('`', '\\`') for item in row]
|
||||||
|
result += '\t'.join(row) + '\n'
|
||||||
|
|
||||||
|
result += '```'
|
||||||
|
return result
|
BIN
logo_200x200.png
Normal file
BIN
logo_200x200.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.1 KiB |
26
main.py
26
main.py
@ -1,15 +1,19 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import os
|
|
||||||
from discord.ext import commands
|
|
||||||
import discord
|
|
||||||
import signal
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from DynamicCommandPrefix import dynamic_command_prefix
|
|
||||||
import Observ
|
import Observ
|
||||||
|
from DynamicCommandPrefix import dynamic_command_prefix
|
||||||
|
|
||||||
|
LOG_FILE_ENABLED = os.getenv('LOG_FILE_ENABLED', 'true').lower() == 'true'
|
||||||
|
|
||||||
logger.add('offlineTTSBot.log', backtrace=True, diagnose=False, rotation='5MB')
|
if LOG_FILE_ENABLED:
|
||||||
|
logger.add('offlineTTSBot.log', backtrace=True, diagnose=False, rotation='5MB')
|
||||||
|
|
||||||
"""
|
"""
|
||||||
while msg := input('$ '):
|
while msg := input('$ '):
|
||||||
@ -49,7 +53,7 @@ class DiscordTTSBot(commands.Bot, Observ.Subject):
|
|||||||
await self.notify(ctx.message)
|
await self.notify(ctx.message)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise exception
|
logger.opt(exception=exception).warning(f'Global error caught:')
|
||||||
|
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
@ -66,6 +70,8 @@ async def main():
|
|||||||
|
|
||||||
await discord_client.start(os.environ['DISCORD_TOKEN'])
|
await discord_client.start(os.environ['DISCORD_TOKEN'])
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
loop.run_until_complete(main())
|
if __name__ == '__main__':
|
||||||
logger.debug('Shutdown completed')
|
loop = asyncio.new_event_loop()
|
||||||
|
loop.run_until_complete(main())
|
||||||
|
logger.debug('Shutdown completed')
|
||||||
|
@ -1,14 +1,7 @@
|
|||||||
aiohttp==3.7.4.post0
|
discord-py==2.3.2
|
||||||
async-timeout==3.0.1
|
loguru==0.7.2
|
||||||
attrs==21.4.0
|
peewee==3.17.5
|
||||||
chardet==4.0.0
|
PyNaCl==1.5.0
|
||||||
idna==3.3
|
git+https://github.com/aaiyer/rfoo.git@1555bd4eed204bb6a33a5e313146a6c2813cfe91
|
||||||
multidict==6.0.2
|
# Cython is setup dependency for rfoo
|
||||||
numpy==1.22.3
|
pydantic
|
||||||
torch==1.11.0
|
|
||||||
typing-extensions==4.1.1
|
|
||||||
yarl==1.7.2
|
|
||||||
git+https://github.com/Rapptz/discord.py.git#egg=discord-py
|
|
||||||
loguru~=0.6.0
|
|
||||||
peewee~=3.14.10
|
|
||||||
PyNaCl
|
|
6
utils.py
6
utils.py
@ -1,10 +1,4 @@
|
|||||||
import winsound
|
import winsound
|
||||||
import io
|
|
||||||
|
|
||||||
|
|
||||||
def save_bytes(filename: str, bytes_audio: bytes) -> None:
|
|
||||||
with open(file=filename, mode='wb') as res_file:
|
|
||||||
res_file.write(bytes_audio)
|
|
||||||
|
|
||||||
|
|
||||||
def play_bytes(bytes_sound: bytes) -> None:
|
def play_bytes(bytes_sound: bytes) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user