一下是一个Transformer代码实例:
def sample(self, batch_size, max_length=140, con_token_list= ['is_JNK3', 'is_GSK3', 'high_QED', 'good_SA']):
"""
Sample a batch of sequences
Args:
batch_size : Number of sequences to sample
max_length: Maximum length of the sequences
Outputs:
seqs: (batch_size, seq_length) The sampled sequences.
log_probs : (batch_size) Log likelihood for each sequence.
entropy: (batch_size) The entropies for the sequences. Not
currently used.
"""
# conditional token
con_token_list = Variable(self.voc.encode(con_token_list))
con_tokens = Variable(torch.zeros(batch_size, len(con_token_list)).long()) #形状为 (batch_size, len(con_token_list)),表示条件标记的张量。
for ind, token in enumerate(con_token_list):
con_tokens[:, ind] = token
start_token = Variable(torch.zeros(batch_size, 1).long()) #形状为 (batch_size, 1),表示序列开始标记的张量。
start_token[:] = self.voc.vocab['GO']
input_vector = start_token # 在循环中更新的张量,它的形状与 sequences 相同。
# print(batch_size)
sequences = start_token
log_probs = Variable(torch.zeros(batch_size))
# log_probs1 = Variable(torch.zeros(batch_size))
finished = torch.zeros(batch_size).byte()
finished = finished.to(self.device)
for step in range(max_length):
logits = sample_forward_model(self.decodertf, input_vector, con_tokens) #形状为 (batch_size, max_length, vocab_size)。
logits_step = logits[:, step, :] #是从 logits 中选择当前时间步的张量,形状为 (batch_size, vocab_size)。
prob = F.softmax(logits_step, dim=1)
log_prob = F.log_softmax(logits_step, dim=1)
input_vector = torch.multinomial(prob, 1)
# need to concat prior words as the sequences and input 记录下每一步采样
sequences = torch.cat((sequences, input_vector), 1) #形状为 (batch_size, seq_length),表示生成的序列。
log_probs += self._nll_loss(log_prob, input_vector.view(-1)) #形状为 (batch_size),表示每个生成序列的对数似然。
# log_probs1 += NLLLoss(log_prob, input_vector.view(-1))
# print(log_probs1==-log_probs)
EOS_sampled = (input_vector.view(-1) == self.voc.vocab['EOS']).data
finished = torch.ge(finished + EOS_sampled, 1) #形状为 (batch_size),是一个二进制张量,表示每个序列是否已经结束。
if torch.prod(finished) == 1:
# print('End')
break
# because there are no hidden layer in transformer, so we need to append generated word in every step as the input_vector
input_vector = sequences
return sequences[:, 1:].data, log_probs