-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathminiflow.py
118 lines (89 loc) · 2.96 KB
/
miniflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Modify Linear#forward so that it linearly transforms
input matrices, weights matrices and a bias vector to
an output.
"""
import numpy as np
class Node(object):
def __init__(self, inbound_nodes=[]):
self.inbound_nodes = inbound_nodes
self.value = None
self.outbound_nodes = []
for node in inbound_nodes:
node.outbound_nodes.append(self)
def forward():
raise NotImplementedError
class Input(Node):
"""
While it may be strange to consider an input a node when
an input is only an individual node in a node, for the sake
of simpler code we'll still use Node as the base class.
Think of Input as collating many individual input nodes into
a Node.
"""
def __init__(self):
# An Input node has no inbound nodes,
# so no need to pass anything to the Node instantiator
Node.__init__(self)
def forward(self):
# Do nothing because nothing is calculated.
pass
class Linear(Node):
def __init__(self, X, W, b):
# Notice the ordering of the input nodes passed to the
# Node constructor.
Node.__init__(self, [X, W, b])
def forward(self):
"""
Set the value of this node to the linear transform output.
Your code goes here!
"""
X = self.inbound_nodes[0].value
W = self.inbound_nodes[1].value
b = self.inbound_nodes[2].value
# Numpy dot product to the rescue!
self.value = np.dot(X, W) + b
def topological_sort(feed_dict):
"""
Sort the nodes in topological order using Kahn's Algorithm.
`feed_dict`: A dictionary where the key is a `Input` Node and the value is the respective value feed to that Node.
Returns a list of sorted nodes.
"""
input_nodes = [n for n in feed_dict.keys()]
G = {}
nodes = [n for n in input_nodes]
while len(nodes) > 0:
n = nodes.pop(0)
if n not in G:
G[n] = {'in': set(), 'out': set()}
for m in n.outbound_nodes:
if m not in G:
G[m] = {'in': set(), 'out': set()}
G[n]['out'].add(m)
G[m]['in'].add(n)
nodes.append(m)
L = []
S = set(input_nodes)
while len(S) > 0:
n = S.pop()
if isinstance(n, Input):
n.value = feed_dict[n]
L.append(n)
for m in n.outbound_nodes:
G[n]['out'].remove(m)
G[m]['in'].remove(n)
# if no other incoming edges add to S
if len(G[m]['in']) == 0:
S.add(m)
return L
def forward_pass(output_node, sorted_nodes):
"""
Performs a forward pass through a list of sorted Nodes.
Arguments:
`output_node`: A Node in the graph, should be the output node (have no outgoing edges).
`sorted_nodes`: a topologically sorted list of nodes.
Returns the output node's value
"""
for n in sorted_nodes:
n.forward()
return output_node.value