mirror of
https://github.com/inventree/InvenTree
synced 2024-08-30 18:33:04 +00:00
Floating point API bug (#3877)
* Add unit tests for internalpricebreak - Exposes an existing bug * Ensure that rounding-decimal and prices are rounded correctly - Force remove trailing digits / reduce precision
This commit is contained in:
parent
5263ccdca3
commit
55c8b73b0a
@ -91,6 +91,18 @@ class InvenTreeModelMoneyField(ModelMoneyField):
|
||||
kwargs['form_class'] = InvenTreeMoneyField
|
||||
return super().formfield(**kwargs)
|
||||
|
||||
def to_python(self, value):
|
||||
"""Convert value to python type."""
|
||||
value = super().to_python(value)
|
||||
return round_decimal(value, self.decimal_places)
|
||||
|
||||
def prepare_value(self, value):
|
||||
"""Override the 'prepare_value' method, to remove trailing zeros when displaying.
|
||||
|
||||
Why? It looks nice!
|
||||
"""
|
||||
return round_decimal(value, self.decimal_places, normalize=True)
|
||||
|
||||
|
||||
class InvenTreeMoneyField(MoneyField):
|
||||
"""Custom MoneyField for clean migrations while using dynamic currency settings."""
|
||||
@ -126,11 +138,16 @@ class DatePickerFormField(forms.DateField):
|
||||
)
|
||||
|
||||
|
||||
def round_decimal(value, places):
|
||||
def round_decimal(value, places, normalize=False):
|
||||
"""Round value to the specified number of places."""
|
||||
if value is not None:
|
||||
# see https://docs.python.org/2/library/decimal.html#decimal.Decimal.quantize for options
|
||||
return value.quantize(Decimal(10) ** -places)
|
||||
|
||||
if type(value) in [Decimal, float]:
|
||||
value = round(value, places)
|
||||
|
||||
if normalize:
|
||||
# Remove any trailing zeroes
|
||||
value = InvenTree.helpers.normalize(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@ -140,18 +157,14 @@ class RoundingDecimalFormField(forms.DecimalField):
|
||||
def to_python(self, value):
|
||||
"""Convert value to python type."""
|
||||
value = super().to_python(value)
|
||||
value = round_decimal(value, self.decimal_places)
|
||||
return value
|
||||
return round_decimal(value, self.decimal_places)
|
||||
|
||||
def prepare_value(self, value):
|
||||
"""Override the 'prepare_value' method, to remove trailing zeros when displaying.
|
||||
|
||||
Why? It looks nice!
|
||||
"""
|
||||
if type(value) == Decimal:
|
||||
return InvenTree.helpers.normalize(value)
|
||||
else:
|
||||
return value
|
||||
return round_decimal(value, self.decimal_places, normalize=True)
|
||||
|
||||
|
||||
class RoundingDecimalField(models.DecimalField):
|
||||
|
@ -34,13 +34,14 @@ class InvenTreeMoneySerializer(MoneyField):
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Overrite default values."""
|
||||
kwargs["max_digits"] = kwargs.get("max_digits", 19)
|
||||
kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
|
||||
self.decimal_places = kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
|
||||
kwargs["required"] = kwargs.get("required", False)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_value(self, data):
|
||||
"""Test that the returned amount is a valid Decimal."""
|
||||
|
||||
amount = super(DecimalField, self).get_value(data)
|
||||
|
||||
# Convert an empty string to None
|
||||
@ -49,7 +50,9 @@ class InvenTreeMoneySerializer(MoneyField):
|
||||
|
||||
try:
|
||||
if amount is not None and amount is not empty:
|
||||
# Convert to a Decimal instance, and round to maximum allowed decimal places
|
||||
amount = Decimal(amount)
|
||||
amount = round(amount, self.decimal_places)
|
||||
except Exception:
|
||||
raise ValidationError({
|
||||
self.field_name: [_("Must be a valid number")],
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Unit tests for the various part API endpoints"""
|
||||
|
||||
from decimal import Decimal
|
||||
from random import randint
|
||||
|
||||
from django.urls import reverse
|
||||
@ -2430,3 +2431,65 @@ class PartAttachmentTest(InvenTreeAPITestCase):
|
||||
self.assertEqual(data['part'], 1)
|
||||
self.assertEqual(data['link'], link)
|
||||
self.assertEqual(data['comment'], 'Hello world')
|
||||
|
||||
|
||||
class PartInternalPriceBreakTest(InvenTreeAPITestCase):
|
||||
"""Unit tests for the PartInternalPrice API endpoints"""
|
||||
|
||||
fixtures = [
|
||||
'category',
|
||||
'part',
|
||||
'params',
|
||||
'location',
|
||||
'bom',
|
||||
'company',
|
||||
'test_templates',
|
||||
'manufacturer_part',
|
||||
'supplier_part',
|
||||
'order',
|
||||
'stock',
|
||||
]
|
||||
|
||||
roles = [
|
||||
'part.change',
|
||||
'part.add',
|
||||
'part.delete',
|
||||
'part_category.change',
|
||||
'part_category.add',
|
||||
'part_category.delete',
|
||||
]
|
||||
|
||||
def test_create_price_breaks(self):
|
||||
"""Test we can create price breaks at various quantities"""
|
||||
|
||||
url = reverse('api-part-internal-price-list')
|
||||
|
||||
breaks = [
|
||||
(1.0, 101),
|
||||
(1.1, 92.555555555),
|
||||
(1.5, 90.999999999),
|
||||
(1.756, 89),
|
||||
(2, 86),
|
||||
(25, 80)
|
||||
]
|
||||
|
||||
for q, p in breaks:
|
||||
data = self.post(
|
||||
url,
|
||||
{
|
||||
'part': 1,
|
||||
'quantity': q,
|
||||
'price': p,
|
||||
},
|
||||
expected_code=201
|
||||
).data
|
||||
|
||||
self.assertEqual(data['part'], 1)
|
||||
self.assertEqual(
|
||||
round(Decimal(data['quantity']), 4),
|
||||
round(Decimal(q), 4)
|
||||
)
|
||||
self.assertEqual(
|
||||
round(Decimal(data['price']), 4),
|
||||
round(Decimal(p), 4)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user