Skip to content

Commit

Permalink
Unbreak image datasets in model prediction component (#3469)
Browse files Browse the repository at this point in the history
* Intermediate state.

* Unit tests for image datasets.

* Build image datasets without using df append.

* Undo changes to download_dependencies.py.

* Cosmetics.

* More cosmetics.

* Even more cosmetics.
  • Loading branch information
rdondera-microsoft authored Oct 9, 2024
1 parent eed062d commit 597d616
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 11 deletions.
34 changes: 23 additions & 11 deletions assets/training/model_evaluation/src/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,20 @@ def get_classification_dataset(
# labels: {test_dataset_wrapper.num_classes}"
)

df = pd.DataFrame(columns=input_column_names + [label_column_name])
# Initialize the rows of the output dataframe to the empty list.
frame_rows = []

for index in range(len(test_dataset_wrapper)):
image_path = test_dataset_wrapper.get_image_full_path(index)
if is_valid_image(image_path):
# sending image_paths instead of base64 encoded string as oss flavor doesnt take bytes as input.
df = df.append({
frame_rows.append({
input_column_names[0]: image_path,
label_column_name: test_dataset_wrapper.label_at_index(index)
}, ignore_index=True)
})

# Make the output dataframe.
df = pd.DataFrame(data=frame_rows, columns=input_column_names + [label_column_name])

return df

Expand Down Expand Up @@ -253,7 +258,9 @@ def get_object_detection_dataset(
f"# test images: {len(test_dataset)}, # labels: {test_dataset.num_classes}"
)
test_dataset_wrapper = RuntimeDetectionDatasetAdapter(test_dataset)
df = pd.DataFrame(columns=input_column_names + [label_column_name])

# Initialize the rows of the output dataframe to the empty list.
frame_rows = []

counter = 0
for index in range(len(test_dataset_wrapper)):
Expand All @@ -262,12 +269,15 @@ def get_object_detection_dataset(

if is_valid_image(image_path):
counter += 1
df = df.append({
frame_rows.append({
input_column_names[0]: base64.encodebytes(read_image(image_path)).decode("utf-8"),
input_column_names[1]: image_meta_info,
input_column_names[2]: ". ".join(test_dataset.classes),
label_column_name: label,
}, ignore_index=True)
})

# Make the output dataframe.
df = pd.DataFrame(data=frame_rows, columns=input_column_names + [label_column_name])

logger.info(f"Total number of valid images: {counter}")
return df
Expand Down Expand Up @@ -300,8 +310,8 @@ def get_generation_dataset(
mltable = load(mltable_path)
mltable_dataframe = mltable.to_pandas_dataframe()

# Initialize the output dataframe with the input and label columns.
df = pd.DataFrame(columns=input_column_names + [label_column_name])
# Initialize the rows of the output dataframe to the empty list.
frame_rows = []

# Go through all (image_url, captions) pairs and make a (prompt, image_url) from each pair. The model will generate
# a synthetic image from the prompt and the set of synthetic images will be compared with the set of original ones.
Expand All @@ -310,16 +320,18 @@ def get_generation_dataset(
):
# Go through all captions (split according to special separator).
for caption in captions.split(GenerationLiterals.CAPTION_SEPARATOR):
df = df.append(
frame_rows.append(
{
# The model input is a text prompt.
input_column_names[0]: caption,
# The original image is passed through via the label column.
label_column_name: image_url,
},
ignore_index=True
}
)

# Make the output dataframe.
df = pd.DataFrame(data=frame_rows, columns=input_column_names + [label_column_name])

return df


Expand Down
4 changes: 4 additions & 0 deletions assets/training/model_evaluation/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Init file."""
190 changes: 190 additions & 0 deletions assets/training/model_evaluation/tests/test_image_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Test image dataset implementations."""

import json
import os
import pytest
import sys
import tempfile

from unittest.mock import patch

from azureml.acft.common_components.image.runtime_common.common.dataset_helper import AmlDatasetHelper

MODEL_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.dirname(__file__)), "./src"))
sys.path.append(MODEL_DIR)
from constants import TASK # noqa: E402
from image_dataset import get_image_dataset # noqa: E402


DATASET_PER_TASK = {
TASK.IMAGE_CLASSIFICATION: [
{"image_url": "AmlDatastore://images/a/image1.jpg", "label": 0},
{"image_url": "AmlDatastore://images/a/image2.jpg", "label": 1},
],
TASK.IMAGE_OBJECT_DETECTION: [
{
"image_url": "AmlDatastore://images/b/image1.png",
"label": [{"label": 0, "topX": 0.0, "topY": 0.0, "bottomX": 0.5, "bottomY": 0.5}],
},
{
"image_url": "AmlDatastore://images/b/image2.png",
"label": [{"label": 1, "topX": 0.5, "topY": 0.5, "bottomX": 1.0, "bottomY": 1.0}],
},
],
TASK.IMAGE_GENERATION: [
{"image_url": "example.com/image1.png", "label": "an example"},
{"image_url": "example.com/image2.png", "label": "another example"},
],
}
MLTABLE_CONTENTS_PER_TASK = {
TASK.IMAGE_CLASSIFICATION: (
"paths:\n"
" - file: {file_name}\n"
"transformations:\n"
" - read_json_lines:\n"
" encoding: utf8\n"
" invalid_lines: error\n"
" include_path_column: false\n"
" - convert_column_types:\n"
" - columns: image_url\n"
" column_type: stream_info\n"
"type: mltable\n"
),
TASK.IMAGE_OBJECT_DETECTION: (
"paths:\n"
" - file: {file_name}\n"
"transformations:\n"
" - read_json_lines:\n"
" encoding: utf8\n"
" invalid_lines: error\n"
" include_path_column: false\n"
" - convert_column_types:\n"
" - columns: image_url\n"
" column_type: stream_info\n"
"type: mltable\n"
),
TASK.IMAGE_GENERATION: (
"paths:\n"
"- file: {file_name}\n"
"transformations:\n"
"- read_json_lines:\n"
" encoding: utf8\n"
" include_path_column: false\n"
" invalid_lines: error\n"
" partition_size: 20971520\n"
" path_column: Path\n"
"- convert_column_types:\n"
" - column_type: stream_info\n"
" columns: image_url\n"
"type: mltable\n"
),
}


class MockWorkspace:
"""Mock workspace."""

def __init__(self, subscription_id, resource_group, workspace_name, location, workspace_id):
"""Make mock workspace."""
self.subscription_id = subscription_id
self.resource_group = resource_group
self._workspace_name = workspace_name
self.location = location
self._workspace_id_internal = workspace_id
self.name = workspace_name


class MockExperiment:
"""Mock experiment."""

def __init__(self, workspace, id):
"""Make mock experiment."""
self.workspace = workspace
self.id = id


class MockRun:
"""Mock run."""

def __init__(self, id):
"""Make mock run."""
self.id = id


class MockRunContext:
"""Mock run context."""

def __init__(self, experiment, run_id, parent_run_id):
"""Make mock run context."""
self.experiment = experiment
self._run_id = run_id
self.id = run_id
self.parent = MockRun(parent_run_id)


def get_mock_run_context():
"""Make mock run context."""
TEST_EXPERIMENT_ID = "22222222-2222-2222-2222-222222222222"
TEST_REGION = "eastus"
TEST_PARENT_RUN_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
TEST_RESOURCE_GROUP = "testrg"
TEST_RUN_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
TEST_SUBSCRIPTION_ID = "00000000-0000-0000-0000-000000000000"
TEST_WORKSPACE_ID = "11111111-1111-1111-111111111111"
TEST_WORKSPACE_NAME = "testws"

ws = MockWorkspace(
subscription_id=TEST_SUBSCRIPTION_ID,
resource_group=TEST_RESOURCE_GROUP,
workspace_name=TEST_WORKSPACE_NAME,
location=TEST_REGION,
workspace_id=TEST_WORKSPACE_ID,
)
experiment = MockExperiment(workspace=ws, id=TEST_EXPERIMENT_ID)
return MockRunContext(experiment, run_id=TEST_RUN_ID, parent_run_id=TEST_PARENT_RUN_ID)


@pytest.mark.parametrize("task_type,input_column_names,label_column_name", [
(TASK.IMAGE_CLASSIFICATION, ["image_url"], "label"),
(TASK.IMAGE_OBJECT_DETECTION, ["image_url"], "label"),
(TASK.IMAGE_GENERATION, ["prompt"], "label"),
])
def test_image_dataset(task_type, input_column_names, label_column_name):
"""Test image dataset on small example."""
with tempfile.TemporaryDirectory() as directory_name:
# Save the jsonl file.
dataset = DATASET_PER_TASK[task_type]
with open(os.path.join(directory_name, "dataset.jsonl"), "wt") as f:
for r in dataset:
f.write(json.dumps(r) + "\n")

# Save the MLTable file.
mltable_str = MLTABLE_CONTENTS_PER_TASK[task_type].format(file_name="dataset.jsonl")
with open(os.path.join(directory_name, "MLTable"), "wt") as f:
f.write(mltable_str)

# Make blank image files for image classification and object detection tasks, to simulate downloading.
if task_type in [TASK.IMAGE_CLASSIFICATION, TASK.IMAGE_OBJECT_DETECTION]:
for r in dataset:
image_file_name_tokens = r["image_url"].replace("AmlDatastore://", "").split("/")
os.makedirs(os.path.join(directory_name, *image_file_name_tokens[:-1]), exist_ok=True)
open(os.path.join(directory_name, *image_file_name_tokens), "wb").close()

# Load the MLTable.
with patch("azureml.core.Run.get_context", get_mock_run_context), \
patch(
"azureml.acft.common_components.image.runtime_common.common.utils.download_or_mount_image_files"
), \
patch.object(AmlDatasetHelper, "get_data_dir", return_value=directory_name):
df = get_image_dataset(task_type, directory_name, input_column_names, label_column_name)

# Compare the loaded dataset with the original.
if task_type == TASK.IMAGE_GENERATION:
loaded_dataset = [{k: row[k] for k in ["prompt", "label"]} for _, row in df.iterrows()]
for r1, r2 in zip(
sorted(dataset, key=lambda x: x["label"]), sorted(loaded_dataset, key=lambda x: x["prompt"])
):
assert r2 == {"prompt": r1["label"], "label": r1["image_url"]}

0 comments on commit 597d616

Please sign in to comment.