From 6407f946f8e29386c353556ce738266daa958e56 Mon Sep 17 00:00:00 2001 From: Hyunmin-H Date: Wed, 16 Aug 2023 12:54:31 +0000 Subject: [PATCH] =?UTF-8?q?[Feat]=20garment=20masking=20I/O=20=ED=98=95?= =?UTF-8?q?=EC=8B=9D=20=EB=B3=80=EA=B2=BD=20#33?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - garment masking input 형식을 image byte, output 형식을 PIL 형식으로. - gcp class에 json 파일 저장 코드 추가 related to : #31 --- backend/app/main.py | 22 ++++++++----- backend/gcp/cloud_storage.py | 20 +++++------- .../ladi_vton/src/utils/get_clothing_mask.py | 31 ++++++++++++++++--- 3 files changed, 50 insertions(+), 23 deletions(-) diff --git a/backend/app/main.py b/backend/app/main.py index aa7ae1e..3835f6f 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -10,18 +10,18 @@ # scp setting import sys, os sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/Self_Correction_Human_Parsing/') -from simple_extractor import main_schp, main_schp_from_image_byte +from simple_extractor import main_schp, main_schp_fromImageByte # openpose sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/pytorch_openpose/') -from extract_keypoint import main_openpose +from extract_keypoint import main_openpose, main_openpose_fromImageByte # ladi sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton') sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton/src') sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton/src/utils') -from get_clothing_mask import main_mask +from get_clothing_mask import main_mask, main_mask_fromImageByte from inference import main_ladi from face_cut_and_paste import main_cut_and_paste @@ -136,21 +136,29 @@ def inference_allModels(target_bytes, garment_bytes, category, db_dir): # schp - (1024, 784), (512, 384) target_buffer_dir = os.path.join(input_dir, 'buffer/target') # main_schp(target_buffer_dir) - schp_img = main_schp_from_image_byte(target_bytes) + schp_img = main_schp_fromImageByte(target_bytes) schp_img.save('./schp.png') - exit() # openpose output_openpose_buffer_dir = os.path.join(db_dir, 'openpose/buffer') os.makedirs(output_openpose_buffer_dir, exist_ok=True) - main_openpose(target_buffer_dir, output_openpose_buffer_dir) + # main_openpose(target_buffer_dir, output_openpose_buffer_dir) + keypoint_dict = main_openpose_fromImageByte(target_bytes) + gcs.upload_dict_as_json_to_gcs(keypoint_dict, os.path.join(db_dir, 'openpose/buffer/target.json')) + # /opt/ml/user_db/mask/buffer # mask garment_dir = os.path.join(input_dir, 'buffer/garment') output_mask_dir = os.path.join(db_dir, 'mask/buffer') os.makedirs(output_mask_dir, exist_ok=True) - main_mask(category, garment_dir, output_mask_dir) + # main_mask(category, garment_dir, output_mask_dir) + ## garment_mask 형식 - Image + garment_mask = main_mask_fromImageByte(garment_bytes) + + garment_mask.save('./garment_mask.jpg') + + exit() # ladi-vton output_ladi_buffer_dir = os.path.join(db_dir, 'ladi/buffer') os.makedirs(output_ladi_buffer_dir, exist_ok=True) diff --git a/backend/gcp/cloud_storage.py b/backend/gcp/cloud_storage.py index e14f4a6..4b3a18d 100644 --- a/backend/gcp/cloud_storage.py +++ b/backend/gcp/cloud_storage.py @@ -8,6 +8,7 @@ # datetime from datetime import datetime, timedelta +import json class GCSUploader: @@ -35,15 +36,13 @@ def upload_blob(self, source_file_data: bytes, destination_blob_name: str) -> st print(f"File uploaded to {destination_blob_name}.") return user_url + + def upload_dict_as_json_to_gcs(self, data_dict, blob_name): + bucket = self.client.get_bucket(self.bucket_name) + blob = bucket.blob(blob_name) - # Uploads image to GCS and returns the URL - # def save_image_to_gcs(self, urls: list) -> str: - # image_urls = [] - # for i, (byte_arr, url_name) in enumerate(urls): - # url = self.upload_blob(byte_arr, url_name) - # image_urls.append(url) - - # return image_urls + json_data = json.dumps(data_dict) + blob.upload_from_string(json_data, content_type='application/json') def list_images_in_folder(self, folder_name): bucket = self.client.get_bucket(self.bucket_name) @@ -62,11 +61,8 @@ def read_image_from_gcs(self, blob_name): return None image_data = blob.download_as_bytes() - from PIL import Image - from io import BytesIO - - return Image.open(BytesIO(image_data)) + return image_data def load_gcp_config_from_yaml(yaml_path): with open(yaml_path, 'r') as yaml_file: diff --git a/model/ladi_vton/src/utils/get_clothing_mask.py b/model/ladi_vton/src/utils/get_clothing_mask.py index 35d2e7a..c7345b3 100644 --- a/model/ladi_vton/src/utils/get_clothing_mask.py +++ b/model/ladi_vton/src/utils/get_clothing_mask.py @@ -1,6 +1,7 @@ from rembg import remove from PIL import Image import os +from io import BytesIO # garment, lower_body, upper_body @@ -52,9 +53,31 @@ def create_mask_from_png(jpg_file, jpg_mask_file): create_mask_from_png(input_path, output_path) + +def main_mask_fromImageByte(garment_bytes): + + # image = Image.open(jpg_file) + image = Image.open(BytesIO(garment_bytes)) -# output = remove(input) -# output.save(os.path.join(target_buffer_dir, output_mask_dir)) -# target_buffer_dir = '/opt/ml/user_db/input/buffer/garment' -# main_mask('lower_body', target_buffer_dir) \ No newline at end of file + image = remove(image) + + mask_image = Image.new("RGB", image.size) + width, height = image.size + + pixel_data = image.load() + + for y in range(height): + for x in range(width): + # 각 픽셀의 RGB 값과 알파 값 가져오기 + r, g, b, alpha = pixel_data[x, y] + + # alpha 값이 0인 경우 검정색으로, 그렇지 않은 경우 흰색으로 설정 + if alpha == 0: + mask_image.putpixel((x, y), (0, 0, 0)) + else: + mask_image.putpixel((x, y), (255, 255, 255)) + + # JPG 파일로 저장 + return mask_image + \ No newline at end of file