how to define and train a model with mult forward? #3456
Replies: 1 comment
-
Disclaimer: I'm not a Flax developer but I've used it for about a year now. For apply_fn: Per the docs, you don't have to use TrainState's apply_fn, it is just for convenience. You can just set it to None. In general, when you want to apply one of the encoders as your forward method, you can call apply on your instantiated Flax module, and pass in the appropriate forward method required, like so: Model().apply({"params": params}, inputs, method=Model.e1) For the __call__ error: During initialization, I assume you want to obtain the parameters for all encoders wrapped by the Model class, so would be convenient to define the __call__ method to run all of them. When you run init with any given method (e.g., Misc.: |
Beta Was this translation helpful? Give feedback.
-
here is simple case, one model have mult forward:
class model(nn.Module):
here is the problem: it don't have call it will raise a error. And when train the model, TrainState only have one apply_fn
Beta Was this translation helpful? Give feedback.
All reactions