Skip to content

Compatibility with Matcha TTS #39

@mush42

Description

@mush42

Hi

The issue

I trained a model based on Matcha TTS, and I tried to use Vocos with it. Unfortunately, vocoding using a checkpoint trained with the default config of Vocos gives a robotic output with very low volume.

The only config values I changed are sample_rate (=22050) and n_mels (=80).

I asumed that there is a mismatch between Matcha TTS-generated melspectrogram and Vocos expected melspectrogram in terms of parameters.

A new feature extractor

I wrote a feature extractor class to generate melspectogram using same parameters of Matcha TTS. Most of the code is copied directly from Matcha's source code.

Click to expand: MatchaMelSpectrogramFeatures
import numpy as np
import torch
from librosa.filters import mel as librosa_mel_fn

from vocos.feature_extractors import FeatureExtractor


class MatchaMelSpectrogramFeatures(FeatureExtractor):
    """
    Generate MelSpectrogram from audio using same params
    as Matcha TTS (https://github.com/shivammehta25/Matcha-TTS)
    This is also useful with tacatron, waveglow..etc.
    """

    def __init__(
        self,
        *,
        mel_mean,
        mel_std,
        sample_rate=22050,
        n_fft=1024,
        win_length=1024,
        n_mels=80,
        hop_length=256,
        center=False,
        f_min=0,
        f_max=8000,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.win_length = win_length
        self.hop_length = hop_length
        self.center = center
        self.f_min = f_min
        self.f_max = f_max
        # Data-dependent
        self.mel_mean = mel_mean
        self.mel_std = mel_std
        # Cache
        self._mel_basis = {}
        self._hann_window = {}

    def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
        mel = self.mel_spectrogram(audio).squeeze()
        mel = normalize(mel, self.mel_mean, self.mel_std)
        return mel.unsqueeze(0)

    def mel_spectrogram(self, y):
        mel_basis_key = str(self.f_max) + "_" + str(y.device)
        han_window_key = str(y.device)
        if mel_basis_key not in self._mel_basis:
            mel = librosa_mel_fn(
                sr=self.sample_rate,
                n_fft=self.n_fft,
                n_mels=self.n_mels,
                fmin=self.f_min,
                fmax=self.f_max
            )
            self._mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
            self._hann_window[han_window_key] = torch.hann_window(self.win_length).to(y.device)
        pad_vals = (
            (self.n_fft - self.hop_length) // 2,
            (self.n_fft - self.hop_length) // 2,
        )
        y = torch.nn.functional.pad(
            y.unsqueeze(1),
            pad_vals,
            mode="reflect"
        )
        y = y.squeeze(1)
        spec = torch.stft(
            y,
            self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=self._hann_window[han_window_key],
            center=self.center,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )
        spec = torch.view_as_real(spec)
        spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
        spec = torch.matmul(self._mel_basis[mel_basis_key], spec)
        spec = spectral_normalize_torch(spec)
        return spec


def spectral_normalize_torch(magnitudes):
    output = dynamic_range_compression_torch(magnitudes)
    return output


def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
    return torch.log(torch.clamp(x, min=clip_val) * C)

def normalize(data, mu, std):
    if not isinstance(mu, (float, int)):
        if isinstance(mu, list):
            mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
        elif isinstance(mu, torch.Tensor):
            mu = mu.to(data.device)
        elif isinstance(mu, np.ndarray):
            mu = torch.from_numpy(mu).to(data.device)
        mu = mu.unsqueeze(-1)

    if not isinstance(std, (float, int)):
        if isinstance(std, list):
            std = torch.tensor(std, dtype=data.dtype, device=data.device)
        elif isinstance(std, torch.Tensor):
            std = std.to(data.device)
        elif isinstance(std, np.ndarray):
            std = torch.from_numpy(std).to(data.device)
        std = std.unsqueeze(-1)

    return (data - mu) / std

And I used it with the following config:

Click to expand config: vocos-matcha.yaml
# pytorch_lightning==1.8.6
seed_everything: 4444

data:
  class_path: vocos.dataset.VocosDataModule
  init_args:
    train_params:
      filelist_path: ./datasets/train.txt
      sampling_rate: 22050
      num_samples: 16384
      batch_size: 16
      num_workers: 4

    val_params:
      filelist_path: ./datasets/val.txt
      sampling_rate: 22050
      num_samples: 48384
      batch_size: 16
      num_workers: 4

model:
  class_path: vocos.experiment.VocosExp
  init_args:
    sample_rate: 22050
    initial_learning_rate: 5e-4
    mel_loss_coeff: 45
    mrd_loss_coeff: 0.1
    num_warmup_steps: 0 # Optimizers warmup steps
    pretrain_mel_steps: 0  # 0 means GAN objective from the first iteration
    # automatic evaluation
    evaluate_utmos: true
    evaluate_pesq: true
    evaluate_periodicty: true

    feature_extractor:
      class_path: matcha_feature_extractor.MatchaMelSpectrogramFeatures
      init_args:
        sample_rate: 22050
        n_fft: 1024
        n_mels: 80
        hop_length: 256
        win_length: 1024
        f_min: 0
        f_max: 8000
        center: False
        mel_mean: -6.38385
        mel_std: 2.541796

    backbone:
      class_path: vocos.models.VocosBackbone
      init_args:
        input_channels: 80
        dim: 512
        intermediate_dim: 1536
        num_layers: 8

    head:
      class_path: vocos.heads.ISTFTHead
      init_args:
        dim: 512
        n_fft: 1024
        hop_length: 256
        padding: same

trainer:
  logger:
    class_path: pytorch_lightning.loggers.TensorBoardLogger
    init_args:
      save_dir: /content/drive/MyDrive/vocos/logs
  callbacks:
    - class_path: pytorch_lightning.callbacks.LearningRateMonitor
    - class_path: pytorch_lightning.callbacks.ModelSummary
      init_args:
        max_depth: 2
    - class_path: pytorch_lightning.callbacks.ModelCheckpoint
      init_args:
        monitor: val_loss
        filename: vocos_checkpoint_{epoch}_{step}_{val_loss:.4f}
        save_top_k: 2
        save_last: true
    - class_path: vocos.helpers.GradNormCallback

  # Lightning calculates max_steps across all optimizer steps (rather than number of batches)
  # This equals to 1M steps per generator and 1M per discriminator
  max_steps: 2000000
  # You might want to limit val batches when evaluating all the metrics, as they are time-consuming
  limit_val_batches: 128
  accelerator: gpu
  strategy: ddp
  devices: [0]
  log_every_n_steps: 100

Results

I trained Vocos using the above feature extractor and config, but this also fails with even worse vocoding quality and even lower volume.

Questions

  • Did I miss something in above feature extractor?
  • Does the default Vocos head expects melspectograms generated using certain parameters?
  • Any suggestions to resolve this?

Additional notes

I believe many open-source TTS models use the same code to extract melspectogram. So resolving this will help with training Vocos for use with these TTS models.

Best

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions