Skip to content

Commit

Permalink
[NDArray] Bugfix for "ndarray.permute_dims" operation (#854)
Browse files Browse the repository at this point in the history
* fix argument order of distruntime.copy_permute
  • Loading branch information
AllanZyne committed Sep 9, 2024
1 parent 0c9aa46 commit 1271507
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 26 deletions.
2 changes: 1 addition & 1 deletion include/imex/Dialect/DistRuntime/IR/DistRuntimeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ def CopyPermuteOp : DistRuntime_Op<"copy_permute",
AnyType:$lArray,
Variadic<Index>:$gShape,
Variadic<Index>:$lOffsets,
Variadic<Index>:$nlShape,
Variadic<Index>:$nlOffsets,
Variadic<Index>:$nlShape,
DenseI64ArrayAttr:$axes);
let results = (outs DistRuntime_AsyncHandle:$handle, AnyType:$nlArray);
let assemblyFormat = [{
Expand Down
3 changes: 2 additions & 1 deletion include/imex/Dialect/NDArray/IR/NDArrayOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ class NDArray_Op<string mnemonic, list<Trait> traits = []> :
Op<NDArray_Dialect, mnemonic, traits>;


def DeleteOp : NDArray_Op<"delete"> {
def DeleteOp : NDArray_Op<"delete", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Explicitly delete an NDArray, freeing its memory";
let description = [{
Allow explicitly deleting the memory of an NDArray. It is assumed
Expand Down
5 changes: 3 additions & 2 deletions lib/Conversion/DistToStandard/DistToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ struct DeleteOpConverter

// apply DeleteOp to all parts
for (auto p : lParts) {
(void)rewriter.create<::imex::ndarray::DeleteOp>(loc, p);
auto newOp = rewriter.create<::imex::ndarray::DeleteOp>(loc, p);
newOp->setAttrs(adaptor.getAttributes());
}

rewriter.eraseOp(op);
Expand Down Expand Up @@ -1659,7 +1660,7 @@ struct PermuteDimsOpConverter
distLArray.getHandle());
// finally init dist array
rewriter.replaceOp(
op, createDistArray(loc, rewriter, team, srcGShape, dstLOffsets,
op, createDistArray(loc, rewriter, team, dstGShape, dstLOffsets,
::mlir::ValueRange{distLArray.getNlArray()}));

return ::mlir::success();
Expand Down
4 changes: 3 additions & 1 deletion lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ struct DeleteLowering
auto inp = adaptor.getInput();
auto inpMR = createToMemRef(op.getLoc(), rewriter, inp,
inpArType.getMemRefType(inp));
rewriter.replaceOpWithNewOp<::mlir::memref::DeallocOp>(op, inpMR);
auto newOp =
rewriter.replaceOpWithNewOp<::mlir::memref::DeallocOp>(op, inpMR);
newOp->setAttrs(op->getAttrs());

return ::mlir::success();
}
Expand Down
23 changes: 2 additions & 21 deletions lib/Dialect/DistRuntime/IR/CopyPermuteOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,9 @@ class CopyPermuteOpResultCanonicalizer final
return ::mlir::failure();
}

auto src = op.getLArray();
auto srcType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType());
if (!srcType) {
return ::mlir::failure();
}
auto srcShape = srcType.getShape();
if (::mlir::ShapedType::isDynamicShape(srcShape)) {
return ::mlir::failure();
}

auto axes = op.getAxes();
if (axes.size() != dstShape.size()) {
return ::mlir::failure();
}

::mlir::SmallVector<int64_t> permutedShape;
for (auto i = 0u; i < srcShape.size(); ++i) {
permutedShape.push_back(srcShape[axes[i]]);
}

auto dstLShape = ::imex::getShapeFromValues(op.getNlShape());
auto elType = dstType.getElementType();
auto nType = dstType.cloneWith(permutedShape, elType);
auto nType = dstType.cloneWith(dstLShape, elType);
auto hType = ::imex::distruntime::AsyncHandleType::get(getContext());

auto newOp = rewriter.create<::imex::distruntime::CopyPermuteOp>(
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/NDArray/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_imex_dialect_library(IMEXNDArrayDialect
EWBinOp.cpp
EWUnyOp.cpp
PermuteDimsOp.cpp
DeleteOp.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/mlir/Dialect/NDArray
Expand Down
20 changes: 20 additions & 0 deletions lib/Dialect/NDArray/IR/DeleteOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- DeleteOp.cpp - NDArray dialect --------------------------*- C++ -*-===//
//
// Copyright 2024 Intel Corporation
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file implements the DeleteOp of the NDArray dialect.
///
//===----------------------------------------------------------------------===//

#include <imex/Dialect/NDArray/IR/NDArrayOps.h>

void imex::ndarray::DeleteOp::getEffects(
::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
effects.emplace_back(::mlir::MemoryEffects::Free().get());
}

0 comments on commit 1271507

Please sign in to comment.