Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions src/google/adk/cli/trigger_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,22 @@ class TriggerResponse(BaseModel):
)


def _make_trigger_user_id(
raw_value: Optional[str],
*,
default: str,
) -> str:
"""Normalize trigger metadata into a session-safe user_id."""
if not raw_value:
return default

normalized = raw_value.strip().strip("/")
if not normalized:
return default

return normalized.replace("/", "--")


# ---------------------------------------------------------------------------
# Trigger Router
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -411,7 +427,9 @@ def register(self, app: FastAPI) -> None:
async def trigger_pubsub(
app_name: str, req: PubSubTriggerRequest, request: Request
) -> TriggerResponse:
user_id = req.subscription or "pubsub-caller"
user_id = _make_trigger_user_id(
req.subscription, default="pubsub-caller"
)

decoded_data = None
data_payload = None
Expand Down Expand Up @@ -477,8 +495,9 @@ async def trigger_eventarc(
app_name: str, req: EventarcTriggerRequest, request: Request
) -> TriggerResponse:

user_id = (
req.source or request.headers.get("ce-source") or "eventarc-caller"
user_id = _make_trigger_user_id(
req.source or request.headers.get("ce-source"),
default="eventarc-caller",
)

logger.info(
Expand Down
37 changes: 37 additions & 0 deletions tests/unittests/cli/test_trigger_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,24 @@ def test_with_subscription_metadata(self, client):

assert resp.status_code == 200

def test_subscription_user_id_is_path_safe(
self, client, mock_session_service
):
"""Pub/Sub subscription-derived user_id is stored without slashes."""
message_data = base64.b64encode(b"test").decode("utf-8")
payload = {
"message": {"data": message_data},
"subscription": "projects/p/subscriptions/orders-sub",
}

resp = client.post("/apps/test_app/trigger/pubsub", json=payload)

assert resp.status_code == 200
assert (
"projects--p--subscriptions--orders-sub"
in mock_session_service.sessions["test_app"]
)

def test_unknown_app_fails_early(
self, client, mock_agent_loader, mock_session_service
):
Expand Down Expand Up @@ -513,6 +531,25 @@ def test_source_from_ce_header(self, client):
)
assert resp.status_code == 200

def test_eventarc_source_user_id_is_path_safe(
self, client, mock_session_service
):
"""Eventarc ce-source-derived user_id is stored without slashes."""
payload = {
"data": {"key": "value"},
}
resp = client.post(
"/apps/test_app/trigger/eventarc",
json=payload,
headers={"ce-source": "//pubsub.googleapis.com/projects/p/topics/t"},
)

assert resp.status_code == 200
assert (
"pubsub.googleapis.com--projects--p--topics--t"
in mock_session_service.sessions["test_app"]
)

def test_complex_event_data(self, client, monkeypatch):
"""Complex nested event data is serialized as JSON for the agent."""
captured_messages = []
Expand Down
Loading