Skip to content

Commit

Permalink
bugbear, simplicity
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Dec 29, 2024
1 parent 5130b0b commit a157664
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 36 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ select = [
# flake8-bugbear
# "B",
# flake8-simplify
# "SIM",
"SIM",
# isort
"I",
]
Expand Down
2 changes: 1 addition & 1 deletion tensorboardX/comet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def wrapper(*args, **kwargs):
if self._logging is None and comet_installed:
self._logging = False
try:
if 'api_key' not in self._comet_config.keys():
if 'api_key' not in self._comet_config:
comet_ml.init()
if comet_ml.get_global_experiment() is not None:
logger.warning("You have already created a comet \
Expand Down
44 changes: 22 additions & 22 deletions tensorboardX/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def hparams(hparam_dict=None, metric_dict=None):
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_BOOL")))
continue

if isinstance(v, int) or isinstance(v, float):
if isinstance(v, (int, float)):
v = make_np(v)[0]
ssi.hparams[k].number_value = v
hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
Expand All @@ -126,7 +126,7 @@ def hparams(hparam_dict=None, metric_dict=None):
content=content.SerializeToString()))
ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])

mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()]
mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict]

exp = Experiment(hparam_infos=hps, metric_infos=mts)
content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
Expand Down Expand Up @@ -388,26 +388,26 @@ def make_video(tensor, fps):

# encode sequence of images into gif string
clip = mpy.ImageSequenceClip(list(tensor), fps=fps)

filename = tempfile.NamedTemporaryFile(suffix='.gif', delete=False).name

if moviepy.version.__version__.startswith("0."):
logger.warning('Upgrade to moviepy >= 1.0.0 to supress the progress bar.')
clip.write_gif(filename, verbose=False)
elif moviepy.version.__version__.startswith("1."):
# moviepy >= 1.0.0 use logger=None to suppress output.
clip.write_gif(filename, verbose=False, logger=None)
else:
# Moviepy >= 2.0.0.dev1 removed the verbose argument
clip.write_gif(filename, logger=None)

with open(filename, 'rb') as f:
tensor_string = f.read()

try:
os.remove(filename)
except OSError:
logger.warning('The temporary file used by moviepy cannot be deleted.')
with tempfile.NamedTemporaryFile(suffix='.gif', delete=False) as fp:
filename = fp.name

if moviepy.version.__version__.startswith("0."):
logger.warning('Upgrade to moviepy >= 1.0.0 to supress the progress bar.')
clip.write_gif(filename, verbose=False)
elif moviepy.version.__version__.startswith("1."):
# moviepy >= 1.0.0 use logger=None to suppress output.
clip.write_gif(filename, verbose=False, logger=None)
else:
# Moviepy >= 2.0.0.dev1 removed the verbose argument
clip.write_gif(filename, logger=None)

with open(filename, 'rb') as f:
tensor_string = f.read()

try:
os.remove(filename)
except OSError:
logger.warning('The temporary file used by moviepy cannot be deleted.')

return Summary.Image(height=h, width=w, colorspace=c, encoded_image_string=tensor_string)

Expand Down
4 changes: 2 additions & 2 deletions tensorboardX/torchvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def __init__(self, *args, **init_kwargs):

def register(self, *args, **init_kwargs):
# Sets tensorboard as the default visualization format if not specified
formats = ['tensorboard'] if not args else args
formats = args if args else ['tensorboard']
for format in formats:
if self.subscribers.get(format) is None and format in vis_formats.keys():
if self.subscribers.get(format) is None and format in vis_formats:
self.subscribers[format] = vis_formats[format](**init_kwargs.get(format, {}))

def unregister(self, *args):
Expand Down
5 changes: 2 additions & 3 deletions tensorboardX/visdom_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def add_scalar(self, tag, scalar_value, global_step=None, main_tag='default'):
[scalar_value] if exists else [scalar_value]
plot_name = f'{main_tag}-{tag}'
# If there is no global_step provided, follow sequential order
x_val = len(self.scalar_dict[main_tag][tag]
) if not global_step else global_step
x_val = global_step if global_step else len(self.scalar_dict[main_tag][tag])
if exists:
# Update our existing Visdom window
self.vis.line(
Expand Down Expand Up @@ -110,7 +109,7 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None):
'run_14h-arctanx'
with the corresponding values.
"""
for key in tag_scalar_dict.keys():
for key in tag_scalar_dict:
self.add_scalar(key, tag_scalar_dict[key], global_step, main_tag)

@_check_connection
Expand Down
10 changes: 3 additions & 7 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def __append_to_scalar_dict(self, tag, scalar_value, global_step,
{writer_id : [[timestamp, step, value], ...], ...}.
"""
from .x2num import make_np
if tag not in self.scalar_dict.keys():
if tag not in self.scalar_dict:
self.scalar_dict[tag] = []
self.scalar_dict[tag].append(
[timestamp, global_step, float(make_np(scalar_value).squeeze())])
Expand Down Expand Up @@ -483,7 +483,7 @@ def add_scalars(
fw_logdir = self._get_file_writer().get_logdir()
for tag, scalar_value in tag_scalar_dict.items():
fw_tag = os.path.join(str(fw_logdir), main_tag, tag)
if fw_tag in self.all_writers.keys():
if fw_tag in self.all_writers:
fw = self.all_writers[fw_tag]
else:
fw = FileWriter(logdir=fw_tag)
Expand Down Expand Up @@ -1001,11 +1001,7 @@ def add_embedding(
# new funcion to append to the config file a new embedding
append_pbtxt(metadata, label_img,
self._get_file_writer().get_logdir(), subdir, global_step, tag)
if tag is not None:
template_filename = f"{tag}.json"

else:
template_filename = None
template_filename = f'{tag}.json' if tag is not None else None

self._get_comet_logger().log_embedding(mat, metadata, label_img, template_filename=template_filename)

Expand Down

0 comments on commit a157664

Please sign in to comment.