From f9a00b7a903011e899b6518886ab2e2541c713e4 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Nov 2021 22:57:49 +1100 Subject: [PATCH] Adds extra subsctiption functionality for Part and PartCategory - Allows variants and templates - Allows categories and sub-categories - Unit testing --- InvenTree/InvenTree/static/css/inventree.css | 4 - InvenTree/part/api.py | 2 +- InvenTree/part/models.py | 121 ++++++++++++++++--- InvenTree/part/test_part.py | 119 +++++++++++++++++- 4 files changed, 224 insertions(+), 22 deletions(-) diff --git a/InvenTree/InvenTree/static/css/inventree.css b/InvenTree/InvenTree/static/css/inventree.css index 670d577497..273f2ec527 100644 --- a/InvenTree/InvenTree/static/css/inventree.css +++ b/InvenTree/InvenTree/static/css/inventree.css @@ -180,10 +180,6 @@ float: right; } -.starred-part { - color: #ffbb00; -} - .red-cell { background-color: #ec7f7f; } diff --git a/InvenTree/part/api.py b/InvenTree/part/api.py index a11bb1b088..0b754dffe8 100644 --- a/InvenTree/part/api.py +++ b/InvenTree/part/api.py @@ -420,7 +420,7 @@ class PartDetail(generics.RetrieveUpdateDestroyAPIView): if 'starred' in request.data: starred = str2bool(request.data.get('starred', None)) - self.get_object().setStarred(request.user, starred) + self.get_object().set_subscription(request.user, starred) response = super().update(request, *args, **kwargs) diff --git a/InvenTree/part/models.py b/InvenTree/part/models.py index fc7382ac62..0b99b8dac5 100644 --- a/InvenTree/part/models.py +++ b/InvenTree/part/models.py @@ -15,7 +15,7 @@ from django.urls import reverse from django.db import models, transaction from django.db.utils import IntegrityError -from django.db.models import Q, Sum, UniqueConstraint +from django.db.models import Q, Sum, UniqueConstraint, query from django.db.models.functions import Coalesce from django.core.validators import MinValueValidator @@ -201,6 +201,60 @@ class PartCategory(InvenTreeTree): return prefetch.filter(category=self.id) + def get_subscribers(self, include_parents=True): + """ + Return a list of users who subscribe to this PartCategory + """ + + cats = self.get_ancestors(include_self=True) + + subscribers = set() + + if include_parents: + queryset = PartCategoryStar.objects.filter( + category__pk__in=[cat.pk for cat in cats] + ) + else: + queryset = PartCategoryStar.objects.filter( + category=self, + ) + + for result in queryset: + subscribers.add(result.user) + + return [s for s in subscribers] + + def is_subscribed_by(self, user, **kwargs): + """ + Returns True if the specified user subscribes to this category + """ + + return user in self.get_subscribers(**kwargs) + + def set_subscription(self, user, status): + """ + Set the "subscription" status of this PartCategory against the specified user + """ + + if not user: + return + + if self.is_subscribed_by(user) == status: + return + + if status: + PartCategoryStar.objects.create( + category=self, + user=user + ) + else: + # Note that this won't actually stop the user being subscribed, + # if the user is subscribed to a parent category + PartCategoryStar.objects.filter( + category=self, + user=user, + ).delete() + @receiver(pre_delete, sender=PartCategory, dispatch_uid='partcategory_delete_log') def before_delete_part_category(sender, instance, using, **kwargs): @@ -332,7 +386,7 @@ class Part(MPTTModel): context = {} - context['starred'] = self.isStarredBy(request.user) + context['starred'] = self.is_subscribed_by(request.user) context['disabled'] = not self.active # Pre-calculate complex queries so they only need to be performed once @@ -1040,30 +1094,65 @@ class Part(MPTTModel): return self.total_stock - self.allocation_count() + self.on_order - def isStarredBy(self, user): - """ Return True if this part has been starred by a particular user """ - - try: - PartStar.objects.get(part=self, user=user) - return True - except PartStar.DoesNotExist: - return False - - def setStarred(self, user, starred): + def get_subscribers(self, include_variants=True, include_categories=True): """ - Set the "starred" status of this Part for the given user + Return a list of users who are 'subscribed' to this part. + + A user may 'subscribe' to this part in the following ways: + + a) Subscribing to the part instance directly + b) Subscribing to a template part "above" this part (if it is a variant) + c) Subscribing to the part category that this part belongs to + d) Subscribing to a parent category of the category in c) + + """ + + subscribers = set() + + # Start by looking at direct subscriptions to a Part model + queryset = PartStar.objects.all() + + if include_variants: + queryset = queryset.filter( + part__pk__in=[part.pk for part in self.get_ancestors(include_self=True)] + ) + else: + queryset = queryset.filter(part=self) + + for star in queryset: + subscribers.add(star.user) + + if include_categories and self.category: + + for sub in self.category.get_subscribers(): + subscribers.add(sub) + + return [s for s in subscribers] + + def is_subscribed_by(self, user, **kwargs): + """ + Return True if the specified user subscribes to this part + """ + + return user in self.get_subscribers(**kwargs) + + def set_subscription(self, user, status): + """ + Set the "subscription" status of this Part against the specified user """ if not user: return - # Do not duplicate efforts - if self.isStarredBy(user) == starred: + # Already subscribed? + if self.is_subscribed_by(user) == status: return - if starred: + if status: PartStar.objects.create(part=self, user=user) else: + # Note that this won't actually stop the user being subscribed, + # if the user is subscribed to a parent part or category PartStar.objects.filter(part=self, user=user).delete() def need_to_restock(self): diff --git a/InvenTree/part/test_part.py b/InvenTree/part/test_part.py index 1bd9fdf87d..39bb6a39af 100644 --- a/InvenTree/part/test_part.py +++ b/InvenTree/part/test_part.py @@ -11,7 +11,7 @@ from django.core.exceptions import ValidationError import os -from .models import Part, PartCategory, PartTestTemplate +from .models import Part, PartCategory, PartCategoryStar, PartStar, PartTestTemplate from .models import rename_part_image from .templatetags import inventree_extras @@ -347,3 +347,120 @@ class PartSettingsTest(TestCase): with self.assertRaises(ValidationError): part = Part(name='Hello', description='A thing', IPN='IPN123', revision='C') part.full_clean() + + +class PartSubscriptionTests(TestCase): + + fixtures = [ + 'location', + 'category', + 'part', + ] + + def setUp(self): + # Create a user for auth + user = get_user_model() + + self.user = user.objects.create_user( + username='testuser', + email='test@testing.com', + password='password', + is_staff=True + ) + + # electronics / IC / MCU + self.category = PartCategory.objects.get(pk=4) + + self.part = Part.objects.create( + category=self.category, + name='STM32F103', + description='Currently worth a lot of money', + is_template=True, + ) + + def test_part_subcription(self): + """ + Test basic subscription against a part + """ + + # First check that the user is *not* subscribed to the part + self.assertFalse(self.part.is_subscribed_by(self.user)) + + # Now, subscribe directly to the part + self.part.set_subscription(self.user, True) + + self.assertEqual(PartStar.objects.count(), 1) + + self.assertTrue(self.part.is_subscribed_by(self.user)) + + # Now, unsubscribe + self.part.set_subscription(self.user, False) + + self.assertFalse(self.part.is_subscribed_by(self.user)) + + def test_variant_subscription(self): + """ + Test subscription against a parent part + """ + + # Construct a sub-part to star against + sub_part = Part.objects.create( + name='sub_part', + description='a sub part', + variant_of=self.part, + ) + + self.assertFalse(sub_part.is_subscribed_by(self.user)) + + # Subscribe to the "parent" part + self.part.set_subscription(self.user, True) + + self.assertTrue(self.part.is_subscribed_by(self.user)) + self.assertTrue(sub_part.is_subscribed_by(self.user)) + + def test_category_subscription(self): + """ + Test subscription against a PartCategory + """ + + self.assertEqual(PartCategoryStar.objects.count(), 0) + + self.assertFalse(self.part.is_subscribed_by(self.user)) + self.assertFalse(self.category.is_subscribed_by(self.user)) + + # Subscribe to the direct parent category + self.category.set_subscription(self.user, True) + + self.assertEqual(PartStar.objects.count(), 0) + self.assertEqual(PartCategoryStar.objects.count(), 1) + + self.assertTrue(self.category.is_subscribed_by(self.user)) + self.assertTrue(self.part.is_subscribed_by(self.user)) + + # Check that the "parent" category is not starred + self.assertFalse(self.category.parent.is_subscribed_by(self.user)) + + # Un-subscribe + self.category.set_subscription(self.user, False) + + self.assertFalse(self.category.is_subscribed_by(self.user)) + self.assertFalse(self.part.is_subscribed_by(self.user)) + + def test_parent_category_subscription(self): + """ + Check that a parent category can be subscribed to + """ + + # Top-level "electronics" category + cat = PartCategory.objects.get(pk=1) + + cat.set_subscription(self.user, True) + + # Check base category + self.assertTrue(cat.is_subscribed_by(self.user)) + + # Check lower level category + self.assertTrue(self.category.is_subscribed_by(self.user)) + + # Check part + self.assertTrue(self.part.is_subscribed_by(self.user))