Skip to content

Commit

Permalink
Fixed log prefix issue when splitting the output
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-drielsma committed Oct 4, 2024
1 parent 208d086 commit d6de450
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions spine/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def __init__(self, cfg, rank=None):
assert self.model is None or self.unwrap, (
"Must unwrap the model output to run analysis scripts.")
self.watch.initialize('ana')
self.ana = AnaManager(ana, log_dir=self.log_dir, prefix=self.prefix)
self.ana = AnaManager(
ana, log_dir=self.log_dir, prefix=self.log_prefix)

def __len__(self):
"""Returns the number of events in the underlying reader object."""
Expand Down Expand Up @@ -362,7 +363,8 @@ def initialize_io(self, loader=None, reader=None, writer=None):
self.iter_per_epoch = len(self.reader)

# Fetch an appropriate common prefix for all input files
self.prefix = self.get_prefix(self.reader.file_paths, self.split_output)
self.log_prefix, self.output_prefix = self.get_prefixes(
self.reader.file_paths, self.split_output)

# Initialize the data writer, if provided
self.writer = None
Expand All @@ -371,7 +373,7 @@ def initialize_io(self, loader=None, reader=None, writer=None):
"Must unwrap the model output to write it to file.")
self.watch.initialize('write')
self.writer = writer_factory(
writer, prefix=self.prefix, split=self.split_output)
writer, prefix=self.output_prefix, split=self.split_output)

# Harmonize the iterations and epochs parameters
assert (self.iterations is None) or (self.epochs is None), (
Expand All @@ -384,7 +386,7 @@ def initialize_io(self, loader=None, reader=None, writer=None):
self.iterations = self.epochs*self.iter_per_epoch

@staticmethod
def get_prefix(file_paths, split_output):
def get_prefixes(file_paths, split_output):
"""Builds an appropriate output prefix based on the list of input files.
Parameters
Expand All @@ -399,10 +401,6 @@ def get_prefix(file_paths, split_output):
Union[str, List[str]]
Shared input summary string to be used to prefix outputs
"""
# If the output is to be split, use the basename of each file
if split_output:
return [os.path.splitext(os.path.basename(f))[0] for f in file_paths]

# Fetch file base names (ignore where they live)
file_names = [os.path.basename(f) for f in file_paths]

Expand All @@ -420,8 +418,14 @@ def get_prefix(file_paths, split_output):
last = last[0] if last[0] and last[0][0] != '.' else ''

suffix = f'{first}--{len(file_names)-2}--{last}'
log_prefix = prefix + suffix

return prefix + suffix
# Always provide a single prefix for the log, adapt output prefix
if not split_output:
return log_prefix, log_prefix
else:
return (log_prefix,
[os.path.splitext(name)[0] for name in file_names])

def initialize_log(self):
"""Initialize the output log for this driver process."""
Expand All @@ -443,7 +447,7 @@ def initialize_log(self):

# If requested, prefix the log name with the input file name
if self.prefix_log:
log_name = f'{self.prefix}_{log_name}'
log_name = f'{self.log_prefix}_{log_name}'

# Initialize the log
log_path = os.path.join(self.log_dir, log_name)
Expand Down

0 comments on commit d6de450

Please sign in to comment.