# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.from.availabilityimportavailable,dynamo_available,imported# from . import utilsfrom.dynamo_configimportdynamo_config,DynamoConfig
[docs]deffrom_torch(module,concrete_args=None):""" Convert a torch.nn.Module or torch.fx.GraphModule to a hidet.nn.Module. Parameters ---------- module: torch.nn.Module or torch.fx.GraphModule The torch module to convert. concrete_args: Dict[str, Any] or None The concrete arguments to the module. If provided, will be used to make some arguments concrete during symbolic tracing. Returns ------- ret: Interpreter The converted hidet module, which is a subclass of hidet.nn.Module. """importtorchfrom.importregister_functions,register_modules,register_methods# pylint: disable=unused-importfrom.interpreterimportInterpreterifnotavailable():raiseRuntimeError('torch is not available.')ifisinstance(module,torch.fx.GraphModule):graph_module=moduleelifisinstance(module,torch.nn.Module):graph_module=torch.fx.symbolic_trace(module,concrete_args=concrete_args)else:raiseValueError(f'Current only support import torch.nn.Module and torch.fx.GraphModule, got {type(module)}.')returnInterpreter(graph_module)
defregister_dynamo_backends():print('Now, hidet will use the entry_points mechanism to register as a dynamo backend. \n''Feel free to remove the line `hidet.frontend.torch.register_dynamo_backends()` in your code.')