34
34
#include " cudaq/platform/qpu.h"
35
35
#include " cudaq/platform/quantum_platform.h"
36
36
#include " nvqpp_config.h"
37
+ #include " llvm/ADT/SmallSet.h"
37
38
#include " llvm/Bitcode/BitcodeReader.h"
38
39
#include " llvm/Bitcode/BitcodeWriter.h"
39
40
#include " llvm/IR/Module.h"
@@ -444,6 +445,87 @@ class BaseRemoteRESTQPU : public cudaq::QPU {
444
445
auto func = m_module.lookupSymbol <mlir::func::FuncOp>(
445
446
std::string (cudaq::runtime::cudaqGenPrefixName) + kernelName);
446
447
448
+ llvm::SmallVector<mlir::func::FuncOp> newFuncOpsWithDefinitions;
449
+ llvm::SmallSet<std::string, 4 > deviceCallCallees;
450
+ // For every declaration without a definition, we need to try to find the
451
+ // function in the Quake registry and copy the functions into this module.
452
+ m_module.walk ([&](mlir::func::FuncOp funcOp) {
453
+ if (!funcOp.isDeclaration ()) {
454
+ // Skipping function because it already has a definition.
455
+ return mlir::WalkResult::advance ();
456
+ }
457
+ // Definition doesn't exist, so we need to find it in the Quake registry.
458
+ mlir::StringRef fullFuncName = funcOp.getName ();
459
+ mlir::StringRef kernelName = [fullFuncName]() {
460
+ mlir::StringRef retVal = fullFuncName;
461
+ // TODO - clean this up to not have to do this. Considering the module's
462
+ // map, or cudaq::details::getKernelName(). But make sure it works for
463
+ // standard C++ functions.
464
+
465
+ // Only get the portion before the first ".".
466
+ if (auto ix = fullFuncName.find (" ." ); ix != mlir::StringRef::npos)
467
+ retVal = fullFuncName.substr (0 , ix);
468
+ // Also strip out __nvqpp_mlirgen__function_ from the beginning of the
469
+ // function name.
470
+ if (retVal.starts_with (cudaq::runtime::cudaqGenPrefixName)) {
471
+ retVal = retVal.substr (cudaq::runtime::cudaqGenPrefixLength);
472
+ }
473
+ return retVal;
474
+ }();
475
+ std::string quakeCode =
476
+ kernelName.empty ()
477
+ ? " "
478
+ : cudaq::get_quake_by_name (kernelName.str (),
479
+ /* throwException=*/ false );
480
+ if (quakeCode.empty ()) {
481
+ // Skipping function because it does not have a quake code.
482
+ return mlir::WalkResult::advance ();
483
+ }
484
+ auto tmp_module =
485
+ parseSourceString<mlir::ModuleOp>(quakeCode, contextPtr);
486
+ auto tmpFuncOpWithDefinition =
487
+ tmp_module->lookupSymbol <mlir::func::FuncOp>(fullFuncName);
488
+ auto newNameAttr = mlir::StringAttr::get (m_module.getContext (),
489
+ fullFuncName.str () + " .stitch" );
490
+ auto clonedFunc = tmpFuncOpWithDefinition.clone ();
491
+ clonedFunc.setName (newNameAttr);
492
+ mlir::SymbolTable symTable (m_module);
493
+ symTable.insert (clonedFunc);
494
+ newFuncOpsWithDefinitions.push_back (clonedFunc);
495
+
496
+ if (failed (mlir::SymbolTable::replaceAllSymbolUses (
497
+ funcOp.getOperation (), newNameAttr, m_module.getOperation ()))) {
498
+ throw std::runtime_error (
499
+ fmt::format (" Failed to replace symbol uses for function {}" ,
500
+ fullFuncName.str ()));
501
+ }
502
+ return mlir::WalkResult::advance ();
503
+ });
504
+
505
+ // For each one of the added functions, we need to traverse them to find
506
+ // device calls (in order to create declarations for them)
507
+ for (auto &funcOp : newFuncOpsWithDefinitions) {
508
+ mlir::OpBuilder builder (m_module);
509
+ builder.setInsertionPointToStart (m_module.getBody ());
510
+ funcOp.walk ([&](cudaq::cc::DeviceCallOp deviceCall) {
511
+ auto calleeName = deviceCall.getCallee ();
512
+ // If the callee is already in the symbol table, nothing to do.
513
+ if (m_module.lookupSymbol <mlir::func::FuncOp>(calleeName))
514
+ return ;
515
+
516
+ // Otherwise, we need to create a declaration for the callback function.
517
+ auto argTypes = deviceCall.getArgs ().getTypes ();
518
+ auto resTypes = deviceCall.getResultTypes ();
519
+ auto funcType = builder.getFunctionType (argTypes, resTypes);
520
+
521
+ // Create a *declaration* (no body) for the callback function.
522
+ [[maybe_unused]] auto decl = builder.create <mlir::func::FuncOp>(
523
+ deviceCall.getLoc (), calleeName, funcType);
524
+ decl.setPrivate ();
525
+ deviceCallCallees.insert (calleeName.str ());
526
+ });
527
+ }
528
+
447
529
// Create a new Module to clone the function into
448
530
auto location = mlir::FileLineColLoc::get (&context, " <builder>" , 1 , 1 );
449
531
mlir::ImplicitLocOpBuilder builder (location, &context);
@@ -463,7 +545,7 @@ class BaseRemoteRESTQPU : public cudaq::QPU {
463
545
// passes.
464
546
if (auto lfunc = dyn_cast<mlir::func::FuncOp>(op)) {
465
547
bool skip = lfunc.getName ().ends_with (" .thunk" );
466
- if (!skip)
548
+ if (!skip && !deviceCallCallees. contains (lfunc. getName (). str ()) )
467
549
for (auto &entry : mangledNameMap)
468
550
if (lfunc.getName () ==
469
551
cast<mlir::StringAttr>(entry.getValue ()).getValue ()) {
0 commit comments