63 lines
1.9 KiB
Python
63 lines
1.9 KiB
Python
import os
|
|
import io
|
|
import torch.package
|
|
|
|
from .Speakers import Speakers
|
|
from multi_acc_v3_package import TTSModelMultiAcc_v3
|
|
|
|
|
|
class TTSSilero:
|
|
def __init__(self, threads: int = 24):
|
|
device = torch.device('cpu')
|
|
torch.set_num_threads(threads)
|
|
"""
|
|
local_file = 'model_multi.pt'
|
|
|
|
if not os.path.isfile(local_file):
|
|
torch.hub.download_url_to_file(
|
|
'https://models.silero.ai/models/tts/multi/v2_multi.pt',
|
|
local_file
|
|
)
|
|
|
|
self.model: TTSModelMulti_v2 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
|
|
self.model.to(device)
|
|
"""
|
|
local_file = 'model.pt'
|
|
|
|
if not os.path.isfile(local_file):
|
|
torch.hub.download_url_to_file('https://models.silero.ai/models/tts/ru/ru_v3.pt', local_file)
|
|
|
|
self.model: TTSModelMultiAcc_v3 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
|
|
self.model.to(device)
|
|
# print(self.model.speakers)
|
|
|
|
self.sample_rate = 48000
|
|
|
|
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
|
|
return self.to_wav(self._synthesize_text(text, speaker))
|
|
|
|
def _synthesize_text(self, text: str, speaker: Speakers) -> torch.Tensor:
|
|
"""
|
|
Performs splitting text and synthesizing it
|
|
|
|
:param text:
|
|
:return:
|
|
"""
|
|
|
|
result: torch.Tensor = self.model.apply_tts(
|
|
text=text,
|
|
speaker=speaker.value,
|
|
sample_rate=self.sample_rate
|
|
)
|
|
|
|
return result
|
|
|
|
def to_wav(self, synthesized_text: torch.Tensor) -> bytes:
|
|
res_io_stream = io.BytesIO()
|
|
self.model.write_wave(res_io_stream, (synthesized_text * 32767).numpy().astype('int16'), self.sample_rate)
|
|
res_io_stream.seek(0)
|
|
|
|
return res_io_stream.read()
|
|
|
|
|