From 380432a0820f589727bb80dd84dbe7bd4d463939 Mon Sep 17 00:00:00 2001 From: Ben Dichter Date: Sun, 31 Mar 2024 08:56:06 -0400 Subject: [PATCH] update style for _add_missing_timezone --- src/pynwb/file.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/pynwb/file.py b/src/pynwb/file.py index 0b294e873..766a71232 100644 --- a/src/pynwb/file.py +++ b/src/pynwb/file.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, date from dateutil.tz import tzlocal from collections.abc import Iterable from warnings import warn @@ -104,7 +104,7 @@ class Subject(NWBContainer): 'doc': ('The weight of the subject, including units. Using kilograms is recommended. e.g., "0.02 kg". ' 'If a float is provided, then the weight will be stored as "[value] kg".'), 'default': None}, - {'name': 'date_of_birth', 'type': datetime, 'default': None, + {'name': 'date_of_birth', 'type': (datetime, date), 'default': None, 'doc': 'The datetime of the date of birth. May be supplied instead of age.'}, {'name': 'strain', 'type': str, 'doc': 'The strain of the subject, e.g., "C57BL/6J"', 'default': None}, ) @@ -142,7 +142,7 @@ def __init__(self, **kwargs): args_to_set["age"] = pd.Timedelta(args_to_set["age"]).isoformat() date_of_birth = args_to_set['date_of_birth'] - if date_of_birth and date_of_birth.tzinfo is None: + if not self._in_construct_mode and date_of_birth and isinstance(date_of_birth, datetime) and date_of_birth.tzinfo is None: args_to_set['date_of_birth'] = _add_missing_timezone(date_of_birth) for key, val in args_to_set.items(): @@ -1155,16 +1155,16 @@ def copy(self): return NWBFile(**kwargs) -def _add_missing_timezone(date): +def _add_missing_timezone(dt: datetime): """ Add local timezone information on a datetime object if it is missing. """ - if not isinstance(date, datetime): - raise ValueError("require datetime object") - if date.tzinfo is None: - warn("Date is missing timezone information. Updating to local timezone.", stacklevel=2) - return date.replace(tzinfo=tzlocal()) - return date + if not isinstance(dt, datetime): + raise TypeError("require datetime object") + if dt.tzinfo is None: + warn("Datetime is missing timezone information. Updating to local timezone.", stacklevel=2) + return dt.replace(tzinfo=tzlocal()) + return dt def _tablefunc(table_name, description, columns):