diff --git a/InvenTree/label/api.py b/InvenTree/label/api.py index 9a34a1a2a2..9d8092f551 100644 --- a/InvenTree/label/api.py +++ b/InvenTree/label/api.py @@ -23,9 +23,102 @@ from .serializers import (PartLabelSerializer, StockItemLabelSerializer, StockLocationLabelSerializer) -class LabelListView(ListAPI): +class LabelFilterMixin: + """Mixin for filtering a queryset by a list of object ID values. + + Each implementing class defines a database model to lookup, + and a "key" (query parameter) for providing a list of ID (PK) values. + + This mixin defines a 'get_items' method which provides a generic + implementation to return a list of matching database model instances. + """ + + # Database model for instances to actually be "printed" against this label template + ITEM_MODEL = None + + # Default key for looking up database model instances + ITEM_KEY = 'item' + + def get_items(self): + """Return a list of database objects from query parameter""" + + ids = [] + + # Construct a list of possible query parameter value options + # e.g. if self.ITEM_KEY = 'part' -> ['part', 'part', 'parts', parts[]'] + for k in [self.ITEM_KEY + x for x in ['', '[]', 's', 's[]']]: + if ids := self.request.query_params.getlist(k, []): + # Return the first list of matches + break + + # Next we must validate each provided object ID + valid_ids = [] + + for id in ids: + try: + valid_ids.append(int(id)) + except (ValueError): + pass + + # Filter queryset by matching ID values + return self.ITEM_MODEL.objects.filter(pk__in=valid_ids) + + +class LabelListView(LabelFilterMixin, ListAPI): """Generic API class for label templates.""" + def filter_queryset(self, queryset): + """Filter the queryset based on the provided label ID values. + + As each 'label' instance may optionally define its own filters, + the resulting queryset is the 'union' of the two. + """ + + queryset = super().filter_queryset(queryset) + + items = self.get_items() + + if len(items) > 0: + """ + At this point, we are basically forced to be inefficient, + as we need to compare the 'filters' string of each label, + and see if it matches against each of the requested items. + + TODO: In the future, if this becomes excessively slow, it + will need to be readdressed. + """ + valid_label_ids = set() + + for label in queryset.all(): + matches = True + + try: + filters = InvenTree.helpers.validateFilterString(label.filters) + except ValidationError: + continue + + for item in items: + item_query = self.ITEM_MODEL.objects.filter(pk=item.pk) + + try: + if not item_query.filter(**filters).exists(): + matches = False + break + except FieldError: + matches = False + break + + # Matched all items + if matches: + valid_label_ids.add(label.pk) + else: + continue + + # Reduce queryset to only valid matches + queryset = queryset.filter(pk__in=[pk for pk in valid_label_ids]) + + return queryset + filter_backends = [ DjangoFilterBackend, filters.SearchFilter @@ -41,9 +134,13 @@ class LabelListView(ListAPI): ] -class LabelPrintMixin: +class LabelPrintMixin(LabelFilterMixin): """Mixin for printing labels.""" + def get(self, request, *args, **kwargs): + """Perform a GET request against this endpoint to print labels""" + return self.print(request, self.get_items()) + def get_plugin(self, request): """Return the label printing plugin associated with this request. @@ -173,35 +270,7 @@ class LabelPrintMixin: ) -class StockItemLabelMixin: - """Mixin for extracting stock items from query params.""" - - def get_items(self): - """Return a list of requested stock items.""" - items = [] - - params = self.request.query_params - - for key in ['item', 'item[]', 'items', 'items[]']: - if key in params: - items = params.getlist(key, []) - break - - valid_ids = [] - - for item in items: - try: - valid_ids.append(int(item)) - except (ValueError): - pass - - # List of StockItems which match provided values - valid_items = StockItem.objects.filter(pk__in=valid_ids) - - return valid_items - - -class StockItemLabelList(LabelListView, StockItemLabelMixin): +class StockItemLabelList(LabelListView): """API endpoint for viewing list of StockItemLabel objects. Filterable by: @@ -214,59 +283,8 @@ class StockItemLabelList(LabelListView, StockItemLabelMixin): queryset = StockItemLabel.objects.all() serializer_class = StockItemLabelSerializer - def filter_queryset(self, queryset): - """Filter the StockItem label queryset.""" - queryset = super().filter_queryset(queryset) - - # List of StockItem objects to match against - items = self.get_items() - - # We wish to filter by stock items - if len(items) > 0: - """ - At this point, we are basically forced to be inefficient, - as we need to compare the 'filters' string of each label, - and see if it matches against each of the requested items. - - TODO: In the future, if this becomes excessively slow, it - will need to be readdressed. - """ - - # Keep track of which labels match every specified stockitem - valid_label_ids = set() - - for label in queryset.all(): - - matches = True - - # Filter string defined for the StockItemLabel object - try: - filters = InvenTree.helpers.validateFilterString(label.filters) - except ValidationError: # pragma: no cover - continue - - for item in items: - - item_query = StockItem.objects.filter(pk=item.pk) - - try: - if not item_query.filter(**filters).exists(): - matches = False - break - except FieldError: - matches = False - break - - # Matched all items - if matches: - valid_label_ids.add(label.pk) - else: - continue # pragma: no cover - - # Reduce queryset to only valid matches - queryset = queryset.filter(pk__in=[pk for pk in valid_label_ids]) - - return queryset + ITEM_MODEL = StockItem + ITEM_KEY = 'item' class StockItemLabelDetail(RetrieveUpdateDestroyAPI): @@ -276,48 +294,17 @@ class StockItemLabelDetail(RetrieveUpdateDestroyAPI): serializer_class = StockItemLabelSerializer -class StockItemLabelPrint(RetrieveAPI, StockItemLabelMixin, LabelPrintMixin): +class StockItemLabelPrint(LabelPrintMixin, RetrieveAPI): """API endpoint for printing a StockItemLabel object.""" queryset = StockItemLabel.objects.all() serializer_class = StockItemLabelSerializer - def get(self, request, *args, **kwargs): - """Check if valid stock item(s) have been provided.""" - items = self.get_items() - - return self.print(request, items) + ITEM_MODEL = StockItem + ITEM_KEY = 'item' -class StockLocationLabelMixin: - """Mixin for extracting stock locations from query params.""" - - def get_locations(self): - """Return a list of requested stock locations.""" - locations = [] - - params = self.request.query_params - - for key in ['location', 'location[]', 'locations', 'locations[]']: - - if key in params: - locations = params.getlist(key, []) - - valid_ids = [] - - for loc in locations: - try: - valid_ids.append(int(loc)) - except (ValueError): - pass - - # List of StockLocation objects which match provided values - valid_locations = StockLocation.objects.filter(pk__in=valid_ids) - - return valid_locations - - -class StockLocationLabelList(LabelListView, StockLocationLabelMixin): +class StockLocationLabelList(LabelListView): """API endpoint for viewiing list of StockLocationLabel objects. Filterable by: @@ -330,59 +317,8 @@ class StockLocationLabelList(LabelListView, StockLocationLabelMixin): queryset = StockLocationLabel.objects.all() serializer_class = StockLocationLabelSerializer - def filter_queryset(self, queryset): - """Filter the StockLocationLabel queryset.""" - queryset = super().filter_queryset(queryset) - - # List of StockLocation objects to match against - locations = self.get_locations() - - # We wish to filter by stock location(s) - if len(locations) > 0: - """ - At this point, we are basically forced to be inefficient, - as we need to compare the 'filters' string of each label, - and see if it matches against each of the requested items. - - TODO: In the future, if this becomes excessively slow, it - will need to be readdressed. - """ - - valid_label_ids = set() - - for label in queryset.all(): - - matches = True - - # Filter string defined for the StockLocationLabel object - try: - filters = InvenTree.helpers.validateFilterString(label.filters) - except Exception: # pragma: no cover - # Skip if there was an error validating the filters... - continue - - for loc in locations: - - loc_query = StockLocation.objects.filter(pk=loc.pk) - - try: - if not loc_query.filter(**filters).exists(): - matches = False - break - except FieldError: - matches = False - break - - # Matched all items - if matches: - valid_label_ids.add(label.pk) - else: - continue # pragma: no cover - - # Reduce queryset to only valid matches - queryset = queryset.filter(pk__in=[pk for pk in valid_label_ids]) - - return queryset + ITEM_MODEL = StockLocation + ITEM_KEY = 'location' class StockLocationLabelDetail(RetrieveUpdateDestroyAPI): @@ -392,89 +328,24 @@ class StockLocationLabelDetail(RetrieveUpdateDestroyAPI): serializer_class = StockLocationLabelSerializer -class StockLocationLabelPrint(RetrieveAPI, StockLocationLabelMixin, LabelPrintMixin): +class StockLocationLabelPrint(LabelPrintMixin, RetrieveAPI): """API endpoint for printing a StockLocationLabel object.""" queryset = StockLocationLabel.objects.all() seiralizer_class = StockLocationLabelSerializer - def get(self, request, *args, **kwargs): - """Print labels based on the request parameters""" - locations = self.get_locations() - - return self.print(request, locations) + ITEM_MODEL = StockLocation + ITEM_KEY = 'location' -class PartLabelMixin: - """Mixin for extracting Part objects from query parameters.""" - - def get_parts(self): - """Return a list of requested Part objects.""" - parts = [] - - params = self.request.query_params - - for key in ['part', 'part[]', 'parts', 'parts[]']: - if key in params: - parts = params.getlist(key, []) - break - - valid_ids = [] - - for part in parts: - try: - valid_ids.append(int(part)) - except (ValueError): - pass - - # List of Part objects which match provided values - return Part.objects.filter(pk__in=valid_ids) - - -class PartLabelList(LabelListView, PartLabelMixin): +class PartLabelList(LabelListView): """API endpoint for viewing list of PartLabel objects.""" queryset = PartLabel.objects.all() serializer_class = PartLabelSerializer - def filter_queryset(self, queryset): - """Custom queryset filtering for the PartLabel list""" - queryset = super().filter_queryset(queryset) - - parts = self.get_parts() - - if len(parts) > 0: - - valid_label_ids = set() - - for label in queryset.all(): - - matches = True - - try: - filters = InvenTree.helpers.validateFilterString(label.filters) - except ValidationError: # pragma: no cover - continue - - for part in parts: - - part_query = Part.objects.filter(pk=part.pk) - - try: - if not part_query.filter(**filters).exists(): - matches = False - break - except FieldError: - matches = False - break - - if matches: - valid_label_ids.add(label.pk) - - # Reduce queryset to only valid matches - queryset = queryset.filter(pk__in=[pk for pk in valid_label_ids]) - - return queryset + ITEM_MODEL = Part + ITEM_KEY = 'part' class PartLabelDetail(RetrieveUpdateDestroyAPI): @@ -484,17 +355,14 @@ class PartLabelDetail(RetrieveUpdateDestroyAPI): serializer_class = PartLabelSerializer -class PartLabelPrint(RetrieveAPI, PartLabelMixin, LabelPrintMixin): +class PartLabelPrint(LabelPrintMixin, RetrieveAPI): """API endpoint for printing a PartLabel object.""" queryset = PartLabel.objects.all() serializer_class = PartLabelSerializer - def get(self, request, *args, **kwargs): - """Check if valid part(s) have been provided.""" - parts = self.get_parts() - - return self.print(request, parts) + ITEM_MODEL = Part + ITEM_KEY = 'part' label_api_urls = [