Skip to content

Commit

Permalink
ocl: improved tuning script
Browse files Browse the repository at this point in the history
* Backup dotfile if it exists stemming from an earlier session.
* Adjusted default DB-location for OpenTuner.
  • Loading branch information
hfp committed Oct 27, 2023
1 parent 87ff3c8 commit 37640ad
Showing 1 changed file with 43 additions and 24 deletions.
67 changes: 43 additions & 24 deletions src/acc/opencl/smm/tune_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from opentuner import MeasurementInterface
from opentuner import Result
from signal import signal, SIGINT
import tempfile
import socket
import shutil
import copy
import json
import glob
Expand Down Expand Up @@ -43,6 +46,7 @@ def ilog2(n):
class SmmTuner(MeasurementInterface):
def manipulator(self):
"""Setup common state and define search space"""
dbdir = os.path.join(tempfile.gettempdir(), "opentuner.db")
manipulator = ConfigurationManipulator()
# parse and sanitize kernel shape argument
if not self.args.mnk:
Expand Down Expand Up @@ -167,8 +171,15 @@ def manipulator(self):
and (self.device and "" != self.device)
and (self.size and 0 < self.size)
and self.typeid
): # construct label used for the database session
if not self.args.label: # consider to include self.device
): # setup database (DB)
if args.database is None: # adjust DB-location
if os.path.isdir(dbdir):
shutil.rmtree(dbdir)
os.mkdir(dbdir)
self.args.database = "sqlite:///" + os.path.join(
dbdir, socket.gethostname() + ".db"
)
if not self.args.label: # label for DB-session
self.args.label = "{}-{}-{}x{}x{}-s{}".format(
default_basename,
self.typename,
Expand Down Expand Up @@ -450,7 +461,21 @@ def save_final_config(self, configuration, final=True):
)
filedot = os.path.join(self.args.jsondir, ".{}.json".format(self.args.label))
if self.gfsave < self.gflops: # save intermediate result
self.gfsave = self.gflops
if 0 == self.gfsave and os.path.exists(filedot): # backup
data = None
try:
with open(filedot, "r") as file:
data = json.load(file)
except: # noqa: E722
pass
gflops = data["GFLOPS"] if data and "GFLOPS" in data else 0
filename = os.path.join(
self.args.jsondir,
"{}-{}gflops.json".format(self.args.label, round(gflops))
if 0 < gflops
else "{}.json".format(self.args.label),
)
os.rename(filedot, filename)
# self.manipulator().save_to_file(config, filename)
with open(filedot, "w") as file:
cfg = config
Expand All @@ -459,41 +484,35 @@ def save_final_config(self, configuration, final=True):
del cfg["XF"]
json.dump(cfg, file, sort_keys=True)
file.write("\n") # append newline at EOF
self.gfsave = self.gflops
# check return code (consider not saving parameters)
if 0 != result and not final: # incorrect result
failed = " ".join(map(str, cfgenv)).replace("OPENCL_LIBSMM_SMM_", "")
print("FAILED: {}".format(failed))
return
if final:
if final and os.path.exists(filedot):
if not filenames and glob.glob(self.args.csvfile):
print(
"WARNING: no JSON file found but {} exists.".format(
"WARNING: no JSON file found but {} will be overwritten.".format(
self.args.csvfile
)
)
filename = os.path.normpath(
os.path.join(
self.args.jsondir,
"{}-{}gflops.json".format(self.args.label, round(self.gfsave)),
"{}-{}gflops.json".format(self.args.label, round(self.gflops)),
)
)
if os.path.exists(filedot):
os.rename(filedot, filename)
if filename not in filenames:
filenames.append(filename)
self.merge_jsons(filenames)
speedup = round(
(self.gfsave / self.gfbase) if 0 < self.gfbase else 0, 1
)
print(
"Result{} was written to {}".format(
" ({}x over seed)".format(speedup) if 1 < speedup else "",
filename,
)
)
else:
print("WARNING: tuned result seems to be incorrect!")
exit(0)
os.rename(filedot, filename)
if filename not in filenames: # rebuild CSV-file
filenames.append(filename)
self.merge_jsons(filenames)
speedup = round((self.gflops / self.gfbase) if 0 < self.gfbase else 0, 1)
msg = " ({}x over seed)".format(speedup) if 1 < speedup else ""
print("Result{} was written to {}".format(msg, filename))
else:
print("WARNING: tuned result seems to be incorrect!")
exit(0)

def handle_sigint(self, signum, frame):
"""Handle SIGINT or CTRL-C"""
Expand Down Expand Up @@ -780,6 +799,6 @@ def handle_sigint(self, signum, frame):
try:
SmmTuner.main(args)
except Exception as e:
print("ERROR {}: {}!".format(type(e).__name__, e))
print("{}: {}".format(type(e).__name__, e))
print("WARNING: ignored above error!")
pass

0 comments on commit 37640ad

Please sign in to comment.