-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_processing.py
70 lines (43 loc) · 2.04 KB
/
data_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import ants
import tensorflow as tf
import pathlib
import numpy as np
from PIL import Image
AUTOTUNE = tf.data.experimental.AUTOTUNE
def load_png_image(image_path):
image = Image.open(image_path).convert('L')
image_np = np.array(image).astype(np.float32)
return image_np
def preprocess_with_antspy(image_np):
image = ants.from_numpy(image_np)
corrected_image = ants.n4_bias_field_correction(image)
normalized_image = ants.iMath(corrected_image, "Normalize")
return normalized_image
def resize_image(image, target_size):
resampled_image = ants.resample_image(
image, target_size, use_voxels=True, interp_type=1)
return resampled_image
def load_and_preprocess_ants_images(image_path, img_height, img_width):
image_np = load_png_image(image_path)
preprocessed_image = preprocess_with_antspy(image_np)
resized_image = resize_image(preprocessed_image, [img_height, img_width])
return resized_image.numpy()
def load_datasets(data_dir_t1, data_dir_t2, img_height, img_width, batch_size, buffer_size):
t1_image_paths = list(pathlib.Path(data_dir_t1).glob('*/*.png'))
t2_image_paths = list(pathlib.Path(data_dir_t2).glob('*/*.png'))
def process_image(image_path):
image_path_str = image_path.numpy().decode('utf-8')
processed_image = load_and_preprocess_ants_images(
image_path_str, img_height, img_width)
return processed_image
def tf_process_image(image_path):
return tf.py_function(func=process_image, inp=[image_path], Tout=tf.float32)
tr1_dataset = tf.data.Dataset.from_tensor_slices(
[str(p) for p in t1_image_paths])
tr1_train = tr1_dataset.map(tf_process_image, num_parallel_calls=AUTOTUNE).batch(
batch_size).shuffle(buffer_size).prefetch(AUTOTUNE)
tr2_dataset = tf.data.Dataset.from_tensor_slices(
[str(p) for p in t2_image_paths])
tr2_train = tr2_dataset.map(tf_process_image, num_parallel_calls=AUTOTUNE).batch(
batch_size).shuffle(buffer_size).prefetch(AUTOTUNE)
return tr1_train, tr2_train