Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 143 additions & 7 deletions backend/pg_queue/executor_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from pg_queue.flags import PG_QUEUE_FLAG_KEY
from pg_queue.models import PgTaskResult
from pg_queue.producer import enqueue_task
from unstract.core.data_models import PgTaskStatus
from unstract.core.data_models import ContinuationSpec, PgTaskStatus
from unstract.flags.feature_flag import check_feature_flag_status
from unstract.sdk1.execution.dispatcher import ExecutionDispatcher
from unstract.sdk1.execution.result import ExecutionResult
Expand All @@ -65,6 +65,46 @@
_POLL_MAX_SECONDS = 2.0


class _DispatchHandle:
"""Minimal duck-type of Celery ``AsyncResult`` for the PG callback path.

``dispatch_with_callback`` callers read only ``.id`` (to return the task id
in the HTTP 202 response); they must NOT call ``.get()`` — the result arrives
via the self-chained callback (WebSocket), not by polling here. Exposing just
``.id`` lets a PG dispatch return the same shape the call sites already use.
"""

__slots__ = ("id",)

def __init__(self, task_id: str) -> None:
self.id = task_id


def _signature_to_spec(sig: Any | None) -> ContinuationSpec | None:
"""Translate a Celery ``Signature`` to a serialisable continuation spec.

Reads only the three attributes PG self-chaining needs — task name, kwargs,
target queue — so the prompt-studio call sites keep passing
``signature(name, kwargs=..., queue=...)`` unchanged; only the PG branch
translates. ``None`` (no callback for that outcome) passes through. A
signature without a queue is a configuration error: PG routes by the row's
queue and must not silently default it, so we fail fast.
"""
if sig is None:
return None
queue = (getattr(sig, "options", None) or {}).get("queue")
if not queue:
raise ValueError(
f"callback signature {getattr(sig, 'task', sig)!r} has no queue; "
"PG self-chaining routes by the row's queue and cannot default it"
)
return ContinuationSpec(
task_name=sig.task,
kwargs=dict(getattr(sig, "kwargs", None) or {}),
queue=queue,
)


def resolve_executor_transport(context: ExecutionContext) -> bool:
"""True → route this executor dispatch over PG; False → Celery (default).

Expand Down Expand Up @@ -186,6 +226,80 @@
)
return ExecutionResult.failure(error=row.error or "executor task failed")

def dispatch_async(
self, context: ExecutionContext, headers: dict[str, Any] | None = None

Check warning on line 230 in backend/pg_queue/executor_rpc.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove the unused function parameter "headers".

See more on https://sonarcloud.io/project/issues?id=Zipstack_unstract&issues=AZ7l88mv4ZKJmrimpnol&open=AZ7l88mv4ZKJmrimpnol&pullRequest=2097
) -> str:
"""Fire-and-forget enqueue of ``execute_extraction``; returns the task id.

The PG analogue of the SDK ``dispatch_async``: no ``reply_key``, no
callback, no blocking. There is no PG ``AsyncResult`` backend, so a caller
that needs the outcome uses :meth:`dispatch_with_callback` (a self-chained
continuation), not polling on this id. ``headers`` is accepted and ignored
(PG carries routing in the payload). Enqueue failures propagate — parity
with the SDK, which lets a broker error out of ``dispatch_async``.
"""
task_id = str(uuid.uuid4())
queue = f"{_QUEUE_PREFIX}{context.executor_name}"
org = getattr(context, "organization_id", "") or ""
enqueue_task(
task_name=_EXECUTE_TASK,
queue=queue,
args=[context.to_dict()],
org_id=str(org),
task_id=task_id,
)
logger.info(
"PG executor dispatch_async: enqueued task_id=%s queue=%s run_id=%s",
task_id,
queue,
context.run_id,
)
return task_id

def dispatch_with_callback(
self,
context: ExecutionContext,
on_success: Any | None = None,
on_error: Any | None = None,
task_id: str | None = None,
headers: dict[str, Any] | None = None,

Check warning on line 265 in backend/pg_queue/executor_rpc.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Remove the unused function parameter "headers".

See more on https://sonarcloud.io/project/issues?id=Zipstack_unstract&issues=AZ7l88mv4ZKJmrimpnom&open=AZ7l88mv4ZKJmrimpnom&pullRequest=2097
) -> _DispatchHandle:
"""Fire-and-forget enqueue with self-chained callbacks (§5 model).

The PG analogue of the SDK ``dispatch_with_callback``: instead of Celery
``link`` / ``link_error`` (which the broker fires), the on-success /
on-error Celery ``Signature``s are translated to serialisable
:class:`ContinuationSpec`s and carried in the payload. After the executor
consumer runs ``execute_extraction`` it self-chains the matching
continuation onto the callback queue. Returns a :class:`_DispatchHandle`
exposing ``.id`` (== ``task_id``) so call sites read the task id exactly
as on the Celery path. ``headers`` is accepted and ignored.
"""
task_id = task_id or str(uuid.uuid4())
queue = f"{_QUEUE_PREFIX}{context.executor_name}"
org = getattr(context, "organization_id", "") or ""
success_spec = _signature_to_spec(on_success)
error_spec = _signature_to_spec(on_error)
enqueue_task(
task_name=_EXECUTE_TASK,
queue=queue,
args=[context.to_dict()],
org_id=str(org),
on_success=success_spec,
on_error=error_spec,
task_id=task_id,
)
logger.info(
"PG executor dispatch_with_callback: enqueued task_id=%s queue=%s "
"run_id=%s on_success=%s on_error=%s",
task_id,
queue,
context.run_id,
success_spec["task_name"] if success_spec else None,
error_spec["task_name"] if error_spec else None,
)
return _DispatchHandle(task_id)

@staticmethod
def _wait_for_result(reply_key: str, timeout: float) -> PgTaskResult | None:
"""Poll ``pg_task_result`` until the row appears or *timeout* elapses.
Expand Down Expand Up @@ -216,10 +330,10 @@
class RoutingExecutionDispatcher:
"""Gate-routed executor dispatcher returned by ``_get_dispatcher()``.

``dispatch()`` chooses PG vs Celery per call (instant rollout/rollback);
``dispatch_async`` / ``dispatch_with_callback`` always delegate to Celery —
the async/callback path stays on Celery until a later continuation slice.
Duck-typed against the SDK ``ExecutionDispatcher`` so call sites are unchanged.
Every mode chooses PG vs Celery per call (instant rollout/rollback):
``dispatch()`` (request-reply), ``dispatch_async`` (fire-and-forget) and
``dispatch_with_callback`` (self-chained callbacks). Duck-typed against the SDK
``ExecutionDispatcher`` so call sites are unchanged.
"""

def __init__(self, celery_app: object | None = None) -> None:
Expand All @@ -246,10 +360,32 @@
def dispatch_async(
self, context: ExecutionContext, headers: dict[str, Any] | None = None
) -> str:
if resolve_executor_transport(context):
return self._pg.dispatch_async(context)
return self._celery.dispatch_async(context, headers=headers)

def dispatch_with_callback(self, context: ExecutionContext, **kwargs: Any) -> Any:
return self._celery.dispatch_with_callback(context, **kwargs)
def dispatch_with_callback(
self,
context: ExecutionContext,
on_success: Any | None = None,
on_error: Any | None = None,
task_id: str | None = None,
headers: dict[str, Any] | None = None,
) -> Any:
if resolve_executor_transport(context):
return self._pg.dispatch_with_callback(
context,
on_success=on_success,
on_error=on_error,
task_id=task_id,
)
return self._celery.dispatch_with_callback(
context,
on_success=on_success,
on_error=on_error,
task_id=task_id,
headers=headers,
)


def get_executor_dispatcher(
Expand Down
20 changes: 18 additions & 2 deletions backend/pg_queue/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
FAIRNESS_DEFAULT_PRIORITY,
FAIRNESS_MAX_PRIORITY,
FAIRNESS_MIN_PRIORITY,
ContinuationSpec,
FairnessPayload,
TaskPayload,
)
Expand Down Expand Up @@ -63,6 +64,9 @@ def enqueue_task(
priority: int = DEFAULT_PRIORITY,
fairness: FairnessPayload | None = None,
reply_key: str | None = None,
on_success: ContinuationSpec | None = None,
on_error: ContinuationSpec | None = None,
task_id: str | None = None,
) -> int:
"""Enqueue a task onto the PG queue; returns the new ``msg_id``.

Expand All @@ -74,6 +78,12 @@ def enqueue_task(
``reply_key`` marks a **request-reply** dispatch (the executor RPC on PG):
the executor consumer writes the task's result/error to ``pg_task_result``
under it for the blocking caller to poll. Omitted = fire-and-forget.

``on_success`` / ``on_error`` mark an **async/callback** dispatch
(``dispatch_with_callback``): the executor consumer self-chains the matching
continuation after the task runs. ``task_id`` is the dispatch id prepended to
``on_error`` as the failed id (Celery ``link_error`` parity). Mutually
exclusive with ``reply_key``.
"""
if not FAIRNESS_MIN_PRIORITY <= priority <= FAIRNESS_MAX_PRIORITY:
raise ValueError(
Expand All @@ -88,10 +98,16 @@ def enqueue_task(
"queue": pg_queue,
"fairness": fairness,
}
# Only set for request-reply dispatches — keeps fire-and-forget rows
# byte-identical to before this field existed.
# Each optional key is set only when present — keeps fire-and-forget rows
# byte-identical to before these fields existed.
if reply_key is not None:
message["reply_key"] = reply_key
if on_success is not None:
message["on_success"] = on_success

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important — on_success/on_error bypass _json_safe. args/kwargs go through _json_safe(...) above (lines 96-97), but the continuation specs (each with a nested kwargs dict) are written into the JSONField verbatim. If a callback's kwargs carry a UUID/datetime, PgQueueMessage.objects.create raises at insert time — and unlike the worker path this is caller-visible at dispatch. Fix: message["on_success"] = _json_safe(on_success) (and likewise on_error).

if on_error is not None:
message["on_error"] = on_error
if task_id is not None:
message["task_id"] = task_id
# Mirror the worker _enqueue_pg path: log the failure with breadcrumbs before
# it propagates, so a DB/constraint/serialization error isn't mislabeled by
# the caller's broad handler.
Expand Down
105 changes: 100 additions & 5 deletions backend/pg_queue/tests/test_executor_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

from unittest.mock import MagicMock, patch

import pytest
from pg_queue.executor_rpc import (
PgExecutionDispatcher,
RoutingExecutionDispatcher,
_signature_to_spec,
resolve_executor_transport,
)

Expand Down Expand Up @@ -221,11 +223,104 @@ def test_gate_on_dispatch_uses_pg(self):
pg.dispatch.assert_called_once()
celery.dispatch.assert_not_called()

def test_async_and_callback_always_celery(self):
"""The callback/async path stays on Celery regardless of the gate (a later slice)."""
def test_async_and_callback_stay_celery_when_gate_off(self):
"""Zero-regression: gate off → async/callback delegate to Celery unchanged."""
dispatcher, celery, pg = self._build()
dispatcher.dispatch_async(_ctx())
dispatcher.dispatch_with_callback(_ctx(), on_success=None)
with patch(f"{_MOD}.resolve_executor_transport", return_value=False):
dispatcher.dispatch_async(_ctx(), headers={"h": 1})
dispatcher.dispatch_with_callback(_ctx(), on_success="s", on_error="e")
celery.dispatch_async.assert_called_once()
celery.dispatch_with_callback.assert_called_once()
pg.dispatch.assert_not_called()
pg.dispatch_async.assert_not_called()
pg.dispatch_with_callback.assert_not_called()

def test_async_and_callback_route_to_pg_when_gated(self):
"""Gate on (③c) → async/callback take the PG self-chained path."""
dispatcher, celery, pg = self._build()
with patch(f"{_MOD}.resolve_executor_transport", return_value=True):
dispatcher.dispatch_async(_ctx())
dispatcher.dispatch_with_callback(
_ctx(), on_success="s", on_error="e", task_id="t"
)
pg.dispatch_async.assert_called_once()
pg.dispatch_with_callback.assert_called_once()
assert "headers" not in pg.dispatch_with_callback.call_args.kwargs
celery.dispatch_async.assert_not_called()
celery.dispatch_with_callback.assert_not_called()


class TestSignatureToSpec:
"""Celery ``Signature`` → serialisable continuation spec (the §5 wire-form)."""

def test_none_passes_through(self):
assert _signature_to_spec(None) is None

def test_translates_task_kwargs_and_queue(self):
sig = MagicMock(
task="ide_prompt_complete",
kwargs={"callback_kwargs": {"room": "r1"}},
options={"queue": "ide_callback"},
)
assert _signature_to_spec(sig) == {
"task_name": "ide_prompt_complete",
"kwargs": {"callback_kwargs": {"room": "r1"}},
"queue": "ide_callback",
}

def test_missing_queue_fails_fast(self):
sig = MagicMock(task="ide_prompt_complete", kwargs={}, options={})
with pytest.raises(ValueError, match="no queue"):
_signature_to_spec(sig)


class TestPgAsyncCallbackWiring:
"""PG fire-and-forget + self-chained-callback enqueue shapes (``enqueue_task``
mocked). Pins that the async path carries NO reply_key and the callback path
carries the translated continuations + the tracking task_id.
"""

@staticmethod
def _ctx():
c = MagicMock()
c.executor_name = "legacy"
c.run_id = "r"
c.organization_id = "org9"
c.to_dict.return_value = {"run_id": "r", "organization_id": "org9"}
return c

def test_dispatch_async_is_fire_and_forget(self):
with patch(f"{_MOD}.enqueue_task") as enq:
task_id = PgExecutionDispatcher().dispatch_async(self._ctx())
kwargs = enq.call_args.kwargs
assert kwargs["task_name"] == "execute_extraction"
assert kwargs["queue"] == "celery_executor_legacy"
assert kwargs["org_id"] == "org9"
assert kwargs["task_id"] == task_id
assert "reply_key" not in kwargs
assert "on_success" not in kwargs

def test_dispatch_with_callback_carries_continuations(self):
on_s = MagicMock(
task="ide_prompt_complete",
kwargs={"callback_kwargs": {"room": "r1"}},
options={"queue": "ide_callback"},
)
on_e = MagicMock(
task="ide_prompt_error",
kwargs={"callback_kwargs": {"room": "r1"}},
options={"queue": "ide_callback"},
)
with patch(f"{_MOD}.enqueue_task") as enq:
handle = PgExecutionDispatcher().dispatch_with_callback(
self._ctx(), on_success=on_s, on_error=on_e, task_id="tid-7"
)
assert handle.id == "tid-7" # call sites read .id off the handle
kwargs = enq.call_args.kwargs
assert kwargs["on_success"] == {
"task_name": "ide_prompt_complete",
"kwargs": {"callback_kwargs": {"room": "r1"}},
"queue": "ide_callback",
}
assert kwargs["on_error"]["task_name"] == "ide_prompt_error"
assert kwargs["task_id"] == "tid-7"
assert "reply_key" not in kwargs
Loading