Files
HadTavern/tests/test_cancel_modes.py
2025-10-03 21:55:24 +03:00

199 lines
7.0 KiB
Python

import asyncio
import json
from typing import Any, Dict
from agentui.pipeline.executor import PipelineExecutor, ExecutionError
from agentui.common.cancel import request_cancel, clear_cancel
import agentui.providers.http_client as hc
import agentui.pipeline.executor as ex
from tests.utils import ctx as _ctx
class DummyResponse:
def __init__(self, status: int, json_obj: Dict[str, Any]) -> None:
self.status_code = status
self._json = json_obj
self.headers = {}
try:
self.content = json.dumps(json_obj, ensure_ascii=False).encode("utf-8")
except Exception:
self.content = b"{}"
try:
self.text = json.dumps(json_obj, ensure_ascii=False)
except Exception:
self.text = "{}"
def json(self) -> Dict[str, Any]:
return self._json
class DummyClient:
"""
Async client with artificial delay to simulate in-flight HTTP that can be cancelled.
Provides .post() and .request() compatible with executor usage.
"""
def __init__(self, delay: float = 0.3, status_code: int = 200) -> None:
self._delay = delay
self._status = status_code
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def post(self, url: str, content: bytes, headers: Dict[str, str]):
# Artificial delay to allow cancel/abort to happen while awaiting
await asyncio.sleep(self._delay)
try:
payload = json.loads(content.decode("utf-8"))
except Exception:
payload = {"_raw": content.decode("utf-8", errors="ignore")}
return DummyResponse(self._status, {"echo": payload})
async def request(self, method: str, url: str, headers: Dict[str, str], content: bytes | None):
return await self.post(url, content or b"{}", headers)
def _patch_http_client(delay: float = 0.3):
"""
Patch both providers.http_client.build_client and executor.build_client
to return our DummyClient with a given delay.
"""
orig_hc = hc.build_client
orig_ex = ex.build_client
hc.build_client = lambda timeout=60.0: DummyClient(delay=delay) # type: ignore[assignment]
ex.build_client = lambda timeout=60.0: DummyClient(delay=delay) # type: ignore[assignment]
return orig_hc, orig_ex
def _restore_http_client(orig_hc, orig_ex) -> None:
hc.build_client = orig_hc
ex.build_client = orig_ex
def test_graceful_cancel_while_providercall():
"""
Expectation:
- Cancel(mode=graceful) during in-flight HTTP should NOT interrupt the current request.
- While-wrapper should stop before starting next iteration.
- Final CYCLEINDEX__n2 == 0 (only first iteration finished), WAS_ERROR__n2 is False/absent.
"""
async def main():
p = {
"id": "p_cancel_soft",
"name": "ProviderCall graceful cancel",
"loop_mode": "dag",
"nodes": [
{
"id": "n2",
"type": "ProviderCall",
"config": {
"provider": "openai",
"while_expr": "cycleindex < 5",
"while_max_iters": 10,
# ignore_errors not needed for graceful (no interruption of in-flight)
"provider_configs": {
"openai": {
"base_url": "http://dummy.local",
"headers": "{}",
"template": "{}"
}
}
},
"in": {}
}
]
}
pid = p["id"]
orig_hc, orig_ex = _patch_http_client(delay=0.3)
try:
ctx = _ctx()
exr = PipelineExecutor(p)
task = asyncio.create_task(exr.run(ctx))
# Give the node time to start HTTP, then request graceful cancel
await asyncio.sleep(0.05)
request_cancel(pid, mode="graceful")
out = await task
finally:
_restore_http_client(orig_hc, orig_ex)
try:
clear_cancel(pid)
except Exception:
pass
assert isinstance(out, dict)
vars_map = out.get("vars") or {}
assert isinstance(vars_map, dict)
# Only first iteration should have finished; last index = 0
assert vars_map.get("CYCLEINDEX__n2") == 0
# No error expected on graceful (we didn't interrupt the in-flight HTTP)
assert vars_map.get("WAS_ERROR__n2") in (False, None)
asyncio.run(main())
def test_abort_cancel_inflight_providercall():
"""
Expectation:
- Cancel(mode=abort) during in-flight HTTP cancels the await with ExecutionError.
- While-wrapper with ignore_errors=True converts it into {"result":{"error":...}}.
- Final CYCLEINDEX__n2 == 0 and WAS_ERROR__n2 == True; error mentions 'Cancelled by user (abort)'.
"""
async def main():
p = {
"id": "p_cancel_abort",
"name": "ProviderCall abort cancel",
"loop_mode": "dag",
"nodes": [
{
"id": "n2",
"type": "ProviderCall",
"config": {
"provider": "openai",
"while_expr": "cycleindex < 5",
"while_max_iters": 10,
"ignore_errors": True, # convert cancellation exception into error payload
"provider_configs": {
"openai": {
"base_url": "http://dummy.local",
"headers": "{}",
"template": "{}"
}
}
},
"in": {}
}
]
}
pid = p["id"]
orig_hc, orig_ex = _patch_http_client(delay=0.3)
try:
ctx = _ctx()
exr = PipelineExecutor(p)
task = asyncio.create_task(exr.run(ctx))
# Let HTTP start, then trigger hard abort
await asyncio.sleep(0.05)
request_cancel(pid, mode="abort")
out = await task
finally:
_restore_http_client(orig_hc, orig_ex)
try:
clear_cancel(pid)
except Exception:
pass
assert isinstance(out, dict)
vars_map = out.get("vars") or {}
assert isinstance(vars_map, dict)
# First iteration was started; after abort it is considered errored and loop stops
assert vars_map.get("CYCLEINDEX__n2") == 0
assert vars_map.get("WAS_ERROR__n2") is True
# Error propagated into node's result (ignore_errors=True path)
res = out.get("result") or {}
assert isinstance(res, dict)
err = res.get("error")
assert isinstance(err, str) and "Cancelled by user (abort)" in err
asyncio.run(main())