From e6499595cfc662f602effd54bd4af2384bf74efb Mon Sep 17 00:00:00 2001 From: HTHou Date: Thu, 25 Jun 2026 10:41:51 +0800 Subject: [PATCH 1/2] Support Thrift client mutual TLS --- .../env/cluster/config/MppCommonConfig.java | 6 + .../cluster/config/MppSharedCommonConfig.java | 7 + .../iotdb/it/env/cluster/env/AbstractEnv.java | 29 ++ .../env/remote/config/RemoteCommonConfig.java | 5 + .../apache/iotdb/itbase/env/CommonConfig.java | 2 + .../session/it/IoTDBClientMutualSSLIT.java | 284 ++++++++++++++++++ .../org/apache/iotdb/cli/AbstractCli.java | 44 +++ .../main/java/org/apache/iotdb/cli/Cli.java | 23 +- .../apache/iotdb/tool/common/Constants.java | 8 + .../apache/iotdb/tool/common/OptionsUtil.java | 20 ++ .../iotdb/tool/data/AbstractDataTool.java | 21 ++ .../iotdb/tool/schema/AbstractSchemaTool.java | 18 ++ .../java/org/apache/iotdb/jdbc/Config.java | 4 + .../apache/iotdb/jdbc/IoTDBConnection.java | 4 +- .../iotdb/jdbc/IoTDBConnectionParams.java | 18 ++ .../java/org/apache/iotdb/jdbc/Utils.java | 6 + .../java/org/apache/iotdb/jdbc/UtilsTest.java | 8 + .../iotdb/rpc/BaseRpcTransportFactory.java | 16 +- .../iotdb/session/AbstractSessionBuilder.java | 2 + .../apache/iotdb/session/NodesSupplier.java | 12 + .../org/apache/iotdb/session/Session.java | 16 + .../iotdb/session/SessionConnection.java | 18 +- .../iotdb/session/TableSessionBuilder.java | 24 ++ .../iotdb/session/ThriftConnection.java | 6 +- .../iotdb/session/pool/SessionPool.java | 22 ++ .../session/pool/TableSessionPoolBuilder.java | 24 ++ .../iotdb/db/service/ExternalRPCService.java | 80 +++-- .../conf/iotdb-system.properties.template | 7 + .../iotdb/commons/conf/CommonConfig.java | 11 + .../iotdb/commons/conf/CommonDescriptor.java | 4 + 30 files changed, 707 insertions(+), 42 deletions(-) create mode 100644 integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBClientMutualSSLIT.java diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppCommonConfig.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppCommonConfig.java index 4dd7f8571627d..bd2fb84e0b600 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppCommonConfig.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppCommonConfig.java @@ -632,6 +632,12 @@ public CommonConfig setEnableThriftClientSSL(boolean enableThriftClientSSL) { return this; } + @Override + public CommonConfig setThriftSSLClientAuth(boolean thriftSSLClientAuth) { + setProperty("thrift_ssl_client_auth", String.valueOf(thriftSSLClientAuth)); + return this; + } + @Override public CommonConfig setEnableInternalSSL(boolean enableInternalSSL) { setProperty("enable_internal_ssl", String.valueOf(enableInternalSSL)); diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppSharedCommonConfig.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppSharedCommonConfig.java index 8e0398de30973..b1c2a4f8d6be5 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppSharedCommonConfig.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppSharedCommonConfig.java @@ -658,6 +658,13 @@ public CommonConfig setEnableThriftClientSSL(boolean enableThriftClientSSL) { return this; } + @Override + public CommonConfig setThriftSSLClientAuth(boolean thriftSSLClientAuth) { + cnConfig.setThriftSSLClientAuth(thriftSSLClientAuth); + dnConfig.setThriftSSLClientAuth(thriftSSLClientAuth); + return this; + } + @Override public CommonConfig setEnableInternalSSL(boolean enableInternalSSL) { cnConfig.setEnableInternalSSL(enableInternalSSL); diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/env/AbstractEnv.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/env/AbstractEnv.java index 342b4aba26d8c..e5d4e19834d95 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/env/AbstractEnv.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/env/AbstractEnv.java @@ -692,6 +692,10 @@ private boolean isThriftClientSSLEnabled() { return Boolean.parseBoolean(getDataNodeCommonConfigProperty("enable_thrift_ssl", "false")); } + private boolean isThriftSSLClientAuthEnabled() { + return Boolean.parseBoolean(getDataNodeCommonConfigProperty("thrift_ssl_client_auth", "false")); + } + private String getDataNodeCommonConfigProperty(final String key, final String defaultValue) { return ((MppCommonConfig) clusterConfig.getDataNodeCommonConfig()) .getProperty(key, defaultValue); @@ -711,6 +715,11 @@ private Properties constructConnectionProperties( putIfPresent( info, Config.TRUST_STORE_PWD, getDataNodeCommonConfigProperty("trust_store_pwd", "")); putIfPresent(info, Config.SSL_PROTOCOL, getClientSSLProtocol()); + if (isThriftSSLClientAuthEnabled()) { + putIfPresent(info, Config.KEY_STORE, getDataNodeCommonConfigProperty("key_store_path", "")); + putIfPresent( + info, Config.KEY_STORE_PWD, getDataNodeCommonConfigProperty("key_store_pwd", "")); + } } return info; } @@ -728,6 +737,11 @@ private Session.Builder configureClientSSL(final Session.Builder builder) { .trustStore(getDataNodeCommonConfigProperty("trust_store_path", "")) .trustStorePwd(getDataNodeCommonConfigProperty("trust_store_pwd", "")) .sslProtocol(getClientSSLProtocol()); + if (isThriftSSLClientAuthEnabled()) { + builder + .keyStore(getDataNodeCommonConfigProperty("key_store_path", "")) + .keyStorePwd(getDataNodeCommonConfigProperty("key_store_pwd", "")); + } } return builder; } @@ -739,6 +753,11 @@ private TableSessionBuilder configureClientSSL(final TableSessionBuilder builder .trustStore(getDataNodeCommonConfigProperty("trust_store_path", "")) .trustStorePwd(getDataNodeCommonConfigProperty("trust_store_pwd", "")) .sslProtocol(getClientSSLProtocol()); + if (isThriftSSLClientAuthEnabled()) { + builder + .keyStore(getDataNodeCommonConfigProperty("key_store_path", "")) + .keyStorePwd(getDataNodeCommonConfigProperty("key_store_pwd", "")); + } } return builder; } @@ -750,6 +769,11 @@ private SessionPool.Builder configureClientSSL(final SessionPool.Builder builder .trustStore(getDataNodeCommonConfigProperty("trust_store_path", "")) .trustStorePwd(getDataNodeCommonConfigProperty("trust_store_pwd", "")) .sslProtocol(getClientSSLProtocol()); + if (isThriftSSLClientAuthEnabled()) { + builder + .keyStore(getDataNodeCommonConfigProperty("key_store_path", "")) + .keyStorePwd(getDataNodeCommonConfigProperty("key_store_pwd", "")); + } } return builder; } @@ -761,6 +785,11 @@ private TableSessionPoolBuilder configureClientSSL(final TableSessionPoolBuilder .trustStore(getDataNodeCommonConfigProperty("trust_store_path", "")) .trustStorePwd(getDataNodeCommonConfigProperty("trust_store_pwd", "")) .sslProtocol(getClientSSLProtocol()); + if (isThriftSSLClientAuthEnabled()) { + builder + .keyStore(getDataNodeCommonConfigProperty("key_store_path", "")) + .keyStorePwd(getDataNodeCommonConfigProperty("key_store_pwd", "")); + } } return builder; } diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteCommonConfig.java b/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteCommonConfig.java index a2d10c01009da..ba1f7106dd647 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteCommonConfig.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteCommonConfig.java @@ -447,6 +447,11 @@ public CommonConfig setEnableThriftClientSSL(boolean enableThriftClientSSL) { return this; } + @Override + public CommonConfig setThriftSSLClientAuth(boolean thriftSSLClientAuth) { + return this; + } + @Override public CommonConfig setSubscriptionPrefetchTsFileBatchMaxDelayInMs( int subscriptionPrefetchTsFileBatchMaxDelayInMs) { diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/env/CommonConfig.java b/integration-test/src/main/java/org/apache/iotdb/itbase/env/CommonConfig.java index 5b51a6a8cf916..5a7a004fa88a2 100644 --- a/integration-test/src/main/java/org/apache/iotdb/itbase/env/CommonConfig.java +++ b/integration-test/src/main/java/org/apache/iotdb/itbase/env/CommonConfig.java @@ -203,6 +203,8 @@ default CommonConfig setDefaultDatabaseLevel(int defaultDatabaseLevel) { CommonConfig setEnableThriftClientSSL(boolean enableThriftClientSSL); + CommonConfig setThriftSSLClientAuth(boolean thriftSSLClientAuth); + CommonConfig setEnableInternalSSL(boolean enableInternalSSL); CommonConfig setKeyStorePath(String keyStorePath); diff --git a/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBClientMutualSSLIT.java b/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBClientMutualSSLIT.java new file mode 100644 index 0000000000000..0feb9ba6f30d7 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBClientMutualSSLIT.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iotdb.session.it; + +import org.apache.iotdb.isession.ISession; +import org.apache.iotdb.isession.ITableSession; +import org.apache.iotdb.isession.SessionConfig; +import org.apache.iotdb.isession.SessionDataSet; +import org.apache.iotdb.isession.pool.ISessionPool; +import org.apache.iotdb.isession.pool.SessionDataSetWrapper; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.env.cluster.node.DataNodeWrapper; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.ClusterIT; +import org.apache.iotdb.itbase.category.LocalStandaloneIT; +import org.apache.iotdb.itbase.category.TableClusterIT; +import org.apache.iotdb.itbase.category.TableLocalStandaloneIT; +import org.apache.iotdb.jdbc.Config; +import org.apache.iotdb.rpc.IoTDBConnectionException; +import org.apache.iotdb.session.Session; +import org.apache.iotdb.session.TableSessionBuilder; +import org.apache.iotdb.session.pool.SessionPool; + +import org.apache.tsfile.read.common.RowRecord; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.io.File; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Collections; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@RunWith(IoTDBTestRunner.class) +@Category({ + LocalStandaloneIT.class, + ClusterIT.class, + TableLocalStandaloneIT.class, + TableClusterIT.class +}) +public class IoTDBClientMutualSSLIT { + + private static final String STORE_PASSWORD = "thrift"; + private static String keyDir; + + @BeforeClass + public static void setUp() throws Exception { + keyDir = + System.getProperty("user.dir") + + File.separator + + "target" + + File.separator + + "test-classes" + + File.separator; + + EnvFactory.getEnv() + .getConfig() + .getCommonConfig() + .setEnableThriftClientSSL(true) + .setThriftSSLClientAuth(true) + .setKeyStorePath(keyStorePath()) + .setKeyStorePwd(STORE_PASSWORD) + .setTrustStorePath(trustStorePath()) + .setTrustStorePwd(STORE_PASSWORD) + .setSslProtocol(SessionConfig.DEFAULT_SSL_PROTOCOL); + EnvFactory.getEnv().initClusterEnvironment(); + } + + @After + public void tearDown() { + try (ISession session = newMutualSSLSession()) { + deleteTreeDatabase(session, "root.client_mtls_tree"); + deleteTreeDatabase(session, "root.client_mtls_pool"); + deleteTreeDatabase(session, "root.client_mtls_jdbc"); + } catch (Exception ignored) { + // ignored + } + try (ITableSession session = newMutualSSLTableSession()) { + session.executeNonQueryStatement("DROP DATABASE IF EXISTS client_mtls_table"); + } catch (Exception ignored) { + // ignored + } + } + + @AfterClass + public static void tearDownClass() { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void sslClientWithoutKeyStoreCanNotConnectWhenClientAuthRequired() { + final DataNodeWrapper dataNode = EnvFactory.getEnv().getDataNodeWrapper(0); + final Session session = + new Session.Builder() + .host(dataNode.getIp()) + .port(dataNode.getPort()) + .useSSL(true) + .trustStore(trustStorePath()) + .trustStorePwd(STORE_PASSWORD) + .build(); + + assertThrows(IoTDBConnectionException.class, session::open); + } + + @Test + public void treeSessionCanConnectWithMutualSSL() throws Exception { + try (ISession session = newMutualSSLSession()) { + session.executeNonQueryStatement("CREATE DATABASE root.client_mtls_tree"); + session.executeNonQueryStatement( + "CREATE TIMESERIES root.client_mtls_tree.d1.s1 WITH DATATYPE=INT32, ENCODING=PLAIN"); + session.executeNonQueryStatement( + "INSERT INTO root.client_mtls_tree.d1(time, s1) VALUES (1, 11)"); + + try (SessionDataSet dataSet = + session.executeQueryStatement("SELECT s1 FROM root.client_mtls_tree.d1")) { + assertTrue(dataSet.hasNext()); + final RowRecord record = dataSet.next(); + assertEquals(1L, record.getTimestamp()); + assertEquals(11, record.getFields().get(0).getIntV()); + assertFalse(dataSet.hasNext()); + } + } + } + + @Test + public void sessionPoolCanConnectWithMutualSSL() throws Exception { + final DataNodeWrapper dataNode = EnvFactory.getEnv().getDataNodeWrapper(0); + final ISessionPool pool = + new SessionPool.Builder() + .nodeUrls(Collections.singletonList(dataNode.getIpAndPortString())) + .maxSize(1) + .useSSL(true) + .trustStore(trustStorePath()) + .trustStorePwd(STORE_PASSWORD) + .keyStore(keyStorePath()) + .keyStorePwd(STORE_PASSWORD) + .build(); + try { + pool.executeNonQueryStatement("CREATE DATABASE root.client_mtls_pool"); + pool.executeNonQueryStatement( + "CREATE TIMESERIES root.client_mtls_pool.d1.s1 WITH DATATYPE=INT32, ENCODING=PLAIN"); + pool.executeNonQueryStatement( + "INSERT INTO root.client_mtls_pool.d1(time, s1) VALUES (1, 22)"); + + try (SessionDataSetWrapper dataSet = + pool.executeQueryStatement("SELECT s1 FROM root.client_mtls_pool.d1")) { + assertTrue(dataSet.hasNext()); + final RowRecord record = dataSet.next(); + assertEquals(1L, record.getTimestamp()); + assertEquals(22, record.getFields().get(0).getIntV()); + assertFalse(dataSet.hasNext()); + } + } finally { + pool.close(); + } + } + + @Test + public void tableSessionCanConnectWithMutualSSL() throws Exception { + try (ITableSession session = newMutualSSLTableSession()) { + session.executeNonQueryStatement("CREATE DATABASE IF NOT EXISTS client_mtls_table"); + session.executeNonQueryStatement("USE client_mtls_table"); + session.executeNonQueryStatement( + "CREATE TABLE IF NOT EXISTS mtls_table (tag1 STRING TAG, value INT32 FIELD)"); + session.executeNonQueryStatement( + "INSERT INTO mtls_table(time, tag1, value) VALUES (1, 'tag1', 33)"); + + try (SessionDataSet dataSet = + session.executeQueryStatement("SELECT time, value FROM mtls_table WHERE tag1 = 'tag1'")) { + assertTrue(dataSet.hasNext()); + final RowRecord record = dataSet.next(); + assertEquals(1L, record.getFields().get(0).getLongV()); + assertEquals(33, record.getFields().get(1).getIntV()); + assertFalse(dataSet.hasNext()); + } + } + } + + @Test + public void jdbcCanConnectWithMutualSSL() throws Exception { + final DataNodeWrapper dataNode = EnvFactory.getEnv().getDataNodeWrapper(0); + + try (Connection connection = + DriverManager.getConnection( + Config.IOTDB_URL_PREFIX + dataNode.getIpAndPortString(), mutualSSLProperties()); + Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE root.client_mtls_jdbc"); + statement.execute( + "CREATE TIMESERIES root.client_mtls_jdbc.d1.s1 WITH DATATYPE=INT32, ENCODING=PLAIN"); + statement.execute("INSERT INTO root.client_mtls_jdbc.d1(time, s1) VALUES (1, 44)"); + + try (ResultSet resultSet = + statement.executeQuery("SELECT s1 FROM root.client_mtls_jdbc.d1")) { + assertTrue(resultSet.next()); + assertEquals(1L, resultSet.getLong(1)); + assertEquals(44, resultSet.getInt(2)); + assertFalse(resultSet.next()); + } + } + } + + private static ISession newMutualSSLSession() throws IoTDBConnectionException { + final DataNodeWrapper dataNode = EnvFactory.getEnv().getDataNodeWrapper(0); + final Session session = + new Session.Builder() + .host(dataNode.getIp()) + .port(dataNode.getPort()) + .useSSL(true) + .trustStore(trustStorePath()) + .trustStorePwd(STORE_PASSWORD) + .keyStore(keyStorePath()) + .keyStorePwd(STORE_PASSWORD) + .build(); + session.open(); + return session; + } + + private static ITableSession newMutualSSLTableSession() throws IoTDBConnectionException { + final DataNodeWrapper dataNode = EnvFactory.getEnv().getDataNodeWrapper(0); + return new TableSessionBuilder() + .nodeUrls(Collections.singletonList(dataNode.getIpAndPortString())) + .useSSL(true) + .trustStore(trustStorePath()) + .trustStorePwd(STORE_PASSWORD) + .keyStore(keyStorePath()) + .keyStorePwd(STORE_PASSWORD) + .build(); + } + + private static Properties mutualSSLProperties() { + final Properties properties = new Properties(); + properties.put("user", SessionConfig.DEFAULT_USER); + properties.put("password", SessionConfig.DEFAULT_PASSWORD); + properties.put(Config.USE_SSL, Boolean.TRUE.toString()); + properties.put(Config.TRUST_STORE, trustStorePath()); + properties.put(Config.TRUST_STORE_PWD, STORE_PASSWORD); + properties.put(Config.KEY_STORE, keyStorePath()); + properties.put(Config.KEY_STORE_PWD, STORE_PASSWORD); + return properties; + } + + private void deleteTreeDatabase(final ISession session, final String database) { + try { + session.executeNonQueryStatement("DELETE DATABASE " + database); + } catch (Exception ignored) { + // ignored + } + } + + private static String keyStorePath() { + return keyDir + "test-keystore"; + } + + private static String trustStorePath() { + return keyDir + "test-truststore"; + } +} diff --git a/iotdb-client/cli/src/main/java/org/apache/iotdb/cli/AbstractCli.java b/iotdb-client/cli/src/main/java/org/apache/iotdb/cli/AbstractCli.java index 017f0a1a6ca2b..088131c257026 100644 --- a/iotdb-client/cli/src/main/java/org/apache/iotdb/cli/AbstractCli.java +++ b/iotdb-client/cli/src/main/java/org/apache/iotdb/cli/AbstractCli.java @@ -74,8 +74,10 @@ public abstract class AbstractCli { static final String USE_SSL_ARGS = "usessl"; static final String TRUST_STORE_ARGS = "ts"; + static final String KEY_STORE_ARGS = "ks"; static final String TRUST_STORE_PWD_ARGS = "tpw"; + static final String KEY_STORE_PWD_ARGS = "kpw"; static final String SSL_PROTOCOL_ARGS = "ssl_protocol"; @@ -83,8 +85,10 @@ public abstract class AbstractCli { private static final String USE_SSL = "use_ssl"; private static final String TRUST_STORE = "trust_store"; + private static final String KEY_STORE = "key_store"; private static final String TRUST_STORE_PWD = "trust_store_pwd"; + private static final String KEY_STORE_PWD = "key_store_pwd"; private static final String SSL_PROTOCOL = "ssl_protocol"; private static final String NULL = "null"; @@ -135,6 +139,8 @@ public abstract class AbstractCli { static String trustStore; // TODO: Make non-static static String trustStorePwd; + static String keyStore; + static String keyStorePwd; static String sslProtocol; static String execute; @@ -160,6 +166,8 @@ static void init() { keywordSet.add("-" + USE_SSL_ARGS); keywordSet.add("-" + TRUST_STORE_ARGS); keywordSet.add("-" + TRUST_STORE_PWD_ARGS); + keywordSet.add("-" + KEY_STORE_ARGS); + keywordSet.add("-" + KEY_STORE_PWD_ARGS); keywordSet.add("-" + SSL_PROTOCOL_ARGS); keywordSet.add("-" + EXECUTE_ARGS); keywordSet.add("-" + ISO8601_ARGS); @@ -219,6 +227,42 @@ static Options createOptions() { .build(); options.addOption(useSSL); + Option trustStore = + Option.builder(TRUST_STORE_ARGS) + .longOpt(TRUST_STORE) + .argName(TRUST_STORE) + .hasArg() + .desc("Trust store. (optional)") + .build(); + options.addOption(trustStore); + + Option trustStorePwd = + Option.builder(TRUST_STORE_PWD_ARGS) + .longOpt(TRUST_STORE_PWD) + .argName(TRUST_STORE_PWD) + .hasArg() + .desc("Trust store password. (optional)") + .build(); + options.addOption(trustStorePwd); + + Option keyStore = + Option.builder(KEY_STORE_ARGS) + .longOpt(KEY_STORE) + .argName(KEY_STORE) + .hasArg() + .desc("Key store for mutual SSL. (optional)") + .build(); + options.addOption(keyStore); + + Option keyStorePwd = + Option.builder(KEY_STORE_PWD_ARGS) + .longOpt(KEY_STORE_PWD) + .argName(KEY_STORE_PWD) + .hasArg() + .desc("Key store password for mutual SSL. (optional)") + .build(); + options.addOption(keyStorePwd); + Option sslProtocol = Option.builder(SSL_PROTOCOL_ARGS) .longOpt(SSL_PROTOCOL) diff --git a/iotdb-client/cli/src/main/java/org/apache/iotdb/cli/Cli.java b/iotdb-client/cli/src/main/java/org/apache/iotdb/cli/Cli.java index a9e911be6edc3..63c3c5d11539d 100644 --- a/iotdb-client/cli/src/main/java/org/apache/iotdb/cli/Cli.java +++ b/iotdb-client/cli/src/main/java/org/apache/iotdb/cli/Cli.java @@ -111,6 +111,12 @@ private static void constructProperties() { info.setProperty("use_ssl", useSsl); info.setProperty("trust_store", trustStore); info.setProperty("trust_store_pwd", trustStorePwd); + if (keyStore != null) { + info.setProperty(Config.KEY_STORE, keyStore); + } + if (keyStorePwd != null) { + info.setProperty(Config.KEY_STORE_PWD, keyStorePwd); + } if (sslProtocol != null) { info.setProperty(Config.SSL_PROTOCOL, sslProtocol); } @@ -164,8 +170,21 @@ private static void serve(CliContext ctx) { useSsl = commandLine.getOptionValue(USE_SSL_ARGS); sslProtocol = commandLine.getOptionValue(SSL_PROTOCOL_ARGS); if (Boolean.parseBoolean(useSsl)) { - trustStore = ctx.getLineReader().readLine("please input your trust_store:", '\0'); - trustStorePwd = ctx.getLineReader().readLine("please input your trust_store_pwd:", '\0'); + trustStore = commandLine.getOptionValue(TRUST_STORE_ARGS); + if (trustStore == null) { + trustStore = ctx.getLineReader().readLine("please input your trust_store:", '\0'); + } + trustStorePwd = commandLine.getOptionValue(TRUST_STORE_PWD_ARGS); + if (trustStorePwd == null) { + trustStorePwd = ctx.getLineReader().readLine("please input your trust_store_pwd:", '\0'); + } + keyStore = commandLine.getOptionValue(KEY_STORE_ARGS); + keyStorePwd = commandLine.getOptionValue(KEY_STORE_PWD_ARGS); + if (keyStore != null && keyStorePwd == null) { + keyStorePwd = ctx.getLineReader().readLine("please input your key_store_pwd:", '\0'); + } else if (keyStore == null && keyStorePwd != null) { + keyStore = ctx.getLineReader().readLine("please input your key_store:", '\0'); + } } password = commandLine.getOptionValue(PW_ARGS); if (password == null) { diff --git a/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/common/Constants.java b/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/common/Constants.java index a106326095756..ed2d96dc8606b 100644 --- a/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/common/Constants.java +++ b/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/common/Constants.java @@ -68,6 +68,14 @@ public class Constants { public static final String TRUST_STORE_PWD_NAME = "trust_store_password"; public static final String TRUST_STORE_PWD_DESC = "Trust store password. (optional)"; + public static final String KEY_STORE_ARGS = "ks"; + public static final String KEY_STORE_NAME = "key_store"; + public static final String KEY_STORE_DESC = "Key store for mutual SSL. (optional)"; + + public static final String KEY_STORE_PWD_ARGS = "kpw"; + public static final String KEY_STORE_PWD_NAME = "key_store_password"; + public static final String KEY_STORE_PWD_DESC = "Key store password for mutual SSL. (optional)"; + public static final String SSL_PROTOCOL_ARGS = "ssl_protocol"; public static final String SSL_PROTOCOL_NAME = "ssl_protocol"; public static final String SSL_PROTOCOL_DESC = "SSL protocol. (optional)"; diff --git a/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/common/OptionsUtil.java b/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/common/OptionsUtil.java index 1f79b303f6e4e..e7ca2317be6b0 100644 --- a/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/common/OptionsUtil.java +++ b/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/common/OptionsUtil.java @@ -133,6 +133,26 @@ public static Options createCommonOptions(Options options) { .build(); options.addOption(opTrustStorePwd); + Option opKeyStore = + Option.builder(KEY_STORE_ARGS) + .longOpt(KEY_STORE_NAME) + .optionalArg(true) + .argName(KEY_STORE_NAME) + .hasArg() + .desc(KEY_STORE_DESC) + .build(); + options.addOption(opKeyStore); + + Option opKeyStorePwd = + Option.builder(KEY_STORE_PWD_ARGS) + .longOpt(KEY_STORE_PWD_NAME) + .optionalArg(true) + .argName(KEY_STORE_PWD_NAME) + .hasArg() + .desc(KEY_STORE_PWD_DESC) + .build(); + options.addOption(opKeyStorePwd); + Option opSslProtocol = Option.builder(SSL_PROTOCOL_ARGS) .longOpt(SSL_PROTOCOL_NAME) diff --git a/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/data/AbstractDataTool.java b/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/data/AbstractDataTool.java index 6ea79a5b58216..07a21109c203e 100644 --- a/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/data/AbstractDataTool.java +++ b/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/data/AbstractDataTool.java @@ -97,6 +97,8 @@ public abstract class AbstractDataTool { protected static Boolean useSsl; protected static String trustStore; protected static String trustStorePwd; + protected static String keyStore; + protected static String keyStorePwd; protected static String sslProtocol; protected static Boolean aligned; protected static String database; @@ -140,6 +142,9 @@ protected AbstractDataTool() {} protected static Session.Builder configureSsl(Session.Builder builder) { builder.useSSL(true).trustStore(trustStore).trustStorePwd(trustStorePwd); + if (keyStore != null) { + builder.keyStore(keyStore).keyStorePwd(keyStorePwd); + } if (sslProtocol != null) { builder.sslProtocol(sslProtocol); } @@ -148,6 +153,9 @@ protected static Session.Builder configureSsl(Session.Builder builder) { protected static SessionPool.Builder configureSsl(SessionPool.Builder builder) { builder.useSSL(true).trustStore(trustStore).trustStorePwd(trustStorePwd); + if (keyStore != null) { + builder.keyStore(keyStore).keyStorePwd(keyStorePwd); + } if (sslProtocol != null) { builder.sslProtocol(sslProtocol); } @@ -156,6 +164,9 @@ protected static SessionPool.Builder configureSsl(SessionPool.Builder builder) { protected static TableSessionBuilder configureSsl(TableSessionBuilder builder) { builder.useSSL(true).trustStore(trustStore).trustStorePwd(trustStorePwd); + if (keyStore != null) { + builder.keyStore(keyStore).keyStorePwd(keyStorePwd); + } if (sslProtocol != null) { builder.sslProtocol(sslProtocol); } @@ -164,6 +175,9 @@ protected static TableSessionBuilder configureSsl(TableSessionBuilder builder) { protected static TableSessionPoolBuilder configureSsl(TableSessionPoolBuilder builder) { builder.useSSL(true).trustStore(trustStore).trustStorePwd(trustStorePwd); + if (keyStore != null) { + builder.keyStore(keyStore).keyStorePwd(keyStorePwd); + } if (sslProtocol != null) { builder.sslProtocol(sslProtocol); } @@ -219,6 +233,13 @@ protected static void parseBasicParams(CommandLine commandLine) } else { trustStorePwd = cliCtx.getLineReader().readLine("please input your trust_store_pwd:", '\0'); } + keyStore = commandLine.getOptionValue(Constants.KEY_STORE_ARGS); + keyStorePwd = commandLine.getOptionValue(Constants.KEY_STORE_PWD_ARGS); + if (keyStore != null && keyStorePwd == null) { + keyStorePwd = cliCtx.getLineReader().readLine("please input your key_store_pwd:", '\0'); + } else if (keyStore == null && keyStorePwd != null) { + keyStore = cliCtx.getLineReader().readLine("please input your key_store:", '\0'); + } } boolean hasPw = commandLine.hasOption(Constants.PW_ARGS); if (hasPw) { diff --git a/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/schema/AbstractSchemaTool.java b/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/schema/AbstractSchemaTool.java index 2bee5f5d33151..6f0e533065f4b 100644 --- a/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/schema/AbstractSchemaTool.java +++ b/iotdb-client/cli/src/main/java/org/apache/iotdb/tool/schema/AbstractSchemaTool.java @@ -55,6 +55,8 @@ public abstract class AbstractSchemaTool { protected static Boolean useSsl; protected static String trustStore; protected static String trustStorePwd; + protected static String keyStore; + protected static String keyStorePwd; protected static String sslProtocol; protected static Session session; protected static String queryPath; @@ -78,6 +80,9 @@ protected AbstractSchemaTool() {} protected static Session.Builder configureSsl(Session.Builder builder) { builder.useSSL(true).trustStore(trustStore).trustStorePwd(trustStorePwd); + if (keyStore != null) { + builder.keyStore(keyStore).keyStorePwd(keyStorePwd); + } if (sslProtocol != null) { builder.sslProtocol(sslProtocol); } @@ -86,6 +91,9 @@ protected static Session.Builder configureSsl(Session.Builder builder) { protected static SessionPool.Builder configureSsl(SessionPool.Builder builder) { builder.useSSL(true).trustStore(trustStore).trustStorePwd(trustStorePwd); + if (keyStore != null) { + builder.keyStore(keyStore).keyStorePwd(keyStorePwd); + } if (sslProtocol != null) { builder.sslProtocol(sslProtocol); } @@ -94,6 +102,9 @@ protected static SessionPool.Builder configureSsl(SessionPool.Builder builder) { protected static TableSessionPoolBuilder configureSsl(TableSessionPoolBuilder builder) { builder.useSSL(true).trustStore(trustStore).trustStorePwd(trustStorePwd); + if (keyStore != null) { + builder.keyStore(keyStore).keyStorePwd(keyStorePwd); + } if (sslProtocol != null) { builder.sslProtocol(sslProtocol); } @@ -147,6 +158,13 @@ protected static void parseBasicParams(CommandLine commandLine) } else { trustStorePwd = cliCtx.getLineReader().readLine("please input your trust_store_pwd:", '\0'); } + keyStore = commandLine.getOptionValue(Constants.KEY_STORE_ARGS); + keyStorePwd = commandLine.getOptionValue(Constants.KEY_STORE_PWD_ARGS); + if (keyStore != null && keyStorePwd == null) { + keyStorePwd = cliCtx.getLineReader().readLine("please input your key_store_pwd:", '\0'); + } else if (keyStore == null && keyStorePwd != null) { + keyStore = cliCtx.getLineReader().readLine("please input your key_store:", '\0'); + } } boolean hasPw = commandLine.hasOption(Constants.PW_ARGS); if (hasPw) { diff --git a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/Config.java b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/Config.java index d79656cc8d295..54aca784d15a5 100644 --- a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/Config.java +++ b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/Config.java @@ -82,6 +82,10 @@ private Config() { public static final String TRUST_STORE_PWD = "trust_store_pwd"; + public static final String KEY_STORE = "key_store"; + + public static final String KEY_STORE_PWD = "key_store_pwd"; + public static final String SSL_PROTOCOL = "ssl_protocol"; static final String DEFAULT_SSL_PROTOCOL = "TLS"; diff --git a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBConnection.java b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBConnection.java index 0dc81115dd01b..9a6dc91ceb4fe 100644 --- a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBConnection.java +++ b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBConnection.java @@ -544,12 +544,14 @@ private void openTransport() throws TTransportException { if (params.isUseSSL()) { transport = - DeepCopyRpcTransportFactory.INSTANCE.getTransportWithSSLConfig( + DeepCopyRpcTransportFactory.INSTANCE.getTransport( params.getHost(), params.getPort(), getNetworkTimeout(), params.getTrustStore(), params.getTrustStorePwd(), + params.getKeyStore(), + params.getKeyStorePwd(), params.getSslProtocol()); } else { transport = diff --git a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBConnectionParams.java b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBConnectionParams.java index 8bf51379c5c2a..6c9a148517c0b 100644 --- a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBConnectionParams.java +++ b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/IoTDBConnectionParams.java @@ -51,6 +51,8 @@ public class IoTDBConnectionParams { private boolean useSSL = false; private String trustStore; private String trustStorePwd; + private String keyStore; + private String keyStorePwd; private String sslProtocol = Config.DEFAULT_SSL_PROTOCOL; private String sqlDialect = TREE; @@ -185,6 +187,22 @@ public void setTrustStorePwd(String trustStorePwd) { this.trustStorePwd = trustStorePwd; } + public String getKeyStore() { + return keyStore; + } + + public void setKeyStore(String keyStore) { + this.keyStore = keyStore; + } + + public String getKeyStorePwd() { + return keyStorePwd; + } + + public void setKeyStorePwd(String keyStorePwd) { + this.keyStorePwd = keyStorePwd; + } + public String getSslProtocol() { return sslProtocol; } diff --git a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/Utils.java b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/Utils.java index ea205ef726757..cf4fcb7d05638 100644 --- a/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/Utils.java +++ b/iotdb-client/jdbc/src/main/java/org/apache/iotdb/jdbc/Utils.java @@ -138,6 +138,12 @@ static IoTDBConnectionParams parseUrl(String url, Properties info) throws IoTDBU if (info.containsKey(Config.TRUST_STORE_PWD)) { params.setTrustStorePwd(info.getProperty(Config.TRUST_STORE_PWD)); } + if (info.containsKey(Config.KEY_STORE)) { + params.setKeyStore(info.getProperty(Config.KEY_STORE)); + } + if (info.containsKey(Config.KEY_STORE_PWD)) { + params.setKeyStorePwd(info.getProperty(Config.KEY_STORE_PWD)); + } if (info.containsKey(Config.SSL_PROTOCOL)) { params.setSslProtocol(RpcSslUtils.normalizeProtocol(info.getProperty(Config.SSL_PROTOCOL))); } diff --git a/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/UtilsTest.java b/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/UtilsTest.java index d201a9b2d4569..0787886c03ac7 100644 --- a/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/UtilsTest.java +++ b/iotdb-client/jdbc/src/test/java/org/apache/iotdb/jdbc/UtilsTest.java @@ -163,11 +163,19 @@ public void testRpcCompress() throws IoTDBURLException { @Test public void testParseSslConfig() throws IoTDBURLException { Properties properties = new Properties(); + properties.setProperty(Config.TRUST_STORE, "/tmp/truststore.p12"); + properties.setProperty(Config.TRUST_STORE_PWD, "trust_pass"); + properties.setProperty(Config.KEY_STORE, "/tmp/keystore.p12"); + properties.setProperty(Config.KEY_STORE_PWD, "key_pass"); IoTDBConnectionParams params = Utils.parseUrl( "jdbc:iotdb://127.0.0.1:6667?use_ssl=true&ssl_protocol=ProviderProtocol", properties); assertTrue(params.isUseSSL()); assertEquals("ProviderProtocol", params.getSslProtocol()); + assertEquals("/tmp/truststore.p12", params.getTrustStore()); + assertEquals("trust_pass", params.getTrustStorePwd()); + assertEquals("/tmp/keystore.p12", params.getKeyStore()); + assertEquals("key_pass", params.getKeyStorePwd()); } } diff --git a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/BaseRpcTransportFactory.java b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/BaseRpcTransportFactory.java index 03bf2070bca98..57f1dd5a8942a 100644 --- a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/BaseRpcTransportFactory.java +++ b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/BaseRpcTransportFactory.java @@ -19,8 +19,6 @@ package org.apache.iotdb.rpc; -import org.apache.iotdb.rpc.i18n.RpcMessages; - import org.apache.thrift.transport.TMemoryInputTransport; import org.apache.thrift.transport.TSSLTransportFactory; import org.apache.thrift.transport.TSocket; @@ -28,10 +26,6 @@ import org.apache.thrift.transport.TTransportException; import org.apache.thrift.transport.TTransportFactory; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Paths; - @SuppressWarnings("java:S1135") // ignore todos public class BaseRpcTransportFactory extends TTransportFactory { @@ -120,11 +114,9 @@ public TTransport getTransport( throws TTransportException { TSSLTransportFactory.TSSLTransportParameters params = RpcSslUtils.createTSSLTransportParameters(sslProtocol); - if (Files.exists(Paths.get(trustStore)) && Files.exists(Paths.get(keyStore))) { - RpcSslUtils.setTrustStore(params, trustStore, trustStorePwd); + RpcSslUtils.setTrustStore(params, trustStore, trustStorePwd); + if (hasText(keyStore)) { RpcSslUtils.setKeyStore(params, keyStore, keyStorePwd); - } else { - throw new TTransportException(new IOException(RpcMessages.COULD_NOT_LOAD_KEYSTORE)); } TTransport transport = TSSLTransportFactory.getClientSocket(ip, port, timeout, params); return inner.getTransport(transport); @@ -150,4 +142,8 @@ public static void setDefaultBufferCapacity(int thriftDefaultBufferSize) { public static void setThriftMaxFrameSize(int thriftMaxFrameSize) { BaseRpcTransportFactory.thriftMaxFrameSize = thriftMaxFrameSize; } + + private static boolean hasText(String value) { + return value != null && !value.trim().isEmpty(); + } } diff --git a/iotdb-client/session/src/main/java/org/apache/iotdb/session/AbstractSessionBuilder.java b/iotdb-client/session/src/main/java/org/apache/iotdb/session/AbstractSessionBuilder.java index 873e6dd248d8a..44500ed3f10a6 100644 --- a/iotdb-client/session/src/main/java/org/apache/iotdb/session/AbstractSessionBuilder.java +++ b/iotdb-client/session/src/main/java/org/apache/iotdb/session/AbstractSessionBuilder.java @@ -63,6 +63,8 @@ public abstract class AbstractSessionBuilder { public boolean useSSL = false; public String trustStore; public String trustStorePwd; + public String keyStore; + public String keyStorePwd; public String sslProtocol = SessionConfig.DEFAULT_SSL_PROTOCOL; // max retry count, if set to 0, means that we won't do any retry diff --git a/iotdb-client/session/src/main/java/org/apache/iotdb/session/NodesSupplier.java b/iotdb-client/session/src/main/java/org/apache/iotdb/session/NodesSupplier.java index a41db581e9cc7..c1f43fa30d433 100644 --- a/iotdb-client/session/src/main/java/org/apache/iotdb/session/NodesSupplier.java +++ b/iotdb-client/session/src/main/java/org/apache/iotdb/session/NodesSupplier.java @@ -61,6 +61,8 @@ public class NodesSupplier implements INodeSupplier, Runnable { private final boolean useSSL; private final String trustStore; private final String trustStorePwd; + private final String keyStore; + private final String keyStorePwd; private final String sslProtocol; private final boolean enableRPCCompression; private final String userName; @@ -96,6 +98,8 @@ public static NodesSupplier createNodeSupplier( boolean useSSL, String trustStore, String trustStorePwd, + String keyStore, + String keyStorePwd, String sslProtocol, boolean enableRPCCompression, String version) { @@ -112,6 +116,8 @@ public static NodesSupplier createNodeSupplier( useSSL, trustStore, trustStorePwd, + keyStore, + keyStorePwd, sslProtocol, enableRPCCompression, version); @@ -135,6 +141,8 @@ private NodesSupplier( boolean useSSL, String trustStore, String trustStorePwd, + String keyStore, + String keyStorePwd, String sslProtocol, boolean enableRPCCompression, String version) { @@ -144,6 +152,8 @@ private NodesSupplier( this.useSSL = useSSL; this.trustStore = trustStore; this.trustStorePwd = trustStorePwd; + this.keyStore = keyStore; + this.keyStorePwd = keyStorePwd; this.sslProtocol = sslProtocol; this.enableRPCCompression = enableRPCCompression; this.zoneId = zoneId == null ? ZoneId.systemDefault() : zoneId; @@ -193,6 +203,8 @@ private boolean createConnection(TEndPoint endPoint) { useSSL, trustStore, trustStorePwd, + keyStore, + keyStorePwd, sslProtocol, userName, password, diff --git a/iotdb-client/session/src/main/java/org/apache/iotdb/session/Session.java b/iotdb-client/session/src/main/java/org/apache/iotdb/session/Session.java index 648a63845b4bf..8d182122de5e5 100644 --- a/iotdb-client/session/src/main/java/org/apache/iotdb/session/Session.java +++ b/iotdb-client/session/src/main/java/org/apache/iotdb/session/Session.java @@ -134,6 +134,8 @@ public class Session implements ISession { protected boolean useSSL; protected String trustStore; protected String trustStorePwd; + protected String keyStore; + protected String keyStorePwd; protected String sslProtocol; /** @@ -475,6 +477,8 @@ public Session(AbstractSessionBuilder builder) { this.useSSL = builder.useSSL; this.trustStore = builder.trustStore; this.trustStorePwd = builder.trustStorePwd; + this.keyStore = builder.keyStore; + this.keyStorePwd = builder.keyStorePwd; this.sslProtocol = builder.sslProtocol; this.enableAutoFetch = builder.enableAutoFetch; this.maxRetryCount = builder.maxRetryCount; @@ -545,6 +549,8 @@ public synchronized void open(boolean enableRPCCompaction, int connectionTimeout useSSL, trustStore, trustStorePwd, + keyStore, + keyStorePwd, sslProtocol, enableRPCCompaction, version.toString()); @@ -4438,6 +4444,16 @@ public Builder trustStorePwd(String trustStorePwd) { return this; } + public Builder keyStore(String keyStore) { + this.keyStore = keyStore; + return this; + } + + public Builder keyStorePwd(String keyStorePwd) { + this.keyStorePwd = keyStorePwd; + return this; + } + public Builder sslProtocol(String sslProtocol) { this.sslProtocol = sslProtocol; return this; diff --git a/iotdb-client/session/src/main/java/org/apache/iotdb/session/SessionConnection.java b/iotdb-client/session/src/main/java/org/apache/iotdb/session/SessionConnection.java index 468b59abb0c23..ec91fc9f8c0d1 100644 --- a/iotdb-client/session/src/main/java/org/apache/iotdb/session/SessionConnection.java +++ b/iotdb-client/session/src/main/java/org/apache/iotdb/session/SessionConnection.java @@ -153,7 +153,13 @@ public SessionConnection( this.database = database; try { init( - endPoint, session.useSSL, session.trustStore, session.trustStorePwd, session.sslProtocol); + endPoint, + session.useSSL, + session.trustStore, + session.trustStorePwd, + session.keyStore, + session.keyStorePwd, + session.sslProtocol); } catch (StatementExecutionException e) { throw new IoTDBConnectionException(e.getMessage()); } catch (IoTDBConnectionException e) { @@ -186,6 +192,8 @@ private void init( boolean useSSL, String trustStore, String trustStorePwd, + String keyStore, + String keyStorePwd, String sslProtocol) throws IoTDBConnectionException, StatementExecutionException { DeepCopyRpcTransportFactory.setDefaultBufferCapacity(session.thriftDefaultBufferSize); @@ -196,12 +204,14 @@ private void init( } if (useSSL) { transport = - DeepCopyRpcTransportFactory.INSTANCE.getTransportWithSSLConfig( + DeepCopyRpcTransportFactory.INSTANCE.getTransport( endPoint.getIp(), endPoint.getPort(), session.connectionTimeoutInMs, trustStore, trustStorePwd, + keyStore, + keyStorePwd, sslProtocol); } else { transport = @@ -278,6 +288,8 @@ private void initClusterConn() throws IoTDBConnectionException { session.useSSL, session.trustStore, session.trustStorePwd, + session.keyStore, + session.keyStorePwd, session.sslProtocol); } catch (IoTDBConnectionException e) { if (!reconnect()) { @@ -1100,6 +1112,8 @@ private boolean reconnect() { session.useSSL, session.trustStore, session.trustStorePwd, + session.keyStore, + session.keyStorePwd, session.sslProtocol); connectedSuccess = true; } catch (IoTDBConnectionException e) { diff --git a/iotdb-client/session/src/main/java/org/apache/iotdb/session/TableSessionBuilder.java b/iotdb-client/session/src/main/java/org/apache/iotdb/session/TableSessionBuilder.java index e724c0fd53220..2dca0d351dc80 100644 --- a/iotdb-client/session/src/main/java/org/apache/iotdb/session/TableSessionBuilder.java +++ b/iotdb-client/session/src/main/java/org/apache/iotdb/session/TableSessionBuilder.java @@ -239,6 +239,30 @@ public TableSessionBuilder trustStorePwd(String trustStorePwd) { return this; } + /** + * Sets the key store path for mutual SSL connections. + * + * @param keyStore the key store path. + * @return the current {@link TableSessionBuilder} instance. + * @defaultValue null + */ + public TableSessionBuilder keyStore(String keyStore) { + this.keyStore = keyStore; + return this; + } + + /** + * Sets the key store password for mutual SSL connections. + * + * @param keyStorePwd the key store password. + * @return the current {@link TableSessionBuilder} instance. + * @defaultValue null + */ + public TableSessionBuilder keyStorePwd(String keyStorePwd) { + this.keyStorePwd = keyStorePwd; + return this; + } + /** * Sets the SSL protocol for secure connections. * diff --git a/iotdb-client/session/src/main/java/org/apache/iotdb/session/ThriftConnection.java b/iotdb-client/session/src/main/java/org/apache/iotdb/session/ThriftConnection.java index 3ab3abd581d9a..2a7f970bd5046 100644 --- a/iotdb-client/session/src/main/java/org/apache/iotdb/session/ThriftConnection.java +++ b/iotdb-client/session/src/main/java/org/apache/iotdb/session/ThriftConnection.java @@ -78,6 +78,8 @@ public void init( boolean useSSL, String trustStore, String trustStorePwd, + String keyStore, + String keyStorePwd, String sslProtocol, String username, String password, @@ -90,12 +92,14 @@ public void init( try { if (useSSL) { transport = - DeepCopyRpcTransportFactory.INSTANCE.getTransportWithSSLConfig( + DeepCopyRpcTransportFactory.INSTANCE.getTransport( endPoint.getIp(), endPoint.getPort(), connectionTimeoutInMs, trustStore, trustStorePwd, + keyStore, + keyStorePwd, sslProtocol); } else { transport = diff --git a/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/SessionPool.java b/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/SessionPool.java index 8136b2be68fd1..3a9c56a5e50ec 100644 --- a/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/SessionPool.java +++ b/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/SessionPool.java @@ -117,6 +117,10 @@ public class SessionPool implements ISessionPool { private String trustStorePwd; + private String keyStore; + + private String keyStorePwd; + private String sslProtocol = SessionConfig.DEFAULT_SSL_PROTOCOL; private ZoneId zoneId; @@ -540,6 +544,8 @@ public SessionPool(AbstractSessionPoolBuilder builder) { this.useSSL = builder.useSSL; this.trustStore = builder.trustStore; this.trustStorePwd = builder.trustStorePwd; + this.keyStore = builder.keyStore; + this.keyStorePwd = builder.keyStorePwd; this.sslProtocol = builder.sslProtocol; this.maxRetryCount = builder.maxRetryCount; this.retryIntervalInMs = builder.retryIntervalInMs; @@ -598,6 +604,8 @@ private Session constructNewSession() { .useSSL(useSSL) .trustStore(trustStore) .trustStorePwd(trustStorePwd) + .keyStore(keyStore) + .keyStorePwd(keyStorePwd) .sslProtocol(sslProtocol) .maxRetryCount(maxRetryCount) .retryIntervalInMs(retryIntervalInMs) @@ -624,6 +632,8 @@ private Session constructNewSession() { .useSSL(useSSL) .trustStore(trustStore) .trustStorePwd(trustStorePwd) + .keyStore(keyStore) + .keyStorePwd(keyStorePwd) .sslProtocol(sslProtocol) .maxRetryCount(maxRetryCount) .retryIntervalInMs(retryIntervalInMs) @@ -669,6 +679,8 @@ private void initAvailableNodes(List endPointList) { useSSL, trustStore, trustStorePwd, + keyStore, + keyStorePwd, sslProtocol, enableThriftCompression, version.toString()); @@ -3645,6 +3657,16 @@ public Builder trustStorePwd(String trustStorePwd) { return this; } + public Builder keyStore(String keyStore) { + this.keyStore = keyStore; + return this; + } + + public Builder keyStorePwd(String keyStorePwd) { + this.keyStorePwd = keyStorePwd; + return this; + } + public Builder sslProtocol(String sslProtocol) { this.sslProtocol = sslProtocol; return this; diff --git a/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/TableSessionPoolBuilder.java b/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/TableSessionPoolBuilder.java index 2c7aba1a45213..ab3e1f1c6a610 100644 --- a/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/TableSessionPoolBuilder.java +++ b/iotdb-client/session/src/main/java/org/apache/iotdb/session/pool/TableSessionPoolBuilder.java @@ -281,6 +281,30 @@ public TableSessionPoolBuilder trustStorePwd(String trustStorePwd) { return this; } + /** + * Sets the key store path for mutual SSL connections. + * + * @param keyStore the key store path. + * @return the current {@link TableSessionPoolBuilder} instance. + * @defaultValue null + */ + public TableSessionPoolBuilder keyStore(String keyStore) { + this.keyStore = keyStore; + return this; + } + + /** + * Sets the key store password for mutual SSL connections. + * + * @param keyStorePwd the key store password. + * @return the current {@link TableSessionPoolBuilder} instance. + * @defaultValue null + */ + public TableSessionPoolBuilder keyStorePwd(String keyStorePwd) { + this.keyStorePwd = keyStorePwd; + return this; + } + /** * Sets the SSL protocol for secure connections. * diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/ExternalRPCService.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/ExternalRPCService.java index dcc8690de8a87..89650657a61d6 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/ExternalRPCService.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/ExternalRPCService.java @@ -65,32 +65,56 @@ public void initTProcessor() @Override public void initThriftServiceThread() throws IllegalAccessException { try { - thriftServiceThread = - commonConfig.isEnableThriftClientSSL() - ? new ThriftServiceThread( - processor, - getID().getName(), - ThreadName.CLIENT_RPC_PROCESSOR.getName(), - getBindIP(), - getBindPort(), - config.getRpcMaxConcurrentClientNum(), - config.getThriftServerAwaitTimeForStopService(), - new RPCServiceThriftHandler(impl), - config.isRpcThriftCompressionEnable(), - commonConfig.getKeyStorePath(), - commonConfig.getKeyStorePwd(), - ZeroCopyRpcTransportFactory.INSTANCE) - : new ThriftServiceThread( - processor, - getID().getName(), - ThreadName.CLIENT_RPC_PROCESSOR.getName(), - getBindIP(), - getBindPort(), - config.getRpcMaxConcurrentClientNum(), - config.getThriftServerAwaitTimeForStopService(), - new RPCServiceThriftHandler(impl), - config.isRpcThriftCompressionEnable(), - ZeroCopyRpcTransportFactory.INSTANCE); + if (!commonConfig.isEnableThriftClientSSL()) { + thriftServiceThread = + new ThriftServiceThread( + processor, + getID().getName(), + ThreadName.CLIENT_RPC_PROCESSOR.getName(), + getBindIP(), + getBindPort(), + config.getRpcMaxConcurrentClientNum(), + config.getThriftServerAwaitTimeForStopService(), + new RPCServiceThriftHandler(impl), + config.isRpcThriftCompressionEnable(), + ZeroCopyRpcTransportFactory.INSTANCE); + } else if (commonConfig.isThriftSSLClientAuth()) { + if (!hasText(commonConfig.getTrustStorePath())) { + throw new IllegalAccessException( + "trust_store_path must be set when thrift_ssl_client_auth is true"); + } + thriftServiceThread = + new ThriftServiceThread( + processor, + getID().getName(), + ThreadName.CLIENT_RPC_PROCESSOR.getName(), + getBindIP(), + getBindPort(), + config.getRpcMaxConcurrentClientNum(), + config.getThriftServerAwaitTimeForStopService(), + new RPCServiceThriftHandler(impl), + config.isRpcThriftCompressionEnable(), + commonConfig.getKeyStorePath(), + commonConfig.getKeyStorePwd(), + commonConfig.getTrustStorePath(), + commonConfig.getTrustStorePwd(), + ZeroCopyRpcTransportFactory.INSTANCE); + } else { + thriftServiceThread = + new ThriftServiceThread( + processor, + getID().getName(), + ThreadName.CLIENT_RPC_PROCESSOR.getName(), + getBindIP(), + getBindPort(), + config.getRpcMaxConcurrentClientNum(), + config.getThriftServerAwaitTimeForStopService(), + new RPCServiceThriftHandler(impl), + config.isRpcThriftCompressionEnable(), + commonConfig.getKeyStorePath(), + commonConfig.getKeyStorePwd(), + ZeroCopyRpcTransportFactory.INSTANCE); + } } catch (RPCServiceException e) { throw new IllegalAccessException(e.getMessage()); } @@ -118,6 +142,10 @@ public int getRPCPort() { return getBindPort(); } + private boolean hasText(String value) { + return value != null && !value.trim().isEmpty(); + } + private static class RPCServiceHolder { private static final ExternalRPCService INSTANCE = new ExternalRPCService(); diff --git a/iotdb-core/node-commons/src/assembly/resources/conf/iotdb-system.properties.template b/iotdb-core/node-commons/src/assembly/resources/conf/iotdb-system.properties.template index 1e808d7d75a2c..6d75a72aaab94 100644 --- a/iotdb-core/node-commons/src/assembly/resources/conf/iotdb-system.properties.template +++ b/iotdb-core/node-commons/src/assembly/resources/conf/iotdb-system.properties.template @@ -444,6 +444,13 @@ dn_metric_internal_reporter_type=MEMORY # Privilege: SECURITY enable_thrift_ssl=false +# Whether client authentication is required for Thrift SSL connections. +# This only takes effect when enable_thrift_ssl=true. +# effectiveMode: restart +# Datatype: boolean +# Privilege: SECURITY +thrift_ssl_client_auth=false + # Whether enable SSL for Rest Service # effectiveMode: restart # Datatype: boolean diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java index 53127fa2cdc06..257b2a8ad5176 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java @@ -490,6 +490,9 @@ public class CommonConfig { /** Enable the Thrift Client ssl. */ private boolean enableThriftClientSSL = false; + /** Whether the external Thrift SSL service requires client certificate authentication. */ + private boolean thriftSSLClientAuth = false; + /** Enable the cluster internal connection ssl. */ private boolean enableInternalSSL = false; @@ -2998,6 +3001,14 @@ public void setEnableThriftClientSSL(boolean enableThriftClientSSL) { this.enableThriftClientSSL = enableThriftClientSSL; } + public boolean isThriftSSLClientAuth() { + return thriftSSLClientAuth; + } + + public void setThriftSSLClientAuth(boolean thriftSSLClientAuth) { + this.thriftSSLClientAuth = thriftSSLClientAuth; + } + public boolean isEnableInternalSSL() { return enableInternalSSL; } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonDescriptor.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonDescriptor.java index 824ec639ef685..56e8587eb0764 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonDescriptor.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonDescriptor.java @@ -679,6 +679,10 @@ public void initThriftSSL(TrimProperties properties) { Boolean.parseBoolean( properties.getProperty( "enable_thrift_ssl", Boolean.toString(config.isEnableThriftClientSSL())))); + config.setThriftSSLClientAuth( + Boolean.parseBoolean( + properties.getProperty( + "thrift_ssl_client_auth", Boolean.toString(config.isThriftSSLClientAuth())))); config.setKeyStorePath(properties.getProperty("key_store_path", config.getKeyStorePath())); config.setKeyStorePwd(properties.getProperty("key_store_pwd", config.getKeyStorePwd())); config.setTrustStorePath( From d16174053dfdc7d154670ee71b8b7bbbfd706d07 Mon Sep 17 00:00:00 2001 From: HTHou Date: Fri, 26 Jun 2026 15:34:10 +0800 Subject: [PATCH 2/2] Support Python client mutual TLS --- iotdb-client/client-py/README.md | 36 ++++- iotdb-client/client-py/iotdb/Session.py | 28 +++- iotdb-client/client-py/iotdb/SessionPool.py | 8 + .../client-py/iotdb/dbapi/Connection.py | 35 +++- iotdb-client/client-py/iotdb/table_session.py | 6 + .../client-py/iotdb/table_session_pool.py | 4 + iotdb-client/client-py/session_ssl_example.py | 16 +- .../client-py/tests/unit/test_session_ssl.py | 152 ++++++++++++++++++ 8 files changed, 279 insertions(+), 6 deletions(-) create mode 100644 iotdb-client/client-py/tests/unit/test_session_ssl.py diff --git a/iotdb-client/client-py/README.md b/iotdb-client/client-py/README.md index c7dc1b33cc9f1..2441dfb866170 100644 --- a/iotdb-client/client-py/README.md +++ b/iotdb-client/client-py/README.md @@ -90,6 +90,40 @@ Notice: this RPC compression status of client must comply with that of IoTDB ser session.close() ``` +### SSL and Mutual TLS + +Use `use_ssl=True` and `ca_certs` to enable SSL and verify the server certificate. + +```python +session = Session( + ip, + port_, + username_, + password_, + use_ssl=True, + ca_certs="/path/ca.crt", +) +``` + +When the server enables Thrift mutual TLS, also configure a client certificate and an unencrypted PEM private key. + +```python +session = Session( + ip, + port_, + username_, + password_, + use_ssl=True, + ca_certs="/path/ca.crt", + client_cert="/path/client.crt", + client_key="/path/client.key", +) +``` + +`client_cert` and `client_key` must be set together. `client_key` should be protected by file permissions because encrypted private keys are not supported by this client API. + +The same SSL parameters are available in `PoolConfig`, `TableSessionConfig`, `TableSessionPoolConfig`, and DBAPI `connect`. + ### Data Definition Interface (DDL Interface) #### DATABASE Management @@ -606,5 +640,3 @@ Namely, these are * Release to pypi - - diff --git a/iotdb-client/client-py/iotdb/Session.py b/iotdb-client/client-py/iotdb/Session.py index a3d89d717464b..0a5cb43fd5440 100644 --- a/iotdb-client/client-py/iotdb/Session.py +++ b/iotdb-client/client-py/iotdb/Session.py @@ -87,6 +87,8 @@ def __init__( use_ssl=False, ca_certs=None, connection_timeout_in_ms=None, + client_cert=None, + client_key=None, ): self.__host = host self.__port = port @@ -117,6 +119,8 @@ def __init__( self.database = None self.__use_ssl = use_ssl self.__ca_certs = ca_certs + self.__client_cert = client_cert + self.__client_key = client_key self.__connection_timeout_in_ms = connection_timeout_in_ms self.__time_precision = "ms" @@ -132,6 +136,8 @@ def init_from_node_urls( use_ssl=False, ca_certs=None, connection_timeout_in_ms=None, + client_cert=None, + client_key=None, ): if node_urls is None: raise RuntimeError("node urls is empty") @@ -145,6 +151,8 @@ def init_from_node_urls( enable_redirection, use_ssl=use_ssl, ca_certs=ca_certs, + client_cert=client_cert, + client_key=client_key, connection_timeout_in_ms=connection_timeout_in_ms, ) session.__hosts = [] @@ -259,7 +267,21 @@ def __get_transport(self, endpoint): context = ssl.SSLContext(ssl.PROTOCOL_TLS) context.verify_mode = ssl.CERT_REQUIRED context.check_hostname = True - context.load_verify_locations(cafile=self.__ca_certs) + context.load_default_certs(ssl.Purpose.SERVER_AUTH) + if self.__has_text(self.__ca_certs): + context.load_verify_locations(cafile=self.__ca_certs) + if self.__has_text(self.__client_cert) or self.__has_text( + self.__client_key + ): + if not self.__has_text(self.__client_cert) or not self.__has_text( + self.__client_key + ): + raise TTransport.TTransportException( + message="client_cert and client_key must be set together." + ) + context.load_cert_chain( + certfile=self.__client_cert, keyfile=self.__client_key + ) socket = TSSLSocket.TSSLSocket( host=endpoint.ip, port=endpoint.port, ssl_context=context ) @@ -275,6 +297,10 @@ def __get_transport(self, endpoint): raise IoTDBConnectionException(e) from None return transport + @staticmethod + def __has_text(value): + return value is not None and str(value).strip() != "" + def is_open(self): return not self.__is_close diff --git a/iotdb-client/client-py/iotdb/SessionPool.py b/iotdb-client/client-py/iotdb/SessionPool.py index 1e34186b2e0ca..f52f4765395f7 100644 --- a/iotdb-client/client-py/iotdb/SessionPool.py +++ b/iotdb-client/client-py/iotdb/SessionPool.py @@ -49,6 +49,8 @@ def __init__( use_ssl: bool = False, ca_certs: str = None, connection_timeout_in_ms: int = None, + client_cert: str = None, + client_key: str = None, ): self.host = host self.port = port @@ -68,6 +70,8 @@ def __init__( self.enable_redirection = enable_redirection self.use_ssl = use_ssl self.ca_certs = ca_certs + self.client_cert = client_cert + self.client_key = client_key self.connection_timeout_in_ms = connection_timeout_in_ms @@ -96,6 +100,8 @@ def __construct_session(self) -> Session: enable_redirection=self.__pool_config.enable_redirection, use_ssl=self.__pool_config.use_ssl, ca_certs=self.__pool_config.ca_certs, + client_cert=self.__pool_config.client_cert, + client_key=self.__pool_config.client_key, connection_timeout_in_ms=self.__pool_config.connection_timeout_in_ms, ) session.sql_dialect = self.sql_dialect @@ -112,6 +118,8 @@ def __construct_session(self) -> Session: enable_redirection=self.__pool_config.enable_redirection, use_ssl=self.__pool_config.use_ssl, ca_certs=self.__pool_config.ca_certs, + client_cert=self.__pool_config.client_cert, + client_key=self.__pool_config.client_key, connection_timeout_in_ms=self.__pool_config.connection_timeout_in_ms, ) session.sql_dialect = self.sql_dialect diff --git a/iotdb-client/client-py/iotdb/dbapi/Connection.py b/iotdb-client/client-py/iotdb/dbapi/Connection.py index aee5520e9af92..4bbb4e288cffd 100644 --- a/iotdb-client/client-py/iotdb/dbapi/Connection.py +++ b/iotdb-client/client-py/iotdb/dbapi/Connection.py @@ -37,12 +37,29 @@ def __init__( zone_id=Session.DEFAULT_ZONE_ID, enable_rpc_compression=False, sqlalchemy_mode=False, + use_ssl=False, + ca_certs=None, + connection_timeout_in_ms=None, + client_cert=None, + client_key=None, ): - self.__session = Session(host, port, username, password, fetch_size, zone_id) + self.__session = Session( + host, + port, + username, + password, + fetch_size, + zone_id, + use_ssl=self.__to_bool(use_ssl), + ca_certs=ca_certs, + connection_timeout_in_ms=self.__to_optional_int(connection_timeout_in_ms), + client_cert=client_cert, + client_key=client_key, + ) self.__sqlalchemy_mode = sqlalchemy_mode self.__is_close = True try: - self.__session.open(enable_rpc_compression) + self.__session.open(self.__to_bool(enable_rpc_compression)) self.__is_close = False except Exception as e: raise ConnectionError(e) @@ -89,3 +106,17 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() + + @staticmethod + def __to_bool(value): + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in ("true", "1", "yes", "y") + return bool(value) + + @staticmethod + def __to_optional_int(value): + if value is None or value == "": + return None + return int(value) diff --git a/iotdb-client/client-py/iotdb/table_session.py b/iotdb-client/client-py/iotdb/table_session.py index fdf20e7cfa596..47a81ea5c23ee 100644 --- a/iotdb-client/client-py/iotdb/table_session.py +++ b/iotdb-client/client-py/iotdb/table_session.py @@ -38,6 +38,8 @@ def __init__( use_ssl: bool = False, ca_certs: str = None, connection_timeout_in_ms: int = None, + client_cert: str = None, + client_key: str = None, ): """ Initialize a TableSessionConfig object with the provided parameters. @@ -71,6 +73,8 @@ def __init__( self.enable_compression = enable_compression self.use_ssl = use_ssl self.ca_certs = ca_certs + self.client_cert = client_cert + self.client_key = client_key self.connection_timeout_in_ms = connection_timeout_in_ms @@ -91,6 +95,8 @@ def __init__( table_session_config.use_ssl, table_session_config.ca_certs, table_session_config.connection_timeout_in_ms, + client_cert=table_session_config.client_cert, + client_key=table_session_config.client_key, ) self.__session.sql_dialect = "table" self.__session.database = table_session_config.database diff --git a/iotdb-client/client-py/iotdb/table_session_pool.py b/iotdb-client/client-py/iotdb/table_session_pool.py index f44df3f249b6b..08bfd2fc2298d 100644 --- a/iotdb-client/client-py/iotdb/table_session_pool.py +++ b/iotdb-client/client-py/iotdb/table_session_pool.py @@ -37,6 +37,8 @@ def __init__( use_ssl: bool = False, ca_certs: str = None, connection_timeout_in_ms: int = None, + client_cert: str = None, + client_key: str = None, ): """ Initialize a TableSessionPoolConfig object with the provided parameters. @@ -77,6 +79,8 @@ def __init__( use_ssl=use_ssl, ca_certs=ca_certs, connection_timeout_in_ms=connection_timeout_in_ms, + client_cert=client_cert, + client_key=client_key, ) self.max_pool_size = max_pool_size self.wait_timeout_in_ms = wait_timeout_in_ms diff --git a/iotdb-client/client-py/session_ssl_example.py b/iotdb-client/client-py/session_ssl_example.py index 2d5a557afc449..afded04eb6158 100644 --- a/iotdb-client/client-py/session_ssl_example.py +++ b/iotdb-client/client-py/session_ssl_example.py @@ -28,11 +28,21 @@ use_ssl = True # Configure certificate path ca_certs = "/path/server.crt" +# Configure client certificate and private key when the server enables mTLS. +client_cert = "/path/client.crt" +client_key = "/path/client.key" def get_data(): session = Session( - ip, port_, username_, password_, use_ssl=use_ssl, ca_certs=ca_certs + ip, + port_, + username_, + password_, + use_ssl=use_ssl, + ca_certs=ca_certs, + client_cert=client_cert, + client_key=client_key, ) session.open(False) result = session.execute_query_statement("select * from root.eg.etth") @@ -53,6 +63,8 @@ def get_data2(): max_retry=3, use_ssl=use_ssl, ca_certs=ca_certs, + client_cert=client_cert, + client_key=client_key, ) max_pool_size = 5 wait_timeout_in_ms = 3000 @@ -74,6 +86,8 @@ def get_table_data(): time_zone="Asia/Shanghai", use_ssl=use_ssl, ca_certs=ca_certs, + client_cert=client_cert, + client_key=client_key, ) session = TableSession(pool_config) result = session.execute_query_statement("select * from test") diff --git a/iotdb-client/client-py/tests/unit/test_session_ssl.py b/iotdb-client/client-py/tests/unit/test_session_ssl.py new file mode 100644 index 0000000000000..11541223e1b01 --- /dev/null +++ b/iotdb-client/client-py/tests/unit/test_session_ssl.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import importlib +import ssl +from types import SimpleNamespace + +import pytest +from thrift.transport import TTransport +from thrift.transport import TSSLSocket + + +class FakeSslContext: + def __init__(self): + self.verify_mode = None + self.check_hostname = None + self.verify_locations = [] + self.default_cert_purposes = [] + self.cert_chain = None + + def load_default_certs(self, purpose): + self.default_cert_purposes.append(purpose) + + def load_verify_locations(self, cafile=None): + self.verify_locations.append(cafile) + + def load_cert_chain(self, certfile=None, keyfile=None): + self.cert_chain = (certfile, keyfile) + + +class FakeSocket: + def __init__(self, host, port, ssl_context): + self.host = host + self.port = port + self.ssl_context = ssl_context + self.timeout = None + + def setTimeout(self, timeout): + self.timeout = timeout + + +class FakeTransport: + def __init__(self, socket): + self.socket = socket + self.opened = False + + def isOpen(self): + return self.opened + + def open(self): + self.opened = True + + +def patch_ssl_transport(monkeypatch, context): + monkeypatch.setattr(ssl, "create_default_context", lambda purpose: context) + monkeypatch.setattr(ssl, "SSLContext", lambda protocol: context) + monkeypatch.setattr( + TSSLSocket, + "TSSLSocket", + lambda host, port, ssl_context: FakeSocket(host, port, ssl_context), + ) + monkeypatch.setattr(TTransport, "TFramedTransport", FakeTransport) + + +def test_session_ssl_loads_client_cert_chain(monkeypatch): + session_module = importlib.import_module("iotdb.Session") + context = FakeSslContext() + patch_ssl_transport(monkeypatch, context) + + session = session_module.Session( + "127.0.0.1", + "6667", + use_ssl=True, + ca_certs="ca.crt", + client_cert="client.crt", + client_key="client.key", + ) + + transport = session._Session__get_transport( + SimpleNamespace(ip="127.0.0.1", port=6667) + ) + + assert context.verify_locations == ["ca.crt"] + assert context.cert_chain == ("client.crt", "client.key") + assert transport.socket.ssl_context is context + assert transport.opened + + +def test_session_ssl_requires_client_cert_and_key_together(monkeypatch): + session_module = importlib.import_module("iotdb.Session") + patch_ssl_transport(monkeypatch, FakeSslContext()) + + session = session_module.Session( + "127.0.0.1", "6667", use_ssl=True, client_cert="client.crt" + ) + + with pytest.raises(TTransport.TTransportException) as exc_info: + session._Session__get_transport(SimpleNamespace(ip="127.0.0.1", port=6667)) + + assert "client_cert and client_key must be set together" in str(exc_info.value) + + +def test_dbapi_connection_passes_ssl_options(monkeypatch): + connection_module = importlib.import_module("iotdb.dbapi.Connection") + + class FakeSession: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.open_compression = None + + def open(self, enable_rpc_compression): + self.open_compression = enable_rpc_compression + + def close(self): + pass + + monkeypatch.setattr(connection_module, "Session", FakeSession) + + connection = connection_module.Connection( + "127.0.0.1", + "6667", + enable_rpc_compression="true", + use_ssl="true", + ca_certs="ca.crt", + connection_timeout_in_ms="1000", + client_cert="client.crt", + client_key="client.key", + ) + + session = connection._Connection__session + assert session.kwargs["use_ssl"] is True + assert session.kwargs["ca_certs"] == "ca.crt" + assert session.kwargs["connection_timeout_in_ms"] == 1000 + assert session.kwargs["client_cert"] == "client.crt" + assert session.kwargs["client_key"] == "client.key" + assert session.open_compression is True