Sub-graph Rewrite Pass¶
- class hidet.graph.transforms.subgraph_rewrite.TensorPattern(is_const=False, is_symbolic=False, trace=None)[source]¶
The tensor pattern represents a tensor in the pattern graph.
- class hidet.graph.transforms.subgraph_rewrite.OperatorPattern(op_cls, inputs, num_outputs=1)[source]¶
The operator pattern represents an operator in the pattern graph.
- class hidet.graph.transforms.subgraph_rewrite.SubgraphRewriteRule(name='')[source]¶
A sub-graph rewrite rule defines a sub-graph pattern (called source) to match in the computation graph, and a target sub-graph constructor to replace the matched sub-graph.
When defining a new sub-graph rewrite rule, you need to define a new class inherited from SubgraphRewriteRule and implement the source() and target() methods. The source() method returns a list of output tensors in the sub-graph pattern while the target() method returns a list of output tensors in the target sub-graph, given the match dict that maps the tensors/operators in the pattern to the matched tensors/operators in the computation graph.
After defining the sub-graph rewrite rule, you need to register it to the sub-graph rewrite rule registry via
register_rewrite_rule()
.- source()[source]¶
The output tensors in the source template graph to match in the computation graph.
- Return type:
List[TensorPattern]
- target(matched)[source]¶
The output tensors in the target sub-graph used to replace the matched pattern. Return None means failed to generate the target sub-graph, and we should not do the transformation.
- Parameters:
matched (Dict[TensorPattern | OperatorPattern, Tensor | Operator]) –
- Return type:
List[Tensor] | None
- hidet.graph.transforms.subgraph_rewrite.register_rewrite_rule(rule)[source]¶
Register a sub-graph rewrite rule.
- Parameters:
rule (SubgraphRewriteRule or Type[SubgraphRewriteRule]) – The rule to be registered. If it is a type, it will be instantiated with default arguments. Otherwise, it should be an instance of SubgraphRewriteRule.