diff --git a/InvenTree/InvenTree/api_tester.py b/InvenTree/InvenTree/api_tester.py index c55c3d3ba3..0420e93369 100644 --- a/InvenTree/InvenTree/api_tester.py +++ b/InvenTree/InvenTree/api_tester.py @@ -2,6 +2,10 @@ Helper functions for performing API unit tests """ +import io +import re + +from django.http.response import StreamingHttpResponse from django.contrib.auth import get_user_model from django.contrib.auth.models import Group from rest_framework.test import APITestCase @@ -165,3 +169,33 @@ class InvenTreeAPITestCase(APITestCase): self.assertEqual(response.status_code, expected_code) return response + + def download_file(self, url, data, expected_code=None, expected_fn=None): + """ + Download a file from the server, and return an in-memory file + """ + + response = self.client.get(url, data=data, format='json') + + if expected_code is not None: + self.assertEqual(response.status_code, expected_code) + + # Check that the response is of the correct type + if not isinstance(response, StreamingHttpResponse): + raise ValueError("Response is not a StreamingHttpResponse object as expected") + + # Extract filename + disposition = response.headers['Content-Disposition'] + + result = re.search(r'attachment; filename="([\w.]+)"', disposition) + + fn = result.groups()[0] + + with io.BytesIO() as fo: + fo.name = fn + fo.write(response.getvalue()) + + if expected_fn is not None: + self.assertEqual(expected_fn, fn) + + return fo \ No newline at end of file