# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import ast as py_ast
import inspect
from types import FunctionType
from typing import Tuple, Optional, List, Any, Dict
from hidet.ir.expr import Var
from hidet.ir.func import Function
from hidet.ir.module import IRModule
from hidet.ir.type import FuncType, BaseType, func_type
from hidet.lang.transpiler import PythonToHidetTranslator
from hidet.runtime.compiled_module import CompiledModule
def eliminate_indent(source: str) -> Tuple[str, int]:
lines = source.split('\n')
indent = len(source)
for line in lines:
if len(line.strip()) == 0:
continue
indent = min(indent, len(line) - len(line.lstrip()))
source = '\n'.join([line[indent:] for line in lines])
return source, indent
def eliminate_decorators(source: str) -> Tuple[str, int]:
lines = source.split('\n')
num_decorators = 0
for line in lines:
if len(line) > 0 and line[0] == '@':
num_decorators += 1
else:
break
source = '\n'.join(lines[num_decorators:])
return source, num_decorators
[docs]def script(func: FunctionType) -> Function:
"""
Decorator to convert a Python function to a Hidet function.
Parameters
----------
func: FunctionType
The python function to be converted to a Hidet function.
Returns
-------
ret: Function
The hidet.ir.Function that is converted from the given Python function.
"""
# Extract the source code of given function
lines, start_line = inspect.getsourcelines(func)
file = inspect.getsourcefile(func)
source = ''.join(lines)
source, col_offset = eliminate_indent(source)
source, inc_lineno = eliminate_decorators(source)
start_line += inc_lineno
parsed: py_ast.AST = py_ast.parse(source=source)
# Get the environment (globals and binding of free variables)
# See the data model of python for the details of func.__globals__, func.__closure__ and func.__code__:
# https://docs.python.org/3/reference/datamodel.html
env: Dict[str, Any] = func.__globals__.copy()
func_freevar_names: List[str] = list(func.__code__.co_freevars)
func_freevar_cells: List[Any] = [v.cell_contents for v in func.__closure__] if func.__closure__ else []
assert len(func_freevar_names) == len(func_freevar_cells)
env.update(dict(zip(func_freevar_names, func_freevar_cells)))
# get the type annotations of function parameters.
func_annotations: Dict[str, Any] = func.__annotations__
# Translate the Python function into Hidet function
translator = PythonToHidetTranslator(
file=file, start_lineno=start_line, start_column=col_offset, env=env, func_annotations=func_annotations
)
hidet_function = translator(parsed)
# add function to current script module if we are in a script module context
ctx = ScriptModuleContext.current_context()
if ctx:
ctx.append_function(hidet_function)
assert isinstance(hidet_function, Function)
return hidet_function
class ScriptModuleContext:
contexts: List[ScriptModuleContext] = []
def __init__(self):
self.name2var: Dict[str, Var] = {}
self.functions: List[Function] = []
self.extern_functions: Dict[str, Var] = {}
def __enter__(self):
self.contexts.append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.contexts.pop()
@staticmethod
def current_context() -> Optional[ScriptModuleContext]:
contexts = ScriptModuleContext.contexts
return contexts[-1] if len(contexts) > 0 else None
def append_function(self, function: Function):
self.functions.append(function)
self.name2var[function.name] = Var(hint=None, type=FuncType.from_func(function), name=function.name)
def lookup(self, name: str) -> Optional[Var]:
if name not in self.name2var:
return None
return self.name2var[name]
def define_global_var(self, name: str, var_type: BaseType) -> Var:
if name in self.name2var:
raise ValueError(f'Global variable {name} is already defined.')
self.name2var[name] = Var(hint=None, type=var_type, name=name)
return self.name2var[name]
def declare_extern_func(self, name: str, param_types, ret_type):
if name in self.extern_functions:
raise ValueError(f'Extern function {name} is already declared.')
self.extern_functions[name] = Var(hint=None, name=name, type=func_type(param_types, ret_type))
return self.extern_functions[name]
def ir_module(self) -> IRModule:
return IRModule(
functions={func.name: func for func in self.functions},
global_vars=self.name2var,
extern_functions=self.extern_functions,
)
def build(self) -> CompiledModule:
return self.ir_module().build()
def script_module() -> ScriptModuleContext:
return ScriptModuleContext()