Skip to content

Commit

Permalink
support config changes from MVAI down to fbgemm (pytorch#2259)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2259

integrate the new rocksdb config into mvai model authoring chain so that we could tune the model config and affect the rocksdb changes

Differential Revision: D59785241

fbshipit-source-id: f689a86e3743b7fb7e939708af0267aaed60370e
  • Loading branch information
duduyi2013 authored and facebook-github-bot committed Jul 31, 2024
1 parent 95a4b71 commit 9eb6b89
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
16 changes: 16 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Union,
)

from fbgemm_gpu.runtime_monitor import TBEStatsReporterConfig
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
CacheAlgorithm,
Expand Down Expand Up @@ -588,16 +589,31 @@ class KeyValueParams:
ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses
and ports. Example: (("::1", 2000), ("::1", 2001), ("::1", 2002)).
Reason for using tuple is we want it hashable.
ssd_rocksdb_write_buffer_size: Optional[int]: rocksdb write buffer size per tbe,
relavant to rocksdb compaction frequency
ssd_rocksdb_shards: Optional[int]: rocksdb shards number
gather_ssd_cache_stats: bool: whether enable ssd stats collection, std reporter and ods reporter
report_interval: int: report interval in train iteration if gather_ssd_cache_stats is enabled
ods_prefix: str: ods prefix for ods reporting
"""

ssd_storage_directory: Optional[str] = None
ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None
ssd_rocksdb_write_buffer_size: Optional[int] = None
ssd_rocksdb_shards: Optional[int] = None
gather_ssd_cache_stats: Optional[bool] = None
stats_reporter_config: Optional[TBEStatsReporterConfig] = None
use_passed_in_path: bool = True

def __hash__(self) -> int:
return hash(
(
self.ssd_storage_directory,
self.ps_hosts,
self.ssd_rocksdb_write_buffer_size,
self.ssd_rocksdb_shards,
self.gather_ssd_cache_stats,
self.stats_reporter_config,
)
)

Expand Down
7 changes: 6 additions & 1 deletion torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,15 @@ def add_params_from_parameter_sharding(
parameter_sharding.compute_kernel in {EmbeddingComputeKernel.KEY_VALUE.value}
and parameter_sharding.key_value_params is not None
):
key_value_params_dict = asdict(parameter_sharding.key_value_params)
kv_params = parameter_sharding.key_value_params
key_value_params_dict = asdict(kv_params)
key_value_params_dict = {
k: v for k, v in key_value_params_dict.items() if v is not None
}
if kv_params.stats_reporter_config:
key_value_params_dict["stats_reporter_config"] = (
kv_params.stats_reporter_config
)
fused_params.update(key_value_params_dict)

# print warning if sharding_type is data_parallel or kernel is dense
Expand Down

0 comments on commit 9eb6b89

Please sign in to comment.