Skip to content
This repository was archived by the owner on Jun 3, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Unreleased

- Add `forget: bool` flag to `POST /v2/memory/ingest`: memories with `forget=true` get a TTL (`expires_at`) and are excluded from retrieval (`_search_summary` + profile catalog) once expired. Read-time enforcement; no sweeper.
- Add `memory_forget_default_ttl_days` setting (env `MEMORY_FORGET_DEFAULT_TTL_DAYS`, default 30). Known limitation: changing it does not refresh an already-cached (idempotent) forget job's TTL; resolved when TTL becomes a client-supplied field.
- `POST /v2/memory/batch-ingest` rejects `forget=true` with HTTP 400 (per-item forget not yet supported in batch).
- Add modular Razorpay billing, credit wallets, ledger reservations, and v2 memory workflow metering.
- Add durable Temporal-backed v2 memory and scanner workflow APIs with job status, retry, cancel, and dead-letter endpoints.
- Add modular LoCoMo and BEAM benchmark runners for the Python XMem API.
Expand Down
6 changes: 6 additions & 0 deletions src/api/routes/v2/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,14 @@ async def memory_domain_activity(payload: Dict[str, Any]) -> Dict[str, Any]:
billing_account_id=payload.get("billing_account_id"),
user_id=user_id,
):
lifecycle_metadata = payload.get("lifecycle_metadata")

if domain == "profile":
result = await pipeline._node_extract_profile(
{
"profile_queries": payload.get("queries", []),
"user_id": user_id,
"lifecycle_metadata": lifecycle_metadata,
}
)
return {"domain": domain, "result": _domain_payload(result, "profile")}
Expand All @@ -124,6 +127,7 @@ async def memory_domain_activity(payload: Dict[str, Any]) -> Dict[str, Any]:
"temporal_queries": payload.get("queries", []),
"session_datetime": payload.get("session_datetime", ""),
"user_id": user_id,
"lifecycle_metadata": lifecycle_metadata,
}
)
return {"domain": domain, "result": _domain_payload(result, "temporal")}
Expand All @@ -134,6 +138,7 @@ async def memory_domain_activity(payload: Dict[str, Any]) -> Dict[str, Any]:
"user_query": payload.get("user_query", ""),
"agent_response": payload.get("agent_response", ""),
"user_id": user_id,
"lifecycle_metadata": lifecycle_metadata,
}
)
return {"domain": domain, "result": _domain_payload(result, "summary")}
Expand All @@ -144,6 +149,7 @@ async def memory_domain_activity(payload: Dict[str, Any]) -> Dict[str, Any]:
"classifier_output": payload.get("classifier_output", ""),
"image_url": payload.get("image_url", ""),
"user_id": user_id,
"lifecycle_metadata": lifecycle_metadata,
}
)
return {"domain": domain, "result": _domain_payload(result, "image")}
Expand Down
24 changes: 24 additions & 0 deletions src/api/routes/v2/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
from datetime import datetime, timezone
import time
from typing import Any, Dict

Expand All @@ -18,6 +19,7 @@
read_user_job,
)
from src.api.routes.v2.temporal_client import start_job_workflow
from src.pipelines.lifecycle import build_lifecycle_metadata
from src.api.schemas import APIResponse, BatchIngestRequest, IngestRequest, ScrapeRequest, StatusEnum
from src.billing import InsufficientCredits, get_default_billing_service
from src.config import settings
Expand Down Expand Up @@ -122,6 +124,15 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De
payload = req.model_dump()
payload["user_id"] = user_id
payload["timeout_seconds"] = float(settings.memory_ingest_timeout_seconds)

# When forget=true, compute lifecycle_metadata and thread it through so the
# weaver stamps the forget flag + TTL on every vector record it writes.
if req.forget:
payload["lifecycle_metadata"] = build_lifecycle_metadata(
now=datetime.now(timezone.utc),
ttl_days=settings.memory_forget_default_ttl_days,
)

idempotency_fields = {
"user_id": user_id,
"org_id": payload.get("org_id", "default"),
Expand All @@ -131,6 +142,12 @@ async def ingest_memory_v2(req: IngestRequest, request: Request, user: dict = De
"session_datetime": req.session_datetime,
"image_url": req.image_url,
"effort_level": req.effort_level,
# forget distinguishes forget vs non-forget of identical content.
# KNOWN LIMITATION (PR #2): server-default TTL is intentionally NOT hashed.
# Idempotency = "same request → same job"; the request didn't change, the
# server config did. Changing MEMORY_FORGET_DEFAULT_TTL_DAYS won't refresh a
# cached forget job's TTL. Resolved when forget_ttl_days becomes a client field.
"forget": req.forget,
}),
}
job_id = _durable_job_id("memory_ingest", idempotency_fields)
Expand Down Expand Up @@ -207,6 +224,13 @@ async def memory_job_status(job_id: str, request: Request, user: dict = Depends(
@router.post("/batch-ingest", response_model=APIResponse, summary="Start an async durable batch memory ingest job")
async def batch_ingest_memory_v2(req: BatchIngestRequest, request: Request, user: dict = Depends(require_api_key)):
start = time.perf_counter()
if any(getattr(item, "forget", False) for item in req.items):
return _error(
request,
"forget=true is not supported in batch ingest yet; use POST /v2/memory/ingest per item.",
400,
elapsed_ms(start),
)
user_id = memory_v1._current_user_id(user)
items = [memory_v1._scoped_ingest_payload(user, item) for item in req.items]
payload = {
Expand Down
8 changes: 8 additions & 0 deletions src/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ class IngestRequest(UserScopedModel):
default="low",
description="'low' (fast, single pass) or 'high' (chunked parallel extraction)",
)
forget: bool = Field(
default=False,
description=(
"When true, the stored memory is tagged with a TTL and will be "
"automatically excluded from retrieval after it expires. "
"Only honoured on the v2 ingest path."
),
)

@field_validator("user_query")
@classmethod
Expand Down
12 changes: 12 additions & 0 deletions src/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,18 @@ class Settings(BaseSettings):
description="Razorpay subscription plan ID for the global USD Pro plan",
)

# =============================================================================
# Memory Lifecycle — Forget / TTL (default OFF — fully backward compatible)
# =============================================================================
memory_forget_default_ttl_days: float = Field(
default=30.0,
gt=0.0,
description=(
"Default TTL in days for memories ingested with forget=true via the v2 API. "
"Expired records are hidden from retrieval (filtered at read time; no sweeper)."
),
)

@field_validator("fallback_order")
@classmethod
def validate_fallback_order(cls, v: List[str]) -> List[str]:
Expand Down
3 changes: 3 additions & 0 deletions src/pipelines/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,7 @@ async def _node_extract_profile(self, state: IngestState) -> Dict[str, Any]:
judge_result=judge_result,
domain=JudgeDomain.PROFILE,
user_id=user_id,
extra_metadata=state.get("lifecycle_metadata"),
)
return {
"profile_result": result,
Expand Down Expand Up @@ -963,6 +964,7 @@ async def _node_extract_image(self, state: IngestState) -> Dict[str, Any]:
judge_result=judge_result,
domain=JudgeDomain.SUMMARY,
user_id=user_id,
extra_metadata=state.get("lifecycle_metadata"),
)

return {
Expand Down Expand Up @@ -1090,6 +1092,7 @@ async def _node_extract_summary(self, state: IngestState) -> Dict[str, Any]:
judge_result=judge_result,
domain=JudgeDomain.SUMMARY,
user_id=state.get("user_id", "default"),
extra_metadata=state.get("lifecycle_metadata"),
)
return {
"summary_result": result,
Expand Down
77 changes: 77 additions & 0 deletions src/pipelines/lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Memory lifecycle — pure, deterministic helper functions for forget/TTL.

is_retrievable() is the single retrieval-time gate. build_lifecycle_metadata()
stamps the lifecycle fields onto a metadata dict for storage.

All functions are side-effect-free so they can be tested without live services.
"""

from __future__ import annotations

from datetime import datetime
from typing import Any, Dict, Mapping, Optional


def is_retrievable(metadata: Optional[Mapping[str, Any]], now: datetime) -> bool:
"""Return True when a record should appear in retrieval results.

Rules (applied in order):
1. ``lifecycle_state == "forgotten"`` → hidden (manual soft-forget).
2. ``forget is True`` and ``expires_at`` is present and in the past → hidden (TTL expired).
3. Everything else (including all legacy records with no lifecycle keys) → retrievable.

Missing keys default to the legacy-safe value so records stored before
lifecycle was introduced are never hidden.
"""
if not metadata:
return True

if metadata.get("lifecycle_state", "active") == "forgotten":
return False

if metadata.get("forget"):
expires_raw = metadata.get("expires_at")
if expires_raw:
try:
if isinstance(expires_raw, datetime):
expires_at = expires_raw
else:
expires_at = datetime.fromisoformat(str(expires_raw))
# Make both sides timezone-aware or both naive for comparison
if expires_at.tzinfo is None and now.tzinfo is not None:
from datetime import timezone
expires_at = expires_at.replace(tzinfo=timezone.utc)
elif expires_at.tzinfo is not None and now.tzinfo is None:
expires_at = expires_at.replace(tzinfo=None)
if expires_at < now:
return False
except (ValueError, TypeError):
pass

return True


def build_lifecycle_metadata(
now: datetime,
ttl_days: float,
reason: Optional[str] = None,
) -> Dict[str, Any]:
"""Return the lifecycle metadata dict to merge onto a forget=true record.

Called once at v2 ingestion time when the caller sets ``forget=true``.
The result is stored as part of the vector record's metadata so the
retrieval-time filter can enforce the TTL without any background sweeper.
"""
from datetime import timedelta
expires_at = now + timedelta(days=ttl_days)
meta: Dict[str, Any] = {
"forget": True,
"expires_at": expires_at.isoformat(),
"lifecycle_state": "active",
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
}
if reason:
meta["forget_reason"] = reason
return meta
21 changes: 18 additions & 3 deletions src/pipelines/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from __future__ import annotations

import asyncio
from datetime import datetime, timezone
import logging
from typing import Any, Callable, Dict, List, Optional

Expand All @@ -31,6 +32,7 @@

from src.config import settings
from src.graph.neo4j_client import Neo4jClient
from src.pipelines.lifecycle import is_retrievable
from src.prompts.retrieval import ANSWER_PROMPT, build_system_prompt
from src.schemas.retrieval import RetrievalResult, SourceRecord
from src.schemas.code import snippets_namespace
Expand Down Expand Up @@ -100,6 +102,7 @@ def __init__(
model: Optional[BaseChatModel] = None,
vector_store: Optional[BaseVectorStore] = None,
neo4j_client: Optional[Neo4jClient] = None,
_now: Optional[Callable[[], datetime]] = None,
) -> None:
# ── LLM ───────────────────────────────────────────────────────
if model is None:
Expand Down Expand Up @@ -133,6 +136,7 @@ def __init__(

self.embed_fn = embed_fn
self._snippet_stores: Dict[str, BaseVectorStore] = {}
self._now: Callable[[], datetime] = _now or (lambda: datetime.now(timezone.utc))

logger.info("RetrievalPipeline initialized")

Expand Down Expand Up @@ -413,8 +417,11 @@ async def _search_summary(
user_id: str,
top_k: int = 10,
) -> List[SourceRecord]:
"""Semantic search over summary entries in Pinecone."""
"""Semantic search over summary entries in Pinecone.

Records ingested with ``forget=true`` whose TTL has passed are filtered
out at read time. Legacy records (no lifecycle keys) always pass through.
"""
results = await self.vector_store.search_by_text(
query_text=query,
top_k=top_k,
Expand All @@ -424,8 +431,11 @@ async def _search_summary(
},
)

now = self._now()
records = []
for r in results:
if not is_retrievable(r.metadata, now):
continue
records.append(SourceRecord(
domain="summary",
content=r.content,
Expand Down Expand Up @@ -503,11 +513,16 @@ def _fetch_profile_catalog(self, user_id: str):
logger.warning("Failed to fetch profile catalog: %s", exc)
return [], []

now = self._now()
catalog: List[Dict[str, str]] = []
seen = set()
live_results = []

for r in results:
main_content = r.metadata.get("main_content", "")
if not is_retrievable(r.metadata, now):
continue
live_results.append(r)
main_content = (r.metadata or {}).get("main_content", "")
if not main_content or main_content in seen:
continue
Comment on lines 521 to 527
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If r.metadata is None (which is possible in some backends or mock tests, and is now handled safely by is_retrievable), calling r.metadata.get("main_content", "") on line 525 will raise an AttributeError. We should make this access defensive.

Suggested change
for r in results:
if not is_retrievable(r.metadata, now):
continue
live_results.append(r)
main_content = r.metadata.get("main_content", "")
if not main_content or main_content in seen:
continue
for r in results:
if not is_retrievable(r.metadata, now):
continue
live_results.append(r)
main_content = r.metadata.get("main_content", "") if r.metadata else ""
if not main_content or main_content in seen:
continue

seen.add(main_content)
Expand All @@ -524,7 +539,7 @@ def _fetch_profile_catalog(self, user_id: str):
"sub_topic": "",
})

return catalog, results
return catalog, live_results

def _format_catalog(self, catalog: List[Dict[str, str]]) -> str:
"""Format profile catalog for the system prompt."""
Expand Down
Loading
Loading