From 1b70baa1b50e80fc6eebf8eb6fc8997bdb37659c Mon Sep 17 00:00:00 2001 From: Olivier Dulcy Date: Sat, 6 Jan 2024 16:33:12 +0100 Subject: [PATCH] fix in collate_fn --- doctr/datasets/datasets/tensorflow.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/doctr/datasets/datasets/tensorflow.py b/doctr/datasets/datasets/tensorflow.py index 86b7b7928..da7890f97 100644 --- a/doctr/datasets/datasets/tensorflow.py +++ b/doctr/datasets/datasets/tensorflow.py @@ -49,10 +49,18 @@ def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]: @staticmethod def collate_fn(samples: List[Tuple[tf.Tensor, Any]]) -> Tuple[tf.Tensor, List[Any]]: - images, targets = zip(*samples) + # FIXME + # problems with some shape != 3 + images, targets = [], [] + for sample in samples: + if sample[0].shape[-1] == 3: + images.append(sample[0]) + targets.append(sample[1]) + + # images, targets = zip(*samples) images = tf.stack(images, axis=0) - return images, list(targets) + return images, targets class VisionDataset(AbstractDataset, _VisionDataset): # noqa: D101