Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: DiT.forward() got an unexpected keyword argument 'cond_drop' #15

Open
banhongjun opened this issue Dec 24, 2024 · 1 comment

Comments

@banhongjun
Copy link

麻烦大佬解答下这个问题:
Traceback (most recent call last):
File "/root/ban/F5-TTS-main_20241223/F5-TTS-ONNX-main/Export_ONNX/F5_TTS/Export_F5.py", line 330, in
torch.onnx.export(
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
_export(
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/onnx/utils.py", line 1612, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/onnx/utils.py", line 1134, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/onnx/utils.py", line 1010, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/onnx/utils.py", line 914, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/jit/_trace.py", line 1310, in _get_trace_graph
outs = ONNXTracedModule(
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/jit/_trace.py", line 138, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/jit/_trace.py", line 129, in wrapper
outs.append(self.inner(*trace_inputs))
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
result = self.forward(*input, **kwargs)
File "/root/ban/F5-TTS-main_20241223/F5-TTS-ONNX-main/Export_ONNX/F5_TTS/Export_F5.py", line 168, in forward
pred = self.f5_transformer(x=noise, cond=cat_mel_text, cond_drop=cat_mel_text_drop, time=self.time_expand[:, time_step], rope_cos=rope_cos, rope_sin=rope_sin, qk_rotated_empty=qk_rotated_empty)
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/f5-tts/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: DiT.forward() got an unexpected keyword argument 'cond_drop'

@DakeQQ
Copy link
Owner

DakeQQ commented Dec 24, 2024

错误 DiT.forward() got an unexpected keyword argument 'cond_drop' 是由于使用 shutil.copyfile() 时复制过程不完整导致的。

以下是复制脚本:

shutil.copyfile(modified_path + '/vocos/heads.py', python_package_path + '/vocos/heads.py')
shutil.copyfile(modified_path + '/vocos/models.py', python_package_path + '/vocos/models.py')
shutil.copyfile(modified_path + '/vocos/modules.py', python_package_path + '/vocos/modules.py')
shutil.copyfile(modified_path + '/vocos/pretrained.py', python_package_path + '/vocos/pretrained.py')
shutil.copyfile(modified_path + '/F5/modules.py', F5_project_path + '/f5_tts/model/modules.py')
shutil.copyfile(modified_path + '/F5/dit.py', F5_project_path + '/f5_tts/model/backbones/dit.py')
shutil.copyfile(modified_path + '/F5/utils_infer.py', F5_project_path + '/f5_tts/infer/utils_infer.py')

有时,由于磁盘延迟,复制过程可能无法完全完成。如果发生这种情况,请点击“RUN”并重新运行脚本。 (请确保所有路径都正确。)

我们刚刚测试过,如果复制过程正常完成,则 ONNX 模型应能成功导出。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants