-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Forest-Flow: Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees #69
Conversation
Add the ForestFlow notebook to the forest-flow branch.
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #69 +/- ##
==========================================
+ Coverage 35.63% 35.81% +0.17%
==========================================
Files 67 67
Lines 7417 7419 +2
==========================================
+ Hits 2643 2657 +14
+ Misses 4774 4762 -12 ☔ View full report in Codecov by Sentry. |
@atong01 it seems that only the tests in runner are run and not the new one. Can you check that, please? Thank you |
@kilianFatras I had a similar issue, the directory name should be |
examples/tabular/README.md
Outdated
```bash | ||
cd ../../ | ||
|
||
# install requirements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need to do this, it will be done automatically with the next line
requirements.txt
Outdated
@@ -13,3 +13,9 @@ pot | |||
torchdiffeq | |||
absl-py | |||
clean-fid | |||
pytest | |||
|
|||
# Forest-flow example |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need these here as they are in the extra requires
runner-requirements.txt
Outdated
@@ -50,3 +50,6 @@ pot | |||
# --------- notebook reqs -------- # | |||
seaborn>=0.12.2 | |||
pandas | |||
xgboost |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also not here
@@ -36,4 +36,5 @@ | |||
long_description=readme, | |||
long_description_content_type="text/markdown", | |||
packages=find_packages(), | |||
extras_require={"forest-flow": ["xgboost", "scikit-learn", "ForestDiffusion"]}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep perfrect
@@ -164,6 +164,9 @@ def sample_location_and_conditional_flow(self, x0, x1, return_noise=False): | |||
represents the source minibatch | |||
x1 : Tensor, shape (bs, *dim) | |||
represents the target minibatch | |||
(optionally) t : Tensor, shape (bs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
T here has to be batch size, but I think we should allow t as float.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better input checking would be nice, but not necessary, good for now
What does this PR do?
This PR adds a notebook on the Forest-Flow method to TorchCFM. Forest-Flow's purpose is to generate tabular data with Flow Matching methods! We have added a notebook showing how to use XGBoost to train the vector field of I-CFM and generate tabular data. This has required the addition of a parameter 't' within the 'sample_location_and_conditional_flow' function within each class. As we have modified classes, we have also added our first test function.
Before submitting
pytest
command?pre-commit run -a
command?Did you have fun?
Make sure you had fun coding 🙃