使用taichi 写了一个任意平台任意显卡推理的Linear

news/2024/7/19 11:11:18 标签: windows, 重构, transformer, 神经网络, 前端

这东西就是在于任意的显卡都能加速任意模型
当然如何有人使用taichi写一个卷积那么计算机视觉也能任意显卡加速人工智能
如果还有人写了个深度学习训练框架那么恭喜AMD,ARM 等任何芯片厂商乐疯

import taichi as ti
import numpy as np
ti.init(arch=ti.vulkan)


class Linear():
    def __init__(self, input_size, output_size, weights=None):
        if weights:
            self.weights = weights
        else:
            # self.weights = ti.Matrix([[0] * output_size] * input_size)
            self.weights = ti.Matrix(np.random.random(output_size*input_size).reshape([input_size,output_size]))

    @staticmethod
    def taichi_mul(a, b):
        ar, al = a.to_numpy().shape
        br, bl = b.to_numpy().shape
        assert al == br

        @ti.kernel
        def mlp() -> ti.types.matrix(ar, bl, dtype=ti.float32):
            return a @ b

        return mlp()

    def forward(self, a):
        return self.taichi_mul(a, self.weights)

    def set_weights(self, weights):
        self.weights = ti.Matrix(weights)


l1 = Linear(3, 5)
print(l1.forward(ti.Matrix([[1] * 3])))
这段代码实现了一个简单的线性层(Linear layer)类,通过调用 forward 函数可以对输入进行线性变换。其中 Linear 类中的 weights 表示线性层的权重矩阵,可以通过构造函数的输入或者 set_weights 函数进行设置。forward 函数使用 taichi_mul 函数实现了矩阵乘法,并返回了乘积结果。

在 __init__ 函数中,如果 weights 参数被传入,则将其作为权重矩阵;否则通过 np.random.random 函数生成随机的权重矩阵。其中,输入参数 input_size 和 output_size 分别表示输入和输出的特征数。

在 taichi_mul 函数中,输入参数 a 和 b 分别表示两个矩阵,通过 Taichi 提供的 @ 运算符实现了矩阵乘法,返回乘积结果。值得注意的是,Taichi 不能直接操作 Python 中的数据类型,因此在使用 Taichi 前,需要将 Python 中的数据类型转换为 Taichi 中的数据类型,可以调用 to_numpy() 函数将 Taichi 中的数据类型转换为 NumPy 中的数据类型,然后对其进行操作,最后再调用 ti.Matrix() 函数将其转换为 Taichi 中的数据类型。

在主函数中,首先构造了一个输入矩阵 [[1] * 3],然后通过 l1.forward() 函数将其输入到 l1 线性层中,得到线性变换的结果,最后将结果打印输出。

http://www.niftyadmin.cn/n/5066091.html

相关文章

安装matplotlib_

安装pip 安装matplotlib 安装完毕 导入出现bug......

【Java-LangChain:面向开发者的提示工程-2】编写提示词原则

第二章 编写提示词原则 一、环境配置 开始学习前,我们用最简单的方案使用LLM(这里我们使用ChatGPT最为我们调用的LLM)。所以我们使用 Java版的SDK(第三方库) 你也可以参考 官方文档 ,看看都有哪些内容。 自己的项目…

SketchUp Pro 2023 for Mac——打造你的创意之城

SketchUp Pro 2023 for Mac是一款专业级的3D建模软件,为你提供最佳的设计和创意工具。不论你是建筑师、室内设计师,还是爱好者,SketchUp Pro都能满足你对于创意表达的需求。 SketchUp Pro 2023拥有强大而直观的界面,让你轻松绘制…

【算法训练-数组 三】【数组矩阵】螺旋矩阵、搜索二维矩阵

废话不多说,喊一句号子鼓励自己:程序员永不失业,程序员走向架构!本篇Blog的主题是螺旋矩阵,使用【二维数组】这个基本的数据结构来实现 螺旋矩阵【EASY】 二维数组的结构特性入手 题干 解题思路 根据题目示例 mat…

BOM体系学习

1.BOM体系总体概述 BOM定义 描述产品制造所需的部件、零件、原材料还有数量的完整清单。 BOM作用 表明使用/被使用、数量关系,指导投产和配套;形成材料定额,支持物资采购;定义使用、借用、替代等关系,支持更改影响分…

计算机竞赛 车道线检测(自动驾驶 机器视觉)

0 前言 无人驾驶技术是机器学习为主的一门前沿领域,在无人驾驶领域中机器学习的各种算法随处可见,今天学长给大家介绍无人驾驶技术中的车道线检测。 1 车道线检测 在无人驾驶领域每一个任务都是相当复杂,看上去无从下手。那么面对这样极其…

NPDP产品经理知识(产品设计与开发工具)

1.复习产品创新流程 -- 系统工程 -- 设计思维(DESIGN THINKING)

初步了解nodejs语法和web模块

在此, 第一个Node.js实例_js firstnode-CSDN博客 通过node运行一个简单的server.js,实现了一个http服务器; 但是还没有解析server.js的代码,下面看一下; require 指令 在 Node.js 中,使用 require 指令来…