-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path1train_model.py
32 lines (29 loc) · 1.32 KB
/
1train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# train_model.py
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(
preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input,
validation_split=0.2,
rotation_range=40, width_shift_range=0.2, height_shift_range=0.2,
shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode='nearest'
)
train_generator = train_datagen.flow_from_directory(
'downloaded_images', target_size=(224, 224), batch_size=32,
class_mode='categorical', subset='training'
)
validation_generator = train_datagen.flow_from_directory(
'downloaded_images', target_size=(224, 224), batch_size=32,
class_mode='categorical', subset='validation'
)
base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3),
include_top=False, weights='imagenet')
base_model.trainable = False
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1024, activation='relu'),
tf.keras.layers.Dense(train_generator.num_classes, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(train_generator, validation_data=validation_generator, epochs=10)
model.save('trained_model.h5')