文章目录
model arch S1 Model S2 model
model arch
S1 model: AR model–ssl tokens S2 model: VITS,ssl 已经是mel 长度线性相关,MRTE(ssl_codes_embs, text, global_mel_emb)模块,将文本加强相关,学到一个参考结果
S1 Model
class Text2SemanticDecoder ( )
def forward_old ( self, x, x_lens, y, y_lens, bert_feature) :
"""
x: phoneme_ids
y: semantic_ids
bert_feature: 已经根据word2phn 扩展成和x等长
train : y+EOS,已知长度;
infer : AR 预测,预测EOS 终止;如果没有,到预设最大长度,终止;
"""
x = self. ar_text_embedding( x)
x = x + self. bert_proj( bert_feature. transpose( 1 , 2 ) )
x = self. ar_text_position( x)
x_mask = make_pad_mask( x_lens)
y_mask = make_pad_mask( y_lens)
y_mask_int = y_mask. type ( torch. int64)
codes = y. type ( torch. int64) * ( 1 - y_mask_int)
y, targets = self. pad_y_eos( codes, y_mask_int, eos_id= self. EOS)
x_len = x_lens. max ( )
y_len = y_lens. max ( )
y_emb = self. ar_audio_embedding( y)
y_pos = self. ar_audio_position( y_emb)
xy_padding_mask = torch. concat( [ x_mask, y_mask] , dim= 1 )
ar_xy_padding_mask = xy_padding_mask
x_attn_mask = F. pad(
torch. zeros( ( x_len, x_len) , dtype= torch. bool , device= x. device) ,
( 0 , y_len) ,
value= True ,
)
y_attn_mask = F. pad(
torch. triu(
torch. ones( y_len, y_len, dtype= torch. bool , device= x. device) ,
diagonal= 1 ,
) ,
( x_len, 0 ) ,
value= False ,
)
xy_attn_mask = torch. concat( [ x_attn_mask, y_attn_mask] , dim= 0 )
bsz, src_len = x. shape[ 0 ] , x_len + y_len
_xy_padding_mask = (
ar_xy_padding_mask. view( bsz, 1 , 1 , src_len)
. expand( - 1 , self. num_head, - 1 , - 1 )
. reshape( bsz * self. num_head, 1 , src_len)
)
xy_attn_mask = xy_attn_mask. logical_or( _xy_padding_mask)
new_attn_mask = torch. zeros_like( xy_attn_mask, dtype= x. dtype)
new_attn_mask. masked_fill_( xy_attn_mask, float ( "-inf" ) )
xy_attn_mask = new_attn_mask
xy_pos = torch. concat( [ x, y_pos] , dim= 1 )
xy_dec, _ = self. h(
( xy_pos, None ) ,
mask= xy_attn_mask,
)
logits = self. ar_predict_layer( xy_dec[ : , x_len: ] ) . permute( 0 , 2 , 1 )
loss = F. cross_entropy( logits, targets, reduction= "sum" )
acc = self. ar_accuracy_metric( logits. detach( ) , targets) . item( )
return loss, acc
S2 model
class Encoder ( )
def forward ( self, ssl, y_lengths, text, text_lengths, speed= 1 , test= None ) :
'''
y_lengths: mel_length
ge : ref_encoder_outputs
'''
ge = self. ref_enc( y * y_mask, y_mask)
ssl = self. ssl_proj( ssl)
quantized, codes, commit_loss, quantized_list = self. quantizer(
ssl, layers= [ 0 ]
)
if self. semantic_frame_rate == "25hz" :
quantized = F. interpolate(
quantized, size= int ( quantized. shape[ - 1 ] * 2 ) , mode= "nearest"
)
y = self. encoder_ssl( y * y_mask, y_mask)
text_mask = torch. unsqueeze(
commons. sequence_mask( text_lengths, text. size( 1 ) ) , 1
) . to( y. dtype)
if test == 1 :
text[ : , : ] = 0
text = self. text_embedding( text) . transpose( 1 , 2 )
text = self. encoder_text( text * text_mask, text_mask)
y = self. mrte( y, y_mask, text, text_mask, ge)
y = self. encoder2( y * y_mask, y_mask)
if ( speed!= 1 ) :
y = F. interpolate( y, size= int ( y. shape[ - 1 ] / speed) + 1 , mode= "linear" )
y_mask = F. interpolate( y_mask, size= y. shape[ - 1 ] , mode= "nearest" )
stats = self. proj( y) * y_mask
m, logs = torch. split( stats, self. out_channels, dim= 1 )
return y, m, logs, y_mask