Skip to content
This repository was archived by the owner on Mar 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions localstack-core/localstack/services/sns/v2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from localstack.aws.api.sns import (
MessageAttributeMap,
PlatformApplication,
PublishBatchRequestEntry,
TopicAttributesMap,
subscriptionARN,
Expand Down Expand Up @@ -39,6 +38,12 @@ class Topic(TypedDict, total=True):
]


class EndpointAttributeNames(StrEnum):
CUSTOM_USER_DATA = "CustomUserData"
Token = "Token"
ENABLED = "Enabled"


SMS_ATTRIBUTE_NAMES = [
"DeliveryStatusIAMRole",
"DeliveryStatusSuccessSamplingRate",
Expand Down Expand Up @@ -143,6 +148,18 @@ def from_batch_entry(cls, entry: PublishBatchRequestEntry, is_fifo=False) -> "Sn
)


class PlatformEndpoint(TypedDict, total=False):
PlatformEndpointArn: str
Attributes: dict[str, str]
PlatformApplicationArn: str
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type exists in the Api as well, but without the PlatformApplicationArn. Having a reference back to the owning object makes delete easier so we don't have to iterate over all applications, however it might not be worth losing the advantage of sticking exactly to the spec. Wdyt?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good point! I believe @giograno is starting to establish a pattern here, where we use a dataclass to wrap the existing spec type, and add more data to it.

See https://github.com/localstack/localstack-pro/pull/5397

Something like:

from localstack.aws.api.sns import PlatformApplication

@dataclasses.dataclass
class PlatformApplicationDetails:
    platform_application: PlatformApplication
    platform_application_arn: str

I hope it's okay, I'll go ahead and push those to get the PR in this week 👍



class SnsPlatformApplication(TypedDict, total=False):
PlatformApplicationArn: str
Attributes: dict[str, str]
PlatformEndpoints: dict[str, str]


class SnsStore(BaseStore):
topics: dict[str, Topic] = LocalAttribute(default=dict)

Expand All @@ -156,7 +173,10 @@ class SnsStore(BaseStore):
subscription_tokens: dict[str, str] = LocalAttribute(default=dict)

# maps platform application arns to platform applications
platform_applications: dict[str, PlatformApplication] = LocalAttribute(default=dict)
platform_applications: dict[str, SnsPlatformApplication] = LocalAttribute(default=dict)

# maps endpoint arns to platform endpoints
platform_endpoints: dict[str, PlatformEndpoint] = LocalAttribute(default=dict)

# topic/subscription independent default values for sending sms messages
sms_attributes: dict[str, str] = LocalAttribute(default=dict)
Expand Down
139 changes: 129 additions & 10 deletions localstack-core/localstack/services/sns/v2/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from localstack.aws.api.sns import (
AmazonResourceName,
ConfirmSubscriptionResponse,
CreateEndpointResponse,
CreatePlatformApplicationResponse,
CreateTopicResponse,
GetEndpointAttributesResponse,
GetPlatformApplicationAttributesResponse,
GetSMSAttributesResponse,
GetSubscriptionAttributesResponse,
Expand All @@ -26,7 +28,6 @@
ListTopicsResponse,
MapStringToString,
NotFoundException,
PlatformApplication,
SetSMSAttributesResponse,
SnsApi,
String,
Expand Down Expand Up @@ -60,14 +61,18 @@
SMS_ATTRIBUTE_NAMES,
SMS_DEFAULT_SENDER_REGEX,
SMS_TYPES,
EndpointAttributeNames,
PlatformEndpoint,
SnsMessage,
SnsMessageType,
SnsPlatformApplication,
SnsStore,
SnsSubscription,
Topic,
sns_stores,
)
from localstack.services.sns.v2.utils import (
create_platform_endpoint_arn,
create_subscription_arn,
encode_subscription_token_with_region,
get_next_page_token_from_arn,
Expand Down Expand Up @@ -237,10 +242,11 @@ def subscribe(
raise InvalidParameterException("Invalid parameter: SQS endpoint ARN")

elif protocol == "application":
# TODO: This needs to be implemented once applications are ported from moto to the new provider
raise NotImplementedError(
"This functionality needs yet to be ported to the new SNS provider"
)
# TODO: Validate exact behaviour
try:
parse_arn(endpoint)
except InvalidArnException:
raise InvalidParameterException("Invalid parameter: ApplicationEndpoint ARN")

if ".fifo" in endpoint and ".fifo" not in topic_arn:
# TODO: move to sqs protocol block if possible
Expand Down Expand Up @@ -591,8 +597,10 @@ def create_platform_application(
account_id=context.account_id,
region_name=context.region,
)
platform_application = PlatformApplication(
PlatformApplicationArn=application_arn, Attributes=_attributes
platform_application = SnsPlatformApplication(
PlatformApplicationArn=application_arn,
Attributes=_attributes,
PlatformEndpoints={},
)
store.platform_applications[application_arn] = platform_application
return CreatePlatformApplicationResponse(**platform_application)
Expand All @@ -602,6 +610,8 @@ def delete_platform_application(
) -> None:
store = self.get_store(context.account_id, context.region)
store.platform_applications.pop(platform_application_arn, None)
# TODO: if the platform had endpoints, should we remove them from the store? There is no way to list
# endpoints without an application, so this is impossible to check the state of AWS here

def list_platform_applications(
self, context: RequestContext, next_token: String | None = None, **kwargs
Expand Down Expand Up @@ -644,15 +654,106 @@ def set_platform_application_attributes(
# Platform Endpoints
#

def create_platform_endpoint(
self,
context: RequestContext,
platform_application_arn: String,
token: String,
custom_user_data: String | None = None,
attributes: MapStringToString | None = None,
**kwargs,
) -> CreateEndpointResponse:
store = self.get_store(context.account_id, context.region)
application = store.platform_applications.get(platform_application_arn)
if not application:
raise NotFoundException("PlatformApplication does not exist")
endpoint_arn = application["PlatformEndpoints"].get(token, {})
attributes = attributes or {}
_validate_endpoint_attributes(attributes, allow_empty=True)
# CustomUserData can be specified both in attributes and as parameter. Attributes take precedence
attributes.setdefault(EndpointAttributeNames.CUSTOM_USER_DATA, custom_user_data)
_attributes = {"Enabled": "true", "Token": token, **attributes}
if endpoint_arn and (endpoint := store.platform_endpoints.get(endpoint_arn)):
# endpoint for that application with that particular token already exists
if not endpoint["Attributes"] == _attributes:
raise InvalidParameterException(
f"Invalid parameter: Token Reason: Endpoint {endpoint_arn} already exists with the same Token, but different attributes."
)
else:
return CreateEndpointResponse(EndpointArn=endpoint_arn)

endpoint_arn = create_platform_endpoint_arn(platform_application_arn)
endpoint = PlatformEndpoint(
Attributes=_attributes,
PlatformEndpointArn=endpoint_arn,
PlatformApplicationArn=platform_application_arn,
)
store.platform_endpoints[endpoint_arn] = endpoint
application["PlatformEndpoints"][token] = endpoint_arn

return CreateEndpointResponse(EndpointArn=endpoint_arn)

def delete_endpoint(self, context: RequestContext, endpoint_arn: String, **kwargs) -> None:
store = self.get_store(context.account_id, context.region)
endpoint = store.platform_endpoints.pop(endpoint_arn, None)
if endpoint:
platform_application = store.platform_applications.get(
endpoint["PlatformApplicationArn"]
)
if platform_application:
platform_application["PlatformEndpoints"].pop(endpoint["Attributes"]["Token"], None)

def list_endpoints_by_platform_application(
self,
context: RequestContext,
platform_application_arn: String,
next_token: String | None = None,
**kwargs,
) -> ListEndpointsByPlatformApplicationResponse:
# TODO: stub so cleanup fixture won't fail
return ListEndpointsByPlatformApplicationResponse(Endpoints=[])
store = self.get_store(context.account_id, context.region)
platform_application = store.platform_applications.get(platform_application_arn, {})
if not platform_application:
raise NotFoundException("PlatformApplication does not exist")
endpoint_arns = platform_application.get("PlatformEndpoints").values()
paginated_endpoint_arns = PaginatedList(endpoint_arns)
page, token = paginated_endpoint_arns.get_page(
token_generator=lambda x: get_next_page_token_from_arn(x),
page_size=100,
next_token=next_token,
)
response = ListEndpointsByPlatformApplicationResponse(
Endpoints=[
{
"EndpointArn": endpoint_arn,
"Attributes": store.platform_endpoints.get(endpoint_arn, {}).get("Attributes"),
}
for endpoint_arn in page
Comment on lines +740 to +742
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is slightly more complex than it needs to be due to the mentioned deviation of the type. A good example of the trade off I mentioned before

]
)
if token:
response["NextToken"] = token
return response

def get_endpoint_attributes(
self, context: RequestContext, endpoint_arn: String, **kwargs
) -> GetEndpointAttributesResponse:
store = self.get_store(context.account_id, context.region)
endpoint = store.platform_endpoints.get(endpoint_arn)
if not endpoint:
raise NotFoundException("Endpoint does not exist")
attributes = endpoint["Attributes"]
return GetEndpointAttributesResponse(Attributes=attributes)

def set_endpoint_attributes(
self, context: RequestContext, endpoint_arn: String, attributes: MapStringToString, **kwargs
) -> None:
store = self.get_store(context.account_id, context.region)
endpoint = store.platform_endpoints.get(endpoint_arn)
if not endpoint:
raise NotFoundException("Endpoint does not exist")
_validate_endpoint_attributes(attributes)
attributes = attributes or {}
endpoint["Attributes"].update(attributes)

#
# Sms operations
Expand Down Expand Up @@ -732,7 +833,7 @@ def _get_topic(arn: str, context: RequestContext) -> Topic:
@staticmethod
def _get_platform_application(
platform_application_arn: str, context: RequestContext
) -> PlatformApplication:
) -> SnsPlatformApplication:
parse_and_validate_platform_application_arn(platform_application_arn)
try:
store = SnsProvider.get_store(context.account_id, context.region)
Expand Down Expand Up @@ -821,6 +922,10 @@ def _validate_platform_application_name(name: str) -> None:


def _validate_platform_application_attributes(attributes: dict) -> None:
_check_empty_attributes(attributes)


def _check_empty_attributes(attributes: dict) -> None:
if not attributes:
raise CommonServiceException(
code="ValidationError",
Expand All @@ -829,6 +934,20 @@ def _validate_platform_application_attributes(attributes: dict) -> None:
)


def _validate_endpoint_attributes(attributes: dict, allow_empty: bool = False) -> None:
if not allow_empty:
_check_empty_attributes(attributes)
for key in attributes:
if key not in EndpointAttributeNames:
raise InvalidParameterException(
f"Invalid parameter: Attributes Reason: Invalid attribute name: {key}"
)
if len(attributes.get(EndpointAttributeNames.CUSTOM_USER_DATA, "")) > 2048:
raise InvalidParameterException(
"Invalid parameter: Attributes Reason: Invalid value for attribute: CustomUserData: must be at most 2048 bytes long in UTF-8 encoding"
)


def _validate_sms_attributes(attributes: dict) -> None:
for k, v in attributes.items():
if k not in SMS_ATTRIBUTE_NAMES:
Expand Down
8 changes: 8 additions & 0 deletions localstack-core/localstack/services/sns/v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def create_subscription_arn(topic_arn: str) -> str:
return f"{topic_arn}:{uuid4()}"


def create_platform_endpoint_arn(
platform_application_arn: str,
) -> str:
# This is the format of an Endpoint Arn
# arn:aws:sns:us-west-2:1234567890:endpoint/GCM/MyApplication/12345678-abcd-9012-efgh-345678901234
return f"{platform_application_arn.replace('app', 'endpoint', 1)}/{uuid4()}"


def encode_subscription_token_with_region(region: str) -> str:
"""
Create a 64 characters Subscription Token with the region encoded
Expand Down
Loading
Loading