Skip to content

Commit

Permalink
Make config objects json serializable (#862)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
tjruwase and jeffra authored Mar 16, 2021
1 parent fa87a73 commit 7bcd72a
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 25 deletions.
10 changes: 8 additions & 2 deletions deepspeed/profiling/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
Licensed under the MIT license.
"""

from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
from deepspeed.profiling.constants import *


class DeepSpeedFlopsProfilerConfig(object):
class DeepSpeedFlopsProfilerConfig(DeepSpeedConfigObject):
def __init__(self, param_dict):
"""
docstring
"""
super(DeepSpeedFlopsProfilerConfig, self).__init__()

self.enabled = None
Expand All @@ -24,6 +27,9 @@ def __init__(self, param_dict):
self._initialize(flops_profiler_dict)

def _initialize(self, flops_profiler_dict):
"""
docstring
"""
self.enabled = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_ENABLED,
FLOPS_PROFILER_ENABLED_DEFAULT)
Expand Down
11 changes: 2 additions & 9 deletions deepspeed/runtime/activation_checkpointing/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Licensed under the MIT license.
"""

from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject

#########################################
# DeepSpeed Activation Checkpointing
Expand Down Expand Up @@ -56,7 +56,7 @@
}


class DeepSpeedActivationCheckpointingConfig(object):
class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject):
def __init__(self, param_dict):
super(DeepSpeedActivationCheckpointingConfig, self).__init__()

Expand All @@ -74,13 +74,6 @@ def __init__(self, param_dict):

self._initialize(act_chkpt_config_dict)

"""
For json serialization
"""

def repr(self):
return self.__dict__

def _initialize(self, act_chkpt_config_dict):
self.partition_activations = get_scalar_param(
act_chkpt_config_dict,
Expand Down
13 changes: 12 additions & 1 deletion deepspeed/runtime/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,21 @@
"""
Collection of DeepSpeed configuration utilities
"""

import json
from collections import Counter


class DeepSpeedConfigObject(object):
"""
For json serialization
"""
def repr(self):
return self.__dict__

def __repr__(self):
return json.dumps(self.__dict__, sort_keys=True, indent=4)


def get_scalar_param(param_dict, param_name, param_default_value):
return param_dict.get(param_name, param_default_value)

Expand Down
15 changes: 2 additions & 13 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
Licensed under the MIT license.
"""

from deepspeed.runtime.config_utils import get_scalar_param
from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
from deepspeed.utils import logger
from deepspeed.runtime.zero.constants import *
import json


class DeepSpeedZeroConfig(object):
class DeepSpeedZeroConfig(DeepSpeedConfigObject):
def __init__(self, param_dict):
super(DeepSpeedZeroConfig, self).__init__()

Expand Down Expand Up @@ -66,16 +65,6 @@ def read_zero_config_deprecated(self, param_dict):
.format(ZERO_FORMAT))
return zero_config_dict

"""
For json serialization
"""

def repr(self):
return self.__dict__

def __repr__(self):
return json.dumps(self.__dict__, sort_keys=True, indent=4)

def _initialize(self, zero_config_dict):
self.stage = get_scalar_param(zero_config_dict,
ZERO_OPTIMIZATION_STAGE,
Expand Down

0 comments on commit 7bcd72a

Please sign in to comment.