Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions tests/operation/parallel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import importlib
import json
from collections.abc import Mapping
from typing import Any
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -52,6 +54,11 @@ def create_test_context(
)


def _mock_call_kwargs_by_operation_id(mock: Mock) -> dict[str, Mapping[str, Any]]:
"""Return mock call keyword arguments keyed by operation_id."""
return {call.kwargs["operation_id"]: call.kwargs for call in mock.call_args_list}


def test_parallel_executor_init():
"""Test ParallelExecutor initialization."""
executables = [Executable(index=0, func=lambda x: x)]
Expand Down Expand Up @@ -825,12 +832,12 @@ def create_id(self, i):
)

expected = item_serdes or batch_serdes
assert mock_serialize.call_args_list[0][1]["serdes"] is expected
assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0"
assert mock_serialize.call_args_list[1][1]["serdes"] is expected
assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1"
assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes
assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent"
calls_by_operation_id = _mock_call_kwargs_by_operation_id(mock_serialize)

assert set(calls_by_operation_id) == {"child-0", "child-1", "parent"}
assert calls_by_operation_id["child-0"]["serdes"] is expected
assert calls_by_operation_id["child-1"]["serdes"] is expected
assert calls_by_operation_id["parent"]["serdes"] is batch_serdes


@pytest.mark.parametrize(
Expand Down Expand Up @@ -886,10 +893,11 @@ def create_id(self, i):
)

expected = item_serdes or batch_serdes
assert mock_deserialize.call_args_list[0][1]["serdes"] is expected
assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0"
assert mock_deserialize.call_args_list[1][1]["serdes"] is expected
assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1"
calls_by_operation_id = _mock_call_kwargs_by_operation_id(mock_deserialize)

assert set(calls_by_operation_id) == {"child-0", "child-1"}
assert calls_by_operation_id["child-0"]["serdes"] is expected
assert calls_by_operation_id["child-1"]["serdes"] is expected


def test_parallel_result_serialization_roundtrip():
Expand Down
Loading