# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=import-outside-toplevel
from __future__ import annotations
from typing import Sequence, Optional, Union, List, Tuple, Callable, Any
from hidet.ir.node import Node
# typing forward declaration
Expr = 'Expr'
Int = Union[int, Expr]
class BaseType(Node):
def __invert__(self) -> BaseType:
# get the pointer type that points to current type
if isinstance(self, TensorType):
return TensorPointerType.from_tensor_type(self)
elif isinstance(self, DataType):
return PointerType(base_type=self)
elif isinstance(self, (PointerType, TensorPointerType)):
return PointerType(base_type=self)
else:
raise ValueError('Can not recognize type {}'.format(self))
def __getitem__(self, item):
if isinstance(item, (tuple, list)):
if len(item) == 1:
item = item[0]
else:
raise ValueError('Currently, only support 1-d array, but got {}'.format(item))
return array_type(self, int(item))
def is_void(self):
return isinstance(self, VoidType)
def is_tensor(self):
return isinstance(self, TensorType)
def is_pointer(self):
return isinstance(self, (PointerType, TensorPointerType))
def is_data_type(self):
return isinstance(self, DataType)
def is_func_type(self):
return isinstance(self, FuncType)
def is_string_type(self):
return isinstance(self, StringType)
def as_data_type(self) -> Optional[DataType]:
if not isinstance(self, DataType):
return None
return self
[docs]class DataType(BaseType):
"""
The data type that defines how to interpret the data in memory.
"""
def __init__(self, name: str, short_name: str, nbytes: int):
self._name: str = name
self._short_name: str = short_name
self._nbytes: int = nbytes
def __str__(self):
return 'hidet.{}'.format(self.name)
def __eq__(self, other):
return isinstance(other, DataType) and self.name == other.name
def __hash__(self):
return hash(self.name)
def __call__(self, value: Any):
"""
Create a constant of current data type, or convert an existing Expr to current data type with cast expression.
Parameters
----------
value: Union[int, float, bool, list, tuple, Constant, Expr]
The value of the constant or the value to be casted.
Returns
-------
ret: Constant or Cast
The constant or cast expression.
"""
from hidet.ir import expr
built_types = (int, float, bool, complex)
if (
isinstance(value, built_types)
or isinstance(value, (list, tuple))
and all(isinstance(v, built_types) for v in value)
):
return self.constant(value)
elif isinstance(value, expr.Constant):
return self.constant(value.value)
elif isinstance(value, expr.Expr):
return expr.cast(value, self)
else:
raise ValueError('Can not convert {} to {}'.format(value, self))
def __getitem__(self, item):
if not isinstance(item, (tuple, list)):
item = (item,)
return tensor_type(dtype=self, shape=list(item))
@property
def name(self) -> str:
return self._name
@property
def short_name(self) -> str:
return self._short_name
@property
def nbytes(self) -> int:
return self._nbytes
@property
def nbits(self) -> int:
"""
Get the bit length of the data type
Note:
1. The bit length of the data type itself other than the bit length of its storage.
2. For regular data types, the nbits can be computed from its nbytes property.
3. For subbyte data types, the nbits is defined when constructing the data type,
and this method will also be overridden for subbyte data types.
4. In addition, we cannot access the nbytes for a subbyte data type, otherwise
a type error will be raised.
"""
return self._nbytes * 8
@property
def storage(self) -> DataType:
"""
Get the actual storage type of the data type
Note:
1. The storage of a regular data type is the data type itself, while the storage
of a subbyte type is the type of its actual storage. e.g., the storage of int4b is uint8
2. The property will be overridden in the subclass of subbyte types.
"""
return self
def is_integer_subbyte(self) -> bool:
raise NotImplementedError()
def is_float(self) -> bool:
raise NotImplementedError()
def is_integer(self) -> bool:
raise NotImplementedError()
def is_complex(self) -> bool:
raise NotImplementedError()
def is_vector(self) -> bool:
raise NotImplementedError()
def is_boolean(self) -> bool:
raise NotImplementedError()
def is_any_float16(self) -> bool:
raise NotImplementedError()
def constant(self, value: Any):
raise NotImplementedError()
@property
def one(self):
raise NotImplementedError()
@property
def zero(self):
raise NotImplementedError()
@property
def min_value(self):
raise NotImplementedError()
@property
def max_value(self):
raise NotImplementedError()
class TensorType(BaseType):
def __init__(self, dtype=None, shape=None, layout=None):
"""
A tensor type.
Parameters
----------
dtype: DataType
The data type of the tensor.
shape: Tuple[Expr, ...]
The shape of the tensor.
layout: hidet.ir.layout.DataLayout
The layout of the tensor.
"""
from hidet.ir.layout import DataLayout
self.dtype: DataType = dtype
self.shape: Tuple[Expr, ...] = shape
self.layout: DataLayout = layout
def __invert__(self):
return TensorPointerType.from_tensor_type(self)
def storage_bytes(self) -> Expr:
if self.dtype.is_integer_subbyte():
return self.layout.size * self.dtype.nbits // 8
else:
return self.layout.size * self.dtype.nbytes
def const_shape(self) -> List[int]:
return [int(v) for v in self.shape]
class VoidType(BaseType):
pass
class StringType(BaseType):
pass
class PointerType(BaseType):
def __init__(self, base_type, specifiers: Optional[Sequence[str]] = None, use_bracket: bool = False):
super().__init__()
if isinstance(base_type, str):
base_type = data_type(base_type)
self.base_type: BaseType = base_type
# todo: move the following attributes to DeclareStmt
self.specifiers: List[str] = list(specifiers) if specifiers else []
self.use_bracket: bool = use_bracket
def __call__(self, x):
from hidet.ir.expr import Constant, Expr, constant, cast # pylint: disable=redefined-outer-name
if isinstance(x, int):
return constant(x, self)
elif isinstance(x, Constant):
return constant(x.value, self)
elif isinstance(x, Expr):
return cast(x, self)
else:
raise ValueError('Can not convert {} to {}'.format(x, self))
class ReferenceType(BaseType):
def __init__(self, base_type):
super().__init__()
self.base_type = base_type
class TensorPointerType(BaseType):
def __init__(self, ttype: TensorType):
"""
A pointer type that points to tensor.
"""
self.tensor_type: TensorType = ttype
@staticmethod
def from_tensor_type(tp: TensorType) -> TensorPointerType:
tpt = object.__new__(TensorPointerType)
tpt.tensor_type = tp
return tpt
class ArrayType(BaseType):
def __init__(self, base_type, size: int):
super().__init__()
self.base_type: BaseType = base_type
self.size: int = size
assert isinstance(base_type, BaseType) and not isinstance(base_type, (ArrayType, TensorType))
assert isinstance(size, int) and size >= 0
TypeLike = Union[str, BaseType]
class FuncType(BaseType):
def __init__(
self,
param_types: Optional[List[TypeLike]] = None,
ret_type: Optional[TypeLike] = None,
type_infer_func: Optional[Callable] = None, # Callable[[a number of BaseType], BaseType]
):
self.param_types: Optional[List[BaseType]] = (
[self._convert_type(tp) for tp in param_types] if param_types is not None else None
)
self.ret_type: Optional[BaseType] = self._convert_type(ret_type) if ret_type is not None else None
self.type_infer_func: Optional[Callable[[List[BaseType]], BaseType]] = type_infer_func
msg = 'Please provide either a static type or a type infer func'
assert not all(v is None for v in [ret_type, type_infer_func]), msg
def ret_type_on(self, arg_types: List[BaseType]) -> BaseType:
if self.ret_type is not None:
# todo: add type checking
assert isinstance(self.ret_type, BaseType)
return self.ret_type
else:
return self.type_infer_func(arg_types)
def _convert_type(self, tp: Union[str, BaseType]):
if isinstance(tp, str):
return data_type(tp)
else:
return tp
@staticmethod
def from_func(func):
return FuncType([param.type for param in func.params], func.ret_type)
class OpaqueType(BaseType):
def __init__(self, cpp_name: str, *modifiers: str):
self.cpp_name: str = cpp_name
self.modifiers: Sequence[str] = modifiers
def tensor_type(dtype, shape: Optional[Sequence[Union[int, Expr]]] = None, layout=None):
"""
Construct a tensor type.
One of shape and layout must be given.
Parameters
----------
dtype: str or DataType
The scalar type of this tensor.
shape: Sequence[Union[int, Expr]] or none
The shape of the tensor. If not given, the shape in layout will be used.
layout: hidet.ir.layout.DataLayout or none
The layout of the tensor. If not given, the row major layout of given shape will
be used.
Returns
-------
ret: TensorType
The constructed tensor type
"""
from hidet.ir.expr import convert
from hidet.ir.layout import DataLayout, row_major
if isinstance(dtype, str):
dtype = data_type(dtype)
if not isinstance(dtype, DataType):
raise ValueError('Scalar type expect a "str" or "ScalarType", but got {}'.format(type(dtype)))
if shape is None and layout is None:
raise ValueError('Tensor type must give either shape or layout')
elif shape is None:
assert isinstance(layout, DataLayout)
shape = layout.shape
elif layout is None:
layout = row_major(*shape)
else:
assert isinstance(layout, DataLayout)
assert isinstance(shape, (list, tuple))
assert len(shape) == len(layout.shape)
shape = convert(shape)
return TensorType(dtype, shape, layout)
def array_type(base_type: BaseType, size: int):
return ArrayType(base_type, size)
def pointer_type(base_type):
return PointerType(base_type)
def tensor_pointer_type(dtype, shape=None, layout=None):
return TensorPointerType(tensor_type(dtype, shape, layout))
def string_type():
return StringType()
def func_type(param_types, ret_type) -> FuncType:
return FuncType(param_types, ret_type)
[docs]def data_type(dtype: Union[str, DataType]) -> DataType:
from hidet.ir.dtypes import name2dtype, sname2dtype
if isinstance(dtype, DataType):
return dtype
elif isinstance(dtype, str):
if dtype in name2dtype:
return name2dtype[dtype]
elif dtype in sname2dtype:
return sname2dtype[dtype]
else:
raise ValueError('Unknown data type: {}, candidates:\n{}'.format(dtype, '\n'.join(name2dtype.keys())))
else:
raise ValueError('Expect a string or a DataType, but got {}'.format(type(dtype)))
void_p = PointerType(VoidType())
byte_p = PointerType(data_type('uint8'))
void = VoidType()