Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tweak torch parameter registration mechanism #19908

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

haohuanw
Copy link
Contributor

@haohuanw haohuanw commented Jun 23, 2024

this is a follow up from #19885 discussion where i am trying to make torch / keras well played together on tracking parameters.

the solution i ended up with:

  1. since modules are properly tracked with torch module, every torch_params will only safe it's own variables. nested variable resolution will be done by torch with recurse=True
  2. change back to use parameter list instead of dict. i did consider to keep using dict given the readability since now key in torch param could actually be variable.name with just tracking variables the current layer holds. however, current seed generator actually create duplicated variable names. if https://github.com/keras-team/keras/blob/master/keras/src/random/seed_generator.py#L80 can be changed to something like f"{self.name}_generator_state" it will work with ParameterDict approach.
  3. in _post_track/untrack_variables, refresh the entire torch params and it's sublayers. this could be changed to not re-create all sublayers if this function ever becomes too slow.

i also added few torch specific tests to reflect some of the assumptions and usecases that torch user might have. eg. use state_dict.

@codecov-commenter
Copy link

codecov-commenter commented Jun 23, 2024

Codecov Report

Attention: Patch coverage is 79.36508% with 13 lines in your changes missing coverage. Please review.

Project coverage is 79.27%. Comparing base (a226835) to head (37ab7f4).
Report is 175 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/torch/trainer.py 43.75% 3 Missing and 6 partials ⚠️
keras/src/backend/torch/layer.py 90.24% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master   #19908   +/-   ##
=======================================
  Coverage   79.27%   79.27%           
=======================================
  Files         501      501           
  Lines       46921    46979   +58     
  Branches     8648     8666   +18     
=======================================
+ Hits        37195    37241   +46     
- Misses       7981     7985    +4     
- Partials     1745     1753    +8     
Flag Coverage Δ
keras 79.12% <79.36%> (-0.01%) ⬇️
keras-jax 62.39% <6.34%> (-0.07%) ⬇️
keras-numpy 57.45% <3.17%> (-0.07%) ⬇️
keras-tensorflow 63.77% <3.17%> (-0.04%) ⬇️
keras-torch 62.54% <79.36%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@haohuanw
Copy link
Contributor Author

the failing pytorch test is actually passing on my env:

(keras-dev-minimum) haohuanw@haohuanw-ThinkPad-X1-Extreme:~/Documents/keras$ KERAS_BACKEND=torch python integration_tests/numerical_test.py 
2024-06-23 16:13:12.028332: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-23 16:13:12.031879: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-23 16:13:12.080855: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-06-23 16:13:12.900432: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-06-23 16:13:14.305362: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-06-23 16:13:14.305867: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
Checking training histories:
accuracy:
[0.20999999344348907]
[0.20999999344348907]
loss:
[2.6727349758148193]
[2.6727302074432373]
mae:
[0.1606517732143402]
[0.16065318882465363]
Training histories match.

Checking trained weights:
Trained weights match.

Checking predict:
Predict results match.

Checking evaluate:
[2.2113966941833496, 0.17798443138599396, 0.3799999952316284]
[2.2114176750183105, 0.17798538506031036, 0.3799999952316284]
Evaluate results match.

@fchollet
Copy link
Member

if https://github.com/keras-team/keras/blob/master/keras/src/random/seed_generator.py#L80 can be changed to something like f"{self.name}_generator_state" it will work with ParameterDict approach.

The uniqueness of the variable path should come from the parent object name, not from the variable name (e.g. "dense_1/kernel"). What paths do you currently see for seed generators?

@haohuanw
Copy link
Contributor Author

if https://github.com/keras-team/keras/blob/master/keras/src/random/seed_generator.py#L80 can be changed to something like f"{self.name}_generator_state" it will work with ParameterDict approach.

The uniqueness of the variable path should come from the parent object name, not from the variable name (e.g. "dense_1/kernel"). What paths do you currently see for seed generators?

for seed generator if using path it will be seed_generator_{idx}/seed_generator_state, using name it will be the seed_generator_state. the spirit of this change is to let torch module to handle the recursive collection so i was planning to use variable name but find out that there are collisions on seed generator state.

@fchollet
Copy link
Member

the spirit of this change is to let torch module to handle the recursive collection so i was planning to use variable name but find out that there are collisions on seed generator state.

Variable names are never unique. For a unique string you can use variable.path.

@haohuanw
Copy link
Contributor Author

haohuanw commented Jun 24, 2024

the spirit of this change is to let torch module to handle the recursive collection so i was planning to use variable name but find out that there are collisions on seed generator state.

Variable names are never unique. For a unique string you can use variable.path.

i thought under same layer the variable name (excluding the variables from its sub layers) should be unique as an implicit requirement since otherwise variable.path would also be non unique?

my original thought is that self.torch_params will only have variables for the layer, excluding variables in the sub-layers since it will automatically get collected when calling named_parameters() since all sub layers are properly recognized as a sub torch module and it will respect recurse option. (eg. https://github.com/pytorch/pytorch/blob/662e9e10766b040bea000e18e54a4f9e69889fc1/torch/nn/modules/module.py#L2496C20-L2496C34 _named_members will include all registered sub layers.)

then i found out that all seed_generator actually can actually create with same variable name if there are multiple seed generator in one layer since seed generator is not a layer.

@haohuanw
Copy link
Contributor Author

the spirit of this change is to let torch module to handle the recursive collection so i was planning to use variable name but find out that there are collisions on seed generator state.

Variable names are never unique. For a unique string you can use variable.path.

i thought under same layer the variable name (excluding the variables from its sub layers) should be unique as an implicit requirement since otherwise variable.path would also be non unique?

my original thought is that self.torch_params will only have variables for the layer, excluding variables in the sub-layers since it will automatically get collected when calling named_parameters() since all sub layers are properly recognized as a sub torch module and it will respect recurse option. (eg. https://github.com/pytorch/pytorch/blob/662e9e10766b040bea000e18e54a4f9e69889fc1/torch/nn/modules/module.py#L2496C20-L2496C34 _named_members will include all registered sub layers.)

then i found out that all seed_generator actually can actually create with same variable name if there are multiple seed generator in one layer since seed generator is not a layer.

and i do notice that i probably want to add a test with nested seed generator. in theory, seed states should be recursively collected by torch since it basically get all module._parameters for all its submodules.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks good -- thanks for the changes. I will apply docstring fixes after merging.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jun 24, 2024
@fchollet
Copy link
Member

the failing pytorch test is actually passing on my env:

Works for me locally as well. Might be a fluke.

@fchollet
Copy link
Member

There are actually various tests that reliably fail here: https://btx.cloud.google.com/invocations/c55a2ca4-5df3-411b-bd52-7c9873e839ce/targets/keras%2Fgithub%2Fubuntu%2Fgpu%2Ftorch%2Fpresubmit/log (not the numerical integration test)

@haohuanw
Copy link
Contributor Author

haohuanw commented Jun 24, 2024

There are actually various tests that reliably fail here: https://btx.cloud.google.com/invocations/c55a2ca4-5df3-411b-bd52-7c9873e839ce/targets/keras%2Fgithub%2Fubuntu%2Fgpu%2Ftorch%2Fpresubmit/log (not the numerical integration test)

i will address those today / tmr 👍 - and is it possible to configure ci to run pytest regardless whether integration test passes or not?

@fchollet
Copy link
Member

is it possible to configure ci to run pytest regardless whether integration test passes or not?

We'd have to move the integration testing to go after the general pytest command in .github/workflows/actions.yml (job name Run tests)

@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Jun 25, 2024
@haohuanw
Copy link
Contributor Author

i am seeing a weird issue on keras/src/dtype_policies/dtype_policy_map_test.py::DTypePolicyMapTest::test_basic_usage with error like this:


error_msgs = {133475339071600: (<Dense name=subclass_dense, built=True>, ValueError("Layer 'subclass_dense' expected 4 variables, but received 3 variables during loading. Expected: ['bias', 'kernel', 'kernel', 'kernel_scale']"))}
warn_only = False

    def _raise_loading_failure(error_msgs, warn_only=False):
        first_key = list(error_msgs.keys())[0]
        ex_saveable, ex_error = error_msgs[first_key]
        msg = (
            f"A total of {len(error_msgs)} objects could not "
            "be loaded. Example error message for "
            f"object {ex_saveable}:\n\n"
            f"{ex_error}\n\n"
            "List of objects that could not be loaded:\n"
            f"{[x[0] for x in error_msgs.values()]}"
        )
        if warn_only:
            warnings.warn(msg)
        else:
>           raise ValueError(msg)
E           ValueError: A total of 1 objects could not be loaded. Example error message for object <Dense name=subclass_dense, built=True>:
E           
E           Layer 'subclass_dense' expected 4 variables, but received 3 variables during loading. Expected: ['bias', 'kernel', 'kernel', 'kernel_scale']
E           
E           List of objects that could not be loaded:
E           [<Dense name=subclass_dense, built=True>]

i am able to isolate that the model json looks good but restored model here: https://github.com/keras-team/keras/blob/master/keras/src/saving/saving_lib.py#L242 have duplicated <KerasVariable shape=(4, 8), dtype=int8, path=subclass/subclass_dense/kernel>. any idea how this change is relevant to this change?

@fchollet
Copy link
Member

I don't understand the connection. You could try pruning things from your change until the test passes, then you'll have a good idea what particular lines are causing the issue.

@haohuanw
Copy link
Contributor Author

@fchollet most of the unit tests are fixed with one issue left. torch basically requires user to use nn.ModuleList to wrap any modules, so for keras3 any modules that directly passing with a list won't work.

there are two options that i think could work, let me know your thoughts:

  1. create keras.layers.LayerList that mimics what nn.ModuleList has so all modules could be properly tracked by torch modules. a check in setattr could be added to automatically wrap with LayerList when a list[Layer] is passed.
  2. specific in torch backend, when a list[Layer] is observed, wrap it in nn.ModuleList() and then wrap it with TorchModuleWrapper(). this might work but i will need to double check the parameter tracking logic.
  3. (doesn't seem to work) specific in torch backend, when a list[Layer] is observed, also call self.register_module() in setattr_hook to double register the layer. i tried this and it works in most of the cases except serialization since setattr_hook is not called during deserialization.

let me know what do you think.

@fchollet fchollet added the keras-team-review-pending Pending review by a Keras team member. label Jun 27, 2024
@fchollet
Copy link
Member

specific in torch backend, when a list[Layer] is observed, wrap it in nn.ModuleList() and then wrap it with TorchModuleWrapper(). this might work but i will need to double check the parameter tracking logic.

I think we could do this, via __setattr__ in TorchLayer. There should not be any downsides?

@haohuanw
Copy link
Contributor Author

specific in torch backend, when a list[Layer] is observed, wrap it in nn.ModuleList() and then wrap it with TorchModuleWrapper(). this might work but i will need to double check the parameter tracking logic.

I think we could do this, via __setattr__ in TorchLayer. There should not be any downsides?

in theory - let me try it

@haohuanw
Copy link
Contributor Author

specific in torch backend, when a list[Layer] is observed, wrap it in nn.ModuleList() and then wrap it with TorchModuleWrapper(). this might work but i will need to double check the parameter tracking logic.

I think we could do this, via __setattr__ in TorchLayer. There should not be any downsides?

it technically works but i think this will be a pretty impact workflow change for pytorch users:

  1. every time referencing a layer list (for example run forward or quantize those layers), it needs to change from:
    for l in self.layers to for l in self.layers.module which is not really ideal and only specific for torch.

  2. another issue is when re-tracking the parameters since currently the idea is to have every layer only track it's own layers by doing it recursively. additional logic needs to be added to handle the special case where keras layer is wrapped into the torch wrapper.

I think supporting a keras.LayerList is actually a cleaner approach (not sure if it introduces any challenges for serialization) to better support pytorch backend without impacting much on tf/jax side. i think what we can do is to make this an opt-in feature where we warn users in TorchLayer that they have to use keras.LayerList to make sure torch params are being properly tracked but other backend users don't need to worry about using it.

@haohuanw
Copy link
Contributor Author

haohuanw commented Jul 1, 2024

^ @fchollet let me know if above proposal sgty or you have a better idea.

@fchollet
Copy link
Member

fchollet commented Jul 1, 2024

I would recommend not adding any new API.

What about this approach:

(doesn't seem to work) specific in torch backend, when a list[Layer] is observed, also call self.register_module() in setattr_hook to double register the layer. i tried this and it works in most of the cases except serialization since setattr_hook is not called during deserialization.

Could it be fixed? What's the issue with deserialization?

@haohuanw
Copy link
Contributor Author

haohuanw commented Jul 2, 2024

I would recommend not adding any new API.

What about this approach:

(doesn't seem to work) specific in torch backend, when a list[Layer] is observed, also call self.register_module() in setattr_hook to double register the layer. i tried this and it works in most of the cases except serialization since setattr_hook is not called during deserialization.

Could it be fixed? What's the issue with deserialization?

@fchollet error message:

================================================================================================== FAILURES ===================================================================================================
______________________________________________________________________________________ SequentialTest.test_serialization ______________________________________________________________________________________

self = <keras.src.models.sequential_test.SequentialTest testMethod=test_serialization>

    def test_serialization(self):
        # Unbuilt deferred
        model = Sequential(name="seq")
        model.add(layers.Dense(4))
        model.add(layers.Dense(5))
        revived = self.run_class_serialization_test(model)
        self.assertLen(revived.layers, 2)
    
        # Built deferred
        model.build((2, 3))
>       revived = self.run_class_serialization_test(model)

keras/src/models/sequential_test.py:231: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
keras/src/testing/test_case.py:133: in run_class_serialization_test
    self.assertEqual(set(ref_dir), set(dir(revived_instance)))
E   AssertionError: Items in the first set but not the second:
E   '_layers_0'
E   '_layers_2'
E   '_layers_1'

i think the main issue to me is that deserialization don't triggers _setattr_hook (which means __setattr__ won't be triggered) and causes the torch states won't be registered properly. i haven't dig deeper into the actual error, let me know if this is actually something that should be easy to fix and i should dig a little deeper.

@fchollet
Copy link
Member

fchollet commented Jul 2, 2024

E AssertionError: Items in the first set but not the second:

What are those _layer_n attributes?

i think the main issue to me is that deserialization don't triggers _setattr_hook

Why so? I would expect deserialization to trigger the same code path as regular instantiation.

@haohuanw
Copy link
Contributor Author

haohuanw commented Jul 2, 2024

E AssertionError: Items in the first set but not the second:

What are those _layer_n attributes?

i think the main issue to me is that deserialization don't triggers _setattr_hook

Why so? I would expect deserialization to trigger the same code path as regular instantiation.

the _layer_n attribute is done by self.register_module in torch. dir(module) basically loops through attributes.

on deserialization trigger: the _layer_n is the ones created by _setattr_hook. if it is not created that most likely means the function is not triggered. but it sounds like this is not expected i will dig deeper on why this is not working.

@divyashreepathihalli divyashreepathihalli removed the keras-team-review-pending Pending review by a Keras team member. label Jul 11, 2024
@haohuanw
Copy link
Contributor Author

@fchollet sorry for the delay on this pr. i am able to get the final unit test fixed, please TAL.

@james77777778 fyi, since you had couple prs around torch variable tracking. please also take a look and let me know if this will break any of your current use cases.

@fchollet
Copy link
Member

Thanks for the update! There seems to be one failing test: https://github.com/keras-team/keras/actions/runs/10128018370/job/28006468053?pr=19908

@james77777778
Copy link
Contributor

@james77777778 fyi, since you had couple prs around torch variable tracking. please also take a look and let me know if this will break any of your current use cases.

The changes are significant, so it is difficult for me to say if this PR will break anything.
However, if all tests pass, it should be fine.

@haohuanw
Copy link
Contributor Author

@fchollet should be ready to go

Copy link
Contributor

@james77777778 james77777778 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have taken a look since I was mentioned in this PR.

keras/src/layers/attention/grouped_query_attention_test.py Outdated Show resolved Hide resolved
@@ -115,7 +115,8 @@ def add(self, layer, rebuild=True):
f"add a different Input layer to it."
)

self._layers.append(layer)
# append will not trigger __setattr__ for tracking purpose.
self._layers = self._layers + [layer]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this indicate that the tracking will fail if we add a stateful object without __setattr__?
Or is this a special case for Sequential model?

Copy link
Contributor Author

@haohuanw haohuanw Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is like keras thing: basically auto tracking relies on __setattr__ being triggered. if __setattr__ is not triggered then the variables / layers are actually not being tracked in auto tracker. When __setattr__ is not properly triggered, all layers can still properly run but anything relies on autotracking being properly worked will break. (sorry my understanding here is wrong, append should be well supported if just in keras)

this is a slightly bigger thing in torch since torch relies on torch.ModuleList instead of plain dictionary for any collection of layers. with the keras3 logic (where a hook is added in torch layer when __setattr__ is triggered to register modules) it is very important to ensure __setattr__ is properly triggered.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is still unclear to me why this PR breaks keras/src/models/sequential_test.py::SequentialTest::test_serialization when using self._layers.append(layer).

Shouldn't we have tracked these layers in _maybe_rebuild? It doesn't break in master branch.
I am worried that this discrepancy might break some models that use python containers (e.g. list and dict) for layer creation. Or is it already broken this way in the torch backend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is it already broken this way in the torch backend?

yes - this is already broken on torch backend. you can try running below snippet in main branch v.s. the change

class StateAddingWrong(keras.layers.Layer):

    def build(self, _):
        self.layers = []
        self.layers.append(keras.layers.Dense(4, name="0"))
        self.layers.append(keras.layers.Dense(5, name="1"))
        self.layers.append(keras.layers.Dense(6, name="2"))

    def call(self, x):
        y = x
        for l in self.layers:
            y = l(y)
        return y


class StateAddingCorrect(keras.layers.Layer):

    def build(self, _):
        layers = []
        layers.append(keras.layers.Dense(4, name="0"))
        layers.append(keras.layers.Dense(5, name="1"))
        layers.append(keras.layers.Dense(6, name="2"))
        self.layers = layers

    def call(self, x):
        y = x
        for l in self.layers:
            y = l(y)
        return y


layer = StateAddingWrong(name='wrong')
layer(np.ones((2, 3)))
print(f"wrong tracked layers {len(layer._tracker.stored_ids['layers'])}")
print(f"wrong ", list(layer.named_modules()))

layer = StateAddingCorrect(name="correct")
layer(np.ones((2, 3)))
print(f"correct tracked layers {len(layer._tracker.stored_ids['layers'])}")
print(f"correct ", list(layer.named_modules()))

basically the named_modules() won't return layers appends to layers. my change fixes the case by using the pattern of not appending to layers. the most correct way to tackle this might be handling callbacks here: https://github.com/keras-team/keras/blob/master/keras/src/utils/tracking.py#L144-L147 @fchollet let me know if you have better ideas that could also sync the appending cases for torch backend.

It is still unclear to me why this PR breaks keras/src/models/sequential_test.py::SequentialTest::test_serialization when using self._layers.append(layer)

so a little bit deeper rca:

the main error is on _layers_{0, 1, 2} attributes missing in deserialization test and i have found that the main issue is due to build & deserialization order.

here is the minimum reproducible test:

model = keras.Sequential(name="seq")
model.add(keras.layers.Dense(4, name="0"))
model.add(keras.layers.Dense(5, name="1"))
model.add(keras.layers.Dense(6, name="2"))
model.build((2, 3))

for m in model.named_modules():
    print(f"module {m}")
print("_layers_0" in dir(model))

config = model.get_config()
config_json = json.dumps(config, sort_keys=True, indent=4)
revived_model = keras.Sequential.from_config(config)

for m in model.named_modules():
    print(f"module {m}")
print("_layers_1" in dir(revived_model))

during build phase, self._layers is not set till model.build is called since https://github.com/keras-team/keras/blob/master/keras/src/models/sequential.py#L137-L145 conditions are not met. Thus, the layers will actually being called with all layers in the model with model.build(shape) call here: https://github.com/keras-team/keras/blob/master/keras/src/models/sequential.py#L179-L182. during deserialization phase, InputLayer is actually available so every add is rebuilding the model without resetting layers. the fix i have essentially forced resetting every time.

functionality wise all things are good since Functional in Sequential is actually the model being invoked and it has layers reset all the time here: https://github.com/keras-team/keras/blob/master/keras/src/models/functional.py#L140

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the detailed explanation!

If this is the case, would it be more reasonable to put the logic in _maybe_build, since we might want to track these layers in both add and pop?

Actually, I think your solution of adding the logic in track.py is the best, but it might add some backend-specific code to it. This should wait for @fchollet 's call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the case, would it be more reasonable to put the logic in _maybe_build, since we might want to track these layers in both add and pop?

Agreed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet what's your thought on adding it to tracking.py? putting logics in _maybe_build can fix sequential specific issue but append/pop still won't be tracked by torch in current integration unless user triggers a __setattr__.

Options are:

  1. fixing this by adding some backend specific callbacks in tracking.py
  2. add a validation step in torch layer and warn user there are layers that not properly registered in torch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ @fchollet let me know if you have any thoughts on above

self._track_variables()

def _track_variables(self):
"""Adaptation layer to make sure keras.layers.Layer works well with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use backticks around all code keywords in docstrings.

torch.nn.Module. Currently, the main modification are on parameter/module
tracking and pointing torch.nn.Module.forward() to the right keras call.

Module tracking:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a section outside of the normal set of sections to be highlighted, use markdown syntax, e.g.

**Module tracking:**

(content)

tracking and pointing torch.nn.Module.forward() to the right keras call.

Module tracking:
All sublayers are tracked as modules in Module._modules. All module level
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Text content does not need to be indented.

non trainable and seed generator states belongs to the current layer.

Few additional points that user should be aware of:
1. When torch backend is enabled KerasVariable.value is torch.nn.Parameter,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to introduce a line break (blank line) before a list or after a section title.

# as symbolic build is needed to make sure all layers variables are
# initialized before invoke torch.compile(). _symbolic_build()
# should refactored to not require _compile_metrics and optimizer
# is defined.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_symbolic_build() should refactored to not require _compile_metrics and optimizer is defined.

Can you explain?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_symbolic_build requires https://github.com/keras-team/keras/blob/master/keras/src/trainers/trainer.py#L1005-L1014 and these attributes are added during model.compile.

in this change, for torch params to work properly the requirements is that all layers are built before making any torch specific api calls. the easiest way to achieve this inside the trainer is to ensure _symbolic_build is called for all cases that jit_compile is true since torch.compile will call named_parameters() to compile the graph.

during unit test, i have found few unit tests breaks due to not calling torch.compile but requires calling named_parameters and it fails during _symbolic_build due to missing compile_metrics, compile_loss and optimizer attribute. i can try remove these and fixing the unit test instead.

@@ -115,7 +115,8 @@ def add(self, layer, rebuild=True):
f"add a different Input layer to it."
)

self._layers.append(layer)
# append will not trigger __setattr__ for tracking purpose.
self._layers = self._layers + [layer]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the case, would it be more reasonable to put the logic in _maybe_build, since we might want to track these layers in both add and pop?

Agreed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

7 participants