Skip to content

Commit

Permalink
Add utility functions to initialize 1-D memrefs.
Browse files Browse the repository at this point in the history
Currently supported type: f16, bf16, f32.
  • Loading branch information
mshahneo committed Aug 7, 2023
1 parent d08b6f3 commit 73ab8ba
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
18 changes: 18 additions & 0 deletions include/imex/ExecutionEngine/ImexRunnerUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@
#include "mlir/ExecutionEngine/Float16bits.h"
#include "mlir/ExecutionEngine/RunnerUtils.h"

template <typename T, int N> struct MemRefDescriptor {
T *allocated;
T *aligned;
int64_t offset;
int64_t sizes[N];
int64_t strides[N];
};

extern "C" IMEX_RUNNERUTILS_EXPORT void
_mlir_ciface_fillResource1DBF16(MemRefDescriptor<bf16, 1> *ptr, // NOLINT
float value);
extern "C" IMEX_RUNNERUTILS_EXPORT void
_mlir_ciface_fillResource1DF16(MemRefDescriptor<f16, 1> *ptr, // NOLINT
float value);
extern "C" IMEX_RUNNERUTILS_EXPORT void
_mlir_ciface_fillResource1DF32(MemRefDescriptor<float, 1> *ptr, // NOLINT
float value);

extern "C" IMEX_RUNNERUTILS_EXPORT void
_mlir_ciface_printMemrefBF16(UnrankedMemRefType<bf16> *m);
extern "C" IMEX_RUNNERUTILS_EXPORT void
Expand Down
23 changes: 23 additions & 0 deletions lib/ExecutionEngine/ImexRunnerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,29 @@

// NOLINTBEGIN(*-identifier-naming)

/// Fills the given 1D bf16 memref with the given float value.
extern "C" void
_mlir_ciface_fillResource1DBF16(MemRefDescriptor<bf16, 1> *ptr, // NOLINT
float value) {
bf16 bf16_val(value);
std::fill_n(ptr->allocated, ptr->sizes[0], bf16_val);
}

/// Fills the given 1D f16 memref with the given float value.
extern "C" void
_mlir_ciface_fillResource1DF16(MemRefDescriptor<f16, 1> *ptr, // NOLINT
float value) {
f16 f16_val(value);
std::fill_n(ptr->allocated, ptr->sizes[0], f16_val);
}

/// Fills the given 1D float (f32) memref with the given float value.
extern "C" void
_mlir_ciface_fillResource1DF32(MemRefDescriptor<float, 1> *ptr, // NOLINT
float value) {
std::fill_n(ptr->allocated, ptr->sizes[0], value);
}

extern "C" void _mlir_ciface_printMemrefBF16(UnrankedMemRefType<bf16> *M) {
impl::printMemRef(*M);
}
Expand Down

0 comments on commit 73ab8ba

Please sign in to comment.