diff --git a/backend/app/frontend.py b/backend/app/frontend.py
index e485c28..cecf3e8 100644
--- a/backend/app/frontend.py
+++ b/backend/app/frontend.py
@@ -56,10 +56,6 @@ def check_modelLoading():
pass
return is_modelLoading
-def read_image_as_bytes(image_path):
- with open(image_path, "rb") as file:
- image_data = file.read()
- return image_data
## 이미지 리스트에 저장
def append_imgList(uploaded_garment, category):
@@ -89,7 +85,6 @@ def show_garments_and_checkboxes(category):
for i, filename in enumerate(filenames):
im_dir = os.path.join(category_dir, filename)
# garment_img = Image.open(im_dir)
- # garment_byte = read_image_as_bytes(im_dir)
garment_img = gcs.read_image_from_gcs(im_dir)
# st.image(garment_img, caption=filename[:-4], width=100)
@@ -104,7 +99,6 @@ def show_garments_and_checkboxes(category):
filenames_ = [None]
filenames_.extend([f[:-4] for f in filenames])
selected_garment = st.selectbox('입을 옷을 선택해주세요.', filenames_, index=0, key=category)
- print('selected_garment', selected_garment)
im_dir = os.path.join(category_dir, f'{selected_garment}.jpg')
garment_byte = gcs.read_image_from_gcs(im_dir)
@@ -205,7 +199,6 @@ def main():
selected_byte, selected_upper = show_garments_and_checkboxes(category)
if selected_upper :
is_selected_upper = True
- # files[2] = ('files', f'{selected_upper}.jpg')
files[2] = ('files', selected_byte)
print('selected_upper', selected_upper)
@@ -222,25 +215,21 @@ def main():
selected_byte, selected_lower = show_garments_and_checkboxes(category)
if selected_lower :
is_selected_lower = True
- files[3] = ('files', f'{selected_lower}.jpg')
+ files[3] = ('files', selected_byte)
st.write(' ')
st.write(' ')
st.markdown("
", unsafe_allow_html=True)
category = 'dresses'
-
uploaded_garment = st.file_uploader("추가할 드레스를 넣어주세요.", type=["jpg", "jpeg", "png"])
if uploaded_garment :
append_imgList(uploaded_garment, category)
- filenames, selected_dress = show_garments_and_checkboxes(category)
+ selected_byte, selected_dress = show_garments_and_checkboxes(category)
if selected_dress :
is_selected_dress = True
- files[2] = ('files', f'{selected_dress}.jpg')
- print('is_selected_lower', is_selected_lower)
- print('is_selected_dress', is_selected_dress)
-
+ files[2] = ('files', selected_byte)
with col2:
st.markdown("", unsafe_allow_html=True)
@@ -259,11 +248,6 @@ def main():
human_slot.empty()
human_slot.image(target_img)
- # else :
-
- # example_img = Image.open('/opt/ml/level3_cv_finalproject-cv-12/backend/app/utils/example.jpg')
- # human_slot.image(example_img, width=300, use_column_width=True, caption='Example of target image')
-
print('start_button', start_button)
if start_button and uploaded_target:
if is_selected_upper and is_selected_lower :
@@ -305,13 +289,7 @@ def main():
empty_slot.empty()
empty_slot.markdown("Here it is !
", unsafe_allow_html=True)
- output_ladi_buffer_dir = '/opt/ml/user_db/ladi/buffer'
- final_result_dir = output_ladi_buffer_dir
- if category =='upper_lower':
- final_img = Image.open(os.path.join(final_result_dir, 'lower_body.png'))
- else :
- # final_img = Image.open(os.path.join(final_result_dir, f'{category}.png'))
- final_img = response.content
+ final_img = response.content
st.write(' ')
st.write(' ')
diff --git a/backend/app/main.py b/backend/app/main.py
index 8f7cbf9..0d17eab 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -9,21 +9,21 @@
# scp setting
import sys, os
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/Self_Correction_Human_Parsing/')
-from simple_extractor import main_schp, main_schp_fromImageByte
+from simple_extractor import main_schp
# openpose
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/pytorch_openpose/')
-from extract_keypoint import main_openpose, main_openpose_fromImageByte
+from extract_keypoint import main_openpose
# ladi
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton')
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton/src')
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton/src/utils')
-from get_clothing_mask import main_mask, main_mask_fromImageByte
+from get_clothing_mask import main_mask
-from inference import main_ladi, main_ladi_fromImageByte
-from face_cut_and_paste import main_cut_and_paste
+from inference import main_ladi
+from face_cut_and_paste import cut_and_paste
import torch
from accelerate import Accelerator
@@ -68,21 +68,6 @@ async def add_garment_to_db(files: List[UploadFile] = File(...)):
gcs.upload_blob(garment_bytes, os.path.join(user_name, 'input/garment', category, f'{garment_name}'))
# garment_image.save(os.path.join(db_dir, 'input/garment', category, f'{garment_name}'))
-def read_image_as_bytes(image_path):
- with open(image_path, "rb") as file:
- image_data = file.read()
- return image_data
-
-@app.get("/get_db/{category}")
-async def get_DB(category: str) :
- category_dir = os.path.join(db_dir, 'input/garment', category)
- garment_db_bytes = {}
- for filename in os.listdir(category_dir):
- garment_id = filename[:-4]
- garment_byte = read_image_as_bytes(os.path.join(category_dir, filename))
- garment_db_bytes[garment_id] = garment_byte
- return garment_db_bytes
-
def load_ladiModels():
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-inpainting"
@@ -133,61 +118,26 @@ async def get_boolean():
global is_modelLoading
return {"is_modelLoading": is_modelLoading}
-def inference_allModels(target_bytes, garment_bytes, category, db_dir):
-
- input_dir = os.path.join(db_dir, 'input')
+def inference_preprocess(target_bytes, garment_bytes, garment_lower_bytes=None):
# schp - (1024, 784), (512, 384)
- target_buffer_dir = os.path.join(input_dir, 'buffer/target')
- # main_schp(target_buffer_dir)
- schp_img = main_schp_fromImageByte(target_bytes)
- schp_img.save('./schp.png')
-
+ schp_img = main_schp(target_bytes)
+
# openpose
- output_openpose_buffer_dir = os.path.join(db_dir, 'openpose/buffer')
- os.makedirs(output_openpose_buffer_dir, exist_ok=True)
- # main_openpose(target_buffer_dir, output_openpose_buffer_dir)
- keypoint_dict = main_openpose_fromImageByte(target_bytes)
- gcs.upload_dict_as_json_to_gcs(keypoint_dict, os.path.join(db_dir, 'openpose/buffer/target.json'))
-
- # /opt/ml/user_db/mask/buffer
- # mask
- garment_dir = os.path.join(input_dir, 'buffer/garment')
- output_mask_dir = os.path.join(db_dir, 'mask/buffer')
- os.makedirs(output_mask_dir, exist_ok=True)
- # main_mask(category, garment_dir, output_mask_dir)
+ keypoint_dict = main_openpose(target_bytes)
## garment_mask 형식 - Image
- garment_mask = main_mask_fromImageByte(garment_bytes)
-
- garment_mask.save('./garment_mask.jpg')
-
- # ladi-vton
- output_ladi_buffer_dir = os.path.join(db_dir, 'ladi/buffer')
- os.makedirs(output_ladi_buffer_dir, exist_ok=True)
-
- # main_ladi(category, db_dir, output_ladi_buffer_dir, ladi_models)
- finalResult_img = main_ladi_fromImageByte(category, target_bytes, schp_img, keypoint_dict, garment_bytes, garment_mask, ladi_models)
- finalResult_img = main_cut_and_paste(category, target_bytes, finalResult_img, schp_img)
- return finalResult_img
-
-def inference_ladi(category, db_dir, target_name='target.jpg'):
- input_dir = os.path.join(db_dir, 'input')
- garment_dir = os.path.join(input_dir, 'buffer/garment')
- output_mask_dir = os.path.join(db_dir, 'mask/buffer')
- main_mask(category, garment_dir, output_mask_dir)
-
- # ladi-vton
- output_ladi_buffer_dir = os.path.join(db_dir, 'ladi/buffer')
- os.makedirs(output_ladi_buffer_dir, exist_ok=True)
-
- main_ladi(category, db_dir, output_ladi_buffer_dir, ladi_models, target_name)
- main_cut_and_paste(category, db_dir, target_name)
+ garment_mask = main_mask(garment_bytes)
+ if garment_lower_bytes is None :
+ return schp_img, keypoint_dict, garment_mask
+ else :
+ garment_lower_mask = main_mask(garment_lower_bytes)
+
+ return schp_img, keypoint_dict, garment_mask, garment_lower_mask
# post!!
@app.post("/order", description="주문을 요청합니다")
async def make_order(files: List[UploadFile] = File(...)):
- # input_dir = '/opt/ml/user_db/input/'
input_dir = f'{user_name}/input'
# category : files[0], target:files[1], garment:files[2]
@@ -198,67 +148,28 @@ async def make_order(files: List[UploadFile] = File(...)):
## category가 upper & lower일 경우
target_bytes = await files[1].read()
- target_image = Image.open(io.BytesIO(target_bytes))
- target_image = target_image.convert("RGB")
-
- os.makedirs(f'{input_dir}/buffer', exist_ok=True)
- # target_image.save(f'{input_dir}/buffer/target/target.jpg')
-
gcs.upload_blob(target_bytes, f'{input_dir}/buffer/target/target.jpg')
if category == 'upper_lower':
- # garment_upper_bytes = await files[2].read()
- # garment_lower_bytes = await files[3].read()
-
- # garment_upper_image = Image.open(io.BytesIO(garment_upper_bytes))
- # garment_upper_image = garment_upper_image.convert("RGB")
- # garment_lower_image = Image.open(io.BytesIO(garment_lower_bytes))
- # garment_lower_image = garment_lower_image.convert("RGB")
-
- # # garment_upper_image.save(f'{input_dir}/upper_body.jpg')
- # garment_upper_image.save(f'{input_dir}/buffer/garment/upper_body.jpg')
- # # garment_lower_image.save(f'{input_dir}/lower_body.jpg')
- # garment_lower_image.save(f'{input_dir}/buffer/garment/lower_body.jpg')
-
-
- ## string으로 전송됐을 때(filename)
- string_upper_bytes = await files[2].read()
- string_lower_bytes = await files[3].read()
- string_io_upper = io.BytesIO(string_upper_bytes)
- string_io_lower = io.BytesIO(string_lower_bytes)
- filename_upper = string_io_upper.read().decode('utf-8')
- filename_lower = string_io_lower.read().decode('utf-8')
-
- garment_image_upper = Image.open(os.path.join(db_dir, 'input/garment', 'upper_body', filename_upper))
- garment_image_lower = Image.open(os.path.join(db_dir, 'input/garment', 'lower_body', filename_lower))
- garment_image_upper.save(f'{input_dir}/buffer/garment/upper_body.jpg')
- garment_image_lower.save(f'{input_dir}/buffer/garment/lower_body.jpg')
-
+ garment_upper_bytes = await files[2].read()
+ garment_lower_bytes = await files[3].read()
- finalResult_img = inference_allModels('upper_body', db_dir)
- shutil.copy(os.path.join(db_dir, 'ladi/buffer', 'upper_body.png'), f'{input_dir}/buffer/target/upper_body.jpg')
- inference_ladi('lower_body', db_dir, target_name='upper_body.jpg')
+ schp_img, keypoint_dict, garment_upper_mask, garment_lower_mask = inference_preprocess(target_bytes, garment_upper_bytes, garment_lower_bytes)
+ ladi_img = main_ladi('upper_body', target_bytes, schp_img, keypoint_dict, garment_upper_bytes, garment_upper_mask, ladi_models)
+ ladi_bytes = PIL2Byte(ladi_img)
+ ladi_img = main_ladi('lower_body', ladi_bytes, schp_img, keypoint_dict, garment_lower_bytes, garment_lower_mask, ladi_models)
+ finalResult_img = cut_and_paste(target_bytes, ladi_img, schp_img)
else :
## file로 전송됐을 때
-
garment_bytes = await files[2].read()
- garment_image = Image.open(io.BytesIO(garment_bytes))
- garment_image = garment_image.convert("RGB")
-
- ## string으로 전송됐을 때(filename)
- # byte_string = await files[2].read()
- # string_io = io.BytesIO(byte_string)
- # filename = string_io.read().decode('utf-8')
-
- # garment_image = Image.open(os.path.join(db_dir, 'input/garment', category, filename))
- # garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg')
- gcs.upload_blob(garment_bytes, f'{input_dir}/buffer/garment/{category}.jpg')
-
- finalResult_img = inference_allModels(target_bytes, garment_bytes, category, user_name)
+ schp_img, keypoint_dict, garment_mask = inference_preprocess(target_bytes, garment_bytes)
+ ladi_img = main_ladi(category, target_bytes, schp_img, keypoint_dict, garment_bytes, garment_mask, ladi_models)
+ finalResult_img = cut_and_paste(target_bytes, ladi_img, schp_img)
finalResult_bytes = PIL2Byte(finalResult_img)
gcs.upload_blob(finalResult_bytes, f'{input_dir}/ladi/buffer/final.jpg')
+
return StreamingResponse(io.BytesIO(finalResult_bytes), media_type="image/jpg")
\ No newline at end of file