Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull request update/240418 #261

Merged
merged 6 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 37 additions & 24 deletions bulldozer/bulldozer_worker/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,33 @@ def _task2str(task):
def update_task_state(self):
raise NotImplementedError

def get_run_by_runner(self, runner):
instance_id = runner["instance_id"]
task_id = runner["task_id"]
_, runs = self.arcee_cl.runs_by_executor(instance_id, [task_id])
LOG.info("runs info: %s", str(runs))
if runs:
run_id = runs[0]
LOG.info("run found! run id: %s", run_id)
else:
run_id = None
return run_id

def update_run_info(self, run_id, runner):
runset_id = runner["runset_id"]
_, runset = self.bulldozer_cl.runset_get(runset_id)
runset_name = runset.get("name", "")
hp = runner["hyperparameters"]
_, run = self.arcee_cl.run_get(run_id)
existing_hp = run.get("hyperparameters", dict())
existing_hp.update(hp)
self.arcee_cl.run_update(
run_id,
runset_id=runset_id,
runset_name=runset_name,
hyperparameters=existing_hp
)

def process_infra_tries(self):
retry = False
# check is it spot runner type
Expand Down Expand Up @@ -248,20 +275,25 @@ def update_reason(self):
reason = self.body.get("reason")
runner_id = self.body.get('runner_id')
_, runner = self.bulldozer_cl.get_runner(runner_id)
run_id = runner.get("run_id")
reason = str(reason)

run_id = runner.get("run_id")
if not run_id:
run_id = self.get_run_by_runner(runner)

# update runner reason
LOG.info("updating reason for runner %s, reason: %s",
runner_id, reason)
self.bulldozer_cl.update_runner(runner_id, reason=f"{reason}")
self.bulldozer_cl.update_runner(runner_id, reason=f"{reason}",
run_id=run_id)

# if runner knows about arcee run, need to update it also
if run_id:
LOG.info("getting run info for runner: %s, run: %s",
runner_id, run_id)
# update arcee run reason
try:
self.update_run_info(run_id, runner)
_, run = self.arcee_cl.run_get(run_id)
run_state = run["state"]
# In case of stared run need to abort it
Expand Down Expand Up @@ -436,18 +468,12 @@ def _exec(self):
_, runner = self.bulldozer_cl.get_runner(runner_id)
LOG.info("got runner from bulldozer API: %s", runner)
instance_id = runner["instance_id"]
task_id = runner["task_id"]
hp = runner["hyperparameters"]
runset_id = runner["runset_id"]
_, runset = self.bulldozer_cl.runset_get(runset_id)
runset_name = runset.get("name", "")

LOG.info("checking for arcee runs for executor: %s", instance_id)
# try to get run id from Arcee
_, runs = self.arcee_cl.runs_by_executor(instance_id, [task_id])
LOG.info("runs info: %s", str(runs))
run_id = self.get_run_by_runner(runner)

if not runs:
if not run_id:
# check timeout
last_updated = int(self.body.get("updated"))
current_time = int(datetime.datetime.utcnow().timestamp())
Expand All @@ -458,20 +484,7 @@ def _exec(self):
# TODO: Do we need automatically destroy env?
raise ArceeWaitException("Arcee wait exceeded")
else:
run_id = runs[0]
LOG.info("run found! run id: %s", run_id)
LOG.info("updating run %s with runset id %s", run_id, runset_id)
# get run info
_, run = self.arcee_cl.run_get(run_id)
existing_hp = run.get("hyperparameters", dict())
existing_hp.update(hp)
# update run
self.arcee_cl.run_update(
run_id,
runset_id=runset_id,
runset_name=runset_name,
hyperparameters=existing_hp,
)
self.update_run_info(run_id, runner)
self.bulldozer_cl.update_runner(
runner_id,
run_id=run_id,
Expand Down
2 changes: 1 addition & 1 deletion diworker/diworker/importers/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def provider(self):
'aws': 'aws',
'alicloud': 'alibaba',
'azure': 'azure',
'gce': None
'gce': 'gcp'
}
cloud_type = provider_cloud_types.get(cloud_mark)
if cloud_type is None:
Expand Down
60 changes: 58 additions & 2 deletions insider/insider_api/controllers/flavor_price.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tools.cloud_adapter.clouds.alibaba import Alibaba
from tools.cloud_adapter.clouds.aws import Aws
from tools.cloud_adapter.clouds.azure import Azure
from tools.cloud_adapter.clouds.gcp import Gcp
from tools.optscale_exceptions.common_exc import WrongArgumentsException
from botocore.exceptions import ClientError as AwsClientError
from insider.insider_api.utils import handle_credentials_error
Expand Down Expand Up @@ -370,11 +371,66 @@ def _flavor_format(self, price_infos, region, os_type):
return result


class GcpProvider(BaseProvider):
@property
def prices_collection(self):
return self.mongo_client.insider.gcp_prices

@property
def cloud_adapter(self):
config = self._config_cl.read_branch('/service_credentials/gcp')
self._cloud_adapter = Gcp(config)
return self._cloud_adapter

def _load_flavor_prices(self, region, flavor, os_type='linux',
preinstalled=None, billing_method=None,
quantity=None, currency='USD'):
now = datetime.utcnow()
query = {
'region': region,
'flavor': flavor,
'updated_at': {'$gte': now - timedelta(days=60)}
}
price_infos = list(self.prices_collection.find(query))
if not price_infos:
prices = self.cloud_adapter.get_instance_types_priced(region)
updates = []
for flavor_name, price_info in prices.items():
price_info['updated_at'] = now
price_info['flavor'] = flavor_name
price_info['region'] = region
updates.append(UpdateOne(
filter={'flavor': flavor_name, 'region': region},
update={'$set': price_info},
upsert=True,
))
price_infos.append(price_info)
if updates:
self.prices_collection.bulk_write(updates)
return list(filter(lambda x: x['flavor'] == flavor, price_infos))

def _flavor_format(self, price_infos, region, os_type):
result = []
currency = 'USD'
price_unit = '1 hour'
for price_info in price_infos:
result.append({
'price': price_info['price'],
'region': region,
'flavor': price_info['flavor'],
'operating_system': os_type,
'price_unit': price_unit,
'currency': currency
})
return result


class PricesProvider:
__modules__ = {
'azure': AzureProvider,
'aws': AwsProvider,
'alibaba': AlibabaProvider
'alibaba': AlibabaProvider,
'gcp': GcpProvider
}

@staticmethod
Expand All @@ -391,7 +447,7 @@ def __init__(self, *args, **kwargs):

@property
def supported_cloud_types(self):
return ['alibaba', 'azure', 'aws']
return ['alibaba', 'azure', 'aws', 'gcp']

@property
def required_params(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ const SettingModelVersion = () => (
</Typography>
</li>
</ul>
<CodeBlock text={`arcee.set_model_version("1.2.3-release")`} />
<CodeBlock text={`arcee.model_version("1.2.3-release")`} />
</>
);

Expand Down Expand Up @@ -412,10 +412,67 @@ const AddHyperparameters = () => (
</Typography>
</li>
</ul>
<CodeBlock
text={`arcee.hyperparam("hyperparam_key", hyperparam_value)
`}
/>
<CodeBlock text={`arcee.hyperparam("hyperparam_key", hyperparam_value)`} />
</>
);

const SettingModelVersionAlias = () => (
<>
<SubTitle fontWeight="bold">
<FormattedMessage id="mlProfilingIntegration.settingModelVersionAliasTitle" />
</SubTitle>
<Typography gutterBottom>
<FormattedMessage id="mlProfilingIntegration.settingModelVersionAlias" values={{ ...preFormatMessageValues }} />
<HtmlSymbol symbol="colon" />
</Typography>
<ul>
<li>
<Typography>
<FormattedMessage
id="mlProfilingIntegration.settingModelVersionVersionAliasDescription"
values={{
strong: (chunks) => <strong>{chunks}</strong>
}}
/>
</Typography>
</li>
</ul>
<CodeBlock text={`arcee.model_version_alias("winner")`} />
</>
);

const SettingModelVersionTag = () => (
<>
<SubTitle fontWeight="bold">
<FormattedMessage id="mlProfilingIntegration.settingModelVersionTagTitle" />
</SubTitle>
<Typography gutterBottom>
<FormattedMessage id="mlProfilingIntegration.settingModelVersionTag" values={{ ...preFormatMessageValues }} />
<HtmlSymbol symbol="colon" />
</Typography>
<ul>
<li>
<Typography>
<FormattedMessage
id="mlProfilingIntegration.settingModelVersionTagKeyDescription"
values={{
strong: (chunks) => <strong>{chunks}</strong>
}}
/>
</Typography>
</li>
<li>
<Typography>
<FormattedMessage
id="mlProfilingIntegration.settingModelVersionTagValueDescription"
values={{
strong: (chunks) => <strong>{chunks}</strong>
}}
/>
</Typography>
</li>
</ul>
<CodeBlock text={`arcee.model_version_tag("env", "staging demo")`} />
</>
);

Expand Down Expand Up @@ -502,6 +559,12 @@ const ProfilingIntegration = ({ profilingToken, taskKey, isLoading }) => (
<div>
<SettingModelVersion />
</div>
<div>
<SettingModelVersionAlias />
</div>
<div>
<SettingModelVersionTag />
</div>
<div>
<FinishTaskRun />
</div>
Expand Down
9 changes: 8 additions & 1 deletion ngui/ui/src/translations/en-US/app.json
Original file line number Diff line number Diff line change
Expand Up @@ -1155,8 +1155,15 @@
"mlProfilingIntegration.orInCaseOfError": "or in case of error",
"mlProfilingIntegration.sendMetrics": "To send <link>metrics</link>, use the <pre>send</pre> method with the following metric",
"mlProfilingIntegration.sendMetricsDataDescription": "<strong>data</strong> (dict): a dictionary of metric names and their respective values (note that metric data values should be numeric).",
"mlProfilingIntegration.settingModelVersion": "To set custom model version, use the <pre>set_model_version</pre> method with the following parameter",
"mlProfilingIntegration.settingModelVersion": "To set custom model version, use the <pre>model_version</pre> method with the following parameter",
"mlProfilingIntegration.settingModelVersionAlias": "To set model version alias, use the <pre>model_version_alias</pre> method with the following parameter",
"mlProfilingIntegration.settingModelVersionAliasTitle": "Setting model version alias",
"mlProfilingIntegration.settingModelVersionTag": "To add tags to a model version, use the <pre>model_version_tag</pre> method with the following parameters",
"mlProfilingIntegration.settingModelVersionTagKeyDescription": "<strong>key</strong> (str): tag name",
"mlProfilingIntegration.settingModelVersionTagTitle": "Setting model version tag",
"mlProfilingIntegration.settingModelVersionTagValueDescription": "<strong>value</strong> (str): tag value",
"mlProfilingIntegration.settingModelVersionTitle": "Setting model version",
"mlProfilingIntegration.settingModelVersionVersionAliasDescription": "<strong>alias</strong> (str): alias name",
"mlProfilingIntegration.settingModelVersionVersionDescription": "<strong>version</strong> (str): version name",
"mlProfilingIntegration.someCode": "some code",
"mlProfilingIntegration.taggingTaskRun": "Tagging task run",
Expand Down
Loading