From 57e7435acb10d829989cbe73be40c5979b1d4ff7 Mon Sep 17 00:00:00 2001 From: dead-water Date: Mon, 6 May 2024 11:48:28 +0000 Subject: [PATCH] Fruther tensorflow tests for dist --- notebooks/pretrain_mae_tensorflow.ipynb | 203 ++++++++++++++++++++---- scripts/pretrain_mae_tensorflow.py | 25 +-- tpu_setup_tf.sh | 8 +- 3 files changed, 192 insertions(+), 44 deletions(-) diff --git a/notebooks/pretrain_mae_tensorflow.ipynb b/notebooks/pretrain_mae_tensorflow.ipynb index e90b5a4..abbdaee 100644 --- a/notebooks/pretrain_mae_tensorflow.ipynb +++ b/notebooks/pretrain_mae_tensorflow.ipynb @@ -9,17 +9,19 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/opt/conda/envs/sdofm/lib/python3.10/site-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", - " torch.utils._pytree._register_pytree_node(\n", - "/opt/conda/envs/sdofm/lib/python3.10/site-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", - " torch.utils._pytree._register_pytree_node(\n" + "2024-05-06 11:45:57.705064: I tensorflow/core/tpu/tpu_api_dlsym_initializer.cc:95] Opening library: /usr/local/lib/python3.10/dist-packages/tensorflow/python/platform/../../libtensorflow_cc.so.2\n", + "2024-05-06 11:45:57.705188: I tensorflow/core/tpu/tpu_api_dlsym_initializer.cc:119] Libtpu path is: libtpu.so\n", + "2024-05-06 11:45:57.804402: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "/home/walsh/.local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], @@ -35,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -50,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -125,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -134,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -143,9 +145,125 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "D0506 11:46:12.190437091 35513 config.cc:183] gRPC EXPERIMENT block_excessive_requests_before_settings_ack OFF (default:OFF)\n", + "D0506 11:46:12.190456321 35513 config.cc:183] gRPC EXPERIMENT call_status_override_on_cancellation OFF (default:OFF)\n", + "D0506 11:46:12.190459501 35513 config.cc:183] gRPC EXPERIMENT canary_client_privacy ON (default:ON)\n", + "D0506 11:46:12.190461761 35513 config.cc:183] gRPC EXPERIMENT chttp2_batch_requests OFF (default:OFF)\n", + "D0506 11:46:12.190463781 35513 config.cc:183] gRPC EXPERIMENT chttp2_offload_on_rst_stream OFF (default:OFF)\n", + "D0506 11:46:12.190465751 35513 config.cc:183] gRPC EXPERIMENT client_privacy ON (default:ON)\n", + "D0506 11:46:12.190467691 35513 config.cc:183] gRPC EXPERIMENT combiner_offload_to_event_engine OFF (default:OFF)\n", + "D0506 11:46:12.190469671 35513 config.cc:183] gRPC EXPERIMENT event_engine_client OFF (default:OFF)\n", + "D0506 11:46:12.190471901 35513 config.cc:183] gRPC EXPERIMENT event_engine_dns OFF (default:OFF)\n", + "D0506 11:46:12.190473681 35513 config.cc:183] gRPC EXPERIMENT event_engine_listener OFF (default:OFF)\n", + "D0506 11:46:12.190475611 35513 config.cc:183] gRPC EXPERIMENT free_large_allocator OFF (default:OFF)\n", + "D0506 11:46:12.190477521 35513 config.cc:183] gRPC EXPERIMENT jitter_max_idle ON (default:ON)\n", + "D0506 11:46:12.190479481 35513 config.cc:183] gRPC EXPERIMENT keepalive_fix OFF (default:OFF)\n", + "D0506 11:46:12.190481301 35513 config.cc:183] gRPC EXPERIMENT keepalive_server_fix ON (default:ON)\n", + "D0506 11:46:12.190483001 35513 config.cc:183] gRPC EXPERIMENT lazier_stream_updates OFF (default:OFF)\n", + "D0506 11:46:12.190484601 35513 config.cc:183] gRPC EXPERIMENT memory_pressure_controller OFF (default:OFF)\n", + "D0506 11:46:12.190486141 35513 config.cc:183] gRPC EXPERIMENT monitoring_experiment ON (default:ON)\n", + "D0506 11:46:12.190487801 35513 config.cc:183] gRPC EXPERIMENT multiping OFF (default:OFF)\n", + "D0506 11:46:12.190489581 35513 config.cc:183] gRPC EXPERIMENT peer_state_based_framing OFF (default:OFF)\n", + "D0506 11:46:12.190491221 35513 config.cc:183] gRPC EXPERIMENT pick_first_happy_eyeballs OFF (default:OFF)\n", + "D0506 11:46:12.190492791 35513 config.cc:183] gRPC EXPERIMENT ping_on_rst_stream OFF (default:OFF)\n", + "D0506 11:46:12.190494381 35513 config.cc:183] gRPC EXPERIMENT promise_based_client_call OFF (default:OFF)\n", + "D0506 11:46:12.190495991 35513 config.cc:183] gRPC EXPERIMENT promise_based_server_call OFF (default:OFF)\n", + "D0506 11:46:12.190497601 35513 config.cc:183] gRPC EXPERIMENT red_max_concurrent_streams OFF (default:OFF)\n", + "D0506 11:46:12.190499271 35513 config.cc:183] gRPC EXPERIMENT registered_method_lookup_in_transport OFF (default:OFF)\n", + "D0506 11:46:12.190500891 35513 config.cc:183] gRPC EXPERIMENT round_robin_delegate_to_pick_first OFF (default:OFF)\n", + "D0506 11:46:12.190502521 35513 config.cc:183] gRPC EXPERIMENT rstpit OFF (default:OFF)\n", + "D0506 11:46:12.190504141 35513 config.cc:183] gRPC EXPERIMENT schedule_cancellation_over_write OFF (default:OFF)\n", + "D0506 11:46:12.190505761 35513 config.cc:183] gRPC EXPERIMENT server_privacy ON (default:ON)\n", + "D0506 11:46:12.190508511 35513 config.cc:183] gRPC EXPERIMENT settings_timeout OFF (default:OFF)\n", + "D0506 11:46:12.190510301 35513 config.cc:183] gRPC EXPERIMENT tarpit OFF (default:OFF)\n", + "D0506 11:46:12.190511971 35513 config.cc:183] gRPC EXPERIMENT tcp_frame_size_tuning OFF (default:OFF)\n", + "D0506 11:46:12.190513621 35513 config.cc:183] gRPC EXPERIMENT tcp_rcv_lowat OFF (default:OFF)\n", + "D0506 11:46:12.190515241 35513 config.cc:183] gRPC EXPERIMENT trace_record_callops OFF (default:OFF)\n", + "D0506 11:46:12.190516881 35513 config.cc:183] gRPC EXPERIMENT unconstrained_max_quota_buffer_size OFF (default:OFF)\n", + "D0506 11:46:12.190518491 35513 config.cc:183] gRPC EXPERIMENT work_serializer_clears_time_cache OFF (default:OFF)\n", + "D0506 11:46:12.190520101 35513 config.cc:183] gRPC EXPERIMENT work_serializer_dispatch OFF (default:OFF)\n", + "D0506 11:46:12.190521761 35513 config.cc:183] gRPC EXPERIMENT wrr_delegate_to_pick_first OFF (default:OFF)\n", + "I0506 11:46:12.190651981 35513 ev_epoll1_linux.cc:123] grpc epoll fd: 879\n", + "D0506 11:46:12.190664491 35513 ev_posix.cc:113] Using polling engine: epoll1\n", + "I0506 11:46:12.190736131 35513 server_builder.cc:353] Synchronous server. Num CQs: 1, Min pollers: 1, Max Pollers: 2, CQ timeout (msec): 10000\n", + "D0506 11:46:12.190775011 35513 lb_policy_registry.cc:47] registering LB policy factory for \"priority_experimental\"\n", + "D0506 11:46:12.190778661 35513 lb_policy_registry.cc:47] registering LB policy factory for \"outlier_detection_experimental\"\n", + "D0506 11:46:12.190780941 35513 lb_policy_registry.cc:47] registering LB policy factory for \"weighted_target_experimental\"\n", + "D0506 11:46:12.190788601 35513 lb_policy_registry.cc:47] registering LB policy factory for \"pick_first\"\n", + "D0506 11:46:12.190793241 35513 lb_policy_registry.cc:47] registering LB policy factory for \"round_robin\"\n", + "D0506 11:46:12.190795751 35513 lb_policy_registry.cc:47] registering LB policy factory for \"weighted_round_robin\"\n", + "D0506 11:46:12.190804551 35513 lb_policy_registry.cc:47] registering LB policy factory for \"grpclb\"\n", + "D0506 11:46:12.190816141 35513 dns_resolver_plugin.cc:50] Using ares dns resolver\n", + "D0506 11:46:12.190827371 35513 lb_policy_registry.cc:47] registering LB policy factory for \"rls_experimental\"\n", + "D0506 11:46:12.190843191 35513 lb_policy_registry.cc:47] registering LB policy factory for \"xds_cluster_manager_experimental\"\n", + "D0506 11:46:12.190846101 35513 lb_policy_registry.cc:47] registering LB policy factory for \"xds_cluster_impl_experimental\"\n", + "D0506 11:46:12.190848521 35513 lb_policy_registry.cc:47] registering LB policy factory for \"cds_experimental\"\n", + "D0506 11:46:12.190850831 35513 lb_policy_registry.cc:47] registering LB policy factory for \"xds_cluster_resolver_experimental\"\n", + "D0506 11:46:12.190853151 35513 lb_policy_registry.cc:47] registering LB policy factory for \"xds_override_host_experimental\"\n", + "D0506 11:46:12.190855641 35513 lb_policy_registry.cc:47] registering LB policy factory for \"xds_wrr_locality_experimental\"\n", + "D0506 11:46:12.190858871 35513 lb_policy_registry.cc:47] registering LB policy factory for \"ring_hash_experimental\"\n", + "D0506 11:46:12.190861811 35513 certificate_provider_registry.cc:33] registering certificate provider factory for \"file_watcher\"\n", + "I0506 11:46:12.191746541 35513 ev_epoll1_linux.cc:360] grpc epoll fd: 881\n", + "I0506 11:46:12.191833441 35513 socket_utils_common_posix.cc:366] TCP_USER_TIMEOUT is available. TCP_USER_TIMEOUT will be used thereafter\n", + "2024-05-06 11:46:12.228734: I external/local_xla/xla/service/service.cc:168] XLA service 0x5570a1f661a0 initialized for platform TPU (this does not guarantee that XLA will be used). Devices:\n", + "2024-05-06 11:46:12.228754: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (0): TPU, 2a886c8\n", + "2024-05-06 11:46:12.228760: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (1): TPU, 2a886c8\n", + "2024-05-06 11:46:12.228764: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (2): TPU, 2a886c8\n", + "2024-05-06 11:46:12.228768: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (3): TPU, 2a886c8\n", + "2024-05-06 11:46:12.228772: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (4): TPU, 2a886c8\n", + "2024-05-06 11:46:12.228776: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (5): TPU, 2a886c8\n", + "2024-05-06 11:46:12.228780: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (6): TPU, 2a886c8\n", + "2024-05-06 11:46:12.228784: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (7): TPU, 2a886c8\n", + "2024-05-06 11:46:12.228915: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.228964: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229004: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229044: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229087: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229219: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229304: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229377: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229427: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229473: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229624: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229673: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229717: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229779: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229824: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.229980: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230024: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230071: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230121: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230165: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230308: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230360: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230410: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230457: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230512: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230663: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230719: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230776: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230824: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.230877: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.231032: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.231080: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.231125: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.231169: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.231213: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.231361: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.231420: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.231470: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.231520: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n", + "2024-05-06 11:46:12.231564: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.\n" + ] + } + ], "source": [ "model = TFViTMAEForPreTraining(config=configuration)\n", "model.compile()" @@ -153,32 +271,20 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1/4\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-05-03 20:42:56.136685: I external/local_xla/xla/service/service.cc:168] XLA service 0x7f92e800c5b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n", - "2024-05-03 20:42:56.136721: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (0): Tesla T4, Compute Capability 7.5\n", - "2024-05-03 20:42:56.144134: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n", - "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", - "I0000 00:00:1714768976.196050 207819 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - " 118/88740 [..............................] - ETA: 42:51:56 - loss: 2.0085" + "Epoch 1/4\n", + "WARNING:tensorflow:AutoGraph could not transform and will run it as-is.\n", + "Cause: for/else statement not yet supported\n", + "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n", + "WARNING: AutoGraph could not transform and will run it as-is.\n", + "Cause: for/else statement not yet supported\n", + "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n" ] } ], @@ -192,7 +298,38 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "import datetime\n", + "tpu = tf.distribute.cluster_resolver.TPUClusterResolver()\n", + "\n", + "tf.config.experimental_connect_to_cluster(tpu)\n", + "tf.tpu.experimental.initialize_tpu_system(tpu)\n", + "\n", + "strategy = tf.distribute.TPUStrategy(tpu)\n", + "\n", + "aia_data = zarr.group(zarr.DirectoryStore(\"/mnt/sdoml/AIA.zarr\"))\n", + "aligndata = pd.read_csv(\"/mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv\")\n", + "aligndata[\"Time\"] = pd.to_datetime(aligndata[\"Time\"])\n", + "aligndata.set_index(\"Time\", inplace=True)\n", + "with open(\"/mnt/sdoml/cache/normalizations_AIA_FULL_12min.json\", \"r\") as json_file:\n", + " normalisations = json.load(json_file) \n", + "json_file.close()\n", + "\n", + "aiadl = AIALoader(aia_data, aligndata, normalisations, batch_size=16)\n", + "# dataset = DatasetFromSequenceClass(aiadl, aiadl.datalen, N_EPOCHS, batchSize=16)\n", + "\n", + "configuration = ViTMAEConfig(image_size=512, num_channels=9, num_attention_heads=16, training=True)\n", + "\n", + "log_dir = \"~/logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", + "tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n", + "\n", + "with strategy.scope():\n", + " model = TFViTMAEForPreTraining(config=configuration)\n", + " model.compile()\n", + "\n", + " history = model.fit(aiadl, epochs=int(4), callbacks=[tensorboard_callback])\n", + " eval_metrics = {key: val[-1] for key, val in history.history.items()}" + ] } ], "metadata": { @@ -211,7 +348,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/scripts/pretrain_mae_tensorflow.py b/scripts/pretrain_mae_tensorflow.py index 2e37cab..33365b1 100644 --- a/scripts/pretrain_mae_tensorflow.py +++ b/scripts/pretrain_mae_tensorflow.py @@ -6,6 +6,7 @@ import zarr import numpy as np import time +import datetime class AIALoader(tf.keras.utils.Sequence): def __init__(self, zarr_file, aligndata, normalisations, batch_size, n_frames=1, shuffle = False, inst = None): @@ -119,13 +120,12 @@ def LoadBatchFromSequenceClass(batchIndexTensor): def main(): N_EPOCHS = 4 + tpu = tf.distribute.cluster_resolver.TPUClusterResolver() - # tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='t1v-n-a6a7febf-w-0') + tf.config.experimental_connect_to_cluster(tpu) + tf.tpu.experimental.initialize_tpu_system(tpu) - # tf.config.experimental_connect_to_cluster(tpu) - # tf.tpu.experimental.initialize_tpu_system(tpu) - - # strategy = tf.distribute.TPUStrategy(tpu) + strategy = tf.distribute.TPUStrategy(tpu) aia_data = zarr.group(zarr.DirectoryStore("/mnt/sdoml/AIA.zarr")) aligndata = pd.read_csv("/mnt/sdoml/cache/aligndata_AIA_FULL_12min.csv") @@ -136,14 +136,19 @@ def main(): json_file.close() aiadl = AIALoader(aia_data, aligndata, normalisations, batch_size=16) - dataset = DatasetFromSequenceClass(aiadl, aiadl.datalen, N_EPOCHS, batchSize=16) + # dataset = DatasetFromSequenceClass(aiadl, aiadl.datalen, N_EPOCHS, batchSize=16) configuration = ViTMAEConfig(image_size=512, num_channels=9, num_attention_heads=16, training=True) - model = TFViTMAEForPreTraining(config=configuration) - model.compile() - history = model.fit(dataset, epochs=int(4))#, callbacks=callbacks) - eval_metrics = {key: val[-1] for key, val in history.history.items()} + log_dir = "~/logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) + + with strategy.scope(): + model = TFViTMAEForPreTraining(config=configuration) + model.compile() + + history = model.fit(aiadl, epochs=int(4), callbacks=[tensorboard_callback]) + eval_metrics = {key: val[-1] for key, val in history.history.items()} if __name__ == "__main__": diff --git a/tpu_setup_tf.sh b/tpu_setup_tf.sh index e4e9191..7e3fde0 100755 --- a/tpu_setup_tf.sh +++ b/tpu_setup_tf.sh @@ -1,3 +1,9 @@ +sudo apt install nfs-common -y +sudo mkdir /mnt/sdoml +sudo mount 10.14.32.66:/sdoml_hdd /mnt/sdoml -o ro,hard,timeo=600,retrans=3,rsize=262144,wsize=1048576,resvport,async,nconnect=7,_netdev + export TPU_NAME=local export NEXT_PLUGGABLE_DEVICE_USE_C_API=true -export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so \ No newline at end of file +export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so + +pip install zarr transformers numpy pandas \ No newline at end of file