-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft implementation of point estimation #281
base: dev
Are you sure you want to change the base?
Conversation
Find a notebook to look at training and inference for such a point estimator here: https://github.com/han-ol/bayesflow/blob/point-estimation/examples/draft-point-est.ipynb |
After writing this up, in my opinion the best structure is probably to have:
For class names I would propose: If we want to support convenient creation of heads by just specifying a loss function, this can be done by subclassing a Additionally, I would prefer an implementation where estimate() returns a dictionary with named estimates corresponding to the individual heads. After collecting some of your thoughts I would proceed with implementing whatever we converge on. What do you think? |
Hans, thanks for the great ideas and the very mature first draft! Here are some initial thoughts from my side. More to come later:
|
This looks really cool already! Thank you for this PR! There is a lot of content here in this thread already and I may benefit from a call where you show me the current state. This would help me give reasonable feedback. I will contact you offline about it. |
Ok cool, Paul! Stefan, thanks for your takes already! Some notes to some of them:
I am not sure about this. Tagging @LarsKue for this question, I think you mentioned a preference of parallel implementation rather than inheritance?
Agreed, that some form of naming is necessary. How do the "data classes" you imagine differ from dictionaries of the type: # assuming a batch_size of two, two quantile levels and 3 inference variables:
dict(
mean=[
[ 1, 2, 3],
[2, 1, 3]
], # shape=(2,3)
quantiles=[
[[ -1, 0, 1], [ 3, 4, 5]],
[[1, -1, -2], [3, 2, 5]]
], # shape=(2,2,3)
...
) For one thing, it might be good to make the quantile levels accessible somewhere close to the estimated quantiles. This could be part of a data class.
👍
Ok! Just in their defence, I'd say point estimators are also fully Bayesian ;) Functionals of the proper Bayesian posterior distribution.
Good point! Using both |
Regarding the output, I think we can go with dictionaries for now. I believe some custom data classes will come in handy rather soon. Just a thought: Can the heads simply be determined automatically assuming the scoring rules know their dims? |
Yes, they can mostly be inferred, and I'd suggest linear layers followed by a reshape as an overwritable default. However, some scoring rules benefit from (e.g. quantiles, monotonously increasing) or need (e.g. covariance matrix, positive semidefinite) a specific architecture. |
This (draft) pull request is meant to hold discussion and development of point estimation within BayesFlow.
The functionality per se was also discussed in the issue #121.
The implementation should make it easy to
Commit 093785d contains a first example including ONLY quantile estimation.
Writing down draft specifications and guiding thoughts
Names for everything
What is "inference", does it include point estimation? Are we calling networks discussed below PointInferenceNetworks, PointRegressors, or something else?
For now I stick with
ContinuousPointApproximator
andPointInferenceNetwork
to make their roles in relation to the existing codebase obvious.Components
A
ContinuousPointApproximator
parallels theContinuousApproximator
and bundles some feed-forward architecture with an optional summary network suitable for learning point estimates optimized to minimize some Bayes risk.Thus it serves the same roles (including Adapter calls, summary network, etc), but instead of a
sample
method it provides anestimate
method.PointInferenceNetwork
parallels theInferenceNetwork
by providing a base class to inherit from forgenerative model classes suiteable for the role of approximating a conditional distributionfeed-forward model classes suitable for point estimation.Convenient default estimator
The API for functionality that covers the need of most users could be something like
, with optional constructor arguments to tweak it a bit:
choices in output of
estimate
For inference the method
estimate(data_dict))
produces the point estimates for a given input/condition/observation.The default PointInferenceNetwork would produce 5 point estimates (mean, std and three quantile levels) of the marginal distribution for each inference variable.
These estimates need more explanation than samples provided by generative networks typically need. We need to communicate to user, diagnostics code or other downstream tasks which estimate lands where. It seems to me that it would be helpful if such
estimate
output is structured as a dictionary with the point estimates names as keys, likedict(mean: tensor, std: tensor, ...)
rather than a tensor from concatenating all individual estimates.
Architecture
The architecture has one shared "body" network as well as separate "head" networks for each scoring rule.
The subnet keyword argument has a default of "mlp", in general the argument is resolved by
find_network
from the bayesflow.utils which can take a string for predefined networks or a user defined custom class.Currently the non-shared networks are just linear layers.
Extendable design
The first draft only includes a PointInferenceNetwork subclass for quantile estimation, currently called called QuantileRegressor and found in bayesflow/networks/regressors/quantile_regressor.py.
This is just the first step, and we want to support different loss functions (which are typically called scoring rules in this context) that result in other point estimates.
A more flexible API including custom scoring rules /losses could use a
PointInferenceNetwork
that accepts a single or a sequence ofScoringRule
s.This can generate the appropriate number of heads in the
build()
method.It also can pass the respective outputs to the corresponding scoring rules that compute their actual loss contributions and sum them up.
A
ScoringRule
/ScoringLoss
has a name (we want to distinguish multiple of them later) and ascore
/compute
method.A
ScoringRule
could also compute aprediction_shape
in the constructor to be accessed by generic code that generates a corresponding head for the scoring rule.choice we have here
prediction_shape
could also be aprediction_size
, an int rather than a tupleThe head would have
output_shape=(*batch_shape, *prediction_shape, num_variables)
(32,)
and(1,)
(32, 1, 4)
shall be headoutput_shape
32*1*4
units and reshape operation which together form a headinteraction with choice of output of estimate method: If we choose to allow multidimensional
prediction_shape
s rather than onlyprediction_size
s, we can not concatenate output of different heads since each can have an individualprediction_shape
.Thus the decision interacts with whether the predict method should return tensors or a dict of tensors (see above).
Below are some notes how
ScoringRule
definitions could look like.If we want to support specific activation functions for different scoring rules, we might add a non-parametric activation function to the ScoringRule's definition.
We could also go all the way and have the
ScoringRule
also contain the head itself, taking the last joint embedding and mapping it to estimates. This then would naturally include an optional nonlinearity in the end and simultaneously give users the option to tweak the architecture of the non-shared weights.Other notes:
bf.Workflow
?