From 2d8473aa196d92f719f2887f5ec4322fed815463 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 12 May 2026 17:01:12 -0700 Subject: [PATCH 1/2] Adding initializer service --- pyrit/backend/main.py | 3 +- pyrit/backend/models/__init__.py | 11 + pyrit/backend/models/initializers.py | 44 +++ pyrit/backend/models/scenarios.py | 17 +- pyrit/backend/routes/__init__.py | 3 +- pyrit/backend/routes/initializers.py | 75 +++++ pyrit/backend/services/__init__.py | 6 + pyrit/backend/services/initializer_service.py | 141 +++++++++ .../backend/services/scenario_run_service.py | 22 +- pyrit/backend/services/scenario_service.py | 16 +- .../unit/backend/test_initializer_service.py | 291 ++++++++++++++++++ .../unit/backend/test_scenario_run_service.py | 81 +++++ tests/unit/backend/test_scenario_service.py | 110 +++++++ 13 files changed, 802 insertions(+), 18 deletions(-) create mode 100644 pyrit/backend/models/initializers.py create mode 100644 pyrit/backend/routes/initializers.py create mode 100644 pyrit/backend/services/initializer_service.py create mode 100644 tests/unit/backend/test_initializer_service.py diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index a1a9cad0ba..fe19894459 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -18,7 +18,7 @@ import pyrit from pyrit.backend.middleware import RequestIdMiddleware, SecurityHeadersMiddleware, register_error_handlers from pyrit.backend.middleware.auth import EntraAuthMiddleware -from pyrit.backend.routes import attacks, auth, converters, health, labels, media, scenarios, targets, version +from pyrit.backend.routes import attacks, auth, converters, health, initializers, labels, media, scenarios, targets, version from pyrit.memory import CentralMemory # Check for development mode from environment variable @@ -86,6 +86,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app.include_router(targets.router, prefix="/api", tags=["targets"]) app.include_router(converters.router, prefix="/api", tags=["converters"]) app.include_router(scenarios.router, prefix="/api", tags=["scenarios"]) +app.include_router(initializers.router, prefix="/api", tags=["initializers"]) app.include_router(labels.router, prefix="/api", tags=["labels"]) app.include_router(health.router, prefix="/api", tags=["health"]) app.include_router(auth.router, prefix="/api", tags=["auth"]) diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index 4c0aad1665..b33901f560 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -47,9 +47,15 @@ CreateConverterResponse, PreviewStep, ) +from pyrit.backend.models.initializers import ( + InitializerParameterSummary, + ListRegisteredInitializersResponse, + RegisteredInitializer, +) from pyrit.backend.models.scenarios import ( ListRegisteredScenariosResponse, RegisteredScenario, + ScenarioParameterSummary, ) from pyrit.backend.models.targets import ( CreateTargetRequest, @@ -99,6 +105,11 @@ # Scenarios "ListRegisteredScenariosResponse", "RegisteredScenario", + "ScenarioParameterSummary", + # Initializers + "InitializerParameterSummary", + "ListRegisteredInitializersResponse", + "RegisteredInitializer", # Targets "CreateTargetRequest", "TargetCapabilitiesInfo", diff --git a/pyrit/backend/models/initializers.py b/pyrit/backend/models/initializers.py new file mode 100644 index 0000000000..4df752f70b --- /dev/null +++ b/pyrit/backend/models/initializers.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Initializer API response models. + +Initializers configure the PyRIT environment (targets, datasets, env vars) +before scenario execution. These models represent initializer metadata. +""" + +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from pyrit.backend.models.common import PaginationInfo + + +class InitializerParameterSummary(BaseModel): + """Summary of an initializer-declared parameter.""" + + name: str = Field(..., description="Parameter name") + description: str = Field(..., description="Human-readable description of the parameter") + default: Optional[list[str]] = Field(None, description="Default value(s), or None if required") + + +class RegisteredInitializer(BaseModel): + """Summary of a registered initializer.""" + + initializer_name: str = Field(..., description="Initializer registry name (e.g., 'target')") + initializer_type: str = Field(..., description="Initializer class name (e.g., 'TargetInitializer')") + description: str = Field("", description="Human-readable description of the initializer") + required_env_vars: list[str] = Field( + default_factory=list, description="Environment variables required by this initializer" + ) + supported_parameters: list[InitializerParameterSummary] = Field( + default_factory=list, description="Parameters accepted by this initializer" + ) + + +class ListRegisteredInitializersResponse(BaseModel): + """Response for listing initializers.""" + + items: list[RegisteredInitializer] = Field(..., description="List of initializer summaries") + pagination: PaginationInfo = Field(..., description="Pagination metadata") diff --git a/pyrit/backend/models/scenarios.py b/pyrit/backend/models/scenarios.py index e628020c2f..7a74fbcb35 100644 --- a/pyrit/backend/models/scenarios.py +++ b/pyrit/backend/models/scenarios.py @@ -18,6 +18,16 @@ from pyrit.backend.models.common import PaginationInfo +class ScenarioParameterSummary(BaseModel): + """Summary of a scenario-declared parameter.""" + + name: str = Field(..., description="Parameter name (e.g., 'max_turns')") + description: str = Field(..., description="Human-readable description of the parameter") + default: str | None = Field(None, description="Default value as a display string, or None if required") + param_type: str = Field(..., description="Type of the parameter as a display string (e.g., 'int', 'str')") + choices: str | None = Field(None, description="Allowed values as a display string, or None if unconstrained") + + class RegisteredScenario(BaseModel): """Summary of a registered scenario.""" @@ -31,6 +41,9 @@ class RegisteredScenario(BaseModel): all_strategies: list[str] = Field(..., description="All available concrete strategy names") default_datasets: list[str] = Field(..., description="Default dataset names used by the scenario") max_dataset_size: Optional[int] = Field(None, description="Maximum items per dataset (None means unlimited)") + supported_parameters: list[ScenarioParameterSummary] = Field( + default_factory=list, description="Scenario-declared custom parameters" + ) class ListRegisteredScenariosResponse(BaseModel): @@ -99,8 +112,8 @@ class ScenarioRunSummary(BaseModel): updated_at: datetime = Field(..., description="When the run status last changed") error: str | None = Field(None, description="Error message if status is FAILED") strategies_used: list[str] = Field(default_factory=list, description="Strategy names that were executed") - total_attacks: int = Field(0, ge=0, description="Total number of atomic attacks") - completed_attacks: int = Field(0, ge=0, description="Number of attacks that completed") + total_attacks: int = Field(0, ge=0, description="Total number of attack results persisted for this run") + completed_attacks: int = Field(0, ge=0, description="Number of attacks that reached a terminal outcome") objective_achieved_rate: int = Field(0, ge=0, le=100, description="Success rate as percentage (0-100)") labels: dict[str, str] = Field(default_factory=dict, description="Labels attached to this run") completed_at: datetime | None = Field(None, description="When the scenario finished") diff --git a/pyrit/backend/routes/__init__.py b/pyrit/backend/routes/__init__.py index ca412238ea..daad0c53e8 100644 --- a/pyrit/backend/routes/__init__.py +++ b/pyrit/backend/routes/__init__.py @@ -5,12 +5,13 @@ API route handlers. """ -from pyrit.backend.routes import attacks, converters, health, labels, media, scenarios, targets, version +from pyrit.backend.routes import attacks, converters, health, initializers, labels, media, scenarios, targets, version __all__ = [ "attacks", "converters", "health", + "initializers", "labels", "media", "scenarios", diff --git a/pyrit/backend/routes/initializers.py b/pyrit/backend/routes/initializers.py new file mode 100644 index 0000000000..7c10d7ad63 --- /dev/null +++ b/pyrit/backend/routes/initializers.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Initializer API routes. + +Provides endpoints for listing available initializers and their metadata. + +Route structure: + /api/initializers — list all initializers + /api/initializers/{name} — get single initializer detail +""" + +from typing import Optional + +from fastapi import APIRouter, HTTPException, Query, status + +from pyrit.backend.models.common import ProblemDetail +from pyrit.backend.models.initializers import ( + ListRegisteredInitializersResponse, + RegisteredInitializer, +) +from pyrit.backend.services.initializer_service import get_initializer_service + +router = APIRouter(prefix="/initializers", tags=["initializers"]) + + +@router.get( + "", + response_model=ListRegisteredInitializersResponse, +) +async def list_initializers( + limit: int = Query(50, ge=1, le=200, description="Maximum items per page"), + cursor: Optional[str] = Query(None, description="Pagination cursor (initializer_name to start after)"), +) -> ListRegisteredInitializersResponse: + """ + List all available initializers. + + Returns initializer metadata including required environment variables, + supported parameters, and descriptions. + + Returns: + ListRegisteredInitializersResponse: Paginated list of initializer summaries. + """ + service = get_initializer_service() + return await service.list_initializers_async(limit=limit, cursor=cursor) + + +@router.get( + "/{initializer_name}", + response_model=RegisteredInitializer, + responses={ + 404: {"model": ProblemDetail, "description": "Initializer not found"}, + }, +) +async def get_initializer(initializer_name: str) -> RegisteredInitializer: + """ + Get details for a specific initializer. + + Args: + initializer_name: Registry name of the initializer (e.g., 'target'). + + Returns: + RegisteredInitializer: Full initializer metadata. + """ + service = get_initializer_service() + + initializer = await service.get_initializer_async(initializer_name=initializer_name) + if not initializer: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Initializer '{initializer_name}' not found", + ) + + return initializer diff --git a/pyrit/backend/services/__init__.py b/pyrit/backend/services/__init__.py index d36f69a830..9b110915ed 100644 --- a/pyrit/backend/services/__init__.py +++ b/pyrit/backend/services/__init__.py @@ -15,6 +15,10 @@ ConverterService, get_converter_service, ) +from pyrit.backend.services.initializer_service import ( + InitializerService, + get_initializer_service, +) from pyrit.backend.services.scenario_run_service import ( ScenarioRunService, get_scenario_run_service, @@ -33,6 +37,8 @@ "get_attack_service", "ConverterService", "get_converter_service", + "InitializerService", + "get_initializer_service", "ScenarioService", "get_scenario_service", "ScenarioRunService", diff --git a/pyrit/backend/services/initializer_service.py b/pyrit/backend/services/initializer_service.py new file mode 100644 index 0000000000..1f542f87d1 --- /dev/null +++ b/pyrit/backend/services/initializer_service.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Initializer service for listing available initializers. + +Provides read-only access to the InitializerRegistry, exposing initializer +metadata through the REST API. +""" + +from functools import lru_cache +from typing import Optional + +from pyrit.backend.models.common import PaginationInfo +from pyrit.backend.models.initializers import ( + InitializerParameterSummary, + ListRegisteredInitializersResponse, + RegisteredInitializer, +) +from pyrit.registry import InitializerMetadata, InitializerRegistry + + +def _metadata_to_registered_initializer(metadata: InitializerMetadata) -> RegisteredInitializer: + """ + Convert an InitializerMetadata dataclass to a RegisteredInitializer Pydantic model. + + Args: + metadata: The registry metadata for an initializer. + + Returns: + RegisteredInitializer Pydantic model. + """ + return RegisteredInitializer( + initializer_name=metadata.registry_name, + initializer_type=metadata.class_name, + description=metadata.class_description, + required_env_vars=list(metadata.required_env_vars), + supported_parameters=[ + InitializerParameterSummary( + name=name, + description=desc, + default=default, + ) + for name, desc, default in metadata.supported_parameters + ], + ) + + +class InitializerService: + """ + Service for listing available initializers. + + Uses InitializerRegistry as the source of truth for initializer metadata. + """ + + def __init__(self) -> None: + """Initialize the initializer service.""" + self._registry = InitializerRegistry.get_registry_singleton() + + async def list_initializers_async( + self, + *, + limit: int = 50, + cursor: Optional[str] = None, + ) -> ListRegisteredInitializersResponse: + """ + List all available initializers with pagination. + + Args: + limit: Maximum items to return per page. + cursor: Pagination cursor (initializer_name to start after). + + Returns: + ListRegisteredInitializersResponse with paginated initializer summaries. + """ + all_metadata = self._registry.list_metadata() + all_summaries = [_metadata_to_registered_initializer(m) for m in all_metadata] + + page, has_more = self._paginate(items=all_summaries, cursor=cursor, limit=limit) + next_cursor = page[-1].initializer_name if has_more and page else None + + return ListRegisteredInitializersResponse( + items=page, + pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), + ) + + async def get_initializer_async(self, *, initializer_name: str) -> Optional[RegisteredInitializer]: + """ + Get a single initializer by registry name. + + Args: + initializer_name: The registry key of the initializer (e.g., 'target'). + + Returns: + RegisteredInitializer if found, None otherwise. + """ + all_metadata = self._registry.list_metadata() + for metadata in all_metadata: + if metadata.registry_name == initializer_name: + return _metadata_to_registered_initializer(metadata) + return None + + @staticmethod + def _paginate( + *, + items: list[RegisteredInitializer], + cursor: Optional[str], + limit: int, + ) -> tuple[list[RegisteredInitializer], bool]: + """ + Apply cursor-based pagination. + + Args: + items: Full list of items. + cursor: Initializer name to start after. + limit: Maximum items per page. + + Returns: + Tuple of (paginated items, has_more flag). + """ + start_idx = 0 + if cursor: + for i, item in enumerate(items): + if item.initializer_name == cursor: + start_idx = i + 1 + break + + page = items[start_idx : start_idx + limit] + has_more = len(items) > start_idx + limit + return page, has_more + + +@lru_cache(maxsize=1) +def get_initializer_service() -> InitializerService: + """ + Get the global initializer service instance. + + Returns: + The singleton InitializerService instance. + """ + return InitializerService() diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index 26f9b21f60..1c3f2c9f86 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -404,19 +404,15 @@ def _build_response_from_db(self, *, scenario_result: ScenarioResult) -> Scenari status = ScenarioRunStatus(scenario_result.scenario_run_state) - # Build result fields for completed runs - strategies_used: list[str] = [] - total_attacks = 0 - completed_attacks = 0 - if status == ScenarioRunStatus.COMPLETED: - completed_attacks = sum( - 1 - for results in scenario_result.attack_results.values() - for ar in results - if ar.outcome in (AttackOutcome.SUCCESS, AttackOutcome.FAILURE) - ) - total_attacks = sum(len(results) for results in scenario_result.attack_results.values()) - strategies_used = scenario_result.get_strategies_used() + # Build result fields from DB (always computed so in-progress runs show progress) + completed_attacks = sum( + 1 + for results in scenario_result.attack_results.values() + for ar in results + if ar.outcome in (AttackOutcome.SUCCESS, AttackOutcome.FAILURE) + ) + total_attacks = sum(len(results) for results in scenario_result.attack_results.values()) + strategies_used = scenario_result.get_strategies_used() return ScenarioRunSummary( scenario_result_id=scenario_result_id, diff --git a/pyrit/backend/services/scenario_service.py b/pyrit/backend/services/scenario_service.py index a1588e21ac..f071f5947d 100644 --- a/pyrit/backend/services/scenario_service.py +++ b/pyrit/backend/services/scenario_service.py @@ -12,7 +12,11 @@ from typing import Optional from pyrit.backend.models.common import PaginationInfo -from pyrit.backend.models.scenarios import ListRegisteredScenariosResponse, RegisteredScenario +from pyrit.backend.models.scenarios import ( + ListRegisteredScenariosResponse, + RegisteredScenario, + ScenarioParameterSummary, +) from pyrit.registry import ScenarioMetadata, ScenarioRegistry @@ -35,6 +39,16 @@ def _metadata_to_registered_scenario(metadata: ScenarioMetadata) -> RegisteredSc all_strategies=list(metadata.all_strategies), default_datasets=list(metadata.default_datasets), max_dataset_size=metadata.max_dataset_size, + supported_parameters=[ + ScenarioParameterSummary( + name=p.name, + description=p.description, + default=repr(p.default) if p.default is not None else None, + param_type=p.param_type, + choices=p.choices, + ) + for p in metadata.supported_parameters + ], ) diff --git a/tests/unit/backend/test_initializer_service.py b/tests/unit/backend/test_initializer_service.py new file mode 100644 index 0000000000..4601ee8678 --- /dev/null +++ b/tests/unit/backend/test_initializer_service.py @@ -0,0 +1,291 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for backend initializer service and routes. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + +from pyrit.backend.main import app +from pyrit.backend.models.common import PaginationInfo +from pyrit.backend.models.initializers import ( + InitializerParameterSummary, + ListRegisteredInitializersResponse, + RegisteredInitializer, +) +from pyrit.backend.services.initializer_service import InitializerService, get_initializer_service +from pyrit.registry import InitializerMetadata + + +@pytest.fixture +def client() -> TestClient: + """Create a test client for the FastAPI app.""" + return TestClient(app) + + +@pytest.fixture(autouse=True) +def clear_service_cache(): + """Clear the initializer service singleton cache between tests.""" + get_initializer_service.cache_clear() + yield + get_initializer_service.cache_clear() + + +def _make_initializer_metadata( + *, + registry_name: str = "target", + class_name: str = "TargetInitializer", + description: str = "Registers targets", + required_env_vars: tuple[str, ...] = ("AZURE_OPENAI_ENDPOINT",), + supported_parameters: tuple[tuple[str, str, list[str] | None], ...] = ( + ("tags", "Comma-separated tag filter", ["default"]), + ), +) -> InitializerMetadata: + """Create an InitializerMetadata instance for testing.""" + return InitializerMetadata( + registry_name=registry_name, + class_name=class_name, + class_module="pyrit.setup.initializers.target", + class_description=description, + required_env_vars=required_env_vars, + supported_parameters=supported_parameters, + ) + + +# ============================================================================ +# InitializerService Unit Tests +# ============================================================================ + + +class TestInitializerServiceListInitializers: + """Tests for InitializerService.list_initializers_async.""" + + async def test_list_initializers_returns_empty_when_no_initializers(self) -> None: + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [] + + result = await service.list_initializers_async() + + assert result.items == [] + assert result.pagination.has_more is False + + async def test_list_initializers_returns_initializers_from_registry(self) -> None: + metadata = _make_initializer_metadata() + + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [metadata] + + result = await service.list_initializers_async() + + assert len(result.items) == 1 + item = result.items[0] + assert item.initializer_name == "target" + assert item.initializer_type == "TargetInitializer" + assert item.description == "Registers targets" + assert item.required_env_vars == ["AZURE_OPENAI_ENDPOINT"] + assert len(item.supported_parameters) == 1 + assert item.supported_parameters[0].name == "tags" + assert item.supported_parameters[0].description == "Comma-separated tag filter" + assert item.supported_parameters[0].default == ["default"] + + async def test_list_initializers_paginates_with_limit(self) -> None: + metadata_list = [ + _make_initializer_metadata(registry_name=f"init_{i}", class_name=f"Init{i}") for i in range(5) + ] + + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = metadata_list + + result = await service.list_initializers_async(limit=3) + + assert len(result.items) == 3 + assert result.pagination.has_more is True + assert result.pagination.next_cursor == "init_2" + + async def test_list_initializers_paginates_with_cursor(self) -> None: + metadata_list = [ + _make_initializer_metadata(registry_name=f"init_{i}", class_name=f"Init{i}") for i in range(5) + ] + + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = metadata_list + + result = await service.list_initializers_async(limit=2, cursor="init_1") + + assert len(result.items) == 2 + assert result.items[0].initializer_name == "init_2" + assert result.items[1].initializer_name == "init_3" + assert result.pagination.has_more is True + + async def test_list_initializers_last_page_has_more_false(self) -> None: + metadata_list = [ + _make_initializer_metadata(registry_name=f"init_{i}", class_name=f"Init{i}") for i in range(3) + ] + + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = metadata_list + + result = await service.list_initializers_async(limit=5) + + assert len(result.items) == 3 + assert result.pagination.has_more is False + assert result.pagination.next_cursor is None + + async def test_list_initializers_with_no_env_vars(self) -> None: + metadata = _make_initializer_metadata(required_env_vars=(), supported_parameters=()) + + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [metadata] + + result = await service.list_initializers_async() + + assert result.items[0].required_env_vars == [] + assert result.items[0].supported_parameters == [] + + +class TestInitializerServiceGetInitializer: + """Tests for InitializerService.get_initializer_async.""" + + async def test_get_initializer_returns_matching_initializer(self) -> None: + metadata = _make_initializer_metadata(registry_name="target") + + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [metadata] + + result = await service.get_initializer_async(initializer_name="target") + + assert result is not None + assert result.initializer_name == "target" + + async def test_get_initializer_returns_none_for_missing(self) -> None: + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [] + + result = await service.get_initializer_async(initializer_name="nonexistent") + + assert result is None + + +# ============================================================================ +# Route Tests +# ============================================================================ + + +class TestInitializerRoutes: + """Tests for initializer API routes.""" + + def test_list_initializers_returns_200(self, client: TestClient) -> None: + with patch("pyrit.backend.routes.initializers.get_initializer_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_initializers_async = AsyncMock( + return_value=ListRegisteredInitializersResponse( + items=[], + pagination=PaginationInfo(limit=50, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/initializers") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["items"] == [] + assert data["pagination"]["has_more"] is False + + def test_list_initializers_with_items(self, client: TestClient) -> None: + summary = RegisteredInitializer( + initializer_name="target", + initializer_type="TargetInitializer", + description="Registers targets", + required_env_vars=["AZURE_OPENAI_ENDPOINT"], + supported_parameters=[ + InitializerParameterSummary(name="tags", description="Tag filter", default=["default"]) + ], + ) + + with patch("pyrit.backend.routes.initializers.get_initializer_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_initializers_async = AsyncMock( + return_value=ListRegisteredInitializersResponse( + items=[summary], + pagination=PaginationInfo(limit=50, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/initializers") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert len(data["items"]) == 1 + item = data["items"][0] + assert item["initializer_name"] == "target" + assert item["initializer_type"] == "TargetInitializer" + assert item["required_env_vars"] == ["AZURE_OPENAI_ENDPOINT"] + assert item["supported_parameters"][0]["name"] == "tags" + + def test_list_initializers_passes_pagination_params(self, client: TestClient) -> None: + with patch("pyrit.backend.routes.initializers.get_initializer_service") as mock_get_service: + mock_service = MagicMock() + mock_service.list_initializers_async = AsyncMock( + return_value=ListRegisteredInitializersResponse( + items=[], + pagination=PaginationInfo(limit=10, has_more=False, next_cursor=None, prev_cursor=None), + ) + ) + mock_get_service.return_value = mock_service + + response = client.get("/api/initializers?limit=10&cursor=target") + + assert response.status_code == status.HTTP_200_OK + mock_service.list_initializers_async.assert_called_once_with(limit=10, cursor="target") + + def test_get_initializer_returns_200(self, client: TestClient) -> None: + summary = RegisteredInitializer( + initializer_name="target", + initializer_type="TargetInitializer", + description="Registers targets", + required_env_vars=["AZURE_OPENAI_ENDPOINT"], + ) + + with patch("pyrit.backend.routes.initializers.get_initializer_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_initializer_async = AsyncMock(return_value=summary) + mock_get_service.return_value = mock_service + + response = client.get("/api/initializers/target") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["initializer_name"] == "target" + + def test_get_initializer_returns_404_when_not_found(self, client: TestClient) -> None: + with patch("pyrit.backend.routes.initializers.get_initializer_service") as mock_get_service: + mock_service = MagicMock() + mock_service.get_initializer_async = AsyncMock(return_value=None) + mock_get_service.return_value = mock_service + + response = client.get("/api/initializers/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 26fa81a814..83b511f669 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -490,3 +490,84 @@ def test_get_results_returns_details_for_completed_run(self, mock_memory) -> Non assert detail.attacks[0].success_count == 1 assert detail.attacks[0].results[0].objective == "Extract info" assert detail.attacks[0].results[0].outcome == "success" + + +class TestScenarioRunServiceProgressReporting: + """Tests that in-progress runs expose partial attack counts.""" + + def test_in_progress_run_shows_partial_attack_counts(self, mock_memory) -> None: + """Test that polling an IN_PROGRESS run shows incremental results.""" + from pyrit.models import AttackOutcome + + mock_success = MagicMock() + mock_success.outcome = AttackOutcome.SUCCESS + mock_failure = MagicMock() + mock_failure.outcome = AttackOutcome.FAILURE + mock_undetermined = MagicMock() + mock_undetermined.outcome = AttackOutcome.UNDETERMINED + + db_result = _make_db_scenario_result( + result_id="sr-running", + run_state="IN_PROGRESS", + attack_results={ + "attack_a": [mock_success, mock_failure], + "attack_b": [mock_undetermined], + }, + ) + db_result.get_strategies_used.return_value = ["attack_a", "attack_b"] + db_result.objective_achieved_rate.return_value = 33 + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + fetched = service.get_run(scenario_result_id="sr-running") + + assert fetched is not None + assert fetched.status == ScenarioRunStatus.IN_PROGRESS + assert fetched.total_attacks == 3 + assert fetched.completed_attacks == 2 + assert fetched.strategies_used == ["attack_a", "attack_b"] + assert fetched.objective_achieved_rate == 33 + + def test_created_run_shows_zero_counts(self, mock_memory) -> None: + """Test that a CREATED run with no results shows zero counts.""" + db_result = _make_db_scenario_result( + result_id="sr-new", + run_state="CREATED", + attack_results={}, + ) + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + fetched = service.get_run(scenario_result_id="sr-new") + + assert fetched is not None + assert fetched.status == ScenarioRunStatus.CREATED + assert fetched.total_attacks == 0 + assert fetched.completed_attacks == 0 + assert fetched.strategies_used == [] + + def test_completed_run_still_shows_full_counts(self, mock_memory) -> None: + """Test that COMPLETED runs still show accurate counts after the fix.""" + from pyrit.models import AttackOutcome + + mock_success = MagicMock() + mock_success.outcome = AttackOutcome.SUCCESS + + db_result = _make_db_scenario_result( + result_id="sr-done", + run_state="COMPLETED", + attack_results={"attack_a": [mock_success]}, + ) + db_result.get_strategies_used.return_value = ["attack_a"] + db_result.objective_achieved_rate.return_value = 100 + mock_memory.get_scenario_results.return_value = [db_result] + + service = ScenarioRunService() + fetched = service.get_run(scenario_result_id="sr-done") + + assert fetched is not None + assert fetched.status == ScenarioRunStatus.COMPLETED + assert fetched.total_attacks == 1 + assert fetched.completed_attacks == 1 + assert fetched.strategies_used == ["attack_a"] + assert fetched.objective_achieved_rate == 100 diff --git a/tests/unit/backend/test_scenario_service.py b/tests/unit/backend/test_scenario_service.py index 985148ca0c..aa88ad3881 100644 --- a/tests/unit/backend/test_scenario_service.py +++ b/tests/unit/backend/test_scenario_service.py @@ -16,6 +16,7 @@ from pyrit.backend.models.scenarios import ListRegisteredScenariosResponse, RegisteredScenario from pyrit.backend.services.scenario_service import ScenarioService, get_scenario_service from pyrit.registry import ScenarioMetadata +from pyrit.registry.class_registries.scenario_registry import ScenarioParameterMetadata @pytest.fixture @@ -331,3 +332,112 @@ def test_get_scenario_with_dotted_name(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK mock_service.get_scenario_async.assert_called_once_with(scenario_name="garak.encoding") + + +# ============================================================================ +# Supported Parameters Tests +# ============================================================================ + + +class TestScenarioServiceSupportedParameters: + """Tests for supported_parameters in scenario service responses.""" + + async def test_list_scenarios_includes_supported_parameters(self) -> None: + """Test that supported_parameters are included in scenario listing.""" + metadata = _make_scenario_metadata(registry_name="param.scenario") + metadata = ScenarioMetadata( + registry_name="param.scenario", + class_name="ParamScenario", + class_module="pyrit.scenario.scenarios.param", + class_description="A scenario with params", + default_strategy="default", + all_strategies=("prompt_sending",), + aggregate_strategies=("all",), + default_datasets=("test_dataset",), + max_dataset_size=None, + supported_parameters=( + ScenarioParameterMetadata( + name="max_turns", + description="Maximum number of turns", + default=5, + param_type="int", + choices=None, + ), + ScenarioParameterMetadata( + name="mode", + description="Execution mode", + default="fast", + param_type="str", + choices="'fast', 'slow'", + ), + ), + ) + + with patch.object(ScenarioService, "__init__", lambda self: None): + service = ScenarioService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [metadata] + + result = await service.list_scenarios_async() + + assert len(result.items) == 1 + params = result.items[0].supported_parameters + assert len(params) == 2 + + assert params[0].name == "max_turns" + assert params[0].description == "Maximum number of turns" + assert params[0].default == "5" + assert params[0].param_type == "int" + assert params[0].choices is None + + assert params[1].name == "mode" + assert params[1].description == "Execution mode" + assert params[1].default == "'fast'" + assert params[1].param_type == "str" + assert params[1].choices == "'fast', 'slow'" + + async def test_scenario_with_no_parameters_has_empty_list(self) -> None: + """Test that scenarios without parameters have empty supported_parameters.""" + metadata = _make_scenario_metadata() + + with patch.object(ScenarioService, "__init__", lambda self: None): + service = ScenarioService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [metadata] + + result = await service.list_scenarios_async() + + assert result.items[0].supported_parameters == [] + + async def test_supported_parameters_with_none_default(self) -> None: + """Test that parameters with None default are serialized correctly.""" + metadata = ScenarioMetadata( + registry_name="test.scenario", + class_name="TestScenario", + class_module="pyrit.scenario.scenarios.test", + class_description="Test", + default_strategy="default", + all_strategies=("all",), + aggregate_strategies=("all",), + default_datasets=(), + max_dataset_size=None, + supported_parameters=( + ScenarioParameterMetadata( + name="optional_param", + description="An optional param", + default=None, + param_type="str", + choices=None, + ), + ), + ) + + with patch.object(ScenarioService, "__init__", lambda self: None): + service = ScenarioService() + service._registry = MagicMock() + service._registry.list_metadata.return_value = [metadata] + + result = await service.list_scenarios_async() + + param = result.items[0].supported_parameters[0] + assert param.default is None From 40cfea22b1b3eb7e3ea43477ff6bc9ec6fd56720 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Tue, 12 May 2026 17:14:58 -0700 Subject: [PATCH 2/2] pre-commit --- pyrit/backend/main.py | 13 ++++++++++++- pyrit/backend/models/initializers.py | 2 +- tests/unit/backend/test_initializer_service.py | 12 +++--------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/pyrit/backend/main.py b/pyrit/backend/main.py index fe19894459..365d2b5656 100644 --- a/pyrit/backend/main.py +++ b/pyrit/backend/main.py @@ -18,7 +18,18 @@ import pyrit from pyrit.backend.middleware import RequestIdMiddleware, SecurityHeadersMiddleware, register_error_handlers from pyrit.backend.middleware.auth import EntraAuthMiddleware -from pyrit.backend.routes import attacks, auth, converters, health, initializers, labels, media, scenarios, targets, version +from pyrit.backend.routes import ( + attacks, + auth, + converters, + health, + initializers, + labels, + media, + scenarios, + targets, + version, +) from pyrit.memory import CentralMemory # Check for development mode from environment variable diff --git a/pyrit/backend/models/initializers.py b/pyrit/backend/models/initializers.py index 4df752f70b..15174dfd53 100644 --- a/pyrit/backend/models/initializers.py +++ b/pyrit/backend/models/initializers.py @@ -8,7 +8,7 @@ before scenario execution. These models represent initializer metadata. """ -from typing import Any, Optional +from typing import Optional from pydantic import BaseModel, Field diff --git a/tests/unit/backend/test_initializer_service.py b/tests/unit/backend/test_initializer_service.py index 4601ee8678..8c3c5977d0 100644 --- a/tests/unit/backend/test_initializer_service.py +++ b/tests/unit/backend/test_initializer_service.py @@ -98,9 +98,7 @@ async def test_list_initializers_returns_initializers_from_registry(self) -> Non assert item.supported_parameters[0].default == ["default"] async def test_list_initializers_paginates_with_limit(self) -> None: - metadata_list = [ - _make_initializer_metadata(registry_name=f"init_{i}", class_name=f"Init{i}") for i in range(5) - ] + metadata_list = [_make_initializer_metadata(registry_name=f"init_{i}", class_name=f"Init{i}") for i in range(5)] with patch.object(InitializerService, "__init__", lambda self: None): service = InitializerService() @@ -114,9 +112,7 @@ async def test_list_initializers_paginates_with_limit(self) -> None: assert result.pagination.next_cursor == "init_2" async def test_list_initializers_paginates_with_cursor(self) -> None: - metadata_list = [ - _make_initializer_metadata(registry_name=f"init_{i}", class_name=f"Init{i}") for i in range(5) - ] + metadata_list = [_make_initializer_metadata(registry_name=f"init_{i}", class_name=f"Init{i}") for i in range(5)] with patch.object(InitializerService, "__init__", lambda self: None): service = InitializerService() @@ -131,9 +127,7 @@ async def test_list_initializers_paginates_with_cursor(self) -> None: assert result.pagination.has_more is True async def test_list_initializers_last_page_has_more_false(self) -> None: - metadata_list = [ - _make_initializer_metadata(registry_name=f"init_{i}", class_name=f"Init{i}") for i in range(3) - ] + metadata_list = [_make_initializer_metadata(registry_name=f"init_{i}", class_name=f"Init{i}") for i in range(3)] with patch.object(InitializerService, "__init__", lambda self: None): service = InitializerService()