# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=import-outside-toplevel
from __future__ import annotations
from typing import Sequence, Optional, Union, List, Tuple, Callable, Any
from hidet.ir.node import Node
# typing forward declaration
Expr = 'Expr'
Int = Union[int, Expr]
class BaseType(Node):
def __invert__(self) -> BaseType:
# get the pointer type that points to current type
if isinstance(self, TensorType):
return TensorPointerType.from_tensor_type(self)
elif isinstance(self, DataType):
return PointerType(base_type=self)
elif isinstance(self, (PointerType, TensorPointerType)):
return PointerType(base_type=self)
raise ValueError('Can not recognize type {}'.format(self))
def __getitem__(self, item):
if isinstance(item, (tuple, list)):
if len(item) == 1:
item = item[0]
raise ValueError('Currently, only support 1-d array, but got {}'.format(item))
return array_type(self, int(item))
def is_void(self):
return isinstance(self, VoidType)
def is_tensor(self):
return isinstance(self, TensorType)
def is_pointer(self):
return isinstance(self, (PointerType, TensorPointerType))
def is_data_type(self):
return isinstance(self, DataType)
def is_func_type(self):
return isinstance(self, FuncType)
def is_string_type(self):
return isinstance(self, StringType)
def as_data_type(self) -> Optional[DataType]:
if not isinstance(self, DataType):
return None
return self
[docs]class DataType(BaseType):
The data type that defines how to interpret the data in memory.
def __init__(self, name: str, short_name: str, nbytes: int):
self._name: str = name
self._short_name: str = short_name
self._nbytes: int = nbytes
def __str__(self):
return 'hidet.{}'.format(self.name)
def __eq__(self, other):
return isinstance(other, DataType) and self.name == other.name
def __hash__(self):
return hash(self.name)
def __call__(self, value: Any):
Create a constant of current data type, or convert an existing Expr to current data type with cast expression.
value: Union[int, float, bool, list, tuple, Constant, Expr]
The value of the constant or the value to be casted.
ret: Constant or Cast
The constant or cast expression.
from hidet.ir import expr
built_types = (int, float, bool, complex)
if (
isinstance(value, built_types)
or isinstance(value, (list, tuple))
and all(isinstance(v, built_types) for v in value)
return self.constant(value)
elif isinstance(value, expr.Constant):
return self.constant(value.value)
elif isinstance(value, expr.Expr):
return expr.cast(value, self)
raise ValueError('Can not convert {} to {}'.format(value, self))
def __getitem__(self, item):
if not isinstance(item, (tuple, list)):
item = (item,)
return tensor_type(dtype=self, shape=list(item))
def name(self) -> str:
return self._name
def short_name(self) -> str:
return self._short_name
def nbytes(self) -> int:
return self._nbytes
def is_float(self) -> bool:
raise NotImplementedError()
def is_integer(self) -> bool:
raise NotImplementedError()
def is_complex(self) -> bool:
raise NotImplementedError()
def is_vector(self) -> bool:
raise NotImplementedError()
def constant(self, value: Any):
raise NotImplementedError()
def one(self):
raise NotImplementedError()
def zero(self):
raise NotImplementedError()
def min_value(self):
raise NotImplementedError()
def max_value(self):
raise NotImplementedError()
class TensorType(BaseType):
def __init__(self, dtype=None, shape=None, layout=None):
A tensor type.
dtype: DataType
The data type of the tensor.
shape: Tuple[Expr, ...]
The shape of the tensor.
layout: hidet.ir.layout.DataLayout
The layout of the tensor.
from hidet.ir.layout import DataLayout
self.dtype: DataType = dtype
self.shape: Tuple[Expr, ...] = shape
self.layout: DataLayout = layout
def __invert__(self):
return TensorPointerType.from_tensor_type(self)
def storage_bytes(self) -> Expr:
return self.layout.size * self.dtype.nbytes
def const_shape(self) -> List[int]:
return [int(v) for v in self.shape]
class VoidType(BaseType):
class StringType(BaseType):
class PointerType(BaseType):
def __init__(self, base_type, specifiers: Optional[Sequence[str]] = None, use_bracket: bool = False):
if isinstance(base_type, str):
base_type = data_type(base_type)
self.base_type: BaseType = base_type
# todo: move the following attributes to DeclareStmt
self.specifiers: List[str] = list(specifiers) if specifiers else []
self.use_bracket: bool = use_bracket
def __call__(self, x):
from hidet.ir.expr import Constant, Expr, constant, cast # pylint: disable=redefined-outer-name
if isinstance(x, int):
return constant(x, self)
elif isinstance(x, Constant):
return constant(x.value, self)
elif isinstance(x, Expr):
return cast(x, self)
raise ValueError('Can not convert {} to {}'.format(x, self))
class ReferenceType(BaseType):
def __init__(self, base_type):
self.base_type = base_type
class TensorPointerType(BaseType):
def __init__(self, ttype: TensorType):
A pointer type that points to tensor.
self.tensor_type: TensorType = ttype
def from_tensor_type(tp: TensorType) -> TensorPointerType:
tpt = object.__new__(TensorPointerType)
tpt.tensor_type = tp
return tpt
class ArrayType(BaseType):
def __init__(self, base_type, size: int):
self.base_type: BaseType = base_type
self.size: int = size
assert isinstance(base_type, BaseType) and not isinstance(base_type, (ArrayType, TensorType))
assert isinstance(size, int) and size >= 0
TypeLike = Union[str, BaseType]
class FuncType(BaseType):
def __init__(
param_types: Optional[List[TypeLike]] = None,
ret_type: Optional[TypeLike] = None,
type_infer_func: Optional[Callable] = None, # Callable[[a number of BaseType], BaseType]
self.param_types: Optional[List[BaseType]] = (
[self._convert_type(tp) for tp in param_types] if param_types is not None else None
self.ret_type: Optional[BaseType] = self._convert_type(ret_type) if ret_type is not None else None
self.type_infer_func: Optional[Callable[[List[BaseType]], BaseType]] = type_infer_func
msg = 'Please provide either a static type or a type infer func'
assert not all(v is None for v in [ret_type, type_infer_func]), msg
def ret_type_on(self, arg_types: List[BaseType]) -> BaseType:
if self.ret_type is not None:
# todo: add type checking
assert isinstance(self.ret_type, BaseType)
return self.ret_type
return self.type_infer_func(arg_types)
def _convert_type(self, tp: Union[str, BaseType]):
if isinstance(tp, str):
return data_type(tp)
return tp
def from_func(func):
return FuncType([param.type for param in func.params], func.ret_type)
def tensor_type(dtype, shape: Optional[Sequence[Union[int, Expr]]] = None, layout=None):
Construct a tensor type.
One of shape and layout must be given.
dtype: str or DataType
The scalar type of this tensor.
shape: Sequence[Union[int, Expr]] or none
The shape of the tensor. If not given, the shape in layout will be used.
layout: hidet.ir.layout.DataLayout or none
The layout of the tensor. If not given, the row major layout of given shape will
be used.
ret: TensorType
The constructed tensor type
from hidet.ir.expr import convert
from hidet.ir.layout import DataLayout, row_major
if isinstance(dtype, str):
dtype = data_type(dtype)
if not isinstance(dtype, DataType):
raise ValueError('Scalar type expect a "str" or "ScalarType", but got {}'.format(type(dtype)))
if shape is None and layout is None:
raise ValueError('Tensor type must give either shape or layout')
elif shape is None:
assert isinstance(layout, DataLayout)
shape = layout.shape
elif layout is None:
layout = row_major(*shape)
assert isinstance(layout, DataLayout)
assert isinstance(shape, (list, tuple))
assert len(shape) == len(layout.shape)
shape = convert(shape)
return TensorType(dtype, shape, layout)
def array_type(base_type: BaseType, size: int):
return ArrayType(base_type, size)
def pointer_type(base_type):
return PointerType(base_type)
def tensor_pointer_type(dtype, shape=None, layout=None):
return TensorPointerType(tensor_type(dtype, shape, layout))
def string_type():
return StringType()
def func_type(param_types, ret_type) -> FuncType:
return FuncType(param_types, ret_type)
[docs]def data_type(dtype: Union[str, DataType]) -> DataType:
from hidet.ir.dtypes import name2dtype, sname2dtype
if isinstance(dtype, DataType):
return dtype
elif isinstance(dtype, str):
if dtype in name2dtype:
return name2dtype[dtype]
elif dtype in sname2dtype:
return sname2dtype[dtype]
raise ValueError('Unknown data type: {}, candidates:\n{}'.format(dtype, '\n'.join(name2dtype.keys())))
raise ValueError('Expect a string or a DataType, but got {}'.format(type(dtype)))
void_p = PointerType(VoidType())
byte_p = PointerType(data_type('uint8'))
void = VoidType()