Commit c897ebac authored by Ian Craggs's avatar Ian Craggs

Merge branch 'develop' of https://github.com/lt-holman/paho.mqtt.c into lt-holman-develop

parents 5f68068c 73cc3ea9
/*******************************************************************************
* Copyright (c) 2018 Wind River Systems, Inc. All Rights Reserved.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Keith Holman - initial implementation and documentation
*******************************************************************************/
#include "Base64.h"
#if defined(WIN32) || defined(WIN64)
#pragma comment(lib, "crypt32.lib")
#include <Windows.h>
#include <WinCrypt.h>
b64_size_t Base64_decode( b64_data_t *out, b64_size_t out_len, const char *in, b64_size_t in_len )
{
b64_size_t ret = 0u;
DWORD dw_out_len = (DWORD)out_len;
if ( CryptStringToBinaryA( in, in_len, CRYPT_STRING_BASE64, out, &dw_out_len, NULL, NULL ) )
ret = (b64_size_t)dw_out_len;
return ret;
}
b64_size_t Base64_encode( char *out, b64_size_t out_len, const b64_data_t *in, b64_size_t in_len )
{
b64_size_t ret = 0u;
DWORD dw_out_len = (DWORD)out_len;
if ( CryptBinaryToStringA( in, in_len, CRYPT_STRING_BASE64 | CRYPT_STRING_NOCRLF, out, &dw_out_len ) )
ret = (b64_size_t)dw_out_len;
return ret;
}
#else /* if defined(WIN32) || defined(WIN64) */
#if defined(OPENSSL)
#include <openssl/bio.h>
#include <openssl/evp.h>
static b64_size_t Base64_encodeDecode(
char *out, b64_size_t out_len, const char *in, b64_size_t in_len, int encode )
{
b64_size_t ret = 0u;
if ( in_len > 0u )
{
int rv;
BIO *bio, *b64, *b_in, *b_out;
b64 = BIO_new(BIO_f_base64());
bio = BIO_new(BIO_s_mem());
b64 = BIO_push(b64, bio);
BIO_set_flags(b64, BIO_FLAGS_BASE64_NO_NL); /* ignore new-lines */
if ( encode )
{
b_in = bio;
b_out = b64;
}
else
{
b_in = b64;
b_out = bio;
}
rv = BIO_write(b_out, in, (int)in_len);
BIO_flush(b_out); /* indicate end of encoding */
if ( rv > 0 )
{
rv = BIO_read(b_in, out, (int)out_len);
if ( rv > 0 )
{
ret = (b64_size_t)rv;
if ( out_len > ret )
out[ret] = '\0';
}
}
BIO_free_all(b64); /* free all used memory */
}
return ret;
}
b64_size_t Base64_decode( b64_data_t *out, b64_size_t out_len, const char *in, b64_size_t in_len )
{
return Base64_encodeDecode( (char*)out, out_len, in, in_len, 0 );
}
b64_size_t Base64_encode( char *out, b64_size_t out_len, const b64_data_t *in, b64_size_t in_len )
{
return Base64_encodeDecode( out, out_len, (const char*)in, in_len, 1 );
}
#else /* if defined(OPENSSL) */
b64_size_t Base64_decode( b64_data_t *out, b64_size_t out_len, const char *in, b64_size_t in_len )
{
#define NV 64
static const unsigned char BASE64_DECODE_TABLE[] =
{
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 0-15 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 16-31 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, 62, NV, NV, NV, 63, /* 32-47 */
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, NV, NV, NV, NV, NV, NV, /* 48-63 */
NV, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, /* 64-79 */
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, NV, NV, NV, NV, NV, /* 80-95 */
NV, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, /* 96-111 */
41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, NV, NV, NV, NV, NV, /* 112-127 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 128-143 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 144-159 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 160-175 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 176-191 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 192-207 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 208-223 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 224-239 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV /* 240-255 */
};
b64_size_t ret = 0u;
b64_size_t out_count = 0u;
/* in valid base64, length must be multiple of 4's: 0, 4, 8, 12, etc */
while ( in_len > 3u && out_count < out_len )
{
int i;
unsigned char c[4];
for ( i = 0; i < 4; ++i, ++in )
c[i] = BASE64_DECODE_TABLE[(int)(*in)];
in_len -= 4u;
/* first byte */
*out = c[0] << 2;
*out |= (c[1] & ~0xF) >> 4;
++out;
++out_count;
if ( out_count < out_len )
{
/* second byte */
*out = (c[1] & 0xF) << 4;
if ( c[2] < NV )
{
*out |= (c[2] & ~0x3) >> 2;
++out;
++out_count;
if ( out_count < out_len )
{
/* third byte */
*out = (c[2] & 0x3) << 6;
if ( c[3] < NV )
{
*out |= c[3];
++out;
++out_count;
}
else
in_len = 0u;
}
}
else
in_len = 0u;
}
}
if ( out_count <= out_len )
{
ret = out_count;
if ( out_count < out_len )
*out = '\0';
}
return ret;
}
b64_size_t Base64_encode( char *out, b64_size_t out_len, const b64_data_t *in, b64_size_t in_len )
{
static const char BASE64_ENCODE_TABLE[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/=";
b64_size_t ret = 0u;
b64_size_t out_count = 0u;
while ( in_len > 0u && out_count < out_len )
{
int i;
unsigned char c[] = { 0, 0, 64, 64 }; /* index of '=' char */
/* first character */
i = *in;
c[0] = (i & ~0x3) >> 2;
/* second character */
c[1] = (i & 0x3) << 4;
--in_len;
if ( in_len > 0u )
{
++in;
i = *in;
c[1] |= (i & ~0xF) >> 4;
/* third character */
c[2] = (i & 0xF) << 2;
--in_len;
if ( in_len > 0u )
{
++in;
i = *in;
c[2] |= (i & ~0x3F) >> 6;
/* fourth character */
c[3] = (i & 0x3F);
--in_len;
++in;
}
}
/* encode the characters */
out_count += 4u;
for ( i = 0; i < 4 && out_count <= out_len; ++i, ++out )
*out = BASE64_ENCODE_TABLE[c[i]];
}
if ( out_count <= out_len )
{
if ( out_count < out_len )
*out = '\0';
ret = out_count;
}
return ret;
}
#endif /* else if defined(OPENSSL) */
#endif /* if else defined(WIN32) || defined(WIN64) */
b64_size_t Base64_decodeLength( const char *in, b64_size_t in_len )
{
b64_size_t pad = 0u;
if ( in && in_len > 1u )
pad += ( in[in_len - 2u] == '=' ? 1u : 0u );
if ( in && in_len > 0u )
pad += ( in[in_len - 1u] == '=' ? 1u : 0u );
return (in_len / 4u * 3u) - pad;
}
b64_size_t Base64_encodeLength( const b64_data_t *in, b64_size_t in_len )
{
return ((4u * in_len / 3u) + 3u) & ~0x3;
}
#if defined(BASE64_TEST)
#include <stdio.h>
#include <string.h>
#define TEST_EXPECT(i,x) if (!(x)) {fprintf( stderr, "failed test: %s (for i == %d)\n", #x, i ); ++fails;}
int main(int argc, char *argv[])
{
struct _td
{
const char *in;
const char *out;
};
int i;
unsigned int fails = 0u;
struct _td test_data[] = {
{ "", "" },
{ "p", "cA==" },
{ "pa", "cGE=" },
{ "pah", "cGFo" },
{ "paho", "cGFobw==" },
{ "paho ", "cGFobyA=" },
{ "paho w", "cGFobyB3" },
{ "paho wi", "cGFobyB3aQ==" },
{ "paho wit", "cGFobyB3aXQ=" },
{ "paho with", "cGFobyB3aXRo" },
{ "paho with ", "cGFobyB3aXRoIA==" },
{ "paho with w", "cGFobyB3aXRoIHc=" },
{ "paho with we", "cGFobyB3aXRoIHdl" },
{ "paho with web", "cGFobyB3aXRoIHdlYg==" },
{ "paho with webs", "cGFobyB3aXRoIHdlYnM=" },
{ "paho with webso", "cGFobyB3aXRoIHdlYnNv" },
{ "paho with websoc", "cGFobyB3aXRoIHdlYnNvYw==" },
{ "paho with websock", "cGFobyB3aXRoIHdlYnNvY2s=" },
{ "paho with websocke", "cGFobyB3aXRoIHdlYnNvY2tl" },
{ "paho with websocket", "cGFobyB3aXRoIHdlYnNvY2tldA==" },
{ "paho with websockets", "cGFobyB3aXRoIHdlYnNvY2tldHM=" },
{ "paho with websockets.", "cGFobyB3aXRoIHdlYnNvY2tldHMu" },
{ "The quick brown fox jumps over the lazy dog",
"VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIHRoZSBsYXp5IGRvZw==" },
{ "Man is distinguished, not only by his reason, but by this singular passion from\n"
"other animals, which is a lust of the mind, that by a perseverance of delight\n"
"in the continued and indefatigable generation of knowledge, exceeds the short\n"
"vehemence of any carnal pleasure.",
"TWFuIGlzIGRpc3Rpbmd1aXNoZWQsIG5vdCBvbmx5IGJ5IGhpcyByZWFzb24sIGJ1dCBieSB0aGlz"
"IHNpbmd1bGFyIHBhc3Npb24gZnJvbQpvdGhlciBhbmltYWxzLCB3aGljaCBpcyBhIGx1c3Qgb2Yg"
"dGhlIG1pbmQsIHRoYXQgYnkgYSBwZXJzZXZlcmFuY2Ugb2YgZGVsaWdodAppbiB0aGUgY29udGlu"
"dWVkIGFuZCBpbmRlZmF0aWdhYmxlIGdlbmVyYXRpb24gb2Yga25vd2xlZGdlLCBleGNlZWRzIHRo"
"ZSBzaG9ydAp2ZWhlbWVuY2Ugb2YgYW55IGNhcm5hbCBwbGVhc3VyZS4=" },
{ NULL, NULL }
};
/* decode tests */
i = 0;
while ( test_data[i].in != NULL )
{
int r;
char out[512u];
r = Base64_decode( out, sizeof(out), test_data[i].out, strlen(test_data[i].out) );
TEST_EXPECT( i, r == strlen(test_data[i].in) && strncmp(out, test_data[i].in, strlen(test_data[i].in)) == 0 );
++i;
}
/* decode length tests */
i = 0;
while ( test_data[i].in != NULL )
{
TEST_EXPECT( i, Base64_decodeLength(test_data[i].out, strlen(test_data[i].out)) == strlen(test_data[i].in));
++i;
}
/* encode tests */
i = 0;
while ( test_data[i].in != NULL )
{
int r;
char out[512u];
r = Base64_encode( out, sizeof(out), test_data[i].in, strlen(test_data[i].in) );
TEST_EXPECT( i, r == strlen(test_data[i].out) && strncmp(out, test_data[i].out, strlen(test_data[i].out)) == 0 );
++i;
}
/* encode length tests */
i = 0;
while ( test_data[i].in != NULL )
{
TEST_EXPECT( i, Base64_encodeLength(test_data[i].in, strlen(test_data[i].in)) == strlen(test_data[i].out) );
++i;
}
if ( fails )
printf( "%u test failed!\n", fails );
else
printf( "all tests passed\n" );
return fails;
}
#endif /* if defined(BASE64_TEST) */
/*******************************************************************************
* Copyright (c) 2018 Wind River Systems, Inc. All Rights Reserved.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Keith Holman - initial implementation and documentation
*******************************************************************************/
#if !defined(BASE64_H)
#define BASE64_H
/** type for size of a buffer, it saves passing around @p size_t (unsigned long long or unsigned long int) */
typedef unsigned int b64_size_t;
/** type for raw base64 data */
typedef unsigned char b64_data_t;
/**
* Decodes base64 data
*
* @param[out] out decoded data
* @param[in] out_len length of output buffer
* @param[in] in base64 string to decode
* @param[in] in_len length of input buffer
*
* @return the amount of data decoded
*
* @see Base64_decodeLength
* @see Base64_encode
*/
b64_size_t Base64_decode( b64_data_t *out, b64_size_t out_len,
const char *in, b64_size_t in_len );
/**
* Size of buffer required to decode base64 data
*
* @param[in] in base64 string to decode
* @param[in] in_len length of input buffer
*
* @return the size of buffer the decoded string would require
*
* @see Base64_decode
* @see Base64_encodeLength
*/
b64_size_t Base64_decodeLength( const char *in, b64_size_t in_len );
/**
* Encodes base64 data
*
* @param[out] out encode base64 string
* @param[in] out_len length of output buffer
* @param[in] in raw data to encode
* @param[in] in_len length of input buffer
*
* @return the amount of data encoded
*
* @see Base64_decode
* @see Base64_encodeLength
*/
b64_size_t Base64_encode( char *out, b64_size_t out_len,
const b64_data_t *in, b64_size_t in_len );
/**
* Size of buffer required to encode base64 data
*
* @param[in] in raw data to encode
* @param[in] in_len length of input buffer
*
* @return the size of buffer the encoded string would require
*
* @see Base64_decodeLength
* @see Base64_encode
*/
b64_size_t Base64_encodeLength( const b64_data_t *in, b64_size_t in_len );
#endif /* BASE64_H */
......@@ -46,6 +46,12 @@ SET(common_src
SocketBuffer.c
Heap.c
LinkedList.c
Base64.c
Base64.h
SHA1.c
SHA1.h
WebSocket.c
WebSocket.h
)
IF (WIN32)
......@@ -60,7 +66,6 @@ ELSEIF (UNIX)
ENDIF()
ENDIF()
## common compilation for libpaho-mqtt3c and libpaho-mqtt3a
ADD_LIBRARY(common_obj OBJECT ${common_src})
SET_PROPERTY(TARGET common_obj PROPERTY POSITION_INDEPENDENT_CODE ON)
......@@ -163,3 +168,21 @@ IF (PAHO_WITH_SSL)
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
ENDIF()
ENDIF()
# Base64 test
ADD_EXECUTABLE( Base64Test EXCLUDE_FROM_ALL Base64.c Base64.h )
TARGET_COMPILE_DEFINITIONS( Base64Test PUBLIC "-DBASE64_TEST" )
IF (PAHO_WITH_SSL)
ADD_EXECUTABLE( Base64TestOpenSSL EXCLUDE_FROM_ALL Base64.c Base64.h )
TARGET_LINK_LIBRARIES( Base64TestOpenSSL ${OPENSSL_LIB} ${OPENSSLCRYPTO_LIB} )
TARGET_COMPILE_DEFINITIONS( Base64TestOpenSSL PUBLIC "-DBASE64_TEST -DOPENSSL=1" )
ENDIF (PAHO_WITH_SSL)
# SHA1 test
ADD_EXECUTABLE( Sha1Test EXCLUDE_FROM_ALL SHA1.c SHA1.h )
TARGET_COMPILE_DEFINITIONS( Sha1Test PUBLIC "-DSHA1_TEST" )
IF (PAHO_WITH_SSL)
ADD_EXECUTABLE( Sha1TestOpenSSL EXCLUDE_FROM_ALL SHA1.c SHA1.h )
TARGET_LINK_LIBRARIES( Sha1TestOpenSSL ${OPENSSL_LIB} ${OPENSSLCRYPTO_LIB} )
TARGET_COMPILE_DEFINITIONS( Sha1TestOpenSSL PUBLIC "-DSHA1_TEST -DOPENSSL=1" )
ENDIF (PAHO_WITH_SSL)
......@@ -158,8 +158,25 @@ typedef struct
SSL* ssl;
SSL_CTX* ctx;
#endif
int websocket; /**< socket has been upgraded to use web sockets */
char *websocket_key;
} networkHandles;
/* connection states */
/** no connection in progress, see connected value */
#define NOT_IN_PROGRESS 0x0
/** TCP connection in progress */
#define TCP_IN_PROGRESS 0x1
/** SSL connection in progress */
#define SSL_IN_PROGRESS 0x2
/** Websocket connection in progress */
#define WEBSOCKET_IN_PROGRESS 0x3
/** TCP completed, waiting for MQTT ACK */
#define WAIT_FOR_CONNACK 0x4
/** Disconnecting */
#define DISCONNECTING -2
/**
* Data related to one client
*/
......
......@@ -62,8 +62,11 @@
#include "StackTrace.h"
#include "Heap.h"
#include "OsWrapper.h"
#include "WebSocket.h"
#define URI_TCP "tcp://"
#define URI_WS "ws://"
#define URI_WSS "wss://"
#include "VersionInfo.h"
......@@ -301,6 +304,7 @@ typedef struct MQTTAsync_struct
{
char* serverURI;
int ssl;
int websocket;
Clients* c;
/* "Global", to the client, callback definitions */
......@@ -476,9 +480,10 @@ int MQTTAsync_createWithOptions(MQTTAsync* handle, const char* serverURI, const
if (strstr(serverURI, "://") != NULL)
{
if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) != 0
&& strncmp(URI_WS, serverURI, strlen(URI_WS)) != 0
#if defined(OPENSSL)
&& strncmp(URI_SSL, serverURI, strlen(URI_SSL)) != 0
&& strncmp(URI_WSS, serverURI, strlen(URI_WSS)) != 0
#endif
)
{
......@@ -514,12 +519,23 @@ int MQTTAsync_createWithOptions(MQTTAsync* handle, const char* serverURI, const
memset(m, '\0', sizeof(MQTTAsyncs));
if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) == 0)
serverURI += strlen(URI_TCP);
else if (strncmp(URI_WS, serverURI, strlen(URI_WS)) == 0)
{
serverURI += strlen(URI_WS);
m->websocket = 1;
}
#if defined(OPENSSL)
else if (strncmp(URI_SSL, serverURI, strlen(URI_SSL)) == 0)
{
serverURI += strlen(URI_SSL);
m->ssl = 1;
}
else if (strncmp(URI_WSS, serverURI, strlen(URI_WSS)) == 0)
{
serverURI += strlen(URI_WSS);
m->ssl = 1;
m->websocket = 1;
}
#endif
m->serverURI = MQTTStrdup(serverURI);
m->responses = ListInitialize();
......@@ -582,10 +598,7 @@ static void MQTTAsync_terminate(void)
MQTTAsync_freeCommand1((MQTTAsync_queuedCommand*)(elem->content));
ListFree(commands);
handles = NULL;
Socket_outTerminate();
#if defined(OPENSSL)
SSLSocket_terminate();
#endif
WebSocket_terminate();
#if defined(HEAP_H)
Heap_terminate();
#endif
......@@ -885,7 +898,7 @@ static int MQTTAsync_addCommand(MQTTAsync_queuedCommand* command, int command_si
FUNC_ENTRY;
MQTTAsync_lock_mutex(mqttcommand_mutex);
/* Don't set start time if the connect command is already in process #218 */
if ((command->command.type != CONNECT) || (command->client->c->connect_state == 0))
if ((command->command.type != CONNECT) || (command->client->c->connect_state == NOT_IN_PROGRESS))
command->command.start_time = MQTTAsync_start_clock();
if (command->command.type == CONNECT ||
(command->command.type == DISCONNECT && command->command.details.dis.internal))
......@@ -1199,7 +1212,7 @@ static int MQTTAsync_processCommand(void)
continue;
if (cmd->command.type == CONNECT || cmd->command.type == DISCONNECT || (cmd->client->c->connected &&
cmd->client->c->connect_state == 0 && MQTTAsync_Socket_noPendingWrites(cmd->client->c->net.socket)))
cmd->client->c->connect_state == NOT_IN_PROGRESS && MQTTAsync_Socket_noPendingWrites(cmd->client->c->net.socket)))
{
if ((cmd->command.type == PUBLISH || cmd->command.type == SUBSCRIBE || cmd->command.type == UNSUBSCRIBE) &&
cmd->client->c->outboundMsgs->count >= MAX_MSG_ID - 1)
......@@ -1230,7 +1243,7 @@ static int MQTTAsync_processCommand(void)
if (command->command.type == CONNECT)
{
if (command->client->c->connect_state != 0 || command->client->c->connected)
if (command->client->c->connect_state != NOT_IN_PROGRESS || command->client->c->connected)
rc = 0;
else
{
......@@ -1242,12 +1255,23 @@ static int MQTTAsync_processCommand(void)
if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) == 0)
serverURI += strlen(URI_TCP);
else if (strncmp(URI_WS, serverURI, strlen(URI_WS)) == 0)
{
serverURI += strlen(URI_WS);
command->client->websocket = 1;
}
#if defined(OPENSSL)
else if (strncmp(URI_SSL, serverURI, strlen(URI_SSL)) == 0)
{
serverURI += strlen(URI_SSL);
command->client->ssl = 1;
}
else if (strncmp(URI_WSS, serverURI, strlen(URI_WSS)) == 0)
{
serverURI += strlen(URI_WSS);
command->client->ssl = 1;
command->client->websocket = 1;
}
#endif
}
......@@ -1263,11 +1287,11 @@ static int MQTTAsync_processCommand(void)
Log(TRACE_MIN, -1, "Connecting to serverURI %s with MQTT version %d", serverURI, command->command.details.conn.MQTTVersion);
#if defined(OPENSSL)
rc = MQTTProtocol_connect(serverURI, command->client->c, command->client->ssl, command->command.details.conn.MQTTVersion);
rc = MQTTProtocol_connect(serverURI, command->client->c, command->client->ssl, command->client->websocket, command->command.details.conn.MQTTVersion);
#else
rc = MQTTProtocol_connect(serverURI, command->client->c, command->command.details.conn.MQTTVersion);
rc = MQTTProtocol_connect(serverURI, command->client->c, command->client->websocket, command->command.details.conn.MQTTVersion);
#endif
if (command->client->c->connect_state == 0)
if (command->client->c->connect_state == NOT_IN_PROGRESS)
rc = SOCKET_ERROR;
/* if the TCP connect is pending, then we must call select to determine when the connect has completed,
......@@ -1349,11 +1373,11 @@ static int MQTTAsync_processCommand(void)
}
else if (command->command.type == DISCONNECT)
{
if (command->client->c->connect_state != 0 || command->client->c->connected != 0)
if (command->client->c->connect_state != NOT_IN_PROGRESS || command->client->c->connected != 0)
{
if (command->client->c->connect_state != 0)
if (command->client->c->connect_state != NOT_IN_PROGRESS)
{
command->client->c->connect_state = -2;
command->client->c->connect_state = DISCONNECTING;
if (command->client->connect.onFailure)
{
MQTTAsync_failureData data;
......@@ -1509,11 +1533,11 @@ static void MQTTAsync_checkTimeouts(void)
MQTTAsyncs* m = (MQTTAsyncs*)(current->content);
/* check disconnect timeout */
if (m->c->connect_state == -2)
if (m->c->connect_state == DISCONNECTING)
MQTTAsync_checkDisconnect(m, &m->disconnect);
/* check connect timeout */
if (m->c->connect_state != 0 && MQTTAsync_elapsed(m->connect.start_time) > (m->connectTimeout * 1000))
if (m->c->connect_state != NOT_IN_PROGRESS && MQTTAsync_elapsed(m->connect.start_time) > (m->connectTimeout * 1000))
{
nextOrClose(m, MQTTASYNC_FAILURE, "TCP connect timeout");
continue;
......@@ -1763,7 +1787,7 @@ static int MQTTAsync_completeConnection(MQTTAsyncs* m, MQTTPacket* pack)
int rc = MQTTASYNC_FAILURE;
FUNC_ENTRY;
if (m->c->connect_state == 3) /* MQTT connect sent - wait for CONNACK */
if (m->c->connect_state == WAIT_FOR_CONNACK) /* MQTT connect sent - wait for CONNACK */
{
Connack* connack = (Connack*)pack;
Log(LOG_PROTOCOL, 1, NULL, m->c->net.socket, m->c->clientID, connack->rc);
......@@ -1772,7 +1796,7 @@ static int MQTTAsync_completeConnection(MQTTAsyncs* m, MQTTPacket* pack)
m->retrying = 0;
m->c->connected = 1;
m->c->good = 1;
m->c->connect_state = 0;
m->c->connect_state = NOT_IN_PROGRESS;
if (m->c->cleansession)
rc = MQTTAsync_cleanSession(m->c);
if (m->c->outboundMsgs->count > 0)
......@@ -1851,7 +1875,7 @@ static thread_return_type WINAPI MQTTAsync_receiveThread(void* n)
MQTTAsync_disconnect_internal(m, 0);
MQTTAsync_lock_mutex(mqttasync_mutex);
}
else if (m->c->connect_state != 0)
else if (m->c->connect_state != NOT_IN_PROGRESS)
nextOrClose(m, rc, "socket error");
else /* calling disconnect_internal won't have any effect if we're already disconnected */
MQTTAsync_closeOnly(m->c);
......@@ -2039,7 +2063,7 @@ static void MQTTAsync_stop(void)
/* find out how many handles are still connected */
while (ListNextElement(handles, &current))
{
if (((MQTTAsyncs*)(current->content))->c->connect_state > 0 ||
if (((MQTTAsyncs*)(current->content))->c->connect_state > NOT_IN_PROGRESS ||
((MQTTAsyncs*)(current->content))->c->connected)
++conn_count;
}
......@@ -2076,7 +2100,7 @@ int MQTTAsync_setCallbacks(MQTTAsync handle, void* context,
FUNC_ENTRY;
MQTTAsync_lock_mutex(mqttasync_mutex);
if (m == NULL || ma == NULL || m->c->connect_state != 0)
if (m == NULL || ma == NULL || m->c == NULL || m->c->connect_state != NOT_IN_PROGRESS)
rc = MQTTASYNC_FAILURE;
else
{
......@@ -2100,7 +2124,7 @@ int MQTTAsync_setConnected(MQTTAsync handle, void* context, MQTTAsync_connected*
FUNC_ENTRY;
MQTTAsync_lock_mutex(mqttasync_mutex);
if (m == NULL || m->c->connect_state != 0)
if (m == NULL || m->c->connect_state != NOT_IN_PROGRESS)
rc = MQTTASYNC_FAILURE;
else
{
......@@ -2125,6 +2149,7 @@ static void MQTTAsync_closeOnly(Clients* client)
if (client->connected && Socket_noPendingWrites(client->net.socket))
MQTTPacket_send_disconnect(&client->net, client->clientID);
Thread_lock_mutex(socket_mutex);
WebSocket_close(&client->net, WebSocket_CLOSE_NORMAL, NULL);
#if defined(OPENSSL)
SSLSocket_close(&client->net);
#endif
......@@ -2136,7 +2161,7 @@ static void MQTTAsync_closeOnly(Clients* client)
Thread_unlock_mutex(socket_mutex);
}
client->connected = 0;
client->connect_state = 0;
client->connect_state = NOT_IN_PROGRESS;
FUNC_EXIT;
}
......@@ -2914,7 +2939,7 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
int rc = -1;
FUNC_ENTRY;
if (m->c->connect_state == 1) /* TCP connect started - check for completion */
if (m->c->connect_state == TCP_IN_PROGRESS) /* TCP connect started - check for completion */
{
int error;
socklen_t len = sizeof(error);
......@@ -2931,13 +2956,12 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
if (m->ssl)
{
int port;
char* hostname;
size_t hostname_len;
int setSocketForSSLrc = 0;
hostname = MQTTProtocol_addressPort(m->serverURI, &port);
setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts, hostname);
if (hostname != m->serverURI)
free(hostname);
hostname_len = MQTTProtocol_addressPort(m->serverURI, &port, NULL);
setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts,
m->serverURI, hostname_len);
if (setSocketForSSLrc != MQTTASYNC_SUCCESS)
{
......@@ -2949,7 +2973,7 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
if (rc == TCPSOCKET_INTERRUPTED)
{
rc = MQTTCLIENT_SUCCESS; /* the connect is still in progress */
m->c->connect_state = 2;
m->c->connect_state = SSL_IN_PROGRESS;
}
else if (rc == SSL_FATAL)
{
......@@ -2957,14 +2981,23 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
goto exit;
}
else if (rc == 1)
{
if ( m->websocket )
{
m->c->connect_state = WEBSOCKET_IN_PROGRESS;
if ((rc = WebSocket_connect(&m->c->net, m->serverURI)) == SOCKET_ERROR )
goto exit;
}
else
{
rc = MQTTCLIENT_SUCCESS;
m->c->connect_state = 3;
m->c->connect_state = WAIT_FOR_CONNACK;
if (MQTTPacket_send_connect(m->c, m->connect.details.conn.MQTTVersion) == SOCKET_ERROR)
{
rc = SOCKET_ERROR;
goto exit;
}
}
if (!m->c->cleansession && m->c->session == NULL)
m->c->session = SSL_get1_session(m->c->net.ssl);
}
......@@ -2978,15 +3011,24 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
else
{
#endif
m->c->connect_state = 3; /* TCP/SSL connect completed, in which case send the MQTT connect packet */
if ( m->websocket )
{
m->c->connect_state = WEBSOCKET_IN_PROGRESS;
if ((rc = WebSocket_connect(&m->c->net, m->serverURI)) == SOCKET_ERROR )
goto exit;
}
else
{
m->c->connect_state = WAIT_FOR_CONNACK; /* TCP/SSL connect completed, in which case send the MQTT connect packet */
if ((rc = MQTTPacket_send_connect(m->c, m->connect.details.conn.MQTTVersion)) == SOCKET_ERROR)
goto exit;
}
#if defined(OPENSSL)
}
#endif
}
#if defined(OPENSSL)
else if (m->c->connect_state == 2) /* SSL connect sent - wait for completion */
else if (m->c->connect_state == SSL_IN_PROGRESS) /* SSL connect sent - wait for completion */
{
if ((rc = SSLSocket_connect(m->c->net.ssl, m->c->net.socket,
m->serverURI, m->c->sslopts->verify)) != 1)
......@@ -2994,14 +3036,34 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
if(!m->c->cleansession && m->c->session == NULL)
m->c->session = SSL_get1_session(m->c->net.ssl);
m->c->connect_state = 3; /* SSL connect completed, in which case send the MQTT connect packet */
if ( m->websocket )
{
m->c->connect_state = WEBSOCKET_IN_PROGRESS;
if ((rc = WebSocket_connect(&m->c->net, m->serverURI)) == SOCKET_ERROR )
goto exit;
}
else
{
m->c->connect_state = WAIT_FOR_CONNACK; /* SSL connect completed, in which case send the MQTT connect packet */
if ((rc = MQTTPacket_send_connect(m->c, m->connect.details.conn.MQTTVersion)) == SOCKET_ERROR)
goto exit;
}
}
#endif
else if (m->c->connect_state == WEBSOCKET_IN_PROGRESS) /* Websocket connect sent - wait for completion */
{
if ((rc = WebSocket_upgrade( &m->c->net ) ) == SOCKET_ERROR )
goto exit;
else
{
m->c->connect_state = WAIT_FOR_CONNACK; /* Websocket upgrade completed, in which case send the MQTT connect packet */
if ((rc = MQTTPacket_send_connect(m->c, m->connect.details.conn.MQTTVersion)) == SOCKET_ERROR)
goto exit;
}
}
exit:
if ((rc != 0 && rc != TCPSOCKET_INTERRUPTED && m->c->connect_state != 2) || (rc == SSL_FATAL))
if ((rc != 0 && rc != TCPSOCKET_INTERRUPTED && (m->c->connect_state != SSL_IN_PROGRESS && m->c->connect_state != WEBSOCKET_IN_PROGRESS)) || (rc == SSL_FATAL))
nextOrClose(m, MQTTASYNC_FAILURE, "TCP/TLS connect failure");
FUNC_EXIT_RC(rc);
......@@ -3042,11 +3104,11 @@ static MQTTPacket* MQTTAsync_cycle(int* sock, unsigned long timeout, int* rc)
if (m != NULL)
{
Log(TRACE_MINIMUM, -1, "m->c->connect_state = %d",m->c->connect_state);
if (m->c->connect_state == 1 || m->c->connect_state == 2)
if (m->c->connect_state == TCP_IN_PROGRESS || m->c->connect_state == SSL_IN_PROGRESS || m->c->connect_state == WEBSOCKET_IN_PROGRESS)
*rc = MQTTAsync_connecting(m);
else
pack = MQTTPacket_Factory(&m->c->net, rc);
if (m->c->connect_state == 3 && *rc == SOCKET_ERROR)
if (m->c->connect_state == WAIT_FOR_CONNACK && *rc == SOCKET_ERROR)
{
Log(TRACE_MINIMUM, -1, "CONNECT sent but MQTTPacket_Factory has returned SOCKET_ERROR");
nextOrClose(m, MQTTASYNC_FAILURE, "TCP connect completion failure");
......
......@@ -71,9 +71,11 @@
#include "OsWrapper.h"
#define URI_TCP "tcp://"
#define URI_WS "ws://"
#define URI_WSS "wss://"
#include "VersionInfo.h"
#include "WebSocket.h"
const char *client_timestamp_eye = "MQTTClientV3_Timestamp " BUILD_TIMESTAMP;
const char *client_version_eye = "MQTTClientV3_Version " CLIENT_VERSION;
......@@ -195,6 +197,7 @@ typedef struct
#if defined(OPENSSL)
int ssl;
#endif
int websocket;
Clients* c;
MQTTClient_connectionLost* cl;
MQTTClient_messageArrived* ma;
......@@ -331,9 +334,10 @@ int MQTTClient_create(MQTTClient* handle, const char* serverURI, const char* cli
if (strstr(serverURI, "://") != NULL)
{
if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) != 0
&& strncmp(URI_WS, serverURI, strlen(URI_WS)) != 0
#if defined(OPENSSL)
&& strncmp(URI_SSL, serverURI, strlen(URI_SSL)) != 0
&& strncmp(URI_WSS, serverURI, strlen(URI_WSS)) != 0
#endif
)
{
......@@ -357,11 +361,17 @@ int MQTTClient_create(MQTTClient* handle, const char* serverURI, const char* cli
#endif
initialized = 1;
}
m = malloc(sizeof(MQTTClients));
*handle = m;
memset(m, '\0', sizeof(MQTTClients));
if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) == 0)
serverURI += strlen(URI_TCP);
else if (strncmp(URI_WS, serverURI, strlen(URI_WS)) == 0)
{
serverURI += strlen(URI_WS);
m->websocket = 1;
}
else if (strncmp(URI_SSL, serverURI, strlen(URI_SSL)) == 0)
{
#if defined(OPENSSL)
......@@ -370,6 +380,17 @@ int MQTTClient_create(MQTTClient* handle, const char* serverURI, const char* cli
#else
rc = MQTTCLIENT_SSL_NOT_SUPPORTED;
goto exit;
#endif
}
else if (strncmp(URI_WSS, serverURI, strlen(URI_WSS)) == 0)
{
#if defined(OPENSSL)
serverURI += strlen(URI_WSS);
m->ssl = 1;
m->websocket = 1;
#else
rc = MQTTCLIENT_SSL_NOT_SUPPORTED;
goto exit;
#endif
}
m->serverURI = MQTTStrdup(serverURI);
......@@ -414,10 +435,7 @@ static void MQTTClient_terminate(void)
ListFree(bstate->clients);
ListFree(handles);
handles = NULL;
Socket_outTerminate();
#if defined(OPENSSL)
SSLSocket_terminate();
#endif
WebSocket_terminate();
#if defined(HEAP_H)
Heap_terminate();
#endif
......@@ -599,12 +617,12 @@ static thread_return_type WINAPI MQTTClient_run(void* n)
MQTTClient_disconnect_internal(m, 0);
else
{
if (m->c->connect_state == 2 && !Thread_check_sem(m->connect_sem))
if (m->c->connect_state == SSL_IN_PROGRESS && !Thread_check_sem(m->connect_sem))
{
Log(TRACE_MIN, -1, "Posting connect semaphore for client %s", m->c->clientID);
Thread_post_sem(m->connect_sem);
}
if (m->c->connect_state == 3 && !Thread_check_sem(m->connack_sem))
if (m->c->connect_state == WAIT_FOR_CONNACK && !Thread_check_sem(m->connack_sem))
{
Log(TRACE_MIN, -1, "Posting connack semaphore for client %s", m->c->clientID);
Thread_post_sem(m->connack_sem);
......@@ -663,7 +681,7 @@ static thread_return_type WINAPI MQTTClient_run(void* n)
Thread_post_sem(m->unsuback_sem);
}
}
else if (m->c->connect_state == 1 && !Thread_check_sem(m->connect_sem))
else if (m->c->connect_state == TCP_IN_PROGRESS && !Thread_check_sem(m->connect_sem))
{
int error;
socklen_t len = sizeof(error);
......@@ -674,7 +692,7 @@ static thread_return_type WINAPI MQTTClient_run(void* n)
Thread_post_sem(m->connect_sem);
}
#if defined(OPENSSL)
else if (m->c->connect_state == 2 && !Thread_check_sem(m->connect_sem))
else if (m->c->connect_state == SSL_IN_PROGRESS && !Thread_check_sem(m->connect_sem))
{
rc = SSLSocket_connect(m->c->net.ssl, m->c->net.socket,
m->serverURI, m->c->sslopts->verify);
......@@ -688,6 +706,12 @@ static thread_return_type WINAPI MQTTClient_run(void* n)
}
}
#endif
else if (m->c->connect_state == WEBSOCKET_IN_PROGRESS && !Thread_check_sem(m->connect_sem))
{
Log(TRACE_MIN, -1, "Posting websocket handshake for client %s rc %d", m->c->clientID, m->rc);
m->c->connect_state = WAIT_FOR_CONNACK;
Thread_post_sem(m->connect_sem);
}
}
}
run_id = 0;
......@@ -713,7 +737,7 @@ static void MQTTClient_stop(void)
/* find out how many handles are still connected */
while (ListNextElement(handles, &current))
{
if (((MQTTClients*)(current->content))->c->connect_state > 0 ||
if (((MQTTClients*)(current->content))->c->connect_state > NOT_IN_PROGRESS ||
((MQTTClients*)(current->content))->c->connected)
++conn_count;
}
......@@ -750,7 +774,7 @@ int MQTTClient_setCallbacks(MQTTClient handle, void* context, MQTTClient_connect
FUNC_ENTRY;
Thread_lock_mutex(mqttclient_mutex);
if (m == NULL || ma == NULL || m->c->connect_state != 0)
if (m == NULL || ma == NULL || m->c->connect_state != NOT_IN_PROGRESS)
rc = MQTTCLIENT_FAILURE;
else
{
......@@ -776,6 +800,8 @@ static void MQTTClient_closeSession(Clients* client)
if (client->connected)
MQTTPacket_send_disconnect(&client->net, client->clientID);
Thread_lock_mutex(socket_mutex);
WebSocket_close(&client->net, WebSocket_CLOSE_NORMAL, NULL);
#if defined(OPENSSL)
SSLSocket_close(&client->net);
#endif
......@@ -787,7 +813,7 @@ static void MQTTClient_closeSession(Clients* client)
#endif
}
client->connected = 0;
client->connect_state = 0;
client->connect_state = NOT_IN_PROGRESS;
if (client->cleansession)
MQTTClient_cleanSession(client);
......@@ -877,20 +903,20 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt
Log(TRACE_MIN, -1, "Connecting to serverURI %s with MQTT version %d", serverURI, MQTTVersion);
#if defined(OPENSSL)
rc = MQTTProtocol_connect(serverURI, m->c, m->ssl, MQTTVersion);
rc = MQTTProtocol_connect(serverURI, m->c, m->ssl, m->websocket, MQTTVersion);
#else
rc = MQTTProtocol_connect(serverURI, m->c, MQTTVersion);
rc = MQTTProtocol_connect(serverURI, m->c, m->websocket, MQTTVersion);
#endif
if (rc == SOCKET_ERROR)
goto exit;
if (m->c->connect_state == 0)
if (m->c->connect_state == NOT_IN_PROGRESS)
{
rc = SOCKET_ERROR;
goto exit;
}
if (m->c->connect_state == 1) /* TCP connect started - wait for completion */
if (m->c->connect_state == TCP_IN_PROGRESS) /* TCP connect started - wait for completion */
{
Thread_unlock_mutex(mqttclient_mutex);
MQTTClient_waitfor(handle, CONNECT, &rc, millisecsTimeout - MQTTClient_elapsed(start));
......@@ -900,18 +926,17 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt
rc = SOCKET_ERROR;
goto exit;
}
#if defined(OPENSSL)
if (m->ssl)
{
int port;
char* hostname;
size_t hostname_len;
const char *topic;
int setSocketForSSLrc = 0;
hostname = MQTTProtocol_addressPort(m->serverURI, &port);
setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts, hostname);
if (hostname != m->serverURI)
free(hostname);
hostname_len = MQTTProtocol_addressPort(m->serverURI, &port, &topic);
setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts,
m->serverURI, hostname_len);
if (setSocketForSSLrc != MQTTCLIENT_SUCCESS)
{
......@@ -921,16 +946,25 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt
rc = SSLSocket_connect(m->c->net.ssl, m->c->net.socket,
m->serverURI, m->c->sslopts->verify);
if (rc == TCPSOCKET_INTERRUPTED)
m->c->connect_state = 2; /* the connect is still in progress */
m->c->connect_state = SSL_IN_PROGRESS; /* the connect is still in progress */
else if (rc == SSL_FATAL)
{
rc = SOCKET_ERROR;
goto exit;
}
else if (rc == 1)
{
if (m->websocket)
{
m->c->connect_state = WEBSOCKET_IN_PROGRESS;
rc = WebSocket_connect(&m->c->net,m->serverURI);
if ( rc == SOCKET_ERROR )
goto exit;
}
else if ( rc == 1 )
{
rc = MQTTCLIENT_SUCCESS;
m->c->connect_state = 3;
m->c->connect_state = WAIT_FOR_CONNACK;
if (MQTTPacket_send_connect(m->c, MQTTVersion) == SOCKET_ERROR)
{
rc = SOCKET_ERROR;
......@@ -940,28 +974,36 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt
m->c->session = SSL_get1_session(m->c->net.ssl);
}
}
}
else
{
rc = SOCKET_ERROR;
goto exit;
}
}
#endif
else if (m->websocket)
{
m->c->connect_state = WEBSOCKET_IN_PROGRESS;
if ( WebSocket_connect(&m->c->net, m->serverURI) == SOCKET_ERROR )
{
rc = SOCKET_ERROR;
goto exit;
}
}
else
{
#endif
m->c->connect_state = 3; /* TCP connect completed, in which case send the MQTT connect packet */
m->c->connect_state = WAIT_FOR_CONNACK; /* TCP connect completed, in which case send the MQTT connect packet */
if (MQTTPacket_send_connect(m->c, MQTTVersion) == SOCKET_ERROR)
{
rc = SOCKET_ERROR;
goto exit;
}
#if defined(OPENSSL)
}
#endif
}
#if defined(OPENSSL)
if (m->c->connect_state == 2) /* SSL connect sent - wait for completion */
if (m->c->connect_state == SSL_IN_PROGRESS) /* SSL connect sent - wait for completion */
{
Thread_unlock_mutex(mqttclient_mutex);
MQTTClient_waitfor(handle, CONNECT, &rc, millisecsTimeout - MQTTClient_elapsed(start));
......@@ -973,19 +1015,45 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt
}
if(!m->c->cleansession && m->c->session == NULL)
m->c->session = SSL_get1_session(m->c->net.ssl);
m->c->connect_state = 3; /* TCP connect completed, in which case send the MQTT connect packet */
if ( m->websocket )
{
/* wait for websocket connect */
m->c->connect_state = WEBSOCKET_IN_PROGRESS;
rc = WebSocket_connect( &m->c->net, m->serverURI );
if ( rc != 1 )
{
rc = SOCKET_ERROR;
goto exit;
}
}
else
{
m->c->connect_state = WAIT_FOR_CONNACK; /* TCP connect completed, in which case send the MQTT connect packet */
if (MQTTPacket_send_connect(m->c, MQTTVersion) == SOCKET_ERROR)
{
rc = SOCKET_ERROR;
goto exit;
}
}
}
#endif
if (m->c->connect_state == 3) /* MQTT connect sent - wait for CONNACK */
if (m->c->connect_state == WEBSOCKET_IN_PROGRESS) /* websocket request sent - wait for upgrade */
{
MQTTPacket* pack = NULL;
Thread_unlock_mutex(mqttclient_mutex);
MQTTClient_waitfor(handle, CONNECT, &rc, millisecsTimeout - MQTTClient_elapsed(start));
Thread_lock_mutex(mqttclient_mutex);
m->c->connect_state = WAIT_FOR_CONNACK; /* websocket upgrade complete */
if (MQTTPacket_send_connect(m->c, MQTTVersion) == SOCKET_ERROR)
{
rc = SOCKET_ERROR;
goto exit;
}
}
if (m->c->connect_state == WAIT_FOR_CONNACK) /* MQTT connect sent - wait for CONNACK */
{
MQTTPacket* pack = NULL;
Thread_unlock_mutex(mqttclient_mutex);
pack = MQTTClient_waitfor(handle, CONNACK, &rc, millisecsTimeout - MQTTClient_elapsed(start));
Thread_lock_mutex(mqttclient_mutex);
......@@ -999,7 +1067,7 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt
{
m->c->connected = 1;
m->c->good = 1;
m->c->connect_state = 0;
m->c->connect_state = NOT_IN_PROGRESS;
if (MQTTVersion == 4)
sessionPresent = connack->flags.bits.sessionPresent;
if (m->c->cleansession)
......@@ -1262,12 +1330,23 @@ int MQTTClient_connect(MQTTClient handle, MQTTClient_connectOptions* options)
if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) == 0)
serverURI += strlen(URI_TCP);
else if (strncmp(URI_WS, serverURI, strlen(URI_WS)) == 0)
{
serverURI += strlen(URI_WS);
m->websocket = 1;
}
#if defined(OPENSSL)
else if (strncmp(URI_SSL, serverURI, strlen(URI_SSL)) == 0)
{
serverURI += strlen(URI_SSL);
m->ssl = 1;
}
else if (strncmp(URI_WSS, serverURI, strlen(URI_WSS)) == 0)
{
serverURI += strlen(URI_WSS);
m->ssl = 1;
m->websocket = 1;
}
#endif
if ((rc = MQTTClient_connectURI(handle, options, serverURI)) == MQTTCLIENT_SUCCESS)
break;
......@@ -1307,7 +1386,7 @@ static int MQTTClient_disconnect1(MQTTClient handle, int timeout, int call_conne
rc = MQTTCLIENT_FAILURE;
goto exit;
}
if (m->c->connected == 0 && m->c->connect_state == 0)
if (m->c->connected == 0 && m->c->connect_state == NOT_IN_PROGRESS)
{
rc = MQTTCLIENT_DISCONNECTED;
goto exit;
......@@ -1316,7 +1395,7 @@ static int MQTTClient_disconnect1(MQTTClient handle, int timeout, int call_conne
if (m->c->connected != 0)
{
start = MQTTClient_start_clock();
m->c->connect_state = -2; /* indicate disconnecting */
m->c->connect_state = DISCONNECTING; /* indicate disconnecting */
while (m->c->inboundMsgs->count > 0 || m->c->outboundMsgs->count > 0)
{ /* wait for all inflight message flows to finish, up to timeout */
if (MQTTClient_elapsed(start) >= timeout)
......@@ -1628,11 +1707,16 @@ int MQTTClient_publish(MQTTClient handle, const char* topicName, int payloadlen,
goto exit;
}
p = malloc(sizeof(Publish));
p = malloc(sizeof(Publish) + payloadlen);
p->payload = payload;
p->payloadlen = payloadlen;
p->topic = (char*)topicName;
if (payloadlen > 0)
{
p->payload = (char*)p + sizeof(Publish);
memcpy(p->payload, payload, payloadlen);
p->payloadlen = payloadlen;
}
p->topic = MQTTStrdup(topicName);
p->msgId = msgid;
rc = MQTTProtocol_startPublish(m->c, p, qos, retained, &msg);
......@@ -1656,6 +1740,7 @@ int MQTTClient_publish(MQTTClient handle, const char* topicName, int payloadlen,
if (deliveryToken && qos > 0)
*deliveryToken = msg->msgid;
if (p->topic) free(p->topic);
free(p);
if (rc == SOCKET_ERROR)
......@@ -1748,8 +1833,10 @@ static MQTTPacket* MQTTClient_cycle(int* sock, unsigned long timeout, int* rc)
m = (MQTTClient)(handles->current->content);
if (m != NULL)
{
if (m->c->connect_state == 1 || m->c->connect_state == 2)
if (m->c->connect_state == TCP_IN_PROGRESS || m->c->connect_state == SSL_IN_PROGRESS)
*rc = 0; /* waiting for connect state to clear */
else if (m->c->connect_state == WEBSOCKET_IN_PROGRESS)
*rc = WebSocket_upgrade(&m->c->net);
else
{
pack = MQTTPacket_Factory(&m->c->net, rc);
......@@ -1757,6 +1844,7 @@ static MQTTPacket* MQTTClient_cycle(int* sock, unsigned long timeout, int* rc)
*rc = 0;
}
}
if (pack)
{
int freed = 1;
......@@ -1840,7 +1928,7 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r
break;
if (pack && (pack->header.bits.type == packet_type))
break;
if (m->c->connect_state == 1)
if (m->c->connect_state == TCP_IN_PROGRESS)
{
int error;
socklen_t len = sizeof(error);
......@@ -1850,7 +1938,7 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r
break;
}
#if defined(OPENSSL)
else if (m->c->connect_state == 2)
else if (m->c->connect_state == SSL_IN_PROGRESS)
{
*rc = SSLSocket_connect(m->c->net.ssl, sock,
m->serverURI, m->c->sslopts->verify);
......@@ -1864,11 +1952,15 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r
}
}
#endif
else if (m->c->connect_state == 3)
else if (m->c->connect_state == WEBSOCKET_IN_PROGRESS )
{
*rc = 1;
break;
}
else if (m->c->connect_state == WAIT_FOR_CONNACK)
{
int error;
socklen_t len = sizeof(error);
if (getsockopt(m->c->net.socket, SOL_SOCKET, SO_ERROR, (char*)&error, &len) == 0)
{
if (error)
......@@ -1972,7 +2064,7 @@ void MQTTClient_yield(void)
if (rc == SOCKET_ERROR && ListFindItem(handles, &sock, clientSockCompare))
{
MQTTClients* m = (MQTTClient)(handles->current->content);
if (m->c->connect_state != -2)
if (m->c->connect_state != DISCONNECTING)
MQTTClient_disconnect_internal(m, 0);
}
Thread_unlock_mutex(mqttclient_mutex);
......
......@@ -31,6 +31,7 @@
#endif
#include "Messages.h"
#include "StackTrace.h"
#include "WebSocket.h"
#include <stdlib.h>
#include <string.h>
......@@ -109,11 +110,7 @@ void* MQTTPacket_Factory(networkHandles* net, int* error)
*error = SOCKET_ERROR; /* indicate whether an error occurred, or not */
/* read the packet data from the socket */
#if defined(OPENSSL)
*error = (net->ssl) ? SSLSocket_getch(net->ssl, net->socket, &header.byte) : Socket_getch(net->socket, &header.byte);
#else
*error = Socket_getch(net->socket, &header.byte);
#endif
*error = WebSocket_getch(net, &header.byte);
if (*error != TCPSOCKET_COMPLETE) /* first byte is the header byte */
goto exit; /* packet not read, *error indicates whether SOCKET_ERROR occurred */
......@@ -122,12 +119,7 @@ void* MQTTPacket_Factory(networkHandles* net, int* error)
goto exit; /* packet not read, *error indicates whether SOCKET_ERROR occurred */
/* now read the rest, the variable header and payload */
#if defined(OPENSSL)
data = (net->ssl) ? SSLSocket_getdata(net->ssl, net->socket, remaining_length, &actual_len) :
Socket_getdata(net->socket, remaining_length, &actual_len);
#else
data = Socket_getdata(net->socket, remaining_length, &actual_len);
#endif
data = WebSocket_getdata(net, remaining_length, &actual_len);
if (data == NULL)
{
*error = SOCKET_ERROR;
......@@ -179,13 +171,17 @@ int MQTTPacket_send(networkHandles* net, Header header, char* buffer, size_t buf
{
int rc;
size_t buf0len;
size_t ws_header;
char *buf;
int count = 0;
FUNC_ENTRY;
buf = malloc(10);
buf[0] = header.byte;
buf0len = 1 + MQTTPacket_encode(&buf[1], buflen);
ws_header = WebSocket_calculateFrameHeaderSize(net, 1, buflen + 10);
buf = malloc(10 + ws_header);
if ( !buf ) return -1;
buf[ws_header] = header.byte;
buf0len = 1 + MQTTPacket_encode(&buf[ws_header + 1], buflen);
if (buffer != NULL)
count = 1;
......@@ -195,17 +191,12 @@ int MQTTPacket_send(networkHandles* net, Header header, char* buffer, size_t buf
{
char* ptraux = buffer;
int msgId = readInt(&ptraux);
rc = MQTTPersistence_put(net->socket, buf, buf0len, count, &buffer, &buflen,
rc = MQTTPersistence_put(net->socket, &buf[ws_header], buf0len, count, &buffer, &buflen,
header.bits.type, msgId, 0);
}
#endif
#if defined(OPENSSL)
if (net->ssl)
rc = SSLSocket_putdatas(net->ssl, net->socket, buf, buf0len, count, &buffer, &buflen, &freeData);
else
#endif
rc = Socket_putdatas(net->socket, buf, buf0len, count, &buffer, &buflen, &freeData);
rc = WebSocket_putdatas(net, &buf[ws_header], buf0len, count, &buffer, &buflen, &freeData);
if (rc == TCPSOCKET_COMPLETE)
time(&(net->lastSent));
......@@ -231,29 +222,30 @@ int MQTTPacket_sends(networkHandles* net, Header header, int count, char** buffe
{
int i, rc;
size_t buf0len, total = 0;
size_t ws_header;
char *buf;
FUNC_ENTRY;
buf = malloc(10);
buf[0] = header.byte;
for (i = 0; i < count; i++)
total += buflens[i];
buf0len = 1 + MQTTPacket_encode(&buf[1], total);
ws_header = WebSocket_calculateFrameHeaderSize(net, 1, total + 10);
buf = malloc(10 + ws_header);
if ( !buf ) return -1;
buf[ws_header] = header.byte;
buf0len = 1 + MQTTPacket_encode(&buf[ws_header + 1], total);
#if !defined(NO_PERSISTENCE)
if (header.bits.type == PUBLISH && header.bits.qos != 0)
{ /* persist PUBLISH QoS1 and Qo2 */
char *ptraux = buffers[2];
int msgId = readInt(&ptraux);
rc = MQTTPersistence_put(net->socket, buf, buf0len, count, buffers, buflens,
rc = MQTTPersistence_put(net->socket, &buf[ws_header], buf0len, count, buffers, buflens,
header.bits.type, msgId, 0);
}
#endif
#if defined(OPENSSL)
if (net->ssl)
rc = SSLSocket_putdatas(net->ssl, net->socket, buf, buf0len, count, buffers, buflens, frees);
else
#endif
rc = Socket_putdatas(net->socket, buf, buf0len, count, buffers, buflens, frees);
rc = WebSocket_putdatas(net, &buf[ws_header], buf0len, count, buffers, buflens, frees);
if (rc == TCPSOCKET_COMPLETE)
time(&(net->lastSent));
......@@ -313,11 +305,7 @@ int MQTTPacket_decode(networkHandles* net, size_t* value)
rc = SOCKET_ERROR; /* bad data */
goto exit;
}
#if defined(OPENSSL)
rc = (net->ssl) ? SSLSocket_getch(net->ssl, net->socket, &c) : Socket_getch(net->socket, &c);
#else
rc = Socket_getch(net->socket, &c);
#endif
rc = WebSocket_getch(net, &c);
if (rc != TCPSOCKET_COMPLETE)
goto exit;
*value += (c & 127) * multiplier;
......
......@@ -219,13 +219,14 @@ Publications* MQTTProtocol_storePublication(Publish* publish, int* len)
p->refcount = 1;
*len = (int)strlen(publish->topic)+1;
if (Heap_findItem(publish->topic))
p->topic = publish->topic;
else
{
p->topic = malloc(*len);
strcpy(p->topic, publish->topic);
if (Heap_findItem(publish->topic))
{
free(publish->topic);
publish->topic = NULL;
}
*len += sizeof(Publications);
p->topiclen = publish->topiclen;
......
......@@ -35,6 +35,7 @@
#include "MQTTProtocolOut.h"
#include "StackTrace.h"
#include "Heap.h"
#include "WebSocket.h"
extern ClientStates* bstate;
......@@ -42,11 +43,12 @@ extern ClientStates* bstate;
/**
* Separates an address:port into two separate values
* @param uri the input string - hostname:port
* @param port the returned port integer
* @param[in] uri the input string - hostname:port
* @param[out] port the returned port integer
* @param[out] topic optional topic portion of the address starting with '/'
* @return the address string
*/
char* MQTTProtocol_addressPort(const char* uri, int* port)
size_t MQTTProtocol_addressPort(const char* uri, int* port, const char **topic)
{
char* colon_pos = strrchr(uri, ':'); /* reverse find to allow for ':' in IPv6 addresses */
char* buf = (char*)uri;
......@@ -61,27 +63,31 @@ char* MQTTProtocol_addressPort(const char* uri, int* port)
if (colon_pos) /* have to strip off the port */
{
size_t addr_len = colon_pos - uri;
buf = malloc(addr_len + 1);
len = colon_pos - uri;
*port = atoi(colon_pos + 1);
MQTTStrncpy(buf, uri, addr_len+1);
}
else
{
len = strlen(buf);
*port = DEFAULT_PORT;
}
len = strlen(buf);
if (buf[len - 1] == ']')
{
if (buf == (char*)uri)
/* try and find topic portion */
if ( topic )
{
buf = malloc(len); /* we are stripping off the final ], so length is 1 shorter */
MQTTStrncpy(buf, uri, len);
const char* addr_start = uri;
if ( colon_pos )
addr_start = colon_pos;
*topic = strchr( addr_start, '/' );
}
else
buf[len - 1] = '\0';
if (buf[len - 1] == ']')
{
/* we are stripping off the final ], so length is 1 shorter */
--len;
}
FUNC_EXIT;
return buf;
return len;
}
......@@ -94,49 +100,52 @@ char* MQTTProtocol_addressPort(const char* uri, int* port)
* @return return code
*/
#if defined(OPENSSL)
int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int ssl, int MQTTVersion)
int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int ssl, int websocket, int MQTTVersion)
#else
int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int MQTTVersion)
int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int websocket, int MQTTVersion)
#endif
{
int rc, port;
char* addr;
size_t addr_len;
FUNC_ENTRY;
aClient->good = 1;
addr = MQTTProtocol_addressPort(ip_address, &port);
rc = Socket_new(addr, port, &(aClient->net.socket));
addr_len = MQTTProtocol_addressPort(ip_address, &port, NULL);
rc = Socket_new(ip_address, addr_len, port, &(aClient->net.socket));
if (rc == EINPROGRESS || rc == EWOULDBLOCK)
aClient->connect_state = 1; /* TCP connect called - wait for connect completion */
aClient->connect_state = TCP_IN_PROGRESS; /* TCP connect called - wait for connect completion */
else if (rc == 0)
{ /* TCP connect completed. If SSL, send SSL connect */
#if defined(OPENSSL)
if (ssl)
{
if (SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts, addr) == 1)
if (SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts, ip_address, addr_len) == 1)
{
rc = SSLSocket_connect(aClient->net.ssl, aClient->net.socket,
addr, aClient->sslopts->verify);
ip_address, aClient->sslopts->verify);
if (rc == TCPSOCKET_INTERRUPTED)
aClient->connect_state = 2; /* SSL connect called - wait for completion */
aClient->connect_state = SSL_IN_PROGRESS; /* SSL connect called - wait for completion */
}
else
rc = SOCKET_ERROR;
}
#endif
if ( websocket )
{
rc = WebSocket_connect( &aClient->net, ip_address );
if ( rc == TCPSOCKET_INTERRUPTED )
aClient->connect_state = WEBSOCKET_IN_PROGRESS; /* Websocket connect called - wait for completion */
}
if (rc == 0)
{
/* Now send the MQTT connect packet */
if ((rc = MQTTPacket_send_connect(aClient, MQTTVersion)) == 0)
aClient->connect_state = 3; /* MQTT Connect sent - wait for CONNACK */
aClient->connect_state = WAIT_FOR_CONNACK; /* MQTT Connect sent - wait for CONNACK */
else
aClient->connect_state = 0;
aClient->connect_state = NOT_IN_PROGRESS;
}
}
if (addr != ip_address)
free(addr);
FUNC_EXIT_RC(rc);
return rc;
......
......@@ -30,12 +30,12 @@
#define DEFAULT_PORT 1883
char* MQTTProtocol_addressPort(const char* uri, int* port);
size_t MQTTProtocol_addressPort(const char* uri, int* port, const char **topic);
void MQTTProtocol_reconnect(const char* ip_address, Clients* client);
#if defined(OPENSSL)
int MQTTProtocol_connect(const char* ip_address, Clients* acClients, int ssl, int MQTTVersion);
int MQTTProtocol_connect(const char* ip_address, Clients* acClients, int ssl, int websocket, int MQTTVersion);
#else
int MQTTProtocol_connect(const char* ip_address, Clients* acClients, int MQTTVersion);
int MQTTProtocol_connect(const char* ip_address, Clients* acClients, int websocket, int MQTTVersion);
#endif
int MQTTProtocol_handlePingresps(void* pack, int sock);
int MQTTProtocol_subscribe(Clients* client, List* topics, List* qoss, int msgID);
......
......@@ -60,7 +60,7 @@
char* FindString(char* filename, const char* eyecatcher_input);
int printVersionInfo(MQTTAsync_nameValue* info);
int loadandcall(char* libname);
int loadandcall(const char* libname);
void printEyecatchers(char* filename);
......@@ -133,16 +133,14 @@ int printVersionInfo(MQTTAsync_nameValue* info)
typedef MQTTAsync_nameValue* (*func_type)(void);
int loadandcall(char* libname)
int loadandcall(const char* libname)
{
int rc = 0;
MQTTAsync_nameValue* (*func_address)(void) = NULL;
#if defined(WIN32) || defined(WIN64)
wchar_t wlibname[30];
HMODULE APILibrary;
mbstowcs(wlibname, libname, strlen(libname) + 1);
if ((APILibrary = LoadLibrary(wlibname)) == NULL)
if ((APILibrary = LoadLibraryA(libname)) == NULL)
printf("Error loading library %s, error code %d\n", libname, GetLastError());
else
{
......
/*******************************************************************************
* Copyright (c) 2018 Wind River Systems, Inc. All Rights Reserved.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Keith Holman - initial implementation and documentation
*******************************************************************************/
#include "SHA1.h"
#if !defined(OPENSSL)
#if defined(WIN32) || defined(WIN64)
#pragma comment(lib, "crypt32.lib")
int SHA1_Init(SHA_CTX *c)
{
if (!CryptAcquireContext(&c->hProv, NULL, NULL,
PROV_RSA_FULL, CRYPT_VERIFYCONTEXT))
return 0;
if (!CryptCreateHash(c->hProv, CALG_SHA1, 0, 0, &c->hHash))
{
CryptReleaseContext(c->hProv, 0);
return 0;
}
return 1;
}
int SHA1_Update(SHA_CTX *c, const void *data, size_t len)
{
int rv = 1;
if (!CryptHashData(c->hHash, data, len, 0))
rv = 0;
return rv;
}
int SHA1_Final(unsigned char *md, SHA_CTX *c)
{
int rv = 0;
DWORD md_len = SHA1_DIGEST_LENGTH;
if (CryptGetHashParam(c->hHash, HP_HASHVAL, md, &md_len, 0))
rv = 1;
CryptDestroyHash(c->hHash);
CryptReleaseContext(c->hProv, 0);
return rv;
}
#else /* if defined(WIN32) || defined(WIN64) */
#if defined(__linux__)
# include <endian.h>
#elif defined(__APPLE__)
# include <libkern/OSByteOrder.h>
# define htobe32(x) OSSwapHostToBigInt32(x)
# define be32toh(x) OSSwapBigToHostInt32(x)
#elif defined(__FreeBSD__) || defined(__NetBSD__)
# include <sys/endian.h>
#endif
#include <string.h>
static unsigned char pad[64] = {
0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00
};
int SHA1_Init(SHA_CTX *ctx)
{
int ret = 0;
if ( ctx )
{
ctx->h[0] = 0x67452301;
ctx->h[1] = 0xEFCDAB89;
ctx->h[2] = 0x98BADCFE;
ctx->h[3] = 0x10325476;
ctx->h[4] = 0xC3D2E1F0;
ctx->size = 0u;
ctx->total = 0u;
ret = 1;
}
return ret;
}
#define ROTATE_LEFT32(a, n) (((a) << (n)) | ((a) >> (32 - (n))))
static void SHA1_ProcessBlock(SHA_CTX *ctx)
{
uint32_t blks[5];
uint32_t *w;
int i;
/* initialize */
for ( i = 0; i < 5; ++i )
blks[i] = ctx->h[i];
w = ctx->w;
/* perform SHA-1 hash */
for ( i = 0; i < 16; ++i )
w[i] = be32toh(w[i]);
for( i = 0; i < 80; ++i )
{
int tmp;
if ( i >= 16 )
w[i & 0x0F] = ROTATE_LEFT32( w[(i+13) & 0x0F] ^ w[(i+8) & 0x0F] ^ w[(i+2) & 0x0F] ^ w[i & 0x0F], 1 );
if ( i < 20 )
tmp = ROTATE_LEFT32(blks[0], 5) + ((blks[1] & blks[2]) | (~(blks[1]) & blks[3])) + blks[4] + w[i & 0x0F] + 0x5A827999;
else if ( i < 40 )
tmp = ROTATE_LEFT32(blks[0], 5) + (blks[1]^blks[2]^blks[3]) + blks[4] + w[i & 0x0F] + 0x6ED9EBA1;
else if ( i < 60 )
tmp = ROTATE_LEFT32(blks[0], 5) + ((blks[1] & blks[2]) | (blks[1] & blks[3]) | (blks[2] & blks[3])) + blks[4] + w[i & 0x0F] + 0x8F1BBCDC;
else
tmp = ROTATE_LEFT32(blks[0], 5) + (blks[1]^blks[2]^blks[3]) + blks[4] + w[i & 0x0F] + 0xCA62C1D6;
/* update registers */
blks[4] = blks[3];
blks[3] = blks[2];
blks[2] = ROTATE_LEFT32(blks[1], 30);
blks[1] = blks[0];
blks[0] = tmp;
}
/* update of hash */
for ( i = 0; i < 5; ++i )
ctx->h[i] += blks[i];
}
int SHA1_Final(unsigned char *md, SHA_CTX *ctx)
{
int i;
int ret = 0;
size_t pad_amount;
uint64_t total;
/* length before pad */
total = ctx->total * 8;
if ( ctx->size < 56 )
pad_amount = 56 - ctx->size;
else
pad_amount = 64 + 56 - ctx->size;
SHA1_Update(ctx, pad, pad_amount);
ctx->w[14] = htobe32((uint32_t)(total >> 32));
ctx->w[15] = htobe32((uint32_t)total);
SHA1_ProcessBlock(ctx);
for ( i = 0; i < 5; ++i )
ctx->h[i] = htobe32(ctx->h[i]);
if ( md )
{
memcpy( md, &ctx->h[0], SHA1_DIGEST_LENGTH );
ret = 1;
}
return ret;
}
int SHA1_Update(SHA_CTX *ctx, const void *data, size_t len)
{
while ( len > 0 )
{
unsigned int n = 64 - ctx->size;
if ( len < n )
n = len;
memcpy(ctx->buffer + ctx->size, data, n);
ctx->size += n;
ctx->total += n;
data = (uint8_t *)data + n;
len -= n;
if ( ctx->size == 64 )
{
SHA1_ProcessBlock(ctx);
ctx->size = 0;
}
}
return 1;
}
#endif /* else if defined(WIN32) || defined(WIN64) */
#endif /* elif !defined(OPENSSL) */
#if defined(SHA1_TEST)
#include <stdio.h>
#include <string.h>
#define TEST_EXPECT(i,x) if (!(x)) {fprintf( stderr, "failed test: %s (for i == %d)\n", #x, i ); ++fails;}
int main(int argc, char *argv[])
{
struct _td
{
const char *in;
const char *out;
};
int i;
unsigned int fails = 0u;
struct _td test_data[] = {
{ "", "da39a3ee5e6b4b0d3255bfef95601890afd80709" },
{ "this string", "fda4e74bc7489a18b146abdf23346d166663dab8" },
{ NULL, NULL }
};
/* only 1 update */
i = 0;
while ( test_data[i].in != NULL )
{
int r[3] = { 1, 1, 1 };
unsigned char sha_out[SHA1_DIGEST_LENGTH];
char out[SHA1_DIGEST_LENGTH * 2 + 1];
SHA_CTX c;
int j;
r[0] = SHA1_Init( &c );
r[1] = SHA1_Update( &c, test_data[i].in, strlen(test_data[i].in));
r[2] = SHA1_Final( sha_out, &c );
for ( j = 0u; j < SHA1_DIGEST_LENGTH; ++j )
snprintf( &out[j*2], 3u, "%02x", sha_out[j] );
out[SHA1_DIGEST_LENGTH * 2] = '\0';
TEST_EXPECT( i, r[0] == 1 && r[1] == 1 && r[2] == 1 && strncmp(out, test_data[i].out, strlen(test_data[i].out)) == 0 );
++i;
}
if ( fails )
printf( "%u test failed!\n", fails );
else
printf( "all tests passed\n" );
return fails;
}
#endif /* if defined(SHA1_TEST) */
/*******************************************************************************
* Copyright (c) 2018 Wind River Systems, Inc. All Rights Reserved.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Keith Holman - initial implementation and documentation
*******************************************************************************/
#if !defined(SHA1_H)
#define SHA1_H
#if defined(OPENSSL)
#include <openssl/sha.h>
/** SHA-1 Digest Length */
#define SHA1_DIGEST_LENGTH SHA_DIGEST_LENGTH
#else /* if defined(OPENSSL) */
#if defined(WIN32) || defined(WIN64)
#include <Windows.h>
#include <WinCrypt.h>
typedef struct SHA_CTX_S
{
HCRYPTPROV hProv;
HCRYPTHASH hHash;
} SHA_CTX;
#else /* if defined(WIN32) || defined(WIN64) */
#include <stdint.h>
typedef struct SHA_CTX_S {
uint32_t h[5];
union {
uint32_t w[16];
uint8_t buffer[64];
};
unsigned int size;
unsigned int total;
} SHA_CTX;
#endif /* else if defined(WIN32) || defined(WIN64) */
#include <stddef.h>
/** SHA-1 Digest Length (number of bytes in SHA1) */
#define SHA1_DIGEST_LENGTH (160/8)
/**
* Initializes the SHA1 hashing algorithm
*
* @param[in,out] ctx hashing context structure
*
* @see SHA1_Update
* @see SHA1_Final
*/
int SHA1_Init(SHA_CTX *ctx);
/**
* Updates a block to the SHA1 hash
*
* @param[in,out] ctx hashing context structure
* @param[in] data block of data to hash
* @param[in] len length of block to hash
*
* @see SHA1_Init
* @see SHA1_Final
*/
int SHA1_Update(SHA_CTX *ctx, const void *data, size_t len);
/**
* Produce final SHA1 hash
*
* @param[out] md SHA1 hash produced (must be atleast
* @p SHA1_DIGEST_LENGTH in length)
* @param[in,out] ctx hashing context structure
*
* @see SHA1_Init
* @see SHA1_Final
*/
int SHA1_Final(unsigned char *md, SHA_CTX *ctx);
#endif /* if defined(OPENSSL) */
#endif /* SHA1_H */
......@@ -30,11 +30,11 @@
#include "SocketBuffer.h"
#include "MQTTClient.h"
#include "MQTTProtocolOut.h"
#include "SSLSocket.h"
#include "Log.h"
#include "StackTrace.h"
#include "Socket.h"
char* MQTTProtocol_addressPort(const char* uri, int* port);
#include "Heap.h"
......@@ -620,7 +620,8 @@ exit:
}
int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, char* hostname)
int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
const char* hostname, size_t hostname_len)
{
int rc = 1;
......@@ -628,6 +629,7 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
if (net->ctx != NULL || (rc = SSLSocket_createContext(net, opts)) == 1)
{
char *hostname_plus_null;
int i;
SSL_CTX_set_info_callback(net->ctx, SSL_CTX_info_callback);
......@@ -648,8 +650,11 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
if ((rc = SSL_set_fd(net->ssl, net->socket)) != 1)
SSLSocket_error("SSL_set_fd", net->ssl, net->socket, rc);
if ((rc = SSL_set_tlsext_host_name(net->ssl, hostname)) != 1)
hostname_plus_null = malloc(hostname_len + 1u );
MQTTStrncpy(hostname_plus_null, hostname, hostname_len + 1u);
if ((rc = SSL_set_tlsext_host_name(net->ssl, hostname_plus_null)) != 1)
SSLSocket_error("SSL_set_tlsext_host_name", NULL, net->socket, rc);
free(hostname_plus_null);
}
FUNC_EXIT_RC(rc);
......@@ -659,7 +664,7 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
/*
* Return value: 1 - success, TCPSOCKET_INTERRUPTED - try again, anything else is failure
*/
int SSLSocket_connect(SSL* ssl, int sock, char* hostname, int verify)
int SSLSocket_connect(SSL* ssl, int sock, const char* hostname, int verify)
{
int rc = 0;
......@@ -680,12 +685,12 @@ int SSLSocket_connect(SSL* ssl, int sock, char* hostname, int verify)
{
char* peername = NULL;
int port;
char* addr = NULL;
size_t hostname_len;
X509* cert = SSL_get_peer_certificate(ssl);
addr = MQTTProtocol_addressPort(hostname, &port);
hostname_len = MQTTProtocol_addressPort(hostname, &port, NULL);
rc = X509_check_host(cert, addr, strlen(addr), 0, &peername);
rc = X509_check_host(cert, hostname, hostname_len, 0, &peername);
if (rc == 0)
rc = SOCKET_ERROR;
Log(TRACE_MIN, -1, "rc from X509_check_host is %d", rc);
......@@ -693,8 +698,6 @@ int SSLSocket_connect(SSL* ssl, int sock, char* hostname, int verify)
if (cert)
X509_free(cert);
if (addr != hostname)
free(addr);
}
#endif
......
......@@ -37,14 +37,14 @@ void SSLSocket_handleOpensslInit(int bool_value);
int SSLSocket_initialize(void);
void SSLSocket_terminate(void);
int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, char* hostname);
int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, const char* hostname, size_t hostname_len);
int SSLSocket_getch(SSL* ssl, int socket, char* c);
char *SSLSocket_getdata(SSL* ssl, int socket, size_t bytes, size_t* actual_len);
int SSLSocket_close(networkHandles* net);
int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, size_t buf0len, int count, char** buffers, size_t* buflens, int* frees);
int SSLSocket_connect(SSL* ssl, int sock, char* hostname, int verify);
int SSLSocket_connect(SSL* ssl, int sock, const char* hostname, int verify);
int SSLSocket_getPendingRead(void);
int SSLSocket_continueWrite(pending_writes* pw);
......
......@@ -650,9 +650,10 @@ void Socket_close(int socket)
* @param sock returns the new socket
* @return completion code
*/
int Socket_new(char* addr, int port, int* sock)
int Socket_new(const char* addr, size_t addr_len, int port, int* sock)
{
int type = SOCK_STREAM;
char *addr_mem;
struct sockaddr_in address;
#if defined(AF_INET6)
struct sockaddr_in6 address6;
......@@ -671,9 +672,16 @@ int Socket_new(char* addr, int port, int* sock)
memset(&address6, '\0', sizeof(address6));
if (addr[0] == '[')
{
++addr;
--addr_len;
}
addr_mem = malloc( addr_len + 1u );
memcpy( addr_mem, addr, addr_len );
addr_mem[addr_len] = '\0';
if ((rc = getaddrinfo(addr, NULL, &hints, &result)) == 0)
if ((rc = getaddrinfo(addr_mem, NULL, &hints, &result)) == 0)
{
struct addrinfo* res = result;
......@@ -708,10 +716,10 @@ int Socket_new(char* addr, int port, int* sock)
freeaddrinfo(result);
}
else
Log(LOG_ERROR, -1, "getaddrinfo failed for addr %s with rc %d", addr, rc);
Log(LOG_ERROR, -1, "getaddrinfo failed for addr %s with rc %d", addr_mem, rc);
if (rc != 0)
Log(LOG_ERROR, -1, "%s is not a valid IP address", addr);
Log(LOG_ERROR, -1, "%s is not a valid IP address", addr_mem);
else
{
*sock = (int)socket(family, type, 0);
......@@ -771,6 +779,10 @@ int Socket_new(char* addr, int port, int* sock)
}
}
}
if (addr_mem)
free(addr_mem);
FUNC_EXIT_RC(rc);
return rc;
}
......
......@@ -131,7 +131,7 @@ int Socket_getch(int socket, char* c);
char *Socket_getdata(int socket, size_t bytes, size_t* actual_len);
int Socket_putdatas(int socket, char* buf0, size_t buf0len, int count, char** buffers, size_t* buflens, int* frees);
void Socket_close(int socket);
int Socket_new(char* addr, int port, int* socket);
int Socket_new(const char* addr, size_t addr_len, int port, int* socket);
int Socket_noPendingWrites(int socket);
char* Socket_getpeer(int sock);
......
/*******************************************************************************
* Copyright (c) 2018 Wind River Systems, Inc. All Rights Reserved.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Keith Holman - initial implementation and documentation
*******************************************************************************/
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include "WebSocket.h"
#include "Base64.h"
#include "Log.h"
#include "SHA1.h"
#include "LinkedList.h"
#include "MQTTProtocolOut.h"
#if defined(__linux__)
# include <endian.h>
#elif defined(__APPLE__)
# include <libkern/OSByteOrder.h>
# define htobe16(x) OSSwapHostToBigInt16(x)
# define htobe32(x) OSSwapHostToBigInt32(x)
# define htobe64(x) OSSwapHostToBigInt64(x)
# define be16toh(x) OSSwapBigToHostInt16(x)
# define be32toh(x) OSSwapBigToHostInt32(x)
# define be64toh(x) OSSwapBigToHostInt64(x)
#elif defined(__FreeBSD__) || defined(__NetBSD__)
# include <sys/endian.h>
#elif defined(WIN32) || defined(WIN64)
# pragma comment(lib, "rpcrt4.lib")
# include <Rpc.h>
# define strncasecmp(s1,s2,c) _strnicmp(s1,s2,c)
# if BYTE_ORDER == LITTLE_ENDIAN
# define htobe16(x) htons(x)
# define htobe32(x) htonl(x)
# define htobe64(x) htonll(x)
# define be16toh(x) ntohs(x)
# define be32toh(x) ntohl(x)
# define be64toh(x) ntohll(x)
# elif BTYE_ORDER == BIG_ENDIAN
# define htobe16(x) (x)
# define htobe32(x) (x)
# define htobe64(x) (x)
# define be16toh(x) (x)
# define be32toh(x) (x)
# define be64toh(x) (x)
# else
# error "unknown endian"
# endif
/* For Microsoft Visual Studio 2013 */
# if !defined( snprintf )
# define snprintf _snprintf
# endif /* if !defined( snprintf ) */
#endif
#if defined(OPENSSL)
#include "SSLSocket.h"
#endif /* defined(OPENSSL) */
#include "Socket.h"
#if !(defined(WIN32) || defined(WIN64))
#if defined(LIBUUID)
#include <uuid/uuid.h>
#else /* if defined(USE_LIBUUID) */
#include <limits.h>
#include <stdlib.h>
#include <time.h>
#if defined(OPENSSL)
#include <openssl/rand.h>
#endif /* if defined(OPENSSL) */
/** @brief raw uuid type */
typedef unsigned char uuid_t[16];
/**
* @brief generates a uuid, compatible with RFC 4122, version 4 (random)
* @note Uses a very insecure algorithm but no external dependencies
*/
void uuid_generate( uuid_t out )
{
#if defined(OPENSSL)
int rc = RAND_bytes( out, sizeof(uuid_t));
if ( !rc )
#endif /* defined (OPENSSL) */
{
/* very insecure, but generates a random uuid */
srand(time(NULL));
int i;
for ( i = 0; i < 16; ++i )
out[i] = (unsigned char)(rand() % UCHAR_MAX);
out[6] = (out[6] & 0x0f) | 0x40;
out[8] = (out[8] & 0x3F) | 0x80;
}
}
/** @brief converts a uuid to a string */
void uuid_unparse( uuid_t uu, char *out )
{
int i;
for ( i = 0; i < 16; ++i )
{
if ( i == 4 || i == 6 || i == 8 || i == 10 )
{
*out = '-';
++out;
}
out += sprintf( out, "%02x", uu[i] );
}
*out = '\0';
}
#endif /* else if defined(LIBUUID) */
#endif /* if !(defined(WIN32) || defined(WIN64)) */
/** raw websocket frame data */
struct ws_frame
{
size_t len; /**< length of frame */
size_t pos; /**< current position within the buffer */
};
/** Current frame being processed */
struct ws_frame *last_frame = NULL;
/** Holds any received websocket frames, to be process */
static List* in_frames = NULL;
/* static function declarations */
static const char *WebSocket_strcasefind(
const char *buf, const char *str, size_t len);
static char *WebSocket_getRawSocketData(
networkHandles *net, size_t bytes, size_t* actual_len);
static void WebSocket_pong(
networkHandles *net, char *app_data, size_t app_data_len);
static int WebSocket_receiveFrame(networkHandles *net,
size_t bytes, size_t *actual_len );
/**
* @brief builds a websocket frame for data transmission
*
* write a websocket header and will mask the payload in all the passed in
* buffers
*
* @param[in,out] net network connection
* @param[in] opcode websocket opcode for the packet
* @param[in] mask_data whether to maskt he data
* @param[in,out] buf0 first buffer, will write before this
* @param[in] buf0len size of first buffer
* @param[in] count number of payload buffers
* @param[in,out] buffers array of paylaod buffers
* @param[in] buflens array of payload buffer sizes
* @param[in] freeData array indicating to free payload buffers
*
* @return amount of data to write to socket
*/
static int WebSocket_buildFrame(networkHandles* net, int opcode, int mask_data,
char* buf0, size_t buf0len, int count, char** buffers, size_t* buflens)
{
int i;
int buf_len = 0u;
size_t data_len = buf0len;
for (i = 0; i < count; ++i)
data_len += buflens[i];
buf0 -= WebSocket_calculateFrameHeaderSize(net, mask_data, data_len);
if ( net->websocket )
{
uint8_t mask[4];
/* genearate mask, since we are a client */
#if defined(OPENSSL)
RAND_bytes( &mask[0], sizeof(mask) );
#else /* if defined(OPENSSL) */
mask[0] = (rand() % UINT8_MAX);
mask[1] = (rand() % UINT8_MAX);
mask[2] = (rand() % UINT8_MAX);
mask[3] = (rand() % UINT8_MAX);
#endif /* else if defined(OPENSSL) */
/* 1st byte */
buf0[buf_len] = (char)(1 << 7); /* final flag */
/* 3 bits reserved for negotiation of protocol */
buf0[buf_len] |= (char)(opcode & 0x0F); /* op code */
++buf_len;
/* 2nd byte */
buf0[buf_len] = (char)((mask_data & 0x1) << 7); /* masking bit */
/* payload length */
if ( data_len < 126u )
buf0[buf_len++] |= data_len & 0x7F;
/* 3rd byte & 4th bytes - extended payload length */
else if ( data_len <= 65536u )
{
uint16_t len = htobe16((uint16_t)data_len);
buf0[buf_len++] |= (126u & 0x7F);
memcpy( &buf0[buf_len], &len, 2u );
buf_len += 2;
}
else if ( data_len < 0xFFFFFFFFFFFFFFFF )
{
uint64_t len = htobe64((uint64_t)data_len);
buf0[buf_len++] |= (127u & 0x7F);
memcpy( &buf0[buf_len], &len, 8 );
buf_len += 8;
}
else
{
Log(TRACE_PROTOCOL, 1, "Data too large for websocket frame" );
buf_len = -1;
}
/* masking key */
if ( (mask_data & 0x1) && buf_len > 0 )
{
memcpy( &buf0[buf_len], &mask, sizeof(uint32_t));
buf_len += sizeof(uint32_t);
}
/* mask data */
if ( mask_data & 0x1 )
{
size_t idx = 0u;
/* packet fixed header */
for (i = 0; i < (int)buf0len; ++i, ++idx)
buf0[buf_len + i] ^= mask[idx % 4];
/* variable data buffers */
for (i = 0; i < count; ++i)
{
size_t j;
for ( j = 0u; j < buflens[i]; ++j, ++idx )
buffers[i][j] ^= mask[idx % 4];
}
}
}
return buf_len;
}
/**
* calculates the amount of data required for the websocket header
*
* this function is used to calculate how much offset is required before calling
* @p WebSocket_putdatas, as that function will write data before the passed in
* buffer
*
* @param[in,out] net network connection
* @param[in] mask_data whether to mask the data
* @param[in] data_len amount of data in the payload
*
* @return the size in bytes of the websocket header required
*
* @see WebSocket_putdatas
*/
size_t WebSocket_calculateFrameHeaderSize(networkHandles *net, int mask_data,
size_t data_len)
{
int ret = 0;
if ( net && net->websocket )
{
ret += 2; /* header 2 bytes */
if ( data_len >= 126 )
ret += 2; /* for extra 2-bytes for payload length */
else if ( data_len > 65536u )
ret += 8; /* for extra 8-bytes for payload length */
if ( mask_data & 0x1 )
ret += sizeof(uint32_t); /* for mask */
}
return ret;
}
/**
* sends out a websocket request on the given uri
*
* @param[in] net network connection
* @param[in] uri uri to connect to
*
* @retval SOCKET_ERROR on failure
* @retval 1 on success
*
* @see WebSocket_upgrade
*/
int WebSocket_connect( networkHandles *net, const char *uri )
{
int rc;
char *buf = NULL;
int i, buf_len = 0;
size_t hostname_len;
int port = 80;
const char *topic = NULL;
/* Generate UUID */
net->websocket_key = realloc(net->websocket_key, 25u);
#if defined(WIN32) || defined(WIN64)
UUID uuid;
ZeroMemory( &uuid, sizeof(UUID) );
UuidCreate( &uuid );
Base64_encode( net->websocket_key, 25u, (const b64_data_t*)&uuid, sizeof(UUID) );
#else /* if defined(WIN32) || defined(WIN64) */
uuid_t uuid;
uuid_generate( uuid );
Base64_encode( net->websocket_key, 25u, uuid, sizeof(uuid_t) );
#endif /* else if defined(WIN32) || defined(WIN64) */
hostname_len = MQTTProtocol_addressPort(uri, &port, &topic);
/* if no topic, use default */
if ( !topic )
topic = "/mqtt";
for ( i = 0; i < 2; ++i )
{
buf_len = snprintf( buf, (size_t)buf_len,
"GET %s HTTP/1.1\r\n"
"Host: %.*s:%d\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Origin: http://%.*s:%d\r\n"
"Sec-WebSocket-Key: %s\r\n"
"Sec-WebSocket-Version: 13\r\n"
"Sec-WebSocket-Protocol: mqtt\r\n"
"\r\n", topic,
(int)hostname_len, uri, port,
(int)hostname_len, uri, port,
net->websocket_key );
if ( i == 0 && buf_len > 0 )
{
++buf_len; /* need 1 extra byte for ending '\0' */
buf = malloc( buf_len );
}
}
if ( buf )
{
#if defined(OPENSSL)
if (net->ssl)
SSLSocket_putdatas(net->ssl, net->socket,
buf, buf_len, 0, NULL, NULL, NULL );
else
#endif
Socket_putdatas( net->socket, buf, buf_len,
0, NULL, NULL, NULL );
free( buf );
rc = 1;
}
else
{
free(net->websocket_key);
net->websocket_key = NULL;
rc = SOCKET_ERROR;
}
return rc;
}
/**
* closes a websocket connection
*
* @param[in,out] net structure containing network connection
* @param[in] status_code websocket close status code
* @param[in] reason reason for closing connection (optional)
*/
void WebSocket_close(networkHandles *net, int status_code, const char *reason)
{
if ( net->websocket )
{
char *buf0;
size_t buf0len = sizeof(uint16_t);
size_t header_len;
uint16_t status_code_be;
const int mask_data = 0;
if ( status_code < WebSocket_CLOSE_NORMAL ||
status_code > WebSocket_CLOSE_TLS_FAIL )
status_code = WebSocket_CLOSE_GOING_AWAY;
if ( reason )
buf0len += strlen(reason);
header_len = WebSocket_calculateFrameHeaderSize(net,
mask_data, buf0len);
buf0 = malloc(header_len + buf0len);
if ( !buf0 ) return;
/* encode status code */
status_code_be = htobe16((uint16_t)status_code);
memcpy( &buf0[header_len], &status_code_be, sizeof(uint16_t));
/* encode reason, if provided */
if ( reason )
strcpy( &buf0[header_len + sizeof(uint16_t)], reason );
WebSocket_buildFrame( net, WebSocket_OP_CLOSE, mask_data,
&buf0[header_len], buf0len, 0, NULL, NULL );
buf0len += header_len;
#if defined(OPENSSL)
if (net->ssl)
SSLSocket_putdatas(net->ssl, net->socket,
buf0, buf0len, 0, NULL, NULL, NULL);
else
#endif
Socket_putdatas(net->socket, buf0, buf0len, 0,
NULL, NULL, NULL);
/* websocket connection is now closed */
net->websocket = 0;
free( buf0 );
}
if ( net->websocket_key )
free( net->websocket_key );
}
/**
* @brief receives 1 byte from a socket
*
* @param[in,out] net network connection
* @param[out] c byte that was read
*
* @retval SOCKET_ERROR on error
* @retval TCPSOCKET_INTERRUPTED no data available
* @retval TCPSOCKET_COMPLETE on success
*
* @see WebSocket_getdata
*/
int WebSocket_getch(networkHandles *net, char* c)
{
int rc = SOCKET_ERROR;
if ( net->websocket )
{
struct ws_frame *frame = NULL;
if ( in_frames && in_frames->first )
frame = in_frames->first->content;
if ( !frame )
{
size_t actual_len = 0u;
rc = WebSocket_receiveFrame( net, 1u, &actual_len );
if ( rc != TCPSOCKET_COMPLETE )
return rc;
/* we got a frame, let take off the top of queue */
if ( in_frames->first )
frame = in_frames->first->content;
}
/* set current working frame */
if (frame && frame->len > frame->pos)
{
unsigned char *buf =
(unsigned char *)frame + sizeof(struct ws_frame);
*c = buf[frame->pos++];
rc = TCPSOCKET_COMPLETE;
}
}
#if defined(OPENSSL)
else if ( net->ssl )
rc = SSLSocket_getch(net->ssl, net->socket, c);
#endif
else
rc = Socket_getch(net->socket, c);
return rc;
}
/**
* @brief receives data from a socket
*
* @param[in,out] net network connection
* @param[in] bytes amount of data to get (0 if last packet)
* @param[out] actual_len amount of data read
*
* @return a pointer to the read data
*
* @see WebSocket_getch
*/
char *WebSocket_getdata(networkHandles *net, size_t bytes, size_t* actual_len)
{
char *rv = NULL;
if ( net->websocket )
{
struct ws_frame *frame = NULL;
if ( bytes == 0u )
{
/* done with current frame, move it to last frame */
if ( in_frames && in_frames->first )
frame = in_frames->first->content;
/* return the data from the next frame, if we have one */
if ( frame )
{
rv = (char *)frame +
sizeof(struct ws_frame) + frame->pos;
*actual_len = frame->len - frame->pos;
if ( last_frame )
free( last_frame );
last_frame = ListDetachHead(in_frames);
}
return rv;
}
/* no current frame, let's see if there's one in the list */
if ( in_frames && in_frames->first )
frame = in_frames->first->content;
/* no current frame, so let's go receive one for the network */
if ( !frame )
{
const int rc =
WebSocket_receiveFrame( net, bytes, actual_len );
if ( rc == TCPSOCKET_COMPLETE && in_frames && in_frames->first)
frame = in_frames->first->content;
}
if ( frame )
{
rv = (char *)frame + sizeof(struct ws_frame) + frame->pos;
*actual_len = frame->len - frame->pos;
if ( *actual_len == bytes && in_frames)
{
/* set new frame as current frame */
if ( last_frame )
free( last_frame );
last_frame = ListDetachHead(in_frames);
}
}
}
else
rv = WebSocket_getRawSocketData(net, bytes, actual_len);
return rv;
}
/**
* reads raw socket data for underlying layers
*
* @param[in] net network connection
* @param[in] bytes number of bytes to read, 0 to complete packet
* @param[in] actual_len amount of data read
*
* @return a buffer containing raw data
*/
char *WebSocket_getRawSocketData(
networkHandles *net, size_t bytes, size_t* actual_len)
{
char *rv;
#if defined(OPENSSL)
if ( net->ssl )
rv = SSLSocket_getdata(net->ssl, net->socket, bytes, actual_len);
else
#endif
rv = Socket_getdata(net->socket, bytes, actual_len);
return rv;
}
/**
* sends a "websocket pong" message
*
* @param[in] net network connection
* @param[in] app_data application data to put in payload
* @param[in] app_data_len application data length
*/
void WebSocket_pong(networkHandles *net, char *app_data,
size_t app_data_len)
{
if ( net->websocket )
{
char *buf0;
size_t header_len;
int freeData = 0;
const int mask_data = 0;
header_len = WebSocket_calculateFrameHeaderSize(net, mask_data,
app_data_len);
buf0 = malloc(header_len);
if ( !buf0 ) return;
WebSocket_buildFrame( net, WebSocket_OP_PONG, 1,
&buf0[header_len], header_len, mask_data, &app_data,
&app_data_len );
Log(TRACE_PROTOCOL, 1, "Sending WebSocket PONG" );
#if defined(OPENSSL)
if (net->ssl)
SSLSocket_putdatas(net->ssl, net->socket, buf0,
header_len + app_data_len, 1,
&app_data, &app_data_len, &freeData);
else
#endif
Socket_putdatas(net->socket, buf0,
header_len + app_data_len, 1,
&app_data, &app_data_len, &freeData );
/* clean up memory */
free( buf0 );
}
}
/**
* writes data to a socket (websocket header will be prepended if required)
*
* @warning buf0 will be expanded (backwords before @p buf0 buffer, to add a
* websocket frame header to the data if required). So use
* @p WebSocket_calculateFrameHeader, to determine if extra space is needed
* before the @p buf0 pointer.
*
* @param[in,out] net network connection
* @param[in,out] buf0 first buffer
* @param[in] buf0len size of first buffer
* @param[in] count number of payload buffers
* @param[in,out] buffers array of paylaod buffers
* @param[in] buflens array of payload buffer sizes
* @param[in] freeData array indicating to free payload buffers
*
* @return amount of data wrote to socket
*
* @see WebSocket_calculateFrameHeaderSize
*/
int WebSocket_putdatas(networkHandles* net, char* buf0, size_t buf0len,
int count, char** buffers, size_t* buflens, int* freeData)
{
int rc;
/* prepend WebSocket frame */
if ( net->websocket )
{
size_t data_len = buf0len + 4u;
size_t header_len;
const int mask_data = 1;
for (rc = 0; rc < count; ++rc)
data_len += buflens[rc];
header_len = WebSocket_calculateFrameHeaderSize(
net, mask_data, data_len);
rc = WebSocket_buildFrame(
net, WebSocket_OP_BINARY, mask_data, buf0, buf0len,
count, buffers, buflens );
/* header added so adjust buffer */
if ( rc > 0 )
{
buf0 -= header_len;
buf0len += header_len;
}
}
#if defined(OPENSSL)
if (net->ssl)
rc = SSLSocket_putdatas(net->ssl, net->socket, buf0, buf0len, count, buffers, buflens, freeData);
else
#endif
rc = Socket_putdatas(net->socket, buf0, buf0len, count, buffers, buflens, freeData);
return rc;
}
/**
* receives incoming socket data and parses websocket frames
*
* @param[in] net network connection
* @param[in] bytes amount of data to receive
* @param[out] actual_len amount of data actually read
*
* @retval TCPSOCKET_COMPLETE packet received
* @retval TCPSOCKET_INTERRUPTED incomplete packet received
* @retval SOCKET_ERROR an error was encountered
*/
int WebSocket_receiveFrame(networkHandles *net,
size_t bytes, size_t *actual_len )
{
struct ws_frame *res = NULL;
if ( !in_frames )
in_frames = ListInitialize();
/* see if there is frame acurrently on queue */
if ( in_frames->first )
res = in_frames->first->content;
while( !res )
{
int opcode = WebSocket_OP_BINARY;
do
{
/* obtain all frames in the sequence */
int final = 0;
while ( !final )
{
char *b;
size_t len = 0u;
int tmp_opcode;
int has_mask;
size_t cur_len = 0u;
uint8_t mask[4] = { 0u, 0u, 0u, 0u };
size_t payload_len;
b = WebSocket_getRawSocketData(net, 2u, &len);
if ( !b || len == 0u )
return TCPSOCKET_INTERRUPTED;
/* 1st byte */
final = (b[0] & 0xFF) >> 7;
tmp_opcode = (b[0] & 0x0F);
if ( tmp_opcode ) /* not a continuation frame */
opcode = tmp_opcode;
/* invalid websocket packet must return error */
if ( opcode < WebSocket_OP_CONTINUE ||
opcode > WebSocket_OP_PONG ||
( opcode > WebSocket_OP_BINARY &&
opcode < WebSocket_OP_CLOSE ) )
return SOCKET_ERROR;
/* 2nd byte */
has_mask = (b[1] & 0xFF) >> 7;
payload_len = (b[1] & 0x7F);
/* determine payload length */
if ( payload_len == 126 )
{
b = WebSocket_getRawSocketData( net,
2u, &len);
if ( !b || len == 0u )
return TCPSOCKET_INTERRUPTED;
/* convert from big endian 16 to host */
payload_len = be16toh(*(uint16_t*)b);
}
else if ( payload_len == 127 )
{
b = WebSocket_getRawSocketData( net,
8u, &len);
if ( !b || len == 0u )
return TCPSOCKET_INTERRUPTED;
/* convert from big-endian 64 to host */
payload_len = (size_t)be64toh(*(uint64_t*)b);
}
if ( has_mask )
{
uint8_t mask[4];
b = WebSocket_getRawSocketData(net, 4u, &len);
if ( !b || len == 0u )
return TCPSOCKET_INTERRUPTED;
memcpy( &mask[0], b, sizeof(uint32_t));
}
b = WebSocket_getRawSocketData(net,
payload_len, &len);
if ( !b || len == 0u )
return TCPSOCKET_INTERRUPTED;
/* unmask data */
if ( has_mask )
{
size_t i;
for ( i = 0u; i < payload_len; ++i )
b[i] ^= mask[i % 4];
}
if ( res )
cur_len = res->len;
res = realloc( res, sizeof(struct ws_frame) + cur_len + len );
memcpy( (unsigned char *)res + sizeof(struct ws_frame) + cur_len, b, len );
res->pos = 0u;
res->len = cur_len + len;
WebSocket_getRawSocketData(net, 0u, &len);
}
if ( opcode == WebSocket_OP_PONG || opcode == WebSocket_OP_PONG )
{
/* respond to a "ping" with a "pong" */
if ( opcode == WebSocket_OP_PING )
WebSocket_pong( net,
(char *)res + sizeof(struct ws_frame),
res->len );
/* discard message */
free( res );
res = NULL;
}
else if ( opcode == WebSocket_OP_CLOSE )
{
/* server end closed websocket connection */
free( res );
WebSocket_close( net, WebSocket_CLOSE_GOING_AWAY, NULL );
return SOCKET_ERROR; /* closes socket */
}
} while ( opcode == WebSocket_OP_PING || opcode == WebSocket_OP_PONG );
}
/* add new frame to end of list */
ListAppend( in_frames, res, sizeof(struct ws_frame) + res->len);
*actual_len = res->len - res->pos;
return TCPSOCKET_COMPLETE;
}
/**
* case-insensitive string search
*
* similar to @p strcase, but takes a maximum length
*
* @param[in] buf buffer to search
* @param[in] str string to find
* @param[in] len length of the buffer
*
* @retval !NULL location of string found
* @retval NULL string not found
*/
const char *WebSocket_strcasefind(const char *buf, const char *str, size_t len)
{
const char *res = NULL;
if ( buf && len > 0u && str )
{
const size_t str_len = strlen( str );
while ( len >= str_len && !res )
{
if ( strncasecmp( buf, str, str_len ) == 0 )
res = buf;
++buf;
--len;
}
}
return res;
}
/**
* releases resources used by the websocket sub-system
*/
void WebSocket_terminate( void )
{
/* clean up and un-processed websocket frames */
if ( in_frames )
{
struct ws_frame *f = ListDetachHead( in_frames );
while ( f )
{
free( f );
f = ListDetachHead( in_frames );
}
ListFree( in_frames );
in_frames = NULL;
}
if ( last_frame )
{
free( last_frame );
last_frame = NULL;
}
Socket_outTerminate();
#if defined(OPENSSL)
SSLSocket_terminate();
#endif
}
/**
* handles the websocket upgrade response
*
* @param[in,out] net network connection to upgrade
*
* @retval SOCKET_ERROR failed to upgrade network connection
* @retval 1 socket upgraded to use websockets
*
* @see WebSocket_connect
*/
int WebSocket_upgrade( networkHandles *net )
{
static const char *const ws_guid =
"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
int rc = SOCKET_ERROR;
if ( net->websocket_key )
{
SHA_CTX ctx;
char ws_key[62u] = { 0 };
unsigned char sha_hash[SHA1_DIGEST_LENGTH];
size_t rcv = 0u;
char *read_buf;
/* calculate the expected websocket key, expected from server */
snprintf( ws_key, sizeof(ws_key), "%s%s", net->websocket_key, ws_guid );
SHA1_Init( &ctx );
SHA1_Update( &ctx, ws_key, strlen(ws_key));
SHA1_Final( sha_hash, &ctx );
Base64_encode( ws_key, sizeof(ws_key), sha_hash, SHA1_DIGEST_LENGTH );
rc = TCPSOCKET_INTERRUPTED;
read_buf = WebSocket_getRawSocketData( net, 12u, &rcv );
if ( rcv > 0 && strncmp( read_buf, "HTTP/1.1 101", 11u ) == 0 )
{
const char *p;
read_buf = WebSocket_getRawSocketData( net, 500u, &rcv );
/* check for upgrade */
p = WebSocket_strcasefind(
read_buf, "Connection", rcv );
if ( p )
{
const char *eol;
eol = memchr( p, '\n', rcv-(read_buf-p) );
if ( eol )
p = WebSocket_strcasefind(
p, "Upgrade", eol - p);
else
p = NULL;
}
/* check key hash */
if ( p )
p = WebSocket_strcasefind( read_buf,
"sec-websocket-accept", rcv );
if ( p )
{
const char *eol;
eol = memchr( p, '\n', rcv-(read_buf-p) );
if ( eol )
{
p = memchr( p, ':', eol-p );
if ( p )
{
size_t hash_len = eol-p-1;
while ( *p == ':' || *p == ' ' )
{
++p;
--hash_len;
}
if ( strncmp( p, ws_key, hash_len ) != 0 )
p = NULL;
}
}
else
p = NULL;
}
if ( p )
{
net->websocket = 1;
Log(TRACE_PROTOCOL, 1, "WebSocket connection upgraded" );
rc = 1;
}
else
{
Log(TRACE_PROTOCOL, 1, "WebSocket failed to upgrade connection" );
rc = SOCKET_ERROR;
}
if ( net->websocket_key )
{
free(net->websocket_key);
net->websocket_key = NULL;
}
/* indicate that we done with the packet */
WebSocket_getRawSocketData( net, 0u, &rcv );
}
}
return rc;
}
/*******************************************************************************
* Copyright (c) 2018 Wind River Systems, Inc. All Rights Reserved.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Keith Holman - initial implementation and documentation
*******************************************************************************/
#if !defined(WEBSOCKET_H)
#define WEBSOCKET_H
#include "Clients.h"
/**
* WebSocket op codes
* @{
*/
#define WebSocket_OP_CONTINUE 0x0 /* 0000 - continue frame */
#define WebSocket_OP_TEXT 0x1 /* 0001 - text frame */
#define WebSocket_OP_BINARY 0x2 /* 0010 - binary frame */
#define WebSocket_OP_CLOSE 0x8 /* 1000 - close frame */
#define WebSocket_OP_PING 0x9 /* 1001 - ping frame */
#define WebSocket_OP_PONG 0xA /* 1010 - pong frame */
/** @} */
/**
* Various close status codes
* @{
*/
#define WebSocket_CLOSE_NORMAL 1000
#define WebSocket_CLOSE_GOING_AWAY 1001
#define WebSocket_CLOSE_PROTOCOL_ERROR 1002
#define WebSocket_CLOSE_UNKNOWN_DATA 1003
#define WebSocket_CLOSE_RESERVED 1004
#define WebSocket_CLOSE_NO_STATUS_CODE 1005 /* reserved: not to be used */
#define WebSocket_CLOSE_ABNORMAL 1006 /* reserved: not to be used */
#define WebSocket_CLOSE_BAD_DATA 1007
#define WebSocket_CLOSE_POLICY 1008
#define WebSocket_CLOSE_MSG_TOO_BIG 1009
#define WebSocket_CLOSE_NO_EXTENSION 1010
#define WebScoket_CLOSE_UNEXPECTED 1011
#define WebSocket_CLOSE_TLS_FAIL 1015 /* reserved: not be used */
/** @} */
/* closes a websocket connection */
void WebSocket_close(networkHandles *net, int status_code, const char *reason);
/* sends upgrade request */
int WebSocket_connect(networkHandles *net, const char *uri);
/* calculates the extra data required in a packet to hold a WebSocket frame header */
size_t WebSocket_calculateFrameHeaderSize(networkHandles *net, int mask_data,
size_t data_len);
/* obtain data from network socket */
int WebSocket_getch(networkHandles *net, char* c);
char *WebSocket_getdata(networkHandles *net, size_t bytes, size_t* actual_len);
/* send data out, in websocket format only if required */
int WebSocket_putdatas(networkHandles* net, char* buf0, size_t buf0len,
int count, char** buffers, size_t* buflens, int* freeData);
/* releases any resources used by the websocket system */
void WebSocket_terminate(void);
/* handles websocket upgrade request */
int WebSocket_upgrade(networkHandles *net);
#endif /* WEBSOCKET_H */
......@@ -106,7 +106,7 @@ void onConnect(void* context, MQTTAsync_successData* response)
opts.context = client;
pubmsg.payload = PAYLOAD;
pubmsg.payloadlen = strlen(PAYLOAD);
pubmsg.payloadlen = (int)strlen(PAYLOAD);
pubmsg.qos = QOS;
pubmsg.retained = 0;
deliveredtoken = 0;
......
......@@ -45,7 +45,7 @@ int main(int argc, char* argv[])
exit(EXIT_FAILURE);
}
pubmsg.payload = PAYLOAD;
pubmsg.payloadlen = strlen(PAYLOAD);
pubmsg.payloadlen = (int)strlen(PAYLOAD);
pubmsg.qos = QOS;
pubmsg.retained = 0;
MQTTClient_publishMessage(client, TOPIC, &pubmsg, &token);
......
......@@ -547,7 +547,7 @@ int test2(struct Options options)
opts.MQTTVersion = options.MQTTVersion;
opts.username = "testuser";
opts.binarypwd.data = "testpassword";
opts.binarypwd.len = strlen(opts.binarypwd.data);
opts.binarypwd.len = (int)strlen(opts.binarypwd.data);
if (options.haconnections != NULL)
{
opts.serverURIs = options.haconnections;
......@@ -1068,7 +1068,7 @@ int test6a(struct Options options)
opts.MQTTVersion = MQTTVERSION_3_1_1;
opts.will = &wopts;
opts.will->payload.data = test6_will_message;
opts.will->payload.len = strlen(test6_will_message) + 1;
opts.will->payload.len = (int)strlen(test6_will_message) + 1;
opts.will->qos = 1;
opts.will->retained = 0;
opts.will->topicName = test6_will_topic;
......
......@@ -471,7 +471,7 @@ int test2(struct Options options)
opts.cleansession = 1;
opts.username = "testuser";
opts.binarypwd.data = "testpassword";
opts.binarypwd.len = strlen(opts.binarypwd.data);
opts.binarypwd.len = (int)strlen(opts.binarypwd.data);
opts.MQTTVersion = options.MQTTVersion;
opts.will = &wopts;
......
......@@ -1726,7 +1726,7 @@ int test6(struct Options options)
/* let client c go: connect, and send disconnect command to proxy */
opts.will = &wopts;
opts.will->payload.data = "will message";
opts.will->payload.len = strlen(opts.will->payload.data) + 1;
opts.will->payload.len = (int)strlen(opts.will->payload.data) + 1;
opts.will->qos = 1;
opts.will->retained = 0;
opts.will->topicName = willTopic;
......@@ -1870,7 +1870,6 @@ void test7cOnConnectSuccess(void* context, MQTTAsync_successData* response)
{
MQTTAsync c = (MQTTAsync)context;
MQTTAsync_message pubmsg = MQTTAsync_message_initializer;
int rc;
MyLog(LOGA_DEBUG, "In connect onSuccess callback for client c, context %p\n", context);
......@@ -1904,7 +1903,6 @@ void test7dOnConnectSuccess(void* context, MQTTAsync_successData* response)
{
MQTTAsync c = (MQTTAsync)context;
MQTTAsync_responseOptions opts = MQTTAsync_responseOptions_initializer;
int rc;
int qoss[2] = {2, 2};
char* topics[2] = {willTopic, test_topic};
......@@ -1933,7 +1931,6 @@ int test7(struct Options options)
char clientidc[50];
char clientidd[50];
int i = 0;
MQTTAsync_token *tokens;
test7_will_message_received = 0;
test7_messages_received = 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment