Skip to content

Commit

Permalink
Merge pull request #105 from chairc/dev
Browse files Browse the repository at this point in the history
Fix MEAN and STD bug and update organization logo.
  • Loading branch information
chairc authored Nov 18, 2024
2 parents d1d80b7 + 4969000 commit f16a127
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -493,4 +493,4 @@ If this project is used for experiments in an academic paper, where possible ple

[@Pytorch](https://pytorch.org/)

<img src="assets/Pycharm.png" alt="Pycharm" style="zoom:25%;" /><img src="assets/Python.png" alt="Python" style="zoom:25%;" /><img src="assets/pytorch.png" alt="pytorch" style="zoom:25%;" />
![JetBrains logo](assets/jetbrains.svg)
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -494,5 +494,5 @@ Integrated Design Diffusion Model

[@Pytorch](https://pytorch.org/)

<img src="assets/Pycharm.png" alt="Pycharm" style="zoom:25%;" /><img src="assets/Python.png" alt="Python" style="zoom:25%;" /><img src="assets/pytorch.png" alt="pytorch" style="zoom:25%;" />
![JetBrains logo](assets/jetbrains.svg)

Binary file removed assets/Pycharm.png
Binary file not shown.
Binary file removed assets/Python.png
Binary file not shown.
13 changes: 13 additions & 0 deletions assets/jetbrains.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed assets/pytorch.png
Binary file not shown.
12 changes: 6 additions & 6 deletions sr/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ def prepare_image(images):
return images


def post_image(images):
def post_image(images, device="cpu"):
"""
Post images
:param images: Images
:param device: CPU or GPU
:return: new_images
"""
new_images = torch.empty(size=images.shape, dtype=torch.uint8)
for i in range(images.shape[0]):
new_image = (images[i].clamp(-1, 1) + 1) / 2
new_image = (new_image * 255).to(torch.uint8)
new_images[i] = new_image
mean_tensor = torch.tensor(data=MEAN).view(1, -1, 1, 1).to(device)
std_tensor = torch.tensor(data=STD).view(1, -1, 1, 1).to(device)
new_images = images * std_tensor + mean_tensor
new_images = (new_images * 255).to(torch.uint8)
return new_images


Expand Down
6 changes: 3 additions & 3 deletions sr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,9 @@ def train(rank=None, args=None):
global_step=epoch * len_val_dataloader + i)
val_loss_list.append(val_loss.item())
# Save super resolution image and high resolution image
lr_images = post_image(lr_images)
sr_images = post_image(output)
hr_images = post_image(hr_images)
lr_images = post_image(lr_images, device=device)
sr_images = post_image(output, device=device)
hr_images = post_image(hr_images, device=device)
image_name = time.time()
for lr_index, lr_image in enumerate(lr_images):
save_images(images=lr_image, path=os.path.join(save_val_vis_dir, f"{i}_{image_name}_{lr_index}_lr.jpg"))
Expand Down

0 comments on commit f16a127

Please sign in to comment.