diff --git a/src/proxy_server.py b/src/proxy_server.py index ce00a04..cb30c68 100644 --- a/src/proxy_server.py +++ b/src/proxy_server.py @@ -8,6 +8,7 @@ import asyncio import logging +import random import re import socket import ssl @@ -177,6 +178,7 @@ class ProxyServer: "video/", "audio/", ) + _RELAY_RETRY_SAFE_METHODS = {"GET", "HEAD", "OPTIONS"} def __init__(self, config: dict): self.host = config.get("listen_host", "127.0.0.1") @@ -212,6 +214,22 @@ def __init__(self, config: dict): self._download_max_chunks = self._cfg_int( config, "chunked_download_max_chunks", 256, minimum=1, ) + self._relay_retry_on_502 = bool(config.get("relay_retry_on_502", True)) + self._relay_retry_max_attempts = self._cfg_int( + # Stronger default for transient Apps Script throttling bursts. + config, "relay_retry_max_attempts", 3, minimum=1, + ) + self._relay_retry_backoff = self._cfg_float( + config, "relay_retry_backoff_seconds", 1.8, minimum=0.1, + ) + self._relay_retry_cooldown_seconds = self._cfg_float( + config, "relay_retry_cooldown_seconds", 3.0, minimum=0.1, + ) + self._relay_retry_cooldown_max_seconds = self._cfg_float( + config, "relay_retry_cooldown_max_seconds", 12.0, minimum=0.1, + ) + self._relay_retry_cooldown_by_host: dict[str, float] = {} + self._relay_retry_streak_by_host: dict[str, int] = {} self._download_extensions, self._download_any_extension = ( self._normalize_download_extensions( config.get( @@ -1251,18 +1269,10 @@ async def _relay_http_stream(self, host: str, port: int, reader, writer): log.debug("Cache HIT: %s", url[:60]) if response is None: - # Relay through Apps Script - try: - response = await self._relay_smart(method, url, headers, body) - except Exception as e: - log.error("Relay error (%s): %s", url[:60], e) - err_body = f"Relay error: {e}".encode() - response = ( - b"HTTP/1.1 502 Bad Gateway\r\n" - b"Content-Type: text/plain\r\n" - b"Content-Length: " + str(len(err_body)).encode() + b"\r\n" - b"\r\n" + err_body - ) + # Relay through Apps Script (+ transient 502 auto-retry) + response = await self._relay_smart_with_transient_retry( + method, url, headers, body, + ) # Cache successful GET responses if self._cache_allowed(method, url, headers, body) and response: @@ -1393,6 +1403,111 @@ def _is_likely_download(self, url: str, headers: dict) -> bool: return True return False + @staticmethod + def _is_temporary_relay_502(response: bytes | None) -> bool: + if not response: + return False + if not response.startswith(b"HTTP/1.1 502"): + return False + sep = b"\r\n\r\n" + if sep not in response: + return False + _headers, body = response.split(sep, 1) + body_l = body.lower() + # Retry only known relay-generated 502s (temporary upstream throttling + # and Apps Script error payloads), not arbitrary origin 502 pages. + return ( + b"relay error" in body_l + or b"relay response" in body_l + or b"bad json" in body_l + or b"no json" in body_l + ) + + @staticmethod + def _response_status_code(response: bytes | None) -> int: + if not response: + return 0 + try: + line = response.split(b"\r\n", 1)[0].decode(errors="replace") + except Exception: + return 0 + m = re.search(r"\b(\d{3})\b", line) + return int(m.group(1)) if m else 0 + + async def _relay_smart_with_transient_retry(self, method: str, url: str, + headers: dict, body: bytes) -> bytes: + method_u = method.upper() + host_key = (urlparse(url).hostname or "").lower().rstrip(".") + max_attempts = self._relay_retry_max_attempts + can_retry = ( + self._relay_retry_on_502 + and max_attempts > 1 + and method_u in self._RELAY_RETRY_SAFE_METHODS + and not body + ) + attempts = max_attempts if can_retry else 1 + last_response: bytes | None = None + + # Global back-pressure per host after temporary 502. This reduces burst + # pressure for workloads like sequential image downloads (100+ files) + # where immediate retries across concurrent clients can re-trigger the + # same Apps Script throttle window. + if can_retry and host_key: + now = time.time() + until = self._relay_retry_cooldown_by_host.get(host_key, 0.0) + if until > now: + await asyncio.sleep(until - now) + + for attempt in range(1, attempts + 1): + try: + response = await self._relay_smart(method, url, headers, body) + except Exception as e: + log.error("Relay error (%s): %s", url[:60], e) + err_body = f"Relay error: {e}".encode() + response = ( + b"HTTP/1.1 502 Bad Gateway\r\n" + b"Content-Type: text/plain\r\n" + b"Content-Length: " + str(len(err_body)).encode() + b"\r\n" + b"\r\n" + err_body + ) + + last_response = response + status_code = self._response_status_code(response) + is_temp_502 = self._is_temporary_relay_502(response) + is_temp_status = status_code in (429, 503) + + if not (is_temp_502 or is_temp_status): + if host_key: + self._relay_retry_streak_by_host.pop(host_key, None) + return response + + if attempt >= attempts: + return response + + backoff = self._relay_retry_backoff * attempt + if host_key: + streak = self._relay_retry_streak_by_host.get(host_key, 0) + 1 + self._relay_retry_streak_by_host[host_key] = streak + adaptive_cooldown = min( + self._relay_retry_cooldown_seconds * (2 ** (streak - 1)), + self._relay_retry_cooldown_max_seconds, + ) + # Add jitter so concurrent clients don't retry in lockstep. + jitter = random.uniform(0.0, adaptive_cooldown * 0.25) + self._relay_retry_cooldown_by_host[host_key] = ( + time.time() + adaptive_cooldown + jitter + ) + log.warning( + "Transient relay failure status=%s for %s %s — retrying in %.1fs (%d/%d)", + status_code or 502, method_u, url[:80], backoff, attempt, attempts, + ) + await asyncio.sleep(backoff) + + return last_response or ( + b"HTTP/1.1 502 Bad Gateway\r\n" + b"Content-Length: 0\r\n\r\n" + ) + async def _maybe_stream_download(self, method: str, url: str, headers: dict | None, body: bytes, writer) -> bool: @@ -1475,7 +1590,9 @@ async def _do_http(self, header_block: bytes, reader, writer): log.debug("Cache HIT (HTTP): %s", url[:60]) if response is None: - response = await self._relay_smart(method, url, headers, body) + response = await self._relay_smart_with_transient_retry( + method, url, headers, body, + ) # Cache successful GET if self._cache_allowed(method, url, headers, body) and response: ttl = ResponseCache.parse_ttl(response, url)