Commit 695563fc authored by Ian Craggs's avatar Ian Craggs

Add TLS hostname check and CApath #420 #418

parent 8276dcd2
...@@ -2293,7 +2293,7 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options) ...@@ -2293,7 +2293,7 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options)
} }
if (options->struct_version != 0 && options->ssl) /* check validity of SSL options structure */ if (options->struct_version != 0 && options->ssl) /* check validity of SSL options structure */
{ {
if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 1) if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 2)
{ {
rc = MQTTASYNC_BAD_STRUCTURE; rc = MQTTASYNC_BAD_STRUCTURE;
goto exit; goto exit;
...@@ -2393,6 +2393,12 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options) ...@@ -2393,6 +2393,12 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options)
free((void*)m->c->sslopts->privateKeyPassword); free((void*)m->c->sslopts->privateKeyPassword);
if (m->c->sslopts->enabledCipherSuites) if (m->c->sslopts->enabledCipherSuites)
free((void*)m->c->sslopts->enabledCipherSuites); free((void*)m->c->sslopts->enabledCipherSuites);
if (m->c->sslopts->struct_version >= 2)
{
if (m->c->sslopts->CApath)
free((void*)m->c->sslopts->CApath);
}
free(m->c->sslopts);
free((void*)m->c->sslopts); free((void*)m->c->sslopts);
m->c->sslopts = NULL; m->c->sslopts = NULL;
} }
...@@ -2415,6 +2421,12 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options) ...@@ -2415,6 +2421,12 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options)
m->c->sslopts->enableServerCertAuth = options->ssl->enableServerCertAuth; m->c->sslopts->enableServerCertAuth = options->ssl->enableServerCertAuth;
if (m->c->sslopts->struct_version >= 1) if (m->c->sslopts->struct_version >= 1)
m->c->sslopts->sslVersion = options->ssl->sslVersion; m->c->sslopts->sslVersion = options->ssl->sslVersion;
if (m->c->sslopts->struct_version >= 2)
{
m->c->sslopts->verify = options->ssl->verify;
if (m->c->sslopts->CApath)
m->c->sslopts->CApath = MQTTStrdup(options->ssl->CApath);
}
} }
#else #else
if (options->struct_version != 0 && options->ssl) if (options->struct_version != 0 && options->ssl)
...@@ -2893,7 +2905,8 @@ static int MQTTAsync_connecting(MQTTAsyncs* m) ...@@ -2893,7 +2905,8 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
if (m->c->session != NULL) if (m->c->session != NULL)
if ((rc = SSL_set_session(m->c->net.ssl, m->c->session)) != 1) if ((rc = SSL_set_session(m->c->net.ssl, m->c->session)) != 1)
Log(TRACE_MIN, -1, "Failed to set SSL session with stored data, non critical"); Log(TRACE_MIN, -1, "Failed to set SSL session with stored data, non critical");
rc = SSLSocket_connect(m->c->net.ssl, m->c->net.socket); rc = SSLSocket_connect(m->c->net.ssl, m->c->net.socket,
m->serverURI, m->c->sslopts->verify);
if (rc == TCPSOCKET_INTERRUPTED) if (rc == TCPSOCKET_INTERRUPTED)
{ {
rc = MQTTCLIENT_SUCCESS; /* the connect is still in progress */ rc = MQTTCLIENT_SUCCESS; /* the connect is still in progress */
...@@ -2936,7 +2949,8 @@ static int MQTTAsync_connecting(MQTTAsyncs* m) ...@@ -2936,7 +2949,8 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
#if defined(OPENSSL) #if defined(OPENSSL)
else if (m->c->connect_state == 2) /* SSL connect sent - wait for completion */ else if (m->c->connect_state == 2) /* SSL connect sent - wait for completion */
{ {
if ((rc = SSLSocket_connect(m->c->net.ssl, m->c->net.socket)) != 1) if ((rc = SSLSocket_connect(m->c->net.ssl, m->c->net.socket,
m->serverURI, m->c->sslopts->verify)) != 1)
goto exit; goto exit;
if(!m->c->cleansession && m->c->session == NULL) if(!m->c->cleansession && m->c->session == NULL)
......
...@@ -698,9 +698,22 @@ typedef struct ...@@ -698,9 +698,22 @@ typedef struct
*/ */
int sslVersion; int sslVersion;
/**
* Whether to carry out post-connect checks, including that a certificate
* matches the given host name.
* Exists only if struct_version >= 2
*/
int verify;
/**
* From the OpenSSL documentation:
* If CApath is not NULL, it points to a directory containing CA certificates in PEM format.
* Exists only if struct_version >= 2
*/
const char* CApath;
} MQTTAsync_SSLOptions; } MQTTAsync_SSLOptions;
#define MQTTAsync_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 1, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT } #define MQTTAsync_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 2, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL }
/** /**
* MQTTAsync_connectOptions defines several settings that control the way the * MQTTAsync_connectOptions defines several settings that control the way the
......
...@@ -668,7 +668,8 @@ static thread_return_type WINAPI MQTTClient_run(void* n) ...@@ -668,7 +668,8 @@ static thread_return_type WINAPI MQTTClient_run(void* n)
#if defined(OPENSSL) #if defined(OPENSSL)
else if (m->c->connect_state == 2 && !Thread_check_sem(m->connect_sem)) else if (m->c->connect_state == 2 && !Thread_check_sem(m->connect_sem))
{ {
rc = SSLSocket_connect(m->c->net.ssl, m->c->net.socket); rc = SSLSocket_connect(m->c->net.ssl, m->c->net.socket,
m->serverURI, m->c->sslopts->verify);
if (rc == 1 || rc == SSL_FATAL) if (rc == 1 || rc == SSL_FATAL)
{ {
if (rc == 1 && !m->c->cleansession && m->c->session == NULL) if (rc == 1 && !m->c->cleansession && m->c->session == NULL)
...@@ -909,7 +910,8 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt ...@@ -909,7 +910,8 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt
if (m->c->session != NULL) if (m->c->session != NULL)
if ((rc = SSL_set_session(m->c->net.ssl, m->c->session)) != 1) if ((rc = SSL_set_session(m->c->net.ssl, m->c->session)) != 1)
Log(TRACE_MIN, -1, "Failed to set SSL session with stored data, non critical"); Log(TRACE_MIN, -1, "Failed to set SSL session with stored data, non critical");
rc = SSLSocket_connect(m->c->net.ssl, m->c->net.socket); rc = SSLSocket_connect(m->c->net.ssl, m->c->net.socket,
m->serverURI, m->c->sslopts->verify);
if (rc == TCPSOCKET_INTERRUPTED) if (rc == TCPSOCKET_INTERRUPTED)
m->c->connect_state = 2; /* the connect is still in progress */ m->c->connect_state = 2; /* the connect is still in progress */
else if (rc == SSL_FATAL) else if (rc == SSL_FATAL)
...@@ -1111,6 +1113,11 @@ static int MQTTClient_connectURI(MQTTClient handle, MQTTClient_connectOptions* o ...@@ -1111,6 +1113,11 @@ static int MQTTClient_connectURI(MQTTClient handle, MQTTClient_connectOptions* o
free((void*)m->c->sslopts->privateKeyPassword); free((void*)m->c->sslopts->privateKeyPassword);
if (m->c->sslopts->enabledCipherSuites) if (m->c->sslopts->enabledCipherSuites)
free((void*)m->c->sslopts->enabledCipherSuites); free((void*)m->c->sslopts->enabledCipherSuites);
if (m->c->sslopts->struct_version >= 2)
{
if (m->c->sslopts->CApath)
free((void*)m->c->sslopts->CApath);
}
free(m->c->sslopts); free(m->c->sslopts);
m->c->sslopts = NULL; m->c->sslopts = NULL;
} }
...@@ -1133,6 +1140,12 @@ static int MQTTClient_connectURI(MQTTClient handle, MQTTClient_connectOptions* o ...@@ -1133,6 +1140,12 @@ static int MQTTClient_connectURI(MQTTClient handle, MQTTClient_connectOptions* o
m->c->sslopts->enableServerCertAuth = options->ssl->enableServerCertAuth; m->c->sslopts->enableServerCertAuth = options->ssl->enableServerCertAuth;
if (m->c->sslopts->struct_version >= 1) if (m->c->sslopts->struct_version >= 1)
m->c->sslopts->sslVersion = options->ssl->sslVersion; m->c->sslopts->sslVersion = options->ssl->sslVersion;
if (m->c->sslopts->struct_version >= 2)
{
m->c->sslopts->verify = options->ssl->verify;
if (m->c->sslopts->CApath)
m->c->sslopts->CApath = MQTTStrdup(options->ssl->CApath);
}
} }
#endif #endif
...@@ -1207,7 +1220,7 @@ int MQTTClient_connect(MQTTClient handle, MQTTClient_connectOptions* options) ...@@ -1207,7 +1220,7 @@ int MQTTClient_connect(MQTTClient handle, MQTTClient_connectOptions* options)
#if defined(OPENSSL) #if defined(OPENSSL)
if (options->struct_version != 0 && options->ssl) /* check validity of SSL options structure */ if (options->struct_version != 0 && options->ssl) /* check validity of SSL options structure */
{ {
if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 1) if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 2)
{ {
rc = MQTTCLIENT_BAD_STRUCTURE; rc = MQTTCLIENT_BAD_STRUCTURE;
goto exit; goto exit;
...@@ -1826,7 +1839,8 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r ...@@ -1826,7 +1839,8 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r
#if defined(OPENSSL) #if defined(OPENSSL)
else if (m->c->connect_state == 2) else if (m->c->connect_state == 2)
{ {
*rc = SSLSocket_connect(m->c->net.ssl, sock); *rc = SSLSocket_connect(m->c->net.ssl, sock,
m->serverURI, m->c->sslopts->verify);
if (*rc == SSL_FATAL) if (*rc == SSL_FATAL)
break; break;
else if (*rc == 1) /* rc == 1 means SSL connect has finished and succeeded */ else if (*rc == 1) /* rc == 1 means SSL connect has finished and succeeded */
......
...@@ -544,9 +544,23 @@ typedef struct ...@@ -544,9 +544,23 @@ typedef struct
*/ */
int sslVersion; int sslVersion;
/**
* Whether to carry out post-connect checks, including that a certificate
* matches the given host name.
* Exists only if struct_version >= 2
*/
int verify;
/**
* From the OpenSSL documentation:
* If CApath is not NULL, it points to a directory containing CA certificates in PEM format.
* Exists only if struct_version >= 2
*/
const char* CApath;
} MQTTClient_SSLOptions; } MQTTClient_SSLOptions;
#define MQTTClient_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 1, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT } #define MQTTClient_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 2, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL }
/** /**
* MQTTClient_connectOptions defines several settings that control the way the * MQTTClient_connectOptions defines several settings that control the way the
......
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2009, 2017 IBM Corp. * Copyright (c) 2009, 2018 IBM Corp.
* *
* All rights reserved. This program and the accompanying materials * All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0 * are made available under the terms of the Eclipse Public License v1.0
...@@ -694,6 +694,11 @@ void MQTTProtocol_freeClient(Clients* client) ...@@ -694,6 +694,11 @@ void MQTTProtocol_freeClient(Clients* client)
free((void*)client->sslopts->privateKeyPassword); free((void*)client->sslopts->privateKeyPassword);
if (client->sslopts->enabledCipherSuites) if (client->sslopts->enabledCipherSuites)
free((void*)client->sslopts->enabledCipherSuites); free((void*)client->sslopts->enabledCipherSuites);
if (client->sslopts->struct_version >= 2)
{
if (client->sslopts->CApath)
free((void*)client->sslopts->CApath);
}
free(client->sslopts); free(client->sslopts);
} }
#endif #endif
......
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2009, 2017 IBM Corp. * Copyright (c) 2009, 2018 IBM Corp.
* *
* All rights reserved. This program and the accompanying materials * All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0 * are made available under the terms of the Eclipse Public License v1.0
...@@ -116,7 +116,8 @@ int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int MQTTVersi ...@@ -116,7 +116,8 @@ int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int MQTTVersi
{ {
if (SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts, addr) == 1) if (SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts, addr) == 1)
{ {
rc = SSLSocket_connect(aClient->net.ssl, aClient->net.socket); rc = SSLSocket_connect(aClient->net.ssl, aClient->net.socket,
addr, aClient->sslopts->verify);
if (rc == TCPSOCKET_INTERRUPTED) if (rc == TCPSOCKET_INTERRUPTED)
aClient->connect_state = 2; /* SSL connect called - wait for completion */ aClient->connect_state = 2; /* SSL connect called - wait for completion */
} }
......
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2009, 2017 IBM Corp. * Copyright (c) 2009, 2018 IBM Corp.
* *
* All rights reserved. This program and the accompanying materials * All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0 * are made available under the terms of the Eclipse Public License v1.0
...@@ -34,12 +34,14 @@ ...@@ -34,12 +34,14 @@
#include "Log.h" #include "Log.h"
#include "StackTrace.h" #include "StackTrace.h"
#include "Socket.h" #include "Socket.h"
char* MQTTProtocol_addressPort(const char* uri, int* port);
#include "Heap.h" #include "Heap.h"
#include <openssl/ssl.h> #include <openssl/ssl.h>
#include <openssl/err.h> #include <openssl/err.h>
#include <openssl/crypto.h> #include <openssl/crypto.h>
#include <openssl/x509v3.h>
extern Sockets s; extern Sockets s;
...@@ -653,14 +655,18 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, ...@@ -653,14 +655,18 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
return rc; return rc;
} }
/*
int SSLSocket_connect(SSL* ssl, int sock) * Return value: 1 - success, TCPSOCKET_INTERRUPTED - try again, anything else is failure
*/
int SSLSocket_connect(SSL* ssl, int sock, char* hostname, int verify)
{ {
int rc = 0; int rc = 0;
FUNC_ENTRY; FUNC_ENTRY;
printf("SSLSocket_connect verify %d\n", verify);
rc = SSL_connect(ssl); rc = SSL_connect(ssl);
printf("SSLSocket_connect rc %d\n", rc);
if (rc != 1) if (rc != 1)
{ {
int error; int error;
...@@ -670,6 +676,28 @@ int SSLSocket_connect(SSL* ssl, int sock) ...@@ -670,6 +676,28 @@ int SSLSocket_connect(SSL* ssl, int sock)
if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE) if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE)
rc = TCPSOCKET_INTERRUPTED; rc = TCPSOCKET_INTERRUPTED;
} }
else if (verify == 1)
{
char* peername = NULL;
int port;
char* addr = NULL;
X509* cert = SSL_get_peer_certificate(ssl);
addr = MQTTProtocol_addressPort(hostname, &port);
rc = X509_check_host(cert, addr, strlen(addr), 0, &peername);
if (rc == 0)
rc = SOCKET_ERROR;
Log(TRACE_MIN, -1, "rc from X509_check_host is %d", rc);
printf("rc from X509_check_host is %d", rc);
Log(TRACE_MIN, -1, "peername from X509_check_host is %s", peername);
printf("Peername %s\n", peername);
if (cert)
X509_free(cert);
if (addr != hostname)
free(addr);
}
FUNC_EXIT_RC(rc); FUNC_EXIT_RC(rc);
return rc; return rc;
......
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2009, 2017 IBM Corp. * Copyright (c) 2009, 2018 IBM Corp.
* *
* All rights reserved. This program and the accompanying materials * All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0 * are made available under the terms of the Eclipse Public License v1.0
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* Contributors: * Contributors:
* Ian Craggs, Allan Stockdill-Mander - initial implementation * Ian Craggs, Allan Stockdill-Mander - initial implementation
* Ian Craggs - SNI support * Ian Craggs - SNI support
* Ian Craggs - post connect checks and CApath
*******************************************************************************/ *******************************************************************************/
#if !defined(SSLSOCKET_H) #if !defined(SSLSOCKET_H)
#define SSLSOCKET_H #define SSLSOCKET_H
...@@ -43,7 +44,7 @@ char *SSLSocket_getdata(SSL* ssl, int socket, size_t bytes, size_t* actual_len); ...@@ -43,7 +44,7 @@ char *SSLSocket_getdata(SSL* ssl, int socket, size_t bytes, size_t* actual_len);
int SSLSocket_close(networkHandles* net); int SSLSocket_close(networkHandles* net);
int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, size_t buf0len, int count, char** buffers, size_t* buflens, int* frees); int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, size_t buf0len, int count, char** buffers, size_t* buflens, int* frees);
int SSLSocket_connect(SSL* ssl, int socket); int SSLSocket_connect(SSL* ssl, int sock, char* hostname, int verify);
int SSLSocket_getPendingRead(void); int SSLSocket_getPendingRead(void);
int SSLSocket_continueWrite(pending_writes* pw); int SSLSocket_continueWrite(pending_writes* pw);
......
...@@ -749,6 +749,9 @@ int test2a(struct Options options) ...@@ -749,6 +749,9 @@ int test2a(struct Options options)
opts.ssl->privateKeyPassword = options.client_key_pass; opts.ssl->privateKeyPassword = options.client_key_pass;
//opts.ssl->enabledCipherSuites = "DEFAULT"; //opts.ssl->enabledCipherSuites = "DEFAULT";
//opts.ssl->enabledServerCertAuth = 1; //opts.ssl->enabledServerCertAuth = 1;
opts.ssl->verify = 1;
printf("enableServerCertAuth %d\n", opts.ssl->enableServerCertAuth);
printf("verify %d\n", opts.ssl->verify);
rc = MQTTAsync_setCallbacks(c, &tc, NULL, asyncTestMessageArrived, rc = MQTTAsync_setCallbacks(c, &tc, NULL, asyncTestMessageArrived,
asyncTestOnDeliveryComplete); asyncTestOnDeliveryComplete);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment