|
2 | 2 | import urlparse |
3 | 3 | import logging |
4 | 4 | import json |
| 5 | +import xmltodict |
5 | 6 | import xml.etree.ElementTree as ET |
6 | 7 | from requests.models import Response |
7 | 8 | from localstack.constants import * |
|
11 | 12 | # mappings for S3 bucket notifications |
12 | 13 | S3_NOTIFICATIONS = {} |
13 | 14 |
|
| 15 | +# mappings for bucket CORS settings |
| 16 | +BUCKET_CORS = {} |
| 17 | + |
14 | 18 | # set up logger |
15 | 19 | LOGGER = logging.getLogger(__name__) |
16 | 20 |
|
@@ -93,27 +97,87 @@ def get_xml_text(node, name, ns=None, default=None): |
93 | 97 | return child.text |
94 | 98 |
|
95 | 99 |
|
| 100 | +def get_cors(bucket_name): |
| 101 | + response = Response() |
| 102 | + cors = BUCKET_CORS.get(bucket_name) |
| 103 | + if not cors: |
| 104 | + # TODO: check if bucket exists, otherwise return 404-like error |
| 105 | + cors = { |
| 106 | + 'CORSConfiguration': [] |
| 107 | + } |
| 108 | + body = xmltodict.unparse(cors) |
| 109 | + response._content = body |
| 110 | + response.status_code = 200 |
| 111 | + return response |
| 112 | + |
| 113 | + |
| 114 | +def set_cors(bucket_name, cors): |
| 115 | + # TODO: check if bucket exists, otherwise return 404-like error |
| 116 | + if isinstance(cors, basestring): |
| 117 | + cors = xmltodict.parse(cors) |
| 118 | + BUCKET_CORS[bucket_name] = cors |
| 119 | + response = Response() |
| 120 | + response.status_code = 200 |
| 121 | + return response |
| 122 | + |
| 123 | + |
| 124 | +def delete_cors(bucket_name): |
| 125 | + # TODO: check if bucket exists, otherwise return 404-like error |
| 126 | + BUCKET_CORS.pop(bucket_name, {}) |
| 127 | + response = Response() |
| 128 | + response.status_code = 200 |
| 129 | + return response |
| 130 | + |
| 131 | + |
| 132 | +def append_cors_headers(bucket_name, request_method, request_headers, response): |
| 133 | + cors = BUCKET_CORS.get(bucket_name) |
| 134 | + if not cors: |
| 135 | + return |
| 136 | + origin = request_headers.get('Origin', '') |
| 137 | + for rule in cors['CORSConfiguration']['CORSRule']: |
| 138 | + allowed_methods = rule.get('AllowedMethod', []) |
| 139 | + if request_method in allowed_methods: |
| 140 | + allowed_origins = rule.get('AllowedOrigin', []) |
| 141 | + for allowed in allowed_origins: |
| 142 | + if origin in allowed or re.match(allowed.replace('*', '.*'), origin): |
| 143 | + response.headers['Access-Control-Allow-Origin'] = origin |
| 144 | + break |
| 145 | + |
| 146 | + |
96 | 147 | def update_s3(method, path, data, headers, response=None, return_forward_info=False): |
97 | 148 | if return_forward_info: |
98 | 149 | parsed = urlparse.urlparse(path) |
99 | 150 | query = parsed.query |
100 | 151 | path = parsed.path |
| 152 | + bucket = path.split('/')[1] |
101 | 153 | query_map = urlparse.parse_qs(query) |
102 | 154 | if method == 'PUT' and (query == 'notification' or 'notification' in query_map): |
103 | 155 | tree = ET.fromstring(data) |
104 | 156 | queue_config = tree.find('{%s}QueueConfiguration' % XMLNS_S3) |
105 | 157 | if len(queue_config): |
106 | | - bucket = path[1:] |
107 | 158 | S3_NOTIFICATIONS[bucket] = { |
108 | 159 | 'Id': get_xml_text(queue_config, 'Id'), |
109 | 160 | 'Event': get_xml_text(queue_config, 'Event', ns=XMLNS_S3), |
110 | 161 | 'Queue': get_xml_text(queue_config, 'Queue', ns=XMLNS_S3), |
111 | 162 | 'Topic': get_xml_text(queue_config, 'Topic', ns=XMLNS_S3), |
112 | 163 | 'CloudFunction': get_xml_text(queue_config, 'CloudFunction', ns=XMLNS_S3) |
113 | 164 | } |
| 165 | + if query == 'cors' or 'cors' in query_map: |
| 166 | + if method == 'GET': |
| 167 | + return get_cors(bucket) |
| 168 | + if method == 'PUT': |
| 169 | + return set_cors(bucket, data) |
| 170 | + if method == 'DELETE': |
| 171 | + return delete_cors(bucket) |
114 | 172 | return True |
| 173 | + # get subscribers and send bucket notifications |
115 | 174 | if method in ('PUT', 'DELETE') and '/' in path[1:]: |
116 | 175 | parts = path[1:].split('/', 1) |
117 | 176 | bucket_name = parts[0] |
118 | 177 | object_path = '/%s' % parts[1] |
119 | 178 | send_notifications(method, bucket_name, object_path) |
| 179 | + # append CORS headers to response |
| 180 | + if response: |
| 181 | + parsed = urlparse.urlparse(path) |
| 182 | + bucket_name = parsed.path.split('/')[0] |
| 183 | + append_cors_headers(bucket_name, request_method=method, request_headers=headers, response=response) |
0 commit comments