Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ocl: improved tuning script #729

Merged
merged 1 commit into from
Oct 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 43 additions & 24 deletions src/acc/opencl/smm/tune_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from opentuner import MeasurementInterface
from opentuner import Result
from signal import signal, SIGINT
import tempfile
import socket
import shutil
import copy
import json
import glob
Expand Down Expand Up @@ -43,6 +46,7 @@ def ilog2(n):
class SmmTuner(MeasurementInterface):
def manipulator(self):
"""Setup common state and define search space"""
dbdir = os.path.join(tempfile.gettempdir(), "opentuner.db")
manipulator = ConfigurationManipulator()
# parse and sanitize kernel shape argument
if not self.args.mnk:
Expand Down Expand Up @@ -167,8 +171,15 @@ def manipulator(self):
and (self.device and "" != self.device)
and (self.size and 0 < self.size)
and self.typeid
): # construct label used for the database session
if not self.args.label: # consider to include self.device
): # setup database (DB)
if args.database is None: # adjust DB-location
if os.path.isdir(dbdir):
shutil.rmtree(dbdir)
os.mkdir(dbdir)
self.args.database = "sqlite:///" + os.path.join(
dbdir, socket.gethostname() + ".db"
)
if not self.args.label: # label for DB-session
self.args.label = "{}-{}-{}x{}x{}-s{}".format(
default_basename,
self.typename,
Expand Down Expand Up @@ -450,7 +461,21 @@ def save_final_config(self, configuration, final=True):
)
filedot = os.path.join(self.args.jsondir, ".{}.json".format(self.args.label))
if self.gfsave < self.gflops: # save intermediate result
self.gfsave = self.gflops
if 0 == self.gfsave and os.path.exists(filedot): # backup
data = None
try:
with open(filedot, "r") as file:
data = json.load(file)
except: # noqa: E722
pass
gflops = data["GFLOPS"] if data and "GFLOPS" in data else 0
filename = os.path.join(
self.args.jsondir,
"{}-{}gflops.json".format(self.args.label, round(gflops))
if 0 < gflops
else "{}.json".format(self.args.label),
)
os.rename(filedot, filename)
# self.manipulator().save_to_file(config, filename)
with open(filedot, "w") as file:
cfg = config
Expand All @@ -459,41 +484,35 @@ def save_final_config(self, configuration, final=True):
del cfg["XF"]
json.dump(cfg, file, sort_keys=True)
file.write("\n") # append newline at EOF
self.gfsave = self.gflops
# check return code (consider not saving parameters)
if 0 != result and not final: # incorrect result
failed = " ".join(map(str, cfgenv)).replace("OPENCL_LIBSMM_SMM_", "")
print("FAILED: {}".format(failed))
return
if final:
if final and os.path.exists(filedot):
if not filenames and glob.glob(self.args.csvfile):
print(
"WARNING: no JSON file found but {} exists.".format(
"WARNING: no JSON file found but {} will be overwritten.".format(
self.args.csvfile
)
)
filename = os.path.normpath(
os.path.join(
self.args.jsondir,
"{}-{}gflops.json".format(self.args.label, round(self.gfsave)),
"{}-{}gflops.json".format(self.args.label, round(self.gflops)),
)
)
if os.path.exists(filedot):
os.rename(filedot, filename)
if filename not in filenames:
filenames.append(filename)
self.merge_jsons(filenames)
speedup = round(
(self.gfsave / self.gfbase) if 0 < self.gfbase else 0, 1
)
print(
"Result{} was written to {}".format(
" ({}x over seed)".format(speedup) if 1 < speedup else "",
filename,
)
)
else:
print("WARNING: tuned result seems to be incorrect!")
exit(0)
os.rename(filedot, filename)
if filename not in filenames: # rebuild CSV-file
filenames.append(filename)
self.merge_jsons(filenames)
speedup = round((self.gflops / self.gfbase) if 0 < self.gfbase else 0, 1)
msg = " ({}x over seed)".format(speedup) if 1 < speedup else ""
print("Result{} was written to {}".format(msg, filename))
else:
print("WARNING: tuned result seems to be incorrect!")
exit(0)

def handle_sigint(self, signum, frame):
"""Handle SIGINT or CTRL-C"""
Expand Down Expand Up @@ -780,6 +799,6 @@ def handle_sigint(self, signum, frame):
try:
SmmTuner.main(args)
except Exception as e:
print("ERROR {}: {}!".format(type(e).__name__, e))
print("{}: {}".format(type(e).__name__, e))
print("WARNING: ignored above error!")
pass