You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When Performing blocked pointer matrix multiplication, in advanced path (os.environ['TRITON_INTEL_ADVANCED_PATH'] = '1') the computed results are incorrect with A*(Tranpose B), Many of the numbers in result matrices are computed to be zero.
The result of A*(Tranpose B) however aligns with torch output for the default path.
Here is a reproducer which produces the error.
reproducer.py
import torch
import triton
import triton.language as tl
import os
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=3,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2,
num_warps=32),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2,
num_warps=32),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel_with_block_pointers(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
ACCUMULATOR_DTYPE: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See the matrix multiplication tutorial for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create block pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction and accumulate.
# See above `Make a Block Pointer` section for details.
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
order=(1, 0))
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
for k in range(0, K, BLOCK_SIZE_K):
# Load with boundary checks, no need to calculate the mask manually.
# For better performance, you may remove some axis from the boundary
# check, if you can guarantee that the access is always in-bound in
# that axis.
# See above `Load/Store a Block Pointer` section for details.
a = tl.load(a_block_ptr, boundary_check=(0, 1))
b = tl.load(b_block_ptr, boundary_check=(0, 1))
# We accumulate along the K dimension.
accumulator += tl.dot(a, b, out_dtype=ACCUMULATOR_DTYPE)
# Advance the block pointer to the next K block.
# See above `Advance a Block Pointer` section for details.
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
c = accumulator.to(c_ptr.type.element_ty)
# ----------------------------------------------------------------
# Write back the block of the output matrix C with boundary checks.
# See above `Load/Store a Block Pointer` section for details.
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
tl.store(c_block_ptr, c, boundary_check=(0, 1))
def matmul(X, Y, accum_dtype, res_dtype,transpose_x=False,transpose_y=False,):
if len(X.shape) == 2 and len(Y.shape) == 2:
assert X.shape[1] == Y.shape[0], "Incompatible dimensions"
assert X.is_contiguous(), "Matrix A must be contiguous"
assert Y.is_contiguous(), "Matrix B must be contiguous"
M, K = X.shape
K, N = Y.shape
B = 1
if transpose_x:
K, M = X.shape
Xstride0, Xstride1 = X.stride(1), X.stride(0)
else:
M, K = X.shape
Xstride0, Xstride1 = X.stride(0), X.stride(1)
if transpose_y:
N, _ = Y.shape
Ystride0, Ystride1 = Y.stride(1), Y.stride(0)
else:
_, N = Y.shape
Ystride0, Ystride1 = Y.stride(0), Y.stride(1)
Z = torch.empty((M, N), device=a.device, dtype=res_dtype)
# Map accumulator type, e.g. `torch.float16` -> `tl.fp16`
triton_accum_dtype = tl.dtype(str(accum_dtype)[6:].replace('bfloat', 'bf').replace('float', 'fp'))
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel_with_block_pointers[grid](
X, Y, Z, #
M, N, K, #
Xstride0, Xstride1 , #
Ystride0, Ystride1, #
Z.stride(0), Z.stride(1), #
ACCUMULATOR_DTYPE=triton_accum_dtype)
else:
assert False, "Input matrixs dimensions mismatch"
# Allocates output.
return Z
shape=(512, 512)
dtype= torch.float16
accum_dtype=res_dtype=dtype=torch.float16
a = torch.randint(low=-127, high=128, size=(512, 512), device='xpu', dtype=torch.float16)
b = torch.eye(shape[-2], device='xpu', dtype=dtype) + torch.diag(
torch.ones(shape[-2] - 1, device='xpu', dtype=dtype), diagonal=1) + torch.diag(
torch.ones(shape[-2] - 1, device='xpu', dtype=dtype), diagonal=-1)
triton_output = matmul(a, b, torch.float16, torch.float16)
torch_output = torch.matmul(a.to(device='cpu', dtype=accum_dtype),
b.to(device='cpu', dtype=accum_dtype)).to(device='xpu', dtype=res_dtype)
assert torch.equal(torch_output,triton_output), "Torch and triton dont match in default path"
os.environ['TRITON_INTEL_ADVANCED_PATH'] = '1'
triton_output = matmul(a, b, torch.float16, dtype, False, True)
torch_output = torch.matmul(a.to(device='cpu', dtype=accum_dtype),
b.T.to(device='cpu', dtype=accum_dtype)).to(device='xpu', dtype=res_dtype)
assert torch.equal(torch_output,triton_output), "Torch and triton dont match in advanced path"
These are my hw details
LIBIGC1_VERSION=1.0.17193.16-950
LEVEL_ZERO_VERSION=1.3.30049.10-950
AGAMA_VERSION=950
GPU_DEVICE=Intel(R) Data Center GPU Max 1100
The text was updated successfully, but these errors were encountered:
When Performing blocked pointer matrix multiplication, in advanced path (
os.environ['TRITON_INTEL_ADVANCED_PATH'] = '1'
) the computed results are incorrect withA*(Tranpose B)
, Many of the numbers in result matrices are computed to be zero.The result of
A*(Tranpose B)
however aligns with torch output for the default path.Here is a reproducer which produces the error.
reproducer.py
These are my hw details
LIBIGC1_VERSION=1.0.17193.16-950
LEVEL_ZERO_VERSION=1.3.30049.10-950
AGAMA_VERSION=950
GPU_DEVICE=Intel(R) Data Center GPU Max 1100
The text was updated successfully, but these errors were encountered: