Skip to content

Commit

Permalink
fix(import) : enable import of one entity (#3270)
Browse files Browse the repository at this point in the history
* feat(fieldmapping): enable the possibility to import one entity when controlling the mapping
* fix(model) : fix validate function of fieldmapping

---------

Co-authored-by: jacquesfize <[email protected]>
  • Loading branch information
Pierre-Narcisi and jacquesfize authored Nov 27, 2024
1 parent 9ed127a commit 2c30b43
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 29 deletions.
22 changes: 14 additions & 8 deletions backend/geonature/core/imports/checks/sql/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
Entity,
EntityField,
BibFields,
TImports,
)


__all__ = ["init_rows_validity", "check_orphan_rows"]


def init_rows_validity(imprt):
def init_rows_validity(imprt: TImports, dataset_name_field: str = "id_dataset"):
"""
Validity columns are three-states:
- None: the row does not contains data for the given entity
Expand Down Expand Up @@ -46,16 +47,21 @@ def init_rows_validity(imprt):
.where(BibFields.name_field.in_(selected_fields_names))
.where(BibFields.entities.any(EntityField.entity == entity))
.where(~BibFields.entities.any(EntityField.entity != entity))
.where(BibFields.name_field != dataset_name_field)
.all()
)
db.session.execute(
sa.update(transient_table)
.where(transient_table.c.id_import == imprt.id_import)
.where(
sa.or_(*[transient_table.c[field.source_column].isnot(None) for field in fields])

if fields:
db.session.execute(
sa.update(transient_table)
.where(transient_table.c.id_import == imprt.id_import)
.where(
sa.or_(
*[transient_table.c[field.source_column].isnot(None) for field in fields]
)
)
.values({entity.validity_column: True})
)
.values({entity.validity_column: True})
)


def check_orphan_rows(imprt):
Expand Down
105 changes: 91 additions & 14 deletions backend/geonature/core/imports/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from collections.abc import Mapping
import re
from typing import Iterable, List, Optional
from typing import Any, Iterable, List, Optional
from packaging import version

from flask import g
Expand Down Expand Up @@ -618,6 +618,42 @@ def optional_conditions_to_jsonschema(name_field: str, optional_conditions: Iter
}


# TODO move to utils lib
def get_fields_of_an_entity(
entity: "Entity",
columns: Optional[List[str]] = None,
optional_where_clause: Optional[Any] = None,
) -> List["BibFields"]:
"""
Get all BibFields associated with a given entity.
Parameters
----------
entity : Entity
The entity to get the fields for.
columns : Optional[List[str]], optional
The columns to retrieve. If None, all columns are retrieved.
optional_where_clause : Optional[Any], optional
An optional where clause to apply to the query.
Returns
-------
List[BibFields]
The BibFields associated with the given entity.
"""
select_args = [BibFields]
query = sa.select(BibFields).where(
BibFields.entities.any(EntityField.entity == entity),
)
if columns:
select_args = [getattr(BibFields, col) for col in columns]
query.with_only_columns(*select_args)
if optional_where_clause is not None:
query = query.where(optional_where_clause)

return db.session.scalars(query).all()


@serializable
class FieldMapping(MappingTemplate):
__tablename__ = "t_fieldmappings"
Expand All @@ -631,19 +667,60 @@ class FieldMapping(MappingTemplate):
}

@staticmethod
def validate_values(values):
fields = (
BibFields.query.filter_by(destination=g.destination, display=True)
.with_entities(
BibFields.name_field,
BibFields.autogenerated,
BibFields.mandatory,
BibFields.multi,
BibFields.optional_conditions,
BibFields.mandatory_conditions,
)
.all()
def validate_values(field_mapping_json):
"""
Validate the field mapping values returned by the client form.
Parameters
----------
field_mapping_json : dict
The field mapping values.
Raises
------
ValueError
If the field mapping values are invalid.
"""
bib_fields_col = [
"name_field",
"autogenerated",
"mandatory",
"multi",
"optional_conditions",
"mandatory_conditions",
]
entities_for_destination: List[Entity] = (
Entity.query.filter_by(destination=g.destination).order_by(sa.desc(Entity.order)).all()
)
fields = []
for entity in entities_for_destination:
# Get fields associated to this entity and exists in the given field mapping
fields_of_ent = get_fields_of_an_entity(
entity,
columns=bib_fields_col,
optional_where_clause=sa.and_(
sa.or_(
~BibFields.entities.any(EntityField.entity != entity),
BibFields.name_field == entity.unique_column.name_field,
),
BibFields.name_field.in_(field_mapping_json.keys()),
),
)

# if the only column corresponds to id_columns, we only do the validation on the latter
if [entity.unique_column.name_field] == [f.name_field for f in fields_of_ent]:
fields.extend(fields_of_ent)
else:
# if other columns than the id_columns are used, we need to check every fields of this entity
fields.extend(
get_fields_of_an_entity(
entity,
columns=bib_fields_col,
optional_where_clause=sa.and_(
BibFields.destination == g.destination, BibFields.display == True
),
)
)

schema = {
"type": "object",
Expand Down Expand Up @@ -676,7 +753,7 @@ def validate_values(values):
schema["allOf"] = optional_conditions

try:
validate_json(values, schema)
validate_json(field_mapping_json, schema)
except JSONValidationError as e:
raise ValueError(e.message)

Expand Down
2 changes: 2 additions & 0 deletions backend/geonature/core/imports/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ def update_transient_data_from_dataframe(
updated_cols = ["id_import", "line_no"] + list(updated_cols)
dataframe.replace({np.nan: None}, inplace=True)
records = dataframe[updated_cols].to_dict(orient="records")
if not records:
return
insert_stmt = pg_insert(transient_table)
insert_stmt = insert_stmt.values(records).on_conflict_do_update(
index_elements=updated_cols[:2],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@ def preprocess_transient_data(imprt: TImports, df) -> set:
.where(BibFields.destination == imprt.destination)
.where(BibFields.name_field == "date_max")
).scalar_one()
updated_cols |= concat_dates(
df,
datetime_min_col=date_min_field.source_field,
datetime_max_col=date_max_field.source_field,
date_min_col=date_min_field.source_field,
date_max_col=date_max_field.source_field,
)
if date_min_field.source_field in df and date_max_field.source_field in df:
updated_cols |= concat_dates(
df,
datetime_min_col=date_min_field.source_field,
datetime_max_col=date_max_field.source_field,
date_min_col=date_min_field.source_field,
date_max_col=date_max_field.source_field,
)
return updated_cols

@staticmethod
Expand Down

0 comments on commit 2c30b43

Please sign in to comment.