From 85e3f79f317b4725ff6f1b5dc0f99cace4b9a5d5 Mon Sep 17 00:00:00 2001 From: Shivam Sharma Date: Fri, 29 Nov 2024 09:52:58 +0100 Subject: [PATCH] Add query generation based on the mode selected Issue #281 --- dicee/query_generator.py | 72 +++++++++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/dicee/query_generator.py b/dicee/query_generator.py index f9e1cccd..60b96fd1 100644 --- a/dicee/query_generator.py +++ b/dicee/query_generator.py @@ -9,10 +9,11 @@ class QueryGenerator: - def __init__(self, train_path, val_path: str, test_path: str, ent2id: Dict = None, rel2id: Dict = None, + def __init__(self, train_path, val_path: str = None, test_path: str = None, ent2id: Dict = None, rel2id: Dict = None, seed: int = 1, gen_valid: bool = False, - gen_test: bool = True): + gen_test: bool = True, + mode: str = "train"): self.train_path = train_path self.val_path = val_path @@ -24,7 +25,7 @@ def __init__(self, train_path, val_path: str, test_path: str, ent2id: Dict = Non self.max_ans_num = 1e6 - self.mode = str + self.mode = mode self.ent2id = ent2id self.rel2id: Dict = rel2id self.ent_in: Dict = {} @@ -81,7 +82,12 @@ def construct_graph(self, paths: List[str]) -> Tuple[Dict, Dict]: for path in paths: with open(path, "r") as f: for line in f: - h, r, t = map(str, line.strip().split("\t")) + try: + h, r, t = map(str, line.strip().split("\t")) + except: + h, r, t, _ = map(str, line.strip().split(" ")) + if t.startswith('"'): + continue # Skip literals tail_relation_to_heads[self.ent2id[t]][self.rel2id[r]].add(self.ent2id[h]) head_relation_to_tails[self.ent2id[h]][self.rel2id[r]].add(self.ent2id[t]) @@ -447,28 +453,54 @@ def generate_queries(self, query_struct:List, gen_num: int, query_type: str): and getting queries and answers in return @ TODO: create a class for each single query struct """ + + if not self.train_path: + raise ValueError("Training path (train_path) is empty. It must be specified.") + + if self.mode == "train": + tail_relation_to_heads, head_relation_to_tails = self.construct_graph(paths=[self.train_path]) + val_tail_relation_to_heads, val_head_relation_to_tails = {}, {} # No validation data + elif self.mode == 'valid': + # Check if val_path is not empty + if not self.val_path: + raise ValueError("Validation path (val_path) is empty. It must be specified for 'valid' mode.") + + # Use training and validation data + tail_relation_to_heads, head_relation_to_tails = self.construct_graph(paths=[self.train_path, self.val_path]) + val_tail_relation_to_heads, val_head_relation_to_tails = self.construct_graph(paths=[self.val_path]) + elif self.mode == 'test': + # Check if val_path and test_path are not empty + if not self.val_path: + raise ValueError("Validation path (val_path) is empty. It must be specified for 'test' mode.") + if not self.test_path: + raise ValueError("Test path (test_path) is empty. It must be specified for 'test' mode.") + + # Use all data for constructing the graph + tail_relation_to_heads, head_relation_to_tails = self.construct_graph( + paths=[self.train_path, self.val_path, self.test_path]) + val_tail_relation_to_heads, val_head_relation_to_tails = self.construct_graph( + paths=[self.val_path, self.test_path]) + else: + raise ValueError(f"Unknown mode '{self.mode}'. Mode must be 'train', 'valid', or 'test'.") + + # Ground the queries using the constructed graphs + queries, tp_answers, fp_answers, fn_answers = self.ground_queries( + query_struct, + tail_relation_to_heads, + head_relation_to_tails, + val_tail_relation_to_heads, + val_head_relation_to_tails, + gen_num, + query_type + ) - train_tail_relation_to_heads, train_head_relation_to_tails = self.construct_graph(paths=[self.train_path]) - val_tail_relation_to_heads, val_head_relation_to_tails = self.construct_graph( - paths=[self.train_path, self.val_path]) - # ?! - valid_only_ent_in, valid_only_ent_out = self.construct_graph(paths=[self.val_path, self.test_path]) - - test_tail_relation_to_heads, test_head_relation_to_tails = self.construct_graph( - paths=[self.train_path, self.val_path, self.test_path]) - # ?! - test_only_ent_in, test_only_ent_out = self.construct_graph(paths=[self.test_path]) - self.mode = 'test' - test_queries, test_tp_answers, test_fp_answers, test_fn_answers = self.ground_queries( - query_struct, test_tail_relation_to_heads, test_head_relation_to_tails, val_tail_relation_to_heads, - val_head_relation_to_tails, gen_num, query_type) # @TODO: test_queries has keys that are tuple ,e.g. ('e', ('r',)) # Yet, query structure defined as a list ['e', ['r']]. # Fix this inconsistency print( - f"General structure is {query_struct} with name {query_type}. Number of queries generated: {len(test_tp_answers)}") - return test_queries, test_tp_answers, test_fp_answers, test_fn_answers + f"General structure is {query_struct} with name {query_type}. Number of queries generated: {len(tp_answers)}") + return queries, tp_answers, fp_answers, fn_answers def save_queries(self, query_type: str, gen_num: int, save_path: str): """