-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathoracle.py
83 lines (70 loc) · 3.11 KB
/
oracle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python
#coding=utf-8
'''
Definition of Oracle class. Given the gold AMR graph, the alignments and
the current state, it decides which action should be taken next.
@author: Marco Damonte ([email protected])
@since: 03-10-16
'''
from action import Action
from relations import Relations
import copy
from subgraph import Subgraph
class Oracle:
def reentrancy(self, node, found):
siblings = [item[0] for p in found.parents[node] for item in found.children[p[0]] if item[0] != node]
for s in siblings:
label = self.gold.isRel(node, s)
if label is not None:
self.gold.parents[s].remove((node,label))
self.gold.children[node].remove((s,label))
parents = [i[0] for i in found.parents[node]]
parents = [i[0] for i in found.parents[s] if i[0] in parents]
return [s, label, siblings]
return None
def __init__(self, relations):
self.gold = Relations(copy.deepcopy(relations))
def valid_actions(self, state):
top = state.stack.top()
other = state.stack.get(1)
label = self.gold.isRel(top, other)
if label is not None:
self.gold.children[top].remove((other,label))
self.gold.parents[other].remove((top,label))
return Action("larc", label)
label = self.gold.isRel(other, top)
if label is not None:
self.gold.parents[top].remove((other,label))
self.gold.children[other].remove((top,label))
return Action("rarc", label)
if state.stack.isEmpty() == False:
found = False
for item in state.buffer.tokens:
for node in item.nodes:
if self.gold.isRel(top, node) is not None or self.gold.isRel(node, top) is not None:
found = True
if found == False:
return Action("reduce", self.reentrancy(top, state.stack.relations))
if state.buffer.isEmpty() == False:
token = state.buffer.peek()
nodes = token.nodes
relations = []
flag = False
for n1 in nodes:
for n2 in nodes:
if n1 != n2:
children_n1 = copy.deepcopy(self.gold.children[n1])
for (child,label) in children_n1:
if child == n2:
relations.append((n1,n2,label))
self.gold.children[n1].remove((child,label))
self.gold.parents[child].remove((n1,label))
children_n2 = copy.deepcopy(self.gold.children[n2])
for (child,label) in children_n2:
if child == n1:
relations.append((n2,n1,label))
self.gold.children[n2].remove((child,label))
self.gold.parents[child].remove((n2,label))
subgraph = Subgraph(nodes, relations)
return Action("shift", subgraph)
return None