diff --git a/core/clients/auth_flow.go b/core/clients/auth_flow.go index 1aed5df11..0e442c63f 100644 --- a/core/clients/auth_flow.go +++ b/core/clients/auth_flow.go @@ -85,5 +85,5 @@ func tokenExpired(token string, tokenExpirationLeeway time.Duration) (bool, erro // Pretend to be `tokenExpirationLeeway` into the future to avoid token expiring // between retrieving the token and upstream systems validating it. now := time.Now().Add(tokenExpirationLeeway) - return now.After(expirationTimestampNumeric.Time), nil + return now.After(expirationTimestampNumeric.Time) || now.Equal(expirationTimestampNumeric.Time), nil } diff --git a/core/clients/continuous_refresh_test.go b/core/clients/continuous_refresh_test.go index 110d47ef9..311ac6ba1 100644 --- a/core/clients/continuous_refresh_test.go +++ b/core/clients/continuous_refresh_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "testing" + "testing/synctest" "time" "github.com/golang-jwt/jwt/v5" @@ -93,36 +94,39 @@ func TestContinuousRefreshToken(t *testing.T) { tt := tt t.Run(tt.desc, func(t *testing.T) { t.Parallel() - accessToken, err := signToken(accessTokensTimeToLive) - if err != nil { - t.Fatalf("failed to sign access token: %v", err) - } - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn) - defer cancel() - - authFlow := &fakeAuthFlow{ - backgroundTokenRefreshContext: ctx, - doError: tt.doError, - accessTokensTimeToLive: accessTokensTimeToLive, - accessToken: accessToken, - } + synctest.Test(t, func(t *testing.T) { + accessToken, err := signToken(accessTokensTimeToLive) + if err != nil { + t.Fatalf("failed to sign access token: %v", err) + } + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, tt.contextClosesIn) + defer cancel() + + authFlow := &fakeAuthFlow{ + backgroundTokenRefreshContext: ctx, + doError: tt.doError, + accessTokensTimeToLive: accessTokensTimeToLive, + accessToken: accessToken, + } - refresher := &continuousTokenRefresher{ - flow: authFlow, - timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration, - timeBetweenContextCheck: timeBetweenContextCheck, - timeBetweenTries: timeBetweenTries, - } + refresher := &continuousTokenRefresher{ + flow: authFlow, + timeStartBeforeTokenExpiration: timeStartBeforeTokenExpiration, + timeBetweenContextCheck: timeBetweenContextCheck, + timeBetweenTries: timeBetweenTries, + } - err = refresher.continuousRefreshToken() - if err == nil { - t.Fatalf("routine finished with non-nil error") - } - numberDoCalls := authFlow.getTokenCalls() - if numberDoCalls != tt.expectedNumberDoCalls { - t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls) - } + err = refresher.continuousRefreshToken() + synctest.Wait() + if err == nil { + t.Fatalf("routine finished with non-nil error") + } + numberDoCalls := authFlow.getTokenCalls() + if numberDoCalls != tt.expectedNumberDoCalls { + t.Fatalf("expected %d calls to API to refresh token, got %d", tt.expectedNumberDoCalls, numberDoCalls) + } + }) }) } }