Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TEST] Ray 2.6 support #3615

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ jobs:
- python-version: "3.8"
pytorch-version: 1.13.0
torchscript-version: 1.10.2
ray-version: 2.2.0
ray-version: 2.6.3
- python-version: "3.9"
pytorch-version: 2.0.0
torchscript-version: 1.10.2
ray-version: 2.3.0
ray-version: 2.6.3
- python-version: "3.10"
pytorch-version: nightly
torchscript-version: 1.10.2
ray-version: 2.3.1
ray-version: 2.6.3
env:
PYTORCH: ${{ matrix.pytorch-version }}
MARKERS: ${{ matrix.test-markers }}
Expand Down
28 changes: 23 additions & 5 deletions ludwig/data/dataset/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pyarrow.fs import FSSpecHandler, PyFileSystem
from pyarrow.lib import ArrowInvalid
from ray.data import read_parquet
from ray.data.dataset import Dataset as _Dataset
from ray.data.dataset_pipeline import DatasetPipeline

from ludwig.api_annotations import DeveloperAPI
Expand All @@ -49,6 +50,7 @@

logger = logging.getLogger(__name__)

_ray_240 = version.parse(ray.__version__) >= version.parse("2.4.0")
_ray_230 = version.parse(ray.__version__) >= version.parse("2.3.0")


Expand Down Expand Up @@ -234,7 +236,7 @@ def data_format(self):
class RayDatasetShard(Dataset):
def __init__(
self,
dataset_shard: DatasetPipeline,
dataset_shard: _Dataset,
features: Dict[str, FeatureConfigDict],
training_set_metadata: TrainingSetMetadataDict,
):
Expand All @@ -244,6 +246,16 @@ def __init__(
self.create_epoch_iter()

def create_epoch_iter(self) -> None:
if _ray_240:
print("DATASET SHARD", type(self.dataset_shard))
self.epoch_iter = self.dataset_shard
# if isinstance(self.dataset_shard, DatasetPipeline):
# self.epoch_iter = self.dataset_shard.repeat().iter_epochs()
# else:
# self.epoch_iter = self.dataset_shard.repeat()
# print("EPOCH ITER", type(self.epoch_iter))
return

if _ray_230:
# In Ray >= 2.3, session.get_dataset_shard() returns a DatasetIterator object.
if isinstance(self.dataset_shard, ray.data.DatasetIterator):
Expand Down Expand Up @@ -289,7 +301,12 @@ def initialize_batcher(

@lru_cache(1)
def __len__(self):
return next(self.epoch_iter).count()
print("TYPE", type(self.epoch_iter))
num_rows = 0
for block, meta in self.epoch_iter._to_block_iterator()[0]:
num_rows += meta.num_rows
print("NUM ROWS", num_rows)
return num_rows

@property
def size(self):
Expand All @@ -306,7 +323,7 @@ def to_scalar_df(self, features: Optional[Iterable[BaseFeature]] = None) -> Data
class RayDatasetBatcher(Batcher):
def __init__(
self,
dataset_epoch_iterator: Iterator[DatasetPipeline],
dataset_epoch_iterator: _Dataset,
features: Dict[str, Dict],
training_set_metadata: TrainingSetMetadataDict,
batch_size: int,
Expand Down Expand Up @@ -364,7 +381,8 @@ def steps_per_epoch(self):
return math.ceil(self.samples_per_epoch / self.batch_size)

def _fetch_next_epoch(self):
pipeline = next(self.dataset_epoch_iterator)
# pipeline = next(self.dataset_epoch_iterator)
pipeline = self.dataset_epoch_iterator

read_parallelism = 1
if read_parallelism == 1:
Expand Down Expand Up @@ -438,7 +456,7 @@ def sync_read():

return sync_read()

def _create_async_reader(self, pipeline: DatasetPipeline):
def _create_async_reader(self, pipeline: _Dataset):
q = queue.Queue(maxsize=100)
batch_size = self.batch_size
augment_batch = self._augment_batch_fn()
Expand Down
2 changes: 1 addition & 1 deletion requirements_distributed.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ dask[dataframe]<2023.4.0
pyarrow

# requirements for ray
ray[default,data,serve,tune]>=2.2.0,<2.5
ray[default,data,serve,tune]==2.6.3
tensorboardX<2.3
GPUtil
tblib
Expand Down
Loading