-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLSTMCell.py
157 lines (138 loc) · 6.94 KB
/
LSTMCell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import tensorflow as tf
import tensorflow.contrib.layers as layers
from tensorflow.python.ops import array_ops, tensor_array_ops, io_ops
from tensorflow.python.framework import tensor_shape
import contextlib
from tensorflow.contrib import rnn
#from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear
from tensorflow.python.ops.math_ops import sigmoid
from tensorflow.python.ops.math_ops import tanh
from tensorflow.python.ops.rnn_cell_impl import _RNNCell as RNNCell
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import nest
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn_ops
_BIAS_VARIABLE_NAME = "biases"
_WEIGHTS_VARIABLE_NAME = "weights"
class BasicLSTMCell2(RNNCell):
"""Basic LSTM recurrent network cell.
The implementation is based on: http://arxiv.org/abs/1409.2329.
We add forget_bias (default: 1) to the biases of the forget gate in order to
reduce the scale of forgetting in the beginning of the training.
It does not allow cell clipping, a projection layer, and does not
use peep-hole connections: it is the basic baseline.
For advanced models, please use the full LSTMCell that follows.
"""
def __init__(self, in_dim, num_units, weights = None, forget_bias=1.0, input_size=None,
state_is_tuple=True, activation=tanh, reuse=None):
"""Initialize the basic LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
input_size: Deprecated and unused.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
activation: Activation function of the inner states.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation
self._reuse = reuse
self._in_dim = in_dim
self.weights = weights
@property
def state_size(self):
return (rnn.LSTMStateTuple(self._num_units, self._num_units)
if self._state_is_tuple else 2 * self._num_units)
@property
def output_size(self):
return self._num_units
def linear(self,arys):
scope = tf.get_variable_scope()
if self.weights != None:
w_i2h, w_h2h, w_b = self.weights
else:
with tf.variable_scope(scope): #initializer=tf.contrib.layers.xavier_initializer()
with tf.device('/cpu:0'):
w_i2h = tf.get_variable('w_i2h', (self._in_dim, 4*self._num_units), dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(),trainable=True)
w_h2h = tf.get_variable('w_h2h', (self._num_units, 4*self._num_units), dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(),trainable=True)
w_b = tf.get_variable('w_b', (1, 4*self._num_units), dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(),trainable=True)
i2h = tf.matmul(arys[0],w_i2h)
h2h = tf.matmul(arys[1],w_h2h)
out = i2h + h2h + w_b
return out
def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM)."""
with _checked_scope(self, scope or "basic_lstm_cell", reuse=self._reuse):
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
concat = self.linear([inputs, h])
#concat = _linear([inputs, h], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)
if self._state_is_tuple:
new_state = rnn.LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
@contextlib.contextmanager
def _checked_scope(cell, scope, reuse=None, **kwargs):
if reuse is not None:
kwargs["reuse"] = reuse
with vs.variable_scope(scope, **kwargs) as checking_scope:
scope_name = checking_scope.name
if hasattr(cell, "_scope"):
cell_scope = cell._scope # pylint: disable=protected-access
if cell_scope.name != checking_scope.name:
raise ValueError(
"Attempt to reuse RNNCell %s with a different variable scope than "
"its first use. First use of cell was with scope '%s', this "
"attempt is with scope '%s'. Please create a new instance of the "
"cell if you would like it to use a different set of weights. "
"If before you were using: MultiRNNCell([%s(...)] * num_layers), "
"change to: MultiRNNCell([%s(...) for _ in range(num_layers)]). "
"If before you were using the same cell instance as both the "
"forward and reverse cell of a bidirectional RNN, simply create "
"two instances (one for forward, one for reverse). "
"In May 2017, we will start transitioning this cell's behavior "
"to use existing stored weights, if any, when it is called "
"with scope=None (which can lead to silent model degradation, so "
"this error will remain until then.)"
% (cell, cell_scope.name, scope_name, type(cell).__name__,
type(cell).__name__))
else:
weights_found = False
try:
with vs.variable_scope(checking_scope, reuse=True):
vs.get_variable(_WEIGHTS_VARIABLE_NAME)
weights_found = True
except ValueError:
pass
if weights_found and reuse is None:
raise ValueError(
"Attempt to have a second RNNCell use the weights of a variable "
"scope that already has weights: '%s'; and the cell was not "
"constructed as %s(..., reuse=True). "
"To share the weights of an RNNCell, simply "
"reuse it in your second calculation, or create a new one with "
"the argument reuse=True." % (scope_name, type(cell).__name__))
# Everything is OK. Update the cell's scope and yield it.
cell._scope = checking_scope # pylint: disable=protected-access
yield checking_scope