hidet.ir.compute¶
Tip
Please refer to here for how to use these compute primitives to define a computation task.
Classes:
|
An enumeration. |
Functions:
|
Define an input scalar node. |
|
Define an input tensor node. |
|
Define a grid compute node. |
|
Define a reduction node. |
|
Define an arg reduction node. |
- hidet.ir.compute.scalar_input(name, dtype)[source]¶
Define an input scalar node.
- Parameters:
name (str) – The name of the input scalar.
dtype (str or DataType) – The scalar type of the input scalar.
- Returns:
ret – The input scalar node.
- Return type:
ScalarInput
- hidet.ir.compute.tensor_input(name, dtype, shape, layout=None)[source]¶
Define an input tensor node.
- Parameters:
name (str) – The name of the input tensor.
dtype (str or DataType) – The scalar type of the tensor.
shape (Sequence[Expr or int]) – The shape of the tensor.
layout (DataLayout, optional) – The layout of the tensor.
- Returns:
ret – The input tensor node.
- Return type:
TensorInput
- hidet.ir.compute.compute(name, shape, fcompute, layout=None)[source]¶
Define a grid compute node.
- Parameters:
name (str) – The name of the compute node.
shape (Sequence[Union[int, Expr]]) – The shape of the compute node.
fcompute (Callable[[Sequence[Var]], Expr]) – The compute function. It takes a list of index variables and returns the output value corresponding to the index.
layout (DataLayout, optional) – The layout of the compute node.
- Returns:
ret – The grid compute node.
- Return type:
TensorNode
- hidet.ir.compute.reduce(shape, fcompute, reduce_type, accumulate_dtype='float32', name=None)[source]¶
Define a reduction node.
- Parameters:
shape (Sequence[int or Expr]) – The domain of the reduction.
fcompute (Callable[[Sequence[Var]], Expr]) – The compute function. It takes a list of reduction variables and returns the reduction value.
reduce_type (ReduceType or str) – The type of the reduction.
accumulate_dtype (str or DataType) – The data type of the accumulator.
name (Optional[str]) – The name hint for the output. If not specified, the name will be generated automatically.
- Returns:
ret – The reduction node.
- Return type:
ReduceCompute
- hidet.ir.compute.arg_reduce(extent, fcompute, reduce_type, index_dtype='int64', name=None)[source]¶
Define an arg reduction node.
- Parameters:
extent (int or Expr) – The domain of the reduction.
fcompute (Callable[[Var], Expr]) – The compute function. It takes a reduction variable and returns the value to compare.
reduce_type (str or ReduceType) – The type of the reduction.
index_dtype (str or DataType) – The data type of the index.
name (str, optional) – The name of the output. If not specified, the name will be generated automatically.
- Returns:
ret – The arg reduction node.
- Return type:
ScalarNode