Merge remote-tracking branch 'inventree/master'

This commit is contained in:
Oliver Walters 2022-10-29 22:10:21 +11:00
commit efccc29922
4 changed files with 98 additions and 11 deletions

View File

@ -91,6 +91,18 @@ class InvenTreeModelMoneyField(ModelMoneyField):
kwargs['form_class'] = InvenTreeMoneyField
return super().formfield(**kwargs)
def to_python(self, value):
"""Convert value to python type."""
value = super().to_python(value)
return round_decimal(value, self.decimal_places)
def prepare_value(self, value):
"""Override the 'prepare_value' method, to remove trailing zeros when displaying.
Why? It looks nice!
"""
return round_decimal(value, self.decimal_places, normalize=True)
class InvenTreeMoneyField(MoneyField):
"""Custom MoneyField for clean migrations while using dynamic currency settings."""
@ -126,11 +138,16 @@ class DatePickerFormField(forms.DateField):
)
def round_decimal(value, places):
def round_decimal(value, places, normalize=False):
"""Round value to the specified number of places."""
if value is not None:
# see https://docs.python.org/2/library/decimal.html#decimal.Decimal.quantize for options
return value.quantize(Decimal(10) ** -places)
if type(value) in [Decimal, float]:
value = round(value, places)
if normalize:
# Remove any trailing zeroes
value = InvenTree.helpers.normalize(value)
return value
@ -140,18 +157,14 @@ class RoundingDecimalFormField(forms.DecimalField):
def to_python(self, value):
"""Convert value to python type."""
value = super().to_python(value)
value = round_decimal(value, self.decimal_places)
return value
return round_decimal(value, self.decimal_places)
def prepare_value(self, value):
"""Override the 'prepare_value' method, to remove trailing zeros when displaying.
Why? It looks nice!
"""
if type(value) == Decimal:
return InvenTree.helpers.normalize(value)
else:
return value
return round_decimal(value, self.decimal_places, normalize=True)
class RoundingDecimalField(models.DecimalField):

View File

@ -34,13 +34,14 @@ class InvenTreeMoneySerializer(MoneyField):
def __init__(self, *args, **kwargs):
"""Overrite default values."""
kwargs["max_digits"] = kwargs.get("max_digits", 19)
kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
self.decimal_places = kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
kwargs["required"] = kwargs.get("required", False)
super().__init__(*args, **kwargs)
def get_value(self, data):
"""Test that the returned amount is a valid Decimal."""
amount = super(DecimalField, self).get_value(data)
# Convert an empty string to None
@ -49,7 +50,9 @@ class InvenTreeMoneySerializer(MoneyField):
try:
if amount is not None and amount is not empty:
# Convert to a Decimal instance, and round to maximum allowed decimal places
amount = Decimal(amount)
amount = round(amount, self.decimal_places)
except Exception:
raise ValidationError({
self.field_name: [_("Must be a valid number")],

View File

@ -217,6 +217,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
'part': 1,
'manufacturer': 7,
'MPN': 'PART_NUMBER',
'link': 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E',
},
expected_code=201
)
@ -229,17 +230,24 @@ class ManufacturerTest(InvenTreeAPITestCase):
'supplier': 1,
'SKU': 'SKU_TEST',
'manufacturer_part': pk,
'link': 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E',
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
# Check link is not modified
self.assertEqual(response.data['link'], 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E')
# Check manufacturer part
manufacturer_part_id = int(response.data['manufacturer_part_detail']['pk'])
url = reverse('api-manufacturer-part-detail', kwargs={'pk': manufacturer_part_id})
response = self.get(url)
self.assertEqual(response.data['MPN'], 'PART_NUMBER')
# Check link is not modified
self.assertEqual(response.data['link'], 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E')
class SupplierPartTest(InvenTreeAPITestCase):
"""Unit tests for the SupplierPart API endpoints"""

View File

@ -1,5 +1,6 @@
"""Unit tests for the various part API endpoints"""
from decimal import Decimal
from random import randint
from django.urls import reverse
@ -2430,3 +2431,65 @@ class PartAttachmentTest(InvenTreeAPITestCase):
self.assertEqual(data['part'], 1)
self.assertEqual(data['link'], link)
self.assertEqual(data['comment'], 'Hello world')
class PartInternalPriceBreakTest(InvenTreeAPITestCase):
"""Unit tests for the PartInternalPrice API endpoints"""
fixtures = [
'category',
'part',
'params',
'location',
'bom',
'company',
'test_templates',
'manufacturer_part',
'supplier_part',
'order',
'stock',
]
roles = [
'part.change',
'part.add',
'part.delete',
'part_category.change',
'part_category.add',
'part_category.delete',
]
def test_create_price_breaks(self):
"""Test we can create price breaks at various quantities"""
url = reverse('api-part-internal-price-list')
breaks = [
(1.0, 101),
(1.1, 92.555555555),
(1.5, 90.999999999),
(1.756, 89),
(2, 86),
(25, 80)
]
for q, p in breaks:
data = self.post(
url,
{
'part': 1,
'quantity': q,
'price': p,
},
expected_code=201
).data
self.assertEqual(data['part'], 1)
self.assertEqual(
round(Decimal(data['quantity']), 4),
round(Decimal(q), 4)
)
self.assertEqual(
round(Decimal(data['price']), 4),
round(Decimal(p), 4)
)