Files
HadTavern/tests/test_prompt_combine.py

249 lines
10 KiB
Python

import asyncio
import json
from typing import Any, Dict, List
from agentui.pipeline.executor import PipelineExecutor
import agentui.providers.http_client as hc
from tests.utils import ctx as _ctx, pp as _pp
# Capture of all outbound ProviderCall HTTP requests (one per run)
CAPTURED: List[Dict[str, Any]] = []
class DummyResponse:
def __init__(self, status_code: int = 200, body: Dict[str, Any] | None = None):
self.status_code = status_code
self._json = body if body is not None else {"ok": True}
self.headers = {}
try:
self.content = json.dumps(self._json, ensure_ascii=False).encode("utf-8")
except Exception:
self.content = b"{}"
try:
self.text = json.dumps(self._json, ensure_ascii=False)
except Exception:
self.text = "{}"
def json(self) -> Any:
return self._json
class DummyClient:
def __init__(self, capture: List[Dict[str, Any]], status_code: int = 200):
self._capture = capture
self._status = status_code
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def post(self, url: str, content: bytes, headers: Dict[str, str]):
try:
payload = json.loads(content.decode("utf-8"))
except Exception:
payload = {"_raw": content.decode("utf-8", errors="ignore")}
rec = {"url": url, "headers": headers, "payload": payload}
self._capture.append(rec)
# Echo payload back to keep extractor happy but not tied to vendor formats
return DummyResponse(self._status, {"echo": rec})
# RawForward may use .request, but we don't need it here
async def request(self, method: str, url: str, headers: Dict[str, str], content: bytes | None):
return await self.post(url, content or b"{}", headers)
def _patch_http_client():
"""Monkeypatch build_client used by ProviderCall to our dummy."""
hc.build_client = lambda timeout=60.0: DummyClient(CAPTURED, 200) # type: ignore[assignment]
# Также патчим символ, импортированный внутрь executor, чтобы ProviderCall использовал DummyClient
import agentui.pipeline.executor as ex # type: ignore
ex.build_client = lambda timeout=60.0: DummyClient(CAPTURED, 200) # type: ignore
def _mk_pipeline(provider: str, prompt_combine: str) -> Dict[str, Any]:
"""Build a minimal ProviderCall-only pipeline for a given provider and combine spec."""
provider = provider.lower().strip()
if provider not in {"openai", "gemini", "claude"}:
raise AssertionError(f"Unsupported provider in test: {provider}")
base_url = "http://mock.local"
if provider == "openai":
endpoint = "/v1/chat/completions"
template = '{ "model": "{{ model }}", [[PROMPT]] }'
elif provider == "gemini":
endpoint = "/v1beta/models/{{ model }}:generateContent"
template = '{ "model": "{{ model }}", [[PROMPT]] }'
else: # claude
endpoint = "/v1/messages"
template = '{ "model": "{{ model }}", [[PROMPT]] }'
p = {
"id": f"p_prompt_combine_{provider}",
"name": f"prompt_combine to {provider}",
"loop_mode": "dag",
"nodes": [
{
"id": "n1",
"type": "ProviderCall",
"config": {
"provider": provider,
"provider_configs": {
provider: {
"base_url": base_url,
"endpoint": endpoint,
"headers": "{}",
"template": template,
}
},
# Key under test:
"prompt_combine": prompt_combine,
# Prompt Blocks (PROMPT)
"blocks": [
{"id": "b1", "name": "sys", "role": "system", "prompt": "Ты — Narrator-chan.", "enabled": True, "order": 0},
{"id": "b2", "name": "user", "role": "user", "prompt": "как лела", "enabled": True, "order": 1},
],
},
"in": {},
}
],
}
return p
def _ctx_with_incoming(incoming_json: Dict[str, Any], vendor: str = "openai") -> Dict[str, Any]:
base = _ctx(vendor=vendor)
inc = dict(base["incoming"])
inc["json"] = incoming_json
base["incoming"] = inc
return base
async def scenario_openai_target_from_gemini_contents():
print("\n=== PROMPT_COMBINE 1: target=openai, incoming=gemini.contents & PROMPT ===")
_patch_http_client()
CAPTURED.clear()
# Incoming JSON in Gemini shape
incoming_json = {
"contents": [
{"role": "user", "parts": [{"text": "Прив"}]},
{"role": "model", "parts": [{"text": "И тебе привет!"}]},
]
}
p = _mk_pipeline("openai", "[[VAR:incoming.json.contents]] & [[PROMPT]]")
out = await PipelineExecutor(p).run(_ctx_with_incoming(incoming_json, vendor="gemini"))
print("PIPE OUT:", _pp(out))
assert CAPTURED, "No HTTP request captured"
req = CAPTURED[-1]
payload = req["payload"]
# Validate OpenAI body
assert "messages" in payload, "OpenAI payload must contain messages"
msgs = payload["messages"]
# Expected: 2 (converted Gemini) + 2 (PROMPT blocks system+user) = 4
assert isinstance(msgs, list) and len(msgs) == 4
roles = [m.get("role") for m in msgs]
# Gemini model -> OpenAI assistant
assert "assistant" in roles and "user" in roles
# PROMPT system+user present (system may be not first without @pos; we just ensure existence)
assert any(m.get("role") == "system" for m in msgs), "System message from PROMPT must be present"
async def scenario_gemini_target_from_openai_messages():
print("\n=== PROMPT_COMBINE 2: target=gemini, incoming=openai.messages & PROMPT ===")
_patch_http_client()
CAPTURED.clear()
incoming_json = {
"messages": [
{"role": "system", "content": "Системный-тест из входящего"},
{"role": "user", "content": "Its just me.."},
{"role": "assistant", "content": "Reply from model"},
]
}
p = _mk_pipeline("gemini", "[[VAR:incoming.json.messages]] & [[PROMPT]]")
out = await PipelineExecutor(p).run(_ctx_with_incoming(incoming_json, vendor="openai"))
print("PIPE OUT:", _pp(out))
assert CAPTURED, "No HTTP request captured"
payload = CAPTURED[-1]["payload"]
# Validate Gemini body
assert "contents" in payload, "Gemini payload must contain contents"
cnts = payload["contents"]
assert isinstance(cnts, list)
# PROMPT system goes to systemInstruction, user block goes to contents
assert "systemInstruction" in payload, "Gemini payload must contain systemInstruction when system text exists"
si = payload["systemInstruction"]
# SystemInstruction.parts[].text must include both incoming system and PROMPT system merged
si_texts = []
try:
for prt in si.get("parts", []):
t = prt.get("text")
if isinstance(t, str) and t.strip():
si_texts.append(t.strip())
except Exception:
pass
joined = "\n".join(si_texts)
assert "Системный-тест из входящего" in joined, "Incoming system must be merged into systemInstruction"
assert "Narrator-chan" in joined, "PROMPT system must be merged into systemInstruction"
async def scenario_claude_target_from_openai_messages():
print("\n=== PROMPT_COMBINE 3: target=claude, incoming=openai.messages & PROMPT ===")
_patch_http_client()
CAPTURED.clear()
incoming_json = {
"messages": [
{"role": "system", "content": "Системный-тест CLAUDE"},
{"role": "user", "content": "Прив"},
{"role": "assistant", "content": "Привет!"},
]
}
p = _mk_pipeline("claude", "[[VAR:incoming.json.messages]] & [[PROMPT]]")
out = await PipelineExecutor(p).run(_ctx_with_incoming(incoming_json, vendor="openai"))
print("PIPE OUT:", _pp(out))
assert CAPTURED, "No HTTP request captured"
payload = CAPTURED[-1]["payload"]
# Validate Claude body
assert "messages" in payload, "Claude payload must contain messages"
assert "system" in payload, "Claude payload must contain system blocks"
sys_blocks = payload["system"]
# system must be array of blocks with type=text
assert isinstance(sys_blocks, list) and any(isinstance(b, dict) and b.get("type") == "text" for b in sys_blocks)
sys_text_join = "\n".join([b.get("text") for b in sys_blocks if isinstance(b, dict) and isinstance(b.get("text"), str)])
assert "Системный-тест CLAUDE" in sys_text_join, "Incoming system should be present"
assert "Narrator-chan" in sys_text_join, "PROMPT system should be present"
async def scenario_prepend_positioning_openai():
print("\n=== PROMPT_COMBINE 4: target=openai, PROMPT@pos=prepend & incoming.contents ===")
_patch_http_client()
CAPTURED.clear()
incoming_json = {
"contents": [
{"role": "user", "parts": [{"text": "A"}]},
{"role": "model", "parts": [{"text": "B"}]},
]
}
# Put PROMPT first; ensure system message becomes first in messages
p = _mk_pipeline("openai", "[[PROMPT]]@pos=prepend & [[VAR:incoming.json.contents]]")
out = await PipelineExecutor(p).run(_ctx_with_incoming(incoming_json, vendor="gemini"))
print("PIPE OUT:", _pp(out))
assert CAPTURED, "No HTTP request captured"
payload = CAPTURED[-1]["payload"]
msgs = payload.get("messages", [])
assert isinstance(msgs, list) and len(msgs) >= 2
first = msgs[0]
# Expect first to be system (from PROMPT) due to prepend
assert first.get("role") == "system", f"Expected system as first message, got {first}"
def test_prompt_combine_all():
async def main():
await scenario_openai_target_from_gemini_contents()
await scenario_gemini_target_from_openai_messages()
await scenario_claude_target_from_openai_messages()
await scenario_prepend_positioning_openai()
print("\n=== PROMPT_COMBINE: DONE ===")
asyncio.run(main())