Skip to content

Commit

Permalink
Add a few state-related cc ops (#2354)
Browse files Browse the repository at this point in the history
* Add a few state-related cc ops

Signed-off-by: Anna Gringauze <[email protected]>

* Fix test_argument_conversion

Signed-off-by: Anna Gringauze <[email protected]>

* Add printing in failing tests

Signed-off-by: Anna Gringauze <[email protected]>

* Add printing in failing tests

Signed-off-by: Anna Gringauze <[email protected]>

* Fix failing tests

Signed-off-by: Anna Gringauze <[email protected]>

* Address CR comments

Signed-off-by: Anna Gringauze <[email protected]>

* Fix comment

Signed-off-by: Anna Gringauze <[email protected]>

* Address CR comments

Signed-off-by: Anna Gringauze <[email protected]>

---------

Signed-off-by: Anna Gringauze <[email protected]>
  • Loading branch information
annagrin authored Jan 8, 2025
1 parent c5c9361 commit d2bed4c
Show file tree
Hide file tree
Showing 18 changed files with 323 additions and 211 deletions.
1 change: 1 addition & 0 deletions .github/workflows/config/spelling_allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ SLURM
SVD
Stim
Superpositions
Superstaq
TBI
TCP
TableGen
Expand Down
14 changes: 7 additions & 7 deletions docs/sphinx/targets/cpp/infleqtion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// ```
// nvq++ --target infleqtion infleqtion.cpp -o out.x && ./out.x
// ```
// This will submit the job to the Infleqtion's ideal simulator,
// cq_sqale_simulator (default). Alternatively, we can enable hardware noise
// This will submit the job to the ideal simulator for Infleqtion,
// `cq_sqale_simulator` (default). Alternatively, we can enable hardware noise
// model simulation by specifying `noise-sim` to the flag `--infleqtion-method`,
// e.g.,
// ```
Expand All @@ -17,9 +17,9 @@
// nvq++ --target infleqtion --infleqtion-machine cq_sqale_qpu
// --infleqtion-method dry-run infleqtion.cpp -o out.x && ./out.x
// ```
// Note: If targeting ideal cloud simulation, `--infleqtion-machine
// cq_sqale_simulator` is optional since it is the default configuration if not
// provided.
// Note: If targeting ideal cloud simulation,
// `--infleqtion-machine cq_sqale_simulator` is optional since it is the
// default configuration if not provided.

#include <cudaq.h>
#include <fstream>
Expand All @@ -38,7 +38,7 @@ struct ghz {
};

int main() {
// Submit to infleqtion asynchronously (e.g., continue executing
// Submit to Infleqtion asynchronously (e.g., continue executing
// code in the file until the job has been returned).
auto future = cudaq::sample_async(ghz{});
// ... classical code to execute in the meantime ...
Expand All @@ -58,7 +58,7 @@ int main() {
auto async_counts = readIn.get();
async_counts.dump();

// OR: Submit to infleqtion synchronously (e.g., wait for the job
// OR: Submit to Infleqtion synchronously (e.g., wait for the job
// result to be returned before proceeding).
auto counts = cudaq::sample(ghz{});
counts.dump();
Expand Down
86 changes: 86 additions & 0 deletions include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1397,4 +1397,90 @@ def CustomUnitarySymbolOp :
}];
}

//===----------------------------------------------------------------------===//
// Quantum states
//===----------------------------------------------------------------------===//

def quake_CreateStateOp : QuakeOp<"create_state", [Pure]> {
let summary = "Create state from data";
let description = [{
This operation takes a pointer to state data and creates a quantum state,
where state data is a pointer to an array of float or complex numbers.
The operation can be optimized away in DeleteStates pass, or replaced
by an intrinsic runtime call on simulators.

```mlir
%0 = quake.create_state %data %len : (!cc.ptr<!cc.array<complex<f64> x 8>>, i64) -> !cc.ptr<!cc.state>
```
}];

let arguments = (ins
cc_PointerType:$data,
AnySignlessInteger:$length
);
let results = (outs cc_PointerType:$result);
let assemblyFormat = [{
$data `,` $length `:` functional-type(operands, results) attr-dict
}];
}

def QuakeOp_DeleteStateOp : QuakeOp<"delete_state", [] > {
let summary = "Delete quantum state";
let description = [{
This operation takes a pointer to the state and deletes the state object.
The operation can be created in in DeleteStates pass, and replaced later
by an intrinsic runtime call on simulators.

```mlir
quake.delete_state %state : !cc.ptr<!cc.state>
```
}];

let arguments = (ins cc_PointerType:$state);
let results = (outs);
let assemblyFormat = [{
$state `:` type(operands) attr-dict
}];
}

def quake_GetNumberOfQubitsOp : QuakeOp<"get_number_of_qubits", [Pure] > {
let summary = "Get number of qubits from a quantum state";
let description = [{
This operation takes a pointer to the state as an argument and returns
a number of qubits in the state. The operation can be optimized away in
some passes like ReplaceStateByKernel or DeleteStates, or replaced by an
intrinsic runtime call when the target is one of the simulators.

```mlir
%0 = quake.get_number_of_qubits %state : (!cc.ptr<!cc.state>) -> i64
```
}];

let arguments = (ins cc_PointerType:$state);
let results = (outs AnySignlessInteger:$result);
let assemblyFormat = [{
$state `:` functional-type(operands, results) attr-dict
}];
}

def QuakeOp_GetStateOp : QuakeOp<"get_state", [Pure] > {
let summary = "Get state from kernel with the provided name.";
let description = [{
This operation is created by argument synthesis of state pointer arguments
for quantum devices. It takes a kernel name as ASCIIZ string literal value
and returns the kernel's quantum state. The operation is replaced by a call
to the kernel with the provided name in ReplaceStateByKernel pass.

```mlir
%0 = quake.get_state "callee" : !cc.ptr<!cc.state>
```
}];

let arguments = (ins StrAttr:$calleeName);
let results = (outs cc_PointerType:$result);
let assemblyFormat = [{
$calleeName `:` qualified(type(results)) attr-dict
}];
}

#endif // CUDAQ_OPTIMIZER_DIALECT_QUAKE_OPS
5 changes: 2 additions & 3 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -778,9 +778,8 @@ def DeleteStates : Pass<"delete-states", "mlir::ModuleOp"> {
func.func @foo() attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} {
%c8_i64 = arith.constant 8 : i64
%0 = cc.address_of @foo.rodata_synth_0 : !cc.ptr<!cc.array<complex<f32> x 8>>
%3 = cc.cast %0 : (!cc.ptr<!cc.array<complex<f32> x 8>>) -> !cc.ptr<i8>
%4 = call @__nvqpp_cudaq_state_createFromData_fp32(%3, %c8_i64) : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
%5 = call @__nvqpp_cudaq_state_numberOfQubits(%4) : (!cc.ptr<!cc.state>) -> i64
%4 = cc.create_state %3, %c8_i64 : (!cc.ptr<!cc.array<complex<f32> x 8>>, i64) -> !cc.ptr<!cc.state>
%5 = cc.get_number_of_qubits %4 : (!cc.ptr<!cc.state>) -> i64
%6 = quake.alloca !quake.veq<?>[%5 : i64]
%7 = quake.init_state %6, %4 : (!quake.veq<?>, !cc.ptr<!cc.state>) -> !quake.veq<?>

Expand Down
13 changes: 3 additions & 10 deletions lib/Frontend/nvqpp/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2698,19 +2698,12 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
initials = load.getPtrvalue();
}
if (isStateType(initials.getType())) {
IRBuilder irBuilder(builder.getContext());
auto mod =
builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
auto result =
irBuilder.loadIntrinsic(mod, getNumQubitsFromCudaqState);
assert(succeeded(result) && "loading intrinsic should never fail");
Value state = initials;
auto i64Ty = builder.getI64Type();
auto numQubits = builder.create<func::CallOp>(
loc, i64Ty, getNumQubitsFromCudaqState, ValueRange{state});
auto numQubits =
builder.create<quake::GetNumberOfQubitsOp>(loc, i64Ty, state);
auto veqTy = quake::VeqType::getUnsized(ctx);
Value alloc = builder.create<quake::AllocaOp>(loc, veqTy,
numQubits.getResult(0));
Value alloc = builder.create<quake::AllocaOp>(loc, veqTy, numQubits);
return pushValue(builder.create<quake::InitializeStateOp>(
loc, veqTy, alloc, state));
}
Expand Down
90 changes: 89 additions & 1 deletion lib/Optimizer/CodeGen/QuakeToCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

#include "QuakeToCodegen.h"
#include "CodeGenOps.h"
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/CodeGen/Passes.h"
#include "cudaq/Optimizer/CodeGen/QIRFunctionNames.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
Expand Down Expand Up @@ -62,10 +65,95 @@ class ExpandComplexCast : public OpRewritePattern<cudaq::cc::CastOp> {
return success();
}
};

class CreateStateOpPattern : public OpRewritePattern<quake::CreateStateOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::CreateStateOp createStateOp,
PatternRewriter &rewriter) const override {
auto module = createStateOp->getParentOfType<ModuleOp>();
auto loc = createStateOp.getLoc();
auto ctx = createStateOp.getContext();
auto buffer = createStateOp.getOperand(0);
auto size = createStateOp.getOperand(1);

auto bufferTy = buffer.getType();
auto ptrTy = cast<cudaq::cc::PointerType>(bufferTy);
auto arrTy = cast<cudaq::cc::ArrayType>(ptrTy.getElementType());
auto eleTy = arrTy.getElementType();
auto is64Bit = isa<Float64Type>(eleTy);

if (auto cTy = dyn_cast<ComplexType>(eleTy))
is64Bit = isa<Float64Type>(cTy.getElementType());

auto createStateFunc = is64Bit ? cudaq::createCudaqStateFromDataFP64
: cudaq::createCudaqStateFromDataFP32;
cudaq::IRBuilder irBuilder(ctx);
auto result = irBuilder.loadIntrinsic(module, createStateFunc);
assert(succeeded(result) && "loading intrinsic should never fail");

auto stateTy = cudaq::cc::StateType::get(ctx);
auto statePtrTy = cudaq::cc::PointerType::get(stateTy);
auto i8PtrTy = cudaq::cc::PointerType::get(rewriter.getI8Type());
auto cast = rewriter.create<cudaq::cc::CastOp>(loc, i8PtrTy, buffer);

rewriter.replaceOpWithNewOp<func::CallOp>(
createStateOp, statePtrTy, createStateFunc, ValueRange{cast, size});
return success();
}
};

class DeleteStateOpPattern : public OpRewritePattern<quake::DeleteStateOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::DeleteStateOp deleteStateOp,
PatternRewriter &rewriter) const override {
auto module = deleteStateOp->getParentOfType<ModuleOp>();
auto ctx = deleteStateOp.getContext();
auto state = deleteStateOp.getOperand();

cudaq::IRBuilder irBuilder(ctx);
auto result = irBuilder.loadIntrinsic(module, cudaq::deleteCudaqState);
assert(succeeded(result) && "loading intrinsic should never fail");

rewriter.replaceOpWithNewOp<func::CallOp>(deleteStateOp, std::nullopt,
cudaq::deleteCudaqState,
mlir::ValueRange{state});
return success();
}
};

class GetNumberOfQubitsOpPattern
: public OpRewritePattern<quake::GetNumberOfQubitsOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::GetNumberOfQubitsOp getNumQubitsOp,
PatternRewriter &rewriter) const override {
auto module = getNumQubitsOp->getParentOfType<ModuleOp>();
auto ctx = getNumQubitsOp.getContext();
auto state = getNumQubitsOp.getOperand();

cudaq::IRBuilder irBuilder(ctx);
auto result =
irBuilder.loadIntrinsic(module, cudaq::getNumQubitsFromCudaqState);
assert(succeeded(result) && "loading intrinsic should never fail");

rewriter.replaceOpWithNewOp<func::CallOp>(
getNumQubitsOp, rewriter.getI64Type(),
cudaq::getNumQubitsFromCudaqState, state);
return success();
}
};

} // namespace

void cudaq::codegen::populateQuakeToCodegenPatterns(
mlir::RewritePatternSet &patterns) {
auto *ctx = patterns.getContext();
patterns.insert<CodeGenRAIIPattern, ExpandComplexCast>(ctx);
patterns
.insert<CodeGenRAIIPattern, CreateStateOpPattern, DeleteStateOpPattern,
ExpandComplexCast, GetNumberOfQubitsOpPattern>(ctx);
}
Loading

0 comments on commit d2bed4c

Please sign in to comment.