diff --git a/InvenTree/part/api.py b/InvenTree/part/api.py index e13d493fc3..5601850e1c 100644 --- a/InvenTree/part/api.py +++ b/InvenTree/part/api.py @@ -87,13 +87,12 @@ class PartList(generics.ListCreateAPIView): childs = category.getUniqueChildren() for child in childs: # Ignore the top-level category (already filtered) - if child == cat_id: + if str(child) == str(cat_id): continue flt |= Q(category=child) parts_list = parts_list.filter(flt) - # Default - return all parts return parts_list permission_classes = [ diff --git a/InvenTree/stock/api.py b/InvenTree/stock/api.py index e8ab68acfe..32e3807ef7 100644 --- a/InvenTree/stock/api.py +++ b/InvenTree/stock/api.py @@ -2,6 +2,8 @@ from django_filters.rest_framework import FilterSet, DjangoFilterBackend from django_filters import NumberFilter from django.conf.urls import url, include +from django.db.models import Q +from django.shortcuts import get_object_or_404 from .models import StockLocation, StockItem from .models import StockItemTracking @@ -202,7 +204,36 @@ class StockList(generics.ListCreateAPIView): Create a new StockItem """ - queryset = StockItem.objects.all() + def get_queryset(self): + """ + If the query includes a particular location, + we may wish to also request stock items from all child locations. + This is set by the optional param 'include_child_categories' + """ + + # Does the client wish to filter by category? + loc_id = self.request.query_params.get('location', None) + + # Start with all objects + stock_list = StockItem.objects.all() + + if loc_id: + location = get_object_or_404(StockLocation, pk=loc_id) + + # Filter by the supplied category + flt = Q(location=loc_id) + + if self.request.query_params.get('include_child_locations', None): + childs = location.getUniqueChildren() + for child in childs: + # Ignore the top-level category (already filtered!) + if str(child) == str(loc_id): + continue + flt |= Q(location=child) + + stock_list = stock_list.filter(flt) + + return stock_list serializer_class = StockItemSerializer @@ -219,7 +250,6 @@ class StockList(generics.ListCreateAPIView): filter_fields = [ 'part', 'uuid', - 'location', 'supplier_part', 'customer', 'belongs_to', diff --git a/InvenTree/stock/templates/stock/location.html b/InvenTree/stock/templates/stock/location.html index 79ae9ecd15..79a93ccda1 100644 --- a/InvenTree/stock/templates/stock/location.html +++ b/InvenTree/stock/templates/stock/location.html @@ -160,7 +160,8 @@ loadStockTable($("#stock-table"), { params: { {% if location %} - location: {{ location.id }} + location: {{ location.id }}, + include_child_locations: true, {% endif %} }, url: "{% url 'api-stock-list' %}",