Skip to content

Commit

Permalink
Add type conversion for unranked memref type. (#862)
Browse files Browse the repository at this point in the history
Add ranked and dynamic memref test case.
Add unranked memref test case.
Exclude unranked memref case for now as host code lowering has some
issue.
  • Loading branch information
silee2 committed Sep 6, 2024
1 parent 48e567a commit 4fe5852
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 1 deletion.
26 changes: 26 additions & 0 deletions lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "../PassDetail.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"

#include <llvm/ADT/ArrayRef.h>
#include <llvm/Support/Debug.h>
Expand Down Expand Up @@ -417,6 +418,31 @@ void GPUXToSPIRVPass::runOnOperation() {
return isGenericVectorTy(op.getType());
});

// Upstream SPIRVTypeConverter does not add conversion for
// UnrankedMemRefType.
// Conversion logic is the same as ranked dynamic memref type for OpenCL
// Kernel. unranked memref type is converted to a spirv pointer type with
// converted spirv scalar element type and spirv storage class.
// Only scalar element type is currently supported.
// Also vulkan should be handled differently but out of scope since this
// conversion pass is for lowering to OpenCL spirv kernel only.
typeConverter.addConversion(
[&](mlir::UnrankedMemRefType type) -> std::optional<mlir::Type> {
auto attr = mlir::dyn_cast_or_null<mlir::spirv::StorageClassAttr>(
type.getMemorySpace());
if (!attr)
return nullptr;
mlir::spirv::StorageClass storageClass = attr.getValue();

mlir::Type elementType = type.getElementType();
auto scalarType =
mlir::dyn_cast<mlir::spirv::ScalarType>(elementType);
if (!scalarType)
return nullptr;
mlir::Type arrayElemType = typeConverter.convertType(scalarType);
return mlir::spirv::PointerType::get(arrayElemType, storageClass);
});

//------- Upstream Conversion------------
mlir::populateGPUToSPIRVPatterns(typeConverter, patterns);
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
Expand Down
3 changes: 2 additions & 1 deletion test/Integration/Dialect/XeGPU/lit.local.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ non_pvc_excludes = [
]

local_excludes = [
'gemm_SIMT_1024x1024x1024xf16_f16_f32.mlir'
'gemm_SIMT_1024x1024x1024xf16_f16_f32.mlir',
'unranked_memref.vc.mlir' # host code lowering has issues. spirv binary generated is identical to ranked_dynamic_memref.vc.mlir
]

if(not config.imex_enable_pvc_target):
Expand Down
58 changes: 58 additions & 0 deletions test/Integration/Dialect/XeGPU/ranked_dynamic_memref.vc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
module @gemm attributes {gpu.container_module} {
func.func @test(%A : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%memref_0 = gpu.alloc host_shared () : memref<8x16xf32>
memref.copy %A, %memref_0 : memref<8x16xf32> to memref<8x16xf32>
%memref_1 = gpu.alloc host_shared () : memref<8x16xf32>
%memref_0_cast = memref.cast %memref_0 : memref<8x16xf32> to memref<?x?xf32>
%memref_1_cast = memref.cast %memref_1 : memref<8x16xf32> to memref<?x?xf32>
%dim0 = arith.constant 8 : index
%dim1 = arith.constant 16 : index
%stride0 = arith.constant 16 : index
%stride1 = arith.constant 1 : index
%x = arith.constant 0 : index
%y = arith.constant 0 : index
gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_0_cast : memref<?x?xf32>, %memref_1_cast : memref<?x?xf32>, %dim0 : index, %dim1 : index, %stride0 : index, %stride1 : index, %x : index, %y : index)
gpu.dealloc %memref_0 : memref<8x16xf32>
return %memref_1 : memref<8x16xf32>
}
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_kernel(%arg0 : memref<?x?xf32>, %arg1: memref<?x?xf32>, %dim0: index, %dim1: index, %stride0: index, %stride1: index, %x: index, %y: index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%1 = xegpu.create_nd_tdesc %arg0[%x, %y], [%dim0, %dim1], [%stride0, %stride1] : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
%2 = xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%6 = xegpu.create_nd_tdesc %arg1[%x, %y], [%dim0, %dim1], [%stride0, %stride1] : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
xegpu.store_nd %2, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}
func.func @main() attributes {llvm.emit_c_interface} {
%A = memref.alloc() : memref<8x16xf32>
%A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32>
%c_gen_int = arith.constant 0 : i1
%cf_lower = arith.constant -0.5 : f32
%cf_upper = arith.constant 0.5 : f32

call @fillResource1DRandomF32(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> ()

%B = call @test(%A) : (memref<8x16xf32>) -> memref<8x16xf32>
%B_cast = memref.cast %B : memref<8x16xf32> to memref<*xf32>
%A_cast = memref.cast %A : memref<8x16xf32> to memref<*xf32>
// call @printMemrefF32(%B_cast) : (memref<*xf32>) -> ()
// CHECK: [ALLCLOSE: TRUE]
call @printAllcloseF32(%A_cast, %B_cast) : (memref<*xf32>, memref<*xf32>) -> ()

memref.dealloc %A : memref<8x16xf32>
return
}
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
func.func private @fillResource1DRandomF32(memref<*xf32>, f32, f32, i1) attributes {llvm.emit_c_interface}
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
}
60 changes: 60 additions & 0 deletions test/Integration/Dialect/XeGPU/unranked_memref.vc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \
// RUN: --runner imex-cpu-runner -e main \
// RUN: --entry-point-result=void \
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
module @gemm attributes {gpu.container_module} {
func.func @test(%A : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%memref_0 = gpu.alloc host_shared () : memref<8x16xf32>
memref.copy %A, %memref_0 : memref<8x16xf32> to memref<8x16xf32>
%memref_1 = gpu.alloc host_shared () : memref<8x16xf32>
%memref_0_cast = memref.cast %memref_0 : memref<8x16xf32> to memref<*xf32>
%memref_1_cast = memref.cast %memref_1 : memref<8x16xf32> to memref<*xf32>
%dim0 = arith.constant 8 : index
%dim1 = arith.constant 16 : index
%stride0 = arith.constant 16 : index
%stride1 = arith.constant 1 : index
%x = arith.constant 0 : index
%y = arith.constant 0 : index
gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_0_cast : memref<*xf32>, %memref_1_cast : memref<*xf32>, %dim0 : index, %dim1 : index, %stride0 : index, %stride1 : index, %x : index, %y : index)
gpu.dealloc %memref_0 : memref<8x16xf32>
return %memref_1 : memref<8x16xf32>
}
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
gpu.func @test_kernel(%arg0 : memref<*xf32>, %arg1: memref<*xf32>, %dim0: index, %dim1: index, %stride0: index, %stride1: index, %x: index, %y: index) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%ranked0 = memref.cast %arg0 : memref<*xf32> to memref<?x?xf32>
%ranked1 = memref.cast %arg1 : memref<*xf32> to memref<?x?xf32>
%1 = xegpu.create_nd_tdesc %ranked0[%x, %y], [%dim0, %dim1], [%stride0, %stride1] : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
%2 = xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%6 = xegpu.create_nd_tdesc %ranked1[%x, %y], [%dim0, %dim1], [%stride0, %stride1] : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
xegpu.store_nd %2, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}
func.func @main() attributes {llvm.emit_c_interface} {
%A = memref.alloc() : memref<8x16xf32>
%A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32>
%c_gen_int = arith.constant 0 : i1
%cf_lower = arith.constant -0.5 : f32
%cf_upper = arith.constant 0.5 : f32

call @fillResource1DRandomF32(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> ()

%B = call @test(%A) : (memref<8x16xf32>) -> memref<8x16xf32>
%B_cast = memref.cast %B : memref<8x16xf32> to memref<*xf32>
%A_cast = memref.cast %A : memref<8x16xf32> to memref<*xf32>
// call @printMemrefF32(%B_cast) : (memref<*xf32>) -> ()
// CHECK: [ALLCLOSE: TRUE]
call @printAllcloseF32(%A_cast, %B_cast) : (memref<*xf32>, memref<*xf32>) -> ()

memref.dealloc %A : memref<8x16xf32>
return
}
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
func.func private @fillResource1DRandomF32(memref<*xf32>, f32, f32, i1) attributes {llvm.emit_c_interface}
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
}

0 comments on commit 4fe5852

Please sign in to comment.