Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,29 @@ curl -o session_example.go -L https://github.com/apache/iotdb-client-go/raw/main
go run session_example.go
```

## TLS/mTLS

Set `TLSConfig` on `client.Config`, `client.ClusterConfig`, or `client.PoolConfig` to enable TLS. Add `CertFile` and `KeyFile` when the server requires mTLS client authentication.

```golang
config := &client.Config{
Host: host,
Port: port,
UserName: user,
Password: password,
TLSConfig: &client.TLSConfig{
CAFile: "/path/to/ca.pem",
CertFile: "/path/to/client.pem",
KeyFile: "/path/to/client-key.pem",
},
}
session := client.NewSession(config)
if err := session.Open(false, 0); err != nil {
log.Fatal(err)
}
defer session.Close()
```

## How to Use the SessionPool

SessionPool is a wrapper of a Session Set. Using SessionPool, the user do not need to consider how to reuse a session connection.
Expand Down
23 changes: 23 additions & 0 deletions README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,29 @@ curl -o session_example.go -L https://github.com/apache/iotdb-client-go/raw/main
go run session_example.go
```

## TLS/mTLS

在 `client.Config`、`client.ClusterConfig` 或 `client.PoolConfig` 上设置 `TLSConfig` 即可启用 TLS。如果服务端要求 mTLS 客户端认证,同时设置 `CertFile` 和 `KeyFile`。

```golang
config := &client.Config{
Host: host,
Port: port,
UserName: user,
Password: password,
TLSConfig: &client.TLSConfig{
CAFile: "/path/to/ca.pem",
CertFile: "/path/to/client.pem",
KeyFile: "/path/to/client-key.pem",
},
}
session := client.NewSession(config)
if err := session.Open(false, 0); err != nil {
log.Fatal(err)
}
defer session.Close()
```

## SessionPool
通过SessionPool管理session,用户不需要考虑如何重用session,当到达pool的最大值时,获取session的请求会阻塞
注意:session使用完成后需要调用PutBack方法
Expand Down
58 changes: 29 additions & 29 deletions client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"errors"
"fmt"
"log"
"net"
"reflect"
"sort"
"strings"
Expand Down Expand Up @@ -68,6 +67,7 @@ type Config struct {
sqlDialect string
Version Version
Database string
TLSConfig *TLSConfig
}

type Session struct {
Expand Down Expand Up @@ -100,13 +100,10 @@ func (s *Session) Open(enableRPCCompression bool, connectionTimeoutInMs int) err

var err error

// In thrift 0.14.1, this func returns two values; in newer versions, it returns one.
s.trans = thrift.NewTSocketConf(net.JoinHostPort(s.config.Host, s.config.Port), &thrift.TConfiguration{
ConnectTimeout: time.Duration(connectionTimeoutInMs) * time.Millisecond, // Use 0 for no timeout
})
// s.trans = thrift.NewTFramedTransport(s.trans) // deprecated
tmp_conf := thrift.TConfiguration{MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE}
s.trans = thrift.NewTFramedTransportConf(s.trans, &tmp_conf)
s.trans, err = newTransport(s.config.Host, s.config.Port, connectionTimeoutInMs, s.config.TLSConfig)
if err != nil {
return err
}
if !s.trans.IsOpen() {
err = s.trans.Open()
if err != nil {
Expand Down Expand Up @@ -154,6 +151,7 @@ type ClusterConfig struct {
ConnectRetryMax int
sqlDialect string
Database string
TLSConfig *TLSConfig
}

func (s *Session) OpenCluster(enableRPCCompression bool) error {
Expand Down Expand Up @@ -1326,26 +1324,31 @@ func newClusterSessionWithSqlDialect(clusterConfig *ClusterConfig) (Session, err
session.endPointList[i] = node
}
var err error
var lastErr error
for i := range session.endPointList {
ep := session.endPointList[i]
session.trans = thrift.NewTSocketConf(net.JoinHostPort(ep.Host, ep.Port), &thrift.TConfiguration{
ConnectTimeout: time.Duration(0), // Use 0 for no timeout
})
// session.trans = thrift.NewTFramedTransport(session.trans) // deprecated
tmp_conf := thrift.TConfiguration{MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE}
session.trans = thrift.NewTFramedTransportConf(session.trans, &tmp_conf)
session.trans, err = newTransport(ep.Host, ep.Port, 0, clusterConfig.TLSConfig)
if err != nil {
lastErr = err
log.Println(err)
continue
}
if !session.trans.IsOpen() {
err = session.trans.Open()
if err != nil {
lastErr = err
log.Println(err)
} else {
session.config = getConfig(ep.Host, ep.Port,
clusterConfig.UserName, clusterConfig.Password, clusterConfig.FetchSize, clusterConfig.TimeZone, clusterConfig.ConnectRetryMax, clusterConfig.Database, clusterConfig.sqlDialect)
clusterConfig.UserName, clusterConfig.Password, clusterConfig.FetchSize, clusterConfig.TimeZone, clusterConfig.ConnectRetryMax, clusterConfig.Database, clusterConfig.sqlDialect, clusterConfig.TLSConfig)
break
}
}
}
if !session.trans.IsOpen() {
if session.trans == nil || !session.trans.IsOpen() {
if lastErr != nil {
return session, fmt.Errorf("no server can connect: %w", lastErr)
}
return session, fmt.Errorf("no server can connect")
}
Comment thread
HTHou marked this conversation as resolved.
return session, nil
Expand All @@ -1354,18 +1357,14 @@ func newClusterSessionWithSqlDialect(clusterConfig *ClusterConfig) (Session, err
func (s *Session) initClusterConn(node endPoint) error {
var err error

s.trans = thrift.NewTSocketConf(net.JoinHostPort(node.Host, node.Port), &thrift.TConfiguration{
ConnectTimeout: time.Duration(0), // Use 0 for no timeout
})
if err == nil {
// s.trans = thrift.NewTFramedTransport(s.trans) // deprecated
tmp_conf := thrift.TConfiguration{MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE}
s.trans = thrift.NewTFramedTransportConf(s.trans, &tmp_conf)
if !s.trans.IsOpen() {
err = s.trans.Open()
if err != nil {
return err
}
s.trans, err = newTransport(node.Host, node.Port, 0, s.config.TLSConfig)
if err != nil {
return err
}
if !s.trans.IsOpen() {
err = s.trans.Open()
if err != nil {
return err
}
}

Expand Down Expand Up @@ -1398,7 +1397,7 @@ func (s *Session) initClusterConn(node endPoint) error {
return err
}

func getConfig(host string, port string, userName string, passWord string, fetchSize int32, timeZone string, connectRetryMax int, database string, sqlDialect string) *Config {
func getConfig(host string, port string, userName string, passWord string, fetchSize int32, timeZone string, connectRetryMax int, database string, sqlDialect string, tlsConfig *TLSConfig) *Config {
return &Config{
Host: host,
Port: port,
Expand All @@ -1409,6 +1408,7 @@ func getConfig(host string, port string, userName string, passWord string, fetch
ConnectRetryMax: connectRetryMax,
sqlDialect: sqlDialect,
Database: database,
TLSConfig: tlsConfig,
}
}

Expand Down
3 changes: 3 additions & 0 deletions client/sessionpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type PoolConfig struct {
TimeZone string
ConnectRetryMax int
Database string
TLSConfig *TLSConfig
sqlDialect string
}

Expand Down Expand Up @@ -146,6 +147,7 @@ func getSessionConfig(config *PoolConfig) *Config {
ConnectRetryMax: config.ConnectRetryMax,
sqlDialect: config.sqlDialect,
Database: config.Database,
TLSConfig: config.TLSConfig,
}
}

Expand All @@ -159,6 +161,7 @@ func getClusterSessionConfig(config *PoolConfig) *ClusterConfig {
ConnectRetryMax: config.ConnectRetryMax,
sqlDialect: config.sqlDialect,
Database: config.Database,
TLSConfig: config.TLSConfig,
}
}

Expand Down
120 changes: 120 additions & 0 deletions client/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package client

import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"time"

"github.com/apache/thrift/lib/go/thrift"
)

// TLSConfig enables TLS for an IoTDB client connection. Set CertFile and
// KeyFile together to enable mTLS client authentication.
type TLSConfig struct {
// Config is an optional base tls.Config. It is cloned before use.
Config *tls.Config

// CAFile is an optional PEM encoded CA certificate file used to verify the server.
CAFile string

// CertFile and KeyFile are optional PEM encoded client certificate and key files for mTLS.
CertFile string
KeyFile string
}

func newTransport(host string, port string, connectionTimeoutInMs int, tlsConfig *TLSConfig) (thrift.TTransport, error) {
conf := &thrift.TConfiguration{
ConnectTimeout: time.Duration(connectionTimeoutInMs) * time.Millisecond,
MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE,
}
hostPort := net.JoinHostPort(host, port)

var base thrift.TTransport
if tlsConfig == nil {
base = thrift.NewTSocketConf(hostPort, conf)
} else {
cfg, err := buildTLSConfig(tlsConfig)
if err != nil {
return nil, err
}
conf.TLSConfig = cfg
base = thrift.NewTSSLSocketConf(hostPort, conf)
}

return thrift.NewTFramedTransportConf(base, conf), nil
}

func buildTLSConfig(config *TLSConfig) (*tls.Config, error) {
if config == nil {
return nil, nil
}

tlsConfig := &tls.Config{}
if config.Config != nil {
tlsConfig = config.Config.Clone()
}
if config.CAFile != "" {
rootCAs, err := loadCertPool(tlsConfig.RootCAs, config.CAFile)
if err != nil {
return nil, err
}
tlsConfig.RootCAs = rootCAs
}
if config.CertFile != "" || config.KeyFile != "" {
if config.CertFile == "" || config.KeyFile == "" {
return nil, fmt.Errorf("both TLS CertFile and KeyFile must be set")
}
certificate, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile)
if err != nil {
return nil, fmt.Errorf("load TLS client certificate/key: %w", err)
}
tlsConfig.Certificates = append(tlsConfig.Certificates, certificate)
}

return tlsConfig, nil
}

func loadCertPool(base *x509.CertPool, caFile string) (*x509.CertPool, error) {
rootCAs := base
if rootCAs != nil {
rootCAs = rootCAs.Clone()
} else {
systemPool, err := x509.SystemCertPool()
if err == nil && systemPool != nil {
rootCAs = systemPool
} else {
rootCAs = x509.NewCertPool()
}
Comment thread
Copilot marked this conversation as resolved.
}

caCert, err := os.ReadFile(caFile)
if err != nil {
return nil, fmt.Errorf("read TLS CA file %q: %w", caFile, err)
}
if !rootCAs.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("append TLS CA file %q: no certificates found", caFile)
}
return rootCAs, nil
}
Loading
Loading