Skip to content

Commit

Permalink
Minor inference optimizations and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
deiteris committed Jul 28, 2024
1 parent 2446f20 commit 2a80282
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 103 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import copy
import math
from typing import Optional

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

from voice_changer.RVC.inferencer.rvc_models.infer_pack import commons, modules
from voice_changer.RVC.inferencer.rvc_models.infer_pack import commons
from voice_changer.RVC.inferencer.rvc_models.infer_pack.modules import LayerNorm


Expand Down Expand Up @@ -142,8 +140,9 @@ def forward(self, x, x_mask, h, h_mask):
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
m_size = x_mask.size(2)
self_attn_mask = commons.subsequent_mask(
torch.ones(m_size, m_size, device=x.device, dtype=x.dtype),
)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
Expand Down Expand Up @@ -191,6 +190,7 @@ def __init__(
self.attn = None

self.k_channels = channels // n_heads
self.k_channels_sqrt = math.sqrt(self.k_channels)
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
Expand Down Expand Up @@ -243,22 +243,21 @@ def attention(
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)

scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
query /= self.k_channels_sqrt

scores = torch.matmul(query, key.transpose(-2, -1))
if self.window_size is not None:
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
r = torch.arange(t_s, dtype=scores.dtype, device=scores.device)
scores = scores + self._attention_bias_proximal(r)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
Expand Down Expand Up @@ -373,14 +372,13 @@ def _absolute_position_to_relative_position(self, x):
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final

def _attention_bias_proximal(self, length: int):
def _attention_bias_proximal(self, r: torch.Tensor):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
r: torch.Tensor
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import copy
import math
from typing import Optional

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
Expand Down Expand Up @@ -142,8 +140,9 @@ def forward(self, x, x_mask, h, h_mask):
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
m_size = x_mask.size(2)
self_attn_mask = commons.subsequent_mask(
torch.ones(m_size, m_size, device=x.device, dtype=x.dtype),
)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
Expand Down Expand Up @@ -191,6 +190,7 @@ def __init__(
self.attn = None

self.k_channels = channels // n_heads
self.k_channels_sqrt = math.sqrt(self.k_channels)
self.conv_q = nn.Conv1d(channels, channels, 1)
self.conv_k = nn.Conv1d(channels, channels, 1)
self.conv_v = nn.Conv1d(channels, channels, 1)
Expand Down Expand Up @@ -243,22 +243,21 @@ def attention(
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)

scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
query /= self.k_channels_sqrt

scores = torch.matmul(query, key.transpose(-2, -1))
if self.window_size is not None:
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
r = torch.arange(t_s, dtype=scores.dtype, device=scores.device)
scores = scores + self._attention_bias_proximal(r)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
Expand Down Expand Up @@ -373,14 +372,13 @@ def _absolute_position_to_relative_position(self, x):
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final

def _attention_bias_proximal(self, length: int):
def _attention_bias_proximal(self, r: torch.Tensor):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
r: torch.Tensor
Returns:
a Tensor with shape [1, 1, length, length]
"""
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import List, Optional
import math

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F


Expand Down Expand Up @@ -66,7 +64,7 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str

Expand Down Expand Up @@ -99,9 +97,8 @@ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)


def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
def subsequent_mask(mask: torch.Tensor):
return torch.tril(mask, out=mask).unsqueeze(0).unsqueeze(0)


@torch.jit.script
Expand Down
46 changes: 17 additions & 29 deletions server/voice_changer/RVC/inferencer/rvc_models/infer_pack/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def forward(
x = x + self.cond(g)

for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = F.leaky_relu(x, modules.LRELU_SLOPE, inplace=True)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
Expand All @@ -270,9 +270,9 @@ def forward(
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = F.leaky_relu(x, inplace=True)
x = self.conv_post(x)
x = torch.tanh(x)
x = torch.tanh(x, out=x)

return x

Expand Down Expand Up @@ -338,15 +338,12 @@ def __init__(
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold

def _f02uv(self, f0):
def _f02uv(self, f0: torch.Tensor):
# generate uv signal
uv = torch.ones_like(f0)
uv = uv * (f0 > self.voiced_threshold)
if uv.device.type == "privateuseone": # for DirectML
uv = uv.float()
return uv
return uv * (f0 > self.voiced_threshold)

def forward(self, f0: torch.Tensor, upp: int):
def forward(self, f0: torch.Tensor, upp: float):
"""sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, length, dim=1)
f0 for unvoiced steps should be 0
Expand Down Expand Up @@ -376,12 +373,12 @@ def forward(self, f0: torch.Tensor, upp: int):
tmp_over_one *= upp
tmp_over_one = F.interpolate(
tmp_over_one.transpose(2, 1),
scale_factor=float(upp),
scale_factor=upp,
mode="linear",
align_corners=True,
).transpose(2, 1)
rad_values = F.interpolate(
rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest"
rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
).transpose(
2, 1
) #######
Expand All @@ -393,9 +390,9 @@ def forward(self, f0: torch.Tensor, upp: int):
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi
)
sine_waves = sine_waves * self.sine_amp
uv = self._f02uv(f0)
uv = self._f02uv(f0).to(f0.dtype)
uv = F.interpolate(
uv.transpose(2, 1), scale_factor=float(upp), mode="nearest"
uv.transpose(2, 1), scale_factor=upp, mode="nearest"
).transpose(2, 1)
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
Expand Down Expand Up @@ -443,18 +440,9 @@ def __init__(
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
# self.ddtype:int = -1

def forward(self, x: torch.Tensor, upp: int = 1):
# if self.ddtype ==-1:
# self.ddtype = self.l_linear.weight.dtype
def forward(self, x: torch.Tensor, upp: float = 1.):
sine_wavs, uv, _ = self.l_sin_gen(x, upp)
# print(x.dtype,sine_wavs.dtype,self.l_linear.weight.dtype)
# if self.is_half:
# sine_wavs = sine_wavs.half()
# sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(x)))
# print(sine_wavs.dtype,self.ddtype)
# if sine_wavs.dtype != self.l_linear.weight.dtype:
sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
return sine_merge, None, None # noise, uv
Expand Down Expand Up @@ -540,7 +528,7 @@ def forward(
g: Optional[torch.Tensor] = None,
n_res: Optional[int] = None,
):
har_source, noi_source, uv = self.m_source(f0, self.upp)
har_source, noi_source, uv = self.m_source(f0, float(self.upp))
har_source = har_source.transpose(1, 2)
if n_res is not None:
if (n := n_res * self.upp) != har_source.shape[-1]:
Expand All @@ -554,7 +542,7 @@ def forward(
# That's why I wrote this
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
if i < self.num_upsamples:
x = F.leaky_relu(x, self.lrelu_slope)
x = F.leaky_relu(x, self.lrelu_slope, inplace=True)
x = ups(x)
x_source = noise_convs(har_source)
x = x + x_source
Expand All @@ -570,9 +558,9 @@ def forward(
# If ignored, it will cause torch.jit.script() compilation errors
assert isinstance(xs, torch.Tensor)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = F.leaky_relu(x, inplace=True)
x = self.conv_post(x)
x = torch.tanh(x)
x = torch.tanh(x, out=x)
return x

def remove_weight_norm(self):
Expand Down Expand Up @@ -1121,7 +1109,7 @@ def forward(self, x):

for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = F.leaky_relu(x, modules.LRELU_SLOPE, inplace=True)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
Expand Down Expand Up @@ -1205,7 +1193,7 @@ def forward(self, x):

for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = F.leaky_relu(x, modules.LRELU_SLOPE, inplace=True)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
Expand Down
Loading

0 comments on commit 2a80282

Please sign in to comment.