Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions docs/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ ctx: ClientRequestContext
server_ctx: ServerRequestContext[LifespanContextT, RequestT]
```

`ServerRequestContext` is now a standalone dataclass — it no longer subclasses `RequestContext[ServerSession]`. It carries the same fields (`session`, `request_id`, `meta`, `lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) plus a new `protocol_version: str` field, so handler code is unaffected, but `isinstance(ctx, RequestContext)` checks and `RequestContext[ServerSession]` annotations need updating to `ServerRequestContext`.
`ServerRequestContext` is now a standalone dataclass — it no longer subclasses `RequestContext[ServerSession]`. It carries the same fields (`session`, `request_id`, `meta`, `lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) plus new `protocol_version: str`, `method: str`, and raw `params: Mapping[str, Any] | None` fields (the last two let middleware read and rewrite the inbound message), so handler code is unaffected, but `isinstance(ctx, RequestContext)` checks and `RequestContext[ServerSession]` annotations need updating to `ServerRequestContext`.

The high-level `Context` class (injected into `@mcp.tool()` etc.) similarly dropped its `ServerSessionT` parameter: `Context[ServerSessionT, LifespanContextT, RequestT]` → `Context[LifespanContextT, RequestT]`. Both remaining parameters have defaults, so bare `Context` is usually sufficient:

Expand Down Expand Up @@ -961,27 +961,24 @@ server.add_notification_handler("notifications/custom", MyNotifyParams, my_notif
These were private, but some users subclassed `Server` and overrode them to intercept requests. Use middleware instead:

```python
from collections.abc import Mapping
from typing import Any

from mcp.server import Server, ServerRequestContext
from mcp.server.context import CallNext, HandlerResult


async def logging_middleware(
ctx: ServerRequestContext[Any, Any], method: str, params: Mapping[str, Any] | None, call_next: CallNext
) -> HandlerResult:
print(f"handling {method}")
result = await call_next()
print(f"done {method}")
async def logging_middleware(ctx: ServerRequestContext[Any, Any], call_next: CallNext) -> HandlerResult:
print(f"handling {ctx.method}")
result = await call_next(ctx)
print(f"done {ctx.method}")
return result


server = Server("my-server", on_call_tool=...)
server.middleware.append(logging_middleware)
```

Middleware runs before params validation, so `params` is the raw inbound mapping (or `None`), and it also wraps unknown methods.
The method and the raw inbound params are `ctx.method` and `ctx.params` (`params` is `None` when the message carries none). Middleware runs before params validation and also wraps unknown methods. To rewrite the method or params before the handler runs, pass an adjusted context through: `await call_next(replace(ctx, params=...))`.

### Lowlevel `Server.run(raise_exceptions=True)`: transport errors no longer re-raised

Expand Down
51 changes: 51 additions & 0 deletions src/mcp/server/_otel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from typing import Any

from opentelemetry.trace import SpanKind, StatusCode
from pydantic import ValidationError

from mcp.server.context import CallNext, HandlerResult, ServerMiddleware, ServerRequestContext
from mcp.shared._otel import extract_trace_context, otel_span
from mcp.shared.exceptions import MCPError


class OpenTelemetryMiddleware(ServerMiddleware[Any]):
"""Context-tier middleware that wraps each inbound message in an OpenTelemetry span.

Span name `"MCP handle <method> [<target>]"`, `mcp.method.name` attribute, W3C
trace context extracted from `params._meta` (SEP-414), and an ERROR status if
the handler raises. Requests and notifications both get a span;
`jsonrpc.request.id` is set only when `ctx.request_id` is present (notifications
have none).
"""

async def __call__(self, ctx: ServerRequestContext[Any, Any], call_next: CallNext) -> HandlerResult:
name = ctx.params.get("name") if ctx.params else None
target = name if isinstance(name, str) else None

attributes: dict[str, Any] = {"mcp.method.name": ctx.method}
if ctx.request_id is not None:
attributes["jsonrpc.request.id"] = str(ctx.request_id)

with otel_span(
name=f"MCP handle {ctx.method}{f' {target}' if target else ''}",
kind=SpanKind.SERVER,
attributes=attributes,
context=extract_trace_context(ctx.meta),
record_exception=False,
set_status_on_exception=False,
) as span:
Comment thread
claude[bot] marked this conversation as resolved.
try:
return await call_next(ctx)
except MCPError as e:
span.set_status(StatusCode.ERROR, e.error.message)
raise
except ValidationError:
# Mirror the sanitized wire response; pydantic messages carry client input.
span.set_status(StatusCode.ERROR, "Invalid request parameters")
raise
except Exception as e:
span.record_exception(e)
span.set_status(StatusCode.ERROR, str(e))
raise
37 changes: 21 additions & 16 deletions src/mcp/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mcp.shared.transport_context import TransportContext
from mcp.types import LoggingLevel, RequestId, RequestParamsMeta

# Invariant: parameterizes a mutable dataclass field; dict default matches the default lifespan.
# Invariant: parametrizes a mutable dataclass field; dict default matches the default lifespan.
LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any])
RequestT = TypeVar("RequestT", default=Any)

Expand All @@ -33,6 +33,8 @@ class ServerRequestContext(Generic[LifespanContextT, RequestT]):
session: ServerSession
lifespan_context: LifespanContextT
protocol_version: str
method: str
params: Mapping[str, Any] | None = None
request_id: RequestId | None = None
meta: RequestParamsMeta | None = None
request: RequestT | None = None
Expand Down Expand Up @@ -113,39 +115,44 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *
"""What a request handler (or middleware) may return. `ServerRunner` serializes
all three to a result dict."""

CallNext = Callable[[], Awaitable[HandlerResult]]
CallNext = Callable[["ServerRequestContext[Any, Any]"], Awaitable[HandlerResult]]
"""Invokes the rest of the chain. Pass the `ctx` through; rewrite `method` or
`params` with `dataclasses.replace(ctx, ...)` to alter what the handler sees."""

_MwLifespanT = TypeVar("_MwLifespanT")


class ServerMiddleware(Protocol[_MwLifespanT]):
"""Context-tier middleware: `(ctx, method, params, call_next) -> result`.
"""Context-tier middleware: `(ctx, call_next) -> result`.

Runs at the top of `ServerRunner._on_request` / `_on_notify` after `ctx`
is built but before any validation, lookup, or handshake. Wraps every
inbound request and notification: `initialize`, the pre-init gate,
`METHOD_NOT_FOUND`, params validation, the handler call, and
`notifications/initialized` all run inside `call_next()`.
`notifications/initialized` all run inside `call_next(ctx)`.
`notifications/cancelled` is observed too; the dispatcher applies the
cancellation itself, then forwards the notification. A request-side
failure reaches the middleware as a raised `MCPError` (or
`ValidationError` for malformed params) so observation/logging middleware
can record it. Listed outermost-first on `Server.middleware`.

The method and the raw inbound params are `ctx.method` and `ctx.params` (no
model validation has happened yet). To rewrite either before the handler
runs, pass an adjusted context: `await call_next(replace(ctx, params=...))`.
`ctx.request_id is None` distinguishes a notification from a request. For
notifications `call_next()` returns `None` (a dropped or unhandled
notifications `call_next(ctx)` returns `None` (a dropped or unhandled
notification also returns `None`) and the middleware's own return value is
discarded.

`params` is the raw inbound mapping (no model validation has happened
yet). For typed inspection, validate against the model the middleware
expects.

Warning: `initialize` is handled inline - the dispatcher does not read
further inbound messages until the middleware chain returns. Awaiting a
server-to-client request (`ctx.session.send_request`, `send_ping`, ...)
while handling `initialize` therefore deadlocks the connection: the
response can never be dequeued. Send-and-forget notifications are safe.
!!! warning
`initialize` is handled inline - the dispatcher does not read
further inbound messages until the middleware chain returns. Awaiting a
server-to-client request (`ctx.session.send_request`, `send_ping`, ...)
while handling `initialize` therefore deadlocks the connection: the
response can never be dequeued. Send-and-forget notifications are safe.
`initialize` is observed but not rewritable: the post-chain handshake
commit reads the wire params, so to veto the handshake raise *before*
`call_next()`.

`Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific
middleware sees `ctx.lifespan_context: L`. While the context is the
Expand All @@ -162,7 +169,5 @@ class ServerMiddleware(Protocol[_MwLifespanT]):
async def __call__(
self,
ctx: ServerRequestContext[_MwLifespanT, Any],
method: str,
params: Mapping[str, Any] | None,
call_next: CallNext,
) -> HandlerResult: ...
2 changes: 1 addition & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def __init__(
self._session_manager: StreamableHTTPSessionManager | None = None
# Context-tier middleware: wraps every inbound request (including
# `initialize`, lookup, validation, handler) with
# `(ctx, method, params, call_next)`. Applied in `ServerRunner._on_request`.
# `(ctx, call_next)`. Applied in `ServerRunner._on_request`.
# TODO(L54): provisional - signature and semantics change with the
# Context/middleware rework (covariant `Context[L]`, outbound seam) before
# v2 final.
Expand Down
67 changes: 44 additions & 23 deletions src/mcp/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _extract_meta(params: Mapping[str, Any] | None) -> RequestParamsMeta | None:
return None


def otel_middleware(next_on_request: OnRequest) -> OnRequest:
def otel_middleware(call_next: OnRequest) -> OnRequest:
"""Dispatch-tier middleware that wraps each request in an OpenTelemetry span.

Mirrors the span shape of the existing `Server._handle_request`: span name
Expand Down Expand Up @@ -129,7 +129,7 @@ async def wrapped(
set_status_on_exception=False,
) as span:
try:
return await next_on_request(dctx, method, params)
return await call_next(dctx, method, params)
except MCPError as e:
span.set_status(StatusCode.ERROR, e.error.message)
raise
Expand Down Expand Up @@ -200,6 +200,14 @@ async def to_jsonrpc_response(request_id: RequestId, coro: Awaitable[dict[str, A
return JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result)


def _apply_middleware(
middleware: ServerMiddleware[Any], call_next: CallNext, ctx: ServerRequestContext[Any, Any]
) -> Awaitable[HandlerResult]:
"""Adapt one middleware to the `CallNext` shape: bind `call_next`, take
`ctx` at call time so a rewritten context flows down the chain."""
return middleware(ctx, call_next)


@dataclass
class ServerRunner(Generic[LifespanT]):
"""Per-connection handler kernel. One instance per client connection."""
Expand All @@ -220,7 +228,9 @@ def on_request(self) -> OnRequest:
wraps everything - initialize, METHOD_NOT_FOUND, validation failures
included.
"""
return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request)
return reduce(
lambda handler, middleware: middleware(handler), reversed(self.dispatch_middleware), self._on_request
)

@cached_property
def on_notify(self) -> OnNotify:
Expand All @@ -234,15 +244,18 @@ async def _on_request(
) -> dict[str, Any]:
meta = _extract_meta(params)
version = self.connection.protocol_version
ctx = self._make_context(dctx, meta, version)
ctx = self._make_context(dctx, method, params, meta, version)
is_spec_method = method in _methods.SPEC_CLIENT_METHODS

async def _inner() -> HandlerResult:
async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> HandlerResult:
# Read method/params off `ctx` so a middleware that rewrote them via
# `call_next(replace(ctx, ...))` reaches lookup and the handler.
method, params = ctx.method, ctx.params
# Pinned compat: spec methods are surface-validated before lookup,
# so malformed params are INVALID_PARAMS even with no handler
# registered. Custom methods miss the monolith map and fall through
# to `entry.params_type` exactly as before.
if is_spec_method:
if method in _methods.SPEC_CLIENT_METHODS:
try:
_methods.validate_client_request(method, version, params)
except KeyError:
Expand Down Expand Up @@ -272,8 +285,8 @@ async def _inner() -> HandlerResult:
raise MCPError.from_error_data(result)
return result

call = self._compose_server_middleware(ctx, method, params, _inner)
result = _dump_result(await call())
call = self._compose_server_middleware(_inner)
result = _dump_result(await call(ctx))
# TODO(L56): reject resultType values outside {"complete", "input_required"} unless the
# corresponding extension is in this request's _meta clientCapabilities.extensions; the
# explicit MUST-reject is client-side (basic/index.mdx ResultType), this enforces it proactively.
Comment thread
claude[bot] marked this conversation as resolved.
Expand All @@ -292,6 +305,11 @@ async def _inner() -> HandlerResult:
if method == "initialize":
# Commit only on chain success, so a middleware veto leaves no state.
# Race-free: the read loop is parked until this call returns.
# TODO: this re-reads the wire `params`, so a middleware that rewrote
# `ctx.params` (or `ctx.method`, or short-circuited without `call_next`)
# can leave `connection.protocol_version` out of step with the
# `InitializeResult` `_inner` produced. Resolve when `initialize` becomes
# a built-in handler so commit and result derive from one negotiation.
self.connection.client_params, self.connection.protocol_version = self._negotiate_initialize(params)
return result

Expand All @@ -303,9 +321,10 @@ async def _on_notify(
) -> None:
meta = _extract_meta(params)
version = self.connection.protocol_version
ctx = self._make_context(dctx, meta, version)
ctx = self._make_context(dctx, method, params, meta, version)

async def _inner() -> None:
async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> None:
method, params = ctx.method, ctx.params
if method in _methods.SPEC_CLIENT_NOTIFICATION_METHODS:
try:
_methods.validate_client_notification(method, version, params)
Expand Down Expand Up @@ -335,33 +354,33 @@ async def _inner() -> None:
return
await entry.handler(ctx, typed_params)

call = self._compose_server_middleware(ctx, method, params, _inner)
call = self._compose_server_middleware(_inner)
try:
await call()
await call(ctx)
except Exception:
# A crashing handler must not cancel the dispatcher's task group;
# middleware saw the raise out of call_next() first.
logger.exception("notification handler for %r raised", method)

def _compose_server_middleware(
self,
ctx: ServerRequestContext[LifespanT, Any],
method: str,
params: Mapping[str, Any] | None,
inner: CallNext,
) -> CallNext:
def _compose_server_middleware(self, inner: CallNext) -> CallNext:
"""Wrap `inner` in `Server.middleware`, outermost-first.

Shared by `_on_request` and `_on_notify` so the same middleware chain
observes every inbound message.
observes every inbound message. The composed callable takes the `ctx`
at call time, so a middleware can rewrite it for the rest of the chain.
"""
call = inner
for mw in reversed(self.server.middleware):
call = partial(mw, ctx, method, params, call)
for middleware in reversed(self.server.middleware):
call = partial(_apply_middleware, middleware, call)
return call

def _make_context(
self, dctx: DispatchContext[TransportContext], meta: RequestParamsMeta | None, protocol_version: str
self,
dctx: DispatchContext[TransportContext],
method: str,
params: Mapping[str, Any] | None,
meta: RequestParamsMeta | None,
protocol_version: str,
) -> ServerRequestContext[LifespanT, Any]:
# TODO(L54): remove for Context rework. Reads the SHTTP per-request
# data off the raw `dctx.message_metadata` carrier; replace with the
Expand All @@ -380,6 +399,8 @@ def _make_context(
return ServerRequestContext(
session=session,
lifespan_context=self.lifespan_state,
method=method,
params=params,
request_id=dctx.request_id,
meta=meta,
protocol_version=protocol_version,
Expand Down
24 changes: 16 additions & 8 deletions src/mcp/shared/_otel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from __future__ import annotations

from collections.abc import Iterator
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from typing import Any

from opentelemetry.context import Context
from opentelemetry.propagate import extract, inject
from opentelemetry.trace import SpanKind, get_tracer
from opentelemetry.trace import SpanKind, get_current_span, get_tracer
from opentelemetry.trace.span import Span

_tracer = get_tracer("mcp-python-sdk")

Expand All @@ -22,7 +23,7 @@ def otel_span(
context: Context | None = None,
record_exception: bool = True,
set_status_on_exception: bool = True,
) -> Iterator[Any]:
) -> Generator[Span]:
"""Create an OTel span."""
with _tracer.start_as_current_span(
name,
Expand All @@ -40,13 +41,20 @@ def inject_trace_context(meta: dict[str, Any]) -> None:
inject(meta)


def extract_trace_context(meta: dict[str, Any]) -> Context | None:
def extract_trace_context(meta: Mapping[str, Any] | None) -> Context | None:
"""Extract W3C trace context from a `_meta` dict.

Returns `None` when the carrier is malformed; telemetry parsing must
never fail the request it annotates.
Returns `None` when the carrier is absent, malformed, or carries no
valid `traceparent`, so callers fall through to ambient parenting; an
explicit empty `Context` would orphan the span instead of nesting under
the current one.
"""
if not meta:
return None
try:
return extract(meta)
except (TypeError, ValueError):
ctx = extract(meta)
except (ValueError, TypeError):
return None
Comment thread
claude[bot] marked this conversation as resolved.
if not get_current_span(ctx).get_span_context().is_valid:
return None
return ctx
Loading
Loading