Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Storage] az storage blob copy start/start-batch: Fix --auth-mode login #29964

Merged
merged 4 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions src/azure-cli/azure/cli/command_modules/storage/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def process_blob_source_uri(cmd, namespace):
if not sas:
prefix = cmd.command_kwargs['resource_type'].value[0]
if is_storagev2(prefix):
sas = create_short_lived_blob_sas_v2(cmd, source_account_name, source_account_key, container,
blob)
sas = create_short_lived_blob_sas_v2(cmd, source_account_name, container,
blob, account_key=source_account_key)
else:
sas = create_short_lived_blob_sas(cmd, source_account_name, source_account_key, container, blob)
query_params = []
Expand Down Expand Up @@ -409,8 +409,8 @@ def validate_source_uri(cmd, namespace): # pylint: disable=too-many-statements
dir_name, file_name)
elif valid_blob_source and (ns.get('share_name', None) or not same_account):
if is_storagev2(prefix):
source_sas = create_short_lived_blob_sas_v2(cmd, source_account_name, source_account_key, container,
blob)
source_sas = create_short_lived_blob_sas_v2(cmd, source_account_name, container,
blob, account_key=source_account_key)
else:
source_sas = create_short_lived_blob_sas(cmd, source_account_name, source_account_key, container, blob)

Expand All @@ -435,7 +435,8 @@ def validate_source_uri(cmd, namespace): # pylint: disable=too-many-statements


def validate_source_url(cmd, namespace): # pylint: disable=too-many-statements, too-many-locals
from .util import create_short_lived_blob_sas, create_short_lived_blob_sas_v2, create_short_lived_file_sas
from .util import create_short_lived_blob_sas, create_short_lived_blob_sas_v2, create_short_lived_file_sas, \
create_short_lived_file_sas_v2
from azure.cli.core.azclierror import InvalidArgumentValueError, RequiredArgumentMissingError, \
MutuallyExclusiveArgumentError
usage_string = \
Expand Down Expand Up @@ -463,6 +464,8 @@ def validate_source_url(cmd, namespace): # pylint: disable=too-many-statements,
source_account_name = ns.pop('source_account_name', None)
source_account_key = ns.pop('source_account_key', None)
source_sas = ns.pop('source_sas', None)
token_credential = ns.get('token_credential')
is_oauth = token_credential is not None

# source in the form of an uri
uri = ns.get('source_url', None)
Expand Down Expand Up @@ -499,7 +502,7 @@ def validate_source_url(cmd, namespace): # pylint: disable=too-many-statements,
# determine if the copy will happen in the same storage account
same_account = False

if not source_account_key and not source_sas:
if not source_account_key and not source_sas and not is_oauth:
if source_account_name == ns.get('account_name', None):
same_account = True
source_account_key = ns.get('account_key', None)
Expand All @@ -511,20 +514,41 @@ def validate_source_url(cmd, namespace): # pylint: disable=too-many-statements,
except ValueError:
raise RequiredArgumentMissingError('Source storage account {} not found.'.format(source_account_name))

# if oauth, use user delegation key to generate sas
source_user_delegation_key = None
if is_oauth:
client_kwargs = {'account_name': source_account_name,
'token_credential': token_credential}
if valid_blob_source:
client = cf_blob_service(cmd.cli_ctx, client_kwargs)

from datetime import datetime, timedelta
start = datetime.utcnow()
expiry = datetime.utcnow() + timedelta(days=1)
source_user_delegation_key = client.get_user_delegation_key(start, expiry)

# Both source account name and either key or sas (or both) are now available
if not source_sas:
prefix = cmd.command_kwargs['resource_type'].value[0]
# generate a sas token even in the same account when the source and destination are not the same kind.
if valid_file_source and (ns.get('container_name', None) or not same_account):
dir_name, file_name = os.path.split(path) if path else (None, '')
source_sas = create_short_lived_file_sas(cmd, source_account_name, source_account_key, share,
dir_name, file_name)
if dir_name == '':
dir_name = None
if is_storagev2(prefix):
source_sas = create_short_lived_file_sas_v2(cmd, source_account_name, source_account_key, share,
dir_name, file_name)
else:
source_sas = create_short_lived_file_sas(cmd, source_account_name, source_account_key, share,
dir_name, file_name)
elif valid_blob_source and (ns.get('share_name', None) or not same_account):
prefix = cmd.command_kwargs['resource_type'].value[0]
# is_storagev2() is used to distinguish if the command is in track2 SDK
# If yes, we will use get_login_credentials() as token credential
if is_storagev2(prefix):
source_sas = create_short_lived_blob_sas_v2(cmd, source_account_name, source_account_key, container,
blob)
source_sas = create_short_lived_blob_sas_v2(cmd, source_account_name, container, blob,
account_key=source_account_key,
user_delegation_key=source_user_delegation_key)
else:
source_sas = create_short_lived_blob_sas(cmd, source_account_name, source_account_key, container, blob)

Expand Down Expand Up @@ -1069,6 +1093,8 @@ def get_source_file_or_blob_service_client_track2(cmd, namespace):
source_sas = ns.get('source_sas', None)
source_container = ns.get('source_container', None)
source_share = ns.get('source_share', None)
token_credential = ns.get('token_credential')
is_oauth = token_credential is not None

if source_uri and source_account:
raise ValueError(usage_string)
Expand All @@ -1090,13 +1116,13 @@ def get_source_file_or_blob_service_client_track2(cmd, namespace):

source_account, source_key, source_sas = ns['account_name'], ns['account_key'], ns['sas_token']

if source_account:
if source_account and not is_oauth:
if not (source_key or source_sas):
# when neither storage account key nor SAS is given, try to fetch the key in the current
# subscription
source_key = _query_account_key(cmd.cli_ctx, source_account)

elif source_uri:
elif source_uri and not is_oauth:
if source_key or source_container or source_share:
raise ValueError(usage_string)

Expand Down Expand Up @@ -1125,7 +1151,7 @@ def get_source_file_or_blob_service_client_track2(cmd, namespace):
ns['source_container'] = source_container
ns['source_share'] = source_share
# get sas token for source
if not source_sas:
if not source_sas and not is_oauth:
from .util import create_short_lived_container_sas_track2, create_short_lived_share_sas_track2
if source_container:
source_sas = create_short_lived_container_sas_track2(cmd, account_name=source_account,
Expand All @@ -1139,6 +1165,8 @@ def get_source_file_or_blob_service_client_track2(cmd, namespace):
client_kwargs = {'account_name': ns['source_account_name'],
'account_key': ns['source_account_key'],
'sas_token': ns['source_sas']}
if is_oauth:
client_kwargs.update({'token_credential': token_credential})
if source_container:
ns['source_client'] = cf_blob_service(cmd.cli_ctx, client_kwargs)
if source_share:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,14 @@ def create_blob_url(client, container_name, blob_name, snapshot, protocol='https
def _copy_blob_to_blob_container(cmd, blob_service, source_blob_service, destination_container, destination_path,
source_container, source_blob_name, source_sas, **kwargs):
t_blob_client = cmd.get_models('_blob_client#BlobClient')
# generate sas for oauth copy source
if not source_sas:
from ..util import create_short_lived_blob_sas_v2
start = datetime.utcnow()
expiry = datetime.utcnow() + timedelta(hours=1)
source_user_delegation_key = source_blob_service.get_user_delegation_key(start, expiry)
source_sas = create_short_lived_blob_sas_v2(cmd, source_blob_service.account_name, source_container,
source_blob_name, user_delegation_key=source_user_delegation_key)
source_client = t_blob_client(account_url=source_blob_service.url, container_name=source_container,
blob_name=source_blob_name, credential=source_sas)
source_blob_url = source_client.url
Expand All @@ -931,7 +939,10 @@ def _copy_file_to_blob_container(blob_service, source_file_service, destination_
source_share, source_sas, source_file_dir, source_file_name):
t_share_client = source_file_service.get_share_client(source_share)
t_file_client = t_share_client.get_file_client(os.path.join(source_file_dir, source_file_name))
source_file_url = '{}?{}'.format(t_file_client.url, source_sas)
if '?' not in t_file_client.url:
source_file_url = '{}?{}'.format(t_file_client.url, source_sas)
else:
source_file_url = t_file_client.url

source_path = os.path.join(source_file_dir, source_file_name) if source_file_dir else source_file_name
destination_blob_name = normalize_blob_file_path(destination_path, source_path)
Expand Down
Loading