Still working on it

This commit is contained in:
norohind 2022-03-16 01:36:45 +03:00
parent d3eaa8bb70
commit d7536bab7a
Signed by: norohind
GPG Key ID: 01C3BECC26FB59E1
14 changed files with 330 additions and 92 deletions

View File

@ -1,27 +0,0 @@
import io
import sqlite3
class Cache:
SCHEMA = "create table if not exists cache (key text primary key, value blob);"
SET = "insert into cache (key, value) values (:key, :value);"
GET = "select value from cache where key = :key;"
def __init__(self):
self.connection = sqlite3.connect('voice_cache.sqlite')
self.connection.execute(self.SCHEMA)
def set(self, key: str, value: bytes) -> None:
with self.connection:
self.connection.execute(self.SET, {'key': key, 'value': value})
def get(self, key: str):
res = self.connection.execute(self.GET, {'key': key}).fetchone()
if res is None:
return None
else:
return io.BytesIO(res[0])
cache = Cache()

22
DB.py Normal file
View File

@ -0,0 +1,22 @@
import peewee
database = peewee.SqliteDatabase('voice_cache.sqlite')
class BaseModel(peewee.Model):
class Meta:
database = database
class Prefix(BaseModel):
server_id = peewee.BigIntegerField(primary_key=True)
prefix_char = peewee.CharField(max_length=10)
class Speaker(BaseModel):
server_id = peewee.BigIntegerField(primary_key=True)
speaker = peewee.CharField()
Prefix.create_table()
Speaker.create_table()

26
DynamicCommandPrefix.py Normal file
View File

@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
import discord
from discord.ext import commands
import peewee
import DB
# from loguru import logger
def dynamic_command_prefix(bot: commands.Bot, message: discord.Message) -> list[str]:
mention_prefixes = commands.bot.when_mentioned(bot, message)
custom_prefix = get_guild_prefix(message.guild.id)
all_prefixes = mention_prefixes + [custom_prefix + ' ', custom_prefix]
# logger.debug(f'Return prefixes {all_prefixes!r} for {message.content!r}')
return all_prefixes
def get_guild_prefix(guild_id: int) -> str:
try:
prefix = DB.Prefix[guild_id].prefix_char
except peewee.DoesNotExist:
prefix = 'tts '
return prefix

8
Observ/Observer.py Normal file
View File

@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
import discord
class Observer:
@abstractmethod
def update(self, message: discord.Message) -> None:
raise NotImplemented

20
Observ/Subject.py Normal file
View File

@ -0,0 +1,20 @@
import discord
from .Observer import Observer
class Subject:
observers: set = set()
def subscribe(self, observer: Observer) -> None:
self.observers.add(observer)
def unsubscribe(self, observer: Observer) -> None:
try:
self.observers.remove(observer)
except KeyError:
pass
async def notify(self, message: discord.Message) -> None:
for observer in self.observers:
await observer.update(message)

2
Observ/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .Observer import Observer
from .Subject import Subject

View File

@ -8,7 +8,7 @@ from .Speakers import Speakers
from .multi_v2_package import TTSModelMulti_v2
class TTS:
class TTSSilero:
def __init__(self, threads: int = 12):
device = torch.device('cpu')
torch.set_num_threads(threads)
@ -25,8 +25,8 @@ class TTS:
self.sample_rate = 16000
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya, seek: int = None) -> io.BytesIO:
return self.to_wav(self._synthesize_text(text, speaker), seek)
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
return self.to_wav(self._synthesize_text(text, speaker))
def _synthesize_text(self, text: str, speaker: Speakers) -> list[torch.Tensor]:
"""
@ -44,7 +44,7 @@ class TTS:
return results_list
def to_wav(self, synthesized_text: list[torch.Tensor], seek: int = None) -> io.BytesIO:
def to_wav(self, synthesized_text: list[torch.Tensor]) -> bytes:
res_io_stream = io.BytesIO()
with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf:
@ -54,9 +54,8 @@ class TTS:
for result in synthesized_text:
wf.writeframes((result * 32767).numpy().astype('int16'))
if type(seek) is int:
res_io_stream.seek(seek)
res_io_stream.seek(0)
return res_io_stream
return res_io_stream.read()

View File

@ -0,0 +1,60 @@
import time
from typing import Union
from .TTSSilero import TTSSilero
from .Speakers import Speakers
import DB
import sqlite3
from loguru import logger
class TTSSileroCached(TTSSilero):
_SQLITE_GET = "select audio from soundcache where text = :text and speaker = :speaker;"
_SQLITE_SET = "insert into soundcache (text, speaker, audio) values (:text, :speaker, :audio);"
_SQLITE_INCREMENT_USAGES = "update soundcache set usages = usages + 1 where text = :text and :speaker = :speaker;"
_SQLITE_SCHEMA = """CREATE TABLE IF NOT EXISTS "soundcache" (
"text" TEXT NOT NULL,
"speaker" VARCHAR(255) NOT NULL,
"audio" BLOB NOT NULL,
"usages" INTEGER NOT NULL default 0,
PRIMARY KEY ("text", "speaker")
);"""
database: sqlite3.Connection = DB.database.connection()
def __init__(self):
super().__init__()
self.database.execute(self._SQLITE_SCHEMA)
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
# start = time.time()
cached = self._cache_get(text, speaker.value)
# logger.debug(f'Cache lookup in {time.time() - start} s')
if cached is not None:
# logger.debug(f'Cache lookup successful in {time.time() - start} s')
return cached
else:
# logger.debug(f'Starting synthesis')
# start2 = time.time()
synthesized = super().synthesize_text(text, speaker)
# logger.debug(f'Synthesis done in {time.time() - start2} s in {time.time() - start} s after start')
self._cache_set(text, speaker.value, synthesized)
# logger.debug(f'Cache set in {time.time() - start2} synth end and {time.time() - start2} s after start')
return synthesized
def _cache_get(self, text: str, speaker: str) -> Union[bytes, None]:
query_args = {'text': text, 'speaker': speaker}
result = self.database.execute(self._SQLITE_GET, query_args).fetchone()
if result is None:
return None
else:
with self.database:
self.database.execute(self._SQLITE_INCREMENT_USAGES, query_args)
return result[0]
def _cache_set(self, text: str, speaker: str, audio: bytes) -> None:
with self.database:
self.database.execute(self._SQLITE_SET, {'text': text, 'speaker': speaker, 'audio': audio})

View File

@ -1,2 +1,3 @@
from .TTSSilero import TTS
from .TTSSilero import TTSSilero
from .Speakers import Speakers
from .TTSSileroCached import TTSSileroCached

15
cogErrorHandlers.py Normal file
View File

@ -0,0 +1,15 @@
from loguru import logger
from discord.ext import commands
from discord.ext.commands import Context
class cogErrorHandlers:
@classmethod
async def missing_argument_handler(cls, ctx: Context, error: Exception) -> None:
if isinstance(error, commands.errors.MissingRequiredArgument):
# No argument was specified
await ctx.reply(str(error))
else:
logger.exception(f'prefixConfiguration error occurred: ', exc_info=error)
await ctx.reply(f'Internal error occurred')

100
cogs/TTSCommands.py Normal file
View File

@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
from discord.ext import commands
from discord.ext.commands import Context
import discord
import DB
from typing import Union
from loguru import logger
from TTSSilero import TTSSileroCached
from TTSSilero import Speakers
from FFmpegPCMAudioModified import FFmpegPCMAudio
import Observ
from cogErrorHandlers import cogErrorHandlers
class TTSCommands(commands.Cog, Observ.Observer):
DEFAULT_SPEAKER = Speakers.kseniya
def __init__(self, bot: Union[commands.Bot, Observ.Subject]):
self.bot = bot
self.cog_command_error = cogErrorHandlers.missing_argument_handler
self.bot.subscribe(self) # subscribe for messages that aren't commands
self.tts = TTSSileroCached()
@commands.command('exit')
async def leave_voice(self, ctx: Context):
if ctx.guild.voice_client is not None:
await ctx.guild.voice_client.disconnect(force=False)
await ctx.channel.send(f"Left voice channel")
return
else:
await ctx.channel.send("I'm not in any voice channel")
return
async def update(self, message: discord.Message):
"""
Like on_message but only for messages which aren't commands
:param message:
:return:
"""
if message.author == self.bot.user:
return
if not isinstance(message.channel, discord.TextChannel):
return
logger.info(f'Message: {message.content}')
user_voice_state = message.author.voice
if user_voice_state is None:
await message.channel.send(f"You're not in a voice channel")
return
# noinspection PyTypeChecker
voice_client: discord.VoiceClient = message.guild.voice_client
if voice_client is None:
voice_client: discord.VoiceClient = await user_voice_state.channel.connect()
speaker: Speakers = await self._get_speaker(message.guild.id)
wav_file_like_object = self.tts.synthesize_text(message.content, speaker=speaker)
sound_source = FFmpegPCMAudio(wav_file_like_object, pipe=True)
voice_client.play(sound_source)
@commands.command('getAllSpeakers')
async def get_speakers(self, ctx: Context):
speakers = '\n'.join([speaker.name for speaker in Speakers])
await ctx.send(f"```\n{speakers}```")
@commands.command('setSpeaker')
async def set_speaker(self, ctx: Context, speaker: str):
try:
checked_speaker: Speakers = Speakers(speaker)
DB.Speaker.replace(server_id=ctx.guild.id, speaker=checked_speaker.value).execute()
await ctx.send(f'Successfully set speaker to `{checked_speaker.value}`')
except KeyError:
await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
@commands.command('getSpeaker')
async def get_speaker(self, ctx: Context):
speaker = await self._get_speaker(ctx.guild.id)
await ctx.send(f'Your current speaker is `{speaker.value}`')
async def _get_speaker(self, guild_id: int) -> Speakers:
try:
speaker = Speakers(DB.Speaker[guild_id].speaker)
except DB.peewee.DoesNotExist:
speaker = self.DEFAULT_SPEAKER
return speaker
async def setup(bot):
await bot.add_cog(TTSCommands(bot))

View File

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
from discord.ext import commands
from discord.ext.commands import Context
import DB
from loguru import logger
import DynamicCommandPrefix
from cogErrorHandlers import cogErrorHandlers
class prefixConfiguration(commands.Cog):
"""
Cog for manage prefix in per server way
"""
def __init__(self, bot: commands.Bot):
self.bot = bot
self.cog_command_error = cogErrorHandlers.missing_argument_handler
@commands.command('setPrefix')
async def set_prefix(self, ctx: Context, prefix: str):
logger.debug(f'Going to set prefix')
if len(prefix) > DB.Prefix.prefix_char.max_length:
await ctx.reply(f'Prefix must be one symbol')
return
DB.Prefix.replace(server_id=ctx.guild.id, prefix_char=prefix).execute()
logger.debug(f'Set prefix {prefix!r} for guild {ctx.guild.name!r}')
await ctx.reply(f'Your new prefix is `{prefix}`')
@commands.command('getPrefix')
async def get_prefix(self, ctx: Context):
prefix = DynamicCommandPrefix.get_guild_prefix(ctx.guild.id)
await ctx.reply(f'Your current prefix is `{prefix}` and <@{self.bot.user.id}>')
@commands.command('resetPrefix')
async def reset_prefix(self, ctx: Context):
DB.Prefix.delete().where(DB.Prefix.server_id == ctx.guild.id).execute()
await ctx.reply(f'Your prefix was deleted')
async def setup(bot):
await bot.add_cog(prefixConfiguration(bot))

83
main.py
View File

@ -1,19 +1,14 @@
# -*- coding: utf-8 -*-
import os
import time
from discord.ext import commands
import discord
import signal
import asyncio
from loguru import logger
from DynamicCommandPrefix import dynamic_command_prefix
import Observ
import utils
from TTSSilero import TTS
from TTSSilero import Speakers
from FFmpegPCMAudioModified import FFmpegPCMAudio
from Cache import cache
tts = TTS()
"""
while msg := input('$ '):
start = time.time()
@ -23,7 +18,7 @@ while msg := input('$ '):
"""
class DiscordTTSBot(discord.Client):
class DiscordTTSBot(commands.Bot, Observ.Subject):
def __init__(self, **kwargs):
super().__init__(**kwargs)
signal.signal(signal.SIGTERM, self.shutdown)
@ -37,61 +32,35 @@ class DiscordTTSBot(discord.Client):
async def on_ready(self):
logger.debug('Bot is ready')
async def on_message(self, message: discord.Message):
if message.author == self.user:
return
async def on_message(self, message: discord.Message) -> None:
await super(DiscordTTSBot, self).on_message(message)
if message.author.bot: # because on_command_error will not be called if author is bot
# so it isn't a command, so, pass it next
await self.notify(message)
if not message.content.startswith('-'):
return
async def on_command_error(self, ctx: commands.Context, exception: commands.errors.CommandError) -> None:
if isinstance(exception, commands.errors.CommandNotFound):
ctx.message.content = ctx.message.content[len(ctx.prefix):] # skip prefix
await self.notify(ctx.message)
if isinstance(message.channel, discord.TextChannel):
logger.info(f'Message: {message.content}')
user_voice_state = message.author.voice
if message.content.startswith('/exit'):
if message.guild.voice_client is not None:
logger.debug(f'Disconnecting from voice channel')
await message.guild.voice_client.disconnect(force=False)
await message.channel.send(f"Left voice channel")
return
else:
await message.channel.send("I'm not in any voice channel")
return
if user_voice_state is None:
await message.channel.send(f"You're not in a voice channel")
return
# noinspection PyTypeChecker
voice_client: discord.VoiceClient = message.guild.voice_client
if voice_client is None:
voice_client: discord.VoiceClient = await user_voice_state.channel.connect()
cached = cache.get(message.content)
if cached is not None:
wav_file_like_object = cached
logger.debug(f'Cache lookup for {message.content!r} successful')
else:
synthesis_start = time.time()
wav_file_like_object = tts.synthesize_text(message.content, seek=0)
logger.debug(f'Synthesis took {time.time() - synthesis_start} s')
cache.set(message.content, wav_file_like_object.read())
logger.debug(f'Set cache for {message.content!r}')
wav_file_like_object.seek(0)
sound_source = FFmpegPCMAudio(
wav_file_like_object.read(),
pipe=True
)
voice_client.play(sound_source, after=lambda e: logger.debug(f"Player done, {e=}"))
else:
raise exception
intents = discord.Intents.default()
intents.message_content = True
discord_client = DiscordTTSBot(intents=intents)
discord_client = DiscordTTSBot(command_prefix=dynamic_command_prefix, intents=intents)
async def main():
for filename in os.listdir("./cogs"):
if filename.endswith(".py"):
logger.debug(f'Loading extension {filename}')
await discord_client.load_extension(f"cogs.{filename[:-3]}")
await discord_client.start(os.environ['DISCORD_TOKEN'])
loop = asyncio.new_event_loop()
loop.run_until_complete(discord_client.start(os.environ['DISCORD_TOKEN']))
loop.run_until_complete(main())
logger.debug('Shutdown completed')

BIN
requirements.txt Normal file

Binary file not shown.