Skip to content

Commit

Permalink
unregister if already registered
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya1503 committed Jan 26, 2024
1 parent 43794dc commit 4a2e1da
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions object_detection/detectron2_training-kfold.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@
"import glob\n",
"from sklearn.model_selection import KFold\n",
"import json\n",
"from collections import defaultdict\n",
"from detectron2.data.datasets import register_coco_instances"
"from collections import defaultdict"
]
},
{
Expand Down Expand Up @@ -150,15 +149,24 @@
" annotations_count = len(data_dict['annotations'])\n",
" print(f\"Number of images: {images_count}, Number of annotations: {annotations_count}\")\n",
"\n",
" \n",
"def unregister_coco_instances(name):\n",
" if name in DatasetCatalog.list():\n",
" DatasetCatalog.remove(name)\n",
" MetadataCatalog.remove(name)\n",
"\n",
"# Generate K-Fold cross-validation\n",
"kf = KFold(n_splits=NUM_FOLDS)\n",
"pairs = []\n",
"for fold, (train_indices, test_indices) in enumerate(kf.split(image_ids)):\n",
" train_data, test_data = split_data(train_indices, test_indices)\n",
" # Register COCO instances for training and validation. \n",
" # Note: The 'train2017' folder is retained as the base path for images.\n",
" train_file = f\"train_coco_{fold}_fold.json\"\n",
" test_file = f\"test_coco_{fold}_fold.json\"\n",
" # Unregister instances with the same names only if they exist\n",
" unregister_coco_instances(train_file)\n",
" unregister_coco_instances(test_file)\n",
" # Register COCO instances for training and validation. \n",
" # Note: The 'train2017' folder is retained as the base path for images.\n",
" register_coco_instances(train_file, {}, train_file, \"train2017\")\n",
" register_coco_instances(test_file, {}, test_file, \"train2017\")\n",
" pairs.append([train_file,test_file])\n",
Expand Down Expand Up @@ -215,8 +223,6 @@
" \n",
" def data_loader_mapper(self, batch):\n",
" return batch\n",
" \n",
"\n",
"\n",
" def run_hooks(self):\n",
" val_loss = self.validation()\n",
Expand Down

0 comments on commit 4a2e1da

Please sign in to comment.