Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

computed results are incorrect with blocked pointer matrix multiplication, in TRITON_INTEL_ADVANCED_PATH #2209

Open
arunjose696 opened this issue Sep 11, 2024 · 0 comments
Assignees
Labels
bug Something isn't working codegen: gemm

Comments

@arunjose696
Copy link

When Performing blocked pointer matrix multiplication, in advanced path (os.environ['TRITON_INTEL_ADVANCED_PATH'] = '1') the computed results are incorrect with A*(Tranpose B), Many of the numbers in result matrices are computed to be zero.

The result of A*(Tranpose B) however aligns with torch output for the default path.

Here is a reproducer which produces the error.

reproducer.py
import torch

import triton
import triton.language as tl
import os

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2,
                      num_warps=32),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=3,
                      num_warps=32),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2,
                      num_warps=32),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=2,
                      num_warps=32),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel_with_block_pointers(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        # Matrix dimensions
        M, N, K,
        # The stride variables represent how much to increase the ptr by when moving by 1
        # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
        # by to get the element one row down (A has M rows).
        stride_am, stride_ak,  #
        stride_bk, stride_bn,  #
        stride_cm, stride_cn,  #
        ACCUMULATOR_DTYPE: tl.constexpr,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See the matrix multiplication tutorial for details.
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create block pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction and accumulate.
    # See above `Make a Block Pointer` section for details.
    a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
                                    offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
                                    order=(1, 0))
    b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
                                    offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
                                    order=(1, 0))

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block.
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE)
    for k in range(0, K, BLOCK_SIZE_K):
        # Load with boundary checks, no need to calculate the mask manually.
        # For better performance, you may remove some axis from the boundary
        # check, if you can guarantee that the access is always in-bound in
        # that axis.
        # See above `Load/Store a Block Pointer` section for details.
        a = tl.load(a_block_ptr, boundary_check=(0, 1))
        b = tl.load(b_block_ptr, boundary_check=(0, 1))
        # We accumulate along the K dimension.
        accumulator += tl.dot(a, b, out_dtype=ACCUMULATOR_DTYPE)
        # Advance the block pointer to the next K block.
        # See above `Advance a Block Pointer` section for details.
        a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
        b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
    c = accumulator.to(c_ptr.type.element_ty)
    # ----------------------------------------------------------------
    # Write back the block of the output matrix C with boundary checks.
    # See above `Load/Store a Block Pointer` section for details.
    c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
                                    offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
                                    block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
    tl.store(c_block_ptr, c, boundary_check=(0, 1))
    
def matmul(X, Y, accum_dtype, res_dtype,transpose_x=False,transpose_y=False,):
    if len(X.shape) == 2 and len(Y.shape) == 2:
        assert X.shape[1] == Y.shape[0], "Incompatible dimensions"
        assert X.is_contiguous(), "Matrix A must be contiguous"
        assert Y.is_contiguous(), "Matrix B must be contiguous"
        M, K = X.shape
        K, N = Y.shape
        B = 1   
        if transpose_x:
            K, M = X.shape
            Xstride0, Xstride1 = X.stride(1), X.stride(0)
        else:
            M, K = X.shape
            Xstride0, Xstride1 = X.stride(0), X.stride(1)
        if transpose_y:
            N, _ = Y.shape
            Ystride0, Ystride1 = Y.stride(1), Y.stride(0)
        else:
            _, N = Y.shape
            Ystride0, Ystride1 = Y.stride(0), Y.stride(1)
        Z = torch.empty((M, N), device=a.device, dtype=res_dtype)
        
        # Map accumulator type, e.g. `torch.float16` -> `tl.fp16`
        triton_accum_dtype = tl.dtype(str(accum_dtype)[6:].replace('bfloat', 'bf').replace('float', 'fp'))
        # 1D launch kernel where each block gets its own program.
        
            
        grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
        matmul_kernel_with_block_pointers[grid](
            X, Y, Z,  #
            M, N, K,  #
            Xstride0, Xstride1 ,  #
            Ystride0, Ystride1,  #
            Z.stride(0), Z.stride(1),  #
            ACCUMULATOR_DTYPE=triton_accum_dtype)
    else:
        assert False, "Input matrixs dimensions mismatch"
    # Allocates output.
    return Z
shape=(512, 512)
dtype= torch.float16
accum_dtype=res_dtype=dtype=torch.float16
a = torch.randint(low=-127, high=128, size=(512, 512), device='xpu', dtype=torch.float16)
b = torch.eye(shape[-2], device='xpu', dtype=dtype) + torch.diag(
                    torch.ones(shape[-2] - 1, device='xpu', dtype=dtype), diagonal=1) + torch.diag(
                        torch.ones(shape[-2] - 1, device='xpu', dtype=dtype), diagonal=-1)
triton_output = matmul(a, b, torch.float16, torch.float16)




torch_output = torch.matmul(a.to(device='cpu', dtype=accum_dtype),
                                        b.to(device='cpu', dtype=accum_dtype)).to(device='xpu', dtype=res_dtype)

assert torch.equal(torch_output,triton_output), "Torch and triton dont match in default path"


os.environ['TRITON_INTEL_ADVANCED_PATH'] = '1'
triton_output = matmul(a, b, torch.float16, dtype, False, True)
torch_output = torch.matmul(a.to(device='cpu', dtype=accum_dtype),
                                        b.T.to(device='cpu', dtype=accum_dtype)).to(device='xpu', dtype=res_dtype)


assert torch.equal(torch_output,triton_output), "Torch and triton dont match in advanced path"

These are my hw details

LIBIGC1_VERSION=1.0.17193.16-950
LEVEL_ZERO_VERSION=1.3.30049.10-950
AGAMA_VERSION=950
GPU_DEVICE=Intel(R) Data Center GPU Max 1100

@vlad-penkin vlad-penkin added bug Something isn't working codegen: gemm labels Sep 12, 2024
@vlad-penkin vlad-penkin added this to the 4.0 [Performance] Core milestone Sep 12, 2024
@Dewei-Wang-sh Dewei-Wang-sh removed their assignment Sep 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working codegen: gemm
Projects
None yet
Development

No branches or pull requests

4 participants