# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Callable, Tuple, Any, Dict, Union
import time
from dataclasses import dataclass
from scipy import stats
import numpy as np
import nvtx
import hidet
import hidet.cuda
# copied from: https://github.com/openai/triton/blob/main/python/triton/testing.py
def do_bench(fn, warmup=25, rep=100, percentiles=(0.2, 0.5, 0.8)):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param percentiles: Performance percentile to return in addition to the median.
:type percentiles: list[float]
"""
# Estimate the runtime of the function
fn()
hidet.cuda.synchronize()
start_event = hidet.cuda.Event(enable_timing=True)
end_event = hidet.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
fn()
end_event.record()
hidet.cuda.synchronize()
estimate_ms = end_event.elapsed_time(start_event) / 5
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
start_event = [hidet.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [hidet.cuda.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(n_repeat):
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
hidet.cuda.synchronize()
times = np.array([e.elapsed_time(s) for s, e in zip(start_event, end_event)])
if percentiles:
percentiles = np.quantile(times, percentiles)
return tuple(percentiles)
else:
return np.mean(times).item()
[docs]def benchmark_func(run_func, *args, warmup=1, number=5, repeat=5, median=True) -> Union[List[float], float]:
"""Benchmark given function.
The given function ``run_func`` will be executed :math:`warmup + repeat * number` times. Each :math:`number` times
of execution will be grouped and conducted together.
Parameters
----------
run_func: Callable[[], Any]
Any callable function to be benchmarked.
warmup: int
The number of warm-up executions.
number: int
The number of executions to be grouped for measurement.
repeat: int
The number of repeat times of the group measurement.
median: bool
Whether the median latency is returned, instead of the latency.
Returns
-------
ret: Union[float, List[float]]
- When median == True, a single latency number is returned.
- When median == False, the latency of each repeat is returned, as a list of floats.
"""
results = []
with nvtx.annotate('warmup'):
for _ in range(warmup):
run_func(*args)
hidet.cuda.synchronize()
for i in range(repeat):
with nvtx.annotate(f'repeat {i}'):
hidet.cuda.synchronize()
start_time = time.time_ns()
for _ in range(number):
run_func(*args)
hidet.cuda.synchronize()
end_time = time.time_ns()
results.append((end_time - start_time) / 10**6 / number)
if median:
return float(np.median(results))
else:
return results
@dataclass
class CandidateData:
idx: int
latencies: List[float] = None
median: float = 0.0
in_game: bool = True
def find_best_candidate(candidates: List[Callable[..., None]], *args):
P_VALUE_THRESHOLD = 0.01
num_candidates = len(candidates)
candidates_data = [CandidateData(idx=idx) for idx, _ in enumerate(candidates)]
repeats = (7, 31)
for cur_repeat in repeats:
for idx, cand in enumerate(candidates):
if candidates_data[idx].in_game:
lats = benchmark_func(cand, *args, warmup=5, number=1, repeat=cur_repeat, median=False)
candidates_data[idx].latencies = lats
for cand in candidates_data:
if cand.in_game:
cand.median = np.median(cand.latencies)
# We have samples for every cansidate.
# Start with candidate with minimum median. Likely it drop a lot of slower candidates.
# Just optimisation. The next loop is enough for functionality
min_lat_cand = min((cand for cand in candidates_data if cand.in_game), key=lambda cand: cand.median)
min_idx = min_lat_cand.idx
for i in range(num_candidates):
if i == min_idx or not candidates_data[i].in_game:
continue
_, p_value = stats.ttest_ind(
candidates_data[min_idx].latencies, candidates_data[i].latencies, alternative='less'
)
if p_value < P_VALUE_THRESHOLD:
candidates_data[i].in_game = False
# If left only one candidate - good we found the best
left_candidates = [cand for cand in candidates_data if cand.in_game]
if len(left_candidates) == 1:
return (left_candidates[0].idx, [cand.median for cand in candidates_data])
# Compare all candidates betwee each other. Comparison use T-test
for i in range(num_candidates):
if not candidates_data[i].in_game:
continue
for j in range(num_candidates):
if not candidates_data[j].in_game or i == j:
continue
_, p_value = stats.ttest_ind(
candidates_data[i].latencies, candidates_data[j].latencies, alternative='less'
)
if p_value < P_VALUE_THRESHOLD:
candidates[j].in_game = False
# If left only one candidate - good we found the best
left_candidates = [cand for cand in candidates_data if cand.in_game]
if len(left_candidates) == 1:
return (left_candidates[0].idx, [cand.median for cand in candidates_data])
# Can not prove that one candidate statistically significant than all other.
# There are several but we can not order them using above method.
# Should choose some candidate. Choose one with minimal median
best = min((cand for cand in candidates_data if cand.in_game), key=lambda cand: cand.median)
best_idx = best.idx
latensies = [cand.median for cand in candidates_data]
return (best_idx, latensies)
@dataclass
class BenchData:
x_vals: List[Any]
x_name: str
y_name: str
kwargs: Dict[str, Any]
data: Dict[str, Tuple[List[float], List[float], List[float]]] # [t_min, t_avg, t_max]
def show_plot(self, show=True, save_path=None, figsize=None, title=None):
from matplotlib import pyplot as plt
if all(isinstance(x, (float, int)) for x in self.x_vals):
x_vals = self.x_vals
else:
x_vals = range(1, len(self.x_vals) + 1)
plt.figure(figsize=figsize)
ax = plt.subplot()
for name, (t_min, t_avg, t_max) in self.data.items():
p = ax.plot(x_vals, t_avg, label=name)
color = p[0].get_color()
ax.fill_between(x_vals, t_min, t_max, alpha=0.15, color=color)
ax.legend()
ax.set_xlabel(self.x_name)
ax.set_ylabel(self.y_name)
if title is not None:
ax.set_title(title)
ax.set_xticks(ticks=x_vals, labels=[str(x) for x in self.x_vals])
if show:
plt.show()
if save_path is not None:
plt.savefig(save_path)
return self
def to_dataframe(self):
import pandas as pd
columns = list(self.data.keys())
df = pd.DataFrame(columns=columns, index=self.x_vals)
for n in columns:
df[n] = self.data[n][1] # get t_avg
return df
def print_data(self):
print(self.to_dataframe())
class Bench:
def __init__(self, x_vals: List[Any], x_name: str, **kwargs):
self.x_vals = x_vals
self.x_name = x_name
self.y_name = 'ms'
self.byte_fn = None
self.kwargs: Dict[str, Any] = kwargs
self.bench_fns: List[Tuple[str, Callable]] = []
self.bench_data: Dict[str, Tuple[List[float], List[float], List[float]]] = {}
def measure_flops(self, byte_fn: Callable[[Any], int]):
"""
set a function that takes in the config, and the current x_val and returns the number of bytes
"""
self.byte_fn = byte_fn
self.y_name = 'TFLOP/s'
def bench(self, fn: Callable[[Any], Callable[[], Any]], name: Optional[str] = None):
"""
add a function that takes in the config and int and returns a function to be benchmarked
to the list of functions to be benchmarked.
If the name argument is None, the the name for this particular line is fn.__name__
"""
if name is None:
if hasattr(fn, '__name__'):
name = fn.__name__
else:
raise ValueError("cannot get name of function")
self.bench_fns.append((name, fn))
return self
def run(self):
"""
run all the functions that needs to be benchmarked, returning BenchData representing
the collected results
"""
for i in self.x_vals:
for name, fn in self.bench_fns:
if name not in self.bench_data:
self.bench_data[name] = ([], [], [])
t_min, t_avg, t_max = self.bench_data[name]
bench_fn = fn(i, **self.kwargs)
lo, avg, hi = do_bench(bench_fn)
if self.byte_fn is not None:
lo = self.byte_fn(i, **self.kwargs) * 1e-12 / (lo * 1e-3)
avg = self.byte_fn(i, **self.kwargs) * 1e-12 / (avg * 1e-3)
hi = self.byte_fn(i, **self.kwargs) * 1e-12 / (hi * 1e-3)
t_min.append(lo)
t_avg.append(avg)
t_max.append(hi)
return BenchData(self.x_vals, self.x_name, self.y_name, self.kwargs, self.bench_data)