Skip to content

Commit 856eea6

Browse files
committed
fix: improve Instagram OAuth2 token refresh flow
1 parent 8b14aaa commit 856eea6

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

backend/aci/server/oauth2_manager.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import time
44
from typing import Any, cast
55

6+
import httpx
67
from authlib.integrations.httpx_client import AsyncOAuth2Client
78

89
from aci.common.exceptions import OAuth2Error
910
from aci.common.logging_setup import get_logger
1011
from aci.common.schemas.security_scheme import OAuth2SchemeCredentials
12+
from aci.server import config
1113

1214
UNICODE_ASCII_CHARACTER_SET = string.ascii_letters + string.digits
1315
logger = get_logger(__name__)
@@ -214,6 +216,16 @@ async def refresh_token(
214216
access_token: str,
215217
refresh_token: str,
216218
) -> dict[str, Any]:
219+
"""
220+
Refresh OAuth2 access token
221+
222+
Args:
223+
access_token: The current access token used for Instagram refresh
224+
refresh_token: The refresh token used for standard OAuth2 refresh
225+
226+
Returns:
227+
Token response dictionary
228+
"""
217229
try:
218230
if self.app_name == "INSTAGRAM":
219231
response = await self.oauth2_client.get(
@@ -233,11 +245,22 @@ async def refresh_token(
233245
self.refresh_token_url, refresh_token=refresh_token
234246
),
235247
)
236-
return token
248+
249+
except httpx.HTTPStatusError as e:
250+
logger.error(f"Failed to refresh access token, app_name={self.app_name}, error={e}")
251+
if self.app_name == "INSTAGRAM" and e.response.status_code == 400:
252+
raise OAuth2Error(
253+
f"Access token expired. Please re-authorize at: "
254+
f"{config.DEV_PORTAL_URL}/appconfigs/{self.app_name}"
255+
) from e
256+
raise OAuth2Error("Failed to refresh access token") from e
257+
237258
except Exception as e:
238259
logger.error(f"Failed to refresh access token, app_name={self.app_name}, error={e}")
239260
raise OAuth2Error("Failed to refresh access token") from e
240261

262+
return token
263+
241264
async def parse_fetch_token_response(self, token: dict) -> OAuth2SchemeCredentials:
242265
"""
243266
Parse OAuth2SchemeCredentials from token response with app-specific handling.

backend/aci/server/security_credentials_manager.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,20 @@ async def _get_oauth2_credentials(
9797
linked_account.security_credentials
9898
)
9999
if _access_token_is_expired(oauth2_scheme_credentials):
100-
# Instagram's access token only could be refreshed with a valid access token, so we need to re-authorize
100+
# Instagram's access token only could be refreshed with a valid access token, so we need to re-authorize if invalid
101101
if app.name == "INSTAGRAM":
102-
logger.error(
103-
f"Access token expired, please re-authorize, linked_account_id={linked_account.id}, "
104-
f"security_scheme={linked_account.security_scheme}, app={app.name}"
105-
)
106-
# NOTE: this error message could be used by the frontend to guide the user to re-authorize
107-
raise OAuth2Error(
108-
f"Access token expired. Please re-authorize at: "
109-
f"{config.DEV_PORTAL_URL}/appconfigs/{app.name}"
110-
)
102+
# Since _access_token_is_expired returned True, expires_at is guaranteed to be not None
103+
actual_expires_at = oauth2_scheme_credentials.expires_at + 86400 # type: ignore[operator]
104+
if int(time.time()) > actual_expires_at:
105+
logger.error(
106+
f"Access token expired, please re-authorize, linked_account_id={linked_account.id}, "
107+
f"security_scheme={linked_account.security_scheme}, app={app.name}"
108+
)
109+
# NOTE: this error message could be used by the frontend to guide the user to re-authorize
110+
raise OAuth2Error(
111+
f"Access token expired. Please re-authorize at: "
112+
f"{config.DEV_PORTAL_URL}/appconfigs/{app.name}"
113+
)
111114

112115
logger.warning(
113116
f"Access token expired, trying to refresh linked_account_id={linked_account.id}, "
@@ -116,6 +119,7 @@ async def _get_oauth2_credentials(
116119
token_response = await _refresh_oauth2_access_token(
117120
app.name, oauth2_scheme, oauth2_scheme_credentials
118121
)
122+
119123
# TODO: refactor parsing to _refresh_oauth2_access_token
120124
expires_at: int | None = None
121125
if "expires_at" in token_response:

0 commit comments

Comments
 (0)