HadTavern 0.01: Gemini/Claude fixes; UI _origId reuse; docs; .bat open
This commit is contained in:
@@ -3,13 +3,14 @@ import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
import json
|
||||
from urllib.parse import urlsplit, urlunsplit, parse_qsl, urlencode, unquote
|
||||
from fastapi.responses import JSONResponse, HTMLResponse
|
||||
from fastapi.responses import JSONResponse, HTMLResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from agentui.pipeline.executor import PipelineExecutor
|
||||
from agentui.pipeline.defaults import default_pipeline
|
||||
from agentui.pipeline.storage import load_pipeline, save_pipeline, list_presets, load_preset, save_preset
|
||||
from agentui.common.vendors import detect_vendor
|
||||
|
||||
|
||||
class UnifiedParams(BaseModel):
|
||||
@@ -38,17 +39,6 @@ class UnifiedChatRequest(BaseModel):
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def detect_vendor(payload: Dict[str, Any]) -> str:
|
||||
if "anthropic_version" in payload or payload.get("provider") == "anthropic":
|
||||
return "claude"
|
||||
# Gemini typical payload keys
|
||||
if "contents" in payload or "generationConfig" in payload:
|
||||
return "gemini"
|
||||
# OpenAI typical keys
|
||||
if "messages" in payload or "model" in payload:
|
||||
return "openai"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def normalize_to_unified(payload: Dict[str, Any]) -> UnifiedChatRequest:
|
||||
vendor = detect_vendor(payload)
|
||||
@@ -278,6 +268,34 @@ def create_app() -> FastAPI:
|
||||
logger.addHandler(stream_handler)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# --- Simple in-process SSE hub (subscriptions per browser tab) ---
|
||||
import asyncio as _asyncio
|
||||
|
||||
class _SSEHub:
|
||||
def __init__(self) -> None:
|
||||
self._subs: List[_asyncio.Queue] = []
|
||||
|
||||
def subscribe(self) -> _asyncio.Queue:
|
||||
q: _asyncio.Queue = _asyncio.Queue()
|
||||
self._subs.append(q)
|
||||
return q
|
||||
|
||||
def unsubscribe(self, q: _asyncio.Queue) -> None:
|
||||
try:
|
||||
self._subs.remove(q)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def publish(self, event: Dict[str, Any]) -> None:
|
||||
# Fan-out to all subscribers; drop if queue is full
|
||||
for q in list(self._subs):
|
||||
try:
|
||||
await q.put(event)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_trace_hub = _SSEHub()
|
||||
|
||||
def _mask_headers(h: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Временно отключаем маскировку Authorization для отладки
|
||||
hidden = {"x-api-key", "cookie"}
|
||||
@@ -369,7 +387,15 @@ def create_app() -> FastAPI:
|
||||
macro_ctx = build_macro_context(unified, incoming=incoming)
|
||||
pipeline = load_pipeline()
|
||||
executor = PipelineExecutor(pipeline)
|
||||
last = await executor.run(macro_ctx)
|
||||
|
||||
async def _trace(evt: Dict[str, Any]) -> None:
|
||||
try:
|
||||
base = {"pipeline_id": pipeline.get("id", "pipeline_editor")}
|
||||
await _trace_hub.publish({**base, **evt})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
last = await executor.run(macro_ctx, trace=_trace)
|
||||
result = last.get("result") or await execute_pipeline_echo(unified)
|
||||
await _log_response(request, 200, result)
|
||||
return JSONResponse(result)
|
||||
@@ -402,7 +428,13 @@ def create_app() -> FastAPI:
|
||||
macro_ctx = build_macro_context(unified, incoming=incoming)
|
||||
pipeline = load_pipeline()
|
||||
executor = PipelineExecutor(pipeline)
|
||||
last = await executor.run(macro_ctx)
|
||||
async def _trace(evt: Dict[str, Any]) -> None:
|
||||
try:
|
||||
base = {"pipeline_id": pipeline.get("id", "pipeline_editor")}
|
||||
await _trace_hub.publish({**base, **evt})
|
||||
except Exception:
|
||||
pass
|
||||
last = await executor.run(macro_ctx, trace=_trace)
|
||||
result = last.get("result") or await execute_pipeline_echo(unified)
|
||||
await _log_response(request, 200, result)
|
||||
return JSONResponse(result)
|
||||
@@ -431,7 +463,13 @@ def create_app() -> FastAPI:
|
||||
macro_ctx = build_macro_context(unified, incoming=incoming)
|
||||
pipeline = load_pipeline()
|
||||
executor = PipelineExecutor(pipeline)
|
||||
last = await executor.run(macro_ctx)
|
||||
async def _trace(evt: Dict[str, Any]) -> None:
|
||||
try:
|
||||
base = {"pipeline_id": pipeline.get("id", "pipeline_editor")}
|
||||
await _trace_hub.publish({**base, **evt})
|
||||
except Exception:
|
||||
pass
|
||||
last = await executor.run(macro_ctx, trace=_trace)
|
||||
result = last.get("result") or await execute_pipeline_echo(unified)
|
||||
await _log_response(request, 200, result)
|
||||
return JSONResponse(result)
|
||||
@@ -465,7 +503,13 @@ def create_app() -> FastAPI:
|
||||
macro_ctx = build_macro_context(unified, incoming=incoming)
|
||||
pipeline = load_pipeline()
|
||||
executor = PipelineExecutor(pipeline)
|
||||
last = await executor.run(macro_ctx)
|
||||
async def _trace(evt: Dict[str, Any]) -> None:
|
||||
try:
|
||||
base = {"pipeline_id": pipeline.get("id", "pipeline_editor")}
|
||||
await _trace_hub.publish({**base, **evt})
|
||||
except Exception:
|
||||
pass
|
||||
last = await executor.run(macro_ctx, trace=_trace)
|
||||
result = last.get("result") or await execute_pipeline_echo(unified)
|
||||
await _log_response(request, 200, result)
|
||||
return JSONResponse(result)
|
||||
@@ -498,7 +542,13 @@ def create_app() -> FastAPI:
|
||||
macro_ctx = build_macro_context(unified, incoming=incoming)
|
||||
pipeline = load_pipeline()
|
||||
executor = PipelineExecutor(pipeline)
|
||||
last = await executor.run(macro_ctx)
|
||||
async def _trace(evt: Dict[str, Any]) -> None:
|
||||
try:
|
||||
base = {"pipeline_id": pipeline.get("id", "pipeline_editor")}
|
||||
await _trace_hub.publish({**base, **evt})
|
||||
except Exception:
|
||||
pass
|
||||
last = await executor.run(macro_ctx, trace=_trace)
|
||||
result = last.get("result") or await execute_pipeline_echo(unified)
|
||||
await _log_response(request, 200, result)
|
||||
return JSONResponse(result)
|
||||
@@ -532,11 +582,16 @@ def create_app() -> FastAPI:
|
||||
macro_ctx = build_macro_context(unified, incoming=incoming)
|
||||
pipeline = load_pipeline()
|
||||
executor = PipelineExecutor(pipeline)
|
||||
last = await executor.run(macro_ctx)
|
||||
async def _trace(evt: Dict[str, Any]) -> None:
|
||||
try:
|
||||
base = {"pipeline_id": pipeline.get("id", "pipeline_editor")}
|
||||
await _trace_hub.publish({**base, **evt})
|
||||
except Exception:
|
||||
pass
|
||||
last = await executor.run(macro_ctx, trace=_trace)
|
||||
result = last.get("result") or await execute_pipeline_echo(unified)
|
||||
await _log_response(request, 200, result)
|
||||
return JSONResponse(result)
|
||||
|
||||
app.mount("/ui", StaticFiles(directory="static", html=True), name="ui")
|
||||
|
||||
# Admin API для пайплайна
|
||||
@@ -580,6 +635,30 @@ def create_app() -> FastAPI:
|
||||
raise HTTPException(status_code=400, detail="Invalid pipeline format")
|
||||
save_preset(name, payload)
|
||||
return JSONResponse({"ok": True})
|
||||
# --- SSE endpoint for live pipeline trace ---
|
||||
@app.get("/admin/trace/stream")
|
||||
async def sse_trace() -> StreamingResponse:
|
||||
loop = _asyncio.get_event_loop()
|
||||
q = _trace_hub.subscribe()
|
||||
|
||||
async def _gen():
|
||||
try:
|
||||
# warm-up: send a comment to keep connection open
|
||||
yield ":ok\n\n"
|
||||
while True:
|
||||
evt = await q.get()
|
||||
try:
|
||||
line = f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
|
||||
except Exception:
|
||||
line = "data: {}\n\n"
|
||||
yield line
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
_trace_hub.unsubscribe(q)
|
||||
|
||||
return StreamingResponse(_gen(), media_type="text/event-stream")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user