Improve StockItem API speed (#5042)

- Removes child detail fields which cannot be effectively annotated
- Prefetch required fields
- Add unit test method for checking query count
This commit is contained in:
Oliver 2023-06-14 18:33:49 +10:00 committed by GitHub
parent 8d16abcefb
commit be6ab14c9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 11 deletions

View File

@ -2,13 +2,17 @@
import csv
import io
import json
import re
from contextlib import contextmanager
from pathlib import Path
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group, Permission
from django.db import connections
from django.http.response import StreamingHttpResponse
from django.test import TestCase
from django.test.utils import CaptureQueriesContext
from djmoney.contrib.exchange.models import ExchangeBackend, Rate
from rest_framework.test import APITestCase
@ -241,6 +245,30 @@ class InvenTreeTestCase(ExchangeRateMixin, UserMixin, TestCase):
class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
"""Base class for running InvenTree API tests."""
@contextmanager
def assertNumQueriesLessThan(self, value, using='default', verbose=False, debug=False):
"""Context manager to check that the number of queries is less than a certain value.
Example:
with self.assertNumQueriesLessThan(10):
# Do some stuff
Ref: https://stackoverflow.com/questions/1254170/django-is-there-a-way-to-count-sql-queries-from-an-unit-test/59089020#59089020
"""
with CaptureQueriesContext(connections[using]) as context:
yield # your test will be run here
if verbose:
msg = "\r\n%s" % json.dumps(context.captured_queries, indent=4)
else:
msg = None
n = len(context.captured_queries)
if debug:
print(f"Expected less than {value} queries, got {n} queries")
self.assertLess(n, value, msg=msg)
def checkResponse(self, url, method, expected_code, response):
"""Debug output for an unexpected response"""

View File

@ -334,7 +334,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
MPN = serializers.CharField(read_only=True)
manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', read_only=True)
manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', part_detail=False, read_only=True)
url = serializers.CharField(source='get_absolute_url', read_only=True)

View File

@ -795,15 +795,6 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView):
queryset = StockSerializers.StockItemSerializer.annotate_queryset(queryset)
# Also ensure that we pre-fecth all the related items
queryset = queryset.prefetch_related(
'part',
'part__category',
'location',
'test_results',
'tags',
)
return queryset
def filter_queryset(self, queryset):

View File

@ -231,10 +231,17 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer):
"""Add some extra annotations to the queryset, performing database queries as efficiently as possible."""
queryset = queryset.prefetch_related(
'location',
'sales_order',
'purchase_order',
'part',
'part__category',
'part__pricing_data',
'supplier_part',
'supplier_part__manufacturer_part',
'supplier_part__tags',
'test_results',
'tags',
)
# Annotate the queryset with the total allocated to sales orders
@ -280,7 +287,7 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer):
status_text = serializers.CharField(source='get_status_display', read_only=True)
# Optional detail fields, which can be appended via query parameters
supplier_part_detail = SupplierPartSerializer(source='supplier_part', many=False, read_only=True)
supplier_part_detail = SupplierPartSerializer(source='supplier_part', supplier_detail=False, manufacturer_detail=False, part_detail=False, many=False, read_only=True)
part_detail = PartBriefSerializer(source='part', many=False, read_only=True)
location_detail = LocationBriefSerializer(source='location', many=False, read_only=True)
tests = StockItemTestResultSerializer(source='test_results', many=True, read_only=True)

View File

@ -557,6 +557,42 @@ class StockItemListTest(StockAPITestCase):
self.assertEqual(len(dataset), 17)
def test_query_count(self):
"""Test that the number of queries required to fetch stock items is reasonable."""
def get_stock(data):
"""Helper function to fetch stock items."""
response = self.client.get(self.list_url, data=data)
self.assertEqual(response.status_code, 200)
return response.data
# Create a bunch of StockItem objects
prt = Part.objects.first()
StockItem.objects.bulk_create([
StockItem(
part=prt,
quantity=1,
level=0, tree_id=0, lft=0, rght=0,
) for _ in range(100)
])
# List *all* stock items
with self.assertNumQueriesLessThan(25):
get_stock({})
# List all stock items, with part detail
with self.assertNumQueriesLessThan(20):
get_stock({'part_detail': True})
# List all stock items, with supplier_part detail
with self.assertNumQueriesLessThan(20):
get_stock({'supplier_part_detail': True})
# List all stock items, with 'location' and 'tests' detail
with self.assertNumQueriesLessThan(20):
get_stock({'location_detail': True, 'tests': True})
class StockItemTest(StockAPITestCase):
"""Series of API tests for the StockItem API."""