diff --git a/tests/operation/parallel_test.py b/tests/operation/parallel_test.py index 15a8346..1922207 100644 --- a/tests/operation/parallel_test.py +++ b/tests/operation/parallel_test.py @@ -2,6 +2,8 @@ import importlib import json +from collections.abc import Mapping +from typing import Any from unittest.mock import Mock, patch import pytest @@ -52,6 +54,11 @@ def create_test_context( ) +def _mock_call_kwargs_by_operation_id(mock: Mock) -> dict[str, Mapping[str, Any]]: + """Return mock call keyword arguments keyed by operation_id.""" + return {call.kwargs["operation_id"]: call.kwargs for call in mock.call_args_list} + + def test_parallel_executor_init(): """Test ParallelExecutor initialization.""" executables = [Executable(index=0, func=lambda x: x)] @@ -825,12 +832,12 @@ def create_id(self, i): ) expected = item_serdes or batch_serdes - assert mock_serialize.call_args_list[0][1]["serdes"] is expected - assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0" - assert mock_serialize.call_args_list[1][1]["serdes"] is expected - assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1" - assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes - assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent" + calls_by_operation_id = _mock_call_kwargs_by_operation_id(mock_serialize) + + assert set(calls_by_operation_id) == {"child-0", "child-1", "parent"} + assert calls_by_operation_id["child-0"]["serdes"] is expected + assert calls_by_operation_id["child-1"]["serdes"] is expected + assert calls_by_operation_id["parent"]["serdes"] is batch_serdes @pytest.mark.parametrize( @@ -886,10 +893,11 @@ def create_id(self, i): ) expected = item_serdes or batch_serdes - assert mock_deserialize.call_args_list[0][1]["serdes"] is expected - assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0" - assert mock_deserialize.call_args_list[1][1]["serdes"] is expected - assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1" + calls_by_operation_id = _mock_call_kwargs_by_operation_id(mock_deserialize) + + assert set(calls_by_operation_id) == {"child-0", "child-1"} + assert calls_by_operation_id["child-0"]["serdes"] is expected + assert calls_by_operation_id["child-1"]["serdes"] is expected def test_parallel_result_serialization_roundtrip():