[CI] docstrings (#6172)

* Squashed commit of the following:

commit 52d7ff0f65
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 23:03:20 2024 +0100

    fixed lookup

commit 0d076eaea8
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 23:03:08 2024 +0100

    switched to pathlib for lookup

commit 473e75eda2
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 22:52:30 2024 +0100

    fix wrong url response

commit fd74f8d703
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 21:14:38 2024 +0100

    switched to ruff for import sorting

commit f83fedbbb8
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 21:03:14 2024 +0100

    switched to single quotes everywhere

commit a92442e60e
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:58:23 2024 +0100

    added autofixes

commit cc66c93136
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:56:47 2024 +0100

    enable autoformat

commit 1f343606ec
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:42:14 2024 +0100

    Squashed commit of the following:

    commit f5cf7b2e78
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 20:36:57 2024 +0100

        fixed reqs

    commit 9d845bee98
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 20:32:35 2024 +0100

        disable autofix/format

    commit aff5f27148
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 20:28:50 2024 +0100

        adjust checks

    commit 47271cf1ef
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 20:28:22 2024 +0100

        reorder order of operations

    commit e1bf178b40
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 20:01:09 2024 +0100

        adapted ruff settings to better fit code base

    commit ad7d88a6f4
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 19:59:45 2024 +0100

        auto fixed docstring

    commit a2e54a760e
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 19:46:35 2024 +0100

        fix getattr useage

    commit cb80c73bc6
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 19:25:09 2024 +0100

        fix requirements file

    commit b7780bbd21
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 18:42:28 2024 +0100

        fix removed sections

    commit 71f1681f55
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 18:41:21 2024 +0100

        fix djlint syntax

    commit a0bcf1bcce
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 18:35:28 2024 +0100

        remove flake8 from code base

    commit 22475b31cc
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 18:34:56 2024 +0100

        remove flake8 from code base

    commit 0413350f14
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 18:24:39 2024 +0100

        moved ruff section

    commit d90c48a0bf
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 18:24:24 2024 +0100

        move djlint config to pyproject

    commit c5ce55d511
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 18:20:39 2024 +0100

        added isort again

    commit 42a41d23af
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 18:19:02 2024 +0100

        move config section

    commit 8569233181
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 18:17:52 2024 +0100

        fix codespell error

    commit 2897c6704d
    Author: Matthias Mair <code@mjmair.com>
    Date:   Sun Jan 7 17:29:21 2024 +0100

        replaced flake8 with ruff
        mostly for speed improvements

* enable docstring checks

* fix docstrings

* fixed D417 Missing argument description

* Squashed commit of the following:

commit d3b795824b
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 22:56:17 2024 +0100

    fixed source path

commit 0bac0c19b8
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 22:47:53 2024 +0100

    fixed req

commit 9f61f01d9c
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 22:45:18 2024 +0100

    added missing toml req

commit 91b71ed24a
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:49:50 2024 +0100

    moved isort config

commit 12460b0419
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:43:22 2024 +0100

    remove flake8 section from setup.cfg

commit f5cf7b2e78
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:36:57 2024 +0100

    fixed reqs

commit 9d845bee98
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:32:35 2024 +0100

    disable autofix/format

commit aff5f27148
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:28:50 2024 +0100

    adjust checks

commit 47271cf1ef
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:28:22 2024 +0100

    reorder order of operations

commit e1bf178b40
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:01:09 2024 +0100

    adapted ruff settings to better fit code base

commit ad7d88a6f4
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:59:45 2024 +0100

    auto fixed docstring

commit a2e54a760e
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:46:35 2024 +0100

    fix getattr useage

commit cb80c73bc6
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:25:09 2024 +0100

    fix requirements file

commit b7780bbd21
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:42:28 2024 +0100

    fix removed sections

commit 71f1681f55
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:41:21 2024 +0100

    fix djlint syntax

commit a0bcf1bcce
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:35:28 2024 +0100

    remove flake8 from code base

commit 22475b31cc
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:34:56 2024 +0100

    remove flake8 from code base

commit 0413350f14
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:24:39 2024 +0100

    moved ruff section

commit d90c48a0bf
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:24:24 2024 +0100

    move djlint config to pyproject

commit c5ce55d511
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:20:39 2024 +0100

    added isort again

commit 42a41d23af
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:19:02 2024 +0100

    move config section

commit 8569233181
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:17:52 2024 +0100

    fix codespell error

commit 2897c6704d
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 17:29:21 2024 +0100

    replaced flake8 with ruff
    mostly for speed improvements

* fix pyproject

* make docstrings more uniform

* auto-format

* fix order

* revert url change
This commit is contained in:
Matthias Mair 2024-01-11 04:40:07 +01:00 committed by GitHub
parent 9db3efa085
commit 9d0264c319
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
168 changed files with 1659 additions and 1754 deletions

View File

@ -1,4 +1,4 @@
"""Admin classes""" """Admin classes."""
from django.contrib import admin from django.contrib import admin
from django.http.request import HttpRequest from django.http.request import HttpRequest
@ -10,7 +10,7 @@ from import_export.resources import ModelResource
class InvenTreeResource(ModelResource): class InvenTreeResource(ModelResource):
"""Custom subclass of the ModelResource class provided by django-import-export" """Custom subclass of the ModelResource class provided by django-import-export".
Ensures that exported data are escaped to prevent malicious formula injection. Ensures that exported data are escaped to prevent malicious formula injection.
Ref: https://owasp.org/www-community/attacks/CSV_Injection Ref: https://owasp.org/www-community/attacks/CSV_Injection
@ -32,7 +32,7 @@ class InvenTreeResource(ModelResource):
rollback_on_validation_errors=None, rollback_on_validation_errors=None,
**kwargs, **kwargs,
): ):
"""Override the default import_data_inner function to provide better error handling""" """Override the default import_data_inner function to provide better error handling."""
if len(dataset) > self.MAX_IMPORT_ROWS: if len(dataset) > self.MAX_IMPORT_ROWS:
raise ImportExportError( raise ImportExportError(
f'Dataset contains too many rows (max {self.MAX_IMPORT_ROWS})' f'Dataset contains too many rows (max {self.MAX_IMPORT_ROWS})'
@ -76,7 +76,7 @@ class InvenTreeResource(ModelResource):
return row return row
def get_fields(self, **kwargs): def get_fields(self, **kwargs):
"""Return fields, with some common exclusions""" """Return fields, with some common exclusions."""
fields = super().get_fields(**kwargs) fields = super().get_fields(**kwargs)
fields_to_exclude = ['metadata', 'lft', 'rght', 'tree_id', 'level'] fields_to_exclude = ['metadata', 'lft', 'rght', 'tree_id', 'level']
@ -84,18 +84,17 @@ class InvenTreeResource(ModelResource):
return [f for f in fields if f.column_name not in fields_to_exclude] return [f for f in fields if f.column_name not in fields_to_exclude]
def before_import_row(self, row, row_number=None, **kwargs): def before_import_row(self, row, row_number=None, **kwargs):
"""Run custom code before importing each row""" """Run custom code before importing each row."""
for field in self.CONVERT_NULL_FIELDS: for field in self.CONVERT_NULL_FIELDS:
if field in row and row[field] is None: if field in row and row[field] is None:
row[field] = '' row[field] = ''
class CustomRateAdmin(RateAdmin): class CustomRateAdmin(RateAdmin):
"""Admin interface for the Rate class""" """Admin interface for the Rate class."""
def has_add_permission(self, request: HttpRequest) -> bool: def has_add_permission(self, request: HttpRequest) -> bool:
"""Disable the 'add' permission for Rate objects""" """Disable the 'add' permission for Rate objects."""
return False return False

View File

@ -98,7 +98,7 @@ class InfoView(AjaxView):
permission_classes = [permissions.AllowAny] permission_classes = [permissions.AllowAny]
def worker_pending_tasks(self): def worker_pending_tasks(self):
"""Return the current number of outstanding background tasks""" """Return the current number of outstanding background tasks."""
return OrmQ.objects.count() return OrmQ.objects.count()
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
@ -158,7 +158,7 @@ class NotFoundView(AjaxView):
permission_classes = [permissions.AllowAny] permission_classes = [permissions.AllowAny]
def not_found(self, request): def not_found(self, request):
"""Return a 404 error""" """Return a 404 error."""
return JsonResponse( return JsonResponse(
{ {
'detail': _('API endpoint not found'), 'detail': _('API endpoint not found'),
@ -168,27 +168,27 @@ class NotFoundView(AjaxView):
) )
def options(self, request, *args, **kwargs): def options(self, request, *args, **kwargs):
"""Return 404""" """Return 404."""
return self.not_found(request) return self.not_found(request)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""Return 404""" """Return 404."""
return self.not_found(request) return self.not_found(request)
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
"""Return 404""" """Return 404."""
return self.not_found(request) return self.not_found(request)
def patch(self, request, *args, **kwargs): def patch(self, request, *args, **kwargs):
"""Return 404""" """Return 404."""
return self.not_found(request) return self.not_found(request)
def put(self, request, *args, **kwargs): def put(self, request, *args, **kwargs):
"""Return 404""" """Return 404."""
return self.not_found(request) return self.not_found(request)
def delete(self, request, *args, **kwargs): def delete(self, request, *args, **kwargs):
"""Return 404""" """Return 404."""
return self.not_found(request) return self.not_found(request)
@ -204,7 +204,7 @@ class BulkDeleteMixin:
""" """
def filter_delete_queryset(self, queryset, request): def filter_delete_queryset(self, queryset, request):
"""Provide custom filtering for the queryset *before* it is deleted""" """Provide custom filtering for the queryset *before* it is deleted."""
return queryset return queryset
def delete(self, request, *args, **kwargs): def delete(self, request, *args, **kwargs):
@ -270,7 +270,7 @@ class BulkDeleteMixin:
class ListCreateDestroyAPIView(BulkDeleteMixin, ListCreateAPI): class ListCreateDestroyAPIView(BulkDeleteMixin, ListCreateAPI):
"""Custom API endpoint which provides BulkDelete functionality in addition to List and Create""" """Custom API endpoint which provides BulkDelete functionality in addition to List and Create."""
... ...
@ -328,7 +328,7 @@ class AttachmentMixin:
class APISearchView(APIView): class APISearchView(APIView):
"""A general-purpose 'search' API endpoint """A general-purpose 'search' API endpoint.
Returns hits against a number of different models simultaneously, Returns hits against a number of different models simultaneously,
to consolidate multiple API requests into a single query. to consolidate multiple API requests into a single query.
@ -339,7 +339,7 @@ class APISearchView(APIView):
permission_classes = [permissions.IsAuthenticated] permission_classes = [permissions.IsAuthenticated]
def get_result_types(self): def get_result_types(self):
"""Construct a list of search types we can return""" """Construct a list of search types we can return."""
import build.api import build.api
import company.api import company.api
import order.api import order.api
@ -361,7 +361,7 @@ class APISearchView(APIView):
} }
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
"""Perform search query against available models""" """Perform search query against available models."""
data = request.data data = request.data
results = {} results = {}
@ -424,12 +424,12 @@ class APISearchView(APIView):
class MetadataView(RetrieveUpdateAPI): class MetadataView(RetrieveUpdateAPI):
"""Generic API endpoint for reading and editing metadata for a model""" """Generic API endpoint for reading and editing metadata for a model."""
MODEL_REF = 'model' MODEL_REF = 'model'
def get_model_type(self): def get_model_type(self):
"""Return the model type associated with this API instance""" """Return the model type associated with this API instance."""
model = self.kwargs.get(self.MODEL_REF, None) model = self.kwargs.get(self.MODEL_REF, None)
if model is None: if model is None:
@ -440,13 +440,13 @@ class MetadataView(RetrieveUpdateAPI):
return model return model
def get_permission_model(self): def get_permission_model(self):
"""Return the 'permission' model associated with this view""" """Return the 'permission' model associated with this view."""
return self.get_model_type() return self.get_model_type()
def get_queryset(self): def get_queryset(self):
"""Return the queryset for this endpoint""" """Return the queryset for this endpoint."""
return self.get_model_type().objects.all() return self.get_model_type().objects.all()
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return MetadataSerializer instance""" """Return MetadataSerializer instance."""
return MetadataSerializer(self.get_model_type(), *args, **kwargs) return MetadataSerializer(self.get_model_type(), *args, **kwargs)

View File

@ -60,7 +60,7 @@ def to_dict(value):
def is_true(x): def is_true(x):
"""Shortcut function to determine if a value "looks" like a boolean""" """Shortcut function to determine if a value "looks" like a boolean."""
return str(x).strip().lower() in ['1', 'y', 'yes', 't', 'true', 'on'] return str(x).strip().lower() in ['1', 'y', 'yes', 't', 'true', 'on']
@ -226,12 +226,12 @@ def get_setting(env_var=None, config_key=None, default_value=None, typecast=None
def get_boolean_setting(env_var=None, config_key=None, default_value=False): def get_boolean_setting(env_var=None, config_key=None, default_value=False):
"""Helper function for retrieving a boolean configuration setting""" """Helper function for retrieving a boolean configuration setting."""
return is_true(get_setting(env_var, config_key, default_value)) return is_true(get_setting(env_var, config_key, default_value))
def get_media_dir(create=True): def get_media_dir(create=True):
"""Return the absolute path for the 'media' directory (where uploaded files are stored)""" """Return the absolute path for the 'media' directory (where uploaded files are stored)."""
md = get_setting('INVENTREE_MEDIA_ROOT', 'media_root') md = get_setting('INVENTREE_MEDIA_ROOT', 'media_root')
if not md: if not md:
@ -246,7 +246,7 @@ def get_media_dir(create=True):
def get_static_dir(create=True): def get_static_dir(create=True):
"""Return the absolute path for the 'static' directory (where static files are stored)""" """Return the absolute path for the 'static' directory (where static files are stored)."""
sd = get_setting('INVENTREE_STATIC_ROOT', 'static_root') sd = get_setting('INVENTREE_STATIC_ROOT', 'static_root')
if not sd: if not sd:
@ -261,7 +261,7 @@ def get_static_dir(create=True):
def get_backup_dir(create=True): def get_backup_dir(create=True):
"""Return the absolute path for the backup directory""" """Return the absolute path for the backup directory."""
bd = get_setting('INVENTREE_BACKUP_DIR', 'backup_dir') bd = get_setting('INVENTREE_BACKUP_DIR', 'backup_dir')
if not bd: if not bd:
@ -307,7 +307,7 @@ def get_plugin_file():
def get_plugin_dir(): def get_plugin_dir():
"""Returns the path of the custom plugins directory""" """Returns the path of the custom plugins directory."""
return get_setting('INVENTREE_PLUGIN_DIR', 'plugin_dir') return get_setting('INVENTREE_PLUGIN_DIR', 'plugin_dir')
@ -389,7 +389,6 @@ def get_frontend_settings(debug=True):
Note that the new config settings use the 'FRONTEND' key, Note that the new config settings use the 'FRONTEND' key,
whereas the legacy key was 'PUI' (platform UI) which is now deprecated whereas the legacy key was 'PUI' (platform UI) which is now deprecated
""" """
# Legacy settings # Legacy settings
pui_settings = get_setting( pui_settings = get_setting(
'INVENTREE_PUI_SETTINGS', 'pui_settings', {}, typecast=dict 'INVENTREE_PUI_SETTINGS', 'pui_settings', {}, typecast=dict

View File

@ -151,7 +151,7 @@ def convert_physical_value(value: str, unit: str = None, strip_units=True):
def is_dimensionless(value): def is_dimensionless(value):
"""Determine if the provided value is 'dimensionless' """Determine if the provided value is 'dimensionless'.
A dimensionless value might look like: A dimensionless value might look like:

View File

@ -80,7 +80,7 @@ class InvenTreeExchange(SimpleExchangeBackend):
@atomic @atomic
def update_rates(self, base_currency=None, **kwargs): def update_rates(self, base_currency=None, **kwargs):
"""Call to update all exchange rates""" """Call to update all exchange rates."""
backend, _ = ExchangeBackend.objects.update_or_create( backend, _ = ExchangeBackend.objects.update_or_create(
name=self.name, defaults={'base_currency': base_currency} name=self.name, defaults={'base_currency': base_currency}
) )

View File

@ -31,8 +31,7 @@ class InvenTreeRestURLField(RestURLField):
self.validators[-1].schemes = allowable_url_schemes() self.validators[-1].schemes = allowable_url_schemes()
def run_validation(self, data=empty): def run_validation(self, data=empty):
"""Override default validation behaviour for this field type""" """Override default validation behaviour for this field type."""
import common.models import common.models
strict_urls = common.models.InvenTreeSetting.get_setting( strict_urls = common.models.InvenTreeSetting.get_setting(
@ -53,7 +52,7 @@ class InvenTreeURLField(models.URLField):
default_validators = [AllowedURLValidator()] default_validators = [AllowedURLValidator()]
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Initialization method for InvenTreeURLField""" """Initialization method for InvenTreeURLField."""
# Max length for InvenTreeURLField is set to 200 # Max length for InvenTreeURLField is set to 200
kwargs['max_length'] = 200 kwargs['max_length'] = 200
super().__init__(**kwargs) super().__init__(**kwargs)
@ -199,13 +198,13 @@ class RoundingDecimalField(models.DecimalField):
class InvenTreeNotesField(models.TextField): class InvenTreeNotesField(models.TextField):
"""Custom implementation of a 'notes' field""" """Custom implementation of a 'notes' field."""
# Maximum character limit for the various 'notes' fields # Maximum character limit for the various 'notes' fields
NOTES_MAX_LENGTH = 50000 NOTES_MAX_LENGTH = 50000
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""Configure default initial values for this field""" """Configure default initial values for this field."""
kwargs['max_length'] = self.NOTES_MAX_LENGTH kwargs['max_length'] = self.NOTES_MAX_LENGTH
kwargs['verbose_name'] = _('Notes') kwargs['verbose_name'] = _('Notes')
kwargs['blank'] = True kwargs['blank'] = True

View File

@ -17,7 +17,6 @@ class InvenTreeDateFilter(rest_filters.DateFilter):
def filter(self, qs, value): def filter(self, qs, value):
"""Override the filter method to handle timezones correctly.""" """Override the filter method to handle timezones correctly."""
if settings.USE_TZ: if settings.USE_TZ:
if value is not None: if value is not None:
tz = timezone.get_current_timezone() tz = timezone.get_current_timezone()
@ -28,7 +27,7 @@ class InvenTreeDateFilter(rest_filters.DateFilter):
class InvenTreeSearchFilter(filters.SearchFilter): class InvenTreeSearchFilter(filters.SearchFilter):
"""Custom search filter which allows adjusting of search terms dynamically""" """Custom search filter which allows adjusting of search terms dynamically."""
def get_search_fields(self, view, request): def get_search_fields(self, view, request):
"""Return a set of search fields for the request, adjusted based on request params. """Return a set of search fields for the request, adjusted based on request params.

View File

@ -1,4 +1,4 @@
"""Custom string formatting functions and helpers""" """Custom string formatting functions and helpers."""
import re import re
import string import string
@ -42,7 +42,7 @@ def parse_format_string(fmt_string: str) -> dict:
def construct_format_regex(fmt_string: str) -> str: def construct_format_regex(fmt_string: str) -> str:
r"""Construct a regular expression based on a provided format string r"""Construct a regular expression based on a provided format string.
This function turns a python format string into a regular expression, This function turns a python format string into a regular expression,
which can be used for two purposes: which can be used for two purposes:
@ -143,7 +143,7 @@ def validate_string(value: str, fmt_string: str) -> str:
def extract_named_group(name: str, value: str, fmt_string: str) -> str: def extract_named_group(name: str, value: str, fmt_string: str) -> str:
"""Extract a named value from the provided string, given the provided format string """Extract a named value from the provided string, given the provided format string.
Args: Args:
name: Name of group to extract e.g. 'ref' name: Name of group to extract e.g. 'ref'
@ -181,11 +181,12 @@ def extract_named_group(name: str, value: str, fmt_string: str) -> str:
def format_money(money: Money, decimal_places: int = None, format: str = None) -> str: def format_money(money: Money, decimal_places: int = None, format: str = None) -> str:
"""Format money object according to the currently set local """Format money object according to the currently set local.
Args: Args:
decimal_places: Number of decimal places to use money (Money): The money object to format
format: Format pattern according LDML / the babel format pattern syntax (https://babel.pocoo.org/en/latest/numbers.html) decimal_places (int): Number of decimal places to use
format (str): Format pattern according LDML / the babel format pattern syntax (https://babel.pocoo.org/en/latest/numbers.html)
Returns: Returns:
str: The formatted string str: The formatted string

View File

@ -151,7 +151,7 @@ class SetPasswordForm(HelperForm):
# override allauth # override allauth
class CustomLoginForm(LoginForm): class CustomLoginForm(LoginForm):
"""Custom login form to override default allauth behaviour""" """Custom login form to override default allauth behaviour."""
def login(self, request, redirect_url=None): def login(self, request, redirect_url=None):
"""Perform login action. """Perform login action.
@ -317,7 +317,7 @@ class CustomAccountAdapter(
return False return False
def get_email_confirmation_url(self, request, emailconfirmation): def get_email_confirmation_url(self, request, emailconfirmation):
"""Construct the email confirmation url""" """Construct the email confirmation url."""
from InvenTree.helpers_model import construct_absolute_url from InvenTree.helpers_model import construct_absolute_url
url = super().get_email_confirmation_url(request, emailconfirmation) url = super().get_email_confirmation_url(request, emailconfirmation)
@ -363,7 +363,6 @@ class CustomSocialAccountAdapter(
self, request, provider_id, error=None, exception=None, extra_context=None self, request, provider_id, error=None, exception=None, extra_context=None
): ):
"""Callback method for authentication errors.""" """Callback method for authentication errors."""
# Log the error to the database # Log the error to the database
log_error(request.path if request else 'sso') log_error(request.path if request else 'sso')
logger.error("SSO error for provider '%s' - check admin error log", provider_id) logger.error("SSO error for provider '%s' - check admin error log", provider_id)

View File

@ -116,7 +116,7 @@ def getLogoImage(as_file=False, custom=True):
def getSplashScreen(custom=True): def getSplashScreen(custom=True):
"""Return the InvenTree splash screen, or a custom splash if available""" """Return the InvenTree splash screen, or a custom splash if available."""
static_storage = StaticFilesStorage() static_storage = StaticFilesStorage()
if custom and settings.CUSTOM_SPLASH: if custom and settings.CUSTOM_SPLASH:
@ -161,7 +161,7 @@ def str2bool(text, test=True):
def str2int(text, default=None): def str2int(text, default=None):
"""Convert a string to int if possible """Convert a string to int if possible.
Args: Args:
text: Int like string text: Int like string
@ -334,10 +334,9 @@ def MakeBarcode(cls_name, object_pk: int, object_data=None, **kwargs):
"""Generate a string for a barcode. Adds some global InvenTree parameters. """Generate a string for a barcode. Adds some global InvenTree parameters.
Args: Args:
object_type: string describing the object type e.g. 'StockItem' cls_name: string describing the object type e.g. 'StockItem'
object_id: ID (Primary Key) of the object in the database object_pk (int): ID (Primary Key) of the object in the database
object_url: url for JSON API detail view of the object object_data: Python dict object containing extra data which will be rendered to string (must only contain stringable values)
data: Python dict object containing extra data which will be rendered to string (must only contain stringable values)
Returns: Returns:
json string of the supplied data plus some other data json string of the supplied data plus some other data
@ -479,13 +478,12 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
errors = [] errors = []
def add_error(error: str): def add_error(error: str):
"""Helper function for adding an error message""" """Helper function for adding an error message."""
if error not in errors: if error not in errors:
errors.append(error) errors.append(error)
def add_serial(serial): def add_serial(serial):
"""Helper function to check for duplicated values""" """Helper function to check for duplicated values."""
serial = serial.strip() serial = serial.strip()
# Ignore blank / empty serials # Ignore blank / empty serials
@ -753,7 +751,7 @@ def strip_html_tags(value: str, raise_error=True, field_name=None):
def remove_non_printable_characters( def remove_non_printable_characters(
value: str, remove_newline=True, remove_ascii=True, remove_unicode=True value: str, remove_newline=True, remove_ascii=True, remove_unicode=True
): ):
"""Remove non-printable / control characters from the provided string""" """Remove non-printable / control characters from the provided string."""
cleaned = value cleaned = value
if remove_ascii: if remove_ascii:

View File

@ -73,7 +73,7 @@ def construct_absolute_url(*arg, **kwargs):
def get_base_url(**kwargs): def get_base_url(**kwargs):
"""Return the base URL for the InvenTree server""" """Return the base URL for the InvenTree server."""
return construct_absolute_url('', **kwargs) return construct_absolute_url('', **kwargs)
@ -193,7 +193,7 @@ def render_currency(
min_decimal_places=None, min_decimal_places=None,
max_decimal_places=None, max_decimal_places=None,
): ):
"""Render a currency / Money object to a formatted string (e.g. for reports) """Render a currency / Money object to a formatted string (e.g. for reports).
Arguments: Arguments:
money: The Money instance to be rendered money: The Money instance to be rendered

View File

@ -18,7 +18,6 @@ class DiffMixin:
Returns: Returns:
object: Instance of the object saved in the database object: Instance of the object saved in the database
""" """
if self.pk: if self.pk:
try: try:
return self.__class__.objects.get(pk=self.pk) return self.__class__.objects.get(pk=self.pk)
@ -36,7 +35,6 @@ class DiffMixin:
Returns: Returns:
dict: Dict of field deltas dict: Dict of field deltas
""" """
db_instance = self.get_db_instance() db_instance = self.get_db_instance()
if db_instance is None: if db_instance is None:
@ -58,7 +56,6 @@ class DiffMixin:
def has_field_changed(self, field_name): def has_field_changed(self, field_name):
"""Determine if a particular field has changed.""" """Determine if a particular field has changed."""
return field_name in self.get_field_deltas() return field_name in self.get_field_deltas()

View File

@ -95,6 +95,7 @@ class MetadataMixin(models.Model):
Args: Args:
key: String key for requesting metadata. e.g. if a plugin is accessing the metadata, the plugin slug should be used key: String key for requesting metadata. e.g. if a plugin is accessing the metadata, the plugin slug should be used
backup_value: Value that should be used if no value is found
Returns: Returns:
Python dict object containing requested metadata. If no matching metadata is found, returns None Python dict object containing requested metadata. If no matching metadata is found, returns None
@ -228,7 +229,7 @@ class ReferenceIndexingMixin(models.Model):
@classmethod @classmethod
def get_most_recent_item(cls): def get_most_recent_item(cls):
"""Return the item which is 'most recent' """Return the item which is 'most recent'.
In practice, this means the item with the highest reference value In practice, this means the item with the highest reference value
""" """
@ -270,7 +271,7 @@ class ReferenceIndexingMixin(models.Model):
@classmethod @classmethod
def generate_reference(cls): def generate_reference(cls):
"""Generate the next 'reference' field based on specified pattern""" """Generate the next 'reference' field based on specified pattern."""
fmt = cls.get_reference_pattern() fmt = cls.get_reference_pattern()
ctx = cls.get_reference_context() ctx = cls.get_reference_context()
@ -309,7 +310,7 @@ class ReferenceIndexingMixin(models.Model):
@classmethod @classmethod
def validate_reference_pattern(cls, pattern): def validate_reference_pattern(cls, pattern):
"""Ensure that the provided pattern is valid""" """Ensure that the provided pattern is valid."""
ctx = cls.get_reference_context() ctx = cls.get_reference_context()
try: try:
@ -334,7 +335,7 @@ class ReferenceIndexingMixin(models.Model):
@classmethod @classmethod
def validate_reference_field(cls, value): def validate_reference_field(cls, value):
"""Check that the provided 'reference' value matches the requisite pattern""" """Check that the provided 'reference' value matches the requisite pattern."""
pattern = cls.get_reference_pattern() pattern = cls.get_reference_pattern()
value = str(value).strip() value = str(value).strip()
@ -713,7 +714,6 @@ class InvenTreeTree(MPTTModel):
C) delete_children = False and delete_items = True C) delete_children = False and delete_items = True
D) delete_children = False and delete_items = False D) delete_children = False and delete_items = False
""" """
child_nodes = self.get_descendants(include_self=False) child_nodes = self.get_descendants(include_self=False)
# Case A: Delete all child items, and all child nodes. # Case A: Delete all child items, and all child nodes.
@ -756,7 +756,6 @@ class InvenTreeTree(MPTTModel):
Arguments: Arguments:
nodes: A queryset of nodes to delete nodes: A queryset of nodes to delete
""" """
nodes.update(parent=None) nodes.update(parent=None)
nodes.delete() nodes.delete()
@ -782,11 +781,11 @@ class InvenTreeTree(MPTTModel):
return {'parent': {'exclude_tree': self.pk}} return {'parent': {'exclude_tree': self.pk}}
def construct_pathstring(self): def construct_pathstring(self):
"""Construct the pathstring for this tree node""" """Construct the pathstring for this tree node."""
return InvenTree.helpers.constructPathString([item.name for item in self.path]) return InvenTree.helpers.constructPathString([item.name for item in self.path])
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Custom save method for InvenTreeTree abstract model""" """Custom save method for InvenTreeTree abstract model."""
try: try:
super().save(*args, **kwargs) super().save(*args, **kwargs)
except InvalidMove: except InvalidMove:
@ -1003,7 +1002,6 @@ class InvenTreeBarcodeMixin(models.Model):
def format_matched_response(self): def format_matched_response(self):
"""Format a standard response for a matched barcode.""" """Format a standard response for a matched barcode."""
data = {'pk': self.pk} data = {'pk': self.pk}
if hasattr(self, 'get_api_url'): if hasattr(self, 'get_api_url'):
@ -1017,7 +1015,7 @@ class InvenTreeBarcodeMixin(models.Model):
@property @property
def barcode(self): def barcode(self):
"""Format a minimal barcode string (e.g. for label printing)""" """Format a minimal barcode string (e.g. for label printing)."""
return self.format_barcode(brief=True) return self.format_barcode(brief=True)
@classmethod @classmethod
@ -1055,7 +1053,7 @@ class InvenTreeBarcodeMixin(models.Model):
return True return True
def unassign_barcode(self): def unassign_barcode(self):
"""Unassign custom barcode from this model""" """Unassign custom barcode from this model."""
self.barcode_data = '' self.barcode_data = ''
self.barcode_hash = '' self.barcode_hash = ''

View File

@ -8,7 +8,7 @@ import users.models
def get_model_for_view(view, raise_error=True): def get_model_for_view(view, raise_error=True):
"""Attempt to introspect the 'model' type for an API view""" """Attempt to introspect the 'model' type for an API view."""
if hasattr(view, 'get_permission_model'): if hasattr(view, 'get_permission_model'):
return view.get_permission_model() return view.get_permission_model()

View File

@ -16,7 +16,7 @@ logger = logging.getLogger('inventree')
def default_sentry_dsn(): def default_sentry_dsn():
"""Return the default Sentry.io DSN for InvenTree""" """Return the default Sentry.io DSN for InvenTree."""
return 'https://3928ccdba1d34895abde28031fd00100@o378676.ingest.sentry.io/6494600' return 'https://3928ccdba1d34895abde28031fd00100@o378676.ingest.sentry.io/6494600'
@ -36,7 +36,7 @@ def sentry_ignore_errors():
def init_sentry(dsn, sample_rate, tags): def init_sentry(dsn, sample_rate, tags):
"""Initialize sentry.io error reporting""" """Initialize sentry.io error reporting."""
logger.info('Initializing sentry.io integration') logger.info('Initializing sentry.io integration')
sentry_sdk.init( sentry_sdk.init(
@ -62,7 +62,7 @@ def init_sentry(dsn, sample_rate, tags):
def report_exception(exc): def report_exception(exc):
"""Report an exception to sentry.io""" """Report an exception to sentry.io."""
if settings.SENTRY_ENABLED and settings.SENTRY_DSN: if settings.SENTRY_ENABLED and settings.SENTRY_DSN:
if not any(isinstance(exc, e) for e in sentry_ignore_errors()): if not any(isinstance(exc, e) for e in sentry_ignore_errors()):
logger.info('Reporting exception to sentry.io: %s', exc) logger.info('Reporting exception to sentry.io: %s', exc)

View File

@ -75,10 +75,10 @@ class InvenTreeMoneySerializer(MoneyField):
class InvenTreeCurrencySerializer(serializers.ChoiceField): class InvenTreeCurrencySerializer(serializers.ChoiceField):
"""Custom serializers for selecting currency option""" """Custom serializers for selecting currency option."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialize the currency serializer""" """Initialize the currency serializer."""
choices = currency_code_mappings() choices = currency_code_mappings()
allow_blank = kwargs.get('allow_blank', False) or kwargs.get( allow_blank = kwargs.get('allow_blank', False) or kwargs.get(
@ -289,7 +289,7 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
return self.instance return self.instance
def create(self, validated_data): def create(self, validated_data):
"""Custom create method which supports field adjustment""" """Custom create method which supports field adjustment."""
initial_data = validated_data.copy() initial_data = validated_data.copy()
# Remove any fields which do not exist on the model # Remove any fields which do not exist on the model
@ -494,7 +494,7 @@ class InvenTreeAttachmentSerializer(InvenTreeModelSerializer):
@staticmethod @staticmethod
def attachment_fields(extra_fields=None): def attachment_fields(extra_fields=None):
"""Default set of fields for an attachment serializer""" """Default set of fields for an attachment serializer."""
fields = [ fields = [
'pk', 'pk',
'attachment', 'attachment',
@ -827,7 +827,7 @@ class RemoteImageMixin(metaclass=serializers.SerializerMetaclass):
""" """
def skip_create_fields(self): def skip_create_fields(self):
"""Ensure the 'remote_image' field is skipped when creating a new instance""" """Ensure the 'remote_image' field is skipped when creating a new instance."""
return ['remote_image'] return ['remote_image']
remote_image = serializers.URLField( remote_image = serializers.URLField(

View File

@ -22,7 +22,7 @@ logger = logging.getLogger('inventree')
class GenericOAuth2ApiLoginView(OAuth2LoginView): class GenericOAuth2ApiLoginView(OAuth2LoginView):
"""Api view to login a user with a social account""" """Api view to login a user with a social account."""
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
"""Dispatch the regular login view directly.""" """Dispatch the regular login view directly."""
@ -30,7 +30,7 @@ class GenericOAuth2ApiLoginView(OAuth2LoginView):
class GenericOAuth2ApiConnectView(GenericOAuth2ApiLoginView): class GenericOAuth2ApiConnectView(GenericOAuth2ApiLoginView):
"""Api view to connect a social account to the current user""" """Api view to connect a social account to the current user."""
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
"""Dispatch the connect request directly.""" """Dispatch the connect request directly."""

View File

@ -1,4 +1,4 @@
"""Helper functions for Single Sign On functionality""" """Helper functions for Single Sign On functionality."""
import logging import logging
@ -9,8 +9,7 @@ logger = logging.getLogger('inventree')
def get_provider_app(provider): def get_provider_app(provider):
"""Return the SocialApp object for the given provider""" """Return the SocialApp object for the given provider."""
from allauth.socialaccount.models import SocialApp from allauth.socialaccount.models import SocialApp
try: try:
@ -36,7 +35,6 @@ def check_provider(provider, raise_error=False):
- Provider must either have a registered SocialApp - Provider must either have a registered SocialApp
- Must have at least one site enabled - Must have at least one site enabled
""" """
import allauth.app_settings import allauth.app_settings
# First, check that the provider is enabled # First, check that the provider is enabled
@ -56,8 +54,7 @@ def check_provider(provider, raise_error=False):
def provider_display_name(provider): def provider_display_name(provider):
"""Return the 'display name' for the given provider""" """Return the 'display name' for the given provider."""
if app := get_provider_app(provider): if app := get_provider_app(provider):
return app.name return app.name
@ -66,15 +63,15 @@ def provider_display_name(provider):
def login_enabled() -> bool: def login_enabled() -> bool:
"""Return True if SSO login is enabled""" """Return True if SSO login is enabled."""
return str2bool(InvenTreeSetting.get_setting('LOGIN_ENABLE_SSO')) return str2bool(InvenTreeSetting.get_setting('LOGIN_ENABLE_SSO'))
def registration_enabled() -> bool: def registration_enabled() -> bool:
"""Return True if SSO registration is enabled""" """Return True if SSO registration is enabled."""
return str2bool(InvenTreeSetting.get_setting('LOGIN_ENABLE_SSO_REG')) return str2bool(InvenTreeSetting.get_setting('LOGIN_ENABLE_SSO_REG'))
def auto_registration_enabled() -> bool: def auto_registration_enabled() -> bool:
"""Return True if SSO auto-registration is enabled""" """Return True if SSO auto-registration is enabled."""
return str2bool(InvenTreeSetting.get_setting('LOGIN_SIGNUP_SSO_AUTO')) return str2bool(InvenTreeSetting.get_setting('LOGIN_SIGNUP_SSO_AUTO'))

View File

@ -158,7 +158,7 @@ class BuildStatusGroups:
class ReturnOrderStatus(StatusCode): class ReturnOrderStatus(StatusCode):
"""Defines a set of status codes for a ReturnOrder""" """Defines a set of status codes for a ReturnOrder."""
# Order is pending, waiting for receipt of items # Order is pending, waiting for receipt of items
PENDING = 10, _('Pending'), 'secondary' PENDING = 10, _('Pending'), 'secondary'
@ -177,7 +177,7 @@ class ReturnOrderStatusGroups:
class ReturnOrderLineStatus(StatusCode): class ReturnOrderLineStatus(StatusCode):
"""Defines a set of status codes for a ReturnOrderLineItem""" """Defines a set of status codes for a ReturnOrderLineItem."""
PENDING = 10, _('Pending'), 'secondary' PENDING = 10, _('Pending'), 'secondary'

View File

@ -78,8 +78,8 @@ def check_daily_holdoff(task_name: str, n_days: int = 1) -> bool:
"""Check if a periodic task should be run, based on the provided setting name. """Check if a periodic task should be run, based on the provided setting name.
Arguments: Arguments:
task_name: The name of the task being run, e.g. 'dummy_task' task_name (str): The name of the task being run, e.g. 'dummy_task'
setting_name: The name of the global setting, e.g. 'INVENTREE_DUMMY_TASK_INTERVAL' n_days (int): The number of days between task runs (default = 1)
Returns: Returns:
bool: If the task should be run *now*, or wait another day bool: If the task should be run *now*, or wait another day
@ -151,7 +151,7 @@ def check_daily_holdoff(task_name: str, n_days: int = 1) -> bool:
def record_task_attempt(task_name: str): def record_task_attempt(task_name: str):
"""Record that a multi-day task has been attempted *now*""" """Record that a multi-day task has been attempted *now*."""
from common.models import InvenTreeSetting from common.models import InvenTreeSetting
logger.info("Logging task attempt for '%s'", task_name) logger.info("Logging task attempt for '%s'", task_name)
@ -162,7 +162,7 @@ def record_task_attempt(task_name: str):
def record_task_success(task_name: str): def record_task_success(task_name: str):
"""Record that a multi-day task was successful *now*""" """Record that a multi-day task was successful *now*."""
from common.models import InvenTreeSetting from common.models import InvenTreeSetting
InvenTreeSetting.set_setting( InvenTreeSetting.set_setting(
@ -365,7 +365,7 @@ def heartbeat():
@scheduled_task(ScheduledTask.DAILY) @scheduled_task(ScheduledTask.DAILY)
def delete_successful_tasks(): def delete_successful_tasks():
"""Delete successful task logs which are older than a specified period""" """Delete successful task logs which are older than a specified period."""
try: try:
from django_q.models import Success from django_q.models import Success
@ -389,7 +389,7 @@ def delete_successful_tasks():
@scheduled_task(ScheduledTask.DAILY) @scheduled_task(ScheduledTask.DAILY)
def delete_failed_tasks(): def delete_failed_tasks():
"""Delete failed task logs which are older than a specified period""" """Delete failed task logs which are older than a specified period."""
try: try:
from django_q.models import Failure from django_q.models import Failure
@ -435,7 +435,7 @@ def delete_old_error_logs():
@scheduled_task(ScheduledTask.DAILY) @scheduled_task(ScheduledTask.DAILY)
def delete_old_notifications(): def delete_old_notifications():
"""Delete old notification logs""" """Delete old notification logs."""
try: try:
from common.models import ( from common.models import (
InvenTreeSetting, InvenTreeSetting,
@ -552,7 +552,7 @@ def check_for_updates():
@scheduled_task(ScheduledTask.DAILY) @scheduled_task(ScheduledTask.DAILY)
def update_exchange_rates(force: bool = False): def update_exchange_rates(force: bool = False):
"""Update currency exchange rates """Update currency exchange rates.
Arguments: Arguments:
force: If True, force the update to run regardless of the last update time force: If True, force the update to run regardless of the last update time
@ -648,8 +648,7 @@ def check_for_migrations():
from plugin import registry from plugin import registry
def set_pending_migrations(n: int): def set_pending_migrations(n: int):
"""Helper function to inform the user about pending migrations""" """Helper function to inform the user about pending migrations."""
logger.info('There are %s pending migrations', n) logger.info('There are %s pending migrations', n)
InvenTreeSetting.set_setting('_PENDING_MIGRATIONS', n, None) InvenTreeSetting.set_setting('_PENDING_MIGRATIONS', n, None)

View File

@ -1,4 +1,4 @@
"""Custom template loader for InvenTree""" """Custom template loader for InvenTree."""
import os import os
@ -8,7 +8,7 @@ from django.template.loaders.cached import Loader as CachedLoader
class InvenTreeTemplateLoader(CachedLoader): class InvenTreeTemplateLoader(CachedLoader):
"""Custom template loader which bypasses cache for PDF export""" """Custom template loader which bypasses cache for PDF export."""
def get_template(self, template_name, skip=None): def get_template(self, template_name, skip=None):
"""Return a template object for the given template name. """Return a template object for the given template name.

View File

@ -257,12 +257,12 @@ class APITests(InvenTreeAPITestCase):
class BulkDeleteTests(InvenTreeAPITestCase): class BulkDeleteTests(InvenTreeAPITestCase):
"""Unit tests for the BulkDelete endpoints""" """Unit tests for the BulkDelete endpoints."""
superuser = True superuser = True
def test_errors(self): def test_errors(self):
"""Test that the correct errors are thrown""" """Test that the correct errors are thrown."""
url = reverse('api-stock-test-result-list') url = reverse('api-stock-test-result-list')
# DELETE without any of the required fields # DELETE without any of the required fields
@ -285,7 +285,7 @@ class BulkDeleteTests(InvenTreeAPITestCase):
class SearchTests(InvenTreeAPITestCase): class SearchTests(InvenTreeAPITestCase):
"""Unit tests for global search endpoint""" """Unit tests for global search endpoint."""
fixtures = [ fixtures = [
'category', 'category',
@ -299,7 +299,7 @@ class SearchTests(InvenTreeAPITestCase):
] ]
def test_empty(self): def test_empty(self):
"""Test empty request""" """Test empty request."""
data = ['', None, {}] data = ['', None, {}]
for d in data: for d in data:
@ -307,7 +307,7 @@ class SearchTests(InvenTreeAPITestCase):
self.assertIn('Search term must be provided', str(response.data)) self.assertIn('Search term must be provided', str(response.data))
def test_results(self): def test_results(self):
"""Test individual result types""" """Test individual result types."""
response = self.post( response = self.post(
reverse('api-search'), reverse('api-search'),
{'search': 'chair', 'limit': 3, 'part': {}, 'build': {}}, {'search': 'chair', 'limit': 3, 'part': {}, 'build': {}},
@ -339,7 +339,7 @@ class SearchTests(InvenTreeAPITestCase):
self.assertNotIn('build', response.data) self.assertNotIn('build', response.data)
def test_permissions(self): def test_permissions(self):
"""Test that users with insufficient permissions are handled correctly""" """Test that users with insufficient permissions are handled correctly."""
# First, remove all roles # First, remove all roles
for ruleset in self.group.rule_sets.all(): for ruleset in self.group.rule_sets.all():
ruleset.can_view = False ruleset.can_view = False

View File

@ -28,7 +28,7 @@ class ViewTests(InvenTreeTestCase):
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
def get_index_page(self): def get_index_page(self):
"""Retrieve the index page (used for subsequent unit tests)""" """Retrieve the index page (used for subsequent unit tests)."""
response = self.client.get('/index/') response = self.client.get('/index/')
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@ -44,7 +44,7 @@ class ViewTests(InvenTreeTestCase):
# TODO: In future, run the javascript and ensure that the panels get created! # TODO: In future, run the javascript and ensure that the panels get created!
def test_settings_page(self): def test_settings_page(self):
"""Test that the 'settings' page loads correctly""" """Test that the 'settings' page loads correctly."""
# Settings page loads # Settings page loads
url = reverse('settings') url = reverse('settings')
@ -102,7 +102,7 @@ class ViewTests(InvenTreeTestCase):
self.assertNotIn(f'panel-{panel}', content) self.assertNotIn(f'panel-{panel}', content)
def test_url_login(self): def test_url_login(self):
"""Test logging in via arguments""" """Test logging in via arguments."""
# Log out # Log out
self.client.logout() self.client.logout()
response = self.client.get('/index/') response = self.client.get('/index/')

View File

@ -40,10 +40,10 @@ from .validators import validate_overage
class ConversionTest(TestCase): class ConversionTest(TestCase):
"""Tests for conversion of physical units""" """Tests for conversion of physical units."""
def test_prefixes(self): def test_prefixes(self):
"""Test inputs where prefixes are used""" """Test inputs where prefixes are used."""
tests = { tests = {
'3': 3, '3': 3,
'3m': 3, '3m': 3,
@ -58,7 +58,7 @@ class ConversionTest(TestCase):
self.assertAlmostEqual(q, expected, 3) self.assertAlmostEqual(q, expected, 3)
def test_base_units(self): def test_base_units(self):
"""Test conversion to specified base units""" """Test conversion to specified base units."""
tests = { tests = {
'3': 3, '3': 3,
'3 dozen': 36, '3 dozen': 36,
@ -76,7 +76,7 @@ class ConversionTest(TestCase):
self.assertAlmostEqual(float(q.magnitude), expected, places=2) self.assertAlmostEqual(float(q.magnitude), expected, places=2)
def test_dimensionless_units(self): def test_dimensionless_units(self):
"""Tests for 'dimensionless' unit quantities""" """Tests for 'dimensionless' unit quantities."""
# Test some dimensionless units # Test some dimensionless units
tests = { tests = {
'ea': 1, 'ea': 1,
@ -103,7 +103,7 @@ class ConversionTest(TestCase):
self.assertAlmostEqual(q, expected, 3) self.assertAlmostEqual(q, expected, 3)
def test_invalid_units(self): def test_invalid_units(self):
"""Test conversion with bad units""" """Test conversion with bad units."""
tests = {'3': '10', '13': '-?-', '-3': 'xyz', '-12': '-12', '1/0': '1/0'} tests = {'3': '10', '13': '-?-', '-3': 'xyz', '-12': '-12', '1/0': '1/0'}
for val, unit in tests.items(): for val, unit in tests.items():
@ -111,7 +111,7 @@ class ConversionTest(TestCase):
InvenTree.conversion.convert_physical_value(val, unit) InvenTree.conversion.convert_physical_value(val, unit)
def test_invalid_values(self): def test_invalid_values(self):
"""Test conversion of invalid inputs""" """Test conversion of invalid inputs."""
inputs = ['-x', '1/0', 'xyz', '12B45C'] inputs = ['-x', '1/0', 'xyz', '12B45C']
for val in inputs: for val in inputs:
@ -124,7 +124,7 @@ class ConversionTest(TestCase):
InvenTree.conversion.convert_physical_value(val) InvenTree.conversion.convert_physical_value(val)
def test_custom_units(self): def test_custom_units(self):
"""Tests for custom unit conversion""" """Tests for custom unit conversion."""
# Start with an empty set of units # Start with an empty set of units
CustomUnit.objects.all().delete() CustomUnit.objects.all().delete()
InvenTree.conversion.reload_unit_registry() InvenTree.conversion.reload_unit_registry()
@ -193,8 +193,7 @@ class ValidatorTest(TestCase):
validate_overage('aaaa') validate_overage('aaaa')
def test_url_validation(self): def test_url_validation(self):
"""Test for AllowedURLValidator""" """Test for AllowedURLValidator."""
from common.models import InvenTreeSetting from common.models import InvenTreeSetting
from part.models import Part, PartCategory from part.models import Part, PartCategory
@ -225,10 +224,10 @@ class ValidatorTest(TestCase):
class FormatTest(TestCase): class FormatTest(TestCase):
"""Unit tests for custom string formatting functionality""" """Unit tests for custom string formatting functionality."""
def test_parse(self): def test_parse(self):
"""Tests for the 'parse_format_string' function""" """Tests for the 'parse_format_string' function."""
# Extract data from a valid format string # Extract data from a valid format string
fmt = 'PO-{abc:02f}-{ref:04d}-{date}-???' fmt = 'PO-{abc:02f}-{ref:04d}-{date}-???'
@ -244,7 +243,7 @@ class FormatTest(TestCase):
InvenTree.format.parse_format_string(fmt) InvenTree.format.parse_format_string(fmt)
def test_create_regex(self): def test_create_regex(self):
"""Test function for creating a regex from a format string""" """Test function for creating a regex from a format string."""
tests = { tests = {
'PO-123-{ref:04f}': r'^PO\-123\-(?P<ref>.+)$', 'PO-123-{ref:04f}': r'^PO\-123\-(?P<ref>.+)$',
'{PO}-???-{ref}-{date}-22': r'^(?P<PO>.+)\-...\-(?P<ref>.+)\-(?P<date>.+)\-22$', '{PO}-???-{ref}-{date}-22': r'^(?P<PO>.+)\-...\-(?P<ref>.+)\-(?P<date>.+)\-22$',
@ -256,7 +255,7 @@ class FormatTest(TestCase):
self.assertEqual(InvenTree.format.construct_format_regex(fmt), reg) self.assertEqual(InvenTree.format.construct_format_regex(fmt), reg)
def test_validate_format(self): def test_validate_format(self):
"""Test that string validation works as expected""" """Test that string validation works as expected."""
# These tests should pass # These tests should pass
for value, pattern in { for value, pattern in {
'ABC-hello-123': '???-{q}-###', 'ABC-hello-123': '???-{q}-###',
@ -276,7 +275,7 @@ class FormatTest(TestCase):
self.assertFalse(InvenTree.format.validate_string(value, pattern)) self.assertFalse(InvenTree.format.validate_string(value, pattern))
def test_extract_value(self): def test_extract_value(self):
"""Test that we can extract named values based on a format string""" """Test that we can extract named values based on a format string."""
# Simple tests based on a straight-forward format string # Simple tests based on a straight-forward format string
fmt = 'PO-###-{ref:04d}' fmt = 'PO-###-{ref:04d}'
@ -319,8 +318,7 @@ class FormatTest(TestCase):
InvenTree.format.extract_named_group('test', 'PO-ABC-xyz', 'PO-###-{test}') InvenTree.format.extract_named_group('test', 'PO-ABC-xyz', 'PO-###-{test}')
def test_currency_formatting(self): def test_currency_formatting(self):
"""Test that currency formatting works correctly for multiple currencies""" """Test that currency formatting works correctly for multiple currencies."""
test_data = ( test_data = (
(Money(3651.285718, 'USD'), 4, '$3,651.2857'), # noqa: E201,E202 (Money(3651.285718, 'USD'), 4, '$3,651.2857'), # noqa: E201,E202
(Money(487587.849178, 'CAD'), 5, 'CA$487,587.84918'), # noqa: E201,E202 (Money(487587.849178, 'CAD'), 5, 'CA$487,587.84918'), # noqa: E201,E202
@ -352,7 +350,7 @@ class TestHelpers(TestCase):
"""Tests for InvenTree helper functions.""" """Tests for InvenTree helper functions."""
def test_absolute_url(self): def test_absolute_url(self):
"""Test helper function for generating an absolute URL""" """Test helper function for generating an absolute URL."""
base = 'https://demo.inventree.org:12345' base = 'https://demo.inventree.org:12345'
InvenTreeSetting.set_setting('INVENTREE_BASE_URL', base, change_user=None) InvenTreeSetting.set_setting('INVENTREE_BASE_URL', base, change_user=None)
@ -431,7 +429,7 @@ class TestHelpers(TestCase):
self.assertEqual(helpers.decimal2string('test'), 'test') self.assertEqual(helpers.decimal2string('test'), 'test')
def test_logo_image(self): def test_logo_image(self):
"""Test for retrieving logo image""" """Test for retrieving logo image."""
# By default, there is no custom logo provided # By default, there is no custom logo provided
logo = helpers.getLogoImage() logo = helpers.getLogoImage()
self.assertEqual(logo, '/static/img/inventree.png') self.assertEqual(logo, '/static/img/inventree.png')
@ -440,7 +438,7 @@ class TestHelpers(TestCase):
self.assertEqual(logo, f'file://{settings.STATIC_ROOT}/img/inventree.png') self.assertEqual(logo, f'file://{settings.STATIC_ROOT}/img/inventree.png')
def test_download_image(self): def test_download_image(self):
"""Test function for downloading image from remote URL""" """Test function for downloading image from remote URL."""
# Run check with a sequence of bad URLs # Run check with a sequence of bad URLs
for url in ['blog', 'htp://test.com/?', 'google', '\\invalid-url']: for url in ['blog', 'htp://test.com/?', 'google', '\\invalid-url']:
with self.assertRaises(django_exceptions.ValidationError): with self.assertRaises(django_exceptions.ValidationError):
@ -452,7 +450,6 @@ class TestHelpers(TestCase):
As the httpstat.us service occasionally refuses a connection, As the httpstat.us service occasionally refuses a connection,
we will simply try multiple times we will simply try multiple times
""" """
tries = 0 tries = 0
with self.assertRaises(expected_error): with self.assertRaises(expected_error):
@ -499,7 +496,7 @@ class TestHelpers(TestCase):
InvenTree.helpers_model.download_image_from_url(large_img, timeout=10) InvenTree.helpers_model.download_image_from_url(large_img, timeout=10)
def test_model_mixin(self): def test_model_mixin(self):
"""Test the getModelsWithMixin function""" """Test the getModelsWithMixin function."""
from InvenTree.models import InvenTreeBarcodeMixin from InvenTree.models import InvenTreeBarcodeMixin
models = InvenTree.helpers_model.getModelsWithMixin(InvenTreeBarcodeMixin) models = InvenTree.helpers_model.getModelsWithMixin(InvenTreeBarcodeMixin)
@ -1069,7 +1066,7 @@ class TestInstanceName(InvenTreeTestCase):
class TestOffloadTask(InvenTreeTestCase): class TestOffloadTask(InvenTreeTestCase):
"""Tests for offloading tasks to the background worker""" """Tests for offloading tasks to the background worker."""
fixtures = ['category', 'part', 'location', 'stock'] fixtures = ['category', 'part', 'location', 'stock']
@ -1086,7 +1083,6 @@ class TestOffloadTask(InvenTreeTestCase):
Ref: https://github.com/inventree/InvenTree/pull/3273 Ref: https://github.com/inventree/InvenTree/pull/3273
""" """
self.assertTrue( self.assertTrue(
offload_task( offload_task(
'dummy_tasks.stock', 'dummy_tasks.stock',
@ -1121,7 +1117,7 @@ class TestOffloadTask(InvenTreeTestCase):
) )
def test_daily_holdoff(self): def test_daily_holdoff(self):
"""Tests for daily task holdoff helper functions""" """Tests for daily task holdoff helper functions."""
import InvenTree.tasks import InvenTree.tasks
with self.assertLogs(logger='inventree', level='INFO') as cm: with self.assertLogs(logger='inventree', level='INFO') as cm:
@ -1179,10 +1175,10 @@ class TestOffloadTask(InvenTreeTestCase):
class BarcodeMixinTest(InvenTreeTestCase): class BarcodeMixinTest(InvenTreeTestCase):
"""Tests for the InvenTreeBarcodeMixin mixin class""" """Tests for the InvenTreeBarcodeMixin mixin class."""
def test_barcode_model_type(self): def test_barcode_model_type(self):
"""Test that the barcode_model_type property works for each class""" """Test that the barcode_model_type property works for each class."""
from part.models import Part from part.models import Part
from stock.models import StockItem, StockLocation from stock.models import StockItem, StockLocation
@ -1191,7 +1187,7 @@ class BarcodeMixinTest(InvenTreeTestCase):
self.assertEqual(StockLocation.barcode_model_type(), 'stocklocation') self.assertEqual(StockLocation.barcode_model_type(), 'stocklocation')
def test_barcode_hash(self): def test_barcode_hash(self):
"""Test that the barcode hashing function provides correct results""" """Test that the barcode hashing function provides correct results."""
# Test multiple values for the hashing function # Test multiple values for the hashing function
# This is to ensure that the hash function is always "backwards compatible" # This is to ensure that the hash function is always "backwards compatible"
hashing_tests = { hashing_tests = {
@ -1226,7 +1222,7 @@ class MagicLoginTest(InvenTreeTestCase):
"""Test magic login token generation.""" """Test magic login token generation."""
def test_generation(self): def test_generation(self):
"""Test that magic login tokens are generated correctly""" """Test that magic login tokens are generated correctly."""
# User does not exists # User does not exists
resp = self.client.post(reverse('sesame-generate'), {'email': 1}) resp = self.client.post(reverse('sesame-generate'), {'email': 1})
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)

View File

@ -1,4 +1,4 @@
"""Translation helper functions""" """Translation helper functions."""
import json import json
@ -9,7 +9,7 @@ _translation_stats = None
def reload_translation_stats(): def reload_translation_stats():
"""Reload the translation stats from the compiled file""" """Reload the translation stats from the compiled file."""
global _translation_stats global _translation_stats
STATS_FILE = settings.BASE_DIR.joinpath('InvenTree/locale_stats.json').absolute() STATS_FILE = settings.BASE_DIR.joinpath('InvenTree/locale_stats.json').absolute()
@ -39,7 +39,7 @@ def reload_translation_stats():
def get_translation_percent(lang_code): def get_translation_percent(lang_code):
"""Return the translation percentage for the given language code""" """Return the translation percentage for the given language code."""
if _translation_stats is None: if _translation_stats is None:
reload_translation_stats() reload_translation_stats()

View File

@ -1,4 +1,4 @@
"""Helper functions for unit testing / CI""" """Helper functions for unit testing / CI."""
import csv import csv
import io import io
@ -109,7 +109,7 @@ class UserMixin:
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Run setup for all tests in a given class""" """Run setup for all tests in a given class."""
super().setUpTestData() super().setUpTestData()
# Create a user to log in with # Create a user to log in with
@ -139,7 +139,7 @@ class UserMixin:
cls.assignRole(role=role, group=cls.group) cls.assignRole(role=role, group=cls.group)
def setUp(self): def setUp(self):
"""Run setup for individual test methods""" """Run setup for individual test methods."""
if self.auto_login: if self.auto_login:
self.client.login(username=self.username, password=self.password) self.client.login(username=self.username, password=self.password)
@ -198,10 +198,10 @@ class PluginMixin:
class ExchangeRateMixin: class ExchangeRateMixin:
"""Mixin class for generating exchange rate data""" """Mixin class for generating exchange rate data."""
def generate_exchange_rates(self): def generate_exchange_rates(self):
"""Helper function which generates some exchange rates to work with""" """Helper function which generates some exchange rates to work with."""
rates = {'AUD': 1.5, 'CAD': 1.7, 'GBP': 0.9, 'USD': 1.0} rates = {'AUD': 1.5, 'CAD': 1.7, 'GBP': 0.9, 'USD': 1.0}
# Create a dummy backend # Create a dummy backend
@ -253,7 +253,7 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
self.assertLess(n, value, msg=msg) self.assertLess(n, value, msg=msg)
def checkResponse(self, url, method, expected_code, response): def checkResponse(self, url, method, expected_code, response):
"""Debug output for an unexpected response""" """Debug output for an unexpected response."""
# No expected code, return # No expected code, return
if expected_code is None: if expected_code is None:
return return

View File

@ -63,7 +63,6 @@ class AllowedURLValidator(validators.URLValidator):
def __call__(self, value): def __call__(self, value):
"""Validate the URL.""" """Validate the URL."""
import common.models import common.models
self.schemes = allowable_url_schemes() self.schemes = allowable_url_schemes()

View File

@ -30,8 +30,7 @@ except (NotGitRepository, FileNotFoundError):
def checkMinPythonVersion(): def checkMinPythonVersion():
"""Check that the Python version is at least 3.9""" """Check that the Python version is at least 3.9."""
version = sys.version.split(' ')[0] version = sys.version.split(' ')[0]
docs = 'https://docs.inventree.org/en/stable/start/intro/#python-requirements' docs = 'https://docs.inventree.org/en/stable/start/intro/#python-requirements'
@ -199,7 +198,7 @@ def inventreeDjangoVersion():
def inventreePythonVersion(): def inventreePythonVersion():
"""Returns the version of python""" """Returns the version of python."""
return sys.version.split(' ')[0] return sys.version.split(' ')[0]

View File

@ -333,8 +333,8 @@ class AjaxUpdateView(AjaxMixin, UpdateView):
"""Method for updating the object in the database. Default implementation is very simple, but can be overridden if required. """Method for updating the object in the database. Default implementation is very simple, but can be overridden if required.
Args: Args:
object - The current object, to be updated object: The current object, to be updated
form - The validated form form: The validated form
Returns: Returns:
object instance for supplied form object instance for supplied form
@ -651,7 +651,7 @@ class DatabaseStatsView(AjaxView):
class AboutView(AjaxView): class AboutView(AjaxView):
"""A view for displaying InvenTree version information""" """A view for displaying InvenTree version information."""
ajax_template_name = 'about.html' ajax_template_name = 'about.html'
ajax_form_title = _('About InvenTree') ajax_form_title = _('About InvenTree')

View File

@ -84,7 +84,7 @@ class BuildAdmin(ImportExportModelAdmin):
class BuildItemAdmin(admin.ModelAdmin): class BuildItemAdmin(admin.ModelAdmin):
"""Class for managing the BuildItem model via the admin interface""" """Class for managing the BuildItem model via the admin interface."""
list_display = ( list_display = (
'stock_item', 'stock_item',

View File

@ -110,12 +110,12 @@ class WebhookView(CsrfExemptMixin, APIView):
class CurrencyExchangeView(APIView): class CurrencyExchangeView(APIView):
"""API endpoint for displaying currency information""" """API endpoint for displaying currency information."""
permission_classes = [permissions.IsAuthenticated] permission_classes = [permissions.IsAuthenticated]
def get(self, request, format=None): def get(self, request, format=None):
"""Return information on available currency conversions""" """Return information on available currency conversions."""
# Extract a list of all available rates # Extract a list of all available rates
try: try:
rates = Rate.objects.all() rates = Rate.objects.all()
@ -157,7 +157,7 @@ class CurrencyRefreshView(APIView):
permission_classes = [permissions.IsAuthenticated, permissions.IsAdminUser] permission_classes = [permissions.IsAuthenticated, permissions.IsAdminUser]
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
"""Performing a POST request will update currency exchange rates""" """Performing a POST request will update currency exchange rates."""
from InvenTree.tasks import update_exchange_rates from InvenTree.tasks import update_exchange_rates
update_exchange_rates(force=True) update_exchange_rates(force=True)
@ -185,7 +185,7 @@ class GlobalSettingsList(SettingsList):
serializer_class = common.serializers.GlobalSettingsSerializer serializer_class = common.serializers.GlobalSettingsSerializer
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
"""Ensure all global settings are created""" """Ensure all global settings are created."""
common.models.InvenTreeSetting.build_default_values() common.models.InvenTreeSetting.build_default_values()
return super().list(request, *args, **kwargs) return super().list(request, *args, **kwargs)
@ -241,7 +241,7 @@ class UserSettingsList(SettingsList):
serializer_class = common.serializers.UserSettingsSerializer serializer_class = common.serializers.UserSettingsSerializer
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
"""Ensure all user settings are created""" """Ensure all user settings are created."""
common.models.InvenTreeUserSetting.build_default_values(user=request.user) common.models.InvenTreeUserSetting.build_default_values(user=request.user)
return super().list(request, *args, **kwargs) return super().list(request, *args, **kwargs)
@ -361,7 +361,7 @@ class NotificationList(NotificationMessageMixin, BulkDeleteMixin, ListAPI):
return queryset return queryset
def filter_delete_queryset(self, queryset, request): def filter_delete_queryset(self, queryset, request):
"""Ensure that the user can only delete their *own* notifications""" """Ensure that the user can only delete their *own* notifications."""
queryset = queryset.filter(user=request.user) queryset = queryset.filter(user=request.user)
return queryset return queryset
@ -440,7 +440,7 @@ class NotesImageList(ListCreateAPI):
permission_classes = [permissions.IsAuthenticated] permission_classes = [permissions.IsAuthenticated]
def perform_create(self, serializer): def perform_create(self, serializer):
"""Create (upload) a new notes image""" """Create (upload) a new notes image."""
image = serializer.save() image = serializer.save()
image.user = self.request.user image.user = self.request.user
image.save() image.save()
@ -460,7 +460,7 @@ class ProjectCodeList(ListCreateAPI):
class ProjectCodeDetail(RetrieveUpdateDestroyAPI): class ProjectCodeDetail(RetrieveUpdateDestroyAPI):
"""Detail view for a particular project code""" """Detail view for a particular project code."""
queryset = common.models.ProjectCode.objects.all() queryset = common.models.ProjectCode.objects.all()
serializer_class = common.serializers.ProjectCodeSerializer serializer_class = common.serializers.ProjectCodeSerializer
@ -468,7 +468,7 @@ class ProjectCodeDetail(RetrieveUpdateDestroyAPI):
class CustomUnitList(ListCreateAPI): class CustomUnitList(ListCreateAPI):
"""List view for custom units""" """List view for custom units."""
queryset = common.models.CustomUnit.objects.all() queryset = common.models.CustomUnit.objects.all()
serializer_class = common.serializers.CustomUnitSerializer serializer_class = common.serializers.CustomUnitSerializer
@ -477,7 +477,7 @@ class CustomUnitList(ListCreateAPI):
class CustomUnitDetail(RetrieveUpdateDestroyAPI): class CustomUnitDetail(RetrieveUpdateDestroyAPI):
"""Detail view for a particular custom unit""" """Detail view for a particular custom unit."""
queryset = common.models.CustomUnit.objects.all() queryset = common.models.CustomUnit.objects.all()
serializer_class = common.serializers.CustomUnitSerializer serializer_class = common.serializers.CustomUnitSerializer

View File

@ -19,7 +19,6 @@ class CommonConfig(AppConfig):
def ready(self): def ready(self):
"""Initialize restart flag clearance on startup.""" """Initialize restart flag clearance on startup."""
if InvenTree.ready.isRunningMigrations(): if InvenTree.ready.isRunningMigrations():
return return

View File

@ -77,14 +77,15 @@ class MetaMixin(models.Model):
class BaseURLValidator(URLValidator): class BaseURLValidator(URLValidator):
"""Validator for the InvenTree base URL: """Validator for the InvenTree base URL.
Rules:
- Allow empty value - Allow empty value
- Allow value without specified TLD (top level domain) - Allow value without specified TLD (top level domain)
""" """
def __init__(self, schemes=None, **kwargs): def __init__(self, schemes=None, **kwargs):
"""Custom init routine""" """Custom init routine."""
super().__init__(schemes, **kwargs) super().__init__(schemes, **kwargs)
# Override default host_re value - allow optional tld regex # Override default host_re value - allow optional tld regex
@ -145,7 +146,7 @@ class ProjectCode(InvenTree.models.MetadataMixin, models.Model):
class SettingsKeyType(TypedDict, total=False): class SettingsKeyType(TypedDict, total=False):
"""Type definitions for a SettingsKeyType """Type definitions for a SettingsKeyType.
Attributes: Attributes:
name: Translatable string name of the setting (required) name: Translatable string name of the setting (required)
@ -216,7 +217,7 @@ class BaseInvenTreeSetting(models.Model):
@classmethod @classmethod
def build_default_values(cls, **kwargs): def build_default_values(cls, **kwargs):
"""Ensure that all values defined in SETTINGS are present in the database """Ensure that all values defined in SETTINGS are present in the database.
If a particular setting is not present, create it with the default value If a particular setting is not present, create it with the default value
""" """
@ -269,13 +270,13 @@ class BaseInvenTreeSetting(models.Model):
@property @property
def cache_key(self): def cache_key(self):
"""Generate a unique cache key for this settings object""" """Generate a unique cache key for this settings object."""
return self.__class__.create_cache_key( return self.__class__.create_cache_key(
self.key, **self.get_filters_for_instance() self.key, **self.get_filters_for_instance()
) )
def save_to_cache(self): def save_to_cache(self):
"""Save this setting object to cache""" """Save this setting object to cache."""
ckey = self.cache_key ckey = self.cache_key
# skip saving to cache if no pk is set # skip saving to cache if no pk is set
@ -308,7 +309,7 @@ class BaseInvenTreeSetting(models.Model):
@classmethod @classmethod
def get_filters(cls, **kwargs): def get_filters(cls, **kwargs):
"""Enable to filter by other kwargs defined in cls.extra_unique_fields""" """Enable to filter by other kwargs defined in cls.extra_unique_fields."""
return { return {
key: value key: value
for key, value in kwargs.items() for key, value in kwargs.items()
@ -316,7 +317,7 @@ class BaseInvenTreeSetting(models.Model):
} }
def get_filters_for_instance(self): def get_filters_for_instance(self):
"""Enable to filter by other fields defined in self.extra_unique_fields""" """Enable to filter by other fields defined in self.extra_unique_fields."""
return { return {
key: getattr(self, key, None) key: getattr(self, key, None)
for key in self.extra_unique_fields for key in self.extra_unique_fields
@ -1090,7 +1091,7 @@ def validate_email_domains(setting):
def currency_exchange_plugins(): def currency_exchange_plugins():
"""Return a set of plugin choices which can be used for currency exchange""" """Return a set of plugin choices which can be used for currency exchange."""
try: try:
from plugin import registry from plugin import registry
@ -1102,7 +1103,7 @@ def currency_exchange_plugins():
def update_exchange_rates(setting): def update_exchange_rates(setting):
"""Update exchange rates when base currency is changed""" """Update exchange rates when base currency is changed."""
if InvenTree.ready.isImportingData(): if InvenTree.ready.isImportingData():
return return
@ -1113,7 +1114,7 @@ def update_exchange_rates(setting):
def reload_plugin_registry(setting): def reload_plugin_registry(setting):
"""When a core plugin setting is changed, reload the plugin registry""" """When a core plugin setting is changed, reload the plugin registry."""
from plugin import registry from plugin import registry
logger.info("Reloading plugin registry due to change in setting '%s'", setting.key) logger.info("Reloading plugin registry due to change in setting '%s'", setting.key)
@ -2752,7 +2753,7 @@ class NotificationEntry(MetaMixin):
class NotificationMessage(models.Model): class NotificationMessage(models.Model):
"""A NotificationMessage is a message sent to a particular user, notifying them of some *important information* """A NotificationMessage is a message sent to a particular user, notifying them of some *important information*.
Notification messages can be generated by a variety of sources. Notification messages can be generated by a variety of sources.
@ -2870,7 +2871,7 @@ class NotesImage(models.Model):
class CustomUnit(models.Model): class CustomUnit(models.Model):
"""Model for storing custom physical unit definitions """Model for storing custom physical unit definitions.
Model Attributes: Model Attributes:
name: Name of the unit name: Name of the unit
@ -2882,7 +2883,7 @@ class CustomUnit(models.Model):
""" """
def fmt_string(self): def fmt_string(self):
"""Construct a unit definition string e.g. 'dog_year = 52 * day = dy'""" """Construct a unit definition string e.g. 'dog_year = 52 * day = dy'."""
fmt = f'{self.name} = {self.definition}' fmt = f'{self.name} = {self.definition}'
if self.symbol: if self.symbol:
@ -2891,7 +2892,7 @@ class CustomUnit(models.Model):
return fmt return fmt
def clean(self): def clean(self):
"""Validate that the provided custom unit is indeed valid""" """Validate that the provided custom unit is indeed valid."""
super().clean() super().clean()
from InvenTree.conversion import get_unit_registry from InvenTree.conversion import get_unit_registry
@ -2946,7 +2947,7 @@ class CustomUnit(models.Model):
@receiver(post_save, sender=CustomUnit, dispatch_uid='custom_unit_saved') @receiver(post_save, sender=CustomUnit, dispatch_uid='custom_unit_saved')
@receiver(post_delete, sender=CustomUnit, dispatch_uid='custom_unit_deleted') @receiver(post_delete, sender=CustomUnit, dispatch_uid='custom_unit_deleted')
def after_custom_unit_updated(sender, instance, **kwargs): def after_custom_unit_updated(sender, instance, **kwargs):
"""Callback when a custom unit is updated or deleted""" """Callback when a custom unit is updated or deleted."""
# Force reload of the unit registry # Force reload of the unit registry
from InvenTree.conversion import reload_unit_registry from InvenTree.conversion import reload_unit_registry

View File

@ -263,7 +263,7 @@ class UIMessageNotification(SingleNotificationMethod):
METHOD_NAME = 'ui_message' METHOD_NAME = 'ui_message'
def get_targets(self): def get_targets(self):
"""Only send notifications for active users""" """Only send notifications for active users."""
return [target for target in self.targets if target.is_active] return [target for target in self.targets if target.is_active]
def send(self, target): def send(self, target):

View File

@ -23,14 +23,14 @@ class SettingsValueField(serializers.Field):
return instance return instance
def to_representation(self, instance): def to_representation(self, instance):
"""Return the value of the setting: """Return the value of the setting.
- Protected settings are returned as '***' Protected settings are returned as '***'
""" """
return '***' if instance.protected else str(instance.value) return '***' if instance.protected else str(instance.value)
def to_internal_value(self, data): def to_internal_value(self, data):
"""Return the internal value of the setting""" """Return the internal value of the setting."""
return str(data) return str(data)

View File

@ -11,7 +11,7 @@ logger = logging.getLogger('inventree')
def currency_code_default(): def currency_code_default():
"""Returns the default currency code (or USD if not specified)""" """Returns the default currency code (or USD if not specified)."""
from common.models import InvenTreeSetting from common.models import InvenTreeSetting
cached_value = cache.get('currency_code_default', '') cached_value = cache.get('currency_code_default', '')

View File

@ -134,14 +134,14 @@ class SettingsTest(InvenTreeTestCase):
self.assertNotIn('SERVER_RESTART_REQUIRED', result) self.assertNotIn('SERVER_RESTART_REQUIRED', result)
def test_all_settings(self): def test_all_settings(self):
"""Make sure that the all_settings function returns correctly""" """Make sure that the all_settings function returns correctly."""
result = InvenTreeSetting.all_settings() result = InvenTreeSetting.all_settings()
self.assertIn('INVENTREE_INSTANCE', result) self.assertIn('INVENTREE_INSTANCE', result)
self.assertIsInstance(result['INVENTREE_INSTANCE'], InvenTreeSetting) self.assertIsInstance(result['INVENTREE_INSTANCE'], InvenTreeSetting)
@mock.patch('common.models.InvenTreeSetting.get_setting_definition') @mock.patch('common.models.InvenTreeSetting.get_setting_definition')
def test_check_all_settings(self, get_setting_definition): def test_check_all_settings(self, get_setting_definition):
"""Make sure that the check_all_settings function returns correctly""" """Make sure that the check_all_settings function returns correctly."""
# define partial schema # define partial schema
settings_definition = { settings_definition = {
'AB': { # key that's has not already been accessed 'AB': { # key that's has not already been accessed
@ -295,7 +295,7 @@ class SettingsTest(InvenTreeTestCase):
) # pragma: no cover ) # pragma: no cover
def test_global_setting_caching(self): def test_global_setting_caching(self):
"""Test caching operations for the global settings class""" """Test caching operations for the global settings class."""
key = 'PART_NAME_FORMAT' key = 'PART_NAME_FORMAT'
cache_key = InvenTreeSetting.create_cache_key(key) cache_key = InvenTreeSetting.create_cache_key(key)
@ -316,7 +316,7 @@ class SettingsTest(InvenTreeTestCase):
self.assertEqual(InvenTreeSetting.get_setting(key), val) self.assertEqual(InvenTreeSetting.get_setting(key), val)
def test_user_setting_caching(self): def test_user_setting_caching(self):
"""Test caching operation for the user settings class""" """Test caching operation for the user settings class."""
cache.clear() cache.clear()
# Generate a number of new users # Generate a number of new users
@ -348,7 +348,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
"""Tests for the global settings API.""" """Tests for the global settings API."""
def setUp(self): def setUp(self):
"""Ensure cache is cleared as part of test setup""" """Ensure cache is cleared as part of test setup."""
cache.clear() cache.clear()
return super().setUp() return super().setUp()
@ -825,7 +825,7 @@ class NotificationTest(InvenTreeAPITestCase):
response = self.post(url, {}, expected_code=405) response = self.post(url, {}, expected_code=405)
def test_bulk_delete(self): def test_bulk_delete(self):
"""Tests for bulk deletion of user notifications""" """Tests for bulk deletion of user notifications."""
from error_report.models import Error from error_report.models import Error
# Create some notification messages by throwing errors # Create some notification messages by throwing errors
@ -1004,17 +1004,17 @@ class ColorThemeTest(TestCase):
class CurrencyAPITests(InvenTreeAPITestCase): class CurrencyAPITests(InvenTreeAPITestCase):
"""Unit tests for the currency exchange API endpoints""" """Unit tests for the currency exchange API endpoints."""
def test_exchange_endpoint(self): def test_exchange_endpoint(self):
"""Test that the currency exchange endpoint works as expected""" """Test that the currency exchange endpoint works as expected."""
response = self.get(reverse('api-currency-exchange'), expected_code=200) response = self.get(reverse('api-currency-exchange'), expected_code=200)
self.assertIn('base_currency', response.data) self.assertIn('base_currency', response.data)
self.assertIn('exchange_rates', response.data) self.assertIn('exchange_rates', response.data)
def test_refresh_endpoint(self): def test_refresh_endpoint(self):
"""Call the 'refresh currencies' endpoint""" """Call the 'refresh currencies' endpoint."""
from djmoney.contrib.exchange.models import Rate from djmoney.contrib.exchange.models import Rate
# Delete any existing exchange rate data # Delete any existing exchange rate data
@ -1074,7 +1074,7 @@ class NotesImageTest(InvenTreeAPITestCase):
self.assertEqual(NotesImage.objects.count(), n) self.assertEqual(NotesImage.objects.count(), n)
def test_valid_image(self): def test_valid_image(self):
"""Test upload of a valid image file""" """Test upload of a valid image file."""
n = NotesImage.objects.count() n = NotesImage.objects.count()
# Construct a simple image file # Construct a simple image file
@ -1100,16 +1100,16 @@ class NotesImageTest(InvenTreeAPITestCase):
class ProjectCodesTest(InvenTreeAPITestCase): class ProjectCodesTest(InvenTreeAPITestCase):
"""Units tests for the ProjectCodes model and API endpoints""" """Units tests for the ProjectCodes model and API endpoints."""
@property @property
def url(self): def url(self):
"""Return the URL for the project code list endpoint""" """Return the URL for the project code list endpoint."""
return reverse('api-project-code-list') return reverse('api-project-code-list')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Create some initial project codes""" """Create some initial project codes."""
super().setUpTestData() super().setUpTestData()
codes = [ codes = [
@ -1122,12 +1122,12 @@ class ProjectCodesTest(InvenTreeAPITestCase):
ProjectCode.objects.bulk_create(codes) ProjectCode.objects.bulk_create(codes)
def test_list(self): def test_list(self):
"""Test that the list endpoint works as expected""" """Test that the list endpoint works as expected."""
response = self.get(self.url, expected_code=200) response = self.get(self.url, expected_code=200)
self.assertEqual(len(response.data), ProjectCode.objects.count()) self.assertEqual(len(response.data), ProjectCode.objects.count())
def test_delete(self): def test_delete(self):
"""Test we can delete a project code via the API""" """Test we can delete a project code via the API."""
n = ProjectCode.objects.count() n = ProjectCode.objects.count()
# Get the first project code # Get the first project code
@ -1143,7 +1143,7 @@ class ProjectCodesTest(InvenTreeAPITestCase):
self.assertEqual(ProjectCode.objects.count(), n - 1) self.assertEqual(ProjectCode.objects.count(), n - 1)
def test_duplicate_code(self): def test_duplicate_code(self):
"""Test that we cannot create two project codes with the same code""" """Test that we cannot create two project codes with the same code."""
# Create a new project code # Create a new project code
response = self.post( response = self.post(
self.url, self.url,
@ -1157,7 +1157,7 @@ class ProjectCodesTest(InvenTreeAPITestCase):
) )
def test_write_access(self): def test_write_access(self):
"""Test that non-staff users have read-only access""" """Test that non-staff users have read-only access."""
# By default user has staff access, can create a new project code # By default user has staff access, can create a new project code
response = self.post( response = self.post(
self.url, self.url,
@ -1196,16 +1196,16 @@ class ProjectCodesTest(InvenTreeAPITestCase):
class CustomUnitAPITest(InvenTreeAPITestCase): class CustomUnitAPITest(InvenTreeAPITestCase):
"""Unit tests for the CustomUnit API""" """Unit tests for the CustomUnit API."""
@property @property
def url(self): def url(self):
"""Return the API endpoint for the CustomUnit list""" """Return the API endpoint for the CustomUnit list."""
return reverse('api-custom-unit-list') return reverse('api-custom-unit-list')
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Construct some initial test fixture data""" """Construct some initial test fixture data."""
super().setUpTestData() super().setUpTestData()
units = [ units = [
@ -1222,12 +1222,12 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
CustomUnit.objects.bulk_create(units) CustomUnit.objects.bulk_create(units)
def test_list(self): def test_list(self):
"""Test API list functionality""" """Test API list functionality."""
response = self.get(self.url, expected_code=200) response = self.get(self.url, expected_code=200)
self.assertEqual(len(response.data), CustomUnit.objects.count()) self.assertEqual(len(response.data), CustomUnit.objects.count())
def test_edit(self): def test_edit(self):
"""Test edit permissions for CustomUnit model""" """Test edit permissions for CustomUnit model."""
unit = CustomUnit.objects.first() unit = CustomUnit.objects.first()
# Try to edit without permission # Try to edit without permission
@ -1254,7 +1254,7 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
self.assertEqual(unit.name, 'new_unit_name') self.assertEqual(unit.name, 'new_unit_name')
def test_validation(self): def test_validation(self):
"""Test that validation works as expected""" """Test that validation works as expected."""
unit = CustomUnit.objects.first() unit = CustomUnit.objects.first()
self.user.is_staff = True self.user.is_staff = True

View File

@ -1,4 +1,4 @@
"""Admin class for the 'company' app""" """Admin class for the 'company' app."""
from django.contrib import admin from django.contrib import admin
@ -25,7 +25,7 @@ class CompanyResource(InvenTreeResource):
"""Class for managing Company data import/export.""" """Class for managing Company data import/export."""
class Meta: class Meta:
"""Metaclass defines extra options""" """Metaclass defines extra options."""
model = Company model = Company
skip_unchanged = True skip_unchanged = True
@ -34,7 +34,7 @@ class CompanyResource(InvenTreeResource):
class CompanyAdmin(ImportExportModelAdmin): class CompanyAdmin(ImportExportModelAdmin):
"""Admin class for the Company model""" """Admin class for the Company model."""
resource_class = CompanyResource resource_class = CompanyResource
@ -47,7 +47,7 @@ class SupplierPartResource(InvenTreeResource):
"""Class for managing SupplierPart data import/export.""" """Class for managing SupplierPart data import/export."""
class Meta: class Meta:
"""Metaclass defines extra admin options""" """Metaclass defines extra admin options."""
model = SupplierPart model = SupplierPart
skip_unchanged = True skip_unchanged = True
@ -64,13 +64,13 @@ class SupplierPartResource(InvenTreeResource):
class SupplierPriceBreakInline(admin.TabularInline): class SupplierPriceBreakInline(admin.TabularInline):
"""Inline for supplier-part pricing""" """Inline for supplier-part pricing."""
model = SupplierPriceBreak model = SupplierPriceBreak
class SupplierPartAdmin(ImportExportModelAdmin): class SupplierPartAdmin(ImportExportModelAdmin):
"""Admin class for the SupplierPart model""" """Admin class for the SupplierPart model."""
resource_class = SupplierPartResource resource_class = SupplierPartResource
@ -87,7 +87,7 @@ class ManufacturerPartResource(InvenTreeResource):
"""Class for managing ManufacturerPart data import/export.""" """Class for managing ManufacturerPart data import/export."""
class Meta: class Meta:
"""Metaclass defines extra admin options""" """Metaclass defines extra admin options."""
model = ManufacturerPart model = ManufacturerPart
skip_unchanged = True skip_unchanged = True
@ -129,7 +129,7 @@ class ManufacturerPartParameterResource(InvenTreeResource):
"""Class for managing ManufacturerPartParameter data import/export.""" """Class for managing ManufacturerPartParameter data import/export."""
class Meta: class Meta:
"""Metaclass defines extra admin options""" """Metaclass defines extra admin options."""
model = ManufacturerPartParameter model = ManufacturerPartParameter
skip_unchanged = True skip_unchanged = True
@ -153,7 +153,7 @@ class SupplierPriceBreakResource(InvenTreeResource):
"""Class for managing SupplierPriceBreak data import/export.""" """Class for managing SupplierPriceBreak data import/export."""
class Meta: class Meta:
"""Metaclass defines extra admin options""" """Metaclass defines extra admin options."""
model = SupplierPriceBreak model = SupplierPriceBreak
skip_unchanged = True skip_unchanged = True
@ -174,7 +174,7 @@ class SupplierPriceBreakResource(InvenTreeResource):
class SupplierPriceBreakAdmin(ImportExportModelAdmin): class SupplierPriceBreakAdmin(ImportExportModelAdmin):
"""Admin class for the SupplierPriceBreak model""" """Admin class for the SupplierPriceBreak model."""
resource_class = SupplierPriceBreakResource resource_class = SupplierPriceBreakResource
@ -184,10 +184,10 @@ class SupplierPriceBreakAdmin(ImportExportModelAdmin):
class AddressResource(InvenTreeResource): class AddressResource(InvenTreeResource):
"""Class for managing Address data import/export""" """Class for managing Address data import/export."""
class Meta: class Meta:
"""Metaclass defining extra options""" """Metaclass defining extra options."""
model = Address model = Address
skip_unchanged = True skip_unchanged = True
@ -198,7 +198,7 @@ class AddressResource(InvenTreeResource):
class AddressAdmin(ImportExportModelAdmin): class AddressAdmin(ImportExportModelAdmin):
"""Admin class for the Address model""" """Admin class for the Address model."""
resource_class = AddressResource resource_class = AddressResource
@ -208,10 +208,10 @@ class AddressAdmin(ImportExportModelAdmin):
class ContactResource(InvenTreeResource): class ContactResource(InvenTreeResource):
"""Class for managing Contact data import/export""" """Class for managing Contact data import/export."""
class Meta: class Meta:
"""Metaclass defining extra options""" """Metaclass defining extra options."""
model = Contact model = Contact
skip_unchanged = True skip_unchanged = True
@ -222,7 +222,7 @@ class ContactResource(InvenTreeResource):
class ContactAdmin(ImportExportModelAdmin): class ContactAdmin(ImportExportModelAdmin):
"""Admin class for the Contact model""" """Admin class for the Contact model."""
resource_class = ContactResource resource_class = ContactResource

View File

@ -52,7 +52,7 @@ class CompanyList(ListCreateAPI):
queryset = Company.objects.all() queryset = Company.objects.all()
def get_queryset(self): def get_queryset(self):
"""Return annotated queryset for the company list endpoint""" """Return annotated queryset for the company list endpoint."""
queryset = super().get_queryset() queryset = super().get_queryset()
queryset = CompanySerializer.annotate_queryset(queryset) queryset = CompanySerializer.annotate_queryset(queryset)
@ -76,7 +76,7 @@ class CompanyDetail(RetrieveUpdateDestroyAPI):
serializer_class = CompanySerializer serializer_class = CompanySerializer
def get_queryset(self): def get_queryset(self):
"""Return annotated queryset for the company detail endpoint""" """Return annotated queryset for the company detail endpoint."""
queryset = super().get_queryset() queryset = super().get_queryset()
queryset = CompanySerializer.annotate_queryset(queryset) queryset = CompanySerializer.annotate_queryset(queryset)
@ -84,7 +84,7 @@ class CompanyDetail(RetrieveUpdateDestroyAPI):
class CompanyAttachmentList(AttachmentMixin, ListCreateDestroyAPIView): class CompanyAttachmentList(AttachmentMixin, ListCreateDestroyAPIView):
"""API endpoint for the CompanyAttachment model""" """API endpoint for the CompanyAttachment model."""
queryset = CompanyAttachment.objects.all() queryset = CompanyAttachment.objects.all()
serializer_class = CompanyAttachmentSerializer serializer_class = CompanyAttachmentSerializer
@ -100,7 +100,7 @@ class CompanyAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI):
class ContactList(ListCreateDestroyAPIView): class ContactList(ListCreateDestroyAPIView):
"""API endpoint for list view of Company model""" """API endpoint for list view of Company model."""
queryset = Contact.objects.all() queryset = Contact.objects.all()
serializer_class = ContactSerializer serializer_class = ContactSerializer
@ -117,14 +117,14 @@ class ContactList(ListCreateDestroyAPIView):
class ContactDetail(RetrieveUpdateDestroyAPI): class ContactDetail(RetrieveUpdateDestroyAPI):
"""Detail endpoint for Company model""" """Detail endpoint for Company model."""
queryset = Contact.objects.all() queryset = Contact.objects.all()
serializer_class = ContactSerializer serializer_class = ContactSerializer
class AddressList(ListCreateDestroyAPIView): class AddressList(ListCreateDestroyAPIView):
"""API endpoint for list view of Address model""" """API endpoint for list view of Address model."""
queryset = Address.objects.all() queryset = Address.objects.all()
serializer_class = AddressSerializer serializer_class = AddressSerializer
@ -139,7 +139,7 @@ class AddressList(ListCreateDestroyAPIView):
class AddressDetail(RetrieveUpdateDestroyAPI): class AddressDetail(RetrieveUpdateDestroyAPI):
"""API endpoint for a single Address object""" """API endpoint for a single Address object."""
queryset = Address.objects.all() queryset = Address.objects.all()
serializer_class = AddressSerializer serializer_class = AddressSerializer
@ -173,7 +173,7 @@ class ManufacturerPartList(ListCreateDestroyAPIView):
filterset_class = ManufacturerPartFilter filterset_class = ManufacturerPartFilter
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint""" """Return serializer instance for this endpoint."""
# Do we wish to include extra detail? # Do we wish to include extra detail?
try: try:
params = self.request.query_params params = self.request.query_params
@ -233,10 +233,10 @@ class ManufacturerPartAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI
class ManufacturerPartParameterFilter(rest_filters.FilterSet): class ManufacturerPartParameterFilter(rest_filters.FilterSet):
"""Custom filterset for the ManufacturerPartParameterList API endpoint""" """Custom filterset for the ManufacturerPartParameterList API endpoint."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = ManufacturerPartParameter model = ManufacturerPartParameter
fields = ['name', 'value', 'units', 'manufacturer_part'] fields = ['name', 'value', 'units', 'manufacturer_part']
@ -258,7 +258,7 @@ class ManufacturerPartParameterList(ListCreateDestroyAPIView):
filterset_class = ManufacturerPartParameterFilter filterset_class = ManufacturerPartParameterFilter
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint""" """Return serializer instance for this endpoint."""
# Do we wish to include any extra detail? # Do we wish to include any extra detail?
try: try:
params = self.request.query_params params = self.request.query_params
@ -288,10 +288,10 @@ class ManufacturerPartParameterDetail(RetrieveUpdateDestroyAPI):
class SupplierPartFilter(rest_filters.FilterSet): class SupplierPartFilter(rest_filters.FilterSet):
"""API filters for the SupplierPartList endpoint""" """API filters for the SupplierPartList endpoint."""
class Meta: class Meta:
"""Metaclass option""" """Metaclass option."""
model = SupplierPart model = SupplierPart
fields = [ fields = [
@ -325,7 +325,7 @@ class SupplierPartList(ListCreateDestroyAPIView):
filterset_class = SupplierPartFilter filterset_class = SupplierPartFilter
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Return annotated queryest object for the SupplierPart list""" """Return annotated queryest object for the SupplierPart list."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = SupplierPartSerializer.annotate_queryset(queryset) queryset = SupplierPartSerializer.annotate_queryset(queryset)
@ -354,7 +354,7 @@ class SupplierPartList(ListCreateDestroyAPIView):
return queryset return queryset
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint""" """Return serializer instance for this endpoint."""
# Do we wish to include extra detail? # Do we wish to include extra detail?
try: try:
params = self.request.query_params params = self.request.query_params
@ -425,10 +425,10 @@ class SupplierPartDetail(RetrieveUpdateDestroyAPI):
class SupplierPriceBreakFilter(rest_filters.FilterSet): class SupplierPriceBreakFilter(rest_filters.FilterSet):
"""Custom API filters for the SupplierPriceBreak list endpoint""" """Custom API filters for the SupplierPriceBreak list endpoint."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = SupplierPriceBreak model = SupplierPriceBreak
fields = ['part', 'quantity'] fields = ['part', 'quantity']
@ -456,7 +456,7 @@ class SupplierPriceBreakList(ListCreateAPI):
filterset_class = SupplierPriceBreakFilter filterset_class = SupplierPriceBreakFilter
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint""" """Return serializer instance for this endpoint."""
try: try:
params = self.request.query_params params = self.request.query_params

View File

@ -1,10 +1,10 @@
"""Config for the 'company' app""" """Config for the 'company' app."""
from django.apps import AppConfig from django.apps import AppConfig
class CompanyConfig(AppConfig): class CompanyConfig(AppConfig):
"""Config class for the 'company' app""" """Config class for the 'company' app."""
name = 'company' name = 'company'

View File

@ -90,7 +90,7 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
""" """
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options."""
ordering = ['name'] ordering = ['name']
constraints = [ constraints = [
@ -100,7 +100,7 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the Company model""" """Return the API URL associated with the Company model."""
return reverse('api-company-list') return reverse('api-company-list')
name = models.CharField( name = models.CharField(
@ -186,7 +186,7 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
@property @property
def address(self): def address(self):
"""Return the string representation for the primary address """Return the string representation for the primary address.
This property exists for backwards compatibility This property exists for backwards compatibility
""" """
@ -196,7 +196,7 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
@property @property
def primary_address(self): def primary_address(self):
"""Returns address object of primary address. Parsed by serializer""" """Returns address object of primary address. Parsed by serializer."""
return Address.objects.filter(company=self.id).filter(primary=True).first() return Address.objects.filter(company=self.id).filter(primary=True).first()
@property @property
@ -251,15 +251,15 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
class CompanyAttachment(InvenTreeAttachment): class CompanyAttachment(InvenTreeAttachment):
"""Model for storing file or URL attachments against a Company object""" """Model for storing file or URL attachments against a Company object."""
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with this model""" """Return the API URL associated with this model."""
return reverse('api-company-attachment-list') return reverse('api-company-attachment-list')
def getSubdir(self): def getSubdir(self):
"""Return the subdirectory where these attachments are uploaded""" """Return the subdirectory where these attachments are uploaded."""
return os.path.join('company_files', str(self.company.pk)) return os.path.join('company_files', str(self.company.pk))
company = models.ForeignKey( company = models.ForeignKey(
@ -283,7 +283,7 @@ class Contact(MetadataMixin, models.Model):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the Contcat model""" """Return the API URL associated with the Contcat model."""
return reverse('api-contact-list') return reverse('api-contact-list')
company = models.ForeignKey( company = models.ForeignKey(
@ -300,7 +300,7 @@ class Contact(MetadataMixin, models.Model):
class Address(models.Model): class Address(models.Model):
"""An address represents a physical location where the company is located. It is possible for a company to have multiple locations """An address represents a physical location where the company is located. It is possible for a company to have multiple locations.
Attributes: Attributes:
company: Company link for this address company: Company link for this address
@ -316,16 +316,16 @@ class Address(models.Model):
""" """
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options."""
verbose_name_plural = 'Addresses' verbose_name_plural = 'Addresses'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Custom init function""" """Custom init function."""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def __str__(self): def __str__(self):
"""Defines string representation of address to supple a one-line to API calls""" """Defines string representation of address to supple a one-line to API calls."""
available_lines = [ available_lines = [
self.line1, self.line1,
self.line2, self.line2,
@ -343,8 +343,9 @@ class Address(models.Model):
return ', '.join(populated_lines) return ', '.join(populated_lines)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Run checks when saving an address: """Run checks when saving an address.
Rules:
- If this address is marked as "primary", ensure that all other addresses for this company are marked as non-primary - If this address is marked as "primary", ensure that all other addresses for this company are marked as non-primary
""" """
others = list( others = list(
@ -366,7 +367,7 @@ class Address(models.Model):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the Contcat model""" """Return the API URL associated with the Contcat model."""
return reverse('api-address-list') return reverse('api-address-list')
company = models.ForeignKey( company = models.ForeignKey(
@ -465,13 +466,13 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
""" """
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options."""
unique_together = ('part', 'manufacturer', 'MPN') unique_together = ('part', 'manufacturer', 'MPN')
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the ManufacturerPart instance""" """Return the API URL associated with the ManufacturerPart instance."""
return reverse('api-manufacturer-part-list') return reverse('api-manufacturer-part-list')
part = models.ForeignKey( part = models.ForeignKey(
@ -542,7 +543,7 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
return manufacturer_part return manufacturer_part
def __str__(self): def __str__(self):
"""Format a string representation of a ManufacturerPart""" """Format a string representation of a ManufacturerPart."""
s = '' s = ''
if self.manufacturer: if self.manufacturer:
@ -559,11 +560,11 @@ class ManufacturerPartAttachment(InvenTreeAttachment):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the ManufacturerPartAttachment model""" """Return the API URL associated with the ManufacturerPartAttachment model."""
return reverse('api-manufacturer-part-attachment-list') return reverse('api-manufacturer-part-attachment-list')
def getSubdir(self): def getSubdir(self):
"""Return the subdirectory where attachment files for the ManufacturerPart model are located""" """Return the subdirectory where attachment files for the ManufacturerPart model are located."""
return os.path.join('manufacturer_part_files', str(self.manufacturer_part.id)) return os.path.join('manufacturer_part_files', str(self.manufacturer_part.id))
manufacturer_part = models.ForeignKey( manufacturer_part = models.ForeignKey(
@ -583,13 +584,13 @@ class ManufacturerPartParameter(models.Model):
""" """
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options."""
unique_together = ('manufacturer_part', 'name') unique_together = ('manufacturer_part', 'name')
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the ManufacturerPartParameter model""" """Return the API URL associated with the ManufacturerPartParameter model."""
return reverse('api-manufacturer-part-parameter-list') return reverse('api-manufacturer-part-parameter-list')
manufacturer_part = models.ForeignKey( manufacturer_part = models.ForeignKey(
@ -630,7 +631,7 @@ class SupplierPartManager(models.Manager):
""" """
def get_queryset(self): def get_queryset(self):
"""Prefetch related fields when querying against the SupplierPart model""" """Prefetch related fields when querying against the SupplierPart model."""
# Always prefetch related models # Always prefetch related models
return ( return (
super() super()
@ -660,7 +661,7 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
""" """
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options."""
unique_together = ('part', 'supplier', 'SKU') unique_together = ('part', 'supplier', 'SKU')
@ -673,20 +674,21 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the SupplierPart model""" """Return the API URL associated with the SupplierPart model."""
return reverse('api-supplier-part-list') return reverse('api-supplier-part-list')
def get_absolute_url(self): def get_absolute_url(self):
"""Return the web URL of the detail view for this SupplierPart""" """Return the web URL of the detail view for this SupplierPart."""
return reverse('supplier-part-detail', kwargs={'pk': self.id}) return reverse('supplier-part-detail', kwargs={'pk': self.id})
def api_instance_filters(self): def api_instance_filters(self):
"""Return custom API filters for this particular instance""" """Return custom API filters for this particular instance."""
return {'manufacturer_part': {'part': self.part.pk}} return {'manufacturer_part': {'part': self.part.pk}}
def clean(self): def clean(self):
"""Custom clean action for the SupplierPart model: """Custom clean action for the SupplierPart model.
Rules:
- Ensure that manufacturer_part.part and part are the same! - Ensure that manufacturer_part.part and part are the same!
""" """
super().clean() super().clean()
@ -888,7 +890,7 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
) )
def update_available_quantity(self, quantity): def update_available_quantity(self, quantity):
"""Update the available quantity for this SupplierPart""" """Update the available quantity for this SupplierPart."""
self.available = quantity self.available = quantity
self.availability_updated = datetime.now() self.availability_updated = datetime.now()
self.save() self.save()
@ -911,7 +913,7 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
@property @property
def has_price_breaks(self): def has_price_breaks(self):
"""Return True if this SupplierPart has associated price breaks""" """Return True if this SupplierPart has associated price breaks."""
return self.price_breaks.count() > 0 return self.price_breaks.count() > 0
@property @property
@ -921,7 +923,7 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
@property @property
def unit_pricing(self): def unit_pricing(self):
"""Return the single-quantity pricing for this SupplierPart""" """Return the single-quantity pricing for this SupplierPart."""
return self.get_price(1) return self.get_price(1)
def add_price_break(self, quantity, price) -> None: def add_price_break(self, quantity, price) -> None:
@ -971,11 +973,11 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
@property @property
def pretty_name(self): def pretty_name(self):
"""Format a 'pretty' name for this SupplierPart""" """Format a 'pretty' name for this SupplierPart."""
return str(self) return str(self)
def __str__(self): def __str__(self):
"""Format a string representation of a SupplierPart""" """Format a string representation of a SupplierPart."""
s = '' s = ''
if self.part.IPN: if self.part.IPN:
@ -1005,7 +1007,7 @@ class SupplierPriceBreak(common.models.PriceBreak):
""" """
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options."""
unique_together = ('part', 'quantity') unique_together = ('part', 'quantity')
@ -1013,12 +1015,12 @@ class SupplierPriceBreak(common.models.PriceBreak):
db_table = 'part_supplierpricebreak' db_table = 'part_supplierpricebreak'
def __str__(self): def __str__(self):
"""Format a string representation of a SupplierPriceBreak instance""" """Format a string representation of a SupplierPriceBreak instance."""
return f'{self.part.SKU} - {self.price} @ {self.quantity}' return f'{self.part.SKU} - {self.price} @ {self.quantity}'
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the SupplierPriceBreak model""" """Return the API URL associated with the SupplierPriceBreak model."""
return reverse('api-part-supplier-price-list') return reverse('api-part-supplier-price-list')
part = models.ForeignKey( part = models.ForeignKey(
@ -1033,7 +1035,7 @@ class SupplierPriceBreak(common.models.PriceBreak):
post_save, sender=SupplierPriceBreak, dispatch_uid='post_save_supplier_price_break' post_save, sender=SupplierPriceBreak, dispatch_uid='post_save_supplier_price_break'
) )
def after_save_supplier_price(sender, instance, created, **kwargs): def after_save_supplier_price(sender, instance, created, **kwargs):
"""Callback function when a SupplierPriceBreak is created or updated""" """Callback function when a SupplierPriceBreak is created or updated."""
if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData(): if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData():
if instance.part and instance.part.part: if instance.part and instance.part.part:
instance.part.part.schedule_pricing_update(create=True) instance.part.part.schedule_pricing_update(create=True)
@ -1045,7 +1047,7 @@ def after_save_supplier_price(sender, instance, created, **kwargs):
dispatch_uid='post_delete_supplier_price_break', dispatch_uid='post_delete_supplier_price_break',
) )
def after_delete_supplier_price(sender, instance, **kwargs): def after_delete_supplier_price(sender, instance, **kwargs):
"""Callback function when a SupplierPriceBreak is deleted""" """Callback function when a SupplierPriceBreak is deleted."""
if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData(): if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData():
if instance.part and instance.part.part: if instance.part and instance.part.part:
instance.part.part.schedule_pricing_update(create=False) instance.part.part.schedule_pricing_update(create=False)

View File

@ -36,7 +36,7 @@ from .models import (
class CompanyBriefSerializer(InvenTreeModelSerializer): class CompanyBriefSerializer(InvenTreeModelSerializer):
"""Serializer for Company object (limited detail)""" """Serializer for Company object (limited detail)."""
class Meta: class Meta:
"""Metaclass options.""" """Metaclass options."""
@ -50,10 +50,10 @@ class CompanyBriefSerializer(InvenTreeModelSerializer):
class AddressSerializer(InvenTreeModelSerializer): class AddressSerializer(InvenTreeModelSerializer):
"""Serializer for the Address Model""" """Serializer for the Address Model."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = Address model = Address
fields = [ fields = [
@ -74,10 +74,10 @@ class AddressSerializer(InvenTreeModelSerializer):
class AddressBriefSerializer(InvenTreeModelSerializer): class AddressBriefSerializer(InvenTreeModelSerializer):
"""Serializer for Address Model (limited)""" """Serializer for Address Model (limited)."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = Address model = Address
fields = [ fields = [
@ -94,7 +94,7 @@ class AddressBriefSerializer(InvenTreeModelSerializer):
class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer): class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer):
"""Serializer for Company object (full detail)""" """Serializer for Company object (full detail)."""
class Meta: class Meta:
"""Metaclass options.""" """Metaclass options."""
@ -127,7 +127,7 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer):
@staticmethod @staticmethod
def annotate_queryset(queryset): def annotate_queryset(queryset):
"""Annotate the supplied queryset with aggregated information""" """Annotate the supplied queryset with aggregated information."""
# Add count of parts manufactured # Add count of parts manufactured
queryset = queryset.annotate( queryset = queryset.annotate(
parts_manufactured=SubqueryCount('manufactured_parts') parts_manufactured=SubqueryCount('manufactured_parts')
@ -154,7 +154,7 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer):
) )
def save(self): def save(self):
"""Save the Company instance""" """Save the Company instance."""
super().save() super().save()
company = self.instance company = self.instance
@ -176,10 +176,10 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer):
class CompanyAttachmentSerializer(InvenTreeAttachmentSerializer): class CompanyAttachmentSerializer(InvenTreeAttachmentSerializer):
"""Serializer for the CompanyAttachment class""" """Serializer for the CompanyAttachment class."""
class Meta: class Meta:
"""Metaclass defines serializer options""" """Metaclass defines serializer options."""
model = CompanyAttachment model = CompanyAttachment
@ -187,10 +187,10 @@ class CompanyAttachmentSerializer(InvenTreeAttachmentSerializer):
class ContactSerializer(InvenTreeModelSerializer): class ContactSerializer(InvenTreeModelSerializer):
"""Serializer class for the Contact model""" """Serializer class for the Contact model."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = Contact model = Contact
fields = ['pk', 'company', 'name', 'phone', 'email', 'role'] fields = ['pk', 'company', 'name', 'phone', 'email', 'role']
@ -220,7 +220,7 @@ class ManufacturerPartSerializer(InvenTreeTagModelSerializer):
tags = TagListSerializerField(required=False) tags = TagListSerializerField(required=False)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialize this serializer with extra detail fields as required""" """Initialize this serializer with extra detail fields as required."""
part_detail = kwargs.pop('part_detail', True) part_detail = kwargs.pop('part_detail', True)
manufacturer_detail = kwargs.pop('manufacturer_detail', True) manufacturer_detail = kwargs.pop('manufacturer_detail', True)
prettify = kwargs.pop('pretty', False) prettify = kwargs.pop('pretty', False)
@ -278,7 +278,7 @@ class ManufacturerPartParameterSerializer(InvenTreeModelSerializer):
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialize this serializer with extra detail fields as required""" """Initialize this serializer with extra detail fields as required."""
man_detail = kwargs.pop('manufacturer_part_detail', False) man_detail = kwargs.pop('manufacturer_part_detail', False)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -331,7 +331,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
tags = TagListSerializerField(required=False) tags = TagListSerializerField(required=False)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialize this serializer with extra detail fields as required""" """Initialize this serializer with extra detail fields as required."""
# Check if 'available' quantity was supplied # Check if 'available' quantity was supplied
self.has_available_quantity = 'available' in kwargs.get('data', {}) self.has_available_quantity = 'available' in kwargs.get('data', {})
@ -395,7 +395,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
@staticmethod @staticmethod
def annotate_queryset(queryset): def annotate_queryset(queryset):
"""Annotate the SupplierPart queryset with extra fields: """Annotate the SupplierPart queryset with extra fields.
Fields: Fields:
in_stock: Current stock quantity for each SupplierPart in_stock: Current stock quantity for each SupplierPart
@ -405,7 +405,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
return queryset return queryset
def update(self, supplier_part, data): def update(self, supplier_part, data):
"""Custom update functionality for the serializer""" """Custom update functionality for the serializer."""
available = data.pop('available', None) available = data.pop('available', None)
response = super().update(supplier_part, data) response = super().update(supplier_part, data)
@ -457,7 +457,7 @@ class SupplierPriceBreakSerializer(InvenTreeModelSerializer):
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialize this serializer with extra fields as required""" """Initialize this serializer with extra fields as required."""
supplier_detail = kwargs.pop('supplier_detail', False) supplier_detail = kwargs.pop('supplier_detail', False)
part_detail = kwargs.pop('part_detail', False) part_detail = kwargs.pop('part_detail', False)

View File

@ -1,4 +1,4 @@
"""Unit testing for the company app API functions""" """Unit testing for the company app API functions."""
from django.urls import reverse from django.urls import reverse
@ -16,7 +16,7 @@ class CompanyTest(InvenTreeAPITestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Perform initialization for the unit test class""" """Perform initialization for the unit test class."""
super().setUpTestData() super().setUpTestData()
# Create some company objects to work with # Create some company objects to work with
@ -34,7 +34,7 @@ class CompanyTest(InvenTreeAPITestCase):
) )
def test_company_list(self): def test_company_list(self):
"""Test the list API endpoint for the Company model""" """Test the list API endpoint for the Company model."""
url = reverse('api-company-list') url = reverse('api-company-list')
# There should be three companies # There should be three companies
@ -133,13 +133,13 @@ class CompanyTest(InvenTreeAPITestCase):
class ContactTest(InvenTreeAPITestCase): class ContactTest(InvenTreeAPITestCase):
"""Tests for the Contact models""" """Tests for the Contact models."""
roles = [] roles = []
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Perform init for this test class""" """Perform init for this test class."""
super().setUpTestData() super().setUpTestData()
# Create some companies # Create some companies
@ -163,7 +163,7 @@ class ContactTest(InvenTreeAPITestCase):
cls.url = reverse('api-contact-list') cls.url = reverse('api-contact-list')
def test_list(self): def test_list(self):
"""Test company list API endpoint""" """Test company list API endpoint."""
# List all results # List all results
response = self.get(self.url, {}, expected_code=200) response = self.get(self.url, {}, expected_code=200)
@ -180,7 +180,7 @@ class ContactTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), 3) self.assertEqual(len(response.data), 3)
def test_create(self): def test_create(self):
"""Test that we can create a new Contact object via the API""" """Test that we can create a new Contact object via the API."""
n = Contact.objects.count() n = Contact.objects.count()
company = Company.objects.first() company = Company.objects.first()
@ -199,7 +199,7 @@ class ContactTest(InvenTreeAPITestCase):
self.assertEqual(Contact.objects.count(), n + 1) self.assertEqual(Contact.objects.count(), n + 1)
def test_edit(self): def test_edit(self):
"""Test that we can edit a Contact via the API""" """Test that we can edit a Contact via the API."""
# Get the first contact # Get the first contact
contact = Contact.objects.first() contact = Contact.objects.first()
# Use this contact in the tests # Use this contact in the tests
@ -222,7 +222,7 @@ class ContactTest(InvenTreeAPITestCase):
self.assertEqual(contact.role, 'x') self.assertEqual(contact.role, 'x')
def test_delete(self): def test_delete(self):
"""Tests that we can delete a Contact via the API""" """Tests that we can delete a Contact via the API."""
# Get the last contact # Get the last contact
contact = Contact.objects.first() contact = Contact.objects.first()
url = reverse('api-contact-detail', kwargs={'pk': contact.pk}) url = reverse('api-contact-detail', kwargs={'pk': contact.pk})
@ -239,13 +239,13 @@ class ContactTest(InvenTreeAPITestCase):
class AddressTest(InvenTreeAPITestCase): class AddressTest(InvenTreeAPITestCase):
"""Test cases for Address API endpoints""" """Test cases for Address API endpoints."""
roles = [] roles = []
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Perform initialization for this test class""" """Perform initialization for this test class."""
super().setUpTestData() super().setUpTestData()
cls.num_companies = 3 cls.num_companies = 3
cls.num_addr = 3 cls.num_addr = 3
@ -271,13 +271,13 @@ class AddressTest(InvenTreeAPITestCase):
Address.objects.bulk_create(addresses) Address.objects.bulk_create(addresses)
def test_list(self): def test_list(self):
"""Test listing all addresses without filtering""" """Test listing all addresses without filtering."""
response = self.get(self.url, expected_code=200) response = self.get(self.url, expected_code=200)
self.assertEqual(len(response.data), self.num_companies * self.num_addr) self.assertEqual(len(response.data), self.num_companies * self.num_addr)
def test_filter_list(self): def test_filter_list(self):
"""Test listing addresses filtered on company""" """Test listing addresses filtered on company."""
company = Company.objects.first() company = Company.objects.first()
response = self.get(self.url, {'company': company.pk}, expected_code=200) response = self.get(self.url, {'company': company.pk}, expected_code=200)
@ -285,7 +285,7 @@ class AddressTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), self.num_addr) self.assertEqual(len(response.data), self.num_addr)
def test_create(self): def test_create(self):
"""Test creating a new address""" """Test creating a new address."""
company = Company.objects.first() company = Company.objects.first()
self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=403) self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=403)
@ -295,7 +295,7 @@ class AddressTest(InvenTreeAPITestCase):
self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=201) self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=201)
def test_get(self): def test_get(self):
"""Test that objects are properly returned from a get""" """Test that objects are properly returned from a get."""
addr = Address.objects.first() addr = Address.objects.first()
url = reverse('api-address-detail', kwargs={'pk': addr.pk}) url = reverse('api-address-detail', kwargs={'pk': addr.pk})
@ -315,7 +315,7 @@ class AddressTest(InvenTreeAPITestCase):
self.assertIn(key, response.data) self.assertIn(key, response.data)
def test_edit(self): def test_edit(self):
"""Test editing an object""" """Test editing an object."""
addr = Address.objects.first() addr = Address.objects.first()
url = reverse('api-address-detail', kwargs={'pk': addr.pk}) url = reverse('api-address-detail', kwargs={'pk': addr.pk})
@ -331,7 +331,7 @@ class AddressTest(InvenTreeAPITestCase):
self.assertEqual(data['title'], 'World') self.assertEqual(data['title'], 'World')
def test_delete(self): def test_delete(self):
"""Test deleting an object""" """Test deleting an object."""
addr = Address.objects.first() addr = Address.objects.first()
url = reverse('api-address-detail', kwargs={'pk': addr.pk}) url = reverse('api-address-detail', kwargs={'pk': addr.pk})
@ -360,7 +360,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
roles = ['part.add', 'part.change'] roles = ['part.add', 'part.change']
def test_manufacturer_part_list(self): def test_manufacturer_part_list(self):
"""Test the ManufacturerPart API list functionality""" """Test the ManufacturerPart API list functionality."""
url = reverse('api-manufacturer-part-list') url = reverse('api-manufacturer-part-list')
# There should be three manufacturer parts # There should be three manufacturer parts
@ -399,14 +399,14 @@ class ManufacturerTest(InvenTreeAPITestCase):
self.assertEqual(response.data['MPN'], 'MPN-TEST-123') self.assertEqual(response.data['MPN'], 'MPN-TEST-123')
def test_manufacturer_part_search(self): def test_manufacturer_part_search(self):
"""Test search functionality in manufacturer list""" """Test search functionality in manufacturer list."""
url = reverse('api-manufacturer-part-list') url = reverse('api-manufacturer-part-list')
data = {'search': 'MPN'} data = {'search': 'MPN'}
response = self.get(url, data) response = self.get(url, data)
self.assertEqual(len(response.data), 3) self.assertEqual(len(response.data), 3)
def test_supplier_part_create(self): def test_supplier_part_create(self):
"""Test a SupplierPart can be created via the API""" """Test a SupplierPart can be created via the API."""
url = reverse('api-supplier-part-list') url = reverse('api-supplier-part-list')
# Create a manufacturer part # Create a manufacturer part
@ -449,7 +449,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
class SupplierPartTest(InvenTreeAPITestCase): class SupplierPartTest(InvenTreeAPITestCase):
"""Unit tests for the SupplierPart API endpoints""" """Unit tests for the SupplierPart API endpoints."""
fixtures = [ fixtures = [
'category', 'category',
@ -463,7 +463,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
roles = ['part.add', 'part.change', 'part.add', 'purchase_order.change'] roles = ['part.add', 'part.change', 'part.add', 'purchase_order.change']
def test_supplier_part_list(self): def test_supplier_part_list(self):
"""Test the SupplierPart API list functionality""" """Test the SupplierPart API list functionality."""
url = reverse('api-supplier-part-list') url = reverse('api-supplier-part-list')
# Return *all* SupplierParts # Return *all* SupplierParts
@ -484,7 +484,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), n) self.assertEqual(len(response.data), n)
def test_available(self): def test_available(self):
"""Tests for updating the 'available' field""" """Tests for updating the 'available' field."""
url = reverse('api-supplier-part-list') url = reverse('api-supplier-part-list')
# Should fail when sending an invalid 'available' field # Should fail when sending an invalid 'available' field
@ -545,7 +545,7 @@ class CompanyMetadataAPITest(InvenTreeAPITestCase):
roles = ['company.change', 'purchase_order.change', 'part.change'] roles = ['company.change', 'purchase_order.change', 'part.change']
def metatester(self, apikey, model): def metatester(self, apikey, model):
"""Generic tester""" """Generic tester."""
modeldata = model.objects.first() modeldata = model.objects.first()
# Useless test unless a model object is found # Useless test unless a model object is found
@ -571,7 +571,7 @@ class CompanyMetadataAPITest(InvenTreeAPITestCase):
) )
def test_metadata(self): def test_metadata(self):
"""Test all endpoints""" """Test all endpoints."""
for apikey, model in { for apikey, model in {
'api-manufacturer-part-metadata': ManufacturerPart, 'api-manufacturer-part-metadata': ManufacturerPart,
'api-supplier-part-metadata': SupplierPart, 'api-supplier-part-metadata': SupplierPart,

View File

@ -6,7 +6,7 @@ from InvenTree import unit_test
class TestForwardMigrations(MigratorTestCase): class TestForwardMigrations(MigratorTestCase):
"""Unit testing class for testing 'company' app migrations""" """Unit testing class for testing 'company' app migrations."""
migrate_from = ('company', unit_test.getOldestMigrationFile('company')) migrate_from = ('company', unit_test.getOldestMigrationFile('company'))
migrate_to = ('company', unit_test.getNewestMigrationFile('company')) migrate_to = ('company', unit_test.getNewestMigrationFile('company'))
@ -20,7 +20,7 @@ class TestForwardMigrations(MigratorTestCase):
) )
def test_migrations(self): def test_migrations(self):
"""Test the database state after applying all migrations""" """Test the database state after applying all migrations."""
Company = self.new_state.apps.get_model('company', 'company') Company = self.new_state.apps.get_model('company', 'company')
self.assertEqual(Company.objects.count(), 1) self.assertEqual(Company.objects.count(), 1)
@ -33,8 +33,9 @@ class TestManufacturerField(MigratorTestCase):
migrate_to = ('company', '0019_auto_20200413_0642') migrate_to = ('company', '0019_auto_20200413_0642')
def prepare(self): def prepare(self):
"""Prepare the database by adding some test data 'before' the change: """Prepare the database by adding some test data 'before' the change.
Changes:
- Part object - Part object
- Company object (supplier) - Company object (supplier)
- SupplierPart object - SupplierPart object
@ -99,8 +100,9 @@ class TestManufacturerPart(MigratorTestCase):
migrate_to = ('company', '0037_supplierpart_update_3') migrate_to = ('company', '0037_supplierpart_update_3')
def prepare(self): def prepare(self):
"""Prepare the database by adding some test data 'before' the change: """Prepare the database by adding some test data 'before' the change.
Changes:
- Part object - Part object
- Company object (supplier) - Company object (supplier)
- SupplierPart object - SupplierPart object
@ -217,8 +219,9 @@ class TestCurrencyMigration(MigratorTestCase):
migrate_to = ('company', '0026_auto_20201110_1011') migrate_to = ('company', '0026_auto_20201110_1011')
def prepare(self): def prepare(self):
"""Prepare some data: """Prepare some data.
Changes:
- A part to buy - A part to buy
- A supplier to buy from - A supplier to buy from
- A supplier part - A supplier part
@ -269,7 +272,7 @@ class TestCurrencyMigration(MigratorTestCase):
self.assertIsNone(pb.price) self.assertIsNone(pb.price)
def test_currency_migration(self): def test_currency_migration(self):
"""Test database state after applying migrations""" """Test database state after applying migrations."""
PB = self.new_state.apps.get_model('company', 'supplierpricebreak') PB = self.new_state.apps.get_model('company', 'supplierpricebreak')
for pb in PB.objects.all(): for pb in PB.objects.all():
@ -278,7 +281,7 @@ class TestCurrencyMigration(MigratorTestCase):
class TestAddressMigration(MigratorTestCase): class TestAddressMigration(MigratorTestCase):
"""Test moving address data into Address model""" """Test moving address data into Address model."""
migrate_from = ('company', '0063_auto_20230502_1956') migrate_from = ('company', '0063_auto_20230502_1956')
migrate_to = ('company', '0064_move_address_field_to_address_model') migrate_to = ('company', '0064_move_address_field_to_address_model')
@ -289,14 +292,14 @@ class TestAddressMigration(MigratorTestCase):
l2 = 'splitting functionality' l2 = 'splitting functionality'
def prepare(self): def prepare(self):
"""Set up some companies with addresses""" """Set up some companies with addresses."""
Company = self.old_state.apps.get_model('company', 'company') Company = self.old_state.apps.get_model('company', 'company')
Company.objects.create(name='Company 1', address=self.short_l1) Company.objects.create(name='Company 1', address=self.short_l1)
Company.objects.create(name='Company 2', address=self.long_l1 + self.l2) Company.objects.create(name='Company 2', address=self.long_l1 + self.l2)
def test_address_migration(self): def test_address_migration(self):
"""Test database state after applying the migration""" """Test database state after applying the migration."""
Address = self.new_state.apps.get_model('company', 'address') Address = self.new_state.apps.get_model('company', 'address')
Company = self.new_state.apps.get_model('company', 'company') Company = self.new_state.apps.get_model('company', 'company')
@ -323,7 +326,7 @@ class TestSupplierPartQuantity(MigratorTestCase):
migrate_to = ('company', unit_test.getNewestMigrationFile('company')) migrate_to = ('company', unit_test.getNewestMigrationFile('company'))
def prepare(self): def prepare(self):
"""Prepare a number of SupplierPart objects""" """Prepare a number of SupplierPart objects."""
Part = self.old_state.apps.get_model('part', 'part') Part = self.old_state.apps.get_model('part', 'part')
Company = self.old_state.apps.get_model('company', 'company') Company = self.old_state.apps.get_model('company', 'company')
SupplierPart = self.old_state.apps.get_model('company', 'supplierpart') SupplierPart = self.old_state.apps.get_model('company', 'supplierpart')

View File

@ -1,4 +1,4 @@
"""Unit tests specific to the SupplierPart model""" """Unit tests specific to the SupplierPart model."""
from decimal import Decimal from decimal import Decimal
@ -10,10 +10,10 @@ from part.models import Part
class SupplierPartPackUnitsTests(InvenTreeTestCase): class SupplierPartPackUnitsTests(InvenTreeTestCase):
"""Unit tests for the SupplierPart pack_quantity field""" """Unit tests for the SupplierPart pack_quantity field."""
def test_pack_quantity_dimensionless(self): def test_pack_quantity_dimensionless(self):
"""Test valid values for the 'pack_quantity' field""" """Test valid values for the 'pack_quantity' field."""
# Create a part without units (dimensionless) # Create a part without units (dimensionless)
part = Part.objects.create( part = Part.objects.create(
name='Test Part', description='Test part description', component=True name='Test Part', description='Test part description', component=True
@ -48,7 +48,7 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase):
sp.full_clean() sp.full_clean()
def test_pack_quantity(self): def test_pack_quantity(self):
"""Test pack_quantity for a part with a specified dimension""" """Test pack_quantity for a part with a specified dimension."""
# Create a part with units 'm' # Create a part with units 'm'
part = Part.objects.create( part = Part.objects.create(
name='Test Part', name='Test Part',

View File

@ -1,4 +1,4 @@
"""Unit tests for Company views (see views.py)""" """Unit tests for Company views (see views.py)."""
from django.urls import reverse from django.urls import reverse

View File

@ -1,4 +1,4 @@
"""Unit tests for the models in the 'company' app""" """Unit tests for the models in the 'company' app."""
import os import os
from decimal import Decimal from decimal import Decimal
@ -19,7 +19,7 @@ from .models import (
class CompanySimpleTest(TestCase): class CompanySimpleTest(TestCase):
"""Unit tests for the Company model""" """Unit tests for the Company model."""
fixtures = [ fixtures = [
'company', 'company',
@ -34,7 +34,7 @@ class CompanySimpleTest(TestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Perform initialization for the tests in this class""" """Perform initialization for the tests in this class."""
super().setUpTestData() super().setUpTestData()
Company.objects.create( Company.objects.create(
@ -51,18 +51,18 @@ class CompanySimpleTest(TestCase):
cls.zergm312 = SupplierPart.objects.get(SKU='ZERGM312') cls.zergm312 = SupplierPart.objects.get(SKU='ZERGM312')
def test_company_model(self): def test_company_model(self):
"""Tests for the company model data""" """Tests for the company model data."""
c = Company.objects.get(name='ABC Co.') c = Company.objects.get(name='ABC Co.')
self.assertEqual(c.name, 'ABC Co.') self.assertEqual(c.name, 'ABC Co.')
self.assertEqual(str(c), 'ABC Co. - Seller of ABC products') self.assertEqual(str(c), 'ABC Co. - Seller of ABC products')
def test_company_url(self): def test_company_url(self):
"""Test the detail URL for a company""" """Test the detail URL for a company."""
c = Company.objects.get(pk=1) c = Company.objects.get(pk=1)
self.assertEqual(c.get_absolute_url(), '/company/1/') self.assertEqual(c.get_absolute_url(), '/company/1/')
def test_image_renamer(self): def test_image_renamer(self):
"""Test the company image upload functionality""" """Test the company image upload functionality."""
c = Company.objects.get(pk=1) c = Company.objects.get(pk=1)
rn = rename_company_image(c, 'test.png') rn = rename_company_image(c, 'test.png')
self.assertEqual(rn, 'company_images' + os.path.sep + 'company_1_img.png') self.assertEqual(rn, 'company_images' + os.path.sep + 'company_1_img.png')
@ -71,7 +71,7 @@ class CompanySimpleTest(TestCase):
self.assertEqual(rn, 'company_images' + os.path.sep + 'company_1_img') self.assertEqual(rn, 'company_images' + os.path.sep + 'company_1_img')
def test_price_breaks(self): def test_price_breaks(self):
"""Unit tests for price breaks""" """Unit tests for price breaks."""
self.assertTrue(self.acme0001.has_price_breaks) self.assertTrue(self.acme0001.has_price_breaks)
self.assertTrue(self.acme0002.has_price_breaks) self.assertTrue(self.acme0002.has_price_breaks)
self.assertTrue(self.zergm312.has_price_breaks) self.assertTrue(self.zergm312.has_price_breaks)
@ -100,7 +100,7 @@ class CompanySimpleTest(TestCase):
self.assertEqual(p(55), 68.75) self.assertEqual(p(55), 68.75)
def test_part_pricing(self): def test_part_pricing(self):
"""Unit tests for supplier part pricing""" """Unit tests for supplier part pricing."""
m2x4 = Part.objects.get(name='M2x4 LPHS') m2x4 = Part.objects.get(name='M2x4 LPHS')
self.assertEqual(m2x4.get_price_info(5.5), '38.5 - 41.25') self.assertEqual(m2x4.get_price_info(5.5), '38.5 - 41.25')
@ -153,10 +153,10 @@ class CompanySimpleTest(TestCase):
class ContactSimpleTest(TestCase): class ContactSimpleTest(TestCase):
"""Unit tests for the Contact model""" """Unit tests for the Contact model."""
def setUp(self): def setUp(self):
"""Initialization for the tests in this class""" """Initialization for the tests in this class."""
# Create a simple company # Create a simple company
self.c = Company.objects.create( self.c = Company.objects.create(
name='Test Corp.', description='We make stuff good' name='Test Corp.', description='We make stuff good'
@ -168,39 +168,39 @@ class ContactSimpleTest(TestCase):
Contact.objects.create(name='Sally Smith', company=self.c) Contact.objects.create(name='Sally Smith', company=self.c)
def test_exists(self): def test_exists(self):
"""Test that contacts exist""" """Test that contacts exist."""
self.assertEqual(Contact.objects.count(), 3) self.assertEqual(Contact.objects.count(), 3)
def test_delete(self): def test_delete(self):
"""Test deletion of a Contact instance""" """Test deletion of a Contact instance."""
# Remove the parent company # Remove the parent company
Company.objects.get(pk=self.c.pk).delete() Company.objects.get(pk=self.c.pk).delete()
self.assertEqual(Contact.objects.count(), 0) self.assertEqual(Contact.objects.count(), 0)
class AddressTest(TestCase): class AddressTest(TestCase):
"""Unit tests for the Address model""" """Unit tests for the Address model."""
def setUp(self): def setUp(self):
"""Initialization for the tests in this class""" """Initialization for the tests in this class."""
# Create a simple company # Create a simple company
self.c = Company.objects.create( self.c = Company.objects.create(
name='Test Corp.', description='We make stuff good' name='Test Corp.', description='We make stuff good'
) )
def test_create(self): def test_create(self):
"""Test that object creation with only company supplied is successful""" """Test that object creation with only company supplied is successful."""
Address.objects.create(company=self.c) Address.objects.create(company=self.c)
self.assertEqual(Address.objects.count(), 1) self.assertEqual(Address.objects.count(), 1)
def test_delete(self): def test_delete(self):
"""Test Address deletion""" """Test Address deletion."""
addr = Address.objects.create(company=self.c) addr = Address.objects.create(company=self.c)
addr.delete() addr.delete()
self.assertEqual(Address.objects.count(), 0) self.assertEqual(Address.objects.count(), 0)
def test_primary_constraint(self): def test_primary_constraint(self):
"""Test that there can only be one company-'primary=true' pair""" """Test that there can only be one company-'primary=true' pair."""
Address.objects.create(company=self.c, primary=True) Address.objects.create(company=self.c, primary=True)
Address.objects.create(company=self.c, primary=False) Address.objects.create(company=self.c, primary=False)
@ -216,12 +216,12 @@ class AddressTest(TestCase):
self.assertTrue(Address.objects.last().primary) self.assertTrue(Address.objects.last().primary)
def test_first_address_is_primary(self): def test_first_address_is_primary(self):
"""Test that first address related to company is always set to primary""" """Test that first address related to company is always set to primary."""
addr = Address.objects.create(company=self.c) addr = Address.objects.create(company=self.c)
self.assertTrue(addr.primary) self.assertTrue(addr.primary)
def test_model_str(self): def test_model_str(self):
"""Test value of __str__""" """Test value of __str__."""
t = 'Test address' t = 'Test address'
l1 = 'Busy street 56' l1 = 'Busy street 56'
l2 = 'Red building' l2 = 'Red building'
@ -249,12 +249,12 @@ class AddressTest(TestCase):
class ManufacturerPartSimpleTest(TestCase): class ManufacturerPartSimpleTest(TestCase):
"""Unit tests for the ManufacturerPart model""" """Unit tests for the ManufacturerPart model."""
fixtures = ['category', 'company', 'location', 'part', 'manufacturer_part'] fixtures = ['category', 'company', 'location', 'part', 'manufacturer_part']
def setUp(self): def setUp(self):
"""Initialization for the unit tests in this class""" """Initialization for the unit tests in this class."""
# Create a manufacturer part # Create a manufacturer part
self.part = Part.objects.get(pk=1) self.part = Part.objects.get(pk=1)
manufacturer = Company.objects.get(pk=1) manufacturer = Company.objects.get(pk=1)
@ -275,7 +275,7 @@ class ManufacturerPartSimpleTest(TestCase):
supplier_part.save() supplier_part.save()
def test_exists(self): def test_exists(self):
"""That that a ManufacturerPart has been created""" """That that a ManufacturerPart has been created."""
self.assertEqual(ManufacturerPart.objects.count(), 4) self.assertEqual(ManufacturerPart.objects.count(), 4)
# Check that manufacturer part was created from supplier part creation # Check that manufacturer part was created from supplier part creation
@ -283,7 +283,7 @@ class ManufacturerPartSimpleTest(TestCase):
self.assertEqual(manufacturer_parts.count(), 1) self.assertEqual(manufacturer_parts.count(), 1)
def test_delete(self): def test_delete(self):
"""Test deletion of a ManufacturerPart""" """Test deletion of a ManufacturerPart."""
Part.objects.get(pk=self.part.id).delete() Part.objects.get(pk=self.part.id).delete()
# Check that ManufacturerPart was deleted # Check that ManufacturerPart was deleted
self.assertEqual(ManufacturerPart.objects.count(), 3) self.assertEqual(ManufacturerPart.objects.count(), 3)

View File

@ -20,7 +20,7 @@ class CompanyIndex(InvenTreeRoleMixin, ListView):
permission_required = 'company.view_company' permission_required = 'company.view_company'
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
"""Add extra context data to the company index page""" """Add extra context data to the company index page."""
ctx = super().get_context_data(**kwargs) ctx = super().get_context_data(**kwargs)
# Provide custom context data to the template, # Provide custom context data to the template,

View File

@ -24,7 +24,7 @@ class StatusView(APIView):
MODEL_REF = 'statusmodel' MODEL_REF = 'statusmodel'
def get_status_model(self, *args, **kwargs): def get_status_model(self, *args, **kwargs):
"""Return the StatusCode model based on extra parameters passed to the view""" """Return the StatusCode model based on extra parameters passed to the view."""
status_model = self.kwargs.get(self.MODEL_REF, None) status_model = self.kwargs.get(self.MODEL_REF, None)
if status_model is None: if status_model is None:
@ -35,7 +35,7 @@ class StatusView(APIView):
return status_model return status_model
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""Perform a GET request to learn information about status codes""" """Perform a GET request to learn information about status codes."""
status_class = self.get_status_model() status_class = self.get_status_model()
if not inspect.isclass(status_class): if not inspect.isclass(status_class):
@ -55,7 +55,7 @@ class AllStatusViews(StatusView):
permission_classes = [permissions.IsAuthenticated] permission_classes = [permissions.IsAuthenticated]
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""Perform a GET request to learn information about status codes""" """Perform a GET request to learn information about status codes."""
data = {} data = {}
for status_class in StatusCode.__subclasses__(): for status_class in StatusCode.__subclasses__():

View File

@ -77,7 +77,7 @@ class StatusCode(BaseEnum):
@classmethod @classmethod
def values(cls, key=None): def values(cls, key=None):
"""Return a dict representation containing all required information""" """Return a dict representation containing all required information."""
elements = [itm for itm in cls if cls._is_element(itm.name)] elements = [itm for itm in cls if cls._is_element(itm.name)]
if key is None: if key is None:
return elements return elements
@ -148,7 +148,7 @@ class StatusCode(BaseEnum):
@classmethod @classmethod
def dict(cls, key=None): def dict(cls, key=None):
"""Return a dict representation containing all required information""" """Return a dict representation containing all required information."""
return { return {
x.name: {'color': x.color, 'key': x.value, 'label': x.label, 'name': x.name} x.name: {'color': x.color, 'key': x.value, 'label': x.label, 'name': x.name}
for x in cls.values(key) for x in cls.values(key)

View File

@ -23,7 +23,6 @@ def dflt(*args, **kwargs):
def _clean_storage(refs): def _clean_storage(refs):
"""Clean the storage.""" """Clean the storage."""
for ref in refs: for ref in refs:
del ref del ref
storage.collect() storage.collect()
@ -44,7 +43,6 @@ class TransitionTests(InvenTreeTestCase):
def test_storage(self): def test_storage(self):
"""Ensure that the storage collection mechanism works.""" """Ensure that the storage collection mechanism works."""
global raise_storage global raise_storage
global raise_function global raise_function
@ -54,7 +52,6 @@ class TransitionTests(InvenTreeTestCase):
class RaisingImplementation(TransitionMethod): class RaisingImplementation(TransitionMethod):
def transition(self, *args, **kwargs): def transition(self, *args, **kwargs):
"""Custom transition method.""" """Custom transition method."""
global raise_storage global raise_storage
if raise_storage: if raise_storage:
@ -75,7 +72,6 @@ class TransitionTests(InvenTreeTestCase):
def test_function(self): def test_function(self):
"""Ensure that a TransitionMethod's function is called.""" """Ensure that a TransitionMethod's function is called."""
global raise_storage global raise_storage
global raise_function global raise_function

View File

@ -22,7 +22,7 @@ class GeneralStatus(StatusCode):
jkl = None # This should be ignored jkl = None # This should be ignored
def GHI(self): # This should be ignored def GHI(self): # This should be ignored
"""A invalid function""" """A invalid function."""
pass pass
@ -36,7 +36,7 @@ class GeneralStateTest(InvenTreeTestCase):
self.assertEqual(GeneralStatus.COMPLETE, 30) self.assertEqual(GeneralStatus.COMPLETE, 30)
def test_code_functions(self): def test_code_functions(self):
"""Test that the status code class functions work correctly""" """Test that the status code class functions work correctly."""
# render # render
self.assertEqual( self.assertEqual(
GeneralStatus.render(10), GeneralStatus.render(10),

View File

@ -76,7 +76,6 @@ class StateTransitionMixin:
instance: Object instance instance: Object instance
default_action: Default action to be taken if none of the transitions returns a boolean true value default_action: Default action to be taken if none of the transitions returns a boolean true value
""" """
# Check if there is a custom override function for this transition # Check if there is a custom override function for this transition
for override in storage.list: for override in storage.list:
rslt = override.transition( rslt = override.transition(

View File

@ -1,4 +1,4 @@
"""Gunicorn configuration script for InvenTree web server""" """Gunicorn configuration script for InvenTree web server."""
import multiprocessing import multiprocessing

View File

@ -1,4 +1,4 @@
"""Admin functionality for the 'label' app""" """Admin functionality for the 'label' app."""
from django.contrib import admin from django.contrib import admin
@ -6,7 +6,7 @@ import label.models
class LabelAdmin(admin.ModelAdmin): class LabelAdmin(admin.ModelAdmin):
"""Admin class for the various label models""" """Admin class for the various label models."""
list_display = ('name', 'description', 'label', 'filters', 'enabled') list_display = ('name', 'description', 'label', 'filters', 'enabled')

View File

@ -1,4 +1,4 @@
"""API functionality for the 'label' app""" """API functionality for the 'label' app."""
from django.core.exceptions import FieldError, ValidationError from django.core.exceptions import FieldError, ValidationError
from django.http import JsonResponse from django.http import JsonResponse
@ -42,7 +42,7 @@ class LabelFilterMixin:
ITEM_KEY = 'item' ITEM_KEY = 'item'
def get_items(self): def get_items(self):
"""Return a list of database objects from query parameter""" """Return a list of database objects from query parameter."""
ids = [] ids = []
# Construct a list of possible query parameter value options # Construct a list of possible query parameter value options
@ -140,7 +140,7 @@ class LabelPrintMixin(LabelFilterMixin):
@method_decorator(never_cache) @method_decorator(never_cache)
def dispatch(self, *args, **kwargs): def dispatch(self, *args, **kwargs):
"""Prevent caching when printing report templates""" """Prevent caching when printing report templates."""
return super().dispatch(*args, **kwargs) return super().dispatch(*args, **kwargs)
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
@ -160,7 +160,7 @@ class LabelPrintMixin(LabelFilterMixin):
return serializer return serializer
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""Perform a GET request against this endpoint to print labels""" """Perform a GET request against this endpoint to print labels."""
common.models.InvenTreeUserSetting.set_setting( common.models.InvenTreeUserSetting.set_setting(
'DEFAULT_' + self.ITEM_KEY.upper() + '_LABEL_TEMPLATE', 'DEFAULT_' + self.ITEM_KEY.upper() + '_LABEL_TEMPLATE',
self.get_object().pk, self.get_object().pk,
@ -170,7 +170,7 @@ class LabelPrintMixin(LabelFilterMixin):
return self.print(request, self.get_items()) return self.print(request, self.get_items())
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
"""Perform a GET request against this endpoint to print labels""" """Perform a GET request against this endpoint to print labels."""
return self.get(request, *args, **kwargs) return self.get(request, *args, **kwargs)
def get_plugin(self, request): def get_plugin(self, request):
@ -245,7 +245,7 @@ class LabelPrintMixin(LabelFilterMixin):
class StockItemLabelMixin: class StockItemLabelMixin:
"""Mixin for StockItemLabel endpoints""" """Mixin for StockItemLabel endpoints."""
queryset = label.models.StockItemLabel.objects.all() queryset = label.models.StockItemLabel.objects.all()
serializer_class = label.serializers.StockItemLabelSerializer serializer_class = label.serializers.StockItemLabelSerializer
@ -280,7 +280,7 @@ class StockItemLabelPrint(StockItemLabelMixin, LabelPrintMixin, RetrieveAPI):
class StockLocationLabelMixin: class StockLocationLabelMixin:
"""Mixin for StockLocationLabel endpoints""" """Mixin for StockLocationLabel endpoints."""
queryset = label.models.StockLocationLabel.objects.all() queryset = label.models.StockLocationLabel.objects.all()
serializer_class = label.serializers.StockLocationLabelSerializer serializer_class = label.serializers.StockLocationLabelSerializer
@ -315,7 +315,7 @@ class StockLocationLabelPrint(StockLocationLabelMixin, LabelPrintMixin, Retrieve
class PartLabelMixin: class PartLabelMixin:
"""Mixin for PartLabel endpoints""" """Mixin for PartLabel endpoints."""
queryset = label.models.PartLabel.objects.all() queryset = label.models.PartLabel.objects.all()
serializer_class = label.serializers.PartLabelSerializer serializer_class = label.serializers.PartLabelSerializer
@ -343,7 +343,7 @@ class PartLabelPrint(PartLabelMixin, LabelPrintMixin, RetrieveAPI):
class BuildLineLabelMixin: class BuildLineLabelMixin:
"""Mixin class for BuildLineLabel endpoints""" """Mixin class for BuildLineLabel endpoints."""
queryset = label.models.BuildLineLabel.objects.all() queryset = label.models.BuildLineLabel.objects.all()
serializer_class = label.serializers.BuildLineLabelSerializer serializer_class = label.serializers.BuildLineLabelSerializer
@ -353,19 +353,19 @@ class BuildLineLabelMixin:
class BuildLineLabelList(BuildLineLabelMixin, LabelListView): class BuildLineLabelList(BuildLineLabelMixin, LabelListView):
"""API endpoint for viewing a list of BuildLineLabel objects""" """API endpoint for viewing a list of BuildLineLabel objects."""
pass pass
class BuildLineLabelDetail(BuildLineLabelMixin, RetrieveUpdateDestroyAPI): class BuildLineLabelDetail(BuildLineLabelMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for a single BuildLineLabel object""" """API endpoint for a single BuildLineLabel object."""
pass pass
class BuildLineLabelPrint(BuildLineLabelMixin, LabelPrintMixin, RetrieveAPI): class BuildLineLabelPrint(BuildLineLabelMixin, LabelPrintMixin, RetrieveAPI):
"""API endpoint for printing a BuildLineLabel object""" """API endpoint for printing a BuildLineLabel object."""
pass pass

View File

@ -1,4 +1,4 @@
"""label app specification""" """label app specification."""
import hashlib import hashlib
import logging import logging
@ -29,7 +29,7 @@ def hashFile(filename):
class LabelConfig(AppConfig): class LabelConfig(AppConfig):
"""App configuration class for the 'label' app""" """App configuration class for the 'label' app."""
name = 'label' name = 'label'

View File

@ -48,28 +48,28 @@ def rename_label_output(instance, filename):
def validate_stock_item_filters(filters): def validate_stock_item_filters(filters):
"""Validate query filters for the StockItemLabel model""" """Validate query filters for the StockItemLabel model."""
filters = validateFilterString(filters, model=stock.models.StockItem) filters = validateFilterString(filters, model=stock.models.StockItem)
return filters return filters
def validate_stock_location_filters(filters): def validate_stock_location_filters(filters):
"""Validate query filters for the StockLocationLabel model""" """Validate query filters for the StockLocationLabel model."""
filters = validateFilterString(filters, model=stock.models.StockLocation) filters = validateFilterString(filters, model=stock.models.StockLocation)
return filters return filters
def validate_part_filters(filters): def validate_part_filters(filters):
"""Validate query filters for the PartLabel model""" """Validate query filters for the PartLabel model."""
filters = validateFilterString(filters, model=part.models.Part) filters = validateFilterString(filters, model=part.models.Part)
return filters return filters
def validate_build_line_filters(filters): def validate_build_line_filters(filters):
"""Validate query filters for the BuildLine model""" """Validate query filters for the BuildLine model."""
filters = validateFilterString(filters, model=build.models.BuildLine) filters = validateFilterString(filters, model=build.models.BuildLine)
return filters return filters
@ -82,7 +82,7 @@ class WeasyprintLabelMixin(WeasyTemplateResponseMixin):
pdf_attachment = True pdf_attachment = True
def __init__(self, request, template, **kwargs): def __init__(self, request, template, **kwargs):
"""Initialize a label mixin with certain properties""" """Initialize a label mixin with certain properties."""
self.request = request self.request = request
self.template_name = template self.template_name = template
self.pdf_filename = kwargs.get('filename', 'label.pdf') self.pdf_filename = kwargs.get('filename', 'label.pdf')
@ -104,11 +104,11 @@ class LabelTemplate(MetadataMixin, models.Model):
@property @property
def template(self): def template(self):
"""Return the file path of the template associated with this label instance""" """Return the file path of the template associated with this label instance."""
return self.label.path return self.label.path
def __str__(self): def __str__(self):
"""Format a string representation of a label instance""" """Format a string representation of a label instance."""
return f'{self.name} - {self.description}' return f'{self.name} - {self.description}'
name = models.CharField( name = models.CharField(
@ -196,7 +196,6 @@ class LabelTemplate(MetadataMixin, models.Model):
This is inserted at the top of the style block for a given label This is inserted at the top of the style block for a given label
""" """
width = kwargs.get('width', self.width) width = kwargs.get('width', self.width)
height = kwargs.get('height', self.height) height = kwargs.get('height', self.height)
margin = kwargs.get('margin', 0) margin = kwargs.get('margin', 0)
@ -215,7 +214,6 @@ class LabelTemplate(MetadataMixin, models.Model):
request: The HTTP request object request: The HTTP request object
kwargs: Additional keyword arguments kwargs: Additional keyword arguments
""" """
context = self.get_context_data(request) context = self.get_context_data(request)
# By default, each label is supplied with '@page' data # By default, each label is supplied with '@page' data
@ -242,8 +240,7 @@ class LabelTemplate(MetadataMixin, models.Model):
return context return context
def render_as_string(self, request, target_object=None, **kwargs): def render_as_string(self, request, target_object=None, **kwargs):
"""Render the label to a HTML string""" """Render the label to a HTML string."""
if target_object: if target_object:
self.object_to_print = target_object self.object_to_print = target_object
@ -256,7 +253,6 @@ class LabelTemplate(MetadataMixin, models.Model):
Uses django-weasyprint plugin to render HTML template Uses django-weasyprint plugin to render HTML template
""" """
if target_object: if target_object:
self.object_to_print = target_object self.object_to_print = target_object
@ -275,7 +271,7 @@ class LabelTemplate(MetadataMixin, models.Model):
class LabelOutput(models.Model): class LabelOutput(models.Model):
"""Class representing a label output file """Class representing a label output file.
'Printing' a label may generate a file object (such as PDF) 'Printing' a label may generate a file object (such as PDF)
which is made available for download. which is made available for download.
@ -301,7 +297,7 @@ class StockItemLabel(LabelTemplate):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the StockItemLabel model""" """Return the API URL associated with the StockItemLabel model."""
return reverse('api-stockitem-label-list') # pragma: no cover return reverse('api-stockitem-label-list') # pragma: no cover
SUBDIR = 'stockitem' SUBDIR = 'stockitem'
@ -340,7 +336,7 @@ class StockLocationLabel(LabelTemplate):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the StockLocationLabel model""" """Return the API URL associated with the StockLocationLabel model."""
return reverse('api-stocklocation-label-list') # pragma: no cover return reverse('api-stocklocation-label-list') # pragma: no cover
SUBDIR = 'stocklocation' SUBDIR = 'stocklocation'
@ -365,7 +361,7 @@ class PartLabel(LabelTemplate):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API url associated with the PartLabel model""" """Return the API url associated with the PartLabel model."""
return reverse('api-part-label-list') # pragma: no cover return reverse('api-part-label-list') # pragma: no cover
SUBDIR = 'part' SUBDIR = 'part'
@ -396,11 +392,11 @@ class PartLabel(LabelTemplate):
class BuildLineLabel(LabelTemplate): class BuildLineLabel(LabelTemplate):
"""Template for printing labels against BuildLine objects""" """Template for printing labels against BuildLine objects."""
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the BuildLineLabel model""" """Return the API URL associated with the BuildLineLabel model."""
return reverse('api-buildline-label-list') return reverse('api-buildline-label-list')
SUBDIR = 'buildline' SUBDIR = 'buildline'

View File

@ -1,4 +1,4 @@
"""API serializers for the label app""" """API serializers for the label app."""
import label.models import label.models
from InvenTree.serializers import ( from InvenTree.serializers import (
@ -8,13 +8,13 @@ from InvenTree.serializers import (
class LabelSerializerBase(InvenTreeModelSerializer): class LabelSerializerBase(InvenTreeModelSerializer):
"""Base class for label serializer""" """Base class for label serializer."""
label = InvenTreeAttachmentSerializerField(required=True) label = InvenTreeAttachmentSerializerField(required=True)
@staticmethod @staticmethod
def label_fields(): def label_fields():
"""Generic serializer fields for a label template""" """Generic serializer fields for a label template."""
return ['pk', 'name', 'description', 'label', 'filters', 'enabled'] return ['pk', 'name', 'description', 'label', 'filters', 'enabled']
@ -49,7 +49,7 @@ class PartLabelSerializer(LabelSerializerBase):
class BuildLineLabelSerializer(LabelSerializerBase): class BuildLineLabelSerializer(LabelSerializerBase):
"""Serializes a BuildLineLabel object""" """Serializes a BuildLineLabel object."""
class Meta: class Meta:
"""Metaclass options.""" """Metaclass options."""

View File

@ -1,4 +1,4 @@
"""Background tasks for the label app""" """Background tasks for the label app."""
from datetime import timedelta from datetime import timedelta
@ -10,6 +10,6 @@ from label.models import LabelOutput
@scheduled_task(ScheduledTask.DAILY) @scheduled_task(ScheduledTask.DAILY)
def cleanup_old_label_outputs(): def cleanup_old_label_outputs():
"""Remove old label outputs from the database""" """Remove old label outputs from the database."""
# Remove any label outputs which are older than 30 days # Remove any label outputs which are older than 30 days
LabelOutput.objects.filter(created__lte=timezone.now() - timedelta(days=5)).delete() LabelOutput.objects.filter(created__lte=timezone.now() - timedelta(days=5)).delete()

View File

@ -1,4 +1,4 @@
"""Unit tests for label API""" """Unit tests for label API."""
from django.urls import reverse from django.urls import reverse
@ -15,7 +15,7 @@ class TestReportTests(InvenTreeAPITestCase):
list_url = reverse('api-stockitem-testreport-list') list_url = reverse('api-stockitem-testreport-list')
def do_list(self, filters=None): def do_list(self, filters=None):
"""Helper function to request list of labels with provided filters""" """Helper function to request list of labels with provided filters."""
# Set default - see B006 # Set default - see B006
if filters is None: if filters is None:
filters = {} filters = {}
@ -27,7 +27,7 @@ class TestReportTests(InvenTreeAPITestCase):
return response.data return response.data
def test_list(self): def test_list(self):
"""Test the API list endpoint""" """Test the API list endpoint."""
response = self.do_list() response = self.do_list()
# TODO - Add some report templates to the fixtures # TODO - Add some report templates to the fixtures

View File

@ -1,4 +1,4 @@
"""Tests for labels""" """Tests for labels."""
import io import io
import json import json
@ -22,13 +22,13 @@ from .models import PartLabel, StockItemLabel, StockLocationLabel
class LabelTest(InvenTreeAPITestCase): class LabelTest(InvenTreeAPITestCase):
"""Unit test class for label models""" """Unit test class for label models."""
fixtures = ['category', 'part', 'location', 'stock'] fixtures = ['category', 'part', 'location', 'stock']
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Ensure that some label instances exist as part of init routine""" """Ensure that some label instances exist as part of init routine."""
super().setUpTestData() super().setUpTestData()
apps.get_app_config('label').create_labels() apps.get_app_config('label').create_labels()

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
"""InvenTree / django management commands""" """InvenTree / django management commands."""
import os import os
import sys import sys

View File

@ -1,4 +1,4 @@
"""Admin functionality for the 'order' app""" """Admin functionality for the 'order' app."""
from django.contrib import admin from django.contrib import admin
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -12,36 +12,36 @@ from order import models
class ProjectCodeResourceMixin: class ProjectCodeResourceMixin:
"""Mixin for exporting project code data""" """Mixin for exporting project code data."""
project_code = Field(attribute='project_code', column_name=_('Project Code')) project_code = Field(attribute='project_code', column_name=_('Project Code'))
def dehydrate_project_code(self, order): def dehydrate_project_code(self, order):
"""Return the project code value, not the pk""" """Return the project code value, not the pk."""
if order.project_code: if order.project_code:
return order.project_code.code return order.project_code.code
return '' return ''
class TotalPriceResourceMixin: class TotalPriceResourceMixin:
"""Mixin for exporting total price data""" """Mixin for exporting total price data."""
total_price = Field(attribute='total_price', column_name=_('Total Price')) total_price = Field(attribute='total_price', column_name=_('Total Price'))
def dehydrate_total_price(self, order): def dehydrate_total_price(self, order):
"""Return the total price amount, not the object itself""" """Return the total price amount, not the object itself."""
if order.total_price: if order.total_price:
return order.total_price.amount return order.total_price.amount
return '' return ''
class PriceResourceMixin: class PriceResourceMixin:
"""Mixin for 'price' field""" """Mixin for 'price' field."""
price = Field(attribute='price', column_name=_('Price')) price = Field(attribute='price', column_name=_('Price'))
def dehydrate_price(self, line): def dehydrate_price(self, line):
"""Return the price amount, not the object itself""" """Return the price amount, not the object itself."""
if line.price: if line.price:
return line.price.amount return line.price.amount
return '' return ''
@ -49,7 +49,7 @@ class PriceResourceMixin:
# region general classes # region general classes
class GeneralExtraLineAdmin: class GeneralExtraLineAdmin:
"""Admin class template for the 'ExtraLineItem' models""" """Admin class template for the 'ExtraLineItem' models."""
list_display = ('order', 'quantity', 'reference') list_display = ('order', 'quantity', 'reference')
@ -59,7 +59,7 @@ class GeneralExtraLineAdmin:
class GeneralExtraLineMeta: class GeneralExtraLineMeta:
"""Metaclass template for the 'ExtraLineItem' models""" """Metaclass template for the 'ExtraLineItem' models."""
skip_unchanged = True skip_unchanged = True
report_skipped = False report_skipped = False
@ -70,7 +70,7 @@ class GeneralExtraLineMeta:
class PurchaseOrderLineItemInlineAdmin(admin.StackedInline): class PurchaseOrderLineItemInlineAdmin(admin.StackedInline):
"""Inline admin class for the PurchaseOrderLineItem model""" """Inline admin class for the PurchaseOrderLineItem model."""
model = models.PurchaseOrderLineItem model = models.PurchaseOrderLineItem
extra = 0 extra = 0
@ -82,7 +82,7 @@ class PurchaseOrderResource(
"""Class for managing import / export of PurchaseOrder data.""" """Class for managing import / export of PurchaseOrder data."""
class Meta: class Meta:
"""Metaclass""" """Metaclass options."""
model = models.PurchaseOrder model = models.PurchaseOrder
skip_unchanged = True skip_unchanged = True
@ -101,7 +101,7 @@ class PurchaseOrderResource(
class PurchaseOrderAdmin(ImportExportModelAdmin): class PurchaseOrderAdmin(ImportExportModelAdmin):
"""Admin class for the PurchaseOrder model""" """Admin class for the PurchaseOrder model."""
resource_class = PurchaseOrderResource resource_class = PurchaseOrderResource
@ -122,7 +122,7 @@ class SalesOrderResource(
"""Class for managing import / export of SalesOrder data.""" """Class for managing import / export of SalesOrder data."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = models.SalesOrder model = models.SalesOrder
skip_unchanged = True skip_unchanged = True
@ -141,7 +141,7 @@ class SalesOrderResource(
class SalesOrderAdmin(ImportExportModelAdmin): class SalesOrderAdmin(ImportExportModelAdmin):
"""Admin class for the SalesOrder model""" """Admin class for the SalesOrder model."""
resource_class = SalesOrderResource resource_class = SalesOrderResource
@ -158,7 +158,7 @@ class PurchaseOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
"""Class for managing import / export of PurchaseOrderLineItem data.""" """Class for managing import / export of PurchaseOrderLineItem data."""
class Meta: class Meta:
"""Metaclass""" """Metaclass."""
model = models.PurchaseOrderLineItem model = models.PurchaseOrderLineItem
skip_unchanged = True skip_unchanged = True
@ -174,7 +174,7 @@ class PurchaseOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
SKU = Field(attribute='part__SKU', readonly=True) SKU = Field(attribute='part__SKU', readonly=True)
def dehydrate_purchase_price(self, line): def dehydrate_purchase_price(self, line):
"""Return a string value of the 'purchase_price' field, rather than the 'Money' object""" """Return a string value of the 'purchase_price' field, rather than the 'Money' object."""
if line.purchase_price: if line.purchase_price:
return line.purchase_price.amount return line.purchase_price.amount
return '' return ''
@ -193,7 +193,7 @@ class SalesOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
"""Class for managing import / export of SalesOrderLineItem data.""" """Class for managing import / export of SalesOrderLineItem data."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = models.SalesOrderLineItem model = models.SalesOrderLineItem
skip_unchanged = True skip_unchanged = True
@ -228,7 +228,7 @@ class SalesOrderExtraLineResource(PriceResourceMixin, InvenTreeResource):
class PurchaseOrderLineItemAdmin(ImportExportModelAdmin): class PurchaseOrderLineItemAdmin(ImportExportModelAdmin):
"""Admin class for the PurchaseOrderLine model""" """Admin class for the PurchaseOrderLine model."""
resource_class = PurchaseOrderLineItemResource resource_class = PurchaseOrderLineItemResource
@ -240,13 +240,13 @@ class PurchaseOrderLineItemAdmin(ImportExportModelAdmin):
class PurchaseOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin): class PurchaseOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin):
"""Admin class for the PurchaseOrderExtraLine model""" """Admin class for the PurchaseOrderExtraLine model."""
resource_class = PurchaseOrderExtraLineResource resource_class = PurchaseOrderExtraLineResource
class SalesOrderLineItemAdmin(ImportExportModelAdmin): class SalesOrderLineItemAdmin(ImportExportModelAdmin):
"""Admin class for the SalesOrderLine model""" """Admin class for the SalesOrderLine model."""
resource_class = SalesOrderLineItemResource resource_class = SalesOrderLineItemResource
@ -263,13 +263,13 @@ class SalesOrderLineItemAdmin(ImportExportModelAdmin):
class SalesOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin): class SalesOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin):
"""Admin class for the SalesOrderExtraLine model""" """Admin class for the SalesOrderExtraLine model."""
resource_class = SalesOrderExtraLineResource resource_class = SalesOrderExtraLineResource
class SalesOrderShipmentAdmin(ImportExportModelAdmin): class SalesOrderShipmentAdmin(ImportExportModelAdmin):
"""Admin class for the SalesOrderShipment model""" """Admin class for the SalesOrderShipment model."""
list_display = ['order', 'shipment_date', 'reference'] list_display = ['order', 'shipment_date', 'reference']
@ -279,7 +279,7 @@ class SalesOrderShipmentAdmin(ImportExportModelAdmin):
class SalesOrderAllocationAdmin(ImportExportModelAdmin): class SalesOrderAllocationAdmin(ImportExportModelAdmin):
"""Admin class for the SalesOrderAllocation model""" """Admin class for the SalesOrderAllocation model."""
list_display = ('line', 'item', 'quantity') list_display = ('line', 'item', 'quantity')
@ -289,10 +289,10 @@ class SalesOrderAllocationAdmin(ImportExportModelAdmin):
class ReturnOrderResource( class ReturnOrderResource(
ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource
): ):
"""Class for managing import / export of ReturnOrder data""" """Class for managing import / export of ReturnOrder data."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = models.ReturnOrder model = models.ReturnOrder
skip_unchanged = True skip_unchanged = True
@ -301,7 +301,7 @@ class ReturnOrderResource(
class ReturnOrderAdmin(ImportExportModelAdmin): class ReturnOrderAdmin(ImportExportModelAdmin):
"""Admin class for the ReturnOrder model""" """Admin class for the ReturnOrder model."""
resource_class = ReturnOrderResource resource_class = ReturnOrderResource
@ -315,10 +315,10 @@ class ReturnOrderAdmin(ImportExportModelAdmin):
class ReturnOrderLineItemResource(PriceResourceMixin, InvenTreeResource): class ReturnOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
"""Class for managing import / export of ReturnOrderLineItem data""" """Class for managing import / export of ReturnOrderLineItem data."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = models.ReturnOrderLineItem model = models.ReturnOrderLineItem
skip_unchanged = True skip_unchanged = True
@ -327,7 +327,7 @@ class ReturnOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
class ReturnOrderLineItemAdmin(ImportExportModelAdmin): class ReturnOrderLineItemAdmin(ImportExportModelAdmin):
"""Admin class for ReturnOrderLine model""" """Admin class for ReturnOrderLine model."""
resource_class = ReturnOrderLineItemResource resource_class = ReturnOrderLineItemResource
@ -335,16 +335,16 @@ class ReturnOrderLineItemAdmin(ImportExportModelAdmin):
class ReturnOrderExtraLineClass(PriceResourceMixin, InvenTreeResource): class ReturnOrderExtraLineClass(PriceResourceMixin, InvenTreeResource):
"""Class for managing import/export of ReturnOrderExtraLine data""" """Class for managing import/export of ReturnOrderExtraLine data."""
class Meta(GeneralExtraLineMeta): class Meta(GeneralExtraLineMeta):
"""Metaclass options""" """Metaclass options."""
model = models.ReturnOrderExtraLine model = models.ReturnOrderExtraLine
class ReturnOrdeerExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin): class ReturnOrdeerExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin):
"""Admin class for the ReturnOrderExtraLine model""" """Admin class for the ReturnOrderExtraLine model."""
resource_class = ReturnOrderExtraLineClass resource_class = ReturnOrderExtraLineClass

View File

@ -53,7 +53,7 @@ class GeneralExtraLineList(APIDownloadMixin):
"""General template for ExtraLine API classes.""" """General template for ExtraLine API classes."""
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return the serializer instance for this endpoint""" """Return the serializer instance for this endpoint."""
try: try:
params = self.request.query_params3 params = self.request.query_params3
@ -66,7 +66,7 @@ class GeneralExtraLineList(APIDownloadMixin):
return self.serializer_class(*args, **kwargs) return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Return the annotated queryset for this endpoint""" """Return the annotated queryset for this endpoint."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related('order') queryset = queryset.prefetch_related('order')
@ -89,7 +89,7 @@ class OrderFilter(rest_filters.FilterSet):
status = rest_filters.NumberFilter(label='Order Status', method='filter_status') status = rest_filters.NumberFilter(label='Order Status', method='filter_status')
def filter_status(self, queryset, name, value): def filter_status(self, queryset, name, value):
"""Filter by integer status code""" """Filter by integer status code."""
return queryset.filter(status=value) return queryset.filter(status=value)
# Exact match for reference # Exact match for reference
@ -126,7 +126,7 @@ class OrderFilter(rest_filters.FilterSet):
) )
def filter_outstanding(self, queryset, name, value): def filter_outstanding(self, queryset, name, value):
"""Generic filter for determining if an order is 'outstanding'""" """Generic filter for determining if an order is 'outstanding'."""
if str2bool(value): if str2bool(value):
return queryset.filter(status__in=self.Meta.model.get_status_class().OPEN) return queryset.filter(status__in=self.Meta.model.get_status_class().OPEN)
return queryset.exclude(status__in=self.Meta.model.get_status_class().OPEN) return queryset.exclude(status__in=self.Meta.model.get_status_class().OPEN)
@ -140,14 +140,14 @@ class OrderFilter(rest_filters.FilterSet):
) )
def filter_has_project_code(self, queryset, name, value): def filter_has_project_code(self, queryset, name, value):
"""Filter by whether or not the order has a project code""" """Filter by whether or not the order has a project code."""
if str2bool(value): if str2bool(value):
return queryset.exclude(project_code=None) return queryset.exclude(project_code=None)
return queryset.filter(project_code=None) return queryset.filter(project_code=None)
class LineItemFilter(rest_filters.FilterSet): class LineItemFilter(rest_filters.FilterSet):
"""Base class for custom API filters for order line item list(s)""" """Base class for custom API filters for order line item list(s)."""
# Filter by order status # Filter by order status
order_status = rest_filters.NumberFilter( order_status = rest_filters.NumberFilter(
@ -159,7 +159,7 @@ class LineItemFilter(rest_filters.FilterSet):
) )
def filter_has_pricing(self, queryset, name, value): def filter_has_pricing(self, queryset, name, value):
"""Filter by whether or not the line item has pricing information""" """Filter by whether or not the line item has pricing information."""
filters = {self.Meta.price_field: None} filters = {self.Meta.price_field: None}
if str2bool(value): if str2bool(value):
@ -178,13 +178,13 @@ class PurchaseOrderFilter(OrderFilter):
class PurchaseOrderMixin: class PurchaseOrderMixin:
"""Mixin class for PurchaseOrder endpoints""" """Mixin class for PurchaseOrder endpoints."""
queryset = models.PurchaseOrder.objects.all() queryset = models.PurchaseOrder.objects.all()
serializer_class = serializers.PurchaseOrderSerializer serializer_class = serializers.PurchaseOrderSerializer
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return the serializer instance for this endpoint""" """Return the serializer instance for this endpoint."""
try: try:
kwargs['supplier_detail'] = str2bool( kwargs['supplier_detail'] = str2bool(
self.request.query_params.get('supplier_detail', False) self.request.query_params.get('supplier_detail', False)
@ -198,7 +198,7 @@ class PurchaseOrderMixin:
return self.serializer_class(*args, **kwargs) return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Return the annotated queryset for this endpoint""" """Return the annotated queryset for this endpoint."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related('supplier', 'lines') queryset = queryset.prefetch_related('supplier', 'lines')
@ -266,7 +266,7 @@ class PurchaseOrderList(PurchaseOrderMixin, APIDownloadMixin, ListCreateAPI):
) )
def download_queryset(self, queryset, export_format): def download_queryset(self, queryset, export_format):
"""Download the filtered queryset as a file""" """Download the filtered queryset as a file."""
dataset = PurchaseOrderResource().export(queryset=queryset) dataset = PurchaseOrderResource().export(queryset=queryset)
filedata = dataset.export(export_format) filedata = dataset.export(export_format)
@ -276,7 +276,7 @@ class PurchaseOrderList(PurchaseOrderMixin, APIDownloadMixin, ListCreateAPI):
return DownloadFile(filedata, filename) return DownloadFile(filedata, filename)
def filter_queryset(self, queryset): def filter_queryset(self, queryset):
"""Custom queryset filtering""" """Custom queryset filtering."""
# Perform basic filtering # Perform basic filtering
queryset = super().filter_queryset(queryset) queryset = super().filter_queryset(queryset)
@ -427,7 +427,7 @@ class PurchaseOrderLineItemFilter(LineItemFilter):
pending = rest_filters.BooleanFilter(label='pending', method='filter_pending') pending = rest_filters.BooleanFilter(label='pending', method='filter_pending')
def filter_pending(self, queryset, name, value): def filter_pending(self, queryset, name, value):
"""Filter by "pending" status (order status = pending)""" """Filter by "pending" status (order status = pending)."""
if str2bool(value): if str2bool(value):
return queryset.filter(order__status__in=PurchaseOrderStatusGroups.OPEN) return queryset.filter(order__status__in=PurchaseOrderStatusGroups.OPEN)
return queryset.exclude(order__status__in=PurchaseOrderStatusGroups.OPEN) return queryset.exclude(order__status__in=PurchaseOrderStatusGroups.OPEN)
@ -435,7 +435,7 @@ class PurchaseOrderLineItemFilter(LineItemFilter):
received = rest_filters.BooleanFilter(label='received', method='filter_received') received = rest_filters.BooleanFilter(label='received', method='filter_received')
def filter_received(self, queryset, name, value): def filter_received(self, queryset, name, value):
"""Filter by lines which are "received" (or "not" received) """Filter by lines which are "received" (or "not" received).
A line is considered "received" when received >= quantity A line is considered "received" when received >= quantity
""" """
@ -450,13 +450,13 @@ class PurchaseOrderLineItemFilter(LineItemFilter):
class PurchaseOrderLineItemMixin: class PurchaseOrderLineItemMixin:
"""Mixin class for PurchaseOrderLineItem endpoints""" """Mixin class for PurchaseOrderLineItem endpoints."""
queryset = models.PurchaseOrderLineItem.objects.all() queryset = models.PurchaseOrderLineItem.objects.all()
serializer_class = serializers.PurchaseOrderLineItemSerializer serializer_class = serializers.PurchaseOrderLineItemSerializer
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for this endpoint""" """Return annotated queryset for this endpoint."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = serializers.PurchaseOrderLineItemSerializer.annotate_queryset( queryset = serializers.PurchaseOrderLineItemSerializer.annotate_queryset(
@ -466,7 +466,7 @@ class PurchaseOrderLineItemMixin:
return queryset return queryset
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint""" """Return serializer instance for this endpoint."""
try: try:
kwargs['part_detail'] = str2bool( kwargs['part_detail'] = str2bool(
self.request.query_params.get('part_detail', False) self.request.query_params.get('part_detail', False)
@ -513,7 +513,7 @@ class PurchaseOrderLineItemList(
return queryset return queryset
def download_queryset(self, queryset, export_format): def download_queryset(self, queryset, export_format):
"""Download the requested queryset as a file""" """Download the requested queryset as a file."""
dataset = PurchaseOrderLineItemResource().export(queryset=queryset) dataset = PurchaseOrderLineItemResource().export(queryset=queryset)
filedata = dataset.export(export_format) filedata = dataset.export(export_format)
@ -564,7 +564,7 @@ class PurchaseOrderExtraLineList(GeneralExtraLineList, ListCreateAPI):
serializer_class = serializers.PurchaseOrderExtraLineSerializer serializer_class = serializers.PurchaseOrderExtraLineSerializer
def download_queryset(self, queryset, export_format): def download_queryset(self, queryset, export_format):
"""Download this queryset as a file""" """Download this queryset as a file."""
dataset = PurchaseOrderExtraLineResource().export(queryset=queryset) dataset = PurchaseOrderExtraLineResource().export(queryset=queryset)
filedata = dataset.export(export_format) filedata = dataset.export(export_format)
filename = f'InvenTree_ExtraPurchaseOrderLines.{export_format}' filename = f'InvenTree_ExtraPurchaseOrderLines.{export_format}'
@ -580,7 +580,7 @@ class PurchaseOrderExtraLineDetail(RetrieveUpdateDestroyAPI):
class SalesOrderAttachmentList(AttachmentMixin, ListCreateDestroyAPIView): class SalesOrderAttachmentList(AttachmentMixin, ListCreateDestroyAPIView):
"""API endpoint for listing (and creating) a SalesOrderAttachment (file upload)""" """API endpoint for listing (and creating) a SalesOrderAttachment (file upload)."""
queryset = models.SalesOrderAttachment.objects.all() queryset = models.SalesOrderAttachment.objects.all()
serializer_class = serializers.SalesOrderAttachmentSerializer serializer_class = serializers.SalesOrderAttachmentSerializer
@ -606,13 +606,13 @@ class SalesOrderFilter(OrderFilter):
class SalesOrderMixin: class SalesOrderMixin:
"""Mixin class for SalesOrder endpoints""" """Mixin class for SalesOrder endpoints."""
queryset = models.SalesOrder.objects.all() queryset = models.SalesOrder.objects.all()
serializer_class = serializers.SalesOrderSerializer serializer_class = serializers.SalesOrderSerializer
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint""" """Return serializer instance for this endpoint."""
try: try:
kwargs['customer_detail'] = str2bool( kwargs['customer_detail'] = str2bool(
self.request.query_params.get('customer_detail', False) self.request.query_params.get('customer_detail', False)
@ -626,7 +626,7 @@ class SalesOrderMixin:
return self.serializer_class(*args, **kwargs) return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for this endpoint""" """Return annotated queryset for this endpoint."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related('customer', 'lines') queryset = queryset.prefetch_related('customer', 'lines')
@ -660,7 +660,7 @@ class SalesOrderList(SalesOrderMixin, APIDownloadMixin, ListCreateAPI):
) )
def download_queryset(self, queryset, export_format): def download_queryset(self, queryset, export_format):
"""Download this queryset as a file""" """Download this queryset as a file."""
dataset = SalesOrderResource().export(queryset=queryset) dataset = SalesOrderResource().export(queryset=queryset)
filedata = dataset.export(export_format) filedata = dataset.export(export_format)
@ -759,13 +759,13 @@ class SalesOrderLineItemFilter(LineItemFilter):
class SalesOrderLineItemMixin: class SalesOrderLineItemMixin:
"""Mixin class for SalesOrderLineItem endpoints""" """Mixin class for SalesOrderLineItem endpoints."""
queryset = models.SalesOrderLineItem.objects.all() queryset = models.SalesOrderLineItem.objects.all()
serializer_class = serializers.SalesOrderLineItemSerializer serializer_class = serializers.SalesOrderLineItemSerializer
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return serializer for this endpoint with extra data as requested""" """Return serializer for this endpoint with extra data as requested."""
try: try:
params = self.request.query_params params = self.request.query_params
@ -782,7 +782,7 @@ class SalesOrderLineItemMixin:
return self.serializer_class(*args, **kwargs) return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for this endpoint""" """Return annotated queryset for this endpoint."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related( queryset = queryset.prefetch_related(
@ -805,7 +805,7 @@ class SalesOrderLineItemList(SalesOrderLineItemMixin, APIDownloadMixin, ListCrea
filterset_class = SalesOrderLineItemFilter filterset_class = SalesOrderLineItemFilter
def download_queryset(self, queryset, export_format): def download_queryset(self, queryset, export_format):
"""Download the requested queryset as a file""" """Download the requested queryset as a file."""
dataset = SalesOrderLineItemResource().export(queryset=queryset) dataset = SalesOrderLineItemResource().export(queryset=queryset)
filedata = dataset.export(export_format) filedata = dataset.export(export_format)
@ -833,7 +833,7 @@ class SalesOrderExtraLineList(GeneralExtraLineList, ListCreateAPI):
serializer_class = serializers.SalesOrderExtraLineSerializer serializer_class = serializers.SalesOrderExtraLineSerializer
def download_queryset(self, queryset, export_format): def download_queryset(self, queryset, export_format):
"""Download this queryset as a file""" """Download this queryset as a file."""
dataset = SalesOrderExtraLineResource().export(queryset=queryset) dataset = SalesOrderExtraLineResource().export(queryset=queryset)
filedata = dataset.export(export_format) filedata = dataset.export(export_format)
filename = f'InvenTree_ExtraSalesOrderLines.{export_format}' filename = f'InvenTree_ExtraSalesOrderLines.{export_format}'
@ -854,7 +854,7 @@ class SalesOrderContextMixin:
queryset = models.SalesOrder.objects.all() queryset = models.SalesOrder.objects.all()
def get_serializer_context(self): def get_serializer_context(self):
"""Add the 'order' reference to the serializer context for any classes which inherit this mixin""" """Add the 'order' reference to the serializer context for any classes which inherit this mixin."""
ctx = super().get_serializer_context() ctx = super().get_serializer_context()
ctx['request'] = self.request ctx['request'] = self.request
@ -868,13 +868,13 @@ class SalesOrderContextMixin:
class SalesOrderCancel(SalesOrderContextMixin, CreateAPI): class SalesOrderCancel(SalesOrderContextMixin, CreateAPI):
"""API endpoint to cancel a SalesOrder""" """API endpoint to cancel a SalesOrder."""
serializer_class = serializers.SalesOrderCancelSerializer serializer_class = serializers.SalesOrderCancelSerializer
class SalesOrderIssue(SalesOrderContextMixin, CreateAPI): class SalesOrderIssue(SalesOrderContextMixin, CreateAPI):
"""API endpoint to issue a SalesOrder""" """API endpoint to issue a SalesOrder."""
serializer_class = serializers.SalesOrderIssueSerializer serializer_class = serializers.SalesOrderIssueSerializer
@ -935,7 +935,7 @@ class SalesOrderAllocationList(ListAPI):
return self.serializer_class(*args, **kwargs) return self.serializer_class(*args, **kwargs)
def filter_queryset(self, queryset): def filter_queryset(self, queryset):
"""Custom queryset filtering""" """Custom queryset filtering."""
queryset = super().filter_queryset(queryset) queryset = super().filter_queryset(queryset)
# Filter by order # Filter by order
@ -995,7 +995,7 @@ class SalesOrderShipmentFilter(rest_filters.FilterSet):
shipped = rest_filters.BooleanFilter(label='shipped', method='filter_shipped') shipped = rest_filters.BooleanFilter(label='shipped', method='filter_shipped')
def filter_shipped(self, queryset, name, value): def filter_shipped(self, queryset, name, value):
"""Filter SalesOrder list by 'shipped' status (boolean)""" """Filter SalesOrder list by 'shipped' status (boolean)."""
if str2bool(value): if str2bool(value):
return queryset.exclude(shipment_date=None) return queryset.exclude(shipment_date=None)
return queryset.filter(shipment_date=None) return queryset.filter(shipment_date=None)
@ -1003,7 +1003,7 @@ class SalesOrderShipmentFilter(rest_filters.FilterSet):
delivered = rest_filters.BooleanFilter(label='delivered', method='filter_delivered') delivered = rest_filters.BooleanFilter(label='delivered', method='filter_delivered')
def filter_delivered(self, queryset, name, value): def filter_delivered(self, queryset, name, value):
"""Filter SalesOrder list by 'delivered' status (boolean)""" """Filter SalesOrder list by 'delivered' status (boolean)."""
if str2bool(value): if str2bool(value):
return queryset.exclude(delivery_date=None) return queryset.exclude(delivery_date=None)
return queryset.filter(delivery_date=None) return queryset.filter(delivery_date=None)
@ -1048,7 +1048,7 @@ class SalesOrderShipmentComplete(CreateAPI):
class PurchaseOrderAttachmentList(AttachmentMixin, ListCreateDestroyAPIView): class PurchaseOrderAttachmentList(AttachmentMixin, ListCreateDestroyAPIView):
"""API endpoint for listing (and creating) a PurchaseOrderAttachment (file upload)""" """API endpoint for listing (and creating) a PurchaseOrderAttachment (file upload)."""
queryset = models.PurchaseOrderAttachment.objects.all() queryset = models.PurchaseOrderAttachment.objects.all()
serializer_class = serializers.PurchaseOrderAttachmentSerializer serializer_class = serializers.PurchaseOrderAttachmentSerializer
@ -1064,23 +1064,23 @@ class PurchaseOrderAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI):
class ReturnOrderFilter(OrderFilter): class ReturnOrderFilter(OrderFilter):
"""Custom API filters for the ReturnOrderList endpoint""" """Custom API filters for the ReturnOrderList endpoint."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = models.ReturnOrder model = models.ReturnOrder
fields = ['customer'] fields = ['customer']
class ReturnOrderMixin: class ReturnOrderMixin:
"""Mixin class for ReturnOrder endpoints""" """Mixin class for ReturnOrder endpoints."""
queryset = models.ReturnOrder.objects.all() queryset = models.ReturnOrder.objects.all()
serializer_class = serializers.ReturnOrderSerializer serializer_class = serializers.ReturnOrderSerializer
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return serializer instance for this endpoint""" """Return serializer instance for this endpoint."""
try: try:
kwargs['customer_detail'] = str2bool( kwargs['customer_detail'] = str2bool(
self.request.query_params.get('customer_detail', False) self.request.query_params.get('customer_detail', False)
@ -1094,7 +1094,7 @@ class ReturnOrderMixin:
return self.serializer_class(*args, **kwargs) return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for this endpoint""" """Return annotated queryset for this endpoint."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related('customer') queryset = queryset.prefetch_related('customer')
@ -1105,7 +1105,7 @@ class ReturnOrderMixin:
class ReturnOrderList(ReturnOrderMixin, APIDownloadMixin, ListCreateAPI): class ReturnOrderList(ReturnOrderMixin, APIDownloadMixin, ListCreateAPI):
"""API endpoint for accessing a list of ReturnOrder objects""" """API endpoint for accessing a list of ReturnOrder objects."""
filterset_class = ReturnOrderFilter filterset_class = ReturnOrderFilter
@ -1124,7 +1124,7 @@ class ReturnOrderList(ReturnOrderMixin, APIDownloadMixin, ListCreateAPI):
) )
def download_queryset(self, queryset, export_format): def download_queryset(self, queryset, export_format):
"""Download this queryset as a file""" """Download this queryset as a file."""
dataset = ReturnOrderResource().export(queryset=queryset) dataset = ReturnOrderResource().export(queryset=queryset)
filedata = dataset.export(export_format) filedata = dataset.export(export_format)
filename = f'InvenTree_ReturnOrders.{export_format}' filename = f'InvenTree_ReturnOrders.{export_format}'
@ -1161,13 +1161,13 @@ class ReturnOrderList(ReturnOrderMixin, APIDownloadMixin, ListCreateAPI):
class ReturnOrderDetail(ReturnOrderMixin, RetrieveUpdateDestroyAPI): class ReturnOrderDetail(ReturnOrderMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single ReturnOrder object""" """API endpoint for detail view of a single ReturnOrder object."""
pass pass
class ReturnOrderContextMixin: class ReturnOrderContextMixin:
"""Simple mixin class to add a ReturnOrder to the serializer context""" """Simple mixin class to add a ReturnOrder to the serializer context."""
queryset = models.ReturnOrder.objects.all() queryset = models.ReturnOrder.objects.all()
@ -1189,35 +1189,35 @@ class ReturnOrderContextMixin:
class ReturnOrderCancel(ReturnOrderContextMixin, CreateAPI): class ReturnOrderCancel(ReturnOrderContextMixin, CreateAPI):
"""API endpoint to cancel a ReturnOrder""" """API endpoint to cancel a ReturnOrder."""
serializer_class = serializers.ReturnOrderCancelSerializer serializer_class = serializers.ReturnOrderCancelSerializer
class ReturnOrderComplete(ReturnOrderContextMixin, CreateAPI): class ReturnOrderComplete(ReturnOrderContextMixin, CreateAPI):
"""API endpoint to complete a ReturnOrder""" """API endpoint to complete a ReturnOrder."""
serializer_class = serializers.ReturnOrderCompleteSerializer serializer_class = serializers.ReturnOrderCompleteSerializer
class ReturnOrderIssue(ReturnOrderContextMixin, CreateAPI): class ReturnOrderIssue(ReturnOrderContextMixin, CreateAPI):
"""API endpoint to issue (place) a ReturnOrder""" """API endpoint to issue (place) a ReturnOrder."""
serializer_class = serializers.ReturnOrderIssueSerializer serializer_class = serializers.ReturnOrderIssueSerializer
class ReturnOrderReceive(ReturnOrderContextMixin, CreateAPI): class ReturnOrderReceive(ReturnOrderContextMixin, CreateAPI):
"""API endpoint to receive items against a ReturnOrder""" """API endpoint to receive items against a ReturnOrder."""
queryset = models.ReturnOrder.objects.none() queryset = models.ReturnOrder.objects.none()
serializer_class = serializers.ReturnOrderReceiveSerializer serializer_class = serializers.ReturnOrderReceiveSerializer
class ReturnOrderLineItemFilter(LineItemFilter): class ReturnOrderLineItemFilter(LineItemFilter):
"""Custom filters for the ReturnOrderLineItemList endpoint""" """Custom filters for the ReturnOrderLineItemList endpoint."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
price_field = 'price' price_field = 'price'
model = models.ReturnOrderLineItem model = models.ReturnOrderLineItem
@ -1228,20 +1228,20 @@ class ReturnOrderLineItemFilter(LineItemFilter):
received = rest_filters.BooleanFilter(label='received', method='filter_received') received = rest_filters.BooleanFilter(label='received', method='filter_received')
def filter_received(self, queryset, name, value): def filter_received(self, queryset, name, value):
"""Filter by 'received' field""" """Filter by 'received' field."""
if str2bool(value): if str2bool(value):
return queryset.exclude(received_date=None) return queryset.exclude(received_date=None)
return queryset.filter(received_date=None) return queryset.filter(received_date=None)
class ReturnOrderLineItemMixin: class ReturnOrderLineItemMixin:
"""Mixin class for ReturnOrderLineItem endpoints""" """Mixin class for ReturnOrderLineItem endpoints."""
queryset = models.ReturnOrderLineItem.objects.all() queryset = models.ReturnOrderLineItem.objects.all()
serializer_class = serializers.ReturnOrderLineItemSerializer serializer_class = serializers.ReturnOrderLineItemSerializer
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return serializer for this endpoint with extra data as requested""" """Return serializer for this endpoint with extra data as requested."""
try: try:
params = self.request.query_params params = self.request.query_params
@ -1256,7 +1256,7 @@ class ReturnOrderLineItemMixin:
return self.serializer_class(*args, **kwargs) return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Return annotated queryset for this endpoint""" """Return annotated queryset for this endpoint."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related('order', 'item', 'item__part') queryset = queryset.prefetch_related('order', 'item', 'item__part')
@ -1267,12 +1267,12 @@ class ReturnOrderLineItemMixin:
class ReturnOrderLineItemList( class ReturnOrderLineItemList(
ReturnOrderLineItemMixin, APIDownloadMixin, ListCreateAPI ReturnOrderLineItemMixin, APIDownloadMixin, ListCreateAPI
): ):
"""API endpoint for accessing a list of ReturnOrderLineItemList objects""" """API endpoint for accessing a list of ReturnOrderLineItemList objects."""
filterset_class = ReturnOrderLineItemFilter filterset_class = ReturnOrderLineItemFilter
def download_queryset(self, queryset, export_format): def download_queryset(self, queryset, export_format):
"""Download the requested queryset as a file""" """Download the requested queryset as a file."""
raise NotImplementedError( raise NotImplementedError(
'download_queryset not yet implemented for this endpoint' 'download_queryset not yet implemented for this endpoint'
) )
@ -1290,31 +1290,31 @@ class ReturnOrderLineItemList(
class ReturnOrderLineItemDetail(ReturnOrderLineItemMixin, RetrieveUpdateDestroyAPI): class ReturnOrderLineItemDetail(ReturnOrderLineItemMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a ReturnOrderLineItem object""" """API endpoint for detail view of a ReturnOrderLineItem object."""
pass pass
class ReturnOrderExtraLineList(GeneralExtraLineList, ListCreateAPI): class ReturnOrderExtraLineList(GeneralExtraLineList, ListCreateAPI):
"""API endpoint for accessing a list of ReturnOrderExtraLine objects""" """API endpoint for accessing a list of ReturnOrderExtraLine objects."""
queryset = models.ReturnOrderExtraLine.objects.all() queryset = models.ReturnOrderExtraLine.objects.all()
serializer_class = serializers.ReturnOrderExtraLineSerializer serializer_class = serializers.ReturnOrderExtraLineSerializer
def download_queryset(self, queryset, export_format): def download_queryset(self, queryset, export_format):
"""Download this queryset as a file""" """Download this queryset as a file."""
raise NotImplementedError('download_queryset not yet implemented') raise NotImplementedError('download_queryset not yet implemented')
class ReturnOrderExtraLineDetail(RetrieveUpdateDestroyAPI): class ReturnOrderExtraLineDetail(RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a ReturnOrderExtraLine object""" """API endpoint for detail view of a ReturnOrderExtraLine object."""
queryset = models.ReturnOrderExtraLine.objects.all() queryset = models.ReturnOrderExtraLine.objects.all()
serializer_class = serializers.ReturnOrderExtraLineSerializer serializer_class = serializers.ReturnOrderExtraLineSerializer
class ReturnOrderAttachmentList(AttachmentMixin, ListCreateDestroyAPIView): class ReturnOrderAttachmentList(AttachmentMixin, ListCreateDestroyAPIView):
"""API endpoint for listing (and creating) a ReturnOrderAttachment (file upload)""" """API endpoint for listing (and creating) a ReturnOrderAttachment (file upload)."""
queryset = models.ReturnOrderAttachment.objects.all() queryset = models.ReturnOrderAttachment.objects.all()
serializer_class = serializers.ReturnOrderAttachmentSerializer serializer_class = serializers.ReturnOrderAttachmentSerializer
@ -1323,14 +1323,14 @@ class ReturnOrderAttachmentList(AttachmentMixin, ListCreateDestroyAPIView):
class ReturnOrderAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI): class ReturnOrderAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI):
"""Detail endpoint for the ReturnOrderAttachment model""" """Detail endpoint for the ReturnOrderAttachment model."""
queryset = models.ReturnOrderAttachment.objects.all() queryset = models.ReturnOrderAttachment.objects.all()
serializer_class = serializers.ReturnOrderAttachmentSerializer serializer_class = serializers.ReturnOrderAttachmentSerializer
class OrderCalendarExport(ICalFeed): class OrderCalendarExport(ICalFeed):
"""Calendar export for Purchase/Sales Orders """Calendar export for Purchase/Sales Orders.
Optional parameters: Optional parameters:
- include_completed: true/false - include_completed: true/false
@ -1390,7 +1390,7 @@ class OrderCalendarExport(ICalFeed):
return response return response
def get_object(self, request, *args, **kwargs): def get_object(self, request, *args, **kwargs):
"""This is where settings from the URL etc will be obtained""" """This is where settings from the URL etc will be obtained."""
# Help: # Help:
# https://django.readthedocs.io/en/stable/ref/contrib/syndication.html # https://django.readthedocs.io/en/stable/ref/contrib/syndication.html
@ -1457,11 +1457,11 @@ class OrderCalendarExport(ICalFeed):
return outlist return outlist
def item_title(self, item): def item_title(self, item):
"""Set the event title to the order reference""" """Set the event title to the order reference."""
return f'{item.reference}' return f'{item.reference}'
def item_description(self, item): def item_description(self, item):
"""Set the event description""" """Set the event description."""
return f'Company: {item.company.name}\nStatus: {item.get_status_display()}\nDescription: {item.description}' return f'Company: {item.company.name}\nStatus: {item.get_status_display()}\nDescription: {item.description}'
def item_start_datetime(self, item): def item_start_datetime(self, item):
@ -1477,11 +1477,11 @@ class OrderCalendarExport(ICalFeed):
return item.creation_date return item.creation_date
def item_class(self, item): def item_class(self, item):
"""Set item class to PUBLIC""" """Set item class to PUBLIC."""
return 'PUBLIC' return 'PUBLIC'
def item_guid(self, item): def item_guid(self, item):
"""Return globally unique UID for event""" """Return globally unique UID for event."""
return f'po_{item.pk}_{item.reference.replace(" ","-")}@{self.instance_url}' return f'po_{item.pk}_{item.reference.replace(" ","-")}@{self.instance_url}'
def item_link(self, item): def item_link(self, item):

View File

@ -1,9 +1,9 @@
"""Config for the 'order' app""" """Config for the 'order' app."""
from django.apps import AppConfig from django.apps import AppConfig
class OrderConfig(AppConfig): class OrderConfig(AppConfig):
"""Configuration class for the 'order' app""" """Configuration class for the 'order' app."""
name = 'order' name = 'order'

View File

@ -67,7 +67,7 @@ logger = logging.getLogger('inventree')
class TotalPriceMixin(models.Model): class TotalPriceMixin(models.Model):
"""Mixin which provides 'total_price' field for an order""" """Mixin which provides 'total_price' field for an order."""
class Meta: class Meta:
"""Meta for MetadataMixin.""" """Meta for MetadataMixin."""
@ -75,7 +75,7 @@ class TotalPriceMixin(models.Model):
abstract = True abstract = True
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Update the total_price field when saved""" """Update the total_price field when saved."""
# Recalculate total_price for this order # Recalculate total_price for this order
self.update_total_price(commit=False) self.update_total_price(commit=False)
super().save(*args, **kwargs) super().save(*args, **kwargs)
@ -99,8 +99,9 @@ class TotalPriceMixin(models.Model):
@property @property
def currency(self): def currency(self):
"""Return the currency associated with this order instance: """Return the currency associated with this order instance.
Rules:
- If the order_currency field is set, return that - If the order_currency field is set, return that
- Otherwise, return the currency associated with the company - Otherwise, return the currency associated with the company
- Finally, return the default currency code - Finally, return the default currency code
@ -115,7 +116,7 @@ class TotalPriceMixin(models.Model):
return currency_code_default() return currency_code_default()
def update_total_price(self, commit=True): def update_total_price(self, commit=True):
"""Recalculate and save the total_price for this order""" """Recalculate and save the total_price for this order."""
self.total_price = self.calculate_total_price(target_currency=self.currency) self.total_price = self.calculate_total_price(target_currency=self.currency)
if commit: if commit:
@ -205,7 +206,7 @@ class Order(
abstract = True abstract = True
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Custom save method for the order models: """Custom save method for the order models.
Ensures that the reference field is rebuilt whenever the instance is saved. Ensures that the reference field is rebuilt whenever the instance is saved.
""" """
@ -217,7 +218,7 @@ class Order(
super().save(*args, **kwargs) super().save(*args, **kwargs)
def clean(self): def clean(self):
"""Custom clean method for the generic order class""" """Custom clean method for the generic order class."""
super().clean() super().clean()
# Check that the referenced 'contact' matches the correct 'company' # Check that the referenced 'contact' matches the correct 'company'
@ -229,7 +230,7 @@ class Order(
@classmethod @classmethod
def overdue_filter(cls): def overdue_filter(cls):
"""A generic implementation of an 'overdue' filter for the Model class """A generic implementation of an 'overdue' filter for the Model class.
It requires any subclasses to implement the get_status_class() class method It requires any subclasses to implement the get_status_class() class method
""" """
@ -326,7 +327,7 @@ class Order(
@classmethod @classmethod
def get_status_class(cls): def get_status_class(cls):
"""Return the enumeration class which represents the 'status' field for this model""" """Return the enumeration class which represents the 'status' field for this model."""
raise NotImplementedError(f'get_status_class() not implemented for {__class__}') raise NotImplementedError(f'get_status_class() not implemented for {__class__}')
@ -341,22 +342,22 @@ class PurchaseOrder(TotalPriceMixin, Order):
""" """
def get_absolute_url(self): def get_absolute_url(self):
"""Get the 'web' URL for this order""" """Get the 'web' URL for this order."""
return reverse('po-detail', kwargs={'pk': self.pk}) return reverse('po-detail', kwargs={'pk': self.pk})
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the PurchaseOrder model""" """Return the API URL associated with the PurchaseOrder model."""
return reverse('api-po-list') return reverse('api-po-list')
@classmethod @classmethod
def get_status_class(cls): def get_status_class(cls):
"""Return the PurchasOrderStatus class""" """Return the PurchasOrderStatus class."""
return PurchaseOrderStatusGroups return PurchaseOrderStatusGroups
@classmethod @classmethod
def api_defaults(cls, request): def api_defaults(cls, request):
"""Return default values for this model when issuing an API OPTIONS request""" """Return default values for this model when issuing an API OPTIONS request."""
defaults = { defaults = {
'reference': order.validators.generate_next_purchase_order_reference() 'reference': order.validators.generate_next_purchase_order_reference()
} }
@ -411,7 +412,7 @@ class PurchaseOrder(TotalPriceMixin, Order):
return queryset return queryset
def __str__(self): def __str__(self):
"""Render a string representation of this PurchaseOrder""" """Render a string representation of this PurchaseOrder."""
return f"{self.reference} - {self.supplier.name if self.supplier else _('deleted')}" return f"{self.reference} - {self.supplier.name if self.supplier else _('deleted')}"
reference = models.CharField( reference = models.CharField(
@ -432,7 +433,7 @@ class PurchaseOrder(TotalPriceMixin, Order):
@property @property
def status_text(self): def status_text(self):
"""Return the text representation of the status field""" """Return the text representation of the status field."""
return PurchaseOrderStatus.text(self.status) return PurchaseOrderStatus.text(self.status)
supplier = models.ForeignKey( supplier = models.ForeignKey(
@ -447,7 +448,7 @@ class PurchaseOrder(TotalPriceMixin, Order):
@property @property
def company(self): def company(self):
"""Accessor helper for Order base class""" """Accessor helper for Order base class."""
return self.supplier return self.supplier
supplier_reference = models.CharField( supplier_reference = models.CharField(
@ -617,12 +618,12 @@ class PurchaseOrder(TotalPriceMixin, Order):
@property @property
def is_pending(self): def is_pending(self):
"""Return True if the PurchaseOrder is 'pending'""" """Return True if the PurchaseOrder is 'pending'."""
return self.status == PurchaseOrderStatus.PENDING.value return self.status == PurchaseOrderStatus.PENDING.value
@property @property
def is_open(self): def is_open(self):
"""Return True if the PurchaseOrder is 'open'""" """Return True if the PurchaseOrder is 'open'."""
return self.status in PurchaseOrderStatusGroups.OPEN return self.status in PurchaseOrderStatusGroups.OPEN
@property @property
@ -668,17 +669,17 @@ class PurchaseOrder(TotalPriceMixin, Order):
@property @property
def line_count(self): def line_count(self):
"""Return the total number of line items associated with this order""" """Return the total number of line items associated with this order."""
return self.lines.count() return self.lines.count()
@property @property
def completed_line_count(self): def completed_line_count(self):
"""Return the number of complete line items associated with this order""" """Return the number of complete line items associated with this order."""
return self.completed_line_items().count() return self.completed_line_items().count()
@property @property
def pending_line_count(self): def pending_line_count(self):
"""Return the number of pending line items associated with this order""" """Return the number of pending line items associated with this order."""
return self.pending_line_items().count() return self.pending_line_items().count()
@property @property
@ -794,22 +795,22 @@ class SalesOrder(TotalPriceMixin, Order):
"""A SalesOrder represents a list of goods shipped outwards to a customer.""" """A SalesOrder represents a list of goods shipped outwards to a customer."""
def get_absolute_url(self): def get_absolute_url(self):
"""Get the 'web' URL for this order""" """Get the 'web' URL for this order."""
return reverse('so-detail', kwargs={'pk': self.pk}) return reverse('so-detail', kwargs={'pk': self.pk})
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the SalesOrder model""" """Return the API URL associated with the SalesOrder model."""
return reverse('api-so-list') return reverse('api-so-list')
@classmethod @classmethod
def get_status_class(cls): def get_status_class(cls):
"""Return the SalesOrderStatus class""" """Return the SalesOrderStatus class."""
return SalesOrderStatusGroups return SalesOrderStatusGroups
@classmethod @classmethod
def api_defaults(cls, request): def api_defaults(cls, request):
"""Return default values for this model when issuing an API OPTIONS request""" """Return default values for this model when issuing an API OPTIONS request."""
defaults = {'reference': order.validators.generate_next_sales_order_reference()} defaults = {'reference': order.validators.generate_next_sales_order_reference()}
return defaults return defaults
@ -862,7 +863,7 @@ class SalesOrder(TotalPriceMixin, Order):
return queryset return queryset
def __str__(self): def __str__(self):
"""Render a string representation of this SalesOrder""" """Render a string representation of this SalesOrder."""
return f"{self.reference} - {self.customer.name if self.customer else _('deleted')}" return f"{self.reference} - {self.customer.name if self.customer else _('deleted')}"
reference = models.CharField( reference = models.CharField(
@ -887,7 +888,7 @@ class SalesOrder(TotalPriceMixin, Order):
@property @property
def company(self): def company(self):
"""Accessor helper for Order base""" """Accessor helper for Order base."""
return self.customer return self.customer
status = models.PositiveIntegerField( status = models.PositiveIntegerField(
@ -899,7 +900,7 @@ class SalesOrder(TotalPriceMixin, Order):
@property @property
def status_text(self): def status_text(self):
"""Return the text representation of the status field""" """Return the text representation of the status field."""
return SalesOrderStatus.text(self.status) return SalesOrderStatus.text(self.status)
customer_reference = models.CharField( customer_reference = models.CharField(
@ -924,12 +925,12 @@ class SalesOrder(TotalPriceMixin, Order):
@property @property
def is_pending(self): def is_pending(self):
"""Return True if this order is 'pending'""" """Return True if this order is 'pending'."""
return self.status == SalesOrderStatus.PENDING return self.status == SalesOrderStatus.PENDING
@property @property
def is_open(self): def is_open(self):
"""Return True if this order is 'open' (either 'pending' or 'in_progress')""" """Return True if this order is 'open' (either 'pending' or 'in_progress')."""
return self.status in SalesOrderStatusGroups.OPEN return self.status in SalesOrderStatusGroups.OPEN
@property @property
@ -997,11 +998,11 @@ class SalesOrder(TotalPriceMixin, Order):
# region state changes # region state changes
def place_order(self): def place_order(self):
"""Deprecated version of 'issue_order'""" """Deprecated version of 'issue_order'."""
self.issue_order() self.issue_order()
def _action_place(self, *args, **kwargs): def _action_place(self, *args, **kwargs):
"""Change this order from 'PENDING' to 'IN_PROGRESS'""" """Change this order from 'PENDING' to 'IN_PROGRESS'."""
if self.status == SalesOrderStatus.PENDING: if self.status == SalesOrderStatus.PENDING:
self.status = SalesOrderStatus.IN_PROGRESS.value self.status = SalesOrderStatus.IN_PROGRESS.value
self.issue_date = datetime.now().date() self.issue_date = datetime.now().date()
@ -1094,7 +1095,7 @@ class SalesOrder(TotalPriceMixin, Order):
@property @property
def line_count(self): def line_count(self):
"""Return the total number of lines associated with this order""" """Return the total number of lines associated with this order."""
return self.lines.count() return self.lines.count()
def completed_line_items(self): def completed_line_items(self):
@ -1107,12 +1108,12 @@ class SalesOrder(TotalPriceMixin, Order):
@property @property
def completed_line_count(self): def completed_line_count(self):
"""Return the number of completed lines for this order""" """Return the number of completed lines for this order."""
return self.completed_line_items().count() return self.completed_line_items().count()
@property @property
def pending_line_count(self): def pending_line_count(self):
"""Return the number of pending (incomplete) lines associated with this order""" """Return the number of pending (incomplete) lines associated with this order."""
return self.pending_line_items().count() return self.pending_line_items().count()
def completed_shipments(self): def completed_shipments(self):
@ -1125,17 +1126,17 @@ class SalesOrder(TotalPriceMixin, Order):
@property @property
def shipment_count(self): def shipment_count(self):
"""Return the total number of shipments associated with this order""" """Return the total number of shipments associated with this order."""
return self.shipments.count() return self.shipments.count()
@property @property
def completed_shipment_count(self): def completed_shipment_count(self):
"""Return the number of completed shipments associated with this order""" """Return the number of completed shipments associated with this order."""
return self.completed_shipments().count() return self.completed_shipments().count()
@property @property
def pending_shipment_count(self): def pending_shipment_count(self):
"""Return the number of pending shipments associated with this order""" """Return the number of pending shipments associated with this order."""
return self.pending_shipments().count() return self.pending_shipments().count()
@ -1169,11 +1170,11 @@ class PurchaseOrderAttachment(InvenTreeAttachment):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the PurchaseOrderAttachment model""" """Return the API URL associated with the PurchaseOrderAttachment model."""
return reverse('api-po-attachment-list') return reverse('api-po-attachment-list')
def getSubdir(self): def getSubdir(self):
"""Return the directory path where PurchaseOrderAttachment files are located""" """Return the directory path where PurchaseOrderAttachment files are located."""
return os.path.join('po_files', str(self.order.id)) return os.path.join('po_files', str(self.order.id))
order = models.ForeignKey( order = models.ForeignKey(
@ -1186,11 +1187,11 @@ class SalesOrderAttachment(InvenTreeAttachment):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the SalesOrderAttachment class""" """Return the API URL associated with the SalesOrderAttachment class."""
return reverse('api-so-attachment-list') return reverse('api-so-attachment-list')
def getSubdir(self): def getSubdir(self):
"""Return the directory path where SalesOrderAttachment files are located""" """Return the directory path where SalesOrderAttachment files are located."""
return os.path.join('so_files', str(self.order.id)) return os.path.join('so_files', str(self.order.id))
order = models.ForeignKey( order = models.ForeignKey(
@ -1214,7 +1215,7 @@ class OrderLineItem(MetadataMixin, models.Model):
abstract = True abstract = True
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Custom save method for the OrderLineItem model """Custom save method for the OrderLineItem model.
Calls save method on the linked order Calls save method on the linked order
""" """
@ -1222,7 +1223,7 @@ class OrderLineItem(MetadataMixin, models.Model):
self.order.save() self.order.save()
def delete(self, *args, **kwargs): def delete(self, *args, **kwargs):
"""Custom delete method for the OrderLineItem model """Custom delete method for the OrderLineItem model.
Calls save method on the linked order Calls save method on the linked order
""" """
@ -1240,7 +1241,7 @@ class OrderLineItem(MetadataMixin, models.Model):
@property @property
def total_line_price(self): def total_line_price(self):
"""Return the total price for this line item""" """Return the total price for this line item."""
if self.price: if self.price:
return self.quantity * self.price return self.quantity * self.price
@ -1325,13 +1326,13 @@ class PurchaseOrderLineItem(OrderLineItem):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the PurchaseOrderLineItem model""" """Return the API URL associated with the PurchaseOrderLineItem model."""
return reverse('api-po-line-list') return reverse('api-po-line-list')
def clean(self): def clean(self):
"""Custom clean method for the PurchaseOrderLineItem model: """Custom clean method for the PurchaseOrderLineItem model.
- Ensure the supplier part matches the supplier Ensure the supplier part matches the supplier
""" """
super().clean() super().clean()
@ -1341,7 +1342,7 @@ class PurchaseOrderLineItem(OrderLineItem):
raise ValidationError({'part': _('Supplier part must match supplier')}) raise ValidationError({'part': _('Supplier part must match supplier')})
def __str__(self): def __str__(self):
"""Render a string representation of a PurchaseOrderLineItem instance""" """Render a string representation of a PurchaseOrderLineItem instance."""
return '{n} x {part} from {supplier} (for {po})'.format( return '{n} x {part} from {supplier} (for {po})'.format(
n=decimal2string(self.quantity), n=decimal2string(self.quantity),
part=self.part.SKU if self.part else 'unknown part', part=self.part.SKU if self.part else 'unknown part',
@ -1395,7 +1396,7 @@ class PurchaseOrderLineItem(OrderLineItem):
@property @property
def price(self): def price(self):
"""Return the 'purchase_price' field as 'price'""" """Return the 'purchase_price' field as 'price'."""
return self.purchase_price return self.purchase_price
destination = TreeForeignKey( destination = TreeForeignKey(
@ -1442,7 +1443,7 @@ class PurchaseOrderExtraLine(OrderExtraLine):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the PurchaseOrderExtraLine model""" """Return the API URL associated with the PurchaseOrderExtraLine model."""
return reverse('api-po-extra-line-list') return reverse('api-po-extra-line-list')
order = models.ForeignKey( order = models.ForeignKey(
@ -1473,11 +1474,11 @@ class SalesOrderLineItem(OrderLineItem):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the SalesOrderLineItem model""" """Return the API URL associated with the SalesOrderLineItem model."""
return reverse('api-so-line-list') return reverse('api-so-line-list')
def clean(self): def clean(self):
"""Perform extra validation steps for this SalesOrderLineItem instance""" """Perform extra validation steps for this SalesOrderLineItem instance."""
super().clean() super().clean()
if self.part: if self.part:
@ -1520,7 +1521,7 @@ class SalesOrderLineItem(OrderLineItem):
@property @property
def price(self): def price(self):
"""Return the 'sale_price' field as 'price'""" """Return the 'sale_price' field as 'price'."""
return self.sale_price return self.sale_price
shipped = RoundingDecimalField( shipped = RoundingDecimalField(
@ -1583,14 +1584,14 @@ class SalesOrderShipment(InvenTreeNotesMixin, MetadataMixin, models.Model):
""" """
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options."""
# Shipment reference must be unique for a given sales order # Shipment reference must be unique for a given sales order
unique_together = ['order', 'reference'] unique_together = ['order', 'reference']
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the SalesOrderShipment model""" """Return the API URL associated with the SalesOrderShipment model."""
return reverse('api-so-shipment-list') return reverse('api-so-shipment-list')
order = models.ForeignKey( order = models.ForeignKey(
@ -1656,15 +1657,15 @@ class SalesOrderShipment(InvenTreeNotesMixin, MetadataMixin, models.Model):
) )
def is_complete(self): def is_complete(self):
"""Return True if this shipment has already been completed""" """Return True if this shipment has already been completed."""
return self.shipment_date is not None return self.shipment_date is not None
def is_delivered(self): def is_delivered(self):
"""Return True if this shipment has already been delivered""" """Return True if this shipment has already been delivered."""
return self.delivery_date is not None return self.delivery_date is not None
def check_can_complete(self, raise_error=True): def check_can_complete(self, raise_error=True):
"""Check if this shipment is able to be completed""" """Check if this shipment is able to be completed."""
try: try:
if self.shipment_date: if self.shipment_date:
# Shipment has already been sent! # Shipment has already been sent!
@ -1744,7 +1745,7 @@ class SalesOrderExtraLine(OrderExtraLine):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the SalesOrderExtraLine model""" """Return the API URL associated with the SalesOrderExtraLine model."""
return reverse('api-so-extra-line-list') return reverse('api-so-extra-line-list')
order = models.ForeignKey( order = models.ForeignKey(
@ -1768,7 +1769,7 @@ class SalesOrderAllocation(models.Model):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the SalesOrderAllocation model""" """Return the API URL associated with the SalesOrderAllocation model."""
return reverse('api-so-allocation-list') return reverse('api-so-allocation-list')
def clean(self): def clean(self):
@ -1869,11 +1870,11 @@ class SalesOrderAllocation(models.Model):
) )
def get_location(self): def get_location(self):
"""Return the <pk> value of the location associated with this allocation""" """Return the <pk> value of the location associated with this allocation."""
return self.item.location.id if self.item.location else None return self.item.location.id if self.item.location else None
def get_po(self): def get_po(self):
"""Return the PurchaseOrder associated with this allocation""" """Return the PurchaseOrder associated with this allocation."""
return self.item.purchase_order return self.item.purchase_order
def complete_allocation(self, user): def complete_allocation(self, user):
@ -1900,7 +1901,7 @@ class SalesOrderAllocation(models.Model):
class ReturnOrder(TotalPriceMixin, Order): class ReturnOrder(TotalPriceMixin, Order):
"""A ReturnOrder represents goods returned from a customer, e.g. an RMA or warranty """A ReturnOrder represents goods returned from a customer, e.g. an RMA or warranty.
Attributes: Attributes:
customer: Reference to the customer customer: Reference to the customer
@ -1909,22 +1910,22 @@ class ReturnOrder(TotalPriceMixin, Order):
""" """
def get_absolute_url(self): def get_absolute_url(self):
"""Get the 'web' URL for this order""" """Get the 'web' URL for this order."""
return reverse('return-order-detail', kwargs={'pk': self.pk}) return reverse('return-order-detail', kwargs={'pk': self.pk})
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the ReturnOrder model""" """Return the API URL associated with the ReturnOrder model."""
return reverse('api-return-order-list') return reverse('api-return-order-list')
@classmethod @classmethod
def get_status_class(cls): def get_status_class(cls):
"""Return the ReturnOrderStatus class""" """Return the ReturnOrderStatus class."""
return ReturnOrderStatusGroups return ReturnOrderStatusGroups
@classmethod @classmethod
def api_defaults(cls, request): def api_defaults(cls, request):
"""Return default values for this model when issuing an API OPTIONS request""" """Return default values for this model when issuing an API OPTIONS request."""
defaults = { defaults = {
'reference': order.validators.generate_next_return_order_reference() 'reference': order.validators.generate_next_return_order_reference()
} }
@ -1934,7 +1935,7 @@ class ReturnOrder(TotalPriceMixin, Order):
REFERENCE_PATTERN_SETTING = 'RETURNORDER_REFERENCE_PATTERN' REFERENCE_PATTERN_SETTING = 'RETURNORDER_REFERENCE_PATTERN'
def __str__(self): def __str__(self):
"""Render a string representation of this ReturnOrder""" """Render a string representation of this ReturnOrder."""
return f"{self.reference} - {self.customer.name if self.customer else _('no customer')}" return f"{self.reference} - {self.customer.name if self.customer else _('no customer')}"
reference = models.CharField( reference = models.CharField(
@ -1959,7 +1960,7 @@ class ReturnOrder(TotalPriceMixin, Order):
@property @property
def company(self): def company(self):
"""Accessor helper for Order base class""" """Accessor helper for Order base class."""
return self.customer return self.customer
status = models.PositiveIntegerField( status = models.PositiveIntegerField(
@ -1993,21 +1994,21 @@ class ReturnOrder(TotalPriceMixin, Order):
# region state changes # region state changes
@property @property
def is_pending(self): def is_pending(self):
"""Return True if this order is pending""" """Return True if this order is pending."""
return self.status == ReturnOrderStatus.PENDING return self.status == ReturnOrderStatus.PENDING
@property @property
def is_open(self): def is_open(self):
"""Return True if this order is outstanding""" """Return True if this order is outstanding."""
return self.status in ReturnOrderStatusGroups.OPEN return self.status in ReturnOrderStatusGroups.OPEN
@property @property
def is_received(self): def is_received(self):
"""Return True if this order is fully received""" """Return True if this order is fully received."""
return not self.lines.filter(received_date=None).exists() return not self.lines.filter(received_date=None).exists()
def _action_cancel(self, *args, **kwargs): def _action_cancel(self, *args, **kwargs):
"""Cancel this ReturnOrder (if not already cancelled)""" """Cancel this ReturnOrder (if not already cancelled)."""
if self.status != ReturnOrderStatus.CANCELLED: if self.status != ReturnOrderStatus.CANCELLED:
self.status = ReturnOrderStatus.CANCELLED.value self.status = ReturnOrderStatus.CANCELLED.value
self.save() self.save()
@ -2023,7 +2024,7 @@ class ReturnOrder(TotalPriceMixin, Order):
) )
def _action_complete(self, *args, **kwargs): def _action_complete(self, *args, **kwargs):
"""Complete this ReturnOrder (if not already completed)""" """Complete this ReturnOrder (if not already completed)."""
if self.status == ReturnOrderStatus.IN_PROGRESS: if self.status == ReturnOrderStatus.IN_PROGRESS:
self.status = ReturnOrderStatus.COMPLETE.value self.status = ReturnOrderStatus.COMPLETE.value
self.complete_date = datetime.now().date() self.complete_date = datetime.now().date()
@ -2032,11 +2033,11 @@ class ReturnOrder(TotalPriceMixin, Order):
trigger_event('returnorder.completed', id=self.pk) trigger_event('returnorder.completed', id=self.pk)
def place_order(self): def place_order(self):
"""Deprecated version of 'issue_order""" """Deprecated version of 'issue_order."""
self.issue_order() self.issue_order()
def _action_place(self, *args, **kwargs): def _action_place(self, *args, **kwargs):
"""Issue this ReturnOrder (if currently pending)""" """Issue this ReturnOrder (if currently pending)."""
if self.status == ReturnOrderStatus.PENDING: if self.status == ReturnOrderStatus.PENDING:
self.status = ReturnOrderStatus.IN_PROGRESS.value self.status = ReturnOrderStatus.IN_PROGRESS.value
self.issue_date = datetime.now().date() self.issue_date = datetime.now().date()
@ -2069,8 +2070,9 @@ class ReturnOrder(TotalPriceMixin, Order):
@transaction.atomic @transaction.atomic
def receive_line_item(self, line, location, user, note=''): def receive_line_item(self, line, location, user, note=''):
"""Receive a line item against this ReturnOrder: """Receive a line item against this ReturnOrder.
Rules:
- Transfers the StockItem to the specified location - Transfers the StockItem to the specified location
- Marks the StockItem as "quarantined" - Marks the StockItem as "quarantined"
- Adds a tracking entry to the StockItem - Adds a tracking entry to the StockItem
@ -2126,20 +2128,20 @@ class ReturnOrder(TotalPriceMixin, Order):
class ReturnOrderLineItem(OrderLineItem): class ReturnOrderLineItem(OrderLineItem):
"""Model for a single LineItem in a ReturnOrder""" """Model for a single LineItem in a ReturnOrder."""
class Meta: class Meta:
"""Metaclass options for this model""" """Metaclass options for this model."""
unique_together = [('order', 'item')] unique_together = [('order', 'item')]
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with this model""" """Return the API URL associated with this model."""
return reverse('api-return-order-line-list') return reverse('api-return-order-line-list')
def clean(self): def clean(self):
"""Perform extra validation steps for the ReturnOrderLineItem model""" """Perform extra validation steps for the ReturnOrderLineItem model."""
super().clean() super().clean()
if self.item and not self.item.serialized: if self.item and not self.item.serialized:
@ -2172,7 +2174,7 @@ class ReturnOrderLineItem(OrderLineItem):
@property @property
def received(self): def received(self):
"""Return True if this item has been received""" """Return True if this item has been received."""
return self.received_date is not None return self.received_date is not None
outcome = models.PositiveIntegerField( outcome = models.PositiveIntegerField(
@ -2191,11 +2193,11 @@ class ReturnOrderLineItem(OrderLineItem):
class ReturnOrderExtraLine(OrderExtraLine): class ReturnOrderExtraLine(OrderExtraLine):
"""Model for a single ExtraLine in a ReturnOrder""" """Model for a single ExtraLine in a ReturnOrder."""
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the ReturnOrderExtraLine model""" """Return the API URL associated with the ReturnOrderExtraLine model."""
return reverse('api-return-order-extra-line-list') return reverse('api-return-order-extra-line-list')
order = models.ForeignKey( order = models.ForeignKey(
@ -2208,15 +2210,15 @@ class ReturnOrderExtraLine(OrderExtraLine):
class ReturnOrderAttachment(InvenTreeAttachment): class ReturnOrderAttachment(InvenTreeAttachment):
"""Model for storing file attachments against a ReturnOrder object""" """Model for storing file attachments against a ReturnOrder object."""
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the ReturnOrderAttachment class""" """Return the API URL associated with the ReturnOrderAttachment class."""
return reverse('api-return-order-attachment-list') return reverse('api-return-order-attachment-list')
def getSubdir(self): def getSubdir(self):
"""Return the directory path where ReturnOrderAttachment files are located""" """Return the directory path where ReturnOrderAttachment files are located."""
return os.path.join('return_files', str(self.order.id)) return os.path.join('return_files', str(self.order.id))
order = models.ForeignKey( order = models.ForeignKey(

View File

@ -43,7 +43,7 @@ from users.serializers import OwnerSerializer
class TotalPriceMixin(serializers.Serializer): class TotalPriceMixin(serializers.Serializer):
"""Serializer mixin which provides total price fields""" """Serializer mixin which provides total price fields."""
total_price = InvenTreeMoneySerializer(allow_null=True, read_only=True) total_price = InvenTreeMoneySerializer(allow_null=True, read_only=True)
@ -57,7 +57,7 @@ class TotalPriceMixin(serializers.Serializer):
class AbstractOrderSerializer(serializers.Serializer): class AbstractOrderSerializer(serializers.Serializer):
"""Abstract serializer class which provides fields common to all order types""" """Abstract serializer class which provides fields common to all order types."""
# Number of line items in this order # Number of line items in this order
line_items = serializers.IntegerField(read_only=True) line_items = serializers.IntegerField(read_only=True)
@ -98,20 +98,20 @@ class AbstractOrderSerializer(serializers.Serializer):
barcode_hash = serializers.CharField(read_only=True) barcode_hash = serializers.CharField(read_only=True)
def validate_reference(self, reference): def validate_reference(self, reference):
"""Custom validation for the reference field""" """Custom validation for the reference field."""
self.Meta.model.validate_reference_field(reference) self.Meta.model.validate_reference_field(reference)
return reference return reference
@staticmethod @staticmethod
def annotate_queryset(queryset): def annotate_queryset(queryset):
"""Add extra information to the queryset""" """Add extra information to the queryset."""
queryset = queryset.annotate(line_items=SubqueryCount('lines')) queryset = queryset.annotate(line_items=SubqueryCount('lines'))
return queryset return queryset
@staticmethod @staticmethod
def order_fields(extra_fields): def order_fields(extra_fields):
"""Construct a set of fields for this serializer""" """Construct a set of fields for this serializer."""
return [ return [
'pk', 'pk',
'creation_date', 'creation_date',
@ -141,7 +141,7 @@ class AbstractExtraLineSerializer(serializers.Serializer):
"""Abstract Serializer for a ExtraLine object.""" """Abstract Serializer for a ExtraLine object."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialization routine for the serializer""" """Initialization routine for the serializer."""
order_detail = kwargs.pop('order_detail', False) order_detail = kwargs.pop('order_detail', False)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -202,7 +202,7 @@ class PurchaseOrderSerializer(
} }
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialization routine for the serializer""" """Initialization routine for the serializer."""
supplier_detail = kwargs.pop('supplier_detail', False) supplier_detail = kwargs.pop('supplier_detail', False)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -257,7 +257,7 @@ class PurchaseOrderCancelSerializer(serializers.Serializer):
return {'can_cancel': self.order.can_cancel} return {'can_cancel': self.order.can_cancel}
def save(self): def save(self):
"""Save the serializer to 'cancel' the order""" """Save the serializer to 'cancel' the order."""
order = self.context['order'] order = self.context['order']
if not order.can_cancel: if not order.can_cancel:
@ -282,7 +282,7 @@ class PurchaseOrderCompleteSerializer(serializers.Serializer):
) )
def validate_accept_incomplete(self, value): def validate_accept_incomplete(self, value):
"""Check if the 'accept_incomplete' field is required""" """Check if the 'accept_incomplete' field is required."""
order = self.context['order'] order = self.context['order']
if not value and not order.is_complete: if not value and not order.is_complete:
@ -297,7 +297,7 @@ class PurchaseOrderCompleteSerializer(serializers.Serializer):
return {'is_complete': order.is_complete} return {'is_complete': order.is_complete}
def save(self): def save(self):
"""Save the serializer to 'complete' the order""" """Save the serializer to 'complete' the order."""
order = self.context['order'] order = self.context['order']
order.complete_order() order.complete_order()
@ -311,13 +311,13 @@ class PurchaseOrderIssueSerializer(serializers.Serializer):
fields = [] fields = []
def save(self): def save(self):
"""Save the serializer to 'place' the order""" """Save the serializer to 'place' the order."""
order = self.context['order'] order = self.context['order']
order.place_order() order.place_order()
class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer): class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer):
"""Serializer class for the PurchaseOrderLineItem model""" """Serializer class for the PurchaseOrderLineItem model."""
class Meta: class Meta:
"""Metaclass options.""" """Metaclass options."""
@ -346,7 +346,7 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer):
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialization routine for the serializer""" """Initialization routine for the serializer."""
part_detail = kwargs.pop('part_detail', False) part_detail = kwargs.pop('part_detail', False)
order_detail = kwargs.pop('order_detail', False) order_detail = kwargs.pop('order_detail', False)
@ -362,10 +362,10 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer):
@staticmethod @staticmethod
def annotate_queryset(queryset): def annotate_queryset(queryset):
"""Add some extra annotations to this queryset: """Add some extra annotations to this queryset.
- Total price = purchase_price * quantity - "total_price" = purchase_price * quantity
- "Overdue" status (boolean field) - "overdue" status (boolean field)
""" """
queryset = queryset.annotate( queryset = queryset.annotate(
total_price=ExpressionWrapper( total_price=ExpressionWrapper(
@ -388,14 +388,14 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer):
quantity = serializers.FloatField(min_value=0, required=True) quantity = serializers.FloatField(min_value=0, required=True)
def validate_quantity(self, quantity): def validate_quantity(self, quantity):
"""Validation for the 'quantity' field""" """Validation for the 'quantity' field."""
if quantity <= 0: if quantity <= 0:
raise ValidationError(_('Quantity must be greater than zero')) raise ValidationError(_('Quantity must be greater than zero'))
return quantity return quantity
def validate_purchase_order(self, purchase_order): def validate_purchase_order(self, purchase_order):
"""Validation for the 'purchase_order' field""" """Validation for the 'purchase_order' field."""
if purchase_order.status not in PurchaseOrderStatusGroups.OPEN: if purchase_order.status not in PurchaseOrderStatusGroups.OPEN:
raise ValidationError(_('Order is not open')) raise ValidationError(_('Order is not open'))
@ -428,7 +428,7 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer):
order_detail = PurchaseOrderSerializer(source='order', read_only=True, many=False) order_detail = PurchaseOrderSerializer(source='order', read_only=True, many=False)
def validate(self, data): def validate(self, data):
"""Custom validation for the serializer: """Custom validation for the serializer.
- Ensure the supplier_part field is supplied - Ensure the supplier_part field is supplied
- Ensure the purchase_order field is supplied - Ensure the purchase_order field is supplied
@ -495,7 +495,7 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer):
) )
def validate_line_item(self, item): def validate_line_item(self, item):
"""Validation for the 'line_item' field""" """Validation for the 'line_item' field."""
if item.order != self.context['order']: if item.order != self.context['order']:
raise ValidationError(_('Line item does not match purchase order')) raise ValidationError(_('Line item does not match purchase order'))
@ -515,7 +515,7 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer):
) )
def validate_quantity(self, quantity): def validate_quantity(self, quantity):
"""Validation for the 'quantity' field""" """Validation for the 'quantity' field."""
if quantity <= 0: if quantity <= 0:
raise ValidationError(_('Quantity must be greater than zero')) raise ValidationError(_('Quantity must be greater than zero'))
@ -564,7 +564,7 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer):
return barcode return barcode
def validate(self, data): def validate(self, data):
"""Custom validation for the serializer: """Custom validation for the serializer.
- Integer quantity must be provided for serialized stock - Integer quantity must be provided for serialized stock
- Validate serial numbers (if provided) - Validate serial numbers (if provided)
@ -619,7 +619,7 @@ class PurchaseOrderReceiveSerializer(serializers.Serializer):
) )
def validate(self, data): def validate(self, data):
"""Custom validation for the serializer: """Custom validation for the serializer.
- Ensure line items are provided - Ensure line items are provided
- Check that a location is specified - Check that a location is specified
@ -714,7 +714,7 @@ class PurchaseOrderAttachmentSerializer(InvenTreeAttachmentSerializer):
class SalesOrderSerializer( class SalesOrderSerializer(
TotalPriceMixin, AbstractOrderSerializer, InvenTreeModelSerializer TotalPriceMixin, AbstractOrderSerializer, InvenTreeModelSerializer
): ):
"""Serializer for the SalesOrder model class""" """Serializer for the SalesOrder model class."""
class Meta: class Meta:
"""Metaclass options.""" """Metaclass options."""
@ -735,7 +735,7 @@ class SalesOrderSerializer(
extra_kwargs = {'order_currency': {'required': False}} extra_kwargs = {'order_currency': {'required': False}}
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialization routine for the serializer""" """Initialization routine for the serializer."""
customer_detail = kwargs.pop('customer_detail', False) customer_detail = kwargs.pop('customer_detail', False)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -775,15 +775,15 @@ class SalesOrderSerializer(
class SalesOrderIssueSerializer(serializers.Serializer): class SalesOrderIssueSerializer(serializers.Serializer):
"""Serializer for issuing a SalesOrder""" """Serializer for issuing a SalesOrder."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
fields = [] fields = []
def save(self): def save(self):
"""Save the serializer to 'issue' the order""" """Save the serializer to 'issue' the order."""
order = self.context['order'] order = self.context['order']
order.issue_order() order.issue_order()
@ -818,7 +818,7 @@ class SalesOrderAllocationSerializer(InvenTreeModelSerializer):
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialization routine for the serializer""" """Initialization routine for the serializer."""
order_detail = kwargs.pop('order_detail', False) order_detail = kwargs.pop('order_detail', False)
part_detail = kwargs.pop('part_detail', True) part_detail = kwargs.pop('part_detail', True)
item_detail = kwargs.pop('item_detail', True) item_detail = kwargs.pop('item_detail', True)
@ -901,7 +901,7 @@ class SalesOrderLineItemSerializer(InvenTreeModelSerializer):
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialization routine for the serializer: """Initialization routine for the serializer.
- Add extra related serializer information if required - Add extra related serializer information if required
""" """
@ -926,7 +926,7 @@ class SalesOrderLineItemSerializer(InvenTreeModelSerializer):
@staticmethod @staticmethod
def annotate_queryset(queryset): def annotate_queryset(queryset):
"""Add some extra annotations to this queryset: """Add some extra annotations to this queryset.
- "overdue" status (boolean field) - "overdue" status (boolean field)
- "available_quantity" - "available_quantity"
@ -1066,7 +1066,7 @@ class SalesOrderShipmentCompleteSerializer(serializers.ModelSerializer):
] ]
def validate(self, data): def validate(self, data):
"""Custom validation for the serializer: """Custom validation for the serializer.
- Ensure the shipment reference is provided - Ensure the shipment reference is provided
""" """
@ -1082,7 +1082,7 @@ class SalesOrderShipmentCompleteSerializer(serializers.ModelSerializer):
return data return data
def save(self): def save(self):
"""Save the serializer to complete the SalesOrderShipment""" """Save the serializer to complete the SalesOrderShipment."""
shipment = self.context.get('shipment', None) shipment = self.context.get('shipment', None)
if not shipment: if not shipment:
@ -1127,7 +1127,7 @@ class SalesOrderShipmentAllocationItemSerializer(serializers.Serializer):
) )
def validate_line_item(self, line_item): def validate_line_item(self, line_item):
"""Custom validation for the 'line_item' field: """Custom validation for the 'line_item' field.
- Ensure the line_item is associated with the particular SalesOrder - Ensure the line_item is associated with the particular SalesOrder
""" """
@ -1152,14 +1152,14 @@ class SalesOrderShipmentAllocationItemSerializer(serializers.Serializer):
) )
def validate_quantity(self, quantity): def validate_quantity(self, quantity):
"""Custom validation for the 'quantity' field""" """Custom validation for the 'quantity' field."""
if quantity <= 0: if quantity <= 0:
raise ValidationError(_('Quantity must be positive')) raise ValidationError(_('Quantity must be positive'))
return quantity return quantity
def validate(self, data): def validate(self, data):
"""Custom validation for the serializer: """Custom validation for the serializer.
- Ensure that the quantity is 1 for serialized stock - Ensure that the quantity is 1 for serialized stock
- Quantity cannot exceed the available amount - Quantity cannot exceed the available amount
@ -1193,7 +1193,7 @@ class SalesOrderCompleteSerializer(serializers.Serializer):
) )
def validate_accept_incomplete(self, value): def validate_accept_incomplete(self, value):
"""Check if the 'accept_incomplete' field is required""" """Check if the 'accept_incomplete' field is required."""
order = self.context['order'] order = self.context['order']
if not value and not order.is_completed(): if not value and not order.is_completed():
@ -1202,7 +1202,7 @@ class SalesOrderCompleteSerializer(serializers.Serializer):
return value return value
def get_context_data(self): def get_context_data(self):
"""Custom context data for this serializer""" """Custom context data for this serializer."""
order = self.context['order'] order = self.context['order']
return { return {
@ -1211,7 +1211,7 @@ class SalesOrderCompleteSerializer(serializers.Serializer):
} }
def validate(self, data): def validate(self, data):
"""Custom validation for the serializer""" """Custom validation for the serializer."""
data = super().validate(data) data = super().validate(data)
order = self.context['order'] order = self.context['order']
@ -1224,7 +1224,7 @@ class SalesOrderCompleteSerializer(serializers.Serializer):
return data return data
def save(self): def save(self):
"""Save the serializer to complete the SalesOrder""" """Save the serializer to complete the SalesOrder."""
request = self.context['request'] request = self.context['request']
order = self.context['order'] order = self.context['order']
data = self.validated_data data = self.validated_data
@ -1240,13 +1240,13 @@ class SalesOrderCancelSerializer(serializers.Serializer):
"""Serializer for marking a SalesOrder as cancelled.""" """Serializer for marking a SalesOrder as cancelled."""
def get_context_data(self): def get_context_data(self):
"""Add extra context data to the serializer""" """Add extra context data to the serializer."""
order = self.context['order'] order = self.context['order']
return {'can_cancel': order.can_cancel} return {'can_cancel': order.can_cancel}
def save(self): def save(self):
"""Save the serializer to cancel the order""" """Save the serializer to cancel the order."""
order = self.context['order'] order = self.context['order']
order.cancel_order() order.cancel_order()
@ -1298,7 +1298,7 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer):
) )
def validate_shipment(self, shipment): def validate_shipment(self, shipment):
"""Validate the shipment: """Validate the shipment.
- Must point to the same order - Must point to the same order
- Must not be shipped - Must not be shipped
@ -1314,7 +1314,7 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer):
return shipment return shipment
def validate(self, data): def validate(self, data):
"""Validation for the serializer: """Validation for the serializer.
- Ensure the serial_numbers and quantity fields match - Ensure the serial_numbers and quantity fields match
- Check that all serial numbers exist - Check that all serial numbers exist
@ -1374,7 +1374,7 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer):
return data return data
def save(self): def save(self):
"""Allocate stock items against the sales order""" """Allocate stock items against the sales order."""
data = self.validated_data data = self.validated_data
line_item = data['line_item'] line_item = data['line_item']
@ -1481,10 +1481,10 @@ class SalesOrderAttachmentSerializer(InvenTreeAttachmentSerializer):
class ReturnOrderSerializer( class ReturnOrderSerializer(
AbstractOrderSerializer, TotalPriceMixin, InvenTreeModelSerializer AbstractOrderSerializer, TotalPriceMixin, InvenTreeModelSerializer
): ):
"""Serializer for the ReturnOrder model class""" """Serializer for the ReturnOrder model class."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = order.models.ReturnOrder model = order.models.ReturnOrder
@ -1499,7 +1499,7 @@ class ReturnOrderSerializer(
read_only_fields = ['creation_date'] read_only_fields = ['creation_date']
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialization routine for the serializer""" """Initialization routine for the serializer."""
customer_detail = kwargs.pop('customer_detail', False) customer_detail = kwargs.pop('customer_detail', False)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -1509,7 +1509,7 @@ class ReturnOrderSerializer(
@staticmethod @staticmethod
def annotate_queryset(queryset): def annotate_queryset(queryset):
"""Custom annotation for the serializer queryset""" """Custom annotation for the serializer queryset."""
queryset = AbstractOrderSerializer.annotate_queryset(queryset) queryset = AbstractOrderSerializer.annotate_queryset(queryset)
queryset = queryset.annotate( queryset = queryset.annotate(
@ -1536,52 +1536,52 @@ class ReturnOrderSerializer(
class ReturnOrderIssueSerializer(serializers.Serializer): class ReturnOrderIssueSerializer(serializers.Serializer):
"""Serializer for issuing a ReturnOrder""" """Serializer for issuing a ReturnOrder."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
fields = [] fields = []
def save(self): def save(self):
"""Save the serializer to 'issue' the order""" """Save the serializer to 'issue' the order."""
order = self.context['order'] order = self.context['order']
order.issue_order() order.issue_order()
class ReturnOrderCancelSerializer(serializers.Serializer): class ReturnOrderCancelSerializer(serializers.Serializer):
"""Serializer for cancelling a ReturnOrder""" """Serializer for cancelling a ReturnOrder."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
fields = [] fields = []
def save(self): def save(self):
"""Save the serializer to 'cancel' the order""" """Save the serializer to 'cancel' the order."""
order = self.context['order'] order = self.context['order']
order.cancel_order() order.cancel_order()
class ReturnOrderCompleteSerializer(serializers.Serializer): class ReturnOrderCompleteSerializer(serializers.Serializer):
"""Serializer for completing a ReturnOrder""" """Serializer for completing a ReturnOrder."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
fields = [] fields = []
def save(self): def save(self):
"""Save the serializer to 'complete' the order""" """Save the serializer to 'complete' the order."""
order = self.context['order'] order = self.context['order']
order.complete_order() order.complete_order()
class ReturnOrderLineItemReceiveSerializer(serializers.Serializer): class ReturnOrderLineItemReceiveSerializer(serializers.Serializer):
"""Serializer for receiving a single line item against a ReturnOrder""" """Serializer for receiving a single line item against a ReturnOrder."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
fields = ['item'] fields = ['item']
@ -1594,7 +1594,7 @@ class ReturnOrderLineItemReceiveSerializer(serializers.Serializer):
) )
def validate_line_item(self, item): def validate_line_item(self, item):
"""Validation for a single line item""" """Validation for a single line item."""
if item.order != self.context['order']: if item.order != self.context['order']:
raise ValidationError(_('Line item does not match return order')) raise ValidationError(_('Line item does not match return order'))
@ -1605,10 +1605,10 @@ class ReturnOrderLineItemReceiveSerializer(serializers.Serializer):
class ReturnOrderReceiveSerializer(serializers.Serializer): class ReturnOrderReceiveSerializer(serializers.Serializer):
"""Serializer for receiving items against a ReturnOrder""" """Serializer for receiving items against a ReturnOrder."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
fields = ['items', 'location'] fields = ['items', 'location']
@ -1624,7 +1624,7 @@ class ReturnOrderReceiveSerializer(serializers.Serializer):
) )
def validate(self, data): def validate(self, data):
"""Perform data validation for this serializer""" """Perform data validation for this serializer."""
order = self.context['order'] order = self.context['order']
if order.status != ReturnOrderStatus.IN_PROGRESS: if order.status != ReturnOrderStatus.IN_PROGRESS:
raise ValidationError( raise ValidationError(
@ -1642,7 +1642,7 @@ class ReturnOrderReceiveSerializer(serializers.Serializer):
@transaction.atomic @transaction.atomic
def save(self): def save(self):
"""Saving this serializer marks the returned items as received""" """Saving this serializer marks the returned items as received."""
order = self.context['order'] order = self.context['order']
request = self.context['request'] request = self.context['request']
@ -1657,10 +1657,10 @@ class ReturnOrderReceiveSerializer(serializers.Serializer):
class ReturnOrderLineItemSerializer(InvenTreeModelSerializer): class ReturnOrderLineItemSerializer(InvenTreeModelSerializer):
"""Serializer for a ReturnOrderLineItem object""" """Serializer for a ReturnOrderLineItem object."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = order.models.ReturnOrderLineItem model = order.models.ReturnOrderLineItem
@ -1683,7 +1683,7 @@ class ReturnOrderLineItemSerializer(InvenTreeModelSerializer):
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Initialization routine for the serializer""" """Initialization routine for the serializer."""
order_detail = kwargs.pop('order_detail', False) order_detail = kwargs.pop('order_detail', False)
item_detail = kwargs.pop('item_detail', False) item_detail = kwargs.pop('item_detail', False)
part_detail = kwargs.pop('part_detail', False) part_detail = kwargs.pop('part_detail', False)
@ -1712,10 +1712,10 @@ class ReturnOrderLineItemSerializer(InvenTreeModelSerializer):
class ReturnOrderExtraLineSerializer( class ReturnOrderExtraLineSerializer(
AbstractExtraLineSerializer, InvenTreeModelSerializer AbstractExtraLineSerializer, InvenTreeModelSerializer
): ):
"""Serializer for a ReturnOrderExtraLine object""" """Serializer for a ReturnOrderExtraLine object."""
class Meta(AbstractExtraLineMeta): class Meta(AbstractExtraLineMeta):
"""Metaclass options""" """Metaclass options."""
model = order.models.ReturnOrderExtraLine model = order.models.ReturnOrderExtraLine
@ -1723,10 +1723,10 @@ class ReturnOrderExtraLineSerializer(
class ReturnOrderAttachmentSerializer(InvenTreeAttachmentSerializer): class ReturnOrderAttachmentSerializer(InvenTreeAttachmentSerializer):
"""Serializer for the ReturnOrderAttachment model""" """Serializer for the ReturnOrderAttachment model."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = order.models.ReturnOrderAttachment model = order.models.ReturnOrderAttachment

View File

@ -1,4 +1,4 @@
"""Background tasks for the 'order' app""" """Background tasks for the 'order' app."""
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -13,7 +13,7 @@ from plugin.events import trigger_event
def notify_overdue_purchase_order(po: order.models.PurchaseOrder): def notify_overdue_purchase_order(po: order.models.PurchaseOrder):
"""Notify users that a PurchaseOrder has just become 'overdue'""" """Notify users that a PurchaseOrder has just become 'overdue'."""
targets = [] targets = []
if po.created_by: if po.created_by:
@ -45,8 +45,9 @@ def notify_overdue_purchase_order(po: order.models.PurchaseOrder):
@scheduled_task(ScheduledTask.DAILY) @scheduled_task(ScheduledTask.DAILY)
def check_overdue_purchase_orders(): def check_overdue_purchase_orders():
"""Check if any outstanding PurchaseOrders have just become overdue: """Check if any outstanding PurchaseOrders have just become overdue.
Rules:
- This check is performed daily - This check is performed daily
- Look at the 'target_date' of any outstanding PurchaseOrder objects - Look at the 'target_date' of any outstanding PurchaseOrder objects
- If the 'target_date' expired *yesterday* then the order is just out of date - If the 'target_date' expired *yesterday* then the order is just out of date
@ -62,7 +63,7 @@ def check_overdue_purchase_orders():
def notify_overdue_sales_order(so: order.models.SalesOrder): def notify_overdue_sales_order(so: order.models.SalesOrder):
"""Notify appropriate users that a SalesOrder has just become 'overdue'""" """Notify appropriate users that a SalesOrder has just become 'overdue'."""
targets = [] targets = []
if so.created_by: if so.created_by:
@ -94,7 +95,7 @@ def notify_overdue_sales_order(so: order.models.SalesOrder):
@scheduled_task(ScheduledTask.DAILY) @scheduled_task(ScheduledTask.DAILY)
def check_overdue_sales_orders(): def check_overdue_sales_orders():
"""Check if any outstanding SalesOrders have just become overdue """Check if any outstanding SalesOrders have just become overdue.
- This check is performed daily - This check is performed daily
- Look at the 'target_date' of any outstanding SalesOrder objects - Look at the 'target_date' of any outstanding SalesOrder objects

View File

@ -30,7 +30,7 @@ from stock.models import StockItem
class OrderTest(InvenTreeAPITestCase): class OrderTest(InvenTreeAPITestCase):
"""Base class for order API unit testing""" """Base class for order API unit testing."""
fixtures = [ fixtures = [
'category', 'category',
@ -112,7 +112,7 @@ class PurchaseOrderTest(OrderTest):
) )
def test_po_list(self): def test_po_list(self):
"""Test the PurchaseOrder list API endpoint""" """Test the PurchaseOrder list API endpoint."""
# List *ALL* PurchaseOrder items # List *ALL* PurchaseOrder items
self.filter({}, 7) self.filter({}, 7)
@ -150,7 +150,7 @@ class PurchaseOrderTest(OrderTest):
self.filter({'supplier_part': 4}, 0) self.filter({'supplier_part': 4}, 0)
def test_total_price(self): def test_total_price(self):
"""Unit tests for the 'total_price' field""" """Unit tests for the 'total_price' field."""
# Ensure we have exchange rate data # Ensure we have exchange rate data
self.generate_exchange_rates() self.generate_exchange_rates()
@ -219,7 +219,7 @@ class PurchaseOrderTest(OrderTest):
self.filter({'overdue': False}, 6) self.filter({'overdue': False}, 6)
def test_po_detail(self): def test_po_detail(self):
"""Test the PurchaseOrder detail API endpoint""" """Test the PurchaseOrder detail API endpoint."""
url = '/api/order/po/1/' url = '/api/order/po/1/'
response = self.get(url) response = self.get(url)
@ -255,7 +255,7 @@ class PurchaseOrderTest(OrderTest):
self.assertEqual(order.reference_int, 0x7FFFFFFF) self.assertEqual(order.reference_int, 0x7FFFFFFF)
def test_po_attachments(self): def test_po_attachments(self):
"""Test the list endpoint for the PurchaseOrderAttachment model""" """Test the list endpoint for the PurchaseOrderAttachment model."""
url = reverse('api-po-attachment-list') url = reverse('api-po-attachment-list')
response = self.get(url) response = self.get(url)
@ -358,7 +358,7 @@ class PurchaseOrderTest(OrderTest):
) )
def test_po_duplicate(self): def test_po_duplicate(self):
"""Test that we can duplicate a PurchaseOrder via the API""" """Test that we can duplicate a PurchaseOrder via the API."""
self.assignRole('purchase_order.add') self.assignRole('purchase_order.add')
po = models.PurchaseOrder.objects.get(pk=1) po = models.PurchaseOrder.objects.get(pk=1)
@ -488,7 +488,7 @@ class PurchaseOrderTest(OrderTest):
self.assertEqual(po.status, PurchaseOrderStatus.PLACED) self.assertEqual(po.status, PurchaseOrderStatus.PLACED)
def test_po_calendar(self): def test_po_calendar(self):
"""Test the calendar export endpoint""" """Test the calendar export endpoint."""
# Create required purchase orders # Create required purchase orders
self.assignRole('purchase_order.add') self.assignRole('purchase_order.add')
@ -578,7 +578,7 @@ class PurchaseOrderTest(OrderTest):
self.assertEqual(number_orders_incl_completed, n_events) self.assertEqual(number_orders_incl_completed, n_events)
def test_po_calendar_noauth(self): def test_po_calendar_noauth(self):
"""Test accessing calendar without authorization""" """Test accessing calendar without authorization."""
self.client.logout() self.client.logout()
response = self.client.get( response = self.client.get(
reverse('api-po-so-calendar', kwargs={'ordertype': 'purchase-order'}), reverse('api-po-so-calendar', kwargs={'ordertype': 'purchase-order'}),
@ -593,7 +593,7 @@ class PurchaseOrderTest(OrderTest):
) )
def test_po_calendar_auth(self): def test_po_calendar_auth(self):
"""Test accessing calendar with header authorization""" """Test accessing calendar with header authorization."""
self.client.logout() self.client.logout()
base64_token = base64.b64encode( base64_token = base64.b64encode(
f'{self.username}:{self.password}'.encode('ascii') f'{self.username}:{self.password}'.encode('ascii')
@ -612,7 +612,7 @@ class PurchaseOrderLineItemTest(OrderTest):
LIST_URL = reverse('api-po-line-list') LIST_URL = reverse('api-po-line-list')
def test_po_line_list(self): def test_po_line_list(self):
"""Test the PurchaseOrderLine list API endpoint""" """Test the PurchaseOrderLine list API endpoint."""
# List *ALL* PurchaseOrderLine items # List *ALL* PurchaseOrderLine items
self.filter({}, 5) self.filter({}, 5)
@ -688,7 +688,7 @@ class PurchaseOrderDownloadTest(OrderTest):
self.assertEqual(order.reference, row['reference']) self.assertEqual(order.reference, row['reference'])
def test_download_line_items(self): def test_download_line_items(self):
"""Test that the PurchaseOrderLineItems can be downloaded to a file""" """Test that the PurchaseOrderLineItems can be downloaded to a file."""
with self.download_file( with self.download_file(
reverse('api-po-line-list'), reverse('api-po-line-list'),
{'export': 'xlsx'}, {'export': 'xlsx'},
@ -703,7 +703,7 @@ class PurchaseOrderReceiveTest(OrderTest):
"""Unit tests for receiving items against a PurchaseOrder.""" """Unit tests for receiving items against a PurchaseOrder."""
def setUp(self): def setUp(self):
"""Init routines for this unit test class""" """Init routines for this unit test class."""
super().setUp() super().setUp()
self.assignRole('purchase_order.add') self.assignRole('purchase_order.add')
@ -805,7 +805,7 @@ class PurchaseOrderReceiveTest(OrderTest):
) )
def test_invalid_barcodes(self): def test_invalid_barcodes(self):
"""Tests for checking in items with invalid barcodes: """Tests for checking in items with invalid barcodes.
- Cannot check in "duplicate" barcodes - Cannot check in "duplicate" barcodes
- Barcodes cannot match 'barcode_hash' field for existing StockItem - Barcodes cannot match 'barcode_hash' field for existing StockItem
@ -998,7 +998,7 @@ class SalesOrderTest(OrderTest):
LIST_URL = reverse('api-so-list') LIST_URL = reverse('api-so-list')
def test_so_list(self): def test_so_list(self):
"""Test the SalesOrder list API endpoint""" """Test the SalesOrder list API endpoint."""
# All orders # All orders
self.filter({}, 5) self.filter({}, 5)
@ -1024,7 +1024,7 @@ class SalesOrderTest(OrderTest):
self.filter({'assigned_to_me': 0}, 5) self.filter({'assigned_to_me': 0}, 5)
def test_total_price(self): def test_total_price(self):
"""Unit tests for the 'total_price' field""" """Unit tests for the 'total_price' field."""
# Ensure we have exchange rate data # Ensure we have exchange rate data
self.generate_exchange_rates() self.generate_exchange_rates()
@ -1106,7 +1106,7 @@ class SalesOrderTest(OrderTest):
self.filter({'overdue': False}, 3) self.filter({'overdue': False}, 3)
def test_so_detail(self): def test_so_detail(self):
"""Test the SalesOrder detail endpoint""" """Test the SalesOrder detail endpoint."""
url = '/api/order/so/1/' url = '/api/order/so/1/'
response = self.get(url) response = self.get(url)
@ -1116,7 +1116,7 @@ class SalesOrderTest(OrderTest):
self.assertEqual(data['pk'], 1) self.assertEqual(data['pk'], 1)
def test_so_attachments(self): def test_so_attachments(self):
"""Test the list endpoint for the SalesOrderAttachment model""" """Test the list endpoint for the SalesOrderAttachment model."""
url = reverse('api-so-attachment-list') url = reverse('api-so-attachment-list')
self.get(url) self.get(url)
@ -1241,7 +1241,7 @@ class SalesOrderTest(OrderTest):
self.assertEqual(so.status, SalesOrderStatus.CANCELLED) self.assertEqual(so.status, SalesOrderStatus.CANCELLED)
def test_so_calendar(self): def test_so_calendar(self):
"""Test the calendar export endpoint""" """Test the calendar export endpoint."""
# Create required sales orders # Create required sales orders
self.assignRole('sales_order.add') self.assignRole('sales_order.add')
@ -1313,7 +1313,7 @@ class SalesOrderTest(OrderTest):
self.assertEqual(number_orders_incl_complete, n_events) self.assertEqual(number_orders_incl_complete, n_events)
def test_export(self): def test_export(self):
"""Test we can export the SalesOrder list""" """Test we can export the SalesOrder list."""
n = models.SalesOrder.objects.count() n = models.SalesOrder.objects.count()
# Check there are some sales orders # Check there are some sales orders
@ -1341,7 +1341,7 @@ class SalesOrderLineItemTest(OrderTest):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Init routine for this unit test class""" """Init routine for this unit test class."""
super().setUpTestData() super().setUpTestData()
# List of salable parts # List of salable parts
@ -1367,7 +1367,7 @@ class SalesOrderLineItemTest(OrderTest):
cls.url = reverse('api-so-line-list') cls.url = reverse('api-so-line-list')
def test_so_line_list(self): def test_so_line_list(self):
"""Test list endpoint""" """Test list endpoint."""
response = self.get(self.url, {}, expected_code=200) response = self.get(self.url, {}, expected_code=200)
n = models.SalesOrderLineItem.objects.count() n = models.SalesOrderLineItem.objects.count()
@ -1416,7 +1416,7 @@ class SalesOrderDownloadTest(OrderTest):
self.download_file(url, {}, expected_code=200) self.download_file(url, {}, expected_code=200)
def test_download_xls(self): def test_download_xls(self):
"""Test xls file download""" """Test xls file download."""
url = reverse('api-so-list') url = reverse('api-so-list')
# Download .xls file # Download .xls file
@ -1430,7 +1430,7 @@ class SalesOrderDownloadTest(OrderTest):
self.assertTrue(isinstance(file, io.BytesIO)) self.assertTrue(isinstance(file, io.BytesIO))
def test_download_csv(self): def test_download_csv(self):
"""Test that the list of sales orders can be downloaded as a .csv file""" """Test that the list of sales orders can be downloaded as a .csv file."""
url = reverse('api-so-list') url = reverse('api-so-list')
required_cols = [ required_cols = [
@ -1490,7 +1490,7 @@ class SalesOrderAllocateTest(OrderTest):
"""Unit tests for allocating stock items against a SalesOrder.""" """Unit tests for allocating stock items against a SalesOrder."""
def setUp(self): def setUp(self):
"""Init routines for this unit testing class""" """Init routines for this unit testing class."""
super().setUp() super().setUp()
self.assignRole('sales_order.add') self.assignRole('sales_order.add')
@ -1599,7 +1599,7 @@ class SalesOrderAllocateTest(OrderTest):
self.assertEqual(line.allocations.count(), 1) self.assertEqual(line.allocations.count(), 1)
def test_allocate_variant(self): def test_allocate_variant(self):
"""Test that the allocation endpoint acts as expected, when provided with variant""" """Test that the allocation endpoint acts as expected, when provided with variant."""
# First, check that there are no line items allocated against this SalesOrder # First, check that there are no line items allocated against this SalesOrder
self.assertEqual(self.order.stock_allocations.count(), 0) self.assertEqual(self.order.stock_allocations.count(), 0)
@ -1706,7 +1706,7 @@ class SalesOrderAllocateTest(OrderTest):
self.assertEqual(self.shipment.delivery_date, datetime(2023, 5, 15).date()) self.assertEqual(self.shipment.delivery_date, datetime(2023, 5, 15).date())
def test_sales_order_shipment_list(self): def test_sales_order_shipment_list(self):
"""Test the SalesOrderShipment list API endpoint""" """Test the SalesOrderShipment list API endpoint."""
url = reverse('api-so-shipment-list') url = reverse('api-so-shipment-list')
# Count before creation # Count before creation
@ -1740,7 +1740,7 @@ class SalesOrderAllocateTest(OrderTest):
class ReturnOrderTests(InvenTreeAPITestCase): class ReturnOrderTests(InvenTreeAPITestCase):
"""Unit tests for ReturnOrder API endpoints""" """Unit tests for ReturnOrder API endpoints."""
fixtures = [ fixtures = [
'category', 'category',
@ -1753,7 +1753,7 @@ class ReturnOrderTests(InvenTreeAPITestCase):
] ]
def test_options(self): def test_options(self):
"""Test the OPTIONS endpoint""" """Test the OPTIONS endpoint."""
self.assignRole('return_order.add') self.assignRole('return_order.add')
data = self.options(reverse('api-return-order-list'), expected_code=200).data data = self.options(reverse('api-return-order-list'), expected_code=200).data
@ -1770,7 +1770,7 @@ class ReturnOrderTests(InvenTreeAPITestCase):
self.assertEqual(reference['type'], 'string') self.assertEqual(reference['type'], 'string')
def test_list(self): def test_list(self):
"""Tests for the list endpoint""" """Tests for the list endpoint."""
url = reverse('api-return-order-list') url = reverse('api-return-order-list')
response = self.get(url, expected_code=200) response = self.get(url, expected_code=200)
@ -1813,7 +1813,7 @@ class ReturnOrderTests(InvenTreeAPITestCase):
self.assertEqual(result['status'], 20) self.assertEqual(result['status'], 20)
def test_create(self): def test_create(self):
"""Test creation of ReturnOrder via the API""" """Test creation of ReturnOrder via the API."""
url = reverse('api-return-order-list') url = reverse('api-return-order-list')
# Do not have required permissions yet # Do not have required permissions yet
@ -1838,7 +1838,7 @@ class ReturnOrderTests(InvenTreeAPITestCase):
self.assertEqual(data['customer_reference'], 'cr') self.assertEqual(data['customer_reference'], 'cr')
def test_update(self): def test_update(self):
"""Test that we can update a ReturnOrder via the API""" """Test that we can update a ReturnOrder via the API."""
url = reverse('api-return-order-detail', kwargs={'pk': 1}) url = reverse('api-return-order-detail', kwargs={'pk': 1})
# Test detail endpoint # Test detail endpoint
@ -1859,7 +1859,7 @@ class ReturnOrderTests(InvenTreeAPITestCase):
self.assertEqual(rma.customer_reference, 'customer ref') self.assertEqual(rma.customer_reference, 'customer ref')
def test_ro_issue(self): def test_ro_issue(self):
"""Test the 'issue' order for a ReturnOrder""" """Test the 'issue' order for a ReturnOrder."""
order = models.ReturnOrder.objects.get(pk=1) order = models.ReturnOrder.objects.get(pk=1)
self.assertEqual(order.status, ReturnOrderStatus.PENDING) self.assertEqual(order.status, ReturnOrderStatus.PENDING)
self.assertIsNone(order.issue_date) self.assertIsNone(order.issue_date)
@ -1877,7 +1877,7 @@ class ReturnOrderTests(InvenTreeAPITestCase):
self.assertIsNotNone(order.issue_date) self.assertIsNotNone(order.issue_date)
def test_receive(self): def test_receive(self):
"""Test that we can receive items against a ReturnOrder""" """Test that we can receive items against a ReturnOrder."""
customer = Company.objects.get(pk=4) customer = Company.objects.get(pk=4)
# Create an order # Create an order
@ -1905,7 +1905,7 @@ class ReturnOrderTests(InvenTreeAPITestCase):
self.assertEqual(rma.lines.count(), 3) self.assertEqual(rma.lines.count(), 3)
def receive(items, location=None, expected_code=400): def receive(items, location=None, expected_code=400):
"""Helper function to receive items against this ReturnOrder""" """Helper function to receive items against this ReturnOrder."""
url = reverse('api-return-order-receive', kwargs={'pk': rma.pk}) url = reverse('api-return-order-receive', kwargs={'pk': rma.pk})
response = self.post( response = self.post(
@ -1970,7 +1970,7 @@ class ReturnOrderTests(InvenTreeAPITestCase):
self.assertEqual(deltas['returnorder'], rma.pk) self.assertEqual(deltas['returnorder'], rma.pk)
def test_ro_calendar(self): def test_ro_calendar(self):
"""Test the calendar export endpoint""" """Test the calendar export endpoint."""
# Full test is in test_po_calendar. Since these use the same backend, test only # Full test is in test_po_calendar. Since these use the same backend, test only
# that the endpoint is available # that the endpoint is available
url = reverse('api-po-so-calendar', kwargs={'ordertype': 'return-order'}) url = reverse('api-po-so-calendar', kwargs={'ordertype': 'return-order'})
@ -1999,7 +1999,7 @@ class OrderMetadataAPITest(InvenTreeAPITestCase):
roles = ['purchase_order.change', 'sales_order.change', 'return_order.change'] roles = ['purchase_order.change', 'sales_order.change', 'return_order.change']
def metatester(self, apikey, model): def metatester(self, apikey, model):
"""Generic tester""" """Generic tester."""
modeldata = model.objects.first() modeldata = model.objects.first()
# Useless test unless a model object is found # Useless test unless a model object is found
@ -2025,7 +2025,7 @@ class OrderMetadataAPITest(InvenTreeAPITestCase):
) )
def test_metadata(self): def test_metadata(self):
"""Test all endpoints""" """Test all endpoints."""
for apikey, model in { for apikey, model in {
'api-po-metadata': models.PurchaseOrder, 'api-po-metadata': models.PurchaseOrder,
'api-po-line-metadata': models.PurchaseOrderLineItem, 'api-po-line-metadata': models.PurchaseOrderLineItem,

View File

@ -1,4 +1,4 @@
"""Unit tests for the SalesOrder models""" """Unit tests for the SalesOrder models."""
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -30,7 +30,7 @@ class SalesOrderTest(TestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Initial setup for this set of unit tests""" """Initial setup for this set of unit tests."""
# Create a Company to ship the goods to # Create a Company to ship the goods to
cls.customer = Company.objects.create( cls.customer = Company.objects.create(
name='ABC Co', description='My customer', is_customer=True name='ABC Co', description='My customer', is_customer=True
@ -76,14 +76,14 @@ class SalesOrderTest(TestCase):
) )
def test_so_reference(self): def test_so_reference(self):
"""Unit tests for sales order generation""" """Unit tests for sales order generation."""
# Test that a good reference is created when we have no existing orders # Test that a good reference is created when we have no existing orders
SalesOrder.objects.all().delete() SalesOrder.objects.all().delete()
self.assertEqual(SalesOrder.generate_reference(), 'SO-0001') self.assertEqual(SalesOrder.generate_reference(), 'SO-0001')
def test_rebuild_reference(self): def test_rebuild_reference(self):
"""Test that the 'reference_int' field gets rebuilt when the model is saved""" """Test that the 'reference_int' field gets rebuilt when the model is saved."""
self.assertEqual(self.order.reference_int, 1234) self.assertEqual(self.order.reference_int, 1234)
self.order.reference = '999' self.order.reference = '999'
@ -112,7 +112,7 @@ class SalesOrderTest(TestCase):
self.assertFalse(self.order.is_overdue) self.assertFalse(self.order.is_overdue)
def test_empty_order(self): def test_empty_order(self):
"""Test for an empty order""" """Test for an empty order."""
self.assertEqual(self.line.quantity, 50) self.assertEqual(self.line.quantity, 50)
self.assertEqual(self.line.allocated_quantity(), 0) self.assertEqual(self.line.allocated_quantity(), 0)
self.assertEqual(self.line.fulfilled_quantity(), 0) self.assertEqual(self.line.fulfilled_quantity(), 0)
@ -123,14 +123,14 @@ class SalesOrderTest(TestCase):
self.assertFalse(self.order.is_fully_allocated()) self.assertFalse(self.order.is_fully_allocated())
def test_add_duplicate_line_item(self): def test_add_duplicate_line_item(self):
"""Adding a duplicate line item to a SalesOrder is accepted""" """Adding a duplicate line item to a SalesOrder is accepted."""
for ii in range(1, 5): for ii in range(1, 5):
SalesOrderLineItem.objects.create( SalesOrderLineItem.objects.create(
order=self.order, part=self.part, quantity=ii order=self.order, part=self.part, quantity=ii
) )
def allocate_stock(self, full=True): def allocate_stock(self, full=True):
"""Allocate stock to the order""" """Allocate stock to the order."""
SalesOrderAllocation.objects.create( SalesOrderAllocation.objects.create(
line=self.line, line=self.line,
shipment=self.shipment, shipment=self.shipment,
@ -146,8 +146,7 @@ class SalesOrderTest(TestCase):
) )
def test_over_allocate(self): def test_over_allocate(self):
"""Test that over allocation logic works""" """Test that over allocation logic works."""
SA = StockItem.objects.create(part=self.part, quantity=9) SA = StockItem.objects.create(part=self.part, quantity=9)
# First three allocations should succeed # First three allocations should succeed
@ -171,7 +170,7 @@ class SalesOrderTest(TestCase):
allocation.clean() allocation.clean()
def test_allocate_partial(self): def test_allocate_partial(self):
"""Partially allocate stock""" """Partially allocate stock."""
self.allocate_stock(False) self.allocate_stock(False)
self.assertFalse(self.order.is_fully_allocated()) self.assertFalse(self.order.is_fully_allocated())
@ -180,7 +179,7 @@ class SalesOrderTest(TestCase):
self.assertEqual(self.line.fulfilled_quantity(), 0) self.assertEqual(self.line.fulfilled_quantity(), 0)
def test_allocate_full(self): def test_allocate_full(self):
"""Fully allocate stock""" """Fully allocate stock."""
self.allocate_stock(True) self.allocate_stock(True)
self.assertTrue(self.order.is_fully_allocated()) self.assertTrue(self.order.is_fully_allocated())
@ -188,7 +187,7 @@ class SalesOrderTest(TestCase):
self.assertEqual(self.line.allocated_quantity(), 50) self.assertEqual(self.line.allocated_quantity(), 50)
def test_allocate_variant(self): def test_allocate_variant(self):
"""Allocate a variant of the designated item""" """Allocate a variant of the designated item."""
SalesOrderAllocation.objects.create( SalesOrderAllocation.objects.create(
line=self.line, line=self.line,
shipment=self.shipment, shipment=self.shipment,
@ -198,7 +197,7 @@ class SalesOrderTest(TestCase):
self.assertEqual(self.line.allocated_quantity(), 50) self.assertEqual(self.line.allocated_quantity(), 50)
def test_order_cancel(self): def test_order_cancel(self):
"""Allocate line items then cancel the order""" """Allocate line items then cancel the order."""
self.allocate_stock(True) self.allocate_stock(True)
self.assertEqual(SalesOrderAllocation.objects.count(), 2) self.assertEqual(SalesOrderAllocation.objects.count(), 2)
@ -216,7 +215,7 @@ class SalesOrderTest(TestCase):
self.assertFalse(result) self.assertFalse(result)
def test_complete_order(self): def test_complete_order(self):
"""Allocate line items, then ship the order""" """Allocate line items, then ship the order."""
# Assert some stuff before we run the test # Assert some stuff before we run the test
# Initially there are three stock items # Initially there are three stock items
self.assertEqual(StockItem.objects.count(), 3) self.assertEqual(StockItem.objects.count(), 3)
@ -279,7 +278,7 @@ class SalesOrderTest(TestCase):
self.assertEqual(self.line.allocated_quantity(), 50) self.assertEqual(self.line.allocated_quantity(), 50)
def test_default_shipment(self): def test_default_shipment(self):
"""Test sales order default shipment creation""" """Test sales order default shipment creation."""
# Default setting value should be False # Default setting value should be False
self.assertEqual( self.assertEqual(
False, InvenTreeSetting.get_setting('SALESORDER_DEFAULT_SHIPMENT') False, InvenTreeSetting.get_setting('SALESORDER_DEFAULT_SHIPMENT')
@ -312,13 +311,13 @@ class SalesOrderTest(TestCase):
self.assertEqual('1', order_2.pending_shipments()[0].reference) self.assertEqual('1', order_2.pending_shipments()[0].reference)
def test_shipment_delivery(self): def test_shipment_delivery(self):
"""Test the shipment delivery settings""" """Test the shipment delivery settings."""
# Shipment delivery date should be empty before setting date # Shipment delivery date should be empty before setting date
self.assertIsNone(self.shipment.delivery_date) self.assertIsNone(self.shipment.delivery_date)
self.assertFalse(self.shipment.is_delivered()) self.assertFalse(self.shipment.is_delivered())
def test_overdue_notification(self): def test_overdue_notification(self):
"""Test overdue sales order notification""" """Test overdue sales order notification."""
self.order.created_by = get_user_model().objects.get(pk=3) self.order.created_by = get_user_model().objects.get(pk=3)
self.order.responsible = Owner.create(obj=Group.objects.get(pk=2)) self.order.responsible = Owner.create(obj=Group.objects.get(pk=2))
self.order.target_date = datetime.now().date() - timedelta(days=1) self.order.target_date = datetime.now().date() - timedelta(days=1)

View File

@ -1,4 +1,4 @@
"""Unit tests for Order views (see views.py)""" """Unit tests for Order views (see views.py)."""
from django.urls import reverse from django.urls import reverse
@ -6,7 +6,7 @@ from InvenTree.unit_test import InvenTreeTestCase
class OrderViewTestCase(InvenTreeTestCase): class OrderViewTestCase(InvenTreeTestCase):
"""Base unit test class for order views""" """Base unit test class for order views."""
fixtures = [ fixtures = [
'category', 'category',
@ -35,10 +35,10 @@ class OrderViewTestCase(InvenTreeTestCase):
class PurchaseOrderListTest(OrderViewTestCase): class PurchaseOrderListTest(OrderViewTestCase):
"""Unit tests for the PurchaseOrder index page""" """Unit tests for the PurchaseOrder index page."""
def test_order_list(self): def test_order_list(self):
"""Tests for the PurchaseOrder index page""" """Tests for the PurchaseOrder index page."""
response = self.client.get(reverse('purchase-order-index')) response = self.client.get(reverse('purchase-order-index'))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@ -65,28 +65,28 @@ class PurchaseOrderTests(OrderViewTestCase):
class SalesOrderViews(OrderViewTestCase): class SalesOrderViews(OrderViewTestCase):
"""Unit tests for the SalesOrder pages""" """Unit tests for the SalesOrder pages."""
def test_index(self): def test_index(self):
"""Test the SalesOrder index page""" """Test the SalesOrder index page."""
response = self.client.get(reverse('sales-order-index')) response = self.client.get(reverse('sales-order-index'))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_detail(self): def test_detail(self):
"""Test SalesOrder detail view""" """Test SalesOrder detail view."""
response = self.client.get(reverse('so-detail', args=(1,))) response = self.client.get(reverse('so-detail', args=(1,)))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
class ReturnOrderVIews(OrderViewTestCase): class ReturnOrderVIews(OrderViewTestCase):
"""Unit tests for the ReturnOrder pages""" """Unit tests for the ReturnOrder pages."""
def test_index(self): def test_index(self):
"""Test the ReturnOrder index page""" """Test the ReturnOrder index page."""
response = self.client.get(reverse('return-order-index')) response = self.client.get(reverse('return-order-index'))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_detail(self): def test_detail(self):
"""Test ReturnOrder detail view""" """Test ReturnOrder detail view."""
response = self.client.get(reverse('return-order-detail', args=(1,))) response = self.client.get(reverse('return-order-detail', args=(1,)))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)

View File

@ -1,4 +1,4 @@
"""Various unit tests for order models""" """Various unit tests for order models."""
from datetime import datetime, timedelta from datetime import datetime, timedelta
from decimal import Decimal from decimal import Decimal
@ -49,7 +49,7 @@ class OrderTest(TestCase):
self.assertEqual(str(line), '100 x ACME0001 from ACME (for PO-0001 - ACME)') self.assertEqual(str(line), '100 x ACME0001 from ACME (for PO-0001 - ACME)')
def test_rebuild_reference(self): def test_rebuild_reference(self):
"""Test that the reference_int field is correctly updated when the model is saved""" """Test that the reference_int field is correctly updated when the model is saved."""
order = PurchaseOrder.objects.get(pk=1) order = PurchaseOrder.objects.get(pk=1)
order.save() order.save()
self.assertEqual(order.reference_int, 1) self.assertEqual(order.reference_int, 1)
@ -214,7 +214,7 @@ class OrderTest(TestCase):
self.assertEqual(order.status, PurchaseOrderStatus.COMPLETE) self.assertEqual(order.status, PurchaseOrderStatus.COMPLETE)
def test_receive_pack_size(self): def test_receive_pack_size(self):
"""Test receiving orders from suppliers with different pack_size values""" """Test receiving orders from suppliers with different pack_size values."""
prt = Part.objects.get(pk=1) prt = Part.objects.get(pk=1)
sup = Company.objects.get(pk=1) sup = Company.objects.get(pk=1)
@ -305,7 +305,7 @@ class OrderTest(TestCase):
self.assertEqual(si.purchase_price, Money(100, 'USD')) self.assertEqual(si.purchase_price, Money(100, 'USD'))
def test_overdue_notification(self): def test_overdue_notification(self):
"""Test overdue purchase order notification """Test overdue purchase order notification.
Ensure that a notification is sent when a PurchaseOrder becomes overdue Ensure that a notification is sent when a PurchaseOrder becomes overdue
""" """
@ -343,7 +343,7 @@ class OrderTest(TestCase):
self.assertEqual(msg.name, 'Overdue Purchase Order') self.assertEqual(msg.name, 'Overdue Purchase Order')
def test_new_po_notification(self): def test_new_po_notification(self):
"""Test that a notification is sent when a new PurchaseOrder is created """Test that a notification is sent when a new PurchaseOrder is created.
- The responsible user(s) should receive a notification - The responsible user(s) should receive a notification
- The creating user should *not* receive a notification - The creating user should *not* receive a notification

View File

@ -1,5 +1,6 @@
"""URL lookup for the Order app. Provides URL endpoints for: """URL lookup for the Order app.
Provides URL endpoints for:
- List view of Purchase Orders - List view of Purchase Orders
- Detail view of Purchase Orders - Detail view of Purchase Orders
""" """

View File

@ -1,64 +1,64 @@
"""Validation methods for the order app""" """Validation methods for the order app."""
def generate_next_sales_order_reference(): def generate_next_sales_order_reference():
"""Generate the next available SalesOrder reference""" """Generate the next available SalesOrder reference."""
from order.models import SalesOrder from order.models import SalesOrder
return SalesOrder.generate_reference() return SalesOrder.generate_reference()
def generate_next_purchase_order_reference(): def generate_next_purchase_order_reference():
"""Generate the next available PurchasesOrder reference""" """Generate the next available PurchasesOrder reference."""
from order.models import PurchaseOrder from order.models import PurchaseOrder
return PurchaseOrder.generate_reference() return PurchaseOrder.generate_reference()
def generate_next_return_order_reference(): def generate_next_return_order_reference():
"""Generate the next available ReturnOrder reference""" """Generate the next available ReturnOrder reference."""
from order.models import ReturnOrder from order.models import ReturnOrder
return ReturnOrder.generate_reference() return ReturnOrder.generate_reference()
def validate_sales_order_reference_pattern(pattern): def validate_sales_order_reference_pattern(pattern):
"""Validate the SalesOrder reference 'pattern' setting""" """Validate the SalesOrder reference 'pattern' setting."""
from order.models import SalesOrder from order.models import SalesOrder
SalesOrder.validate_reference_pattern(pattern) SalesOrder.validate_reference_pattern(pattern)
def validate_purchase_order_reference_pattern(pattern): def validate_purchase_order_reference_pattern(pattern):
"""Validate the PurchaseOrder reference 'pattern' setting""" """Validate the PurchaseOrder reference 'pattern' setting."""
from order.models import PurchaseOrder from order.models import PurchaseOrder
PurchaseOrder.validate_reference_pattern(pattern) PurchaseOrder.validate_reference_pattern(pattern)
def validate_return_order_reference_pattern(pattern): def validate_return_order_reference_pattern(pattern):
"""Validate the ReturnOrder reference 'pattern' setting""" """Validate the ReturnOrder reference 'pattern' setting."""
from order.models import ReturnOrder from order.models import ReturnOrder
ReturnOrder.validate_reference_pattern(pattern) ReturnOrder.validate_reference_pattern(pattern)
def validate_sales_order_reference(value): def validate_sales_order_reference(value):
"""Validate that the SalesOrder reference field matches the required pattern""" """Validate that the SalesOrder reference field matches the required pattern."""
from order.models import SalesOrder from order.models import SalesOrder
SalesOrder.validate_reference_field(value) SalesOrder.validate_reference_field(value)
def validate_purchase_order_reference(value): def validate_purchase_order_reference(value):
"""Validate that the PurchaseOrder reference field matches the required pattern""" """Validate that the PurchaseOrder reference field matches the required pattern."""
from order.models import PurchaseOrder from order.models import PurchaseOrder
PurchaseOrder.validate_reference_field(value) PurchaseOrder.validate_reference_field(value)
def validate_return_order_reference(value): def validate_return_order_reference(value):
"""Validate that the ReturnOrder reference field matches the required pattern""" """Validate that the ReturnOrder reference field matches the required pattern."""
from order.models import ReturnOrder from order.models import ReturnOrder
ReturnOrder.validate_reference_field(value) ReturnOrder.validate_reference_field(value)

View File

@ -50,7 +50,7 @@ class PurchaseOrderIndex(InvenTreeRoleMixin, ListView):
class SalesOrderIndex(InvenTreeRoleMixin, ListView): class SalesOrderIndex(InvenTreeRoleMixin, ListView):
"""SalesOrder index (list) view class""" """SalesOrder index (list) view class."""
model = SalesOrder model = SalesOrder
template_name = 'order/sales_orders.html' template_name = 'order/sales_orders.html'
@ -58,7 +58,7 @@ class SalesOrderIndex(InvenTreeRoleMixin, ListView):
class ReturnOrderIndex(InvenTreeRoleMixin, ListView): class ReturnOrderIndex(InvenTreeRoleMixin, ListView):
"""ReturnOrder index (list) view""" """ReturnOrder index (list) view."""
model = ReturnOrder model = ReturnOrder
template_name = 'order/return_orders.html' template_name = 'order/return_orders.html'
@ -84,7 +84,7 @@ class SalesOrderDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView)
class ReturnOrderDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): class ReturnOrderDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView):
"""Detail view for a ReturnOrder object""" """Detail view for a ReturnOrder object."""
context_object_name = 'order' context_object_name = 'order'
queryset = ReturnOrder.objects.all() queryset = ReturnOrder.objects.all()
@ -92,10 +92,10 @@ class ReturnOrderDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView
class PurchaseOrderUpload(FileManagementFormView): class PurchaseOrderUpload(FileManagementFormView):
"""PurchaseOrder: Upload file, match to fields and parts (using multi-Step form)""" """PurchaseOrder: Upload file, match to fields and parts (using multi-Step form)."""
class OrderFileManager(FileManager): class OrderFileManager(FileManager):
"""Specify required fields""" """Specify required fields."""
REQUIRED_HEADERS = ['Quantity'] REQUIRED_HEADERS = ['Quantity']
@ -289,7 +289,7 @@ class SalesOrderExport(AjaxView):
role_required = 'sales_order.view' role_required = 'sales_order.view'
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""Perform GET request to export SalesOrder dataset""" """Perform GET request to export SalesOrder dataset."""
order = get_object_or_404(SalesOrder, pk=self.kwargs.get('pk', None)) order = get_object_or_404(SalesOrder, pk=self.kwargs.get('pk', None))
export_format = request.GET.get('format', 'csv') export_format = request.GET.get('format', 'csv')
@ -316,7 +316,7 @@ class PurchaseOrderExport(AjaxView):
role_required = 'purchase_order.view' role_required = 'purchase_order.view'
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""Perform GET request to export PurchaseOrder dataset""" """Perform GET request to export PurchaseOrder dataset."""
order = get_object_or_404(PurchaseOrder, pk=self.kwargs.get('pk', None)) order = get_object_or_404(PurchaseOrder, pk=self.kwargs.get('pk', None))
export_format = request.GET.get('format', 'csv') export_format = request.GET.get('format', 'csv')
@ -334,7 +334,7 @@ class LineItemPricing(PartPricing):
"""View for inspecting part pricing information.""" """View for inspecting part pricing information."""
class EnhancedForm(PartPricing.form_class): class EnhancedForm(PartPricing.form_class):
"""Extra form options""" """Extra form options."""
pk = IntegerField(widget=HiddenInput()) pk = IntegerField(widget=HiddenInput())
so_line = IntegerField(widget=HiddenInput()) so_line = IntegerField(widget=HiddenInput())
@ -342,7 +342,7 @@ class LineItemPricing(PartPricing):
form_class = EnhancedForm form_class = EnhancedForm
def get_part(self, id=False): def get_part(self, id=False):
"""Return the Part instance associated with this view""" """Return the Part instance associated with this view."""
if 'line_item' in self.request.GET: if 'line_item' in self.request.GET:
try: try:
part_id = self.request.GET.get('line_item') part_id = self.request.GET.get('line_item')
@ -364,7 +364,7 @@ class LineItemPricing(PartPricing):
return part return part
def get_so(self, pk=False): def get_so(self, pk=False):
"""Return the SalesOrderLineItem associated with this view""" """Return the SalesOrderLineItem associated with this view."""
so_line = self.request.GET.get('line_item', None) so_line = self.request.GET.get('line_item', None)
if not so_line: if not so_line:
so_line = self.request.POST.get('so_line', None) so_line = self.request.POST.get('so_line', None)
@ -387,14 +387,14 @@ class LineItemPricing(PartPricing):
return qty return qty
def get_initials(self): def get_initials(self):
"""Return initial context values for this view""" """Return initial context values for this view."""
initials = super().get_initials() initials = super().get_initials()
initials['pk'] = self.get_part(id=True) initials['pk'] = self.get_part(id=True)
initials['so_line'] = self.get_so(pk=True) initials['so_line'] = self.get_so(pk=True)
return initials return initials
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
"""Respond to a POST request to get particular pricing information""" """Respond to a POST request to get particular pricing information."""
REF = 'act-btn_' REF = 'act-btn_'
act_btn = [a.replace(REF, '') for a in self.request.POST if REF in a] act_btn = [a.replace(REF, '') for a in self.request.POST if REF in a]

View File

@ -1,4 +1,4 @@
"""Admin class definitions for the 'part' app""" """Admin class definitions for the 'part' app."""
from django.contrib import admin from django.contrib import admin
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -17,7 +17,7 @@ class PartResource(InvenTreeResource):
"""Class for managing Part data import/export.""" """Class for managing Part data import/export."""
class Meta: class Meta:
"""Metaclass definition""" """Metaclass options."""
model = models.Part model = models.Part
skip_unchanged = True skip_unchanged = True
@ -159,14 +159,14 @@ class PartResource(InvenTreeResource):
) )
def dehydrate_min_cost(self, part): def dehydrate_min_cost(self, part):
"""Render minimum cost value for this Part""" """Render minimum cost value for this Part."""
min_cost = part.pricing.overall_min if part.pricing else None min_cost = part.pricing.overall_min if part.pricing else None
if min_cost is not None: if min_cost is not None:
return float(min_cost.amount) return float(min_cost.amount)
def dehydrate_max_cost(self, part): def dehydrate_max_cost(self, part):
"""Render maximum cost value for this Part""" """Render maximum cost value for this Part."""
max_cost = part.pricing.overall_max if part.pricing else None max_cost = part.pricing.overall_max if part.pricing else None
if max_cost is not None: if max_cost is not None:
@ -186,7 +186,7 @@ class PartResource(InvenTreeResource):
return query return query
def after_import(self, dataset, result, using_transactions, dry_run, **kwargs): def after_import(self, dataset, result, using_transactions, dry_run, **kwargs):
"""Rebuild MPTT tree structure after importing Part data""" """Rebuild MPTT tree structure after importing Part data."""
super().after_import(dataset, result, using_transactions, dry_run, **kwargs) super().after_import(dataset, result, using_transactions, dry_run, **kwargs)
# Rebuild the Part tree(s) # Rebuild the Part tree(s)
@ -197,7 +197,7 @@ class PartImportResource(InvenTreeResource):
"""Class for managing Part data import/export.""" """Class for managing Part data import/export."""
class Meta(PartResource.Meta): class Meta(PartResource.Meta):
"""Metaclass definition""" """Metaclass options."""
skip_unchanged = True skip_unchanged = True
report_skipped = False report_skipped = False
@ -223,13 +223,13 @@ class PartImportResource(InvenTreeResource):
class PartParameterInline(admin.TabularInline): class PartParameterInline(admin.TabularInline):
"""Inline for part parameter data""" """Inline for part parameter data."""
model = models.PartParameter model = models.PartParameter
class PartAdmin(ImportExportModelAdmin): class PartAdmin(ImportExportModelAdmin):
"""Admin class for the Part model""" """Admin class for the Part model."""
resource_class = PartResource resource_class = PartResource
@ -256,7 +256,7 @@ class PartAdmin(ImportExportModelAdmin):
class PartPricingAdmin(admin.ModelAdmin): class PartPricingAdmin(admin.ModelAdmin):
"""Admin class for PartPricing model""" """Admin class for PartPricing model."""
list_display = ('part', 'overall_min', 'overall_max') list_display = ('part', 'overall_min', 'overall_max')
@ -264,13 +264,13 @@ class PartPricingAdmin(admin.ModelAdmin):
class PartStocktakeAdmin(admin.ModelAdmin): class PartStocktakeAdmin(admin.ModelAdmin):
"""Admin class for PartStocktake model""" """Admin class for PartStocktake model."""
list_display = ['part', 'date', 'quantity', 'user'] list_display = ['part', 'date', 'quantity', 'user']
class PartStocktakeReportAdmin(admin.ModelAdmin): class PartStocktakeReportAdmin(admin.ModelAdmin):
"""Admin class for PartStocktakeReport model""" """Admin class for PartStocktakeReport model."""
list_display = ['date', 'user'] list_display = ['date', 'user']
@ -279,7 +279,7 @@ class PartCategoryResource(InvenTreeResource):
"""Class for managing PartCategory data import/export.""" """Class for managing PartCategory data import/export."""
class Meta: class Meta:
"""Metaclass definition""" """Metaclass options."""
model = models.PartCategory model = models.PartCategory
skip_unchanged = True skip_unchanged = True
@ -326,7 +326,7 @@ class PartCategoryResource(InvenTreeResource):
) )
def after_import(self, dataset, result, using_transactions, dry_run, **kwargs): def after_import(self, dataset, result, using_transactions, dry_run, **kwargs):
"""Rebuild MPTT tree structure after importing PartCategory data""" """Rebuild MPTT tree structure after importing PartCategory data."""
super().after_import(dataset, result, using_transactions, dry_run, **kwargs) super().after_import(dataset, result, using_transactions, dry_run, **kwargs)
# Rebuild the PartCategory tree(s) # Rebuild the PartCategory tree(s)
@ -334,7 +334,7 @@ class PartCategoryResource(InvenTreeResource):
class PartCategoryAdmin(ImportExportModelAdmin): class PartCategoryAdmin(ImportExportModelAdmin):
"""Admin class for the PartCategory model""" """Admin class for the PartCategory model."""
resource_class = PartCategoryResource resource_class = PartCategoryResource
@ -352,7 +352,7 @@ class PartRelatedAdmin(admin.ModelAdmin):
class PartAttachmentAdmin(admin.ModelAdmin): class PartAttachmentAdmin(admin.ModelAdmin):
"""Admin class for the PartAttachment model""" """Admin class for the PartAttachment model."""
list_display = ('part', 'attachment', 'comment') list_display = ('part', 'attachment', 'comment')
@ -360,7 +360,7 @@ class PartAttachmentAdmin(admin.ModelAdmin):
class PartTestTemplateAdmin(admin.ModelAdmin): class PartTestTemplateAdmin(admin.ModelAdmin):
"""Admin class for the PartTestTemplate model""" """Admin class for the PartTestTemplate model."""
list_display = ('part', 'test_name', 'required') list_display = ('part', 'test_name', 'required')
@ -371,7 +371,7 @@ class BomItemResource(InvenTreeResource):
"""Class for managing BomItem data import/export.""" """Class for managing BomItem data import/export."""
class Meta: class Meta:
"""Metaclass definition""" """Metaclass options."""
model = models.BomItem model = models.BomItem
skip_unchanged = True skip_unchanged = True
@ -431,28 +431,28 @@ class BomItemResource(InvenTreeResource):
) )
def dehydrate_min_cost(self, item): def dehydrate_min_cost(self, item):
"""Render minimum cost value for the BOM line item""" """Render minimum cost value for the BOM line item."""
min_price = item.sub_part.pricing.overall_min if item.sub_part.pricing else None min_price = item.sub_part.pricing.overall_min if item.sub_part.pricing else None
if min_price is not None: if min_price is not None:
return float(min_price.amount) * float(item.quantity) return float(min_price.amount) * float(item.quantity)
def dehydrate_max_cost(self, item): def dehydrate_max_cost(self, item):
"""Render maximum cost value for the BOM line item""" """Render maximum cost value for the BOM line item."""
max_price = item.sub_part.pricing.overall_max if item.sub_part.pricing else None max_price = item.sub_part.pricing.overall_max if item.sub_part.pricing else None
if max_price is not None: if max_price is not None:
return float(max_price.amount) * float(item.quantity) return float(max_price.amount) * float(item.quantity)
def dehydrate_quantity(self, item): def dehydrate_quantity(self, item):
"""Special consideration for the 'quantity' field on data export. We do not want a spreadsheet full of "1.0000" (we'd rather "1") """Special consideration for the 'quantity' field on data export. We do not want a spreadsheet full of "1.0000" (we'd rather "1").
Ref: https://django-import-export.readthedocs.io/en/latest/getting_started.html#advanced-data-manipulation-on-export Ref: https://django-import-export.readthedocs.io/en/latest/getting_started.html#advanced-data-manipulation-on-export
""" """
return float(item.quantity) return float(item.quantity)
def before_export(self, queryset, *args, **kwargs): def before_export(self, queryset, *args, **kwargs):
"""Perform before exporting data""" """Perform before exporting data."""
self.is_importing = kwargs.get('importing', False) self.is_importing = kwargs.get('importing', False)
self.include_pricing = kwargs.pop('include_pricing', False) self.include_pricing = kwargs.pop('include_pricing', False)
@ -496,7 +496,7 @@ class BomItemResource(InvenTreeResource):
class BomItemAdmin(ImportExportModelAdmin): class BomItemAdmin(ImportExportModelAdmin):
"""Admin class for the BomItem model""" """Admin class for the BomItem model."""
resource_class = BomItemResource resource_class = BomItemResource
@ -513,13 +513,13 @@ class BomItemAdmin(ImportExportModelAdmin):
class ParameterTemplateResource(InvenTreeResource): class ParameterTemplateResource(InvenTreeResource):
"""Class for managing ParameterTemplate import/export""" """Class for managing ParameterTemplate import/export."""
# The following fields will be converted from None to '' # The following fields will be converted from None to ''
CONVERT_NULL_FIELDS = ['choices', 'units'] CONVERT_NULL_FIELDS = ['choices', 'units']
class Meta: class Meta:
"""Metaclass definition""" """Metaclass options."""
model = models.PartParameterTemplate model = models.PartParameterTemplate
skip_unchanged = True skip_unchanged = True
@ -530,7 +530,7 @@ class ParameterTemplateResource(InvenTreeResource):
class ParameterTemplateAdmin(ImportExportModelAdmin): class ParameterTemplateAdmin(ImportExportModelAdmin):
"""Admin class for the PartParameterTemplate model""" """Admin class for the PartParameterTemplate model."""
resource_class = ParameterTemplateResource resource_class = ParameterTemplateResource
@ -543,7 +543,7 @@ class ParameterResource(InvenTreeResource):
"""Class for managing PartParameter data import/export.""" """Class for managing PartParameter data import/export."""
class Meta: class Meta:
"""Metaclass definition""" """Metaclass options."""
model = models.PartParameter model = models.PartParameter
skip_unchanged = True skip_unchanged = True
@ -563,7 +563,7 @@ class ParameterResource(InvenTreeResource):
class ParameterAdmin(ImportExportModelAdmin): class ParameterAdmin(ImportExportModelAdmin):
"""Admin class for the PartParameter model""" """Admin class for the PartParameter model."""
resource_class = ParameterResource resource_class = ParameterResource
@ -573,16 +573,16 @@ class ParameterAdmin(ImportExportModelAdmin):
class PartCategoryParameterAdmin(admin.ModelAdmin): class PartCategoryParameterAdmin(admin.ModelAdmin):
"""Admin class for the PartCategoryParameterTemplate model""" """Admin class for the PartCategoryParameterTemplate model."""
autocomplete_fields = ('category', 'parameter_template') autocomplete_fields = ('category', 'parameter_template')
class PartSellPriceBreakAdmin(admin.ModelAdmin): class PartSellPriceBreakAdmin(admin.ModelAdmin):
"""Admin class for the PartSellPriceBreak model""" """Admin class for the PartSellPriceBreak model."""
class Meta: class Meta:
"""Metaclass definition""" """Metaclass options."""
model = models.PartSellPriceBreak model = models.PartSellPriceBreak
@ -590,10 +590,10 @@ class PartSellPriceBreakAdmin(admin.ModelAdmin):
class PartInternalPriceBreakAdmin(admin.ModelAdmin): class PartInternalPriceBreakAdmin(admin.ModelAdmin):
"""Admin class for the PartInternalPriceBreak model""" """Admin class for the PartInternalPriceBreak model."""
class Meta: class Meta:
"""Metaclass definition""" """Metaclass options."""
model = models.PartInternalPriceBreak model = models.PartInternalPriceBreak

View File

@ -77,19 +77,19 @@ from .models import (
class CategoryMixin: class CategoryMixin:
"""Mixin class for PartCategory endpoints""" """Mixin class for PartCategory endpoints."""
serializer_class = part_serializers.CategorySerializer serializer_class = part_serializers.CategorySerializer
queryset = PartCategory.objects.all() queryset = PartCategory.objects.all()
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Return an annotated queryset for the CategoryDetail endpoint""" """Return an annotated queryset for the CategoryDetail endpoint."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = part_serializers.CategorySerializer.annotate_queryset(queryset) queryset = part_serializers.CategorySerializer.annotate_queryset(queryset)
return queryset return queryset
def get_serializer_context(self): def get_serializer_context(self):
"""Add extra context to the serializer for the CategoryDetail endpoint""" """Add extra context to the serializer for the CategoryDetail endpoint."""
ctx = super().get_serializer_context() ctx = super().get_serializer_context()
try: try:
@ -111,7 +111,7 @@ class CategoryList(CategoryMixin, APIDownloadMixin, ListCreateAPI):
""" """
def download_queryset(self, queryset, export_format): def download_queryset(self, queryset, export_format):
"""Download the filtered queryset as a data file""" """Download the filtered queryset as a data file."""
dataset = PartCategoryResource().export(queryset=queryset) dataset = PartCategoryResource().export(queryset=queryset)
filedata = dataset.export(export_format) filedata = dataset.export(export_format)
filename = f'InvenTree_Categories.{export_format}' filename = f'InvenTree_Categories.{export_format}'
@ -119,8 +119,9 @@ class CategoryList(CategoryMixin, APIDownloadMixin, ListCreateAPI):
return DownloadFile(filedata, filename) return DownloadFile(filedata, filename)
def filter_queryset(self, queryset): def filter_queryset(self, queryset):
"""Custom filtering: """Custom filtering.
Rules:
- Allow filtering by "null" parent to retrieve top-level part categories - Allow filtering by "null" parent to retrieve top-level part categories
""" """
queryset = super().filter_queryset(queryset) queryset = super().filter_queryset(queryset)
@ -208,7 +209,7 @@ class CategoryDetail(CategoryMixin, CustomRetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single PartCategory object.""" """API endpoint for detail view of a single PartCategory object."""
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Add additional context based on query parameters""" """Add additional context based on query parameters."""
try: try:
params = self.request.query_params params = self.request.query_params
@ -219,7 +220,7 @@ class CategoryDetail(CategoryMixin, CustomRetrieveUpdateDestroyAPI):
return self.serializer_class(*args, **kwargs) return self.serializer_class(*args, **kwargs)
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
"""Perform 'update' function and mark this part as 'starred' (or not)""" """Perform 'update' function and mark this part as 'starred' (or not)."""
# Clean up input data # Clean up input data
data = self.clean_data(request.data) data = self.clean_data(request.data)
@ -233,7 +234,7 @@ class CategoryDetail(CategoryMixin, CustomRetrieveUpdateDestroyAPI):
return response return response
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
"""Delete a Part category instance via the API""" """Delete a Part category instance via the API."""
delete_parts = ( delete_parts = (
'delete_parts' in request.data and request.data['delete_parts'] == '1' 'delete_parts' in request.data and request.data['delete_parts'] == '1'
) )
@ -274,8 +275,9 @@ class CategoryParameterList(ListCreateAPI):
serializer_class = part_serializers.CategoryParameterTemplateSerializer serializer_class = part_serializers.CategoryParameterTemplateSerializer
def get_queryset(self): def get_queryset(self):
"""Custom filtering: """Custom filtering.
Rules:
- Allow filtering by "null" parent to retrieve all categories parameter templates - Allow filtering by "null" parent to retrieve all categories parameter templates
- Allow filtering by category - Allow filtering by category
- Allow traversing all parent categories - Allow traversing all parent categories
@ -305,7 +307,7 @@ class CategoryParameterList(ListCreateAPI):
class CategoryParameterDetail(RetrieveUpdateDestroyAPI): class CategoryParameterDetail(RetrieveUpdateDestroyAPI):
"""Detail endpoint for the PartCategoryParameterTemplate model""" """Detail endpoint for the PartCategoryParameterTemplate model."""
queryset = PartCategoryParameterTemplate.objects.all() queryset = PartCategoryParameterTemplate.objects.all()
serializer_class = part_serializers.CategoryParameterTemplateSerializer serializer_class = part_serializers.CategoryParameterTemplateSerializer
@ -419,7 +421,7 @@ class PartThumbs(ListAPI):
serializer_class = part_serializers.PartThumbSerializer serializer_class = part_serializers.PartThumbSerializer
def get_queryset(self): def get_queryset(self):
"""Return a queryset which excludes any parts without images""" """Return a queryset which excludes any parts without images."""
queryset = super().get_queryset() queryset = super().get_queryset()
# Get all Parts which have an associated image # Get all Parts which have an associated image
@ -477,7 +479,7 @@ class PartScheduling(RetrieveAPI):
queryset = Part.objects.all() queryset = Part.objects.all()
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
"""Return scheduling information for the referenced Part instance""" """Return scheduling information for the referenced Part instance."""
part = self.get_object() part = self.get_object()
schedule = [] schedule = []
@ -485,13 +487,13 @@ class PartScheduling(RetrieveAPI):
def add_schedule_entry( def add_schedule_entry(
date, quantity, title, label, url, speculative_quantity=0 date, quantity, title, label, url, speculative_quantity=0
): ):
"""Check if a scheduled entry should be added: """Check if a scheduled entry should be added.
Rules:
- date must be non-null - date must be non-null
- date cannot be in the "past" - date cannot be in the "past"
- quantity must not be zero - quantity must not be zero
""" """
schedule.append({ schedule.append({
'date': date, 'date': date,
'quantity': quantity, 'quantity': quantity,
@ -646,7 +648,6 @@ class PartScheduling(RetrieveAPI):
Account for the fact that either date might be None Account for the fact that either date might be None
""" """
date_1 = entry_1['date'] date_1 = entry_1['date']
date_2 = entry_2['date'] date_2 = entry_2['date']
@ -678,7 +679,7 @@ class PartRequirements(RetrieveAPI):
queryset = Part.objects.all() queryset = Part.objects.all()
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
"""Construct a response detailing Part requirements""" """Construct a response detailing Part requirements."""
part = self.get_object() part = self.get_object()
data = { data = {
@ -705,18 +706,18 @@ class PartRequirements(RetrieveAPI):
class PartPricingDetail(RetrieveUpdateAPI): class PartPricingDetail(RetrieveUpdateAPI):
"""API endpoint for viewing part pricing data""" """API endpoint for viewing part pricing data."""
serializer_class = part_serializers.PartPricingSerializer serializer_class = part_serializers.PartPricingSerializer
queryset = Part.objects.all() queryset = Part.objects.all()
def get_object(self): def get_object(self):
"""Return the PartPricing object associated with the linked Part""" """Return the PartPricing object associated with the linked Part."""
part = super().get_object() part = super().get_object()
return part.pricing return part.pricing
def _get_serializer(self, *args, **kwargs): def _get_serializer(self, *args, **kwargs):
"""Return a part pricing serializer object""" """Return a part pricing serializer object."""
part = self.get_object() part = self.get_object()
kwargs['instance'] = part.pricing kwargs['instance'] = part.pricing
@ -729,7 +730,7 @@ class PartSerialNumberDetail(RetrieveAPI):
queryset = Part.objects.all() queryset = Part.objects.all()
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
"""Return serial number information for the referenced Part instance""" """Return serial number information for the referenced Part instance."""
part = self.get_object() part = self.get_object()
# Calculate the "latest" serial number # Calculate the "latest" serial number
@ -753,7 +754,7 @@ class PartCopyBOM(CreateAPI):
serializer_class = part_serializers.PartCopyBOMSerializer serializer_class = part_serializers.PartCopyBOMSerializer
def get_serializer_context(self): def get_serializer_context(self):
"""Add custom information to the serializer context for this endpoint""" """Add custom information to the serializer context for this endpoint."""
ctx = super().get_serializer_context() ctx = super().get_serializer_context()
try: try:
@ -768,10 +769,10 @@ class PartValidateBOM(RetrieveUpdateAPI):
"""API endpoint for 'validating' the BOM for a given Part.""" """API endpoint for 'validating' the BOM for a given Part."""
class BOMValidateSerializer(serializers.ModelSerializer): class BOMValidateSerializer(serializers.ModelSerializer):
"""Simple serializer class for validating a single BomItem instance""" """Simple serializer class for validating a single BomItem instance."""
class Meta: class Meta:
"""Metaclass defines serializer fields""" """Metaclass defines serializer fields."""
model = Part model = Part
fields = ['checksum', 'valid'] fields = ['checksum', 'valid']
@ -786,7 +787,7 @@ class PartValidateBOM(RetrieveUpdateAPI):
) )
def validate_valid(self, valid): def validate_valid(self, valid):
"""Check that the 'valid' input was flagged""" """Check that the 'valid' input was flagged."""
if not valid: if not valid:
raise ValidationError(_('This option must be selected')) raise ValidationError(_('This option must be selected'))
@ -795,7 +796,7 @@ class PartValidateBOM(RetrieveUpdateAPI):
serializer_class = BOMValidateSerializer serializer_class = BOMValidateSerializer
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
"""Validate the referenced BomItem instance""" """Validate the referenced BomItem instance."""
part = self.get_object() part = self.get_object()
partial = kwargs.pop('partial', False) partial = kwargs.pop('partial', False)
@ -818,7 +819,7 @@ class PartFilter(rest_filters.FilterSet):
""" """
class Meta: class Meta:
"""Metaclass options for this filter set""" """Metaclass options for this filter set."""
model = Part model = Part
fields = [] fields = []
@ -826,7 +827,7 @@ class PartFilter(rest_filters.FilterSet):
has_units = rest_filters.BooleanFilter(label='Has units', method='filter_has_units') has_units = rest_filters.BooleanFilter(label='Has units', method='filter_has_units')
def filter_has_units(self, queryset, name, value): def filter_has_units(self, queryset, name, value):
"""Filter by whether the Part has units or not""" """Filter by whether the Part has units or not."""
if str2bool(value): if str2bool(value):
return queryset.exclude(Q(units=None) | Q(units='')) return queryset.exclude(Q(units=None) | Q(units=''))
@ -836,7 +837,7 @@ class PartFilter(rest_filters.FilterSet):
has_ipn = rest_filters.BooleanFilter(label='Has IPN', method='filter_has_ipn') has_ipn = rest_filters.BooleanFilter(label='Has IPN', method='filter_has_ipn')
def filter_has_ipn(self, queryset, name, value): def filter_has_ipn(self, queryset, name, value):
"""Filter by whether the Part has an IPN (internal part number) or not""" """Filter by whether the Part has an IPN (internal part number) or not."""
if str2bool(value): if str2bool(value):
return queryset.exclude(IPN='') return queryset.exclude(IPN='')
return queryset.filter(IPN='') return queryset.filter(IPN='')
@ -878,7 +879,7 @@ class PartFilter(rest_filters.FilterSet):
has_stock = rest_filters.BooleanFilter(label='Has stock', method='filter_has_stock') has_stock = rest_filters.BooleanFilter(label='Has stock', method='filter_has_stock')
def filter_has_stock(self, queryset, name, value): def filter_has_stock(self, queryset, name, value):
"""Filter by whether the Part has any stock""" """Filter by whether the Part has any stock."""
if str2bool(value): if str2bool(value):
return queryset.filter(Q(in_stock__gt=0)) return queryset.filter(Q(in_stock__gt=0))
return queryset.filter(Q(in_stock__lte=0)) return queryset.filter(Q(in_stock__lte=0))
@ -889,7 +890,7 @@ class PartFilter(rest_filters.FilterSet):
) )
def filter_unallocated_stock(self, queryset, name, value): def filter_unallocated_stock(self, queryset, name, value):
"""Filter by whether the Part has unallocated stock""" """Filter by whether the Part has unallocated stock."""
if str2bool(value): if str2bool(value):
return queryset.filter(Q(unallocated_stock__gt=0)) return queryset.filter(Q(unallocated_stock__gt=0))
return queryset.filter(Q(unallocated_stock__lte=0)) return queryset.filter(Q(unallocated_stock__lte=0))
@ -901,7 +902,7 @@ class PartFilter(rest_filters.FilterSet):
) )
def filter_convert_from(self, queryset, name, part): def filter_convert_from(self, queryset, name, part):
"""Limit the queryset to valid conversion options for the specified part""" """Limit the queryset to valid conversion options for the specified part."""
conversion_options = part.get_conversion_options() conversion_options = part.get_conversion_options()
queryset = queryset.filter(pk__in=conversion_options) queryset = queryset.filter(pk__in=conversion_options)
@ -915,7 +916,7 @@ class PartFilter(rest_filters.FilterSet):
) )
def filter_exclude_tree(self, queryset, name, part): def filter_exclude_tree(self, queryset, name, part):
"""Exclude all parts and variants 'down' from the specified part from the queryset""" """Exclude all parts and variants 'down' from the specified part from the queryset."""
children = part.get_descendants(include_self=True) children = part.get_descendants(include_self=True)
return queryset.exclude(id__in=children) return queryset.exclude(id__in=children)
@ -925,7 +926,7 @@ class PartFilter(rest_filters.FilterSet):
) )
def filter_ancestor(self, queryset, name, part): def filter_ancestor(self, queryset, name, part):
"""Limit queryset to descendants of the specified ancestor part""" """Limit queryset to descendants of the specified ancestor part."""
descendants = part.get_descendants(include_self=False) descendants = part.get_descendants(include_self=False)
return queryset.filter(id__in=descendants) return queryset.filter(id__in=descendants)
@ -934,7 +935,7 @@ class PartFilter(rest_filters.FilterSet):
) )
def filter_variant_of(self, queryset, name, part): def filter_variant_of(self, queryset, name, part):
"""Limit queryset to direct children (variants) of the specified part""" """Limit queryset to direct children (variants) of the specified part."""
return queryset.filter(id__in=part.get_children()) return queryset.filter(id__in=part.get_children())
in_bom_for = rest_filters.ModelChoiceFilter( in_bom_for = rest_filters.ModelChoiceFilter(
@ -942,7 +943,7 @@ class PartFilter(rest_filters.FilterSet):
) )
def filter_in_bom(self, queryset, name, part): def filter_in_bom(self, queryset, name, part):
"""Limit queryset to parts in the BOM for the specified part""" """Limit queryset to parts in the BOM for the specified part."""
bom_parts = part.get_parts_in_bom() bom_parts = part.get_parts_in_bom()
return queryset.filter(id__in=[p.pk for p in bom_parts]) return queryset.filter(id__in=[p.pk for p in bom_parts])
@ -951,7 +952,7 @@ class PartFilter(rest_filters.FilterSet):
) )
def filter_has_pricing(self, queryset, name, value): def filter_has_pricing(self, queryset, name, value):
"""Filter the queryset based on whether pricing information is available for the sub_part""" """Filter the queryset based on whether pricing information is available for the sub_part."""
q_a = Q(pricing_data=None) q_a = Q(pricing_data=None)
q_b = Q(pricing_data__overall_min=None, pricing_data__overall_max=None) q_b = Q(pricing_data__overall_min=None, pricing_data__overall_max=None)
@ -965,7 +966,7 @@ class PartFilter(rest_filters.FilterSet):
) )
def filter_has_stocktake(self, queryset, name, value): def filter_has_stocktake(self, queryset, name, value):
"""Filter the queryset based on whether stocktake data is available""" """Filter the queryset based on whether stocktake data is available."""
if str2bool(value): if str2bool(value):
return queryset.exclude(last_stocktake=None) return queryset.exclude(last_stocktake=None)
return queryset.filter(last_stocktake=None) return queryset.filter(last_stocktake=None)
@ -975,7 +976,7 @@ class PartFilter(rest_filters.FilterSet):
) )
def filter_stock_to_build(self, queryset, name, value): def filter_stock_to_build(self, queryset, name, value):
"""Filter the queryset based on whether part stock is required for a pending BuildOrder""" """Filter the queryset based on whether part stock is required for a pending BuildOrder."""
if str2bool(value): if str2bool(value):
# Return parts which are required for a build order, but have not yet been allocated # Return parts which are required for a build order, but have not yet been allocated
return queryset.filter( return queryset.filter(
@ -991,7 +992,7 @@ class PartFilter(rest_filters.FilterSet):
) )
def filter_depleted_stock(self, queryset, name, value): def filter_depleted_stock(self, queryset, name, value):
"""Filter the queryset based on whether the part is fully depleted of stock""" """Filter the queryset based on whether the part is fully depleted of stock."""
if str2bool(value): if str2bool(value):
return queryset.filter(Q(in_stock=0) & ~Q(stock_item_count=0)) return queryset.filter(Q(in_stock=0) & ~Q(stock_item_count=0))
return queryset.exclude(Q(in_stock=0) & ~Q(stock_item_count=0)) return queryset.exclude(Q(in_stock=0) & ~Q(stock_item_count=0))
@ -1030,7 +1031,7 @@ class PartFilter(rest_filters.FilterSet):
class PartMixin: class PartMixin:
"""Mixin class for Part API endpoints""" """Mixin class for Part API endpoints."""
serializer_class = part_serializers.PartSerializer serializer_class = part_serializers.PartSerializer
queryset = Part.objects.all() queryset = Part.objects.all()
@ -1040,7 +1041,7 @@ class PartMixin:
is_create = False is_create = False
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Return an annotated queryset object for the PartDetail endpoint""" """Return an annotated queryset object for the PartDetail endpoint."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = part_serializers.PartSerializer.annotate_queryset(queryset) queryset = part_serializers.PartSerializer.annotate_queryset(queryset)
@ -1048,7 +1049,7 @@ class PartMixin:
return queryset return queryset
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return a serializer instance for this endpoint""" """Return a serializer instance for this endpoint."""
# Ensure the request context is passed through # Ensure the request context is passed through
kwargs['context'] = self.get_serializer_context() kwargs['context'] = self.get_serializer_context()
@ -1077,7 +1078,7 @@ class PartMixin:
return self.serializer_class(*args, **kwargs) return self.serializer_class(*args, **kwargs)
def get_serializer_context(self): def get_serializer_context(self):
"""Extend serializer context data""" """Extend serializer context data."""
context = super().get_serializer_context() context = super().get_serializer_context()
context['request'] = self.request context['request'] = self.request
@ -1085,13 +1086,13 @@ class PartMixin:
class PartList(PartMixin, APIDownloadMixin, ListCreateAPI): class PartList(PartMixin, APIDownloadMixin, ListCreateAPI):
"""API endpoint for accessing a list of Part objects, or creating a new Part instance""" """API endpoint for accessing a list of Part objects, or creating a new Part instance."""
filterset_class = PartFilter filterset_class = PartFilter
is_create = True is_create = True
def download_queryset(self, queryset, export_format): def download_queryset(self, queryset, export_format):
"""Download the filtered queryset as a data file""" """Download the filtered queryset as a data file."""
dataset = PartResource().export(queryset=queryset) dataset = PartResource().export(queryset=queryset)
filedata = dataset.export(export_format) filedata = dataset.export(export_format)
@ -1127,7 +1128,7 @@ class PartList(PartMixin, APIDownloadMixin, ListCreateAPI):
return Response(data) return Response(data)
def filter_queryset(self, queryset): def filter_queryset(self, queryset):
"""Perform custom filtering of the queryset""" """Perform custom filtering of the queryset."""
params = self.request.query_params params = self.request.query_params
queryset = super().filter_queryset(queryset) queryset = super().filter_queryset(queryset)
@ -1311,7 +1312,7 @@ class PartList(PartMixin, APIDownloadMixin, ListCreateAPI):
class PartChangeCategory(CreateAPI): class PartChangeCategory(CreateAPI):
"""API endpoint to change the location of multiple parts in bulk""" """API endpoint to change the location of multiple parts in bulk."""
serializer_class = part_serializers.PartSetCategorySerializer serializer_class = part_serializers.PartSetCategorySerializer
queryset = Part.objects.none() queryset = Part.objects.none()
@ -1321,7 +1322,7 @@ class PartDetail(PartMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single Part object.""" """API endpoint for detail view of a single Part object."""
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
"""Delete a Part instance via the API """Delete a Part instance via the API.
- If the part is 'active' it cannot be deleted - If the part is 'active' it cannot be deleted
- It must first be marked as 'inactive' - It must first be marked as 'inactive'
@ -1360,7 +1361,7 @@ class PartRelatedList(ListCreateAPI):
serializer_class = part_serializers.PartRelationSerializer serializer_class = part_serializers.PartRelationSerializer
def filter_queryset(self, queryset): def filter_queryset(self, queryset):
"""Custom queryset filtering""" """Custom queryset filtering."""
queryset = super().filter_queryset(queryset) queryset = super().filter_queryset(queryset)
params = self.request.query_params params = self.request.query_params
@ -1390,7 +1391,7 @@ class PartParameterTemplateFilter(rest_filters.FilterSet):
"""FilterSet for PartParameterTemplate objects.""" """FilterSet for PartParameterTemplate objects."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = PartParameterTemplate model = PartParameterTemplate
@ -1472,7 +1473,7 @@ class PartParameterTemplateList(ListCreateAPI):
class PartParameterTemplateDetail(RetrieveUpdateDestroyAPI): class PartParameterTemplateDetail(RetrieveUpdateDestroyAPI):
"""API endpoint for accessing the detail view for a PartParameterTemplate object""" """API endpoint for accessing the detail view for a PartParameterTemplate object."""
queryset = PartParameterTemplate.objects.all() queryset = PartParameterTemplate.objects.all()
serializer_class = part_serializers.PartParameterTemplateSerializer serializer_class = part_serializers.PartParameterTemplateSerializer
@ -1485,7 +1486,7 @@ class PartParameterAPIMixin:
serializer_class = part_serializers.PartParameterSerializer serializer_class = part_serializers.PartParameterSerializer
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Override get_queryset method to prefetch related fields""" """Override get_queryset method to prefetch related fields."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = queryset.prefetch_related('part', 'template') queryset = queryset.prefetch_related('part', 'template')
return queryset return queryset
@ -1509,10 +1510,10 @@ class PartParameterAPIMixin:
class PartParameterFilter(rest_filters.FilterSet): class PartParameterFilter(rest_filters.FilterSet):
"""Custom filters for the PartParameterList API endpoint""" """Custom filters for the PartParameterList API endpoint."""
class Meta: class Meta:
"""Metaclass options for the filterset""" """Metaclass options for the filterset."""
model = PartParameter model = PartParameter
fields = ['template'] fields = ['template']
@ -1570,24 +1571,24 @@ class PartParameterDetail(PartParameterAPIMixin, RetrieveUpdateDestroyAPI):
class PartStocktakeFilter(rest_filters.FilterSet): class PartStocktakeFilter(rest_filters.FilterSet):
"""Custom filter for the PartStocktakeList endpoint""" """Custom filter for the PartStocktakeList endpoint."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = PartStocktake model = PartStocktake
fields = ['part', 'user'] fields = ['part', 'user']
class PartStocktakeList(ListCreateAPI): class PartStocktakeList(ListCreateAPI):
"""API endpoint for listing part stocktake information""" """API endpoint for listing part stocktake information."""
queryset = PartStocktake.objects.all() queryset = PartStocktake.objects.all()
serializer_class = part_serializers.PartStocktakeSerializer serializer_class = part_serializers.PartStocktakeSerializer
filterset_class = PartStocktakeFilter filterset_class = PartStocktakeFilter
def get_serializer_context(self): def get_serializer_context(self):
"""Extend serializer context data""" """Extend serializer context data."""
context = super().get_serializer_context() context = super().get_serializer_context()
context['request'] = self.request context['request'] = self.request
@ -1612,7 +1613,7 @@ class PartStocktakeDetail(RetrieveUpdateDestroyAPI):
class PartStocktakeReportList(ListAPI): class PartStocktakeReportList(ListAPI):
"""API endpoint for listing part stocktake report information""" """API endpoint for listing part stocktake report information."""
queryset = PartStocktakeReport.objects.all() queryset = PartStocktakeReport.objects.all()
serializer_class = part_serializers.PartStocktakeReportSerializer serializer_class = part_serializers.PartStocktakeReportSerializer
@ -1626,7 +1627,7 @@ class PartStocktakeReportList(ListAPI):
class PartStocktakeReportGenerate(CreateAPI): class PartStocktakeReportGenerate(CreateAPI):
"""API endpoint for manually generating a new PartStocktakeReport""" """API endpoint for manually generating a new PartStocktakeReport."""
serializer_class = part_serializers.PartStocktakeReportGenerateSerializer serializer_class = part_serializers.PartStocktakeReportGenerateSerializer
@ -1635,7 +1636,7 @@ class PartStocktakeReportGenerate(CreateAPI):
role_required = 'stocktake' role_required = 'stocktake'
def get_serializer_context(self): def get_serializer_context(self):
"""Extend serializer context data""" """Extend serializer context data."""
context = super().get_serializer_context() context = super().get_serializer_context()
context['request'] = self.request context['request'] = self.request
@ -1646,7 +1647,7 @@ class BomFilter(rest_filters.FilterSet):
"""Custom filters for the BOM list.""" """Custom filters for the BOM list."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = BomItem model = BomItem
fields = ['optional', 'consumable', 'inherited', 'allow_variants', 'validated'] fields = ['optional', 'consumable', 'inherited', 'allow_variants', 'validated']
@ -1672,7 +1673,7 @@ class BomFilter(rest_filters.FilterSet):
) )
def filter_available_stock(self, queryset, name, value): def filter_available_stock(self, queryset, name, value):
"""Filter the queryset based on whether each line item has any available stock""" """Filter the queryset based on whether each line item has any available stock."""
if str2bool(value): if str2bool(value):
return queryset.filter(available_stock__gt=0) return queryset.filter(available_stock__gt=0)
return queryset.filter(available_stock=0) return queryset.filter(available_stock=0)
@ -1680,7 +1681,7 @@ class BomFilter(rest_filters.FilterSet):
on_order = rest_filters.BooleanFilter(label='On order', method='filter_on_order') on_order = rest_filters.BooleanFilter(label='On order', method='filter_on_order')
def filter_on_order(self, queryset, name, value): def filter_on_order(self, queryset, name, value):
"""Filter the queryset based on whether each line item has any stock on order""" """Filter the queryset based on whether each line item has any stock on order."""
if str2bool(value): if str2bool(value):
return queryset.filter(on_order__gt=0) return queryset.filter(on_order__gt=0)
return queryset.filter(on_order=0) return queryset.filter(on_order=0)
@ -1690,7 +1691,7 @@ class BomFilter(rest_filters.FilterSet):
) )
def filter_has_pricing(self, queryset, name, value): def filter_has_pricing(self, queryset, name, value):
"""Filter the queryset based on whether pricing information is available for the sub_part""" """Filter the queryset based on whether pricing information is available for the sub_part."""
q_a = Q(sub_part__pricing_data=None) q_a = Q(sub_part__pricing_data=None)
q_b = Q( q_b = Q(
sub_part__pricing_data__overall_min=None, sub_part__pricing_data__overall_min=None,
@ -1704,13 +1705,13 @@ class BomFilter(rest_filters.FilterSet):
class BomMixin: class BomMixin:
"""Mixin class for BomItem API endpoints""" """Mixin class for BomItem API endpoints."""
serializer_class = part_serializers.BomItemSerializer serializer_class = part_serializers.BomItemSerializer
queryset = BomItem.objects.all() queryset = BomItem.objects.all()
def get_serializer(self, *args, **kwargs): def get_serializer(self, *args, **kwargs):
"""Return the serializer instance for this API endpoint """Return the serializer instance for this API endpoint.
If requested, extra detail fields are annotated to the queryset: If requested, extra detail fields are annotated to the queryset:
- part_detail - part_detail
@ -1735,7 +1736,7 @@ class BomMixin:
return self.serializer_class(*args, **kwargs) return self.serializer_class(*args, **kwargs)
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
"""Return the queryset object for this endpoint""" """Return the queryset object for this endpoint."""
queryset = super().get_queryset(*args, **kwargs) queryset = super().get_queryset(*args, **kwargs)
queryset = self.get_serializer_class().setup_eager_loading(queryset) queryset = self.get_serializer_class().setup_eager_loading(queryset)
@ -1754,7 +1755,7 @@ class BomList(BomMixin, ListCreateDestroyAPIView):
filterset_class = BomFilter filterset_class = BomFilter
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
"""Return serialized list response for this endpoint""" """Return serialized list response for this endpoint."""
queryset = self.filter_queryset(self.get_queryset()) queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset) page = self.paginate_queryset(queryset)
@ -1778,7 +1779,7 @@ class BomList(BomMixin, ListCreateDestroyAPIView):
return Response(data) return Response(data)
def filter_queryset(self, queryset): def filter_queryset(self, queryset):
"""Custom query filtering for the BomItem list API""" """Custom query filtering for the BomItem list API."""
queryset = super().filter_queryset(queryset) queryset = super().filter_queryset(queryset)
params = self.request.query_params params = self.request.query_params
@ -1906,7 +1907,7 @@ class BomItemValidate(UpdateAPI):
"""API endpoint for validating a BomItem.""" """API endpoint for validating a BomItem."""
class BomItemValidationSerializer(serializers.Serializer): class BomItemValidationSerializer(serializers.Serializer):
"""Simple serializer for passing a single boolean field""" """Simple serializer for passing a single boolean field."""
valid = serializers.BooleanField(default=False) valid = serializers.BooleanField(default=False)

View File

@ -1,4 +1,4 @@
"""part app specification""" """part app specification."""
import logging import logging
@ -11,7 +11,7 @@ logger = logging.getLogger('inventree')
class PartConfig(AppConfig): class PartConfig(AppConfig):
"""Config class for the 'part' app""" """Config class for the 'part' app."""
name = 'part' name = 'part'

View File

@ -1,4 +1,4 @@
"""Custom query filters for the Part models """Custom query filters for the Part models.
The code here makes heavy use of subquery annotations! The code here makes heavy use of subquery annotations!
@ -89,7 +89,7 @@ def annotate_on_order_quantity(reference: str = ''):
def annotate_total_stock(reference: str = ''): def annotate_total_stock(reference: str = ''):
"""Annotate 'total stock' quantity against a queryset: """Annotate 'total stock' quantity against a queryset.
- This function calculates the 'total stock' for a given part - This function calculates the 'total stock' for a given part
- Finds all stock items associated with each part (using the provided filter) - Finds all stock items associated with each part (using the provided filter)
@ -128,7 +128,7 @@ def annotate_build_order_requirements(reference: str = ''):
def annotate_build_order_allocations(reference: str = ''): def annotate_build_order_allocations(reference: str = ''):
"""Annotate the total quantity of each part allocated to build orders: """Annotate the total quantity of each part allocated to build orders.
- This function calculates the total part quantity allocated to open build orders - This function calculates the total part quantity allocated to open build orders
- Finds all build order allocations for each part (using the provided filter) - Finds all build order allocations for each part (using the provided filter)
@ -151,7 +151,7 @@ def annotate_build_order_allocations(reference: str = ''):
def annotate_sales_order_allocations(reference: str = ''): def annotate_sales_order_allocations(reference: str = ''):
"""Annotate the total quantity of each part allocated to sales orders: """Annotate the total quantity of each part allocated to sales orders.
- This function calculates the total part quantity allocated to open sales orders" - This function calculates the total part quantity allocated to open sales orders"
- Finds all sales order allocations for each part (using the provided filter) - Finds all sales order allocations for each part (using the provided filter)
@ -180,7 +180,7 @@ def annotate_sales_order_allocations(reference: str = ''):
def variant_stock_query( def variant_stock_query(
reference: str = '', filter: Q = stock.models.StockItem.IN_STOCK_FILTER reference: str = '', filter: Q = stock.models.StockItem.IN_STOCK_FILTER
): ):
"""Create a queryset to retrieve all stock items for variant parts under the specified part """Create a queryset to retrieve all stock items for variant parts under the specified part.
- Useful for annotating a queryset with aggregated information about variant parts - Useful for annotating a queryset with aggregated information about variant parts
@ -196,7 +196,7 @@ def variant_stock_query(
def annotate_variant_quantity(subquery: Q, reference: str = 'quantity'): def annotate_variant_quantity(subquery: Q, reference: str = 'quantity'):
"""Create a subquery annotation for all variant part stock items on the given parent query """Create a subquery annotation for all variant part stock items on the given parent query.
Args: Args:
subquery: A 'variant_stock_query' Q object subquery: A 'variant_stock_query' Q object
@ -239,15 +239,15 @@ def annotate_category_parts():
def filter_by_parameter(queryset, template_id: int, value: str, func: str = ''): def filter_by_parameter(queryset, template_id: int, value: str, func: str = ''):
"""Filter the given queryset by a given template parameter """Filter the given queryset by a given template parameter.
Parts which do not have a value for the given parameter are excluded. Parts which do not have a value for the given parameter are excluded.
Arguments: Arguments:
queryset - A queryset of Part objects queryset: A queryset of Part objects
template_id - The ID of the template parameter to filter by template_id (int): The ID of the template parameter to filter by
value - The value of the parameter to filter by value (str): The value of the parameter to filter by
func - The function to use for the filter (e.g. __gt, __lt, __contains) func (str): The function to use for the filter (e.g. __gt, __lt, __contains)
Returns: Returns:
A queryset of Part objects filtered by the given parameter A queryset of Part objects filtered by the given parameter
@ -257,13 +257,14 @@ def filter_by_parameter(queryset, template_id: int, value: str, func: str = ''):
def order_by_parameter(queryset, template_id: int, ascending=True): def order_by_parameter(queryset, template_id: int, ascending=True):
"""Order the given queryset by a given template parameter """Order the given queryset by a given template parameter.
Parts which do not have a value for the given parameter are ordered last. Parts which do not have a value for the given parameter are ordered last.
Arguments: Arguments:
queryset - A queryset of Part objects queryset: A queryset of Part objects
template_id - The ID of the template parameter to order by template_id (int): The ID of the template parameter to order by
ascending (bool): Order by ascending or descending (default = True)
Returns: Returns:
A queryset of Part objects ordered by the given parameter A queryset of Part objects ordered by the given parameter

View File

@ -37,7 +37,7 @@ class PartPriceForm(forms.Form):
"""Simple form for viewing part pricing information.""" """Simple form for viewing part pricing information."""
class Meta: class Meta:
"""Metaclass defines fields for this form""" """Metaclass defines fields for this form."""
model = Part model = Part
fields = ['quantity'] fields = ['quantity']

View File

@ -1,4 +1,4 @@
"""Various helper functions for the part app""" """Various helper functions for the part app."""
import logging import logging
import os import os
@ -20,7 +20,6 @@ def compile_full_name_template(*args, **kwargs):
This function is called whenever the 'PART_NAME_FORMAT' setting is changed. This function is called whenever the 'PART_NAME_FORMAT' setting is changed.
""" """
from common.models import InvenTreeSetting from common.models import InvenTreeSetting
global _part_full_name_template global _part_full_name_template
@ -61,7 +60,6 @@ def render_part_full_name(part) -> str:
Args: Args:
part: The Part object to render part: The Part object to render
""" """
template = compile_full_name_template() template = compile_full_name_template()
if template: if template:
@ -91,7 +89,6 @@ def get_part_image_directory() -> str:
TODO: Future work may be needed here to support other storage backends, such as S3 TODO: Future work may be needed here to support other storage backends, such as S3
""" """
part_image_directory = os.path.abspath( part_image_directory = os.path.abspath(
os.path.join(settings.MEDIA_ROOT, PART_IMAGE_DIR) os.path.join(settings.MEDIA_ROOT, PART_IMAGE_DIR)
) )

View File

@ -83,7 +83,7 @@ class PartCategory(MetadataMixin, InvenTreeTree):
ITEM_PARENT_KEY = 'category' ITEM_PARENT_KEY = 'category'
class Meta: class Meta:
"""Metaclass defines extra model properties""" """Metaclass defines extra model properties."""
verbose_name = _('Part Category') verbose_name = _('Part Category')
verbose_name_plural = _('Part Categories') verbose_name_plural = _('Part Categories')
@ -93,7 +93,6 @@ class PartCategory(MetadataMixin, InvenTreeTree):
This must be handled within a transaction.atomic(), otherwise the tree structure is damaged This must be handled within a transaction.atomic(), otherwise the tree structure is damaged
""" """
super().delete( super().delete(
delete_children=kwargs.get('delete_child_categories', False), delete_children=kwargs.get('delete_child_categories', False),
delete_items=kwargs.get('delete_parts', False), delete_items=kwargs.get('delete_parts', False),
@ -135,17 +134,17 @@ class PartCategory(MetadataMixin, InvenTreeTree):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API url associated with the PartCategory model""" """Return the API url associated with the PartCategory model."""
return reverse('api-part-category-list') return reverse('api-part-category-list')
def get_absolute_url(self): def get_absolute_url(self):
"""Return the web URL associated with the detail view for this PartCategory instance""" """Return the web URL associated with the detail view for this PartCategory instance."""
return reverse('category-detail', kwargs={'pk': self.id}) return reverse('category-detail', kwargs={'pk': self.id})
def clean(self): def clean(self):
"""Custom clean action for the PartCategory model: """Custom clean action for the PartCategory model.
- Ensure that the structural parameter cannot get set if products already assigned to the category Ensure that the structural parameter cannot get set if products already assigned to the category
""" """
if self.pk and self.structural and self.partcount(False, False) > 0: if self.pk and self.structural and self.partcount(False, False) > 0:
raise ValidationError( raise ValidationError(
@ -177,11 +176,11 @@ class PartCategory(MetadataMixin, InvenTreeTree):
@property @property
def item_count(self): def item_count(self):
"""Return the number of parts contained in this PartCategory""" """Return the number of parts contained in this PartCategory."""
return self.partcount() return self.partcount()
def get_items(self, cascade=False): def get_items(self, cascade=False):
"""Return a queryset containing the parts which exist in this category""" """Return a queryset containing the parts which exist in this category."""
return self.get_parts(cascade=cascade) return self.get_parts(cascade=cascade)
def partcount(self, cascade=True, active=False): def partcount(self, cascade=True, active=False):
@ -312,7 +311,6 @@ def rename_part_image(instance, filename):
Returns: Returns:
Cleaned filename in format part_<n>_img Cleaned filename in format part_<n>_img
""" """
base = part_helpers.PART_IMAGE_DIR base = part_helpers.PART_IMAGE_DIR
fname = os.path.basename(filename) fname = os.path.basename(filename)
@ -327,7 +325,7 @@ class PartManager(TreeManager):
""" """
def get_queryset(self): def get_queryset(self):
"""Perform default prefetch operations when accessing Part model from the database""" """Perform default prefetch operations when accessing Part model from the database."""
return ( return (
super() super()
.get_queryset() .get_queryset()
@ -385,7 +383,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
tags = TaggableManager(blank=True) tags = TaggableManager(blank=True)
class Meta: class Meta:
"""Metaclass defines extra model properties""" """Metaclass defines extra model properties."""
verbose_name = _('Part') verbose_name = _('Part')
verbose_name_plural = _('Parts') verbose_name_plural = _('Parts')
@ -395,14 +393,14 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
] ]
class MPTTMeta: class MPTTMeta:
"""MPTT metaclass definitions""" """MPTT Metaclass options."""
# For legacy reasons the 'variant_of' field is used to indicate the MPTT parent # For legacy reasons the 'variant_of' field is used to indicate the MPTT parent
parent_attr = 'variant_of' parent_attr = 'variant_of'
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the list API endpoint URL associated with the Part model""" """Return the list API endpoint URL associated with the Part model."""
return reverse('api-part-list') return reverse('api-part-list')
def api_instance_filters(self): def api_instance_filters(self):
@ -481,7 +479,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
raise ValidationError({'variant_of': _('Invalid choice for parent part')}) raise ValidationError({'variant_of': _('Invalid choice for parent part')})
def __str__(self): def __str__(self):
"""Return a string representation of the Part (for use in the admin interface)""" """Return a string representation of the Part (for use in the admin interface)."""
return f'{self.full_name} - {self.description}' return f'{self.full_name} - {self.description}'
def get_parts_in_bom(self, **kwargs): def get_parts_in_bom(self, **kwargs):
@ -560,7 +558,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
return result return result
def validate_name(self, raise_error=True): def validate_name(self, raise_error=True):
"""Validate the name field for this Part instance """Validate the name field for this Part instance.
This function is exposed to any Validation plugins, and thus can be customized. This function is exposed to any Validation plugins, and thus can be customized.
""" """
@ -579,7 +577,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
raise ValidationError({'name': exc.message}) raise ValidationError({'name': exc.message})
def validate_ipn(self, raise_error=True): def validate_ipn(self, raise_error=True):
"""Ensure that the IPN (internal part number) is valid for this Part" """Ensure that the IPN (internal part number) is valid for this Part".
- Validation is handled by custom plugins - Validation is handled by custom plugins
- By default, no validation checks are performed - By default, no validation checks are performed
@ -745,8 +743,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
@property @property
def full_name(self): def full_name(self):
"""Format a 'full name' for this Part based on the format PART_NAME_FORMAT defined in InvenTree settings""" """Format a 'full name' for this Part based on the format PART_NAME_FORMAT defined in InvenTree settings."""
return part_helpers.render_part_full_name(self) return part_helpers.render_part_full_name(self)
def get_absolute_url(self): def get_absolute_url(self):
@ -1090,7 +1087,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
@property @property
def category_path(self): def category_path(self):
"""Return the category path of this Part instance""" """Return the category path of this Part instance."""
if self.category: if self.category:
return self.category.pathstring return self.category.pathstring
return '' return ''
@ -1652,7 +1649,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
@property @property
def has_bom(self): def has_bom(self):
"""Return True if this Part instance has any BOM items""" """Return True if this Part instance has any BOM items."""
return self.get_bom_items().exists() return self.get_bom_items().exists()
def get_trackable_parts(self): def get_trackable_parts(self):
@ -1756,7 +1753,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
return self.supplier_parts.count() return self.supplier_parts.count()
def update_pricing(self): def update_pricing(self):
"""Recalculate cached pricing for this Part instance""" """Recalculate cached pricing for this Part instance."""
self.pricing.update_pricing() self.pricing.update_pricing()
@property @property
@ -1826,8 +1823,9 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
return f'{min_price} - {max_price}' return f'{min_price} - {max_price}'
def get_supplier_price_range(self, quantity=1): def get_supplier_price_range(self, quantity=1):
"""Return the supplier price range of this part: """Return the supplier price range of this part.
Actions:
- Checks if there is any supplier pricing information associated with this Part - Checks if there is any supplier pricing information associated with this Part
- Iterate through available supplier pricing and select (min, max) - Iterate through available supplier pricing and select (min, max)
- Returns tuple of (min, max) - Returns tuple of (min, max)
@ -1964,7 +1962,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
@property @property
def has_price_breaks(self): def has_price_breaks(self):
"""Return True if this part has sale price breaks""" """Return True if this part has sale price breaks."""
return self.price_breaks.exists() return self.price_breaks.exists()
@property @property
@ -1974,7 +1972,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
@property @property
def unit_pricing(self): def unit_pricing(self):
"""Returns the price of this Part at quantity=1""" """Returns the price of this Part at quantity=1."""
return self.get_price(1) return self.get_price(1)
def add_price_break(self, quantity, price): def add_price_break(self, quantity, price):
@ -1991,14 +1989,14 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
PartSellPriceBreak.objects.create(part=self, quantity=quantity, price=price) PartSellPriceBreak.objects.create(part=self, quantity=quantity, price=price)
def get_internal_price(self, quantity, moq=True, multiples=True, currency=None): def get_internal_price(self, quantity, moq=True, multiples=True, currency=None):
"""Return the internal price of this Part at the specified quantity""" """Return the internal price of this Part at the specified quantity."""
return common.models.get_price( return common.models.get_price(
self, quantity, moq, multiples, currency, break_name='internal_price_breaks' self, quantity, moq, multiples, currency, break_name='internal_price_breaks'
) )
@property @property
def has_internal_price_breaks(self): def has_internal_price_breaks(self):
"""Return True if this Part has internal pricing information""" """Return True if this Part has internal pricing information."""
return self.internal_price_breaks.exists() return self.internal_price_breaks.exists()
@property @property
@ -2007,7 +2005,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
return self.internalpricebreaks.order_by('quantity').all() return self.internalpricebreaks.order_by('quantity').all()
def get_purchase_price(self, quantity): def get_purchase_price(self, quantity):
"""Calculate the purchase price for this part at the specified quantity """Calculate the purchase price for this part at the specified quantity.
- Looks at available supplier pricing data - Looks at available supplier pricing data
- Calculates the price base on the closest price point - Calculates the price base on the closest price point
@ -2091,7 +2089,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
@transaction.atomic @transaction.atomic
def copy_parameters_from(self, other, **kwargs): def copy_parameters_from(self, other, **kwargs):
"""Copy all parameter values from another Part instance""" """Copy all parameter values from another Part instance."""
clear = kwargs.get('clear', True) clear = kwargs.get('clear', True)
if clear: if clear:
@ -2136,7 +2134,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
return tests return tests
def getTestTemplateMap(self, **kwargs): def getTestTemplateMap(self, **kwargs):
"""Return a map of all test templates associated with this Part""" """Return a map of all test templates associated with this Part."""
templates = {} templates = {}
for template in self.getTestTemplates(**kwargs): for template in self.getTestTemplates(**kwargs):
@ -2145,7 +2143,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
return templates return templates
def getRequiredTests(self): def getRequiredTests(self):
"""Return the tests which are required by this part""" """Return the tests which are required by this part."""
return self.getTestTemplates(required=True) return self.getTestTemplates(required=True)
@property @property
@ -2246,7 +2244,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
@property @property
def latest_stocktake(self): def latest_stocktake(self):
"""Return the latest PartStocktake object associated with this part (if one exists)""" """Return the latest PartStocktake object associated with this part (if one exists)."""
return self.stocktakes.order_by('-pk').first() return self.stocktakes.order_by('-pk').first()
@property @property
@ -2303,7 +2301,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
return filtered_parts return filtered_parts
def get_related_parts(self): def get_related_parts(self):
"""Return a set of all related parts for this part""" """Return a set of all related parts for this part."""
related_parts = set() related_parts = set()
related_parts_1 = self.related_parts_1.filter(part_1__id=self.pk) related_parts_1 = self.related_parts_1.filter(part_1__id=self.pk)
@ -2322,7 +2320,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel)
@property @property
def related_count(self): def related_count(self):
"""Return the number of 'related parts' which point to this Part""" """Return the number of 'related parts' which point to this Part."""
return len(self.get_related_parts()) return len(self.get_related_parts())
def is_part_low_on_stock(self): def is_part_low_on_stock(self):
@ -2356,7 +2354,7 @@ def after_save_part(sender, instance: Part, created, **kwargs):
class PartPricing(common.models.MetaMixin): class PartPricing(common.models.MetaMixin):
"""Model for caching min/max pricing information for a particular Part """Model for caching min/max pricing information for a particular Part.
It is prohibitively expensive to calculate min/max pricing for a part "on the fly". It is prohibitively expensive to calculate min/max pricing for a part "on the fly".
As min/max pricing does not change very often, we pre-calculate and cache these values. As min/max pricing does not change very often, we pre-calculate and cache these values.
@ -2385,7 +2383,7 @@ class PartPricing(common.models.MetaMixin):
@property @property
def is_valid(self): def is_valid(self):
"""Return True if the cached pricing is valid""" """Return True if the cached pricing is valid."""
return self.updated is not None return self.updated is not None
def convert(self, money): def convert(self, money):
@ -2411,7 +2409,7 @@ class PartPricing(common.models.MetaMixin):
return result return result
def schedule_for_update(self, counter: int = 0, test: bool = False): def schedule_for_update(self, counter: int = 0, test: bool = False):
"""Schedule this pricing to be updated""" """Schedule this pricing to be updated."""
import InvenTree.ready import InvenTree.ready
# If we are running within CI, only schedule the update if the test flag is set # If we are running within CI, only schedule the update if the test flag is set
@ -2487,7 +2485,7 @@ class PartPricing(common.models.MetaMixin):
) )
def update_pricing(self, counter: int = 0, cascade: bool = True): def update_pricing(self, counter: int = 0, cascade: bool = True):
"""Recalculate all cost data for the referenced Part instance""" """Recalculate all cost data for the referenced Part instance."""
# If importing data, skip pricing update # If importing data, skip pricing update
if InvenTree.ready.isImportingData(): if InvenTree.ready.isImportingData():
@ -2526,7 +2524,7 @@ class PartPricing(common.models.MetaMixin):
self.update_templates(counter) self.update_templates(counter)
def update_assemblies(self, counter: int = 0): def update_assemblies(self, counter: int = 0):
"""Schedule updates for any assemblies which use this part""" """Schedule updates for any assemblies which use this part."""
# If the linked Part is used in any assemblies, schedule a pricing update for those assemblies # If the linked Part is used in any assemblies, schedule a pricing update for those assemblies
used_in_parts = self.part.get_used_in() used_in_parts = self.part.get_used_in()
@ -2534,14 +2532,14 @@ class PartPricing(common.models.MetaMixin):
p.pricing.schedule_for_update(counter + 1) p.pricing.schedule_for_update(counter + 1)
def update_templates(self, counter: int = 0): def update_templates(self, counter: int = 0):
"""Schedule updates for any template parts above this part""" """Schedule updates for any template parts above this part."""
templates = self.part.get_ancestors(include_self=False) templates = self.part.get_ancestors(include_self=False)
for p in templates: for p in templates:
p.pricing.schedule_for_update(counter + 1) p.pricing.schedule_for_update(counter + 1)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Whenever pricing model is saved, automatically update overall prices""" """Whenever pricing model is saved, automatically update overall prices."""
# Update the currency which was used to perform the calculation # Update the currency which was used to perform the calculation
self.currency = currency_code_default() self.currency = currency_code_default()
@ -2720,7 +2718,7 @@ class PartPricing(common.models.MetaMixin):
self.save() self.save()
def update_internal_cost(self, save=True): def update_internal_cost(self, save=True):
"""Recalculate internal cost for the referenced Part instance""" """Recalculate internal cost for the referenced Part instance."""
min_int_cost = None min_int_cost = None
max_int_cost = None max_int_cost = None
@ -2835,7 +2833,6 @@ class PartPricing(common.models.MetaMixin):
Here we simply take the minimum / maximum values of the other calculated fields. Here we simply take the minimum / maximum values of the other calculated fields.
""" """
overall_min = None overall_min = None
overall_max = None overall_max = None
@ -2907,7 +2904,7 @@ class PartPricing(common.models.MetaMixin):
self.overall_max = overall_max self.overall_max = overall_max
def update_sale_cost(self, save=True): def update_sale_cost(self, save=True):
"""Recalculate sale cost data""" """Recalculate sale cost data."""
# Iterate through the sell price breaks # Iterate through the sell price breaks
min_sell_price = None min_sell_price = None
max_sell_price = None max_sell_price = None
@ -3179,7 +3176,7 @@ class PartStocktake(models.Model):
@receiver(post_save, sender=PartStocktake, dispatch_uid='post_save_stocktake') @receiver(post_save, sender=PartStocktake, dispatch_uid='post_save_stocktake')
def update_last_stocktake(sender, instance, created, **kwargs): def update_last_stocktake(sender, instance, created, **kwargs):
"""Callback function when a PartStocktake instance is created / edited""" """Callback function when a PartStocktake instance is created / edited."""
# When a new PartStocktake instance is create, update the last_stocktake date for the Part # When a new PartStocktake instance is create, update the last_stocktake date for the Part
if created: if created:
try: try:
@ -3191,7 +3188,7 @@ def update_last_stocktake(sender, instance, created, **kwargs):
def save_stocktake_report(instance, filename): def save_stocktake_report(instance, filename):
"""Save stocktake reports to the correct subdirectory""" """Save stocktake reports to the correct subdirectory."""
filename = os.path.basename(filename) filename = os.path.basename(filename)
return os.path.join('stocktake', 'report', filename) return os.path.join('stocktake', 'report', filename)
@ -3214,11 +3211,11 @@ class PartStocktakeReport(models.Model):
""" """
def __str__(self): def __str__(self):
"""Construct a simple string representation for the report""" """Construct a simple string representation for the report."""
return os.path.basename(self.report.name) return os.path.basename(self.report.name)
def get_absolute_url(self): def get_absolute_url(self):
"""Return the URL for the associaed report file for download""" """Return the URL for the associaed report file for download."""
if self.report: if self.report:
return self.report.url return self.report.url
return None return None
@ -3255,11 +3252,11 @@ class PartAttachment(InvenTreeAttachment):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the list API endpoint URL associated with the PartAttachment model""" """Return the list API endpoint URL associated with the PartAttachment model."""
return reverse('api-part-attachment-list') return reverse('api-part-attachment-list')
def getSubdir(self): def getSubdir(self):
"""Returns the media subdirectory where part attachments are stored""" """Returns the media subdirectory where part attachments are stored."""
return os.path.join('part_files', str(self.part.id)) return os.path.join('part_files', str(self.part.id))
part = models.ForeignKey( part = models.ForeignKey(
@ -3274,13 +3271,13 @@ class PartSellPriceBreak(common.models.PriceBreak):
"""Represents a price break for selling this part.""" """Represents a price break for selling this part."""
class Meta: class Meta:
"""Metaclass providing extra model definition""" """Metaclass providing extra model definition."""
unique_together = ('part', 'quantity') unique_together = ('part', 'quantity')
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the list API endpoint URL associated with the PartSellPriceBreak model""" """Return the list API endpoint URL associated with the PartSellPriceBreak model."""
return reverse('api-part-sale-price-list') return reverse('api-part-sale-price-list')
part = models.ForeignKey( part = models.ForeignKey(
@ -3296,13 +3293,13 @@ class PartInternalPriceBreak(common.models.PriceBreak):
"""Represents a price break for internally selling this part.""" """Represents a price break for internally selling this part."""
class Meta: class Meta:
"""Metaclass providing extra model definition""" """Metaclass providing extra model definition."""
unique_together = ('part', 'quantity') unique_together = ('part', 'quantity')
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the list API endpoint URL associated with the PartInternalPriceBreak model""" """Return the list API endpoint URL associated with the PartInternalPriceBreak model."""
return reverse('api-part-internal-price-list') return reverse('api-part-internal-price-list')
part = models.ForeignKey( part = models.ForeignKey(
@ -3324,7 +3321,7 @@ class PartStar(models.Model):
""" """
class Meta: class Meta:
"""Metaclass providing extra model definition""" """Metaclass providing extra model definition."""
unique_together = ['part', 'user'] unique_together = ['part', 'user']
@ -3352,7 +3349,7 @@ class PartCategoryStar(models.Model):
""" """
class Meta: class Meta:
"""Metaclass providing extra model definition""" """Metaclass providing extra model definition."""
unique_together = ['category', 'user'] unique_together = ['category', 'user']
@ -3386,17 +3383,17 @@ class PartTestTemplate(MetadataMixin, models.Model):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the list API endpoint URL associated with the PartTestTemplate model""" """Return the list API endpoint URL associated with the PartTestTemplate model."""
return reverse('api-part-test-template-list') return reverse('api-part-test-template-list')
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Enforce 'clean' operation when saving a PartTestTemplate instance""" """Enforce 'clean' operation when saving a PartTestTemplate instance."""
self.clean() self.clean()
super().save(*args, **kwargs) super().save(*args, **kwargs)
def clean(self): def clean(self):
"""Clean fields for the PartTestTemplate model""" """Clean fields for the PartTestTemplate model."""
self.test_name = self.test_name.strip() self.test_name = self.test_name.strip()
self.validate_unique() self.validate_unique()
@ -3496,19 +3493,20 @@ class PartParameterTemplate(MetadataMixin, models.Model):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the list API endpoint URL associated with the PartParameterTemplate model""" """Return the list API endpoint URL associated with the PartParameterTemplate model."""
return reverse('api-part-parameter-template-list') return reverse('api-part-parameter-template-list')
def __str__(self): def __str__(self):
"""Return a string representation of a PartParameterTemplate instance""" """Return a string representation of a PartParameterTemplate instance."""
s = str(self.name) s = str(self.name)
if self.units: if self.units:
s += f' ({self.units})' s += f' ({self.units})'
return s return s
def clean(self): def clean(self):
"""Custom cleaning step for this model: """Custom cleaning step for this model.
Checks:
- A 'checkbox' field cannot have 'choices' set - A 'checkbox' field cannot have 'choices' set
- A 'checkbox' field cannot have 'units' set - A 'checkbox' field cannot have 'units' set
""" """
@ -3566,7 +3564,7 @@ class PartParameterTemplate(MetadataMixin, models.Model):
pass pass
def get_choices(self): def get_choices(self):
"""Return a list of choices for this parameter template""" """Return a list of choices for this parameter template."""
if not self.choices: if not self.choices:
return [] return []
@ -3614,7 +3612,7 @@ class PartParameterTemplate(MetadataMixin, models.Model):
dispatch_uid='post_save_part_parameter_template', dispatch_uid='post_save_part_parameter_template',
) )
def post_save_part_parameter_template(sender, instance, created, **kwargs): def post_save_part_parameter_template(sender, instance, created, **kwargs):
"""Callback function when a PartParameterTemplate is created or saved""" """Callback function when a PartParameterTemplate is created or saved."""
import part.tasks as part_tasks import part.tasks as part_tasks
if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData(): if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData():
@ -3635,18 +3633,18 @@ class PartParameter(MetadataMixin, models.Model):
""" """
class Meta: class Meta:
"""Metaclass providing extra model definition""" """Metaclass providing extra model definition."""
# Prevent multiple instances of a parameter for a single part # Prevent multiple instances of a parameter for a single part
unique_together = ('part', 'template') unique_together = ('part', 'template')
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the list API endpoint URL associated with the PartParameter model""" """Return the list API endpoint URL associated with the PartParameter model."""
return reverse('api-part-parameter-list') return reverse('api-part-parameter-list')
def __str__(self): def __str__(self):
"""String representation of a PartParameter (used in the admin interface)""" """String representation of a PartParameter (used in the admin interface)."""
return f'{self.part.full_name} : {self.template.name} = {self.data} ({self.template.units})' return f'{self.part.full_name} : {self.template.name} = {self.data} ({self.template.units})'
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
@ -3745,22 +3743,22 @@ class PartParameter(MetadataMixin, models.Model):
@property @property
def units(self): def units(self):
"""Return the units associated with the template""" """Return the units associated with the template."""
return self.template.units return self.template.units
@property @property
def name(self): def name(self):
"""Return the name of the template""" """Return the name of the template."""
return self.template.name return self.template.name
@property @property
def description(self): def description(self):
"""Return the description of the template""" """Return the description of the template."""
return self.template.description return self.template.description
@classmethod @classmethod
def create(cls, part, template, data, save=False): def create(cls, part, template, data, save=False):
"""Custom save method for the PartParameter class""" """Custom save method for the PartParameter class."""
part_parameter = cls(part=part, template=template, data=data) part_parameter = cls(part=part, template=template, data=data)
if save: if save:
part_parameter.save() part_parameter.save()
@ -3780,7 +3778,7 @@ class PartCategoryParameterTemplate(MetadataMixin, models.Model):
""" """
class Meta: class Meta:
"""Metaclass providing extra model definition""" """Metaclass providing extra model definition."""
constraints = [ constraints = [
UniqueConstraint( UniqueConstraint(
@ -3857,21 +3855,21 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model):
} }
class Meta: class Meta:
"""Metaclass providing extra model definition""" """Metaclass providing extra model definition."""
verbose_name = _('BOM Item') verbose_name = _('BOM Item')
def __str__(self): def __str__(self):
"""Return a string representation of this BomItem instance""" """Return a string representation of this BomItem instance."""
return f'{decimal2string(self.quantity)} x {self.sub_part.full_name} to make {self.part.full_name}' return f'{decimal2string(self.quantity)} x {self.sub_part.full_name} to make {self.part.full_name}'
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the list API endpoint URL associated with the BomItem model""" """Return the list API endpoint URL associated with the BomItem model."""
return reverse('api-bom-list') return reverse('api-bom-list')
def get_assemblies(self): def get_assemblies(self):
"""Return a list of assemblies which use this BomItem""" """Return a list of assemblies which use this BomItem."""
assemblies = [self.part] assemblies = [self.part]
if self.inherited: if self.inherited:
@ -3928,7 +3926,7 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model):
return Q(part__in=self.get_valid_parts_for_allocation()) return Q(part__in=self.get_valid_parts_for_allocation())
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Enforce 'clean' operation when saving a BomItem instance""" """Enforce 'clean' operation when saving a BomItem instance."""
self.clean() self.clean()
# Update the 'validated' field based on checksum calculation # Update the 'validated' field based on checksum calculation
@ -4212,7 +4210,7 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model):
@receiver(post_save, sender=BomItem, dispatch_uid='update_bom_build_lines') @receiver(post_save, sender=BomItem, dispatch_uid='update_bom_build_lines')
def update_bom_build_lines(sender, instance, created, **kwargs): def update_bom_build_lines(sender, instance, created, **kwargs):
"""Update existing build orders when a BomItem is created or edited""" """Update existing build orders when a BomItem is created or edited."""
if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData(): if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData():
import build.tasks import build.tasks
@ -4229,7 +4227,7 @@ def update_bom_build_lines(sender, instance, created, **kwargs):
dispatch_uid='post_save_internal_price_break', dispatch_uid='post_save_internal_price_break',
) )
def update_pricing_after_edit(sender, instance, created, **kwargs): def update_pricing_after_edit(sender, instance, created, **kwargs):
"""Callback function when a part price break is created or updated""" """Callback function when a part price break is created or updated."""
# Update part pricing *unless* we are importing data # Update part pricing *unless* we are importing data
if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData(): if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData():
instance.part.schedule_pricing_update(create=True) instance.part.schedule_pricing_update(create=True)
@ -4245,7 +4243,7 @@ def update_pricing_after_edit(sender, instance, created, **kwargs):
dispatch_uid='post_delete_internal_price_break', dispatch_uid='post_delete_internal_price_break',
) )
def update_pricing_after_delete(sender, instance, **kwargs): def update_pricing_after_delete(sender, instance, **kwargs):
"""Callback function when a part price break is deleted""" """Callback function when a part price break is deleted."""
# Update part pricing *unless* we are importing data # Update part pricing *unless* we are importing data
if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData(): if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData():
instance.part.schedule_pricing_update(create=False) instance.part.schedule_pricing_update(create=False)
@ -4260,7 +4258,7 @@ class BomItemSubstitute(MetadataMixin, models.Model):
""" """
class Meta: class Meta:
"""Metaclass providing extra model definition""" """Metaclass providing extra model definition."""
verbose_name = _('BOM Item Substitute') verbose_name = _('BOM Item Substitute')
@ -4268,7 +4266,7 @@ class BomItemSubstitute(MetadataMixin, models.Model):
unique_together = ('part', 'bom_item') unique_together = ('part', 'bom_item')
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Enforce a full_clean when saving the BomItemSubstitute model""" """Enforce a full_clean when saving the BomItemSubstitute model."""
self.full_clean() self.full_clean()
super().save(*args, **kwargs) super().save(*args, **kwargs)
@ -4288,7 +4286,7 @@ class BomItemSubstitute(MetadataMixin, models.Model):
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Returns the list API endpoint URL associated with this model""" """Returns the list API endpoint URL associated with this model."""
return reverse('api-bom-substitute-list') return reverse('api-bom-substitute-list')
bom_item = models.ForeignKey( bom_item = models.ForeignKey(
@ -4313,7 +4311,7 @@ class PartRelated(MetadataMixin, models.Model):
"""Store and handle related parts (eg. mating connector, crimps, etc.).""" """Store and handle related parts (eg. mating connector, crimps, etc.)."""
class Meta: class Meta:
"""Metaclass defines extra model properties""" """Metaclass defines extra model properties."""
unique_together = ('part_1', 'part_2') unique_together = ('part_1', 'part_2')
@ -4333,11 +4331,11 @@ class PartRelated(MetadataMixin, models.Model):
) )
def __str__(self): def __str__(self):
"""Return a string representation of this Part-Part relationship""" """Return a string representation of this Part-Part relationship."""
return f'{self.part_1} <--> {self.part_2}' return f'{self.part_1} <--> {self.part_2}'
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Enforce a 'clean' operation when saving a PartRelated instance""" """Enforce a 'clean' operation when saving a PartRelated instance."""
self.clean() self.clean()
self.validate_unique() self.validate_unique()
super().save(*args, **kwargs) super().save(*args, **kwargs)

View File

@ -62,7 +62,7 @@ class CategorySerializer(InvenTree.serializers.InvenTreeModelSerializer):
"""Serializer for PartCategory.""" """Serializer for PartCategory."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = PartCategory model = PartCategory
fields = [ fields = [
@ -83,7 +83,7 @@ class CategorySerializer(InvenTree.serializers.InvenTreeModelSerializer):
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Optionally add or remove extra fields""" """Optionally add or remove extra fields."""
path_detail = kwargs.pop('path_detail', False) path_detail = kwargs.pop('path_detail', False)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -97,7 +97,7 @@ class CategorySerializer(InvenTree.serializers.InvenTreeModelSerializer):
@staticmethod @staticmethod
def annotate_queryset(queryset): def annotate_queryset(queryset):
"""Annotate extra information to the queryset""" """Annotate extra information to the queryset."""
# Annotate the number of 'parts' which exist in each category (including subcategories!) # Annotate the number of 'parts' which exist in each category (including subcategories!)
queryset = queryset.annotate(part_count=part.filters.annotate_category_parts()) queryset = queryset.annotate(part_count=part.filters.annotate_category_parts())
@ -120,7 +120,7 @@ class CategoryTree(InvenTree.serializers.InvenTreeModelSerializer):
"""Serializer for PartCategory tree.""" """Serializer for PartCategory tree."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = PartCategory model = PartCategory
fields = ['pk', 'name', 'parent', 'icon', 'structural'] fields = ['pk', 'name', 'parent', 'icon', 'structural']
@ -130,7 +130,7 @@ class PartAttachmentSerializer(InvenTree.serializers.InvenTreeAttachmentSerializ
"""Serializer for the PartAttachment class.""" """Serializer for the PartAttachment class."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = PartAttachment model = PartAttachment
@ -143,7 +143,7 @@ class PartTestTemplateSerializer(InvenTree.serializers.InvenTreeModelSerializer)
"""Serializer for the PartTestTemplate class.""" """Serializer for the PartTestTemplate class."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = PartTestTemplate model = PartTestTemplate
@ -165,7 +165,7 @@ class PartSalePriceSerializer(InvenTree.serializers.InvenTreeModelSerializer):
"""Serializer for sale prices for Part model.""" """Serializer for sale prices for Part model."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = PartSellPriceBreak model = PartSellPriceBreak
fields = ['pk', 'part', 'quantity', 'price', 'price_currency'] fields = ['pk', 'part', 'quantity', 'price', 'price_currency']
@ -183,7 +183,7 @@ class PartInternalPriceSerializer(InvenTree.serializers.InvenTreeModelSerializer
"""Serializer for internal prices for Part model.""" """Serializer for internal prices for Part model."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = PartInternalPriceBreak model = PartInternalPriceBreak
fields = ['pk', 'part', 'quantity', 'price', 'price_currency'] fields = ['pk', 'part', 'quantity', 'price', 'price_currency']
@ -211,7 +211,7 @@ class PartThumbSerializerUpdate(InvenTree.serializers.InvenTreeModelSerializer):
"""Serializer for updating Part thumbnail.""" """Serializer for updating Part thumbnail."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = Part model = Part
fields = ['image'] fields = ['image']
@ -230,17 +230,17 @@ class PartParameterTemplateSerializer(InvenTree.serializers.InvenTreeModelSerial
"""JSON serializer for the PartParameterTemplate model.""" """JSON serializer for the PartParameterTemplate model."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = PartParameterTemplate model = PartParameterTemplate
fields = ['pk', 'name', 'units', 'description', 'checkbox', 'choices'] fields = ['pk', 'name', 'units', 'description', 'checkbox', 'choices']
class PartBriefSerializer(InvenTree.serializers.InvenTreeModelSerializer): class PartBriefSerializer(InvenTree.serializers.InvenTreeModelSerializer):
"""Serializer for Part (brief detail)""" """Serializer for Part (brief detail)."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = Part model = Part
fields = [ fields = [
@ -268,7 +268,7 @@ class PartBriefSerializer(InvenTree.serializers.InvenTreeModelSerializer):
read_only_fields = ['barcode_hash'] read_only_fields = ['barcode_hash']
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Custom initialization routine for the PartBrief serializer""" """Custom initialization routine for the PartBrief serializer."""
pricing = kwargs.pop('pricing', True) pricing = kwargs.pop('pricing', True)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -292,7 +292,7 @@ class PartParameterSerializer(InvenTree.serializers.InvenTreeModelSerializer):
"""JSON serializers for the PartParameter model.""" """JSON serializers for the PartParameter model."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = PartParameter model = PartParameter
fields = [ fields = [
@ -328,10 +328,10 @@ class PartParameterSerializer(InvenTree.serializers.InvenTreeModelSerializer):
class PartSetCategorySerializer(serializers.Serializer): class PartSetCategorySerializer(serializers.Serializer):
"""Serializer for changing PartCategory for multiple Part objects""" """Serializer for changing PartCategory for multiple Part objects."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
fields = ['parts', 'category'] fields = ['parts', 'category']
@ -344,7 +344,7 @@ class PartSetCategorySerializer(serializers.Serializer):
) )
def validate_parts(self, parts): def validate_parts(self, parts):
"""Validate the selected parts""" """Validate the selected parts."""
if len(parts) == 0: if len(parts) == 0:
raise serializers.ValidationError(_('No parts selected')) raise serializers.ValidationError(_('No parts selected'))
@ -361,7 +361,7 @@ class PartSetCategorySerializer(serializers.Serializer):
@transaction.atomic @transaction.atomic
def save(self): def save(self):
"""Save the serializer to change the location of the selected parts""" """Save the serializer to change the location of the selected parts."""
data = self.validated_data data = self.validated_data
parts = data['parts'] parts = data['parts']
category = data['category'] category = data['category']
@ -444,7 +444,7 @@ class InitialStockSerializer(serializers.Serializer):
class InitialSupplierSerializer(serializers.Serializer): class InitialSupplierSerializer(serializers.Serializer):
"""Serializer for adding initial supplier / manufacturer information""" """Serializer for adding initial supplier / manufacturer information."""
supplier = serializers.PrimaryKeyRelatedField( supplier = serializers.PrimaryKeyRelatedField(
queryset=company.models.Company.objects.all(), queryset=company.models.Company.objects.all(),
@ -479,7 +479,7 @@ class InitialSupplierSerializer(serializers.Serializer):
) )
def validate_supplier(self, company): def validate_supplier(self, company):
"""Validation for the provided Supplier""" """Validation for the provided Supplier."""
if company and not company.is_supplier: if company and not company.is_supplier:
raise serializers.ValidationError( raise serializers.ValidationError(
_('Selected company is not a valid supplier') _('Selected company is not a valid supplier')
@ -488,7 +488,7 @@ class InitialSupplierSerializer(serializers.Serializer):
return company return company
def validate_manufacturer(self, company): def validate_manufacturer(self, company):
"""Validation for the provided Manufacturer""" """Validation for the provided Manufacturer."""
if company and not company.is_manufacturer: if company and not company.is_manufacturer:
raise serializers.ValidationError( raise serializers.ValidationError(
_('Selected company is not a valid manufacturer') _('Selected company is not a valid manufacturer')
@ -497,7 +497,7 @@ class InitialSupplierSerializer(serializers.Serializer):
return company return company
def validate(self, data): def validate(self, data):
"""Extra validation for this serializer""" """Extra validation for this serializer."""
if company.models.ManufacturerPart.objects.filter( if company.models.ManufacturerPart.objects.filter(
manufacturer=data.get('manufacturer', None), MPN=data.get('mpn', '') manufacturer=data.get('manufacturer', None), MPN=data.get('mpn', '')
).exists(): ).exists():
@ -525,7 +525,7 @@ class PartSerializer(
""" """
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = Part model = Part
partial = True partial = True
@ -593,7 +593,7 @@ class PartSerializer(
tags = TagListSerializerField(required=False) tags = TagListSerializerField(required=False)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Custom initialization method for PartSerializer: """Custom initialization method for PartSerializer.
- Allows us to optionally pass extra fields based on the query. - Allows us to optionally pass extra fields based on the query.
""" """
@ -628,11 +628,11 @@ class PartSerializer(
self.fields.pop('pricing_max') self.fields.pop('pricing_max')
def get_api_url(self): def get_api_url(self):
"""Return the API url associated with this serializer""" """Return the API url associated with this serializer."""
return reverse_lazy('api-part-list') return reverse_lazy('api-part-list')
def skip_create_fields(self): def skip_create_fields(self):
"""Skip these fields when instantiating a new Part instance""" """Skip these fields when instantiating a new Part instance."""
fields = super().skip_create_fields() fields = super().skip_create_fields()
fields += [ fields += [
@ -811,7 +811,7 @@ class PartSerializer(
) )
def validate_existing_image(self, img): def validate_existing_image(self, img):
"""Validate the selected image file""" """Validate the selected image file."""
if not img: if not img:
return img return img
@ -827,7 +827,7 @@ class PartSerializer(
@transaction.atomic @transaction.atomic
def create(self, validated_data): def create(self, validated_data):
"""Custom method for creating a new Part instance using this serializer""" """Custom method for creating a new Part instance using this serializer."""
duplicate = validated_data.pop('duplicate', None) duplicate = validated_data.pop('duplicate', None)
initial_stock = validated_data.pop('initial_stock', None) initial_stock = validated_data.pop('initial_stock', None)
initial_supplier = validated_data.pop('initial_supplier', None) initial_supplier = validated_data.pop('initial_supplier', None)
@ -925,7 +925,7 @@ class PartSerializer(
return instance return instance
def save(self): def save(self):
"""Save the Part instance""" """Save the Part instance."""
super().save() super().save()
part = self.instance part = self.instance
@ -956,10 +956,10 @@ class PartSerializer(
class PartStocktakeSerializer(InvenTree.serializers.InvenTreeModelSerializer): class PartStocktakeSerializer(InvenTree.serializers.InvenTreeModelSerializer):
"""Serializer for the PartStocktake model""" """Serializer for the PartStocktake model."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options."""
model = PartStocktake model = PartStocktake
fields = [ fields = [
@ -992,7 +992,7 @@ class PartStocktakeSerializer(InvenTree.serializers.InvenTreeModelSerializer):
cost_max_currency = InvenTree.serializers.InvenTreeCurrencySerializer() cost_max_currency = InvenTree.serializers.InvenTreeCurrencySerializer()
def save(self): def save(self):
"""Called when this serializer is saved""" """Called when this serializer is saved."""
data = self.validated_data data = self.validated_data
# Add in user information automatically # Add in user information automatically
@ -1003,10 +1003,10 @@ class PartStocktakeSerializer(InvenTree.serializers.InvenTreeModelSerializer):
class PartStocktakeReportSerializer(InvenTree.serializers.InvenTreeModelSerializer): class PartStocktakeReportSerializer(InvenTree.serializers.InvenTreeModelSerializer):
"""Serializer for stocktake report class""" """Serializer for stocktake report class."""
class Meta: class Meta:
"""Metaclass defines serializer fields""" """Metaclass defines serializer fields."""
model = PartStocktakeReport model = PartStocktakeReport
fields = ['pk', 'date', 'report', 'part_count', 'user', 'user_detail'] fields = ['pk', 'date', 'report', 'part_count', 'user', 'user_detail']
@ -1019,7 +1019,7 @@ class PartStocktakeReportSerializer(InvenTree.serializers.InvenTreeModelSerializ
class PartStocktakeReportGenerateSerializer(serializers.Serializer): class PartStocktakeReportGenerateSerializer(serializers.Serializer):
"""Serializer class for manually generating a new PartStocktakeReport via the API""" """Serializer class for manually generating a new PartStocktakeReport via the API."""
part = serializers.PrimaryKeyRelatedField( part = serializers.PrimaryKeyRelatedField(
queryset=Part.objects.all(), queryset=Part.objects.all(),
@ -1070,7 +1070,7 @@ class PartStocktakeReportGenerateSerializer(serializers.Serializer):
) )
def validate(self, data): def validate(self, data):
"""Custom validation for this serializer""" """Custom validation for this serializer."""
# Stocktake functionality must be enabled # Stocktake functionality must be enabled
if not common.models.InvenTreeSetting.get_setting('STOCKTAKE_ENABLE', False): if not common.models.InvenTreeSetting.get_setting('STOCKTAKE_ENABLE', False):
raise serializers.ValidationError( raise serializers.ValidationError(
@ -1084,7 +1084,7 @@ class PartStocktakeReportGenerateSerializer(serializers.Serializer):
return data return data
def save(self): def save(self):
"""Saving this serializer instance requests generation of a new stocktake report""" """Saving this serializer instance requests generation of a new stocktake report."""
data = self.validated_data data = self.validated_data
user = self.context['request'].user user = self.context['request'].user
@ -1103,10 +1103,10 @@ class PartStocktakeReportGenerateSerializer(serializers.Serializer):
class PartPricingSerializer(InvenTree.serializers.InvenTreeModelSerializer): class PartPricingSerializer(InvenTree.serializers.InvenTreeModelSerializer):
"""Serializer for Part pricing information""" """Serializer for Part pricing information."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = PartPricing model = PartPricing
fields = [ fields = [
@ -1239,8 +1239,7 @@ class PartPricingSerializer(InvenTree.serializers.InvenTreeModelSerializer):
) )
def validate(self, data): def validate(self, data):
"""Validate supplied pricing data""" """Validate supplied pricing data."""
super().validate(data) super().validate(data)
# Check that override_min is not greater than override_max # Check that override_min is not greater than override_max
@ -1273,8 +1272,7 @@ class PartPricingSerializer(InvenTree.serializers.InvenTreeModelSerializer):
return data return data
def save(self): def save(self):
"""Called when the serializer is saved""" """Called when the serializer is saved."""
super().save() super().save()
# Update part pricing # Update part pricing
@ -1286,7 +1284,7 @@ class PartRelationSerializer(InvenTree.serializers.InvenTreeModelSerializer):
"""Serializer for a PartRelated model.""" """Serializer for a PartRelated model."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = PartRelated model = PartRelated
fields = ['pk', 'part_1', 'part_1_detail', 'part_2', 'part_2_detail'] fields = ['pk', 'part_1', 'part_1_detail', 'part_2', 'part_2_detail']
@ -1299,7 +1297,7 @@ class PartStarSerializer(InvenTree.serializers.InvenTreeModelSerializer):
"""Serializer for a PartStar object.""" """Serializer for a PartStar object."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = PartStar model = PartStar
fields = ['pk', 'part', 'partname', 'user', 'username'] fields = ['pk', 'part', 'partname', 'user', 'username']
@ -1312,7 +1310,7 @@ class BomItemSubstituteSerializer(InvenTree.serializers.InvenTreeModelSerializer
"""Serializer for the BomItemSubstitute class.""" """Serializer for the BomItemSubstitute class."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = BomItemSubstitute model = BomItemSubstitute
fields = ['pk', 'bom_item', 'part', 'part_detail'] fields = ['pk', 'bom_item', 'part', 'part_detail']
@ -1326,7 +1324,7 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer):
"""Serializer for BomItem object.""" """Serializer for BomItem object."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = BomItem model = BomItem
fields = [ fields = [
@ -1358,7 +1356,7 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer):
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Determine if extra detail fields are to be annotated on this serializer """Determine if extra detail fields are to be annotated on this serializer.
- part_detail and sub_part_detail serializers are only included if requested. - part_detail and sub_part_detail serializers are only included if requested.
- This saves a bunch of database requests - This saves a bunch of database requests
@ -1382,7 +1380,7 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer):
quantity = InvenTree.serializers.InvenTreeDecimalField(required=True) quantity = InvenTree.serializers.InvenTreeDecimalField(required=True)
def validate_quantity(self, quantity): def validate_quantity(self, quantity):
"""Perform validation for the BomItem quantity field""" """Perform validation for the BomItem quantity field."""
if quantity <= 0: if quantity <= 0:
raise serializers.ValidationError(_('Quantity must be greater than zero')) raise serializers.ValidationError(_('Quantity must be greater than zero'))
@ -1420,7 +1418,7 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer):
@staticmethod @staticmethod
def setup_eager_loading(queryset): def setup_eager_loading(queryset):
"""Prefetch against the provided queryset to speed up database access""" """Prefetch against the provided queryset to speed up database access."""
queryset = queryset.prefetch_related('part') queryset = queryset.prefetch_related('part')
queryset = queryset.prefetch_related('part__category') queryset = queryset.prefetch_related('part__category')
queryset = queryset.prefetch_related('part__stock_items') queryset = queryset.prefetch_related('part__stock_items')
@ -1444,7 +1442,7 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer):
@staticmethod @staticmethod
def annotate_queryset(queryset): def annotate_queryset(queryset):
"""Annotate the BomItem queryset with extra information: """Annotate the BomItem queryset with extra information.
Annotations: Annotations:
available_stock: The amount of stock available for the sub_part Part object available_stock: The amount of stock available for the sub_part Part object
@ -1552,7 +1550,7 @@ class CategoryParameterTemplateSerializer(
"""Serializer for the PartCategoryParameterTemplate model.""" """Serializer for the PartCategoryParameterTemplate model."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
model = PartCategoryParameterTemplate model = PartCategoryParameterTemplate
fields = [ fields = [
@ -1575,7 +1573,7 @@ class PartCopyBOMSerializer(serializers.Serializer):
"""Serializer for copying a BOM from another part.""" """Serializer for copying a BOM from another part."""
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
fields = [ fields = [
'part', 'part',
@ -1643,7 +1641,7 @@ class BomImportUploadSerializer(InvenTree.serializers.DataFileUploadSerializer):
TARGET_MODEL = BomItem TARGET_MODEL = BomItem
class Meta: class Meta:
"""Metaclass defining serializer fields""" """Metaclass defining serializer fields."""
fields = ['data_file', 'part', 'clear_existing_bom'] fields = ['data_file', 'part', 'clear_existing_bom']
@ -1657,7 +1655,7 @@ class BomImportUploadSerializer(InvenTree.serializers.DataFileUploadSerializer):
) )
def save(self): def save(self):
"""The uploaded data file has been validated, accept the submitted data""" """The uploaded data file has been validated, accept the submitted data."""
data = self.validated_data data = self.validated_data
if data.get('clear_existing_bom', False): if data.get('clear_existing_bom', False):
@ -1676,7 +1674,7 @@ class BomImportExtractSerializer(InvenTree.serializers.DataFileExtractSerializer
TARGET_MODEL = BomItem TARGET_MODEL = BomItem
def validate_extracted_columns(self): def validate_extracted_columns(self):
"""Validate that the extracted columns are correct""" """Validate that the extracted columns are correct."""
super().validate_extracted_columns() super().validate_extracted_columns()
part_columns = ['part', 'part_name', 'part_ipn', 'part_id'] part_columns = ['part', 'part_name', 'part_ipn', 'part_id']
@ -1687,7 +1685,7 @@ class BomImportExtractSerializer(InvenTree.serializers.DataFileExtractSerializer
@staticmethod @staticmethod
def process_row(row): def process_row(row):
"""Process a single row from the loaded BOM file""" """Process a single row from the loaded BOM file."""
# Skip any rows which are at a lower "level" # Skip any rows which are at a lower "level"
level = row.get('level', None) level = row.get('level', None)
@ -1764,9 +1762,9 @@ class BomImportSubmitSerializer(serializers.Serializer):
items = BomItemSerializer(many=True, required=True) items = BomItemSerializer(many=True, required=True)
def validate(self, data): def validate(self, data):
"""Validate the submitted BomItem data: """Validate the submitted BomItem data.
- At least one line (BomItem) is required At least one line (BomItem) is required
""" """
items = data['items'] items = data['items']
@ -1778,8 +1776,9 @@ class BomImportSubmitSerializer(serializers.Serializer):
return data return data
def save(self): def save(self):
"""POST: Perform final save of submitted BOM data: """POST: Perform final save of submitted BOM data.
Actions:
- By this stage each line in the BOM has been validated - By this stage each line in the BOM has been validated
- Individually 'save' (create) each BomItem line - Individually 'save' (create) each BomItem line
""" """

View File

@ -1,4 +1,4 @@
"""Stocktake report functionality""" """Stocktake report functionality."""
import io import io
import logging import logging
@ -28,8 +28,9 @@ def perform_stocktake(
Arguments: Arguments:
target: A single Part model instance target: A single Part model instance
commit: If True (default) save the result to the database
user: User who requested this stocktake user: User who requested this stocktake
note: Optional note to attach to the stocktake
commit: If True (default) save the result to the database
kwargs: kwargs:
exclude_external: If True, exclude stock items in external locations (default = False) exclude_external: If True, exclude stock items in external locations (default = False)

View File

@ -1,4 +1,4 @@
"""Background task definitions for the 'part' app""" """Background task definitions for the 'part' app."""
import logging import logging
import random import random
@ -28,8 +28,9 @@ logger = logging.getLogger('inventree')
def notify_low_stock(part: part.models.Part): def notify_low_stock(part: part.models.Part):
"""Notify interested users that a part is 'low stock': """Notify interested users that a part is 'low stock'.
Rules:
- Triggered when the available stock for a given part falls be low the configured threhsold - Triggered when the available stock for a given part falls be low the configured threhsold
- A notification is delivered to any users who are 'subscribed' to this part - A notification is delivered to any users who are 'subscribed' to this part
""" """
@ -64,7 +65,7 @@ def notify_low_stock_if_required(part: part.models.Part):
def update_part_pricing(pricing: part.models.PartPricing, counter: int = 0): def update_part_pricing(pricing: part.models.PartPricing, counter: int = 0):
"""Update cached pricing data for the specified PartPricing instance """Update cached pricing data for the specified PartPricing instance.
Arguments: Arguments:
pricing: The target PartPricing instance to be updated pricing: The target PartPricing instance to be updated
@ -77,8 +78,9 @@ def update_part_pricing(pricing: part.models.PartPricing, counter: int = 0):
@scheduled_task(ScheduledTask.DAILY) @scheduled_task(ScheduledTask.DAILY)
def check_missing_pricing(limit=250): def check_missing_pricing(limit=250):
"""Check for parts with missing or outdated pricing information: """Check for parts with missing or outdated pricing information.
Tests for the following conditions:
- Pricing information does not exist - Pricing information does not exist
- Pricing information is "old" - Pricing information is "old"
- Pricing information is in the wrong currency - Pricing information is in the wrong currency

View File

@ -1 +1 @@
"""Custom InvenTree template tags for HTML template rendering""" """Custom InvenTree template tags for HTML template rendering."""

View File

@ -17,7 +17,7 @@ register = template.Library()
@register.simple_tag() @register.simple_tag()
def translation_stats(lang_code): def translation_stats(lang_code):
"""Return the translation percentage for the given language code""" """Return the translation percentage for the given language code."""
if lang_code is None: if lang_code is None:
return None return None
@ -25,10 +25,10 @@ def translation_stats(lang_code):
class CustomTranslateNode(TranslateNode): class CustomTranslateNode(TranslateNode):
"""Custom translation node class, which sanitizes the translated strings for javascript use""" """Custom translation node class, which sanitizes the translated strings for javascript use."""
def render(self, context): def render(self, context):
"""Custom render function overrides / extends default behaviour""" """Custom render function overrides / extends default behaviour."""
result = super().render(context) result = super().render(context)
result = bleach.clean(result) result = bleach.clean(result)
@ -52,7 +52,7 @@ class CustomTranslateNode(TranslateNode):
@register.tag('translate') @register.tag('translate')
@register.tag('trans') @register.tag('trans')
def do_translate(parser, token): def do_translate(parser, token):
"""Custom translation function, lifted from https://github.com/django/django/blob/main/django/templatetags/i18n.py """Custom translation function, lifted from https://github.com/django/django/blob/main/django/templatetags/i18n.py.
The only difference is that we pass this to our custom rendering node class The only difference is that we pass this to our custom rendering node class
""" """

View File

@ -98,7 +98,7 @@ def render_date(context, date_object):
@register.simple_tag @register.simple_tag
def render_currency(money, **kwargs): def render_currency(money, **kwargs):
"""Render a currency / Money object""" """Render a currency / Money object."""
return InvenTree.helpers_model.render_currency(money, **kwargs) return InvenTree.helpers_model.render_currency(money, **kwargs)
@ -122,7 +122,7 @@ def to_list(*args):
@register.simple_tag() @register.simple_tag()
def part_allocation_count(build, part, *args, **kwargs): def part_allocation_count(build, part, *args, **kwargs):
"""Return the total number of <part> allocated to <build>""" """Return the total number of <part> allocated to <build>."""
return InvenTree.helpers.decimal2string(build.getAllocatedQuantity(part)) return InvenTree.helpers.decimal2string(build.getAllocatedQuantity(part))
@ -185,7 +185,7 @@ def inventree_instance_name(*args, **kwargs):
@register.simple_tag() @register.simple_tag()
def inventree_title(*args, **kwargs): def inventree_title(*args, **kwargs):
"""Return the title for the current instance - respecting the settings""" """Return the title for the current instance - respecting the settings."""
return version.inventreeInstanceTitle() return version.inventreeInstanceTitle()
@ -206,7 +206,7 @@ def inventree_splash(**kwargs):
@register.simple_tag() @register.simple_tag()
def inventree_base_url(*args, **kwargs): def inventree_base_url(*args, **kwargs):
"""Return the base URL of the InvenTree server""" """Return the base URL of the InvenTree server."""
return InvenTree.helpers_model.get_base_url() return InvenTree.helpers_model.get_base_url()
@ -226,19 +226,19 @@ def inventree_version(shortstring=False, *args, **kwargs):
@register.simple_tag() @register.simple_tag()
def inventree_is_development(*args, **kwargs): def inventree_is_development(*args, **kwargs):
"""Returns True if this is a development version of InvenTree""" """Returns True if this is a development version of InvenTree."""
return version.isInvenTreeDevelopmentVersion() return version.isInvenTreeDevelopmentVersion()
@register.simple_tag() @register.simple_tag()
def inventree_is_release(*args, **kwargs): def inventree_is_release(*args, **kwargs):
"""Returns True if this is a release version of InvenTree""" """Returns True if this is a release version of InvenTree."""
return not version.isInvenTreeDevelopmentVersion() return not version.isInvenTreeDevelopmentVersion()
@register.simple_tag() @register.simple_tag()
def inventree_docs_version(*args, **kwargs): def inventree_docs_version(*args, **kwargs):
"""Returns the InvenTree documentation version""" """Returns the InvenTree documentation version."""
return version.inventreeDocsVersion() return version.inventreeDocsVersion()
@ -429,7 +429,7 @@ def progress_bar(val, max_val, *args, **kwargs):
@register.simple_tag() @register.simple_tag()
def get_color_theme_css(username): def get_color_theme_css(username):
"""Return the custom theme .css file for the selected user""" """Return the custom theme .css file for the selected user."""
user_theme_name = get_user_color_theme(username) user_theme_name = get_user_color_theme(username)
# Build path to CSS sheet # Build path to CSS sheet
inventree_css_sheet = os.path.join('css', 'color-themes', user_theme_name + '.css') inventree_css_sheet = os.path.join('css', 'color-themes', user_theme_name + '.css')
@ -443,7 +443,6 @@ def get_color_theme_css(username):
@register.simple_tag() @register.simple_tag()
def get_user_color_theme(username): def get_user_color_theme(username):
"""Get current user color theme.""" """Get current user color theme."""
from common.models import ColorTheme from common.models import ColorTheme
try: try:
@ -488,7 +487,7 @@ def primitive_to_javascript(primitive):
@register.simple_tag() @register.simple_tag()
def js_bool(val): def js_bool(val):
"""Return a javascript boolean value (true or false)""" """Return a javascript boolean value (true or false)."""
if val: if val:
return 'true' return 'true'
return 'false' return 'false'
@ -599,14 +598,14 @@ if settings.DEBUG:
@register.simple_tag() @register.simple_tag()
def i18n_static(url_name): def i18n_static(url_name):
"""Simple tag to enable {% url %} functionality instead of {% static %}""" """Simple tag to enable {% url %} functionality instead of {% static %}."""
return reverse(url_name) return reverse(url_name)
else: # pragma: no cover else: # pragma: no cover
@register.tag('i18n_static') @register.tag('i18n_static')
def do_i18n_static(parser, token): def do_i18n_static(parser, token):
"""Overrides normal static, adds language - lookup for prerenderd files #1485 """Overrides normal static, adds language - lookup for prerenderd files #1485.
Usage (like static): Usage (like static):
{% i18n_static path [as varname] %} {% i18n_static path [as varname] %}
@ -623,8 +622,7 @@ else: # pragma: no cover
@register.simple_tag() @register.simple_tag()
def admin_index(user): def admin_index(user):
"""Return a URL for the admin interface""" """Return a URL for the admin interface."""
if not djangosettings.INVENTREE_ADMIN_ENABLED: if not djangosettings.INVENTREE_ADMIN_ENABLED:
return '' return ''
@ -642,7 +640,6 @@ def admin_url(user, table, pk):
- If the user is not a staff user, an empty URL is returned - If the user is not a staff user, an empty URL is returned
- If the user does not have the correct permission, an empty URL is returned - If the user does not have the correct permission, an empty URL is returned
""" """
app, model = table.strip().split('.') app, model = table.strip().split('.')
from django.urls import reverse from django.urls import reverse

View File

@ -1,4 +1,4 @@
"""This module provides template tags pertaining to SSO functionality""" """This module provides template tags pertaining to SSO functionality."""
from django import template from django import template
@ -9,24 +9,23 @@ register = template.Library()
@register.simple_tag() @register.simple_tag()
def sso_login_enabled(): def sso_login_enabled():
"""Return True if single-sign-on is enabled""" """Return True if single-sign-on is enabled."""
return InvenTree.sso.login_enabled() return InvenTree.sso.login_enabled()
@register.simple_tag() @register.simple_tag()
def sso_reg_enabled(): def sso_reg_enabled():
"""Return True if single-sign-on is enabled for self-registration""" """Return True if single-sign-on is enabled for self-registration."""
return InvenTree.sso.registration_enabled() return InvenTree.sso.registration_enabled()
@register.simple_tag() @register.simple_tag()
def sso_auto_enabled(): def sso_auto_enabled():
"""Return True if single-sign-on is enabled for auto-registration""" """Return True if single-sign-on is enabled for auto-registration."""
return InvenTree.sso.auto_registration_enabled() return InvenTree.sso.auto_registration_enabled()
@register.simple_tag() @register.simple_tag()
def sso_check_provider(provider): def sso_check_provider(provider):
"""Return True if the given provider is correctly configured""" """Return True if the given provider is correctly configured."""
return InvenTree.sso.check_provider(provider) return InvenTree.sso.check_provider(provider)

View File

@ -1,4 +1,4 @@
"""Unit tests for the various part API endpoints""" """Unit tests for the various part API endpoints."""
import os import os
from datetime import datetime from datetime import datetime
@ -64,7 +64,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase):
] ]
def test_category_list(self): def test_category_list(self):
"""Test the PartCategoryList API endpoint""" """Test the PartCategoryList API endpoint."""
url = reverse('api-part-category-list') url = reverse('api-part-category-list')
# star categories manually for tests as it is not possible with fixures # star categories manually for tests as it is not possible with fixures
@ -167,7 +167,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase):
) )
def test_part_count(self): def test_part_count(self):
"""Test that the 'part_count' field is annotated correctly""" """Test that the 'part_count' field is annotated correctly."""
url = reverse('api-part-category-list') url = reverse('api-part-category-list')
# Create a parent category # Create a parent category
@ -206,7 +206,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase):
self.assertEqual(response.data['part_count'], 100) self.assertEqual(response.data['part_count'], 100)
def test_category_parameters(self): def test_category_parameters(self):
"""Test that the PartCategoryParameterTemplate API function work""" """Test that the PartCategoryParameterTemplate API function work."""
url = reverse('api-part-category-parameter-list') url = reverse('api-part-category-parameter-list')
response = self.get(url, {}, expected_code=200) response = self.get(url, {}, expected_code=200)
@ -288,7 +288,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase):
self.assertEqual(response.data['description'], val) self.assertEqual(response.data['description'], val)
def test_invisible_chars(self): def test_invisible_chars(self):
"""Test that invisible characters are removed from the input data""" """Test that invisible characters are removed from the input data."""
url = reverse('api-part-category-detail', kwargs={'pk': 1}) url = reverse('api-part-category-detail', kwargs={'pk': 1})
values = [ values = [
@ -304,7 +304,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase):
self.assertEqual(response.data['description'], 'A part category') self.assertEqual(response.data['description'], 'A part category')
def test_category_delete(self): def test_category_delete(self):
"""Test category deletion with different parameters""" """Test category deletion with different parameters."""
class Target(IntEnum): class Target(IntEnum):
move_subcategories_to_parent_move_parts_to_parent = (0,) move_subcategories_to_parent_move_parts_to_parent = (0,)
@ -419,7 +419,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase):
self.assertEqual(child.parent, parent_category) self.assertEqual(child.parent, parent_category)
def test_structural(self): def test_structural(self):
"""Test the effectiveness of structural categories """Test the effectiveness of structural categories.
Make sure: Make sure:
- Parts cannot be created in structural categories - Parts cannot be created in structural categories
@ -471,7 +471,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase):
self.assertEqual(part.category.pk, non_structural_category.pk) self.assertEqual(part.category.pk, non_structural_category.pk)
def test_path_detail(self): def test_path_detail(self):
"""Test path_detail information""" """Test path_detail information."""
url = reverse('api-part-category-detail', kwargs={'pk': 5}) url = reverse('api-part-category-detail', kwargs={'pk': 5})
# First, request without path detail # First, request without path detail
@ -578,7 +578,7 @@ class PartOptionsAPITest(InvenTreeAPITestCase):
class PartAPITestBase(InvenTreeAPITestCase): class PartAPITestBase(InvenTreeAPITestCase):
"""Base class for running tests on the Part API endpoints""" """Base class for running tests on the Part API endpoints."""
fixtures = [ fixtures = [
'category', 'category',
@ -669,7 +669,7 @@ class PartAPITest(PartAPITestBase):
self.assertEqual(len(response.data), 12) self.assertEqual(len(response.data), 12)
def test_cat_detail(self): def test_cat_detail(self):
"""Test the PartCategoryDetail API endpoint""" """Test the PartCategoryDetail API endpoint."""
url = reverse('api-part-category-detail', kwargs={'pk': 4}) url = reverse('api-part-category-detail', kwargs={'pk': 4})
response = self.get(url) response = self.get(url)
@ -688,7 +688,7 @@ class PartAPITest(PartAPITestBase):
self.assertIsNone(response.data['parent']) self.assertIsNone(response.data['parent'])
def test_filter_parts(self): def test_filter_parts(self):
"""Test part filtering using the API""" """Test part filtering using the API."""
url = reverse('api-part-list') url = reverse('api-part-list')
data = {'cascade': True} data = {'cascade': True}
response = self.get(url, data) response = self.get(url, data)
@ -707,7 +707,7 @@ class PartAPITest(PartAPITestBase):
self.assertEqual(part['category'], 2) self.assertEqual(part['category'], 2)
def test_filter_by_in_bom(self): def test_filter_by_in_bom(self):
"""Test that we can filter part list by the 'in_bom_for' parameter""" """Test that we can filter part list by the 'in_bom_for' parameter."""
url = reverse('api-part-list') url = reverse('api-part-list')
response = self.get(url, {'in_bom_for': 100}, expected_code=200) response = self.get(url, {'in_bom_for': 100}, expected_code=200)
@ -715,7 +715,7 @@ class PartAPITest(PartAPITestBase):
self.assertEqual(len(response.data), 4) self.assertEqual(len(response.data), 4)
def test_filter_by_related(self): def test_filter_by_related(self):
"""Test that we can filter by the 'related' status""" """Test that we can filter by the 'related' status."""
url = reverse('api-part-list') url = reverse('api-part-list')
# Initially there are no relations, so this should return zero results # Initially there are no relations, so this should return zero results
@ -735,7 +735,7 @@ class PartAPITest(PartAPITestBase):
self.assertEqual(len(response.data), 2) self.assertEqual(len(response.data), 2)
def test_filter_by_convert(self): def test_filter_by_convert(self):
"""Test that we can correctly filter the Part list by conversion options""" """Test that we can correctly filter the Part list by conversion options."""
category = PartCategory.objects.get(pk=3) category = PartCategory.objects.get(pk=3)
# First, construct a set of template / variant parts # First, construct a set of template / variant parts
@ -799,7 +799,7 @@ class PartAPITest(PartAPITestBase):
self.assertEqual(len(response.data), 3) self.assertEqual(len(response.data), 3)
def test_test_templates(self): def test_test_templates(self):
"""Test the PartTestTemplate API""" """Test the PartTestTemplate API."""
url = reverse('api-part-test-template-list') url = reverse('api-part-test-template-list')
# List ALL items # List ALL items
@ -870,8 +870,9 @@ class PartAPITest(PartAPITestBase):
self.assertEqual(len(data['results']), n) self.assertEqual(len(data['results']), n)
def test_template_filters(self): def test_template_filters(self):
"""Unit tests for API filters related to template parts: """Unit tests for API filters related to template parts.
Test:
- variant_of : Return children of specified part - variant_of : Return children of specified part
- ancestor : Return descendants of specified part - ancestor : Return descendants of specified part
@ -1029,8 +1030,7 @@ class PartAPITest(PartAPITestBase):
self.assertEqual(part.category.name, row['Category Name']) self.assertEqual(part.category.name, row['Category Name'])
def test_date_filters(self): def test_date_filters(self):
"""Test that the creation date filters work correctly""" """Test that the creation date filters work correctly."""
url = reverse('api-part-list') url = reverse('api-part-list')
response = self.get(url) response = self.get(url)
@ -1064,10 +1064,10 @@ class PartAPITest(PartAPITestBase):
class PartCreationTests(PartAPITestBase): class PartCreationTests(PartAPITestBase):
"""Tests for creating new Part instances via the API""" """Tests for creating new Part instances via the API."""
def test_default_values(self): def test_default_values(self):
"""Tests for 'default' values: """Tests for 'default' values.
Ensure that unspecified fields revert to "default" values Ensure that unspecified fields revert to "default" values
(as specified in the model field definition) (as specified in the model field definition)
@ -1121,7 +1121,7 @@ class PartCreationTests(PartAPITestBase):
"""Tests for initial stock quantity creation.""" """Tests for initial stock quantity creation."""
def submit(stock_data, expected_code=None): def submit(stock_data, expected_code=None):
"""Helper function for submitting with initial stock data""" """Helper function for submitting with initial stock data."""
data = { data = {
'category': 1, 'category': 1,
'name': "My lil' test part", 'name': "My lil' test part",
@ -1162,7 +1162,7 @@ class PartCreationTests(PartAPITestBase):
"""Tests for initial creation of supplier / manufacturer data.""" """Tests for initial creation of supplier / manufacturer data."""
def submit(supplier_data, expected_code=400): def submit(supplier_data, expected_code=400):
"""Helper function for submitting with supplier data""" """Helper function for submitting with supplier data."""
data = { data = {
'name': 'My test part', 'name': 'My test part',
'description': 'A test part thingy', 'description': 'A test part thingy',
@ -1248,7 +1248,7 @@ class PartCreationTests(PartAPITestBase):
self.assertEqual(response.data['description'], description) self.assertEqual(response.data['description'], description)
def test_duplication(self): def test_duplication(self):
"""Test part duplication options""" """Test part duplication options."""
# Run a matrix of tests # Run a matrix of tests
for bom in [True, False]: for bom in [True, False]:
for img in [True, False]: for img in [True, False]:
@ -1276,7 +1276,7 @@ class PartCreationTests(PartAPITestBase):
self.assertEqual(part.parameters.count(), 2 if params else 0) self.assertEqual(part.parameters.count(), 2 if params else 0)
def test_category_parameters(self): def test_category_parameters(self):
"""Test that category parameters are correctly applied""" """Test that category parameters are correctly applied."""
cat = PartCategory.objects.get(pk=1) cat = PartCategory.objects.get(pk=1)
# Add some parameter template to the parent category # Add some parameter template to the parent category
@ -1325,7 +1325,7 @@ class PartDetailTests(PartAPITestBase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Custom setup routine for this class""" """Custom setup routine for this class."""
super().setUpTestData() super().setUpTestData()
# Create a custom APIClient for file uploads # Create a custom APIClient for file uploads
@ -1334,7 +1334,7 @@ class PartDetailTests(PartAPITestBase):
cls.upload_client.force_authenticate(user=cls.user) cls.upload_client.force_authenticate(user=cls.user)
def test_part_operations(self): def test_part_operations(self):
"""Test that Part instances can be adjusted via the API""" """Test that Part instances can be adjusted via the API."""
n = Part.objects.count() n = Part.objects.count()
# Create a part # Create a part
@ -1530,8 +1530,7 @@ class PartDetailTests(PartAPITestBase):
self.assertIsNotNone(p.image) self.assertIsNotNone(p.image)
def test_existing_image(self): def test_existing_image(self):
"""Test that we can allocate an existing uploaded image to a new Part""" """Test that we can allocate an existing uploaded image to a new Part."""
# First, upload an image for an existing part # First, upload an image for an existing part
p = Part.objects.first() p = Part.objects.first()
@ -1662,7 +1661,7 @@ class PartDetailTests(PartAPITestBase):
self.assertEqual(data['unallocated_stock'], 9000) self.assertEqual(data['unallocated_stock'], 9000)
def test_path_detail(self): def test_path_detail(self):
"""Check that path_detail can be requested against the serializer""" """Check that path_detail can be requested against the serializer."""
response = self.get( response = self.get(
reverse('api-part-detail', kwargs={'pk': 1}), reverse('api-part-detail', kwargs={'pk': 1}),
{'path_detail': True}, {'path_detail': True},
@ -1674,10 +1673,10 @@ class PartDetailTests(PartAPITestBase):
class PartListTests(PartAPITestBase): class PartListTests(PartAPITestBase):
"""Unit tests for the Part List API endpoint""" """Unit tests for the Part List API endpoint."""
def test_query_count(self): def test_query_count(self):
"""Test that the query count is unchanged, independent of query results""" """Test that the query count is unchanged, independent of query results."""
queries = [{'limit': 1}, {'limit': 10}, {'limit': 50}, {'category': 1}, {}] queries = [{'limit': 1}, {'limit': 10}, {'limit': 50}, {'category': 1}, {}]
url = reverse('api-part-list') url = reverse('api-part-list')
@ -1722,14 +1721,14 @@ class PartListTests(PartAPITestBase):
class PartNotesTests(InvenTreeAPITestCase): class PartNotesTests(InvenTreeAPITestCase):
"""Tests for the 'notes' field (markdown field)""" """Tests for the 'notes' field (markdown field)."""
fixtures = ['category', 'part', 'location', 'company'] fixtures = ['category', 'part', 'location', 'company']
roles = ['part.change', 'part.add'] roles = ['part.change', 'part.add']
def test_long_notes(self): def test_long_notes(self):
"""Test that very long notes field is rejected""" """Test that very long notes field is rejected."""
# Ensure that we cannot upload a very long piece of text # Ensure that we cannot upload a very long piece of text
url = reverse('api-part-detail', kwargs={'pk': 1}) url = reverse('api-part-detail', kwargs={'pk': 1})
@ -1741,7 +1740,7 @@ class PartNotesTests(InvenTreeAPITestCase):
) )
def test_multiline_formatting(self): def test_multiline_formatting(self):
"""Ensure that markdown formatting is retained""" """Ensure that markdown formatting is retained."""
url = reverse('api-part-detail', kwargs={'pk': 1}) url = reverse('api-part-detail', kwargs={'pk': 1})
notes = """ notes = """
@ -1765,18 +1764,18 @@ class PartNotesTests(InvenTreeAPITestCase):
class PartPricingDetailTests(InvenTreeAPITestCase): class PartPricingDetailTests(InvenTreeAPITestCase):
"""Tests for the part pricing API endpoint""" """Tests for the part pricing API endpoint."""
fixtures = ['category', 'part', 'location'] fixtures = ['category', 'part', 'location']
roles = ['part.change'] roles = ['part.change']
def url(self, pk): def url(self, pk):
"""Construct a pricing URL""" """Construct a pricing URL."""
return reverse('api-part-pricing', kwargs={'pk': pk}) return reverse('api-part-pricing', kwargs={'pk': pk})
def test_pricing_detail(self): def test_pricing_detail(self):
"""Test an empty pricing detail""" """Test an empty pricing detail."""
response = self.get(self.url(1), expected_code=200) response = self.get(self.url(1), expected_code=200)
# Check for expected fields # Check for expected fields
@ -1823,7 +1822,7 @@ class PartAPIAggregationTest(InvenTreeAPITestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Create test data as part of setup routine""" """Create test data as part of setup routine."""
super().setUpTestData() super().setUpTestData()
# Ensure the part "variant" tree is correctly structured # Ensure the part "variant" tree is correctly structured
@ -1849,7 +1848,7 @@ class PartAPIAggregationTest(InvenTreeAPITestCase):
) )
def get_part_data(self): def get_part_data(self):
"""Helper function for retrieving part data""" """Helper function for retrieving part data."""
url = reverse('api-part-list') url = reverse('api-part-list')
response = self.get(url) response = self.get(url)
@ -2119,7 +2118,7 @@ class BomItemTest(InvenTreeAPITestCase):
roles = ['part.add', 'part.change', 'part.delete'] roles = ['part.add', 'part.change', 'part.delete']
def setUp(self): def setUp(self):
"""Set up the test case""" """Set up the test case."""
super().setUp() super().setUp()
# Rebuild part tree so BOM items validate correctly # Rebuild part tree so BOM items validate correctly
@ -2166,7 +2165,7 @@ class BomItemTest(InvenTreeAPITestCase):
self.assertTrue(key in el) self.assertTrue(key in el)
def test_bom_list_search(self): def test_bom_list_search(self):
"""Test that we can search the BOM list API endpoint""" """Test that we can search the BOM list API endpoint."""
url = reverse('api-bom-list') url = reverse('api-bom-list')
response = self.get(url, expected_code=200) response = self.get(url, expected_code=200)
@ -2191,7 +2190,7 @@ class BomItemTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), 0) self.assertEqual(len(response.data), 0)
def test_bom_list_ordering(self): def test_bom_list_ordering(self):
"""Test that the BOM list results can be ordered""" """Test that the BOM list results can be ordered."""
url = reverse('api-bom-list') url = reverse('api-bom-list')
# Order by increasing quantity # Order by increasing quantity
@ -2227,7 +2226,6 @@ class BomItemTest(InvenTreeAPITestCase):
def test_get_bom_detail(self): def test_get_bom_detail(self):
"""Get the detail view for a single BomItem object.""" """Get the detail view for a single BomItem object."""
url = reverse('api-bom-item-detail', kwargs={'pk': 3}) url = reverse('api-bom-item-detail', kwargs={'pk': 3})
response = self.get(url, expected_code=200) response = self.get(url, expected_code=200)
@ -2476,12 +2474,12 @@ class BomItemTest(InvenTreeAPITestCase):
class PartAttachmentTest(InvenTreeAPITestCase): class PartAttachmentTest(InvenTreeAPITestCase):
"""Unit tests for the PartAttachment API endpoint""" """Unit tests for the PartAttachment API endpoint."""
fixtures = ['category', 'part', 'location'] fixtures = ['category', 'part', 'location']
def test_add_attachment(self): def test_add_attachment(self):
"""Test that we can create a new PartAttachment via the API""" """Test that we can create a new PartAttachment via the API."""
url = reverse('api-part-attachment-list') url = reverse('api-part-attachment-list')
# Upload without permission # Upload without permission
@ -2525,7 +2523,7 @@ class PartAttachmentTest(InvenTreeAPITestCase):
class PartInternalPriceBreakTest(InvenTreeAPITestCase): class PartInternalPriceBreakTest(InvenTreeAPITestCase):
"""Unit tests for the PartInternalPrice API endpoints""" """Unit tests for the PartInternalPrice API endpoints."""
fixtures = [ fixtures = [
'category', 'category',
@ -2551,7 +2549,7 @@ class PartInternalPriceBreakTest(InvenTreeAPITestCase):
] ]
def test_create_price_breaks(self): def test_create_price_breaks(self):
"""Test we can create price breaks at various quantities""" """Test we can create price breaks at various quantities."""
url = reverse('api-part-internal-price-list') url = reverse('api-part-internal-price-list')
breaks = [ breaks = [
@ -2589,7 +2587,7 @@ class PartInternalPriceBreakTest(InvenTreeAPITestCase):
class PartStocktakeTest(InvenTreeAPITestCase): class PartStocktakeTest(InvenTreeAPITestCase):
"""Unit tests for the part stocktake functionality""" """Unit tests for the part stocktake functionality."""
superuser = False superuser = False
is_staff = False is_staff = False
@ -2597,7 +2595,7 @@ class PartStocktakeTest(InvenTreeAPITestCase):
fixtures = ['category', 'part', 'location', 'stock'] fixtures = ['category', 'part', 'location', 'stock']
def test_list_endpoint(self): def test_list_endpoint(self):
"""Test the list endpoint for the stocktake data""" """Test the list endpoint for the stocktake data."""
url = reverse('api-part-stocktake-list') url = reverse('api-part-stocktake-list')
self.assignRole('part.view') self.assignRole('part.view')
@ -2636,7 +2634,7 @@ class PartStocktakeTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), total) self.assertEqual(len(response.data), total)
def test_create_stocktake(self): def test_create_stocktake(self):
"""Test that stocktake entries can be created via the API""" """Test that stocktake entries can be created via the API."""
url = reverse('api-part-stocktake-list') url = reverse('api-part-stocktake-list')
self.assignRole('stocktake.add') self.assignRole('stocktake.add')
@ -2695,7 +2693,7 @@ class PartStocktakeTest(InvenTreeAPITestCase):
self.delete(url, expected_code=204) self.delete(url, expected_code=204)
def test_report_list(self): def test_report_list(self):
"""Test for PartStocktakeReport list endpoint""" """Test for PartStocktakeReport list endpoint."""
from part.stocktake import generate_stocktake_report from part.stocktake import generate_stocktake_report
# Initially, no stocktake records are available # Initially, no stocktake records are available
@ -2728,7 +2726,7 @@ class PartStocktakeTest(InvenTreeAPITestCase):
self.assertTrue(data['report'].endswith('.csv')) self.assertTrue(data['report'].endswith('.csv'))
def test_report_generate(self): def test_report_generate(self):
"""Test API functionality for generating a new stocktake report""" """Test API functionality for generating a new stocktake report."""
url = reverse('api-part-stocktake-report-generate') url = reverse('api-part-stocktake-report-generate')
# Permission denied, initially # Permission denied, initially
@ -2767,12 +2765,12 @@ class PartMetadataAPITest(InvenTreeAPITestCase):
roles = ['part.change', 'part_category.change'] roles = ['part.change', 'part_category.change']
def setUp(self): def setUp(self):
"""Setup unit tets""" """Setup unit tets."""
super().setUp() super().setUp()
Part.objects.rebuild() Part.objects.rebuild()
def metatester(self, apikey, model): def metatester(self, apikey, model):
"""Generic tester""" """Generic tester."""
modeldata = model.objects.first() modeldata = model.objects.first()
# Useless test unless a model object is found # Useless test unless a model object is found
@ -2798,7 +2796,7 @@ class PartMetadataAPITest(InvenTreeAPITestCase):
) )
def test_metadata(self): def test_metadata(self):
"""Test all endpoints""" """Test all endpoints."""
for apikey, model in { for apikey, model in {
'api-part-category-parameter-metadata': PartCategoryParameterTemplate, 'api-part-category-parameter-metadata': PartCategoryParameterTemplate,
'api-part-category-metadata': PartCategory, 'api-part-category-metadata': PartCategory,
@ -2814,10 +2812,10 @@ class PartMetadataAPITest(InvenTreeAPITestCase):
class PartSchedulingTest(PartAPITestBase): class PartSchedulingTest(PartAPITestBase):
"""Unit tests for the 'part scheduling' API endpoint""" """Unit tests for the 'part scheduling' API endpoint."""
def test_get_schedule(self): def test_get_schedule(self):
"""Test that the scheduling endpoint returns OK""" """Test that the scheduling endpoint returns OK."""
part_ids = [1, 3, 100, 101] part_ids = [1, 3, 100, 101]
for pk in part_ids: for pk in part_ids:

View File

@ -9,14 +9,14 @@ from InvenTree.unit_test import InvenTreeTestCase
class BomExportTest(InvenTreeTestCase): class BomExportTest(InvenTreeTestCase):
"""Class for performing unit testing of BOM export functionality""" """Class for performing unit testing of BOM export functionality."""
fixtures = ['category', 'part', 'location', 'bom'] fixtures = ['category', 'part', 'location', 'bom']
roles = 'all' roles = 'all'
def setUp(self): def setUp(self):
"""Perform test setup functions""" """Perform test setup functions."""
super().setUp() super().setUp()
part.models.Part.objects.rebuild() part.models.Part.objects.rebuild()

View File

@ -16,7 +16,7 @@ class BomUploadTest(InvenTreeAPITestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Create BOM data as part of setup routine""" """Create BOM data as part of setup routine."""
super().setUpTestData() super().setUpTestData()
Part.objects.rebuild() Part.objects.rebuild()
@ -55,7 +55,7 @@ class BomUploadTest(InvenTreeAPITestCase):
expected_code=None, expected_code=None,
content_type='text/plain', content_type='text/plain',
): ):
"""Helper function for submitting a BOM file""" """Helper function for submitting a BOM file."""
bom_file = SimpleUploadedFile(filename, file_data, content_type=content_type) bom_file = SimpleUploadedFile(filename, file_data, content_type=content_type)
if clear_existing is None: if clear_existing is None:
@ -102,7 +102,7 @@ class BomUploadTest(InvenTreeAPITestCase):
) )
def test_missing_rows(self): def test_missing_rows(self):
"""Test upload of an invalid file (without data rows)""" """Test upload of an invalid file (without data rows)."""
dataset = tablib.Dataset() dataset = tablib.Dataset()
dataset.headers = ['apple', 'banana'] dataset.headers = ['apple', 'banana']

View File

@ -1,4 +1,4 @@
"""Unit tests for the BomItem model""" """Unit tests for the BomItem model."""
from decimal import Decimal from decimal import Decimal
@ -12,7 +12,7 @@ from .models import BomItem, BomItemSubstitute, Part
class BomItemTest(TestCase): class BomItemTest(TestCase):
"""Class for unit testing BomItem model""" """Class for unit testing BomItem model."""
fixtures = [ fixtures = [
'category', 'category',
@ -26,7 +26,7 @@ class BomItemTest(TestCase):
] ]
def setUp(self): def setUp(self):
"""Create initial data""" """Create initial data."""
super().setUp() super().setUp()
Part.objects.rebuild() Part.objects.rebuild()
@ -36,19 +36,19 @@ class BomItemTest(TestCase):
self.r1 = Part.objects.get(name='R_2K2_0805') self.r1 = Part.objects.get(name='R_2K2_0805')
def test_str(self): def test_str(self):
"""Test the string representation of a BOMItem""" """Test the string representation of a BOMItem."""
b = BomItem.objects.get(id=1) b = BomItem.objects.get(id=1)
self.assertEqual(str(b), '10 x M2x4 LPHS to make BOB | Bob | A2') self.assertEqual(str(b), '10 x M2x4 LPHS to make BOB | Bob | A2')
def test_has_bom(self): def test_has_bom(self):
"""Test the has_bom attribute""" """Test the has_bom attribute."""
self.assertFalse(self.orphan.has_bom) self.assertFalse(self.orphan.has_bom)
self.assertTrue(self.bob.has_bom) self.assertTrue(self.bob.has_bom)
self.assertEqual(self.bob.bom_count, 4) self.assertEqual(self.bob.bom_count, 4)
def test_in_bom(self): def test_in_bom(self):
"""Test BOM aggregation""" """Test BOM aggregation."""
parts = self.bob.getRequiredParts() parts = self.bob.getRequiredParts()
self.assertIn(self.orphan, parts) self.assertIn(self.orphan, parts)
@ -56,7 +56,7 @@ class BomItemTest(TestCase):
self.assertTrue(self.bob.check_if_part_in_bom(self.orphan)) self.assertTrue(self.bob.check_if_part_in_bom(self.orphan))
def test_used_in(self): def test_used_in(self):
"""Test that the 'used_in_count' attribute is calculated correctly""" """Test that the 'used_in_count' attribute is calculated correctly."""
self.assertEqual(self.bob.used_in_count, 1) self.assertEqual(self.bob.used_in_count, 1)
self.assertEqual(self.orphan.used_in_count, 1) self.assertEqual(self.orphan.used_in_count, 1)
@ -131,7 +131,7 @@ class BomItemTest(TestCase):
self.assertNotEqual(h1, h2) self.assertNotEqual(h1, h2)
def test_pricing(self): def test_pricing(self):
"""Test BOM pricing""" """Test BOM pricing."""
self.bob.get_price(1) self.bob.get_price(1)
self.assertEqual( self.assertEqual(
self.bob.get_bom_price_range(1, internal=True), self.bob.get_bom_price_range(1, internal=True),
@ -193,7 +193,7 @@ class BomItemTest(TestCase):
self.assertEqual(bom_item.substitutes.count(), 0) self.assertEqual(bom_item.substitutes.count(), 0)
def test_consumable(self): def test_consumable(self):
"""Tests for the 'consumable' BomItem field""" """Tests for the 'consumable' BomItem field."""
# Create an assembly part # Create an assembly part
assembly = Part.objects.create( assembly = Part.objects.create(
name='An assembly', description='Made with parts', assembly=True name='An assembly', description='Made with parts', assembly=True
@ -252,8 +252,7 @@ class BomItemTest(TestCase):
self.assertEqual(len(p.metadata.keys()), 4) self.assertEqual(len(p.metadata.keys()), 4)
def test_invalid_bom(self): def test_invalid_bom(self):
"""Test that ValidationError is correctly raised for an invalid BOM item""" """Test that ValidationError is correctly raised for an invalid BOM item."""
# First test: A BOM item which points to itself # First test: A BOM item which points to itself
with self.assertRaises(django_exceptions.ValidationError): with self.assertRaises(django_exceptions.ValidationError):
BomItem.objects.create(part=self.bob, sub_part=self.bob, quantity=1) BomItem.objects.create(part=self.bob, sub_part=self.bob, quantity=1)

View File

@ -1,4 +1,4 @@
"""Unit tests for the PartCategory model""" """Unit tests for the PartCategory model."""
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.test import TestCase from django.test import TestCase
@ -17,7 +17,7 @@ class CategoryTest(TestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Extract some interesting categories for time-saving""" """Extract some interesting categories for time-saving."""
super().setUpTestData() super().setUpTestData()
cls.electronics = PartCategory.objects.get(name='Electronics') cls.electronics = PartCategory.objects.get(name='Electronics')
@ -216,7 +216,7 @@ class CategoryTest(TestCase):
self.assertIsNone(w.get_default_location()) self.assertIsNone(w.get_default_location())
def test_category_tree(self): def test_category_tree(self):
"""Unit tests for the part category tree structure (MPTT) """Unit tests for the part category tree structure (MPTT).
Ensure that the MPTT structure is rebuilt correctly, Ensure that the MPTT structure is rebuilt correctly,
and the correct ancestor tree is observed. and the correct ancestor tree is observed.

View File

@ -32,7 +32,7 @@ class TestForwardMigrations(MigratorTestCase):
print(p.is_template) print(p.is_template)
def test_models_exist(self): def test_models_exist(self):
"""Test that the Part model can still be accessed at the end of schema migration""" """Test that the Part model can still be accessed at the end of schema migration."""
Part = self.new_state.apps.get_model('part', 'part') Part = self.new_state.apps.get_model('part', 'part')
self.assertEqual(Part.objects.count(), 5) self.assertEqual(Part.objects.count(), 5)
@ -49,13 +49,13 @@ class TestForwardMigrations(MigratorTestCase):
class TestBomItemMigrations(MigratorTestCase): class TestBomItemMigrations(MigratorTestCase):
"""Tests for BomItem migrations""" """Tests for BomItem migrations."""
migrate_from = ('part', '0002_auto_20190520_2204') migrate_from = ('part', '0002_auto_20190520_2204')
migrate_to = ('part', unit_test.getNewestMigrationFile('part')) migrate_to = ('part', unit_test.getNewestMigrationFile('part'))
def prepare(self): def prepare(self):
"""Create initial dataset""" """Create initial dataset."""
Part = self.old_state.apps.get_model('part', 'part') Part = self.old_state.apps.get_model('part', 'part')
BomItem = self.old_state.apps.get_model('part', 'bomitem') BomItem = self.old_state.apps.get_model('part', 'bomitem')
@ -73,7 +73,7 @@ class TestBomItemMigrations(MigratorTestCase):
print(b.validated) print(b.validated)
def test_validated_field(self): def test_validated_field(self):
"""Test that the 'validated' field is added to the BomItem objects""" """Test that the 'validated' field is added to the BomItem objects."""
BomItem = self.new_state.apps.get_model('part', 'bomitem') BomItem = self.new_state.apps.get_model('part', 'bomitem')
self.assertEqual(BomItem.objects.count(), 2) self.assertEqual(BomItem.objects.count(), 2)
@ -83,13 +83,13 @@ class TestBomItemMigrations(MigratorTestCase):
class TestParameterMigrations(MigratorTestCase): class TestParameterMigrations(MigratorTestCase):
"""Unit test for part parameter migrations""" """Unit test for part parameter migrations."""
migrate_from = ('part', '0106_part_tags') migrate_from = ('part', '0106_part_tags')
migrate_to = ('part', unit_test.getNewestMigrationFile('part')) migrate_to = ('part', unit_test.getNewestMigrationFile('part'))
def prepare(self): def prepare(self):
"""Create some parts, and templates with parameters""" """Create some parts, and templates with parameters."""
Part = self.old_state.apps.get_model('part', 'part') Part = self.old_state.apps.get_model('part', 'part')
PartParameter = self.old_state.apps.get_model('part', 'partparameter') PartParameter = self.old_state.apps.get_model('part', 'partparameter')
PartParameterTemlate = self.old_state.apps.get_model( PartParameterTemlate = self.old_state.apps.get_model(
@ -117,7 +117,7 @@ class TestParameterMigrations(MigratorTestCase):
PartParameter.objects.create(part=b, template=t2, data='abc') PartParameter.objects.create(part=b, template=t2, data='abc')
def test_data_migration(self): def test_data_migration(self):
"""Test that the template units and values have been updated correctly""" """Test that the template units and values have been updated correctly."""
Part = self.new_state.apps.get_model('part', 'part') Part = self.new_state.apps.get_model('part', 'part')
PartParameter = self.new_state.apps.get_model('part', 'partparameter') PartParameter = self.new_state.apps.get_model('part', 'partparameter')
PartParameterTemlate = self.new_state.apps.get_model( PartParameterTemlate = self.new_state.apps.get_model(
@ -155,13 +155,13 @@ class TestParameterMigrations(MigratorTestCase):
class PartUnitsMigrationTest(MigratorTestCase): class PartUnitsMigrationTest(MigratorTestCase):
"""Test for data migration of Part.units field""" """Test for data migration of Part.units field."""
migrate_from = ('part', '0109_auto_20230517_1048') migrate_from = ('part', '0109_auto_20230517_1048')
migrate_to = ('part', unit_test.getNewestMigrationFile('part')) migrate_to = ('part', unit_test.getNewestMigrationFile('part'))
def prepare(self): def prepare(self):
"""Prepare some parts with units""" """Prepare some parts with units."""
Part = self.old_state.apps.get_model('part', 'part') Part = self.old_state.apps.get_model('part', 'part')
units = ['mm', 'INCH', '', '%'] units = ['mm', 'INCH', '', '%']
@ -178,7 +178,7 @@ class PartUnitsMigrationTest(MigratorTestCase):
) )
def test_units_migration(self): def test_units_migration(self):
"""Test that the units have migrated OK""" """Test that the units have migrated OK."""
Part = self.new_state.apps.get_model('part', 'part') Part = self.new_state.apps.get_model('part', 'part')
part_1 = Part.objects.get(name='Part 1') part_1 = Part.objects.get(name='Part 1')
@ -193,7 +193,7 @@ class PartUnitsMigrationTest(MigratorTestCase):
class TestPartParameterTemplateMigration(MigratorTestCase): class TestPartParameterTemplateMigration(MigratorTestCase):
"""Test for data migration of PartParameterTemplate """Test for data migration of PartParameterTemplate.
Ref: https://github.com/inventree/InvenTree/pull/4987 Ref: https://github.com/inventree/InvenTree/pull/4987
""" """
@ -202,7 +202,7 @@ class TestPartParameterTemplateMigration(MigratorTestCase):
migrate_to = ('part', '0113_auto_20230531_1205') migrate_to = ('part', '0113_auto_20230531_1205')
def prepare(self): def prepare(self):
"""Prepare some parts with units""" """Prepare some parts with units."""
PartParameterTemplate = self.old_state.apps.get_model( PartParameterTemplate = self.old_state.apps.get_model(
'part', 'partparametertemplate' 'part', 'partparametertemplate'
) )
@ -220,7 +220,7 @@ class TestPartParameterTemplateMigration(MigratorTestCase):
template.checkbox template.checkbox
def test_units_migration(self): def test_units_migration(self):
"""Test that the new fields have been added correctly""" """Test that the new fields have been added correctly."""
PartParameterTemplate = self.new_state.apps.get_model( PartParameterTemplate = self.new_state.apps.get_model(
'part', 'partparametertemplate' 'part', 'partparametertemplate'
) )

View File

@ -1,4 +1,4 @@
"""Various unit tests for Part Parameters""" """Various unit tests for Part Parameters."""
import django.core.exceptions as django_exceptions import django.core.exceptions as django_exceptions
from django.test import TestCase, TransactionTestCase from django.test import TestCase, TransactionTestCase
@ -17,12 +17,12 @@ from .models import (
class TestParams(TestCase): class TestParams(TestCase):
"""Unit test class for testing the PartParameter model""" """Unit test class for testing the PartParameter model."""
fixtures = ['location', 'category', 'part', 'params'] fixtures = ['location', 'category', 'part', 'params']
def test_str(self): def test_str(self):
"""Test the str representation of the PartParameterTemplate model""" """Test the str representation of the PartParameterTemplate model."""
t1 = PartParameterTemplate.objects.get(pk=1) t1 = PartParameterTemplate.objects.get(pk=1)
self.assertEqual(str(t1), 'Length (mm)') self.assertEqual(str(t1), 'Length (mm)')
@ -33,7 +33,7 @@ class TestParams(TestCase):
self.assertEqual(str(c1), 'Mechanical | Length | 2.8') self.assertEqual(str(c1), 'Mechanical | Length | 2.8')
def test_validate(self): def test_validate(self):
"""Test validation for part templates""" """Test validation for part templates."""
n = PartParameterTemplate.objects.all().count() n = PartParameterTemplate.objects.all().count()
t1 = PartParameterTemplate(name='abcde', units='dd') t1 = PartParameterTemplate(name='abcde', units='dd')
@ -65,7 +65,7 @@ class TestParams(TestCase):
self.assertEqual(len(p.metadata.keys()), 4) self.assertEqual(len(p.metadata.keys()), 4)
def test_get_parameter(self): def test_get_parameter(self):
"""Test the Part.get_parameter method""" """Test the Part.get_parameter method."""
prt = Part.objects.get(pk=3) prt = Part.objects.get(pk=3)
# Check that we can get a parameter by name # Check that we can get a parameter by name
@ -79,12 +79,12 @@ class TestParams(TestCase):
class TestCategoryTemplates(TransactionTestCase): class TestCategoryTemplates(TransactionTestCase):
"""Test class for PartCategoryParameterTemplate model""" """Test class for PartCategoryParameterTemplate model."""
fixtures = ['location', 'category', 'part', 'params'] fixtures = ['location', 'category', 'part', 'params']
def test_validate(self): def test_validate(self):
"""Test that category templates are correctly applied to Part instances""" """Test that category templates are correctly applied to Part instances."""
# Category templates # Category templates
n = PartCategoryParameterTemplate.objects.all().count() n = PartCategoryParameterTemplate.objects.all().count()
self.assertEqual(n, 2) self.assertEqual(n, 2)
@ -102,12 +102,12 @@ class TestCategoryTemplates(TransactionTestCase):
class ParameterTests(TestCase): class ParameterTests(TestCase):
"""Unit tests for parameter validation""" """Unit tests for parameter validation."""
fixtures = ['location', 'category', 'part', 'params'] fixtures = ['location', 'category', 'part', 'params']
def test_choice_validation(self): def test_choice_validation(self):
"""Test that parameter choices are correctly validated""" """Test that parameter choices are correctly validated."""
template = PartParameterTemplate.objects.create( template = PartParameterTemplate.objects.create(
name='My Template', name='My Template',
description='A template with choices', description='A template with choices',
@ -129,7 +129,7 @@ class ParameterTests(TestCase):
param.full_clean() param.full_clean()
def test_unit_validation(self): def test_unit_validation(self):
"""Test validation of 'units' field for PartParameterTemplate""" """Test validation of 'units' field for PartParameterTemplate."""
# Test that valid units pass # Test that valid units pass
for unit in [ for unit in [
None, None,
@ -156,7 +156,7 @@ class ParameterTests(TestCase):
tmp.full_clean() tmp.full_clean()
def test_param_unit_validation(self): def test_param_unit_validation(self):
"""Test that parameters are correctly validated against template units""" """Test that parameters are correctly validated against template units."""
template = PartParameterTemplate.objects.create(name='My Template', units='m') template = PartParameterTemplate.objects.create(name='My Template', units='m')
prt = Part.objects.get(pk=1) prt = Part.objects.get(pk=1)
@ -208,7 +208,7 @@ class ParameterTests(TestCase):
param.full_clean() param.full_clean()
def test_param_unit_conversion(self): def test_param_unit_conversion(self):
"""Test that parameters are correctly converted to template units""" """Test that parameters are correctly converted to template units."""
template = PartParameterTemplate.objects.create(name='My Template', units='m') template = PartParameterTemplate.objects.create(name='My Template', units='m')
tests = { tests = {
@ -325,7 +325,7 @@ class PartParameterTest(InvenTreeAPITestCase):
"""Test that we can order parts by a specified parameter.""" """Test that we can order parts by a specified parameter."""
def get_param_value(response, template, index): def get_param_value(response, template, index):
"""Helper function to extract a parameter value from a response""" """Helper function to extract a parameter value from a response."""
params = response.data[index]['parameters'] params = response.data[index]['parameters']
for param in params: for param in params:

View File

@ -37,47 +37,47 @@ class TemplateTagTest(InvenTreeTestCase):
"""Tests for the custom template tag code.""" """Tests for the custom template tag code."""
def test_define(self): def test_define(self):
"""Test the 'define' template tag""" """Test the 'define' template tag."""
self.assertEqual(int(inventree_extras.define(3)), 3) self.assertEqual(int(inventree_extras.define(3)), 3)
def test_str2bool(self): def test_str2bool(self):
"""Various test for the str2bool template tag""" """Various test for the str2bool template tag."""
self.assertEqual(int(inventree_extras.str2bool('true')), True) self.assertEqual(int(inventree_extras.str2bool('true')), True)
self.assertEqual(int(inventree_extras.str2bool('yes')), True) self.assertEqual(int(inventree_extras.str2bool('yes')), True)
self.assertEqual(int(inventree_extras.str2bool('none')), False) self.assertEqual(int(inventree_extras.str2bool('none')), False)
self.assertEqual(int(inventree_extras.str2bool('off')), False) self.assertEqual(int(inventree_extras.str2bool('off')), False)
def test_add(self): def test_add(self):
"""Test that the 'add""" """Test that the 'add."""
self.assertEqual(int(inventree_extras.add(3, 5)), 8) self.assertEqual(int(inventree_extras.add(3, 5)), 8)
def test_plugins_enabled(self): def test_plugins_enabled(self):
"""Test the plugins_enabled tag""" """Test the plugins_enabled tag."""
self.assertEqual(inventree_extras.plugins_enabled(), True) self.assertEqual(inventree_extras.plugins_enabled(), True)
def test_inventree_instance_name(self): def test_inventree_instance_name(self):
"""Test the 'instance name' setting""" """Test the 'instance name' setting."""
self.assertEqual(inventree_extras.inventree_instance_name(), 'InvenTree') self.assertEqual(inventree_extras.inventree_instance_name(), 'InvenTree')
def test_inventree_base_url(self): def test_inventree_base_url(self):
"""Test that the base URL tag returns correctly""" """Test that the base URL tag returns correctly."""
self.assertEqual(inventree_extras.inventree_base_url(), '') self.assertEqual(inventree_extras.inventree_base_url(), '')
def test_inventree_is_release(self): def test_inventree_is_release(self):
"""Test that the release version check functions as expected""" """Test that the release version check functions as expected."""
self.assertEqual( self.assertEqual(
inventree_extras.inventree_is_release(), inventree_extras.inventree_is_release(),
not version.isInvenTreeDevelopmentVersion(), not version.isInvenTreeDevelopmentVersion(),
) )
def test_inventree_docs_version(self): def test_inventree_docs_version(self):
"""Test that the documentation version template tag returns correctly""" """Test that the documentation version template tag returns correctly."""
self.assertEqual( self.assertEqual(
inventree_extras.inventree_docs_version(), version.inventreeDocsVersion() inventree_extras.inventree_docs_version(), version.inventreeDocsVersion()
) )
def test_hash(self): def test_hash(self):
"""Test that the commit hash template tag returns correctly""" """Test that the commit hash template tag returns correctly."""
result_hash = inventree_extras.inventree_commit_hash() result_hash = inventree_extras.inventree_commit_hash()
if settings.DOCKER: # pragma: no cover if settings.DOCKER: # pragma: no cover
# Testing inside docker environment *may* return an empty git commit hash # Testing inside docker environment *may* return an empty git commit hash
@ -87,7 +87,7 @@ class TemplateTagTest(InvenTreeTestCase):
self.assertGreater(len(result_hash), 5) self.assertGreater(len(result_hash), 5)
def test_date(self): def test_date(self):
"""Test that the commit date template tag returns correctly""" """Test that the commit date template tag returns correctly."""
d = inventree_extras.inventree_commit_date() d = inventree_extras.inventree_commit_date()
if settings.DOCKER: # pragma: no cover if settings.DOCKER: # pragma: no cover
# Testing inside docker environment *may* return an empty git commit hash # Testing inside docker environment *may* return an empty git commit hash
@ -97,33 +97,33 @@ class TemplateTagTest(InvenTreeTestCase):
self.assertEqual(len(d.split('-')), 3) self.assertEqual(len(d.split('-')), 3)
def test_github(self): def test_github(self):
"""Test that the github URL template tag returns correctly""" """Test that the github URL template tag returns correctly."""
self.assertIn('github.com', inventree_extras.inventree_github_url()) self.assertIn('github.com', inventree_extras.inventree_github_url())
def test_docs(self): def test_docs(self):
"""Test that the documentation URL template tag returns correctly""" """Test that the documentation URL template tag returns correctly."""
self.assertIn('docs.inventree.org', inventree_extras.inventree_docs_url()) self.assertIn('docs.inventree.org', inventree_extras.inventree_docs_url())
def test_keyvalue(self): def test_keyvalue(self):
"""Test keyvalue template tag""" """Test keyvalue template tag."""
self.assertEqual(inventree_extras.keyvalue({'a': 'a'}, 'a'), 'a') self.assertEqual(inventree_extras.keyvalue({'a': 'a'}, 'a'), 'a')
def test_mail_configured(self): def test_mail_configured(self):
"""Test that mail configuration returns False""" """Test that mail configuration returns False."""
self.assertEqual(inventree_extras.mail_configured(), False) self.assertEqual(inventree_extras.mail_configured(), False)
def test_user_settings(self): def test_user_settings(self):
"""Test user settings""" """Test user settings."""
result = inventree_extras.user_settings(self.user) result = inventree_extras.user_settings(self.user)
self.assertEqual(len(result), len(InvenTreeUserSetting.SETTINGS)) self.assertEqual(len(result), len(InvenTreeUserSetting.SETTINGS))
def test_global_settings(self): def test_global_settings(self):
"""Test global settings""" """Test global settings."""
result = inventree_extras.global_settings() result = inventree_extras.global_settings()
self.assertEqual(len(result), len(InvenTreeSetting.SETTINGS)) self.assertEqual(len(result), len(InvenTreeSetting.SETTINGS))
def test_visible_global_settings(self): def test_visible_global_settings(self):
"""Test that hidden global settings are actually hidden""" """Test that hidden global settings are actually hidden."""
result = inventree_extras.visible_global_settings() result = inventree_extras.visible_global_settings()
n = len(result) n = len(result)
@ -147,7 +147,7 @@ class PartTest(TestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Create some Part instances as part of init routine""" """Create some Part instances as part of init routine."""
super().setUpTestData() super().setUpTestData()
cls.r1 = Part.objects.get(name='R_2K2_0805') cls.r1 = Part.objects.get(name='R_2K2_0805')
@ -158,7 +158,7 @@ class PartTest(TestCase):
Part.objects.rebuild() Part.objects.rebuild()
def test_barcode_mixin(self): def test_barcode_mixin(self):
"""Test the barcode mixin functionality""" """Test the barcode mixin functionality."""
self.assertEqual(Part.barcode_model_type(), 'part') self.assertEqual(Part.barcode_model_type(), 'part')
p = Part.objects.get(pk=1) p = Part.objects.get(pk=1)
@ -166,7 +166,7 @@ class PartTest(TestCase):
self.assertEqual(barcode, '{"part": 1}') self.assertEqual(barcode, '{"part": 1}')
def test_tree(self): def test_tree(self):
"""Test that the part variant tree is working properly""" """Test that the part variant tree is working properly."""
chair = Part.objects.get(pk=10000) chair = Part.objects.get(pk=10000)
self.assertEqual(chair.get_children().count(), 3) self.assertEqual(chair.get_children().count(), 3)
self.assertEqual(chair.get_descendant_count(), 4) self.assertEqual(chair.get_descendant_count(), 4)
@ -178,7 +178,7 @@ class PartTest(TestCase):
self.assertEqual(Part.objects.filter(tree_id=chair.tree_id).count(), 5) self.assertEqual(Part.objects.filter(tree_id=chair.tree_id).count(), 5)
def test_str(self): def test_str(self):
"""Test string representation of a Part""" """Test string representation of a Part."""
p = Part.objects.get(pk=100) p = Part.objects.get(pk=100)
self.assertEqual(str(p), 'BOB | Bob | A2 - Can we build it? Yes we can!') self.assertEqual(str(p), 'BOB | Bob | A2 - Can we build it? Yes we can!')
@ -235,12 +235,12 @@ class PartTest(TestCase):
part_2.validate_unique() part_2.validate_unique()
def test_attributes(self): def test_attributes(self):
"""Test Part attributes""" """Test Part attributes."""
self.assertEqual(self.r1.name, 'R_2K2_0805') self.assertEqual(self.r1.name, 'R_2K2_0805')
self.assertEqual(self.r1.get_absolute_url(), '/part/3/') self.assertEqual(self.r1.get_absolute_url(), '/part/3/')
def test_category(self): def test_category(self):
"""Test PartCategory path""" """Test PartCategory path."""
self.c1.category.save() self.c1.category.save()
self.assertEqual(str(self.c1.category), 'Electronics/Capacitors - Capacitors') self.assertEqual(str(self.c1.category), 'Electronics/Capacitors - Capacitors')
@ -249,25 +249,25 @@ class PartTest(TestCase):
self.assertEqual(orphan.category_path, '') self.assertEqual(orphan.category_path, '')
def test_rename_img(self): def test_rename_img(self):
"""Test that an image can be renamed""" """Test that an image can be renamed."""
img = rename_part_image(self.r1, 'hello.png') img = rename_part_image(self.r1, 'hello.png')
self.assertEqual(img, os.path.join('part_images', 'hello.png')) self.assertEqual(img, os.path.join('part_images', 'hello.png'))
def test_stock(self): def test_stock(self):
"""Test case where there is zero stock""" """Test case where there is zero stock."""
res = Part.objects.filter(description__contains='resistor') res = Part.objects.filter(description__contains='resistor')
for r in res: for r in res:
self.assertEqual(r.total_stock, 0) self.assertEqual(r.total_stock, 0)
self.assertEqual(r.available_stock, 0) self.assertEqual(r.available_stock, 0)
def test_barcode(self): def test_barcode(self):
"""Test barcode format functionality""" """Test barcode format functionality."""
barcode = self.r1.format_barcode(brief=False) barcode = self.r1.format_barcode(brief=False)
self.assertIn('InvenTree', barcode) self.assertIn('InvenTree', barcode)
self.assertIn('"part": {"id": 3}', barcode) self.assertIn('"part": {"id": 3}', barcode)
def test_sell_pricing(self): def test_sell_pricing(self):
"""Check that the sell pricebreaks were loaded""" """Check that the sell pricebreaks were loaded."""
self.assertTrue(self.r1.has_price_breaks) self.assertTrue(self.r1.has_price_breaks)
self.assertEqual(self.r1.price_breaks.count(), 2) self.assertEqual(self.r1.price_breaks.count(), 2)
# check that the sell pricebreaks work # check that the sell pricebreaks work
@ -275,7 +275,7 @@ class PartTest(TestCase):
self.assertEqual(float(self.r1.get_price(10)), 1.0) self.assertEqual(float(self.r1.get_price(10)), 1.0)
def test_internal_pricing(self): def test_internal_pricing(self):
"""Check that the sell pricebreaks were loaded""" """Check that the sell pricebreaks were loaded."""
self.assertTrue(self.r1.has_internal_price_breaks) self.assertTrue(self.r1.has_internal_price_breaks)
self.assertEqual(self.r1.internal_price_breaks.count(), 2) self.assertEqual(self.r1.internal_price_breaks.count(), 2)
# check that the sell pricebreaks work # check that the sell pricebreaks work
@ -300,7 +300,7 @@ class PartTest(TestCase):
self.assertEqual(len(p.metadata.keys()), 4) self.assertEqual(len(p.metadata.keys()), 4)
def test_related(self): def test_related(self):
"""Unit tests for the PartRelated model""" """Unit tests for the PartRelated model."""
# Create a part relationship # Create a part relationship
# Count before creation # Count before creation
countbefore = PartRelated.objects.count() countbefore = PartRelated.objects.count()
@ -348,7 +348,7 @@ class PartTest(TestCase):
self.assertEqual(PartRelated.objects.count(), countbefore) self.assertEqual(PartRelated.objects.count(), countbefore)
def test_stocktake(self): def test_stocktake(self):
"""Test for adding stocktake data""" """Test for adding stocktake data."""
# Grab a part # Grab a part
p = Part.objects.all().first() p = Part.objects.all().first()
@ -361,12 +361,12 @@ class PartTest(TestCase):
class TestTemplateTest(TestCase): class TestTemplateTest(TestCase):
"""Unit test for the TestTemplate class""" """Unit test for the TestTemplate class."""
fixtures = ['category', 'part', 'location', 'test_templates'] fixtures = ['category', 'part', 'location', 'test_templates']
def test_template_count(self): def test_template_count(self):
"""Tests for the test template functions""" """Tests for the test template functions."""
chair = Part.objects.get(pk=10000) chair = Part.objects.get(pk=10000)
# Tests for the top-level chair object (nothing above it!) # Tests for the top-level chair object (nothing above it!)
@ -383,7 +383,7 @@ class TestTemplateTest(TestCase):
self.assertEqual(variant.getTestTemplates(required=True).count(), 5) self.assertEqual(variant.getTestTemplates(required=True).count(), 5)
def test_uniqueness(self): def test_uniqueness(self):
"""Test names must be unique for this part and also parts above""" """Test names must be unique for this part and also parts above."""
variant = Part.objects.get(pk=10004) variant = Part.objects.get(pk=10004)
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
@ -434,7 +434,7 @@ class PartSettingsTest(InvenTreeTestCase):
self.assertFalse(part.settings.part_trackable_default()) self.assertFalse(part.settings.part_trackable_default())
def test_initial(self): def test_initial(self):
"""Test the 'initial' default values (no default values have been set)""" """Test the 'initial' default values (no default values have been set)."""
cache.clear() cache.clear()
part = self.make_part() part = self.make_part()
@ -525,13 +525,13 @@ class PartSettingsTest(InvenTreeTestCase):
class PartSubscriptionTests(InvenTreeTestCase): class PartSubscriptionTests(InvenTreeTestCase):
"""Unit tests for part 'subscription'""" """Unit tests for part 'subscription'."""
fixtures = ['location', 'category', 'part'] fixtures = ['location', 'category', 'part']
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Create category and part data as part of setup routine""" """Create category and part data as part of setup routine."""
super().setUpTestData() super().setUpTestData()
# Electronics / IC / MCU # Electronics / IC / MCU
@ -625,7 +625,7 @@ class BaseNotificationIntegrationTest(InvenTreeTestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
"""Add an email address as part of initialization""" """Add an email address as part of initialization."""
super().setUpTestData() super().setUpTestData()
# Add email address # Add email address
@ -668,7 +668,7 @@ class PartNotificationTest(BaseNotificationIntegrationTest):
"""Integration test for part notifications.""" """Integration test for part notifications."""
def test_notification(self): def test_notification(self):
"""Test that a notification is generated""" """Test that a notification is generated."""
self._notification_run(UIMessageNotification) self._notification_run(UIMessageNotification)
# There should be 1 notification message right now # There should be 1 notification message right now

View File

@ -1,4 +1,4 @@
"""Unit tests for Part pricing calculations""" """Unit tests for Part pricing calculations."""
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
@ -16,10 +16,10 @@ from InvenTree.unit_test import InvenTreeTestCase
class PartPricingTests(InvenTreeTestCase): class PartPricingTests(InvenTreeTestCase):
"""Unit tests for part pricing calculations""" """Unit tests for part pricing calculations."""
def setUp(self): def setUp(self):
"""Setup routines""" """Setup routines."""
super().setUp() super().setUp()
self.generate_exchange_rates() self.generate_exchange_rates()
@ -35,7 +35,7 @@ class PartPricingTests(InvenTreeTestCase):
) )
def create_price_breaks(self): def create_price_breaks(self):
"""Create some price breaks for the part, in various currencies""" """Create some price breaks for the part, in various currencies."""
# First supplier part (CAD) # First supplier part (CAD)
self.supplier_1 = company.models.Company.objects.create( self.supplier_1 = company.models.Company.objects.create(
name='Supplier 1', is_supplier=True name='Supplier 1', is_supplier=True
@ -87,7 +87,7 @@ class PartPricingTests(InvenTreeTestCase):
) )
def test_pricing_data(self): def test_pricing_data(self):
"""Test link between Part and PartPricing model""" """Test link between Part and PartPricing model."""
# Initially there is no associated Pricing data # Initially there is no associated Pricing data
with self.assertRaises(ObjectDoesNotExist): with self.assertRaises(ObjectDoesNotExist):
pricing = self.part.pricing_data pricing = self.part.pricing_data
@ -108,11 +108,11 @@ class PartPricingTests(InvenTreeTestCase):
self.assertIsNone(pricing.overall_max) self.assertIsNone(pricing.overall_max)
def test_invalid_rate(self): def test_invalid_rate(self):
"""Ensure that conversion behaves properly with missing rates""" """Ensure that conversion behaves properly with missing rates."""
... ...
def test_simple(self): def test_simple(self):
"""Tests for hard-coded values""" """Tests for hard-coded values."""
pricing = self.part.pricing pricing = self.part.pricing
# Add internal pricing # Add internal pricing
@ -143,7 +143,7 @@ class PartPricingTests(InvenTreeTestCase):
self.assertEqual(pricing.overall_max, Money('25', 'USD')) self.assertEqual(pricing.overall_max, Money('25', 'USD'))
def test_supplier_part_pricing(self): def test_supplier_part_pricing(self):
"""Test for supplier part pricing""" """Test for supplier part pricing."""
pricing = self.part.pricing pricing = self.part.pricing
# Initially, no information (not yet calculated) # Initially, no information (not yet calculated)
@ -169,7 +169,7 @@ class PartPricingTests(InvenTreeTestCase):
self.assertIsNone(pricing.supplier_price_max) self.assertIsNone(pricing.supplier_price_max)
def test_internal_pricing(self): def test_internal_pricing(self):
"""Tests for internal price breaks""" """Tests for internal price breaks."""
# Ensure internal pricing is enabled # Ensure internal pricing is enabled
common.models.InvenTreeSetting.set_setting('PART_INTERNAL_PRICE', True, None) common.models.InvenTreeSetting.set_setting('PART_INTERNAL_PRICE', True, None)
@ -201,7 +201,7 @@ class PartPricingTests(InvenTreeTestCase):
self.assertEqual(pricing.overall_max, Money(10, currency)) self.assertEqual(pricing.overall_max, Money(10, currency))
def test_stock_item_pricing(self): def test_stock_item_pricing(self):
"""Test for stock item pricing data""" """Test for stock item pricing data."""
# Create a part # Create a part
p = part.models.Part.objects.create( p = part.models.Part.objects.create(
name='Test part for pricing', name='Test part for pricing',
@ -247,7 +247,7 @@ class PartPricingTests(InvenTreeTestCase):
self.assertEqual(pricing.overall_max, Money(6.666667, 'USD')) self.assertEqual(pricing.overall_max, Money(6.666667, 'USD'))
def test_bom_pricing(self): def test_bom_pricing(self):
"""Unit test for BOM pricing calculations""" """Unit test for BOM pricing calculations."""
pricing = self.part.pricing pricing = self.part.pricing
self.assertIsNone(pricing.bom_cost_min) self.assertIsNone(pricing.bom_cost_min)
@ -286,7 +286,7 @@ class PartPricingTests(InvenTreeTestCase):
self.assertEqual(pricing.overall_max, Money('550', 'USD')) self.assertEqual(pricing.overall_max, Money('550', 'USD'))
def test_purchase_pricing(self): def test_purchase_pricing(self):
"""Unit tests for historical purchase pricing""" """Unit tests for historical purchase pricing."""
self.create_price_breaks() self.create_price_breaks()
pricing = self.part.pricing pricing = self.part.pricing
@ -353,7 +353,7 @@ class PartPricingTests(InvenTreeTestCase):
self.assertAlmostEqual(float(pricing.purchase_cost_max.amount), 6.95, places=2) self.assertAlmostEqual(float(pricing.purchase_cost_max.amount), 6.95, places=2)
def test_delete_with_pricing(self): def test_delete_with_pricing(self):
"""Test for deleting a part which has pricing information""" """Test for deleting a part which has pricing information."""
# Create some pricing data # Create some pricing data
self.create_price_breaks() self.create_price_breaks()
@ -377,7 +377,7 @@ class PartPricingTests(InvenTreeTestCase):
pricing.refresh_from_db() pricing.refresh_from_db()
def test_delete_without_pricing(self): def test_delete_without_pricing(self):
"""Test that we can delete a part which does not have pricing information""" """Test that we can delete a part which does not have pricing information."""
pricing = self.part.pricing pricing = self.part.pricing
self.assertIsNone(pricing.pk) self.assertIsNone(pricing.pk)
@ -392,7 +392,7 @@ class PartPricingTests(InvenTreeTestCase):
self.part.refresh_from_db() self.part.refresh_from_db()
def test_check_missing_pricing(self): def test_check_missing_pricing(self):
"""Tests for check_missing_pricing background task """Tests for check_missing_pricing background task.
Calling the check_missing_pricing task should: Calling the check_missing_pricing task should:
- Create PartPricing objects where there are none - Create PartPricing objects where there are none

View File

@ -1,4 +1,4 @@
"""Unit tests for Part Views (see views.py)""" """Unit tests for Part Views (see views.py)."""
from django.urls import reverse from django.urls import reverse
@ -8,7 +8,7 @@ from .models import Part
class PartViewTestCase(InvenTreeTestCase): class PartViewTestCase(InvenTreeTestCase):
"""Base class for unit testing the various Part views""" """Base class for unit testing the various Part views."""
fixtures = ['category', 'part', 'bom', 'location', 'company', 'supplier_part'] fixtures = ['category', 'part', 'bom', 'location', 'company', 'supplier_part']
@ -17,10 +17,10 @@ class PartViewTestCase(InvenTreeTestCase):
class PartListTest(PartViewTestCase): class PartListTest(PartViewTestCase):
"""Unit tests for the PartList view""" """Unit tests for the PartList view."""
def test_part_index(self): def test_part_index(self):
"""Test that the PartIndex page returns successfully""" """Test that the PartIndex page returns successfully."""
response = self.client.get(reverse('part-index')) response = self.client.get(reverse('part-index'))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@ -31,7 +31,7 @@ class PartListTest(PartViewTestCase):
class PartDetailTest(PartViewTestCase): class PartDetailTest(PartViewTestCase):
"""Unit tests for the PartDetail view""" """Unit tests for the PartDetail view."""
def test_part_detail(self): def test_part_detail(self):
"""Test that we can retrieve a part detail page.""" """Test that we can retrieve a part detail page."""
@ -51,8 +51,9 @@ class PartDetailTest(PartViewTestCase):
self.assertEqual(response.context['category'], part.category) self.assertEqual(response.context['category'], part.category)
def test_part_detail_from_ipn(self): def test_part_detail_from_ipn(self):
"""Test that we can retrieve a part detail page from part IPN: """Test that we can retrieve a part detail page from part IPN.
Rules:
- if no part with matching IPN -> return part index - if no part with matching IPN -> return part index
- if unique IPN match -> return part detail page - if unique IPN match -> return part detail page
- if multiple IPN matches -> return part index - if multiple IPN matches -> return part index
@ -61,7 +62,7 @@ class PartDetailTest(PartViewTestCase):
pk = 1 pk = 1
def test_ipn_match(index_result=False, detail_result=False): def test_ipn_match(index_result=False, detail_result=False):
"""Helper function for matching IPN detail view""" """Helper function for matching IPN detail view."""
index_redirect = False index_redirect = False
detail_redirect = False detail_redirect = False

Some files were not shown because too many files have changed in this diff Show More