Resolve Operator Pass¶
- class hidet.graph.transforms.resolve_variant.ResolveRule[source]¶
A resolve rule defines how to resolve an operator to other operators.
- resolve(op)[source]¶
When define a resolve rule, the user should subclass this class and override this method.
- Parameters:
op (Operator) – The operator to be resolved.
- Returns:
ret – This function should return a list of tensors if the operator can be resolved, otherwise return None. In the first case, the returned tensors will be used to replace the outputs of the original operator, thus the number of tensors should be the same as the number of outputs of the original operator.
- Return type:
List[Tensor], optional
- hidet.graph.transforms.resolve_variant.register_resolve_rule(op_cls)[source]¶
Register a resolve rule for an operator class.
- Parameters:
op_cls (Type[Operator]) – The operator class to be registered.
- Returns:
ret – The decorator function.
- Return type:
Callable[[Type[ResolveRule]], Type[ResolveRule]]
Notes
In the following example, we define a resolve rule for operator
PowOp
to resolvepow(x, 2.0)
tosquare(x)
.from hidet.ir import Tensor from hidet import ops from hidet.graph.ops import PowOp from hidet.graph.transforms import ResolveRule, register_resolve_rule @register_resolve_rule(PowOp) class AddResolveRule(ResolveRule): def resolve(self, op: PowOp) -> Optional[List[Tensor]]: a: Tensor = op.inputs[0] b: Tensor = op.inputs[1] if not b.is_symbolic() and len(b.shape) == 0 and b.scalar() == 2: return [ops.square(a)] return None