You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When attempting to deserialize a PyTorch model containing Sequential layers with tensors having different dimensions in weight (e.g., Linear and Conv1d) into a Burn model, the process panics due to a dimension mismatch error.
#[derive(Module,Debug)]enumSeqType<B:Backend>{Linear(Linear<B>),Conv1d(Conv1d<B>),}#[derive(Module,Debug)]structSeqModel<B:Backend>{seq:Vec<SeqType<B>>,}impl<B:Backend>SeqModel<B>{fnnew(device:&B::Device) -> Self{let linear = LinearConfig::new(8,4).init(device);let conv = Conv1dConfig::new(4,4,3).init(device);Self{seq:vec![SeqType::Linear(linear), SeqType::Conv1d(conv)],}}}fnmain(){let device = Default::default();let record = PyTorchFileRecorder::<FullPrecisionSettings>::default().load("./test_seq.pt".into(),&device).expect("Failed when decoding state dict");let model = SeqModel::<LibTorch>::new(&device);let model = model.load_record(record);println!("{:#?}", model);}
get panic:
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
thread 'main' panicked at /home/med/.cargo/registry/src/index.crates.io-6f17d22bba15001f/b
urn-tensor-0.14.0/src/tensor/api/base.rs:722:9:
=== Tensor Operation Error ===
Operation: 'From Data'
Reason:
1. Given dimensions differ from the tensor rank. Tensor rank: '2', given dimensions: '[4, 4, 3]'.
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
This suggests that the error occurs when trying to deserialize the weight of Conv1d (Tensor<B, 3>) into a Linear weight (Tensor<B, 2>) while attempting to match each variant in the enum.
The solution would be propagating dimension mismatch error upwards instead of panicking.
Desktop (please complete the following information):
OS: Ubuntu 22.04.4 LTS on Windows 10 x86_64
burn = { version = "0.14.0", features = ["tch", "wgpu"] }
burn-import = "0.14.0"
rustc 1.81.0 (eeb90cda1 2024-09-04)
The text was updated successfully, but these errors were encountered:
This suggests that the error occurs when trying to deserialize the weight of Conv1d (Tensor<B, 3>) into a Linear weight (Tensor<B, 2>) while attempting to match each variant in the enum.
The flexibility of sequential blocks has a couple of drawbacks, and you just faced one 😅 In pytorch you can basically insert any type of module you want, but with Burn replicating this behavior is a bit "hacky".
We use serde's untagged enum deserialization to try to match the correct module, but in this case both the linear and conv layers have the same fields so they technically match.
The solution would be propagating dimension mismatch error upwards instead of panicking.
Haven't looked at this part of the codebase in a while, but I'd have to check how this could be handled better (if possible).
In the meantime, you could declare a module that contains a fixed number of fields instead of having a sequential-like type. Then, the fields can simply be remapped. Something like this should work:
#[derive(Module,Debug)]structLinearConv<B:Backend>{linear:Linear<B>,conv:Conv1d<B>,}#[derive(Module,Debug)]structSeqModel<B:Backend>{seq:LinearConv<B>,}// Remap the sequential fields to your new Burn modulelet load_args = LoadArgs::new(weights_file)// Map seq.0.* -> seq.linear.*.with_key_remap("seq\\.0\\.(.+)","seq.linear.$1")// Map seq.1.* -> seq.conv.*.with_key_remap("seq\\.1\\.(.+)","seq.conv.$1");let record = PyTorchFileRecorder::<FullPrecisionSettings>::default().load(load_args,&device).expect("Failed when decoding state dict");
Note: I haven't tested the code, but it should work. We did something similar for the ResNet implementation.
Describe the bug
When attempting to deserialize a PyTorch model containing Sequential layers with tensors having different dimensions in weight (e.g., Linear and Conv1d) into a Burn model, the process panics due to a dimension mismatch error.
To Reproduce
Expected behavior
Should be able to decode.
Debugging info
rust-gdb
shows thatThis suggests that the error occurs when trying to deserialize the weight of Conv1d (Tensor<B, 3>) into a Linear weight (Tensor<B, 2>) while attempting to match each variant in the enum.
The solution would be propagating dimension mismatch error upwards instead of panicking.
Desktop (please complete the following information):
The text was updated successfully, but these errors were encountered: