Browse Source

working on MultiTenantMode support, removing non-TLS mode

forward-proxy
forest 10 months ago
parent
commit
7909b24089
7 changed files with 204 additions and 112 deletions
  1. +0
    -2
      README.md
  2. +184
    -80
      main.go
  3. +0
    -3
      tunnel-lib/client.go
  4. +0
    -3
      tunnel-lib/proto/proto.go
  5. +20
    -22
      tunnel-lib/server.go
  6. +0
    -1
      usage-example/client-config.json
  7. +0
    -1
      usage-example/server-config.json

+ 0
- 2
README.md View File

@ -36,7 +36,6 @@ Starting the "listener" test app. It listens on port 9001. This would be your w
"DebugLog": false,
"TunnelControlPort": 9056,
"ManagementPort": 9057,
"UseTls": true,
"CaCertificateFile": "InternalCA+chain.crt",
"ServerTlsKeyFile": "localhost.key",
"ServerTlsCertificateFile": "localhost+chain.crt"
@ -52,7 +51,6 @@ Starting the tunnel client. Client Identifier: TestClient1
"ServerHost": "localhost",
"ServerTunnelControlPort": 9056,
"ServerManagementPort": 9057,
"UseTls": true,
"ServiceToLocalAddrMap": {
"fooService": "127.0.0.1:9001"
},


+ 184
- 80
main.go View File

@ -13,9 +13,11 @@ import (
"log"
"net"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
@ -33,7 +35,18 @@ type ServerConfig struct {
// based on domain users/email addresses.
Domain string
UseTls bool
// MultiTenantMode ON:
// tenantId is required. ClientId must be formatted `<tenantId>.<nodeId>`
// clients will not be allowed to register listeners capturing all packets on a given port,
// they must specify a hostname, and they must prove that they own it (via a TXT record for example).
// Exception: Each client will get a few allocated ports for SSH & maybe etc???
//
// MultiTenantMode OFF:
// tenantId is N/A. ClientId must be formatted `<nodeId>`
// clients can register listeners with any hostname including null, on any open port.
//
MultiTenantMode bool
CaCertificateFilesGlob string
ServerTlsKeyFile string
ServerTlsCertificateFile string
@ -43,7 +56,6 @@ type ClientConfig struct {
DebugLog bool
ClientIdentifier string
ServerAddr string
UseTls bool
ServiceToLocalAddrMap *map[string]string
CaCertificateFilesGlob string
ClientTlsKeyFile string
@ -66,6 +78,7 @@ type ClientState struct {
}
type ManagementHttpHandler struct {
Domain string
ControlHandler http.Handler
}
@ -77,15 +90,15 @@ type LiveConfigUpdate struct {
type adminAPI struct{}
// Server State
var listeners []ListenerConfig
var listenersByTenant map[string][]ListenerConfig
var clientStatesMutex = &sync.Mutex{}
var clientStates map[string]ClientState
var clientStatesByTenant map[string]map[string]ClientState
var server *tunnel.Server
// Client State
var client *tunnel.Client
var tlsClientConfig *tls.Config
var serverURL *string
var serverHostPort *string
var serviceToLocalAddrMap *map[string]string
func main() {
@ -107,7 +120,7 @@ func main() {
}
// admin api handler for /liveconfig over unix socket
// client admin api handler for /liveconfig over unix socket
func (handler adminAPI) ServeHTTP(response http.ResponseWriter, request *http.Request) {
switch path.Clean(request.URL.Path) {
case "/liveconfig":
@ -132,7 +145,7 @@ func (handler adminAPI) ServeHTTP(response http.ResponseWriter, request *http.Re
http.Error(response, "500 Listeners json serialization failed", http.StatusInternalServerError)
return
}
apiURL := fmt.Sprintf("https://%s/tunnels", *serverURL)
apiURL := fmt.Sprintf("https://%s/tunnels", *serverHostPort)
tunnelsRequest, err := http.NewRequest("PUT", apiURL, bytes.NewReader(sendBytes))
if err != nil {
log.Printf("adminAPI: error creating tunnels request: %+v\n\n", err)
@ -198,44 +211,69 @@ func runClient(configFileName *string) {
if err != nil {
log.Fatalf("runClient(): can't json.Unmarshal(configBytes, &config) because %s \n", err)
}
serviceToLocalAddrMap = config.ServiceToLocalAddrMap
serverURL = &config.ServerAddr
serverHostPort = &config.ServerAddr
serverURLString := fmt.Sprintf("https://%s", *serverHostPort)
serverURL, err := url.Parse(serverURLString)
if err != nil {
log.Fatal(fmt.Errorf("failed to parse the ServerAddr (prefixed with https://) '%s' as a url", serverURLString))
}
configToLog, _ := json.MarshalIndent(config, "", " ")
log.Printf("theshold client is starting up using config:\n%s\n", string(configToLog))
dialFunction := net.Dial
if config.UseTls {
cert, err := tls.LoadX509KeyPair(config.ClientTlsCertificateFile, config.ClientTlsKeyFile)
if err != nil {
log.Fatal(err)
}
cert, err := tls.LoadX509KeyPair(config.ClientTlsCertificateFile, config.ClientTlsKeyFile)
if err != nil {
log.Fatal(err)
}
commonName := cert.Leaf.Subject.CommonName
clientIdDomain := strings.Split(commonName, "@")
if len(clientIdDomain) != 2 {
log.Fatal(fmt.Errorf(
"expected TLS client certificate common name '%s' to match format '<clientId>@<domain>'", commonName,
))
}
if clientIdDomain[1] != serverURL.Hostname() {
log.Fatal(fmt.Errorf(
"expected TLS client certificate common name domain '%s' to match ServerAddr domain '%s'",
clientIdDomain[1], serverURL.Hostname(),
))
}
if clientIdDomain[0] != config.ClientIdentifier {
log.Fatal(fmt.Errorf(
"expected TLS client certificate common name clientId '%s' to match ClientIdentifier '%s'",
clientIdDomain[0], config.ClientIdentifier,
))
}
certificates, err := filepath.Glob(config.CaCertificateFilesGlob)
if err != nil {
log.Fatal(err)
}
certificates, err := filepath.Glob(config.CaCertificateFilesGlob)
caCertPool := x509.NewCertPool()
for _, filename := range certificates {
caCert, err := ioutil.ReadFile(filename)
if err != nil {
log.Fatal(err)
}
caCertPool.AppendCertsFromPEM(caCert)
}
caCertPool := x509.NewCertPool()
for _, filename := range certificates {
caCert, err := ioutil.ReadFile(filename)
if err != nil {
log.Fatal(err)
}
caCertPool.AppendCertsFromPEM(caCert)
}
tlsClientConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
}
tlsClientConfig.BuildNameToCertificate()
tlsClientConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caCertPool,
}
tlsClientConfig.BuildNameToCertificate()
dialFunction = func(network, address string) (net.Conn, error) {
return tls.Dial(network, address, tlsClientConfig)
}
dialFunction = func(network, address string) (net.Conn, error) {
return tls.Dial(network, address, tlsClientConfig)
}
clientStateChanges := make(chan *tunnel.ClientStateChange)
@ -302,6 +340,50 @@ func runClientAdminApi(config ClientConfig) {
}
}
func validateCertificate(domain string, request *http.Request) (identifier string, tenantId string, err error) {
if len(request.TLS.PeerCertificates) != 1 {
return "", "", fmt.Errorf("expected exactly 1 TLS client certificate, got %d", len(request.TLS.PeerCertificates))
}
certCommonName := request.TLS.PeerCertificates[0].Subject.CommonName
clientIdDomain := strings.Split(certCommonName, "@")
if len(clientIdDomain) != 2 {
return "", "", fmt.Errorf(
"expected TLS client certificate common name '%s' to match format '<clientId>@<domain>'", certCommonName,
)
}
if clientIdDomain[1] != domain {
return "", "", fmt.Errorf(
"expected TLS client certificate common name domain '%s' to match server domain '%s'",
clientIdDomain[1], domain,
)
}
identifier = clientIdDomain[0]
nodeId := identifier
if strings.Contains(identifier, ".") {
tenantIdNodeId := strings.Split(identifier, ".")
if len(tenantIdNodeId) != 2 {
return "", "", fmt.Errorf(
"expected TLS client certificate common name '%s' to match format '<tenantId>.<nodeId>@<domain>'", certCommonName,
)
}
tenantId = tenantIdNodeId[0]
nodeId = tenantIdNodeId[1]
}
mustMatchRegexp := regexp.MustCompile("(?i)^[a-z0-9]+([a-z0-9-_]*[a-z0-9]+)?$")
if !mustMatchRegexp.MatchString(nodeId) {
return "", "", fmt.Errorf("expected TLS client certificate common name nodeId '%s' to be a valid subdomain", nodeId)
}
if tenantId != "" && !mustMatchRegexp.MatchString(tenantId) {
return "", "", fmt.Errorf("expected TLS client certificate common name tenantId '%s' to be a valid subdomain", tenantId)
}
return identifier, tenantId, nil
}
func runServer(configFileName *string) {
configBytes := getConfigBytes(configFileName)
@ -319,9 +401,10 @@ func runServer(configFileName *string) {
clientStateChangeChannel := make(chan *tunnel.ClientStateChange)
tunnelServerConfig := &tunnel.ServerConfig{
StateChanges: clientStateChangeChannel,
Domain: config.Domain,
DebugLog: config.DebugLog,
StateChanges: clientStateChangeChannel,
ValidateCertificate: validateCertificate,
Domain: config.Domain,
DebugLog: config.DebugLog,
}
server, err = tunnel.NewServer(tunnelServerConfig)
if err != nil {
@ -329,14 +412,27 @@ func runServer(configFileName *string) {
os.Exit(1)
}
clientStates = make(map[string]ClientState)
clientStatesByTenant = make(map[string]map[string]ClientState)
go (func() {
for {
clientStateChange := <-clientStateChangeChannel
clientStatesMutex.Lock()
previousState := ""
currentState := clientStateChange.Current.String()
fromMap, wasInMap := clientStates[clientStateChange.Identifier]
tenantId := ""
if strings.Contains(clientStateChange.Identifier, ".") {
tenantIdNodeId := strings.Split(clientStateChange.Identifier, ".")
if len(tenantIdNodeId) != 2 {
fmt.Printf("runServer(): go func(): can't handle clientStateChange with malformed Identifier '%s' \n", clientStateChange.Identifier)
break
}
tenantId = tenantIdNodeId[0]
}
if _, hasTenant := clientStatesByTenant[tenantId]; !hasTenant {
clientStatesByTenant[tenantId] = map[string]ClientState{}
}
fromMap, wasInMap := clientStatesByTenant[tenantId][clientStateChange.Identifier]
if wasInMap {
previousState = fromMap.CurrentState
} else {
@ -346,7 +442,7 @@ func runServer(configFileName *string) {
log.Printf("runServer(): recieved a client state change with an error: %s \n", clientStateChange.Error)
currentState = "ClientError"
}
clientStates[clientStateChange.Identifier] = ClientState{
clientStatesByTenant[tenantId][clientStateChange.Identifier] = ClientState{
CurrentState: currentState,
LastState: previousState,
}
@ -354,53 +450,48 @@ func runServer(configFileName *string) {
}
})()
if config.UseTls {
certificates, err := filepath.Glob(config.CaCertificateFilesGlob)
if err != nil {
log.Fatal(err)
}
certificates, err := filepath.Glob(config.CaCertificateFilesGlob)
caCertPool := x509.NewCertPool()
for _, filename := range certificates {
log.Printf("loading certificate %s, clients who have a key signed by this certificat will be allowed to connect", filename)
caCert, err := ioutil.ReadFile(filename)
if err != nil {
log.Fatal(err)
}
caCertPool.AppendCertsFromPEM(caCert)
}
caCertPool := x509.NewCertPool()
for _, filename := range certificates {
log.Printf("loading certificate %s, clients who have a key signed by this certificat will be allowed to connect", filename)
caCert, err := ioutil.ReadFile(filename)
if err != nil {
log.Fatal(err)
}
caCertPool.AppendCertsFromPEM(caCert)
}
tlsConfig := &tls.Config{
ClientCAs: caCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
}
tlsConfig.BuildNameToCertificate()
httpsManagementServer := &http.Server{
Addr: fmt.Sprintf(":%d", config.ListenPort),
TLSConfig: tlsConfig,
Handler: &(ManagementHttpHandler{ControlHandler: server}),
}
log.Print("runServer(): the server should be running now\n")
err = httpsManagementServer.ListenAndServeTLS(config.ServerTlsCertificateFile, config.ServerTlsKeyFile)
panic(err)
} else {
log.Print("runServer(): the server should be running now\n")
err = http.ListenAndServe(fmt.Sprintf(":%d", config.ListenPort), &(ManagementHttpHandler{ControlHandler: server}))
panic(err)
tlsConfig := &tls.Config{
ClientCAs: caCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
}
tlsConfig.BuildNameToCertificate()
httpsManagementServer := &http.Server{
Addr: fmt.Sprintf(":%d", config.ListenPort),
TLSConfig: tlsConfig,
Handler: &(ManagementHttpHandler{
Domain: config.Domain,
ControlHandler: server,
}),
}
log.Print("runServer(): the server should be running now\n")
err = httpsManagementServer.ListenAndServeTLS(config.ServerTlsCertificateFile, config.ServerTlsKeyFile)
panic(err)
}
func setListeners(listenerConfigs []ListenerConfig) (int, string) {
func setListeners(tenantId string, listenerConfigs []ListenerConfig) (int, string) {
currentListenersThatCanKeepRunning := make([]ListenerConfig, 0)
newListenersThatHaveToBeAdded := make([]ListenerConfig, 0)
for _, newListenerConfig := range listenerConfigs {
clientState, everHeardOfClientBefore := clientStates[newListenerConfig.ClientIdentifier]
clientState, everHeardOfClientBefore := clientStatesByTenant[tenantId][newListenerConfig.ClientIdentifier]
if !everHeardOfClientBefore {
return http.StatusNotFound, fmt.Sprintf("Client %s Not Found", newListenerConfig.ClientIdentifier)
}
@ -409,7 +500,7 @@ func setListeners(listenerConfigs []ListenerConfig) (int, string) {
}
}
for _, existingListener := range listeners {
for _, existingListener := range listenersByTenant[tenantId] {
canKeepRunning := false
for _, newListenerConfig := range listenerConfigs {
if compareListenerConfigs(existingListener, newListenerConfig) {
@ -432,7 +523,7 @@ func setListeners(listenerConfigs []ListenerConfig) (int, string) {
for _, newListenerConfig := range listenerConfigs {
hasToBeAdded := true
for _, existingListener := range listeners {
for _, existingListener := range listenersByTenant[tenantId] {
if compareListenerConfigs(existingListener, newListenerConfig) {
hasToBeAdded = false
}
@ -466,7 +557,7 @@ func setListeners(listenerConfigs []ListenerConfig) (int, string) {
}
}
listeners = append(currentListenersThatCanKeepRunning, newListenersThatHaveToBeAdded...)
listenersByTenant[tenantId] = append(currentListenersThatCanKeepRunning, newListenersThatHaveToBeAdded...)
return http.StatusOK, "ok"
@ -483,11 +574,24 @@ func compareListenerConfigs(a, b ListenerConfig) bool {
func (s *ManagementHttpHandler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
_, tenantId, err := validateCertificate(s.Domain, request)
if err != nil {
http.Error(responseWriter, fmt.Sprintf("400 bad request: %s", err.Error()), http.StatusBadRequest)
return
}
if _, hasTenant := clientStatesByTenant[tenantId]; !hasTenant {
clientStatesByTenant[tenantId] = map[string]ClientState{}
}
if _, hasTenant := listenersByTenant[tenantId]; !hasTenant {
listenersByTenant[tenantId] = []ListenerConfig{}
}
switch path.Clean(request.URL.Path) {
case "/clients":
if request.Method == "GET" {
clientStatesMutex.Lock()
bytes, err := json.Marshal(clientStates)
bytes, err := json.Marshal(clientStatesByTenant[tenantId])
clientStatesMutex.Unlock()
if err != nil {
http.Error(responseWriter, "500 JSON Marshal Error", http.StatusInternalServerError)
@ -497,8 +601,8 @@ func (s *ManagementHttpHandler) ServeHTTP(responseWriter http.ResponseWriter, re
responseWriter.Write(bytes)
} else {
responseWriter.Header().Set("Allow", "PUT")
http.Error(responseWriter, "405 Method Not Allowed", http.StatusMethodNotAllowed)
responseWriter.Header().Set("Allow", "GET")
http.Error(responseWriter, "405 Method Not Allowed, try GET", http.StatusMethodNotAllowed)
}
case "/tunnels":
if request.Method == "PUT" {
@ -517,7 +621,7 @@ func (s *ManagementHttpHandler) ServeHTTP(responseWriter http.ResponseWriter, re
return
}
statusCode, errorMessage := setListeners(listenerConfigs)
statusCode, errorMessage := setListeners(tenantId, listenerConfigs)
if statusCode != 200 {
http.Error(responseWriter, errorMessage, statusCode)
@ -535,14 +639,14 @@ func (s *ManagementHttpHandler) ServeHTTP(responseWriter http.ResponseWriter, re
}
} else {
responseWriter.Header().Set("Allow", "PUT")
http.Error(responseWriter, "405 Method Not Allowed", http.StatusMethodNotAllowed)
http.Error(responseWriter, "405 Method Not Allowed, try PUT", http.StatusMethodNotAllowed)
}
case "/ping":
if request.Method == "GET" {
fmt.Fprint(responseWriter, "pong")
} else {
responseWriter.Header().Set("Allow", "GET")
http.Error(responseWriter, "405 method not allowed", http.StatusMethodNotAllowed)
http.Error(responseWriter, "405 method not allowed, try GET", http.StatusMethodNotAllowed)
}
default:
s.ControlHandler.ServeHTTP(responseWriter, request)


+ 0
- 3
tunnel-lib/client.go View File

@ -431,9 +431,6 @@ func (c *Client) connect(identifier, serverAddr string) error {
if err != nil {
return fmt.Errorf("error creating request to %s: %s", remoteURL, err)
}
req.Header.Set(proto.ClientIdentifierHeader, identifier)
if c.config.DebugLog {
log.Printf("Client.connect(): Writing request to TCP: %+v\n", req)
}


+ 0
- 3
tunnel-lib/proto/proto.go View File

@ -5,9 +5,6 @@ const (
// ControlPath is http.Handler url path for control connection.
ControlPath = "/_controlPath/"
// ClientIdentifierHeader is header carrying information about tunnel identifier.
ClientIdentifierHeader = "X-Threshold-ClientId"
// control messages
// Connected is message sent by server to client when control connection was established.


+ 20
- 22
tunnel-lib/server.go View File

@ -70,6 +70,9 @@ type Server struct {
// the domain of the server, used for validating clientIds
domain string
// see ServerConfig.ValidateCertificate comment
validateCertificate func(domain string, request *http.Request) (identifier string, tenantId string, err error)
// yamuxConfig is passed to new yamux.Session's
yamuxConfig *yamux.Config
@ -91,6 +94,15 @@ type ServerConfig struct {
// the domain of the server, used for validating clientIds
Domain string
// function that analyzes the TLS client certificate of the request.
// this is based on the CommonName attribute of the TLS certificate.
// If we are in multi-tenant mode, it must be formatted like `<tenantId>.<nodeId>@<domain>`
// otherwise, it must be formatted like `<nodeId>@<domain>`
// <domain> must match the configured Domain of this Threshold server
// the identifier it returns will be `<tenantId>.<nodeId>` or `<nodeId>`.
// the tenantId it returns will be `<tenantId>` or ""
ValidateCertificate func(domain string, request *http.Request) (identifier string, tenantId string, err error)
// YamuxConfig defines the config which passed to every new yamux.Session. If nil
// yamux.DefaultConfig() is used.
YamuxConfig *yamux.Config
@ -285,28 +297,13 @@ func (s *Server) dial(identifier string, service string) (net.Conn, error) {
// controlHandler is used to capture incoming tunnel connect requests into raw
// tunnel TCP connections.
func (s *Server) controlHandler(w http.ResponseWriter, r *http.Request) (ctErr error) {
identifier := r.Header.Get(proto.ClientIdentifierHeader)
// When TLS is turned on, the Client Authentication certificate is required, so in that case
// if we got to this point, we should make sure
// the ClientIdentifier header matches the CommonName on the client cert.
// https://stackoverflow.com/questions/31751764/get-remote-ssl-certificate-in-golang
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
cn := r.TLS.PeerCertificates[0].Subject.CommonName
if fmt.Sprintf("%s@%s", identifier, s.domain) != cn {
return fmt.Errorf(
"\"%s: %s\" does not match TLS certificate CommonName %s",
proto.ClientIdentifierHeader, identifier, cn,
)
}
}
// We will allow clients to connect even if they are not configured to be used yet.
// In this case they have an empty set of listening front-end ports.
// ok := s.hasIdentifier(identifier)
// if !ok {
// return fmt.Errorf("no host associated for identifier %s. please use server.AddAddr()", identifier)
// }
clientId, tenantId, err := s.validateCertificate(s.domain, r)
fmt.Println(tenantId)
if err != nil {
return err
}
identifier := clientId
ct, ok := s.getControl(identifier)
if ok {
@ -613,7 +610,8 @@ func (s *Server) checkConnect(fn func(w http.ResponseWriter, r *http.Request) er
if err := fn(w, r); err != nil {
log.Printf("Server.checkConnect(): Handler err: %v\n", err.Error())
if identifier := r.Header.Get(proto.ClientIdentifierHeader); identifier != "" {
identifier, _, err := s.validateCertificate(s.domain, r)
if err == nil {
s.onDisconnect(identifier, err)
}


+ 0
- 1
usage-example/client-config.json View File

@ -2,7 +2,6 @@
"DebugLog": false,
"ClientIdentifier": "TestClient1",
"ServerAddr": "localhost:9056",
"UseTls": true,
"ServiceToLocalAddrMap": {
"fooService": "127.0.0.1:9001"
},


+ 0
- 1
usage-example/server-config.json View File

@ -3,7 +3,6 @@
"DebugLog": false,
"Domain": "example.com",
"ListenPort": 9056,
"UseTls": true,
"CaCertificateFilesGlob": "InternalCA+chain.crt",
"ServerTlsKeyFile": "localhost.key",
"ServerTlsCertificateFile": "localhost+chain.crt"

Loading…
Cancel
Save