Source code for hidet.graph.frontend.onnx.onnx

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Import onnx model to hidet.

Please refers to https://github.com/onnx/onnx/blob/main/docs/Operators.md for operator definition.
Please refers to https://github.com/onnx/onnx/blob/main/onnx/onnx.proto for proto structure of onnx format.
"""
# pylint: disable=unused-argument
from typing import List, Union, Optional, Dict, Callable, Type, Sequence, Set
from collections import defaultdict
import warnings
import os
import logging
import numpy as np
import onnx
import onnx.numpy_helper
import onnx.external_data_helper
import hidet
from hidet.graph import nn
from hidet.graph import ops
from hidet.graph.tensor import Tensor, from_numpy, randn
from . import utils

log = logging.getLogger(__name__)


class OnnxOperator:
    def __init__(self, node, op_sets: List[int]):
        """
        Parameters
        ----------
        node: onnx.NodeProto
        """
        self.node: onnx.NodeProto = node
        self.op_sets: List[int] = op_sets
        self.input_names: List[str] = [name for name in node.input]
        self.output_names: List[str] = [name for name in node.output]
        self.attrs = {}
        for attr in node.attribute:
            if attr.type == 1:  # float
                v = attr.f
            elif attr.type == 2:  # int
                v = attr.i
            elif attr.type == 3:  # string
                v = attr.s.decode('utf-8')
            elif attr.type == 4:  # tensor
                v = from_numpy(onnx.numpy_helper.to_array(tensor=attr.t)).cuda()
            elif attr.type == 5:  # graph
                v = attr.g
            elif attr.type == 6:  # floats
                v = list(attr.floats)
            elif attr.type == 7:  # ints
                v = list(attr.ints)
            elif attr.type == 8:  # strings
                v = [s.decode('utf-8') for s in attr.strings]
            else:
                raise ValueError('Can not recognize type id {} of attribute {}'.format(attr.type, attr.name))
            self.attrs[attr.name] = v

    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        opset = self.resolve_opset(self.op_sets)
        run_func: Callable[[List[Tensor]], List[Tensor]] = getattr(self, 'run_v{}'.format(opset))
        outs = run_func(inputs)
        return outs

    def resolve_opset(self, op_sets: List[int]) -> int:
        for op_set in op_sets:
            try_op_set = op_set
            while try_op_set >= 1:
                if self.implemented(try_op_set):
                    return try_op_set
                try_op_set -= 1
        raise NotImplementedError(
            'Can not resolve opset for operator {} given opsets {}.'.format(self.node.op_type, op_sets)
        )

    def implemented(self, opset: int):
        func_name = 'run_v{}'.format(opset)
        this_func = getattr(self, func_name)
        base_func = getattr(OnnxOperator, func_name)
        return this_func.__func__ is not base_func

    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v2(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v3(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v4(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v5(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v6(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v7(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v8(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v9(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v10(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v11(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v12(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v13(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v14(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v15(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v16(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v17(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    def run_v18(self, inputs: List[Tensor]) -> List[Tensor]:
        return NotImplemented

    @staticmethod
    def tensor2list(tensor: Tensor) -> Union[List, int, float]:
        ret = tensor.cpu().numpy().tolist()
        assert isinstance(ret, (list, int, float))
        return ret

    @staticmethod
    def tensor2scalar(tensor: Tensor) -> Union[int, float]:
        value = OnnxOperator.tensor2list(tensor)
        if isinstance(value, (list, tuple)):
            if len(value) == 1:
                return value[0]
            else:
                raise ValueError('Expect a scalar, got {}'.format(value))
        else:
            assert isinstance(value, (int, float))
            return value

    @staticmethod
    def optional_inputs(inputs: List[Tensor], requires: List[bool]) -> List[Union[Tensor, None]]:
        diff = len(requires) - len(inputs)
        assert diff >= 0, 'Onnx get {} inputs but expect at most {}.'.format(len(inputs), len(requires))
        ret: List[Union[Tensor, None]] = []
        ret += inputs
        ret += [None for _ in range(diff)]
        for i, (t, r) in enumerate(zip(ret, requires)):
            if t is None and r:
                raise ValueError('The {}th input is required.'.format(i))
        return ret


dispatch_table: Dict[str, Type[OnnxOperator]] = {}


def register_onnx_operator(cls: Type[OnnxOperator]):
    if not issubclass(cls, OnnxOperator):
        raise ValueError('Can only register a sub-class of OnnxOperator as an onnx operator.')
    cls_name = cls.__name__
    if not cls_name.startswith('Onnx'):
        raise ValueError(
            'Please name the class as OnnxOPNAME such as OnnxConv and OnnxAdd,'
            ' where OPNAME is the same as the operator name used by ONNX. Got {}'.format(cls_name)
        )
    dispatch_table[cls_name[4:]] = cls
    return cls


@register_onnx_operator
class OnnxConv(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        groups = self.attrs.get('group', 1)
        if len(inputs) == 2:
            x, w = inputs
            bias = None
        else:
            x, w, bias = inputs
        if len(x.shape) == 4:
            dilations = self.attrs.get('dilations', [1, 1])
            padding = self.attrs.get('pads', [0, 0, 0, 0])
            strides = self.attrs.get('strides', [1, 1])
            padding = ops.utils.normalize_padding(padding)
            # currently conv2d only supports symmetric padding, like torch
            if not (padding[0] == padding[2] and padding[1] == padding[3]):
                x = ops.pad(x, padding)
                output = ops.conv2d(x, w, stride=strides, dilations=dilations, groups=groups)
            else:
                output = ops.conv2d(
                    x, w, padding=(padding[0], padding[1]), stride=strides, dilations=dilations, groups=groups
                )
            if bias is not None:
                bias = ops.unsqueeze(bias, [0, 2, 3])
                output = output + bias
        elif len(x.shape) == 5:
            dilations = self.attrs.get('dilations', [1, 1, 1])
            padding = self.attrs.get('pads', [0, 0, 0, 0, 0, 0])
            strides = self.attrs.get('strides', [1, 1, 1])
            x = ops.pad(x, ops.utils.normalize_padding(padding, dim=3))
            output = ops.conv3d(x, w, stride=strides, dilations=dilations, groups=groups)
            if bias is not None:
                bias = ops.unsqueeze(bias, [0, 2, 3, 4])
                output = output + bias
        else:
            raise NotImplementedError('Currently only support 2D and 3D convolution, got x {}.'.format(x.shape))
        return [output]

    def run_v11(self, inputs: List[Tensor]) -> List[Tensor]:
        return self.run_v1(inputs)


@register_onnx_operator
class OnnxBatchNormalization(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        epsilon: float = self.attrs.get('epsilon', 1e-5)
        # for inference, we can ignore this momentum attribute
        momentum: float = self.attrs.get('momentum', 0.9)  # pylint: disable=unused-variable
        training_mode: int = self.attrs.get('training_mode', 0)
        assert training_mode == 0, 'BatchNorm in training mode occurs, currently, hidet does not support training.'

        x, scale, bias, running_mean, running_var = inputs
        if len(x.shape) == 1:
            y = (x - running_mean) * (running_var + epsilon).rsqrt()
            return [y * scale + bias]
        else:
            unsqueeze_dims = [dim for dim in range(len(x.shape)) if dim != 1]
            y = ops.batch_norm_infer(x, running_mean=running_mean, running_var=running_var, epsilon=epsilon, axis=1)
            return [y * scale.unsqueeze(unsqueeze_dims) + bias.unsqueeze(unsqueeze_dims)]


@register_onnx_operator
class OnnxRelu(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.relu(inputs[0])]


@register_onnx_operator
class OnnxSin(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.sin(inputs[0])]


@register_onnx_operator
class OnnxCos(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.cos(inputs[0])]


@register_onnx_operator
class OnnxPow(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        x, y = inputs
        return [ops.pow(x, y)]


@register_onnx_operator
class OnnxDiv(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        x, y = inputs
        return [ops.divide(x, y)]


@register_onnx_operator
class OnnxSqrt(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.sqrt(inputs[0])]


@register_onnx_operator
class OnnxErf(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.erf(inputs[0])]


@register_onnx_operator
class OnnxTanh(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.tanh(inputs[0])]


@register_onnx_operator
class OnnxMaxPool(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        kernel_size = list(self.attrs.get('kernel_shape'))
        x = inputs[0]
        if len(x.shape) == 4:
            padding = list(self.attrs.get('pads', [0, 0, 0, 0]))
            strides = list(self.attrs.get('strides', [1, 1]))
            return [ops.max_pool2d(inputs[0], kernel_size, strides, padding)]
        elif len(x.shape) == 5:
            padding = list(self.attrs.get('pads', [0, 0, 0, 0, 0, 0]))
            strides = list(self.attrs.get('strides', [1, 1, 1]))
            return [ops.max_pool3d(inputs[0], kernel_size, strides, padding)]
        else:
            raise NotImplementedError('Currently only support 2d and 3d max pooling')


@register_onnx_operator
class OnnxReduceMean(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        dims = self.attrs.get('axes')
        keep_dim = self.attrs.get('keepdims', 1) == 1
        return [ops.mean(inputs[0], dims, keep_dim)]


@register_onnx_operator
class OnnxSqueeze(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        dims = self.attrs.get('axes', None)
        data = inputs[0]
        if dims is None:
            # squeeze all dimensions with extent 1
            dims = [i for i, dim in enumerate(data.shape) if dim == 1]
        else:
            dims = list(dims)
        return [ops.squeeze(inputs[0], dims)]

    def run_v13(self, inputs: List[Tensor]) -> List[Tensor]:
        data, axes = inputs
        dims = self.tensor2list(axes)
        return [ops.squeeze(data, dims)]


@register_onnx_operator
class OnnxAdd(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [inputs[0] + inputs[1]]


@register_onnx_operator
class OnnxSub(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [inputs[0] - inputs[1]]


@register_onnx_operator
class OnnxMul(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [inputs[0] * inputs[1]]


@register_onnx_operator
class OnnxMatMul(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        a, b = inputs
        return [ops.matmul(a, b)]


@register_onnx_operator
class OnnxSoftmax(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        axis = self.attrs.get('axis', 1)
        return [ops.softmax(inputs[0], axis)]

    def run_v13(self, inputs: List[Tensor]) -> List[Tensor]:
        axis = self.attrs.get('axis', -1)
        return [ops.softmax(inputs[0], axis)]


@register_onnx_operator
class OnnxGlobalAveragePool(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        (x,) = inputs
        dims = list(range(2, len(x.shape)))
        return [ops.mean(x, dims=dims, keep_dim=True)]


@register_onnx_operator
class OnnxFlatten(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        axis = self.attrs.get('axis', 1)
        x = inputs[0]
        rank = len(x.shape)
        axis = (axis + rank) % rank
        dims = list(range(rank))
        return [ops.rearrange(x, plan=[dims[:axis], dims[axis:]])]


@register_onnx_operator
class OnnxUnsqueeze(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        axes = self.attrs['axes']  # in [-output_rank, output_rank - 1]
        x = inputs[0]
        rank = len(x.shape) + len(axes)
        axes = [(axis + rank) % rank for axis in axes]
        return [ops.unsqueeze(x, axes)]

    def run_v13(self, inputs: List[Tensor]) -> List[Tensor]:
        x, axes = inputs
        axes = self.tensor2list(axes)
        rank = len(x.shape) + len(axes)
        axes = [(axis + rank) % rank for axis in axes]
        return [ops.unsqueeze(x, axes)]


@register_onnx_operator
class OnnxReshape(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        allow_zero = self.attrs.get('allowzero', 0)  # pylint: disable=unused-variable
        x, shape = inputs
        shape = self.tensor2list(shape)
        return [ops.reshape(x, shape)]


@register_onnx_operator
class OnnxTranspose(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        perm = self.attrs.get('perm', None)
        x = inputs[0]
        perm = perm if perm else list(reversed(range(len(x.shape))))
        return [ops.transpose(x, perm)]


@register_onnx_operator
class OnnxConcat(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        axis = self.attrs.get('axis')
        return [ops.concat(inputs, axis)]


@register_onnx_operator
class OnnxArgMax(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        axis = self.attrs.get('axis', 0)
        keepdims = self.attrs.get('keepdims', True)
        select_last_index = self.attrs.get('select_last_index', False)
        if select_last_index:
            raise NotImplementedError()
        return [ops.argmax(inputs[0], dim=axis, keep_dim=keepdims)]


@register_onnx_operator
class OnnxGemm(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        alpha = self.attrs.get('alpha', 1.0)
        beta = self.attrs.get('beta', 0.0)
        trans_a = self.attrs.get('transA', 0)
        trans_b = self.attrs.get('transB', 0)

        a, b = inputs[:2]
        c = inputs[2] if len(inputs) > 2 else None
        if trans_a == 1:
            a = ops.rearrange(a, plan=[[1], [0]])
        if trans_b == 1:
            b = ops.rearrange(b, plan=[[1], [0]])
        assert a.shape[1] == b.shape[0]
        d = ops.matmul(a, b)
        if alpha != 1.0:
            d = d * alpha
        if c is not None and beta != 0.0:
            d = d + c * beta
        return [d]


@register_onnx_operator
class OnnxCast(OnnxOperator):
    code2dtype = {
        1: 'float32',
        2: 'uint8',
        3: 'int8',
        4: 'uint16',
        5: 'int16',
        6: 'int32',
        7: 'int64',
        8: 'string',
        9: 'bool',
        10: 'float16',
        11: 'float64',
        12: 'uint32',
        13: 'uint64',
        14: 'complex64',
        15: 'complex128',
        16: 'bfloat16',
    }

    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        to = self.attrs.get('to')
        x = inputs[0]
        dtype = self.code2dtype[to]
        return [ops.cast(x, dtype)]


@register_onnx_operator
class OnnxShape(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        start = self.attrs.get('start', 0)
        end: Optional[int] = self.attrs.get('end', None)

        x = inputs[0]
        rank = len(x.shape)
        start = start + rank if start < 0 else start
        if end is not None:
            end = end + rank if end < 0 else end
        else:
            end = rank
        start = max(min(start, rank), 0)
        end = max(min(end, rank), 0)
        return [hidet.asarray(x.shape[start:end]).cuda()]


@register_onnx_operator
class OnnxConstant(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        value: Optional[Tensor] = self.attrs.get('value')
        if value is None:
            raise NotImplementedError('Currently, only support Tensor constant in onnx importer')
        assert len(inputs) == 0
        return [value]


@register_onnx_operator
class OnnxGather(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        axis = self.attrs.get('axis', 0)
        data, indices = inputs
        return [ops.take(data, indices, axis)]


@register_onnx_operator
class OnnxSlice(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        data = inputs[0]
        starts = self.attrs['starts']
        ends = self.attrs['ends']
        axes = self.attrs.get('axes', list(range(len(starts))))
        ends = [min(end, data.shape[i]) for i, end in zip(axes, ends)]
        return [ops.strided_slice(data, starts, ends, axes)]

    def run_v10(self, inputs: List[Tensor]) -> List[Tensor]:
        data, starts, ends = inputs[:3]
        axes = inputs[3] if len(inputs) > 3 else None
        steps = inputs[4] if len(inputs) > 4 else None
        starts = self.tensor2list(starts)
        ends = self.tensor2list(ends)
        axes = self.tensor2list(axes) if axes is not None else None
        steps = self.tensor2list(steps) if steps is not None else None
        ends = [min(end, data.shape[i]) for i, end in zip(axes, ends)]
        return [ops.strided_slice(data, starts, ends, axes, steps)]


@register_onnx_operator
class OnnxSigmoid(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.sigmoid(inputs[0])]


@register_onnx_operator
class OnnxInstanceNormalization(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        epsilon = self.attrs.get('epsilon', 1e-5)

        x, scale, bias = inputs
        rank = len(x.shape)
        dims = [0] + list(range(2, rank))
        scale = ops.unsqueeze(scale, dims)  # [1, C, D1, ...]
        bias = ops.unsqueeze(bias, dims)  # [1, C, D1, ...]
        return [ops.instance_norm(x, epsilon) * scale + bias]


@register_onnx_operator
class OnnxConstantOfShape(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        value = self.attrs.get('value')
        if value is None:
            value = hidet.zeros([1], dtype='float32')

        shape = inputs[0].cpu().numpy().tolist()
        assert all(v >= 0 for v in shape)
        return [ops.broadcast(value, shape)]


@register_onnx_operator
class OnnxPad(OnnxOperator):
    def run_v2(self, inputs: List[Tensor]) -> List[Tensor]:
        data = inputs[0]
        mode = self.attrs.get('mode', 'constant')
        pads = self.attrs.get('pads')
        value = self.attrs.get('value', 0.0)
        return [ops.pad(data, pads, mode, value)]

    def run_v13(self, inputs: List[Tensor]) -> List[Tensor]:
        mode = self.attrs.get('mode', 'constant')
        data, pads = inputs[:2]
        value = self.tensor2list(inputs[2]) if len(inputs) > 2 else 0.0
        pads = self.tensor2list(pads)
        return [ops.pad(data, pads, mode, value)]


@register_onnx_operator
class OnnxResize(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        coordinate_transformation_mode = self.attrs.get('coordinate_transformation_mode', 'half_pixel')
        cubic_coeff_a = self.attrs.get('cubic_coeff_a', -0.75)
        exclude_outside = self.attrs.get('exclude_outside', 0)
        extrapolation_value = self.attrs.get('extrapolation_value', 0.0)
        mode = self.attrs.get('mode', 'nearest')
        nearest_mode = self.attrs.get('nearest_mode', 'round_prefer_floor')

        x, roi, scales, sizes = self.optional_inputs(inputs, requires=[True, False, False, False])
        if roi is not None:
            roi = self.tensor2list(roi)
        target_size = None
        if scales is not None and scales.size > 0:
            scales = self.tensor2list(scales)
            assert len(x.shape) == len(scales)
            target_size = [int(a * b) for a, b in zip(x.shape, scales)]
        elif sizes is not None and sizes.size > 0:
            sizes = self.tensor2list(sizes)
            target_size = [int(v) for v in sizes]
        else:
            raise ValueError('Resize operator in onnx must give either scales or sizes.')
        if len(x.shape) == 4:
            if not (target_size[0] == x.shape[0] and target_size[1] == x.shape[1]):
                raise ValueError('Unsupported resize on batch and channel dimension.')
            return [
                ops.resize2d(
                    x,
                    size=target_size[2:],
                    method=mode,
                    coordinate_transformation_mode=coordinate_transformation_mode,
                    rounding_method=nearest_mode,
                    roi=roi,
                    cubic_alpha=cubic_coeff_a,
                    cubic_exclude=exclude_outside,
                    extrapolation_value=extrapolation_value,
                )
            ]
        else:
            raise NotImplementedError('Current only support 2d resize, got x {}.'.format(x.shape))


@register_onnx_operator
class OnnxExpand(OnnxOperator):
    def run_v8(self, inputs: List[Tensor]) -> List[Tensor]:
        data, new_shape = inputs
        new_shape = self.tensor2list(new_shape)
        new_shape = hidet.graph.ops.arithmetic.broadcast_shape(data.shape, new_shape)
        return [ops.broadcast(data, new_shape)]


@register_onnx_operator
class OnnxRange(OnnxOperator):
    def run_v11(self, inputs: List[Tensor]) -> List[Tensor]:
        start, limit, delta = [self.tensor2list(t) for t in inputs]
        array = np.arange(start=start, stop=limit, step=delta)
        array = hidet.asarray(array).cuda().astype(dtype=inputs[0].dtype)
        return [array]


@register_onnx_operator
class OnnxTile(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        data, repeats = inputs
        repeats = self.tensor2list(repeats)
        return [ops.tile(data, repeats)]


@register_onnx_operator
class OnnxAveragePool(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        auto_pad = self.attrs.get('auto_pad', 'NOTSET')
        ceil_mode = self.attrs.get('ceil_mode', 0)
        count_include_pad = self.attrs.get('count_include_pad', 0)
        kernel_shape = self.attrs.get('kernel_shape')
        if auto_pad != 'NOTSET' or ceil_mode != 0 or count_include_pad != 0:
            raise NotImplementedError(self)

        x = inputs[0]
        if len(x.shape) == 4:
            pads = list(self.attrs.get('pads', [0, 0, 0, 0]))
            strides = list(self.attrs.get('strides', [1, 1]))
            x = ops.avg_pool2d(x, kernel_shape, strides, pads)
        elif len(x.shape) == 5:
            pads = list(self.attrs.get('pads', [0, 0, 0, 0, 0, 0]))
            strides = list(self.attrs.get('strides', [1, 1, 1]))
            x = ops.avg_pool3d(x, kernel_shape, strides, pads)
        else:
            raise NotImplementedError('Currently only support 2d and 3d avg pooling')
        return [x]


@register_onnx_operator
class OnnxClip(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        (x,) = inputs
        min_value = self.attrs.get('min', None)
        max_value = self.attrs.get('max', None)
        x = ops.clip(x, min_value, max_value)
        return [x]

    def run_v11(self, inputs: List[Tensor]) -> List[Tensor]:
        data, min_value, max_value = self.optional_inputs(inputs, requires=[True, False, False])
        if min_value is not None:
            min_value = self.tensor2scalar(min_value)
        if max_value is not None:
            max_value = self.tensor2scalar(max_value)
        return [ops.clip(data, min_value, max_value)]


@register_onnx_operator
class OnnxEqual(OnnxOperator):
    def run_v11(self, inputs: List[Tensor]) -> List[Tensor]:
        a, b = inputs
        return [ops.equal(a, b)]


@register_onnx_operator
class OnnxLess(OnnxOperator):
    def run_v9(self, inputs: List[Tensor]) -> List[Tensor]:
        a, b = inputs
        return [ops.less(a, b)]


@register_onnx_operator
class OnnxGreater(OnnxOperator):
    def run_v7(self, inputs: List[Tensor]) -> List[Tensor]:
        a, b = inputs
        return [ops.greater(a, b)]


@register_onnx_operator
class OnnxGreaterOrEqual(OnnxOperator):
    def run_v12(self, inputs: List[Tensor]) -> List[Tensor]:
        a, b = inputs
        return [ops.greater_equal(a, b)]


@register_onnx_operator
class OnnxLessOrEqual(OnnxOperator):
    def run_v12(self, inputs: List[Tensor]) -> List[Tensor]:
        a, b = inputs
        return [ops.less_equal(a, b)]


@register_onnx_operator
class OnnxWhere(OnnxOperator):
    def run_v9(self, inputs: List[Tensor]) -> List[Tensor]:
        cond, a, b = inputs
        return [ops.where(cond, a, b)]


@register_onnx_operator
class OnnxSplit(OnnxOperator):
    def run_v2(self, inputs: List[Tensor]) -> List[Tensor]:
        axis = self.attrs.get('axis', 0)
        parts = self.attrs['split']
        data = inputs[0]
        return ops.split(data, parts, axis)

    def run_v13(self, inputs: List[Tensor]) -> List[Tensor]:
        data = inputs[0]
        axis = self.attrs.get('axis', 0)
        if len(inputs) == 1:
            num_outputs = len(self.output_names)
            extent = data.shape[axis]
            if extent % num_outputs != 0:
                raise ValueError(
                    'Can not split tensor with shape {} on axis {} into {} parts evenly.'.format(
                        data.shape, axis, num_outputs
                    )
                )
            parts = [extent // num_outputs] * num_outputs
        elif len(inputs) == 2:
            parts = self.tensor2list(inputs[1])
        else:
            raise ValueError(
                'Expect the input of Split operator have 1 or 2 inputs, but got {} inputs. See:\n'.format(len(inputs))
                + 'https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split'
            )
        return ops.split(data, parts, axis)


@register_onnx_operator
class OnnxReduceSum(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        axes = self.attrs['axes']
        keepdims = self.attrs.get('keepdims', True)
        data = inputs[0]
        return [ops.sum(data, dims=axes, keep_dim=keepdims)]

    def run_v13(self, inputs: List[Tensor]) -> List[Tensor]:
        keepdims = self.attrs.get('keepdims', True)
        noop_with_emtpy_axes = self.attrs.get('noop_with_empty_axes', False)
        data = inputs[0]
        if len(inputs) == 1:
            if noop_with_emtpy_axes:
                axes = []
            else:
                axes = list(range(len(data.shape)))
        else:
            axes = self.tensor2list(inputs[1])
        return [ops.sum(data, dims=axes, keep_dim=keepdims)]


@register_onnx_operator
class OnnxReduceMin(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        axes = self.attrs['axes']
        keepdims = self.attrs.get('keepdims', True)
        data = inputs[0]
        return [ops.min(data, dims=axes, keep_dim=keepdims)]


@register_onnx_operator
class OnnxReduceMax(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        axes = self.attrs['axes']
        keepdims = self.attrs.get('keepdims', True)
        data = inputs[0]
        return [ops.max(data, dims=axes, keep_dim=keepdims)]


@register_onnx_operator
class OnnxMax(OnnxOperator):
    def run_v6(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.maximum(*inputs)]

    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        raise NotImplementedError()


@register_onnx_operator
class OnnxMin(OnnxOperator):
    def run_v6(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.minimum(*inputs)]


@register_onnx_operator
class OnnxReciprocal(OnnxOperator):
    def run_v6(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.reciprocal(inputs[0])]


@register_onnx_operator
class OnnxExp(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.exp(inputs[0])]


@register_onnx_operator
class OnnxLog(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.log(inputs[0])]


@register_onnx_operator
class OnnxNeg(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.negative(inputs[0])]


@register_onnx_operator
class OnnxIf(OnnxOperator):
    def __init__(self, node, op_sets: List[int]):
        super().__init__(node, op_sets)
        self.env_tensors: Dict[str, Tensor] = {}

    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        cond = inputs[0]
        if cond.storage is None:
            raise ValueError(
                'Hidet currently does not support dynamic control flow in computation graph'
                ' (If operator with condition that depends on non-const input).'
            )

        cond = cond.numpy().flatten()
        if cond.size > 1:
            raise ValueError('Condition in If operator can only have a single element.')
        if np.all(cond):
            graph = OnnxGraph(self.attrs['then_branch'], self.op_sets, self.env_tensors)
        else:
            graph = OnnxGraph(self.attrs['else_branch'], self.op_sets, self.env_tensors)
        return graph(*inputs[1:])


@register_onnx_operator
class OnnxNot(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.logical_not(inputs[0])]


@register_onnx_operator
class OnnxCumSum(OnnxOperator):
    def run_v11(self, inputs: List[Tensor]) -> List[Tensor]:
        x, axis = inputs
        axis = self.tensor2list(axis)
        exclusive = self.attrs.get('exclusive', False)
        reverse = self.attrs.get('reverse', False)
        return [ops.cumsum(x, axis, exclusive=exclusive, reverse=reverse)]


@register_onnx_operator
class OnnxIdentity(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return inputs


@register_onnx_operator
class OnnxPyFunc(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        warnings.warn(
            'PyFunc operator in ONNX model encountered, dummy output is returned. '
            'If dummy output are used, there will be errors.'
        )
        return [randn([1]) for name in self.output_names]


@register_onnx_operator
class OnnxLeakyRelu(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        alpha = self.attrs.get('alpha', 0.01)
        return [ops.leaky_relu(inputs[0], alpha)]


@register_onnx_operator
class OnnxConvTranspose(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        from hidet.graph.ops.utils import normalize_stride

        data, weight = inputs[:2]
        if len(data.shape) != 4:
            raise ValueError('Currently, only support 2D ConvTranspose.')
        auto_pad: str = self.attrs.get('auto_pad', 'NOTSET')
        dilations: Union[int, List[int]] = self.attrs.get('dilations', 1)
        group: int = self.attrs.get('group', 1)
        output_padding: Union[int, List[int]] = self.attrs.get('output_padding', 0)
        output_shape: Optional[List[int]] = self.attrs.get('output_shape', None)
        pads: Union[int, List[int]] = self.attrs.get('pads', 0)
        strides: int = self.attrs.get('strides', 1)

        if auto_pad != 'NOTSET':
            raise NotImplementedError('auto_pad {} is not supported yet.'.format(auto_pad))
        if output_shape is not None:
            raise NotImplementedError('output_shape is not supported yet.')
        if isinstance(dilations, int):
            dilations = [dilations] * 2
        if any(d != 1 for d in dilations):
            raise NotImplementedError('dilations {} is not supported yet.'.format(dilations))

        output_padding = normalize_stride(output_padding)

        if len(pads) == 4 and any(p < 0 for p in pads[2:]):
            # sometimes upstream framework may export onnx model with negative pads
            # this is a workaround to fix it
            # remove this when upstream framework fix their bug
            for i, p in enumerate(pads[2:]):
                if p < 0:
                    pads[2 + i] = 0
                    output_padding[i] += -p

        output = ops.conv2d_transpose(
            data, weight, stride=strides, padding=pads, groups=group, output_padding=output_padding
        )
        if len(inputs) > 2:
            bias: Tensor = inputs[2]  # 1D tensor added on channel axis
            output = output + ops.unsqueeze(bias, [0, 2, 3])
        return [output]


@register_onnx_operator
class OnnxPRelu(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.prelu(inputs[0], inputs[1])]


@register_onnx_operator
class OnnxAbs(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.abs(inputs[0])]


@register_onnx_operator
class OnnxAnd(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.logical_and(inputs[0], inputs[1])]


@register_onnx_operator
class OnnxBitShift(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        direction = self.attrs.get('direction', 'RIGHT')
        if direction == 'RIGHT':
            return [ops.bitwise_right_shift(inputs[0], inputs[1])]
        else:
            return [ops.bitwise_left_shift(inputs[0], inputs[1])]


@register_onnx_operator
class OnnxBitwiseAnd(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.bitwise_and(inputs[0], inputs[1])]


@register_onnx_operator
class OnnxBitwiseNot(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.bitwise_invert(inputs[0])]


@register_onnx_operator
class OnnxBitwiseOr(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.bitwise_or(inputs[0], inputs[1])]


@register_onnx_operator
class OnnxBitwiseXor(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.bitwise_xor(inputs[0], inputs[1])]


@register_onnx_operator
class OnnxCeil(OnnxOperator):
    def run(self, inputs: List[Tensor]) -> List[Tensor]:
        return [ops.ceil(inputs[0])]


@register_onnx_operator
class OnnxReduceL2(OnnxOperator):
    def run_v1(self, inputs: List[Tensor]) -> List[Tensor]:
        axes: Optional[List[int]] = self.attrs.get('axes', None)
        keepdims: int = self.attrs.get('keepdims', 1)
        assert len(inputs) == 1
        data: Tensor = inputs[0]
        rank = len(data.shape)
        if axes is None:
            axes = list(range(rank))
        axes: List[int] = [ops.utils.normalize_dim(axis, rank) for axis in axes]
        return [ops.sqrt(ops.sum(ops.square(data), axes, keep_dim=bool(keepdims)))]

    def run_v18(self, inputs: List[Tensor]) -> List[Tensor]:
        keepdims: int = self.attrs.get('keepdims', 1)
        noop_with_empty_axes: int = self.attrs.get('noop_with_empty_axes', 0)
        data, axes_tensor = self.optional_inputs(inputs, requires=[True, False])
        if axes_tensor is None:
            if noop_with_empty_axes:
                return [data]
            else:
                axes: List[int] = list(range(len(data.shape)))
        else:
            axes: List[int] = self.tensor2list(axes_tensor)
        return [ops.sqrt(ops.sum(ops.square(data), axes, keep_dim=bool(keepdims)))]


def dispatch(node, op_sets: List[int]) -> OnnxOperator:
    op_type = node.op_type
    if op_type not in dispatch_table:
        raise NotImplementedError(
            "Operator '{}' (in opset {}) from onnx has not been supported yet.".format(op_type, op_sets)
        )
    op = dispatch_table[op_type](node, op_sets)
    return op


def dispatch_operators(nodes: Sequence[onnx.NodeProto], op_sets: List[int]) -> List[OnnxOperator]:
    dispatched: List[OnnxOperator] = []
    unsupported: Set[str] = set()

    for node in nodes:
        op_type: str = node.op_type
        if op_type not in dispatch_table:
            unsupported.add(op_type)
        else:
            op_cls: Type[OnnxOperator] = dispatch_table[op_type]
            dispatched.append(op_cls(node, op_sets))
    if len(unsupported) > 0:
        raise NotImplementedError("Operator(s) {} from onnx have not been supported yet.".format(list(unsupported)))
    return dispatched


def run_trt(node: OnnxOperator, inputs: List[Tensor]) -> List[Tensor]:
    # pylint: disable=no-member
    import onnxruntime

    hidet_outputs = node.run(inputs)
    inputs_value_info = [
        onnx.helper.make_value_info(
            name=name,
            type_proto=onnx.helper.make_tensor_type_proto(
                elem_type=utils.dtype_to_onnx(tensor.dtype), shape=tensor.shape
            ),
        )
        for name, tensor in zip(node.input_names, inputs)
    ]
    outputs_value_info = [
        onnx.helper.make_value_info(
            name=name,
            type_proto=onnx.helper.make_tensor_type_proto(
                elem_type=utils.dtype_to_onnx(tensor.dtype), shape=tensor.shape
            ),
        )
        for name, tensor in zip(node.output_names, hidet_outputs)
    ]
    graph = onnx.helper.make_graph(nodes=[node.node], name='test', inputs=inputs_value_info, outputs=outputs_value_info)
    model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", opset) for opset in node.op_sets])
    # print(model)
    onnx.checker.check_model(model)
    # serialized_model = onnx._serialize(model)
    serialized_model = model.SerializeToString()
    session = onnxruntime.InferenceSession(serialized_model, providers=['CPUExecutionProvider'])
    outputs = session.run(
        node.output_names, input_feed={name: tensor.cpu().numpy() for name, tensor in zip(node.input_names, inputs)}
    )
    return [hidet.asarray(output).cuda() for output in outputs]


class OnnxGraph(nn.Module):
    def __init__(self, graph: onnx.GraphProto, op_sets: List[int], env_tensors: Optional[Dict[str, Tensor]] = None):
        super().__init__()
        self.op_sets = op_sets
        self.name: str = graph.name
        for param in graph.initializer:
            numpy_array = onnx.numpy_helper.to_array(tensor=param)
            self._parameters[param.name] = from_numpy(numpy_array).cuda()
        self.input_names: List[str] = [input.name for input in graph.input if input.name not in self._parameters]
        self.output_names: List[str] = [output.name for output in graph.output]
        self.operators: List[OnnxOperator] = dispatch_operators(graph.node, op_sets)
        # self.operators: List[OnnxOperator] = [dispatch(node, op_sets=self.op_sets) for node in graph.node]
        self.env_tensors: Dict[str, Tensor] = env_tensors if env_tensors else {}
        self.usage_count: Dict[str, int] = self.count_usage()

    def forward(self, *args):
        name2tensor = {"": None}
        if self.env_tensors:
            name2tensor.update(self.env_tensors)
        assert len(args) == len(self.input_names)
        # parameters
        for name, param in self._parameters.items():
            name2tensor[name] = param
        # inputs
        for name, inp in zip(self.input_names, args):
            name2tensor[name] = inp
        # run nodes

        log.info('start to interpret onnx graph')

        usage_count = self.usage_count.copy()
        for operator in self.operators:
            for name in operator.input_names:
                if name not in name2tensor:
                    raise ValueError('Tensor "{}" is used before produce.'.format(name))
            inputs = [name2tensor[name] for name in operator.input_names]
            if isinstance(operator, OnnxIf):
                operator.env_tensors = name2tensor
            outputs = operator.run(inputs)
            if not isinstance(outputs, (tuple, list)):
                raise ValueError(
                    'Operator "{}" should return a sequence of tensors, got {}.'.format(
                        operator.node.op_type, type(outputs)
                    )
                )

            check = False
            if check:
                outputs_trt = run_trt(operator, inputs)
                for a, b in zip(outputs, outputs_trt):
                    try:
                        np.testing.assert_allclose(a.cpu().numpy(), b.cpu().numpy(), atol=1e-3, rtol=1e-3)
                    except AssertionError as e:
                        print('Operator check failed: {:>20}'.format(operator.node.name))
                        # print('{}'.format(', '.join(out.signature() for out in outputs)))
                        raise e

            assert len(outputs) == len(operator.output_names)
            for name, tensor in zip(operator.output_names, outputs):
                name2tensor[name] = tensor
                # print('{:>50} {}'.format(name, tensor.signature()))
            for name in operator.input_names:
                if name not in self.env_tensors:
                    usage_count[name] -= 1
                    if usage_count[name] == 0:
                        # free memory
                        del name2tensor[name]

        # put outputs
        results = [name2tensor[name] for name in self.output_names]

        log.info('finish to interpret onnx graph')

        return results

    def count_usage(self):
        usage_count = defaultdict(int)
        for op in self.operators:
            for input_name in op.input_names:
                usage_count[input_name] += 1
        for graph_output_name in self.output_names:
            usage_count[graph_output_name] += 1
        # todo: add the usage of sub graphs
        return usage_count


class OnnxModule(nn.Module):
    """Loaded ONNX model.

    Parameters
    ----------
    model: onnx.ModelProto
        The onnx model to load, in the protobuf format.

    Attributes
    ----------
    op_sets: List[int]
        The operator sets used by the loaded model.

    input_names: List[str]
        The input names of the loaded onnx model.

    output_names: List[str]
        The output names of the loaded onnx model.
    """

    def __init__(self, model: onnx.ModelProto):
        super().__init__()
        op_sets = []
        for opset_import in model.opset_import:
            if opset_import.domain not in ['', 'ai.onnx', 'ai.onnx.ml']:
                # we currently only support standard onnx operator domain
                raise ValueError(
                    'Onnx model imports unknown operator domain: {}, we currently '
                    'only support standard onnx operator set.'.format(repr(opset_import.domain))
                )
            op_sets.append(int(opset_import.version))
        self.op_sets: List[int] = list(reversed(sorted(op_sets)))
        self.graph: OnnxGraph = OnnxGraph(model.graph, op_sets=self.op_sets)
        self.input_names: List[str] = self.graph.input_names
        self.output_names: List[str] = self.graph.output_names

    def forward(self, *args):
        """Run the onnx model with given inputs.

        Parameters
        ----------
        args: Sequence[hidet.Tensor]
            The input tensors. The number and order of the input tensors should match the
            OnnxModule.input_names attributes.

        Returns
        -------
        ret: Union[hidet.Tensor, List[hidet.Tensor]]
            The output tensor(s). If there are 2 or more tensors returned,
            a list of tensors are return with the order of OnnxModule.output_names.
            If there is only one tensor is returned, the single tensor is directly returned (instead of a list).
        """
        results = self.graph(*args)
        if len(results) == 1:
            return results[0]
        else:
            return results

    def dict_forward(self, feed_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
        args = []
        for name in self.input_names:
            if name not in feed_dict:
                raise ValueError('Missing input: {}'.format(name))
            args.append(feed_dict[name])
        outputs = self.graph(*args)
        output_dict = {name: value for name, value in zip(self.output_names, outputs)}
        return output_dict


[docs]def from_onnx(model: Union[str, 'onnx.ModelProto']) -> OnnxModule: """ Load an onnx model to hidet.graph.nn.Module. Parameters ---------- model: Union[str, onnx.ModelProto] The path or model proto of given onnx model. Returns ------- ret: OnnxModule The loaded model. """ if isinstance(model, str): model = os.path.expanduser(model) model = onnx.load_model(model, load_external_data=False) try: onnx.checker.check_model(model, full_check=True) except ValueError: # ignore 'ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB' pass except onnx.onnx_cpp2py_export.checker.ValidationError: # pylint: disable=c-extension-no-member warnings.warn('The onnx model has not pass the onnx checker.') return OnnxModule(model)