# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=import-outside-toplevel
from __future__ import annotations
from typing import Any, Dict, List, Union, Callable, Optional, Tuple
import os
import enum
import pickle
from hashlib import sha256
from hidet.ir.node import Node
from hidet.ir.type import FuncType, VoidType
from hidet.ir.expr import Expr, Var, SymbolVar, var, is_constant
from hidet.ir.module import IRModule
from hidet.ir.compute import ComputeNode, TensorNode, TensorInput, ScalarInput, GridCompute
from hidet.ir.target import Target
from hidet.ir.cute import TensorLayout
class InverseMap(Node):
def __init__(self, axes: List[Var], indices: List[Expr]):
from hidet.ir.tools import simplify
self.axes: List[Var] = axes
self.indices: List[Expr] = [simplify(e) for e in indices]
self.tile_mapping: Optional[TensorLayout] = None
@staticmethod
def from_obj(obj: Union[InverseMap, Callable[[Any], Any]]):
if isinstance(obj, InverseMap):
return obj
else:
return InverseMap.from_lambda(lambda *args: obj(*args)) # pylint: disable=unnecessary-lambda
@staticmethod
def from_lambda(func, num_args=None) -> InverseMap:
from hidet.ir.utils import as_expr
num_args = num_args if num_args is not None else func.__code__.co_argcount
axes = [var('v') for v in range(num_args)]
indices = [as_expr(index_expr) for index_expr in func(*axes)]
return InverseMap(axes, indices)
@staticmethod
def identity(num_args: int) -> InverseMap:
return InverseMap.from_lambda(lambda *indices: list(indices), num_args=num_args)
def __add__(self, other) -> InverseMap:
from hidet.ir.tools import rewrite
if not isinstance(other, InverseMap):
raise ValueError('Can not concat InverseMap with {}'.format(type(other)))
lhs, rhs = self, other
if len(lhs.indices) != len(rhs.axes):
raise ValueError(
'Can not concat InverseMap a and b, '
'where a has {} indices and b has {} axes'.format(len(lhs.indices), len(rhs.axes))
)
rmap = dict(zip(rhs.axes, lhs.indices))
indices = [rewrite(index_expr, rmap) for index_expr in rhs.indices]
return InverseMap(lhs.axes, indices)
[docs]class Task(Node):
"""
A task defines the operator computation.
Attributes
----------
name: str
The name of the task.
inputs: List[TensorInput]
The input tensors of this task.
outputs: List[TensorNode]
The output tensors of this task. They are derived by applying computation on the input tensors.
inverse_map: Dict[TensorInput, InverseMap]
The inverse map. It records how the input tensors are derived from the output tensor (when there is only one
output tensor). This is used to support epilogue fusion.
attrs: Dict[str, Union[str, float, int, bool]]
The attributes of this task.
assertions: List[Tuple[Expr, Optional[str]]]
The assertions of this task. Each assertion is a tuple of an expression and an optional message. The
expression is evaluated at runtime. If the expression is not true, the program will abort and print the
message.
share_map: Dict[int, int]
The share map. If one output tensor shares memory with one input tensor, it is specified in this map. For
example, `share_map = {0: 0, 1: 2}` means that the output tensor 0 shares the memory with input tensor 0, and
output tensor 1 shares the memory with input tensor 2.
symbols: List[SymbolVar]
The list of symbols used in this task.
"""
def __init__(self, name, inputs, outputs, *, inverse_map=None, attributes=None, share_map=None):
inverse_map = inverse_map if inverse_map else {}
attributes = attributes if attributes else {}
share_map = share_map if share_map else {}
self.name: str = name
self.inputs: List[TensorInput] = list(inputs)
self.outputs: List[TensorNode] = list(outputs)
self.inverse_map: Dict[TensorInput, InverseMap] = {a: InverseMap.from_obj(b) for a, b in inverse_map.items()}
self.attrs: Dict[str, Union[str, float, int, bool]] = attributes
self.assertions: List[Tuple[Expr, Optional[str]]] = getattr(self, 'assertions', [])
self.share_map: Dict[int, int] = share_map
self.str = None
from hidet.ir.tools import collect
self.symbols: List[SymbolVar] = list(collect(self.outputs, SymbolVar))
self._sanity_check()
def _assert(self, expr: Union[Expr, bool], msg: Optional[str] = None):
import hidet
simplified = hidet.ir.tools.simplify(expr)
if is_constant(simplified):
assert simplified, msg
else:
if hasattr(self, 'assertions'):
self.assertions.append((expr, msg))
else:
self.assertions = [(expr, msg)]
@property
def params(self) -> List[TensorNode]:
return [*self.inputs, *self.outputs]
def _sanity_check(self):
from hidet.ir.tools import collect_free_vars, collect
for tn, im in self.inverse_map.items():
if len(im.axes) != tn.ndim:
raise ValueError(
'InverseMap for tensor {} has {} input axes, but input tensor has {} axes'.format(
tn.name, len(im.axes), tn.ndim
)
)
if len(im.indices) != self.outputs[0].ndim:
raise ValueError(
'InverseMap for tensor {} has {} output indices, but output tensor has {} axes'.format(
tn.name, len(im.indices), self.outputs[0].ndim
)
)
free_vars: List[Var] = collect_free_vars(self.outputs)
if any(
v not in self.params and not isinstance(v.type, FuncType) and not isinstance(v, SymbolVar)
for v in free_vars
):
raise ValueError('Some free variables are not in params: {}'.format(free_vars))
# check all TensorInput used in outputs are placed in inputs
used_inputs = collect(self.outputs, TensorInput)
if any(x not in self.inputs + self.outputs for x in used_inputs):
raise ValueError('Some TensorInput used in outputs are not placed in inputs: {}'.format(used_inputs))
# check assertions for correctness
assert_symbols: List[SymbolVar] = list(collect([cond for cond, _ in self.assertions], SymbolVar))
for sym in assert_symbols:
assert sym in self.symbols, f"encountered {sym} in assertions, but not in list of defined symbols"
def has_symbolic_shape(self) -> bool:
from hidet.ir.tools import collect
return len(collect(self.outputs, SymbolVar)) > 0
def signature(self) -> str:
params = []
for tensor in self.tensor_params:
name = tensor.name
dtype = tensor.type.dtype.name
params.append('{}={}{}'.format(name, dtype, tensor.type.shape))
for name, value in self.attrs.items():
if isinstance(value, enum.Enum):
value_str = value.name
else:
value_str = repr(value)
params.append('{}={}'.format(name, value_str))
param_doc = ', '.join(params)
fuse_doc = ''
return ''.join([self.name, '(', param_doc, ')', fuse_doc])
[docs] def generate_arguments(self, inputs, outputs):
"""
Generate arguments for the compiled function of this task given the tensor parameters.
Parameters
----------
inputs: Sequence[Tensor]
The input tensors.
outputs: Sequence[Tensor]
The output tensors.
Returns
-------
args: Sequence[Tensor or int]
The arguments for the compiled function.
"""
remap = {a: b for a, b in zip(self.inputs, inputs)}
remap.update({a: b for a, b in zip(self.outputs, outputs)})
return [remap[arg] for arg in self.params]
@property
def tensor_params(self) -> List[TensorNode]:
ret: List[TensorNode] = []
ret.extend(self.inputs)
ret.extend(self.outputs)
return ret
[docs] def dummy_arguments(self, device: str):
"""
Generate dummy arguments for the compiled function of this task.
Parameters
----------
device: str
The target device.
Returns
-------
args: Sequence[Tensor or int]
The arguments for the compiled function.
"""
import hidet
from hidet.graph.tensor import Tensor
arguments: List[Union[Tensor, int]] = []
for param in self.params:
if isinstance(param, Var):
arguments.append(10)
elif isinstance(param, TensorNode):
if param.type.dtype.is_integer():
arguments.append(hidet.zeros(param.const_shape, dtype=param.type.dtype, device=device))
elif param.type.dtype.is_float():
arguments.append(hidet.randn(param.const_shape, dtype=param.type.dtype, device=device))
else:
raise ValueError('Unknown dtype: {}'.format(param.type.dtype))
else:
raise ValueError('Unknown parameter type: {}'.format(type(param)))
return arguments
[docs] def build(self, target: Union[str, Target], load: bool = True):
"""
Build the task for the given target to a callable function.
Parameters
----------
target: Union[str, Target]
The target device.
load: bool
Whether to load the task
Returns
-------
func: hidet.runtime.CompiledTask
The compiled module.
"""
from hidet.drivers import build_task
if isinstance(target, Target):
target = target.name
return build_task(self, target=target, load=load)
def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]:
from hidet.ir.schedulers import GpuAutoScheduler, CpuAutoScheduler
if isinstance(target, str):
target = Target.from_string(target)
implement_target, scheduler = {
'cuda': (self.implement_cuda, GpuAutoScheduler('cuda')),
'hip': (self.implement_hip, GpuAutoScheduler('hip')),
'cpu': (self.implement_cpu, CpuAutoScheduler()),
}[target.name]
ir_modules: Union[IRModule, List[IRModule]] = implement_target(working_dir)
if ir_modules is NotImplemented:
ir_modules = [scheduler.schedule_task(self, target.name)]
elif isinstance(ir_modules, IRModule):
ir_modules = [ir_modules]
elif isinstance(ir_modules, (list, tuple)) and all(isinstance(x, IRModule) for x in ir_modules):
ir_modules = list(ir_modules)
else:
raise ValueError(
'Expect the `implement` method to return an IRModule or List[IRModule], got {}'.format(ir_modules)
)
for ir_module in ir_modules:
ir_module.task = self
return ir_modules
def implement_cuda(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
return NotImplemented
def implement_hip(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
return NotImplemented
def implement_cpu(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
return NotImplemented
def allow_prologue(self) -> bool:
return True
def allow_epilogue(self) -> bool:
return True
def is_injective(self) -> bool:
from hidet.ir.tools import collect
allowed_nodes = (ScalarInput, TensorInput, GridCompute)
# if found other node like ReduceCompute and ArgReduceCompute, return False
found_nodes = collect(self.outputs, ComputeNode, stop_when_found=False)
return all(isinstance(node, allowed_nodes) for node in found_nodes)
def is_bijective(self) -> bool:
return self.is_injective() and len(self.inverse_map) > 0
def save(self, fname: str):
dirname = os.path.dirname(fname)
os.makedirs(dirname, exist_ok=True)
with open(fname, 'wb') as f:
pickle.dump(self, f)
@staticmethod
def load(fname: str) -> Task:
with open(fname, 'rb') as f:
return pickle.load(f)
def calculate_hash(self, len: int = 16) -> str:
return sha256(str(self).encode()).hexdigest()[:len]
def __str__(self):
if self.str is None:
self.str = super().__str__()
return self.str
def save_task(task: Task, fname: str):
task.save(fname)
def load_task(fname: str) -> Task:
return Task.load(fname)
def task_compiled_func_type(task: Task) -> FuncType:
from hidet.ir.tools import infer_type
return FuncType(param_types=[infer_type(t) for t in task.params], ret_type=VoidType())