Skip to content

Commit

Permalink
[transform][bf16] Enable bf16 emulation only for non-native bf16 ops. (
Browse files Browse the repository at this point in the history
…#874)

[transform][bf16] Enable bf16 emulation only for non-native bf16 ops.

Emulate bf16 ops by extending them to f32 and truncate the result back to
bf16 whose SPIR-V counterpart is not natively supported.
  • Loading branch information
mshahneo committed Sep 11, 2024
1 parent d9e0bff commit c28dab2
Show file tree
Hide file tree
Showing 15 changed files with 333 additions and 10 deletions.
1 change: 1 addition & 0 deletions include/imex/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ std::unique_ptr<mlir::Pass> createRemoveSingleElemVectorPass();
std::unique_ptr<mlir::Pass> createOptimizeTransposePass();
std::unique_ptr<mlir::Pass> createHoistTransposePass();
std::unique_ptr<mlir::Pass> createVnniTransformationPass();
std::unique_ptr<mlir::Pass> createEmulateNonNativeBF16Pass();

#define GEN_PASS_DECL
#include "imex/Transforms/Passes.h.inc"
Expand Down
18 changes: 17 additions & 1 deletion include/imex/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def BF16ToGPU : Pass<"bf16-to-gpu", "::mlir::ModuleOp"> {
and f32 dtype that can be lowered to spirv dialect with Intel spirv extension ops.
bf16 is bitcast to a bitwidth equal type i16 as bf16 is not a supported type
in spirv.
Computation is replace by first extending bf16 to f32, do the compute in f32
Computation is replaced by first extending bf16 to f32, do the compute in f32
and truncate result back to bf16.
}];
let constructor = "imex::createBF16ToGPUPass()";
Expand All @@ -125,6 +125,22 @@ def CastIndex : Pass<"cast-index", "::mlir::ModuleOp"> {
];
}

def EmulateNonNativeBF16 : Pass<"imex-emulate-non-native-bf16", "::mlir::gpu::GPUModuleOp"> {
let summary = "transform gpu.func with bf16 emulation (upconvert and downconvert) for ops that are not natively supported";
let description = [{
This pass transforms a set of operations inside gpu.func
whose respective lowered SPIR-V ops do not support bf16 data type natively.
For the unsupported ops, computation is replaced by first extending bf16 to f32,
do the compute in f32 and truncate result back to bf16 when appropiate.
}];
let constructor = "imex::createEmulateNonNativeBF16Pass()";
let dependentDialects = [
"::mlir::gpu::GPUDialect",
"::mlir::memref::MemRefDialect",
"::mlir::arith::ArithDialect"
];
}

def RemoveTemporaries : Pass<"imex-remove-temporaries"> {
let summary = "Remove redundant memref.alloc and memref.copy operations";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_library(IMEXTransforms
AddOuterParallelLoop.cpp
BF16ToGPU.cpp
CastIndex.cpp
EmulateNonNativeBF16.cpp
InsertGPUAllocs.cpp
LowerMemRefCopy.cpp
PropagatePackedLayout.cpp
Expand Down
155 changes: 155 additions & 0 deletions lib/Transforms/EmulateNonNativeBF16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
//===- EmulateNonNativeBF16.cpp -
// Emulate bf16 for ops that doesn't support native bf16 data type pass
// -------------------*- C++ -*-===//
//
// Copyright 2022 Intel Corporation
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This pass iterates gpu.func starting from top module
/// and
/// Emulate bf16 ops by extending them to f32 and truncate the result back to
/// bf16 whose SPIR-V counterpart is not natively supported
///
//===----------------------------------------------------------------------===//

#include "imex/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/TypeUtilities.h"
#include <mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>

#include <algorithm>
#include <unordered_set>
#include <vector>

namespace imex {
#define GEN_PASS_DEF_EMULATENONNATIVEBF16
#include "imex/Transforms/Passes.h.inc"
} // namespace imex

using namespace mlir;
using namespace imex;

namespace {
struct EmulateNonNativeBF16Pass
: public imex::impl::EmulateNonNativeBF16Base<EmulateNonNativeBF16Pass> {

public:
void runOnOperation() override {
auto mod = getOperation();
SymbolTable symbolTable(mod);
mlir::OpBuilder builder(mod);
// gpu::GPUFuncOp
(void)mod.walk<WalkOrder::PreOrder>([&](gpu::GPUFuncOp op) -> WalkResult {
// 1: Collect ops that need bf16 widening and widen those ops
// Most ops in arith and math dialect that has bf16 operand will
// be widened to use f32 operand

// ATTENTION: Please be aware this pass is specifically intended for the
// OpenCL kernels

// Skip widening of ops, whose lowered SPIR-V counterpart is natively
// supported.
// One thing to keep in mind is that, the bf16 natively supported ops are
// based on SPIR-V ops, not arith or math ops.
// As a result, the skipped ops are identified based on their current
// upstream lowering to SPIR-V ops, so if in the future the lowering
// changes to other ops, this list may also need to be updated.

// @TODO: Make this an arch-specific list and move it to XeArch.h/cpp
std::unordered_set<std::string> nativelySupportedOps{
"arith.bitcast", "arith.extf", "arith.truncf", "arith.addf",
"arith.mulf", "arith.subf", "arith.divf", "arith.maximumf",
"arith.minnumf", "arith.maxnumf", "arith.uitofp", "arith.sitofp",
"arith.fptoui", "arith.fptosi", "math.absf", "math.fma",
"math.tanh"};
SmallVector<Operation *, 8> widenOps;
(void)op.getRegion().walk<WalkOrder::PreOrder>(
[&](Operation *lop) -> WalkResult {
auto oname = lop->getName().getStringRef();
// Skip the natively supported operations
if (auto nativeop = nativelySupportedOps.find(oname.str());
nativeop != nativelySupportedOps.end())
return WalkResult::skip();

// For arith and math ops whose lowered SPIR-V counterpart is not
// natively supported, emulate them with f32 upconvert and bf16
// downconvert
auto needWidening = false;
if (oname.starts_with("arith.") || oname.starts_with("math.")) {
for (const auto &oper : lop->getOperands()) {
if (auto vecTy = mlir::dyn_cast<VectorType>(oper.getType())) {
if (vecTy.getElementType().isBF16()) {
needWidening = true;
}
} else if (oper.getType().isBF16()) {
needWidening = true;
}
}
if (needWidening) {
widenOps.push_back(lop);
}
}
return WalkResult::advance();
});
for (Operation *o : widenOps) {
builder.setInsertionPoint(o);
unsigned int idx = 0;
for (const auto &oper : o->getOperands()) {
if (auto vecTy = mlir::dyn_cast<VectorType>(oper.getType())) {
if (vecTy.getElementType().isBF16()) {
auto newTy =
VectorType::get(vecTy.getShape(), builder.getF32Type());
auto newOp =
builder.create<arith::ExtFOp>(o->getLoc(), newTy, oper);
o->setOperand(idx, newOp);
}
} else if (oper.getType().isBF16()) {
auto newOp = builder.create<arith::ExtFOp>(
o->getLoc(), builder.getF32Type(), oper);
o->setOperand(idx, newOp);
}
idx++;
}
for (mlir::OpResult res : o->getResults()) {
if (auto vecTy = mlir::dyn_cast<VectorType>(res.getType())) {
if (vecTy.getElementType().isBF16()) {
auto resTy =
VectorType::get(vecTy.getShape(), builder.getF32Type());
res.setType(resTy);
builder.setInsertionPointAfter(o);
auto newTy =
VectorType::get(vecTy.getShape(), builder.getBF16Type());
auto newRes =
builder.create<arith::TruncFOp>(o->getLoc(), newTy, res);
res.replaceAllUsesExcept(newRes, newRes);
}
} else if (res.getType().isBF16()) {
res.setType(builder.getF32Type());
builder.setInsertionPointAfter(o);
auto newRes = builder.create<arith::TruncFOp>(
o->getLoc(), builder.getBF16Type(), res);
res.replaceAllUsesExcept(newRes, newRes);
}
}
}

return WalkResult::advance();
});
}
};
} // namespace

namespace imex {
std::unique_ptr<mlir::Pass> createEmulateNonNativeBF16Pass() {
return std::make_unique<EmulateNonNativeBF16Pass>();
}
} // namespace imex
2 changes: 1 addition & 1 deletion test/Integration/Dialect/Gpu/EltwiseAdd_BF16.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ module @eltwise_add attributes {gpu.container_module} {
gpu.dealloc %memref : memref<10x20xbf16>
return %alloc : memref<10x20xbf16>
}
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Bfloat16ConversionINTEL, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, VectorAnyINTEL, BFloat16TypeKHR], [SPV_INTEL_bfloat16_conversion, SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute, SPV_KHR_bfloat16]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL, Bfloat16ConversionINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute, SPV_INTEL_bfloat16_conversion]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 10, 20, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%block_id_x = gpu.block_id x
%block_id_y = gpu.block_id y
Expand Down
1 change: 1 addition & 0 deletions test/Integration/Dialect/Gpu/gpu-to-llvm.pp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
builtin.module(
imex-vector-linearize
reconcile-unrealized-casts
bf16-to-gpu
imex-convert-gpu-to-spirv
spirv.module(spirv-lower-abi-attrs
spirv-update-vce)
Expand Down
2 changes: 1 addition & 1 deletion test/Integration/Dialect/XeGPU/gemm_1024x1024xbf16.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ module @gemm attributes {gpu.container_module} {
gpu.dealloc %memref_0 : memref<1024x1024xbf16>
return %memref_1 : memref<1024x1024xf32>
}
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL, Bfloat16ConversionINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute, SPV_INTEL_bfloat16_conversion]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_kernel(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 128, 64, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ module @gemm attributes {gpu.container_module} {
return %C_gpu : memref<256x256xf32>
}

gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, BFloat16TypeKHR, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_bfloat16, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL, Bfloat16ConversionINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute, SPV_INTEL_bfloat16_conversion]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_kernel(%A: memref<256x256xbf16>, %B: memref<256x256xbf16>, %C: memref<256x256xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
// constants
%c256 = arith.constant 256 : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ module @gemm attributes {gpu.container_module} {

}

gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, BFloat16TypeKHR, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_bfloat16, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL, Bfloat16ConversionINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute, SPV_INTEL_bfloat16_conversion]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_kernel(%A: memref<4096x4096xbf16>, %B: memref<4096x4096xbf16>, %C: memref<4096x4096xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
// constants
%c256 = arith.constant 256 : index
Expand Down
5 changes: 0 additions & 5 deletions test/Transforms/BF16ToGPU/lit.local.cfg

This file was deleted.

19 changes: 19 additions & 0 deletions test/Transforms/EmulateNonNativeBF16/Constants.bf16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: imex-opt %s --imex-emulate-non-native-bf16 | FileCheck %s

module @bf16_constants {
gpu.module @test_kernel attributes {} {
gpu.func @test_kernel(%arg0: memref<10x10xbf16>) kernel attributes {} {
%cst0 = arith.constant 0 : index
%0 = gpu.block_id x
%1 = gpu.block_id y
// CHECK: arith.constant 2.000000e+00 : bf16
%2 = arith.constant 2.0 : bf16
// CHECK: arith.constant dense<1.000000e+00> : vector<10xbf16>
%3 = arith.constant dense<1.0> : vector<10xbf16>
%4 = arith.addf %2, %2 : bf16
vector.store %3, %arg0[%1, %cst0] : memref<10x10xbf16>, vector<10xbf16>
memref.store %4, %arg0[%1, %0] : memref<10x10xbf16>
gpu.return
}
}
}
16 changes: 16 additions & 0 deletions test/Transforms/EmulateNonNativeBF16/Extf.bf16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: imex-opt %s --imex-emulate-non-native-bf16 | FileCheck %s

module @bf16_constants {
gpu.module @test_kernel attributes {} {
gpu.func @test_kernel(%arg0: memref<10x10xbf16>, %arg1: memref<10x10xf32>) kernel attributes {} {
%cst0 = arith.constant 0 : index
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16>
// CHECK: arith.extf %[[LOAD]] : vector<10x10xbf16> to vector<10x10xf32>
%2 = vector.load %arg0[%cst0, %cst0] : memref<10x10xbf16>, vector<10x10xbf16>
%3 = arith.extf %2 : vector<10x10xbf16> to vector<10x10xf32>

vector.store %3, %arg1[%cst0, %cst0] : memref<10x10xf32>, vector<10x10xf32>
gpu.return
}
}
}
55 changes: 55 additions & 0 deletions test/Transforms/EmulateNonNativeBF16/GEMM.bf16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// RUN: %python_executable %imex_runner -i %s --pass-pipeline-file=%p/imex-emulate-non-native-bf16.pp \
// RUN: --no-mlir-runner --filecheck

module @gemm attributes {gpu.container_module} {
memref.global "private" constant @__constant_3x3xbf16_1 : memref<3x3xbf16> = dense<1.000000e+00>
memref.global "private" constant @__constant_3x3xbf16_0 : memref<3x3xbf16> = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e-01], [3.000000e+00, 3.000000e+00, 3.000000e+00]]>
memref.global "private" constant @__constant_3x3xbf16 : memref<3x3xbf16> = dense<[[5.000000e-01, 2.001950e-01, 4.000000e+00], [1.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 3.000000e+00, 3.007810e-01]]>
func.func @main() {
%0 = memref.get_global @__constant_3x3xbf16 : memref<3x3xbf16>
%1 = memref.get_global @__constant_3x3xbf16_0 : memref<3x3xbf16>
%2 = memref.get_global @__constant_3x3xbf16_1 : memref<3x3xbf16>
%3 = call @test(%0, %1, %2) : (memref<3x3xbf16>, memref<3x3xbf16>, memref<3x3xbf16>) -> memref<3x3xbf16>
%cast = memref.cast %3 : memref<3x3xbf16> to memref<*xbf16>
call @printMemrefBF16(%cast) : (memref<*xbf16>) -> ()
return
}
func.func private @printMemrefBF16(memref<*xbf16>)
func.func @test(%arg0: memref<3x3xbf16>, %arg1: memref<3x3xbf16>, %arg2: memref<3x3xbf16>) -> memref<3x3xbf16> {
%c3 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%memref = gpu.alloc host_shared () : memref<3x3xbf16>
memref.copy %arg1, %memref : memref<3x3xbf16> to memref<3x3xbf16>
%memref_0 = gpu.alloc host_shared () : memref<3x3xbf16>
memref.copy %arg0, %memref_0 : memref<3x3xbf16> to memref<3x3xbf16>
%memref_1 = gpu.alloc host_shared () : memref<3x3xbf16>
memref.copy %arg2, %memref_1 : memref<3x3xbf16> to memref<3x3xbf16>
gpu.launch_func @test_kernel::@test_kernel blocks in (%c3, %c3, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<3x3xbf16>, %memref : memref<3x3xbf16>, %memref_1 : memref<3x3xbf16>, %c0 : index, %c3 : index, %c1 : index)
gpu.dealloc %memref_0 : memref<3x3xbf16>
gpu.dealloc %memref : memref<3x3xbf16>
return %memref_1 : memref<3x3xbf16>
}
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>} {
// CHECK: gpu.func @test_kernel(%arg0: memref<3x3xbf16>, %arg1: memref<3x3xbf16>, %arg2: memref<3x3xbf16>, %arg3: index, %arg4: index, %arg5: index)
gpu.func @test_kernel(%arg0: memref<3x3xbf16>, %arg1: memref<3x3xbf16>, %arg2: memref<3x3xbf16>, %arg3: index, %arg4: index, %arg5: index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 3, 3, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%0 = gpu.block_id x
%1 = gpu.block_id y
scf.for %arg6 = %arg3 to %arg4 step %arg5 {
// CHECK: %[[VAR2:.*]] = memref.load %arg0[%[[VAR0:.*]], %arg6] : memref<3x3xbf16>
// CHECK: %[[VAR3:.*]] = memref.load %arg1[%arg6, %[[VAR1:.*]]] : memref<3x3xbf16>
// CHECK: %[[VAR4:.*]] = memref.load %arg2[%[[VAR0]], %[[VAR1]]] : memref<3x3xbf16>
%2 = memref.load %arg0[%0, %arg6] : memref<3x3xbf16>
%3 = memref.load %arg1[%arg6, %1] : memref<3x3xbf16>
%4 = memref.load %arg2[%0, %1] : memref<3x3xbf16>
// CHECK: %[[VAR5:.*]] = arith.mulf %[[VAR2]], %[[VAR3]] : bf16
%5 = arith.mulf %2, %3 : bf16
// CHECK: %[[VAR6:.*]] = arith.addf %[[VAR4]], %[[VAR5]] : bf16
%6 = arith.addf %4, %5 : bf16
// CHECK: memref.store %[[VAR6]], %arg2[%[[VAR0:.*]], %[[VAR0:.*]]] : memref<3x3xbf16>
memref.store %6, %arg2[%0, %1] : memref<3x3xbf16>
}
gpu.return
}
}
}
Loading

0 comments on commit c28dab2

Please sign in to comment.