Source code for hidet.lang.script

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import ast as py_ast
import inspect
from types import FunctionType
from typing import Tuple, Optional, List, Any, Dict

from import Var
from import Function
from import IRModule
from import FuncType, BaseType, func_type
from hidet.lang.transpiler import PythonToHidetTranslator
from hidet.runtime.compiled_module import CompiledModule

def eliminate_indent(source: str) -> Tuple[str, int]:
    lines = source.split('\n')
    indent = len(source)
    for line in lines:
        if len(line.strip()) == 0:
        indent = min(indent, len(line) - len(line.lstrip()))
    source = '\n'.join([line[indent:] for line in lines])
    return source, indent

def eliminate_decorators(source: str) -> Tuple[str, int]:
    lines = source.split('\n')
    num_decorators = 0
    for line in lines:
        if len(line) > 0 and line[0] == '@':
            num_decorators += 1
    source = '\n'.join(lines[num_decorators:])
    return source, num_decorators

[docs]def script(func: FunctionType) -> Function: """ Decorator to convert a Python function to a Hidet function. Parameters ---------- func: FunctionType The python function to be converted to a Hidet function. Returns ------- ret: Function The that is converted from the given Python function. """ # Extract the source code of given function lines, start_line = inspect.getsourcelines(func) file = inspect.getsourcefile(func) source = ''.join(lines) source, col_offset = eliminate_indent(source) source, inc_lineno = eliminate_decorators(source) start_line += inc_lineno parsed: py_ast.AST = py_ast.parse(source=source) # Get the environment (globals and binding of free variables) # See the data model of python for the details of func.__globals__, func.__closure__ and func.__code__: # env: Dict[str, Any] = func.__globals__.copy() func_freevar_names: List[str] = list(func.__code__.co_freevars) func_freevar_cells: List[Any] = [v.cell_contents for v in func.__closure__] if func.__closure__ else [] assert len(func_freevar_names) == len(func_freevar_cells) env.update(dict(zip(func_freevar_names, func_freevar_cells))) # get the type annotations of function parameters. func_annotations: Dict[str, Any] = func.__annotations__ # Translate the Python function into Hidet function translator = PythonToHidetTranslator( file=file, start_lineno=start_line, start_column=col_offset, env=env, func_annotations=func_annotations ) hidet_function = translator(parsed) # add function to current script module if we are in a script module context ctx = ScriptModuleContext.current_context() if ctx: ctx.append_function(hidet_function) assert isinstance(hidet_function, Function) return hidet_function
class ScriptModuleContext: contexts: List[ScriptModuleContext] = [] def __init__(self): self.name2var: Dict[str, Var] = {} self.functions: List[Function] = [] self.extern_functions: Dict[str, Var] = {} def __enter__(self): self.contexts.append(self) return self def __exit__(self, exc_type, exc_val, exc_tb): self.contexts.pop() @staticmethod def current_context() -> Optional[ScriptModuleContext]: contexts = ScriptModuleContext.contexts return contexts[-1] if len(contexts) > 0 else None def append_function(self, function: Function): self.functions.append(function) self.name2var[] = Var(hint=None, type=FuncType.from_func(function), def lookup(self, name: str) -> Optional[Var]: if name not in self.name2var: return None return self.name2var[name] def define_global_var(self, name: str, var_type: BaseType) -> Var: if name in self.name2var: raise ValueError(f'Global variable {name} is already defined.') self.name2var[name] = Var(hint=None, type=var_type, name=name) return self.name2var[name] def declare_extern_func(self, name: str, param_types, ret_type): if name in self.extern_functions: raise ValueError(f'Extern function {name} is already declared.') self.extern_functions[name] = Var(hint=None, name=name, type=func_type(param_types, ret_type)) return self.extern_functions[name] def ir_module(self) -> IRModule: return IRModule( functions={ func for func in self.functions}, global_vars=self.name2var, extern_functions=self.extern_functions, ) def build(self) -> CompiledModule: return self.ir_module().build() def script_module() -> ScriptModuleContext: return ScriptModuleContext()