Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend/aci/common/schemas/security_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class OAuth2Scheme(BaseModel):
description="The authentication method for the OAuth2 token endpoint, e.g., 'client_secret_post' "
"for some providers that require client_id/client_secret to be sent in the body of the token request, like Hubspot",
)
custom_data: dict | None = Field(
default=None,
description="Custom data for OAuth2 scheme, e.g., additional URLs or configuration parameters specific to the provider",
)
# NOTE: For now this field should not be provided when creating a new OAuth2 App (because the current server redirect URL should be used,
# which is constructed dynamically).
# It only makes sense for user to provide it in OAuth2SchemeOverride if they want whitelabeling.
Expand Down
122 changes: 113 additions & 9 deletions backend/aci/server/oauth2_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import time
from typing import Any, cast

import httpx
from authlib.integrations.httpx_client import AsyncOAuth2Client

from aci.common.exceptions import OAuth2Error
from aci.common.logging_setup import get_logger
from aci.common.schemas.security_scheme import OAuth2SchemeCredentials
from aci.server import config

UNICODE_ASCII_CHARACTER_SET = string.ascii_letters + string.digits
logger = get_logger(__name__)
Expand All @@ -24,6 +26,7 @@ def __init__(
access_token_url: str,
refresh_token_url: str,
token_endpoint_auth_method: str | None = None,
custom_data: dict | None = None,
):
"""
Initialize the OAuth2Manager
Expand All @@ -39,6 +42,7 @@ def __init__(
token_endpoint_auth_method:
client_secret_basic (default) | client_secret_post | none
Additional options can be achieved by registering a custom auth method
custom_data: Custom data for OAuth2 scheme, e.g., additional URLs or configuration
"""
self.app_name = app_name
self.client_id = client_id
Expand All @@ -48,6 +52,7 @@ def __init__(
self.access_token_url = access_token_url
self.refresh_token_url = refresh_token_url
self.token_endpoint_auth_method = token_endpoint_auth_method
self.custom_data = custom_data or {}

# TODO: need to close the client after use
# Add an aclose() helper (or implement __aenter__/__aexit__) and make callers invoke it during shutdown.
Expand Down Expand Up @@ -111,6 +116,48 @@ async def create_authorization_url(

return str(authorization_url)

async def exchange_short_lived_token(self, short_lived_token: str) -> dict[str, Any]:
"""
Exchange short-lived access token for long-lived access token.
This is specific to Instagram's API requirements.

Args:
short_lived_token: The short-lived access token from the initial OAuth flow

Returns:
Token response dictionary with long-lived access token
"""
if self.app_name != "INSTAGRAM":
raise OAuth2Error("Token exchange is only supported for Instagram")

exchange_token_url = self.custom_data.get(
"exchange_token_url", "https://graph.instagram.com/access_token"
)

try:
response = await self.oauth2_client.get(
exchange_token_url,
params={
"grant_type": "ig_exchange_token",
"client_secret": self.client_secret,
"access_token": short_lived_token,
},
timeout=30.0,
)
response.raise_for_status()

token_data = cast(dict[str, Any], response.json())
logger.info(
f"Successfully exchanged short-lived token for long-lived token, app_name={self.app_name}"
)
return token_data

except Exception as e:
logger.error(
f"Failed to exchange short-lived token, app_name={self.app_name}, error={e}"
)
raise OAuth2Error("Failed to exchange short-lived token for long-lived token") from e

# TODO: some app may not support "code_verifier"?
async def fetch_token(
self,
Expand Down Expand Up @@ -140,28 +187,81 @@ async def fetch_token(
scope=self.scope,
),
)
# handle Instagram's special case - exchange short-lived token for long-lived token
if self.app_name == "INSTAGRAM":
if "access_token" in token:
short_lived_token = token["access_token"]
logger.info(
f"Exchanging short-lived token for long-lived token, app_name={self.app_name}"
)
long_lived_token_response = await self.exchange_short_lived_token(
short_lived_token
)
# Update data with long-lived token response: add expires_in and token_type, update access_token
token.update(long_lived_token_response)
else:
logger.error(
f"Missing access_token in Instagram OAuth response, app={self.app_name}"
)
raise OAuth2Error("Missing access_token in Instagram OAuth response")

# return the token response with long-lived access token
return token
except Exception as e:
logger.error(f"Failed to fetch access token, app_name={self.app_name}, error={e}")
raise OAuth2Error("failed to fetch access token") from e

async def refresh_token(
self,
access_token: str,
refresh_token: str,
) -> dict[str, Any]:
"""
Refresh OAuth2 access token

Args:
access_token: The current access token used for Instagram refresh
refresh_token: The refresh token used for standard OAuth2 refresh

Returns:
Token response dictionary
"""
try:
token = cast(
dict[str, Any],
await self.oauth2_client.refresh_token(
self.refresh_token_url, refresh_token=refresh_token
),
)
return token
if self.app_name == "INSTAGRAM":
response = await self.oauth2_client.get(
self.refresh_token_url,
params={
"grant_type": "ig_refresh_token",
"access_token": access_token,
},
timeout=30.0,
)
response.raise_for_status()
token = cast(dict[str, Any], response.json())
else:
token = cast(
dict[str, Any],
await self.oauth2_client.refresh_token(
self.refresh_token_url, refresh_token=refresh_token
),
)

except httpx.HTTPStatusError as e:
logger.error(f"Failed to refresh access token, app_name={self.app_name}, error={e}")
if self.app_name == "INSTAGRAM" and e.response.status_code == 400:
raise OAuth2Error(
f"Access token expired. Please re-authorize at: "
f"{config.DEV_PORTAL_URL}/appconfigs/{self.app_name}"
) from e
raise OAuth2Error("Failed to refresh access token") from e

except Exception as e:
logger.error(f"Failed to refresh access token, app_name={self.app_name}, error={e}")
raise OAuth2Error("Failed to refresh access token") from e

def parse_fetch_token_response(self, token: dict) -> OAuth2SchemeCredentials:
return token

async def parse_fetch_token_response(self, token: dict) -> OAuth2SchemeCredentials:
"""
Parse OAuth2SchemeCredentials from token response with app-specific handling.

Expand Down Expand Up @@ -190,7 +290,11 @@ def parse_fetch_token_response(self, token: dict) -> OAuth2SchemeCredentials:
if "expires_at" in data:
expires_at = int(data["expires_at"])
elif "expires_in" in data:
expires_at = int(time.time()) + int(data["expires_in"])
if self.app_name == "INSTAGRAM":
# Reduce expiration time by 1 day (86400 seconds) for safety margin
expires_at = int(time.time()) + max(0, int(data["expires_in"]) - 86400)
else:
expires_at = int(time.time()) + int(data["expires_in"])

# TODO: if scope is present, check if it matches the scope in the App Configuration

Expand Down
4 changes: 3 additions & 1 deletion backend/aci/server/routes/linked_accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ async def link_oauth2_account(
access_token_url=oauth2_scheme.access_token_url,
refresh_token_url=oauth2_scheme.refresh_token_url,
token_endpoint_auth_method=oauth2_scheme.token_endpoint_auth_method,
custom_data=oauth2_scheme.custom_data,
)

path = request.url_for(LINKED_ACCOUNTS_OAUTH2_CALLBACK_ROUTE_NAME).path
Expand Down Expand Up @@ -501,14 +502,15 @@ async def linked_accounts_oauth2_callback(
access_token_url=oauth2_scheme.access_token_url,
refresh_token_url=oauth2_scheme.refresh_token_url,
token_endpoint_auth_method=oauth2_scheme.token_endpoint_auth_method,
custom_data=oauth2_scheme.custom_data,
)

token_response = await oauth2_manager.fetch_token(
redirect_uri=state.redirect_uri,
code=code,
code_verifier=state.code_verifier,
)
security_credentials = oauth2_manager.parse_fetch_token_response(token_response)
security_credentials = await oauth2_manager.parse_fetch_token_response(token_response)

# if the linked account already exists, update it, otherwise create a new one
# TODO: consider separating the logic for updating and creating a linked account or give warning to clients
Expand Down
28 changes: 26 additions & 2 deletions backend/aci/server/security_credentials_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OAuth2SchemeCredentials,
SecuritySchemeOverrides,
)
from aci.server import config
from aci.server.oauth2_manager import OAuth2Manager

logger = get_logger(__name__)
Expand Down Expand Up @@ -96,19 +97,39 @@ async def _get_oauth2_credentials(
linked_account.security_credentials
)
if _access_token_is_expired(oauth2_scheme_credentials):
# Instagram's access token only could be refreshed with a valid access token, so we need to re-authorize if invalid
if app.name == "INSTAGRAM":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just make sure to test if re-autorize works as expected when you're doing e2e testing

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image image

This is the dev version fyi.

# Since _access_token_is_expired returned True, expires_at is guaranteed to be not None
actual_expires_at = oauth2_scheme_credentials.expires_at + 86400 # type: ignore[operator]
if int(time.time()) > actual_expires_at:
logger.error(
f"Access token expired, please re-authorize, linked_account_id={linked_account.id}, "
f"security_scheme={linked_account.security_scheme}, app={app.name}"
)
# NOTE: this error message could be used by the frontend to guide the user to re-authorize
raise OAuth2Error(
f"Access token expired. Please re-authorize at: "
f"{config.DEV_PORTAL_URL}/appconfigs/{app.name}"
)

logger.warning(
f"Access token expired, trying to refresh linked_account_id={linked_account.id}, "
f"security_scheme={linked_account.security_scheme}, app={app.name}"
)
token_response = await _refresh_oauth2_access_token(
app.name, oauth2_scheme, oauth2_scheme_credentials
)

# TODO: refactor parsing to _refresh_oauth2_access_token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Centralize token parsing and expiry handling in OAuth2Manager

The TODO is on point. Move provider-specific expiry calculation (e.g., IG’s safety margin and potential absence/presence of expires_in) into OAuth2Manager so callers don’t duplicate logic and drift over time. Return a normalized structure with access_token, refresh_token (optional), and computed expires_at.

I can implement a parse_refresh_response(access_token_response: dict) -> dict[str, Any] in OAuth2Manager and update callers accordingly. Want me to open a follow-up PR?

🤖 Prompt for AI Agents
In backend/aci/server/security_credentials_manager.py around line 123, the
provider-specific token parsing and expiry logic should be moved into
OAuth2Manager: implement a method parse_refresh_response(access_token_response:
dict) -> dict[str, Any] that normalizes and returns { "access_token": str,
"refresh_token"?: str, "expires_at": int }, apply IG-specific safety margin and
robustly handle absent or malformed expires_in values when computing expires_at;
replace duplicated parsing in callers to call this new method so all token
parsing and expiry computation is centralized and consistent.

expires_at: int | None = None
if "expires_at" in token_response:
expires_at = int(token_response["expires_at"])
elif "expires_in" in token_response:
expires_at = int(time.time()) + int(token_response["expires_in"])
if app.name == "INSTAGRAM":
# Reduce expiration time by 1 day (86400 seconds) for safety margin
expires_at = int(time.time()) + max(0, int(token_response["expires_in"]) - 86400)
else:
expires_at = int(time.time()) + int(token_response["expires_in"])

if not token_response.get("access_token") or not expires_at:
logger.error(
Expand Down Expand Up @@ -143,6 +164,8 @@ async def _refresh_oauth2_access_token(
app_name: str, oauth2_scheme: OAuth2Scheme, oauth2_scheme_credentials: OAuth2SchemeCredentials
) -> dict:
refresh_token = oauth2_scheme_credentials.refresh_token
access_token = oauth2_scheme_credentials.access_token

if not refresh_token:
raise OAuth2Error("no refresh token found")

Expand All @@ -157,9 +180,10 @@ async def _refresh_oauth2_access_token(
access_token_url=oauth2_scheme.access_token_url,
refresh_token_url=oauth2_scheme.refresh_token_url,
token_endpoint_auth_method=oauth2_scheme.token_endpoint_auth_method,
custom_data=oauth2_scheme.custom_data,
)

return await oauth2_manager.refresh_token(refresh_token)
return await oauth2_manager.refresh_token(refresh_token, access_token)


def _get_api_key_credentials(
Expand Down
29 changes: 29 additions & 0 deletions backend/apps/instagram/app.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"name": "INSTAGRAM",
"display_name": "Instagram",
"logo": "https://raw.githubusercontent.com/aipotheosis-labs/aipolabs-icons/refs/heads/main/apps/instagram.svg",
"provider": "Meta Platforms, Inc.",
"version": "1.0.0",
"description": "The Instagram API allows developers to access and manage Instagram resources programmatically. It provides functionality for publishing content, retrieving user information, fetching post data, tracking user feeds, and managing direct messages through RESTful HTTP calls.",
"security_schemes": {
"oauth2": {
"location": "header",
"name": "Authorization",
"prefix": "Bearer",
"client_id": "{{ AIPOLABS_INSTAGRAM_APP_CLIENT_ID }}",
"client_secret": "{{ AIPOLABS_INSTAGRAM_APP_CLIENT_SECRET }}",
"scope": "instagram_business_basic instagram_business_content_publish instagram_business_manage_messages instagram_business_manage_comments instagram_business_manage_insights",
"authorize_url": "https://www.instagram.com/oauth/authorize",
"access_token_url": "https://api.instagram.com/oauth/access_token",
"refresh_token_url": "https://graph.instagram.com/refresh_access_token",
"custom_data": {
"exchange_token_url": "https://graph.instagram.com/access_token"
},
"token_endpoint_auth_method": "client_secret_post"
}
},
"default_security_credentials_by_scheme": {},
"categories": ["Social Media"],
"visibility": "public",
"active": true
}
Loading