From 6edaa28da61d2858621f926b7bf1a4939061becc Mon Sep 17 00:00:00 2001 From: dmolony3 Date: Wed, 8 Jul 2020 09:47:15 -0400 Subject: [PATCH] Edited directory structure --- LICENSE | 201 ++++++++++++++++++++++++++++++ README.md | 2 +- drn/accuracy.py | 98 +++++++++++++++ drn/data_reader.py | 195 +++++++++++++++++++++++++++++ drn/dice.py | 41 +++++++ drn/drn.py | 283 +++++++++++++++++++++++++++++++++++++++++++ drn/inference.py | 85 +++++++++++++ drn/main.py | 23 ++++ drn/parse_args.py | 47 +++++++ drn/train_network.py | 136 +++++++++++++++++++++ 10 files changed, 1110 insertions(+), 1 deletion(-) create mode 100644 LICENSE create mode 100644 drn/accuracy.py create mode 100644 drn/data_reader.py create mode 100644 drn/dice.py create mode 100644 drn/drn.py create mode 100644 drn/inference.py create mode 100644 drn/main.py create mode 100644 drn/parse_args.py create mode 100644 drn/train_network.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..989e2c5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index a1a4a77..6f9f809 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,6 @@ python main.py --train_file=path/to/train_data.txt --val_file=path/to/val_data - The model can be used for inference by setting the mode to predict and providing similar arguments as above ``` -python main.py --mode=predict --val_file=path_to_val_data.txt --directory=C:\Users\David\Desktop\ivus_images1 --num_classes=5 --image_dims=500 --channels=3 +python main.py --mode=predict --val_file=path_to_val_data.txt --directory=image/directory_path --num_classes=5 --image_dims=500 --channels=3 ``` diff --git a/drn/accuracy.py b/drn/accuracy.py new file mode 100644 index 0000000..7dcc5eb --- /dev/null +++ b/drn/accuracy.py @@ -0,0 +1,98 @@ +from skimage import measure +import numpy as np + +# get accuracy +def jaccard(label, pred, num_classes): + """Computes the intersection over union (Jaccard index) for each class + + Args: + label: array, batch of ground truth lablemaps + pred: array, batch of predicted labelmaps + num_classes: int, total number of classes + Returns: + IOU: array, Intersection over union for each class + """ + + assert label.shape == pred.shape, "Size of label {} does not agree with size \ + of prediction {}".format(a.shape, b.shape) + + IOU = np.zeros((label.shape[0], num_classes), dtype=np.float32) + for i in range(num_classes): + inter = np.multiply(label == i, pred == i) + inter = np.sum(inter, axis=(1,2)) + union = np.subtract(np.add(np.sum(label == i, axis=(1,2)), + np.sum(pred == i, axis=(1,2))), inter) + IOU[:, i] = inter/union + IOU = np.mean(IOU, 0) + + return IOU + +def dice(label, pred, num_classes): + """Computes the DICE coefficient for each class + + Args: + label: array, batch of ground truth lablemaps + pred: array, batch of predicted labelmaps + num_classes: int, total number of classes + Returns: + IOU: array, Intersection over union for each class + """ + + assert label.shape == pred.shape, "Size of label {} does not agree with size \ + of prediction {}".format(a.shape, b.shape) + + dice = np.zeros((label.shape[0], num_classes), dtype=np.float32) + for i in range(num_classes): + inter = np.multiply(label == i, pred == i) + inter = np.sum(inter, axis=(1,2)) + union = np.add(np.sum(label == i, axis=(1,2)), np.sum(pred == i, axis=(1,2))) + dice[:, i] = 2*inter/union + dice = np.mean(dice, 0) + + return dice + +def hausdorff(label, pred, num_classes): + """Computes the Hausdorff distance for each classs + + Args: + label: array, batch of ground truth lablemaps + pred: array, batch of predicted labelmaps + num_classes: int, total number of classes + Returns: + IOU: array, Intersection over union for each class + """ + + assert label.shape == pred.shape, "Size of label {} does not agree with size \ + of prediction {}".format(a.shape, b.shape) + + hauss = np.zeros((label.shape[0], num_classes), dtype=np.float32) + + # Add 0.5 to create contour around integer labels + levels = np.arange(0, num_classes) + 0.5 + + # iterate over each contour + for i, level in enumerate(levels): + # iterate over each image in the batch + for j in range(label.shape[0]): + label_contour = measure.find_contours(label[j, :, :], level) + pred_contour = measure.find_contours(pred[j, :, :], level) + + # convert label and pred to contours where rows are samples and cols are dimensions + P = np.asarray(label_contour[0]) + Q = np.asarray(pred_contour[0]) + + lenP = P.shape[0] + lenQ = Q.shape[0] + + D = np.zeros((lenP, lenQ)) + + for ii in range(0, lenP): + for jj in range(0, lenQ): + D[ii- 1, jj - 1] = np.sqrt((P[ii,0] - Q[jj, 0])**2 + + (P[ii, 1] - Q[jj, 1])**2) + + d1 = np.max(np.min(D, axis=1)) + d2 = np.max(np.min(D, axis=0)) + hauss[j, i] = np.maximum(d1, d2) + + return hauss \ No newline at end of file diff --git a/drn/data_reader.py b/drn/data_reader.py new file mode 100644 index 0000000..d93484f --- /dev/null +++ b/drn/data_reader.py @@ -0,0 +1,195 @@ +import tensorflow as tf +import os + +class DataReader(): + """Class for loading and batching images, labelmaps and weightmaps. + + Args: + directory: string, path to directory containing images + batch_size: int, number of samples in batch + num_epochs: int, number of train/test epochs + use_weights: bool, flag indicating whether weightmaps are included + """ + + def __init__(self, directory, image_size, batch_size, num_epochs, use_weights): + self.image_size = image_size + self.batch_size = batch_size + self.num_epochs = num_epochs + self.use_weights = use_weights + self.directory = directory + self.image_list = [] + self.IMG_MEAN = tf.constant([60.3486, 60.3486, 60.3486], dtype=tf.float32) + + def read_files(self, data_file): + """Reads files and returns list of images, labels (and weights) + + Args: + data_file: string, path to file containing rows of image/label paths + Returns: + image_list: list, full path to each image + label_list: list, full path to each labelmap + weight_list: list, full path to each weightmap + """ + + f = open(data_file, 'r') + data = f.read() + data = data.split('\n') + image_list = [] + label_list = [] + weight_list = [] + + for i in range(len(data)): + line = data[i] + if line: + try: + image, label, weight = line.split(' ') + image_list.append(os.path.join(self.directory, image)) + label_list.append(os.path.join(self.directory, label)) + weight_list.append(os.path.join(self.directory, weight)) + except ValueError: + try: + image, label = line.split(' ') + image_list.append(os.path.join(self.directory, image)) + label_list.append(os.path.join(self.directory, label)) + weight_list.append('') + except ValueError: + image = line + image_list.append(os.path.join(self.directory, image)) + label_list.append('') + weight_list.append('') + + self.num_images = len(data) + + return image_list, label_list, weight_list + + def decode_image(self, image_path, label_path, weight_path): + """Reads image, label and weight paths and decodes + + Args: + image_path: string, path to image + label_path: string, path to labelmap + weight_path: string, path to weightmap + Returns: + image: 3D tensor, single image + label: 2D tensor, single labelmap + weight: 2D tensor, single weightmap + """ + + image = tf.read_file(image_path) + image = tf.image.decode_jpeg(image) + image = tf.cast(image, dtype=tf.float32) + + """ + if label_path: + label = tf.read_file(label_path) + label = tf.image.decode_png(label) + label = tf.cast(label, dtype=tf.int32) + else: + label = tf.zeros((tf.shape(image)[0], tf.shape(image)[1])) + + if weight_path: + weight = tf.read_file(weight_path) + weight = tf.image.decode_png(weight) + weight = tf.cast(weight, dtype=tf.float32) + else: + weight = tf.ones((tf.shape(image)[0], tf.shape(image)[1])) + """ + + label = tf.cond(tf.cast(tf.strings.length(label_path), dtype=tf.bool), + lambda: self.read_and_decode_png(label_path), + lambda: tf.zeros(self.image_size, dtype=tf.uint8)) + + weight = tf.cond(tf.cast(tf.strings.length(weight_path), dtype=tf.bool), + lambda: self.read_and_decode_png(weight_path), + lambda: tf.zeros(self.image_size, dtype=tf.uint8)) + + label = tf.cast(label, dtype=tf.int32) + weight = tf.cast(weight, dtype=tf.float32) + image -= self.IMG_MEAN + + return image, label, weight + + def read_and_decode_png(self, file_path): + """Reads and decodes png files""" + + image = tf.read_file(file_path) + image = tf.image.decode_png(image) + + return image + + def mirror_image(self, image, label, weight): + """Performs random flipping of image/labelmap/weightmap""" + + cond = tf.cast(tf.random_uniform([], maxval=2, dtype=tf.int32), tf.bool) + image = tf.cond(cond, lambda: tf.image.flip_left_right(image), lambda: tf.identity(image)) + label = tf.cond(cond, lambda: tf.image.flip_left_right(label), lambda: tf.identity(label)) + weight = tf.cond(cond, lambda: tf.image.flip_left_right(weight), lambda: tf.identity(weight)) + + return image, label, weight + + def rotate_image(self, image, label, weight): + """Performs random rotation of image/labelmap/weightmap""" + + rot_angle = tf.random_uniform([], minval=0, maxval=360, dtype=tf.float32) + image = tf.contrib.image.rotate(image, rot_angle) + label = tf.contrib.image.rotate(label, rot_angle) + weight = tf.contrib.image.rotate(weight, rot_angle) + + return image, label, weight + + def add_noise(self, image, label, weight): + """Adds gaussian noise to input image""" + + noise = tf.random_normal(shape=tf.shape(image), mean=0.0, stddev=1) + image += noise + + return image, label, weight + + def train_batch(self, train_file): + """Reads and batches images for training + + Args: + train_file: string, path to file containing rows of image/label paths + Returns: + train_data: tensorflow dataset, augmented batch of images/labels/weights + """ + + image_list, label_list, weight_list = self.read_files(train_file) + self.image_list = image_list + + train_data = tf.data.Dataset.from_tensor_slices((image_list, label_list, weight_list)) + + # shuffle all files + train_data = train_data.shuffle(buffer_size=len(image_list)) + + # decode images and subtract image mean + train_data = train_data.map(self.decode_image) + + # Data augmentation + train_data = train_data.map(self.rotate_image, num_parallel_calls=2) + train_data = train_data.map(self.mirror_image, num_parallel_calls=2) + train_data = train_data.map(self.add_noise, num_parallel_calls=2) + train_data = train_data.repeat() + + train_data = train_data.apply(tf.contrib.data.batch_and_drop_remainder(self.batch_size)) + + return train_data + + def test_batch(self, test_file): + """Reads and batches images for testing + + Args: + test_file: string, path to file containing rows of image/label paths + Returns: + test_data: tensorflow dataset, batch of images/labels/weights + """ + + image_list, label_list, weight_list = self.read_files(test_file) + self.image_list = image_list + + test_data = tf.data.Dataset.from_tensor_slices((image_list, label_list, weight_list)) + test_data = test_data.map(self.decode_image) + + test_data = test_data.apply(tf.contrib.data.batch_and_drop_remainder(self.batch_size)) + + return test_data \ No newline at end of file diff --git a/drn/dice.py b/drn/dice.py new file mode 100644 index 0000000..3823969 --- /dev/null +++ b/drn/dice.py @@ -0,0 +1,41 @@ +import tensorflow as tf + +def dice_loss(logits, label, num_classes, use_weights): + """Computes the DICE loss + + Args: + logits: tensor, output logits/scores from neural network + labels: tensor, ground truth labelmaps + num_classes: int, number of classes + use_weights: bool, Flag to weight class labels in loss + Returns: + dice_loss: int, loss evaluated as (1 - dice_coefficient) + """ + + num_classes = logits.shape[-1] + + logits = tf.nn.softmax(logits) + + label_one_hot = tf.one_hot(label, num_classes) + + # create weight for each class + w = tf.zeros((num_classes)) + w = tf.reduce_sum(label_one_hot, axis=[0,1,2]) + + # optionally apply weights + use_weights = tf.convert_to_tensor(use_weights) + w = tf.cond(use_weights, lambda: 1/(w**2), lambda: tf.ones((num_classes))) + + # sum over batches and images + ref_vol = tf.reduce_sum(label_one_hot, axis=[0,1,2]) + 0.1 + intersect = tf.reduce_sum(label_one_hot*logits, axis=[0,1,2]) + seg_vol = tf.reduce_sum(logits, [0,1,2]) + 0.1 + + # sum over all classes + dice_numerator = 2.0*tf.reduce_sum(tf.multiply(w, intersect)) + dice_denominator = tf.reduce_sum(tf.multiply(w, seg_vol + ref_vol)) + + # subtract 1 as we are tyring to maximize the DICE but optimization will minimize + dice_loss = 1.0 - dice_numerator / dice_denominator + + return dice_loss \ No newline at end of file diff --git a/drn/drn.py b/drn/drn.py new file mode 100644 index 0000000..80b5a11 --- /dev/null +++ b/drn/drn.py @@ -0,0 +1,283 @@ +import tensorflow as tf + +class DRN(): + """Dilated Residual Network for semantic segmentation + + This class creates either an 18 or 26 layer dilated residual network. + The is_training flag is used to switch between training and evaluation + mode. + + Args: + image: 4D tensor, input image + image_dims: list, dimensions for input rows, columns and channels + batch_size: int, number of images in batch + num_classes: int, number of class maps to produce at output + is_training: bool, flag to indicate whether model is being trained + network: string, choice of network, must be either 'DRN18' or 'DRN26' + """ + + def __init__(self, image, image_dims, batch_size, num_classes, is_training, network): + self.image = image + self.batch_size = batch_size + self.num_classes = num_classes + self.is_training = is_training + self.image_dims = image_dims + + if network == 'DRN18': + self.build_DRN18() + elif network == 'DRN26': + self.build_DRN26() + + def batch_norm(self, X, is_training, decay=0.999): + """Batch normalization + + The offset (beta) should always be used, but the scale is not + necessary for activation function like relu. beta and scale will + have the same shape as the bias i.e. the no. of features. Code + source https://gist.github.com/tomokishii/0ce3bdac1588b5cca9fa5fbdf6e1c412 + + Args: + X: 4D tensor, image or feature map + is_training: bool, flag to indicate whether model is being trained + decay: decay rate for exponential moving average + Returns: + X: 4D tensor, batch normalized tensor for input X + """ + + scale = tf.Variable(tf.ones([X.get_shape()[-1]])) + beta = tf.Variable(tf.zeros([X.get_shape()[-1]])) + epsilon = 1e-6 + + batch_mean, batch_var = tf.nn.moments(X, [0, 1, 2]) + ema = tf.train.ExponentialMovingAverage(decay) + + def mean_var_with_update(): + ema_apply_op = ema.apply([batch_mean, batch_var]) + with tf.control_dependencies([ema_apply_op]): + return tf.identity(batch_mean), tf.identity(batch_var) + + mean, var = tf.cond(is_training, + mean_var_with_update, + lambda: (ema.average(batch_mean), ema.average(batch_var))) + + X = tf.nn.batch_normalization(X, mean, var, beta, scale, epsilon) + + return X + + def conv_repeat(self, X, strides, dilation, kernel, residual, name): + """Performs convolution with residual block + + Args: + X: 4D tensor, image or feature map + strides: int, stride for convolution + dilation: list/int, factor by which convolution kernel is dilated + kernel: list, 3 elements - size of kernel and number of filters + residual: bool, flag for residual connection + name: string, layer name + Returns: + X: 4D tensor, output from residual block + """ + + if residual: + shortcut = X + if strides[-1] != 1: + # strided convolution and double channel depth + shortcut = self.conv_2d(shortcut, strides[-1], [1, 1, kernel[-1]], name + '_shortcut') + elif shortcut.get_shape()[3] != kernel[-1]: + shortcut = self.conv_2d(shortcut, strides[0], kernel, name + '_shortcut') + + for i in range(len(strides)): + if dilation: + X = self.atrous_conv_2d(X, dilation, kernel, name=name + '_' + str(i)) + X = self.batch_norm(X, is_training=self.is_training) + X = tf.nn.relu(X) # perform relu activation + else: + X = self.conv_2d(X, strides[i], kernel, name=name + '_' + str(i)) + X = self.batch_norm(X, is_training=self.is_training) + X = tf.nn.relu(X) # perform relu activation + + if residual and i == len(strides) - 1: + # add shortcut on last operation + X = tf.add(shortcut, X) + X = tf.nn.relu(X) + else: + X = tf.nn.relu(X) + + return X + + def atrous_conv_2d(self, X, dilation, kernel, name): + """2D Dilated (atrous) convolution + + The atrous convolution is a convolution with holes. It is used to + expand the receptive field without downsizing the activation map. + Here a stride of 1 is assumed + + Args: + X: 4D tensor, image or feature map + dilation: int, factor by which convolution kernel is dilated + kernel: list, 3 elements - size of kernel and number of filters + name: string, layer name + Returns: + X: 4D tensor, output after convolution and bias operations + """ + + with tf.variable_scope(name) as scope: + W = tf.get_variable(shape=[kernel[0], kernel[1], X.shape[3], kernel[2]], + dtype=tf.float32, name=name + '_weights') + b = tf.get_variable(shape=[kernel[2]], dtype=tf.float32, + name=name + '_bias') + + X = tf.nn.atrous_conv2d(X, W, dilation, padding='SAME') + X = X + b + + return X + + def conv_2d(self, X, strides, kernel, name): + """2D convolution + + Args: + X: 4D tensor, image or feature map + strides: int, stride for convolution + kernel: list, 3 elements - size of kernel and number of filters + name: string, name of layer + Returns: + X: 4D tensor, output after convolution and bias operations + """ + + with tf.variable_scope(name) as scope: + W = tf.get_variable(shape=[kernel[0], kernel[1], X.shape[3], kernel[2]], + dtype=tf.float32, name=name + '_weights') + b = tf.get_variable(shape=[kernel[2]], dtype=tf.float32, + name=name + '_bias') + + X = tf.nn.conv2d(X, W, [1, strides, strides, 1], padding='SAME', name=None) + X = X + b + + return X + + def build_DRN18(self): + """Dilated residual network with 18 layers""" + + X = self.image + X.set_shape([None, self.image_dims[0], self.image_dims[1], self.image_dims[2]]) + + print('Input shape is {}'.format(X.get_shape())) + + kernel = [7, 7, 64] + strides = 2 + dilation = [] + X = self.conv_2d(X, strides, kernel, 'layer2') + print('Layer2 shape is {}'.format(X.get_shape())) + + residual = 1 + X = tf.nn.max_pool(X, [1, 1, 1, 1], [1, 2, 2, 1], padding='SAME') + kernel = [3, 3, 64] + strides = [1, 1] + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer3_1') + strides = [1, 2] + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer3_2') + print('Layer3 shape is {}'.format(X.get_shape())) + + kernel = [3, 3, 128] + strides = [1, 1] + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer4_1') + strides = [1, 1] + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer4_2') + print('Layer4 shape is {}'.format(X.get_shape())) + + kernel = [3, 3, 256] + strides = [1, 1, 1, 1] + dilation = 2 + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer5') + print('Layer5 shape is {}'.format(X.get_shape())) + + kernel = [3, 3, 512] + strides = [1, 1, 1, 1] + dilation = 4 + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer6') + print('Layer6 shape is {}'.format(X.get_shape())) + + # 1x1 convolution to squash output to number of classes + kernel = [1, 1, self.num_classes] + strides = [1] + dilation = [] + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'output') + print('Output shape is {}'.format(X.get_shape())) + + self.prob = X + self.pred = tf.argmax(X, 3) + + def build_DRN26(self): + """Dilated residual network with 26 layers""" + + X = self.image + X.set_shape([None, self.image_dims[0], self.image_dims[1], self.image_dims[2]]) + + print('Input shape is {}'.format(X.get_shape())) + + kernel = [7, 7, 16] + strides = 1 + self.conv_2d(X, strides, kernel, 'layer1_1') + residual = 1 + kernel = [3, 3, 16] + strides = [1, 2] + dilation = [] + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer1_2') + print('Layer1 shape is {}'.format(X.get_shape())) + + residual = 1 + kernel = [3, 3, 32] + strides = [1, 2] + dilation = [] + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer2') + print('Layer2 shape is {}'.format(X.get_shape())) + + residual = 1 + kernel = [3, 3, 64] + strides = [1, 1] + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer3_1') + strides = [1, 2] + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer3_2') + print('Layer3 shape is {}'.format(X.get_shape())) + + kernel = [3, 3, 128] + strides = [1, 1] + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer4_1') + strides = [1, 1] + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer4_2') + print('Layer4 shape is {}'.format(X.get_shape())) + + kernel = [3, 3, 256] + strides = [1, 1, 1, 1] + dilation = 2 + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer5') + print('Layer5 shape is {}'.format(X.get_shape())) + + kernel = [3, 3, 512] + strides = [1, 1, 1, 1] + dilation = 4 + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer6') + print('Layer6 shape is {}'.format(X.get_shape())) + + residual = 0 + kernel = [3, 3, 512] + strides = [1, 1] + dilation = 2 + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer7') + print('Layer7 shape is {}'.format(X.get_shape())) + + kernel = [3, 3, 512] + strides = [1, 1] + dilation = 1 + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'layer8') + print('Layer8 shape is {}'.format(X.get_shape())) + + # 1x1 convolution to squash output to number of classes + kernel = [1, 1, self.num_classes] + strides = [1] + dilation = [] + X = self.conv_repeat(X, strides, dilation, kernel, residual, 'output') + print('Output shape is {}'.format(X.get_shape())) + + self.prob = X + self.pred = tf.argmax(X, 3) \ No newline at end of file diff --git a/drn/inference.py b/drn/inference.py new file mode 100644 index 0000000..e96b855 --- /dev/null +++ b/drn/inference.py @@ -0,0 +1,85 @@ +import os +import tensorflow as tf +import numpy as np +from data_reader import DataReader +from drn import DRN +from PIL import Image as im + +def predict(config): + + tf.reset_default_graph() + + directory = os.getcwd() + + pred_directory = os.path.join(directory, 'Pred') + + data = DataReader(config.directory, config.image_dims, config.batch_size, + config.num_epochs, use_weights=False) + dataset = data.test_batch(config.val_file) + num_images = data.num_images + + # get image filenames + image_list = data.image_list + + # determine number of iterations based on number of images + num_iterations = int(np.floor(num_images/config.batch_size)) + + # create iterator allowing us to switch between datasets + data_iterator = dataset.make_one_shot_iterator() + next_element = data_iterator.get_next() + + # create placeholder for train or test + train_network = tf.placeholder(tf.bool, []) + + # get images and pass into network + image, label, weight = next_element + drn = DRN(image, config.image_dims, config.batch_size, config.num_classes, + train_network, config.network) + + # get predictions and logits + prediction = drn.pred + logits = drn.prob + label = tf.squeeze(label, axis=-1) + + # resize the logits using bilinear interpolation + imsize = tf.constant([config.image_dims[0], config.image_dims[1]], dtype=tf.int32) + logits = tf.image.resize_bilinear(logits, imsize) + prediction = tf.argmax(logits, axis=-1) + print('Resized shape is {}'.format(logits.get_shape())) + + # global step to keep track of iterations + global_step = tf.Variable(0, trainable=False, name='global_step') + + saver = tf.train.Saver(max_to_keep=3) + + init = tf.global_variables_initializer() + + with tf.Session() as sess: + + # initialize variables + sess.run(init) + + # restore checkpiont if it exists + ckpt = tf.train.get_checkpoint_state(config.logs) + if ckpt and ckpt.model_checkpoint_path: + saver.restore(sess, ckpt.model_checkpoint_path) + print('Restoring session at step {}'.format(global_step.eval())) + + iteration = global_step.eval() + for i in range(num_iterations): + print('step: {} of {}'.format(i, num_iterations)) + img, pred = sess.run([image, prediction], feed_dict={train_network:False}) + + fnames = image_list[config.batch_size*i:config.batch_size*i + config.batch_size] + # write images to file + for j in range(pred.shape[0]): + fname = fnames[j].split('/')[-1] + + # drop file extension + fname = fname.split('.')[0] + + if not os.path.isdir(pred_directory): + os.makedirs(pred_directory) + + img_write = im.fromarray(pred[j, :, :], "L") + img_write.save(os.path.join(pred_directory, fname + ".png")) \ No newline at end of file diff --git a/drn/main.py b/drn/main.py new file mode 100644 index 0000000..ffc4579 --- /dev/null +++ b/drn/main.py @@ -0,0 +1,23 @@ +from parse_args import parse_args +import os +from train_network import train +from inference import predict + +def run(): + """Runs dilated residual network model in either train or predict mode""" + + config = parse_args() + + if not os.path.isdir(config.logs): + os.makedirs(config.logs) + + if config.mode == 'train': + train(config) + elif config.mode == 'predict': + predict(config) + else: + ValueError("Mode must be either train or predict") + + +if __name__ == '__main__': + run() \ No newline at end of file diff --git a/drn/parse_args.py b/drn/parse_args.py new file mode 100644 index 0000000..a35a12e --- /dev/null +++ b/drn/parse_args.py @@ -0,0 +1,47 @@ +import argparse + +class Config: + pass + +def parse_args(): + """Parses input arguments into a configuration class + + Args: + args: Argument parser + Returns: + config: Instance of configuration class + """ + + parser = argparse.ArgumentParser() + parser.add_argument('--mode', type=str, default='train', help="Mode must be either train or predict") + parser.add_argument('--network', type=str, default='DRN18', help="Select either DRN18 or DRN26 for network") + parser.add_argument('--image_dims', type=int, default=500, help="Dimension of the input image, assumes square") + parser.add_argument('--channels', type=int, default=3, help="Dimension of the image color channel") + parser.add_argument('--num_epochs', type=int, default=50, help="Number of training epochs") + parser.add_argument('--batch_size', type=int, default=16, help="Batch size") + parser.add_argument('--directory', type=str, help="Enter the path to the directory contanining all images") + parser.add_argument('--train_file', type=str, help="Path to the training data file") + parser.add_argument('--val_file', type=str, help="Path to the validation data file") + parser.add_argument('--num_classes', type=int, default=2, help="Number of classes to predict") + parser.add_argument('--loss', type=str, default='dice', help="Enter the loss function to use; either 'dice' or weighted cross-entropy - 'CE'") + parser.add_argument('--logs', type=str, default='logs', help="Enter the path to the log/save directory") + parser.add_argument('--use_weights', type=bool, default=False, help="Flag indicating whether to include a weight map") + parser.add_argument('--learning_rate', type=float, default=0.001, help="Learning rate during training") + args = parser.parse_args() + + config = Config() + config.mode = args.mode + config.network = args.network + config.directory = args.directory + config.loss = args.loss + config.num_epochs = args.num_epochs + config.image_dims = [args.image_dims, args.image_dims, args.channels] + config.batch_size = args.batch_size + config.num_classes = args.num_classes + config.use_weights = args.use_weights + config.train_file = args.train_file + config.val_file = args.val_file + config.logs = args.logs + config.learning_rate = args.learning_rate + + return config \ No newline at end of file diff --git a/drn/train_network.py b/drn/train_network.py new file mode 100644 index 0000000..2210753 --- /dev/null +++ b/drn/train_network.py @@ -0,0 +1,136 @@ +import os +import tensorflow as tf +import numpy as np +from accuracy import jaccard, dice +from data_reader import DataReader +from drn import DRN +from dice import dice_loss + +def train(config): + """Trains the model based on configuration settings + + Args: + config: configurations for training the model + """ + + tf.reset_default_graph() + + data = DataReader(config.directory, config.image_dims, config.batch_size, + config.num_epochs, config.use_weights) + train_data = data.train_batch(config.train_file) + num_train_images = data.num_images + + test_data = data.test_batch(config.val_file) + num_val_images = data.num_images + + # determine number of iterations based on number of images + training_iterations = int(np.floor(num_train_images/config.batch_size)) + validation_iterations = int(np.floor(num_val_images/config.batch_size)) + + # create iterators allowing us to switch between datasets + handle = tf.placeholder(tf.string, shape=[]) + iterator = tf.data.Iterator.from_string_handle(handle, + train_data.output_types, train_data.output_shapes) + next_element = iterator.get_next() + training_iterator = train_data.make_initializable_iterator() + val_iterator = test_data.make_initializable_iterator() + + # create placeholder for train or test + train_network = tf.placeholder(tf.bool, []) + + # get images and pass into network + image, label, weight = next_element + drn = DRN(image, config.image_dims, config.batch_size, config.num_classes, + train_network, config.network) + + # get predictions and logits + prediction = drn.pred + logits = drn.prob + label = tf.squeeze(label, 3) + + # resize the logits using bilinear interpolation + imsize = tf.constant([config.image_dims[0], config.image_dims[1]], + dtype=tf.int32) + logits = tf.image.resize_bilinear(logits, imsize) + print('Resized shape is {}'.format(logits.get_shape())) + + prediction = tf.argmax(logits, 3) + + if config.loss == 'CE': + if config.use_weights: + label_one_hot = tf.one_hot(label, config.num_classes) + loss = tf.nn.softmax_cross_entropy_with_logits(labels=label_one_hot, + logits=logits) + loss = loss*tf.squeeze(weight, 3) + else: + # use sparse with flattened labelmaps + loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, + logits=logits) + loss = tf.reduce_mean(loss) + elif config.loss == 'dice': + loss = dice_loss(logits, label, config.num_classes, + use_weights=config.use_weights) + else: + NameError("Loss must be specified as CE or DICE") + + # global step to keep track of iterations + global_step = tf.Variable(0, trainable=False, name='global_step') + + # create placeholder for learning rate + learning_rate = tf.placeholder(tf.float32, shape=[]) + + optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step) + + saver = tf.train.Saver(max_to_keep=3) + + init = tf.global_variables_initializer() + + with tf.Session() as sess: + training_handle = sess.run(training_iterator.string_handle()) + validation_handle = sess.run(val_iterator.string_handle()) + + sess.run(training_iterator.initializer) + sess.run(init) + + ckpt = tf.train.get_checkpoint_state(config.logs) + if ckpt and ckpt.model_checkpoint_path: + saver.restore(sess, ckpt.model_checkpoint_path) + print('Restoring session at step {}'.format(global_step.eval())) + + # if restoring saved checkpoint get last saved iteration so that correct + # epoch can be restored + iteration = global_step.eval() + current_epoch = int(np.floor(iteration/training_iterations)) + + while current_epoch < config.num_epochs: + + train_loss = 0 + for i in range(training_iterations): + _, l = sess.run([optimizer, loss], feed_dict={handle:training_handle, + learning_rate:config.learning_rate, train_network:True}) + train_loss += l + iteration = global_step.eval() + + sess.run(val_iterator.initializer) + val_loss = 0 + for i in range(validation_iterations): + l, img, lbl, pred = sess.run([loss, image, label, prediction], + feed_dict={handle:validation_handle, train_network:False}) + val_loss += l + + # evaluate accuracy + accuracy = jaccard(lbl, pred, config.num_classes) + dice_score = dice(lbl, pred, config.num_classes) + + print('Train loss Epoch {} step {} :{}'.format(current_epoch, iteration, + train_loss/training_iterations)) + print('Validation loss Epoch {} step {} :{}'.format(current_epoch, iteration, + val_loss/validation_iterations)) + + with open('loss.txt', 'a') as f: + f.write("Epoch: {} Step: {} Loss: {}\n".format(current_epoch, iteration, + train_loss/training_iterations)) + + saver.save(sess, config.logs + '/model.ckpt', global_step) + + current_epoch += 1 \ No newline at end of file