-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconfig.py
123 lines (88 loc) · 3.72 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Modified from the Pyimagesearch 5-part series on Siamese networks: https://pyimg.co/dq1w5
"""
# Standard Library imports
from pathlib import Path
import re
def extract_epoch_and_loss(filename: str | Path | None):
"""
Return
initial epoch, initial loss threshold
"""
if filename is None:
return 0, None
if isinstance(filename, Path):
filename = str(filename)
match = re.search(r"epoch_(\d+)-loss_(\d+\.\d+)", filename)
if match:
return int(match[1]), float(match[2])
raise ValueError(f"Incorrect filename format in file '{filename}'")
def get_latest_epoch_filename(folder_path: Path):
""" """
latest_epoch = -1
latest_filename = None
for filename in folder_path.rglob("*.keras"):
epoch, _ = extract_epoch_and_loss(filename.name)
if epoch > latest_epoch:
latest_filename = filename.name
return latest_filename
def get_model_path(filename):
if CKPT_FILENAME is not None:
return OUTPUT_PATH / filename
return None
########################################################################################################################
# Input paths
########################################################################################################################
ROOT = Path("/content")
DATA = ROOT / "oracle-cards"
TRAIN_DATASET = DATA / "train"
VALID_DATASET = DATA / "val"
DATA_SUBSET = ROOT / "oracle-cards-subset"
QUERY_DATASET = ROOT / "query"
########################################################################################################################
# Output paths
########################################################################################################################
OUTPUT_PATH = Path("siamese_output", "densenet121")
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
# Filename of the checkpoint with the largest epoch number
CKPT_FILENAME = get_latest_epoch_filename(OUTPUT_PATH)
CKPT_FILENAME_PT = "checkpoint.pt"
# Used for loading, if exists
LOAD_MODEL_PATH = get_model_path(CKPT_FILENAME)
LOAD_MODEL_PATH_PT = OUTPUT_PATH / CKPT_FILENAME_PT
# New checkpoints will be saved to here
MODEL_CKPT_PATH = OUTPUT_PATH / "epoch_{epoch:02d}-loss_{val_loss:.4f}.keras"
# Index path
FAISS_INDEX_PATH = OUTPUT_PATH / "index.faiss"
MANUAL_INDEX_PATH = OUTPUT_PATH / "index.pickle"
LOGS_PATH = OUTPUT_PATH / "logs"
IMAGES_DF_PATH = OUTPUT_PATH / "images.csv"
ORACLE_CARDS_CSV = "oracle-cards-20240821210300.csv"
########################################################################################################################
# Other parameters
########################################################################################################################
MOMENTUM = 0.937
# Model input image size
IMAGE_SIZE = (357, 256)
# Number of features of the embedding generated by the backbone
EMBEDDING_SHAPE = 128 # Densenet121
# Inference parameters
N_RESULTS = 9
# Index parameters
INDEX_TYPE = "faiss" # faiss, dict
EXTENSIONS = ("*.jpg", "*.jpeg", "*.png")
########################################################################################################################
# Training parameters
########################################################################################################################
TRAIN_BACKBONE = False
LEARNING_RATE = 1e-4
INITIAL_EPOCH, _ = extract_epoch_and_loss(CKPT_FILENAME)
INITIAL_LOSS = None # None to save models starting from any loss
EPOCHS = 100
BATCH_SIZE = 4
NUM_TRAIN_SAMPLES = len(list(TRAIN_DATASET.rglob("*.jpg")))
NUM_VALIDATION_SAMPLES = len(list(VALID_DATASET.rglob("*.jpg")))
STEPS_PER_EPOCH = int(NUM_TRAIN_SAMPLES / BATCH_SIZE)
VALIDATION_STEPS = int(NUM_VALIDATION_SAMPLES / BATCH_SIZE)
NUM_WORKERS = 0
WHITE = (1, 1, 1)