diff --git a/README.md b/README.md index e5b5dde..7f22289 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,25 @@ A derived class must either: If you change the keyword arguments to a class which derives from `Serializable` but would like to be able to deserialize older JSON representations then you can define a class-level dictionary called `_KEYWORD_ALIASES` which maps old keywords to new names (or `None` if a keyword was removed). +## `DataclassSerializable` for `@dataclass` subclasses + +If you're using `@dataclass` (e.g. in `vaxrank`, `pyensembl`, or `varcode`), inherit from `DataclassSerializable` instead of `Serializable`. It provides the same serialization surface — `to_dict` / `from_dict` / `to_json` / `from_json` — but leaves `__init__`, `__eq__`, `__repr__`, and `__hash__` to `@dataclass`, so you get dataclass-native equality and repr without conflicts. + +```python +from dataclasses import dataclass +from serializable import DataclassSerializable + +@dataclass +class Point(DataclassSerializable): + x: float + y: float + +p = Point(1.0, 2.0) +assert Point.from_json(p.to_json()) == p +``` + +The on-wire JSON format is identical to `Serializable`, so mixed codebases interoperate: a `DataclassSerializable` instance can reference a legacy `Serializable` object (and vice versa) and still round-trip cleanly. The `_SERIALIZABLE_KEYWORD_ALIASES` hook works the same way for migrating field names across releases. + ## Limitations - Serializable objects must inherit from `Serializable`, be tuples or namedtuples, be serializble primitive types such as dict, list, int, float, or str. diff --git a/serializable/__init__.py b/serializable/__init__.py index 6b0863f..ceff428 100644 --- a/serializable/__init__.py +++ b/serializable/__init__.py @@ -11,6 +11,7 @@ # limitations under the License. +from .dataclass_serializable import DataclassSerializable from .helpers import ( from_json, from_serializable_repr, @@ -22,6 +23,7 @@ from .version import __version__ __all__ = [ + "DataclassSerializable", "Serializable", "from_json", "from_serializable_repr", diff --git a/serializable/dataclass_serializable.py b/serializable/dataclass_serializable.py new file mode 100644 index 0000000..9187a73 --- /dev/null +++ b/serializable/dataclass_serializable.py @@ -0,0 +1,93 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Serialization mixin for classes decorated with ``@dataclass``. + +``Serializable`` (the original base class) supplies its own ``__init__`` +introspection plus ``__eq__`` / ``__repr__`` / ``__hash__``, which clashes +with the methods ``@dataclass`` generates. ``DataclassSerializable`` is a +lightweight alternative that contributes only the serialization surface — +``to_dict`` / ``from_dict`` / ``to_json`` / ``from_json`` — and leaves +equality, repr, hashing, and ``__init__`` to ``@dataclass``. + +Wire format parity with ``Serializable`` is preserved: the underlying +``to_serializable_repr`` helper dispatches on ``obj.to_dict()`` regardless +of which base the class inherits from, so a mixed codebase — some classes +migrated, some still on ``Serializable`` — round-trips JSON cleanly. + +Example:: + + from dataclasses import dataclass + from serializable import DataclassSerializable + + @dataclass + class Point(DataclassSerializable): + x: float + y: float + + p = Point(1.0, 2.0) + assert Point.from_json(p.to_json()) == p +""" + +from __future__ import annotations + +from dataclasses import fields +from typing import Any, ClassVar + +from .helpers import from_json, from_serializable_repr, to_json, to_serializable_repr + + +class DataclassSerializable: + """Mixin providing ``to_dict`` / ``from_dict`` / ``to_json`` / ``from_json`` + for ``@dataclass``-decorated subclasses, without overriding the dunder + methods that ``@dataclass`` generates. + + Subclasses may set ``_SERIALIZABLE_KEYWORD_ALIASES`` to migrate old + field names across releases: map an old name to the new name, or to + ``None`` to drop it on load. This mirrors the same hook on + ``Serializable``. + """ + + _SERIALIZABLE_KEYWORD_ALIASES: ClassVar[dict[str, str | None]] = {} + + def to_dict(self) -> dict[str, Any]: + """Return a dict mapping each dataclass field name to its current + value. Keys match the ``__init__`` keyword arguments, so + ``cls(**obj.to_dict())`` reconstructs an equal instance.""" + return {f.name: getattr(self, f.name) for f in fields(self)} + + @classmethod + def from_dict(cls, state_dict: dict[str, Any]): + """Reconstruct an instance from a ``to_dict``-shaped dictionary, + applying ``_SERIALIZABLE_KEYWORD_ALIASES`` for backwards compat.""" + kwargs = dict(state_dict) + for klass in cls.__mro__: + aliases = getattr(klass, "_SERIALIZABLE_KEYWORD_ALIASES", {}) + for old_name, new_name in aliases.items(): + if old_name in kwargs: + value = kwargs.pop(old_name) + if new_name is not None and new_name not in kwargs: + kwargs[new_name] = value + return cls(**kwargs) + + def to_json(self) -> str: + return to_json(self) + + @classmethod + def from_json(cls, json_string: str): + return from_json(json_string) + + def __reduce__(self): + """Pickle via the same to_dict / from_dict path used for JSON so + pickled objects round-trip even when field order or internal + representation changes between releases.""" + return (from_serializable_repr, (to_serializable_repr(self),)) diff --git a/serializable/version.py b/serializable/version.py index 89fed8f..7c0f2bf 100644 --- a/serializable/version.py +++ b/serializable/version.py @@ -1,4 +1,4 @@ -__version__ = "1.0.0" +__version__ = "1.1.0" def print_version(): diff --git a/tests/test_dataclass_serializable.py b/tests/test_dataclass_serializable.py new file mode 100644 index 0000000..d5b841d --- /dev/null +++ b/tests/test_dataclass_serializable.py @@ -0,0 +1,184 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pickle +from dataclasses import dataclass, field +from typing import ClassVar + +import pytest + +from serializable import DataclassSerializable, Serializable, from_json, to_json + +from .common import eq_ + + +@dataclass +class Point(DataclassSerializable): + x: float + y: float + + +@dataclass +class Person(DataclassSerializable): + name: str + age: int + tags: list[str] = field(default_factory=list) + + +class LegacyTag(Serializable): + def __init__(self, name): + self.name = name + + +@dataclass +class Tagged(DataclassSerializable): + label: str + tag: LegacyTag | None = None + + +@dataclass +class Inner(DataclassSerializable): + n: int + + +@dataclass +class Outer(DataclassSerializable): + name: str + inner: Inner + + +@dataclass(frozen=True) +class FrozenPoint(DataclassSerializable): + x: float + y: float + + +def test_to_dict_returns_field_values(): + p = Point(1.0, 2.0) + eq_(p.to_dict(), {"x": 1.0, "y": 2.0}) + + +def test_from_dict_reconstructs_instance(): + reconstructed = Point.from_dict({"x": 3.0, "y": 4.0}) + eq_(reconstructed, Point(3.0, 4.0)) + + +def test_json_roundtrip_simple(): + p = Point(1.0, 2.0) + eq_(Point.from_json(p.to_json()), p) + + +def test_json_roundtrip_with_collection_field(): + person = Person(name="Ada", age=36, tags=["mathematician", "engineer"]) + eq_(Person.from_json(person.to_json()), person) + + +def test_module_level_to_json_accepts_dataclass_serializable(): + p = Point(1.0, 2.0) + # Calling the module-level helpers directly should still work since + # to_serializable_repr dispatches on obj.to_dict(). + eq_(from_json(to_json(p)), p) + + +def test_pickle_roundtrip(): + p = Point(1.0, 2.0) + eq_(pickle.loads(pickle.dumps(p)), p) + + +def test_dataclass_eq_and_repr_not_overridden_by_mixin(): + # @dataclass generates __eq__ and __repr__ — the mixin must not shadow them. + a = Point(1.0, 2.0) + b = Point(1.0, 2.0) + c = Point(1.0, 3.0) + assert a == b + assert a != c + assert repr(a) == "Point(x=1.0, y=2.0)" + + +def test_frozen_dataclass_is_hashable(): + p1 = FrozenPoint(1.0, 2.0) + p2 = FrozenPoint(1.0, 2.0) + # Equal, hashable, and usable as a set member. + assert p1 == p2 + assert hash(p1) == hash(p2) + assert {p1, p2} == {p1} + + +def test_frozen_dataclass_pickle_roundtrip(): + # __reduce__ goes through cls(**kwargs), which must work on frozen + # dataclasses even though their __setattr__ is disabled. + p = FrozenPoint(1.0, 2.0) + eq_(pickle.loads(pickle.dumps(p)), p) + + +def test_keyword_aliases_rename(): + @dataclass + class Renamed(DataclassSerializable): + new_name: str + _SERIALIZABLE_KEYWORD_ALIASES: ClassVar[dict[str, str | None]] = {"old_name": "new_name"} + + # Old wire format still loads. + obj = Renamed.from_dict({"old_name": "hello"}) + eq_(obj, Renamed(new_name="hello")) + + +def test_keyword_aliases_drop(): + @dataclass + class Dropped(DataclassSerializable): + kept: int + _SERIALIZABLE_KEYWORD_ALIASES: ClassVar[dict[str, str | None]] = {"removed": None} + + # Old wire format with an extra field that has since been dropped. + obj = Dropped.from_dict({"kept": 5, "removed": "ignored"}) + eq_(obj, Dropped(kept=5)) + + +def test_keyword_aliases_inherited_from_parent(): + # Aliases defined on a parent class should apply when loading a child, + # so migrations can live on a shared base without every subclass having + # to restate them. + @dataclass + class ParentWithAliases(DataclassSerializable): + kept: int + _SERIALIZABLE_KEYWORD_ALIASES: ClassVar[dict[str, str | None]] = { + "old_kept": "kept", + "removed": None, + } + + @dataclass + class Child(ParentWithAliases): + extra: str = "" + + # Rename and drop should both fire via the inherited alias dict. + obj = Child.from_dict({"old_kept": 7, "removed": "gone", "extra": "hi"}) + eq_(obj, Child(kept=7, extra="hi")) + + +def test_from_dict_rejects_unknown_field_without_alias(): + with pytest.raises(TypeError): + Point.from_dict({"x": 1.0, "y": 2.0, "z": 3.0}) + + +def test_interop_with_legacy_serializable(): + # A legacy Serializable instance referenced from a DataclassSerializable + # field should round-trip through the shared wire format. + t = Tagged(label="x", tag=LegacyTag("demo")) + restored = Tagged.from_json(t.to_json()) + eq_(restored.label, "x") + eq_(restored.tag.name, "demo") + + +def test_nested_dataclass_serializable_roundtrip(): + o = Outer(name="parent", inner=Inner(n=7)) + eq_(Outer.from_json(o.to_json()), o)