CANN/catlass:Ascend 950 MX FP4矩阵乘示例
MXFP4MatmulTla Example Readme【免费下载链接】catlass本项目是CANN的算子模板库提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass注意社区包暂不支持 950 能力后续支持的版本敬请期待功能介绍演示 Ascend 950 上的MX FP4 矩阵乘左矩阵 A、右矩阵 B 经 MX 缩放float8_e8m0后在 Cube 上完成乘加输出为 FP32。本示例中 A、B 元素类型为float4_e2m1x2_tE2M1 打包缩放因子为float8_e8m0_t。未启用 BiasElementBias为void。默认布局为 ARowMajor、BColumnMajor、CRowMajor与gen_data.py在trans_a0, trans_b1时生成的数据一致。代码组织├── 54_ascend950_fp4_mx_matmul │ ├── CMakeLists.txt # CMake编译文件 │ ├── README.md │ ├── gen_data.py │ └── fp4_mx_matmul.cpp # 主文件使用示例获取代码之后编译相应的算子可执行文件可参考quickstart本用例为 Ascend9503510算子编译时需加-DCATLASS_ARCH3510。L1 分块为 256×256×448、L0 为 256×256×128以满足 512KiB L1 与 L0 容量约束勿随意增大 L1 的 K否则L1TileShape exceeding the L1 space。执行算子# 编译指定用例 bash scripts/build.sh 54_ascend950_fp4_mx_matmul -DCATLASS_ARCH3510 # 生成测试样例在 examples/54_ascend950_fp4_mx_matmul/data 下生成 input/ 与 golden/ python3 examples/54_ascend950_fp4_mx_matmul/gen_data.py 256 512 1024 0 1 # 输入参数分别对应 m, n, k, trans_a, trans_b # trans_a表示A矩阵是否转置0是不转置1是转置 # trans_b表示B矩阵是否转置0是不转置1是转置 # 执行测试样例 ./output/bin/54_ascend950_fp4_mx_matmul 256 512 1024 0 # 可执行文件名 |矩阵m轴|n轴|k轴|Device ID # Device ID可选默认为0执行结果如下说明精度比对成功。Compare success.使用说明1、gen_data.py的输入支持trans_a和trans_b但54_ascend950_fp4_mx_matmul可执行文件不支持仅仅是trans_a为0及trans_b为1的example示例。若要对应转置情况请修改example示例中的layout因为layout隐式表征转置状态即layout::RowMajor表示不转置layout::ColumnMajor表示转置。其对应关系如下表trans_atrans_bLayoutALayoutB00layout::RowMajorlayout::RowMajor01layout::RowMajorlayout::ColumnMajor10layout::ColumnMajorlayout::RowMajor11layout::ColumnMajorlayout::ColumnMajor2、 本example完成mx量化矩阵乘 C (MxScaleA x A) * (MxScaleB x B) Bias A、B支持数据类型为float4_e1m2或float4_e2m1 MxScaleA、MxScaleB支持数据类型为float8_e8m0其中对于MxScaleA、MxScaleB的数据排布要求如下 当A为RowMajor时MxScaleA的shape为m, ceil(k/64), 2 当A为ColumnMajor时MxScaleA的shape为ceil(k/64), m, 2 当B为RowMajor时MxScaleB的shape为ceil(k/64), n, 2 当B为ColumnMajor时MxScaleB的shape为n, ceil(k/64), 23、MxMatmulTla与BlockMmadTla搭配使用的 DispatchPolicy 为Gemm::MmadMx定义见include/catlass/gemm/dispatch_policy.hpp模板参数顺序与默认值如下模板参数默认值参数说明ArchTag无架构标签例如Arch::Ascend950ENABLE_UNIT_FLAGfalse是否开启 UnitFlag当L0C_STAGES 1L0C 多缓冲时必须为falseL1_SCALE_FACTOR_K16GM→L1 的 MX scale 一次驻留所覆盖的L1 K 方向条带个数为1时表示每个 L1 K 条带各搬一次 scale见类型内注释L0C_STAGES1L0C 缓冲段数设为2可开启 L0C 双缓冲需与ENABLE_UNIT_FLAG约束一致ENABLE_L1_RESIDENTfalse是否开启 L1 常驻L1A_STAGES2L1 上加载矩阵 A 的 buffer 数量L1B_STAGES2L1 上加载矩阵 B 的 buffer 数量L0A_STAGES2L0 上加载矩阵 A 的 buffer 数量L0B_STAGES2L0 上加载矩阵 B 的 buffer 数量设矩阵Shape为M N K, L1上的分块大小为m1 n1 k1M方向的分块数量mTiles CeilDiv(M, m1)N方向的分块数量nTiles CeilDiv(N, n1)总任务数为taskBlocks mTiles * nTiles在以下两种情况下可以选择开启enableL1Resident1.mTiles 1且nTiles CoreNum且K 2 * k1。此时还可以设置l0CStages2(需要关闭enableUnitFlag)如果空间不足无法设置l0CStages2则将n1设置为原来的一半。2.nTiles 1且mTiles CoreNum, 且K 2 * k1。此时还可以设置l0CStages2(需要关闭enableUnitFlag)如果空间不足无法设置l0CStages2则将m1设置为原来的一半。【免费下载链接】catlass本项目是CANN的算子模板库提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考