Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][CIRGen] Support Lambda capturing this object #1213

Merged
merged 6 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,11 +943,35 @@ static LValue emitGlobalVarDeclLValue(CIRGenFunction &CGF, const Expr *E,
return LV;
}

static LValue emitCapturedFieldLValue(CIRGenFunction &CGF, const FieldDecl *FD,
mlir::Value ThisValue) {
QualType TagType = CGF.getContext().getTagDeclType(FD->getParent());
LValue LV = CGF.MakeNaturalAlignAddrLValue(ThisValue, TagType);
return CGF.emitLValueForField(LV, FD);
static LValue emitCapturedFieldLValue(CIRGenFunction &cgf, const FieldDecl *fd,
mlir::Value thisValue) {
return cgf.emitLValueForLambdaField(fd, thisValue);
}

/// Given that we are currently emitting a lambda, emit an l-value for
/// one of its members.
///
LValue CIRGenFunction::emitLValueForLambdaField(const FieldDecl *field,
mlir::Value thisValue) {
bool hasExplicitObjectParameter = false;
const auto *methD = dyn_cast_if_present<CXXMethodDecl>(CurCodeDecl);
LValue lambdaLV;
if (methD) {
hasExplicitObjectParameter = methD->isExplicitObjectMemberFunction();
assert(methD->getParent()->isLambda());
assert(methD->getParent() == field->getParent());
}
if (hasExplicitObjectParameter) {
llvm_unreachable("ExplicitObjectMemberFunction NYI");
} else {
QualType lambdaTagType = getContext().getTagDeclType(field->getParent());
lambdaLV = MakeNaturalAlignAddrLValue(thisValue, lambdaTagType);
}
bcardosolopes marked this conversation as resolved.
Show resolved Hide resolved
return emitLValueForField(lambdaLV, field);
}

LValue CIRGenFunction::emitLValueForLambdaField(const FieldDecl *field) {
return emitLValueForLambdaField(field, CXXABIThisValue);
}

static LValue emitFunctionDeclLValue(CIRGenFunction &CGF, const Expr *E,
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,8 @@ void AggExprEmitter::VisitLambdaExpr(LambdaExpr *E) {
ValueDecl *v = capture.getCapturedVar();
fieldName = v->getName();
CGF.getCIRGenModule().LambdaFieldToName[*CurField] = fieldName;
} else if (capture.capturesThis()) {
CGF.getCIRGenModule().LambdaFieldToName[*CurField] = "this";
} else {
llvm_unreachable("NYI");
}
Expand Down
19 changes: 18 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1279,7 +1279,24 @@ void CIRGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
MD->getParent()->getCaptureFields(LambdaCaptureFields,
LambdaThisCaptureField);
if (LambdaThisCaptureField) {
llvm_unreachable("NYI");
// If the lambda captures the object referred to by '*this' - either by
// value or by reference, make sure CXXThisValue points to the correct
// object.

// Get the lvalue for the field (which is a copy of the enclosing object
// or contains the address of the enclosing object).
LValue thisFieldLValue =
emitLValueForLambdaField(LambdaThisCaptureField);
if (!LambdaThisCaptureField->getType()->isPointerType()) {
// If the enclosing object was captured by value, just use its
// address. Sign this pointer.
CXXThisValue = thisFieldLValue.getPointer();
} else {
// Load the lvalue pointed to by the field, since '*this' was captured
// by reference.
CXXThisValue = emitLoadOfLValue(thisFieldLValue, SourceLocation())
.getScalarVal();
}
}
for (auto *FD : MD->getParent()->fields()) {
if (FD->hasCapturedVLAType()) {
Expand Down
5 changes: 4 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1679,8 +1679,11 @@ class CIRGenFunction : public CIRGenTypeCache {
void initializeVTablePointer(mlir::Location loc, const VPtr &Vptr);

AggValueSlot::Overlap_t getOverlapForFieldInit(const FieldDecl *FD);
LValue emitLValueForField(LValue Base, const clang::FieldDecl *Field);
LValue emitLValueForField(LValue base, const clang::FieldDecl *field);
LValue emitLValueForBitField(LValue base, const FieldDecl *field);
LValue emitLValueForLambdaField(const FieldDecl *field);
LValue emitLValueForLambdaField(const FieldDecl *field,
mlir::Value thisValue);

/// Like emitLValueForField, excpet that if the Field is a reference, this
/// will return the address of the reference and not the address of the value
Expand Down
133 changes: 120 additions & 13 deletions clang/test/CIR/CodeGen/lambda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ void fn() {
a();
}

// CHECK-DAG: !ty_A = !cir.struct<struct "A" {!s32i}>
// CHECK: !ty_anon2E0_ = !cir.struct<class "anon.0" {!u8i}>
// CHECK-DAG: !ty_anon2E7_ = !cir.struct<class "anon.7" {!ty_A}>
// CHECK-DAG: !ty_anon2E8_ = !cir.struct<class "anon.8" {!cir.ptr<!ty_A>}>
// CHECK-DAG: module

// CHECK: cir.func lambda internal private @_ZZ2fnvENK3$_0clEv{{.*}}) extra
Expand All @@ -18,9 +21,8 @@ void fn() {
// CHECK-NEXT: %0 = cir.alloca !ty_anon2E0_, !cir.ptr<!ty_anon2E0_>, ["a"]
// CHECK: cir.call @_ZZ2fnvENK3$_0clEv

// LLVM: {{.*}}void @"_ZZ2fnvENK3$_0clEv"(ptr [[THIS:%.*]])
// FIXME: argument attributes should be emmitted, and lambda's alignment
// COM: LLVM: {{.*}} @"_ZZ2fnvENK3$_0clEv"(ptr noundef nonnull align 1 dereferenceable(1) [[THIS:%.*]]){{%.*}} align 2 {
// LLVM-LABEL: _ZZ2fnvENK3$_0clEv
// LLVM-SAME: (ptr [[THIS:%.*]])
// LLVM: [[THIS_ADDR:%.*]] = alloca ptr, i64 1, align 8
// LLVM: store ptr [[THIS]], ptr [[THIS_ADDR]], align 8
// LLVM: [[THIS1:%.*]] = load ptr, ptr [[THIS_ADDR]], align 8
Expand Down Expand Up @@ -53,9 +55,10 @@ void l0() {
// CHECK: %8 = cir.load %7 : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
// CHECK: cir.store %6, %8 : !s32i, !cir.ptr<!s32i>

// CHECK: cir.func @_Z2l0v()
// CHECK-LABEL: _Z2l0v

// LLVM: {{.* }}void @"_ZZ2l0vENK3$_0clEv"(ptr [[THIS:%.*]])
// LLVM-LABEL: _ZZ2l0vENK3$_0clEv
// LLVM-SAME: (ptr [[THIS:%.*]])
// LLVM: [[THIS_ADDR:%.*]] = alloca ptr, i64 1, align 8
// LLVM: store ptr [[THIS]], ptr [[THIS_ADDR]], align 8
// LLVM: [[THIS1:%.*]] = load ptr, ptr [[THIS_ADDR]], align 8
Expand Down Expand Up @@ -91,7 +94,7 @@ auto g() {
};
}

// CHECK: cir.func @_Z1gv() -> !ty_anon2E3_
// CHECK-LABEL: @_Z1gv()
// CHECK: %0 = cir.alloca !ty_anon2E3_, !cir.ptr<!ty_anon2E3_>, ["__retval"] {alignment = 8 : i64}
// CHECK: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["i", init] {alignment = 4 : i64}
// CHECK: %2 = cir.const #cir.int<12> : !s32i
Expand Down Expand Up @@ -120,7 +123,7 @@ auto g2() {
}

// Should be same as above because of NRVO
// CHECK: cir.func @_Z2g2v() -> !ty_anon2E4_
// CHECK-LABEL: @_Z2g2v()
// CHECK-NEXT: %0 = cir.alloca !ty_anon2E4_, !cir.ptr<!ty_anon2E4_>, ["__retval", init] {alignment = 8 : i64}
// CHECK-NEXT: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["i", init] {alignment = 4 : i64}
// CHECK-NEXT: %2 = cir.const #cir.int<12> : !s32i
Expand All @@ -143,7 +146,7 @@ int f() {
return g2()();
}

// CHECK: cir.func @_Z1fv() -> !s32i
// CHECK-LABEL: @_Z1fv()
// CHECK-NEXT: %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
// CHECK-NEXT: cir.scope {
// CHECK-NEXT: %2 = cir.alloca !ty_anon2E4_, !cir.ptr<!ty_anon2E4_>, ["ref.tmp0"] {alignment = 8 : i64}
Expand All @@ -156,7 +159,8 @@ int f() {
// CHECK-NEXT: cir.return %1 : !s32i
// CHECK-NEXT: }

// LLVM: {{.*}}i32 @"_ZZ2g2vENK3$_0clEv"(ptr [[THIS:%.*]])
// LLVM-LABEL: _ZZ2g2vENK3$_0clEv
// LLVM-SAME: (ptr [[THIS:%.*]])
// LLVM: [[THIS_ADDR:%.*]] = alloca ptr, i64 1, align 8
// LLVM: [[I_SAVE:%.*]] = alloca i32, i64 1, align 4
// LLVM: store ptr [[THIS]], ptr [[THIS_ADDR]], align 8
Expand Down Expand Up @@ -201,7 +205,7 @@ int g3() {
// lambda operator int (*)(int const&)()
// CHECK: cir.func internal private @_ZZ2g3vENK3$_0cvPFiRKiEEv

// CHECK: cir.func @_Z2g3v() -> !s32i
// CHECK-LABEL: @_Z2g3v()
// CHECK: %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
// CHECK: %1 = cir.alloca !cir.ptr<!cir.func<!s32i (!cir.ptr<!s32i>)>>, !cir.ptr<!cir.ptr<!cir.func<!s32i (!cir.ptr<!s32i>)>>>, ["fn", init] {alignment = 8 : i64}
// CHECK: %2 = cir.alloca !s32i, !cir.ptr<!s32i>, ["task", init] {alignment = 4 : i64}
Expand Down Expand Up @@ -230,12 +234,14 @@ int g3() {
// CHECK: }

// lambda operator()
// LLVM-LABEL: _ZZ2g3vENK3$_0clERKi
// FIXME: argument attributes should be emitted
// COM: LLVM: define internal noundef i32 @"_ZZ2g3vENK3$_0clERKi"(ptr noundef nonnull align 1 dereferenceable(1) {{%.*}}, ptr noundef nonnull align 4 dereferenceable(4){{%.*}}) #0 align 2
// LLVM: {{.*}}i32 @"_ZZ2g3vENK3$_0clERKi"(ptr {{%.*}}, ptr {{%.*}})
// COM: LLVM-SAME: (ptr noundef nonnull align 1 dereferenceable(1) {{%.*}},
// COM: LLVM-SAME: ptr noundef nonnull align 4 dereferenceable(4){{%.*}}) #0 align 2

// lambda __invoke()
// LLVM: {{.*}}i32 @"_ZZ2g3vEN3$_08__invokeERKi"(ptr [[i:%.*]])
// LLVM-LABEL: _ZZ2g3vEN3$_08__invokeERKi
// LLVM-SAME: (ptr [[i:%.*]])
// LLVM: [[i_addr:%.*]] = alloca ptr, i64 1, align 8
// LLVM: [[ret_val:%.*]] = alloca i32, i64 1, align 4
// LLVM: [[unused_capture:%.*]] = alloca %class.anon.5, i64 1, align 1
Expand Down Expand Up @@ -285,3 +291,104 @@ int g3() {
// LLVM: store i32 [[tmp2]], ptr [[ret_val]], align 4
// LLVM: [[tmp3:%.*]] = load i32, ptr [[ret_val]], align 4
// LLVM: ret i32 [[tmp3]]

struct A {
int a = 111;
int foo() { return [*this] { return a; }(); }
int bar() { return [this] { return a; }(); }
};
// A's default ctor
// CHECK-LABEL: _ZN1AC1Ev

// lambda operator() in foo()
// CHECK-LABEL: _ZZN1A3fooEvENKUlvE_clEv
// CHECK-SAME: ([[ARG:%.*]]: !cir.ptr<!ty_anon2E7_>
// CHECK: [[ARG_ADDR:%.*]] = cir.alloca !cir.ptr<!ty_anon2E7_>, !cir.ptr<!cir.ptr<!ty_anon2E7_>>, ["this", init] {alignment = 8 : i64}
// CHECK: [[RETVAL_ADDR:%.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
// CHECK: cir.store [[ARG]], [[ARG_ADDR]] : !cir.ptr<!ty_anon2E7_>, !cir.ptr<!cir.ptr<!ty_anon2E7_>>
// CHECK: [[CLS_ANNO7:%.*]] = cir.load [[ARG_ADDR]] : !cir.ptr<!cir.ptr<!ty_anon2E7_>>, !cir.ptr<!ty_anon2E7_>
// CHECK: [[STRUCT_A:%.*]] = cir.get_member [[CLS_ANNO7]][0] {name = "this"} : !cir.ptr<!ty_anon2E7_> -> !cir.ptr<!ty_A>
// CHECK: [[a:%.*]] = cir.get_member [[STRUCT_A]][0] {name = "a"} : !cir.ptr<!ty_A> -> !cir.ptr<!s32i> loc(#loc70)
// CHECK: [[TMP0:%.*]] = cir.load [[a]] : !cir.ptr<!s32i>, !s32i
// CHECK: cir.store [[TMP0]], [[RETVAL_ADDR]] : !s32i, !cir.ptr<!s32i>
// CHECK: [[RET_VAL:%.*]] = cir.load [[RETVAL_ADDR]] : !cir.ptr<!s32i>,
// CHECK: cir.return [[RET_VAL]] : !s32i

// LLVM-LABEL: @_ZZN1A3fooEvENKUlvE_clEv
// LLVM-SAME: (ptr [[ARG:%.*]])
// LLVM: [[ARG_ADDR:%.*]] = alloca ptr, i64 1, align 8
// LLVM: [[RET:%.*]] = alloca i32, i64 1, align 4
// LLVM: store ptr [[ARG]], ptr [[ARG_ADDR]], align 8
// LLVM: [[CLS_ANNO7:%.*]] = load ptr, ptr [[ARG_ADDR]], align 8
// LLVM: [[STRUCT_A:%.*]] = getelementptr %class.anon.7, ptr [[CLS_ANNO7]], i32 0, i32 0
// LLVM: [[a:%.*]] = getelementptr %struct.A, ptr [[STRUCT_A]], i32 0, i32 0
// LLVM: [[TMP0:%.*]] = load i32, ptr [[a]], align 4
// LLVM: store i32 [[TMP0]], ptr [[RET]], align 4
// LLVM: [[TMP1:%.*]] = load i32, ptr [[RET]], align 4
// LLVM: ret i32 [[TMP1]]

// A::foo()
// CHECK-LABEL: @_ZN1A3fooEv
// CHECK: [[THIS_ARG:%.*]] = cir.alloca !ty_anon2E7_, !cir.ptr<!ty_anon2E7_>, ["ref.tmp0"] {alignment = 4 : i64}
// CHECK: cir.call @_ZZN1A3fooEvENKUlvE_clEv([[THIS_ARG]]) : (!cir.ptr<!ty_anon2E7_>) -> !s32i

// LLVM-LABEL: _ZN1A3fooEv
// LLVM: [[this_in_foo:%.*]] = alloca %class.anon.7, i64 1, align 4
// LLVM: call i32 @_ZZN1A3fooEvENKUlvE_clEv(ptr [[this_in_foo]])

// lambda operator() in bar()
// CHECK-LABEL: _ZZN1A3barEvENKUlvE_clEv
// CHECK-SAME: ([[ARG2:%.*]]: !cir.ptr<!ty_anon2E8_>
// CHECK: [[ARG2_ADDR:%.*]] = cir.alloca !cir.ptr<!ty_anon2E8_>, !cir.ptr<!cir.ptr<!ty_anon2E8_>>, ["this", init] {alignment = 8 : i64}
// CHECK: [[RETVAL_ADDR:%.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
// CHECK: cir.store [[ARG2]], [[ARG2_ADDR]] : !cir.ptr<!ty_anon2E8_>, !cir.ptr<!cir.ptr<!ty_anon2E8_>>
// CHECK: [[CLS_ANNO8:%.*]] = cir.load [[ARG2_ADDR]] : !cir.ptr<!cir.ptr<!ty_anon2E8_>>, !cir.ptr<!ty_anon2E8_>
// CHECK: [[STRUCT_A_PTR:%.*]] = cir.get_member [[CLS_ANNO8]][0] {name = "this"} : !cir.ptr<!ty_anon2E8_> -> !cir.ptr<!cir.ptr<!ty_A>>
// CHECK: [[STRUCT_A:%.*]] = cir.load [[STRUCT_A_PTR]] : !cir.ptr<!cir.ptr<!ty_A>>, !cir.ptr<!ty_A>
// CHECK: [[a:%.*]] = cir.get_member [[STRUCT_A]][0] {name = "a"} : !cir.ptr<!ty_A> -> !cir.ptr<!s32i> loc(#loc70)
// CHECK: [[TMP0:%.*]] = cir.load [[a]] : !cir.ptr<!s32i>, !s32i
// CHECK: cir.store [[TMP0]], [[RETVAL_ADDR]] : !s32i, !cir.ptr<!s32i>
// CHECK: [[RET_VAL:%.*]] = cir.load [[RETVAL_ADDR]] : !cir.ptr<!s32i>
// CHECK: cir.return [[RET_VAL]] : !s32i

// LLVM-LABEL: _ZZN1A3barEvENKUlvE_clEv
// LLVM-SAME: (ptr [[ARG2:%.*]])
// LLVM: [[ARG2_ADDR:%.*]] = alloca ptr, i64 1, align 8
// LLVM: [[RET:%.*]] = alloca i32, i64 1, align 4
// LLVM: store ptr [[ARG2]], ptr [[ARG2_ADDR]], align 8
// LLVM: [[CLS_ANNO8:%.*]] = load ptr, ptr [[ARG2_ADDR]], align 8
// LLVM: [[STRUCT_A_PTR:%.*]] = getelementptr %class.anon.8, ptr [[CLS_ANNO8]], i32 0, i32 0
// LLVM: [[STRUCT_A:%.*]] = load ptr, ptr [[STRUCT_A_PTR]], align 8
// LLVM: [[a:%.*]] = getelementptr %struct.A, ptr [[STRUCT_A]], i32
// LLVM: [[TMP0:%.*]] = load i32, ptr [[a]], align 4
// LLVM: store i32 [[TMP0]], ptr [[RET]], align 4
// LLVM: [[TMP1:%.*]] = load i32, ptr [[RET]], align 4
// LLVM: ret i32 [[TMP1]]

// A::bar()
// CHECK-LABEL: _ZN1A3barEv
// CHECK: [[THIS_ARG:%.*]] = cir.alloca !ty_anon2E8_, !cir.ptr<!ty_anon2E8_>, ["ref.tmp0"] {alignment = 8 : i64}
// CHECK: cir.call @_ZZN1A3barEvENKUlvE_clEv([[THIS_ARG]])

// LLVM-LABEL: _ZN1A3barEv
// LLVM: [[this_in_bar:%.*]] = alloca %class.anon.8, i64 1, align 8
// LLVM: call i32 @_ZZN1A3barEvENKUlvE_clEv(ptr [[this_in_bar]])

int test_lambda_this1(){
struct A clsA;
bcardosolopes marked this conversation as resolved.
Show resolved Hide resolved
int x = clsA.foo();
int y = clsA.bar();
return x+y;
}

// CHECK-LABEL: test_lambda_this1
// Construct A
// CHECK: cir.call @_ZN1AC1Ev([[A_THIS:%.*]]) : (!cir.ptr<!ty_A>) -> ()
// CHECK: cir.call @_ZN1A3fooEv([[A_THIS]]) : (!cir.ptr<!ty_A>) -> !s32i
// CHECK: cir.call @_ZN1A3barEv([[A_THIS]]) : (!cir.ptr<!ty_A>) -> !s32i

// LLVM-LABEL: test_lambda_this1
// LLVM: [[A_THIS:%.*]] = alloca %struct.A, i64 1, align 4
// LLVM: call void @_ZN1AC1Ev(ptr [[A_THIS]])
// LLVM: call i32 @_ZN1A3fooEv(ptr [[A_THIS]])
// LLVM: call i32 @_ZN1A3barEv(ptr [[A_THIS]])
Loading