diff --git a/hippynn/databases/SNAPJson.py b/hippynn/databases/SNAPJson.py index 717c1078..c9cb208b 100644 --- a/hippynn/databases/SNAPJson.py +++ b/hippynn/databases/SNAPJson.py @@ -25,6 +25,7 @@ def __init__( transpose_cell=True, allow_unfound=False, quiet=False, + comments=1, **kwargs, ): @@ -34,6 +35,7 @@ def __init__( self.targets = targets self.transpose_cell = transpose_cell self.depth = depth + self.comments = comments arr_dict = self.load_arrays(quiet=quiet, allow_unfound=allow_unfound) super().__init__(arr_dict, inputs, targets, *args, **kwargs, allow_unfound=allow_unfound, quiet=quiet) @@ -96,7 +98,8 @@ def filter_arrays(self, arr_dict, allow_unfound=False, quiet=False): def extract_snap_file(self, file): with open(file, "rt") as jf: - comment = jf.readline() + for i in range(self.comments): + comment = jf.readline() content = jf.read() parsed = json.loads(content) dataset = parsed["Dataset"]