diff --git a/aiida_pseudo/cli/install.py b/aiida_pseudo/cli/install.py index b2ed365..d07192d 100644 --- a/aiida_pseudo/cli/install.py +++ b/aiida_pseudo/cli/install.py @@ -26,10 +26,13 @@ def cmd_install(): @click.argument('label', type=click.STRING) @options_core.DESCRIPTION(help='Description for the family.') @options.ARCHIVE_FORMAT() +@options.FAMILY_TYPE( + type=types.PseudoPotentialFamilyTypeParam(exclude=('pseudo.family.sssp', 'pseudo.family.pseudo_dojo')) +) @options.PSEUDO_TYPE() @options.TRACEBACK() @decorators.with_dbenv() -def cmd_install_family(archive, label, description, archive_format, pseudo_type, traceback): # pylint: disable=too-many-arguments +def cmd_install_family(archive, label, description, archive_format, family_type, pseudo_type, traceback): # pylint: disable=too-many-arguments """Install a standard pseudopotential family from an ARCHIVE. The ARCHIVE can be a (compressed) archive of a directory containing the pseudopotentials on the local file system or @@ -51,16 +54,15 @@ def cmd_install_family(archive, label, description, archive_format, pseudo_type, pseudopotential type, the format of the file is unknown and the family requires the element to be known, which in this case can then only be parsed from the filename. """ - from aiida_pseudo.groups.family.pseudo import PseudoPotentialFamily from .utils import attempt, create_family_from_archive if isinstance(archive, pathlib.Path) and archive.is_dir(): with attempt(f'creating a pseudopotential family from directory `{archive}`...', include_traceback=traceback): - family = PseudoPotentialFamily.create_from_folder(archive, label, pseudo_type=pseudo_type) + family = family_type.create_from_folder(archive, label, pseudo_type=pseudo_type) elif isinstance(archive, pathlib.Path) and archive.is_file(): with attempt('unpacking archive and parsing pseudos... ', include_traceback=traceback): family = create_family_from_archive( - PseudoPotentialFamily, label, archive, fmt=archive_format, pseudo_type=pseudo_type + family_type, label, archive, fmt=archive_format, pseudo_type=pseudo_type ) else: # At this point, we can assume that it is not a valid filepath on disk, but rather a URL and the ``archive`` @@ -79,11 +81,7 @@ def cmd_install_family(archive, label, description, archive_format, pseudo_type, with attempt('unpacking archive and parsing pseudos... ', include_traceback=traceback): family = create_family_from_archive( - PseudoPotentialFamily, - label, - pathlib.Path(handle.name), - fmt=archive_format, - pseudo_type=pseudo_type + family_type, label, pathlib.Path(handle.name), fmt=archive_format, pseudo_type=pseudo_type ) family.description = description diff --git a/aiida_pseudo/cli/params/types.py b/aiida_pseudo/cli/params/types.py index afebb95..2f0b4c5 100644 --- a/aiida_pseudo/cli/params/types.py +++ b/aiida_pseudo/cli/params/types.py @@ -62,6 +62,14 @@ class PseudoPotentialFamilyTypeParam(click.ParamType): name = 'pseudo_family_type' + def __init__(self, exclude: typing.Optional[typing.List[str]] = None, **kwargs): + """Construct the parameter. + + :param exclude: an optional list of values that should be considered invalid and will raise ``BadParameter``. + """ + super().__init__(**kwargs) + self.exclude = exclude + def convert(self, value, _, __): """Convert the entry point name to the corresponding class. @@ -78,6 +86,9 @@ def convert(self, value, _, __): except exceptions.EntryPointError as exception: raise click.BadParameter(f'`{value}` is not an existing group plugin.') from exception + if self.exclude and value in self.exclude: + raise click.BadParameter(f'`{value}` is not an accepted value for this option.') + if not issubclass(family_type, PseudoPotentialFamily): raise click.BadParameter(f'`{value}` entry point is not a subclass of `PseudoPotentialFamily`.')