Commit e7b9ffec authored by Ian Craggs's avatar Ian Craggs

fix for Bug 406030 - SSL connection memory leaks

parent f0ff597d
......@@ -392,6 +392,9 @@ void MQTTAsync_terminate(void)
ListFree(commands);
handles = NULL;
Socket_outTerminate();
#if defined(OPENSSL)
SSLSocket_terminate();
#endif
#if defined(HEAP_H)
Heap_terminate();
#endif
......@@ -1567,8 +1570,7 @@ void MQTTProtocol_closeSession(Clients* client, int sendwill)
if (client->connected || client->connect_state)
MQTTPacket_send_disconnect(&client->net, client->clientID);
#if defined(OPENSSL)
if (client->net.ssl)
SSLSocket_close(client->net.ssl);
SSLSocket_close(&client->net);
#endif
Socket_close(client->net.socket);
client->net.socket = 0;
......@@ -2298,8 +2300,11 @@ int MQTTAsync_connecting(MQTTAsyncs* m)
#if defined(OPENSSL)
if (m->ssl)
{
if ((m->c->net.ssl = SSLSocket_setSocketForSSL(m->c->net.socket, m->c->sslopts)) != NULL)
if (SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts) != MQTTASYNC_SUCCESS)
{
if (m->c->session != NULL)
if ((rc = SSL_set_session(m->c->net.ssl, m->c->session)) != 1)
Log(TRACE_MIN, -1, "Failed to set SSL session with stored data, non critical");
rc = SSLSocket_connect(m->c->net.ssl, m->c->net.socket);
if (rc == -1)
m->c->connect_state = 2;
......@@ -2308,6 +2313,8 @@ int MQTTAsync_connecting(MQTTAsyncs* m)
rc = SOCKET_ERROR;
goto exit;
}
else if (rc == 1 && !m->c->cleansession && m->c->session == NULL)
m->c->session = SSL_get1_session(m->c->net.ssl);
}
else
{
......@@ -2400,6 +2407,13 @@ MQTTPacket* MQTTAsync_cycle(int* sock, unsigned long timeout, int* rc)
*rc = MQTTAsync_connecting(m);
else
pack = MQTTPacket_Factory(&m->c->net, rc);
if ((m->c->connect_state == 3) && (*rc == SOCKET_ERROR))
{
Log( TRACE_MINIMUM, -1, "CONNECT sent but MQTTPacket_Factory has returned SOCKET_ERROR, calling connect.onFailure");
MQTTProtocol_closeSession(m->c, 0);
if (m->connect.onFailure)
(*(m->connect.onFailure))(m->connect.context, NULL);
}
}
if (pack)
{
......
......@@ -289,6 +289,9 @@ void MQTTClient_terminate(void)
ListFree(handles);
handles = NULL;
Socket_outTerminate();
#if defined(OPENSSL)
SSLSocket_terminate();
#endif
#if defined(HEAP_H)
Heap_terminate();
#endif
......@@ -347,6 +350,10 @@ void MQTTClient_destroy(MQTTClient* handle)
}
if (m->serverURI)
free(m->serverURI);
Thread_destroy_sem(m->connect_sem);
Thread_destroy_sem(m->connack_sem);
Thread_destroy_sem(m->suback_sem);
Thread_destroy_sem(m->unsuback_sem);
if (!ListRemove(handles, m))
Log(LOG_ERROR, -1, "free error");
*handle = NULL;
......@@ -628,8 +635,7 @@ void MQTTProtocol_closeSession(Clients* client, int sendwill)
client->connect_state = 0;
}
#if defined(OPENSSL)
if (client->net.ssl)
SSLSocket_close(client->net.ssl);
SSLSocket_close(&client->net);
#endif
Socket_close(client->net.socket);
client->net.socket = 0;
......@@ -814,7 +820,7 @@ int MQTTClient_connect(MQTTClient handle, MQTTClient_connectOptions* options)
#if defined(OPENSSL)
if (m->ssl)
{
if ((m->c->net.ssl = SSLSocket_setSocketForSSL(m->c->net.socket, m->c->sslopts)) != NULL)
if (SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts) != MQTTCLIENT_SUCCESS)
{
if (m->c->session != NULL)
if ((rc = SSL_set_session(m->c->net.ssl, m->c->session)) != 1)
......
......@@ -106,7 +106,7 @@ int MQTTProtocol_connect(char* ip_address, Clients* aClient)
#if defined(OPENSSL)
if (ssl)
{
if((aClient->net.ssl = SSLSocket_setSocketForSSL(aClient->net.socket, aClient->sslopts)) != NULL)
if (SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts) != 1)
{
rc = SSLSocket_connect(aClient->net.ssl, aClient->net.socket);
if (rc == -1)
......
......@@ -302,6 +302,7 @@ int SSLSocket_initialize()
int rc = 0;
/*int prc;*/
int i;
int lockMemSize;
FUNC_ENTRY;
......@@ -317,12 +318,17 @@ int SSLSocket_initialize()
OpenSSL_add_all_algorithms();
sslLocks = calloc(CRYPTO_num_locks(), sizeof(ssl_mutex_type));
lockMemSize = CRYPTO_num_locks() * sizeof(ssl_mutex_type);
sslLocks = malloc(lockMemSize);
if (!sslLocks)
{
rc = -1;
goto exit;
}
else
memset(sslLocks, 0, lockMemSize);
for (i = 0; i < CRYPTO_num_locks(); i++)
{
/* prc = */SSL_create_mutex(&sslLocks[i]);
......@@ -338,24 +344,31 @@ exit:
return rc;
}
SSL_CTX* SSLSocket_createContext(int socket, MQTTClient_SSLOptions* opts)
void SSLSocket_terminate()
{
FUNC_ENTRY;
free(sslLocks);
FUNC_EXIT;
}
int SSLSocket_createContext(networkHandles* net, MQTTClient_SSLOptions* opts)
{
int rc = 1;
SSL_CTX* ctx = NULL;
char* ciphers = NULL;
FUNC_ENTRY;
if ((ctx = SSL_CTX_new(SSLv23_client_method())) == NULL) /* SSLv23 for compatibility with SSLv2, SSLv3 and TLSv1 */
if (net->ctx == NULL)
if ((net->ctx = SSL_CTX_new(SSLv23_client_method())) == NULL) /* SSLv23 for compatibility with SSLv2, SSLv3 and TLSv1 */
{
SSLSocket_error("SSL_CTX_new", NULL, socket, rc);
SSLSocket_error("SSL_CTX_new", NULL, net->socket, rc);
goto exit;
}
if (opts->keyStore)
{
if ((rc = SSL_CTX_use_certificate_chain_file(ctx, opts->keyStore)) != 1)
if ((rc = SSL_CTX_use_certificate_chain_file(net->ctx, opts->keyStore)) != 1)
{
SSLSocket_error("SSL_CTX_use_certificate_chain_file", NULL, socket, rc);
SSLSocket_error("SSL_CTX_use_certificate_chain_file", NULL, net->socket, rc);
goto free_ctx; /*If we can't load the certificate (chain) file then loading the privatekey won't work either as it needs a matching cert already loaded */
}
......@@ -364,29 +377,29 @@ SSL_CTX* SSLSocket_createContext(int socket, MQTTClient_SSLOptions* opts)
if (opts->privateKeyPassword != NULL)
{
SSL_CTX_set_default_passwd_cb(ctx, pem_passwd_cb);
SSL_CTX_set_default_passwd_cb_userdata(ctx, (void*)opts->privateKeyPassword);
SSL_CTX_set_default_passwd_cb(net->ctx, pem_passwd_cb);
SSL_CTX_set_default_passwd_cb_userdata(net->ctx, (void*)opts->privateKeyPassword);
}
/* support for ASN.1 == DER format? DER can contain only one certificate? */
if ((rc = SSL_CTX_use_PrivateKey_file(ctx, opts->privateKey, SSL_FILETYPE_PEM)) != 1)
if ((rc = SSL_CTX_use_PrivateKey_file(net->ctx, opts->privateKey, SSL_FILETYPE_PEM)) != 1)
{
SSLSocket_error("SSL_CTX_use_PrivateKey_file", NULL, socket, rc);
SSLSocket_error("SSL_CTX_use_PrivateKey_file", NULL, net->socket, rc);
goto free_ctx;
}
}
if (opts->trustStore)
{
if ((rc = SSL_CTX_load_verify_locations(ctx, opts->trustStore, NULL)) != 1)
if ((rc = SSL_CTX_load_verify_locations(net->ctx, opts->trustStore, NULL)) != 1)
{
SSLSocket_error("SSL_CTX_load_verify_locations", NULL, socket, rc);
SSLSocket_error("SSL_CTX_load_verify_locations", NULL, net->socket, rc);
goto free_ctx;
}
}
else if ((rc = SSL_CTX_set_default_verify_paths(ctx)) != 1)
else if ((rc = SSL_CTX_set_default_verify_paths(net->ctx)) != 1)
{
SSLSocket_error("SSL_CTX_set_default_verify_paths", NULL, socket, rc);
SSLSocket_error("SSL_CTX_set_default_verify_paths", NULL, net->socket, rc);
goto free_ctx;
}
......@@ -395,56 +408,54 @@ SSL_CTX* SSLSocket_createContext(int socket, MQTTClient_SSLOptions* opts)
else
ciphers = opts->enabledCipherSuites;
if ((rc = SSL_CTX_set_cipher_list(ctx, ciphers)) != 1)
if ((rc = SSL_CTX_set_cipher_list(net->ctx, ciphers)) != 1)
{
SSLSocket_error("SSL_CTX_set_cipher_list", NULL, socket, rc);
SSLSocket_error("SSL_CTX_set_cipher_list", NULL, net->socket, rc);
goto free_ctx;
}
SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
SSL_CTX_set_mode(net->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
goto exit;
free_ctx:
SSL_CTX_free(ctx);
ctx = NULL;
SSL_CTX_free(net->ctx);
net->ctx = NULL;
exit:
FUNC_EXIT;
return ctx;
FUNC_EXIT_RC(rc);
return rc;
}
SSL* SSLSocket_setSocketForSSL(int socket, MQTTClient_SSLOptions* opts)
int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts)
{
int rc = 1;
SSL_CTX* ctx = NULL;
SSL* ssl = NULL;
FUNC_ENTRY;
if ((ctx = SSLSocket_createContext(socket, opts)) != NULL)
if (net->ctx != NULL || (rc = SSLSocket_createContext(net, opts)) == 1)
{
int i;
SSL_CTX_set_info_callback(ctx, SSL_CTX_info_callback);
SSL_CTX_set_info_callback(net->ctx, SSL_CTX_info_callback);
if (opts->enableServerCertAuth)
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL);
SSL_CTX_set_verify(net->ctx, SSL_VERIFY_PEER, NULL);
ssl = SSL_new(ctx);
net->ssl = SSL_new(net->ctx);
/* Log all ciphers available to the SSL sessions (loaded in ctx) */
for (i = 0; ;i++)
{
const char* cipher = SSL_get_cipher_list(ssl, i);
const char* cipher = SSL_get_cipher_list(net->ssl, i);
if (cipher == NULL) break;
Log(TRACE_MIN, 1, "SSL cipher available: %d:%s", i, cipher);
}
if ((rc = SSL_set_fd(ssl, socket)) != 1)
SSLSocket_error("SSL_set_fd", ssl, socket, rc);
if ((rc = SSL_set_fd(net->ssl, net->socket)) != 1)
SSLSocket_error("SSL_set_fd", net->ssl, net->socket, rc);
}
FUNC_EXIT_RC(rc);
return ssl;
return rc;
}
......@@ -565,10 +576,28 @@ exit:
return buf;
}
void SSLSocket_destroyContext(networkHandles* net)
{
FUNC_ENTRY;
if (net->ctx)
SSL_CTX_free(net->ctx);
net->ctx = NULL;
FUNC_EXIT;
}
int SSLSocket_close(SSL* ssl)
int SSLSocket_close(networkHandles* net)
{
return SSL_shutdown(ssl);
int rc = 1;
FUNC_ENTRY;
if (net->ssl) {
rc = SSL_shutdown(net->ssl);
SSL_free(net->ssl);
net->ssl = NULL;
}
SSLSocket_destroyContext(net);
FUNC_EXIT_RC(rc);
return rc;
}
......
......@@ -29,15 +29,17 @@
#include <openssl/ssl.h>
#include "SocketBuffer.h"
#include "Clients.h"
#define URI_SSL "ssl://"
int SSLSocket_initialize();
SSL* SSLSocket_setSocketForSSL(int socket, MQTTClient_SSLOptions* opts);
void SSLSocket_terminate();
int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts);
int SSLSocket_getch(SSL* ssl, int socket, char* c);
char *SSLSocket_getdata(SSL* ssl, int socket, int bytes, int* actual_len);
int SSLSocket_close(SSL* ssl);
int SSLSocket_close(networkHandles* net);
int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, int buf0len, int count, char** buffers, int* buflens);
int SSLSocket_connect(SSL* ssl, int socket);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment