Add PyTorch Operator Mapping

This guide describes how to add an operator mapping for PyTorch.

digraph { // rankdir=LR; splines=curved; node [ shape=box, style="rounded, filled", height=0.4, width=0.6, margin="0.2,0.10", fillcolor="#EEF0E5", color="#163020", fontcolor="#163020", ]; edge [ color="#163020", fontcolor="#163020", ]; graph [style="rounded, dashed"] a [label="PyTorch nn.Module"]; b [label="torch.fx.Graph"]; c [label="hidet.FlowGraph"]; d [label="hidet.runtime.CompiledGraph"]; a -> b [label=" Step 1: PyTorch Dynamo"]; b -> c [label=" Step 2: Operator mapping"]; c -> d [label=" Step 3: FlowGraph building"]; }

The workflow of hidet backend of torch.compile(..., backend='hidet').

During step 2, we convert each pytorch operator to a hidet operator. In a torch.fx.Graph, there are three kinds of operators that need to be converted:

  1. functions (e.g., torch.nn.functional.relu, torch.relu, operator.add, etc.)

  2. modules (e.g., torch.nn.ReLU, torch.nn.Linear, etc.)

  3. tensor methods (e.g., torch.Tensor.squeeze, torch.Tensor.to, etc.)

In this guide, we will show how to add the operator mapping for all the three kinds of operators.

1. Prepare Environment

First, we remove some existing operator mapping (i.e., conversion) rules for demonstration purpose, and define an example model.

import operator
import torch
from torch import nn

# hidet employs an interpreter to convert a fx.Graph to FlowGraph
from hidet.graph.frontend.torch.registry import Registry

# the following three modules register the conversion rules
import hidet.graph.frontend.torch.register_functions
import hidet.graph.frontend.torch.register_modules
import hidet.graph.frontend.torch.register_methods

# Before removing registered functions, make sure to
# call allow_in_graph_registered_funcs_only() by importing dynamo_backends
import hidet.graph.frontend.torch.dynamo_backends

# we remove the rules for the following operators for demonstration purpose
# we will add them back later
del Registry.registered_functions[torch.nn.functional.relu]
del Registry.registered_functions[operator.add]
del Registry.registered_modules[torch.nn.Linear]
del Registry.registered_methods[torch.Tensor.flatten]


class Model(nn.Module):
    """a model used nn.Linear, nn.functional.relu, operator.add and Tensor.flatten"""

    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        x = self.linear(x)
        x = torch.nn.functional.relu(x)
        x = x + x
        return x.flatten()

2. Compile and Run the Model

If we compile and run the model, we will get an error that complains about the missing conversion rules for torch.nn.Linear, torch.nn.functional.relu and operator.add.

def run_model():
    model = Model().cuda()
    model_opt = torch.compile(model, backend='hidet', mode='max-autotune')

    x = torch.randn(10, 10, device='cuda')
    y1 = model_opt(x)
    y2 = model(x)
    torch.testing.assert_close(actual=y1, expected=y2, atol=3e-3, rtol=3e-3)
    print('success!')


try:
    run_model()
except Exception as e:
    print(e)
backend='hidet' raised:
NotImplementedError: The following operators are not supported or mapped by hidet yet:
  operator.add
  torch.nn.functional.relu
Please see the following guide to add the conversion rules:
  https://docs.hidet.org/stable/gallery/developer-guides/add-torch-operator-mapping.html
You are also welcome to submit a PR or an issue with reproducible script to:
  https://github.com/hidet-org/hidet
Thanks for your contribution!

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

3. Add Operator Mappings

from typing import Optional
from hidet import ops
from hidet import Tensor
from hidet.graph.frontend.torch.registry import (
    register_function,
    register_module,
    register_method,
    HidetModule,
)


# register the conversion rule for torch.nn.functional.relu
@register_function(torch.nn.functional.relu)
def torch_relu(x: Tensor, inplace: bool = False):  # the signature must match the original function
    # the parameter `x` is hidet.Tensor instead of torch.Tensor
    # we also need to return a hidet.Tensor instead of torch.Tensor
    _ = inplace  # ignore inplace
    return ops.relu(x)


@register_function(operator.add)
def operator_add(x: Tensor, y: Tensor):
    return ops.add(x, y)


@register_module(torch.nn.Linear)
class HidetLinear(
    HidetModule
):  # HidetModule is a tool class that helps us to convert a torch.nn.Module
    def __init__(self, torch_module: torch.nn.Module):
        super().__init__(torch_module)
        # inside the class, we can access the parameter of the torch module via
        # `self.param(name: str, optional: bool = False) -> Tensor`
        # and the returned tensor is a hidet.Tensor
        self.transposed_weight: Tensor = ops.transpose(self.param('weight'), [1, 0])
        self.bias: Optional[Tensor] = self.param('bias', optional=True)

    def __call__(self, x: Tensor) -> Tensor:
        # similarly, the parameter `x` is hidet.Tensor instead of torch.Tensor
        y = ops.matmul(x, self.transposed_weight)
        if self.bias is not None:
            y = y + self.bias
        return y

If we run the model again, it will complain about the missing conversion rule for torch.Tensor.flatten. It does not complain about missing conversion rule for torch.Tensor.flatten before because we can not know the type of the method’s class (i.e., torch.Tensor) before we actually run the model.

try:
    run_model()
except Exception as e:
    print(e)
backend='hidet' raised:
NotImplementedError: The following operators are not supported or mapped by hidet yet:
  torch.Tensor.flatten
Please see the following guide to add the conversion rules:
  https://docs.hidet.org/stable/gallery/developer-guides/add-torch-operator-mapping.html
You are also welcome to submit a PR or an issue with reproducible script to:
  https://github.com/hidet-org/hidet
Thanks for your contribution!

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Thus, we need to add the conversion rule for torch.Tensor.flatten later as well.

@register_method(torch.Tensor.flatten)
def tensor_flatten(self: Tensor, start_dim=0, end_dim=-1):
    return ops.flatten(self, start_dim=start_dim, end_dim=end_dim)


run_model()
Gen IR:   0%|                                         | 0/36288 [00:00<?, ?it/s]
Gen IR:   0%|                               | 1/36288 [00:00<2:47:10,  3.62it/s]
Gen IR:   8%|██▏                         | 2836/36288 [00:00<00:04, 6901.13it/s]
Gen IR:  14%|███▉                        | 5104/36288 [00:01<00:06, 4984.61it/s]
Gen IR:  34%|████████▉                 | 12475/36288 [00:01<00:01, 14564.11it/s]
Gen IR:  42%|███████████▍               | 15408/36288 [00:01<00:02, 9454.07it/s]
Gen IR:  66%|█████████████████         | 23815/36288 [00:02<00:01, 11080.91it/s]
Gen IR: 100%|██████████████████████████| 36288/36288 [00:02<00:00, 14749.86it/s]

Appling fusing:   0%|                                   | 0/701 [00:00<?, ?it/s]
Appling fusing:   0%|                           | 1/701 [00:01<14:19,  1.23s/it]
Appling fusing:   3%|▊                         | 23/701 [00:01<00:28, 23.60it/s]
Appling fusing:  25%|██████                  | 177/701 [00:02<00:04, 113.15it/s]
Appling fusing:  28%|██████▉                  | 193/701 [00:03<00:07, 65.93it/s]
Appling fusing:  50%|████████████            | 353/701 [00:03<00:02, 161.85it/s]
Appling fusing:  57%|█████████████▋          | 401/701 [00:04<00:02, 112.81it/s]
Appling fusing:  66%|████████████████▌        | 463/701 [00:05<00:02, 83.82it/s]
Appling fusing:  93%|██████████████████████▎ | 650/701 [00:05<00:00, 141.51it/s]
Appling fusing:  97%|███████████████████████▏| 677/701 [00:06<00:00, 125.84it/s]
Appling fusing: 100%|████████████████████████| 701/701 [00:06<00:00, 108.43it/s]

Compiling:   0%|                                         | 0/32 [00:00<?, ?it/s]
Compiling:   3%|█                                | 1/32 [00:32<16:52, 32.66s/it]
Compiling:   9%|███                              | 3/32 [00:36<04:49,  9.98s/it]
Compiling:  41%|█████████████                   | 13/32 [00:38<00:33,  1.75s/it]
Compiling:  53%|█████████████████               | 17/32 [01:04<00:49,  3.31s/it]
Compiling:  72%|███████████████████████         | 23/32 [01:07<00:19,  2.12s/it]
Compiling:  88%|████████████████████████████    | 28/32 [01:09<00:06,  1.57s/it]
Compiling: 100%|████████████████████████████████| 32/32 [01:09<00:00,  2.18s/it]
success!

We put all the registration code in the following three modules:

  1. hidet.graph.frontend.torch.register_functions (all the functions in torch.nn.functional.* and operator.*)

  2. hidet.graph.frontend.torch.register_modules (all the modules in torch.nn.*)

  3. hidet.graph.frontend.torch.register_methods (all the methods in torch.Tensor.*)

Lots of operators have already been registered in the above three modules, and they are also good examples for us to learn how to add operator mapping.

Usually, we will use the existing operators in hidet (defined in hidet.ops.*) to implement the pytorch operators. If there are no corresponding operators in hidet, we can add the missing operators to hidet.ops.* by following the guide Add New Operator.

Note

The operator mapping rules are registered in the global registry. Thus, if we register the same operator mapping rules multiple times, only the last registration will take effect.

4. Summary

In this guide, we show how to add operator mapping for PyTorch. We first remove some existing operator mapping rules for demonstration purpose, and then add them back. We also show how to add operator mapping for functions, modules and tensor methods.

Total running time of the script: (2 minutes 27.685 seconds)

Gallery generated by Sphinx-Gallery