Skip to content

Commit 031e5fe

Browse files
authored
Verifiers (#17)
* Add `qualified` to the type assembly format This is because I noticed the assembly format was printing <10> for polynomials instead of the fully qualified type name. After this commit it will print the whole type See https://mlir.llvm.org/docs/DefiningDialects/Operations/#declarative-assembly-format for more details. * Add SameOperandsAndResultType This removes the flexibility of having mixed poly + tensor ops for the binary operations, but demonstrates how the type inference engine enables a more succinct textual IR. If you were to simplify the assembly format without doing this, you'd get a compile-time error complaining that it can't infer the type of the operands or argument. * add AllTypesMatch to EvalOp * add a custom verifier for evalop * add verifier via trait
1 parent c387ac0 commit 031e5fe

12 files changed

Lines changed: 84 additions & 31 deletions

lib/Dialect/Poly/BUILD

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ td_library(
1313
],
1414
includes = ["@heir//include"],
1515
deps = [
16-
# the base mlir target for defining operations and dialects in tablegen
17-
"@llvm-project//mlir:OpBaseTdFiles",
1816
"@llvm-project//mlir:BuiltinDialectTdFiles",
17+
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
18+
"@llvm-project//mlir:OpBaseTdFiles",
1919
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
2020
],
2121
)
@@ -89,6 +89,7 @@ cc_library(
8989
hdrs = [
9090
"PolyDialect.h",
9191
"PolyOps.h",
92+
"PolyTraits.h",
9293
"PolyTypes.h",
9394
],
9495
deps = [
@@ -97,6 +98,7 @@ cc_library(
9798
":types_inc_gen",
9899
"@llvm-project//mlir:Dialect",
99100
"@llvm-project//mlir:IR",
101+
"@llvm-project//mlir:InferTypeOpInterface",
100102
"@llvm-project//mlir:Support",
101103
],
102104
)

lib/Dialect/Poly/PolyOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) {
6161
return dyn_cast<DenseIntElementsAttr>(adaptor.getInput());
6262
}
6363

64+
LogicalResult EvalOp::verify() {
65+
return getPoint().getType().isSignlessInteger(32)
66+
? success()
67+
: emitOpError("argument point must be a 32-bit integer");
68+
}
69+
6470
} // namespace poly
6571
} // namespace tutorial
6672
} // namespace mlir

lib/Dialect/Poly/PolyOps.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
#define LIB_DIALECT_POLY_POLYOPS_H_
33

44
#include "lib/Dialect/Poly/PolyDialect.h"
5+
#include "lib/Dialect/Poly/PolyTraits.h"
56
#include "lib/Dialect/Poly/PolyTypes.h"
6-
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
7-
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
8-
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
7+
#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
8+
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
9+
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
10+
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
911

1012
#define GET_OP_CLASSES
1113
#include "lib/Dialect/Poly/PolyOps.h.inc"
1214

13-
#endif // LIB_DIALECT_POLY_POLYOPS_H_
15+
#endif // LIB_DIALECT_POLY_POLYOPS_H_

lib/Dialect/Poly/PolyOps.td

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,22 @@ include "PolyDialect.td"
55
include "PolyTypes.td"
66
include "mlir/IR/BuiltinAttributes.td"
77
include "mlir/IR/OpBase.td"
8+
include "mlir/Interfaces/InferTypeOpInterface.td"
89
include "mlir/Interfaces/SideEffectInterfaces.td"
910

1011
// Type constraint for poly binop arguments: polys, vectors of polys, or
1112
// tensors of polys.
1213
def PolyOrContainer : TypeOrContainer<Polynomial, "poly-or-container">;
1314

14-
class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic, [Pure, ElementwiseMappable, SameOperandsAndResultElementType]> {
15+
// Inject verification that all integer-like arguments are 32-bits
16+
def Has32BitArguments : NativeOpTrait<"Has32BitArguments"> {
17+
let cppNamespace = "::mlir::tutorial::poly";
18+
}
19+
20+
class Poly_BinOp<string mnemonic> : Op<Poly_Dialect, mnemonic, [Pure, ElementwiseMappable, SameOperandsAndResultType]> {
1521
let arguments = (ins PolyOrContainer:$lhs, PolyOrContainer:$rhs);
1622
let results = (outs PolyOrContainer:$output);
17-
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)";
23+
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` qualified(type($output))";
1824
let hasFolder = 1;
1925
}
2026

@@ -34,22 +40,23 @@ def Poly_FromTensorOp : Op<Poly_Dialect, "from_tensor", [Pure]> {
3440
let summary = "Creates a Polynomial from integer coefficients stored in a tensor.";
3541
let arguments = (ins TensorOf<[AnyInteger]>:$input);
3642
let results = (outs Polynomial:$output);
37-
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
43+
let assemblyFormat = "$input attr-dict `:` type($input) `->` qualified(type($output))";
3844
let hasFolder = 1;
3945
}
4046

41-
def Poly_EvalOp : Op<Poly_Dialect, "eval"> {
47+
def Poly_EvalOp : Op<Poly_Dialect, "eval", [AllTypesMatch<["point", "output"]>, Has32BitArguments]> {
4248
let summary = "Evaluates a Polynomial at a given input value.";
4349
let arguments = (ins Polynomial:$input, AnyInteger:$point);
4450
let results = (outs AnyInteger:$output);
45-
let assemblyFormat = "$input `,` $point attr-dict `:` `(` type($input) `,` type($point) `)` `->` type($output)";
51+
let assemblyFormat = "$input `,` $point attr-dict `:` `(` qualified(type($input)) `,` type($point) `)` `->` type($output)";
52+
let hasVerifier = 1;
4653
}
4754

4855
def Poly_ConstantOp : Op<Poly_Dialect, "constant", [Pure, ConstantLike]> {
4956
let summary = "Define a constant polynomial via an attribute.";
5057
let arguments = (ins AnyIntElementsAttr:$coefficients);
5158
let results = (outs Polynomial:$output);
52-
let assemblyFormat = "$coefficients attr-dict `:` type($output)";
59+
let assemblyFormat = "$coefficients attr-dict `:` qualified(type($output))";
5360
let hasFolder = 1;
5461
}
5562

lib/Dialect/Poly/PolyTraits.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#ifndef LIB_DIALECT_POLY_POLYTRAITS_H_
2+
#define LIB_DIALECT_POLY_POLYTRAITS_H_
3+
4+
#include "mlir/include/mlir/IR/OpDefinition.h"
5+
6+
namespace mlir::tutorial::poly {
7+
8+
template <typename ConcreteType>
9+
class Has32BitArguments : public OpTrait::TraitBase<ConcreteType, Has32BitArguments> {
10+
public:
11+
static LogicalResult verifyTrait(Operation *op) {
12+
for (auto type : op->getOperandTypes()) {
13+
// OK to skip non-integer operand types
14+
if (!type.isIntOrIndex()) continue;
15+
16+
if (!type.isInteger(32)) {
17+
return op->emitOpError()
18+
<< "requires each numeric operand to be a 32-bit integer";
19+
}
20+
}
21+
22+
return success();
23+
}
24+
};
25+
26+
}
27+
28+
#endif // LIB_DIALECT_POLY_POLYTRAITS_H_

tests/code_motion.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ module {
1515
%ret_val = affine.for %i = 0 to 100 iter_args(%sum_iter = %p0) -> !poly.poly<10> {
1616
// The poly.mul should be hoisted out of the loop.
1717
// CHECK-NOT: poly.mul
18-
%2 = poly.mul %p0, %p1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
19-
%sum_next = poly.add %sum_iter, %2 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
18+
%2 = poly.mul %p0, %p1 : !poly.poly<10>
19+
%sum_next = poly.add %sum_iter, %2 : !poly.poly<10>
2020
affine.yield %sum_next : !poly.poly<10>
2121
}
2222

tests/control_flow_sink.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ func.func @test_simple_sink(%arg0: i1) -> !poly.poly<10> {
1212
// CHECK: scf.if
1313
%4 = scf.if %arg0 -> (!poly.poly<10>) {
1414
// CHECK: poly.from_tensor
15-
%2 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
15+
%2 = poly.mul %p0, %p0 : !poly.poly<10>
1616
scf.yield %2 : !poly.poly<10>
1717
// CHECK: else
1818
} else {
1919
// CHECK: poly.from_tensor
20-
%3 = poly.mul %p1, %p1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
20+
%3 = poly.mul %p1, %p1 : !poly.poly<10>
2121
scf.yield %3 : !poly.poly<10>
2222
}
2323
return %4 : !poly.poly<10>

tests/cse.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ func.func @test_simple_cse() -> !poly.poly<10> {
88
// exactly one mul op
99
// CHECK-NEXT: poly.mul
1010
// CHECK-NEXT: poly.add
11-
%2 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
12-
%3 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
13-
%4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
11+
%2 = poly.mul %p0, %p0 : !poly.poly<10>
12+
%3 = poly.mul %p0, %p0 : !poly.poly<10>
13+
%4 = poly.add %2, %3 : !poly.poly<10>
1414
return %4 : !poly.poly<10>
1515
}

tests/poly_canonicalize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ func.func @test_simple() -> !poly.poly<10> {
66
// CHECK-NEXT: return
77
%0 = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
88
%p0 = poly.from_tensor %0 : tensor<3xi32> -> !poly.poly<10>
9-
%2 = poly.add %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
10-
%3 = poly.mul %p0, %p0 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
11-
%4 = poly.add %2, %3 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
9+
%2 = poly.add %p0, %p0 : !poly.poly<10>
10+
%3 = poly.mul %p0, %p0 : !poly.poly<10>
11+
%4 = poly.add %2, %3 : !poly.poly<10>
1212
return %2 : !poly.poly<10>
1313
}

tests/poly_syntax.mlir

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ module {
1111
// CHECK-LABEL: test_op_syntax
1212
func.func @test_op_syntax(%arg0: !poly.poly<10>, %arg1: !poly.poly<10>) -> !poly.poly<10> {
1313
// CHECK: poly.add
14-
%0 = poly.add %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
14+
%0 = poly.add %arg0, %arg1 : !poly.poly<10>
1515
// CHECK: poly.sub
16-
%1 = poly.sub %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
16+
%1 = poly.sub %arg0, %arg1 : !poly.poly<10>
1717
// CHECK: poly.mul
18-
%2 = poly.mul %arg0, %arg1 : (!poly.poly<10>, !poly.poly<10>) -> !poly.poly<10>
18+
%2 = poly.mul %arg0, %arg1 : !poly.poly<10>
1919

2020
%3 = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
2121
// CHECK: poly.from_tensor
@@ -27,9 +27,7 @@ module {
2727

2828
%7 = tensor.from_elements %arg0, %arg1 : tensor<2x!poly.poly<10>>
2929
// CHECK: poly.add
30-
%8 = poly.add %7, %7 : (tensor<2x!poly.poly<10>>, tensor<2x!poly.poly<10>>) -> tensor<2x!poly.poly<10>>
31-
// CHECK: poly.add
32-
%9 = poly.add %7, %4 : (tensor<2x!poly.poly<10>>, !poly.poly<10>) -> tensor<2x!poly.poly<10>>
30+
%8 = poly.add %7, %7 : tensor<2x!poly.poly<10>>
3331

3432
// CHECK: poly.constant
3533
%10 = poly.constant dense<[2, 3, 4]> : tensor<3xi32> : !poly.poly<10>

0 commit comments

Comments
 (0)