diff --git a/README.md b/README.md index 34133a796..4873876a6 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,9 @@ For comprehensive guides and SDK API documentation - [Features](https://modelcontextprotocol.github.io/java-sdk/#features) - Overview the features provided by the Java MCP SDK - [Architecture](https://modelcontextprotocol.github.io/java-sdk/#architecture) - Java MCP SDK architecture overview. -- [Java Dependencies / BOM](https://modelcontextprotocol.github.io/java-sdk/quickstart/#dependencies) - Java dependencies and BOM. -- [Java MCP Client](https://modelcontextprotocol.github.io/java-sdk/client/) - Learn how to use the MCP client to interact with MCP servers. -- [Java MCP Server](https://modelcontextprotocol.github.io/java-sdk/server/) - Learn how to implement and configure a MCP servers. +- [Java Dependencies / BOM](https://java.sdk.modelcontextprotocol.io/latest/quickstart/#dependencies) - Java dependencies and BOM. +- [Java MCP Client](https://java.sdk.modelcontextprotocol.io/latest/client/) - Learn how to use the MCP client to interact with MCP servers. +- [Java MCP Server](https://java.sdk.modelcontextprotocol.io/latest/server/) - Learn how to implement and configure a MCP servers. #### Spring AI MCP documentation [Spring AI MCP](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-overview.html) extends the MCP Java SDK with Spring Boot integration, providing both [client](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-client-boot-starter-docs.html) and [server](https://docs.spring.io/spring-ai/reference/2.0-SNAPSHOT/api/mcp/mcp-server-boot-starter-docs.html) starters. @@ -40,6 +40,41 @@ To run the tests you have to pre-install `Docker` and `npx`. ```bash ./mvnw test ``` +### Conformance Tests + +The SDK is validated against the [MCP conformance test suite](https://github.com/modelcontextprotocol/conformance) at 0.1.15 version. +Full details and instructions are in [`conformance-tests/VALIDATION_RESULTS.md`](conformance-tests/VALIDATION_RESULTS.md). + +**Latest results:** + +| Suite | Result | +|---------------|-----------------------------------------------------| +| Server | ✅ 40/40 passed (100%) | +| Client | 🟡 3/4 scenarios, 9/10 checks passed | +| Auth (Spring) | 🟡 12/14 scenarios fully passing (98.9% checks) | + +To run the conformance tests locally you need `npx` installed. + +```bash +# Server conformance +./mvnw compile -pl conformance-tests/server-servlet -am exec:java +npx @modelcontextprotocol/conformance server --url http://localhost:8080/mcp --suite active + +# Client conformance +./mvnw clean package -DskipTests -pl conformance-tests/client-jdk-http-client -am +for scenario in initialize tools_call elicitation-sep1034-client-defaults sse-retry; do + npx @modelcontextprotocol/conformance client \ + --command "java -jar conformance-tests/client-jdk-http-client/target/client-jdk-http-client-2.0.0-SNAPSHOT.jar" \ + --scenario $scenario +done + +# Auth conformance (Spring HTTP Client) +./mvnw clean package -DskipTests -pl conformance-tests/client-spring-http-client -am +npx @modelcontextprotocol/conformance@0.1.15 client \ + --spec-version 2025-11-25 \ + --command "java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-2.0.0-SNAPSHOT.jar" \ + --suite auth +``` ## Contributing diff --git a/conformance-tests/VALIDATION_RESULTS.md b/conformance-tests/VALIDATION_RESULTS.md index 8edc7ad71..e4ce396bc 100644 --- a/conformance-tests/VALIDATION_RESULTS.md +++ b/conformance-tests/VALIDATION_RESULTS.md @@ -4,7 +4,7 @@ **Server Tests:** 40/40 passed (100%) **Client Tests:** 3/4 scenarios passed (9/10 checks passed) -**Auth Tests:** 12/14 scenarios fully passing (178 passed, 1 failed, 1 warning, 85.7% scenarios, 98.9% checks) +**Auth Tests:** 14/15 scenarios fully passing (196 passed, 0 failed, 1 warning, 93.3% scenarios, 99.5% checks) ## Server Test Results @@ -37,35 +37,35 @@ ## Auth Test Results (Spring HTTP Client) -**Status: 178 passed, 1 failed, 1 warning across 14 scenarios** +**Status: 196 passed, 0 failed, 1 warning across 15 scenarios** Uses the `client-spring-http-client` module with Spring Security OAuth2 and the [mcp-client-security](https://github.com/springaicommunity/mcp-client-security) library. -### Fully Passing (12/14 scenarios) +### Fully Passing (14/15 scenarios) -- **auth/metadata-default (12/12):** Default metadata discovery -- **auth/metadata-var1 (12/12):** Metadata discovery variant 1 -- **auth/metadata-var2 (12/12):** Metadata discovery variant 2 -- **auth/metadata-var3 (12/12):** Metadata discovery variant 3 -- **auth/scope-from-www-authenticate (13/13):** Scope extraction from WWW-Authenticate header -- **auth/scope-from-scopes-supported (13/13):** Scope extraction from scopes_supported -- **auth/scope-omitted-when-undefined (13/13):** Scope omitted when not defined +- **auth/metadata-default (13/13):** Default metadata discovery +- **auth/metadata-var1 (13/13):** Metadata discovery variant 1 +- **auth/metadata-var2 (13/13):** Metadata discovery variant 2 +- **auth/metadata-var3 (13/13):** Metadata discovery variant 3 +- **auth/scope-from-www-authenticate (14/14):** Scope extraction from WWW-Authenticate header +- **auth/scope-from-scopes-supported (14/14):** Scope extraction from scopes_supported +- **auth/scope-omitted-when-undefined (14/14):** Scope omitted when not defined +- **auth/scope-step-up (16/16):** Scope step-up challenge - **auth/scope-retry-limit (11/11):** Scope retry limit handling -- **auth/token-endpoint-auth-basic (17/17):** Token endpoint with HTTP Basic auth -- **auth/token-endpoint-auth-post (17/17):** Token endpoint with POST body auth -- **auth/token-endpoint-auth-none (17/17):** Token endpoint with no client auth +- **auth/token-endpoint-auth-basic (18/18):** Token endpoint with HTTP Basic auth +- **auth/token-endpoint-auth-post (18/18):** Token endpoint with POST body auth +- **auth/token-endpoint-auth-none (18/18):** Token endpoint with no client auth +- **auth/resource-mismatch (2/2):** Resource mismatch handling - **auth/pre-registration (6/6):** Pre-registered client credentials flow -### Partially Passing (2/14 scenarios) +### Partially Passing (1/15 scenarios) -- **auth/basic-cimd (12/12 + 1 warning):** Basic Client-Initiated Metadata Discovery — all checks pass, minor warning -- **auth/scope-step-up (11/12):** Scope step-up challenge — 1 failure, client does not fully handle scope escalation after initial authorization +- **auth/basic-cimd (13/13 + 1 warning):** Basic Client-Initiated Metadata Discovery — all checks pass, minor warning ## Known Limitations 1. **Client SSE Retry:** Client doesn't parse or respect the `retry:` field, reconnects immediately, and doesn't send Last-Event-ID header -2. **Auth Scope Step-Up:** Client does not fully handle scope step-up challenges where the server requests additional scopes after initial authorization -3. **Auth Basic CIMD:** Minor conformance warning in the basic Client-Initiated Metadata Discovery flow +2. **Auth Basic CIMD:** Minor conformance warning in the basic Client-Initiated Metadata Discovery flow ## Running Tests @@ -113,4 +113,3 @@ npx @modelcontextprotocol/conformance@0.1.15 client \ ### High Priority 1. Fix client SSE retry field handling in `HttpClientStreamableHttpTransport` 2. Implement CIMD -3. Implement scope step up diff --git a/conformance-tests/client-jdk-http-client/pom.xml b/conformance-tests/client-jdk-http-client/pom.xml index f30361438..f939cfa6c 100644 --- a/conformance-tests/client-jdk-http-client/pom.xml +++ b/conformance-tests/client-jdk-http-client/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk conformance-tests - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT client-jdk-http-client jar @@ -16,19 +16,19 @@ https://github.com/modelcontextprotocol/java-sdk - git://github.com/modelcontextprotocol/java-sdk.git - git@github.com/modelcontextprotocol/java-sdk.git + scm:git:git://github.com/modelcontextprotocol/java-sdk.git + scm:git:ssh://git@github.com/modelcontextprotocol/java-sdk.git true - + io.modelcontextprotocol.sdk mcp - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT @@ -57,7 +57,8 @@ - io.modelcontextprotocol.conformance.client.ConformanceJdkClientMcpClient + + io.modelcontextprotocol.conformance.client.ConformanceJdkClientMcpClient @@ -79,4 +80,4 @@ - + \ No newline at end of file diff --git a/conformance-tests/client-spring-http-client/README.md b/conformance-tests/client-spring-http-client/README.md index afbf64773..e5ed016c3 100644 --- a/conformance-tests/client-spring-http-client/README.md +++ b/conformance-tests/client-spring-http-client/README.md @@ -26,7 +26,7 @@ Test with @modelcontextprotocol/conformance@0.1.15. | auth/scope-from-www-authenticate | ✅ Pass | 13/13 | | auth/scope-from-scopes-supported | ✅ Pass | 13/13 | | auth/scope-omitted-when-undefined | ✅ Pass | 13/13 | -| auth/scope-step-up | ❌ Fail | 11/12 (1 failed) | +| auth/scope-step-up | ✅ Pass | 12/12 | | auth/scope-retry-limit | ✅ Pass | 11/11 | | auth/token-endpoint-auth-basic | ✅ Pass | 17/17 | | auth/token-endpoint-auth-post | ✅ Pass | 17/17 | @@ -67,7 +67,7 @@ cd conformance-tests/client-spring-http-client This creates an executable JAR at: ``` -target/client-spring-http-client-1.1.0-SNAPSHOT.jar +target/client-spring-http-client-2.0.0-SNAPSHOT.jar ``` ## Running Tests @@ -79,7 +79,7 @@ Run the full auth suite: ```bash npx @modelcontextprotocol/conformance@0.1.15 client \ --spec-version 2025-11-25 \ - --command "java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-1.1.0-SNAPSHOT.jar" \ + --command "java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-2.0.0-SNAPSHOT.jar" \ --suite auth ``` @@ -88,7 +88,7 @@ Run a single scenario: ```bash npx @modelcontextprotocol/conformance@0.1.15 client \ --spec-version 2025-11-25 \ - --command "java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-1.1.0-SNAPSHOT.jar" \ + --command "java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-2.0.0-SNAPSHOT.jar" \ --scenario auth/metadata-default ``` @@ -97,7 +97,7 @@ Run with verbose output: ```bash npx @modelcontextprotocol/conformance@0.1.15 client \ --spec-version 2025-11-25 \ - --command "java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-1.1.0-SNAPSHOT.jar" \ + --command "java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-2.0.0-SNAPSHOT.jar" \ --scenario auth/metadata-default \ --verbose ``` @@ -108,7 +108,7 @@ You can also run the client manually if you have a test server: ```bash export MCP_CONFORMANCE_SCENARIO=auth/metadata-default -java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-1.1.0-SNAPSHOT.jar http://localhost:3000/mcp +java -jar conformance-tests/client-spring-http-client/target/client-spring-http-client-2.0.0-SNAPSHOT.jar http://localhost:3000/mcp ``` ## Known Issues diff --git a/conformance-tests/client-spring-http-client/pom.xml b/conformance-tests/client-spring-http-client/pom.xml index 46dae68ef..44aa7f925 100644 --- a/conformance-tests/client-spring-http-client/pom.xml +++ b/conformance-tests/client-spring-http-client/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk conformance-tests - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT client-spring-http-client jar @@ -16,14 +16,15 @@ https://github.com/modelcontextprotocol/java-sdk - git://github.com/modelcontextprotocol/java-sdk.git - git@github.com/modelcontextprotocol/java-sdk.git + scm:git:git://github.com/modelcontextprotocol/java-sdk.git + scm:git:ssh://git@github.com/modelcontextprotocol/java-sdk.git 17 - 4.0.2 - 2.0.0-M2 + 4.0.5 + 2.0.0-M4 + 0.1.5 true @@ -64,7 +65,12 @@ org.springaicommunity mcp-client-security - 0.1.2 + ${spring-ai-mcp-security.version} + + + io.modelcontextprotocol.sdk + mcp-core + ${project.version} diff --git a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/ConformanceSpringClientApplication.java b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/ConformanceSpringClientApplication.java index 00582c9f2..63c3601f0 100644 --- a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/ConformanceSpringClientApplication.java +++ b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/ConformanceSpringClientApplication.java @@ -8,8 +8,11 @@ import io.modelcontextprotocol.conformance.client.scenario.Scenario; import org.springaicommunity.mcp.security.client.sync.oauth2.metadata.McpMetadataDiscoveryService; +import org.springaicommunity.mcp.security.client.sync.oauth2.registration.DefaultMcpOAuth2ClientManager; import org.springaicommunity.mcp.security.client.sync.oauth2.registration.DynamicClientRegistrationService; import org.springaicommunity.mcp.security.client.sync.oauth2.registration.InMemoryMcpClientRegistrationRepository; +import org.springaicommunity.mcp.security.client.sync.oauth2.registration.McpClientRegistrationRepository; +import org.springaicommunity.mcp.security.client.sync.oauth2.registration.McpOAuth2ClientManager; import org.springframework.boot.ApplicationArguments; import org.springframework.boot.ApplicationRunner; @@ -49,8 +52,15 @@ McpMetadataDiscoveryService discovery() { } @Bean - InMemoryMcpClientRegistrationRepository clientRegistrationRepository(McpMetadataDiscoveryService discovery) { - return new InMemoryMcpClientRegistrationRepository(new DynamicClientRegistrationService(), discovery); + McpClientRegistrationRepository clientRegistrationRepository() { + return new InMemoryMcpClientRegistrationRepository(); + } + + @Bean + McpOAuth2ClientManager mcpOAuth2ClientManager(McpClientRegistrationRepository mcpClientRegistrationRepository, + McpMetadataDiscoveryService mcpMetadataDiscoveryService) { + return new DefaultMcpOAuth2ClientManager(mcpClientRegistrationRepository, + new DynamicClientRegistrationService(), mcpMetadataDiscoveryService); } @Bean diff --git a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/McpClientController.java b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/McpClientController.java index e02cfd416..1b1910298 100644 --- a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/McpClientController.java +++ b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/McpClientController.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.conformance.client; import io.modelcontextprotocol.conformance.client.scenario.Scenario; +import io.modelcontextprotocol.spec.McpSchema; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RestController; @@ -27,4 +28,15 @@ public String execute() { return "OK"; } + @GetMapping("/tools-list") + public String toolsList() { + return "OK"; + } + + @GetMapping("/tools-call") + public String toolsCall() { + this.scenario.getMcpClient().callTool(McpSchema.CallToolRequest.builder().name("test-tool").build()); + return "OK"; + } + } diff --git a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/configuration/DefaultConfiguration.java b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/configuration/DefaultConfiguration.java index 12a9c4a5c..febd0f461 100644 --- a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/configuration/DefaultConfiguration.java +++ b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/configuration/DefaultConfiguration.java @@ -8,15 +8,16 @@ import io.modelcontextprotocol.conformance.client.scenario.DefaultScenario; import org.springaicommunity.mcp.security.client.sync.config.McpClientOAuth2Configurer; import org.springaicommunity.mcp.security.client.sync.oauth2.registration.McpClientRegistrationRepository; +import org.springaicommunity.mcp.security.client.sync.oauth2.registration.McpOAuth2ClientManager; import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression; import org.springframework.boot.web.server.servlet.context.ServletWebServerApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.web.SecurityFilterChain; -import static io.modelcontextprotocol.conformance.client.ConformanceSpringClientApplication.REGISTRATION_ID; @Configuration @ConditionalOnExpression("#{environment['MCP_CONFORMANCE_SCENARIO'] != 'auth/pre-registration'}") @@ -25,15 +26,16 @@ public class DefaultConfiguration { @Bean DefaultScenario defaultScenario(McpClientRegistrationRepository clientRegistrationRepository, ServletWebServerApplicationContext serverCtx, - OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository) { - return new DefaultScenario(clientRegistrationRepository, serverCtx, oAuth2AuthorizedClientRepository); + OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository, + McpOAuth2ClientManager mcpOAuth2ClientManager) { + return new DefaultScenario(clientRegistrationRepository, serverCtx, oAuth2AuthorizedClientRepository, + mcpOAuth2ClientManager); } @Bean SecurityFilterChain securityFilterChain(HttpSecurity http, ConformanceSpringClientApplication.ServerUrl serverUrl) { return http.authorizeHttpRequests(authz -> authz.anyRequest().permitAll()) - .with(new McpClientOAuth2Configurer(), - mcp -> mcp.registerMcpOAuth2Client(REGISTRATION_ID, serverUrl.value())) + .with(new McpClientOAuth2Configurer(), Customizer.withDefaults()) .build(); } diff --git a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/DefaultScenario.java b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/DefaultScenario.java index 907cea10d..b1fb78a14 100644 --- a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/DefaultScenario.java +++ b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/DefaultScenario.java @@ -17,15 +17,16 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springaicommunity.mcp.security.client.sync.AuthenticationMcpTransportContextProvider; -import org.springaicommunity.mcp.security.client.sync.oauth2.http.client.OAuth2AuthorizationCodeSyncHttpRequestCustomizer; +import org.springaicommunity.mcp.security.client.sync.oauth2.http.client.OAuth2HttpClientTransportCustomizer; import org.springaicommunity.mcp.security.client.sync.oauth2.registration.McpClientRegistrationRepository; +import org.springaicommunity.mcp.security.client.sync.oauth2.registration.McpOAuth2ClientManager; import org.springframework.boot.web.server.servlet.context.ServletWebServerApplicationContext; import org.springframework.http.client.JdkClientHttpRequestFactory; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.web.client.RestClient; -import static io.modelcontextprotocol.conformance.client.ConformanceSpringClientApplication.REGISTRATION_ID; +import org.springframework.web.util.UriComponentsBuilder; public class DefaultScenario implements Scenario { @@ -35,12 +36,19 @@ public class DefaultScenario implements Scenario { private final DefaultOAuth2AuthorizedClientManager authorizedClientManager; + private final McpClientRegistrationRepository clientRegistrationRepository; + + private final McpOAuth2ClientManager mcpOAuth2ClientManager; + private McpSyncClient client; public DefaultScenario(McpClientRegistrationRepository clientRegistrationRepository, ServletWebServerApplicationContext serverCtx, - OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository) { + OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository, + McpOAuth2ClientManager mcpOAuth2ClientManager) { this.serverCtx = serverCtx; + this.clientRegistrationRepository = clientRegistrationRepository; + this.mcpOAuth2ClientManager = mcpOAuth2ClientManager; this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(clientRegistrationRepository, oAuth2AuthorizedClientRepository); } @@ -51,10 +59,13 @@ public void execute(String serverUrl) { var testServerUrl = "http://localhost:" + serverCtx.getWebServer().getPort(); var testClient = buildTestClient(testServerUrl); - var customizer = new OAuth2AuthorizationCodeSyncHttpRequestCustomizer(authorizedClientManager, REGISTRATION_ID); - HttpClientStreamableHttpTransport transport = HttpClientStreamableHttpTransport.builder(serverUrl) - .httpRequestCustomizer(customizer) - .build(); + var customizer = new OAuth2HttpClientTransportCustomizer(authorizedClientManager, clientRegistrationRepository, + mcpOAuth2ClientManager); + var baseUri = UriComponentsBuilder.fromUriString(serverUrl).replacePath(null).toUriString(); + var path = UriComponentsBuilder.fromUriString(serverUrl).build().getPath(); + var transportBuilder = HttpClientStreamableHttpTransport.builder(baseUri).endpoint(path); + customizer.customize("default-transport", transportBuilder); + HttpClientStreamableHttpTransport transport = transportBuilder.build(); this.client = McpClient.sync(transport) .transportContextProvider(new AuthenticationMcpTransportContextProvider()) @@ -64,6 +75,8 @@ public void execute(String serverUrl) { try { testClient.get().uri("/initialize-mcp-client").retrieve().toBodilessEntity(); + testClient.get().uri("/tools-list").retrieve().toBodilessEntity(); + testClient.get().uri("/tools-call").retrieve().toBodilessEntity(); } finally { // Close the client (which will close the transport) diff --git a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/PreRegistrationScenario.java b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/PreRegistrationScenario.java index 8e6bbe228..accb7862a 100644 --- a/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/PreRegistrationScenario.java +++ b/conformance-tests/client-spring-http-client/src/main/java/io/modelcontextprotocol/conformance/client/scenario/PreRegistrationScenario.java @@ -87,7 +87,7 @@ private void setClientRegistration(String mcpServerUrl, PreRegistrationContext o .clientSecret(oauthCredentials.clientSecret()) .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .build(); - clientRegistrationRepository.addPreRegisteredClient(registration, + clientRegistrationRepository.addClientRegistration(registration, metadata.protectedResourceMetadata().resource()); } diff --git a/conformance-tests/conformance-baseline.yml b/conformance-tests/conformance-baseline.yml index d2990c155..37cdb3110 100644 --- a/conformance-tests/conformance-baseline.yml +++ b/conformance-tests/conformance-baseline.yml @@ -9,5 +9,3 @@ client: - sse-retry # CIMD not implemented yet - auth/basic-cimd - # Scope step up beyond initial authorization request not implemented - - auth/scope-step-up diff --git a/conformance-tests/pom.xml b/conformance-tests/pom.xml index d1bef2a24..88ab7c4b0 100644 --- a/conformance-tests/pom.xml +++ b/conformance-tests/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT conformance-tests pom @@ -16,18 +16,18 @@ https://github.com/modelcontextprotocol/java-sdk - git://github.com/modelcontextprotocol/java-sdk.git - git@github.com/modelcontextprotocol/java-sdk.git + scm:git:git://github.com/modelcontextprotocol/java-sdk.git + scm:git:ssh://git@github.com/modelcontextprotocol/java-sdk.git true - + client-jdk-http-client client-spring-http-client server-servlet - + \ No newline at end of file diff --git a/conformance-tests/server-servlet/pom.xml b/conformance-tests/server-servlet/pom.xml index 66acea835..a80c7c4ec 100644 --- a/conformance-tests/server-servlet/pom.xml +++ b/conformance-tests/server-servlet/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk conformance-tests - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT server-servlet jar @@ -16,8 +16,8 @@ https://github.com/modelcontextprotocol/java-sdk - git://github.com/modelcontextprotocol/java-sdk.git - git@github.com/modelcontextprotocol/java-sdk.git + scm:git:git://github.com/modelcontextprotocol/java-sdk.git + scm:git:ssh://git@github.com/modelcontextprotocol/java-sdk.git @@ -28,7 +28,7 @@ io.modelcontextprotocol.sdk mcp - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT diff --git a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java index 3d162a5de..25ec2c106 100644 --- a/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java +++ b/conformance-tests/server-servlet/src/main/java/io/modelcontextprotocol/conformance/server/ConformanceServlet.java @@ -20,7 +20,6 @@ import io.modelcontextprotocol.spec.McpSchema.EmbeddedResource; import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; import io.modelcontextprotocol.spec.McpSchema.ImageContent; -import io.modelcontextprotocol.spec.McpSchema.JsonSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; @@ -51,8 +50,8 @@ public class ConformanceServlet { private static final String MCP_ENDPOINT = "/mcp"; - private static final JsonSchema EMPTY_JSON_SCHEMA = new JsonSchema("object", Collections.emptyMap(), null, null, - null, null); + private static final Map EMPTY_JSON_SCHEMA = Map.of("type", "object", "properties", + Collections.emptyMap()); // Minimal 1x1 red pixel PNG (base64 encoded) private static final String RED_PIXEL_PNG = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="; @@ -326,10 +325,10 @@ private static List createToolSpecs() { .tool(Tool.builder() .name("test_sampling") .description("Tool that requests LLM sampling from client") - .inputSchema(new JsonSchema("object", + .inputSchema(Map.of("type", "object", "properties", Map.of("prompt", Map.of("type", "string", "description", "The prompt to send to the LLM")), - List.of("prompt"), null, null, null)) + "required", List.of("prompt"))) .build()) .callHandler((exchange, request) -> { logger.info("Tool 'test_sampling' called"); @@ -355,10 +354,10 @@ private static List createToolSpecs() { .tool(Tool.builder() .name("test_elicitation") .description("Tool that requests user input from client") - .inputSchema(new JsonSchema("object", + .inputSchema(Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string", "description", "The message to show the user")), - List.of("message"), null, null, null)) + "required", List.of("message"))) .build()) .callHandler((exchange, request) -> { logger.info("Tool 'test_elicitation' called"); diff --git a/docs/server.md b/docs/server.md index f9f3aa683..378de6975 100644 --- a/docs/server.md +++ b/docs/server.md @@ -795,3 +795,42 @@ Supported logging levels (in order of increasing severity): DEBUG (0), INFO (1), ## Error Handling The SDK provides comprehensive error handling through the McpError class, covering protocol compatibility, transport communication, JSON-RPC messaging, tool execution, resource management, prompt handling, timeouts, and connection issues. This unified error handling approach ensures consistent and reliable error management across both synchronous and asynchronous operations. + +### Error Handling in Tool Implementations + +#### Two Tiers of Errors + +MCP distinguishes between two categories of errors in tool execution: + +**1. Tool-Level Errors (Recoverable by the LLM)** + +Use `CallToolResult` with `isError(true)` for validation failures, missing arguments, or domain errors the LLM can act on and retry. + +```java +// Example: Domain validation failure (e.g., invalid email format) +if (!emailAddress.matches("^[A-Za-z0-9+_.-]+@(.+)$")) { + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent("Invalid argument: 'email' must be a valid email address."))) + .isError(true) + .build(); +} +``` + +The LLM receives this as part of the normal tool response and can self-correct in a subsequent interaction. + +**2. Protocol-Level Errors (Unrecoverable)** + +Uncaught exceptions from a tool handler are mapped to a JSON-RPC error response. Use this only for truly unexpected failures (e.g., infrastructure errors such as DB timeout), not for input validation. + +```java +// This propagates as a JSON-RPC error — use sparingly +throw new McpError(McpSchema.ErrorCodes.INTERNAL_ERROR, "Unexpected failure"); +``` + +#### Decision Guide + +| Situation | Approach | +|------------------------------------|---------------------------------------| +| Domain validation failure | `CallToolResult` with `isError=true` | +| Infrastructure / unexpected error | Throw `McpError` or let it propagate | +| Partial success with a warning | `CallToolResult` with warning in text | diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index fb6f3a32a..303520517 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT mcp-bom @@ -16,13 +16,13 @@ Java SDK MCP BOM Java SDK MCP Bill of Materials - https://github.com/modelcontextprotocol/java-sdk + https://github.com/modelcontextprotocol/java-sdk - - https://github.com/modelcontextprotocol/java-sdk - git://github.com/modelcontextprotocol/java-sdk.git - git@github.com/modelcontextprotocol/java-sdk.git - + + https://github.com/modelcontextprotocol/java-sdk + scm:git:git://github.com/modelcontextprotocol/java-sdk.git + scm:git:ssh://git@github.com/modelcontextprotocol/java-sdk.git + @@ -47,6 +47,13 @@ ${project.version} + + + io.modelcontextprotocol.sdk + mcp-json-jackson3 + ${project.version} + + io.modelcontextprotocol.sdk diff --git a/mcp-core/pom.xml b/mcp-core/pom.xml index 4de0fba2b..d622df0d1 100644 --- a/mcp-core/pom.xml +++ b/mcp-core/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT mcp-core jar @@ -16,8 +16,8 @@ https://github.com/modelcontextprotocol/java-sdk - git://github.com/modelcontextprotocol/java-sdk.git - git@github.com/modelcontextprotocol/java-sdk.git + scm:git:git://github.com/modelcontextprotocol/java-sdk.git + scm:git:ssh://git@github.com/modelcontextprotocol/java-sdk.git @@ -164,14 +164,14 @@ test - - - com.google.code.gson - gson - 2.10.1 - test - + + + com.google.code.gson + gson + 2.10.1 + test + - + \ No newline at end of file diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 93fcc332a..434c07a1b 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -303,7 +303,7 @@ public class McpAsyncClient { return Mono.empty(); } - return this.listToolsInternal(init, McpSchema.FIRST_PAGE).doOnNext(listToolsResult -> { + return this.listToolsInternal(init, McpSchema.FIRST_PAGE, null).doOnNext(listToolsResult -> { listToolsResult.tools() .forEach(tool -> logger.debug("Tool {} schema: {}", tool.name(), tool.outputSchema())); if (enableCallToolSchemaCaching && listToolsResult.tools() != null) { @@ -645,16 +645,27 @@ public Mono listTools() { * @return A Mono that emits the list of tools result */ public Mono listTools(String cursor) { - return this.initializer.withInitialization("listing tools", init -> this.listToolsInternal(init, cursor)); + return this.initializer.withInitialization("listing tools", init -> this.listToolsInternal(init, cursor, null)); } - private Mono listToolsInternal(Initialization init, String cursor) { + /** + * Retrieves a paginated list of tools with optional metadata. + * @param cursor Optional pagination cursor from a previous list request + * @param meta Optional metadata to include in the request (_meta field) + * @return A Mono that emits the list of tools result + */ + public Mono listTools(String cursor, Map meta) { + return this.initializer.withInitialization("listing tools", init -> this.listToolsInternal(init, cursor, meta)); + } + + private Mono listToolsInternal(Initialization init, String cursor, + Map meta) { if (init.initializeResult().capabilities().tools() == null) { return Mono.error(new IllegalStateException("Server does not provide tools capability")); } return init.mcpSession() - .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor), + .sendRequest(McpSchema.METHOD_TOOLS_LIST, new McpSchema.PaginatedRequest(cursor, meta), LIST_TOOLS_RESULT_TYPE_REF) .doOnNext(result -> { // Validate tool names (warn only) @@ -725,12 +736,30 @@ public Mono listResources() { * @see #readResource(McpSchema.Resource) */ public Mono listResources(String cursor) { + return this.listResourcesInternal(cursor, null); + } + + /** + * Retrieves a paginated list of resources provided by the server. Resources represent + * any kind of UTF-8 encoded data that an MCP server makes available to clients, such + * as database records, API responses, log files, and more. + * @param cursor Optional pagination cursor from a previous list request + * @param meta Optional metadata to include in the request (_meta field) + * @return A Mono that completes with the list of resources result. + * @see McpSchema.ListResourcesResult + * @see #readResource(McpSchema.Resource) + */ + public Mono listResources(String cursor, Map meta) { + return this.listResourcesInternal(cursor, meta); + } + + private Mono listResourcesInternal(String cursor, Map meta) { return this.initializer.withInitialization("listing resources", init -> { if (init.initializeResult().capabilities().resources() == null) { return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } return init.mcpSession() - .sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor), + .sendRequest(McpSchema.METHOD_RESOURCES_LIST, new McpSchema.PaginatedRequest(cursor, meta), LIST_RESOURCES_RESULT_TYPE_REF); }); } @@ -795,12 +824,30 @@ public Mono listResourceTemplates() { * @see McpSchema.ListResourceTemplatesResult */ public Mono listResourceTemplates(String cursor) { + return this.listResourceTemplatesInternal(cursor, null); + } + + /** + * Retrieves a paginated list of resource templates provided by the server. Resource + * templates allow servers to expose parameterized resources using URI templates, + * enabling dynamic resource access based on variable parameters. + * @param cursor Optional pagination cursor from a previous list request + * @param meta Optional metadata to include in the request (_meta field) + * @return A Mono that completes with the list of resource templates result. + * @see McpSchema.ListResourceTemplatesResult + */ + public Mono listResourceTemplates(String cursor, Map meta) { + return this.listResourceTemplatesInternal(cursor, meta); + } + + private Mono listResourceTemplatesInternal(String cursor, + Map meta) { return this.initializer.withInitialization("listing resource templates", init -> { if (init.initializeResult().capabilities().resources() == null) { return Mono.error(new IllegalStateException("Server does not provide the resources capability")); } return init.mcpSession() - .sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, new McpSchema.PaginatedRequest(cursor), + .sendRequest(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, new McpSchema.PaginatedRequest(cursor, meta), LIST_RESOURCE_TEMPLATES_RESULT_TYPE_REF); }); } @@ -895,8 +942,26 @@ public Mono listPrompts() { * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts(String cursor) { - return this.initializer.withInitialization("listing prompts", init -> init.mcpSession() - .sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF)); + return this.listPromptsInternal(cursor, null); + } + + /** + * Retrieves a paginated list of prompts with optional metadata. + * @param cursor Optional pagination cursor from a previous list request + * @param meta Optional metadata to include in the request (_meta field) + * @return A Mono that completes with the list of prompts result. + * @see McpSchema.ListPromptsResult + * @see #getPrompt(GetPromptRequest) + */ + public Mono listPrompts(String cursor, Map meta) { + return this.listPromptsInternal(cursor, meta); + } + + private Mono listPromptsInternal(String cursor, Map meta) { + return this.initializer.withInitialization("listing prompts", + init -> init.mcpSession() + .sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor, meta), + LIST_PROMPTS_RESULT_TYPE_REF)); } /** diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 7fdaa8941..7e08f83a0 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.util.Map; import java.util.function.Supplier; import org.slf4j.Logger; @@ -259,6 +260,18 @@ public McpSchema.ListToolsResult listTools(String cursor) { } + /** + * Retrieves a paginated list of tools provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @param meta Optional metadata to include in the request (_meta field) + * @return The list of tools result containing: - tools: List of available tools, each + * with a name, description, and input schema - nextCursor: Optional cursor for + * pagination if more tools are available + */ + public McpSchema.ListToolsResult listTools(String cursor, Map meta) { + return withProvidedContext(this.delegate.listTools(cursor, meta)).block(); + } + // -------------------------- // Resources // -------------------------- @@ -282,6 +295,17 @@ public McpSchema.ListResourcesResult listResources(String cursor) { } + /** + * Retrieves a paginated list of resources with optional metadata. + * @param cursor Optional pagination cursor from a previous list request + * @param meta Optional metadata to include in the request (_meta field) + * @return The list of resources result + */ + public McpSchema.ListResourcesResult listResources(String cursor, Map meta) { + return withProvidedContext(this.delegate.listResources(cursor, meta)).block(); + + } + /** * Send a resources/read request. * @param resource the resource to read @@ -324,6 +348,20 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates(String cursor } + /** + * Resource templates allow servers to expose parameterized resources using URI + * templates. Arguments may be auto-completed through the completion API. + * + * Retrieves a paginated list of resource templates provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @param meta Optional metadata to include in the request (_meta field) + * @return The list of resource templates result. + */ + public McpSchema.ListResourceTemplatesResult listResourceTemplates(String cursor, Map meta) { + return withProvidedContext(this.delegate.listResourceTemplates(cursor, meta)).block(); + + } + /** * Subscriptions. The protocol supports optional subscriptions to resource changes. * Clients can subscribe to specific resources and receive notifications when they @@ -370,6 +408,17 @@ public ListPromptsResult listPrompts(String cursor) { } + /** + * Retrieves a paginated list of prompts provided by the server. + * @param cursor Optional pagination cursor from a previous list request + * @param meta Optional metadata to include in the request (_meta field) + * @return The list of prompts result. + */ + public ListPromptsResult listPrompts(String cursor, Map meta) { + return withProvidedContext(this.delegate.listPrompts(cursor, meta)).block(); + + } + public GetPromptResult getPrompt(GetPromptRequest getPromptRequest) { return withProvidedContext(this.delegate.getPrompt(getPromptRequest)).block(); } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index be4e4cf97..70d8b68e3 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -240,17 +240,6 @@ public Builder requestBuilder(HttpRequest.Builder requestBuilder) { return this; } - /** - * Customizes the HTTP client builder. - * @param requestCustomizer the consumer to customize the HTTP request builder - * @return this builder - */ - public Builder customizeRequest(final Consumer requestCustomizer) { - Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); - requestCustomizer.accept(requestBuilder); - return this; - } - /** * Sets the JSON mapper implementation to use for serialization/deserialization. * @param jsonMapper the JSON mapper @@ -456,7 +445,7 @@ private Mono> sendHttpPost(final String endpoint, final Str return Mono.deferContextual(ctx -> { var builder = this.requestBuilder.copy() .uri(requestUri) - .header(HttpHeaders.CONTENT_TYPE, "application/json") + .header(HttpHeaders.CONTENT_TYPE, "application/json; charset=utf-8") .header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION) .POST(HttpRequest.BodyPublishers.ofString(body)); var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index d6b01e17f..142c0302c 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2025 the original author or authors. + * Copyright 2024-2026 the original author or authors. */ package io.modelcontextprotocol.client.transport; @@ -23,6 +23,7 @@ import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.client.transport.customizer.McpHttpClientAuthorizationErrorHandler; import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.json.McpJsonDefaults; @@ -50,6 +51,7 @@ import reactor.core.publisher.Mono; import reactor.util.function.Tuple2; import reactor.util.function.Tuples; +import reactor.util.retry.Retry; /** * An implementation of the Streamable HTTP protocol as defined by the @@ -72,6 +74,7 @@ *

* * @author Christian Tzolov + * @author Daniel Garnier-Moiroux * @see Streamable * HTTP transport specification @@ -99,6 +102,8 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private static final String APPLICATION_JSON = "application/json"; + private static final String APPLICATION_JSON_UTF8 = "application/json; charset=utf-8"; + private static final String TEXT_EVENT_STREAM = "text/event-stream"; public static int NOT_FOUND = 404; @@ -115,6 +120,8 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private final boolean openConnectionOnStartup; + private final McpHttpClientAuthorizationErrorHandler authorizationErrorHandler; + private final boolean resumableStreams; private final McpAsyncHttpClientRequestCustomizer httpRequestCustomizer; @@ -132,7 +139,7 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, boolean openConnectionOnStartup, McpAsyncHttpClientRequestCustomizer httpRequestCustomizer, - List supportedProtocolVersions) { + McpHttpClientAuthorizationErrorHandler authorizationErrorHandler, List supportedProtocolVersions) { this.jsonMapper = jsonMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; @@ -140,6 +147,7 @@ private HttpClientStreamableHttpTransport(McpJsonMapper jsonMapper, HttpClient h this.endpoint = endpoint; this.resumableStreams = resumableStreams; this.openConnectionOnStartup = openConnectionOnStartup; + this.authorizationErrorHandler = authorizationErrorHandler; this.activeSession.set(createTransportSession()); this.httpRequestCustomizer = httpRequestCustomizer; this.supportedProtocolVersions = Collections.unmodifiableList(supportedProtocolVersions); @@ -239,7 +247,6 @@ public Mono closeGracefully() { } private Mono reconnect(McpTransportStream stream) { - return Mono.deferContextual(ctx -> { if (stream != null) { @@ -275,121 +282,120 @@ private Mono reconnect(McpTransportStream stream) { var transportContext = connectionCtx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null, transportContext)); }) - .flatMapMany( - requestBuilder -> Flux.create( - sseSink -> this.httpClient - .sendAsync(requestBuilder.build(), - responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, - sseSink)) - .whenComplete((response, throwable) -> { - if (throwable != null) { - sseSink.error(throwable); - } - else { - logger.debug("SSE connection established successfully"); - } - })) - .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) - .flatMap(responseEvent -> { - int statusCode = responseEvent.responseInfo().statusCode(); - - if (statusCode >= 200 && statusCode < 300) { - - if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - String data = responseEvent.sseEvent().data(); - // Per 2025-11-25 spec (SEP-1699), servers may - // send SSE events - // with empty data to prime the client for - // reconnection. - // Skip these events as they contain no JSON-RPC - // message. - if (data == null || data.isBlank()) { - logger.debug("Skipping SSE event with empty data (stream primer)"); - return Flux.empty(); - } - try { - // We don't support batching ATM and probably - // won't since the next version considers - // removing it. - McpSchema.JSONRPCMessage message = McpSchema - .deserializeJsonRpcMessage(this.jsonMapper, data); - - Tuple2, Iterable> idWithMessages = Tuples - .of(Optional.ofNullable(responseEvent.sseEvent().id()), - List.of(message)); - - McpTransportStream sessionStream = stream != null ? stream - : new DefaultMcpTransportStream<>(this.resumableStreams, - this::reconnect); - logger.debug("Connected stream {}", sessionStream.streamId()); - - return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); - - } - catch (IOException ioException) { - return Flux.error(new McpTransportException( - "Error parsing JSON-RPC message: " + responseEvent, ioException)); - } - } - else { - logger.debug("Received SSE event with type: {}", responseEvent.sseEvent()); - return Flux.empty(); - } - } - else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed - logger - .debug("The server does not support SSE streams, using request-response mode."); + .flatMapMany(requestBuilder -> Flux.create(sseSink -> this.httpClient + .sendAsync(requestBuilder.build(), this.toSendMessageBodySubscriber(sseSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + sseSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })).flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + if (statusCode == 401 || statusCode == 403) { + logger.debug("Authorization error in reconnect with code {}", statusCode); + return Mono.error( + new McpHttpClientTransportAuthorizationException( + "Authorization error connecting to SSE stream", + responseEvent.responseInfo())); + } + else if (statusCode == METHOD_NOT_ALLOWED) { + logger.debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + + if (!(responseEvent instanceof ResponseSubscribers.SseResponseEvent sseResponseEvent)) { + return Flux.error(new McpTransportException( + "Unrecognized server error when connecting to SSE stream, status code: " + + statusCode)); + } + else if (statusCode >= 200 && statusCode < 300) { + if (MESSAGE_EVENT_TYPE.equals(sseResponseEvent.sseEvent().event())) { + String data = sseResponseEvent.sseEvent().data(); + // Per 2025-11-25 spec (SEP-1699), servers may + // send SSE events + // with empty data to prime the client for + // reconnection. + // Skip these events as they contain no JSON-RPC + // message. + if (data == null || data.isBlank()) { + logger.debug("Skipping SSE event with empty data (stream primer)"); return Flux.empty(); } - else if (statusCode == NOT_FOUND) { - - if (transportSession != null && transportSession.sessionId().isPresent()) { - // only if the request was sent with a session id - // and the response is 404, we consider it a - // session not found error. - logger.debug("Session not found for session ID: {}", - transportSession.sessionId().get()); - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - return Flux.error( - new McpTransportException("Server Not Found. Status code:" + statusCode - + ", response-event:" + responseEvent)); - } - else if (statusCode == BAD_REQUEST) { - if (transportSession != null && transportSession.sessionId().isPresent()) { - // only if the request was sent with a session id - // and thre response is 404, we consider it a - // session not found error. - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - return Flux.error( - new McpTransportException("Bad Request. Status code:" + statusCode - + ", response-event:" + responseEvent)); + try { + // We don't support batching ATM and probably + // won't since the next version considers + // removing it. + McpSchema.JSONRPCMessage message = McpSchema + .deserializeJsonRpcMessage(this.jsonMapper, data); - } + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(sseResponseEvent.sseEvent().id()), List.of(message)); - return Flux.error(new McpTransportException( - "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); - }).flatMap( - jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) - .onErrorMap(CompletionException.class, t -> t.getCause()) - .onErrorComplete(t -> { - this.handleException(t); - return true; - }) - .doFinally(s -> { - Disposable ref = disposableRef.getAndSet(null); - if (ref != null) { - transportSession.removeConnection(ref); + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + + } + catch (IOException ioException) { + return Flux.error(new McpTransportException( + "Error parsing JSON-RPC message: " + responseEvent, ioException)); } - })) + } + else { + logger.debug("Received SSE event with type: {}", sseResponseEvent.sseEvent()); + return Flux.empty(); + } + } + else if (statusCode == NOT_FOUND) { + + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id + // and the response is 404, we consider it a + // session not found error. + logger.debug("Session not found for session ID: {}", + transportSession.sessionId().get()); + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + return Flux.error( + new McpTransportException("Server Not Found. Status code:" + statusCode + + ", response-event:" + responseEvent)); + } + else if (statusCode == BAD_REQUEST) { + if (transportSession != null && transportSession.sessionId().isPresent()) { + // only if the request was sent with a session id + // and thre response is 404, we consider it a + // session not found error. + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + return Flux.error(new McpTransportException( + "Bad Request. Status code:" + statusCode + ", response-event:" + responseEvent)); + } + return Flux.error(new McpTransportException( + "Received unrecognized SSE event type: " + sseResponseEvent.sseEvent().event())); + }) + .retryWhen(authorizationErrorRetrySpec()) + .flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + })) .contextWrite(ctx) .subscribe(); @@ -400,6 +406,25 @@ else if (statusCode == BAD_REQUEST) { } + private Retry authorizationErrorRetrySpec() { + return Retry.from(companion -> companion.flatMap(retrySignal -> { + if (!(retrySignal.failure() instanceof McpHttpClientTransportAuthorizationException authException)) { + return Mono.error(retrySignal.failure()); + } + if (retrySignal.totalRetriesInARow() >= this.authorizationErrorHandler.maxRetries()) { + return Mono.error(retrySignal.failure()); + } + return Mono.deferContextual(ctx -> { + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + return Mono + .from(this.authorizationErrorHandler.handle(authException.getResponseInfo(), transportContext)) + .switchIfEmpty(Mono.just(false)) + .flatMap(shouldRetry -> shouldRetry ? Mono.just(retrySignal.totalRetries()) + : Mono.error(retrySignal.failure())); + }); + })); + } + private BodyHandler toSendMessageBodySubscriber(FluxSink sink) { BodyHandler responseBodyHandler = responseInfo -> { @@ -454,7 +479,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { var builder = requestBuilder.uri(uri) .header(HttpHeaders.ACCEPT, APPLICATION_JSON + ", " + TEXT_EVENT_STREAM) - .header(HttpHeaders.CONTENT_TYPE, APPLICATION_JSON) + .header(HttpHeaders.CONTENT_TYPE, APPLICATION_JSON_UTF8) .header(HttpHeaders.CACHE_CONTROL, "no-cache") .header(HttpHeaders.PROTOCOL_VERSION, ctx.getOrDefault(McpAsyncClient.NEGOTIATED_PROTOCOL_VERSION, @@ -478,6 +503,13 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { })).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete().subscribe(); })).flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + if (statusCode == 401 || statusCode == 403) { + logger.debug("Authorization error in sendMessage with code {}", statusCode); + return Mono.error(new McpHttpClientTransportAuthorizationException( + "Authorization error when sending message", responseEvent.responseInfo())); + } + if (transportSession.markInitialized( responseEvent.responseInfo().headers().firstValue("mcp-session-id").orElseGet(() -> null))) { // Once we have a session, we try to open an async stream for @@ -488,8 +520,6 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { String sessionRepresentation = sessionIdOrPlaceholder(transportSession); - int statusCode = responseEvent.responseInfo().statusCode(); - if (statusCode >= 200 && statusCode < 300) { String contentType = responseEvent.responseInfo() @@ -605,6 +635,7 @@ else if (statusCode == BAD_REQUEST) { return Flux.error( new RuntimeException("Failed to send message: " + responseEvent)); }) + .retryWhen(authorizationErrorRetrySpec()) .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) .onErrorMap(CompletionException.class, t -> t.getCause()) .onErrorComplete(t -> { @@ -664,6 +695,8 @@ public static class Builder { private List supportedProtocolVersions = List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26, ProtocolVersions.MCP_2025_06_18, ProtocolVersions.MCP_2025_11_25); + private McpHttpClientAuthorizationErrorHandler authorizationErrorHandler = McpHttpClientAuthorizationErrorHandler.NOOP; + /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server @@ -706,17 +739,6 @@ public Builder requestBuilder(HttpRequest.Builder requestBuilder) { return this; } - /** - * Customizes the HTTP client builder. - * @param requestCustomizer the consumer to customize the HTTP request builder - * @return this builder - */ - public Builder customizeRequest(final Consumer requestCustomizer) { - Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); - requestCustomizer.accept(requestBuilder); - return this; - } - /** * Configure a custom {@link McpJsonMapper} implementation to use. * @param jsonMapper instance to use @@ -801,6 +823,17 @@ public Builder asyncHttpRequestCustomizer(McpAsyncHttpClientRequestCustomizer as return this; } + /** + * Sets the handler to be used when the server responds with HTTP 401 or HTTP 403 + * when sending a message. + * @param authorizationErrorHandler the handler + * @return this builder + */ + public Builder authorizationErrorHandler(McpHttpClientAuthorizationErrorHandler authorizationErrorHandler) { + this.authorizationErrorHandler = authorizationErrorHandler; + return this; + } + /** * Sets the connection timeout for the HTTP client. * @param connectTimeout the connection timeout duration @@ -845,7 +878,7 @@ public HttpClientStreamableHttpTransport build() { HttpClient httpClient = this.clientBuilder.connectTimeout(this.connectTimeout).build(); return new HttpClientStreamableHttpTransport(jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, httpClient, requestBuilder, baseUri, endpoint, resumableStreams, openConnectionOnStartup, - httpRequestCustomizer, supportedProtocolVersions); + httpRequestCustomizer, authorizationErrorHandler, supportedProtocolVersions); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportAuthorizationException.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportAuthorizationException.java new file mode 100644 index 000000000..31e5ae95e --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/McpHttpClientTransportAuthorizationException.java @@ -0,0 +1,31 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.net.http.HttpResponse; + +import io.modelcontextprotocol.spec.McpTransportException; + +/** + * Thrown when the MCP server responds with an authorization error (HTTP 401 or HTTP 403). + * Subclass of {@link McpTransportException} for targeted retry handling in + * {@link HttpClientStreamableHttpTransport}. + * + * @author Daniel Garnier-Moiroux + */ +public class McpHttpClientTransportAuthorizationException extends McpTransportException { + + private final HttpResponse.ResponseInfo responseInfo; + + public McpHttpClientTransportAuthorizationException(String message, HttpResponse.ResponseInfo responseInfo) { + super(message); + this.responseInfo = responseInfo; + } + + public HttpResponse.ResponseInfo getResponseInfo() { + return responseInfo; + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java new file mode 100644 index 000000000..c98fac61d --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandler.java @@ -0,0 +1,104 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.http.HttpResponse; + +import io.modelcontextprotocol.client.transport.McpHttpClientTransportAuthorizationException; +import io.modelcontextprotocol.common.McpTransportContext; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +/** + * Handle security-related errors in HTTP-client based transports. This class handles MCP + * server responses with status code 401 and 403. + * + * @see MCP + * Specification: Authorization + * @author Daniel Garnier-Moiroux + */ +public interface McpHttpClientAuthorizationErrorHandler { + + /** + * Handle authorization error (HTTP 401 or 403), and signal whether the HTTP request + * should be retried or not. If the publisher returns true, the original transport + * method (connect, sendMessage) will be replayed with the original arguments. + * Otherwise, the transport will throw an + * {@link McpHttpClientTransportAuthorizationException}, indicating the error status. + *

+ * If the returned {@link Publisher} errors, the error will be propagated to the + * calling method, to be handled by the caller. + *

+ * The number of retries is bounded by {@link #maxRetries()}. + * @param responseInfo the HTTP response information + * @param context the MCP client transport context + * @return {@link Publisher} emitting true if the original request should be replayed, + * false otherwise. + */ + Publisher handle(HttpResponse.ResponseInfo responseInfo, McpTransportContext context); + + /** + * Maximum number of authorization error retries the transport will attempt. When the + * handler signals a retry via {@link #handle}, the transport will replay the original + * request at most this many times. If the authorization error persists after + * exhausting all retries, the transport will propagate the + * {@link McpHttpClientTransportAuthorizationException}. + *

+ * Defaults to {@code 1}. + * @return the maximum number of retries + */ + default int maxRetries() { + return 1; + } + + /** + * A no-op handler, used in the default use-case. + */ + McpHttpClientAuthorizationErrorHandler NOOP = new Noop(); + + /** + * Create a {@link McpHttpClientAuthorizationErrorHandler} from a synchronous handler. + * Will be subscribed on {@link Schedulers#boundedElastic()}. The handler may be + * blocking. + * @param handler the synchronous handler + * @return an async handler + */ + static McpHttpClientAuthorizationErrorHandler fromSync(Sync handler) { + return (info, context) -> Mono.fromCallable(() -> handler.handle(info, context)) + .subscribeOn(Schedulers.boundedElastic()); + } + + /** + * Synchronous authorization error handler. + */ + interface Sync { + + /** + * Handle authorization error (HTTP 401 or 403), and signal whether the HTTP + * request should be retried or not. If the return value is true, the original + * transport method (connect, sendMessage) will be replayed with the original + * arguments. Otherwise, the transport will throw an + * {@link McpHttpClientTransportAuthorizationException}, indicating the error + * status. + * @param responseInfo the HTTP response information + * @param context the MCP client transport context + * @return true if the original request should be replayed, false otherwise. + */ + boolean handle(HttpResponse.ResponseInfo responseInfo, McpTransportContext context); + + } + + class Noop implements McpHttpClientAuthorizationErrorHandler { + + @Override + public Publisher handle(HttpResponse.ResponseInfo responseInfo, McpTransportContext context) { + return Mono.just(false); + } + + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index b078493ef..30a3146a7 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -38,6 +38,7 @@ import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.ToolInputValidator; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -98,6 +99,8 @@ public class McpAsyncServer { private final JsonSchemaValidator jsonSchemaValidator; + private final boolean validateToolInputs; + private final McpSchema.ServerCapabilities serverCapabilities; private final McpSchema.Implementation serverInfo; @@ -129,7 +132,8 @@ public class McpAsyncServer { */ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, McpServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator, + boolean validateToolInputs) { this.mcpTransportProvider = mcpTransportProvider; this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); @@ -142,6 +146,7 @@ public class McpAsyncServer { this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; this.jsonSchemaValidator = jsonSchemaValidator; + this.validateToolInputs = validateToolInputs; Map> requestHandlers = prepareRequestHandlers(); Map notificationHandlers = prepareNotificationHandlers(features); @@ -157,7 +162,8 @@ public class McpAsyncServer { McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, McpServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator, + boolean validateToolInputs) { this.mcpTransportProvider = mcpTransportProvider; this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); @@ -170,6 +176,7 @@ public class McpAsyncServer { this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; this.jsonSchemaValidator = jsonSchemaValidator; + this.validateToolInputs = validateToolInputs; Map> requestHandlers = prepareRequestHandlers(); Map notificationHandlers = prepareNotificationHandlers(features); @@ -543,6 +550,13 @@ private McpRequestHandler toolsCallRequestHandler() { .build()); } + McpSchema.Tool tool = toolSpecification.get().tool(); + CallToolResult validationError = ToolInputValidator.validate(tool, callToolRequest.arguments(), + this.validateToolInputs, this.jsonSchemaValidator); + if (validationError != null) { + return Mono.just(validationError); + } + return toolSpecification.get().callHandler().apply(exchange, callToolRequest); }; } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java index 360eb607d..bef5a5c73 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -243,7 +243,7 @@ public McpAsyncServer build() { : McpJsonDefaults.getSchemaValidator(); return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, - features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator, validateToolInputs); } } @@ -269,7 +269,7 @@ public McpAsyncServer build() { var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(); return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, - features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator, validateToolInputs); } } @@ -293,6 +293,8 @@ abstract class AsyncSpecification> { boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault(); + boolean validateToolInputs = true; + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -421,6 +423,17 @@ public AsyncSpecification strictToolNameValidation(boolean strict) { return this; } + /** + * Sets whether to validate tool inputs against the tool's input schema. + * @param validate true to validate inputs and return error on validation failure, + * false to skip validation. Defaults to true. + * @return This builder instance for method chaining + */ + public AsyncSpecification validateToolInputs(boolean validate) { + this.validateToolInputs = validate; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -818,7 +831,8 @@ public McpSyncServer build() { var asyncServer = new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, asyncFeatures, requestTimeout, uriTemplateManagerFactory, - jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator()); + jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(), + validateToolInputs); return new McpSyncServer(asyncServer, this.immediateExecution); } @@ -849,7 +863,7 @@ public McpSyncServer build() { : McpJsonDefaults.getSchemaValidator(); var asyncServer = new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, asyncFeatures, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + this.uriTemplateManagerFactory, jsonSchemaValidator, validateToolInputs); return new McpSyncServer(asyncServer, this.immediateExecution); } @@ -872,6 +886,8 @@ abstract class SyncSpecification> { boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault(); + boolean validateToolInputs = true; + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -1004,6 +1020,17 @@ public SyncSpecification strictToolNameValidation(boolean strict) { return this; } + /** + * Sets whether to validate tool inputs against the tool's input schema. + * @param validate true to validate inputs and return error on validation failure, + * false to skip validation. Defaults to true. + * @return This builder instance for method chaining + */ + public SyncSpecification validateToolInputs(boolean validate) { + this.validateToolInputs = validate; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -1401,6 +1428,8 @@ class StatelessAsyncSpecification { boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault(); + boolean validateToolInputs = true; + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -1530,6 +1559,17 @@ public StatelessAsyncSpecification strictToolNameValidation(boolean strict) { return this; } + /** + * Sets whether to validate tool inputs against the tool's input schema. + * @param validate true to validate inputs and return error on validation failure, + * false to skip validation. Defaults to true. + * @return This builder instance for method chaining + */ + public StatelessAsyncSpecification validateToolInputs(boolean validate) { + this.validateToolInputs = validate; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -1859,7 +1899,8 @@ public McpStatelessAsyncServer build() { this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); return new McpStatelessAsyncServer(transport, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, features, requestTimeout, uriTemplateManagerFactory, - jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator()); + jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(), + validateToolInputs); } } @@ -1884,6 +1925,8 @@ class StatelessSyncSpecification { boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault(); + boolean validateToolInputs = true; + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -2013,6 +2056,17 @@ public StatelessSyncSpecification strictToolNameValidation(boolean strict) { return this; } + /** + * Sets whether to validate tool inputs against the tool's input schema. + * @param validate true to validate inputs and return error on validation failure, + * false to skip validation. Defaults to true. + * @return This builder instance for method chaining + */ + public StatelessSyncSpecification validateToolInputs(boolean validate) { + this.validateToolInputs = validate; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -2360,7 +2414,8 @@ public McpStatelessSyncServer build() { var asyncServer = new McpStatelessAsyncServer(transport, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, asyncFeatures, requestTimeout, uriTemplateManagerFactory, - this.jsonSchemaValidator != null ? this.jsonSchemaValidator : McpJsonDefaults.getSchemaValidator()); + this.jsonSchemaValidator != null ? this.jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(), + validateToolInputs); return new McpStatelessSyncServer(asyncServer, this.immediateExecution); } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java index c7a1fd0d7..e85451af9 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java @@ -21,6 +21,7 @@ import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.ToolInputValidator; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -77,9 +78,12 @@ public class McpStatelessAsyncServer { private final JsonSchemaValidator jsonSchemaValidator; + private final boolean validateToolInputs; + McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, McpJsonMapper jsonMapper, McpStatelessServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator, + boolean validateToolInputs) { this.mcpTransportProvider = mcpTransport; this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); @@ -92,6 +96,7 @@ public class McpStatelessAsyncServer { this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; this.jsonSchemaValidator = jsonSchemaValidator; + this.validateToolInputs = validateToolInputs; Map> requestHandlers = new HashMap<>(); @@ -409,6 +414,13 @@ private McpStatelessRequestHandler toolsCallRequestHandler() { .build()); } + McpSchema.Tool tool = toolSpecification.get().tool(); + CallToolResult validationError = ToolInputValidator.validate(tool, callToolRequest.arguments(), + this.validateToolInputs, this.jsonSchemaValidator); + if (validationError != null) { + return Mono.just(validationError); + } + return toolSpecification.get().callHandler().apply(ctx, callToolRequest); }; } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index d3648a06f..0fb2fa778 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -286,7 +286,6 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) response.setCharacterEncoding(UTF_8); response.setHeader("Cache-Control", "no-cache"); response.setHeader("Connection", "keep-alive"); - response.setHeader("Access-Control-Allow-Origin", "*"); String sessionId = UUID.randomUUID().toString(); AsyncContext asyncContext = request.startAsync(); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 95edb63a0..fe38b2589 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -315,7 +315,6 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) response.setCharacterEncoding(UTF_8); response.setHeader("Cache-Control", "no-cache"); response.setHeader("Connection", "keep-alive"); - response.setHeader("Access-Control-Allow-Origin", "*"); AsyncContext asyncContext = request.startAsync(); asyncContext.setTimeout(0); @@ -522,7 +521,6 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { response.setCharacterEncoding(UTF_8); response.setHeader("Cache-Control", "no-cache"); response.setHeader("Connection", "keep-alive"); - response.setHeader("Access-Control-Allow-Origin", "*"); AsyncContext asyncContext = request.startAsync(); asyncContext.setTimeout(0); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index bb9cead7e..2e7f73b72 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -1307,7 +1307,9 @@ public ListToolsResult(List tools, String nextCursor) { * @param additionalProperties Whether additional properties are allowed * @param defs Schema definitions using the newer $defs keyword * @param definitions Schema definitions using the legacy definitions keyword + * @deprecated use {@link Map} instead. */ + @Deprecated @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record JsonSchema( // @formatter:off @@ -1363,7 +1365,7 @@ public record Tool( // @formatter:off @JsonProperty("name") String name, @JsonProperty("title") String title, @JsonProperty("description") String description, - @JsonProperty("inputSchema") JsonSchema inputSchema, + @JsonProperty("inputSchema") Map inputSchema, @JsonProperty("outputSchema") Map outputSchema, @JsonProperty("annotations") ToolAnnotations annotations, @JsonProperty("_meta") Map meta) { // @formatter:on @@ -1380,7 +1382,7 @@ public static class Builder { private String description; - private JsonSchema inputSchema; + private Map inputSchema; private Map outputSchema; @@ -1403,13 +1405,34 @@ public Builder description(String description) { return this; } + /** + * @deprecated use {@link #inputSchema(Map)} instead. + */ + @Deprecated public Builder inputSchema(JsonSchema inputSchema) { + Map schema = new HashMap<>(); + if (inputSchema.type() != null) + schema.put("type", inputSchema.type()); + if (inputSchema.properties() != null) + schema.put("properties", inputSchema.properties()); + if (inputSchema.required() != null) + schema.put("required", inputSchema.required()); + if (inputSchema.additionalProperties() != null) + schema.put("additionalProperties", inputSchema.additionalProperties()); + if (inputSchema.defs() != null) + schema.put("$defs", inputSchema.defs()); + if (inputSchema.definitions() != null) + schema.put("definitions", inputSchema.definitions()); + return inputSchema(schema); + } + + public Builder inputSchema(Map inputSchema) { this.inputSchema = inputSchema; return this; } public Builder inputSchema(McpJsonMapper jsonMapper, String inputSchema) { - this.inputSchema = parseSchema(jsonMapper, inputSchema); + this.inputSchema = schemaToMap(jsonMapper, inputSchema); return this; } @@ -1450,15 +1473,6 @@ private static Map schemaToMap(McpJsonMapper jsonMapper, String } } - private static JsonSchema parseSchema(McpJsonMapper jsonMapper, String schema) { - try { - return jsonMapper.readValue(schema, JsonSchema.class); - } - catch (IOException e) { - throw new IllegalArgumentException("Invalid schema: " + schema, e); - } - } - /** * Used by the client to call a tool provided by the server. * diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/util/ToolInputValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/ToolInputValidator.java new file mode 100644 index 000000000..d3db7fb4b --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/ToolInputValidator.java @@ -0,0 +1,54 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Validates tool input arguments against JSON schema. + * + * @author Andrei Shakirin + */ +public final class ToolInputValidator { + + private static final Logger logger = LoggerFactory.getLogger(ToolInputValidator.class); + + private ToolInputValidator() { + } + + /** + * Validates tool arguments against the tool's input schema. + * @param tool the tool definition containing the input schema + * @param arguments the arguments to validate + * @param validateToolInputs whether validation is enabled + * @param validator the JSON schema validator (may be null) + * @return CallToolResult with isError=true if validation fails, null if valid or + * validation skipped + */ + public static CallToolResult validate(McpSchema.Tool tool, Map arguments, + boolean validateToolInputs, JsonSchemaValidator validator) { + if (!validateToolInputs || tool.inputSchema() == null || validator == null) { + return null; + } + Map args = arguments != null ? arguments : Map.of(); + var validation = validator.validate(tool.inputSchema(), args); + if (!validation.valid()) { + logger.warn("Tool '{}' input validation failed: {}", tool.name(), validation.errorMessage()); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.errorMessage()))) + .isError(true) + .build(); + } + return null; + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java new file mode 100644 index 000000000..2812522f5 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/customizer/McpHttpClientAuthorizationErrorHandlerTest.java @@ -0,0 +1,48 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ +package io.modelcontextprotocol.client.transport.customizer; + +import java.net.http.HttpResponse; + +import io.modelcontextprotocol.common.McpTransportContext; +import org.junit.jupiter.api.Test; +import reactor.test.StepVerifier; + +import static org.mockito.Mockito.mock; + +/** + * @author Daniel Garnier-Moiroux + */ +class McpHttpClientAuthorizationErrorHandlerTest { + + private final HttpResponse.ResponseInfo responseInfo = mock(HttpResponse.ResponseInfo.class); + + private final McpTransportContext context = McpTransportContext.EMPTY; + + @Test + void whenTrueThenRetry() { + McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler + .fromSync((info, ctx) -> true); + StepVerifier.create(handler.handle(responseInfo, context)).expectNext(true).verifyComplete(); + } + + @Test + void whenFalseThenError() { + McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler + .fromSync((info, ctx) -> false); + StepVerifier.create(handler.handle(responseInfo, context)).expectNext(false).verifyComplete(); + } + + @Test + void whenExceptionThenPropagate() { + McpHttpClientAuthorizationErrorHandler handler = McpHttpClientAuthorizationErrorHandler + .fromSync((info, ctx) -> { + throw new IllegalStateException("sync handler error"); + }); + StepVerifier.create(handler.handle(responseInfo, context)) + .expectErrorMatches(t -> t instanceof IllegalStateException && t.getMessage().equals("sync handler error")) + .verify(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java index 897ae2ccc..ee8c70ffe 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/AsyncToolSpecificationBuilderTest.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.server; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; + import java.util.List; import java.util.Map; @@ -25,7 +27,6 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java index 54c45e561..f7364be2d 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/SyncToolSpecificationBuilderTest.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.server; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; + import java.util.List; import java.util.Map; @@ -22,7 +24,6 @@ import org.junit.jupiter.api.Test; import org.slf4j.LoggerFactory; -import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolInputValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolInputValidatorTests.java new file mode 100644 index 000000000..4d073d1a7 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolInputValidatorTests.java @@ -0,0 +1,98 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator.ValidationResponse; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link ToolInputValidator}. + * + * @author Andrei Shakirin + */ +class ToolInputValidatorTests { + + private final JsonSchemaValidator validator = mock(JsonSchemaValidator.class); + + private final Map inputSchema = Map.of("type", "object", "properties", + Map.of("name", Map.of("type", "string")), "required", List.of("name")); + + private final Tool toolWithSchema = Tool.builder() + .name("test-tool") + .description("Test tool") + .inputSchema(inputSchema) + .build(); + + private final Tool toolWithoutSchema = Tool.builder().name("test-tool").description("Test tool").build(); + + @Test + void validate_whenDisabled_returnsNull() { + CallToolResult result = ToolInputValidator.validate(toolWithSchema, Map.of("name", "test"), false, validator); + + assertThat(result).isNull(); + verify(validator, never()).validate(any(), any()); + } + + @Test + void validate_whenNoSchema_returnsNull() { + CallToolResult result = ToolInputValidator.validate(toolWithoutSchema, Map.of("name", "test"), true, validator); + + assertThat(result).isNull(); + verify(validator, never()).validate(any(), any()); + } + + @Test + void validate_whenNoValidator_returnsNull() { + CallToolResult result = ToolInputValidator.validate(toolWithSchema, Map.of("name", "test"), true, null); + + assertThat(result).isNull(); + } + + @Test + void validate_withValidInput_returnsNull() { + when(validator.validate(any(), any())).thenReturn(ValidationResponse.asValid(null)); + + CallToolResult result = ToolInputValidator.validate(toolWithSchema, Map.of("name", "test"), true, validator); + + assertThat(result).isNull(); + } + + @Test + void validate_withInvalidInput_returnsErrorResult() { + when(validator.validate(any(), any())).thenReturn(ValidationResponse.asInvalid("missing required: 'name'")); + + CallToolResult result = ToolInputValidator.validate(toolWithSchema, Map.of(), true, validator); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(((TextContent) result.content().get(0)).text()).contains("missing required: 'name'"); + verify(validator).validate(any(), any()); + } + + @Test + void validate_withNullArguments_usesEmptyMap() { + when(validator.validate(any(), any())).thenReturn(ValidationResponse.asValid(null)); + + CallToolResult result = ToolInputValidator.validate(toolWithSchema, null, true, validator); + + assertThat(result).isNull(); + verify(validator).validate(any(), any()); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java index ce8755223..a1cafa2e1 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolsUtils.java @@ -1,15 +1,14 @@ package io.modelcontextprotocol.util; -import io.modelcontextprotocol.spec.McpSchema; - import java.util.Collections; +import java.util.Map; public final class ToolsUtils { private ToolsUtils() { } - public static final McpSchema.JsonSchema EMPTY_JSON_SCHEMA = new McpSchema.JsonSchema("object", - Collections.emptyMap(), null, null, null, null); + public static final Map EMPTY_JSON_SCHEMA = Map.of("type", "object", "properties", + Collections.emptyMap()); } diff --git a/mcp-json-jackson2/pom.xml b/mcp-json-jackson2/pom.xml index f25877cd3..5dd9a5ac1 100644 --- a/mcp-json-jackson2/pom.xml +++ b/mcp-json-jackson2/pom.xml @@ -6,18 +6,20 @@ io.modelcontextprotocol.sdk mcp-parent - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT mcp-json-jackson2 jar Java MCP SDK JSON Jackson 2 Java MCP SDK JSON implementation based on Jackson 2 https://github.com/modelcontextprotocol/java-sdk + https://github.com/modelcontextprotocol/java-sdk - git://github.com/modelcontextprotocol/java-sdk.git - git@github.com/modelcontextprotocol/java-sdk.git + scm:git:git://github.com/modelcontextprotocol/java-sdk.git + scm:git:ssh://git@github.com/modelcontextprotocol/java-sdk.git + @@ -62,21 +64,21 @@ - - com.fasterxml.jackson.core - jackson-databind - ${jackson2.version} - - - io.modelcontextprotocol.sdk - mcp-core - 1.1.0-SNAPSHOT - - - com.networknt - json-schema-validator - ${json-schema-validator-jackson2.version} - + + com.fasterxml.jackson.core + jackson-databind + ${jackson2.version} + + + io.modelcontextprotocol.sdk + mcp-core + 2.0.0-SNAPSHOT + + + com.networknt + json-schema-validator + ${json-schema-validator-jackson2.version} + org.assertj @@ -104,4 +106,4 @@ - + \ No newline at end of file diff --git a/mcp-json-jackson3/pom.xml b/mcp-json-jackson3/pom.xml index 99baf14e1..2afd474f6 100644 --- a/mcp-json-jackson3/pom.xml +++ b/mcp-json-jackson3/pom.xml @@ -6,18 +6,20 @@ io.modelcontextprotocol.sdk mcp-parent - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT mcp-json-jackson3 jar Java MCP SDK JSON Jackson 3 Java MCP SDK JSON implementation based on Jackson 3 https://github.com/modelcontextprotocol/java-sdk + https://github.com/modelcontextprotocol/java-sdk - git://github.com/modelcontextprotocol/java-sdk.git - git@github.com/modelcontextprotocol/java-sdk.git + scm:git:git://github.com/modelcontextprotocol/java-sdk.git + scm:git:ssh://git@github.com/modelcontextprotocol/java-sdk.git + @@ -61,21 +63,21 @@ - - io.modelcontextprotocol.sdk - mcp-core - 1.1.0-SNAPSHOT - - - tools.jackson.core - jackson-databind - ${jackson3.version} - - - com.networknt - json-schema-validator - ${json-schema-validator-jackson3.version} - + + io.modelcontextprotocol.sdk + mcp-core + 2.0.0-SNAPSHOT + + + tools.jackson.core + jackson-databind + ${jackson3.version} + + + com.networknt + json-schema-validator + ${json-schema-validator-jackson3.version} + org.assertj @@ -103,4 +105,4 @@ - + \ No newline at end of file diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index 531c0bbc5..45e74717c 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -1,12 +1,12 @@ + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 io.modelcontextprotocol.sdk mcp-parent - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT mcp-test jar @@ -16,15 +16,15 @@ https://github.com/modelcontextprotocol/java-sdk - git://github.com/modelcontextprotocol/java-sdk.git - git@github.com/modelcontextprotocol/java-sdk.git + scm:git:git://github.com/modelcontextprotocol/java-sdk.git + scm:git:ssh://git@github.com/modelcontextprotocol/java-sdk.git io.modelcontextprotocol.sdk mcp-core - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT @@ -159,7 +159,7 @@ io.modelcontextprotocol.sdk mcp-json-jackson3 - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT test @@ -170,7 +170,7 @@ io.modelcontextprotocol.sdk mcp-json-jackson2 - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT test diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java index 1ed9b270a..beec006ba 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; + import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -23,6 +25,7 @@ import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.McpJsonDefaults; import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.McpSyncServer; @@ -47,15 +50,14 @@ import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.util.McpJsonMapperUtils; import io.modelcontextprotocol.util.Utils; import net.javacrumbs.jsonunit.core.Option; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; -import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; import static org.assertj.core.api.Assertions.assertThat; @@ -404,6 +406,8 @@ void testCreateElicitationSuccess(String clientType) { .addContent(new McpSchema.TextContent("CALL RESPONSE")) .build(); + AtomicReference elicitResultRef = new AtomicReference<>(); + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(EMPTY_JSON_SCHEMA).build()) .callHandler((exchange, request) -> { @@ -414,13 +418,9 @@ void testCreateElicitationSuccess(String clientType) { Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) .build(); - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); + return exchange.createElicitation(elicitationRequest) + .doOnNext(elicitResultRef::set) + .thenReturn(callResponse); }) .build(); @@ -438,6 +438,11 @@ void testCreateElicitationSuccess(String clientType) { assertThat(response).isNotNull(); assertThat(response).isEqualTo(callResponse); + assertWith(elicitResultRef.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }); } finally { mcpServer.closeGracefully().block(); @@ -912,6 +917,62 @@ void testToolCallSuccessWithTranportContextExtraction(String clientType) { } } + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testToolWithNonAsciiCharacters(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + String inputSchema = """ + { + "type": "object", + "properties": { + "username": { "type": "string" } + }, + "required": ["username"] + } + """; + + McpServerFeatures.SyncToolSpecification nonAsciiTool = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("greeter") + .description("打招呼") + .inputSchema(McpJsonDefaults.getMapper(), inputSchema) + .build()) + .callHandler((exchange, request) -> { + String username = (String) request.arguments().get("username"); + return McpSchema.CallToolResult.builder() + .addContent(new McpSchema.TextContent("Hello " + username)) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(nonAsciiTool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var tools = mcpClient.listTools().tools(); + assertThat(tools).hasSize(1); + assertThat(tools.get(0).name()).isEqualTo("greeter"); + assertThat(tools.get(0).description()).isEqualTo("打招呼"); + + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("greeter", Map.of("username", "测试用户"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()).isEqualTo("Hello 测试用户"); + } + finally { + mcpServer.closeGracefully(); + } + } + @ParameterizedTest(name = "{0} : {displayName} ") @MethodSource("clientsForTesting") void testToolListChangeHandlingSuccess(String clientType) { diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java index 7755ce456..24cc9c3d0 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; + import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -32,7 +34,6 @@ import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; -import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; import static org.assertj.core.api.Assertions.assertThat; diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index bee8f4f16..2ef45a1e0 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -610,22 +610,17 @@ void testListAllResourceTemplatesReturnsImmutableList() { }); } - // @Test + @Test void testResourceSubscription() { withClient(createMcpTransport(), mcpAsyncClient -> { - StepVerifier.create(mcpAsyncClient.listResources()).consumeNextWith(resources -> { - if (!resources.resources().isEmpty()) { - Resource firstResource = resources.resources().get(0); - - // Test subscribe - StepVerifier.create(mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri()))) - .verifyComplete(); - - // Test unsubscribe - StepVerifier.create(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))) - .verifyComplete(); + StepVerifier.create(mcpAsyncClient.listResources().flatMap(resources -> { + if (resources.resources().isEmpty()) { + return Mono.empty(); } - }).verifyComplete(); + Resource firstResource = resources.resources().get(0); + return mcpAsyncClient.subscribeResource(new SubscribeRequest(firstResource.uri())) + .then(mcpAsyncClient.unsubscribeResource(new UnsubscribeRequest(firstResource.uri()))); + })).verifyComplete(); }); } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java index 26d60568a..7fe7bd657 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpSyncClientTests.java @@ -154,6 +154,19 @@ void testListTools() { }); } + @Test + void testListToolsWithMeta() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + Map meta = java.util.Map.of("requestId", "test-123"); + ListToolsResult tools = mcpSyncClient.listTools(McpSchema.FIRST_PAGE, meta); + + assertThat(tools).isNotNull().satisfies(result -> { + assertThat(result.tools()).isNotNull().isNotEmpty(); + }); + }); + } + @Test void testListAllTools() { withClient(createMcpTransport(), mcpSyncClient -> { @@ -678,4 +691,43 @@ void testProgressConsumer() { }); } + @Test + void testListResourcesWithMeta() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + Map meta = java.util.Map.of("requestId", "test-123"); + ListResourcesResult resources = mcpSyncClient.listResources(McpSchema.FIRST_PAGE, meta); + + assertThat(resources).isNotNull().satisfies(result -> { + assertThat(result.resources()).isNotNull(); + }); + }); + } + + @Test + void testListResourceTemplatesWithMeta() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + Map meta = java.util.Map.of("requestId", "test-123"); + ListResourceTemplatesResult result = mcpSyncClient.listResourceTemplates(McpSchema.FIRST_PAGE, meta); + + assertThat(result).isNotNull().satisfies(r -> { + assertThat(r.resourceTemplates()).isNotNull(); + }); + }); + } + + @Test + void testListPromptsWithMeta() { + withClient(createMcpTransport(), mcpSyncClient -> { + mcpSyncClient.initialize(); + Map meta = java.util.Map.of("requestId", "test-123"); + McpSchema.ListPromptsResult result = mcpSyncClient.listPrompts(McpSchema.FIRST_PAGE, meta); + + assertThat(result).isNotNull().satisfies(r -> { + assertThat(r.prompts()).isNotNull(); + }); + }); + } + } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 9cd1191d1..731f763a3 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -4,8 +4,11 @@ package io.modelcontextprotocol.server; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; + import java.time.Duration; import java.util.List; +import java.util.Map; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -25,7 +28,6 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index eee5f1a4d..d8d036dc0 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -4,7 +4,10 @@ package io.modelcontextprotocol.server; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; + import java.util.List; +import java.util.Map; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -20,7 +23,6 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java b/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java index ce8755223..a1cafa2e1 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/util/ToolsUtils.java @@ -1,15 +1,14 @@ package io.modelcontextprotocol.util; -import io.modelcontextprotocol.spec.McpSchema; - import java.util.Collections; +import java.util.Map; public final class ToolsUtils { private ToolsUtils() { } - public static final McpSchema.JsonSchema EMPTY_JSON_SCHEMA = new McpSchema.JsonSchema("object", - Collections.emptyMap(), null, null, null, null); + public static final Map EMPTY_JSON_SCHEMA = Map.of("type", "object", "properties", + Collections.emptyMap()); } diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java index 48bf1da5b..732f82926 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java @@ -44,11 +44,10 @@ private McpClientTransport createMockTransportForToolValidation(boolean hasOutpu Map inputSchemaMap = Map.of("type", "object", "properties", Map.of("expression", Map.of("type", "string")), "required", List.of("expression")); - McpSchema.JsonSchema inputSchema = new McpSchema.JsonSchema("object", inputSchemaMap, null, null, null, null); McpSchema.Tool.Builder toolBuilder = McpSchema.Tool.builder() .name("calculator") .description("Performs mathematical calculations") - .inputSchema(inputSchema); + .inputSchema(inputSchemaMap); if (hasOutputSchema) { Map outputSchema = Map.of("type", "object", "properties", @@ -239,72 +238,159 @@ void testCallToolWithOutputSchemaValidationFailure() { } @Test - void testListToolsWithEmptyCursor() { - McpSchema.Tool addTool = McpSchema.Tool.builder().name("add").description("calculate add").build(); - McpSchema.Tool subtractTool = McpSchema.Tool.builder() - .name("subtract") - .description("calculate subtract") - .build(); - McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(addTool, subtractTool), ""); + void testListToolsWithCursorAndMeta() { + var transport = new TestMcpClientTransport(); + McpAsyncClient client = McpClient.async(transport).build(); + + Map meta = Map.of("customKey", "customValue"); + McpSchema.ListToolsResult result = client.listTools("cursor-1", meta).block(); + assertThat(result).isNotNull(); + assertThat(result.tools()).hasSize(1); + assertThat(transport.getCapturedRequest()).isNotNull(); + assertThat(transport.getCapturedRequest().cursor()).isEqualTo("cursor-1"); + assertThat(transport.getCapturedRequest().meta()).containsEntry("customKey", "customValue"); + } - McpClientTransport transport = new McpClientTransport() { - Function, Mono> handler; + @Test + void testListResourcesWithCursorAndMeta() { + var transport = new TestMcpClientTransport(); + McpAsyncClient client = McpClient.async(transport).build(); + + Map meta = Map.of("customKey", "customValue"); + McpSchema.ListResourcesResult result = client.listResources("cursor-1", meta).block(); + assertThat(result).isNotNull(); + assertThat(result.resources()).hasSize(1); + assertThat(transport.getCapturedRequest()).isNotNull(); + assertThat(transport.getCapturedRequest().cursor()).isEqualTo("cursor-1"); + assertThat(transport.getCapturedRequest().meta()).containsEntry("customKey", "customValue"); + } - @Override - public Mono connect( - Function, Mono> handler) { - return Mono.deferContextual(ctx -> { - this.handler = handler; - return Mono.empty(); - }); - } + @Test + void testListResourceTemplatesWithCursorAndMeta() { + var transport = new TestMcpClientTransport(); + McpAsyncClient client = McpClient.async(transport).build(); + + Map meta = Map.of("customKey", "customValue"); + McpSchema.ListResourceTemplatesResult result = client.listResourceTemplates("cursor-1", meta).block(); + assertThat(result).isNotNull(); + assertThat(result.resourceTemplates()).hasSize(1); + assertThat(transport.getCapturedRequest()).isNotNull(); + assertThat(transport.getCapturedRequest().cursor()).isEqualTo("cursor-1"); + assertThat(transport.getCapturedRequest().meta()).containsEntry("customKey", "customValue"); + } - @Override - public Mono closeGracefully() { - return Mono.empty(); - } + @Test + void testListPromptsWithCursorAndMeta() { + var transport = new TestMcpClientTransport(); + McpAsyncClient client = McpClient.async(transport).build(); + + Map meta = Map.of("customKey", "customValue"); + McpSchema.ListPromptsResult result = client.listPrompts("cursor-1", meta).block(); + assertThat(result).isNotNull(); + assertThat(result.prompts()).hasSize(1); + assertThat(transport.getCapturedRequest()).isNotNull(); + assertThat(transport.getCapturedRequest().cursor()).isEqualTo("cursor-1"); + assertThat(transport.getCapturedRequest().meta()).containsEntry("customKey", "customValue"); - @Override - public Mono sendMessage(McpSchema.JSONRPCMessage message) { - if (!(message instanceof McpSchema.JSONRPCRequest request)) { - return Mono.empty(); - } + } - McpSchema.JSONRPCResponse response; - if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { - response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), MOCK_INIT_RESULT, - null); - } - else if (McpSchema.METHOD_TOOLS_LIST.equals(request.method())) { - response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), mockToolsResult, - null); - } - else { - return Mono.empty(); - } + static class TestMcpClientTransport implements McpClientTransport { - return handler.apply(Mono.just(response)).then(); + private Function, Mono> handler; + + private McpSchema.PaginatedRequest capturedRequest = null; + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler = handler; + return Mono.empty(); + }); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!(message instanceof McpSchema.JSONRPCRequest request)) { + return Mono.empty(); + } + McpSchema.JSONRPCResponse response; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + McpSchema.ServerCapabilities caps = McpSchema.ServerCapabilities.builder() + .prompts(false) + .resources(false, false) + .tools(false) + .build(); + + McpSchema.InitializeResult initResult = new McpSchema.InitializeResult(ProtocolVersions.MCP_2024_11_05, + caps, MOCK_SERVER_INFO, null); + + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), initResult, null); + } + else if (McpSchema.METHOD_PROMPT_LIST.equals(request.method())) { + capturedRequest = JSON_MAPPER.convertValue(request.params(), McpSchema.PaginatedRequest.class); + + McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt", "A test prompt", List.of()); + McpSchema.ListPromptsResult mockPromptResult = new McpSchema.ListPromptsResult(List.of(mockPrompt), + null); + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), mockPromptResult, + null); + } + else if (McpSchema.METHOD_RESOURCES_TEMPLATES_LIST.equals(request.method())) { + capturedRequest = JSON_MAPPER.convertValue(request.params(), McpSchema.PaginatedRequest.class); + + McpSchema.ResourceTemplate mockTemplate = new McpSchema.ResourceTemplate("file:///{name}", "template", + null, null, null); + McpSchema.ListResourceTemplatesResult mockResourceTemplateResult = new McpSchema.ListResourceTemplatesResult( + List.of(mockTemplate), null); + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + mockResourceTemplateResult, null); } + else if (McpSchema.METHOD_RESOURCES_LIST.equals(request.method())) { + capturedRequest = JSON_MAPPER.convertValue(request.params(), McpSchema.PaginatedRequest.class); + + McpSchema.Resource mockResource = McpSchema.Resource.builder() + .uri("file:///test.txt") + .name("test.txt") + .build(); + McpSchema.ListResourcesResult mockResourceResult = new McpSchema.ListResourcesResult( + List.of(mockResource), null); + + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), mockResourceResult, + null); + } + else if (McpSchema.METHOD_TOOLS_LIST.equals(request.method())) { + capturedRequest = JSON_MAPPER.convertValue(request.params(), McpSchema.PaginatedRequest.class); - @Override - public T unmarshalFrom(Object data, TypeRef typeRef) { - return JSON_MAPPER.convertValue(data, new TypeRef<>() { - @Override - public java.lang.reflect.Type getType() { - return typeRef.getType(); - } - }); + McpSchema.Tool addTool = McpSchema.Tool.builder().name("add").description("calculate add").build(); + McpSchema.ListToolsResult mockToolsResult = new McpSchema.ListToolsResult(List.of(addTool), null); + response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), mockToolsResult, + null); } - }; + else { + return Mono.empty(); + } + return handler.apply(Mono.just(response)).then(); + } - McpAsyncClient client = McpClient.async(transport).enableCallToolSchemaCaching(true).build(); + @Override + public T unmarshalFrom(Object data, TypeRef typeRef) { + return JSON_MAPPER.convertValue(data, new TypeRef<>() { + @Override + public java.lang.reflect.Type getType() { + return typeRef.getType(); + } + }); + } - Mono mono = client.listTools(); - McpSchema.ListToolsResult toolsResult = mono.block(); - assertThat(toolsResult).isNotNull(); + public McpSchema.PaginatedRequest getCapturedRequest() { + return capturedRequest; + } - Set names = toolsResult.tools().stream().map(McpSchema.Tool::name).collect(Collectors.toSet()); - assertThat(names).containsExactlyInAnyOrder("subtract", "add"); } } diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index a24805a30..f3bc17f5b 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -333,66 +333,6 @@ void testCustomizeClient() { customizedTransport.closeGracefully().block(); } - @Test - void testCustomizeRequest() { - // Create an atomic boolean to verify the customizer was called - AtomicBoolean customizerCalled = new AtomicBoolean(false); - - // Create a reference to store the custom header value - AtomicReference headerName = new AtomicReference<>(); - AtomicReference headerValue = new AtomicReference<>(); - - // Create a transport with the customizer - HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) - // Create a request customizer that adds a custom header - .customizeRequest(builder -> { - builder.header("X-Custom-Header", "test-value"); - customizerCalled.set(true); - - // Create a new request to verify the header was set - HttpRequest request = builder.uri(URI.create("http://example.com")).build(); - headerName.set("X-Custom-Header"); - headerValue.set(request.headers().firstValue("X-Custom-Header").orElse(null)); - }) - .build(); - - // Verify the customizer was called - assertThat(customizerCalled.get()).isTrue(); - - // Verify the header was set correctly - assertThat(headerName.get()).isEqualTo("X-Custom-Header"); - assertThat(headerValue.get()).isEqualTo("test-value"); - - // Clean up - customizedTransport.closeGracefully().block(); - } - - @Test - void testChainedCustomizations() { - // Create atomic booleans to verify both customizers were called - AtomicBoolean clientCustomizerCalled = new AtomicBoolean(false); - AtomicBoolean requestCustomizerCalled = new AtomicBoolean(false); - - // Create a transport with both customizers chained - HttpClientSseClientTransport customizedTransport = HttpClientSseClientTransport.builder(host) - .customizeClient(builder -> { - builder.connectTimeout(Duration.ofSeconds(30)); - clientCustomizerCalled.set(true); - }) - .customizeRequest(builder -> { - builder.header("X-Api-Key", "test-api-key"); - requestCustomizerCalled.set(true); - }) - .build(); - - // Verify both customizers were called - assertThat(clientCustomizerCalled.get()).isTrue(); - assertThat(requestCustomizerCalled.get()).isTrue(); - - // Clean up - customizedTransport.closeGracefully().block(); - } - @Test void testRequestCustomizer() { var mockCustomizer = mock(McpSyncHttpClientRequestCustomizer.class); diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java index b82d6eb2c..d3793ca01 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportErrorHandlingTest.java @@ -1,26 +1,23 @@ /* - * Copyright 2025-2025 the original author or authors. + * Copyright 2025-2026 the original author or authors. */ package io.modelcontextprotocol.client.transport; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; - import java.io.IOException; import java.net.InetSocketAddress; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; - -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; +import java.util.function.Predicate; import com.sun.net.httpserver.HttpServer; - +import io.modelcontextprotocol.client.transport.customizer.McpHttpClientAuthorizationErrorHandler; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.transport.TomcatTestUtil; import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; @@ -28,14 +25,31 @@ import io.modelcontextprotocol.spec.McpTransportException; import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.spec.ProtocolVersions; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.type; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + /** * Tests for error handling changes in HttpClientStreamableHttpTransport. Specifically * tests the distinction between session-related errors and general transport errors for * 404 and 400 status codes. * * @author Christian Tzolov + * @author Daniel Garnier-Moiroux */ @Timeout(15) public class HttpClientStreamableHttpTransportErrorHandlingTest { @@ -46,11 +60,17 @@ public class HttpClientStreamableHttpTransportErrorHandlingTest { private HttpServer server; - private AtomicReference serverResponseStatus = new AtomicReference<>(200); + private final AtomicInteger serverResponseStatus = new AtomicInteger(200); + + private final AtomicInteger serverSseResponseStatus = new AtomicInteger(200); + + private final AtomicReference currentServerSessionId = new AtomicReference<>(null); + + private final AtomicReference lastReceivedSessionId = new AtomicReference<>(null); - private AtomicReference currentServerSessionId = new AtomicReference<>(null); + private final AtomicInteger processedMessagesCount = new AtomicInteger(0); - private AtomicReference lastReceivedSessionId = new AtomicReference<>(null); + private final AtomicInteger processedSseConnectCount = new AtomicInteger(0); private McpClientTransport transport; @@ -88,6 +108,20 @@ else if ("POST".equals(httpExchange.getRequestMethod())) { else { httpExchange.sendResponseHeaders(status, 0); } + processedMessagesCount.incrementAndGet(); + } + else if ("GET".equals(httpExchange.getRequestMethod())) { + int status = serverSseResponseStatus.get(); + if (status == 200) { + httpExchange.getResponseHeaders().set("Content-Type", "text/event-stream"); + httpExchange.sendResponseHeaders(200, 0); + String sseData = "event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{}}\n\n"; + httpExchange.getResponseBody().write(sseData.getBytes()); + } + else { + httpExchange.sendResponseHeaders(status, 0); + } + processedSseConnectCount.incrementAndGet(); } httpExchange.close(); }); @@ -103,6 +137,7 @@ void stopServer() { if (server != null) { server.stop(0); } + StepVerifier.create(transport.closeGracefully()).verifyComplete(); } /** @@ -334,6 +369,406 @@ else if (status == 404) { StepVerifier.create(transport.closeGracefully()).verifyComplete(); } + @Test + void test405OnConnectReturnsEmptyFlux() { + serverSseResponseStatus.set(405); + AtomicReference capturedException = new AtomicReference<>(); + var transport = HttpClientStreamableHttpTransport.builder(HOST).openConnectionOnStartup(true).build(); + transport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(transport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isNull(); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + + @Nested + class AuthorizationError { + + @Nested + class SendMessage { + + @ParameterizedTest + @ValueSource(ints = { 401, 403 }) + void invokeHandler(int httpStatus) { + serverResponseStatus.set(httpStatus); + + AtomicReference capturedResponseInfo = new AtomicReference<>(); + AtomicReference capturedContext = new AtomicReference<>(); + + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> { + capturedResponseInfo.set(responseInfo); + capturedContext.set(context); + return Mono.just(false); + }) + .build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(httpStatus)) + .verify(); + assertThat(processedMessagesCount.get()).isEqualTo(1); + assertThat(capturedResponseInfo.get()).isNotNull(); + assertThat(capturedResponseInfo.get().statusCode()).isEqualTo(httpStatus); + assertThat(capturedContext.get()).isNotNull(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void defaultHandler() { + serverResponseStatus.set(401); + + var authTransport = HttpClientStreamableHttpTransport.builder(HOST).build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + assertThat(processedMessagesCount.get()).isEqualTo(1); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void retry() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> { + serverResponseStatus.set(200); + return Mono.just(true); + }) + .build(); + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())).verifyComplete(); + // initial request + retry + assertThat(processedMessagesCount.get()).isEqualTo(2); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void retryAtMostOnce() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> Mono.just(true)) + .build(); + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + // initial request + 1 retry (maxRetries default is 1) + assertThat(processedMessagesCount.get()).isEqualTo(2); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void customMaxRetries() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler(new McpHttpClientAuthorizationErrorHandler() { + @Override + public Publisher handle(HttpResponse.ResponseInfo responseInfo, + McpTransportContext context) { + return Mono.just(true); + } + + @Override + public int maxRetries() { + return 3; + } + }) + .build(); + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + // initial request + 3 retries + assertThat(processedMessagesCount.get()).isEqualTo(4); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void noRetry() { + serverResponseStatus.set(401); + + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> Mono.just(false)) + .build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + assertThat(processedMessagesCount.get()).isEqualTo(1); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void propagateHandlerError() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler( + (responseInfo, context) -> Mono.error(new IllegalStateException("handler error"))) + .build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(throwable -> throwable instanceof IllegalStateException + && throwable.getMessage().equals("handler error")) + .verify(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void emptyHandler() { + serverResponseStatus.set(401); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> Mono.empty()) + .build(); + + StepVerifier.create(authTransport.sendMessage(createTestRequestMessage())) + .expectErrorMatches(authorizationError(401)) + .verify(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + } + + @Nested + class Connect { + + @ParameterizedTest + @ValueSource(ints = { 401, 403 }) + void invokeHandler(int httpStatus) { + serverSseResponseStatus.set(httpStatus); + @SuppressWarnings("unchecked") + AtomicReference capturedException = new AtomicReference<>(); + + AtomicReference capturedResponseInfo = new AtomicReference<>(); + AtomicReference capturedContext = new AtomicReference<>(); + + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .authorizationErrorHandler((responseInfo, context) -> { + capturedResponseInfo.set(responseInfo); + capturedContext.set(context); + return Mono.just(false); + }) + .openConnectionOnStartup(true) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(messages).isEmpty(); + assertThat(capturedResponseInfo.get()).isNotNull(); + assertThat(capturedResponseInfo.get().statusCode()).isEqualTo(httpStatus); + assertThat(capturedContext.get()).isNotNull(); + assertThat(capturedException.get()).hasMessage("Authorization error connecting to SSE stream") + .asInstanceOf(type(McpHttpClientTransportAuthorizationException.class)) + .extracting(McpHttpClientTransportAuthorizationException::getResponseInfo) + .extracting(HttpResponse.ResponseInfo::statusCode) + .isEqualTo(httpStatus); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void defaultHandler() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + StepVerifier.create(authTransport.connect(msg -> msg)).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void retry() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler((responseInfo, context) -> { + serverSseResponseStatus.set(200); + return Mono.just(true); + }) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + var messageHandlerClosed = new AtomicBoolean(false); + StepVerifier + .create(authTransport + .connect(msg -> msg.doOnNext(messages::add).doFinally(s -> messageHandlerClosed.set(true)))) + .verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(messageHandlerClosed).isTrue()); + assertThat(processedSseConnectCount.get()).isEqualTo(2); + assertThat(messages).hasSize(1); + assertThat(capturedException.get()).isNull(); + assertThat(messageHandlerClosed.get()).isTrue(); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void retryAtMostOnce() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler((responseInfo, context) -> { + return Mono.just(true); + }) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(capturedException.get()).isNotNull()); + // initial request + 1 retry (maxRetries default is 1) + assertThat(processedSseConnectCount.get()).isEqualTo(2); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void customMaxRetries() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler(new McpHttpClientAuthorizationErrorHandler() { + @Override + public Publisher handle(HttpResponse.ResponseInfo responseInfo, + McpTransportContext context) { + return Mono.just(true); + } + + @Override + public int maxRetries() { + return 3; + } + }) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(capturedException.get()).isNotNull()); + // initial request + 3 retries + assertThat(processedSseConnectCount.get()).isEqualTo(4); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void noRetry() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler((responseInfo, context) -> { + // if there was a retry, the request would succeed. + serverSseResponseStatus.set(200); + return Mono.just(false); + }) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void emptyHandler() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler((responseInfo, context) -> Mono.empty()) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(McpHttpClientTransportAuthorizationException.class); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + @Test + void propagateHandlerError() { + serverSseResponseStatus.set(401); + AtomicReference capturedException = new AtomicReference<>(); + var authTransport = HttpClientStreamableHttpTransport.builder(HOST) + .openConnectionOnStartup(true) + .authorizationErrorHandler( + (responseInfo, context) -> Mono.error(new IllegalStateException("handler error"))) + .build(); + authTransport.setExceptionHandler(capturedException::set); + + var messages = new ArrayList(); + StepVerifier.create(authTransport.connect(msg -> msg.doOnNext(messages::add))).verifyComplete(); + Awaitility.await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertThat(processedSseConnectCount.get()).isEqualTo(1)); + assertThat(messages).isEmpty(); + assertThat(capturedException.get()).isInstanceOf(IllegalStateException.class) + .hasMessage("handler error"); + + StepVerifier.create(authTransport.closeGracefully()).verifyComplete(); + } + + } + + private static Predicate authorizationError(int httpStatus) { + return throwable -> throwable instanceof McpHttpClientTransportAuthorizationException + && throwable.getMessage().contains("Authorization error") + && ((McpHttpClientTransportAuthorizationException) throwable).getResponseInfo() + .statusCode() == httpStatus; + } + + } + private McpSchema.JSONRPCRequest createTestRequestMessage() { var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, McpSchema.ClientCapabilities.builder().roots(true).build(), diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java index 491c2d4ed..3d40453a3 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.server; +import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; + import java.time.Duration; import java.util.List; import java.util.Map; @@ -48,7 +50,6 @@ import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.APPLICATION_JSON; import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.TEXT_EVENT_STREAM; import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; -import static io.modelcontextprotocol.util.ToolsUtils.EMPTY_JSON_SCHEMA; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; import static org.assertj.core.api.Assertions.assertThat; diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/server/ToolInputValidationIntegrationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/ToolInputValidationIntegrationTests.java new file mode 100644 index 000000000..13bcbc571 --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/ToolInputValidationIntegrationTests.java @@ -0,0 +1,254 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for tool input validation against JSON schema. Validates that input validation + * errors are returned as Tool Execution Errors (isError=true) rather than Protocol + * Errors, per MCP specification. + * + * @author Andrei Shakirin + */ +@Timeout(15) +class ToolInputValidationIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final String TOOL_NAME = "test-tool"; + + private static final McpSchema.JsonSchema INPUT_SCHEMA = new McpSchema.JsonSchema("object", + Map.of("name", Map.of("type", "string"), "age", Map.of("type", "integer", "minimum", 0)), + List.of("name", "age"), null, null, null); + + private static final McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = ( + r) -> McpTransportContext.create(Map.of("important", "value")); + + private HttpServletStreamableServerTransportProvider mcpServerTransportProvider; + + private Tomcat tomcat; + + static Stream validInputTestCases() { + return Stream.of( + // serverType, validationEnabled, inputArgs, expectedOutput + Arguments.of("sync", true, Map.of("name", "Alice", "age", 30), "Hello Alice, age 30"), + Arguments.of("async", true, Map.of("name", "Bob", "age", 25), "Hello Bob, age 25"), + Arguments.of("sync", false, Map.of("name", "Alice", "age", 30), "Hello Alice, age 30"), + Arguments.of("async", false, Map.of("name", "Bob", "age", 25), "Hello Bob, age 25")); + } + + static Stream invalidInputTestCases() { + return Stream.of( + // serverType, inputArgs, expectedErrorSubstring + Arguments.of("sync", Map.of("name", "Alice"), "age"), // missing required + Arguments.of("async", Map.of("name", "Bob", "age", -10), "minimum")); // invalid + // value + } + + private final McpClient.SyncSpec clientBuilder = McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT).endpoint(MESSAGE_ENDPOINT).build()) + .requestTimeout(Duration.ofSeconds(10)); + + @BeforeEach + public void before() { + mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder() + .mcpEndpoint(MESSAGE_ENDPOINT) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + private McpServerFeatures.SyncToolSpecification createSyncTool() { + Tool tool = Tool.builder() + .name(TOOL_NAME) + .description("Test tool with schema") + .inputSchema(INPUT_SCHEMA) + .build(); + + return McpServerFeatures.SyncToolSpecification.builder().tool(tool).callHandler((exchange, request) -> { + String name = (String) request.arguments().get("name"); + Integer age = ((Number) request.arguments().get("age")).intValue(); + return CallToolResult.builder() + .content(List.of(new TextContent("Hello " + name + ", age " + age))) + .isError(false) + .build(); + }).build(); + } + + private McpServerFeatures.AsyncToolSpecification createAsyncTool() { + Tool tool = Tool.builder() + .name(TOOL_NAME) + .description("Test tool with schema") + .inputSchema(INPUT_SCHEMA) + .build(); + + return McpServerFeatures.AsyncToolSpecification.builder().tool(tool).callHandler((exchange, request) -> { + String name = (String) request.arguments().get("name"); + Integer age = ((Number) request.arguments().get("age")).intValue(); + return Mono.just(CallToolResult.builder() + .content(List.of(new TextContent("Hello " + name + ", age " + age))) + .isError(false) + .build()); + }).build(); + } + + @ParameterizedTest(name = "{0} server, validation={1}") + @MethodSource("validInputTestCases") + void validInput_shouldSucceed(String serverType, boolean validationEnabled, Map input, + String expectedOutput) { + Object server = createServer(serverType, validationEnabled); + + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("test-client", "1.0.0")).build()) { + client.initialize(); + CallToolResult result = client.callTool(new CallToolRequest(TOOL_NAME, input)); + + assertThat(result.isError()).isFalse(); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedOutput); + } + finally { + closeServer(server, serverType); + } + } + + @ParameterizedTest(name = "{0} server, input={1}") + @MethodSource("invalidInputTestCases") + void invalidInput_withDefaultValidation_shouldReturnToolError(String serverType, Map input, + String expectedErrorSubstring) { + Object server = createServerWithDefaultValidation(serverType); + + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("test-client", "1.0.0")).build()) { + client.initialize(); + CallToolResult result = client.callTool(new CallToolRequest(TOOL_NAME, input)); + + assertThat(result.isError()).isTrue(); + String errorMessage = ((TextContent) result.content().get(0)).text(); + assertThat(errorMessage).containsIgnoringCase(expectedErrorSubstring); + } + finally { + closeServer(server, serverType); + } + } + + @ParameterizedTest(name = "{0} server, input={1}") + @MethodSource("invalidInputTestCases") + void invalidInput_withValidationDisabled_shouldSucceed(String serverType, Map input, + String ignored) { + Object server = createServer(serverType, false); + + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("test-client", "1.0.0")).build()) { + client.initialize(); + // Invalid input should pass through when validation is disabled + // The tool handler will fail, but that's expected - we're testing validation + // is skipped + try { + client.callTool(new CallToolRequest(TOOL_NAME, input)); + } + catch (Exception e) { + // Expected - tool handler fails on invalid input, but validation didn't + // block it + assertThat(e.getMessage()).doesNotContainIgnoringCase("validation"); + } + } + finally { + closeServer(server, serverType); + } + } + + private Object createServerWithDefaultValidation(String serverType) { + if ("sync".equals(serverType)) { + return prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").tools(createSyncTool()).build(); + } + else { + return prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(createAsyncTool()).build(); + } + } + + private Object createServer(String serverType, boolean validationEnabled) { + if ("sync".equals(serverType)) { + return prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .validateToolInputs(validationEnabled) + .tools(createSyncTool()) + .build(); + } + else { + return prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .validateToolInputs(validationEnabled) + .tools(createAsyncTool()) + .build(); + } + } + + private void closeServer(Object server, String serverType) { + if ("async".equals(serverType)) { + ((McpAsyncServer) server).closeGracefully().block(); + } + else { + ((McpSyncServer) server).close(); + } + } + +} diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index 942e0a6e2..09529f2e0 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -21,6 +21,7 @@ import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; +import io.modelcontextprotocol.json.TypeRef; import net.javacrumbs.jsonunit.core.Option; /** @@ -713,13 +714,15 @@ void testJsonSchema() throws Exception { """; // Deserialize the original string to a JsonSchema object - McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); + Map schema = JSON_MAPPER.readValue(schemaJson, new TypeRef>() { + }); // Serialize the object back to a string String serialized = JSON_MAPPER.writeValueAsString(schema); // Deserialize again - McpSchema.JsonSchema deserialized = JSON_MAPPER.readValue(serialized, McpSchema.JsonSchema.class); + Map deserialized = JSON_MAPPER.readValue(serialized, new TypeRef>() { + }); // Serialize one more time and compare with the first serialization String serializedAgain = JSON_MAPPER.writeValueAsString(deserialized); @@ -756,13 +759,15 @@ void testJsonSchemaWithDefinitions() throws Exception { """; // Deserialize the original string to a JsonSchema object - McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); + Map schema = JSON_MAPPER.readValue(schemaJson, new TypeRef>() { + }); // Serialize the object back to a string String serialized = JSON_MAPPER.writeValueAsString(schema); // Deserialize again - McpSchema.JsonSchema deserialized = JSON_MAPPER.readValue(serialized, McpSchema.JsonSchema.class); + Map deserialized = JSON_MAPPER.readValue(serialized, new TypeRef>() { + }); // Serialize one more time and compare with the first serialization String serializedAgain = JSON_MAPPER.writeValueAsString(deserialized); @@ -845,8 +850,11 @@ void testToolWithComplexSchema() throws Exception { assertThatJson(serializedAgain).when(Option.IGNORING_ARRAY_ORDER).isEqualTo(json(serialized)); // Just verify the basic structure was preserved - assertThat(deserializedTool.inputSchema().defs()).isNotNull(); - assertThat(deserializedTool.inputSchema().defs()).containsKey("Address"); + assertThat(deserializedTool.inputSchema()).containsKey("$defs") + .extractingByKey("$defs") + .isNotNull() + .asInstanceOf(InstanceOfAssertFactories.MAP) + .containsKey("Address"); } @Test @@ -866,14 +874,14 @@ void testToolWithMeta() throws Exception { } """; - McpSchema.JsonSchema schema = JSON_MAPPER.readValue(schemaJson, McpSchema.JsonSchema.class); + Map inputSchema = Map.of("inputSchema", schemaJson); Map meta = Map.of("metaKey", "metaValue"); McpSchema.Tool tool = McpSchema.Tool.builder() .name("addressTool") .title("addressTool") .description("Handles addresses") - .inputSchema(schema) + .inputSchema(inputSchema) .meta(meta) .build(); @@ -1114,7 +1122,7 @@ void testToolDeserialization() throws Exception { assertThat(tool.name()).isEqualTo("test-tool"); assertThat(tool.description()).isEqualTo("A test tool"); assertThat(tool.inputSchema()).isNotNull(); - assertThat(tool.inputSchema().type()).isEqualTo("object"); + assertThat(tool.inputSchema().get("type")).isEqualTo("object"); assertThat(tool.outputSchema()).isNotNull(); assertThat(tool.outputSchema()).containsKey("type"); assertThat(tool.outputSchema().get("type")).isEqualTo("object"); diff --git a/mcp/pom.xml b/mcp/pom.xml index 937974228..16fca0ba4 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT mcp jar @@ -16,8 +16,8 @@ https://github.com/modelcontextprotocol/java-sdk - git://github.com/modelcontextprotocol/java-sdk.git - git@github.com/modelcontextprotocol/java-sdk.git + scm:git:git://github.com/modelcontextprotocol/java-sdk.git + scm:git:ssh://git@github.com/modelcontextprotocol/java-sdk.git @@ -25,15 +25,15 @@ io.modelcontextprotocol.sdk mcp-json-jackson3 - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-core - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT - + \ No newline at end of file diff --git a/pom.xml b/pom.xml index b1eedd38e..d738e26e6 100644 --- a/pom.xml +++ b/pom.xml @@ -6,15 +6,15 @@ io.modelcontextprotocol.sdk mcp-parent - 1.1.0-SNAPSHOT + 2.0.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk https://github.com/modelcontextprotocol/java-sdk - git://github.com/modelcontextprotocol/java-sdk.git - git@github.com/modelcontextprotocol/java-sdk.git + scm:git:git://github.com/modelcontextprotocol/java-sdk.git + scm:git:ssh://git@github.com/modelcontextprotocol/java-sdk.git Java SDK MCP Parent @@ -29,7 +29,7 @@ MIT License - http://www.opensource.org/licenses/mit-license.php + https://www.opensource.org/licenses/mit-license.php @@ -57,7 +57,7 @@ 17 17 17 - + 3.27.6 6.0.2 @@ -105,11 +105,11 @@ mcp-bom mcp - mcp-core - mcp-json-jackson2 - mcp-json-jackson3 + mcp-core + mcp-json-jackson2 + mcp-json-jackson3 mcp-test - conformance-tests + conformance-tests @@ -329,9 +329,9 @@ true central - - mcp-parent,conformance-tests,client-jdk-http-client,client-spring-http-client,server-servlet - + + mcp-parent,conformance-tests,client-jdk-http-client,client-spring-http-client,server-servlet + true @@ -387,4 +387,4 @@ - + \ No newline at end of file