Skip to content

Commit

Permalink
[fix] remove register device guard to avoid hang (#219)
Browse files Browse the repository at this point in the history
On latest version of tt-metal, there is an added synchronization when
closing the device. This synchronization doesn't work well with
opening device on libary load (`dl_open`).

There is a deadlock between the main thread (loading the forge _C
library) and the `completion_queue_thread` in tt-metal.

The completion thread is trying to acquire lock on `dl_load_lock`,
which is already being held by the main thread (inside of the `dl_open`).
The reason why the completion thread is acquiring this lock, is to
register the destructor for the TLS variable `dispatch_cmd_and_event`.

Hence, to unblock further uplifts of tt-mlir (tt-metal) - removing
code triggering the opening of the device during library load. This
code is related to torch 2.0 integration, and is not actually used
currently.
  • Loading branch information
pilkicTT authored Sep 4, 2024
1 parent 4a8c56c commit ce8d1d5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions forge/csrc/tt_torch_device/torch_device_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ class TorchDeviceImpl final : public c10::impl::DeviceGuardImplInterface
int next_id = 0;
};

// register backend
c10::impl::DeviceGuardImplRegistrar tt_device_impl_reg(TT, &TorchDeviceImpl::get());
// NOTE: We'll need to rework implementation of DeviceGuard to avoid opening the device
// on library load. This causes a hang in tt-metal during device close.
// c10::impl::DeviceGuardImplRegistrar tt_device_impl_reg(TT, &TorchDeviceImpl::get());

const std::shared_ptr<TTDevice>& get_default_tt_device() { return TorchDeviceImpl::get().getDefaultTTDevice();}
std::vector<std::shared_ptr<TTDevice>> get_available_tt_devices() { return TorchDeviceImpl::get().getTTDevices(); }
Expand Down
2 changes: 1 addition & 1 deletion forge/test/mlir/mnist/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_mnist_inference():
fw_out = framework_model(*inputs)

compiled_model = forge.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*[i.to("tt") for i in inputs])
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)]

0 comments on commit ce8d1d5

Please sign in to comment.