diff --git a/tests/test_oauth.py b/tests/test_oauth.py new file mode 100644 index 0000000..36a4a17 --- /dev/null +++ b/tests/test_oauth.py @@ -0,0 +1,35 @@ +from datetime import datetime, timedelta, timezone +from unittest.mock import Mock + +import pytest + +from trakt.api import TokenAuth +from trakt.config import AuthConfig +from trakt.errors import OAuthException, OAuthRefreshException + + +def test_token_refresh_failure_raises_oauth_refresh_exception(): + config = AuthConfig('missing.json').update( + CLIENT_ID='client-id', + CLIENT_SECRET='client-secret', + OAUTH_TOKEN='stale-token', + OAUTH_REFRESH='refresh-token', + OAUTH_EXPIRES_AT=int((datetime.now(tz=timezone.utc) - timedelta(minutes=1)).timestamp()), + ) + response = Mock() + response.json.return_value = { + 'error': 'invalid_grant', + 'error_description': 'refresh token is invalid', + } + client = Mock() + client.post.side_effect = OAuthException(response=response) + + auth = TokenAuth(client=client, config=config) + + with pytest.raises(OAuthRefreshException) as exc_info: + auth.get_token() + + assert exc_info.value.error == 'invalid_grant' + assert exc_info.value.error_description == 'refresh token is invalid' + assert auth.TOKEN_UNDER_REFRESH is False + assert auth.OAUTH_TOKEN_VALID is False diff --git a/trakt/api.py b/trakt/api.py index 6ce7960..d858e69 100644 --- a/trakt/api.py +++ b/trakt/api.py @@ -11,7 +11,7 @@ from trakt.config import AuthConfig from trakt.core import TIMEOUT from trakt.errors import (BadRequestException, BadResponseException, - OAuthException) + OAuthException, OAuthRefreshException) __author__ = 'Elan Ruusamäe' @@ -166,9 +166,6 @@ class TokenAuth(AuthBase): #: The OAuth2 Redirect URI for your OAuth Application REDIRECT_URI: str = 'urn:ietf:wg:oauth:2.0:oob' - #: How many times to attempt token auth refresh before failing - MAX_RETRIES = 1 - # Time margin before token expiry when refresh should be triggered TOKEN_REFRESH_MARGIN = {'minutes': 10} @@ -180,7 +177,6 @@ def __init__(self, client: HttpClient, config: AuthConfig): self.client = client # OAuth token validity checked self.OAUTH_TOKEN_VALID = None - self.refresh_attempts = 0 self.TOKEN_UNDER_REFRESH = False def __call__(self, r): @@ -223,25 +219,21 @@ def validate_token(self): critical operations while also maximizing the token's useful lifetime. """ - current = datetime.now(tz=timezone.utc) - expires_at = datetime.fromtimestamp(self.config.OAUTH_EXPIRES_AT, tz=timezone.utc) - margin = expires_at - current - if margin > timedelta(**self.TOKEN_REFRESH_MARGIN): - self.OAUTH_TOKEN_VALID = True - else: - self.logger.debug("Token expires in %s, refreshing (margin: %s)", margin, self.TOKEN_REFRESH_MARGIN) - self.refresh_token() - - self.TOKEN_UNDER_REFRESH = False + try: + current = datetime.now(tz=timezone.utc) + expires_at = datetime.fromtimestamp(self.config.OAUTH_EXPIRES_AT, tz=timezone.utc) + margin = expires_at - current + if margin > timedelta(**self.TOKEN_REFRESH_MARGIN): + self.OAUTH_TOKEN_VALID = True + else: + self.logger.debug("Token expires in %s, refreshing (margin: %s)", margin, self.TOKEN_REFRESH_MARGIN) + self.refresh_token() + finally: + self.TOKEN_UNDER_REFRESH = False def refresh_token(self): """Request Trakt API for a new valid OAuth token using refresh_token""" - if self.refresh_attempts >= self.MAX_RETRIES: - self.logger.error("Max token refresh attempts reached. Manual intervention required.") - return - self.refresh_attempts += 1 - self.logger.info("OAuth token has expired, refreshing now...") data = { 'client_id': self.config.CLIENT_ID, @@ -253,24 +245,9 @@ def refresh_token(self): try: response = self.client.post('oauth/token', data) - self.refresh_attempts = 0 except (OAuthException, BadRequestException) as e: - if e.response is not None: - try: - data = e.response.json() - error = data.get("error") - error_description = data.get("error_description") - except JSONDecodeError: - error = "Invalid JSON response" - error_description = e.response.text - else: - error = "No error description" - error_description = "" - self.logger.error( - "%s - Unable to refresh expired OAuth token (%s) %s", - e.http_code, error, error_description - ) - return + self.OAUTH_TOKEN_VALID = False + raise OAuthRefreshException(response=e.response) from e self.config.update( OAUTH_TOKEN=response.get("access_token"),