Source code for hidet.graph.transforms.resolve_variant

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Type, List, Optional, Dict
import logging
from hidet.ir.expr import is_constant
from hidet.graph.flow_graph import FlowGraph, Tensor, Operator
from hidet.graph.graph_utils.functors import GraphRewriter
from hidet.utils import strict_zip, same_list, repeat_until_converge
from .base import GraphPass, PassContext


logger = logging.getLogger(__name__)


[docs]class ResolveRule: """ A resolve rule defines how to resolve an operator to other operators. """
[docs] def resolve(self, op: Operator) -> Optional[List[Tensor]]: """ When define a resolve rule, the user should subclass this class and override this method. Parameters ---------- op: Operator The operator to be resolved. Returns ------- ret: List[Tensor], optional This function should return a list of tensors if the operator can be resolved, otherwise return None. In the first case, the returned tensors will be used to replace the outputs of the original operator, thus the number of tensors should be the same as the number of outputs of the original operator. """ raise NotImplementedError()
def get_config(self, name, default=None): return PassContext.current().configs.get(name, default)
class ResolveRuleChain: def __init__(self, op_cls: Type[Operator], rules: List[ResolveRule]): self.op_cls: Type[Operator] = op_cls self.rules: List[ResolveRule] = list(rules) def resolve(self, op: Operator) -> Optional[List[Tensor]]: # apply rules in reverse order, so that the latest rule has the highest priority for rule in reversed(self.rules): outs = rule.resolve(op) if outs is not None: return outs return None registered_resolve_rules: Dict[Type[Operator], ResolveRuleChain] = {}
[docs]def register_resolve_rule(op_cls: Type[Operator]): """ Register a resolve rule for an operator class. Parameters ---------- op_cls: Type[Operator] The operator class to be registered. Returns ------- ret: Callable[[Type[ResolveRule]], Type[ResolveRule]] The decorator function. Notes ----- In the following example, we define a resolve rule for operator ``PowOp`` to resolve ``pow(x, 2.0)`` to ``square(x)``. .. code-block:: python from hidet.ir import Tensor from hidet import ops from hidet.graph.ops import PowOp from hidet.graph.transforms import ResolveRule, register_resolve_rule @register_resolve_rule(PowOp) class AddResolveRule(ResolveRule): def resolve(self, op: PowOp) -> Optional[List[Tensor]]: a: Tensor = op.inputs[0] b: Tensor = op.inputs[1] if not b.is_symbolic() and len(b.shape) == 0 and b.scalar() == 2: return [ops.square(a)] return None """ if not issubclass(op_cls, Operator): raise ValueError("Expect a subclass of Operator, got {}".format(type(op_cls))) def wrapper(rule_cls: Type[ResolveRule]): if not issubclass(rule_cls, ResolveRule): raise ValueError("Expect a subclass of ResolveRule, got {}".format(type(rule_cls))) if op_cls not in registered_resolve_rules: registered_resolve_rules[op_cls] = ResolveRuleChain(op_cls, []) chain = registered_resolve_rules[op_cls] chain.rules.append(rule_cls()) return rule_cls return wrapper
def get_resolve_chain(op_cls: Type[Operator]) -> Optional[ResolveRuleChain]: if op_cls not in registered_resolve_rules: return None return registered_resolve_rules[op_cls] class ResolveVariantRewriter(GraphRewriter): def __init__(self, op_cls: Type[Operator], rule_chain: ResolveRuleChain): super().__init__() self.op_cls: Type[Operator] = op_cls self.rule_chain: ResolveRuleChain = rule_chain def visit_Operator(self, op: Operator): if not isinstance(op, self.op_cls): GraphRewriter.visit_Operator(self, op) return inputs = [self(x) for x in op.inputs] if same_list(inputs, op.inputs): resolve_op = op else: updated_outputs = op.reforward(inputs) resolve_op = updated_outputs[0].op outs = self.rule_chain.resolve(resolve_op) if outs is None: # keep the original operator # we still need to update memo in case inputs changed assert all(original not in self.memo for original in op.outputs) for original, updated in zip(op.outputs, resolve_op.outputs): self.memo[original] = updated else: logger.debug("Resolve operator %s", op.name) # update output of resolved operator if not isinstance(outs, (list, tuple)): raise ValueError( "The resolve rule of operator '{}' should return a list of tensors, but got {}".format( op.name, type(outs) ) ) if len(outs) != len(op.outputs): raise ValueError( "The resolve rule of operator '{}' should return {} tensors, but got {} ones".format( op.name, len(op.outputs), len(outs) ) ) for i, (original, updated) in enumerate(strict_zip(op.outputs, outs)): assert original not in self.memo if not self.is_compatible_output(original, updated): raise ValueError( ( "The resolve rule of operator '{}' should return tensors with the same dtype and " "shape as the original ones. The {}-th tensor expect {}{} but got {}{}" ).format(op.name, i, original.dtype, list(original.shape), updated.dtype, list(updated.shape)) ) for original, updated in zip(op.outputs, outs): self.memo[original] = updated @staticmethod def is_compatible_output(a: Tensor, b: Tensor): if a.dtype != b.dtype: return False if len(a.shape) != len(b.shape): return False for va, vb in zip(a.shape, b.shape): if is_constant(va, vb) and va != vb: return False return True class ResolveVariantPass(GraphPass): def process_graph(self, input_graph: FlowGraph) -> FlowGraph: def apply_rules(graph: FlowGraph) -> FlowGraph: for op_cls, rule_chain in registered_resolve_rules.items(): rewriter = ResolveVariantRewriter(op_cls, rule_chain) graph = rewriter(graph) return graph return repeat_until_converge(apply_rules, input_graph, limit=None) def resolve_variant_pass() -> GraphPass: return ResolveVariantPass()