From be6ab14c9bfc1b5387d6a7bbefdae9d46e5c1720 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 14 Jun 2023 18:33:49 +1000 Subject: [PATCH] Improve StockItem API speed (#5042) - Removes child detail fields which cannot be effectively annotated - Prefetch required fields - Add unit test method for checking query count --- InvenTree/InvenTree/unit_test.py | 28 +++++++++++++++++++++++++ InvenTree/company/serializers.py | 2 +- InvenTree/stock/api.py | 9 -------- InvenTree/stock/serializers.py | 9 +++++++- InvenTree/stock/test_api.py | 36 ++++++++++++++++++++++++++++++++ 5 files changed, 73 insertions(+), 11 deletions(-) diff --git a/InvenTree/InvenTree/unit_test.py b/InvenTree/InvenTree/unit_test.py index d609dc88f8..3590a6f56b 100644 --- a/InvenTree/InvenTree/unit_test.py +++ b/InvenTree/InvenTree/unit_test.py @@ -2,13 +2,17 @@ import csv import io +import json import re +from contextlib import contextmanager from pathlib import Path from django.contrib.auth import get_user_model from django.contrib.auth.models import Group, Permission +from django.db import connections from django.http.response import StreamingHttpResponse from django.test import TestCase +from django.test.utils import CaptureQueriesContext from djmoney.contrib.exchange.models import ExchangeBackend, Rate from rest_framework.test import APITestCase @@ -241,6 +245,30 @@ class InvenTreeTestCase(ExchangeRateMixin, UserMixin, TestCase): class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase): """Base class for running InvenTree API tests.""" + @contextmanager + def assertNumQueriesLessThan(self, value, using='default', verbose=False, debug=False): + """Context manager to check that the number of queries is less than a certain value. + + Example: + with self.assertNumQueriesLessThan(10): + # Do some stuff + Ref: https://stackoverflow.com/questions/1254170/django-is-there-a-way-to-count-sql-queries-from-an-unit-test/59089020#59089020 + """ + with CaptureQueriesContext(connections[using]) as context: + yield # your test will be run here + + if verbose: + msg = "\r\n%s" % json.dumps(context.captured_queries, indent=4) + else: + msg = None + + n = len(context.captured_queries) + + if debug: + print(f"Expected less than {value} queries, got {n} queries") + + self.assertLess(n, value, msg=msg) + def checkResponse(self, url, method, expected_code, response): """Debug output for an unexpected response""" diff --git a/InvenTree/company/serializers.py b/InvenTree/company/serializers.py index f0833d18b0..f4c04fde22 100644 --- a/InvenTree/company/serializers.py +++ b/InvenTree/company/serializers.py @@ -334,7 +334,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer): MPN = serializers.CharField(read_only=True) - manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', read_only=True) + manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', part_detail=False, read_only=True) url = serializers.CharField(source='get_absolute_url', read_only=True) diff --git a/InvenTree/stock/api.py b/InvenTree/stock/api.py index 22358cfabb..ba7387885e 100644 --- a/InvenTree/stock/api.py +++ b/InvenTree/stock/api.py @@ -795,15 +795,6 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): queryset = StockSerializers.StockItemSerializer.annotate_queryset(queryset) - # Also ensure that we pre-fecth all the related items - queryset = queryset.prefetch_related( - 'part', - 'part__category', - 'location', - 'test_results', - 'tags', - ) - return queryset def filter_queryset(self, queryset): diff --git a/InvenTree/stock/serializers.py b/InvenTree/stock/serializers.py index 103622ee73..2518ca3c35 100644 --- a/InvenTree/stock/serializers.py +++ b/InvenTree/stock/serializers.py @@ -231,10 +231,17 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer): """Add some extra annotations to the queryset, performing database queries as efficiently as possible.""" queryset = queryset.prefetch_related( + 'location', 'sales_order', 'purchase_order', 'part', + 'part__category', 'part__pricing_data', + 'supplier_part', + 'supplier_part__manufacturer_part', + 'supplier_part__tags', + 'test_results', + 'tags', ) # Annotate the queryset with the total allocated to sales orders @@ -280,7 +287,7 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer): status_text = serializers.CharField(source='get_status_display', read_only=True) # Optional detail fields, which can be appended via query parameters - supplier_part_detail = SupplierPartSerializer(source='supplier_part', many=False, read_only=True) + supplier_part_detail = SupplierPartSerializer(source='supplier_part', supplier_detail=False, manufacturer_detail=False, part_detail=False, many=False, read_only=True) part_detail = PartBriefSerializer(source='part', many=False, read_only=True) location_detail = LocationBriefSerializer(source='location', many=False, read_only=True) tests = StockItemTestResultSerializer(source='test_results', many=True, read_only=True) diff --git a/InvenTree/stock/test_api.py b/InvenTree/stock/test_api.py index d8c0a9cf35..b96d5c660d 100644 --- a/InvenTree/stock/test_api.py +++ b/InvenTree/stock/test_api.py @@ -557,6 +557,42 @@ class StockItemListTest(StockAPITestCase): self.assertEqual(len(dataset), 17) + def test_query_count(self): + """Test that the number of queries required to fetch stock items is reasonable.""" + + def get_stock(data): + """Helper function to fetch stock items.""" + response = self.client.get(self.list_url, data=data) + self.assertEqual(response.status_code, 200) + return response.data + + # Create a bunch of StockItem objects + prt = Part.objects.first() + + StockItem.objects.bulk_create([ + StockItem( + part=prt, + quantity=1, + level=0, tree_id=0, lft=0, rght=0, + ) for _ in range(100) + ]) + + # List *all* stock items + with self.assertNumQueriesLessThan(25): + get_stock({}) + + # List all stock items, with part detail + with self.assertNumQueriesLessThan(20): + get_stock({'part_detail': True}) + + # List all stock items, with supplier_part detail + with self.assertNumQueriesLessThan(20): + get_stock({'supplier_part_detail': True}) + + # List all stock items, with 'location' and 'tests' detail + with self.assertNumQueriesLessThan(20): + get_stock({'location_detail': True, 'tests': True}) + class StockItemTest(StockAPITestCase): """Series of API tests for the StockItem API."""