Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,15 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool:
content = await response.aread()
token_response = OAuthToken.model_validate_json(content)

# RFC 6749 §6: an omitted scope on refresh means the scope is unchanged from
# the prior access token. Carry it forward so the persisted token stays
# self-describing for the SEP-2350 step-up union after a restart.
if token_response.scope is None and self.context.current_tokens is not None:
token_response.scope = self.context.current_tokens.scope
# RFC 6749 §6: a refresh response may omit scope (unchanged) and refresh_token
# (the AS does not rotate). Carry both forward so the persisted token stays
# self-describing for the SEP-2350 step-up union and the next expiry can
# still refresh instead of forcing a full re-authorization.
prior = self.context.current_tokens
if token_response.scope is None and prior is not None:
token_response.scope = prior.scope
if token_response.refresh_token is None and prior is not None:
token_response.refresh_token = prior.refresh_token

self.context.current_tokens = token_response
self.context.update_token_expiry(token_response)
Expand Down Expand Up @@ -663,21 +667,25 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
await self.context.storage.set_client_info(client_information)
else:
# Fallback to Dynamic Client Registration
fallback_base = self.context.get_authorization_base_url(self.context.server_url)
registration_request = create_client_registration_request(
self.context.oauth_metadata,
self.context.client_metadata,
self.context.get_authorization_base_url(self.context.server_url),
self.context.oauth_metadata, self.context.client_metadata, fallback_base
)
registration_response = yield registration_request
client_information = await handle_registration_response(registration_response)
# Only record the issuer when the registration above actually targeted
# the discovered AS's registration_endpoint. With no metadata, or
# metadata that omits registration_endpoint, DCR fell back to the
# resource-server origin's /register — recording that as bound to a
# PRM-advertised AS would persist a binding that was never established.
# the discovered AS — either via its published registration_endpoint,
# or because the resource-origin /register fallback is on the issuer's
# own host (legacy same-origin embedded AS). Otherwise the fallback hit
# a different server and recording a binding to the PRM-advertised AS
# would persist a binding that was never established.
if (
self.context.oauth_metadata is not None
and self.context.oauth_metadata.registration_endpoint is not None
and discovered_issuer is not None
and (
self.context.oauth_metadata.registration_endpoint is not None
or self.context.get_authorization_base_url(discovered_issuer) == fallback_base
)
):
client_information.issuer = discovered_issuer
self.context.client_info = client_information
Expand Down
114 changes: 111 additions & 3 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2861,11 +2861,18 @@ async def test_handle_token_response_backfills_omitted_scope_from_request(


@pytest.mark.anyio
async def test_handle_refresh_response_carries_prior_scope_when_response_omits_it(
async def test_handle_refresh_response_carries_prior_scope_and_refresh_token_when_omitted(
oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage
):
"""RFC 6749 §6: an omitted refresh-response scope means scope is unchanged from the prior token."""
oauth_provider.context.current_tokens = OAuthToken(access_token="old", scope="read write")
"""RFC 6749 §6: omitted refresh-response scope and refresh_token are carried forward.

Omitted scope means it is unchanged from the prior access token. Omitted refresh_token
means the AS does not rotate refresh tokens; the client keeps using the previously
issued one so the next expiry can refresh instead of forcing a full re-authorization.
"""
oauth_provider.context.current_tokens = OAuthToken(
access_token="old", scope="read write", refresh_token="prior-refresh"
)
response = httpx.Response(
200,
json={"access_token": "new", "token_type": "Bearer", "expires_in": 3600},
Expand All @@ -2877,9 +2884,32 @@ async def test_handle_refresh_response_carries_prior_scope_when_response_omits_i
assert oauth_provider.context.current_tokens is not None
assert oauth_provider.context.current_tokens.access_token == "new"
assert oauth_provider.context.current_tokens.scope == "read write"
assert oauth_provider.context.current_tokens.refresh_token == "prior-refresh"
stored = await mock_storage.get_tokens()
assert stored is not None
assert stored.scope == "read write"
assert stored.refresh_token == "prior-refresh"


@pytest.mark.anyio
async def test_handle_refresh_response_adopts_rotated_refresh_token_when_returned(
oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage
):
"""A refresh response that includes ``refresh_token`` replaces the prior one (rotation)."""
oauth_provider.context.current_tokens = OAuthToken(
access_token="old", scope="read write", refresh_token="prior-refresh"
)
response = httpx.Response(
200,
json={"access_token": "new", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "rotated"},
request=httpx.Request("POST", "https://auth.example.com/token"),
)
ok = await oauth_provider._handle_refresh_response(response)

assert ok is True
stored = await mock_storage.get_tokens()
assert stored is not None
assert stored.refresh_token == "rotated"


@pytest.mark.anyio
Expand Down Expand Up @@ -3050,3 +3080,81 @@ async def echo_callback() -> AuthorizationCodeResult:
await auth_flow.asend(httpx.Response(200, request=final_req))
except StopAsyncIteration:
pass


@pytest.mark.anyio
async def test_issuer_is_stamped_when_same_origin_fallback_register_is_on_the_discovered_issuer(
oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage
):
"""SEP-2352: a fallback registration on the discovered issuer's own host is still bound.

Legacy same-origin embedded AS: PRM is absent, root ASM discovery succeeds with
``issuer`` equal to the resource origin and no ``registration_endpoint``. DCR falls
back to ``<resource-origin>/register`` — the issuer's own host — so the binding was
established and is recorded, preserving auto-recovery on a later AS migration.
"""
oauth_provider.context.current_tokens = None
oauth_provider.context.token_expiry_time = None
oauth_provider._initialized = True
oauth_provider.context.client_info = None

captured_state: str | None = None

async def capture_redirect(url: str) -> None:
nonlocal captured_state
captured_state = parse_qs(urlparse(url).query).get("state", [None])[0]

async def echo_callback() -> AuthorizationCodeResult:
return AuthorizationCodeResult(code="auth_code", state=captured_state)

oauth_provider.context.redirect_handler = capture_redirect
oauth_provider.context.callback_handler = echo_callback

auth_flow = oauth_provider.async_auth_flow(httpx.Request("GET", "https://api.example.com/v1/mcp"))
request = await auth_flow.__anext__()

# PRM discovery 404s on both well-known URLs.
prm_req = await auth_flow.asend(httpx.Response(401, request=request))
assert str(prm_req.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
prm_req = await auth_flow.asend(httpx.Response(404, request=prm_req))
assert str(prm_req.url) == "https://api.example.com/.well-known/oauth-protected-resource"

# Root ASM discovery succeeds with the resource origin as issuer and no registration_endpoint.
asm_req = await auth_flow.asend(httpx.Response(404, request=prm_req))
assert str(asm_req.url) == "https://api.example.com/.well-known/oauth-authorization-server"
asm_response = httpx.Response(
200,
content=(
b'{"issuer": "https://api.example.com", '
b'"authorization_endpoint": "https://api.example.com/authorize", '
b'"token_endpoint": "https://api.example.com/token"}'
),
request=asm_req,
)

# DCR falls back to the resource origin's /register — the issuer's own host.
dcr_req = await auth_flow.asend(asm_response)
assert dcr_req.method == "POST"
assert str(dcr_req.url) == "https://api.example.com/register"
dcr_response = httpx.Response(
201,
json={"client_id": "embedded-client", "redirect_uris": ["http://localhost:3030/callback"]},
request=dcr_req,
)
token_req = await auth_flow.asend(dcr_response)

stored = await mock_storage.get_client_info()
assert stored is not None
assert oauth_provider.context.oauth_metadata is not None
assert stored.client_id == "embedded-client"
assert stored.issuer == str(oauth_provider.context.oauth_metadata.issuer)
assert urlparse(stored.issuer).netloc == "api.example.com"

token_response = httpx.Response(
200, json={"access_token": "t", "token_type": "Bearer", "expires_in": 3600}, request=token_req
)
final_req = await auth_flow.asend(token_response)
try:
await auth_flow.asend(httpx.Response(200, request=final_req))
except StopAsyncIteration:
pass
Loading