Add ability to filter stock items by part category

- Using the REST API
- Add &category=x to the query string
- Add &include_child_categories to include child categories
This commit is contained in:
Oliver Walters 2019-05-16 18:54:14 +10:00
parent 387711a9a1
commit e1558a7a96

View File

@ -7,16 +7,18 @@ from django_filters import NumberFilter
from django.conf.urls import url, include
from django.urls import reverse
from django.db.models import Q
from .models import StockLocation, StockItem
from .models import StockItemTracking
from part.models import PartCategory
from .serializers import StockItemSerializer, StockQuantitySerializer
from .serializers import LocationSerializer
from .serializers import StockTrackingSerializer
from InvenTree.views import TreeSerializer
from InvenTree.helpers import str2bool
from rest_framework.serializers import ValidationError
from rest_framework.views import APIView
@ -237,6 +239,10 @@ class StockList(generics.ListCreateAPIView):
- GET: Return a list of all StockItem objects (with optional query filters)
- POST: Create a new StockItem
Additional query parameters are available:
- location: Filter stock by location
- category: Filter by parts belonging to a certain category
"""
def get_queryset(self):
@ -246,7 +252,7 @@ class StockList(generics.ListCreateAPIView):
This is set by the optional param 'include_child_categories'
"""
# Does the client wish to filter by category?
# Does the client wish to filter by stock location?
loc_id = self.request.query_params.get('location', None)
# Start with all objects
@ -256,22 +262,29 @@ class StockList(generics.ListCreateAPIView):
try:
location = StockLocation.objects.get(pk=loc_id)
# Filter by the supplied category
flt = Q(location=loc_id)
if self.request.query_params.get('include_child_locations', None):
childs = location.getUniqueChildren()
for child in childs:
# Ignore the top-level category (already filtered!)
if str(child) == str(loc_id):
continue
flt |= Q(location=child)
stock_list = stock_list.filter(flt)
if str2bool(self.request.query_params.get('include_child_locations', None)):
stock_list = stock_list.filter(location__in=location.getUniqueChildren())
else:
stock_list = stock_list.filter(location=location.id)
except StockLocation.DoesNotExist:
pass
# Does the client wish to filter by part category?
cat_id = self.request.query_params.get('category', None)
if cat_id:
try:
category = PartCategory.objects.get(pk=cat_id)
if str2bool(self.request.query_params.get('include_child_categories', None)):
stock_list = stock_list.filter(part__category__in=category.getUniqueChildren())
else:
stock_list = stock_list.filter(category=category.id)
except PartCategory.DoesNotExist:
pass
return stock_list
serializer_class = StockItemSerializer