Commit 73cc3ea9 authored by Keith Holman's avatar Keith Holman

websocket: fixes to work with paho test suite

This patch contains some fixes to work with the paho test suite.
Signed-off-by: 's avatarKeith Holman <keith.holman@windriver.com>
parent 8c988485
...@@ -473,9 +473,10 @@ int MQTTAsync_createWithOptions(MQTTAsync* handle, const char* serverURI, const ...@@ -473,9 +473,10 @@ int MQTTAsync_createWithOptions(MQTTAsync* handle, const char* serverURI, const
if (strstr(serverURI, "://") != NULL) if (strstr(serverURI, "://") != NULL)
{ {
if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) != 0 if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) != 0
&& strncmp(URI_WS, serverURI, strlen(URI_WS)) != 0
#if defined(OPENSSL) #if defined(OPENSSL)
&& strncmp(URI_SSL, serverURI, strlen(URI_SSL)) != 0 && strncmp(URI_SSL, serverURI, strlen(URI_SSL)) != 0
&& strncmp(URI_WSS, serverURI, strlen(URI_WSS)) != 0
#endif #endif
) )
{ {
...@@ -1247,12 +1248,23 @@ static int MQTTAsync_processCommand(void) ...@@ -1247,12 +1248,23 @@ static int MQTTAsync_processCommand(void)
if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) == 0) if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) == 0)
serverURI += strlen(URI_TCP); serverURI += strlen(URI_TCP);
else if (strncmp(URI_WS, serverURI, strlen(URI_WS)) == 0)
{
serverURI += strlen(URI_WS);
command->client->websocket = 1;
}
#if defined(OPENSSL) #if defined(OPENSSL)
else if (strncmp(URI_SSL, serverURI, strlen(URI_SSL)) == 0) else if (strncmp(URI_SSL, serverURI, strlen(URI_SSL)) == 0)
{ {
serverURI += strlen(URI_SSL); serverURI += strlen(URI_SSL);
command->client->ssl = 1; command->client->ssl = 1;
} }
else if (strncmp(URI_WSS, serverURI, strlen(URI_WSS)) == 0)
{
serverURI += strlen(URI_WSS);
command->client->ssl = 1;
command->client->websocket = 1;
}
#endif #endif
} }
...@@ -2081,7 +2093,7 @@ int MQTTAsync_setCallbacks(MQTTAsync handle, void* context, ...@@ -2081,7 +2093,7 @@ int MQTTAsync_setCallbacks(MQTTAsync handle, void* context,
FUNC_ENTRY; FUNC_ENTRY;
MQTTAsync_lock_mutex(mqttasync_mutex); MQTTAsync_lock_mutex(mqttasync_mutex);
if (m == NULL || ma == NULL || m->c->connect_state != NOT_IN_PROGRESS) if (m == NULL || ma == NULL || m->c == NULL || m->c->connect_state != NOT_IN_PROGRESS)
rc = MQTTASYNC_FAILURE; rc = MQTTASYNC_FAILURE;
else else
{ {
......
...@@ -327,9 +327,10 @@ int MQTTClient_create(MQTTClient* handle, const char* serverURI, const char* cli ...@@ -327,9 +327,10 @@ int MQTTClient_create(MQTTClient* handle, const char* serverURI, const char* cli
if (strstr(serverURI, "://") != NULL) if (strstr(serverURI, "://") != NULL)
{ {
if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) != 0 if (strncmp(URI_TCP, serverURI, strlen(URI_TCP)) != 0
&& strncmp(URI_WS, serverURI, strlen(URI_WS)) != 0
#if defined(OPENSSL) #if defined(OPENSSL)
&& strncmp(URI_SSL, serverURI, strlen(URI_SSL)) != 0 && strncmp(URI_SSL, serverURI, strlen(URI_SSL)) != 0
&& strncmp(URI_WSS, serverURI, strlen(URI_WSS)) != 0
#endif #endif
) )
{ {
...@@ -1699,11 +1700,16 @@ int MQTTClient_publish(MQTTClient handle, const char* topicName, int payloadlen, ...@@ -1699,11 +1700,16 @@ int MQTTClient_publish(MQTTClient handle, const char* topicName, int payloadlen,
goto exit; goto exit;
} }
p = malloc(sizeof(Publish)); p = malloc(sizeof(Publish) + payloadlen);
p->payload = payload; p->payload = payload;
p->payloadlen = payloadlen; p->payloadlen = payloadlen;
p->topic = (char*)topicName; if (payloadlen > 0)
{
p->payload = (char*)p + sizeof(Publish);
memcpy(p->payload, payload, payloadlen);
p->payloadlen = payloadlen;
}
p->topic = MQTTStrdup(topicName);
p->msgId = msgid; p->msgId = msgid;
rc = MQTTProtocol_startPublish(m->c, p, qos, retained, &msg); rc = MQTTProtocol_startPublish(m->c, p, qos, retained, &msg);
...@@ -1727,6 +1733,7 @@ int MQTTClient_publish(MQTTClient handle, const char* topicName, int payloadlen, ...@@ -1727,6 +1733,7 @@ int MQTTClient_publish(MQTTClient handle, const char* topicName, int payloadlen,
if (deliveryToken && qos > 0) if (deliveryToken && qos > 0)
*deliveryToken = msg->msgid; *deliveryToken = msg->msgid;
if (p->topic) free(p->topic);
free(p); free(p);
if (rc == SOCKET_ERROR) if (rc == SOCKET_ERROR)
...@@ -1830,6 +1837,7 @@ static MQTTPacket* MQTTClient_cycle(int* sock, unsigned long timeout, int* rc) ...@@ -1830,6 +1837,7 @@ static MQTTPacket* MQTTClient_cycle(int* sock, unsigned long timeout, int* rc)
*rc = 0; *rc = 0;
} }
} }
if (pack) if (pack)
{ {
int freed = 1; int freed = 1;
...@@ -1937,12 +1945,15 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r ...@@ -1937,12 +1945,15 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r
} }
} }
#endif #endif
else if (m->c->connect_state == WEBSOCKET_IN_PROGRESS || else if (m->c->connect_state == WEBSOCKET_IN_PROGRESS )
m->c->connect_state == WAIT_FOR_CONNACK) {
*rc = 1;
break;
}
else if (m->c->connect_state == WAIT_FOR_CONNACK)
{ {
int error; int error;
socklen_t len = sizeof(error); socklen_t len = sizeof(error);
if (getsockopt(m->c->net.socket, SOL_SOCKET, SO_ERROR, (char*)&error, &len) == 0) if (getsockopt(m->c->net.socket, SOL_SOCKET, SO_ERROR, (char*)&error, &len) == 0)
{ {
if (error) if (error)
......
...@@ -174,7 +174,7 @@ int MQTTPacket_send(networkHandles* net, Header header, char* buffer, size_t buf ...@@ -174,7 +174,7 @@ int MQTTPacket_send(networkHandles* net, Header header, char* buffer, size_t buf
char *buf; char *buf;
FUNC_ENTRY; FUNC_ENTRY;
ws_header = WebSocket_calculateFrameHeaderSize(net, buflen + 10); ws_header = WebSocket_calculateFrameHeaderSize(net, 1, buflen + 10);
buf = malloc(10 + ws_header); buf = malloc(10 + ws_header);
if ( !buf ) return -1; if ( !buf ) return -1;
...@@ -224,7 +224,7 @@ int MQTTPacket_sends(networkHandles* net, Header header, int count, char** buffe ...@@ -224,7 +224,7 @@ int MQTTPacket_sends(networkHandles* net, Header header, int count, char** buffe
for (i = 0; i < count; i++) for (i = 0; i < count; i++)
total += buflens[i]; total += buflens[i];
ws_header = WebSocket_calculateFrameHeaderSize(net, total + 10); ws_header = WebSocket_calculateFrameHeaderSize(net, 1, total + 10);
buf = malloc(10 + ws_header); buf = malloc(10 + ws_header);
if ( !buf ) return -1; if ( !buf ) return -1;
......
...@@ -219,13 +219,14 @@ Publications* MQTTProtocol_storePublication(Publish* publish, int* len) ...@@ -219,13 +219,14 @@ Publications* MQTTProtocol_storePublication(Publish* publish, int* len)
p->refcount = 1; p->refcount = 1;
*len = (int)strlen(publish->topic)+1; *len = (int)strlen(publish->topic)+1;
p->topic = malloc(*len);
strcpy(p->topic, publish->topic);
if (Heap_findItem(publish->topic)) if (Heap_findItem(publish->topic))
p->topic = publish->topic;
else
{ {
p->topic = malloc(*len); free(publish->topic);
strcpy(p->topic, publish->topic); publish->topic = NULL;
} }
*len += sizeof(Publications); *len += sizeof(Publications);
p->topiclen = publish->topiclen; p->topiclen = publish->topiclen;
......
...@@ -180,7 +180,7 @@ static int WebSocket_buildFrame(networkHandles* net, int opcode, int mask_data, ...@@ -180,7 +180,7 @@ static int WebSocket_buildFrame(networkHandles* net, int opcode, int mask_data,
for (i = 0; i < count; ++i) for (i = 0; i < count; ++i)
data_len += buflens[i]; data_len += buflens[i];
buf0 -= WebSocket_calculateFrameHeaderSize(net, data_len); buf0 -= WebSocket_calculateFrameHeaderSize(net, mask_data, data_len);
if ( net->websocket ) if ( net->websocket )
{ {
uint8_t mask[4]; uint8_t mask[4];
...@@ -197,11 +197,11 @@ static int WebSocket_buildFrame(networkHandles* net, int opcode, int mask_data, ...@@ -197,11 +197,11 @@ static int WebSocket_buildFrame(networkHandles* net, int opcode, int mask_data,
/* 1st byte */ /* 1st byte */
buf0[buf_len] = (char)(1 << 7); /* final flag */ buf0[buf_len] = (char)(1 << 7); /* final flag */
/* 3 bits reserved for negotiation of protocol */ /* 3 bits reserved for negotiation of protocol */
buf0[buf_len] |= (opcode & 0x0F); /* op code */ buf0[buf_len] |= (char)(opcode & 0x0F); /* op code */
++buf_len; ++buf_len;
/* 2nd byte */ /* 2nd byte */
buf0[buf_len] = (mask_data & 0x1) << 7; /* masking bit */ buf0[buf_len] = (char)((mask_data & 0x1) << 7); /* masking bit */
/* payload length */ /* payload length */
if ( data_len < 126u ) if ( data_len < 126u )
...@@ -247,7 +247,7 @@ static int WebSocket_buildFrame(networkHandles* net, int opcode, int mask_data, ...@@ -247,7 +247,7 @@ static int WebSocket_buildFrame(networkHandles* net, int opcode, int mask_data,
for (i = 0; i < count; ++i) for (i = 0; i < count; ++i)
{ {
size_t j; size_t j;
for ( j = 0; j < buflens[i]; ++j, ++idx ) for ( j = 0u; j < buflens[i]; ++j, ++idx )
buffers[i][j] ^= mask[idx % 4]; buffers[i][j] ^= mask[idx % 4];
} }
} }
...@@ -263,13 +263,15 @@ static int WebSocket_buildFrame(networkHandles* net, int opcode, int mask_data, ...@@ -263,13 +263,15 @@ static int WebSocket_buildFrame(networkHandles* net, int opcode, int mask_data,
* buffer * buffer
* *
* @param[in,out] net network connection * @param[in,out] net network connection
* @param[in] mask_data whether to mask the data
* @param[in] data_len amount of data in the payload * @param[in] data_len amount of data in the payload
* *
* @return the size in bytes of the websocket header required * @return the size in bytes of the websocket header required
* *
* @see WebSocket_putdatas * @see WebSocket_putdatas
*/ */
size_t WebSocket_calculateFrameHeaderSize(networkHandles *net, size_t data_len) size_t WebSocket_calculateFrameHeaderSize(networkHandles *net, int mask_data,
size_t data_len)
{ {
int ret = 0; int ret = 0;
if ( net && net->websocket ) if ( net && net->websocket )
...@@ -279,7 +281,8 @@ size_t WebSocket_calculateFrameHeaderSize(networkHandles *net, size_t data_len) ...@@ -279,7 +281,8 @@ size_t WebSocket_calculateFrameHeaderSize(networkHandles *net, size_t data_len)
ret += 2; /* for extra 2-bytes for payload length */ ret += 2; /* for extra 2-bytes for payload length */
else if ( data_len > 65536u ) else if ( data_len > 65536u )
ret += 8; /* for extra 8-bytes for payload length */ ret += 8; /* for extra 8-bytes for payload length */
ret += sizeof(uint32_t); /* for mask - all outgoing data is masked */ if ( mask_data & 0x1 )
ret += sizeof(uint32_t); /* for mask */
} }
return ret; return ret;
} }
...@@ -382,7 +385,8 @@ void WebSocket_close(networkHandles *net, int status_code, const char *reason) ...@@ -382,7 +385,8 @@ void WebSocket_close(networkHandles *net, int status_code, const char *reason)
char *buf0; char *buf0;
size_t buf0len = sizeof(uint16_t); size_t buf0len = sizeof(uint16_t);
size_t header_len; size_t header_len;
uint16_t status_code_big_endian; uint16_t status_code_be;
const int mask_data = 0;
if ( status_code < WebSocket_CLOSE_NORMAL || if ( status_code < WebSocket_CLOSE_NORMAL ||
status_code > WebSocket_CLOSE_TLS_FAIL ) status_code > WebSocket_CLOSE_TLS_FAIL )
...@@ -391,22 +395,23 @@ void WebSocket_close(networkHandles *net, int status_code, const char *reason) ...@@ -391,22 +395,23 @@ void WebSocket_close(networkHandles *net, int status_code, const char *reason)
if ( reason ) if ( reason )
buf0len += strlen(reason); buf0len += strlen(reason);
header_len = WebSocket_calculateFrameHeaderSize(net, buf0len); header_len = WebSocket_calculateFrameHeaderSize(net,
mask_data, buf0len);
buf0 = malloc(header_len + buf0len); buf0 = malloc(header_len + buf0len);
if ( !buf0 ) return; if ( !buf0 ) return;
/* encode status code */ /* encode status code */
status_code_big_endian = (uint16_t)htobe32((uint16_t)status_code); status_code_be = htobe16((uint16_t)status_code);
memcpy( &buf0[header_len], &status_code_big_endian, memcpy( &buf0[header_len], &status_code_be, sizeof(uint16_t));
sizeof(uint16_t));
/* encode reason, if provided */ /* encode reason, if provided */
if ( reason ) if ( reason )
strcpy( &buf0[header_len + 2], reason ); strcpy( &buf0[header_len + sizeof(uint16_t)], reason );
WebSocket_buildFrame( net, WebSocket_OP_CLOSE, 1, WebSocket_buildFrame( net, WebSocket_OP_CLOSE, mask_data,
&buf0[header_len], buf0len, 0, NULL, NULL ); &buf0[header_len], buf0len, 0, NULL, NULL );
buf0len += header_len;
#if defined(OPENSSL) #if defined(OPENSSL)
if (net->ssl) if (net->ssl)
SSLSocket_putdatas(net->ssl, net->socket, SSLSocket_putdatas(net->ssl, net->socket,
...@@ -584,14 +589,15 @@ void WebSocket_pong(networkHandles *net, char *app_data, ...@@ -584,14 +589,15 @@ void WebSocket_pong(networkHandles *net, char *app_data,
char *buf0; char *buf0;
size_t header_len; size_t header_len;
int freeData = 0; int freeData = 0;
const int mask_data = 0;
header_len = WebSocket_calculateFrameHeaderSize(net, header_len = WebSocket_calculateFrameHeaderSize(net, mask_data,
app_data_len); app_data_len);
buf0 = malloc(header_len); buf0 = malloc(header_len);
if ( !buf0 ) return; if ( !buf0 ) return;
WebSocket_buildFrame( net, WebSocket_OP_PONG, 1, WebSocket_buildFrame( net, WebSocket_OP_PONG, 1,
&buf0[header_len], header_len, 1, &app_data, &buf0[header_len], header_len, mask_data, &app_data,
&app_data_len ); &app_data_len );
Log(TRACE_PROTOCOL, 1, "Sending WebSocket PONG" ); Log(TRACE_PROTOCOL, 1, "Sending WebSocket PONG" );
...@@ -599,8 +605,8 @@ void WebSocket_pong(networkHandles *net, char *app_data, ...@@ -599,8 +605,8 @@ void WebSocket_pong(networkHandles *net, char *app_data,
#if defined(OPENSSL) #if defined(OPENSSL)
if (net->ssl) if (net->ssl)
SSLSocket_putdatas(net->ssl, net->socket, buf0, SSLSocket_putdatas(net->ssl, net->socket, buf0,
header_len, 1, &app_data, &app_data_len, header_len + app_data_len, 1,
&freeData); &app_data, &app_data_len, &freeData);
else else
#endif #endif
Socket_putdatas(net->socket, buf0, Socket_putdatas(net->socket, buf0,
...@@ -640,15 +646,24 @@ int WebSocket_putdatas(networkHandles* net, char* buf0, size_t buf0len, ...@@ -640,15 +646,24 @@ int WebSocket_putdatas(networkHandles* net, char* buf0, size_t buf0len,
/* prepend WebSocket frame */ /* prepend WebSocket frame */
if ( net->websocket ) if ( net->websocket )
{ {
size_t data_len = buf0len + 4u;
size_t header_len;
const int mask_data = 1;
for (rc = 0; rc < count; ++rc)
data_len += buflens[rc];
header_len = WebSocket_calculateFrameHeaderSize(
net, mask_data, data_len);
rc = WebSocket_buildFrame( rc = WebSocket_buildFrame(
net, WebSocket_OP_BINARY, 1, buf0, buf0len, net, WebSocket_OP_BINARY, mask_data, buf0, buf0len,
count, buffers, buflens ); count, buffers, buflens );
/* header added so adjust buffer */ /* header added so adjust buffer */
if ( rc > 0 ) if ( rc > 0 )
{ {
buf0 -= rc; buf0 -= header_len;
buf0len += rc; buf0len += header_len;
} }
} }
...@@ -849,9 +864,13 @@ void WebSocket_terminate( void ) ...@@ -849,9 +864,13 @@ void WebSocket_terminate( void )
f = ListDetachHead( in_frames ); f = ListDetachHead( in_frames );
} }
ListFree( in_frames ); ListFree( in_frames );
in_frames = NULL;
} }
if ( last_frame ) if ( last_frame )
{
free( last_frame ); free( last_frame );
last_frame = NULL;
}
Socket_outTerminate(); Socket_outTerminate();
#if defined(OPENSSL) #if defined(OPENSSL)
SSLSocket_terminate(); SSLSocket_terminate();
......
...@@ -57,7 +57,8 @@ void WebSocket_close(networkHandles *net, int status_code, const char *reason); ...@@ -57,7 +57,8 @@ void WebSocket_close(networkHandles *net, int status_code, const char *reason);
int WebSocket_connect(networkHandles *net, const char *uri); int WebSocket_connect(networkHandles *net, const char *uri);
/* calculates the extra data required in a packet to hold a WebSocket frame header */ /* calculates the extra data required in a packet to hold a WebSocket frame header */
size_t WebSocket_calculateFrameHeaderSize(networkHandles *net, size_t data_len); size_t WebSocket_calculateFrameHeaderSize(networkHandles *net, int mask_data,
size_t data_len);
/* obtain data from network socket */ /* obtain data from network socket */
int WebSocket_getch(networkHandles *net, char* c); int WebSocket_getch(networkHandles *net, char* c);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment