diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..fc21d47 --- /dev/null +++ b/infer.py @@ -0,0 +1,38 @@ +from model import lowlight_enhance, load_images +import tensorflow as tf +import os + + +def lowlight_test(input_file, lowlight_enhance): + test_low_data_name = [input_file] + test_low_data = [] + test_high_data = [] + for i in range(1): + print('fileload', test_low_data_name[i]) + test_low_im = load_images(test_low_data_name[i]) + print('fileload return', test_low_im) + test_low_data.append(test_low_im) + + lowlight_enhance.test(test_low_data, test_high_data, test_low_data_name, save_dir='test_results', decom_flag=0) + + +def main(input_file, use_gpu=False): + print('called main') + if use_gpu: + print("[*] GPU\n") + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_idx + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem) + with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: + model = lowlight_enhance(sess) + lowlight_test(input_file, model) + else: + print("[*] CPU\n") + with tf.Session() as sess: + model = lowlight_enhance(sess) + lowlight_test(input_file, model) + + +#if __name__ == '__main__': +# main(input_file, use_gpu=False) + + #tf.app.run() \ No newline at end of file diff --git a/model.py b/model.py index ad784fb..a2fab36 100644 --- a/model.py +++ b/model.py @@ -10,9 +10,11 @@ from utils import * + def concat(layers): return tf.concat(layers, axis=3) + def DecomNet(input_im, layer_num, channel=64, kernel_size=3): input_max = tf.reduce_max(input_im, axis=3, keepdims=True) input_im = concat([input_max, input_im]) @@ -27,6 +29,7 @@ def DecomNet(input_im, layer_num, channel=64, kernel_size=3): return R, L + def RelightNet(input_L, input_R, channel=64, kernel_size=3): input_im = concat([input_R, input_L]) with tf.variable_scope('RelightNet'): @@ -49,6 +52,7 @@ def RelightNet(input_L, input_R, channel=64, kernel_size=3): output = tf.layers.conv2d(feature_fusion, 1, 3, padding='same', activation=None) return output + class lowlight_enhance(object): def __init__(self, sess): self.sess = sess @@ -194,7 +198,7 @@ def train(self, train_low_data, train_high_data, eval_low_data, batch_size, patc iter_num += 1 # evalutate the model and save a checkpoint file for it - if (epoch + 1) % eval_every_epoch == 0: + if (epoch + 1) % int(eval_every_epoch) == 0: self.evaluate(epoch + 1, eval_low_data, sample_dir=sample_dir, train_phase=train_phase) self.save(saver, iter_num, ckpt_dir, "RetinexNet-%s" % train_phase) diff --git a/utils.py b/utils.py index 6fc320c..c0e4fb4 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,10 @@ import numpy as np from PIL import Image +import time + +# import storgae library for cloud +from google.cloud import storage + def data_augmentation(image, mode): if mode == 0: @@ -30,10 +35,12 @@ def data_augmentation(image, mode): image = np.rot90(image, k=3) return np.flipud(image) + def load_images(file): im = Image.open(file) return np.array(im, dtype="float32") / 255.0 + def save_images(filepath, result_1, result_2 = None): result_1 = np.squeeze(result_1) result_2 = np.squeeze(result_2) @@ -45,3 +52,21 @@ def save_images(filepath, result_1, result_2 = None): im = Image.fromarray(np.clip(cat_image * 255.0, 0, 255.0).astype('uint8')) im.save(filepath, 'png') + + +def get_epoch_time(): #epoch time get function + epoch_time = int(time.time()) + return epoch_time + + +def upload_blob(bucket_name, source_file_name, destination_blob_name): #GCP Storage upload function + """Uploads a file to the bucket.""" + storage_client = storage.Client() + bucket = storage_client.get_bucket(bucket_name) + blob = bucket.blob(destination_blob_name) + blob.upload_from_filename(source_file_name) + print('File {} uploaded to {}.'.format( + source_file_name, destination_blob_name)) + + +