Skip to content

Commit

Permalink
fix import & asignement issue (#377)
Browse files Browse the repository at this point in the history
* fix import & asignement issue
* Apply suggestions from code review

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Sep 18, 2024
1 parent 7bf1d6a commit 895a829
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/litdata/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

if TYPE_CHECKING:
from PIL.JpegImagePlugin import JpegImageFile

_PIL_AVAILABLE = RequirementCache("PIL")
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
_AV_AVAILABLE = RequirementCache("av")
Expand Down Expand Up @@ -70,6 +69,8 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]:

@classmethod
def deserialize(cls, data: bytes) -> Any:
if not _PIL_AVAILABLE:
raise ModuleNotFoundError("PIL is required. Run `pip install pillow`")
from PIL import Image

idx = 3 * 4
Expand All @@ -94,6 +95,9 @@ class JPEGSerializer(Serializer):
"""The JPEGSerializer serialize and deserialize JPEG image to and from bytes."""

def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]:
if not _PIL_AVAILABLE:
raise ModuleNotFoundError("PIL is required. Run `pip install pillow`")

from PIL import Image
from PIL.GifImagePlugin import GifImageFile
from PIL.JpegImagePlugin import JpegImageFile
Expand Down Expand Up @@ -122,7 +126,7 @@ def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]:
buff.seek(0)
return buff.read(), None

raise TypeError(f"The provided item should be of type {JpegImageFile}. Found {item}.")
raise TypeError(f"The provided item should be of type `JpegImageFile`. Found {item}.")

def deserialize(self, data: bytes) -> Union["JpegImageFile", torch.Tensor]:
if _TORCH_VISION_AVAILABLE:
Expand Down Expand Up @@ -184,7 +188,7 @@ def deserialize(self, data: bytes) -> torch.Tensor:
shape = []
for shape_idx in range(shape_size):
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())
idx_start = 8 + 4 * (shape_idx + 1)
idx_start = 8 + 4 * shape_size
idx_end = len(data)
if idx_end > idx_start:
tensor = torch.frombuffer(data[idx_start:idx_end], dtype=dtype)
Expand Down Expand Up @@ -250,7 +254,7 @@ def deserialize(self, data: bytes) -> np.ndarray:
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())

# deserialize the numpy array bytes
tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype)
tensor = np.frombuffer(data[8 + 4 * shape_size : len(data)], dtype=dtype)
if tensor.shape == shape:
return tensor
return np.reshape(tensor, shape)
Expand Down

0 comments on commit 895a829

Please sign in to comment.