199 lines
7.0 KiB
Python
199 lines
7.0 KiB
Python
import asyncio
|
|
import json
|
|
from typing import Any, Dict
|
|
|
|
from agentui.pipeline.executor import PipelineExecutor, ExecutionError
|
|
from agentui.common.cancel import request_cancel, clear_cancel
|
|
import agentui.providers.http_client as hc
|
|
import agentui.pipeline.executor as ex
|
|
from tests.utils import ctx as _ctx
|
|
|
|
|
|
class DummyResponse:
|
|
def __init__(self, status: int, json_obj: Dict[str, Any]) -> None:
|
|
self.status_code = status
|
|
self._json = json_obj
|
|
self.headers = {}
|
|
try:
|
|
self.content = json.dumps(json_obj, ensure_ascii=False).encode("utf-8")
|
|
except Exception:
|
|
self.content = b"{}"
|
|
try:
|
|
self.text = json.dumps(json_obj, ensure_ascii=False)
|
|
except Exception:
|
|
self.text = "{}"
|
|
|
|
def json(self) -> Dict[str, Any]:
|
|
return self._json
|
|
|
|
|
|
class DummyClient:
|
|
"""
|
|
Async client with artificial delay to simulate in-flight HTTP that can be cancelled.
|
|
Provides .post() and .request() compatible with executor usage.
|
|
"""
|
|
def __init__(self, delay: float = 0.3, status_code: int = 200) -> None:
|
|
self._delay = delay
|
|
self._status = status_code
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
async def post(self, url: str, content: bytes, headers: Dict[str, str]):
|
|
# Artificial delay to allow cancel/abort to happen while awaiting
|
|
await asyncio.sleep(self._delay)
|
|
try:
|
|
payload = json.loads(content.decode("utf-8"))
|
|
except Exception:
|
|
payload = {"_raw": content.decode("utf-8", errors="ignore")}
|
|
return DummyResponse(self._status, {"echo": payload})
|
|
|
|
async def request(self, method: str, url: str, headers: Dict[str, str], content: bytes | None):
|
|
return await self.post(url, content or b"{}", headers)
|
|
|
|
|
|
def _patch_http_client(delay: float = 0.3):
|
|
"""
|
|
Patch both providers.http_client.build_client and executor.build_client
|
|
to return our DummyClient with a given delay.
|
|
"""
|
|
orig_hc = hc.build_client
|
|
orig_ex = ex.build_client
|
|
hc.build_client = lambda timeout=60.0: DummyClient(delay=delay) # type: ignore[assignment]
|
|
ex.build_client = lambda timeout=60.0: DummyClient(delay=delay) # type: ignore[assignment]
|
|
return orig_hc, orig_ex
|
|
|
|
|
|
def _restore_http_client(orig_hc, orig_ex) -> None:
|
|
hc.build_client = orig_hc
|
|
ex.build_client = orig_ex
|
|
|
|
|
|
def test_graceful_cancel_while_providercall():
|
|
"""
|
|
Expectation:
|
|
- Cancel(mode=graceful) during in-flight HTTP should NOT interrupt the current request.
|
|
- While-wrapper should stop before starting next iteration.
|
|
- Final CYCLEINDEX__n2 == 0 (only first iteration finished), WAS_ERROR__n2 is False/absent.
|
|
"""
|
|
async def main():
|
|
p = {
|
|
"id": "p_cancel_soft",
|
|
"name": "ProviderCall graceful cancel",
|
|
"loop_mode": "dag",
|
|
"nodes": [
|
|
{
|
|
"id": "n2",
|
|
"type": "ProviderCall",
|
|
"config": {
|
|
"provider": "openai",
|
|
"while_expr": "cycleindex < 5",
|
|
"while_max_iters": 10,
|
|
# ignore_errors not needed for graceful (no interruption of in-flight)
|
|
"provider_configs": {
|
|
"openai": {
|
|
"base_url": "http://dummy.local",
|
|
"headers": "{}",
|
|
"template": "{}"
|
|
}
|
|
}
|
|
},
|
|
"in": {}
|
|
}
|
|
]
|
|
}
|
|
pid = p["id"]
|
|
orig_hc, orig_ex = _patch_http_client(delay=0.3)
|
|
try:
|
|
ctx = _ctx()
|
|
exr = PipelineExecutor(p)
|
|
task = asyncio.create_task(exr.run(ctx))
|
|
# Give the node time to start HTTP, then request graceful cancel
|
|
await asyncio.sleep(0.05)
|
|
request_cancel(pid, mode="graceful")
|
|
out = await task
|
|
finally:
|
|
_restore_http_client(orig_hc, orig_ex)
|
|
try:
|
|
clear_cancel(pid)
|
|
except Exception:
|
|
pass
|
|
|
|
assert isinstance(out, dict)
|
|
vars_map = out.get("vars") or {}
|
|
assert isinstance(vars_map, dict)
|
|
# Only first iteration should have finished; last index = 0
|
|
assert vars_map.get("CYCLEINDEX__n2") == 0
|
|
# No error expected on graceful (we didn't interrupt the in-flight HTTP)
|
|
assert vars_map.get("WAS_ERROR__n2") in (False, None)
|
|
|
|
asyncio.run(main())
|
|
|
|
|
|
def test_abort_cancel_inflight_providercall():
|
|
"""
|
|
Expectation:
|
|
- Cancel(mode=abort) during in-flight HTTP cancels the await with ExecutionError.
|
|
- While-wrapper with ignore_errors=True converts it into {"result":{"error":...}}.
|
|
- Final CYCLEINDEX__n2 == 0 and WAS_ERROR__n2 == True; error mentions 'Cancelled by user (abort)'.
|
|
"""
|
|
async def main():
|
|
p = {
|
|
"id": "p_cancel_abort",
|
|
"name": "ProviderCall abort cancel",
|
|
"loop_mode": "dag",
|
|
"nodes": [
|
|
{
|
|
"id": "n2",
|
|
"type": "ProviderCall",
|
|
"config": {
|
|
"provider": "openai",
|
|
"while_expr": "cycleindex < 5",
|
|
"while_max_iters": 10,
|
|
"ignore_errors": True, # convert cancellation exception into error payload
|
|
"provider_configs": {
|
|
"openai": {
|
|
"base_url": "http://dummy.local",
|
|
"headers": "{}",
|
|
"template": "{}"
|
|
}
|
|
}
|
|
},
|
|
"in": {}
|
|
}
|
|
]
|
|
}
|
|
pid = p["id"]
|
|
orig_hc, orig_ex = _patch_http_client(delay=0.3)
|
|
try:
|
|
ctx = _ctx()
|
|
exr = PipelineExecutor(p)
|
|
task = asyncio.create_task(exr.run(ctx))
|
|
# Let HTTP start, then trigger hard abort
|
|
await asyncio.sleep(0.05)
|
|
request_cancel(pid, mode="abort")
|
|
out = await task
|
|
finally:
|
|
_restore_http_client(orig_hc, orig_ex)
|
|
try:
|
|
clear_cancel(pid)
|
|
except Exception:
|
|
pass
|
|
|
|
assert isinstance(out, dict)
|
|
vars_map = out.get("vars") or {}
|
|
assert isinstance(vars_map, dict)
|
|
# First iteration was started; after abort it is considered errored and loop stops
|
|
assert vars_map.get("CYCLEINDEX__n2") == 0
|
|
assert vars_map.get("WAS_ERROR__n2") is True
|
|
|
|
# Error propagated into node's result (ignore_errors=True path)
|
|
res = out.get("result") or {}
|
|
assert isinstance(res, dict)
|
|
err = res.get("error")
|
|
assert isinstance(err, str) and "Cancelled by user (abort)" in err
|
|
|
|
asyncio.run(main()) |