Skip to content
This repository was archived by the owner on Mar 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion localstack-core/localstack/services/sns/v2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Literal, TypedDict

from localstack.aws.api.sns import (
Endpoint,
MessageAttributeMap,
PlatformApplication,
PublishBatchRequestEntry,
Expand Down Expand Up @@ -39,6 +40,12 @@ class Topic(TypedDict, total=True):
]


class EndpointAttributeNames(StrEnum):
CUSTOM_USER_DATA = "CustomUserData"
Token = "Token"
ENABLED = "Enabled"


SMS_ATTRIBUTE_NAMES = [
"DeliveryStatusIAMRole",
"DeliveryStatusSuccessSamplingRate",
Expand Down Expand Up @@ -143,6 +150,19 @@ def from_batch_entry(cls, entry: PublishBatchRequestEntry, is_fifo=False) -> "Sn
)


@dataclass
class PlatformEndpoint:
platform_application_arn: str
platform_endpoint: Endpoint


@dataclass
class PlatformApplicationDetails:
platform_application: PlatformApplication
# maps all Endpoints of the PlatformApplication, from their Token to their ARN
platform_endpoints: dict[str, str]


class SnsStore(BaseStore):
topics: dict[str, Topic] = LocalAttribute(default=dict)

Expand All @@ -156,7 +176,10 @@ class SnsStore(BaseStore):
subscription_tokens: dict[str, str] = LocalAttribute(default=dict)

# maps platform application arns to platform applications
platform_applications: dict[str, PlatformApplication] = LocalAttribute(default=dict)
platform_applications: dict[str, PlatformApplicationDetails] = LocalAttribute(default=dict)

# maps endpoint arns to platform endpoints
platform_endpoints: dict[str, PlatformEndpoint] = LocalAttribute(default=dict)

# topic/subscription independent default values for sending sms messages
sms_attributes: dict[str, str] = LocalAttribute(default=dict)
Expand Down
156 changes: 144 additions & 12 deletions localstack-core/localstack/services/sns/v2/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from localstack.aws.api.sns import (
AmazonResourceName,
ConfirmSubscriptionResponse,
CreateEndpointResponse,
CreatePlatformApplicationResponse,
CreateTopicResponse,
Endpoint,
GetEndpointAttributesResponse,
GetPlatformApplicationAttributesResponse,
GetSMSAttributesResponse,
GetSubscriptionAttributesResponse,
Expand Down Expand Up @@ -60,6 +63,9 @@
SMS_ATTRIBUTE_NAMES,
SMS_DEFAULT_SENDER_REGEX,
SMS_TYPES,
EndpointAttributeNames,
PlatformApplicationDetails,
PlatformEndpoint,
SnsMessage,
SnsMessageType,
SnsStore,
Expand All @@ -68,6 +74,7 @@
sns_stores,
)
from localstack.services.sns.v2.utils import (
create_platform_endpoint_arn,
create_subscription_arn,
encode_subscription_token_with_region,
get_next_page_token_from_arn,
Expand Down Expand Up @@ -237,10 +244,11 @@ def subscribe(
raise InvalidParameterException("Invalid parameter: SQS endpoint ARN")

elif protocol == "application":
# TODO: This needs to be implemented once applications are ported from moto to the new provider
raise NotImplementedError(
"This functionality needs yet to be ported to the new SNS provider"
)
# TODO: Validate exact behaviour
try:
parse_arn(endpoint)
except InvalidArnException:
raise InvalidParameterException("Invalid parameter: ApplicationEndpoint ARN")

if ".fifo" in endpoint and ".fifo" not in topic_arn:
# TODO: move to sqs protocol block if possible
Expand Down Expand Up @@ -591,17 +599,24 @@ def create_platform_application(
account_id=context.account_id,
region_name=context.region,
)
platform_application = PlatformApplication(
PlatformApplicationArn=application_arn, Attributes=_attributes
platform_application_details = PlatformApplicationDetails(
platform_application=PlatformApplication(
PlatformApplicationArn=application_arn,
Attributes=_attributes,
),
platform_endpoints={},
)
store.platform_applications[application_arn] = platform_application
return CreatePlatformApplicationResponse(**platform_application)
store.platform_applications[application_arn] = platform_application_details

return platform_application_details.platform_application

def delete_platform_application(
self, context: RequestContext, platform_application_arn: String, **kwargs
) -> None:
store = self.get_store(context.account_id, context.region)
store.platform_applications.pop(platform_application_arn, None)
# TODO: if the platform had endpoints, should we remove them from the store? There is no way to list
# endpoints without an application, so this is impossible to check the state of AWS here

def list_platform_applications(
self, context: RequestContext, next_token: String | None = None, **kwargs
Expand All @@ -615,7 +630,9 @@ def list_platform_applications(
next_token=next_token,
)

response = ListPlatformApplicationsResponse(PlatformApplications=page)
response = ListPlatformApplicationsResponse(
PlatformApplications=[platform_app.platform_application for platform_app in page]
)
if token:
response["NextToken"] = token
return response
Expand Down Expand Up @@ -644,15 +661,112 @@ def set_platform_application_attributes(
# Platform Endpoints
#

def create_platform_endpoint(
self,
context: RequestContext,
platform_application_arn: String,
token: String,
custom_user_data: String | None = None,
attributes: MapStringToString | None = None,
**kwargs,
) -> CreateEndpointResponse:
store = self.get_store(context.account_id, context.region)
application = store.platform_applications.get(platform_application_arn)
if not application:
raise NotFoundException("PlatformApplication does not exist")
endpoint_arn = application.platform_endpoints.get(token, {})
attributes = attributes or {}
_validate_endpoint_attributes(attributes, allow_empty=True)
# CustomUserData can be specified both in attributes and as parameter. Attributes take precedence
attributes.setdefault(EndpointAttributeNames.CUSTOM_USER_DATA, custom_user_data)
_attributes = {"Enabled": "true", "Token": token, **attributes}
if endpoint_arn and (
platform_endpoint_details := store.platform_endpoints.get(endpoint_arn)
):
# endpoint for that application with that particular token already exists
if not platform_endpoint_details.platform_endpoint["Attributes"] == _attributes:
raise InvalidParameterException(
f"Invalid parameter: Token Reason: Endpoint {endpoint_arn} already exists with the same Token, but different attributes."
)
else:
return CreateEndpointResponse(EndpointArn=endpoint_arn)

endpoint_arn = create_platform_endpoint_arn(platform_application_arn)
platform_endpoint = PlatformEndpoint(
platform_application_arn=endpoint_arn,
platform_endpoint=Endpoint(
Attributes=_attributes,
EndpointArn=endpoint_arn,
),
)
store.platform_endpoints[endpoint_arn] = platform_endpoint
application.platform_endpoints[token] = endpoint_arn

return CreateEndpointResponse(EndpointArn=endpoint_arn)

def delete_endpoint(self, context: RequestContext, endpoint_arn: String, **kwargs) -> None:
store = self.get_store(context.account_id, context.region)
platform_endpoint_details = store.platform_endpoints.pop(endpoint_arn, None)
if platform_endpoint_details:
platform_application = store.platform_applications.get(
platform_endpoint_details.platform_application_arn
)
if platform_application:
platform_endpoint = platform_endpoint_details.platform_endpoint
platform_application.platform_endpoints.pop(
platform_endpoint["Attributes"]["Token"], None
)

def list_endpoints_by_platform_application(
self,
context: RequestContext,
platform_application_arn: String,
next_token: String | None = None,
**kwargs,
) -> ListEndpointsByPlatformApplicationResponse:
# TODO: stub so cleanup fixture won't fail
return ListEndpointsByPlatformApplicationResponse(Endpoints=[])
store = self.get_store(context.account_id, context.region)
platform_application = store.platform_applications.get(platform_application_arn)
if not platform_application:
raise NotFoundException("PlatformApplication does not exist")
endpoint_arns = platform_application.platform_endpoints.values()
paginated_endpoint_arns = PaginatedList(endpoint_arns)
page, token = paginated_endpoint_arns.get_page(
token_generator=lambda x: get_next_page_token_from_arn(x),
page_size=100,
next_token=next_token,
)

response = ListEndpointsByPlatformApplicationResponse(
Endpoints=[
store.platform_endpoints[endpoint_arn].platform_endpoint
for endpoint_arn in page
Comment on lines +740 to +742
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is slightly more complex than it needs to be due to the mentioned deviation of the type. A good example of the trade off I mentioned before

if endpoint_arn in store.platform_endpoints
]
)
if token:
response["NextToken"] = token
return response

def get_endpoint_attributes(
self, context: RequestContext, endpoint_arn: String, **kwargs
) -> GetEndpointAttributesResponse:
store = self.get_store(context.account_id, context.region)
platform_endpoint_details = store.platform_endpoints.get(endpoint_arn)
if not platform_endpoint_details:
raise NotFoundException("Endpoint does not exist")
attributes = platform_endpoint_details.platform_endpoint["Attributes"]
return GetEndpointAttributesResponse(Attributes=attributes)

def set_endpoint_attributes(
self, context: RequestContext, endpoint_arn: String, attributes: MapStringToString, **kwargs
) -> None:
store = self.get_store(context.account_id, context.region)
platform_endpoint_details = store.platform_endpoints.get(endpoint_arn)
if not platform_endpoint_details:
raise NotFoundException("Endpoint does not exist")
_validate_endpoint_attributes(attributes)
attributes = attributes or {}
platform_endpoint_details.platform_endpoint["Attributes"].update(attributes)

#
# Sms operations
Expand Down Expand Up @@ -736,7 +850,7 @@ def _get_platform_application(
parse_and_validate_platform_application_arn(platform_application_arn)
try:
store = SnsProvider.get_store(context.account_id, context.region)
return store.platform_applications[platform_application_arn]
return store.platform_applications[platform_application_arn].platform_application
except KeyError:
raise NotFoundException("PlatformApplication does not exist")

Expand Down Expand Up @@ -821,6 +935,10 @@ def _validate_platform_application_name(name: str) -> None:


def _validate_platform_application_attributes(attributes: dict) -> None:
_check_empty_attributes(attributes)


def _check_empty_attributes(attributes: dict) -> None:
if not attributes:
raise CommonServiceException(
code="ValidationError",
Expand All @@ -829,6 +947,20 @@ def _validate_platform_application_attributes(attributes: dict) -> None:
)


def _validate_endpoint_attributes(attributes: dict, allow_empty: bool = False) -> None:
if not allow_empty:
_check_empty_attributes(attributes)
for key in attributes:
if key not in EndpointAttributeNames:
raise InvalidParameterException(
f"Invalid parameter: Attributes Reason: Invalid attribute name: {key}"
)
if len(attributes.get(EndpointAttributeNames.CUSTOM_USER_DATA, "")) > 2048:
raise InvalidParameterException(
"Invalid parameter: Attributes Reason: Invalid value for attribute: CustomUserData: must be at most 2048 bytes long in UTF-8 encoding"
)


def _validate_sms_attributes(attributes: dict) -> None:
for k, v in attributes.items():
if k not in SMS_ATTRIBUTE_NAMES:
Expand Down
8 changes: 8 additions & 0 deletions localstack-core/localstack/services/sns/v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def create_subscription_arn(topic_arn: str) -> str:
return f"{topic_arn}:{uuid4()}"


def create_platform_endpoint_arn(
platform_application_arn: str,
) -> str:
# This is the format of an Endpoint Arn
# arn:aws:sns:us-west-2:1234567890:endpoint/GCM/MyApplication/12345678-abcd-9012-efgh-345678901234
return f"{platform_application_arn.replace('app', 'endpoint', 1)}/{uuid4()}"


def encode_subscription_token_with_region(region: str) -> str:
"""
Create a 64 characters Subscription Token with the region encoded
Expand Down
Loading
Loading