diff --git a/oodeel/methods/dknn.py b/oodeel/methods/dknn.py index 983316f..e7d1470 100644 --- a/oodeel/methods/dknn.py +++ b/oodeel/methods/dknn.py @@ -38,21 +38,28 @@ class DKNN(OODBaseDetector): Args: nearest: number of nearest neighbors to consider. Defaults to 1. + use_gpu (bool): Flag to enable GPU acceleration for FAISS. Defaults to False. """ - def __init__( - self, - nearest: int = 50, - ): + def __init__(self, nearest: int = 50, use_gpu: bool = False): super().__init__() - self.index = None self.nearest = nearest + self.use_gpu = use_gpu + + if self.use_gpu: + try: + self.res = faiss.StandardGpuResources() + except AttributeError as e: + raise ImportError( + "faiss-gpu is not installed, but use_gpu was set to True." + + "Please install faiss-gpu or set use_gpu to False." + ) from e def _fit_to_dataset(self, fit_dataset: Union[TensorType, DatasetType]) -> None: """ Constructs the index from ID data "fit_dataset", which will be used for - nearest neighbor search. + nearest neighbor search. Can operate on CPU or GPU based on the `use_gpu` flag. Args: fit_dataset: input dataset (ID) to construct the index with. @@ -61,7 +68,13 @@ def _fit_to_dataset(self, fit_dataset: Union[TensorType, DatasetType]) -> None: fit_projected = self.op.convert_to_numpy(fit_projected[0]) fit_projected = fit_projected.reshape(fit_projected.shape[0], -1) norm_fit_projected = self._l2_normalization(fit_projected) - self.index = faiss.IndexFlatL2(norm_fit_projected.shape[1]) + + if self.use_gpu: + cpu_index = faiss.IndexFlatL2(norm_fit_projected.shape[1]) + self.index = faiss.index_cpu_to_gpu(self.res, 0, cpu_index) + else: + self.index = faiss.IndexFlatL2(norm_fit_projected.shape[1]) + self.index.add(norm_fit_projected) def _score_tensor(self, inputs: TensorType) -> Tuple[np.ndarray]: