init
This commit is contained in:
commit
f123cc1f86
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
__pycache__
|
||||
.idea
|
||||
cache
|
33
EngineABC.py
Normal file
33
EngineABC.py
Normal file
@ -0,0 +1,33 @@
|
||||
from typing import Literal
|
||||
from abc import abstractmethod, ABC
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Argument(BaseModel):
|
||||
type: Literal['str', 'int', 'float']
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class ModelDescription(BaseModel):
|
||||
engine: str
|
||||
name: str
|
||||
arguments: dict[str, Argument]
|
||||
description: None | str = None
|
||||
|
||||
|
||||
class EngineABC(ABC):
|
||||
def __init__(self, save_path: 'Path'):
|
||||
self.save_path = save_path
|
||||
|
||||
@abstractmethod
|
||||
def discovery(self) -> tuple[ModelDescription, ...]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def synth(self, text: str, model: str, *args, **kwargs) -> bytes:
|
||||
...
|
||||
|
||||
:
|
40
EnginesController.py
Normal file
40
EnginesController.py
Normal file
@ -0,0 +1,40 @@
|
||||
from loguru import logger
|
||||
import importlib
|
||||
from EngineABC import EngineABC, ModelDescription
|
||||
from pathlib import Path
|
||||
from analytics import measure
|
||||
from preprocessing import preprocess
|
||||
import config
|
||||
|
||||
|
||||
class EnginesController:
|
||||
def __init__(self):
|
||||
self.engines: dict[str, EngineABC] = dict()
|
||||
engines_dir = Path('engines')
|
||||
for module_dir in engines_dir.iterdir():
|
||||
init_path = module_dir / '__init__.py'
|
||||
if init_path.exists():
|
||||
module = importlib.machinery.SourceFileLoader(module_dir.name, str(init_path)).load_module()
|
||||
logger.info(f'imported {init_path.parent.name!r} engine')
|
||||
for obj_name in dir(module):
|
||||
obj = getattr(module, obj_name)
|
||||
if isinstance(obj, type) and EngineABC in obj.__bases__:
|
||||
obj: type[EngineABC]
|
||||
cache_path = config.BASE / obj.__name__
|
||||
cache_path.mkdir(exist_ok=True)
|
||||
instance: EngineABC = obj(save_path=cache_path)
|
||||
self.engines[obj.__name__] = instance
|
||||
|
||||
def discovery(self) -> list[ModelDescription]:
|
||||
to_return: list[ModelDescription] = list()
|
||||
for name, engine in self.engines.items():
|
||||
to_return += engine.discovery()
|
||||
|
||||
return to_return
|
||||
|
||||
@measure
|
||||
def synth(self, engine: str, model: str, text: str, **kwargs) -> bytes:
|
||||
return self.engines[engine].synth(preprocess(text), model, **kwargs)
|
||||
|
||||
def synth_all(self, text: str = 'Съешь ещё 15 этих пирожков') -> list[bytes]:
|
||||
return [self.synth(model.engine, model.name, text=text) for model in self.discovery()]
|
38
analytics.py
Normal file
38
analytics.py
Normal file
@ -0,0 +1,38 @@
|
||||
import time
|
||||
|
||||
|
||||
def measure(function: callable, name_to_display: str = ''):
|
||||
"""
|
||||
Decorator to measure function (method) execution time
|
||||
Use as easy as
|
||||
|
||||
@utils.measure
|
||||
def im_function_to_measure():
|
||||
....
|
||||
|
||||
:param name_to_display:
|
||||
:param function:
|
||||
:return:
|
||||
"""
|
||||
if name_to_display != '':
|
||||
name_to_display = name_to_display + ':'
|
||||
|
||||
def decorated(*args, **kwargs):
|
||||
start = time.time()
|
||||
result = function(*args, **kwargs)
|
||||
end = time.time()
|
||||
|
||||
print(f'{name_to_display}{function.__name__}:{args[1:]} {(end - start) * 100} ms')
|
||||
|
||||
return result
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class Measure:
|
||||
def __init__(self, name: str):
|
||||
self.start = time.time()
|
||||
self.name = name
|
||||
|
||||
def record(self) -> None:
|
||||
print(f'{self.name}: {(time.time() - self.start) * 100} ms')
|
45
engines/silero/__init__.py
Normal file
45
engines/silero/__init__.py
Normal file
@ -0,0 +1,45 @@
|
||||
from EngineABC import EngineABC, ModelDescription
|
||||
import torch.package
|
||||
from loguru import logger
|
||||
from to_wav import tensor2wav
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from .multi_v2_package import TTSModelMulti_v2
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Silero(EngineABC):
|
||||
def __init__(self, save_path: 'Path'):
|
||||
super().__init__(save_path)
|
||||
threads = 12
|
||||
device = torch.device('cpu')
|
||||
torch.set_num_threads(threads)
|
||||
dst = save_path / 'model_multi.pt'
|
||||
if not dst.exists():
|
||||
torch.hub.download_url_to_file(
|
||||
'https://models.silero.ai/models/tts/multi/v2_multi.pt',
|
||||
str(dst)
|
||||
)
|
||||
|
||||
self.model: TTSModelMulti_v2 = torch.package.PackageImporter(dst).load_pickle("tts_models", "model")
|
||||
|
||||
self.model.to(device)
|
||||
logger.debug(f"Loading speaker: v2_multi")
|
||||
self.sample_rate = 16000
|
||||
|
||||
def synth(self, text: str, model: str, *args, **kwargs) -> bytes:
|
||||
# model is speaker
|
||||
return tensor2wav(
|
||||
self.model.apply_tts(texts=text, speakers=model, sample_rate=self.sample_rate)[0],
|
||||
self.sample_rate
|
||||
)
|
||||
|
||||
def discovery(self) -> tuple[ModelDescription, ...]:
|
||||
return tuple(
|
||||
ModelDescription(
|
||||
engine=self.__class__.__name__,
|
||||
name=speaker_name,
|
||||
arguments=dict()
|
||||
)
|
||||
for speaker_name in self.model.speaker_to_id.keys()
|
||||
)
|
148
engines/silero/multi_v2_package.py
Normal file
148
engines/silero/multi_v2_package.py
Normal file
@ -0,0 +1,148 @@
|
||||
import re
|
||||
import wave
|
||||
import torch
|
||||
import warnings
|
||||
import contextlib
|
||||
|
||||
# for type hints only
|
||||
|
||||
|
||||
class TTSModelMulti_v2:
|
||||
def __init__(self, model_path, symbols):
|
||||
self.model = self.init_jit_model(model_path)
|
||||
self.symbols = symbols
|
||||
self.device = torch.device('cpu')
|
||||
speakers = ['aidar', 'baya', 'kseniya', 'irina', 'ruslan', 'natasha',
|
||||
'thorsten', 'tux', 'gilles', 'lj', 'dilyara']
|
||||
self.speaker_to_id = {sp: i for i, sp in enumerate(speakers)}
|
||||
|
||||
def init_jit_model(self, model_path: str):
|
||||
torch.set_grad_enabled(False)
|
||||
model = torch.jit.load(model_path, map_location='cpu')
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def prepare_text_input(self, text, symbols, symbol_to_id=None):
|
||||
if len(text) > 140:
|
||||
warnings.warn('Text string is longer than 140 symbols.')
|
||||
|
||||
if symbol_to_id is None:
|
||||
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
|
||||
text = text.lower()
|
||||
text = re.sub(r'[^{}]'.format(symbols[2:]), '', text)
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
if text[-1] not in ['.', '!', '?']:
|
||||
text = text + '.'
|
||||
text = text + symbols[1]
|
||||
|
||||
text_ohe = [symbol_to_id[s] for s in text if s in symbols]
|
||||
text_tensor = torch.LongTensor(text_ohe)
|
||||
return text_tensor
|
||||
|
||||
def prepare_tts_model_input(self, text: str or list, symbols: str, speakers: list):
|
||||
assert len(speakers) == len(text) or len(speakers) == 1
|
||||
if type(text) == str:
|
||||
text = [text]
|
||||
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
if len(text) == 1:
|
||||
return self.prepare_text_input(text[0], symbols, symbol_to_id).unsqueeze(0), torch.LongTensor(speakers), torch.LongTensor([0])
|
||||
|
||||
text_tensors = []
|
||||
for string in text:
|
||||
string_tensor = self.prepare_text_input(string, symbols, symbol_to_id)
|
||||
text_tensors.append(string_tensor)
|
||||
input_lengths, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([len(t) for t in text_tensors]),
|
||||
dim=0, descending=True)
|
||||
max_input_len = input_lengths[0]
|
||||
batch_size = len(text_tensors)
|
||||
|
||||
text_padded = torch.ones(batch_size, max_input_len, dtype=torch.int32)
|
||||
if len(speakers) == 1:
|
||||
speakers = speakers*batch_size
|
||||
speaker_ids = torch.LongTensor(batch_size).zero_()
|
||||
|
||||
for i, idx in enumerate(ids_sorted_decreasing):
|
||||
text_tensor = text_tensors[idx]
|
||||
in_len = text_tensor.size(0)
|
||||
text_padded[i, :in_len] = text_tensor
|
||||
speaker_ids[i] = speakers[idx]
|
||||
|
||||
return text_padded, speaker_ids, ids_sorted_decreasing
|
||||
|
||||
def process_tts_model_output(self, out, out_lens, ids):
|
||||
out = out.to('cpu')
|
||||
out_lens = out_lens.to('cpu')
|
||||
_, orig_ids = ids.sort()
|
||||
|
||||
proc_outs = []
|
||||
orig_out = out.index_select(0, orig_ids)
|
||||
orig_out_lens = out_lens.index_select(0, orig_ids)
|
||||
|
||||
for i, out_len in enumerate(orig_out_lens):
|
||||
proc_outs.append(orig_out[i][:out_len])
|
||||
return proc_outs
|
||||
|
||||
def to(self, device):
|
||||
self.model = self.model.to(device)
|
||||
self.device = device
|
||||
|
||||
def get_speakers(self, speakers: str or list):
|
||||
if type(speakers) == str:
|
||||
speakers = [speakers]
|
||||
speaker_ids = []
|
||||
for speaker in speakers:
|
||||
try:
|
||||
speaker_id = self.speaker_to_id[speaker]
|
||||
speaker_ids.append(speaker_id)
|
||||
except Exception:
|
||||
raise ValueError(f'No such speaker: {speaker}')
|
||||
return speaker_ids
|
||||
|
||||
def apply_tts(self, texts: str or list,
|
||||
speakers: str or list,
|
||||
sample_rate: int = 16000):
|
||||
speaker_ids = self.get_speakers(speakers)
|
||||
text_padded, speaker_ids, orig_ids = self.prepare_tts_model_input(texts,
|
||||
symbols=self.symbols,
|
||||
speakers=speaker_ids)
|
||||
with torch.inference_mode():
|
||||
out, out_lens = self.model(text_padded.to(self.device),
|
||||
speaker_ids.to(self.device),
|
||||
sr=sample_rate)
|
||||
audios = self.process_tts_model_output(out, out_lens, orig_ids)
|
||||
return audios
|
||||
|
||||
@staticmethod
|
||||
def write_wave(path, audio, sample_rate):
|
||||
"""Writes a .wav file.
|
||||
Takes path, PCM audio data, and sample rate.
|
||||
"""
|
||||
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio)
|
||||
|
||||
def save_wav(self, texts: str or list,
|
||||
speakers: str or list,
|
||||
audio_pathes: str or list = '',
|
||||
sample_rate: int = 16000):
|
||||
if type(texts) == str:
|
||||
texts = [texts]
|
||||
|
||||
if not audio_pathes:
|
||||
audio_pathes = [f'test_{str(i).zfill(3)}.wav' for i in range(len(texts))]
|
||||
if type(audio_pathes) == str:
|
||||
audio_pathes = [audio_pathes]
|
||||
assert len(audio_pathes) == len(texts)
|
||||
|
||||
audio = self.apply_tts(texts=texts,
|
||||
speakers=speakers,
|
||||
sample_rate=sample_rate)
|
||||
for i, _audio in enumerate(audio):
|
||||
self.write_wave(path=audio_pathes[i],
|
||||
audio=(_audio * 32767).numpy().astype('int16'),
|
||||
sample_rate=sample_rate)
|
||||
return audio_pathes
|
2
engines/teratts/TeraTTS/__init__.py
Normal file
2
engines/teratts/TeraTTS/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .infer_onnx import TTS
|
||||
from .tokenizer import TokenizerG2P
|
65
engines/teratts/TeraTTS/infer_onnx.py
Normal file
65
engines/teratts/TeraTTS/infer_onnx.py
Normal file
@ -0,0 +1,65 @@
|
||||
import os
|
||||
import onnxruntime
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
from .tokenizer import TokenizerG2P
|
||||
|
||||
|
||||
class TTS:
|
||||
def __init__(self, model_name: str, save_path: str = "./model", add_time_to_end: float = 1.0, tokenizer_load_dict=True) -> None:
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
|
||||
model_dir = os.path.join(save_path, model_name)
|
||||
|
||||
if not os.path.exists(model_dir):
|
||||
snapshot_download(repo_id=model_name,
|
||||
allow_patterns=["*.txt", "*.onnx", "*.json"],
|
||||
local_dir=model_dir
|
||||
)
|
||||
|
||||
self.model = onnxruntime.InferenceSession(os.path.join(model_dir, "exported/model.onnx"),
|
||||
providers=['CPUExecutionProvider'])
|
||||
|
||||
self.tokenizer = TokenizerG2P(os.path.join(model_dir, "exported"), load_dict=tokenizer_load_dict)
|
||||
|
||||
self.add_time_to_end = add_time_to_end
|
||||
|
||||
def _add_silent(self, audio, silence_duration: float = 1.0, sample_rate: int = 22050):
|
||||
num_samples_silence = int(sample_rate * silence_duration)
|
||||
silence_array = np.zeros(num_samples_silence, dtype=np.float32)
|
||||
audio_with_silence = np.concatenate((audio, silence_array), axis=0)
|
||||
return audio_with_silence
|
||||
|
||||
def _intersperse(self, lst, item):
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
def _get_seq(self, text):
|
||||
phoneme_ids = self.tokenizer._get_seq(text)
|
||||
phoneme_ids_inter = self._intersperse(phoneme_ids, 0)
|
||||
return phoneme_ids_inter
|
||||
|
||||
def __call__(self, text: str, play=False, lenght_scale=1.2):
|
||||
|
||||
phoneme_ids = self._get_seq(text)
|
||||
text = np.expand_dims(np.array(phoneme_ids, dtype=np.int64), 0)
|
||||
text_lengths = np.array([text.shape[1]], dtype=np.int64)
|
||||
scales = np.array(
|
||||
[0.667, lenght_scale, 0.8],
|
||||
dtype=np.float32,
|
||||
)
|
||||
audio = self.model.run(
|
||||
None,
|
||||
{
|
||||
"input": text,
|
||||
"input_lengths": text_lengths,
|
||||
"scales": scales,
|
||||
"sid": None,
|
||||
},
|
||||
)[0][0, 0][0]
|
||||
audio = self._add_silent(audio, silence_duration=self.add_time_to_end)
|
||||
if play:
|
||||
self.play_audio(audio)
|
||||
return audio
|
1
engines/teratts/TeraTTS/tokenizer/__init__.py
Normal file
1
engines/teratts/TeraTTS/tokenizer/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .g2p import Tokenizer as TokenizerG2P
|
1
engines/teratts/TeraTTS/tokenizer/g2p/__init__.py
Normal file
1
engines/teratts/TeraTTS/tokenizer/g2p/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .tokenizer import Tokenizer
|
94
engines/teratts/TeraTTS/tokenizer/g2p/g2p.py
Normal file
94
engines/teratts/TeraTTS/tokenizer/g2p/g2p.py
Normal file
@ -0,0 +1,94 @@
|
||||
|
||||
softletters=set(u"яёюиье")
|
||||
startsyl=set(u"#ъьаяоёуюэеиы-")
|
||||
others = set(["#", "+", "-", u"ь", u"ъ"])
|
||||
|
||||
softhard_cons = {
|
||||
u"б" : u"b",
|
||||
u"в" : u"v",
|
||||
u"г" : u"g",
|
||||
u"Г" : u"g",
|
||||
u"д" : u"d",
|
||||
u"з" : u"z",
|
||||
u"к" : u"k",
|
||||
u"л" : u"l",
|
||||
u"м" : u"m",
|
||||
u"н" : u"n",
|
||||
u"п" : u"p",
|
||||
u"р" : u"r",
|
||||
u"с" : u"s",
|
||||
u"т" : u"t",
|
||||
u"ф" : u"f",
|
||||
u"х" : u"h"
|
||||
}
|
||||
|
||||
other_cons = {
|
||||
u"ж" : u"zh",
|
||||
u"ц" : u"c",
|
||||
u"ч" : u"ch",
|
||||
u"ш" : u"sh",
|
||||
u"щ" : u"sch",
|
||||
u"й" : u"j"
|
||||
}
|
||||
|
||||
vowels = {
|
||||
u"а" : u"a",
|
||||
u"я" : u"a",
|
||||
u"у" : u"u",
|
||||
u"ю" : u"u",
|
||||
u"о" : u"o",
|
||||
u"ё" : u"o",
|
||||
u"э" : u"e",
|
||||
u"е" : u"e",
|
||||
u"и" : u"i",
|
||||
u"ы" : u"y",
|
||||
}
|
||||
|
||||
def pallatize(phones):
|
||||
for i, phone in enumerate(phones[:-1]):
|
||||
if phone[0] in softhard_cons:
|
||||
if phones[i+1][0] in softletters:
|
||||
phones[i] = (softhard_cons[phone[0]] + "j", 0)
|
||||
else:
|
||||
phones[i] = (softhard_cons[phone[0]], 0)
|
||||
if phone[0] in other_cons:
|
||||
phones[i] = (other_cons[phone[0]], 0)
|
||||
|
||||
def convert_vowels(phones):
|
||||
new_phones = []
|
||||
prev = ""
|
||||
for phone in phones:
|
||||
if prev in startsyl:
|
||||
if phone[0] in set(u"яюеё"):
|
||||
new_phones.append("j")
|
||||
if phone[0] in vowels:
|
||||
new_phones.append(vowels[phone[0]] + str(phone[1]))
|
||||
else:
|
||||
new_phones.append(phone[0])
|
||||
prev = phone[0]
|
||||
|
||||
return new_phones
|
||||
|
||||
def convert(stressword):
|
||||
phones = ("#" + stressword + "#")
|
||||
|
||||
|
||||
# Assign stress marks
|
||||
stress_phones = []
|
||||
stress = 0
|
||||
for phone in phones:
|
||||
if phone == "+":
|
||||
stress = 1
|
||||
else:
|
||||
stress_phones.append((phone, stress))
|
||||
stress = 0
|
||||
|
||||
# Pallatize
|
||||
pallatize(stress_phones)
|
||||
|
||||
# Assign stress
|
||||
phones = convert_vowels(stress_phones)
|
||||
|
||||
# Filter
|
||||
phones = [x for x in phones if x not in others]
|
||||
return " ".join(phones)
|
50
engines/teratts/TeraTTS/tokenizer/g2p/tokenizer.py
Normal file
50
engines/teratts/TeraTTS/tokenizer/g2p/tokenizer.py
Normal file
@ -0,0 +1,50 @@
|
||||
import re
|
||||
from .g2p import * #noqa
|
||||
import json
|
||||
import os
|
||||
|
||||
class Tokenizer():
|
||||
def __init__(self, data_path: str, load_dict=True) -> None:
|
||||
'''data_path - path to data dir; load_dict - load dict, if you use accent model like ruaccent you dont need its'''
|
||||
self.dic = {}
|
||||
if load_dict:
|
||||
for line in open(os.path.join(data_path, "dictionary.txt")): #noqa
|
||||
items = line.split()
|
||||
self.dic[items[0]] = " ".join(items[1:])
|
||||
|
||||
self.config = json.load(open(os.path.join(data_path, "config.json"))) #noqa
|
||||
|
||||
def g2p(self, text):
|
||||
text = re.sub("—", "-", text)
|
||||
text = re.sub("([!'(),-.:;?])", r' \1 ', text)
|
||||
|
||||
phonemes = []
|
||||
for word in text.split():
|
||||
if re.match("[!'(),-.:;?]", word):
|
||||
phonemes.append(word)
|
||||
continue
|
||||
|
||||
word = word.lower()
|
||||
if len(phonemes) > 0:
|
||||
phonemes.append(' ')
|
||||
|
||||
if word in self.dic:
|
||||
phonemes.extend(self.dic[word].split())
|
||||
else:
|
||||
phonemes.extend(convert(word).split()) #noqa
|
||||
|
||||
phoneme_id_map = self.config["phoneme_id_map"]
|
||||
phoneme_ids = []
|
||||
phoneme_ids.extend(phoneme_id_map["^"])
|
||||
phoneme_ids.extend(phoneme_id_map["_"])
|
||||
for p in phonemes:
|
||||
if p in phoneme_id_map:
|
||||
phoneme_ids.extend(phoneme_id_map[p])
|
||||
phoneme_ids.extend(phoneme_id_map["_"])
|
||||
phoneme_ids.extend(phoneme_id_map["$"])
|
||||
|
||||
return phoneme_ids, phonemes
|
||||
|
||||
def _get_seq(self, text: str) -> list[int]:
|
||||
seq = self.g2p(text)[0]
|
||||
return seq
|
34
engines/teratts/__init__.py
Normal file
34
engines/teratts/__init__.py
Normal file
@ -0,0 +1,34 @@
|
||||
from to_wav import ndarray2wav
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from EngineABC import EngineABC, ModelDescription, Argument
|
||||
from .TeraTTS import TTS
|
||||
|
||||
|
||||
class TeraTTSEngine(EngineABC):
|
||||
def discovery(self) -> tuple[ModelDescription, ...]:
|
||||
return tuple(
|
||||
ModelDescription(
|
||||
engine=self.__class__.__name__,
|
||||
name=model_name,
|
||||
arguments={
|
||||
'lenght_scale': Argument(
|
||||
type='float',
|
||||
description="'length_scale' можно использовать для замедления аудио для лучшего звучания, по умолчанию 1.1")})
|
||||
for model_name in self.speakers.keys()
|
||||
)
|
||||
|
||||
def __init__(self, save_path: 'Path'):
|
||||
super().__init__(save_path)
|
||||
self.speakers: dict[str, TTS] = {}
|
||||
|
||||
for speaker_name in ('natasha-g2p-vits', 'glados2-g2p-vits', 'glados-g2p-vits', 'girl_nice-g2p-vits'):
|
||||
logger.debug(f"Loading speaker: {speaker_name}")
|
||||
self.speakers[speaker_name] = TTS(f"TeraTTS/{speaker_name}", add_time_to_end=1.0, save_path=str(save_path / 'tts'))
|
||||
|
||||
def synth(self, text: str, model: str, **kwargs) -> bytes:
|
||||
tts = self.speakers[model]
|
||||
return ndarray2wav(
|
||||
tts(text, **kwargs),
|
||||
sample_rate=22050
|
||||
)
|
66
main.py
Normal file
66
main.py
Normal file
@ -0,0 +1,66 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Response
|
||||
from pprint import pprint
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from EnginesController import EnginesController
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from EngineABC import ModelDescription
|
||||
|
||||
|
||||
def play_bytes(bytes_sound: bytes) -> None:
|
||||
import winsound
|
||||
winsound.PlaySound(bytes_sound, winsound.SND_MEMORY)
|
||||
|
||||
|
||||
engines_controller: EnginesController
|
||||
preview_cache: dict[str, bytes] = dict()
|
||||
|
||||
|
||||
class BodyText(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI):
|
||||
# on_startup
|
||||
global engines_controller
|
||||
engines_controller = EnginesController()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get('/discovery')
|
||||
async def discovery() -> list[ModelDescription]:
|
||||
return await run_in_threadpool(engines_controller.discovery)
|
||||
|
||||
|
||||
@app.post('/synth/{engine}/model/{model}')
|
||||
async def synth(engine: str, model: str, text: BodyText):
|
||||
res = await run_in_threadpool(
|
||||
engines_controller.synth,
|
||||
engine,
|
||||
model,
|
||||
text.text)
|
||||
|
||||
# preview_cache[engine + model] = res
|
||||
|
||||
return Response(content=res, media_type='audio/wav')
|
||||
|
||||
|
||||
# @app.get('/synth/{engine}/model/{model}')
|
||||
# async def synth_get(engine: str, model: str):
|
||||
# return Response(content=preview_cache[engine+model], media_type='audio/wav')
|
||||
|
||||
|
||||
def main():
|
||||
c = EnginesController()
|
||||
c.discovery()
|
||||
c.synth_all()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
41
preprocessing.py
Normal file
41
preprocessing.py
Normal file
@ -0,0 +1,41 @@
|
||||
import re
|
||||
from num2words import num2words
|
||||
from transliterate import translit
|
||||
from ruaccent import RUAccent
|
||||
from config import BASE
|
||||
|
||||
|
||||
class Accents:
|
||||
def __init__(self):
|
||||
self.accentizer = RUAccent()
|
||||
self.accentizer.load(
|
||||
omograph_model_size='turbo',
|
||||
use_dictionary=True,
|
||||
workdir=str(BASE / 'preprocessing_accents')
|
||||
)
|
||||
|
||||
def __call__(self, text: str) -> str:
|
||||
return self.accentizer.process_all(text)
|
||||
|
||||
|
||||
preprocess_accents = Accents()
|
||||
|
||||
|
||||
def preprocess_nums(text: str) -> str:
|
||||
def _num2wordsshor(match):
|
||||
match = match.group()
|
||||
ret = num2words(match, lang='ru')
|
||||
return ret
|
||||
|
||||
return re.sub(r'\d+', _num2wordsshor, text)
|
||||
|
||||
|
||||
def preprocess_translit(text: str) -> str:
|
||||
return translit(text, 'ru')
|
||||
|
||||
|
||||
def preprocess(text: str) -> str:
|
||||
for preprocess_func in (preprocess_accents, preprocess_nums, preprocess_translit):
|
||||
text = preprocess_func(text)
|
||||
|
||||
return text
|
10
requirements.txt
Normal file
10
requirements.txt
Normal file
@ -0,0 +1,10 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
onnxruntime==1.18.0
|
||||
ruaccent==1.5.6.3
|
||||
transliterate==1.10.2
|
||||
num2words==0.5.13
|
||||
torch==2.3.0+cpu
|
||||
numpy==1.26.4
|
||||
loguru~=0.7.2
|
||||
fastapi==0.111.0
|
||||
pydantic
|
28
to_wav.py
Normal file
28
to_wav.py
Normal file
@ -0,0 +1,28 @@
|
||||
import numpy
|
||||
import io
|
||||
import contextlib
|
||||
import wave
|
||||
import torch
|
||||
from collections.abc import Iterable
|
||||
|
||||
|
||||
def frames2wav(resulting_array: Iterable[int], sample_rate: int) -> bytes:
|
||||
res_io_stream = io.BytesIO()
|
||||
|
||||
with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(resulting_array)
|
||||
|
||||
res_io_stream.seek(0)
|
||||
|
||||
return res_io_stream.read()
|
||||
|
||||
|
||||
def ndarray2wav(resulting_array: numpy.ndarray, sample_rate: int) -> bytes:
|
||||
return frames2wav((resulting_array * 32767).astype('int16'), sample_rate)
|
||||
|
||||
|
||||
def tensor2wav(resulting_array: torch.Tensor, sample_rate: int) -> bytes:
|
||||
return frames2wav((resulting_array * 32767).to(torch.int16).numpy(), sample_rate)
|
Loading…
x
Reference in New Issue
Block a user