Still working on it
This commit is contained in:
parent
d3eaa8bb70
commit
d7536bab7a
27
Cache.py
27
Cache.py
@ -1,27 +0,0 @@
|
||||
import io
|
||||
import sqlite3
|
||||
|
||||
|
||||
class Cache:
|
||||
SCHEMA = "create table if not exists cache (key text primary key, value blob);"
|
||||
SET = "insert into cache (key, value) values (:key, :value);"
|
||||
GET = "select value from cache where key = :key;"
|
||||
|
||||
def __init__(self):
|
||||
self.connection = sqlite3.connect('voice_cache.sqlite')
|
||||
self.connection.execute(self.SCHEMA)
|
||||
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
with self.connection:
|
||||
self.connection.execute(self.SET, {'key': key, 'value': value})
|
||||
|
||||
def get(self, key: str):
|
||||
res = self.connection.execute(self.GET, {'key': key}).fetchone()
|
||||
if res is None:
|
||||
return None
|
||||
|
||||
else:
|
||||
return io.BytesIO(res[0])
|
||||
|
||||
|
||||
cache = Cache()
|
22
DB.py
Normal file
22
DB.py
Normal file
@ -0,0 +1,22 @@
|
||||
import peewee
|
||||
|
||||
database = peewee.SqliteDatabase('voice_cache.sqlite')
|
||||
|
||||
|
||||
class BaseModel(peewee.Model):
|
||||
class Meta:
|
||||
database = database
|
||||
|
||||
|
||||
class Prefix(BaseModel):
|
||||
server_id = peewee.BigIntegerField(primary_key=True)
|
||||
prefix_char = peewee.CharField(max_length=10)
|
||||
|
||||
|
||||
class Speaker(BaseModel):
|
||||
server_id = peewee.BigIntegerField(primary_key=True)
|
||||
speaker = peewee.CharField()
|
||||
|
||||
|
||||
Prefix.create_table()
|
||||
Speaker.create_table()
|
26
DynamicCommandPrefix.py
Normal file
26
DynamicCommandPrefix.py
Normal file
@ -0,0 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
import peewee
|
||||
import DB
|
||||
# from loguru import logger
|
||||
|
||||
|
||||
def dynamic_command_prefix(bot: commands.Bot, message: discord.Message) -> list[str]:
|
||||
mention_prefixes = commands.bot.when_mentioned(bot, message)
|
||||
|
||||
custom_prefix = get_guild_prefix(message.guild.id)
|
||||
all_prefixes = mention_prefixes + [custom_prefix + ' ', custom_prefix]
|
||||
|
||||
# logger.debug(f'Return prefixes {all_prefixes!r} for {message.content!r}')
|
||||
return all_prefixes
|
||||
|
||||
|
||||
def get_guild_prefix(guild_id: int) -> str:
|
||||
try:
|
||||
prefix = DB.Prefix[guild_id].prefix_char
|
||||
|
||||
except peewee.DoesNotExist:
|
||||
prefix = 'tts '
|
||||
|
||||
return prefix
|
8
Observ/Observer.py
Normal file
8
Observ/Observer.py
Normal file
@ -0,0 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import discord
|
||||
|
||||
|
||||
class Observer:
|
||||
@abstractmethod
|
||||
def update(self, message: discord.Message) -> None:
|
||||
raise NotImplemented
|
20
Observ/Subject.py
Normal file
20
Observ/Subject.py
Normal file
@ -0,0 +1,20 @@
|
||||
import discord
|
||||
from .Observer import Observer
|
||||
|
||||
|
||||
class Subject:
|
||||
observers: set = set()
|
||||
|
||||
def subscribe(self, observer: Observer) -> None:
|
||||
self.observers.add(observer)
|
||||
|
||||
def unsubscribe(self, observer: Observer) -> None:
|
||||
try:
|
||||
self.observers.remove(observer)
|
||||
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
async def notify(self, message: discord.Message) -> None:
|
||||
for observer in self.observers:
|
||||
await observer.update(message)
|
2
Observ/__init__.py
Normal file
2
Observ/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .Observer import Observer
|
||||
from .Subject import Subject
|
@ -8,7 +8,7 @@ from .Speakers import Speakers
|
||||
from .multi_v2_package import TTSModelMulti_v2
|
||||
|
||||
|
||||
class TTS:
|
||||
class TTSSilero:
|
||||
def __init__(self, threads: int = 12):
|
||||
device = torch.device('cpu')
|
||||
torch.set_num_threads(threads)
|
||||
@ -25,8 +25,8 @@ class TTS:
|
||||
|
||||
self.sample_rate = 16000
|
||||
|
||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya, seek: int = None) -> io.BytesIO:
|
||||
return self.to_wav(self._synthesize_text(text, speaker), seek)
|
||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
|
||||
return self.to_wav(self._synthesize_text(text, speaker))
|
||||
|
||||
def _synthesize_text(self, text: str, speaker: Speakers) -> list[torch.Tensor]:
|
||||
"""
|
||||
@ -44,7 +44,7 @@ class TTS:
|
||||
|
||||
return results_list
|
||||
|
||||
def to_wav(self, synthesized_text: list[torch.Tensor], seek: int = None) -> io.BytesIO:
|
||||
def to_wav(self, synthesized_text: list[torch.Tensor]) -> bytes:
|
||||
res_io_stream = io.BytesIO()
|
||||
|
||||
with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf:
|
||||
@ -54,9 +54,8 @@ class TTS:
|
||||
for result in synthesized_text:
|
||||
wf.writeframes((result * 32767).numpy().astype('int16'))
|
||||
|
||||
if type(seek) is int:
|
||||
res_io_stream.seek(seek)
|
||||
res_io_stream.seek(0)
|
||||
|
||||
return res_io_stream
|
||||
return res_io_stream.read()
|
||||
|
||||
|
||||
|
60
TTSSilero/TTSSileroCached.py
Normal file
60
TTSSilero/TTSSileroCached.py
Normal file
@ -0,0 +1,60 @@
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
from .TTSSilero import TTSSilero
|
||||
from .Speakers import Speakers
|
||||
import DB
|
||||
import sqlite3
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class TTSSileroCached(TTSSilero):
|
||||
_SQLITE_GET = "select audio from soundcache where text = :text and speaker = :speaker;"
|
||||
_SQLITE_SET = "insert into soundcache (text, speaker, audio) values (:text, :speaker, :audio);"
|
||||
_SQLITE_INCREMENT_USAGES = "update soundcache set usages = usages + 1 where text = :text and :speaker = :speaker;"
|
||||
_SQLITE_SCHEMA = """CREATE TABLE IF NOT EXISTS "soundcache" (
|
||||
"text" TEXT NOT NULL,
|
||||
"speaker" VARCHAR(255) NOT NULL,
|
||||
"audio" BLOB NOT NULL,
|
||||
"usages" INTEGER NOT NULL default 0,
|
||||
PRIMARY KEY ("text", "speaker")
|
||||
);"""
|
||||
|
||||
database: sqlite3.Connection = DB.database.connection()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.database.execute(self._SQLITE_SCHEMA)
|
||||
|
||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
|
||||
# start = time.time()
|
||||
cached = self._cache_get(text, speaker.value)
|
||||
# logger.debug(f'Cache lookup in {time.time() - start} s')
|
||||
if cached is not None:
|
||||
# logger.debug(f'Cache lookup successful in {time.time() - start} s')
|
||||
return cached
|
||||
|
||||
else:
|
||||
# logger.debug(f'Starting synthesis')
|
||||
# start2 = time.time()
|
||||
synthesized = super().synthesize_text(text, speaker)
|
||||
# logger.debug(f'Synthesis done in {time.time() - start2} s in {time.time() - start} s after start')
|
||||
self._cache_set(text, speaker.value, synthesized)
|
||||
# logger.debug(f'Cache set in {time.time() - start2} synth end and {time.time() - start2} s after start')
|
||||
return synthesized
|
||||
|
||||
def _cache_get(self, text: str, speaker: str) -> Union[bytes, None]:
|
||||
query_args = {'text': text, 'speaker': speaker}
|
||||
result = self.database.execute(self._SQLITE_GET, query_args).fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
else:
|
||||
with self.database:
|
||||
self.database.execute(self._SQLITE_INCREMENT_USAGES, query_args)
|
||||
|
||||
return result[0]
|
||||
|
||||
def _cache_set(self, text: str, speaker: str, audio: bytes) -> None:
|
||||
with self.database:
|
||||
self.database.execute(self._SQLITE_SET, {'text': text, 'speaker': speaker, 'audio': audio})
|
@ -1,2 +1,3 @@
|
||||
from .TTSSilero import TTS
|
||||
from .TTSSilero import TTSSilero
|
||||
from .Speakers import Speakers
|
||||
from .TTSSileroCached import TTSSileroCached
|
||||
|
15
cogErrorHandlers.py
Normal file
15
cogErrorHandlers.py
Normal file
@ -0,0 +1,15 @@
|
||||
from loguru import logger
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
|
||||
|
||||
class cogErrorHandlers:
|
||||
@classmethod
|
||||
async def missing_argument_handler(cls, ctx: Context, error: Exception) -> None:
|
||||
if isinstance(error, commands.errors.MissingRequiredArgument):
|
||||
# No argument was specified
|
||||
await ctx.reply(str(error))
|
||||
|
||||
else:
|
||||
logger.exception(f'prefixConfiguration error occurred: ', exc_info=error)
|
||||
await ctx.reply(f'Internal error occurred')
|
100
cogs/TTSCommands.py
Normal file
100
cogs/TTSCommands.py
Normal file
@ -0,0 +1,100 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
import discord
|
||||
import DB
|
||||
from typing import Union
|
||||
from loguru import logger
|
||||
from TTSSilero import TTSSileroCached
|
||||
from TTSSilero import Speakers
|
||||
from FFmpegPCMAudioModified import FFmpegPCMAudio
|
||||
import Observ
|
||||
from cogErrorHandlers import cogErrorHandlers
|
||||
|
||||
|
||||
class TTSCommands(commands.Cog, Observ.Observer):
|
||||
DEFAULT_SPEAKER = Speakers.kseniya
|
||||
|
||||
def __init__(self, bot: Union[commands.Bot, Observ.Subject]):
|
||||
self.bot = bot
|
||||
self.cog_command_error = cogErrorHandlers.missing_argument_handler
|
||||
self.bot.subscribe(self) # subscribe for messages that aren't commands
|
||||
self.tts = TTSSileroCached()
|
||||
|
||||
@commands.command('exit')
|
||||
async def leave_voice(self, ctx: Context):
|
||||
if ctx.guild.voice_client is not None:
|
||||
await ctx.guild.voice_client.disconnect(force=False)
|
||||
await ctx.channel.send(f"Left voice channel")
|
||||
return
|
||||
|
||||
else:
|
||||
await ctx.channel.send("I'm not in any voice channel")
|
||||
return
|
||||
|
||||
async def update(self, message: discord.Message):
|
||||
"""
|
||||
Like on_message but only for messages which aren't commands
|
||||
|
||||
:param message:
|
||||
:return:
|
||||
"""
|
||||
|
||||
if message.author == self.bot.user:
|
||||
return
|
||||
|
||||
if not isinstance(message.channel, discord.TextChannel):
|
||||
return
|
||||
|
||||
logger.info(f'Message: {message.content}')
|
||||
user_voice_state = message.author.voice
|
||||
if user_voice_state is None:
|
||||
await message.channel.send(f"You're not in a voice channel")
|
||||
return
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
voice_client: discord.VoiceClient = message.guild.voice_client
|
||||
if voice_client is None:
|
||||
voice_client: discord.VoiceClient = await user_voice_state.channel.connect()
|
||||
|
||||
speaker: Speakers = await self._get_speaker(message.guild.id)
|
||||
|
||||
wav_file_like_object = self.tts.synthesize_text(message.content, speaker=speaker)
|
||||
sound_source = FFmpegPCMAudio(wav_file_like_object, pipe=True)
|
||||
|
||||
voice_client.play(sound_source)
|
||||
|
||||
@commands.command('getAllSpeakers')
|
||||
async def get_speakers(self, ctx: Context):
|
||||
speakers = '\n'.join([speaker.name for speaker in Speakers])
|
||||
|
||||
await ctx.send(f"```\n{speakers}```")
|
||||
|
||||
@commands.command('setSpeaker')
|
||||
async def set_speaker(self, ctx: Context, speaker: str):
|
||||
try:
|
||||
checked_speaker: Speakers = Speakers(speaker)
|
||||
DB.Speaker.replace(server_id=ctx.guild.id, speaker=checked_speaker.value).execute()
|
||||
await ctx.send(f'Successfully set speaker to `{checked_speaker.value}`')
|
||||
|
||||
except KeyError:
|
||||
await ctx.send(f"Provided speaker is invalid, provided speaker must be from `getAllSpeakers` command")
|
||||
|
||||
@commands.command('getSpeaker')
|
||||
async def get_speaker(self, ctx: Context):
|
||||
speaker = await self._get_speaker(ctx.guild.id)
|
||||
|
||||
await ctx.send(f'Your current speaker is `{speaker.value}`')
|
||||
|
||||
async def _get_speaker(self, guild_id: int) -> Speakers:
|
||||
try:
|
||||
speaker = Speakers(DB.Speaker[guild_id].speaker)
|
||||
|
||||
except DB.peewee.DoesNotExist:
|
||||
speaker = self.DEFAULT_SPEAKER
|
||||
|
||||
return speaker
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(TTSCommands(bot))
|
43
cogs/prefixConfiguration.py
Normal file
43
cogs/prefixConfiguration.py
Normal file
@ -0,0 +1,43 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context
|
||||
import DB
|
||||
from loguru import logger
|
||||
import DynamicCommandPrefix
|
||||
from cogErrorHandlers import cogErrorHandlers
|
||||
|
||||
|
||||
class prefixConfiguration(commands.Cog):
|
||||
"""
|
||||
Cog for manage prefix in per server way
|
||||
"""
|
||||
|
||||
def __init__(self, bot: commands.Bot):
|
||||
self.bot = bot
|
||||
self.cog_command_error = cogErrorHandlers.missing_argument_handler
|
||||
|
||||
@commands.command('setPrefix')
|
||||
async def set_prefix(self, ctx: Context, prefix: str):
|
||||
logger.debug(f'Going to set prefix')
|
||||
if len(prefix) > DB.Prefix.prefix_char.max_length:
|
||||
await ctx.reply(f'Prefix must be one symbol')
|
||||
return
|
||||
|
||||
DB.Prefix.replace(server_id=ctx.guild.id, prefix_char=prefix).execute()
|
||||
|
||||
logger.debug(f'Set prefix {prefix!r} for guild {ctx.guild.name!r}')
|
||||
await ctx.reply(f'Your new prefix is `{prefix}`')
|
||||
|
||||
@commands.command('getPrefix')
|
||||
async def get_prefix(self, ctx: Context):
|
||||
prefix = DynamicCommandPrefix.get_guild_prefix(ctx.guild.id)
|
||||
await ctx.reply(f'Your current prefix is `{prefix}` and <@{self.bot.user.id}>')
|
||||
|
||||
@commands.command('resetPrefix')
|
||||
async def reset_prefix(self, ctx: Context):
|
||||
DB.Prefix.delete().where(DB.Prefix.server_id == ctx.guild.id).execute()
|
||||
await ctx.reply(f'Your prefix was deleted')
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
await bot.add_cog(prefixConfiguration(bot))
|
83
main.py
83
main.py
@ -1,19 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import time
|
||||
|
||||
from discord.ext import commands
|
||||
import discord
|
||||
import signal
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
from DynamicCommandPrefix import dynamic_command_prefix
|
||||
import Observ
|
||||
|
||||
import utils
|
||||
from TTSSilero import TTS
|
||||
from TTSSilero import Speakers
|
||||
from FFmpegPCMAudioModified import FFmpegPCMAudio
|
||||
from Cache import cache
|
||||
|
||||
tts = TTS()
|
||||
"""
|
||||
while msg := input('$ '):
|
||||
start = time.time()
|
||||
@ -23,7 +18,7 @@ while msg := input('$ '):
|
||||
"""
|
||||
|
||||
|
||||
class DiscordTTSBot(discord.Client):
|
||||
class DiscordTTSBot(commands.Bot, Observ.Subject):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
signal.signal(signal.SIGTERM, self.shutdown)
|
||||
@ -37,61 +32,35 @@ class DiscordTTSBot(discord.Client):
|
||||
async def on_ready(self):
|
||||
logger.debug('Bot is ready')
|
||||
|
||||
async def on_message(self, message: discord.Message):
|
||||
if message.author == self.user:
|
||||
return
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
await super(DiscordTTSBot, self).on_message(message)
|
||||
if message.author.bot: # because on_command_error will not be called if author is bot
|
||||
# so it isn't a command, so, pass it next
|
||||
await self.notify(message)
|
||||
|
||||
if not message.content.startswith('-'):
|
||||
return
|
||||
async def on_command_error(self, ctx: commands.Context, exception: commands.errors.CommandError) -> None:
|
||||
if isinstance(exception, commands.errors.CommandNotFound):
|
||||
ctx.message.content = ctx.message.content[len(ctx.prefix):] # skip prefix
|
||||
await self.notify(ctx.message)
|
||||
|
||||
if isinstance(message.channel, discord.TextChannel):
|
||||
logger.info(f'Message: {message.content}')
|
||||
user_voice_state = message.author.voice
|
||||
if message.content.startswith('/exit'):
|
||||
if message.guild.voice_client is not None:
|
||||
logger.debug(f'Disconnecting from voice channel')
|
||||
await message.guild.voice_client.disconnect(force=False)
|
||||
await message.channel.send(f"Left voice channel")
|
||||
return
|
||||
|
||||
else:
|
||||
await message.channel.send("I'm not in any voice channel")
|
||||
return
|
||||
|
||||
if user_voice_state is None:
|
||||
await message.channel.send(f"You're not in a voice channel")
|
||||
return
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
voice_client: discord.VoiceClient = message.guild.voice_client
|
||||
if voice_client is None:
|
||||
voice_client: discord.VoiceClient = await user_voice_state.channel.connect()
|
||||
|
||||
cached = cache.get(message.content)
|
||||
if cached is not None:
|
||||
wav_file_like_object = cached
|
||||
logger.debug(f'Cache lookup for {message.content!r} successful')
|
||||
|
||||
else:
|
||||
synthesis_start = time.time()
|
||||
wav_file_like_object = tts.synthesize_text(message.content, seek=0)
|
||||
logger.debug(f'Synthesis took {time.time() - synthesis_start} s')
|
||||
cache.set(message.content, wav_file_like_object.read())
|
||||
logger.debug(f'Set cache for {message.content!r}')
|
||||
wav_file_like_object.seek(0)
|
||||
|
||||
sound_source = FFmpegPCMAudio(
|
||||
wav_file_like_object.read(),
|
||||
pipe=True
|
||||
)
|
||||
voice_client.play(sound_source, after=lambda e: logger.debug(f"Player done, {e=}"))
|
||||
else:
|
||||
raise exception
|
||||
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
discord_client = DiscordTTSBot(intents=intents)
|
||||
discord_client = DiscordTTSBot(command_prefix=dynamic_command_prefix, intents=intents)
|
||||
|
||||
|
||||
async def main():
|
||||
for filename in os.listdir("./cogs"):
|
||||
if filename.endswith(".py"):
|
||||
logger.debug(f'Loading extension {filename}')
|
||||
await discord_client.load_extension(f"cogs.{filename[:-3]}")
|
||||
|
||||
await discord_client.start(os.environ['DISCORD_TOKEN'])
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(discord_client.start(os.environ['DISCORD_TOKEN']))
|
||||
loop.run_until_complete(main())
|
||||
logger.debug('Shutdown completed')
|
||||
|
BIN
requirements.txt
Normal file
BIN
requirements.txt
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user