From 37640ad85cf6adfc07a3bd384f4fa82e1f277f69 Mon Sep 17 00:00:00 2001 From: Hans Pabst Date: Fri, 27 Oct 2023 14:30:51 +0200 Subject: [PATCH] ocl: improved tuning script * Backup dotfile if it exists stemming from an earlier session. * Adjusted default DB-location for OpenTuner. --- src/acc/opencl/smm/tune_multiply.py | 67 ++++++++++++++++++----------- 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/src/acc/opencl/smm/tune_multiply.py b/src/acc/opencl/smm/tune_multiply.py index 48fcbc1843d..235ea577a17 100755 --- a/src/acc/opencl/smm/tune_multiply.py +++ b/src/acc/opencl/smm/tune_multiply.py @@ -13,6 +13,9 @@ from opentuner import MeasurementInterface from opentuner import Result from signal import signal, SIGINT +import tempfile +import socket +import shutil import copy import json import glob @@ -43,6 +46,7 @@ def ilog2(n): class SmmTuner(MeasurementInterface): def manipulator(self): """Setup common state and define search space""" + dbdir = os.path.join(tempfile.gettempdir(), "opentuner.db") manipulator = ConfigurationManipulator() # parse and sanitize kernel shape argument if not self.args.mnk: @@ -167,8 +171,15 @@ def manipulator(self): and (self.device and "" != self.device) and (self.size and 0 < self.size) and self.typeid - ): # construct label used for the database session - if not self.args.label: # consider to include self.device + ): # setup database (DB) + if args.database is None: # adjust DB-location + if os.path.isdir(dbdir): + shutil.rmtree(dbdir) + os.mkdir(dbdir) + self.args.database = "sqlite:///" + os.path.join( + dbdir, socket.gethostname() + ".db" + ) + if not self.args.label: # label for DB-session self.args.label = "{}-{}-{}x{}x{}-s{}".format( default_basename, self.typename, @@ -450,7 +461,21 @@ def save_final_config(self, configuration, final=True): ) filedot = os.path.join(self.args.jsondir, ".{}.json".format(self.args.label)) if self.gfsave < self.gflops: # save intermediate result - self.gfsave = self.gflops + if 0 == self.gfsave and os.path.exists(filedot): # backup + data = None + try: + with open(filedot, "r") as file: + data = json.load(file) + except: # noqa: E722 + pass + gflops = data["GFLOPS"] if data and "GFLOPS" in data else 0 + filename = os.path.join( + self.args.jsondir, + "{}-{}gflops.json".format(self.args.label, round(gflops)) + if 0 < gflops + else "{}.json".format(self.args.label), + ) + os.rename(filedot, filename) # self.manipulator().save_to_file(config, filename) with open(filedot, "w") as file: cfg = config @@ -459,41 +484,35 @@ def save_final_config(self, configuration, final=True): del cfg["XF"] json.dump(cfg, file, sort_keys=True) file.write("\n") # append newline at EOF + self.gfsave = self.gflops # check return code (consider not saving parameters) if 0 != result and not final: # incorrect result failed = " ".join(map(str, cfgenv)).replace("OPENCL_LIBSMM_SMM_", "") print("FAILED: {}".format(failed)) return - if final: + if final and os.path.exists(filedot): if not filenames and glob.glob(self.args.csvfile): print( - "WARNING: no JSON file found but {} exists.".format( + "WARNING: no JSON file found but {} will be overwritten.".format( self.args.csvfile ) ) filename = os.path.normpath( os.path.join( self.args.jsondir, - "{}-{}gflops.json".format(self.args.label, round(self.gfsave)), + "{}-{}gflops.json".format(self.args.label, round(self.gflops)), ) ) - if os.path.exists(filedot): - os.rename(filedot, filename) - if filename not in filenames: - filenames.append(filename) - self.merge_jsons(filenames) - speedup = round( - (self.gfsave / self.gfbase) if 0 < self.gfbase else 0, 1 - ) - print( - "Result{} was written to {}".format( - " ({}x over seed)".format(speedup) if 1 < speedup else "", - filename, - ) - ) - else: - print("WARNING: tuned result seems to be incorrect!") - exit(0) + os.rename(filedot, filename) + if filename not in filenames: # rebuild CSV-file + filenames.append(filename) + self.merge_jsons(filenames) + speedup = round((self.gflops / self.gfbase) if 0 < self.gfbase else 0, 1) + msg = " ({}x over seed)".format(speedup) if 1 < speedup else "" + print("Result{} was written to {}".format(msg, filename)) + else: + print("WARNING: tuned result seems to be incorrect!") + exit(0) def handle_sigint(self, signum, frame): """Handle SIGINT or CTRL-C""" @@ -780,6 +799,6 @@ def handle_sigint(self, signum, frame): try: SmmTuner.main(args) except Exception as e: - print("ERROR {}: {}!".format(type(e).__name__, e)) + print("{}: {}".format(type(e).__name__, e)) print("WARNING: ignored above error!") pass