Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 130 additions & 13 deletions src/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import asyncio
import logging
import random
import re
import socket
import ssl
Expand Down Expand Up @@ -177,6 +178,7 @@ class ProxyServer:
"video/",
"audio/",
)
_RELAY_RETRY_SAFE_METHODS = {"GET", "HEAD", "OPTIONS"}

def __init__(self, config: dict):
self.host = config.get("listen_host", "127.0.0.1")
Expand Down Expand Up @@ -212,6 +214,22 @@ def __init__(self, config: dict):
self._download_max_chunks = self._cfg_int(
config, "chunked_download_max_chunks", 256, minimum=1,
)
self._relay_retry_on_502 = bool(config.get("relay_retry_on_502", True))
self._relay_retry_max_attempts = self._cfg_int(
# Stronger default for transient Apps Script throttling bursts.
config, "relay_retry_max_attempts", 3, minimum=1,
)
self._relay_retry_backoff = self._cfg_float(
config, "relay_retry_backoff_seconds", 1.8, minimum=0.1,
)
self._relay_retry_cooldown_seconds = self._cfg_float(
config, "relay_retry_cooldown_seconds", 3.0, minimum=0.1,
)
self._relay_retry_cooldown_max_seconds = self._cfg_float(
config, "relay_retry_cooldown_max_seconds", 12.0, minimum=0.1,
)
self._relay_retry_cooldown_by_host: dict[str, float] = {}
self._relay_retry_streak_by_host: dict[str, int] = {}
self._download_extensions, self._download_any_extension = (
self._normalize_download_extensions(
config.get(
Expand Down Expand Up @@ -1251,18 +1269,10 @@ async def _relay_http_stream(self, host: str, port: int, reader, writer):
log.debug("Cache HIT: %s", url[:60])

if response is None:
# Relay through Apps Script
try:
response = await self._relay_smart(method, url, headers, body)
except Exception as e:
log.error("Relay error (%s): %s", url[:60], e)
err_body = f"Relay error: {e}".encode()
response = (
b"HTTP/1.1 502 Bad Gateway\r\n"
b"Content-Type: text/plain\r\n"
b"Content-Length: " + str(len(err_body)).encode() + b"\r\n"
b"\r\n" + err_body
)
# Relay through Apps Script (+ transient 502 auto-retry)
response = await self._relay_smart_with_transient_retry(
method, url, headers, body,
)

# Cache successful GET responses
if self._cache_allowed(method, url, headers, body) and response:
Expand Down Expand Up @@ -1393,6 +1403,111 @@ def _is_likely_download(self, url: str, headers: dict) -> bool:
return True
return False

@staticmethod
def _is_temporary_relay_502(response: bytes | None) -> bool:
if not response:
return False
if not response.startswith(b"HTTP/1.1 502"):
return False
sep = b"\r\n\r\n"
if sep not in response:
return False
_headers, body = response.split(sep, 1)
body_l = body.lower()
# Retry only known relay-generated 502s (temporary upstream throttling
# and Apps Script error payloads), not arbitrary origin 502 pages.
return (
b"relay error" in body_l
or b"relay response" in body_l
or b"bad json" in body_l
or b"no json" in body_l
)

@staticmethod
def _response_status_code(response: bytes | None) -> int:
if not response:
return 0
try:
line = response.split(b"\r\n", 1)[0].decode(errors="replace")
except Exception:
return 0
m = re.search(r"\b(\d{3})\b", line)
return int(m.group(1)) if m else 0

async def _relay_smart_with_transient_retry(self, method: str, url: str,
headers: dict, body: bytes) -> bytes:
method_u = method.upper()
host_key = (urlparse(url).hostname or "").lower().rstrip(".")
max_attempts = self._relay_retry_max_attempts
can_retry = (
self._relay_retry_on_502
and max_attempts > 1
and method_u in self._RELAY_RETRY_SAFE_METHODS
and not body
)
attempts = max_attempts if can_retry else 1
last_response: bytes | None = None

# Global back-pressure per host after temporary 502. This reduces burst
# pressure for workloads like sequential image downloads (100+ files)
# where immediate retries across concurrent clients can re-trigger the
# same Apps Script throttle window.
if can_retry and host_key:
now = time.time()
until = self._relay_retry_cooldown_by_host.get(host_key, 0.0)
if until > now:
await asyncio.sleep(until - now)

for attempt in range(1, attempts + 1):
try:
response = await self._relay_smart(method, url, headers, body)
except Exception as e:
log.error("Relay error (%s): %s", url[:60], e)
err_body = f"Relay error: {e}".encode()
response = (
b"HTTP/1.1 502 Bad Gateway\r\n"
b"Content-Type: text/plain\r\n"
b"Content-Length: " + str(len(err_body)).encode() + b"\r\n"
b"\r\n" + err_body
)

last_response = response
status_code = self._response_status_code(response)
is_temp_502 = self._is_temporary_relay_502(response)
is_temp_status = status_code in (429, 503)

if not (is_temp_502 or is_temp_status):
if host_key:
self._relay_retry_streak_by_host.pop(host_key, None)
return response

if attempt >= attempts:
return response

backoff = self._relay_retry_backoff * attempt
if host_key:
streak = self._relay_retry_streak_by_host.get(host_key, 0) + 1
self._relay_retry_streak_by_host[host_key] = streak
adaptive_cooldown = min(
self._relay_retry_cooldown_seconds * (2 ** (streak - 1)),
self._relay_retry_cooldown_max_seconds,
)
# Add jitter so concurrent clients don't retry in lockstep.
jitter = random.uniform(0.0, adaptive_cooldown * 0.25)
self._relay_retry_cooldown_by_host[host_key] = (
time.time() + adaptive_cooldown + jitter
)
log.warning(
"Transient relay failure status=%s for %s %s — retrying in %.1fs (%d/%d)",
status_code or 502, method_u, url[:80], backoff, attempt, attempts,
)
await asyncio.sleep(backoff)

return last_response or (
b"HTTP/1.1 502 Bad Gateway\r\n"
b"Content-Length: 0\r\n\r\n"
)

async def _maybe_stream_download(self, method: str, url: str,
headers: dict | None, body: bytes,
writer) -> bool:
Expand Down Expand Up @@ -1475,7 +1590,9 @@ async def _do_http(self, header_block: bytes, reader, writer):
log.debug("Cache HIT (HTTP): %s", url[:60])

if response is None:
response = await self._relay_smart(method, url, headers, body)
response = await self._relay_smart_with_transient_retry(
method, url, headers, body,
)
# Cache successful GET
if self._cache_allowed(method, url, headers, body) and response:
ttl = ResponseCache.parse_ttl(response, url)
Expand Down