# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional, Sequence, Iterator, Dict, Any, Generic, TypeVar
from collections import OrderedDict
from hidet.graph.tensor import symbol_like
from hidet.graph.flow_graph import FlowGraph, trace_from
from hidet.graph.tensor import Tensor
# forward method return type
R = TypeVar('R')
[docs]class Module(Generic[R]):
def __init__(self):
self.name = None
self._parameters: OrderedDict[str, Optional[Tensor]] = OrderedDict()
self._submodules: OrderedDict[str, Optional[Module]] = OrderedDict()
def __setattr__(self, key, value):
if key in ['name', '_submodules', '_parameters']:
super().__setattr__(key, value)
return
parameters = self.__dict__.get('_parameters')
submodules = self.__dict__.get('_submodules')
if key in parameters:
del self._parameters[key]
elif key in submodules:
del self._submodules[key]
elif key in self.__dict__:
del self.__dict__[key]
if isinstance(value, Tensor):
parameters[key] = value
elif isinstance(value, Module):
submodules[key] = value
else:
self.__dict__[key] = value
cnt = sum(1 for collection in [parameters, submodules, self.__dict__] if collection and key in collection)
assert cnt <= 1, 'duplicated definition of {}'.format(key)
def __getattr__(self, item):
if item == '_parameters':
return super().__getattribute__(item)
if item == '_submodules':
return super().__getattribute__(item)
if item in self._parameters:
return self._parameters[item]
if item in self._submodules:
return self._submodules[item]
raise AttributeError(item)
def __str__(self):
lines = []
args_lines = self.extra_str().split('\n')
lines.extend([line for line in args_lines if len(line) > 0])
for key, submodule in self._submodules.items():
substr = str(submodule)
sub_lines = substr.split('\n')
sub_lines[0] = '({}): {}'.format(key, sub_lines[0])
lines.extend(sub_lines)
indent = 2
name = self.__class__.__name__
if len(lines) <= 1:
return '{}({})'.format(name, '\n'.join(lines))
else:
lines = [' ' * indent + line for line in lines]
return '{}(\n{}\n)'.format(name, '\n'.join(lines))
[docs] def __call__(self, *args, **kwargs) -> R:
return self.forward(*args, **kwargs)
def state_dict(self) -> Dict[str, Any]:
state_dict = OrderedDict()
for name, parameter in self.named_parameters():
state_dict[name] = parameter
return state_dict
def load_state_dict(self, state_dict: Dict[str, Any]):
for name, parameter in self.named_parameters():
parameter.copy_(state_dict[name])
def extra_str(self) -> str:
return ''
def forward(self, *args, **kwargs) -> R:
raise NotImplementedError()
def parameters(self, recursive: bool = True) -> Iterator[Tensor]:
for _, parameter in self.named_parameters(recursive=recursive):
yield parameter
def named_parameters(self, prefix='', recursive=True):
for name, parameter in self._parameters.items():
yield name, parameter
if recursive:
for module_name, submodule in self._submodules.items():
for name, parameter in submodule.named_parameters(prefix, recursive):
param_name = '{}{}.{}'.format(prefix + '.' if prefix else '', module_name, name)
yield param_name, parameter
def flow_graph_for(self, inputs: Sequence[Tensor]) -> FlowGraph:
symbol_inputs = []
for arg in inputs:
if isinstance(arg, Tensor):
symbol_inputs.append(symbol_like(arg))
else:
raise ValueError('Currently only support Tensor as input when automatically creating flow_graph.')
symbol_outputs = self.forward(*symbol_inputs)
return trace_from(symbol_outputs, symbol_inputs)
def cpu(self) -> Module:
for name, submodule in self._submodules.items():
submodule.cpu()
for name, parameter in self._parameters.items():
self._parameters[name] = parameter.cpu()
return self
def cuda(self) -> Module:
for name, submodule in self._submodules.items():
submodule.cuda()
for name, parameter in self._parameters.items():
self._parameters[name] = parameter.cuda()
return self
def hip(self) -> Module:
for name, submodule in self._submodules.items():
submodule.hip()
for name, parameter in self._parameters.items():
self._parameters[name] = parameter.hip()
return self
def to(self, dtype=None, device=None) -> Module:
for name, submodule in self._submodules.items():
submodule.to(dtype, device)
for name, parameter in self._parameters.items():
self._parameters[name] = parameter.to(dtype, device)
return self