Skip to content

Commit

Permalink
WIP, need to randomly select fill color
Browse files Browse the repository at this point in the history
  • Loading branch information
eveenhuis committed Oct 11, 2024
1 parent c14a811 commit 68887ee
Showing 1 changed file with 65 additions and 18 deletions.
83 changes: 65 additions & 18 deletions examples/MC_RISE.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -410,31 +410,78 @@
" n: int,\n",
" s: int,\n",
" p1: float,\n",
" k: int,\n",
" fill_colors: Sequence[Sequence[int]],\n",
" seed: Optional[int] = None,\n",
" threads: int = 0,\n",
" ):\n",
" self._k = k\n",
" self._po = PerturbationOcclusion(\n",
" RISEGrid(n=n, s=s, p1=p1, seed=seed, threads=threads),\n",
" MCRISEScoring(k=k, p1=p1),\n",
" threads=threads\n",
" )\n",
"\n",
" @property\n",
" def fill(self) -> Optional[Union[int, Sequence[int]]]:\n",
" return self._po.fill\n",
"\n",
" @fill.setter\n",
" def fill(self, v: Optional[Union[int, Sequence[int]]]) -> None:\n",
" self._po.fill = v\n",
" self._perturber = RISEGrid(n=n, s=s, p1=p1, seed=seed, threads=threads)\n",
" self._generator = MCRISEScoring(k=k, p1=p1)\n",
" self._threads = threads\n",
" self._fill_colors = fill_colors\n",
"\n",
" def _generate(self, ref_image: np.ndarray, blackbox: ClassifyImage) -> np.ndarray:\n",
" return self._po.generate(ref_image, blackbox)\n",
" perturbation_masks = self._perturber(ref_image)\n",
" class_list = blackbox.get_labels()\n",
" # Input one thing so assume output of one thing.\n",
" ref_conf_dict = list(blackbox.classify_images([ref_image]))[0]\n",
" ref_conf_vec = np.asarray([ref_conf_dict[la] for la in class_list])\n",
" pert_conf_mat = np.empty(\n",
" (perturbation_masks.shape[0], ref_conf_vec.shape[0]),\n",
" dtype=ref_conf_vec.dtype\n",
" )\n",
"\n",
" def _occlude_image_streaming(\n",
" ref_image: np.ndarray,\n",
" masks: Iterable[np.ndarray],\n",
" threads: Optional[int] = None,\n",
" ) -> Generator[np.ndarray, None, None]:\n",
" # Just the [H x W] component.\n",
" img_shape = ref_image.shape[:2]\n",
" s: Tuple = (...,)\n",
" if ref_image.ndim > 2:\n",
" s = (..., None) # add channel axis for multiplication\n",
" \n",
" def work_func(i_: int, m: np.ndarray) -> np.ndarray:\n",
" m_shape = m.shape\n",
" if m_shape != img_shape:\n",
" raise ValueError(\n",
" f\"Input mask (position {i_}) did not the shape of the input \"\n",
" f\"image: {m_shape} != {img_shape}\"\n",
" )\n",
" img_m = np.empty_like(ref_image)\n",
" \n",
" # TODO: Randomly select fill colors for masks\n",
" np.add(\n",
" (m[s] * ref_image),\n",
" ((UINT8_ONE - m[s]) * fill),\n",
" out=img_m, casting=\"unsafe\"\n",
" )\n",
" \n",
" return img_m\n",
"\n",
" pert_conf_it = blackbox.classify_images(\n",
" _occlude_image_streaming(\n",
" ref_image, perturbation_masks,\n",
" fill=self.fill,\n",
" threads=self._threads\n",
" )\n",
" )\n",
" for i, pc in enumerate(pert_conf_it):\n",
" pert_conf_mat[i] = [pc[la] for la in class_list]\n",
"\n",
" # Compose classification results into a matrix for the generator\n",
" # algorithm.\n",
" return self._generator(\n",
" ref_conf_vec,\n",
" pert_conf_mat,\n",
" perturbation_masks,\n",
" )\n",
"\n",
" def get_config(self) -> Dict[str, Any]:\n",
" c = self._po._perturber.get_config()\n",
" c['k'] = self._k\n",
" # It turns out that our configuration here is equivalent to that given\n",
" # and retrieved from the RISEGrid implementation\n",
" c = self._perturber.get_config()\n",
" c['fill_colors'] = self._fill_colors\n",
" return c"
]
},
Expand Down

0 comments on commit 68887ee

Please sign in to comment.