From d32dbe3c2b29756b0ea094611396c17c9c9b202a Mon Sep 17 00:00:00 2001 From: MarkLark86 Date: Mon, 11 Nov 2024 09:26:05 +1100 Subject: [PATCH] [NHUB-537] Upgrade Agenda to async resource (#1150) * update types * fix init_data: await elastic_rebuild * new base web search service * new agenda resource service * new agenda search service * update agenda module * update agenda views * update agenda notifications and emails * update agenda featured to use new agenda search * update push to use async agenda * update wire module * update remove_expired_agenda command * user async resources in update_action_list * fix company updates * add more async product helpers * fix update company products * update reports to use async resources * fix template_locale context with no locale nor timezone * improve async topic helpers * remove old utils no longer used * add pydantic to mypy * fix leaking db connectors in pytest * fix tests * use NHUB-537 branch in superdesk-core * fix duplicate import * use async branch from superdesk-core * remove commented code * Replace if/elif block with match case statements --- .../agenda_restrict_coverage_details.feature | 8 +- features/web_api/agenda_search.feature | 18 +- mypy-requirements.txt | 1 + newsroom/agenda/__init__.py | 30 +- newsroom/agenda/agenda.py | 1550 ----------------- newsroom/agenda/agenda_search.py | 178 ++ newsroom/agenda/agenda_service.py | 396 +++++ newsroom/agenda/email.py | 12 +- newsroom/agenda/featured_service.py | 150 ++ newsroom/agenda/filters.py | 639 +++++++ newsroom/agenda/module.py | 49 +- newsroom/agenda/notifications.py | 274 +++ newsroom/agenda/service.py | 136 -- newsroom/agenda/views.py | 468 ++--- newsroom/am_news/views.py | 6 +- newsroom/commands/initialize_data.py | 2 +- newsroom/commands/remove_expired_agenda.py | 165 +- newsroom/companies/companies.py | 4 +- newsroom/companies/utils.py | 3 + newsroom/factcheck/views.py | 6 +- newsroom/gettext.py | 2 +- newsroom/market_place/views.py | 6 +- newsroom/media_releases/views.py | 6 +- newsroom/monitoring/views.py | 9 +- newsroom/notifications/commands.py | 159 +- newsroom/notifications/utils.py | 49 +- newsroom/products/__init__.py | 15 +- newsroom/products/products.py | 4 +- newsroom/products/service.py | 6 +- newsroom/products/utils.py | 71 +- newsroom/push/agenda_manager.py | 148 +- newsroom/push/notifications.py | 40 +- newsroom/push/publishing.py | 75 +- newsroom/push/tasks.py | 22 +- newsroom/push/utils.py | 3 +- newsroom/push/views.py | 34 +- newsroom/reports/content_activity.py | 144 +- newsroom/reports/reports.py | 32 +- newsroom/search/base_service.py | 3 +- newsroom/search/base_web_service.py | 219 +++ newsroom/search/filters.py | 63 +- newsroom/search/types.py | 15 +- newsroom/tests/conftest.py | 23 +- newsroom/tests/fixtures.py | 125 +- newsroom/tests/markers.py | 4 + newsroom/tests/web_api/environment.py | 3 +- newsroom/topics/__init__.py | 2 + newsroom/topics/topics_async.py | 23 +- newsroom/types/__init__.py | 18 +- newsroom/types/agenda.py | 282 +++ newsroom/types/company.py | 5 +- newsroom/types/products.py | 11 +- newsroom/types/users.py | 9 +- newsroom/types/wire.py | 18 +- newsroom/users/service.py | 21 +- newsroom/users/users.py | 4 +- newsroom/utils.py | 51 +- newsroom/web/default_settings.py | 3 +- newsroom/wire/__init__.py | 6 +- newsroom/wire/filters.py | 2 +- newsroom/wire/items.py | 4 +- newsroom/wire/service.py | 275 +-- newsroom/wire/utils.py | 41 +- newsroom/wire/views.py | 68 +- setup.cfg | 1 + tests/commands/test_remove_expired_items.py | 5 +- tests/core/test_agenda.py | 88 +- tests/core/test_agenda_events_only.py | 22 +- tests/core/test_api_tokens.py | 4 +- tests/core/test_auth.py | 55 +- tests/core/test_auth_providers.py | 11 +- .../test_command_remove_expired_agenda.py | 79 +- tests/core/test_commands.py | 11 +- tests/core/test_companies.py | 10 +- tests/core/test_copy_items.py | 12 +- tests/core/test_csv_formatter.py | 17 +- tests/core/test_download.py | 3 +- tests/core/test_email_templates.py | 13 +- tests/core/test_emails.py | 16 +- tests/core/test_home.py | 6 +- tests/core/test_ical_formatter.py | 4 +- tests/core/test_monitoring.py | 164 +- tests/core/test_navigations.py | 47 +- tests/core/test_products.py | 8 +- tests/core/test_push.py | 72 +- tests/core/test_push_events.py | 155 +- tests/core/test_push_evolved.py | 3 +- tests/core/test_realtime_notifications.py | 31 +- tests/core/test_reports.py | 10 +- tests/core/test_saml.py | 7 +- tests/core/test_search_config.py | 5 +- .../core/test_send_scheduled_notifications.py | 38 +- tests/core/test_signup.py | 17 +- tests/core/test_user_dashboards.py | 22 +- tests/core/test_users.py | 6 +- tests/core/test_wire.py | 88 +- tests/core/utils.py | 83 +- tests/fixtures/item_copy_fixture.json | 1 + tests/news_api/test_api_assets.py | 7 +- tests/news_api/test_api_audit.py | 18 +- tests/search/test_search_fields.py | 8 +- tests/search/test_search_params.py | 6 +- tests/search/test_search_topics.py | 35 +- tests/search/test_user_products.py | 30 +- 104 files changed, 4026 insertions(+), 3410 deletions(-) delete mode 100644 newsroom/agenda/agenda.py create mode 100644 newsroom/agenda/agenda_search.py create mode 100644 newsroom/agenda/agenda_service.py create mode 100644 newsroom/agenda/featured_service.py create mode 100644 newsroom/agenda/filters.py create mode 100644 newsroom/agenda/notifications.py delete mode 100644 newsroom/agenda/service.py create mode 100644 newsroom/search/base_web_service.py create mode 100644 newsroom/types/agenda.py diff --git a/features/web_api/agenda_restrict_coverage_details.feature b/features/web_api/agenda_restrict_coverage_details.feature index a30e1c7d6..b5f2f0356 100644 --- a/features/web_api/agenda_restrict_coverage_details.feature +++ b/features/web_api/agenda_restrict_coverage_details.feature @@ -150,14 +150,14 @@ Feature: Agenda Restricted Coverage Details "coverages": [{ "coverage_id": "plan1_cov1", "coverage_type": "text", - "scheduled": "2018-05-28T10:51:52+0000", + "scheduled": "2018-05-28T10:51:52+00:00", "slugline": "Vivid planning item", "workflow_status": "draft", "coverage_status": "coverage intended" }, { "coverage_id": "plan1_cov2", "coverage_type": "text", - "scheduled": "2018-05-28T10:51:52+0000", + "scheduled": "2018-05-28T10:51:52+00:00", "slugline": "Vivid planning item", "workflow_status": "draft", "coverage_status": "coverage intended" @@ -166,7 +166,7 @@ Feature: Agenda Restricted Coverage Details "guid": "plan1", "slugline": "New Press Conference", "name": "Prime minister press conference", - "planning_date": "2018-05-28T05:00:00+0000", + "planning_date": "2018-05-28T05:00:00+00:00", "coverages": [{ "coverage_id": "plan1_cov1", "news_coverage_status": { @@ -283,7 +283,7 @@ Feature: Agenda Restricted Coverage Details }, { "coverage_id": "plan1_cov2", "coverage_type": "text", - "scheduled": "2018-05-28T10:51:52+0000", + "scheduled": "2018-05-28T10:51:52+00:00", "slugline": "Vivid planning item", "workflow_status": "completed", "coverage_status": "coverage intended", diff --git a/features/web_api/agenda_search.feature b/features/web_api/agenda_search.feature index da4a01d05..ae9d1d3e3 100644 --- a/features/web_api/agenda_search.feature +++ b/features/web_api/agenda_search.feature @@ -13,9 +13,9 @@ Feature: Agenda Search }, "calendars": [{"qcode": "cal1", "name": "Calendar1"}], "subject": [ - {"code": "d1", "scheme": "sttdepartment", "name": "Dep1"}, - {"code": "s1", "scheme": "sttsubj", "name": "Sub1"}, - {"code": "e1", "scheme": "event_type", "name": "Sports"} + {"qcode": "d1", "scheme": "sttdepartment", "name": "Dep1"}, + {"qcode": "s1", "scheme": "sttsubj", "name": "Sub1"}, + {"qcode": "e1", "scheme": "event_type", "name": "Sports"} ], "place": [ {"code": "NSW", "name": "New South Wales"} @@ -51,9 +51,9 @@ Feature: Agenda Search {"_id": "test", "name": "Agenda2"} ], "subject": [ - {"code": "d2", "scheme": "sttdepartment", "name": "Dep2"}, - {"code": "s1", "scheme": "sttsubj", "name": "Sub1"}, - {"code": "s2", "scheme": "sttsubj", "name": "Sub2"} + {"qcode": "d2", "scheme": "sttdepartment", "name": "Dep2"}, + {"qcode": "s1", "scheme": "sttsubj", "name": "Sub1"}, + {"qcode": "s2", "scheme": "sttsubj", "name": "Sub2"} ], "place": [ {"code": "VIC", "name": "Victoria"} @@ -95,9 +95,9 @@ Feature: Agenda Search {"_id": "test", "name": "Agenda2"} ], "subject": [ - {"code": "d3", "scheme": "sttdepartment", "name": "Dep3"}, - {"code": "s1", "scheme": "sttsubj", "name": "Sub1"}, - {"code": "s2", "scheme": "sttsubj", "name": "Sub2"} + {"qcode": "d3", "scheme": "sttdepartment", "name": "Dep3"}, + {"qcode": "s1", "scheme": "sttsubj", "name": "Sub1"}, + {"qcode": "s2", "scheme": "sttsubj", "name": "Sub2"} ], "place": [ {"code": "VIC", "name": "Victoria"} diff --git a/mypy-requirements.txt b/mypy-requirements.txt index ecd9b696b..7bd469fec 100644 --- a/mypy-requirements.txt +++ b/mypy-requirements.txt @@ -1,4 +1,5 @@ mypy +pydantic types-Flask types-Jinja2 types-python-dateutil diff --git a/newsroom/agenda/__init__.py b/newsroom/agenda/__init__.py index 46517e850..b1b3a3ec8 100644 --- a/newsroom/agenda/__init__.py +++ b/newsroom/agenda/__init__.py @@ -1,11 +1,8 @@ -import superdesk from quart_babel import lazy_gettext -from superdesk.flask import Blueprint from newsroom.utils import url_for_agenda -from .agenda import AgendaResource, AgendaService, aggregations, PRIVATE_FIELDS -from newsroom.search.config import init_nested_aggregation -from . import formatters + +from .formatters import iCalFormatter, CSVFormatter from .utils import ( get_coverage_email_text, get_coverage_content_type_name, @@ -15,20 +12,22 @@ get_coverage_status, get_event_state, ) +from .filters import AgendaSearchRequestArgs +from .agenda_service import AgendaItemService +from .agenda_search import AgendaSearchServiceAsync -blueprint = Blueprint("agenda", __name__) - from . import views # noqa from .module import module # noqa - -AGENDA_NESTED_SEARCH_FIELDS = ["subject"] +__all__ = [ + "AgendaSearchServiceAsync", + "AgendaItemService", + "AgendaSearchRequestArgs", +] def init_app(app): - superdesk.register_resource("agenda", AgendaResource, AgendaService, _app=app) - app.section("agenda", app.config["AGENDA_SECTION"], "agenda") app.sidenav(app.config["AGENDA_SECTION"], "agenda.index", "calendar", section="agenda") app.sidenav( @@ -40,8 +39,8 @@ def init_app(app): badge="saved-items-count", ) - app.download_formatter("ical", formatters.iCalFormatter(), "iCalendar", ["agenda"]) - app.download_formatter("Csv", formatters.CSVFormatter(), "CSV", ["agenda"]) + app.download_formatter("ical", iCalFormatter(), "iCalendar", ["agenda"]) + app.download_formatter("Csv", CSVFormatter(), "CSV", ["agenda"]) app.add_template_global(url_for_agenda) app.add_template_global(get_coverage_email_text) app.add_template_global(get_coverage_content_type_name, "get_coverage_content_type") @@ -80,8 +79,3 @@ def init_app(app): "label": lazy_gettext("Place"), }, ] - - init_nested_aggregation("agenda", AGENDA_NESTED_SEARCH_FIELDS, app.config.get("AGENDA_GROUPS", []), aggregations) - - if app.config.get("AGENDA_HIDE_COVERAGE_ASSIGNEES"): - PRIVATE_FIELDS.extend(["*.assigned_desk_*", "*.assigned_user_*"]) diff --git a/newsroom/agenda/agenda.py b/newsroom/agenda/agenda.py deleted file mode 100644 index 8d3004d90..000000000 --- a/newsroom/agenda/agenda.py +++ /dev/null @@ -1,1550 +0,0 @@ -from typing import Dict, Set, Any, Optional, List -import logging -from copy import deepcopy - -from bson import ObjectId -from eve.utils import ParsedRequest -from quart_babel import lazy_gettext - -from superdesk.core import json, get_current_app -from content_api.items.resource import code_mapping -from planning.common import ( - WORKFLOW_STATE_SCHEMA, - ASSIGNMENT_WORKFLOW_STATE, - WORKFLOW_STATE, -) -from planning.events.events_schema import events_schema -from planning.planning.planning import planning_schema - -from superdesk.core import get_app_config -from superdesk.resource_fields import ITEMS -from superdesk import get_resource_service -from superdesk.resource import Resource, not_enabled, not_analyzed, not_indexed, string_with_analyzer -from superdesk.metadata.item import metadata_schema - -import newsroom -from newsroom.agenda.email import ( - send_coverage_notification_email, - send_agenda_notification_email, -) -from newsroom.notifications import save_user_notifications -from newsroom.search.types import BoolQuery, BoolQueryParams -from newsroom.template_filters import is_admin_or_internal, is_admin -from newsroom.utils import ( - get_user_dict, - get_company_dict, - get_entity_or_404, - parse_date_str, -) -from newsroom.utils import get_local_date, get_end_date -from datetime import datetime -from newsroom.wire import url_for_wire -from newsroom.search.service import BaseSearchService, SearchQuery, query_string, get_filter_query, strtobool -from newsroom.search.config import is_search_field_nested, get_nested_config -from .utils import get_latest_available_delivery, TO_BE_CONFIRMED_FIELD, push_agenda_item_notification - - -logger = logging.getLogger(__name__) -PRIVATE_FIELDS = ["event.files", "*.internal_note"] -PLANNING_ITEMS_FIELDS = ["planning_items", "coverages", "display_dates"] - - -agenda_notifications = { - "event_updated": { - "message": lazy_gettext("An event you have been watching has been updated"), - "subject": lazy_gettext("Event updated"), - }, - "event_unposted": { - "message": lazy_gettext("An event you have been watching has been cancelled"), - "subject": lazy_gettext("Event cancelled"), - }, - "planning_added": { - "message": lazy_gettext("An event you have been watching has a new planning"), - "subject": lazy_gettext("Planning added"), - }, - "planning_cancelled": { - "message": lazy_gettext("An event you have been watching has a planning cancelled"), - "subject": lazy_gettext("Planning cancelled"), - }, - "coverage_added": { - "message": lazy_gettext("An event you have been watching has a new coverage added"), - "subject": lazy_gettext("Coverage added"), - }, -} - -nested_code_mapping = { - "type": "list", - "mapping": { - "type": "nested", - "include_in_parent": True, - "properties": code_mapping["properties"], - }, -} - - -def set_saved_items_query(query, user_id): - query["bool"]["filter"].append( - { - "bool": { - "should": [ - {"term": {"bookmarks": str(user_id)}}, - {"term": {"watches": str(user_id)}}, - { - "nested": { - "path": "coverages", - "query": {"bool": {"should": [{"term": {"coverages.watches": str(user_id)}}]}}, - } - }, - ], - }, - } - ) - - -class AgendaResource(newsroom.Resource): - """ - Agenda schema - """ - - SUPPORTED_NESTED_SEARCH_FIELDS = ["subject"] - - schema = {} - - # identifiers - schema["guid"] = events_schema["guid"] - schema["type"] = { - "type": "string", - "mapping": not_analyzed, - "default": "agenda", - } - schema["event_id"] = events_schema["guid"] - schema["item_type"] = { - "type": "string", - "mapping": not_analyzed, - "nullable": False, - "allowed": ["event", "planning"], - } - schema["recurrence_id"] = { - "type": "string", - "mapping": not_analyzed, - "nullable": True, - } - - # content metadata - schema["name"] = metadata_schema["body_html"].copy() - schema["slugline"] = metadata_schema["body_html"].copy() - schema["definition_short"] = metadata_schema["body_html"].copy() - schema["definition_long"] = metadata_schema["body_html"].copy() - schema["description_text"] = metadata_schema["body_html"].copy() - schema["headline"] = metadata_schema["body_html"].copy() - schema["firstcreated"] = events_schema["firstcreated"] - schema["version"] = events_schema["version"] - schema["versioncreated"] = events_schema["versioncreated"] - schema["ednote"] = events_schema["ednote"] - schema["registration_details"] = {"type": "string"} - schema["invitation_details"] = {"type": "string"} - schema["language"] = {"type": "string", "mapping": {"type": "keyword"}} - schema["source"] = {"type": "string", "mapping": {"type": "keyword"}} - - # aggregated fields - schema["urgency"] = {**planning_schema["urgency"], "mapping": {"type": "keyword"}} - schema["priority"] = {**planning_schema["priority"], "mapping": {"type": "keyword"}} - schema["place"] = planning_schema["place"] - schema["service"] = planning_schema["anpa_category"] - schema["state_reason"] = {"type": "string"} - - # Fields supporting Nested Aggregation / Filtering - schema["subject"] = nested_code_mapping - - # dates - schema["dates"] = { - "type": "dict", - "schema": { - "start": {"type": "datetime"}, - "end": {"type": "datetime"}, - "tz": {"type": "string"}, - }, - } - - # additional dates from coverages or planning to be used in searching agenda items - schema["display_dates"] = { - "type": "list", - "nullable": True, - "schema": { - "type": "dict", - "schema": { - "date": {"type": "datetime"}, - }, - }, - } - - # coverages - schema["coverages"] = { - "type": "list", - "mapping": { - "type": "nested", - "include_in_parent": True, # Enabled so advanced search works across multiple fields - "properties": { - "planning_id": not_analyzed, - "coverage_id": not_analyzed, - "scheduled": {"type": "date"}, - "coverage_type": not_analyzed, - "workflow_status": not_analyzed, - "coverage_status": not_analyzed, - "coverage_provider": not_analyzed, - "slugline": string_with_analyzer, - "delivery_id": not_analyzed, # To point ot the latest published item - "delivery_href": not_analyzed, # To point ot the latest published item - TO_BE_CONFIRMED_FIELD: {"type": "boolean"}, - "deliveries": { # All deliveries (incl. updates go here) - "type": "object", - "properties": { - "planning_id": not_analyzed, - "coverage_id": not_analyzed, - "assignment_id": not_analyzed, - "sequence_no": not_analyzed, - "publish_time": {"type": "date"}, - "delivery_id": not_analyzed, - "delivery_state": not_analyzed, - }, - }, - "watches": not_analyzed, - "assigned_desk_name": not_analyzed, - "assigned_desk_email": not_indexed, - "assigned_user_name": not_analyzed, - "assigned_user_email": not_indexed, - }, - }, - } - - # attachments - schema["files"] = { - "type": "list", - "mapping": not_enabled, - } - - # state - schema["state"] = WORKFLOW_STATE_SCHEMA - - # other searchable fields needed in UI - schema["calendars"] = events_schema["calendars"] - schema["location"] = events_schema["location"] - - # update location name to allow exact search and term based searching - schema["location"]["mapping"]["properties"]["name"] = {"type": "text", "fields": {"keyword": {"type": "keyword"}}} - - # event details - schema["event"] = { - "type": "dict", - "mapping": not_enabled, - } - - # planning details which can be more than one per event - schema["planning_items"] = { - "type": "list", - "mapping": { - "type": "nested", - "include_in_parent": True, - "properties": { - "_id": not_analyzed, - "guid": not_analyzed, - "slugline": not_analyzed, - "description_text": {"type": "string"}, - "headline": {"type": "string"}, - "abstract": {"type": "string"}, - "subject": nested_code_mapping["mapping"], - "urgency": {"type": "integer"}, - "service": code_mapping, - "planning_date": {"type": "date"}, - "coverages": not_enabled, - "agendas": { - "type": "object", - "properties": { - "name": not_analyzed, - "_id": not_analyzed, - }, - }, - "ednote": {"type": "string"}, - "internal_note": not_indexed, - "place": planning_schema["place"]["mapping"], - "state": not_analyzed, - "state_reason": {"type": "string"}, - "products": { - "type": "object", - "properties": {"code": not_analyzed, "name": not_analyzed}, - }, - }, - }, - } - - schema["bookmarks"] = Resource.not_analyzed_field("list") # list of user ids who bookmarked this item - schema["downloads"] = Resource.not_analyzed_field("list") # list of user ids who downloaded this item - schema["shares"] = Resource.not_analyzed_field("list") # list of user ids who shared this item - schema["prints"] = Resource.not_analyzed_field("list") # list of user ids who printed this item - schema["copies"] = Resource.not_analyzed_field("list") # list of user ids who copied this item - schema["watches"] = Resource.not_analyzed_field("list") # list of users following the event - - # matching products from superdesk - schema["products"] = { - "type": "list", - "mapping": { - "type": "object", - "properties": {"code": not_analyzed, "name": not_analyzed}, - }, - } - - resource_methods = ["GET"] - datasource = { - "source": "agenda", - "search_backend": "elastic", - "default_sort": [("dates.start", 1)], - } - - item_methods = ["GET"] - - -def build_agenda_query(): - return { - "bool": { - "filter": [], - "should": [], - "must_not": [{"term": {"state": "killed"}}], - } - } - - -def get_date_filters(args): - date_range = {} - offset = int(args.get("timezone_offset", "0")) - if args.get("date_from"): - date_range["gt"] = get_local_date(args["date_from"], "00:00:00", offset) - if args.get("date_to"): - date_range["lt"] = get_end_date(args["date_to"], get_local_date(args["date_to"], "23:59:59", offset)) - return date_range - - -def gen_date_range_filter(field: str, operator: str, date_str: str, datetime_instance: datetime): - return [ - { - "bool": { - "must_not": {"term": {"dates.all_day": True}}, - "filter": {"range": {field: {operator: datetime_instance}}}, - }, - }, - { - "bool": { - "filter": [ - {"term": {"dates.all_day": True}}, - {"range": {field: {operator: date_str}}}, - ], - }, - }, - ] - - -def _set_event_date_range(search): - """Get events for selected date. - - ATM it should display everything not finished by that date, even starting later. - - :param newsroom.search.SearchQuery search: The search query instance - """ - - date_range = get_date_filters(search.args) - date_from = date_range.get("gt") - date_to = date_range.get("lt") - - should = [] - - if date_from and not date_to: - # Filter from a particular date onwards - should = gen_date_range_filter("dates.end", "gte", search.args["date_from"], date_from) - elif not date_from and date_to: - # Filter up to a particular date - should = gen_date_range_filter("dates.end", "lte", search.args["date_to"], date_to) - elif date_from and date_to: - # Filter based on the date range provided - should = [ - { - # Both start/end dates are inside the range - "bool": { - "filter": [ - {"range": {"dates.start": {"gte": date_from}}}, - {"range": {"dates.end": {"lte": date_to}}}, - ], - "must_not": {"term": {"dates.all_day": True}}, - }, - }, - { - # Both start/end dates are inside the range, all day version - "bool": { - "filter": [ - {"range": {"dates.start": {"gte": search.args["date_from"]}}}, - {"range": {"dates.end": {"lte": search.args["date_to"]}}}, - {"term": {"dates.all_day": True}}, - ], - }, - }, - { - # Starts before date_from and finishes after date_to - "bool": { - "filter": [ - {"range": {"dates.start": {"lt": date_from}}}, - {"range": {"dates.end": {"gt": date_to}}}, - ], - "must_not": {"term": {"dates.all_day": True}}, - }, - }, - { - # Starts before date_from and finishes after date_to, all day version - "bool": { - "filter": [ - {"range": {"dates.start": {"lt": search.args["date_from"]}}}, - {"range": {"dates.end": {"gt": search.args["date_to"]}}}, - {"term": {"dates.all_day": True}}, - ], - }, - }, - { - # Start date is within range OR End date is within range - "bool": { - "should": [ - {"range": {"dates.start": {"gte": date_from, "lte": date_to}}}, - {"range": {"dates.end": {"gte": date_from, "lte": date_to}}}, - ], - "must_not": {"term": {"dates.all_day": True}}, - "minimum_should_match": 1, - }, - }, - { - # Start date is within range OR End date is within range, all day version - "bool": { - "should": [ - {"range": {"dates.start": {"gte": search.args["date_from"], "lte": search.args["date_to"]}}}, - {"range": {"dates.end": {"gte": search.args["date_from"], "lte": search.args["date_to"]}}}, - ], - "filter": {"term": {"dates.all_day": True}}, - "minimum_should_match": 1, - }, - }, - ] - - # Get events for extra dates for coverages and planning. - should.append({"range": {"display_dates.date": date_range}}) - - if len(should): - search.query["bool"]["filter"].append({"bool": {"should": should, "minimum_should_match": 1}}) - - -aggregations: Dict[str, Dict[str, Any]] = { - "language": {"terms": {"field": "language"}}, - "calendar": {"terms": {"field": "calendars.name", "size": 100}}, - "service": {"terms": {"field": "service.name", "size": 100}}, - "subject": {"terms": {"field": "subject.name", "size": 200}}, - "urgency": {"terms": {"field": "urgency"}}, - "place": {"terms": {"field": "place.name", "size": 50}}, - "coverage": { - "nested": {"path": "coverages"}, - "aggs": {"coverage_type": {"terms": {"field": "coverages.coverage_type", "size": 10}}}, - }, - "planning_items": { - "nested": { - "path": "planning_items", - }, - "aggs": { - "service": {"terms": {"field": "planning_items.service.name", "size": 100}}, - "subject": {"terms": {"field": "planning_items.subject.name", "size": 200}}, - "urgency": {"terms": {"field": "planning_items.urgency"}}, - "place": {"terms": {"field": "planning_items.place.name", "size": 50}}, - }, - }, - "agendas": { - "nested": {"path": "planning_items"}, - "aggs": { - "agenda": {"terms": {"field": "planning_items.agendas.name", "size": 100}}, - }, - }, -} - - -def get_agenda_aggregations(events_only=False): - aggs = deepcopy(aggregations) - if events_only: - aggs.pop("coverage", None) - aggs.pop("planning_items", None) - aggs.pop("urgency", None) - aggs.pop("agendas", None) - return aggs - - -def get_aggregation_field(key: str): - if key == "coverage": - return aggregations[key]["aggs"]["coverage_type"]["terms"]["field"] - elif key == "agendas": - return aggregations[key]["aggs"]["agenda"]["terms"]["field"] - elif is_search_field_nested("agenda", key): - return aggregations[key]["aggs"][f"{key}_filtered"]["aggs"][key]["terms"]["field"] - return aggregations[key]["terms"]["field"] - - -def nested_query(path, query, inner_hits=True, name=None): - nested = {"path": path, "query": query} - if inner_hits: - nested["inner_hits"] = {} - if name: - nested["inner_hits"]["name"] = name - - return {"nested": nested} - - -coverage_filters = ["coverage", "coverage_status"] -planning_filters = coverage_filters + ["agendas"] - - -def _filter_terms(filters, item_type, highlights=False): - must_term_filters = [] - must_not_term_filters = [] - for key, val in filters.items(): - if not val: - continue - elif item_type == "events" and key in planning_filters: - continue - elif key == "location": - search_type = val.get("type", "location") - - if search_type == "city": - field = "location.address.city.keyword" - elif search_type == "state": - field = "location.address.state.keyword" - elif search_type == "country": - field = "location.address.country.keyword" - else: - field = "location.name.keyword" - - must_term_filters.append({"term": {field: val.get("name")}}) - elif key == "coverage": - must_term_filters.append( - nested_query( - path="coverages", - query={"terms": {get_aggregation_field(key): val}}, - name="coverage", - ) - ) - elif key == "coverage_status": - if val == ["planned"]: - must_term_filters.append( - nested_query( - path="coverages", - query={"terms": {"coverages.coverage_status": ["coverage intended"]}}, - name="coverage_status", - ) - ) - must_not_term_filters.append( - nested_query( - path="coverages", - query={"exists": {"field": "coverages.delivery_id"}}, - ) - ) - must_not_term_filters.append( - nested_query( - path="coverages", - query={"terms": {"coverages.workflow_status": ["completed"]}}, - name="workflow_status", - ) - ) - elif val == ["may be"]: - must_term_filters.append( - nested_query( - path="coverages", - query={ - "terms": { - "coverages.coverage_status": [ - "coverage not decided yet", - "coverage upon request", - ], - }, - }, - name="coverage_status", - ) - ) - elif val == ["not planned"]: - must_not_term_filters.append( - nested_query( - path="coverages", - query={"exists": {"field": "coverages"}}, - name="coverage_status", - ) - ) - elif val == ["completed"]: - should_term_filters = [] - # Check if "delivery_id" is present - should_term_filters.append( - nested_query( - path="coverages", - query={"exists": {"field": "coverages.delivery_id"}}, - ) - ) - - # If "delivery_id" is not present, check "workflow_status" - should_term_filters.append( - nested_query( - path="coverages", - query={"terms": {"coverages.workflow_status": ["completed"]}}, - name="workflow_status", - ) - ) - must_term_filters.append({"bool": {"should": should_term_filters}}) - elif val == ["not intended"]: - must_term_filters.append( - nested_query( - path="coverages", - query={"terms": {"coverages.coverage_status": ["coverage not intended"]}}, - name="coverage_status", - ) - ) - elif key == "agendas": - must_term_filters.append( - nested_query( - path="planning_items", - query={"terms": {get_aggregation_field(key): val}}, - name="agendas", - ) - ) - else: - if item_type != "events": - agg_field = get_aggregation_field(key) - must_term_filters.append( - { - "bool": { - "minimum_should_match": 1, - "should": [ - get_filter_query(key, val, agg_field, get_nested_config("agenda", key)), - nested_query( - path="planning_items", - query={"bool": {"filter": [{"terms": {f"planning_items.{agg_field}": val}}]}}, - name=key, - inner_hits=highlights, - ), - ], - }, - } - ) - else: - must_term_filters.append( - get_filter_query(key, val, get_aggregation_field(key), get_nested_config("agenda", key)) - ) - - return { - "must_term_filters": must_term_filters, - "must_not_term_filters": must_not_term_filters, - } - - -def remove_fields(source, fields): - """Add fields to remove the elastic search - - :param dict source: elasticsearch query object - :param fields: list of fields - """ - if not source.get("_source"): - source["_source"] = {} - - if not source.get("_source").get("exclude"): - source["_source"]["exclude"] = [] - - source["_source"]["exclude"].extend(fields) - - -def planning_items_query_string(query, fields=None): - return query_string(query, fields=fields or ["planning_items.*"]) - - -def is_events_only_access(user, company): - if user and company and not is_admin(user): - return company.get("events_only", False) - return False - - -def filter_active_users(user_ids, user_dict, company_dict, events_only=False): - active = [] - for _id in user_ids: - user = user_dict.get(str(_id)) - if user and (not user.get("company") or str(user.get("company", "")) in company_dict): - if ( - events_only - and user.get("company") - and (company_dict.get(str(user.get("company", ""))) or {}).get("events_only") - ): - continue - active.append(_id) - return active - - -class AgendaService(BaseSearchService): - section = "agenda" - limit_days_setting = None - default_sort = [{"dates.start": "asc"}] - - @property - def default_page_size(self) -> int: - return get_app_config("AGENDA_PAGE_SIZE", 250) - - def get_advanced_search_fields(self, search: SearchQuery) -> List[str]: - fields = super().get_advanced_search_fields(search) - - if "slugline" in fields: - # Add ``slugline`` field for Planning & Coverages too - fields.extend(["planning_items.slugline", "coverages.slugline"]) - - if "headline" in fields: - # Add ``headline`` field for Planning items too - fields.append("planning_items.headline") - - if "description" in fields: - # Replace ``description`` alias with appropriate description fields - fields.remove("description") - fields.extend( - ["definition_short", "definition_long", "description_text", "planning_items.description_text"] - ) - - return fields - - def on_fetched(self, doc): - self.enhance_items(doc[ITEMS]) - - def on_fetched_item(self, doc): - self.enhance_items([doc]) - - def enhance_items(self, docs): - for doc in docs: - self.enhance_coverages(doc.get("coverages") or []) - doc.setdefault("_hits", {}) - doc["_hits"]["matched_event"] = doc.pop("_search_matched_event", False) - - if not doc.get("planning_items"): - continue - - doc["_hits"]["matched_planning_items"] = [plan["_id"] for plan in doc.get("planning_items") or []] - - # Filter based on _inner_hits - inner_hits = doc.pop("_inner_hits", {}) - - # If the search matched the Event - # then only count Planning based filters when checking ``_inner_hits`` - if doc["_hits"]["matched_event"]: - inner_hits = {key: val for key, val in inner_hits.items() if key in planning_filters} - - if not inner_hits or not doc.get("planning_items"): - continue - - if len([f for f in inner_hits.keys() if f in coverage_filters]) > 0: - # Collect hits for 'coverage' and 'coverage_status' separately to other inner_hits - coverages_by_filter = { - key: [item.get("coverage_id") for item in items] - for key, items in inner_hits.items() - if key in ["coverage", "coverage_status"] - } - unique_coverage_ids = set( - [coverage_id for items in coverages_by_filter.values() for coverage_id in items] - ) - doc["_hits"]["matched_coverages"] = [ - coverage_id - for coverage_id in unique_coverage_ids - if all([coverage_id in items for items in coverages_by_filter.values()]) - ] - - if doc["item_type"] == "planning": - # If this is a Planning item, then ``inner_hits`` should only include the - # fields relevant to the Coverages (as this is the only nested field of a Planning item) - inner_hits = {key: val for key, val in inner_hits.items() if key in planning_filters} - - if len(inner_hits.keys()) > 0: - # Store matched Planning IDs into matched_planning_items - # The Planning IDs must be in all supplied ``_inner_hits`` - # In order to be included (i.e. match all nested planning queries) - items_by_filter = { - key: [item.get("guid") or item.get("planning_id") for item in items] - for key, items in inner_hits.items() - } - unique_ids = set([item_id for items in items_by_filter.values() for item_id in items]) - doc["_hits"]["matched_planning_items"] = [ - item_id for item_id in unique_ids if all([item_id in items for items in items_by_filter.values()]) - ] - - def enhance_coverages(self, coverages): - completed_coverages = [ - c - for c in coverages - if c["workflow_status"] == ASSIGNMENT_WORKFLOW_STATE.COMPLETED and len(c.get("deliveries") or []) > 0 - ] - # Enhance completed coverages in general - add story's abstract/headline/slugline - text_delivery_ids = [ - c.get("delivery_id") - for c in completed_coverages - if c.get("delivery_id") and c.get("coverage_type") == "text" - ] - # TODO-ASYNC: Use new WireSearchServiceAsync service when Agenda is migrated to async - wire_search_service = get_resource_service("wire_search") - if text_delivery_ids: - wire_items = wire_search_service.get_items(text_delivery_ids) - if wire_items.count() > 0: - for item in wire_items: - c = [c for c in completed_coverages if c.get("delivery_id") == item.get("_id")][0] - self.enhance_coverage_with_wire_details(c, item) - - media_coverages = [c for c in completed_coverages if c.get("coverage_type") != "text"] - for c in media_coverages: - try: - c["deliveries"][0]["delivery_href"] = c["delivery_href"] = ( - get_current_app().as_any().set_photo_coverage_href(c, None, c["deliveries"]) - ) - except Exception as e: - logger.exception(e) - logger.error("Failed to generate delivery_href for coverage={}".format(c.get("coverage_id"))) - - def enhance_coverage_with_wire_details(self, coverage, wire_item): - coverage["publish_time"] = wire_item.get("publish_schedule") or wire_item.get("firstpublished") - - def get(self, req, lookup): - cursor = super().get(req, lookup) - - args = req.args - if args.get("itemType") is None or (args.get("date_from") and args.get("date_to")): - matching_event_ids: Set[str] = ( - set() if args.get("itemType") is not None else self._get_event_ids_matching_query(req, lookup) - ) - date_range = {} if not (args.get("date_from") and args.get("date_to")) else get_date_filters(args) - - for doc in cursor.docs: - if doc["_id"] in matching_event_ids: - doc["_search_matched_event"] = True - if date_range: - # make the items display on the featured day, - # it's used in ui instead of dates.start and dates.end - doc.update( - { - "_display_from": date_range.get("gt"), - "_display_to": date_range.get("lt"), - } - ) - - return cursor - - def _get_event_ids_matching_query(self, req, lookup) -> Set[str]: - """Re-run the query to retrieve the list of Event IDs matching the query - - This is used to show ALL Planning Items for the Event if the search query matched the parent Event - """ - - orig_args = req.args - req.args = {key: val for key, val in dict(req.args).items() if key not in planning_filters} - req.args["itemType"] = "events" - req.args["aggs"] = "false" - req.projection = json.dumps({"_id": 1}) - item_ids = set([item["_id"] for item in super().get(req, lookup)]) - req.args = orig_args - return item_ids - - def prefill_search_query(self, search: SearchQuery, req=None, lookup=None): - """Generate the search query instance - - :param newsroom.search.SearchQuery search: The search query instance - :param ParsedRequest req: The parsed in request instance from the endpoint - :param dict lookup: The parsed in lookup dictionary from the endpoint - """ - - super().prefill_search_query(search, req, lookup) - search.item_type = ( - "events" if is_events_only_access(search.user, search.company) else search.args.get("itemType") - ) - - def prefill_search_items(self, search): - """Prefill the item filters - - :param newsroom.search.SearchQuery search: The search query instance - """ - - pass - - def apply_filters(self, search: SearchQuery, section_filters=None): - """Generate and apply the different search filters - - :param newsroom.search.SearchQuery search: the search query instance - """ - # First construct the product query - self.apply_company_filter(search) - - search.planning_items_should = [] - self.apply_products_filter(search) - - if search.planning_items_should: - search.query["bool"]["should"].append( - nested_query( - "planning_items", - { - "bool": { - "should": search.planning_items_should, - "minimum_should_match": 1, - } - }, - name="products", - ) - ) - search.query["bool"]["minimum_should_match"] = 1 - - # Append the product query to the agenda query - agenda_query = build_agenda_query() - agenda_query["bool"]["filter"].append(search.query) - search.query = agenda_query - - # Apply agenda based filters - self.apply_section_filter(search, section_filters) - self.apply_request_filter(search) - - if not is_admin_or_internal(search.user): - remove_fields(search.source, PRIVATE_FIELDS) - - if search.item_type == "events": - # no adhoc planning items and remove planning items and coverages fields - search.query["bool"]["filter"].append( - { - "bool": { - "should": [ - {"term": {"item_type": "event"}}, - { - # Match Events before ``item_type`` field was added - "bool": { - "must_not": [{"exists": {"field": "item_type"}}], - "filter": [{"exists": {"field": "event_id"}}], - }, - }, - ], - "minimum_should_match": 1, - }, - } - ) - remove_fields(search.source, PLANNING_ITEMS_FIELDS) - elif search.item_type == "planning": - search.query["bool"]["filter"].append( - { - "bool": { - "should": [ - {"term": {"item_type": "planning"}}, - { - # Match Planning before ``item_type`` field was added - "bool": { - "must_not": [ - {"exists": {"field": "item_type"}}, - {"exists": {"field": "event_id"}}, - ], - }, - }, - ], - "minimum_should_match": 1, - } - } - ) - elif get_app_config("AGENDA_DEFAULT_FILTER_HIDE_PLANNING"): - search.query["bool"]["filter"].append( - { - "bool": { - "should": [ - {"term": {"item_type": "event"}}, - { - # Match Events before ``item_type`` field was added - "bool": { - "must_not": [{"exists": {"field": "item_type"}}], - "filter": [{"exists": {"field": "event_id"}}], - }, - }, - ], - "minimum_should_match": 1, - }, - } - ) - remove_fields(search.source, ["planning_items", "display_dates"]) - else: - # Don't include Planning items that are associated with an Event - search.query["bool"]["filter"].append( - { - "bool": { - "should": [ - {"bool": {"must_not": [{"exists": {"field": "item_type"}}]}}, - {"term": {"item_type": "event"}}, - { - "bool": { - "filter": [{"term": {"item_type": "planning"}}], - "must_not": [{"exists": {"field": "event_id"}}], - }, - }, - ], - "minimum_should_match": 1, - }, - } - ) - - def apply_product_filter(self, search, product): - """Generate the filter for a single product - - :param newsroom.search.SearchQuery search: The search query instance - :param dict product: The product to filter - :return: - """ - if search.args.get("requested_products") and product["_id"] not in search.args["requested_products"]: - return - - if product.get("query"): - search.query["bool"]["should"].append(self.query_string(product["query"])) - - if product.get("planning_item_query") and search.item_type != "events": - search.planning_items_should.append(planning_items_query_string(product.get("planning_item_query"))) - - def apply_request_filter(self, search: SearchQuery, highlights=False): - """Generate the request filters - - :param newsroom.search.SearchQuery search: The search query instance - """ - - if search.args.get("q"): - test_query: BoolQuery = {"bool": {"should": []}} - try: - q = json.loads(search.args.get("q")) - if isinstance(q, dict): - # used for product testing - if q.get("query"): - test_query["bool"]["should"].append(self.query_string(q.get("query"))) - - if q.get("planning_item_query"): - test_query["bool"]["should"].append( - nested_query( - "planning_items", - planning_items_query_string(q.get("planning_item_query")), - name="product_test", - ) - ) - - if test_query["bool"]["should"]: - search.query["bool"]["filter"].append(test_query) - except Exception: - pass - - if not test_query["bool"]["should"]: - search.query["bool"]["filter"].append( - self.get_agenda_query(search.args["q"], search.item_type == "events") - ) - - if search.args.get("id"): - search.query["bool"]["filter"].append({"ids": {"values": [search.args["id"]]}}) - - if search.args.get("ids"): - search.query["bool"]["filter"].append({"ids": {"values": search.args["ids"]}}) - - if search.args.get("bookmarks"): - set_saved_items_query(search.query, search.args["bookmarks"]) - - if search.args.get("date_from") or search.args.get("date_to"): - _set_event_date_range(search) - - filters = self.parse_filters(search) - if filters: - self.set_bool_query_from_filters(search.query["bool"], filters, search.item_type, highlights=highlights) - - self.apply_request_advanced_search(search) - - def set_post_filter(self, source: Dict[str, Any], req: ParsedRequest, item_type: Optional[str] = None): - filters = req.args.get("filter") - if isinstance(filters, str): - filters = json.loads(filters) - - if not filters: - return - - if get_app_config("FILTER_BY_POST_FILTER", False): - source["post_filter"] = {"bool": {}} - self.set_bool_query_from_filters(source["post_filter"]["bool"], filters, item_type) - - def gen_source_from_search(self, search): - """Generate the eve source object from the search query instance - - :param newsroom.search.SearchQuery search: The search query instance - """ - - super().gen_source_from_search(search) - - self.set_post_filter(search.source, search, search.item_type) - - if ( - not search.source["from"] - and not search.args.get("bookmarks") - and strtobool(search.args.get("aggs", "true")) - ): - # avoid aggregations when handling pagination - search.source["aggs"] = get_agenda_aggregations(search.item_type == "events") - else: - search.source.pop("aggs", None) - - def get_items(self, item_ids): - query = { - "bool": { - "filter": [{"terms": {"_id": item_ids}}], - } - } - get_resource_service("section_filters").apply_section_filter(query, self.section) - return self.get_agenda_items_by_query(query, size=len(item_ids)) - - def get_agenda_items_by_query(self, query, size=50): - try: - source = {"query": query} - - if size: - source["size"] = size - - req = ParsedRequest() - req.args = {"source": json.dumps(source)} - - return self.internal_get(req, None) - except Exception as exc: - logger.error( - "Error in get_items for agenda query: {}".format(json.dumps(source)), - exc, - exc_info=True, - ) - - def get_matching_bookmarks(self, item_ids, active_users, active_companies): - """Returns a list of user ids bookmarked any of the given items - - :param item_ids: list of ids of items to be searched - :param active_users: user_id, user dictionary - :param active_companies: company_id, company dictionary - :return: - """ - bookmark_users = [] - - search_results = self.get_items(item_ids) - - if not search_results: - return bookmark_users - - for result in search_results.hits["hits"]["hits"]: - bookmarks = result["_source"].get("watches", []) - for bookmark in bookmarks: - user = active_users.get(bookmark) - if user and str(user.get("company", "")) in active_companies: - bookmark_users.append(bookmark) - - return bookmark_users - - async def set_delivery(self, wire_item): - if not wire_item.get("coverage_id"): - return - - def is_delivery_validated(coverage, item): - latest_delivery = get_latest_available_delivery(coverage) - if not latest_delivery or not item.get("rewrite_sequence"): - return True - - if (item.get("rewrite_sequence") or 0) >= latest_delivery.get("sequence_no", 0) or ( - item.get("publish_schedule") or item.get("firstpublished") - ) >= latest_delivery.get("publish_time"): - return True - - return False - - query = { - "bool": { - "filter": [ - { - "nested": { - "path": "coverages", - "query": { - "bool": {"filter": [{"term": {"coverages.coverage_id": wire_item["coverage_id"]}}]} - }, - } - } - ], - } - } - - agenda_items = self.get_agenda_items_by_query(query) - agenda_updated_notification_sent = False - - def update_coverage_details(coverage): - coverage["delivery_id"] = wire_item["guid"] - coverage["delivery_href"] = url_for_wire( - None, - _external=False, - section="wire.item", - item_id=wire_item["guid"], - ) - coverage["workflow_status"] = ASSIGNMENT_WORKFLOW_STATE.COMPLETED - deliveries = coverage.get("deliveries") or [] - d = next( - (d for d in deliveries if d.get("delivery_id") == wire_item["guid"]), - None, - ) - if d and d.get("delivery_state") != "published": - d["delivery_state"] = "published" - d["publish_time"] = parse_date_str(wire_item.get("publish_schedule") or wire_item.get("firstpublished")) - return d - - for item in agenda_items: - self.enhance_coverage_watches(item) - - parent_coverage = next( - (c for c in item.get("coverages") or [] if c["coverage_id"] == wire_item["coverage_id"]), None - ) - if not parent_coverage or not is_delivery_validated(parent_coverage, item): - continue - - delivery = update_coverage_details(parent_coverage) - planning_item = next( - (p for p in item.get("planning_items") or [] if p["_id"] == parent_coverage["planning_id"]), None - ) - planning_updated = False - if planning_item: - coverage = next( - (c for c in planning_item.get("coverages") or [] if c["coverage_id"] == wire_item["coverage_id"]), - None, - ) - if coverage: - planning_updated = True - update_coverage_details(coverage) - - if not planning_updated: - self.system_update(item["_id"], {"coverages": item["coverages"]}, item) - else: - updates = { - "coverages": item["coverages"], - "planning_items": item["planning_items"], - } - self.system_update(item["_id"], updates, item) - - updated_agenda = get_entity_or_404(item.get("_id"), "agenda") - # Notify agenda to update itself with new details of coverage - self.enhance_coverage_with_wire_details(parent_coverage, wire_item) - await push_agenda_item_notification("new_item", item=item) - - # If published first time, coverage completion will trigger email - not needed now - if (delivery or {}).get("sequence_no", 0) > 0 and not agenda_updated_notification_sent: - agenda_updated_notification_sent = True - await self.notify_agenda_update(updated_agenda, updated_agenda, None, True, None, parent_coverage) - return agenda_items - - def set_bool_query_from_filters( - self, - bool_query: BoolQueryParams, - filters: Dict[str, Any], - item_type: Optional[str] = None, - highlights=False, - ): - filter_terms = _filter_terms(filters, item_type, highlights=highlights) - bool_query.setdefault("filter", []) - bool_query["filter"] += filter_terms["must_term_filters"] - - bool_query.setdefault("must_not", []) - bool_query["must_not"] += filter_terms["must_not_term_filters"] - - def get_matching_topics(self, item_id, topics, users, companies): - """Returns a list of topic ids matching to the given item_id - - :param item_id: item id to be tested against all topics - :param topics: list of topics - :param users: user_id, user dictionary - :param companies: company_id, company dictionary - :return: - """ - - return self.get_matching_topics_for_item( - topics, - users, - companies, - { - "bool": { - "must_not": [ - {"term": {"state": "killed"}}, - ], - "must": [ - {"term": {"_id": item_id}}, - ], - "should": [], - } - }, - ) - - def enhance_coverage_watches(self, item): - for c in item.get("coverages") or []: - if c.get("watches"): - c["watches"] = [ObjectId(u) for u in c["watches"]] - - async def notify_new_coverage(self, agenda, wire_item): - user_dict = await get_user_dict() - company_dict = await get_company_dict() - notify_user_ids = filter_active_users(agenda.get("watches", []), user_dict, company_dict, events_only=True) - for user_id in notify_user_ids: - user = user_dict[str(user_id)] - await send_coverage_notification_email(user, agenda, wire_item) - - async def notify_agenda_update( - self, - update_agenda, - original_agenda, - item=None, - events_only=False, - related_planning_removed=None, - coverage_updated=None, - ): - agenda = deepcopy(update_agenda) - if agenda and original_agenda.get("state") != WORKFLOW_STATE.KILLED: - user_dict = await get_user_dict() - company_dict = await get_company_dict() - coverage_watched = None - for c in original_agenda.get("coverages") or []: - if len(c.get("watches") or []) > 0: - coverage_watched = True - break - - notify_user_ids = filter_active_users( - original_agenda.get("watches", []), user_dict, company_dict, events_only - ) - - users = [user_dict[str(user_id)] for user_id in notify_user_ids] - - if len(notify_user_ids) == 0 and not coverage_watched: - return - - def get_detailed_coverage(cov): - plan = next( - (p for p in (agenda.get("planning_items") or []) if p["guid"] == cov.get("planning_id")), - None, - ) - if plan and plan.get("state") != WORKFLOW_STATE.KILLED: - detail_cov = next( - (c for c in (plan.get("coverages") or []) if c.get("coverage_id") == cov.get("coverage_id")), - None, - ) - if detail_cov: - detail_cov["watches"] = cov.get("watches") - - return detail_cov - - original_cov = next( - (c for c in original_agenda.get("coverages") or [] if c["coverage_id"] == cov["coverage_id"]), - cov, - ) - cov["watches"] = original_cov.get("watches") or [] - return cov - - def fill_all_coverages(skip_coverages=[], cancelled=False, use_original_agenda=False): - fill_list = ( - coverage_updates["unaltered_coverages"] - if not cancelled - else coverage_updates["cancelled_coverages"] - ) - for coverage in (agenda if not use_original_agenda else original_agenda).get("coverages") or []: - if not next( - (s for s in skip_coverages if s.get("coverage_id") == coverage.get("coverage_id")), - None, - ): - detailed_coverage = get_detailed_coverage(coverage) - if detailed_coverage: - fill_list.append(detailed_coverage) - - coverage_updates = { - "modified_coverages": [] if not coverage_updated else [coverage_updated], - "cancelled_coverages": [], - "unaltered_coverages": [], - } - - only_new_coverages = len(coverage_updates["modified_coverages"]) == 0 - time_updated = False - state_changed = False - coverage_modified = False - - # Send notification for only these state changes - notify_states = [ - WORKFLOW_STATE.CANCELLED, - WORKFLOW_STATE.RESCHEDULED, - WORKFLOW_STATE.POSTPONED, - WORKFLOW_STATE.KILLED, - WORKFLOW_STATE.SCHEDULED, - ] - - if not coverage_updated: # If not story updates - but from planning side - if not related_planning_removed: - # Send notification if time got updated - if original_agenda.get("dates") and agenda.get("dates"): - time_updated = (original_agenda.get("dates") or {}).get("start").replace(tzinfo=None) != ( - agenda.get("dates") or {} - ).get("start").replace(tzinfo=None) or (original_agenda.get("dates") or {}).get("end").replace( - tzinfo=None - ) != ( - agenda.get("dates") or {} - ).get( - "end" - ).replace( - tzinfo=None - ) - - if agenda.get("state") and agenda.get("state") != original_agenda.get("state"): - state_changed = agenda.get("state") in notify_states - - if not state_changed: - if time_updated: - fill_all_coverages() - else: - for coverage in agenda.get("coverages") or []: - existing_coverage = next( - ( - c - for c in original_agenda.get("coverages") or [] - if c["coverage_id"] == coverage["coverage_id"] - ), - None, - ) - detailed_coverage = get_detailed_coverage(coverage) - if detailed_coverage: - if not existing_coverage: - if coverage["workflow_status"] != WORKFLOW_STATE.CANCELLED: - coverage_updates["modified_coverages"].append(detailed_coverage) - elif coverage.get( - "workflow_status" - ) == WORKFLOW_STATE.CANCELLED and existing_coverage.get( - "workflow_status" - ) != coverage.get( - "workflow_status" - ): - coverage_updates["cancelled_coverages"].append(detailed_coverage) - elif ( - ( - coverage.get("delivery_state") != existing_coverage.get("delivery_state") - and coverage.get("delivery_state") == "published" - ) - or ( - coverage.get("workflow_status") != existing_coverage.get("workflow_status") - and coverage.get("workflow_status") == "completed" - ) - or (existing_coverage.get("scheduled") != coverage.get("scheduled")) - ): - coverage_updates["modified_coverages"].append(detailed_coverage) - only_new_coverages = False - elif detailed_coverage["coverage_id"] != (coverage_updated or {}).get( - "coverage_id" - ): - coverage_updates["unaltered_coverages"].append(detailed_coverage) - - # Check for removed coverages - show it as cancelled - if item and item.get("type") == "planning": - for original_cov in original_agenda.get("coverages") or []: - updated_cov = next( - ( - c - for c in (agenda.get("coverages") or []) - if c.get("coverage_id") == original_cov.get("coverage_id") - ), - None, - ) - if not updated_cov: - coverage_updates["cancelled_coverages"].append(original_cov) - else: - fill_all_coverages( - cancelled=False if agenda.get("state") == WORKFLOW_STATE.SCHEDULED else True, - use_original_agenda=True, - ) - else: - fill_all_coverages(related_planning_removed.get("coverages") or []) - # Add removed coverages: - for coverage in related_planning_removed.get("coverages") or []: - detailed_coverage = get_detailed_coverage(coverage) - if detailed_coverage: - coverage_updates["cancelled_coverages"].append(detailed_coverage) - - if len(coverage_updates["modified_coverages"]) > 0 or len(coverage_updates["cancelled_coverages"]) > 0: - coverage_modified = True - - if coverage_updated or related_planning_removed or time_updated or state_changed or coverage_modified: - agenda["name"] = agenda.get("name", original_agenda.get("name")) - agenda["definition_short"] = agenda.get("definition_short", original_agenda.get("definition_short")) - agenda["ednote"] = agenda.get("ednote", original_agenda.get("ednote")) - agenda["state_reason"] = agenda.get("state_reason", original_agenda.get("state_reason")) - action = "been updated." - if state_changed: - action = "been {}.".format( - agenda.get("state") - if agenda.get("state") != WORKFLOW_STATE.KILLED - else "removed from the calendar" - ) - - if ( - len(coverage_updates["modified_coverages"]) > 0 - and only_new_coverages - and len(coverage_updates["cancelled_coverages"]) == 0 - ): - action = "new coverage(s)." - - message = "The {} you have been following has {}".format( - "event" if agenda.get("event") else "coverage plan", action - ) - if agenda.get("state_reason"): - reason_prefix = agenda.get("state_reason").find(":") - if reason_prefix > 0: - message = "{} {}".format( - message, - agenda["state_reason"][(reason_prefix + 1) : len(agenda["state_reason"])], - ) - - # append coverage watching users too - except for unaltered_coverages - for c in coverage_updates["cancelled_coverages"] + coverage_updates["modified_coverages"]: - if c.get("watches"): - notify_user_ids = filter_active_users(c["watches"], user_dict, company_dict, events_only) - users = users + [user_dict[str(user_id)] for user_id in notify_user_ids] - - # Send notifications to users - await save_user_notifications( - [ - dict( - user=user["_id"], - item=agenda.get("_id"), - resource="agenda", - action="watched_agenda_updated", - data=None, - ) - for user in users - ] - ) - - for user in users: - await send_agenda_notification_email( - user, - agenda, - message, - original_agenda, - coverage_updates, - related_planning_removed, - coverage_updated, - time_updated, - coverage_modified, - ) - - def get_saved_items_count(self): - search = SearchQuery() - search.query = build_agenda_query() - - self.prefill_search_query(search) - self.apply_filters(search) - set_saved_items_query(search.query, str(search.user["_id"])) - - cursor = self.get_agenda_items_by_query(search.query, size=0) - return cursor.count() if cursor else 0 - - def get_agenda_query(self, query, events_only=False): - if events_only: - return self.query_string(query) - else: - return { - "bool": { - "should": [ - self.query_string(query), - nested_query("planning_items", planning_items_query_string(query), name="query"), - ] - }, - } diff --git a/newsroom/agenda/agenda_search.py b/newsroom/agenda/agenda_search.py new file mode 100644 index 000000000..b63c53349 --- /dev/null +++ b/newsroom/agenda/agenda_search.py @@ -0,0 +1,178 @@ +from typing import Any +from copy import deepcopy +import logging + +from superdesk.core.types import Request, Response, SearchRequest, ESQuery + +from newsroom.types import ( + AgendaItem, + AgendaItemType, + SectionEnum, + UserResourceModel, + CompanyResource, +) +from newsroom.auth.utils import get_company_from_request +from newsroom.search.types import NewshubSearchRequest +from newsroom.search.base_web_service import BaseWebSearchService +from newsroom.search.filters import ( + prefill_products, + apply_products_filter, + apply_ids_filter, + apply_advanced_search, + apply_section_filter, + apply_company_filter, +) + +from .filters import ( + AgendaSearchRequestArgs, + default_agenda_filters, + get_date_filters, + aggregations, + filters_without_dates, + prefill_item_type_arg, + apply_item_state_filter, + apply_agenda_query_string, + apply_agenda_filters, + apply_agenda_date_filters, +) +from .agenda_service import AgendaItemService +from .utils import remove_restricted_coverage_info + +logger = logging.getLogger(__name__) + + +def get_agenda_aggregations(events_only: bool = False): + aggs = deepcopy(aggregations) + if events_only: + aggs.pop("coverage", None) + aggs.pop("planning_items", None) + aggs.pop("urgency", None) + aggs.pop("agendas", None) + return aggs + + +class AgendaSearchServiceAsync(BaseWebSearchService[AgendaSearchRequestArgs, AgendaItem]): + search_args_class = AgendaSearchRequestArgs + filters = default_agenda_filters + section = SectionEnum.AGENDA + default_sort = [("dates.start", 1)] + default_page_size = 250 + service: AgendaItemService + + get_items_by_id_filters = [ + apply_item_state_filter, + apply_ids_filter, + ] + get_topic_items_query_execute_filters = [ + apply_products_filter, + apply_agenda_query_string, + apply_ids_filter, + apply_agenda_filters, + apply_advanced_search, + apply_agenda_date_filters, + ] + get_topic_items_query_user_filters = [ + apply_section_filter, + apply_item_state_filter, + apply_company_filter, + ] + + def __init__(self): + self.service = AgendaItemService() + + async def process_web_request(self, request: Request) -> Response: + search_request = self.get_search_request_instance(request) + elastic_query = await self.run_filters_and_return_query(search_request) + internal_request = SearchRequest( + sort=search_request.args.sort, + max_results=search_request.args.page_size, + page=search_request.args.page, + aggregations=not search_request.args.page and search_request.args.aggs, + projection=search_request.args.projection, + elastic=elastic_query, + ) + + args = search_request.args + + if not args.page and not args.bookmarks and args.aggs: + internal_request.elastic.aggs = get_agenda_aggregations( + search_request.args.item_type == AgendaItemType.EVENT + ) + + cursor = await self.service.find(internal_request) + response, count = await self.get_search_response(internal_request, cursor) + + if args.item_type is None or (args.start_date and args.end_date): + matching_event_ids: set[str] = ( + set() if args.item_type is not None else await self._get_event_ids_matching_query(args) + ) + date_range = {} if not args.start_date and args.end_date else get_date_filters(args) + for item in response["_items"]: + if item["_id"] in matching_event_ids: + item["_search_matched_event"] = True + if date_range: + # make the items display on the featured day, + # it's used in ui instead of dates.start and dates.end + item.update( + { + "_display_from": date_range.get("gt"), + "_display_to": date_range.get("lt"), + } + ) + + await self.service.enhance_item(item) + else: + for item in response["_items"]: + await self.service.enhance_item(item) + + return Response(response, 200, [("X-Total-Count", count)]) + + async def _get_event_ids_matching_query(self, args: AgendaSearchRequestArgs) -> set[str]: + search_request = NewshubSearchRequest[AgendaSearchRequestArgs]( + section=self.section, + web_request=None, + args=args.model_copy(), + search=ESQuery(), + ) + search_request.args.item_type = AgendaItemType.EVENT + search_request.args.aggs = False + search_request.args.projection = {"_id"} + + cursor = await self.search(search_request) + return set([item["_id"] for item in await cursor.to_list_raw()]) + + async def get_saved_items_count(self, user: UserResourceModel, company: CompanyResource | None) -> int: + def set_user_and_company(request: NewshubSearchRequest) -> None: + request.current_user = request.user = user + request.company = company + request.is_admin = user.is_admin() + + cursor = await self.search( + AgendaSearchRequestArgs(bookmarks=[user.id], page_size=0), + filters=[ + set_user_and_company, + prefill_products, + prefill_item_type_arg, + ] + + filters_without_dates, + ) + return await cursor.count() + + async def get_items_for_action(self, item_ids: list[str]) -> list[dict[str, Any]]: + """Searches for item by ID, for use by downloads, sharing etc + + If the current user's company has ``restrict_coverage_info`` config turned on, then + for each item removes the restricted coverage information + + :param item_ids: A list of item IDs to search for + :returns: The list of Agenda items + """ + + cursor = await self.get_items_by_id(item_ids) + items = await cursor.to_list_raw() + + company = get_company_from_request(None) + if company and company.restrict_coverage_info: + remove_restricted_coverage_info(items) + + return items diff --git a/newsroom/agenda/agenda_service.py b/newsroom/agenda/agenda_service.py new file mode 100644 index 000000000..7f72eb171 --- /dev/null +++ b/newsroom/agenda/agenda_service.py @@ -0,0 +1,396 @@ +from typing import cast, Any, Sequence +from datetime import datetime +import logging + +from bson import ObjectId + +from superdesk.core.resources.cursor import ElasticsearchResourceCursorAsync +from superdesk.core.resources import AsyncResourceService + +from planning.common import ASSIGNMENT_WORKFLOW_STATE + +from newsroom.types import AgendaItem, AgendaWorkflowState +from newsroom.core import get_current_wsgi_app +from newsroom.utils import parse_date_str, parse_dates + +from newsroom.wire import url_for_wire, WireSearchServiceAsync + +from .filters import planning_filters, coverage_filters +from .notifications import notify_agenda_update +from .utils import get_latest_available_delivery, push_agenda_item_notification, TO_BE_CONFIRMED_FIELD + + +logger = logging.getLogger(__name__) + + +class AgendaItemService(AsyncResourceService[AgendaItem]): + async def _convert_dicts_to_model(self, docs: Sequence[AgendaItem | dict[str, Any]]) -> list[AgendaItem]: + items: list[AgendaItem] = [] + for item in docs: + if isinstance(item, AgendaItem): + items.append(item) + elif item.get("type") == "event": + agenda, _ = self.convert_event_to_agenda_dict({}, item) + items.append(AgendaItem.from_dict(agenda)) + elif item.get("type") == "planning": + agenda, _ = await self.convert_planning_to_agenda_dict({}, item, force_adhoc=True, add_coverages=True) + items.append(AgendaItem.from_dict(agenda)) + else: + items.append(AgendaItem.from_dict(item)) + + return items + + def convert_event_to_agenda_dict( + self, agenda: dict[str, Any], event: dict[str, Any], set_doc_id: bool = True + ) -> tuple[dict[str, Any], list[str]]: + """ + Sets agenda metadata from a given event + """ + from newsroom.push.utils import format_qcode_items, set_dates, get_event_dates + + app = get_current_wsgi_app() + if event.get("files"): + for file_ref in event["files"]: + if file_ref.get("media"): + file_ref.setdefault("href", app.upload_url(file_ref["media"])) + + plan_ids = event.pop("plans", []) + parse_dates(event) + + # setting _id of agenda to be equal to event + guid = event.get("guid") or event["_id"] + if set_doc_id: + agenda.setdefault("_id", guid) + + agenda["item_type"] = "event" + agenda["guid"] = guid + agenda["event_id"] = guid + agenda["recurrence_id"] = event.get("recurrence_id") + agenda["name"] = event.get("name") + agenda["slugline"] = event.get("slugline") + agenda["definition_short"] = event.get("definition_short") + agenda["definition_long"] = event.get("definition_long") + agenda["version"] = event.get("version") + agenda["versioncreated"] = event.get("versioncreated") + agenda["calendars"] = event.get("calendars") + agenda["location"] = event.get("location", []) + agenda["ednote"] = event.get("ednote") + agenda["state_reason"] = event.get("state_reason") + agenda["place"] = event.get("place") + agenda["subject"] = format_qcode_items(event.get("subject")) + agenda["products"] = event.get("products") + agenda["service"] = format_qcode_items(event.get("anpa_category")) + agenda["event"] = event + agenda["registration_details"] = event.get("registration_details") + agenda["invitation_details"] = event.get("invitation_details") + agenda["language"] = event.get("language") + agenda["source"] = event.get("source") + + set_dates(agenda) + + agenda["dates"] = get_event_dates(event) + + agenda["state"] = event.get("state") or AgendaWorkflowState.CANCELLED.SCHEDULED + if event.get("pubstatus") == "cancelled": + agenda["state"] = AgendaWorkflowState.CANCELLED + + if event.get("planning_items"): + agenda["planning_items"] = event["planning_items"] + if event.get("coverages"): + agenda["coverages"] = event["coverages"] + + return agenda, plan_ids + + async def convert_planning_to_agenda_dict( + self, + agenda: dict[str, Any], + planning_item: dict[str, Any], + force_adhoc: bool = False, + add_coverages: bool = False, + ) -> tuple[dict[str, Any], bool]: + """Sets agenda metadata from a given planning""" + + from newsroom.push.utils import format_qcode_items, set_dates, get_display_dates + from newsroom.push.agenda_manager import AgendaManager + + parse_dates(planning_item) + set_dates(agenda) + + if not planning_item.get("event_item") or force_adhoc: + # adhoc planning item + agenda.setdefault("_id", planning_item["guid"]) + agenda.setdefault("guid", planning_item["guid"]) + agenda["item_type"] = "planning" + + # planning dates is saved as the dates of the new agenda + agenda["dates"] = { + "start": planning_item["planning_date"], + "end": planning_item["planning_date"], + } + if planning_item.get("pubstatus") == "cancelled": + agenda["watches"] = [] + + agenda["name"] = planning_item.get("name") + agenda["headline"] = planning_item.get("headline") + agenda["slugline"] = planning_item.get("slugline") + agenda["ednote"] = planning_item.get("ednote") + agenda["place"] = planning_item.get("place") + agenda["subject"] = format_qcode_items(planning_item.get("subject")) + agenda["products"] = planning_item.get("products") + agenda["urgency"] = planning_item.get("urgency") + agenda["definition_short"] = planning_item.get("description_text") or agenda.get("definition_short") + agenda["definition_long"] = planning_item.get("abstract") or agenda.get("definition_long") + agenda["service"] = format_qcode_items(planning_item.get("anpa_category")) + agenda["state"] = planning_item.get("state") or "scheduled" + agenda["state_reason"] = planning_item.get("state_reason") + agenda["language"] = planning_item.get("language") + agenda["source"] = planning_item.get("source") + + agenda["state"] = planning_item.get("state") or AgendaWorkflowState.CANCELLED.SCHEDULED + if planning_item.get("pubstatus") == "cancelled": + agenda["state"] = AgendaWorkflowState.CANCELLED + + if planning_item.get("event_id"): + agenda["event_id"] = planning_item["event_id"] + elif planning_item.get("event_item") and force_adhoc: + agenda["event_id"] = planning_item["event_item"] + + if not agenda.get("planning_items"): + agenda["planning_items"] = [] + + new_plan = False + plan: dict[str, Any] = next( + (p for p in (agenda.get("planning_items") or []) if p.get("guid") == planning_item.get("guid")), + {}, + ) + + if not plan: + new_plan = True + + agenda_versioncreated: datetime = agenda["versioncreated"] + plan_versioncreated: datetime = parse_date_str(planning_item.get("versioncreated") or agenda_versioncreated) + + plan["_id"] = planning_item.get("_id") or planning_item.get("guid") + plan["guid"] = planning_item.get("guid") + plan["slugline"] = planning_item.get("slugline") + plan["description_text"] = planning_item.get("description_text") + plan["headline"] = planning_item.get("headline") + plan["name"] = planning_item.get("name") + plan["abstract"] = planning_item.get("abstract") + plan["place"] = planning_item.get("place") + plan["subject"] = format_qcode_items(planning_item.get("subject")) + plan["service"] = format_qcode_items(planning_item.get("anpa_category")) + plan["urgency"] = planning_item.get("urgency") + plan["planning_date"] = planning_item.get("planning_date") + plan["coverages"] = planning_item.get("coverages") or [] + plan["ednote"] = planning_item.get("ednote") + plan["internal_note"] = planning_item.get("internal_note") + plan["versioncreated"] = plan_versioncreated + plan["firstcreated"] = parse_date_str(planning_item.get("firstcreated") or agenda["firstcreated"]) + plan["state"] = planning_item.get("state") or "scheduled" + plan["state_reason"] = planning_item.get("state_reason") + plan["products"] = planning_item.get("products") + plan["agendas"] = planning_item.get("agendas") + plan[TO_BE_CONFIRMED_FIELD] = planning_item.get(TO_BE_CONFIRMED_FIELD) + plan["language"] = planning_item.get("language") + plan["source"] = planning_item.get("source") + + if new_plan: + agenda["planning_items"].append(plan) + + # Update the versioncreated datetime from Planning item if it's newer than the parent item + try: + if plan_versioncreated > agenda_versioncreated: + agenda["versioncreated"] = plan_versioncreated + except (KeyError, TypeError): + pass + + if add_coverages: + agenda["coverages"], _ = await AgendaManager().get_coverages(agenda["planning_items"], [], planning_item) + agenda["display_dates"] = get_display_dates(agenda["planning_items"]) + + return agenda, new_plan + + async def get_by_coverage_id(self, coverage_id: str) -> ElasticsearchResourceCursorAsync[AgendaItem]: + return cast( + ElasticsearchResourceCursorAsync, + await self.search( + { + "query": { + "bool": { + "filter": [ + { + "nested": { + "path": "coverages", + "query": { + "bool": {"filter": [{"term": {"coverages.coverage_id": coverage_id}}]} + }, + } + } + ], + } + } + } + ), + ) + + async def enhance_item(self, doc: dict[str, Any]): + await self.enhance_coverages(doc.get("coverages") or []) + doc.setdefault("_hits", {}) + doc["_hits"]["matched_event"] = doc.pop("_search_matched_event", False) + + if not doc.get("planning_items"): + return + + doc["_hits"]["matched_planning_items"] = [plan["_id"] for plan in doc.get("planning_items") or []] + + # Filter based on _inner_hits + inner_hits = doc.pop("_inner_hits", {}) + + # If the search matched the Event + # then only count Planning based filters when checking ``_inner_hits`` + if doc["_hits"]["matched_event"]: + inner_hits = {key: val for key, val in inner_hits.items() if key in planning_filters} + + if not inner_hits or not doc.get("planning_items"): + return + + if len([f for f in inner_hits.keys() if f in coverage_filters]) > 0: + # Collect hits for 'coverage' and 'coverage_status' separately to other inner_hits + coverages_by_filter = { + key: [item.get("coverage_id") for item in items] + for key, items in inner_hits.items() + if key in ["coverage", "coverage_status"] + } + unique_coverage_ids = set([coverage_id for items in coverages_by_filter.values() for coverage_id in items]) + doc["_hits"]["matched_coverages"] = [ + coverage_id + for coverage_id in unique_coverage_ids + if all([coverage_id in items for items in coverages_by_filter.values()]) + ] + + if doc["item_type"] == "planning": + # If this is a Planning item, then ``inner_hits`` should only include the + # fields relevant to the Coverages (as this is the only nested field of a Planning item) + inner_hits = {key: val for key, val in inner_hits.items() if key in planning_filters} + + if len(inner_hits.keys()) > 0: + # Store matched Planning IDs into matched_planning_items + # The Planning IDs must be in all supplied ``_inner_hits`` + # In order to be included (i.e. match all nested planning queries) + items_by_filter = { + key: [item.get("guid") or item.get("planning_id") for item in items] + for key, items in inner_hits.items() + } + unique_ids = set([item_id for items in items_by_filter.values() for item_id in items]) + doc["_hits"]["matched_planning_items"] = [ + item_id for item_id in unique_ids if all([item_id in items for items in items_by_filter.values()]) + ] + + async def enhance_coverages(self, coverages: list[dict[str, Any]]): + completed_coverages = [ + c + for c in coverages + if c["workflow_status"] == ASSIGNMENT_WORKFLOW_STATE.COMPLETED and len(c.get("deliveries") or []) > 0 + ] + # Enhance completed coverages in general - add story's abstract/headline/slugline + text_delivery_ids: list[str] = [ + c["delivery_id"] for c in completed_coverages if c.get("delivery_id") and c.get("coverage_type") == "text" + ] + if text_delivery_ids: + wire_items = await WireSearchServiceAsync().get_items_by_id(text_delivery_ids) + if await wire_items.count(): + async for item in wire_items: + coverage = [c for c in completed_coverages if c.get("delivery_id") == item.get("_id")][0] + coverage["publish_time"] = item.publish_schedule or item.firstpublished + + async def set_delivery(self, wire_item: dict[str, Any]) -> list[dict[str, Any]]: + coverage_id = wire_item.get("coverage_id") + if not coverage_id: + return [] + + cursor = await self.get_by_coverage_id(coverage_id) + if not await cursor.count(): + return [] + + agenda_items = await cursor.to_list_raw() + agenda_updated_notification_sent = False + + def is_delivery_validated(coverage: dict[str, Any]): + latest_delivery = get_latest_available_delivery(coverage) + + return ( + (not latest_delivery or not wire_item.get("rewrite_sequence")) + or ((wire_item.get("rewrite_sequence") or 0) >= latest_delivery.get("sequence_no", 0)) + or ( + (wire_item.get("publish_schedule") or wire_item.get("firstpublished")) + >= latest_delivery.get("publish_time") + ) + ) + + def update_coverage_details(coverage: dict[str, Any]): + coverage["delivery_id"] = wire_item["guid"] + coverage["delivery_href"] = url_for_wire( + None, + _external=False, + section="wire.item", + item_id=wire_item["guid"], + ) + coverage["workflow_status"] = ASSIGNMENT_WORKFLOW_STATE.COMPLETED + deliveries = coverage.get("deliveries") or [] + d = next( + (d for d in deliveries if d.get("delivery_id") == wire_item["guid"]), + None, + ) + if d and d.get("delivery_state") != "published": + d["delivery_state"] = "published" + publish_time_str: str | None = wire_item.get("publish_schedule") or wire_item.get("firstpublished") + d["publish_time"] = parse_date_str(publish_time_str) if publish_time_str else None + return d + + for item in agenda_items: + # Make sure coverage watches are using ObjectIds + for c in item.get("coverages") or []: + if c.get("watches"): + c["watches"] = [ObjectId(u) for u in c["watches"]] + + parent_coverage = next( + (c for c in item.get("coverages") or [] if c["coverage_id"] == wire_item["coverage_id"]), None + ) + if not parent_coverage or not is_delivery_validated(parent_coverage): + continue + + delivery = update_coverage_details(parent_coverage) + planning_item = next( + (p for p in item.get("planning_items") or [] if p["_id"] == parent_coverage["planning_id"]), None + ) + planning_updated = False + if planning_item: + coverage = next( + (c for c in planning_item.get("coverages") or [] if c["coverage_id"] == wire_item["coverage_id"]), + None, + ) + if coverage: + planning_updated = True + update_coverage_details(coverage) + + if not planning_updated: + await self.system_update(item["_id"], {"coverages": item["coverages"]}) + else: + updates = { + "coverages": item["coverages"], + "planning_items": item["planning_items"], + } + await self.system_update(item["_id"], updates) + + updated_agenda = self.find_by_id(item["_id"]) + # Notify agenda to update itself with new details of coverage + parent_coverage["publish_time"] = wire_item.get("publish_schedule") or wire_item.get("firstpublished") + await push_agenda_item_notification("new_item", item=item) + + # If published first time, coverage completion will trigger email - not needed now + if (delivery or {}).get("sequence_no", 0) > 0 and not agenda_updated_notification_sent: + agenda_updated_notification_sent = True + await notify_agenda_update(updated_agenda, updated_agenda, None, True, None, parent_coverage) + + return agenda_items diff --git a/newsroom/agenda/email.py b/newsroom/agenda/email.py index cd3c85775..50c9f231d 100644 --- a/newsroom/agenda/email.py +++ b/newsroom/agenda/email.py @@ -1,6 +1,4 @@ -from typing import Any - -from newsroom.types import UserResourceModel +from newsroom.types import UserResourceModel, AgendaItem from newsroom.email import send_template_email, send_user_email from newsroom.utils import ( get_agenda_dates, @@ -65,7 +63,7 @@ async def send_agenda_notification_email( ) -async def send_coverage_request_email(user: UserResourceModel, message: str, item: dict[str, Any]) -> None: +async def send_coverage_request_email(user: UserResourceModel, message: str, item: AgendaItem) -> None: """ Forms and sends coverage request email :param user: User that makes the request @@ -81,11 +79,11 @@ async def send_coverage_request_email(user: UserResourceModel, message: str, ite recipients = general_settings.get("values").get("coverage_request_recipients").split(",") assert recipients assert isinstance(recipients, list) - url = url_for_agenda({"_id": item["_id"]}, _external=True) + url = url_for_agenda({"_id": item.id}, _external=True) name = f"{user.first_name} {user.last_name}" email = user.email - item_name = item.get("name") or item.get("slugline") + item_name = item.name or item.slugline user_company = await user.get_company() company_name = user_company.name if user_company else None @@ -97,7 +95,7 @@ async def send_coverage_request_email(user: UserResourceModel, message: str, ite company=company_name, recipients=recipients, item_name=item_name, - item=item, + item=item.to_dict(), ) await send_template_email( diff --git a/newsroom/agenda/featured_service.py b/newsroom/agenda/featured_service.py new file mode 100644 index 000000000..5cf94685f --- /dev/null +++ b/newsroom/agenda/featured_service.py @@ -0,0 +1,150 @@ +from typing import Any +from datetime import datetime + +from superdesk.core.types import ESQuery, RestGetResponse, RestResponseMeta +from superdesk.core.resources import AsyncResourceService +from superdesk.utc import local_to_utc + +from newsroom.types import FeaturedResourceModel, SectionEnum +from newsroom.utils import get_local_date +from newsroom.search.types import NewshubSearchRequest +from newsroom.search.utils import query_string_for_section + +from .filters import ( + apply_item_state_filter, + apply_section_filter, + apply_agenda_filters, + planning_items_query_string, + nested_query, + aggregations, + AgendaSearchRequestArgs, +) +from .agenda_search import AgendaSearchServiceAsync + + +class FeaturedService(AsyncResourceService[FeaturedResourceModel]): + resource_name = "agenda_featured" + + async def on_create(self, docs: list[FeaturedResourceModel]) -> None: + """ + Add UTC from/to datetimes on save. + Problem is 31.8. in Sydney is from 30.8. 14:00 UTC to 31.8. 13:59 UTC. + And because we query later using UTC, we store those UTC datetimes as + display_from and display_to. + """ + for item in docs: + date = datetime.strptime(item.id, "%Y%m%d") + item.display_from = local_to_utc(item.tz, date.replace(hour=0, minute=0, second=0)) + item.display_to = local_to_utc(item.tz, date.replace(hour=23, minute=59, second=59)) + await super().on_create(docs) + + async def find_one_for_date(self, for_date: datetime) -> FeaturedResourceModel | None: + return await self.find_one(display_from={"$lte": for_date}, display_to={"$gte": for_date}) + + async def get_featured_stories( + self, + date_from: str, + timezone_offset: int = 0, + query_string: str | None = None, + filters: dict[str, Any] | None = None, + from_offset: int = 0, + ) -> RestGetResponse: + for_date = datetime.strptime(date_from, "%d/%m/%Y %H:%M") + local_date = get_local_date( + for_date.strftime("%Y-%m-%d"), + for_date.strftime("%H:%M:%S"), + timezone_offset, + ) + featured_doc = await self.find_one_for_date(local_date) + return await self.featured(featured_doc, query_string, filters, from_offset) + + async def featured( + self, + featured_doc: FeaturedResourceModel | None = None, + query_string: str | None = None, + filters: dict[str, Any] | None = None, + from_offset: int = 0, + ) -> RestGetResponse: + """Return featured items. + + :param Optional[dict] featured_doc: The featured document for the given date + :param Optional[str] query_string: Optional search query to filter the results + :param Optional[str] filter_string: Optional filter query to filter the results + :param int from_offset: Pagination offset for the results + :return: A list of filtered featured items + """ + + if featured_doc is None or featured_doc.items is None or not len(featured_doc.items): + return RestGetResponse( + _items=[], + _meta=RestResponseMeta( + page=from_offset, + max_results=0, + total=0, + ), + ) + + def apply_featured_filters(request: NewshubSearchRequest) -> None: + planning_items_query = nested_query( + "planning_items", + {"bool": {"filter": [{"terms": {"planning_items.guid": featured_doc.items}}]}}, + name="featured", + ) + + if query_string: + request.search.query.filter.append(query_string_for_section(SectionEnum.AGENDA, query_string)) + planning_items_query["nested"]["query"]["bool"]["filter"].append( + planning_items_query_string(query_string) + ) + + request.search.query.filter.append(planning_items_query) + + cursor = await AgendaSearchServiceAsync().search( + NewshubSearchRequest( + args=AgendaSearchRequestArgs( + featured=True, + page_size=len(featured_doc.items), + page=from_offset, + filter=filters, + ), + search=ESQuery(aggs=aggregations if not from_offset else {}), + ), + filters=[ + apply_item_state_filter, + apply_section_filter, + apply_agenda_filters, + apply_featured_filters, + ], + ) + + docs_by_id: dict[str, dict[str, Any]] = {} + for doc in await cursor.to_list_raw(): + for p in doc.get("planning_items") or []: + docs_by_id[p.get("guid")] = doc + + # Update display dates based on the featured document + doc.update( + { + "_display_from": featured_doc.display_from, + "_display_to": featured_doc.display_to, + } + ) + + docs = [] + agenda_ids = set() + for agenda_id in featured_doc.items: + agenda_item = docs_by_id.get(agenda_id) + if agenda_item and agenda_item.get("_id") not in agenda_ids: + docs.append(agenda_item) + agenda_ids.add(agenda_item.get("_id")) + + response = RestGetResponse( + _items=docs, + _meta=RestResponseMeta( + page=from_offset, + max_results=len(docs), + total=len(docs), + ), + ) + cursor.extra(response) + return response diff --git a/newsroom/agenda/filters.py b/newsroom/agenda/filters.py new file mode 100644 index 000000000..7a140cdd4 --- /dev/null +++ b/newsroom/agenda/filters.py @@ -0,0 +1,639 @@ +from typing import Any, Annotated +from datetime import datetime + +from pydantic import field_validator, Field, AliasChoices + +from superdesk.core import get_app_config, json +from superdesk.core.types import ESBoolQuery, SortListParam + +from newsroom.types import AgendaItemType, SectionEnum +from newsroom.search.types import BaseSearchRequestArgs, SearchFilterFunction, NewshubSearchRequest, QueryStringQuery +from newsroom.search.utils import query_string, query_string_for_section, get_filter_query +from newsroom.search.config import is_search_field_nested, get_nested_config +from newsroom.utils import get_local_date, get_end_date +from newsroom.search.filters import ( + prefill_user, + prefill_company, + prefill_products, + prefill_args_from_topic, + apply_company_filter, + apply_products_filter, + apply_section_filter, + apply_ids_filter, + apply_advanced_search, + get_apply_highlights, +) + +"""" +company_filter +products filters +product planning filters +agenda items query +section filter +apply_request_filter + +if not is_admin_or_internal: + remove_fields(PRIVATE_FIELDS) + +itemType param filters (events / planning / get_app_config("AGENDA_DEFAULT_FILTER_HIDE_PLANNING") / default) + +custom agenda filters: + q + id/ids + bookmarks + dates + filters + advanced + +""" + +PRIVATE_FIELDS = ["event.files", "*.internal_note"] +PLANNING_ITEMS_FIELDS = ["planning_items", "coverages", "display_dates"] + +aggregations: dict[str, dict[str, Any]] = { + "language": {"terms": {"field": "language"}}, + "calendar": {"terms": {"field": "calendars.name", "size": 100}}, + "service": {"terms": {"field": "service.name", "size": 100}}, + "subject": {"terms": {"field": "subject.name", "size": 200}}, + "urgency": {"terms": {"field": "urgency"}}, + "place": {"terms": {"field": "place.name", "size": 50}}, + "coverage": { + "nested": {"path": "coverages"}, + "aggs": {"coverage_type": {"terms": {"field": "coverages.coverage_type", "size": 10}}}, + }, + "planning_items": { + "nested": { + "path": "planning_items", + }, + "aggs": { + "service": {"terms": {"field": "planning_items.service.name", "size": 100}}, + "subject": {"terms": {"field": "planning_items.subject.name", "size": 200}}, + "urgency": {"terms": {"field": "planning_items.urgency"}}, + "place": {"terms": {"field": "planning_items.place.name", "size": 50}}, + }, + }, + "agendas": { + "nested": {"path": "planning_items"}, + "aggs": { + "agenda": {"terms": {"field": "planning_items.agendas.name", "size": 100}}, + }, + }, +} + + +class AgendaSearchRequestArgs(BaseSearchRequestArgs): + #: The sorting that should be applied to this request + sort: SortListParam = [("dates.start", 1)] + + item_type: Annotated[AgendaItemType | None, Field(validation_alias=AliasChoices("item_type", "itemType"))] = None + featured: bool = False + + @field_validator("item_type", mode="before") + def parse_item_type(cls, value: str) -> str: + # Make sure that we use the same value type for item type and search item type + return "event" if value == "events" else value + + +def get_date_filters(args: BaseSearchRequestArgs): + date_range = {} + offset = args.timezone_offset or 0 + if args.start_date: + date_range["gt"] = get_local_date(args.start_date, args.start_time, offset) + if args.end_date: + date_range["lt"] = get_end_date(args.end_date, get_local_date(args.end_date, args.end_time, offset)) + + return date_range + + +def prefill_item_type_arg(request: NewshubSearchRequest[AgendaSearchRequestArgs]) -> None: + if request.user and not request.user.is_admin() and request.company and request.company.events_only: + request.args.item_type = AgendaItemType.EVENT + + +def apply_item_state_filter(request: NewshubSearchRequest[AgendaSearchRequestArgs]) -> None: + request.search.query.must_not.append({"term": {"state": "killed"}}) + + if request.user and not request.user.is_admin_or_internal(): + request.search.exclude_fields.extend(PRIVATE_FIELDS) + + +def apply_item_type_filter(request: NewshubSearchRequest[AgendaSearchRequestArgs]) -> None: + item_type = request.args.item_type + if item_type == AgendaItemType.EVENT: + # no adhoc planning items and remove planning items and coverages fields + request.search.query.filter.append( + { + "bool": { + "should": [ + {"term": {"item_type": "event"}}, + { + # Match Events before ``item_type`` field was added + "bool": { + "must_not": [{"exists": {"field": "item_type"}}], + "filter": [{"exists": {"field": "event_id"}}], + }, + }, + ], + "minimum_should_match": 1, + }, + } + ) + request.search.exclude_fields.extend(PLANNING_ITEMS_FIELDS) + elif item_type == AgendaItemType.PLANNING: + request.search.query.filter.append( + { + "bool": { + "should": [ + {"term": {"item_type": "planning"}}, + { + # Match Planning before ``item_type`` field was added + "bool": { + "must_not": [ + {"exists": {"field": "item_type"}}, + {"exists": {"field": "event_id"}}, + ], + }, + }, + ], + "minimum_should_match": 1, + } + } + ) + elif get_app_config("AGENDA_DEFAULT_FILTER_HIDE_PLANNING"): + request.search.query.filter.append( + { + "bool": { + "should": [ + {"term": {"item_type": "event"}}, + { + # Match Events before ``item_type`` field was added + "bool": { + "must_not": [{"exists": {"field": "item_type"}}], + "filter": [{"exists": {"field": "event_id"}}], + }, + }, + ], + "minimum_should_match": 1, + }, + } + ) + request.search.exclude_fields.extend(["planning_items", "display_dates"]) + else: + # Don't include Planning items that are associated with an Event + request.search.query.filter.append( + { + "bool": { + "should": [ + {"bool": {"must_not": [{"exists": {"field": "item_type"}}]}}, + {"term": {"item_type": "event"}}, + { + "bool": { + "filter": [{"term": {"item_type": "planning"}}], + "must_not": [{"exists": {"field": "event_id"}}], + }, + }, + ], + "minimum_should_match": 1, + }, + } + ) + + +def planning_items_query_string(query, fields=None): + return query_string(query, fields=fields or ["planning_items.*"]) + + +def nested_query(path, query, inner_hits=True, name=None): + nested = {"path": path, "query": query} + if inner_hits: + nested["inner_hits"] = {} + if name: + nested["inner_hits"]["name"] = name + + return {"nested": nested} + + +def apply_product_planning_filters(request: NewshubSearchRequest[AgendaSearchRequestArgs]) -> None: + planning_items_should = [] + for product in request.products: + if product.planning_item_query and request.args.item_type != AgendaItemType.EVENT: + planning_items_should.append(planning_items_query_string(product.planning_item_query)) + + if len(planning_items_should): + request.search.query.should.append( + nested_query( + "planning_items", + { + "bool": { + "should": planning_items_should, + "minimum_should_match": 1, + }, + }, + name="products", + ) + ) + + +def apply_agenda_query_string(request: NewshubSearchRequest[AgendaSearchRequestArgs]) -> None: + if not request.args.q: + return + + q_dict: dict[str, Any] | None = None + try: + q_dict = json.loads(request.args.q) + except ValueError: + pass + + if q_dict is None: + # Normal query string query + query = query_string_for_section(SectionEnum.AGENDA, request.args.q) + if request.args.item_type == AgendaItemType.EVENT: + # Events Only + request.search.query.filter.append(query) + else: + request.search.query.filter.append( + { + "bool": { + "should": [ + query, + nested_query( + "planning_items", + planning_items_query_string(request.args.q), + name="query", + ), + ], + "minimum_should_match": 1, + }, + } + ) + else: + # Complex Agenda query string + filters: list[dict[str, Any] | QueryStringQuery] = [] + if q_dict.get("query"): + filters.append(query_string_for_section(SectionEnum.AGENDA, q_dict["query"])) + + if q_dict.get("planning_item_query"): + filters.append( + nested_query( + "planning_items", + planning_items_query_string(q_dict["planning_item_query"]), + name="product_test", + ), + ) + + if len(filters): + request.search.query.filter.append( + { + "bool": { + "should": filters, + "minimum_should_match": 1, + } + } + ) + + +def apply_agenda_bookmarks_query(request: NewshubSearchRequest[AgendaSearchRequestArgs]) -> None: + if not len(request.args.bookmarks): + return + + user_ids: list[str] = [str(user_id) for user_id in request.args.bookmarks] + request.search.query.filter.append( + { + "bool": { + "should": [ + {"terms": {"bookmarks": user_ids}}, + {"terms": {"watches": user_ids}}, + { + "nested": { + "path": "coverages", + "query": {"bool": {"should": [{"terms": {"coverages.watches": user_ids}}]}}, + } + }, + ], + }, + } + ) + + +def gen_date_range_filter(field: str, operator: str, date_str: str, datetime_instance: datetime): + return [ + { + "bool": { + "must_not": {"term": {"dates.all_day": True}}, + "filter": {"range": {field: {operator: datetime_instance}}}, + }, + }, + { + "bool": { + "filter": [ + {"term": {"dates.all_day": True}}, + {"range": {field: {operator: date_str}}}, + ], + }, + }, + ] + + +def apply_agenda_date_filters(request: NewshubSearchRequest[AgendaSearchRequestArgs]) -> None: + date_range = get_date_filters(request.args) + date_from = date_range.get("gt") + date_to = date_range.get("lt") + should = [] + + if request.args.start_date and date_from and not date_to: + # Filter from a particular date onwards + should = gen_date_range_filter("dates.end", "gte", request.args.start_date, date_from) + elif request.args.end_date and not date_from and date_to: + # Filter up to a particular date + should = gen_date_range_filter("dates.end", "lte", request.args.end_date, date_to) + elif request.args.start_date and request.args.end_date and date_from and date_to: + # Filter based on the date range provided + should = [ + { + # Both start/end dates are inside the range + "bool": { + "filter": [ + {"range": {"dates.start": {"gte": date_from}}}, + {"range": {"dates.end": {"lte": date_to}}}, + ], + "must_not": {"term": {"dates.all_day": True}}, + }, + }, + { + # Both start/end dates are inside the range, all day version + "bool": { + "filter": [ + {"range": {"dates.start": {"gte": request.args.start_date}}}, + {"range": {"dates.end": {"lte": request.args.end_date}}}, + {"term": {"dates.all_day": True}}, + ], + }, + }, + { + # Starts before date_from and finishes after date_to + "bool": { + "filter": [ + {"range": {"dates.start": {"lt": date_from}}}, + {"range": {"dates.end": {"gt": date_to}}}, + ], + "must_not": {"term": {"dates.all_day": True}}, + }, + }, + { + # Starts before date_from and finishes after date_to, all day version + "bool": { + "filter": [ + {"range": {"dates.start": {"lt": request.args.start_date}}}, + {"range": {"dates.end": {"gt": request.args.end_date}}}, + {"term": {"dates.all_day": True}}, + ], + }, + }, + { + # Start date is within range OR End date is within range + "bool": { + "should": [ + {"range": {"dates.start": {"gte": date_from, "lte": date_to}}}, + {"range": {"dates.end": {"gte": date_from, "lte": date_to}}}, + ], + "must_not": {"term": {"dates.all_day": True}}, + "minimum_should_match": 1, + }, + }, + { + # Start date is within range OR End date is within range, all day version + "bool": { + "should": [ + {"range": {"dates.start": {"gte": request.args.start_date, "lte": request.args.end_date}}}, + {"range": {"dates.end": {"gte": request.args.start_date, "lte": request.args.end_date}}}, + ], + "filter": {"term": {"dates.all_day": True}}, + "minimum_should_match": 1, + }, + }, + ] + + if date_range: + # Get events for extra dates for coverages and planning. + should.append({"range": {"display_dates.date": date_range}}) + + if len(should): + request.search.query.filter.append({"bool": {"should": should, "minimum_should_match": 1}}) + + +coverage_filters = ["coverage", "coverage_status"] +planning_filters = coverage_filters + ["agendas"] + + +def get_aggregation_field(key: str): + if key == "coverage": + return aggregations[key]["aggs"]["coverage_type"]["terms"]["field"] + elif key == "agendas": + return aggregations[key]["aggs"]["agenda"]["terms"]["field"] + elif is_search_field_nested("agenda", key): + return aggregations[key]["aggs"][f"{key}_filtered"]["aggs"][key]["terms"]["field"] + return aggregations[key]["terms"]["field"] + + +def apply_location_filter(query: ESBoolQuery, val: dict[str, Any]) -> None: + search_type = val.get("type", "location") + + if search_type == "city": + field = "location.address.city.keyword" + elif search_type == "state": + field = "location.address.state.keyword" + elif search_type == "country": + field = "location.address.country.keyword" + else: + field = "location.name.keyword" + + query.filter.append({"term": {field: val.get("name")}}) + + +def apply_coverage_filter(query: ESBoolQuery, val: dict[str, Any]) -> None: + query.filter.append( + nested_query( + path="coverages", + query={"terms": {get_aggregation_field("coverage"): val}}, + name="coverage", + ) + ) + + +def apply_coverage_status_filter(query: ESBoolQuery, val: list[str]) -> None: + if val == ["planned"]: + query.filter.append( + nested_query( + path="coverages", + query={"terms": {"coverages.coverage_status": ["coverage intended"]}}, + name="coverage_status", + ) + ) + query.must_not.extend( + [ + nested_query( + path="coverages", + query={"exists": {"field": "coverages.delivery_id"}}, + ), + nested_query( + path="coverages", + query={"terms": {"coverages.workflow_status": ["completed"]}}, + name="workflow_status", + ), + ] + ) + elif val == ["may be"]: + query.filter.append( + nested_query( + path="coverages", + query={ + "terms": { + "coverages.coverage_status": [ + "coverage not decided yet", + "coverage upon request", + ], + }, + }, + name="coverage_status", + ) + ) + elif val == ["not planned"]: + query.must_not.append( + nested_query( + path="coverages", + query={"exists": {"field": "coverages"}}, + name="coverage_status", + ) + ) + elif val == ["completed"]: + should_term_filters = [] + # Check if "delivery_id" is present + should_term_filters.append( + nested_query( + path="coverages", + query={"exists": {"field": "coverages.delivery_id"}}, + ) + ) + + # If "delivery_id" is not present, check "workflow_status" + should_term_filters.append( + nested_query( + path="coverages", + query={"terms": {"coverages.workflow_status": ["completed"]}}, + name="workflow_status", + ) + ) + query.filter.append( + { + "bool": { + "should": should_term_filters, + "minimum_should_match": 1, + }, + } + ) + elif val == ["not intended"]: + query.filter.append( + nested_query( + path="coverages", + query={"terms": {"coverages.coverage_status": ["coverage not intended"]}}, + name="coverage_status", + ) + ) + + +def apply_agendas_filter(query: ESBoolQuery, val) -> None: + query.filter.append( + nested_query( + path="planning_items", + query={"terms": {get_aggregation_field("agendas"): val}}, + name="agendas", + ) + ) + + +def get_apply_agenda_filters(highlights: bool) -> SearchFilterFunction: + def _apply_agenda_filters(request: NewshubSearchRequest[AgendaSearchRequestArgs]) -> None: + if not request.args.filter: + return + + query = request.search.post_filter if get_app_config("FILTER_BY_POST_FILTER", False) else request.search.query + + for key, val in request.args.filter.items(): + is_event_type = request.args.item_type == AgendaItemType.EVENT + if not val or (is_event_type and key in planning_filters): + continue + + match key: + case "location": + apply_location_filter(query, val) + case "coverage": + apply_coverage_filter(query, val) + case "coverage_status": + apply_coverage_status_filter(query, val) + case "agendas": + apply_agendas_filter(query, val) + case _: + agg_field = get_aggregation_field(key) + filter_query = get_filter_query(key, val, agg_field, get_nested_config("agenda", key)) + if not is_event_type: + query.filter.append( + { + "bool": { + "minimum_should_match": 1, + "should": [ + filter_query, + nested_query( + path="planning_items", + query={ + "bool": {"filter": [{"terms": {f"planning_items.{agg_field}": val}}]} + }, + name=key, + inner_hits=highlights, + ), + ], + }, + } + ) + else: + query.filter.append(filter_query) + + return _apply_agenda_filters + + +apply_agenda_filters = get_apply_agenda_filters(False) +apply_agenda_highlights_filters = get_apply_agenda_filters(True) + +apply_highlights = get_apply_highlights( + [ + apply_agenda_query_string, + apply_ids_filter, + apply_agenda_highlights_filters, + apply_agenda_date_filters, + apply_advanced_search, + ] +) + +filters_without_dates: list[SearchFilterFunction] = [ + apply_item_state_filter, + apply_item_type_filter, + apply_section_filter, + apply_company_filter, + apply_products_filter, + apply_product_planning_filters, + apply_ids_filter, + apply_agenda_query_string, + apply_agenda_bookmarks_query, + apply_advanced_search, + apply_agenda_filters, +] + +#: Default filters to run for Agenda searches +default_agenda_filters: list[SearchFilterFunction] = [ + # Prefill request variables + prefill_user, + prefill_company, + prefill_products, + prefill_args_from_topic, + prefill_item_type_arg, + apply_agenda_date_filters, +] + filters_without_dates diff --git a/newsroom/agenda/module.py b/newsroom/agenda/module.py index d765057d7..11c31005c 100644 --- a/newsroom/agenda/module.py +++ b/newsroom/agenda/module.py @@ -1,10 +1,27 @@ -from newsroom import MONGO_PREFIX -from newsroom.agenda.service import FeaturedService +from superdesk.core.module import Module, SuperdeskAsyncApp +from superdesk.core.resources import ResourceConfig, MongoResourceConfig, ElasticResourceConfig +from superdesk.core.web import EndpointGroup -from newsroom.types import FeaturedResourceModel -from superdesk.core.module import Module -from superdesk.core.resources import ResourceConfig, MongoResourceConfig +from newsroom.types import FeaturedResourceModel, AgendaItem +from newsroom import MONGO_PREFIX, ELASTIC_PREFIX +from newsroom.search.config import init_nested_aggregation +from .agenda_search import AgendaItemService, AgendaSearchServiceAsync +from .filters import PRIVATE_FIELDS, aggregations +from .featured_service import FeaturedService + + +agenda_endpoints = EndpointGroup("agenda", __name__) + + +agenda_items_resource_config = ResourceConfig( + name="agenda", + data_class=AgendaItem, + service=AgendaItemService, + default_sort=[("dates.start", 1)], + mongo=MongoResourceConfig(prefix=MONGO_PREFIX), + elastic=ElasticResourceConfig(prefix=ELASTIC_PREFIX), +) featured_resource_config = ResourceConfig( name="agenda_featured", @@ -13,4 +30,24 @@ mongo=MongoResourceConfig(prefix=MONGO_PREFIX), ) -module = Module(name="newsroom.agenda", resources=[featured_resource_config]) + +def init_module(app: SuperdeskAsyncApp): + configured_page_size = app.wsgi.config.get("AGENDA_PAGE_SIZE") + if configured_page_size is not None: + AgendaSearchServiceAsync.default_page_size = configured_page_size + + if app.wsgi.config.get("AGENDA_HIDE_COVERAGE_ASSIGNEES"): + PRIVATE_FIELDS.extend(["*.assigned_desk_*", "*.assigned_user_*"]) + + init_nested_aggregation("agenda", ["subject"], app.wsgi.config.get("AGENDA_GROUPS", []), aggregations) + + +module = Module( + name="newsroom.agenda", + init=init_module, + endpoints=[agenda_endpoints], + resources=[ + featured_resource_config, + agenda_items_resource_config, + ], +) diff --git a/newsroom/agenda/notifications.py b/newsroom/agenda/notifications.py new file mode 100644 index 000000000..63492facc --- /dev/null +++ b/newsroom/agenda/notifications.py @@ -0,0 +1,274 @@ +from copy import deepcopy +from bson import ObjectId + +from newsroom.agenda.email import send_agenda_notification_email +from newsroom.notifications import save_user_notifications +from planning.common import WORKFLOW_STATE + +from newsroom.types import UserResourceModel, CompanyResource +from newsroom.utils import get_user_dict_async, get_company_dict_async + + +def _filter_active_users( + user_ids: list[str], + user_dict: dict[ObjectId, UserResourceModel], + company_dict: dict[ObjectId, CompanyResource], + events_only: bool = False, +) -> list[ObjectId]: + active: list[ObjectId] = [] + for user_id_str in user_ids: + user_id = ObjectId(user_id_str) + user = user_dict.get(user_id) + if not user: + continue + + company = company_dict.get(user.company) if user.company else None + + if user and (not user.company or company): + if events_only and company and company.events_only: + continue + active.append(user_id) + return active + + +def _get_detailed_coverage(agenda, original_agenda, cov): + plan = next( + (p for p in (agenda.get("planning_items") or []) if p["guid"] == cov.get("planning_id")), + None, + ) + if plan and plan.get("state") != WORKFLOW_STATE.KILLED: + detail_cov = next( + (c for c in (plan.get("coverages") or []) if c.get("coverage_id") == cov.get("coverage_id")), + None, + ) + if detail_cov: + detail_cov["watches"] = cov.get("watches") + + return detail_cov + + original_cov = next( + (c for c in original_agenda.get("coverages") or [] if c["coverage_id"] == cov["coverage_id"]), + cov, + ) + cov["watches"] = original_cov.get("watches") or [] + return cov + + +def _fill_all_coverages( + agenda, original_agenda, coverage_updates, skip_coverages=None, cancelled=False, use_original_agenda=False +): + if skip_coverages is None: + skip_coverages = [] + + fill_list = coverage_updates["unaltered_coverages"] if not cancelled else coverage_updates["cancelled_coverages"] + for coverage in (agenda if not use_original_agenda else original_agenda).get("coverages") or []: + if not next( + (s for s in skip_coverages if s.get("coverage_id") == coverage.get("coverage_id")), + None, + ): + detailed_coverage = _get_detailed_coverage(agenda, original_agenda, coverage) + if detailed_coverage: + fill_list.append(detailed_coverage) + + +async def notify_agenda_update( + update_agenda, + original_agenda, + item=None, + events_only=False, + related_planning_removed=None, + coverage_updated=None, +): + if not update_agenda or original_agenda.get("state") == WORKFLOW_STATE.KILLED: + return + + agenda = deepcopy(update_agenda) + user_dict = await get_user_dict_async() + company_dict = await get_company_dict_async() + coverage_watched = False + + for c in original_agenda.get("coverages") or []: + if len(c.get("watches") or []) > 0: + coverage_watched = True + break + + notify_user_ids = _filter_active_users(original_agenda.get("watches", []), user_dict, company_dict, events_only) + if len(notify_user_ids) == 0 and not coverage_watched: + return + + users: list[UserResourceModel] = list(filter(None, [user_dict.get(user_id) for user_id in notify_user_ids])) + coverage_updates = { + "modified_coverages": [] if not coverage_updated else [coverage_updated], + "cancelled_coverages": [], + "unaltered_coverages": [], + } + + only_new_coverages = len(coverage_updates["modified_coverages"]) == 0 + time_updated = False + state_changed = False + coverage_modified = False + + # Send notification for only these state changes + notify_states = [ + WORKFLOW_STATE.CANCELLED, + WORKFLOW_STATE.RESCHEDULED, + WORKFLOW_STATE.POSTPONED, + WORKFLOW_STATE.KILLED, + WORKFLOW_STATE.SCHEDULED, + ] + + if not coverage_updated: # If not story updates - but from planning side + if related_planning_removed: + _fill_all_coverages( + agenda, original_agenda, coverage_updates, related_planning_removed.get("coverages") or [] + ) + # Add removed coverages: + for coverage in related_planning_removed.get("coverages") or []: + detailed_coverage = _get_detailed_coverage(agenda, original_agenda, coverage) + if detailed_coverage: + coverage_updates["cancelled_coverages"].append(detailed_coverage) + else: + # Send notification if time got updated + if original_agenda.get("dates") and agenda.get("dates"): + time_updated = ( + (original_agenda.get("dates") or {}).get("start").replace(tzinfo=None) + != (agenda.get("dates") or {}).get("start").replace(tzinfo=None) + ) or ( + (original_agenda.get("dates") or {}).get("end").replace(tzinfo=None) + != (agenda.get("dates") or {}).get("end").replace(tzinfo=None) + ) + + if agenda.get("state") and agenda.get("state") != original_agenda.get("state"): + state_changed = agenda.get("state") in notify_states + + if state_changed: + _fill_all_coverages( + agenda, + original_agenda, + coverage_updates, + cancelled=False if agenda.get("state") == WORKFLOW_STATE.SCHEDULED else True, + use_original_agenda=True, + ) + else: + if time_updated: + _fill_all_coverages(agenda, original_agenda, coverage_updates) + else: + for coverage in agenda.get("coverages") or []: + existing_coverage = next( + ( + c + for c in original_agenda.get("coverages") or [] + if c["coverage_id"] == coverage["coverage_id"] + ), + None, + ) + detailed_coverage = _get_detailed_coverage(agenda, original_agenda, coverage) + if detailed_coverage: + if not existing_coverage: + if coverage["workflow_status"] != WORKFLOW_STATE.CANCELLED: + coverage_updates["modified_coverages"].append(detailed_coverage) + elif coverage.get( + "workflow_status" + ) == WORKFLOW_STATE.CANCELLED and existing_coverage.get( + "workflow_status" + ) != coverage.get( + "workflow_status" + ): + coverage_updates["cancelled_coverages"].append(detailed_coverage) + elif ( + ( + coverage.get("delivery_state") != existing_coverage.get("delivery_state") + and coverage.get("delivery_state") == "published" + ) + or ( + coverage.get("workflow_status") != existing_coverage.get("workflow_status") + and coverage.get("workflow_status") == "completed" + ) + or (existing_coverage.get("scheduled") != coverage.get("scheduled")) + ): + coverage_updates["modified_coverages"].append(detailed_coverage) + only_new_coverages = False + elif detailed_coverage["coverage_id"] != (coverage_updated or {}).get("coverage_id"): + coverage_updates["unaltered_coverages"].append(detailed_coverage) + + # Check for removed coverages - show it as cancelled + if item and item.get("type") == "planning": + for original_cov in original_agenda.get("coverages") or []: + updated_cov = next( + ( + c + for c in (agenda.get("coverages") or []) + if c.get("coverage_id") == original_cov.get("coverage_id") + ), + None, + ) + if not updated_cov: + coverage_updates["cancelled_coverages"].append(original_cov) + + if len(coverage_updates["modified_coverages"]) > 0 or len(coverage_updates["cancelled_coverages"]) > 0: + coverage_modified = True + + if not (coverage_updated or related_planning_removed or time_updated or state_changed or coverage_modified): + return + + agenda["name"] = agenda.get("name", original_agenda.get("name")) + agenda["definition_short"] = agenda.get("definition_short", original_agenda.get("definition_short")) + agenda["ednote"] = agenda.get("ednote", original_agenda.get("ednote")) + agenda["state_reason"] = agenda.get("state_reason", original_agenda.get("state_reason")) + action = "been updated." + if state_changed: + action = "been {}.".format( + agenda.get("state") if agenda.get("state") != WORKFLOW_STATE.KILLED else "removed from the calendar" + ) + + if ( + len(coverage_updates["modified_coverages"]) > 0 + and only_new_coverages + and len(coverage_updates["cancelled_coverages"]) == 0 + ): + action = "new coverage(s)." + + message = "The {} you have been following has {}".format( + "event" if agenda.get("event") else "coverage plan", action + ) + + if agenda.get("state_reason"): + reason_prefix = agenda.get("state_reason").find(":") + if reason_prefix > 0: + message = "{} {}".format( + message, + agenda["state_reason"][(reason_prefix + 1) : len(agenda["state_reason"])], + ) + + # append coverage watching users too - except for unaltered_coverages + for c in coverage_updates["cancelled_coverages"] + coverage_updates["modified_coverages"]: + if c.get("watches"): + notify_user_ids = _filter_active_users(c["watches"], user_dict, company_dict, events_only) + users = users + [user_dict[user_id] for user_id in notify_user_ids] + + # Send notifications to users + await save_user_notifications( + [ + dict( + user=user.id, + item=agenda.get("_id"), + resource="agenda", + action="watched_agenda_updated", + data=None, + ) + for user in users + ] + ) + + for user in users: + await send_agenda_notification_email( + user.to_dict(), + agenda, + message, + original_agenda, + coverage_updates, + related_planning_removed, + coverage_updated, + time_updated, + coverage_modified, + ) diff --git a/newsroom/agenda/service.py b/newsroom/agenda/service.py deleted file mode 100644 index 7bb8dc28c..000000000 --- a/newsroom/agenda/service.py +++ /dev/null @@ -1,136 +0,0 @@ -from datetime import datetime -from eve.utils import ParsedRequest -from typing import Any - -from newsroom.auth.utils import get_user_from_request, get_company_from_request -from newsroom.agenda.agenda import ( - is_events_only_access, - build_agenda_query, - nested_query, - planning_items_query_string, - aggregations, - remove_fields, - PLANNING_ITEMS_FIELDS, -) -from newsroom.types import FeaturedResourceModel -from newsroom.utils import get_local_date -from newsroom.template_filters import is_admin - -from superdesk import get_resource_service -from superdesk.flask import abort -from superdesk.core.resources import AsyncResourceService -from superdesk.utc import local_to_utc - - -class FeaturedService(AsyncResourceService[FeaturedResourceModel]): - resource_name = "agenda_featured" - - async def on_create(self, docs: list[FeaturedResourceModel]) -> None: - """ - Add UTC from/to datetimes on save. - Problem is 31.8. in Sydney is from 30.8. 14:00 UTC to 31.8. 13:59 UTC. - And because we query later using UTC, we store those UTC datetimes as - display_from and display_to. - """ - for item in docs: - date = datetime.strptime(item.id, "%Y%m%d") - item.display_from = local_to_utc(item.tz, date.replace(hour=0, minute=0, second=0)) - item.display_to = local_to_utc(item.tz, date.replace(hour=23, minute=59, second=59)) - await super().on_create(docs) - - async def find_one_for_date(self, for_date: datetime) -> FeaturedResourceModel | None: - return await self.find_one(display_from={"$lte": for_date}, display_to={"$gte": for_date}) - - async def get_featured_stories( - self, - date_from: str, - timezone_offset: int = 0, - query_string: str | None = None, - filter_string: str | None = None, - from_offset: int = 0, - ) -> Any: - for_date = datetime.strptime(date_from, "%d/%m/%Y %H:%M") - local_date = get_local_date( - for_date.strftime("%Y-%m-%d"), - for_date.strftime("%H:%M:%S"), - timezone_offset, - ) - featured_doc = await self.find_one_for_date(local_date) - return await self.featured(featured_doc, query_string, filter_string, from_offset) - - async def featured( - self, - featured_doc: dict | None = None, - query_string: str | None = None, - filter_string: str | None = None, - from_offset: int = 0, - ) -> Any: - """Return featured items. - - :param Optional[dict] featured_doc: The featured document for the given date - :param Optional[str] query_string: Optional search query to filter the results - :param Optional[str] filter_string: Optional filter query to filter the results - :param int from_offset: Pagination offset for the results - :return: A list of filtered featured items - """ - from newsroom.section_filters.service import SectionFiltersService - - user = get_user_from_request(None) - company = get_company_from_request(None) - if is_events_only_access(user.to_dict(), company.to_dict()): # type: ignore - abort(403) - - if not featured_doc or not featured_doc.get("items"): - return ({"_items": [], "_meta": {"total": 0}},) - - query = build_agenda_query() - await SectionFiltersService().apply_section_filter(query, self.section) - - planning_items_query = nested_query( - "planning_items", - {"bool": {"filter": [{"terms": {"planning_items.guid": featured_doc["items"]}}]}}, - name="featured", - ) - - if query_string: - query["bool"]["filter"].append(self.query_string(query_string)) - planning_items_query["nested"]["query"]["bool"]["filter"].append(planning_items_query_string(query_string)) - - query["bool"]["filter"].append(planning_items_query) - - source = {"query": query, "size": len(featured_doc["items"]), "from": from_offset} - req = ParsedRequest() - req.args = {"filter": filter_string} - get_resource_service("agenda").set_post_filter(source, req) - if not from_offset: - source["aggs"] = aggregations - - if company and not is_admin(user) and company.events_only: - query["bool"]["filter"].append({"exists": {"field": "event"}}) - remove_fields(source, PLANNING_ITEMS_FIELDS) - - cursor = await self.search(source) - - docs_by_id = {} - for doc in cursor.docs: - for p in doc.get("planning_items") or []: - docs_by_id[p.get("guid")] = doc - - # Update display dates based on the featured document - doc.update( - { - "_display_from": featured_doc["display_from"], - "_display_to": featured_doc["display_to"], - } - ) - - docs = [] - agenda_ids = set() - for _id in featured_doc["items"]: - if docs_by_id.get(_id) and docs_by_id.get(_id).get("_id") not in agenda_ids: # type: ignore - docs.append(docs_by_id.get(_id)) - agenda_ids.add(docs_by_id.get(_id).get("_id")) # type: ignore - - # TODO-ASYNC: Temporary implementation to ensure the response includes _meta details - # for total count as the front end relies on data._meta.total - return ({"_items": docs, "_meta": {"total": len(docs)}},) diff --git a/newsroom/agenda/views.py b/newsroom/agenda/views.py index 5644e9141..e17a2feab 100644 --- a/newsroom/agenda/views.py +++ b/newsroom/agenda/views.py @@ -1,334 +1,403 @@ -import json -from typing import Dict +from typing import Annotated, Any, cast +from asyncio import gather from bson import ObjectId +from pydantic import Field, field_validator from quart_babel import gettext -from eve.methods.get import get_internal -from eve.render import send_response -from eve.utils import ParsedRequest -from newsroom.search.types import NewshubSearchRequest -from newsroom.wire.filters import WireSearchRequestArgs from superdesk.core import get_app_config, get_current_app -from superdesk.core.types import ESQuery -from superdesk.flask import request, render_template, abort, jsonify -from superdesk import get_resource_service +from superdesk.core.types import ESQuery, BaseModel, Request, Response, RestGetResponse +from superdesk.core.resources.cursor import ElasticsearchResourceCursorAsync +from superdesk.flask import render_template +from newsroom.types import AgendaItem, AgendaItemType, SectionEnum +from newsroom.auth import auth_rules from newsroom.auth.utils import ( get_user_from_request, - get_user_id_from_request, get_company_from_request, check_user_has_products, ) -from newsroom.agenda import blueprint +from newsroom.ui_config_async import UiConfigResourceService +from newsroom.users import get_user_profile_data from newsroom.products import get_products_by_company -from newsroom.topics import get_user_topics +from newsroom.topics import get_user_topics_async from newsroom.topics_folders import get_company_folders, get_user_folders from newsroom.navigations import get_navigations -from newsroom.decorator import login_required, section +from newsroom.notifications import push_user_notification +from newsroom.history_async import HistoryService +from newsroom.search.types import NewshubSearchRequest +from newsroom.search.config import merge_planning_aggs + +from newsroom.wire import WireSearchServiceAsync +from newsroom.wire.utils import update_action_list +from newsroom.wire.views import set_item_permission +from newsroom.wire.filters import WireSearchRequestArgs + from newsroom.utils import ( - get_entity_or_404, is_json_request, get_json_or_400, get_agenda_dates, get_location_string, get_public_contacts, get_links, - get_vocabulary, + get_vocabulary_async, get_groups, ) -from newsroom.wire import WireSearchServiceAsync -from newsroom.wire.utils import update_action_list -from newsroom.wire.views import set_item_permission -from newsroom.agenda.email import send_coverage_request_email -from newsroom.agenda.service import FeaturedService -from newsroom.agenda.utils import remove_fields_for_public_user, remove_restricted_coverage_info -from newsroom.notifications import push_user_notification -from newsroom.search.config import merge_planning_aggs -from newsroom.ui_config_async import UiConfigResourceService -from newsroom.users import get_user_profile_data -from newsroom.history_async import HistoryService +from .email import send_coverage_request_email +from .featured_service import FeaturedService +from .utils import remove_fields_for_public_user, remove_restricted_coverage_info +from .module import agenda_endpoints +from .agenda_service import AgendaItemService +from .agenda_search import AgendaSearchServiceAsync +from .filters import AgendaSearchRequestArgs -@blueprint.route("/agenda") -@login_required -@section("agenda") -async def index(): + +@agenda_endpoints.endpoint("/agenda", auth=[auth_rules.section_required("agenda")]) +async def index() -> str: user_profile_data = await get_user_profile_data() data = await get_view_data() return await render_template("agenda_index.html", data=data, user_profile_data=user_profile_data) -@blueprint.route("/bookmarks_agenda") -@login_required -async def bookmarks(): +@agenda_endpoints.endpoint("/bookmarks_agenda") +async def bookmarks() -> str: data = await get_view_data() user_profile_data = await get_user_profile_data() data["bookmarks"] = True return await render_template("agenda_bookmarks.html", data=data, user_profile_data=user_profile_data) -@blueprint.route("/agenda/<_id>") -@login_required -async def item(_id): - item = get_entity_or_404(_id, "agenda") - user_profile_data = await get_user_profile_data() +class AgendaItemViewArgs(BaseModel): + item_id: Annotated[str, Field(alias="_id")] + + +class AgendaItemParams(BaseModel): + print: bool = False + map: str | None = None + type: str = "agenda" + + @field_validator("print", mode="before") + def parse_print(cls, value: str | bool | None) -> bool | str | None: + # Support this URL param as a toggle, if `print` is provided in the URL then it is `True` + return True if value == "" else value + + +@agenda_endpoints.endpoint("/agenda/<_id>") +async def item(args: AgendaItemViewArgs, params: AgendaItemParams, request: Request) -> Response | str: + agenda_item = await AgendaItemService().find_by_id(args.item_id) + if not agenda_item: + await request.abort(404) + + agenda_item_dict = agenda_item.to_dict() user = get_user_from_request(None) company = get_company_from_request(None) if not user.is_admin_or_internal(): - remove_fields_for_public_user(item) + remove_fields_for_public_user(agenda_item_dict) if company and not user.is_admin() and company.events_only: # if the company has permission events only permission then # remove planning items and coverages. - if not item.get("event"): + if not agenda_item_dict.get("event"): # for adhoc planning items abort the request - abort(403) + return await request.abort(403) - item.pop("planning_items", None) - item.pop("coverages", None) + agenda_item_dict.pop("planning_items", None) + agenda_item_dict.pop("coverages", None) if company and company.restrict_coverage_info: - remove_restricted_coverage_info([item]) + remove_restricted_coverage_info([agenda_item_dict]) if is_json_request(request): - return jsonify(item) + return Response(agenda_item_dict) - if "print" in request.args: - map = request.args.get("map") + user_profile_data = await get_user_profile_data() + if params.print: template = "agenda_item_print.html" - update_action_list([_id], "prints", force_insert=True) - await HistoryService().create_history_record( - [item], "print", user.id, user.company, request.args.get("type", "agenda") - ) + await update_action_list([args.item_id], "prints", force_insert=True) + await HistoryService().create_history_record([agenda_item_dict], "print", user.id, user.company, params.type) return await render_template( template, - item=item, - map=map, - dateString=get_agenda_dates(item), - location=get_location_string(item), - contacts=get_public_contacts(item), - links=get_links(item), + item=agenda_item_dict, + map=params.map, + dateString=get_agenda_dates(agenda_item_dict), + location=get_location_string(agenda_item_dict), + contacts=get_public_contacts(agenda_item_dict), + links=get_links(agenda_item_dict), is_admin=user.is_admin_or_internal(), user_profile_data=user_profile_data, ) data = await get_view_data() - data["item"] = item + data["item"] = agenda_item_dict return await render_template( "agenda_index.html", data=data, - title=item.get("name", item.get("headline")), + title=agenda_item_dict.get("name", agenda_item_dict.get("headline")), user_profile_data=user_profile_data, ) -@blueprint.route("/agenda/search") -@login_required -@section("agenda") -async def search(): - if request.args.get("featured"): - date_from = request.args.get("date_from") - timezone_offset = int(request.args.get("timezone_offset", 0)) - query_string = request.args.get("q") - filter_string = request.args.get("filter") - from_offset = int(request.args.get("from", 0)) +@agenda_endpoints.endpoint("/agenda/search", auth=[auth_rules.section_required("agenda")]) +async def search(args: None, params: AgendaSearchRequestArgs, request: Request) -> Response: + user = get_user_from_request(request) + company = get_company_from_request(None) + if params.featured: + if user.is_events_only_access(company): + return await request.abort(403) + elif params.start_date is None: + return await request.abort(400, gettext("No date specified.")) + response = await FeaturedService().get_featured_stories( - date_from, timezone_offset, query_string, filter_string, from_offset + params.start_date, + params.timezone_offset or 0, + params.q, + params.filter, + params.page or 0, ) - return await send_response("agenda", response) + return Response(response) - response = await get_internal("agenda") - if len(response): - company = get_company_from_request(None) - if company and company.restrict_coverage_info: - remove_restricted_coverage_info(response[0].get("_items") or []) - if response[0].get("_aggregations"): - merge_planning_aggs(response[0]["_aggregations"]) - return await send_response("agenda", response) + response = await AgendaSearchServiceAsync().process_web_request(request) + body: RestGetResponse = response.body + if len(body.get("_items") or []) and company and company.restrict_coverage_info: + remove_restricted_coverage_info(body["_items"]) -async def get_view_data() -> Dict: + if body.get("_aggregations"): + merge_planning_aggs(body["_aggregations"]) + + return response + + +async def get_view_data() -> dict: user = get_user_from_request(None) user_dict = None if not user else user.to_dict() company = get_company_from_request(None) company_dict = None if not company else company.to_dict() - topics = await get_user_topics(user.id) if user else [] - products = await get_products_by_company(company_dict, product_type="agenda") if company else [] + # Helper function to provide an async function, otherwise ``gather`` fails with + # TypeError('An asyncio.Future, a coroutine or an awaitable is required') + async def empty_array(): + return [] + + ( + topics, + products, + navigations, + ui_config, + featured_count, + user_folders, + company_folders, + saved_items, + locators, + ) = await gather( + get_user_topics_async(user) if user else empty_array(), + get_products_by_company(company_dict, product_type=SectionEnum.AGENDA) if company else empty_array(), + get_navigations(user_dict, company_dict, "agenda"), + UiConfigResourceService().get_section_config("agenda"), + FeaturedService().count(), + get_user_folders(user, "agenda") if user else empty_array(), + get_company_folders(company, "agenda") if company else empty_array(), + AgendaSearchServiceAsync().get_saved_items_count(user, company), + get_vocabulary_async("locators"), + ) check_user_has_products(user, products) - ui_config_service = UiConfigResourceService() return { "user": user_dict or {}, "company": company.id if company else None, - "topics": [t for t in topics if t.get("topic_type") == "agenda"], + "topics": [t.to_dict() for t in topics if t.topic_type == "agenda"], "formats": [ {"format": f["format"], "name": f["name"]} for f in get_current_app().as_any().download_formatters.values() if "agenda" in f["types"] ], - "navigations": await get_navigations(user_dict, company_dict, "agenda"), - "saved_items": get_resource_service("agenda").get_saved_items_count(), + "navigations": navigations, + "saved_items": saved_items, "events_only": company.events_only if company else False, "restrict_coverage_info": company.restrict_coverage_info if company else False, - "locators": get_vocabulary("locators"), - "ui_config": await ui_config_service.get_section_config("agenda"), + "locators": locators, + "ui_config": ui_config, "groups": get_groups(get_app_config("AGENDA_GROUPS", []), company_dict), - "has_agenda_featured_items": await FeaturedService().count() > 0, - "user_folders": await get_user_folders(user, "agenda") if user else [], - "company_folders": await get_company_folders(company, "agenda") if company else [], + "has_agenda_featured_items": featured_count > 0, + "user_folders": user_folders, + "company_folders": company_folders, "date_filters": get_app_config("AGENDA_TIME_FILTERS", []), } -@blueprint.route("/agenda/request_coverage", methods=["POST"]) -@login_required -async def request_coverage(): +@agenda_endpoints.endpoint("/agenda/request_coverage", methods=["POST"]) +async def request_coverage(request: Request) -> Response: user = get_user_from_request(None) data = await get_json_or_400() assert data.get("item") assert data.get("message") - item = get_entity_or_404(data.get("item"), "agenda") - await send_coverage_request_email(user, data.get("message"), item) - return jsonify(), 201 + agenda_item = await AgendaItemService().find_by_id(data.get("item")) + if not agenda_item: + return await request.abort(404) + await send_coverage_request_email(user, data.get("message"), agenda_item) + return Response("", 201) -@blueprint.route("/agenda_bookmark", methods=["POST", "DELETE"]) -@login_required -async def bookmark(): +@agenda_endpoints.endpoint("/agenda_bookmark", methods=["POST", "DELETE"]) +async def bookmark(request: Request) -> Response: data = await get_json_or_400() assert data.get("items") - update_action_list(data.get("items"), "bookmarks", item_type="agenda") - push_user_notification("saved_items", count=get_resource_service("agenda").get_saved_items_count()) - return jsonify(), 200 + await update_action_list(data.get("items"), "bookmarks", item_type="agenda") + item_count = await AgendaSearchServiceAsync().get_saved_items_count( + get_user_from_request(request), + get_company_from_request(request), + ) + push_user_notification("saved_items", count=item_count) + return Response("") + +class WatchAgendaParams(BaseModel): + bookmarks: bool = False -@blueprint.route("/agenda_watch", methods=["POST", "DELETE"]) -@login_required -async def follow(): + +@agenda_endpoints.endpoint("/agenda_watch", methods=["POST", "DELETE"]) +async def follow(args: None, params: WatchAgendaParams, request: Request) -> Response: data = await get_json_or_400() assert data.get("items") - user_id = get_user_id_from_request(None) + user = get_user_from_request(request) + company = get_company_from_request(request) + + agenda_service = AgendaItemService() + cursor = await agenda_service.search({"_id": {"$in": data.get("items")}}, use_mongo=True) + agenda_items: dict[str, AgendaItem] = {agenda_item.id: agenda_item async for agenda_item in cursor} + for item_id in data.get("items"): - item = get_entity_or_404(item_id, "agenda") - coverage_updates = {"coverages": item.get("coverages") or []} + agenda_item = agenda_items.get(item_id) + if not agenda_item: + return await request.abort(404) + coverage_updates = {"coverages": agenda_item.coverages or []} + for c in coverage_updates["coverages"]: - if c.get("watches") and user_id in c["watches"]: - c["watches"].remove(user_id) + if c.watches and user.id in c.watches: + c.watches.remove(user.id) if request.method == "POST": - updates = {"watches": list(set((item.get("watches") or []) + [user_id]))} - if item.get("coverages"): + updates = {"watches": list(set((agenda_item.watches or []) + [user.id]))} + + if agenda_item.coverages: updates.update(coverage_updates) - get_resource_service("agenda").patch(item_id, updates) + await agenda_service.update(agenda_item.id, updates) else: - if request.args.get("bookmarks"): - user_item_watches = [u for u in (item.get("watches") or []) if str(u) == str(user_id)] + if params.bookmarks: + user_item_watches = [user_id for user_id in (agenda_item.watches or []) if user_id == user.id] if not user_item_watches: # delete user watches of all coverages - get_resource_service("agenda").patch(item_id, coverage_updates) - return jsonify(), 200 + await agenda_service.update(agenda_item.id, coverage_updates) + return Response("") - update_action_list(data.get("items"), "watches", item_type="agenda") + await update_action_list(data.get("items"), "watches", item_type="agenda") - push_user_notification("saved_items", count=get_resource_service("agenda").get_saved_items_count()) - return jsonify(), 200 + item_count = await AgendaSearchServiceAsync().get_saved_items_count(user, company) + push_user_notification("saved_items", count=item_count) + return Response("") -@blueprint.route("/agenda_coverage_watch", methods=["POST", "DELETE"]) -@login_required -async def watch_coverage(): - user_id = get_user_id_from_request(None) +@agenda_endpoints.endpoint("/agenda_coverage_watch", methods=["POST", "DELETE"]) +async def watch_coverage(request: Request) -> Response: + user = get_user_from_request(request) + company = get_company_from_request(request) data = await get_json_or_400() - assert data.get("item_id") + item_id = data.get("item_id") + assert item_id assert data.get("coverage_id") - response = update_coverage_watch(data["item_id"], data["coverage_id"], user_id, add=request.method == "POST") - push_user_notification("saved_items", count=get_resource_service("agenda").get_saved_items_count()) - return response + agenda_item = await AgendaItemService().find_by_id(item_id) + if not agenda_item: + return Response({"error": gettext(f"Agenda item '{item_id}' not found")}, 404) + body, return_code = await _update_coverage_watch( + agenda_item, data["coverage_id"], user.id, add=request.method == "POST" + ) + if return_code == 404: + return Response(body, 404) -def update_coverage_watch(item_id: str, coverage_id: str, user_id: ObjectId, add: bool, skip_associated: bool = False): - item = get_entity_or_404(item_id, "agenda") + item_count = await AgendaSearchServiceAsync().get_saved_items_count(user, company) + push_user_notification("saved_items", count=item_count) + return Response(body or "", return_code) - if user_id in item.get("watches", []): - return ( - jsonify({"error": gettext("Cannot edit coverage watch when watching parent item")}), - 403, - ) + +async def _update_coverage_watch( + agenda_item: AgendaItem, coverage_id: str, user_id: ObjectId, add: bool, skip_associated: bool = False +) -> tuple[None, int] | tuple[dict[str, str], int]: + agenda_service = AgendaItemService() + + if user_id in (agenda_item.watches or []): + return {"error": gettext("Cannot edit coverage watch when watching parent item")}, 403 try: - coverage_index = [c["coverage_id"] for c in (item.get("coverages") or [])].index(coverage_id) + coverage_index = [c.coverage_id for c in (agenda_item.coverages or [])].index(coverage_id) except ValueError: - return jsonify({"error": gettext("Coverage not found")}), 404 - - updates = {"coverages": item["coverages"]} + return {"error": gettext(f"Coverage '{coverage_id}' not found on agenda item '{agenda_item.id}'")}, 404 + updates = {"coverages": agenda_item.coverages} if add: - updates["coverages"][coverage_index]["watches"] = list( - set((updates["coverages"][coverage_index].get("watches") or []) + [user_id]) + updates["coverages"][coverage_index].watches = list( + set((updates["coverages"][coverage_index].watches or []) + [user_id]) ) else: try: - updates["coverages"][coverage_index]["watches"].remove(user_id) + updates["coverages"][coverage_index].watches.remove(user_id) except Exception: - return jsonify({"error": gettext("Error removing watch.")}), 404 + return {"error": gettext("Error removing watch.")}, 404 - get_resource_service("agenda").patch(item_id, updates) + await agenda_service.update(agenda_item.id, updates) if skip_associated: - return jsonify(), 200 - elif item.get("item_type") == "planning" and item.get("event_id"): + return None, 200 + elif agenda_item.item_type == AgendaItemType.PLANNING and agenda_item.event_id: # Need to also update the parent Event's list of coverage watches - return update_coverage_watch(item["event_id"], coverage_id, user_id, add, skip_associated=True) - elif item.get("item_type") == "event": + event_item = await agenda_service.find_by_id(agenda_item.event_id) + if event_item: + return await _update_coverage_watch(event_item, coverage_id, user_id, add, skip_associated=True) + + # return await _update_coverage_watch(agenda_item.event_id, coverage_id, user_id, add, skip_associated=True) + elif agenda_item.item_type == AgendaItemType.EVENT: # Need to also update the Planning item's list of coverage watches - return update_coverage_watch( - item["coverages"][coverage_index]["planning_id"], - coverage_id, - user_id, - add, - skip_associated=True, - ) + planning_item = await agenda_service.find_by_id(agenda_item.coverages[coverage_index].planning_id) + if planning_item: + return await _update_coverage_watch( + planning_item, + coverage_id, + user_id, + add, + skip_associated=True, + ) - return jsonify(), 200 + return None, 200 -@blueprint.route("/agenda/wire_items/") -@login_required -async def related_wire_items(wire_id): - elastic = get_current_app().data._search_backend("agenda") - source = {} - must_terms = [{"term": {"coverages.delivery_id": {"value": wire_id}}}] - query = { - "bool": {"filter": must_terms}, - } +class RelatedWireUrlArgs(BaseModel): + wire_id: str - source.update({"query": {"nested": {"path": "coverages", "query": query}}}) - internal_req = ParsedRequest() - internal_req.args = {"source": json.dumps(source)} - agenda_result, _ = elastic.find("agenda", internal_req, None) - if len(agenda_result.docs) == 0: - return ( - jsonify({"error": gettext("%(section)s item not found", section=get_app_config("AGENDA_SECTION"))}), - 404, - ) +@agenda_endpoints.endpoint("/agenda/wire_items/") +async def related_wire_items(args: RelatedWireUrlArgs, params: None, request: Request) -> Response: + agenda_service = AgendaItemService() + query = {"bool": {"filter": [{"term": {"coverages.delivery_id": args.wire_id}}]}} + cursor = await agenda_service.search({"query": {"nested": {"path": "coverages", "query": query}}}) + agenda_item = await cursor.next_raw() + + if agenda_item is None: + return Response({"error": gettext("%(section)s item not found", section=get_app_config("AGENDA_SECTION"))}, 404) company = get_company_from_request(None) if company and company.restrict_coverage_info: - remove_restricted_coverage_info([agenda_result.docs[0]]) + remove_restricted_coverage_info([agenda_item]) wire_ids = [] - for cov in agenda_result.docs[0].get("coverages") or []: + for cov in agenda_item.get("coverages") or []: if cov.get("coverage_type") == "text" and cov.get("delivery_id"): wire_ids.append(cov["delivery_id"]) wire_search = WireSearchServiceAsync() - cursor = await wire_search.service.search({"bool": {"query": {"must": [{"terms": {"_id": wire_ids}}]}}}) - wire_items = await cursor.to_list() + cursor = await wire_search.service.search({"query": {"bool": {"must": [{"terms": {"_id": wire_ids}}]}}}) permissioned_result = await wire_search.search( NewshubSearchRequest( @@ -346,24 +415,27 @@ async def related_wire_items(wire_id): for b in buckets: permissioned_ids.append(b["key"]) - for wire_item in wire_items: + wire_items = [] + async for wire_item in cursor: set_item_permission(wire_item, wire_item.id in permissioned_ids) + wire_items.append(wire_item.to_dict()) - return ( - jsonify( - { - "agenda_item": agenda_result.docs[0], - "wire_items": [item.to_dict() for item in wire_items], - } - ), + return Response( + { + "agenda_item": agenda_item, + "wire_items": wire_items, + }, 200, ) -@blueprint.route("/agenda/search_locations") -@login_required -async def search_locations(): - query = request.args.get("q") or "" +class SearchLocationsParams(BaseModel): + q: str = "" + + +@agenda_endpoints.endpoint("/agenda/search_locations") +async def search_locations(args: None, params: SearchLocationsParams, request: Request) -> Response: + query = params.q apply_filters = len(query) > 0 if apply_filters and not query.startswith("*") and not query.endswith("*"): @@ -389,7 +461,7 @@ def gen_agg_terms(field: str): "size": 1000, } - es_query = { + es_query: dict[str, Any] = { "size": 0, "aggs": { "city_search_country": { @@ -456,10 +528,7 @@ def gen_agg_terms(field: str): "aggs": {"places": es_query["aggs"].pop("places")}, } - req = ParsedRequest() - req.args = {"source": json.dumps(es_query)} - service = get_resource_service("agenda") - cursor = service.internal_get(req, {}) + cursor = cast(ElasticsearchResourceCursorAsync, await AgendaItemService().search(es_query)) aggs = cursor.hits.get("aggregations") or {} regions = [] @@ -491,10 +560,9 @@ def gen_agg_terms(field: str): } ) - return ( + return Response( { "regions": regions, "places": [bucket["key"] for bucket in (aggs.get("places") or aggs["place_search"]["places"])["buckets"]], - }, - 200, + } ) diff --git a/newsroom/am_news/views.py b/newsroom/am_news/views.py index 5d55186d7..893e5854f 100644 --- a/newsroom/am_news/views.py +++ b/newsroom/am_news/views.py @@ -84,7 +84,7 @@ async def bookmark(): """ data = await get_json_or_400() assert data.get("items") - update_action_list(data.get("items"), "bookmarks", item_type="items") + await update_action_list(data.get("items"), "bookmarks", item_type="items") push_user_notification( "saved_items", count=await WireSearchServiceAsync().get_current_user_bookmarks_count(SectionEnum.AM_NEWS), @@ -97,7 +97,7 @@ async def bookmark(): async def copy(_id): item_type = get_type() get_entity_or_404(_id, item_type) - update_action_list([_id], "copies", item_type=item_type) + await update_action_list([_id], "copies", item_type=item_type) return jsonify(), 200 @@ -125,7 +125,7 @@ async def item(_id): previous_versions = get_previous_versions(item) if "print" in request.args: template = "wire_item_print.html" - update_action_list([_id], "prints", force_insert=True) + await update_action_list([_id], "prints", force_insert=True) else: template = "wire_item.html" return await render_template( diff --git a/newsroom/commands/initialize_data.py b/newsroom/commands/initialize_data.py index aee27f721..1f0c5fa80 100644 --- a/newsroom/commands/initialize_data.py +++ b/newsroom/commands/initialize_data.py @@ -48,7 +48,7 @@ async def run(self, entity_name=None, force=False, init_index_only=False): if app.config.get("REBUILD_ELASTIC_ON_INIT_DATA_ERROR"): logger.warning("Can't update the mapping, running elastic_rebuild command now.") - elastic_rebuild() + await elastic_rebuild() else: logger.warning("Can't update the mapping, please run elastic_rebuild command.") diff --git a/newsroom/commands/remove_expired_agenda.py b/newsroom/commands/remove_expired_agenda.py index 99da58ce6..7937c7dca 100644 --- a/newsroom/commands/remove_expired_agenda.py +++ b/newsroom/commands/remove_expired_agenda.py @@ -1,18 +1,17 @@ -import click +from typing import AsyncGenerator import logging - -from typing import List, Set, Dict, Any, Generator from datetime import datetime, timedelta -from eve.utils import ParsedRequest, date_to_str -from superdesk.core import json, get_app_config +import click + +from superdesk.core import get_app_config +from superdesk.core.utils import date_to_str from superdesk.resource_fields import ID_FIELD -from superdesk import get_resource_service from superdesk.lock import lock, unlock from superdesk.utc import utcnow -from newsroom.utils import parse_date_str -from newsroom.agenda.utils import get_item_type, AGENDA_ITEM_TYPE +from newsroom.types import AgendaItem, AgendaItemType +from newsroom.agenda import AgendaItemService from .cli import newsroom_cli logger = logging.getLogger(__name__) @@ -20,7 +19,7 @@ @newsroom_cli.command("remove_expired_agenda") @click.option("-m", "--expiry", "expiry_days", required=False, help="Number of days to determine expiry") -def remove_expired_agenda(expiry_days=None): +async def remove_expired_agenda_command(expiry_days=None): """Remove expired Agenda items By default, no Agenda items expire, you can change this with the ``AGENDA_EXPIRY_DAYS`` config. @@ -32,8 +31,11 @@ def remove_expired_agenda(expiry_days=None): $ python manage.py remove_expired_agenda -m 60 $ python manage.py remove_expired_agenda --expiry 60 """ + await remove_expired_agenda(expiry_days=expiry_days) - num_of_days = int(expiry_days) if expiry_days is not None else get_app_config("AGENDA_EXPIRY_DAYS", 0) + +async def remove_expired_agenda(expiry_days=None): + num_of_days = int(expiry_days) if expiry_days is not None else int(get_app_config("AGENDA_EXPIRY_DAYS", 0)) if num_of_days == 0: logger.info("Expiry days is set to 0, therefor no items will be removed") @@ -45,7 +47,7 @@ def remove_expired_agenda(expiry_days=None): return try: - num_items_removed = _remove_expired_items(utcnow(), num_of_days) + num_items_removed = await _remove_expired_items(utcnow(), num_of_days) finally: unlock(lock_name) @@ -55,90 +57,81 @@ def remove_expired_agenda(expiry_days=None): logger.info(f"Completed removing {num_items_removed} expired agenda items") -def _remove_expired_items(now: datetime, expiry_days: int): +async def _remove_expired_items(now: datetime, expiry_days: int): """Remove expired Event and/or Planning items from the Agenda collection""" logger.info("Starting to remove expired items") - # TODO-ASYNC: revisit once `agenda` module has been migrate to async - agenda_service = get_resource_service("agenda") + agenda_service = AgendaItemService() expiry_datetime = now - timedelta(days=expiry_days) num_items_removed = 0 - for expired_items in _get_expired_items(expiry_datetime): - items_to_remove: Set[str] = set() + async for expired_items in _get_expired_items(expiry_datetime): + items_to_remove: set[str] = set() for item in expired_items: - item_id = item[ID_FIELD] - logger.info(f"Processing expired item {item_id}") - for child_id in _get_expired_chain_ids(item, expiry_datetime): - items_to_remove.add(child_id) + logger.info(f"Processing expired item {item.id}") + items_to_remove |= await _get_expired_chain_ids(item, expiry_datetime) if len(items_to_remove): logger.info(f"Deleting items: {items_to_remove}") num_items_removed += len(items_to_remove) - agenda_service.delete_action(lookup={ID_FIELD: {"$in": list(items_to_remove)}}) + await agenda_service.delete_many({ID_FIELD: {"$in": list(items_to_remove)}}) logger.info("Finished removing expired items from agenda collection") return num_items_removed -def _get_expired_items( - expiry_datetime: datetime, -) -> Generator[List[Dict[str, Any]], datetime, None]: +async def _get_expired_items(expiry_datetime: datetime) -> AsyncGenerator[list[AgendaItem], None]: """Get the expired items, based on ``expiry_datetime``""" - agenda_service = get_resource_service("agenda") + agenda_service = AgendaItemService() + expiry_datetime_str = date_to_str(expiry_datetime) max_loops = get_app_config("MAX_EXPIRY_LOOPS", 50) - for i in range(max_loops): # avoid blocking forever just in case - req = ParsedRequest() - expiry_datetime_str = date_to_str(expiry_datetime) - - # Filters out Planning items with coverages that have not yet expired - coverage_scheduled_query = { - "nested": { - "path": "coverages", - "query": {"range": {"coverages.scheduled": {"gt": expiry_datetime_str}}}, - }, - } - req.args = { - "source": json.dumps( - { - "query": { + + # Filters out Planning items with coverages that have not yet expired + coverage_scheduled_query = { + "nested": { + "path": "coverages", + "query": {"range": {"coverages.scheduled": {"gt": expiry_datetime_str}}}, + }, + } + query = { + "query": { + "bool": { + "filter": [{"range": {"dates.end": {"lte": expiry_datetime_str}}}], + "should": [ + # Match Events directly (stored from v2.3+) + # No more filters required, as we'll query & check planning items separately + {"term": {"item_type": "event"}}, + # Match Planning directly with no associated Event (stored from v2.3+) + { + "bool": { + "filter": [{"term": {"item_type": "planning"}}], + "must_not": [ + {"exists": {"field": "event_id"}}, + coverage_scheduled_query, + ], + } + }, + # Match Event and/or Planning items (stored before v2.3 changes to storage) + { "bool": { - "filter": [{"range": {"dates.end": {"lte": expiry_datetime_str}}}], - "should": [ - # Match Events directly (stored from v2.3+) - # No more filters required, as we'll query & check planning items separately - {"term": {"item_type": "event"}}, - # Match Planning directly with no associated Event (stored from v2.3+) - { - "bool": { - "filter": [{"term": {"item_type": "planning"}}], - "must_not": [ - {"exists": {"field": "event_id"}}, - coverage_scheduled_query, - ], - } - }, - # Match Event and/or Planning items (stored before v2.3 changes to storage) - { - "bool": { - "must_not": [ - {"exists": {"field": "item_type"}}, - coverage_scheduled_query, - ], - } - }, + "must_not": [ + {"exists": {"field": "item_type"}}, + coverage_scheduled_query, ], - "minimum_should_match": 1, - }, + } }, - "sort": [{"dates.start": "asc"}], - "size": get_app_config("MAX_EXPIRY_QUERY_LIMIT", 100), - } - ), - } + ], + "minimum_should_match": 1, + }, + }, + "sort": [{"dates.start": "asc"}], + "size": get_app_config("MAX_EXPIRY_QUERY_LIMIT", 100), + } - items = list(agenda_service.internal_get(req=req, lookup=None)) + for i in range(max_loops): # avoid blocking forever just in case + cursor = await agenda_service.search(query) + items = await cursor.to_list() if not len(items): break @@ -148,36 +141,32 @@ def _get_expired_items( logger.warning(f"_get_expired_items did not finish in {max_loops} loops") -def has_plan_expired(item: Dict[str, Any], expiry_datetime: datetime) -> bool: +def has_plan_expired(item: AgendaItem, expiry_datetime: datetime) -> bool: """Returns ``True`` if the maximum planning/coverage time is before or equal to ``expiry_datetime``""" - max_schedule_datetime = max( - [parse_date_str(coverage["scheduled"]) for coverage in (item.get("coverages") or [])] - + [parse_date_str(item["dates"]["end"])] - ) + max_schedule_datetime = max([coverage.scheduled for coverage in (item.coverages or [])] + [item.dates.end]) return max_schedule_datetime <= expiry_datetime -def _get_expired_chain_ids(parent: Dict[str, Any], expiry_datetime: datetime) -> List[str]: +async def _get_expired_chain_ids(parent: AgendaItem, expiry_datetime: datetime) -> set[str]: """Returns the list of IDs to expire from ``parent`` and it's associated planning items If any one item in the chain has not expired, then this function returns an empty array, otherwise the list of IDs from the parent and any associated items are returned for purging. """ - item_type = get_item_type(parent) - plan_ids = [plan.get(ID_FIELD) for plan in (parent.get("planning_items") or [])] + plan_ids = [plan._id for plan in (parent.planning_items or [])] - if item_type == AGENDA_ITEM_TYPE.PLANNING: - return [] if not has_plan_expired(parent, expiry_datetime) else [parent[ID_FIELD]] + if parent.item_type == AgendaItemType.PLANNING: + return set() if not has_plan_expired(parent, expiry_datetime) else {parent.id} elif not len(plan_ids): - return [parent[ID_FIELD]] + return {parent.id} - agenda_service = get_resource_service("agenda") - items: List[str] = [parent[ID_FIELD]] - for plan in agenda_service.find(where={ID_FIELD: {"$in": plan_ids}}): + cursor = await AgendaItemService().search({ID_FIELD: {"$in": plan_ids}}, use_mongo=True) + items: set[str] = {parent.id} + async for plan in cursor: if not has_plan_expired(plan, expiry_datetime): - return [] - items.append(plan[ID_FIELD]) + return set() + items.add(plan.id) return items diff --git a/newsroom/companies/companies.py b/newsroom/companies/companies.py index 78d8fbe0b..6d54fd2a1 100644 --- a/newsroom/companies/companies.py +++ b/newsroom/companies/companies.py @@ -7,7 +7,7 @@ from superdesk import get_resource_service import newsroom -from newsroom.types import PRODUCT_TYPES +from newsroom.types import SectionEnum from newsroom.signals import company_create from .utils import get_company_section_names, get_company_product_ids @@ -64,7 +64,7 @@ class CompaniesResource(newsroom.Resource): "schema": { "_id": newsroom.Resource.rel("products"), "seats": {"type": "number", "default": 0}, - "section": {"type": "string", "required": True, "allowed": PRODUCT_TYPES}, + "section": {"type": "string", "required": True, "allowed": [t.value for t in SectionEnum]}, }, }, }, diff --git a/newsroom/companies/utils.py b/newsroom/companies/utils.py index 69383d2f8..5b294cf78 100644 --- a/newsroom/companies/utils.py +++ b/newsroom/companies/utils.py @@ -49,6 +49,9 @@ def get_updated_products(updates, original, company: Optional[CompanyResource]) elif "products" in original: products = original["products"] or [] + # Make sure the products are of the correct type + products = [CompanyProduct(**product) if isinstance(product, dict) else product for product in products] + if not company: return products diff --git a/newsroom/factcheck/views.py b/newsroom/factcheck/views.py index 6c07257fd..2eb42f113 100644 --- a/newsroom/factcheck/views.py +++ b/newsroom/factcheck/views.py @@ -79,7 +79,7 @@ async def bookmark(): """ data = await get_json_or_400() assert data.get("items") - update_action_list(data.get("items"), "bookmarks", item_type="items") + await update_action_list(data.get("items"), "bookmarks", item_type="items") push_user_notification( "saved_items", count=await WireSearchServiceAsync().get_current_user_bookmarks_count(SectionEnum.FACTCHECK) ) @@ -91,7 +91,7 @@ async def bookmark(): async def copy(_id): item_type = get_type() get_entity_or_404(_id, item_type) - update_action_list([_id], "copies", item_type=item_type) + await update_action_list([_id], "copies", item_type=item_type) return jsonify(), 200 @@ -119,7 +119,7 @@ async def item(_id): previous_versions = get_previous_versions(item) if "print" in request.args: template = "wire_item_print.html" - update_action_list([_id], "prints", force_insert=True) + await update_action_list([_id], "prints", force_insert=True) else: template = "wire_item.html" return await render_template( diff --git a/newsroom/gettext.py b/newsroom/gettext.py index c54d1e0c7..bb62d6e55 100644 --- a/newsroom/gettext.py +++ b/newsroom/gettext.py @@ -61,7 +61,7 @@ def get_user_timezone(user: User) -> str: return get_app_config("BABEL_DEFAULT_TIMEZONE") or get_app_config("DEFAULT_TIMEZONE") -def get_session_timezone(): +def get_session_timezone() -> str: from newsroom.auth.utils import get_user_or_none_from_request try: diff --git a/newsroom/market_place/views.py b/newsroom/market_place/views.py index 5a105c6c3..df40667cb 100644 --- a/newsroom/market_place/views.py +++ b/newsroom/market_place/views.py @@ -134,7 +134,7 @@ async def bookmark(): """ data = await get_json_or_400() assert data.get("items") - update_action_list(data.get("items"), "bookmarks", item_type="items") + await update_action_list(data.get("items"), "bookmarks", item_type="items") push_user_notification( "saved_items", count=await WireSearchServiceAsync().get_current_user_bookmarks_count(SectionEnum.MARKET_PLACE), @@ -147,7 +147,7 @@ async def bookmark(): async def copy(_id): item_type = get_type() get_entity_or_404(_id, item_type) - update_action_list([_id], "copies", item_type=item_type) + await update_action_list([_id], "copies", item_type=item_type) return jsonify(), 200 @@ -175,7 +175,7 @@ async def item(_id): previous_versions = get_previous_versions(item) if "print" in request.args: template = "wire_item_print.html" - update_action_list([_id], "prints", force_insert=True) + await update_action_list([_id], "prints", force_insert=True) else: template = "wire_item.html" return await render_template( diff --git a/newsroom/media_releases/views.py b/newsroom/media_releases/views.py index 8b23ba22b..4fb40f97b 100644 --- a/newsroom/media_releases/views.py +++ b/newsroom/media_releases/views.py @@ -79,7 +79,7 @@ async def bookmark(): """ data = await get_json_or_400() assert data.get("items") - update_action_list(data.get("items"), "bookmarks", item_type="items") + await update_action_list(data.get("items"), "bookmarks", item_type="items") push_user_notification( "saved_items", count=await WireSearchServiceAsync().get_current_user_bookmarks_count(SectionEnum.MEDIA_RELEASES) ) @@ -91,7 +91,7 @@ async def bookmark(): async def copy(_id): item_type = get_type() get_entity_or_404(_id, item_type) - update_action_list([_id], "copies", item_type=item_type) + await update_action_list([_id], "copies", item_type=item_type) return jsonify(), 200 @@ -119,7 +119,7 @@ async def item(_id): previous_versions = get_previous_versions(item) if "print" in request.args: template = "wire_item_print.html" - update_action_list([_id], "prints", force_insert=True) + await update_action_list([_id], "prints", force_insert=True) else: template = "wire_item.html" return await render_template( diff --git a/newsroom/monitoring/views.py b/newsroom/monitoring/views.py index d6089cb96..71b799f14 100644 --- a/newsroom/monitoring/views.py +++ b/newsroom/monitoring/views.py @@ -156,8 +156,7 @@ async def create(): @login_required async def edit(_id): if request.args.get("context", "") == "wire": - cursor = await WireSearchServiceAsync().get_items_for_action([_id]) - items = await cursor.to_list_raw() + items = await WireSearchServiceAsync().get_items_for_action([_id]) if not len(items): return @@ -243,7 +242,7 @@ async def export(_ids): return jsonify({"message": "Error exporting items to file"}), 400 if _file: - update_action_list(_ids.split(","), "export", force_insert=True) + await update_action_list(_ids.split(","), "export", force_insert=True) await HistoryService().create_history_record(items, "export", user.id, user.company, "monitoring") return send_file( @@ -296,7 +295,7 @@ async def share(): ], ) - update_action_list(data.get("items"), "shares") + await update_action_list(data.get("items"), "shares") await HistoryService().create_history_record(items, "share", current_user.id, current_user.company, "monitoring") return jsonify({"success": True}), 200 @@ -311,7 +310,7 @@ async def bookmark(): """ data = await get_json_or_400() assert data.get("items") - update_action_list(data.get("items"), "bookmarks", item_type="items") + await update_action_list(data.get("items"), "bookmarks", item_type="items") push_user_notification( "saved_items", count=await WireSearchServiceAsync().get_current_user_bookmarks_count(SectionEnum.MONITORING), diff --git a/newsroom/notifications/commands.py b/newsroom/notifications/commands.py index 3ad038656..b94129695 100644 --- a/newsroom/notifications/commands.py +++ b/newsroom/notifications/commands.py @@ -1,30 +1,34 @@ import logging from bson import ObjectId from datetime import datetime, timedelta -from typing import List, Dict, Any, Optional, TypedDict, Tuple, Set, cast +from typing import Any, TypedDict from newsroom.users.service import UsersService from superdesk.core import get_app_config -from superdesk import get_resource_service +from superdesk.core.types import SearchRequest from superdesk.utc import utcnow, utc_to_local from superdesk.celery_task_utils import get_lock_id from superdesk.lock import lock, unlock from newsroom.types import ( - User, - NotificationSchedule, - Company, + UserResourceModel, + NotificationScheduleModel, + CompanyResource, NotificationQueue, + TopicResourceModel, Topic, NotificationType, NotificationTopic, ) -from newsroom.utils import get_user_dict, get_company_dict +from newsroom.utils import get_user_dict_async, get_company_dict_async from newsroom.email import send_user_email from newsroom.celery_app import celery from newsroom.topics.topics_async import get_user_id_to_topic_for_subscribers from newsroom.gettext import get_session_timezone, set_session_timezone +from newsroom.wire import WireSearchServiceAsync, WireSearchRequestArgs +from newsroom.agenda import AgendaSearchServiceAsync, AgendaSearchRequestArgs + from .services import NotificationQueueService logger = logging.getLogger(__name__) @@ -32,12 +36,12 @@ class NotificationEmailTopicEntry(TypedDict): topic: Topic - item: Dict[str, Any] + item: dict[str, Any] -TopicEntriesDict = Dict[str, List[NotificationEmailTopicEntry]] +TopicEntriesDict = dict[str, list[NotificationEmailTopicEntry]] -TopicMatchTable = Dict[str, List[Tuple[str, int]]] +TopicMatchTable = dict[str, list[tuple[str, int]]] class SendScheduledNotificationEmails: @@ -74,8 +78,8 @@ async def run_schedules(self, force: bool): try: now_utc = utcnow().replace(second=0, microsecond=0) - companies = await get_company_dict(False) - users = await get_user_dict(False) + companies = await get_company_dict_async(False) + users = await get_user_dict_async(False) user_topic_map = await get_user_id_to_topic_for_subscribers(NotificationType.SCHEDULED) schedules_cursor = await notification_queue_service.search({}) @@ -86,9 +90,8 @@ async def run_schedules(self, force: bool): return for schedule in schedules: - user_id = schedule.user try: - user = users.get(str(user_id)) + user = users.get(schedule.user) if not user: # User not found, this account might be disabled @@ -96,46 +99,58 @@ async def run_schedules(self, force: bool): await notification_queue_service.reset_queue(schedule.user) continue - if not user.get("notification_schedule"): - user["notification_schedule"] = {} + if not user.notification_schedule: + # MyPy fails here with call-arg error, not sure why. Ignore it for now + user.notification_schedule = NotificationScheduleModel( # type: ignore[call-arg] + timezone=get_session_timezone(), + times=[], + ) - user["notification_schedule"].setdefault("timezone", get_session_timezone()) - user["notification_schedule"].setdefault( - "times", get_app_config("DEFAULT_SCHEDULED_NOTIFICATION_TIMES") - ) + if not user.notification_schedule.timezone: + user.notification_schedule.timezone = get_session_timezone() + if not user.notification_schedule.times: + user.notification_schedule.times = get_app_config("DEFAULT_SCHEDULED_NOTIFICATION_TIMES") - company = companies.get(str(user.get("company", ""))) + company = companies.get(user.company) await self.process_schedule( schedule, user, company, now_utc, user_topic_map.get(user["_id"]) or {}, force ) except Exception as e: logger.exception(e) - logger.error("Failed to run schedule for user %s", user_id) + logger.error("Failed to run schedule for user %s", schedule.user) async def process_schedule( self, schedule: NotificationQueue, - user: User, - company: Optional[Company], + user: UserResourceModel, + company: CompanyResource | None, now_utc: datetime, - user_topics: Dict[ObjectId, Topic], + user_topics: dict[ObjectId, TopicResourceModel], force: bool, ): """ Processes a user's notification schedule. Sends an email based on whether topics matched or not for the user's scheduled notification period. """ + timezone: str = ( + user.notification_schedule.timezone + if user.notification_schedule and user.notification_schedule.timezone + else get_session_timezone() + ) + now_local = utc_to_local(timezone, now_utc) - now_local = utc_to_local(user["notification_schedule"]["timezone"], now_utc) - - if not self.is_scheduled_to_run_for_user(user["notification_schedule"], now_local, force): + if not user.notification_schedule or not self.is_scheduled_to_run_for_user( + user.notification_schedule, now_local, force + ): return # Set the timezone on the session, so Babel is able to get the timezone for this user # when rendering the email, otherwise it uses the system default - set_session_timezone(user["notification_schedule"]["timezone"]) + set_session_timezone(timezone) - topic_entries, topic_match_table = self.get_topic_entries_and_match_table(schedule, user, company, user_topics) + topic_entries, topic_match_table = await self.get_topic_entries_and_match_table( + schedule, user, company, user_topics + ) template_kwargs = dict( app_name=get_app_config("SITE_NAME"), @@ -162,22 +177,24 @@ async def process_schedule( # Now clear the topic match queue await self._clear_user_notification_queue(user) - def is_scheduled_to_run_for_user(self, schedule: NotificationSchedule, now_local: datetime, force: bool): + def is_scheduled_to_run_for_user( + self, schedule: NotificationScheduleModel, now_local: datetime, force: bool + ) -> bool: """ Determines if the notification schedule should run for the user based on their scheduled times and the current time. """ last_run_time_local = ( - utc_to_local(schedule["timezone"], schedule.get("last_run_time")).replace(second=0, microsecond=0) - if schedule.get("last_run_time") is not None + utc_to_local(schedule.timezone, schedule.last_run_time).replace(second=0, microsecond=0) + if schedule.last_run_time is not None else None ) if last_run_time_local is None and force: return True - for schedule_datetime in self._convert_schedule_times(now_local, schedule["times"]): + for schedule_datetime in self._convert_schedule_times(now_local, schedule.times): schedule_within_time = timedelta() <= now_local - schedule_datetime < timedelta(minutes=5) if last_run_time_local is None and schedule_within_time: @@ -187,12 +204,12 @@ def is_scheduled_to_run_for_user(self, schedule: NotificationSchedule, now_local return False - async def _clear_user_notification_queue(self, user: User): - await NotificationQueueService().reset_queue(user["_id"]) - await UsersService().update_notification_schedule_run_time(cast(dict, user), utcnow()) + async def _clear_user_notification_queue(self, user: UserResourceModel): + await NotificationQueueService().reset_queue(user.id) + await UsersService().update_notification_schedule_run_time(user, utcnow()) - def _convert_schedule_times(self, now_local: datetime, times: List[str]) -> List[datetime]: - schedule_datetimes: List[datetime] = [] + def _convert_schedule_times(self, now_local: datetime, times: list[str]) -> list[datetime]: + schedule_datetimes: list[datetime] = [] for time_str in times: time_parts = time_str.split(":") @@ -200,7 +217,7 @@ def _convert_schedule_times(self, now_local: datetime, times: List[str]) -> List return schedule_datetimes - def get_queue_entries_for_section(self, queue: NotificationQueue, section: str) -> List[NotificationTopic]: + def get_queue_entries_for_section(self, queue: NotificationQueue, section: str) -> list[NotificationTopic]: """ Return the entries in the queue for a given section sorted by `last_item_arrived` attribute """ @@ -211,14 +228,14 @@ def get_queue_entries_for_section(self, queue: NotificationQueue, section: str) reverse=True, ) - def get_latest_item_from_topic_queue( + async def get_latest_item_from_topic_queue( self, topic_queue: NotificationTopic, - topic: Topic, - user: User, - company: Optional[Company], - exclude_items: Set[str], - ) -> Optional[Dict[str, Any]]: + topic: TopicResourceModel, + user: UserResourceModel, + company: CompanyResource | None, + exclude_items: set[str], + ) -> dict[str, Any] | None: """ Retrieves the latest item from a topic queue for the user and company, excluding specific items from the result. """ @@ -227,28 +244,34 @@ def get_latest_item_from_topic_queue( if item_id in exclude_items: continue - # TODO-ASYNC: update when `wire_search` and `agenda_search` are migrated to async - search_service = get_resource_service("wire_search" if topic["topic_type"] == "wire" else "agenda") - - query = search_service.get_topic_query(topic, user, company, args={"es_highlight": 1, "ids": [item_id]}) + search_service: AgendaSearchServiceAsync | WireSearchServiceAsync + if topic.topic_type == "agenda": + search_service = AgendaSearchServiceAsync() + query = await search_service.get_topic_items_query( + topic, user, company, args=AgendaSearchRequestArgs(es_highlight=True, ids=[item_id], page_size=1) + ) + else: + search_service = WireSearchServiceAsync() + query = await search_service.get_topic_items_query( + topic, user, company, args=WireSearchRequestArgs(es_highlight=True, ids=[item_id], page_size=1) + ) if not query: # user might not have access to section anymore return None - items = search_service.get_items_by_query(query, size=1) - - if items.count(): - return items[0] + cursor = await search_service.service.find(SearchRequest(elastic=query)) + if await cursor.count(): + return await cursor.next_raw() return None - def get_topic_entries_and_match_table( + async def get_topic_entries_and_match_table( self, schedule: NotificationQueue, - user: User, - company: Optional[Company], - user_topics: Dict[ObjectId, Topic], - ) -> Tuple[TopicEntriesDict, TopicMatchTable]: + user: UserResourceModel, + company: CompanyResource | None, + user_topics: dict[ObjectId, TopicResourceModel], + ) -> tuple[TopicEntriesDict, TopicMatchTable]: """ Generates the topic entries and a match table for a user's scheduled notifications. @@ -265,7 +288,7 @@ def get_topic_entries_and_match_table( "agenda": [], } - topics_matched: List[ObjectId] = [] + topics_matched: list[ObjectId] = [] topic_match_table: TopicMatchTable = { "wire": [], "agenda": [], @@ -275,7 +298,7 @@ def get_topic_entries_and_match_table( return topic_entries, topic_match_table for section in ["wire", "agenda"]: - items_in_entries: Set[str] = set() + items_in_entries: set[str] = set() for topic_queue in self.get_queue_entries_for_section(schedule, section): if not len(topic_queue.items): @@ -287,10 +310,12 @@ def get_topic_entries_and_match_table( # Topic was not found for some reason continue - topic_match_table[section].append((topic["label"], len(topic_queue.items))) - topics_matched.append(topic["_id"]) + topic_match_table[section].append((topic.label, len(topic_queue.items))) + topics_matched.append(topic.id) - latest_item = self.get_latest_item_from_topic_queue(topic_queue, topic, user, company, items_in_entries) + latest_item = await self.get_latest_item_from_topic_queue( + topic_queue, topic, user, company, items_in_entries + ) if latest_item is None: # Latest item was not found. It may have matched multiple topics @@ -299,14 +324,14 @@ def get_topic_entries_and_match_table( items_in_entries.add(latest_item["_id"]) topic_entries[section].append( NotificationEmailTopicEntry( - topic=topic, + topic=topic.to_dict(), item=latest_item, ) ) - for _, topic in user_topics.items(): - if topic["_id"] not in topics_matched: - topic_match_table[topic["topic_type"]].append((topic["label"], 0)) + for topic in user_topics.values(): + if topic.id not in topics_matched: + topic_match_table[topic.topic_type].append((topic.label, 0)) return topic_entries, topic_match_table diff --git a/newsroom/notifications/utils.py b/newsroom/notifications/utils.py index 65e7c1fda..33158769c 100644 --- a/newsroom/notifications/utils.py +++ b/newsroom/notifications/utils.py @@ -4,21 +4,18 @@ from bson import ObjectId from typing import Any -import superdesk from superdesk.utc import utcnow -from superdesk.flask import session from superdesk.core import get_app_config from superdesk.notification import push_notification +from newsroom.exceptions import AuthorizationError +from newsroom.auth.utils import get_user_id_from_request from newsroom.wire import WireSearchServiceAsync from newsroom.topics.topics_async import TopicService from .services import NotificationsService -def user_notifications_lookup(user_id: str | ObjectId) -> dict[str, Any]: - if isinstance(user_id, str): - user_id = ObjectId(user_id) - +def user_notifications_lookup(user_id: ObjectId) -> dict[str, Any]: ttl = get_app_config("NOTIFICATIONS_TTL", 1) return { "user": user_id, @@ -26,7 +23,7 @@ def user_notifications_lookup(user_id: str | ObjectId) -> dict[str, Any]: } -async def get_user_notifications(user_id: str) -> list[dict[str, Any]]: +async def get_user_notifications(user_id: ObjectId) -> list[dict[str, Any]]: """ Returns the notification entries for the given user """ @@ -40,10 +37,12 @@ async def get_initial_notifications() -> dict[str, Any] | None: Returns the stories that user has notifications for :return: List of stories. None if there is not user session. """ - if not session.get("user"): + + try: + user_id = get_user_id_from_request(None) + except AuthorizationError: return None - user_id = session["user"] lookup = user_notifications_lookup(user_id) notifications = await NotificationsService().search(lookup) @@ -58,30 +57,30 @@ async def get_notifications_with_items() -> dict[str, Any] | None: Returns the stories that user has notifications for :return: List of stories. None if there is not user session. """ - if not session.get("user"): + + from newsroom.agenda import AgendaSearchServiceAsync + + try: + user_id = get_user_id_from_request(None) + except AuthorizationError: return None - saved_notifications = await get_user_notifications(session["user"]) + saved_notifications = await get_user_notifications(user_id) item_ids = [n["item"] for n in saved_notifications] - items = [] - wire_cursor, topics_cursor = await gather( + wire_cursor, agenda_cursor, topic_items = await gather( WireSearchServiceAsync().get_items_by_id(item_ids), - TopicService().search({"_id": {"$in": item_ids}}), + AgendaSearchServiceAsync().get_items_by_id(item_ids), + TopicService().find_by_ids_raw(item_ids), + ) + wire_items, agenda_items = await gather( + wire_cursor.to_list_raw(), + agenda_cursor.to_list_raw(), ) - wire_items, topic_items = await gather(wire_cursor.to_list_raw(), topics_cursor.to_list_raw()) - - items.extend(wire_items) - items.extend(topic_items) - - try: - items.extend(superdesk.get_resource_service("agenda").get_items(item_ids)) - except (KeyError, TypeError): # agenda disabled - pass return { - "user": session["user"], - "items": list(items), + "user": str(user_id), + "items": wire_items + agenda_items + topic_items, "notifications": saved_notifications, } diff --git a/newsroom/products/__init__.py b/newsroom/products/__init__.py index fcbe91356..0ebbe8e4a 100644 --- a/newsroom/products/__init__.py +++ b/newsroom/products/__init__.py @@ -12,9 +12,20 @@ from . import products from .service import ProductsService from .views import get_settings_data, products_endpoints -from .utils import get_products_by_company +from .utils import ( + get_products_by_company, + get_products_by_company_async, + get_products_by_user_async, + get_products_by_navigation_async, +) -__all__ = ["get_products_by_company", "ProducsService"] +__all__ = [ + "get_products_by_company", + "get_products_by_company_async", + "get_products_by_user_async", + "get_products_by_navigation_async", + "ProductsService", +] def init_module(app: SuperdeskAsyncApp): diff --git a/newsroom/products/products.py b/newsroom/products/products.py index 76d4e6df1..710ef90a7 100644 --- a/newsroom/products/products.py +++ b/newsroom/products/products.py @@ -7,7 +7,7 @@ import superdesk from superdesk.services import CacheableService -from newsroom.types import Company, Product, User, NavigationIds, PRODUCT_TYPES +from newsroom.types import Company, Product, User, NavigationIds, SectionEnum from newsroom.utils import any_objectid_in_list, parse_objectid IdsList = NavigationIds @@ -35,7 +35,7 @@ class ProductsResource(newsroom.Resource): "schema": newsroom.Resource.rel("companies"), "nullable": True, }, - "product_type": {"type": "string", "default": "wire", "allowed": PRODUCT_TYPES}, + "product_type": {"type": "string", "default": "wire", "allowed": [t.value for t in SectionEnum]}, "original_creator": newsroom.Resource.rel("users"), "version_creator": newsroom.Resource.rel("users"), } diff --git a/newsroom/products/service.py b/newsroom/products/service.py index 86f4a2426..f2f9d5408 100644 --- a/newsroom/products/service.py +++ b/newsroom/products/service.py @@ -32,15 +32,15 @@ async def create(self, docs: Sequence[ProductResourceModel | dict[str, Any]]) -> company_service = CompanyService() for company_id, products in company_products.items(): - company = await company_service.find_by_id(company_id) + company = await company_service.find_by_id_raw(company_id) if company: updates = { - "products": company.products or [], + "products": company.get("products") or [], } for product in products: updates["products"].append({"_id": product["_id"], "section": product["product_type"], "seats": 0}) - await company_service.system_update(company.id, updates) + await company_service.system_update(company["_id"], updates) return res diff --git a/newsroom/products/utils.py b/newsroom/products/utils.py index e1ea07fd9..a465dd38c 100644 --- a/newsroom/products/utils.py +++ b/newsroom/products/utils.py @@ -1,8 +1,15 @@ from typing import Any -from newsroom.types import Company, Product, NavigationIds -from newsroom.utils import parse_objectid +from newsroom.types import ( + Company, + Product, + NavigationIds, + ProductResourceModel, + CompanyResource, + SectionEnum, + UserResourceModel, +) from .service import ProductsService - +from ..utils import any_objectid_in_list IdsList = NavigationIds @@ -10,11 +17,33 @@ async def get_products_by_company( company: Company | None, navigation_ids: NavigationIds | None = None, - product_type: str | None = None, + product_type: SectionEnum | None = None, unlimited_only: bool = False, ) -> list[Product]: """Get the list of products for a company + :param company: Company + :param navigation_ids: List of Navigation Ids + :param product_type: Type of the product + :param unlimited_only: Include unlimited only products + """ + + return [ + product.to_dict() + for product in await get_products_by_company_async( + CompanyResource.from_dict(company), navigation_ids, product_type, unlimited_only + ) + ] + + +async def get_products_by_company_async( + company: CompanyResource | None, + navigation_ids: NavigationIds | None = None, + product_type: SectionEnum | None = None, + unlimited_only: bool = False, +) -> list[ProductResourceModel]: + """Get the list of products for a company + :param company: Company :param navigation_ids: List of Navigation Ids :param product_type: Type of the product @@ -24,20 +53,44 @@ async def get_products_by_company( return [] company_product_ids = [ - parse_objectid(product["_id"]) - for product in company.get("products") or [] - if (product_type is None or product["section"] == product_type) - and (not unlimited_only or not product.get("seats")) + product._id + for product in company.products or [] + if (product_type is None or product.section == product_type) and (not unlimited_only or not product.seats) ] if company_product_ids: lookup = get_products_lookup(company_product_ids, navigation_ids) cursor = await ProductsService().search(lookup) - return await cursor.to_list_raw() + return await cursor.to_list() return [] +async def get_products_by_user_async( + user: UserResourceModel, section: SectionEnum, navigation_ids: NavigationIds | None +) -> list[ProductResourceModel]: + ids = [p._id for p in user.products or [] if p.section == section] + if ids: + lookup = get_products_lookup(ids, navigation_ids) + cursor = await ProductsService().search(lookup) + return await cursor.to_list() + + return [] + + +async def get_products_by_navigation_async( + navigation_ids: NavigationIds, product_type: SectionEnum | None = None +) -> list[ProductResourceModel]: + return [ + product + async for product in ProductsService().get_all() + if ( + any_objectid_in_list(navigation_ids, product.navigations or []) + and (product_type is None or product.product_type == product_type) + ) + ] + + def get_products_lookup(product_ids: IdsList, navigation_ids: IdsList | None) -> dict[str, Any]: lookup = {"_id": {"$in": product_ids}} diff --git a/newsroom/push/agenda_manager.py b/newsroom/push/agenda_manager.py index 1b212e987..95b076a33 100644 --- a/newsroom/push/agenda_manager.py +++ b/newsroom/push/agenda_manager.py @@ -1,18 +1,17 @@ import logging -from datetime import datetime from superdesk.utc import utcnow -from superdesk import get_resource_service from planning.common import WORKFLOW_STATE from newsroom.wire import url_for_wire, WireSearchServiceAsync from newsroom.utils import parse_date_str from newsroom.core import get_current_wsgi_app + from newsroom.agenda.utils import get_latest_available_delivery, TO_BE_CONFIRMED_FIELD +from newsroom.agenda.notifications import notify_agenda_update from .tasks import notify_new_wire_item -from .utils import format_qcode_items, get_display_dates, parse_dates, set_dates - +from .utils import get_display_dates logger = logging.getLogger(__name__) @@ -21,141 +20,6 @@ class AgendaManager: - def init_adhoc_agenda(self, planning, agenda): - """ - Inits an adhoc agenda item - """ - # check if there's an existing ad-hoc - agenda["item_type"] = "planning" - - # planning dates is saved as the dates of the new agenda - agenda["dates"] = { - "start": planning["planning_date"], - "end": planning["planning_date"], - } - - agenda["state"] = planning["state"] - if planning.get("pubstatus") == "cancelled": - agenda["watches"] = [] - - return agenda - - def set_metadata_from_event(self, agenda, event, set_doc_id=True): - """ - Sets agenda metadata from a given event - """ - parse_dates(event) - - # setting _id of agenda to be equal to event - if set_doc_id: - agenda.setdefault("_id", event["guid"]) - - agenda["item_type"] = "event" - agenda["guid"] = event["guid"] - agenda["event_id"] = event["guid"] - agenda["recurrence_id"] = event.get("recurrence_id") - agenda["name"] = event.get("name") - agenda["slugline"] = event.get("slugline") - agenda["definition_short"] = event.get("definition_short") - agenda["definition_long"] = event.get("definition_long") - agenda["version"] = event.get("version") - agenda["versioncreated"] = event.get("versioncreated") - agenda["calendars"] = event.get("calendars") - agenda["location"] = event.get("location") - agenda["ednote"] = event.get("ednote") - agenda["state"] = event.get("state") - agenda["state_reason"] = event.get("state_reason") - agenda["place"] = event.get("place") - agenda["subject"] = format_qcode_items(event.get("subject")) - agenda["products"] = event.get("products") - agenda["service"] = format_qcode_items(event.get("anpa_category")) - agenda["event"] = event - agenda["registration_details"] = event.get("registration_details") - agenda["invitation_details"] = event.get("invitation_details") - agenda["language"] = event.get("language") - agenda["source"] = event.get("source") - - set_dates(agenda) - - def set_metadata_from_planning(self, agenda, planning_item, force_adhoc=False): - """Sets agenda metadata from a given planning""" - - parse_dates(planning_item) - set_dates(agenda) - - if not planning_item.get("event_item") or force_adhoc: - # adhoc planning item - agenda["name"] = planning_item.get("name") - agenda["headline"] = planning_item.get("headline") - agenda["slugline"] = planning_item.get("slugline") - agenda["ednote"] = planning_item.get("ednote") - agenda["place"] = planning_item.get("place") - agenda["subject"] = format_qcode_items(planning_item.get("subject")) - agenda["products"] = planning_item.get("products") - agenda["urgency"] = planning_item.get("urgency") - agenda["definition_short"] = planning_item.get("description_text") or agenda.get("definition_short") - agenda["definition_long"] = planning_item.get("abstract") or agenda.get("definition_long") - agenda["service"] = format_qcode_items(planning_item.get("anpa_category")) - agenda["state"] = planning_item.get("state") - agenda["state_reason"] = planning_item.get("state_reason") - agenda["language"] = planning_item.get("language") - agenda["source"] = planning_item.get("source") - - if planning_item.get("event_item") and force_adhoc: - agenda["event_id"] = planning_item["event_item"] - - if not agenda.get("planning_items"): - agenda["planning_items"] = [] - - new_plan = False - plan = next( - (p for p in (agenda.get("planning_items")) if p.get("guid") == planning_item.get("guid")), - {}, - ) - - if not plan: - new_plan = True - - agenda_versioncreated: datetime = agenda["versioncreated"] - plan_versioncreated: datetime = parse_date_str(planning_item.get("versioncreated")) or agenda_versioncreated - - plan["_id"] = planning_item.get("_id") or planning_item.get("guid") - plan["guid"] = planning_item.get("guid") - plan["slugline"] = planning_item.get("slugline") - plan["description_text"] = planning_item.get("description_text") - plan["headline"] = planning_item.get("headline") - plan["name"] = planning_item.get("name") - plan["abstract"] = planning_item.get("abstract") - plan["place"] = planning_item.get("place") - plan["subject"] = format_qcode_items(planning_item.get("subject")) - plan["service"] = format_qcode_items(planning_item.get("anpa_category")) - plan["urgency"] = planning_item.get("urgency") - plan["planning_date"] = planning_item.get("planning_date") - plan["coverages"] = planning_item.get("coverages") or [] - plan["ednote"] = planning_item.get("ednote") - plan["internal_note"] = planning_item.get("internal_note") - plan["versioncreated"] = plan_versioncreated - plan["firstcreated"] = parse_date_str(planning_item.get("firstcreated")) or agenda["firstcreated"] - plan["state"] = planning_item.get("state") - plan["state_reason"] = planning_item.get("state_reason") - plan["products"] = planning_item.get("products") - plan["agendas"] = planning_item.get("agendas") - plan[TO_BE_CONFIRMED_FIELD] = planning_item.get(TO_BE_CONFIRMED_FIELD) - plan["language"] = planning_item.get("language") - plan["source"] = planning_item.get("source") - - if new_plan: - agenda["planning_items"].append(plan) - - # Update the versioncreated datetime from Planning item if it's newer than the parent item - try: - if plan_versioncreated > agenda_versioncreated: - agenda["versioncreated"] = plan_versioncreated - except (KeyError, TypeError): - pass - - return new_plan - async def set_agenda_planning_items(self, agenda, orig_agenda, planning_item, action="add", send_notification=True): """ Updates the list of planning items of agenda. If action is 'add' then adds the new one. @@ -169,9 +33,7 @@ async def set_agenda_planning_items(self, agenda, orig_agenda, planning_item, ac len(agenda["planning_items"]) < len(existing_planning_items) and len(planning_item.get("coverages") or []) > 0 ): - await get_resource_service("agenda").notify_agenda_update( - agenda, orig_agenda, planning_item, True, planning_item - ) + await notify_agenda_update(agenda, orig_agenda, planning_item, True, planning_item) agenda["coverages"], coverage_changes = await self.get_coverages( agenda["planning_items"], @@ -188,7 +50,7 @@ async def set_agenda_planning_items(self, agenda, orig_agenda, planning_item, ac or coverage_changes.get("coverage_modified") ) ): - await get_resource_service("agenda").notify_agenda_update(agenda, orig_agenda, planning_item, True) + await notify_agenda_update(agenda, orig_agenda, planning_item, True) agenda["display_dates"] = get_display_dates(agenda["planning_items"]) agenda.pop("_updated", None) diff --git a/newsroom/push/notifications.py b/newsroom/push/notifications.py index 7ce234670..f8ae4baa5 100644 --- a/newsroom/push/notifications.py +++ b/newsroom/push/notifications.py @@ -5,12 +5,10 @@ from superdesk.core.types import SearchRequest -from superdesk import get_resource_service from superdesk.core import get_app_config from newsroom.core import get_current_wsgi_app -from newsroom.types import Company, Topic, User, UserResourceModel, CompanyResource, TopicResourceModel, SectionEnum -from newsroom.wire.filters import WireSearchRequestArgs +from newsroom.types import User, UserResourceModel, CompanyResource, TopicResourceModel, SectionEnum from newsroom.history_async import get_history_users from newsroom.utils import get_user_dict_async, get_company_dict_async from newsroom.agenda.utils import push_agenda_item_notification @@ -25,13 +23,13 @@ ) from newsroom.notifications import push_notification, save_user_notifications, NotificationQueueService from newsroom.wire import WireSearchServiceAsync +from newsroom.wire.filters import WireSearchRequestArgs +from newsroom.agenda import AgendaSearchServiceAsync +from newsroom.agenda.filters import AgendaSearchRequestArgs logger = logging.getLogger(__name__) -# TODO-ASYNC: revisit when agenda and wire_search are async - - def is_canceled(item: dict[str, Any]) -> bool: return item.get("pubstatus", item.get("state")) in ["canceled", "cancelled"] @@ -82,7 +80,7 @@ async def notify_wire_topic_matches( self, item: dict[str, Any], users: dict[ObjectId, UserResourceModel], companies: dict[ObjectId, CompanyResource] ) -> set[ObjectId]: topics = await get_topics_with_subscribers_async("wire") - topic_matches = await WireSearchServiceAsync().get_mathing_topics_for_item( + topic_matches = await WireSearchServiceAsync().get_matching_topics_for_item( item["_id"], topics, list(users.values()), companies ) @@ -157,12 +155,19 @@ async def send_topic_notification_emails( cursor = await wire_service.service.find(SearchRequest(elastic=query)) items = await cursor.to_list_raw() else: - search_service = get_resource_service("agenda") - query = search_service.get_topic_query( - topic, user, company, args={"es_highlight": 1, "ids": [item["_id"]]} + agenda_service = AgendaSearchServiceAsync() + query = await agenda_service.get_topic_items_query( + topic, + user, + company, + args=AgendaSearchRequestArgs( + page_size=1, + ids=[item["_id"]], + es_highlight=True, + ), ) - - items = list(search_service.get_items_by_query(query, size=1)) + cursor = await agenda_service.service.find(SearchRequest(elastic=query)) + items = await cursor.to_list_raw() highlighted_item = item if len(items) > 0: @@ -278,16 +283,13 @@ async def notify_agenda_topic_matches( ) -> set[ObjectId]: topics = await get_topics_with_subscribers_async("agenda") - # TODO-ASYNC: Remove these conversions once Agenda is updated to async - users_dict: dict[str, User] = {str(user.id): user.to_dict() for user in users.values()} - companies_dict: dict[str, Company] = {str(company.id): company.to_dict() for company in companies.values()} - topics_list: list[Topic] = [topic.to_dict() for topic in topics] - topic_matches = get_resource_service("agenda").get_matching_topics( - item["_id"], topics_list, users_dict, companies_dict + topic_matches = await AgendaSearchServiceAsync().get_matching_topics_for_item( + item["_id"], topics, list(users.values()), companies ) # Include topics where the ``query`` is ``item["_id"]`` - topic_matches.extend( + users_dict: dict[str, User] = {str(user.id): user.to_dict() for user in users.values()} + topic_matches |= set( [ topic for topic in await get_agenda_notification_topics_for_query_by_id(item, users_dict) diff --git a/newsroom/push/publishing.py b/newsroom/push/publishing.py index 4fe69284d..b0c44f058 100644 --- a/newsroom/push/publishing.py +++ b/newsroom/push/publishing.py @@ -6,7 +6,6 @@ from superdesk.types import Item from superdesk.utc import utcnow from superdesk.core import get_app_config -from superdesk import get_resource_service from superdesk.resource_fields import VERSION, ID_FIELD, GUID_FIELD from superdesk.text_utils import get_word_count, get_char_count @@ -19,19 +18,18 @@ from newsroom.types import WireItem from newsroom.utils import parse_date_str from newsroom.wire import WireSearchServiceAsync +from newsroom.agenda import AgendaItemService +from newsroom.agenda.notifications import notify_agenda_update from .tasks import notify_new_agenda_item from .agenda_manager import AgendaManager -from .utils import fix_hrefs, fix_updates, get_event_dates, set_dates, validate_event_push +from .utils import fix_hrefs, fix_updates, set_dates, validate_event_push logger = logging.getLogger(__name__) agenda_manager = AgendaManager() -# TODO-ASYNC: Revisit when agenda and content_api are async - - class Publisher: async def publish_item(self, doc: Item, original: Item): """Duplicating the logic from content_api.publish service.""" @@ -91,7 +89,7 @@ async def publish_item(self, doc: Item, original: Item): try: if doc.get("coverage_id"): - agenda_items = await get_resource_service("agenda").set_delivery(doc) + agenda_items = await AgendaItemService().set_delivery(doc) if agenda_items: [notify_new_agenda_item.delay(item["_id"], check_topics=False) for item in agenda_items] except Exception as ex: @@ -142,19 +140,16 @@ async def publish_event(self, event: dict[str, Any], orig: dict[str, Any]): if file_ref.get("media"): file_ref.setdefault("href", app.upload_url(file_ref["media"])) - _id = event["guid"] - service = get_resource_service("agenda") - plan_ids = event.pop("plans", []) + agenda_id = event["guid"] + service = AgendaItemService() if not orig: # new event - agenda: dict[str, Any] = {} - agenda_manager.set_metadata_from_event(agenda, event) - agenda["dates"] = get_event_dates(event) + agenda, plan_ids = service.convert_event_to_agenda_dict({}, event) # Retrieve all current Planning items and add them into this Event agenda.setdefault("planning_items", []) - for plan in service.find(where={"_id": {"$in": plan_ids}}): + for plan in await (await service.search({"_id": {"$in": plan_ids}}, use_mongo=True)).to_list_raw(): planning_item = plan["planning_items"][0] agenda["planning_items"].append(planning_item) await agenda_manager.set_agenda_planning_items( @@ -164,10 +159,10 @@ async def publish_event(self, event: dict[str, Any], orig: dict[str, Any]): if not plan.get("event_id"): # Make sure the Planning item has an ``event_id`` defined # This can happen when pushing a Planning item before linking to an Event - service.system_update(plan["_id"], {"event_id": _id}, plan) + await service.system_update(plan["_id"], {"event_id": agenda_id}) signals.publish_event.send(app.as_any(), item=agenda, is_new=True) - _id = service.post([agenda])[0] + agenda_id = (await service.create([agenda]))[0] else: # replace the original document updates = None @@ -195,9 +190,7 @@ async def publish_event(self, event: dict[str, Any], orig: dict[str, Any]): WORKFLOW_STATE.POSTPONED, ]: # schedule is changed, recalculate the dates, planning id and coverages from dates will be removed - updates = {} - agenda_manager.set_metadata_from_event(updates, event, False) - updates["dates"] = get_event_dates(event) + updates, _ = service.convert_event_to_agenda_dict({}, event, set_doc_id=False) updates["coverages"] = None updates["planning_items"] = None @@ -208,12 +201,12 @@ async def publish_event(self, event: dict[str, Any], orig: dict[str, Any]): "event": event, "version": event.get("version", event.get(VERSION)), "state": event["state"], - "dates": get_event_dates(event), + # "dates": get_event_dates(event), "planning_items": orig.get("planning_items"), "coverages": orig.get("coverages"), } - agenda_manager.set_metadata_from_event(updates, event, False) + service.convert_event_to_agenda_dict(updates, event, set_doc_id=False) else: logger.info("Ignoring event %s", orig["_id"]) @@ -222,53 +215,56 @@ async def publish_event(self, event: dict[str, Any], orig: dict[str, Any]): updated = orig.copy() updated.update(updates) signals.publish_event.send(app.as_any(), item=updated, updates=updates, orig=orig, is_new=False) - service.patch(orig["_id"], updates) + await service.update(orig["_id"], updates) updates["_id"] = orig["_id"] - await get_resource_service("agenda").notify_agenda_update(updates, orig) + await notify_agenda_update(updates, orig) - return _id + return agenda_id async def publish_planning_item(self, planning: dict[str, Any], orig: dict[str, Any]): - service = get_resource_service("agenda") + service = AgendaItemService() agenda = deepcopy(orig) - agenda_manager.init_adhoc_agenda(planning, agenda) - # Update agenda metadata - new_plan = agenda_manager.set_metadata_from_planning(agenda, planning, force_adhoc=True) + _, new_plan = await service.convert_planning_to_agenda_dict(agenda, planning, force_adhoc=True) # Add the planning item to the list await agenda_manager.set_agenda_planning_items(agenda, orig, planning, action="add" if new_plan else "update") app = get_current_wsgi_app() - if not agenda.get("_id"): + if not orig.get("_id"): # Setting ``_id`` of Agenda to be equal to the Planning item if there's no Event ID agenda.setdefault("_id", planning["guid"]) agenda.setdefault("guid", planning["guid"]) signals.publish_planning.send(app.as_any(), item=agenda, is_new=new_plan) - return service.post([agenda])[0] + return (await service.create([agenda]))[0] else: # Replace the original signals.publish_planning.send(app.as_any(), item=agenda, is_new=new_plan) - service.patch(agenda["_id"], agenda) + await service.update(agenda["_id"], agenda) return agenda["_id"] async def publish_planning_into_event(self, planning: dict[str, Any]) -> Optional[str]: if not planning.get("event_item"): return None - service = get_resource_service("agenda") + service = AgendaItemService() event_id = planning["event_item"] plan_id = planning["guid"] - orig_agenda = service.find_one(req=None, _id=event_id) - if not orig_agenda: + orig_agenda: dict[str, Any] | None = None + event = await service.find_by_id(event_id) + if event: + orig_agenda = event.to_dict() + else: # Item not found using ``event_item`` attribute # Try again using ``guid`` attribute - orig_agenda = service.find_one(req=None, _id=plan_id) + plan = await service.find_by_id(plan_id) + if plan: + orig_agenda = plan.to_dict() - if (orig_agenda or {}).get("item_type") != "event": + if orig_agenda is None or (orig_agenda or {}).get("item_type") != "event": # event id exists in planning item but event is not in the system logger.warning(f"Event '{event_id}' for Planning '{plan_id}' not found") return None @@ -281,11 +277,12 @@ async def publish_planning_into_event(self, planning: dict[str, Any]) -> Optiona ): # Remove the Planning item from the list await agenda_manager.set_agenda_planning_items(agenda, orig_agenda, planning, action="remove") - service.patch(agenda["_id"], agenda) + await service.update(agenda["_id"], agenda) return None # Update agenda metadata - new_plan = agenda_manager.set_metadata_from_planning(agenda, planning) + _, new_plan = await service.convert_planning_to_agenda_dict(agenda, planning) + # new_plan = agenda_manager.set_metadata_from_planning(agenda, planning) # Add the Planning item to the list await agenda_manager.set_agenda_planning_items( @@ -296,8 +293,8 @@ async def publish_planning_into_event(self, planning: dict[str, Any]) -> Optiona # setting _id of agenda to be equal to planning if there's no event id agenda.setdefault("_id", planning.get("event_item", planning["guid"]) or planning["guid"]) agenda.setdefault("guid", planning.get("event_item", planning["guid"]) or planning["guid"]) - return service.post([agenda])[0] + return (await service.create([agenda]))[0] else: # Replace the original document - service.patch(agenda["_id"], agenda) + await service.update(agenda["_id"], agenda) return agenda["_id"] diff --git a/newsroom/push/tasks.py b/newsroom/push/tasks.py index f9c592707..2ecf57d7b 100644 --- a/newsroom/push/tasks.py +++ b/newsroom/push/tasks.py @@ -2,11 +2,10 @@ from contextlib import contextmanager from superdesk.lock import lock, unlock -from superdesk import get_resource_service from newsroom.celery_app import celery -from newsroom.core import get_current_wsgi_app from newsroom.wire import WireSearchServiceAsync +from newsroom.agenda import AgendaItemService from .notifications import NotificationManager @@ -40,13 +39,16 @@ async def notify_new_wire_item(_id, check_topics=True): @celery.task async def notify_new_agenda_item(_id, check_topics=True, is_new=False): with locked(_id, "agenda"): - app = get_current_wsgi_app() - agenda = app.data.find_one("agenda", req=None, _id=_id) + service = AgendaItemService() + agenda = await service.find_by_id(_id) - if agenda: - if agenda.get("recurrence_id") and agenda.get("recurrence_id") != _id and is_new: - logger.info("Ignoring recurring event %s", _id) - return + if not agenda: + return - get_resource_service("agenda").enhance_items([agenda]) - await notifier.notify_new_item(agenda, check_topics=check_topics) + if agenda.recurrence_id and agenda.recurrence_id != _id and is_new: + logger.info("Ignoring recurring event %s", _id) + return + + agenda_dict = agenda.to_dict() + await AgendaItemService().enhance_item(agenda_dict) + await notifier.notify_new_item(agenda_dict, check_topics=check_topics) diff --git a/newsroom/push/utils.py b/newsroom/push/utils.py index 40f5c60c2..ee8f01d56 100644 --- a/newsroom/push/utils.py +++ b/newsroom/push/utils.py @@ -50,7 +50,8 @@ def format_qcode_items(items: list[dict[str, Any]] | None = None): return [] for item in items: - item["code"] = item.get("qcode") + if not item.get("code"): + item["code"] = item.get("qcode") return items diff --git a/newsroom/push/views.py b/newsroom/push/views.py index 00611cef8..712cbbb07 100644 --- a/newsroom/push/views.py +++ b/newsroom/push/views.py @@ -10,14 +10,15 @@ from superdesk.core.web import EndpointGroup from newsroom import signals -from newsroom.agenda.service import FeaturedService +from newsroom.core import get_current_wsgi_app from newsroom.utils import parse_date_str from newsroom.assets import ASSETS_RESOURCE -from newsroom.core import get_current_wsgi_app -from newsroom.web.factory import NewsroomWebApp from newsroom.flask import get_file_from_request + from newsroom.wire import WireSearchServiceAsync from newsroom.wire.views import delete_dashboard_caches +from newsroom.agenda.featured_service import FeaturedService +from newsroom.agenda.agenda_service import AgendaItemService from .publishing import Publisher from .utils import assert_test_signature @@ -30,23 +31,20 @@ notifier = NotificationManager() publisher = Publisher() -PublishHandlerFunc = Callable[[NewsroomWebApp, dict[str, Any]], Awaitable[None]] - - -# TODO-ASYNC: Revisit this module when agenda, items and agenda_featured are async +PublishHandlerFunc = Callable[[dict[str, Any]], Awaitable[None]] -async def handle_publish_event(app: NewsroomWebApp, item): - orig = app.data.find_one("agenda", req=None, _id=item["guid"]) - event_id = await publisher.publish_event(item, orig) +async def handle_publish_event(item): + orig = await AgendaItemService().find_by_id(item["guid"]) + event_id = await publisher.publish_event(item, orig.to_dict() if orig else None) notify_new_agenda_item.delay(event_id, check_topics=True, is_new=orig is None) -async def handle_publish_planning(app: NewsroomWebApp, item): - orig = app.data.find_one("agenda", req=None, _id=item["guid"]) or {} +async def handle_publish_planning(item): + orig = await AgendaItemService().find_by_id(item["guid"]) item["planning_date"] = parse_date_str(item["planning_date"]) - plan_id = await publisher.publish_planning_item(item, orig) + plan_id = await publisher.publish_planning_item(item, orig.to_dict() if orig else {}) event_id = await publisher.publish_planning_into_event(item) # Prefer parent Event when sending notificaitons @@ -54,7 +52,7 @@ async def handle_publish_planning(app: NewsroomWebApp, item): notify_new_agenda_item.delay(_id, check_topics=True, is_new=orig is None) -async def handle_publish_text_item(_, item): +async def handle_publish_text_item(item): orig = await WireSearchServiceAsync().service.find_by_id(item["guid"]) item["_id"] = await publisher.publish_item(item, orig.to_dict() if orig else None) @@ -64,18 +62,18 @@ async def handle_publish_text_item(_, item): ) -async def handle_publish_planning_featured(_, item): +async def handle_publish_planning_featured(item): assert item.get("_id"), {"_id": 1} service = FeaturedService() orig = await service.find_by_id(item["_id"]) if orig: - service.update(orig["_id"], {"items": item.get("items") or []}, orig) + await service.update(orig["_id"], {"items": item.get("items") or []}) else: # Assert `tz` and `items` in initial push only assert item.get("tz"), {"tz": 1} assert item.get("items"), {"items": 1} - service.create([item]) + await service.create([item]) def get_publish_handler( @@ -115,7 +113,7 @@ async def push(request: Request): publish_fn = get_publish_handler(item_type) if publish_fn: - await publish_fn(app, item) + await publish_fn(item) else: await request.abort(400, gettext("Unknown type {}".format(item.get("type")))) diff --git a/newsroom/reports/content_activity.py b/newsroom/reports/content_activity.py index 499c43e42..e64cf81c0 100644 --- a/newsroom/reports/content_activity.py +++ b/newsroom/reports/content_activity.py @@ -1,20 +1,57 @@ +from typing import Any, cast from copy import deepcopy from quart_babel import gettext +from superdesk.core.resources.cursor import ElasticsearchResourceCursorAsync from superdesk.flask import abort, request from superdesk import get_resource_service from superdesk.utc import utc_to_local -from newsroom.wire.search import items_query -from newsroom.agenda.agenda import get_date_filters +from newsroom.types import SectionEnum +from newsroom.search.types import BaseSearchRequestArgs, NewshubSearchRequest +from newsroom.search.filters import apply_section_filter +from newsroom.wire.filters import apply_item_type_filter as apply_wire_type_filter +from newsroom.wire import WireItemService +from newsroom.agenda.filters import get_date_filters, apply_item_type_filter as apply_agenda_type_filter +from newsroom.agenda import AgendaItemService + from newsroom.utils import query_resource, MAX_TERMS_SIZE CHUNK_SIZE = 100 -def get_items(args): +async def get_query_source(args: dict[str, Any], source: dict[str, Any]) -> dict[str, Any]: + search_request = NewshubSearchRequest[BaseSearchRequestArgs](section=args["section"]) + query = search_request.search.query + + if args.get("genre"): + query.filter.append({"terms": {"genre.code": [genre for genre in args["genre"]]}}) + + await apply_section_filter(search_request) + + if args["section"] == SectionEnum.AGENDA: + # Set ``featured`` to True, so we don't add filters for Event, Planning, Combined type filter + search_request.args.featured = True + apply_agenda_type_filter(search_request) + else: + apply_wire_type_filter(search_request) + + date_range = get_date_filters( + BaseSearchRequestArgs( + start_date=args["date_from"], + end_date=args["date_from"], + timezone_offset=args.get("timezone_offset"), + ) + ) + if date_range.get("gt") or date_range.get("lt"): + query.filter.append({"range": {"versioncreated": date_range}}) + + return search_request.search.generate_query_dict(source) + + +async def get_items(args): """Get all the news items for the date and filters provided For performance reasons, returns an iterator that yields an array of CHUNK_SIZE @@ -24,48 +61,35 @@ def get_items(args): if not args.get("section"): abort(400, gettext("Must provide a section for this report")) - source = { - "query": items_query(True), - "size": CHUNK_SIZE, - "from": 0, - "sort": [{"versioncreated": "asc"}], - "_source": [ - "_resource", - "headline", - "place", - "subject", - "service", - "versioncreated", - "anpa_take_key", - "source", - ], - } - - must_terms = [] - if args.get("genre"): - must_terms.append({"terms": {"genre.code": [genre for genre in args["genre"]]}}) - - args["date_to"] = args["date_from"] - date_range = get_date_filters(args) - if date_range.get("gt") or date_range.get("lt"): - must_terms.append({"range": {"versioncreated": date_range}}) - - if len(must_terms) > 0: - source["query"]["bool"]["filter"] += must_terms + source = await get_query_source( + args, + { + "size": CHUNK_SIZE, + "from": 0, + "sort": [{"versioncreated": "asc"}], + "_source": [ + "_resource", + "headline", + "place", + "subject", + "service", + "versioncreated", + "anpa_take_key", + "source", + ], + }, + ) - # Apply the section filters - section = args["section"] - get_resource_service("section_filters").apply_section_filter(source["query"], section) + service = AgendaItemService() if args["section"] == SectionEnum.AGENDA else WireItemService() while True: - results = get_resource_service(section if section == "agenda" else f"{section}_search").search(source) - items = list(results) + cursor = await service.search(source) + items = await cursor.to_list_raw() if not len(items): break source["from"] += CHUNK_SIZE - yield items @@ -114,42 +138,34 @@ def get_aggregations(args, ids): } -def get_facets(args): +async def get_facets(args): """Get aggregations for genre and companies using the date range and section This is used to populate the dropdown filters in the front-end """ - args["date_to"] = args["date_from"] - date_range = get_date_filters(args) + section = args["section"] + date_range = get_date_filters( + BaseSearchRequestArgs( + start_date=args["date_from"], + end_date=args["date_from"], + timezone_offset=args.get("timezone_offset"), + ) + ) - def get_genres(): + async def get_genres(): """Get the list of genres from the news items""" - query = items_query(True) - must_terms = [] - source = {} - - if date_range.get("gt") or date_range.get("lt"): - must_terms.append({"range": {"versioncreated": date_range}}) - - if len(must_terms) > 0: - query["bool"]["filter"] += must_terms - - source.update( + source = await get_query_source( + args, { - "query": query, "size": 0, "aggs": {"genres": {"terms": {"field": "genre.code", "size": MAX_TERMS_SIZE}}}, - } + }, ) - # Apply the section filters - section = args["section"] - get_resource_service("section_filters").apply_section_filter(source["query"], section) - - results = get_resource_service(section if section == "agenda" else f"{section}_search").search(source) - + service = AgendaItemService() if args["section"] == SectionEnum.AGENDA else WireItemService() + results = cast(ElasticsearchResourceCursorAsync, await service.search(source)) buckets = ((results.hits.get("aggregations") or {}).get("genres") or {}).get("buckets") or [] return [genre["key"] for genre in buckets] @@ -157,7 +173,7 @@ def get_genres(): def get_companies(): """Get the list of companies from the action history""" - must_terms = [{"term": {"section": args["section"]}}] + must_terms = [{"term": {"section": section}}] if date_range.get("gt") or date_range.get("lt"): must_terms.append({"range": {"_created": date_range}}) @@ -174,7 +190,7 @@ def get_companies(): return [company["key"] for company in buckets] - return {"genres": get_genres(), "companies": get_companies()} + return {"genres": await get_genres(), "companies": get_companies()} def export_csv(args, results): @@ -280,11 +296,11 @@ async def get_content_activity_report(): if args.get("aggregations"): # This request is for populating the dropdown filters # for genre and companies - return get_facets(args) + return await get_facets(args) response = {"results": [], "name": gettext("Content activity")} - for items in get_items(args): + async for items in get_items(args): item_ids = [item.get("_id") for item in items] aggs = get_aggregations(args, item_ids) diff --git a/newsroom/reports/reports.py b/newsroom/reports/reports.py index 36d02ae8c..dcfce9783 100644 --- a/newsroom/reports/reports.py +++ b/newsroom/reports/reports.py @@ -20,7 +20,8 @@ get_items_by_id, MAX_TERMS_SIZE, ) -from newsroom.agenda.agenda import get_date_filters +from newsroom.search.types import BaseSearchRequestArgs +from newsroom.agenda.filters import get_date_filters from newsroom.news_api.api_tokens import API_TOKENS from newsroom.news_api.utils import format_report_results from newsroom.companies.utils import get_companies_id_by_product @@ -156,16 +157,11 @@ async def get_product_stories(): results = [] - # TODO-ASYNC: Use Async ProductsService when get_product_item_report picks ProductResourceModel - products = query_resource("products") - # cursor = await ProductsService().find({}) - # products = await cursor.to_list_raw() - - for product in products: + async for product in ProductsService().get_all(): product_stories = { - "_id": product["_id"], - "name": product.get("name"), - "is_enabled": product.get("is_enabled"), + "_id": product.id, + "name": product.name, + "is_enabled": product.is_enabled, } counts = await WireSearchServiceAsync().get_product_item_report(product) for key, value in counts.hits["aggregations"].items(): @@ -221,7 +217,13 @@ async def get_subscriber_activity_report(): if args.get("section"): must_terms.append({"term": {"section": args.get("section")}}) - date_range = get_date_filters(args) + date_range = get_date_filters( + BaseSearchRequestArgs( + start_date=args["date_from"], + end_date=args["date_from"], + timezone_offset=args.get("timezone_offset"), + ) + ) if date_range.get("gt") or date_range.get("lt"): must_terms.append({"range": {"versioncreated": date_range}}) @@ -334,7 +336,13 @@ def get_section_name(s): async def get_company_api_usage(): args = deepcopy(request.args.to_dict()) - date_range = get_date_filters(args) + date_range = get_date_filters( + BaseSearchRequestArgs( + start_date=args["date_from"], + end_date=args["date_from"], + timezone_offset=args.get("timezone_offset"), + ) + ) if not date_range.get("gt") and date_range.get("lt"): abort(400, "No date range specified.") diff --git a/newsroom/search/base_service.py b/newsroom/search/base_service.py index aefb3fbac..4d2daf8bb 100644 --- a/newsroom/search/base_service.py +++ b/newsroom/search/base_service.py @@ -139,8 +139,7 @@ async def get_search_response( total=count, ), ) - if hasattr(cursor, "extra"): - getattr(cursor, "extra")(response) + cursor.extra(response) return response, count diff --git a/newsroom/search/base_web_service.py b/newsroom/search/base_web_service.py new file mode 100644 index 000000000..272f16af0 --- /dev/null +++ b/newsroom/search/base_web_service.py @@ -0,0 +1,219 @@ +from typing import Generic, Any +import logging + +from bson import ObjectId + +from superdesk.core.types import ESQuery, ESBoolQuery, SearchRequest +from superdesk.core.resources.cursor import ElasticsearchResourceCursorAsync + +from newsroom.exceptions import AuthorizationError +from newsroom.types import TopicResourceModel, UserResourceModel, CompanyResource +from newsroom.auth.utils import get_user_sections +from newsroom.products import get_products_by_navigation_async + +from .base_service import BaseNewshubSearchService, SearchArgsType, SearchItemType +from .filters import prefill_products, validate_request, prefill_args_from_topic +from .types import SearchFilterFunction, NewshubSearchRequest + +logger = logging.getLogger(__name__) + + +class BaseWebSearchService( + BaseNewshubSearchService[SearchArgsType, SearchItemType], Generic[SearchArgsType, SearchItemType] +): + get_items_by_id_filters: list[SearchFilterFunction] + + get_topic_items_query_execute_filters: list[SearchFilterFunction] + get_topic_items_query_user_filters: list[SearchFilterFunction] + + async def get_items_by_id( + self, + item_ids: list[str], + args: SearchArgsType | None = None, + apply_permissions: bool = False, + ) -> ElasticsearchResourceCursorAsync[SearchItemType]: + """Searches for items by ID, optionally applying user/company permissions + + :param item_ids: A list of item IDs to search for + :param args: Optional set of request arguments to apply + :param apply_permissions: Whether to apply user/company permissions or not + :returns: Elasticsearch cursor with the results + """ + + if args is None: + args = self.search_args_class() + + args.ids = item_ids + return await self.search( + args, + filters=None if apply_permissions else self.get_items_by_id_filters, + ) + + async def get_items_for_action(self, item_ids: list[str]) -> list[dict[str, Any]]: + """Searches for item by ID, for use by downloads, sharing etc + + For each item, appends the ``anpa_take_key`` to the slugline if defined + + :param item_ids: A list of item IDs to search for + :returns: The list of WIre items + """ + + raise NotImplementedError() + + async def get_topic_items_query( + self, + topic: TopicResourceModel | None, + user: UserResourceModel | None, + company: CompanyResource | None, + query: ESQuery | None = None, + args: SearchArgsType | None = None, + ) -> ESQuery | None: + """Generate an elasticsearch query, based on topic, user and company + + :param topic: An optional Topic to be added to the request args + :param user: An optional User to be added to the request args + :param company: An optional Company to be added to the request args + :param query: An optional Elasticsearch query to start with + :param args: An optional request args to start with + :returns: The generated Elasticsearch query, or None if the supplied User does not have permission + """ + + async def prefill_request(request: NewshubSearchRequest): + if topic: + request.topic = topic + if user: + request.user = request.current_user = user + request.is_admin = request.user.is_admin() + else: + request.is_admin = False + + if company: + request.company = company + + if user is None and topic is not None and topic.navigation is not None: + request.products = await get_products_by_navigation_async(topic.navigation) + + search_request = NewshubSearchRequest( + section=self.section, web_request=None, args=args or self.search_args_class(), search=query or ESQuery() + ) + + prefill_filter_params: list[SearchFilterFunction] = [ + prefill_request, + prefill_args_from_topic, + ] + execute_filters = self.get_topic_items_query_execute_filters.copy() + + if user is not None: + # If this query is from a User's perspective, then add + # validation and section/company filters + prefill_filter_params.extend([prefill_products, validate_request]) + execute_filters.extend(self.get_topic_items_query_user_filters) + + try: + return await self.run_filters_and_return_query(search_request, prefill_filter_params + execute_filters) + except AuthorizationError: + if user and topic: + logger.info(f"Notification for user:{user.id} and topic:{topic.id} is skipped") + pass + + return None + + async def get_matching_topics_for_item( + self, + item_id: str, + topics: list[TopicResourceModel], + users: list[UserResourceModel], + companies: dict[ObjectId, CompanyResource], + ) -> set[ObjectId]: + """Get a set of Topic IDs that match the supplied item + + :param item_id: The ID of the item to match topics against + :param topics: The list of Topics to match the item against + :param users: The list of Users to match the item against + :param companies: The list of Companies to match the item against + :returns: A set of Topic IDs that the wire item matches + """ + + return await self.get_matching_topics_for_query( + topics, + users, + companies, + ESQuery(query=ESBoolQuery(must=[{"term": {"_id": item_id}}])), + ) + + async def get_matching_topics_for_query( + self, + topics: list[TopicResourceModel], + users: list[UserResourceModel], + companies: dict[ObjectId, CompanyResource], + query: ESQuery | None = None, + ) -> set[ObjectId]: + """Get a set of Topic IDs that match the supplied query + + :param topics: The list of Topics to match the item against + :param users: The list of Users to match the item against + :param companies: The list of Companies to match the item against + :param query: The Elasticsearch query to match topics for + :returns: A set of Topic IDs that the wire item matches + """ + + topic_matches: set[ObjectId] = set() + topics_checked: set[ObjectId] = set() + + for user in users: + company = companies.get(user.company) if user.company else None + user_sections = get_user_sections(user, company) + if not user_sections.get(self.section): + continue + + if user.has_paused_notifications(): + continue + + aggs: dict[str, Any] = {"topics": {"filters": {"filters": {}}}} + + # There will be one base search for a user with aggs for user topics + search = await self.get_topic_items_query(None, user, company, query=query) + if not search: + continue + + queried_topics: list[TopicResourceModel] = [] + for topic in topics: + if topic.user is None or topic.user != user.id: + continue + elif topic.id in topics_checked: + continue + topics_checked.add(topic.id) + + topic_query = await self.get_topic_items_query(topic, None, None) + if not topic_query: + continue + + try: + aggs["topics"]["filters"]["filters"][str(topic.id)] = topic_query.generate_query_dict()["query"] + queried_topics.append(topic) + except (KeyError, TypeError, IndexError): + continue + + if not len(queried_topics): + continue + + search.aggs = aggs + search_request = SearchRequest( + max_results=0, + aggregations=True, + elastic=search, + ) + + try: + search_results: ElasticsearchResourceCursorAsync[SearchItemType] = await self.service.find( + search_request + ) + for topic in queried_topics: + try: + if search_results.hits["aggregations"]["topics"]["buckets"][str(topic.id)]["doc_count"] > 0: + topic_matches.add(topic.id) + except (KeyError, IndexError, TypeError): + logger.warning(f"Failed to find aggregation result for topic {topic.id}") + except Exception: + logger.exception("Error in get_matching_topics", extra=dict(query=search_request, user=user.id)) + return topic_matches diff --git a/newsroom/search/filters.py b/newsroom/search/filters.py index ff8e3385a..d0fcedec7 100644 --- a/newsroom/search/filters.py +++ b/newsroom/search/filters.py @@ -5,7 +5,6 @@ from superdesk.core import get_app_config from superdesk.core.types import ESQuery -from superdesk import get_resource_service from content_api.errors import BadParameterValueError from newsroom.types import SectionEnum @@ -15,10 +14,11 @@ from newsroom.settings import get_setting from newsroom.users import UsersService -from newsroom.products.products import ( - get_products_by_navigation, - get_products_by_company, - get_products_by_user, +from newsroom.products import ( + ProductsService, + get_products_by_company_async, + get_products_by_user_async, + get_products_by_navigation_async, ) from .types import ( @@ -76,18 +76,15 @@ async def prefill_products(request: NewshubSearchRequest) -> None: # This should not happen, as it's prefilled by search service return - products_service = get_resource_service("products") + products_service = ProductsService() request.products = [] if request.is_admin: if len(request.args.navigation_ids): - request.products = get_products_by_navigation( - request.args.navigation_ids, product_type=request.section.value + request.products = await get_products_by_navigation_async( + request.args.navigation_ids, product_type=request.section ) elif len(request.args.product_ids): - # TODO-ASYNC: Convert to Async service when it's available - request.products = list( - products_service.get_from_mongo(req=None, lookup={"_id": {"$in": request.args.product_ids}}) - ) + request.products = await products_service.find_by_ids(request.args.product_ids) elif request.company is not None: if request.args.product_ids: allowed_product_ids = ( @@ -99,23 +96,20 @@ async def prefill_products(request: NewshubSearchRequest) -> None: if request.company is not None and request.company.products is not None else [] ) - # TODO-ASYNC: Convert to Async service when it's available - request.products = list( - products_service.get_from_mongo(req=None, lookup={"_id": {"$in": allowed_product_ids}}) - ) + request.products = await products_service.find_by_ids(allowed_product_ids) else: if request.user and request.user.products: - request.products = get_products_by_user( - request.user.to_dict(context={"use_objectid": True}), - str(request.section.value), + request.products = await get_products_by_user_async( + request.user, + request.section, request.args.navigation_ids, ) # add unlimited (seats=0) company products - company_products = get_products_by_company( - request.company.to_dict(context={"use_objectid": True}), + company_products = await get_products_by_company_async( + request.company, request.args.navigation_ids, - product_type=request.section.value, + product_type=request.section, unlimited_only=True, ) if company_products: @@ -167,7 +161,7 @@ def apply_ids_filter(request: NewshubSearchRequest) -> None: if not len(request.args.ids): return - request.search.query.must.append({"ids": {"values": request.args.ids}}) + request.search.query.filter.append({"ids": {"values": request.args.ids}}) def get_apply_filters(get_aggregation_field: Callable[[str], str]) -> SearchFilterFunction: @@ -273,14 +267,13 @@ def apply_products_filter(request: NewshubSearchRequest) -> None: # This should not happen, as it's prefilled by search service return - # TODO-ASYNC: Convert to async Product model when available - sdesk_product_ids = [product["sd_product_id"] for product in request.products if product.get("sd_product_id")] + sdesk_product_ids = [product.sd_product_id for product in request.products if product.sd_product_id] if sdesk_product_ids: request.search.query.should.append({"terms": {"products.code": sdesk_product_ids}}) for product in request.products: - if product.get("query"): - request.search.query.should.append(query_string_for_section(request.section, product["query"])) + if product.query: + request.search.query.should.append(query_string_for_section(request.section, product.query)) def prefill_args_from_topic(request: NewshubSearchRequest) -> None: @@ -325,6 +318,22 @@ def apply_advanced_search(request: NewshubSearchRequest) -> None: if not advanced["fields"]: return + if request.section is SectionEnum.AGENDA: + if "slugline" in advanced["fields"]: + # Add ``slugline`` field for Planning & Coverages too + advanced["fields"].extend(["planning_items.slugline", "coverages.slugline"]) + + if "headline" in advanced["fields"]: + # Add ``headline`` field for Planning items too + advanced["fields"].append("planning_items.headline") + + if "description" in advanced["fields"]: + # Replace ``description`` alias with appropriate description fields + advanced["fields"].remove("description") + advanced["fields"].extend( + ["definition_short", "definition_long", "description_text", "planning_items.description_text"] + ) + if advanced.get("all"): request.search.query.must.append( query_string( diff --git a/newsroom/search/types.py b/newsroom/search/types.py index 94c5f01ff..dd9755789 100644 --- a/newsroom/search/types.py +++ b/newsroom/search/types.py @@ -16,7 +16,7 @@ UserResourceModel, CompanyResource, SectionEnum, - Product, + ProductResourceModel, TopicResourceModel, AdvancedSearchParams, ) @@ -164,13 +164,15 @@ class BaseSearchRequestArgs(BaseModel): user_id: ObjectId | None = Field(validation_alias=AliasChoices("user_id", "user"), default=None) #: Filter items from this date onwards (an alias for ``created_from``) - start_date: str | None = Field(validation_alias=AliasChoices("start_date", "created_from"), default=None) + start_date: str | None = Field( + validation_alias=AliasChoices("start_date", "created_from", "date_from"), default=None + ) #: Start time to use with the ``start_date`` argument (defaults to ``00.00:00``) start_time: str = Field(validation_alias=AliasChoices("start_time", "created_from_time"), default="00:00:00") #: Filter items up to this date (an alias for ``created_to``) - end_date: str | None = Field(validation_alias=AliasChoices("end_date", "created_to"), default=None) + end_date: str | None = Field(validation_alias=AliasChoices("end_date", "created_to", "date_to"), default=None) #: End time to use with the ``end_date`` argument (defaults to ``23:59:59``) end_time: str = Field(validation_alias=AliasChoices("end_time", "created_to_time"), default="23:59:59") @@ -191,7 +193,7 @@ class BaseSearchRequestArgs(BaseModel): filter: dict[str, Any] | None = None #: List of item IDs to search for - ids: list[str] = Field(default_factory=list) + ids: list[str] = Field(validation_alias=AliasChoices("ids", "id"), default_factory=list) #: The timezone offest, used when constructing the date queries timezone_offset: int | None = None @@ -230,7 +232,7 @@ def parse_filter(cls, value: dict[str, Any] | str | None) -> dict[str, Any] | No except (ValueError, TypeError): raise BadParameterValueError(gettext("Incorrect type supplied for filter parameter")) - @field_validator("product_ids", "bookmarks", "navigation_ids", mode="before") + @field_validator("product_ids", "bookmarks", "navigation_ids", "ids", mode="before") def parse_list_ids(cls, value: list[str] | list[ObjectId] | str | ObjectId | None) -> list[str]: """If value is not a list, then convert it to a list here @@ -342,9 +344,8 @@ class NewshubSearchRequest(Generic[SearchArgsType]): search: ESQuery = field(default_factory=ESQuery) - # TODO-ASYNC: Convert to Async resource model when it's available #: The list of pre-filled products to use when constructing the elasticsearch query - products: list[Product] = field(default_factory=list) + products: list[ProductResourceModel] = field(default_factory=list) #: The list of topics to use when constructing the elasticsearch query topic: TopicResourceModel | None = None diff --git a/newsroom/tests/conftest.py b/newsroom/tests/conftest.py index 9b774b925..cfec2c5e1 100644 --- a/newsroom/tests/conftest.py +++ b/newsroom/tests/conftest.py @@ -114,6 +114,16 @@ def get_mongo_uri(key, dbname): @fixture async def app(request): + # Make sure old DB connections are closed + prev_instance = getattr(app, "instance", None) + if prev_instance: + # Close all PyMongo Connections (new ones will be created with ``app_factory`` call) + for key, val in prev_instance.extensions["pymongo"].items(): + val[0].close() + + prev_instance.async_app.stop() + await prev_instance.async_app.elastic.stop() + cfg = Config(root) update_config(cfg) @@ -128,7 +138,8 @@ async def app(request): # drop mongodb now, indexes will be created during app init drop_mongo(cfg) - app = get_app(config=cfg, testing=True) + app_instance = get_app(config=cfg, testing=True) + setattr(app, "instance", app_instance) limiter_key = str(ObjectId()) async def limiter_key_function(): @@ -136,16 +147,16 @@ async def limiter_key_function(): limiter.key_function = limiter_key_function - async with app.app_context(): - await reset_elastic(app) + async with app_instance.app_context(): + await reset_elastic(app_instance) cache.clean() - app.init_indexes() - yield app + app_instance.init_indexes() + yield app_instance # Clean up blueprints, so they can be re-registered import importlib - for name in app.config["BLUEPRINTS"]: + for name in app_instance.config["BLUEPRINTS"]: mod = importlib.import_module(name) if getattr(mod, "blueprint"): mod.blueprint._got_registered_once = False diff --git a/newsroom/tests/fixtures.py b/newsroom/tests/fixtures.py index 818fa3477..8fbcb81ea 100644 --- a/newsroom/tests/fixtures.py +++ b/newsroom/tests/fixtures.py @@ -10,6 +10,8 @@ from newsroom.tests.users import ADMIN_USER_ID, test_login_succeeds_for_admin from tests.core.utils import create_entries_for +from . import markers + PUBLIC_USER_ID = ObjectId("59b4c5c61d41c8d736852fbf") PUBLIC_USER_FIRSTNAME = "Foo" @@ -26,6 +28,11 @@ ADMIN_USER_EMAIL = "admin@sourcefabric.org" + +def get_markers(request): + return [mark.name for mark in request.node.own_markers] + + items = [ { "_id": "tag:foo", @@ -91,8 +98,89 @@ agenda_items = [ { - "type": "agenda", - "_id": "urn:conference", + "type": "event", + "guid": "urn:conference", + "event_id": "urn:conference", + "versioncreated": datetime(2018, 6, 27, 11, 12, 4, tzinfo=utc), + "name": "Conference Planning", + "slugline": "Prime Conference", + "internal_note": "Internal message for event", + "_created": "2018-06-27T11:12:07+0000", + "dates": { + "end": datetime(2018, 7, 20, 4, 0, 0, tzinfo=utc), + "start": datetime(2018, 7, 20, 4, 0, 0, tzinfo=utc), + }, + "event": { + "definition_short": "Blah Blah", + "pubstatus": "usable", + "files": [{"media": "media", "name": "test.txt", "mimetype": "text/plain"}], + "internal_note": "Internal message for event", + "state": "scheduled", + }, + "firstcreated": "2018-06-27T11:12:04+0000", + "_current_version": 1, + "files": [{"media": "media", "name": "test.txt", "mimetype": "text/plain"}], + "definition_short": "Blah Blah", + "state": "scheduled", + "pubstatus": "usable", + "planning_items": [ + { + "versioncreated": "2018-06-27T11:07:17+0000", + "planning_date": "2018-07-20T04:00:00+0000", + # "expired": False, + # "flags": {"marked_for_not_publication": False}, + "slugline": "Prime Conference", + # "item_class": "plinat:newscoverage", + "pubstatus": "usable", + # "item_id": "urn:planning", + "name": "Conference Planning", + "_id": "urn:planning", + "firstcreated": "2018-06-27T11:07:17+0000", + "state": "draft", + "guid": "urn:planning", + "agendas": [], + # "_current_version": 1, + # "type": "planning", + "internal_note": "Internal message for planning", + "coverages": [ + { + "firstcreated": "2018-06-27T11:07:17+0000", + "planning": { + "g2_content_type": "text", + "genre": [{"name": "Article", "qcode": "Article"}], + "ednote": "An editorial Note", + "keyword": ["Motoring"], + "scheduled": "2018-04-09T14:00:53.000Z", + "slugline": "Raiders", + "internal_note": "Internal message for coverage", + }, + "workflow_status": "active", + "coverage_id": "urn:coverage", + "news_coverage_status": { + "label": "Planned", + "name": "coverage intended", + "qcode": "ncostat:int", + }, + } + ], + } + ], + "coverages": [ + { + "planning_id": "urn:planning", + "coverage_id": "urn:coverage", + "scheduled": "2018-04-09T14:00:53.000Z", + "coverage_type": "text", + "workflow_status": "active", + "coverage_status": "coverage intended", + "slugline": "Raiders", + "genre": [{"name": "Article", "qcode": "Article"}], + } + ], + }, + { + "type": "planning", + "guid": "urn:planning", "event_id": "urn:conference", "versioncreated": datetime(2018, 6, 27, 11, 12, 4, tzinfo=utc), "name": "Conference Planning", @@ -163,33 +251,24 @@ }, } ], - "dates": { - "end": datetime(2018, 7, 20, 4, 0, 0, tzinfo=utc), - "start": datetime(2018, 7, 20, 4, 0, 0, tzinfo=utc), - }, - "event": { - "definition_short": "Blah Blah", - "pubstatus": "usable", - "files": [{"media": "media", "name": "test.txt", "mimetype": "text/plain"}], - "internal_note": "Internal message for event", - "state": "scheduled", - }, + "planning_date": "2018-07-20T04:00:00+0000", "firstcreated": "2018-06-27T11:12:04+0000", "_current_version": 1, "headline": "test headline", - } + }, ] @fixture(autouse=True) -async def init_items(app): - app.data.insert("items", items) +async def init_items(request, app): + if markers.skip_auto_wire_items.name not in get_markers(request): + await create_entries_for("items", items) @fixture(autouse=True) -async def init_agenda_items(app): - async with app.app_context(): - app.data.insert("agenda", agenda_items) +async def init_agenda_items(request, app): + if markers.skip_auto_agenda_items.name not in get_markers(request): + await create_entries_for("agenda", agenda_items) @fixture() @@ -230,7 +309,7 @@ async def init_auth(app, auth_users): async def setup_user_company(app): - app.data.insert( + await create_entries_for( "companies", [ { @@ -258,8 +337,8 @@ async def setup_user_company(app): ], ) - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": PUBLIC_USER_ID, @@ -348,6 +427,6 @@ async def company_products(app): }, ] - app.data.insert("products", _products) + await create_entries_for("products", _products) return _products diff --git a/newsroom/tests/markers.py b/newsroom/tests/markers.py index 380d080a0..a798a12ec 100644 --- a/newsroom/tests/markers.py +++ b/newsroom/tests/markers.py @@ -14,3 +14,7 @@ # Skips the test, due to known issue with async changes # Change this to `requires_async_celery = mark.requires_async_celery` to run these tests requires_async_celery = mark.skip(reason="Requires celery to support async tasks") + + +skip_auto_wire_items = mark.skip_auto_wire_items +skip_auto_agenda_items = mark.skip_auto_agenda_items diff --git a/newsroom/tests/web_api/environment.py b/newsroom/tests/web_api/environment.py index a4fcf1c0c..973d8874a 100644 --- a/newsroom/tests/web_api/environment.py +++ b/newsroom/tests/web_api/environment.py @@ -8,7 +8,7 @@ from newsroom.web.factory import get_app from newsroom.web.default_settings import CORE_APPS, BLUEPRINTS -from newsroom.agenda.agenda import aggregations as agenda_aggs +from newsroom.agenda.filters import aggregations as agenda_aggs from tests.search.fixtures import USERS, COMPANIES @@ -95,6 +95,7 @@ async def before_scenario_async(context, scenario): context.app = get_app(config=config, testing=True) async with context.app.app_context(): await reset_elastic(context.app) + context.app.cache.clear() context.headers = [("Content-Type", "application/json"), ("Origin", "localhost")] context.client = context.app.test_client() diff --git a/newsroom/topics/__init__.py b/newsroom/topics/__init__.py index 10b475153..0cc9039bb 100644 --- a/newsroom/topics/__init__.py +++ b/newsroom/topics/__init__.py @@ -4,6 +4,7 @@ topic_resource_config, topic_endpoints, get_user_topics, + get_user_topics_async, get_agenda_notification_topics_for_query_by_id, get_topics_with_subscribers, get_topics_with_subscribers_async, @@ -12,6 +13,7 @@ __all__ = [ "get_user_topics", + "get_user_topics_async", "topic_endpoints", "topic_resource_config", "get_agenda_notification_topics_for_query_by_id", diff --git a/newsroom/topics/topics_async.py b/newsroom/topics/topics_async.py index 1c6c96066..29c71ddac 100644 --- a/newsroom/topics/topics_async.py +++ b/newsroom/topics/topics_async.py @@ -1,8 +1,10 @@ from bson import ObjectId from typing import Optional, List, Dict, Any, Union +from superdesk.core.resources import ResourceConfig, MongoResourceConfig, RestEndpointConfig, RestParentLink + from newsroom import MONGO_PREFIX -from newsroom.types import TopicResourceModel, UserResourceModel +from newsroom.types import TopicResourceModel, UserResourceModel, NotificationType from newsroom.exceptions import AuthorizationError from newsroom.auth.utils import get_user_from_request @@ -15,7 +17,6 @@ from superdesk.core.web import EndpointGroup # from superdesk.core.module import SuperdeskAsyncApp -from superdesk.core.resources import ResourceConfig, MongoResourceConfig, RestEndpointConfig, RestParentLink class TopicService(NewshubAsyncResourceService[TopicResourceModel]): @@ -146,7 +147,7 @@ async def get_topics_with_subscribers(topic_type: Optional[str] = None) -> list[ # TODO-ASYNC: Remove the above function, replace with this one and remove `_async` suffix -async def get_topics_with_subscribers_async(topic_type: Optional[str] = None) -> list[TopicResourceModel]: +async def get_topics_with_subscribers_async(topic_type: str | None = None) -> list[TopicResourceModel]: lookup: Dict[str, Any] = ( {"subscribers": {"$exists": True, "$ne": []}} if topic_type is None @@ -164,15 +165,15 @@ async def get_topics_with_subscribers_async(topic_type: Optional[str] = None) -> async def get_user_id_to_topic_for_subscribers( - notification_type: Optional[str] = None, -) -> Dict[ObjectId, Dict[ObjectId, Topic]]: - user_topic_map: Dict[ObjectId, Dict[ObjectId, Topic]] = {} - for topic in await get_topics_with_subscribers(): - for subscriber in topic.get("subscribers") or []: - if notification_type is not None and subscriber.get("notification_type") != notification_type: + notification_type: NotificationType | None = None, +) -> dict[ObjectId, dict[ObjectId, TopicResourceModel]]: + user_topic_map: dict[ObjectId, dict[ObjectId, TopicResourceModel]] = {} + for topic in await get_topics_with_subscribers_async(): + for subscriber in topic.subscribers or []: + if notification_type is not None and subscriber.notification_type != notification_type: continue - user_topic_map.setdefault(subscriber["user_id"], {}) - user_topic_map[subscriber["user_id"]][topic["_id"]] = topic + user_topic_map.setdefault(subscriber.user_id, {}) + user_topic_map[subscriber.user_id][topic.id] = topic return user_topic_map diff --git a/newsroom/types/__init__.py b/newsroom/types/__init__.py index 644a7bafb..a5fdcf99c 100644 --- a/newsroom/types/__init__.py +++ b/newsroom/types/__init__.py @@ -7,7 +7,7 @@ from .common import SectionEnum from .search import AdvancedSearchParams from .user_roles import UserRole -from .products import ProductType, PRODUCT_TYPES, ProductResourceModel +from .products import ProductResourceModel from .cards import CardResourceModel, DashboardCardConfig, DashboardCardType, DashboardCardDict from .company import CompanyProduct, CompanyResource from .navigation import NavigationModel @@ -27,13 +27,19 @@ from .history import HistoryResourceModel from .wire import WireItem from .featured import FeaturedResourceModel +from .agenda import ( + AgendaItem, + AgendaItemType, + AgendaCoverage, + AgendaCoverageDelivery, + AgendaPlanningItem, + AgendaWorkflowState, +) __all__ = [ "SectionEnum", "AdvancedSearchParams", "UserRole", - "ProductType", - "PRODUCT_TYPES", "CardResourceModel", "DashboardCardConfig", "DashboardCardType", @@ -61,6 +67,12 @@ "WireItem", "FeaturedResourceModel", "ProductResourceModel", + "AgendaItem", + "AgendaItemType", + "AgendaPlanningItem", + "AgendaCoverage", + "AgendaCoverageDelivery", + "AgendaWorkflowState", ] diff --git a/newsroom/types/agenda.py b/newsroom/types/agenda.py new file mode 100644 index 000000000..d53c51fd5 --- /dev/null +++ b/newsroom/types/agenda.py @@ -0,0 +1,282 @@ +from typing import Annotated, Any +from datetime import datetime, timezone +from enum import Enum, unique + +from pydantic import Field, field_validator, model_validator + +from superdesk.core.resources import ResourceModel, fields, dataclass, ModelWithVersions +from content_api.items.model import Place, CVItemWithCode, CVItem, PubStatusType + + +def convert_value_to_bool(value: bool | None) -> bool: + """This allows to support None values for fields that require a boolean""" + + return bool(value) + + +def convert_none_to_utcnow(value: datetime | None) -> datetime: + """This allows to support None values for fields that require a datetime""" + + return datetime.now(timezone.utc) if value is None else value + + +def convert_none_to_list(value: list | None) -> list: + """This allows to support None values for fields that require a list""" + + return [] if value is None else value + + +@unique +class AgendaItemType(str, Enum): + EVENT = "event" + PLANNING = "planning" + + +@unique +class AgendaWorkflowState(str, Enum): + SCHEDULED = "scheduled" + KILLED = "killed" + CANCELLED = "cancelled" + RESCHEDULED = "rescheduled" + POSTPONED = "postponed" + + +@dataclass +class AgendaCVItem: + code: fields.Keyword + name: fields.Keyword + qcode: fields.Keyword | None = None + scheme: fields.Keyword | None = None + parent: fields.Keyword | None = None + translations: dict[str, dict[str, str | None]] | None = None + + +@dataclass +class EventRecurringRule: + frequency: str | None = None + interval: int | None = None + endRepeatMode: str | None = None + until: datetime | None = None + count: int | None = None + _created_externally: bool | None = None + + +@dataclass +class AgendaDates: + start: datetime + end: datetime + tz: str | None = None + all_day: bool = False + no_end_time: bool = False + recurring_rule: EventRecurringRule | None = None + + # Field validators + _parse_no_end_time = field_validator("all_day", "no_end_time", mode="before")(convert_value_to_bool) + + +@dataclass +class AgendaDisplayDates: + date: datetime + + +@dataclass +class AgendaCoverageDelivery: + delivery_id: fields.Keyword | None = None + delivery_href: fields.Keyword | None = None + sequence_no: Annotated[int, fields.keyword_mapping()] = 0 + publish_time: datetime | None = None + delivery_state: fields.Keyword | None = None + + +@dataclass +class AgendaCoverage: + planning_id: fields.Keyword + coverage_id: fields.Keyword + scheduled: datetime + coverage_type: fields.Keyword + workflow_status: fields.Keyword + coverage_status: fields.Keyword + coverage_provider: fields.Keyword | None = None + slugline: fields.HTML | None = None + + delivery_id: fields.Keyword | None = None + delivery_href: fields.Keyword | None = None + publish_time: datetime | None = None + + time_to_be_confirmed: Annotated[bool, Field(alias="_time_to_be_confirmed")] = False + deliveries: list[AgendaCoverageDelivery] = Field(default_factory=list) + watches: Annotated[list[fields.ObjectId], fields.keyword_mapping()] = Field(default_factory=list) + + genre: list[CVItem] = Field(default_factory=list) + + assigned_desk_name: fields.Keyword | None = None + assigned_desk_email: Annotated[str | None, fields.not_indexed()] = None + assigned_user_name: fields.Keyword | None = None + assigned_user_email: Annotated[str | None, fields.not_indexed()] = None + + # Field validators + _parse_time_to_be_confirmed = field_validator("time_to_be_confirmed", mode="before")(convert_value_to_bool) + _parse_list_fields = field_validator( + "deliveries", + "watches", + "genre", + mode="before", + )(convert_none_to_list) + + +@dataclass +class EventLocation: + name: fields.TextWithKeyword + address: Annotated[dict | None, fields.dynamic_mapping()] = None + location: fields.Geopoint | None = None + qcode: fields.Keyword | None = None + geo: str | None = None + formatted_address: str | None = None + details: list[str] | None = None + + +@dataclass +class PlanningItemAgenda: + _id: fields.Keyword + name: fields.Keyword + + +@dataclass +class AgendaPlanningItem: + _id: fields.Keyword + guid: fields.Keyword + planning_date: datetime + state: fields.Keyword + pubstatus: Annotated[PubStatusType | None, fields.keyword_mapping()] = None + time_to_be_confirmed: Annotated[bool, Field(alias="_time_to_be_confirmed")] = False + firstcreated: datetime = Field(default_factory=datetime.now) + versioncreated: datetime = Field(default_factory=datetime.now) + language: fields.Keyword | None = None + source: fields.Keyword | None = None + name: fields.Keyword | None = None + slugline: fields.Keyword | None = None + description_text: str | None = None + headline: str | None = None + abstract: str | None = None + subject: Annotated[list[AgendaCVItem], fields.nested_list(include_in_parent=True)] = Field(default_factory=list) + urgency: int | None = None + service: list[AgendaCVItem] = Field(default_factory=list) + coverages: Annotated[list[dict[str, Any]], fields.mapping_disabled("object")] = Field(default_factory=list) + agendas: list[PlanningItemAgenda] = Field(default_factory=list) + ednote: str | None = None + internal_note: Annotated[str | None, fields.not_indexed()] = None + place: list[Place] = Field(default_factory=list) + state_reason: str | None = None + products: list[CVItemWithCode] = Field(default_factory=list) + + # Field validators + _parse_time_to_be_confirmed = field_validator("time_to_be_confirmed", mode="before")(convert_value_to_bool) + _parse_datetime_fields = field_validator("firstcreated", "versioncreated", mode="before")(convert_none_to_utcnow) + _parse_list_fields = field_validator( + "subject", + "service", + "coverages", + "agendas", + "place", + "products", + mode="before", + )(convert_none_to_list) + + +@dataclass +class CalendarItem: + qcode: fields.Keyword + name: fields.Keyword + schema: fields.Keyword | None = None + is_active: bool = True + translations: dict[str, Any] | None = None + + +class AgendaItem(ResourceModel, ModelWithVersions): + id: Annotated[str, Field(alias="_id")] + guid: fields.Keyword + content_type: Annotated[fields.Keyword, Field(alias="type")] = "agenda" + event_id: fields.Keyword | None = None + item_type: Annotated[AgendaItemType, fields.keyword_mapping()] + recurrence_id: fields.Keyword | None = None + name: fields.HTML | None = None + slugline: fields.HTML | None = None + definition_short: fields.HTML | None = None + definition_long: fields.HTML | None = None + description_text: fields.HTML | None = None + headline: fields.HTML | None = None + firstcreated: datetime = Field(default_factory=datetime.now) + versioncreated: datetime = Field(default_factory=datetime.now) + version: int | None = None + ednote: fields.HTML | None = None + registration_details: str | None = None + invitation_details: str | None = None + language: fields.Keyword | None = None + source: fields.Keyword | None = None + + urgency: Annotated[int | None, fields.keyword_mapping()] = None + priority: Annotated[int | None, fields.keyword_mapping()] = None + + place: list[Place] = Field(default_factory=list) + service: list[AgendaCVItem] = Field(default_factory=list) + + state_reason: str | None = None + subject: Annotated[list[AgendaCVItem], fields.nested_list(include_in_parent=True), Field(default_factory=list)] + dates: AgendaDates + display_dates: list[AgendaDisplayDates] = Field(default_factory=list) + + coverages: Annotated[list[AgendaCoverage], fields.nested_list(include_in_parent=True), Field(default_factory=list)] + + files: Annotated[list[dict], fields.mapping_disabled("object"), Field(default_factory=list)] + + state: AgendaWorkflowState + pubstatus: Annotated[PubStatusType | None, fields.keyword_mapping()] = None + calendars: list[CalendarItem] = Field(default_factory=list) + location: list[EventLocation] = Field(default_factory=list) + event: dict | None = None + + bookmarks: Annotated[list[fields.ObjectId], fields.keyword_mapping(), Field(default_factory=list)] + downloads: Annotated[list[fields.ObjectId], fields.keyword_mapping(), Field(default_factory=list)] + shares: Annotated[list[fields.ObjectId], fields.keyword_mapping(), Field(default_factory=list)] + prints: Annotated[list[fields.ObjectId], fields.keyword_mapping(), Field(default_factory=list)] + copies: Annotated[list[fields.ObjectId], fields.keyword_mapping(), Field(default_factory=list)] + watches: Annotated[list[fields.ObjectId], fields.keyword_mapping(), Field(default_factory=list)] + + products: Annotated[list[CVItemWithCode], Field(default_factory=list)] + planning_items: Annotated[ + list[AgendaPlanningItem], fields.nested_list(include_in_parent=True), Field(default_factory=list) + ] + + # Field/Model validators + _parse_datetime_fields = field_validator("firstcreated", "versioncreated", mode="before")(convert_none_to_utcnow) + _parse_list_fields = field_validator( + "place", + "service", + "subject", + "display_dates", + "coverages", + "files", + "calendars", + "location", + "bookmarks", + "downloads", + "shares", + "prints", + "copies", + "watches", + "products", + "planning_items", + mode="before", + )(convert_none_to_list) + + @model_validator(mode="before") + @classmethod + def parse_dict(cls, values) -> dict[str, Any]: + if not values.get("guid") and values.get("_id"): + # Make sure there is a ``guid`` + values["guid"] = values["_id"] + elif not values.get("_id") and values.get("guid"): + # Make sure there is a ``_id`` + values["_id"] = values["guid"] + + return values diff --git a/newsroom/types/company.py b/newsroom/types/company.py index b70352ded..a05f3a99b 100644 --- a/newsroom/types/company.py +++ b/newsroom/types/company.py @@ -8,13 +8,13 @@ from newsroom.core.resources import NewshubResourceModel, validate_ip_address, validate_auth_provider -from .products import ProductType +from .common import SectionEnum @dataclass class CompanyProduct: _id: Annotated[ObjectId, validate_data_relation_async("products")] - section: ProductType + section: SectionEnum seats: int = 0 def to_dict(self): @@ -51,3 +51,4 @@ class CompanyResource(NewshubResourceModel): auth_provider: Annotated[Optional[str], validate_auth_provider()] = None company_size: Optional[str] = None referred_by: Optional[str] = None + internal: bool = False diff --git a/newsroom/types/products.py b/newsroom/types/products.py index 0ae8e9c1b..45a3fc3bc 100644 --- a/newsroom/types/products.py +++ b/newsroom/types/products.py @@ -1,4 +1,3 @@ -from enum import Enum from bson import ObjectId from typing import Annotated @@ -7,13 +6,7 @@ from newsroom.core.resources.model import NewshubResourceModel from newsroom.core.resources.validators import validate_multi_field_iunique_value_async, validate_valid_objectid -PRODUCT_TYPES = ["wire", "agenda", "news_api"] - - -class ProductType(str, Enum): - WIRE = "wire" - AGENDA = "agenda" - NEWS_API = "news_api" +from .common import SectionEnum class ProductResourceModel(NewshubResourceModel): @@ -26,7 +19,7 @@ class ProductResourceModel(NewshubResourceModel): query: str | None = None planning_item_query: str | None = None is_enabled: bool = True - product_type: ProductType = ProductType.WIRE + product_type: SectionEnum = SectionEnum.WIRE navigations: list[ Annotated[ ObjectId, diff --git a/newsroom/types/users.py b/newsroom/types/users.py index dc8ba0cbf..fc66c0e4d 100644 --- a/newsroom/types/users.py +++ b/newsroom/types/users.py @@ -2,6 +2,8 @@ import pytz from pydantic import Field + +# from pydantic.dataclasses import dataclass from typing import Annotated, List, Optional from dataclasses import asdict from quart_babel import lazy_gettext @@ -34,8 +36,8 @@ def to_dict(self): @dataclass class NotificationScheduleModel: - timezone: str - times: List[str] + timezone: str | None = None + times: list[str] = Field(default_factory=list) last_run_time: Optional[datetime] = None pause_from: Optional[str] = None pause_to: Optional[str] = None @@ -117,6 +119,9 @@ def has_paused_notifications(self) -> bool: return False + def is_events_only_access(self, company: CompanyResource | None) -> bool: + return company.events_only if company and not self.is_admin() else False + class UserAuthResourceModel(UserResourceModel): password: Optional[Annotated[str, validate_minlength(8)]] = None diff --git a/newsroom/types/wire.py b/newsroom/types/wire.py index a05cff59c..729e7b3ff 100644 --- a/newsroom/types/wire.py +++ b/newsroom/types/wire.py @@ -2,18 +2,24 @@ from pydantic import Field -from superdesk.core.resources import fields +from superdesk.core.resources import fields, dataclass from content_api.items.model import ContentAPIItem, CVItemWithCode +@dataclass +class PublishedProduct: + code: fields.Keyword + name: fields.Keyword | None = None + + class WireItem(ContentAPIItem): - products: Annotated[list[CVItemWithCode], Field(default_factory=list)] + products: Annotated[list[PublishedProduct], Field(default_factory=list)] bookmarks: Annotated[list[fields.ObjectId], fields.keyword_mapping(), Field(default_factory=list)] - downloads: Annotated[list[fields.Keyword], fields.keyword_mapping(), Field(default_factory=list)] - shares: Annotated[list[fields.Keyword], fields.keyword_mapping(), Field(default_factory=list)] - prints: Annotated[list[fields.Keyword], fields.keyword_mapping(), Field(default_factory=list)] - copies: Annotated[list[fields.Keyword], fields.keyword_mapping(), Field(default_factory=list)] + downloads: Annotated[list[fields.ObjectId], fields.keyword_mapping(), Field(default_factory=list)] + shares: Annotated[list[fields.ObjectId], fields.keyword_mapping(), Field(default_factory=list)] + prints: Annotated[list[fields.ObjectId], fields.keyword_mapping(), Field(default_factory=list)] + copies: Annotated[list[fields.ObjectId], fields.keyword_mapping(), Field(default_factory=list)] # Overrides from ContentAPI Schema subject: Annotated[list[CVItemWithCode], fields.nested_list(include_in_parent=True), Field(default_factory=list)] diff --git a/newsroom/users/service.py b/newsroom/users/service.py index 6e856d7d1..281ccb57c 100644 --- a/newsroom/users/service.py +++ b/newsroom/users/service.py @@ -7,13 +7,12 @@ from quart_babel import gettext from werkzeug.exceptions import BadRequest -from newsroom.types import User from superdesk.core import get_app_config from superdesk.core.types import HTTP_METHOD, Request from superdesk.flask import abort from superdesk.utils import is_hashed, get_hash -from newsroom.types import UserResourceModel, UserAuthResourceModel +from newsroom.types import User, UserResourceModel, UserAuthResourceModel from newsroom.exceptions import AuthorizationError from newsroom.settings import get_setting from newsroom.auth.utils import ( @@ -156,7 +155,7 @@ async def get_by_email(self, email: str) -> UserResourceModel | None: lookup = {"$regex": re.compile("^{}$".format(re.escape(email)), re.IGNORECASE)} return await self.find_one(email=lookup) - async def update_notification_schedule_run_time(self, user: dict[str, Any], run_time: datetime): + async def update_notification_schedule_run_time(self, user: UserResourceModel, run_time: datetime): """ Updates the user's notification schedule with the provided run time and clears related cache. @@ -164,13 +163,19 @@ async def update_notification_schedule_run_time(self, user: dict[str, Any], run_ user: The user object containing the current notification schedule. run_time: The new run time to be updated in the notification schedule. """ - notification_schedule = deepcopy(user["notification_schedule"]) - notification_schedule["last_run_time"] = run_time - await self.update(user["_id"], {"notification_schedule": notification_schedule}) + + if not user.notification_schedule: + # No need to update the schedule if None is set + # A default is populated in the ``SendScheduledNotificationEmails`` class anyway + return + + notification_schedule = deepcopy(user.notification_schedule) + notification_schedule.last_run_time = run_time + await self.update(user.id, {"notification_schedule": notification_schedule}) app = self.app.wsgi - app.cache.delete(str(user["_id"])) - app.cache.delete(user["email"]) + app.cache.delete(str(user.id)) + app.cache.delete(user.email) @staticmethod def user_has_paused_notifications(user: User) -> bool: diff --git a/newsroom/users/users.py b/newsroom/users/users.py index eb70d0a25..bccd766d2 100644 --- a/newsroom/users/users.py +++ b/newsroom/users/users.py @@ -5,7 +5,7 @@ from superdesk.flask import request -from newsroom.types import PRODUCT_TYPES, UserRole +from newsroom.types import SectionEnum, UserRole from newsroom.auth.eve_auth import SessionAuth @@ -96,7 +96,7 @@ class UsersResource(newsroom.Resource): "type": "dict", "schema": { "_id": newsroom.Resource.rel("products", required=True), - "section": {"type": "string", "required": True, "allowed": PRODUCT_TYPES}, + "section": {"type": "string", "required": True, "allowed": [t.value for t in SectionEnum]}, }, }, }, diff --git a/newsroom/utils.py b/newsroom/utils.py index fd9da4492..8aeaa88ca 100644 --- a/newsroom/utils.py +++ b/newsroom/utils.py @@ -17,7 +17,7 @@ from quart_babel import gettext, format_date as _format_date from superdesk.core.types import Request -from superdesk.core import json, get_current_app, get_app_config +from superdesk.core import json, get_current_app, get_app_config, get_current_async_app from superdesk.flask import abort, request, g, url_for, Request as FlaskRequest from superdesk.json_utils import try_cast from superdesk.etree import parse_html @@ -105,24 +105,6 @@ def get_entity_or_404(_id, resource): return item -# TODO-ASYNC: Remove this once Agenda is migrated to async -def get_entities_elastic_or_mongo_or_404(_ids, resource): - """Finds item in elastic search as fist preference. If not configured, finds from mongo""" - elastic = get_current_app().data._search_backend(resource) - items = [] - if elastic: - for id in _ids: - item = elastic.find_one("items", req=None, _id=id) - if not item: - item = get_entity_or_404(id, resource) - - items.append(item) - else: - items = [get_entity_or_404(i, resource) for i in _ids] - - return items - - async def get_json_or_400(): data = await request.get_json() if not isinstance(data, dict): @@ -454,6 +436,13 @@ def get_vocabulary(id): return None +async def get_vocabulary_async(cv_id) -> dict[str, Any] | None: + vocabularies = get_current_async_app().mongo.get_db_async("items").get_collection("vocabularies") + if vocabularies is not None: + return await vocabularies.find_one({"_id": cv_id}) + return None + + def url_for_agenda(item, _external=True): """Get url for agenda item.""" return url_for("agenda.index", item=item["_id"], _external=_external) @@ -471,30 +460,6 @@ def set_version_creator(doc): doc["version_creator"] = get_user_id_from_request(None) -# TODO-ASYNC: Remove this once Agenda is upgraded to async -def get_items_for_user_action(_ids, item_type): - # Getting entities from elastic first so that we get all fields - # even those which are not a part of ItemsResource(content_api) schema. - items = get_entities_elastic_or_mongo_or_404(_ids, item_type) - - if not items: - return items - elif items[0].get("type") == "text": - for item in items: - if item.get("slugline") and item.get("anpa_take_key"): - item["slugline"] = "{0} | {1}".format(item["slugline"], item["anpa_take_key"]) - elif items[0].get("type") == "agenda": - # Import here to prevent circular imports - from newsroom.auth.utils import get_company_from_request - from newsroom.agenda.utils import remove_restricted_coverage_info - - company = get_company_from_request(None) - if company and company.restrict_coverage_info: - remove_restricted_coverage_info(items) - - return items - - def get_utcnow(): """added for unit tests""" return datetime.utcnow() diff --git a/newsroom/web/default_settings.py b/newsroom/web/default_settings.py index f62e19c6d..8d0e93370 100644 --- a/newsroom/web/default_settings.py +++ b/newsroom/web/default_settings.py @@ -115,7 +115,6 @@ BABEL_DEFAULT_TIMEZONE = DEFAULT_TIMEZONE BLUEPRINTS = [ - "newsroom.agenda", "newsroom.news_api.api_tokens", "newsroom.monitoring", ] @@ -170,7 +169,7 @@ "newsroom.wire.module", "newsroom.company_admin", "newsroom.public", - "newsroom.agenda", + "newsroom.agenda.module", "newsroom.products", "newsroom.design", "newsroom.auth_server.client", diff --git a/newsroom/wire/__init__.py b/newsroom/wire/__init__.py index ddf77efd4..0d83e655b 100644 --- a/newsroom/wire/__init__.py +++ b/newsroom/wire/__init__.py @@ -7,11 +7,13 @@ from newsroom.search.config import init_nested_aggregation from . import utils - -from .service import WireSearchServiceAsync +from .filters import WireSearchRequestArgs +from .service import WireSearchServiceAsync, WireItemService __all__ = [ "WireSearchServiceAsync", + "WireItemService", + "WireSearchRequestArgs", ] diff --git a/newsroom/wire/filters.py b/newsroom/wire/filters.py index 57db7b545..19224b28b 100644 --- a/newsroom/wire/filters.py +++ b/newsroom/wire/filters.py @@ -153,7 +153,7 @@ def _get_aggregation_field(key: str) -> str: def apply_aggs(request: NewshubSearchRequest) -> None: """Adds elasticsearch aggregations to the query, based on the request args""" - if request.args.page > 0 or not request.args.aggs: + if request.args.page > 0 or not request.args.aggs or request.search.aggs: return request.search.aggs = _get_wire_aggregations() diff --git a/newsroom/wire/items.py b/newsroom/wire/items.py index a88cbcf7a..b98ad794f 100644 --- a/newsroom/wire/items.py +++ b/newsroom/wire/items.py @@ -15,6 +15,7 @@ from superdesk.metadata.item import metadata_schema from newsroom.types import Article, CardResourceModel +from newsroom.products import ProductsService from newsroom.cards import get_card_size from .service import WireSearchServiceAsync @@ -99,8 +100,7 @@ def filter_fields(item: Article) -> Article: items_by_card = {} wire_search = WireSearchServiceAsync() - # TODO-ASYNC: Convert to async Product model when available - all_products = {product["_id"]: product for product in superdesk.get_resource_service("products").get_cached()} + all_products = {product.id: product async for product in ProductsService().get_all()} for card in cards: if card.config is not None and card.config.get("product"): items_by_card[card.label] = [ diff --git a/newsroom/wire/service.py b/newsroom/wire/service.py index f722499c5..8e19a89c0 100644 --- a/newsroom/wire/service.py +++ b/newsroom/wire/service.py @@ -5,31 +5,33 @@ from bson import ObjectId from content_api.items.model import PubStatusType -from superdesk.core.types import Request, Response, SearchRequest, ESQuery, ESBoolQuery +from superdesk.core.types import Request, Response, SearchRequest, ESQuery from superdesk.core import get_app_config from superdesk.core.resources import AsyncResourceService from superdesk.core.resources.cursor import ElasticsearchResourceCursorAsync -from newsroom.exceptions import AuthorizationError -from newsroom.types import SectionEnum, Product, TopicResourceModel, UserResourceModel, CompanyResource, WireItem -from newsroom.auth.utils import get_user_or_none_from_request, get_user_sections +from newsroom.types import ( + SectionEnum, + ProductResourceModel, + UserResourceModel, + CompanyResource, + WireItem, +) +from newsroom.auth.utils import get_user_or_none_from_request from newsroom.search.types import NewshubSearchRequest, SearchFilterFunction -from newsroom.search.base_service import BaseNewshubSearchService +from newsroom.search.base_web_service import BaseWebSearchService from newsroom.search.filters import ( apply_query_string, apply_date_range, apply_advanced_search, prefill_user, prefill_company, - prefill_products, - prefill_args_from_topic, apply_section_filter, apply_company_filter, apply_products_filter, validate_request, apply_ids_filter, ) -from newsroom.products.products import get_products_by_navigation from .filters import ( WireSearchRequestArgs, @@ -39,6 +41,7 @@ apply_embargoed_filters, apply_not_canceled_filter, apply_time_limit_filter, + apply_highlights, ) @@ -50,17 +53,18 @@ async def insert_versioned_document(self, doc_dict: dict[str, Any]): await super().insert_versioned_document(doc_dict) if doc_dict.get("pubstatus") == PubStatusType.CANCELLED.value: + # versioned_document = self._get_versioned_document(doc_dict) # If the update is a cancel, we need to cancel all versions await self.mongo_versioned_async.update_many( { - "_id_document": doc_dict["_id_document"], + "_id_document": doc_dict["_id"], "pubstatus": {"$ne": PubStatusType.CANCELLED.value}, }, {"$set": {"pubstatus": PubStatusType.CANCELLED.value}}, ) -class WireSearchServiceAsync(BaseNewshubSearchService[WireSearchRequestArgs, WireItem]): +class WireSearchServiceAsync(BaseWebSearchService[WireSearchRequestArgs, WireItem]): """Wire search service class, for searching for items while applying permissions and API request args""" search_args_class = WireSearchRequestArgs @@ -68,6 +72,28 @@ class WireSearchServiceAsync(BaseNewshubSearchService[WireSearchRequestArgs, Wir section = SectionEnum.WIRE default_sort = [{"versioncreated", -1}] default_page_size = 25 + service: WireItemService + + get_items_by_id_filters = [ + apply_item_type_filter, + apply_ids_filter, + ] + get_topic_items_query_execute_filters = [ + apply_products_filter, + apply_embargoed_filters, + apply_query_string, + apply_ids_filter, + apply_filters, + apply_advanced_search, + apply_date_range, + apply_highlights, + ] + get_topic_items_query_user_filters = [ + apply_section_filter, + apply_item_type_filter, + apply_company_filter, + apply_time_limit_filter, + ] def __init__(self): self.service = WireItemService() @@ -91,46 +117,22 @@ async def get_current_user_bookmarks_count(self, section: SectionEnum | None = N ) return await cursor.count() - async def get_items_by_id( - self, item_ids: list[str], args: WireSearchRequestArgs | None = None, apply_permissions: bool = False - ) -> ElasticsearchResourceCursorAsync[WireItem]: - """Searches for items by ID, optionally applying user/company permissions - - :param item_ids: A list of item IDs to search for - :param args: Optional set of request arguments to apply - :param apply_permissions: Whether to apply user/company permissions or not - :returns: Elasticsearch cursor with the results - """ - - if args is None: - args = WireSearchRequestArgs() - - args.ids = item_ids - return await self.search( - args, - filters=None - if apply_permissions - else [ - apply_item_type_filter, - apply_ids_filter, - ], - ) - - async def get_items_for_action(self, item_ids: list[str]) -> ElasticsearchResourceCursorAsync[WireItem]: + async def get_items_for_action(self, item_ids: list[str]) -> list[dict[str, Any]]: """Searches for item by ID, for use by downloads, sharing etc For each item, appends the ``anpa_take_key`` to the slugline if defined :param item_ids: A list of item IDs to search for - :returns: Elasticsearch cursor with the results + :returns: The list of WIre items """ cursor = await self.get_items_by_id(item_ids, args=WireSearchRequestArgs(ignore_latest=True)) - async for item in cursor: - if item.slugline and item.anpa_take_key: - item.slugline = f"{item.slugline} | {item.anpa_take_key}" + items = await cursor.to_list_raw() + for item in items: + if item.get("slugline") and item.get("anpa_take_key"): + item["slugline"] = item["slugline"] + " | " + item["anpa_take_key"] - return cursor + return items async def process_web_request(self, request: Request) -> Response: """Process q request from the WebAPI @@ -311,9 +313,8 @@ async def prepend_embargoed_items_to_response( cursor.hits["hits"]["hits"] = embargoed_cursor.hits["hits"]["hits"] + cursor.hits["hits"]["hits"] cursor.hits["hits"]["total"] = await embargoed_cursor.count() + await cursor.count() - # TODO-ASYNC: Convert to async Product model when available async def get_product_items_for_dashboard( - self, product: Product, size: int, exclude_embargoed: bool = False + self, product: ProductResourceModel, size: int, exclude_embargoed: bool = False ) -> list[WireItem]: """Return items for the provided product for use with the dashboard @@ -328,7 +329,7 @@ def prefill_requested_product(request: NewshubSearchRequest): cursor = await self.search( WireSearchRequestArgs( - product_ids=[product["_id"]], + product_ids=[product.id], page_size=size, exclude_embargoed=not get_app_config("DASHBOARD_EMBARGOED") or exclude_embargoed, ), @@ -350,186 +351,6 @@ def prefill_requested_product(request: NewshubSearchRequest): ) return await cursor.to_list() - async def get_topic_items_query( - self, - topic: TopicResourceModel | None, - user: UserResourceModel | None, - company: CompanyResource | None, - query: ESQuery | None = None, - args: WireSearchRequestArgs | None = None, - ) -> ESQuery | None: - """Generate an elasticsearch query, based on topic, user and company - - :param topic: An optional Topic to be added to the request args - :param user: An optional User to be added to the request args - :param company: An optional Company to be added to the request args - :param query: An optional Elasticsearch query to start with - :param args: An optional request args to start with - :returns: The generated Elasticsearch query, or None if the supplied User does not have permission - """ - - def prefill_request(request: NewshubSearchRequest): - if topic: - request.topic = topic - if user: - request.user = request.current_user = user - request.is_admin = request.user.is_admin() - else: - request.is_admin = False - - if company: - request.company = company - if query: - request.search = query - - if user is None and topic is not None and topic.navigation is not None: - # TODO-ASYNC: Convert to Async service when it's available - request.products = get_products_by_navigation(topic.navigation) - - search_request = NewshubSearchRequest( - section=self.section, - web_request=None, - args=args or WireSearchRequestArgs(), - search=query or ESQuery(), - ) - - filters: list[SearchFilterFunction] = [ - # Pre-fill the request arguments - prefill_request, - prefill_args_from_topic, - # Apply standard filters used to match a topic - apply_products_filter, - apply_embargoed_filters, - apply_query_string, - apply_ids_filter, - apply_filters, - apply_advanced_search, - apply_date_range, - ] - - if user is not None: - # If this query is from a User's perspective, then add - # validation and section/company filters - filters.extend( - [ - prefill_products, - # Make sure the request has been validated - validate_request, - # Base topics: - apply_section_filter, - apply_item_type_filter, - apply_company_filter, - apply_time_limit_filter, - ] - ) - - try: - return await self.run_filters_and_return_query(search_request, filters) - except AuthorizationError: - if user and topic: - logger.info(f"Notification for user:{user.id} and topic:{topic.id} is skipped") - pass - - return None - - async def get_mathing_topics_for_item( - self, - item_id: str, - topics: list[TopicResourceModel], - users: list[UserResourceModel], - companies: dict[ObjectId, CompanyResource], - ) -> set[ObjectId]: - """Get a set of Topic IDs that match the supplied item - - :param item_id: The ID of the item to match topics against - :param topics: The list of Topics to match the item against - :param users: The list of Users to match the item against - :param companies: The list of Companies to match the item against - :returns: A set of Topic IDs that the wire item matches - """ - - return await self.get_matching_topics_for_query( - topics, - users, - companies, - ESQuery(query=ESBoolQuery(must=[{"term": {"_id": item_id}}])), - ) - - async def get_matching_topics_for_query( - self, - topics: list[TopicResourceModel], - users: list[UserResourceModel], - companies: dict[ObjectId, CompanyResource], - query: ESQuery | None = None, - ) -> set[ObjectId]: - """Get a set of Topic IDs that match the supplied query - - :param topics: The list of Topics to match the item against - :param users: The list of Users to match the item against - :param companies: The list of Companies to match the item against - :param query: The Elasticsearch query to match topics for - :returns: A set of Topic IDs that the wire item matches - """ - - topic_matches: set[ObjectId] = set() - topics_checked: set[ObjectId] = set() - - for user in users: - company = companies.get(user.company) if user.company else None - user_sections = get_user_sections(user, company) - if not user_sections.get(self.section): - continue - - if user.has_paused_notifications(): - continue - - aggs: dict[str, Any] = {"topics": {"filters": {"filters": {}}}} - - # There will be one base search for a user with aggs for user topics - search = await self.get_topic_items_query(None, user, company, query=query) - if not search: - continue - queried_topics: list[TopicResourceModel] = [] - for topic in topics: - if topic.user is None or topic.user != user.id: - continue - elif topic.id in topics_checked: - continue - topics_checked.add(topic.id) - - topic_query = await self.get_topic_items_query(topic, None, None) - if not topic_query: - continue - - try: - aggs["topics"]["filters"]["filters"][str(topic.id)] = topic_query.generate_query_dict()["query"] - queried_topics.append(topic) - except (KeyError, TypeError, IndexError): - continue - - if not len(queried_topics): - continue - - search.aggs = aggs - search_request = SearchRequest( - max_results=0, - aggregations=True, - elastic=search, - ) - - try: - search_results: ElasticsearchResourceCursorAsync[WireItem] = await self.service.find(search_request) - for topic in queried_topics: - try: - if search_results.hits["aggregations"]["topics"]["buckets"][str(topic.id)]["doc_count"] > 0: - topic_matches.add(topic.id) - except (KeyError, IndexError, TypeError): - logger.warning(f"Failed to find aggregation result for topic {topic.id}") - except Exception: - logger.exception("Error in get_matching_topics", extra=dict(query=search_request, user=user.id)) - - return topic_matches - async def get_matching_item_bookmarks( self, item_ids: list[str], users: dict[ObjectId, UserResourceModel], companies: dict[ObjectId, CompanyResource] ) -> set[ObjectId]: @@ -569,7 +390,9 @@ async def get_matching_item_bookmarks( return bookmark_users - async def get_product_item_report(self, product: Product) -> ElasticsearchResourceCursorAsync[WireItem]: + async def get_product_item_report( + self, product: ProductResourceModel + ) -> ElasticsearchResourceCursorAsync[WireItem]: """Returns aggregation results for items for the supplied product, grouped by date range :param product: The product to get the items for the report @@ -629,7 +452,7 @@ async def get_product_item_report(self, product: Product) -> ElasticsearchResour return await self.search( NewshubSearchRequest( - section=cast(SectionEnum | None, product.get("product_type")) or self.section or SectionEnum.WIRE, + section=cast(SectionEnum | None, product.product_type) or self.section or SectionEnum.WIRE, products=[product], args=WireSearchRequestArgs(page_size=0), search=ESQuery(aggs=aggs), diff --git a/newsroom/wire/utils.py b/newsroom/wire/utils.py index 9d74b3321..d1d73305d 100644 --- a/newsroom/wire/utils.py +++ b/newsroom/wire/utils.py @@ -1,7 +1,6 @@ -from superdesk.core import get_current_app -from superdesk.flask import request +from superdesk.core import get_current_async_app -from newsroom.auth.utils import get_user_id_from_request +from newsroom.auth.utils import get_user_id_from_request, is_from_request, get_current_request def get_picture(item): @@ -21,8 +20,7 @@ def get_caption(picture): return picture.get("description_text") or picture.get("body_text") -# TODO-ASYNC: Update this once Agenda is migrated to async -def update_action_list(items, action_list, force_insert=False, item_type="items"): +async def update_action_list(items, action_list, force_insert=False, item_type="items"): """ Stores user id into array of action_list of an item :param items: items to be updated @@ -32,16 +30,23 @@ def update_action_list(items, action_list, force_insert=False, item_type="items" :return: """ user_id = get_user_id_from_request(None) - if user_id: - app = get_current_app() - db = app.data.get_mongo_collection(item_type) - elastic = app.data._search_backend(item_type) - if request.method == "POST" or force_insert: - updates = {"$addToSet": {action_list: user_id}} - else: - updates = {"$pull": {action_list: user_id}} - for item_id in items: - result = db.update_one({"_id": item_id}, updates) - if result.modified_count: - modified = db.find_one({"_id": item_id}) - elastic.update(item_type, item_id, {action_list: modified[action_list]}) + if not user_id: + return + + app = get_current_async_app() + db = app.mongo.get_collection_async(item_type) + elastic = app.elastic.get_client_async(item_type) + + insert = force_insert + if not insert and is_from_request(): + insert = get_current_request().method == "POST" + + if insert: + updates = {"$addToSet": {action_list: user_id}} + else: + updates = {"$pull": {action_list: user_id}} + for item_id in items: + result = await db.update_one({"_id": item_id}, updates) + if result.modified_count: + modified = await db.find_one({"_id": item_id}) + await elastic.update(item_id, {action_list: modified[action_list]}) diff --git a/newsroom/wire/views.py b/newsroom/wire/views.py index eedc32463..a496cce11 100644 --- a/newsroom/wire/views.py +++ b/newsroom/wire/views.py @@ -11,7 +11,6 @@ from superdesk.core.types import Request, Response from superdesk.core import get_app_config, get_current_app -from superdesk import get_resource_service from superdesk.flask import render_template, send_file from superdesk.utc import utcnow @@ -55,7 +54,6 @@ get_location_string, get_public_contacts, get_links, - get_items_for_user_action, ) from newsroom.notifications import push_user_notification, push_notification, save_user_notifications from newsroom.template_filters import is_admin_or_internal @@ -73,7 +71,7 @@ from newsroom.history_async import HistoryService from .items import get_items_for_dashboard -from .service import WireSearchServiceAsync +from .service import WireSearchServiceAsync, WireItemService HOME_ITEMS_CACHE_KEY = "home_items" HOME_EXTERNAL_ITEMS_CACHE_KEY = "home_external_items" @@ -117,7 +115,7 @@ async def get_view_data() -> dict: topics = await get_user_topics_async(user) user_folders = await get_user_folders(user, "wire") if user else [] company_folders = await get_company_folders(company, "wire") if company else [] - products = await get_products_by_company(company_dict, product_type="wire") if company_dict else [] + products = await get_products_by_company(company_dict, product_type=SectionEnum.WIRE) if company_dict else [] ui_config_service = UiConfigResourceService() check_user_has_products(user, products) @@ -265,7 +263,6 @@ async def get_previous_versions(wire_item: WireItem) -> list[dict]: wire_item.ancestors, args=WireSearchRequestArgs(ignore_latest=True) ) ancestors = await cursor.to_list_raw() - # ancestors = await (await WireSearchServiceAsync().get_items_by_id(wire_item.ancestors)).to_list_raw() return sorted(ancestors, key=itemgetter("versioncreated"), reverse=True) return [] @@ -344,12 +341,12 @@ async def download(args: None, params: ItemActionUrlParams, request: Request): if item_type == "agenda": # Getting Event and/or Planning items - # TODO-ASYNC: Update when Agenda is migrated to async - items = get_items_for_user_action(data["items"], item_type) + from newsroom.agenda import AgendaSearchServiceAsync + + items = await AgendaSearchServiceAsync().get_items_for_action(data["items"]) else: # Getting Wire items - cursor = await WireSearchServiceAsync().get_items_for_action(data["items"]) - items = await cursor.to_list_raw() + items = await WireSearchServiceAsync().get_items_for_action(data["items"]) _file = io.BytesIO() formatter = get_current_app().as_any().download_formatters[_format]["formatter"] @@ -406,7 +403,7 @@ async def download(args: None, params: ItemActionUrlParams, request: Request): ) _file.seek(0) - update_action_list(data["items"], "downloads", force_insert=True) + await update_action_list(data["items"], "downloads", force_insert=True) await HistoryService().create_history_record(items, "download", user.id, user.company, params.type.value) return await send_file( _file, @@ -431,12 +428,12 @@ async def share(args: None, params: ItemActionUrlParams, request: Request) -> Re users_service = UsersService() if item_type == "agenda": # Getting Event and/or Planning items - # TODO-ASYNC: Update when Agenda is migrated to async - items = get_items_for_user_action(data.get("items"), item_type) + from newsroom.agenda import AgendaSearchServiceAsync + + items = await AgendaSearchServiceAsync().get_items_for_action(data.get("items")) else: # Getting Wire items - cursor = await WireSearchServiceAsync().get_items_for_action(data.get("items")) - items = await cursor.to_list_raw() + items = await WireSearchServiceAsync().get_items_for_action(data.get("items")) for user_id in data["users"]: user = await users_service.find_by_id(user_id) @@ -489,7 +486,7 @@ async def share(args: None, params: ItemActionUrlParams, request: Request) -> Re template=f"share_{item_type}", template_kwargs=template_kwargs, ) - update_action_list(data.get("items"), "shares", item_type=item_type) + await update_action_list(data.get("items"), "shares", item_type=item_type) await HistoryService().create_history_record( items, "share", current_user.id, current_user.company, params.type.value ) @@ -528,7 +525,7 @@ async def bookmark() -> Response: """ data = await get_json_or_400() assert data.get("items") - update_action_list(data.get("items"), "bookmarks", item_type="items") + await update_action_list(data.get("items"), "bookmarks", item_type="items") push_user_notification("saved_items", count=await WireSearchServiceAsync().get_current_user_bookmarks_count()) return Response("") @@ -541,34 +538,34 @@ class WireItemRouteArgs(BaseModel): async def copy(args: WireItemRouteArgs, params: ItemActionUrlParams, request: Request) -> Response: """Endpoint to copy Wire OR Agenda item(s)""" + from newsroom.agenda import AgendaItemService + item_type = get_type() - if item_type == "agenda": - item = get_resource_service("agenda").find_one(req=None, _id=args.item_id) - else: - item = (await WireSearchServiceAsync().service.find_by_id(args.item_id)).to_dict() + service = AgendaItemService() if item_type == "agenda" else WireItemService() + item_to_copy = (await service.find_by_id(args.item_id)).to_dict() - if not item: + if not item_to_copy: await request.abort(404) template_filename = "copy_agenda_item" if item_type == "agenda" else "copy_wire_item" locale = (get_session_locale() or "en").lower() template_name = get_language_template_name(template_filename, locale, "txt") - template_kwargs = {"item": item} + template_kwargs = {"item": item_to_copy} if item_type == "agenda": template_kwargs.update( { - "location": "" if item_type != "agenda" else get_location_string(item), - "contacts": get_public_contacts(item), - "calendars": ", ".join([calendar.get("name") for calendar in item.get("calendars") or []]), + "location": "" if item_type != "agenda" else get_location_string(item_to_copy), + "contacts": get_public_contacts(item_to_copy), + "calendars": ", ".join([calendar.get("name") for calendar in item_to_copy.get("calendars") or []]), "user_profile_data": await get_user_profile_data(), } ) copy_data = (await render_template(template_name, **template_kwargs)).strip() - update_action_list([args.item_id], "copies", item_type=item_type) + await update_action_list([args.item_id], "copies", item_type=item_type) user = get_user_from_request(request) - await HistoryService().create_history_record([item], "copy", user.id, user.company, params.type.value) + await HistoryService().create_history_record([item_to_copy], "copy", user.id, user.company, params.type.value) return Response({"data": copy_data}) @@ -587,6 +584,11 @@ class WireItemUrlParams(BaseModel): monitoring_profile: str | None = None type: SectionEnum = SectionEnum.WIRE + @field_validator("print", mode="before") + def parse_print(cls, value: str | bool | None) -> bool | str | None: + # Support this URL param as a toggle, if `print` is provided in the URL then it is `True` + return True if value == "" else value + @wire_endpoints.endpoint("/wire/") async def item(args: WireItemRouteArgs, params: WireItemUrlParams, request: Request) -> Response | str: @@ -620,7 +622,7 @@ async def item(args: WireItemRouteArgs, params: WireItemUrlParams, request: Requ else: template = "wire_item_print.html" - update_action_list([wire_item.id], "prints", force_insert=True) + await update_action_list([wire_item.id], "prints", force_insert=True) user = get_user_from_request(request) await HistoryService().create_history_record( [wire_item.to_dict()], "print", user.id, user.company, params.type.value @@ -648,7 +650,7 @@ async def items(args: WireItemsRouteArgs, params: WireItemUrlParams, request: Re wire_search = WireSearchServiceAsync() # First get the items directly from the resource service - items_cursor = await wire_search.service.search({"bool": {"query": {"must": [{"terms": {"_id": args.item_ids}}]}}}) + items_cursor = await wire_search.service.search({"_id": {"$in": args.item_ids}}, use_mongo=True) if not await items_cursor.count(): return Response([]) @@ -659,7 +661,9 @@ async def items(args: WireItemsRouteArgs, params: WireItemUrlParams, request: Re allowed_ids = {item.id async for item in allowed_items_cursor} # And set the item permissions for each item - async for item in items_cursor: - set_item_permission(item, item.id in allowed_ids) + response = [] + async for wire_item in items_cursor: + set_item_permission(wire_item, wire_item.id in allowed_ids) + response.append(wire_item.to_dict()) - return Response(await items_cursor.to_list_raw()) + return Response(response) diff --git a/setup.cfg b/setup.cfg index 744604b7f..d5cb7098f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,6 +9,7 @@ per-file-ignores = python_version = 3.10 allow_untyped_globals = True ignore_missing_imports = True +plugins = pydantic.mypy exclude = (env|e2e|node_modules|scripts|docker|dist|assets|docs) [mypy-newsroom.core] diff --git a/tests/commands/test_remove_expired_items.py b/tests/commands/test_remove_expired_items.py index 177e0ae07..d05dbca23 100644 --- a/tests/commands/test_remove_expired_items.py +++ b/tests/commands/test_remove_expired_items.py @@ -1,14 +1,15 @@ from datetime import datetime, timedelta from newsroom.commands.remove_expired import remove_expired from newsroom.utils import find_one +from tests.core.utils import create_entries_for async def test_remove_expired_items(app): items = [ - {"_id": "expired", "versioncreated": datetime(2020, 10, 1), "expiry": datetime.utcnow() - timedelta(days=1)}, + {"_id": "expired", "versioncreated": datetime(2020, 10, 1), "expiry": datetime.now() - timedelta(days=1)}, ] - app.data.insert("items", items) + await create_entries_for("items", items) await remove_expired(1) diff --git a/tests/core/test_agenda.py b/tests/core/test_agenda.py index 7ab9a2439..ae05c1f23 100644 --- a/tests/core/test_agenda.py +++ b/tests/core/test_agenda.py @@ -9,10 +9,10 @@ get_location_string, get_agenda_dates, get_public_contacts, - get_entity_or_404, get_local_date, get_end_date, ) +from newsroom.tests import markers from tests.utils import ( post_json, delete_json, @@ -22,7 +22,7 @@ ) from tests.utils import login from tests.fixtures import ADMIN_USER_ID, PUBLIC_USER_ID, PUBLIC_USER_EMAIL, COMPANY_1_ID -from .utils import add_company_products, create_entries_for +from .utils import add_company_products, create_entries_for, find_one_by_id from copy import deepcopy from bson import ObjectId @@ -34,7 +34,7 @@ "abstract": "abstract text", "_current_version": 1, "agendas": [], - "anpa_category": [{"name": "Entertainment", "subject": "01000000", "qcode": "e"}], + "anpa_category": [{"name": "Entertainment", "qcode": "e"}], "item_id": "foo", "ednote": "ed note here", "slugline": "Vivid planning item", @@ -66,7 +66,7 @@ "urgency": 3, "guid": "foo", "name": "This is the name of the vivid planning item", - "subject": [{"name": "library and museum", "qcode": "01009000", "parent": "01000000"}], + "subject": [{"name": "library and museum", "qcode": "01009000"}], "pubstatus": "usable", "type": "planning", } @@ -74,7 +74,7 @@ @fixture async def agenda_user(client, app): - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -112,21 +112,30 @@ async def test_item_detail(client): async def test_item_json(client): resp = await client.get("/agenda/urn:conference?format=json") data = json.loads(await resp.get_data()) - assert "headline" in data + assert "name" in data assert "files" in data["event"] assert "internal_note" in data["event"] + + resp = await client.get("/agenda/urn:planning?format=json") + data = json.loads(await resp.get_data()) + assert "headline" in data assert "internal_note" in data["planning_items"][0] - assert "internal_note" in data["coverages"][0]["planning"] + assert "internal_note" in data["planning_items"][0]["coverages"][0]["planning"] async def test_item_json_does_not_return_files(client, app): await login(client, {"email": PUBLIC_USER_EMAIL}) data = await get_json(client, "/agenda/urn:conference?format=json") - assert "headline" in data + assert "name" in data assert "files" not in data["event"] assert "internal_note" not in data["event"] + + data = await get_json(client, "/agenda/urn:planning?format=json") + assert "headline" in data + # assert "files" not in data["event"] + # assert "internal_note" not in data["event"] assert "internal_note" not in data["planning_items"][0] - assert "internal_note" not in data["coverages"][0]["planning"] + assert "internal_note" not in data["planning_items"][0]["coverages"][0]["planning"] async def get_bookmarks_count(client, user): @@ -137,14 +146,14 @@ async def get_bookmarks_count(client, user): async def test_basic_search(client, agenda_user): - resp = await client.get("/agenda/search?q=headline") + resp = await client.get("/agenda/search?itemType=planning&q=headline") assert resp.status_code == 200, await resp.get_data(as_text=True) data = json.loads(await resp.get_data()) assert data["_meta"]["total"] async def test_search_with_accents(client, agenda_user): - resp = await client.get("/agenda/search?q=héadlíne") + resp = await client.get("/agenda/search?itemType=planning&q=héadlíne") assert resp.status_code == 200 data = json.loads(await resp.get_data()) assert data["_meta"]["total"] @@ -183,7 +192,7 @@ async def test_item_copy(client, app): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_share_items(client, app, mocker): - user_ids = app.data.insert( + user_ids = await create_entries_for( "users", [ { @@ -199,7 +208,7 @@ async def test_share_items(client, app, mocker): resp = await client.post( "/wire_share?type=agenda", json={ - "items": ["urn:conference"], + "items": ["urn:planning"], "users": [str(user_ids[0])], "message": "Some info message", }, @@ -212,10 +221,10 @@ async def test_share_items(client, app, mocker): assert "Hi Foo Bar" in outbox[0].body assert "admin admin (admin@sourcefabric.org) shared " in outbox[0].body assert "Conference Planning" in outbox[0].body - assert "http://localhost:5050/agenda?item=urn:conference" in outbox[0].body + assert "http://localhost:5050/agenda?item=urn:planning" in outbox[0].body assert "Some info message" in outbox[0].body - resp = await client.get("/agenda/{}?format=json".format("urn:conference")) + resp = await client.get("/agenda/{}?format=json".format("urn:planning")) data = json.loads(await resp.get_data()) assert "shares" in data @@ -245,7 +254,7 @@ async def test_agenda_search_filtered_by_query_product(client, app, public_compa ], ) - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -271,12 +280,7 @@ async def test_agenda_search_filtered_by_query_product(client, app, public_compa data = json.loads(await resp.get_data()) assert 1 == len(data["_items"]) assert "_aggregations" in data - assert "files" not in data["_items"][0]["event"] - assert "internal_note" not in data["_items"][0]["event"] - assert "internal_note" not in data["_items"][0]["planning_items"][0] - assert "internal_note" not in data["_items"][0]["planning_items"][0]["coverages"][0]["planning"] - assert "internal_note" not in data["_items"][0]["coverages"][0]["planning"] - resp = await client.get(f"/agenda/search?navigation={NAV_1}") + resp = await client.get(f"/agenda/search?itemType=planning&navigation={NAV_1}") data = json.loads(await resp.get_data()) assert 1 == len(data["_items"]) assert "_aggregations" in data @@ -320,19 +324,17 @@ async def test_watch_event(client, app): async def test_watch_coverages(client, app): - user_id = get_admin_user_id(app) - await post_json( client, "/agenda_coverage_watch", { "coverage_id": "urn:coverage", - "item_id": "urn:conference", + "item_id": "urn:planning", }, ) - after_watch_item = get_entity_or_404("urn:conference", "agenda") - assert after_watch_item["coverages"][0]["watches"] == [user_id] + after_watch_item = await find_one_by_id("agenda", "urn:planning") + assert after_watch_item["coverages"][0]["watches"] == [ObjectId(ADMIN_USER_ID)] async def test_unwatch_coverages(client, app): @@ -347,7 +349,7 @@ async def test_unwatch_coverages(client, app): }, ) - after_watch_item = get_entity_or_404("urn:conference", "agenda") + after_watch_item = await find_one_by_id("agenda", "urn:conference") assert after_watch_item["coverages"][0]["watches"] == [user_id] await delete_json( @@ -359,16 +361,13 @@ async def test_unwatch_coverages(client, app): }, ) - after_watch_item = get_entity_or_404("urn:conference", "agenda") + after_watch_item = await find_one_by_id("agenda", "urn:conference") assert after_watch_item["coverages"][0]["watches"] == [] async def test_remove_watch_coverages_on_watch_item(client, app): - user_id = ObjectId(get_admin_user_id(app)) - other_user_id = PUBLIC_USER_ID - test_planning_coverage_watches = deepcopy(test_planning) - test_planning_coverage_watches["coverages"][0]["watches"] = [other_user_id] + test_planning_coverage_watches["coverages"][0]["watches"] = [PUBLIC_USER_ID] await client.post( "/push", json=test_planning_coverage_watches, @@ -383,20 +382,19 @@ async def test_remove_watch_coverages_on_watch_item(client, app): }, ) - after_watch_coverage_item = get_entity_or_404(test_planning_coverage_watches["_id"], "agenda") - assert str(other_user_id) in after_watch_coverage_item["coverages"][0]["watches"] - assert user_id in after_watch_coverage_item["coverages"][0]["watches"] + after_watch_coverage_item = await find_one_by_id("agenda", test_planning_coverage_watches["_id"]) + assert PUBLIC_USER_ID in after_watch_coverage_item["coverages"][0]["watches"] + assert ObjectId(ADMIN_USER_ID) in after_watch_coverage_item["coverages"][0]["watches"] await post_json(client, "/agenda_watch", {"items": [test_planning_coverage_watches["_id"]]}) - after_watch_item = get_entity_or_404(test_planning_coverage_watches["_id"], "agenda") - assert after_watch_item["coverages"][0]["watches"] == [str(other_user_id)] - assert after_watch_item["watches"] == [user_id] + after_watch_item = await find_one_by_id("agenda", test_planning_coverage_watches["_id"]) + assert after_watch_item["coverages"][0]["watches"] == [PUBLIC_USER_ID] + assert after_watch_item["watches"] == [ObjectId(ADMIN_USER_ID)] async def test_fail_watch_coverages(client, app): await post_json(client, "/agenda_watch", {"items": ["urn:conference"]}) - after_watch_item = get_entity_or_404("urn:conference", "agenda") - print(after_watch_item["watches"]) + after_watch_item = await find_one_by_id("agenda", "urn:conference") assert after_watch_item["watches"] == [ObjectId(ADMIN_USER_ID)] request = { @@ -571,6 +569,8 @@ async def test_get_agenda_dates(): assert get_agenda_dates(agenda) == "May 27, 2018 - May 30, 2018" +@markers.skip_auto_agenda_items +@markers.skip_auto_wire_items async def test_filter_agenda_by_coverage_status(client): await client.post("/push", json=test_planning) @@ -583,7 +583,7 @@ async def test_filter_agenda_by_coverage_status(client): await client.post("/push", json=test_planning) test_planning["guid"] = "baz" - test_planning["planning_date"] = ("2018-05-28T10:45:52+0000",) + test_planning["planning_date"] = "2018-05-28T10:45:52+0000" test_planning["coverages"] = [] await client.post("/push", json=test_planning) @@ -657,7 +657,7 @@ async def test_filter_events_only(client): "abstract": "abstract text", "_current_version": 1, "agendas": [], - "anpa_category": [{"name": "Entertainment", "subject": "01000000", "qcode": "e"}], + "anpa_category": [{"name": "Entertainment", "qcode": "e"}], "item_id": "foo", "ednote": "ed note here", "slugline": "Vivid planning item", @@ -689,7 +689,7 @@ async def test_filter_events_only(client): "urgency": 3, "guid": "foo", "name": "This is the name of the vivid planning item", - "subject": [{"name": "library and museum", "qcode": "01009000", "parent": "01000000"}], + "subject": [{"name": "library and museum", "qcode": "01009000"}], "pubstatus": "usable", "type": "planning", } diff --git a/tests/core/test_agenda_events_only.py b/tests/core/test_agenda_events_only.py index bc3f64a78..efd1517fa 100644 --- a/tests/core/test_agenda_events_only.py +++ b/tests/core/test_agenda_events_only.py @@ -5,7 +5,7 @@ from newsroom.notifications import get_user_notifications from newsroom.tests import markers -from tests.core.utils import add_company_products, create_entries_for +from tests.core.utils import add_company_products, create_entries_for, update_entries_for, find_one_by_id from tests.fixtures import ( # noqa: F401 items, init_items, @@ -26,19 +26,19 @@ @fixture(autouse=True) async def set_events_only_company(app): - company = app.data.find_one("companies", None, _id=COMPANY_1_ID) + company = await find_one_by_id("companies", COMPANY_1_ID) assert company is not None updates = { "events_only": True, "sections": {"wire": True, "agenda": True}, "is_enabled": True, } - app.data.update("companies", COMPANY_1_ID, updates, company) - company = app.data.find_one("companies", None, _id=COMPANY_1_ID) + await update_entries_for("companies", COMPANY_1_ID, updates, company) + company = await find_one_by_id("companies", COMPANY_1_ID) assert company.get("events_only") is True - user = app.data.find_one("users", None, _id=PUBLIC_USER_ID) + user = await find_one_by_id("users", PUBLIC_USER_ID) assert user is not None - app.data.update("users", PUBLIC_USER_ID, {"is_enabled": True, "receive_email": True}, user) + await update_entries_for("users", PUBLIC_USER_ID, {"is_enabled": True, "receive_email": True}, user) @fixture @@ -61,7 +61,7 @@ async def agenda_products(app): ], ) - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -87,7 +87,7 @@ async def test_item_json(client): await login_public(client) resp = await client.get("/agenda/urn:conference?format=json") data = json.loads(await resp.get_data()) - assert "headline" in data + assert "slugline" in data assert "planning_items" not in data assert "coverages" not in data @@ -104,7 +104,7 @@ async def test_search(client, app, agenda_products): assert "planning_items" not in data["_items"][0] assert "coverages" not in data["_items"][0] - resp = await client.get(f"/agenda/search?navigation={NAV_1}") + resp = await client.get(f"/agenda/search?navigation={NAV_2}") data = json.loads(await resp.get_data()) assert 1 == len(data["_items"]) assert "_aggregations" in data @@ -128,7 +128,7 @@ async def set_watch_products(app): ], ) - app.data.insert( + await create_entries_for( "products", [ { @@ -289,7 +289,7 @@ async def test_watched_event_sends_notification_for_added_coverage(client, app, "ednote": "ed note here", "scheduled": "2018-05-29T10:51:52+0000", }, - "coverage_status": { + "news_coverage_status": { "name": "coverage intended", "label": "Planned", "qcode": "ncostat:int", diff --git a/tests/core/test_api_tokens.py b/tests/core/test_api_tokens.py index 421a28e55..9e804a689 100644 --- a/tests/core/test_api_tokens.py +++ b/tests/core/test_api_tokens.py @@ -2,6 +2,8 @@ from bson import ObjectId import urllib.parse +from tests.core.utils import create_entries_for + async def test_api_tokens_create(client): response = await client.post( @@ -44,7 +46,7 @@ async def test_api_tokens_create_only_one_per_company(client): async def test_api_tokens_patch(client, app): - data = app.data.insert( + data = await create_entries_for( "news_api_tokens", [ { diff --git a/tests/core/test_auth.py b/tests/core/test_auth.py index a6ac61070..4139f9883 100644 --- a/tests/core/test_auth.py +++ b/tests/core/test_auth.py @@ -12,6 +12,7 @@ from newsroom.tests.users import ADMIN_USER_EMAIL from newsroom.companies import CompanyServiceAsync from tests.utils import get_user_by_email, login, logout +from tests.core.utils import create_entries_for disabled_company = ObjectId() expired_company = ObjectId() @@ -20,7 +21,7 @@ @fixture(autouse=True) async def init(app): - app.data.insert( + await create_entries_for( "companies", [ { @@ -46,8 +47,8 @@ async def test_login_fails_for_wrong_username_or_password(client): async def test_login_fails_for_disabled_user(app, client): # Register a new account - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": ObjectId(), @@ -71,8 +72,8 @@ async def test_login_fails_for_disabled_user(app, client): async def test_login_fails_for_user_with_disabled_company(app, client): # Register a new account - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": ObjectId(), @@ -96,8 +97,8 @@ async def test_login_fails_for_user_with_disabled_company(app, client): async def test_login_succesfull_for_user_with_expired_company(app, client): # Register a new account - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": ObjectId(), @@ -124,8 +125,8 @@ async def test_login_succesfull_for_user_with_expired_company(app, client): async def test_login_for_user_with_enabled_company_succeeds(app, client): # Register a new account - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": ObjectId(), @@ -150,8 +151,8 @@ async def test_login_for_user_with_enabled_company_succeeds(app, client): async def test_login_fails_for_not_approved_user(app, client): # If user is created more than 14 days ago login fails - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": ObjectId(), @@ -189,8 +190,8 @@ async def test_login_fails_for_many_times_gets_limited(client, app): async def test_account_is_locked_after_5_wrong_passwords(app, client): await logout(client) # Register a new account - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": ObjectId(), @@ -229,8 +230,8 @@ async def test_account_is_locked_after_5_wrong_passwords(app, client): async def test_account_stays_unlocked_after_few_wrong_attempts(app, client): await logout(client) # Register a new account - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": ObjectId(), @@ -297,8 +298,8 @@ async def test_account_appears_locked_for_non_existing_user(client): async def test_login_with_remember_me_selected_creates_permanent_session(app, client): # Register a new account - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": ObjectId(), @@ -396,8 +397,8 @@ async def test_is_user_valid_empty_password(): async def test_login_for_public_user_if_company_not_assigned(client, app): - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": ObjectId(), @@ -419,8 +420,8 @@ async def test_login_for_public_user_if_company_not_assigned(client, app): async def test_login_for_internal_user_if_company_not_assigned(client, app): - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": ObjectId(), @@ -444,8 +445,8 @@ async def test_login_for_internal_user_if_company_not_assigned(client, app): async def test_access_for_disabled_user(app, client): # Register a new account user_id = ObjectId() - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": user_id, @@ -506,8 +507,8 @@ async def test_access_for_disabled_user(app, client): async def test_access_for_disabled_company(app, client): # Register a new account user_id = ObjectId() - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": user_id, @@ -533,8 +534,8 @@ async def test_access_for_disabled_company(app, client): async def test_access_for_not_approved_user(client, app): - user_ids = app.data.insert( - "users", + user_ids = await create_entries_for( + "auth_user", [ { "email": "foo2@bar.com", diff --git a/tests/core/test_auth_providers.py b/tests/core/test_auth_providers.py index 790ed9cee..8de1800e7 100644 --- a/tests/core/test_auth_providers.py +++ b/tests/core/test_auth_providers.py @@ -10,6 +10,7 @@ from newsroom.types import AuthProviderType from newsroom.tests import markers from tests.utils import get_user_by_email, logout +from tests.core.utils import create_entries_for companies = { @@ -28,7 +29,7 @@ async def init(app): {"_id": "saml", "name": "Azure", "auth_type": AuthProviderType.SAML}, ] ) - app.data.insert( + await create_entries_for( "companies", [ { @@ -66,8 +67,8 @@ async def test_password_auth_denies_other_auth_types(app, client): users_service = UsersService() user_id = ObjectId() - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": user_id, @@ -210,8 +211,8 @@ async def test_google_oauth_denies_other_auth_types(app, client): await logout(client) companies_service = get_resource_service("companies") user_id = ObjectId() - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": user_id, diff --git a/tests/core/test_command_remove_expired_agenda.py b/tests/core/test_command_remove_expired_agenda.py index 911a29531..68c461069 100644 --- a/tests/core/test_command_remove_expired_agenda.py +++ b/tests/core/test_command_remove_expired_agenda.py @@ -1,50 +1,73 @@ from typing import Dict, List, Optional, Any from datetime import timedelta -from superdesk import get_resource_service -from superdesk.utc import utcnow from eve.utils import date_to_str +from bson import ObjectId + +from superdesk.utc import utcnow + +from newsroom.types import AgendaItem from newsroom.commands.remove_expired_agenda import ( has_plan_expired, _get_expired_chain_ids, remove_expired_agenda, ) +from tests.core.utils import create_entries_for, get_all def now_minus_days(days: int): return utcnow() - timedelta(days=days) -def gen_plans(item: Dict[str, Any], plan_date: int, coverage_dates: Optional[List[int]] = None) -> Dict[str, Any]: +def gen_plans(item: Dict[str, Any], plan_date: int, coverage_dates: Optional[List[int]] = None) -> AgendaItem: + plan_id = str(ObjectId()) item.update( { "item_type": "planning", + "state": "scheduled", "dates": { "start": now_minus_days(plan_date) - timedelta(hours=1), "end": now_minus_days(plan_date), }, "coverages": [ - {"scheduled": date_to_str(now_minus_days(coverage_date))} for coverage_date in coverage_dates or [] + { + "scheduled": date_to_str(now_minus_days(coverage_date)), + "coverage_id": str(ObjectId()), + "planning_id": plan_id, + "coverage_type": "text", + "workflow_status": "active", + "coverage_status": "coverage intended", + } + for coverage_date in coverage_dates or [] ], } ) - return item + return AgendaItem.from_dict(item) -def gen_event(item: Dict[str, Any], event_end_date: int, plan_ids: Optional[List[str]] = None) -> Dict[str, Any]: +def gen_event(item: Dict[str, Any], event_end_date: int, plan_ids: Optional[List[str]] = None) -> AgendaItem: item.update( { "item_type": "event", + "state": "scheduled", "dates": { "start": now_minus_days(event_end_date) - timedelta(hours=1), "end": now_minus_days(event_end_date), }, - "planning_items": [{"_id": plan_id} for plan_id in plan_ids or []], + "planning_items": [ + { + "_id": plan_id, + "guid": plan_id, + "planning_date": now_minus_days(event_end_date) - timedelta(hours=1), + "state": "scheduled", + } + for plan_id in plan_ids or [] + ], } ) - return item + return AgendaItem.from_dict(item) async def test_has_plan_expired(): @@ -67,7 +90,7 @@ async def test_get_expired_chain_ids(app): event2 = gen_event({"_id": "event2"}, 61, ["plan3", "plan4"]) event3 = gen_event({"_id": "event3"}, 61, ["plan5"]) event4 = gen_event({"_id": "event4"}, 61, ["plan6", "plan7"]) - app.data.insert( + await create_entries_for( "agenda", [ plan1, @@ -84,20 +107,20 @@ async def test_get_expired_chain_ids(app): ], ) - assert _get_expired_chain_ids(plan1, expiry_datetime) == [] - assert _get_expired_chain_ids(plan2, expiry_datetime) == ["plan2"] - assert _get_expired_chain_ids(event1, expiry_datetime) == ["event1"] - assert _get_expired_chain_ids(event2, expiry_datetime) == [ + assert await _get_expired_chain_ids(plan1, expiry_datetime) == set() + assert await _get_expired_chain_ids(plan2, expiry_datetime) == {"plan2"} + assert await _get_expired_chain_ids(event1, expiry_datetime) == {"event1"} + assert await _get_expired_chain_ids(event2, expiry_datetime) == { "event2", "plan3", "plan4", - ] - assert _get_expired_chain_ids(event3, expiry_datetime) == [] - assert _get_expired_chain_ids(event4, expiry_datetime) == [] + } + assert await _get_expired_chain_ids(event3, expiry_datetime) == set() + assert await _get_expired_chain_ids(event4, expiry_datetime) == set() async def test_remove_expired_agenda(runner, app): - app.data.insert( + ids = await create_entries_for( "agenda", [ # Items to keep (not yet expired) @@ -115,22 +138,24 @@ async def test_remove_expired_agenda(runner, app): gen_plans({"_id": "plan4", "event_id": "event2"}, 61, [61]), ], ) + print(ids) ids_to_keep = ["plan1", "plan5", "plan6", "plan7", "event3", "event4"] ids_to_purge = ["plan2", "plan3", "plan4", "event1", "event2"] # Test with default ``AGENDA_EXPIRY_DAYS=0`` (disable purging) - agenda_service = get_resource_service("agenda") - runner.invoke(remove_expired_agenda) - item_ids = [item["_id"] for item in agenda_service.find({})] + await remove_expired_agenda() + + item_ids = [item["_id"] for item in await get_all("agenda")] for item_id in ids_to_keep + ids_to_purge: - assert item_id in item_ids + assert item_id in item_ids, item_ids # Test with setting ``AGENDA_EXPIRY_DAYS=60`` (as a string) app.config["AGENDA_EXPIRY_DAYS"] = "60" - # remove_expired_agenda(60) - # item_ids = [item["_id"] for item in agenda_service.find({})] - # for item_id in ids_to_keep: - # assert item_id in item_ids - # for item_id in ids_to_purge: - # assert item_id not in item_ids + await remove_expired_agenda() + + item_ids = [item["_id"] for item in await get_all("agenda")] + for item_id in ids_to_keep: + assert item_id in item_ids + for item_id in ids_to_purge: + assert item_id not in item_ids diff --git a/tests/core/test_commands.py b/tests/core/test_commands.py index b5e183283..3acfd2070 100644 --- a/tests/core/test_commands.py +++ b/tests/core/test_commands.py @@ -16,6 +16,7 @@ from newsroom.tests.conftest import reset_elastic from ..fixtures import items, init_items, init_auth, init_company # noqa +from tests.core.utils import create_entries_for, delete_entries_for, find_one_by_id async def test_item_detail(app, client): @@ -66,7 +67,7 @@ async def test_index_from_mongo_collection(app, client): async def test_index_from_mongo_from_timestamp(app, client): - app.data.remove("items") + await delete_entries_for("items") sorted_items = [ { "_id": "tag:foo-1", @@ -79,7 +80,7 @@ async def test_index_from_mongo_from_timestamp(app, client): {"_id": "urn:bar-3", "_created": datetime.now() - timedelta(hours=3)}, ] - app.data.insert("items", sorted_items) + await create_entries_for("items", sorted_items) await reset_elastic(app) assert 0 == app.data.elastic.find("items", ParsedRequest(), {})[1] @@ -122,7 +123,7 @@ async def test_fix_topic_nested_filters(app, admin): init_nested_aggregation("items", WIRE_NESTED_SEARCH_FIELDS, app.config["WIRE_GROUPS"], _get_wire_aggregations()) await reset_elastic(app) - app.data.insert( + await create_entries_for( "items", [ { @@ -149,7 +150,7 @@ async def test_fix_topic_nested_filters(app, admin): ], ) topic_id = ObjectId() - app.data.insert( + await create_entries_for( "topics", [ { @@ -169,7 +170,7 @@ async def test_fix_topic_nested_filters(app, admin): await fix_topic_nested_filters() - updated_topic = app.data.find_one("topics", None, topic_id) + updated_topic = await find_one_by_id("topics", topic_id) assert "subject" not in updated_topic["filter"] assert len(updated_topic["filter"]["distribution"]) == 2 diff --git a/tests/core/test_companies.py b/tests/core/test_companies.py index ac572f0dd..da7532715 100644 --- a/tests/core/test_companies.py +++ b/tests/core/test_companies.py @@ -8,6 +8,7 @@ from newsroom.users.service import UsersService from tests.utils import logout +from tests.core.utils import create_entries_for, update_entries_for, find_one_by_id, find_one_for async def test_delete_company_deletes_company_and_users(client): @@ -106,7 +107,7 @@ async def test_get_company_users(client): async def test_save_company_permissions(client, app): await logout(client) sports_id = ObjectId() - app.data.insert( + await create_entries_for( "products", [ { @@ -123,7 +124,6 @@ async def test_save_company_permissions(client, app): "description": "news product", "is_enabled": True, "product_type": "wire", - "product_type": "wire", }, ], ) @@ -138,7 +138,7 @@ async def test_save_company_permissions(client, app): }, ) - updated = app.data.find_one("companies", req=None, _id=COMPANY_1_ID) + updated = await find_one_by_id("companies", COMPANY_1_ID) assert updated["sections"]["wire"] assert not updated["sections"].get("agenda") assert updated["archive_access"] @@ -149,9 +149,9 @@ async def test_save_company_permissions(client, app): assert resp.status_code == 200 # set company with wire only - user = app.data.find_one("users", req=None, first_name="admin") + user = await find_one_for("users", first_name="admin") assert user - app.data.update("users", user["_id"], {"company": COMPANY_1_ID, "user_type": UserRole.PUBLIC.value}, user) + await update_entries_for("users", user["_id"], {"company": COMPANY_1_ID, "user_type": UserRole.PUBLIC.value}, user) # refresh session with new type await logout(client) diff --git a/tests/core/test_copy_items.py b/tests/core/test_copy_items.py index ef80b87fd..f2038b3f4 100644 --- a/tests/core/test_copy_items.py +++ b/tests/core/test_copy_items.py @@ -1,5 +1,7 @@ from quart import json -from ..utils import load_fixture, add_fixture_to_db + +from .utils import create_entries_for +from ..utils import load_fixture, load_json_fixture def fix_spaces(input): @@ -7,7 +9,9 @@ def fix_spaces(input): async def test_copy_agenda(client, app): - item = add_fixture_to_db("agenda", "agenda_copy_fixture.json") + item = load_json_fixture("agenda_copy_fixture.json") + await create_entries_for("agenda", [item]) + # item = add_fixture_to_db("agenda", "agenda_copy_fixture.json") item_id = item["_id"] resp = await client.post(f"/wire/{item_id}/copy?type=agenda") @@ -19,7 +23,9 @@ async def test_copy_agenda(client, app): async def test_copy_wire(client, app): - item = add_fixture_to_db("items", "item_copy_fixture.json") + item = load_json_fixture("item_copy_fixture.json") + await create_entries_for("items", [item]) + # item = add_fixture_to_db("items", "item_copy_fixture.json") item_id = item["_id"] resp = await client.post(f"/wire/{item_id}/copy?type=wire") diff --git a/tests/core/test_csv_formatter.py b/tests/core/test_csv_formatter.py index 670fb1667..1fb56234b 100644 --- a/tests/core/test_csv_formatter.py +++ b/tests/core/test_csv_formatter.py @@ -1,7 +1,7 @@ from .test_push_events import test_event import copy -from newsroom.utils import get_entity_or_404 from newsroom.agenda.formatters import CSVFormatter +from tests.core.utils import find_one_by_id import csv @@ -31,12 +31,12 @@ {"name": "Statistics & Economic Indicators", "qcode": "150", "scheme": "event_types", "code": "150"}, {"name": "Economic Indicators", "qcode": "3", "scheme": "categories", "code": "3"}, { - "name": None, + "name": "custom something", "qcode": "20001237", "parent": "08000000", - "iptc_subject": None, - "ap_subject": None, - "in_jimi": False, + # "iptc_subject": None, + # "ap_subject": None, + # "in_jimi": False, "translations": {"name": {"en-CA": "anniversary", "fr-CA": None}}, "scheme": "subject_custom", "code": "20001237", @@ -58,8 +58,9 @@ def read_csv(data): async def test_csv_formatter_item(client, app): - await client.post("/push", json=event) - parsed = get_entity_or_404(event["guid"], "agenda") + response = await client.post("/push", json=event) + print(await response.get_data(as_text=True)) + parsed = await find_one_by_id("agenda", event["guid"]) assert formatter.format_filename(parsed).endswith("new-press-conference.csv") @@ -139,7 +140,7 @@ async def test_csv_formatter_item(client, app): ], } await client.post("/push", json=event2) - parsed = get_entity_or_404(event2["guid"], "agenda") + parsed = await find_one_by_id("agenda", event2["guid"]) assert formatter.format_filename(parsed).endswith("latest-press-conference.csv") # update config diff --git a/tests/core/test_download.py b/tests/core/test_download.py index 56b04ac67..e07832eff 100644 --- a/tests/core/test_download.py +++ b/tests/core/test_download.py @@ -18,6 +18,7 @@ ) from .test_push import upload_binary from newsroom.history_async import HistoryService +from tests.core.utils import update_entries_for items_ids = [item["_id"] for item in items[:2]] item = items[:2][0] @@ -138,7 +139,7 @@ async def setup_image(client, app): }, } } - app.data.update("items", item["_id"], {"associations": associations}, item) + await update_entries_for("items", item["_id"], {"associations": associations}, item) async def test_download_single(client, app): diff --git a/tests/core/test_email_templates.py b/tests/core/test_email_templates.py index 56c86fd09..f95bc5702 100644 --- a/tests/core/test_email_templates.py +++ b/tests/core/test_email_templates.py @@ -3,6 +3,7 @@ from superdesk import get_resource_service from newsroom.email_templates import RESOURCE +from tests.core.utils import create_entries_for async def test_email_template_find_one(app): @@ -21,7 +22,7 @@ async def test_email_template_find_one(app): async def test_default_subjects(app): - app.data.insert( + await create_entries_for( RESOURCE, [ { @@ -39,7 +40,7 @@ async def test_default_subjects(app): async def test_get_subject_translation(app): - app.data.insert( + await create_entries_for( RESOURCE, [ { @@ -67,7 +68,7 @@ async def test_get_subject_translation(app): async def test_subject_translation_falls_back_to_default(app): - app.data.insert( + await create_entries_for( RESOURCE, [ { @@ -92,7 +93,7 @@ async def test_subject_translation_falls_back_to_default(app): async def test_get_subject_translation_with_template_variables(app): - app.data.insert( + await create_entries_for( RESOURCE, [ { @@ -120,7 +121,7 @@ async def test_get_subject_translation_with_template_variables(app): async def test_get_subject_falls_back_to_default_on_render_error(app): - app.data.insert( + await create_entries_for( RESOURCE, [ { @@ -142,7 +143,7 @@ async def test_get_subject_falls_back_to_default_on_render_error(app): async def test_get_from_mongo_returns_working_cursor(app): - app.data.insert( + await create_entries_for( RESOURCE, [ { diff --git a/tests/core/test_emails.py b/tests/core/test_emails.py index 42b61831d..b017e5958 100644 --- a/tests/core/test_emails.py +++ b/tests/core/test_emails.py @@ -1,4 +1,5 @@ import pathlib +from copy import deepcopy from quart import render_template_string, json, url_for from jinja2 import TemplateNotFound @@ -18,6 +19,7 @@ from newsroom.email import send_user_email from tests.fixtures import agenda_items from newsroom.tests import markers +from tests.core.utils import create_entries_for async def test_item_notification_template(client, app, mocker): @@ -127,7 +129,7 @@ def mock_get_template_include_fr_ca(template_name_or_list): @mock.patch("flask.current_app.jinja_env.get_or_select_template", mock_get_template_always_pass) async def test_map_email_recipients_by_language(client, app): - app.data.insert("users", MOCK_USERS) + await create_entries_for("users", MOCK_USERS) async with app.app_context(): email_groups = map_email_recipients_by_language(EMAILS, "test_template") @@ -159,7 +161,7 @@ async def test_map_email_recipients_by_language(client, app): mock_get_template_include_fr_ca, ) async def test_map_email_recipients_by_language_fallback(client, app): - app.data.insert("users", MOCK_USERS) + await create_entries_for("users", MOCK_USERS) async with app.app_context(): email_groups = map_email_recipients_by_language(EMAILS, "test_template") @@ -276,7 +278,10 @@ async def test_item_killed_notification_email(app): async def test_send_user_email_on_locale_changed(): - event_item = agenda_items[0] + event_item = deepcopy(agenda_items[0]) + event_item["coverages"] = deepcopy(agenda_items[1]["coverages"]) + event_item["planning_items"] = deepcopy(agenda_items[1]["planning_items"]) + # event_item = agenda_items[1] user = User( email="foo@example.com", @@ -289,14 +294,15 @@ async def test_send_user_email_on_locale_changed(): user["locale"] = "fr_CA" with mock.patch("newsroom.email.send_email") as send_email_mock: - template_kwargs = dict(item=agenda_items[0], planning_item=agenda_items[0]["planning_items"][0]) + template_kwargs = dict(item=event_item, planning_item=event_item["planning_items"][0]) await send_user_email(user, "test_template", template_kwargs=template_kwargs) + print(send_email_mock.call_args[1]["text_body"]) assert "Event status : Planifiée" in send_email_mock.call_args[1]["text_body"] assert "Coverage status: Planifiée" in send_email_mock.call_args[1]["text_body"] user["locale"] = "en" with mock.patch("newsroom.email.send_email") as send_email_mock: - template_kwargs = dict(item=agenda_items[0], planning_item=agenda_items[0]["planning_items"][0]) + template_kwargs = dict(item=event_item, planning_item=event_item["planning_items"][0]) await send_user_email(user, "test_template", template_kwargs=template_kwargs) assert "Event status : Planned" in send_email_mock.call_args[1]["text_body"] assert "Coverage status: Planned" in send_email_mock.call_args[1]["text_body"] diff --git a/tests/core/test_home.py b/tests/core/test_home.py index af9e1a93c..f44c028a5 100644 --- a/tests/core/test_home.py +++ b/tests/core/test_home.py @@ -3,11 +3,11 @@ from newsroom.wire.views import get_home_data from newsroom.tests.fixtures import PUBLIC_USER_ID -from tests.core.utils import create_entries_for +from tests.core.utils import create_entries_for, update_entries_for, find_one_by_id async def test_personal_dashboard_data(client, app, company_products): - user = app.data.find_one("users", req=None, _id=PUBLIC_USER_ID) + user = await find_one_by_id("users", PUBLIC_USER_ID) assert user topics = [ @@ -16,7 +16,7 @@ async def test_personal_dashboard_data(client, app, company_products): await create_entries_for("topics", topics) - app.data.update( + await update_entries_for( "users", PUBLIC_USER_ID, { diff --git a/tests/core/test_ical_formatter.py b/tests/core/test_ical_formatter.py index 54b0cfba0..ee9084363 100644 --- a/tests/core/test_ical_formatter.py +++ b/tests/core/test_ical_formatter.py @@ -6,9 +6,9 @@ from quart import json import newsroom.auth # noqa - Fix cyclic import when running single test file -from newsroom.utils import get_entity_or_404 from newsroom.agenda.formatters import iCalFormatter from .test_push_events import test_event +from tests.core.utils import find_one_by_id event = copy.deepcopy(test_event) event["ednote"] = "ed note" @@ -24,7 +24,7 @@ async def test_ical_formatter_item(client, app, mocker): await client.post("/push", json=event) - parsed = get_entity_or_404(event["guid"], "agenda") + parsed = await find_one_by_id("agenda", event["guid"]) formatter = iCalFormatter() assert formatter.format_filename(parsed).endswith("new-press-conference.ical") diff --git a/tests/core/test_monitoring.py b/tests/core/test_monitoring.py index 786a5ea87..593330cbc 100644 --- a/tests/core/test_monitoring.py +++ b/tests/core/test_monitoring.py @@ -4,7 +4,7 @@ from pytest import fixture from bson import ObjectId -from tests.core.utils import create_entries_for +from tests.core.utils import create_entries_for, update_entries_for, find_one_by_id from newsroom.monitoring.email_alerts import MonitoringEmailAlerts from unittest import mock from tests.utils import mock_send_email, post_json, login_public @@ -63,7 +63,7 @@ async def init(app): ], ) - app.data.insert( + await create_entries_for( "monitoring", [ { @@ -235,7 +235,7 @@ async def test_get_companies_with_monitoring_schedules(client): @mock.patch("newsroom.monitoring.email_alerts.utcnow", mock_utcnow) @mock.patch("newsroom.email.send_email", mock_send_email) async def test_send_immediate_alerts(client, app): - app.data.insert( + await create_entries_for( "items", [ { @@ -275,15 +275,15 @@ def assert_recipients(outbox, recipients: List[str]): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_send_one_hour_alerts(client, app): - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "one_hour"}}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -294,7 +294,7 @@ async def test_send_one_hour_alerts(client, app): } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -323,15 +323,15 @@ async def test_send_one_hour_alerts(client, app): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_send_two_hour_alerts(client, app): - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "two_hour"}}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -342,7 +342,7 @@ async def test_send_two_hour_alerts(client, app): } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -371,15 +371,15 @@ async def test_send_two_hour_alerts(client, app): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_send_four_hour_alerts(client, app): - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "four_hour"}}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -390,7 +390,7 @@ async def test_send_four_hour_alerts(client, app): } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -421,9 +421,9 @@ async def test_send_four_hour_alerts(client, app): async def test_send_daily_alerts(client, app): now = utcnow() now = utc_to_local(app.config["DEFAULT_TIMEZONE"], now) - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), { @@ -434,7 +434,7 @@ async def test_send_daily_alerts(client, app): }, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -445,7 +445,7 @@ async def test_send_daily_alerts(client, app): } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -456,7 +456,7 @@ async def test_send_daily_alerts(client, app): } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -487,9 +487,9 @@ async def test_send_daily_alerts(client, app): async def test_send_weekly_alerts(client, app): now = utcnow() now = utc_to_local(app.config["DEFAULT_TIMEZONE"], now) - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), { @@ -501,7 +501,7 @@ async def test_send_weekly_alerts(client, app): }, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -512,7 +512,7 @@ async def test_send_weekly_alerts(client, app): } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -523,7 +523,7 @@ async def test_send_weekly_alerts(client, app): } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -552,15 +552,15 @@ async def test_send_weekly_alerts(client, app): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_send_alerts_respects_last_run_time(client, app): - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "two_hour"}}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -571,7 +571,7 @@ async def test_send_alerts_respects_last_run_time(client, app): } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -599,7 +599,7 @@ async def test_send_alerts_respects_last_run_time(client, app): with app.mail.record_messages() as newoutbox: # async with app.app_context(): - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None assert w.get("last_run_time") is not None last_run_time = local_to_utc(app.config["DEFAULT_TIMEZONE"], even_now) @@ -612,7 +612,7 @@ async def test_send_alerts_respects_last_run_time(client, app): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_disabled_profile_wont_send_immediate_alerts(client, app): get_resource_service("monitoring").patch(ObjectId("5db11ec55f627d8aa0b545fb"), {"is_enabled": False}) - app.data.insert( + await create_entries_for( "items", [ { @@ -630,15 +630,15 @@ async def test_disabled_profile_wont_send_immediate_alerts(client, app): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_disabled_profile_wont_send_scheduled_alerts(client, app): - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "two_hour"}, "is_enabled": False}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -649,7 +649,7 @@ async def test_disabled_profile_wont_send_scheduled_alerts(client, app): } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -669,7 +669,7 @@ async def test_disabled_profile_wont_send_scheduled_alerts(client, app): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_always_send_immediate_alerts_wiont_send_default_email(client, app): get_resource_service("monitoring").patch(ObjectId("5db11ec55f627d8aa0b545fb"), {"always_send": True}) - app.data.insert( + await create_entries_for( "items", [ { @@ -687,14 +687,14 @@ async def test_always_send_immediate_alerts_wiont_send_default_email(client, app @mock.patch("newsroom.email.send_email", mock_send_email) async def test_always_send_schedule_alerts(client, app): - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") - app.data.update( + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "two_hour"}, "always_send": True}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -713,14 +713,14 @@ async def test_always_send_schedule_alerts(client, app): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_disable_always_send_schedule_alerts(client, app): - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") - app.data.update( + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "two_hour"}, "always_send": False}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -740,7 +740,7 @@ async def test_disable_always_send_schedule_alerts(client, app): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_always_send_immediate_alerts(client, app): get_resource_service("monitoring").patch(ObjectId("5db11ec55f627d8aa0b545fb"), {"always_send": False}) - app.data.insert( + await create_entries_for( "items", [ { @@ -759,7 +759,7 @@ async def test_always_send_immediate_alerts(client, app): @mock.patch("newsroom.monitoring.email_alerts.utcnow", mock_utcnow) @mock.patch("newsroom.email.send_email", mock_send_email) async def test_last_run_time_always_updated_with_matching_content_immediate(client, app): - app.data.insert( + await create_entries_for( "items", [ { @@ -783,7 +783,7 @@ async def test_last_run_time_always_updated_with_matching_content_immediate(clie assert outbox[0].subject == "Monitoring Subject" assert "Newsroom Monitoring: W1" in outbox[0].body assert "monitoring-export.pdf" in outbox[0].attachments[0] - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None assert w.get("last_run_time") is not None assert w["last_run_time"] > (mock_utcnow() - timedelta(minutes=5)) @@ -791,15 +791,15 @@ async def test_last_run_time_always_updated_with_matching_content_immediate(clie @mock.patch("newsroom.email.send_email", mock_send_email) async def test_last_run_time_always_updated_with_matching_content_scheduled(client, app): - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "two_hour"}}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -810,7 +810,7 @@ async def test_last_run_time_always_updated_with_matching_content_scheduled(clie } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -834,7 +834,7 @@ async def test_last_run_time_always_updated_with_matching_content_scheduled(clie assert outbox[0].subject == "Monitoring Subject" assert "Newsroom Monitoring: W1" in outbox[0].body assert "monitoring-export.pdf" in outbox[0].attachments[0] - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None assert w.get("last_run_time") is not None last_run_time = local_to_utc(app.config["DEFAULT_TIMEZONE"], even_now) @@ -843,7 +843,7 @@ async def test_last_run_time_always_updated_with_matching_content_scheduled(clie @mock.patch("newsroom.email.send_email", mock_send_email) async def test_last_run_time_always_updated_with_no_matching_content_immediate(client, app): - app.data.insert( + await create_entries_for( "items", [ { @@ -857,7 +857,7 @@ async def test_last_run_time_always_updated_with_no_matching_content_immediate(c with app.mail.record_messages() as outbox: await MonitoringEmailAlerts().run(immediate=True) assert len(outbox) == 0 - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None assert w.get("last_run_time") is not None assert w["last_run_time"] > (mock_utcnow() - timedelta(minutes=5)) @@ -865,15 +865,15 @@ async def test_last_run_time_always_updated_with_no_matching_content_immediate(c @mock.patch("newsroom.email.send_email", mock_send_email) async def test_last_run_time_always_updated_with_no_matching_content_scheduled(client, app): - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "two_hour"}}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -887,7 +887,7 @@ async def test_last_run_time_always_updated_with_no_matching_content_scheduled(c with app.mail.record_messages() as outbox: await MonitoringEmailAlerts().scheduled_worker(even_now) assert len(outbox) == 0 - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None assert w.get("last_run_time") is not None last_run_time = local_to_utc(app.config["DEFAULT_TIMEZONE"], even_now) @@ -896,10 +896,10 @@ async def test_last_run_time_always_updated_with_no_matching_content_scheduled(c @mock.patch("newsroom.email.send_email", mock_send_email) async def test_last_run_time_always_updated_with_no_users_immediate(client, app): - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") - app.data.update("monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"users": []}, w) + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") + await update_entries_for("monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"users": []}, w) - app.data.insert( + await create_entries_for( "items", [ { @@ -913,7 +913,7 @@ async def test_last_run_time_always_updated_with_no_users_immediate(client, app) with app.mail.record_messages() as outbox: await MonitoringEmailAlerts().run(immediate=True) assert len(outbox) == 0 - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None assert w.get("last_run_time") is not None assert w["last_run_time"] > (mock_utcnow() - timedelta(minutes=5)) @@ -921,15 +921,15 @@ async def test_last_run_time_always_updated_with_no_users_immediate(client, app) @mock.patch("newsroom.email.send_email", mock_send_email) async def test_last_run_time_always_updated_with_no_users_scheduled(client, app): - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "two_hour"}, "users": []}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -943,7 +943,7 @@ async def test_last_run_time_always_updated_with_no_users_scheduled(client, app) with app.mail.record_messages() as outbox: await MonitoringEmailAlerts().scheduled_worker(even_now) assert len(outbox) == 0 - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None assert w.get("last_run_time") is not None last_run_time = local_to_utc(app.config["DEFAULT_TIMEZONE"], even_now) @@ -953,15 +953,15 @@ async def test_last_run_time_always_updated_with_no_users_scheduled(client, app) @mock.patch("newsroom.email.send_email", mock_send_email) async def test_will_send_one_hour_alerts_on_odd_hours(client, app): now = even_now.replace(hour=3, minute=0) - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "one_hour"}}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -972,7 +972,7 @@ async def test_will_send_one_hour_alerts_on_odd_hours(client, app): } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -991,15 +991,15 @@ async def test_will_send_one_hour_alerts_on_odd_hours(client, app): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_wont_send_two_hour_alerts_on_odd_hours(client, app): now = even_now.replace(hour=3, minute=0) - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "two_hour"}}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -1010,7 +1010,7 @@ async def test_wont_send_two_hour_alerts_on_odd_hours(client, app): } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -1029,15 +1029,15 @@ async def test_wont_send_two_hour_alerts_on_odd_hours(client, app): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_wont_send_four_hour_alerts_on_odd_hours(client, app): now = even_now.replace(hour=3, minute=0) - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"schedule": {"interval": "four_hour"}}, w, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -1048,7 +1048,7 @@ async def test_wont_send_four_hour_alerts_on_odd_hours(client, app): } ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -1072,7 +1072,7 @@ async def test_send_immediate_rtf_attachment_alerts(client, app): "/settings/general_settings", {"monitoring_report_logo_path": get_fixture_path("thumbnail.jpg")}, ) - app.data.insert( + await create_entries_for( "items", [ { @@ -1086,9 +1086,9 @@ async def test_send_immediate_rtf_attachment_alerts(client, app): } ], ) - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), { @@ -1116,7 +1116,7 @@ async def test_send_immediate_rtf_attachment_alerts(client, app): @mock.patch("newsroom.monitoring.email_alerts.utcnow", mock_utcnow) @mock.patch("newsroom.email.send_email", mock_send_email) async def test_send_immediate_headline_subject_alerts(client, app): - app.data.insert( + await create_entries_for( "items", [ { @@ -1127,9 +1127,9 @@ async def test_send_immediate_headline_subject_alerts(client, app): } ], ) - w = app.data.find_one("monitoring", None, _id="5db11ec55f627d8aa0b545fb") + w = await find_one_by_id("monitoring", "5db11ec55f627d8aa0b545fb") assert w is not None - app.data.update( + await update_entries_for( "monitoring", ObjectId("5db11ec55f627d8aa0b545fb"), {"headline_subject": True}, diff --git a/tests/core/test_navigations.py b/tests/core/test_navigations.py index 6d1dd042c..5f18202cd 100644 --- a/tests/core/test_navigations.py +++ b/tests/core/test_navigations.py @@ -9,8 +9,7 @@ from newsroom.tests.fixtures import COMPANY_1_ID from newsroom.navigations import get_navigations from newsroom.types import Product -from tests.core.utils import add_company_products, create_entries_for - +from tests.core.utils import add_company_products, create_entries_for, find_one_by_id NAV_ID = ObjectId("59b4c5c61d41c8d736852fbf") AGENDA_NAV_ID = ObjectId() @@ -111,11 +110,15 @@ async def test_delete_navigation_removes_references(client): async def test_create_navigation_with_products(client, app): - app.data.insert( + product_ids = [ + ObjectId(), + ObjectId(), + ] + await create_entries_for( "products", [ { - "_id": "p-1", + "_id": product_ids[0], "name": "Sport", "description": "sport product", "navigations": [], @@ -123,7 +126,7 @@ async def test_create_navigation_with_products(client, app): "product_type": "wire", }, { - "_id": "p-2", + "_id": product_ids[1], "name": "News", "description": "news product", "navigations": [], @@ -143,7 +146,7 @@ async def test_create_navigation_with_products(client, app): "description": "Breaking news", "product_type": "wire", "is_enabled": True, - "products": ["p-2"], + "products": [product_ids[1]], } ) }, @@ -157,16 +160,21 @@ async def test_create_navigation_with_products(client, app): response = await client.get("/products") data = json.loads(await response.get_data()) - assert [p for p in data if p["_id"] == "p-1"][0]["navigations"] == [] - assert [p for p in data if p["_id"] == "p-2"][0]["navigations"] == [nav_id] + print(data) + assert [p for p in data if p["_id"] == str(product_ids[0])][0]["navigations"] == [] + assert [p for p in data if p["_id"] == str(product_ids[1])][0]["navigations"] == [nav_id] async def test_update_navigation_with_products(client, app): - app.data.insert( + product_ids = [ + ObjectId(), + ObjectId(), + ] + await create_entries_for( "products", [ { - "_id": "p-1", + "_id": product_ids[0], "name": "Sport", "description": "sport product", "navigations": [], @@ -174,7 +182,7 @@ async def test_update_navigation_with_products(client, app): "product_type": "wire", }, { - "_id": "p-2", + "_id": product_ids[1], "name": "News", "description": "news product", "navigations": [NAV_ID], @@ -187,13 +195,13 @@ async def test_update_navigation_with_products(client, app): await test_login_succeeds_for_admin(client) await client.post( f"navigations/{NAV_ID}", - form={"navigation": json.dumps({"name": "Sports 2", "is_enabled": True, "products": ["p-1"]})}, + form={"navigation": json.dumps({"name": "Sports 2", "is_enabled": True, "products": [product_ids[0]]})}, ) response = await client.get("/products") data = json.loads(await response.get_data()) - assert [p for p in data if p["_id"] == "p-1"][0]["navigations"] == [str(NAV_ID)] - assert [p for p in data if p["_id"] == "p-2"][0]["navigations"] == [] + assert [p for p in data if p["_id"] == str(product_ids[0])][0]["navigations"] == [str(NAV_ID)] + assert [p for p in data if p["_id"] == str(product_ids[1])][0]["navigations"] == [] async def test_get_agenda_navigations_by_company_returns_ordered(client, app): @@ -209,7 +217,7 @@ async def test_get_agenda_navigations_by_company_returns_ordered(client, app): ], ) - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -232,7 +240,7 @@ async def test_get_agenda_navigations_by_company_returns_ordered(client, app): ) await test_login_succeeds_for_admin(client) - company = app.data.find_one("companies", req=None, _id=COMPANY_1_ID) + company = await find_one_by_id("companies", COMPANY_1_ID) navigations = await get_navigations(None, company, "agenda") assert navigations[0].get("name") == "Uber" navigations = await get_navigations(None, company, "wire") @@ -241,6 +249,7 @@ async def test_get_agenda_navigations_by_company_returns_ordered(client, app): async def test_get_products_by_navigation_caching(app): nav_id = ObjectId() + product_id = ObjectId() await create_entries_for( "navigations", [ @@ -253,11 +262,11 @@ async def test_get_products_by_navigation_caching(app): ], ) - app.data.insert( + await create_entries_for( "products", [ { - "_id": "p-2", + "_id": product_id, "name": "A News", "navigations": [nav_id], "description": "news product", @@ -312,7 +321,7 @@ async def test_get_navigations_for_user(public_user, public_company, app): ), ] - app.data.insert("products", products) + await create_entries_for("products", products) public_user["products"] = [get_product_ref(products[0]), get_product_ref(products[1])] navigations = await get_navigations(public_user, public_company, "wire") diff --git a/tests/core/test_products.py b/tests/core/test_products.py index 344679939..36e34d17c 100644 --- a/tests/core/test_products.py +++ b/tests/core/test_products.py @@ -37,7 +37,7 @@ async def companies(app): {"name": "test3"}, ] - app.data.insert("companies", _companies) + await create_entries_for("companies", _companies) return _companies @@ -131,7 +131,7 @@ async def test_gets_all_products(client, app): await test_login_succeeds_for_admin(client) for i in range(250): - app.data.insert( + await create_entries_for( "products", [ { @@ -216,8 +216,8 @@ async def test_company_and_user_products(client, app, public_company, public_use "product_type": "wire", } - app.data.insert("products", [test_product]) - app.data.insert( + await create_entries_for("products", [test_product]) + await create_entries_for( "items", [ {"headline": "finance item", "type": "text", "versioncreated": datetime.utcnow()}, diff --git a/tests/core/test_push.py b/tests/core/test_push.py index b4253a48d..d2e081395 100644 --- a/tests/core/test_push.py +++ b/tests/core/test_push.py @@ -10,13 +10,13 @@ from quart.datastructures import FileStorage from newsroom.types import UserResourceModel, CompanyResource, UserRole, TopicResourceModel, SectionEnum -from newsroom.utils import get_company_dict_async, get_entity_or_404, get_user_dict_async +from newsroom.utils import get_company_dict_async, get_user_dict_async from newsroom.wire import WireSearchServiceAsync from newsroom.notifications import NotificationsService from newsroom.tests.fixtures import TEST_USER_ID # noqa - Fix cyclic import when running single test file from newsroom.tests import markers -from tests.core.utils import add_company_products, create_entries_for +from tests.core.utils import add_company_products, create_entries_for, update_entries_for, find_one_by_id from ..fixtures import COMPANY_1_ID, PUBLIC_USER_ID from ..utils import mock_send_email @@ -276,8 +276,8 @@ async def test_push_binary_invalid_signature(client, app): @markers.requires_async_celery async def test_notify_topic_matches_for_new_item(client, app, mocker): - user_ids = app.data.insert( - "users", + user_ids = await create_entries_for( + "auth_user", [ { "email": "foo2@bar.com", @@ -344,7 +344,7 @@ async def test_notify_topic_matches_for_new_item(client, app, mocker): @markers.requires_async_celery @mock.patch("newsroom.email.send_email", mock_send_email) async def test_notify_user_matches_for_new_item_in_history(client, app, mocker): - company_ids = app.data.insert( + company_ids = await create_entries_for( "companies", [ { @@ -363,9 +363,10 @@ async def test_notify_user_matches_for_new_item_in_history(client, app, mocker): "company": company_ids[0], } - user_ids = app.data.insert("users", [user]) + user_ids = await create_entries_for("auth_user", [user]) user["_id"] = user_ids[0] + # TODO-ASYNC-AGENDA: Replace this with proper function call app.data.insert( "history", docs=[ @@ -429,7 +430,7 @@ async def test_notify_user_matches_for_new_item_in_history(client, app, mocker): @markers.requires_async_celery @mock.patch("newsroom.email.send_email", mock_send_email) async def test_notify_user_matches_for_killed_item_in_history(client, app, mocker): - company_ids = app.data.insert( + company_ids = await create_entries_for( "companies", [ { @@ -448,9 +449,10 @@ async def test_notify_user_matches_for_killed_item_in_history(client, app, mocke "company": company_ids[0], } - user_ids = app.data.insert("users", [user]) + user_ids = await create_entries_for("auth_user", [user]) user["_id"] = user_ids[0] + # TODO-ASYNC-AGENDA: Replace this with proper function call app.data.insert( "history", docs=[ @@ -505,10 +507,10 @@ async def test_notify_user_matches_for_new_item_in_bookmarks(client, app, mocker "company": COMPANY_1_ID, } - user_ids = app.data.insert("users", [user]) + user_ids = await create_entries_for("auth_user", [user]) user["_id"] = user_ids[0] - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -523,7 +525,7 @@ async def test_notify_user_matches_for_new_item_in_bookmarks(client, app, mocker ], ) - app.data.insert( + await create_entries_for( "items", [ { @@ -571,7 +573,7 @@ async def test_notify_user_matches_for_new_item_in_bookmarks(client, app, mocker @markers.requires_async_celery async def test_do_not_notify_disabled_user(client, app, mocker): - app.data.insert( + await create_entries_for( "companies", [ { @@ -582,8 +584,8 @@ async def test_do_not_notify_disabled_user(client, app, mocker): ], ) - user_ids = app.data.insert( - "users", + user_ids = await create_entries_for( + "auth_user", [ { "email": "foo2@bar.com", @@ -605,8 +607,8 @@ async def test_do_not_notify_disabled_user(client, app, mocker): assert 201 == resp.status_code # disable user - user = app.data.find_one("users", req=None, _id=user_ids[0]) - app.data.update("users", user_ids[0], {"is_enabled": False}, user) + user = await find_one_by_id("users", user_ids[0]) + await update_entries_for("users", user_ids[0], {"is_enabled": False}, user) # clean cache app.cache.delete(str(user_ids[0])) @@ -624,7 +626,7 @@ async def test_do_not_notify_disabled_user(client, app, mocker): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_notify_checks_service_subscriptions(client, app, mocker): company_id = ObjectId() - app.data.insert( + await create_entries_for( "companies", [ { @@ -635,7 +637,7 @@ async def test_notify_checks_service_subscriptions(client, app, mocker): ], ) - user_ids = app.data.insert( + user_ids = await create_entries_for( "auth_user", [ { @@ -679,8 +681,8 @@ async def test_notify_checks_service_subscriptions(client, app, mocker): @markers.requires_async_celery @mock.patch("newsroom.email.send_email", mock_send_email) async def test_send_notification_emails(client, app): - user_ids = app.data.insert( - "users", + user_ids = await create_entries_for( + "auth_user", [ { "email": "foo2@bar.com", @@ -692,7 +694,7 @@ async def test_send_notification_emails(client, app): ], ) - app.data.insert( + await create_entries_for( "topics", [ { @@ -810,13 +812,13 @@ async def test_matching_topics(client, app): ) ), ] - matching = await WireSearchServiceAsync().get_mathing_topics_for_item(item["guid"], topics, users, companies) + matching = await WireSearchServiceAsync().get_matching_topics_for_item(item["guid"], topics, users, companies) assert {topic_ids["created_from_future"], topic_ids["query"]} == matching async def test_matching_topics_for_public_user(client, app): app.config["WIRE_AGGS"]["genre"] = {"terms": {"field": "genre.name", "size": 50}} - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -885,7 +887,7 @@ async def test_matching_topics_for_public_user(client, app): ) ), ] - matching = await WireSearchServiceAsync().get_mathing_topics_for_item( + matching = await WireSearchServiceAsync().get_matching_topics_for_item( item["guid"], topics, list(users.values()), companies ) assert {topic_ids["created_from_future"], topic_ids["query"]} == matching @@ -893,7 +895,7 @@ async def test_matching_topics_for_public_user(client, app): async def test_matching_topics_for_user_with_inactive_company(client, app): app.config["WIRE_AGGS"]["genre"] = {"terms": {"field": "genre.name", "size": 50}} - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -961,7 +963,7 @@ async def test_matching_topics_for_user_with_inactive_company(client, app): ) ), ] - matching = await WireSearchServiceAsync().get_mathing_topics_for_item( + matching = await WireSearchServiceAsync().get_matching_topics_for_item( item["guid"], topics, list(users.values()), companies ) assert {topic_ids["created_from_future"], topic_ids["query"]} == matching @@ -969,7 +971,7 @@ async def test_matching_topics_for_user_with_inactive_company(client, app): async def test_push_parsed_item(client, app): await client.post("/push", json=item) - parsed = get_entity_or_404(item["guid"], "wire_search") + parsed = await find_one_by_id("items", item["guid"]) assert isinstance(parsed["firstcreated"], datetime) assert 2 == parsed["wordcount"] assert 7 == parsed["charcount"] @@ -979,7 +981,7 @@ async def test_push_parsed_dates(client, app): payload = item.copy() payload["embargoed"] = "2019-01-31T00:01:00+00:00" await client.post("/push", json=payload) - parsed = get_entity_or_404(item["guid"], "items") + parsed = await find_one_by_id("items", item["guid"]) assert isinstance(parsed["firstcreated"], datetime) assert isinstance(parsed["versioncreated"], datetime) assert isinstance(parsed["embargoed"], datetime) @@ -987,7 +989,7 @@ async def test_push_parsed_dates(client, app): async def test_push_event_coverage_info(client, app): await client.post("/push", json=item) - parsed = get_entity_or_404(item["guid"], "items") + parsed = await find_one_by_id("items", item["guid"]) assert parsed["event_id"] == "urn:event/1" assert parsed["coverage_id"] == "urn:coverage/1" @@ -995,7 +997,7 @@ async def test_push_event_coverage_info(client, app): async def test_push_wire_subject_whitelist(client, app): app.config["WIRE_SUBJECT_SCHEME_WHITELIST"] = ["b"] await client.post("/push", json=item) - parsed = get_entity_or_404(item["guid"], "items") + parsed = await find_one_by_id("items", item["guid"]) assert 1 == len(parsed["subject"]) assert "b" == parsed["subject"][0]["name"] @@ -1005,14 +1007,14 @@ async def test_push_custom_expiry(client, app): updated = item.copy() updated["source"] = "foo" await client.post("/push", json=updated) - parsed = get_entity_or_404(item["guid"], "items") + parsed = await find_one_by_id("items", item["guid"]) now = datetime.utcnow().replace(second=0, microsecond=0) expiry: datetime = parsed["expiry"].replace(tzinfo=None) assert now + timedelta(days=49) < expiry < now + timedelta(days=51) async def test_matching_topics_with_mallformed_query(client, app): - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -1058,14 +1060,14 @@ async def test_matching_topics_with_mallformed_query(client, app): ), ] - matching = await WireSearchServiceAsync().get_mathing_topics_for_item( + matching = await WireSearchServiceAsync().get_matching_topics_for_item( item["guid"], topics, list(users.values()), companies ) assert {topic_ids["good"]} == matching async def test_matching_topics_when_disabling_section(client, app): - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -1109,7 +1111,7 @@ async def test_matching_topics_when_disabling_section(client, app): ), ] users[TEST_USER_ID].sections = {"wire": False, "agenda": True} - matching = await WireSearchServiceAsync().get_mathing_topics_for_item( + matching = await WireSearchServiceAsync().get_matching_topics_for_item( item["guid"], topics, list(users.values()), companies ) assert set() == matching diff --git a/tests/core/test_push_events.py b/tests/core/test_push_events.py index acb1236e6..be453ad81 100644 --- a/tests/core/test_push_events.py +++ b/tests/core/test_push_events.py @@ -1,5 +1,4 @@ import io -import pytz import pytest from datetime import datetime @@ -16,12 +15,12 @@ from newsroom.tests.users import ( ADMIN_USER_ID, ) # noqa - Fix cyclic import when running single test file -from newsroom.utils import get_entity_or_404 from newsroom.notifications import get_user_notifications from newsroom.tests import markers from .test_push import get_signature_headers from tests.utils import post_json, get_json, mock_send_email +from tests.core.utils import create_entries_for, update_entries_for, find_one_by_id @pytest.fixture @@ -94,7 +93,7 @@ def init_agenda_items(): "abstract": "abstract text", "_current_version": 1, "agendas": [], - "anpa_category": [{"name": "Entertainment", "subject": "01000000", "qcode": "e"}], + "anpa_category": [{"name": "Entertainment", "qcode": "e"}], "item_id": "bar", "ednote": "ed note here", "slugline": "Vivid planning item", @@ -112,7 +111,7 @@ def init_agenda_items(): "ednote": "ed note here", "scheduled": "2018-05-28T10:51:52+0000", }, - "coverage_status": { + "news_coverage_status": { "name": "coverage intended", "label": "Planned", "qcode": "ncostat:int", @@ -129,7 +128,7 @@ def init_agenda_items(): "ednote": "ed note here", "scheduled": "2018-05-28T10:51:52+0000", }, - "coverage_status": { + "news_coverage_status": { "name": "coverage intended", "label": "Planned", "qcode": "ncostat:int", @@ -183,12 +182,10 @@ def init_agenda_items(): async def test_push_parsed_event(client, app): event = deepcopy(test_event) await client.post("/push", json=event) - parsed = get_entity_or_404(event["guid"], "agenda") + parsed = await find_one_by_id("agenda", event["guid"]) assert isinstance(parsed["firstcreated"], datetime) assert parsed["dates"]["tz"] == "Australia/Sydney" - assert parsed["dates"]["end"] == datetime.strptime("2018-05-28T05:00:00+0000", "%Y-%m-%dT%H:%M:%S+0000").replace( - tzinfo=pytz.UTC - ) + assert parsed["dates"]["end"] == datetime.strptime("2018-05-28T05:00:00+0000", "%Y-%m-%dT%H:%M:%S+0000") assert 1 == len(parsed["event"]["event_contact_info"]) assert 1 == len(parsed["location"]) assert 1 == len(parsed["service"]) @@ -217,7 +214,7 @@ async def test_push_cancelled_event(client, app): resp = await client.post("/push", json=event) assert resp.status_code == 200 - parsed = get_entity_or_404(event["guid"], "agenda") + parsed = await find_one_by_id("agenda", event["guid"]) assert isinstance(parsed["firstcreated"], datetime) assert 1 == len(parsed["event"]["event_contact_info"]) assert 1 == len(parsed["location"]) @@ -238,7 +235,7 @@ async def test_push_updated_event(client, app): "tz": "Australia/Sydney", } await client.post("/push", json=event) - parsed = get_entity_or_404(event["guid"], "agenda") + parsed = await find_one_by_id("agenda", event["guid"]) assert isinstance(parsed["firstcreated"], datetime) assert 1 == len(parsed["event"]["event_contact_info"]) assert 1 == len(parsed["location"]) @@ -270,7 +267,7 @@ async def test_push_parsed_planning_for_an_existing_event(client, app): event = deepcopy(test_event) event["guid"] = "foo4" await client.post("/push", json=event) - parsed = get_entity_or_404(event["guid"], "agenda") + parsed = await find_one_by_id("agenda", event["guid"]) assert isinstance(parsed["firstcreated"], datetime) assert 1 == len(parsed["event"]["event_contact_info"]) assert 1 == len(parsed["location"]) @@ -279,13 +276,13 @@ async def test_push_parsed_planning_for_an_existing_event(client, app): planning["guid"] = "bar1" planning["event_item"] = "foo4" await client.post("/push", json=planning) - parsed = get_entity_or_404("foo4", "agenda") + parsed = await find_one_by_id("agenda", "foo4") assert parsed["name"] == test_event["name"] assert parsed["definition_short"] == test_event["definition_short"] assert parsed["slugline"] == test_event["slugline"] assert parsed["definition_long"] == test_event["definition_long"] - assert parsed["dates"]["start"].isoformat() == test_event["dates"]["start"].replace("0000", "00:00") - assert parsed["dates"]["end"].isoformat() == test_event["dates"]["end"].replace("0000", "00:00") + assert parsed["dates"]["start"].isoformat() == test_event["dates"]["start"].replace("+0000", "") + assert parsed["dates"]["end"].isoformat() == test_event["dates"]["end"].replace("+0000", "") assert parsed["ednote"] == event["ednote"] assert 2 == len(parsed["coverages"]) @@ -312,7 +309,7 @@ async def test_push_coverages_with_different_dates_for_an_existing_event(client, event = deepcopy(test_event) event["guid"] = "foo4" await client.post("/push", json=event) - parsed = get_entity_or_404(event["guid"], "agenda") + parsed = await find_one_by_id("agenda", event["guid"]) assert isinstance(parsed["firstcreated"], datetime) assert 1 == len(parsed["event"]["event_contact_info"]) assert 1 == len(parsed["location"]) @@ -325,7 +322,7 @@ async def test_push_coverages_with_different_dates_for_an_existing_event(client, # planning['planning_date'] = "2018-05-28T10:51:52+0000" await client.post("/push", json=planning) - parsed = get_entity_or_404("foo4", "agenda") + parsed = await find_one_by_id("agenda", "foo4") assert parsed["name"] == test_event["name"] assert parsed["definition_short"] == test_event["definition_short"] assert parsed["slugline"] == test_event["slugline"] @@ -334,8 +331,8 @@ async def test_push_coverages_with_different_dates_for_an_existing_event(client, assert parsed_planning["description_text"] == planning["description_text"] assert 2 == len(parsed["coverages"]) - assert parsed["dates"]["start"].isoformat() == event["dates"]["start"].replace("0000", "00:00") - assert parsed["dates"]["end"].isoformat() == event["dates"]["end"].replace("0000", "00:00") + assert parsed["dates"]["start"].isoformat() == event["dates"]["start"].replace("+0000", "") + assert parsed["dates"]["end"].isoformat() == event["dates"]["end"].replace("+0000", "") assert 2 == len(parsed["display_dates"]) assert parsed["display_dates"][0]["date"].isoformat() == planning["coverages"][0]["planning"]["scheduled"].replace( "0000", "00:00" @@ -349,7 +346,7 @@ async def test_push_planning_with_different_dates_for_an_existing_event(client, event = deepcopy(test_event) event["guid"] = "foo4" await client.post("/push", json=event) - parsed = get_entity_or_404(event["guid"], "agenda") + parsed = await find_one_by_id("agenda", event["guid"]) assert isinstance(parsed["firstcreated"], datetime) assert 1 == len(parsed["event"]["event_contact_info"]) assert 1 == len(parsed["location"]) @@ -362,12 +359,12 @@ async def test_push_planning_with_different_dates_for_an_existing_event(client, planning["planning_date"] = "2018-07-28T10:51:52+0000" await client.post("/push", json=planning) - parsed = get_entity_or_404("foo4", "agenda") + parsed = await find_one_by_id("agenda", "foo4") assert parsed["name"] == test_event["name"] assert parsed["definition_short"] == test_event["definition_short"] assert parsed["slugline"] == test_event["slugline"] - assert parsed["dates"]["start"].isoformat() == event["dates"]["start"].replace("0000", "00:00") - assert parsed["dates"]["end"].isoformat() == event["dates"]["end"].replace("0000", "00:00") + assert parsed["dates"]["start"].isoformat() == event["dates"]["start"].replace("+0000", "") + assert parsed["dates"]["end"].isoformat() == event["dates"]["end"].replace("+0000", "") assert 1 == len(parsed["display_dates"]) assert parsed["display_dates"][0]["date"].isoformat() == planning["planning_date"].replace("0000", "00:00") @@ -380,7 +377,7 @@ async def test_push_cancelled_planning_for_an_existing_event(client, app): event = deepcopy(test_event) event["guid"] = "foo5" await client.post("/push", json=event) - parsed = get_entity_or_404(event["guid"], "agenda") + parsed = await find_one_by_id("agenda", event["guid"]) assert isinstance(parsed["firstcreated"], datetime) assert 1 == len(parsed["event"]["event_contact_info"]) assert 1 == len(parsed["location"]) @@ -390,7 +387,7 @@ async def test_push_cancelled_planning_for_an_existing_event(client, app): planning["guid"] = "bar2" planning["event_item"] = "foo5" await client.post("/push", json=planning) - parsed = get_entity_or_404("foo5", "agenda") + parsed = await find_one_by_id("agenda", "foo5") assert len(parsed["coverages"]) == 2 assert len(parsed["planning_items"]) == 1 @@ -400,7 +397,7 @@ async def test_push_cancelled_planning_for_an_existing_event(client, app): # second push await client.post("/push", json=planning) - parsed = get_entity_or_404("foo5", "agenda") + parsed = await find_one_by_id("agenda", "foo5") assert len(parsed["coverages"]) == 0 assert len(parsed["planning_items"]) == 0 @@ -417,7 +414,7 @@ async def test_push_parsed_adhoc_planning_for_an_non_existing_event(client, app) planning["event_item"] = None await client.post("/push", json=planning) - parsed = get_entity_or_404("bar3", "agenda") + parsed = await find_one_by_id("agenda", "bar3") assert isinstance(parsed["firstcreated"], datetime) assert 2 == len(parsed["coverages"]) assert 1 == len(parsed["planning_items"]) @@ -432,8 +429,8 @@ async def test_notify_topic_matches_for_new_event_item(client, app, mocker): event = deepcopy(test_event) await client.post("/push", json=event) - user_ids = app.data.insert( - "users", + user_ids = await create_entries_for( + "auth_user", [ { "email": "foo2@bar.com", @@ -478,8 +475,8 @@ async def test_notify_topic_matches_for_new_planning_item(client, app, mocker): event = deepcopy(test_event) await client.post("/push", json=event) - user_ids = app.data.insert( - "users", + user_ids = await create_entries_for( + "auth_user", [ { "email": "foo2@bar.com", @@ -531,8 +528,8 @@ async def test_notify_topic_matches_for_ad_hoc_planning_item(client, app, mocker planning["event_item"] = None client.post("/push", json=planning) - user_ids = app.data.insert( - "users", + user_ids = await create_entries_for( + "auth_user", [ { "email": "foo2@bar.com", @@ -576,7 +573,7 @@ async def test_notify_topic_matches_for_ad_hoc_planning_item(client, app, mocker @markers.requires_async_celery @mock.patch("newsroom.email.send_email", mock_send_email) async def test_notify_user_matches_for_ad_hoc_agenda_in_history(client, app, mocker): - company_ids = app.data.insert( + company_ids = await create_entries_for( "companies", [ { @@ -595,9 +592,10 @@ async def test_notify_user_matches_for_ad_hoc_agenda_in_history(client, app, moc "company": company_ids[0], } - user_ids = app.data.insert("users", [user]) + user_ids = await create_entries_for("auth_user", [user]) user["_id"] = user_ids[0] + # TODO-ASYNC-AGENDA: Replace this with the proper function call app.data.insert( "history", docs=[ @@ -638,7 +636,7 @@ async def test_notify_user_matches_for_ad_hoc_agenda_in_history(client, app, moc @markers.requires_async_celery @mock.patch("newsroom.email.send_email", mock_send_email) async def test_notify_user_matches_for_new_agenda_in_history(client, app, mocker): - company_ids = app.data.insert( + company_ids = await create_entries_for( "companies", [ { @@ -657,9 +655,10 @@ async def test_notify_user_matches_for_new_agenda_in_history(client, app, mocker "company": company_ids[0], } - user_ids = app.data.insert("users", [user]) + user_ids = await create_entries_for("auth_user", [user]) user["_id"] = user_ids[0] + # TODO-ASYNC-AGENDA: Replace this with the proper function call app.data.insert( "history", docs=[ @@ -698,7 +697,7 @@ async def test_notify_user_matches_for_new_planning_in_history(client, app, mock event = deepcopy(test_event) await client.post("/push", json=event) - company_ids = app.data.insert( + company_ids = await create_entries_for( "companies", [ { @@ -717,9 +716,10 @@ async def test_notify_user_matches_for_new_planning_in_history(client, app, mock "company": company_ids[0], } - user_ids = app.data.insert("users", [user]) + user_ids = await create_entries_for("auth_user", [user]) user["_id"] = user_ids[0] + # TODO-ASYNC-AGENDA: Replace this with the proper function call app.data.insert( "history", docs=[ @@ -761,7 +761,7 @@ async def test_notify_user_matches_for_killed_item_in_history(client, app, mocke event = deepcopy(test_event) await client.post("/push", json=event) - company_ids = app.data.insert( + company_ids = await create_entries_for( "companies", [ { @@ -780,9 +780,10 @@ async def test_notify_user_matches_for_killed_item_in_history(client, app, mocke "company": company_ids[0], } - user_ids = app.data.insert("users", [user]) + user_ids = await create_entries_for("auth_user", [user]) user["_id"] = user_ids[0] + # TODO-ASYNC-AGENDA: Replace this with the proper function call app.data.insert( "history", docs=[ @@ -875,8 +876,8 @@ async def test_push_story_wont_notify_for_first_publish(client, app, mocker): assert len(outbox) == 0 -def assign_active_company(app): - company_ids = app.data.insert( +async def assign_active_company(app): + company_ids = await create_entries_for( "companies", [ { @@ -886,8 +887,8 @@ def assign_active_company(app): ], ) - current_user = app.data.find_one("users", req=None, _id=ADMIN_USER_ID) - app.data.update("users", current_user["_id"], {"company": company_ids[0]}, current_user) + current_user = await find_one_by_id("users", ADMIN_USER_ID) + await update_entries_for("users", current_user["_id"], {"company": company_ids[0]}, current_user) return current_user["_id"] @@ -896,7 +897,7 @@ def assign_active_company(app): async def test_watched_event_sends_notification_for_event_update(client, app, mocker): event = deepcopy(test_event) await post_json(client, "/push", event) - user_id = assign_active_company(app) + user_id = await assign_active_company(app) await post_json(client, "/agenda_watch", {"items": [event["guid"]]}) # update comes in @@ -935,7 +936,7 @@ async def test_watched_event_sends_notification_for_unpost_event(client, app, mo planning = deepcopy(test_planning) await post_json(client, "/push", event) await post_json(client, "/push", planning) - user_id = assign_active_company(app) + user_id = await assign_active_company(app) await post_json(client, "/agenda_watch", {"items": [event["guid"]]}) # update the event for unpost @@ -968,7 +969,7 @@ async def test_watched_event_sends_notification_for_unpost_event(client, app, mo async def test_watched_event_sends_notification_for_added_planning(client, app, mocker): event = deepcopy(test_event) await post_json(client, "/push", event) - user_id = assign_active_company(app) + user_id = await assign_active_company(app) await post_json(client, "/agenda_watch", {"items": [event["guid"]]}) # planning comes in @@ -1004,7 +1005,7 @@ async def test_watched_event_sends_notification_for_cancelled_planning(client, a planning = deepcopy(test_planning) await post_json(client, "/push", event) await post_json(client, "/push", planning) - user_id = assign_active_company(app) + user_id = await assign_active_company(app) await post_json(client, "/agenda_watch", {"items": [event["guid"]]}) # update the planning for cancel @@ -1040,7 +1041,7 @@ async def test_watched_event_sends_notification_for_added_coverage(client, app, planning = deepcopy(test_planning) await post_json(client, "/push", event) await post_json(client, "/push", planning) - user_id = assign_active_company(app) + user_id = await assign_active_company(app) await post_json(client, "/agenda_watch", {"items": [event["guid"]]}) # update the planning with an added coverage @@ -1119,17 +1120,17 @@ async def test_push_cancelled_planning_cancels_adhoc_planning(client, app): planning["event_item"] = None await client.post("/push", json=planning) - parsed = get_entity_or_404("bar3", "agenda") + parsed = await find_one_by_id("agenda", "bar3") assert parsed["state"] == "scheduled" # cancel planning item planning["state"] = "cancelled" - planning["ednote"] = ("-------------------------\nPlanning cancelled\nReason: Test\n",) + planning["ednote"] = "-------------------------\nPlanning cancelled\nReason: Test\n" await client.post("/push", json=planning) - parsed = get_entity_or_404("bar3", "agenda") + parsed = await find_one_by_id("agenda", "bar3") assert parsed["state"] == "cancelled" - assert "Reason" in parsed["ednote"][0] + assert "Reason" in parsed["ednote"] async def test_push_update_for_an_item_with_coverage(client, app, mocker): @@ -1189,7 +1190,7 @@ async def test_push_coverages_with_linked_stories(client, app): planning["coverages"][0]["workflow_status"] = "completed" await client.post("/push", json=planning) - parsed = get_entity_or_404("foo7", "agenda") + parsed = await find_one_by_id("agenda", "foo7") assert 2 == len(parsed["coverages"]) assert parsed["coverages"][0]["delivery_id"] == "item7" assert parsed["coverages"][0]["delivery_href"] == "/wire/item7" @@ -1197,7 +1198,7 @@ async def test_push_coverages_with_linked_stories(client, app): planning["coverages"][0]["deliveries"] = [] planning["coverages"][0]["workflow_status"] = "active" await client.post("/push", json=planning) - parsed = get_entity_or_404("foo7", "agenda") + parsed = await find_one_by_id("agenda", "foo7") assert 2 == len(parsed["coverages"]) assert parsed["coverages"][0]["delivery_id"] is None assert parsed["coverages"][0]["delivery_href"] is None @@ -1220,7 +1221,7 @@ async def test_push_coverages_with_updates_to_linked_stories(client, app): planning["coverages"][0]["workflow_status"] = "completed" await client.post("/push", json=planning) - parsed = get_entity_or_404("foo7", "agenda") + parsed = await find_one_by_id("agenda", "foo7") assert 2 == len(parsed["coverages"]) assert parsed["coverages"][0]["delivery_id"] == "item7" assert parsed["coverages"][0]["delivery_href"] == "/wire/item7" @@ -1234,7 +1235,7 @@ async def test_push_coverages_with_updates_to_linked_stories(client, app): ) await client.post("/push", json=planning) - parsed = get_entity_or_404("foo7", "agenda") + parsed = await find_one_by_id("agenda", "foo7") assert 2 == len(parsed["coverages"]) assert parsed["coverages"][0]["delivery_id"] == "item7" assert parsed["coverages"][0]["delivery_href"] == "/wire/item7" @@ -1248,7 +1249,7 @@ async def test_push_coverages_with_updates_to_linked_stories(client, app): ) await client.post("/push", json=planning) - parsed = get_entity_or_404("foo7", "agenda") + parsed = await find_one_by_id("agenda", "foo7") assert 2 == len(parsed["coverages"]) assert parsed["coverages"][0]["delivery_id"] == "item8" assert parsed["coverages"][0]["delivery_href"] == "/wire/item8" @@ -1271,7 +1272,7 @@ async def test_push_coverages_with_correction_to_linked_stories(client, app): planning["coverages"][0]["workflow_status"] = "completed" await client.post("/push", json=planning) - parsed = get_entity_or_404("foo7", "agenda") + parsed = await find_one_by_id("agenda", "foo7") assert 2 == len(parsed["coverages"]) assert parsed["coverages"][0]["delivery_id"] == "item7" assert parsed["coverages"][0]["delivery_href"] == "/wire/item7" @@ -1286,7 +1287,7 @@ async def test_push_coverages_with_correction_to_linked_stories(client, app): ) await client.post("/push", json=planning) - parsed = get_entity_or_404("foo7", "agenda") + parsed = await find_one_by_id("agenda", "foo7") assert 2 == len(parsed["coverages"]) # Coverage should point to the latest version assert parsed["coverages"][0]["delivery_id"] == "item8" @@ -1302,7 +1303,7 @@ async def test_push_coverages_with_correction_to_linked_stories(client, app): ) await client.post("/push", json=planning) - parsed = get_entity_or_404("foo7", "agenda") + parsed = await find_one_by_id("agenda", "foo7") assert 2 == len(parsed["coverages"]) # Coverage should still point to the latest version assert parsed["coverages"][0]["delivery_id"] == "item8" @@ -1315,7 +1316,7 @@ async def test_push_event_from_planning(client, app): plan["planning_date"] = "2018-05-29T00:00:00+0000" plan.pop("event_item", None) await post_json(client, "/push", plan) - parsed = get_entity_or_404(plan["guid"], "agenda") + parsed = await find_one_by_id("agenda", plan["guid"]) assert parsed["slugline"] == test_planning["slugline"] assert parsed["headline"] == test_planning["headline"] @@ -1333,7 +1334,7 @@ async def test_push_event_from_planning(client, app): event["guid"] = "retrospective_event" event["plans"] = ["adhoc_plan"] await post_json(client, "/push", event) - parsed = get_entity_or_404(event["guid"], "agenda") + parsed = await find_one_by_id("agenda", event["guid"]) assert parsed["slugline"] == test_event["slugline"] assert parsed["definition_short"] == test_event["definition_short"] @@ -1343,8 +1344,8 @@ async def test_push_event_from_planning(client, app): assert "a" == parsed["service"][0]["code"] assert 1 == len(parsed["subject"]) assert "06002002" == parsed["subject"][0]["code"] - assert parsed["dates"]["start"].isoformat() == event["dates"]["start"].replace("0000", "00:00") - assert parsed["dates"]["end"].isoformat() == event["dates"]["end"].replace("0000", "00:00") + assert parsed["dates"]["start"].isoformat() == event["dates"]["start"].replace("+0000", "") + assert parsed["dates"]["end"].isoformat() == event["dates"]["end"].replace("+0000", "") async def test_coverages_delivery_sequence_has_default(client, app): @@ -1360,7 +1361,7 @@ async def test_coverages_delivery_sequence_has_default(client, app): planning["coverages"][0]["coverage_type"] = "text" await client.post("/push", json=planning) - parsed = get_entity_or_404("foo7", "agenda") + parsed = await find_one_by_id("agenda", "foo7") assert 2 == len(parsed["coverages"]) assert parsed["coverages"][0]["delivery_id"] == "item7" assert parsed["coverages"][0]["delivery_href"] == "/wire/item7" @@ -1389,9 +1390,9 @@ async def test_item_planning_reference_set_on_fulfill(client, app): await post_json(client, "/push", planning) - parsed = get_entity_or_404( + parsed = await find_one_by_id( + "items", "urn:newsml:localhost:2020-08-06T15:59:39.183090:1f02e9bb-3007-48f3-bfad-ffa6107f87bd", - "content_api", ) assert parsed["planning_id"] == "bar1" assert ( @@ -1406,17 +1407,17 @@ async def test_push_plan_with_date_before_event_start(client, app): event["plans"] = [planning["guid"]] await client.post("/push", json=event) - parsed = get_entity_or_404(event["guid"], "agenda") + parsed = await find_one_by_id("agenda", event["guid"]) assert 0 == len(parsed["planning_items"]) # Change the Planning Date to before the Event's start date planning["planning_date"] = "2018-05-28T04:00:00+0000" await client.post("/push", json=planning) - parsed = get_entity_or_404(event["guid"], "agenda") + parsed = await find_one_by_id("agenda", event["guid"]) assert 1 == len(parsed["planning_items"]) assert 2 == len(parsed["coverages"]) - parsed = get_entity_or_404(planning["guid"], "agenda") + parsed = await find_one_by_id("agenda", planning["guid"]) assert 1 == len(parsed["planning_items"]) assert 2 == len(parsed["coverages"]) @@ -1431,7 +1432,7 @@ def on_push_planning(sender, item, is_new, **kwargs): planning = deepcopy(test_planning) await client.post("/push", json=planning) - parsed = get_entity_or_404(planning["guid"], "agenda") + parsed = await find_one_by_id("agenda", planning["guid"]) assert parsed and parsed["dates"]["all_day"] @@ -1445,7 +1446,7 @@ def on_push_event(sender, item, is_new, **kwargs): event = deepcopy(test_event) await client.post("/push", json=event) - parsed = get_entity_or_404(event["guid"], "agenda") + parsed = await find_one_by_id("agenda", event["guid"]) assert parsed and parsed["dates"]["all_day"] @@ -1462,9 +1463,11 @@ async def test_push_planning_coverages_assignemnt_info(client, app): "email": "john@example.com", } - await client.post("/push", json=planning) + response = await client.post("/push", json=planning) + print(response.status_code) + print(await response.get_data(as_text=True)) - parsed = get_entity_or_404(planning["guid"], "agenda") + parsed = await find_one_by_id("agenda", planning["guid"]) assert parsed coverage = parsed["coverages"][0] assert "Sports" == coverage["assigned_desk_name"] diff --git a/tests/core/test_push_evolved.py b/tests/core/test_push_evolved.py index 3c7076c32..eebcf887a 100644 --- a/tests/core/test_push_evolved.py +++ b/tests/core/test_push_evolved.py @@ -1,4 +1,5 @@ from unittest.mock import patch +from tests.core.utils import delete_entries_for def ids(items): @@ -6,7 +7,7 @@ def ids(items): async def test_evolved_from_order(client, app): - app.data.remove("items") + await delete_entries_for("items") async def push_item(data): resp = await client.post("/push", json=data) diff --git a/tests/core/test_realtime_notifications.py b/tests/core/test_realtime_notifications.py index 15b268ebc..21fa13986 100644 --- a/tests/core/test_realtime_notifications.py +++ b/tests/core/test_realtime_notifications.py @@ -8,13 +8,12 @@ from newsroom.push.tasks import notify_new_agenda_item, notify_new_wire_item from superdesk.utc import utcnow -from newsroom.tests import markers from newsroom.users import UsersService from newsroom.companies import CompanyServiceAsync from newsroom.notifications import NotificationsService from newsroom.tests.fixtures import COMPANY_1_ID, PUBLIC_USER_ID from newsroom.tests.users import ADMIN_USER_ID -from tests.core.utils import create_entries_for +from tests.core.utils import create_entries_for, update_entries_for, find_one_by_id from ..utils import mock_send_email @@ -45,7 +44,7 @@ async def test_realtime_notifications_wire(app, mocker, company_products): # we want only products which will filter out everything continue updates = {"navigations": [navigations[0]["_id"]]} - app.data.update("products", product["_id"], updates, product) + await update_entries_for("products", product["_id"], updates, product) await create_entries_for( "topics", @@ -214,7 +213,7 @@ async def test_realtime_notifications_agenda(app, mocker): topic_id = ObjectId() - app.data.insert( + await create_entries_for( "products", [ { @@ -227,13 +226,13 @@ async def test_realtime_notifications_agenda(app, mocker): ], ) - company = app.data.find_one("companies", req=None, _id=COMPANY_1_ID) + company = await find_one_by_id("companies", COMPANY_1_ID) assert company - app.data.update( + await update_entries_for( "companies", company["_id"], {"products": [{"_id": topic_id, "seats": 0, "section": "agenda"}]}, company ) - app.data.insert( + await create_entries_for( "agenda", [ { @@ -279,12 +278,14 @@ async def test_realtime_notifications_agenda(app, mocker): async def test_realtime_notifications_agenda_reccuring_event(app): - app.data.insert( + await create_entries_for( "agenda", [ { "_id": "event_id_1", "type": "agenda", + "item_type": "event", + "state": "scheduled", "versioncreated": utcnow(), "name": "cheese event", "dates": { @@ -296,6 +297,8 @@ async def test_realtime_notifications_agenda_reccuring_event(app): { "_id": "event_id_2", "type": "agenda", + "item_type": "event", + "state": "scheduled", "versioncreated": utcnow(), "name": "another event", "dates": { @@ -320,24 +323,26 @@ async def test_realtime_notifications_agenda_reccuring_event(app): assert notifier.notify_new_item.call_count == 2 -@markers.requires_async_celery +# @markers.requires_async_celery @mock.patch("newsroom.email.send_email", mock_send_email) async def test_pause_notifications(app, mocker, company_products): - user = app.data.find_one("users", req=None, _id=PUBLIC_USER_ID) + user = await find_one_by_id("users", PUBLIC_USER_ID) updates = { "notification_schedule": dict( pause_from=(datetime.now() - timedelta(days=1)).date().isoformat(), pause_to=(datetime.now() + timedelta(days=1)).date().isoformat(), ) } - app.data.update("users", user["_id"], updates, user) + await update_entries_for("users", user["_id"], updates, user) - app.data.insert( + await create_entries_for( "agenda", [ { "_id": "event_id_1", "type": "agenda", + "item_type": "event", + "state": "scheduled", "versioncreated": utcnow(), "name": "cheese event", "dates": { @@ -348,7 +353,7 @@ async def test_pause_notifications(app, mocker, company_products): ], ) - app.data.insert( + await create_entries_for( "items", [ { diff --git a/tests/core/test_reports.py b/tests/core/test_reports.py index 2ca1a8d6c..317220658 100644 --- a/tests/core/test_reports.py +++ b/tests/core/test_reports.py @@ -8,8 +8,8 @@ @fixture(autouse=True) async def init(app): - app.data.insert( - "users", + await create_entries_for( + "auth_user", [ { "_id": ObjectId("5cc94454bc43165c045ffec0"), @@ -36,7 +36,7 @@ async def init(app): }, ], ) - app.data.insert( + await create_entries_for( "products", [ { @@ -55,7 +55,7 @@ async def init(app): }, ], ) - app.data.insert( + await create_entries_for( "companies", [ { @@ -174,7 +174,7 @@ async def test_product_companies(client): async def test_expired_companies(client, app): - app.data.insert( + await create_entries_for( "companies", [ { diff --git a/tests/core/test_saml.py b/tests/core/test_saml.py index 81bdc53d2..92b0381ec 100644 --- a/tests/core/test_saml.py +++ b/tests/core/test_saml.py @@ -5,6 +5,7 @@ from newsroom.auth.saml import get_userdata from newsroom.companies import CompanyServiceAsync, CompanyResource +from tests.core.utils import create_entries_for, update_entries_for async def test_user_data_with_matching_company(app): @@ -12,7 +13,7 @@ async def test_user_data_with_matching_company(app): "name": "test", "auth_domains": ["example.com"], } - app.data.insert("companies", [company]) + await create_entries_for("companies", [company]) saml_data = { "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname": ["Foo"], @@ -35,7 +36,7 @@ async def test_user_data_with_matching_preconfigured_client(app, client): "auth_domains": ["samplecomp"], } - app.data.insert("companies", [company]) + await create_entries_for("companies", [company]) saml_data = { "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/givenname": ["Foo"], @@ -59,7 +60,7 @@ async def test_user_data_with_matching_preconfigured_client(app, client): assert user_data.get("company") == company["_id"] assert user_data.get("user_type") == "public" - app.data.update("companies", company["_id"], {"internal": True}, company) + await update_entries_for("companies", company["_id"], {"internal": True}, company) async with app.test_client() as c: resp = await c.get("/login/samplecomp") diff --git a/tests/core/test_search_config.py b/tests/core/test_search_config.py index 950a36f45..222fa89b4 100644 --- a/tests/core/test_search_config.py +++ b/tests/core/test_search_config.py @@ -3,8 +3,7 @@ from quart.testing import QuartClient from newsroom.factory.app import BaseNewsroomApp -from newsroom.agenda import AGENDA_NESTED_SEARCH_FIELDS -from newsroom.agenda.agenda import aggregations as agenda_aggregations +from newsroom.agenda.filters import aggregations as agenda_aggregations from newsroom.wire import WIRE_NESTED_SEARCH_FIELDS from newsroom.wire.filters import _get_wire_aggregations from newsroom.search.config import init_nested_aggregation @@ -131,7 +130,7 @@ async def test_custom_agenda_groups_config(app: BaseNewsroomApp, client: QuartCl }, } ) - init_nested_aggregation("agenda", AGENDA_NESTED_SEARCH_FIELDS, app.config["AGENDA_GROUPS"], agenda_aggregations) + init_nested_aggregation("agenda", ["subject"], app.config["AGENDA_GROUPS"], agenda_aggregations) await reset_elastic(app) # Test if the Eve & agenda_aggregations config has been updated diff --git a/tests/core/test_send_scheduled_notifications.py b/tests/core/test_send_scheduled_notifications.py index c829be35d..4b44bd572 100644 --- a/tests/core/test_send_scheduled_notifications.py +++ b/tests/core/test_send_scheduled_notifications.py @@ -4,11 +4,16 @@ from newsroom.topics.topics_async import TopicService from superdesk.utc import utcnow, utc_to_local -from newsroom.types import NotificationSchedule, NotificationQueue, NotificationTopic +from newsroom.types import ( + NotificationQueue, + NotificationTopic, + UserResourceModel, + TopicResourceModel, + NotificationScheduleModel, +) from newsroom.notifications import NotificationQueueService from newsroom.notifications.commands import SendScheduledNotificationEmails - -from tests.core.utils import create_entries_for +from tests.core.utils import create_entries_for, get_all def test_convert_schedule_times(): @@ -110,7 +115,7 @@ async def test_get_latest_item_from_topic_queue(app, user): ], ) )[0] - topic = (await TopicService().find_by_id(topic_id)).to_dict() + topic = await TopicService().find_by_id(topic_id) await create_entries_for( "items", @@ -133,9 +138,14 @@ async def test_get_latest_item_from_topic_queue(app, user): ) command = SendScheduledNotificationEmails() - item = command.get_latest_item_from_topic_queue(topic_queue, topic, user, None, set()) + item = await command.get_latest_item_from_topic_queue( + topic_queue, topic, UserResourceModel.from_dict(user), None, set() + ) assert item["_id"] == "topic1_item1" + from superdesk.core import json + + print(json.dumps(item)) assert 'cheese' in item["es_highlight"]["body_html"][0] assert 'cheese' in item["es_highlight"]["slugline"][0] @@ -158,7 +168,7 @@ async def test_get_topic_entries_and_match_table(app, user): }, ], ) - user_topics = {topic["_id"]: topic for topic in app.data.find_all("topics")} + user_topics = {topic["_id"]: TopicResourceModel.from_dict(topic) for topic in await get_all("topics")} await create_entries_for( "items", [ @@ -186,7 +196,9 @@ async def test_get_topic_entries_and_match_table(app, user): ) command = SendScheduledNotificationEmails() - topic_entries, topic_match_table = command.get_topic_entries_and_match_table(schedule, user, None, user_topics) + topic_entries, topic_match_table = await command.get_topic_entries_and_match_table( + schedule, UserResourceModel.from_dict(user), None, user_topics + ) assert len(topic_entries["wire"]) == 1 assert topic_entries["wire"][0]["topic"]["label"] == "Cheesy Stuff" @@ -202,7 +214,7 @@ async def test_is_scheduled_to_run_for_user(): timezone = "Australia/Sydney" # Run schedule if ``last_run_time`` is not defined and ``force=True`` - assert command.is_scheduled_to_run_for_user({"timezone": timezone}, utcnow(), True) is True + assert command.is_scheduled_to_run_for_user(NotificationScheduleModel(timezone=timezone), utcnow(), True) is True times = ["07:00", "15:00", "20:00"] tests = [ @@ -222,12 +234,12 @@ def create_datetime_instance(hour: int, minute: int) -> datetime: return utc_to_local(timezone, utcnow()).replace(hour=hour, minute=minute, second=0, microsecond=0) for test in tests: - schedule: NotificationSchedule = { - "timezone": timezone, - "times": times, - } + schedule = NotificationScheduleModel( + timezone=timezone, + times=times, + ) if test.get("last_run"): - schedule["last_run_time"] = create_datetime_instance(test["last_run"][0], test["last_run"][1]) + schedule.last_run_time = create_datetime_instance(test["last_run"][0], test["last_run"][1]) now = create_datetime_instance(test["now"][0], test["now"][1]) diff --git a/tests/core/test_signup.py b/tests/core/test_signup.py index a548ddc27..0720828aa 100644 --- a/tests/core/test_signup.py +++ b/tests/core/test_signup.py @@ -2,11 +2,12 @@ from quart import url_for -from newsroom.types import CompanyType, Country, ProductType, CompanyProduct +from newsroom.types import CompanyType, Country, SectionEnum, CompanyProduct from newsroom.users import UsersAuthService from newsroom.companies.companies_async import CompanyService from tests.utils import get_user_by_email, mock_send_email +from tests.core.utils import create_entries_for, find_one_for @mock.patch("newsroom.email.send_email", mock_send_email) @@ -15,7 +16,7 @@ async def test_new_user_signup_sends_email(app, client): app.countries = [Country(value="AUS", text="Australia")] app.config["SIGNUP_EMAIL_RECIPIENTS"] = "admin@bar.com" app.config["COMPANY_TYPES"] = [CompanyType(id="news_media", name="News Media")] - product_ids = app.data.insert( + product_ids = await create_entries_for( "products", [{"name": "test", "query": "foo", "is_enabled": True, "product_type": "wire"}] ) with app.mail.record_messages() as outbox: @@ -66,13 +67,13 @@ async def test_new_user_signup_sends_email(app, client): assert new_company.products == [ CompanyProduct( _id=product_ids[0], - section=ProductType.WIRE, + section=SectionEnum.WIRE, seats=0, ) ] # Test that the new User has been created - new_user = app.data.find_one("users", req=None, email="newuser@abc.org") + new_user = await find_one_for("users", email="newuser@abc.org") assert new_user is not None assert new_user["first_name"] == "John" assert new_user["last_name"] == "Doe" @@ -138,7 +139,7 @@ async def test_approve_company_and_users(app, client): assert response.status_code == 200 # Test the Company & User are not enabled nor approved - new_company = app.data.find_one("companies", req=None, name="Doe Press Co.") + new_company = await find_one_for("companies", name="Doe Press Co.") assert new_company["is_enabled"] is False assert new_company["is_approved"] is False @@ -153,7 +154,7 @@ async def test_approve_company_and_users(app, client): assert response.status_code == 200 # Test the Company & User are now enabled and approved, but not yet validated - new_company = app.data.find_one("companies", req=None, name="Doe Press Co.") + new_company = await find_one_for("companies", name="Doe Press Co.") assert new_company["is_enabled"] is True assert new_company["is_approved"] is True @@ -187,10 +188,10 @@ async def test_approve_company_and_users(app, client): assert response.status_code == 200 # Test the Company is enabled and approved, but new User is not - new_company = app.data.find_one("companies", req=None, name="Doe Press Co.") + new_company = await find_one_for("companies", name="Doe Press Co.") assert new_company["is_enabled"] is True assert new_company["is_approved"] is True - new_user = app.data.find_one("users", req=None, email="jane@doe.org") + new_user = await find_one_for("users", email="jane@doe.org") assert new_user["is_enabled"] is False assert new_user["is_approved"] is False assert new_user["is_validated"] is False diff --git a/tests/core/test_user_dashboards.py b/tests/core/test_user_dashboards.py index 9664e1e61..c0fccebe8 100644 --- a/tests/core/test_user_dashboards.py +++ b/tests/core/test_user_dashboards.py @@ -8,7 +8,7 @@ from newsroom.topics.topics_async import TopicService from datetime import datetime -from tests.core.utils import create_entries_for +from tests.core.utils import create_entries_for, delete_entries_for, update_entries_for async def test_user_dashboards(app, client, public_user, public_company, company_products): @@ -16,11 +16,11 @@ async def test_user_dashboards(app, client, public_user, public_company, company topic_id = (await create_entries_for("topics", topics))[0] topic = await TopicService().find_by_id(topic_id) - app.data.remove("products") + await delete_entries_for("products") products = [{"name": "test", "query": "foo", "is_enabled": True, "product_type": "wire"}] - app.data.insert("products", products) + await create_entries_for("products", products) - assert app.data.update( + await update_entries_for( "companies", public_company["_id"], { @@ -32,13 +32,13 @@ async def test_user_dashboards(app, client, public_user, public_company, company public_company_instance = await CompanyServiceAsync().find_by_id(public_company["_id"]) assert 1 == len(public_company_instance.products) - app.data.insert( + await create_entries_for( "items", [ - {"guid": "test1", "headline": "foo", "versioncreated": datetime.utcnow()}, - {"guid": "test2", "headline": "bar", "versioncreated": datetime.utcnow()}, - {"guid": "test3", "headline": "baz", "versioncreated": datetime.utcnow()}, - {"guid": "test4", "headline": "foo bar", "versioncreated": datetime.utcnow()}, + {"_id": "test1", "guid": "test1", "headline": "foo", "versioncreated": datetime.utcnow()}, + {"_id": "test2", "guid": "test2", "headline": "bar", "versioncreated": datetime.utcnow()}, + {"_id": "test3", "guid": "test3", "headline": "baz", "versioncreated": datetime.utcnow()}, + {"_id": "test4", "guid": "test4", "headline": "foo bar", "versioncreated": datetime.utcnow()}, ], ) @@ -84,10 +84,10 @@ async def test_user_dashboards(app, client, public_user, public_company, company async def test_dashboard_data_for_user_without_wire_section(app): products = [ - {"product_type": "wire"}, + {"name": "Sports", "product_type": "wire"}, ] - app.data.insert("products", products) + await create_entries_for("products", products) topic = TopicResourceModel.from_dict( { diff --git a/tests/core/test_users.py b/tests/core/test_users.py index e3ad8601b..03156391f 100644 --- a/tests/core/test_users.py +++ b/tests/core/test_users.py @@ -18,7 +18,7 @@ from newsroom.signals import user_created, user_updated, user_deleted from unittest import mock -from tests.core.utils import create_entries_for +from tests.core.utils import create_entries_for, find_one_by_id from tests.utils import get_user_by_email, mock_send_email, login, login_public, logout @@ -527,7 +527,7 @@ async def test_signals(client, app): assert user["email"] == updated_user["email"] updated_listener.reset_mock() - token = app.data.find_one("auth_user", req=None, _id=user_id)["token"] + token = (await find_one_by_id("auth_user", user_id))["token"] resp = await client.get(f"/validate/{token}") assert 302 == resp.status_code, await resp.get_data(as_text=True) updated_listener.assert_called_once() @@ -580,7 +580,7 @@ async def update_user_schedule(data): # Update the schedules ``last_run_time`` now = utcnow() - await UsersService().update_notification_schedule_run_time(user, now) + await UsersService().update_notification_schedule_run_time(UserResourceModel.from_dict(user), now) user = await (await client.get(f"/users/{ADMIN_USER_ID}")).get_json() assert user["notification_schedule"]["timezone"] == "Australia/Sydney" diff --git a/tests/core/test_wire.py b/tests/core/test_wire.py index f22cfe634..0fed25eb9 100644 --- a/tests/core/test_wire.py +++ b/tests/core/test_wire.py @@ -8,14 +8,21 @@ from bson import ObjectId from superdesk.core import json +from superdesk.cache import cache -from newsroom.types import Product +from newsroom.types import ProductResourceModel, SectionEnum from newsroom.companies import CompanyServiceAsync from newsroom.search.types import NewshubSearchRequest from newsroom.wire import WireSearchServiceAsync from newsroom.wire.filters import WireSearchRequestArgs, apply_date_filters, apply_date_range -from tests.core.utils import add_company_products, create_entries_for +from tests.core.utils import ( + add_company_products, + create_entries_for, + delete_entries_for, + update_entries_for, + find_one_by_id, +) from ..fixtures import ( # noqa: F401 items, init_items, @@ -57,14 +64,14 @@ async def setup_products(app): ], ) - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ { "_id": PROD_1, "name": "product test", - "sd_product_id": 1, + "sd_product_id": "1", "navigations": [NAV_1], "product_type": "wire", "is_enabled": True, @@ -72,7 +79,7 @@ async def setup_products(app): { "_id": PROD_2, "name": "product test 2", - "sd_product_id": 2, + "sd_product_id": "2", "navigations": [NAV_2], "product_type": "wire", "is_enabled": True, @@ -96,8 +103,8 @@ async def test_item_json(client): @mock.patch("newsroom.email.send_email", mock_send_email) async def test_share_items(client, app): - user_ids = app.data.insert( - "users", + user_ids = await create_entries_for( + "auth_user", [ { "email": "foo2@bar.com", @@ -165,7 +172,7 @@ async def test_bookmarks(client, app): async def test_bookmarks_by_section(client, app): - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -229,7 +236,7 @@ async def test_search_filters_items_with_updates(client, app): async def test_search_includes_killed_items(client, app): - app.data.insert( + await create_entries_for( "items", [{"_id": "foo", "pubstatus": "canceled", "headline": "killed", "versioncreated": datetime.utcnow()}] ) resp = await client.get("/wire/search?q=headline:killed") @@ -238,7 +245,7 @@ async def test_search_includes_killed_items(client, app): async def test_search_by_products_id(client, app): - app.data.insert( + await create_entries_for( "items", [ { @@ -308,13 +315,13 @@ async def test_administrator_gets_all_results(client, app): async def test_search_filtered_by_users_products(client, app, public_user): - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ { "name": "product test", - "sd_product_id": 1, + "sd_product_id": "1", "is_enabled": True, "product_type": "wire", } @@ -369,7 +376,7 @@ async def test_search_filtered_by_query_product(client, app, public_user): ], ) - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -454,7 +461,7 @@ async def test_item_detail_access(client, app, public_user): assert not data.get("body_html") # add product - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -492,7 +499,7 @@ async def test_search_using_section_filter_for_public_user(client, app, public_u ], ) - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -513,7 +520,9 @@ async def test_search_using_section_filter_for_public_user(client, app, public_u ], ) + # Remove cached data g.pop("cached:navigations", None) + cache.clean() await login(client, public_user) resp = await client.get("/wire/search") @@ -525,11 +534,10 @@ async def test_search_using_section_filter_for_public_user(client, app, public_u assert 1 == len(data["_items"]) assert "_aggregations" in data - app.data.insert( + await create_entries_for( "section_filters", [ { - "_id": ObjectId(), "name": "product test 2", "query": "headline:Weather", "is_enabled": True, @@ -538,8 +546,9 @@ async def test_search_using_section_filter_for_public_user(client, app, public_u ], ) - g.section_filters = None + # Remove cached data g.pop("cached:section_filters", None) + cache.clean() resp = await client.get("/wire/search") data = json.loads(await resp.get_data()) @@ -556,7 +565,7 @@ async def test_search_using_section_filter_for_public_user(client, app, public_u async def test_administrator_gets_results_based_on_section_filter(client, app): await login(client, {"email": ADMIN_USER_EMAIL}) - app.data.insert( + await create_entries_for( "section_filters", [ { @@ -575,7 +584,7 @@ async def test_administrator_gets_results_based_on_section_filter(client, app): async def test_time_limited_access(client, app, public_user): - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -606,15 +615,15 @@ async def test_time_limited_access(client, app, public_user): assert 2 == len(data["_items"]) g.settings["wire_time_limit_days"]["value"] = 1 - company = app.data.find_one("companies", req=None, _id=COMPANY_1_ID) - app.data.update("companies", COMPANY_1_ID, {"archive_access": True}, company) + company = await find_one_by_id("companies", COMPANY_1_ID) + await update_entries_for("companies", COMPANY_1_ID, {"archive_access": True}, company) resp = await client.get("/wire/search") data = json.loads(await resp.get_data()) assert 2 == len(data["_items"]) async def test_company_type_filter(client, app, public_user): - add_company_products( + await add_company_products( app, COMPANY_1_ID, [ @@ -636,8 +645,8 @@ async def test_company_type_filter(client, app, public_user): dict(id="test", wire_must={"term": {"service.code": "b"}}), ] - company = app.data.find_one("companies", req=None, _id=COMPANY_1_ID) - app.data.update("companies", COMPANY_1_ID, {"company_type": "test"}, company) + company = await find_one_by_id("companies", COMPANY_1_ID) + await update_entries_for("companies", COMPANY_1_ID, {"company_type": "test"}, company) resp = await client.get("/wire/search") data = json.loads(await resp.get_data()) @@ -656,14 +665,14 @@ async def test_company_type_filter(client, app, public_user): async def test_search_by_products_and_filtered_by_embargoe(app, public_user): product_id = ObjectId() - product = Product( - _id=product_id, + product = ProductResourceModel( + id=product_id, name="product test", query="headline:china", is_enabled=True, - product_type="wire", + product_type=SectionEnum.WIRE, ) - add_company_products(app, COMPANY_1_ID, [product]) + await add_company_products(app, COMPANY_1_ID, [product.to_dict()]) # embargoed item is not fetched await create_entries_for( @@ -694,7 +703,7 @@ async def test_search_by_products_and_filtered_by_embargoe(app, public_user): assert 0 == len(items) # ex-embargoed item is fetched - app.data.insert( + await create_entries_for( "items", [ { @@ -731,7 +740,7 @@ async def test_wire_delete(client, app): async def test_highlighting(client, app): - app.data.insert( + await create_entries_for( "items", [ { @@ -770,7 +779,7 @@ async def test_highlighting(client, app): async def test_highlighting_with_advanced_search(client, app): - app.data.insert( + await create_entries_for( "items", [ { @@ -803,7 +812,7 @@ async def test_highlighting_with_advanced_search(client, app): async def test_french_accents_search(client, app): - app.data.insert( + await create_entries_for( "items", [{"_id": "foo", "body_html": "Story that involves élection", "versioncreated": datetime.utcnow()}] ) resp = await client.get("/wire/search?q=election") @@ -813,14 +822,14 @@ async def test_french_accents_search(client, app): async def test_navigation_for_public_users(client, app, setup_products): - user = app.data.find_one("users", req=None, _id=PUBLIC_USER_ID) + user = await find_one_by_id("users", PUBLIC_USER_ID) assert user - company = app.data.find_one("companies", req=None, _id=COMPANY_1_ID) + company = await find_one_by_id("companies", COMPANY_1_ID) assert company # add products to user - app.data.update( + await update_entries_for( "users", PUBLIC_USER_ID, {"products": [{"section": "wire", "_id": PROD_1}, {"section": "wire", "_id": PROD_2}]}, @@ -828,7 +837,7 @@ async def test_navigation_for_public_users(client, app, setup_products): ) # and remove those from company - app.data.update( + await update_entries_for( "companies", COMPANY_1_ID, {"products": [{"section": "wire", "_id": PROD_1, "seats": 1}, {"section": "wire", "_id": PROD_2, "seats": 1}]}, @@ -840,6 +849,7 @@ async def test_navigation_for_public_users(client, app, setup_products): # make sure user gets the products resp = await client.get("/wire/search") data = json.loads(await resp.get_data()) + print(data) assert 2 == len(data["_items"]) # test navigation @@ -850,10 +860,10 @@ async def test_navigation_for_public_users(client, app, setup_products): async def test_date_filters(client, app): # remove all other's item - app.data.remove("items") + await delete_entries_for("items") now = datetime.utcnow() app.config["DEFAULT_TIMEZONE"] = "Europe/Berlin" - app.data.insert( + await create_entries_for( "items", [ { diff --git a/tests/core/utils.py b/tests/core/utils.py index e5d38f1c0..4760e9d5f 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -1,22 +1,63 @@ from typing import Any +from bson import ObjectId from superdesk import get_resource_service from newsroom.core import get_current_wsgi_app -def add_company_products(app, company_id, products): - company = app.data.find_one("companies", req=None, _id=company_id) - app.data.insert("products", products) +async def find_one_by_id(resource: str, item_id: ObjectId | str) -> Any | None: + """ + Attempts to create a new resource entries. First tries with async, otherwise it falls back to + sync resources. + """ + app = get_current_wsgi_app() + async_app = app.async_app + + try: + model_instance = await async_app.resources.get_resource_service(resource).find_by_id(item_id) + return model_instance.to_dict(context={"use_objectid": True}) if model_instance else None + except KeyError: + return app.data.find_one(resource, req=None, _id=item_id) + + +async def find_one_for(resource: str, **kwargs: Any) -> Any | None: + app = get_current_wsgi_app() + async_app = app.async_app + + try: + model_instance = await async_app.resources.get_resource_service(resource).find_one(**kwargs) + return model_instance.to_dict(context={"use_objectid": True}) if model_instance else None + except KeyError: + return app.data.find_one(resource, req=None, **kwargs) + + +async def get_all(resource: str) -> list[Any]: + app = get_current_wsgi_app() + async_app = app.async_app + + try: + return [ + # item.to_dict(context={"use_objectid": True}) + item + async for item in async_app.resources.get_resource_service(resource).get_all_raw() + ] + except KeyError: + return [item for item in app.data.find(resource, {})] + + +async def add_company_products(app, company_id, products): + company = await find_one_by_id("companies", company_id) + await create_entries_for("products", products) company_products = company["products"] or [] for product in products: company_products.append({"_id": product["_id"], "section": product["product_type"], "seats": 0}) - app.data.update("companies", company["_id"], {"products": company_products}, company) + await update_entries_for("companies", company["_id"], {"products": company_products}, company) -async def create_entries_for(resource: str, items: list[Any]): +async def create_entries_for(resource: str, items: list[Any]) -> list[ObjectId | str]: """ - Attemps create a new resource entries. First tries with async, otherwise it falls back to + Attempts to create a new resource entries. First tries with async, otherwise it falls back to sync resources. """ app = get_current_wsgi_app() @@ -25,8 +66,38 @@ async def create_entries_for(resource: str, items: list[Any]): try: return await async_app.resources.get_resource_service(resource).create(items) except KeyError: + print(f"Failed to find async service for {resource}") ids = [] for item in items: app.data.mongo._mongotize(item, resource) ids.extend(get_resource_service(resource).post([item])) return ids + + +async def update_entries_for(resource: str, item_id: str | ObjectId, updates: dict[str, Any], original: Any): + """ + Attempts to update existing resource entries. First tries with async, otherwise it falls back to + sync resources. + """ + app = get_current_wsgi_app() + async_app = app.async_app + + try: + await async_app.resources.get_resource_service(resource).update(item_id, updates) + except KeyError: + app.data.update(resource, item_id, updates, original) + + +async def delete_entries_for(resource: str) -> None: + """ + Attempts to remove all items from the resources MongoDB and/or Elastic. + First tries with async, otherwise it falls back to sync resources. + """ + + app = get_current_wsgi_app() + async_app = app.async_app + + try: + await async_app.resources.get_resource_service(resource).delete_many({}) + except KeyError: + app.data.remove(resource) diff --git a/tests/fixtures/item_copy_fixture.json b/tests/fixtures/item_copy_fixture.json index 65107f9c2..b4ffae255 100644 --- a/tests/fixtures/item_copy_fixture.json +++ b/tests/fixtures/item_copy_fixture.json @@ -1,4 +1,5 @@ { + "_id": "ea3ed480b5db145e60d54b58152f1c23", "version": "2", "type": "text", "byline": "Ralph D. Russo", diff --git a/tests/news_api/test_api_assets.py b/tests/news_api/test_api_assets.py index 5034cbbc9..123fb5ecf 100644 --- a/tests/news_api/test_api_assets.py +++ b/tests/news_api/test_api_assets.py @@ -1,6 +1,7 @@ import os from bson import ObjectId from tests.news_api.test_api_audit import audit_check +from tests.core.utils import create_entries_for, find_one_for def get_fixture_path(fixture): @@ -18,12 +19,12 @@ async def setup_image(app): async def test_get_asset(client, app): company_id = ObjectId() - app.data.insert( + await create_entries_for( "companies", [{"_id": company_id, "name": "Test Company", "is_enabled": True}], ) - app.data.insert("news_api_tokens", [{"company": company_id, "enabled": True}]) - token = app.data.find_one("news_api_tokens", req=None, company=company_id) + await create_entries_for("news_api_tokens", [{"company": company_id, "enabled": True}]) + token = await find_one_for("news_api_tokens", company=company_id) image_id = await setup_image(app) response = await client.get("api/v1/assets/{}".format(image_id), headers={"Authorization": token.get("token")}) diff --git a/tests/news_api/test_api_audit.py b/tests/news_api/test_api_audit.py index 5948ea3ec..44cb13418 100644 --- a/tests/news_api/test_api_audit.py +++ b/tests/news_api/test_api_audit.py @@ -7,6 +7,7 @@ from newsroom.companies import CompanyServiceAsync from newsroom.tests.fixtures import COMPANY_1_ID, COMPANY_2_ID +from tests.core.utils import create_entries_for, find_one_for company_id = "5c3eb6975f627db90c84093c" @@ -19,11 +20,11 @@ def audit_check(item_id): @fixture(autouse=True) async def init(app): - app.data.insert( + await create_entries_for( "companies", [{"_id": ObjectId(company_id), "name": "Test Company", "is_enabled": True}], ) - app.data.insert( + await create_entries_for( "products", [ { @@ -45,12 +46,12 @@ async def init(app): async def test_get_item_audit_creation(client, app): - app.data.insert( + await create_entries_for( "items", [{"_id": "111", "pubstatus": "usable", "headline": "Headline of the story"}], ) - app.data.insert("news_api_tokens", [{"company": ObjectId(company_id), "enabled": True}]) - token = app.data.find_one("news_api_tokens", req=None, company=ObjectId(company_id)) + await create_entries_for("news_api_tokens", [{"company": ObjectId(company_id), "enabled": True}]) + token = await find_one_for("news_api_tokens", company=ObjectId(company_id)) response = await client.get( "api/v1/news/item/111?format=NINJSFormatter", headers={"Authorization": token.get("token")}, @@ -76,14 +77,17 @@ async def test_get_single_product_audit_creation(client, app): async def test_search_audit_creation(client, app): - app.data.insert( + await create_entries_for( "items", [ { "_id": "5ab03a87bdd78169bb6d0785", "body_html": "Once upon a time there was a fish who could swim", }, - {"body_html": "Once upon a time there was a aardvark that could not swim"}, + { + "_id": "5ab03a87bdd78169bb6d0786", + "body_html": "Once upon a time there was a aardvark that could not swim", + }, ], ) diff --git a/tests/search/test_search_fields.py b/tests/search/test_search_fields.py index e753688c5..4075deb15 100644 --- a/tests/search/test_search_fields.py +++ b/tests/search/test_search_fields.py @@ -57,7 +57,13 @@ async def test_agenda_search_fields(client, app): await create_entries_for( "agenda", [ - {"name": "foo", "ednote": "bar", "guid": "test"}, + { + "name": "foo", + "ednote": "bar", + "guid": "test", + "type": "event", + "dates": {"start": datetime.utcnow(), "end": datetime.utcnow()}, + }, ], ) diff --git a/tests/search/test_search_params.py b/tests/search/test_search_params.py index bf2f33e71..6e3021ffb 100644 --- a/tests/search/test_search_params.py +++ b/tests/search/test_search_params.py @@ -39,10 +39,10 @@ async def init(app): global service service = BaseSearchService() - app.data.insert("users", USERS) - app.data.insert("companies", COMPANIES) + await create_entries_for("companies", COMPANIES) + await create_entries_for("auth_user", USERS) await create_entries_for("navigations", NAVIGATIONS) - app.data.insert("products", PRODUCTS) + await create_entries_for("products", PRODUCTS) def get_search_instance(args=None, lookup=None): diff --git a/tests/search/test_search_topics.py b/tests/search/test_search_topics.py index 039fffaa1..05ee7837a 100644 --- a/tests/search/test_search_topics.py +++ b/tests/search/test_search_topics.py @@ -1,6 +1,7 @@ from quart import json from unittest.mock import patch -from newsroom.agenda.agenda import AgendaService + +# from newsroom.agenda.agenda import AgendaService from newsroom.wire.search import WireSearchService @@ -31,19 +32,19 @@ async def test_get_topic_query_wire(app): assert_no_aggs(service, search) -async def test_get_topic_query_agenda(app): - service = AgendaService() - search = service.get_topic_query( - {"query": "topic-query"}, - {"user_type": "administrator"}, - None, - {}, - args={"es_highlight": 1, "ids": ["item-id"]}, - ) - - assert search is not None - assert "bool" in search.query - assert {"ids": {"values": ["item-id"]}} in search.query["bool"]["filter"] - assert "aggs" not in search.query - - assert_no_aggs(service, search) +# async def test_get_topic_query_agenda(app): +# service = AgendaService() +# search = service.get_topic_query( +# {"query": "topic-query"}, +# {"user_type": "administrator"}, +# None, +# {}, +# args={"es_highlight": 1, "ids": ["item-id"]}, +# ) +# +# assert search is not None +# assert "bool" in search.query +# assert {"ids": {"values": ["item-id"]}} in search.query["bool"]["filter"] +# assert "aggs" not in search.query +# +# assert_no_aggs(service, search) diff --git a/tests/search/test_user_products.py b/tests/search/test_user_products.py index 47065ade9..8f43f6d5f 100644 --- a/tests/search/test_user_products.py +++ b/tests/search/test_user_products.py @@ -4,6 +4,7 @@ from quart import g from newsroom.users.users import UserRole +from tests.core.utils import create_entries_for, update_entries_for, find_one_by_id from .fixtures import ( USERS, @@ -15,9 +16,9 @@ @pytest.fixture(autouse=True) async def init(app): - app.data.insert("users", USERS) - app.data.insert("companies", COMPANIES) - app.data.insert("products", PRODUCTS) + await create_entries_for("companies", COMPANIES) + await create_entries_for("users", USERS) + await create_entries_for("products", PRODUCTS) @pytest.fixture @@ -28,7 +29,7 @@ async def product(app): "is_enabled": True, "product_type": "wire", } - app.data.insert("products", [product]) + await create_entries_for("products", [product]) return product @@ -43,7 +44,7 @@ async def company(app, product): } ] company.pop("_id") - app.data.insert("companies", [company]) + await create_entries_for("companies", [company]) return company @@ -55,7 +56,7 @@ async def manager(app, client, product, company): manager["user_type"] = UserRole.COMPANY_ADMIN.value manager.pop("_id") - app.data.insert("users", [manager]) + await create_entries_for("auth_user", [manager]) manager.pop("password") await utils.login(client, manager) @@ -82,7 +83,7 @@ async def test_user_products(app, client, manager, product, company): data = await utils.get_json(client, "/wire/search?q=weather") assert 0 == len(data["_items"]) - app.data.update("products", product["_id"], {"query": "headline:WEATHER"}, product) + await update_entries_for("products", product["_id"], {"query": "headline:WEATHER"}, product) g.pop("cached:products", None) data = await utils.get_json(client, "/wire/search") @@ -100,7 +101,7 @@ async def test_user_products_after_company_update(app, client, manager, product, }, ) - user = app.data.find_one("users", req=None, _id=manager["_id"]) + user = await find_one_by_id("users", manager["_id"]) assert user["products"] @@ -155,9 +156,9 @@ async def test_user_sections(app, client, manager, product): data = utils.get_json(client, "/agenda/search") assert data - company = app.data.find_one("companies", req=None, _id=manager["company"]) + company = await find_one_by_id("companies", manager["company"]) assert company - app.data.update("companies", manager["company"], {"sections": {"agenda": True}}, company) + await update_entries_for("companies", manager["company"], {"sections": {"agenda": True}}, company) with pytest.raises(AssertionError) as err: await utils.get_json(client, "/wire/search") @@ -181,8 +182,7 @@ async def test_other_company_user_changes_blocked(client, manager): assert "401" in str(err) -async def test_public_user_can_edit_his_dashboard(client, manager): - public_user = next((user for user in USERS if user["_id"] == PUBLIC_USER_ID)) - public_user.pop("password") - await utils.login(client, public_user) - await utils.patch_json(client, f"/api/_users/{PUBLIC_USER_ID}", {"dashboards": []}) +async def test_public_user_can_edit_his_dashboard(app, client, public_user): + async with app.test_request_context("/") as request: + request.session["user"] = str(PUBLIC_USER_ID) + await utils.patch_json(client, f"/api/_users/{PUBLIC_USER_ID}", {"dashboards": []})