Skip to content

Commit

Permalink
Support shared memory allocation via Sycl runtime.
Browse files Browse the repository at this point in the history
  • Loading branch information
nbpatel authored and silee2 committed Aug 22, 2023
1 parent 2b30545 commit 4d8dce5
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 15 deletions.
32 changes: 22 additions & 10 deletions lib/ExecutionEngine/LEVELZERORUNTIME/LevelZeroRuntimeWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,18 +406,30 @@ getKernel(GPUL0QUEUE *queue, ze_module_handle_t module, const char *name) {
return zeKernel;
}

static void
enqueueKernel(ze_command_list_handle_t zeCommandList, ze_kernel_handle_t kernel,
const ze_group_count_t *pLaunchArgs, ParamDesc *params,
ze_event_handle_t waitEvent = nullptr, uint32_t numWaitEvents = 0,
ze_event_handle_t *phWaitEvents = nullptr) {
static void enqueueKernel(ze_command_list_handle_t zeCommandList,
ze_kernel_handle_t kernel,
const ze_group_count_t *pLaunchArgs,
ParamDesc *params, size_t sharedMemBytes,
ze_event_handle_t waitEvent = nullptr,
uint32_t numWaitEvents = 0,
ze_event_handle_t *phWaitEvents = nullptr) {
auto paramsCount = countUntil(params, ParamDesc{nullptr, 0});

if (sharedMemBytes) {
paramsCount = paramsCount - 1;
}

for (size_t i = 0; i < paramsCount; ++i) {
auto param = params[i];
CHECK_ZE_RESULT(zeKernelSetArgumentValue(kernel, static_cast<uint32_t>(i),
param.size, param.data));
}

if (sharedMemBytes) {
CHECK_ZE_RESULT(
zeKernelSetArgumentValue(kernel, paramsCount, sharedMemBytes, nullptr));
}

CHECK_ZE_RESULT(zeCommandListAppendLaunchKernel(zeCommandList, kernel,
pLaunchArgs, waitEvent,
numWaitEvents, phWaitEvents));
Expand Down Expand Up @@ -456,14 +468,14 @@ static void launchKernel(GPUL0QUEUE *queue, ze_kernel_handle_t kernel,

// warmup
for (int r = 0; r < warmups; r++)
enqueueKernel(queue->zeCommandList_, kernel, &launchArgs, params, nullptr,
0, nullptr);
enqueueKernel(queue->zeCommandList_, kernel, &launchArgs, params,
sharedMemBytes, nullptr, 0, nullptr);

// profiling using timestamp event privided by level-zero
for (int r = 0; r < rounds; r++) {
Event event(queue->zeContext_, queue->zeDevice_);
enqueueKernel(queue->zeCommandList_, kernel, &launchArgs, params,
event.zeEvent, 0, nullptr);
sharedMemBytes, event.zeEvent, 0, nullptr);

auto startTime =
event.get_profiling_info<imex::profiling::command_start>();
Expand All @@ -480,8 +492,8 @@ static void launchKernel(GPUL0QUEUE *queue, ze_kernel_handle_t kernel,
"avg: %.4f, min: %.4f, max: %.4f (over %d runs)\n",
executionTime / rounds, minTime, maxTime, rounds);
} else {
enqueueKernel(queue->zeCommandList_, kernel, &launchArgs, params, nullptr,
0, nullptr);
enqueueKernel(queue->zeCommandList_, kernel, &launchArgs, params,
sharedMemBytes, nullptr, 0, nullptr);
}
}

Expand Down
28 changes: 23 additions & 5 deletions lib/ExecutionEngine/SYCLRUNTIME/SyclRuntimeWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,32 @@ static sycl::kernel *getKernel(GPUSYCLQUEUE *queue, ze_module_handle_t zeModule,
}

static sycl::event enqueueKernel(sycl::queue queue, sycl::kernel *kernel,
sycl::nd_range<3> NdRange, ParamDesc *params) {
sycl::nd_range<3> NdRange, ParamDesc *params,
size_t sharedMemBytes) {
auto paramsCount = countUntil(params, ParamDesc{nullptr, 0});
// The assumption is, if there is a param for the shared local memory,
// then that will always be the last argument.
if (sharedMemBytes) {
paramsCount = paramsCount - 1;
}
sycl::event event = queue.submit([&](sycl::handler &cgh) {
for (size_t i = 0; i < paramsCount; i++) {
auto param = params[i];
cgh.set_arg(static_cast<uint32_t>(i),
*(static_cast<void **>(param.data)));
}
cgh.parallel_for(NdRange, *kernel);
if (sharedMemBytes) {
// TODO: Handle other data types
using share_mem_t =
sycl::accessor<float, 1, sycl::access::mode::read_write,
sycl::access::target::local>;
share_mem_t local_buffer =
share_mem_t(sharedMemBytes / sizeof(float), cgh);
cgh.set_arg(paramsCount, local_buffer);
cgh.parallel_for(NdRange, *kernel);
} else {
cgh.parallel_for(NdRange, *kernel);
}
});
return event;
}
Expand Down Expand Up @@ -262,11 +279,12 @@ static void launchKernel(GPUSYCLQUEUE *queue, sycl::kernel *kernel,

// warmups
for (int r = 0; r < warmups; r++) {
enqueueKernel(syclQueue, kernel, syclNdRange, params);
enqueueKernel(syclQueue, kernel, syclNdRange, params, sharedMemBytes);
}

for (int r = 0; r < rounds; r++) {
sycl::event event = enqueueKernel(syclQueue, kernel, syclNdRange, params);
sycl::event event =
enqueueKernel(syclQueue, kernel, syclNdRange, params, sharedMemBytes);

auto startTime = event.get_profiling_info<
cl::sycl::info::event_profiling::command_start>();
Expand All @@ -285,7 +303,7 @@ static void launchKernel(GPUSYCLQUEUE *queue, sycl::kernel *kernel,
"avg: %.4f, min: %.4f, max: %.4f (over %d runs)\n",
executionTime / rounds, minTime, maxTime, rounds);
} else {
enqueueKernel(syclQueue, kernel, syclNdRange, params);
enqueueKernel(syclQueue, kernel, syclNdRange, params, sharedMemBytes);
}
}

Expand Down
Loading

0 comments on commit 4d8dce5

Please sign in to comment.