clang 22.0.0git
FlattenCFG.cpp
Go to the documentation of this file.
1/===----------------------------------------------------------------------===/
2/
3/ Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4/ See https://llvm.org/LICENSE.txt for license information.
5/ SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6/
7/===----------------------------------------------------------------------===/
8/
9/ This file implements pass that inlines CIR operations regions into the parent
10/ function region.
11/
12/===----------------------------------------------------------------------===/
13
14#include "PassDetail.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Block.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Support/LogicalResult.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26using namespace mlir;
27using namespace cir;
28
29namespace mlir {
30#define GEN_PASS_DEF_CIRFLATTENCFG
31#include "clang/CIR/Dialect/Passes.h.inc"
32} / namespace mlir
33
34namespace {
35
36/ Lowers operations with the terminator trait that have a single successor.
37void lowerTerminator(mlir::Operation *op, mlir::Block *dest,
38 mlir::PatternRewriter &rewriter) {
39 assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator");
40 mlir::OpBuilder::InsertionGuard guard(rewriter);
41 rewriter.setInsertionPoint(op);
42 rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);
43}
44
45/ Walks a region while skipping operations of type `Ops`. This ensures the
46/ callback is not applied to said operations and its children.
47template <typename... Ops>
48void walkRegionSkipping(
49 mlir::Region &region,
50 mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {
51 region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
52 if (isa<Ops...>(op))
53 return mlir::WalkResult::skip();
54 return callback(op);
55 });
56}
57
58struct CIRFlattenCFGPass : public impl::CIRFlattenCFGBase<CIRFlattenCFGPass> {
59
60 CIRFlattenCFGPass() = default;
61 void runOnOperation() override;
62};
63
64struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> {
65 using OpRewritePattern<IfOp>::OpRewritePattern;
66
67 mlir::LogicalResult
68 matchAndRewrite(cir::IfOp ifOp,
69 mlir::PatternRewriter &rewriter) const override {
70 mlir::OpBuilder::InsertionGuard guard(rewriter);
71 mlir::Location loc = ifOp.getLoc();
72 bool emptyElse = ifOp.getElseRegion().empty();
73 mlir::Block *currentBlock = rewriter.getInsertionBlock();
74 mlir::Block *remainingOpsBlock =
75 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
76 mlir::Block *continueBlock;
77 if (ifOp->getResults().empty())
78 continueBlock = remainingOpsBlock;
79 else
80 llvm_unreachable("NYI");
81
82 / Inline the region
83 mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();
84 mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();
85 rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);
86
87 rewriter.setInsertionPointToEnd(thenAfterBody);
88 if (auto thenYieldOp =
89 dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {
90 rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),
91 continueBlock);
92 }
93
94 rewriter.setInsertionPointToEnd(continueBlock);
95
96 / Has else region: inline it.
97 mlir::Block *elseBeforeBody = nullptr;
98 mlir::Block *elseAfterBody = nullptr;
99 if (!emptyElse) {
100 elseBeforeBody = &ifOp.getElseRegion().front();
101 elseAfterBody = &ifOp.getElseRegion().back();
102 rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);
103 } else {
104 elseBeforeBody = elseAfterBody = continueBlock;
105 }
106
107 rewriter.setInsertionPointToEnd(currentBlock);
108 cir::BrCondOp::create(rewriter, loc, ifOp.getCondition(), thenBeforeBody,
109 elseBeforeBody);
110
111 if (!emptyElse) {
112 rewriter.setInsertionPointToEnd(elseAfterBody);
113 if (auto elseYieldOP =
114 dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {
115 rewriter.replaceOpWithNewOp<cir::BrOp>(
116 elseYieldOP, elseYieldOP.getArgs(), continueBlock);
117 }
118 }
119
120 rewriter.replaceOp(ifOp, continueBlock->getArguments());
121 return mlir::success();
122 }
123};
124
125class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {
126public:
127 using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;
128
129 mlir::LogicalResult
130 matchAndRewrite(cir::ScopeOp scopeOp,
131 mlir::PatternRewriter &rewriter) const override {
132 mlir::OpBuilder::InsertionGuard guard(rewriter);
133 mlir::Location loc = scopeOp.getLoc();
134
135 / Empty scope: just remove it.
136 / TODO: Remove this logic once CIR uses MLIR infrastructure to remove
137 / trivially dead operations. MLIR canonicalizer is too aggressive and we
138 / need to either (a) make sure all our ops model all side-effects and/or
139 / (b) have more options in the canonicalizer in MLIR to temper
140 / aggressiveness level.
141 if (scopeOp.isEmpty()) {
142 rewriter.eraseOp(scopeOp);
143 return mlir::success();
144 }
145
146 / Split the current block before the ScopeOp to create the inlining
147 / point.
148 mlir::Block *currentBlock = rewriter.getInsertionBlock();
149 mlir::Block *continueBlock =
150 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
151 if (scopeOp.getNumResults() > 0)
152 continueBlock->addArguments(scopeOp.getResultTypes(), loc);
153
154 / Inline body region.
155 mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();
156 mlir::Block *afterBody = &scopeOp.getScopeRegion().back();
157 rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);
158
159 / Save stack and then branch into the body of the region.
160 rewriter.setInsertionPointToEnd(currentBlock);
162 cir::BrOp::create(rewriter, loc, mlir::ValueRange(), beforeBody);
163
164 / Replace the scopeop return with a branch that jumps out of the body.
165 / Stack restore before leaving the body region.
166 rewriter.setInsertionPointToEnd(afterBody);
167 if (auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {
168 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),
169 continueBlock);
170 }
171
172 / Replace the op with values return from the body region.
173 rewriter.replaceOp(scopeOp, continueBlock->getArguments());
174
175 return mlir::success();
176 }
177};
178
179class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {
180public:
181 using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;
182
183 inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,
184 cir::YieldOp yieldOp,
185 mlir::Block *destination) const {
186 rewriter.setInsertionPoint(yieldOp);
187 rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),
188 destination);
189 }
190
191 / Return the new defaultDestination block.
192 Block *condBrToRangeDestination(cir::SwitchOp op,
193 mlir::PatternRewriter &rewriter,
194 mlir::Block *rangeDestination,
195 mlir::Block *defaultDestination,
196 const APInt &lowerBound,
197 const APInt &upperBound) const {
198 assert(lowerBound.sle(upperBound) && "Invalid range");
199 mlir::Block *resBlock = rewriter.createBlock(defaultDestination);
200 cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);
201 cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);
202
203 cir::ConstantOp rangeLength = cir::ConstantOp::create(
204 rewriter, op.getLoc(),
205 cir::IntAttr::get(sIntType, upperBound - lowerBound));
206
207 cir::ConstantOp lowerBoundValue = cir::ConstantOp::create(
208 rewriter, op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));
209 cir::BinOp diffValue =
210 cir::BinOp::create(rewriter, op.getLoc(), sIntType, cir::BinOpKind::Sub,
211 op.getCondition(), lowerBoundValue);
212
213 / Use unsigned comparison to check if the condition is in the range.
214 cir::CastOp uDiffValue = cir::CastOp::create(
215 rewriter, op.getLoc(), uIntType, CastKind::integral, diffValue);
216 cir::CastOp uRangeLength = cir::CastOp::create(
217 rewriter, op.getLoc(), uIntType, CastKind::integral, rangeLength);
218
219 cir::CmpOp cmpResult = cir::CmpOp::create(
220 rewriter, op.getLoc(), cir::CmpOpKind::le, uDiffValue, uRangeLength);
221 cir::BrCondOp::create(rewriter, op.getLoc(), cmpResult, rangeDestination,
222 defaultDestination);
223 return resBlock;
224 }
225
226 mlir::LogicalResult
227 matchAndRewrite(cir::SwitchOp op,
228 mlir::PatternRewriter &rewriter) const override {
229 llvm::SmallVector<CaseOp> cases;
230 op.collectCases(cases);
231
232 / Empty switch statement: just erase it.
233 if (cases.empty()) {
234 rewriter.eraseOp(op);
235 return mlir::success();
236 }
237
238 / Create exit block from the next node of cir.switch op.
239 mlir::Block *exitBlock = rewriter.splitBlock(
240 rewriter.getBlock(), op->getNextNode()->getIterator());
241
242 / We lower cir.switch op in the following process:
243 / 1. Inline the region from the switch op after switch op.
244 / 2. Traverse each cir.case op:
245 / a. Record the entry block, block arguments and condition for every
246 / case. b. Inline the case region after the case op.
247 / 3. Replace the empty cir.switch.op with the new cir.switchflat op by the
248 / recorded block and conditions.
249
250 / inline everything from switch body between the switch op and the exit
251 / block.
252 {
253 cir::YieldOp switchYield = nullptr;
254 / Clear switch operation.
255 for (mlir::Block &block :
256 llvm::make_early_inc_range(op.getBody().getBlocks()))
257 if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))
258 switchYield = yieldOp;
259
260 assert(!op.getBody().empty());
261 mlir::Block *originalBlock = op->getBlock();
262 mlir::Block *swopBlock =
263 rewriter.splitBlock(originalBlock, op->getIterator());
264 rewriter.inlineRegionBefore(op.getBody(), exitBlock);
265
266 if (switchYield)
267 rewriteYieldOp(rewriter, switchYield, exitBlock);
268
269 rewriter.setInsertionPointToEnd(originalBlock);
270 cir::BrOp::create(rewriter, op.getLoc(), swopBlock);
271 }
272
273 / Allocate required data structures (disconsider default case in
274 / vectors).
275 llvm::SmallVector<mlir::APInt, 8> caseValues;
276 llvm::SmallVector<mlir::Block *, 8> caseDestinations;
277 llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
278
279 llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;
280 llvm::SmallVector<mlir::Block *> rangeDestinations;
281 llvm::SmallVector<mlir::ValueRange> rangeOperands;
282
283 / Initialize default case as optional.
284 mlir::Block *defaultDestination = exitBlock;
285 mlir::ValueRange defaultOperands = exitBlock->getArguments();
286
287 / Digest the case statements values and bodies.
288 for (cir::CaseOp caseOp : cases) {
289 mlir::Region &region = caseOp.getCaseRegion();
290
291 / Found default case: save destination and operands.
292 switch (caseOp.getKind()) {
293 case cir::CaseOpKind::Default:
294 defaultDestination = &region.front();
295 defaultOperands = defaultDestination->getArguments();
296 break;
297 case cir::CaseOpKind::Range:
298 assert(caseOp.getValue().size() == 2 &&
299 "Case range should have 2 case value");
300 rangeValues.push_back(
301 {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),
302 cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});
303 rangeDestinations.push_back(&region.front());
304 rangeOperands.push_back(rangeDestinations.back()->getArguments());
305 break;
306 case cir::CaseOpKind::Anyof:
307 case cir::CaseOpKind::Equal:
308 / AnyOf cases kind can have multiple values, hence the loop below.
309 for (const mlir::Attribute &value : caseOp.getValue()) {
310 caseValues.push_back(cast<cir::IntAttr>(value).getValue());
311 caseDestinations.push_back(&region.front());
312 caseOperands.push_back(caseDestinations.back()->getArguments());
313 }
314 break;
315 }
316
317 / Handle break statements.
318 walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(
319 region, [&](mlir::Operation *op) {
320 if (!isa<cir::BreakOp>(op))
321 return mlir::WalkResult::advance();
322
323 lowerTerminator(op, exitBlock, rewriter);
324 return mlir::WalkResult::skip();
325 });
326
327 / Track fallthrough in cases.
328 for (mlir::Block &blk : region.getBlocks()) {
329 if (blk.getNumSuccessors())
330 continue;
331
332 if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {
333 mlir::Operation *nextOp = caseOp->getNextNode();
334 assert(nextOp && "caseOp is not expected to be the last op");
335 mlir::Block *oldBlock = nextOp->getBlock();
336 mlir::Block *newBlock =
337 rewriter.splitBlock(oldBlock, nextOp->getIterator());
338 rewriter.setInsertionPointToEnd(oldBlock);
339 cir::BrOp::create(rewriter, nextOp->getLoc(), mlir::ValueRange(),
340 newBlock);
341 rewriteYieldOp(rewriter, yieldOp, newBlock);
342 }
343 }
344
345 mlir::Block *oldBlock = caseOp->getBlock();
346 mlir::Block *newBlock =
347 rewriter.splitBlock(oldBlock, caseOp->getIterator());
348
349 mlir::Block &entryBlock = caseOp.getCaseRegion().front();
350 rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);
351
352 / Create a branch to the entry of the inlined region.
353 rewriter.setInsertionPointToEnd(oldBlock);
354 cir::BrOp::create(rewriter, caseOp.getLoc(), &entryBlock);
355 }
356
357 / Remove all cases since we've inlined the regions.
358 for (cir::CaseOp caseOp : cases) {
359 mlir::Block *caseBlock = caseOp->getBlock();
360 / Erase the block with no predecessors here to make the generated code
361 / simpler a little bit.
362 if (caseBlock->hasNoPredecessors())
363 rewriter.eraseBlock(caseBlock);
364 else
365 rewriter.eraseOp(caseOp);
366 }
367
368 for (auto [rangeVal, operand, destination] :
369 llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {
370 APInt lowerBound = rangeVal.first;
371 APInt upperBound = rangeVal.second;
372
373 / The case range is unreachable, skip it.
374 if (lowerBound.sgt(upperBound))
375 continue;
376
377 / If range is small, add multiple switch instruction cases.
378 / This magical number is from the original CGStmt code.
379 constexpr int kSmallRangeThreshold = 64;
380 if ((upperBound - lowerBound)
381 .ult(llvm::APInt(32, kSmallRangeThreshold))) {
382 for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {
383 caseValues.push_back(iValue);
384 caseOperands.push_back(operand);
385 caseDestinations.push_back(destination);
386 }
387 continue;
388 }
389
390 defaultDestination =
391 condBrToRangeDestination(op, rewriter, destination,
392 defaultDestination, lowerBound, upperBound);
393 defaultOperands = operand;
394 }
395
396 / Set switch op to branch to the newly created blocks.
397 rewriter.setInsertionPoint(op);
398 rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(
399 op, op.getCondition(), defaultDestination, defaultOperands, caseValues,
400 caseDestinations, caseOperands);
401
402 return mlir::success();
403 }
404};
405
406class CIRLoopOpInterfaceFlattening
407 : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {
408public:
409 using mlir::OpInterfaceRewritePattern<
410 cir::LoopOpInterface>::OpInterfaceRewritePattern;
411
412 inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,
413 mlir::Block *exit,
414 mlir::PatternRewriter &rewriter) const {
415 mlir::OpBuilder::InsertionGuard guard(rewriter);
416 rewriter.setInsertionPoint(op);
417 rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,
418 exit);
419 }
420
421 mlir::LogicalResult
422 matchAndRewrite(cir::LoopOpInterface op,
423 mlir::PatternRewriter &rewriter) const final {
424 / Setup CFG blocks.
425 mlir::Block *entry = rewriter.getInsertionBlock();
426 mlir::Block *exit =
427 rewriter.splitBlock(entry, rewriter.getInsertionPoint());
428 mlir::Block *cond = &op.getCond().front();
429 mlir::Block *body = &op.getBody().front();
430 mlir::Block *step =
431 (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);
432
433 / Setup loop entry branch.
434 rewriter.setInsertionPointToEnd(entry);
435 cir::BrOp::create(rewriter, op.getLoc(), &op.getEntry().front());
436
437 / Branch from condition region to body or exit.
438 auto conditionOp = cast<cir::ConditionOp>(cond->getTerminator());
439 lowerConditionOp(conditionOp, body, exit, rewriter);
440
441 / TODO(cir): Remove the walks below. It visits operations unnecessarily.
442 / However, to solve this we would likely need a custom DialectConversion
443 / driver to customize the order that operations are visited.
444
445 / Lower continue statements.
446 mlir::Block *dest = (step ? step : cond);
447 op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {
448 if (!isa<cir::ContinueOp>(op))
449 return mlir::WalkResult::advance();
450
451 lowerTerminator(op, dest, rewriter);
452 return mlir::WalkResult::skip();
453 });
454
455 / Lower break statements.
457 walkRegionSkipping<cir::LoopOpInterface>(
458 op.getBody(), [&](mlir::Operation *op) {
459 if (!isa<cir::BreakOp>(op))
460 return mlir::WalkResult::advance();
461
462 lowerTerminator(op, exit, rewriter);
463 return mlir::WalkResult::skip();
464 });
465
466 / Lower optional body region yield.
467 for (mlir::Block &blk : op.getBody().getBlocks()) {
468 auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());
469 if (bodyYield)
470 lowerTerminator(bodyYield, (step ? step : cond), rewriter);
471 }
472
473 / Lower mandatory step region yield.
474 if (step)
475 lowerTerminator(cast<cir::YieldOp>(step->getTerminator()), cond,
476 rewriter);
477
478 / Move region contents out of the loop op.
479 rewriter.inlineRegionBefore(op.getCond(), exit);
480 rewriter.inlineRegionBefore(op.getBody(), exit);
481 if (step)
482 rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);
483
484 rewriter.eraseOp(op);
485 return mlir::success();
486 }
487};
488
489class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
490public:
491 using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
492
493 mlir::LogicalResult
494 matchAndRewrite(cir::TernaryOp op,
495 mlir::PatternRewriter &rewriter) const override {
496 Location loc = op->getLoc();
497 Block *condBlock = rewriter.getInsertionBlock();
498 Block::iterator opPosition = rewriter.getInsertionPoint();
499 Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
500 llvm::SmallVector<mlir::Location, 2> locs;
501 / Ternary result is optional, make sure to populate the location only
502 / when relevant.
503 if (op->getResultTypes().size())
504 locs.push_back(loc);
505 Block *continueBlock =
506 rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
507 cir::BrOp::create(rewriter, loc, remainingOpsBlock);
508
509 Region &trueRegion = op.getTrueRegion();
510 Block *trueBlock = &trueRegion.front();
511 mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
512 rewriter.setInsertionPointToEnd(&trueRegion.back());
513
514 / Handle both yield and unreachable terminators (throw expressions)
515 if (auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator)) {
516 rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
517 continueBlock);
518 } else if (isa<cir::UnreachableOp>(trueTerminator)) {
519 / Terminator is unreachable (e.g., from throw), just keep it
520 } else {
521 trueTerminator->emitError("unexpected terminator in ternary true region, "
522 "expected yield or unreachable, got: ")
523 << trueTerminator->getName();
524 return mlir::failure();
525 }
526 rewriter.inlineRegionBefore(trueRegion, continueBlock);
527
528 Block *falseBlock = continueBlock;
529 Region &falseRegion = op.getFalseRegion();
530
531 falseBlock = &falseRegion.front();
532 mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
533 rewriter.setInsertionPointToEnd(&falseRegion.back());
534
535 / Handle both yield and unreachable terminators (throw expressions)
536 if (auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator)) {
537 rewriter.replaceOpWithNewOp<cir::BrOp>(
538 falseYieldOp, falseYieldOp.getArgs(), continueBlock);
539 } else if (isa<cir::UnreachableOp>(falseTerminator)) {
540 / Terminator is unreachable (e.g., from throw), just keep it
541 } else {
542 falseTerminator->emitError("unexpected terminator in ternary false "
543 "region, expected yield or unreachable, got: ")
544 << falseTerminator->getName();
545 return mlir::failure();
546 }
547 rewriter.inlineRegionBefore(falseRegion, continueBlock);
548
549 rewriter.setInsertionPointToEnd(condBlock);
550 cir::BrCondOp::create(rewriter, loc, op.getCond(), trueBlock, falseBlock);
551
552 rewriter.replaceOp(op, continueBlock->getArguments());
553
554 / Ok, we're done!
555 return mlir::success();
556 }
557};
558
559class CIRTryOpFlattening : public mlir::OpRewritePattern<cir::TryOp> {
560public:
561 using OpRewritePattern<cir::TryOp>::OpRewritePattern;
562
563 mlir::Block *buildTryBody(cir::TryOp tryOp,
564 mlir::PatternRewriter &rewriter) const {
565 / Split the current block before the TryOp to create the inlining
566 / point.
567 mlir::Block *beforeTryScopeBlock = rewriter.getInsertionBlock();
568 mlir::Block *afterTry =
569 rewriter.splitBlock(beforeTryScopeBlock, rewriter.getInsertionPoint());
570
571 / Inline body region.
572 mlir::Block *beforeBody = &tryOp.getTryRegion().front();
573 rewriter.inlineRegionBefore(tryOp.getTryRegion(), afterTry);
574
575 / Branch into the body of the region.
576 rewriter.setInsertionPointToEnd(beforeTryScopeBlock);
577 cir::BrOp::create(rewriter, tryOp.getLoc(), mlir::ValueRange(), beforeBody);
578 return afterTry;
579 }
580
581 void buildHandlers(cir::TryOp tryOp, mlir::PatternRewriter &rewriter,
582 mlir::Block *afterBody, mlir::Block *afterTry,
583 SmallVectorImpl<cir::CallOp> &callsToRewrite,
584 SmallVectorImpl<mlir::Block *> &landingPads) const {
585 / Replace the tryOp return with a branch that jumps out of the body.
586 rewriter.setInsertionPointToEnd(afterBody);
587
588 mlir::Block *beforeCatch = rewriter.getInsertionBlock();
589 rewriter.setInsertionPointToEnd(beforeCatch);
590
591 / Check if the terminator is a YieldOp because there could be another
592 / terminator, e.g. unreachable
593 if (auto tryBodyYield = dyn_cast<cir::YieldOp>(afterBody->getTerminator()))
594 rewriter.replaceOpWithNewOp<cir::BrOp>(tryBodyYield, afterTry);
595
596 mlir::ArrayAttr handlers = tryOp.getHandlerTypesAttr();
597 if (!handlers || handlers.empty())
598 return;
599
600 llvm_unreachable("TryOpFlattening buildHandlers with CallsOp is NYI");
601 }
602
603 mlir::LogicalResult
604 matchAndRewrite(cir::TryOp tryOp,
605 mlir::PatternRewriter &rewriter) const override {
606 mlir::OpBuilder::InsertionGuard guard(rewriter);
607 mlir::Block *afterBody = &tryOp.getTryRegion().back();
608
609 / Grab the collection of `cir.call exception`s to rewrite to
610 / `cir.try_call`.
611 llvm::SmallVector<cir::CallOp, 4> callsToRewrite;
612 tryOp.getTryRegion().walk([&](CallOp op) {
613 if (op.getNothrow())
614 return;
615
616 / Only grab calls within immediate closest TryOp scope.
617 if (op->getParentOfType<cir::TryOp>() != tryOp)
618 return;
619 callsToRewrite.push_back(op);
620 });
621
622 if (!callsToRewrite.empty())
623 llvm_unreachable(
624 "TryOpFlattening with try block that contains CallOps is NYI");
625
626 / Build try body.
627 mlir::Block *afterTry = buildTryBody(tryOp, rewriter);
628
629 / Build handlers.
630 llvm::SmallVector<mlir::Block *, 4> landingPads;
631 buildHandlers(tryOp, rewriter, afterBody, afterTry, callsToRewrite,
632 landingPads);
633
634 rewriter.eraseOp(tryOp);
635
636 assert((landingPads.size() == callsToRewrite.size()) &&
637 "expected matching number of entries");
638
639 / Quick block cleanup: no indirection to the post try block.
640 auto brOp = dyn_cast<cir::BrOp>(afterTry->getTerminator());
641 if (brOp && brOp.getDest()->hasNoPredecessors()) {
642 mlir::Block *srcBlock = brOp.getDest();
643 rewriter.eraseOp(brOp);
644 rewriter.mergeBlocks(srcBlock, afterTry);
645 }
646
647 return mlir::success();
648 }
649};
650
651void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
652 patterns
653 .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,
654 CIRSwitchOpFlattening, CIRTernaryOpFlattening, CIRTryOpFlattening>(
655 patterns.getContext());
656}
657
658void CIRFlattenCFGPass::runOnOperation() {
659 RewritePatternSet patterns(&getContext());
660 populateFlattenCFGPatterns(patterns);
661
662 / Collect operations to apply patterns.
663 llvm::SmallVector<Operation *, 16> ops;
664 getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
668 if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp, TryOp>(op))
669 ops.push_back(op);
670 });
671
672 / Apply patterns.
673 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
674 signalPassFailure();
675}
676
677} / namespace
678
679namespace mlir {
680
681std::unique_ptr<Pass> createCIRFlattenCFGPass() {
682 return std::make_unique<CIRFlattenCFGPass>();
683}
684
685} / namespace mlir
llvm::APInt APInt
Definition FixedPoint.h:19
std::unique_ptr< Pass > createCIRFlattenCFGPass()
float __ovld __cnfn step(float, float)
Returns 0.0 if x < edge, otherwise it returns 1.0.
static bool stackSaveOp()

Follow Lee on X/Twitter - Father, Husband, Serial builder creating AI, crypto, games & web tools. We are friends :) AI Will Come To Life!

Check out: eBank.nz (Art Generator) | Netwrck.com (AI Tools) | Text-Generator.io (AI API) | BitBank.nz (Crypto AI) | ReadingTime (Kids Reading) | RewordGame | BigMultiplayerChess | WebFiddle | How.nz | Helix AI Assistant