15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Support/LogicalResult.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#define GEN_PASS_DEF_CIRFLATTENCFG
31#include "clang/CIR/Dialect/Passes.h.inc"
37void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
38 mlir::PatternRewriter &rewriter) {
39 assert(op->hasTrait<mlir::OpTrait::IsTerminator>() &&
"not a terminator");
40 mlir::OpBuilder::InsertionGuard guard(rewriter);
41 rewriter.setInsertionPoint(op);
42 rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);
47template <
typename... Ops>
48void walkRegionSkipping(
50 mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
51 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
53 return mlir::WalkResult::skip();
58struct CIRFlattenCFGPass :
public impl::CIRFlattenCFGBase<CIRFlattenCFGPass> {
60 CIRFlattenCFGPass() =
default;
61 void runOnOperation()
override;
64struct CIRIfFlattening :
public mlir::OpRewritePattern<cir::IfOp> {
65 using OpRewritePattern<IfOp>::OpRewritePattern;
68 matchAndRewrite(cir::IfOp ifOp,
69 mlir::PatternRewriter &rewriter)
const override {
70 mlir::OpBuilder::InsertionGuard guard(rewriter);
71 mlir::Location loc = ifOp.getLoc();
72 bool emptyElse = ifOp.getElseRegion().empty();
73 mlir::Block *currentBlock = rewriter.getInsertionBlock();
74 mlir::Block *remainingOpsBlock =
75 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
76 mlir::Block *continueBlock;
77 if (ifOp->getResults().empty())
78 continueBlock = remainingOpsBlock;
80 llvm_unreachable(
"NYI");
83 mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();
84 mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();
85 rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
87 rewriter.setInsertionPointToEnd(thenAfterBody);
88 if (
auto thenYieldOp =
89 dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {
90 rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),
94 rewriter.setInsertionPointToEnd(continueBlock);
97 mlir::Block *elseBeforeBody =
nullptr;
98 mlir::Block *elseAfterBody =
nullptr;
100 elseBeforeBody = &ifOp.getElseRegion().front();
101 elseAfterBody = &ifOp.getElseRegion().back();
102 rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);
104 elseBeforeBody = elseAfterBody = continueBlock;
107 rewriter.setInsertionPointToEnd(currentBlock);
108 cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody,
112 rewriter.setInsertionPointToEnd(elseAfterBody);
113 if (
auto elseYieldOP =
114 dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {
115 rewriter.replaceOpWithNewOp<cir::BrOp>(
116 elseYieldOP, elseYieldOP.getArgs(), continueBlock);
120 rewriter.replaceOp(ifOp, continueBlock->getArguments());
121 return mlir::success();
125class CIRScopeOpFlattening :
public mlir::OpRewritePattern<cir::ScopeOp> {
127 using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;
130 matchAndRewrite(cir::ScopeOp scopeOp,
131 mlir::PatternRewriter &rewriter)
const override {
132 mlir::OpBuilder::InsertionGuard guard(rewriter);
133 mlir::Location loc = scopeOp.getLoc();
141 if (scopeOp.isEmpty()) {
142 rewriter.eraseOp(scopeOp);
143 return mlir::success();
148 mlir::Block *currentBlock = rewriter.getInsertionBlock();
149 mlir::Block *continueBlock =
150 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
151 if (scopeOp.getNumResults() > 0)
152 continueBlock->addArguments(scopeOp.getResultTypes(), loc);
155 mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();
156 mlir::Block *afterBody = &scopeOp.getScopeRegion().back();
157 rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);
160 rewriter.setInsertionPointToEnd(currentBlock);
162 cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody);
166 rewriter.setInsertionPointToEnd(afterBody);
167 if (
auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {
168 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
173 rewriter.replaceOp(scopeOp, continueBlock->getArguments());
175 return mlir::success();
179class CIRSwitchOpFlattening :
public mlir::OpRewritePattern<cir::SwitchOp> {
181 using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
183 inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
184 cir::YieldOp yieldOp,
185 mlir::Block *destination)
const {
186 rewriter.setInsertionPoint(yieldOp);
187 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
192 Block *condBrToRangeDestination(cir::SwitchOp op,
193 mlir::PatternRewriter &rewriter,
194 mlir::Block *rangeDestination,
195 mlir::Block *defaultDestination,
196 const APInt &lowerBound,
197 const APInt &upperBound)
const {
198 assert(lowerBound.sle(upperBound) &&
"Invalid range");
199 mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
200 cir::IntType sIntType = cir::IntType::get(op.getContext(), 32,
true);
201 cir::IntType uIntType = cir::IntType::get(op.getContext(), 32,
false);
203 cir::ConstantOp rangeLength = cir::ConstantOp::create(
204 rewriter, op.getLoc(),
205 cir::IntAttr::get(sIntType, upperBound - lowerBound));
207 cir::ConstantOp lowerBoundValue = cir::ConstantOp::create(
208 rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
209 cir::BinOp diffValue =
210 cir::BinOp::create(rewriter, op.getLoc(), sIntType, cir::BinOpKind::Sub,
211 op.getCondition(), lowerBoundValue);
214 cir::CastOp uDiffValue = cir::CastOp::create(
215 rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue);
216 cir::CastOp uRangeLength = cir::CastOp::create(
217 rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength);
219 cir::CmpOp cmpResult = cir::CmpOp::create(
220 rewriter, op.getLoc(), cir::CmpOpKind::le, uDiffValue, uRangeLength);
221 cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination,
227 matchAndRewrite(cir::SwitchOp op,
228 mlir::PatternRewriter &rewriter)
const override {
229 llvm::SmallVector<CaseOp> cases;
230 op.collectCases(cases);
234 rewriter.eraseOp(op);
235 return mlir::success();
239 mlir::Block *exitBlock = rewriter.splitBlock(
240 rewriter.getBlock(), op->getNextNode()->getIterator());
253 cir::YieldOp switchYield =
nullptr;
255 for (mlir::Block &block :
256 llvm::make_early_inc_range(op.getBody().getBlocks()))
257 if (
auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
258 switchYield = yieldOp;
260 assert(!op.getBody().empty());
261 mlir::Block *originalBlock = op->getBlock();
262 mlir::Block *swopBlock =
263 rewriter.splitBlock(originalBlock, op->getIterator());
264 rewriter.inlineRegionBefore(op.getBody(), exitBlock);
267 rewriteYieldOp(rewriter, switchYield, exitBlock);
269 rewriter.setInsertionPointToEnd(originalBlock);
270 cir::BrOp::create(rewriter, op.getLoc(), swopBlock);
275 llvm::SmallVector<mlir::APInt, 8> caseValues;
276 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
277 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
279 llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
280 llvm::SmallVector<mlir::Block *> rangeDestinations;
281 llvm::SmallVector<mlir::ValueRange> rangeOperands;
284 mlir::Block *defaultDestination = exitBlock;
285 mlir::ValueRange defaultOperands = exitBlock->getArguments();
288 for (cir::CaseOp caseOp : cases) {
289 mlir::Region ®ion = caseOp.getCaseRegion();
292 switch (caseOp.getKind()) {
293 case cir::CaseOpKind::Default:
294 defaultDestination = ®ion.front();
295 defaultOperands = defaultDestination->getArguments();
297 case cir::CaseOpKind::Range:
298 assert(caseOp.getValue().size() == 2 &&
299 "Case range should have 2 case value");
300 rangeValues.push_back(
301 {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
302 cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
303 rangeDestinations.push_back(®ion.front());
304 rangeOperands.push_back(rangeDestinations.back()->getArguments());
306 case cir::CaseOpKind::Anyof:
307 case cir::CaseOpKind::Equal:
309 for (
const mlir::Attribute &value : caseOp.getValue()) {
310 caseValues.push_back(cast<cir::IntAttr>(value).getValue());
311 caseDestinations.push_back(®ion.front());
312 caseOperands.push_back(caseDestinations.back()->getArguments());
318 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
319 region, [&](mlir::Operation *op) {
320 if (!isa<cir::BreakOp>(op))
321 return mlir::WalkResult::advance();
323 lowerTerminator(op, exitBlock, rewriter);
324 return mlir::WalkResult::skip();
328 for (mlir::Block &blk : region.getBlocks()) {
329 if (blk.getNumSuccessors())
332 if (
auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
333 mlir::Operation *nextOp = caseOp->getNextNode();
334 assert(nextOp &&
"caseOp is not expected to be the last op");
335 mlir::Block *oldBlock = nextOp->getBlock();
336 mlir::Block *newBlock =
337 rewriter.splitBlock(oldBlock, nextOp->getIterator());
338 rewriter.setInsertionPointToEnd(oldBlock);
339 cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(),
341 rewriteYieldOp(rewriter, yieldOp, newBlock);
345 mlir::Block *oldBlock = caseOp->getBlock();
346 mlir::Block *newBlock =
347 rewriter.splitBlock(oldBlock, caseOp->getIterator());
349 mlir::Block &entryBlock = caseOp.getCaseRegion().front();
350 rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
353 rewriter.setInsertionPointToEnd(oldBlock);
354 cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock);
358 for (cir::CaseOp caseOp : cases) {
359 mlir::Block *caseBlock = caseOp->getBlock();
362 if (caseBlock->hasNoPredecessors())
363 rewriter.eraseBlock(caseBlock);
365 rewriter.eraseOp(caseOp);
368 for (
auto [rangeVal, operand, destination] :
369 llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
370 APInt lowerBound = rangeVal.first;
371 APInt upperBound = rangeVal.second;
374 if (lowerBound.sgt(upperBound))
379 constexpr int kSmallRangeThreshold = 64;
380 if ((upperBound - lowerBound)
381 .ult(llvm::APInt(32, kSmallRangeThreshold))) {
382 for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
383 caseValues.push_back(iValue);
384 caseOperands.push_back(operand);
385 caseDestinations.push_back(destination);
391 condBrToRangeDestination(op, rewriter, destination,
392 defaultDestination, lowerBound, upperBound);
393 defaultOperands = operand;
397 rewriter.setInsertionPoint(op);
398 rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
399 op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
400 caseDestinations, caseOperands);
402 return mlir::success();
406class CIRLoopOpInterfaceFlattening
407 :
public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
409 using mlir::OpInterfaceRewritePattern<
410 cir::LoopOpInterface>::OpInterfaceRewritePattern;
412 inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,
414 mlir::PatternRewriter &rewriter)
const {
415 mlir::OpBuilder::InsertionGuard guard(rewriter);
416 rewriter.setInsertionPoint(op);
417 rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,
422 matchAndRewrite(cir::LoopOpInterface op,
423 mlir::PatternRewriter &rewriter)
const final {
425 mlir::Block *entry = rewriter.getInsertionBlock();
427 rewriter.splitBlock(entry, rewriter.getInsertionPoint());
428 mlir::Block *cond = &op.getCond().front();
429 mlir::Block *body = &op.getBody().front();
431 (op.maybeGetStep() ? &op.maybeGetStep()->front() :
nullptr);
434 rewriter.setInsertionPointToEnd(entry);
435 cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front());
438 auto conditionOp = cast<cir::ConditionOp>(cond->getTerminator());
439 lowerConditionOp(conditionOp, body, exit, rewriter);
446 mlir::Block *dest = (
step ?
step : cond);
447 op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
448 if (!isa<cir::ContinueOp>(op))
449 return mlir::WalkResult::advance();
451 lowerTerminator(op, dest, rewriter);
452 return mlir::WalkResult::skip();
457 walkRegionSkipping<cir::LoopOpInterface>(
458 op.getBody(), [&](mlir::Operation *op) {
459 if (!isa<cir::BreakOp>(op))
460 return mlir::WalkResult::advance();
462 lowerTerminator(op, exit, rewriter);
463 return mlir::WalkResult::skip();
467 for (mlir::Block &blk : op.getBody().getBlocks()) {
468 auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());
470 lowerTerminator(bodyYield, (
step ?
step : cond), rewriter);
475 lowerTerminator(cast<cir::YieldOp>(
step->getTerminator()), cond,
479 rewriter.inlineRegionBefore(op.getCond(), exit);
480 rewriter.inlineRegionBefore(op.getBody(), exit);
482 rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);
484 rewriter.eraseOp(op);
485 return mlir::success();
489class CIRTernaryOpFlattening :
public mlir::OpRewritePattern<cir::TernaryOp> {
491 using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
494 matchAndRewrite(cir::TernaryOp op,
495 mlir::PatternRewriter &rewriter)
const override {
496 Location loc = op->getLoc();
497 Block *condBlock = rewriter.getInsertionBlock();
498 Block::iterator opPosition = rewriter.getInsertionPoint();
499 Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
500 llvm::SmallVector<mlir::Location, 2> locs;
503 if (op->getResultTypes().size())
505 Block *continueBlock =
506 rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
507 cir::BrOp::create(rewriter, loc, remainingOpsBlock);
509 Region &trueRegion = op.getTrueRegion();
510 Block *trueBlock = &trueRegion.front();
511 mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
512 rewriter.setInsertionPointToEnd(&trueRegion.back());
515 if (
auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator)) {
516 rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
518 }
else if (isa<cir::UnreachableOp>(trueTerminator)) {
521 trueTerminator->emitError(
"unexpected terminator in ternary true region, "
522 "expected yield or unreachable, got: ")
523 << trueTerminator->getName();
524 return mlir::failure();
526 rewriter.inlineRegionBefore(trueRegion, continueBlock);
528 Block *falseBlock = continueBlock;
529 Region &falseRegion = op.getFalseRegion();
531 falseBlock = &falseRegion.front();
532 mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
533 rewriter.setInsertionPointToEnd(&falseRegion.back());
536 if (
auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator)) {
537 rewriter.replaceOpWithNewOp<cir::BrOp>(
538 falseYieldOp, falseYieldOp.getArgs(), continueBlock);
539 }
else if (isa<cir::UnreachableOp>(falseTerminator)) {
542 falseTerminator->emitError(
"unexpected terminator in ternary false "
543 "region, expected yield or unreachable, got: ")
544 << falseTerminator->getName();
545 return mlir::failure();
547 rewriter.inlineRegionBefore(falseRegion, continueBlock);
549 rewriter.setInsertionPointToEnd(condBlock);
550 cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock);
552 rewriter.replaceOp(op, continueBlock->getArguments());
555 return mlir::success();
559class CIRTryOpFlattening :
public mlir::OpRewritePattern<cir::TryOp> {
561 using OpRewritePattern<cir::TryOp>::OpRewritePattern;
563 mlir::Block *buildTryBody(cir::TryOp tryOp,
564 mlir::PatternRewriter &rewriter)
const {
567 mlir::Block *beforeTryScopeBlock = rewriter.getInsertionBlock();
568 mlir::Block *afterTry =
569 rewriter.splitBlock(beforeTryScopeBlock, rewriter.getInsertionPoint());
572 mlir::Block *beforeBody = &tryOp.getTryRegion().front();
573 rewriter.inlineRegionBefore(tryOp.getTryRegion(), afterTry);
576 rewriter.setInsertionPointToEnd(beforeTryScopeBlock);
577 cir::BrOp::create(rewriter, tryOp.getLoc(), mlir::ValueRange(), beforeBody);
581 void buildHandlers(cir::TryOp tryOp, mlir::PatternRewriter &rewriter,
582 mlir::Block *afterBody, mlir::Block *afterTry,
583 SmallVectorImpl<cir::CallOp> &callsToRewrite,
584 SmallVectorImpl<mlir::Block *> &landingPads)
const {
586 rewriter.setInsertionPointToEnd(afterBody);
588 mlir::Block *beforeCatch = rewriter.getInsertionBlock();
589 rewriter.setInsertionPointToEnd(beforeCatch);
593 if (
auto tryBodyYield = dyn_cast<cir::YieldOp>(afterBody->getTerminator()))
594 rewriter.replaceOpWithNewOp<cir::BrOp>(tryBodyYield, afterTry);
596 mlir::ArrayAttr handlers = tryOp.getHandlerTypesAttr();
597 if (!handlers || handlers.empty())
600 llvm_unreachable(
"TryOpFlattening buildHandlers with CallsOp is NYI");
604 matchAndRewrite(cir::TryOp tryOp,
605 mlir::PatternRewriter &rewriter)
const override {
606 mlir::OpBuilder::InsertionGuard guard(rewriter);
607 mlir::Block *afterBody = &tryOp.getTryRegion().back();
611 llvm::SmallVector<cir::CallOp, 4> callsToRewrite;
612 tryOp.getTryRegion().walk([&](CallOp op) {
617 if (op->getParentOfType<cir::TryOp>() != tryOp)
619 callsToRewrite.push_back(op);
622 if (!callsToRewrite.empty())
624 "TryOpFlattening with try block that contains CallOps is NYI");
627 mlir::Block *afterTry = buildTryBody(tryOp, rewriter);
630 llvm::SmallVector<mlir::Block *, 4> landingPads;
631 buildHandlers(tryOp, rewriter, afterBody, afterTry, callsToRewrite,
634 rewriter.eraseOp(tryOp);
636 assert((landingPads.size() == callsToRewrite.size()) &&
637 "expected matching number of entries");
640 auto brOp = dyn_cast<cir::BrOp>(afterTry->getTerminator());
641 if (brOp && brOp.getDest()->hasNoPredecessors()) {
642 mlir::Block *srcBlock = brOp.getDest();
643 rewriter.eraseOp(brOp);
644 rewriter.mergeBlocks(srcBlock, afterTry);
647 return mlir::success();
651void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
653 .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
654 CIRSwitchOpFlattening, CIRTernaryOpFlattening, CIRTryOpFlattening>(
655 patterns.getContext());
658void CIRFlattenCFGPass::runOnOperation() {
659 RewritePatternSet patterns(&getContext());
660 populateFlattenCFGPatterns(patterns);
663 llvm::SmallVector<Operation *, 16> ops;
664 getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
668 if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, TryOp>(op))
673 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
682 return std::make_unique<CIRFlattenCFGPass>();
std::unique_ptr< Pass > createCIRFlattenCFGPass()
float __ovld __cnfn step(float, float)
Returns 0.0 if x < edge, otherwise it returns 1.0.
static bool stackSaveOp()