-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgreedy_join.py
executable file
·112 lines (103 loc) · 4.3 KB
/
greedy_join.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from __future__ import print_function
import sys
import pddl
import pddl_to_prolog
class OccurrencesTracker(object):
"""Keeps track of the number of times each variable appears
in a list of symbolic atoms."""
def __init__(self, rule):
self.occurrences = {}
self.update(rule.effect, +1)
for cond in rule.conditions:
self.update(cond, +1)
def update(self, symatom, delta):
for var in symatom.args:
if var[0] == "?":
if var not in self.occurrences:
self.occurrences[var] = 0
self.occurrences[var] += delta
assert self.occurrences[var] >= 0
if not self.occurrences[var]:
del self.occurrences[var]
def variables(self):
return set(self.occurrences)
class CostMatrix(object):
def __init__(self, joinees):
self.joinees = []
self.cost_matrix = []
for joinee in joinees:
self.add_entry(joinee)
def add_entry(self, joinee):
new_row = [self.compute_join_cost(joinee, other) for other in self.joinees]
self.cost_matrix.append(new_row)
self.joinees.append(joinee)
def delete_entry(self, index):
for row in self.cost_matrix[index + 1:]:
del row[index]
del self.cost_matrix[index]
del self.joinees[index]
def find_min_pair(self):
assert len(self.joinees) >= 2
min_cost = (sys.maxsize, sys.maxsize)
for i, row in enumerate(self.cost_matrix):
for j, entry in enumerate(row):
if entry < min_cost:
min_cost = entry
left_index, right_index = i, j
return left_index, right_index
def remove_min_pair(self):
left_index, right_index = self.find_min_pair()
left, right = self.joinees[left_index], self.joinees[right_index]
assert left_index > right_index
self.delete_entry(left_index)
self.delete_entry(right_index)
return (left, right)
def compute_join_cost(self, left_joinee, right_joinee):
left_vars = pddl_to_prolog.get_variables([left_joinee])
right_vars = pddl_to_prolog.get_variables([right_joinee])
if len(left_vars) > len(right_vars):
left_vars, right_vars = right_vars, left_vars
common_vars = left_vars & right_vars
return (len(left_vars) - len(common_vars),
len(right_vars) - len(common_vars),
-len(common_vars))
def can_join(self):
return len(self.joinees) >= 2
class ResultList(object):
def __init__(self, rule, name_generator):
self.final_effect = rule.effect
self.result = []
self.name_generator = name_generator
def get_result(self):
self.result[-1].effect = self.final_effect
return self.result
def add_rule(self, type, conditions, effect_vars):
effect = pddl.Atom(next(self.name_generator), effect_vars)
rule = pddl_to_prolog.Rule(conditions, effect)
rule.type = type
self.result.append(rule)
return rule.effect
def greedy_join(rule, name_generator):
assert len(rule.conditions) >= 2
cost_matrix = CostMatrix(rule.conditions)
occurrences = OccurrencesTracker(rule)
result = ResultList(rule, name_generator)
while cost_matrix.can_join():
joinees = list(cost_matrix.remove_min_pair())
for joinee in joinees:
occurrences.update(joinee, -1)
common_vars = set(joinees[0].args) & set(joinees[1].args)
condition_vars = set(joinees[0].args) | set(joinees[1].args)
effect_vars = occurrences.variables() & condition_vars
for i, joinee in enumerate(joinees):
joinee_vars = set(joinee.args)
retained_vars = joinee_vars & (effect_vars | common_vars)
if retained_vars != joinee_vars:
joinees[i] = result.add_rule("project", [joinee], sorted(retained_vars))
joint_condition = result.add_rule("join", joinees, sorted(effect_vars))
cost_matrix.add_entry(joint_condition)
occurrences.update(joint_condition, +1)
#assert occurrences.variables() == set(rule.effect.args)
#for var in set(rule.effect.args):
# assert occurrences.occurrences[var] == 2 * rule.effect.args.count(var)
return result.get_result()