Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/apps/src/microsoft/teams/apps/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Credentials,
JsonWebToken,
MessageActivityInput,
TokenCredentials,
)
from microsoft.teams.cards import AdaptiveCard
from microsoft.teams.common import Client, ClientOptions, ConsoleLogger, EventEmitter, LocalStorage
Expand Down Expand Up @@ -288,16 +289,22 @@ def _init_credentials(self) -> Optional[Credentials]:
client_id = self.options.client_id or os.getenv("CLIENT_ID")
client_secret = self.options.client_secret or os.getenv("CLIENT_SECRET")
tenant_id = self.options.tenant_id or os.getenv("TENANT_ID")
token = self.options.token

self.log.debug(f"Using CLIENT_ID: {client_id}")
if not tenant_id:
self.log.warning("TENANT_ID is not set, assuming multi-tenant app")
else:
self.log.debug(f"Using TENANT_ID: {tenant_id} (assuming single-tenant app)")

# - If client_id + client_secret : use ClientCredentials (standard client auth)
if client_id and client_secret:
return ClientCredentials(client_id=client_id, client_secret=client_secret, tenant_id=tenant_id)

# - If client_id + token callable : use TokenCredentials (where token is a custom token provider)
if client_id and token:
return TokenCredentials(client_id=client_id, tenant_id=tenant_id, token=token)

return None

async def _refresh_tokens(self, force: bool = False) -> None:
Expand Down
5 changes: 4 additions & 1 deletion packages/apps/src/microsoft/teams/apps/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dataclasses import dataclass, field
from logging import Logger
from typing import Any, List, Optional, TypedDict, cast
from typing import Any, Awaitable, Callable, List, Optional, TypedDict, Union, cast

from microsoft.teams.common.storage import Storage
from typing_extensions import Unpack
Expand All @@ -20,6 +20,8 @@ class AppOptions(TypedDict, total=False):
client_id: Optional[str]
client_secret: Optional[str]
tenant_id: Optional[str]
# Custom token provider function
token: Optional[Callable[[Union[str, list[str]], Optional[str]], Union[str, Awaitable[str]]]]

# Infrastructure
logger: Optional[Logger]
Expand All @@ -44,6 +46,7 @@ class InternalAppOptions:
client_id: Optional[str] = None
client_secret: Optional[str] = None
tenant_id: Optional[str] = None
token: Optional[Callable[[Union[str, list[str]], Optional[str]], Union[str, Awaitable[str]]]] = None
logger: Optional[Logger] = None
storage: Optional[Storage[str, Any]] = None

Expand Down
42 changes: 33 additions & 9 deletions packages/apps/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from microsoft.teams.api import TokenProtocol
from microsoft.teams.api.activities import InvokeActivity
from microsoft.teams.api.activities.message import MessageActivity
from microsoft.teams.api.activities.typing import TypingActivity
from microsoft.teams.api.models import Account, ConversationAccount
from microsoft.teams.apps.app import App
from microsoft.teams.apps.events import ActivityEvent
from microsoft.teams.apps.options import AppOptions
from microsoft.teams.apps.routing.activity_context import ActivityContext
from microsoft.teams.api import (
Account,
ConversationAccount,
InvokeActivity,
MessageActivity,
TokenCredentials,
TokenProtocol,
TypingActivity,
)
from microsoft.teams.apps import ActivityContext, ActivityEvent, App, AppOptions


class FakeToken(TokenProtocol):
Expand Down Expand Up @@ -493,3 +494,26 @@ async def handle_hello_pattern(ctx: ActivityContext[MessageActivity]) -> None:
# Verify non-matching activity doesn't match
non_matching_handlers = app_with_options.router.select_handlers(non_matching_activity)
assert len(non_matching_handlers) == 0

@pytest.mark.asyncio
async def test_app_with_callable_token(self):
"""Test that app initializes with callable token."""
token_called = False

def get_token(scope, tenant_id=None):
nonlocal token_called
token_called = True
return "test.jwt.token"

options = AppOptions(client_id="test-client-123", token=get_token)

app = App(**options)

assert app.credentials is not None
assert type(app.credentials) is TokenCredentials
assert app.credentials.client_id == "test-client-123"
assert callable(app.credentials.token)

res = await app.api.bots.token.get(app.credentials)
assert token_called is True
assert res.access_token == "test.jwt.token"