diff --git a/InvenTree/part/api.py b/InvenTree/part/api.py index f776c208b5..0d61025a8e 100644 --- a/InvenTree/part/api.py +++ b/InvenTree/part/api.py @@ -293,10 +293,10 @@ class PartList(generics.ListCreateAPIView): def filter_queryset(self, queryset): """ - Perform custom filtering of the queryset + Perform custom filtering of the queryset. + We overide the DRF filter_fields here because """ - # Perform basic filtering queryset = super().filter_queryset(queryset) # Filter by 'starred' parts? diff --git a/InvenTree/stock/api.py b/InvenTree/stock/api.py index d1f6c2d944..4816966718 100644 --- a/InvenTree/stock/api.py +++ b/InvenTree/stock/api.py @@ -7,12 +7,17 @@ from django_filters import NumberFilter from django.conf.urls import url, include from django.urls import reverse +from django.http import JsonResponse from django.db.models import Q from .models import StockLocation, StockItem from .models import StockItemTracking from part.models import Part, PartCategory +from part.serializers import PartBriefSerializer + +from company.models import SupplierPart +from company.serializers import SupplierPartSerializer from .serializers import StockItemSerializer from .serializers import LocationSerializer, LocationBriefSerializer @@ -322,26 +327,8 @@ class StockList(generics.ListCreateAPIView): """ serializer_class = StockItemSerializer - queryset = StockItem.objects.all() - def get_serializer(self, *args, **kwargs): - - try: - kwargs['part_detail'] = str2bool(self.request.query_params.get('part_detail', None)) - except AttributeError: - pass - - try: - kwargs['supplier_part_detail'] = str2bool(self.request.query_params.get('supplier_part_detail', None)) - except AttributeError: - pass - - # Ensure the request context is passed through - kwargs['context'] = self.get_serializer_context() - - return self.serializer_class(*args, **kwargs) - # TODO - Override the 'create' method for this view, # to allow the user to be recorded when a new StockItem object is created @@ -364,17 +351,59 @@ class StockList(generics.ListCreateAPIView): data = serializer.data - # Do we wish to include StockLocation detail? - if str2bool(request.query_params.get('location_detail', False)): + # Keep track of which related models we need to query + location_ids = set() + part_ids = set() + supplier_part_ids = set() - # Work out which locations we need to query - location_ids = set() + # Iterate through each StockItem and grab some data + for item in data: + loc = item['location'] + if loc: + location_ids.add(loc) + + part = item['part'] + if part: + part_ids.add(part) + + sp = item['supplier_part'] + if sp: + supplier_part_ids.add(sp) + + # Do we wish to include Part detail? + if str2bool(request.query_params.get('part_detail', False)): + + # Fetch only the required Part objects from the database + parts = Part.objects.filter(pk__in=part_ids).prefetch_related( + 'category', + ) + + part_map = {} + + for part in parts: + part_map[part.pk] = PartBriefSerializer(part).data + + # Now update each StockItem with the related Part data + for stock_item in data: + part_id = stock_item['part'] + stock_item['part_detail'] = part_map.get(part_id, None) + + # Do we wish to include SupplierPart detail? + if str2bool(request.query_params.get('supplier_part_detail', False)): + + supplier_parts = SupplierPart.objects.filter(pk__in=supplier_part_ids) + + supplier_part_map = {} + + for part in supplier_parts: + supplier_part_map[part.pk] = SupplierPartSerializer(part).data for stock_item in data: - loc_id = stock_item['location'] + part_id = stock_item['supplier_part'] + stock_item['supplier_part_detail'] = supplier_part_map.get(part_id, None) - if loc_id is not None: - location_ids.add(loc_id) + # Do we wish to include StockLocation detail? + if str2bool(request.query_params.get('location_detail', False)): # Fetch only the required StockLocation objects from the database locations = StockLocation.objects.filter(pk__in=location_ids).prefetch_related( @@ -391,27 +420,60 @@ class StockList(generics.ListCreateAPIView): # Now update each StockItem with the related StockLocation data for stock_item in data: loc_id = stock_item['location'] + stock_item['supplier_detail'] = location_map.get(loc_id, None) - if loc_id is not None and loc_id in location_map.keys(): - detail = location_map[loc_id] - else: - detail = None + """ + Determine the response type based on the request. + a) For HTTP requests (e.g. via the browseable API) return a DRF response + b) For AJAX requests, simply return a JSON rendered response. - stock_item['location_detail'] = detail + Note: b) is about 100x quicker than a), because the DRF framework adds a lot of cruft + """ - return Response(data) + if request.is_ajax(): + return JsonResponse(data, safe=False) + else: + return Response(data) def get_queryset(self, *args, **kwargs): queryset = super().get_queryset(*args, **kwargs) queryset = StockItemSerializer.prefetch_queryset(queryset) + queryset = StockItemSerializer.annotate_queryset(queryset) return queryset def filter_queryset(self, queryset): - # Start with all objects - stock_list = super().filter_queryset(queryset) + params = self.request.query_params + + # Perform basic filtering: + # Note: We do not let DRF filter here, it be slow AF + + supplier_part = params.get('supplier_part', None) + + if supplier_part: + queryset = queryset.filter(supplier_part=supplier_part) + + belongs_to = params.get('belongs_to', None) + + if belongs_to: + queryset = queryset.filter(belongs_to=belongs_to) + + build = params.get('build', None) + + if build: + queryset = queryset.filter(build=build) + + build_order = params.get('build_order', None) + + if build_order: + queryset = queryset.filter(build_order=build_order) + + sales_order = params.get('sales_order', None) + + if sales_order: + queryset = queryset.filter(sales_order=sales_order) in_stock = self.request.query_params.get('in_stock', None) @@ -420,10 +482,10 @@ class StockList(generics.ListCreateAPIView): if in_stock: # Filter out parts which are not actually "in stock" - stock_list = stock_list.filter(StockItem.IN_STOCK_FILTER) + queryset = queryset.filter(StockItem.IN_STOCK_FILTER) else: # Only show parts which are not in stock - stock_list = stock_list.exclude(StockItem.IN_STOCK_FILTER) + queryset = queryset.exclude(StockItem.IN_STOCK_FILTER) # Filter by 'allocated' patrs? allocated = self.request.query_params.get('allocated', None) @@ -433,17 +495,17 @@ class StockList(generics.ListCreateAPIView): if allocated: # Filter StockItem with either build allocations or sales order allocations - stock_list = stock_list.filter(Q(sales_order_allocations__isnull=False) | Q(allocations__isnull=False)) + queryset = queryset.filter(Q(sales_order_allocations__isnull=False) | Q(allocations__isnull=False)) else: # Filter StockItem without build allocations or sales order allocations - stock_list = stock_list.filter(Q(sales_order_allocations__isnull=True) & Q(allocations__isnull=True)) + queryset = queryset.filter(Q(sales_order_allocations__isnull=True) & Q(allocations__isnull=True)) # Do we wish to filter by "active parts" active = self.request.query_params.get('active', None) if active is not None: active = str2bool(active) - stock_list = stock_list.filter(part__active=active) + queryset = queryset.filter(part__active=active) # Does the client wish to filter by the Part ID? part_id = self.request.query_params.get('part', None) @@ -454,9 +516,9 @@ class StockList(generics.ListCreateAPIView): # If the part is a Template part, select stock items for any "variant" parts under that template if part.is_template: - stock_list = stock_list.filter(part__in=[part.id for part in Part.objects.filter(variant_of=part_id)]) + queryset = queryset.filter(part__in=[part.id for part in Part.objects.filter(variant_of=part_id)]) else: - stock_list = stock_list.filter(part=part_id) + queryset = queryset.filter(part=part_id) except (ValueError, Part.DoesNotExist): raise ValidationError({"part": "Invalid Part ID specified"}) @@ -469,7 +531,7 @@ class StockList(generics.ListCreateAPIView): ancestor = StockItem.objects.get(pk=anc_id) # Only allow items which are descendants of the specified StockItem - stock_list = stock_list.filter(id__in=[item.pk for item in ancestor.children.all()]) + queryset = queryset.filter(id__in=[item.pk for item in ancestor.children.all()]) except (ValueError, Part.DoesNotExist): raise ValidationError({"ancestor": "Invalid ancestor ID specified"}) @@ -483,15 +545,15 @@ class StockList(generics.ListCreateAPIView): # Filter by 'null' location (i.e. top-level items) if isNull(loc_id): - stock_list = stock_list.filter(location=None) + queryset = queryset.filter(location=None) else: try: # If '?cascade=true' then include items which exist in sub-locations if cascade: location = StockLocation.objects.get(pk=loc_id) - stock_list = stock_list.filter(location__in=location.getUniqueChildren()) + queryset = queryset.filter(location__in=location.getUniqueChildren()) else: - stock_list = stock_list.filter(location=loc_id) + queryset = queryset.filter(location=loc_id) except (ValueError, StockLocation.DoesNotExist): pass @@ -502,7 +564,7 @@ class StockList(generics.ListCreateAPIView): if cat_id: try: category = PartCategory.objects.get(pk=cat_id) - stock_list = stock_list.filter(part__category__in=category.getUniqueChildren()) + queryset = queryset.filter(part__category__in=category.getUniqueChildren()) except (ValueError, PartCategory.DoesNotExist): raise ValidationError({"category": "Invalid category id specified"}) @@ -511,44 +573,42 @@ class StockList(generics.ListCreateAPIView): status = self.request.query_params.get('status', None) if status: - stock_list = stock_list.filter(status=status) + queryset = queryset.filter(status=status) # Filter by supplier_part ID supplier_part_id = self.request.query_params.get('supplier_part', None) if supplier_part_id: - stock_list = stock_list.filter(supplier_part=supplier_part_id) + queryset = queryset.filter(supplier_part=supplier_part_id) # Filter by company (either manufacturer or supplier) company = self.request.query_params.get('company', None) if company is not None: - stock_list = stock_list.filter(Q(supplier_part__supplier=company) | Q(supplier_part__manufacturer=company)) + queryset = queryset.filter(Q(supplier_part__supplier=company) | Q(supplier_part__manufacturer=company)) # Filter by supplier supplier = self.request.query_params.get('supplier', None) if supplier is not None: - stock_list = stock_list.filter(supplier_part__supplier=supplier) + queryset = queryset.filter(supplier_part__supplier=supplier) # Filter by manufacturer manufacturer = self.request.query_params.get('manufacturer', None) if manufacturer is not None: - stock_list = stock_list.filter(supplier_part__manufacturer=manufacturer) + queryset = queryset.filter(supplier_part__manufacturer=manufacturer) # Also ensure that we pre-fecth all the related items - stock_list = stock_list.prefetch_related( + queryset = queryset.prefetch_related( 'part', 'part__category', 'location' ) - stock_list = stock_list.order_by('part__name') + queryset = queryset.order_by('part__name') - return stock_list - - serializer_class = StockItemSerializer + return queryset permission_classes = [ permissions.IsAuthenticated, @@ -561,12 +621,6 @@ class StockList(generics.ListCreateAPIView): ] filter_fields = [ - 'supplier_part', - 'belongs_to', - 'build', - 'build_order', - 'sales_order', - 'build_order', ] diff --git a/InvenTree/stock/serializers.py b/InvenTree/stock/serializers.py index eeab49eb4b..6b34fbe6ce 100644 --- a/InvenTree/stock/serializers.py +++ b/InvenTree/stock/serializers.py @@ -7,6 +7,9 @@ from rest_framework import serializers from .models import StockItem, StockLocation from .models import StockItemTracking +from django.db.models import Sum, Count +from django.db.models.functions import Coalesce + from company.serializers import SupplierPartSerializer from part.serializers import PartBriefSerializer from InvenTree.serializers import UserSerializerBrief, InvenTreeModelSerializer @@ -62,6 +65,10 @@ class StockItemSerializer(InvenTreeModelSerializer): """ return queryset.prefetch_related( + 'belongs_to', + 'build', + 'build_order', + 'sales_order', 'supplier_part', 'supplier_part__supplier', 'supplier_part__manufacturer', @@ -79,7 +86,13 @@ class StockItemSerializer(InvenTreeModelSerializer): performing database queries as efficiently as possible. """ - # TODO - Add custom annotated fields + queryset = queryset.annotate( + allocated = Coalesce( + Sum('sales_order_allocations__quantity', distinct=True), 0) + Coalesce( + Sum('allocations__quantity', distinct=True), 0), + tracking_items = Count('tracking_info'), + ) + return queryset status_text = serializers.CharField(source='get_status_display', read_only=True) @@ -88,10 +101,10 @@ class StockItemSerializer(InvenTreeModelSerializer): location_detail = LocationBriefSerializer(source='location', many=False, read_only=True) supplier_part_detail = SupplierPartSerializer(source='supplier_part', many=False, read_only=True) - tracking_items = serializers.IntegerField(source='tracking_info_count', read_only=True) + tracking_items = serializers.IntegerField() quantity = serializers.FloatField() - allocated = serializers.FloatField(source='allocation_count', read_only=True) + allocated = serializers.FloatField() def __init__(self, *args, **kwargs): @@ -140,6 +153,7 @@ class StockItemSerializer(InvenTreeModelSerializer): They can be updated by accessing the appropriate API endpoints """ read_only_fields = [ + 'allocated', 'stocktake_date', 'stocktake_user', 'updated',