diff --git a/InvenTree/part/api.py b/InvenTree/part/api.py index e67640ccef..4e9ab8f5ca 100644 --- a/InvenTree/part/api.py +++ b/InvenTree/part/api.py @@ -1118,6 +1118,7 @@ class PartList(APIDownloadMixin, ListCreateAPI): params = self.request.query_params kwargs['parameters'] = str2bool(params.get('parameters', None)) + kwargs['category_detail'] = str2bool(params.get('category_detail', False)) except AttributeError: pass @@ -1156,41 +1157,6 @@ class PartList(APIDownloadMixin, ListCreateAPI): data = serializer.data - # Do we wish to include PartCategory detail? - if str2bool(request.query_params.get('category_detail', False)): - - # Work out which part categories we need to query - category_ids = set() - - for part in data: - cat_id = part['category'] - - if cat_id is not None: - category_ids.add(cat_id) - - # Fetch only the required PartCategory objects from the database - categories = PartCategory.objects.filter(pk__in=category_ids).prefetch_related( - 'parts', - 'parent', - 'children', - ) - - category_map = {} - - # Serialize each PartCategory object - for category in categories: - category_map[category.pk] = part_serializers.CategorySerializer(category).data - - for part in data: - cat_id = part['category'] - - if cat_id is not None and cat_id in category_map.keys(): - detail = category_map[cat_id] - else: - detail = None - - part['category_detail'] = detail - """ Determine the response type based on the request. a) For HTTP requests (e.g. via the browseable API) return a DRF response diff --git a/InvenTree/part/serializers.py b/InvenTree/part/serializers.py index 29d1fa8749..a4b4105b6a 100644 --- a/InvenTree/part/serializers.py +++ b/InvenTree/part/serializers.py @@ -470,22 +470,19 @@ class PartSerializer(RemoteImageMixin, InvenTreeModelSerializer): - Allows us to optionally pass extra fields based on the query. """ self.starred_parts = kwargs.pop('starred_parts', []) - category_detail = kwargs.pop('category_detail', False) - parameters = kwargs.pop('parameters', False) - create = kwargs.pop('create', False) super().__init__(*args, **kwargs) - if category_detail is not True: + if not category_detail: self.fields.pop('category_detail') - if parameters is not True: + if not parameters: self.fields.pop('parameters') - if create is not True: + if not create: # These fields are only used for the LIST API endpoint for f in self.skip_create_fields()[1:]: self.fields.pop(f) diff --git a/InvenTree/part/test_api.py b/InvenTree/part/test_api.py index 76097faa8b..22d689c8d0 100644 --- a/InvenTree/part/test_api.py +++ b/InvenTree/part/test_api.py @@ -5,6 +5,8 @@ from enum import IntEnum from random import randint from django.core.exceptions import ValidationError +from django.db import connection +from django.test.utils import CaptureQueriesContext from django.urls import reverse import PIL @@ -1704,6 +1706,60 @@ class PartDetailTests(PartAPITestBase): self.assertEqual(part.metadata['x'], 'y') +class PartListTests(PartAPITestBase): + """Unit tests for the Part List API endpoint""" + + def test_query_count(self): + """Test that the query count is unchanged, independent of query results""" + + queries = [ + {'limit': 1}, + {'limit': 10}, + {'limit': 50}, + {'category': 1}, + {}, + ] + + url = reverse('api-part-list') + + # Create a bunch of extra parts (efficiently) + parts = [] + + for ii in range(100): + parts.append(Part( + name=f"Extra part {ii}", + description="A new part which will appear via the API", + level=0, tree_id=0, + lft=0, rght=0, + )) + + Part.objects.bulk_create(parts) + + for query in queries: + + with CaptureQueriesContext(connection) as ctx: + self.get(url, query, expected_code=200) + + # No more than 20 database queries + self.assertLess(len(ctx), 20) + + # Test 'category_detail' annotation + for b in [False, True]: + with CaptureQueriesContext(connection) as ctx: + results = self.get( + reverse('api-part-list'), + {'category_detail': b}, + expected_code=200 + ) + + for result in results.data: + if b and result['category'] is not None: + self.assertIn('category_detail', result) + + # No more than 20 DB queries + self.assertLessEqual(len(ctx), 20) + + class PartNotesTests(InvenTreeAPITestCase): """Tests for the 'notes' field (markdown field)"""