diff --git a/rvc/lib/train/data_utils.py b/rvc/lib/train/data_utils.py new file mode 100644 index 0000000..8d23c31 --- /dev/null +++ b/rvc/lib/train/data_utils.py @@ -0,0 +1,517 @@ +import logging +import os +import traceback + +logger = logging.getLogger(__name__) + +import numpy as np +import torch +import torch.utils.data + +from rvc.lib.train.mel_processing import spectrogram_torch +from rvc.lib.train.utils import load_filepaths_and_text, load_wav_to_torch + + +class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset): + """ + 1) loads audio, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, audiopaths_and_text, hparams): + self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + self.min_text_len = getattr(hparams, "min_text_len", 1) + self.max_text_len = getattr(hparams, "max_text_len", 5000) + self._filter() + + def _filter(self): + """ + Filter text & store spec lengths + """ + # Store spectrogram lengths for Bucketing + # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) + # spec_length = wav_length // hop_length + audiopaths_and_text_new = [] + lengths = [] + for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text: + if self.min_text_len <= len(text) and len(text) <= self.max_text_len: + audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv]) + lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length)) + self.audiopaths_and_text = audiopaths_and_text_new + self.lengths = lengths + + def get_sid(self, sid): + sid = torch.LongTensor([int(sid)]) + return sid + + def get_audio_text_pair(self, audiopath_and_text): + # separate filename and text + file = audiopath_and_text[0] + phone = audiopath_and_text[1] + pitch = audiopath_and_text[2] + pitchf = audiopath_and_text[3] + dv = audiopath_and_text[4] + + phone, pitch, pitchf = self.get_labels(phone, pitch, pitchf) + spec, wav = self.get_audio(file) + dv = self.get_sid(dv) + + len_phone = phone.size()[0] + len_spec = spec.size()[-1] + # print(123,phone.shape,pitch.shape,spec.shape) + if len_phone != len_spec: + len_min = min(len_phone, len_spec) + # amor + len_wav = len_min * self.hop_length + + spec = spec[:, :len_min] + wav = wav[:, :len_wav] + + phone = phone[:len_min, :] + pitch = pitch[:len_min] + pitchf = pitchf[:len_min] + + return (spec, wav, phone, pitch, pitchf, dv) + + def get_labels(self, phone, pitch, pitchf): + phone = np.load(phone) + phone = np.repeat(phone, 2, axis=0) + pitch = np.load(pitch) + pitchf = np.load(pitchf) + n_num = min(phone.shape[0], 900) # DistributedBucketSampler + # print(234,phone.shape,pitch.shape) + phone = phone[:n_num, :] + pitch = pitch[:n_num] + pitchf = pitchf[:n_num] + phone = torch.FloatTensor(phone) + pitch = torch.LongTensor(pitch) + pitchf = torch.FloatTensor(pitchf) + return phone, pitch, pitchf + + def get_audio(self, filename): + audio, sampling_rate = load_wav_to_torch(filename) + if sampling_rate != self.sampling_rate: + raise ValueError( + "{} SR doesn't match target {} SR".format( + sampling_rate, self.sampling_rate + ) + ) + audio_norm = audio + # audio_norm = audio / self.max_wav_value + # audio_norm = audio / np.abs(audio).max() + + audio_norm = audio_norm.unsqueeze(0) + spec_filename = filename.replace(".wav", ".spec.pt") + if os.path.exists(spec_filename): + try: + spec = torch.load(spec_filename) + except: + logger.warning("%s %s", spec_filename, traceback.format_exc()) + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ) + spec = torch.squeeze(spec, 0) + torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) + else: + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ) + spec = torch.squeeze(spec, 0) + torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) + return spec, audio_norm + + def __getitem__(self, index): + return self.get_audio_text_pair(self.audiopaths_and_text[index]) + + def __len__(self): + return len(self.audiopaths_and_text) + + +class TextAudioCollateMultiNSFsid: + """Zero-pads model inputs and targets""" + + def __init__(self, return_ids=False): + self.return_ids = return_ids + + def __call__(self, batch): + """Collate's training batch from normalized text and aduio + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized] + """ + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True + ) + + max_spec_len = max([x[0].size(1) for x in batch]) + max_wave_len = max([x[1].size(1) for x in batch]) + spec_lengths = torch.LongTensor(len(batch)) + wave_lengths = torch.LongTensor(len(batch)) + spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len) + wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len) + spec_padded.zero_() + wave_padded.zero_() + + max_phone_len = max([x[2].size(0) for x in batch]) + phone_lengths = torch.LongTensor(len(batch)) + phone_padded = torch.FloatTensor( + len(batch), max_phone_len, batch[0][2].shape[1] + ) # (spec, wav, phone, pitch) + pitch_padded = torch.LongTensor(len(batch), max_phone_len) + pitchf_padded = torch.FloatTensor(len(batch), max_phone_len) + phone_padded.zero_() + pitch_padded.zero_() + pitchf_padded.zero_() + # dv = torch.FloatTensor(len(batch), 256)#gin=256 + sid = torch.LongTensor(len(batch)) + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + + spec = row[0] + spec_padded[i, :, : spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + wave = row[1] + wave_padded[i, :, : wave.size(1)] = wave + wave_lengths[i] = wave.size(1) + + phone = row[2] + phone_padded[i, : phone.size(0), :] = phone + phone_lengths[i] = phone.size(0) + + pitch = row[3] + pitch_padded[i, : pitch.size(0)] = pitch + pitchf = row[4] + pitchf_padded[i, : pitchf.size(0)] = pitchf + + # dv[i] = row[5] + sid[i] = row[5] + + return ( + phone_padded, + phone_lengths, + pitch_padded, + pitchf_padded, + spec_padded, + spec_lengths, + wave_padded, + wave_lengths, + # dv + sid, + ) + + +class TextAudioLoader(torch.utils.data.Dataset): + """ + 1) loads audio, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + + def __init__(self, audiopaths_and_text, hparams): + self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + self.min_text_len = getattr(hparams, "min_text_len", 1) + self.max_text_len = getattr(hparams, "max_text_len", 5000) + self._filter() + + def _filter(self): + """ + Filter text & store spec lengths + """ + # Store spectrogram lengths for Bucketing + # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) + # spec_length = wav_length // hop_length + audiopaths_and_text_new = [] + lengths = [] + for audiopath, text, dv in self.audiopaths_and_text: + if self.min_text_len <= len(text) and len(text) <= self.max_text_len: + audiopaths_and_text_new.append([audiopath, text, dv]) + lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length)) + self.audiopaths_and_text = audiopaths_and_text_new + self.lengths = lengths + + def get_sid(self, sid): + sid = torch.LongTensor([int(sid)]) + return sid + + def get_audio_text_pair(self, audiopath_and_text): + # separate filename and text + file = audiopath_and_text[0] + phone = audiopath_and_text[1] + dv = audiopath_and_text[2] + + phone = self.get_labels(phone) + spec, wav = self.get_audio(file) + dv = self.get_sid(dv) + + len_phone = phone.size()[0] + len_spec = spec.size()[-1] + if len_phone != len_spec: + len_min = min(len_phone, len_spec) + len_wav = len_min * self.hop_length + spec = spec[:, :len_min] + wav = wav[:, :len_wav] + phone = phone[:len_min, :] + return (spec, wav, phone, dv) + + def get_labels(self, phone): + phone = np.load(phone) + phone = np.repeat(phone, 2, axis=0) + n_num = min(phone.shape[0], 900) # DistributedBucketSampler + phone = phone[:n_num, :] + phone = torch.FloatTensor(phone) + return phone + + def get_audio(self, filename): + audio, sampling_rate = load_wav_to_torch(filename) + if sampling_rate != self.sampling_rate: + raise ValueError( + "{} SR doesn't match target {} SR".format( + sampling_rate, self.sampling_rate + ) + ) + audio_norm = audio + # audio_norm = audio / self.max_wav_value + # audio_norm = audio / np.abs(audio).max() + + audio_norm = audio_norm.unsqueeze(0) + spec_filename = filename.replace(".wav", ".spec.pt") + if os.path.exists(spec_filename): + try: + spec = torch.load(spec_filename) + except: + logger.warning("%s %s", spec_filename, traceback.format_exc()) + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ) + spec = torch.squeeze(spec, 0) + torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) + else: + spec = spectrogram_torch( + audio_norm, + self.filter_length, + self.sampling_rate, + self.hop_length, + self.win_length, + center=False, + ) + spec = torch.squeeze(spec, 0) + torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) + return spec, audio_norm + + def __getitem__(self, index): + return self.get_audio_text_pair(self.audiopaths_and_text[index]) + + def __len__(self): + return len(self.audiopaths_and_text) + + +class TextAudioCollate: + """Zero-pads model inputs and targets""" + + def __init__(self, return_ids=False): + self.return_ids = return_ids + + def __call__(self, batch): + """Collate's training batch from normalized text and aduio + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized] + """ + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True + ) + + max_spec_len = max([x[0].size(1) for x in batch]) + max_wave_len = max([x[1].size(1) for x in batch]) + spec_lengths = torch.LongTensor(len(batch)) + wave_lengths = torch.LongTensor(len(batch)) + spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len) + wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len) + spec_padded.zero_() + wave_padded.zero_() + + max_phone_len = max([x[2].size(0) for x in batch]) + phone_lengths = torch.LongTensor(len(batch)) + phone_padded = torch.FloatTensor( + len(batch), max_phone_len, batch[0][2].shape[1] + ) + phone_padded.zero_() + sid = torch.LongTensor(len(batch)) + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + + spec = row[0] + spec_padded[i, :, : spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + wave = row[1] + wave_padded[i, :, : wave.size(1)] = wave + wave_lengths[i] = wave.size(1) + + phone = row[2] + phone_padded[i, : phone.size(0), :] = phone + phone_lengths[i] = phone.size(0) + + sid[i] = row[3] + + return ( + phone_padded, + phone_lengths, + spec_padded, + spec_lengths, + wave_padded, + wave_lengths, + sid, + ) + + +class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): + """ + Maintain similar input lengths in a batch. + Length groups are specified by boundaries. + Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. + + It removes samples which are not included in the boundaries. + Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. + """ + + def __init__( + self, + dataset, + batch_size, + boundaries, + num_replicas=None, + rank=None, + shuffle=True, + ): + super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + self.lengths = dataset.lengths + self.batch_size = batch_size + self.boundaries = boundaries + + self.buckets, self.num_samples_per_bucket = self._create_buckets() + self.total_size = sum(self.num_samples_per_bucket) + self.num_samples = self.total_size // self.num_replicas + + def _create_buckets(self): + buckets = [[] for _ in range(len(self.boundaries) - 1)] + for i in range(len(self.lengths)): + length = self.lengths[i] + idx_bucket = self._bisect(length) + if idx_bucket != -1: + buckets[idx_bucket].append(i) + + for i in range(len(buckets) - 1, -1, -1): # + if len(buckets[i]) == 0: + buckets.pop(i) + self.boundaries.pop(i + 1) + + num_samples_per_bucket = [] + for i in range(len(buckets)): + len_bucket = len(buckets[i]) + total_batch_size = self.num_replicas * self.batch_size + rem = ( + total_batch_size - (len_bucket % total_batch_size) + ) % total_batch_size + num_samples_per_bucket.append(len_bucket + rem) + return buckets, num_samples_per_bucket + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + + indices = [] + if self.shuffle: + for bucket in self.buckets: + indices.append(torch.randperm(len(bucket), generator=g).tolist()) + else: + for bucket in self.buckets: + indices.append(list(range(len(bucket)))) + + batches = [] + for i in range(len(self.buckets)): + bucket = self.buckets[i] + len_bucket = len(bucket) + ids_bucket = indices[i] + num_samples_bucket = self.num_samples_per_bucket[i] + + # add extra samples to make it evenly divisible + rem = num_samples_bucket - len_bucket + ids_bucket = ( + ids_bucket + + ids_bucket * (rem // len_bucket) + + ids_bucket[: (rem % len_bucket)] + ) + + # subsample + ids_bucket = ids_bucket[self.rank :: self.num_replicas] + + # batching + for j in range(len(ids_bucket) // self.batch_size): + batch = [ + bucket[idx] + for idx in ids_bucket[ + j * self.batch_size : (j + 1) * self.batch_size + ] + ] + batches.append(batch) + + if self.shuffle: + batch_ids = torch.randperm(len(batches), generator=g).tolist() + batches = [batches[i] for i in batch_ids] + self.batches = batches + + assert len(self.batches) * self.batch_size == self.num_samples + return iter(self.batches) + + def _bisect(self, x, lo=0, hi=None): + if hi is None: + hi = len(self.boundaries) - 1 + + if hi > lo: + mid = (hi + lo) // 2 + if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: + return mid + elif x <= self.boundaries[mid]: + return self._bisect(x, lo, mid) + else: + return self._bisect(x, mid + 1, hi) + else: + return -1 + + def __len__(self): + return self.num_samples // self.batch_size diff --git a/rvc/lib/train/losses.py b/rvc/lib/train/losses.py new file mode 100644 index 0000000..aa7bd81 --- /dev/null +++ b/rvc/lib/train/losses.py @@ -0,0 +1,58 @@ +import torch + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l diff --git a/rvc/lib/train/mel_processing.py b/rvc/lib/train/mel_processing.py new file mode 100644 index 0000000..3be443f --- /dev/null +++ b/rvc/lib/train/mel_processing.py @@ -0,0 +1,133 @@ +import logging + +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +logger = logging.getLogger(__name__) + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + return dynamic_range_compression_torch(magnitudes) + + +def spectral_de_normalize_torch(magnitudes): + return dynamic_range_decompression_torch(magnitudes) + + +# Reusable banks +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + """Convert waveform into Linear-frequency Linear-amplitude spectrogram. + + Args: + y :: (B, T) - Audio waveforms + n_fft + sampling_rate + hop_size + win_size + center + Returns: + :: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram + """ + # Validation + if torch.min(y) < -1.07: + logger.debug("min value is %s", str(torch.min(y))) + if torch.max(y) > 1.07: + logger.debug("max value is %s", str(torch.max(y))) + + # Window - Cache if needed + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + + # Padding + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + # Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + # Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + # MelBasis - Cache if needed + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=spec.dtype, device=spec.device + ) + + # Mel-frequency Log-amplitude spectrogram :: (B, Freq=num_mels, Frame) + melspec = torch.matmul(mel_basis[fmax_dtype_device], spec) + melspec = spectral_normalize_torch(melspec) + return melspec + + +def mel_spectrogram_torch( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + """Convert waveform into Mel-frequency Log-amplitude spectrogram. + + Args: + y :: (B, T) - Waveforms + Returns: + melspec :: (B, Freq, Frame) - Mel-frequency Log-amplitude spectrogram + """ + # Linear-frequency Linear-amplitude spectrogram :: (B, T) -> (B, Freq, Frame) + spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center) + + # Mel-frequency Log-amplitude spectrogram :: (B, Freq, Frame) -> (B, Freq=num_mels, Frame) + melspec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax) + + return melspec diff --git a/rvc/lib/train/process_ckpt.py b/rvc/lib/train/process_ckpt.py new file mode 100644 index 0000000..2a05154 --- /dev/null +++ b/rvc/lib/train/process_ckpt.py @@ -0,0 +1,260 @@ +import os +import sys +import traceback +from collections import OrderedDict + +import torch +from i18n.i18n import I18nAuto + +i18n = I18nAuto() + + +def savee(ckpt, sr, if_f0, name, epoch, version, hps): + try: + opt = OrderedDict() + opt["weight"] = {} + for key in ckpt.keys(): + if "enc_q" in key: + continue + opt["weight"][key] = ckpt[key].half() + opt["config"] = [ + hps.data.filter_length // 2 + 1, + 32, + hps.model.inter_channels, + hps.model.hidden_channels, + hps.model.filter_channels, + hps.model.n_heads, + hps.model.n_layers, + hps.model.kernel_size, + hps.model.p_dropout, + hps.model.resblock, + hps.model.resblock_kernel_sizes, + hps.model.resblock_dilation_sizes, + hps.model.upsample_rates, + hps.model.upsample_initial_channel, + hps.model.upsample_kernel_sizes, + hps.model.spk_embed_dim, + hps.model.gin_channels, + hps.data.sampling_rate, + ] + opt["info"] = "%sepoch" % epoch + opt["sr"] = sr + opt["f0"] = if_f0 + opt["version"] = version + torch.save(opt, "assets/weights/%s.pth" % name) + return "Success." + except: + return traceback.format_exc() + + +def show_info(path): + try: + a = torch.load(path, map_location="cpu") + return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % ( + a.get("info", "None"), + a.get("sr", "None"), + a.get("f0", "None"), + a.get("version", "None"), + ) + except: + return traceback.format_exc() + + +def extract_small_model(path, name, sr, if_f0, info, version): + try: + ckpt = torch.load(path, map_location="cpu") + if "model" in ckpt: + ckpt = ckpt["model"] + opt = OrderedDict() + opt["weight"] = {} + for key in ckpt.keys(): + if "enc_q" in key: + continue + opt["weight"][key] = ckpt[key].half() + if sr == "40k": + opt["config"] = [ + 1025, + 32, + 192, + 192, + 768, + 2, + 6, + 3, + 0, + "1", + [3, 7, 11], + [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + [10, 10, 2, 2], + 512, + [16, 16, 4, 4], + 109, + 256, + 40000, + ] + elif sr == "48k": + if version == "v1": + opt["config"] = [ + 1025, + 32, + 192, + 192, + 768, + 2, + 6, + 3, + 0, + "1", + [3, 7, 11], + [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + [10, 6, 2, 2, 2], + 512, + [16, 16, 4, 4, 4], + 109, + 256, + 48000, + ] + else: + opt["config"] = [ + 1025, + 32, + 192, + 192, + 768, + 2, + 6, + 3, + 0, + "1", + [3, 7, 11], + [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + [12, 10, 2, 2], + 512, + [24, 20, 4, 4], + 109, + 256, + 48000, + ] + elif sr == "32k": + if version == "v1": + opt["config"] = [ + 513, + 32, + 192, + 192, + 768, + 2, + 6, + 3, + 0, + "1", + [3, 7, 11], + [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + [10, 4, 2, 2, 2], + 512, + [16, 16, 4, 4, 4], + 109, + 256, + 32000, + ] + else: + opt["config"] = [ + 513, + 32, + 192, + 192, + 768, + 2, + 6, + 3, + 0, + "1", + [3, 7, 11], + [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + [10, 8, 2, 2], + 512, + [20, 16, 4, 4], + 109, + 256, + 32000, + ] + if info == "": + info = "Extracted model." + opt["info"] = info + opt["version"] = version + opt["sr"] = sr + opt["f0"] = int(if_f0) + torch.save(opt, "assets/weights/%s.pth" % name) + return "Success." + except: + return traceback.format_exc() + + +def change_info(path, info, name): + try: + ckpt = torch.load(path, map_location="cpu") + ckpt["info"] = info + if name == "": + name = os.path.basename(path) + torch.save(ckpt, "assets/weights/%s" % name) + return "Success." + except: + return traceback.format_exc() + + +def merge(path1, path2, alpha1, sr, f0, info, name, version): + try: + + def extract(ckpt): + a = ckpt["model"] + opt = OrderedDict() + opt["weight"] = {} + for key in a.keys(): + if "enc_q" in key: + continue + opt["weight"][key] = a[key] + return opt + + ckpt1 = torch.load(path1, map_location="cpu") + ckpt2 = torch.load(path2, map_location="cpu") + cfg = ckpt1["config"] + if "model" in ckpt1: + ckpt1 = extract(ckpt1) + else: + ckpt1 = ckpt1["weight"] + if "model" in ckpt2: + ckpt2 = extract(ckpt2) + else: + ckpt2 = ckpt2["weight"] + if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())): + return "Fail to merge the models. The model architectures are not the same." + opt = OrderedDict() + opt["weight"] = {} + for key in ckpt1.keys(): + # try: + if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape: + min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0]) + opt["weight"][key] = ( + alpha1 * (ckpt1[key][:min_shape0].float()) + + (1 - alpha1) * (ckpt2[key][:min_shape0].float()) + ).half() + else: + opt["weight"][key] = ( + alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float()) + ).half() + # except: + # pdb.set_trace() + opt["config"] = cfg + """ + if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000] + elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000] + elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000] + """ + opt["sr"] = sr + opt["f0"] = 1 if f0 == i18n("是") else 0 + opt["version"] = version + opt["info"] = info + torch.save(opt, "assets/weights/%s.pth" % name) + return "Success." + except: + return traceback.format_exc() diff --git a/rvc/lib/train/utils.py b/rvc/lib/train/utils.py new file mode 100644 index 0000000..9785097 --- /dev/null +++ b/rvc/lib/train/utils.py @@ -0,0 +1,478 @@ +import argparse +import glob +import json +import logging +import os +import shutil +import subprocess +import sys + +import numpy as np +import torch +from scipy.io.wavfile import read + +MATPLOTLIB_FLAG = False + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +logger = logging + + +def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + + ################## + def go(model, bkey): + saved_state_dict = checkpoint_dict[bkey] + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): # 模型需要的shape + try: + new_state_dict[k] = saved_state_dict[k] + if saved_state_dict[k].shape != state_dict[k].shape: + logger.warning( + "shape-%s-mismatch. need: %s, get: %s", + k, + state_dict[k].shape, + saved_state_dict[k].shape, + ) # + raise KeyError + except: + # logger.info(traceback.format_exc()) + logger.info("%s is not in the checkpoint", k) # pretrain缺失的 + new_state_dict[k] = v # 模型自带的随机值 + if hasattr(model, "module"): + model.module.load_state_dict(new_state_dict, strict=False) + else: + model.load_state_dict(new_state_dict, strict=False) + return model + + go(combd, "combd") + model = go(sbd, "sbd") + ############# + logger.info("Loaded model weights") + + iteration = checkpoint_dict["iteration"] + learning_rate = checkpoint_dict["learning_rate"] + if ( + optimizer is not None and load_opt == 1 + ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch + # try: + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + # except: + # traceback.print_exc() + logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration)) + return model, optimizer, learning_rate, iteration + + +# def load_checkpoint(checkpoint_path, model, optimizer=None): +# assert os.path.isfile(checkpoint_path) +# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') +# iteration = checkpoint_dict['iteration'] +# learning_rate = checkpoint_dict['learning_rate'] +# if optimizer is not None: +# optimizer.load_state_dict(checkpoint_dict['optimizer']) +# # print(1111) +# saved_state_dict = checkpoint_dict['model'] +# # print(1111) +# +# if hasattr(model, 'module'): +# state_dict = model.module.state_dict() +# else: +# state_dict = model.state_dict() +# new_state_dict= {} +# for k, v in state_dict.items(): +# try: +# new_state_dict[k] = saved_state_dict[k] +# except: +# logger.info("%s is not in the checkpoint" % k) +# new_state_dict[k] = v +# if hasattr(model, 'module'): +# model.module.load_state_dict(new_state_dict) +# else: +# model.load_state_dict(new_state_dict) +# logger.info("Loaded checkpoint '{}' (epoch {})" .format( +# checkpoint_path, iteration)) +# return model, optimizer, learning_rate, iteration +def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + + saved_state_dict = checkpoint_dict["model"] + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): # 模型需要的shape + try: + new_state_dict[k] = saved_state_dict[k] + if saved_state_dict[k].shape != state_dict[k].shape: + logger.warning( + "shape-%s-mismatch|need-%s|get-%s", + k, + state_dict[k].shape, + saved_state_dict[k].shape, + ) # + raise KeyError + except: + # logger.info(traceback.format_exc()) + logger.info("%s is not in the checkpoint", k) # pretrain缺失的 + new_state_dict[k] = v # 模型自带的随机值 + if hasattr(model, "module"): + model.module.load_state_dict(new_state_dict, strict=False) + else: + model.load_state_dict(new_state_dict, strict=False) + logger.info("Loaded model weights") + + iteration = checkpoint_dict["iteration"] + learning_rate = checkpoint_dict["learning_rate"] + if ( + optimizer is not None and load_opt == 1 + ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch + # try: + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + # except: + # traceback.print_exc() + logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration)) + return model, optimizer, learning_rate, iteration + + +def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): + logger.info( + "Saving model and optimizer state at epoch {} to {}".format( + iteration, checkpoint_path + ) + ) + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save( + { + "model": state_dict, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) + + +def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path): + logger.info( + "Saving model and optimizer state at epoch {} to {}".format( + iteration, checkpoint_path + ) + ) + if hasattr(combd, "module"): + state_dict_combd = combd.module.state_dict() + else: + state_dict_combd = combd.state_dict() + if hasattr(sbd, "module"): + state_dict_sbd = sbd.module.state_dict() + else: + state_dict_sbd = sbd.state_dict() + torch.save( + { + "combd": state_dict_combd, + "sbd": state_dict_sbd, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) + + +def summarize( + writer, + global_step, + scalars={}, + histograms={}, + images={}, + audios={}, + audio_sampling_rate=22050, +): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats="HWC") + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) + + +def latest_checkpoint_path(dir_path, regex="G_*.pth"): + f_list = glob.glob(os.path.join(dir_path, regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + x = f_list[-1] + logger.debug(x) + return x + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow( + alignment.transpose(), aspect="auto", origin="lower", interpolation="none" + ) + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + if info is not None: + xlabel += "\n\n" + info + plt.xlabel(xlabel) + plt.ylabel("Encoder timestep") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate + + +def load_filepaths_and_text(filename, split="|"): + with open(filename, encoding="utf-8") as f: + filepaths_and_text = [line.strip().split(split) for line in f] + return filepaths_and_text + + +def get_hparams(init=True): + """ + todo: + 结尾七人组: + 保存频率、总epoch done + bs done + pretrainG、pretrainD done + 卡号:os.en["CUDA_VISIBLE_DEVICES"] done + if_latest done + 模型:if_f0 done + 采样率:自动选择config done + 是否缓存数据集进GPU:if_cache_data_in_gpu done + + -m: + 自动决定training_files路径,改掉train_nsf_load_pretrain.py里的hps.data.training_files done + -c不要了 + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "-se", + "--save_every_epoch", + type=int, + required=True, + help="checkpoint save frequency (epoch)", + ) + parser.add_argument( + "-te", "--total_epoch", type=int, required=True, help="total_epoch" + ) + parser.add_argument( + "-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path" + ) + parser.add_argument( + "-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path" + ) + parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -") + parser.add_argument( + "-bs", "--batch_size", type=int, required=True, help="batch size" + ) + parser.add_argument( + "-e", "--experiment_dir", type=str, required=True, help="experiment dir" + ) # -m + parser.add_argument( + "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k" + ) + parser.add_argument( + "-sw", + "--save_every_weights", + type=str, + default="0", + help="save the extracted model in weights directory when saving checkpoints", + ) + parser.add_argument( + "-v", "--version", type=str, required=True, help="model version" + ) + parser.add_argument( + "-f0", + "--if_f0", + type=int, + required=True, + help="use f0 as one of the inputs of the model, 1 or 0", + ) + parser.add_argument( + "-l", + "--if_latest", + type=int, + required=True, + help="if only save the latest G/D pth file, 1 or 0", + ) + parser.add_argument( + "-c", + "--if_cache_data_in_gpu", + type=int, + required=True, + help="if caching the dataset in GPU memory, 1 or 0", + ) + + args = parser.parse_args() + name = args.experiment_dir + experiment_dir = os.path.join("./logs", args.experiment_dir) + + config_save_path = os.path.join(experiment_dir, "config.json") + with open(config_save_path, "r") as f: + config = json.load(f) + + hparams = HParams(**config) + hparams.model_dir = hparams.experiment_dir = experiment_dir + hparams.save_every_epoch = args.save_every_epoch + hparams.name = name + hparams.total_epoch = args.total_epoch + hparams.pretrainG = args.pretrainG + hparams.pretrainD = args.pretrainD + hparams.version = args.version + hparams.gpus = args.gpus + hparams.train.batch_size = args.batch_size + hparams.sample_rate = args.sample_rate + hparams.if_f0 = args.if_f0 + hparams.if_latest = args.if_latest + hparams.save_every_weights = args.save_every_weights + hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu + hparams.data.training_files = "%s/filelist.txt" % experiment_dir + return hparams + + +def get_hparams_from_dir(model_dir): + config_save_path = os.path.join(model_dir, "config.json") + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_file(config_path): + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + return hparams + + +def check_git_hash(model_dir): + source_dir = os.path.dirname(os.path.realpath(__file__)) + if not os.path.exists(os.path.join(source_dir, ".git")): + logger.warning( + "{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir + ) + ) + return + + cur_hash = subprocess.getoutput("git rev-parse HEAD") + + path = os.path.join(model_dir, "githash") + if os.path.exists(path): + saved_hash = open(path).read() + if saved_hash != cur_hash: + logger.warning( + "git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], cur_hash[:8] + ) + ) + else: + open(path, "w").write(cur_hash) + + +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.DEBUG) + h.setFormatter(formatter) + logger.addHandler(h) + return logger + + +class HParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__()