Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update style for _add_missing_timezone #1877

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/pynwb/file.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime, timedelta
from datetime import datetime, timedelta, date
from dateutil.tz import tzlocal
from collections.abc import Iterable
from warnings import warn
Expand Down Expand Up @@ -104,7 +104,7 @@ class Subject(NWBContainer):
'doc': ('The weight of the subject, including units. Using kilograms is recommended. e.g., "0.02 kg". '
'If a float is provided, then the weight will be stored as "[value] kg".'),
'default': None},
{'name': 'date_of_birth', 'type': datetime, 'default': None,
{'name': 'date_of_birth', 'type': (datetime, date), 'default': None,
'doc': 'The datetime of the date of birth. May be supplied instead of age.'},
{'name': 'strain', 'type': str, 'doc': 'The strain of the subject, e.g., "C57BL/6J"', 'default': None},
)
Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(self, **kwargs):
args_to_set["age"] = pd.Timedelta(args_to_set["age"]).isoformat()

date_of_birth = args_to_set['date_of_birth']
if date_of_birth and date_of_birth.tzinfo is None:
if not self._in_construct_mode and date_of_birth and isinstance(date_of_birth, datetime) and date_of_birth.tzinfo is None:
args_to_set['date_of_birth'] = _add_missing_timezone(date_of_birth)

for key, val in args_to_set.items():
Expand Down Expand Up @@ -1155,16 +1155,16 @@ def copy(self):
return NWBFile(**kwargs)


def _add_missing_timezone(date):
def _add_missing_timezone(dt: datetime):
"""
Add local timezone information on a datetime object if it is missing.
"""
if not isinstance(date, datetime):
raise ValueError("require datetime object")
if date.tzinfo is None:
warn("Date is missing timezone information. Updating to local timezone.", stacklevel=2)
return date.replace(tzinfo=tzlocal())
return date
if not isinstance(dt, datetime):
raise TypeError("require datetime object")
if dt.tzinfo is None:
warn("Datetime is missing timezone information. Updating to local timezone.", stacklevel=2)
return dt.replace(tzinfo=tzlocal())
return dt


def _tablefunc(table_name, description, columns):
Expand Down
Loading