This commit is contained in:
norohind 2024-05-28 22:50:48 +03:00
commit f123cc1f86
Signed by: norohind
GPG Key ID: 01C3BECC26FB59E1
18 changed files with 702 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
__pycache__
.idea
cache

33
EngineABC.py Normal file
View File

@ -0,0 +1,33 @@
from typing import Literal
from abc import abstractmethod, ABC
from dataclasses import dataclass
from pathlib import Path
from pydantic import BaseModel
class Argument(BaseModel):
type: Literal['str', 'int', 'float']
description: str | None = None
class ModelDescription(BaseModel):
engine: str
name: str
arguments: dict[str, Argument]
description: None | str = None
class EngineABC(ABC):
def __init__(self, save_path: 'Path'):
self.save_path = save_path
@abstractmethod
def discovery(self) -> tuple[ModelDescription, ...]:
...
@abstractmethod
def synth(self, text: str, model: str, *args, **kwargs) -> bytes:
...
:

40
EnginesController.py Normal file
View File

@ -0,0 +1,40 @@
from loguru import logger
import importlib
from EngineABC import EngineABC, ModelDescription
from pathlib import Path
from analytics import measure
from preprocessing import preprocess
import config
class EnginesController:
def __init__(self):
self.engines: dict[str, EngineABC] = dict()
engines_dir = Path('engines')
for module_dir in engines_dir.iterdir():
init_path = module_dir / '__init__.py'
if init_path.exists():
module = importlib.machinery.SourceFileLoader(module_dir.name, str(init_path)).load_module()
logger.info(f'imported {init_path.parent.name!r} engine')
for obj_name in dir(module):
obj = getattr(module, obj_name)
if isinstance(obj, type) and EngineABC in obj.__bases__:
obj: type[EngineABC]
cache_path = config.BASE / obj.__name__
cache_path.mkdir(exist_ok=True)
instance: EngineABC = obj(save_path=cache_path)
self.engines[obj.__name__] = instance
def discovery(self) -> list[ModelDescription]:
to_return: list[ModelDescription] = list()
for name, engine in self.engines.items():
to_return += engine.discovery()
return to_return
@measure
def synth(self, engine: str, model: str, text: str, **kwargs) -> bytes:
return self.engines[engine].synth(preprocess(text), model, **kwargs)
def synth_all(self, text: str = 'Съешь ещё 15 этих пирожков') -> list[bytes]:
return [self.synth(model.engine, model.name, text=text) for model in self.discovery()]

38
analytics.py Normal file
View File

@ -0,0 +1,38 @@
import time
def measure(function: callable, name_to_display: str = ''):
"""
Decorator to measure function (method) execution time
Use as easy as
@utils.measure
def im_function_to_measure():
....
:param name_to_display:
:param function:
:return:
"""
if name_to_display != '':
name_to_display = name_to_display + ':'
def decorated(*args, **kwargs):
start = time.time()
result = function(*args, **kwargs)
end = time.time()
print(f'{name_to_display}{function.__name__}:{args[1:]} {(end - start) * 100} ms')
return result
return decorated
class Measure:
def __init__(self, name: str):
self.start = time.time()
self.name = name
def record(self) -> None:
print(f'{self.name}: {(time.time() - self.start) * 100} ms')

3
config.py Normal file
View File

@ -0,0 +1,3 @@
from pathlib import Path
BASE = Path('cache')

View File

@ -0,0 +1,45 @@
from EngineABC import EngineABC, ModelDescription
import torch.package
from loguru import logger
from to_wav import tensor2wav
import typing
if typing.TYPE_CHECKING:
from .multi_v2_package import TTSModelMulti_v2
from pathlib import Path
class Silero(EngineABC):
def __init__(self, save_path: 'Path'):
super().__init__(save_path)
threads = 12
device = torch.device('cpu')
torch.set_num_threads(threads)
dst = save_path / 'model_multi.pt'
if not dst.exists():
torch.hub.download_url_to_file(
'https://models.silero.ai/models/tts/multi/v2_multi.pt',
str(dst)
)
self.model: TTSModelMulti_v2 = torch.package.PackageImporter(dst).load_pickle("tts_models", "model")
self.model.to(device)
logger.debug(f"Loading speaker: v2_multi")
self.sample_rate = 16000
def synth(self, text: str, model: str, *args, **kwargs) -> bytes:
# model is speaker
return tensor2wav(
self.model.apply_tts(texts=text, speakers=model, sample_rate=self.sample_rate)[0],
self.sample_rate
)
def discovery(self) -> tuple[ModelDescription, ...]:
return tuple(
ModelDescription(
engine=self.__class__.__name__,
name=speaker_name,
arguments=dict()
)
for speaker_name in self.model.speaker_to_id.keys()
)

View File

@ -0,0 +1,148 @@
import re
import wave
import torch
import warnings
import contextlib
# for type hints only
class TTSModelMulti_v2:
def __init__(self, model_path, symbols):
self.model = self.init_jit_model(model_path)
self.symbols = symbols
self.device = torch.device('cpu')
speakers = ['aidar', 'baya', 'kseniya', 'irina', 'ruslan', 'natasha',
'thorsten', 'tux', 'gilles', 'lj', 'dilyara']
self.speaker_to_id = {sp: i for i, sp in enumerate(speakers)}
def init_jit_model(self, model_path: str):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location='cpu')
model.eval()
return model
def prepare_text_input(self, text, symbols, symbol_to_id=None):
if len(text) > 140:
warnings.warn('Text string is longer than 140 symbols.')
if symbol_to_id is None:
symbol_to_id = {s: i for i, s in enumerate(symbols)}
text = text.lower()
text = re.sub(r'[^{}]'.format(symbols[2:]), '', text)
text = re.sub(r'\s+', ' ', text).strip()
if text[-1] not in ['.', '!', '?']:
text = text + '.'
text = text + symbols[1]
text_ohe = [symbol_to_id[s] for s in text if s in symbols]
text_tensor = torch.LongTensor(text_ohe)
return text_tensor
def prepare_tts_model_input(self, text: str or list, symbols: str, speakers: list):
assert len(speakers) == len(text) or len(speakers) == 1
if type(text) == str:
text = [text]
symbol_to_id = {s: i for i, s in enumerate(symbols)}
if len(text) == 1:
return self.prepare_text_input(text[0], symbols, symbol_to_id).unsqueeze(0), torch.LongTensor(speakers), torch.LongTensor([0])
text_tensors = []
for string in text:
string_tensor = self.prepare_text_input(string, symbols, symbol_to_id)
text_tensors.append(string_tensor)
input_lengths, ids_sorted_decreasing = torch.sort(
torch.LongTensor([len(t) for t in text_tensors]),
dim=0, descending=True)
max_input_len = input_lengths[0]
batch_size = len(text_tensors)
text_padded = torch.ones(batch_size, max_input_len, dtype=torch.int32)
if len(speakers) == 1:
speakers = speakers*batch_size
speaker_ids = torch.LongTensor(batch_size).zero_()
for i, idx in enumerate(ids_sorted_decreasing):
text_tensor = text_tensors[idx]
in_len = text_tensor.size(0)
text_padded[i, :in_len] = text_tensor
speaker_ids[i] = speakers[idx]
return text_padded, speaker_ids, ids_sorted_decreasing
def process_tts_model_output(self, out, out_lens, ids):
out = out.to('cpu')
out_lens = out_lens.to('cpu')
_, orig_ids = ids.sort()
proc_outs = []
orig_out = out.index_select(0, orig_ids)
orig_out_lens = out_lens.index_select(0, orig_ids)
for i, out_len in enumerate(orig_out_lens):
proc_outs.append(orig_out[i][:out_len])
return proc_outs
def to(self, device):
self.model = self.model.to(device)
self.device = device
def get_speakers(self, speakers: str or list):
if type(speakers) == str:
speakers = [speakers]
speaker_ids = []
for speaker in speakers:
try:
speaker_id = self.speaker_to_id[speaker]
speaker_ids.append(speaker_id)
except Exception:
raise ValueError(f'No such speaker: {speaker}')
return speaker_ids
def apply_tts(self, texts: str or list,
speakers: str or list,
sample_rate: int = 16000):
speaker_ids = self.get_speakers(speakers)
text_padded, speaker_ids, orig_ids = self.prepare_tts_model_input(texts,
symbols=self.symbols,
speakers=speaker_ids)
with torch.inference_mode():
out, out_lens = self.model(text_padded.to(self.device),
speaker_ids.to(self.device),
sr=sample_rate)
audios = self.process_tts_model_output(out, out_lens, orig_ids)
return audios
@staticmethod
def write_wave(path, audio, sample_rate):
"""Writes a .wav file.
Takes path, PCM audio data, and sample rate.
"""
with contextlib.closing(wave.open(path, 'wb')) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio)
def save_wav(self, texts: str or list,
speakers: str or list,
audio_pathes: str or list = '',
sample_rate: int = 16000):
if type(texts) == str:
texts = [texts]
if not audio_pathes:
audio_pathes = [f'test_{str(i).zfill(3)}.wav' for i in range(len(texts))]
if type(audio_pathes) == str:
audio_pathes = [audio_pathes]
assert len(audio_pathes) == len(texts)
audio = self.apply_tts(texts=texts,
speakers=speakers,
sample_rate=sample_rate)
for i, _audio in enumerate(audio):
self.write_wave(path=audio_pathes[i],
audio=(_audio * 32767).numpy().astype('int16'),
sample_rate=sample_rate)
return audio_pathes

View File

@ -0,0 +1,2 @@
from .infer_onnx import TTS
from .tokenizer import TokenizerG2P

View File

@ -0,0 +1,65 @@
import os
import onnxruntime
import numpy as np
from huggingface_hub import snapshot_download
from .tokenizer import TokenizerG2P
class TTS:
def __init__(self, model_name: str, save_path: str = "./model", add_time_to_end: float = 1.0, tokenizer_load_dict=True) -> None:
if not os.path.exists(save_path):
os.mkdir(save_path)
model_dir = os.path.join(save_path, model_name)
if not os.path.exists(model_dir):
snapshot_download(repo_id=model_name,
allow_patterns=["*.txt", "*.onnx", "*.json"],
local_dir=model_dir
)
self.model = onnxruntime.InferenceSession(os.path.join(model_dir, "exported/model.onnx"),
providers=['CPUExecutionProvider'])
self.tokenizer = TokenizerG2P(os.path.join(model_dir, "exported"), load_dict=tokenizer_load_dict)
self.add_time_to_end = add_time_to_end
def _add_silent(self, audio, silence_duration: float = 1.0, sample_rate: int = 22050):
num_samples_silence = int(sample_rate * silence_duration)
silence_array = np.zeros(num_samples_silence, dtype=np.float32)
audio_with_silence = np.concatenate((audio, silence_array), axis=0)
return audio_with_silence
def _intersperse(self, lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def _get_seq(self, text):
phoneme_ids = self.tokenizer._get_seq(text)
phoneme_ids_inter = self._intersperse(phoneme_ids, 0)
return phoneme_ids_inter
def __call__(self, text: str, play=False, lenght_scale=1.2):
phoneme_ids = self._get_seq(text)
text = np.expand_dims(np.array(phoneme_ids, dtype=np.int64), 0)
text_lengths = np.array([text.shape[1]], dtype=np.int64)
scales = np.array(
[0.667, lenght_scale, 0.8],
dtype=np.float32,
)
audio = self.model.run(
None,
{
"input": text,
"input_lengths": text_lengths,
"scales": scales,
"sid": None,
},
)[0][0, 0][0]
audio = self._add_silent(audio, silence_duration=self.add_time_to_end)
if play:
self.play_audio(audio)
return audio

View File

@ -0,0 +1 @@
from .g2p import Tokenizer as TokenizerG2P

View File

@ -0,0 +1 @@
from .tokenizer import Tokenizer

View File

@ -0,0 +1,94 @@
softletters=set(u"яёюиье")
startsyl=set(u"#ъьаяоёуюэеиы-")
others = set(["#", "+", "-", u"ь", u"ъ"])
softhard_cons = {
u"б" : u"b",
u"в" : u"v",
u"г" : u"g",
u"Г" : u"g",
u"д" : u"d",
u"з" : u"z",
u"к" : u"k",
u"л" : u"l",
u"м" : u"m",
u"н" : u"n",
u"п" : u"p",
u"р" : u"r",
u"с" : u"s",
u"т" : u"t",
u"ф" : u"f",
u"х" : u"h"
}
other_cons = {
u"ж" : u"zh",
u"ц" : u"c",
u"ч" : u"ch",
u"ш" : u"sh",
u"щ" : u"sch",
u"й" : u"j"
}
vowels = {
u"а" : u"a",
u"я" : u"a",
u"у" : u"u",
u"ю" : u"u",
u"о" : u"o",
u"ё" : u"o",
u"э" : u"e",
u"е" : u"e",
u"и" : u"i",
u"ы" : u"y",
}
def pallatize(phones):
for i, phone in enumerate(phones[:-1]):
if phone[0] in softhard_cons:
if phones[i+1][0] in softletters:
phones[i] = (softhard_cons[phone[0]] + "j", 0)
else:
phones[i] = (softhard_cons[phone[0]], 0)
if phone[0] in other_cons:
phones[i] = (other_cons[phone[0]], 0)
def convert_vowels(phones):
new_phones = []
prev = ""
for phone in phones:
if prev in startsyl:
if phone[0] in set(u"яюеё"):
new_phones.append("j")
if phone[0] in vowels:
new_phones.append(vowels[phone[0]] + str(phone[1]))
else:
new_phones.append(phone[0])
prev = phone[0]
return new_phones
def convert(stressword):
phones = ("#" + stressword + "#")
# Assign stress marks
stress_phones = []
stress = 0
for phone in phones:
if phone == "+":
stress = 1
else:
stress_phones.append((phone, stress))
stress = 0
# Pallatize
pallatize(stress_phones)
# Assign stress
phones = convert_vowels(stress_phones)
# Filter
phones = [x for x in phones if x not in others]
return " ".join(phones)

View File

@ -0,0 +1,50 @@
import re
from .g2p import * #noqa
import json
import os
class Tokenizer():
def __init__(self, data_path: str, load_dict=True) -> None:
'''data_path - path to data dir; load_dict - load dict, if you use accent model like ruaccent you dont need its'''
self.dic = {}
if load_dict:
for line in open(os.path.join(data_path, "dictionary.txt")): #noqa
items = line.split()
self.dic[items[0]] = " ".join(items[1:])
self.config = json.load(open(os.path.join(data_path, "config.json"))) #noqa
def g2p(self, text):
text = re.sub("", "-", text)
text = re.sub("([!'(),-.:;?])", r' \1 ', text)
phonemes = []
for word in text.split():
if re.match("[!'(),-.:;?]", word):
phonemes.append(word)
continue
word = word.lower()
if len(phonemes) > 0:
phonemes.append(' ')
if word in self.dic:
phonemes.extend(self.dic[word].split())
else:
phonemes.extend(convert(word).split()) #noqa
phoneme_id_map = self.config["phoneme_id_map"]
phoneme_ids = []
phoneme_ids.extend(phoneme_id_map["^"])
phoneme_ids.extend(phoneme_id_map["_"])
for p in phonemes:
if p in phoneme_id_map:
phoneme_ids.extend(phoneme_id_map[p])
phoneme_ids.extend(phoneme_id_map["_"])
phoneme_ids.extend(phoneme_id_map["$"])
return phoneme_ids, phonemes
def _get_seq(self, text: str) -> list[int]:
seq = self.g2p(text)[0]
return seq

View File

@ -0,0 +1,34 @@
from to_wav import ndarray2wav
from pathlib import Path
from loguru import logger
from EngineABC import EngineABC, ModelDescription, Argument
from .TeraTTS import TTS
class TeraTTSEngine(EngineABC):
def discovery(self) -> tuple[ModelDescription, ...]:
return tuple(
ModelDescription(
engine=self.__class__.__name__,
name=model_name,
arguments={
'lenght_scale': Argument(
type='float',
description="'length_scale' можно использовать для замедления аудио для лучшего звучания, по умолчанию 1.1")})
for model_name in self.speakers.keys()
)
def __init__(self, save_path: 'Path'):
super().__init__(save_path)
self.speakers: dict[str, TTS] = {}
for speaker_name in ('natasha-g2p-vits', 'glados2-g2p-vits', 'glados-g2p-vits', 'girl_nice-g2p-vits'):
logger.debug(f"Loading speaker: {speaker_name}")
self.speakers[speaker_name] = TTS(f"TeraTTS/{speaker_name}", add_time_to_end=1.0, save_path=str(save_path / 'tts'))
def synth(self, text: str, model: str, **kwargs) -> bytes:
tts = self.speakers[model]
return ndarray2wav(
tts(text, **kwargs),
sample_rate=22050
)

66
main.py Normal file
View File

@ -0,0 +1,66 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, Response
from pprint import pprint
from loguru import logger
from pydantic import BaseModel
from EnginesController import EnginesController
from fastapi.concurrency import run_in_threadpool
from EngineABC import ModelDescription
def play_bytes(bytes_sound: bytes) -> None:
import winsound
winsound.PlaySound(bytes_sound, winsound.SND_MEMORY)
engines_controller: EnginesController
preview_cache: dict[str, bytes] = dict()
class BodyText(BaseModel):
text: str
@asynccontextmanager
async def lifespan(_app: FastAPI):
# on_startup
global engines_controller
engines_controller = EnginesController()
yield
app = FastAPI(lifespan=lifespan)
@app.get('/discovery')
async def discovery() -> list[ModelDescription]:
return await run_in_threadpool(engines_controller.discovery)
@app.post('/synth/{engine}/model/{model}')
async def synth(engine: str, model: str, text: BodyText):
res = await run_in_threadpool(
engines_controller.synth,
engine,
model,
text.text)
# preview_cache[engine + model] = res
return Response(content=res, media_type='audio/wav')
# @app.get('/synth/{engine}/model/{model}')
# async def synth_get(engine: str, model: str):
# return Response(content=preview_cache[engine+model], media_type='audio/wav')
def main():
c = EnginesController()
c.discovery()
c.synth_all()
if __name__ == '__main__':
main()

41
preprocessing.py Normal file
View File

@ -0,0 +1,41 @@
import re
from num2words import num2words
from transliterate import translit
from ruaccent import RUAccent
from config import BASE
class Accents:
def __init__(self):
self.accentizer = RUAccent()
self.accentizer.load(
omograph_model_size='turbo',
use_dictionary=True,
workdir=str(BASE / 'preprocessing_accents')
)
def __call__(self, text: str) -> str:
return self.accentizer.process_all(text)
preprocess_accents = Accents()
def preprocess_nums(text: str) -> str:
def _num2wordsshor(match):
match = match.group()
ret = num2words(match, lang='ru')
return ret
return re.sub(r'\d+', _num2wordsshor, text)
def preprocess_translit(text: str) -> str:
return translit(text, 'ru')
def preprocess(text: str) -> str:
for preprocess_func in (preprocess_accents, preprocess_nums, preprocess_translit):
text = preprocess_func(text)
return text

10
requirements.txt Normal file
View File

@ -0,0 +1,10 @@
--extra-index-url https://download.pytorch.org/whl/cpu
onnxruntime==1.18.0
ruaccent==1.5.6.3
transliterate==1.10.2
num2words==0.5.13
torch==2.3.0+cpu
numpy==1.26.4
loguru~=0.7.2
fastapi==0.111.0
pydantic

28
to_wav.py Normal file
View File

@ -0,0 +1,28 @@
import numpy
import io
import contextlib
import wave
import torch
from collections.abc import Iterable
def frames2wav(resulting_array: Iterable[int], sample_rate: int) -> bytes:
res_io_stream = io.BytesIO()
with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(resulting_array)
res_io_stream.seek(0)
return res_io_stream.read()
def ndarray2wav(resulting_array: numpy.ndarray, sample_rate: int) -> bytes:
return frames2wav((resulting_array * 32767).astype('int16'), sample_rate)
def tensor2wav(resulting_array: torch.Tensor, sample_rate: int) -> bytes:
return frames2wav((resulting_array * 32767).to(torch.int16).numpy(), sample_rate)