Skip to content

Commit

Permalink
enable embedding sparse optimization by default (#19714)
Browse files Browse the repository at this point in the history
  • Loading branch information
guyang3532 authored Mar 5, 2024
1 parent 7e613ee commit cd56ea4
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ to standard outputs.
#### ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is disabled. This env var can be used for enabling or disabling the embedding input
- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the embedding input
data sparsity based performance optimizations.

```bash
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,11 +681,15 @@ def _enable_conditional_optimizations(
)

if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0:
graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys())
self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results)
self._runtime_options.embed_sparsity_ratio = ",".join(
[f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()]
)
if detected_device.type == "cuda":
# Embedding sparsity optimization is only supported on CUDA devices.
graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys())
self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results)
self._runtime_options.embed_sparsity_ratio = ",".join(
[f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()]
)
else:
self._logger.info("Embedding sparsity-based optimization is not supported on non-CUDA devices.")

# If users don't want to print input density, disable the input density observer to avoid overhead
# when looping through inputs during training.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(self, logger: Logger):
self.enable_sparse_optimizer = True
self.label_sparsity_ratio = ""
self.embed_sparsity_ratio = ""
self.enable_embedding_sparse_optimizer = False # TODO(pengwa): remove once validation on more models are done.
self.enable_embedding_sparse_optimizer = True

# Configuration for memory optimization.
self.memory_optimization_level = (
Expand Down

0 comments on commit cd56ea4

Please sign in to comment.