-
Notifications
You must be signed in to change notification settings - Fork 178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Example] Add AI4Chem example IFM(nips2023) #1002
base: develop
Are you sure you want to change the base?
Conversation
- if not change 't' to 'T', run check cannot be passed
- train val func impl
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢提交PR,有一些修改建议如下
examples/ifm/conf/ifm.yaml
Outdated
hydra: | ||
run: | ||
# dynamic output directory according to running time and override name | ||
dir: outputs_ifm/doc_metric #${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} | ||
job: | ||
name: ${mode} # name of logfile | ||
chdir: false # keep current working directory unchanged | ||
config: | ||
override_dirname: | ||
exclude_keys: | ||
- TRAIN.checkpoint_path | ||
- TRAIN.pretrained_model_path | ||
- EVAL.pretrained_model_path | ||
- mode | ||
- output_dir | ||
- log_freq | ||
sweep: | ||
# output directory for multirun | ||
dir: ${hydra.run.dir} | ||
subdir: ./ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hydra: | |
run: | |
# dynamic output directory according to running time and override name | |
dir: outputs_ifm/doc_metric #${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} | |
job: | |
name: ${mode} # name of logfile | |
chdir: false # keep current working directory unchanged | |
config: | |
override_dirname: | |
exclude_keys: | |
- TRAIN.checkpoint_path | |
- TRAIN.pretrained_model_path | |
- EVAL.pretrained_model_path | |
- mode | |
- output_dir | |
- log_freq | |
sweep: | |
# output directory for multirun | |
dir: ${hydra.run.dir} | |
subdir: ./ | |
defaults: | |
- ppsci_default | |
- TRAIN: train_default | |
- TRAIN/ema: ema_default | |
- TRAIN/swa: swa_default | |
- EVAL: eval_default | |
- INFER: infer_default | |
- hydra/job/config/override_dirname/exclude_keys: exclude_keys_default | |
- _self_ | |
hydra: | |
run: | |
# dynamic output directory according to running time and override name | |
dir: outputs_ifm/doc_metric #${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} | |
job: | |
name: ${mode} # name of logfile | |
chdir: false # keep current working directory unchanged | |
sweep: | |
# output directory for multirun | |
dir: ${hydra.run.dir} | |
subdir: ./ | |
|
||
|
||
class IFMMoeDataset(io.Dataset): | ||
"""Dataset for `MeshAirfoil`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处不是MeshAirfoil
return (num_indices - num_pos) / num_pos | ||
|
||
|
||
class IFMMoeDataset(io.Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请删除IFMMoeDataset内有关MeshAirfoil的文本,MeshAirfoil是翼型数据集,与化学无关
# super().__init__() | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# super().__init__() | |
pass | |
super().__init__() |
self.mask = ~np.isnan(Ys) * 1.0 | ||
|
||
def __len__(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass |
examples/ifm/ednn_utils.py
Outdated
return tn, fp, fn, tp, se, sp, acc, mcc, auc_prc, auc_roc | ||
|
||
|
||
class Meter(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class Meter(object): | |
class Meter: |
examples/ifm/ednn_utils.py
Outdated
return self.mcc() | ||
|
||
|
||
class MyDataset(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个类有用到吗?
examples/ifm/ednn_utils.py
Outdated
def collate_fn(data_batch): | ||
Xs, Ys, masks = map(list, zip(*data_batch)) | ||
|
||
Xs = paddle.stack(Xs, axis=0) | ||
Ys = paddle.stack(Ys, axis=0) | ||
masks = paddle.stack(masks, axis=0) | ||
|
||
return Xs, Ys, masks | ||
|
||
|
||
def set_random_seed(seed=0): | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
paddle.manual_seed(seed) # 为CPU设置种子用于生成随机数 | ||
if paddle.cuda.is_available(): | ||
paddle.cuda.manual_seed(seed) # 为当前GPU设置随机种子 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个函数是否与其他文件中的重复了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本文件是否可以删除?看起来每个函数在其他文件中都有实现
examples/ifm/ifm.py
Outdated
model, | ||
output_dir=cfg.output_dir, | ||
log_freq=cfg.log_freq, | ||
seed=cfg.seed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seed设置这行可以删除
ppsci/arch/ifm_mlp.py
Outdated
inputs (int): Input dim. | ||
hidden_units (List[int]): Units num in hidden layers. | ||
outputs (int): Output dim. | ||
dp_ratio (float): Dropout ratio. | ||
first_omega_0 (float): Frequency factor used in first layer. | ||
hidden_omega_0 (float): Frequency factor used in hidden layer. | ||
reg (bool): Regularization flag. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参数相比Args,要增加四个空格的缩进
examples/ifm/ednn_utils.py
Outdated
def init_parameter_uniform( | ||
parameter: paddle.base.framework.EagerParamBase, n: int | ||
) -> None: | ||
ppsci.utils.initializer.uniform_(parameter, -1 / np.sqrt(n), 1 / np.sqrt(n)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
examples/ifm/ednn_utils.py里的init_parameter_uniform函数是不是可以删除,另一个文件有这个函数,并且ednn_utils.py里的函数并没有被使用。
PR types
New features
PR changes
Others : Add new AI4Chem example.
Describe
Add new AI4Chem example including code and docs.