clang 22.0.0git
CIRSimplify.cpp
Go to the documentation of this file.
1/===----------------------------------------------------------------------===/
2/
3/ Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4/ See https://llvm.org/LICENSE.txt for license information.
5/ SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6/
7/===----------------------------------------------------------------------===/
8
9#include "PassDetail.h"
10#include "mlir/Dialect/Func/IR/FuncOps.h"
11#include "mlir/IR/Block.h"
12#include "mlir/IR/Operation.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/IR/Region.h"
15#include "mlir/Support/LogicalResult.h"
16#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/ADT/SmallVector.h"
20
21using namespace mlir;
22using namespace cir;
23
24namespace mlir {
25#define GEN_PASS_DEF_CIRSIMPLIFY
26#include "clang/CIR/Dialect/Passes.h.inc"
27} / namespace mlir
28
29/===----------------------------------------------------------------------===/
30/ Rewrite patterns
31/===----------------------------------------------------------------------===/
32
33namespace {
34
35/ Simplify suitable ternary operations into select operations.
36/
37/ For now we only simplify those ternary operations whose true and false
38/ branches directly yield a value or a constant. That is, both of the true and
39/ the false branch must either contain a cir.yield operation as the only
40/ operation in the branch, or contain a cir.const operation followed by a
41/ cir.yield operation that yields the constant value.
42/
43/ For example, we will simplify the following ternary operation:
44/
45/ %0 = ...
46/ %1 = cir.ternary (%condition, true {
47/ %2 = cir.const ...
48/ cir.yield %2
49/ } false {
50/ cir.yield %0
51/
52/ into the following sequence of operations:
53/
54/ %1 = cir.const ...
55/ %0 = cir.select if %condition then %1 else %2
56struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
57 using OpRewritePattern<TernaryOp>::OpRewritePattern;
58
59 LogicalResult matchAndRewrite(TernaryOp op,
60 PatternRewriter &rewriter) const override {
61 if (op->getNumResults() != 1)
62 return mlir::failure();
63
64 if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
65 !isSimpleTernaryBranch(op.getFalseRegion()))
66 return mlir::failure();
67
68 cir::YieldOp trueBranchYieldOp =
69 mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
70 cir::YieldOp falseBranchYieldOp =
71 mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
72 mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
73 mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
74
75 rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
76 rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
77 rewriter.eraseOp(trueBranchYieldOp);
78 rewriter.eraseOp(falseBranchYieldOp);
79 rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
80 falseValue);
81
82 return mlir::success();
83 }
84
85private:
86 bool isSimpleTernaryBranch(mlir::Region &region) const {
87 if (!region.hasOneBlock())
88 return false;
89
90 mlir::Block &onlyBlock = region.front();
91 mlir::Block::OpListType &ops = onlyBlock.getOperations();
92
93 / The region/block could only contain at most 2 operations.
94 if (ops.size() > 2)
95 return false;
96
97 if (ops.size() == 1) {
98 / The region/block only contain a cir.yield operation.
99 return true;
100 }
101
102 / Check whether the region/block contains a cir.const followed by a
103 / cir.yield that yields the value.
104 auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
105 auto yieldValueDefOp =
106 yieldOp.getArgs()[0].getDefiningOp<cir::ConstantOp>();
107 return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
108 }
109};
110
111/ Simplify select operations with boolean constants into simpler forms.
112/
113/ This pattern simplifies select operations where both true and false values
114/ are boolean constants. Two specific cases are handled:
115/
116/ 1. When selecting between true and false based on a condition,
117/ the operation simplifies to just the condition itself:
118/
119/ %0 = cir.select if %condition then true else false
120/ ->
121/ (replaced with %condition directly)
122/
123/ 2. When selecting between false and true based on a condition,
124/ the operation simplifies to the logical negation of the condition:
125/
126/ %0 = cir.select if %condition then false else true
127/ ->
128/ %0 = cir.unary not %condition
129struct SimplifySelect : public OpRewritePattern<SelectOp> {
130 using OpRewritePattern<SelectOp>::OpRewritePattern;
131
132 LogicalResult matchAndRewrite(SelectOp op,
133 PatternRewriter &rewriter) const final {
134 auto trueValueOp = op.getTrueValue().getDefiningOp<cir::ConstantOp>();
135 auto falseValueOp = op.getFalseValue().getDefiningOp<cir::ConstantOp>();
136 if (!trueValueOp || !falseValueOp)
137 return mlir::failure();
138
139 auto trueValue = trueValueOp.getValueAttr<cir::BoolAttr>();
140 auto falseValue = falseValueOp.getValueAttr<cir::BoolAttr>();
141 if (!trueValue || !falseValue)
142 return mlir::failure();
143
144 / cir.select if %0 then #true else #false -> %0
145 if (trueValue.getValue() && !falseValue.getValue()) {
146 rewriter.replaceAllUsesWith(op, op.getCondition());
147 rewriter.eraseOp(op);
148 return mlir::success();
149 }
150
151 / cir.select if %0 then #false else #true -> cir.unary not %0
152 if (!trueValue.getValue() && falseValue.getValue()) {
153 rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
154 op.getCondition());
155 return mlir::success();
156 }
157
158 return mlir::failure();
159 }
160};
161
162/ Simplify `cir.switch` operations by folding cascading cases
163/ into a single `cir.case` with the `anyof` kind.
164/
165/ This pattern identifies cascading cases within a `cir.switch` operation.
166/ Cascading cases are defined as consecutive `cir.case` operations of kind
167/ `equal`, each containing a single `cir.yield` operation in their body.
168/
169/ The pattern merges these cascading cases into a single `cir.case` operation
170/ with kind `anyof`, aggregating all the case values.
171/
172/ The merging process continues until a `cir.case` with a different body
173/ (e.g., containing `cir.break` or compound stmt) is encountered, which
174/ breaks the chain.
175/
176/ Example:
177/
178/ Before:
179/ cir.case equal, [#cir.int<0> : !s32i] {
180/ cir.yield
181/ }
182/ cir.case equal, [#cir.int<1> : !s32i] {
183/ cir.yield
184/ }
185/ cir.case equal, [#cir.int<2> : !s32i] {
186/ cir.break
187/ }
188/
189/ After applying SimplifySwitch:
190/ cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
191/ !s32i] {
192/ cir.break
193/ }
194struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
195 using OpRewritePattern<SwitchOp>::OpRewritePattern;
196 LogicalResult matchAndRewrite(SwitchOp op,
197 PatternRewriter &rewriter) const override {
198
199 LogicalResult changed = mlir::failure();
200 SmallVector<CaseOp, 8> cases;
201 SmallVector<CaseOp, 4> cascadingCases;
202 SmallVector<mlir::Attribute, 4> cascadingCaseValues;
203
204 op.collectCases(cases);
205 if (cases.empty())
206 return mlir::failure();
207
208 auto flushMergedOps = [&]() {
209 for (CaseOp &c : cascadingCases)
210 rewriter.eraseOp(c);
211 cascadingCases.clear();
212 cascadingCaseValues.clear();
213 };
214
215 auto mergeCascadingInto = [&](CaseOp &target) {
216 rewriter.modifyOpInPlace(target, [&]() {
217 target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
218 target.setKind(CaseOpKind::Anyof);
219 });
220 changed = mlir::success();
221 };
222
223 for (CaseOp c : cases) {
224 cir::CaseOpKind kind = c.getKind();
225 if (kind == cir::CaseOpKind::Equal &&
226 isa<YieldOp>(c.getCaseRegion().front().front())) {
227 / If the case contains only a YieldOp, collect it for cascading merge
228 cascadingCases.push_back(c);
229 cascadingCaseValues.push_back(c.getValue()[0]);
230 } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
231 / merge previously collected cascading cases
232 cascadingCaseValues.push_back(c.getValue()[0]);
233 mergeCascadingInto(c);
234 flushMergedOps();
235 } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
236 / If a Default, Anyof or Range case is found and there are previous
237 / cascading cases, merge all of them into the last cascading case.
238 / We don't currently fold case range statements with other case
239 / statements.
241 CaseOp lastCascadingCase = cascadingCases.back();
242 mergeCascadingInto(lastCascadingCase);
243 cascadingCases.pop_back();
244 flushMergedOps();
245 } else {
246 cascadingCases.clear();
247 cascadingCaseValues.clear();
248 }
249 }
250
251 / Edge case: all cases are simple cascading cases
252 if (cascadingCases.size() == cases.size()) {
253 CaseOp lastCascadingCase = cascadingCases.back();
254 mergeCascadingInto(lastCascadingCase);
255 cascadingCases.pop_back();
256 flushMergedOps();
257 }
258
259 return changed;
260 }
261};
262
263struct SimplifyVecSplat : public OpRewritePattern<VecSplatOp> {
264 using OpRewritePattern<VecSplatOp>::OpRewritePattern;
265 LogicalResult matchAndRewrite(VecSplatOp op,
266 PatternRewriter &rewriter) const override {
267 mlir::Value splatValue = op.getValue();
268 auto constant = splatValue.getDefiningOp<cir::ConstantOp>();
269 if (!constant)
270 return mlir::failure();
271
272 auto value = constant.getValue();
273 if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
274 !mlir::isa_and_nonnull<cir::FPAttr>(value))
275 return mlir::failure();
276
277 cir::VectorType resultType = op.getResult().getType();
278 SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);
279 auto constVecAttr = cir::ConstVectorAttr::get(
280 resultType, mlir::ArrayAttr::get(getContext(), elements));
281
282 rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);
283 return mlir::success();
284 }
285};
286
287/===----------------------------------------------------------------------===/
288/ CIRSimplifyPass
289/===----------------------------------------------------------------------===/
290
291struct CIRSimplifyPass : public impl::CIRSimplifyBase<CIRSimplifyPass> {
292 using CIRSimplifyBase::CIRSimplifyBase;
293
294 void runOnOperation() override;
295};
296
297void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
298 / clang-format off
299 patterns.add<
300 SimplifyTernary,
301 SimplifySelect,
302 SimplifySwitch,
303 SimplifyVecSplat
304 >(patterns.getContext());
305 / clang-format on
306}
307
308void CIRSimplifyPass::runOnOperation() {
309 / Collect rewrite patterns.
310 RewritePatternSet patterns(&getContext());
311 populateMergeCleanupPatterns(patterns);
312
313 / Collect operations to apply patterns.
314 llvm::SmallVector<Operation *, 16> ops;
315 getOperation()->walk([&](Operation *op) {
316 if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op))
317 ops.push_back(op);
318 });
319
320 / Apply patterns.
321 if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
322 signalPassFailure();
323}
324
325} / namespace
326
327std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
328 return std::make_unique<CIRSimplifyPass>();
329}
__device__ __2f16 float c
unsigned kind
All of the diagnostics that can be emitted by the frontend.
std::unique_ptr< Pass > createCIRSimplifyPass()
static bool foldRangeCase()

Follow Lee on X/Twitter - Father, Husband, Serial builder creating AI, crypto, games & web tools. We are friends :) AI Will Come To Life!

Check out: eBank.nz (Art Generator) | Netwrck.com (AI Tools) | Text-Generator.io (AI API) | BitBank.nz (Crypto AI) | ReadingTime (Kids Reading) | RewordGame | BigMultiplayerChess | WebFiddle | How.nz | Helix AI Assistant