diff --git a/InvenTree/report/tests.py b/InvenTree/report/tests.py index a41abfff77..c30f2f4151 100644 --- a/InvenTree/report/tests.py +++ b/InvenTree/report/tests.py @@ -1,7 +1,11 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals +import os +import shutil + from django.urls import reverse +from django.conf import settings from InvenTree.api_tester import InvenTreeAPITestCase @@ -28,20 +32,89 @@ class ReportTest(InvenTreeAPITestCase): def setUp(self): super().setUp() + def copyReportTemplate(self, filename, description): + """ + Copy the provided report template into the required media directory + """ + + src_dir = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + 'templates', + 'report' + ) + + template_dir = os.path.join( + 'report', + 'inventree', + self.model.getSubdir(), + ) + + dst_dir = os.path.join( + settings.MEDIA_ROOT, + template_dir + ) + + if not os.path.exists(dst_dir): + os.makedirs(dst_dir, exist_ok=True) + + src_file = os.path.join(src_dir, filename) + dst_file = os.path.join(dst_dir, filename) + + if not os.path.exists(dst_file): + shutil.copyfile(src_file, dst_file) + + # Convert to an "internal" filename + db_filename = os.path.join( + template_dir, + filename + ) + + # Create a database entry for this report template! + self.model.objects.create( + name=os.path.splitext(filename)[0], + description=description, + template=db_filename, + enabled=True + ) + def test_list_endpoint(self): """ Test that the LIST endpoint works for each report """ - if self.list_url: - url = reverse(self.list_url) + if not self.list_url: + return - print("URL:", url) - response = self.get(url) - self.assertEqual(response.status_code, 200) - print("Response:") - print(response) - print(response.data) + url = reverse(self.list_url) + + response = self.get(url) + self.assertEqual(response.status_code, 200) + + reports = self.model.objects.all() + + n = len(reports) + + # API endpoint must return correct number of reports + self.assertEqual(len(response.data), n) + + # Filter by "enabled" status + response = self.get(url, {'enabled': True}) + self.assertEqual(len(response.data), n) + + response = self.get(url, {'enabled': False}) + self.assertEqual(len(response.data), 0) + + # Disable each report + for report in reports: + report.enabled = False + report.save() + + # Filter by "enabled" status + response = self.get(url, {'enabled': True}) + self.assertEqual(len(response.data), 0) + + response = self.get(url, {'enabled': False}) + self.assertEqual(len(response.data), n) class TestReportTest(ReportTest): @@ -52,6 +125,12 @@ class TestReportTest(ReportTest): detail_url = 'api-stockitem-testreport-detail' print_url = 'api-stockitem-testreport-print' + def setUp(self): + + self.copyReportTemplate('inventree_test_report.html', 'stock item test report') + + return super().setUp() + class BuildReportTest(ReportTest): @@ -61,6 +140,12 @@ class BuildReportTest(ReportTest): detail_url = 'api-build-report-detail' print_url = 'api-build-report-print' + def setUp(self): + + self.copyReportTemplate('inventree_build_order.html', 'build order template') + + return super().setUp() + class BOMReportTest(ReportTest):