Skip to content

Commit

Permalink
fix(compiler): Use FHE.zero_tensor instead of bufferization.alloc_ten…
Browse files Browse the repository at this point in the history
…sor as alloc_tensor explictly has a alloc semantic, so it cannot be eliminated by dce
  • Loading branch information
BourgerieQuentin committed Sep 11, 2023
1 parent f409993 commit d71201f
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ struct FHELinalgOpToLinalgGeneric : public mlir::OpRewritePattern<FHELinalgOp> {
mlir::RankedTensorType rhsTy = ((mlir::Type)linalgOp.getRhs().getType())
.cast<mlir::RankedTensorType>();
// linalg.init_tensor for initial value
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
mlir::Value init = rewriter.create<FHE::ZeroTensorOp>(
linalgOp.getLoc(), resultTy, mlir::ValueRange{});

// Create the affine #maps_0
Expand Down Expand Up @@ -424,8 +424,8 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric
nestedBuilder.create<linalg::YieldOp>(loc, lookup.getResult());
};

auto output = rewriter.create<bufferization::AllocTensorOp>(
loc, resultTy, mlir::ValueRange{});
auto output =
rewriter.create<FHE::ZeroTensorOp>(loc, resultTy, mlir::ValueRange{});

// Create the `linalg.g eneric` op
Types resTys{resultTy};
Expand Down Expand Up @@ -508,7 +508,7 @@ struct FHELinalgApplyMultiLookupTableToLinalgGeneric
mlir::RankedTensorType lutsTy = getRankedTensorType(luts);
auto lutElmtTy = lutsTy.getElementType();
// linalg.init_tensor for initial value
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
mlir::Value init = rewriter.create<FHE::ZeroTensorOp>(
fheLinalgLutOp.getLoc(), resultTy, mlir::ValueRange{});

auto lutsShape = lutsTy.getShape();
Expand Down Expand Up @@ -655,7 +655,7 @@ struct FHELinalgApplyLookupTableToLinalgGeneric
((mlir::Type)lutOp.getT().getType()).cast<mlir::RankedTensorType>();

// linalg.init_tensor for initial value
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
mlir::Value init = rewriter.create<FHE::ZeroTensorOp>(
lutOp.getLoc(), resultTy, mlir::ValueRange{});

// Create the affine #maps_0
Expand Down Expand Up @@ -756,7 +756,7 @@ struct FHELinalgNegEintToLinalgGeneric
.cast<mlir::RankedTensorType>();

// linalg.init_tensor for initial value
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
mlir::Value init = rewriter.create<FHE::ZeroTensorOp>(
negEintOp.getLoc(), resultTy, mlir::ValueRange{});

// Create the affine #maps_0
Expand Down Expand Up @@ -1985,8 +1985,8 @@ struct FHELinalgToSignedToLinalgGeneric
mlir::RankedTensorType resultTy =
op->getResult(0).getType().cast<mlir::RankedTensorType>();

mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
mlir::Value init = rewriter.create<FHE::ZeroTensorOp>(op.getLoc(), resultTy,
mlir::ValueRange{});

llvm::SmallVector<mlir::AffineMap, 2> maps{
mlir::AffineMap::getMultiDimIdentityMap(inputTy.getShape().size(),
Expand Down Expand Up @@ -2074,8 +2074,8 @@ struct FHELinalgToUnsignedToLinalgGeneric
mlir::RankedTensorType resultTy =
op->getResult(0).getType().cast<mlir::RankedTensorType>();

mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
mlir::Value init = rewriter.create<FHE::ZeroTensorOp>(op.getLoc(), resultTy,
mlir::ValueRange{});

llvm::SmallVector<mlir::AffineMap, 2> maps{
mlir::AffineMap::getMultiDimIdentityMap(inputTy.getShape().size(),
Expand Down Expand Up @@ -2161,8 +2161,8 @@ struct FHELinalgRoundToLinalgGeneric
auto inputTy = op.getInput().getType().cast<mlir::RankedTensorType>();
auto outputTy = op.getOutput().getType().cast<mlir::RankedTensorType>();

auto buffer = rewriter.create<bufferization::AllocTensorOp>(
loc, outputTy, mlir::ValueRange{});
auto buffer =
rewriter.create<FHE::ZeroTensorOp>(loc, outputTy, mlir::ValueRange{});

auto maps = llvm::SmallVector<mlir::AffineMap, 2>{
mlir::AffineMap::getMultiDimIdentityMap(inputTy.getShape().size(),
Expand Down Expand Up @@ -2222,8 +2222,6 @@ void FHETensorOpsToLinalg::runOnOperation() {
target.addIllegalOp<mlir::concretelang::FHELinalg::Dot>();
target.addIllegalDialect<mlir::concretelang::FHELinalg::FHELinalgDialect>();

target.addLegalOp<bufferization::AllocTensorOp>();

mlir::RewritePatternSet patterns(&getContext());

patterns.insert<DotToLinalgGeneric<mlir::concretelang::FHELinalg::Dot,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @apply_lookup_table(%[[Varg0:.*]]: tensor<2x3x4x!FHE.eint<2>>, %[[Varg1:.*]]: tensor<4xi64>) -> tensor<2x3x4x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = {{\[}}#map, #map{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Varg0]] : tensor<2x3x4x!FHE.eint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%[[Varg2:.*]]: !FHE.eint<2>, %[[Varg3:.*]]: !FHE.eint<2>):
// CHECK-NEXT: %[[V2:.*]] = "FHE.apply_lookup_table"(%[[Varg2]], %[[Varg1]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

//CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
//CHECK: func.func @multi_lut(%arg0: tensor<4x4x!FHE.eint<2>>, %[[LUTS:.*]]: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> {
//CHECK: %[[MEM:.*]] = bufferization.alloc_tensor() : tensor<4x4x!FHE.eint<2>>
//CHECK: %[[MEM:.*]] = "FHE.zero_tensor"() : () -> tensor<4x4x!FHE.eint<2>>
//CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4x4x!FHE.eint<2>>) outs(%[[MEM]] : tensor<4x4x!FHE.eint<2>>) {
//CHECK: ^bb0(%[[IN:.*]]: !FHE.eint<2>, %[[UNUSED:.*]]: !FHE.eint<2>):
//CHECK: %[[INDEXA:.*]] = linalg.index 0 : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

//CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
//CHECK: func.func @multi_lut(%arg0: tensor<4x3x!FHE.eint<2>>, %[[LUTS:.*]]: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> {
//CHECK: %[[MEM:.*]] = bufferization.alloc_tensor() : tensor<4x3x!FHE.eint<2>>
//CHECK: %[[MEM:.*]] = "FHE.zero_tensor"() : () -> tensor<4x3x!FHE.eint<2>>
//CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4x3x!FHE.eint<2>>) outs(%[[MEM]] : tensor<4x3x!FHE.eint<2>>) {
//CHECK: ^bb0(%[[IN:.*]]: !FHE.eint<2>, %[[UNUSED:.*]]: !FHE.eint<2>):
//CHECK: %[[INDEX:.*]] = linalg.index 1 : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func.func @main(%arg0: tensor<1x1x8x10x!FHE.eint<5>>) -> tensor<1x1x6x9x!FHE.ein
// CHECK: func.func @main(%[[a0:.*]]: tensor<1x1x6x5x!FHE.esint<6>>) -> tensor<1x1x5x3x!FHE.esint<6>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: %[[v1:.*]] = arith.constant dense<16> : tensor<1xi7>
// CHECK-NEXT: %[[v2:.*]] = bufferization.alloc_tensor() : tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: %[[v2:.*]] = "FHE.zero_tensor"() : () -> tensor<1x1x5x3x!FHE.esint<6>>
// CHECK-NEXT: %[[v3:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[v0]], %[[v1]] : tensor<1x1x5x3x!FHE.esint<6>>, tensor<1xi7>) outs(%[[v2]] : tensor<1x1x5x3x!FHE.esint<6>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.esint<6>, %[[aa1:.*]]: i7, %[[aa2:.*]]: !FHE.esint<6>):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.sub_eint_int"(%[[aa0]], %[[aa1]]) : (!FHE.esint<6>, i7) -> !FHE.esint<6>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @neg_eint(%[[Varg0:.*]]: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = {{\[}}#map, #map{{\], iterator}}_types = {{\[}}"parallel", "parallel", "parallel"{{\]}}} ins(%[[Varg0]] : tensor<2x3x4x!FHE.eint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%[[Varg1:.*]]: !FHE.eint<2>, %[[Varg2:.*]]: !FHE.eint<2>):
// CHECK-NEXT: %[[V2:.*]] = "FHE.neg_eint"(%[[Varg1]]) : (!FHE.eint<2>) -> !FHE.eint<2>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// CHECK: #[[m0:.*]] = affine_map<(d0) -> (d0)>

// CHECK: func.func @main(%[[a0:.*]]: tensor<5x!FHE.eint<8>>) -> tensor<5x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = bufferization.alloc_tensor() : tensor<5x!FHE.eint<6>>
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<5x!FHE.eint<6>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m0]]], iterator_types = ["parallel"]} ins(%[[a0]] : tensor<5x!FHE.eint<8>>) outs(%[[v0]] : tensor<5x!FHE.eint<6>>) {
// CHECK-NEXT: ^bb0(%[[i0:.*]]: !FHE.eint<8>, %[[o0:.*]]: !FHE.eint<6>):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.round"(%[[i0]]) : (!FHE.eint<8>) -> !FHE.eint<6>
Expand All @@ -23,7 +23,7 @@ func.func @main(%arg0: tensor<5x!FHE.eint<8>>) -> tensor<5x!FHE.eint<6>> {
// CHECK: #[[m0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

// CHECK: func.func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.eint<8>>) -> tensor<2x3x4x!FHE.eint<6>> {
// CHECK-NEXT: %[[v0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<6>>
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x4x!FHE.eint<6>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<2x3x4x!FHE.eint<8>>) outs(%[[v0]] : tensor<2x3x4x!FHE.eint<6>>) {
// CHECK-NEXT: ^bb0(%[[i0:.*]]: !FHE.eint<8>, %[[o0:.*]]: !FHE.eint<6>):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.round"(%[[i0]]) : (!FHE.eint<8>) -> !FHE.eint<6>
Expand All @@ -41,7 +41,7 @@ func.func @main(%arg0: tensor<2x3x4x!FHE.eint<8>>) -> tensor<2x3x4x!FHE.eint<6>>
// CHECK: #[[m0:.*]] = affine_map<(d0) -> (d0)>

// CHECK: func.func @main(%[[a0:.*]]: tensor<5x!FHE.esint<8>>) -> tensor<5x!FHE.esint<6>> {
// CHECK-NEXT: %[[v0:.*]] = bufferization.alloc_tensor() : tensor<5x!FHE.esint<6>>
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<5x!FHE.esint<6>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m0]]], iterator_types = ["parallel"]} ins(%[[a0]] : tensor<5x!FHE.esint<8>>) outs(%[[v0]] : tensor<5x!FHE.esint<6>>) {
// CHECK-NEXT: ^bb0(%[[i0:.*]]: !FHE.esint<8>, %[[o0:.*]]: !FHE.esint<6>):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.round"(%[[i0]]) : (!FHE.esint<8>) -> !FHE.esint<6>
Expand All @@ -59,7 +59,7 @@ func.func @main(%arg0: tensor<5x!FHE.esint<8>>) -> tensor<5x!FHE.esint<6>> {
// CHECK: #[[m0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

// CHECK: func.func @main(%[[a0:.*]]: tensor<2x3x4x!FHE.esint<8>>) -> tensor<2x3x4x!FHE.esint<6>> {
// CHECK-NEXT: %[[v0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.esint<6>>
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x4x!FHE.esint<6>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m0]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[a0]] : tensor<2x3x4x!FHE.esint<8>>) outs(%[[v0]] : tensor<2x3x4x!FHE.esint<6>>) {
// CHECK-NEXT: ^bb0(%[[i0:.*]]: !FHE.esint<8>, %[[o0:.*]]: !FHE.esint<6>):
// CHECK-NEXT: %[[vv0:.*]] = "FHE.round"(%[[i0]]) : (!FHE.esint<8>) -> !FHE.esint<6>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @main(%[[Varg0:.*]]: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.esint<2>> {
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.esint<2>>
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x4x!FHE.esint<2>>
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[Varg0]] : tensor<2x3x4x!FHE.eint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.esint<2>>) {
// CHECK-NEXT: ^bb0(%[[Varg1:.*]]: !FHE.eint<2>, %[[Varg2:.*]]: !FHE.esint<2>):
// CHECK-NEXT: %[[V2:.*]] = "FHE.to_signed"(%[[Varg1]]) : (!FHE.eint<2>) -> !FHE.esint<2>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @main(%[[Varg0:.*]]: tensor<2x3x4x!FHE.esint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[Varg0]] : tensor<2x3x4x!FHE.esint<2>>) outs(%[[V0]] : tensor<2x3x4x!FHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%[[Varg1:.*]]: !FHE.esint<2>, %[[Varg2:.*]]: !FHE.eint<2>):
// CHECK-NEXT: %[[V2:.*]] = "FHE.to_unsigned"(%[[Varg1]]) : (!FHE.esint<2>) -> !FHE.eint<2>
Expand Down

0 comments on commit d71201f

Please sign in to comment.