diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 00a0b88b4..39858cba4 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -477,11 +477,15 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: content = await response.aread() token_response = OAuthToken.model_validate_json(content) - # RFC 6749 §6: an omitted scope on refresh means the scope is unchanged from - # the prior access token. Carry it forward so the persisted token stays - # self-describing for the SEP-2350 step-up union after a restart. - if token_response.scope is None and self.context.current_tokens is not None: - token_response.scope = self.context.current_tokens.scope + # RFC 6749 §6: a refresh response may omit scope (unchanged) and refresh_token + # (the AS does not rotate). Carry both forward so the persisted token stays + # self-describing for the SEP-2350 step-up union and the next expiry can + # still refresh instead of forcing a full re-authorization. + prior = self.context.current_tokens + if token_response.scope is None and prior is not None: + token_response.scope = prior.scope + if token_response.refresh_token is None and prior is not None: + token_response.refresh_token = prior.refresh_token self.context.current_tokens = token_response self.context.update_token_expiry(token_response) @@ -663,21 +667,25 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. await self.context.storage.set_client_info(client_information) else: # Fallback to Dynamic Client Registration + fallback_base = self.context.get_authorization_base_url(self.context.server_url) registration_request = create_client_registration_request( - self.context.oauth_metadata, - self.context.client_metadata, - self.context.get_authorization_base_url(self.context.server_url), + self.context.oauth_metadata, self.context.client_metadata, fallback_base ) registration_response = yield registration_request client_information = await handle_registration_response(registration_response) # Only record the issuer when the registration above actually targeted - # the discovered AS's registration_endpoint. With no metadata, or - # metadata that omits registration_endpoint, DCR fell back to the - # resource-server origin's /register — recording that as bound to a - # PRM-advertised AS would persist a binding that was never established. + # the discovered AS — either via its published registration_endpoint, + # or because the resource-origin /register fallback is on the issuer's + # own host (legacy same-origin embedded AS). Otherwise the fallback hit + # a different server and recording a binding to the PRM-advertised AS + # would persist a binding that was never established. if ( self.context.oauth_metadata is not None - and self.context.oauth_metadata.registration_endpoint is not None + and discovered_issuer is not None + and ( + self.context.oauth_metadata.registration_endpoint is not None + or self.context.get_authorization_base_url(discovered_issuer) == fallback_base + ) ): client_information.issuer = discovered_issuer self.context.client_info = client_information diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 06f7b8076..cdbba1b58 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -2861,11 +2861,18 @@ async def test_handle_token_response_backfills_omitted_scope_from_request( @pytest.mark.anyio -async def test_handle_refresh_response_carries_prior_scope_when_response_omits_it( +async def test_handle_refresh_response_carries_prior_scope_and_refresh_token_when_omitted( oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage ): - """RFC 6749 §6: an omitted refresh-response scope means scope is unchanged from the prior token.""" - oauth_provider.context.current_tokens = OAuthToken(access_token="old", scope="read write") + """RFC 6749 §6: omitted refresh-response scope and refresh_token are carried forward. + + Omitted scope means it is unchanged from the prior access token. Omitted refresh_token + means the AS does not rotate refresh tokens; the client keeps using the previously + issued one so the next expiry can refresh instead of forcing a full re-authorization. + """ + oauth_provider.context.current_tokens = OAuthToken( + access_token="old", scope="read write", refresh_token="prior-refresh" + ) response = httpx.Response( 200, json={"access_token": "new", "token_type": "Bearer", "expires_in": 3600}, @@ -2877,9 +2884,32 @@ async def test_handle_refresh_response_carries_prior_scope_when_response_omits_i assert oauth_provider.context.current_tokens is not None assert oauth_provider.context.current_tokens.access_token == "new" assert oauth_provider.context.current_tokens.scope == "read write" + assert oauth_provider.context.current_tokens.refresh_token == "prior-refresh" stored = await mock_storage.get_tokens() assert stored is not None assert stored.scope == "read write" + assert stored.refresh_token == "prior-refresh" + + +@pytest.mark.anyio +async def test_handle_refresh_response_adopts_rotated_refresh_token_when_returned( + oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage +): + """A refresh response that includes ``refresh_token`` replaces the prior one (rotation).""" + oauth_provider.context.current_tokens = OAuthToken( + access_token="old", scope="read write", refresh_token="prior-refresh" + ) + response = httpx.Response( + 200, + json={"access_token": "new", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "rotated"}, + request=httpx.Request("POST", "https://auth.example.com/token"), + ) + ok = await oauth_provider._handle_refresh_response(response) + + assert ok is True + stored = await mock_storage.get_tokens() + assert stored is not None + assert stored.refresh_token == "rotated" @pytest.mark.anyio @@ -3050,3 +3080,81 @@ async def echo_callback() -> AuthorizationCodeResult: await auth_flow.asend(httpx.Response(200, request=final_req)) except StopAsyncIteration: pass + + +@pytest.mark.anyio +async def test_issuer_is_stamped_when_same_origin_fallback_register_is_on_the_discovered_issuer( + oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage +): + """SEP-2352: a fallback registration on the discovered issuer's own host is still bound. + + Legacy same-origin embedded AS: PRM is absent, root ASM discovery succeeds with + ``issuer`` equal to the resource origin and no ``registration_endpoint``. DCR falls + back to ``/register`` — the issuer's own host — so the binding was + established and is recorded, preserving auto-recovery on a later AS migration. + """ + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider._initialized = True + oauth_provider.context.client_info = None + + captured_state: str | None = None + + async def capture_redirect(url: str) -> None: + nonlocal captured_state + captured_state = parse_qs(urlparse(url).query).get("state", [None])[0] + + async def echo_callback() -> AuthorizationCodeResult: + return AuthorizationCodeResult(code="auth_code", state=captured_state) + + oauth_provider.context.redirect_handler = capture_redirect + oauth_provider.context.callback_handler = echo_callback + + auth_flow = oauth_provider.async_auth_flow(httpx.Request("GET", "https://api.example.com/v1/mcp")) + request = await auth_flow.__anext__() + + # PRM discovery 404s on both well-known URLs. + prm_req = await auth_flow.asend(httpx.Response(401, request=request)) + assert str(prm_req.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" + prm_req = await auth_flow.asend(httpx.Response(404, request=prm_req)) + assert str(prm_req.url) == "https://api.example.com/.well-known/oauth-protected-resource" + + # Root ASM discovery succeeds with the resource origin as issuer and no registration_endpoint. + asm_req = await auth_flow.asend(httpx.Response(404, request=prm_req)) + assert str(asm_req.url) == "https://api.example.com/.well-known/oauth-authorization-server" + asm_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://api.example.com", ' + b'"authorization_endpoint": "https://api.example.com/authorize", ' + b'"token_endpoint": "https://api.example.com/token"}' + ), + request=asm_req, + ) + + # DCR falls back to the resource origin's /register — the issuer's own host. + dcr_req = await auth_flow.asend(asm_response) + assert dcr_req.method == "POST" + assert str(dcr_req.url) == "https://api.example.com/register" + dcr_response = httpx.Response( + 201, + json={"client_id": "embedded-client", "redirect_uris": ["http://localhost:3030/callback"]}, + request=dcr_req, + ) + token_req = await auth_flow.asend(dcr_response) + + stored = await mock_storage.get_client_info() + assert stored is not None + assert oauth_provider.context.oauth_metadata is not None + assert stored.client_id == "embedded-client" + assert stored.issuer == str(oauth_provider.context.oauth_metadata.issuer) + assert urlparse(stored.issuer).netloc == "api.example.com" + + token_response = httpx.Response( + 200, json={"access_token": "t", "token_type": "Bearer", "expires_in": 3600}, request=token_req + ) + final_req = await auth_flow.asend(token_response) + try: + await auth_flow.asend(httpx.Response(200, request=final_req)) + except StopAsyncIteration: + pass