Commit 086111f1 authored by Keith Holman's avatar Keith Holman

WebSocket: initial websocket support + utilities

closes: #166

This patch provides an initial implementation for websocket support for
paho. For the websocket specification see RFC 6455.  The purpose of this
patch is to allow connnecting to an MQTT broker listening on a websocket
port (typically 80 [HTTP] or 443 [HTTPS]) to be able to communicate with
a client using the paho library.  Using websockets to communicate increases
the packet overhead both sending and receiving as well as additional setup
and ping packets.  However, using websockets allows for communications on
standard HTTP/HTTPS ports which are generally already configured by
firewalls to allow outside communications.

To use websockets, prefix the connection URI with either: "ws://" or
"wss://" for either websockets or secure websockets, repectfully.
Signed-off-by: 's avatarKeith Holman <keith.holman@windriver.com>
parent af104622
/*******************************************************************************
* Copyright (c) 2018 Wind River Systems, Inc. All Rights Reserved.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Keith Holman - initial implementation and documentation
*******************************************************************************/
#include "Base64.h"
#if defined(WIN32) || defined(WIN64)
#pragma comment(lib, "crypt32.lib")
#include <Windows.h>
#include <WinCrypt.h>
b64_size_t Base64_decode( b64_data_t *out, b64_size_t out_len, const char *in, b64_size_t in_len )
{
b64_size_t ret = 0u;
DWORD dw_out_len = (DWORD)out_len;
if ( CryptStringToBinaryA( in, in_len, CRYPT_STRING_BASE64, out, &dw_out_len, NULL, NULL ) )
ret = (b64_size_t)dw_out_len;
return ret;
}
b64_size_t Base64_encode( char *out, b64_size_t out_len, const b64_data_t *in, b64_size_t in_len )
{
b64_size_t ret = 0u;
DWORD dw_out_len = (DWORD)out_len;
if ( CryptBinaryToStringA( in, in_len, CRYPT_STRING_BASE64 | CRYPT_STRING_NOCRLF, out, &dw_out_len ) )
ret = (b64_size_t)dw_out_len;
return ret;
}
#else /* if defined(WIN32) || defined(WIN64) */
#if defined(OPENSSL)
#include <openssl/bio.h>
#include <openssl/evp.h>
static b64_size_t Base64_encodeDecode(
char *out, b64_size_t out_len, const char *in, b64_size_t in_len, int encode )
{
b64_size_t ret = 0u;
if ( in_len > 0u )
{
int rv;
BIO *bio, *b64, *b_in, *b_out;
b64 = BIO_new(BIO_f_base64());
bio = BIO_new(BIO_s_mem());
b64 = BIO_push(b64, bio);
BIO_set_flags(b64, BIO_FLAGS_BASE64_NO_NL); /* ignore new-lines */
if ( encode )
{
b_in = bio;
b_out = b64;
}
else
{
b_in = b64;
b_out = bio;
}
rv = BIO_write(b_out, in, (int)in_len);
BIO_flush(b_out); /* indicate end of encoding */
if ( rv > 0 )
{
rv = BIO_read(b_in, out, (int)out_len);
if ( rv > 0 )
{
ret = (b64_size_t)rv;
if ( out_len > ret )
out[ret] = '\0';
}
}
BIO_free_all(b64); /* free all used memory */
}
return ret;
}
b64_size_t Base64_decode( b64_data_t *out, b64_size_t out_len, const char *in, b64_size_t in_len )
{
return Base64_encodeDecode( (char*)out, out_len, in, in_len, 0 );
}
b64_size_t Base64_encode( char *out, b64_size_t out_len, const b64_data_t *in, b64_size_t in_len )
{
return Base64_encodeDecode( out, out_len, (const char*)in, in_len, 1 );
}
#else /* if defined(OPENSSL) */
b64_size_t Base64_decode( b64_data_t *out, b64_size_t out_len, const char *in, b64_size_t in_len )
{
#define NV 64
static const unsigned char BASE64_DECODE_TABLE[] =
{
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 0-15 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 16-31 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, 62, NV, NV, NV, 63, /* 32-47 */
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, NV, NV, NV, NV, NV, NV, /* 48-63 */
NV, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, /* 64-79 */
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, NV, NV, NV, NV, NV, /* 80-95 */
NV, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, /* 96-111 */
41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, NV, NV, NV, NV, NV, /* 112-127 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 128-143 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 144-159 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 160-175 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 176-191 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 192-207 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 208-223 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, /* 224-239 */
NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV, NV /* 240-255 */
};
b64_size_t ret = 0u;
b64_size_t out_count = 0u;
/* in valid base64, length must be multiple of 4's: 0, 4, 8, 12, etc */
while ( in_len > 3u && out_count < out_len )
{
int i;
unsigned char c[4];
for ( i = 0; i < 4; ++i, ++in )
c[i] = BASE64_DECODE_TABLE[(int)(*in)];
in_len -= 4u;
/* first byte */
*out = c[0] << 2;
*out |= (c[1] & ~0xF) >> 4;
++out;
++out_count;
if ( out_count < out_len )
{
/* second byte */
*out = (c[1] & 0xF) << 4;
if ( c[2] < NV )
{
*out |= (c[2] & ~0x3) >> 2;
++out;
++out_count;
if ( out_count < out_len )
{
/* third byte */
*out = (c[2] & 0x3) << 6;
if ( c[3] < NV )
{
*out |= c[3];
++out;
++out_count;
}
else
in_len = 0u;
}
}
else
in_len = 0u;
}
}
if ( out_count <= out_len )
{
ret = out_count;
if ( out_count < out_len )
*out = '\0';
}
return ret;
}
b64_size_t Base64_encode( char *out, b64_size_t out_len, const b64_data_t *in, b64_size_t in_len )
{
static const char BASE64_ENCODE_TABLE[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/=";
b64_size_t ret = 0u;
b64_size_t out_count = 0u;
while ( in_len > 0u && out_count < out_len )
{
int i;
unsigned char c[] = { 0, 0, 64, 64 }; /* index of '=' char */
/* first character */
i = *in;
c[0] = (i & ~0x3) >> 2;
/* second character */
c[1] = (i & 0x3) << 4;
--in_len;
if ( in_len > 0u )
{
++in;
i = *in;
c[1] |= (i & ~0xF) >> 4;
/* third character */
c[2] = (i & 0xF) << 2;
--in_len;
if ( in_len > 0u )
{
++in;
i = *in;
c[2] |= (i & ~0x3F) >> 6;
/* fourth character */
c[3] = (i & 0x3F);
--in_len;
++in;
}
}
/* encode the characters */
out_count += 4u;
for ( i = 0; i < 4 && out_count <= out_len; ++i, ++out )
*out = BASE64_ENCODE_TABLE[c[i]];
}
if ( out_count <= out_len )
{
if ( out_count < out_len )
*out = '\0';
ret = out_count;
}
return ret;
}
#endif /* else if defined(OPENSSL) */
#endif /* if else defined(WIN32) || defined(WIN64) */
b64_size_t Base64_decodeLength( const char *in, b64_size_t in_len )
{
b64_size_t pad = 0u;
if ( in && in_len > 1u )
pad += ( in[in_len - 2u] == '=' ? 1u : 0u );
if ( in && in_len > 0u )
pad += ( in[in_len - 1u] == '=' ? 1u : 0u );
return (in_len / 4u * 3u) - pad;
}
b64_size_t Base64_encodeLength( const b64_data_t *in, b64_size_t in_len )
{
return ((4u * in_len / 3u) + 3u) & ~0x3;
}
#if defined(BASE64_TEST)
#include <stdio.h>
#include <string.h>
#define TEST_EXPECT(i,x) if (!(x)) {fprintf( stderr, "failed test: %s (for i == %d)\n", #x, i ); ++fails;}
int main(int argc, char *argv[])
{
struct _td
{
const char *in;
const char *out;
};
int i;
unsigned int fails = 0u;
struct _td test_data[] = {
{ "", "" },
{ "p", "cA==" },
{ "pa", "cGE=" },
{ "pah", "cGFo" },
{ "paho", "cGFobw==" },
{ "paho ", "cGFobyA=" },
{ "paho w", "cGFobyB3" },
{ "paho wi", "cGFobyB3aQ==" },
{ "paho wit", "cGFobyB3aXQ=" },
{ "paho with", "cGFobyB3aXRo" },
{ "paho with ", "cGFobyB3aXRoIA==" },
{ "paho with w", "cGFobyB3aXRoIHc=" },
{ "paho with we", "cGFobyB3aXRoIHdl" },
{ "paho with web", "cGFobyB3aXRoIHdlYg==" },
{ "paho with webs", "cGFobyB3aXRoIHdlYnM=" },
{ "paho with webso", "cGFobyB3aXRoIHdlYnNv" },
{ "paho with websoc", "cGFobyB3aXRoIHdlYnNvYw==" },
{ "paho with websock", "cGFobyB3aXRoIHdlYnNvY2s=" },
{ "paho with websocke", "cGFobyB3aXRoIHdlYnNvY2tl" },
{ "paho with websocket", "cGFobyB3aXRoIHdlYnNvY2tldA==" },
{ "paho with websockets", "cGFobyB3aXRoIHdlYnNvY2tldHM=" },
{ "paho with websockets.", "cGFobyB3aXRoIHdlYnNvY2tldHMu" },
{ "The quick brown fox jumps over the lazy dog",
"VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wcyBvdmVyIHRoZSBsYXp5IGRvZw==" },
{ "Man is distinguished, not only by his reason, but by this singular passion from\n"
"other animals, which is a lust of the mind, that by a perseverance of delight\n"
"in the continued and indefatigable generation of knowledge, exceeds the short\n"
"vehemence of any carnal pleasure.",
"TWFuIGlzIGRpc3Rpbmd1aXNoZWQsIG5vdCBvbmx5IGJ5IGhpcyByZWFzb24sIGJ1dCBieSB0aGlz"
"IHNpbmd1bGFyIHBhc3Npb24gZnJvbQpvdGhlciBhbmltYWxzLCB3aGljaCBpcyBhIGx1c3Qgb2Yg"
"dGhlIG1pbmQsIHRoYXQgYnkgYSBwZXJzZXZlcmFuY2Ugb2YgZGVsaWdodAppbiB0aGUgY29udGlu"
"dWVkIGFuZCBpbmRlZmF0aWdhYmxlIGdlbmVyYXRpb24gb2Yga25vd2xlZGdlLCBleGNlZWRzIHRo"
"ZSBzaG9ydAp2ZWhlbWVuY2Ugb2YgYW55IGNhcm5hbCBwbGVhc3VyZS4=" },
{ NULL, NULL }
};
/* decode tests */
i = 0;
while ( test_data[i].in != NULL )
{
int r;
char out[512u];
r = Base64_decode( out, sizeof(out), test_data[i].out, strlen(test_data[i].out) );
TEST_EXPECT( i, r == strlen(test_data[i].in) && strncmp(out, test_data[i].in, strlen(test_data[i].in)) == 0 );
++i;
}
/* decode length tests */
i = 0;
while ( test_data[i].in != NULL )
{
TEST_EXPECT( i, Base64_decodeLength(test_data[i].out, strlen(test_data[i].out)) == strlen(test_data[i].in));
++i;
}
/* encode tests */
i = 0;
while ( test_data[i].in != NULL )
{
int r;
char out[512u];
r = Base64_encode( out, sizeof(out), test_data[i].in, strlen(test_data[i].in) );
TEST_EXPECT( i, r == strlen(test_data[i].out) && strncmp(out, test_data[i].out, strlen(test_data[i].out)) == 0 );
++i;
}
/* encode length tests */
i = 0;
while ( test_data[i].in != NULL )
{
TEST_EXPECT( i, Base64_encodeLength(test_data[i].in, strlen(test_data[i].in)) == strlen(test_data[i].out) );
++i;
}
if ( fails )
printf( "%u test failed!\n", fails );
else
printf( "all tests passed\n" );
return fails;
}
#endif /* if defined(BASE64_TEST) */
/*******************************************************************************
* Copyright (c) 2018 Wind River Systems, Inc. All Rights Reserved.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Keith Holman - initial implementation and documentation
*******************************************************************************/
#if !defined(BASE64_H)
#define BASE64_H
/** type for size of a buffer, it saves passing around @p size_t (unsigned long long or unsigned long int) */
typedef unsigned int b64_size_t;
/** type for raw base64 data */
typedef unsigned char b64_data_t;
/**
* Decodes base64 data
*
* @param[out] out decoded data
* @param[in] out_len length of output buffer
* @param[in] in base64 string to decode
* @param[in] in_len length of input buffer
*
* @return the amount of data decoded
*
* @see Base64_decodeLength
* @see Base64_encode
*/
b64_size_t Base64_decode( b64_data_t *out, b64_size_t out_len,
const char *in, b64_size_t in_len );
/**
* Size of buffer required to decode base64 data
*
* @param[in] in base64 string to decode
* @param[in] in_len length of input buffer
*
* @return the size of buffer the decoded string would require
*
* @see Base64_decode
* @see Base64_encodeLength
*/
b64_size_t Base64_decodeLength( const char *in, b64_size_t in_len );
/**
* Encodes base64 data
*
* @param[out] out encode base64 string
* @param[in] out_len length of output buffer
* @param[in] in raw data to encode
* @param[in] in_len length of input buffer
*
* @return the amount of data encoded
*
* @see Base64_decode
* @see Base64_encodeLength
*/
b64_size_t Base64_encode( char *out, b64_size_t out_len,
const b64_data_t *in, b64_size_t in_len );
/**
* Size of buffer required to encode base64 data
*
* @param[in] in raw data to encode
* @param[in] in_len length of input buffer
*
* @return the size of buffer the encoded string would require
*
* @see Base64_decodeLength
* @see Base64_encode
*/
b64_size_t Base64_encodeLength( const b64_data_t *in, b64_size_t in_len );
#endif /* BASE64_H */
...@@ -46,6 +46,12 @@ SET(common_src ...@@ -46,6 +46,12 @@ SET(common_src
SocketBuffer.c SocketBuffer.c
Heap.c Heap.c
LinkedList.c LinkedList.c
Base64.c
Base64.h
SHA1.c
SHA1.h
WebSocket.c
WebSocket.h
) )
IF (WIN32) IF (WIN32)
......
...@@ -158,6 +158,8 @@ typedef struct ...@@ -158,6 +158,8 @@ typedef struct
SSL* ssl; SSL* ssl;
SSL_CTX* ctx; SSL_CTX* ctx;
#endif #endif
int websocket; /**< socket has been upgraded to use web sockets */
char *websocket_key;
} networkHandles; } networkHandles;
...@@ -168,8 +170,10 @@ typedef struct ...@@ -168,8 +170,10 @@ typedef struct
#define TCP_IN_PROGRESS 0x1 #define TCP_IN_PROGRESS 0x1
/** SSL connection in progress */ /** SSL connection in progress */
#define SSL_IN_PROGRESS 0x2 #define SSL_IN_PROGRESS 0x2
/** Websocket connection in progress */
#define WEBSOCKET_IN_PROGRESS 0x3
/** TCP completed, waiting for MQTT ACK */ /** TCP completed, waiting for MQTT ACK */
#define WAIT_FOR_CONNACK 0x3 #define WAIT_FOR_CONNACK 0x4
/** Disconnecting */ /** Disconnecting */
#define DISCONNECTING -2 #define DISCONNECTING -2
......
...@@ -62,8 +62,11 @@ ...@@ -62,8 +62,11 @@
#include "StackTrace.h" #include "StackTrace.h"
#include "Heap.h" #include "Heap.h"
#include "OsWrapper.h" #include "OsWrapper.h"
#include "WebSocket.h"
#define URI_TCP "tcp://" #define URI_TCP "tcp://"
#define URI_WS "ws://"
#define URI_WSS "wss://"
#include "VersionInfo.h" #include "VersionInfo.h"
...@@ -294,6 +297,7 @@ typedef struct MQTTAsync_struct ...@@ -294,6 +297,7 @@ typedef struct MQTTAsync_struct
{ {
char* serverURI; char* serverURI;
int ssl; int ssl;
int websocket;
Clients* c; Clients* c;
/* "Global", to the client, callback definitions */ /* "Global", to the client, callback definitions */
...@@ -507,12 +511,23 @@ int MQTTAsync_createWithOptions(MQTTAsync* handle, const char* serverURI, const ...@@ -507,12 +511,23 @@ int MQTTAsync_createWithOptions(MQTTAsync* handle, const char* serverURI, const
memset(m, '\0', sizeof(MQTTAsyncs)); memset(m, '\0', sizeof(MQTTAsyncs));
if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) == 0) if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) == 0)
serverURI += strlen(URI_TCP); serverURI += strlen(URI_TCP);
else if (strncmp(URI_WS, serverURI, strlen(URI_WS)) == 0)
{
serverURI += strlen(URI_WS);
m->websocket = 1;
}
#if defined(OPENSSL) #if defined(OPENSSL)
else if (strncmp(URI_SSL, serverURI, strlen(URI_SSL)) == 0) else if (strncmp(URI_SSL, serverURI, strlen(URI_SSL)) == 0)
{ {
serverURI += strlen(URI_SSL); serverURI += strlen(URI_SSL);
m->ssl = 1; m->ssl = 1;
} }
else if (strncmp(URI_WSS, serverURI, strlen(URI_WSS)) == 0)
{
serverURI += strlen(URI_WSS);
m->ssl = 1;
m->websocket = 1;
}
#endif #endif
m->serverURI = MQTTStrdup(serverURI); m->serverURI = MQTTStrdup(serverURI);
m->responses = ListInitialize(); m->responses = ListInitialize();
...@@ -575,10 +590,7 @@ static void MQTTAsync_terminate(void) ...@@ -575,10 +590,7 @@ static void MQTTAsync_terminate(void)
MQTTAsync_freeCommand1((MQTTAsync_queuedCommand*)(elem->content)); MQTTAsync_freeCommand1((MQTTAsync_queuedCommand*)(elem->content));
ListFree(commands); ListFree(commands);
handles = NULL; handles = NULL;
Socket_outTerminate(); WebSocket_terminate();
#if defined(OPENSSL)
SSLSocket_terminate();
#endif
#if defined(HEAP_H) #if defined(HEAP_H)
Heap_terminate(); Heap_terminate();
#endif #endif
...@@ -1256,9 +1268,9 @@ static int MQTTAsync_processCommand(void) ...@@ -1256,9 +1268,9 @@ static int MQTTAsync_processCommand(void)
Log(TRACE_MIN, -1, "Connecting to serverURI %s with MQTT version %d", serverURI, command->command.details.conn.MQTTVersion); Log(TRACE_MIN, -1, "Connecting to serverURI %s with MQTT version %d", serverURI, command->command.details.conn.MQTTVersion);
#if defined(OPENSSL) #if defined(OPENSSL)
rc = MQTTProtocol_connect(serverURI, command->client->c, command->client->ssl, command->command.details.conn.MQTTVersion); rc = MQTTProtocol_connect(serverURI, command->client->c, command->client->ssl, command->client->websocket, command->command.details.conn.MQTTVersion);
#else #else
rc = MQTTProtocol_connect(serverURI, command->client->c, command->command.details.conn.MQTTVersion); rc = MQTTProtocol_connect(serverURI, command->client->c, command->client->websocket, command->command.details.conn.MQTTVersion);
#endif #endif
if (command->client->c->connect_state == NOT_IN_PROGRESS) if (command->client->c->connect_state == NOT_IN_PROGRESS)
rc = SOCKET_ERROR; rc = SOCKET_ERROR;
...@@ -2118,6 +2130,7 @@ static void MQTTAsync_closeOnly(Clients* client) ...@@ -2118,6 +2130,7 @@ static void MQTTAsync_closeOnly(Clients* client)
if (client->connected && Socket_noPendingWrites(client->net.socket)) if (client->connected && Socket_noPendingWrites(client->net.socket))
MQTTPacket_send_disconnect(&client->net, client->clientID); MQTTPacket_send_disconnect(&client->net, client->clientID);
Thread_lock_mutex(socket_mutex); Thread_lock_mutex(socket_mutex);
WebSocket_close(&client->net, WebSocket_CLOSE_NORMAL, NULL);
#if defined(OPENSSL) #if defined(OPENSSL)
SSLSocket_close(&client->net); SSLSocket_close(&client->net);
#endif #endif
...@@ -2924,13 +2937,12 @@ static int MQTTAsync_connecting(MQTTAsyncs* m) ...@@ -2924,13 +2937,12 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
if (m->ssl) if (m->ssl)
{ {
int port; int port;
char* hostname; size_t hostname_len;
int setSocketForSSLrc = 0; int setSocketForSSLrc = 0;
hostname = MQTTProtocol_addressPort(m->serverURI, &port); hostname_len = MQTTProtocol_addressPort(m->serverURI, &port, NULL);
setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts, hostname); setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts,
if (hostname != m->serverURI) m->serverURI, hostname_len);
free(hostname);
if (setSocketForSSLrc != MQTTASYNC_SUCCESS) if (setSocketForSSLrc != MQTTASYNC_SUCCESS)
{ {
...@@ -2951,12 +2963,21 @@ static int MQTTAsync_connecting(MQTTAsyncs* m) ...@@ -2951,12 +2963,21 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
} }
else if (rc == 1) else if (rc == 1)
{ {
rc = MQTTCLIENT_SUCCESS; if ( m->websocket )
m->c->connect_state = WAIT_FOR_CONNACK;
if (MQTTPacket_send_connect(m->c, m->connect.details.conn.MQTTVersion) == SOCKET_ERROR)
{ {
rc = SOCKET_ERROR; m->c->connect_state = WEBSOCKET_IN_PROGRESS;
goto exit; if ((rc = WebSocket_connect(&m->c->net, m->serverURI)) == SOCKET_ERROR )
goto exit;
}
else
{
rc = MQTTCLIENT_SUCCESS;
m->c->connect_state = WAIT_FOR_CONNACK;
if (MQTTPacket_send_connect(m->c, m->connect.details.conn.MQTTVersion) == SOCKET_ERROR)
{
rc = SOCKET_ERROR;
goto exit;
}
} }
if (!m->c->cleansession && m->c->session == NULL) if (!m->c->cleansession && m->c->session == NULL)
m->c->session = SSL_get1_session(m->c->net.ssl); m->c->session = SSL_get1_session(m->c->net.ssl);
...@@ -2971,9 +2992,18 @@ static int MQTTAsync_connecting(MQTTAsyncs* m) ...@@ -2971,9 +2992,18 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
else else
{ {
#endif #endif
m->c->connect_state = WAIT_FOR_CONNACK; /* TCP/SSL connect completed, in which case send the MQTT connect packet */ if ( m->websocket )
if ((rc = MQTTPacket_send_connect(m->c, m->connect.details.conn.MQTTVersion)) == SOCKET_ERROR) {
goto exit; m->c->connect_state = WEBSOCKET_IN_PROGRESS;
if ((rc = WebSocket_connect(&m->c->net, m->serverURI)) == SOCKET_ERROR )
goto exit;
}
else
{
m->c->connect_state = WAIT_FOR_CONNACK; /* TCP/SSL connect completed, in which case send the MQTT connect packet */
if ((rc = MQTTPacket_send_connect(m->c, m->connect.details.conn.MQTTVersion)) == SOCKET_ERROR)
goto exit;
}
#if defined(OPENSSL) #if defined(OPENSSL)
} }
#endif #endif
...@@ -2987,14 +3017,34 @@ static int MQTTAsync_connecting(MQTTAsyncs* m) ...@@ -2987,14 +3017,34 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
if(!m->c->cleansession && m->c->session == NULL) if(!m->c->cleansession && m->c->session == NULL)
m->c->session = SSL_get1_session(m->c->net.ssl); m->c->session = SSL_get1_session(m->c->net.ssl);
m->c->connect_state = WAIT_FOR_CONNACK; /* SSL connect completed, in which case send the MQTT connect packet */ if ( m->websocket )
if ((rc = MQTTPacket_send_connect(m->c, m->connect.details.conn.MQTTVersion)) == SOCKET_ERROR) {
goto exit; m->c->connect_state = WEBSOCKET_IN_PROGRESS;
if ((rc = WebSocket_connect(&m->c->net, m->serverURI)) == SOCKET_ERROR )
goto exit;
}
else
{
m->c->connect_state = WAIT_FOR_CONNACK; /* SSL connect completed, in which case send the MQTT connect packet */
if ((rc = MQTTPacket_send_connect(m->c, m->connect.details.conn.MQTTVersion)) == SOCKET_ERROR)
goto exit;
}
} }
#endif #endif
else if (m->c->connect_state == WEBSOCKET_IN_PROGRESS) /* Websocket connect sent - wait for completion */
{
if ((rc = WebSocket_upgrade( &m->c->net ) ) == SOCKET_ERROR )
goto exit;
else
{
m->c->connect_state = WAIT_FOR_CONNACK; /* Websocket upgrade completed, in which case send the MQTT connect packet */
if ((rc = MQTTPacket_send_connect(m->c, m->connect.details.conn.MQTTVersion)) == SOCKET_ERROR)
goto exit;
}
}
exit: exit:
if ((rc != 0 && rc != TCPSOCKET_INTERRUPTED && m->c->connect_state != SSL_IN_PROGRESS) || (rc == SSL_FATAL)) if ((rc != 0 && rc != TCPSOCKET_INTERRUPTED && (m->c->connect_state != SSL_IN_PROGRESS && m->c->connect_state != WEBSOCKET_IN_PROGRESS)) || (rc == SSL_FATAL))
nextOrClose(m, MQTTASYNC_FAILURE, "TCP/TLS connect failure"); nextOrClose(m, MQTTASYNC_FAILURE, "TCP/TLS connect failure");
FUNC_EXIT_RC(rc); FUNC_EXIT_RC(rc);
...@@ -3035,7 +3085,7 @@ static MQTTPacket* MQTTAsync_cycle(int* sock, unsigned long timeout, int* rc) ...@@ -3035,7 +3085,7 @@ static MQTTPacket* MQTTAsync_cycle(int* sock, unsigned long timeout, int* rc)
if (m != NULL) if (m != NULL)
{ {
Log(TRACE_MINIMUM, -1, "m->c->connect_state = %d",m->c->connect_state); Log(TRACE_MINIMUM, -1, "m->c->connect_state = %d",m->c->connect_state);
if (m->c->connect_state == TCP_IN_PROGRESS || m->c->connect_state == SSL_IN_PROGRESS) if (m->c->connect_state == TCP_IN_PROGRESS || m->c->connect_state == SSL_IN_PROGRESS || m->c->connect_state == WEBSOCKET_IN_PROGRESS)
*rc = MQTTAsync_connecting(m); *rc = MQTTAsync_connecting(m);
else else
pack = MQTTPacket_Factory(&m->c->net, rc); pack = MQTTPacket_Factory(&m->c->net, rc);
......
...@@ -71,9 +71,11 @@ ...@@ -71,9 +71,11 @@
#include "OsWrapper.h" #include "OsWrapper.h"
#define URI_TCP "tcp://" #define URI_TCP "tcp://"
#define URI_WS "ws://"
#define URI_WSS "wss://"
#include "VersionInfo.h" #include "VersionInfo.h"
#include "WebSocket.h"
const char *client_timestamp_eye = "MQTTClientV3_Timestamp " BUILD_TIMESTAMP; const char *client_timestamp_eye = "MQTTClientV3_Timestamp " BUILD_TIMESTAMP;
const char *client_version_eye = "MQTTClientV3_Version " CLIENT_VERSION; const char *client_version_eye = "MQTTClientV3_Version " CLIENT_VERSION;
...@@ -195,6 +197,7 @@ typedef struct ...@@ -195,6 +197,7 @@ typedef struct
#if defined(OPENSSL) #if defined(OPENSSL)
int ssl; int ssl;
#endif #endif
int websocket;
Clients* c; Clients* c;
MQTTClient_connectionLost* cl; MQTTClient_connectionLost* cl;
MQTTClient_messageArrived* ma; MQTTClient_messageArrived* ma;
...@@ -350,19 +353,36 @@ int MQTTClient_create(MQTTClient* handle, const char* serverURI, const char* cli ...@@ -350,19 +353,36 @@ int MQTTClient_create(MQTTClient* handle, const char* serverURI, const char* cli
#endif #endif
initialized = 1; initialized = 1;
} }
m = malloc(sizeof(MQTTClients)); m = malloc(sizeof(MQTTClients));
*handle = m; *handle = m;
memset(m, '\0', sizeof(MQTTClients)); memset(m, '\0', sizeof(MQTTClients));
if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) == 0) if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) == 0)
serverURI += strlen(URI_TCP); serverURI += strlen(URI_TCP);
else if (strncmp(URI_WS, serverURI, strlen(URI_WS)) == 0)
{
serverURI += strlen(URI_WS);
m->websocket = 1;
}
else if (strncmp(URI_SSL, serverURI, strlen(URI_SSL)) == 0) else if (strncmp(URI_SSL, serverURI, strlen(URI_SSL)) == 0)
{ {
#if defined(OPENSSL) #if defined(OPENSSL)
serverURI += strlen(URI_SSL); serverURI += strlen(URI_SSL);
m->ssl = 1; m->ssl = 1;
#else #else
rc = MQTTCLIENT_SSL_NOT_SUPPORTED; rc = MQTTCLIENT_SSL_NOT_SUPPORTED;
goto exit; goto exit;
#endif
}
else if (strncmp(URI_WSS, serverURI, strlen(URI_WSS)) == 0)
{
#if defined(OPENSSL)
serverURI += strlen(URI_WSS);
m->ssl = 1;
m->websocket = 1;
#else
rc = MQTTCLIENT_SSL_NOT_SUPPORTED;
goto exit;
#endif #endif
} }
m->serverURI = MQTTStrdup(serverURI); m->serverURI = MQTTStrdup(serverURI);
...@@ -407,10 +427,7 @@ static void MQTTClient_terminate(void) ...@@ -407,10 +427,7 @@ static void MQTTClient_terminate(void)
ListFree(bstate->clients); ListFree(bstate->clients);
ListFree(handles); ListFree(handles);
handles = NULL; handles = NULL;
Socket_outTerminate(); WebSocket_terminate();
#if defined(OPENSSL)
SSLSocket_terminate();
#endif
#if defined(HEAP_H) #if defined(HEAP_H)
Heap_terminate(); Heap_terminate();
#endif #endif
...@@ -681,6 +698,12 @@ static thread_return_type WINAPI MQTTClient_run(void* n) ...@@ -681,6 +698,12 @@ static thread_return_type WINAPI MQTTClient_run(void* n)
} }
} }
#endif #endif
else if (m->c->connect_state == WEBSOCKET_IN_PROGRESS && !Thread_check_sem(m->connect_sem))
{
Log(TRACE_MIN, -1, "Posting websocket handshake for client %s rc %d", m->c->clientID, m->rc);
m->c->connect_state = WAIT_FOR_CONNACK;
Thread_post_sem(m->connect_sem);
}
} }
} }
run_id = 0; run_id = 0;
...@@ -769,6 +792,8 @@ static void MQTTClient_closeSession(Clients* client) ...@@ -769,6 +792,8 @@ static void MQTTClient_closeSession(Clients* client)
if (client->connected) if (client->connected)
MQTTPacket_send_disconnect(&client->net, client->clientID); MQTTPacket_send_disconnect(&client->net, client->clientID);
Thread_lock_mutex(socket_mutex); Thread_lock_mutex(socket_mutex);
WebSocket_close(&client->net, WebSocket_CLOSE_NORMAL, NULL);
#if defined(OPENSSL) #if defined(OPENSSL)
SSLSocket_close(&client->net); SSLSocket_close(&client->net);
#endif #endif
...@@ -870,9 +895,9 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt ...@@ -870,9 +895,9 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt
Log(TRACE_MIN, -1, "Connecting to serverURI %s with MQTT version %d", serverURI, MQTTVersion); Log(TRACE_MIN, -1, "Connecting to serverURI %s with MQTT version %d", serverURI, MQTTVersion);
#if defined(OPENSSL) #if defined(OPENSSL)
rc = MQTTProtocol_connect(serverURI, m->c, m->ssl, MQTTVersion); rc = MQTTProtocol_connect(serverURI, m->c, m->ssl, m->websocket, MQTTVersion);
#else #else
rc = MQTTProtocol_connect(serverURI, m->c, MQTTVersion); rc = MQTTProtocol_connect(serverURI, m->c, m->websocket, MQTTVersion);
#endif #endif
if (rc == SOCKET_ERROR) if (rc == SOCKET_ERROR)
goto exit; goto exit;
...@@ -893,18 +918,17 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt ...@@ -893,18 +918,17 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt
rc = SOCKET_ERROR; rc = SOCKET_ERROR;
goto exit; goto exit;
} }
#if defined(OPENSSL) #if defined(OPENSSL)
if (m->ssl) if (m->ssl)
{ {
int port; int port;
char* hostname; size_t hostname_len;
const char *topic;
int setSocketForSSLrc = 0; int setSocketForSSLrc = 0;
hostname = MQTTProtocol_addressPort(m->serverURI, &port); hostname_len = MQTTProtocol_addressPort(m->serverURI, &port, &topic);
setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts, hostname); setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts,
if (hostname != m->serverURI) m->serverURI, hostname_len);
free(hostname);
if (setSocketForSSLrc != MQTTCLIENT_SUCCESS) if (setSocketForSSLrc != MQTTCLIENT_SUCCESS)
{ {
...@@ -922,15 +946,25 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt ...@@ -922,15 +946,25 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt
} }
else if (rc == 1) else if (rc == 1)
{ {
rc = MQTTCLIENT_SUCCESS; if (m->websocket)
m->c->connect_state = WAIT_FOR_CONNACK;
if (MQTTPacket_send_connect(m->c, MQTTVersion) == SOCKET_ERROR)
{ {
rc = SOCKET_ERROR; m->c->connect_state = WEBSOCKET_IN_PROGRESS;
goto exit; rc = WebSocket_connect(&m->c->net,m->serverURI);
if ( rc == SOCKET_ERROR )
goto exit;
}
else if ( rc == 1 )
{
rc = MQTTCLIENT_SUCCESS;
m->c->connect_state = WAIT_FOR_CONNACK;
if (MQTTPacket_send_connect(m->c, MQTTVersion) == SOCKET_ERROR)
{
rc = SOCKET_ERROR;
goto exit;
}
if (!m->c->cleansession && m->c->session == NULL)
m->c->session = SSL_get1_session(m->c->net.ssl);
} }
if (!m->c->cleansession && m->c->session == NULL)
m->c->session = SSL_get1_session(m->c->net.ssl);
} }
} }
else else
...@@ -939,18 +973,25 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt ...@@ -939,18 +973,25 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt
goto exit; goto exit;
} }
} }
#endif
else if (m->websocket)
{
m->c->connect_state = WEBSOCKET_IN_PROGRESS;
if ( WebSocket_connect(&m->c->net, m->serverURI) == SOCKET_ERROR )
{
rc = SOCKET_ERROR;
goto exit;
}
}
else else
{ {
#endif
m->c->connect_state = WAIT_FOR_CONNACK; /* TCP connect completed, in which case send the MQTT connect packet */ m->c->connect_state = WAIT_FOR_CONNACK; /* TCP connect completed, in which case send the MQTT connect packet */
if (MQTTPacket_send_connect(m->c, MQTTVersion) == SOCKET_ERROR) if (MQTTPacket_send_connect(m->c, MQTTVersion) == SOCKET_ERROR)
{ {
rc = SOCKET_ERROR; rc = SOCKET_ERROR;
goto exit; goto exit;
} }
#if defined(OPENSSL)
} }
#endif
} }
#if defined(OPENSSL) #if defined(OPENSSL)
...@@ -966,19 +1007,45 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt ...@@ -966,19 +1007,45 @@ static int MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_connectOpt
} }
if(!m->c->cleansession && m->c->session == NULL) if(!m->c->cleansession && m->c->session == NULL)
m->c->session = SSL_get1_session(m->c->net.ssl); m->c->session = SSL_get1_session(m->c->net.ssl);
m->c->connect_state = WAIT_FOR_CONNACK; /* TCP connect completed, in which case send the MQTT connect packet */ if ( m->websocket )
{
/* wait for websocket connect */
m->c->connect_state = WEBSOCKET_IN_PROGRESS;
rc = WebSocket_connect( &m->c->net, m->serverURI );
if ( rc != 1 )
{
rc = SOCKET_ERROR;
goto exit;
}
}
else
{
m->c->connect_state = WAIT_FOR_CONNACK; /* TCP connect completed, in which case send the MQTT connect packet */
if (MQTTPacket_send_connect(m->c, MQTTVersion) == SOCKET_ERROR)
{
rc = SOCKET_ERROR;
goto exit;
}
}
}
#endif
if (m->c->connect_state == WEBSOCKET_IN_PROGRESS) /* websocket request sent - wait for upgrade */
{
Thread_unlock_mutex(mqttclient_mutex);
MQTTClient_waitfor(handle, CONNECT, &rc, millisecsTimeout - MQTTClient_elapsed(start));
Thread_lock_mutex(mqttclient_mutex);
m->c->connect_state = WAIT_FOR_CONNACK; /* websocket upgrade complete */
if (MQTTPacket_send_connect(m->c, MQTTVersion) == SOCKET_ERROR) if (MQTTPacket_send_connect(m->c, MQTTVersion) == SOCKET_ERROR)
{ {
rc = SOCKET_ERROR; rc = SOCKET_ERROR;
goto exit; goto exit;
} }
} }
#endif
if (m->c->connect_state == WAIT_FOR_CONNACK) /* MQTT connect sent - wait for CONNACK */ if (m->c->connect_state == WAIT_FOR_CONNACK) /* MQTT connect sent - wait for CONNACK */
{ {
MQTTPacket* pack = NULL; MQTTPacket* pack = NULL;
Thread_unlock_mutex(mqttclient_mutex); Thread_unlock_mutex(mqttclient_mutex);
pack = MQTTClient_waitfor(handle, CONNACK, &rc, millisecsTimeout - MQTTClient_elapsed(start)); pack = MQTTClient_waitfor(handle, CONNACK, &rc, millisecsTimeout - MQTTClient_elapsed(start));
Thread_lock_mutex(mqttclient_mutex); Thread_lock_mutex(mqttclient_mutex);
...@@ -1255,12 +1322,23 @@ int MQTTClient_connect(MQTTClient handle, MQTTClient_connectOptions* options) ...@@ -1255,12 +1322,23 @@ int MQTTClient_connect(MQTTClient handle, MQTTClient_connectOptions* options)
if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) == 0) if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) == 0)
serverURI += strlen(URI_TCP); serverURI += strlen(URI_TCP);
else if (strncmp(URI_WS, serverURI, strlen(URI_WS)) == 0)
{
serverURI += strlen(URI_WS);
m->websocket = 1;
}
#if defined(OPENSSL) #if defined(OPENSSL)
else if (strncmp(URI_SSL, serverURI, strlen(URI_SSL)) == 0) else if (strncmp(URI_SSL, serverURI, strlen(URI_SSL)) == 0)
{ {
serverURI += strlen(URI_SSL); serverURI += strlen(URI_SSL);
m->ssl = 1; m->ssl = 1;
} }
else if (strncmp(URI_WSS, serverURI, strlen(URI_WSS)) == 0)
{
serverURI += strlen(URI_WSS);
m->ssl = 1;
m->websocket = 1;
}
#endif #endif
if ((rc = MQTTClient_connectURI(handle, options, serverURI)) == MQTTCLIENT_SUCCESS) if ((rc = MQTTClient_connectURI(handle, options, serverURI)) == MQTTCLIENT_SUCCESS)
break; break;
...@@ -1743,6 +1821,8 @@ static MQTTPacket* MQTTClient_cycle(int* sock, unsigned long timeout, int* rc) ...@@ -1743,6 +1821,8 @@ static MQTTPacket* MQTTClient_cycle(int* sock, unsigned long timeout, int* rc)
{ {
if (m->c->connect_state == TCP_IN_PROGRESS || m->c->connect_state == SSL_IN_PROGRESS) if (m->c->connect_state == TCP_IN_PROGRESS || m->c->connect_state == SSL_IN_PROGRESS)
*rc = 0; /* waiting for connect state to clear */ *rc = 0; /* waiting for connect state to clear */
else if (m->c->connect_state == WEBSOCKET_IN_PROGRESS)
*rc = WebSocket_upgrade(&m->c->net);
else else
{ {
pack = MQTTPacket_Factory(&m->c->net, rc); pack = MQTTPacket_Factory(&m->c->net, rc);
...@@ -1857,7 +1937,8 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r ...@@ -1857,7 +1937,8 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r
} }
} }
#endif #endif
else if (m->c->connect_state == WAIT_FOR_CONNACK) else if (m->c->connect_state == WEBSOCKET_IN_PROGRESS ||
m->c->connect_state == WAIT_FOR_CONNACK)
{ {
int error; int error;
socklen_t len = sizeof(error); socklen_t len = sizeof(error);
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#endif #endif
#include "Messages.h" #include "Messages.h"
#include "StackTrace.h" #include "StackTrace.h"
#include "WebSocket.h"
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
...@@ -108,11 +109,7 @@ void* MQTTPacket_Factory(networkHandles* net, int* error) ...@@ -108,11 +109,7 @@ void* MQTTPacket_Factory(networkHandles* net, int* error)
*error = SOCKET_ERROR; /* indicate whether an error occurred, or not */ *error = SOCKET_ERROR; /* indicate whether an error occurred, or not */
/* read the packet data from the socket */ /* read the packet data from the socket */
#if defined(OPENSSL) *error = WebSocket_getch(net, &header.byte);
*error = (net->ssl) ? SSLSocket_getch(net->ssl, net->socket, &header.byte) : Socket_getch(net->socket, &header.byte);
#else
*error = Socket_getch(net->socket, &header.byte);
#endif
if (*error != TCPSOCKET_COMPLETE) /* first byte is the header byte */ if (*error != TCPSOCKET_COMPLETE) /* first byte is the header byte */
goto exit; /* packet not read, *error indicates whether SOCKET_ERROR occurred */ goto exit; /* packet not read, *error indicates whether SOCKET_ERROR occurred */
...@@ -121,12 +118,7 @@ void* MQTTPacket_Factory(networkHandles* net, int* error) ...@@ -121,12 +118,7 @@ void* MQTTPacket_Factory(networkHandles* net, int* error)
goto exit; /* packet not read, *error indicates whether SOCKET_ERROR occurred */ goto exit; /* packet not read, *error indicates whether SOCKET_ERROR occurred */
/* now read the rest, the variable header and payload */ /* now read the rest, the variable header and payload */
#if defined(OPENSSL) data = WebSocket_getdata(net, remaining_length, &actual_len);
data = (net->ssl) ? SSLSocket_getdata(net->ssl, net->socket, remaining_length, &actual_len) :
Socket_getdata(net->socket, remaining_length, &actual_len);
#else
data = Socket_getdata(net->socket, remaining_length, &actual_len);
#endif
if (data == NULL) if (data == NULL)
{ {
*error = SOCKET_ERROR; *error = SOCKET_ERROR;
...@@ -178,29 +170,28 @@ int MQTTPacket_send(networkHandles* net, Header header, char* buffer, size_t buf ...@@ -178,29 +170,28 @@ int MQTTPacket_send(networkHandles* net, Header header, char* buffer, size_t buf
{ {
int rc; int rc;
size_t buf0len; size_t buf0len;
size_t ws_header;
char *buf; char *buf;
FUNC_ENTRY; FUNC_ENTRY;
buf = malloc(10); ws_header = WebSocket_calculateFrameHeaderSize(net, buflen + 10);
buf[0] = header.byte; buf = malloc(10 + ws_header);
buf0len = 1 + MQTTPacket_encode(&buf[1], buflen); if ( !buf ) return -1;
buf[ws_header] = header.byte;
buf0len = 1 + MQTTPacket_encode(&buf[ws_header + 1], buflen);
#if !defined(NO_PERSISTENCE) #if !defined(NO_PERSISTENCE)
if (header.bits.type == PUBREL) if (header.bits.type == PUBREL)
{ {
char* ptraux = buffer; char* ptraux = buffer;
int msgId = readInt(&ptraux); int msgId = readInt(&ptraux);
rc = MQTTPersistence_put(net->socket, buf, buf0len, 1, &buffer, &buflen, rc = MQTTPersistence_put(net->socket, &buf[ws_header], buf0len, 1, &buffer, &buflen,
header.bits.type, msgId, 0); header.bits.type, msgId, 0);
} }
#endif #endif
#if defined(OPENSSL) rc = WebSocket_putdatas(net, &buf[ws_header], buf0len, 1, &buffer, &buflen, &freeData);
if (net->ssl)
rc = SSLSocket_putdatas(net->ssl, net->socket, buf, buf0len, 1, &buffer, &buflen, &freeData);
else
#endif
rc = Socket_putdatas(net->socket, buf, buf0len, 1, &buffer, &buflen, &freeData);
if (rc == TCPSOCKET_COMPLETE) if (rc == TCPSOCKET_COMPLETE)
time(&(net->lastSent)); time(&(net->lastSent));
...@@ -225,30 +216,31 @@ int MQTTPacket_sends(networkHandles* net, Header header, int count, char** buffe ...@@ -225,30 +216,31 @@ int MQTTPacket_sends(networkHandles* net, Header header, int count, char** buffe
{ {
int i, rc; int i, rc;
size_t buf0len, total = 0; size_t buf0len, total = 0;
size_t ws_header;
char *buf; char *buf;
FUNC_ENTRY; FUNC_ENTRY;
buf = malloc(10);
buf[0] = header.byte;
for (i = 0; i < count; i++) for (i = 0; i < count; i++)
total += buflens[i]; total += buflens[i];
buf0len = 1 + MQTTPacket_encode(&buf[1], total);
ws_header = WebSocket_calculateFrameHeaderSize(net, total + 10);
buf = malloc(10 + ws_header);
if ( !buf ) return -1;
buf[ws_header] = header.byte;
buf0len = 1 + MQTTPacket_encode(&buf[ws_header + 1], total);
#if !defined(NO_PERSISTENCE) #if !defined(NO_PERSISTENCE)
if (header.bits.type == PUBLISH && header.bits.qos != 0) if (header.bits.type == PUBLISH && header.bits.qos != 0)
{ /* persist PUBLISH QoS1 and Qo2 */ { /* persist PUBLISH QoS1 and Qo2 */
char *ptraux = buffers[2]; char *ptraux = buffers[2];
int msgId = readInt(&ptraux); int msgId = readInt(&ptraux);
rc = MQTTPersistence_put(net->socket, buf, buf0len, count, buffers, buflens, rc = MQTTPersistence_put(net->socket, &buf[ws_header], buf0len, count, buffers, buflens,
header.bits.type, msgId, 0); header.bits.type, msgId, 0);
} }
#endif #endif
#if defined(OPENSSL) rc = WebSocket_putdatas(net, &buf[ws_header], buf0len, count, buffers, buflens, frees);
if (net->ssl)
rc = SSLSocket_putdatas(net->ssl, net->socket, buf, buf0len, count, buffers, buflens, frees);
else
#endif
rc = Socket_putdatas(net->socket, buf, buf0len, count, buffers, buflens, frees);
if (rc == TCPSOCKET_COMPLETE) if (rc == TCPSOCKET_COMPLETE)
time(&(net->lastSent)); time(&(net->lastSent));
...@@ -307,11 +299,7 @@ int MQTTPacket_decode(networkHandles* net, size_t* value) ...@@ -307,11 +299,7 @@ int MQTTPacket_decode(networkHandles* net, size_t* value)
rc = SOCKET_ERROR; /* bad data */ rc = SOCKET_ERROR; /* bad data */
goto exit; goto exit;
} }
#if defined(OPENSSL) rc = WebSocket_getch(net, &c);
rc = (net->ssl) ? SSLSocket_getch(net->ssl, net->socket, &c) : Socket_getch(net->socket, &c);
#else
rc = Socket_getch(net->socket, &c);
#endif
if (rc != TCPSOCKET_COMPLETE) if (rc != TCPSOCKET_COMPLETE)
goto exit; goto exit;
*value += (c & 127) * multiplier; *value += (c & 127) * multiplier;
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "MQTTProtocolOut.h" #include "MQTTProtocolOut.h"
#include "StackTrace.h" #include "StackTrace.h"
#include "Heap.h" #include "Heap.h"
#include "WebSocket.h"
extern ClientStates* bstate; extern ClientStates* bstate;
...@@ -42,11 +43,12 @@ extern ClientStates* bstate; ...@@ -42,11 +43,12 @@ extern ClientStates* bstate;
/** /**
* Separates an address:port into two separate values * Separates an address:port into two separate values
* @param uri the input string - hostname:port * @param[in] uri the input string - hostname:port
* @param port the returned port integer * @param[out] port the returned port integer
* @param[out] topic optional topic portion of the address starting with '/'
* @return the address string * @return the address string
*/ */
char* MQTTProtocol_addressPort(const char* uri, int* port) size_t MQTTProtocol_addressPort(const char* uri, int* port, const char **topic)
{ {
char* colon_pos = strrchr(uri, ':'); /* reverse find to allow for ':' in IPv6 addresses */ char* colon_pos = strrchr(uri, ':'); /* reverse find to allow for ':' in IPv6 addresses */
char* buf = (char*)uri; char* buf = (char*)uri;
...@@ -61,27 +63,31 @@ char* MQTTProtocol_addressPort(const char* uri, int* port) ...@@ -61,27 +63,31 @@ char* MQTTProtocol_addressPort(const char* uri, int* port)
if (colon_pos) /* have to strip off the port */ if (colon_pos) /* have to strip off the port */
{ {
size_t addr_len = colon_pos - uri; len = colon_pos - uri;
buf = malloc(addr_len + 1);
*port = atoi(colon_pos + 1); *port = atoi(colon_pos + 1);
MQTTStrncpy(buf, uri, addr_len+1);
} }
else else
{
len = strlen(buf);
*port = DEFAULT_PORT; *port = DEFAULT_PORT;
}
/* try and find topic portion */
if ( topic )
{
const char* addr_start = uri;
if ( colon_pos )
addr_start = colon_pos;
*topic = strchr( addr_start, '/' );
}
len = strlen(buf);
if (buf[len - 1] == ']') if (buf[len - 1] == ']')
{ {
if (buf == (char*)uri) /* we are stripping off the final ], so length is 1 shorter */
{ --len;
buf = malloc(len); /* we are stripping off the final ], so length is 1 shorter */
MQTTStrncpy(buf, uri, len);
}
else
buf[len - 1] = '\0';
} }
FUNC_EXIT; FUNC_EXIT;
return buf; return len;
} }
...@@ -94,19 +100,19 @@ char* MQTTProtocol_addressPort(const char* uri, int* port) ...@@ -94,19 +100,19 @@ char* MQTTProtocol_addressPort(const char* uri, int* port)
* @return return code * @return return code
*/ */
#if defined(OPENSSL) #if defined(OPENSSL)
int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int ssl, int MQTTVersion) int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int ssl, int websocket, int MQTTVersion)
#else #else
int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int MQTTVersion) int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int websocket, int MQTTVersion)
#endif #endif
{ {
int rc, port; int rc, port;
char* addr; size_t addr_len;
FUNC_ENTRY; FUNC_ENTRY;
aClient->good = 1; aClient->good = 1;
addr = MQTTProtocol_addressPort(ip_address, &port); addr_len = MQTTProtocol_addressPort(ip_address, &port, NULL);
rc = Socket_new(addr, port, &(aClient->net.socket)); rc = Socket_new(ip_address, addr_len, port, &(aClient->net.socket));
if (rc == EINPROGRESS || rc == EWOULDBLOCK) if (rc == EINPROGRESS || rc == EWOULDBLOCK)
aClient->connect_state = TCP_IN_PROGRESS; /* TCP connect called - wait for connect completion */ aClient->connect_state = TCP_IN_PROGRESS; /* TCP connect called - wait for connect completion */
else if (rc == 0) else if (rc == 0)
...@@ -114,10 +120,10 @@ int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int MQTTVersi ...@@ -114,10 +120,10 @@ int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int MQTTVersi
#if defined(OPENSSL) #if defined(OPENSSL)
if (ssl) if (ssl)
{ {
if (SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts, addr) == 1) if (SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts, ip_address, addr_len) == 1)
{ {
rc = SSLSocket_connect(aClient->net.ssl, aClient->net.socket, rc = SSLSocket_connect(aClient->net.ssl, aClient->net.socket,
addr, aClient->sslopts->verify); ip_address, aClient->sslopts->verify);
if (rc == TCPSOCKET_INTERRUPTED) if (rc == TCPSOCKET_INTERRUPTED)
aClient->connect_state = SSL_IN_PROGRESS; /* SSL connect called - wait for completion */ aClient->connect_state = SSL_IN_PROGRESS; /* SSL connect called - wait for completion */
} }
...@@ -125,6 +131,12 @@ int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int MQTTVersi ...@@ -125,6 +131,12 @@ int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int MQTTVersi
rc = SOCKET_ERROR; rc = SOCKET_ERROR;
} }
#endif #endif
if ( websocket )
{
rc = WebSocket_connect( &aClient->net, ip_address );
if ( rc == TCPSOCKET_INTERRUPTED )
aClient->connect_state = WEBSOCKET_IN_PROGRESS; /* Websocket connect called - wait for completion */
}
if (rc == 0) if (rc == 0)
{ {
/* Now send the MQTT connect packet */ /* Now send the MQTT connect packet */
...@@ -134,8 +146,6 @@ int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int MQTTVersi ...@@ -134,8 +146,6 @@ int MQTTProtocol_connect(const char* ip_address, Clients* aClient, int MQTTVersi
aClient->connect_state = NOT_IN_PROGRESS; aClient->connect_state = NOT_IN_PROGRESS;
} }
} }
if (addr != ip_address)
free(addr);
FUNC_EXIT_RC(rc); FUNC_EXIT_RC(rc);
return rc; return rc;
......
...@@ -30,12 +30,12 @@ ...@@ -30,12 +30,12 @@
#define DEFAULT_PORT 1883 #define DEFAULT_PORT 1883
char* MQTTProtocol_addressPort(const char* uri, int* port); size_t MQTTProtocol_addressPort(const char* uri, int* port, const char **topic);
void MQTTProtocol_reconnect(const char* ip_address, Clients* client); void MQTTProtocol_reconnect(const char* ip_address, Clients* client);
#if defined(OPENSSL) #if defined(OPENSSL)
int MQTTProtocol_connect(const char* ip_address, Clients* acClients, int ssl, int MQTTVersion); int MQTTProtocol_connect(const char* ip_address, Clients* acClients, int ssl, int websocket, int MQTTVersion);
#else #else
int MQTTProtocol_connect(const char* ip_address, Clients* acClients, int MQTTVersion); int MQTTProtocol_connect(const char* ip_address, Clients* acClients, int websocket, int MQTTVersion);
#endif #endif
int MQTTProtocol_handlePingresps(void* pack, int sock); int MQTTProtocol_handlePingresps(void* pack, int sock);
int MQTTProtocol_subscribe(Clients* client, List* topics, List* qoss, int msgID); int MQTTProtocol_subscribe(Clients* client, List* topics, List* qoss, int msgID);
......
/*******************************************************************************
* Copyright (c) 2018 Wind River Systems, Inc. All Rights Reserved.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Keith Holman - initial implementation and documentation
*******************************************************************************/
#include "SHA1.h"
#if !defined(OPENSSL)
#if defined(WIN32) || defined(WIN64)
#pragma comment(lib, "crypt32.lib")
int SHA1_Init(SHA_CTX *c)
{
if (!CryptAcquireContext(&c->hProv, NULL, NULL,
PROV_RSA_FULL, CRYPT_VERIFYCONTEXT))
return 0;
if (!CryptCreateHash(c->hProv, CALG_SHA1, 0, 0, &c->hHash))
{
CryptReleaseContext(c->hProv, 0);
return 0;
}
return 1;
}
int SHA1_Update(SHA_CTX *c, const void *data, size_t len)
{
int rv = 1;
if (!CryptHashData(c->hHash, data, len, 0))
rv = 0;
return rv;
}
int SHA1_Final(unsigned char *md, SHA_CTX *c)
{
int rv = 0;
DWORD md_len = SHA1_DIGEST_LENGTH;
if (CryptGetHashParam(c->hHash, HP_HASHVAL, md, &md_len, 0))
rv = 1;
CryptDestroyHash(c->hHash);
CryptReleaseContext(c->hProv, 0);
return rv;
}
#else /* if defined(WIN32) || defined(WIN64) */
#if defined(__linux__)
# include <endian.h>
#elif defined(__APPLE__)
# include <libkern/OSByteOrder.h>
# define htobe32(x) OSSwapHostToBigInt32(x)
# define be32toh(x) OSSwapBigToHostInt32(x)
#elif defined(__FreeBSD__) || defined(__NetBSD__)
# include <sys/endian.h>
#endif
#include <string.h>
static unsigned char pad[64] = {
0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00
};
int SHA1_Init(SHA_CTX *ctx)
{
int ret = 0;
if ( ctx )
{
ctx->h[0] = 0x67452301;
ctx->h[1] = 0xEFCDAB89;
ctx->h[2] = 0x98BADCFE;
ctx->h[3] = 0x10325476;
ctx->h[4] = 0xC3D2E1F0;
ctx->size = 0u;
ctx->total = 0u;
ret = 1;
}
return ret;
}
#define ROTATE_LEFT32(a, n) (((a) << (n)) | ((a) >> (32 - (n))))
static void SHA1_ProcessBlock(SHA_CTX *ctx)
{
uint32_t blks[5];
uint32_t *w;
int i;
/* initialize */
for ( i = 0; i < 5; ++i )
blks[i] = ctx->h[i];
w = ctx->w;
/* perform SHA-1 hash */
for ( i = 0; i < 16; ++i )
w[i] = be32toh(w[i]);
for( i = 0; i < 80; ++i )
{
int tmp;
if ( i >= 16 )
w[i & 0x0F] = ROTATE_LEFT32( w[(i+13) & 0x0F] ^ w[(i+8) & 0x0F] ^ w[(i+2) & 0x0F] ^ w[i & 0x0F], 1 );
if ( i < 20 )
tmp = ROTATE_LEFT32(blks[0], 5) + ((blks[1] & blks[2]) | (~(blks[1]) & blks[3])) + blks[4] + w[i & 0x0F] + 0x5A827999;
else if ( i < 40 )
tmp = ROTATE_LEFT32(blks[0], 5) + (blks[1]^blks[2]^blks[3]) + blks[4] + w[i & 0x0F] + 0x6ED9EBA1;
else if ( i < 60 )
tmp = ROTATE_LEFT32(blks[0], 5) + ((blks[1] & blks[2]) | (blks[1] & blks[3]) | (blks[2] & blks[3])) + blks[4] + w[i & 0x0F] + 0x8F1BBCDC;
else
tmp = ROTATE_LEFT32(blks[0], 5) + (blks[1]^blks[2]^blks[3]) + blks[4] + w[i & 0x0F] + 0xCA62C1D6;
/* update registers */
blks[4] = blks[3];
blks[3] = blks[2];
blks[2] = ROTATE_LEFT32(blks[1], 30);
blks[1] = blks[0];
blks[0] = tmp;
}
/* update of hash */
for ( i = 0; i < 5; ++i )
ctx->h[i] += blks[i];
}
int SHA1_Final(unsigned char *md, SHA_CTX *ctx)
{
int i;
int ret = 0;
size_t pad_amount;
uint64_t total;
/* length before pad */
total = ctx->total * 8;
if ( ctx->size < 56 )
pad_amount = 56 - ctx->size;
else
pad_amount = 64 + 56 - ctx->size;
SHA1_Update(ctx, pad, pad_amount);
ctx->w[14] = htobe32((uint32_t)(total >> 32));
ctx->w[15] = htobe32((uint32_t)total);
SHA1_ProcessBlock(ctx);
for ( i = 0; i < 5; ++i )
ctx->h[i] = htobe32(ctx->h[i]);
if ( md )
{
memcpy( md, &ctx->h[0], SHA1_DIGEST_LENGTH );
ret = 1;
}
return ret;
}
int SHA1_Update(SHA_CTX *ctx, const void *data, size_t len)
{
while ( len > 0 )
{
unsigned int n = 64 - ctx->size;
if ( len < n )
n = len;
memcpy(ctx->buffer + ctx->size, data, n);
ctx->size += n;
ctx->total += n;
data = (uint8_t *)data + n;
len -= n;
if ( ctx->size == 64 )
{
SHA1_ProcessBlock(ctx);
ctx->size = 0;
}
}
return 1;
}
#endif /* else if defined(WIN32) || defined(WIN64) */
#endif /* elif !defined(OPENSSL) */
#if defined(SHA1_TEST)
#include <stdio.h>
#include <string.h>
#define TEST_EXPECT(i,x) if (!(x)) {fprintf( stderr, "failed test: %s (for i == %d)\n", #x, i ); ++fails;}
int main(int argc, char *argv[])
{
struct _td
{
const char *in;
const char *out;
};
int i;
unsigned int fails = 0u;
struct _td test_data[] = {
{ "", "da39a3ee5e6b4b0d3255bfef95601890afd80709" },
{ "this string", "fda4e74bc7489a18b146abdf23346d166663dab8" },
{ NULL, NULL }
};
/* only 1 update */
i = 0;
while ( test_data[i].in != NULL )
{
int r[3] = { 1, 1, 1 };
unsigned char sha_out[SHA1_DIGEST_LENGTH];
char out[SHA1_DIGEST_LENGTH * 2 + 1];
SHA_CTX c;
int j;
r[0] = SHA1_Init( &c );
r[1] = SHA1_Update( &c, test_data[i].in, strlen(test_data[i].in));
r[2] = SHA1_Final( sha_out, &c );
for ( j = 0u; j < SHA1_DIGEST_LENGTH; ++j )
snprintf( &out[j*2], 3u, "%02x", sha_out[j] );
out[SHA1_DIGEST_LENGTH * 2] = '\0';
TEST_EXPECT( i, r[0] == 1 && r[1] == 1 && r[2] == 1 && strncmp(out, test_data[i].out, strlen(test_data[i].out)) == 0 );
++i;
}
if ( fails )
printf( "%u test failed!\n", fails );
else
printf( "all tests passed\n" );
return fails;
}
#endif /* if defined(SHA1_TEST) */
/*******************************************************************************
* Copyright (c) 2018 Wind River Systems, Inc. All Rights Reserved.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Keith Holman - initial implementation and documentation
*******************************************************************************/
#if !defined(SHA1_H)
#define SHA1_H
#if defined(OPENSSL)
#include <openssl/sha.h>
/** SHA-1 Digest Length */
#define SHA1_DIGEST_LENGTH SHA_DIGEST_LENGTH
#else /* if defined(OPENSSL) */
#if defined(WIN32) || defined(WIN64)
#include <Windows.h>
#include <WinCrypt.h>
typedef struct SHA_CTX_S
{
HCRYPTPROV hProv;
HCRYPTHASH hHash;
} SHA_CTX;
#else /* if defined(WIN32) || defined(WIN64) */
#include <stdint.h>
typedef struct SHA_CTX_S {
uint32_t h[5];
union {
uint32_t w[16];
uint8_t buffer[64];
};
unsigned int size;
unsigned int total;
} SHA_CTX;
#endif /* else if defined(WIN32) || defined(WIN64) */
#include <stddef.h>
/** SHA-1 Digest Length (number of bytes in SHA1) */
#define SHA1_DIGEST_LENGTH (160/8)
/**
* Initializes the SHA1 hashing algorithm
*
* @param[in,out] ctx hashing context structure
*
* @see SHA1_Update
* @see SHA1_Final
*/
int SHA1_Init(SHA_CTX *ctx);
/**
* Updates a block to the SHA1 hash
*
* @param[in,out] ctx hashing context structure
* @param[in] data block of data to hash
* @param[in] len length of block to hash
*
* @see SHA1_Init
* @see SHA1_Final
*/
int SHA1_Update(SHA_CTX *ctx, const void *data, size_t len);
/**
* Produce final SHA1 hash
*
* @param[out] md SHA1 hash produced (must be atleast
* @p SHA1_DIGEST_LENGTH in length)
* @param[in,out] ctx hashing context structure
*
* @see SHA1_Init
* @see SHA1_Final
*/
int SHA1_Final(unsigned char *md, SHA_CTX *ctx);
#endif /* if defined(OPENSSL) */
#endif /* SHA1_H */
...@@ -30,11 +30,11 @@ ...@@ -30,11 +30,11 @@
#include "SocketBuffer.h" #include "SocketBuffer.h"
#include "MQTTClient.h" #include "MQTTClient.h"
#include "MQTTProtocolOut.h"
#include "SSLSocket.h" #include "SSLSocket.h"
#include "Log.h" #include "Log.h"
#include "StackTrace.h" #include "StackTrace.h"
#include "Socket.h" #include "Socket.h"
char* MQTTProtocol_addressPort(const char* uri, int* port);
#include "Heap.h" #include "Heap.h"
...@@ -620,7 +620,8 @@ exit: ...@@ -620,7 +620,8 @@ exit:
} }
int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, char* hostname) int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
const char* hostname, size_t hostname_len)
{ {
int rc = 1; int rc = 1;
...@@ -628,6 +629,7 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, ...@@ -628,6 +629,7 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
if (net->ctx != NULL || (rc = SSLSocket_createContext(net, opts)) == 1) if (net->ctx != NULL || (rc = SSLSocket_createContext(net, opts)) == 1)
{ {
char *hostname_plus_null;
int i; int i;
SSL_CTX_set_info_callback(net->ctx, SSL_CTX_info_callback); SSL_CTX_set_info_callback(net->ctx, SSL_CTX_info_callback);
...@@ -648,8 +650,11 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, ...@@ -648,8 +650,11 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
if ((rc = SSL_set_fd(net->ssl, net->socket)) != 1) if ((rc = SSL_set_fd(net->ssl, net->socket)) != 1)
SSLSocket_error("SSL_set_fd", net->ssl, net->socket, rc); SSLSocket_error("SSL_set_fd", net->ssl, net->socket, rc);
if ((rc = SSL_set_tlsext_host_name(net->ssl, hostname)) != 1) hostname_plus_null = malloc(hostname_len + 1u );
MQTTStrncpy(hostname_plus_null, hostname, hostname_len + 1u);
if ((rc = SSL_set_tlsext_host_name(net->ssl, hostname_plus_null)) != 1)
SSLSocket_error("SSL_set_tlsext_host_name", NULL, net->socket, rc); SSLSocket_error("SSL_set_tlsext_host_name", NULL, net->socket, rc);
free(hostname_plus_null);
} }
FUNC_EXIT_RC(rc); FUNC_EXIT_RC(rc);
...@@ -659,7 +664,7 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, ...@@ -659,7 +664,7 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
/* /*
* Return value: 1 - success, TCPSOCKET_INTERRUPTED - try again, anything else is failure * Return value: 1 - success, TCPSOCKET_INTERRUPTED - try again, anything else is failure
*/ */
int SSLSocket_connect(SSL* ssl, int sock, char* hostname, int verify) int SSLSocket_connect(SSL* ssl, int sock, const char* hostname, int verify)
{ {
int rc = 0; int rc = 0;
...@@ -680,12 +685,12 @@ int SSLSocket_connect(SSL* ssl, int sock, char* hostname, int verify) ...@@ -680,12 +685,12 @@ int SSLSocket_connect(SSL* ssl, int sock, char* hostname, int verify)
{ {
char* peername = NULL; char* peername = NULL;
int port; int port;
char* addr = NULL; size_t hostname_len;
X509* cert = SSL_get_peer_certificate(ssl); X509* cert = SSL_get_peer_certificate(ssl);
addr = MQTTProtocol_addressPort(hostname, &port); hostname_len = MQTTProtocol_addressPort(hostname, &port, NULL);
rc = X509_check_host(cert, addr, strlen(addr), 0, &peername); rc = X509_check_host(cert, hostname, hostname_len, 0, &peername);
if (rc == 0) if (rc == 0)
rc = SOCKET_ERROR; rc = SOCKET_ERROR;
Log(TRACE_MIN, -1, "rc from X509_check_host is %d", rc); Log(TRACE_MIN, -1, "rc from X509_check_host is %d", rc);
...@@ -693,8 +698,6 @@ int SSLSocket_connect(SSL* ssl, int sock, char* hostname, int verify) ...@@ -693,8 +698,6 @@ int SSLSocket_connect(SSL* ssl, int sock, char* hostname, int verify)
if (cert) if (cert)
X509_free(cert); X509_free(cert);
if (addr != hostname)
free(addr);
} }
#endif #endif
......
...@@ -37,14 +37,14 @@ void SSLSocket_handleOpensslInit(int bool_value); ...@@ -37,14 +37,14 @@ void SSLSocket_handleOpensslInit(int bool_value);
int SSLSocket_initialize(void); int SSLSocket_initialize(void);
void SSLSocket_terminate(void); void SSLSocket_terminate(void);
int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, char* hostname); int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, const char* hostname, size_t hostname_len);
int SSLSocket_getch(SSL* ssl, int socket, char* c); int SSLSocket_getch(SSL* ssl, int socket, char* c);
char *SSLSocket_getdata(SSL* ssl, int socket, size_t bytes, size_t* actual_len); char *SSLSocket_getdata(SSL* ssl, int socket, size_t bytes, size_t* actual_len);
int SSLSocket_close(networkHandles* net); int SSLSocket_close(networkHandles* net);
int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, size_t buf0len, int count, char** buffers, size_t* buflens, int* frees); int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, size_t buf0len, int count, char** buffers, size_t* buflens, int* frees);
int SSLSocket_connect(SSL* ssl, int sock, char* hostname, int verify); int SSLSocket_connect(SSL* ssl, int sock, const char* hostname, int verify);
int SSLSocket_getPendingRead(void); int SSLSocket_getPendingRead(void);
int SSLSocket_continueWrite(pending_writes* pw); int SSLSocket_continueWrite(pending_writes* pw);
......
...@@ -650,9 +650,10 @@ void Socket_close(int socket) ...@@ -650,9 +650,10 @@ void Socket_close(int socket)
* @param sock returns the new socket * @param sock returns the new socket
* @return completion code * @return completion code
*/ */
int Socket_new(char* addr, int port, int* sock) int Socket_new(const char* addr, size_t addr_len, int port, int* sock)
{ {
int type = SOCK_STREAM; int type = SOCK_STREAM;
char *addr_mem;
struct sockaddr_in address; struct sockaddr_in address;
#if defined(AF_INET6) #if defined(AF_INET6)
struct sockaddr_in6 address6; struct sockaddr_in6 address6;
...@@ -671,9 +672,16 @@ int Socket_new(char* addr, int port, int* sock) ...@@ -671,9 +672,16 @@ int Socket_new(char* addr, int port, int* sock)
memset(&address6, '\0', sizeof(address6)); memset(&address6, '\0', sizeof(address6));
if (addr[0] == '[') if (addr[0] == '[')
++addr; {
++addr;
--addr_len;
}
addr_mem = malloc( addr_len + 1u );
memcpy( addr_mem, addr, addr_len );
addr_mem[addr_len] = '\0';
if ((rc = getaddrinfo(addr, NULL, &hints, &result)) == 0) if ((rc = getaddrinfo(addr_mem, NULL, &hints, &result)) == 0)
{ {
struct addrinfo* res = result; struct addrinfo* res = result;
...@@ -708,10 +716,10 @@ int Socket_new(char* addr, int port, int* sock) ...@@ -708,10 +716,10 @@ int Socket_new(char* addr, int port, int* sock)
freeaddrinfo(result); freeaddrinfo(result);
} }
else else
Log(LOG_ERROR, -1, "getaddrinfo failed for addr %s with rc %d", addr, rc); Log(LOG_ERROR, -1, "getaddrinfo failed for addr %s with rc %d", addr_mem, rc);
if (rc != 0) if (rc != 0)
Log(LOG_ERROR, -1, "%s is not a valid IP address", addr); Log(LOG_ERROR, -1, "%s is not a valid IP address", addr_mem);
else else
{ {
*sock = (int)socket(family, type, 0); *sock = (int)socket(family, type, 0);
...@@ -771,6 +779,10 @@ int Socket_new(char* addr, int port, int* sock) ...@@ -771,6 +779,10 @@ int Socket_new(char* addr, int port, int* sock)
} }
} }
} }
if (addr_mem)
free(addr_mem);
FUNC_EXIT_RC(rc); FUNC_EXIT_RC(rc);
return rc; return rc;
} }
......
...@@ -131,7 +131,7 @@ int Socket_getch(int socket, char* c); ...@@ -131,7 +131,7 @@ int Socket_getch(int socket, char* c);
char *Socket_getdata(int socket, size_t bytes, size_t* actual_len); char *Socket_getdata(int socket, size_t bytes, size_t* actual_len);
int Socket_putdatas(int socket, char* buf0, size_t buf0len, int count, char** buffers, size_t* buflens, int* frees); int Socket_putdatas(int socket, char* buf0, size_t buf0len, int count, char** buffers, size_t* buflens, int* frees);
void Socket_close(int socket); void Socket_close(int socket);
int Socket_new(char* addr, int port, int* socket); int Socket_new(const char* addr, size_t addr_len, int port, int* socket);
int Socket_noPendingWrites(int socket); int Socket_noPendingWrites(int socket);
char* Socket_getpeer(int sock); char* Socket_getpeer(int sock);
......
/*******************************************************************************
* Copyright (c) 2018 Wind River Systems, Inc. All Rights Reserved.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Keith Holman - initial implementation and documentation
*******************************************************************************/
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include "WebSocket.h"
#include "Base64.h"
#include "Log.h"
#include "SHA1.h"
#include "LinkedList.h"
#include "MQTTProtocolOut.h"
#if defined(__linux__)
# include <endian.h>
#elif defined(__APPLE__)
# include <libkern/OSByteOrder.h>
# define htobe16(x) OSSwapHostToBigInt16(x)
# define htobe32(x) OSSwapHostToBigInt32(x)
# define htobe64(x) OSSwapHostToBigInt64(x)
# define be16toh(x) OSSwapBigToHostInt16(x)
# define be32toh(x) OSSwapBigToHostInt32(x)
# define be64toh(x) OSSwapBigToHostInt64(x)
#elif defined(__FreeBSD__) || defined(__NetBSD__)
# include <sys/endian.h>
#elif defined(WIN32) || defined(WIN64)
# pragma comment(lib, "rpcrt4.lib")
# include <Rpc.h>
# define strncasecmp(s1,s2,c) _strnicmp(s1,s2,c)
# if BYTE_ORDER == LITTLE_ENDIAN
# define htobe16(x) htons(x)
# define htobe32(x) htonl(x)
# define htobe64(x) htonll(x)
# define be16toh(x) ntohs(x)
# define be32toh(x) ntohl(x)
# define be64toh(x) ntohll(x)
# elif BTYE_ORDER == BIG_ENDIAN
# define htobe16(x) (x)
# define htobe32(x) (x)
# define htobe64(x) (x)
# define be16toh(x) (x)
# define be32toh(x) (x)
# define be64toh(x) (x)
# else
# error "unknown endian"
# endif
/* For Microsoft Visual Studio 2013 */
# if !defined( snprintf )
# define snprintf _snprintf
# endif /* if !defined( snprintf ) */
#endif
#if defined(OPENSSL)
#include "SSLSocket.h"
#endif /* defined(OPENSSL) */
#include "Socket.h"
#if !(defined(WIN32) || defined(WIN64))
#if defined(LIBUUID)
#include <uuid/uuid.h>
#else /* if defined(USE_LIBUUID) */
#include <limits.h>
#include <stdlib.h>
#include <time.h>
#if defined(OPENSSL)
#include <openssl/rand.h>
#endif /* if defined(OPENSSL) */
/** @brief raw uuid type */
typedef unsigned char uuid_t[16];
/**
* @brief generates a uuid, compatible with RFC 4122, version 4 (random)
* @note Uses a very insecure algorithm but no external dependencies
*/
void uuid_generate( uuid_t out )
{
#if defined(OPENSSL)
int rc = RAND_bytes( out, sizeof(uuid_t));
if ( !rc )
#endif /* defined (OPENSSL) */
{
/* very insecure, but generates a random uuid */
srand(time(NULL));
int i;
for ( i = 0; i < 16; ++i )
out[i] = (unsigned char)(rand() % UCHAR_MAX);
out[6] = (out[6] & 0x0f) | 0x40;
out[8] = (out[8] & 0x3F) | 0x80;
}
}
/** @brief converts a uuid to a string */
void uuid_unparse( uuid_t uu, char *out )
{
int i;
for ( i = 0; i < 16; ++i )
{
if ( i == 4 || i == 6 || i == 8 || i == 10 )
{
*out = '-';
++out;
}
out += sprintf( out, "%02x", uu[i] );
}
*out = '\0';
}
#endif /* else if defined(LIBUUID) */
#endif /* if !(defined(WIN32) || defined(WIN64)) */
/** raw websocket frame data */
struct ws_frame
{
size_t len; /**< length of frame */
size_t pos; /**< current position within the buffer */
};
/** Current frame being processed */
struct ws_frame *last_frame = NULL;
/** Holds any received websocket frames, to be process */
static List* in_frames = NULL;
/* static function declarations */
static const char *WebSocket_strcasefind(
const char *buf, const char *str, size_t len);
static char *WebSocket_getRawSocketData(
networkHandles *net, size_t bytes, size_t* actual_len);
static void WebSocket_pong(
networkHandles *net, char *app_data, size_t app_data_len);
static int WebSocket_receiveFrame(networkHandles *net,
size_t bytes, size_t *actual_len );
/**
* @brief builds a websocket frame for data transmission
*
* write a websocket header and will mask the payload in all the passed in
* buffers
*
* @param[in,out] net network connection
* @param[in] opcode websocket opcode for the packet
* @param[in] mask_data whether to maskt he data
* @param[in,out] buf0 first buffer, will write before this
* @param[in] buf0len size of first buffer
* @param[in] count number of payload buffers
* @param[in,out] buffers array of paylaod buffers
* @param[in] buflens array of payload buffer sizes
* @param[in] freeData array indicating to free payload buffers
*
* @return amount of data to write to socket
*/
static int WebSocket_buildFrame(networkHandles* net, int opcode, int mask_data,
char* buf0, size_t buf0len, int count, char** buffers, size_t* buflens)
{
int i;
int buf_len = 0u;
size_t data_len = buf0len;
for (i = 0; i < count; ++i)
data_len += buflens[i];
buf0 -= WebSocket_calculateFrameHeaderSize(net, data_len);
if ( net->websocket )
{
uint8_t mask[4];
/* genearate mask, since we are a client */
#if defined(OPENSSL)
RAND_bytes( &mask[0], sizeof(mask) );
#else /* if defined(OPENSSL) */
mask[0] = (rand() % UINT8_MAX);
mask[1] = (rand() % UINT8_MAX);
mask[2] = (rand() % UINT8_MAX);
mask[3] = (rand() % UINT8_MAX);
#endif /* else if defined(OPENSSL) */
/* 1st byte */
buf0[buf_len] = (char)(1 << 7); /* final flag */
/* 3 bits reserved for negotiation of protocol */
buf0[buf_len] |= (opcode & 0x0F); /* op code */
++buf_len;
/* 2nd byte */
buf0[buf_len] = (mask_data & 0x1) << 7; /* masking bit */
/* payload length */
if ( data_len < 126u )
buf0[buf_len++] |= data_len & 0x7F;
/* 3rd byte & 4th bytes - extended payload length */
else if ( data_len <= 65536u )
{
uint16_t len = htobe16((uint16_t)data_len);
buf0[buf_len++] |= (126u & 0x7F);
memcpy( &buf0[buf_len], &len, 2u );
buf_len += 2;
}
else if ( data_len < 0xFFFFFFFFFFFFFFFF )
{
uint64_t len = htobe64((uint64_t)data_len);
buf0[buf_len++] |= (127u & 0x7F);
memcpy( &buf0[buf_len], &len, 8 );
buf_len += 8;
}
else
{
Log(TRACE_PROTOCOL, 1, "Data too large for websocket frame" );
buf_len = -1;
}
/* masking key */
if ( (mask_data & 0x1) && buf_len > 0 )
{
memcpy( &buf0[buf_len], &mask, sizeof(uint32_t));
buf_len += sizeof(uint32_t);
}
/* mask data */
if ( mask_data & 0x1 )
{
size_t idx = 0u;
/* packet fixed header */
for (i = 0; i < (int)buf0len; ++i, ++idx)
buf0[buf_len + i] ^= mask[idx % 4];
/* variable data buffers */
for (i = 0; i < count; ++i)
{
size_t j;
for ( j = 0; j < buflens[i]; ++j, ++idx )
buffers[i][j] ^= mask[idx % 4];
}
}
}
return buf_len;
}
/**
* calculates the amount of data required for the websocket header
*
* this function is used to calculate how much offset is required before calling
* @p WebSocket_putdatas, as that function will write data before the passed in
* buffer
*
* @param[in,out] net network connection
* @param[in] data_len amount of data in the payload
*
* @return the size in bytes of the websocket header required
*
* @see WebSocket_putdatas
*/
size_t WebSocket_calculateFrameHeaderSize(networkHandles *net, size_t data_len)
{
int ret = 0;
if ( net && net->websocket )
{
ret += 2; /* header 2 bytes */
if ( data_len >= 126 )
ret += 2; /* for extra 2-bytes for payload length */
else if ( data_len > 65536u )
ret += 8; /* for extra 8-bytes for payload length */
ret += sizeof(uint32_t); /* for mask - all outgoing data is masked */
}
return ret;
}
/**
* sends out a websocket request on the given uri
*
* @param[in] net network connection
* @param[in] uri uri to connect to
*
* @retval SOCKET_ERROR on failure
* @retval 1 on success
*
* @see WebSocket_upgrade
*/
int WebSocket_connect( networkHandles *net, const char *uri )
{
int rc;
char *buf = NULL;
int i, buf_len = 0;
size_t hostname_len;
int port = 80;
const char *topic = NULL;
/* Generate UUID */
net->websocket_key = realloc(net->websocket_key, 25u);
#if defined(WIN32) || defined(WIN64)
UUID uuid;
ZeroMemory( &uuid, sizeof(UUID) );
UuidCreate( &uuid );
Base64_encode( net->websocket_key, 25u, (const b64_data_t*)&uuid, sizeof(UUID) );
#else /* if defined(WIN32) || defined(WIN64) */
uuid_t uuid;
uuid_generate( uuid );
Base64_encode( net->websocket_key, 25u, uuid, sizeof(uuid_t) );
#endif /* else if defined(WIN32) || defined(WIN64) */
hostname_len = MQTTProtocol_addressPort(uri, &port, &topic);
/* if no topic, use default */
if ( !topic )
topic = "/mqtt";
for ( i = 0; i < 2; ++i )
{
buf_len = snprintf( buf, (size_t)buf_len,
"GET %s HTTP/1.1\r\n"
"Host: %.*s:%d\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Origin: http://%.*s:%d\r\n"
"Sec-WebSocket-Key: %s\r\n"
"Sec-WebSocket-Version: 13\r\n"
"Sec-WebSocket-Protocol: mqtt\r\n"
"\r\n", topic,
(int)hostname_len, uri, port,
(int)hostname_len, uri, port,
net->websocket_key );
if ( i == 0 && buf_len > 0 )
{
++buf_len; /* need 1 extra byte for ending '\0' */
buf = malloc( buf_len );
}
}
if ( buf )
{
#if defined(OPENSSL)
if (net->ssl)
SSLSocket_putdatas(net->ssl, net->socket,
buf, buf_len, 0, NULL, NULL, NULL );
else
#endif
Socket_putdatas( net->socket, buf, buf_len,
0, NULL, NULL, NULL );
free( buf );
rc = 1;
}
else
{
free(net->websocket_key);
net->websocket_key = NULL;
rc = SOCKET_ERROR;
}
return rc;
}
/**
* closes a websocket connection
*
* @param[in,out] net structure containing network connection
* @param[in] status_code websocket close status code
* @param[in] reason reason for closing connection (optional)
*/
void WebSocket_close(networkHandles *net, int status_code, const char *reason)
{
if ( net->websocket )
{
char *buf0;
size_t buf0len = sizeof(uint16_t);
size_t header_len;
uint16_t status_code_big_endian;
if ( status_code < WebSocket_CLOSE_NORMAL ||
status_code > WebSocket_CLOSE_TLS_FAIL )
status_code = WebSocket_CLOSE_GOING_AWAY;
if ( reason )
buf0len += strlen(reason);
header_len = WebSocket_calculateFrameHeaderSize(net, buf0len);
buf0 = malloc(header_len + buf0len);
if ( !buf0 ) return;
/* encode status code */
status_code_big_endian = (uint16_t)htobe32((uint16_t)status_code);
memcpy( &buf0[header_len], &status_code_big_endian,
sizeof(uint16_t));
/* encode reason, if provided */
if ( reason )
strcpy( &buf0[header_len + 2], reason );
WebSocket_buildFrame( net, WebSocket_OP_CLOSE, 1,
&buf0[header_len], buf0len, 0, NULL, NULL );
#if defined(OPENSSL)
if (net->ssl)
SSLSocket_putdatas(net->ssl, net->socket,
buf0, buf0len, 0, NULL, NULL, NULL);
else
#endif
Socket_putdatas(net->socket, buf0, buf0len, 0,
NULL, NULL, NULL);
/* websocket connection is now closed */
net->websocket = 0;
free( buf0 );
}
if ( net->websocket_key )
free( net->websocket_key );
}
/**
* @brief receives 1 byte from a socket
*
* @param[in,out] net network connection
* @param[out] c byte that was read
*
* @retval SOCKET_ERROR on error
* @retval TCPSOCKET_INTERRUPTED no data available
* @retval TCPSOCKET_COMPLETE on success
*
* @see WebSocket_getdata
*/
int WebSocket_getch(networkHandles *net, char* c)
{
int rc = SOCKET_ERROR;
if ( net->websocket )
{
struct ws_frame *frame = NULL;
if ( in_frames && in_frames->first )
frame = in_frames->first->content;
if ( !frame )
{
size_t actual_len = 0u;
rc = WebSocket_receiveFrame( net, 1u, &actual_len );
if ( rc != TCPSOCKET_COMPLETE )
return rc;
/* we got a frame, let take off the top of queue */
if ( in_frames->first )
frame = in_frames->first->content;
}
/* set current working frame */
if (frame && frame->len > frame->pos)
{
unsigned char *buf =
(unsigned char *)frame + sizeof(struct ws_frame);
*c = buf[frame->pos++];
rc = TCPSOCKET_COMPLETE;
}
}
#if defined(OPENSSL)
else if ( net->ssl )
rc = SSLSocket_getch(net->ssl, net->socket, c);
#endif
else
rc = Socket_getch(net->socket, c);
return rc;
}
/**
* @brief receives data from a socket
*
* @param[in,out] net network connection
* @param[in] bytes amount of data to get (0 if last packet)
* @param[out] actual_len amount of data read
*
* @return a pointer to the read data
*
* @see WebSocket_getch
*/
char *WebSocket_getdata(networkHandles *net, size_t bytes, size_t* actual_len)
{
char *rv = NULL;
if ( net->websocket )
{
struct ws_frame *frame = NULL;
if ( bytes == 0u )
{
/* done with current frame, move it to last frame */
if ( in_frames && in_frames->first )
frame = in_frames->first->content;
/* return the data from the next frame, if we have one */
if ( frame )
{
rv = (char *)frame +
sizeof(struct ws_frame) + frame->pos;
*actual_len = frame->len - frame->pos;
if ( last_frame )
free( last_frame );
last_frame = ListDetachHead(in_frames);
}
return rv;
}
/* no current frame, let's see if there's one in the list */
if ( in_frames && in_frames->first )
frame = in_frames->first->content;
/* no current frame, so let's go receive one for the network */
if ( !frame )
{
const int rc =
WebSocket_receiveFrame( net, bytes, actual_len );
if ( rc == TCPSOCKET_COMPLETE && in_frames && in_frames->first)
frame = in_frames->first->content;
}
if ( frame )
{
rv = (char *)frame + sizeof(struct ws_frame) + frame->pos;
*actual_len = frame->len - frame->pos;
if ( *actual_len == bytes && in_frames)
{
/* set new frame as current frame */
if ( last_frame )
free( last_frame );
last_frame = ListDetachHead(in_frames);
}
}
}
else
rv = WebSocket_getRawSocketData(net, bytes, actual_len);
return rv;
}
/**
* reads raw socket data for underlying layers
*
* @param[in] net network connection
* @param[in] bytes number of bytes to read, 0 to complete packet
* @param[in] actual_len amount of data read
*
* @return a buffer containing raw data
*/
char *WebSocket_getRawSocketData(
networkHandles *net, size_t bytes, size_t* actual_len)
{
char *rv;
#if defined(OPENSSL)
if ( net->ssl )
rv = SSLSocket_getdata(net->ssl, net->socket, bytes, actual_len);
else
#endif
rv = Socket_getdata(net->socket, bytes, actual_len);
return rv;
}
/**
* sends a "websocket pong" message
*
* @param[in] net network connection
* @param[in] app_data application data to put in payload
* @param[in] app_data_len application data length
*/
void WebSocket_pong(networkHandles *net, char *app_data,
size_t app_data_len)
{
if ( net->websocket )
{
char *buf0;
size_t header_len;
int freeData = 0;
header_len = WebSocket_calculateFrameHeaderSize(net,
app_data_len);
buf0 = malloc(header_len);
if ( !buf0 ) return;
WebSocket_buildFrame( net, WebSocket_OP_PONG, 1,
&buf0[header_len], header_len, 1, &app_data,
&app_data_len );
Log(TRACE_PROTOCOL, 1, "Sending WebSocket PONG" );
#if defined(OPENSSL)
if (net->ssl)
SSLSocket_putdatas(net->ssl, net->socket, buf0,
header_len, 1, &app_data, &app_data_len,
&freeData);
else
#endif
Socket_putdatas(net->socket, buf0,
header_len + app_data_len, 1,
&app_data, &app_data_len, &freeData );
/* clean up memory */
free( buf0 );
}
}
/**
* writes data to a socket (websocket header will be prepended if required)
*
* @warning buf0 will be expanded (backwords before @p buf0 buffer, to add a
* websocket frame header to the data if required). So use
* @p WebSocket_calculateFrameHeader, to determine if extra space is needed
* before the @p buf0 pointer.
*
* @param[in,out] net network connection
* @param[in,out] buf0 first buffer
* @param[in] buf0len size of first buffer
* @param[in] count number of payload buffers
* @param[in,out] buffers array of paylaod buffers
* @param[in] buflens array of payload buffer sizes
* @param[in] freeData array indicating to free payload buffers
*
* @return amount of data wrote to socket
*
* @see WebSocket_calculateFrameHeaderSize
*/
int WebSocket_putdatas(networkHandles* net, char* buf0, size_t buf0len,
int count, char** buffers, size_t* buflens, int* freeData)
{
int rc;
/* prepend WebSocket frame */
if ( net->websocket )
{
rc = WebSocket_buildFrame(
net, WebSocket_OP_BINARY, 1, buf0, buf0len,
count, buffers, buflens );
/* header added so adjust buffer */
if ( rc > 0 )
{
buf0 -= rc;
buf0len += rc;
}
}
#if defined(OPENSSL)
if (net->ssl)
rc = SSLSocket_putdatas(net->ssl, net->socket, buf0, buf0len, count, buffers, buflens, freeData);
else
#endif
rc = Socket_putdatas(net->socket, buf0, buf0len, count, buffers, buflens, freeData);
return rc;
}
/**
* receives incoming socket data and parses websocket frames
*
* @param[in] net network connection
* @param[in] bytes amount of data to receive
* @param[out] actual_len amount of data actually read
*
* @retval TCPSOCKET_COMPLETE packet received
* @retval TCPSOCKET_INTERRUPTED incomplete packet received
* @retval SOCKET_ERROR an error was encountered
*/
int WebSocket_receiveFrame(networkHandles *net,
size_t bytes, size_t *actual_len )
{
struct ws_frame *res = NULL;
if ( !in_frames )
in_frames = ListInitialize();
/* see if there is frame acurrently on queue */
if ( in_frames->first )
res = in_frames->first->content;
while( !res )
{
int opcode = WebSocket_OP_BINARY;
do
{
/* obtain all frames in the sequence */
int final = 0;
while ( !final )
{
char *b;
size_t len = 0u;
int tmp_opcode;
int has_mask;
size_t cur_len = 0u;
uint8_t mask[4] = { 0u, 0u, 0u, 0u };
size_t payload_len;
b = WebSocket_getRawSocketData(net, 2u, &len);
if ( !b || len == 0u )
return TCPSOCKET_INTERRUPTED;
/* 1st byte */
final = (b[0] & 0xFF) >> 7;
tmp_opcode = (b[0] & 0x0F);
if ( tmp_opcode ) /* not a continuation frame */
opcode = tmp_opcode;
/* invalid websocket packet must return error */
if ( opcode < WebSocket_OP_CONTINUE ||
opcode > WebSocket_OP_PONG ||
( opcode > WebSocket_OP_BINARY &&
opcode < WebSocket_OP_CLOSE ) )
return SOCKET_ERROR;
/* 2nd byte */
has_mask = (b[1] & 0xFF) >> 7;
payload_len = (b[1] & 0x7F);
/* determine payload length */
if ( payload_len == 126 )
{
b = WebSocket_getRawSocketData( net,
2u, &len);
if ( !b || len == 0u )
return TCPSOCKET_INTERRUPTED;
/* convert from big endian 16 to host */
payload_len = be16toh(*(uint16_t*)b);
}
else if ( payload_len == 127 )
{
b = WebSocket_getRawSocketData( net,
8u, &len);
if ( !b || len == 0u )
return TCPSOCKET_INTERRUPTED;
/* convert from big-endian 64 to host */
payload_len = (size_t)be64toh(*(uint64_t*)b);
}
if ( has_mask )
{
uint8_t mask[4];
b = WebSocket_getRawSocketData(net, 4u, &len);
if ( !b || len == 0u )
return TCPSOCKET_INTERRUPTED;
memcpy( &mask[0], b, sizeof(uint32_t));
}
b = WebSocket_getRawSocketData(net,
payload_len, &len);
if ( !b || len == 0u )
return TCPSOCKET_INTERRUPTED;
/* unmask data */
if ( has_mask )
{
size_t i;
for ( i = 0u; i < payload_len; ++i )
b[i] ^= mask[i % 4];
}
if ( res )
cur_len = res->len;
res = realloc( res, sizeof(struct ws_frame) + cur_len + len );
memcpy( (unsigned char *)res + sizeof(struct ws_frame) + cur_len, b, len );
res->pos = 0u;
res->len = cur_len + len;
WebSocket_getRawSocketData(net, 0u, &len);
}
if ( opcode == WebSocket_OP_PONG || opcode == WebSocket_OP_PONG )
{
/* respond to a "ping" with a "pong" */
if ( opcode == WebSocket_OP_PING )
WebSocket_pong( net,
(char *)res + sizeof(struct ws_frame),
res->len );
/* discard message */
free( res );
res = NULL;
}
else if ( opcode == WebSocket_OP_CLOSE )
{
/* server end closed websocket connection */
free( res );
WebSocket_close( net, WebSocket_CLOSE_GOING_AWAY, NULL );
return SOCKET_ERROR; /* closes socket */
}
} while ( opcode == WebSocket_OP_PING || opcode == WebSocket_OP_PONG );
}
/* add new frame to end of list */
ListAppend( in_frames, res, sizeof(struct ws_frame) + res->len);
*actual_len = res->len - res->pos;
return TCPSOCKET_COMPLETE;
}
/**
* case-insensitive string search
*
* similar to @p strcase, but takes a maximum length
*
* @param[in] buf buffer to search
* @param[in] str string to find
* @param[in] len length of the buffer
*
* @retval !NULL location of string found
* @retval NULL string not found
*/
const char *WebSocket_strcasefind(const char *buf, const char *str, size_t len)
{
const char *res = NULL;
if ( buf && len > 0u && str )
{
const size_t str_len = strlen( str );
while ( len >= str_len && !res )
{
if ( strncasecmp( buf, str, str_len ) == 0 )
res = buf;
++buf;
--len;
}
}
return res;
}
/**
* releases resources used by the websocket sub-system
*/
void WebSocket_terminate( void )
{
/* clean up and un-processed websocket frames */
if ( in_frames )
{
struct ws_frame *f = ListDetachHead( in_frames );
while ( f )
{
free( f );
f = ListDetachHead( in_frames );
}
ListFree( in_frames );
}
if ( last_frame )
free( last_frame );
Socket_outTerminate();
#if defined(OPENSSL)
SSLSocket_terminate();
#endif
}
/**
* handles the websocket upgrade response
*
* @param[in,out] net network connection to upgrade
*
* @retval SOCKET_ERROR failed to upgrade network connection
* @retval 1 socket upgraded to use websockets
*
* @see WebSocket_connect
*/
int WebSocket_upgrade( networkHandles *net )
{
static const char *const ws_guid =
"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
int rc = SOCKET_ERROR;
if ( net->websocket_key )
{
SHA_CTX ctx;
char ws_key[62u] = { 0 };
unsigned char sha_hash[SHA1_DIGEST_LENGTH];
size_t rcv = 0u;
char *read_buf;
/* calculate the expected websocket key, expected from server */
snprintf( ws_key, sizeof(ws_key), "%s%s", net->websocket_key, ws_guid );
SHA1_Init( &ctx );
SHA1_Update( &ctx, ws_key, strlen(ws_key));
SHA1_Final( sha_hash, &ctx );
Base64_encode( ws_key, sizeof(ws_key), sha_hash, SHA1_DIGEST_LENGTH );
rc = TCPSOCKET_INTERRUPTED;
read_buf = WebSocket_getRawSocketData( net, 12u, &rcv );
if ( rcv > 0 && strncmp( read_buf, "HTTP/1.1 101", 11u ) == 0 )
{
const char *p;
read_buf = WebSocket_getRawSocketData( net, 500u, &rcv );
/* check for upgrade */
p = WebSocket_strcasefind(
read_buf, "Connection", rcv );
if ( p )
{
const char *eol;
eol = memchr( p, '\n', rcv-(read_buf-p) );
if ( eol )
p = WebSocket_strcasefind(
p, "Upgrade", eol - p);
else
p = NULL;
}
/* check key hash */
if ( p )
p = WebSocket_strcasefind( read_buf,
"sec-websocket-accept", rcv );
if ( p )
{
const char *eol;
eol = memchr( p, '\n', rcv-(read_buf-p) );
if ( eol )
{
p = memchr( p, ':', eol-p );
if ( p )
{
size_t hash_len = eol-p-1;
while ( *p == ':' || *p == ' ' )
{
++p;
--hash_len;
}
if ( strncmp( p, ws_key, hash_len ) != 0 )
p = NULL;
}
}
else
p = NULL;
}
if ( p )
{
net->websocket = 1;
Log(TRACE_PROTOCOL, 1, "WebSocket connection upgraded" );
rc = 1;
}
else
{
Log(TRACE_PROTOCOL, 1, "WebSocket failed to upgrade connection" );
rc = SOCKET_ERROR;
}
if ( net->websocket_key )
{
free(net->websocket_key);
net->websocket_key = NULL;
}
/* indicate that we done with the packet */
WebSocket_getRawSocketData( net, 0u, &rcv );
}
}
return rc;
}
/*******************************************************************************
* Copyright (c) 2018 Wind River Systems, Inc. All Rights Reserved.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Keith Holman - initial implementation and documentation
*******************************************************************************/
#if !defined(WEBSOCKET_H)
#define WEBSOCKET_H
#include "Clients.h"
/**
* WebSocket op codes
* @{
*/
#define WebSocket_OP_CONTINUE 0x0 /* 0000 - continue frame */
#define WebSocket_OP_TEXT 0x1 /* 0001 - text frame */
#define WebSocket_OP_BINARY 0x2 /* 0010 - binary frame */
#define WebSocket_OP_CLOSE 0x8 /* 1000 - close frame */
#define WebSocket_OP_PING 0x9 /* 1001 - ping frame */
#define WebSocket_OP_PONG 0xA /* 1010 - pong frame */
/** @} */
/**
* Various close status codes
* @{
*/
#define WebSocket_CLOSE_NORMAL 1000
#define WebSocket_CLOSE_GOING_AWAY 1001
#define WebSocket_CLOSE_PROTOCOL_ERROR 1002
#define WebSocket_CLOSE_UNKNOWN_DATA 1003
#define WebSocket_CLOSE_RESERVED 1004
#define WebSocket_CLOSE_NO_STATUS_CODE 1005 /* reserved: not to be used */
#define WebSocket_CLOSE_ABNORMAL 1006 /* reserved: not to be used */
#define WebSocket_CLOSE_BAD_DATA 1007
#define WebSocket_CLOSE_POLICY 1008
#define WebSocket_CLOSE_MSG_TOO_BIG 1009
#define WebSocket_CLOSE_NO_EXTENSION 1010
#define WebScoket_CLOSE_UNEXPECTED 1011
#define WebSocket_CLOSE_TLS_FAIL 1015 /* reserved: not be used */
/** @} */
/* closes a websocket connection */
void WebSocket_close(networkHandles *net, int status_code, const char *reason);
/* sends upgrade request */
int WebSocket_connect(networkHandles *net, const char *uri);
/* calculates the extra data required in a packet to hold a WebSocket frame header */
size_t WebSocket_calculateFrameHeaderSize(networkHandles *net, size_t data_len);
/* obtain data from network socket */
int WebSocket_getch(networkHandles *net, char* c);
char *WebSocket_getdata(networkHandles *net, size_t bytes, size_t* actual_len);
/* send data out, in websocket format only if required */
int WebSocket_putdatas(networkHandles* net, char* buf0, size_t buf0len,
int count, char** buffers, size_t* buflens, int* freeData);
/* releases any resources used by the websocket system */
void WebSocket_terminate(void);
/* handles websocket upgrade request */
int WebSocket_upgrade(networkHandles *net);
#endif /* WEBSOCKET_H */
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment