From e1558a7a968ba4763206b257df0630068b7d1e61 Mon Sep 17 00:00:00 2001 From: Oliver Walters Date: Thu, 16 May 2019 18:54:14 +1000 Subject: [PATCH] Add ability to filter stock items by part category - Using the REST API - Add &category=x to the query string - Add &include_child_categories to include child categories --- InvenTree/stock/api.py | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/InvenTree/stock/api.py b/InvenTree/stock/api.py index 9f305b657a..e3fbce8636 100644 --- a/InvenTree/stock/api.py +++ b/InvenTree/stock/api.py @@ -7,16 +7,18 @@ from django_filters import NumberFilter from django.conf.urls import url, include from django.urls import reverse -from django.db.models import Q from .models import StockLocation, StockItem from .models import StockItemTracking +from part.models import PartCategory + from .serializers import StockItemSerializer, StockQuantitySerializer from .serializers import LocationSerializer from .serializers import StockTrackingSerializer from InvenTree.views import TreeSerializer +from InvenTree.helpers import str2bool from rest_framework.serializers import ValidationError from rest_framework.views import APIView @@ -237,6 +239,10 @@ class StockList(generics.ListCreateAPIView): - GET: Return a list of all StockItem objects (with optional query filters) - POST: Create a new StockItem + + Additional query parameters are available: + - location: Filter stock by location + - category: Filter by parts belonging to a certain category """ def get_queryset(self): @@ -246,7 +252,7 @@ class StockList(generics.ListCreateAPIView): This is set by the optional param 'include_child_categories' """ - # Does the client wish to filter by category? + # Does the client wish to filter by stock location? loc_id = self.request.query_params.get('location', None) # Start with all objects @@ -256,22 +262,29 @@ class StockList(generics.ListCreateAPIView): try: location = StockLocation.objects.get(pk=loc_id) - # Filter by the supplied category - flt = Q(location=loc_id) - - if self.request.query_params.get('include_child_locations', None): - childs = location.getUniqueChildren() - for child in childs: - # Ignore the top-level category (already filtered!) - if str(child) == str(loc_id): - continue - flt |= Q(location=child) - - stock_list = stock_list.filter(flt) + if str2bool(self.request.query_params.get('include_child_locations', None)): + stock_list = stock_list.filter(location__in=location.getUniqueChildren()) + else: + stock_list = stock_list.filter(location=location.id) except StockLocation.DoesNotExist: pass + # Does the client wish to filter by part category? + cat_id = self.request.query_params.get('category', None) + + if cat_id: + try: + category = PartCategory.objects.get(pk=cat_id) + + if str2bool(self.request.query_params.get('include_child_categories', None)): + stock_list = stock_list.filter(part__category__in=category.getUniqueChildren()) + else: + stock_list = stock_list.filter(category=category.id) + + except PartCategory.DoesNotExist: + pass + return stock_list serializer_class = StockItemSerializer