-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdistutil.py
96 lines (85 loc) · 2.68 KB
/
distutil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from distutil import *
import tensorflow as tf
def _variable_on_cpu(name, shape, initializer):
with tf.device('/cpu:0'):
var = tf.get_variable(name, shape, initializer=initializer)
return var
def _variable_with_weight_decay(name, shape, stddev, wd):
var = _variable_on_cpu(name, shape,
tf.truncated_normal_initializer(stddev=stddev))
if wd is not None:
weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
tf.add_to_collection('losses', weight_decay)
return var
def print_activations(t):
print(t.op.name, ' ', t.get_shape().as_list())
def calculate_loss(logits, labels):
# Calculate the average cross entropy loss across the batch.
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
return tf.add_n(tf.get_collection('losses'), name='total_loss')
def arg_parser():
"""
argument parser
:return: parser
"""
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
# Flags for defining the tf.train.ClusterSpec
parser.add_argument(
"--ps_hosts",
type=str,
default="localhost:2222",
help="Comma-seperated list of hostname:port pairs"
)
parser.add_argument(
"--worker_hosts",
type=str,
default="localhost:2223, localhost:2224",
help="Comma-seperated list of hostname:port pairs"
)
parser.add_argument(
"--job_name",
type=str,
default="",
help="One of 'ps', 'worker'"
)
# Flags for defining the tf.train.Server
parser.add_argument(
"--task_index",
type=int,
default=0,
help="Index of task within the job"
)
parser.add_argument(
"--max_steps",
type=int,
default=5000,
help="maximum number of iterations"
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="Batch size"
)
parser.add_argument(
"--learning_rate",
type=float,
default=0.01,
help="Learning rate."
)
parser.add_argument(
"--log_dir",
type=str,
default="/tmp/cifar10_dps",
help="directory to store log data"
)
return parser