Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
3a9989d0de | |||
bb28efe437 |
12
CLISynth.py
Normal file
12
CLISynth.py
Normal file
@ -0,0 +1,12 @@
|
||||
import TTSSilero
|
||||
from time import time
|
||||
import utils
|
||||
|
||||
tts = TTSSilero.TTSSilero()
|
||||
|
||||
# while msg := input('$ '):
|
||||
start = time()
|
||||
audio = tts.synthesize_text(
|
||||
"""Миша, давай заведем старый жигуль ... ви ви ви ви ви ви ви ви ви ви ви ви ви ви... ви ви ви ви ви ви ви ви ви ви... ви ви пр пр пр ви ви ви пр пр пр пр пр ви ви ви ви... миша, не сиди, помоги дотолкать до гаража, может там придётся перебирать её""")
|
||||
print('synthesize took ', str(time() - start))
|
||||
# utils.play_bytes(audio)
|
@ -1,17 +1,16 @@
|
||||
import contextlib
|
||||
import os
|
||||
import io
|
||||
import wave
|
||||
import torch.package
|
||||
|
||||
from .Speakers import Speakers
|
||||
from .multi_v2_package import TTSModelMulti_v2
|
||||
from multi_acc_v3_package import TTSModelMultiAcc_v3
|
||||
|
||||
|
||||
class TTSSilero:
|
||||
def __init__(self, threads: int = 12):
|
||||
def __init__(self, threads: int = 24):
|
||||
device = torch.device('cpu')
|
||||
torch.set_num_threads(threads)
|
||||
"""
|
||||
local_file = 'model_multi.pt'
|
||||
|
||||
if not os.path.isfile(local_file):
|
||||
@ -22,13 +21,22 @@ class TTSSilero:
|
||||
|
||||
self.model: TTSModelMulti_v2 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
|
||||
self.model.to(device)
|
||||
"""
|
||||
local_file = 'model.pt'
|
||||
|
||||
self.sample_rate = 16000
|
||||
if not os.path.isfile(local_file):
|
||||
torch.hub.download_url_to_file('https://models.silero.ai/models/tts/ru/ru_v3.pt', local_file)
|
||||
|
||||
self.model: TTSModelMultiAcc_v3 = torch.package.PackageImporter(local_file).load_pickle("tts_models", "model")
|
||||
self.model.to(device)
|
||||
# print(self.model.speakers)
|
||||
|
||||
self.sample_rate = 48000
|
||||
|
||||
def synthesize_text(self, text: str, speaker: Speakers = Speakers.kseniya) -> bytes:
|
||||
return self.to_wav(self._synthesize_text(text, speaker))
|
||||
|
||||
def _synthesize_text(self, text: str, speaker: Speakers) -> list[torch.Tensor]:
|
||||
def _synthesize_text(self, text: str, speaker: Speakers) -> torch.Tensor:
|
||||
"""
|
||||
Performs splitting text and synthesizing it
|
||||
|
||||
@ -36,24 +44,17 @@ class TTSSilero:
|
||||
:return:
|
||||
"""
|
||||
|
||||
results_list: list[torch.Tensor] = self.model.apply_tts(
|
||||
texts=[text],
|
||||
speakers=speaker.value,
|
||||
result: torch.Tensor = self.model.apply_tts(
|
||||
text=text,
|
||||
speaker=speaker.value,
|
||||
sample_rate=self.sample_rate
|
||||
)
|
||||
|
||||
return results_list
|
||||
return result
|
||||
|
||||
def to_wav(self, synthesized_text: list[torch.Tensor]) -> bytes:
|
||||
def to_wav(self, synthesized_text: torch.Tensor) -> bytes:
|
||||
res_io_stream = io.BytesIO()
|
||||
|
||||
with contextlib.closing(wave.open(res_io_stream, 'wb')) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(self.sample_rate)
|
||||
for result in synthesized_text:
|
||||
wf.writeframes((result * 32767).numpy().astype('int16'))
|
||||
|
||||
self.model.write_wave(res_io_stream, (synthesized_text * 32767).numpy().astype('int16'), self.sample_rate)
|
||||
res_io_stream.seek(0)
|
||||
|
||||
return res_io_stream.read()
|
||||
|
381
multi_acc_v3_package.py
Normal file
381
multi_acc_v3_package.py
Normal file
@ -0,0 +1,381 @@
|
||||
import re
|
||||
import wave
|
||||
import torch
|
||||
import warnings
|
||||
import contextlib
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
|
||||
class TTSModelMultiAcc_v3():
|
||||
def __init__(self, model_path, symbols, speaker_to_id, emb_dim=128):
|
||||
torch.set_grad_enabled(False)
|
||||
self.model = self.init_jit_model(model_path)
|
||||
self.symbols = symbols
|
||||
self.device = torch.device('cpu')
|
||||
self.speaker_to_id = speaker_to_id
|
||||
self.speakers = list(speaker_to_id.keys())
|
||||
assert 'random' in self.speakers
|
||||
self.ru_ascii_dict = {r: asc for r, asc in zip('абвгдеёжзийклмнопрстуфхцчшщъыьэюя–',
|
||||
'abvgde1jzi2klmnoprstufhc4w35y6789=')}
|
||||
# ssml tags
|
||||
self.strength2time = {'x_weak': 25, 'weak': 75, 'medium': 150, 'strong': 300, 'x-strong': 1000}
|
||||
self.rate2value = {'x-slow': 0.5, 'slow': 0.8, 'medium': 1., 'fast': 1.2, 'x-fast': 1.5}
|
||||
self.pitch2value = {'x-low': 0.6, 'low': 0.8, 'medium': 1., 'high': 1.2, 'x-high': 1.4, 'robot': 0.}
|
||||
self.emb_dim = emb_dim
|
||||
|
||||
self.random_emb = None
|
||||
self.debug = False
|
||||
|
||||
self.valid_tags = {'break': {'strength': list(self.strength2time.keys())},
|
||||
'prosody': {'rate': list(self.rate2value.keys()),
|
||||
'pitch': list(self.pitch2value.keys())}}
|
||||
|
||||
def init_jit_model(self, model_path: str):
|
||||
torch.set_grad_enabled(False)
|
||||
model = torch.jit.load(model_path, map_location='cpu')
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def ru_to_ascii(self, sentence):
|
||||
ascii_list = [self.ru_ascii_dict.get(s, s) for s in sentence]
|
||||
ascii_text = ''.join(ascii_list)
|
||||
return ascii_text
|
||||
|
||||
def prepare_text_input(self, text):
|
||||
|
||||
text = text.lower()
|
||||
text = text.replace('—', '–').replace('–', '–').replace('‑', '-')
|
||||
text = re.sub(r'[^{}]'.format(self.symbols[3:]), '', text)
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
sentence = self.ru_to_ascii(text)
|
||||
|
||||
clean_sentence = re.sub(r'[^a-z1-9\- ]', '', sentence)
|
||||
has_text = len(clean_sentence.replace(' ', '')) > 0
|
||||
return sentence, clean_sentence, has_text
|
||||
|
||||
def prepare_tts_model_input(self, text: str,
|
||||
ssml: bool,
|
||||
speaker_ids: list):
|
||||
if ssml:
|
||||
clean_text_list = self.process_ssml(text)
|
||||
else:
|
||||
clean_text_list = self.process_simple_text(text)
|
||||
|
||||
sentences, clean_sentences, break_lens, prosody_rates, prosody_pitches = map(list, zip(*map(dict.values, clean_text_list)))
|
||||
full_text_len = sum([len(s) for s in sentences])
|
||||
if full_text_len > 1000:
|
||||
warnings.warn('Text string is longer than 1000 symbols.')
|
||||
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
return sentences, clean_sentences, break_lens, prosody_rates, prosody_pitches, speaker_ids
|
||||
|
||||
def to(self, device):
|
||||
self.model.tts_model = self.model.tts_model.to(device)
|
||||
self.device = device
|
||||
|
||||
def get_speakers(self, speaker: str, voice_path=None):
|
||||
try:
|
||||
if speaker == 'random':
|
||||
self.load_random_voice(voice_path)
|
||||
speaker_id = self.speaker_to_id.get(speaker, None)
|
||||
if speaker_id is None:
|
||||
raise ValueError(f"`speaker` should be in {', '.join(self.speakers)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f'Failed to load speaker: {speaker}, error: {e}')
|
||||
return [speaker_id]
|
||||
|
||||
def process_simple_text(self, text):
|
||||
sentence, clean_sentence, has_text = self.prepare_text_input(text)
|
||||
if not has_text:
|
||||
raise ValueError
|
||||
simple_text_dict = [{'text': sentence,
|
||||
'clean_text': clean_sentence,
|
||||
'break_time': None,
|
||||
'prosody_rate': 1.,
|
||||
'prosody_pitch': 1.}]
|
||||
return simple_text_dict
|
||||
|
||||
def process_ssml(self, ssml_text):
|
||||
ssml_text = re.sub(r'\s+', ' ', ssml_text).strip().replace('\n ', '\n')
|
||||
try:
|
||||
root = ET.fromstring(ssml_text)
|
||||
except Exception:
|
||||
raise ValueError("Invalid XML format")
|
||||
assert root.tag == 'speak', "Invalid SSML format: <speak> tag is essential"
|
||||
try:
|
||||
ssml_parsed = self.process_ssml_element(root)
|
||||
if self.debug:
|
||||
print(ssml_parsed)
|
||||
except AssertionError as ae:
|
||||
raise ae
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse SSML: {e}")
|
||||
try:
|
||||
clean_text_list = self.process_ssml_tag_dict(ssml_parsed)
|
||||
if self.debug:
|
||||
print(clean_text_list)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to process SSML: {e}")
|
||||
return clean_text_list
|
||||
|
||||
def process_ssml_tag_dict(self, text_break_list):
|
||||
proc_text_break_list = []
|
||||
for i, text_break_prosody in enumerate(text_break_list):
|
||||
tbreak = text_break_prosody['break']
|
||||
tprosody = text_break_prosody['prosody']
|
||||
text, clean_text, has_text = self.prepare_text_input(text_break_prosody['text'])
|
||||
break_time = int(tbreak['time']/12.5) if tbreak['time'] is not None else None
|
||||
if has_text or i == 0:
|
||||
text = self.check_text_break(text, tbreak)
|
||||
proc_text_break_list.append({'text': text,
|
||||
'clean_text': clean_text,
|
||||
'break_time': break_time,
|
||||
'prosody_rate': tprosody['rate'],
|
||||
'prosody_pitch': tprosody['pitch']})
|
||||
elif tbreak['strength'] is not None and len(proc_text_break_list) > 0:
|
||||
text = self.check_text_break(proc_text_break_list[-1]['text'], tbreak)
|
||||
proc_text_break_list[-1]['text'] = text
|
||||
if proc_text_break_list[-1]['break_time'] is None:
|
||||
proc_text_break_list[-1]['break_time'] = break_time
|
||||
else:
|
||||
proc_text_break_list[-1]['break_time'] = max(break_time, proc_text_break_list[-1]['break_time'])
|
||||
return proc_text_break_list
|
||||
|
||||
def check_text_break(self, text, tbreak):
|
||||
if len(text) == 0 or tbreak['strength'] is not None and text[-1] not in '!,-.:;?–…': # TODO fx dash
|
||||
text = text + '.'
|
||||
return text
|
||||
|
||||
def process_ssml_element(self, element, def_strength='strong', def_rate=1., def_pitch=1.):
|
||||
parsed = []
|
||||
last_tag = None
|
||||
head_text_parsed = self.process_head_tail_text(element.text, def_rate, def_pitch)
|
||||
parsed.extend(head_text_parsed)
|
||||
|
||||
for child in element:
|
||||
if child.tag == 'break':
|
||||
break_strength, break_ts = self.process_break_attrib(child.attrib)
|
||||
if len(parsed) == 0:
|
||||
parsed.append({'text': '.',
|
||||
'break': {'strength': None,
|
||||
'time': None},
|
||||
'prosody': {'rate': def_rate,
|
||||
'pitch': def_pitch}})
|
||||
parsed[-1]['break'] = {'strength': break_strength,
|
||||
'time': break_ts}
|
||||
|
||||
elif child.tag == 'prosody':
|
||||
prosody_rate, prosody_pitch, change_rate, change_pitch = self.process_prosody(child.attrib)
|
||||
child_rate = prosody_rate if change_rate else def_rate
|
||||
child_pitch = prosody_pitch if change_pitch else def_pitch
|
||||
child_parsed = self.process_ssml_element(child, def_strength, child_rate, child_pitch)
|
||||
parsed.extend(child_parsed)
|
||||
|
||||
elif child.tag in ['p', 's']:
|
||||
break_strength = 'strong' if child.tag == 's' else 'x-strong'
|
||||
child_parsed = self.process_ssml_element(child, break_strength, def_rate, def_pitch)
|
||||
if len(parsed) > 0 and (parsed[-1]['text'] or last_tag is not None or last_tag != child.tag):
|
||||
if parsed[-1]['break']['strength'] is None:
|
||||
parsed[-1]['break'] = {'strength': break_strength,
|
||||
'time': self.strength2time[break_strength]}
|
||||
else:
|
||||
last_time = parsed[-1]['break']['time']
|
||||
parsed[-1]['break'] = {'strength': break_strength,
|
||||
'time': max(last_time, self.strength2time[break_strength])}
|
||||
if len(child_parsed) > 0:
|
||||
if child_parsed[-1]['break']['strength'] is None:
|
||||
child_parsed[-1]['break'] = {'strength': break_strength,
|
||||
'time': self.strength2time[break_strength]}
|
||||
else:
|
||||
last_time = child_parsed[-1]['break']['time']
|
||||
child_parsed[-1]['break'] = {'strength': break_strength,
|
||||
'time': max(last_time, self.strength2time[break_strength])}
|
||||
parsed.extend(child_parsed)
|
||||
else:
|
||||
warnings.warn(f"Current model doesn't support SSML tag: {child.tag}")
|
||||
|
||||
last_tag = child.tag
|
||||
|
||||
if child.tail:
|
||||
tail_text = child.tail
|
||||
if tail_text[0] in '.,!?…–;:' and len(parsed) > 0:
|
||||
lost_punct = tail_text[0]
|
||||
parsed[-1]['text'] = parsed[-1]['text'].strip() + lost_punct
|
||||
if len(tail_text) > 1:
|
||||
tail_text = tail_text[1:]
|
||||
tail_text_parsed = self.process_head_tail_text(tail_text, def_rate, def_pitch)
|
||||
parsed.extend(tail_text_parsed)
|
||||
return parsed
|
||||
|
||||
def process_head_tail_text(self, element_text, def_rate, def_pitch):
|
||||
text_parsed = []
|
||||
if element_text is None:
|
||||
return text_parsed
|
||||
proc_text = element_text.replace('\n', '')
|
||||
proc_text = re.sub(r'\s+', ' ', proc_text).strip()
|
||||
text_parsed.append({'text': proc_text,
|
||||
'break': {'strength': None,
|
||||
'time': None},
|
||||
'prosody': {'rate': def_rate,
|
||||
'pitch': def_pitch}})
|
||||
return text_parsed
|
||||
|
||||
def process_break_attrib(self, attrib):
|
||||
for k in attrib.keys():
|
||||
if k not in ['strength', 'time']:
|
||||
warnings.warn(f"Current model doesn't support SSML <break> attrib: {k}")
|
||||
strength = attrib.get('strength', 'medium')
|
||||
break_time = attrib.get('time', None)
|
||||
if break_time is not None:
|
||||
if break_time.endswith('ms'):
|
||||
break_ts = int(break_time[:-2])
|
||||
elif break_time.endswith('s'):
|
||||
break_ts = int(break_time[:-1]) * 1000
|
||||
else:
|
||||
raise AssertionError("Invalid <break> tag, time should end with 'ms' or 's'")
|
||||
|
||||
if break_ts >= self.strength2time['x-strong']:
|
||||
strength = 'x-strong'
|
||||
elif break_ts >= self.strength2time['strong']:
|
||||
strength = 'strong'
|
||||
else:
|
||||
if strength in self.strength2time:
|
||||
break_ts = self.strength2time[strength]
|
||||
else:
|
||||
raise AssertionError(f"Invalid <break> tag, strength should be in {', '.join(self.valid_tags['break']['strength'])}")
|
||||
if break_ts > 5000:
|
||||
warnings.warn('Cuurent model supports pauses less than 5 sec')
|
||||
break_ts = 5000
|
||||
return strength, break_ts
|
||||
|
||||
def process_prosody(self, attrib):
|
||||
for k in attrib.keys():
|
||||
if k not in ['rate', 'pitch']:
|
||||
warnings.warn(f"Current model doesn't support SSML <prosody> attrib: {k}")
|
||||
rate = attrib.get('rate', None)
|
||||
pitch = attrib.get('pitch', None)
|
||||
assert rate is not None or pitch is not None, "Empty <prosody> tag"
|
||||
if rate is not None:
|
||||
change_rate = True
|
||||
if rate.endswith('%'):
|
||||
rate_val = int(rate.replace('%', '')) / 100
|
||||
else:
|
||||
rate_val = self.rate2value.get(rate, None)
|
||||
if rate_val is None:
|
||||
raise AssertionError(f"Invalid <prosody> tag, rate should be in {', '.join(self.valid_tags['prosody']['rate'])}")
|
||||
else:
|
||||
change_rate = False
|
||||
rate_val = 1.
|
||||
if pitch is not None:
|
||||
change_pitch = True
|
||||
if pitch.endswith('%'):
|
||||
pitch_val = int(pitch.replace('%', '')[1:]) / 100
|
||||
if pitch[0] == '+':
|
||||
pitch_val = 1. + pitch_val
|
||||
else:
|
||||
pitch_val = 1. - pitch_val
|
||||
else:
|
||||
pitch_val = self.pitch2value.get(pitch, None)
|
||||
if pitch_val is None:
|
||||
raise AssertionError(f"Invalid <prosody> tag, pitch should be in {', '.join(self.valid_tags['prosody']['pitch'])}")
|
||||
else:
|
||||
change_pitch = False
|
||||
pitch_val = 1.
|
||||
return rate_val, pitch_val, change_rate, change_pitch
|
||||
|
||||
def apply_tts(self, text=None,
|
||||
ssml_text=None,
|
||||
speaker: str = 'xenia',
|
||||
sample_rate: int = 48000,
|
||||
put_accent=True,
|
||||
put_yo=True,
|
||||
voice_path=None):
|
||||
|
||||
assert sample_rate in [8000, 24000, 48000], f"`sample_rate` should be in [8000, 24000, 48000], current value is {sample_rate}"
|
||||
assert speaker in self.speakers, f"`speaker` should be in {', '.join(self.speakers)}"
|
||||
assert text is not None or ssml_text is not None, "Both `text` and `ssml_text` are empty"
|
||||
|
||||
ssml = ssml_text is not None
|
||||
if ssml:
|
||||
input_text = ssml_text
|
||||
else:
|
||||
input_text = text
|
||||
speaker_ids = self.get_speakers(speaker, voice_path)
|
||||
sentences, clean_sentences, break_lens, prosody_rates, prosody_pitches, sp_ids = self.prepare_tts_model_input(input_text,
|
||||
ssml=ssml,
|
||||
speaker_ids=speaker_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
try:
|
||||
out, out_lens = self.model(sentences=sentences,
|
||||
clean_sentences=clean_sentences,
|
||||
break_lens=break_lens,
|
||||
prosody_rates=prosody_rates,
|
||||
prosody_pitches=prosody_pitches,
|
||||
speaker_ids=sp_ids,
|
||||
sr=sample_rate,
|
||||
device=str(self.device),
|
||||
put_yo=put_yo,
|
||||
put_accent=put_accent
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise Exception("Model couldn't generate your text, probably it's too long")
|
||||
audio = out.to('cpu')[0]
|
||||
return audio
|
||||
|
||||
@staticmethod
|
||||
def write_wave(path, audio, sample_rate):
|
||||
"""Writes a .wav file.
|
||||
Takes path, PCM audio data, and sample rate.
|
||||
"""
|
||||
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio)
|
||||
|
||||
def save_wav(self, text=None,
|
||||
ssml_text=None,
|
||||
speaker: str = 'xenia',
|
||||
audio_path: str = '',
|
||||
sample_rate: int = 48000,
|
||||
put_accent=True,
|
||||
put_yo=True):
|
||||
|
||||
if not audio_path:
|
||||
audio_path = 'test.wav'
|
||||
|
||||
audio = self.apply_tts(text=text,
|
||||
ssml_text=ssml_text,
|
||||
speaker=speaker,
|
||||
sample_rate=sample_rate,
|
||||
put_yo=put_yo,
|
||||
put_accent=put_accent)
|
||||
self.write_wave(path=audio_path,
|
||||
audio=(audio * 32767).numpy().astype('int16'),
|
||||
sample_rate=sample_rate)
|
||||
return audio_path
|
||||
|
||||
def load_random_voice(self, voice_path=None):
|
||||
if voice_path is None:
|
||||
random_emb = torch.randn(2, self.emb_dim, requires_grad=False).to(self.device)
|
||||
self.random_emb = random_emb
|
||||
print("Generated new voice")
|
||||
else:
|
||||
random_emb = torch.load(voice_path, map_location=self.device)
|
||||
print(f"Loaded voice from {voice_path}")
|
||||
if self.random_emb is not None and torch.equal(self.random_emb, random_emb):
|
||||
return
|
||||
mel_weight = random_emb[0]
|
||||
dur_weight = random_emb[1, :self.emb_dim//2]
|
||||
p_weight = random_emb[1, self.emb_dim//2:]
|
||||
|
||||
self.model.tts_model.tacotron.speaker_embedding.weight[-1] = mel_weight
|
||||
self.model.tts_model.dur_predictor.dur_pred.speaker_embedding.weight[-1] = dur_weight
|
||||
self.model.tts_model.pitch_predictor.pitch_pred.speaker_embedding.weight[-1] = p_weight
|
||||
|
||||
def save_random_voice(self, voice_path):
|
||||
assert self.random_emb is not None, "No generated random voice"
|
||||
torch.save(self.random_emb, voice_path)
|
||||
print(f"Saved generated voice to {voice_path}")
|
Loading…
x
Reference in New Issue
Block a user