From e47e092c63181c0ffcf203b09921b4d12480fec7 Mon Sep 17 00:00:00 2001 From: Hyunmin-H Date: Wed, 16 Aug 2023 11:29:27 +0000 Subject: [PATCH] =?UTF-8?q?[Feat]=20SCHP=20input=20=ED=98=95=EC=8B=9D?= =?UTF-8?q?=EC=9D=84=20byte=EB=A1=9C=20=EB=B3=80=EA=B2=BD=20#33?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - simple_extractor.py에 image 한장이 input일 경우의 inference 함수 구현(main_schp_from_image_byte 함수) - front/back에서도 byte 형식 전달로 변경 --- backend/app/frontend.py | 27 +++--- backend/app/main.py | 44 +++++---- .../datasets/simple_extractor_dataset.py | 55 ++++++++++++ .../simple_extractor.py | 90 ++++++++++++++++++- 4 files changed, 184 insertions(+), 32 deletions(-) diff --git a/backend/app/frontend.py b/backend/app/frontend.py index 2c3786a..c7e58da 100644 --- a/backend/app/frontend.py +++ b/backend/app/frontend.py @@ -15,13 +15,12 @@ ASSETS_DIR_PATH = os.path.join(Path(__file__).parent.parent.parent.parent, "assets") st.set_page_config(layout="wide") -root_password = 'a' category_pair = {'Upper':'upper_body', 'Lower':'lower_body', 'Upper & Lower':'upper_lower', 'Dress':'dresses'} db_dir = '/opt/ml/user_db' gcp_config = load_gcp_config_from_yaml("/opt/ml/level3_cv_finalproject-cv-12/backend/config/gcs.yaml") -gcs_uploader = GCSUploader(gcp_config) +gcs = GCSUploader(gcp_config) user_name = 'hi' @@ -76,7 +75,7 @@ def append_imgList(uploaded_garment, category): def show_garments_and_checkboxes(category): category_dir = os.path.join(user_name, 'input/garment', category) - filenames = gcs_uploader.list_images_in_folder(category_dir) + filenames = gcs.list_images_in_folder(category_dir) # filenames = os.listdir(category_dir) # category_dir = os.path.join(db_dir, 'input/garment', category) @@ -91,8 +90,12 @@ def show_garments_and_checkboxes(category): im_dir = os.path.join(category_dir, filename) # garment_img = Image.open(im_dir) # garment_byte = read_image_as_bytes(im_dir) - garment_img = gcs_uploader.read_image_from_gcs(im_dir) + garment_img = gcs.read_image_from_gcs(im_dir) # st.image(garment_img, caption=filename[:-4], width=100) + + from PIL import Image + from io import BytesIO + garment_img = Image.open(BytesIO(garment_img)) cols[i % num_columns].image(garment_img, width=100, use_column_width=True, caption=filename[:-4]) # if st.checkbox(filename[:-4]) : # return True, garment_byte @@ -102,8 +105,11 @@ def show_garments_and_checkboxes(category): filenames_.extend([f[:-4] for f in filenames]) selected_garment = st.selectbox('입을 옷을 선택해주세요.', filenames_, index=0, key=category) print('selected_garment', selected_garment) + + im_dir = os.path.join(category_dir, f'{selected_garment}.jpg') + garment_byte = gcs.read_image_from_gcs(im_dir) - return filenames, selected_garment + return garment_byte, selected_garment def md_style(): st.markdown( @@ -196,10 +202,11 @@ def main(): if uploaded_garment : append_imgList(uploaded_garment, category) - filenames, selected_upper = show_garments_and_checkboxes(category) + selected_byte, selected_upper = show_garments_and_checkboxes(category) if selected_upper : is_selected_upper = True - files[2] = ('files', f'{selected_upper}.jpg') + # files[2] = ('files', f'{selected_upper}.jpg') + files[2] = ('files', selected_byte) print('selected_upper', selected_upper) with col3: @@ -212,7 +219,7 @@ def main(): if uploaded_garment : append_imgList(uploaded_garment, category) - filenames, selected_lower = show_garments_and_checkboxes(category) + selected_byte, selected_lower = show_garments_and_checkboxes(category) if selected_lower : is_selected_lower = True files[3] = ('files', f'{selected_lower}.jpg') @@ -264,7 +271,6 @@ def main(): category = 'upper_lower' elif is_selected_upper : - print('catogory upperrr') gen_start = True category = 'upper_body' elif is_selected_lower : @@ -280,9 +286,6 @@ def main(): files[0] = ('files', category) files[1] = ('files', (uploaded_target.name, target_bytes, uploaded_target.type)) - print('category', category) - print('files2', files[2]) - print('files3', files[3]) if gen_start : diff --git a/backend/app/main.py b/backend/app/main.py index 0fce236..aa7ae1e 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -10,7 +10,7 @@ # scp setting import sys, os sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/Self_Correction_Human_Parsing/') -from simple_extractor import main_schp +from simple_extractor import main_schp, main_schp_from_image_byte # openpose sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/pytorch_openpose/') @@ -45,7 +45,7 @@ db_dir = '/opt/ml/user_db' gcp_config = load_gcp_config_from_yaml("/opt/ml/level3_cv_finalproject-cv-12/backend/config/gcs.yaml") -gcs_uploader = GCSUploader(gcp_config) +gcs = GCSUploader(gcp_config) user_name = 'hi' @app.post("/add_data", description="데이터 저장") @@ -62,7 +62,7 @@ async def add_garment_to_db(files: List[UploadFile] = File(...)): garment_image = Image.open(io.BytesIO(garment_bytes)) garment_image = garment_image.convert("RGB") - gcs_uploader.upload_blob(garment_bytes, os.path.join(user_name, 'input/garment', category, f'{garment_name}')) + gcs.upload_blob(garment_bytes, os.path.join(user_name, 'input/garment', category, f'{garment_name}')) # garment_image.save(os.path.join(db_dir, 'input/garment', category, f'{garment_name}')) def read_image_as_bytes(image_path): @@ -130,13 +130,16 @@ async def get_boolean(): global is_modelLoading return {"is_modelLoading": is_modelLoading} -def inference_allModels(category, db_dir): +def inference_allModels(target_bytes, garment_bytes, category, db_dir): input_dir = os.path.join(db_dir, 'input') # schp - (1024, 784), (512, 384) target_buffer_dir = os.path.join(input_dir, 'buffer/target') - main_schp(target_buffer_dir) - + # main_schp(target_buffer_dir) + schp_img = main_schp_from_image_byte(target_bytes) + schp_img.save('./schp.png') + + exit() # openpose output_openpose_buffer_dir = os.path.join(db_dir, 'openpose/buffer') os.makedirs(output_openpose_buffer_dir, exist_ok=True) @@ -172,7 +175,8 @@ def inference_ladi(category, db_dir, target_name='target.jpg'): @app.post("/order", description="주문을 요청합니다") async def make_order(files: List[UploadFile] = File(...)): - input_dir = '/opt/ml/user_db/input/' + # input_dir = '/opt/ml/user_db/input/' + input_dir = f'{user_name}/input' # category : files[0], target:files[1], garment:files[2] byte_string = await files[0].read() @@ -186,9 +190,9 @@ async def make_order(files: List[UploadFile] = File(...)): target_image = target_image.convert("RGB") os.makedirs(f'{input_dir}/buffer', exist_ok=True) + # target_image.save(f'{input_dir}/buffer/target/target.jpg') - # target_image.save(f'{input_dir}/target.jpg') - target_image.save(f'{input_dir}/buffer/target/target.jpg') + gcs.upload_blob(target_bytes, f'{input_dir}/buffer/target/target.jpg') if category == 'upper_lower': # garment_upper_bytes = await files[2].read() @@ -226,19 +230,21 @@ async def make_order(files: List[UploadFile] = File(...)): else : ## file로 전송됐을 때 - # garment_bytes = await files[2].read() - # garment_image = Image.open(io.BytesIO(garment_bytes)) - # garment_image = garment_image.convert("RGB") - # garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg') + + garment_bytes = await files[2].read() + garment_image = Image.open(io.BytesIO(garment_bytes)) + garment_image = garment_image.convert("RGB") ## string으로 전송됐을 때(filename) - byte_string = await files[2].read() - string_io = io.BytesIO(byte_string) - filename = string_io.read().decode('utf-8') + # byte_string = await files[2].read() + # string_io = io.BytesIO(byte_string) + # filename = string_io.read().decode('utf-8') + + # garment_image = Image.open(os.path.join(db_dir, 'input/garment', category, filename)) + # garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg') - garment_image = Image.open(os.path.join(db_dir, 'input/garment', category, filename)) - garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg') + gcs.upload_blob(garment_bytes, f'{input_dir}/buffer/garment/{category}.jpg') - inference_allModels(category, db_dir) + inference_allModels(target_bytes, garment_bytes, category, user_name) return None \ No newline at end of file diff --git a/model/Self_Correction_Human_Parsing/datasets/simple_extractor_dataset.py b/model/Self_Correction_Human_Parsing/datasets/simple_extractor_dataset.py index 3a00204..957d7a2 100644 --- a/model/Self_Correction_Human_Parsing/datasets/simple_extractor_dataset.py +++ b/model/Self_Correction_Human_Parsing/datasets/simple_extractor_dataset.py @@ -76,3 +76,58 @@ def __getitem__(self, index): } return input, meta + +class SimpleImageData(data.Dataset): + def __init__(self, img_byte, input_size=[512, 512], transform=None): + self.img_byte = img_byte + self.input_size = input_size + self.transform = transform + self.aspect_ratio = input_size[1] * 1.0 / input_size[0] + self.input_size = np.asarray(input_size) + + def __len__(self): + return 1 + + def _box2cs(self, box): + x, y, w, h = box[:4] + return self._xywh2cs(x, y, w, h) + + def _xywh2cs(self, x, y, w, h): + center = np.zeros((2), dtype=np.float32) + center[0] = x + w * 0.5 + center[1] = y + h * 0.5 + if w > self.aspect_ratio * h: + h = w * 1.0 / self.aspect_ratio + elif w < self.aspect_ratio * h: + w = h * self.aspect_ratio + scale = np.array([w, h], dtype=np.float32) + return center, scale + + def __getitem__(self, index): + nparr = np.frombuffer(self.img_byte, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + h, w, _ = img.shape + + # Get person center and scale + person_center, s = self._box2cs([0, 0, w - 1, h - 1]) + r = 0 + trans = get_affine_transform(person_center, s, r, self.input_size) + input = cv2.warpAffine( + img, + trans, + (int(self.input_size[1]), int(self.input_size[0])), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0, 0, 0)) + + input = self.transform(input) + meta = { + 'name': 'target.jpg', + 'center': person_center, + 'height': h, + 'width': w, + 'scale': s, + 'rotation': r + } + + return input, meta \ No newline at end of file diff --git a/model/Self_Correction_Human_Parsing/simple_extractor.py b/model/Self_Correction_Human_Parsing/simple_extractor.py index 436a4e2..4a64c67 100644 --- a/model/Self_Correction_Human_Parsing/simple_extractor.py +++ b/model/Self_Correction_Human_Parsing/simple_extractor.py @@ -23,7 +23,7 @@ import networks from utils.transforms import transform_logits -from datasets.simple_extractor_dataset import SimpleFolderDataset +from datasets.simple_extractor_dataset import SimpleFolderDataset, SimpleImageData dataset_settings = { 'lip': { @@ -87,6 +87,33 @@ def get_palette(num_cls): lab >>= 3 return palette +def get_metaData(image, input_size): + + img = image + h, w, _ = img.shape + + # Get person center and scale + person_center, s = self._box2cs([0, 0, w - 1, h - 1]) + r = 0 + + from utils.transforms import get_affine_transform + trans = get_affine_transform(person_center, s, r, input_size) + input = cv2.warpAffine( + img, + trans, + (int(input_size[1]), int(input_size[0])), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0, 0, 0)) + + meta = { + 'center': person_center, + 'height': h, + 'width': w, + 'scale': s, + 'rotation': r + } + def main_schp(target_buffer_dir): @@ -152,6 +179,67 @@ def main_schp(target_buffer_dir): np.save(logits_result_path, logits_result) return +def main_schp_from_image_byte(image_byte, dataset='atr'): + args = get_arguments() + + gpus = [int(i) for i in args.gpu.split(',')] + assert len(gpus) == 1 + if not args.gpu == 'None': + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu + + num_classes = dataset_settings[args.dataset]['num_classes'] + input_size = dataset_settings[args.dataset]['input_size'] + label = dataset_settings[args.dataset]['label'] + print("Evaluating total class number {} with {}".format(num_classes, label)) + + model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None) + + state_dict = torch.load(args.model_restore)['state_dict'] + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + model.cuda() + model.eval() + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) + ]) + dataset = SimpleImageData(img_byte=image_byte, input_size=input_size, transform=transform) + dataloader = DataLoader(dataset) + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + palette = get_palette(num_classes) + with torch.no_grad(): + for idx, batch in enumerate(tqdm(dataloader)): + image, meta = batch + img_name = meta['name'][0] + c = meta['center'].numpy()[0] + s = meta['scale'].numpy()[0] + w = meta['width'].numpy()[0] + h = meta['height'].numpy()[0] + + output = model(image.cuda()) + upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True) + upsample_output = upsample(output[0][-1][0].unsqueeze(0)) + upsample_output = upsample_output.squeeze() + upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC + + logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=input_size) + parsing_result = np.argmax(logits_result, axis=2) + parsing_result_path = os.path.join(args.output_dir, img_name[:-4] + '.png') + output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8)) + output_img.putpalette(palette) + output_img.save(parsing_result_path) + if args.logits: + logits_result_path = os.path.join(args.output_dir, img_name[:-4] + '.npy') + np.save(logits_result_path, logits_result) + return output_img # if __name__ == '__main__': # main(target_buffer_dir)