diff --git a/.gitignore b/.gitignore index 736d07fe..bec8053b 100644 --- a/.gitignore +++ b/.gitignore @@ -80,4 +80,7 @@ example/data test-model # Testing -sample/ \ No newline at end of file +sample/ + +tests/test-sample-app +tests/test-sample-site \ No newline at end of file diff --git a/docs/model.md b/docs/model.md index c9d56181..ce4c713a 100644 --- a/docs/model.md +++ b/docs/model.md @@ -202,6 +202,9 @@ model: oarepo:use: ./common.yaml#person ``` +**Note:** If you reference multiple files via ``oarepo:use: ["a", "b"]`` and both +include the same property, the value from ``a`` will be used. + ## Built-in extensions ### `oarepo:mapping` - elasticsearch definition @@ -253,6 +256,16 @@ This customizes how the marshmallow field will be generated. The section might c generate `from import as generate a nested schema class, otherwise do not generate the class + * `field` is set -> use this field directly as is + * `class` is set -> generate/use class with this name. If not set, infer the class name from the context + * `read=false, write=false` - do not generate the field at all + Example: diff --git a/oarepo_model_builder/builtin_models/invenio.json b/oarepo_model_builder/builtin_models/invenio.json index df6da841..5a1cbfd2 100644 --- a/oarepo_model_builder/builtin_models/invenio.json +++ b/oarepo_model_builder/builtin_models/invenio.json @@ -1,29 +1,46 @@ { "model": { + "oarepo:marshmallow": { + "base-classes": [ + "invenio_records_resources.services.records.schema.BaseRecordSchema" + ] + }, "properties": { "id": { "type": "keyword", "oarepo:sample": { "skip": true + }, + "oarepo:marshmallow": { + "read": false, + "write": false } }, "created": { "type": "date", "oarepo:sample": { "skip": true + }, + "oarepo:marshmallow": { + "write": false, + "read": true } }, "updated": { "type": "date", "oarepo:sample": { "skip": true + }, + "oarepo:marshmallow": { + "write": false, + "read": true } }, "$schema": { "type": "keyword", "oarepo:marshmallow": { - "field_name": "_schema", - "field_args": "data_key='$schema'" + "read": false, + "write": false }, "oarepo:sample": { "skip": true diff --git a/oarepo_model_builder/entrypoints.py b/oarepo_model_builder/entrypoints.py index 69a9bce4..fd579c99 100644 --- a/oarepo_model_builder/entrypoints.py +++ b/oarepo_model_builder/entrypoints.py @@ -1,10 +1,13 @@ import sys +from functools import reduce +from importlib import import_module from pathlib import Path -import pkg_resources +import importlib.metadata +import importlib.resources from oarepo_model_builder.builder import ModelBuilder -from oarepo_model_builder.schema import ModelSchema +from oarepo_model_builder.schema import ModelSchema, remove_star_keys from oarepo_model_builder.utils.hyphen_munch import HyphenMunch @@ -31,40 +34,50 @@ def create_builder_from_entrypoints(**kwargs): def load_entry_points_dict(name): - return {ep.name: ep.load() for ep in pkg_resources.iter_entry_points(group=name)} + return {ep.name: ep.load() for ep in importlib.metadata.entry_points().select(group=name)} def load_entry_points_list(name): - ret = [(ep.name, ep.load()) for ep in pkg_resources.iter_entry_points(group=name)] + ret = [(ep.name, ep.load()) for ep in importlib.metadata.entry_points().select(group=name)] ret.sort() return [x[1] for x in ret] -def load_model_from_entrypoint(ep: pkg_resources.EntryPoint): +def load_model_from_entrypoint(ep: importlib.metadata.EntryPoint): def load(schema): - filename = ".".join(ep.attrs) - data = pkg_resources.resource_string(ep.module_name, filename) - return schema._load(filename, content=data) + try: + loaded_schema = ep.load() + except: + module = import_module(ep.module) + split_attr = ep.attr.split('.') + fn = f'{split_attr[-2]}.{split_attr[-1]}' + if len(split_attr) > 2: + fn = reduce(lambda x, y: Path(x) / Path(y), split_attr[:-2]) / fn + content = importlib.resources.open_text(module, fn, encoding='utf-8').read() + loaded_schema = schema._load(fn, content=content) + + remove_star_keys(loaded_schema) + return loaded_schema return load def load_included_models_from_entry_points(): ret = {} - for ep in pkg_resources.iter_entry_points(group="oarepo.models"): + for ep in importlib.metadata.entry_points().select(group="oarepo.models"): ret[ep.name] = load_model_from_entrypoint(ep) return ret def load_model( - model_filename, - package=None, - configs=(), - black=True, - isort=True, - sets=(), - model_content=None, - extra_included=None, + model_filename, + package=None, + configs=(), + black=True, + isort=True, + sets=(), + model_content=None, + extra_included=None, ): loaders = load_entry_points_dict("oarepo_model_builder.loaders") included_models = load_included_models_from_entry_points() @@ -114,11 +127,11 @@ def check_plugin_packages(schema): unknown_packages = [rp for rp in required_packages if rp not in known_packages] if unknown_packages: if ( - input( - f'Required packages {", ".join(unknown_packages)} are missing. ' - f"Should I install them for you via pip install? (y/n) " - ) - == "y" + input( + f'Required packages {", ".join(unknown_packages)} are missing. ' + f"Should I install them for you via pip install? (y/n) " + ) + == "y" ): if subprocess.call(["pip", "install", *unknown_packages]): sys.exit(1) diff --git a/oarepo_model_builder/invenio/invenio_record.py b/oarepo_model_builder/invenio/invenio_record.py index a9bce7e9..061f1321 100644 --- a/oarepo_model_builder/invenio/invenio_record.py +++ b/oarepo_model_builder/invenio/invenio_record.py @@ -1,7 +1,25 @@ from .invenio_base import InvenioBaseClassPythonBuilder +from ..builders import process class InvenioRecordBuilder(InvenioBaseClassPythonBuilder): TYPE = "invenio_record" class_config = "record-class" template = "record" + + def begin(self, schema, settings): + super().begin(schema, settings) + self.relations = [] + + @process("/model/**", condition=lambda current, stack: stack.schema_valid) + def enter_model_element(self): + self.build_children() + data = self.stack.top.data + if isinstance(data, dict) and 'invenio:relation' in data: + self.relations.append(data['invenio:relation']) + + def process_template(self, python_path, template, **extra_kwargs): + return super().process_template(python_path, template, **{ + **extra_kwargs, + 'invenio_relations': self.relations + }) diff --git a/oarepo_model_builder/invenio/invenio_record_schema.py b/oarepo_model_builder/invenio/invenio_record_schema.py index 4d458a6a..6ae7e7b7 100644 --- a/oarepo_model_builder/invenio/invenio_record_schema.py +++ b/oarepo_model_builder/invenio/invenio_record_schema.py @@ -50,6 +50,12 @@ def enter_model_element(self): definition = None recurse = True + if isinstance(self.stack.top.data, dict): + definition = self.stack.top.data.get(OAREPO_MARSHMALLOW_PROPERTY, {}) + generate_key = definition.get('read', True) or definition.get('write', True) + if not generate_key: + return + if schema_element_type == "properties": parent = self.stack[-2].data definition = parent.get(OAREPO_MARSHMALLOW_PROPERTY, {}) @@ -58,12 +64,18 @@ def enter_model_element(self): if "nested" not in definition: definition["nested"] = True - generate_schema_class = definition.get("generate", True) + generate_schema_class = definition.get('generate') schema_class = None # to make pycharm happy schema_class_base_classes = None + if generate_schema_class: if "class" not in definition: - definition["class"] = self.stack.top.key.title() + for se in reversed(self.stack.stack): + if se.schema_element_type == 'property': + definition["class"] = se.key.title() + break + else: + definition["class"] = self.stack.top.key.title() if "class" in definition: schema_class = definition["class"] if "." not in schema_class: @@ -194,12 +206,18 @@ def create_field(field_type, options=(), validators=(), definition=None): validators = [*validators, *definition.get("validators", [])] nested = definition.get("nested", False) required = definition.get('required', False) + read = definition.get('read', True) + write = definition.get('write', True) list_nested = definition.get("list_nested", False) if validators: opts.append(f'validate=[{",".join(validators)}]') if required: opts.append(f'required=' + str(required)) + if not read and write: + opts.append('load_only=True') + if not write and read: + opts.append('dump_only=True') kwargs = definition.get("field_args", "") if kwargs and opts: kwargs = ", " + kwargs @@ -208,10 +226,20 @@ def create_field(field_type, options=(), validators=(), definition=None): else: ret = f"{field_type}()" if nested: - if opts or kwargs: - ret = f'ma_fields.Nested(lambda: {ret}, {", ".join(opts)}{kwargs})' + if ret.endswith('()'): + ret = ret[:-2] + else: + ret = f'lamda: {ret}' + if isinstance(nested, str): + if opts or kwargs: + ret = f'{nested}({ret}, {", ".join(opts)}{kwargs})' + else: + ret = f"{nested}({ret})" else: - ret = f"ma_fields.Nested(lambda: {ret})" + if opts or kwargs: + ret = f'ma_fields.Nested({ret}, {", ".join(opts)}{kwargs})' + else: + ret = f"ma_fields.Nested({ret})" if list_nested: if opts or kwargs: ret = f'ma_fields.List(ma_fields.Nested(lambda: {ret}, {", ".join(opts)}{kwargs}))' @@ -240,10 +268,12 @@ def marshmallow_boolean_generator(data, definition, schema, imports): validators = [] return create_field("ma_fields.Boolean", [], validators, definition) + def marshmallow_raw_generator(data, definition, schema, imports): validators = [] return create_field("ma_fields.Raw", [], validators, definition) + def marshmallow_generic_number_generator(datatype, data, definition, schema, imports): validators = definition.get('validators', []) if validators != []: diff --git a/oarepo_model_builder/invenio/templates/invenio_record.py.jinja2 b/oarepo_model_builder/invenio/templates/invenio_record.py.jinja2 index aa202fd8..31928a82 100644 --- a/oarepo_model_builder/invenio/templates/invenio_record.py.jinja2 +++ b/oarepo_model_builder/invenio/templates/invenio_record.py.jinja2 @@ -1,5 +1,6 @@ -from invenio_records.systemfields import ConstantField -from invenio_records_resources.records.systemfields import IndexField +from invenio_records.systemfields import ConstantField, RelationsField +from invenio_records_resources.records.systemfields import IndexField, + from invenio_records_resources.records.systemfields.pid import PIDField, PIDFieldContext from invenio_pidstore.providers.recordid_v2 import RecordIdProviderV2 from invenio_records_resources.records.api import Record as InvenioBaseRecord @@ -10,16 +11,41 @@ from {{ b|package_name }} import {{ b|base_name }} from {{ python.record_metadata_class|package_name }} import {{ python.record_metadata_class|base_name }} from {{ python.record_dumper_class|package_name }} import {{ python.record_dumper_class|base_name }} +{% for rel in invenio_relations %} +{% for imp in rel.imports or [] %} +import {{ imp }} +{% endfor %} +{% endfor %} +{% if invenio_relations %} +from invenio_records.dumpers.relations import RelationDumperExt +{% endif %} class {{ python.record_class|base_name }}({% for b in python.record_bases %}{{ b|base_name }}, {% endfor %}InvenioBaseRecord): model_cls = {{ python.record_metadata_class|base_name }} schema = ConstantField("$schema", "{{ settings.schema_server }}{{ settings.schema_name }}") index = IndexField("{{ settings.index_name }}") +{% if python.generate_record_pid_field %} pid = PIDField( create=True, provider=RecordIdProviderV2, context_cls = PIDFieldContext ) - dumper_extensions = [] +{% endif %} + dumper_extensions = [ + {%- for ext in python.record_dumper_extensions %}{{ ext }}{% if not loop.last %}, {% endif %}{% endfor -%} + {% if invenio_relations %} + RelationDumperExt("relations"), + {% endif %} + ] dumper = {{ python.record_dumper_class|base_name }}(extensions=dumper_extensions) + {% if invenio_relations %} + relations = RelationsField( + {% for rel in invenio_relations %} + {{ rel.name }}={{ rel.type }}( + {% for param in rel.params %}{{ param }}, + {% endfor %} + ), + {% endfor %} + ) + {% endif %} \ No newline at end of file diff --git a/oarepo_model_builder/invenio/templates/invenio_record_schema.py.jinja2 b/oarepo_model_builder/invenio/templates/invenio_record_schema.py.jinja2 index a4ed6270..8171fc96 100644 --- a/oarepo_model_builder/invenio/templates/invenio_record_schema.py.jinja2 +++ b/oarepo_model_builder/invenio/templates/invenio_record_schema.py.jinja2 @@ -2,8 +2,8 @@ from invenio_records_resources.services.records.schema import BaseRecordSchema a from marshmallow import ValidationError from marshmallow import validates as ma_validates -{% for b in python.record_schema_bases %} - from {{ b|package_name }} import {{ b|base_name }} +{% for b in schema_bases %} +from {{ b|package_name }} import {{ b|base_name }} {% endfor %} {% include "imports" %} diff --git a/oarepo_model_builder/invenio/templates/invenio_record_search_options.py.jinja2 b/oarepo_model_builder/invenio/templates/invenio_record_search_options.py.jinja2 index 3c433b40..31ad05a8 100644 --- a/oarepo_model_builder/invenio/templates/invenio_record_search_options.py.jinja2 +++ b/oarepo_model_builder/invenio/templates/invenio_record_search_options.py.jinja2 @@ -20,18 +20,11 @@ class {{ python.record_search_options_class|base_name }}({% for b in python.reco {% endfor %} } sort_options = { - "bestmatch": dict( - title=_('Best match'), - fields=['_score'], # ES defaults to desc on `_score` field - ), - "newest": dict( - title=_('Newest'), - fields=['-created'], - ), - "oldest": dict( - title=_('Oldest'), - fields=['created'], - ), + {% if python.record_search_options_bases %} + **{{ python.record_search_options_bases[0]|base_name }}.sort_options, + {% else %} + **InvenioSearchOptions.sort_options, + {% endif %} {% for dict in sort_definition %} {% for key, value in dict.items()%} '{{ key }}': {{ value }}, diff --git a/oarepo_model_builder/invenio/templates/invenio_record_service_config.py.jinja2 b/oarepo_model_builder/invenio/templates/invenio_record_service_config.py.jinja2 index f3666bf8..2721387c 100644 --- a/oarepo_model_builder/invenio/templates/invenio_record_service_config.py.jinja2 +++ b/oarepo_model_builder/invenio/templates/invenio_record_service_config.py.jinja2 @@ -1,11 +1,12 @@ from invenio_records_resources.services import RecordServiceConfig as InvenioRecordServiceConfig from invenio_records_resources.services import RecordLink, pagination_links -from invenio_records_resources.services.records.components import DataComponent, MetadataComponent - {% for b in python.record_service_config_bases %} from {{ b|package_name }} import {{ b|base_name }} {% endfor %} +{% for b in python.record_service_config_components %} +from {{ b|package_name }} import {{ b|base_name }} +{% endfor %} from {{ python.record_class|package_name }} import {{ python.record_class|base_name }} from {{ python.record_permissions_class|package_name }} import {{ python.record_permissions_class|base_name }} from {{ python.record_schema_class|package_name }} import {{ python.record_schema_class|base_name }} @@ -21,13 +22,14 @@ class {{ python.record_service_config_class|base_name }}({% for b in python.reco record_cls = {{ python.record_class|base_name }} {% if python.record_service_config_bases %} - components = {% for b in python.record_service_config_bases[:1] %}{{ b|base_name }}.components{% endfor %} + components = [ *{% for b in python.record_service_config_bases[:1] %}{{ b|base_name }}.components{% endfor %}{% for c in python.record_service_config_components %}, {{ c|base_name }} {% endfor %}] {% else %} - components = [ *InvenioRecordServiceConfig.components ] + components = [ *InvenioRecordServiceConfig.components{% for c in python.record_service_config_components %}, {{ c|base_name }} {% endfor %} ] {% endif %} model = "{{ settings.model_name }}" + {% if python.record_service_config_generate_links %} @property def links_item(self): return { @@ -35,4 +37,4 @@ class {{ python.record_service_config_class|base_name }}({% for b in python.reco } links_search = pagination_links("{{ settings.collection_url }}{?args*}") - + {% endif %} diff --git a/oarepo_model_builder/invenio/templates/invenio_views.py.jinja2 b/oarepo_model_builder/invenio/templates/invenio_views.py.jinja2 index 4fca3354..50de0744 100644 --- a/oarepo_model_builder/invenio/templates/invenio_views.py.jinja2 +++ b/oarepo_model_builder/invenio/templates/invenio_views.py.jinja2 @@ -1,3 +1,19 @@ def {{ python.create_blueprint_from_app|base_name }}(app): """Create {{ python.package_name }} blueprint.""" - return app.extensions["{{ python.flask_extension_name }}"].resource.as_blueprint() + blueprint = app.extensions["{{ python.flask_extension_name }}"].resource.as_blueprint() + blueprint.record_once(init) + return blueprint + + +def init(state): + """Init app.""" + app = state.app + ext = app.extensions["{{ python.flask_extension_name }}"] + + # register service + sregistry = app.extensions["invenio-records-resources"].registry + sregistry.register(ext.service, service_id="{{ python.flask_extension_name }}") + + # Register indexer + iregistry = app.extensions["invenio-indexer"].registry + iregistry.register(ext.service.indexer, indexer_id="{{ python.flask_extension_name }}") diff --git a/oarepo_model_builder/model_preprocessors/default_values.py b/oarepo_model_builder/model_preprocessors/default_values.py index 1bc41d87..a70022f0 100644 --- a/oarepo_model_builder/model_preprocessors/default_values.py +++ b/oarepo_model_builder/model_preprocessors/default_values.py @@ -70,7 +70,7 @@ def c(): "records", "mappings", "v7", - settings.package_base, + settings.package, settings.schema_name, ), ) @@ -81,7 +81,7 @@ def c(): self.set( settings, "index-name", - lambda: settings.package_base + "-" + + lambda: settings.package + "-" + os.path.basename(settings.mapping_file).replace(".json", ""), ) diff --git a/oarepo_model_builder/model_preprocessors/invenio.py b/oarepo_model_builder/model_preprocessors/invenio.py index 29172433..8ba8d1b5 100644 --- a/oarepo_model_builder/model_preprocessors/invenio.py +++ b/oarepo_model_builder/model_preprocessors/invenio.py @@ -15,7 +15,7 @@ def transform(self, schema, settings): settings, { "python": { - "record_prefix": camel_case(settings.package.rsplit(".", maxsplit=1)[-1]), + "record-prefix": camel_case(settings.package.rsplit(".", maxsplit=1)[-1]), # just make sure that the templates is always there "templates": {}, "marshmallow": {"mapping": {}}, @@ -155,6 +155,7 @@ def transform(self, schema, settings): "record-service-config-bases", lambda: [] ) + settings.python.setdefault('record-service-config-generate-links', True) # - schema self.set( settings.python, @@ -232,6 +233,10 @@ def transform(self, schema, settings): "generate": True, }, ) + + settings.python.setdefault("generate-record-pid-field", True) + settings.python.setdefault("record-dumper-extensions", []) + # default import prefixes settings.python.setdefault("always-defined-import-prefixes", []).extend(["ma", "ma_fields", "ma_valid"]) diff --git a/oarepo_model_builder/schema.py b/oarepo_model_builder/schema.py index 8455dd32..ddef75aa 100644 --- a/oarepo_model_builder/schema.py +++ b/oarepo_model_builder/schema.py @@ -15,11 +15,11 @@ class ModelSchema: OAREPO_USE = "oarepo:use" def __init__( - self, - file_path, - content=None, - included_models: Dict[str, Callable] = None, - loaders=None, + self, + file_path, + content=None, + included_models: Dict[str, Callable] = None, + loaders=None, ): """ Creates and parses model schema @@ -41,6 +41,9 @@ def __init__( self._resolve_references(self.schema, []) + # any star keys should be kept + use_star_keys(self.schema) + self.schema.setdefault("settings", {}) self.schema = munch.munchify(self.schema, factory=HyphenMunch) @@ -121,7 +124,7 @@ def _resolve_references(self, element, stack): if not name: raise IncludedFileNotFoundException(f"No file for oarepo:include at path {'/'.join(stack)}") included_data = self._load_included_file(name) - deepmerge(element, included_data, []) + deepmerge(element, included_data, [], listmerge="keep") return self._resolve_references(element, stack) for k, v in element.items(): self._resolve_references(v, stack + [k]) @@ -147,3 +150,28 @@ def resolve_id(json, element_id): ret = resolve_id(k, element_id) if ret is not None: return ret + + +def remove_star_keys(schema): + if isinstance(schema, dict): + for k, v in list(schema.items()): + if k.startswith('*'): + del schema[k] + else: + remove_star_keys(v) + elif isinstance(schema, (list, tuple)): + for k in schema: + remove_star_keys(k) + + +def use_star_keys(schema): + if isinstance(schema, dict): + for k, v in list(schema.items()): + if k.startswith('*'): + del schema[k] + schema[k[1:]] = v + for v in schema.values(): + use_star_keys(v) + elif isinstance(schema, (list, tuple)): + for k in schema: + use_star_keys(k) diff --git a/oarepo_model_builder/stack/schema.py b/oarepo_model_builder/stack/schema.py index f918792c..9d88a453 100644 --- a/oarepo_model_builder/stack/schema.py +++ b/oarepo_model_builder/stack/schema.py @@ -115,7 +115,7 @@ def __repr__(self): Ref.refs["additionalProperties"] = DictValidator(primitives="type") -Ref.refs["propertyNames"] = (DictValidator(primitives="pattern"),) +Ref.refs["propertyNames"] = (DictValidator(primitives="pattern")) Ref.refs["properties"] = AnyKeyDictValidator(Ref("property", "type")) diff --git a/oarepo_model_builder/stack/stack.py b/oarepo_model_builder/stack/stack.py index 5b756d81..12efad2d 100644 --- a/oarepo_model_builder/stack/stack.py +++ b/oarepo_model_builder/stack/stack.py @@ -57,7 +57,10 @@ def push(self, key, el): if not self.stack: entry = ModelBuilderStackEntry(key, el, model_paths) else: - entry = ModelBuilderStackEntry(key, el, self.top.schema.get(key)) + try: + entry = ModelBuilderStackEntry(key, el, self.top.schema.get(key)) + except Exception as e: + print(self.top.schema.get(key)) self.stack.append(entry) def pop(self): diff --git a/oarepo_model_builder/templates/__init__.py b/oarepo_model_builder/templates/__init__.py index 6321f208..3ed53f4c 100644 --- a/oarepo_model_builder/templates/__init__.py +++ b/oarepo_model_builder/templates/__init__.py @@ -1,14 +1,15 @@ +from importlib.metadata import entry_points from pathlib import Path -import pkg_resources class TemplateRegistry: def __init__(self): self.mapping = {} + eps = entry_points() for ep in reversed( sorted( - pkg_resources.iter_entry_points("oarepo_model_builder.templates"), + eps.select(group="oarepo_model_builder.templates"), key=lambda ep: ep.name, ) ): diff --git a/oarepo_model_builder/utils/cst/collections.py b/oarepo_model_builder/utils/cst/collections.py index 5bb462dd..d7895472 100644 --- a/oarepo_model_builder/utils/cst/collections.py +++ b/oarepo_model_builder/utils/cst/collections.py @@ -35,11 +35,17 @@ class DictMerger(MergerBase): def identity(self, context: PythonContext, node): return node + def get_key(self, context, el): + if hasattr(el, 'key') and el.key: + return el.key.value + # TODO: StarredDictElement might contain trailing comma, should remove it + return context.to_source_code(el) + def merge_internal(self, context: PythonContext, existing_node, new_node): ret = [] mergers = expression_mergers - existing_elements = {el.key.value: el for el in existing_node.elements} - new_elements = {el.key.value: el for el in new_node.elements} + existing_elements = {self.get_key(context, el): el for el in existing_node.elements} + new_elements = {self.get_key(context, el): el for el in new_node.elements} for k, el in existing_elements.items(): merger = mergers.get(type(el), IdentityMerger()) if k not in new_elements: diff --git a/oarepo_model_builder/utils/cst/mergers.py b/oarepo_model_builder/utils/cst/mergers.py index 30a5bb60..7e167de6 100644 --- a/oarepo_model_builder/utils/cst/mergers.py +++ b/oarepo_model_builder/utils/cst/mergers.py @@ -17,7 +17,7 @@ Pass, SimpleStatementLine, SimpleString, - Tuple, + Tuple, StarredElement, ) @@ -65,7 +65,7 @@ def expression_mergers(): from oarepo_model_builder.utils.cst.collections import DictMerger, ElementMerger, ListMerger from .call import CallMerger - from .simple_nodes import ExprMerger, IntegerMerger, NameMerger, SimpleStringMerger + from .simple_nodes import ExprMerger, IntegerMerger, NameMerger, SimpleStringMerger, StarredElementMerger return { Call: CallMerger(), @@ -77,4 +77,5 @@ def expression_mergers(): Name: NameMerger(), Expr: ExprMerger(), Dict: DictMerger(), + StarredElement: StarredElementMerger() } diff --git a/oarepo_model_builder/utils/cst/simple_nodes.py b/oarepo_model_builder/utils/cst/simple_nodes.py index 47aa6a20..b628d684 100644 --- a/oarepo_model_builder/utils/cst/simple_nodes.py +++ b/oarepo_model_builder/utils/cst/simple_nodes.py @@ -1,6 +1,6 @@ import logging -from libcst import Assign, Integer +from libcst import Assign, Integer, Name from .common import IdentityBaseMerger, IdentityMerger, MergerBase, PythonContext from .mergers import expression_mergers, simple_line_mergers @@ -46,6 +46,10 @@ class NameMerger(IdentityBaseMerger): pass +class StarredElementMerger(IdentityBaseMerger): + pass + + class FunctionMerger(MergerBase): def merge_internal(self, context: PythonContext, existing_node, new_node): return existing_node diff --git a/oarepo_model_builder/utils/deepmerge.py b/oarepo_model_builder/utils/deepmerge.py index c2afafe4..c84e7b66 100644 --- a/oarepo_model_builder/utils/deepmerge.py +++ b/oarepo_model_builder/utils/deepmerge.py @@ -26,6 +26,9 @@ def deepmerge(target, source, stack=None, listmerge="overwrite"): target.append(source[idx]) elif listmerge == "extend": target.extend(source) + elif listmerge == "keep": + if len(source) > len(target): + target.extend(source[len(target):]) else: - raise AttributeError('listmerge must be one of "overwrite" or "extend"') + raise AttributeError('listmerge must be one of "overwrite", "extend" or "keep"') return target diff --git a/oarepo_model_builder/validation/schemas/common-schema.json5 b/oarepo_model_builder/validation/schemas/common-schema.json5 index 2cb08c25..4f0e3b90 100644 --- a/oarepo_model_builder/validation/schemas/common-schema.json5 +++ b/oarepo_model_builder/validation/schemas/common-schema.json5 @@ -12,6 +12,9 @@ "settings": { "$ref": "#/$defs/settings" }, + "oarepo:mapping": { + "$ref": "#/$defs/oarepo-mapping-root" + }, "model": { "$ref": "#/$defs/model" }, diff --git a/oarepo_model_builder/validation/schemas/mapping.json5 b/oarepo_model_builder/validation/schemas/mapping.json5 index 6c3203a7..d75ef944 100644 --- a/oarepo_model_builder/validation/schemas/mapping.json5 +++ b/oarepo_model_builder/validation/schemas/mapping.json5 @@ -6,5 +6,9 @@ "additionalProperties": true } } + }, + "oarepo-mapping-root": { + "type": "object", + "additionalProperties": true } } \ No newline at end of file diff --git a/oarepo_model_builder/validation/schemas/marshmallow.json5 b/oarepo_model_builder/validation/schemas/marshmallow.json5 index ff56d509..f38da03a 100644 --- a/oarepo_model_builder/validation/schemas/marshmallow.json5 +++ b/oarepo_model_builder/validation/schemas/marshmallow.json5 @@ -20,6 +20,12 @@ "generate": { "type": "boolean" }, + "generate-class": { + "type": "boolean" + }, + "generate-field": { + "type": "boolean" + }, "base-schema": { "type": "string" }, diff --git a/oarepo_model_builder/validation/schemas/modelschema.json5 b/oarepo_model_builder/validation/schemas/modelschema.json5 index a045afa3..e0a7680e 100644 --- a/oarepo_model_builder/validation/schemas/modelschema.json5 +++ b/oarepo_model_builder/validation/schemas/modelschema.json5 @@ -16,7 +16,8 @@ } }, "unevaluatedProperties": false, - "minProperties": 1 // empty object is not allowed in the current schema + "minProperties": 1 + // empty object is not allowed in the current schema }, "jsonschema-property": { "type": "object", @@ -74,6 +75,22 @@ }, "enum": { "type": "array" + }, + "additionalProperties": { + "type": "object", + "properties": { + "type": { + "type": "string" + } + } + }, + "propertyNames": { + "type": "object", + "properties": { + "pattern": { + "type": "string" + } + } } } } diff --git a/oarepo_model_builder/validation/schemas/settings.json5 b/oarepo_model_builder/validation/schemas/settings.json5 index ddc47b8d..c4f61f8d 100644 --- a/oarepo_model_builder/validation/schemas/settings.json5 +++ b/oarepo_model_builder/validation/schemas/settings.json5 @@ -182,7 +182,7 @@ "record-service-config-class": { "type": "string" }, - "record_prefix": { + "record-prefix": { "type": "string" }, "script-import-sample-data-cli": { @@ -257,6 +257,24 @@ "type": "string" } }, + "record-service-config-components": { + "type": "array", + "items": { + "type": "string" + } + }, + "record-dumper-extensions": { + "type": "array", + "items": { + "type": "string" + } + }, + "generate-record-pid-field": { + type: "boolean" + }, + "record-service-config-generate-links": { + type: "boolean" + }, "templates": { "type": "object", "patternProperties": { diff --git a/pyproject.toml b/pyproject.toml index 3f26ee63..961f3ffe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "oarepo-model-builder" -version = "1.0.0.dev7" +version = "1.0.0.dev8" description = "An utility library that generates OARepo required data model files from a JSON specification file" authors = ["Miroslav Bauer ", "Miroslav Simek "] readme = "README.md" diff --git a/tests/test_builder_from_entrypoints.py b/tests/test_builder_from_entrypoints.py index 92551b02..b9a4f681 100644 --- a/tests/test_builder_from_entrypoints.py +++ b/tests/test_builder_from_entrypoints.py @@ -23,10 +23,12 @@ def test_include_invenio(): builder.build(schema, "") data = builder.filesystem.open(os.path.join("test", "services", "schema.py")).read() + print(data) assert re.sub(r"\s", "", data) == re.sub( r"\s", "", """ +from invenio_records_resources.services.records.schema import BaseRecordSchema import marshmallow as ma import marshmallow.fields as ma_fields import marshmallow.validate as ma_valid @@ -34,18 +36,14 @@ def test_include_invenio(): from marshmallow import ValidationError from marshmallow import validates as ma_validates -class TestSchema(ma.Schema, ): +class TestSchema(BaseRecordSchema, ): \"""TestSchema schema.\""" a = ma_fields.String() - id = ma_fields.String() + created = ma_fields.Date(dump_only=True) - created = ma_fields.Date() - - updated = ma_fields.Date() - - _schema = ma_fields.String(data_key='$schema') + updated = ma_fields.Date(dump_only=True) """, ) @@ -89,7 +87,7 @@ def test_incremental_builder(): schema = load_model( "test.yaml", "test", - model_content={"model": {"properties": {"a": {"type": "keyword"}}}}, + model_content={"oarepo:use": "invenio", "model": {"properties": {"a": {"type": "keyword"}}}}, isort=False, black=False, ) @@ -136,4 +134,4 @@ def test_incremental_builder(): for k, iteration_result in snapshot_1.items(): expected_result = snapshot_2[k] # normally handled by black - assert_python_equals(iteration_result.replace(",'_id'", ",\n'_id'"), expected_result) + assert_python_equals(iteration_result.replace(",'_id'", ",\n'_id'"), expected_result, f'File {k}') diff --git a/tests/test_marshmallow_builder.py b/tests/test_marshmallow_builder.py index 561e5638..60f727d9 100644 --- a/tests/test_marshmallow_builder.py +++ b/tests/test_marshmallow_builder.py @@ -107,7 +107,7 @@ def test_generate_nested_schema_same_file(fulltext_builder): """class TestSchema(ma.Schema, ): \"""TestSchema schema.\""" - a = ma_fields.Nested(lambda: B())""", + a = ma_fields.Nested(B)""", ) in re.sub(r"\s", "", data) ) @@ -140,7 +140,7 @@ def test_generate_nested_schema_different_file(fulltext_builder): """class TestSchema(ma.Schema, ): \"""TestSchema schema.\""" - a = ma_fields.Nested(lambda: B())""", + a = ma_fields.Nested(B)""", ) in re.sub(r"\s", "", data) ) @@ -168,7 +168,7 @@ def test_use_nested_schema_same_file(fulltext_builder): with fulltext_builder.filesystem.open(os.path.join("test", "services", "schema.py")) as f: data = f.read() - + print(data) assert ( re.sub( r"\s", @@ -176,7 +176,7 @@ def test_use_nested_schema_same_file(fulltext_builder): """class TestSchema(ma.Schema, ): \"""TestSchema schema.\""" - a = ma_fields.Nested(lambda: B())""", + a = ma_fields.Nested(B)""", ) in re.sub(r"\s", "", data) ) @@ -200,7 +200,7 @@ def test_use_nested_schema_different_file(fulltext_builder): with fulltext_builder.filesystem.open(os.path.join("test", "services", "schema.py")) as f: data = f.read() assert re.sub(r"\s", "", "from c import B") in re.sub(r"\s", "", data) - assert 'classTestSchema(ma.Schema,):"""TestSchemaschema."""a=ma_fields.Nested(lambda:B())' in re.sub( + assert 'classTestSchema(ma.Schema,):"""TestSchemaschema."""a=ma_fields.Nested(B)' in re.sub( r"\s", "", data ) @@ -226,7 +226,7 @@ def test_generate_nested_schema_array(fulltext_builder): data = f.read() assert 'classB(ma.Schema,):"""Bschema."""b=ma_fields.String()' in re.sub(r"\s", "", data) assert ( - 'classTestSchema(ma.Schema,):"""TestSchemaschema."""a=ma_fields.List(ma_fields.Nested(lambda:B()))' + 'classTestSchema(ma.Schema,):"""TestSchemaschema."""a=ma_fields.List(ma_fields.Nested(B))' in re.sub(r"\s", "", data) ) @@ -279,7 +279,7 @@ def test_generate_nested_schema_relative_same_package(fulltext_builder): """class TestSchema(ma.Schema, ): \"""TestSchema schema.\""" - a = ma_fields.Nested(lambda: B())""", + a = ma_fields.Nested(B)""", ) in re.sub(r"\s", "", data) ) @@ -318,7 +318,7 @@ def test_generate_nested_schema_relative_same_file(fulltext_builder): """class TestSchema(ma.Schema, ): \"""TestSchema schema.\""" - a = ma_fields.Nested(lambda: B())""", + a = ma_fields.Nested(B)""", ) in re.sub(r"\s", "", data) ) @@ -351,7 +351,7 @@ def test_generate_nested_schema_relative_same_package(fulltext_builder): """class TestSchema(ma.Schema, ): \"""TestSchema schema.\""" - a = ma_fields.Nested(lambda: B())""", + a = ma_fields.Nested(B)""", ) in re.sub(r"\s", "", data) ) @@ -390,7 +390,7 @@ def test_generate_nested_schema_relative_upper(fulltext_builder): """class TestSchema(ma.Schema, ): \"""TestSchema schema.\""" - a = ma_fields.Nested(lambda: B())""", + a = ma_fields.Nested(B)""", ) in re.sub(r"\s", "", data) ) diff --git a/tests/test_raw.py b/tests/test_raw.py index 95ecaef6..4fdea154 100644 --- a/tests/test_raw.py +++ b/tests/test_raw.py @@ -23,11 +23,12 @@ def test_raw_type(): builder.build(schema, "") data = builder.filesystem.open(os.path.join("test", "services", "schema.py")).read() - + print(data) assert re.sub(r"\s", "", data) == re.sub( r"\s", "", """ +from invenio_records_resources.services.records.schema import BaseRecordSchema import marshmallow as ma import marshmallow.fields as ma_fields import marshmallow.validate as ma_valid @@ -35,18 +36,15 @@ def test_raw_type(): from marshmallow import ValidationError from marshmallow import validates as ma_validates -class TestSchema(ma.Schema, ): +class TestSchema(BaseRecordSchema, ): \"""TestSchema schema.\""" a = ma_fields.Raw() - id = ma_fields.String() - - created = ma_fields.Date() + created = ma_fields.Date(dump_only=True) - updated = ma_fields.Date() - - _schema = ma_fields.String(data_key='$schema') + updated = ma_fields.Date(dump_only=True) + """, ) diff --git a/tests/test_schema_props.py b/tests/test_schema_props.py index 6dee738a..53f942cc 100644 --- a/tests/test_schema_props.py +++ b/tests/test_schema_props.py @@ -30,6 +30,7 @@ def test_enum(): r"\s", "", """ +from invenio_records_resources.services.records.schema import BaseRecordSchema import marshmallow as ma import marshmallow.fields as ma_fields import marshmallow.validate as ma_valid @@ -37,18 +38,14 @@ def test_enum(): from marshmallow import ValidationError from marshmallow import validates as ma_validates -class TestSchema(ma.Schema, ): +class TestSchema(BaseRecordSchema, ): \"""TestSchema schema.\""" a = ma_fields.String(validate=[ma_valid.OneOf(["a", "b", "c"])]) - - id = ma_fields.String() - - created = ma_fields.Date() - updated = ma_fields.Date() + created = ma_fields.Date(dump_only=True) - _schema = ma_fields.String(data_key='$schema') + updated = ma_fields.Date(dump_only=True) """, ) diff --git a/tests/test_search_options.py b/tests/test_search_options.py index d822d01e..08489630 100644 --- a/tests/test_search_options.py +++ b/tests/test_search_options.py @@ -158,62 +158,22 @@ def _(x): \"""Identity function for string extraction.\""" return x - - class TestSearchOptions(InvenioSearchOptions): \"""TestRecord search options.\""" facets = { - - 'a_keyword': facets.a_keyword, - - - 'b': facets.b, - - - '_id': facets._id, - - - 'created': facets.created, - - - 'updated': facets.updated, - - - '_schema': facets._schema, - - } sort_options = { - "bestmatch": dict( - title=_('Best match'), - fields=['_score'], # ES defaults to desc on `_score` field - ), - "newest": dict( - title=_('Newest'), - fields=['-created'], - ), - "oldest": dict( - title=_('Oldest'), - fields=['created'], - ), - - + **InvenioSearchOptions.sort_options, 'a_test': {'fields': ['a']}, - - - 'b_test': {'fields': ['-b']}, - - } - """, ) def test_nested(): @@ -343,6 +303,7 @@ def test_search_class(): builder.build(schema, "") data = builder.filesystem.open(os.path.join("test", "services", "search.py")).read() + print(data) assert re.sub(r"\s", "", data) == re.sub( r"\s", "", @@ -354,52 +315,19 @@ def _(x): \"""Identity function for string extraction.\""" return x - - class TestSearchOptions(InvenioSearchOptions): \"""TestRecord search options.\""" facets = { - - - 'a_keyword': facets.a_keyword, - - - - 'b': facets.b, - - - - '_id': facets._id, - - - - 'created': facets.created, - - - - 'updated': facets.updated, - - - - '_schema': facets._schema, - - + 'a_keyword': facets.a_keyword, + 'b': facets.b, + '_id': facets._id, + 'created': facets.created, + 'updated': facets.updated, + '_schema': facets._schema, } sort_options = { - "bestmatch": dict( - title=_('Best match'), - fields=['_score'], # ES defaults to desc on `_score` field - ), - "newest": dict( - title=_('Newest'), - fields=['-created'], - ), - "oldest": dict( - title=_('Oldest'), - fields=['created'], - ), - + **InvenioSearchOptions.sort_options, } """, ) diff --git a/tests/test_shortcuts.py b/tests/test_shortcuts.py index c73963d3..ba8074e9 100644 --- a/tests/test_shortcuts.py +++ b/tests/test_shortcuts.py @@ -26,10 +26,12 @@ def test_array_shortcuts(): builder.build(schema, "") data = builder.filesystem.open(os.path.join("test", "services", "schema.py")).read() + print(data) assert re.sub(r"\s", "", data) == re.sub( r"\s", "", """ +from invenio_records_resources.services.records.schema import BaseRecordSchema import marshmallow as ma import marshmallow.fields as ma_fields import marshmallow.validate as ma_valid @@ -37,16 +39,12 @@ def test_array_shortcuts(): from marshmallow import ValidationError from marshmallow import validates as ma_validates -class TestSchema(ma.Schema, ): +class TestSchema(BaseRecordSchema, ): \"""TestSchema schema.\""" - id = ma_fields.String() + created = ma_fields.Date(dump_only=True) - created = ma_fields.Date() - - updated = ma_fields.Date() - - _schema = ma_fields.String(data_key='$schema') + updated = ma_fields.Date(dump_only=True) a = ma_fields.List(ma_fields.String()) """, diff --git a/tests/utils.py b/tests/utils.py index 37b8304a..b3d679c0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,7 @@ from collections import defaultdict -def assert_python_equals(actual, expected): +def assert_python_equals(actual, expected, msg=''): actual_lines = [x.strip() for x in actual.split("\n")] expected_lines = [x.strip() for x in expected.split("\n")] @@ -26,7 +26,7 @@ def assert_python_equals(actual, expected): print("Actual lines:\n") print_lines_around(actual_lines, actual_line_no) - raise AssertionError(f"Actual line {actual_line_no + 1} '{actual_line}' not in expected lines") + raise AssertionError(f"Actual line {actual_line_no + 1} '{actual_line}' not in expected lines. {msg}") def print_lines_around(lines, position):