Improve deletion behaviour for InvenTreeTree model (#5806)

* Improve deletion behaviour for InvenTreeTree model

- Remove recursive call to function
- Handle database operations as bulk queries
- Ensure child nodes have their pathstring updated correctly
- Remove old @receiver hook
- Refactor StockLocation.delete method
- Refactor PartCategory.delete method
- Atomic transactions potentially problematic here

* Add docstring

* Fix method name

* Use bulk-update instead of recursive save when pathstring changes

* Improvements for tree delete method

- Handle case where item has already been deleted

* Raise exception rather than simply logging

* Update unit tests

* Improvements to unrelated unit test

* Fix urls.md

* Fix typo
This commit is contained in:
Oliver 2023-10-30 06:57:40 +11:00 committed by GitHub
parent a1f9260da6
commit fd0a57c4a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 254 additions and 114 deletions

View File

@ -12,7 +12,7 @@ from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django.db import models
from django.db.models.signals import post_save, pre_delete
from django.db.models.signals import post_save
from django.dispatch import receiver
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
@ -580,6 +580,10 @@ class InvenTreeTree(MPTTModel):
parent: The item immediately above this one. An item with a null parent is a top-level item
"""
# How items (not nodes) are hooked into the tree
# e.g. for StockLocation, this value is 'location'
ITEM_PARENT_KEY = None
class Meta:
"""Metaclass defines extra model properties."""
abstract = True
@ -588,6 +592,106 @@ class InvenTreeTree(MPTTModel):
"""Set insert order."""
order_insertion_by = ['name']
def delete(self, delete_children=False, delete_items=False):
"""Handle the deletion of a tree node.
1. Update nodes and items under the current node
2. Delete this node
3. Rebuild the model tree
4. Rebuild the path for any remaining lower nodes
"""
tree_id = self.tree_id if self.parent else None
# Ensure that we have the latest version of the database object
try:
self.refresh_from_db()
except self.__class__.DoesNotExist:
# If the object no longer exists, raise a ValidationError
raise ValidationError("Object %s of type %s no longer exists", str(self), str(self.__class__))
# Cache node ID values for lower nodes, before we delete this one
lower_nodes = list(self.get_descendants(include_self=False).values_list('pk', flat=True))
# 1. Update nodes and items under the current node
self.handle_tree_delete(delete_children=delete_children, delete_items=delete_items)
# 2. Delete *this* node
super().delete()
# 3. Update the tree structure
if tree_id:
self.__class__.objects.partial_rebuild(tree_id)
else:
self.__class__.objects.rebuild()
# 4. Rebuild the path for any remaining lower nodes
nodes = self.__class__.objects.filter(pk__in=lower_nodes)
nodes_to_update = []
for node in nodes:
new_path = node.construct_pathstring()
if new_path != node.pathstring:
node.pathstring = new_path
nodes_to_update.append(node)
if len(nodes_to_update) > 0:
self.__class__.objects.bulk_update(nodes_to_update, ['pathstring'])
def handle_tree_delete(self, delete_children=False, delete_items=False):
"""Delete a single instance of the tree, based on provided kwargs.
Removing a tree "node" from the database must be considered carefully,
based on what the user intends for any items which exist *under* that node.
- "children" are any nodes which exist *under* this node (e.g. PartCategory)
- "items" are any items which exist *under* this node (e.g. Part)
Arguments:
delete_children: If True, delete all child items
delete_items: If True, delete all items associated with this node
There are multiple scenarios we can consider here:
A) delete_children = True and delete_items = True
B) delete_children = True and delete_items = False
C) delete_children = False and delete_items = True
D) delete_children = False and delete_items = False
"""
# Case A: Delete all child items, and all child nodes.
# - Delete all items at any lower level
# - Delete all descendant nodes
if delete_children and delete_items:
self.get_items(cascade=True).delete()
self.get_descendants(include_self=False).delete()
# Case B: Delete all child nodes, but move all child items up to the parent
# - Move all items at any lower level to the parent of this item
# - Delete all descendant nodes
elif delete_children and not delete_items:
self.get_items(cascade=True).update(**{
self.ITEM_PARENT_KEY: self.parent
})
self.get_descendants(include_self=False).delete()
# Case C: Delete all child items, but keep all child nodes
# - Remove all items directly associated with this node
# - Move any direct child nodes up one level
elif not delete_children and delete_items:
self.get_items(cascade=False).delete()
self.get_children().update(parent=self.parent)
# Case D: Keep all child items, and keep all child nodes
# - Move all items directly associated with this node up one level
# - Move any direct child nodes up one level
elif not delete_children and not delete_items:
self.get_items(cascade=False).update(**{
self.ITEM_PARENT_KEY: self.parent
})
self.get_children().update(parent=self.parent)
def validate_unique(self, exclude=None):
"""Validate that this tree instance satisfies our uniqueness requirements.
@ -614,6 +718,12 @@ class InvenTreeTree(MPTTModel):
}
}
def construct_pathstring(self):
"""Construct the pathstring for this tree node"""
return InvenTree.helpers.constructPathString(
[item.name for item in self.path]
)
def save(self, *args, **kwargs):
"""Custom save method for InvenTreeTree abstract model"""
try:
@ -625,9 +735,7 @@ class InvenTreeTree(MPTTModel):
})
# Re-calculate the 'pathstring' field
pathstring = InvenTree.helpers.constructPathString(
[item.name for item in self.path]
)
pathstring = self.construct_pathstring()
if pathstring != self.pathstring:
@ -639,9 +747,20 @@ class InvenTreeTree(MPTTModel):
self.pathstring = pathstring
super().save(*args, **kwargs)
# Ensure that the pathstring changes are propagated down the tree also
for child in self.get_children():
child.save(*args, **kwargs)
# Update the pathstring for any child nodes
lower_nodes = self.get_descendants(include_self=False)
nodes_to_update = []
for node in lower_nodes:
new_path = node.construct_pathstring()
if new_path != node.pathstring:
node.pathstring = new_path
nodes_to_update.append(node)
if len(nodes_to_update) > 0:
self.__class__.objects.bulk_update(nodes_to_update, ['pathstring'])
name = models.CharField(
blank=False,
@ -673,16 +792,15 @@ class InvenTreeTree(MPTTModel):
help_text=_('Path')
)
@property
def item_count(self):
"""Return the number of items which exist *under* this node in the tree.
def get_items(self, cascade=False):
"""Return a queryset of items which exist *under* this node in the tree.
Here an 'item' is considered to be the 'leaf' at the end of each branch,
and the exact nature here will depend on the class implementation.
- For a StockLocation instance, this would be a queryset of StockItem objects
- For a PartCategory instance, this would be a queryset of Part objects
The default implementation returns zero
The default implementation returns an empty list
"""
return 0
raise NotImplementedError(f"items() method not implemented for {type(self)}")
def getUniqueParents(self):
"""Return a flat set of all parent items that exist above this node.
@ -878,18 +996,6 @@ class InvenTreeBarcodeMixin(models.Model):
self.save()
@receiver(pre_delete, sender=InvenTreeTree, dispatch_uid='tree_pre_delete_log')
def before_delete_tree_item(sender, instance, using, **kwargs):
"""Receives pre_delete signal from InvenTreeTree object.
Before an item is deleted, update each child object to point to the parent of the object being deleted.
"""
# Update each tree item below this one
for child in instance.children.all():
child.parent = instance.parent
child.save()
@receiver(post_save, sender=Error, dispatch_uid='error_post_save_notification')
def after_error_logged(sender, instance: Error, created: bool, **kwargs):
"""Callback when a server error is logged.

View File

@ -72,55 +72,23 @@ class PartCategory(MetadataMixin, InvenTreeTree):
default_keywords: Default keywords for parts created in this category
"""
ITEM_PARENT_KEY = 'category'
class Meta:
"""Metaclass defines extra model properties"""
verbose_name = _("Part Category")
verbose_name_plural = _("Part Categories")
def delete_recursive(self, *args, **kwargs):
"""This function handles the recursive deletion of subcategories depending on kwargs contents"""
delete_parts = kwargs.get('delete_parts', False)
parent_category = kwargs.get('parent_category', None)
if parent_category is None:
# First iteration, (no part_category kwargs passed)
parent_category = self.parent
for child_part in self.parts.all():
if delete_parts:
child_part.delete()
else:
child_part.category = parent_category
child_part.save()
for child_category in self.children.all():
if kwargs.get('delete_child_categories', False):
child_category.delete_recursive(**{
"delete_child_categories": True,
"delete_parts": delete_parts,
"parent_category": parent_category})
else:
child_category.parent = parent_category
child_category.save()
super().delete(*args, **{})
def delete(self, *args, **kwargs):
"""Custom model deletion routine, which updates any child categories or parts.
This must be handled within a transaction.atomic(), otherwise the tree structure is damaged
"""
with transaction.atomic():
self.delete_recursive(**{
"delete_parts": kwargs.get('delete_parts', False),
"delete_child_categories": kwargs.get('delete_child_categories', False),
"parent_category": self.parent})
if self.parent is not None:
# Partially rebuild the tree (cheaper than a complete rebuild)
PartCategory.objects.partial_rebuild(self.tree_id)
else:
PartCategory.objects.rebuild()
super().delete(
delete_children=kwargs.get('delete_child_categories', False),
delete_items=kwargs.get('delete_parts', False),
)
default_location = TreeForeignKey(
'stock.StockLocation', related_name="default_categories",
@ -189,6 +157,10 @@ class PartCategory(MetadataMixin, InvenTreeTree):
"""Return the number of parts contained in this PartCategory"""
return self.partcount()
def get_items(self, cascade=False):
"""Return a queryset containing the parts which exist in this category"""
return self.get_parts(cascade=cascade)
def partcount(self, cascade=True, active=False):
"""Return the total part count under this category (including children of child categories)."""
query = self.get_parts(cascade=cascade)

View File

@ -237,9 +237,9 @@ class BomUploadTest(InvenTreeAPITestCase):
components = Part.objects.filter(component=True)
for idx, _ in enumerate(components):
for component in components:
dataset.append([
f"Component {idx}",
component.name,
10,
])
@ -266,9 +266,9 @@ class BomUploadTest(InvenTreeAPITestCase):
dataset.headers = ['part_ipn', 'quantity']
for idx, _ in enumerate(components):
for component in components:
dataset.append([
f"CMP_{idx}",
component.IPN,
10,
])

View File

@ -248,6 +248,18 @@ class CategoryTest(TestCase):
C32 = PartCategory.objects.create(name='C32', parent=B3)
C33 = PartCategory.objects.create(name='C33', parent=B3)
D31 = PartCategory.objects.create(name='D31', parent=C31)
D32 = PartCategory.objects.create(name='D32', parent=C32)
D33 = PartCategory.objects.create(name='D33', parent=C33)
E33 = PartCategory.objects.create(name='E33', parent=D33)
# Check that pathstrings have been generated correctly
self.assertEqual(B3.pathstring, 'A/B3')
self.assertEqual(C11.pathstring, 'A/B1/C11')
self.assertEqual(C22.pathstring, 'A/B2/C22')
self.assertEqual(C33.pathstring, 'A/B3/C33')
# Check that the tree_id value is correct
for cat in [B1, B2, B3, C11, C22, C33]:
self.assertEqual(cat.tree_id, A.tree_id)
@ -289,6 +301,8 @@ class CategoryTest(TestCase):
self.assertEqual(cat.get_ancestors().count(), 1)
self.assertEqual(cat.get_ancestors()[0], A)
self.assertEqual(cat.pathstring, f'A/{cat.name}')
# Now, delete category A
A.delete()
@ -302,6 +316,13 @@ class CategoryTest(TestCase):
self.assertEqual(loc.level, 0)
self.assertEqual(loc.parent, None)
# Pathstring should be the same as the name
self.assertEqual(loc.pathstring, loc.name)
# Test pathstring for direct children
for child in loc.get_children():
self.assertEqual(child.pathstring, f'{loc.name}/{child.name}')
# Check descendants for B1
descendants = B1.get_descendants()
self.assertEqual(descendants.count(), 3)
@ -321,6 +342,8 @@ class CategoryTest(TestCase):
self.assertEqual(ancestors[0], B1)
self.assertEqual(ancestors[1], loc)
self.assertEqual(loc.pathstring, f'B1/{loc.name}')
# Check category C2x, should be B2 -> C2x
for loc in [C21, C22, C23]:
loc.refresh_from_db()
@ -332,3 +355,65 @@ class CategoryTest(TestCase):
self.assertEqual(ancestors.count(), 2)
self.assertEqual(ancestors[0], B2)
self.assertEqual(ancestors[1], loc)
self.assertEqual(loc.pathstring, f'B2/{loc.name}')
# Check category D3x, should be C3x -> D3x
D31.refresh_from_db()
self.assertEqual(D31.pathstring, 'C31/D31')
D32.refresh_from_db()
self.assertEqual(D32.pathstring, 'C32/D32')
D33.refresh_from_db()
self.assertEqual(D33.pathstring, 'C33/D33')
# Check category E33
E33.refresh_from_db()
self.assertEqual(E33.pathstring, 'C33/D33/E33')
# Change the name of an upper level
C33.name = '-C33-'
C33.save()
D33.refresh_from_db()
self.assertEqual(D33.pathstring, '-C33-/D33')
E33.refresh_from_db()
self.assertEqual(E33.pathstring, '-C33-/D33/E33')
# Test the "delete child categories" functionality
C33.delete(delete_child_categories=True)
# Any child underneath C33 should have been deleted
for cat in [D33, E33]:
with self.assertRaises(PartCategory.DoesNotExist):
cat.refresh_from_db()
Part.objects.all().delete()
# Create some sample parts under D32
for ii in range(10):
Part.objects.create(
name=f'Part D32 {ii}',
description='A test part',
category=D32,
)
self.assertEqual(Part.objects.filter(category=D32).count(), 10)
self.assertEqual(Part.objects.filter(category=C32).count(), 0)
# Delete D32, should move the parts up to C32
D32.delete(delete_child_categories=False, delete_parts=False)
# All parts should have been deleted
self.assertEqual(Part.objects.filter(category=C32).count(), 10)
# Now, delete C32 and delete all parts underneath
C32.delete(delete_parts=True)
# 10 parts should have been deleted from the database
self.assertEqual(Part.objects.count(), 0)
# Finally, try deleting a category which has already been deleted
# should log an exception
with self.assertRaises(ValidationError):
B3.delete()

View File

@ -1418,13 +1418,19 @@ class LocationDetail(CustomRetrieveUpdateDestroyAPI):
def destroy(self, request, *args, **kwargs):
"""Delete a Stock location instance via the API"""
delete_stock_items = 'delete_stock_items' in request.data and request.data['delete_stock_items'] == '1'
delete_sub_locations = 'delete_sub_locations' in request.data and request.data['delete_sub_locations'] == '1'
return super().destroy(request,
*args,
**dict(kwargs,
delete_sub_locations=delete_sub_locations,
delete_stock_items=delete_stock_items))
delete_stock_items = str(request.data.get('delete_stock_items', 0)) == '1'
delete_sub_locations = str(request.data.get('delete_sub_locations', 0)) == '1'
return super().destroy(
request,
*args,
**dict(
kwargs,
delete_sub_locations=delete_sub_locations,
delete_stock_items=delete_stock_items
)
)
stock_api_urls = [

View File

@ -108,6 +108,8 @@ class StockLocation(InvenTreeBarcodeMixin, MetadataMixin, InvenTreeTree):
Stock locations can be hierarchical as required
"""
ITEM_PARENT_KEY = 'location'
objects = StockLocationManager()
class Meta:
@ -118,51 +120,16 @@ class StockLocation(InvenTreeBarcodeMixin, MetadataMixin, InvenTreeTree):
tags = TaggableManager(blank=True)
def delete_recursive(self, *args, **kwargs):
"""This function handles the recursive deletion of sub-locations depending on kwargs contents"""
delete_stock_items = kwargs.get('delete_stock_items', False)
parent_location = kwargs.get('parent_location', None)
if parent_location is None:
# First iteration, (no parent_location kwargs passed)
parent_location = self.parent
for child_item in self.get_stock_items(False):
if delete_stock_items:
child_item.delete()
else:
child_item.location = parent_location
child_item.save()
for child_location in self.children.all():
if kwargs.get('delete_sub_locations', False):
child_location.delete_recursive(**{
"delete_sub_locations": True,
"delete_stock_items": delete_stock_items,
"parent_location": parent_location})
else:
child_location.parent = parent_location
child_location.save()
super().delete(*args, **{})
def delete(self, *args, **kwargs):
"""Custom model deletion routine, which updates any child locations or items.
This must be handled within a transaction.atomic(), otherwise the tree structure is damaged
"""
with transaction.atomic():
self.delete_recursive(**{
"delete_stock_items": kwargs.get('delete_stock_items', False),
"delete_sub_locations": kwargs.get('delete_sub_locations', False),
"parent_category": self.parent})
if self.parent is not None:
# Partially rebuild the tree (cheaper than a complete rebuild)
StockLocation.objects.partial_rebuild(self.tree_id)
else:
StockLocation.objects.rebuild()
super().delete(
delete_children=kwargs.get('delete_sub_locations', False),
delete_items=kwargs.get('delete_stock_items', False),
)
@staticmethod
def get_api_url():
@ -300,6 +267,10 @@ class StockLocation(InvenTreeBarcodeMixin, MetadataMixin, InvenTreeTree):
"""
return self.stock_item_count()
def get_items(self, cascade=False):
"""Return a queryset for all stock items under this category"""
return self.get_stock_items(cascade=cascade)
def generate_batch_code():
"""Generate a default 'batch code' for a new StockItem.