Skip to content

Commit 373f608

Browse files
authored
Merge branch 'main' into use-zstd-layers
2 parents 7931a92 + ee90fa0 commit 373f608

File tree

30 files changed

+533
-76
lines changed

30 files changed

+533
-76
lines changed

.github/workflows/config/spelling_allowlist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ runtimes
313313
rvalue
314314
scalability
315315
scalable
316+
selectable
316317
sexualized
317318
shifter
318319
shifters

include/cudaq/Optimizer/Builder/Factory.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,23 @@ mlir::Type genArgumentBufferType(mlir::Type ty);
8989
///
9090
/// A kernel signature of
9191
/// ```c++
92-
/// i32_t operator() (i16_t, std::vector<double>, double);
92+
/// i32_t operator() (i16_t, std::vector<double>, double);
9393
/// ```
9494
/// will generate the LLVM struct
9595
/// ```llvm
96-
/// { i16, i64, double, i32 }
96+
/// { i16, i64, double, i32 }
9797
/// ```
9898
/// where the values of the vector argument are pass-by-value and appended to
9999
/// the end of the struct as a sequence of \i n double values.
100100
///
101101
/// The leading `startingArgIdx + 1` parameters are omitted from the struct.
102+
///
103+
/// NB: It is DEEPLY INCORRECT to add a packed attribute to this data structure
104+
/// and pass it to other APIs, since there is absolutely, positively NO chance
105+
/// that foreign code will be able to decode this buffer correctly. To do so
106+
/// requires information that is erased by the NVQ++ compiler.
102107
cudaq::cc::StructType buildInvokeStructType(mlir::FunctionType funcTy,
103-
std::size_t startingArgIdx = 0,
104-
bool packed = false);
108+
std::size_t startingArgIdx = 0);
105109

106110
/// Return the LLVM-IR dialect type: `[length x i8]`.
107111
inline mlir::Type getStringType(mlir::MLIRContext *ctx, std::size_t length) {

lib/Optimizer/Builder/Factory.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,17 @@ Type factory::genArgumentBufferType(Type ty) {
7878
return genBufferType</*isOutput=*/false>(ty);
7979
}
8080

81-
cudaq::cc::StructType factory::buildInvokeStructType(FunctionType funcTy,
82-
std::size_t startingArgIdx,
83-
bool packed) {
81+
cudaq::cc::StructType
82+
factory::buildInvokeStructType(FunctionType funcTy,
83+
std::size_t startingArgIdx) {
8484
auto *ctx = funcTy.getContext();
8585
SmallVector<Type> eleTys;
8686
for (auto inTy : llvm::enumerate(funcTy.getInputs()))
8787
if (inTy.index() >= startingArgIdx)
8888
eleTys.push_back(genBufferType</*isOutput=*/false>(inTy.value()));
8989
for (auto outTy : funcTy.getResults())
9090
eleTys.push_back(genBufferType</*isOutput=*/true>(outTy));
91-
return cudaq::cc::StructType::get(ctx, eleTys, packed);
91+
return cudaq::cc::StructType::get(ctx, eleTys, /*packed=*/false);
9292
}
9393

9494
Value factory::packIsArrayAndLengthArray(Location loc,

lib/Optimizer/CodeGen/VerifyQIRProfile.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ struct VerifyQIRProfilePass
4646
if (!func->hasAttr(cudaq::entryPointAttrName))
4747
return;
4848
auto *ctx = &getContext();
49-
bool isBaseProfile = convertTo.getValue() == "qir-base";
49+
const bool isBaseProfile = convertTo.getValue() == "qir-base";
5050
func.walk([&](Operation *op) {
5151
if (auto call = dyn_cast<LLVM::CallOp>(op)) {
5252
auto funcNameAttr = call.getCalleeAttr();
5353
if (!funcNameAttr)
5454
return WalkResult::advance();
5555
auto funcName = funcNameAttr.getValue();
56-
if (!funcName.startswith("__quantum_") ||
57-
funcName == cudaq::opt::QIRCustomOp) {
56+
if (isBaseProfile && (!funcName.startswith("__quantum_") ||
57+
funcName.equals(cudaq::opt::QIRCustomOp))) {
5858
call.emitOpError("unexpected call in QIR base profile");
5959
passFailed = true;
6060
return WalkResult::advance();

runtime/common/BaseRemoteRESTQPU.h

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "cudaq/platform/qpu.h"
3535
#include "cudaq/platform/quantum_platform.h"
3636
#include "nvqpp_config.h"
37+
#include "llvm/ADT/SmallSet.h"
3738
#include "llvm/Bitcode/BitcodeReader.h"
3839
#include "llvm/Bitcode/BitcodeWriter.h"
3940
#include "llvm/IR/Module.h"
@@ -444,6 +445,87 @@ class BaseRemoteRESTQPU : public cudaq::QPU {
444445
auto func = m_module.lookupSymbol<mlir::func::FuncOp>(
445446
std::string(cudaq::runtime::cudaqGenPrefixName) + kernelName);
446447

448+
llvm::SmallVector<mlir::func::FuncOp> newFuncOpsWithDefinitions;
449+
llvm::SmallSet<std::string, 4> deviceCallCallees;
450+
// For every declaration without a definition, we need to try to find the
451+
// function in the Quake registry and copy the functions into this module.
452+
m_module.walk([&](mlir::func::FuncOp funcOp) {
453+
if (!funcOp.isDeclaration()) {
454+
// Skipping function because it already has a definition.
455+
return mlir::WalkResult::advance();
456+
}
457+
// Definition doesn't exist, so we need to find it in the Quake registry.
458+
mlir::StringRef fullFuncName = funcOp.getName();
459+
mlir::StringRef kernelName = [fullFuncName]() {
460+
mlir::StringRef retVal = fullFuncName;
461+
// TODO - clean this up to not have to do this. Considering the module's
462+
// map, or cudaq::details::getKernelName(). But make sure it works for
463+
// standard C++ functions.
464+
465+
// Only get the portion before the first ".".
466+
if (auto ix = fullFuncName.find("."); ix != mlir::StringRef::npos)
467+
retVal = fullFuncName.substr(0, ix);
468+
// Also strip out __nvqpp_mlirgen__function_ from the beginning of the
469+
// function name.
470+
if (retVal.starts_with(cudaq::runtime::cudaqGenPrefixName)) {
471+
retVal = retVal.substr(cudaq::runtime::cudaqGenPrefixLength);
472+
}
473+
return retVal;
474+
}();
475+
std::string quakeCode =
476+
kernelName.empty()
477+
? ""
478+
: cudaq::get_quake_by_name(kernelName.str(),
479+
/*throwException=*/false);
480+
if (quakeCode.empty()) {
481+
// Skipping function because it does not have a quake code.
482+
return mlir::WalkResult::advance();
483+
}
484+
auto tmp_module =
485+
parseSourceString<mlir::ModuleOp>(quakeCode, contextPtr);
486+
auto tmpFuncOpWithDefinition =
487+
tmp_module->lookupSymbol<mlir::func::FuncOp>(fullFuncName);
488+
auto newNameAttr = mlir::StringAttr::get(m_module.getContext(),
489+
fullFuncName.str() + ".stitch");
490+
auto clonedFunc = tmpFuncOpWithDefinition.clone();
491+
clonedFunc.setName(newNameAttr);
492+
mlir::SymbolTable symTable(m_module);
493+
symTable.insert(clonedFunc);
494+
newFuncOpsWithDefinitions.push_back(clonedFunc);
495+
496+
if (failed(mlir::SymbolTable::replaceAllSymbolUses(
497+
funcOp.getOperation(), newNameAttr, m_module.getOperation()))) {
498+
throw std::runtime_error(
499+
fmt::format("Failed to replace symbol uses for function {}",
500+
fullFuncName.str()));
501+
}
502+
return mlir::WalkResult::advance();
503+
});
504+
505+
// For each one of the added functions, we need to traverse them to find
506+
// device calls (in order to create declarations for them)
507+
for (auto &funcOp : newFuncOpsWithDefinitions) {
508+
mlir::OpBuilder builder(m_module);
509+
builder.setInsertionPointToStart(m_module.getBody());
510+
funcOp.walk([&](cudaq::cc::DeviceCallOp deviceCall) {
511+
auto calleeName = deviceCall.getCallee();
512+
// If the callee is already in the symbol table, nothing to do.
513+
if (m_module.lookupSymbol<mlir::func::FuncOp>(calleeName))
514+
return;
515+
516+
// Otherwise, we need to create a declaration for the callback function.
517+
auto argTypes = deviceCall.getArgs().getTypes();
518+
auto resTypes = deviceCall.getResultTypes();
519+
auto funcType = builder.getFunctionType(argTypes, resTypes);
520+
521+
// Create a *declaration* (no body) for the callback function.
522+
[[maybe_unused]] auto decl = builder.create<mlir::func::FuncOp>(
523+
deviceCall.getLoc(), calleeName, funcType);
524+
decl.setPrivate();
525+
deviceCallCallees.insert(calleeName.str());
526+
});
527+
}
528+
447529
// Create a new Module to clone the function into
448530
auto location = mlir::FileLineColLoc::get(&context, "<builder>", 1, 1);
449531
mlir::ImplicitLocOpBuilder builder(location, &context);
@@ -463,7 +545,7 @@ class BaseRemoteRESTQPU : public cudaq::QPU {
463545
// passes.
464546
if (auto lfunc = dyn_cast<mlir::func::FuncOp>(op)) {
465547
bool skip = lfunc.getName().ends_with(".thunk");
466-
if (!skip)
548+
if (!skip && !deviceCallCallees.contains(lfunc.getName().str()))
467549
for (auto &entry : mangledNameMap)
468550
if (lfunc.getName() ==
469551
cast<mlir::StringAttr>(entry.getValue()).getValue()) {

runtime/common/RuntimeMLIR.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
******************************************************************************/
88

99
#include "RuntimeMLIR.h"
10+
#include "ThunkInterface.h"
1011
#include "cudaq/Optimizer/Builder/Runtime.h"
1112
#include "cudaq/Optimizer/CodeGen/IQMJsonEmitter.h"
1213
#include "cudaq/Optimizer/CodeGen/OpenQASMEmitter.h"
@@ -16,6 +17,7 @@
1617
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
1718
#include "cudaq/Optimizer/InitAllDialects.h"
1819
#include "cudaq/Optimizer/InitAllPasses.h"
20+
#include "cudaq/Support/TargetConfig.h"
1921
#include "llvm/Bitcode/BitcodeWriter.h"
2022
#include "llvm/IR/Instructions.h"
2123
#include "llvm/MC/SubtargetFeature.h"
@@ -93,4 +95,22 @@ std::unique_ptr<MLIRContext> initializeMLIR() {
9395
return context;
9496
}
9597

98+
std::optional<std::string> getEntryPointName(OwningOpRef<ModuleOp> &module) {
99+
std::string name;
100+
// FIXME: don't use a recursive `walk` to DFS for FuncOps, which appear as
101+
// children, in a Module.
102+
module->walk([&name](mlir::func::FuncOp op) {
103+
if (op.getName().endswith(".thunk")) {
104+
name = op.getName();
105+
return mlir::WalkResult::interrupt();
106+
}
107+
return mlir::WalkResult::advance();
108+
});
109+
110+
if (!name.empty())
111+
return name;
112+
113+
return std::nullopt;
114+
}
115+
96116
} // namespace cudaq

0 commit comments

Comments
 (0)