|
2 | 2 | Helper functions for performing API unit tests |
3 | 3 | """ |
4 | 4 |
|
| 5 | +import csv |
| 6 | +import io |
| 7 | +import re |
| 8 | + |
| 9 | +from django.http.response import StreamingHttpResponse |
5 | 10 | from django.contrib.auth import get_user_model |
6 | 11 | from django.contrib.auth.models import Group |
7 | 12 | from rest_framework.test import APITestCase |
@@ -165,3 +170,87 @@ def options(self, url, expected_code=None): |
165 | 170 | self.assertEqual(response.status_code, expected_code) |
166 | 171 |
|
167 | 172 | return response |
| 173 | + |
| 174 | + def download_file(self, url, data, expected_code=None, expected_fn=None, decode=True): |
| 175 | + """ |
| 176 | + Download a file from the server, and return an in-memory file |
| 177 | + """ |
| 178 | + |
| 179 | + response = self.client.get(url, data=data, format='json') |
| 180 | + |
| 181 | + if expected_code is not None: |
| 182 | + self.assertEqual(response.status_code, expected_code) |
| 183 | + |
| 184 | + # Check that the response is of the correct type |
| 185 | + if not isinstance(response, StreamingHttpResponse): |
| 186 | + raise ValueError("Response is not a StreamingHttpResponse object as expected") |
| 187 | + |
| 188 | + # Extract filename |
| 189 | + disposition = response.headers['Content-Disposition'] |
| 190 | + |
| 191 | + result = re.search(r'attachment; filename="([\w.]+)"', disposition) |
| 192 | + |
| 193 | + fn = result.groups()[0] |
| 194 | + |
| 195 | + if expected_fn is not None: |
| 196 | + self.assertEqual(expected_fn, fn) |
| 197 | + |
| 198 | + if decode: |
| 199 | + # Decode data and return as StringIO file object |
| 200 | + fo = io.StringIO() |
| 201 | + fo.name = fo |
| 202 | + fo.write(response.getvalue().decode('UTF-8')) |
| 203 | + else: |
| 204 | + # Return a a BytesIO file object |
| 205 | + fo = io.BytesIO() |
| 206 | + fo.name = fn |
| 207 | + fo.write(response.getvalue()) |
| 208 | + |
| 209 | + fo.seek(0) |
| 210 | + |
| 211 | + return fo |
| 212 | + |
| 213 | + def process_csv(self, fo, delimiter=',', required_cols=None, excluded_cols=None, required_rows=None): |
| 214 | + """ |
| 215 | + Helper function to process and validate a downloaded csv file |
| 216 | + """ |
| 217 | + |
| 218 | + # Check that the correct object type has been passed |
| 219 | + self.assertTrue(isinstance(fo, io.StringIO)) |
| 220 | + |
| 221 | + fo.seek(0) |
| 222 | + |
| 223 | + reader = csv.reader(fo, delimiter=delimiter) |
| 224 | + |
| 225 | + headers = [] |
| 226 | + rows = [] |
| 227 | + |
| 228 | + for idx, row in enumerate(reader): |
| 229 | + if idx == 0: |
| 230 | + headers = row |
| 231 | + else: |
| 232 | + rows.append(row) |
| 233 | + |
| 234 | + if required_cols is not None: |
| 235 | + for col in required_cols: |
| 236 | + self.assertIn(col, headers) |
| 237 | + |
| 238 | + if excluded_cols is not None: |
| 239 | + for col in excluded_cols: |
| 240 | + self.assertNotIn(col, headers) |
| 241 | + |
| 242 | + if required_rows is not None: |
| 243 | + self.assertEqual(len(rows), required_rows) |
| 244 | + |
| 245 | + # Return the file data as a list of dict items, based on the headers |
| 246 | + data = [] |
| 247 | + |
| 248 | + for row in rows: |
| 249 | + entry = {} |
| 250 | + |
| 251 | + for idx, col in enumerate(headers): |
| 252 | + entry[col] = row[idx] |
| 253 | + |
| 254 | + data.append(entry) |
| 255 | + |
| 256 | + return data |
0 commit comments