Skip to content

Commit

Permalink
fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ajaits committed Oct 17, 2023
1 parent 2af2284 commit a136c25
Show file tree
Hide file tree
Showing 17 changed files with 2,037 additions and 1,873 deletions.
2 changes: 2 additions & 0 deletions scripts/earthengine/events_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ def run_stage(self, stage_name: str, input_files: list = []) -> list:
'''Run a single stage and return the output files generated.'''
for stage_runner in self.stage_runners:
if stage_name == stage_runner.get_name():
logging.info(f'Running stage {stage_name} with {input_files}')
return stage_runner.run_stage(input_files)
logging.error(f'No stage runner for {stage_name} with input: {input_files}')
return []

def run(self, run_stages: list = []) -> list:
Expand Down
202 changes: 104 additions & 98 deletions scripts/earthengine/pipeline_stage_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class to run the events pipeline stage to download files from a URL.
"""
"""Class to run the events pipeline stage to download files from a URL."""

import os
import re
Expand All @@ -26,7 +25,8 @@
sys.path.append(os.path.dirname(_SCRIPTS_DIR))
sys.path.append(os.path.dirname(os.path.dirname(_SCRIPTS_DIR)))
sys.path.append(
os.path.join(os.path.dirname(os.path.dirname(_SCRIPTS_DIR)), 'util'))
os.path.join(os.path.dirname(os.path.dirname(_SCRIPTS_DIR)), 'util')
)

import file_util
import utils
Expand All @@ -37,101 +37,107 @@


class DownloadRunner(StageRunner):
'''Class to download data files from URL source.'''

def __init__(self,
config_dicts: list = [],
state: dict = {},
counters=None):
self.set_up('download', config_dicts, state, counters)

def run(self,
input_files: list = None,
config_dict: dict = {},
counters: Counters = None) -> list:
'''Returns the list of files downloaded from the URL in the config.
URLs are downloaded for each time period until the current date.'''
# Download data from start_date up to end_date
# advancing date by the time_period.
start_date = self.get_config('start_date', '', config_dict)
yesterday = utils.date_yesterday()
end_date = self.get_config('end_date', yesterday, config_dict)
if end_date > yesterday:
end_date = yesterday
data_files = []
while start_date and start_date <= end_date:
# Download data for the start_date
download_files = self.download_file_with_config(self.get_configs())
if download_files:
data_files.extend(download_files)

# Advance start_date to the next date.
start_date = utils.date_advance_by_period(
start_date, self.get_config('time_period', 'P1M', config_dict))
if start_date:
self.set_config_dates(start_date=start_date)
return data_files

def download_file_with_config(self, config_dict: dict = {}) -> list:
'''Returns list of files downloaded for config.'''
logging.info(f'Downloading data for config: {config_dict}')
downloaded_files = []
urls = config_dict.get('url', [])
if not isinstance(urls, list):
urls = [urls]
for url in urls:
if not url:
continue
url_params = config_dict.get('url_params', {})
filename = self.get_output_filename(config_dict=config_dict)
if self.should_skip_existing_output(filename):
logging.info(f'Skipping download for existing file: {filename}')
continue

# Download the URL with retries.
download_content = ''
retry_count = 0
retries = config_dict.get('retry_count', 5)
retry_secs = config_dict.get('retry_interval', 5)
while not download_content and retry_count < retries:
download_content = request_url(
url,
params=url_params,
method=config_dict.get('http_method', 'GET'),
output=config_dict.get('response_type', 'text'),
timeout=config_dict.get('timeout', 60),
retries=config_dict.get('retry_count', 3),
retry_secs=retry_secs)
if download_content:
# Check if the downloaded content matches the regex.
regex = config_dict.get('successful_response_regex', '')
if regex:
match = re.search(regex, download_content)
if not match:
download_content = ''
retry_count += 1
logging.info(
f'Downloaded content for {url} does not match {regex}'
)
if retry_count < retries:
logging.info(
f'retrying {url} #{retry_count} after {retry_secs}'
)
time.sleep(retry_secs)
if not download_content:
logging.error(
f'Failed to download {url} after {retries} retries')
return None

# Save downloaded content to file.
with file_util.FileIO(filename, mode='w') as file:
file.write(download_content)
logging.info(
f'Downloaded {len(download_content)} bytes from {url} into file: {filename}'
)
downloaded_files.append(filename)

return downloaded_files
"""Class to download data files from URL source."""

def __init__(self, config_dicts: list = [], state: dict = {}, counters=None):
self.set_up('download', config_dicts, state, counters)

def run(
self,
input_files: list = None,
config_dict: dict = {},
counters: Counters = None,
) -> list:
"""Returns the list of files downloaded from the URL in the config.
URLs are downloaded for each time period until the current date.
"""
# Download data from start_date up to end_date
# advancing date by the time_period.
start_date = self.get_config('start_date', '', config_dict)
yesterday = utils.date_yesterday()
end_date = self.get_config('end_date', yesterday, config_dict)
if end_date > yesterday:
end_date = yesterday
logging.info(
f'Running download with start_date: {start_date}, end_date:{end_date}'
)
data_files = []
while start_date and start_date <= end_date:
# Download data for the start_date
download_files = self.download_file_with_config(self.get_configs())
if download_files:
data_files.extend(download_files)

# Advance start_date to the next date.
start_date = utils.date_advance_by_period(
start_date, self.get_config('time_period', 'P1M', config_dict)
)
if start_date:
self.set_config_dates(start_date=start_date)
return data_files

def download_file_with_config(self, config_dict: dict = {}) -> list:
"""Returns list of files downloaded for config."""
logging.info(f'Downloading data for config: {config_dict}')
downloaded_files = []
urls = config_dict.get('url', [])
if not isinstance(urls, list):
urls = [urls]
for url in urls:
if not url:
continue
url_params = config_dict.get('url_params', {})
filename = self.get_output_filename(config_dict=config_dict)
if self.should_skip_existing_output(filename):
logging.info(f'Skipping download for existing file: {filename}')
continue

# Download the URL with retries.
download_content = ''
retry_count = 0
retries = config_dict.get('retry_count', 5)
retry_secs = config_dict.get('retry_interval', 5)
while not download_content and retry_count < retries:
download_content = request_url(
url,
params=url_params,
method=config_dict.get('http_method', 'GET'),
output=config_dict.get('response_type', 'text'),
timeout=config_dict.get('timeout', 60),
retries=config_dict.get('retry_count', 3),
retry_secs=retry_secs,
)
if download_content:
# Check if the downloaded content matches the regex.
regex = config_dict.get('successful_response_regex', '')
if regex:
match = re.search(regex, download_content)
if not match:
download_content = ''
retry_count += 1
logging.info(
f'Downloaded content for {url} does not match {regex}'
)
if retry_count < retries:
logging.info(
f'retrying {url} #{retry_count} after {retry_secs}'
)
time.sleep(retry_secs)
if not download_content:
logging.error(f'Failed to download {url} after {retries} retries')
return None

# Save downloaded content to file.
with file_util.FileIO(filename, mode='w') as file:
file.write(download_content)
logging.info(
f'Downloaded {len(download_content)} bytes from {url} into file:'
f' {filename}'
)
downloaded_files.append(filename)

return downloaded_files


# Register the DownloadRunner
Expand Down
2 changes: 1 addition & 1 deletion scripts/earthengine/process_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,7 +1544,7 @@ def get_place_id_for_event(self, place_id: str) -> str:
# Got a location. Convert it to a grid.
if (output_place_type
== 'grid_1') and (not utils.is_grid_id(place_id)):
grid_id = utils.grid_id_from_lat_lng(1, int(lat), int(lng))
grid_id = utils.grid_id_from_lat_lng(1, lat, lng)
place_id = grid_id
self._counters.add_counter(f'place_converted_to_grid_1', 1)
elif (output_place_type
Expand Down
28 changes: 12 additions & 16 deletions scripts/earthengine/process_events_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def compare_csv_files(self,

def test_process(self):
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = '/tmp/test_events'
output_prefix = os.path.join(tmp_dir, 'events_test_')
test_prefix = os.path.join(_TESTDIR, 'sample_floods_')
# Process flood s2 cells into events.
Expand All @@ -99,21 +98,18 @@ def test_process(self):
output_path=output_prefix,
config=self._config)
# Verify generated events.
for file in [
'events.csv',
'events.tmcf',
'svobs.csv',
'svobs.tmcf',
'place_svobs.csv',
'place_svobs.tmcf',
]:
if file.endswith('.csv'):
# compare csv output without geoJson that is not deterministic
self.compare_csv_files(test_prefix + file,
output_prefix + file,
['geoJsonCoordinatesDP1'])
else:
self.compare_files(test_prefix + file, output_prefix + file)
self.compare_csv_files(os.path.join(tmp_dir, 'events_test_events.csv'),
os.path.join(_TESTDIR, test_prefix + 'events.csv'))
self.compare_files(os.path.join(tmp_dir, 'events_test_events.tmcf'),
os.path.join(_TESTDIR, test_prefix + 'events.tmcf'))
self.compare_csv_files(os.path.join(tmp_dir, 'event_svobs', 'events_test_svobs.csv'),
os.path.join(_TESTDIR, test_prefix + 'svobs.csv'))
self.compare_files(os.path.join(tmp_dir, 'event_svobs', 'events_test_svobs.tmcf'),
os.path.join(_TESTDIR, test_prefix + 'svobs.tmcf'))
self.compare_csv_files(os.path.join(tmp_dir, 'place_svobs', 'events_test_place_svobs.csv'),
os.path.join(_TESTDIR, test_prefix + 'place_svobs.csv'))
self.compare_files(os.path.join(tmp_dir, 'place_svobs', 'events_test_place_svobs.tmcf'),
os.path.join(_TESTDIR, test_prefix + 'place_svobs.tmcf'))

def test_process_event_data(self):
'''Verify events can be added by date.'''
Expand Down
Loading

0 comments on commit a136c25

Please sign in to comment.