diff --git a/InvenTree/InvenTree/models.py b/InvenTree/InvenTree/models.py index 4eb8dd8daa..239967e32a 100644 --- a/InvenTree/InvenTree/models.py +++ b/InvenTree/InvenTree/models.py @@ -173,3 +173,19 @@ class InvenTreeTree(models.Model): """ return self.path + + +def FilterChildren(queryset, parent): + """ Filter a queryset, limit to only objects that are a child of the given parent + """ + + if not parent: + return queryset + elif isinstance(parent,str) and parent.lower() in ['none', 'false', 'null', 'top', '0']: + return queryset.filter(parent=None) + else: + try: + parent_id = int(parent) + return queryset.filter(parent=parent_id) + except: + return queryset diff --git a/InvenTree/part/serializers.py b/InvenTree/part/serializers.py index 06ad724b50..7e5cfd3aaa 100644 --- a/InvenTree/part/serializers.py +++ b/InvenTree/part/serializers.py @@ -43,21 +43,13 @@ class PartCategoryBriefSerializer(serializers.ModelSerializer): class PartCategoryDetailSerializer(serializers.ModelSerializer): - # List of parts in this category - parts = PartSerializer(many=True, read_only=True) - - # List of child categories under this one - children = PartCategoryBriefSerializer(many=True, read_only=True) - class Meta: model = PartCategory fields = ('pk', 'name', 'description', 'parent', - 'path', - 'children', - 'parts') + 'path') class PartTemplateSerializer(serializers.ModelSerializer): diff --git a/InvenTree/part/urls.py b/InvenTree/part/urls.py index b351ce3eda..e7b1004ec5 100644 --- a/InvenTree/part/urls.py +++ b/InvenTree/part/urls.py @@ -10,10 +10,10 @@ from . import views categorypatterns = [ # Part category detail - url(r'^category/(?P[0-9]+)/$', views.PartCategoryDetail.as_view()), + url(r'^(?P[0-9]+)/?$', views.PartCategoryDetail.as_view()), # List of top-level categories - url(r'^$', views.PartCategoryList.as_view()) + url(r'^\?*[^/]*/?$', views.PartCategoryList.as_view()) ] partparampatterns = [ @@ -47,13 +47,13 @@ urlpatterns = [ url(r'^(?P[0-9]+)/$', views.PartDetail.as_view()), # Part categories - url(r'^category/', views.PartCategoryList.as_view()), + url(r'^category/?', include(categorypatterns)), # Part parameters - url(r'^parameters/', include(partparampatterns)), + url(r'^parameters/?', include(partparampatterns)), # Part templates - url(r'^templates/', include(parttemplatepatterns)), + url(r'^templates/?', include(parttemplatepatterns)), # List parts with optional filters url(r'^\?*[^/]*/?$', views.PartList.as_view()), diff --git a/InvenTree/part/views.py b/InvenTree/part/views.py index c4646b6a6e..9ad987d2fb 100644 --- a/InvenTree/part/views.py +++ b/InvenTree/part/views.py @@ -2,6 +2,7 @@ from rest_framework import generics, permissions +from InvenTree.models import FilterChildren from .models import PartCategory, Part, PartParameter, PartParameterTemplate from .serializers import PartSerializer from .serializers import PartCategoryDetailSerializer @@ -91,6 +92,16 @@ class PartCategoryList(generics.ListCreateAPIView): """ Return a list of all top-level part categories. Categories are considered "top-level" if they do not have a parent """ + + def get_queryset(self): + params = self.request.query_params + + categories = PartCategory.objects.all() + + categories = FilterChildren(categories, params.get('parent', None)) + + return categories + queryset = PartCategory.objects.filter(parent=None) serializer_class = PartCategoryDetailSerializer permission_classes = (permissions.IsAuthenticatedOrReadOnly,) diff --git a/InvenTree/stock/views.py b/InvenTree/stock/views.py index 779bdc9695..12de5ac8d2 100644 --- a/InvenTree/stock/views.py +++ b/InvenTree/stock/views.py @@ -1,6 +1,7 @@ from rest_framework import generics, permissions import django_filters +from InvenTree.models import FilterChildren from .models import StockLocation, StockItem from .serializers import StockItemSerializer, LocationDetailSerializer @@ -70,16 +71,7 @@ class LocationList(generics.ListCreateAPIView): locations = StockLocation.objects.all() - parent_id = params.get('parent', None) - - if parent_id and parent_id.lower() in ['none', 'false', 'null', 'top']: - locations = locations.filter(parent=None) - else: - try: - parent_id_num = int(parent_id) - locations = locations.filter(parent=parent_id_num) - except: - pass + locations = FilterChildren(locations, params.get('parent', None)) return locations