Skip to content

Commit

Permalink
[Feat] SCHP input 형식을 byte로 변경 #33
Browse files Browse the repository at this point in the history
- simple_extractor.py에 image 한장이 input일 경우의  inference 함수
구현(main_schp_from_image_byte 함수)
- front/back에서도 byte 형식 전달로 변경
  • Loading branch information
Hyunmin-H committed Aug 16, 2023
1 parent b90ef84 commit e47e092
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 32 deletions.
27 changes: 15 additions & 12 deletions backend/app/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
ASSETS_DIR_PATH = os.path.join(Path(__file__).parent.parent.parent.parent, "assets")
st.set_page_config(layout="wide")

root_password = 'a'
category_pair = {'Upper':'upper_body', 'Lower':'lower_body', 'Upper & Lower':'upper_lower', 'Dress':'dresses'}

db_dir = '/opt/ml/user_db'

gcp_config = load_gcp_config_from_yaml("/opt/ml/level3_cv_finalproject-cv-12/backend/config/gcs.yaml")
gcs_uploader = GCSUploader(gcp_config)
gcs = GCSUploader(gcp_config)

user_name = 'hi'

Expand Down Expand Up @@ -76,7 +75,7 @@ def append_imgList(uploaded_garment, category):
def show_garments_and_checkboxes(category):

category_dir = os.path.join(user_name, 'input/garment', category)
filenames = gcs_uploader.list_images_in_folder(category_dir)
filenames = gcs.list_images_in_folder(category_dir)
# filenames = os.listdir(category_dir)

# category_dir = os.path.join(db_dir, 'input/garment', category)
Expand All @@ -91,8 +90,12 @@ def show_garments_and_checkboxes(category):
im_dir = os.path.join(category_dir, filename)
# garment_img = Image.open(im_dir)
# garment_byte = read_image_as_bytes(im_dir)
garment_img = gcs_uploader.read_image_from_gcs(im_dir)
garment_img = gcs.read_image_from_gcs(im_dir)
# st.image(garment_img, caption=filename[:-4], width=100)

from PIL import Image
from io import BytesIO
garment_img = Image.open(BytesIO(garment_img))
cols[i % num_columns].image(garment_img, width=100, use_column_width=True, caption=filename[:-4])
# if st.checkbox(filename[:-4]) :
# return True, garment_byte
Expand All @@ -102,8 +105,11 @@ def show_garments_and_checkboxes(category):
filenames_.extend([f[:-4] for f in filenames])
selected_garment = st.selectbox('입을 옷을 선택해주세요.', filenames_, index=0, key=category)
print('selected_garment', selected_garment)

im_dir = os.path.join(category_dir, f'{selected_garment}.jpg')
garment_byte = gcs.read_image_from_gcs(im_dir)

return filenames, selected_garment
return garment_byte, selected_garment

def md_style():
st.markdown(
Expand Down Expand Up @@ -196,10 +202,11 @@ def main():
if uploaded_garment :
append_imgList(uploaded_garment, category)

filenames, selected_upper = show_garments_and_checkboxes(category)
selected_byte, selected_upper = show_garments_and_checkboxes(category)
if selected_upper :
is_selected_upper = True
files[2] = ('files', f'{selected_upper}.jpg')
# files[2] = ('files', f'{selected_upper}.jpg')
files[2] = ('files', selected_byte)
print('selected_upper', selected_upper)

with col3:
Expand All @@ -212,7 +219,7 @@ def main():
if uploaded_garment :
append_imgList(uploaded_garment, category)

filenames, selected_lower = show_garments_and_checkboxes(category)
selected_byte, selected_lower = show_garments_and_checkboxes(category)
if selected_lower :
is_selected_lower = True
files[3] = ('files', f'{selected_lower}.jpg')
Expand Down Expand Up @@ -264,7 +271,6 @@ def main():
category = 'upper_lower'

elif is_selected_upper :
print('catogory upperrr')
gen_start = True
category = 'upper_body'
elif is_selected_lower :
Expand All @@ -280,9 +286,6 @@ def main():
files[0] = ('files', category)
files[1] = ('files', (uploaded_target.name, target_bytes,
uploaded_target.type))
print('category', category)
print('files2', files[2])
print('files3', files[3])

if gen_start :

Expand Down
44 changes: 25 additions & 19 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# scp setting
import sys, os
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/Self_Correction_Human_Parsing/')
from simple_extractor import main_schp
from simple_extractor import main_schp, main_schp_from_image_byte

# openpose
sys.path.append('/opt/ml/level3_cv_finalproject-cv-12/model/pytorch_openpose/')
Expand Down Expand Up @@ -45,7 +45,7 @@
db_dir = '/opt/ml/user_db'

gcp_config = load_gcp_config_from_yaml("/opt/ml/level3_cv_finalproject-cv-12/backend/config/gcs.yaml")
gcs_uploader = GCSUploader(gcp_config)
gcs = GCSUploader(gcp_config)
user_name = 'hi'

@app.post("/add_data", description="데이터 저장")
Expand All @@ -62,7 +62,7 @@ async def add_garment_to_db(files: List[UploadFile] = File(...)):
garment_image = Image.open(io.BytesIO(garment_bytes))
garment_image = garment_image.convert("RGB")

gcs_uploader.upload_blob(garment_bytes, os.path.join(user_name, 'input/garment', category, f'{garment_name}'))
gcs.upload_blob(garment_bytes, os.path.join(user_name, 'input/garment', category, f'{garment_name}'))
# garment_image.save(os.path.join(db_dir, 'input/garment', category, f'{garment_name}'))

def read_image_as_bytes(image_path):
Expand Down Expand Up @@ -130,13 +130,16 @@ async def get_boolean():
global is_modelLoading
return {"is_modelLoading": is_modelLoading}

def inference_allModels(category, db_dir):
def inference_allModels(target_bytes, garment_bytes, category, db_dir):

input_dir = os.path.join(db_dir, 'input')
# schp - (1024, 784), (512, 384)
target_buffer_dir = os.path.join(input_dir, 'buffer/target')
main_schp(target_buffer_dir)

# main_schp(target_buffer_dir)
schp_img = main_schp_from_image_byte(target_bytes)
schp_img.save('./schp.png')

exit()
# openpose
output_openpose_buffer_dir = os.path.join(db_dir, 'openpose/buffer')
os.makedirs(output_openpose_buffer_dir, exist_ok=True)
Expand Down Expand Up @@ -172,7 +175,8 @@ def inference_ladi(category, db_dir, target_name='target.jpg'):
@app.post("/order", description="주문을 요청합니다")
async def make_order(files: List[UploadFile] = File(...)):

input_dir = '/opt/ml/user_db/input/'
# input_dir = '/opt/ml/user_db/input/'
input_dir = f'{user_name}/input'

# category : files[0], target:files[1], garment:files[2]
byte_string = await files[0].read()
Expand All @@ -186,9 +190,9 @@ async def make_order(files: List[UploadFile] = File(...)):
target_image = target_image.convert("RGB")

os.makedirs(f'{input_dir}/buffer', exist_ok=True)
# target_image.save(f'{input_dir}/buffer/target/target.jpg')

# target_image.save(f'{input_dir}/target.jpg')
target_image.save(f'{input_dir}/buffer/target/target.jpg')
gcs.upload_blob(target_bytes, f'{input_dir}/buffer/target/target.jpg')

if category == 'upper_lower':
# garment_upper_bytes = await files[2].read()
Expand Down Expand Up @@ -226,19 +230,21 @@ async def make_order(files: List[UploadFile] = File(...)):

else :
## file로 전송됐을 때
# garment_bytes = await files[2].read()
# garment_image = Image.open(io.BytesIO(garment_bytes))
# garment_image = garment_image.convert("RGB")
# garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg')

garment_bytes = await files[2].read()
garment_image = Image.open(io.BytesIO(garment_bytes))
garment_image = garment_image.convert("RGB")

## string으로 전송됐을 때(filename)
byte_string = await files[2].read()
string_io = io.BytesIO(byte_string)
filename = string_io.read().decode('utf-8')
# byte_string = await files[2].read()
# string_io = io.BytesIO(byte_string)
# filename = string_io.read().decode('utf-8')

# garment_image = Image.open(os.path.join(db_dir, 'input/garment', category, filename))
# garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg')

garment_image = Image.open(os.path.join(db_dir, 'input/garment', category, filename))
garment_image.save(f'{input_dir}/buffer/garment/{category}.jpg')
gcs.upload_blob(garment_bytes, f'{input_dir}/buffer/garment/{category}.jpg')

inference_allModels(category, db_dir)
inference_allModels(target_bytes, garment_bytes, category, user_name)

return None
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,58 @@ def __getitem__(self, index):
}

return input, meta

class SimpleImageData(data.Dataset):
def __init__(self, img_byte, input_size=[512, 512], transform=None):
self.img_byte = img_byte
self.input_size = input_size
self.transform = transform
self.aspect_ratio = input_size[1] * 1.0 / input_size[0]
self.input_size = np.asarray(input_size)

def __len__(self):
return 1

def _box2cs(self, box):
x, y, w, h = box[:4]
return self._xywh2cs(x, y, w, h)

def _xywh2cs(self, x, y, w, h):
center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5
center[1] = y + h * 0.5
if w > self.aspect_ratio * h:
h = w * 1.0 / self.aspect_ratio
elif w < self.aspect_ratio * h:
w = h * self.aspect_ratio
scale = np.array([w, h], dtype=np.float32)
return center, scale

def __getitem__(self, index):
nparr = np.frombuffer(self.img_byte, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
h, w, _ = img.shape

# Get person center and scale
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
r = 0
trans = get_affine_transform(person_center, s, r, self.input_size)
input = cv2.warpAffine(
img,
trans,
(int(self.input_size[1]), int(self.input_size[0])),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0, 0, 0))

input = self.transform(input)
meta = {
'name': 'target.jpg',
'center': person_center,
'height': h,
'width': w,
'scale': s,
'rotation': r
}

return input, meta
90 changes: 89 additions & 1 deletion model/Self_Correction_Human_Parsing/simple_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import networks
from utils.transforms import transform_logits
from datasets.simple_extractor_dataset import SimpleFolderDataset
from datasets.simple_extractor_dataset import SimpleFolderDataset, SimpleImageData

dataset_settings = {
'lip': {
Expand Down Expand Up @@ -87,6 +87,33 @@ def get_palette(num_cls):
lab >>= 3
return palette

def get_metaData(image, input_size):

img = image
h, w, _ = img.shape

# Get person center and scale
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
r = 0

from utils.transforms import get_affine_transform
trans = get_affine_transform(person_center, s, r, input_size)
input = cv2.warpAffine(
img,
trans,
(int(input_size[1]), int(input_size[0])),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0, 0, 0))

meta = {
'center': person_center,
'height': h,
'width': w,
'scale': s,
'rotation': r
}


def main_schp(target_buffer_dir):

Expand Down Expand Up @@ -152,6 +179,67 @@ def main_schp(target_buffer_dir):
np.save(logits_result_path, logits_result)
return

def main_schp_from_image_byte(image_byte, dataset='atr'):
args = get_arguments()

gpus = [int(i) for i in args.gpu.split(',')]
assert len(gpus) == 1
if not args.gpu == 'None':
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

num_classes = dataset_settings[args.dataset]['num_classes']
input_size = dataset_settings[args.dataset]['input_size']
label = dataset_settings[args.dataset]['label']
print("Evaluating total class number {} with {}".format(num_classes, label))

model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)

state_dict = torch.load(args.model_restore)['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.cuda()
model.eval()

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
])
dataset = SimpleImageData(img_byte=image_byte, input_size=input_size, transform=transform)
dataloader = DataLoader(dataset)

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

palette = get_palette(num_classes)
with torch.no_grad():
for idx, batch in enumerate(tqdm(dataloader)):
image, meta = batch
img_name = meta['name'][0]
c = meta['center'].numpy()[0]
s = meta['scale'].numpy()[0]
w = meta['width'].numpy()[0]
h = meta['height'].numpy()[0]

output = model(image.cuda())
upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
upsample_output = upsample(output[0][-1][0].unsqueeze(0))
upsample_output = upsample_output.squeeze()
upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC

logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=input_size)
parsing_result = np.argmax(logits_result, axis=2)
parsing_result_path = os.path.join(args.output_dir, img_name[:-4] + '.png')
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
output_img.putpalette(palette)
output_img.save(parsing_result_path)
if args.logits:
logits_result_path = os.path.join(args.output_dir, img_name[:-4] + '.npy')
np.save(logits_result_path, logits_result)
return output_img

# if __name__ == '__main__':
# main(target_buffer_dir)

0 comments on commit e47e092

Please sign in to comment.