diff --git a/InvenTree/stock/api.py b/InvenTree/stock/api.py index ed7a4b8c40..f60f1c535f 100644 --- a/InvenTree/stock/api.py +++ b/InvenTree/stock/api.py @@ -30,6 +30,7 @@ from company.models import Company, SupplierPart from company.serializers import CompanySerializer, SupplierPartSerializer from InvenTree.helpers import str2bool, isNull, extract_serial_numbers +from InvenTree.helpers import DownloadFile from InvenTree.api import AttachmentMixin from InvenTree.filters import InvenTreeOrderingFilter @@ -40,6 +41,7 @@ from order.serializers import POSerializer from part.models import BomItem, Part, PartCategory from part.serializers import PartBriefSerializer +from stock.admin import StockItemResource from stock.models import StockLocation, StockItem from stock.models import StockItemTracking from stock.models import StockItemAttachment @@ -611,6 +613,27 @@ class StockList(generics.ListCreateAPIView): queryset = self.filter_queryset(self.get_queryset()) + params = request.query_params + + # Check if we wish to export the queried data to a file. + # If so, skip pagination! + export_format = params.get('export', None) + + if export_format: + export_format = str(export_format).strip().lower() + + if export_format in ['csv', 'tsv', 'xls', 'xlsx']: + dataset = StockItemResource().export(queryset=queryset) + + filedata = dataset.export(export_format) + + filename = 'InvenTree_Stocktake_{date}.{fmt}'.format( + date=datetime.now().strftime("%d-%b-%Y"), + fmt=export_format + ) + + return DownloadFile(filedata, filename) + page = self.paginate_queryset(queryset) if page is not None: @@ -641,7 +664,7 @@ class StockList(generics.ListCreateAPIView): supplier_part_ids.add(sp) # Do we wish to include Part detail? - if str2bool(request.query_params.get('part_detail', False)): + if str2bool(params.get('part_detail', False)): # Fetch only the required Part objects from the database parts = Part.objects.filter(pk__in=part_ids).prefetch_related( @@ -659,7 +682,7 @@ class StockList(generics.ListCreateAPIView): stock_item['part_detail'] = part_map.get(part_id, None) # Do we wish to include SupplierPart detail? - if str2bool(request.query_params.get('supplier_part_detail', False)): + if str2bool(params.get('supplier_part_detail', False)): supplier_parts = SupplierPart.objects.filter(pk__in=supplier_part_ids) @@ -673,7 +696,7 @@ class StockList(generics.ListCreateAPIView): stock_item['supplier_part_detail'] = supplier_part_map.get(part_id, None) # Do we wish to include StockLocation detail? - if str2bool(request.query_params.get('location_detail', False)): + if str2bool(params.get('location_detail', False)): # Fetch only the required StockLocation objects from the database locations = StockLocation.objects.filter(pk__in=location_ids).prefetch_related(