Skip to content

Commit

Permalink
Add query generation based on the mode selected
Browse files Browse the repository at this point in the history
Issue #281
  • Loading branch information
sshivam95 committed Nov 29, 2024
1 parent f325b8a commit 85e3f79
Showing 1 changed file with 52 additions and 20 deletions.
72 changes: 52 additions & 20 deletions dicee/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


class QueryGenerator:
def __init__(self, train_path, val_path: str, test_path: str, ent2id: Dict = None, rel2id: Dict = None,
def __init__(self, train_path, val_path: str = None, test_path: str = None, ent2id: Dict = None, rel2id: Dict = None,
seed: int = 1,
gen_valid: bool = False,
gen_test: bool = True):
gen_test: bool = True,
mode: str = "train"):

self.train_path = train_path
self.val_path = val_path
Expand All @@ -24,7 +25,7 @@ def __init__(self, train_path, val_path: str, test_path: str, ent2id: Dict = Non

self.max_ans_num = 1e6

self.mode = str
self.mode = mode
self.ent2id = ent2id
self.rel2id: Dict = rel2id
self.ent_in: Dict = {}
Expand Down Expand Up @@ -81,7 +82,12 @@ def construct_graph(self, paths: List[str]) -> Tuple[Dict, Dict]:
for path in paths:
with open(path, "r") as f:
for line in f:
h, r, t = map(str, line.strip().split("\t"))
try:
h, r, t = map(str, line.strip().split("\t"))
except:
h, r, t, _ = map(str, line.strip().split(" "))
if t.startswith('"'):
continue # Skip literals
tail_relation_to_heads[self.ent2id[t]][self.rel2id[r]].add(self.ent2id[h])
head_relation_to_tails[self.ent2id[h]][self.rel2id[r]].add(self.ent2id[t])

Expand Down Expand Up @@ -447,28 +453,54 @@ def generate_queries(self, query_struct:List, gen_num: int, query_type: str):
and getting queries and answers in return
@ TODO: create a class for each single query struct
"""

if not self.train_path:
raise ValueError("Training path (train_path) is empty. It must be specified.")

if self.mode == "train":
tail_relation_to_heads, head_relation_to_tails = self.construct_graph(paths=[self.train_path])
val_tail_relation_to_heads, val_head_relation_to_tails = {}, {} # No validation data
elif self.mode == 'valid':
# Check if val_path is not empty
if not self.val_path:
raise ValueError("Validation path (val_path) is empty. It must be specified for 'valid' mode.")

# Use training and validation data
tail_relation_to_heads, head_relation_to_tails = self.construct_graph(paths=[self.train_path, self.val_path])
val_tail_relation_to_heads, val_head_relation_to_tails = self.construct_graph(paths=[self.val_path])
elif self.mode == 'test':
# Check if val_path and test_path are not empty
if not self.val_path:
raise ValueError("Validation path (val_path) is empty. It must be specified for 'test' mode.")
if not self.test_path:
raise ValueError("Test path (test_path) is empty. It must be specified for 'test' mode.")

# Use all data for constructing the graph
tail_relation_to_heads, head_relation_to_tails = self.construct_graph(
paths=[self.train_path, self.val_path, self.test_path])
val_tail_relation_to_heads, val_head_relation_to_tails = self.construct_graph(
paths=[self.val_path, self.test_path])
else:
raise ValueError(f"Unknown mode '{self.mode}'. Mode must be 'train', 'valid', or 'test'.")

# Ground the queries using the constructed graphs
queries, tp_answers, fp_answers, fn_answers = self.ground_queries(
query_struct,
tail_relation_to_heads,
head_relation_to_tails,
val_tail_relation_to_heads,
val_head_relation_to_tails,
gen_num,
query_type
)


train_tail_relation_to_heads, train_head_relation_to_tails = self.construct_graph(paths=[self.train_path])
val_tail_relation_to_heads, val_head_relation_to_tails = self.construct_graph(
paths=[self.train_path, self.val_path])
# ?!
valid_only_ent_in, valid_only_ent_out = self.construct_graph(paths=[self.val_path, self.test_path])

test_tail_relation_to_heads, test_head_relation_to_tails = self.construct_graph(
paths=[self.train_path, self.val_path, self.test_path])
# ?!
test_only_ent_in, test_only_ent_out = self.construct_graph(paths=[self.test_path])
self.mode = 'test'
test_queries, test_tp_answers, test_fp_answers, test_fn_answers = self.ground_queries(
query_struct, test_tail_relation_to_heads, test_head_relation_to_tails, val_tail_relation_to_heads,
val_head_relation_to_tails, gen_num, query_type)
# @TODO: test_queries has keys that are tuple ,e.g. ('e', ('r',))
# Yet, query structure defined as a list ['e', ['r']].
# Fix this inconsistency
print(
f"General structure is {query_struct} with name {query_type}. Number of queries generated: {len(test_tp_answers)}")
return test_queries, test_tp_answers, test_fp_answers, test_fn_answers
f"General structure is {query_struct} with name {query_type}. Number of queries generated: {len(tp_answers)}")
return queries, tp_answers, fp_answers, fn_answers

def save_queries(self, query_type: str, gen_num: int, save_path: str):
"""
Expand Down

0 comments on commit 85e3f79

Please sign in to comment.