mirror of
https://github.com/inventree/InvenTree
synced 2024-08-30 18:33:04 +00:00
Improve StockItem API speed (#5042)
- Removes child detail fields which cannot be effectively annotated - Prefetch required fields - Add unit test method for checking query count
This commit is contained in:
parent
8d16abcefb
commit
be6ab14c9b
@ -2,13 +2,17 @@
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.models import Group, Permission
|
||||
from django.db import connections
|
||||
from django.http.response import StreamingHttpResponse
|
||||
from django.test import TestCase
|
||||
from django.test.utils import CaptureQueriesContext
|
||||
|
||||
from djmoney.contrib.exchange.models import ExchangeBackend, Rate
|
||||
from rest_framework.test import APITestCase
|
||||
@ -241,6 +245,30 @@ class InvenTreeTestCase(ExchangeRateMixin, UserMixin, TestCase):
|
||||
class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
|
||||
"""Base class for running InvenTree API tests."""
|
||||
|
||||
@contextmanager
|
||||
def assertNumQueriesLessThan(self, value, using='default', verbose=False, debug=False):
|
||||
"""Context manager to check that the number of queries is less than a certain value.
|
||||
|
||||
Example:
|
||||
with self.assertNumQueriesLessThan(10):
|
||||
# Do some stuff
|
||||
Ref: https://stackoverflow.com/questions/1254170/django-is-there-a-way-to-count-sql-queries-from-an-unit-test/59089020#59089020
|
||||
"""
|
||||
with CaptureQueriesContext(connections[using]) as context:
|
||||
yield # your test will be run here
|
||||
|
||||
if verbose:
|
||||
msg = "\r\n%s" % json.dumps(context.captured_queries, indent=4)
|
||||
else:
|
||||
msg = None
|
||||
|
||||
n = len(context.captured_queries)
|
||||
|
||||
if debug:
|
||||
print(f"Expected less than {value} queries, got {n} queries")
|
||||
|
||||
self.assertLess(n, value, msg=msg)
|
||||
|
||||
def checkResponse(self, url, method, expected_code, response):
|
||||
"""Debug output for an unexpected response"""
|
||||
|
||||
|
@ -334,7 +334,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
|
||||
|
||||
MPN = serializers.CharField(read_only=True)
|
||||
|
||||
manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', read_only=True)
|
||||
manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', part_detail=False, read_only=True)
|
||||
|
||||
url = serializers.CharField(source='get_absolute_url', read_only=True)
|
||||
|
||||
|
@ -795,15 +795,6 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView):
|
||||
|
||||
queryset = StockSerializers.StockItemSerializer.annotate_queryset(queryset)
|
||||
|
||||
# Also ensure that we pre-fecth all the related items
|
||||
queryset = queryset.prefetch_related(
|
||||
'part',
|
||||
'part__category',
|
||||
'location',
|
||||
'test_results',
|
||||
'tags',
|
||||
)
|
||||
|
||||
return queryset
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
|
@ -231,10 +231,17 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer):
|
||||
"""Add some extra annotations to the queryset, performing database queries as efficiently as possible."""
|
||||
|
||||
queryset = queryset.prefetch_related(
|
||||
'location',
|
||||
'sales_order',
|
||||
'purchase_order',
|
||||
'part',
|
||||
'part__category',
|
||||
'part__pricing_data',
|
||||
'supplier_part',
|
||||
'supplier_part__manufacturer_part',
|
||||
'supplier_part__tags',
|
||||
'test_results',
|
||||
'tags',
|
||||
)
|
||||
|
||||
# Annotate the queryset with the total allocated to sales orders
|
||||
@ -280,7 +287,7 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer):
|
||||
status_text = serializers.CharField(source='get_status_display', read_only=True)
|
||||
|
||||
# Optional detail fields, which can be appended via query parameters
|
||||
supplier_part_detail = SupplierPartSerializer(source='supplier_part', many=False, read_only=True)
|
||||
supplier_part_detail = SupplierPartSerializer(source='supplier_part', supplier_detail=False, manufacturer_detail=False, part_detail=False, many=False, read_only=True)
|
||||
part_detail = PartBriefSerializer(source='part', many=False, read_only=True)
|
||||
location_detail = LocationBriefSerializer(source='location', many=False, read_only=True)
|
||||
tests = StockItemTestResultSerializer(source='test_results', many=True, read_only=True)
|
||||
|
@ -557,6 +557,42 @@ class StockItemListTest(StockAPITestCase):
|
||||
|
||||
self.assertEqual(len(dataset), 17)
|
||||
|
||||
def test_query_count(self):
|
||||
"""Test that the number of queries required to fetch stock items is reasonable."""
|
||||
|
||||
def get_stock(data):
|
||||
"""Helper function to fetch stock items."""
|
||||
response = self.client.get(self.list_url, data=data)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
return response.data
|
||||
|
||||
# Create a bunch of StockItem objects
|
||||
prt = Part.objects.first()
|
||||
|
||||
StockItem.objects.bulk_create([
|
||||
StockItem(
|
||||
part=prt,
|
||||
quantity=1,
|
||||
level=0, tree_id=0, lft=0, rght=0,
|
||||
) for _ in range(100)
|
||||
])
|
||||
|
||||
# List *all* stock items
|
||||
with self.assertNumQueriesLessThan(25):
|
||||
get_stock({})
|
||||
|
||||
# List all stock items, with part detail
|
||||
with self.assertNumQueriesLessThan(20):
|
||||
get_stock({'part_detail': True})
|
||||
|
||||
# List all stock items, with supplier_part detail
|
||||
with self.assertNumQueriesLessThan(20):
|
||||
get_stock({'supplier_part_detail': True})
|
||||
|
||||
# List all stock items, with 'location' and 'tests' detail
|
||||
with self.assertNumQueriesLessThan(20):
|
||||
get_stock({'location_detail': True, 'tests': True})
|
||||
|
||||
|
||||
class StockItemTest(StockAPITestCase):
|
||||
"""Series of API tests for the StockItem API."""
|
||||
|
Loading…
Reference in New Issue
Block a user