-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expanded sharded support for alternative sharding mechanisms #680
base: main
Are you sure you want to change the base?
Conversation
Single-logical-multi-physical sharding allows tensor access between different devices and tighter synchronization on execution. This means that sharding needs to support more than differing device ordinals but also configre multiple queues for the same device. Sharded tensor types are reworked to support tracking both the supported device AND the queue it is enqueued on. To support this each sharded tensor now tracks the DeviceAffinity it is associated with, along with reassigning affinities post construction. This allows pre-sharded models to have their affinities updated with an alternative transfer mechanism. If device affinity is not specified the default arrangement assumes separate device ordinals for each shard.
e486ad4
to
9a061e6
Compare
@@ -279,16 +283,15 @@ def main(): | |||
tensor_parallelism_size=args.tensor_parallelism_size, | |||
fake_quant=args.fake_quant, | |||
) | |||
if config.tensor_parallelism_size > 1: | |||
dataset.root_theta = shard_theta(dataset.root_theta, config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can remove shard_theta
import if unused.
I saw the following error this morning when attempting to validate toy_llama_tp2 from iree-test-suites by exporting, and compiling with intent to then verify with
python -m sharktank.examples.export_paged_llm_v1 --bs=1 --irpa-file assets/toy_llama_tp2.irpa --output-mlir=llama.mlir --output-config=config.json --use-queue-affinities
iree-compile llama.mlir -o llama.vmfb --iree-hip-target=gfx942 --iree-hal-target-device=hip[0] Received the following error:/toy_new/llama.mlir:4027:12: error: op affinity #hal.device.affinity<@__device_0> is not compatible with the partition affinity #hal.device.affinity<@__device_0, [0]>
%153 = torch.prims.convert_element_type %1, %int5_87 : !torch.vtensor<[256,256],f32>, !torch.int -> !torch.vtensor<[256,256],f16>
^
./toy_new/llama.mlir:4027:12: note: see current operation: %190 = "stream.async.transfer"(%189, %10, %10) <{result_affinity = #hal.device.affinity<@__device_0>, source_affinity = #hal.device.affinity<@__device_0, [1]>}> : (!stream.resource<constant>, index, index) -> !stream.resource<constant> Feedback from Rob this morning before sync:
|
self.rope_dimension_count = rope_dimension_count | ||
self.max_seqlen = max_seqlen | ||
self.use_hf = use_hf | ||
self.static_tables = static_tables | ||
self.use_table = use_table | ||
|
||
self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0 | ||
self.tensor_parallelism_size = tensor_parallelism_size | ||
self.devices = devices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is redundant with L34, can be removed.
Single-logical-multi-physical sharding allows tensor access between
different devices and tighter synchronization on execution. This means
that sharding needs to support more than differing device ordinals but
also configre multiple queues for the same device. Sharded tensor types
are reworked to support tracking both the supported device AND the queue
it is enqueued on.
To support this each sharded tensor now tracks the DeviceAffinity it is
associated with, along with reassigning affinities post construction.
This allows pre-sharded models to have their affinities updated with an
alternative transfer mechanism.
If device affinity is not specified the default arrangement assumes
separate device ordinals for each shard.