SileroTTSBot/multi_acc_v3_package.py
2022-04-21 16:17:10 +03:00

382 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import wave
import torch
import warnings
import contextlib
import xml.etree.ElementTree as ET
class TTSModelMultiAcc_v3():
def __init__(self, model_path, symbols, speaker_to_id, emb_dim=128):
torch.set_grad_enabled(False)
self.model = self.init_jit_model(model_path)
self.symbols = symbols
self.device = torch.device('cpu')
self.speaker_to_id = speaker_to_id
self.speakers = list(speaker_to_id.keys())
assert 'random' in self.speakers
self.ru_ascii_dict = {r: asc for r, asc in zip('абвгдеёжзийклмнопрстуфхцчшщъыьэюя–',
'abvgde1jzi2klmnoprstufhc4w35y6789=')}
# ssml tags
self.strength2time = {'x_weak': 25, 'weak': 75, 'medium': 150, 'strong': 300, 'x-strong': 1000}
self.rate2value = {'x-slow': 0.5, 'slow': 0.8, 'medium': 1., 'fast': 1.2, 'x-fast': 1.5}
self.pitch2value = {'x-low': 0.6, 'low': 0.8, 'medium': 1., 'high': 1.2, 'x-high': 1.4, 'robot': 0.}
self.emb_dim = emb_dim
self.random_emb = None
self.debug = False
self.valid_tags = {'break': {'strength': list(self.strength2time.keys())},
'prosody': {'rate': list(self.rate2value.keys()),
'pitch': list(self.pitch2value.keys())}}
def init_jit_model(self, model_path: str):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location='cpu')
model.eval()
return model
def ru_to_ascii(self, sentence):
ascii_list = [self.ru_ascii_dict.get(s, s) for s in sentence]
ascii_text = ''.join(ascii_list)
return ascii_text
def prepare_text_input(self, text):
text = text.lower()
text = text.replace('', '').replace('', '').replace('', '-')
text = re.sub(r'[^{}]'.format(self.symbols[3:]), '', text)
text = re.sub(r'\s+', ' ', text).strip()
sentence = self.ru_to_ascii(text)
clean_sentence = re.sub(r'[^a-z1-9\- ]', '', sentence)
has_text = len(clean_sentence.replace(' ', '')) > 0
return sentence, clean_sentence, has_text
def prepare_tts_model_input(self, text: str,
ssml: bool,
speaker_ids: list):
if ssml:
clean_text_list = self.process_ssml(text)
else:
clean_text_list = self.process_simple_text(text)
sentences, clean_sentences, break_lens, prosody_rates, prosody_pitches = map(list, zip(*map(dict.values, clean_text_list)))
full_text_len = sum([len(s) for s in sentences])
if full_text_len > 1000:
warnings.warn('Text string is longer than 1000 symbols.')
speaker_ids = torch.LongTensor(speaker_ids)
return sentences, clean_sentences, break_lens, prosody_rates, prosody_pitches, speaker_ids
def to(self, device):
self.model.tts_model = self.model.tts_model.to(device)
self.device = device
def get_speakers(self, speaker: str, voice_path=None):
try:
if speaker == 'random':
self.load_random_voice(voice_path)
speaker_id = self.speaker_to_id.get(speaker, None)
if speaker_id is None:
raise ValueError(f"`speaker` should be in {', '.join(self.speakers)}")
except Exception as e:
raise ValueError(f'Failed to load speaker: {speaker}, error: {e}')
return [speaker_id]
def process_simple_text(self, text):
sentence, clean_sentence, has_text = self.prepare_text_input(text)
if not has_text:
raise ValueError
simple_text_dict = [{'text': sentence,
'clean_text': clean_sentence,
'break_time': None,
'prosody_rate': 1.,
'prosody_pitch': 1.}]
return simple_text_dict
def process_ssml(self, ssml_text):
ssml_text = re.sub(r'\s+', ' ', ssml_text).strip().replace('\n ', '\n')
try:
root = ET.fromstring(ssml_text)
except Exception:
raise ValueError("Invalid XML format")
assert root.tag == 'speak', "Invalid SSML format: <speak> tag is essential"
try:
ssml_parsed = self.process_ssml_element(root)
if self.debug:
print(ssml_parsed)
except AssertionError as ae:
raise ae
except Exception as e:
raise ValueError(f"Failed to parse SSML: {e}")
try:
clean_text_list = self.process_ssml_tag_dict(ssml_parsed)
if self.debug:
print(clean_text_list)
except Exception as e:
raise ValueError(f"Failed to process SSML: {e}")
return clean_text_list
def process_ssml_tag_dict(self, text_break_list):
proc_text_break_list = []
for i, text_break_prosody in enumerate(text_break_list):
tbreak = text_break_prosody['break']
tprosody = text_break_prosody['prosody']
text, clean_text, has_text = self.prepare_text_input(text_break_prosody['text'])
break_time = int(tbreak['time']/12.5) if tbreak['time'] is not None else None
if has_text or i == 0:
text = self.check_text_break(text, tbreak)
proc_text_break_list.append({'text': text,
'clean_text': clean_text,
'break_time': break_time,
'prosody_rate': tprosody['rate'],
'prosody_pitch': tprosody['pitch']})
elif tbreak['strength'] is not None and len(proc_text_break_list) > 0:
text = self.check_text_break(proc_text_break_list[-1]['text'], tbreak)
proc_text_break_list[-1]['text'] = text
if proc_text_break_list[-1]['break_time'] is None:
proc_text_break_list[-1]['break_time'] = break_time
else:
proc_text_break_list[-1]['break_time'] = max(break_time, proc_text_break_list[-1]['break_time'])
return proc_text_break_list
def check_text_break(self, text, tbreak):
if len(text) == 0 or tbreak['strength'] is not None and text[-1] not in '!,-.:;?–…': # TODO fx dash
text = text + '.'
return text
def process_ssml_element(self, element, def_strength='strong', def_rate=1., def_pitch=1.):
parsed = []
last_tag = None
head_text_parsed = self.process_head_tail_text(element.text, def_rate, def_pitch)
parsed.extend(head_text_parsed)
for child in element:
if child.tag == 'break':
break_strength, break_ts = self.process_break_attrib(child.attrib)
if len(parsed) == 0:
parsed.append({'text': '.',
'break': {'strength': None,
'time': None},
'prosody': {'rate': def_rate,
'pitch': def_pitch}})
parsed[-1]['break'] = {'strength': break_strength,
'time': break_ts}
elif child.tag == 'prosody':
prosody_rate, prosody_pitch, change_rate, change_pitch = self.process_prosody(child.attrib)
child_rate = prosody_rate if change_rate else def_rate
child_pitch = prosody_pitch if change_pitch else def_pitch
child_parsed = self.process_ssml_element(child, def_strength, child_rate, child_pitch)
parsed.extend(child_parsed)
elif child.tag in ['p', 's']:
break_strength = 'strong' if child.tag == 's' else 'x-strong'
child_parsed = self.process_ssml_element(child, break_strength, def_rate, def_pitch)
if len(parsed) > 0 and (parsed[-1]['text'] or last_tag is not None or last_tag != child.tag):
if parsed[-1]['break']['strength'] is None:
parsed[-1]['break'] = {'strength': break_strength,
'time': self.strength2time[break_strength]}
else:
last_time = parsed[-1]['break']['time']
parsed[-1]['break'] = {'strength': break_strength,
'time': max(last_time, self.strength2time[break_strength])}
if len(child_parsed) > 0:
if child_parsed[-1]['break']['strength'] is None:
child_parsed[-1]['break'] = {'strength': break_strength,
'time': self.strength2time[break_strength]}
else:
last_time = child_parsed[-1]['break']['time']
child_parsed[-1]['break'] = {'strength': break_strength,
'time': max(last_time, self.strength2time[break_strength])}
parsed.extend(child_parsed)
else:
warnings.warn(f"Current model doesn't support SSML tag: {child.tag}")
last_tag = child.tag
if child.tail:
tail_text = child.tail
if tail_text[0] in '.,!?…–;:' and len(parsed) > 0:
lost_punct = tail_text[0]
parsed[-1]['text'] = parsed[-1]['text'].strip() + lost_punct
if len(tail_text) > 1:
tail_text = tail_text[1:]
tail_text_parsed = self.process_head_tail_text(tail_text, def_rate, def_pitch)
parsed.extend(tail_text_parsed)
return parsed
def process_head_tail_text(self, element_text, def_rate, def_pitch):
text_parsed = []
if element_text is None:
return text_parsed
proc_text = element_text.replace('\n', '')
proc_text = re.sub(r'\s+', ' ', proc_text).strip()
text_parsed.append({'text': proc_text,
'break': {'strength': None,
'time': None},
'prosody': {'rate': def_rate,
'pitch': def_pitch}})
return text_parsed
def process_break_attrib(self, attrib):
for k in attrib.keys():
if k not in ['strength', 'time']:
warnings.warn(f"Current model doesn't support SSML <break> attrib: {k}")
strength = attrib.get('strength', 'medium')
break_time = attrib.get('time', None)
if break_time is not None:
if break_time.endswith('ms'):
break_ts = int(break_time[:-2])
elif break_time.endswith('s'):
break_ts = int(break_time[:-1]) * 1000
else:
raise AssertionError("Invalid <break> tag, time should end with 'ms' or 's'")
if break_ts >= self.strength2time['x-strong']:
strength = 'x-strong'
elif break_ts >= self.strength2time['strong']:
strength = 'strong'
else:
if strength in self.strength2time:
break_ts = self.strength2time[strength]
else:
raise AssertionError(f"Invalid <break> tag, strength should be in {', '.join(self.valid_tags['break']['strength'])}")
if break_ts > 5000:
warnings.warn('Cuurent model supports pauses less than 5 sec')
break_ts = 5000
return strength, break_ts
def process_prosody(self, attrib):
for k in attrib.keys():
if k not in ['rate', 'pitch']:
warnings.warn(f"Current model doesn't support SSML <prosody> attrib: {k}")
rate = attrib.get('rate', None)
pitch = attrib.get('pitch', None)
assert rate is not None or pitch is not None, "Empty <prosody> tag"
if rate is not None:
change_rate = True
if rate.endswith('%'):
rate_val = int(rate.replace('%', '')) / 100
else:
rate_val = self.rate2value.get(rate, None)
if rate_val is None:
raise AssertionError(f"Invalid <prosody> tag, rate should be in {', '.join(self.valid_tags['prosody']['rate'])}")
else:
change_rate = False
rate_val = 1.
if pitch is not None:
change_pitch = True
if pitch.endswith('%'):
pitch_val = int(pitch.replace('%', '')[1:]) / 100
if pitch[0] == '+':
pitch_val = 1. + pitch_val
else:
pitch_val = 1. - pitch_val
else:
pitch_val = self.pitch2value.get(pitch, None)
if pitch_val is None:
raise AssertionError(f"Invalid <prosody> tag, pitch should be in {', '.join(self.valid_tags['prosody']['pitch'])}")
else:
change_pitch = False
pitch_val = 1.
return rate_val, pitch_val, change_rate, change_pitch
def apply_tts(self, text=None,
ssml_text=None,
speaker: str = 'xenia',
sample_rate: int = 48000,
put_accent=True,
put_yo=True,
voice_path=None):
assert sample_rate in [8000, 24000, 48000], f"`sample_rate` should be in [8000, 24000, 48000], current value is {sample_rate}"
assert speaker in self.speakers, f"`speaker` should be in {', '.join(self.speakers)}"
assert text is not None or ssml_text is not None, "Both `text` and `ssml_text` are empty"
ssml = ssml_text is not None
if ssml:
input_text = ssml_text
else:
input_text = text
speaker_ids = self.get_speakers(speaker, voice_path)
sentences, clean_sentences, break_lens, prosody_rates, prosody_pitches, sp_ids = self.prepare_tts_model_input(input_text,
ssml=ssml,
speaker_ids=speaker_ids)
with torch.no_grad():
try:
out, out_lens = self.model(sentences=sentences,
clean_sentences=clean_sentences,
break_lens=break_lens,
prosody_rates=prosody_rates,
prosody_pitches=prosody_pitches,
speaker_ids=sp_ids,
sr=sample_rate,
device=str(self.device),
put_yo=put_yo,
put_accent=put_accent
)
except RuntimeError as e:
raise Exception("Model couldn't generate your text, probably it's too long")
audio = out.to('cpu')[0]
return audio
@staticmethod
def write_wave(path, audio, sample_rate):
"""Writes a .wav file.
Takes path, PCM audio data, and sample rate.
"""
with contextlib.closing(wave.open(path, 'wb')) as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio)
def save_wav(self, text=None,
ssml_text=None,
speaker: str = 'xenia',
audio_path: str = '',
sample_rate: int = 48000,
put_accent=True,
put_yo=True):
if not audio_path:
audio_path = 'test.wav'
audio = self.apply_tts(text=text,
ssml_text=ssml_text,
speaker=speaker,
sample_rate=sample_rate,
put_yo=put_yo,
put_accent=put_accent)
self.write_wave(path=audio_path,
audio=(audio * 32767).numpy().astype('int16'),
sample_rate=sample_rate)
return audio_path
def load_random_voice(self, voice_path=None):
if voice_path is None:
random_emb = torch.randn(2, self.emb_dim, requires_grad=False).to(self.device)
self.random_emb = random_emb
print("Generated new voice")
else:
random_emb = torch.load(voice_path, map_location=self.device)
print(f"Loaded voice from {voice_path}")
if self.random_emb is not None and torch.equal(self.random_emb, random_emb):
return
mel_weight = random_emb[0]
dur_weight = random_emb[1, :self.emb_dim//2]
p_weight = random_emb[1, self.emb_dim//2:]
self.model.tts_model.tacotron.speaker_embedding.weight[-1] = mel_weight
self.model.tts_model.dur_predictor.dur_pred.speaker_embedding.weight[-1] = dur_weight
self.model.tts_model.pitch_predictor.pitch_pred.speaker_embedding.weight[-1] = p_weight
def save_random_voice(self, voice_path):
assert self.random_emb is not None, "No generated random voice"
torch.save(self.random_emb, voice_path)
print(f"Saved generated voice to {voice_path}")