Skip to content

Commit

Permalink
Cleaned up notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
dead-water committed May 3, 2024
1 parent a70a9ec commit 863c29c
Showing 1 changed file with 1 addition and 111 deletions.
112 changes: 1 addition & 111 deletions notebooks/pretrain_mae_tensorflow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,95 +48,6 @@
"json_file.close()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(177481, 9)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"aligndata.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class FrameGenerator:\n",
" def __init__(self, zarr_file, aligndata, normalisations, n_frames, training = False, inst = None):\n",
" \"\"\" Returns a set of frames with their associated label. \n",
"\n",
" Args:\n",
" path: Video file paths.\n",
" n_frames: Number of frames. \n",
" training: Boolean to determine if training dataset is being created.\n",
" \"\"\"\n",
" self.inst = inst\n",
" self.zarr_file = zarr_file\n",
" self.aligndata = aligndata\n",
" self.normalizations = normalisations\n",
" self.n_frames = n_frames\n",
" self.training = training\n",
" self.idx = 0\n",
" # self.class_names = sorted(set(p.name for p in self.path.iterdir() if p.is_dir()))\n",
" # self.class_ids_for_name = dict((name, idx) for idx, name in enumerate(self.class_names))\n",
"\n",
" def get_aia_frames(self, idx):\n",
" \"\"\"Get AIA image for a given index.\n",
" Returns a numpy array of shape (num_wavelengths, num_frames, height, width).\n",
" \"\"\"\n",
" aia_image_dict = {}\n",
" for wavelength in self.wavelengths:\n",
" aia_image_dict[wavelength] = []\n",
" for frame in range(self.num_frames):\n",
" idx_row_element = self.aligndata.iloc[idx + frame]\n",
" idx_wavelength = idx_row_element[f\"idx_{wavelength}\"]\n",
" year = str(idx_row_element.name.year)\n",
" img = self.zarr_file[year][wavelength][idx_wavelength, :, :]\n",
"\n",
" if self.mask is not None:\n",
" img = img * self.mask\n",
"\n",
" aia_image_dict[wavelength].append(img)\n",
"\n",
" if self.normalizations:\n",
" aia_image_dict[wavelength][-1] -= self.normalizations[\"AIA\"][\n",
" wavelength\n",
" ][\"mean\"]\n",
" aia_image_dict[wavelength][-1] /= self.normalizations[\"AIA\"][\n",
" wavelength\n",
" ][\"std\"]\n",
"\n",
" aia_image = np.array(list(aia_image_dict.values()))\n",
"\n",
" return aia_image\n",
"\n",
" def __call__(self):\n",
" # video_paths, classes = self.get_files_and_class_names()\n",
"\n",
" # pairs = list(zip(video_paths, classes))\n",
"\n",
" # if self.training:\n",
" # random.shuffle(pairs)\n",
"\n",
" for idx in range(self.aligndata.shape[0]):\n",
" # video_frames = frames_from_video_file(path, self.n_frames) \n",
" # label = self.class_ids_for_name[name] # Encode labels \n",
" yield self.get_aia_frames(idx)"
]
},
{
"cell_type": "code",
"execution_count": 16,
Expand Down Expand Up @@ -221,27 +132,6 @@
"aiadl = AIALoader(aia_data, aligndata, normalisations, 1, 2)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"ename": "AttributeError",
"evalue": "'tuple' object has no attribute 'shape'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/walsh/repos/SDO-FM/notebooks/pretrain_mae_tensorflow.ipynb Cell 8\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> <a href='vscode-notebook-cell://ssh-remote%2Bsdofm-workbench-n1-16cpu-60ram-t4x2.us-central1-a.sdo-fm-2024/home/walsh/repos/SDO-FM/notebooks/pretrain_mae_tensorflow.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m aiadl\u001b[39m.\u001b[39;49m\u001b[39m__getitem__\u001b[39;49m(\u001b[39m0\u001b[39;49m)\u001b[39m.\u001b[39;49mshape\n",
"\u001b[0;31mAttributeError\u001b[0m: 'tuple' object has no attribute 'shape'"
]
}
],
"source": [
"aiadl.__getitem__(0).shape"
]
},
{
"cell_type": "code",
"execution_count": 22,
Expand Down Expand Up @@ -288,7 +178,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
" 53/88740 [..............................] - ETA: 42:27:25 - loss: 2.0440"
" 118/88740 [..............................] - ETA: 42:51:56 - loss: 2.0085"
]
}
],
Expand Down

0 comments on commit 863c29c

Please sign in to comment.