Docstring checks in QC checks (#3089)

* Add pre-commit to the stack

* exclude static

* Add locales to excludes

* fix style errors

* rename pipeline steps

* also wait on precommit

* make template matching simpler

* Use the same code for python setup everywhere

* use step and cache for python setup

* move regular settings up into general envs

* just use full update

* Use invoke instead of static references

* make setup actions more similar

* use python3

* refactor names to be similar

* fix runner version

* fix references

* remove incidential change

* use matrix for os

* Github can't do this right now

* ignore docstyle errors

* Add seperate docstring test

* update flake call

* do not fail on docstring

* refactor setup into workflow

* update reference

* switch to action

* resturcture

* add bash statements

* remove os from cache

* update input checks

* make code cleaner

* fix boolean

* no relative paths

* install wheel by python

* switch to install

* revert back to simple wheel

* refactor import export tests

* move setup keys back to not disturbe tests

* remove docstyle till that is fixed

* update references

* continue on error

* add docstring test

* use relativ action references

* Change step / job docstrings

* update to merge

* reformat comments 1

* fix docstrings 2

* fix docstrings 3

* fix docstrings 4

* fix docstrings 5

* fix docstrings 6

* fix docstrings 7

* fix docstrings 8

* fix docstirns 9

* fix docstrings 10

* docstring adjustments

* update the remaining docstrings

* small docstring changes

* fix function name

* update support files for docstrings

* Add missing args to docstrings

* Remove outdated function

* Add docstrings for the 'build' app

* Make API code cleaner

* add more docstrings for plugin app

* Remove dead code for plugin settings
No idea what that was even intended for

* ignore __init__ files for docstrings

* More docstrings

* Update docstrings for the 'part' directory

* Fixes for related_part functionality

* Fix removed stuff from merge 99676ee

* make more consistent

* Show statistics for docstrings

* add more docstrings

* move specific register statements to make them clearer to understant

* More docstrings for common

* and more docstrings

* and more

* simpler call

* docstrings for notifications

* docstrings for common/tests

* Add docs for common/models

* Revert "move specific register statements to make them clearer to understant"

This reverts commit ca96654622.

* use typing here

* Revert "Make API code cleaner"

This reverts commit 24fb68bd3e.

* docstring updates for the 'users' app

* Add generic Meta info to simple Meta classes

* remove unneeded unique_together statements

* More simple metas

* Remove unnecessary format specifier

* Remove extra json format specifiers

* Add docstrings for the 'plugin' app

* Docstrings for the 'label' app

* Add missing docstrings for the 'report' app

* Fix build test regression

* Fix top-level files

* docstrings for InvenTree/InvenTree

* reduce unneeded code

* add docstrings

* and more docstrings

* more docstrings

* more docstrings for stock

* more docstrings

* docstrings for order/views

* Docstrings for various files in the 'order' app

* Docstrings for order/test_api.py

* Docstrings for order/serializers.py

* Docstrings for order/admin.py

* More docstrings for the order app

* Add docstrings for the 'company' app

* Add unit tests for rebuilding the reference fields

* Prune out some more dead code

* remove more dead code

Co-authored-by: Oliver Walters <oliver.henry.walters@gmail.com>
This commit is contained in:
Matthias Mair 2022-06-01 17:37:39 +02:00 committed by GitHub
parent 66a6915213
commit 0c97a50e47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
223 changed files with 4416 additions and 6980 deletions

View File

@ -130,6 +130,22 @@ jobs:
invoke check-server
coverage run -m unittest discover -s test/
docstyle:
name: Style [Python Docstrings]
runs-on: ubuntu-20.04
needs: pre-commit
continue-on-error: true
steps:
- uses: actions/checkout@v1
- name: Enviroment Setup
uses: ./.github/actions/setup
with:
install: true
- name: Run flake8
run: flake8 InvenTree --statistics
coverage:
name: Tests - DB [SQLite] + Coverage
runs-on: ubuntu-20.04

View File

@ -1,5 +1,4 @@
"""
The InvenTree module provides high-level management and functionality.
"""The InvenTree module provides high-level management and functionality.
It provides a number of helper functions and generic classes which are used by InvenTree apps.
"""

View File

@ -1,6 +1,4 @@
"""
Main JSON interface views
"""
"""Main JSON interface views."""
from django.conf import settings
from django.http import JsonResponse
@ -16,14 +14,15 @@ from .views import AjaxView
class InfoView(AjaxView):
""" Simple JSON endpoint for InvenTree information.
"""Simple JSON endpoint for InvenTree information.
Use to confirm that the server is running, etc.
"""
permission_classes = [permissions.AllowAny]
def get(self, request, *args, **kwargs):
"""Serve current server information."""
data = {
'server': 'InvenTree',
'version': inventreeVersion(),
@ -37,14 +36,12 @@ class InfoView(AjaxView):
class NotFoundView(AjaxView):
"""
Simple JSON view when accessing an invalid API view.
"""
"""Simple JSON view when accessing an invalid API view."""
permission_classes = [permissions.AllowAny]
def get(self, request, *args, **kwargs):
"""Proces an `not found` event on the API."""
data = {
'details': _('API endpoint not found'),
'url': request.build_absolute_uri(),
@ -54,8 +51,7 @@ class NotFoundView(AjaxView):
class APIDownloadMixin:
"""
Mixin for enabling a LIST endpoint to be downloaded a file.
"""Mixin for enabling a LIST endpoint to be downloaded a file.
To download the data, add the ?export=<fmt> to the query string.
@ -76,7 +72,7 @@ class APIDownloadMixin:
"""
def get(self, request, *args, **kwargs):
"""Generic handler for a download request."""
export_format = request.query_params.get('export', None)
if export_format and export_format in ['csv', 'tsv', 'xls', 'xlsx']:
@ -88,14 +84,12 @@ class APIDownloadMixin:
return super().get(request, *args, **kwargs)
def download_queryset(self, queryset, export_format):
"""This function must be implemented to provide a downloadFile request."""
raise NotImplementedError("download_queryset method not implemented!")
class AttachmentMixin:
"""
Mixin for creating attachment objects,
and ensuring the user information is saved correctly.
"""
"""Mixin for creating attachment objects, and ensuring the user information is saved correctly."""
permission_classes = [permissions.IsAuthenticated]
@ -106,8 +100,7 @@ class AttachmentMixin:
]
def perform_create(self, serializer):
""" Save the user information when a file is uploaded """
"""Save the user information when a file is uploaded."""
attachment = serializer.save()
attachment.user = self.request.user
attachment.save()

View File

@ -1,6 +1,4 @@
"""
Helper functions for performing API unit tests
"""
"""Helper functions for performing API unit tests."""
import csv
import io
@ -14,6 +12,10 @@ from rest_framework.test import APITestCase
class UserMixin:
"""Mixin to setup a user and login for tests.
Use parameters to set username, password, email, roles and permissions.
"""
# User information
username = 'testuser'
@ -28,7 +30,7 @@ class UserMixin:
roles = []
def setUp(self):
"""Setup for all tests."""
super().setUp()
# Create a user to log in with
@ -62,10 +64,7 @@ class UserMixin:
self.client.login(username=self.username, password=self.password)
def assignRole(self, role=None, assign_all: bool = False):
"""
Set the user roles for the registered user
"""
"""Set the user roles for the registered user."""
# role is of the format 'rule.permission' e.g. 'part.add'
if not assign_all and role:
@ -89,16 +88,13 @@ class UserMixin:
class InvenTreeAPITestCase(UserMixin, APITestCase):
"""
Base class for running InvenTree API tests
"""
"""Base class for running InvenTree API tests."""
def getActions(self, url):
"""
Return a dict of the 'actions' available at a given endpoint.
"""Return a dict of the 'actions' available at a given endpoint.
Makes use of the HTTP 'OPTIONS' method to request this.
"""
response = self.client.options(url)
self.assertEqual(response.status_code, 200)
@ -110,10 +106,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase):
return actions
def get(self, url, data={}, expected_code=200):
"""
Issue a GET request
"""
"""Issue a GET request."""
response = self.client.get(url, data, format='json')
if expected_code is not None:
@ -127,10 +120,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase):
return response
def post(self, url, data, expected_code=None, format='json'):
"""
Issue a POST request
"""
"""Issue a POST request."""
response = self.client.post(url, data=data, format=format)
if expected_code is not None:
@ -139,10 +129,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase):
return response
def delete(self, url, expected_code=None):
"""
Issue a DELETE request
"""
"""Issue a DELETE request."""
response = self.client.delete(url)
if expected_code is not None:
@ -151,10 +138,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase):
return response
def patch(self, url, data, expected_code=None, format='json'):
"""
Issue a PATCH request
"""
"""Issue a PATCH request."""
response = self.client.patch(url, data=data, format=format)
if expected_code is not None:
@ -163,10 +147,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase):
return response
def put(self, url, data, expected_code=None, format='json'):
"""
Issue a PUT request
"""
"""Issue a PUT request."""
response = self.client.put(url, data=data, format=format)
if expected_code is not None:
@ -175,10 +156,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase):
return response
def options(self, url, expected_code=None):
"""
Issue an OPTIONS request
"""
"""Issue an OPTIONS request."""
response = self.client.options(url, format='json')
if expected_code is not None:
@ -187,10 +165,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase):
return response
def download_file(self, url, data, expected_code=None, expected_fn=None, decode=True):
"""
Download a file from the server, and return an in-memory file
"""
"""Download a file from the server, and return an in-memory file."""
response = self.client.get(url, data=data, format='json')
if expected_code is not None:
@ -226,10 +201,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase):
return fo
def process_csv(self, fo, delimiter=',', required_cols=None, excluded_cols=None, required_rows=None):
"""
Helper function to process and validate a downloaded csv file
"""
"""Helper function to process and validate a downloaded csv file."""
# Check that the correct object type has been passed
self.assertTrue(isinstance(fo, io.StringIO))

View File

@ -1,6 +1,4 @@
"""
InvenTree API version information
"""
"""InvenTree API version information."""
# InvenTree API version

View File

@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""AppConfig for inventree app."""
import logging
@ -18,10 +18,11 @@ logger = logging.getLogger("inventree")
class InvenTreeConfig(AppConfig):
"""AppConfig for inventree app."""
name = 'InvenTree'
def ready(self):
"""Setup background tasks and update exchange rates."""
if canAppAccessDatabase():
self.remove_obsolete_tasks()
@ -37,10 +38,7 @@ class InvenTreeConfig(AppConfig):
self.add_user_on_startup()
def remove_obsolete_tasks(self):
"""
Delete any obsolete scheduled tasks in the database
"""
"""Delete any obsolete scheduled tasks in the database."""
obsolete = [
'InvenTree.tasks.delete_expired_sessions',
'stock.tasks.delete_old_stock_items',
@ -55,7 +53,7 @@ class InvenTreeConfig(AppConfig):
Schedule.objects.filter(func__in=obsolete).delete()
def start_background_tasks(self):
"""Start all background tests for InvenTree."""
try:
from django_q.models import Schedule
except AppRegistryNotReady: # pragma: no cover
@ -101,13 +99,12 @@ class InvenTreeConfig(AppConfig):
)
def update_exchange_rates(self): # pragma: no cover
"""
Update exchange rates each time the server is started, *if*:
"""Update exchange rates each time the server is started.
Only runs *if*:
a) Have not been updated recently (one day or less)
b) The base exchange rate has been altered
"""
try:
from djmoney.contrib.exchange.models import ExchangeBackend
@ -150,7 +147,7 @@ class InvenTreeConfig(AppConfig):
logger.error(f"Error updating exchange rates: {e}")
def add_user_on_startup(self):
"""Add a user on startup"""
"""Add a user on startup."""
# stop if checks were already created
if hasattr(settings, 'USER_ADDED') and settings.USER_ADDED:
return
@ -202,9 +199,7 @@ class InvenTreeConfig(AppConfig):
settings.USER_ADDED = True
def collect_notification_methods(self):
"""
Collect all notification methods
"""
"""Collect all notification methods."""
from common.notifications import storage
storage.collect()

View File

@ -1,6 +1,6 @@
"""
Pull rendered copies of the templated
only used for testing the js files! - This file is omited from coverage
"""Pull rendered copies of the templated.
Only used for testing the js files! - This file is omited from coverage.
"""
import os # pragma: no cover
@ -10,15 +10,14 @@ from InvenTree.helpers import InvenTreeTestCase # pragma: no cover
class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover
"""
A unit test to "render" javascript files.
"""A unit test to "render" javascript files.
The server renders templated javascript files,
we need the fully-rendered files for linting and static tests.
"""
def download_file(self, filename, prefix):
"""Function to `download`(copy) a file to a temporay firectory."""
url = os.path.join(prefix, filename)
response = self.client.get(url)
@ -46,6 +45,7 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover
output.write(response.content)
def download_files(self, subdir, prefix):
"""Download files in directory."""
here = os.path.abspath(os.path.dirname(__file__))
js_template_dir = os.path.join(
@ -73,10 +73,7 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover
return n
def test_render_files(self):
"""
Look for all javascript files
"""
"""Look for all javascript files."""
n = 0
print("Rendering javascript files...")

View File

@ -1,6 +1,4 @@
"""
Helper functions for loading InvenTree configuration options
"""
"""Helper functions for loading InvenTree configuration options."""
import logging
import os
@ -10,17 +8,15 @@ logger = logging.getLogger('inventree')
def get_base_dir():
""" Returns the base (top-level) InvenTree directory """
"""Returns the base (top-level) InvenTree directory."""
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def get_config_file():
"""
Returns the path of the InvenTree configuration file.
"""Returns the path of the InvenTree configuration file.
Note: It will be created it if does not already exist!
"""
base_dir = get_base_dir()
cfg_filename = os.getenv('INVENTREE_CONFIG_FILE')
@ -43,8 +39,7 @@ def get_config_file():
def get_plugin_file():
"""
Returns the path of the InvenTree plugins specification file.
"""Returns the path of the InvenTree plugins specification file.
Note: It will be created if it does not already exist!
"""
@ -70,14 +65,12 @@ def get_plugin_file():
def get_setting(environment_var, backup_val, default_value=None):
"""
Helper function for retrieving a configuration setting value
"""Helper function for retrieving a configuration setting value.
- First preference is to look for the environment variable
- Second preference is to look for the value of the settings file
- Third preference is the default value
"""
val = os.getenv(environment_var)
if val is not None:

View File

@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
"""
Provides extra global data to all templates.
"""
"""Provides extra global data to all templates."""
import InvenTree.status
from InvenTree.status_codes import (BuildStatus, PurchaseOrderStatus,
@ -12,13 +10,11 @@ from users.models import RuleSet
def health_status(request):
"""
Provide system health status information to the global context.
"""Provide system health status information to the global context.
- Not required for AJAX requests
- Do not provide if it is already provided to the context
"""
if request.path.endswith('.js'):
# Do not provide to script requests
return {} # pragma: no cover
@ -53,10 +49,7 @@ def health_status(request):
def status_codes(request):
"""
Provide status code enumerations.
"""
"""Provide status code enumerations."""
if hasattr(request, '_inventree_status_codes'):
# Do not duplicate efforts
return {}
@ -74,8 +67,7 @@ def status_codes(request):
def user_roles(request):
"""
Return a map of the current roles assigned to the user.
"""Return a map of the current roles assigned to the user.
Roles are denoted by their simple names, and then the permission type.
@ -86,7 +78,6 @@ def user_roles(request):
Each value will return a boolean True / False
"""
user = request.user
roles = {

View File

@ -1,6 +1,4 @@
"""
Custom exception handling for the DRF API
"""
"""Custom exception handling for the DRF API."""
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
@ -21,13 +19,11 @@ from rest_framework.response import Response
def exception_handler(exc, context):
"""
Custom exception handler for DRF framework.
Ref: https://www.django-rest-framework.org/api-guide/exceptions/#custom-exception-handling
"""Custom exception handler for DRF framework.
Ref: https://www.django-rest-framework.org/api-guide/exceptions/#custom-exception-handling
Catches any errors not natively handled by DRF, and re-throws as an error DRF can handle
"""
response = None
# Catch any django validation error, and re-throw a DRF validation error

View File

@ -1,3 +1,5 @@
"""Exchangerate backend to use `exchangerate.host` to get rates."""
import ssl
from urllib.error import URLError
from urllib.request import urlopen
@ -11,30 +13,30 @@ from common.settings import currency_code_default, currency_codes
class InvenTreeExchange(SimpleExchangeBackend):
"""
Backend for automatically updating currency exchange rates.
"""Backend for automatically updating currency exchange rates.
Uses the exchangerate.host service API
Uses the `exchangerate.host` service API
"""
name = "InvenTreeExchange"
def __init__(self):
"""Set API url."""
self.url = "https://api.exchangerate.host/latest"
super().__init__()
def get_params(self):
"""Placeholder to set API key. Currently not required by `exchangerate.host`."""
# No API key is required
return {
}
def get_response(self, **kwargs):
"""
Custom code to get response from server.
"""Custom code to get response from server.
Note: Adds a 5-second timeout
"""
url = self.get_url(**kwargs)
try:
@ -46,7 +48,7 @@ class InvenTreeExchange(SimpleExchangeBackend):
return None
def update_rates(self, base_currency=currency_code_default()):
"""Set the requested currency codes and get rates."""
symbols = ','.join(currency_codes())
try:

View File

@ -1,4 +1,4 @@
""" Custom fields used in InvenTree """
"""Custom fields used in InvenTree."""
import sys
from decimal import Decimal
@ -19,24 +19,25 @@ from .validators import allowable_url_schemes
class InvenTreeURLFormField(FormURLField):
""" Custom URL form field with custom scheme validators """
"""Custom URL form field with custom scheme validators."""
default_validators = [validators.URLValidator(schemes=allowable_url_schemes())]
class InvenTreeURLField(models.URLField):
""" Custom URL field which has custom scheme validators """
"""Custom URL field which has custom scheme validators."""
default_validators = [validators.URLValidator(schemes=allowable_url_schemes())]
def formfield(self, **kwargs):
"""Return a Field instance for this field."""
return super().formfield(**{
'form_class': InvenTreeURLFormField
})
def money_kwargs():
""" returns the database settings for MoneyFields """
"""Returns the database settings for MoneyFields."""
from common.settings import currency_code_default, currency_code_mappings
kwargs = {}
@ -46,11 +47,10 @@ def money_kwargs():
class InvenTreeModelMoneyField(ModelMoneyField):
"""
Custom MoneyField for clean migrations while using dynamic currency settings
"""
"""Custom MoneyField for clean migrations while using dynamic currency settings."""
def __init__(self, **kwargs):
"""Overwrite default values and validators."""
# detect if creating migration
if 'migrate' in sys.argv or 'makemigrations' in sys.argv:
# remove currency information for a clean migration
@ -73,26 +73,24 @@ class InvenTreeModelMoneyField(ModelMoneyField):
super().__init__(**kwargs)
def formfield(self, **kwargs):
""" override form class to use own function """
"""Override form class to use own function."""
kwargs['form_class'] = InvenTreeMoneyField
return super().formfield(**kwargs)
class InvenTreeMoneyField(MoneyField):
""" custom MoneyField for clean migrations while using dynamic currency settings """
"""Custom MoneyField for clean migrations while using dynamic currency settings."""
def __init__(self, *args, **kwargs):
# override initial values with the real info from database
"""Override initial values with the real info from database."""
kwargs.update(money_kwargs())
super().__init__(*args, **kwargs)
class DatePickerFormField(forms.DateField):
"""
Custom date-picker field
"""
"""Custom date-picker field."""
def __init__(self, **kwargs):
"""Set up custom values."""
help_text = kwargs.get('help_text', _('Enter date'))
label = kwargs.get('label', None)
required = kwargs.get('required', False)
@ -115,10 +113,7 @@ class DatePickerFormField(forms.DateField):
def round_decimal(value, places):
"""
Round value to the specified number of places.
"""
"""Round value to the specified number of places."""
if value is not None:
# see https://docs.python.org/2/library/decimal.html#decimal.Decimal.quantize for options
return value.quantize(Decimal(10) ** -places)
@ -126,17 +121,19 @@ def round_decimal(value, places):
class RoundingDecimalFormField(forms.DecimalField):
"""Custom FormField that automatically rounds inputs."""
def to_python(self, value):
"""Convert value to python type."""
value = super().to_python(value)
value = round_decimal(value, self.decimal_places)
return value
def prepare_value(self, value):
"""
Override the 'prepare_value' method, to remove trailing zeros when displaying.
"""Override the 'prepare_value' method, to remove trailing zeros when displaying.
Why? It looks nice!
"""
if type(value) == Decimal:
return InvenTree.helpers.normalize(value)
else:
@ -144,11 +141,15 @@ class RoundingDecimalFormField(forms.DecimalField):
class RoundingDecimalField(models.DecimalField):
"""Custom Field that automatically rounds inputs."""
def to_python(self, value):
"""Convert value to python type."""
value = super().to_python(value)
return round_decimal(value, self.decimal_places)
def formfield(self, **kwargs):
"""Return a Field instance for this field."""
defaults = {
'form_class': RoundingDecimalFormField
}

View File

@ -1,9 +1,10 @@
"""General filters for InvenTree."""
from rest_framework.filters import OrderingFilter
class InvenTreeOrderingFilter(OrderingFilter):
"""
Custom OrderingFilter class which allows aliased filtering of related fields.
"""Custom OrderingFilter class which allows aliased filtering of related fields.
To use, simply specify this filter in the "filter_backends" section.
@ -20,16 +21,14 @@ class InvenTreeOrderingFilter(OrderingFilter):
"""
def get_ordering(self, request, queryset, view):
"""Override ordering for supporting aliases."""
ordering = super().get_ordering(request, queryset, view)
aliases = getattr(view, 'ordering_field_aliases', None)
# Attempt to map ordering fields based on provided aliases
if ordering is not None and aliases is not None:
"""
Ordering fields should be mapped to separate fields
"""
"""Ordering fields should be mapped to separate fields."""
ordering_initial = ordering
ordering = []

View File

@ -1,6 +1,4 @@
"""
Helper forms which subclass Django forms to provide additional functionality
"""
"""Helper forms which subclass Django forms to provide additional functionality."""
import logging
from urllib.parse import urlencode
@ -30,7 +28,7 @@ logger = logging.getLogger('inventree')
class HelperForm(forms.ModelForm):
""" Provides simple integration of crispy_forms extension. """
"""Provides simple integration of crispy_forms extension."""
# Custom field decorations can be specified here, per form class
field_prefix = {}
@ -38,6 +36,7 @@ class HelperForm(forms.ModelForm):
field_placeholder = {}
def __init__(self, *args, **kwargs):
"""Setup layout."""
super(forms.ModelForm, self).__init__(*args, **kwargs)
self.helper = FormHelper()
@ -54,14 +53,8 @@ class HelperForm(forms.ModelForm):
self.rebuild_layout()
def is_valid(self):
valid = super().is_valid()
return valid
def rebuild_layout(self):
"""Build crispy layout out of current fields."""
layouts = []
for field in self.fields:
@ -117,7 +110,7 @@ class HelperForm(forms.ModelForm):
class ConfirmForm(forms.Form):
""" Generic confirmation form """
"""Generic confirmation form."""
confirm = forms.BooleanField(
required=False, initial=False,
@ -125,14 +118,15 @@ class ConfirmForm(forms.Form):
)
class Meta:
"""Metaclass options."""
fields = [
'confirm'
]
class DeleteForm(forms.Form):
""" Generic deletion form which provides simple user confirmation
"""
"""Generic deletion form which provides simple user confirmation."""
confirm_delete = forms.BooleanField(
required=False,
@ -142,17 +136,19 @@ class DeleteForm(forms.Form):
)
class Meta:
"""Metaclass options."""
fields = [
'confirm_delete'
]
class EditUserForm(HelperForm):
"""
Form for editing user information
"""
"""Form for editing user information."""
class Meta:
"""Metaclass options."""
model = User
fields = [
'first_name',
@ -161,8 +157,7 @@ class EditUserForm(HelperForm):
class SetPasswordForm(HelperForm):
""" Form for setting user password
"""
"""Form for setting user password."""
enter_password = forms.CharField(max_length=100,
min_length=8,
@ -181,6 +176,8 @@ class SetPasswordForm(HelperForm):
help_text=_('Confirm new password'))
class Meta:
"""Metaclass options."""
model = User
fields = [
'enter_password',
@ -189,18 +186,21 @@ class SetPasswordForm(HelperForm):
class SettingCategorySelectForm(forms.ModelForm):
""" Form for setting category settings """
"""Form for setting category settings."""
category = forms.ModelChoiceField(queryset=PartCategory.objects.all())
class Meta:
"""Metaclass options."""
model = PartCategory
fields = [
'category'
]
def __init__(self, *args, **kwargs):
super(SettingCategorySelectForm, self).__init__(*args, **kwargs)
"""Setup form layout."""
super().__init__(*args, **kwargs)
self.helper = FormHelper()
# Form rendering
@ -220,10 +220,10 @@ class SettingCategorySelectForm(forms.ModelForm):
# override allauth
class CustomSignupForm(SignupForm):
"""
Override to use dynamic settings
"""
"""Override to use dynamic settings."""
def __init__(self, *args, **kwargs):
"""Check settings to influence which fields are needed."""
kwargs['email_required'] = InvenTreeSetting.get_setting('LOGIN_MAIL_REQUIRED')
super().__init__(*args, **kwargs)
@ -248,6 +248,7 @@ class CustomSignupForm(SignupForm):
set_form_field_order(self, ["username", "email", "email2", "password1", "password2", ])
def clean(self):
"""Make sure the supllied emails match if enabled in settings."""
cleaned_data = super().clean()
# check for two mail fields
@ -261,15 +262,16 @@ class CustomSignupForm(SignupForm):
class RegistratonMixin:
"""
Mixin to check if registration should be enabled
"""
"""Mixin to check if registration should be enabled."""
def is_open_for_signup(self, request, *args, **kwargs):
"""Check if signup is enabled in settings."""
if settings.EMAIL_HOST and InvenTreeSetting.get_setting('LOGIN_ENABLE_REG', True):
return super().is_open_for_signup(request, *args, **kwargs)
return False
def save_user(self, request, user, form, commit=True):
"""Check if a default group is set in settings."""
user = super().save_user(request, user, form)
start_group = InvenTreeSetting.get_setting('SIGNUP_GROUP')
if start_group:
@ -283,21 +285,19 @@ class RegistratonMixin:
class CustomAccountAdapter(RegistratonMixin, OTPAdapter, DefaultAccountAdapter):
"""
Override of adapter to use dynamic settings
"""
"""Override of adapter to use dynamic settings."""
def send_mail(self, template_prefix, email, context):
"""only send mail if backend configured"""
"""Only send mail if backend configured."""
if settings.EMAIL_HOST:
return super().send_mail(template_prefix, email, context)
return False
class CustomSocialAccountAdapter(RegistratonMixin, DefaultSocialAccountAdapter):
"""
Override of adapter to use dynamic settings
"""
"""Override of adapter to use dynamic settings."""
def is_auto_signup_allowed(self, request, sociallogin):
"""Check if auto signup is enabled in settings."""
if InvenTreeSetting.get_setting('LOGIN_SIGNUP_SSO_AUTO', True):
return super().is_auto_signup_allowed(request, sociallogin)
return False
@ -308,6 +308,7 @@ class CustomSocialAccountAdapter(RegistratonMixin, DefaultSocialAccountAdapter):
return user_has_valid_totp_device(user)
def login(self, request, user):
"""Ensure user is send to 2FA before login if enabled."""
# Require two-factor authentication if it has been configured.
if self.has_2fa_enabled(user):
# Cast to string for the case when this is not a JSON serializable

View File

@ -1,6 +1,4 @@
"""
Provides helper functions used throughout the InvenTree project
"""
"""Provides helper functions used throughout the InvenTree project."""
import io
import json
@ -27,21 +25,15 @@ from .settings import MEDIA_URL, STATIC_URL
def getSetting(key, backup_value=None):
"""
Shortcut for reading a setting value from the database
"""
"""Shortcut for reading a setting value from the database."""
return InvenTreeSetting.get_setting(key, backup_value=backup_value)
def generateTestKey(test_name):
"""
Generate a test 'key' for a given test name.
This must not have illegal chars as it will be used for dict lookup in a template.
"""Generate a test 'key' for a given test name. This must not have illegal chars as it will be used for dict lookup in a template.
Tests must be named such that they will have unique keys.
"""
key = test_name.strip().lower()
key = key.replace(" ", "")
@ -52,33 +44,23 @@ def generateTestKey(test_name):
def getMediaUrl(filename):
"""
Return the qualified access path for the given file,
under the media directory.
"""
"""Return the qualified access path for the given file, under the media directory."""
return os.path.join(MEDIA_URL, str(filename))
def getStaticUrl(filename):
"""
Return the qualified access path for the given file,
under the static media directory.
"""
"""Return the qualified access path for the given file, under the static media directory."""
return os.path.join(STATIC_URL, str(filename))
def construct_absolute_url(*arg):
"""
Construct (or attempt to construct) an absolute URL from a relative URL.
"""Construct (or attempt to construct) an absolute URL from a relative URL.
This is useful when (for example) sending an email to a user with a link
to something in the InvenTree web framework.
This requires the BASE_URL configuration option to be set!
"""
base = str(InvenTreeSetting.get_setting('INVENTREE_BASE_URL'))
url = '/'.join(arg)
@ -99,23 +81,17 @@ def construct_absolute_url(*arg):
def getBlankImage():
"""
Return the qualified path for the 'blank image' placeholder.
"""
"""Return the qualified path for the 'blank image' placeholder."""
return getStaticUrl("img/blank_image.png")
def getBlankThumbnail():
"""
Return the qualified path for the 'blank image' thumbnail placeholder.
"""
"""Return the qualified path for the 'blank image' thumbnail placeholder."""
return getStaticUrl("img/blank_image.thumbnail.png")
def TestIfImage(img):
""" Test if an image file is indeed an image """
"""Test if an image file is indeed an image."""
try:
Image.open(img).verify()
return True
@ -124,7 +100,7 @@ def TestIfImage(img):
def TestIfImageURL(url):
""" Test if an image URL (or filename) looks like a valid image format.
"""Test if an image URL (or filename) looks like a valid image format.
Simply tests the extension against a set of allowed values
"""
@ -137,7 +113,7 @@ def TestIfImageURL(url):
def str2bool(text, test=True):
""" Test if a string 'looks' like a boolean value.
"""Test if a string 'looks' like a boolean value.
Args:
text: Input text
@ -153,10 +129,7 @@ def str2bool(text, test=True):
def is_bool(text):
"""
Determine if a string value 'looks' like a boolean.
"""
"""Determine if a string value 'looks' like a boolean."""
if str2bool(text, True):
return True
elif str2bool(text, False):
@ -166,9 +139,7 @@ def is_bool(text):
def isNull(text):
"""
Test if a string 'looks' like a null value.
This is useful for querying the API against a null key.
"""Test if a string 'looks' like a null value. This is useful for querying the API against a null key.
Args:
text: Input text
@ -176,15 +147,11 @@ def isNull(text):
Returns:
True if the text looks like a null value
"""
return str(text).strip().lower() in ['top', 'null', 'none', 'empty', 'false', '-1', '']
def normalize(d):
"""
Normalize a decimal number, and remove exponential formatting.
"""
"""Normalize a decimal number, and remove exponential formatting."""
if type(d) is not Decimal:
d = Decimal(d)
@ -195,8 +162,7 @@ def normalize(d):
def increment(n):
"""
Attempt to increment an integer (or a string that looks like an integer!)
"""Attempt to increment an integer (or a string that looks like an integer).
e.g.
@ -204,9 +170,7 @@ def increment(n):
2 -> 3
AB01 -> AB02
QQQ -> QQQ
"""
value = str(n).strip()
# Ignore empty strings
@ -248,10 +212,7 @@ def increment(n):
def decimal2string(d):
"""
Format a Decimal number as a string,
stripping out any trailing zeroes or decimal points.
Essentially make it look like a whole number if it is one.
"""Format a Decimal number as a string, stripping out any trailing zeroes or decimal points. Essentially make it look like a whole number if it is one.
Args:
d: A python Decimal object
@ -259,7 +220,6 @@ def decimal2string(d):
Returns:
A string representation of the input number
"""
if type(d) is Decimal:
d = normalize(d)
@ -280,8 +240,7 @@ def decimal2string(d):
def decimal2money(d, currency=None):
"""
Format a Decimal number as Money
"""Format a Decimal number as Money.
Args:
d: A python Decimal object
@ -296,7 +255,7 @@ def decimal2money(d, currency=None):
def WrapWithQuotes(text, quote='"'):
""" Wrap the supplied text with quotes
"""Wrap the supplied text with quotes.
Args:
text: Input text to wrap
@ -305,7 +264,6 @@ def WrapWithQuotes(text, quote='"'):
Returns:
Supplied text wrapped in quote char
"""
if not text.startswith(quote):
text = quote + text
@ -316,7 +274,7 @@ def WrapWithQuotes(text, quote='"'):
def MakeBarcode(object_name, object_pk, object_data=None, **kwargs):
""" Generate a string for a barcode. Adds some global InvenTree parameters.
"""Generate a string for a barcode. Adds some global InvenTree parameters.
Args:
object_type: string describing the object type e.g. 'StockItem'
@ -363,8 +321,7 @@ def MakeBarcode(object_name, object_pk, object_data=None, **kwargs):
def GetExportFormats():
""" Return a list of allowable file formats for exporting data """
"""Return a list of allowable file formats for exporting data."""
return [
'csv',
'tsv',
@ -375,9 +332,8 @@ def GetExportFormats():
]
def DownloadFile(data, filename, content_type='application/text', inline=False):
"""
Create a dynamic file for the user to download.
def DownloadFile(data, filename, content_type='application/text', inline=False) -> StreamingHttpResponse:
"""Create a dynamic file for the user to download.
Args:
data: Raw file data (string or bytes)
@ -388,7 +344,6 @@ def DownloadFile(data, filename, content_type='application/text', inline=False):
Return:
A StreamingHttpResponse object wrapping the supplied data
"""
filename = WrapWithQuotes(filename)
if type(data) == str:
@ -407,8 +362,7 @@ def DownloadFile(data, filename, content_type='application/text', inline=False):
def extract_serial_numbers(serials, expected_quantity, next_number: int):
"""
Attempt to extract serial numbers from an input string:
"""Attempt to extract serial numbers from an input string.
Requirements:
- Serial numbers can be either strings, or integers
@ -423,7 +377,6 @@ def extract_serial_numbers(serials, expected_quantity, next_number: int):
expected_quantity: The number of (unique) serial numbers we expect
next_number(int): the next possible serial number
"""
serials = serials.strip()
# fill in the next serial number into the serial
@ -543,8 +496,7 @@ def extract_serial_numbers(serials, expected_quantity, next_number: int):
def validateFilterString(value, model=None):
"""
Validate that a provided filter string looks like a list of comma-separated key=value pairs
"""Validate that a provided filter string looks like a list of comma-separated key=value pairs.
These should nominally match to a valid database filter based on the model being filtered.
@ -559,7 +511,6 @@ def validateFilterString(value, model=None):
Returns a map of key:value pairs
"""
# Empty results map
results = {}
@ -605,28 +556,19 @@ def validateFilterString(value, model=None):
def addUserPermission(user, permission):
"""
Shortcut function for adding a certain permission to a user.
"""
"""Shortcut function for adding a certain permission to a user."""
perm = Permission.objects.get(codename=permission)
user.user_permissions.add(perm)
def addUserPermissions(user, permissions):
"""
Shortcut function for adding multiple permissions to a user.
"""
"""Shortcut function for adding multiple permissions to a user."""
for permission in permissions:
addUserPermission(user, permission)
def getMigrationFileNames(app):
"""
Return a list of all migration filenames for provided app
"""
"""Return a list of all migration filenames for provided app."""
local_dir = os.path.dirname(os.path.abspath(__file__))
migration_dir = os.path.join(local_dir, '..', app, 'migrations')
@ -646,10 +588,7 @@ def getMigrationFileNames(app):
def getOldestMigrationFile(app, exclude_extension=True, ignore_initial=True):
"""
Return the filename associated with the oldest migration
"""
"""Return the filename associated with the oldest migration."""
oldest_num = -1
oldest_file = None
@ -671,10 +610,7 @@ def getOldestMigrationFile(app, exclude_extension=True, ignore_initial=True):
def getNewestMigrationFile(app, exclude_extension=True):
"""
Return the filename associated with the newest migration
"""
"""Return the filename associated with the newest migration."""
newest_file = None
newest_num = -1
@ -692,8 +628,7 @@ def getNewestMigrationFile(app, exclude_extension=True):
def clean_decimal(number):
""" Clean-up decimal value """
"""Clean-up decimal value."""
# Check if empty
if number is None or number == '' or number == 0:
return Decimal(0)
@ -729,7 +664,7 @@ def clean_decimal(number):
def get_objectreference(obj, type_ref: str = 'content_type', object_ref: str = 'object_id'):
"""lookup method for the GenericForeignKey fields
"""Lookup method for the GenericForeignKey fields.
Attributes:
- obj: object that will be resolved
@ -769,9 +704,7 @@ def get_objectreference(obj, type_ref: str = 'content_type', object_ref: str = '
def inheritors(cls):
"""
Return all classes that are subclasses from the supplied cls
"""
"""Return all classes that are subclasses from the supplied cls."""
subcls = set()
work = [cls]
while work:
@ -784,4 +717,5 @@ def inheritors(cls):
class InvenTreeTestCase(UserMixin, TestCase):
"""Testcase with user setup buildin."""
pass

View File

@ -1,6 +1,4 @@
"""
Custom management command to cleanup old settings that are not defined anymore
"""
"""Custom management command to cleanup old settings that are not defined anymore."""
import logging
@ -10,12 +8,10 @@ logger = logging.getLogger('inventree')
class Command(BaseCommand):
"""
Cleanup old (undefined) settings in the database
"""
"""Cleanup old (undefined) settings in the database."""
def handle(self, *args, **kwargs):
"""Cleanup old (undefined) settings in the database."""
logger.info("Collecting settings")
from common.models import InvenTreeSetting, InvenTreeUserSetting

View File

@ -1,6 +1,4 @@
"""
Custom management command to prerender files
"""
"""Custom management command to prerender files."""
import os
@ -13,7 +11,7 @@ from django.utils.translation import override as lang_over
def render_file(file_name, source, target, locales, ctx):
""" renders a file into all provided locales """
"""Renders a file into all provided locales."""
for locale in locales:
target_file = os.path.join(target, locale + '.' + file_name)
with open(target_file, 'w') as localised_file:
@ -23,11 +21,10 @@ def render_file(file_name, source, target, locales, ctx):
class Command(BaseCommand):
"""
django command to prerender files
"""
"""Django command to prerender files."""
def handle(self, *args, **kwargs):
"""Django command to prerender files."""
# static directories
LC_DIR = settings.LOCALE_PATHS[0]
SOURCE_DIR = settings.STATICFILES_I18_SRC

View File

@ -1,5 +1,4 @@
"""
Custom management command to rebuild all MPTT models
"""Custom management command to rebuild all MPTT models.
- This is crucial after importing any fixtures, etc
"""
@ -8,12 +7,10 @@ from django.core.management.base import BaseCommand
class Command(BaseCommand):
"""
Rebuild all database models which leverage the MPTT structure.
"""
"""Rebuild all database models which leverage the MPTT structure."""
def handle(self, *args, **kwargs):
"""Rebuild all database models which leverage the MPTT structure."""
# Part model
try:
print("Rebuilding Part objects")

View File

@ -1,5 +1,4 @@
"""
Custom management command to rebuild thumbnail images
"""Custom management command to rebuild thumbnail images.
- May be required after importing a new dataset, for example
"""
@ -20,15 +19,10 @@ logger = logging.getLogger('inventree')
class Command(BaseCommand):
"""
Rebuild all thumbnail images
"""
"""Rebuild all thumbnail images."""
def rebuild_thumbnail(self, model):
"""
Rebuild the thumbnail specified by the "image" field of the provided model
"""
"""Rebuild the thumbnail specified by the "image" field of the provided model."""
if not model.image:
return
@ -47,7 +41,7 @@ class Command(BaseCommand):
logger.warning(f"Warning: Image file '{img}' is not a valid image")
def handle(self, *args, **kwargs):
"""Rebuild all thumbnail images."""
logger.info("Rebuilding Part thumbnails")
for part in Part.objects.exclude(image=None):

View File

@ -1,21 +1,18 @@
"""
Custom management command to remove MFA for a user
"""
"""Custom management command to remove MFA for a user."""
from django.contrib.auth import get_user_model
from django.core.management.base import BaseCommand
class Command(BaseCommand):
"""
Remove MFA for a user
"""
"""Remove MFA for a user."""
def add_arguments(self, parser):
"""Add the arguments."""
parser.add_argument('mail', type=str)
def handle(self, *args, **kwargs):
"""Remove MFA for the supplied user (by mail)."""
# general settings
mail = kwargs.get('mail')
if not mail:

View File

@ -1,6 +1,4 @@
"""
Custom management command, wait for the database to be ready!
"""
"""Custom management command, wait for the database to be ready!"""
import time
@ -10,12 +8,10 @@ from django.db.utils import ImproperlyConfigured, OperationalError
class Command(BaseCommand):
"""
django command to pause execution until the database is ready
"""
"""Django command to pause execution until the database is ready."""
def handle(self, *args, **kwargs):
"""Wait till the database is ready."""
self.stdout.write("Waiting for database...")
connected = False

View File

@ -1,3 +1,5 @@
"""Custom metadata for DRF."""
import logging
from rest_framework import serializers
@ -12,8 +14,7 @@ logger = logging.getLogger('inventree')
class InvenTreeMetadata(SimpleMetadata):
"""
Custom metadata class for the DRF API.
"""Custom metadata class for the DRF API.
This custom metadata class imits the available "actions",
based on the user's role permissions.
@ -23,11 +24,10 @@ class InvenTreeMetadata(SimpleMetadata):
Additionally, we include some extra information about database models,
so we can perform lookup for ForeignKey related fields.
"""
def determine_metadata(self, request, view):
"""Overwrite the metadata to adapt to hte request user."""
self.request = request
self.view = view
@ -106,11 +106,7 @@ class InvenTreeMetadata(SimpleMetadata):
return metadata
def get_serializer_info(self, serializer):
"""
Override get_serializer_info so that we can add 'default' values
to any fields whose Meta.model specifies a default value
"""
"""Override get_serializer_info so that we can add 'default' values to any fields whose Meta.model specifies a default value."""
self.serializer = serializer
serializer_info = super().get_serializer_info(serializer)
@ -208,10 +204,7 @@ class InvenTreeMetadata(SimpleMetadata):
pass
if instance is not None:
"""
If there is an instance associated with this API View,
introspect that instance to find any specific API info.
"""
"""If there is an instance associated with this API View, introspect that instance to find any specific API info."""
if hasattr(instance, 'api_instance_filters'):
@ -233,13 +226,10 @@ class InvenTreeMetadata(SimpleMetadata):
return serializer_info
def get_field_info(self, field):
"""
Given an instance of a serializer field, return a dictionary
of metadata about it.
"""Given an instance of a serializer field, return a dictionary of metadata about it.
We take the regular DRF metadata and add our own unique flavor
"""
# Run super method first
field_info = super().get_field_info(field)

View File

@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""Middleware for InvenTree."""
import logging
@ -19,10 +19,17 @@ logger = logging.getLogger("inventree")
class AuthRequiredMiddleware(object):
"""Check for user to be authenticated."""
def __init__(self, get_response):
"""Save response object."""
self.get_response = get_response
def __call__(self, request):
"""Check if user needs to be authenticated and is.
Redirects to login if not authenticated.
"""
# Code to be executed for each request before
# the view (and later middleware) are called.
@ -35,6 +42,7 @@ class AuthRequiredMiddleware(object):
if not request.user.is_authenticated:
"""
Normally, a web-based session would use csrftoken based authentication.
However when running an external application (e.g. the InvenTree app or Python library),
we must validate the user token manually.
"""
@ -105,9 +113,9 @@ url_matcher = re_path('', include(frontendpatterns))
class Check2FAMiddleware(BaseRequire2FAMiddleware):
"""check if user is required to have MFA enabled"""
"""Check if user is required to have MFA enabled."""
def require_2fa(self, request):
# Superusers are require to have 2FA.
"""Use setting to check if MFA should be enforced for frontend page."""
try:
if url_matcher.resolve(request.path[1:]):
return InvenTreeSetting.get_setting('LOGIN_ENFORCE_MFA')
@ -117,8 +125,9 @@ class Check2FAMiddleware(BaseRequire2FAMiddleware):
class CustomAllauthTwoFactorMiddleware(AllauthTwoFactorMiddleware):
"""This function ensures only frontend code triggers the MFA auth cycle"""
"""This function ensures only frontend code triggers the MFA auth cycle."""
def process_request(self, request):
"""Check if requested url is forntend and enforce MFA check."""
try:
if not url_matcher.resolve(request.path[1:]):
super().process_request(request)
@ -127,12 +136,11 @@ class CustomAllauthTwoFactorMiddleware(AllauthTwoFactorMiddleware):
class InvenTreeRemoteUserMiddleware(PersistentRemoteUserMiddleware):
"""
Middleware to check if HTTP-header based auth is enabled and to set it up
"""
"""Middleware to check if HTTP-header based auth is enabled and to set it up."""
header = settings.REMOTE_LOGIN_HEADER
def process_request(self, request):
"""Check if proxy login is enabled."""
if not settings.REMOTE_LOGIN:
return

View File

@ -1,6 +1,4 @@
"""
Generic models which provide extra functionality over base Django model types.
"""
"""Generic models which provide extra functionality over base Django model types."""
import logging
import os
@ -25,9 +23,7 @@ logger = logging.getLogger('inventree')
def rename_attachment(instance, filename):
"""
Function for renaming an attachment file.
The subdirectory for the uploaded file is determined by the implementing class.
"""Function for renaming an attachment file. The subdirectory for the uploaded file is determined by the implementing class.
Args:
instance: Instance of a PartAttachment object
@ -36,14 +32,12 @@ def rename_attachment(instance, filename):
Returns:
path to store file, format: '<subdir>/<id>/filename'
"""
# Construct a path to store a file attachment for a given model type
return os.path.join(instance.getSubdir(), filename)
class DataImportMixin(object):
"""
Model mixin class which provides support for 'data import' functionality.
"""Model mixin class which provides support for 'data import' functionality.
Models which implement this mixin should provide information on the fields available for import
"""
@ -53,12 +47,10 @@ class DataImportMixin(object):
@classmethod
def get_import_fields(cls):
"""
Return all available import fields
"""Return all available import fields.
Where information on a particular field is not explicitly provided,
introspect the base model to (attempt to) find that information.
"""
fields = cls.IMPORT_FIELDS
@ -85,7 +77,7 @@ class DataImportMixin(object):
@classmethod
def get_required_import_fields(cls):
""" Return all *required* import fields """
"""Return all *required* import fields."""
fields = {}
for name, field in cls.get_import_fields().items():
@ -98,8 +90,7 @@ class DataImportMixin(object):
class ReferenceIndexingMixin(models.Model):
"""
A mixin for keeping track of numerical copies of the "reference" field.
"""A mixin for keeping track of numerical copies of the "reference" field.
!!DANGER!! always add `ReferenceIndexingSerializerMixin`to all your models serializers to
ensure the reference field is not too big
@ -118,18 +109,20 @@ class ReferenceIndexingMixin(models.Model):
"""
class Meta:
"""Metaclass options. Abstract ensures no database table is created."""
abstract = True
def rebuild_reference_field(self):
"""Extract integer out of reference for sorting."""
reference = getattr(self, 'reference', '')
self.reference_int = extract_int(reference)
reference_int = models.BigIntegerField(default=0)
def extract_int(reference, clip=0x7fffffff):
"""Extract integer out of reference."""
# Default value if we cannot convert to an integer
ref_int = 0
@ -155,7 +148,7 @@ def extract_int(reference, clip=0x7fffffff):
class InvenTreeAttachment(models.Model):
""" Provides an abstracted class for managing file attachments.
"""Provides an abstracted class for managing file attachments.
An attachment can be either an uploaded file, or an external URL
@ -167,14 +160,14 @@ class InvenTreeAttachment(models.Model):
"""
def getSubdir(self):
"""
Return the subdirectory under which attachments should be stored.
"""Return the subdirectory under which attachments should be stored.
Note: Re-implement this for each subclass of InvenTreeAttachment
"""
return "attachments"
def save(self, *args, **kwargs):
"""Provide better validation error."""
# Either 'attachment' or 'link' must be specified!
if not self.attachment and not self.link:
raise ValidationError({
@ -185,6 +178,7 @@ class InvenTreeAttachment(models.Model):
super().save(*args, **kwargs)
def __str__(self):
"""Human name for attachment."""
if self.attachment is not None:
return os.path.basename(self.attachment.name)
else:
@ -215,6 +209,7 @@ class InvenTreeAttachment(models.Model):
@property
def basename(self):
"""Base name/path for attachment."""
if self.attachment:
return os.path.basename(self.attachment.name)
else:
@ -222,15 +217,13 @@ class InvenTreeAttachment(models.Model):
@basename.setter
def basename(self, fn):
"""
Function to rename the attachment file.
"""Function to rename the attachment file.
- Filename cannot be empty
- Filename cannot contain illegal characters
- Filename must specify an extension
- Filename cannot match an existing file
"""
fn = fn.strip()
if len(fn) == 0:
@ -287,11 +280,13 @@ class InvenTreeAttachment(models.Model):
raise ValidationError(_("Error renaming file"))
class Meta:
"""Metaclass options. Abstract ensures no database table is created."""
abstract = True
class InvenTreeTree(MPTTModel):
""" Provides an abstracted self-referencing tree model for data categories.
"""Provides an abstracted self-referencing tree model for data categories.
- Each Category has one parent Category, which can be blank (for a top-level Category).
- Each Category can have zero-or-more child Categor(y/ies)
@ -303,10 +298,7 @@ class InvenTreeTree(MPTTModel):
"""
def api_instance_filters(self):
"""
Instance filters for InvenTreeTree models
"""
"""Instance filters for InvenTreeTree models."""
return {
'parent': {
'exclude_tree': self.pk,
@ -314,7 +306,7 @@ class InvenTreeTree(MPTTModel):
}
def save(self, *args, **kwargs):
"""Provide better error for invalid moves."""
try:
super().save(*args, **kwargs)
except InvalidMove:
@ -323,12 +315,15 @@ class InvenTreeTree(MPTTModel):
})
class Meta:
"""Metaclass defines extra model properties."""
abstract = True
# Names must be unique at any given level in the tree
unique_together = ('name', 'parent')
class MPTTMeta:
"""Set insert order."""
order_insertion_by = ['name']
name = models.CharField(
@ -356,7 +351,7 @@ class InvenTreeTree(MPTTModel):
@property
def item_count(self):
""" Return the number of items which exist *under* this node in the tree.
"""Return the number of items which exist *under* this node in the tree.
Here an 'item' is considered to be the 'leaf' at the end of each branch,
and the exact nature here will depend on the class implementation.
@ -366,30 +361,29 @@ class InvenTreeTree(MPTTModel):
return 0
def getUniqueParents(self):
""" Return a flat set of all parent items that exist above this node.
"""Return a flat set of all parent items that exist above this node.
If any parents are repeated (which would be very bad!), the process is halted
"""
return self.get_ancestors()
def getUniqueChildren(self, include_self=True):
""" Return a flat set of all child items that exist under this node.
"""Return a flat set of all child items that exist under this node.
If any child items are repeated, the repetitions are omitted.
"""
return self.get_descendants(include_self=include_self)
@property
def has_children(self):
""" True if there are any children under this item """
"""True if there are any children under this item."""
return self.getUniqueChildren(include_self=False).count() > 0
def getAcceptableParents(self):
""" Returns a list of acceptable parent items within this model
Acceptable parents are ones which are not underneath this item.
"""Returns a list of acceptable parent items within this model Acceptable parents are ones which are not underneath this item.
Setting the parent of an item to its own child results in recursion.
"""
contents = ContentType.objects.get_for_model(type(self))
available = contents.get_all_objects_for_this_type()
@ -407,17 +401,16 @@ class InvenTreeTree(MPTTModel):
@property
def parentpath(self):
""" Get the parent path of this category
"""Get the parent path of this category.
Returns:
List of category names from the top level to the parent of this category
"""
return [a for a in self.get_ancestors()]
@property
def path(self):
""" Get the complete part of this category.
"""Get the complete part of this category.
e.g. ["Top", "Second", "Third", "This"]
@ -428,25 +421,23 @@ class InvenTreeTree(MPTTModel):
@property
def pathstring(self):
""" Get a string representation for the path of this item.
"""Get a string representation for the path of this item.
e.g. "Top/Second/Third/This"
"""
return '/'.join([item.name for item in self.path])
def __str__(self):
""" String representation of a category is the full path to that category """
"""String representation of a category is the full path to that category."""
return "{path} - {desc}".format(path=self.pathstring, desc=self.description)
@receiver(pre_delete, sender=InvenTreeTree, dispatch_uid='tree_pre_delete_log')
def before_delete_tree_item(sender, instance, using, **kwargs):
""" Receives pre_delete signal from InvenTreeTree object.
"""Receives pre_delete signal from InvenTreeTree object.
Before an item is deleted, update each child object to point to the parent of the object being deleted.
"""
# Update each tree item below this one
for child in instance.children.all():
child.parent = instance.parent

View File

@ -1,12 +1,12 @@
"""Permission set for InvenTree."""
from rest_framework import permissions
import users.models
class RolePermission(permissions.BasePermission):
"""
Role mixin for API endpoints, allowing us to specify the user "role"
which is required for certain operations.
"""Role mixin for API endpoints, allowing us to specify the user "role" which is required for certain operations.
Each endpoint can have one or more of the following actions:
- GET
@ -25,14 +25,10 @@ class RolePermission(permissions.BasePermission):
to perform the specified action.
For example, a DELETE action will be rejected unless the user has the "part.remove" permission
"""
def has_permission(self, request, view):
"""
Determine if the current user has the specified permissions
"""
"""Determine if the current user has the specified permissions."""
user = request.user
# Superuser can do it all

View File

@ -1,31 +1,24 @@
"""Functions to check if certain parts of InvenTree are ready."""
import sys
def isInTestMode():
"""
Returns True if the database is in testing mode
"""
"""Returns True if the database is in testing mode."""
return 'test' in sys.argv
def isImportingData():
"""
Returns True if the database is currently importing data,
e.g. 'loaddata' command is performed
"""
"""Returns True if the database is currently importing data, e.g. 'loaddata' command is performed."""
return 'loaddata' in sys.argv
def canAppAccessDatabase(allow_test=False):
"""
Returns True if the apps.py file can access database records.
"""Returns True if the apps.py file can access database records.
There are some circumstances where we don't want the ready function in apps.py
to touch the database
"""
# If any of the following management commands are being executed,
# prevent custom "on load" code from running!
excluded_commands = [

View File

@ -1,6 +1,4 @@
"""
Serializers used in various InvenTree apps
"""
"""Serializers used in various InvenTree apps."""
import os
from collections import OrderedDict
@ -26,25 +24,20 @@ from .models import extract_int
class InvenTreeMoneySerializer(MoneyField):
"""
Custom serializer for 'MoneyField',
which ensures that passed values are numerically valid
"""Custom serializer for 'MoneyField', which ensures that passed values are numerically valid.
Ref: https://github.com/django-money/django-money/blob/master/djmoney/contrib/django_rest_framework/fields.py
"""
def __init__(self, *args, **kwargs):
"""Overrite default values."""
kwargs["max_digits"] = kwargs.get("max_digits", 19)
kwargs["decimal_places"] = kwargs.get("decimal_places", 4)
super().__init__(*args, **kwargs)
def get_value(self, data):
"""
Test that the returned amount is a valid Decimal
"""
"""Test that the returned amount is a valid Decimal."""
amount = super(DecimalField, self).get_value(data)
# Convert an empty string to None
@ -68,17 +61,21 @@ class InvenTreeMoneySerializer(MoneyField):
class UserSerializer(serializers.ModelSerializer):
""" Serializer for User - provides all fields """
"""Serializer for User - provides all fields."""
class Meta:
"""Metaclass options."""
model = User
fields = 'all'
class UserSerializerBrief(serializers.ModelSerializer):
""" Serializer for User - provides limited information """
"""Serializer for User - provides limited information."""
class Meta:
"""Metaclass options."""
model = User
fields = [
'pk',
@ -87,17 +84,10 @@ class UserSerializerBrief(serializers.ModelSerializer):
class InvenTreeModelSerializer(serializers.ModelSerializer):
"""
Inherits the standard Django ModelSerializer class,
but also ensures that the underlying model class data are checked on validation.
"""
"""Inherits the standard Django ModelSerializer class, but also ensures that the underlying model class data are checked on validation."""
def __init__(self, instance=None, data=empty, **kwargs):
"""
Custom __init__ routine to ensure that *default* values (as specified in the ORM)
are used by the DRF serializers, *if* the values are not provided by the user.
"""
"""Custom __init__ routine to ensure that *default* values (as specified in the ORM) are used by the DRF serializers, *if* the values are not provided by the user."""
# If instance is None, we are creating a new instance
if instance is None and data is not empty:
@ -118,6 +108,7 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
"""
Update the field IF (and ONLY IF):
- The field has a specified default value
- The field does not already have a value set
"""
@ -137,11 +128,10 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
super().__init__(instance, data, **kwargs)
def get_initial(self):
"""
Construct initial data for the serializer.
"""Construct initial data for the serializer.
Use the 'default' values specified by the django model definition
"""
initials = super().get_initial().copy()
# Are we creating a new instance?
@ -168,11 +158,7 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
return initials
def save(self, **kwargs):
"""
Catch any django ValidationError thrown at the moment save() is called,
and re-throw as a DRF ValidationError
"""
"""Catch any django ValidationError thrown at the moment `save` is called, and re-throw as a DRF ValidationError."""
try:
super().save(**kwargs)
except (ValidationError, DjangoValidationError) as exc:
@ -181,10 +167,7 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
return self.instance
def update(self, instance, validated_data):
"""
Catch any django ValidationError, and re-throw as a DRF ValidationError
"""
"""Catch any django ValidationError, and re-throw as a DRF ValidationError."""
try:
instance = super().update(instance, validated_data)
except (ValidationError, DjangoValidationError) as exc:
@ -193,12 +176,11 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
return instance
def run_validation(self, data=empty):
"""
Perform serializer validation.
"""Perform serializer validation.
In addition to running validators on the serializer fields,
this class ensures that the underlying model is also validated.
"""
# Run any native validation checks first (may raise a ValidationError)
data = super().run_validation(data)
@ -237,20 +219,17 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
class ReferenceIndexingSerializerMixin():
"""
This serializer mixin ensures the the reference is not to big / small
for the BigIntegerField
"""
"""This serializer mixin ensures the the reference is not to big / small for the BigIntegerField."""
def validate_reference(self, value):
"""Ensures the reference is not to big / small for the BigIntegerField."""
if extract_int(value) > models.BigIntegerField.MAX_BIGINT:
raise serializers.ValidationError('reference is to to big')
return value
class InvenTreeAttachmentSerializerField(serializers.FileField):
"""
Override the DRF native FileField serializer,
to remove the leading server path.
"""Override the DRF native FileField serializer, to remove the leading server path.
For example, the FileField might supply something like:
@ -269,7 +248,7 @@ class InvenTreeAttachmentSerializerField(serializers.FileField):
"""
def to_representation(self, value):
"""To json-serializable type."""
if not value:
return None
@ -277,8 +256,7 @@ class InvenTreeAttachmentSerializerField(serializers.FileField):
class InvenTreeAttachmentSerializer(InvenTreeModelSerializer):
"""
Special case of an InvenTreeModelSerializer, which handles an "attachment" model.
"""Special case of an InvenTreeModelSerializer, which handles an "attachment" model.
The only real addition here is that we support "renaming" of the attachment file.
"""
@ -298,13 +276,13 @@ class InvenTreeAttachmentSerializer(InvenTreeModelSerializer):
class InvenTreeImageSerializerField(serializers.ImageField):
"""
Custom image serializer.
"""Custom image serializer.
On upload, validate that the file is a valid image file
"""
def to_representation(self, value):
"""To json-serializable type."""
if not value:
return None
@ -312,15 +290,15 @@ class InvenTreeImageSerializerField(serializers.ImageField):
class InvenTreeDecimalField(serializers.FloatField):
"""
Custom serializer for decimal fields. Solves the following issues:
"""Custom serializer for decimal fields.
Solves the following issues:
- The normal DRF DecimalField renders values with trailing zeros
- Using a FloatField can result in rounding issues: https://code.djangoproject.com/ticket/30290
"""
def to_internal_value(self, data):
"""Convert to python type."""
# Convert the value to a string, and then a decimal
try:
return Decimal(str(data))
@ -329,8 +307,7 @@ class InvenTreeDecimalField(serializers.FloatField):
class DataFileUploadSerializer(serializers.Serializer):
"""
Generic serializer for uploading a data file, and extracting a dataset.
"""Generic serializer for uploading a data file, and extracting a dataset.
- Validates uploaded file
- Extracts column names
@ -341,6 +318,8 @@ class DataFileUploadSerializer(serializers.Serializer):
TARGET_MODEL = None
class Meta:
"""Metaclass options."""
fields = [
'data_file',
]
@ -353,10 +332,7 @@ class DataFileUploadSerializer(serializers.Serializer):
)
def validate_data_file(self, data_file):
"""
Perform validation checks on the uploaded data file.
"""
"""Perform validation checks on the uploaded data file."""
self.filename = data_file.name
name, ext = os.path.splitext(data_file.name)
@ -406,15 +382,13 @@ class DataFileUploadSerializer(serializers.Serializer):
return data_file
def match_column(self, column_name, field_names, exact=False):
"""
Attempt to match a column name (from the file) to a field (defined in the model)
"""Attempt to match a column name (from the file) to a field (defined in the model).
Order of matching is:
- Direct match
- Case insensitive match
- Fuzzy match
"""
if not column_name:
return None
@ -439,10 +413,7 @@ class DataFileUploadSerializer(serializers.Serializer):
return None
def extract_data(self):
"""
Returns dataset extracted from the file
"""
"""Returns dataset extracted from the file."""
# Provide a dict of available import fields for the model
model_fields = {}
@ -483,12 +454,12 @@ class DataFileUploadSerializer(serializers.Serializer):
}
def save(self):
"""Empty overwrite for save."""
...
class DataFileExtractSerializer(serializers.Serializer):
"""
Generic serializer for extracting data from an imported dataset.
"""Generic serializer for extracting data from an imported dataset.
- User provides an array of matched headers
- User provides an array of raw data rows
@ -498,6 +469,8 @@ class DataFileExtractSerializer(serializers.Serializer):
TARGET_MODEL = None
class Meta:
"""Metaclass options."""
fields = [
'columns',
'rows',
@ -520,7 +493,7 @@ class DataFileExtractSerializer(serializers.Serializer):
)
def validate(self, data):
"""Clean data."""
data = super().validate(data)
self.columns = data.get('columns', [])
@ -538,7 +511,7 @@ class DataFileExtractSerializer(serializers.Serializer):
@property
def data(self):
"""Returns current data."""
if self.TARGET_MODEL:
try:
model_fields = self.TARGET_MODEL.get_import_fields()
@ -548,9 +521,7 @@ class DataFileExtractSerializer(serializers.Serializer):
rows = []
for row in self.rows:
"""
Optionally pre-process each row, before sending back to the client
"""
"""Optionally pre-process each row, before sending back to the client."""
processed_row = self.process_row(self.row_to_dict(row))
@ -567,22 +538,17 @@ class DataFileExtractSerializer(serializers.Serializer):
}
def process_row(self, row):
"""
Process a 'row' of data, which is a mapped column:value dict
"""Process a 'row' of data, which is a mapped column:value dict.
Returns either a mapped column:value dict, or None.
If the function returns None, the column is ignored!
"""
# Default implementation simply returns the original row data
return row
def row_to_dict(self, row):
"""
Convert a "row" to a named data dict
"""
"""Convert a "row" to a named data dict."""
row_dict = {
'errors': {},
}
@ -598,10 +564,7 @@ class DataFileExtractSerializer(serializers.Serializer):
return row_dict
def validate_extracted_columns(self):
"""
Perform custom validation of header mapping.
"""
"""Perform custom validation of header mapping."""
if self.TARGET_MODEL:
try:
model_fields = self.TARGET_MODEL.get_import_fields()
@ -631,7 +594,5 @@ class DataFileExtractSerializer(serializers.Serializer):
cols_seen.add(col)
def save(self):
"""
No "save" action for this serializer
"""
...
"""No "save" action for this serializer."""
pass

View File

@ -1,5 +1,4 @@
"""
Django settings for InvenTree project.
"""Django settings for InvenTree project.
In practice the settings in this file should not be adjusted,
instead settings can be configured in the config.yaml file
@ -8,7 +7,6 @@ located in the top level project directory.
This allows implementation configuration to be hidden from source control,
as well as separate configuration parameters from the more complex
database setup in this file.
"""
import logging

View File

@ -1,6 +1,4 @@
"""
Provides system status functionality checks.
"""
"""Provides system status functionality checks."""
# -*- coding: utf-8 -*-
import logging
@ -19,10 +17,7 @@ logger = logging.getLogger("inventree")
def is_worker_running(**kwargs):
"""
Return True if the background worker process is oprational
"""
"""Return True if the background worker process is oprational."""
clusters = Stat.get_all()
if len(clusters) > 0:
@ -48,12 +43,10 @@ def is_worker_running(**kwargs):
def is_email_configured():
"""
Check if email backend is configured.
"""Check if email backend is configured.
NOTE: This does not check if the configuration is valid!
"""
configured = True
if InvenTree.ready.isInTestMode():
@ -87,12 +80,10 @@ def is_email_configured():
def check_system_health(**kwargs):
"""
Check that the InvenTree system is running OK.
"""Check that the InvenTree system is running OK.
Returns True if all system checks pass.
"""
result = True
if InvenTree.ready.isInTestMode():

View File

@ -1,9 +1,11 @@
"""Status codes for InvenTree."""
from django.utils.translation import gettext_lazy as _
class StatusCode:
"""
Base class for representing a set of StatusCodes.
"""Base class for representing a set of StatusCodes.
This is used to map a set of integer values to text.
"""
@ -11,10 +13,7 @@ class StatusCode:
@classmethod
def render(cls, key, large=False):
"""
Render the value as a HTML label.
"""
"""Render the value as a HTML label."""
# If the key cannot be found, pass it back
if key not in cls.options.keys():
return key
@ -31,10 +30,7 @@ class StatusCode:
@classmethod
def list(cls):
"""
Return the StatusCode options as a list of mapped key / value items
"""
"""Return the StatusCode options as a list of mapped key / value items."""
codes = []
for key in cls.options.keys():
@ -55,28 +51,32 @@ class StatusCode:
@classmethod
def text(cls, key):
"""Text for supplied status code."""
return cls.options.get(key, None)
@classmethod
def items(cls):
"""All status code items."""
return cls.options.items()
@classmethod
def keys(cls):
"""All status code keys."""
return cls.options.keys()
@classmethod
def labels(cls):
"""All status code labels."""
return cls.options.values()
@classmethod
def label(cls, value):
""" Return the status code label associated with the provided value """
"""Return the status code label associated with the provided value."""
return cls.options.get(value, value)
@classmethod
def value(cls, label):
""" Return the value associated with the provided label """
"""Return the value associated with the provided label."""
for k in cls.options.keys():
if cls.options[k].lower() == label.lower():
return k
@ -85,9 +85,7 @@ class StatusCode:
class PurchaseOrderStatus(StatusCode):
"""
Defines a set of status codes for a PurchaseOrder
"""
"""Defines a set of status codes for a PurchaseOrder."""
# Order status codes
PENDING = 10 # Order is pending (not yet placed)
@ -130,7 +128,7 @@ class PurchaseOrderStatus(StatusCode):
class SalesOrderStatus(StatusCode):
""" Defines a set of status codes for a SalesOrder """
"""Defines a set of status codes for a SalesOrder."""
PENDING = 10 # Order is pending
SHIPPED = 20 # Order has been shipped to customer
@ -166,6 +164,7 @@ class SalesOrderStatus(StatusCode):
class StockStatus(StatusCode):
"""Status codes for Stock."""
OK = 10 # Item is OK
ATTENTION = 50 # Item requires attention
@ -207,6 +206,7 @@ class StockStatus(StatusCode):
class StockHistoryCode(StatusCode):
"""Status codes for StockHistory."""
LEGACY = 0
@ -295,8 +295,8 @@ class StockHistoryCode(StatusCode):
class BuildStatus(StatusCode):
"""Build status codes."""
# Build status codes
PENDING = 10 # Build is pending / active
PRODUCTION = 20 # BuildOrder is in production
CANCELLED = 30 # Build was cancelled

View File

@ -1,3 +1,5 @@
"""Functions for tasks and a few general async tasks."""
import json
import logging
import re
@ -16,11 +18,10 @@ logger = logging.getLogger("inventree")
def schedule_task(taskname, **kwargs):
"""
Create a scheduled task.
"""Create a scheduled task.
If the task has already been scheduled, ignore!
"""
# If unspecified, repeat indefinitely
repeats = kwargs.pop('repeats', -1)
kwargs['repeats'] = repeats
@ -52,7 +53,7 @@ def schedule_task(taskname, **kwargs):
def raise_warning(msg):
"""Log and raise a warning"""
"""Log and raise a warning."""
logger.warning(msg)
# If testing is running raise a warning that can be asserted
@ -61,15 +62,11 @@ def raise_warning(msg):
def offload_task(taskname, *args, force_sync=False, **kwargs):
"""
Create an AsyncTask if workers are running.
This is different to a 'scheduled' task,
in that it only runs once!
"""Create an AsyncTask if workers are running. This is different to a 'scheduled' task, in that it only runs once!
If workers are not running or force_sync flag
is set then the task is ran synchronously.
"""
try:
import importlib
@ -129,14 +126,10 @@ def offload_task(taskname, *args, force_sync=False, **kwargs):
def heartbeat():
"""
Simple task which runs at 5 minute intervals,
so we can determine that the background worker
is actually running.
"""Simple task which runs at 5 minute intervals, so we can determine that the background worker is actually running.
(There is probably a less "hacky" way of achieving this)?
"""
try:
from django_q.models import Success
except AppRegistryNotReady: # pragma: no cover
@ -156,11 +149,7 @@ def heartbeat():
def delete_successful_tasks():
"""
Delete successful task logs
which are more than a month old.
"""
"""Delete successful task logs which are more than a month old."""
try:
from django_q.models import Success
except AppRegistryNotReady: # pragma: no cover
@ -179,10 +168,7 @@ def delete_successful_tasks():
def delete_old_error_logs():
"""
Delete old error logs from the server
"""
"""Delete old error logs from the server."""
try:
from error_report.models import Error
@ -204,10 +190,7 @@ def delete_old_error_logs():
def check_for_updates():
"""
Check if there is an update for InvenTree
"""
"""Check if there is an update for InvenTree."""
try:
import common.models
except AppRegistryNotReady: # pragma: no cover
@ -249,10 +232,7 @@ def check_for_updates():
def update_exchange_rates():
"""
Update currency exchange rates
"""
"""Update currency exchange rates."""
try:
from djmoney.contrib.exchange.models import ExchangeBackend, Rate
@ -293,11 +273,7 @@ def update_exchange_rates():
def send_email(subject, body, recipients, from_email=None, html_message=None):
"""
Send an email with the specified subject and body,
to the specified recipients list.
"""
"""Send an email with the specified subject and body, to the specified recipients list."""
if type(recipients) == str:
recipients = [recipients]

View File

@ -1,4 +1,4 @@
""" Low level tests for the InvenTree API """
"""Low level tests for the InvenTree API."""
from base64 import b64encode
@ -12,8 +12,7 @@ from users.models import RuleSet
class HTMLAPITests(InvenTreeTestCase):
"""
Test that we can access the REST API endpoints via the HTML interface.
"""Test that we can access the REST API endpoints via the HTML interface.
History: Discovered on 2021-06-28 a bug in InvenTreeModelSerializer,
which raised an AssertionError when using the HTML API interface,
@ -22,6 +21,7 @@ class HTMLAPITests(InvenTreeTestCase):
roles = 'all'
def test_part_api(self):
"""Test that part list is working."""
url = reverse('api-part-list')
# Check JSON response
@ -33,6 +33,7 @@ class HTMLAPITests(InvenTreeTestCase):
self.assertEqual(response.status_code, 200)
def test_build_api(self):
"""Test that build list is working."""
url = reverse('api-build-list')
# Check JSON response
@ -44,6 +45,7 @@ class HTMLAPITests(InvenTreeTestCase):
self.assertEqual(response.status_code, 200)
def test_stock_api(self):
"""Test that stock list is working."""
url = reverse('api-stock-list')
# Check JSON response
@ -55,6 +57,7 @@ class HTMLAPITests(InvenTreeTestCase):
self.assertEqual(response.status_code, 200)
def test_company_list(self):
"""Test that company list is working."""
url = reverse('api-company-list')
# Check JSON response
@ -66,14 +69,13 @@ class HTMLAPITests(InvenTreeTestCase):
self.assertEqual(response.status_code, 200)
def test_not_found(self):
"""Test that the NotFoundView is working"""
"""Test that the NotFoundView is working."""
response = self.client.get('/api/anc')
self.assertEqual(response.status_code, 404)
class APITests(InvenTreeAPITestCase):
""" Tests for the InvenTree API """
"""Tests for the InvenTree API."""
fixtures = [
'location',
@ -81,16 +83,11 @@ class APITests(InvenTreeAPITestCase):
'part',
'stock'
]
token = None
auto_login = False
def setUp(self):
super().setUp()
def basicAuth(self):
"""Helper function to use basic auth."""
# Use basic authentication
authstring = bytes("{u}:{p}".format(u=self.username, p=self.password), "ascii")
@ -100,7 +97,7 @@ class APITests(InvenTreeAPITestCase):
self.client.credentials(HTTP_AUTHORIZATION="Basic {auth}".format(auth=auth))
def tokenAuth(self):
"""Helper function to use token auth."""
self.basicAuth()
token_url = reverse('api-token')
response = self.client.get(token_url, format='json', data={})
@ -112,6 +109,7 @@ class APITests(InvenTreeAPITestCase):
self.token = token
def test_token_failure(self):
"""Test token resolve endpoint does not work without basic auth."""
# Test token endpoint without basic auth
url = reverse('api-token')
response = self.client.get(url, format='json')
@ -120,15 +118,12 @@ class APITests(InvenTreeAPITestCase):
self.assertIsNone(self.token)
def test_token_success(self):
"""Test token auth works."""
self.tokenAuth()
self.assertIsNotNone(self.token)
def test_info_view(self):
"""
Test that we can read the 'info-view' endpoint.
"""
"""Test that we can read the 'info-view' endpoint."""
url = reverse('api-inventree-info')
response = self.client.get(url, format='json')
@ -141,12 +136,10 @@ class APITests(InvenTreeAPITestCase):
self.assertEqual('InvenTree', data['server'])
def test_role_view(self):
"""
Test that we can access the 'roles' view for the logged in user.
"""Test that we can access the 'roles' view for the logged in user.
Also tests that it is *not* accessible if the client is not logged in.
"""
url = reverse('api-user-roles')
response = self.client.get(url, format='json')
@ -182,10 +175,7 @@ class APITests(InvenTreeAPITestCase):
self.assertNotIn('delete', roles[rule])
def test_with_superuser(self):
"""
Superuser should have *all* roles assigned
"""
"""Superuser should have *all* roles assigned."""
self.user.is_superuser = True
self.user.save()
@ -202,10 +192,7 @@ class APITests(InvenTreeAPITestCase):
self.assertIn(perm, roles[rule])
def test_with_roles(self):
"""
Assign some roles to the user
"""
"""Assign some roles to the user."""
self.basicAuth()
response = self.get(reverse('api-user-roles'))
@ -220,10 +207,7 @@ class APITests(InvenTreeAPITestCase):
self.assertIn('change', roles['build'])
def test_list_endpoint_actions(self):
"""
Tests for the OPTIONS method for API endpoints.
"""
"""Tests for the OPTIONS method for API endpoints."""
self.basicAuth()
# Without any 'part' permissions, we should not see any available actions
@ -252,10 +236,7 @@ class APITests(InvenTreeAPITestCase):
self.assertIn('GET', actions)
def test_detail_endpoint_actions(self):
"""
Tests for detail API endpoint actions
"""
"""Tests for detail API endpoint actions."""
self.basicAuth()
url = reverse('api-part-detail', kwargs={'pk': 1})

View File

@ -1,4 +1,4 @@
"""Tests for middleware functions"""
"""Tests for middleware functions."""
from django.urls import reverse
@ -6,16 +6,16 @@ from InvenTree.helpers import InvenTreeTestCase
class MiddlewareTests(InvenTreeTestCase):
"""Test for middleware functions"""
"""Test for middleware functions."""
def check_path(self, url, code=200, **kwargs):
"""Helper function to run a request."""
response = self.client.get(url, HTTP_ACCEPT='application/json', **kwargs)
self.assertEqual(response.status_code, code)
return response
def test_AuthRequiredMiddleware(self):
"""Test the auth middleware"""
"""Test the auth middleware."""
# test that /api/ routes go through
self.check_path(reverse('api-inventree-info'))
@ -40,7 +40,7 @@ class MiddlewareTests(InvenTreeTestCase):
self.check_path(reverse('settings.js'), 401)
def test_token_auth(self):
"""Test auth with token auth"""
"""Test auth with token auth."""
# get token
response = self.client.get(reverse('api-token'), format='json', data={})
token = response.data['token']

View File

@ -1,6 +1,4 @@
"""
Unit tests for task management
"""
"""Unit tests for task management."""
from datetime import timedelta
@ -18,19 +16,14 @@ threshold_low = threshold - timedelta(days=1)
class ScheduledTaskTests(TestCase):
"""
Unit tests for scheduled tasks
"""
"""Unit tests for scheduled tasks."""
def get_tasks(self, name):
"""Helper function to get a Schedule object."""
return Schedule.objects.filter(func=name)
def test_add_task(self):
"""
Ensure that duplicate tasks cannot be added.
"""
"""Ensure that duplicate tasks cannot be added."""
task = 'InvenTree.tasks.heartbeat'
self.assertEqual(self.get_tasks(task).count(), 0)
@ -53,16 +46,15 @@ class ScheduledTaskTests(TestCase):
def get_result():
"""Demo function for test_offloading"""
"""Demo function for test_offloading."""
return 'abc'
class InvenTreeTaskTests(TestCase):
"""Unit tests for tasks"""
"""Unit tests for tasks."""
def test_offloading(self):
"""Test task offloading"""
"""Test task offloading."""
# Run with function ref
InvenTree.tasks.offload_task(get_result)
@ -83,11 +75,11 @@ class InvenTreeTaskTests(TestCase):
InvenTree.tasks.offload_task('InvenTree.test_tasks.doesnotexsist')
def test_task_hearbeat(self):
"""Test the task heartbeat"""
"""Test the task heartbeat."""
InvenTree.tasks.offload_task(InvenTree.tasks.heartbeat)
def test_task_delete_successful_tasks(self):
"""Test the task delete_successful_tasks"""
"""Test the task delete_successful_tasks."""
from django_q.models import Success
Success.objects.create(name='abc', func='abc', stopped=threshold, started=threshold_low)
@ -96,8 +88,7 @@ class InvenTreeTaskTests(TestCase):
self.assertEqual(len(results), 0)
def test_task_delete_old_error_logs(self):
"""Test the task delete_old_error_logs"""
"""Test the task delete_old_error_logs."""
# Create error
error_obj = Error.objects.create()
error_obj.when = threshold_low
@ -115,7 +106,7 @@ class InvenTreeTaskTests(TestCase):
self.assertEqual(len(errors), 0)
def test_task_check_for_updates(self):
"""Test the task check_for_updates"""
"""Test the task check_for_updates."""
# Check that setting should be empty
self.assertEqual(InvenTreeSetting.get_setting('INVENTREE_LATEST_VERSION'), '')

View File

@ -1,6 +1,4 @@
"""
Validate that all URLs specified in template files are correct.
"""
"""Validate that all URLs specified in template files are correct."""
import os
import re
@ -11,6 +9,7 @@ from django.urls import reverse
class URLTest(TestCase):
"""Test all files for broken url tags."""
# Need fixture data in the database
fixtures = [
@ -35,11 +34,7 @@ class URLTest(TestCase):
]
def find_files(self, suffix):
"""
Search for all files in the template directories,
which can have URLs rendered
"""
"""Search for all files in the template directories, which can have URLs rendered."""
template_dirs = [
('build', 'templates'),
('common', 'templates'),
@ -71,10 +66,7 @@ class URLTest(TestCase):
return template_files
def find_urls(self, input_file):
"""
Search for all instances of {% url %} in supplied template file
"""
"""Search for all instances of {% url %} in supplied template file."""
urls = []
pattern = "{% url ['\"]([^'\"]+)['\"]([^%]*)%}"
@ -100,10 +92,7 @@ class URLTest(TestCase):
return urls
def reverse_url(self, url_pair):
"""
Perform lookup on the URL
"""
"""Perform lookup on the URL."""
url, pk = url_pair
# Ignore "renaming"
@ -125,24 +114,21 @@ class URLTest(TestCase):
reverse(url)
def check_file(self, f):
"""
Run URL checks for the provided file
"""
"""Run URL checks for the provided file."""
urls = self.find_urls(f)
for url in urls:
self.reverse_url(url)
def test_html_templates(self):
"""Test all HTML templates for broken url tags."""
template_files = self.find_files("*.html")
for f in template_files:
self.check_file(f)
def test_js_templates(self):
"""Test all JS templates for broken url tags."""
template_files = self.find_files("*.js")
for f in template_files:

View File

@ -1,6 +1,4 @@
"""
Unit tests for the main web views
"""
"""Unit tests for the main web views."""
import os
import re
@ -11,33 +9,26 @@ from InvenTree.helpers import InvenTreeTestCase
class ViewTests(InvenTreeTestCase):
""" Tests for various top-level views """
"""Tests for various top-level views."""
username = 'test_user'
password = 'test_pass'
def test_api_doc(self):
""" Test that the api-doc view works """
"""Test that the api-doc view works."""
api_url = os.path.join(reverse('index'), 'api-doc') + '/'
response = self.client.get(api_url)
self.assertEqual(response.status_code, 200)
def test_index_redirect(self):
"""
top-level URL should redirect to "index" page
"""
"""Top-level URL should redirect to "index" page."""
response = self.client.get("/")
self.assertEqual(response.status_code, 302)
def get_index_page(self):
"""
Retrieve the index page (used for subsequent unit tests)
"""
"""Retrieve the index page (used for subsequent unit tests)"""
response = self.client.get("/index/")
self.assertEqual(response.status_code, 200)
@ -45,10 +36,7 @@ class ViewTests(InvenTreeTestCase):
return str(response.content.decode())
def test_panels(self):
"""
Test that the required 'panels' are present
"""
"""Test that the required 'panels' are present."""
content = self.get_index_page()
self.assertIn("<div id='detail-panels'>", content)
@ -56,10 +44,7 @@ class ViewTests(InvenTreeTestCase):
# TODO: In future, run the javascript and ensure that the panels get created!
def test_js_load(self):
"""
Test that the required javascript files are loaded correctly
"""
"""Test that the required javascript files are loaded correctly."""
# Change this number as more javascript files are added to the index page
N_SCRIPT_FILES = 40

View File

@ -1,3 +1,5 @@
"""Test general functions and helpers."""
import json
import os
import time
@ -24,20 +26,17 @@ from .validators import validate_overage, validate_part_name
class ValidatorTest(TestCase):
""" Simple tests for custom field validators """
"""Simple tests for custom field validators."""
def test_part_name(self):
""" Test part name validator """
"""Test part name validator."""
validate_part_name('hello world')
with self.assertRaises(django_exceptions.ValidationError):
validate_part_name('This | name is not } valid')
def test_overage(self):
""" Test overage validator """
"""Test overage validator."""
validate_overage("100%")
validate_overage("10")
validate_overage("45.2 %")
@ -59,11 +58,10 @@ class ValidatorTest(TestCase):
class TestHelpers(TestCase):
""" Tests for InvenTree helper functions """
"""Tests for InvenTree helper functions."""
def test_image_url(self):
""" Test if a filename looks like an image """
"""Test if a filename looks like an image."""
for name in ['ape.png', 'bat.GiF', 'apple.WeBP', 'BiTMap.Bmp']:
self.assertTrue(helpers.TestIfImageURL(name))
@ -71,8 +69,7 @@ class TestHelpers(TestCase):
self.assertFalse(helpers.TestIfImageURL(name))
def test_str2bool(self):
""" Test string to boolean conversion """
"""Test string to boolean conversion."""
for s in ['yes', 'Y', 'ok', '1', 'OK', 'Ok', 'tRuE', 'oN']:
self.assertTrue(helpers.str2bool(s))
self.assertFalse(helpers.str2bool(s, test=False))
@ -86,7 +83,7 @@ class TestHelpers(TestCase):
self.assertFalse(helpers.str2bool(s, test=False))
def test_isnull(self):
"""Test isNull."""
for s in ['null', 'none', '', '-1', 'false']:
self.assertTrue(helpers.isNull(s))
@ -94,35 +91,35 @@ class TestHelpers(TestCase):
self.assertFalse(helpers.isNull(s))
def testStaticUrl(self):
"""Test static url helpers."""
self.assertEqual(helpers.getStaticUrl('test.jpg'), '/static/test.jpg')
self.assertEqual(helpers.getBlankImage(), '/static/img/blank_image.png')
self.assertEqual(helpers.getBlankThumbnail(), '/static/img/blank_image.thumbnail.png')
def testMediaUrl(self):
"""Test getMediaUrl."""
self.assertEqual(helpers.getMediaUrl('xx/yy.png'), '/media/xx/yy.png')
def testDecimal2String(self):
"""Test decimal2string."""
self.assertEqual(helpers.decimal2string(Decimal('1.2345000')), '1.2345')
self.assertEqual(helpers.decimal2string('test'), 'test')
class TestQuoteWrap(TestCase):
""" Tests for string wrapping """
"""Tests for string wrapping."""
def test_single(self):
"""Test WrapWithQuotes."""
self.assertEqual(helpers.WrapWithQuotes('hello'), '"hello"')
self.assertEqual(helpers.WrapWithQuotes('hello"'), '"hello"')
class TestIncrement(TestCase):
"""Tests for increment function."""
def tests(self):
""" Test 'intelligent' incrementing function """
"""Test 'intelligent' incrementing function."""
tests = [
("", ""),
(1, "2"),
@ -142,10 +139,10 @@ class TestIncrement(TestCase):
class TestMakeBarcode(TestCase):
""" Tests for barcode string creation """
"""Tests for barcode string creation."""
def test_barcode_extended(self):
"""Test creation of barcode with extended data."""
bc = helpers.MakeBarcode(
"part",
3,
@ -166,7 +163,7 @@ class TestMakeBarcode(TestCase):
self.assertEqual(data['part']['url'], 'www.google.com')
def test_barcode_brief(self):
"""Test creation of simple barcode."""
bc = helpers.MakeBarcode(
"stockitem",
7,
@ -178,27 +175,29 @@ class TestMakeBarcode(TestCase):
class TestDownloadFile(TestCase):
"""Tests for DownloadFile."""
def test_download(self):
"""Tests for DownloadFile."""
helpers.DownloadFile("hello world", "out.txt")
helpers.DownloadFile(bytes(b"hello world"), "out.bin")
class TestMPTT(TestCase):
""" Tests for the MPTT tree models """
"""Tests for the MPTT tree models."""
fixtures = [
'location',
]
def setUp(self):
"""Setup for all tests."""
super().setUp()
StockLocation.objects.rebuild()
def test_self_as_parent(self):
""" Test that we cannot set self as parent """
"""Test that we cannot set self as parent."""
loc = StockLocation.objects.get(pk=4)
loc.parent = loc
@ -206,8 +205,7 @@ class TestMPTT(TestCase):
loc.save()
def test_child_as_parent(self):
""" Test that we cannot set a child as parent """
"""Test that we cannot set a child as parent."""
parent = StockLocation.objects.get(pk=4)
child = StockLocation.objects.get(pk=5)
@ -217,8 +215,7 @@ class TestMPTT(TestCase):
parent.save()
def test_move(self):
""" Move an item to a different tree """
"""Move an item to a different tree."""
drawer = StockLocation.objects.get(name='Drawer_1')
# Record the tree ID
@ -233,10 +230,10 @@ class TestMPTT(TestCase):
class TestSerialNumberExtraction(TestCase):
""" Tests for serial number extraction code """
"""Tests for serial number extraction code."""
def test_simple(self):
"""Test simple serial numbers."""
e = helpers.extract_serial_numbers
sn = e("1-5", 5, 1)
@ -301,7 +298,7 @@ class TestSerialNumberExtraction(TestCase):
self.assertEqual(sn, [5, 6, 7, 8])
def test_failures(self):
"""Test wron serial numbers."""
e = helpers.extract_serial_numbers
# Test duplicates
@ -332,6 +329,7 @@ class TestSerialNumberExtraction(TestCase):
e("1, 2, 3, E-5", 5, 1)
def test_combinations(self):
"""Test complex serial number combinations."""
e = helpers.extract_serial_numbers
sn = e("1 3-5 9+2", 7, 1)
@ -352,12 +350,10 @@ class TestSerialNumberExtraction(TestCase):
class TestVersionNumber(TestCase):
"""
Unit tests for version number functions
"""
"""Unit tests for version number functions."""
def test_tuple(self):
"""Test inventreeVersionTuple."""
v = version.inventreeVersionTuple()
self.assertEqual(len(v), 3)
@ -366,10 +362,7 @@ class TestVersionNumber(TestCase):
self.assertTrue(s in version.inventreeVersion())
def test_comparison(self):
"""
Test direct comparison of version numbers
"""
"""Test direct comparison of version numbers."""
v_a = version.inventreeVersionTuple('1.2.0')
v_b = version.inventreeVersionTuple('1.2.3')
v_c = version.inventreeVersionTuple('1.2.4')
@ -381,8 +374,7 @@ class TestVersionNumber(TestCase):
self.assertTrue(v_d > v_a)
def test_commit_info(self):
"""Test that the git commit information is extracted successfully"""
"""Test that the git commit information is extracted successfully."""
envs = {
'INVENTREE_COMMIT_HASH': 'abcdef',
'INVENTREE_COMMIT_DATE': '2022-12-31'
@ -406,12 +398,10 @@ class TestVersionNumber(TestCase):
class CurrencyTests(TestCase):
"""
Unit tests for currency / exchange rate functionality
"""
"""Unit tests for currency / exchange rate functionality."""
def test_rates(self):
"""Test exchange rate update."""
# Initially, there will not be any exchange rate information
rates = Rate.objects.all()
@ -459,33 +449,32 @@ class CurrencyTests(TestCase):
class TestStatus(TestCase):
"""
Unit tests for status functions
"""
"""Unit tests for status functions."""
def test_check_system_healt(self):
"""test that the system health check is false in testing -> background worker not running"""
"""Test that the system health check is false in testing -> background worker not running."""
self.assertEqual(status.check_system_health(), False)
def test_TestMode(self):
"""Test isInTestMode check."""
self.assertTrue(ready.isInTestMode())
def test_Importing(self):
"""Test isImportingData check."""
self.assertEqual(ready.isImportingData(), False)
class TestSettings(helpers.InvenTreeTestCase):
"""
Unit tests for settings
"""
"""Unit tests for settings."""
superuser = True
def in_env_context(self, envs={}):
"""Patch the env to include the given dict"""
"""Patch the env to include the given dict."""
return mock.patch.dict(os.environ, envs)
def run_reload(self, envs={}):
"""Helper function to reload InvenTree."""
from plugin import registry
with self.in_env_context(envs):
@ -494,6 +483,7 @@ class TestSettings(helpers.InvenTreeTestCase):
@override_settings(TESTING_ENV=True)
def test_set_user_to_few(self):
"""Test adding an admin user via env variables."""
user_model = get_user_model()
# add shortcut
user_count = user_model.objects.count
@ -537,7 +527,7 @@ class TestSettings(helpers.InvenTreeTestCase):
settings.TESTING_ENV = False
def test_initial_install(self):
"""Test if install of plugins on startup works"""
"""Test if install of plugins on startup works."""
from plugin import registry
# Check an install run
@ -553,6 +543,7 @@ class TestSettings(helpers.InvenTreeTestCase):
self.assertEqual(response, True)
def test_helpers_cfg_file(self):
"""Test get_config_file."""
# normal run - not configured
valid = [
@ -567,6 +558,7 @@ class TestSettings(helpers.InvenTreeTestCase):
self.assertIn('inventree/my_special_conf.yaml', config.get_config_file().lower())
def test_helpers_plugin_file(self):
"""Test get_plugin_file."""
# normal run - not configured
valid = [
@ -581,6 +573,7 @@ class TestSettings(helpers.InvenTreeTestCase):
self.assertIn('my_special_plugins.txt', config.get_plugin_file())
def test_helpers_setting(self):
"""Test get_setting."""
TEST_ENV_NAME = '123TEST'
# check that default gets returned if not present
self.assertEqual(config.get_setting(TEST_ENV_NAME, None, '123!'), '123!')
@ -591,12 +584,10 @@ class TestSettings(helpers.InvenTreeTestCase):
class TestInstanceName(helpers.InvenTreeTestCase):
"""
Unit tests for instance name
"""
"""Unit tests for instance name."""
def test_instance_name(self):
"""Test instance name settings."""
# default setting
self.assertEqual(version.inventreeInstanceTitle(), 'InvenTree')

View File

@ -1,5 +1,4 @@
"""
Top-level URL lookup for InvenTree application.
"""Top-level URL lookup for InvenTree application.
Passes URL lookup downstream to each app as required.
"""

View File

@ -1,6 +1,4 @@
"""
Custom field validators for InvenTree
"""
"""Custom field validators for InvenTree."""
import re
from decimal import Decimal, InvalidOperation
@ -15,20 +13,18 @@ import common.models
def validate_currency_code(code):
"""
Check that a given code is a valid currency code.
"""
"""Check that a given code is a valid currency code."""
if code not in CURRENCIES:
raise ValidationError(_('Not a valid currency code'))
def allowable_url_schemes():
""" Return the list of allowable URL schemes.
"""Return the list of allowable URL schemes.
In addition to the default schemes allowed by Django,
the install configuration file (config.yaml) can specify
extra schemas """
extra schemas
"""
# Default schemes
schemes = ['http', 'https', 'ftp', 'ftps']
@ -42,9 +38,7 @@ def allowable_url_schemes():
def validate_part_name(value):
""" Prevent some illegal characters in part names.
"""
"""Prevent some illegal characters in part names."""
for c in ['|', '#', '$', '{', '}']:
if c in str(value):
raise ValidationError(
@ -53,8 +47,7 @@ def validate_part_name(value):
def validate_part_ipn(value):
""" Validate the Part IPN against regex rule """
"""Validate the Part IPN against regex rule."""
pattern = common.models.InvenTreeSetting.get_setting('PART_IPN_REGEX')
if pattern:
@ -65,10 +58,7 @@ def validate_part_ipn(value):
def validate_build_order_reference(value):
"""
Validate the 'reference' field of a BuildOrder
"""
"""Validate the 'reference' field of a BuildOrder."""
pattern = common.models.InvenTreeSetting.get_setting('BUILDORDER_REFERENCE_REGEX')
if pattern:
@ -79,10 +69,7 @@ def validate_build_order_reference(value):
def validate_purchase_order_reference(value):
"""
Validate the 'reference' field of a PurchaseOrder
"""
"""Validate the 'reference' field of a PurchaseOrder."""
pattern = common.models.InvenTreeSetting.get_setting('PURCHASEORDER_REFERENCE_REGEX')
if pattern:
@ -93,10 +80,7 @@ def validate_purchase_order_reference(value):
def validate_sales_order_reference(value):
"""
Validate the 'reference' field of a SalesOrder
"""
"""Validate the 'reference' field of a SalesOrder."""
pattern = common.models.InvenTreeSetting.get_setting('SALESORDER_REFERENCE_REGEX')
if pattern:
@ -107,16 +91,14 @@ def validate_sales_order_reference(value):
def validate_tree_name(value):
""" Prevent illegal characters in tree item names """
"""Prevent illegal characters in tree item names."""
for c in "!@#$%^&*'\"\\/[]{}<>,|+=~`\"":
if c in str(value):
raise ValidationError(_('Illegal character in name ({x})'.format(x=c)))
def validate_overage(value):
"""
Validate that a BOM overage string is properly formatted.
"""Validate that a BOM overage string is properly formatted.
An overage string can look like:
@ -124,7 +106,6 @@ def validate_overage(value):
- A decimal number ('0.123')
- A percentage ('5%' / '10 %')
"""
value = str(value).lower().strip()
# First look for a simple numerical value
@ -162,11 +143,10 @@ def validate_overage(value):
def validate_part_name_format(self):
"""
Validate part name format.
"""Validate part name format.
Make sure that each template container has a field of Part Model
"""
jinja_template_regex = re.compile('{{.*?}}')
field_name_regex = re.compile('(?<=part\\.)[A-z]+')
for jinja_template in jinja_template_regex.findall(str(self)):

View File

@ -1,5 +1,5 @@
"""
Version information for InvenTree.
"""Version information for InvenTree.
Provides information on the current InvenTree version
"""
@ -17,12 +17,12 @@ INVENTREE_SW_VERSION = "0.8.0 dev"
def inventreeInstanceName():
""" Returns the InstanceName settings for the current database """
"""Returns the InstanceName settings for the current database."""
return common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE", "")
def inventreeInstanceTitle():
""" Returns the InstanceTitle for the current database """
"""Returns the InstanceTitle for the current database."""
if common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE_TITLE", False):
return common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE", "")
else:
@ -30,13 +30,12 @@ def inventreeInstanceTitle():
def inventreeVersion():
""" Returns the InvenTree version string """
"""Returns the InvenTree version string."""
return INVENTREE_SW_VERSION.lower().strip()
def inventreeVersionTuple(version=None):
""" Return the InvenTree version string as (maj, min, sub) tuple """
"""Return the InvenTree version string as (maj, min, sub) tuple."""
if version is None:
version = INVENTREE_SW_VERSION
@ -46,21 +45,16 @@ def inventreeVersionTuple(version=None):
def isInvenTreeDevelopmentVersion():
"""
Return True if current InvenTree version is a "development" version
"""
"""Return True if current InvenTree version is a "development" version."""
return inventreeVersion().endswith('dev')
def inventreeDocsVersion():
"""
Return the version string matching the latest documentation.
"""Return the version string matching the latest documentation.
Development -> "latest"
Release -> "major.minor.sub" e.g. "0.5.2"
"""
if isInvenTreeDevelopmentVersion():
return "latest"
else:
@ -68,13 +62,10 @@ def inventreeDocsVersion():
def isInvenTreeUpToDate():
"""
Test if the InvenTree instance is "up to date" with the latest version.
"""Test if the InvenTree instance is "up to date" with the latest version.
A background task periodically queries GitHub for latest version,
and stores it to the database as INVENTREE_LATEST_VERSION
A background task periodically queries GitHub for latest version, and stores it to the database as INVENTREE_LATEST_VERSION
"""
latest = common.models.InvenTreeSetting.get_setting('INVENTREE_LATEST_VERSION', backup_value=None, create=False)
# No record for "latest" version - we must assume we are up to date!
@ -89,17 +80,17 @@ def isInvenTreeUpToDate():
def inventreeApiVersion():
"""Returns current API version of InvenTree."""
return INVENTREE_API_VERSION
def inventreeDjangoVersion():
""" Return the version of Django library """
"""Returns the version of Django library."""
return django.get_version()
def inventreeCommitHash():
""" Returns the git commit hash for the running codebase """
"""Returns the git commit hash for the running codebase."""
# First look in the environment variables, i.e. if running in docker
commit_hash = os.environ.get('INVENTREE_COMMIT_HASH', '')
@ -113,8 +104,7 @@ def inventreeCommitHash():
def inventreeCommitDate():
""" Returns the git commit date for the running codebase """
"""Returns the git commit date for the running codebase."""
# First look in the environment variables, e.g. if running in docker
commit_date = os.environ.get('INVENTREE_COMMIT_DATE', '')

View File

@ -1,5 +1,4 @@
"""
Various Views which provide extra functionality over base Django Views.
"""Various Views which provide extra functionality over base Django Views.
In particular these views provide base functionality for rendering Django forms
as JSON objects and passing them to modal forms (using jQuery / bootstrap).
@ -41,12 +40,10 @@ from .helpers import str2bool
def auth_request(request):
"""
Simple 'auth' endpoint used to determine if the user is authenticated.
Useful for (for example) redirecting authentication requests through
django's permission framework.
"""
"""Simple 'auth' endpoint used to determine if the user is authenticated.
Useful for (for example) redirecting authentication requests through django's permission framework.
"""
if request.user.is_authenticated:
return HttpResponse(status=200)
else:
@ -54,8 +51,7 @@ def auth_request(request):
class InvenTreeRoleMixin(PermissionRequiredMixin):
"""
Permission class based on user roles, not user 'permissions'.
"""Permission class based on user roles, not user 'permissions'.
There are a number of ways that the permissions can be specified for a view:
@ -97,10 +93,7 @@ class InvenTreeRoleMixin(PermissionRequiredMixin):
role_required = None
def has_permission(self):
"""
Determine if the current user has specified permissions
"""
"""Determine if the current user has specified permissions."""
roles_required = []
if type(self.role_required) is str:
@ -163,8 +156,7 @@ class InvenTreeRoleMixin(PermissionRequiredMixin):
return True
def get_permission_class(self):
"""
Return the 'permission_class' required for the current View.
"""Return the 'permission_class' required for the current View.
Must be one of:
@ -177,7 +169,6 @@ class InvenTreeRoleMixin(PermissionRequiredMixin):
'permission_class' attribute,
or it can be "guessed" by looking at the type of class
"""
perm = getattr(self, 'permission_class', None)
# Permission is specified by the class itself
@ -204,13 +195,10 @@ class InvenTreeRoleMixin(PermissionRequiredMixin):
class AjaxMixin(InvenTreeRoleMixin):
""" AjaxMixin provides basic functionality for rendering a Django form to JSON.
Handles jsonResponse rendering, and adds extra data for the modal forms to process
on the client side.
"""AjaxMixin provides basic functionality for rendering a Django form to JSON. Handles jsonResponse rendering, and adds extra data for the modal forms to process on the client side.
Any view which inherits the AjaxMixin will need
correct permissions set using the 'role_required' attribute
"""
# By default, allow *any* role
@ -223,11 +211,11 @@ class AjaxMixin(InvenTreeRoleMixin):
ajax_form_title = ''
def get_form_title(self):
""" Default implementation - return the ajax_form_title variable """
"""Default implementation - return the ajax_form_title variable."""
return self.ajax_form_title
def get_param(self, name, method='GET'):
""" Get a request query parameter value from URL e.g. ?part=3
"""Get a request query parameter value from URL e.g. ?part=3.
Args:
name: Variable name e.g. 'part'
@ -236,14 +224,13 @@ class AjaxMixin(InvenTreeRoleMixin):
Returns:
Value of the supplier parameter or None if parameter is not available
"""
if method == 'POST':
return self.request.POST.get(name, None)
else:
return self.request.GET.get(name, None)
def get_data(self):
""" Get extra context data (default implementation is empty dict)
"""Get extra context data (default implementation is empty dict).
Returns:
dict object (empty)
@ -251,20 +238,18 @@ class AjaxMixin(InvenTreeRoleMixin):
return {}
def validate(self, obj, form, **kwargs):
"""
Hook for performing custom form validation steps.
"""Hook for performing custom form validation steps.
If a form error is detected, add it to the form,
with 'form.add_error()'
Ref: https://docs.djangoproject.com/en/dev/topics/forms/
"""
# Do nothing by default
pass
def renderJsonResponse(self, request, form=None, data=None, context=None):
""" Render a JSON response based on specific class context.
"""Render a JSON response based on specific class context.
Args:
request: HTTP request object (e.g. GET / POST)
@ -318,19 +303,25 @@ class AjaxMixin(InvenTreeRoleMixin):
class AjaxView(AjaxMixin, View):
""" An 'AJAXified' View for displaying an object
"""
"""An 'AJAXified' View for displaying an object."""
def post(self, request, *args, **kwargs):
"""Return a json formatted response.
This renderJsonResponse function must be supplied by your function.
"""
return self.renderJsonResponse(request)
def get(self, request, *args, **kwargs):
"""Return a json formatted response.
This renderJsonResponse function must be supplied by your function.
"""
return self.renderJsonResponse(request)
class QRCodeView(AjaxView):
""" An 'AJAXified' view for displaying a QR code.
"""An 'AJAXified' view for displaying a QR code.
Subclasses should implement the get_qr_data(self) function.
"""
@ -338,22 +329,23 @@ class QRCodeView(AjaxView):
ajax_template_name = "qr_code.html"
def get(self, request, *args, **kwargs):
"""Return json with qr-code data."""
self.request = request
self.pk = self.kwargs['pk']
return self.renderJsonResponse(request, None, context=self.get_context_data())
def get_qr_data(self):
""" Returns the text object to render to a QR code.
The actual rendering will be handled by the template """
"""Returns the text object to render to a QR code.
The actual rendering will be handled by the template
"""
return None
def get_context_data(self):
""" Get context data for passing to the rendering template.
"""Get context data for passing to the rendering template.
Explicity passes the parameter 'qr_data'
"""
context = {}
qr = self.get_qr_data()
@ -367,15 +359,14 @@ class QRCodeView(AjaxView):
class AjaxCreateView(AjaxMixin, CreateView):
"""An 'AJAXified' CreateView for creating a new object in the db.
""" An 'AJAXified' CreateView for creating a new object in the db
- Returns a form in JSON format (for delivery to a modal window)
- Handles form validation via AJAX POST requests
"""
def get(self, request, *args, **kwargs):
""" Creates form with initial data, and renders JSON response """
"""Creates form with initial data, and renders JSON response."""
super(CreateView, self).get(request, *args, **kwargs)
self.request = request
@ -383,18 +374,16 @@ class AjaxCreateView(AjaxMixin, CreateView):
return self.renderJsonResponse(request, form)
def save(self, form):
"""
Method for actually saving the form to the database.
Default implementation is very simple,
but can be overridden if required.
"""
"""Method for actually saving the form to the database.
Default implementation is very simple, but can be overridden if required.
"""
self.object = form.save()
return self.object
def post(self, request, *args, **kwargs):
""" Responds to form POST. Validates POST data and returns status info.
"""Responds to form POST. Validates POST data and returns status info.
- Validate POST form data
- If valid, save form
@ -441,45 +430,44 @@ class AjaxCreateView(AjaxMixin, CreateView):
class AjaxUpdateView(AjaxMixin, UpdateView):
""" An 'AJAXified' UpdateView for updating an object in the db
"""An 'AJAXified' UpdateView for updating an object in the db.
- Returns form in JSON format (for delivery to a modal window)
- Handles repeated form validation (via AJAX) until the form is valid
"""
def get(self, request, *args, **kwargs):
""" Respond to GET request.
"""Respond to GET request.
- Populates form with object data
- Renders form to JSON and returns to client
"""
super(UpdateView, self).get(request, *args, **kwargs)
return self.renderJsonResponse(request, self.get_form(), context=self.get_context_data())
def save(self, object, form, **kwargs):
"""
Method for updating the object in the database.
Default implementation is very simple, but can be overridden if required.
"""Method for updating the object in the database. Default implementation is very simple, but can be overridden if required.
Args:
object - The current object, to be updated
form - The validated form
"""
Returns:
object instance for supplied form
"""
self.object = form.save()
return self.object
def post(self, request, *args, **kwargs):
""" Respond to POST request.
"""Respond to POST request.
- Updates model with POST field data
- Performs form and object validation
- If errors exist, re-render the form
- Otherwise, return sucess status
"""
self.request = request
# Make sure we have an object to point to
@ -524,8 +512,8 @@ class AjaxUpdateView(AjaxMixin, UpdateView):
class AjaxDeleteView(AjaxMixin, UpdateView):
"""An 'AJAXified DeleteView for removing an object from the DB.
""" An 'AJAXified DeleteView for removing an object from the DB
- Returns a HTML object (not a form!) in JSON format (for delivery to a modal window)
- Handles deletion
"""
@ -536,6 +524,7 @@ class AjaxDeleteView(AjaxMixin, UpdateView):
context_object_name = 'item'
def get_object(self):
"""Return object matched to the model of the calling class."""
try:
self.object = self.model.objects.get(pk=self.kwargs['pk'])
except:
@ -543,15 +532,15 @@ class AjaxDeleteView(AjaxMixin, UpdateView):
return self.object
def get_form(self):
"""Returns a form instance for the form_class of the calling class."""
return self.form_class(self.get_form_kwargs())
def get(self, request, *args, **kwargs):
""" Respond to GET request
"""Respond to GET request.
- Render a DELETE confirmation form to JSON
- Return rendered form to client
"""
super(UpdateView, self).get(request, *args, **kwargs)
form = self.get_form()
@ -563,12 +552,11 @@ class AjaxDeleteView(AjaxMixin, UpdateView):
return self.renderJsonResponse(request, form, context=context)
def post(self, request, *args, **kwargs):
""" Respond to POST request
"""Respond to POST request.
- DELETE the object
- Render success message to JSON and return to client
"""
obj = self.get_object()
pk = obj.id
@ -592,28 +580,30 @@ class AjaxDeleteView(AjaxMixin, UpdateView):
class EditUserView(AjaxUpdateView):
""" View for editing user information """
"""View for editing user information."""
ajax_template_name = "modal_form.html"
ajax_form_title = _("Edit User Information")
form_class = EditUserForm
def get_object(self):
"""Set form to edit current user."""
return self.request.user
class SetPasswordView(AjaxUpdateView):
""" View for setting user password """
"""View for setting user password."""
ajax_template_name = "InvenTree/password.html"
ajax_form_title = _("Set Password")
form_class = SetPasswordForm
def get_object(self):
"""Set form to edit current user."""
return self.request.user
def post(self, request, *args, **kwargs):
"""Validate inputs and change password."""
form = self.get_form()
valid = form.is_valid()
@ -645,19 +635,13 @@ class SetPasswordView(AjaxUpdateView):
class IndexView(TemplateView):
""" View for InvenTree index page """
"""View for InvenTree index page."""
template_name = 'InvenTree/index.html'
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
return context
class SearchView(TemplateView):
""" View for InvenTree search page.
"""View for InvenTree search page.
Displays results of search query
"""
@ -665,11 +649,10 @@ class SearchView(TemplateView):
template_name = 'InvenTree/search.html'
def post(self, request, *args, **kwargs):
""" Handle POST request (which contains search query).
"""Handle POST request (which contains search query).
Pass the search query to the page template
"""
context = self.get_context_data()
query = request.POST.get('search', '')
@ -680,24 +663,19 @@ class SearchView(TemplateView):
class DynamicJsView(TemplateView):
"""
View for returning javacsript files,
which instead of being served dynamically,
are passed through the django translation engine!
"""
"""View for returning javacsript files, which instead of being served dynamically, are passed through the django translation engine!"""
template_name = ""
content_type = 'text/javascript'
class SettingsView(TemplateView):
""" View for configuring User settings
"""
"""View for configuring User settings."""
template_name = "InvenTree/settings/settings.html"
def get_context_data(self, **kwargs):
"""Add data for template."""
ctx = super().get_context_data(**kwargs).copy()
ctx['settings'] = InvenTreeSetting.objects.all().order_by('key')
@ -739,61 +717,53 @@ class SettingsView(TemplateView):
class AllauthOverrides(LoginRequiredMixin):
"""
Override allauths views to always redirect to success_url
"""
"""Override allauths views to always redirect to success_url."""
def get(self, request, *args, **kwargs):
# always redirect to settings
"""Always redirect to success_url (set to settings)."""
return HttpResponseRedirect(self.success_url)
class CustomEmailView(AllauthOverrides, EmailView):
"""
Override of allauths EmailView to always show the settings but leave the functions allow
"""
"""Override of allauths EmailView to always show the settings but leave the functions allow."""
success_url = reverse_lazy("settings")
class CustomConnectionsView(AllauthOverrides, ConnectionsView):
"""
Override of allauths ConnectionsView to always show the settings but leave the functions allow
"""
"""Override of allauths ConnectionsView to always show the settings but leave the functions allow."""
success_url = reverse_lazy("settings")
class CustomPasswordResetFromKeyView(PasswordResetFromKeyView):
"""
Override of allauths PasswordResetFromKeyView to always show the settings but leave the functions allow
"""
"""Override of allauths PasswordResetFromKeyView to always show the settings but leave the functions allow."""
success_url = reverse_lazy("account_login")
class UserSessionOverride():
"""overrides sucessurl to lead to settings"""
"""Overrides sucessurl to lead to settings."""
def get_success_url(self):
"""Revert to settings page after success."""
return str(reverse_lazy('settings'))
class CustomSessionDeleteView(UserSessionOverride, SessionDeleteView):
"""Revert to settings after session delete."""
pass
class CustomSessionDeleteOtherView(UserSessionOverride, SessionDeleteOtherView):
"""Revert to settings after session delete."""
pass
class CurrencyRefreshView(RedirectView):
"""
POST endpoint to refresh / update exchange rates
"""
"""POST endpoint to refresh / update exchange rates."""
url = reverse_lazy("settings-currencies")
def post(self, request, *args, **kwargs):
"""
On a POST request we will attempt to refresh the exchange rates
"""
"""On a POST request we will attempt to refresh the exchange rates."""
from InvenTree.tasks import offload_task, update_exchange_rates
offload_task(update_exchange_rates, force_sync=True)
@ -802,10 +772,10 @@ class CurrencyRefreshView(RedirectView):
class AppearanceSelectView(RedirectView):
""" View for selecting a color theme """
"""View for selecting a color theme."""
def get_user_theme(self):
""" Get current user color theme """
"""Get current user color theme."""
try:
user_theme = ColorTheme.objects.filter(user=self.request.user).get()
except ColorTheme.DoesNotExist:
@ -814,8 +784,7 @@ class AppearanceSelectView(RedirectView):
return user_theme
def post(self, request, *args, **kwargs):
""" Save user color theme selection """
"""Save user color theme selection."""
theme = request.POST.get('theme', None)
# Get current user theme
@ -833,15 +802,14 @@ class AppearanceSelectView(RedirectView):
class SettingCategorySelectView(FormView):
""" View for selecting categories in settings """
"""View for selecting categories in settings."""
form_class = SettingCategorySelectForm
success_url = reverse_lazy('settings-category')
template_name = "InvenTree/settings/category.html"
def get_initial(self):
""" Set category selection """
"""Set category selection."""
initial = super().get_initial()
category = self.request.GET.get('category', None)
@ -851,11 +819,10 @@ class SettingCategorySelectView(FormView):
return initial
def post(self, request, *args, **kwargs):
""" Handle POST request (which contains category selection).
"""Handle POST request (which contains category selection).
Pass the selected category to the page template
"""
form = self.get_form()
if form.is_valid():
@ -869,14 +836,13 @@ class SettingCategorySelectView(FormView):
class DatabaseStatsView(AjaxView):
""" View for displaying database statistics """
"""View for displaying database statistics."""
ajax_template_name = "stats.html"
ajax_form_title = _("System Information")
class NotificationsView(TemplateView):
""" View for showing notifications
"""
"""View for showing notifications."""
template_name = "InvenTree/notifications/notifications.html"

View File

@ -1,5 +1,4 @@
"""
WSGI config for InvenTree project.
"""WSGI config for InvenTree project.
It exposes the WSGI callable as a module-level variable named ``application``.

View File

@ -1,5 +1,4 @@
"""
The Build module is responsible for managing "Build" transactions.
"""The Build module is responsible for managing "Build" transactions.
A Build consumes parts from stock to create new parts
"""

View File

@ -1,3 +1,5 @@
"""Admin functionality for the BuildOrder app"""
from django.contrib import admin
from import_export.admin import ImportExportModelAdmin
@ -11,7 +13,7 @@ import part.models
class BuildResource(ModelResource):
"""Class for managing import/export of Build data"""
"""Class for managing import/export of Build data."""
# For some reason, we need to specify the fields individually for this ModelResource,
# but we don't for other ones.
# TODO: 2022-05-12 - Need to investigate why this is the case!
@ -39,6 +41,7 @@ class BuildResource(ModelResource):
notes = Field(attribute='notes')
class Meta:
"""Metaclass options"""
models = Build
skip_unchanged = True
report_skipped = False
@ -50,6 +53,7 @@ class BuildResource(ModelResource):
class BuildAdmin(ImportExportModelAdmin):
"""Class for managing the Build model via the admin interface"""
exclude = [
'reference_int',
@ -81,6 +85,7 @@ class BuildAdmin(ImportExportModelAdmin):
class BuildItemAdmin(admin.ModelAdmin):
"""Class for managing the BuildItem model via the admin interface"""
list_display = (
'build',

View File

@ -1,6 +1,4 @@
"""
JSON API for the Build app
"""
"""JSON API for the Build app."""
from django.urls import include, re_path
@ -22,16 +20,14 @@ from users.models import Owner
class BuildFilter(rest_filters.FilterSet):
"""
Custom filterset for BuildList API endpoint
"""
"""Custom filterset for BuildList API endpoint."""
status = rest_filters.NumberFilter(label='Status')
active = rest_filters.BooleanFilter(label='Build is active', method='filter_active')
def filter_active(self, queryset, name, value):
"""Filter the queryset to either include or exclude orders which are active."""
if str2bool(value):
queryset = queryset.filter(status__in=BuildStatus.ACTIVE_CODES)
else:
@ -42,7 +38,7 @@ class BuildFilter(rest_filters.FilterSet):
overdue = rest_filters.BooleanFilter(label='Build is overdue', method='filter_overdue')
def filter_overdue(self, queryset, name, value):
"""Filter the queryset to either include or exclude orders which are overdue."""
if str2bool(value):
queryset = queryset.filter(Build.OVERDUE_FILTER)
else:
@ -53,10 +49,7 @@ class BuildFilter(rest_filters.FilterSet):
assigned_to_me = rest_filters.BooleanFilter(label='assigned_to_me', method='filter_assigned_to_me')
def filter_assigned_to_me(self, queryset, name, value):
"""
Filter by orders which are assigned to the current user
"""
"""Filter by orders which are assigned to the current user."""
value = str2bool(value)
# Work out who "me" is!
@ -71,7 +64,7 @@ class BuildFilter(rest_filters.FilterSet):
class BuildList(APIDownloadMixin, generics.ListCreateAPIView):
""" API endpoint for accessing a list of Build objects.
"""API endpoint for accessing a list of Build objects.
- GET: Return list of objects (with filters)
- POST: Create a new Build object
@ -113,11 +106,7 @@ class BuildList(APIDownloadMixin, generics.ListCreateAPIView):
]
def get_queryset(self):
"""
Override the queryset filtering,
as some of the fields don't natively play nicely with DRF
"""
"""Override the queryset filtering, as some of the fields don't natively play nicely with DRF."""
queryset = super().get_queryset().select_related('part')
queryset = build.serializers.BuildSerializer.annotate_queryset(queryset)
@ -125,6 +114,7 @@ class BuildList(APIDownloadMixin, generics.ListCreateAPIView):
return queryset
def download_queryset(self, queryset, export_format):
"""Download the queryset data as a file."""
dataset = build.admin.BuildResource().export(queryset=queryset)
filedata = dataset.export(export_format)
@ -133,7 +123,7 @@ class BuildList(APIDownloadMixin, generics.ListCreateAPIView):
return DownloadFile(filedata, filename)
def filter_queryset(self, queryset):
"""Custom query filtering for the BuildList endpoint."""
queryset = super().filter_queryset(queryset)
params = self.request.query_params
@ -197,7 +187,7 @@ class BuildList(APIDownloadMixin, generics.ListCreateAPIView):
return queryset
def get_serializer(self, *args, **kwargs):
"""Add extra context information to the endpoint serializer."""
try:
part_detail = str2bool(self.request.GET.get('part_detail', None))
except AttributeError:
@ -209,15 +199,14 @@ class BuildList(APIDownloadMixin, generics.ListCreateAPIView):
class BuildDetail(generics.RetrieveUpdateAPIView):
""" API endpoint for detail view of a Build object """
"""API endpoint for detail view of a Build object."""
queryset = Build.objects.all()
serializer_class = build.serializers.BuildSerializer
class BuildUnallocate(generics.CreateAPIView):
"""
API endpoint for unallocating stock items from a build order
"""API endpoint for unallocating stock items from a build order.
- The BuildOrder object is specified by the URL
- "output" (StockItem) can optionally be specified
@ -229,7 +218,7 @@ class BuildUnallocate(generics.CreateAPIView):
serializer_class = build.serializers.BuildUnallocationSerializer
def get_serializer_context(self):
"""Add extra context information to the endpoint serializer."""
ctx = super().get_serializer_context()
try:
@ -243,9 +232,10 @@ class BuildUnallocate(generics.CreateAPIView):
class BuildOrderContextMixin:
""" Mixin class which adds build order as serializer context variable """
"""Mixin class which adds build order as serializer context variable."""
def get_serializer_context(self):
"""Add extra context information to the endpoint serializer."""
ctx = super().get_serializer_context()
ctx['request'] = self.request
@ -260,9 +250,7 @@ class BuildOrderContextMixin:
class BuildOutputCreate(BuildOrderContextMixin, generics.CreateAPIView):
"""
API endpoint for creating new build output(s)
"""
"""API endpoint for creating new build output(s)."""
queryset = Build.objects.none()
@ -270,9 +258,7 @@ class BuildOutputCreate(BuildOrderContextMixin, generics.CreateAPIView):
class BuildOutputComplete(BuildOrderContextMixin, generics.CreateAPIView):
"""
API endpoint for completing build outputs
"""
"""API endpoint for completing build outputs."""
queryset = Build.objects.none()
@ -280,11 +266,10 @@ class BuildOutputComplete(BuildOrderContextMixin, generics.CreateAPIView):
class BuildOutputDelete(BuildOrderContextMixin, generics.CreateAPIView):
"""
API endpoint for deleting multiple build outputs
"""
"""API endpoint for deleting multiple build outputs."""
def get_serializer_context(self):
"""Add extra context information to the endpoint serializer."""
ctx = super().get_serializer_context()
ctx['to_complete'] = False
@ -297,9 +282,7 @@ class BuildOutputDelete(BuildOrderContextMixin, generics.CreateAPIView):
class BuildFinish(BuildOrderContextMixin, generics.CreateAPIView):
"""
API endpoint for marking a build as finished (completed)
"""
"""API endpoint for marking a build as finished (completed)."""
queryset = Build.objects.none()
@ -307,8 +290,7 @@ class BuildFinish(BuildOrderContextMixin, generics.CreateAPIView):
class BuildAutoAllocate(BuildOrderContextMixin, generics.CreateAPIView):
"""
API endpoint for 'automatically' allocating stock against a build order.
"""API endpoint for 'automatically' allocating stock against a build order.
- Only looks at 'untracked' parts
- If stock exists in a single location, easy!
@ -322,8 +304,7 @@ class BuildAutoAllocate(BuildOrderContextMixin, generics.CreateAPIView):
class BuildAllocate(BuildOrderContextMixin, generics.CreateAPIView):
"""
API endpoint to allocate stock items to a build order
"""API endpoint to allocate stock items to a build order.
- The BuildOrder object is specified by the URL
- Items to allocate are specified as a list called "items" with the following options:
@ -339,23 +320,21 @@ class BuildAllocate(BuildOrderContextMixin, generics.CreateAPIView):
class BuildCancel(BuildOrderContextMixin, generics.CreateAPIView):
""" API endpoint for cancelling a BuildOrder """
"""API endpoint for cancelling a BuildOrder."""
queryset = Build.objects.all()
serializer_class = build.serializers.BuildCancelSerializer
class BuildItemDetail(generics.RetrieveUpdateDestroyAPIView):
"""
API endpoint for detail view of a BuildItem object
"""
"""API endpoint for detail view of a BuildItem object."""
queryset = BuildItem.objects.all()
serializer_class = build.serializers.BuildItemSerializer
class BuildItemList(generics.ListCreateAPIView):
""" API endpoint for accessing a list of BuildItem objects
"""API endpoint for accessing a list of BuildItem objects.
- GET: Return list of objects
- POST: Create a new BuildItem object
@ -364,7 +343,7 @@ class BuildItemList(generics.ListCreateAPIView):
serializer_class = build.serializers.BuildItemSerializer
def get_serializer(self, *args, **kwargs):
"""Returns a BuildItemSerializer instance based on the request."""
try:
params = self.request.query_params
@ -377,10 +356,7 @@ class BuildItemList(generics.ListCreateAPIView):
return self.serializer_class(*args, **kwargs)
def get_queryset(self):
""" Override the queryset method,
to allow filtering by stock_item.part
"""
"""Override the queryset method, to allow filtering by stock_item.part."""
query = BuildItem.objects.all()
query = query.select_related('stock_item__location')
@ -390,7 +366,7 @@ class BuildItemList(generics.ListCreateAPIView):
return query
def filter_queryset(self, queryset):
"""Customm query filtering for the BuildItem list."""
queryset = super().filter_queryset(queryset)
params = self.request.query_params
@ -438,9 +414,7 @@ class BuildItemList(generics.ListCreateAPIView):
class BuildAttachmentList(generics.ListCreateAPIView, AttachmentMixin):
"""
API endpoint for listing (and creating) BuildOrderAttachment objects
"""
"""API endpoint for listing (and creating) BuildOrderAttachment objects."""
queryset = BuildOrderAttachment.objects.all()
serializer_class = build.serializers.BuildAttachmentSerializer
@ -455,9 +429,7 @@ class BuildAttachmentList(generics.ListCreateAPIView, AttachmentMixin):
class BuildAttachmentDetail(generics.RetrieveUpdateDestroyAPIView, AttachmentMixin):
"""
Detail endpoint for a BuildOrderAttachment object
"""
"""Detail endpoint for a BuildOrderAttachment object."""
queryset = BuildOrderAttachment.objects.all()
serializer_class = build.serializers.BuildAttachmentSerializer

View File

@ -1,5 +1,8 @@
"""Django app for the BuildOrder module"""
from django.apps import AppConfig
class BuildConfig(AppConfig):
"""BuildOrder app config class"""
name = 'build'

View File

@ -1,6 +1,4 @@
"""
Build database model definitions
"""
"""Build database model definitions."""
import decimal
@ -42,10 +40,7 @@ from users import models as UserModels
def get_next_build_number():
"""
Returns the next available BuildOrder reference number
"""
"""Returns the next available BuildOrder reference number."""
if Build.objects.count() == 0:
return '0001'
@ -71,7 +66,7 @@ def get_next_build_number():
class Build(MPTTModel, ReferenceIndexingMixin):
""" A Build object organises the creation of new StockItem objects from other existing StockItem objects.
"""A Build object organises the creation of new StockItem objects from other existing StockItem objects.
Attributes:
part: The part to be built (from component BOM items)
@ -97,10 +92,11 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@staticmethod
def get_api_url():
"""Return the API URL associated with the BuildOrder model"""
return reverse('api-build-list')
def api_instance_filters(self):
"""Returns custom API filters for the particular BuildOrder instance"""
return {
'parent': {
'exclude_tree': self.pk,
@ -109,10 +105,7 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@classmethod
def api_defaults(cls, request):
"""
Return default values for this model when issuing an API OPTIONS request
"""
"""Return default values for this model when issuing an API OPTIONS request."""
defaults = {
'reference': get_next_build_number(),
}
@ -123,7 +116,7 @@ class Build(MPTTModel, ReferenceIndexingMixin):
return defaults
def save(self, *args, **kwargs):
"""Custom save method for the BuildOrder model"""
self.rebuild_reference_field()
try:
@ -134,14 +127,12 @@ class Build(MPTTModel, ReferenceIndexingMixin):
})
class Meta:
"""Metaclass options for the BuildOrder model"""
verbose_name = _("Build Order")
verbose_name_plural = _("Build Orders")
def format_barcode(self, **kwargs):
"""
Return a JSON string to represent this build as a barcode
"""
"""Return a JSON string to represent this build as a barcode."""
return MakeBarcode(
"buildorder",
self.pk,
@ -153,13 +144,11 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@staticmethod
def filterByDate(queryset, min_date, max_date):
"""
Filter by 'minimum and maximum date range'
"""Filter by 'minimum and maximum date range'.
- Specified as min_date, max_date
- Both must be specified for filter to be applied
"""
date_fmt = '%Y-%m-%d' # ISO format date string
# Ensure that both dates are valid
@ -183,12 +172,13 @@ class Build(MPTTModel, ReferenceIndexingMixin):
return queryset
def __str__(self):
"""String representation of a BuildOrder"""
prefix = getSetting("BUILDORDER_REFERENCE_PREFIX")
return f"{prefix}{self.reference}"
def get_absolute_url(self):
"""Return the web URL associated with this BuildOrder"""
return reverse('build-detail', kwargs={'pk': self.id})
reference = models.CharField(
@ -336,10 +326,7 @@ class Build(MPTTModel, ReferenceIndexingMixin):
)
def sub_builds(self, cascade=True):
"""
Return all Build Order objects under this one.
"""
"""Return all Build Order objects under this one."""
if cascade:
return Build.objects.filter(parent=self.pk)
else:
@ -347,23 +334,22 @@ class Build(MPTTModel, ReferenceIndexingMixin):
Build.objects.filter(parent__pk__in=[d.pk for d in descendants])
def sub_build_count(self, cascade=True):
"""
Return the number of sub builds under this one.
"""Return the number of sub builds under this one.
Args:
cascade: If True (defualt), include cascading builds under sub builds
"""
return self.sub_builds(cascade=cascade).count()
@property
def is_overdue(self):
"""
Returns true if this build is "overdue":
"""Returns true if this build is "overdue".
Makes use of the OVERDUE_FILTER to avoid code duplication
"""
Returns:
bool: Is the build overdue
"""
query = Build.objects.filter(pk=self.pk)
query = query.filter(Build.OVERDUE_FILTER)
@ -371,80 +357,59 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@property
def active(self):
"""
Return True if this build is active
"""
"""Return True if this build is active."""
return self.status in BuildStatus.ACTIVE_CODES
@property
def bom_items(self):
"""
Returns the BOM items for the part referenced by this BuildOrder
"""
"""Returns the BOM items for the part referenced by this BuildOrder."""
return self.part.get_bom_items()
@property
def tracked_bom_items(self):
"""
Returns the "trackable" BOM items for this BuildOrder
"""
"""Returns the "trackable" BOM items for this BuildOrder."""
items = self.bom_items
items = items.filter(sub_part__trackable=True)
return items
def has_tracked_bom_items(self):
"""
Returns True if this BuildOrder has trackable BomItems
"""
"""Returns True if this BuildOrder has trackable BomItems."""
return self.tracked_bom_items.count() > 0
@property
def untracked_bom_items(self):
"""
Returns the "non trackable" BOM items for this BuildOrder
"""
"""Returns the "non trackable" BOM items for this BuildOrder."""
items = self.bom_items
items = items.filter(sub_part__trackable=False)
return items
def has_untracked_bom_items(self):
"""
Returns True if this BuildOrder has non trackable BomItems
"""
"""Returns True if this BuildOrder has non trackable BomItems."""
return self.untracked_bom_items.count() > 0
@property
def remaining(self):
"""
Return the number of outputs remaining to be completed.
"""
"""Return the number of outputs remaining to be completed."""
return max(0, self.quantity - self.completed)
@property
def output_count(self):
"""Return the number of build outputs (StockItem) associated with this build order"""
return self.build_outputs.count()
def has_build_outputs(self):
"""Returns True if this build has more than zero build outputs"""
return self.output_count > 0
def get_build_outputs(self, **kwargs):
"""
Return a list of build outputs.
"""Return a list of build outputs.
kwargs:
complete = (True / False) - If supplied, filter by completed status
in_stock = (True / False) - If supplied, filter by 'in-stock' status
"""
outputs = self.build_outputs.all()
# Filter by 'in stock' status
@ -469,17 +434,14 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@property
def complete_outputs(self):
"""
Return all the "completed" build outputs
"""
"""Return all the "completed" build outputs."""
outputs = self.get_build_outputs(complete=True)
return outputs
@property
def complete_count(self):
"""Return the total quantity of completed outputs"""
quantity = 0
for output in self.complete_outputs:
@ -489,20 +451,14 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@property
def incomplete_outputs(self):
"""
Return all the "incomplete" build outputs
"""
"""Return all the "incomplete" build outputs."""
outputs = self.get_build_outputs(complete=False)
return outputs
@property
def incomplete_count(self):
"""
Return the total number of "incomplete" outputs
"""
"""Return the total number of "incomplete" outputs."""
quantity = 0
for output in self.incomplete_outputs:
@ -512,10 +468,7 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@classmethod
def getNextBuildNumber(cls):
"""
Try to predict the next Build Order reference:
"""
"""Try to predict the next Build Order reference."""
if cls.objects.count() == 0:
return None
@ -552,13 +505,11 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@property
def can_complete(self):
"""
Returns True if this build can be "completed"
"""Returns True if this build can be "completed".
- Must not have any outstanding build outputs
- 'completed' value must meet (or exceed) the 'quantity' value
"""
if self.incomplete_count > 0:
return False
@ -573,10 +524,7 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@transaction.atomic
def complete_build(self, user):
"""
Mark this build as complete
"""
"""Mark this build as complete."""
if self.incomplete_count > 0:
return
@ -597,13 +545,12 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@transaction.atomic
def cancel_build(self, user, **kwargs):
""" Mark the Build as CANCELLED
"""Mark the Build as CANCELLED.
- Delete any pending BuildItem objects (but do not remove items from stock)
- Set build status to CANCELLED
- Save the Build object
"""
remove_allocated_stock = kwargs.get('remove_allocated_stock', False)
remove_incomplete_outputs = kwargs.get('remove_incomplete_outputs', False)
@ -633,14 +580,12 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@transaction.atomic
def unallocateStock(self, bom_item=None, output=None):
"""
Unallocate stock from this Build
"""Unallocate stock from this Build.
arguments:
- bom_item: Specify a particular BomItem to unallocate stock against
- output: Specify a particular StockItem (output) to unallocate stock against
Args:
bom_item: Specify a particular BomItem to unallocate stock against
output: Specify a particular StockItem (output) to unallocate stock against
"""
allocations = BuildItem.objects.filter(
build=self,
install_into=output
@ -653,19 +598,17 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@transaction.atomic
def create_build_output(self, quantity, **kwargs):
"""
Create a new build output against this BuildOrder.
"""Create a new build output against this BuildOrder.
args:
Args:
quantity: The quantity of the item to produce
kwargs:
Kwargs:
batch: Override batch code
serials: Serial numbers
location: Override location
auto_allocate: Automatically allocate stock with matching serial numbers
"""
batch = kwargs.get('batch', self.batch)
location = kwargs.get('location', self.destination)
serials = kwargs.get('serials', None)
@ -687,9 +630,7 @@ class Build(MPTTModel, ReferenceIndexingMixin):
multiple = True
if multiple:
"""
Create multiple build outputs with a single quantity of 1
"""
"""Create multiple build outputs with a single quantity of 1."""
# Quantity *must* be an integer at this point!
quantity = int(quantity)
@ -743,9 +684,7 @@ class Build(MPTTModel, ReferenceIndexingMixin):
)
else:
"""
Create a single build output of the given quantity
"""
"""Create a single build output of the given quantity."""
StockModels.StockItem.objects.create(
quantity=quantity,
@ -762,13 +701,12 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@transaction.atomic
def delete_output(self, output):
"""
Remove a build output from the database:
"""Remove a build output from the database.
Executes:
- Unallocate any build items against the output
- Delete the output StockItem
"""
if not output:
raise ValidationError(_("No build output specified"))
@ -786,11 +724,7 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@transaction.atomic
def subtract_allocated_stock(self, user):
"""
Called when the Build is marked as "complete",
this function removes the allocated untracked items from stock.
"""
"""Called when the Build is marked as "complete", this function removes the allocated untracked items from stock."""
items = self.allocated_stock.filter(
stock_item__part__trackable=False
)
@ -804,13 +738,11 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@transaction.atomic
def complete_build_output(self, output, user, **kwargs):
"""
Complete a particular build output
"""Complete a particular build output.
- Remove allocated StockItems
- Mark the output as complete
"""
# Select the location for the build output
location = kwargs.get('location', self.destination)
status = kwargs.get('status', StockStatus.OK)
@ -850,10 +782,9 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@transaction.atomic
def auto_allocate_stock(self, **kwargs):
"""
Automatically allocate stock items against this build order,
following a number of 'guidelines':
"""Automatically allocate stock items against this build order.
Following a number of 'guidelines':
- Only "untracked" BOM items are considered (tracked BOM items must be manually allocated)
- If a particular BOM item is already fully allocated, it is skipped
- Extract all available stock items for the BOM part
@ -863,7 +794,6 @@ class Build(MPTTModel, ReferenceIndexingMixin):
- If multiple stock items are found, we *may* be able to allocate:
- If the calling function has specified that items are interchangeable
"""
location = kwargs.get('location', None)
exclude_location = kwargs.get('exclude_location', None)
interchangeable = kwargs.get('interchangeable', False)
@ -958,14 +888,12 @@ class Build(MPTTModel, ReferenceIndexingMixin):
break
def required_quantity(self, bom_item, output=None):
"""
Get the quantity of a part required to complete the particular build output.
"""Get the quantity of a part required to complete the particular build output.
Args:
part: The Part object
output - The particular build output (StockItem)
bom_item: The Part object
output: The particular build output (StockItem)
"""
quantity = bom_item.quantity
if output:
@ -976,17 +904,15 @@ class Build(MPTTModel, ReferenceIndexingMixin):
return quantity
def allocated_bom_items(self, bom_item, output=None):
"""
Return all BuildItem objects which allocate stock of <bom_item> to <output>
"""Return all BuildItem objects which allocate stock of <bom_item> to <output>.
Note that the bom_item may allow variants, or direct substitutes,
making things difficult.
Args:
bom_item - The BomItem object
output - Build output (StockItem).
bom_item: The BomItem object
output: Build output (StockItem).
"""
allocations = BuildItem.objects.filter(
build=self,
bom_item=bom_item,
@ -996,10 +922,7 @@ class Build(MPTTModel, ReferenceIndexingMixin):
return allocations
def allocated_quantity(self, bom_item, output=None):
"""
Return the total quantity of given part allocated to a given build output.
"""
"""Return the total quantity of given part allocated to a given build output."""
allocations = self.allocated_bom_items(bom_item, output)
allocated = allocations.aggregate(
@ -1013,27 +936,18 @@ class Build(MPTTModel, ReferenceIndexingMixin):
return allocated['q']
def unallocated_quantity(self, bom_item, output=None):
"""
Return the total unallocated (remaining) quantity of a part against a particular output.
"""
"""Return the total unallocated (remaining) quantity of a part against a particular output."""
required = self.required_quantity(bom_item, output)
allocated = self.allocated_quantity(bom_item, output)
return max(required - allocated, 0)
def is_bom_item_allocated(self, bom_item, output=None):
"""
Test if the supplied BomItem has been fully allocated!
"""
"""Test if the supplied BomItem has been fully allocated!"""
return self.unallocated_quantity(bom_item, output) == 0
def is_fully_allocated(self, output):
"""
Returns True if the particular build output is fully allocated.
"""
"""Returns True if the particular build output is fully allocated."""
# If output is not specified, we are talking about "untracked" items
if output is None:
bom_items = self.untracked_bom_items
@ -1049,10 +963,7 @@ class Build(MPTTModel, ReferenceIndexingMixin):
return True
def is_partially_allocated(self, output):
"""
Returns True if the particular build output is (at least) partially allocated
"""
"""Returns True if the particular build output is (at least) partially allocated."""
# If output is not specified, we are talking about "untracked" items
if output is None:
bom_items = self.untracked_bom_items
@ -1067,17 +978,11 @@ class Build(MPTTModel, ReferenceIndexingMixin):
return False
def are_untracked_parts_allocated(self):
"""
Returns True if the un-tracked parts are fully allocated for this BuildOrder
"""
"""Returns True if the un-tracked parts are fully allocated for this BuildOrder."""
return self.is_fully_allocated(None)
def unallocated_bom_items(self, output):
"""
Return a list of bom items which have *not* been fully allocated against a particular output
"""
"""Return a list of bom items which have *not* been fully allocated against a particular output."""
unallocated = []
# If output is not specified, we are talking about "untracked" items
@ -1095,7 +1000,7 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@property
def required_parts(self):
""" Returns a list of parts required to build this part (BOM) """
"""Returns a list of parts required to build this part (BOM)."""
parts = []
for item in self.bom_items:
@ -1105,7 +1010,7 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@property
def required_parts_to_complete_build(self):
""" Returns a list of parts required to complete the full build """
"""Returns a list of parts required to complete the full build."""
parts = []
for bom_item in self.bom_items:
@ -1119,26 +1024,23 @@ class Build(MPTTModel, ReferenceIndexingMixin):
@property
def is_active(self):
""" Is this build active? An active build is either:
"""Is this build active?
An active build is either:
- PENDING
- HOLDING
"""
return self.status in BuildStatus.ACTIVE_CODES
@property
def is_complete(self):
""" Returns True if the build status is COMPLETE """
"""Returns True if the build status is COMPLETE."""
return self.status == BuildStatus.COMPLETE
@receiver(post_save, sender=Build, dispatch_uid='build_post_save_log')
def after_save_build(sender, instance: Build, created: bool, **kwargs):
"""
Callback function to be executed after a Build instance is saved
"""
"""Callback function to be executed after a Build instance is saved."""
from . import tasks as build_tasks
if created:
@ -1149,21 +1051,19 @@ def after_save_build(sender, instance: Build, created: bool, **kwargs):
class BuildOrderAttachment(InvenTreeAttachment):
"""
Model for storing file attachments against a BuildOrder object
"""
"""Model for storing file attachments against a BuildOrder object."""
def getSubdir(self):
"""Return the media file subdirectory for storing BuildOrder attachments"""
return os.path.join('bo_files', str(self.build.id))
build = models.ForeignKey(Build, on_delete=models.CASCADE, related_name='attachments')
class BuildItem(models.Model):
""" A BuildItem links multiple StockItem objects to a Build.
These are used to allocate part stock to a build.
Once the Build is completed, the parts are removed from stock and the
BuildItemAllocation objects are removed.
"""A BuildItem links multiple StockItem objects to a Build.
These are used to allocate part stock to a build. Once the Build is completed, the parts are removed from stock and the BuildItemAllocation objects are removed.
Attributes:
build: Link to a Build object
@ -1175,33 +1075,28 @@ class BuildItem(models.Model):
@staticmethod
def get_api_url():
"""Return the API URL used to access this model"""
return reverse('api-build-item-list')
def get_absolute_url(self):
# TODO - Fix!
return '/build/item/{pk}/'.format(pk=self.id)
# return reverse('build-detail', kwargs={'pk': self.id})
class Meta:
"""Serializer metaclass"""
unique_together = [
('build', 'stock_item', 'install_into'),
]
def save(self, *args, **kwargs):
"""Custom save method for the BuildItem model"""
self.clean()
super().save()
def clean(self):
"""
Check validity of this BuildItem instance.
The following checks are performed:
"""Check validity of this BuildItem instance.
The following checks are performed:
- StockItem.part must be in the BOM of the Part object referenced by Build
- Allocation quantity cannot exceed available quantity
"""
self.validate_unique()
super().clean()
@ -1303,13 +1198,11 @@ class BuildItem(models.Model):
@transaction.atomic
def complete_allocation(self, user, notes=''):
"""
Complete the allocation of this BuildItem into the output stock item.
"""Complete the allocation of this BuildItem into the output stock item.
- If the referenced part is trackable, the stock item will be *installed* into the build output
- If the referenced part is *not* trackable, the stock item will be removed from stock
"""
item = self.stock_item
# For a trackable part, special consideration needed!
@ -1344,10 +1237,7 @@ class BuildItem(models.Model):
)
def getStockItemThumbnail(self):
"""
Return qualified URL for part thumbnail image
"""
"""Return qualified URL for part thumbnail image."""
thumb_url = None
if self.stock_item and self.stock_item.part:

View File

@ -1,6 +1,4 @@
"""
JSON serializers for Build API
"""
"""JSON serializers for Build API."""
from django.db import transaction
from django.core.exceptions import ValidationError as DjangoValidationError
@ -31,9 +29,7 @@ from .models import Build, BuildItem, BuildOrderAttachment
class BuildSerializer(ReferenceIndexingSerializerMixin, InvenTreeModelSerializer):
"""
Serializes a Build object
"""
"""Serializes a Build object."""
url = serializers.CharField(source='get_absolute_url', read_only=True)
status_text = serializers.CharField(source='get_status_display', read_only=True)
@ -50,16 +46,12 @@ class BuildSerializer(ReferenceIndexingSerializerMixin, InvenTreeModelSerializer
@staticmethod
def annotate_queryset(queryset):
"""
Add custom annotations to the BuildSerializer queryset,
performing database queries as efficiently as possible.
"""Add custom annotations to the BuildSerializer queryset, performing database queries as efficiently as possible.
The following annoted fields are added:
- overdue: True if the build is outstanding *and* the completion date has past
"""
# Annotate a boolean 'overdue' flag
queryset = queryset.annotate(
@ -74,6 +66,7 @@ class BuildSerializer(ReferenceIndexingSerializerMixin, InvenTreeModelSerializer
return queryset
def __init__(self, *args, **kwargs):
"""Determine if extra serializer fields are required"""
part_detail = kwargs.pop('part_detail', True)
super().__init__(*args, **kwargs)
@ -82,6 +75,7 @@ class BuildSerializer(ReferenceIndexingSerializerMixin, InvenTreeModelSerializer
self.fields.pop('part_detail')
class Meta:
"""Serializer metaclass"""
model = Build
fields = [
'pk',
@ -121,8 +115,7 @@ class BuildSerializer(ReferenceIndexingSerializerMixin, InvenTreeModelSerializer
class BuildOutputSerializer(serializers.Serializer):
"""
Serializer for a "BuildOutput"
"""Serializer for a "BuildOutput".
Note that a "BuildOutput" is really just a StockItem which is "in production"!
"""
@ -136,7 +129,7 @@ class BuildOutputSerializer(serializers.Serializer):
)
def validate_output(self, output):
"""Perform validation for the output (StockItem) provided to the serializer"""
build = self.context['build']
# As this serializer can be used in multiple contexts, we need to work out why we are here
@ -168,14 +161,14 @@ class BuildOutputSerializer(serializers.Serializer):
return output
class Meta:
"""Serializer metaclass"""
fields = [
'output',
]
class BuildOutputCreateSerializer(serializers.Serializer):
"""
Serializer for creating a new BuildOutput against a BuildOrder.
"""Serializer for creating a new BuildOutput against a BuildOrder.
URL pattern is "/api/build/<pk>/create-output/", where <pk> is the PK of a Build.
@ -192,13 +185,15 @@ class BuildOutputCreateSerializer(serializers.Serializer):
)
def get_build(self):
"""Return the Build instance associated with this serializer"""
return self.context["build"]
def get_part(self):
"""Return the Part instance associated with the build"""
return self.get_build().part
def validate_quantity(self, quantity):
"""Validate the provided quantity field"""
if quantity <= 0:
raise ValidationError(_("Quantity must be greater than zero"))
@ -229,7 +224,7 @@ class BuildOutputCreateSerializer(serializers.Serializer):
)
def validate_serial_numbers(self, serial_numbers):
"""Clean the provided serial number string"""
serial_numbers = serial_numbers.strip()
return serial_numbers
@ -243,10 +238,7 @@ class BuildOutputCreateSerializer(serializers.Serializer):
)
def validate(self, data):
"""
Perform form validation
"""
"""Perform form validation."""
part = self.get_part()
# Cache a list of serial numbers (to be used in the "save" method)
@ -284,10 +276,7 @@ class BuildOutputCreateSerializer(serializers.Serializer):
return data
def save(self):
"""
Generate the new build output(s)
"""
"""Generate the new build output(s)"""
data = self.validated_data
quantity = data['quantity']
@ -305,11 +294,10 @@ class BuildOutputCreateSerializer(serializers.Serializer):
class BuildOutputDeleteSerializer(serializers.Serializer):
"""
DRF serializer for deleting (cancelling) one or more build outputs
"""
"""DRF serializer for deleting (cancelling) one or more build outputs."""
class Meta:
"""Serializer metaclass"""
fields = [
'outputs',
]
@ -320,7 +308,7 @@ class BuildOutputDeleteSerializer(serializers.Serializer):
)
def validate(self, data):
"""Perform data validation for this serializer"""
data = super().validate(data)
outputs = data.get('outputs', [])
@ -331,10 +319,7 @@ class BuildOutputDeleteSerializer(serializers.Serializer):
return data
def save(self):
"""
'save' the serializer to delete the build outputs
"""
"""'save' the serializer to delete the build outputs."""
data = self.validated_data
outputs = data.get('outputs', [])
@ -347,11 +332,10 @@ class BuildOutputDeleteSerializer(serializers.Serializer):
class BuildOutputCompleteSerializer(serializers.Serializer):
"""
DRF serializer for completing one or more build outputs
"""
"""DRF serializer for completing one or more build outputs."""
class Meta:
"""Serializer metaclass"""
fields = [
'outputs',
'location',
@ -393,7 +377,7 @@ class BuildOutputCompleteSerializer(serializers.Serializer):
)
def validate(self, data):
"""Perform data validation for this serializer"""
super().validate(data)
outputs = data.get('outputs', [])
@ -404,10 +388,7 @@ class BuildOutputCompleteSerializer(serializers.Serializer):
return data
def save(self):
"""
"save" the serializer to complete the build outputs
"""
"""Save the serializer to complete the build outputs."""
build = self.context['build']
request = self.context['request']
@ -435,15 +416,17 @@ class BuildOutputCompleteSerializer(serializers.Serializer):
class BuildCancelSerializer(serializers.Serializer):
"""DRF serializer class for cancelling an active BuildOrder"""
class Meta:
"""Serializer metaclass"""
fields = [
'remove_allocated_stock',
'remove_incomplete_outputs',
]
def get_context_data(self):
"""Retrieve extra context data from this serializer"""
build = self.context['build']
return {
@ -467,7 +450,7 @@ class BuildCancelSerializer(serializers.Serializer):
)
def save(self):
"""Cancel the specified build"""
build = self.context['build']
request = self.context['request']
@ -481,9 +464,7 @@ class BuildCancelSerializer(serializers.Serializer):
class BuildCompleteSerializer(serializers.Serializer):
"""
DRF serializer for marking a BuildOrder as complete
"""
"""DRF serializer for marking a BuildOrder as complete."""
accept_unallocated = serializers.BooleanField(
label=_('Accept Unallocated'),
@ -493,7 +474,7 @@ class BuildCompleteSerializer(serializers.Serializer):
)
def validate_accept_unallocated(self, value):
"""Check if the 'accept_unallocated' field is required"""
build = self.context['build']
if not build.are_untracked_parts_allocated() and not value:
@ -509,7 +490,7 @@ class BuildCompleteSerializer(serializers.Serializer):
)
def validate_accept_incomplete(self, value):
"""Check if the 'accept_incomplete' field is required"""
build = self.context['build']
if build.remaining > 0 and not value:
@ -518,7 +499,7 @@ class BuildCompleteSerializer(serializers.Serializer):
return value
def validate(self, data):
"""Perform validation of this serializer prior to saving"""
build = self.context['build']
if build.incomplete_count > 0:
@ -530,7 +511,7 @@ class BuildCompleteSerializer(serializers.Serializer):
return data
def save(self):
"""Complete the specified build output"""
request = self.context['request']
build = self.context['build']
@ -538,14 +519,12 @@ class BuildCompleteSerializer(serializers.Serializer):
class BuildUnallocationSerializer(serializers.Serializer):
"""
DRF serializer for unallocating stock from a BuildOrder
"""DRF serializer for unallocating stock from a BuildOrder.
Allocated stock can be unallocated with a number of filters:
- output: Filter against a particular build output (blank = untracked stock)
- bom_item: Filter against a particular BOM line item
"""
bom_item = serializers.PrimaryKeyRelatedField(
@ -567,8 +546,7 @@ class BuildUnallocationSerializer(serializers.Serializer):
)
def validate_output(self, stock_item):
# Stock item must point to the same build order!
"""Validation for the output StockItem instance. Stock item must point to the same build order!"""
build = self.context['build']
if stock_item and stock_item.build != build:
@ -577,11 +555,10 @@ class BuildUnallocationSerializer(serializers.Serializer):
return stock_item
def save(self):
"""
'Save' the serializer data.
"""Save the serializer data.
This performs the actual unallocation against the build order
"""
build = self.context['build']
data = self.validated_data
@ -593,9 +570,7 @@ class BuildUnallocationSerializer(serializers.Serializer):
class BuildAllocationItemSerializer(serializers.Serializer):
"""
A serializer for allocating a single stock item against a build order
"""
"""A serializer for allocating a single stock item against a build order."""
bom_item = serializers.PrimaryKeyRelatedField(
queryset=BomItem.objects.all(),
@ -606,10 +581,7 @@ class BuildAllocationItemSerializer(serializers.Serializer):
)
def validate_bom_item(self, bom_item):
"""
Check if the parts match!
"""
"""Check if the parts match"""
build = self.context['build']
# BomItem should point to the same 'part' as the parent build
@ -632,7 +604,7 @@ class BuildAllocationItemSerializer(serializers.Serializer):
)
def validate_stock_item(self, stock_item):
"""Perform validation of the stock_item field"""
if not stock_item.in_stock:
raise ValidationError(_("Item must be in stock"))
@ -646,7 +618,7 @@ class BuildAllocationItemSerializer(serializers.Serializer):
)
def validate_quantity(self, quantity):
"""Perform validation of the 'quantity' field"""
if quantity <= 0:
raise ValidationError(_("Quantity must be greater than zero"))
@ -661,6 +633,7 @@ class BuildAllocationItemSerializer(serializers.Serializer):
)
class Meta:
"""Serializer metaclass"""
fields = [
'bom_item',
'stock_item',
@ -669,7 +642,7 @@ class BuildAllocationItemSerializer(serializers.Serializer):
]
def validate(self, data):
"""Perfofrm data validation for this item"""
super().validate(data)
build = self.context['build']
@ -715,22 +688,18 @@ class BuildAllocationItemSerializer(serializers.Serializer):
class BuildAllocationSerializer(serializers.Serializer):
"""
DRF serializer for allocation stock items against a build order
"""
"""DRF serializer for allocation stock items against a build order."""
items = BuildAllocationItemSerializer(many=True)
class Meta:
"""Serializer metaclass"""
fields = [
'items',
]
def validate(self, data):
"""
Validation
"""
"""Validation."""
data = super().validate(data)
items = data.get('items', [])
@ -741,7 +710,7 @@ class BuildAllocationSerializer(serializers.Serializer):
return data
def save(self):
"""Perform the allocation"""
data = self.validated_data
items = data.get('items', [])
@ -770,11 +739,10 @@ class BuildAllocationSerializer(serializers.Serializer):
class BuildAutoAllocationSerializer(serializers.Serializer):
"""
DRF serializer for auto allocating stock items against a build order
"""
"""DRF serializer for auto allocating stock items against a build order."""
class Meta:
"""Serializer metaclass"""
fields = [
'location',
'exclude_location',
@ -813,7 +781,7 @@ class BuildAutoAllocationSerializer(serializers.Serializer):
)
def save(self):
"""Perform the auto-allocation step"""
data = self.validated_data
build = self.context['build']
@ -827,7 +795,7 @@ class BuildAutoAllocationSerializer(serializers.Serializer):
class BuildItemSerializer(InvenTreeModelSerializer):
""" Serializes a BuildItem object """
"""Serializes a BuildItem object."""
bom_part = serializers.IntegerField(source='bom_item.sub_part.pk', read_only=True)
part = serializers.IntegerField(source='stock_item.part.pk', read_only=True)
@ -842,7 +810,7 @@ class BuildItemSerializer(InvenTreeModelSerializer):
quantity = InvenTreeDecimalField()
def __init__(self, *args, **kwargs):
"""Determine which extra details fields should be included"""
build_detail = kwargs.pop('build_detail', False)
part_detail = kwargs.pop('part_detail', False)
location_detail = kwargs.pop('location_detail', False)
@ -859,6 +827,7 @@ class BuildItemSerializer(InvenTreeModelSerializer):
self.fields.pop('location_detail')
class Meta:
"""Serializer metaclass"""
model = BuildItem
fields = [
'pk',
@ -877,11 +846,10 @@ class BuildItemSerializer(InvenTreeModelSerializer):
class BuildAttachmentSerializer(InvenTreeAttachmentSerializer):
"""
Serializer for a BuildAttachment
"""
"""Serializer for a BuildAttachment."""
class Meta:
"""Serializer metaclass"""
model = BuildOrderAttachment
fields = [

View File

@ -1,3 +1,5 @@
"""Background task definitions for the BuildOrder app"""
from decimal import Decimal
import logging
@ -18,11 +20,10 @@ logger = logging.getLogger('inventree')
def check_build_stock(build: build.models.Build):
"""
Check the required stock for a newly created build order,
and send an email out to any subscribed users if stock is low.
"""
"""Check the required stock for a newly created build order.
Send an email out to any subscribed users if stock is low.
"""
# Do not notify if we are importing data
if isImportingData():
return

View File

@ -1,3 +1,5 @@
"""Unit tests for the BuildOrder API"""
from datetime import datetime, timedelta
from django.urls import reverse
@ -13,8 +15,8 @@ from InvenTree.api_tester import InvenTreeAPITestCase
class TestBuildAPI(InvenTreeAPITestCase):
"""
Series of tests for the Build DRF API
"""Series of tests for the Build DRF API.
- Tests for Build API
- Tests for BuildItem API
"""
@ -33,10 +35,7 @@ class TestBuildAPI(InvenTreeAPITestCase):
]
def test_get_build_list(self):
"""
Test that we can retrieve list of build objects
"""
"""Test that we can retrieve list of build objects."""
url = reverse('api-build-list')
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -65,7 +64,7 @@ class TestBuildAPI(InvenTreeAPITestCase):
self.assertEqual(len(response.data), 0)
def test_get_build_item_list(self):
""" Test that we can retrieve list of BuildItem objects """
"""Test that we can retrieve list of BuildItem objects."""
url = reverse('api-build-item-list')
response = self.client.get(url, format='json')
@ -77,9 +76,7 @@ class TestBuildAPI(InvenTreeAPITestCase):
class BuildAPITest(InvenTreeAPITestCase):
"""
Series of tests for the Build DRF API
"""
"""Series of tests for the Build DRF API."""
fixtures = [
'category',
@ -96,18 +93,12 @@ class BuildAPITest(InvenTreeAPITestCase):
'build.add'
]
def setUp(self):
super().setUp()
class BuildTest(BuildAPITest):
"""
Unit testing for the build complete API endpoint
"""
"""Unit testing for the build complete API endpoint."""
def setUp(self):
"""Basic setup for this test suite"""
super().setUp()
self.build = Build.objects.get(pk=1)
@ -115,10 +106,7 @@ class BuildTest(BuildAPITest):
self.url = reverse('api-build-output-complete', kwargs={'pk': self.build.pk})
def test_invalid(self):
"""
Test with invalid data
"""
"""Test with invalid data."""
# Test with an invalid build ID
self.post(
reverse('api-build-output-complete', kwargs={'pk': 99999}),
@ -199,10 +187,7 @@ class BuildTest(BuildAPITest):
)
def test_complete(self):
"""
Test build order completion
"""
"""Test build order completion."""
# Initially, build should not be able to be completed
self.assertFalse(self.build.can_complete)
@ -270,8 +255,7 @@ class BuildTest(BuildAPITest):
self.assertTrue(self.build.is_complete)
def test_cancel(self):
""" Test that we can cancel a BuildOrder via the API """
"""Test that we can cancel a BuildOrder via the API."""
bo = Build.objects.get(pk=1)
url = reverse('api-build-cancel', kwargs={'pk': bo.pk})
@ -285,10 +269,7 @@ class BuildTest(BuildAPITest):
self.assertEqual(bo.status, BuildStatus.CANCELLED)
def test_create_delete_output(self):
"""
Test that we can create and delete build outputs via the API
"""
"""Test that we can create and delete build outputs via the API."""
bo = Build.objects.get(pk=1)
n_outputs = bo.output_count
@ -494,7 +475,7 @@ class BuildTest(BuildAPITest):
self.assertIn('This build output has already been completed', str(response.data))
def test_download_build_orders(self):
"""Test that we can download a list of build orders via the API"""
required_cols = [
'reference',
'status',
@ -539,19 +520,17 @@ class BuildTest(BuildAPITest):
class BuildAllocationTest(BuildAPITest):
"""
Unit tests for allocation of stock items against a build order.
"""Unit tests for allocation of stock items against a build order.
For this test, we will be using Build ID=1;
- This points to Part 100 (see fixture data in part.yaml)
- This Part already has a BOM with 4 items (see fixture data in bom.yaml)
- There are no BomItem objects yet created for this build
"""
def setUp(self):
"""Basic operation as part of test suite setup"""
super().setUp()
self.assignRole('build.add')
@ -565,10 +544,7 @@ class BuildAllocationTest(BuildAPITest):
self.n = BuildItem.objects.count()
def test_build_data(self):
"""
Check that our assumptions about the particular BuildOrder are correct
"""
"""Check that our assumptions about the particular BuildOrder are correct."""
self.assertEqual(self.build.part.pk, 100)
# There should be 4x BOM items we can use
@ -578,26 +554,17 @@ class BuildAllocationTest(BuildAPITest):
self.assertEqual(self.build.allocated_stock.count(), 0)
def test_get(self):
"""
A GET request to the endpoint should return an error
"""
"""A GET request to the endpoint should return an error."""
self.get(self.url, expected_code=405)
def test_options(self):
"""
An OPTIONS request to the endpoint should return information about the endpoint
"""
"""An OPTIONS request to the endpoint should return information about the endpoint."""
response = self.options(self.url, expected_code=200)
self.assertIn("API endpoint to allocate stock items to a build order", str(response.data))
def test_empty(self):
"""
Test without any POST data
"""
"""Test without any POST data."""
# Initially test with an empty data set
data = self.post(self.url, {}, expected_code=400).data
@ -618,10 +585,7 @@ class BuildAllocationTest(BuildAPITest):
self.assertEqual(self.n, BuildItem.objects.count())
def test_missing(self):
"""
Test with missing data
"""
"""Test with missing data."""
# Missing quantity
data = self.post(
self.url,
@ -674,10 +638,7 @@ class BuildAllocationTest(BuildAPITest):
self.assertEqual(self.n, BuildItem.objects.count())
def test_invalid_bom_item(self):
"""
Test by passing an invalid BOM item
"""
"""Test by passing an invalid BOM item."""
data = self.post(
self.url,
{
@ -695,11 +656,10 @@ class BuildAllocationTest(BuildAPITest):
self.assertIn('must point to the same part', str(data))
def test_valid_data(self):
"""
Test with valid data.
"""Test with valid data.
This should result in creation of a new BuildItem object
"""
self.post(
self.url,
{
@ -725,17 +685,12 @@ class BuildAllocationTest(BuildAPITest):
class BuildListTest(BuildAPITest):
"""
Tests for the BuildOrder LIST API
"""
"""Tests for the BuildOrder LIST API."""
url = reverse('api-build-list')
def test_get_all_builds(self):
"""
Retrieve *all* builds via the API
"""
"""Retrieve *all* builds via the API."""
builds = self.get(self.url)
self.assertEqual(len(builds.data), 5)
@ -753,10 +708,7 @@ class BuildListTest(BuildAPITest):
self.assertEqual(len(builds.data), 0)
def test_overdue(self):
"""
Create a new build, in the past
"""
"""Create a new build, in the past."""
in_the_past = datetime.now().date() - timedelta(days=50)
part = Part.objects.get(pk=50)
@ -776,10 +728,7 @@ class BuildListTest(BuildAPITest):
self.assertEqual(len(builds), 1)
def test_sub_builds(self):
"""
Test the build / sub-build relationship
"""
"""Test the build / sub-build relationship."""
parent = Build.objects.get(pk=5)
part = Part.objects.get(pk=50)

View File

@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""Unit tests for the 'build' models"""
from django.test import TestCase
@ -12,13 +12,10 @@ from stock.models import StockItem
class BuildTestBase(TestCase):
"""
Run some tests to ensure that the Build model is working properly.
"""
"""Run some tests to ensure that the Build model is working properly."""
def setUp(self):
"""
Initialize data to use for these tests.
"""Initialize data to use for these tests.
The base Part 'assembly' has a BOM consisting of three parts:
@ -119,11 +116,10 @@ class BuildTestBase(TestCase):
class BuildTest(BuildTestBase):
"""Unit testing class for the Build model"""
def test_ref_int(self):
"""
Test the "integer reference" field used for natural sorting
"""
"""Test the "integer reference" field used for natural sorting"""
for ii in range(10):
build = Build(
@ -141,7 +137,7 @@ class BuildTest(BuildTestBase):
self.assertEqual(build.reference_int, ii)
def test_init(self):
# Perform some basic tests before we start the ball rolling
"""Perform some basic tests before we start the ball rolling"""
self.assertEqual(StockItem.objects.count(), 10)
@ -166,7 +162,7 @@ class BuildTest(BuildTestBase):
self.assertFalse(self.build.is_complete)
def test_build_item_clean(self):
# Ensure that dodgy BuildItem objects cannot be created
"""Ensure that dodgy BuildItem objects cannot be created"""
stock = StockItem.objects.create(part=self.assembly, quantity=99)
@ -193,7 +189,7 @@ class BuildTest(BuildTestBase):
b.save()
def test_duplicate_bom_line(self):
# Try to add a duplicate BOM item - it should be allowed
"""Try to add a duplicate BOM item - it should be allowed"""
BomItem.objects.create(
part=self.assembly,
@ -202,12 +198,11 @@ class BuildTest(BuildTestBase):
)
def allocate_stock(self, output, allocations):
"""
Allocate stock to this build, against a particular output
"""Allocate stock to this build, against a particular output
Args:
output - StockItem object (or None)
allocations - Map of {StockItem: quantity}
output: StockItem object (or None)
allocations: Map of {StockItem: quantity}
"""
for item, quantity in allocations.items():
@ -219,9 +214,7 @@ class BuildTest(BuildTestBase):
)
def test_partial_allocation(self):
"""
Test partial allocation of stock
"""
"""Test partial allocation of stock"""
# Fully allocate tracked stock against build output 1
self.allocate_stock(
@ -294,9 +287,7 @@ class BuildTest(BuildTestBase):
self.assertTrue(self.build.are_untracked_parts_allocated())
def test_cancel(self):
"""
Test cancellation of the build
"""
"""Test cancellation of the build"""
# TODO
@ -309,9 +300,7 @@ class BuildTest(BuildTestBase):
pass
def test_complete(self):
"""
Test completion of a build output
"""
"""Test completion of a build output"""
self.stock_1_1.quantity = 1000
self.stock_1_1.save()
@ -385,12 +374,10 @@ class BuildTest(BuildTestBase):
class AutoAllocationTests(BuildTestBase):
"""
Tests for auto allocating stock against a build order
"""
"""Tests for auto allocating stock against a build order"""
def setUp(self):
"""Init routines for this unit test class"""
super().setUp()
# Add a "substitute" part for bom_item_2
@ -411,8 +398,7 @@ class AutoAllocationTests(BuildTestBase):
)
def test_auto_allocate(self):
"""
Run the 'auto-allocate' function. What do we expect to happen?
"""Run the 'auto-allocate' function. What do we expect to happen?
There are two "untracked" parts:
- sub_part_1 (quantity 5 per BOM = 50 required total) / 103 in stock (2 items)
@ -474,9 +460,7 @@ class AutoAllocationTests(BuildTestBase):
self.assertTrue(self.build.is_bom_item_allocated(self.bom_item_2))
def test_fully_auto(self):
"""
We should be able to auto-allocate against a build in a single go
"""
"""We should be able to auto-allocate against a build in a single go"""
self.build.auto_allocate_stock(
interchangeable=True,

View File

@ -1,6 +1,4 @@
"""
Tests for the build model database migrations
"""
"""Tests for the build model database migrations."""
from django_test_migrations.contrib.unittest_case import MigratorTestCase
@ -8,18 +6,13 @@ from InvenTree import helpers
class TestForwardMigrations(MigratorTestCase):
"""
Test entire schema migration sequence for the build app
"""
"""Test entire schema migration sequence for the build app."""
migrate_from = ('build', helpers.getOldestMigrationFile('build'))
migrate_to = ('build', helpers.getNewestMigrationFile('build'))
def prepare(self):
"""
Create initial data!
"""
"""Create initial data!"""
Part = self.old_state.apps.get_model('part', 'part')
buildable_part = Part.objects.create(
@ -45,7 +38,7 @@ class TestForwardMigrations(MigratorTestCase):
)
def test_items_exist(self):
"""Test to ensure that the 'assembly' field is correctly configured"""
Part = self.new_state.apps.get_model('part', 'part')
self.assertEqual(Part.objects.count(), 1)
@ -63,18 +56,13 @@ class TestForwardMigrations(MigratorTestCase):
class TestReferenceMigration(MigratorTestCase):
"""
Test custom migration which adds 'reference' field to Build model
"""
"""Test custom migration which adds 'reference' field to Build model."""
migrate_from = ('build', helpers.getOldestMigrationFile('build'))
migrate_to = ('build', '0018_build_reference')
def prepare(self):
"""
Create some builds
"""
"""Create some builds."""
Part = self.old_state.apps.get_model('part', 'part')
part = Part.objects.create(
@ -108,7 +96,7 @@ class TestReferenceMigration(MigratorTestCase):
print(build.reference)
def test_build_reference(self):
"""Test that the build reference is correctly assigned to the PK of the Build"""
Build = self.new_state.apps.get_model('build', 'build')
self.assertEqual(Build.objects.count(), 3)

View File

@ -1,3 +1,5 @@
"""Basic unit tests for the BuildOrder app"""
from django.urls import reverse
from datetime import datetime, timedelta
@ -11,6 +13,7 @@ from InvenTree.status_codes import BuildStatus
class BuildTestSimple(InvenTreeTestCase):
"""Basic set of tests for the BuildOrder model functionality"""
fixtures = [
'category',
@ -26,7 +29,7 @@ class BuildTestSimple(InvenTreeTestCase):
]
def test_build_objects(self):
# Ensure the Build objects were correctly created
"""Ensure the Build objects were correctly created"""
self.assertEqual(Build.objects.count(), 5)
b = Build.objects.get(pk=2)
self.assertEqual(b.batch, 'B2')
@ -35,10 +38,12 @@ class BuildTestSimple(InvenTreeTestCase):
self.assertEqual(str(b), 'BO0002')
def test_url(self):
"""Test URL lookup"""
b1 = Build.objects.get(pk=1)
self.assertEqual(b1.get_absolute_url(), '/build/1/')
def test_is_complete(self):
"""Test build completion status"""
b1 = Build.objects.get(pk=1)
b2 = Build.objects.get(pk=2)
@ -48,10 +53,7 @@ class BuildTestSimple(InvenTreeTestCase):
self.assertEqual(b2.status, BuildStatus.COMPLETE)
def test_overdue(self):
"""
Test overdue status functionality
"""
"""Test overdue status functionality."""
today = datetime.now().date()
build = Build.objects.get(pk=1)
@ -66,6 +68,7 @@ class BuildTestSimple(InvenTreeTestCase):
self.assertFalse(build.is_overdue)
def test_is_active(self):
"""Test active / inactive build status"""
b1 = Build.objects.get(pk=1)
b2 = Build.objects.get(pk=2)
@ -73,12 +76,12 @@ class BuildTestSimple(InvenTreeTestCase):
self.assertEqual(b2.is_active, False)
def test_required_parts(self):
# TODO - Generate BOM for test part
pass
"""Test set of required BOM items for the build"""
# TODO: Generate BOM for test part
...
def test_cancel_build(self):
""" Test build cancellation function """
"""Test build cancellation function."""
build = Build.objects.get(id=1)
self.assertEqual(build.status, BuildStatus.PENDING)
@ -89,7 +92,7 @@ class BuildTestSimple(InvenTreeTestCase):
class TestBuildViews(InvenTreeTestCase):
""" Tests for Build app views """
"""Tests for Build app views."""
fixtures = [
'category',
@ -105,6 +108,7 @@ class TestBuildViews(InvenTreeTestCase):
]
def setUp(self):
"""Fixturing for this suite of unit tests"""
super().setUp()
# Create a build output for build # 1
@ -118,14 +122,12 @@ class TestBuildViews(InvenTreeTestCase):
)
def test_build_index(self):
""" test build index view """
"""Test build index view."""
response = self.client.get(reverse('build-index'))
self.assertEqual(response.status_code, 200)
def test_build_detail(self):
""" Test the detail view for a Build object """
"""Test the detail view for a Build object."""
pk = 1
response = self.client.get(reverse('build-detail', args=(pk,)))

View File

@ -1,6 +1,4 @@
"""
URL lookup for Build app
"""
"""URL lookup for Build app."""
from django.urls import include, re_path

View File

@ -1,6 +1,4 @@
"""
Django views for interacting with Build objects
"""
"""Django views for interacting with Build objects."""
from django.utils.translation import gettext_lazy as _
from django.views.generic import DetailView, ListView
@ -15,42 +13,25 @@ from plugin.views import InvenTreePluginViewMixin
class BuildIndex(InvenTreeRoleMixin, ListView):
"""
View for displaying list of Builds
"""
"""View for displaying list of Builds."""
model = Build
template_name = 'build/index.html'
context_object_name = 'builds'
def get_queryset(self):
""" Return all Build objects (order by date, newest first) """
"""Return all Build objects (order by date, newest first)"""
return Build.objects.order_by('status', '-completion_date')
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
context['BuildStatus'] = BuildStatus
context['active'] = self.get_queryset().filter(status__in=BuildStatus.ACTIVE_CODES)
context['completed'] = self.get_queryset().filter(status=BuildStatus.COMPLETE)
context['cancelled'] = self.get_queryset().filter(status=BuildStatus.CANCELLED)
return context
class BuildDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView):
"""
Detail view of a single Build object.
"""
"""Detail view of a single Build object."""
model = Build
template_name = 'build/detail.html'
context_object_name = 'build'
def get_context_data(self, **kwargs):
"""Return extra context information for the BuildDetail view"""
ctx = super().get_context_data(**kwargs)
build = self.get_object()
@ -71,9 +52,7 @@ class BuildDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView):
class BuildDelete(AjaxDeleteView):
"""
View to delete a build
"""
"""View to delete a build."""
model = Build
ajax_template_name = 'build/delete_build.html'

View File

@ -1,3 +1,5 @@
"""Admin for the common app."""
from django.contrib import admin
from import_export.admin import ImportExportModelAdmin
@ -6,14 +8,12 @@ import common.models
class SettingsAdmin(ImportExportModelAdmin):
"""Admin settings for InvenTreeSetting."""
list_display = ('key', 'value')
def get_readonly_fields(self, request, obj=None): # pragma: no cover
"""
Prevent the 'key' field being edited once the setting is created
"""
"""Prevent the 'key' field being edited once the setting is created."""
if obj:
return ['key']
else:
@ -21,14 +21,12 @@ class SettingsAdmin(ImportExportModelAdmin):
class UserSettingsAdmin(ImportExportModelAdmin):
"""Admin settings for InvenTreeUserSetting."""
list_display = ('key', 'value', 'user', )
def get_readonly_fields(self, request, obj=None): # pragma: no cover
"""
Prevent the 'key' field being edited once the setting is created
"""
"""Prevent the 'key' field being edited once the setting is created."""
if obj:
return ['key']
else:
@ -36,16 +34,19 @@ class UserSettingsAdmin(ImportExportModelAdmin):
class WebhookAdmin(ImportExportModelAdmin):
"""Admin settings for Webhook."""
list_display = ('endpoint_id', 'name', 'active', 'user')
class NotificationEntryAdmin(admin.ModelAdmin):
"""Admin settings for NotificationEntry."""
list_display = ('key', 'uid', 'updated', )
class NotificationMessageAdmin(admin.ModelAdmin):
"""Admin settings for NotificationMessage."""
list_display = ('age_human', 'user', 'category', 'name', 'read', 'target_object', 'source_object', )

View File

@ -1,6 +1,4 @@
"""
Provides a JSON API for common components.
"""
"""Provides a JSON API for common components."""
import json
@ -24,25 +22,23 @@ from plugin.serializers import NotificationUserSettingSerializer
class CsrfExemptMixin(object):
"""
Exempts the view from CSRF requirements.
"""
"""Exempts the view from CSRF requirements."""
@method_decorator(csrf_exempt)
def dispatch(self, *args, **kwargs):
return super(CsrfExemptMixin, self).dispatch(*args, **kwargs)
"""Overwrites dispatch to be extempt from csrf checks."""
return super().dispatch(*args, **kwargs)
class WebhookView(CsrfExemptMixin, APIView):
"""
Endpoint for receiving webhooks.
"""
"""Endpoint for receiving webhooks."""
authentication_classes = []
permission_classes = []
model_class = common.models.WebhookEndpoint
run_async = False
def post(self, request, endpoint, *args, **kwargs):
"""Process incomming webhook."""
# get webhook definition
self._get_webhook(endpoint, request, *args, **kwargs)
@ -101,6 +97,10 @@ class WebhookView(CsrfExemptMixin, APIView):
class SettingsList(generics.ListAPIView):
"""Generic ListView for settings.
This is inheritted by all list views for settings.
"""
filter_backends = [
DjangoFilterBackend,
@ -120,24 +120,17 @@ class SettingsList(generics.ListAPIView):
class GlobalSettingsList(SettingsList):
"""
API endpoint for accessing a list of global settings objects
"""
"""API endpoint for accessing a list of global settings objects."""
queryset = common.models.InvenTreeSetting.objects.all()
serializer_class = common.serializers.GlobalSettingsSerializer
class GlobalSettingsPermissions(permissions.BasePermission):
"""
Special permission class to determine if the user is "staff"
"""
"""Special permission class to determine if the user is "staff"."""
def has_permission(self, request, view):
"""
Check that the requesting user is 'admin'
"""
"""Check that the requesting user is 'admin'."""
try:
user = request.user
@ -152,8 +145,7 @@ class GlobalSettingsPermissions(permissions.BasePermission):
class GlobalSettingsDetail(generics.RetrieveUpdateAPIView):
"""
Detail view for an individual "global setting" object.
"""Detail view for an individual "global setting" object.
- User must have 'staff' status to view / edit
"""
@ -163,10 +155,7 @@ class GlobalSettingsDetail(generics.RetrieveUpdateAPIView):
serializer_class = common.serializers.GlobalSettingsSerializer
def get_object(self):
"""
Attempt to find a global setting object with the provided key.
"""
"""Attempt to find a global setting object with the provided key."""
key = self.kwargs['key']
if key not in common.models.InvenTreeSetting.SETTINGS.keys():
@ -181,18 +170,13 @@ class GlobalSettingsDetail(generics.RetrieveUpdateAPIView):
class UserSettingsList(SettingsList):
"""
API endpoint for accessing a list of user settings objects
"""
"""API endpoint for accessing a list of user settings objects."""
queryset = common.models.InvenTreeUserSetting.objects.all()
serializer_class = common.serializers.UserSettingsSerializer
def filter_queryset(self, queryset):
"""
Only list settings which apply to the current user
"""
"""Only list settings which apply to the current user."""
try:
user = self.request.user
except AttributeError: # pragma: no cover
@ -206,12 +190,10 @@ class UserSettingsList(SettingsList):
class UserSettingsPermissions(permissions.BasePermission):
"""
Special permission class to determine if the user can view / edit a particular setting
"""
"""Special permission class to determine if the user can view / edit a particular setting."""
def has_object_permission(self, request, view, obj):
"""Check if the user that requested is also the object owner."""
try:
user = request.user
except AttributeError: # pragma: no cover
@ -221,8 +203,7 @@ class UserSettingsPermissions(permissions.BasePermission):
class UserSettingsDetail(generics.RetrieveUpdateAPIView):
"""
Detail view for an individual "user setting" object
"""Detail view for an individual "user setting" object.
- User can only view / edit settings their own settings objects
"""
@ -232,10 +213,7 @@ class UserSettingsDetail(generics.RetrieveUpdateAPIView):
serializer_class = common.serializers.UserSettingsSerializer
def get_object(self):
"""
Attempt to find a user setting object with the provided key.
"""
"""Attempt to find a user setting object with the provided key."""
key = self.kwargs['key']
if key not in common.models.InvenTreeUserSetting.SETTINGS.keys():
@ -249,18 +227,13 @@ class UserSettingsDetail(generics.RetrieveUpdateAPIView):
class NotificationUserSettingsList(SettingsList):
"""
API endpoint for accessing a list of notification user settings objects
"""
"""API endpoint for accessing a list of notification user settings objects."""
queryset = NotificationUserSetting.objects.all()
serializer_class = NotificationUserSettingSerializer
def filter_queryset(self, queryset):
"""
Only list settings which apply to the current user
"""
"""Only list settings which apply to the current user."""
try:
user = self.request.user
except AttributeError:
@ -272,8 +245,7 @@ class NotificationUserSettingsList(SettingsList):
class NotificationUserSettingsDetail(generics.RetrieveUpdateAPIView):
"""
Detail view for an individual "notification user setting" object
"""Detail view for an individual "notification user setting" object.
- User can only view / edit settings their own settings objects
"""
@ -287,6 +259,8 @@ class NotificationUserSettingsDetail(generics.RetrieveUpdateAPIView):
class NotificationList(generics.ListAPIView):
"""List view for all notifications of the current user."""
queryset = common.models.NotificationMessage.objects.all()
serializer_class = common.serializers.NotificationMessageSerializer
@ -313,10 +287,7 @@ class NotificationList(generics.ListAPIView):
]
def filter_queryset(self, queryset):
"""
Only list notifications which apply to the current user
"""
"""Only list notifications which apply to the current user."""
try:
user = self.request.user
except AttributeError:
@ -328,8 +299,7 @@ class NotificationList(generics.ListAPIView):
class NotificationDetail(generics.RetrieveUpdateDestroyAPIView):
"""
Detail view for an individual notification object
"""Detail view for an individual notification object.
- User can only view / delete their own notification objects
"""
@ -342,9 +312,7 @@ class NotificationDetail(generics.RetrieveUpdateDestroyAPIView):
class NotificationReadEdit(generics.CreateAPIView):
"""
general API endpoint to manipulate read state of a notification
"""
"""General API endpoint to manipulate read state of a notification."""
queryset = common.models.NotificationMessage.objects.all()
serializer_class = common.serializers.NotificationReadSerializer
@ -354,12 +322,14 @@ class NotificationReadEdit(generics.CreateAPIView):
]
def get_serializer_context(self):
"""Add instance to context so it can be accessed in the serializer."""
context = super().get_serializer_context()
if self.request:
context['instance'] = self.get_object()
return context
def perform_create(self, serializer):
"""Set the `read` status to the target value."""
message = self.get_object()
try:
message.read = self.target
@ -369,23 +339,17 @@ class NotificationReadEdit(generics.CreateAPIView):
class NotificationRead(NotificationReadEdit):
"""
API endpoint to mark a notification as read.
"""
"""API endpoint to mark a notification as read."""
target = True
class NotificationUnread(NotificationReadEdit):
"""
API endpoint to mark a notification as unread.
"""
"""API endpoint to mark a notification as unread."""
target = False
class NotificationReadAll(generics.RetrieveAPIView):
"""
API endpoint to mark all notifications as read.
"""
"""API endpoint to mark all notifications as read."""
queryset = common.models.NotificationMessage.objects.all()
@ -394,6 +358,7 @@ class NotificationReadAll(generics.RetrieveAPIView):
]
def get(self, request, *args, **kwargs):
"""Set all messages for the current user as read."""
try:
self.queryset.filter(user=request.user, read=False).update(read=True)
return Response({'status': 'ok'})

View File

@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""App config for common app."""
import logging
@ -8,17 +8,19 @@ logger = logging.getLogger('inventree')
class CommonConfig(AppConfig):
"""AppConfig for common app.
Clears system wide flags on ready.
"""
name = 'common'
def ready(self):
"""Initialize restart flag clearance on startup."""
self.clear_restart_flag()
def clear_restart_flag(self):
"""
Clear the SERVER_RESTART_REQUIRED setting
"""
"""Clear the SERVER_RESTART_REQUIRED setting."""
try:
import common.models

View File

@ -1,6 +1,4 @@
"""
Files management tools.
"""
"""Files management tools."""
import os
@ -12,7 +10,7 @@ from rapidfuzz import fuzz
class FileManager:
""" Class for managing an uploaded file """
"""Class for managing an uploaded file."""
name = ''
@ -32,8 +30,7 @@ class FileManager:
HEADERS = []
def __init__(self, file, name=None):
""" Initialize the FileManager class with a user-uploaded file object """
"""Initialize the FileManager class with a user-uploaded file object."""
# Set name
if name:
self.name = name
@ -46,8 +43,7 @@ class FileManager:
@classmethod
def validate(cls, file):
""" Validate file extension and data """
"""Validate file extension and data."""
cleaned_data = None
ext = os.path.splitext(file.name)[-1].lower().replace('.', '')
@ -79,21 +75,15 @@ class FileManager:
return cleaned_data
def process(self, file):
""" Process file """
"""Process file."""
self.data = self.__class__.validate(file)
def update_headers(self):
""" Update headers """
"""Update headers."""
self.HEADERS = self.REQUIRED_HEADERS + self.ITEM_MATCH_HEADERS + self.OPTIONAL_MATCH_HEADERS + self.OPTIONAL_HEADERS
def setup(self):
"""
Setup headers
should be overriden in usage to set the Different Headers
"""
"""Setup headers should be overriden in usage to set the Different Headers."""
if not self.name:
return
@ -101,14 +91,15 @@ class FileManager:
self.update_headers()
def guess_header(self, header, threshold=80):
"""
Try to match a header (from the file) to a list of known headers
"""Try to match a header (from the file) to a list of known headers.
Args:
header - Header name to look for
threshold - Match threshold for fuzzy search
"""
header (Any): Header name to look for
threshold (int, optional): Match threshold for fuzzy search. Defaults to 80.
Returns:
Any: Matched headers
"""
# Replace null values with empty string
if header is None:
header = ''
@ -143,7 +134,7 @@ class FileManager:
return None
def columns(self):
""" Return a list of headers for the thingy """
"""Return a list of headers for the thingy."""
headers = []
for header in self.data.headers:
@ -170,21 +161,21 @@ class FileManager:
return headers
def col_count(self):
"""Return the number of columns in the file."""
if self.data is None:
return 0
return len(self.data.headers)
def row_count(self):
""" Return the number of rows in the file. """
"""Return the number of rows in the file."""
if self.data is None:
return 0
return len(self.data)
def rows(self):
""" Return a list of all rows """
"""Return a list of all rows."""
rows = []
for i in range(self.row_count()):
@ -221,15 +212,14 @@ class FileManager:
return rows
def get_row_data(self, index):
""" Retrieve row data at a particular index """
"""Retrieve row data at a particular index."""
if self.data is None or index >= len(self.data):
return None
return self.data[index]
def get_row_dict(self, index):
""" Retrieve a dict object representing the data row at a particular offset """
"""Retrieve a dict object representing the data row at a particular offset."""
if self.data is None or index >= len(self.data):
return None

View File

@ -1,6 +1,4 @@
"""
Django forms for interacting with common objects
"""
"""Django forms for interacting with common objects."""
from django import forms
from django.utils.translation import gettext as _
@ -12,11 +10,11 @@ from .models import InvenTreeSetting
class SettingEditForm(HelperForm):
"""
Form for creating / editing a settings object
"""
"""Form for creating / editing a settings object."""
class Meta:
"""Metaclassoptions for SettingEditForm."""
model = InvenTreeSetting
fields = [
@ -25,7 +23,7 @@ class SettingEditForm(HelperForm):
class UploadFileForm(forms.Form):
""" Step 1 of FileManagementFormView """
"""Step 1 of FileManagementFormView."""
file = forms.FileField(
label=_('File'),
@ -33,8 +31,7 @@ class UploadFileForm(forms.Form):
)
def __init__(self, *args, **kwargs):
""" Update label and help_text """
"""Update label and help_text."""
# Get file name
name = None
if 'name' in kwargs:
@ -48,11 +45,10 @@ class UploadFileForm(forms.Form):
self.fields['file'].help_text = _(f'Select {name} file to upload')
def clean_file(self):
"""
Run tabular file validation.
"""Run tabular file validation.
If anything is wrong with the file, it will raise ValidationError
"""
file = self.cleaned_data['file']
# Validate file using FileManager class - will perform initial data validation
@ -63,10 +59,10 @@ class UploadFileForm(forms.Form):
class MatchFieldForm(forms.Form):
""" Step 2 of FileManagementFormView """
"""Step 2 of FileManagementFormView."""
def __init__(self, *args, **kwargs):
"""Setup filemanager and check columsn."""
# Get FileManager
file_manager = None
if 'file_manager' in kwargs:
@ -96,10 +92,10 @@ class MatchFieldForm(forms.Form):
class MatchItemForm(forms.Form):
""" Step 3 of FileManagementFormView """
"""Step 3 of FileManagementFormView."""
def __init__(self, *args, **kwargs):
"""Setup filemanager and create fields."""
# Get FileManager
file_manager = None
if 'file_manager' in kwargs:
@ -194,6 +190,5 @@ class MatchItemForm(forms.Form):
)
def get_special_field(self, col_guess, row, file_manager):
""" Function to be overriden in inherited forms to add specific form settings """
"""Function to be overriden in inherited forms to add specific form settings."""
return None

View File

@ -1,5 +1,5 @@
"""
Common database model definitions.
"""Common database model definitions.
These models are 'generic' and do not fit a particular business logic object.
"""
@ -42,9 +42,10 @@ logger = logging.getLogger('inventree')
class EmptyURLValidator(URLValidator):
"""Validator for filed with url - that can be empty."""
def __call__(self, value):
"""Make sure empty values pass."""
value = str(value).strip()
if len(value) == 0:
@ -55,21 +56,17 @@ class EmptyURLValidator(URLValidator):
class BaseInvenTreeSetting(models.Model):
"""
An base InvenTreeSetting object is a key:value pair used for storing
single values (e.g. one-off settings values).
"""
"""An base InvenTreeSetting object is a key:value pair used for storing single values (e.g. one-off settings values)."""
SETTINGS = {}
class Meta:
"""Meta options for BaseInvenTreeSetting -> abstract stops creation of database entry."""
abstract = True
def save(self, *args, **kwargs):
"""
Enforce validation and clean before saving
"""
"""Enforce validation and clean before saving."""
self.key = str(self.key).upper()
self.clean(**kwargs)
@ -79,14 +76,12 @@ class BaseInvenTreeSetting(models.Model):
@classmethod
def allValues(cls, user=None, exclude_hidden=False):
"""
Return a dict of "all" defined global settings.
"""Return a dict of "all" defined global settings.
This performs a single database lookup,
and then any settings which are not *in* the database
are assigned their default values
"""
results = cls.objects.all()
# Optionally filter by user
@ -131,28 +126,23 @@ class BaseInvenTreeSetting(models.Model):
return settings
def get_kwargs(self):
"""
Construct kwargs for doing class-based settings lookup,
depending on *which* class we are.
"""Construct kwargs for doing class-based settings lookup, depending on *which* class we are.
This is necessary to abtract the settings object
from the implementing class (e.g plugins)
Subclasses should override this function to ensure the kwargs are correctly set.
"""
return {}
@classmethod
def get_setting_definition(cls, key, **kwargs):
"""
Return the 'definition' of a particular settings value, as a dict object.
"""Return the 'definition' of a particular settings value, as a dict object.
- The 'settings' dict can be passed as a kwarg
- If not passed, look for cls.SETTINGS
- Returns an empty dict if the key is not found
"""
settings = kwargs.get('settings', cls.SETTINGS)
key = str(key).strip().upper()
@ -164,69 +154,56 @@ class BaseInvenTreeSetting(models.Model):
@classmethod
def get_setting_name(cls, key, **kwargs):
"""
Return the name of a particular setting.
"""Return the name of a particular setting.
If it does not exist, return an empty string.
"""
setting = cls.get_setting_definition(key, **kwargs)
return setting.get('name', '')
@classmethod
def get_setting_description(cls, key, **kwargs):
"""
Return the description for a particular setting.
"""Return the description for a particular setting.
If it does not exist, return an empty string.
"""
setting = cls.get_setting_definition(key, **kwargs)
return setting.get('description', '')
@classmethod
def get_setting_units(cls, key, **kwargs):
"""
Return the units for a particular setting.
"""Return the units for a particular setting.
If it does not exist, return an empty string.
"""
setting = cls.get_setting_definition(key, **kwargs)
return setting.get('units', '')
@classmethod
def get_setting_validator(cls, key, **kwargs):
"""
Return the validator for a particular setting.
"""Return the validator for a particular setting.
If it does not exist, return None
"""
setting = cls.get_setting_definition(key, **kwargs)
return setting.get('validator', None)
@classmethod
def get_setting_default(cls, key, **kwargs):
"""
Return the default value for a particular setting.
"""Return the default value for a particular setting.
If it does not exist, return an empty string
"""
setting = cls.get_setting_definition(key, **kwargs)
return setting.get('default', '')
@classmethod
def get_setting_choices(cls, key, **kwargs):
"""
Return the validator choices available for a particular setting.
"""
"""Return the validator choices available for a particular setting."""
setting = cls.get_setting_definition(key, **kwargs)
choices = setting.get('choices', None)
@ -239,13 +216,11 @@ class BaseInvenTreeSetting(models.Model):
@classmethod
def get_setting_object(cls, key, **kwargs):
"""
Return an InvenTreeSetting object matching the given key.
"""Return an InvenTreeSetting object matching the given key.
- Key is case-insensitive
- Returns None if no match is made
"""
key = str(key).strip().upper()
settings = cls.objects.all()
@ -311,11 +286,10 @@ class BaseInvenTreeSetting(models.Model):
@classmethod
def get_setting(cls, key, backup_value=None, **kwargs):
"""
Get the value of a particular setting.
"""Get the value of a particular setting.
If it does not exist, return the backup value (default = None)
"""
# If no backup value is specified, atttempt to retrieve a "default" value
if backup_value is None:
backup_value = cls.get_setting_default(key, **kwargs)
@ -343,9 +317,7 @@ class BaseInvenTreeSetting(models.Model):
@classmethod
def set_setting(cls, key, value, change_user, create=True, **kwargs):
"""
Set the value of a particular setting.
If it does not exist, option to create it.
"""Set the value of a particular setting. If it does not exist, option to create it.
Args:
key: settings key
@ -353,7 +325,6 @@ class BaseInvenTreeSetting(models.Model):
change_user: User object (must be staff member to update a core setting)
create: If True, create a new setting if the specified key does not exist.
"""
if change_user is not None and not change_user.is_staff:
return
@ -397,26 +368,26 @@ class BaseInvenTreeSetting(models.Model):
@property
def name(self):
"""Return name for setting."""
return self.__class__.get_setting_name(self.key, **self.get_kwargs())
@property
def default_value(self):
"""Return default_value for setting."""
return self.__class__.get_setting_default(self.key, **self.get_kwargs())
@property
def description(self):
"""Return description for setting."""
return self.__class__.get_setting_description(self.key, **self.get_kwargs())
@property
def units(self):
"""Return units for setting."""
return self.__class__.get_setting_units(self.key, **self.get_kwargs())
def clean(self, **kwargs):
"""
If a validator (or multiple validators) are defined for a particular setting key,
run them against the 'value' field.
"""
"""If a validator (or multiple validators) are defined for a particular setting key, run them against the 'value' field."""
super().clean()
# Encode as native values
@ -437,10 +408,7 @@ class BaseInvenTreeSetting(models.Model):
raise ValidationError(_("Chosen value is not a valid option"))
def run_validator(self, validator):
"""
Run a validator against the 'value' field for this InvenTreeSetting object.
"""
"""Run a validator against the 'value' field for this InvenTreeSetting object."""
if validator is None:
return
@ -485,15 +453,11 @@ class BaseInvenTreeSetting(models.Model):
validator(value)
def validate_unique(self, exclude=None, **kwargs):
"""
Ensure that the key:value pair is unique.
In addition to the base validators, this ensures that the 'key'
is unique, using a case-insensitive comparison.
"""Ensure that the key:value pair is unique. In addition to the base validators, this ensures that the 'key' is unique, using a case-insensitive comparison.
Note that sub-classes (UserSetting, PluginSetting) use other filters
to determine if the setting is 'unique' or not
"""
super().validate_unique(exclude)
filters = {
@ -520,17 +484,11 @@ class BaseInvenTreeSetting(models.Model):
pass
def choices(self):
"""
Return the available choices for this setting (or None if no choices are defined)
"""
"""Return the available choices for this setting (or None if no choices are defined)."""
return self.__class__.get_setting_choices(self.key, **self.get_kwargs())
def valid_options(self):
"""
Return a list of valid options for this setting
"""
"""Return a list of valid options for this setting."""
choices = self.choices()
if not choices:
@ -539,21 +497,17 @@ class BaseInvenTreeSetting(models.Model):
return [opt[0] for opt in choices]
def is_choice(self):
"""
Check if this setting is a "choice" field
"""
"""Check if this setting is a "choice" field."""
return self.__class__.get_setting_choices(self.key, **self.get_kwargs()) is not None
def as_choice(self):
"""
Render this setting as the "display" value of a choice field,
e.g. if the choices are:
"""Render this setting as the "display" value of a choice field.
E.g. if the choices are:
[('A4', 'A4 paper'), ('A3', 'A3 paper')],
and the value is 'A4',
then display 'A4 paper'
"""
choices = self.get_setting_choices(self.key, **self.get_kwargs())
if not choices:
@ -566,30 +520,23 @@ class BaseInvenTreeSetting(models.Model):
return self.value
def is_model(self):
"""
Check if this setting references a model instance in the database
"""
"""Check if this setting references a model instance in the database."""
return self.model_name() is not None
def model_name(self):
"""
Return the model name associated with this setting
"""
"""Return the model name associated with this setting."""
setting = self.get_setting_definition(self.key, **self.get_kwargs())
return setting.get('model', None)
def model_class(self):
"""
Return the model class associated with this setting, if (and only if):
"""Return the model class associated with this setting.
If (and only if):
- It has a defined 'model' parameter
- The 'model' parameter is of the form app.model
- The 'model' parameter has matches a known app model
"""
model_name = self.model_name()
if not model_name:
@ -617,11 +564,7 @@ class BaseInvenTreeSetting(models.Model):
return model
def api_url(self):
"""
Return the API url associated with the linked model,
if provided, and valid!
"""
"""Return the API url associated with the linked model, if provided, and valid!"""
model_class = self.model_class()
if model_class:
@ -634,28 +577,20 @@ class BaseInvenTreeSetting(models.Model):
return None
def is_bool(self):
"""
Check if this setting is required to be a boolean value
"""
"""Check if this setting is required to be a boolean value."""
validator = self.__class__.get_setting_validator(self.key, **self.get_kwargs())
return self.__class__.validator_is_bool(validator)
def as_bool(self):
"""
Return the value of this setting converted to a boolean value.
"""Return the value of this setting converted to a boolean value.
Warning: Only use on values where is_bool evaluates to true!
"""
return InvenTree.helpers.str2bool(self.value)
def setting_type(self):
"""
Return the field type identifier for this setting object
"""
"""Return the field type identifier for this setting object."""
if self.is_bool():
return 'boolean'
@ -670,7 +605,7 @@ class BaseInvenTreeSetting(models.Model):
@classmethod
def validator_is_bool(cls, validator):
"""Return if validator is for bool."""
if validator == bool:
return True
@ -682,17 +617,14 @@ class BaseInvenTreeSetting(models.Model):
return False
def is_int(self,):
"""
Check if the setting is required to be an integer value:
"""
"""Check if the setting is required to be an integer value."""
validator = self.__class__.get_setting_validator(self.key, **self.get_kwargs())
return self.__class__.validator_is_int(validator)
@classmethod
def validator_is_int(cls, validator):
"""Return if validator is for int."""
if validator == int:
return True
@ -704,12 +636,10 @@ class BaseInvenTreeSetting(models.Model):
return False
def as_int(self):
"""
Return the value of this setting converted to a boolean value.
"""Return the value of this setting converted to a boolean value.
If an error occurs, return the default value
"""
try:
value = int(self.value)
except (ValueError, TypeError):
@ -719,41 +649,34 @@ class BaseInvenTreeSetting(models.Model):
@classmethod
def is_protected(cls, key, **kwargs):
"""
Check if the setting value is protected
"""
"""Check if the setting value is protected."""
setting = cls.get_setting_definition(key, **kwargs)
return setting.get('protected', False)
@property
def protected(self):
"""Returns if setting is protected from rendering."""
return self.__class__.is_protected(self.key, **self.get_kwargs())
def settings_group_options():
"""
Build up group tuple for settings based on your choices
"""
"""Build up group tuple for settings based on your choices."""
return [('', _('No group')), *[(str(a.id), str(a)) for a in Group.objects.all()]]
class InvenTreeSetting(BaseInvenTreeSetting):
"""
An InvenTreeSetting object is a key:value pair used for storing
single values (e.g. one-off settings values).
"""An InvenTreeSetting object is a key:value pair used for storing single values (e.g. one-off settings values).
The class provides a way of retrieving the value for a particular key,
even if that key does not exist.
"""
def save(self, *args, **kwargs):
"""
When saving a global setting, check to see if it requires a server restart.
"""When saving a global setting, check to see if it requires a server restart.
If so, set the "SERVER_RESTART_REQUIRED" setting to True
"""
super().save()
if self.requires_restart():
@ -1235,6 +1158,8 @@ class InvenTreeSetting(BaseInvenTreeSetting):
}
class Meta:
"""Meta options for InvenTreeSetting."""
verbose_name = "InvenTree Setting"
verbose_name_plural = "InvenTree Settings"
@ -1246,18 +1171,11 @@ class InvenTreeSetting(BaseInvenTreeSetting):
)
def to_native_value(self):
"""
Return the "pythonic" value,
e.g. convert "True" to True, and "1" to 1
"""
"""Return the "pythonic" value, e.g. convert "True" to True, and "1" to 1."""
return self.__class__.get_setting(self.key)
def requires_restart(self):
"""
Return True if this setting requires a server restart after changing
"""
"""Return True if this setting requires a server restart after changing."""
options = InvenTreeSetting.SETTINGS.get(self.key, None)
if options:
@ -1267,9 +1185,7 @@ class InvenTreeSetting(BaseInvenTreeSetting):
class InvenTreeUserSetting(BaseInvenTreeSetting):
"""
An InvenTreeSetting object with a usercontext
"""
"""An InvenTreeSetting object with a usercontext."""
SETTINGS = {
'HOMEPAGE_PART_STARRED': {
@ -1561,6 +1477,8 @@ class InvenTreeUserSetting(BaseInvenTreeSetting):
}
class Meta:
"""Meta options for InvenTreeUserSetting."""
verbose_name = "InvenTree User Setting"
verbose_name_plural = "InvenTree User Settings"
constraints = [
@ -1584,36 +1502,30 @@ class InvenTreeUserSetting(BaseInvenTreeSetting):
@classmethod
def get_setting_object(cls, key, user=None):
"""Return setting object for provided user."""
return super().get_setting_object(key, user=user)
def validate_unique(self, exclude=None, **kwargs):
"""Return if the setting (including key) is unique."""
return super().validate_unique(exclude=exclude, user=self.user)
def to_native_value(self):
"""
Return the "pythonic" value,
e.g. convert "True" to True, and "1" to 1
"""
"""Return the "pythonic" value, e.g. convert "True" to True, and "1" to 1."""
return self.__class__.get_setting(self.key, user=self.user)
def get_kwargs(self):
"""
Explicit kwargs required to uniquely identify a particular setting object,
in addition to the 'key' parameter
"""
"""Explicit kwargs required to uniquely identify a particular setting object, in addition to the 'key' parameter."""
return {
'user': self.user,
}
class PriceBreak(models.Model):
"""
Represents a PriceBreak model
"""
"""Represents a PriceBreak model."""
class Meta:
"""Define this as abstract -> no DB entry is created."""
abstract = True
quantity = InvenTree.fields.RoundingDecimalField(
@ -1634,13 +1546,11 @@ class PriceBreak(models.Model):
)
def convert_to(self, currency_code):
"""
Convert the unit-price at this price break to the specified currency code.
"""Convert the unit-price at this price break to the specified currency code.
Args:
currency_code - The currency code to convert to (e.g "USD" or "AUD")
currency_code: The currency code to convert to (e.g "USD" or "AUD")
"""
try:
converted = convert_money(self.price, currency_code)
except MissingRate:
@ -1651,7 +1561,7 @@ class PriceBreak(models.Model):
def get_price(instance, quantity, moq=True, multiples=True, currency=None, break_name: str = 'price_breaks'):
""" Calculate the price based on quantity price breaks.
"""Calculate the price based on quantity price breaks.
- Don't forget to add in flat-fee cost (base_cost field)
- If MOQ (minimum order quantity) is required, bump quantity
@ -1721,7 +1631,7 @@ def get_price(instance, quantity, moq=True, multiples=True, currency=None, break
class ColorTheme(models.Model):
""" Color Theme Setting """
"""Color Theme Setting."""
name = models.CharField(max_length=20,
default='',
blank=True)
@ -1731,7 +1641,7 @@ class ColorTheme(models.Model):
@classmethod
def get_color_themes_choices(cls):
""" Get all color themes from static folder """
"""Get all color themes from static folder."""
if settings.TESTING and not os.path.exists(settings.STATIC_COLOR_THEMES_DIR):
logger.error('Theme directory does not exsist')
return []
@ -1750,7 +1660,7 @@ class ColorTheme(models.Model):
@classmethod
def is_valid_choice(cls, user_color_theme):
""" Check if color theme is valid choice """
"""Check if color theme is valid choice."""
try:
user_color_theme_name = user_color_theme.name
except AttributeError:
@ -1764,13 +1674,15 @@ class ColorTheme(models.Model):
class VerificationMethod:
"""Class to hold method references."""
NONE = 0
TOKEN = 1
HMAC = 2
class WebhookEndpoint(models.Model):
""" Defines a Webhook entdpoint
"""Defines a Webhook entdpoint.
Attributes:
endpoint_id: Path to the webhook,
@ -1835,9 +1747,19 @@ class WebhookEndpoint(models.Model):
# To be overridden
def init(self, request, *args, **kwargs):
"""Set verification method.
Args:
request: Original request object.
"""
self.verify = self.VERIFICATION_METHOD
def process_webhook(self):
"""Process the webhook incomming.
This does not deal with the data itself - that happens in process_payload.
Do not touch or pickle data here - it was not verified to be safe.
"""
if self.token:
self.verify = VerificationMethod.TOKEN
if self.secret:
@ -1845,6 +1767,10 @@ class WebhookEndpoint(models.Model):
return True
def validate_token(self, payload, headers, request):
"""Make sure that the provided token (if any) confirms to the setting for this endpoint.
This can be overridden to create your own token validation method.
"""
token = headers.get(self.TOKEN_NAME, "")
# no token
@ -1866,7 +1792,14 @@ class WebhookEndpoint(models.Model):
return True
def save_data(self, payload, headers=None, request=None):
def save_data(self, payload=None, headers=None, request=None):
"""Safes payload to database.
Args:
payload (optional): Payload that was send along. Defaults to None.
headers (optional): Headers that were send along. Defaults to None.
request (optional): Original request object. Defaults to None.
"""
return WebhookMessage.objects.create(
host=request.get_host(),
header=json.dumps({key: val for key, val in headers.items()}),
@ -1874,15 +1807,35 @@ class WebhookEndpoint(models.Model):
endpoint=self,
)
def process_payload(self, message, payload=None, headers=None):
def process_payload(self, message, payload=None, headers=None) -> bool:
"""Process a payload.
Args:
message: DB entry for this message mm
payload (optional): Payload that was send along. Defaults to None.
headers (optional): Headers that were included. Defaults to None.
Returns:
bool: Was the message processed
"""
return True
def get_return(self, payload, headers=None, request=None):
def get_return(self, payload=None, headers=None, request=None) -> str:
"""Returns the message that should be returned to the endpoint caller.
Args:
payload (optional): Payload that was send along. Defaults to None.
headers (optional): Headers that were send along. Defaults to None.
request (optional): Original request object. Defaults to None.
Returns:
str: Message for caller.
"""
return self.MESSAGE_OK
class WebhookMessage(models.Model):
""" Defines a webhook message
"""Defines a webhook message.
Attributes:
message_id: Unique identifier for this message,
@ -1939,8 +1892,7 @@ class WebhookMessage(models.Model):
class NotificationEntry(models.Model):
"""
A NotificationEntry records the last time a particular notifaction was sent out.
"""A NotificationEntry records the last time a particular notifaction was sent out.
It is recorded to ensure that notifications are not sent out "too often" to users.
@ -1951,6 +1903,8 @@ class NotificationEntry(models.Model):
"""
class Meta:
"""Meta options for NotificationEntry."""
unique_together = [
('key', 'uid'),
]
@ -1970,10 +1924,7 @@ class NotificationEntry(models.Model):
@classmethod
def check_recent(cls, key: str, uid: int, delta: timedelta):
"""
Test if a particular notification has been sent in the specified time period
"""
"""Test if a particular notification has been sent in the specified time period."""
since = datetime.now().date() - delta
entries = cls.objects.filter(
@ -1986,10 +1937,7 @@ class NotificationEntry(models.Model):
@classmethod
def notify(cls, key: str, uid: int):
"""
Notify the database that a particular notification has been sent out
"""
"""Notify the database that a particular notification has been sent out."""
entry, created = cls.objects.get_or_create(
key=key,
uid=uid
@ -1999,8 +1947,7 @@ class NotificationEntry(models.Model):
class NotificationMessage(models.Model):
"""
A NotificationEntry records the last time a particular notifaction was sent out.
"""A NotificationEntry records the last time a particular notifaction was sent out.
It is recorded to ensure that notifications are not sent out "too often" to users.
@ -2073,13 +2020,14 @@ class NotificationMessage(models.Model):
@staticmethod
def get_api_url():
"""Return API endpoint."""
return reverse('api-notifications-list')
def age(self):
"""age of the message in seconds"""
"""Age of the message in seconds."""
delta = now() - self.creation
return delta.seconds
def age_human(self):
"""humanized age"""
"""Humanized age."""
return naturaltime(self.creation)

View File

@ -1,3 +1,5 @@
"""Base classes and functions for notifications."""
import logging
from datetime import timedelta
@ -12,9 +14,7 @@ logger = logging.getLogger('inventree')
# region methods
class NotificationMethod:
"""
Base class for notification methods
"""
"""Base class for notification methods."""
METHOD_NAME = ''
METHOD_ICON = None
@ -24,6 +24,13 @@ class NotificationMethod:
USER_SETTING = None
def __init__(self, obj, category, targets, context) -> None:
"""Check that the method is read.
This checks that:
- All needed functions are implemented
- The method is not disabled via plugin
- All needed contaxt values were provided
"""
# Check if a sending fnc is defined
if (not hasattr(self, 'send')) and (not hasattr(self, 'send_bulk')):
raise NotImplementedError('A NotificationMethod must either define a `send` or a `send_bulk` method')
@ -47,6 +54,7 @@ class NotificationMethod:
self.targets = self.get_targets()
def check_context(self, context):
"""Check that all values defined in the methods CONTEXT were provided in the current context."""
def check(ref, obj):
# the obj is not accesible so we are on the end
if not isinstance(obj, (list, dict, tuple, )):
@ -82,21 +90,33 @@ class NotificationMethod:
return context
def get_targets(self):
"""Returns targets for notifications.
Processes `self.targets` to extract all users that should be notified.
"""
raise NotImplementedError('The `get_targets` method must be implemented!')
def setup(self):
"""Set up context before notifications are send.
This is intended to be overridden in method implementations.
"""
return True
def cleanup(self):
"""Clean up context after all notifications were send.
This is intended to be overridden in method implementations.
"""
return True
# region plugins
def get_plugin(self):
"""Returns plugin class"""
"""Returns plugin class."""
return False
def global_setting_disable(self):
"""Check if the method is defined in a plugin and has a global setting"""
"""Check if the method is defined in a plugin and has a global setting."""
# Check if plugin has a setting
if not self.GLOBAL_SETTING:
return False
@ -115,29 +135,45 @@ class NotificationMethod:
return False
def usersetting(self, target):
"""
Returns setting for this method for a given user
"""
"""Returns setting for this method for a given user."""
return NotificationUserSetting.get_setting(f'NOTIFICATION_METHOD_{self.METHOD_NAME.upper()}', user=target, method=self.METHOD_NAME)
# endregion
class SingleNotificationMethod(NotificationMethod):
"""NotificationMethod that sends notifications one by one."""
def send(self, target):
"""This function must be overriden."""
raise NotImplementedError('The `send` method must be overriden!')
class BulkNotificationMethod(NotificationMethod):
"""NotificationMethod that sends all notifications in bulk."""
def send_bulk(self):
"""This function must be overriden."""
raise NotImplementedError('The `send` method must be overriden!')
# endregion
class MethodStorageClass:
"""Class that works as registry for all available notification methods in InvenTree.
Is initialized on startup as one instance named `storage` in this file.
"""
liste = None
user_settings = {}
def collect(self, selected_classes=None):
"""Collect all classes in the enviroment that are notification methods.
Can be filtered to only include provided classes for testing.
Args:
selected_classes (class, optional): References to the classes that should be registered. Defaults to None.
"""
logger.info('collecting notification methods')
current_method = inheritors(NotificationMethod) - IGNORED_NOTIFICATION_CLS
@ -155,7 +191,17 @@ class MethodStorageClass:
storage.liste = list(filtered_list.values())
logger.info(f'found {len(storage.liste)} notification methods')
def get_usersettings(self, user):
def get_usersettings(self, user) -> list:
"""Returns all user settings for a specific user.
This is needed to show them in the settings UI.
Args:
user (User): User that should be used as a filter.
Returns:
list: All applicablae notification settings.
"""
methods = []
for item in storage.liste:
if item.USER_SETTING:
@ -186,12 +232,16 @@ storage = MethodStorageClass()
class UIMessageNotification(SingleNotificationMethod):
"""Delivery method for sending specific users notifications in the notification pain in the web UI."""
METHOD_NAME = 'ui_message'
def get_targets(self):
"""Just return the targets - no tricks here."""
return self.targets
def send(self, target):
"""Send a UI notification to a user."""
NotificationMessage.objects.create(
target_object=self.obj,
source_object=target,
@ -204,10 +254,7 @@ class UIMessageNotification(SingleNotificationMethod):
def trigger_notification(obj, category=None, obj_ref='pk', **kwargs):
"""
Send out a notification
"""
"""Send out a notification."""
targets = kwargs.get('targets', None)
target_fnc = kwargs.get('target_fnc', None)
target_args = kwargs.get('target_args', [])
@ -267,6 +314,15 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs):
def deliver_notification(cls: NotificationMethod, obj, category: str, targets, context: dict):
"""Send notification with the provided class.
This:
- Intis the method
- Checks that there are valid targets
- Runs the delivery setup
- Sends notifications either via `send_bulk` or send`
- Runs the delivery cleanup
"""
# Init delivery method
method = cls(obj, category, targets, context)

View File

@ -1,6 +1,4 @@
"""
JSON serializers for common components
"""
"""JSON serializers for common components."""
from rest_framework import serializers
@ -11,9 +9,7 @@ from InvenTree.serializers import InvenTreeModelSerializer
class SettingsSerializer(InvenTreeModelSerializer):
"""
Base serializer for a settings object
"""
"""Base serializer for a settings object."""
key = serializers.CharField(read_only=True)
@ -30,10 +26,7 @@ class SettingsSerializer(InvenTreeModelSerializer):
api_url = serializers.CharField(read_only=True)
def get_choices(self, obj):
"""
Returns the choices available for a given item
"""
"""Returns the choices available for a given item."""
results = []
choices = obj.choices()
@ -48,10 +41,7 @@ class SettingsSerializer(InvenTreeModelSerializer):
return results
def get_value(self, obj):
"""
Make sure protected values are not returned
"""
"""Make sure protected values are not returned."""
# never return protected values
if obj.protected:
result = '***'
@ -62,11 +52,11 @@ class SettingsSerializer(InvenTreeModelSerializer):
class GlobalSettingsSerializer(SettingsSerializer):
"""
Serializer for the InvenTreeSetting model
"""
"""Serializer for the InvenTreeSetting model."""
class Meta:
"""Meta options for GlobalSettingsSerializer."""
model = InvenTreeSetting
fields = [
'pk',
@ -82,13 +72,13 @@ class GlobalSettingsSerializer(SettingsSerializer):
class UserSettingsSerializer(SettingsSerializer):
"""
Serializer for the InvenTreeUserSetting model
"""
"""Serializer for the InvenTreeUserSetting model."""
user = serializers.PrimaryKeyRelatedField(read_only=True)
class Meta:
"""Meta options for UserSettingsSerializer."""
model = InvenTreeUserSetting
fields = [
'pk',
@ -105,8 +95,7 @@ class UserSettingsSerializer(SettingsSerializer):
class GenericReferencedSettingSerializer(SettingsSerializer):
"""
Serializer for a GenericReferencedSetting model
"""Serializer for a GenericReferencedSetting model.
Args:
MODEL: model class for the serializer
@ -118,9 +107,9 @@ class GenericReferencedSettingSerializer(SettingsSerializer):
EXTRA_FIELDS = None
def __init__(self, *args, **kwargs):
"""Init overrides the Meta class to make it dynamic"""
"""Init overrides the Meta class to make it dynamic."""
class CustomMeta:
"""Scaffold for custom Meta class"""
"""Scaffold for custom Meta class."""
fields = [
'pk',
'key',
@ -144,9 +133,7 @@ class GenericReferencedSettingSerializer(SettingsSerializer):
class NotificationMessageSerializer(InvenTreeModelSerializer):
"""
Serializer for the InvenTreeUserSetting model
"""
"""Serializer for the InvenTreeUserSetting model."""
target = serializers.SerializerMethodField(read_only=True)
@ -169,12 +156,16 @@ class NotificationMessageSerializer(InvenTreeModelSerializer):
read = serializers.BooleanField(read_only=True)
def get_target(self, obj):
"""Function to resolve generic object reference to target."""
return get_objectreference(obj, 'target_content_type', 'target_object_id')
def get_source(self, obj):
"""Function to resolve generic object reference to source."""
return get_objectreference(obj, 'source_content_type', 'source_object_id')
class Meta:
"""Meta options for NotificationMessageSerializer."""
model = NotificationMessage
fields = [
'pk',
@ -192,8 +183,10 @@ class NotificationMessageSerializer(InvenTreeModelSerializer):
class NotificationReadSerializer(NotificationMessageSerializer):
"""Serializer for reading a notification."""
def is_valid(self, raise_exception=False):
"""Ensure instance data is available for view and let validation pass."""
self.instance = self.context['instance'] # set instance that should be returned
self._validated_data = True
return True

View File

@ -1,6 +1,4 @@
"""
User-configurable settings for the common app
"""
"""User-configurable settings for the common app."""
from django.conf import settings
@ -8,9 +6,7 @@ from moneyed import CURRENCIES
def currency_code_default():
"""
Returns the default currency code (or USD if not specified)
"""
"""Returns the default currency code (or USD if not specified)"""
from django.db.utils import ProgrammingError
from common.models import InvenTreeSetting
@ -28,23 +24,17 @@ def currency_code_default():
def currency_code_mappings():
"""
Returns the current currency choices
"""
"""Returns the current currency choices."""
return [(a, CURRENCIES[a].name) for a in settings.CURRENCIES]
def currency_codes():
"""
Returns the current currency codes
"""
"""Returns the current currency codes."""
return [a for a in settings.CURRENCIES]
def stock_expiry_enabled():
"""
Returns True if the stock expiry feature is enabled
"""
"""Returns True if the stock expiry feature is enabled."""
from common.models import InvenTreeSetting
return InvenTreeSetting.get_setting('STOCK_ENABLE_EXPIRY')

View File

@ -1,3 +1,5 @@
"""Tasks (processes that get offloaded) for common app."""
import logging
from datetime import datetime, timedelta
@ -7,12 +9,10 @@ logger = logging.getLogger('inventree')
def delete_old_notifications():
"""
Remove old notifications from the database.
"""Remove old notifications from the database.
Anything older than ~3 months is removed
"""
try:
from common.models import NotificationEntry
except AppRegistryNotReady: # pragma: no cover

View File

@ -1,3 +1,5 @@
"""Tests for basic notification methods and functions in InvenTree."""
import plugin.templatetags.plugin_extras as plugin_tags
from common.notifications import (BulkNotificationMethod, NotificationMethod,
SingleNotificationMethod, storage)
@ -6,9 +8,10 @@ from plugin.models import NotificationUserSetting
class BaseNotificationTests(BaseNotificationIntegrationTest):
"""Tests for basic NotificationMethod."""
def test_NotificationMethod(self):
"""ensure the implementation requirements are tested"""
"""Ensure the implementation requirements are tested."""
class FalseNotificationMethod(NotificationMethod):
METHOD_NAME = 'FalseNotification'
@ -17,12 +20,12 @@ class BaseNotificationTests(BaseNotificationIntegrationTest):
METHOD_NAME = 'AnotherFalseNotification'
def send(self):
"""a comment so we do not need a pass"""
"""A comment so we do not need a pass."""
class NoNameNotificationMethod(NotificationMethod):
def send(self):
"""a comment so we do not need a pass"""
"""A comment so we do not need a pass."""
class WrongContextNotificationMethod(NotificationMethod):
METHOD_NAME = 'WrongContextNotification'
@ -34,7 +37,7 @@ class BaseNotificationTests(BaseNotificationIntegrationTest):
]
def send(self):
"""a comment so we do not need a pass"""
"""A comment so we do not need a pass."""
# no send / send bulk
with self.assertRaises(NotImplementedError):
@ -53,11 +56,12 @@ class BaseNotificationTests(BaseNotificationIntegrationTest):
AnotherFalseNotificationMethod('', '', '', {'name': 1, 'message': 2, }, )
def test_failing_passing(self):
"""Ensure that an error in one deliverymethod is not blocking all mehthods."""
# cover failing delivery
self._notification_run()
def test_errors_passing(self):
"""ensure that errors do not kill the whole delivery"""
"""Ensure that errors do not kill the whole delivery."""
class ErrorImplementation(SingleNotificationMethod):
METHOD_NAME = 'ErrorImplementation'
@ -72,10 +76,14 @@ class BaseNotificationTests(BaseNotificationIntegrationTest):
class BulkNotificationMethodTests(BaseNotificationIntegrationTest):
"""Tests for BulkNotificationMethod classes specifically.
General tests for NotificationMethods are in BaseNotificationTests.
"""
def test_BulkNotificationMethod(self):
"""
Ensure the implementation requirements are tested.
"""Ensure the implementation requirements are tested.
MixinNotImplementedError needs to raise if the send_bulk() method is not set.
"""
@ -90,10 +98,14 @@ class BulkNotificationMethodTests(BaseNotificationIntegrationTest):
class SingleNotificationMethodTests(BaseNotificationIntegrationTest):
"""Tests for SingleNotificationMethod classes specifically.
General tests for NotificationMethods are in BaseNotificationTests.
"""
def test_SingleNotificationMethod(self):
"""
Ensure the implementation requirements are tested.
"""Ensure the implementation requirements are tested.
MixinNotImplementedError needs to raise if the send() method is not set.
"""
@ -110,14 +122,15 @@ class SingleNotificationMethodTests(BaseNotificationIntegrationTest):
class NotificationUserSettingTests(BaseNotificationIntegrationTest):
""" Tests for NotificationUserSetting """
"""Tests for NotificationUserSetting."""
def setUp(self):
"""Setup for all tests."""
super().setUp()
self.client.login(username=self.user.username, password='password')
def test_setting_attributes(self):
"""check notification method plugin methods: usersettings and tags """
"""Check notification method plugin methods: usersettings and tags."""
class SampleImplementation(BulkNotificationMethod):
METHOD_NAME = 'test'

View File

@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
"""Tests for tasks in app common."""
from django.test import TestCase
from common.models import NotificationEntry
@ -8,12 +9,10 @@ from . import tasks as common_tasks
class TaskTest(TestCase):
"""
Tests for common tasks
"""
"""Tests for common tasks."""
def test_delete(self):
"""Test that the task `delete_old_notifications` runs through without errors."""
# check empty run
self.assertEqual(NotificationEntry.objects.all().count(), 0)
offload_task(common_tasks.delete_old_notifications,)

View File

@ -1,3 +1 @@
"""
Unit tests for the views associated with the 'common' app
"""
"""Unit tests for the views associated with the 'common' app."""

View File

@ -1,3 +1,4 @@
"""Tests for mechanisms in common."""
import json
from datetime import timedelta
@ -19,16 +20,14 @@ CONTENT_TYPE_JSON = 'application/json'
class SettingsTest(InvenTreeTestCase):
"""
Tests for the 'settings' model
"""
"""Tests for the 'settings' model."""
fixtures = [
'settings',
]
def test_settings_objects(self):
"""Test fixture loading and lookup for settings."""
# There should be two settings objects in the database
settings = InvenTreeSetting.objects.all()
@ -42,9 +41,7 @@ class SettingsTest(InvenTreeTestCase):
self.assertEqual(InvenTreeSetting.get_setting_object('iNvEnTrEE_inSTanCE').pk, 1)
def test_settings_functions(self):
"""
Test settings functions and properties
"""
"""Test settings functions and properties."""
# define settings to check
instance_ref = 'INVENTREE_INSTANCE'
instance_obj = InvenTreeSetting.get_setting_object(instance_ref)
@ -90,9 +87,7 @@ class SettingsTest(InvenTreeTestCase):
self.assertEqual(stale_days.to_native_value(), 0)
def test_allValues(self):
"""
Make sure that the allValues functions returns correctly
"""
"""Make sure that the allValues functions returns correctly."""
# define testing settings
# check a few keys
@ -103,7 +98,13 @@ class SettingsTest(InvenTreeTestCase):
self.assertIn('SIGNUP_GROUP', result)
def run_settings_check(self, key, setting):
"""Test that all settings are valid.
- Ensure that a name is set and that it is translated
- Ensure that a description is set
- Ensure that every setting key is valid
- Ensure that a validator is supplied
"""
self.assertTrue(type(setting) is dict)
name = setting.get('name', None)
@ -147,11 +148,11 @@ class SettingsTest(InvenTreeTestCase):
self.assertIn(default, [True, False])
def test_setting_data(self):
"""
"""Test for settings data.
- Ensure that every setting has a name, which is translated
- Ensure that every setting has a description, which is translated
"""
for key, setting in InvenTreeSetting.SETTINGS.items():
try:
@ -168,10 +169,7 @@ class SettingsTest(InvenTreeTestCase):
raise exc
def test_defaults(self):
"""
Populate the settings with default values
"""
"""Populate the settings with default values."""
for key in InvenTreeSetting.SETTINGS.keys():
value = InvenTreeSetting.get_setting_default(key)
@ -192,14 +190,10 @@ class SettingsTest(InvenTreeTestCase):
class GlobalSettingsApiTest(InvenTreeAPITestCase):
"""
Tests for the global settings API
"""
"""Tests for the global settings API."""
def test_global_settings_api_list(self):
"""
Test list URL for global settings
"""
"""Test list URL for global settings."""
url = reverse('api-global-setting-list')
# Read out each of the global settings value, to ensure they are instantiated in the database
@ -212,7 +206,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), len(InvenTreeSetting.SETTINGS.keys()))
def test_company_name(self):
"""Test a settings object lifecyle e2e."""
setting = InvenTreeSetting.get_setting_object('INVENTREE_COMPANY_NAME')
# Check default value
@ -245,8 +239,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
self.assertEqual(setting.value, val)
def test_api_detail(self):
"""Test that we can access the detail view for a setting based on the <key>"""
"""Test that we can access the detail view for a setting based on the <key>."""
# These keys are invalid, and should return 404
for key in ["apple", "carrot", "dog"]:
response = self.get(
@ -287,28 +280,22 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
class UserSettingsApiTest(InvenTreeAPITestCase):
"""
Tests for the user settings API
"""
"""Tests for the user settings API."""
def test_user_settings_api_list(self):
"""
Test list URL for user settings
"""
"""Test list URL for user settings."""
url = reverse('api-user-setting-list')
self.get(url, expected_code=200)
def test_user_setting_invalid(self):
"""Test a user setting with an invalid key"""
"""Test a user setting with an invalid key."""
url = reverse('api-user-setting-detail', kwargs={'key': 'DONKEY'})
self.get(url, expected_code=404)
def test_user_setting_init(self):
"""Test we can retrieve a setting which has not yet been initialized"""
"""Test we can retrieve a setting which has not yet been initialized."""
key = 'HOMEPAGE_PART_LATEST'
# Ensure it does not actually exist in the database
@ -328,10 +315,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
self.assertEqual(setting.to_native_value(), False)
def test_user_setting_boolean(self):
"""
Test a boolean user setting value
"""
"""Test a boolean user setting value."""
# Ensure we have a boolean setting available
setting = InvenTreeUserSetting.get_setting_object(
'SEARCH_PREVIEW_SHOW_PARTS',
@ -395,7 +379,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
self.assertFalse(str2bool(response.data['value']))
def test_user_setting_choice(self):
"""Test a user setting with choices."""
setting = InvenTreeUserSetting.get_setting_object(
'DATE_DISPLAY_FORMAT',
user=self.user
@ -434,7 +418,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
self.assertIn('Chosen value is not a valid option', str(response.data))
def test_user_setting_integer(self):
"""Test a integer user setting value."""
setting = InvenTreeUserSetting.get_setting_object(
'SEARCH_PREVIEW_RESULTS',
user=self.user
@ -480,25 +464,25 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
class NotificationUserSettingsApiTest(InvenTreeAPITestCase):
"""Tests for the notification user settings API"""
"""Tests for the notification user settings API."""
def test_api_list(self):
"""Test list URL"""
"""Test list URL."""
url = reverse('api-notifcation-setting-list')
self.get(url, expected_code=200)
def test_setting(self):
"""Test the string name for NotificationUserSetting"""
"""Test the string name for NotificationUserSetting."""
test_setting = NotificationUserSetting.get_setting_object('NOTIFICATION_METHOD_MAIL', user=self.user)
self.assertEqual(str(test_setting), 'NOTIFICATION_METHOD_MAIL (for testuser): ')
class PluginSettingsApiTest(InvenTreeAPITestCase):
"""Tests for the plugin settings API"""
"""Tests for the plugin settings API."""
def test_plugin_list(self):
"""List installed plugins via API"""
"""List installed plugins via API."""
url = reverse('api-plugin-list')
# Simple request
@ -508,13 +492,13 @@ class PluginSettingsApiTest(InvenTreeAPITestCase):
self.get(url, expected_code=200, data={'mixin': 'settings'})
def test_api_list(self):
"""Test list URL"""
"""Test list URL."""
url = reverse('api-plugin-setting-list')
self.get(url, expected_code=200)
def test_valid_plugin_slug(self):
"""Test that an valid plugin slug runs through"""
"""Test that an valid plugin slug runs through."""
# load plugin configs
fixtures = PluginConfig.objects.all()
if not fixtures:
@ -544,26 +528,30 @@ class PluginSettingsApiTest(InvenTreeAPITestCase):
self.assertIn("Plugin 'sample' has no setting matching 'doesnotexsist'", str(response.data))
def test_invalid_setting_key(self):
"""Test that an invalid setting key returns a 404"""
"""Test that an invalid setting key returns a 404."""
...
def test_uninitialized_setting(self):
"""Test that requesting an uninitialized setting creates the setting"""
"""Test that requesting an uninitialized setting creates the setting."""
...
class WebhookMessageTests(TestCase):
"""Tests for webhooks."""
def setUp(self):
"""Setup for all tests."""
self.endpoint_def = WebhookEndpoint.objects.create()
self.url = f'/api/webhook/{self.endpoint_def.endpoint_id}/'
self.client = Client(enforce_csrf_checks=True)
def test_bad_method(self):
"""Test that a wrong HTTP method does not work."""
response = self.client.get(self.url)
assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED
def test_missing_token(self):
"""Tests that token checks work."""
response = self.client.post(
self.url,
content_type=CONTENT_TYPE_JSON,
@ -575,6 +563,7 @@ class WebhookMessageTests(TestCase):
)
def test_bad_token(self):
"""Test that a wrong token is not working."""
response = self.client.post(
self.url,
content_type=CONTENT_TYPE_JSON,
@ -585,6 +574,7 @@ class WebhookMessageTests(TestCase):
assert (json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR)
def test_bad_url(self):
"""Test that a wrongly formed url is not working."""
response = self.client.post(
'/api/webhook/1234/',
content_type=CONTENT_TYPE_JSON,
@ -593,6 +583,7 @@ class WebhookMessageTests(TestCase):
assert response.status_code == HTTPStatus.NOT_FOUND
def test_bad_json(self):
"""Test that malformed JSON is not accepted."""
response = self.client.post(
self.url,
data="{'this': 123}",
@ -606,6 +597,7 @@ class WebhookMessageTests(TestCase):
)
def test_success_no_token_check(self):
"""Test that a endpoint without a token set does not require one."""
# delete token
self.endpoint_def.token = ''
self.endpoint_def.save()
@ -620,6 +612,7 @@ class WebhookMessageTests(TestCase):
assert str(response.content, 'utf-8') == WebhookView.model_class.MESSAGE_OK
def test_bad_hmac(self):
"""Test that a malformed HMAC does not pass."""
# delete token
self.endpoint_def.token = ''
self.endpoint_def.secret = '123abc'
@ -635,6 +628,7 @@ class WebhookMessageTests(TestCase):
assert (json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR)
def test_success_hmac(self):
"""Test with a valid HMAC provided."""
# delete token
self.endpoint_def.token = ''
self.endpoint_def.secret = '123abc'
@ -651,6 +645,10 @@ class WebhookMessageTests(TestCase):
assert str(response.content, 'utf-8') == WebhookView.model_class.MESSAGE_OK
def test_success(self):
"""Test full e2e webhook call.
The message should go through and save the json payload.
"""
response = self.client.post(
self.url,
data={"this": "is a message"},
@ -665,9 +663,10 @@ class WebhookMessageTests(TestCase):
class NotificationTest(InvenTreeAPITestCase):
"""Tests for NotificationEntriy."""
def test_check_notification_entries(self):
"""Test that notification entries can be created."""
# Create some notification entries
self.assertEqual(NotificationEntry.objects.count(), 0)
@ -684,21 +683,16 @@ class NotificationTest(InvenTreeAPITestCase):
self.assertTrue(NotificationEntry.check_recent('test.notification', 1, delta))
def test_api_list(self):
"""Test list URL"""
"""Test list URL."""
url = reverse('api-notifications-list')
self.get(url, expected_code=200)
class LoadingTest(TestCase):
"""
Tests for the common config
"""
"""Tests for the common config."""
def test_restart_flag(self):
"""
Test that the restart flag is reset on start
"""
"""Test that the restart flag is reset on start."""
import common.models
from plugin import registry
@ -713,10 +707,10 @@ class LoadingTest(TestCase):
class ColorThemeTest(TestCase):
"""Tests for ColorTheme"""
"""Tests for ColorTheme."""
def test_choices(self):
"""Test that default choices are returned"""
"""Test that default choices are returned."""
result = ColorTheme.get_color_themes_choices()
# skip
@ -725,7 +719,7 @@ class ColorThemeTest(TestCase):
self.assertIn(('default', 'Default'), result)
def test_valid_choice(self):
"""Check that is_valid_choice works correctly"""
"""Check that is_valid_choice works correctly."""
result = ColorTheme.get_color_themes_choices()
# skip

View File

@ -1,6 +1,4 @@
"""
URL lookup for common views
"""
"""URL lookup for common views."""
common_urls = [
]

View File

@ -1,6 +1,4 @@
"""
Django views for interacting with common models
"""
"""Django views for interacting with common models."""
import os
@ -18,7 +16,7 @@ from .files import FileManager
class MultiStepFormView(SessionWizardView):
""" Setup basic methods of multi-step form
"""Setup basic methods of multi-step form.
form_list: list of forms
form_steps_description: description for each form
@ -31,14 +29,13 @@ class MultiStepFormView(SessionWizardView):
file_storage = FileSystemStorage(settings.MEDIA_ROOT)
def __init__(self, *args, **kwargs):
""" Override init method to set media folder """
"""Override init method to set media folder."""
super().__init__(**kwargs)
self.process_media_folder()
def process_media_folder(self):
""" Process media folder """
"""Process media folder."""
if self.media_folder:
media_folder_abs = os.path.join(settings.MEDIA_ROOT, self.media_folder)
if not os.path.exists(media_folder_abs):
@ -46,8 +43,7 @@ class MultiStepFormView(SessionWizardView):
self.file_storage = FileSystemStorage(location=media_folder_abs)
def get_template_names(self):
""" Select template """
"""Select template."""
try:
# Get template
template = self.form_steps_template[self.steps.index]
@ -57,8 +53,7 @@ class MultiStepFormView(SessionWizardView):
return template
def get_context_data(self, **kwargs):
""" Update context data """
"""Update context data."""
# Retrieve current context
context = super().get_context_data(**kwargs)
@ -74,7 +69,9 @@ class MultiStepFormView(SessionWizardView):
class FileManagementFormView(MultiStepFormView):
""" Setup form wizard to perform the following steps:
"""File management form wizard.
Perform the following steps:
1. Upload tabular data file
2. Match headers to InvenTree fields
3. Edit row data and match InvenTree items
@ -95,8 +92,7 @@ class FileManagementFormView(MultiStepFormView):
extra_context_data = {}
def __init__(self, *args, **kwargs):
""" Initialize the FormView """
"""Initialize the FormView."""
# Perform all checks and inits for MultiStepFormView
super().__init__(self, *args, **kwargs)
@ -105,8 +101,7 @@ class FileManagementFormView(MultiStepFormView):
raise NotImplementedError('A subclass of a file manager class needs to be set!')
def get_context_data(self, form=None, **kwargs):
""" Handle context data """
"""Handle context data."""
if form is None:
form = self.get_form()
@ -136,8 +131,7 @@ class FileManagementFormView(MultiStepFormView):
return context
def get_file_manager(self, step=None, form=None):
""" Get FileManager instance from uploaded file """
"""Get FileManager instance from uploaded file."""
if self.file_manager:
return
@ -151,8 +145,7 @@ class FileManagementFormView(MultiStepFormView):
self.file_manager = self.file_manager_class(file=file, name=self.name)
def get_form_kwargs(self, step=None):
""" Update kwargs to dynamically build forms """
"""Update kwargs to dynamically build forms."""
# Always retrieve FileManager instance from uploaded file
self.get_file_manager(step)
@ -191,7 +184,7 @@ class FileManagementFormView(MultiStepFormView):
return super().get_form_kwargs()
def get_form(self, step=None, data=None, files=None):
""" add crispy-form helper to form """
"""Add crispy-form helper to form."""
form = super().get_form(step=step, data=data, files=files)
form.helper = FormHelper()
@ -200,17 +193,14 @@ class FileManagementFormView(MultiStepFormView):
return form
def get_form_table_data(self, form_data):
""" Extract table cell data from form data and fields.
These data are used to maintain state between sessions.
"""Extract table cell data from form data and fields. These data are used to maintain state between sessions.
Table data keys are as follows:
col_name_<idx> - Column name at idx as provided in the uploaded file
col_guess_<idx> - Column guess at idx as selected
row_<x>_col<y> - Cell data as provided in the uploaded file
"""
# Map the columns
self.column_names = {}
self.column_selections = {}
@ -264,8 +254,7 @@ class FileManagementFormView(MultiStepFormView):
self.row_data[row_id][col_id] = value
def set_form_table_data(self, form=None):
""" Set the form table data """
"""Set the form table data."""
if self.column_names:
# Re-construct the column data
self.columns = []
@ -324,10 +313,10 @@ class FileManagementFormView(MultiStepFormView):
row[field_key] = field_key + '-' + str(row['index'])
def get_column_index(self, name):
""" Return the index of the column with the given name.
"""Return the index of the column with the given name.
It named column is not found, return -1
"""
try:
idx = list(self.column_selections.values()).index(name)
except ValueError:
@ -336,9 +325,7 @@ class FileManagementFormView(MultiStepFormView):
return idx
def get_field_selection(self):
""" Once data columns have been selected, attempt to pre-select the proper data from the database.
This function is called once the field selection has been validated.
The pre-fill data are then passed through to the part selection form.
"""Once data columns have been selected, attempt to pre-select the proper data from the database. This function is called once the field selection has been validated. The pre-fill data are then passed through to the part selection form.
This method is very specific to the type of data found in the file,
therefore overwrite it in the subclass.
@ -346,7 +333,7 @@ class FileManagementFormView(MultiStepFormView):
pass
def get_clean_items(self):
""" returns dict with all cleaned values """
"""Returns dict with all cleaned values."""
items = {}
for form_key, form_value in self.get_all_cleaned_data().items():
@ -373,8 +360,7 @@ class FileManagementFormView(MultiStepFormView):
return items
def check_field_selection(self, form):
""" Check field matching """
"""Check field matching."""
# Are there any missing columns?
missing_columns = []
@ -422,8 +408,7 @@ class FileManagementFormView(MultiStepFormView):
return valid
def validate(self, step, form):
""" Validate forms """
"""Validate forms."""
valid = True
# Get form table data
@ -442,8 +427,7 @@ class FileManagementFormView(MultiStepFormView):
return valid
def post(self, request, *args, **kwargs):
""" Perform validations before posting data """
"""Perform validations before posting data."""
wizard_goto_step = self.request.POST.get('wizard_goto_step', None)
form = self.get_form(data=self.request.POST, files=self.request.FILES)
@ -458,14 +442,21 @@ class FileManagementFormView(MultiStepFormView):
class FileManagementAjaxView(AjaxView):
""" Use a FileManagementFormView as base for a AjaxView
Inherit this class before inheriting the base FileManagementFormView
"""Use a FileManagementFormView as base for a AjaxView Inherit this class before inheriting the base FileManagementFormView.
ajax_form_steps_template: templates for rendering ajax
validate: function to validate the current form -> normally point to the same function in the base FileManagementFormView
"""
def post(self, request):
"""Handle wizard step call.
Possible actions:
- Step back -> render previous step
- Invalid form -> render error
- Valid form and not done -> render next step
- Valid form and done -> render final step
"""
# check if back-step button was selected
wizard_back = self.request.POST.get('act-btn_back', None)
if wizard_back:
@ -497,6 +488,7 @@ class FileManagementAjaxView(AjaxView):
return self.renderJsonResponse(request, data={'form_valid': None})
def get(self, request):
"""Reset storage if flag is set, proceed to render JsonResponse."""
if 'reset' in request.GET:
# reset form
self.storage.reset()
@ -504,11 +496,12 @@ class FileManagementAjaxView(AjaxView):
return self.renderJsonResponse(request)
def renderJsonResponse(self, request, form=None, data={}, context=None):
""" always set the right templates before rendering """
"""Always set the right templates before rendering."""
self.setTemplate()
return super().renderJsonResponse(request, form=form, data=data, context=context)
def get_data(self):
def get_data(self) -> dict:
"""Get extra context data."""
data = super().get_data()
data['hideErrorMessage'] = '1' # hide the error
buttons = [{'name': 'back', 'title': _('Previous Step')}] if self.get_step_index() > 0 else []
@ -516,9 +509,13 @@ class FileManagementAjaxView(AjaxView):
return data
def setTemplate(self):
""" set template name and title """
"""Set template name and title."""
self.ajax_template_name = self.ajax_form_steps_template[self.get_step_index()]
self.ajax_form_title = self.form_steps_description[self.get_step_index()]
def validate(self, obj, form, **kwargs):
"""Generic validate action.
This is the point to process provided userinput.
"""
raise NotImplementedError('This function needs to be overridden!')

View File

@ -1,5 +1,4 @@
"""
The Company module is responsible for managing Company interactions.
"""The Company module is responsible for managing Company interactions.
A company can be either (or both):

View File

@ -1,3 +1,5 @@
"""Admin class for the 'company' app"""
from django.contrib import admin
import import_export.widgets as widgets
@ -13,9 +15,10 @@ from .models import (Company, ManufacturerPart, ManufacturerPartAttachment,
class CompanyResource(ModelResource):
""" Class for managing Company data import/export """
"""Class for managing Company data import/export."""
class Meta:
"""Metaclass defines extra options"""
model = Company
skip_unchanged = True
report_skipped = False
@ -23,6 +26,7 @@ class CompanyResource(ModelResource):
class CompanyAdmin(ImportExportModelAdmin):
"""Admin class for the Company model"""
resource_class = CompanyResource
@ -35,9 +39,7 @@ class CompanyAdmin(ImportExportModelAdmin):
class SupplierPartResource(ModelResource):
"""
Class for managing SupplierPart data import/export
"""
"""Class for managing SupplierPart data import/export."""
part = Field(attribute='part', widget=widgets.ForeignKeyWidget(Part))
@ -48,6 +50,7 @@ class SupplierPartResource(ModelResource):
supplier_name = Field(attribute='supplier__name', readonly=True)
class Meta:
"""Metaclass defines extra admin options"""
model = SupplierPart
skip_unchanged = True
report_skipped = True
@ -55,6 +58,7 @@ class SupplierPartResource(ModelResource):
class SupplierPartAdmin(ImportExportModelAdmin):
"""Admin class for the SupplierPart model"""
resource_class = SupplierPartResource
@ -71,9 +75,7 @@ class SupplierPartAdmin(ImportExportModelAdmin):
class ManufacturerPartResource(ModelResource):
"""
Class for managing ManufacturerPart data import/export
"""
"""Class for managing ManufacturerPart data import/export."""
part = Field(attribute='part', widget=widgets.ForeignKeyWidget(Part))
@ -84,6 +86,7 @@ class ManufacturerPartResource(ModelResource):
manufacturer_name = Field(attribute='manufacturer__name', readonly=True)
class Meta:
"""Metaclass defines extra admin options"""
model = ManufacturerPart
skip_unchanged = True
report_skipped = True
@ -91,9 +94,7 @@ class ManufacturerPartResource(ModelResource):
class ManufacturerPartAdmin(ImportExportModelAdmin):
"""
Admin class for ManufacturerPart model
"""
"""Admin class for ManufacturerPart model."""
resource_class = ManufacturerPartResource
@ -109,9 +110,7 @@ class ManufacturerPartAdmin(ImportExportModelAdmin):
class ManufacturerPartAttachmentAdmin(ImportExportModelAdmin):
"""
Admin class for ManufacturerPartAttachment model
"""
"""Admin class for ManufacturerPartAttachment model."""
list_display = ('manufacturer_part', 'attachment', 'comment')
@ -119,11 +118,10 @@ class ManufacturerPartAttachmentAdmin(ImportExportModelAdmin):
class ManufacturerPartParameterResource(ModelResource):
"""
Class for managing ManufacturerPartParameter data import/export
"""
"""Class for managing ManufacturerPartParameter data import/export."""
class Meta:
"""Metaclass defines extra admin options"""
model = ManufacturerPartParameter
skip_unchanged = True
report_skipped = True
@ -131,9 +129,7 @@ class ManufacturerPartParameterResource(ModelResource):
class ManufacturerPartParameterAdmin(ImportExportModelAdmin):
"""
Admin class for ManufacturerPartParameter model
"""
"""Admin class for ManufacturerPartParameter model."""
resource_class = ManufacturerPartParameterResource
@ -149,7 +145,7 @@ class ManufacturerPartParameterAdmin(ImportExportModelAdmin):
class SupplierPriceBreakResource(ModelResource):
""" Class for managing SupplierPriceBreak data import/export """
"""Class for managing SupplierPriceBreak data import/export."""
part = Field(attribute='part', widget=widgets.ForeignKeyWidget(SupplierPart))
@ -164,6 +160,7 @@ class SupplierPriceBreakResource(ModelResource):
MPN = Field(attribute='part__MPN', readonly=True)
class Meta:
"""Metaclass defines extra admin options"""
model = SupplierPriceBreak
skip_unchanged = True
report_skipped = False
@ -171,6 +168,7 @@ class SupplierPriceBreakResource(ModelResource):
class SupplierPriceBreakAdmin(ImportExportModelAdmin):
"""Admin class for the SupplierPriceBreak model"""
resource_class = SupplierPriceBreakResource

View File

@ -1,6 +1,4 @@
"""
Provides a JSON API for the Company app
"""
"""Provides a JSON API for the Company app."""
from django.db.models import Q
from django.urls import include, re_path
@ -23,7 +21,7 @@ from .serializers import (CompanySerializer,
class CompanyList(generics.ListCreateAPIView):
""" API endpoint for accessing a list of Company objects
"""API endpoint for accessing a list of Company objects.
Provides two methods:
@ -35,7 +33,7 @@ class CompanyList(generics.ListCreateAPIView):
queryset = Company.objects.all()
def get_queryset(self):
"""Return annotated queryset for the company list endpoint"""
queryset = super().get_queryset()
queryset = CompanySerializer.annotate_queryset(queryset)
@ -70,13 +68,13 @@ class CompanyList(generics.ListCreateAPIView):
class CompanyDetail(generics.RetrieveUpdateDestroyAPIView):
""" API endpoint for detail of a single Company object """
"""API endpoint for detail of a single Company object."""
queryset = Company.objects.all()
serializer_class = CompanySerializer
def get_queryset(self):
"""Return annotated queryset for the company detail endpoint"""
queryset = super().get_queryset()
queryset = CompanySerializer.annotate_queryset(queryset)
@ -84,11 +82,11 @@ class CompanyDetail(generics.RetrieveUpdateDestroyAPIView):
class ManufacturerPartFilter(rest_filters.FilterSet):
"""
Custom API filters for the ManufacturerPart list endpoint.
"""
"""Custom API filters for the ManufacturerPart list endpoint."""
class Meta:
"""Metaclass options."""
model = ManufacturerPart
fields = [
'manufacturer',
@ -101,7 +99,7 @@ class ManufacturerPartFilter(rest_filters.FilterSet):
class ManufacturerPartList(generics.ListCreateAPIView):
""" API endpoint for list view of ManufacturerPart object
"""API endpoint for list view of ManufacturerPart object.
- GET: Return list of ManufacturerPart objects
- POST: Create a new ManufacturerPart object
@ -117,7 +115,7 @@ class ManufacturerPartList(generics.ListCreateAPIView):
filterset_class = ManufacturerPartFilter
def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint"""
# Do we wish to include extra detail?
try:
params = self.request.query_params
@ -149,7 +147,7 @@ class ManufacturerPartList(generics.ListCreateAPIView):
class ManufacturerPartDetail(generics.RetrieveUpdateDestroyAPIView):
""" API endpoint for detail view of ManufacturerPart object
"""API endpoint for detail view of ManufacturerPart object.
- GET: Retrieve detail view
- PATCH: Update object
@ -161,9 +159,7 @@ class ManufacturerPartDetail(generics.RetrieveUpdateDestroyAPIView):
class ManufacturerPartAttachmentList(AttachmentMixin, generics.ListCreateAPIView):
"""
API endpoint for listing (and creating) a ManufacturerPartAttachment (file upload).
"""
"""API endpoint for listing (and creating) a ManufacturerPartAttachment (file upload)."""
queryset = ManufacturerPartAttachment.objects.all()
serializer_class = ManufacturerPartAttachmentSerializer
@ -178,24 +174,20 @@ class ManufacturerPartAttachmentList(AttachmentMixin, generics.ListCreateAPIView
class ManufacturerPartAttachmentDetail(AttachmentMixin, generics.RetrieveUpdateDestroyAPIView):
"""
Detail endpooint for ManufacturerPartAttachment model
"""
"""Detail endpooint for ManufacturerPartAttachment model."""
queryset = ManufacturerPartAttachment.objects.all()
serializer_class = ManufacturerPartAttachmentSerializer
class ManufacturerPartParameterList(generics.ListCreateAPIView):
"""
API endpoint for list view of ManufacturerPartParamater model.
"""
"""API endpoint for list view of ManufacturerPartParamater model."""
queryset = ManufacturerPartParameter.objects.all()
serializer_class = ManufacturerPartParameterSerializer
def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint"""
# Do we wish to include any extra detail?
try:
params = self.request.query_params
@ -215,10 +207,7 @@ class ManufacturerPartParameterList(generics.ListCreateAPIView):
return self.serializer_class(*args, **kwargs)
def filter_queryset(self, queryset):
"""
Custom filtering for the queryset
"""
"""Custom filtering for the queryset."""
queryset = super().filter_queryset(queryset)
params = self.request.query_params
@ -258,16 +247,14 @@ class ManufacturerPartParameterList(generics.ListCreateAPIView):
class ManufacturerPartParameterDetail(generics.RetrieveUpdateDestroyAPIView):
"""
API endpoint for detail view of ManufacturerPartParameter model
"""
"""API endpoint for detail view of ManufacturerPartParameter model."""
queryset = ManufacturerPartParameter.objects.all()
serializer_class = ManufacturerPartParameterSerializer
class SupplierPartList(generics.ListCreateAPIView):
""" API endpoint for list view of SupplierPart object
"""API endpoint for list view of SupplierPart object.
- GET: Return list of SupplierPart objects
- POST: Create a new SupplierPart object
@ -275,17 +262,8 @@ class SupplierPartList(generics.ListCreateAPIView):
queryset = SupplierPart.objects.all()
def get_queryset(self):
queryset = super().get_queryset()
return queryset
def filter_queryset(self, queryset):
"""
Custom filtering for the queryset.
"""
"""Custom filtering for the queryset."""
queryset = super().filter_queryset(queryset)
params = self.request.query_params
@ -330,6 +308,7 @@ class SupplierPartList(generics.ListCreateAPIView):
return queryset
def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint"""
# Do we wish to include extra detail?
try:
@ -369,7 +348,7 @@ class SupplierPartList(generics.ListCreateAPIView):
class SupplierPartDetail(generics.RetrieveUpdateDestroyAPIView):
""" API endpoint for detail view of SupplierPart object
"""API endpoint for detail view of SupplierPart object.
- GET: Retrieve detail view
- PATCH: Update object
@ -384,7 +363,7 @@ class SupplierPartDetail(generics.RetrieveUpdateDestroyAPIView):
class SupplierPriceBreakList(generics.ListCreateAPIView):
""" API endpoint for list view of SupplierPriceBreak object
"""API endpoint for list view of SupplierPriceBreak object.
- GET: Retrieve list of SupplierPriceBreak objects
- POST: Create a new SupplierPriceBreak object
@ -403,9 +382,7 @@ class SupplierPriceBreakList(generics.ListCreateAPIView):
class SupplierPriceBreakDetail(generics.RetrieveUpdateDestroyAPIView):
"""
Detail endpoint for SupplierPriceBreak object
"""
"""Detail endpoint for SupplierPriceBreak object."""
queryset = SupplierPriceBreak.objects.all()
serializer_class = SupplierPriceBreakSerializer

View File

@ -1,12 +1,13 @@
"""Config for the 'company' app"""
from django.apps import AppConfig
class CompanyConfig(AppConfig):
"""Config class for the 'company' app"""
name = 'company'
def ready(self):
"""
This function is called whenever the Company app is loaded.
"""
"""This function is called whenever the Company app is loaded."""
pass

View File

@ -1,6 +1,4 @@
"""
Django Forms for interacting with Company app
"""
"""Django Forms for interacting with Company app."""
import django.forms
from django.utils.translation import gettext_lazy as _
@ -12,9 +10,7 @@ from .models import Company, SupplierPriceBreak
class CompanyImageDownloadForm(HelperForm):
"""
Form for downloading an image from a URL
"""
"""Form for downloading an image from a URL."""
url = django.forms.URLField(
label=_('URL'),
@ -23,6 +19,8 @@ class CompanyImageDownloadForm(HelperForm):
)
class Meta:
"""Metaclass options."""
model = Company
fields = [
'url',
@ -30,7 +28,7 @@ class CompanyImageDownloadForm(HelperForm):
class EditPriceBreakForm(HelperForm):
""" Form for creating / editing a supplier price break """
"""Form for creating / editing a supplier price break."""
quantity = RoundingDecimalFormField(
max_digits=10,
@ -40,6 +38,8 @@ class EditPriceBreakForm(HelperForm):
)
class Meta:
"""Metaclass options."""
model = SupplierPriceBreak
fields = [
'part',

View File

@ -1,6 +1,4 @@
"""
Company database model definitions
"""
"""Company database model definitions."""
import os
@ -27,7 +25,7 @@ from InvenTree.status_codes import PurchaseOrderStatus
def rename_company_image(instance, filename):
""" Function to rename a company image after upload
"""Function to rename a company image after upload.
Args:
instance: Company object
@ -36,7 +34,6 @@ def rename_company_image(instance, filename):
Returns:
New image filename
"""
base = 'company_images'
if filename.count('.') > 0:
@ -53,7 +50,8 @@ def rename_company_image(instance, filename):
class Company(models.Model):
""" A Company object represents an external company.
"""A Company object represents an external company.
It may be a supplier or a customer or a manufacturer (or a combination)
- A supplier is a company from which parts can be purchased
@ -79,9 +77,11 @@ class Company(models.Model):
@staticmethod
def get_api_url():
"""Return the API URL associated with the Company model"""
return reverse('api-company-list')
class Meta:
"""Metaclass defines extra model options"""
ordering = ['name', ]
constraints = [
UniqueConstraint(fields=['name', 'email'], name='unique_name_email_pair')
@ -150,13 +150,11 @@ class Company(models.Model):
@property
def currency_code(self):
"""
Return the currency code associated with this company.
"""Return the currency code associated with this company.
- If the currency code is invalid, use the default currency
- If the currency code is not specified, use the default currency
"""
code = self.currency
if code not in CURRENCIES:
@ -165,103 +163,41 @@ class Company(models.Model):
return code
def __str__(self):
""" Get string representation of a Company """
"""Get string representation of a Company."""
return "{n} - {d}".format(n=self.name, d=self.description)
def get_absolute_url(self):
""" Get the web URL for the detail view for this Company """
"""Get the web URL for the detail view for this Company."""
return reverse('company-detail', kwargs={'pk': self.id})
def get_image_url(self):
""" Return the URL of the image for this company """
"""Return the URL of the image for this company."""
if self.image:
return getMediaUrl(self.image.url)
else:
return getBlankImage()
def get_thumbnail_url(self):
""" Return the URL for the thumbnail image for this Company """
"""Return the URL for the thumbnail image for this Company."""
if self.image:
return getMediaUrl(self.image.thumbnail.url)
else:
return getBlankThumbnail()
@property
def manufactured_part_count(self):
""" The number of parts manufactured by this company """
return self.manufactured_parts.count()
@property
def has_manufactured_parts(self):
return self.manufactured_part_count > 0
@property
def supplied_part_count(self):
""" The number of parts supplied by this company """
return self.supplied_parts.count()
@property
def has_supplied_parts(self):
""" Return True if this company supplies any parts """
return self.supplied_part_count > 0
@property
def parts(self):
""" Return SupplierPart objects which are supplied or manufactured by this company """
"""Return SupplierPart objects which are supplied or manufactured by this company."""
return SupplierPart.objects.filter(Q(supplier=self.id) | Q(manufacturer_part__manufacturer=self.id))
@property
def part_count(self):
""" The number of parts manufactured (or supplied) by this Company """
return self.parts.count()
@property
def has_parts(self):
return self.part_count > 0
@property
def stock_items(self):
""" Return a list of all stock items supplied or manufactured by this company """
"""Return a list of all stock items supplied or manufactured by this company."""
stock = apps.get_model('stock', 'StockItem')
return stock.objects.filter(Q(supplier_part__supplier=self.id) | Q(supplier_part__manufacturer_part__manufacturer=self.id)).all()
@property
def stock_count(self):
""" Return the number of stock items supplied or manufactured by this company """
return self.stock_items.count()
def outstanding_purchase_orders(self):
""" Return purchase orders which are 'outstanding' """
return self.purchase_orders.filter(status__in=PurchaseOrderStatus.OPEN)
def pending_purchase_orders(self):
""" Return purchase orders which are PENDING (not yet issued) """
return self.purchase_orders.filter(status=PurchaseOrderStatus.PENDING)
def closed_purchase_orders(self):
""" Return purchase orders which are not 'outstanding'
- Complete
- Failed / lost
- Returned
"""
return self.purchase_orders.exclude(status__in=PurchaseOrderStatus.OPEN)
def complete_purchase_orders(self):
return self.purchase_orders.filter(status=PurchaseOrderStatus.COMPLETE)
def failed_purchase_orders(self):
""" Return any purchase orders which were not successful """
return self.purchase_orders.filter(status__in=PurchaseOrderStatus.FAILED)
class Contact(models.Model):
""" A Contact represents a person who works at a particular company.
A Company may have zero or more associated Contact objects.
"""A Contact represents a person who works at a particular company. A Company may have zero or more associated Contact objects.
Attributes:
company: Company link for this contact
@ -284,10 +220,7 @@ class Contact(models.Model):
class ManufacturerPart(models.Model):
""" Represents a unique part as provided by a Manufacturer
Each ManufacturerPart is identified by a MPN (Manufacturer Part Number)
Each ManufacturerPart is also linked to a Part object.
A Part may be available from multiple manufacturers
"""Represents a unique part as provided by a Manufacturer Each ManufacturerPart is identified by a MPN (Manufacturer Part Number) Each ManufacturerPart is also linked to a Part object. A Part may be available from multiple manufacturers.
Attributes:
part: Link to the master Part
@ -299,9 +232,11 @@ class ManufacturerPart(models.Model):
@staticmethod
def get_api_url():
"""Return the API URL associated with the ManufacturerPart instance"""
return reverse('api-manufacturer-part-list')
class Meta:
"""Metaclass defines extra model options"""
unique_together = ('part', 'manufacturer', 'MPN')
part = models.ForeignKey('part.Part', on_delete=models.CASCADE,
@ -346,10 +281,7 @@ class ManufacturerPart(models.Model):
@classmethod
def create(cls, part, manufacturer, mpn, description, link=None):
""" Check if ManufacturerPart instance does not already exist
then create it
"""
"""Check if ManufacturerPart instance does not already exist then create it."""
manufacturer_part = None
try:
@ -364,6 +296,7 @@ class ManufacturerPart(models.Model):
return manufacturer_part
def __str__(self):
"""Format a string representation of a ManufacturerPart"""
s = ''
if self.manufacturer:
@ -376,15 +309,15 @@ class ManufacturerPart(models.Model):
class ManufacturerPartAttachment(InvenTreeAttachment):
"""
Model for storing file attachments against a ManufacturerPart object
"""
"""Model for storing file attachments against a ManufacturerPart object."""
@staticmethod
def get_api_url():
"""Return the API URL associated with the ManufacturerPartAttachment model"""
return reverse('api-manufacturer-part-attachment-list')
def getSubdir(self):
"""Return the subdirectory where attachment files for the ManufacturerPart model are located"""
return os.path.join("manufacturer_part_files", str(self.manufacturer_part.id))
manufacturer_part = models.ForeignKey(ManufacturerPart, on_delete=models.CASCADE,
@ -392,8 +325,7 @@ class ManufacturerPartAttachment(InvenTreeAttachment):
class ManufacturerPartParameter(models.Model):
"""
A ManufacturerPartParameter represents a key:value parameter for a MnaufacturerPart.
"""A ManufacturerPartParameter represents a key:value parameter for a MnaufacturerPart.
This is used to represent parmeters / properties for a particular manufacturer part.
@ -402,9 +334,11 @@ class ManufacturerPartParameter(models.Model):
@staticmethod
def get_api_url():
"""Return the API URL associated with the ManufacturerPartParameter model"""
return reverse('api-manufacturer-part-parameter-list')
class Meta:
"""Metaclass defines extra model options"""
unique_together = ('manufacturer_part', 'name')
manufacturer_part = models.ForeignKey(
@ -437,13 +371,14 @@ class ManufacturerPartParameter(models.Model):
class SupplierPartManager(models.Manager):
""" Define custom SupplierPart objects manager
"""Define custom SupplierPart objects manager.
The main purpose of this manager is to improve database hit as the
SupplierPart model involves A LOT of foreign keys lookups
"""
def get_queryset(self):
"""Prefetch related fields when querying against the SupplierPart model"""
# Always prefetch related models
return super().get_queryset().prefetch_related(
'part',
@ -453,10 +388,7 @@ class SupplierPartManager(models.Manager):
class SupplierPart(models.Model):
""" Represents a unique part as provided by a Supplier
Each SupplierPart is identified by a SKU (Supplier Part Number)
Each SupplierPart is also linked to a Part or ManufacturerPart object.
A Part may be available from multiple suppliers
"""Represents a unique part as provided by a Supplier Each SupplierPart is identified by a SKU (Supplier Part Number) Each SupplierPart is also linked to a Part or ManufacturerPart object. A Part may be available from multiple suppliers.
Attributes:
part: Link to the master Part (Obsolete)
@ -476,13 +408,15 @@ class SupplierPart(models.Model):
@staticmethod
def get_api_url():
"""Return the API URL associated with the SupplierPart model"""
return reverse('api-supplier-part-list')
def get_absolute_url(self):
"""Return the web URL of the detail view for this SupplierPart"""
return reverse('supplier-part-detail', kwargs={'pk': self.id})
def api_instance_filters(self):
"""Return custom API filters for this particular instance"""
return {
'manufacturer_part': {
'part': self.part.pk
@ -490,13 +424,17 @@ class SupplierPart(models.Model):
}
class Meta:
"""Metaclass defines extra model options"""
unique_together = ('part', 'supplier', 'SKU')
# This model was moved from the 'Part' app
db_table = 'part_supplierpart'
def clean(self):
"""Custom clean action for the SupplierPart model:
- Ensure that manufacturer_part.part and part are the same!
"""
super().clean()
# Ensure that the linked manufacturer_part points to the same part!
@ -508,8 +446,7 @@ class SupplierPart(models.Model):
})
def save(self, *args, **kwargs):
""" Overriding save method to connect an existing ManufacturerPart """
"""Overriding save method to connect an existing ManufacturerPart."""
manufacturer_part = None
if all(key in kwargs for key in ('manufacturer', 'MPN')):
@ -593,10 +530,10 @@ class SupplierPart(models.Model):
@property
def manufacturer_string(self):
""" Format a MPN string for this SupplierPart.
"""Format a MPN string for this SupplierPart.
Concatenates manufacture name and part number.
"""
items = []
if self.manufacturer_part:
@ -609,26 +546,26 @@ class SupplierPart(models.Model):
@property
def has_price_breaks(self):
"""Return True if this SupplierPart has associated price breaks"""
return self.price_breaks.count() > 0
@property
def price_breaks(self):
""" Return the associated price breaks in the correct order """
"""Return the associated price breaks in the correct order."""
return self.pricebreaks.order_by('quantity').all()
@property
def unit_pricing(self):
"""Return the single-quantity pricing for this SupplierPart"""
return self.get_price(1)
def add_price_break(self, quantity, price):
"""
Create a new price break for this part
def add_price_break(self, quantity, price) -> None:
"""Create a new price break for this part.
args:
quantity - Numerical quantity
price - Must be a Money object
Args:
quantity: Numerical quantity
price: Must be a Money object
"""
# Check if a price break at that quantity already exists...
if self.price_breaks.filter(quantity=quantity, part=self.pk).exists():
return
@ -642,18 +579,14 @@ class SupplierPart(models.Model):
get_price = common.models.get_price
def open_orders(self):
""" Return a database query for PurchaseOrder line items for this SupplierPart,
limited to purchase orders that are open / outstanding.
"""
"""Return a database query for PurchaseOrder line items for this SupplierPart, limited to purchase orders that are open / outstanding."""
return self.purchase_order_line_items.prefetch_related('order').filter(order__status__in=PurchaseOrderStatus.OPEN)
def on_order(self):
""" Return the total quantity of items currently on order.
"""Return the total quantity of items currently on order.
Subtract partially received stock as appropriate
"""
totals = self.open_orders().aggregate(Sum('quantity'), Sum('received'))
# Quantity on order
@ -668,15 +601,16 @@ class SupplierPart(models.Model):
return max(q - r, 0)
def purchase_orders(self):
""" Returns a list of purchase orders relating to this supplier part """
"""Returns a list of purchase orders relating to this supplier part."""
return [line.order for line in self.purchase_order_line_items.all().prefetch_related('order')]
@property
def pretty_name(self):
"""Format a 'pretty' name for this SupplierPart"""
return str(self)
def __str__(self):
"""Format a string representation of a SupplierPart"""
s = ''
if self.part.IPN:
@ -692,7 +626,8 @@ class SupplierPart(models.Model):
class SupplierPriceBreak(common.models.PriceBreak):
""" Represents a quantity price break for a SupplierPart.
"""Represents a quantity price break for a SupplierPart.
- Suppliers can offer discounts at larger quantities
- SupplierPart(s) may have zero-or-more associated SupplierPriceBreak(s)
@ -706,6 +641,7 @@ class SupplierPriceBreak(common.models.PriceBreak):
@staticmethod
def get_api_url():
"""Return the API URL associated with the SupplierPriceBreak model"""
return reverse('api-part-supplier-price-list')
part = models.ForeignKey(SupplierPart, on_delete=models.CASCADE, related_name='pricebreaks', verbose_name=_('Part'),)
@ -713,10 +649,12 @@ class SupplierPriceBreak(common.models.PriceBreak):
updated = models.DateTimeField(auto_now=True, null=True, verbose_name=_('last updated'))
class Meta:
"""Metaclass defines extra model options"""
unique_together = ("part", "quantity")
# This model was moved from the 'Part' app
db_table = 'part_supplierpricebreak'
def __str__(self):
"""Format a string representation of a SupplierPriceBreak instance"""
return f'{self.part.SKU} - {self.price} @ {self.quantity}'

View File

@ -1,6 +1,4 @@
"""
JSON serializers for Company app
"""
"""JSON serializers for Company app."""
from django.utils.translation import gettext_lazy as _
@ -21,13 +19,15 @@ from .models import (Company, ManufacturerPart, ManufacturerPartAttachment,
class CompanyBriefSerializer(InvenTreeModelSerializer):
""" Serializer for Company object (limited detail) """
"""Serializer for Company object (limited detail)"""
url = serializers.CharField(source='get_absolute_url', read_only=True)
image = serializers.CharField(source='get_thumbnail_url', read_only=True)
class Meta:
"""Metaclass options."""
model = Company
fields = [
'pk',
@ -39,11 +39,11 @@ class CompanyBriefSerializer(InvenTreeModelSerializer):
class CompanySerializer(InvenTreeModelSerializer):
""" Serializer for Company object (full detail) """
"""Serializer for Company object (full detail)"""
@staticmethod
def annotate_queryset(queryset):
"""Annoate the supplied queryset with aggregated information"""
# Add count of parts manufactured
queryset = queryset.annotate(
parts_manufactured=SubqueryCount('manufactured_parts')
@ -71,6 +71,8 @@ class CompanySerializer(InvenTreeModelSerializer):
)
class Meta:
"""Metaclass options."""
model = Company
fields = [
'pk',
@ -96,9 +98,7 @@ class CompanySerializer(InvenTreeModelSerializer):
class ManufacturerPartSerializer(InvenTreeModelSerializer):
"""
Serializer for ManufacturerPart object
"""
"""Serializer for ManufacturerPart object."""
part_detail = PartBriefSerializer(source='part', many=False, read_only=True)
@ -107,7 +107,7 @@ class ManufacturerPartSerializer(InvenTreeModelSerializer):
pretty_name = serializers.CharField(read_only=True)
def __init__(self, *args, **kwargs):
"""Initialize this serializer with extra detail fields as required"""
part_detail = kwargs.pop('part_detail', True)
manufacturer_detail = kwargs.pop('manufacturer_detail', True)
prettify = kwargs.pop('pretty', False)
@ -126,6 +126,8 @@ class ManufacturerPartSerializer(InvenTreeModelSerializer):
manufacturer = serializers.PrimaryKeyRelatedField(queryset=Company.objects.filter(is_manufacturer=True))
class Meta:
"""Metaclass options."""
model = ManufacturerPart
fields = [
'pk',
@ -141,11 +143,11 @@ class ManufacturerPartSerializer(InvenTreeModelSerializer):
class ManufacturerPartAttachmentSerializer(InvenTreeAttachmentSerializer):
"""
Serializer for the ManufacturerPartAttachment class
"""
"""Serializer for the ManufacturerPartAttachment class."""
class Meta:
"""Metaclass options."""
model = ManufacturerPartAttachment
fields = [
@ -164,14 +166,12 @@ class ManufacturerPartAttachmentSerializer(InvenTreeAttachmentSerializer):
class ManufacturerPartParameterSerializer(InvenTreeModelSerializer):
"""
Serializer for the ManufacturerPartParameter model
"""
"""Serializer for the ManufacturerPartParameter model."""
manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', many=False, read_only=True)
def __init__(self, *args, **kwargs):
"""Initialize this serializer with extra detail fields as required"""
man_detail = kwargs.pop('manufacturer_part_detail', False)
super(ManufacturerPartParameterSerializer, self).__init__(*args, **kwargs)
@ -180,6 +180,8 @@ class ManufacturerPartParameterSerializer(InvenTreeModelSerializer):
self.fields.pop('manufacturer_part_detail')
class Meta:
"""Metaclass options."""
model = ManufacturerPartParameter
fields = [
@ -193,7 +195,7 @@ class ManufacturerPartParameterSerializer(InvenTreeModelSerializer):
class SupplierPartSerializer(InvenTreeModelSerializer):
""" Serializer for SupplierPart object """
"""Serializer for SupplierPart object."""
part_detail = PartBriefSerializer(source='part', many=False, read_only=True)
@ -204,7 +206,7 @@ class SupplierPartSerializer(InvenTreeModelSerializer):
pretty_name = serializers.CharField(read_only=True)
def __init__(self, *args, **kwargs):
"""Initialize this serializer with extra detail fields as required"""
part_detail = kwargs.pop('part_detail', True)
supplier_detail = kwargs.pop('supplier_detail', True)
manufacturer_detail = kwargs.pop('manufacturer_detail', True)
@ -234,6 +236,8 @@ class SupplierPartSerializer(InvenTreeModelSerializer):
manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', read_only=True)
class Meta:
"""Metaclass options."""
model = SupplierPart
fields = [
'description',
@ -255,8 +259,7 @@ class SupplierPartSerializer(InvenTreeModelSerializer):
]
def create(self, validated_data):
""" Extract manufacturer data and process ManufacturerPart """
"""Extract manufacturer data and process ManufacturerPart."""
# Create SupplierPart
supplier_part = super().create(validated_data)
@ -275,7 +278,7 @@ class SupplierPartSerializer(InvenTreeModelSerializer):
class SupplierPriceBreakSerializer(InvenTreeModelSerializer):
""" Serializer for SupplierPriceBreak object """
"""Serializer for SupplierPriceBreak object."""
quantity = InvenTreeDecimalField()
@ -292,6 +295,8 @@ class SupplierPriceBreakSerializer(InvenTreeModelSerializer):
)
class Meta:
"""Metaclass options."""
model = SupplierPriceBreak
fields = [
'pk',

View File

@ -1,3 +1,5 @@
"""Unit testing for the company app API functions"""
from django.urls import reverse
from rest_framework import status
@ -8,9 +10,7 @@ from .models import Company
class CompanyTest(InvenTreeAPITestCase):
"""
Series of tests for the Company DRF API
"""
"""Series of tests for the Company DRF API."""
roles = [
'purchase_order.add',
@ -18,7 +18,7 @@ class CompanyTest(InvenTreeAPITestCase):
]
def setUp(self):
"""Perform initialization for the unit test class"""
super().setUp()
self.acme = Company.objects.create(name='ACME', description='Supplier', is_customer=False, is_supplier=True)
@ -26,6 +26,7 @@ class CompanyTest(InvenTreeAPITestCase):
Company.objects.create(name='Sippy Cup Emporium', description='Another supplier')
def test_company_list(self):
"""Test the list API endpoint for the Company model"""
url = reverse('api-company-list')
# There should be three companies
@ -45,10 +46,7 @@ class CompanyTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), 2)
def test_company_detail(self):
"""
Tests for the Company detail endpoint
"""
"""Tests for the Company detail endpoint."""
url = reverse('api-company-detail', kwargs={'pk': self.acme.pk})
response = self.get(url)
@ -71,20 +69,14 @@ class CompanyTest(InvenTreeAPITestCase):
self.assertEqual(response.data['currency'], 'NZD')
def test_company_search(self):
"""
Test search functionality in company list
"""
"""Test search functionality in company list."""
url = reverse('api-company-list')
data = {'search': 'cup'}
response = self.get(url, data)
self.assertEqual(len(response.data), 2)
def test_company_create(self):
"""
Test that we can create a company via the API!
"""
"""Test that we can create a company via the API!"""
url = reverse('api-company-list')
# Name is required
@ -146,9 +138,7 @@ class CompanyTest(InvenTreeAPITestCase):
class ManufacturerTest(InvenTreeAPITestCase):
"""
Series of tests for the Manufacturer DRF API
"""
"""Series of tests for the Manufacturer DRF API."""
fixtures = [
'category',
@ -164,6 +154,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
]
def test_manufacturer_part_list(self):
"""Test the ManufacturerPart API list functionality"""
url = reverse('api-manufacturer-part-list')
# There should be three manufacturer parts
@ -191,9 +182,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), 2)
def test_manufacturer_part_detail(self):
"""
Tests for the ManufacturerPart detail endpoint
"""
"""Tests for the ManufacturerPart detail endpoint."""
url = reverse('api-manufacturer-part-detail', kwargs={'pk': 1})
response = self.get(url)
@ -210,13 +199,14 @@ class ManufacturerTest(InvenTreeAPITestCase):
self.assertEqual(response.data['MPN'], 'MPN-TEST-123')
def test_manufacturer_part_search(self):
# Test search functionality in manufacturer list
"""Test search functionality in manufacturer list"""
url = reverse('api-manufacturer-part-list')
data = {'search': 'MPN'}
response = self.get(url, data)
self.assertEqual(len(response.data), 3)
def test_supplier_part_create(self):
"""Test a SupplierPart can be created via the API"""
url = reverse('api-supplier-part-list')
# Create a manufacturer part

View File

@ -1,6 +1,4 @@
"""
Tests for the company model database migrations
"""
"""Tests for the company model database migrations."""
from django_test_migrations.contrib.unittest_case import MigratorTestCase
@ -8,15 +6,13 @@ from InvenTree import helpers
class TestForwardMigrations(MigratorTestCase):
"""Unit testing class for testing 'company' app migrations"""
migrate_from = ('company', helpers.getOldestMigrationFile('company'))
migrate_to = ('company', helpers.getNewestMigrationFile('company'))
def prepare(self):
"""
Create some simple Company data, and ensure that it migrates OK
"""
"""Create some simple Company data, and ensure that it migrates OK."""
Company = self.old_state.apps.get_model('company', 'company')
Company.objects.create(
@ -26,29 +22,25 @@ class TestForwardMigrations(MigratorTestCase):
)
def test_migrations(self):
"""Test the database state after applying all migrations"""
Company = self.new_state.apps.get_model('company', 'company')
self.assertEqual(Company.objects.count(), 1)
class TestManufacturerField(MigratorTestCase):
"""
Tests for migration 0019 which migrates from old 'manufacturer_name' field to new 'manufacturer' field
"""
"""Tests for migration 0019 which migrates from old 'manufacturer_name' field to new 'manufacturer' field."""
migrate_from = ('company', '0018_supplierpart_manufacturer')
migrate_to = ('company', '0019_auto_20200413_0642')
def prepare(self):
"""
Prepare the database by adding some test data 'before' the change:
"""Prepare the database by adding some test data 'before' the change:
- Part object
- Company object (supplier)
- SupplierPart object
"""
Part = self.old_state.apps.get_model('part', 'part')
Company = self.old_state.apps.get_model('company', 'company')
SupplierPart = self.old_state.apps.get_model('company', 'supplierpart')
@ -85,10 +77,7 @@ class TestManufacturerField(MigratorTestCase):
self.assertEqual(Company.objects.count(), 1)
def test_company_objects(self):
"""
Test that the new companies have been created successfully
"""
"""Test that the new companies have been created successfully."""
# Two additional company objects should have been created
Company = self.new_state.apps.get_model('company', 'company')
self.assertEqual(Company.objects.count(), 3)
@ -108,22 +97,18 @@ class TestManufacturerField(MigratorTestCase):
class TestManufacturerPart(MigratorTestCase):
"""
Tests for migration 0034-0037 which added and transitioned to the ManufacturerPart model
"""
"""Tests for migration 0034-0037 which added and transitioned to the ManufacturerPart model."""
migrate_from = ('company', '0033_auto_20210410_1528')
migrate_to = ('company', '0037_supplierpart_update_3')
def prepare(self):
"""
Prepare the database by adding some test data 'before' the change:
"""Prepare the database by adding some test data 'before' the change:
- Part object
- Company object (supplier)
- SupplierPart object
"""
Part = self.old_state.apps.get_model('part', 'part')
Company = self.old_state.apps.get_model('company', 'company')
SupplierPart = self.old_state.apps.get_model('company', 'supplierpart')
@ -214,10 +199,7 @@ class TestManufacturerPart(MigratorTestCase):
)
def test_manufacturer_part_objects(self):
"""
Test that the new companies have been created successfully
"""
"""Test that the new companies have been created successfully."""
# Check on the SupplierPart objects
SupplierPart = self.new_state.apps.get_model('company', 'supplierpart')
@ -238,16 +220,13 @@ class TestManufacturerPart(MigratorTestCase):
class TestCurrencyMigration(MigratorTestCase):
"""
Tests for upgrade from basic currency support to django-money
"""
"""Tests for upgrade from basic currency support to django-money."""
migrate_from = ('company', '0025_auto_20201110_1001')
migrate_to = ('company', '0026_auto_20201110_1011')
def prepare(self):
"""
Prepare some data:
"""Prepare some data:
- A part to buy
- A supplier to buy from
@ -255,7 +234,6 @@ class TestCurrencyMigration(MigratorTestCase):
- Multiple currency objects
- Multiple supplier price breaks
"""
Part = self.old_state.apps.get_model('part', 'part')
part = Part.objects.create(
@ -293,7 +271,7 @@ class TestCurrencyMigration(MigratorTestCase):
self.assertIsNone(pb.price)
def test_currency_migration(self):
"""Test database state after applying migrations"""
PB = self.new_state.apps.get_model('company', 'supplierpricebreak')
for pb in PB.objects.all():

View File

@ -1,11 +1,12 @@
""" Unit tests for Company views (see views.py) """
"""Unit tests for Company views (see views.py)"""
from django.urls import reverse
from InvenTree.helpers import InvenTreeTestCase
class CompanyViewTestBase(InvenTreeTestCase):
class CompanyViewTest(InvenTreeTestCase):
"""Tests for various 'Company' views."""
fixtures = [
'category',
@ -18,40 +19,29 @@ class CompanyViewTestBase(InvenTreeTestCase):
roles = 'all'
class CompanyViewTest(CompanyViewTestBase):
"""
Tests for various 'Company' views
"""
def test_company_index(self):
""" Test the company index """
"""Test the company index."""
response = self.client.get(reverse('company-index'))
self.assertEqual(response.status_code, 200)
def test_manufacturer_index(self):
""" Test the manufacturer index """
"""Test the manufacturer index."""
response = self.client.get(reverse('manufacturer-index'))
self.assertEqual(response.status_code, 200)
def test_customer_index(self):
""" Test the customer index """
"""Test the customer index."""
response = self.client.get(reverse('customer-index'))
self.assertEqual(response.status_code, 200)
def test_manufacturer_part_detail_view(self):
""" Test the manufacturer part detail view """
"""Test the manufacturer part detail view."""
response = self.client.get(reverse('manufacturer-part-detail', kwargs={'pk': 1}))
self.assertEqual(response.status_code, 200)
self.assertContains(response, 'MPN123')
def test_supplier_part_detail_view(self):
""" Test the supplier part detail view """
"""Test the supplier part detail view."""
response = self.client.get(reverse('supplier-part-detail', kwargs={'pk': 10}))
self.assertEqual(response.status_code, 200)
self.assertContains(response, 'MPN456-APPEL')

View File

@ -1,3 +1,5 @@
"""Unit tests for the models in the 'company' app"""
import os
from decimal import Decimal
@ -11,6 +13,7 @@ from .models import (Company, Contact, ManufacturerPart, SupplierPart,
class CompanySimpleTest(TestCase):
"""Unit tests for the Company model"""
fixtures = [
'company',
@ -24,6 +27,7 @@ class CompanySimpleTest(TestCase):
]
def setUp(self):
"""Perform initialization for the tests in this class"""
Company.objects.create(name='ABC Co.',
description='Seller of ABC products',
website='www.abc-sales.com',
@ -37,15 +41,18 @@ class CompanySimpleTest(TestCase):
self.zergm312 = SupplierPart.objects.get(SKU='ZERGM312')
def test_company_model(self):
"""Tests for the company model data"""
c = Company.objects.get(name='ABC Co.')
self.assertEqual(c.name, 'ABC Co.')
self.assertEqual(str(c), 'ABC Co. - Seller of ABC products')
def test_company_url(self):
"""Test the detail URL for a company"""
c = Company.objects.get(pk=1)
self.assertEqual(c.get_absolute_url(), '/company/1/')
def test_image_renamer(self):
"""Test the company image upload functionality"""
c = Company.objects.get(pk=1)
rn = rename_company_image(c, 'test.png')
self.assertEqual(rn, 'company_images' + os.path.sep + 'company_1_img.png')
@ -53,23 +60,8 @@ class CompanySimpleTest(TestCase):
rn = rename_company_image(c, 'test2')
self.assertEqual(rn, 'company_images' + os.path.sep + 'company_1_img')
def test_part_count(self):
acme = Company.objects.get(pk=1)
appel = Company.objects.get(pk=2)
zerg = Company.objects.get(pk=3)
self.assertTrue(acme.has_parts)
self.assertEqual(acme.supplied_part_count, 4)
self.assertTrue(appel.has_parts)
self.assertEqual(appel.supplied_part_count, 4)
self.assertTrue(zerg.has_parts)
self.assertEqual(zerg.supplied_part_count, 2)
def test_price_breaks(self):
"""Unit tests for price breaks"""
self.assertTrue(self.acme0001.has_price_breaks)
self.assertTrue(self.acme0002.has_price_breaks)
self.assertTrue(self.zergm312.has_price_breaks)
@ -81,8 +73,7 @@ class CompanySimpleTest(TestCase):
self.assertEqual(self.zergm312.price_breaks.count(), 2)
def test_quantity_pricing(self):
""" Simple test for quantity pricing """
"""Simple test for quantity pricing."""
p = self.acme0001.get_price
self.assertEqual(p(1), 10)
self.assertEqual(p(4), 40)
@ -99,6 +90,7 @@ class CompanySimpleTest(TestCase):
self.assertEqual(p(55), 68.75)
def test_part_pricing(self):
"""Unit tests for supplier part pricing"""
m2x4 = Part.objects.get(name='M2x4 LPHS')
self.assertEqual(m2x4.get_price_info(5.5), "38.5 - 41.25")
@ -116,10 +108,7 @@ class CompanySimpleTest(TestCase):
self.assertIsNotNone(m3x12.get_price_info(50))
def test_currency_validation(self):
"""
Test validation for currency selection
"""
"""Test validation for currency selection."""
# Create a company with a valid currency code (should pass)
company = Company.objects.create(
name='Test',
@ -141,8 +130,10 @@ class CompanySimpleTest(TestCase):
class ContactSimpleTest(TestCase):
"""Unit tests for the Contact model"""
def setUp(self):
"""Initialization for the tests in this class"""
# Create a simple company
self.c = Company.objects.create(name='Test Corp.', description='We make stuff good')
@ -152,15 +143,18 @@ class ContactSimpleTest(TestCase):
Contact.objects.create(name='Sally Smith', company=self.c)
def test_exists(self):
"""Test that contacts exist"""
self.assertEqual(Contact.objects.count(), 3)
def test_delete(self):
"""Test deletion of a Contact instance"""
# Remove the parent company
Company.objects.get(pk=self.c.pk).delete()
self.assertEqual(Contact.objects.count(), 0)
class ManufacturerPartSimpleTest(TestCase):
"""Unit tests for the ManufacturerPart model"""
fixtures = [
'category',
@ -171,6 +165,8 @@ class ManufacturerPartSimpleTest(TestCase):
]
def setUp(self):
"""Initialization for the unit tests in this class"""
# Create a manufacturer part
self.part = Part.objects.get(pk=1)
manufacturer = Company.objects.get(pk=1)
@ -193,6 +189,7 @@ class ManufacturerPartSimpleTest(TestCase):
supplier_part.save()
def test_exists(self):
"""That that a ManufacturerPart has been created"""
self.assertEqual(ManufacturerPart.objects.count(), 4)
# Check that manufacturer part was created from supplier part creation
@ -200,7 +197,7 @@ class ManufacturerPartSimpleTest(TestCase):
self.assertEqual(manufacturer_parts.count(), 1)
def test_delete(self):
# Remove a part
"""Test deletion of a ManufacturerPart"""
Part.objects.get(pk=self.part.id).delete()
# Check that ManufacturerPart was deleted
self.assertEqual(ManufacturerPart.objects.count(), 3)

View File

@ -1,6 +1,4 @@
"""
URL lookup for Company app
"""
"""URL lookup for Company app."""
from django.urls import include, re_path

View File

@ -1,6 +1,4 @@
"""
Django views for interacting with Company app
"""
"""Django views for interacting with Company app."""
import io
@ -20,8 +18,7 @@ from .models import Company, ManufacturerPart, SupplierPart
class CompanyIndex(InvenTreeRoleMixin, ListView):
""" View for displaying list of companies
"""
"""View for displaying list of companies."""
model = Company
template_name = 'company/index.html'
@ -30,6 +27,7 @@ class CompanyIndex(InvenTreeRoleMixin, ListView):
permission_required = 'company.view_company'
def get_context_data(self, **kwargs):
"""Add extra context data to the company index page"""
ctx = super().get_context_data(**kwargs)
@ -80,7 +78,7 @@ class CompanyIndex(InvenTreeRoleMixin, ListView):
return ctx
def get_queryset(self):
""" Retrieve the Company queryset based on HTTP request parameters.
"""Retrieve the Company queryset based on HTTP request parameters.
- supplier: Filter by supplier
- customer: Filter by customer
@ -97,23 +95,16 @@ class CompanyIndex(InvenTreeRoleMixin, ListView):
class CompanyDetail(InvenTreePluginViewMixin, DetailView):
""" Detail view for Company object """
"""Detail view for Company object."""
context_obect_name = 'company'
template_name = 'company/detail.html'
queryset = Company.objects.all()
model = Company
permission_required = 'company.view_company'
def get_context_data(self, **kwargs):
ctx = super().get_context_data(**kwargs)
return ctx
class CompanyImageDownloadFromURL(AjaxUpdateView):
"""
View for downloading an image from a provided URL
"""
"""View for downloading an image from a provided URL."""
model = Company
ajax_template_name = 'image_download.html'
@ -121,9 +112,7 @@ class CompanyImageDownloadFromURL(AjaxUpdateView):
ajax_form_title = _('Download Image')
def validate(self, company, form):
"""
Validate that the image data are correct
"""
"""Validate that the image data are correct."""
# First ensure that the normal validation routines pass
if not form.is_valid():
return
@ -167,9 +156,7 @@ class CompanyImageDownloadFromURL(AjaxUpdateView):
return
def save(self, company, form, **kwargs):
"""
Save the downloaded image to the company
"""
"""Save the downloaded image to the company."""
fmt = self.image.format
if not fmt:
@ -189,28 +176,18 @@ class CompanyImageDownloadFromURL(AjaxUpdateView):
class ManufacturerPartDetail(InvenTreePluginViewMixin, DetailView):
""" Detail view for ManufacturerPart """
"""Detail view for ManufacturerPart."""
model = ManufacturerPart
template_name = 'company/manufacturer_part_detail.html'
context_object_name = 'part'
queryset = ManufacturerPart.objects.all()
permission_required = 'purchase_order.view'
def get_context_data(self, **kwargs):
ctx = super().get_context_data(**kwargs)
return ctx
class SupplierPartDetail(InvenTreePluginViewMixin, DetailView):
""" Detail view for SupplierPart """
"""Detail view for SupplierPart."""
model = SupplierPart
template_name = 'company/supplier_part_detail.html'
context_object_name = 'part'
queryset = SupplierPart.objects.all()
permission_required = 'purchase_order.view'
def get_context_data(self, **kwargs):
ctx = super().get_context_data(**kwargs)
return ctx

View File

@ -1,3 +1,5 @@
"""Gunicorn configuration script for InvenTree web server"""
import multiprocessing
bind = "0.0.0.0:8000"

View File

@ -1,10 +1,12 @@
"""Admin functionality for the 'label' app"""
from django.contrib import admin
from .models import PartLabel, StockItemLabel, StockLocationLabel
class LabelAdmin(admin.ModelAdmin):
"""Admin class for the various label models"""
list_display = ('name', 'description', 'label', 'filters', 'enabled')

View File

@ -1,3 +1,5 @@
"""API functionality for the 'label' app"""
from django.conf import settings
from django.core.exceptions import FieldError, ValidationError
from django.http import HttpResponse, JsonResponse
@ -21,9 +23,7 @@ from .serializers import (PartLabelSerializer, StockItemLabelSerializer,
class LabelListView(generics.ListAPIView):
"""
Generic API class for label templates
"""
"""Generic API class for label templates."""
filter_backends = [
DjangoFilterBackend,
@ -41,13 +41,11 @@ class LabelListView(generics.ListAPIView):
class LabelPrintMixin:
"""
Mixin for printing labels
"""
"""Mixin for printing labels."""
def get_plugin(self, request):
"""
Return the label printing plugin associated with this request.
"""Return the label printing plugin associated with this request.
This is provided in the url, e.g. ?plugin=myprinter
Requires:
@ -56,7 +54,6 @@ class LabelPrintMixin:
- matching plugin implements the 'labels' mixin
- matching plugin is enabled
"""
if not settings.PLUGINS_ENABLED:
return None # pragma: no cover
@ -80,10 +77,7 @@ class LabelPrintMixin:
raise NotFound(f"Plugin '{plugin_key}' not found")
def print(self, request, items_to_print):
"""
Print this label template against a number of pre-validated items
"""
"""Print this label template against a number of pre-validated items."""
# Check the request to determine if the user has selected a label printing plugin
plugin = self.get_plugin(request)
@ -119,26 +113,20 @@ class LabelPrintMixin:
label_name += ".pdf"
if plugin is not None:
"""
Label printing is to be handled by a plugin,
rather than being exported to PDF.
"""Label printing is to be handled by a plugin, rather than being exported to PDF.
In this case, we do the following:
- Individually generate each label, exporting as an image file
- Pass all the images through to the label printing plugin
- Return a JSON response indicating that the printing has been offloaded
"""
# Label instance
label_instance = self.get_object()
for idx, output in enumerate(outputs):
"""
For each output, we generate a temporary image file,
which will then get sent to the printer
"""
"""For each output, we generate a temporary image file, which will then get sent to the printer."""
# Generate PDF data for the label
pdf = output.get_document().write_pdf()
@ -159,20 +147,14 @@ class LabelPrintMixin:
})
elif debug_mode:
"""
Contatenate all rendered templates into a single HTML string,
and return the string as a HTML response.
"""
"""Contatenate all rendered templates into a single HTML string, and return the string as a HTML response."""
html = "\n".join(outputs)
return HttpResponse(html)
else:
"""
Concatenate all rendered pages into a single PDF object,
and return the resulting document!
"""
"""Concatenate all rendered pages into a single PDF object, and return the resulting document!"""
pages = []
@ -198,15 +180,10 @@ class LabelPrintMixin:
class StockItemLabelMixin:
"""
Mixin for extracting stock items from query params
"""
"""Mixin for extracting stock items from query params."""
def get_items(self):
"""
Return a list of requested stock items
"""
"""Return a list of requested stock items."""
items = []
params = self.request.query_params
@ -231,25 +208,20 @@ class StockItemLabelMixin:
class StockItemLabelList(LabelListView, StockItemLabelMixin):
"""
API endpoint for viewing list of StockItemLabel objects.
"""API endpoint for viewing list of StockItemLabel objects.
Filterable by:
- enabled: Filter by enabled / disabled status
- item: Filter by single stock item
- items: Filter by list of stock items
"""
queryset = StockItemLabel.objects.all()
serializer_class = StockItemLabelSerializer
def filter_queryset(self, queryset):
"""
Filter the StockItem label queryset.
"""
"""Filter the StockItem label queryset."""
queryset = super().filter_queryset(queryset)
# List of StockItem objects to match against
@ -304,42 +276,30 @@ class StockItemLabelList(LabelListView, StockItemLabelMixin):
class StockItemLabelDetail(generics.RetrieveUpdateDestroyAPIView):
"""
API endpoint for a single StockItemLabel object
"""
"""API endpoint for a single StockItemLabel object."""
queryset = StockItemLabel.objects.all()
serializer_class = StockItemLabelSerializer
class StockItemLabelPrint(generics.RetrieveAPIView, StockItemLabelMixin, LabelPrintMixin):
"""
API endpoint for printing a StockItemLabel object
"""
"""API endpoint for printing a StockItemLabel object."""
queryset = StockItemLabel.objects.all()
serializer_class = StockItemLabelSerializer
def get(self, request, *args, **kwargs):
"""
Check if valid stock item(s) have been provided.
"""
"""Check if valid stock item(s) have been provided."""
items = self.get_items()
return self.print(request, items)
class StockLocationLabelMixin:
"""
Mixin for extracting stock locations from query params
"""
"""Mixin for extracting stock locations from query params."""
def get_locations(self):
"""
Return a list of requested stock locations
"""
"""Return a list of requested stock locations."""
locations = []
params = self.request.query_params
@ -364,8 +324,7 @@ class StockLocationLabelMixin:
class StockLocationLabelList(LabelListView, StockLocationLabelMixin):
"""
API endpoint for viewiing list of StockLocationLabel objects.
"""API endpoint for viewiing list of StockLocationLabel objects.
Filterable by:
@ -378,10 +337,7 @@ class StockLocationLabelList(LabelListView, StockLocationLabelMixin):
serializer_class = StockLocationLabelSerializer
def filter_queryset(self, queryset):
"""
Filter the StockLocationLabel queryset
"""
"""Filter the StockLocationLabel queryset."""
queryset = super().filter_queryset(queryset)
# List of StockLocation objects to match against
@ -436,39 +392,30 @@ class StockLocationLabelList(LabelListView, StockLocationLabelMixin):
class StockLocationLabelDetail(generics.RetrieveUpdateDestroyAPIView):
"""
API endpoint for a single StockLocationLabel object
"""
"""API endpoint for a single StockLocationLabel object."""
queryset = StockLocationLabel.objects.all()
serializer_class = StockLocationLabelSerializer
class StockLocationLabelPrint(generics.RetrieveAPIView, StockLocationLabelMixin, LabelPrintMixin):
"""
API endpoint for printing a StockLocationLabel object
"""
"""API endpoint for printing a StockLocationLabel object."""
queryset = StockLocationLabel.objects.all()
seiralizer_class = StockLocationLabelSerializer
def get(self, request, *args, **kwargs):
"""Print labels based on the request parameters"""
locations = self.get_locations()
return self.print(request, locations)
class PartLabelMixin:
"""
Mixin for extracting Part objects from query parameters
"""
"""Mixin for extracting Part objects from query parameters."""
def get_parts(self):
"""
Return a list of requested Part objects
"""
"""Return a list of requested Part objects."""
parts = []
params = self.request.query_params
@ -491,15 +438,13 @@ class PartLabelMixin:
class PartLabelList(LabelListView, PartLabelMixin):
"""
API endpoint for viewing list of PartLabel objects
"""
"""API endpoint for viewing list of PartLabel objects."""
queryset = PartLabel.objects.all()
serializer_class = PartLabelSerializer
def filter_queryset(self, queryset):
"""Custom queryset filtering for the PartLabel list"""
queryset = super().filter_queryset(queryset)
parts = self.get_parts()
@ -539,27 +484,20 @@ class PartLabelList(LabelListView, PartLabelMixin):
class PartLabelDetail(generics.RetrieveUpdateDestroyAPIView):
"""
API endpoint for a single PartLabel object
"""
"""API endpoint for a single PartLabel object."""
queryset = PartLabel.objects.all()
serializer_class = PartLabelSerializer
class PartLabelPrint(generics.RetrieveAPIView, PartLabelMixin, LabelPrintMixin):
"""
API endpoint for printing a PartLabel object
"""
"""API endpoint for printing a PartLabel object."""
queryset = PartLabel.objects.all()
serializer_class = PartLabelSerializer
def get(self, request, *args, **kwargs):
"""
Check if valid part(s) have been provided
"""
"""Check if valid part(s) have been provided."""
parts = self.get_parts()
return self.print(request, parts)

View File

@ -1,3 +1,5 @@
"""label app specification"""
import hashlib
import logging
import os
@ -14,10 +16,7 @@ logger = logging.getLogger("inventree")
def hashFile(filename):
"""
Calculate the MD5 hash of a file
"""
"""Calculate the MD5 hash of a file."""
md5 = hashlib.md5()
with open(filename, 'rb') as f:
@ -28,20 +27,17 @@ def hashFile(filename):
class LabelConfig(AppConfig):
"""App configuration class for the 'label' app"""
name = 'label'
def ready(self):
"""
This function is called whenever the label app is loaded
"""
"""This function is called whenever the label app is loaded."""
if canAppAccessDatabase():
self.create_labels() # pragma: no cover
def create_labels(self):
"""
Create all default templates
"""
"""Create all default templates."""
# Test if models are ready
try:
from .models import StockLocationLabel
@ -56,11 +52,7 @@ class LabelConfig(AppConfig):
self.create_part_labels()
def create_stock_item_labels(self):
"""
Create database entries for the default StockItemLabel templates,
if they do not already exist
"""
"""Create database entries for the default StockItemLabel templates, if they do not already exist."""
from .models import StockItemLabel
src_dir = os.path.join(
@ -139,11 +131,7 @@ class LabelConfig(AppConfig):
)
def create_stock_location_labels(self):
"""
Create database entries for the default StockItemLocation templates,
if they do not already exist
"""
"""Create database entries for the default StockItemLocation templates, if they do not already exist."""
from .models import StockLocationLabel
src_dir = os.path.join(
@ -229,11 +217,7 @@ class LabelConfig(AppConfig):
)
def create_part_labels(self):
"""
Create database entries for the default PartLabel templates,
if they do not already exist.
"""
"""Create database entries for the default PartLabel templates, if they do not already exist."""
from .models import PartLabel
src_dir = os.path.join(

View File

@ -1,6 +1,4 @@
"""
Label printing models
"""
"""Label printing models."""
import datetime
import logging
@ -32,55 +30,52 @@ logger = logging.getLogger("inventree")
def rename_label(instance, filename):
""" Place the label file into the correct subdirectory """
"""Place the label file into the correct subdirectory."""
filename = os.path.basename(filename)
return os.path.join('label', 'template', instance.SUBDIR, filename)
def validate_stock_item_filters(filters):
"""Validate query filters for the StockItemLabel model"""
filters = validateFilterString(filters, model=stock.models.StockItem)
return filters
def validate_stock_location_filters(filters):
"""Validate query filters for the StockLocationLabel model"""
filters = validateFilterString(filters, model=stock.models.StockLocation)
return filters
def validate_part_filters(filters):
"""Validate query filters for the PartLabel model"""
filters = validateFilterString(filters, model=part.models.Part)
return filters
class WeasyprintLabelMixin(WeasyTemplateResponseMixin):
"""
Class for rendering a label to a PDF
"""
"""Class for rendering a label to a PDF."""
pdf_filename = 'label.pdf'
pdf_attachment = True
def __init__(self, request, template, **kwargs):
"""Initialize a label mixin with certain properties"""
self.request = request
self.template_name = template
self.pdf_filename = kwargs.get('filename', 'label.pdf')
class LabelTemplate(models.Model):
"""
Base class for generic, filterable labels.
"""
"""Base class for generic, filterable labels."""
class Meta:
"""Metaclass options. Abstract ensures no database table is created."""
abstract = True
# Each class of label files will be stored in a separate subdirectory
@ -91,9 +86,11 @@ class LabelTemplate(models.Model):
@property
def template(self):
"""Return the file path of the template associated with this label instance"""
return self.label.path
def __str__(self):
"""Format a string representation of a label instance"""
return "{n} - {d}".format(
n=self.name,
d=self.description
@ -150,11 +147,10 @@ class LabelTemplate(models.Model):
@property
def template_name(self):
"""
Returns the file system path to the template file.
"""Returns the file system path to the template file.
Required for passing the file to an external process
"""
template = self.label.name
template = template.replace('/', os.path.sep)
template = template.replace('\\', os.path.sep)
@ -164,19 +160,14 @@ class LabelTemplate(models.Model):
return template
def get_context_data(self, request):
"""
Supply custom context data to the template for rendering.
"""Supply custom context data to the template for rendering.
Note: Override this in any subclass
"""
return {} # pragma: no cover
def generate_filename(self, request, **kwargs):
"""
Generate a filename for this label
"""
"""Generate a filename for this label."""
template_string = Template(self.filename_pattern)
ctx = self.context(request)
@ -186,10 +177,7 @@ class LabelTemplate(models.Model):
return template_string.render(context)
def context(self, request):
"""
Provides context data to the template.
"""
"""Provides context data to the template."""
context = self.get_context_data(request)
# Add "basic" context data which gets passed to every label
@ -204,21 +192,17 @@ class LabelTemplate(models.Model):
return context
def render_as_string(self, request, **kwargs):
"""
Render the label to a HTML string
"""Render the label to a HTML string.
Useful for debug mode (viewing generated code)
"""
return render_to_string(self.template_name, self.context(request), request)
def render(self, request, **kwargs):
"""
Render the label template to a PDF file
"""Render the label template to a PDF file.
Uses django-weasyprint plugin to render HTML template
"""
wp = WeasyprintLabelMixin(
request,
self.template_name,
@ -235,12 +219,11 @@ class LabelTemplate(models.Model):
class StockItemLabel(LabelTemplate):
"""
Template for printing StockItem labels
"""
"""Template for printing StockItem labels."""
@staticmethod
def get_api_url():
"""Return the API URL associated with the StockItemLabel model"""
return reverse('api-stockitem-label-list') # pragma: no cover
SUBDIR = "stockitem"
@ -255,10 +238,7 @@ class StockItemLabel(LabelTemplate):
)
def get_context_data(self, request):
"""
Generate context data for each provided StockItem
"""
"""Generate context data for each provided StockItem."""
stock_item = self.object_to_print
return {
@ -279,12 +259,11 @@ class StockItemLabel(LabelTemplate):
class StockLocationLabel(LabelTemplate):
"""
Template for printing StockLocation labels
"""
"""Template for printing StockLocation labels."""
@staticmethod
def get_api_url():
"""Return the API URL associated with the StockLocationLabel model"""
return reverse('api-stocklocation-label-list') # pragma: no cover
SUBDIR = "stocklocation"
@ -298,10 +277,7 @@ class StockLocationLabel(LabelTemplate):
)
def get_context_data(self, request):
"""
Generate context data for each provided StockLocation
"""
"""Generate context data for each provided StockLocation."""
location = self.object_to_print
return {
@ -311,12 +287,11 @@ class StockLocationLabel(LabelTemplate):
class PartLabel(LabelTemplate):
"""
Template for printing Part labels
"""
"""Template for printing Part labels."""
@staticmethod
def get_api_url():
"""Return the API url associated with the PartLabel model"""
return reverse('api-part-label-list') # pragma: no cover
SUBDIR = 'part'
@ -331,10 +306,7 @@ class PartLabel(LabelTemplate):
)
def get_context_data(self, request):
"""
Generate context data for each provided Part object
"""
"""Generate context data for each provided Part object."""
part = self.object_to_print
return {

View File

@ -1,3 +1,5 @@
"""API serializers for the label app"""
from InvenTree.serializers import (InvenTreeAttachmentSerializerField,
InvenTreeModelSerializer)
@ -5,13 +7,13 @@ from .models import PartLabel, StockItemLabel, StockLocationLabel
class StockItemLabelSerializer(InvenTreeModelSerializer):
"""
Serializes a StockItemLabel object.
"""
"""Serializes a StockItemLabel object."""
label = InvenTreeAttachmentSerializerField(required=True)
class Meta:
"""Metaclass options."""
model = StockItemLabel
fields = [
'pk',
@ -24,13 +26,13 @@ class StockItemLabelSerializer(InvenTreeModelSerializer):
class StockLocationLabelSerializer(InvenTreeModelSerializer):
"""
Serializes a StockLocationLabel object
"""
"""Serializes a StockLocationLabel object."""
label = InvenTreeAttachmentSerializerField(required=True)
class Meta:
"""Metaclass options."""
model = StockLocationLabel
fields = [
'pk',
@ -43,13 +45,13 @@ class StockLocationLabelSerializer(InvenTreeModelSerializer):
class PartLabelSerializer(InvenTreeModelSerializer):
"""
Serializes a PartLabel object
"""
"""Serializes a PartLabel object."""
label = InvenTreeAttachmentSerializerField(required=True)
class Meta:
"""Metaclass options."""
model = PartLabel
fields = [
'pk',

View File

@ -1,4 +1,4 @@
# Tests for labels
"""Unit tests for label API"""
from django.urls import reverse
@ -6,9 +6,7 @@ from InvenTree.api_tester import InvenTreeAPITestCase
class TestReportTests(InvenTreeAPITestCase):
"""
Tests for the StockItem TestReport templates
"""
"""Tests for the StockItem TestReport templates."""
fixtures = [
'category',
@ -24,12 +22,8 @@ class TestReportTests(InvenTreeAPITestCase):
list_url = reverse('api-stockitem-testreport-list')
def setUp(self):
super().setUp()
def do_list(self, filters={}):
"""Helper function to request list of labels with provided filters"""
response = self.client.get(self.list_url, filters, format='json')
self.assertEqual(response.status_code, 200)
@ -37,7 +31,7 @@ class TestReportTests(InvenTreeAPITestCase):
return response.data
def test_list(self):
"""Test the API list endpoint"""
response = self.do_list()
# TODO - Add some report templates to the fixtures

View File

@ -1,4 +1,4 @@
# Tests for labels
"""Tests for labels"""
import os
@ -16,6 +16,7 @@ from .models import PartLabel, StockItemLabel, StockLocationLabel
class LabelTest(InvenTreeAPITestCase):
"""Unit test class for label models"""
fixtures = [
'category',
@ -25,15 +26,12 @@ class LabelTest(InvenTreeAPITestCase):
]
def setUp(self) -> None:
"""Ensure that some label instances exist as part of init routine"""
super().setUp()
# ensure the labels were created
apps.get_app_config('label').create_labels()
def test_default_labels(self):
"""
Test that the default label templates are copied across
"""
"""Test that the default label templates are copied across."""
labels = StockItemLabel.objects.all()
self.assertTrue(labels.count() > 0)
@ -43,10 +41,7 @@ class LabelTest(InvenTreeAPITestCase):
self.assertTrue(labels.count() > 0)
def test_default_files(self):
"""
Test that label files exist in the MEDIA directory
"""
"""Test that label files exist in the MEDIA directory."""
item_dir = os.path.join(
settings.MEDIA_ROOT,
'label',
@ -70,10 +65,7 @@ class LabelTest(InvenTreeAPITestCase):
self.assertTrue(len(files) > 0)
def test_filters(self):
"""
Test the label filters
"""
"""Test the label filters."""
filter_string = "part__pk=10"
filters = validateFilterString(filter_string, model=StockItem)
@ -86,8 +78,7 @@ class LabelTest(InvenTreeAPITestCase):
validateFilterString(bad_filter_string, model=StockItem)
def test_label_rendering(self):
"""Test label rendering"""
"""Test label rendering."""
labels = PartLabel.objects.all()
part = Part.objects.first()

View File

@ -1 +0,0 @@
# Create your views here.

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python
"""InvenTree / django management commands"""
import os
import sys

View File

@ -1,3 +1 @@
"""
The Order module is responsible for managing Orders
"""
"""The Order module is responsible for managing Orders."""

View File

@ -1,3 +1,5 @@
"""Admin functionality for the 'order' app"""
from django.contrib import admin
import import_export.widgets as widgets
@ -13,6 +15,7 @@ from .models import (PurchaseOrder, PurchaseOrderExtraLine,
# region general classes
class GeneralExtraLineAdmin:
"""Admin class template for the 'ExtraLineItem' models"""
list_display = (
'order',
'quantity',
@ -29,6 +32,7 @@ class GeneralExtraLineAdmin:
class GeneralExtraLineMeta:
"""Metaclass template for the 'ExtraLineItem' models"""
skip_unchanged = True
report_skipped = False
clean_model_instances = True
@ -36,11 +40,13 @@ class GeneralExtraLineMeta:
class PurchaseOrderLineItemInlineAdmin(admin.StackedInline):
"""Inline admin class for the PurchaseOrderLineItem model"""
model = PurchaseOrderLineItem
extra = 0
class PurchaseOrderAdmin(ImportExportModelAdmin):
"""Admin class for the PurchaseOrder model"""
exclude = [
'reference_int',
@ -68,6 +74,7 @@ class PurchaseOrderAdmin(ImportExportModelAdmin):
class SalesOrderAdmin(ImportExportModelAdmin):
"""Admin class for the SalesOrder model"""
exclude = [
'reference_int',
@ -91,9 +98,7 @@ class SalesOrderAdmin(ImportExportModelAdmin):
class PurchaseOrderResource(ModelResource):
"""
Class for managing import / export of PurchaseOrder data
"""
"""Class for managing import / export of PurchaseOrder data."""
# Add number of line items
line_items = Field(attribute='line_count', widget=widgets.IntegerWidget(), readonly=True)
@ -102,6 +107,7 @@ class PurchaseOrderResource(ModelResource):
overdue = Field(attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True)
class Meta:
"""Metaclass"""
model = PurchaseOrder
skip_unchanged = True
clean_model_instances = True
@ -111,7 +117,7 @@ class PurchaseOrderResource(ModelResource):
class PurchaseOrderLineItemResource(ModelResource):
""" Class for managing import / export of PurchaseOrderLineItem data """
"""Class for managing import / export of PurchaseOrderLineItem data."""
part_name = Field(attribute='part__part__name', readonly=True)
@ -122,6 +128,7 @@ class PurchaseOrderLineItemResource(ModelResource):
SKU = Field(attribute='part__SKU', readonly=True)
class Meta:
"""Metaclass"""
model = PurchaseOrderLineItem
skip_unchanged = True
report_skipped = False
@ -129,16 +136,16 @@ class PurchaseOrderLineItemResource(ModelResource):
class PurchaseOrderExtraLineResource(ModelResource):
""" Class for managing import / export of PurchaseOrderExtraLine data """
"""Class for managing import / export of PurchaseOrderExtraLine data."""
class Meta(GeneralExtraLineMeta):
"""Metaclass options."""
model = PurchaseOrderExtraLine
class SalesOrderResource(ModelResource):
"""
Class for managing import / export of SalesOrder data
"""
"""Class for managing import / export of SalesOrder data."""
# Add number of line items
line_items = Field(attribute='line_count', widget=widgets.IntegerWidget(), readonly=True)
@ -147,6 +154,7 @@ class SalesOrderResource(ModelResource):
overdue = Field(attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True)
class Meta:
"""Metaclass options"""
model = SalesOrder
skip_unchanged = True
clean_model_instances = True
@ -156,9 +164,7 @@ class SalesOrderResource(ModelResource):
class SalesOrderLineItemResource(ModelResource):
"""
Class for managing import / export of SalesOrderLineItem data
"""
"""Class for managing import / export of SalesOrderLineItem data."""
part_name = Field(attribute='part__name', readonly=True)
@ -169,17 +175,17 @@ class SalesOrderLineItemResource(ModelResource):
fulfilled = Field(attribute='fulfilled_quantity', readonly=True)
def dehydrate_sale_price(self, item):
"""
Return a string value of the 'sale_price' field, rather than the 'Money' object.
"""Return a string value of the 'sale_price' field, rather than the 'Money' object.
Ref: https://github.com/inventree/InvenTree/issues/2207
"""
if item.sale_price:
return str(item.sale_price)
else:
return ''
class Meta:
"""Metaclass options"""
model = SalesOrderLineItem
skip_unchanged = True
report_skipped = False
@ -187,13 +193,16 @@ class SalesOrderLineItemResource(ModelResource):
class SalesOrderExtraLineResource(ModelResource):
""" Class for managing import / export of SalesOrderExtraLine data """
"""Class for managing import / export of SalesOrderExtraLine data."""
class Meta(GeneralExtraLineMeta):
"""Metaclass options."""
model = SalesOrderExtraLine
class PurchaseOrderLineItemAdmin(ImportExportModelAdmin):
"""Admin class for the PurchaseOrderLine model"""
resource_class = PurchaseOrderLineItemResource
@ -210,11 +219,12 @@ class PurchaseOrderLineItemAdmin(ImportExportModelAdmin):
class PurchaseOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin):
"""Admin class for the PurchaseOrderExtraLine model"""
resource_class = PurchaseOrderExtraLineResource
class SalesOrderLineItemAdmin(ImportExportModelAdmin):
"""Admin class for the SalesOrderLine model"""
resource_class = SalesOrderLineItemResource
@ -236,11 +246,12 @@ class SalesOrderLineItemAdmin(ImportExportModelAdmin):
class SalesOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin):
"""Admin class for the SalesOrderExtraLine model"""
resource_class = SalesOrderExtraLineResource
class SalesOrderShipmentAdmin(ImportExportModelAdmin):
"""Admin class for the SalesOrderShipment model"""
list_display = [
'order',
@ -258,6 +269,7 @@ class SalesOrderShipmentAdmin(ImportExportModelAdmin):
class SalesOrderAllocationAdmin(ImportExportModelAdmin):
"""Admin class for the SalesOrderAllocation model"""
list_display = (
'line',

View File

@ -1,6 +1,4 @@
"""
JSON API for the Order app
"""
"""JSON API for the Order app."""
from django.db.models import F, Q
from django.urls import include, path, re_path
@ -24,11 +22,10 @@ from users.models import Owner
class GeneralExtraLineList:
"""
General template for ExtraLine API classes
"""
"""General template for ExtraLine API classes."""
def get_serializer(self, *args, **kwargs):
"""Return the serializer instance for this endpoint"""
try:
params = self.request.query_params
@ -41,7 +38,7 @@ class GeneralExtraLineList:
return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs):
"""Return the annotated queryset for this endpoint"""
queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related(
@ -76,17 +73,12 @@ class GeneralExtraLineList:
class PurchaseOrderFilter(rest_filters.FilterSet):
"""
Custom API filters for the PurchaseOrderList endpoint
"""
"""Custom API filters for the PurchaseOrderList endpoint."""
assigned_to_me = rest_filters.BooleanFilter(label='assigned_to_me', method='filter_assigned_to_me')
def filter_assigned_to_me(self, queryset, name, value):
"""
Filter by orders which are assigned to the current user
"""
"""Filter by orders which are assigned to the current user."""
value = str2bool(value)
# Work out who "me" is!
@ -100,6 +92,8 @@ class PurchaseOrderFilter(rest_filters.FilterSet):
return queryset
class Meta:
"""Metaclass options."""
model = models.PurchaseOrder
fields = [
'supplier',
@ -107,7 +101,7 @@ class PurchaseOrderFilter(rest_filters.FilterSet):
class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView):
""" API endpoint for accessing a list of PurchaseOrder objects
"""API endpoint for accessing a list of PurchaseOrder objects.
- GET: Return list of PurchaseOrder objects (with filters)
- POST: Create a new PurchaseOrder object
@ -118,9 +112,7 @@ class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView):
filterset_class = PurchaseOrderFilter
def create(self, request, *args, **kwargs):
"""
Save user information on create
"""
"""Save user information on create."""
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
@ -132,7 +124,7 @@ class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView):
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def get_serializer(self, *args, **kwargs):
"""Return the serializer instance for this endpoint"""
try:
kwargs['supplier_detail'] = str2bool(self.request.query_params.get('supplier_detail', False))
except AttributeError:
@ -144,7 +136,7 @@ class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView):
return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs):
"""Return the annotated queryset for this endpoint"""
queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related(
@ -157,6 +149,8 @@ class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView):
return queryset
def download_queryset(self, queryset, export_format):
"""Download the filtered queryset as a file"""
dataset = PurchaseOrderResource().export(queryset=queryset)
filedata = dataset.export(export_format)
@ -166,7 +160,7 @@ class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView):
return DownloadFile(filedata, filename)
def filter_queryset(self, queryset):
"""Custom queryset filtering"""
# Perform basic filtering
queryset = super().filter_queryset(queryset)
@ -260,13 +254,13 @@ class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView):
class PurchaseOrderDetail(generics.RetrieveUpdateDestroyAPIView):
""" API endpoint for detail view of a PurchaseOrder object """
"""API endpoint for detail view of a PurchaseOrder object."""
queryset = models.PurchaseOrder.objects.all()
serializer_class = serializers.PurchaseOrderSerializer
def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint"""
try:
kwargs['supplier_detail'] = str2bool(self.request.query_params.get('supplier_detail', False))
except AttributeError:
@ -278,7 +272,7 @@ class PurchaseOrderDetail(generics.RetrieveUpdateDestroyAPIView):
return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for this endpoint"""
queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related(
@ -292,11 +286,10 @@ class PurchaseOrderDetail(generics.RetrieveUpdateDestroyAPIView):
class PurchaseOrderContextMixin:
""" Mixin to add purchase order object as serializer context variable """
"""Mixin to add purchase order object as serializer context variable."""
def get_serializer_context(self):
""" Add the PurchaseOrder object to the serializer context """
"""Add the PurchaseOrder object to the serializer context."""
context = super().get_serializer_context()
# Pass the purchase order through to the serializer for validation
@ -311,8 +304,7 @@ class PurchaseOrderContextMixin:
class PurchaseOrderCancel(PurchaseOrderContextMixin, generics.CreateAPIView):
"""
API endpoint to 'cancel' a purchase order.
"""API endpoint to 'cancel' a purchase order.
The purchase order must be in a state which can be cancelled
"""
@ -323,9 +315,7 @@ class PurchaseOrderCancel(PurchaseOrderContextMixin, generics.CreateAPIView):
class PurchaseOrderComplete(PurchaseOrderContextMixin, generics.CreateAPIView):
"""
API endpoint to 'complete' a purchase order
"""
"""API endpoint to 'complete' a purchase order."""
queryset = models.PurchaseOrder.objects.all()
@ -333,9 +323,7 @@ class PurchaseOrderComplete(PurchaseOrderContextMixin, generics.CreateAPIView):
class PurchaseOrderIssue(PurchaseOrderContextMixin, generics.CreateAPIView):
"""
API endpoint to 'complete' a purchase order
"""
"""API endpoint to 'complete' a purchase order."""
queryset = models.PurchaseOrder.objects.all()
@ -343,17 +331,17 @@ class PurchaseOrderIssue(PurchaseOrderContextMixin, generics.CreateAPIView):
class PurchaseOrderMetadata(generics.RetrieveUpdateAPIView):
"""API endpoint for viewing / updating PurchaseOrder metadata"""
"""API endpoint for viewing / updating PurchaseOrder metadata."""
def get_serializer(self, *args, **kwargs):
"""Return MetadataSerializer instance for a PurchaseOrder"""
return MetadataSerializer(models.PurchaseOrder, *args, **kwargs)
queryset = models.PurchaseOrder.objects.all()
class PurchaseOrderReceive(PurchaseOrderContextMixin, generics.CreateAPIView):
"""
API endpoint to receive stock items against a purchase order.
"""API endpoint to receive stock items against a purchase order.
- The purchase order is specified in the URL.
- Items to receive are specified as a list called "items" with the following options:
@ -370,11 +358,11 @@ class PurchaseOrderReceive(PurchaseOrderContextMixin, generics.CreateAPIView):
class PurchaseOrderLineItemFilter(rest_filters.FilterSet):
"""
Custom filters for the PurchaseOrderLineItemList endpoint
"""
"""Custom filters for the PurchaseOrderLineItemList endpoint."""
class Meta:
"""Metaclass options."""
model = models.PurchaseOrderLineItem
fields = [
'order',
@ -384,10 +372,7 @@ class PurchaseOrderLineItemFilter(rest_filters.FilterSet):
pending = rest_filters.BooleanFilter(label='pending', method='filter_pending')
def filter_pending(self, queryset, name, value):
"""
Filter by "pending" status (order status = pending)
"""
"""Filter by "pending" status (order status = pending)"""
value = str2bool(value)
if value:
@ -402,12 +387,10 @@ class PurchaseOrderLineItemFilter(rest_filters.FilterSet):
received = rest_filters.BooleanFilter(label='received', method='filter_received')
def filter_received(self, queryset, name, value):
"""
Filter by lines which are "received" (or "not" received)
"""Filter by lines which are "received" (or "not" received)
A line is considered "received" when received >= quantity
"""
value = str2bool(value)
q = Q(received__gte=F('quantity'))
@ -422,7 +405,7 @@ class PurchaseOrderLineItemFilter(rest_filters.FilterSet):
class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView):
""" API endpoint for accessing a list of PurchaseOrderLineItem objects
"""API endpoint for accessing a list of PurchaseOrderLineItem objects.
- GET: Return a list of PurchaseOrder Line Item objects
- POST: Create a new PurchaseOrderLineItem object
@ -433,7 +416,7 @@ class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView):
filterset_class = PurchaseOrderLineItemFilter
def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for this endpoint"""
queryset = super().get_queryset(*args, **kwargs)
queryset = serializers.PurchaseOrderLineItemSerializer.annotate_queryset(queryset)
@ -441,7 +424,7 @@ class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView):
return queryset
def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint"""
try:
kwargs['part_detail'] = str2bool(self.request.query_params.get('part_detail', False))
kwargs['order_detail'] = str2bool(self.request.query_params.get('order_detail', False))
@ -453,10 +436,7 @@ class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView):
return self.serializer_class(*args, **kwargs)
def filter_queryset(self, queryset):
"""
Additional filtering options
"""
"""Additional filtering options."""
params = self.request.query_params
queryset = super().filter_queryset(queryset)
@ -475,6 +455,8 @@ class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView):
return queryset
def download_queryset(self, queryset, export_format):
"""Download the requested queryset as a file"""
dataset = PurchaseOrderLineItemResource().export(queryset=queryset)
filedata = dataset.export(export_format)
@ -483,19 +465,6 @@ class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView):
return DownloadFile(filedata, filename)
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
filter_backends = [
rest_filters.DjangoFilterBackend,
filters.SearchFilter,
@ -530,15 +499,13 @@ class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView):
class PurchaseOrderLineItemDetail(generics.RetrieveUpdateDestroyAPIView):
"""
Detail API endpoint for PurchaseOrderLineItem object
"""
"""Detail API endpoint for PurchaseOrderLineItem object."""
queryset = models.PurchaseOrderLineItem.objects.all()
serializer_class = serializers.PurchaseOrderLineItemSerializer
def get_queryset(self):
"""Return annotated queryset for this endpoint"""
queryset = super().get_queryset()
queryset = serializers.PurchaseOrderLineItemSerializer.annotate_queryset(queryset)
@ -547,25 +514,21 @@ class PurchaseOrderLineItemDetail(generics.RetrieveUpdateDestroyAPIView):
class PurchaseOrderExtraLineList(GeneralExtraLineList, generics.ListCreateAPIView):
"""
API endpoint for accessing a list of PurchaseOrderExtraLine objects.
"""
"""API endpoint for accessing a list of PurchaseOrderExtraLine objects."""
queryset = models.PurchaseOrderExtraLine.objects.all()
serializer_class = serializers.PurchaseOrderExtraLineSerializer
class PurchaseOrderExtraLineDetail(generics.RetrieveUpdateDestroyAPIView):
""" API endpoint for detail view of a PurchaseOrderExtraLine object """
"""API endpoint for detail view of a PurchaseOrderExtraLine object."""
queryset = models.PurchaseOrderExtraLine.objects.all()
serializer_class = serializers.PurchaseOrderExtraLineSerializer
class SalesOrderAttachmentList(generics.ListCreateAPIView, AttachmentMixin):
"""
API endpoint for listing (and creating) a SalesOrderAttachment (file upload)
"""
"""API endpoint for listing (and creating) a SalesOrderAttachment (file upload)"""
queryset = models.SalesOrderAttachment.objects.all()
serializer_class = serializers.SalesOrderAttachmentSerializer
@ -580,17 +543,14 @@ class SalesOrderAttachmentList(generics.ListCreateAPIView, AttachmentMixin):
class SalesOrderAttachmentDetail(generics.RetrieveUpdateDestroyAPIView, AttachmentMixin):
"""
Detail endpoint for SalesOrderAttachment
"""
"""Detail endpoint for SalesOrderAttachment."""
queryset = models.SalesOrderAttachment.objects.all()
serializer_class = serializers.SalesOrderAttachmentSerializer
class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView):
"""
API endpoint for accessing a list of SalesOrder objects.
"""API endpoint for accessing a list of SalesOrder objects.
- GET: Return list of SalesOrder objects (with filters)
- POST: Create a new SalesOrder
@ -600,9 +560,7 @@ class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView):
serializer_class = serializers.SalesOrderSerializer
def create(self, request, *args, **kwargs):
"""
Save user information on create
"""
"""Save user information on create."""
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
@ -614,7 +572,7 @@ class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView):
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint"""
try:
kwargs['customer_detail'] = str2bool(self.request.query_params.get('customer_detail', False))
except AttributeError:
@ -626,7 +584,7 @@ class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView):
return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for this endpoint"""
queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related(
@ -639,6 +597,7 @@ class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView):
return queryset
def download_queryset(self, queryset, export_format):
"""Download this queryset as a file"""
dataset = SalesOrderResource().export(queryset=queryset)
filedata = dataset.export(export_format)
@ -648,10 +607,7 @@ class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView):
return DownloadFile(filedata, filename)
def filter_queryset(self, queryset):
"""
Perform custom filtering operations on the SalesOrder queryset.
"""
"""Perform custom filtering operations on the SalesOrder queryset."""
queryset = super().filter_queryset(queryset)
params = self.request.query_params
@ -739,15 +695,13 @@ class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView):
class SalesOrderDetail(generics.RetrieveUpdateDestroyAPIView):
"""
API endpoint for detail view of a SalesOrder object.
"""
"""API endpoint for detail view of a SalesOrder object."""
queryset = models.SalesOrder.objects.all()
serializer_class = serializers.SalesOrderSerializer
def get_serializer(self, *args, **kwargs):
"""Return the serializer instance for this endpoint"""
try:
kwargs['customer_detail'] = str2bool(self.request.query_params.get('customer_detail', False))
except AttributeError:
@ -758,7 +712,7 @@ class SalesOrderDetail(generics.RetrieveUpdateDestroyAPIView):
return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs):
"""Return the annotated queryset for this serializer"""
queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related('customer', 'lines')
@ -769,11 +723,11 @@ class SalesOrderDetail(generics.RetrieveUpdateDestroyAPIView):
class SalesOrderLineItemFilter(rest_filters.FilterSet):
"""
Custom filters for SalesOrderLineItemList endpoint
"""
"""Custom filters for SalesOrderLineItemList endpoint."""
class Meta:
"""Metaclass options."""
model = models.SalesOrderLineItem
fields = [
'order',
@ -783,12 +737,10 @@ class SalesOrderLineItemFilter(rest_filters.FilterSet):
completed = rest_filters.BooleanFilter(label='completed', method='filter_completed')
def filter_completed(self, queryset, name, value):
"""
Filter by lines which are "completed"
"""Filter by lines which are "completed".
A line is completed when shipped >= quantity
"""
value = str2bool(value)
q = Q(shipped__gte=F('quantity'))
@ -802,16 +754,14 @@ class SalesOrderLineItemFilter(rest_filters.FilterSet):
class SalesOrderLineItemList(generics.ListCreateAPIView):
"""
API endpoint for accessing a list of SalesOrderLineItem objects.
"""
"""API endpoint for accessing a list of SalesOrderLineItem objects."""
queryset = models.SalesOrderLineItem.objects.all()
serializer_class = serializers.SalesOrderLineItemSerializer
filterset_class = SalesOrderLineItemFilter
def get_serializer(self, *args, **kwargs):
"""Return serializer for this endpoint with extra data as requested"""
try:
params = self.request.query_params
@ -826,7 +776,7 @@ class SalesOrderLineItemList(generics.ListCreateAPIView):
return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for this endpoint"""
queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related(
@ -866,33 +816,31 @@ class SalesOrderLineItemList(generics.ListCreateAPIView):
class SalesOrderExtraLineList(GeneralExtraLineList, generics.ListCreateAPIView):
"""
API endpoint for accessing a list of SalesOrderExtraLine objects.
"""
"""API endpoint for accessing a list of SalesOrderExtraLine objects."""
queryset = models.SalesOrderExtraLine.objects.all()
serializer_class = serializers.SalesOrderExtraLineSerializer
class SalesOrderExtraLineDetail(generics.RetrieveUpdateDestroyAPIView):
""" API endpoint for detail view of a SalesOrderExtraLine object """
"""API endpoint for detail view of a SalesOrderExtraLine object."""
queryset = models.SalesOrderExtraLine.objects.all()
serializer_class = serializers.SalesOrderExtraLineSerializer
class SalesOrderLineItemDetail(generics.RetrieveUpdateDestroyAPIView):
""" API endpoint for detail view of a SalesOrderLineItem object """
"""API endpoint for detail view of a SalesOrderLineItem object."""
queryset = models.SalesOrderLineItem.objects.all()
serializer_class = serializers.SalesOrderLineItemSerializer
class SalesOrderContextMixin:
""" Mixin to add sales order object as serializer context variable """
"""Mixin to add sales order object as serializer context variable."""
def get_serializer_context(self):
"""Add the 'order' reference to the serializer context for any classes which inherit this mixin"""
ctx = super().get_serializer_context()
ctx['request'] = self.request
@ -906,42 +854,38 @@ class SalesOrderContextMixin:
class SalesOrderCancel(SalesOrderContextMixin, generics.CreateAPIView):
"""API endpoint to cancel a SalesOrder"""
queryset = models.SalesOrder.objects.all()
serializer_class = serializers.SalesOrderCancelSerializer
class SalesOrderComplete(SalesOrderContextMixin, generics.CreateAPIView):
"""
API endpoint for manually marking a SalesOrder as "complete".
"""
"""API endpoint for manually marking a SalesOrder as "complete"."""
queryset = models.SalesOrder.objects.all()
serializer_class = serializers.SalesOrderCompleteSerializer
class SalesOrderMetadata(generics.RetrieveUpdateAPIView):
"""API endpoint for viewing / updating SalesOrder metadata"""
"""API endpoint for viewing / updating SalesOrder metadata."""
def get_serializer(self, *args, **kwargs):
"""Return a metadata serializer for the SalesOrder model"""
return MetadataSerializer(models.SalesOrder, *args, **kwargs)
queryset = models.SalesOrder.objects.all()
class SalesOrderAllocateSerials(SalesOrderContextMixin, generics.CreateAPIView):
"""
API endpoint to allocation stock items against a SalesOrder,
by specifying serial numbers.
"""
"""API endpoint to allocation stock items against a SalesOrder, by specifying serial numbers."""
queryset = models.SalesOrder.objects.none()
serializer_class = serializers.SalesOrderSerialAllocationSerializer
class SalesOrderAllocate(SalesOrderContextMixin, generics.CreateAPIView):
"""
API endpoint to allocate stock items against a SalesOrder
"""API endpoint to allocate stock items against a SalesOrder.
- The SalesOrder is specified in the URL
- See the SalesOrderShipmentAllocationSerializer class
@ -952,24 +896,23 @@ class SalesOrderAllocate(SalesOrderContextMixin, generics.CreateAPIView):
class SalesOrderAllocationDetail(generics.RetrieveUpdateDestroyAPIView):
"""
API endpoint for detali view of a SalesOrderAllocation object
"""
"""API endpoint for detali view of a SalesOrderAllocation object."""
queryset = models.SalesOrderAllocation.objects.all()
serializer_class = serializers.SalesOrderAllocationSerializer
class SalesOrderAllocationList(generics.ListAPIView):
"""
API endpoint for listing SalesOrderAllocation objects
"""
"""API endpoint for listing SalesOrderAllocation objects."""
queryset = models.SalesOrderAllocation.objects.all()
serializer_class = serializers.SalesOrderAllocationSerializer
def get_serializer(self, *args, **kwargs):
"""Return the serializer instance for this endpoint.
Adds extra detail serializers if requested
"""
try:
params = self.request.query_params
@ -984,7 +927,7 @@ class SalesOrderAllocationList(generics.ListAPIView):
return self.serializer_class(*args, **kwargs)
def filter_queryset(self, queryset):
"""Custom queryset filtering"""
queryset = super().filter_queryset(queryset)
# Filter by order
@ -1039,14 +982,12 @@ class SalesOrderAllocationList(generics.ListAPIView):
class SalesOrderShipmentFilter(rest_filters.FilterSet):
"""
Custom filterset for the SalesOrderShipmentList endpoint
"""
"""Custom filterset for the SalesOrderShipmentList endpoint."""
shipped = rest_filters.BooleanFilter(label='shipped', method='filter_shipped')
def filter_shipped(self, queryset, name, value):
"""Filter SalesOrder list by 'shipped' status (boolean)"""
value = str2bool(value)
if value:
@ -1057,6 +998,8 @@ class SalesOrderShipmentFilter(rest_filters.FilterSet):
return queryset
class Meta:
"""Metaclass options."""
model = models.SalesOrderShipment
fields = [
'order',
@ -1064,9 +1007,7 @@ class SalesOrderShipmentFilter(rest_filters.FilterSet):
class SalesOrderShipmentList(generics.ListCreateAPIView):
"""
API list endpoint for SalesOrderShipment model
"""
"""API list endpoint for SalesOrderShipment model."""
queryset = models.SalesOrderShipment.objects.all()
serializer_class = serializers.SalesOrderShipmentSerializer
@ -1078,27 +1019,20 @@ class SalesOrderShipmentList(generics.ListCreateAPIView):
class SalesOrderShipmentDetail(generics.RetrieveUpdateDestroyAPIView):
"""
API detail endpooint for SalesOrderShipment model
"""
"""API detail endpooint for SalesOrderShipment model."""
queryset = models.SalesOrderShipment.objects.all()
serializer_class = serializers.SalesOrderShipmentSerializer
class SalesOrderShipmentComplete(generics.CreateAPIView):
"""
API endpoint for completing (shipping) a SalesOrderShipment
"""
"""API endpoint for completing (shipping) a SalesOrderShipment."""
queryset = models.SalesOrderShipment.objects.all()
serializer_class = serializers.SalesOrderShipmentCompleteSerializer
def get_serializer_context(self):
"""
Pass the request object to the serializer
"""
"""Pass the request object to the serializer."""
ctx = super().get_serializer_context()
ctx['request'] = self.request
@ -1113,9 +1047,7 @@ class SalesOrderShipmentComplete(generics.CreateAPIView):
class PurchaseOrderAttachmentList(generics.ListCreateAPIView, AttachmentMixin):
"""
API endpoint for listing (and creating) a PurchaseOrderAttachment (file upload)
"""
"""API endpoint for listing (and creating) a PurchaseOrderAttachment (file upload)"""
queryset = models.PurchaseOrderAttachment.objects.all()
serializer_class = serializers.PurchaseOrderAttachmentSerializer
@ -1130,9 +1062,7 @@ class PurchaseOrderAttachmentList(generics.ListCreateAPIView, AttachmentMixin):
class PurchaseOrderAttachmentDetail(generics.RetrieveUpdateDestroyAPIView, AttachmentMixin):
"""
Detail endpoint for a PurchaseOrderAttachment
"""
"""Detail endpoint for a PurchaseOrderAttachment."""
queryset = models.PurchaseOrderAttachment.objects.all()
serializer_class = serializers.PurchaseOrderAttachmentSerializer

View File

@ -1,5 +1,8 @@
"""Config for the 'order' app"""
from django.apps import AppConfig
class OrderConfig(AppConfig):
"""Configuration class for the 'order' app"""
name = 'order'

View File

@ -1,6 +1,4 @@
"""
Django Forms for interacting with Order objects
"""
"""Django Forms for interacting with Order objects."""
from django import forms
from django.utils.translation import gettext_lazy as _
@ -11,11 +9,10 @@ from InvenTree.helpers import clean_decimal
class OrderMatchItemForm(MatchItemForm):
""" Override MatchItemForm fields """
"""Override MatchItemForm fields."""
def get_special_field(self, col_guess, row, file_manager):
""" Set special fields """
"""Set special fields."""
# set quantity field
if 'quantity' in col_guess.lower():
return forms.CharField(

View File

@ -1,8 +1,4 @@
"""
Order model definitions
"""
# -*- coding: utf-8 -*-
"""Order model definitions."""
import logging
import os
@ -47,10 +43,7 @@ logger = logging.getLogger('inventree')
def get_next_po_number():
"""
Returns the next available PurchaseOrder reference number
"""
"""Returns the next available PurchaseOrder reference number."""
if PurchaseOrder.objects.count() == 0:
return '0001'
@ -76,10 +69,7 @@ def get_next_po_number():
def get_next_so_number():
"""
Returns the next available SalesOrder reference number
"""
"""Returns the next available SalesOrder reference number."""
if SalesOrder.objects.count() == 0:
return '0001'
@ -105,7 +95,7 @@ def get_next_so_number():
class Order(MetadataMixin, ReferenceIndexingMixin):
""" Abstract model for an order.
"""Abstract model for an order.
Instances of this class:
@ -123,7 +113,10 @@ class Order(MetadataMixin, ReferenceIndexingMixin):
"""
def save(self, *args, **kwargs):
"""Custom save method for the order models:
Ensures that the reference field is rebuilt whenever the instance is saved.
"""
self.rebuild_reference_field()
if not self.creation_date:
@ -132,6 +125,8 @@ class Order(MetadataMixin, ReferenceIndexingMixin):
super().save(*args, **kwargs)
class Meta:
"""Metaclass options. Abstract ensures no database table is created."""
abstract = True
description = models.CharField(max_length=250, verbose_name=_('Description'), help_text=_('Order description'))
@ -159,15 +154,13 @@ class Order(MetadataMixin, ReferenceIndexingMixin):
notes = MarkdownxField(blank=True, verbose_name=_('Notes'), help_text=_('Order notes'))
def get_total_price(self, target_currency=currency_code_default()):
"""
Calculates the total price of all order lines, and converts to the specified target currency.
"""Calculates the total price of all order lines, and converts to the specified target currency.
If not specified, the default system currency is used.
If currency conversion fails (e.g. there are no valid conversion rates),
then we simply return zero, rather than attempting some other calculation.
"""
total = Money(0, target_currency)
# gather name reference
@ -230,7 +223,7 @@ class Order(MetadataMixin, ReferenceIndexingMixin):
class PurchaseOrder(Order):
""" A PurchaseOrder represents goods shipped inwards from an external supplier.
"""A PurchaseOrder represents goods shipped inwards from an external supplier.
Attributes:
supplier: Reference to the company supplying the goods in the order
@ -241,14 +234,14 @@ class PurchaseOrder(Order):
@staticmethod
def get_api_url():
"""Return the API URL associated with the PurchaseOrder model"""
return reverse('api-po-list')
OVERDUE_FILTER = Q(status__in=PurchaseOrderStatus.OPEN) & ~Q(target_date=None) & Q(target_date__lte=datetime.now().date())
@staticmethod
def filterByDate(queryset, min_date, max_date):
"""
Filter by 'minimum and maximum date range'
"""Filter by 'minimum and maximum date range'.
- Specified as min_date, max_date
- Both must be specified for filter to be applied
@ -259,7 +252,6 @@ class PurchaseOrder(Order):
- A "pending" order where the target date lies within the date range
- TODO: An "overdue" order where the target date is in the past
"""
date_fmt = '%Y-%m-%d' # ISO format date string
# Ensure that both dates are valid
@ -283,7 +275,7 @@ class PurchaseOrder(Order):
return queryset
def __str__(self):
"""Render a string representation of this PurchaseOrder"""
prefix = getSetting('PURCHASEORDER_REFERENCE_PREFIX')
return f"{prefix}{self.reference} - {self.supplier.name if self.supplier else _('deleted')}"
@ -340,22 +332,29 @@ class PurchaseOrder(Order):
)
def get_absolute_url(self):
"""Return the web URL of the detail view for this order"""
return reverse('po-detail', kwargs={'pk': self.id})
@transaction.atomic
def add_line_item(self, supplier_part, quantity, group=True, reference='', purchase_price=None):
""" Add a new line item to this purchase order.
This function will check that:
def add_line_item(self, supplier_part, quantity, group: bool = True, reference: str = '', purchase_price=None):
"""Add a new line item to this purchase order.
This function will check that:
* The supplier part matches the supplier specified for this purchase order
* The quantity is greater than zero
Args:
supplier_part - The supplier_part to add
quantity - The number of items to add
group - If True, this new quantity will be added to an existing line item for the same supplier_part (if it exists)
"""
supplier_part: The supplier_part to add
quantity : The number of items to add
group (bool, optional): If True, this new quantity will be added to an existing line item for the same supplier_part (if it exists). Defaults to True.
reference (str, optional): Reference to item. Defaults to ''.
purchase_price (optional): Price of item. Defaults to None.
Raises:
ValidationError: quantity is smaller than 0
ValidationError: quantity is not type int
ValidationError: supplier is not supplier of purchase order
"""
try:
quantity = int(quantity)
if quantity <= 0:
@ -396,8 +395,10 @@ class PurchaseOrder(Order):
@transaction.atomic
def place_order(self):
""" Marks the PurchaseOrder as PLACED. Order must be currently PENDING. """
"""Marks the PurchaseOrder as PLACED.
Order must be currently PENDING.
"""
if self.status == PurchaseOrderStatus.PENDING:
self.status = PurchaseOrderStatus.PLACED
self.issue_date = datetime.now().date()
@ -407,8 +408,10 @@ class PurchaseOrder(Order):
@transaction.atomic
def complete_order(self):
""" Marks the PurchaseOrder as COMPLETE. Order must be currently PLACED. """
"""Marks the PurchaseOrder as COMPLETE.
Order must be currently PLACED.
"""
if self.status == PurchaseOrderStatus.PLACED:
self.status = PurchaseOrderStatus.COMPLETE
self.complete_date = datetime.now().date()
@ -418,22 +421,21 @@ class PurchaseOrder(Order):
@property
def is_overdue(self):
"""
Returns True if this PurchaseOrder is "overdue"
"""Returns True if this PurchaseOrder is "overdue".
Makes use of the OVERDUE_FILTER to avoid code duplication.
"""
query = PurchaseOrder.objects.filter(pk=self.pk)
query = query.filter(PurchaseOrder.OVERDUE_FILTER)
return query.exists()
def can_cancel(self):
"""
A PurchaseOrder can only be cancelled under the following circumstances:
"""
"""A PurchaseOrder can only be cancelled under the following circumstances.
- Status is PLACED
- Status is PENDING
"""
return self.status in [
PurchaseOrderStatus.PLACED,
PurchaseOrderStatus.PENDING
@ -441,8 +443,7 @@ class PurchaseOrder(Order):
@transaction.atomic
def cancel_order(self):
""" Marks the PurchaseOrder as CANCELLED. """
"""Marks the PurchaseOrder as CANCELLED."""
if self.can_cancel():
self.status = PurchaseOrderStatus.CANCELLED
self.save()
@ -450,43 +451,39 @@ class PurchaseOrder(Order):
trigger_event('purchaseorder.cancelled', id=self.pk)
def pending_line_items(self):
""" Return a list of pending line items for this order.
"""Return a list of pending line items for this order.
Any line item where 'received' < 'quantity' will be returned.
"""
return self.lines.filter(quantity__gt=F('received'))
def completed_line_items(self):
"""
Return a list of completed line items against this order
"""
"""Return a list of completed line items against this order."""
return self.lines.filter(quantity__lte=F('received'))
@property
def line_count(self):
"""Return the total number of line items associated with this order"""
return self.lines.count()
@property
def completed_line_count(self):
"""Return the number of complete line items associated with this order"""
return self.completed_line_items().count()
@property
def pending_line_count(self):
"""Return the number of pending line items associated with this order"""
return self.pending_line_items().count()
@property
def is_complete(self):
""" Return True if all line items have been received """
"""Return True if all line items have been received."""
return self.lines.count() > 0 and self.pending_line_items().count() == 0
@transaction.atomic
def receive_line_item(self, line, location, quantity, user, status=StockStatus.OK, **kwargs):
"""
Receive a line item (or partial line item) against this PurchaseOrder
"""
"""Receive a line item (or partial line item) against this PurchaseOrder."""
# Extract optional batch code for the new stock item
batch_code = kwargs.get('batch_code', '')
@ -573,8 +570,7 @@ class PurchaseOrder(Order):
class SalesOrder(Order):
"""
A SalesOrder represents a list of goods shipped outwards to a customer.
"""A SalesOrder represents a list of goods shipped outwards to a customer.
Attributes:
customer: Reference to the company receiving the goods in the order
@ -584,14 +580,14 @@ class SalesOrder(Order):
@staticmethod
def get_api_url():
"""Return the API URL associated with the SalesOrder model"""
return reverse('api-so-list')
OVERDUE_FILTER = Q(status__in=SalesOrderStatus.OPEN) & ~Q(target_date=None) & Q(target_date__lte=datetime.now().date())
@staticmethod
def filterByDate(queryset, min_date, max_date):
"""
Filter by "minimum and maximum date range"
"""Filter by "minimum and maximum date range".
- Specified as min_date, max_date
- Both must be specified for filter to be applied
@ -602,7 +598,6 @@ class SalesOrder(Order):
- A "pending" order where the target date lies within the date range
- TODO: An "overdue" order where the target date is in the past
"""
date_fmt = '%Y-%m-%d' # ISO format date string
# Ensure that both dates are valid
@ -625,19 +620,14 @@ class SalesOrder(Order):
return queryset
def save(self, *args, **kwargs):
self.rebuild_reference_field()
super().save(*args, **kwargs)
def __str__(self):
"""Render a string representation of this SalesOrder"""
prefix = getSetting('SALESORDER_REFERENCE_PREFIX')
return f"{prefix}{self.reference} - {self.customer.name if self.customer else _('deleted')}"
def get_absolute_url(self):
"""Return the web URL for the detail view of this order"""
return reverse('so-detail', kwargs={'pk': self.id})
reference = models.CharField(
@ -682,12 +672,10 @@ class SalesOrder(Order):
@property
def is_overdue(self):
"""
Returns true if this SalesOrder is "overdue":
"""Returns true if this SalesOrder is "overdue".
Makes use of the OVERDUE_FILTER to avoid code duplication.
"""
query = SalesOrder.objects.filter(pk=self.pk)
query = query.filter(SalesOrder.OVERDUE_FILTER)
@ -695,21 +683,18 @@ class SalesOrder(Order):
@property
def is_pending(self):
"""Return True if this order is 'pending'"""
return self.status == SalesOrderStatus.PENDING
@property
def stock_allocations(self):
"""
Return a queryset containing all allocations for this order
"""
"""Return a queryset containing all allocations for this order."""
return SalesOrderAllocation.objects.filter(
line__in=[line.pk for line in self.lines.all()]
)
def is_fully_allocated(self):
""" Return True if all line items are fully allocated """
"""Return True if all line items are fully allocated."""
for line in self.lines.all():
if not line.is_fully_allocated():
return False
@ -717,8 +702,7 @@ class SalesOrder(Order):
return True
def is_over_allocated(self):
""" Return true if any lines in the order are over-allocated """
"""Return true if any lines in the order are over-allocated."""
for line in self.lines.all():
if line.is_over_allocated():
return True
@ -726,19 +710,14 @@ class SalesOrder(Order):
return False
def is_completed(self):
"""
Check if this order is "shipped" (all line items delivered),
"""
"""Check if this order is "shipped" (all line items delivered)."""
return self.lines.count() > 0 and all([line.is_completed() for line in self.lines.all()])
def can_complete(self, raise_error=False):
"""
Test if this SalesOrder can be completed.
"""Test if this SalesOrder can be completed.
Throws a ValidationError if cannot be completed.
"""
try:
# Order without line items cannot be completed
@ -765,10 +744,7 @@ class SalesOrder(Order):
return True
def complete_order(self, user):
"""
Mark this order as "complete"
"""
"""Mark this order as "complete."""
if not self.can_complete():
return False
@ -783,10 +759,7 @@ class SalesOrder(Order):
return True
def can_cancel(self):
"""
Return True if this order can be cancelled
"""
"""Return True if this order can be cancelled."""
if self.status != SalesOrderStatus.PENDING:
return False
@ -794,13 +767,12 @@ class SalesOrder(Order):
@transaction.atomic
def cancel_order(self):
"""
Cancel this order (only if it is "pending")
"""Cancel this order (only if it is "pending").
Executes:
- Mark the order as 'cancelled'
- Delete any StockItems which have been allocated
"""
if not self.can_cancel():
return False
@ -817,59 +789,54 @@ class SalesOrder(Order):
@property
def line_count(self):
"""Return the total number of lines associated with this order"""
return self.lines.count()
def completed_line_items(self):
"""
Return a queryset of the completed line items for this order
"""
"""Return a queryset of the completed line items for this order."""
return self.lines.filter(shipped__gte=F('quantity'))
def pending_line_items(self):
"""
Return a queryset of the pending line items for this order
"""
"""Return a queryset of the pending line items for this order."""
return self.lines.filter(shipped__lt=F('quantity'))
@property
def completed_line_count(self):
"""Return the number of completed lines for this order"""
return self.completed_line_items().count()
@property
def pending_line_count(self):
"""Return the number of pending (incomplete) lines associated with this order"""
return self.pending_line_items().count()
def completed_shipments(self):
"""
Return a queryset of the completed shipments for this order
"""
"""Return a queryset of the completed shipments for this order."""
return self.shipments.exclude(shipment_date=None)
def pending_shipments(self):
"""
Return a queryset of the pending shipments for this order
"""
"""Return a queryset of the pending shipments for this order."""
return self.shipments.filter(shipment_date=None)
@property
def shipment_count(self):
"""Return the total number of shipments associated with this order"""
return self.shipments.count()
@property
def completed_shipment_count(self):
"""Return the number of completed shipments associated with this order"""
return self.completed_shipments().count()
@property
def pending_shipment_count(self):
"""Return the number of pending shipments associated with this order"""
return self.pending_shipments().count()
@receiver(post_save, sender=SalesOrder, dispatch_uid='build_post_save_log')
def after_save_sales_order(sender, instance: SalesOrder, created: bool, **kwargs):
"""
Callback function to be executed after a SalesOrder instance is saved
"""
"""Callback function to be executed after a SalesOrder instance is saved."""
if created and getSetting('SALESORDER_DEFAULT_SHIPMENT'):
# A new SalesOrder has just been created
@ -881,37 +848,37 @@ def after_save_sales_order(sender, instance: SalesOrder, created: bool, **kwargs
class PurchaseOrderAttachment(InvenTreeAttachment):
"""
Model for storing file attachments against a PurchaseOrder object
"""
"""Model for storing file attachments against a PurchaseOrder object."""
@staticmethod
def get_api_url():
"""Return the API URL associated with the PurchaseOrderAttachment model"""
return reverse('api-po-attachment-list')
def getSubdir(self):
"""Return the directory path where PurchaseOrderAttachment files are located"""
return os.path.join("po_files", str(self.order.id))
order = models.ForeignKey(PurchaseOrder, on_delete=models.CASCADE, related_name="attachments")
class SalesOrderAttachment(InvenTreeAttachment):
"""
Model for storing file attachments against a SalesOrder object
"""
"""Model for storing file attachments against a SalesOrder object."""
@staticmethod
def get_api_url():
"""Return the API URL associated with the SalesOrderAttachment class"""
return reverse('api-so-attachment-list')
def getSubdir(self):
"""Return the directory path where SalesOrderAttachment files are located"""
return os.path.join("so_files", str(self.order.id))
order = models.ForeignKey(SalesOrder, on_delete=models.CASCADE, related_name='attachments')
class OrderLineItem(models.Model):
""" Abstract model for an order line item
"""Abstract model for an order line item.
Attributes:
quantity: Number of items
@ -929,6 +896,8 @@ class OrderLineItem(models.Model):
OVERDUE_FILTER = Q(received__lt=F('quantity')) & ~Q(target_date=None) & Q(target_date__lt=datetime.now().date())
class Meta:
"""Metaclass options. Abstract ensures no database table is created."""
abstract = True
quantity = RoundingDecimalField(
@ -951,16 +920,16 @@ class OrderLineItem(models.Model):
class OrderExtraLine(OrderLineItem):
"""
Abstract Model for a single ExtraLine in a Order
"""Abstract Model for a single ExtraLine in a Order.
Attributes:
price: The unit sale price for this OrderLineItem
"""
class Meta:
"""Metaclass options. Abstract ensures no database table is created."""
abstract = True
unique_together = [
]
context = models.JSONField(
blank=True, null=True,
@ -976,30 +945,24 @@ class OrderExtraLine(OrderLineItem):
help_text=_('Unit price'),
)
def price_converted(self):
return convert_money(self.price, currency_code_default())
def price_converted_currency(self):
return currency_code_default()
class PurchaseOrderLineItem(OrderLineItem):
""" Model for a purchase order line item.
"""Model for a purchase order line item.
Attributes:
order: Reference to a PurchaseOrder object
"""
class Meta:
unique_together = (
)
@staticmethod
def get_api_url():
"""Return the API URL associated with the PurchaseOrderLineItem model"""
return reverse('api-po-line-list')
def clean(self):
"""Custom clean method for the PurchaseOrderLineItem model:
- Ensure the supplier part matches the supplier
"""
super().clean()
if self.order.supplier and self.part:
@ -1010,6 +973,7 @@ class PurchaseOrderLineItem(OrderLineItem):
})
def __str__(self):
"""Render a string representation of a PurchaseOrderLineItem instance"""
return "{n} x {part} from {supplier} (for {po})".format(
n=decimal2string(self.quantity),
part=self.part.SKU if self.part else 'unknown part',
@ -1024,8 +988,7 @@ class PurchaseOrderLineItem(OrderLineItem):
)
def get_base_part(self):
"""
Return the base part.Part object for the line item
"""Return the base part.Part object for the line item.
Note: Returns None if the SupplierPart is not set!
"""
@ -1067,14 +1030,12 @@ class PurchaseOrderLineItem(OrderLineItem):
)
def get_destination(self):
"""
Show where the line item is or should be placed
"""Show where the line item is or should be placed.
NOTE: If a line item gets split when recieved, only an arbitrary
stock items location will be reported as the location for the
entire line.
"""
for stock in stock_models.StockItem.objects.filter(supplier_part=self.part, purchase_order=self.order):
if stock.location:
return stock.location
@ -1084,14 +1045,14 @@ class PurchaseOrderLineItem(OrderLineItem):
return self.part.part.default_location
def remaining(self):
""" Calculate the number of items remaining to be received """
"""Calculate the number of items remaining to be received."""
r = self.quantity - self.received
return max(r, 0)
class PurchaseOrderExtraLine(OrderExtraLine):
"""
Model for a single ExtraLine in a PurchaseOrder
"""Model for a single ExtraLine in a PurchaseOrder.
Attributes:
order: Link to the PurchaseOrder that this line belongs to
title: title of line
@ -1099,14 +1060,14 @@ class PurchaseOrderExtraLine(OrderExtraLine):
"""
@staticmethod
def get_api_url():
"""Return the API URL associated with the PurchaseOrderExtraLine model"""
return reverse('api-po-extra-line-list')
order = models.ForeignKey(PurchaseOrder, on_delete=models.CASCADE, related_name='extra_lines', verbose_name=_('Order'), help_text=_('Purchase Order'))
class SalesOrderLineItem(OrderLineItem):
"""
Model for a single LineItem in a SalesOrder
"""Model for a single LineItem in a SalesOrder.
Attributes:
order: Link to the SalesOrder that this line item belongs to
@ -1117,6 +1078,7 @@ class SalesOrderLineItem(OrderLineItem):
@staticmethod
def get_api_url():
"""Return the API URL associated with the SalesOrderLineItem model"""
return reverse('api-so-line-list')
order = models.ForeignKey(
@ -1145,52 +1107,39 @@ class SalesOrderLineItem(OrderLineItem):
validators=[MinValueValidator(0)]
)
class Meta:
unique_together = [
]
def fulfilled_quantity(self):
"""
Return the total stock quantity fulfilled against this line item.
"""
"""Return the total stock quantity fulfilled against this line item."""
query = self.order.stock_items.filter(part=self.part).aggregate(fulfilled=Coalesce(Sum('quantity'), Decimal(0)))
return query['fulfilled']
def allocated_quantity(self):
""" Return the total stock quantity allocated to this LineItem.
"""Return the total stock quantity allocated to this LineItem.
This is a summation of the quantity of each attached StockItem
"""
query = self.allocations.aggregate(allocated=Coalesce(Sum('quantity'), Decimal(0)))
return query['allocated']
def is_fully_allocated(self):
""" Return True if this line item is fully allocated """
"""Return True if this line item is fully allocated."""
if self.order.status == SalesOrderStatus.SHIPPED:
return self.fulfilled_quantity() >= self.quantity
return self.allocated_quantity() >= self.quantity
def is_over_allocated(self):
""" Return True if this line item is over allocated """
"""Return True if this line item is over allocated."""
return self.allocated_quantity() > self.quantity
def is_completed(self):
"""
Return True if this line item is completed (has been fully shipped)
"""
"""Return True if this line item is completed (has been fully shipped)."""
return self.shipped >= self.quantity
class SalesOrderShipment(models.Model):
"""
The SalesOrderShipment model represents a physical shipment made against a SalesOrder.
"""The SalesOrderShipment model represents a physical shipment made against a SalesOrder.
- Points to a single SalesOrder object
- Multiple SalesOrderAllocation objects point to a particular SalesOrderShipment
@ -1205,6 +1154,7 @@ class SalesOrderShipment(models.Model):
"""
class Meta:
"""Metaclass defines extra model options"""
# Shipment reference must be unique for a given sales order
unique_together = [
'order', 'reference',
@ -1212,6 +1162,7 @@ class SalesOrderShipment(models.Model):
@staticmethod
def get_api_url():
"""Return the API URL associated with the SalesOrderShipment model"""
return reverse('api-so-shipment-list')
order = models.ForeignKey(
@ -1275,10 +1226,11 @@ class SalesOrderShipment(models.Model):
)
def is_complete(self):
"""Return True if this shipment has already been completed"""
return self.shipment_date is not None
def check_can_complete(self, raise_error=True):
"""Check if this shipment is able to be completed"""
try:
if self.shipment_date:
# Shipment has already been sent!
@ -1297,14 +1249,13 @@ class SalesOrderShipment(models.Model):
@transaction.atomic
def complete_shipment(self, user, **kwargs):
"""
Complete this particular shipment:
"""Complete this particular shipment.
Executes:
1. Update any stock items associated with this shipment
2. Update the "shipped" quantity of all associated line items
3. Set the "shipment_date" to now
"""
# Check if the shipment can be completed (throw error if not)
self.check_can_complete()
@ -1343,8 +1294,8 @@ class SalesOrderShipment(models.Model):
class SalesOrderExtraLine(OrderExtraLine):
"""
Model for a single ExtraLine in a SalesOrder
"""Model for a single ExtraLine in a SalesOrder.
Attributes:
order: Link to the SalesOrder that this line belongs to
title: title of line
@ -1352,40 +1303,37 @@ class SalesOrderExtraLine(OrderExtraLine):
"""
@staticmethod
def get_api_url():
"""Return the API URL associated with the SalesOrderExtraLine model"""
return reverse('api-so-extra-line-list')
order = models.ForeignKey(SalesOrder, on_delete=models.CASCADE, related_name='extra_lines', verbose_name=_('Order'), help_text=_('Sales Order'))
class SalesOrderAllocation(models.Model):
"""
This model is used to 'allocate' stock items to a SalesOrder.
Items that are "allocated" to a SalesOrder are not yet "attached" to the order,
but they will be once the order is fulfilled.
"""This model is used to 'allocate' stock items to a SalesOrder. Items that are "allocated" to a SalesOrder are not yet "attached" to the order, but they will be once the order is fulfilled.
Attributes:
line: SalesOrderLineItem reference
shipment: SalesOrderShipment reference
item: StockItem reference
quantity: Quantity to take from the StockItem
"""
@staticmethod
def get_api_url():
"""Return the API URL associated with the SalesOrderAllocation model"""
return reverse('api-so-allocation-list')
def clean(self):
"""
Validate the SalesOrderAllocation object:
"""Validate the SalesOrderAllocation object.
Executes:
- Cannot allocate stock to a line item without a part reference
- The referenced part must match the part associated with the line item
- Allocated quantity cannot exceed the quantity of the stock item
- Allocation quantity must be "1" if the StockItem is serialized
- Allocation quantity cannot be zero
"""
super().clean()
errors = {}
@ -1452,29 +1400,21 @@ class SalesOrderAllocation(models.Model):
quantity = RoundingDecimalField(max_digits=15, decimal_places=5, validators=[MinValueValidator(0)], default=1, verbose_name=_('Quantity'), help_text=_('Enter stock allocation quantity'))
def get_serial(self):
return self.item.serial
def get_location(self):
"""Return the <pk> value of the location associated with this allocation"""
return self.item.location.id if self.item.location else None
def get_location_path(self):
if self.item.location:
return self.item.location.pathstring
else:
return ""
def get_po(self):
"""Return the PurchaseOrder associated with this allocation"""
return self.item.purchase_order
def complete_allocation(self, user):
"""
Complete this allocation (called when the parent SalesOrder is marked as "shipped"):
"""Complete this allocation (called when the parent SalesOrder is marked as "shipped").
Executes:
- Determine if the referenced StockItem needs to be "split" (if allocated quantity != stock quantity)
- Mark the StockItem as belonging to the Customer (this will remove it from stock)
"""
order = self.line.order
item = self.item.allocateToCustomer(

Some files were not shown because too many files have changed in this diff Show More