Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MemoryBanking] Support memory banking forGetGlobalOp #8047

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
123 changes: 123 additions & 0 deletions lib/Transforms/MemoryBanking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include <numeric>

namespace circt {
#define GEN_PASS_DEF_MEMORYBANKING
Expand Down Expand Up @@ -88,6 +90,121 @@ MemRefType computeBankedMemRefType(MemRefType originalType,
return newMemRefType;
}

// Decodes the flat index `linIndex` into an n-dimensional index based on the
// given `shape` of the array in row-major order. Returns an array to represent
// the n-dimensional indices.
SmallVector<int64_t> decodeIndex(int64_t linIndex, ArrayRef<int64_t> shape) {
const unsigned rank = shape.size();
SmallVector<int64_t> ndIndex(rank, 0);

// Compute from last dimension to first because we assume row-major.
for (int64_t d = rank - 1; d >= 0; --d) {
ndIndex[d] = linIndex % shape[d];
linIndex /= shape[d];
}

return ndIndex;
}

// Performs multi-dimensional slicing on `allAttrs` by extracting all elements
// whose coordinates range from `bankCnt`*`bankingDimension` to
// (`bankCnt`+1)*`bankingDimension` from `bankingDimension`'s dimension, leaving
// other dimensions alone.
SmallVector<SmallVector<Attribute>> sliceSubBlock(ArrayRef<Attribute> allAttrs,
ArrayRef<int64_t> memShape,
unsigned bankingDimension,
unsigned bankingFactor) {
size_t numElements = std::reduce(memShape.begin(), memShape.end(), 1,
std::multiplies<size_t>());
// `bankingFactor` number of flattened attributes that store the information
// in the original globalOp.
SmallVector<SmallVector<Attribute>> subBlocks;
subBlocks.resize(bankingFactor);

for (unsigned linIndex = 0; linIndex < numElements; ++linIndex) {
SmallVector<int64_t> ndIndex = decodeIndex(linIndex, memShape);
unsigned subBlockIndex = ndIndex[bankingDimension] % bankingFactor;
subBlocks[subBlockIndex].push_back(allAttrs[linIndex]);
}

return subBlocks;
}

// Handles the splitting of a GetGlobalOp into multiple banked memory and
// creates new GetGlobalOp to represent each banked memory by slicing the data
// in the original GetGlobalOp.
SmallVector<Value, 4> handleGetGlobalOp(memref::GetGlobalOp getGlobalOp,
uint64_t bankingFactor,
unsigned bankingDimension,
MemRefType newMemRefType,
OpBuilder &builder) {
SmallVector<Value, 4> banks;
auto memTy = cast<MemRefType>(getGlobalOp.getType());
ArrayRef<int64_t> originalShape = memTy.getShape();
auto newShape =
SmallVector<int64_t>(originalShape.begin(), originalShape.end());
newShape[bankingDimension] = originalShape[bankingDimension] / bankingFactor;

auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
assert(globalOp && "The corresponding GlobalOp should exist in the module");
MemRefType globalOpTy = globalOp.getType();

auto cstAttr =
dyn_cast_or_null<DenseElementsAttr>(globalOp.getConstantInitValue());
auto attributes = cstAttr.getValues<Attribute>();
SmallVector<Attribute, 8> allAttrs(attributes.begin(), attributes.end());

auto subBlocks =
sliceSubBlock(allAttrs, originalShape, bankingDimension, bankingFactor);

// Initialize globalOp and getGlobalOp's insertion points. Since
// bankingFactor is guaranteed to be greater than zero as it would
// have early exited if not, the loop below will execute at least
// once. So it's safe to manipulate the insertion points here.
builder.setInsertionPointAfter(globalOp);
OpBuilder::InsertPoint globalOpsInsertPt = builder.saveInsertionPoint();
builder.setInsertionPointAfter(getGlobalOp);
OpBuilder::InsertPoint getGlobalOpsInsertPt = builder.saveInsertionPoint();

for (size_t bankCnt = 0; bankCnt < bankingFactor; ++bankCnt) {
// Prepare relevant information to create a new GlobalOp
auto newMemRefTy = MemRefType::get(newShape, globalOpTy.getElementType());
auto newTypeAttr = TypeAttr::get(newMemRefTy);
std::string newNameStr =
llvm::formatv("{0}_{1}_{2}_{3}", globalOp.getConstantAttrName(),
llvm::join(llvm::map_range(newShape,
[](int64_t dim) {
return std::to_string(dim);
}),
"x"),
globalOpTy.getElementType(), bankCnt);
RankedTensorType tensorType =
RankedTensorType::get({newShape}, globalOpTy.getElementType());
auto newInitValue = DenseElementsAttr::get(tensorType, subBlocks[bankCnt]);

builder.restoreInsertionPoint(globalOpsInsertPt);
auto newGlobalOp = builder.create<memref::GlobalOp>(
globalOp.getLoc(), builder.getStringAttr(newNameStr),
globalOp.getSymVisibilityAttr(), newTypeAttr, newInitValue,
globalOp.getConstantAttr(), globalOp.getAlignmentAttr());
builder.setInsertionPointAfter(newGlobalOp);
globalOpsInsertPt = builder.saveInsertionPoint();

builder.restoreInsertionPoint(getGlobalOpsInsertPt);
auto newGetGlobalOp = builder.create<memref::GetGlobalOp>(
getGlobalOp.getLoc(), newMemRefTy, newGlobalOp.getName());
builder.setInsertionPointAfter(newGetGlobalOp);
getGlobalOpsInsertPt = builder.saveInsertionPoint();

banks.push_back(newGetGlobalOp);
}

globalOp.erase();
return banks;
}

SmallVector<Value, 4> createBanks(Value originalMem, uint64_t bankingFactor,
unsigned bankingDimension) {
MemRefType originalMemRefType = cast<MemRefType>(originalMem.getType());
Expand Down Expand Up @@ -125,6 +242,12 @@ SmallVector<Value, 4> createBanks(Value originalMem, uint64_t bankingFactor,
banks.push_back(bankAllocaOp);
}
})
.Case<memref::GetGlobalOp>([&](memref::GetGlobalOp getGlobalOp) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How difficult would it be to place all this code in a separate function and just call it here? It is quite large, but unsure if that would require too many arguments.

auto newBanks =
handleGetGlobalOp(getGlobalOp, bankingFactor, bankingDimension,
newMemRefType, builder);
banks.append(newBanks.begin(), newBanks.end());
})
.Default([](Operation *) {
llvm_unreachable("Unhandled memory operation type");
});
Expand Down
46 changes: 46 additions & 0 deletions test/Transforms/memory_banking_multi_dim.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2 dimension=1" | FileCheck %s --check-prefix RANK2-BANKDIM1
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2" | FileCheck %s --check-prefix GETGLOBAL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may have handled this in an earlier PR, but what about corner cases, e.g., banking-factor=0, dim_size % banking-factor != 0, etc.?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in these cases, it will early signal failure:

if (bankingFactor == 0) {

or raise an assertion error:
assert(originalShape[bankingDimension] % bankingFactor == 0 &&


// RANK2-BANKDIM1: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d1 mod 2)>
// RANK2-BANKDIM1: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (d1 floordiv 2)>
Expand Down Expand Up @@ -73,3 +74,48 @@ func.func @rank_two_bank_dim1(%arg0: memref<8x6xf32>, %arg1: memref<8x6xf32>) ->
return %mem : memref<8x6xf32>
}

// -----

// GETGLOBAL-LABEL: memref.global "private" constant @constant_2x8_f32_0 : memref<2x8xf32> = dense<{{\[\[}}8.000000e+00, -2.000000e+00, -2.000000e+00, -1.000000e+00, -3.000000e+00, -2.000000e+00, 3.000000e+00, 6.000000e+00], [9.000000e+00, -1.000000e+00, -2.000000e+00, -2.000000e+00, -2.000000e+00, -2.000000e+00, -1.000000e+00, -2.000000e+00]]>
// GETGLOBAL: memref.global "private" constant @constant_2x8_f32_1 : memref<2x8xf32> = dense<{{\[\[}}1.000000e+00, -3.000000e+00, -2.000000e+00, -1.000000e+00, 5.000000e+00, -3.000000e+00, -1.000000e+00, -2.000000e+00], [2.000000e+00, -7.000000e+00, 3.000000e+00, 1.000000e+00, -2.000000e+00, 2.000000e+00, -9.000000e+00, -1.000000e+00]]>
// GETGLOBAL: memref.global "private" constant @constant_4x6_f32_0 : memref<4x6xf32> = dense<{{\[\[}}2.000000e+00, -2.000000e+00, -4.000000e+00, -1.000000e+00, -3.000000e+00, 3.000000e+00], [2.000000e+00, -2.000000e+00, 1.000000e+00, -1.000000e+00, 1.000000e+00, -8.000000e+00], [3.000000e+00, -3.000000e+00, -4.000000e+00, -3.000000e+00, -2.000000e+00, 1.000000e+00], [2.000000e+00, -9.000000e+00, 2.000000e+00, -3.000000e+00, -2.000000e+00, 1.000000e+00]]>
// GETGLOBAL: memref.global "private" constant @constant_4x6_f32_1 : memref<4x6xf32> = dense<{{\[\[}}1.000000e+00, 1.000000e+00, 1.000000e+00, -7.000000e+00, 3.000000e+00, -2.000000e+00], [3.000000e+00, -2.000000e+00, -2.000000e+00, -2.000000e+00, 3.000000e+00, 1.000000e+00], [1.000000e+00, 3.000000e+00, -2.000000e+00, -2.000000e+00, 2.000000e+00, -1.000000e+00], [8.000000e+00, -1.000000e+00, 2.000000e+00, 2.000000e+00, -2.000000e+00, -2.000000e+00]]>
module {
memref.global "private" constant @__constant_4x8xf32 : memref<4x8xf32> = dense<[
[8.0, -2.0, -2.0, -1.0, -3.0, -2.0, 3.0, 6.0],
[1.0, -3.0, -2.0, -1.0, 5.0, -3.0, -1.0, -2.0],
[9.0, -1.0, -2.0, -2.0, -2.0, -2.0, -1.0, -2.0],
[2.0, -7.0, 3.0, 1.0, -2.0, 2.0, -9.0, -1.0]
]>
memref.global "private" constant @__constant_8x6xf32 : memref<8x6xf32> = dense<[
[2.0, -2.0, -4.0, -1.0, -3.0, 3.0],
[1.0, 1.0, 1.0, -7.0, 3.0, -2.0],
[2.0, -2.0, 1.0, -1.0, 1.0, -8.0],
[3.0, -2.0, -2.0, -2.0, 3.0, 1.0],
[3.0, -3.0, -4.0, -3.0, -2.0, 1.0],
[1.0, 3.0, -2.0, -2.0, 2.0, -1.0],
[2.0, -9.0, 2.0, -3.0, -2.0, 1.0],
[8.0, -1.0, 2.0, 2.0, -2.0, -2.0]
]>
func.func @main() {
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.get_global @__constant_8x6xf32 : memref<8x6xf32>
%2 = memref.get_global @__constant_4x8xf32 : memref<4x8xf32>
%alloc = memref.alloc() : memref<6x8xf32>
affine.parallel (%arg2) = (0) to (6) {
affine.parallel (%arg3) = (0) to (8) {
%4 = affine.load %0[%arg3, %arg2] : memref<8x6xf32>
affine.store %4, %alloc[%arg2, %arg3] : memref<6x8xf32>
}
}
%alloc_5 = memref.alloc() : memref<8x4xf32>
affine.parallel (%arg2) = (0) to (8) {
affine.parallel (%arg3) = (0) to (4) {
%4 = affine.load %2[%arg3, %arg2] : memref<4x8xf32>
affine.store %4, %alloc_5[%arg2, %arg3] : memref<8x4xf32>
}
}
return
}
}

Loading