-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1088 from mrapp-ke/output-sampling-options
Unify options for configuring sampling methods
- Loading branch information
Showing
8 changed files
with
179 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 56 additions & 12 deletions
68
python/subprojects/common/mlrl/common/cython/output_sampling.pyx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,74 @@ | ||
""" | ||
@author: Michael Rapp ([email protected]) | ||
""" | ||
from mlrl.common.cython.validation import assert_greater_or_equal | ||
from mlrl.common.cython.validation import assert_greater, assert_greater_or_equal, assert_less | ||
|
||
|
||
cdef class OutputSamplingWithoutReplacementConfig: | ||
""" | ||
Allows to configure a method for sampling outputs without replacement. | ||
""" | ||
|
||
def get_num_samples(self) -> int: | ||
def get_sample_size(self) -> float: | ||
""" | ||
Returns the number of outputs that are included in a sample. | ||
Returns the fraction of outputs that are included in a sample. | ||
|
||
:return: The number of outputs that are included in a sample | ||
:return: The fraction of outputs that are included in a sample | ||
""" | ||
return self.config_ptr.getNumSamples() | ||
return self.config_ptr.getSampleSize() | ||
|
||
def set_num_samples(self, num_samples: int) -> OutputSamplingWithoutReplacementConfig: | ||
def set_sample_size(self, sample_size: float) -> OutputSamplingWithoutReplacementConfig: | ||
""" | ||
Sets the number of outputs that should be included in a sample. | ||
Sets the fraction of outputs that should be included in a sample. | ||
|
||
:param num_samples: The number of outputs that should be included in a sample. Must be at least 1 | ||
:return: An `OutputSamplingWithoutReplacementConfig` that allows further configuration of the | ||
sampling method | ||
:param sample_size: The fraction of outputs that should be included in a sample, e.g., a value of 0.6 | ||
corresponds to 60 % of the available outputs. Must be in (0, 1) | ||
:return: An `OutputSamplingWithoutReplacementConfig` that allows further configuration of the method | ||
for sampling outputs | ||
""" | ||
assert_greater_or_equal('num_samples', num_samples, 1) | ||
self.config_ptr.setNumSamples(num_samples) | ||
assert_greater('sample_size', sample_size, 0) | ||
assert_less('sample_size', sample_size, 1) | ||
self.config_ptr.setSampleSize(sample_size) | ||
return self | ||
|
||
def get_min_samples(self) -> int: | ||
""" | ||
Returns the minimum number of outputs that are included in a sample. | ||
|
||
:return: The minimum number of outputs that are included in a sample | ||
""" | ||
return self.config_ptr.getMinSamples() | ||
|
||
def set_min_samples(self, min_samples: int) -> OutputSamplingWithoutReplacementConfig: | ||
""" | ||
Sets the minimum number of outputs that should be included in a sample. | ||
|
||
:param min_samples: The minimum number of outputs that should be included in a sample. Must be at least 1 | ||
:return: An `OutputSamplingWithoutReplacementConfig` that allows further configuration of the method | ||
for sampling outputs | ||
""" | ||
assert_greater_or_equal('min_samples', min_samples, 1) | ||
self.config_ptr.setMinSamples(min_samples) | ||
return self | ||
|
||
def get_max_samples(self) -> int: | ||
""" | ||
Returns the maximum number of outputs that are included in a sample. | ||
|
||
:return: The maximum number of outputs that are included in a sample | ||
""" | ||
return self.config_ptr.getMaxSamples() | ||
|
||
def set_max_samples(self, max_samples: int) -> OutputSamplingWithoutReplacementConfig: | ||
""" | ||
Sets the maximum number of outputs that should be included in a sample. | ||
|
||
:param max_samples: The maximum number of outputs that should be included in a sample. Must be at least | ||
`get_min_samples()` or 0, if the number of outputs should not be restricted | ||
:return: An `OutputSamplingWithoutReplacementConfig` that allows further configuration of the method | ||
for sampling outputs | ||
""" | ||
if max_samples != 0: | ||
assert_greater_or_equal('max_samples', max_samples, self.get_min_samples()) | ||
self.config_ptr.setMaxSamples(max_samples) | ||
return self |