Floating point API bug (#3877)

* Add unit tests for internalpricebreak

- Exposes an existing bug

* Ensure that rounding-decimal and prices are rounded correctly

- Force remove trailing digits / reduce precision
This commit is contained in:
Oliver 2022-10-29 14:18:19 +11:00 committed by GitHub
parent 5263ccdca3
commit 55c8b73b0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 11 deletions

View File

@ -91,6 +91,18 @@ class InvenTreeModelMoneyField(ModelMoneyField):
kwargs['form_class'] = InvenTreeMoneyField
return super().formfield(**kwargs)
def to_python(self, value):
"""Convert value to python type."""
value = super().to_python(value)
return round_decimal(value, self.decimal_places)
def prepare_value(self, value):
"""Override the 'prepare_value' method, to remove trailing zeros when displaying.
Why? It looks nice!
"""
return round_decimal(value, self.decimal_places, normalize=True)
class InvenTreeMoneyField(MoneyField):
"""Custom MoneyField for clean migrations while using dynamic currency settings."""
@ -126,11 +138,16 @@ class DatePickerFormField(forms.DateField):
)
def round_decimal(value, places):
def round_decimal(value, places, normalize=False):
"""Round value to the specified number of places."""
if value is not None:
# see https://docs.python.org/2/library/decimal.html#decimal.Decimal.quantize for options
return value.quantize(Decimal(10) ** -places)
if type(value) in [Decimal, float]:
value = round(value, places)
if normalize:
# Remove any trailing zeroes
value = InvenTree.helpers.normalize(value)
return value
@ -140,18 +157,14 @@ class RoundingDecimalFormField(forms.DecimalField):
def to_python(self, value):
"""Convert value to python type."""
value = super().to_python(value)
value = round_decimal(value, self.decimal_places)
return value
return round_decimal(value, self.decimal_places)
def prepare_value(self, value):
"""Override the 'prepare_value' method, to remove trailing zeros when displaying.
Why? It looks nice!
"""
if type(value) == Decimal:
return InvenTree.helpers.normalize(value)
else:
return value
return round_decimal(value, self.decimal_places, normalize=True)
class RoundingDecimalField(models.DecimalField):

View File

@ -34,13 +34,14 @@ class InvenTreeMoneySerializer(MoneyField):
def __init__(self, *args, **kwargs):
"""Overrite default values."""
kwargs["max_digits"] = kwargs.get("max_digits", 19)
kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
self.decimal_places = kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
kwargs["required"] = kwargs.get("required", False)
super().__init__(*args, **kwargs)
def get_value(self, data):
"""Test that the returned amount is a valid Decimal."""
amount = super(DecimalField, self).get_value(data)
# Convert an empty string to None
@ -49,7 +50,9 @@ class InvenTreeMoneySerializer(MoneyField):
try:
if amount is not None and amount is not empty:
# Convert to a Decimal instance, and round to maximum allowed decimal places
amount = Decimal(amount)
amount = round(amount, self.decimal_places)
except Exception:
raise ValidationError({
self.field_name: [_("Must be a valid number")],

View File

@ -1,5 +1,6 @@
"""Unit tests for the various part API endpoints"""
from decimal import Decimal
from random import randint
from django.urls import reverse
@ -2430,3 +2431,65 @@ class PartAttachmentTest(InvenTreeAPITestCase):
self.assertEqual(data['part'], 1)
self.assertEqual(data['link'], link)
self.assertEqual(data['comment'], 'Hello world')
class PartInternalPriceBreakTest(InvenTreeAPITestCase):
"""Unit tests for the PartInternalPrice API endpoints"""
fixtures = [
'category',
'part',
'params',
'location',
'bom',
'company',
'test_templates',
'manufacturer_part',
'supplier_part',
'order',
'stock',
]
roles = [
'part.change',
'part.add',
'part.delete',
'part_category.change',
'part_category.add',
'part_category.delete',
]
def test_create_price_breaks(self):
"""Test we can create price breaks at various quantities"""
url = reverse('api-part-internal-price-list')
breaks = [
(1.0, 101),
(1.1, 92.555555555),
(1.5, 90.999999999),
(1.756, 89),
(2, 86),
(25, 80)
]
for q, p in breaks:
data = self.post(
url,
{
'part': 1,
'quantity': q,
'price': p,
},
expected_code=201
).data
self.assertEqual(data['part'], 1)
self.assertEqual(
round(Decimal(data['quantity']), 4),
round(Decimal(q), 4)
)
self.assertEqual(
round(Decimal(data['price']), 4),
round(Decimal(p), 4)
)