-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[transform][bf16] Enable bf16 emulation only for non-native bf16 ops. (…
…#874) [transform][bf16] Enable bf16 emulation only for non-native bf16 ops. Emulate bf16 ops by extending them to f32 and truncate the result back to bf16 whose SPIR-V counterpart is not natively supported.
- Loading branch information
Showing
15 changed files
with
333 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
//===- EmulateNonNativeBF16.cpp - | ||
// Emulate bf16 for ops that doesn't support native bf16 data type pass | ||
// -------------------*- C++ -*-===// | ||
// | ||
// Copyright 2022 Intel Corporation | ||
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
/// | ||
/// \file | ||
/// This pass iterates gpu.func starting from top module | ||
/// and | ||
/// Emulate bf16 ops by extending them to f32 and truncate the result back to | ||
/// bf16 whose SPIR-V counterpart is not natively supported | ||
/// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "imex/Transforms/Passes.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/GPU/IR/GPUDialect.h" | ||
#include "mlir/IR/IRMapping.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include <mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h> | ||
#include <mlir/Dialect/MemRef/IR/MemRef.h> | ||
|
||
#include <algorithm> | ||
#include <unordered_set> | ||
#include <vector> | ||
|
||
namespace imex { | ||
#define GEN_PASS_DEF_EMULATENONNATIVEBF16 | ||
#include "imex/Transforms/Passes.h.inc" | ||
} // namespace imex | ||
|
||
using namespace mlir; | ||
using namespace imex; | ||
|
||
namespace { | ||
struct EmulateNonNativeBF16Pass | ||
: public imex::impl::EmulateNonNativeBF16Base<EmulateNonNativeBF16Pass> { | ||
|
||
public: | ||
void runOnOperation() override { | ||
auto mod = getOperation(); | ||
SymbolTable symbolTable(mod); | ||
mlir::OpBuilder builder(mod); | ||
// gpu::GPUFuncOp | ||
(void)mod.walk<WalkOrder::PreOrder>([&](gpu::GPUFuncOp op) -> WalkResult { | ||
// 1: Collect ops that need bf16 widening and widen those ops | ||
// Most ops in arith and math dialect that has bf16 operand will | ||
// be widened to use f32 operand | ||
|
||
// ATTENTION: Please be aware this pass is specifically intended for the | ||
// OpenCL kernels | ||
|
||
// Skip widening of ops, whose lowered SPIR-V counterpart is natively | ||
// supported. | ||
// One thing to keep in mind is that, the bf16 natively supported ops are | ||
// based on SPIR-V ops, not arith or math ops. | ||
// As a result, the skipped ops are identified based on their current | ||
// upstream lowering to SPIR-V ops, so if in the future the lowering | ||
// changes to other ops, this list may also need to be updated. | ||
|
||
// @TODO: Make this an arch-specific list and move it to XeArch.h/cpp | ||
std::unordered_set<std::string> nativelySupportedOps{ | ||
"arith.bitcast", "arith.extf", "arith.truncf", "arith.addf", | ||
"arith.mulf", "arith.subf", "arith.divf", "arith.maximumf", | ||
"arith.minnumf", "arith.maxnumf", "arith.uitofp", "arith.sitofp", | ||
"arith.fptoui", "arith.fptosi", "math.absf", "math.fma", | ||
"math.tanh"}; | ||
SmallVector<Operation *, 8> widenOps; | ||
(void)op.getRegion().walk<WalkOrder::PreOrder>( | ||
[&](Operation *lop) -> WalkResult { | ||
auto oname = lop->getName().getStringRef(); | ||
// Skip the natively supported operations | ||
if (auto nativeop = nativelySupportedOps.find(oname.str()); | ||
nativeop != nativelySupportedOps.end()) | ||
return WalkResult::skip(); | ||
|
||
// For arith and math ops whose lowered SPIR-V counterpart is not | ||
// natively supported, emulate them with f32 upconvert and bf16 | ||
// downconvert | ||
auto needWidening = false; | ||
if (oname.starts_with("arith.") || oname.starts_with("math.")) { | ||
for (const auto &oper : lop->getOperands()) { | ||
if (auto vecTy = mlir::dyn_cast<VectorType>(oper.getType())) { | ||
if (vecTy.getElementType().isBF16()) { | ||
needWidening = true; | ||
} | ||
} else if (oper.getType().isBF16()) { | ||
needWidening = true; | ||
} | ||
} | ||
if (needWidening) { | ||
widenOps.push_back(lop); | ||
} | ||
} | ||
return WalkResult::advance(); | ||
}); | ||
for (Operation *o : widenOps) { | ||
builder.setInsertionPoint(o); | ||
unsigned int idx = 0; | ||
for (const auto &oper : o->getOperands()) { | ||
if (auto vecTy = mlir::dyn_cast<VectorType>(oper.getType())) { | ||
if (vecTy.getElementType().isBF16()) { | ||
auto newTy = | ||
VectorType::get(vecTy.getShape(), builder.getF32Type()); | ||
auto newOp = | ||
builder.create<arith::ExtFOp>(o->getLoc(), newTy, oper); | ||
o->setOperand(idx, newOp); | ||
} | ||
} else if (oper.getType().isBF16()) { | ||
auto newOp = builder.create<arith::ExtFOp>( | ||
o->getLoc(), builder.getF32Type(), oper); | ||
o->setOperand(idx, newOp); | ||
} | ||
idx++; | ||
} | ||
for (mlir::OpResult res : o->getResults()) { | ||
if (auto vecTy = mlir::dyn_cast<VectorType>(res.getType())) { | ||
if (vecTy.getElementType().isBF16()) { | ||
auto resTy = | ||
VectorType::get(vecTy.getShape(), builder.getF32Type()); | ||
res.setType(resTy); | ||
builder.setInsertionPointAfter(o); | ||
auto newTy = | ||
VectorType::get(vecTy.getShape(), builder.getBF16Type()); | ||
auto newRes = | ||
builder.create<arith::TruncFOp>(o->getLoc(), newTy, res); | ||
res.replaceAllUsesExcept(newRes, newRes); | ||
} | ||
} else if (res.getType().isBF16()) { | ||
res.setType(builder.getF32Type()); | ||
builder.setInsertionPointAfter(o); | ||
auto newRes = builder.create<arith::TruncFOp>( | ||
o->getLoc(), builder.getBF16Type(), res); | ||
res.replaceAllUsesExcept(newRes, newRes); | ||
} | ||
} | ||
} | ||
|
||
return WalkResult::advance(); | ||
}); | ||
} | ||
}; | ||
} // namespace | ||
|
||
namespace imex { | ||
std::unique_ptr<mlir::Pass> createEmulateNonNativeBF16Pass() { | ||
return std::make_unique<EmulateNonNativeBF16Pass>(); | ||
} | ||
} // namespace imex |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// RUN: imex-opt %s --imex-emulate-non-native-bf16 | FileCheck %s | ||
|
||
module @bf16_constants { | ||
gpu.module @test_kernel attributes {} { | ||
gpu.func @test_kernel(%arg0: memref<10x10xbf16>) kernel attributes {} { | ||
%cst0 = arith.constant 0 : index | ||
%0 = gpu.block_id x | ||
%1 = gpu.block_id y | ||
// CHECK: arith.constant 2.000000e+00 : bf16 | ||
%2 = arith.constant 2.0 : bf16 | ||
// CHECK: arith.constant dense<1.000000e+00> : vector<10xbf16> | ||
%3 = arith.constant dense<1.0> : vector<10xbf16> | ||
%4 = arith.addf %2, %2 : bf16 | ||
vector.store %3, %arg0[%1, %cst0] : memref<10x10xbf16>, vector<10xbf16> | ||
memref.store %4, %arg0[%1, %0] : memref<10x10xbf16> | ||
gpu.return | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// RUN: imex-opt %s --imex-emulate-non-native-bf16 | FileCheck %s | ||
|
||
module @bf16_constants { | ||
gpu.module @test_kernel attributes {} { | ||
gpu.func @test_kernel(%arg0: memref<10x10xbf16>, %arg1: memref<10x10xf32>) kernel attributes {} { | ||
%cst0 = arith.constant 0 : index | ||
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%c0, %c0] : memref<10x10xbf16>, vector<10x10xbf16> | ||
// CHECK: arith.extf %[[LOAD]] : vector<10x10xbf16> to vector<10x10xf32> | ||
%2 = vector.load %arg0[%cst0, %cst0] : memref<10x10xbf16>, vector<10x10xbf16> | ||
%3 = arith.extf %2 : vector<10x10xbf16> to vector<10x10xf32> | ||
|
||
vector.store %3, %arg1[%cst0, %cst0] : memref<10x10xf32>, vector<10x10xf32> | ||
gpu.return | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
// RUN: %python_executable %imex_runner -i %s --pass-pipeline-file=%p/imex-emulate-non-native-bf16.pp \ | ||
// RUN: --no-mlir-runner --filecheck | ||
|
||
module @gemm attributes {gpu.container_module} { | ||
memref.global "private" constant @__constant_3x3xbf16_1 : memref<3x3xbf16> = dense<1.000000e+00> | ||
memref.global "private" constant @__constant_3x3xbf16_0 : memref<3x3xbf16> = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [3.000000e+00, 4.000000e+00, 5.000000e-01], [3.000000e+00, 3.000000e+00, 3.000000e+00]]> | ||
memref.global "private" constant @__constant_3x3xbf16 : memref<3x3xbf16> = dense<[[5.000000e-01, 2.001950e-01, 4.000000e+00], [1.000000e+00, 1.000000e+00, 2.000000e+00], [3.000000e+00, 3.000000e+00, 3.007810e-01]]> | ||
func.func @main() { | ||
%0 = memref.get_global @__constant_3x3xbf16 : memref<3x3xbf16> | ||
%1 = memref.get_global @__constant_3x3xbf16_0 : memref<3x3xbf16> | ||
%2 = memref.get_global @__constant_3x3xbf16_1 : memref<3x3xbf16> | ||
%3 = call @test(%0, %1, %2) : (memref<3x3xbf16>, memref<3x3xbf16>, memref<3x3xbf16>) -> memref<3x3xbf16> | ||
%cast = memref.cast %3 : memref<3x3xbf16> to memref<*xbf16> | ||
call @printMemrefBF16(%cast) : (memref<*xbf16>) -> () | ||
return | ||
} | ||
func.func private @printMemrefBF16(memref<*xbf16>) | ||
func.func @test(%arg0: memref<3x3xbf16>, %arg1: memref<3x3xbf16>, %arg2: memref<3x3xbf16>) -> memref<3x3xbf16> { | ||
%c3 = arith.constant 3 : index | ||
%c1 = arith.constant 1 : index | ||
%c0 = arith.constant 0 : index | ||
%memref = gpu.alloc host_shared () : memref<3x3xbf16> | ||
memref.copy %arg1, %memref : memref<3x3xbf16> to memref<3x3xbf16> | ||
%memref_0 = gpu.alloc host_shared () : memref<3x3xbf16> | ||
memref.copy %arg0, %memref_0 : memref<3x3xbf16> to memref<3x3xbf16> | ||
%memref_1 = gpu.alloc host_shared () : memref<3x3xbf16> | ||
memref.copy %arg2, %memref_1 : memref<3x3xbf16> to memref<3x3xbf16> | ||
gpu.launch_func @test_kernel::@test_kernel blocks in (%c3, %c3, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<3x3xbf16>, %memref : memref<3x3xbf16>, %memref_1 : memref<3x3xbf16>, %c0 : index, %c3 : index, %c1 : index) | ||
gpu.dealloc %memref_0 : memref<3x3xbf16> | ||
gpu.dealloc %memref : memref<3x3xbf16> | ||
return %memref_1 : memref<3x3xbf16> | ||
} | ||
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume]>, api=OpenCL, #spirv.resource_limits<>>} { | ||
// CHECK: gpu.func @test_kernel(%arg0: memref<3x3xbf16>, %arg1: memref<3x3xbf16>, %arg2: memref<3x3xbf16>, %arg3: index, %arg4: index, %arg5: index) | ||
gpu.func @test_kernel(%arg0: memref<3x3xbf16>, %arg1: memref<3x3xbf16>, %arg2: memref<3x3xbf16>, %arg3: index, %arg4: index, %arg5: index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 3, 3, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} { | ||
%0 = gpu.block_id x | ||
%1 = gpu.block_id y | ||
scf.for %arg6 = %arg3 to %arg4 step %arg5 { | ||
// CHECK: %[[VAR2:.*]] = memref.load %arg0[%[[VAR0:.*]], %arg6] : memref<3x3xbf16> | ||
// CHECK: %[[VAR3:.*]] = memref.load %arg1[%arg6, %[[VAR1:.*]]] : memref<3x3xbf16> | ||
// CHECK: %[[VAR4:.*]] = memref.load %arg2[%[[VAR0]], %[[VAR1]]] : memref<3x3xbf16> | ||
%2 = memref.load %arg0[%0, %arg6] : memref<3x3xbf16> | ||
%3 = memref.load %arg1[%arg6, %1] : memref<3x3xbf16> | ||
%4 = memref.load %arg2[%0, %1] : memref<3x3xbf16> | ||
// CHECK: %[[VAR5:.*]] = arith.mulf %[[VAR2]], %[[VAR3]] : bf16 | ||
%5 = arith.mulf %2, %3 : bf16 | ||
// CHECK: %[[VAR6:.*]] = arith.addf %[[VAR4]], %[[VAR5]] : bf16 | ||
%6 = arith.addf %4, %5 : bf16 | ||
// CHECK: memref.store %[[VAR6]], %arg2[%[[VAR0:.*]], %[[VAR0:.*]]] : memref<3x3xbf16> | ||
memref.store %6, %arg2[%0, %1] : memref<3x3xbf16> | ||
} | ||
gpu.return | ||
} | ||
} | ||
} |
Oops, something went wrong.