Files
HadTavern/tests/test_while_nodes.py

134 lines
4.5 KiB
Python

import asyncio
from agentui.pipeline.executor import PipelineExecutor
from tests.utils import ctx as _ctx
async def scenario_providercall_while_ignore():
# ProviderCall with while loop and ignore_errors enabled.
# No base_url is provided to force ExecutionError inside node.run();
# wrapper will catch it and expose {"error": "..."} plus vars.
p = {
"id": "p_pc_while_ignore",
"name": "ProviderCall while+ignore",
"loop_mode": "dag",
"nodes": [
{
"id": "n2",
"type": "ProviderCall",
"config": {
"provider": "openai",
# while: 3 iterations (0,1,2)
"while_expr": "cycleindex < 3",
"while_max_iters": 10,
"ignore_errors": True,
# no base_url / provider_configs to trigger error safely
},
"in": {}
}
]
}
out = await PipelineExecutor(p).run(_ctx())
assert isinstance(out, dict)
# Wrapper returns final out with .vars merged by executor into STORE as well, but we assert on node out.
vars_map = out.get("vars") or {}
assert isinstance(vars_map, dict)
# Final iteration index should be 2
assert vars_map.get("WAS_ERROR__n2") is True
assert vars_map.get("CYCLEINDEX__n2") == 2
async def scenario_rawforward_while_ignore():
# RawForward with while loop and ignore_errors enabled.
# No base_url and incoming.json is a plain string -> detect_vendor=unknown -> ExecutionError,
# wrapper catches and returns {"error": "..."} with vars set.
p = {
"id": "p_rf_while_ignore",
"name": "RawForward while+ignore",
"loop_mode": "dag",
"nodes": [
{
"id": "n1",
"type": "RawForward",
"config": {
"while_expr": "cycleindex < 2",
"while_max_iters": 10,
"ignore_errors": True,
# no base_url; vendor detect will fail on plain text
},
"in": {}
}
]
}
ctx = _ctx()
# Provide incoming as plain text-like JSON so detect_vendor returns unknown
ctx["incoming"] = {
"method": "POST",
"url": "http://example.local/test",
"path": "/test",
"query": "",
"headers": {"content-type": "text/plain"},
"json": "raw-plain-body-simulated"
}
out = await PipelineExecutor(p).run(ctx)
assert isinstance(out, dict)
vars_map = out.get("vars") or {}
assert isinstance(vars_map, dict)
# Final iteration index should be 1 (0 and 1)
assert vars_map.get("WAS_ERROR__n1") is True
assert vars_map.get("CYCLEINDEX__n1") == 1
async def scenario_providercall_while_with_out_macro():
# SetVars -> ProviderCall while uses OUT from n1 in expression
# Expression: ([[OUT:n1.vars.MSG]] contains "123") && (cycleindex < 2)
# Ignore errors to bypass real HTTP
p = {
"id": "p_pc_while_out_macro",
"name": "ProviderCall while with OUT macro",
"loop_mode": "iterative",
"nodes": [
{
"id": "n1",
"type": "SetVars",
"config": {
"variables": [
{"id": "v1", "name": "MSG", "mode": "string", "value": "abc123xyz"}
]
},
"in": {}
},
{
"id": "n2",
"type": "ProviderCall",
"config": {
"provider": "openai",
"while_expr": "([[OUT:n1.vars.MSG]] contains \"123\") && (cycleindex < 2)",
"while_max_iters": 10,
"ignore_errors": True
},
"in": {
"depends": "n1.done"
}
}
]
}
out = await PipelineExecutor(p).run(_ctx())
assert isinstance(out, dict)
vars_map = out.get("vars") or {}
assert isinstance(vars_map, dict)
# Since MSG contains "123" and cycleindex < 2, two iterations (0,1)
assert vars_map.get("WAS_ERROR__n2") is True
assert vars_map.get("CYCLEINDEX__n2") == 1
def run_all():
async def main():
await scenario_providercall_while_ignore()
await scenario_rawforward_while_ignore()
await scenario_providercall_while_with_out_macro()
print("\n=== WHILE_NODES: DONE ===")
asyncio.run(main())
if __name__ == "__main__":
run_all()