diff --git a/flask_mongorest/resources.py b/flask_mongorest/resources.py index 859857eb..85c6968d 100644 --- a/flask_mongorest/resources.py +++ b/flask_mongorest/resources.py @@ -332,6 +332,110 @@ def _subresource(self, obj): else: return None + def get_field_value(self, obj, field_name, field_instance=None, **kwargs): + """Return a json-serializable field value. + + field_name is the name of the field in `obj` to be serialized. + field_instance is a MongoEngine field definition. + **kwargs are just any options to be passed through to child resources serializers. + """ + has_field_instance = bool(field_instance) + field_instance = (field_instance or + self.document._fields.get(field_name, None) or + getattr(self.document, field_name, None)) + + # Determine the field value + if has_field_instance: + field_value = obj + elif isinstance(obj, dict): + return obj[field_name] + else: + try: + field_value = getattr(obj, field_name) + except AttributeError: + raise UnknownFieldError + + return self.serialize_field_value(obj, field_name, field_instance, field_value, **kwargs) + + def serialize_field_value(self, obj, field_name, field_instance, field_value, **kwargs): + """Select and delegate to an appropriate serializer method based on type of field instance. + + field_value is an actual value to be serialized. + For other fields, see get_field_value method. + """ + if isinstance(field_instance, (ReferenceField, GenericReferenceField, EmbeddedDocumentField)): + return self.serialize_document_field(field_name, field_value, **kwargs) + + elif isinstance(field_instance, ListField): + return self.serialize_list_field(field_instance, field_name, field_value, **kwargs) + + elif isinstance(field_instance, DictField): + return self.serialize_dict_field(field_instance, field_name, field_value, **kwargs) + + elif callable(field_instance): + return self.serialize_callable_field(obj, field_instance, field_name, field_value, **kwargs) + return field_value + + def serialize_callable_field(self, obj, field_instance, field_name, field_value, **kwargs): + """Execute a callable field and return it or serialize + it based on its related resource defined in the `related_resources` map. + """ + if isinstance(field_value, list): + value = field_value + else: + if isbound(field_instance): + value = field_instance() + elif isbound(field_value): + value = field_value() + else: + value = field_instance(obj) + if field_name in self._related_resources: + if isinstance(value, list): + return [self._related_resources[field_name]().serialize_field(o, **kwargs) for o in value] + elif value is None: + return None + else: + return self._related_resources[field_name]().serialize_field(value, **kwargs) + return value + + def serialize_dict_field(self, field_instance, field_name, field_value, **kwargs): + """Serialize each value based on an explicit field type + (e.g. if the schema defines a DictField(IntField), where all + the values in the dict should be ints). + """ + if field_instance.field: + return { + key: self.get_field_value(elem, field_name, field_instance=field_instance.field, **kwargs) + for (key, elem) in field_value.items() + } + # ... or simply return the dict intact, if the field type + # wasn't specified + else: + return field_value + + def serialize_list_field(self, field_instance, field_name, field_value, **kwargs): + """Serialize each item in the list separately.""" + return [val for val in [self.get_field_value(elem, field_name, field_instance=field_instance.field, **kwargs) for elem in field_value] if val] + + def serialize_document_field(self, field_name, field_value, **kwargs): + """If this field is a reference or an embedded document, either return + a DBRef or serialize it using a resource found in `related_resources`. + """ + if field_name in self._related_resources: + return ( + field_value and + not isinstance(field_value, DBRef) and + self._related_resources[field_name]().serialize_field(field_value, **kwargs) + ) + else: + if DocumentProxy and isinstance(field_value, DocumentProxy): + # Don't perform a DBRef isinstance check below since + # it might trigger an extra query. + return field_value.to_dbref() + if isinstance(field_value, DBRef): + return field_value + return field_value and field_value.to_dbref() + def serialize(self, obj, **kwargs): """ Given an object, serialize it, turning it into its JSON @@ -346,89 +450,6 @@ def serialize(self, obj, **kwargs): if subresource: return subresource.serialize(obj, **kwargs) - def get(obj, field_name, field_instance=None): - """ - @TODO needs significant cleanup - """ - - has_field_instance = bool(field_instance) - field_instance = (field_instance or - self.document._fields.get(field_name, None) or - getattr(self.document, field_name, None)) - - # Determine the field value - if has_field_instance: - field_value = obj - elif isinstance(obj, dict): - return obj[field_name] - else: - try: - field_value = getattr(obj, field_name) - except AttributeError: - raise UnknownFieldError - - # If this field is a reference or an embedded document, either - # return a DBRef or serialize it using a resource found in - # `related_resources`. - if isinstance(field_instance, (ReferenceField, GenericReferenceField, EmbeddedDocumentField)): - if field_name in self._related_resources: - return ( - field_value and - not isinstance(field_value, DBRef) and - self._related_resources[field_name]().serialize_field(field_value, **kwargs) - ) - else: - if DocumentProxy and isinstance(field_value, DocumentProxy): - # Don't perform a DBRef isinstance check below since - # it might trigger an extra query. - return field_value.to_dbref() - if isinstance(field_value, DBRef): - return field_value - return field_value and field_value.to_dbref() - - # If this field is a list, serialize each item in the list separately. - elif isinstance(field_instance, ListField): - return [val for val in [get(elem, field_name, field_instance=field_instance.field) for elem in field_value] if val] - - # If this field is a dict... - elif isinstance(field_instance, DictField): - # ... serialize each value based on an explicit field type - # (e.g. if the schema defines a DictField(IntField), where all - # the values in the dict should be ints). - if field_instance.field: - return { - key: get(elem, field_name, field_instance=field_instance.field) - for (key, elem) in field_value.items() - } - # ... or simply return the dict intact, if the field type - # wasn't specified - else: - return field_value - - # If this field is callable, execute it and return it or serialize - # it based on its related resource defined in the - # `related_resources` map. - elif callable(field_instance): - if isinstance(field_value, list): - value = field_value - else: - if isbound(field_instance): - value = field_instance() - elif isbound(field_value): - value = field_value() - else: - value = field_instance(obj) - - if field_name in self._related_resources: - if isinstance(value, list): - return [self._related_resources[field_name]().serialize_field(o, **kwargs) for o in value] - elif value is None: - return None - else: - return self._related_resources[field_name]().serialize_field(value, **kwargs) - return value - return field_value - # Get the requested fields requested_fields = self.get_requested_fields(**kwargs) @@ -465,7 +486,7 @@ def get(obj, field_name, field_instance=None): data[renamed_field] = value else: try: - data[renamed_field] = get(obj, field) + data[renamed_field] = self.get_field_value(obj, field, **kwargs) except UnknownFieldError: try: data[renamed_field] = self.value_for_field(obj, field) @@ -897,6 +918,7 @@ def update_object(self, obj, data=None, save=True, parent_resources=None): def delete_object(self, obj, parent_resources=None): obj.delete() + # Py2/3 compatible way to do metaclasses (or six.add_metaclass) body = vars(Resource).copy() body.pop('__dict__', None) diff --git a/flask_mongorest/utils.py b/flask_mongorest/utils.py index 35ea1d1f..93e1e94b 100644 --- a/flask_mongorest/utils.py +++ b/flask_mongorest/utils.py @@ -28,6 +28,7 @@ def default(self, value, **kwargs): return str(value) return super(MongoEncoder, self).default(value, **kwargs) + try: cmp except NameError: # Python 3 diff --git a/tests/__init__.py b/tests/__init__.py index ff479afe..ee6d5528 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1321,6 +1321,7 @@ def test_serialize_mongoengine_validation_error(self): } }) + if __name__ == '__main__': unittest.main()