Source code for hidet.lang.script

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import ast as py_ast
import inspect
from types import FunctionType
from typing import Tuple, Optional, List, Any, Dict

from hidet.ir.expr import Var
from hidet.ir.func import Function
from hidet.ir.module import IRModule
from hidet.ir.type import FuncType, BaseType, func_type
from hidet.lang.transpiler import PythonToHidetTranslator
from hidet.runtime.compiled_module import CompiledModule


def eliminate_indent(source: str) -> Tuple[str, int]:
    lines = source.split('\n')
    indent = len(source)
    for line in lines:
        if len(line.strip()) == 0:
            continue
        indent = min(indent, len(line) - len(line.lstrip()))
    source = '\n'.join([line[indent:] for line in lines])
    return source, indent


def eliminate_decorators(source: str) -> Tuple[str, int]:
    lines = source.split('\n')
    num_decorators = 0
    for line in lines:
        if len(line) > 0 and line[0] == '@':
            num_decorators += 1
        else:
            break
    source = '\n'.join(lines[num_decorators:])
    return source, num_decorators


[docs]def script(func: FunctionType) -> Function: """ Decorator to convert a Python function to a Hidet function. Parameters ---------- func: FunctionType The python function to be converted to a Hidet function. Returns ------- ret: Function The hidet.ir.Function that is converted from the given Python function. """ # Extract the source code of given function lines, start_line = inspect.getsourcelines(func) file = inspect.getsourcefile(func) source = ''.join(lines) source, col_offset = eliminate_indent(source) source, inc_lineno = eliminate_decorators(source) start_line += inc_lineno parsed: py_ast.AST = py_ast.parse(source=source) # Get the environment (globals and binding of free variables) # See the data model of python for the details of func.__globals__, func.__closure__ and func.__code__: # https://docs.python.org/3/reference/datamodel.html env: Dict[str, Any] = func.__globals__.copy() func_freevar_names: List[str] = list(func.__code__.co_freevars) func_freevar_cells: List[Any] = [v.cell_contents for v in func.__closure__] if func.__closure__ else [] assert len(func_freevar_names) == len(func_freevar_cells) env.update(dict(zip(func_freevar_names, func_freevar_cells))) # get the type annotations of function parameters. func_annotations: Dict[str, Any] = func.__annotations__ # Translate the Python function into Hidet function translator = PythonToHidetTranslator( file=file, start_lineno=start_line, start_column=col_offset, env=env, func_annotations=func_annotations ) hidet_function = translator(parsed) # add function to current script module if we are in a script module context ctx = ScriptModuleContext.current_context() if ctx: ctx.append_function(hidet_function) assert isinstance(hidet_function, Function) return hidet_function
class ScriptModuleContext: contexts: List[ScriptModuleContext] = [] def __init__(self): self.name2var: Dict[str, Var] = {} self.functions: List[Function] = [] self.extern_functions: Dict[str, Var] = {} def __enter__(self): self.contexts.append(self) return self def __exit__(self, exc_type, exc_val, exc_tb): self.contexts.pop() @staticmethod def current_context() -> Optional[ScriptModuleContext]: contexts = ScriptModuleContext.contexts return contexts[-1] if len(contexts) > 0 else None def append_function(self, function: Function): self.functions.append(function) self.name2var[function.name] = Var(hint=None, type=FuncType.from_func(function), name=function.name) def lookup(self, name: str) -> Optional[Var]: if name not in self.name2var: return None return self.name2var[name] def define_global_var(self, name: str, var_type: BaseType) -> Var: if name in self.name2var: raise ValueError(f'Global variable {name} is already defined.') self.name2var[name] = Var(hint=None, type=var_type, name=name) return self.name2var[name] def declare_extern_func(self, name: str, param_types, ret_type): if name in self.extern_functions: raise ValueError(f'Extern function {name} is already declared.') self.extern_functions[name] = Var(hint=None, name=name, type=func_type(param_types, ret_type)) return self.extern_functions[name] def ir_module(self) -> IRModule: return IRModule( functions={func.name: func for func in self.functions}, global_vars=self.name2var, extern_functions=self.extern_functions, ) def build(self) -> CompiledModule: return self.ir_module().build() def script_module() -> ScriptModuleContext: return ScriptModuleContext()