Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when deserializing Enum with mismatched weight tensor shapes across variants #2332

Open
med1844 opened this issue Oct 2, 2024 · 1 comment
Labels
bug Something isn't working enhancement Enhance existing features

Comments

@med1844
Copy link
Contributor

med1844 commented Oct 2, 2024

Describe the bug

When attempting to deserialize a PyTorch model containing Sequential layers with tensors having different dimensions in weight (e.g., Linear and Conv1d) into a Burn model, the process panics due to a dimension mismatch error.

To Reproduce

  1. create model:
def gen_seq_block_test(path: pathlib.Path):
    class SeqModel(torch.nn.Module):
        def __init__(self, *args, **kwargs) -> None:
            super().__init__(*args, **kwargs)
            self.seq = torch.nn.Sequential(
                torch.nn.Linear(8, 4),
                torch.nn.Conv1d(4, 4, 3),
            )

    model = SeqModel()
    torch.save(model.state_dict(), path / "test_seq.pt")
  1. load it in Burn:
#[derive(Module, Debug)]
enum SeqType<B: Backend> {
    Linear(Linear<B>),
    Conv1d(Conv1d<B>),
}

#[derive(Module, Debug)]
struct SeqModel<B: Backend> {
    seq: Vec<SeqType<B>>,
}

impl<B: Backend> SeqModel<B> {
    fn new(device: &B::Device) -> Self {
        let linear = LinearConfig::new(8, 4).init(device);
        let conv = Conv1dConfig::new(4, 4, 3).init(device);
        Self {
            seq: vec![SeqType::Linear(linear), SeqType::Conv1d(conv)],
        }
    }
}

fn main() {
    let device = Default::default();
    let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
        .load("./test_seq.pt".into(), &device)
        .expect("Failed when decoding state dict");
    let model = SeqModel::<LibTorch>::new(&device);
    let model = model.load_record(record);
    println!("{:#?}", model);
}
  1. get panic:
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
thread 'main' panicked at /home/med/.cargo/registry/src/index.crates.io-6f17d22bba15001f/b
urn-tensor-0.14.0/src/tensor/api/base.rs:722:9:
=== Tensor Operation Error ===
  Operation: 'From Data'
  Reason:
    1. Given dimensions differ from the tensor rank. Tensor rank: '2', given dimensions: '[4, 4, 3]'.
                                                                                          
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

Expected behavior

Should be able to decode.

Debugging info

rust-gdb shows that

#12 0x00005555556a0ff8 in burn_import::pytorch::adapter::{impl#0}::adapt_linear<burn_core::record::settings::FullPrecisionSettings, burn_tch::backend::LibTorch<f32, i8>> (data=...) at /home/med/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-import-0.14.0/src/pytorch/adapter.rs:35
35              let weight: Param<Tensor<B, 2>> = weight
>>> list
30              let weight = map
31                  .remove("weight")
32                  .expect("Failed to find 'weight' key in map");
33
34              // Convert the weight parameter to a tensor (use default device, since it's quick operation).
35              let weight: Param<Tensor<B, 2>> = weight
36                  .try_into_record::<_, PS, DefaultAdapter, B>(&B::Device::default())
37                  .expect("Failed to deserialize weight");
38
39              // Do not capture transpose op when using autodiff backend
>>> p weight
$11 = burn_core::record::serde::data::NestedValue::Map(HashMap(size=2) = {
    ["\211\352\256\002P"] = burn_core::record::serde::data::NestedValue::Map(HashMap(size=3) = {
        ["\371D\256\002P"] = burn_core::record::serde::data::NestedValue::Map(HashMap(size=1) = {
            ["dloia"] = burn_core::record::serde::data::NestedValue::String(<incomplete sequence \356\256>)
          }),
        ["\334\352\256\002P"] = burn_core::record::serde::data::NestedValue::U8s(Vec(size=192) = {[0] = 63, [1] = 0, [2] = 143, [3] = 60, [4] = 157, [5] = 103, [6] = 36, [7] = 62, [8] = 187, [9] = 107, [10] = 73, [11] = 63, [12] = 142, [13] = 112, [14] = 216, [15] = 190, [16] = 178, [17] = 249, [18] = 39, [19] = 63, [20] = 177, [21] = 233, [22] = 175, [23] = 62, [24] = 25, [25] = 106, [26] = 11, [27] = 191, [28] = 33, [29] = 232, [30] = 53, [31] = 190, [32] = 253, [33] = 130, [34] = 53, [35] = 191, [36] = 215, [37] = 113, [38] = 229, [39] = 62, [40] = 151, [41] = 105, [42] = 204, [43] = 60, [44] = 232, [45] = 17, [46] = 145, [47] = 63, [48] = 221, [49] = 231, [50] = 22, [51] = 63, [52] = 207, [53] = 112, [54] = 94, [55] = 190, [56] = 109, [57] = 116, [58] = 49, [59] = 191, [60] = 182, [61] = 162, [62] = 36, [63] = 191, [64] = 205, [65] = 77, [66] = 214, [67] = 190, [68] = 193, [69] = 33, [70] = 236, [71] = 190, [72] = 67, [73] = 217, [74] = 103, [75] = 63, [76] = 34, [77] = 107, [78] = 164, [79] = 61, [80] = 4, [81] = 8, [82] = 60, [83] = 62, [84] = 230, [85] = 158, [86] = 179, [87] = 61, [88] = 241, [89] = 76, [90] = 49, [91] = 63, [92] = 183, [93] = 70, [94] = 100, [95] = 190, [96] = 189, [97] = 233, [98] = 25, [99] = 191, [100] = 47, [101] = 51, [102] = 181, [103] = 62, [104] = 161, [105] = 182, [106] = 9, [107] = 191, [108] = 170, [109] = 32, [110] = 137, [111] = 62, [112] = 115, [113] = 114, [114] = 22, [115] = 63, [116] = 23, [117] = 169, [118] = 143, [119] = 62, [120] = 151, [121] = 206, [122] = 103, [123] = 190, [124] = 60, [125] = 147, [126] = 197, [127] = 190, [128] = 102, [129] = 211, [130] = 148, [131] = 61, [132] = 187, [133] = 152, [134] = 236, [135] = 61, [136] = 118, [137] = 201, [138] = 141, [139] = 59, [140] = 99, [141] = 180, [142] = 145, [143] = 189, [144] = 208, [145] = 204, [146] = 201, [147] = 61, [148] = 12, [149] = 115, [150] = 18, [151] = 191, [152] = 171, [153] = 56, [154] = 45, [155] = 62, [156] = 105, [157] = 126, [158] = 89, [159] = 63, [160] = 5, [161] = 25, [162] = 180, [163] = 63, [164] = 12, [165] = 54, [166] = 56, [167] = 62, [168] = 103, [169] = 237, [170] = 55, [171] = 189, [172] = 161, [173] = 184, [174] = 106, [175] = 62, [176] = 36, [177] = 127, [178] = 144, [179] = 190,
[180] = 216, [181] = 52, [182] = 7, [183] = 63, [184] = 28, [185] = 13, [186] = 60, [187]= 191, [188] = 162, [189] = 113, [190] = 55, [191] = 63}),
        [" nK\340\377"] = burn_core::record::serde::data::NestedValue::Vec(Vec(size=3) = {[0] = <error reading variable: Could not find active enum variant>, [1] = burn_core::record::serde::data::NestedValue::U64(4), [2] = burn_core::record::serde::data::NestedValue::U64(3)})
      }),
    ["\200n"] = burn_core::record::serde::data::NestedValue::String("\t\354\256\002PU\000\000\237\023")
  })
>>> up 1
#13 0x00005555556a1b03 in burn_core::record::serde::adapter::BurnModuleAdapter::adapt<burn_import::pytorch::adapter::PyTorchAdapter<burn_core::record::settings::FullPrecisionSettings, burn_tch::backend::LibTorch<f32, i8>>> (name="Linear", data=...) at /home/med/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-core-0.14.0/src/record/serde/adapter.rs:20
20                  "Linear" => Self::adapt_linear(data),
>>> p data
$12 = burn_core::record::serde::data::NestedValue::Map(HashMap(size=2) = {
    ["bias"] = burn_core::record::serde::data::NestedValue::Map(HashMap(size=2) = {
        ["param"] = burn_core::record::serde::data::NestedValue::Map(HashMap(size=3) = {
            ["bytes"] = burn_core::record::serde::data::NestedValue::U8s(Vec(size=16) = {[0] = 66, [1] = 191, [2] = 117, [3] = 63, [4] = 255, [5] = 203, [6] = 230, [7] = 189, [8] = 249, [9] = 242, [10] = 46, [11] = 190, [12] = 143, [13] = 167, [14] = 27, [15] = 62}),
            ["shape"] = burn_core::record::serde::data::NestedValue::Vec(Vec(size=1) = {[0] = burn_core::record::serde::data::NestedValue::U64(4)}),
            ["dtype"] = burn_core::record::serde::data::NestedValue::Map(HashMap(size=1) = {
                ["DType"] = burn_core::record::serde::data::NestedValue::String("F32")
              })
          }),
        ["id"] = burn_core::record::serde::data::NestedValue::String("5e9vn1d4ks")
      })Traceback (most recent call last):
  File "/home/med/.rustup/toolchains/stable-x86_64-unknown-linux-gnu/lib/rustlib/etc/gdb_providers.py", line 447, in children
    idx = self._valid_indices[index]
IndexError: list index out of range

    ...
  })
>>> up 7
#20 0x00005555556e9453 in burn_core::record::serde::de::{impl#1}::deserialize_enum<burn_import::pytorch::adapter::PyTorchAdapter<burn_core::record::settings::FullPrecisionSettings, burn_tch::backend::LibTorch<f32, i8>>, fish_speech_rs::_::{impl#0}::deserialize::__Visitor<burn_tch::backend::LibTorch<f32, i8>, burn_core::record::settings::FullPrecisionSettings>> (self=..., _name="SeqTypeRecordItem", variants=&[&str](size=2) = {...}, visitor=...) at /home/med/.cargo/registry/src/index.crates.io-6f17d22bba15001f/burn-core-0.14.0/src/record/serde/de.rs:383
383                 let result = cloned_visitor.visit_enum(ProbeEnumAccess::<A>::new(
>>> list
378
379             // Try each variant in order
380             for &variant in variants {
381                 // clone visitor to avoid moving it
382                 let cloned_visitor = clone_unsafely(&visitor);
383                 let result = cloned_visitor.visit_enum(ProbeEnumAccess::<A>::new(
384                     self.value.clone().unwrap(),
385                     variant.to_owned(),
386                     self.default_for_missing_fields,
387                 ));
>>>

This suggests that the error occurs when trying to deserialize the weight of Conv1d (Tensor<B, 3>) into a Linear weight (Tensor<B, 2>) while attempting to match each variant in the enum.

The solution would be propagating dimension mismatch error upwards instead of panicking.

Desktop (please complete the following information):

  • OS: Ubuntu 22.04.4 LTS on Windows 10 x86_64
  • burn = { version = "0.14.0", features = ["tch", "wgpu"] }
  • burn-import = "0.14.0"
  • rustc 1.81.0 (eeb90cda1 2024-09-04)
@laggui laggui added bug Something isn't working enhancement Enhance existing features labels Oct 2, 2024
@laggui
Copy link
Member

laggui commented Oct 2, 2024

This suggests that the error occurs when trying to deserialize the weight of Conv1d (Tensor<B, 3>) into a Linear weight (Tensor<B, 2>) while attempting to match each variant in the enum.

The flexibility of sequential blocks has a couple of drawbacks, and you just faced one 😅 In pytorch you can basically insert any type of module you want, but with Burn replicating this behavior is a bit "hacky".

We use serde's untagged enum deserialization to try to match the correct module, but in this case both the linear and conv layers have the same fields so they technically match.

The solution would be propagating dimension mismatch error upwards instead of panicking.

Haven't looked at this part of the codebase in a while, but I'd have to check how this could be handled better (if possible).

In the meantime, you could declare a module that contains a fixed number of fields instead of having a sequential-like type. Then, the fields can simply be remapped. Something like this should work:

#[derive(Module, Debug)]
struct LinearConv<B: Backend> {
  linear: Linear<B>,
  conv: Conv1d<B>,
}

#[derive(Module, Debug)]
struct SeqModel<B: Backend> {
  seq: LinearConv<B>,
}

// Remap the sequential fields to your new Burn module
let load_args = LoadArgs::new(weights_file)
  // Map seq.0.* -> seq.linear.*
  .with_key_remap("seq\\.0\\.(.+)", "seq.linear.$1")
  // Map seq.1.* -> seq.conv.*
  .with_key_remap("seq\\.1\\.(.+)", "seq.conv.$1");
let record = PyTorchFileRecorder::<FullPrecisionSettings>::default().load(load_args, &device)
   .expect("Failed when decoding state dict");

Note: I haven't tested the code, but it should work. We did something similar for the ResNet implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement Enhance existing features
Projects
None yet
Development

No branches or pull requests

2 participants