Skip to content

Commit

Permalink
Added subset_in_dim option
Browse files Browse the repository at this point in the history
  • Loading branch information
DomInvivo committed Dec 15, 2023
1 parent 7859351 commit f95872d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 25 deletions.
65 changes: 40 additions & 25 deletions graphium/nn/architectures/global_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def __init__(
hidden_dims: Union[List[int], int],
num_ensemble: int,
reduction: Union[str, Callable],
subset_sample_ratio: float = 1.0,
subset_in_dim: Union[float, int] = 1.0,
depth: Optional[int] = None,
activation: Union[str, Callable] = "relu",
last_activation: Union[str, Callable] = "none",
Expand Down Expand Up @@ -462,9 +462,11 @@ def __init__(
- "median": Median reduction
- `Callable`: Any callable function. Must take `dim` as a keyword argument.
subset_sample_ratio:
Ratio of the subset of the ensemble to use.
Must be between 0 and 1. A different subset is used for each ensemble.
subset_in_dim:
If float, ratio of the subset of the ensemble to use. Must be between 0 and 1.
If int, number of elements to subset from in_dim.
If `None`, the subset_in_dim is set to `1.0`.
A different subset is used for each ensemble.
Only valid if the input shape is `[B, Din]`.
depth:
Expand Down Expand Up @@ -538,8 +540,8 @@ def __init__(
layer_kwargs = {}
layer_kwargs["num_ensemble"] = self._parse_num_ensemble(num_ensemble, layer_kwargs)

# Parse the sample ratio
self.subset_sample_ratio, self.subset_in_dim, self.subset_idx = self._parse_subset_sample(subset_sample_ratio, num_ensemble)
# Parse the sample input dimension
self.subset_in_dim, self.subset_idx = self._parse_subset_in_dim(in_dim, subset_in_dim, num_ensemble)

super().__init__(
in_dim=in_dim,
Expand Down Expand Up @@ -612,68 +614,81 @@ def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optiona
elif reduction == "sum":
return torch.sum
elif reduction == "max":

def max_vals(x, dim):
return torch.max(x, dim=dim).values

return max_vals
elif reduction == "min":

def min_vals(x, dim):
return torch.min(x, dim=dim).values

return min_vals
elif reduction == "median":

def median_vals(x, dim):
return torch.median(x, dim=dim).values

return median_vals
elif callable(reduction):
return reduction
else:
raise ValueError(f"Unknown reduction {reduction}")

def _parse_subset_sample(self, in_dim: int, subset_sample_ratio: float, num_ensemble: int) -> Tuple[float, int]:
def _parse_subset_in_dim(
self, in_dim: int, subset_in_dim: Union[float, int], num_ensemble: int
) -> Tuple[float, int]:
r"""
Parse the subset_sample_ratio argument and the subset_in_dim.
Parse the subset_in_dim argument and the subset_in_dim.
The subset_sample_ratio is the ratio of the hidden features to use by each MLP of the ensemble.
The subset_in_dim is the ratio of the hidden features to use by each MLP of the ensemble.
The subset_in_dim is the number of input features to use by each MLP of the ensemble.
Parameters:
in_dim: The number of input features, before subsampling
subset_sample_ratio:
subset_in_dim:
Ratio of the subset of features to use by each MLP of the ensemble.
Must be between 0 and 1. A different subset is used for each ensemble.
Only valid if the input shape is `[B, Din]`.
If None, the subset_sample_ratio is set to 1.0.
If None, the subset_in_dim is set to 1.0.
num_ensemble:
Number of MLPs that run in parallel.
Returns:
subset_sample_ratio: The ratio of the subset of features to use by each MLP of the ensemble.
subset_in_dim: The number of input features to use by each MLP of the ensemble.
subset_in_dim: The ratio of the subset of features to use by each MLP of the ensemble.
subset_idx: The indices of the features to use by each MLP of the ensemble.
"""

# Parse the subset_sample_ratio, make sure value is between 0 and 1
if subset_sample_ratio is None:
subset_sample_ratio = 1.0
assert subset_sample_ratio > 0.0 and subset_sample_ratio <= 1.0, f"subset_sample_ratio={subset_sample_ratio}"
# Parse the subset_in_dim, make sure value is between 0 and 1
if subset_in_dim is None:
subset_in_dim = 1.0
if isinstance(subset_in_dim, int):
assert (
subset_in_dim > 0 and subset_in_dim <= in_dim
), f"subset_in_dim={subset_in_dim}, in_dim={in_dim}"
elif isinstance(subset_in_dim, float):
assert subset_in_dim > 0.0 and subset_in_dim <= 1.0, f"subset_in_dim={subset_in_dim}"

# Parse the subset_in_dim, make sure value is between 0 and in_dim
subset_in_dim = int(torch.ceil(in_dim * subset_sample_ratio).item())
if subset_in_dim == 0:
subset_in_dim = 1
# Convert to integer value
subset_in_dim = int(in_dim * subset_in_dim)
if subset_in_dim == 0:
subset_in_dim = 1

# Create the subset_idx, which is a list of indices to use for each ensemble
if subset_in_dim == in_dim:
subset_idx = None
else:
subset_idx = torch.stack([
torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)])
subset_idx = torch.stack(
[torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)]
).unsqueeze(-2)

return subset_sample_ratio, subset_in_dim, subset_idx
return subset_in_dim, subset_idx

def _parse_layers(self, layer_type, residual_type):
# Parse the layer and residuals
Expand Down Expand Up @@ -707,7 +722,7 @@ def forward(self, h: torch.Tensor) -> torch.Tensor:
if self.subset_idx is not None:
if len(h.shape) != 2:
assert h.shape[-3] == 1, f"Expected shape to be [B, Din] or [..., 1, B, Din], got {h.shape}"
h = h[..., self.subset_idx]
h = h[..., self.subset_idx]

# Run the standard forward pass
h = super().forward(h)
Expand Down
6 changes: 6 additions & 0 deletions graphium/nn/ensemble_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,16 +391,22 @@ def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optiona
elif reduction == "sum":
return torch.sum
elif reduction == "max":

def max_vals(x, dim):
return torch.max(x, dim=dim).values

return max_vals
elif reduction == "min":

def min_vals(x, dim):
return torch.min(x, dim=dim).values

return min_vals
elif reduction == "median":

def median_vals(x, dim):
return torch.median(x, dim=dim).values

return median_vals
elif callable(reduction):
return reduction
Expand Down

0 comments on commit f95872d

Please sign in to comment.