From d0aa09337ad85ac01816468d8d1b736c3901acf1 Mon Sep 17 00:00:00 2001 From: Oliver Date: Thu, 27 Jan 2022 12:08:50 +1100 Subject: [PATCH] Simplify creation of serialized stock via the API - Fixes atomicity issues - Don't create-then-delete stock items --- InvenTree/stock/api.py | 127 ++++++++++++++++++------------------ InvenTree/stock/test_api.py | 61 +++++++++++++++++ 2 files changed, 125 insertions(+), 63 deletions(-) diff --git a/InvenTree/stock/api.py b/InvenTree/stock/api.py index cfd0b45d89..da43d8a5ca 100644 --- a/InvenTree/stock/api.py +++ b/InvenTree/stock/api.py @@ -465,12 +465,6 @@ class StockList(generics.ListCreateAPIView): user = request.user data = request.data - serializer = self.get_serializer(data=data) - serializer.is_valid(raise_exception=True) - - # Check if a set of serial numbers was provided - serial_numbers = data.get('serial_numbers', '') - quantity = data.get('quantity', None) if quantity is None: @@ -478,77 +472,84 @@ class StockList(generics.ListCreateAPIView): 'quantity': _('Quantity is required'), }) - notes = data.get('notes', '') + try: + part = Part.objects.get(pk=data.get('part', None)) + except (ValueError, Part.DoesNotExist): + raise ValidationError({ + 'part': _('Valid part ID must be supplied'), + }) + + # Set default location (if not provided) + if 'location' not in data: + location = part.get_default_location() + + if location: + data['location'] = location.pk + + # An expiry date was *not* specified - try to infer it! + if 'expiry_date' not in data: + + if part.default_expiry > 0: + data['expiry_date'] = datetime.now().date() + timedelta(days=part.default_expiry) + + # Attempt to extract serial numbers from submitted data + serials = None + + # Check if a set of serial numbers was provided + serial_numbers = data.get('serial_numbers', '') + + # Assign serial numbers for a trackable part + if serial_numbers and part.trackable: + + # If serial numbers are specified, check that they match! + try: + serials = extract_serial_numbers(serial_numbers, quantity, part.getLatestSerialNumberInt()) + except DjangoValidationError as e: + raise ValidationError({ + 'quantity': e.messages, + 'serial_numbers': e.messages, + }) + + if serials is not None: + """ + If the stock item is going to be serialized, set the quantity to 1 + """ + data['quantity'] = 1 + + # De-serialize the provided data + serializer = self.get_serializer(data=data) + serializer.is_valid(raise_exception=True) with transaction.atomic(): - # Create an initial stock item + # Create an initial StockItem object item = serializer.save() - # A location was *not* specified - try to infer it - if 'location' not in data: - item.location = item.part.get_default_location() + if serials: + # Assign the first serial number to the "master" item + item.serial = serials[0] - # An expiry date was *not* specified - try to infer it! - if 'expiry_date' not in data: - - if item.part.default_expiry > 0: - item.expiry_date = datetime.now().date() + timedelta(days=item.part.default_expiry) - - # fetch serial numbers - serials = None - - if serial_numbers: - # If serial numbers are specified, check that they match! - try: - serials = extract_serial_numbers(serial_numbers, quantity, item.part.getLatestSerialNumberInt()) - except DjangoValidationError as e: - raise ValidationError({ - 'quantity': e.messages, - 'serial_numbers': e.messages, - }) - - # Finally, save the item (with user information) + # Save the item (with user information) item.save(user=user) if serials: - """ - Serialize the stock, if required + for serial in serials[1:]: - - Note that the "original" stock item needs to be created first, so it can be serialized - - It is then immediately deleted - """ + # Create a duplicate stock item with the next serial number + item.pk = None + item.serial = serial - try: - item.serializeStock( - quantity, - serials, - user, - notes=notes, - location=item.location, - ) + item.save() - headers = self.get_success_headers(serializer.data) + response_data = { + 'quantity': quantity, + 'serial_numbers': serials, + } - # Delete the original item - item.delete() + else: + response_data = serializer.data - response_data = { - 'quantity': quantity, - 'serial_numbers': serials, - } - - return Response(response_data, status=status.HTTP_201_CREATED, headers=headers) - - except DjangoValidationError as e: - raise ValidationError({ - 'quantity': e.messages, - 'serial_numbers': e.messages, - }) - - # Return a response - headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + return Response(response_data, status=status.HTTP_201_CREATED, headers=self.get_success_headers(serializer.data)) def list(self, request, *args, **kwargs): """ diff --git a/InvenTree/stock/test_api.py b/InvenTree/stock/test_api.py index bd02646a81..9369a6a881 100644 --- a/InvenTree/stock/test_api.py +++ b/InvenTree/stock/test_api.py @@ -380,6 +380,67 @@ class StockItemTest(StockAPITestCase): expected_code=201 ) + def test_creation_with_serials(self): + """ + Test that serialized stock items can be created via the API, + """ + + trackable_part = part.models.Part.objects.create( + name='My part', + description='A trackable part', + trackable=True, + default_location=StockLocation.objects.get(pk=1), + ) + + self.assertEqual(trackable_part.stock_entries().count(), 0) + self.assertEqual(trackable_part.get_stock_count(), 0) + + # This should fail, incorrect serial number count + response = self.post( + self.list_url, + data={ + 'part': trackable_part.pk, + 'quantity': 10, + 'serial_numbers': '1-20', + }, + expected_code=400, + ) + + response = self.post( + self.list_url, + data={ + 'part': trackable_part.pk, + 'quantity': 10, + 'serial_numbers': '1-10', + }, + expected_code=201, + ) + + data = response.data + + self.assertEqual(data['quantity'], 10) + sn = data['serial_numbers'] + + # Check that each serial number was created + for i in range(1, 11): + self.assertTrue(i in sn) + + # Check the unique stock item has been created + + item = StockItem.objects.get( + part=trackable_part, + serial=str(i), + ) + + # Item location should have been set automatically + self.assertIsNotNone(item.location) + + self.assertEqual(str(i), item.serial) + + # There now should be 10 unique stock entries for this part + self.assertEqual(trackable_part.stock_entries().count(), 10) + self.assertEqual(trackable_part.get_stock_count(), 10) + def test_default_expiry(self): """ Test that the "default_expiry" functionality works via the API.