Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDESK-7472] Add async support to ingest handlers #2818

Merged
merged 3 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/rules/routing_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def on_delete(self, doc):
if self.backend.find_one("ingest_providers", req=None, routing_scheme=doc[ID_FIELD]):
raise SuperdeskApiError.forbiddenError(_("Routing scheme is applied to channel(s). It cannot be deleted."))

def apply_routing_scheme(self, ingest_item, provider, routing_scheme):
async def apply_routing_scheme(self, ingest_item, provider, routing_scheme):
"""Applies routing scheme and applies appropriate action (fetch, publish) to the item

:param item: ingest item to which routing scheme needs to applied.
Expand Down Expand Up @@ -225,7 +225,7 @@ def apply_routing_scheme(self, ingest_item, provider, routing_scheme):
% (item_id, routing_scheme.get("name"), rule.get("name"))
)

rule_handler.apply_rule(rule, ingest_item, routing_scheme)
await rule_handler.apply_rule(rule, ingest_item, routing_scheme)
if rule.get("actions", {}).get("exit", False):
logger.info(
"Exiting routing scheme. Item: %s . Routing Scheme: %s. "
Expand Down
13 changes: 6 additions & 7 deletions apps/rules/rule_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
# AUTHORS and LICENSE files distributed with this source code, or
# at https://www.sourcefabric.org/superdesk/license

from typing import Dict, Any
from typing import Dict, Any, Awaitable
import logging

from quart_babel import lazy_gettext
from quart_babel.speaklater import LazyString
from quart_babel import lazy_gettext, LazyString

from superdesk.core import get_app_config
from superdesk.resource_fields import ID_FIELD
Expand All @@ -30,10 +29,10 @@ class RoutingRuleHandler:
supported_configs: Dict[str, bool]
default_values: Dict[str, Any]

def can_handle(self, rule, ingest_item, routing_scheme) -> bool:
async def can_handle(self, rule, ingest_item, routing_scheme) -> Awaitable[bool]:
raise NotImplementedError()

def apply_rule(self, rule, ingest_item, routing_scheme):
async def apply_rule(self, rule, ingest_item, routing_scheme):
raise NotImplementedError()


Expand Down Expand Up @@ -121,12 +120,12 @@ class DeskFetchPublishRoutingRuleHandler(RoutingRuleHandler):
},
}

def can_handle(self, rule, ingest_item, routing_scheme):
async def can_handle(self, rule, ingest_item, routing_scheme):
return ingest_item.get(ITEM_TYPE) in (
MEDIA_TYPES + (CONTENT_TYPE.TEXT, CONTENT_TYPE.PREFORMATTED, CONTENT_TYPE.COMPOSITE)
)

def apply_rule(self, rule, ingest_item, routing_scheme):
async def apply_rule(self, rule, ingest_item, routing_scheme):
if rule.get("actions", {}).get("preserve_desk", False) and ingest_item.get("task", {}).get("desk"):
desk = get_resource_service("desks").find_one(req=None, _id=ingest_item["task"]["desk"])
if ingest_item.get("task", {}).get("stage"):
Expand Down
16 changes: 9 additions & 7 deletions superdesk/io/commands/update_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ async def update_provider(provider, rule_set=None, routing_scheme=None, sync=Fal
logger.warning("lock expired while updating provider %s", provider[ID_FIELD])
return
items = generator.send(failed)
failed = ingest_items(items, provider, feeding_service, rule_set, routing_scheme)
failed = await ingest_items(items, provider, feeding_service, rule_set, routing_scheme)
update_last_item_updated(update, items)

if not update.get(LAST_ITEM_ARRIVED) or update[LAST_ITEM_ARRIVED] < datetime.now(tz=pytz.utc):
Expand Down Expand Up @@ -512,7 +512,7 @@ def ingest_cancel(item, feeding_service):
ingest_service.patch(relative["_id"], update)


def ingest_items(items, provider, feeding_service, rule_set=None, routing_scheme=None):
async def ingest_items(items, provider, feeding_service, rule_set=None, routing_scheme=None):
all_items = filter_expired_items(provider, items)
items_dict = {doc[GUID_FIELD]: doc for doc in all_items}
items_in_package = []
Expand All @@ -524,7 +524,7 @@ def ingest_items(items, provider, feeding_service, rule_set=None, routing_scheme
]

for item in [doc for doc in all_items if doc.get(ITEM_TYPE) != CONTENT_TYPE.COMPOSITE]:
ingested, ids = ingest_item(
ingested, ids = await ingest_item(
item,
provider,
feeding_service,
Expand All @@ -550,7 +550,7 @@ def ingest_items(items, provider, feeding_service, rule_set=None, routing_scheme
ref["residRef"] = items_dict.get(ref["residRef"], {}).get(ID_FIELD)
if item[GUID_FIELD] in failed_items:
continue
ingested, ids = ingest_item(item, provider, feeding_service, rule_set, routing_scheme)
ingested, ids = await ingest_item(item, provider, feeding_service, rule_set, routing_scheme)
if ingested:
created_ids = created_ids + ids
else:
Expand All @@ -569,7 +569,7 @@ def ingest_items(items, provider, feeding_service, rule_set=None, routing_scheme
return failed_items


def ingest_item(item, provider, feeding_service, rule_set=None, routing_scheme=None, expiry=None):
async def ingest_item(item, provider, feeding_service, rule_set=None, routing_scheme=None, expiry=None):
items_ids = []
try:
ingest_collection = get_ingest_collection(feeding_service, item)
Expand Down Expand Up @@ -684,7 +684,7 @@ def ingest_item(item, provider, feeding_service, rule_set=None, routing_scheme=N
name=assoc_name,
),
)
status, ids = ingest_item(assoc, provider, feeding_service, rule_set, expiry=item["expiry"])
status, ids = await ingest_item(assoc, provider, feeding_service, rule_set, expiry=item["expiry"])
if status:
assoc["_id"] = ids[0]
items_ids.extend(ids)
Expand Down Expand Up @@ -715,7 +715,9 @@ def ingest_item(item, provider, feeding_service, rule_set=None, routing_scheme=N

if routing_scheme and new_version:
routed = ingest_service.find_one(_id=item[ID_FIELD], req=None)
superdesk.get_resource_service("routing_schemes").apply_routing_scheme(routed, provider, routing_scheme)
await superdesk.get_resource_service("routing_schemes").apply_routing_scheme(
routed, provider, routing_scheme
)

except Exception as ex:
logger.exception(ex)
Expand Down
14 changes: 7 additions & 7 deletions superdesk/tests/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ def step_create_new_macro(context, macro_name):
@async_run_until_complete
async def step_impl_fetch_from_provider_ingest(context, provider_name, guid):
async with context.app.test_request_context(context.app.config["URL_PREFIX"]):
fetch_from_provider(context, provider_name, guid)
await fetch_from_provider(context, provider_name, guid)


@when('we fetch from "{provider_name}" ingest "{guid}" (mocking with "{mock_file}")')
Expand All @@ -639,7 +639,7 @@ async def step_impl_fetch_from_provider_ingest_with_mocking(context, provider_na

with responses.RequestsMock() as rsps:
apply_mock_file(rsps, mock_file, fixture_path=get_provider_file_path(provider))
fetch_from_provider(context, provider_name, guid)
await fetch_from_provider(context, provider_name, guid)


@when('we run update_ingest command for "{provider_name}"')
Expand Down Expand Up @@ -746,7 +746,7 @@ async def step_impl_fetch_from_provider_ingest_using_routing(context, provider_n
_id = apply_placeholders(context, context.text)
routing_scheme = get_resource_service("routing_schemes").find_one(_id=_id, req=None)
embed_routing_scheme_rules(routing_scheme)
fetch_from_provider(context, provider_name, guid, routing_scheme)
await fetch_from_provider(context, provider_name, guid, routing_scheme)


@when('we ingest and fetch "{provider_name}" "{guid}" to desk "{desk}" stage "{stage}" using routing_scheme')
Expand All @@ -758,7 +758,7 @@ async def step_impl_fetch_from_provider_ingest_using_routing_with_desk(context,
stage_id = apply_placeholders(context, stage)
routing_scheme = get_resource_service("routing_schemes").find_one(_id=_id, req=None)
embed_routing_scheme_rules(routing_scheme)
fetch_from_provider(context, provider_name, guid, routing_scheme, desk_id, stage_id)
await fetch_from_provider(context, provider_name, guid, routing_scheme, desk_id, stage_id)


@when('we ingest with routing scheme "{provider_name}" "{guid}"')
Expand All @@ -768,7 +768,7 @@ async def step_impl_ingest_with_routing_scheme(context, provider_name, guid):
_id = apply_placeholders(context, context.text)
routing_scheme = get_resource_service("routing_schemes").find_one(_id=_id, req=None)
embed_routing_scheme_rules(routing_scheme)
fetch_from_provider(context, provider_name, guid, routing_scheme)
await fetch_from_provider(context, provider_name, guid, routing_scheme)


def get_provider_file_path(provider, filename=""):
Expand All @@ -778,7 +778,7 @@ def get_provider_file_path(provider, filename=""):
return os.path.join(provider.get("config", {}).get("path", ""), filename)


def fetch_from_provider(context, provider_name, guid, routing_scheme=None, desk_id=None, stage_id=None):
async def fetch_from_provider(context, provider_name, guid, routing_scheme=None, desk_id=None, stage_id=None):
ingest_provider_service = get_resource_service("ingest_providers")
provider = ingest_provider_service.find_one(name=provider_name, req=None)
provider["routing_scheme"] = routing_scheme
Expand Down Expand Up @@ -824,7 +824,7 @@ def fetch_from_provider(context, provider_name, guid, routing_scheme=None, desk_

item["task"] = {"desk": ObjectId(desk_id), "stage": ObjectId(stage_id)}

failed = context.ingest_items(
failed = await context.ingest_items(
items, provider, provider_service, rule_set=rule_set, routing_scheme=provider.get("routing_scheme")
)
assert len(failed) == 0, failed
Expand Down
44 changes: 22 additions & 22 deletions tests/io/update_ingest_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ async def test_ingest_items(self):
items = provider_service.fetch_ingest(reuters_guid)
items.extend(provider_service.fetch_ingest(reuters_guid))
self.assertEqual(12, len(items))
self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)

async def test_ingest_item_expiry(self):
provider, provider_service = self.setup_reuters_provider()
items = provider_service.fetch_ingest(reuters_guid)
self.assertIsNone(items[1].get("expiry"))
items[1]["versioncreated"] = utcnow()
self.ingest_items([items[1]], provider, provider_service)
await self.ingest_items([items[1]], provider, provider_service)
self.assertIsNotNone(items[1].get("expiry"))

async def test_ingest_item_sync_if_missing_from_elastic(self):
Expand Down Expand Up @@ -254,7 +254,7 @@ async def test_expiring_content_with_files(self):
items[5]["versioncreated"] = now + timedelta(minutes=11)

# ingest the items and expire them
self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)

# four files in grid fs
current_files = self.app.media.storage().fs("upload").find()
Expand Down Expand Up @@ -317,15 +317,15 @@ async def test_files_dont_duplicate_ingest(self):
item["expiry"] = utcnow() + timedelta(hours=11)

# ingest the items
self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)

items = provider_service.fetch_ingest(reuters_guid)
for item in items:
item["ingest_provider"] = provider["_id"]
item["expiry"] = utcnow() + timedelta(hours=11)

# ingest them again
self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)

# 12 files in grid fs
current_files = self.app.media.storage().fs("upload").find()
Expand All @@ -348,7 +348,7 @@ async def test_anpa_category_to_subject_derived_ingest(self):
items = [feeding_parser.parse(file_path, provider)]

# ingest the items and check the subject code has been derived
self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)
self.assertEqual(items[0]["subject"][0]["qcode"], "15000000")

async def test_anpa_category_to_subject_derived_ingest_ignores_inactive_categories(self):
Expand All @@ -368,7 +368,7 @@ async def test_anpa_category_to_subject_derived_ingest_ignores_inactive_categori
items = [feeding_parser.parse(file_path, provider)]

# ingest the items and check the subject code has been derived
self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)
self.assertNotIn("subject", items[0])

async def test_subject_to_anpa_category_derived_ingest(self):
Expand Down Expand Up @@ -408,7 +408,7 @@ async def test_subject_to_anpa_category_derived_ingest(self):
item["language"] = "fr"

# ingest the items and check the subject code has been derived
self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)
self.assertEqual(items[0]["anpa_category"][0]["qcode"], "f")
self.assertEqual(items[0]["anpa_category"][0]["name"], "Finance FR")

Expand Down Expand Up @@ -440,7 +440,7 @@ async def test_subject_to_anpa_category_derived_ingest_ignores_inactive_map_entr
item["expiry"] = utcnow() + timedelta(hours=11)

# ingest the items and check the subject code has been derived
self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)
self.assertNotIn("anpa_category", items[0])

async def test_ingest_cancellation(self):
Expand All @@ -450,13 +450,13 @@ async def test_ingest_cancellation(self):
for item in items:
item["ingest_provider"] = provider["_id"]
item["expiry"] = utcnow() + timedelta(hours=11)
self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)
guid = "tag_reuters.com_2016_newsml_L1N14N0FF:1542761538"
items = provider_service.fetch_ingest(guid)
for item in items:
item["ingest_provider"] = provider["_id"]
item["expiry"] = utcnow() + timedelta(hours=11)
self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)
ingest_service = get_resource_service("ingest")
lookup = {"uri": items[0].get("uri")}
family_members = ingest_service.get_from_mongo(req=None, lookup=lookup)
Expand All @@ -471,7 +471,7 @@ async def test_ingest_update(self):
items[0]["ingest_provider"] = provider["_id"]
items[0]["expiry"] = utcnow() + timedelta(hours=11)

self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)

self.assertEqual(items[0]["unique_id"], 1)
original_id = items[0]["_id"]
Expand All @@ -485,7 +485,7 @@ async def test_ingest_update(self):
items[0]["version"] = 11

# ingest the item again
self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)

# see the update to the headline and unique_id survives
elastic_item = self.app.data._search_backend("ingest").find_one("ingest", _id=original_id, req=None)
Expand Down Expand Up @@ -523,7 +523,7 @@ async def test_unknown_category_ingested_is_removed(self):

# ingest the items and check the subject code has been derived
items[0]["versioncreated"] = utcnow()
self.ingest_items(items, provider, provider_service)
await self.ingest_items(items, provider, provider_service)
self.assertTrue(len(items[0]["anpa_category"]) == 0)

async def test_ingest_with_routing_keeps_elastic_in_sync(self):
Expand Down Expand Up @@ -598,7 +598,7 @@ async def test_ingest_with_routing_keeps_elastic_in_sync(self):
}

ingest_service = get_resource_service("ingest")
self.ingest_items(items, provider, provider_service, routing_scheme=routing_scheme)
await self.ingest_items(items, provider, provider_service, routing_scheme=routing_scheme)

self.assertEqual(4, ingest_service.get_from_mongo(None, {}).count())
self.assertEqual(4, ingest_service.get(None, {}).count())
Expand Down Expand Up @@ -650,7 +650,7 @@ async def test_ingest_associated_item_renditions(self):
# avoid transfer_renditions call which would store the picture locally
# and it would fetch it using superdesk url which doesn't work in test
with patch("superdesk.io.commands.update_ingest.transfer_renditions"):
status, ids = ingest_item(item, provider, provider_service)
status, ids = await ingest_item(item, provider, provider_service)

self.assertTrue(status)
self.assertEqual(3, len(ids))
Expand All @@ -660,13 +660,13 @@ async def test_ingest_associated_item_renditions(self):
async def test_ingest_profile_if_exists(self):
provider, provider_service = self.setup_reuters_provider()
items = provider_service.fetch_ingest(reuters_guid)
ingest_item(items[0], provider, provider_service)
await ingest_item(items[0], provider, provider_service)
self.assertEqual("composite", items[0].get("profile"))

content_types = [{"_id": "story", "name": "story"}]
self.app.data.insert("content_types", content_types)
items[1]["profile"] = "story"
ingest_item(items[1], provider, provider_service)
await ingest_item(items[1], provider, provider_service)
self.assertEqual("story", items[1].get("profile"))

@markers.requires_async_celery
Expand Down Expand Up @@ -703,7 +703,7 @@ async def test_edited_planning_item_is_not_update(self):
events_post_service = get_resource_service("events_post")

# ingest first version
ingested, ids = ingest_item(item, provider=provider, feeding_service={})
ingested, ids = await ingest_item(item, provider=provider, feeding_service={})
self.assertTrue(ingested)
self.assertIn(item["guid"], ids)

Expand All @@ -718,7 +718,7 @@ async def test_edited_planning_item_is_not_update(self):
self.assertEqual(dest.get("version_creator"), "current_user_id")

# update event
ingested, ids = ingest_item(item, provider=provider, feeding_service={})
ingested, ids = await ingest_item(item, provider=provider, feeding_service={})
self.assertFalse(ingested)
self.assertEqual([], ids)

Expand Down Expand Up @@ -756,7 +756,7 @@ async def test_unpublished_event_is_not_update(self):
events_post_service = get_resource_service("events_post")

# ingest first version
ingested, ids = ingest_item(item, provider=provider, feeding_service={})
ingested, ids = await ingest_item(item, provider=provider, feeding_service={})
self.assertTrue(ingested)
self.assertIn(item["guid"], ids)

Expand Down Expand Up @@ -787,6 +787,6 @@ async def test_unpublished_event_is_not_update(self):
self.assertEqual(dest.get("state"), "killed")

# update an event
ingested, ids = ingest_item(item, provider=provider, feeding_service={})
ingested, ids = await ingest_item(item, provider=provider, feeding_service={})
self.assertFalse(ingested)
self.assertEqual([], ids)
Loading