Skip to content

Commit

Permalink
add arguments to superseed the config files
Browse files Browse the repository at this point in the history
  • Loading branch information
agrouaze committed Feb 5, 2024
1 parent 72084a2 commit 63eb4a5
Showing 1 changed file with 57 additions and 22 deletions.
79 changes: 57 additions & 22 deletions l2awinddirection/l2awinddirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from l2awinddirection.M64RN4 import M64RN4_distribution, M64RN4_regression

conf = get_conf()


def get_memory_usage():
try:
import resource
Expand All @@ -46,9 +48,7 @@ def main():
for handler in root.handlers:
root.removeHandler(handler)
time.sleep(np.random.rand(1, 1)[0][0]) # to avoid issue with mkdir
parser = argparse.ArgumentParser(
description="winddirection_prediction->L2Awinddir"
)
parser = argparse.ArgumentParser(description="winddirection_prediction->L2Awinddir")
parser.add_argument("--verbose", action="store_true", default=False)
parser.add_argument(
"--overwrite",
Expand All @@ -73,8 +73,28 @@ def main():
required=True,
help="directory where to move and store output netCDF files",
)
parser.add_argument('--remove-tiles',help='remove the tiles files in the workspace directory [default=True]',
required=False,default=True,action='store_false')
parser.add_argument(
"--workspace",
help="specify a workspace for tiles I/O [default is the path in localconfig.yml or config.yml]",
required=False,
)
parser.add_argument(
"--modelpdf",
help="specify the path of the NN model to predict PDF result [default is localconfig.yml or config.yml]",
required=False,
)
parser.add_argument(
"--modelregression",
help="specify the path of the NN model to predict float (regression) result [default is localconfig.yml or config.yml]",
required=False,
)
parser.add_argument(
"--remove-tiles",
help="remove the tiles files in the workspace directory [default=True]",
required=False,
default=True,
action="store_false",
)
args = parser.parse_args()
fmt = "%(asctime)s %(levelname)s %(filename)s(%(lineno)d) %(message)s"
if args.verbose:
Expand All @@ -93,19 +113,27 @@ def main():
input_shape = (44, 44, 1)
data_augmentation = True
learning_rate = 1e-3
logging.info('the mode (model) chosen is: %s',args.mode)
logging.info("the mode (model) chosen is: %s", args.mode)
if args.modelpdf is None:
path_model_pdf = conf["model_path_pdf"]
else:
path_model_pdf = args.modelpdf
if args.modelregression is None:
dirmodelsreg = conf["models_path_regression"]
else:
dirmodelsreg = args.modelregression
if args.mode == "pdf":
n_classes = 36
model_m64rn4 = M64RN4_distribution(input_shape, data_augmentation, n_classes)
model_m64rn4.create_and_compile(learning_rate)
path_model = conf["model_path_pdf"]
model_m64rn4.model.load_weights(path_model)

model_m64rn4.model.load_weights(path_model_pdf)
elif args.mode == "regression":
model_m64rn4 = []
# path_best_models = glob.glob(
# ".../analysis/s1_data_analysis/project_rmarquar/wsat/trained_models/iw/*.hdf5"
# )
path_best_models = glob.glob(os.path.join(conf['models_path_regression'],"*.hdf5"))
path_best_models = glob.glob(os.path.join(dirmodelsreg, "*.hdf5"))
for path in path_best_models:

m64rn4_reg = M64RN4_regression(input_shape, data_augmentation)
Expand All @@ -117,13 +145,20 @@ def main():
else:
raise Exception("not handled case")
# files = glob.glob("/raid/localscratch/agrouaze/tiles_iw_4_wdir/3.1/*SAFE/*.nc")
logging.info('workspace where the tiles will be temporarily moved is: %s',conf['workspace_prediction'])
safefile = os.path.join(conf['workspace_prediction'],os.path.basename(args.l2awindirtilessafe))
logging.info(' step 1: move %s -> %s',args.l2awindirtilessafe,safefile)
shutil.move(args.l2awindirtilessafe,safefile)
if args.workspace is None:
workspace = conf["workspace_prediction"]
else:
workspace = args.workspace
logging.info(
"workspace where the tiles will be temporarily moved is: %s",
workspace,
)
safefile = os.path.join(workspace, os.path.basename(args.l2awindirtilessafe))
logging.info(" step 1: move %s -> %s", args.l2awindirtilessafe, safefile)
shutil.move(args.l2awindirtilessafe, safefile)
files = glob.glob(os.path.join(safefile, "*.nc"))
logging.info("Number of files to process: %d" % len(files))
logging.info('step 2: predictions')
logging.info("step 2: predictions")
for ii in tqdm(range(len(files))):
file = files[ii]
# for file in files:
Expand All @@ -144,17 +179,17 @@ def main():
# if you want to remove the file containing the tiles used for inference (data kept in final product)
del tiles
if args.remove_tiles:
logging.info('remove temporary tiles file in the workspace: %s',file)
logging.info("remove temporary tiles file in the workspace: %s", file)
os.remove(file)
final_safe_path = os.path.join(args.outpurdir,os.path.basename(args.l2awindirtilessafe))
final_safe_path = os.path.join(
args.outpurdir, os.path.basename(args.l2awindirtilessafe)
)
if not os.path.exists(os.path.dirname(final_safe_path)):
logging.info('mkdir %s',os.path.dirname(final_safe_path))
logging.info("mkdir %s", os.path.dirname(final_safe_path))
os.makedirs(os.path.dirname(final_safe_path))
logging.info('step 3: move %s -> %s',safefile,final_safe_path)
shutil.move(safefile,final_safe_path)
logging.info(
"Ifremer Level-2A wind direction SAFE path: %s", final_safe_path
)
logging.info("step 3: move %s -> %s", safefile, final_safe_path)
shutil.move(safefile, final_safe_path)
logging.info("Ifremer Level-2A wind direction SAFE path: %s", final_safe_path)
logging.info("successful SAFE processing")
logging.info("peak memory usage: %s ", get_memory_usage())
logging.info("done in %1.3f min", (time.time() - t0) / 60.0)

0 comments on commit 63eb4a5

Please sign in to comment.