Skip to content

Commit

Permalink
[Refactor] upper & lower까지 동작 #33
Browse files Browse the repository at this point in the history
- main.py의 inference 함수 다시 구현(inference_preprocess 함수)

related to : #31
  • Loading branch information
Hyunmin-H committed Aug 17, 2023
1 parent b0c9d12 commit 800a2bb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 142 deletions.
30 changes: 4 additions & 26 deletions backend/app/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ def check_modelLoading():
pass
return is_modelLoading

def read_image_as_bytes(image_path):
with open(image_path, "rb") as file:
image_data = file.read()
return image_data
## 이미지 리스트에 저장
def append_imgList(uploaded_garment, category):

Expand Down Expand Up @@ -89,7 +85,6 @@ def show_garments_and_checkboxes(category):
for i, filename in enumerate(filenames):
im_dir = os.path.join(category_dir, filename)
# garment_img = Image.open(im_dir)
# garment_byte = read_image_as_bytes(im_dir)
garment_img = gcs.read_image_from_gcs(im_dir)
# st.image(garment_img, caption=filename[:-4], width=100)

Expand All @@ -104,7 +99,6 @@ def show_garments_and_checkboxes(category):
filenames_ = [None]
filenames_.extend([f[:-4] for f in filenames])
selected_garment = st.selectbox('입을 옷을 선택해주세요.', filenames_, index=0, key=category)
print('selected_garment', selected_garment)

im_dir = os.path.join(category_dir, f'{selected_garment}.jpg')
garment_byte = gcs.read_image_from_gcs(im_dir)
Expand Down Expand Up @@ -205,7 +199,6 @@ def main():
selected_byte, selected_upper = show_garments_and_checkboxes(category)
if selected_upper :
is_selected_upper = True
# files[2] = ('files', f'{selected_upper}.jpg')
files[2] = ('files', selected_byte)
print('selected_upper', selected_upper)

Expand All @@ -222,25 +215,21 @@ def main():
selected_byte, selected_lower = show_garments_and_checkboxes(category)
if selected_lower :
is_selected_lower = True
files[3] = ('files', f'{selected_lower}.jpg')
files[3] = ('files', selected_byte)

st.write(' ')
st.write(' ')
st.markdown("<h3 class='center-aligned-header'>드레스👗</h3>", unsafe_allow_html=True)
category = 'dresses'

uploaded_garment = st.file_uploader("추가할 드레스를 넣어주세요.", type=["jpg", "jpeg", "png"])

if uploaded_garment :
append_imgList(uploaded_garment, category)

filenames, selected_dress = show_garments_and_checkboxes(category)
selected_byte, selected_dress = show_garments_and_checkboxes(category)
if selected_dress :
is_selected_dress = True
files[2] = ('files', f'{selected_dress}.jpg')
print('is_selected_lower', is_selected_lower)
print('is_selected_dress', is_selected_dress)

files[2] = ('files', selected_byte)

with col2:
st.markdown("<h3 class='center-aligned-header'>드레스룸🚪</h3>", unsafe_allow_html=True)
Expand All @@ -259,11 +248,6 @@ def main():
human_slot.empty()
human_slot.image(target_img)

# else :

# example_img = Image.open('/opt/ml/level3_cv_finalproject-cv-12/backend/app/utils/example.jpg')
# human_slot.image(example_img, width=300, use_column_width=True, caption='Example of target image')

print('start_button', start_button)
if start_button and uploaded_target:
if is_selected_upper and is_selected_lower :
Expand Down Expand Up @@ -305,13 +289,7 @@ def main():
empty_slot.empty()
empty_slot.markdown("<h2 style='text-align: center;'>Here it is !</h2>", unsafe_allow_html=True)

output_ladi_buffer_dir = '/opt/ml/user_db/ladi/buffer'
final_result_dir = output_ladi_buffer_dir
if category =='upper_lower':
final_img = Image.open(os.path.join(final_result_dir, 'lower_body.png'))
else :
# final_img = Image.open(os.path.join(final_result_dir, f'{category}.png'))
final_img = response.content
final_img = response.content

st.write(' ')
st.write(' ')
Expand Down
143 changes: 27 additions & 116 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@
# scp setting
import sys, os
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/Self_Correction_Human_Parsing/')
from simple_extractor import main_schp, main_schp_fromImageByte
from simple_extractor import main_schp

# openpose
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/pytorch_openpose/')
from extract_keypoint import main_openpose, main_openpose_fromImageByte
from extract_keypoint import main_openpose

# ladi
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton')
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton/src')
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/ladi_vton/src/utils')

from get_clothing_mask import main_mask, main_mask_fromImageByte
from get_clothing_mask import main_mask

from inference import main_ladi, main_ladi_fromImageByte
from face_cut_and_paste import main_cut_and_paste
from inference import main_ladi
from face_cut_and_paste import cut_and_paste

import torch
from accelerate import Accelerator
Expand Down Expand Up @@ -68,21 +68,6 @@ async def add_garment_to_db(files: List[UploadFile] = File(...)):
gcs.upload_blob(garment_bytes, os.path.join(user_name, 'input/garment', category, f'{garment_name}'))
# garment_image.save(os.path.join(db_dir, 'input/garment', category, f'{garment_name}'))

def read_image_as_bytes(image_path):
with open(image_path, "rb") as file:
image_data = file.read()
return image_data

@app.get("/get_db/{category}")
async def get_DB(category: str) :
category_dir = os.path.join(db_dir, 'input/garment', category)
garment_db_bytes = {}
for filename in os.listdir(category_dir):
garment_id = filename[:-4]
garment_byte = read_image_as_bytes(os.path.join(category_dir, filename))
garment_db_bytes[garment_id] = garment_byte
return garment_db_bytes

def load_ladiModels():
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-inpainting"

Expand Down Expand Up @@ -133,61 +118,26 @@ async def get_boolean():
global is_modelLoading
return {"is_modelLoading": is_modelLoading}

def inference_allModels(target_bytes, garment_bytes, category, db_dir):

input_dir = os.path.join(db_dir, 'input')
def inference_preprocess(target_bytes, garment_bytes, garment_lower_bytes=None):
# schp - (1024, 784), (512, 384)
target_buffer_dir = os.path.join(input_dir, 'buffer/target')
# main_schp(target_buffer_dir)
schp_img = main_schp_fromImageByte(target_bytes)
schp_img.save('./schp.png')

schp_img = main_schp(target_bytes)

# openpose
output_openpose_buffer_dir = os.path.join(db_dir, 'openpose/buffer')
os.makedirs(output_openpose_buffer_dir, exist_ok=True)
# main_openpose(target_buffer_dir, output_openpose_buffer_dir)
keypoint_dict = main_openpose_fromImageByte(target_bytes)
gcs.upload_dict_as_json_to_gcs(keypoint_dict, os.path.join(db_dir, 'openpose/buffer/target.json'))

# /opt/ml/user_db/mask/buffer
# mask
garment_dir = os.path.join(input_dir, 'buffer/garment')
output_mask_dir = os.path.join(db_dir, 'mask/buffer')
os.makedirs(output_mask_dir, exist_ok=True)
# main_mask(category, garment_dir, output_mask_dir)
keypoint_dict = main_openpose(target_bytes)

## garment_mask 형식 - Image
garment_mask = main_mask_fromImageByte(garment_bytes)

garment_mask.save('./garment_mask.jpg')

# ladi-vton
output_ladi_buffer_dir = os.path.join(db_dir, 'ladi/buffer')
os.makedirs(output_ladi_buffer_dir, exist_ok=True)

# main_ladi(category, db_dir, output_ladi_buffer_dir, ladi_models)
finalResult_img = main_ladi_fromImageByte(category, target_bytes, schp_img, keypoint_dict, garment_bytes, garment_mask, ladi_models)
finalResult_img = main_cut_and_paste(category, target_bytes, finalResult_img, schp_img)
return finalResult_img

def inference_ladi(category, db_dir, target_name='target.jpg'):
input_dir = os.path.join(db_dir, 'input')
garment_dir = os.path.join(input_dir, 'buffer/garment')
output_mask_dir = os.path.join(db_dir, 'mask/buffer')
main_mask(category, garment_dir, output_mask_dir)

# ladi-vton
output_ladi_buffer_dir = os.path.join(db_dir, 'ladi/buffer')
os.makedirs(output_ladi_buffer_dir, exist_ok=True)

main_ladi(category, db_dir, output_ladi_buffer_dir, ladi_models, target_name)
main_cut_and_paste(category, db_dir, target_name)
garment_mask = main_mask(garment_bytes)
if garment_lower_bytes is None :
return schp_img, keypoint_dict, garment_mask
else :
garment_lower_mask = main_mask(garment_lower_bytes)

return schp_img, keypoint_dict, garment_mask, garment_lower_mask

# post!!
@app.post("/order", description="주문을 요청합니다")
async def make_order(files: List[UploadFile] = File(...)):

# input_dir = '/opt/ml/user_db/input/'
input_dir = f'{user_name}/input'

# category : files[0], target:files[1], garment:files[2]
Expand All @@ -198,67 +148,28 @@ async def make_order(files: List[UploadFile] = File(...)):
## category가 upper & lower일 경우
target_bytes = await files[1].read()

target_image = Image.open(io.BytesIO(target_bytes))
target_image = target_image.convert("RGB")

os.makedirs(f'{input_dir}/buffer', exist_ok=True)
# target_image.save(f'{input_dir}/buffer/target/target.jpg')

gcs.upload_blob(target_bytes, f'{input_dir}/buffer/target/target.jpg')

if category == 'upper_lower':
# garment_upper_bytes = await files[2].read()
# garment_lower_bytes = await files[3].read()

# garment_upper_image = Image.open(io.BytesIO(garment_upper_bytes))
# garment_upper_image = garment_upper_image.convert("RGB")
# garment_lower_image = Image.open(io.BytesIO(garment_lower_bytes))
# garment_lower_image = garment_lower_image.convert("RGB")

# # garment_upper_image.save(f'{input_dir}/upper_body.jpg')
# garment_upper_image.save(f'{input_dir}/buffer/garment/upper_body.jpg')
# # garment_lower_image.save(f'{input_dir}/lower_body.jpg')
# garment_lower_image.save(f'{input_dir}/buffer/garment/lower_body.jpg')


## string으로 전송됐을 때(filename)
string_upper_bytes = await files[2].read()
string_lower_bytes = await files[3].read()
string_io_upper = io.BytesIO(string_upper_bytes)
string_io_lower = io.BytesIO(string_lower_bytes)
filename_upper = string_io_upper.read().decode('utf-8')
filename_lower = string_io_lower.read().decode('utf-8')

garment_image_upper = Image.open(os.path.join(db_dir, 'input/garment', 'upper_body', filename_upper))
garment_image_lower = Image.open(os.path.join(db_dir, 'input/garment', 'lower_body', filename_lower))
garment_image_upper.save(f'{input_dir}/buffer/garment/upper_body.jpg')
garment_image_lower.save(f'{input_dir}/buffer/garment/lower_body.jpg')

garment_upper_bytes = await files[2].read()
garment_lower_bytes = await files[3].read()

finalResult_img = inference_allModels('upper_body', db_dir)
shutil.copy(os.path.join(db_dir, 'ladi/buffer', 'upper_body.png'), f'{input_dir}/buffer/target/upper_body.jpg')
inference_ladi('lower_body', db_dir, target_name='upper_body.jpg')
schp_img, keypoint_dict, garment_upper_mask, garment_lower_mask = inference_preprocess(target_bytes, garment_upper_bytes, garment_lower_bytes)
ladi_img = main_ladi('upper_body', target_bytes, schp_img, keypoint_dict, garment_upper_bytes, garment_upper_mask, ladi_models)
ladi_bytes = PIL2Byte(ladi_img)
ladi_img = main_ladi('lower_body', ladi_bytes, schp_img, keypoint_dict, garment_lower_bytes, garment_lower_mask, ladi_models)
finalResult_img = cut_and_paste(target_bytes, ladi_img, schp_img)


else :
## file로 전송됐을 때

garment_bytes = await files[2].read()
garment_image = Image.open(io.BytesIO(garment_bytes))
garment_image = garment_image.convert("RGB")

## string으로 전송됐을 때(filename)
# byte_string = await files[2].read()
# string_io = io.BytesIO(byte_string)
# filename = string_io.read().decode('utf-8')

# garment_image = Image.open(os.path.join(db_dir, 'input/garment', category, filename))
# garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg')

gcs.upload_blob(garment_bytes, f'{input_dir}/buffer/garment/{category}.jpg')

finalResult_img = inference_allModels(target_bytes, garment_bytes, category, user_name)
schp_img, keypoint_dict, garment_mask = inference_preprocess(target_bytes, garment_bytes)
ladi_img = main_ladi(category, target_bytes, schp_img, keypoint_dict, garment_bytes, garment_mask, ladi_models)
finalResult_img = cut_and_paste(target_bytes, ladi_img, schp_img)

finalResult_bytes = PIL2Byte(finalResult_img)
gcs.upload_blob(finalResult_bytes, f'{input_dir}/ladi/buffer/final.jpg')

return StreamingResponse(io.BytesIO(finalResult_bytes), media_type="image/jpg")

0 comments on commit 800a2bb

Please sign in to comment.