Tree query improvements (#3443)

* Allow part category table to be ordered by part count

* Add queryset annotation for part-category part-count

- Uses subquery to annotate the part-count for sub-categories
- Huge reduction in number of queries

* Update 'pathstring' property of PartCategory and StockLocation

- No longer a dynamically calculated value
- Constructed when the model is saved, and then written to the database
- Limited to 250 characters

* Data migration to re-construct pathstring for PartCategory objects

* Fix for tree model save() method

* Add unit tests for pathstring construction

* Data migration for StockLocation pathstring values

* Update part API

- Add new annotation to PartLocationDetail view

* Update API version

* Apply similar annotation to StockLocation API endpoints

* Extra tests for PartCategory API

* Unit test fixes

* Allow PartCategory and StockLocation lists to be sorted by 'pathstring'

* Further unit test fixes
This commit is contained in:
Oliver 2022-08-01 13:43:27 +10:00 committed by GitHub
parent 1306db74b2
commit 175d9555b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 478 additions and 21 deletions

View File

@ -2,11 +2,18 @@
# InvenTree API version
INVENTREE_API_VERSION = 68
INVENTREE_API_VERSION = 69
"""
Increment this API version number whenever there is a significant change to the API that any clients need to know about
v69 -> 2022-08-01 : https://github.com/inventree/InvenTree/pull/3443
- Updates the PartCategory list API:
- Improve query efficiency: O(n) becomes O(1)
- Rename 'parts' field to 'part_count'
- Updates the StockLocation list API:
- Improve query efficiency: O(n) becomes O(1)
v68 -> 2022-07-27 : https://github.com/inventree/InvenTree/pull/3417
- Allows SupplierPart list to be filtered by SKU value
- Allows SupplierPart list to be filtered by MPN value

View File

@ -54,6 +54,42 @@ def generateTestKey(test_name):
return key
def constructPathString(path, max_chars=250):
"""Construct a 'path string' for the given path.
Arguments:
path: A list of strings e.g. ['path', 'to', 'location']
max_chars: Maximum number of characters
"""
pathstring = '/'.join(path)
idx = 0
# Replace middle elements to limit the pathstring
if len(pathstring) > max_chars:
mid = len(path) // 2
path_l = path[0:mid]
path_r = path[mid:]
# Ensure the pathstring length is limited
while len(pathstring) > max_chars:
# Remove an element from the list
if idx % 2 == 0:
path_l = path_l[:-1]
else:
path_r = path_r[1:]
subpath = path_l + ['...'] + path_r
pathstring = '/'.join(subpath)
idx += 1
return pathstring
def getMediaUrl(filename):
"""Return the qualified access path for the given file, under the media directory."""
return os.path.join(MEDIA_URL, str(filename))

View File

@ -500,14 +500,25 @@ class InvenTreeTree(MPTTModel):
}
def save(self, *args, **kwargs):
"""Provide better error for invalid moves."""
"""Custom save method for InvenTreeTree abstract model"""
try:
super().save(*args, **kwargs)
except InvalidMove:
# Provide better error for parent selection
raise ValidationError({
'parent': _("Invalid choice"),
})
# Re-calculate the 'pathstring' field
pathstring = InvenTree.helpers.constructPathString(
[item.name for item in self.path]
)
if pathstring != self.pathstring:
self.pathstring = pathstring
super().save(force_update=True)
class Meta:
"""Metaclass defines extra model properties."""
@ -542,6 +553,14 @@ class InvenTreeTree(MPTTModel):
verbose_name=_("parent"),
related_name='children')
# The 'pathstring' field is calculated each time the model is saved
pathstring = models.CharField(
blank=True,
max_length=250,
verbose_name=_('Path'),
help_text=_('Path')
)
@property
def item_count(self):
"""Return the number of items which exist *under* this node in the tree.
@ -612,14 +631,6 @@ class InvenTreeTree(MPTTModel):
"""
return self.parentpath + [self]
@property
def pathstring(self):
"""Get a string representation for the path of this item.
e.g. "Top/Second/Third/This"
"""
return '/'.join([item.name for item in self.path])
def __str__(self):
"""String representation of a category is the full path to that category."""
return "{path} - {desc}".format(path=self.pathstring, desc=self.description)

View File

@ -53,6 +53,13 @@ class CategoryList(ListCreateAPI):
queryset = PartCategory.objects.all()
serializer_class = part_serializers.CategorySerializer
def get_queryset(self, *args, **kwargs):
"""Return an annotated queryset for the CategoryList endpoint"""
queryset = super().get_queryset(*args, **kwargs)
queryset = part_serializers.CategorySerializer.annotate_queryset(queryset)
return queryset
def get_serializer_context(self):
"""Add extra context data to the serializer for the PartCategoryList endpoint"""
ctx = super().get_serializer_context()
@ -141,9 +148,11 @@ class CategoryList(ListCreateAPI):
ordering_fields = [
'name',
'pathstring',
'level',
'tree_id',
'lft',
'part_count',
]
# Use hierarchical ordering by default
@ -165,6 +174,13 @@ class CategoryDetail(RetrieveUpdateDestroyAPI):
serializer_class = part_serializers.CategorySerializer
queryset = PartCategory.objects.all()
def get_queryset(self, *args, **kwargs):
"""Return an annotated queryset for the CategoryDetail endpoint"""
queryset = super().get_queryset(*args, **kwargs)
queryset = part_serializers.CategorySerializer.annotate_queryset(queryset)
return queryset
def get_serializer_context(self):
"""Add extra context to the serializer for the CategoryDetail endpoint"""
ctx = super().get_serializer_context()

View File

@ -1,4 +1,4 @@
"""Custom query filters for the Part model
"""Custom query filters for the Part models
The code here makes heavy use of subquery annotations!
@ -19,11 +19,13 @@ Relevant PRs:
from decimal import Decimal
from django.db import models
from django.db.models import F, FloatField, Func, OuterRef, Q, Subquery
from django.db.models import (F, FloatField, Func, IntegerField, OuterRef, Q,
Subquery)
from django.db.models.functions import Coalesce
from sql_util.utils import SubquerySum
import part.models
import stock.models
from InvenTree.status_codes import (BuildStatus, PurchaseOrderStatus,
SalesOrderStatus)
@ -158,3 +160,29 @@ def annotate_variant_quantity(subquery: Q, reference: str = 'quantity'):
0,
output_field=FloatField(),
)
def annotate_category_parts():
"""Construct a queryset annotation which returns the number of parts in a particular category.
- Includes parts in subcategories also
- Requires subquery to perform annotation
"""
# Construct a subquery to provide all parts in this category and any subcategories:
subquery = part.models.Part.objects.exclude(category=None).filter(
category__tree_id=OuterRef('tree_id'),
category__lft__gte=OuterRef('lft'),
category__rght__lte=OuterRef('rght'),
category__level__gte=OuterRef('level'),
)
return Coalesce(
Subquery(
subquery.annotate(
total=Func(F('pk'), function='COUNT', output_field=IntegerField())
).values('total'),
),
0,
output_field=IntegerField()
)

View File

@ -0,0 +1,18 @@
# Generated by Django 3.2.14 on 2022-07-31 23:54
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('part', '0081_alter_partcategory_name'),
]
operations = [
migrations.AddField(
model_name='partcategory',
name='pathstring',
field=models.CharField(blank=True, help_text='Path', max_length=250, verbose_name='Path'),
),
]

View File

@ -0,0 +1,54 @@
# Generated by Django 3.2.14 on 2022-07-31 23:57
from django.db import migrations
from InvenTree.helpers import constructPathString
def update_pathstring(apps, schema_editor):
"""Construct pathstring for all existing PartCategory objects"""
PartCategory = apps.get_model('part', 'partcategory')
n = PartCategory.objects.count()
if n > 0:
for cat in PartCategory.objects.all():
# Construct complete path for category
path = [cat.name]
parent = cat.parent
# Iterate up the tree
while parent is not None:
path = [parent.name] + path
parent = parent.parent
pathstring = constructPathString(path)
cat.pathstring = pathstring
cat.save()
print(f"\n--- Updated 'pathstring' for {n} PartCategory objects ---\n")
def nupdate_pathstring(apps, schema_editor):
"""Empty function for reverse migration compatibility"""
pass
class Migration(migrations.Migration):
dependencies = [
('part', '0082_partcategory_pathstring'),
]
operations = [
migrations.RunPython(
update_pathstring,
reverse_code=nupdate_pathstring
)
]

View File

@ -41,9 +41,20 @@ class CategorySerializer(InvenTreeModelSerializer):
"""Return True if the category is directly "starred" by the current user."""
return category in self.context.get('starred_categories', [])
@staticmethod
def annotate_queryset(queryset):
"""Annotate extra information to the queryset"""
# Annotate the number of 'parts' which exist in each category (including subcategories!)
queryset = queryset.annotate(
part_count=part.filters.annotate_category_parts()
)
return queryset
url = serializers.CharField(source='get_absolute_url', read_only=True)
parts = serializers.IntegerField(source='item_count', read_only=True)
part_count = serializers.IntegerField(read_only=True)
level = serializers.IntegerField(read_only=True)
@ -60,7 +71,7 @@ class CategorySerializer(InvenTreeModelSerializer):
'default_keywords',
'level',
'parent',
'parts',
'part_count',
'pathstring',
'starred',
'url',

View File

@ -77,6 +77,76 @@ class PartCategoryAPITest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), 5)
# Check that the required fields are present
fields = [
'pk',
'name',
'description',
'default_location',
'level',
'parent',
'part_count',
'pathstring',
'url'
]
for result in response.data:
for f in fields:
self.assertIn(f, result)
def test_part_count(self):
"""Test that the 'part_count' field is annotated correctly"""
url = reverse('api-part-category-list')
# Create a parent category
cat = PartCategory.objects.create(
name='Parent Cat',
description='Some name',
parent=None
)
# Create child categories
for ii in range(10):
child = PartCategory.objects.create(
name=f"Child cat {ii}",
description="A child category",
parent=cat
)
# Create parts in this category
for jj in range(10):
Part.objects.create(
name=f"Part xyz {jj}",
description="A test part",
category=child
)
# Filter by parent category
response = self.get(
url,
{
'parent': cat.pk,
},
expected_code=200
)
# 10 child categories
self.assertEqual(len(response.data), 10)
for result in response.data:
self.assertEqual(result['parent'], cat.pk)
self.assertEqual(result['part_count'], 10)
# Detail view for parent category
response = self.get(
f'/api/part/category/{cat.pk}/',
expected_code=200
)
# Annotation should include parts from all sub-categories
self.assertEqual(response.data['part_count'], 100)
def test_category_metadata(self):
"""Test metadata endpoint for the PartCategory."""
cat = PartCategory.objects.get(pk=1)

View File

@ -1,5 +1,6 @@
"""Unit tests for the PartCategory model"""
from django.core.exceptions import ValidationError
from django.test import TestCase
from .models import Part, PartCategory, PartParameter, PartParameterTemplate
@ -63,9 +64,69 @@ class CategoryTest(TestCase):
def test_path_string(self):
"""Test that the category path string works correctly."""
# Note that due to data migrations, these fields need to be saved first
self.resistors.save()
self.transceivers.save()
self.assertEqual(str(self.resistors), 'Electronics/Resistors - Resistors')
self.assertEqual(str(self.transceivers.pathstring), 'Electronics/IC/Transceivers')
# Create a new subcategory
subcat = PartCategory.objects.create(
name='Subcategory',
description='My little sub category',
parent=self.transceivers
)
# Pathstring should have been updated correctly
self.assertEqual(subcat.pathstring, 'Electronics/IC/Transceivers/Subcategory')
self.assertEqual(len(subcat.path), 4)
# Move to a new parent location
subcat.parent = self.resistors
subcat.save()
self.assertEqual(subcat.pathstring, 'Electronics/Resistors/Subcategory')
self.assertEqual(len(subcat.path), 3)
# Move to top-level
subcat.parent = None
subcat.save()
self.assertEqual(subcat.pathstring, 'Subcategory')
self.assertEqual(len(subcat.path), 1)
# Construct a very long pathstring and ensure it gets updated correctly
cat = PartCategory.objects.create(
name='Cat',
description='A long running category',
parent=None
)
parent = cat
for idx in range(26):
letter = chr(ord('A') + idx)
child = PartCategory.objects.create(
name=letter * 10,
description=f"Subcategory {letter}",
parent=parent
)
parent = child
self.assertTrue(len(child.path), 26)
self.assertEqual(
child.pathstring,
"Cat/AAAAAAAAAA/BBBBBBBBBB/CCCCCCCCCC/DDDDDDDDDD/EEEEEEEEEE/FFFFFFFFFF/GGGGGGGGGG/HHHHHHHHHH/IIIIIIIIII/JJJJJJJJJJ/.../OOOOOOOOOO/PPPPPPPPPP/QQQQQQQQQQ/RRRRRRRRRR/SSSSSSSSSS/TTTTTTTTTT/UUUUUUUUUU/VVVVVVVVVV/WWWWWWWWWW/XXXXXXXXXX/YYYYYYYYYY/ZZZZZZZZZZ"
)
self.assertTrue(len(child.pathstring) <= 250)
# Attempt an invalid move
with self.assertRaises(ValidationError):
cat.parent = child
cat.save()
def test_url(self):
"""Test that the PartCategory URL works."""
self.assertEqual(self.capacitors.get_absolute_url(), '/part/category/3/')
@ -130,6 +191,9 @@ class CategoryTest(TestCase):
def test_default_locations(self):
"""Test traversal for default locations."""
self.assertIsNotNone(self.fasteners.default_location)
self.fasteners.default_location.save()
self.assertEqual(str(self.fasteners.default_location), 'Office/Drawer_1 - In my desk')
# Any part under electronics should default to 'Home'

View File

@ -220,6 +220,7 @@ class PartTest(TestCase):
def test_category(self):
"""Test PartCategory path"""
self.c1.category.save()
self.assertEqual(str(self.c1.category), 'Electronics/Capacitors - Capacitors')
orphan = Part.objects.get(name='Orphan')

View File

@ -224,6 +224,13 @@ class StockLocationList(ListCreateAPI):
queryset = StockLocation.objects.all()
serializer_class = StockSerializers.LocationSerializer
def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for the StockLocationList endpoint"""
queryset = super().get_queryset(*args, **kwargs)
queryset = StockSerializers.LocationSerializer.annotate_queryset(queryset)
return queryset
def filter_queryset(self, queryset):
"""Custom filtering: - Allow filtering by "null" parent to retrieve top-level stock locations."""
queryset = super().filter_queryset(queryset)
@ -293,6 +300,7 @@ class StockLocationList(ListCreateAPI):
ordering_fields = [
'name',
'pathstring',
'items',
'level',
'tree_id',
@ -1340,6 +1348,13 @@ class LocationDetail(RetrieveUpdateDestroyAPI):
queryset = StockLocation.objects.all()
serializer_class = StockSerializers.LocationSerializer
def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for the StockLocationList endpoint"""
queryset = super().get_queryset(*args, **kwargs)
queryset = StockSerializers.LocationSerializer.annotate_queryset(queryset)
return queryset
stock_api_urls = [
re_path(r'^location/', include([

View File

@ -0,0 +1,36 @@
"""Custom query filters for the Stock models"""
from django.db.models import F, Func, IntegerField, OuterRef, Q, Subquery
from django.db.models.functions import Coalesce
import stock.models
def annotate_location_items(filter: Q = None):
"""Construct a queryset annotation which returns the number of stock items in a particular location.
- Includes items in subcategories also
- Requires subquery to perform annotation
"""
# Construct a subquery to provide all items in this location and any sublocations
subquery = stock.models.StockItem.objects.exclude(location=None).filter(
location__tree_id=OuterRef('tree_id'),
location__lft__gte=OuterRef('lft'),
location__rght__lte=OuterRef('rght'),
location__level__gte=OuterRef('level'),
)
# Optionally apply extra filter to returned results
if filter is not None:
subquery = subquery.filter(filter)
return Coalesce(
Subquery(
subquery.annotate(
total=Func(F('pk'), function='COUNT', output_field=IntegerField())
).values('total')
),
0,
output_field=IntegerField()
)

View File

@ -0,0 +1,18 @@
# Generated by Django 3.2.14 on 2022-07-31 23:54
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('stock', '0079_alter_stocklocation_name'),
]
operations = [
migrations.AddField(
model_name='stocklocation',
name='pathstring',
field=models.CharField(blank=True, help_text='Path', max_length=250, verbose_name='Path'),
),
]

View File

@ -0,0 +1,55 @@
# Generated by Django 3.2.14 on 2022-08-01 00:44
from django.db import migrations
from InvenTree.helpers import constructPathString
def update_pathstring(apps, schema_editor):
"""Construct pathstring for all existing StockLocation objects"""
StockLocation = apps.get_model('stock', 'stocklocation')
n = StockLocation.objects.count()
if n > 0:
for loc in StockLocation.objects.all():
# Construct complete path for category
path = [loc.name]
parent = loc.parent
# Iterate up the tree
while parent is not None:
path = [parent.name] + path
parent = parent.parent
pathstring = constructPathString(path)
loc.pathstring = pathstring
loc.save()
print(f"\n--- Updated 'pathstring' for {n} StockLocation objects ---\n")
def nupdate_pathstring(apps, schema_editor):
"""Empty function for reverse migration compatibility"""
pass
class Migration(migrations.Migration):
dependencies = [
('stock', '0080_stocklocation_pathstring'),
]
operations = [
migrations.RunPython(
update_pathstring,
reverse_code=nupdate_pathstring
)
]

View File

@ -18,6 +18,7 @@ import company.models
import InvenTree.helpers
import InvenTree.serializers
import part.models as part_models
import stock.filters
from common.settings import currency_code_default, currency_code_mappings
from company.serializers import SupplierPartSerializer
from InvenTree.models import extract_int
@ -575,9 +576,20 @@ class LocationTreeSerializer(InvenTree.serializers.InvenTreeModelSerializer):
class LocationSerializer(InvenTree.serializers.InvenTreeModelSerializer):
"""Detailed information about a stock location."""
@staticmethod
def annotate_queryset(queryset):
"""Annotate extra information to the queryset"""
# Annotate the number of stock items which exist in this category (including subcategories)
queryset = queryset.annotate(
items=stock.filters.annotate_location_items()
)
return queryset
url = serializers.CharField(source='get_absolute_url', read_only=True)
items = serializers.IntegerField(source='item_count', read_only=True)
items = serializers.IntegerField(read_only=True)
level = serializers.IntegerField(read_only=True)

View File

@ -125,6 +125,10 @@ class StockTest(InvenTreeTestCase):
def test_parent_locations(self):
"""Test parent."""
# Ensure pathstring gets updated
self.drawer3.save()
self.assertEqual(self.office.parent, None)
self.assertEqual(self.drawer1.parent, self.office)
self.assertEqual(self.drawer2.parent, self.office)

View File

@ -1778,6 +1778,7 @@ function loadPartCategoryTable(table, options) {
name: 'category',
original: original,
showColumns: true,
sortable: true,
buttons: options.allowTreeView ? [
{
icon: 'fas fa-bars',
@ -1882,13 +1883,13 @@ function loadPartCategoryTable(table, options) {
title: '{% trans "Path" %}',
switchable: !tree_view,
visible: !tree_view,
sortable: false,
sortable: true,
},
{
field: 'parts',
field: 'part_count',
title: '{% trans "Parts" %}',
switchable: true,
sortable: false,
sortable: true,
}
]
});

View File

@ -2264,6 +2264,7 @@ function loadStockLocationTable(table, options) {
queryParams: filters,
name: 'location',
original: original,
sortable: true,
showColumns: true,
onPostBody: function() {
@ -2360,14 +2361,13 @@ function loadStockLocationTable(table, options) {
field: 'pathstring',
title: '{% trans "Path" %}',
switchable: true,
sortable: false,
sortable: true,
},
{
field: 'items',
title: '{% trans "Stock Items" %}',
switchable: true,
sortable: false,
sortName: 'item_count',
sortable: true,
}
]
});