forked from localstack/localstack
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtestutil.py
More file actions
258 lines (215 loc) · 8.32 KB
/
testutil.py
File metadata and controls
258 lines (215 loc) · 8.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import os
import json
import glob
import tempfile
import requests
import shutil
import zipfile
from six import iteritems
from localstack.config import DEFAULT_REGION
from localstack.utils.aws import aws_stack
from localstack.constants import (LOCALSTACK_ROOT_FOLDER, LOCALSTACK_VENV_FOLDER,
LAMBDA_TEST_ROLE, TEST_AWS_ACCOUNT_ID)
from localstack.utils.common import mkdir, to_str, save_file, TMP_FILES
from localstack.services.awslambda.lambda_api import (get_handler_file_from_name, LAMBDA_DEFAULT_HANDLER,
LAMBDA_DEFAULT_RUNTIME, LAMBDA_DEFAULT_STARTING_POSITION, LAMBDA_DEFAULT_TIMEOUT)
ARCHIVE_DIR_PREFIX = 'lambda.archive.'
def create_lambda_archive(script, get_content=False, libs=[], runtime=None, file_name=None):
"""Utility method to create a Lambda function archive"""
tmp_dir = tempfile.mkdtemp(prefix=ARCHIVE_DIR_PREFIX)
TMP_FILES.append(tmp_dir)
file_name = file_name or get_handler_file_from_name(LAMBDA_DEFAULT_HANDLER, runtime=runtime)
script_file = os.path.join(tmp_dir, file_name)
if os.path.sep in script_file:
mkdir(os.path.dirname(script_file))
save_file(script_file, script)
# copy libs
for lib in libs:
paths = [lib, '%s.py' % lib]
target_dir = tmp_dir
root_folder = os.path.join(LOCALSTACK_VENV_FOLDER, 'lib/python*/site-packages')
if lib == 'localstack':
paths = ['localstack/*.py', 'localstack/utils']
root_folder = LOCALSTACK_ROOT_FOLDER
target_dir = os.path.join(tmp_dir, lib)
mkdir(target_dir)
for path in paths:
file_to_copy = os.path.join(root_folder, path)
for file_path in glob.glob(file_to_copy):
name = os.path.join(target_dir, file_path.split(os.path.sep)[-1])
if os.path.isdir(file_path):
shutil.copytree(file_path, name)
else:
shutil.copyfile(file_path, name)
# create zip file
return create_zip_file(tmp_dir, get_content=get_content)
def delete_lambda_function(name):
client = aws_stack.connect_to_service('lambda')
client.delete_function(FunctionName=name)
def create_zip_file(file_path, get_content=False):
base_dir = file_path
if not os.path.isdir(file_path):
base_dir = tempfile.mkdtemp(prefix=ARCHIVE_DIR_PREFIX)
shutil.copy(file_path, base_dir)
TMP_FILES.append(base_dir)
tmp_dir = tempfile.mkdtemp(prefix=ARCHIVE_DIR_PREFIX)
zip_file_name = 'archive.zip'
full_zip_file = os.path.join(tmp_dir, zip_file_name)
# create zip file
with zipfile.ZipFile(full_zip_file, 'w') as zip_file:
for root, dirs, files in os.walk(base_dir):
for name in files:
full_name = os.path.join(root, name)
relative = root[len(base_dir):].lstrip(os.path.sep)
dest = os.path.join(relative, name)
zip_file.write(full_name, dest)
if not get_content:
TMP_FILES.append(tmp_dir)
shutil.rmtree(tmp_dir)
return full_zip_file
zip_file_content = None
with open(full_zip_file, 'rb') as file_obj:
zip_file_content = file_obj.read()
shutil.rmtree(tmp_dir)
return zip_file_content
def create_lambda_function(func_name, zip_file, event_source_arn=None, handler=LAMBDA_DEFAULT_HANDLER,
starting_position=LAMBDA_DEFAULT_STARTING_POSITION, runtime=LAMBDA_DEFAULT_RUNTIME,
envvars={}, tags={}, delete=False, layers=None):
"""Utility method to create a new function via the Lambda API"""
client = aws_stack.connect_to_service('lambda')
if delete:
try:
# Delete function if one already exists
client.delete_function(FunctionName=func_name)
except Exception:
pass
# create function
kwargs = {
'FunctionName': func_name,
'Runtime': runtime,
'Handler': handler,
'Role': LAMBDA_TEST_ROLE,
'Code': {
'ZipFile': zip_file
},
'Timeout': LAMBDA_DEFAULT_TIMEOUT,
'Environment': dict(Variables=envvars),
'Tags': tags
}
if layers:
kwargs['Layers'] = layers
client.create_function(**kwargs)
# create event source mapping
if event_source_arn:
client.create_event_source_mapping(
FunctionName=func_name,
EventSourceArn=event_source_arn,
StartingPosition=starting_position
)
def assert_objects(asserts, all_objects):
if type(asserts) is not list:
asserts = [asserts]
for obj in asserts:
assert_object(obj, all_objects)
def assert_object(expected_object, all_objects):
# for Python 3 compatibility
dict_values = type({}.values())
if isinstance(all_objects, dict_values):
all_objects = list(all_objects)
# wrap single item in an array
if type(all_objects) is not list:
all_objects = [all_objects]
found = find_object(expected_object, all_objects)
if not found:
raise Exception('Expected object not found: %s in list %s' % (expected_object, all_objects))
def find_object(expected_object, object_list):
for obj in object_list:
if isinstance(obj, list):
found = find_object(expected_object, obj)
if found:
return found
all_ok = True
if obj != expected_object:
if not isinstance(expected_object, dict):
all_ok = False
else:
for k, v in iteritems(expected_object):
if not find_recursive(k, v, obj):
all_ok = False
break
if all_ok:
return obj
return None
def find_recursive(key, value, obj):
if isinstance(obj, dict):
for k, v in iteritems(obj):
if k == key and v == value:
return True
if find_recursive(key, value, v):
return True
elif isinstance(obj, list):
for o in obj:
if find_recursive(key, value, o):
return True
else:
return False
def list_all_s3_objects():
return map_all_s3_objects().values()
def download_s3_object(s3, bucket, path):
with tempfile.SpooledTemporaryFile() as tmpfile:
s3.Bucket(bucket).download_fileobj(path, tmpfile)
tmpfile.seek(0)
result = tmpfile.read()
try:
result = to_str(result)
except Exception:
pass
return result
def map_all_s3_objects(to_json=True):
s3_client = aws_stack.get_s3_client()
result = {}
for bucket in s3_client.buckets.all():
for key in bucket.objects.all():
value = download_s3_object(s3_client, key.bucket_name, key.key)
try:
if to_json:
value = json.loads(value)
result['%s/%s' % (key.bucket_name, key.key)] = value
except Exception:
# skip non-JSON or binary objects
pass
return result
def get_sample_arn(service, resource):
return 'arn:aws:%s:%s:%s:%s' % (service, DEFAULT_REGION, TEST_AWS_ACCOUNT_ID, resource)
def send_describe_dynamodb_ttl_request(table_name):
return send_dynamodb_request('', 'DescribeTimeToLive', json.dumps({'TableName': table_name}))
def send_update_dynamodb_ttl_request(table_name, ttl_status):
return send_dynamodb_request('', 'UpdateTimeToLive', json.dumps({
'TableName': table_name,
'TimeToLiveSpecification': {
'AttributeName': 'ExpireItem',
'Enabled': ttl_status
}
}))
def send_dynamodb_request(path, action, request_body):
headers = {
'Host': 'dynamodb.amazonaws.com',
'x-amz-target': 'DynamoDB_20120810.{}'.format(action),
'authorization': 'some_token'
}
url = '{}/{}'.format(os.getenv('TEST_DYNAMODB_URL'), path)
return requests.put(url, data=request_body, headers=headers, verify=False)
def create_sqs_queue(queue_name):
"""Utility method to create a new queue via SQS API"""
client = aws_stack.connect_to_service('sqs')
# create queue
queue_url = client.create_queue(QueueName=queue_name)['QueueUrl']
# get the queue arn
queue_arn = client.get_queue_attributes(
QueueUrl=queue_url,
AttributeNames=['QueueArn'],
)['Attributes']['QueueArn']
return {
'QueueUrl': queue_url,
'QueueArn': queue_arn,
}