Source code for hidet.graph.frontend.torch
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .availability import available, dynamo_available, imported
from . import utils
from .dynamo_config import dynamo_config, DynamoConfig
[docs]def from_torch(module, concrete_args=None):
"""
Convert a torch.nn.Module or torch.fx.GraphModule to a hidet.nn.Module.
Parameters
----------
module: torch.nn.Module or torch.fx.GraphModule
The torch module to convert.
concrete_args: Dict[str, Any] or None
The concrete arguments to the module. If provided, will be used to make some arguments concrete during symbolic
tracing.
Returns
-------
ret: Interpreter
The converted hidet module, which is a subclass of hidet.nn.Module.
"""
import torch
from . import register_functions, register_modules, register_methods # pylint: disable=unused-import
from .interpreter import Interpreter
if not available():
raise RuntimeError('torch is not available.')
if isinstance(module, torch.fx.GraphModule):
graph_module = module
elif isinstance(module, torch.nn.Module):
graph_module = torch.fx.symbolic_trace(module, concrete_args=concrete_args)
else:
raise ValueError(f'Current only support import torch.nn.Module and torch.fx.GraphModule, got {type(module)}.')
return Interpreter(graph_module)
def register_dynamo_backends():
print(
'Now, hidet will use the entry_points mechanism to register as a dynamo backend. \n'
'Feel free to remove the line `hidet.frontend.torch.register_dynamo_backends()` in your code.'
)