-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrelation_to_adj_matrix.py
119 lines (98 loc) · 2.79 KB
/
relation_to_adj_matrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import print_function
import numpy as np
"""
Basic operations on dependency trees.
"""
class Tree(object):
"""
Reused tree object from stanfordnlp/treelstm.
"""
def __init__(self):
self.parent = None
self.num_children = 0
self.children = list()
def add_child(self,child):
child.parent = self
self.num_children += 1
self.children.append(child)
def size(self):
if getattr(self, '_size'):
return self._size
count = 1
for i in range(self.num_children):
count += self.children[i].size()
self._size = count
return self._size
def depth(self):
if getattr(self, '_depth'):
return self._depth
count = 0
if self.num_children>0:
for i in range(self.num_children):
child_depth = self.children[i].depth()
if child_depth>count:
count = child_depth
count += 1
self._depth = count
return self._depth
def __iter__(self):
yield self
for c in self.children:
for x in c:
yield x
def head_to_tree(head):
"""
Convert a sequence of head indexes into a tree object.
"""
head = sorted(head, key=lambda x: x[2])
head = [w[1] for w in head]
# print(head, len(head))
# tokens = tokens[:len(head)]
# head = head
root = None
# print('head:'. head.size())
# print('tokens:', )
nodes = [Tree() for _ in head]
for i in range(len(nodes)):
h = head[i]
# print('1111', h)
nodes[i].idx = i
nodes[i].dist = -1 # just a filler
if h == 0:
root = nodes[i]
else:
nodes[h-1].add_child(nodes[i])
assert root is not None
return root
def tree_to_adj(sent_len, tree, sent, not_directed=True):
"""
Convert a tree object to an (numpy) adjacency matrix.
"""
# ret = np.ones((sent_len, sent_len), dtype=np.float32)
ret = np.zeros((sent_len, sent_len), dtype=np.float32)
length = ret.shape[0]
queue = [tree]
idx = []
while len(queue) > 0:
t, queue = queue[0], queue[1:]
idx += [t.idx]
for c in t.children:
ret[t.idx, c.idx] = 1
queue += t.children
# if sent == 'sent2':
# for i in range(length):
# ret[length-1, i] = 1
# elif sent == 'sent3':
# for i in range(length):
# ret[length-2, i] = 1
# ret[length-1, i] = 1
# elif sent == 'sent4':
# for i in range(length):
# ret[length-3, i] = 1
# ret[length-2, i] = 1
# ret[length-1, i] = 1
#
if not_directed:
ret = ret + ret.T
ret = ret + np.eye(sent_len)
return ret