From 68887ee7a9a56228b321161fcf27b5a6975faf89 Mon Sep 17 00:00:00 2001 From: Emily Veenhuis Date: Fri, 11 Oct 2024 11:04:20 -0400 Subject: [PATCH] WIP, need to randomly select fill color --- examples/MC_RISE.ipynb | 83 +++++++++++++++++++++++++++++++++--------- 1 file changed, 65 insertions(+), 18 deletions(-) diff --git a/examples/MC_RISE.ipynb b/examples/MC_RISE.ipynb index c76467fc..3ef90fad 100644 --- a/examples/MC_RISE.ipynb +++ b/examples/MC_RISE.ipynb @@ -410,31 +410,78 @@ " n: int,\n", " s: int,\n", " p1: float,\n", - " k: int,\n", + " fill_colors: Sequence[Sequence[int]],\n", " seed: Optional[int] = None,\n", " threads: int = 0,\n", " ):\n", - " self._k = k\n", - " self._po = PerturbationOcclusion(\n", - " RISEGrid(n=n, s=s, p1=p1, seed=seed, threads=threads),\n", - " MCRISEScoring(k=k, p1=p1),\n", - " threads=threads\n", - " )\n", - "\n", - " @property\n", - " def fill(self) -> Optional[Union[int, Sequence[int]]]:\n", - " return self._po.fill\n", - "\n", - " @fill.setter\n", - " def fill(self, v: Optional[Union[int, Sequence[int]]]) -> None:\n", - " self._po.fill = v\n", + " self._perturber = RISEGrid(n=n, s=s, p1=p1, seed=seed, threads=threads)\n", + " self._generator = MCRISEScoring(k=k, p1=p1)\n", + " self._threads = threads\n", + " self._fill_colors = fill_colors\n", "\n", " def _generate(self, ref_image: np.ndarray, blackbox: ClassifyImage) -> np.ndarray:\n", - " return self._po.generate(ref_image, blackbox)\n", + " perturbation_masks = self._perturber(ref_image)\n", + " class_list = blackbox.get_labels()\n", + " # Input one thing so assume output of one thing.\n", + " ref_conf_dict = list(blackbox.classify_images([ref_image]))[0]\n", + " ref_conf_vec = np.asarray([ref_conf_dict[la] for la in class_list])\n", + " pert_conf_mat = np.empty(\n", + " (perturbation_masks.shape[0], ref_conf_vec.shape[0]),\n", + " dtype=ref_conf_vec.dtype\n", + " )\n", + "\n", + " def _occlude_image_streaming(\n", + " ref_image: np.ndarray,\n", + " masks: Iterable[np.ndarray],\n", + " threads: Optional[int] = None,\n", + " ) -> Generator[np.ndarray, None, None]:\n", + " # Just the [H x W] component.\n", + " img_shape = ref_image.shape[:2]\n", + " s: Tuple = (...,)\n", + " if ref_image.ndim > 2:\n", + " s = (..., None) # add channel axis for multiplication\n", + " \n", + " def work_func(i_: int, m: np.ndarray) -> np.ndarray:\n", + " m_shape = m.shape\n", + " if m_shape != img_shape:\n", + " raise ValueError(\n", + " f\"Input mask (position {i_}) did not the shape of the input \"\n", + " f\"image: {m_shape} != {img_shape}\"\n", + " )\n", + " img_m = np.empty_like(ref_image)\n", + " \n", + " # TODO: Randomly select fill colors for masks\n", + " np.add(\n", + " (m[s] * ref_image),\n", + " ((UINT8_ONE - m[s]) * fill),\n", + " out=img_m, casting=\"unsafe\"\n", + " )\n", + " \n", + " return img_m\n", + "\n", + " pert_conf_it = blackbox.classify_images(\n", + " _occlude_image_streaming(\n", + " ref_image, perturbation_masks,\n", + " fill=self.fill,\n", + " threads=self._threads\n", + " )\n", + " )\n", + " for i, pc in enumerate(pert_conf_it):\n", + " pert_conf_mat[i] = [pc[la] for la in class_list]\n", + "\n", + " # Compose classification results into a matrix for the generator\n", + " # algorithm.\n", + " return self._generator(\n", + " ref_conf_vec,\n", + " pert_conf_mat,\n", + " perturbation_masks,\n", + " )\n", "\n", " def get_config(self) -> Dict[str, Any]:\n", - " c = self._po._perturber.get_config()\n", - " c['k'] = self._k\n", + " # It turns out that our configuration here is equivalent to that given\n", + " # and retrieved from the RISEGrid implementation\n", + " c = self._perturber.get_config()\n", + " c['fill_colors'] = self._fill_colors\n", " return c" ] },