# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TypeVar, Iterable, Tuple, List, Union, Sequence, Optional
import cProfile
import contextlib
import io
import itertools
import os
import pstats
import time
import numpy as np
from tabulate import tabulate
def unique(seq: Sequence) -> List:
    added = set()
    return [item for item in seq if (item not in added and not added.add(item))]
def prod(seq: Iterable):
    seq = list(seq)
    if len(seq) == 0:
        return 1
    else:
        c = seq[0]
        for i in range(1, len(seq)):
            c = c * seq[i]
        return c
def median(seq: Iterable):
    seq = list(seq)
    if len(seq) == 0:
        return None
    else:
        return sorted(seq)[len(seq) // 2]
def clip(
    x: Union[int, float], low: Optional[Union[int, float]], high: Optional[Union[int, float]]
) -> Union[int, float]:
    if low is not None:
        x = max(x, low)
    if high is not None:
        x = min(x, high)
    return x
TypeA = TypeVar('TypeA')
TypeB = TypeVar('TypeB')
def strict_zip(a: Sequence[TypeA], b: Sequence[TypeB]) -> Iterable[Tuple[TypeA, TypeB]]:
    if len(a) != len(b):
        raise ValueError(
            'Expect two sequence have the same length in zip, ' 'got length {} and {}.'.format(len(a), len(b))
        )
    return zip(a, b)
class COLORS:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    MAGENTA = '\033[95m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
def green(v, fmt='{}'):
    return COLORS.OKGREEN + fmt.format(v) + COLORS.ENDC
def cyan(v, fmt='{}'):
    return COLORS.OKCYAN + fmt.format(v) + COLORS.ENDC
def blue(v, fmt='{}'):
    return COLORS.OKBLUE + fmt.format(v) + COLORS.ENDC
def yellow(v, fmt='{}'):
    return COLORS.WARNING + fmt.format(v) + COLORS.ENDC
def red(v, fmt='{}'):
    return COLORS.FAIL + fmt.format(v) + COLORS.ENDC
def magenta(v, fmt='{}'):
    return COLORS.MAGENTA + fmt.format(v) + COLORS.ENDC
def bold(v, fmt='{}'):
    return COLORS.BOLD + fmt.format(v) + COLORS.ENDC
def color(v, fmt='{}', fg='default', bg='default'):
    fg_code = {
        "black": 30,
        "red": 31,
        "green": 32,
        "yellow": 33,
        "blue": 34,
        "magenta": 35,
        "cyan": 36,
        "white": 37,
        "default": 39,
    }
    bg_code = {
        "black": 40,
        "red": 41,
        "green": 42,
        "yellow": 43,
        "blue": 44,
        "magenta": 45,
        "cyan": 46,
        "white": 47,
        "default": 49,
    }
    return '\033[{};{}m{}\033[0m'.format(fg_code[fg], bg_code[bg], fmt.format(v))
def color_table():
    fg_names = ["default", "black", "red", "green", "yellow", "blue", "magenta", "cyan", "white"]
    bg_names = ["default", "black", "red", "green", "yellow", "blue", "magenta", "cyan", "white"]
    print('{:>10} {:>10}   {:<10}'.format('fg', 'bg', 'text'))
    for bg in bg_names:
        for fg in fg_names:
            print('{:>10} {:>10}   {}'.format(fg, bg, color('sample text', fg=fg, bg=bg)))
def color_rgb(v, fg, fmt='{}'):
    fmt = '\033[38;2;{};{};{}m'.format(fg[0], fg[1], fg[2]) + fmt + '\033[0m'
    return fmt.format(v)
def color_text(v, fmt='{}', idx: int = 0):
    if idx == 0:
        return fmt.format(v)
    colors = {1: (153, 96, 52), 2: (135, 166, 73)}
    return color_rgb(v, colors[idx], fmt=fmt)
def nocolor(s: str) -> str:
    for value in COLORS.__dict__.values():
        if isinstance(value, str) and value[0] == '\033':
            s = s.replace(value, '')
    return s
def str_indent(msg: str, indent=0) -> str:
    lines = msg.split('\n')
    lines = [' ' * indent + line for line in lines]
    return '\n'.join(lines)
class Timer:
    def __init__(self, msg=None, file=None, verbose=True, stdout=True):
        self.start_time = None
        self.end_time = None
        self.msg = msg
        self.stdout = stdout
        self.verbose = verbose
        self.file = file
    def __enter__(self):
        self.start_time = time.time()
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end_time = time.time()
        if self.msg is not None and self.verbose:
            msg = '{} {}'.format(self.msg, green(self.time2str(self.end_time - self.start_time)))
            if self.stdout:
                print(msg)
            if self.file:
                if isinstance(self.file, str):
                    with open(self.file, 'w') as f:
                        f.write(nocolor(msg))
                else:
                    self.file.write(msg + '\n')
    def elapsed_seconds(self) -> float:
        return self.end_time - self.start_time
    def time2str(self, seconds: float) -> str:
        if seconds < 1:
            return '{:.1f} {}'.format(seconds * 1000, 'ms')
        elif seconds < 60:
            return '{:.1f} {}'.format(seconds, 'seconds')
        elif seconds < 60 * 60:
            return '{:.1f} {}'.format(seconds / 60, 'minutes')
        else:
            return '{:.1f} {}'.format(seconds / 60 / 60, 'hours')
def repeat_until_converge(func, obj, limit=None):
    i = 0
    while True:
        i += 1
        orig_obj = obj
        obj = func(obj)
        if obj is orig_obj:
            return obj
        if limit is not None and i >= limit:
            return obj
def get_next_file_index(dirname: str) -> int:
    indices = set()
    for fname in os.listdir(dirname):
        parts = fname.split('_')
        with contextlib.suppress(ValueError):
            indices.add(int(parts[0]))
    for idx in itertools.count(0):
        if idx not in indices:
            return idx
    return -1
[docs]def factorize(n):
    """
    example:
    factor(12) => [1, 2, 3, 4, 6, 12]
    """
    i = 1
    ret = []
    while i * i <= n:
        if n % i == 0:
            ret.append(i)
            if i * i != n:
                ret.append(n // i)
        i += 1
    return list(sorted(ret)) 
def _is_immutable(obj):
    from hidet.ir.expr import Constant
    from hidet.graph.operator import Device
    if isinstance(obj, (int, float, str, tuple)):
        return True
    if isinstance(obj, Constant) and obj.type.is_tensor():
        return False
    if isinstance(obj, (Constant, Device)):
        return True
    return False
def same_list(lhs, rhs, use_equal=False):
    if len(lhs) != len(rhs):
        return False
    for l, r in zip(lhs, rhs):
        if use_equal or _is_immutable(l) and _is_immutable(r):
            if l != r:
                return False
        else:
            if l is not r:
                return False
    return True
def index_of(value: object, lst: Sequence, allow_missing=True) -> int:
    for i, v in enumerate(lst):
        if v is value:
            return i
    if allow_missing:
        return -1
    else:
        raise ValueError('value not found in list')
class HidetProfiler:
    def __init__(self, display_on_exit=True):
        self.pr = cProfile.Profile()
        self.display_on_exit = display_on_exit
    def __enter__(self):
        self.pr.enable()
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.pr.disable()
        if self.display_on_exit:
            print(self.result())
    def result(self):
        s = io.StringIO()
        ps = pstats.Stats(self.pr, stream=s).sort_stats('cumulative')
        ps.print_stats()
        return str(s.getvalue())
class TableRowContext:
    def __init__(self, tb):
        self.tb = tb
        self.row = []
    def __iadd__(self, other):
        if isinstance(other, (tuple, list)):
            self.row.extend(other)
        else:
            self.row.append(other)
    def append(self, other):
        self.row.append(other)
    def extend(self, other):
        self.row.extend(other)
    def __enter__(self):
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.tb.rows.append(self.row)
class TableBuilder:
    def __init__(self, headers=tuple(), tablefmt='simple', floatfmt='.3f'):
        self.headers = list(headers)
        self.rows = []
        self.tablefmt = tablefmt
        self.floatfmt = floatfmt
    def __enter__(self):
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
        pass
    def __iadd__(self, row):
        self.rows.append(row)
        return self
    def __str__(self):
        return str(tabulate(self.rows, self.headers, tablefmt=self.tablefmt, floatfmt=self.floatfmt))
    def new_row(self) -> TableRowContext:
        return TableRowContext(tb=self)
    def extend_header(self, column_names):
        self.headers.extend(column_names)
[docs]def initialize(*args, **kwargs):
    """
    Decorate an initialization function. After decorating with this function, the initialization function will be called
    after the definition.
    Parameters
    ----------
    args:
        The positional arguments of initializing.
    kwargs:
        The keyword arguments of initializing.
    Returns
    -------
    ret:
        A decorator that will call given function with args and kwargs,
        and return None (to prevent this function to be called again).
    """
    def decorator(f):
        f(*args, **kwargs)
    return decorator 
[docs]def gcd(a: int, b: int, *args) -> int:
    """
    Get the greatest common divisor of non-negative integers a and b.
    Parameters
    ----------
    a: int
        The lhs operand.
    b: int
        The rhs operand.
    Returns
    -------
    ret: int
        The greatest common divisor.
    """
    if len(args) > 0:
        return gcd(gcd(a, b), *args)
    assert a >= 0 and b >= 0
    return a if b == 0 else gcd(b, a % b) 
[docs]def lcm(a: int, b: int) -> int:
    """
    Get the least common multiple of non-negative integers a and b.
    Parameters
    ----------
    a: int
        The lhs operand.
    b: int
        The rhs operand.
    Returns
    -------
    ret: int
        The least common multiple.
    """
    return a // gcd(a, b) * b 
def is_power_of_two(n: int) -> bool:
    """
    Check if an integer is a power of two: 1, 2, 4, 8, 16, 32, ...
    Parameters
    ----------
    n: int
        The integer to check.
    Returns
    -------
    ret: bool
        True if n is a power of two, False otherwise.
    """
    return n > 0 and (n & (n - 1)) == 0
[docs]def cdiv(n: int, d: int) -> int:
    """
    Get the ceiling of n / d.
    Parameters
    ----------
    n: int
        The numerator.
    d: int
        The denominator.
    Returns
    -------
    ret: int
        The ceiling of n / d.
    """
    return (n + d - 1) // d 
[docs]def error_tolerance(a: Union[np.ndarray, 'Tensor'], b: Union[np.ndarray, 'Tensor']) -> float:
    """
    Given two tensors with the same shape and data type, this function finds the minimal e, such that
        abs(a - b) <= abs(b) * e + e
    Parameters
    ----------
    a: Union[np.ndarray, hidet.Tensor]
        The first tensor.
    b: Union[np.ndarray, hidet.Tensor]
        The second tensor.
    Returns
    -------
    ret: float
        The error tolerance between a and b.
    """
    from hidet.graph import Tensor
    if isinstance(a, Tensor):
        a = a.numpy()
    if isinstance(b, Tensor):
        b = b.numpy()
    if isinstance(a.dtype, np.floating):
        a = a.astype(np.float32)
        b = b.astype(np.float32)
    lf = 0.0
    rg = 10.0
    for _ in range(20):
        mid = (lf + rg) / 2.0
        if np.allclose(a, b, rtol=mid, atol=mid):
            rg = mid
        else:
            lf = mid
    return (lf + rg) / 2.0 
def assert_close(actual, expected, rtol=1e-5, atol=1e-5):
    from numpy import ndarray
    from hidet import Tensor
    import numpy.testing
    if isinstance(actual, ndarray):
        pass
    elif isinstance(actual, Tensor):
        actual = actual.cpu().numpy()
    elif type(actual).__name__ == 'Tensor' and type(actual).__module__ == 'torch':
        actual = actual.cpu().numpy()
    else:
        raise TypeError(f'Unsupported type: {type(actual)}')
    if isinstance(expected, ndarray):
        pass
    elif isinstance(expected, Tensor):
        expected = expected.cpu().numpy()
    elif type(expected).__name__ == 'Tensor' and type(expected).__module__ == 'torch':
        expected = expected.cpu().numpy()
    else:
        raise TypeError(f'Unsupported type: {type(expected)}')
    numpy.testing.assert_allclose(actual, expected, rtol, atol)
def inspect_live_hidet_objects() -> str:
    """
    Inspect the live hidet objects and return a report in human-readable text.
    """
    import gc
    from hidet import FlowGraph, Tensor, Operator
    from hidet.runtime import CompiledApp, CompiledGraph, CompiledTask, CompiledModule, CompiledFunction
    from hidet.ir import IRModule, Function, Stmt, Expr
    classes = [
        FlowGraph,
        Tensor,
        Operator,
        CompiledApp,
        CompiledGraph,
        CompiledTask,
        CompiledModule,
        CompiledFunction,
        IRModule,
        Function,
        Stmt,
        Expr,
    ]
    headers = ['Class', 'Live Instances']
    lines = []
    objects = gc.get_objects()
    lines: List[List[str]]
    for cls in classes:
        class_name = cls.__name__
        num_instances = sum(1 for obj in objects if isinstance(obj, cls))
        lines.append([class_name, num_instances])
    report = tabulate(lines, headers=headers, tablefmt='simple', stralign='right', numalign='left')
    return report
def release_unused_resources():
    """
    Release the resources with circular references but not used anymore (i.e., unreachable).
    """
    import gc
    gc.collect()
if __name__ == '__main__':
    # color_table()
    print(color_text('sample', idx=1))
    print(color_text('sample', idx=2))