Define Operator Computation

Each operator takes a list of input tensors and produces a list of output tensors:

inputs: List[Tensor]
outputs: List[Tensor] = operator(inputs)


Our pioneers Halide and Apache TVM also employ a similar DSL to define the mathematical definition of an operator.

The precise mathematical definition of each operator in Hidet is defined through a domain-specific-language (DSL). In this tutorial, we will show how to define the mathematical definition of a new operator in Hidet using this DSL, which is defined in the module.

Compute Primitives

This module provides compute primitives to define the mathematical computation of an operator:

tensor_input(name: str, dtype: str, shape: List[int])

The tensor_input() primitive defines a tensor input by specifying the name hint, scalar data type, and shape of the tensor.

a = tensor_input('a', dtype='float32', shape=[10, 10])
b = tensor_input('b', dtype='float32', shape=[])
b = tensor_input('data', dtype='float16', shape=[1, 3, 224, 224])
compute(name: str, shape: List[int], fcompute: Callable[[Var, ...], Expr])

The compute() primitive defines a tensor by specifying

  • the name of the tensor, just a hint for what the tensor represents,

  • the shape of the tensor, and

  • a function that maps an index to the expression that computes the value of the tensor at that index.

The computation of each element of the tensor is independent with each other and can be computed in parallel.

# compute primitive
out = compute(
    shape=[n1, n2, ..., nk],
    fcompute=lambda i1, i2, ..., ik: f(i1, i2, ..., ik)

# semantics
for i1 in range(n1):
  for i2 in range(n2):
      for ik in range(nk):
        out[i1, i2, ..., ik] = f(i1, i2, ..., ik)


In the last example, we used an if_then_else() expression to define a conditional expression.

# define an input tensor
a = tensor_input('a', dtype='float32', shape=[10, 10])

# example 1: slice the first column of a
b = compute('slice', shape=[10], fcompute=lambda i: a[i, 0])

# example 2: reverse the rows of matrix a
c = compute('reverse', shape=[10, 10], fcompute=lambda i, j: a[9 - i, j])

# example 3: add 1 to the diagonal elements of a
from import if_then_else
d = compute(
  shape=[10, 10],
  fcompute=lambda i, j: if_then_else(i == j, then_expr=a[i, j] + 1.0, else_expr=a[i, j])
reduce(shape: List[int], fcompute: Callable[[Var, ...], Expr], reduce_type='sum')

The reduce() primitive conducts a reduction operation on a domain with the given shape. It returns a scalar value and can be used in compute() primitive.

# reduce primitive
out = reduce(
    shape=[n1, n2, ..., nk],
    fcompute=lambda i1, i2, ..., ik: f(i1, i2, ..., ik)
    reduce_type='sum' | 'max' | 'min' | 'avg'

# semantics
values = []
for i1 in range(n1):
  for i2 in range(n2):
      for ik in range(nk):
        values.append(f(i1, i2, ..., ik))
out = reduce_type(values)
# define an input tensor
a = tensor_input('a', dtype='float32', shape=[10, 10])

# example 1: sum all elements of a
c = reduce(shape=[10, 10], fcompute=lambda i, j: a[i, j], reduce_type='sum')

# example 2: sum the first column of a
d = reduce(shape=[10], fcompute=lambda i: a[i, 0], reduce_type='sum')

# example 3: matrix multiplication
b = tensor_input('b', dtype='float32', shape=[10, 10])
e = compute(
    shape=[10, 10],
    fcompute=lambda i, j: reduce(
        fcompute=lambda k: a[i, k] * b[k, j],
arg_reduce(extent: int, fcompute: Callable[[Var], Expr], reduce_type='max')

Similar to reduce(), the arg_reduce() primitive conducts a reduction operation on a domain with the given extent. The difference is that it returns the index of the element that corresponds to the reduction result, instead of the result itself.

# arg_reduce primitive
out = arg_reduce(extent, fcompute=lambda i: f(i), reduce_type='max' | 'min')

# semantics
values = []
for i in range(extent):
out = index of the max/min value in values
# define an input tensor
a = tensor_input('a', dtype='float32', shape=[10, 10])

# example: find the index of the max element in each row of a
b = compute('b', [10], lambda i: arg_reduce(10, lambda j: a[i, j], reduce_type='max'))

Define a Computation Task

The computation of each operator can be described as a directed acyclic graph (DAG). The DAG is composed of tensor nodes. Both tensor_input() and compute() primitives create tensor nodes. The edges of the DAG are the dependencies between the tensor nodes. Such a DAG is stored in a Task object.

class Task(name: str, inputs: List[TensorNode], outputs: List[TensorNode])

Each task has a name, a list of inputs, and a list of outputs, correspongding to the inputs and outputs of the operator. The following example shows how to create a task.

def demo_task():
    from import tensor_input, compute
    from import Task

    # define the computation DAG through the compute primitives
    a = tensor_input('a', dtype='float32', shape=[10])
    b = tensor_input('b', dtype='float32', shape=[10])
    c = compute('c', [10], lambda i: a[i] + i)
    d = compute('d', [10], lambda i: c[9 - i])
    e = compute('e', [10], lambda i: a[i] + b[i])

    # create a task object
    task = Task(name='task', inputs=[a, b], outputs=[d, e])

  name: task
    a: tensor(float32, [10])
    b: tensor(float32, [10])
    d: tensor(float32, [10])
    e: tensor(float32, [10])
  inputs: [a, b]
  outputs: [d, e]
    e: float32[10] where e[v] = (a[v] + b[v])
    c: float32[10] where c[v_1] = (a[v_1] + v_1)
    d: float32[10] where d[v_2] = c[(9 - v_2)]
  attributes: {}

Its computation DAG can be visualized as follows.

digraph { // rankdir=LR; splines=curved; node [ shape=box, style="rounded", height=0.4, width=0.6 ]; graph [style="rounded, dashed"] subgraph cluster_0 { graph [style="rounded, dashed", margin="12"]; node [group=0]; label="Inputs"; a [label="A"]; b [label="B"]; } subgraph cluster_1 { graph [style="rounded, dashed", labelloc="b", margin="15"]; node [group=1]; labeljust="b"; d [label="D"]; e [label="E"]; label="Outputs"; } c [label="C"]; a -> c -> d a -> e b -> e }

An example of computation DAG. In this example, there are 5 tensor nodes, where node A and B are inputs and node D and E are outputs. The computation of node C depends on the computation of node A and B.

Build and Run a Task

We provide a driver function hidet.driver.build_task() to build a task into callable function. The build_task() function does the following steps to lower the task into a callable function:


A scheduler is a function that takes a task as input and returns an scheduled tensor program defined in an IRModule.

  1. Dispatch the task to a scheduler according to the target device and task.

  2. The scheduler lowers the task into a tensor program, defined with IRModule.

  3. Lower and optimize the IRModule.

  4. Code generation that translates the IRModule into the target source code (e.g.,

  5. Call compiler (e.g., nvcc) to compile the source code into a dynamic library (i.e.,

  6. Load the dynamic library and wrap it to CompiledFunction that can be directly called.

We can define the following function to build and run a task.

from typing import List
import hidet
from import Task

def run_task(task: Task, inputs: List[hidet.Tensor]):
    """Run given task and print inputs and outputs"""
    from hidet.runtime import CompiledTask

    # build the task
    func: CompiledTask = hidet.drivers.build_task(task, target='cpu')

    # run the compiled task
    outputs = func.run_async(inputs)

    for tensor in inputs:
    for tensor in outputs:

The following code shows how to 1) define the computation, 2) define the task, and 3) build and run the task.


Please pay attention to the difference between Tensor and TensorNode. The former is a tensor object that can be used to store data and trace the high-level computation graph of a deep learning model. The latter is a tensor node in the domain-specific language that is used to describe the computation of a single operator.

from import tensor_input, reduce, compute, arg_reduce, TensorNode

def add_example():
    a: TensorNode = tensor_input(name='a', dtype='float32', shape=[5])
    b: TensorNode = tensor_input(name='b', dtype='float32', shape=[5])
    c: TensorNode = compute(name='c', shape=[5], fcompute=lambda i: a[i] + b[i])
    task = Task(name='add', inputs=[a, b], outputs=[c])
    run_task(task, [hidet.randn([5]), hidet.randn([5])])

Task: add
Tensor(shape=(5,), dtype='float32', device='cpu')
[ 0.25 -0.46  1.21 -0.01  0.72]
Tensor(shape=(5,), dtype='float32', device='cpu')
[-1.65  0.67  0.47 -1.95  0.14]
Tensor(shape=(5,), dtype='float32', device='cpu')
[-1.4   0.21  1.67 -1.96  0.86]

More Examples


All the hidet operators are defined in hidet.graph.ops submodule. And all of existing operators are defined through the compute primitives described in this tutorial. Feel free to check the source code to learn more about how to define the computation of different operators.

At last, we show more examples of using the compute primitives to define operator computation.


def reduce_sum_example():
    a = tensor_input('a', dtype='float32', shape=[4, 3])
    b = compute(
        fcompute=lambda i: reduce(shape=[3], fcompute=lambda j: a[i, j], reduce_type='sum'),
    task = Task('reduce_sum', inputs=[a], outputs=[b])
    run_task(task, [hidet.randn([4, 3])])

Task: reduce_sum
Tensor(shape=(4, 3), dtype='float32', device='cpu')
[[-1.91 -0.52 -0.36]
 [ 0.11  0.95  0.52]
 [-0.34 -1.59  0.26]
 [-1.04 -1.76 -0.2 ]]
Tensor(shape=(4,), dtype='float32', device='cpu')
[-2.8   1.58 -1.68 -3.  ]


def arg_max_example():
    a = tensor_input('a', dtype='float32', shape=[4, 3])
    b = compute(
        fcompute=lambda i: arg_reduce(extent=3, fcompute=lambda j: a[i, j], reduce_type='max'),
    task = Task('arg_max', inputs=[a], outputs=[b])
    run_task(task, [hidet.randn([4, 3])])

Task: arg_max
Tensor(shape=(4, 3), dtype='float32', device='cpu')
[[ 0.41  1.8   0.05]
 [-0.34 -1.74  0.64]
 [-0.44 -0.93  0.13]
 [ 0.25  0.47 -1.07]]
Tensor(shape=(4,), dtype='int64', device='cpu')
[1 2 2 1]


def matmul_example():
    a = tensor_input('a', dtype='float32', shape=[3, 3])
    b = tensor_input('b', dtype='float32', shape=[3, 3])
    c = compute(
        shape=[3, 3],
        fcompute=lambda i, j: reduce(
            shape=[3], fcompute=lambda k: a[i, k] * b[k, j], reduce_type='sum'
    task = Task('matmul', inputs=[a, b], outputs=[c])
    run_task(task, [hidet.randn([3, 3]), hidet.randn([3, 3])])

Task: matmul
Tensor(shape=(3, 3), dtype='float32', device='cpu')
[[-0.85  0.75 -1.04]
 [-0.52  2.25 -0.63]
 [-0.51  0.39  1.08]]
Tensor(shape=(3, 3), dtype='float32', device='cpu')
[[-0.8  -0.65  1.15]
 [-0.02  0.25  0.72]
 [-1.47 -1.33 -0.91]]
Tensor(shape=(3, 3), dtype='float32', device='cpu')
[[ 2.19  2.12  0.5 ]
 [ 1.29  1.73  1.59]
 [-1.18 -1.   -1.29]]


def softmax_example():
    from import exp

    a = tensor_input('a', dtype='float32', shape=[3])
    max_val = reduce(shape=[3], fcompute=lambda i: a[i], reduce_type='max')
    b = compute('b', shape=[3], fcompute=lambda i: a[i] - max_val)
    exp_a = compute('exp', shape=[3], fcompute=lambda i: exp(b[i]))
    exp_sum = reduce(shape=[3], fcompute=lambda i: exp_a[i], reduce_type='sum')
    softmax = compute('softmax', shape=[3], fcompute=lambda i: exp_a[i] / exp_sum)

    task = Task('softmax', inputs=[a], outputs=[softmax])
    run_task(task, [hidet.randn([3])])

Task: softmax
Tensor(shape=(3,), dtype='float32', device='cpu')
[ 0.58  0.16 -1.86]
Tensor(shape=(3,), dtype='float32', device='cpu')
[0.57 0.38 0.05]


In this tutorial, we introduced the compute primitives that are used to define the computation of operators in Hidet. After that, we showed how to wrap the computation DAG into a task and build and run the task. In the next step, we will show you how to use these compute primitives to define new operators in Hidet.

Total running time of the script: (0 minutes 1.148 seconds)

Gallery generated by Sphinx-Gallery