#### Note

**Difference between sub-graph rewrite and operator resolving**. Although\n `operator resolving ` can be conceptually considered as a special case of\n sub-graph rewrite, we use a different mechanism to implement them and their execution order is different. The operator\n resolving can be performed efficiently thus we can add arbitrary number of operator resolve rules. But the sub-graph\n rewrite is usually more expensive. Second, we run the sub-graph rewrite pass before the operator resolving pass, so\n that we can use the generic operators in the sub-graph patterns without worrying about the operator resolving.

\n\n\n## Add a sub-graph rewrite rule\n\nLet's implement the sub-graph rewrite rule shown in the above example. Before we start, we first create a new model\nthat contains the sub-graph we want to rewrite:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from typing import Optional, List\n\nimport hidet\nfrom hidet import Tensor, FlowGraph, Operator\nfrom hidet import ops\nfrom hidet.graph.transforms.graph_patterns import MatchDict\n\n\ndef example_model(x: Tensor, w0: Tensor, w1: Tensor, w2: Tensor):\n x = ops.matmul(x, w0)\n y1 = ops.matmul(x, w1)\n y2 = ops.matmul(x, w2)\n y = ops.relu(ops.concat([y1, y2], axis=1))\n return y\n\n\nx = hidet.symbol([3, 3])\nw0, w1, w2 = hidet.randn([3, 3]), hidet.randn([3, 3]), hidet.randn([3, 3])\ny = example_model(x, w0, w1, w2)\ngraph: FlowGraph = hidet.trace_from(y, inputs=[x])\nprint(graph)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we define and register the sub-graph rewrite rule.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from hidet.graph.ops.matmul import MatmulOp\nfrom hidet.graph.ops.transform import ConcatOp\nfrom hidet.graph.transforms import TensorPattern, SubgraphRewriteRule\nfrom hidet.graph.transforms import op_pattern, register_rewrite_rule\nfrom hidet.utils import same_list\n\n\n# register the rewrite rule, only registered rewrite rules will be applied\n@register_rewrite_rule\nclass FuseTwoMatmulRewriteRule(SubgraphRewriteRule):\n def __init__(self):\n super().__init__(name=\"new: [matmul(x, c1), matmul(x,c2)] => matmul(x, [c1, c2])\")\n self.x = TensorPattern() # x can match either a symbolic or concrete tensor\n self.c1 = TensorPattern(is_const=True) # c1 can only match a concrete tensor\n self.c2 = TensorPattern(is_const=True) # c2 can only match a concrete tensor\n self.y1 = op_pattern(MatmulOp, [self.x, self.c1]) # pattern: y1 = matmul(x, c1)\n self.y2 = op_pattern(MatmulOp, [self.x, self.c2]) # pattern: y2 = matmul(x, c2)\n self.y = op_pattern(ConcatOp, [self.y1, self.y2]) # pattern: y = concat(y1, y2)\n\n def source(self) -> List[TensorPattern]:\n # Return the output tensors of the source sub-graph pattern.\n return [self.y]\n\n def target(self, matched: MatchDict) -> Optional[List[Tensor]]:\n # The target sub-graph constructor\n # The matched dictionary has type Dict[TensorPattern, Tensor]\n # that maps the patterns to the matched tensors.\n x, c1, c2, y = [matched[t] for t in [self.x, self.c1, self.c2, self.y]]\n concat: Operator = y.op\n\n # We can use the matched tensors to determine whether the rewrite is applicable.\n # For example, in this case, we check whether the two weight matrices share the\n # same shape except the last dimension.\n if (\n 2 <= len(c1.shape) == len(c2.shape)\n and same_list(c1.shape[:-1], c2.shape[:-1])\n and concat.attrs[\"axis\"] == len(y.shape) - 1\n ):\n # If applicable, we construct the target sub-graph and return the output tensors.\n c = ops.concat([c1, c2], axis=-1)\n y = ops.matmul(x, c)\n return [y]\n else:\n # If not, we return None to indicate that the rewrite is not applicable.\n return None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can check that the rewrite rule has been registered:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from hidet.graph.transforms import registered_rewrite_rules, clear_registered_rewrite_rules\n\nprint('Registered rewrite rules:')\nfor rule in registered_rewrite_rules():\n assert isinstance(rule, SubgraphRewriteRule)\n print(rule.name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Apply the rewrite rule\nBesides the predefined rewrite rules, we can see that the rewrite rule we just registered is also included at the\nlast line. In this tutorial, to prevent the default rewrite rules from being applied, we first clear the registered\nrewrite rules and then register the rewrite rule we just defined:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"clear_registered_rewrite_rules()\nregister_rewrite_rule(FuseTwoMatmulRewriteRule()) # a second way to register the rewrite rule"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The rewrite process is done in a graph optimization pass called `subgraph_rewrite_pass`.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from hidet.graph.transforms import subgraph_rewrite_pass\n\nrewrite_pass = subgraph_rewrite_pass()\nrewritten_graph: FlowGraph = rewrite_pass(graph)\nprint(rewritten_graph)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that the rewritten graph contains 2 matmul operators instead of 3. There is no concat operator in the\nrewritten graph, because the inputs of concat operator are constant tensors and thus have been folded.\n\nWe do not need to call the rewrite pass explicitly. It will be called automatically when we call\n:func:`hidet.graph.optimize`, together with other graph optimization passes.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"graph_opt: FlowGraph = hidet.graph.optimize(graph)\nprint(graph_opt)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Summary\nIn this tutorial, we have learned how to define and register a sub-graph rewrite rule. It is an important\ncomponent of the graph optimization framework. Hidet uses it to implement some horizontal fusion rules.\n\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
}
},
"nbformat": 4,
"nbformat_minor": 0
}