Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/danger.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ concurrency:
jobs:
danger:
name: Danger Review
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
runs-on: ubuntu-latest
timeout-minutes: 10

Expand Down
43 changes: 7 additions & 36 deletions src/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@
import hashlib
import hmac
import logging
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Optional

from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError, jwt

from src.config import settings
from src.database.control_plane_store import control_plane_store
from src.database.api_key_store import APIKeyStore
from src.database.user_store import UserStore
from src.pipelines.ingest import IngestPipeline
Expand Down Expand Up @@ -288,48 +287,20 @@ async def require_user(current_user: Optional[dict] = Depends(get_current_user))


# ═══════════════════════════════════════════════════════════════════════════
# Sliding-window rate limiter (in-process, per-key)
# Sliding-window rate limiter
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _SlidingWindowRateLimiter class and its instance _rate_limiter appear to be redundant now that rate limiting logic has been moved to ControlPlaneStore. Consider removing them and updating the associated tests to avoid maintaining dead code.

# ═══════════════════════════════════════════════════════════════════════════

class _SlidingWindowRateLimiter:
"""Thread-safe sliding-window counter keyed by API identity."""

def __init__(self, max_requests: int, window_seconds: int = 60):
self.max_requests = max_requests
self.window = window_seconds
self._hits: dict[str, list[float]] = defaultdict(list)
self._lock = asyncio.Lock()

async def check(self, key: str) -> tuple[bool, int]:
"""Return (allowed, remaining) for *key*."""
now = time.monotonic()
cutoff = now - self.window

async with self._lock:
timestamps = self._hits[key]
self._hits[key] = [t for t in timestamps if t > cutoff]

if len(self._hits[key]) >= self.max_requests:
return False, 0

self._hits[key].append(now)
remaining = self.max_requests - len(self._hits[key])
return True, remaining


_rate_limiter = _SlidingWindowRateLimiter(
max_requests=settings.rate_limit,
window_seconds=60,
)


async def enforce_rate_limit(
request: Request,
user: dict = Depends(require_api_key),
) -> dict:
"""Raise 429 if the caller has exceeded their per-minute quota."""
identity = user.get("id", "anonymous")
allowed, remaining = await _rate_limiter.check(identity)
allowed, remaining = await control_plane_store.check_rate_limit(
identity,
max_requests=settings.rate_limit,
window_seconds=60,
)

request.state.rate_limit_remaining = remaining

Expand Down
33 changes: 15 additions & 18 deletions src/api/routes/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from src.config import settings
from src.config.analytics import analytics
from src.database.control_plane_store import control_plane_store

logger = logging.getLogger("xmem.api.admin")

Expand All @@ -45,7 +46,6 @@
# ═══════════════════════════════════════════════════════════════════════════

_admin_collection = None
_admin_sessions: Dict[str, Dict[str, Any]] = {} # token → {user, expires}


def _get_admin_collection():
Expand Down Expand Up @@ -80,23 +80,22 @@ class AdminLoginRequest(BaseModel):
password: str


def _verify_admin_token(request: Request) -> Dict[str, Any]:
async def _verify_admin_token(request: Request) -> Dict[str, Any]:
"""Validate admin session token from cookie or Authorization header."""
token = request.cookies.get("xmem_admin_token")
if not token:
auth = request.headers.get("Authorization", "")
if auth.startswith("Bearer "):
token = auth[7:]

if not token or token not in _admin_sessions:
if not token:
raise HTTPException(status_code=401, detail="Not authenticated")

session = _admin_sessions[token]
if datetime.now(timezone.utc) > session["expires"]:
del _admin_sessions[token]
user = await control_plane_store.get_admin_session_async(token)
if not user:
raise HTTPException(status_code=401, detail="Session expired")

return session["user"]
return user


# ═══════════════════════════════════════════════════════════════════════════
Expand All @@ -115,11 +114,11 @@ async def admin_login(req: AdminLoginRequest):
raise HTTPException(status_code=401, detail="Invalid credentials")

# Generate session token
token = hashlib.sha256(f"{req.username}{time.time()}".encode()).hexdigest()
_admin_sessions[token] = {
"user": {"username": user["username"], "role": user.get("role", "admin")},
"expires": datetime.now(timezone.utc) + timedelta(hours=24),
}
session = await control_plane_store.create_admin_session_async(
user={"username": user["username"], "role": user.get("role", "admin")},
ttl_seconds=24 * 60 * 60,
)
token = session["token"]

response = JSONResponse({"status": "ok", "token": token, "username": user["username"]})
response.set_cookie(
Expand All @@ -135,8 +134,8 @@ async def admin_login(req: AdminLoginRequest):
@router.post("/api/logout")
async def admin_logout(request: Request):
token = request.cookies.get("xmem_admin_token")
if token and token in _admin_sessions:
del _admin_sessions[token]
if token:
await control_plane_store.delete_admin_session_async(token)
response = JSONResponse({"status": "ok"})
response.delete_cookie("xmem_admin_token")
return response
Expand Down Expand Up @@ -220,7 +219,7 @@ async def ws_live_logs(websocket: WebSocket):

# Validate auth token from query param
token = websocket.query_params.get("token", "")
if token not in _admin_sessions:
if not token or not (await control_plane_store.get_admin_session_async(token)):
await websocket.close(code=4001, reason="Not authenticated")
return

Expand Down Expand Up @@ -314,7 +313,7 @@ async def _journal_stream():

if not line:
# journalctl exited — send error event and stop
yield f"event: error\ndata: journalctl process exited\n\n"
yield "event: error\ndata: journalctl process exited\n\n"
break

text = line.decode("utf-8", errors="replace").rstrip("\n")
Expand Down Expand Up @@ -386,8 +385,6 @@ async def analytics_summary(request: Request, user: dict = Depends(_verify_admin
now = datetime.now(timezone.utc)
last_24h = now - timedelta(hours=24)
last_7d = now - timedelta(days=7)
last_30d = now - timedelta(days=30)

# API call stats (last 24h)
api_calls_24h = list(collection.aggregate([
{"$match": {"event": "api_call", "ts": {"$gte": last_24h}}},
Expand Down
118 changes: 40 additions & 78 deletions src/api/routes/auth.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Authentication routes for Google OAuth and JWT management."""

import secrets
import string
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from typing import Optional

from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from fastapi import APIRouter, Depends, Form, HTTPException, status
from fastapi.responses import JSONResponse
from google.auth.transport import requests as google_requests
from google.oauth2 import id_token
from jose import JWTError, jwt
Expand All @@ -16,6 +14,7 @@
from src.config import settings
from src.database.user_store import UserStore
from src.database.api_key_store import APIKeyStore
from src.database.control_plane_store import control_plane_store

router = APIRouter(prefix="/auth", tags=["Authentication"])

Expand All @@ -24,86 +23,52 @@
api_key_store = APIKeyStore()

# ═══════════════════════════════════════════════════════════════════════════
# MCP OAuth Temp Token Store (in-memory with TTL)
# MCP OAuth Temp Token Store
# ═══════════════════════════════════════════════════════════════════════════
_mcp_temp_tokens: Dict[str, Dict[str, Any]] = {}
TEMP_TOKEN_PREFIX = "xm-temp-"
TEMP_TOKEN_TTL_MINUTES = 10
TEMP_TOKEN_LENGTH = 32

MCP_TEMP_TOKEN_RECORD = "mcp_temp_token"
OAUTH_AUTH_CODE_RECORD = "oauth_auth_code"

def _generate_mcp_temp_token() -> str:
"""Generate a temporary token for MCP OAuth flow."""
alphabet = string.ascii_letters + string.digits
random_part = "".join(secrets.choice(alphabet) for _ in range(TEMP_TOKEN_LENGTH))
return f"{TEMP_TOKEN_PREFIX}{random_part}"


def _create_mcp_temp_token(user_id: str) -> str:
async def _create_mcp_temp_token(user_id: str) -> dict:
"""Create and store a temporary token for the user."""
token = _generate_mcp_temp_token()
expires_at = datetime.utcnow() + timedelta(minutes=TEMP_TOKEN_TTL_MINUTES)

_mcp_temp_tokens[token] = {
"user_id": user_id,
"created_at": datetime.utcnow(),
"expires_at": expires_at,
"exchanged": False,
}

return token
return await control_plane_store.create_single_use_token_async(
record_type=MCP_TEMP_TOKEN_RECORD,
user_id=user_id,
prefix=TEMP_TOKEN_PREFIX,
ttl_seconds=TEMP_TOKEN_TTL_MINUTES * 60,
)


def _get_and_invalidate_mcp_token(token: str) -> Optional[str]:
async def _get_and_invalidate_mcp_token(token: str) -> Optional[str]:
"""Validate temp token and return user_id if valid, None otherwise."""
if token not in _mcp_temp_tokens:
return None

token_data = _mcp_temp_tokens[token]

# Check expiry
if datetime.utcnow() > token_data["expires_at"]:
del _mcp_temp_tokens[token]
return None

# Check if already exchanged
if token_data["exchanged"]:
return None

# Mark as exchanged and return user_id
user_id = token_data["user_id"]
del _mcp_temp_tokens[token] # Single-use token
return user_id
return await control_plane_store.consume_single_use_token_async(
MCP_TEMP_TOKEN_RECORD,
token,
)


# ═══════════════════════════════════════════════════════════════════════════
# Standard OAuth 2.0 Store (for ChatGPT UI)
# ═══════════════════════════════════════════════════════════════════════════
_oauth_auth_codes: Dict[str, Dict[str, Any]] = {}

def _generate_auth_code(user_id: str) -> str:
async def _generate_auth_code(user_id: str) -> str:
"""Generate a standard OAuth 2.0 authorization code."""
alphabet = string.ascii_letters + string.digits
code = "".join(secrets.choice(alphabet) for _ in range(32))

_oauth_auth_codes[code] = {
"user_id": user_id,
"expires_at": datetime.utcnow() + timedelta(minutes=10)
}
return code
created = await control_plane_store.create_single_use_token_async(
record_type=OAUTH_AUTH_CODE_RECORD,
user_id=user_id,
prefix="",
ttl_seconds=10 * 60,
)
return created["token"]

def _get_and_invalidate_auth_code(code: str) -> Optional[str]:

async def _get_and_invalidate_auth_code(code: str) -> Optional[str]:
"""Validate auth code and return user_id if valid."""
if code not in _oauth_auth_codes:
return None

data = _oauth_auth_codes[code]
del _oauth_auth_codes[code] # Single-use

if datetime.utcnow() > data["expires_at"]:
return None

return data["user_id"]
return await control_plane_store.consume_single_use_token_async(
OAUTH_AUTH_CODE_RECORD,
code,
)


# ═══════════════════════════════════════════════════════════════════════════
Expand Down Expand Up @@ -457,15 +422,15 @@ async def generate_mcp_temp_token(current_user: dict = Depends(require_user)):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required"
)
)

user_id = str(current_user.get("id"))
temp_token = _create_mcp_temp_token(user_id)
temp_token = await _create_mcp_temp_token(user_id)

return MCPTempTokenResponse(
temp_token=temp_token,
temp_token=temp_token["token"],
expires_in=TEMP_TOKEN_TTL_MINUTES * 60,
expires_at=_mcp_temp_tokens[temp_token]["expires_at"]
expires_at=temp_token["expires_at"],
)


Expand All @@ -480,7 +445,7 @@ async def exchange_mcp_token(request: MCPExchangeRequest):
The temp token is single-use and invalidated after exchange.
"""
# Validate and consume the temp token
user_id = _get_and_invalidate_mcp_token(request.temp_token)
user_id = await _get_and_invalidate_mcp_token(request.temp_token)

if not user_id:
raise HTTPException(
Expand Down Expand Up @@ -531,13 +496,10 @@ async def oauth_approve(request: OAuthApproveRequest, current_user: dict = Depen
raise HTTPException(status_code=401, detail="Authentication required")

user_id = str(current_user.get("id"))
code = _generate_auth_code(user_id)
code = await _generate_auth_code(user_id)
return OAuthApproveResponse(code=code)


from fastapi import Form
from fastapi.responses import JSONResponse

@router.post("/oauth/token")
async def oauth_token(
grant_type: str = Form(...),
Expand All @@ -555,7 +517,7 @@ async def oauth_token(
if not code:
return JSONResponse(status_code=400, content={"error": "invalid_request", "error_description": "code is required"})

user_id = _get_and_invalidate_auth_code(code)
user_id = await _get_and_invalidate_auth_code(code)
if not user_id:
return JSONResponse(status_code=400, content={"error": "invalid_grant", "error_description": "Invalid or expired authorization code"})

Expand Down
Loading
Loading