Skip to content

Commit

Permalink
Refactor getting Celery app.
Browse files Browse the repository at this point in the history
  • Loading branch information
luhn committed Sep 15, 2021
1 parent 440d35a commit acecadf
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions pyramid_tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def includeme(config):
)
config.action(
("celery", "finalize"),
config.registry["pyramid_tasks.app"].finalize,
_get_app(config.registry).finalize,
order=PHASE2_CONFIG,
)
config.add_request_method(defer_task)
Expand All @@ -50,9 +50,13 @@ def includeme(config):
config.include(".taskderivers")


def _get_app(registry):
return registry["pyramid_tasks.app"]


def make_celery_app(config):
config.commit()
return config.registry["pyramid_tasks.app"]
return _get_app(config.registry)


def register_task(config, func, name=None, **kwargs):
Expand All @@ -61,7 +65,7 @@ def register_task(config, func, name=None, **kwargs):
"""
registry = config.registry
app = registry["pyramid_tasks.app"]
app = _get_app(registry)
name = name or app.gen_task_name(func.__name__, func.__module__)

def register():
Expand Down Expand Up @@ -116,8 +120,7 @@ def _get_task(registry, func_or_name):
"""
task_map = registry["pyramid_tasks.task_map"]
if isinstance(func_or_name, str):
celery_app = registry["pyramid_tasks.app"]
return celery_app.tasks[func_or_name]
return _get_app(registry).tasks[func_or_name]
elif func_or_name in task_map:
return task_map[func_or_name]
else:
Expand Down Expand Up @@ -173,8 +176,13 @@ def add_periodic_task(

def add():
task = _get_task(config.registry, func_or_name)
app = config.registry["pyramid_tasks.app"]
app.add_periodic_task(schedule, task, args, kwargs, **opts)
_get_app(config.registry).add_periodic_task(
schedule,
task,
args,
kwargs,
**opts,
)

config.action(None, add, order=PHASE3_CONFIG)

Expand All @@ -184,14 +192,12 @@ def get_task_result(request, task_id):
Get a result object from celery.
"""
app = request.registry["pyramid_tasks.app"]
return app.AsyncResult(task_id)
return _get_app(request.registry).AsyncResult(task_id)


def current_task(request):
"""
Return the task currently being executed.
"""
app = request.registry["pyramid_tasks.app"]
return app.current_worker_task
return _get_app(request.registry).current_worker_task

0 comments on commit acecadf

Please sign in to comment.