-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinvariant_finder.py
executable file
·162 lines (140 loc) · 6.23 KB
/
invariant_finder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#! /usr/bin/env python
from __future__ import print_function
from collections import deque, defaultdict
import itertools
import time
import invariants
import pddl
import timers
class BalanceChecker(object):
def __init__(self, task, reachable_action_params):
self.predicates_to_add_actions = defaultdict(set)
self.action_to_heavy_action = {}
for act in task.actions:
action = self.add_inequality_preconds(act, reachable_action_params)
too_heavy_effects = []
create_heavy_act = False
heavy_act = action
# FOND
for nondet_choice in action.effects:
too_heavy_nondet_choice = []
for eff in nondet_choice:
too_heavy_nondet_choice.append(eff)
if eff.parameters: # universal effect
create_heavy_act = True
too_heavy_nondet_choice.append(eff.copy())
if not eff.literal.negated:
predicate = eff.literal.predicate
self.predicates_to_add_actions[predicate].add(action)
too_heavy_effects.append(too_heavy_nondet_choice)
if create_heavy_act:
heavy_act = pddl.Action(action.name, action.parameters,
action.num_external_parameters,
action.precondition, too_heavy_effects,
action.cost)
# heavy_act: duplicated universal effects and assigned unique names
# to all quantified variables (implicitly in constructor)
self.action_to_heavy_action[action] = heavy_act
def get_threats(self, predicate):
return self.predicates_to_add_actions.get(predicate, set())
def get_heavy_action(self, action):
return self.action_to_heavy_action[action]
def add_inequality_preconds(self, action, reachable_action_params):
if reachable_action_params is None or len(action.parameters) < 2:
return action
inequal_params = []
combs = itertools.combinations(range(len(action.parameters)), 2)
for pos1, pos2 in combs:
for params in reachable_action_params[action]:
if params[pos1] == params[pos2]:
break
else:
inequal_params.append((pos1, pos2))
if inequal_params:
precond_parts = [action.precondition]
for pos1, pos2 in inequal_params:
param1 = action.parameters[pos1].name
param2 = action.parameters[pos2].name
new_cond = pddl.NegatedAtom("=", (param1, param2))
precond_parts.append(new_cond)
precond = pddl.Conjunction(precond_parts).simplified()
return pddl.Action(
action.name, action.parameters, action.num_external_parameters,
precond, action.effects, action.cost)
else:
return action
def get_fluents(task):
fluent_names = set()
for action in task.actions:
# FOND
for nondet_choice in action.effects:
for eff in nondet_choice:
fluent_names.add(eff.literal.predicate)
return [pred for pred in task.predicates if pred.name in fluent_names]
def get_initial_invariants(task):
for predicate in get_fluents(task):
all_args = list(range(len(predicate.arguments)))
for omitted_arg in [-1] + all_args:
order = [i for i in all_args if i != omitted_arg]
part = invariants.InvariantPart(predicate.name, order, omitted_arg)
yield invariants.Invariant((part,))
# Input file might be grounded, beware of too many invariant candidates
MAX_CANDIDATES = 100000
MAX_TIME = 300
def find_invariants(task, reachable_action_params):
candidates = deque(get_initial_invariants(task))
print(len(candidates), "initial candidates")
seen_candidates = set(candidates)
balance_checker = BalanceChecker(task, reachable_action_params)
def enqueue_func(invariant):
if len(seen_candidates) < MAX_CANDIDATES and invariant not in seen_candidates:
candidates.append(invariant)
seen_candidates.add(invariant)
start_time = time.time()
while candidates:
candidate = candidates.popleft()
if time.time() - start_time > MAX_TIME:
print("Time limit reached, aborting invariant generation")
return
if candidate.check_balance(balance_checker, enqueue_func):
yield candidate
def useful_groups(invariants, initial_facts):
predicate_to_invariants = defaultdict(list)
for invariant in invariants:
for predicate in invariant.predicates:
predicate_to_invariants[predicate].append(invariant)
nonempty_groups = set()
overcrowded_groups = set()
for atom in initial_facts:
if isinstance(atom, pddl.Assign):
continue
for invariant in predicate_to_invariants.get(atom.predicate, ()):
group_key = (invariant, tuple(invariant.get_parameters(atom)))
if group_key not in nonempty_groups:
nonempty_groups.add(group_key)
else:
overcrowded_groups.add(group_key)
useful_groups = nonempty_groups - overcrowded_groups
for (invariant, parameters) in useful_groups:
yield [part.instantiate(parameters) for part in sorted(invariant.parts)]
def get_groups(task, reachable_action_params=None):
with timers.timing("Finding invariants", block=True):
invariants = sorted(find_invariants(task, reachable_action_params))
with timers.timing("Checking invariant weight"):
result = list(useful_groups(invariants, task.init))
return result
if __name__ == "__main__":
import normalize
print("Parsing...")
task = pddl.open()
print("Normalizing...")
normalize.normalize(task)
print("Finding invariants...")
print("NOTE: not passing in reachable_action_params.")
print("This means fewer invariants might be found.")
for invariant in find_invariants(task, None):
print(invariant)
print("Finding fact groups...")
groups = get_groups(task)
for group in groups:
print("[%s]" % ", ".join(map(str, group)))