Merge pull request #692 from SchrodingersGat/api-improvements

Api improvements
This commit is contained in:
Oliver 2020-04-03 12:37:03 +11:00 committed by GitHub
commit 7ee94f3574
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 124 additions and 24 deletions

View File

@ -52,6 +52,21 @@ def str2bool(text, test=True):
return str(text).lower() in ['0', 'n', 'no', 'none', 'f', 'false', 'off', ] return str(text).lower() in ['0', 'n', 'no', 'none', 'f', 'false', 'off', ]
def isNull(text):
"""
Test if a string 'looks' like a null value.
This is useful for querying the API against a null key.
Args:
text: Input text
Returns:
True if the text looks like a null value
"""
return str(text).strip().lower() in ['top', 'null', 'none', 'empty', 'false', '-1']
def decimal2string(d): def decimal2string(d):
""" """
Format a Decimal number as a string, Format a Decimal number as a string,

View File

@ -95,6 +95,10 @@ function loadPartTable(table, url, options={}) {
query.active = true; query.active = true;
} }
// Include sub-category search
// TODO - Make this user-configurable!
query.cascade = true;
var columns = [ var columns = [
{ {
field: 'pk', field: 'pk',

View File

@ -42,6 +42,10 @@ function loadStockTable(table, options) {
var params = options.params || {}; var params = options.params || {};
// Enforce 'cascade' option
// TODO - Make this user-configurable?
params.cascade = true;
console.log('load stock table'); console.log('load stock table');
table.inventreeTable({ table.inventreeTable({

View File

@ -27,7 +27,7 @@ from . import serializers as part_serializers
from InvenTree.status_codes import OrderStatus, StockStatus, BuildStatus from InvenTree.status_codes import OrderStatus, StockStatus, BuildStatus
from InvenTree.views import TreeSerializer from InvenTree.views import TreeSerializer
from InvenTree.helpers import str2bool from InvenTree.helpers import str2bool, isNull
class PartCategoryTree(TreeSerializer): class PartCategoryTree(TreeSerializer):
@ -57,6 +57,31 @@ class CategoryList(generics.ListCreateAPIView):
permissions.IsAuthenticated, permissions.IsAuthenticated,
] ]
def get_queryset(self):
"""
Custom filtering:
- Allow filtering by "null" parent to retrieve top-level part categories
"""
cat_id = self.request.query_params.get('parent', None)
queryset = super().get_queryset()
if cat_id is not None:
# Look for top-level categories
if isNull(cat_id):
queryset = queryset.filter(parent=None)
else:
try:
cat_id = int(cat_id)
queryset = queryset.filter(parent=cat_id)
except ValueError:
pass
return queryset
filter_backends = [ filter_backends = [
DjangoFilterBackend, DjangoFilterBackend,
filters.SearchFilter, filters.SearchFilter,
@ -64,7 +89,6 @@ class CategoryList(generics.ListCreateAPIView):
] ]
filter_fields = [ filter_fields = [
'parent',
] ]
ordering_fields = [ ordering_fields = [
@ -219,12 +243,25 @@ class PartList(generics.ListCreateAPIView):
# Start with all objects # Start with all objects
parts_list = Part.objects.all() parts_list = Part.objects.all()
if cat_id: cascade = str2bool(self.request.query_params.get('cascade', False))
try:
category = PartCategory.objects.get(pk=cat_id) if cat_id is not None:
parts_list = parts_list.filter(category__in=category.getUniqueChildren())
except PartCategory.DoesNotExist: if isNull(cat_id):
pass parts_list = parts_list.filter(category=None)
else:
try:
cat_id = int(cat_id)
category = PartCategory.objects.get(pk=cat_id)
# If '?cascade=true' then include parts which exist in sub-categories
if cascade:
parts_list = parts_list.filter(category__in=category.getUniqueChildren())
# Just return parts directly in the requested category
else:
parts_list = parts_list.filter(category=cat_id)
except (ValueError, PartCategory.DoesNotExist):
pass
# Ensure that related models are pre-loaded to reduce DB trips # Ensure that related models are pre-loaded to reduce DB trips
parts_list = self.get_serializer_class().setup_eager_loading(parts_list) parts_list = self.get_serializer_class().setup_eager_loading(parts_list)

View File

@ -18,6 +18,8 @@ class CategorySerializer(InvenTreeModelSerializer):
url = serializers.CharField(source='get_absolute_url', read_only=True) url = serializers.CharField(source='get_absolute_url', read_only=True)
parts = serializers.IntegerField(source='item_count', read_only=True)
class Meta: class Meta:
model = PartCategory model = PartCategory
fields = [ fields = [
@ -27,6 +29,7 @@ class CategorySerializer(InvenTreeModelSerializer):
'pathstring', 'pathstring',
'url', 'url',
'parent', 'parent',
'parts',
] ]

View File

@ -103,7 +103,7 @@ class PartAPITest(APITestCase):
If provided, parts are provided for ANY child category (recursive) If provided, parts are provided for ANY child category (recursive)
""" """
url = reverse('api-part-list') url = reverse('api-part-list')
data = {'category': 1} data = {'category': 1, 'cascade': True}
# Now request to include child categories # Now request to include child categories
response = self.client.get(url, data, format='json') response = self.client.get(url, data, format='json')

View File

@ -19,7 +19,7 @@ from .serializers import LocationSerializer
from .serializers import StockTrackingSerializer from .serializers import StockTrackingSerializer
from InvenTree.views import TreeSerializer from InvenTree.views import TreeSerializer
from InvenTree.helpers import str2bool from InvenTree.helpers import str2bool, isNull
from InvenTree.status_codes import StockStatus from InvenTree.status_codes import StockStatus
import os import os
@ -223,9 +223,33 @@ class StockLocationList(generics.ListCreateAPIView):
""" """
queryset = StockLocation.objects.all() queryset = StockLocation.objects.all()
serializer_class = LocationSerializer serializer_class = LocationSerializer
def get_queryset(self):
"""
Custom filtering:
- Allow filtering by "null" parent to retrieve top-level stock locations
"""
queryset = super().get_queryset()
loc_id = self.request.query_params.get('parent', None)
if loc_id is not None:
# Look for top-level locations
if isNull(loc_id):
queryset = queryset.filter(parent=None)
else:
try:
loc_id = int(loc_id)
queryset = queryset.filter(parent=loc_id)
except ValueError:
pass
return queryset
permission_classes = [ permission_classes = [
permissions.IsAuthenticated, permissions.IsAuthenticated,
] ]
@ -237,7 +261,6 @@ class StockLocationList(generics.ListCreateAPIView):
] ]
filter_fields = [ filter_fields = [
'parent',
] ]
search_fields = [ search_fields = [
@ -373,13 +396,24 @@ class StockList(generics.ListCreateAPIView):
# Does the client wish to filter by stock location? # Does the client wish to filter by stock location?
loc_id = self.request.query_params.get('location', None) loc_id = self.request.query_params.get('location', None)
if loc_id: cascade = str2bool(self.request.query_params.get('cascade', False))
try:
location = StockLocation.objects.get(pk=loc_id) if loc_id is not None:
stock_list = stock_list.filter(location__in=location.getUniqueChildren())
# Filter by 'null' location (i.e. top-level items)
except (ValueError, StockLocation.DoesNotExist): if isNull(loc_id):
pass stock_list = stock_list.filter(location=None)
else:
try:
# If '?cascade=true' then include items which exist in sub-locations
if cascade:
location = StockLocation.objects.get(pk=loc_id)
stock_list = stock_list.filter(location__in=location.getUniqueChildren())
else:
stock_list = stock_list.filter(location=loc_id)
except (ValueError, StockLocation.DoesNotExist):
pass
# Does the client wish to filter by part category? # Does the client wish to filter by part category?
cat_id = self.request.query_params.get('category', None) cat_id = self.request.query_params.get('category', None)
@ -511,13 +545,13 @@ stock_endpoints = [
] ]
location_endpoints = [ location_endpoints = [
url(r'^$', LocationDetail.as_view(), name='api-location-detail'), url(r'^(?P<pk>\d+)/', LocationDetail.as_view(), name='api-location-detail'),
url(r'^.*$', StockLocationList.as_view(), name='api-location-list'),
] ]
stock_api_urls = [ stock_api_urls = [
url(r'location/?', StockLocationList.as_view(), name='api-location-list'), url(r'location/', include(location_endpoints)),
url(r'location/(?P<pk>\d+)/', include(location_endpoints)),
# These JSON endpoints have been replaced (for now) with server-side form rendering - 02/06/2019 # These JSON endpoints have been replaced (for now) with server-side form rendering - 02/06/2019
# url(r'stocktake/?', StockStocktake.as_view(), name='api-stock-stocktake'), # url(r'stocktake/?', StockStocktake.as_view(), name='api-stock-stocktake'),

View File

@ -119,6 +119,8 @@ class LocationSerializer(InvenTreeModelSerializer):
url = serializers.CharField(source='get_absolute_url', read_only=True) url = serializers.CharField(source='get_absolute_url', read_only=True)
items = serializers.IntegerField(source='item_count', read_only=True)
class Meta: class Meta:
model = StockLocation model = StockLocation
fields = [ fields = [
@ -127,7 +129,8 @@ class LocationSerializer(InvenTreeModelSerializer):
'name', 'name',
'description', 'description',
'parent', 'parent',
'pathstring' 'pathstring',
'items',
] ]