Transformer代码实例中各张量的维度是多少

news/2024/7/19 12:25:08 标签: transformer, 深度学习, 人工智能

一下是一个Transformer代码实例:

def sample(self, batch_size, max_length=140, con_token_list= ['is_JNK3', 'is_GSK3', 'high_QED', 'good_SA']):
        """
               Sample a batch of sequences

               Args:
                   batch_size : Number of sequences to sample
                   max_length:  Maximum length of the sequences

               Outputs:
               seqs: (batch_size, seq_length) The sampled sequences.
               log_probs : (batch_size) Log likelihood for each sequence.
               entropy: (batch_size) The entropies for the sequences. Not
                                       currently used.
       """

        # conditional token
        con_token_list = Variable(self.voc.encode(con_token_list))

        con_tokens = Variable(torch.zeros(batch_size, len(con_token_list)).long()) #形状为 (batch_size, len(con_token_list)),表示条件标记的张量。

        for ind, token in enumerate(con_token_list):
            con_tokens[:, ind] = token

        start_token = Variable(torch.zeros(batch_size, 1).long())  #形状为 (batch_size, 1),表示序列开始标记的张量。
        start_token[:] = self.voc.vocab['GO']
        input_vector = start_token   # 在循环中更新的张量,它的形状与 sequences 相同。
        # print(batch_size)

        sequences = start_token
        log_probs = Variable(torch.zeros(batch_size))
        # log_probs1 = Variable(torch.zeros(batch_size))

        finished = torch.zeros(batch_size).byte()

        finished = finished.to(self.device)

        for step in range(max_length):
            logits = sample_forward_model(self.decodertf, input_vector, con_tokens) #形状为 (batch_size, max_length, vocab_size)。

            logits_step = logits[:, step, :]  #是从 logits 中选择当前时间步的张量,形状为 (batch_size, vocab_size)。

            prob = F.softmax(logits_step, dim=1)
            log_prob = F.log_softmax(logits_step, dim=1)

            input_vector = torch.multinomial(prob, 1)

            # need to concat prior words as the sequences and input 记录下每一步采样
            sequences = torch.cat((sequences, input_vector), 1)  #形状为 (batch_size, seq_length),表示生成的序列。


            log_probs += self._nll_loss(log_prob, input_vector.view(-1))  #形状为 (batch_size),表示每个生成序列的对数似然。
            # log_probs1 += NLLLoss(log_prob, input_vector.view(-1))
            # print(log_probs1==-log_probs)




            EOS_sampled = (input_vector.view(-1) == self.voc.vocab['EOS']).data
            finished = torch.ge(finished + EOS_sampled, 1)  #形状为 (batch_size),是一个二进制张量,表示每个序列是否已经结束。

            if torch.prod(finished) == 1:
                # print('End')
                break

            # because there are no hidden layer in transformer, so we need to append generated word in every step as the input_vector
            input_vector = sequences

        return sequences[:, 1:].data, log_probs


http://www.niftyadmin.cn/n/5228985.html

相关文章

iRDMA流量控制总结 - 3

5.0 Priority Flow Control - Planning and Guidelines优先流量控制 - 规划与指导 This section covers planning, considerations, and general configuration guidelines for enabling PFC on a network. 本节介绍在网络上启用 PFC 的规划、注意事项和一般配置指南。 5.1 S…

在Transformer模型中, Positional Encoding的破坏性分析

在Transformer模型中,Word Embedding 被加上一个Positional Encoding,是否会破坏原来的Word Embedding 的含义 Sinusoidal Positional Encoding的破坏性可以从两个方面来分析:一是对Word Embedding的语义信息的破坏,二是对Word Em…

LeetCode(41)单词规律【哈希表】【简单】

目录 1.题目2.答案3.提交结果截图 链接: 单词规律 1.题目 给定一种规律 pattern 和一个字符串 s ,判断 s 是否遵循相同的规律。 这里的 遵循 指完全匹配,例如, pattern 里的每个字母和字符串 s 中的每个非空单词之间存在着双向连…

指针、数组与函数例题

1、简单数字显示 题目描述 本例要求实现对变量的直接访问和间接访问。输入任意两个整数,先用直接访问的方式输出这两个变量的值,再通过指针变量用间接访问的方式输出这两个变量的值。 输入要求 输入两个整数 输出要求 先用直接访问方式使出这两个整…

JDK 动态代理从入门到掌握

快速入门 本文介绍 JDK 实现的动态代理及其原理,通过 ProxyGenerator 生成的动态代理类字节码文件 环境要求 要求原因JDK 8 及以下在 JDK 9 之后无法使用直接调用 ProxyGenerator 中的方法,不便于将动态代理类对应的字节码文件输出lombok为了使用 Sne…

1.7 java实现License认证信息的加密解密处理

java实现License认证信息的加密解密处理 一、什么是License认证二、确定License文件的格式和内容1. 生成一个存放License信息的ini文件 三、使用RSA非对称加密方式对文件进行加密和解密1. 生成密钥对2. 加密数据3. 解密数据 一、什么是License认证 License认证是一种用于验证软…

上海毅速丨新材料将推动3D打印在压铸行业的应用

压铸是一种应用广泛的制造工艺,它的制造原理是将液态或半液态金属,在高压作用下,以高速度填充压铸模具型腔,并在压力下快速凝固而获得铸件的一种方法。压铸模的设计和制造需要考虑到多方面的因素,如模具材料、结构、冷…

MySQL表连接详解:解析内连接与外连接的使用方法

在MySQL中,表连接是一种将两个或多个表中的行相关联的操作。这是通过在这些表之间共享一个或多个列的值来实现的。表连接通常用于从多个表中检索相关的数据。 有几种不同类型的表连接,其中两个常见的是内连接(INNER JOIN)和外连接…