total_price for orders (#4447)

* Adds unit test for counting queries on PurchaseOrderList API endpoint

- We will work to make this queryset more efficient

* PEP fixes

* Add 'total_price' fields to SalesOrder and PurchaseOrder models

* PurchaseOrder list API now has constant query count

* Data migration for updating existing PurchaseOrder and SalesOrder instances

- Calculate total_price for any existing order
- Will fail if exchange rates are not available

* Add total_price_currency to API serializers

* Render total_price in SalesOrder table

* Add ability to filter both lists by total_price field

* Update total_price whenever an order is updated

* Update total price whenever a lineitem is saved or deleted

* Add query-counting unit test for SalesOrder list API

* Calling refresh_from_db inside a save() method is *not* a good idea
This commit is contained in:
Oliver 2023-03-05 22:22:18 +11:00 committed by GitHub
parent c0f405243a
commit 5ba75c868d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 406 additions and 37 deletions

View File

@ -8,6 +8,7 @@ from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group from django.contrib.auth.models import Group
from django.http.response import StreamingHttpResponse from django.http.response import StreamingHttpResponse
from djmoney.contrib.exchange.models import ExchangeBackend, Rate
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from plugin import registry from plugin import registry
@ -113,7 +114,42 @@ class PluginMixin:
self.plugin_confs = PluginConfig.objects.all() self.plugin_confs = PluginConfig.objects.all()
class InvenTreeAPITestCase(UserMixin, APITestCase): class ExchangeRateMixin:
"""Mixin class for generating exchange rate data"""
def generate_exchange_rates(self):
"""Helper function which generates some exchange rates to work with"""
rates = {
'AUD': 1.5,
'CAD': 1.7,
'GBP': 0.9,
'USD': 1.0,
}
# Create a dummy backend
ExchangeBackend.objects.create(
name='InvenTreeExchange',
base_currency='USD',
)
backend = ExchangeBackend.objects.get(name='InvenTreeExchange')
items = []
for currency, rate in rates.items():
items.append(
Rate(
currency=currency,
value=rate,
backend=backend,
)
)
Rate.objects.bulk_create(items)
class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
"""Base class for running InvenTree API tests.""" """Base class for running InvenTree API tests."""
def getActions(self, url): def getActions(self, url):

View File

@ -33,7 +33,7 @@ from common.notifications import (InvenTreeNotificationBodies,
NotificationBody, trigger_notification) NotificationBody, trigger_notification)
from common.settings import currency_code_default from common.settings import currency_code_default
from .api_tester import UserMixin from .api_tester import ExchangeRateMixin, UserMixin
from .settings import MEDIA_URL, STATIC_URL from .settings import MEDIA_URL, STATIC_URL
logger = logging.getLogger('inventree') logger = logging.getLogger('inventree')
@ -1059,7 +1059,7 @@ def inheritors(cls):
return subcls return subcls
class InvenTreeTestCase(UserMixin, TestCase): class InvenTreeTestCase(ExchangeRateMixin, UserMixin, TestCase):
"""Testcase with user setup buildin.""" """Testcase with user setup buildin."""
pass pass

View File

@ -322,6 +322,7 @@ class PurchaseOrderList(APIDownloadMixin, ListCreateAPI):
'line_items', 'line_items',
'status', 'status',
'responsible', 'responsible',
'total_price',
] ]
ordering = '-reference' ordering = '-reference'
@ -783,6 +784,7 @@ class SalesOrderList(APIDownloadMixin, ListCreateAPI):
'target_date', 'target_date',
'line_items', 'line_items',
'shipment_date', 'shipment_date',
'total_price',
] ]
search_fields = [ search_fields = [

View File

@ -0,0 +1,36 @@
# Generated by Django 3.2.18 on 2023-03-04 07:21
import InvenTree.fields
from django.db import migrations
import djmoney.models.fields
import djmoney.models.validators
class Migration(migrations.Migration):
dependencies = [
('order', '0077_auto_20230129_0154'),
]
operations = [
migrations.AddField(
model_name='purchaseorder',
name='total_price',
field=InvenTree.fields.InvenTreeModelMoneyField(blank=True, currency_choices=[], decimal_places=6, default_currency='', help_text='Total price for this order', max_digits=19, null=True, validators=[djmoney.models.validators.MinMoneyValidator(0)], verbose_name='Total Price'),
),
migrations.AddField(
model_name='purchaseorder',
name='total_price_currency',
field=djmoney.models.fields.CurrencyField(choices=[], default='', editable=False, max_length=3),
),
migrations.AddField(
model_name='salesorder',
name='total_price',
field=InvenTree.fields.InvenTreeModelMoneyField(blank=True, currency_choices=[], decimal_places=6, default_currency='', help_text='Total price for this order', max_digits=19, null=True, validators=[djmoney.models.validators.MinMoneyValidator(0)], verbose_name='Total Price'),
),
migrations.AddField(
model_name='salesorder',
name='total_price_currency',
field=djmoney.models.fields.CurrencyField(choices=[], default='', editable=False, max_length=3),
),
]

View File

@ -0,0 +1,131 @@
# Generated by Django 3.2.18 on 2023-03-04 09:04
import logging
from django.db import migrations
from djmoney.contrib.exchange.exceptions import MissingRate
from djmoney.contrib.exchange.models import convert_money
from djmoney.money import Money
from common.settings import currency_code_default
logger = logging.getLogger('inventree')
def update_purchase_order_price(apps, schema_editor):
"""Calculate 'total_price' field for each PurchaseOrder"""
PurchaseOrder = apps.get_model('order', 'purchaseorder')
currency = currency_code_default()
valid_count = 0
invalid_count = 0
for order in PurchaseOrder.objects.all():
valid = True
total_price = Money(0, currency)
for line in order.lines.all():
if line.purchase_price:
try:
total_price += convert_money(line.purchase_price, currency) * line.quantity
except MissingRate:
valid = False
break
for line in order.extra_lines.all():
if line.price:
try:
total_price += convert_money(line.price, currency) * line.quantity
except MissingRate:
valid = False
break
if valid:
order.total_price = total_price
order.save()
valid_count += 1
else:
invalid_count +=1
if valid_count > 0:
logger.info(f"Updated 'total_price' field for {valid_count} PurchaseOrder instances")
if invalid_count > 0:
logger.info(f"'total_price' field could not be updated for {invalid_count} PurchaseOrder instances")
def update_sales_order_price(apps, schema_editor):
"""Calculate 'total_price' field for each SalesOrder"""
SalesOrder = apps.get_model('order', 'salesorder')
currency = currency_code_default()
valid_count = 0
invalid_count = 0
for order in SalesOrder.objects.all():
valid = True
total_price = Money(0, currency)
for line in order.lines.all():
if line.sale_price:
try:
total_price += convert_money(line.sale_price, currency) * line.quantity
except MissingRate:
valid = False
break
for line in order.extra_lines.all():
if line.price:
try:
total_price += convert_money(line.price, currency) * line.quantity
except MissingRate:
valid = False
break
if valid:
order.total_price = total_price
order.save()
valid_count += 1
else:
invalid_count +=1
if valid_count > 0:
logger.info(f"Updated 'total_price' field for {valid_count} SalesOrder instances")
if invalid_count > 0:
logger.info(f"'total_price' field could not be updated for {invalid_count} SalesOrder instances")
def reverse(apps, schema_editor):
"""Reverse migration (does nothing)"""
pass
class Migration(migrations.Migration):
dependencies = [
('order', '0078_auto_20230304_0721'),
]
operations = [
migrations.RunPython(
update_purchase_order_price,
reverse_code=reverse
),
migrations.RunPython(
update_sales_order_price,
reverse_code=reverse,
)
]

View File

@ -51,6 +51,7 @@ class Order(MetadataMixin, ReferenceIndexingMixin):
Instances of this class: Instances of this class:
- PuchaseOrder - PuchaseOrder
- SalesOrder
Attributes: Attributes:
reference: Unique order number / reference / code reference: Unique order number / reference / code
@ -77,6 +78,9 @@ class Order(MetadataMixin, ReferenceIndexingMixin):
if not self.creation_date: if not self.creation_date:
self.creation_date = datetime.now().date() self.creation_date = datetime.now().date()
# Recalculate total_price for this order
self.update_total_price(commit=False)
super().save(*args, **kwargs) super().save(*args, **kwargs)
description = models.CharField(max_length=250, verbose_name=_('Description'), help_text=_('Order description')) description = models.CharField(max_length=250, verbose_name=_('Description'), help_text=_('Order description'))
@ -103,7 +107,22 @@ class Order(MetadataMixin, ReferenceIndexingMixin):
notes = InvenTreeNotesField(help_text=_('Order notes')) notes = InvenTreeNotesField(help_text=_('Order notes'))
def get_total_price(self, target_currency=None): total_price = InvenTreeModelMoneyField(
null=True, blank=True,
allow_negative=False,
verbose_name=_('Total Price'),
help_text=_('Total price for this order')
)
def update_total_price(self, commit=True):
"""Recalculate and save the total_price for this order"""
self.total_price = self.calculate_total_price()
if commit:
self.save()
def calculate_total_price(self, target_currency=None):
"""Calculates the total price of all order lines, and converts to the specified target currency. """Calculates the total price of all order lines, and converts to the specified target currency.
If not specified, the default system currency is used. If not specified, the default system currency is used.
@ -134,7 +153,7 @@ class Order(MetadataMixin, ReferenceIndexingMixin):
# Record the error, try to press on # Record the error, try to press on
kind, info, data = sys.exc_info() kind, info, data = sys.exc_info()
log_error('order.get_total_price') log_error('order.calculate_total_price')
logger.error(f"Missing exchange rate for '{target_currency}'") logger.error(f"Missing exchange rate for '{target_currency}'")
# Return None to indicate the calculated price is invalid # Return None to indicate the calculated price is invalid
@ -151,7 +170,7 @@ class Order(MetadataMixin, ReferenceIndexingMixin):
except MissingRate: except MissingRate:
# Record the error, try to press on # Record the error, try to press on
log_error('order.get_total_price') log_error('order.calculate_total_price')
logger.error(f"Missing exchange rate for '{target_currency}'") logger.error(f"Missing exchange rate for '{target_currency}'")
# Return None to indicate the calculated price is invalid # Return None to indicate the calculated price is invalid
@ -933,6 +952,24 @@ class OrderLineItem(models.Model):
"""Metaclass options. Abstract ensures no database table is created.""" """Metaclass options. Abstract ensures no database table is created."""
abstract = True abstract = True
def save(self, *args, **kwargs):
"""Custom save method for the OrderLineItem model
Calls save method on the linked order
"""
super().save(*args, **kwargs)
self.order.save()
def delete(self, *args, **kwargs):
"""Custom delete method for the OrderLineItem model
Calls save method on the linked order
"""
super().delete(*args, **kwargs)
self.order.save()
quantity = RoundingDecimalField( quantity = RoundingDecimalField(
verbose_name=_('Quantity'), verbose_name=_('Quantity'),
help_text=_('Item quantity'), help_text=_('Item quantity'),

View File

@ -34,11 +34,12 @@ class AbstractOrderSerializer(serializers.Serializer):
"""Abstract field definitions for OrderSerializers.""" """Abstract field definitions for OrderSerializers."""
total_price = InvenTreeMoneySerializer( total_price = InvenTreeMoneySerializer(
source='get_total_price',
allow_null=True, allow_null=True,
read_only=True, read_only=True,
) )
total_price_currency = InvenTreeCurrencySerializer(read_only=True)
class AbstractExtraLineSerializer(serializers.Serializer): class AbstractExtraLineSerializer(serializers.Serializer):
"""Abstract Serializer for a ExtraLine object.""" """Abstract Serializer for a ExtraLine object."""
@ -105,6 +106,7 @@ class PurchaseOrderSerializer(AbstractOrderSerializer, InvenTreeModelSerializer)
'target_date', 'target_date',
'notes', 'notes',
'total_price', 'total_price',
'total_price_currency',
] ]
read_only_fields = [ read_only_fields = [
@ -668,6 +670,7 @@ class SalesOrderSerializer(AbstractOrderSerializer, InvenTreeModelSerializer):
'shipment_date', 'shipment_date',
'target_date', 'target_date',
'total_price', 'total_price',
'total_price_currency',
] ]
read_only_fields = [ read_only_fields = [

View File

@ -191,7 +191,7 @@ src="{% static 'img/blank_image.png' %}"
<td><span class='fas fa-dollar-sign'></span></td> <td><span class='fas fa-dollar-sign'></span></td>
<td>{% trans "Total cost" %}</td> <td>{% trans "Total cost" %}</td>
<td id="poTotalPrice"> <td id="poTotalPrice">
{% with order.get_total_price as tp %} {% with order.total_price as tp %}
{% if tp == None %} {% if tp == None %}
<span class='badge bg-warning'>{% trans "Total cost could not be calculated" %}</span> <span class='badge bg-warning'>{% trans "Total cost could not be calculated" %}</span>
{% else %} {% else %}

View File

@ -189,7 +189,7 @@ src="{% static 'img/blank_image.png' %}"
<td><span class='fas fa-dollar-sign'></span></td> <td><span class='fas fa-dollar-sign'></span></td>
<td>{% trans "Total cost" %}</td> <td>{% trans "Total cost" %}</td>
<td id="soTotalPrice"> <td id="soTotalPrice">
{% with order.get_total_price as tp %} {% with order.total_price as tp %}
{% if tp == None %} {% if tp == None %}
<span class='badge bg-warning'>{% trans "Total cost could not be calculated" %}</span> <span class='badge bg-warning'>{% trans "Total cost could not be calculated" %}</span>
{% else %} {% else %}

View File

@ -5,12 +5,17 @@ import io
from datetime import datetime, timedelta from datetime import datetime, timedelta
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import connection
from django.test.utils import CaptureQueriesContext
from django.urls import reverse from django.urls import reverse
from djmoney.money import Money
from icalendar import Calendar from icalendar import Calendar
from rest_framework import status from rest_framework import status
import order.models as models import order.models as models
from common.settings import currency_codes
from company.models import Company
from InvenTree.api_tester import InvenTreeAPITestCase from InvenTree.api_tester import InvenTreeAPITestCase
from InvenTree.status_codes import PurchaseOrderStatus, SalesOrderStatus from InvenTree.status_codes import PurchaseOrderStatus, SalesOrderStatus
from part.models import Part from part.models import Part
@ -91,6 +96,65 @@ class PurchaseOrderTest(OrderTest):
self.filter({'supplier_part': 3}, 2) self.filter({'supplier_part': 3}, 2)
self.filter({'supplier_part': 4}, 0) self.filter({'supplier_part': 4}, 0)
def test_total_price(self):
"""Unit tests for the 'total_price' field"""
# Ensure we have exchange rate data
self.generate_exchange_rates()
currencies = currency_codes()
n = len(currencies)
idx = 0
new_orders = []
# Let's generate some more orders
for supplier in Company.objects.filter(is_supplier=True):
for _idx in range(10):
new_orders.append(
models.PurchaseOrder(
supplier=supplier,
reference=f'PO-{idx + 100}'
)
)
idx += 1
models.PurchaseOrder.objects.bulk_create(new_orders)
idx = 0
# Create some purchase order line items
lines = []
for po in models.PurchaseOrder.objects.all():
for sp in po.supplier.supplied_parts.all():
lines.append(
models.PurchaseOrderLineItem(
order=po,
part=sp,
quantity=idx + 1,
purchase_price=Money((idx + 1) / 10, currencies[idx % n]),
)
)
idx += 1
models.PurchaseOrderLineItem.objects.bulk_create(lines)
# List all purchase orders
for limit in [1, 5, 10, 100]:
with CaptureQueriesContext(connection) as ctx:
response = self.get(self.LIST_URL, data={'limit': limit}, expected_code=200)
# Total database queries must be below 15, independent of the number of results
self.assertLess(len(ctx), 15)
for result in response.data['results']:
self.assertIn('total_price', result)
self.assertIn('total_price_currency', result)
def test_overdue(self): def test_overdue(self):
"""Test "overdue" status.""" """Test "overdue" status."""
self.filter({'overdue': True}, 0) self.filter({'overdue': True}, 0)
@ -1001,6 +1065,79 @@ class SalesOrderTest(OrderTest):
self.filter({'assigned_to_me': 1}, 0) self.filter({'assigned_to_me': 1}, 0)
self.filter({'assigned_to_me': 0}, 5) self.filter({'assigned_to_me': 0}, 5)
def test_total_price(self):
"""Unit tests for the 'total_price' field"""
# Ensure we have exchange rate data
self.generate_exchange_rates()
currencies = currency_codes()
n = len(currencies)
idx = 0
new_orders = []
# Generate some new SalesOrders
for customer in Company.objects.filter(is_customer=True):
for _idx in range(10):
new_orders.append(
models.SalesOrder(
customer=customer,
reference=f'SO-{idx + 100}',
)
)
idx += 1
models.SalesOrder.objects.bulk_create(new_orders)
idx = 0
# Create some new SalesOrderLineItem objects
lines = []
extra_lines = []
for so in models.SalesOrder.objects.all():
for p in Part.objects.filter(salable=True):
lines.append(
models.SalesOrderLineItem(
order=so,
part=p,
quantity=idx + 1,
sale_price=Money((idx + 1) / 5, currencies[idx % n])
)
)
idx += 1
# Create some extra lines against this order
for ii in range(3):
extra_lines.append(
models.SalesOrderExtraLine(
order=so,
quantity=(idx + 2) % 10,
price=Money(10, 'CAD'),
)
)
models.SalesOrderLineItem.objects.bulk_create(lines)
models.SalesOrderExtraLine.objects.bulk_create(extra_lines)
# List all SalesOrder objects and count queries
for limit in [1, 5, 10, 100]:
with CaptureQueriesContext(connection) as ctx:
response = self.get(self.LIST_URL, data={'limit': limit}, expected_code=200)
# Total database queries must be less than 15
self.assertLess(len(ctx), 15)
n = len(response.data['results'])
for result in response.data['results']:
self.assertIn('total_price', result)
self.assertIn('total_price_currency', result)
def test_overdue(self): def test_overdue(self):
"""Test "overdue" status.""" """Test "overdue" status."""
self.filter({'overdue': True}, 0) self.filter({'overdue': True}, 0)

View File

@ -2,7 +2,6 @@
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from djmoney.contrib.exchange.models import ExchangeBackend, Rate
from djmoney.money import Money from djmoney.money import Money
import common.models import common.models
@ -18,31 +17,6 @@ from InvenTree.status_codes import PurchaseOrderStatus
class PartPricingTests(InvenTreeTestCase): class PartPricingTests(InvenTreeTestCase):
"""Unit tests for part pricing calculations""" """Unit tests for part pricing calculations"""
def generate_exchange_rates(self):
"""Generate some exchange rates to work with"""
rates = {
'AUD': 1.5,
'CAD': 1.7,
'GBP': 0.9,
'USD': 1.0,
}
# Create a dummy backend
ExchangeBackend.objects.create(
name='InvenTreeExchange',
base_currency='USD',
)
backend = ExchangeBackend.objects.get(name='InvenTreeExchange')
for currency, rate in rates.items():
Rate.objects.create(
currency=currency,
value=rate,
backend=backend,
)
def setUp(self): def setUp(self):
"""Setup routines""" """Setup routines"""

View File

@ -2155,9 +2155,11 @@ function loadPurchaseOrderTable(table, options) {
field: 'total_price', field: 'total_price',
title: '{% trans "Total Cost" %}', title: '{% trans "Total Cost" %}',
switchable: true, switchable: true,
sortable: false, sortable: true,
formatter: function(value, row) { formatter: function(value, row) {
return formatCurrency(value); return formatCurrency(value, {
currency: row.total_price_currency,
});
}, },
}, },
{ {
@ -2979,6 +2981,17 @@ function loadSalesOrderTable(table, options) {
field: 'line_items', field: 'line_items',
title: '{% trans "Items" %}' title: '{% trans "Items" %}'
}, },
{
field: 'total_price',
title: '{% trans "Total Cost" %}',
switchable: true,
sortable: true,
formatter: function(value, row) {
return formatCurrency(value, {
currency: row.total_price_currency,
});
}
}
], ],
}); });
} }