diff --git a/src/sns-producer.ts b/src/sns-producer.ts index fc0775b..d1e9769 100644 --- a/src/sns-producer.ts +++ b/src/sns-producer.ts @@ -1,5 +1,6 @@ import * as aws from 'aws-sdk'; import { PromiseResult } from 'aws-sdk/lib/request'; +import { MessageAttributeMap } from 'aws-sdk/clients/sns'; import { v4 as uuid } from 'uuid'; import { S3PayloadMeta } from './types'; import { @@ -81,7 +82,7 @@ export class SnsProducer { return new SnsProducer(options); } - async publishJSON(message: unknown): Promise { + async publishJSON(message: unknown, snsMessageAttributes?: MessageAttributeMap): Promise { const messageBody = JSON.stringify(message); const msgSize = Buffer.byteLength(messageBody, 'utf-8'); @@ -104,7 +105,8 @@ export class SnsProducer { Key: s3Response.Key, Location: s3Response.Location, }, - msgSize + msgSize, + snsMessageAttributes ); return { @@ -121,6 +123,7 @@ export class SnsProducer { .publish({ Message: messageBody, TopicArn: this.topicArn, + MessageAttributes: snsMessageAttributes || {}, }) .promise(); @@ -131,11 +134,17 @@ export class SnsProducer { async publishS3Payload( s3PayloadMeta: S3PayloadMeta, - msgSize?: number + msgSize?: number, + snsMessageAttributes?: MessageAttributeMap ): Promise> { - const messageAttributes = this.extendedLibraryCompatibility - ? createExtendedCompatibilityAttributeMap(msgSize) - : {}; + const messageAttributes = { + ...(snsMessageAttributes || {}), + ...(this.extendedLibraryCompatibility + ? createExtendedCompatibilityAttributeMap(msgSize) + : {} + ) + }; + return await this.sns .publish({ Message: this.extendedLibraryCompatibility diff --git a/tests/sns-sqs.spec.ts b/tests/sns-sqs.spec.ts index 766cb6e..c9bc71f 100644 --- a/tests/sns-sqs.spec.ts +++ b/tests/sns-sqs.spec.ts @@ -10,6 +10,7 @@ import { } from '../src'; import * as aws from 'aws-sdk'; +import { MessageAttributeMap } from 'aws-sdk/clients/sns'; import { v4 as uuid } from 'uuid'; import { S3PayloadMeta } from '../src/types'; @@ -138,9 +139,9 @@ const getSnsProducer = (options: Partial = {}) => { }); }; -async function publishMessage(msg: any, options?: Partial) { +async function publishMessage(msg: any, options?: Partial, attributes?: MessageAttributeMap) { const snsProducer = getSnsProducer(options); - await snsProducer.publishJSON(msg); + await snsProducer.publishJSON(msg, attributes); } async function publishS3Payload(s3PayloadMeta: S3PayloadMeta, options?: Partial) { @@ -152,7 +153,7 @@ async function receiveMessages( expectedMsgsCount: number, options: Partial = {}, eventHandlers?: Record void> -): Promise { +): Promise { const { s3 } = getClients(); return new Promise((resolve, rej) => { const messages: SqsMessage[] = []; @@ -603,6 +604,24 @@ describe('sns-sqs-big-payload', () => { }); expect(receivedMessage.payload).toEqual(message); }); + + it('should publish and receive the message with SNS message attributes', async () => { + const message = { it: 'works' }; + const attributes = { + testAttribute: { + DataType: 'String', + StringValue: 'AttrubuteValue', + } + }; + await publishMessage(message, {}, attributes); + const [receivedMessage] = await receiveMessages(1, { + transformMessageBody: (body) => { + const snsMessage = JSON.parse(body); + return snsMessage.Message; + }, + }); + expect(receivedMessage.payload).toEqual(message); + }); }); describe('publishing message through s3', () => { @@ -643,6 +662,26 @@ describe('sns-sqs-big-payload', () => { expect(reReceivedMessage.payload).toEqual(message); expect(reReceivedMessage.s3PayloadMeta).toEqual(receivedMessage.s3PayloadMeta); }); + + it('should send payload though s3 with SNS message attributes', async () => { + const message = { it: 'works' }; + const attributes = { + testAttribute: { + DataType: 'String', + StringValue: 'AttrubuteValue', + } + }; + await publishMessage(message, { allPayloadThoughS3: true, s3Bucket: TEST_BUCKET_NAME }, attributes); + const [receivedMessage] = await receiveMessages(1, { + getPayloadFromS3: true, + // since it's SNS message we need to unwrap sns envelope first + transformMessageBody: (body) => { + const snsMessage = JSON.parse(body); + return snsMessage.Message; + }, + }); + expect(receivedMessage.payload).toEqual(message); + }); }); }); });