Skip to content

Commit

Permalink
Update benchmark_archs.py
Browse files Browse the repository at this point in the history
  • Loading branch information
the-database committed Jan 9, 2025
1 parent fa21dae commit cf6e938
Showing 1 changed file with 37 additions and 36 deletions.
73 changes: 37 additions & 36 deletions scripts/benchmarking/benchmark_archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,7 @@
if name not in EXCLUDE_BENCHMARK_ARCHS
]
ALL_SCALES = [4, 3, 2, 1]
LIGHTWEIGHT_ARCHS = {
"artcnn_r8f64",
"artcnn_r8f48",
"cfsr",
"realcugan",
"span",
"compact",
"ditn_real",
"plksr_tiny",
"ultracompact",
"realmosr",
"rtmosr",
"rtmosr_l",
"rtmosr_ul",
"safmn",
"sebica",
"sebica_mini",
"seemore_t",
"superultracompact",
"spanplus",
"spanplus_s",
"spanplus_st",
"spanplus_sts",
}

# For archs that have extra parameters, list all combinations that need to be benchmarked.
EXTRA_ARCH_PARAMS: dict[str, list[dict[str, Any]]] = {
k: [] for k, _ in FILTERED_REGISTRY
Expand Down Expand Up @@ -95,6 +72,9 @@
"realcugan": [{"scale": 1, "extra_arch_params": {}}],
}

LIGHTWEIGHT_THRESHOLD = 1000 / 24
MIDDLEWEIGHT_THRESHOLD = 500 # 1000 / 2


def printfc(text: str, f: TextIOWrapper) -> None:
print(text)
Expand Down Expand Up @@ -126,6 +106,7 @@ def get_line(
vram_channels_last: float,
channels_last_vs_baseline: float,
best_fps: float,
num_runs: int,
print_markdown: bool = False,
) -> str:
name_separator = "|" if print_markdown else ": "
Expand Down Expand Up @@ -172,24 +153,41 @@ def benchmark_model(
name: str,
model: nn.Module,
input_tensor: Tensor,
warmup_runs: int = 5,
num_runs: int = 10,
) -> tuple[float, float, Tensor]:
warmup_runs: int = 1,
# num_runs: int = 5,
) -> tuple[float, float, int, Tensor]:
# https://github.com/dslisleedh/PLKSR/blob/main/scripts/test_direct_metrics.py
with torch.inference_mode():
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
num_runs = 5

for _ in range(warmup_runs):
model(input_tensor)
torch.cuda.synchronize()

output = None

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
starter, ender = (
torch.cuda.Event(enable_timing=True),
torch.cuda.Event(enable_timing=True),
)
current_stream = torch.cuda.current_stream()

# determine num runs based on inference speed
starter.record(current_stream)
model(input_tensor)
ender.record(current_stream)
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
if curr_time < LIGHTWEIGHT_THRESHOLD:
num_runs = 500
elif curr_time < MIDDLEWEIGHT_THRESHOLD:
num_runs = 10

output = None

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()

timings = np.zeros((num_runs, 1))

# for i in trange(num_runs, desc=name, leave=False):
Expand All @@ -204,7 +202,7 @@ def benchmark_model(
avg_time: float = np.sum(timings).item() / (num_runs * 1000)
vram_usage = torch.cuda.max_memory_allocated(device) / (1024**3)
assert output is not None
return avg_time, vram_usage, output
return avg_time, vram_usage, num_runs, output


def get_dtype(name: str, use_amp: bool) -> tuple[str, torch.dtype]:
Expand Down Expand Up @@ -243,15 +241,14 @@ def benchmark_arch(
)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
runs = lightweight_num_runs if name in LIGHTWEIGHT_ARCHS else num_runs

with torch.autocast(
device_type="cuda",
dtype=dtype,
enabled=use_amp,
):
avg_time, vram_usage, output = benchmark_model(
arch_key, model, random_input, warmup_runs, runs
avg_time, vram_usage, runs, output = benchmark_model(
arch_key, model, random_input, warmup_runs
)

if not (
Expand All @@ -271,6 +268,7 @@ def benchmark_arch(
total_params,
scale,
extra_arch_params,
runs,
)

return row
Expand Down Expand Up @@ -303,6 +301,7 @@ def benchmark_arch(
float,
float,
float,
int,
]
results_by_scale: dict[
int,
Expand Down Expand Up @@ -400,6 +399,7 @@ def benchmark_arch(
row[3]
if row[3] > row_channels_last[3]
else row_channels_last[3], # better fps
row[8], # num runs
)
results_by_scale[scale].append(new_row)
results_by_arch[arch_key][scale] = new_row
Expand Down Expand Up @@ -471,6 +471,7 @@ def benchmark_arch(
float("inf"),
float("inf"),
float("inf"),
0,
)
results_by_scale[scale].append(row)
results_by_arch[arch_key][scale] = row
Expand Down Expand Up @@ -498,7 +499,7 @@ def benchmark_arch(

for arch_name in sorted(results_by_arch.keys()):
f.write(f"\n### {arch_name}\n")
runs = lightweight_num_runs if arch_name in LIGHTWEIGHT_ARCHS else num_runs
runs = results_by_arch[arch_name][ALL_SCALES[0]][-1]
f.write(
f"{input_shape} input, {warmup_runs} warmup + {runs} runs averaged\n"
)
Expand Down

0 comments on commit cf6e938

Please sign in to comment.