forked from LLNL/AutoCog
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathautomaton.py
60 lines (51 loc) · 2.53 KB
/
automaton.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from typing import Any, Dict, List, Tuple, Union, Optional, Callable, NamedTuple
from abc import abstractmethod
from pydantic import BaseModel
from .vocab import Token
from .actions import Action, Choose
from ...lm.lm import LM
from pydantic_numpy import NDArrayFp32
class FiniteTokenTree(BaseModel):
token: Optional[Token] = None
proba:float = 1.
children: Dict[Token,"FiniteTokenTree"] = {}
class FiniteThoughtAutomaton(BaseModel):
lm: LM
actions: Dict[str,Action] = {}
def create(self, cls, **action):
act = cls(tokenizer=self.lm, **action)
self.actions.update({ act.uid : act })
return act
def greedy_rec(self, prompt:List[Token], root:FiniteTokenTree, action:Action, step:int=0, **kwargs):
branches = action.step(lm=self.lm, step=step, prompt=prompt, **kwargs)
for (tok,prob) in branches.items():
tree = FiniteTokenTree(token=tok, proba=prob)
root.children.update({tok:tree})
self.greedy_rec(prompt=prompt+[tok], root=tree, action=action, step=step+1, **kwargs)
if len(branches) == 0:
aid = action.next(prompt=prompt)
self.greedy_rec(prompt=prompt, root=root, action=self.actions[aid], step=0, **kwargs)
def greedy(self, entry:str, header:str='', min_branch:int=2, max_branch:int=5, tok_clip:float=.9):
header = self.lm.tokenize(header)
root = FiniteTokenTree()
self.greedy_rec(prompt=header, root=root, action=self.actions[entry], min_branch=min_branch, max_branch=max_branch, tok_clip=tok_clip)
return root
def toGraphViz(self):
dotstr = ""
for act in self.actions.values():
dotstr += act.toGraphVizNode() + '\n'
for act in self.actions.values():
stags = list(set([ None if s is None else self.actions[s].toGraphVizTag() for s in act.successors ]))
if len(stags) == 1:
if stags[0] is None:
dotstr += f' {act.toGraphVizTag()} -> {act.toGraphVizTag()}_end [label="*"];\n'
else:
dotstr += f' {act.toGraphVizTag()} -> {stags[0]} [label="*"];\n'
elif len(stags) > 1:
assert isinstance(act, Choose)
assert len(stags) == len(act.choices)
for (stag,(text,toks)) in zip(stags,act.choices):
dotstr += f' {act.toGraphVizTag()} -> {stag} [label="{text}"];\n'
else:
dotstr += f' {act.toGraphVizTag()} -> {act.toGraphVizTag()}_end [label="*"];\n'
return dotstr