Source code for hidet.ir.type

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=import-outside-toplevel
from __future__ import annotations
from typing import Sequence, Optional, Union, List, Tuple, Callable, Any

from hidet.ir.node import Node

# typing forward declaration
Expr = 'Expr'
Int = Union[int, Expr]


class BaseType(Node):
    def __invert__(self) -> BaseType:
        # get the pointer type that points to current type
        if isinstance(self, TensorType):
            return TensorPointerType.from_tensor_type(self)
        elif isinstance(self, DataType):
            return PointerType(base_type=self)
        elif isinstance(self, (PointerType, TensorPointerType)):
            return PointerType(base_type=self)
        else:
            raise ValueError('Can not recognize type {}'.format(self))

    def __getitem__(self, item):
        if isinstance(item, (tuple, list)):
            if len(item) == 1:
                item = item[0]
            else:
                raise ValueError('Currently, only support 1-d array, but got {}'.format(item))
        return array_type(self, int(item))

    def is_void(self):
        return isinstance(self, VoidType)

    def is_tensor(self):
        return isinstance(self, TensorType)

    def is_pointer(self):
        return isinstance(self, (PointerType, TensorPointerType))

    def is_data_type(self):
        return isinstance(self, DataType)

    def is_func_type(self):
        return isinstance(self, FuncType)

    def is_string_type(self):
        return isinstance(self, StringType)

    def as_data_type(self) -> Optional[DataType]:
        if not isinstance(self, DataType):
            return None
        return self


[docs]class DataType(BaseType): """ The data type that defines how to interpret the data in memory. """ def __init__(self, name: str, short_name: str, nbytes: int): self._name: str = name self._short_name: str = short_name self._nbytes: int = nbytes def __str__(self): return 'hidet.{}'.format(self.name) def __eq__(self, other): return isinstance(other, DataType) and self.name == other.name def __hash__(self): return hash(self.name) def __call__(self, value: Any): """ Create a constant of current data type, or convert an existing Expr to current data type with cast expression. Parameters ---------- value: Union[int, float, bool, list, tuple, Constant, Expr] The value of the constant or the value to be casted. Returns ------- ret: Constant or Cast The constant or cast expression. """ from hidet.ir import expr built_types = (int, float, bool, complex) if ( isinstance(value, built_types) or isinstance(value, (list, tuple)) and all(isinstance(v, built_types) for v in value) ): return self.constant(value) elif isinstance(value, expr.Constant): return self.constant(value.value) elif isinstance(value, expr.Expr): return expr.cast(value, self) else: raise ValueError('Can not convert {} to {}'.format(value, self)) def __getitem__(self, item): if not isinstance(item, (tuple, list)): item = (item,) return tensor_type(dtype=self, shape=list(item)) @property def name(self) -> str: return self._name @property def short_name(self) -> str: return self._short_name @property def nbytes(self) -> int: return self._nbytes @property def nbits(self) -> int: """ Get the bit length of the data type Note: 1. The bit length of the data type itself other than the bit length of its storage. 2. For regular data types, the nbits can be computed from its nbytes property. 3. For subbyte data types, the nbits is defined when constructing the data type, and this method will also be overridden for subbyte data types. 4. In addition, we cannot access the nbytes for a subbyte data type, otherwise a type error will be raised. """ return self._nbytes * 8 @property def storage(self) -> DataType: """ Get the actual storage type of the data type Note: 1. The storage of a regular data type is the data type itself, while the storage of a subbyte type is the type of its actual storage. e.g., the storage of int4b is uint8 2. The property will be overridden in the subclass of subbyte types. """ return self def is_integer_subbyte(self) -> bool: raise NotImplementedError() def is_float(self) -> bool: raise NotImplementedError() def is_integer(self) -> bool: raise NotImplementedError() def is_complex(self) -> bool: raise NotImplementedError() def is_vector(self) -> bool: raise NotImplementedError() def is_boolean(self) -> bool: raise NotImplementedError() def is_any_float16(self) -> bool: raise NotImplementedError() def constant(self, value: Any): raise NotImplementedError() @property def one(self): raise NotImplementedError() @property def zero(self): raise NotImplementedError() @property def min_value(self): raise NotImplementedError() @property def max_value(self): raise NotImplementedError()
class TensorType(BaseType): def __init__(self, dtype=None, shape=None, layout=None): """ A tensor type. Parameters ---------- dtype: DataType The data type of the tensor. shape: Tuple[Expr, ...] The shape of the tensor. layout: hidet.ir.layout.DataLayout The layout of the tensor. """ from hidet.ir.layout import DataLayout self.dtype: DataType = dtype self.shape: Tuple[Expr, ...] = shape self.layout: DataLayout = layout def __invert__(self): return TensorPointerType.from_tensor_type(self) def storage_bytes(self) -> Expr: if self.dtype.is_integer_subbyte(): return self.layout.size * self.dtype.nbits // 8 else: return self.layout.size * self.dtype.nbytes def const_shape(self) -> List[int]: return [int(v) for v in self.shape] class VoidType(BaseType): pass class StringType(BaseType): pass class PointerType(BaseType): def __init__(self, base_type, specifiers: Optional[Sequence[str]] = None, use_bracket: bool = False): super().__init__() if isinstance(base_type, str): base_type = data_type(base_type) self.base_type: BaseType = base_type # todo: move the following attributes to DeclareStmt self.specifiers: List[str] = list(specifiers) if specifiers else [] self.use_bracket: bool = use_bracket def __call__(self, x): from hidet.ir.expr import Constant, Expr, constant, cast # pylint: disable=redefined-outer-name if isinstance(x, int): return constant(x, self) elif isinstance(x, Constant): return constant(x.value, self) elif isinstance(x, Expr): return cast(x, self) else: raise ValueError('Can not convert {} to {}'.format(x, self)) class ReferenceType(BaseType): def __init__(self, base_type): super().__init__() self.base_type = base_type class TensorPointerType(BaseType): def __init__(self, ttype: TensorType): """ A pointer type that points to tensor. """ self.tensor_type: TensorType = ttype @staticmethod def from_tensor_type(tp: TensorType) -> TensorPointerType: tpt = object.__new__(TensorPointerType) tpt.tensor_type = tp return tpt class ArrayType(BaseType): def __init__(self, base_type, size: int): super().__init__() self.base_type: BaseType = base_type self.size: int = size assert isinstance(base_type, BaseType) and not isinstance(base_type, (ArrayType, TensorType)) assert isinstance(size, int) and size >= 0 TypeLike = Union[str, BaseType] class FuncType(BaseType): def __init__( self, param_types: Optional[List[TypeLike]] = None, ret_type: Optional[TypeLike] = None, type_infer_func: Optional[Callable] = None, # Callable[[a number of BaseType], BaseType] ): self.param_types: Optional[List[BaseType]] = ( [self._convert_type(tp) for tp in param_types] if param_types is not None else None ) self.ret_type: Optional[BaseType] = self._convert_type(ret_type) if ret_type is not None else None self.type_infer_func: Optional[Callable[[List[BaseType]], BaseType]] = type_infer_func msg = 'Please provide either a static type or a type infer func' assert not all(v is None for v in [ret_type, type_infer_func]), msg def ret_type_on(self, arg_types: List[BaseType]) -> BaseType: if self.ret_type is not None: # todo: add type checking assert isinstance(self.ret_type, BaseType) return self.ret_type else: return self.type_infer_func(arg_types) def _convert_type(self, tp: Union[str, BaseType]): if isinstance(tp, str): return data_type(tp) else: return tp @staticmethod def from_func(func): return FuncType([param.type for param in func.params], func.ret_type) class OpaqueType(BaseType): def __init__(self, cpp_name: str, *modifiers: str): self.cpp_name: str = cpp_name self.modifiers: Sequence[str] = modifiers def tensor_type(dtype, shape: Optional[Sequence[Union[int, Expr]]] = None, layout=None): """ Construct a tensor type. One of shape and layout must be given. Parameters ---------- dtype: str or DataType The scalar type of this tensor. shape: Sequence[Union[int, Expr]] or none The shape of the tensor. If not given, the shape in layout will be used. layout: hidet.ir.layout.DataLayout or none The layout of the tensor. If not given, the row major layout of given shape will be used. Returns ------- ret: TensorType The constructed tensor type """ from hidet.ir.expr import convert from hidet.ir.layout import DataLayout, row_major if isinstance(dtype, str): dtype = data_type(dtype) if not isinstance(dtype, DataType): raise ValueError('Scalar type expect a "str" or "ScalarType", but got {}'.format(type(dtype))) if shape is None and layout is None: raise ValueError('Tensor type must give either shape or layout') elif shape is None: assert isinstance(layout, DataLayout) shape = layout.shape elif layout is None: layout = row_major(*shape) else: assert isinstance(layout, DataLayout) assert isinstance(shape, (list, tuple)) assert len(shape) == len(layout.shape) shape = convert(shape) return TensorType(dtype, shape, layout) def array_type(base_type: BaseType, size: int): return ArrayType(base_type, size) def pointer_type(base_type): return PointerType(base_type) def tensor_pointer_type(dtype, shape=None, layout=None): return TensorPointerType(tensor_type(dtype, shape, layout)) def string_type(): return StringType() def func_type(param_types, ret_type) -> FuncType: return FuncType(param_types, ret_type)
[docs]def data_type(dtype: Union[str, DataType]) -> DataType: from hidet.ir.dtypes import name2dtype, sname2dtype if isinstance(dtype, DataType): return dtype elif isinstance(dtype, str): if dtype in name2dtype: return name2dtype[dtype] elif dtype in sname2dtype: return sname2dtype[dtype] else: raise ValueError('Unknown data type: {}, candidates:\n{}'.format(dtype, '\n'.join(name2dtype.keys()))) else: raise ValueError('Expect a string or a DataType, but got {}'.format(type(dtype)))
void_p = PointerType(VoidType()) byte_p = PointerType(data_type('uint8')) void = VoidType()