Skip to content

Commit

Permalink
Plot NAS search progression non-interactively (#2025)
Browse files Browse the repository at this point in the history
### Changes
As stated in the title

### Reason for changes
Windows pre-commits fail because they cannot find tk during the sanity
tests. The plotting in sanity tests doesn't actually need interactivity,
so agg can be used instead of tk as a backend which does not pose
cross-platform problems.

### Related tickets
N/A

### Tests
windows/precommit_torch_cpu
  • Loading branch information
vshampor authored Aug 4, 2023
1 parent ba2a4ff commit e1215cd
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 49 deletions.
26 changes: 12 additions & 14 deletions nncf/common/accuracy_aware_training/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf.api.compression import CompressionAlgorithmController
from nncf.api.compression import CompressionStage
from nncf.common.logging import nncf_logger
from nncf.common.plotting import noninteractive_plotting
from nncf.common.utils.helpers import configure_accuracy_aware_paths
from nncf.common.utils.tensorboard import prepare_for_tensorboard
from nncf.config.schemata.defaults import AA_COMPRESSION_RATE_STEP_REDUCTION_FACTOR
Expand Down Expand Up @@ -486,20 +487,17 @@ def update_training_history(self, compression_rate, metric_value):
self._compressed_training_history.append((compression_rate, accuracy_budget))

if IMG_PACKAGES_AVAILABLE:
backend = plt.get_backend()
plt.switch_backend("agg")
plt.ioff()
fig = plt.figure()
plt.plot(self.compressed_training_history.keys(), self.compressed_training_history.values())
buf = io.BytesIO()
plt.savefig(buf, format="jpeg")
buf.seek(0)
image = PIL.Image.open(buf)
self.add_tensorboard_image(
"compression/accuracy_aware/acc_budget_vs_comp_rate", image, len(self.compressed_training_history)
)
plt.close(fig)
plt.switch_backend(backend)
with noninteractive_plotting():
fig = plt.figure()
plt.plot(self.compressed_training_history.keys(), self.compressed_training_history.values())
buf = io.BytesIO()
plt.savefig(buf, format="jpeg")
buf.seek(0)
image = PIL.Image.open(buf)
self.add_tensorboard_image(
"compression/accuracy_aware/acc_budget_vs_comp_rate", image, len(self.compressed_training_history)
)
plt.close(fig)

@property
def compressed_training_history(self):
Expand Down
22 changes: 22 additions & 0 deletions nncf/common/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager

from matplotlib import pyplot as plt


@contextmanager
def noninteractive_plotting():
backend = plt.get_backend()
plt.switch_backend("agg")
plt.ioff()
yield
plt.switch_backend(backend)
72 changes: 37 additions & 35 deletions nncf/experimental/torch/nas/bootstrapNAS/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nncf import NNCFConfig
from nncf.common.initialization.batchnorm_adaptation import BatchnormAdaptationAlgorithm
from nncf.common.logging import nncf_logger
from nncf.common.plotting import noninteractive_plotting
from nncf.common.utils.decorators import skip_if_dependency_unavailable
from nncf.common.utils.os import safe_open
from nncf.config.extractors import get_bn_adapt_algo_kwargs
Expand Down Expand Up @@ -399,45 +400,46 @@ def visualize_search_progression(self, filename="search_progression") -> NoRetur
"""
import matplotlib.pyplot as plt

plt.figure()
colormap = plt.cm.get_cmap("viridis")
col = range(int(self.search_params.num_evals / self.search_params.population))
for i in range(0, len(self.search_records), self.search_params.population):
c = [col[int(i / self.search_params.population)]] * len(
self.search_records[i : i + self.search_params.population]
)
plt.scatter(
[abs(row[2]) for row in self.search_records][i : i + self.search_params.population],
[abs(row[4]) for row in self.search_records][i : i + self.search_params.population],
s=9,
c=c,
alpha=0.5,
marker="D",
cmap=colormap,
)
plt.scatter(
*tuple(abs(ev.input_model_value) for ev in self.evaluator_handlers),
marker="s",
s=120,
color="blue",
label="Input Model",
edgecolors="black",
)
if None not in self.best_vals:
with noninteractive_plotting():
plt.figure()
colormap = plt.cm.get_cmap("viridis")
col = range(int(self.search_params.num_evals / self.search_params.population))
for i in range(0, len(self.search_records), self.search_params.population):
c = [col[int(i / self.search_params.population)]] * len(
self.search_records[i : i + self.search_params.population]
)
plt.scatter(
[abs(row[2]) for row in self.search_records][i : i + self.search_params.population],
[abs(row[4]) for row in self.search_records][i : i + self.search_params.population],
s=9,
c=c,
alpha=0.5,
marker="D",
cmap=colormap,
)
plt.scatter(
*tuple(abs(val) for val in self.best_vals),
marker="o",
*tuple(abs(ev.input_model_value) for ev in self.evaluator_handlers),
marker="s",
s=120,
color="yellow",
label="BootstrapNAS A",
color="blue",
label="Input Model",
edgecolors="black",
linewidth=2.5,
)
plt.legend()
plt.title("Search Progression")
plt.xlabel(self.efficiency_evaluator_handler.name)
plt.ylabel(self.accuracy_evaluator_handler.name)
plt.savefig(f"{self._log_dir}/{filename}.png")
if None not in self.best_vals:
plt.scatter(
*tuple(abs(val) for val in self.best_vals),
marker="o",
s=120,
color="yellow",
label="BootstrapNAS A",
edgecolors="black",
linewidth=2.5,
)
plt.legend()
plt.title("Search Progression")
plt.xlabel(self.efficiency_evaluator_handler.name)
plt.ylabel(self.accuracy_evaluator_handler.name)
plt.savefig(f"{self._log_dir}/{filename}.png")

def save_evaluators_state(self) -> NoReturn:
"""
Expand Down

0 comments on commit e1215cd

Please sign in to comment.