Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Forest-Flow: Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees #69

Merged
merged 30 commits into from
Nov 26, 2023

Conversation

kilianFatras
Copy link
Collaborator

What does this PR do?

This PR adds a notebook on the Forest-Flow method to TorchCFM. Forest-Flow's purpose is to generate tabular data with Flow Matching methods! We have added a notebook showing how to use XGBoost to train the vector field of I-CFM and generate tabular data. This has required the addition of a parameter 't' within the 'sample_location_and_conditional_flow' function within each class. As we have modified classes, we have also added our first test function.

Before submitting

  • Did you make sure title is self-explanatory and the description concisely explains the PR?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you test your PR locally with pytest command?
  • Did you run pre-commit hooks with pre-commit run -a command?

Did you have fun?

Make sure you had fun coding 🙃

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@kilianFatras kilianFatras changed the title Forest flow Forest flow (TorchCFM 1.0.5) Nov 10, 2023
@codecov-commenter
Copy link

codecov-commenter commented Nov 10, 2023

Codecov Report

Attention: 2 lines in your changes are missing coverage. Please review.

Comparison is base (81fcb8d) 35.63% compared to head (cac49f1) 35.81%.

Files Patch % Lines
torchcfm/conditional_flow_matching.py 83.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #69      +/-   ##
==========================================
+ Coverage   35.63%   35.81%   +0.17%     
==========================================
  Files          67       67              
  Lines        7417     7419       +2     
==========================================
+ Hits         2643     2657      +14     
+ Misses       4774     4762      -12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@kilianFatras kilianFatras requested a review from atong01 November 10, 2023 17:10
@kilianFatras kilianFatras changed the title Forest flow (TorchCFM 1.0.5) Adding Forest-Flow: Generating and Imputing Tabular Data via Diffusion and Flow-based Gradient-Boosted Trees Nov 10, 2023
@kilianFatras
Copy link
Collaborator Author

@atong01 it seems that only the tests in runner are run and not the new one. Can you check that, please? Thank you

@guillaumehu
Copy link
Contributor

@kilianFatras I had a similar issue, the directory name should be tests instead of test. Hope it works!

```bash
cd ../../

# install requirements
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to do this, it will be done automatically with the next line

requirements.txt Outdated
@@ -13,3 +13,9 @@ pot
torchdiffeq
absl-py
clean-fid
pytest

# Forest-flow example
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need these here as they are in the extra requires

@@ -50,3 +50,6 @@ pot
# --------- notebook reqs -------- #
seaborn>=0.12.2
pandas
xgboost
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also not here

@@ -36,4 +36,5 @@
long_description=readme,
long_description_content_type="text/markdown",
packages=find_packages(),
extras_require={"forest-flow": ["xgboost", "scikit-learn", "ForestDiffusion"]},
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep perfrect

@@ -164,6 +164,9 @@ def sample_location_and_conditional_flow(self, x0, x1, return_noise=False):
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
(optionally) t : Tensor, shape (bs)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T here has to be batch size, but I think we should allow t as float.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better input checking would be nice, but not necessary, good for now

@kilianFatras kilianFatras merged commit 1440a98 into main Nov 26, 2023
33 checks passed
@kilianFatras kilianFatras deleted the forest_flow branch December 14, 2023 15:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants