# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Sequence, Tuple, Any, List, Union, Optional
import enum
from import Node
from import DataType, PointerType, TensorPointerType, ReferenceType
from import Var, Expr, convert, Constant
from import TensorNode
from import TaskMapping

# scope
[docs]class DeclareScope(enum.Enum): """ The scope of a tensor variable used in declaration statement. """ Default = 0 Global = 1 Shared = 2 Register = 3 @staticmethod def from_str(name): if name == 'global': return DeclareScope.Global elif name == 'shared': return DeclareScope.Shared elif name == 'register': return DeclareScope.Register else: return DeclareScope.Default
class ForStmtAttr: def __init__(self, unroll=False, unroll_factor=None, unroll_explicit=False, parallel=False, parallel_threads=None): self.unroll: bool = unroll self.unroll_factor: Optional[int] = unroll_factor self.unroll_explicit: bool = unroll_explicit self.parallel: bool = parallel self.parallel_threads: Optional[int] = parallel_threads def __str__(self): if self.unroll: if self.unroll_explicit: return 'u+' elif self.unroll_factor: return f'u{self.unroll_factor}' else: return 'u' elif self.parallel: if self.parallel_threads: return f'p{self.parallel_threads}' else: return 'p' else: return '.' @staticmethod def from_extent(extent: Union[int, Expr]): if isinstance(extent, Expr): if isinstance(extent, Constant): extent = int(extent) else: return ForStmtAttr() if extent < 4: return ForStmtAttr(unroll=True, unroll_explicit=True) else: return ForStmtAttr() @staticmethod def parse(attr: Optional[str], num_loops: int) -> List[ForStmtAttr]: """ Parse the attribute string and return a list of ForStmtAttr. attr-string: attr* attr: | unroll | parallel | default unroll: | 'u' # unroll | 'u' INT+ # unroll with factor, e.g., u1 u2 u3. u1 indicates unroll with factor 1 (i.e., no unroll) | 'u' '+' # explicit unroll, will be unrolled by hidet instead of underlying compiler parallel: | 'p' # parallel with available number of threads | 'p' INT+ # parallel with specified number of threads default: '.' Parameters ---------- attr: str The attribute string. num_loops: int The number of loops this attr string describe. Returns ------- attrs: List[ForStmtAttr] The list of ForStmtAttr. """ if attr is None: attr = '' s = attr.replace(' ', '') idx = 0 def cur() -> Optional[str]: if idx >= len(s): return None return s[idx] attrs: List[ForStmtAttr] = [] while idx < len(s): if s[idx] == '.': idx += 1 attrs.append(ForStmtAttr()) elif s[idx] == 'u': idx += 1 c = cur() if c == '+': attrs.append(ForStmtAttr(unroll=True, unroll_explicit=True)) idx += 1 elif c and c.isdigit(): unroll_factor = 0 while c and c.isdigit(): unroll_factor = unroll_factor * 10 + int(c) idx += 1 c = cur() if unroll_factor == 0: raise ValueError(f"Invalid attribute string: {attr}") attrs.append(ForStmtAttr(unroll=True, unroll_factor=unroll_factor)) else: attrs.append(ForStmtAttr(unroll=True, unroll_explicit=False)) elif s[idx] == 'p': idx += 1 c = cur() if c and c.isdigit(): parallel_threads = 0 while c and c.isdigit(): parallel_threads = parallel_threads * 10 + int(c) idx += 1 c = cur() if parallel_threads == 0: raise ValueError(f"Invalid attribute string: {attr}") attrs.append(ForStmtAttr(parallel=True, parallel_threads=parallel_threads)) else: attrs.append(ForStmtAttr(parallel=True)) else: raise ValueError(f"Invalid attribute string: {attr}") if len(attrs) == 0: attrs = [ForStmtAttr() for _ in range(num_loops)] elif len(attrs) == 1: attrs = attrs * num_loops elif len(attrs) != num_loops: raise ValueError("Invalid attribute string: {} for {} loops".format(attr, num_loops)) return attrs class Stmt(Node): pass class EvaluateStmt(Stmt): def __init__(self, expr): super().__init__() self.expr: Expr = convert(expr) class DeclareStmt(Stmt): def __init__(self, var, init: Optional[Expr] = None, is_static=False, scope: Optional[DeclareScope] = None): super().__init__() self.var: Var = var self.init: Optional[Expr] = convert(init) self.is_static: bool = is_static self.scope: Optional[DeclareScope] = scope if scope else DeclareScope.Default class BufferStoreStmt(Stmt): def __init__(self, buf, indices, value, protected=False): super().__init__() assert isinstance(indices, (list, tuple)), type(indices) self.buf: Union[Var, TensorNode] = buf self.indices = convert(indices) self.value = convert(value) self.protected = protected class AssignStmt(Stmt): def __init__(self, var, value): super().__init__() self.var: Var = var self.value: Expr = convert(value) class ReturnStmt(Stmt): def __init__(self, ret_value: Optional[Expr] = None): super().__init__() self.ret_value: Optional[Expr] = ret_value class LetStmt(Stmt): def __init__(self, bind_vars, bind_values, body=None): if not isinstance(bind_vars, (list, tuple)): bind_vars = [bind_vars] if not isinstance(bind_values, (list, tuple)): bind_values = [bind_values] assert len(bind_vars) == len(bind_values) assert len(bind_vars) > 0 bind_values = [convert(bind_value) for bind_value in bind_values] self.bind_vars: List[Var] = bind_vars self.bind_values: List[Expr] = bind_values self.body: Optional[Stmt] = body class ForStmt(Stmt): DEFAULT_UNROLL_LIMIT = 32 def __init__(self, loop_var, extent, body=None, *, attr: Optional[ForStmtAttr] = None): from import simplify # pylint: disable=import-outside-toplevel super().__init__() self.loop_var: Var = loop_var self.extent: Expr = simplify(convert(extent)) self.body: Optional[Stmt] = body self.attr: ForStmtAttr = attr if attr else ForStmtAttr.from_extent(extent) class ForMappingStmt(Stmt): def __init__(self, loop_vars: Sequence[Var], mapping: TaskMapping, worker: Expr, body: Stmt): self.loop_vars: List[Var] = list(loop_vars) self.mapping: TaskMapping = mapping self.worker: Expr = worker self.body: Stmt = body class WhileStmt(Stmt): def __init__(self, cond: Expr, body: Stmt): self.cond: Expr = cond self.body: Stmt = body class BreakStmt(Stmt): pass class ContinueStmt(Stmt): pass class IfStmt(Stmt): def __init__(self, cond: Expr, then_body=None, else_body=None): super().__init__() self.cond: Expr = convert(cond) self.then_body: Optional[Stmt] = then_body self.else_body: Optional[Stmt] = else_body class AssertStmt(Stmt): def __init__(self, cond: Union[Expr, bool], msg: Optional[str]): super().__init__() self.cond: Expr = convert(cond) self.msg: Optional[str] = msg class AsmStmt(Stmt): def __init__( self, template_string: str = "", outputs: Sequence[Tuple[str, Expr]] = (), inputs: Sequence[Tuple[str, Expr]] = (), is_volatile=False, ): self.template_string = template_string self.output_labels = [pr[0] for pr in outputs] self.output_exprs = [pr[1] for pr in outputs] self.input_labels = [pr[0] for pr in inputs] self.input_exprs = [pr[1] for pr in inputs] self.is_volatile = is_volatile class BlackBoxStmt(Stmt): def __init__(self, template_string: str, *exprs: Union[Expr, str]): super().__init__() self.template_string: str = template_string self.exprs: Tuple[Union[str, Expr]] = convert(exprs) expect_args_num = self.template_string.count('{}') if expect_args_num != len(exprs): raise ValueError("Invalid template string: {} for {} args".format(self.template_string, len(exprs))) class SeqStmt(Stmt): def __init__(self, seq: List[Stmt]): super().__init__() self.seq: Tuple[Stmt] = tuple(seq) for stmt in seq: assert isinstance(stmt, Stmt), str(type(stmt)) class LaunchKernelStmt(Stmt): def __init__( self, func_var: Var, args: Sequence[Expr], grid_dim: Tuple[Expr, Expr, Expr], block_dim: Tuple[Expr, Expr, Expr], shared_mem: Expr, ): self.func_var: Var = func_var self.args: List[Expr] = list(args) self.grid_dim: Tuple[Expr, Expr, Expr] = grid_dim self.block_dim: Tuple[Expr, Expr, Expr] = block_dim self.shared_mem_bytes: Expr = shared_mem assert is not None def asm( template: str, *, outputs: Sequence[Any] = (), output_inputs: Sequence[Any] = (), inputs: Sequence[Any] = (), is_volatile=False, ): from import infer_type # pylint: disable=import-outside-toplevel updated_outputs = [] updated_inputs = [] def get_register_type(expr: Expr) -> str: expr = convert(expr) expr_type = infer_type(expr) if isinstance(expr_type, ReferenceType): expr_type = expr_type.base_type if isinstance(expr_type, DataType): if isinstance(expr, Constant): return 'n' else: dtype2reg = { 'float16': 'h', 'float32': 'f', 'bfloat16': 'h', 'float64': 'd', 'uint8': 'h', 'uint16': 'h', 'uint32': 'r', 'uint64': 'l', 'int8': 'h', 'int16': 'h', 'int32': 'r', 'int64': 'l', } if not in dtype2reg: raise NotImplementedError('{}'.format(expr_type)) return dtype2reg[] elif isinstance(expr_type, (PointerType, TensorPointerType)): return 'l' else: raise ValueError('Can not deal with type {} in asm code.'.format(expr_type)) for output in outputs: constraint = '=' + get_register_type(output) updated_outputs.append((constraint, convert(output))) for output_input in output_inputs: constraint = '+' + get_register_type(output_input) updated_outputs.append((constraint, convert(output_input))) for x in inputs: constraint = get_register_type(x) updated_inputs.append((constraint, convert(x))) return AsmStmt(template, updated_outputs, updated_inputs, is_volatile) Int = Union[Expr, int] def launch_kernel( func_var: Var, args: Sequence[Expr], grid_dim: Union[Sequence[Int], Int], block_dim: Union[Sequence[Int], Int], shared_mem: Optional[Int] = 0, ) -> LaunchKernelStmt: launch_config: List[Tuple[Expr, Expr, Expr]] = [] for dims in [grid_dim, block_dim]: if not isinstance(dims, (list, tuple)): dims = [dims] dims = list(dims) if len(dims) > 3: raise ValueError('Grid/Block dimension must be 3 or less.') while len(dims) < 3: dims.append(1) launch_config.append(convert(dims)) grid_dim, block_dim = launch_config return LaunchKernelStmt(func_var, args, grid_dim, block_dim, convert(shared_mem))