diff --git a/graphcast/autoregressive.py b/graphcast/autoregressive.py index 1cf1324..9330463 100644 --- a/graphcast/autoregressive.py +++ b/graphcast/autoregressive.py @@ -112,7 +112,7 @@ def _validate_targets_and_forcings(self, targets, forcings): f'forcings, which isn\'t allowed: {overlap}') def _update_inputs(self, inputs, next_frame): - num_inputs = inputs.dims['time'] + num_inputs = inputs.sizes['time'] predicted_or_forced_inputs = next_frame[list(inputs.keys())] @@ -199,7 +199,7 @@ def one_step_prediction(inputs, scan_variables): return next_inputs, flat_pred if self._gradient_checkpointing: - scan_length = targets_template.dims['time'] + scan_length = targets_template.sizes['time'] if scan_length <= 1: logging.warning( 'Skipping gradient checkpointing for sequence length of 1') diff --git a/graphcast/rollout.py b/graphcast/rollout.py index b243d0f..6797897 100644 --- a/graphcast/rollout.py +++ b/graphcast/rollout.py @@ -124,7 +124,7 @@ def chunked_prediction_generator( if "datetime" in forcings.coords: del forcings.coords["datetime"] - num_target_steps = targets_template.dims["time"] + num_target_steps = targets_template.sizes["time"] num_chunks, remainder = divmod(num_target_steps, num_steps_per_chunk) if remainder != 0: raise ValueError( @@ -202,7 +202,7 @@ def _get_next_inputs( next_inputs = next_frame[next_inputs_keys] # Apply concatenate next frame with inputs, crop what we don't need. - num_inputs = prev_inputs.dims["time"] + num_inputs = prev_inputs.sizes["time"] return ( xarray.concat( [prev_inputs, next_inputs], dim="time", data_vars="different")