-
Notifications
You must be signed in to change notification settings - Fork 0
feat: Add per-column metrics to summary #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ae98b68
31bcd39
db7735f
a1a854e
d3cc4f1
8261168
584ba71
f0e78b7
025d8ff
5f5e0e5
4d79870
266bebd
f32f793
8bca8d6
019d914
73bb89f
58645d0
36d9ecb
5eb965c
e8aa08a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |||||
|
|
||||||
| from ._compat import typer | ||||||
| from ._utils import ABS_TOL_DEFAULT, ABS_TOL_TEMPORAL_DEFAULT, REL_TOL_DEFAULT | ||||||
| from .metrics import _PRESETS | ||||||
|
|
||||||
| app = typer.Typer() | ||||||
|
|
||||||
|
|
@@ -129,8 +130,24 @@ def main( | |||||
| ) | ||||||
| ), | ||||||
| ] = [], | ||||||
| metric: Annotated[ | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| list[str], | ||||||
| typer.Option( | ||||||
| help=( | ||||||
| "Metric presets to display per numerical column. Repeatable. " | ||||||
| f"Available: {', '.join(_PRESETS)}." | ||||||
| ) | ||||||
| ), | ||||||
| ] = [], | ||||||
| ) -> None: | ||||||
| """Compare two `parquet` files and print the comparison result.""" | ||||||
| for name in metric: | ||||||
| if name not in _PRESETS: | ||||||
| raise typer.BadParameter( | ||||||
| f"Unknown metric: {name!r}. Available: {', '.join(_PRESETS)}." | ||||||
| ) | ||||||
| metrics = {name: _PRESETS[name] for name in metric} | ||||||
|
|
||||||
| comparison = compare_frames( | ||||||
| pl.scan_parquet(left), | ||||||
| pl.scan_parquet(right), | ||||||
|
|
@@ -148,6 +165,7 @@ def main( | |||||
| right_name=right_name, | ||||||
| slim=slim, | ||||||
| hidden_columns=hidden_columns, | ||||||
| metrics=metrics, | ||||||
| ) | ||||||
| if output_json: | ||||||
| typer.echo(summary.to_json()) | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Copyright (c) QuantCo 2025-2026 | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Callable | ||
| from dataclasses import dataclass | ||
|
|
||
| import polars as pl | ||
| import polars.selectors as cs | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class _Metric: | ||
| """A metric paired with a column-applicability selector. | ||
|
|
||
| Internal only. | ||
| """ | ||
|
|
||
| fn: Metric | ||
| selector: pl.Expr | ||
|
|
||
|
|
||
| Metric = Callable[[pl.Expr, pl.Expr], pl.Expr] | ||
| """A metric is a callable mapping ``(left_expr, right_expr)`` to a scalar aggregation | ||
| expression. | ||
|
|
||
| The expressions refer to the left-side and right-side values of a single column across | ||
| all joined rows. | ||
| """ | ||
|
|
||
|
|
||
| def _make_numeric_metric(metric: Metric) -> _Metric: | ||
| return _Metric(fn=metric, selector=cs.numeric()) | ||
|
|
||
|
|
||
| def mean(left: pl.Expr, right: pl.Expr) -> pl.Expr: | ||
| """Mean of ``right - left``.""" | ||
| return (right - left).mean() | ||
|
|
||
|
|
||
| def median(left: pl.Expr, right: pl.Expr) -> pl.Expr: | ||
| """Median of ``right - left``.""" | ||
| return (right - left).median() | ||
|
|
||
|
|
||
| def min(left: pl.Expr, right: pl.Expr) -> pl.Expr: | ||
| """Minimum of ``right - left``.""" | ||
| return (right - left).min() | ||
|
|
||
|
|
||
| def max(left: pl.Expr, right: pl.Expr) -> pl.Expr: | ||
| """Maximum of ``right - left``.""" | ||
| return (right - left).max() | ||
|
|
||
|
|
||
| def std(left: pl.Expr, right: pl.Expr) -> pl.Expr: | ||
| """Standard deviation of ``right - left``.""" | ||
| return (right - left).std() | ||
|
|
||
|
|
||
| def mean_absolute_deviation(left: pl.Expr, right: pl.Expr) -> pl.Expr: | ||
| """Mean of ``|right - left|``.""" | ||
| return (right - left).abs().mean() | ||
|
|
||
|
|
||
| def mean_relative_deviation(left: pl.Expr, right: pl.Expr) -> pl.Expr: | ||
| """Mean of ``|(right - left) / left|``. Yields ``inf`` or ``null`` where | ||
| ``left`` is zero.""" | ||
| return ((right - left) / left).abs().mean() | ||
|
|
||
|
|
||
| def quantile(q: float) -> Metric: | ||
| """Factory returning a metric that computes the ``q``-quantile of | ||
| ``right - left``.""" | ||
| if not 0 <= q <= 1: | ||
| raise ValueError(f"q must be in [0, 1], got {q}") | ||
|
|
||
| def _quantile(left: pl.Expr, right: pl.Expr) -> pl.Expr: | ||
| return (right - left).quantile(q) | ||
|
|
||
| return _quantile | ||
|
EgeKaraismailogluQC marked this conversation as resolved.
|
||
|
|
||
|
|
||
| _PRESETS: dict[str, Metric] = { | ||
| "Mean": mean, | ||
| "Median": median, | ||
| "Min": min, | ||
| "Max": max, | ||
| "Std": std, | ||
| "Mean absolute deviation": mean_absolute_deviation, | ||
| "Mean relative deviation": mean_relative_deviation, | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,11 +6,14 @@ | |
| import dataclasses | ||
| import io | ||
| import json | ||
| from collections.abc import Mapping | ||
| from dataclasses import dataclass | ||
| from datetime import date, datetime, timedelta | ||
| from decimal import Decimal | ||
| from typing import TYPE_CHECKING, Any, Literal, cast | ||
|
|
||
| import polars as pl | ||
| import polars.selectors as cs | ||
| from rich import box | ||
| from rich.columns import Columns as RichColumns | ||
| from rich.console import Console, Group, RenderableType | ||
|
|
@@ -20,6 +23,7 @@ | |
| from rich.text import Text | ||
|
|
||
| from ._utils import Side, capitalize_first | ||
| from .metrics import Metric, _make_numeric_metric | ||
|
|
||
| if TYPE_CHECKING: # pragma: no cover | ||
| from .comparison import DataFrameComparison | ||
|
|
@@ -57,6 +61,7 @@ def __init__( | |
| right_name: str, | ||
| slim: bool, | ||
| hidden_columns: list[str] | None, | ||
| metrics: Mapping[str, Metric] | None, | ||
| ): | ||
| self.slim = slim | ||
| self._data = _compute_summary_data( | ||
|
|
@@ -69,6 +74,7 @@ def __init__( | |
| right_name=right_name, | ||
| slim=slim, | ||
| hidden_columns=hidden_columns, | ||
| metrics=metrics, | ||
| ) | ||
|
|
||
| def format(self, pretty: bool | None = None) -> str: | ||
|
|
@@ -565,13 +571,16 @@ def _section_columns(self) -> RenderableType: | |
| elif not columns: | ||
| display_items.append(Text("All columns match perfectly.", style="italic")) | ||
| else: | ||
| matches = Table(show_header=False) | ||
| metric_labels = self._data._metric_labels | ||
| matches = Table(show_header=bool(metric_labels)) | ||
| matches.add_column( | ||
| "Column", | ||
| max_width=COLUMN_SECTION_COLUMN_WIDTH, | ||
| overflow=OVERFLOW, | ||
| ) | ||
| matches.add_column("Match Rate", justify="right") | ||
| for label in metric_labels: | ||
| matches.add_column(label, justify="right") | ||
| has_top_changes_column = any( | ||
| c.changes is not None for c in columns if c.match_rate < 1 | ||
| ) | ||
|
|
@@ -583,6 +592,9 @@ def _section_columns(self) -> RenderableType: | |
| Text(col.name, style="cyan"), | ||
| f"{_format_fraction_as_percentage(col.match_rate)}", | ||
| ] | ||
| for label in metric_labels: | ||
| value = col.metrics.get(label) if col.metrics else None | ||
| row_items.append(_format_metric_value(value)) | ||
| if col.changes is not None: | ||
| change_lines = [] | ||
| for change in col.changes: | ||
|
|
@@ -703,6 +715,7 @@ class SummaryDataColumn: | |
| match_rate: float | ||
| n_total_changes: int | ||
| changes: list[SummaryDataColumnChange] | None | ||
| metrics: dict[str, Any] | None | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -720,6 +733,7 @@ class SummaryData: | |
| _other_common_columns: list[str] | ||
| _truncated_left_name: str | ||
| _truncated_right_name: str | ||
| _metric_labels: list[str] | ||
|
|
||
| def to_dict(self) -> dict[str, Any]: | ||
| def _convert(obj: Any) -> Any: | ||
|
|
@@ -758,6 +772,7 @@ def _compute_summary_data( | |
| right_name: str, | ||
| slim: bool, | ||
| hidden_columns: list[str] | None, | ||
| metrics: Mapping[str, Metric] | None, | ||
| ) -> SummaryData: | ||
| from .comparison import DataFrameComparison | ||
|
|
||
|
|
@@ -821,8 +836,13 @@ def _validate_primary_key_hidden_columns() -> None: | |
| _other_common_columns=comp._other_common_columns, | ||
| _truncated_left_name=truncated_left, | ||
| _truncated_right_name=truncated_right, | ||
| _metric_labels=[], | ||
| ) | ||
|
|
||
| metrics_resolved: dict[str, Metric] = dict(metrics or {}) | ||
| metrics_by_column = _compute_column_metrics(comp, metrics_resolved) | ||
| metric_labels = list(metrics_resolved.keys()) | ||
|
|
||
| schemas = _compute_schemas(comp, slim) | ||
| rows = _compute_rows(comp, slim) | ||
| columns = _compute_columns( | ||
|
|
@@ -831,6 +851,7 @@ def _validate_primary_key_hidden_columns() -> None: | |
| show_perfect_column_matches, | ||
| top_k_changes_by_column, | ||
| show_sample_primary_key_per_change, | ||
| metrics_by_column, | ||
| ) | ||
| sample_rows_left_only, sample_rows_right_only = _compute_sample_rows( | ||
| comp, sample_k_rows_only | ||
|
|
@@ -850,6 +871,7 @@ def _validate_primary_key_hidden_columns() -> None: | |
| _other_common_columns=comp._other_common_columns, | ||
| _truncated_left_name=truncated_left, | ||
| _truncated_right_name=truncated_right, | ||
| _metric_labels=metric_labels, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -911,12 +933,52 @@ def _compute_rows(comp: DataFrameComparison, slim: bool) -> SummaryDataRows | No | |
| ) | ||
|
|
||
|
|
||
| def _compute_column_metrics( | ||
| comp: DataFrameComparison, | ||
| metrics: Mapping[str, Metric], | ||
| ) -> dict[str, dict[str, Any]]: | ||
| if comp.primary_key is None or comp.num_rows_joined() == 0: | ||
| return {} | ||
|
|
||
| _metrics = {label: _make_numeric_metric(m) for label, m in metrics.items()} | ||
|
|
||
| def select_columns(selector: pl.Expr) -> set[str]: | ||
| left = set(cs.expand_selector(comp.left_schema, selector)) | ||
| right = set(cs.expand_selector(comp.right_schema, selector)) | ||
| return (left & right) & set(comp._other_common_columns) | ||
|
|
||
| metric_to_columns = { | ||
| label: select_columns(m.selector) for label, m in _metrics.items() | ||
| } | ||
|
|
||
| all_columns = sorted(set().union(*metric_to_columns.values())) | ||
| if not all_columns: | ||
| return {} | ||
| out: dict[str, dict[str, Any]] = {c: {} for c in all_columns} | ||
|
|
||
| joined = comp.joined(lazy=True) | ||
| agg_exprs = [ | ||
| metric.fn( | ||
| pl.col(f"{column}_{Side.LEFT}"), | ||
| pl.col(f"{column}_{Side.RIGHT}"), | ||
| ).alias(f"{label}__{column}") | ||
| for label, metric in _metrics.items() | ||
| for column in sorted(metric_to_columns[label]) | ||
| ] | ||
| row = joined.select(agg_exprs).collect().row(0, named=True) | ||
| for label, columns in metric_to_columns.items(): | ||
| for column in columns: | ||
| out[column][label] = row[f"{label}__{column}"] | ||
| return out | ||
|
|
||
|
|
||
| def _compute_columns( | ||
| comp: DataFrameComparison, | ||
| slim: bool, | ||
| show_perfect_column_matches: bool, | ||
| top_k_changes_by_column: dict[str, int], | ||
| show_sample_primary_key_per_change: bool, | ||
| metrics_by_column: dict[str, dict[str, Any]], | ||
| ) -> list[SummaryDataColumn] | None: | ||
| # NOTE: We can only compute column matches if there are primary key columns and at | ||
| # least one joined row. | ||
|
|
@@ -963,6 +1025,7 @@ def _compute_columns( | |
| match_rate=rate, | ||
| n_total_changes=n_total_changes, | ||
| changes=changes, | ||
| metrics=metrics_by_column.get(col_name), | ||
| ) | ||
| ) | ||
| return columns | ||
|
|
@@ -1038,9 +1101,17 @@ def _format_fraction_as_percentage(fraction: float) -> str: | |
| return f"{percentage:.2f}%" | ||
|
|
||
|
|
||
| def _format_value(value: Any) -> str: | ||
| def _yellow(raw: Any) -> str: | ||
| return f"[yellow]{raw}[/yellow]" | ||
|
|
||
|
|
||
| def _format_value(value: Any, *, float_format: str | None = None) -> str: | ||
| """Format a raw cell value for display, wrapped in the yellow highlight style. | ||
|
|
||
| Floats are shown with their default repr unless ``float_format`` is set. | ||
| """ | ||
| if isinstance(value, list): | ||
| formatted = [_format_value(x) for x in value] | ||
| formatted = [_format_value(x, float_format=float_format) for x in value] | ||
| if len(formatted) > 5: | ||
| return f"[{', '.join(formatted[:2])}, ..., {', '.join(formatted[-2:])}]" | ||
| return f"[{', '.join(formatted)}]" | ||
|
|
@@ -1052,9 +1123,21 @@ def _format_value(value: Any) -> str: | |
| raw = f'"{value}"' | ||
| elif isinstance(value, date | datetime): | ||
| raw = str(value) | ||
| elif isinstance(value, float) and float_format is not None: | ||
| raw = format(value, float_format) | ||
| else: | ||
| raw = value | ||
| return f"[yellow]{raw}[/yellow]" | ||
| return _yellow(raw) | ||
|
|
||
|
|
||
| def _format_metric_value(value: Any) -> str: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why don't we just add the case if isinstance(value, float):
return _yellow(f"{value:.4g}")to
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for this comment, I was also not very content with these two methods. The drawback of your suggestion is that it would format all floats with |
||
| """Format a metric value for the column summary. | ||
|
|
||
| Blanks out ``None`` and renders floats with ``.4g`` precision. | ||
| """ | ||
| if value is None: | ||
| return "" | ||
| return _format_value(value, float_format=".4g") | ||
|
|
||
|
|
||
| def _trim_whitespaces(s: str) -> str: | ||
|
|
||

Uh oh!
There was an error while loading. Please reload this page.