From 4a2e1da39ffa5ac08d5f7fda4c8a6e748de7eedf Mon Sep 17 00:00:00 2001 From: Aditya Thyagarajan Date: Fri, 26 Jan 2024 21:55:45 +0530 Subject: [PATCH] unregister if already registered --- .../detectron2_training-kfold.ipynb | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/object_detection/detectron2_training-kfold.ipynb b/object_detection/detectron2_training-kfold.ipynb index 5e4e8aa..aee0f63 100644 --- a/object_detection/detectron2_training-kfold.ipynb +++ b/object_detection/detectron2_training-kfold.ipynb @@ -49,8 +49,7 @@ "import glob\n", "from sklearn.model_selection import KFold\n", "import json\n", - "from collections import defaultdict\n", - "from detectron2.data.datasets import register_coco_instances" + "from collections import defaultdict" ] }, { @@ -150,15 +149,24 @@ " annotations_count = len(data_dict['annotations'])\n", " print(f\"Number of images: {images_count}, Number of annotations: {annotations_count}\")\n", "\n", + " \n", + "def unregister_coco_instances(name):\n", + " if name in DatasetCatalog.list():\n", + " DatasetCatalog.remove(name)\n", + " MetadataCatalog.remove(name)\n", + "\n", "# Generate K-Fold cross-validation\n", "kf = KFold(n_splits=NUM_FOLDS)\n", "pairs = []\n", "for fold, (train_indices, test_indices) in enumerate(kf.split(image_ids)):\n", " train_data, test_data = split_data(train_indices, test_indices)\n", - " # Register COCO instances for training and validation. \n", - " # Note: The 'train2017' folder is retained as the base path for images.\n", " train_file = f\"train_coco_{fold}_fold.json\"\n", " test_file = f\"test_coco_{fold}_fold.json\"\n", + " # Unregister instances with the same names only if they exist\n", + " unregister_coco_instances(train_file)\n", + " unregister_coco_instances(test_file)\n", + " # Register COCO instances for training and validation. \n", + " # Note: The 'train2017' folder is retained as the base path for images.\n", " register_coco_instances(train_file, {}, train_file, \"train2017\")\n", " register_coco_instances(test_file, {}, test_file, \"train2017\")\n", " pairs.append([train_file,test_file])\n", @@ -215,8 +223,6 @@ " \n", " def data_loader_mapper(self, batch):\n", " return batch\n", - " \n", - "\n", "\n", " def run_hooks(self):\n", " val_loss = self.validation()\n",