diff --git a/.github/workflows/qc_checks.yaml b/.github/workflows/qc_checks.yaml index 7ff9315851..b3c9e18a78 100644 --- a/.github/workflows/qc_checks.yaml +++ b/.github/workflows/qc_checks.yaml @@ -130,6 +130,22 @@ jobs: invoke check-server coverage run -m unittest discover -s test/ + docstyle: + name: Style [Python Docstrings] + runs-on: ubuntu-20.04 + + needs: pre-commit + continue-on-error: true + + steps: + - uses: actions/checkout@v1 + - name: Enviroment Setup + uses: ./.github/actions/setup + with: + install: true + - name: Run flake8 + run: flake8 InvenTree --statistics + coverage: name: Tests - DB [SQLite] + Coverage runs-on: ubuntu-20.04 diff --git a/InvenTree/InvenTree/__init__.py b/InvenTree/InvenTree/__init__.py index 521e0a60fc..006862ae55 100644 --- a/InvenTree/InvenTree/__init__.py +++ b/InvenTree/InvenTree/__init__.py @@ -1,5 +1,4 @@ -""" -The InvenTree module provides high-level management and functionality. +"""The InvenTree module provides high-level management and functionality. It provides a number of helper functions and generic classes which are used by InvenTree apps. """ diff --git a/InvenTree/InvenTree/api.py b/InvenTree/InvenTree/api.py index 0713479f7a..974d518e6a 100644 --- a/InvenTree/InvenTree/api.py +++ b/InvenTree/InvenTree/api.py @@ -1,6 +1,4 @@ -""" -Main JSON interface views -""" +"""Main JSON interface views.""" from django.conf import settings from django.http import JsonResponse @@ -16,14 +14,15 @@ from .views import AjaxView class InfoView(AjaxView): - """ Simple JSON endpoint for InvenTree information. + """Simple JSON endpoint for InvenTree information. + Use to confirm that the server is running, etc. """ permission_classes = [permissions.AllowAny] def get(self, request, *args, **kwargs): - + """Serve current server information.""" data = { 'server': 'InvenTree', 'version': inventreeVersion(), @@ -37,14 +36,12 @@ class InfoView(AjaxView): class NotFoundView(AjaxView): - """ - Simple JSON view when accessing an invalid API view. - """ + """Simple JSON view when accessing an invalid API view.""" permission_classes = [permissions.AllowAny] def get(self, request, *args, **kwargs): - + """Proces an `not found` event on the API.""" data = { 'details': _('API endpoint not found'), 'url': request.build_absolute_uri(), @@ -54,8 +51,7 @@ class NotFoundView(AjaxView): class APIDownloadMixin: - """ - Mixin for enabling a LIST endpoint to be downloaded a file. + """Mixin for enabling a LIST endpoint to be downloaded a file. To download the data, add the ?export= to the query string. @@ -76,7 +72,7 @@ class APIDownloadMixin: """ def get(self, request, *args, **kwargs): - + """Generic handler for a download request.""" export_format = request.query_params.get('export', None) if export_format and export_format in ['csv', 'tsv', 'xls', 'xlsx']: @@ -88,14 +84,12 @@ class APIDownloadMixin: return super().get(request, *args, **kwargs) def download_queryset(self, queryset, export_format): + """This function must be implemented to provide a downloadFile request.""" raise NotImplementedError("download_queryset method not implemented!") class AttachmentMixin: - """ - Mixin for creating attachment objects, - and ensuring the user information is saved correctly. - """ + """Mixin for creating attachment objects, and ensuring the user information is saved correctly.""" permission_classes = [permissions.IsAuthenticated] @@ -106,8 +100,7 @@ class AttachmentMixin: ] def perform_create(self, serializer): - """ Save the user information when a file is uploaded """ - + """Save the user information when a file is uploaded.""" attachment = serializer.save() attachment.user = self.request.user attachment.save() diff --git a/InvenTree/InvenTree/api_tester.py b/InvenTree/InvenTree/api_tester.py index 5385f8f01b..052e1c31a4 100644 --- a/InvenTree/InvenTree/api_tester.py +++ b/InvenTree/InvenTree/api_tester.py @@ -1,6 +1,4 @@ -""" -Helper functions for performing API unit tests -""" +"""Helper functions for performing API unit tests.""" import csv import io @@ -14,6 +12,10 @@ from rest_framework.test import APITestCase class UserMixin: + """Mixin to setup a user and login for tests. + + Use parameters to set username, password, email, roles and permissions. + """ # User information username = 'testuser' @@ -28,7 +30,7 @@ class UserMixin: roles = [] def setUp(self): - + """Setup for all tests.""" super().setUp() # Create a user to log in with @@ -62,10 +64,7 @@ class UserMixin: self.client.login(username=self.username, password=self.password) def assignRole(self, role=None, assign_all: bool = False): - """ - Set the user roles for the registered user - """ - + """Set the user roles for the registered user.""" # role is of the format 'rule.permission' e.g. 'part.add' if not assign_all and role: @@ -89,16 +88,13 @@ class UserMixin: class InvenTreeAPITestCase(UserMixin, APITestCase): - """ - Base class for running InvenTree API tests - """ + """Base class for running InvenTree API tests.""" def getActions(self, url): - """ - Return a dict of the 'actions' available at a given endpoint. + """Return a dict of the 'actions' available at a given endpoint. + Makes use of the HTTP 'OPTIONS' method to request this. """ - response = self.client.options(url) self.assertEqual(response.status_code, 200) @@ -110,10 +106,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase): return actions def get(self, url, data={}, expected_code=200): - """ - Issue a GET request - """ - + """Issue a GET request.""" response = self.client.get(url, data, format='json') if expected_code is not None: @@ -127,10 +120,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase): return response def post(self, url, data, expected_code=None, format='json'): - """ - Issue a POST request - """ - + """Issue a POST request.""" response = self.client.post(url, data=data, format=format) if expected_code is not None: @@ -139,10 +129,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase): return response def delete(self, url, expected_code=None): - """ - Issue a DELETE request - """ - + """Issue a DELETE request.""" response = self.client.delete(url) if expected_code is not None: @@ -151,10 +138,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase): return response def patch(self, url, data, expected_code=None, format='json'): - """ - Issue a PATCH request - """ - + """Issue a PATCH request.""" response = self.client.patch(url, data=data, format=format) if expected_code is not None: @@ -163,10 +147,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase): return response def put(self, url, data, expected_code=None, format='json'): - """ - Issue a PUT request - """ - + """Issue a PUT request.""" response = self.client.put(url, data=data, format=format) if expected_code is not None: @@ -175,10 +156,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase): return response def options(self, url, expected_code=None): - """ - Issue an OPTIONS request - """ - + """Issue an OPTIONS request.""" response = self.client.options(url, format='json') if expected_code is not None: @@ -187,10 +165,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase): return response def download_file(self, url, data, expected_code=None, expected_fn=None, decode=True): - """ - Download a file from the server, and return an in-memory file - """ - + """Download a file from the server, and return an in-memory file.""" response = self.client.get(url, data=data, format='json') if expected_code is not None: @@ -226,10 +201,7 @@ class InvenTreeAPITestCase(UserMixin, APITestCase): return fo def process_csv(self, fo, delimiter=',', required_cols=None, excluded_cols=None, required_rows=None): - """ - Helper function to process and validate a downloaded csv file - """ - + """Helper function to process and validate a downloaded csv file.""" # Check that the correct object type has been passed self.assertTrue(isinstance(fo, io.StringIO)) diff --git a/InvenTree/InvenTree/api_version.py b/InvenTree/InvenTree/api_version.py index 1fc395e83d..423aea9f97 100644 --- a/InvenTree/InvenTree/api_version.py +++ b/InvenTree/InvenTree/api_version.py @@ -1,6 +1,4 @@ -""" -InvenTree API version information -""" +"""InvenTree API version information.""" # InvenTree API version diff --git a/InvenTree/InvenTree/apps.py b/InvenTree/InvenTree/apps.py index 74dbca444c..c8a8589ae8 100644 --- a/InvenTree/InvenTree/apps.py +++ b/InvenTree/InvenTree/apps.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +"""AppConfig for inventree app.""" import logging @@ -18,10 +18,11 @@ logger = logging.getLogger("inventree") class InvenTreeConfig(AppConfig): + """AppConfig for inventree app.""" name = 'InvenTree' def ready(self): - + """Setup background tasks and update exchange rates.""" if canAppAccessDatabase(): self.remove_obsolete_tasks() @@ -37,10 +38,7 @@ class InvenTreeConfig(AppConfig): self.add_user_on_startup() def remove_obsolete_tasks(self): - """ - Delete any obsolete scheduled tasks in the database - """ - + """Delete any obsolete scheduled tasks in the database.""" obsolete = [ 'InvenTree.tasks.delete_expired_sessions', 'stock.tasks.delete_old_stock_items', @@ -55,7 +53,7 @@ class InvenTreeConfig(AppConfig): Schedule.objects.filter(func__in=obsolete).delete() def start_background_tasks(self): - + """Start all background tests for InvenTree.""" try: from django_q.models import Schedule except AppRegistryNotReady: # pragma: no cover @@ -101,13 +99,12 @@ class InvenTreeConfig(AppConfig): ) def update_exchange_rates(self): # pragma: no cover - """ - Update exchange rates each time the server is started, *if*: + """Update exchange rates each time the server is started. + Only runs *if*: a) Have not been updated recently (one day or less) b) The base exchange rate has been altered """ - try: from djmoney.contrib.exchange.models import ExchangeBackend @@ -150,7 +147,7 @@ class InvenTreeConfig(AppConfig): logger.error(f"Error updating exchange rates: {e}") def add_user_on_startup(self): - """Add a user on startup""" + """Add a user on startup.""" # stop if checks were already created if hasattr(settings, 'USER_ADDED') and settings.USER_ADDED: return @@ -202,9 +199,7 @@ class InvenTreeConfig(AppConfig): settings.USER_ADDED = True def collect_notification_methods(self): - """ - Collect all notification methods - """ + """Collect all notification methods.""" from common.notifications import storage storage.collect() diff --git a/InvenTree/InvenTree/ci_render_js.py b/InvenTree/InvenTree/ci_render_js.py index 1be38df107..5c7c36a30c 100644 --- a/InvenTree/InvenTree/ci_render_js.py +++ b/InvenTree/InvenTree/ci_render_js.py @@ -1,6 +1,6 @@ -""" -Pull rendered copies of the templated -only used for testing the js files! - This file is omited from coverage +"""Pull rendered copies of the templated. + +Only used for testing the js files! - This file is omited from coverage. """ import os # pragma: no cover @@ -10,15 +10,14 @@ from InvenTree.helpers import InvenTreeTestCase # pragma: no cover class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover - """ - A unit test to "render" javascript files. + """A unit test to "render" javascript files. The server renders templated javascript files, we need the fully-rendered files for linting and static tests. """ def download_file(self, filename, prefix): - + """Function to `download`(copy) a file to a temporay firectory.""" url = os.path.join(prefix, filename) response = self.client.get(url) @@ -46,6 +45,7 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover output.write(response.content) def download_files(self, subdir, prefix): + """Download files in directory.""" here = os.path.abspath(os.path.dirname(__file__)) js_template_dir = os.path.join( @@ -73,10 +73,7 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover return n def test_render_files(self): - """ - Look for all javascript files - """ - + """Look for all javascript files.""" n = 0 print("Rendering javascript files...") diff --git a/InvenTree/InvenTree/config.py b/InvenTree/InvenTree/config.py index d7691cf4cc..fe3e41a10f 100644 --- a/InvenTree/InvenTree/config.py +++ b/InvenTree/InvenTree/config.py @@ -1,6 +1,4 @@ -""" -Helper functions for loading InvenTree configuration options -""" +"""Helper functions for loading InvenTree configuration options.""" import logging import os @@ -10,17 +8,15 @@ logger = logging.getLogger('inventree') def get_base_dir(): - """ Returns the base (top-level) InvenTree directory """ + """Returns the base (top-level) InvenTree directory.""" return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) def get_config_file(): - """ - Returns the path of the InvenTree configuration file. + """Returns the path of the InvenTree configuration file. Note: It will be created it if does not already exist! """ - base_dir = get_base_dir() cfg_filename = os.getenv('INVENTREE_CONFIG_FILE') @@ -43,8 +39,7 @@ def get_config_file(): def get_plugin_file(): - """ - Returns the path of the InvenTree plugins specification file. + """Returns the path of the InvenTree plugins specification file. Note: It will be created if it does not already exist! """ @@ -70,14 +65,12 @@ def get_plugin_file(): def get_setting(environment_var, backup_val, default_value=None): - """ - Helper function for retrieving a configuration setting value + """Helper function for retrieving a configuration setting value. - First preference is to look for the environment variable - Second preference is to look for the value of the settings file - Third preference is the default value """ - val = os.getenv(environment_var) if val is not None: diff --git a/InvenTree/InvenTree/context.py b/InvenTree/InvenTree/context.py index c0b27ad5bf..4e3d38d1b0 100644 --- a/InvenTree/InvenTree/context.py +++ b/InvenTree/InvenTree/context.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- -""" -Provides extra global data to all templates. -""" +"""Provides extra global data to all templates.""" import InvenTree.status from InvenTree.status_codes import (BuildStatus, PurchaseOrderStatus, @@ -12,13 +10,11 @@ from users.models import RuleSet def health_status(request): - """ - Provide system health status information to the global context. + """Provide system health status information to the global context. - Not required for AJAX requests - Do not provide if it is already provided to the context """ - if request.path.endswith('.js'): # Do not provide to script requests return {} # pragma: no cover @@ -53,10 +49,7 @@ def health_status(request): def status_codes(request): - """ - Provide status code enumerations. - """ - + """Provide status code enumerations.""" if hasattr(request, '_inventree_status_codes'): # Do not duplicate efforts return {} @@ -74,8 +67,7 @@ def status_codes(request): def user_roles(request): - """ - Return a map of the current roles assigned to the user. + """Return a map of the current roles assigned to the user. Roles are denoted by their simple names, and then the permission type. @@ -86,7 +78,6 @@ def user_roles(request): Each value will return a boolean True / False """ - user = request.user roles = { diff --git a/InvenTree/InvenTree/exceptions.py b/InvenTree/InvenTree/exceptions.py index a4737bac4d..2b64fd8b64 100644 --- a/InvenTree/InvenTree/exceptions.py +++ b/InvenTree/InvenTree/exceptions.py @@ -1,6 +1,4 @@ -""" -Custom exception handling for the DRF API -""" +"""Custom exception handling for the DRF API.""" # -*- coding: utf-8 -*- from __future__ import unicode_literals @@ -21,13 +19,11 @@ from rest_framework.response import Response def exception_handler(exc, context): - """ - Custom exception handler for DRF framework. - Ref: https://www.django-rest-framework.org/api-guide/exceptions/#custom-exception-handling + """Custom exception handler for DRF framework. + Ref: https://www.django-rest-framework.org/api-guide/exceptions/#custom-exception-handling Catches any errors not natively handled by DRF, and re-throws as an error DRF can handle """ - response = None # Catch any django validation error, and re-throw a DRF validation error diff --git a/InvenTree/InvenTree/exchange.py b/InvenTree/InvenTree/exchange.py index a46e1356a8..52f6b06725 100644 --- a/InvenTree/InvenTree/exchange.py +++ b/InvenTree/InvenTree/exchange.py @@ -1,3 +1,5 @@ +"""Exchangerate backend to use `exchangerate.host` to get rates.""" + import ssl from urllib.error import URLError from urllib.request import urlopen @@ -11,30 +13,30 @@ from common.settings import currency_code_default, currency_codes class InvenTreeExchange(SimpleExchangeBackend): - """ - Backend for automatically updating currency exchange rates. + """Backend for automatically updating currency exchange rates. - Uses the exchangerate.host service API + Uses the `exchangerate.host` service API """ name = "InvenTreeExchange" def __init__(self): + """Set API url.""" self.url = "https://api.exchangerate.host/latest" super().__init__() def get_params(self): + """Placeholder to set API key. Currently not required by `exchangerate.host`.""" # No API key is required return { } def get_response(self, **kwargs): - """ - Custom code to get response from server. + """Custom code to get response from server. + Note: Adds a 5-second timeout """ - url = self.get_url(**kwargs) try: @@ -46,7 +48,7 @@ class InvenTreeExchange(SimpleExchangeBackend): return None def update_rates(self, base_currency=currency_code_default()): - + """Set the requested currency codes and get rates.""" symbols = ','.join(currency_codes()) try: diff --git a/InvenTree/InvenTree/fields.py b/InvenTree/InvenTree/fields.py index 9995c5444c..9f5d63312b 100644 --- a/InvenTree/InvenTree/fields.py +++ b/InvenTree/InvenTree/fields.py @@ -1,4 +1,4 @@ -""" Custom fields used in InvenTree """ +"""Custom fields used in InvenTree.""" import sys from decimal import Decimal @@ -19,24 +19,25 @@ from .validators import allowable_url_schemes class InvenTreeURLFormField(FormURLField): - """ Custom URL form field with custom scheme validators """ + """Custom URL form field with custom scheme validators.""" default_validators = [validators.URLValidator(schemes=allowable_url_schemes())] class InvenTreeURLField(models.URLField): - """ Custom URL field which has custom scheme validators """ + """Custom URL field which has custom scheme validators.""" default_validators = [validators.URLValidator(schemes=allowable_url_schemes())] def formfield(self, **kwargs): + """Return a Field instance for this field.""" return super().formfield(**{ 'form_class': InvenTreeURLFormField }) def money_kwargs(): - """ returns the database settings for MoneyFields """ + """Returns the database settings for MoneyFields.""" from common.settings import currency_code_default, currency_code_mappings kwargs = {} @@ -46,11 +47,10 @@ def money_kwargs(): class InvenTreeModelMoneyField(ModelMoneyField): - """ - Custom MoneyField for clean migrations while using dynamic currency settings - """ + """Custom MoneyField for clean migrations while using dynamic currency settings.""" def __init__(self, **kwargs): + """Overwrite default values and validators.""" # detect if creating migration if 'migrate' in sys.argv or 'makemigrations' in sys.argv: # remove currency information for a clean migration @@ -73,26 +73,24 @@ class InvenTreeModelMoneyField(ModelMoneyField): super().__init__(**kwargs) def formfield(self, **kwargs): - """ override form class to use own function """ + """Override form class to use own function.""" kwargs['form_class'] = InvenTreeMoneyField return super().formfield(**kwargs) class InvenTreeMoneyField(MoneyField): - """ custom MoneyField for clean migrations while using dynamic currency settings """ + """Custom MoneyField for clean migrations while using dynamic currency settings.""" def __init__(self, *args, **kwargs): - # override initial values with the real info from database + """Override initial values with the real info from database.""" kwargs.update(money_kwargs()) super().__init__(*args, **kwargs) class DatePickerFormField(forms.DateField): - """ - Custom date-picker field - """ + """Custom date-picker field.""" def __init__(self, **kwargs): - + """Set up custom values.""" help_text = kwargs.get('help_text', _('Enter date')) label = kwargs.get('label', None) required = kwargs.get('required', False) @@ -115,10 +113,7 @@ class DatePickerFormField(forms.DateField): def round_decimal(value, places): - """ - Round value to the specified number of places. - """ - + """Round value to the specified number of places.""" if value is not None: # see https://docs.python.org/2/library/decimal.html#decimal.Decimal.quantize for options return value.quantize(Decimal(10) ** -places) @@ -126,17 +121,19 @@ def round_decimal(value, places): class RoundingDecimalFormField(forms.DecimalField): + """Custom FormField that automatically rounds inputs.""" + def to_python(self, value): + """Convert value to python type.""" value = super().to_python(value) value = round_decimal(value, self.decimal_places) return value def prepare_value(self, value): - """ - Override the 'prepare_value' method, to remove trailing zeros when displaying. + """Override the 'prepare_value' method, to remove trailing zeros when displaying. + Why? It looks nice! """ - if type(value) == Decimal: return InvenTree.helpers.normalize(value) else: @@ -144,11 +141,15 @@ class RoundingDecimalFormField(forms.DecimalField): class RoundingDecimalField(models.DecimalField): + """Custom Field that automatically rounds inputs.""" + def to_python(self, value): + """Convert value to python type.""" value = super().to_python(value) return round_decimal(value, self.decimal_places) def formfield(self, **kwargs): + """Return a Field instance for this field.""" defaults = { 'form_class': RoundingDecimalFormField } diff --git a/InvenTree/InvenTree/filters.py b/InvenTree/InvenTree/filters.py index f0058e399a..6cd405cf9e 100644 --- a/InvenTree/InvenTree/filters.py +++ b/InvenTree/InvenTree/filters.py @@ -1,9 +1,10 @@ +"""General filters for InvenTree.""" + from rest_framework.filters import OrderingFilter class InvenTreeOrderingFilter(OrderingFilter): - """ - Custom OrderingFilter class which allows aliased filtering of related fields. + """Custom OrderingFilter class which allows aliased filtering of related fields. To use, simply specify this filter in the "filter_backends" section. @@ -20,16 +21,14 @@ class InvenTreeOrderingFilter(OrderingFilter): """ def get_ordering(self, request, queryset, view): - + """Override ordering for supporting aliases.""" ordering = super().get_ordering(request, queryset, view) aliases = getattr(view, 'ordering_field_aliases', None) # Attempt to map ordering fields based on provided aliases if ordering is not None and aliases is not None: - """ - Ordering fields should be mapped to separate fields - """ + """Ordering fields should be mapped to separate fields.""" ordering_initial = ordering ordering = [] diff --git a/InvenTree/InvenTree/forms.py b/InvenTree/InvenTree/forms.py index 50e2d26fff..7ee9c8634a 100644 --- a/InvenTree/InvenTree/forms.py +++ b/InvenTree/InvenTree/forms.py @@ -1,6 +1,4 @@ -""" -Helper forms which subclass Django forms to provide additional functionality -""" +"""Helper forms which subclass Django forms to provide additional functionality.""" import logging from urllib.parse import urlencode @@ -30,7 +28,7 @@ logger = logging.getLogger('inventree') class HelperForm(forms.ModelForm): - """ Provides simple integration of crispy_forms extension. """ + """Provides simple integration of crispy_forms extension.""" # Custom field decorations can be specified here, per form class field_prefix = {} @@ -38,6 +36,7 @@ class HelperForm(forms.ModelForm): field_placeholder = {} def __init__(self, *args, **kwargs): + """Setup layout.""" super(forms.ModelForm, self).__init__(*args, **kwargs) self.helper = FormHelper() @@ -54,14 +53,8 @@ class HelperForm(forms.ModelForm): self.rebuild_layout() - def is_valid(self): - - valid = super().is_valid() - - return valid - def rebuild_layout(self): - + """Build crispy layout out of current fields.""" layouts = [] for field in self.fields: @@ -117,7 +110,7 @@ class HelperForm(forms.ModelForm): class ConfirmForm(forms.Form): - """ Generic confirmation form """ + """Generic confirmation form.""" confirm = forms.BooleanField( required=False, initial=False, @@ -125,14 +118,15 @@ class ConfirmForm(forms.Form): ) class Meta: + """Metaclass options.""" + fields = [ 'confirm' ] class DeleteForm(forms.Form): - """ Generic deletion form which provides simple user confirmation - """ + """Generic deletion form which provides simple user confirmation.""" confirm_delete = forms.BooleanField( required=False, @@ -142,17 +136,19 @@ class DeleteForm(forms.Form): ) class Meta: + """Metaclass options.""" + fields = [ 'confirm_delete' ] class EditUserForm(HelperForm): - """ - Form for editing user information - """ + """Form for editing user information.""" class Meta: + """Metaclass options.""" + model = User fields = [ 'first_name', @@ -161,8 +157,7 @@ class EditUserForm(HelperForm): class SetPasswordForm(HelperForm): - """ Form for setting user password - """ + """Form for setting user password.""" enter_password = forms.CharField(max_length=100, min_length=8, @@ -181,6 +176,8 @@ class SetPasswordForm(HelperForm): help_text=_('Confirm new password')) class Meta: + """Metaclass options.""" + model = User fields = [ 'enter_password', @@ -189,18 +186,21 @@ class SetPasswordForm(HelperForm): class SettingCategorySelectForm(forms.ModelForm): - """ Form for setting category settings """ + """Form for setting category settings.""" category = forms.ModelChoiceField(queryset=PartCategory.objects.all()) class Meta: + """Metaclass options.""" + model = PartCategory fields = [ 'category' ] def __init__(self, *args, **kwargs): - super(SettingCategorySelectForm, self).__init__(*args, **kwargs) + """Setup form layout.""" + super().__init__(*args, **kwargs) self.helper = FormHelper() # Form rendering @@ -220,10 +220,10 @@ class SettingCategorySelectForm(forms.ModelForm): # override allauth class CustomSignupForm(SignupForm): - """ - Override to use dynamic settings - """ + """Override to use dynamic settings.""" + def __init__(self, *args, **kwargs): + """Check settings to influence which fields are needed.""" kwargs['email_required'] = InvenTreeSetting.get_setting('LOGIN_MAIL_REQUIRED') super().__init__(*args, **kwargs) @@ -248,6 +248,7 @@ class CustomSignupForm(SignupForm): set_form_field_order(self, ["username", "email", "email2", "password1", "password2", ]) def clean(self): + """Make sure the supllied emails match if enabled in settings.""" cleaned_data = super().clean() # check for two mail fields @@ -261,15 +262,16 @@ class CustomSignupForm(SignupForm): class RegistratonMixin: - """ - Mixin to check if registration should be enabled - """ + """Mixin to check if registration should be enabled.""" + def is_open_for_signup(self, request, *args, **kwargs): + """Check if signup is enabled in settings.""" if settings.EMAIL_HOST and InvenTreeSetting.get_setting('LOGIN_ENABLE_REG', True): return super().is_open_for_signup(request, *args, **kwargs) return False def save_user(self, request, user, form, commit=True): + """Check if a default group is set in settings.""" user = super().save_user(request, user, form) start_group = InvenTreeSetting.get_setting('SIGNUP_GROUP') if start_group: @@ -283,21 +285,19 @@ class RegistratonMixin: class CustomAccountAdapter(RegistratonMixin, OTPAdapter, DefaultAccountAdapter): - """ - Override of adapter to use dynamic settings - """ + """Override of adapter to use dynamic settings.""" def send_mail(self, template_prefix, email, context): - """only send mail if backend configured""" + """Only send mail if backend configured.""" if settings.EMAIL_HOST: return super().send_mail(template_prefix, email, context) return False class CustomSocialAccountAdapter(RegistratonMixin, DefaultSocialAccountAdapter): - """ - Override of adapter to use dynamic settings - """ + """Override of adapter to use dynamic settings.""" + def is_auto_signup_allowed(self, request, sociallogin): + """Check if auto signup is enabled in settings.""" if InvenTreeSetting.get_setting('LOGIN_SIGNUP_SSO_AUTO', True): return super().is_auto_signup_allowed(request, sociallogin) return False @@ -308,6 +308,7 @@ class CustomSocialAccountAdapter(RegistratonMixin, DefaultSocialAccountAdapter): return user_has_valid_totp_device(user) def login(self, request, user): + """Ensure user is send to 2FA before login if enabled.""" # Require two-factor authentication if it has been configured. if self.has_2fa_enabled(user): # Cast to string for the case when this is not a JSON serializable diff --git a/InvenTree/InvenTree/helpers.py b/InvenTree/InvenTree/helpers.py index 00ac33ae68..2a7f6f57c2 100644 --- a/InvenTree/InvenTree/helpers.py +++ b/InvenTree/InvenTree/helpers.py @@ -1,6 +1,4 @@ -""" -Provides helper functions used throughout the InvenTree project -""" +"""Provides helper functions used throughout the InvenTree project.""" import io import json @@ -27,21 +25,15 @@ from .settings import MEDIA_URL, STATIC_URL def getSetting(key, backup_value=None): - """ - Shortcut for reading a setting value from the database - """ - + """Shortcut for reading a setting value from the database.""" return InvenTreeSetting.get_setting(key, backup_value=backup_value) def generateTestKey(test_name): - """ - Generate a test 'key' for a given test name. - This must not have illegal chars as it will be used for dict lookup in a template. + """Generate a test 'key' for a given test name. This must not have illegal chars as it will be used for dict lookup in a template. Tests must be named such that they will have unique keys. """ - key = test_name.strip().lower() key = key.replace(" ", "") @@ -52,33 +44,23 @@ def generateTestKey(test_name): def getMediaUrl(filename): - """ - Return the qualified access path for the given file, - under the media directory. - """ - + """Return the qualified access path for the given file, under the media directory.""" return os.path.join(MEDIA_URL, str(filename)) def getStaticUrl(filename): - """ - Return the qualified access path for the given file, - under the static media directory. - """ - + """Return the qualified access path for the given file, under the static media directory.""" return os.path.join(STATIC_URL, str(filename)) def construct_absolute_url(*arg): - """ - Construct (or attempt to construct) an absolute URL from a relative URL. + """Construct (or attempt to construct) an absolute URL from a relative URL. This is useful when (for example) sending an email to a user with a link to something in the InvenTree web framework. This requires the BASE_URL configuration option to be set! """ - base = str(InvenTreeSetting.get_setting('INVENTREE_BASE_URL')) url = '/'.join(arg) @@ -99,23 +81,17 @@ def construct_absolute_url(*arg): def getBlankImage(): - """ - Return the qualified path for the 'blank image' placeholder. - """ - + """Return the qualified path for the 'blank image' placeholder.""" return getStaticUrl("img/blank_image.png") def getBlankThumbnail(): - """ - Return the qualified path for the 'blank image' thumbnail placeholder. - """ - + """Return the qualified path for the 'blank image' thumbnail placeholder.""" return getStaticUrl("img/blank_image.thumbnail.png") def TestIfImage(img): - """ Test if an image file is indeed an image """ + """Test if an image file is indeed an image.""" try: Image.open(img).verify() return True @@ -124,7 +100,7 @@ def TestIfImage(img): def TestIfImageURL(url): - """ Test if an image URL (or filename) looks like a valid image format. + """Test if an image URL (or filename) looks like a valid image format. Simply tests the extension against a set of allowed values """ @@ -137,7 +113,7 @@ def TestIfImageURL(url): def str2bool(text, test=True): - """ Test if a string 'looks' like a boolean value. + """Test if a string 'looks' like a boolean value. Args: text: Input text @@ -153,10 +129,7 @@ def str2bool(text, test=True): def is_bool(text): - """ - Determine if a string value 'looks' like a boolean. - """ - + """Determine if a string value 'looks' like a boolean.""" if str2bool(text, True): return True elif str2bool(text, False): @@ -166,9 +139,7 @@ def is_bool(text): def isNull(text): - """ - Test if a string 'looks' like a null value. - This is useful for querying the API against a null key. + """Test if a string 'looks' like a null value. This is useful for querying the API against a null key. Args: text: Input text @@ -176,15 +147,11 @@ def isNull(text): Returns: True if the text looks like a null value """ - return str(text).strip().lower() in ['top', 'null', 'none', 'empty', 'false', '-1', ''] def normalize(d): - """ - Normalize a decimal number, and remove exponential formatting. - """ - + """Normalize a decimal number, and remove exponential formatting.""" if type(d) is not Decimal: d = Decimal(d) @@ -195,8 +162,7 @@ def normalize(d): def increment(n): - """ - Attempt to increment an integer (or a string that looks like an integer!) + """Attempt to increment an integer (or a string that looks like an integer). e.g. @@ -204,9 +170,7 @@ def increment(n): 2 -> 3 AB01 -> AB02 QQQ -> QQQ - """ - value = str(n).strip() # Ignore empty strings @@ -248,10 +212,7 @@ def increment(n): def decimal2string(d): - """ - Format a Decimal number as a string, - stripping out any trailing zeroes or decimal points. - Essentially make it look like a whole number if it is one. + """Format a Decimal number as a string, stripping out any trailing zeroes or decimal points. Essentially make it look like a whole number if it is one. Args: d: A python Decimal object @@ -259,7 +220,6 @@ def decimal2string(d): Returns: A string representation of the input number """ - if type(d) is Decimal: d = normalize(d) @@ -280,8 +240,7 @@ def decimal2string(d): def decimal2money(d, currency=None): - """ - Format a Decimal number as Money + """Format a Decimal number as Money. Args: d: A python Decimal object @@ -296,7 +255,7 @@ def decimal2money(d, currency=None): def WrapWithQuotes(text, quote='"'): - """ Wrap the supplied text with quotes + """Wrap the supplied text with quotes. Args: text: Input text to wrap @@ -305,7 +264,6 @@ def WrapWithQuotes(text, quote='"'): Returns: Supplied text wrapped in quote char """ - if not text.startswith(quote): text = quote + text @@ -316,7 +274,7 @@ def WrapWithQuotes(text, quote='"'): def MakeBarcode(object_name, object_pk, object_data=None, **kwargs): - """ Generate a string for a barcode. Adds some global InvenTree parameters. + """Generate a string for a barcode. Adds some global InvenTree parameters. Args: object_type: string describing the object type e.g. 'StockItem' @@ -363,8 +321,7 @@ def MakeBarcode(object_name, object_pk, object_data=None, **kwargs): def GetExportFormats(): - """ Return a list of allowable file formats for exporting data """ - + """Return a list of allowable file formats for exporting data.""" return [ 'csv', 'tsv', @@ -375,9 +332,8 @@ def GetExportFormats(): ] -def DownloadFile(data, filename, content_type='application/text', inline=False): - """ - Create a dynamic file for the user to download. +def DownloadFile(data, filename, content_type='application/text', inline=False) -> StreamingHttpResponse: + """Create a dynamic file for the user to download. Args: data: Raw file data (string or bytes) @@ -388,7 +344,6 @@ def DownloadFile(data, filename, content_type='application/text', inline=False): Return: A StreamingHttpResponse object wrapping the supplied data """ - filename = WrapWithQuotes(filename) if type(data) == str: @@ -407,8 +362,7 @@ def DownloadFile(data, filename, content_type='application/text', inline=False): def extract_serial_numbers(serials, expected_quantity, next_number: int): - """ - Attempt to extract serial numbers from an input string: + """Attempt to extract serial numbers from an input string. Requirements: - Serial numbers can be either strings, or integers @@ -423,7 +377,6 @@ def extract_serial_numbers(serials, expected_quantity, next_number: int): expected_quantity: The number of (unique) serial numbers we expect next_number(int): the next possible serial number """ - serials = serials.strip() # fill in the next serial number into the serial @@ -543,8 +496,7 @@ def extract_serial_numbers(serials, expected_quantity, next_number: int): def validateFilterString(value, model=None): - """ - Validate that a provided filter string looks like a list of comma-separated key=value pairs + """Validate that a provided filter string looks like a list of comma-separated key=value pairs. These should nominally match to a valid database filter based on the model being filtered. @@ -559,7 +511,6 @@ def validateFilterString(value, model=None): Returns a map of key:value pairs """ - # Empty results map results = {} @@ -605,28 +556,19 @@ def validateFilterString(value, model=None): def addUserPermission(user, permission): - """ - Shortcut function for adding a certain permission to a user. - """ - + """Shortcut function for adding a certain permission to a user.""" perm = Permission.objects.get(codename=permission) user.user_permissions.add(perm) def addUserPermissions(user, permissions): - """ - Shortcut function for adding multiple permissions to a user. - """ - + """Shortcut function for adding multiple permissions to a user.""" for permission in permissions: addUserPermission(user, permission) def getMigrationFileNames(app): - """ - Return a list of all migration filenames for provided app - """ - + """Return a list of all migration filenames for provided app.""" local_dir = os.path.dirname(os.path.abspath(__file__)) migration_dir = os.path.join(local_dir, '..', app, 'migrations') @@ -646,10 +588,7 @@ def getMigrationFileNames(app): def getOldestMigrationFile(app, exclude_extension=True, ignore_initial=True): - """ - Return the filename associated with the oldest migration - """ - + """Return the filename associated with the oldest migration.""" oldest_num = -1 oldest_file = None @@ -671,10 +610,7 @@ def getOldestMigrationFile(app, exclude_extension=True, ignore_initial=True): def getNewestMigrationFile(app, exclude_extension=True): - """ - Return the filename associated with the newest migration - """ - + """Return the filename associated with the newest migration.""" newest_file = None newest_num = -1 @@ -692,8 +628,7 @@ def getNewestMigrationFile(app, exclude_extension=True): def clean_decimal(number): - """ Clean-up decimal value """ - + """Clean-up decimal value.""" # Check if empty if number is None or number == '' or number == 0: return Decimal(0) @@ -729,7 +664,7 @@ def clean_decimal(number): def get_objectreference(obj, type_ref: str = 'content_type', object_ref: str = 'object_id'): - """lookup method for the GenericForeignKey fields + """Lookup method for the GenericForeignKey fields. Attributes: - obj: object that will be resolved @@ -769,9 +704,7 @@ def get_objectreference(obj, type_ref: str = 'content_type', object_ref: str = ' def inheritors(cls): - """ - Return all classes that are subclasses from the supplied cls - """ + """Return all classes that are subclasses from the supplied cls.""" subcls = set() work = [cls] while work: @@ -784,4 +717,5 @@ def inheritors(cls): class InvenTreeTestCase(UserMixin, TestCase): + """Testcase with user setup buildin.""" pass diff --git a/InvenTree/InvenTree/management/commands/clean_settings.py b/InvenTree/InvenTree/management/commands/clean_settings.py index 7607f2a574..77cc444422 100644 --- a/InvenTree/InvenTree/management/commands/clean_settings.py +++ b/InvenTree/InvenTree/management/commands/clean_settings.py @@ -1,6 +1,4 @@ -""" -Custom management command to cleanup old settings that are not defined anymore -""" +"""Custom management command to cleanup old settings that are not defined anymore.""" import logging @@ -10,12 +8,10 @@ logger = logging.getLogger('inventree') class Command(BaseCommand): - """ - Cleanup old (undefined) settings in the database - """ + """Cleanup old (undefined) settings in the database.""" def handle(self, *args, **kwargs): - + """Cleanup old (undefined) settings in the database.""" logger.info("Collecting settings") from common.models import InvenTreeSetting, InvenTreeUserSetting diff --git a/InvenTree/InvenTree/management/commands/prerender.py b/InvenTree/InvenTree/management/commands/prerender.py index efaaad80dc..a7338338cb 100644 --- a/InvenTree/InvenTree/management/commands/prerender.py +++ b/InvenTree/InvenTree/management/commands/prerender.py @@ -1,6 +1,4 @@ -""" -Custom management command to prerender files -""" +"""Custom management command to prerender files.""" import os @@ -13,7 +11,7 @@ from django.utils.translation import override as lang_over def render_file(file_name, source, target, locales, ctx): - """ renders a file into all provided locales """ + """Renders a file into all provided locales.""" for locale in locales: target_file = os.path.join(target, locale + '.' + file_name) with open(target_file, 'w') as localised_file: @@ -23,11 +21,10 @@ def render_file(file_name, source, target, locales, ctx): class Command(BaseCommand): - """ - django command to prerender files - """ + """Django command to prerender files.""" def handle(self, *args, **kwargs): + """Django command to prerender files.""" # static directories LC_DIR = settings.LOCALE_PATHS[0] SOURCE_DIR = settings.STATICFILES_I18_SRC diff --git a/InvenTree/InvenTree/management/commands/rebuild_models.py b/InvenTree/InvenTree/management/commands/rebuild_models.py index 2a60da9365..b93d1b4627 100644 --- a/InvenTree/InvenTree/management/commands/rebuild_models.py +++ b/InvenTree/InvenTree/management/commands/rebuild_models.py @@ -1,5 +1,4 @@ -""" -Custom management command to rebuild all MPTT models +"""Custom management command to rebuild all MPTT models. - This is crucial after importing any fixtures, etc """ @@ -8,12 +7,10 @@ from django.core.management.base import BaseCommand class Command(BaseCommand): - """ - Rebuild all database models which leverage the MPTT structure. - """ + """Rebuild all database models which leverage the MPTT structure.""" def handle(self, *args, **kwargs): - + """Rebuild all database models which leverage the MPTT structure.""" # Part model try: print("Rebuilding Part objects") diff --git a/InvenTree/InvenTree/management/commands/rebuild_thumbnails.py b/InvenTree/InvenTree/management/commands/rebuild_thumbnails.py index 3d684df06d..a37be73b9c 100644 --- a/InvenTree/InvenTree/management/commands/rebuild_thumbnails.py +++ b/InvenTree/InvenTree/management/commands/rebuild_thumbnails.py @@ -1,5 +1,4 @@ -""" -Custom management command to rebuild thumbnail images +"""Custom management command to rebuild thumbnail images. - May be required after importing a new dataset, for example """ @@ -20,15 +19,10 @@ logger = logging.getLogger('inventree') class Command(BaseCommand): - """ - Rebuild all thumbnail images - """ + """Rebuild all thumbnail images.""" def rebuild_thumbnail(self, model): - """ - Rebuild the thumbnail specified by the "image" field of the provided model - """ - + """Rebuild the thumbnail specified by the "image" field of the provided model.""" if not model.image: return @@ -47,7 +41,7 @@ class Command(BaseCommand): logger.warning(f"Warning: Image file '{img}' is not a valid image") def handle(self, *args, **kwargs): - + """Rebuild all thumbnail images.""" logger.info("Rebuilding Part thumbnails") for part in Part.objects.exclude(image=None): diff --git a/InvenTree/InvenTree/management/commands/remove_mfa.py b/InvenTree/InvenTree/management/commands/remove_mfa.py index 53266348e2..cd4ef8c88b 100644 --- a/InvenTree/InvenTree/management/commands/remove_mfa.py +++ b/InvenTree/InvenTree/management/commands/remove_mfa.py @@ -1,21 +1,18 @@ -""" -Custom management command to remove MFA for a user -""" +"""Custom management command to remove MFA for a user.""" from django.contrib.auth import get_user_model from django.core.management.base import BaseCommand class Command(BaseCommand): - """ - Remove MFA for a user - """ + """Remove MFA for a user.""" def add_arguments(self, parser): + """Add the arguments.""" parser.add_argument('mail', type=str) def handle(self, *args, **kwargs): - + """Remove MFA for the supplied user (by mail).""" # general settings mail = kwargs.get('mail') if not mail: diff --git a/InvenTree/InvenTree/management/commands/wait_for_db.py b/InvenTree/InvenTree/management/commands/wait_for_db.py index ebd6999f7f..b8b9742f0e 100644 --- a/InvenTree/InvenTree/management/commands/wait_for_db.py +++ b/InvenTree/InvenTree/management/commands/wait_for_db.py @@ -1,6 +1,4 @@ -""" -Custom management command, wait for the database to be ready! -""" +"""Custom management command, wait for the database to be ready!""" import time @@ -10,12 +8,10 @@ from django.db.utils import ImproperlyConfigured, OperationalError class Command(BaseCommand): - """ - django command to pause execution until the database is ready - """ + """Django command to pause execution until the database is ready.""" def handle(self, *args, **kwargs): - + """Wait till the database is ready.""" self.stdout.write("Waiting for database...") connected = False diff --git a/InvenTree/InvenTree/metadata.py b/InvenTree/InvenTree/metadata.py index e2d42bb539..473438b5df 100644 --- a/InvenTree/InvenTree/metadata.py +++ b/InvenTree/InvenTree/metadata.py @@ -1,3 +1,5 @@ +"""Custom metadata for DRF.""" + import logging from rest_framework import serializers @@ -12,8 +14,7 @@ logger = logging.getLogger('inventree') class InvenTreeMetadata(SimpleMetadata): - """ - Custom metadata class for the DRF API. + """Custom metadata class for the DRF API. This custom metadata class imits the available "actions", based on the user's role permissions. @@ -23,11 +24,10 @@ class InvenTreeMetadata(SimpleMetadata): Additionally, we include some extra information about database models, so we can perform lookup for ForeignKey related fields. - """ def determine_metadata(self, request, view): - + """Overwrite the metadata to adapt to hte request user.""" self.request = request self.view = view @@ -106,11 +106,7 @@ class InvenTreeMetadata(SimpleMetadata): return metadata def get_serializer_info(self, serializer): - """ - Override get_serializer_info so that we can add 'default' values - to any fields whose Meta.model specifies a default value - """ - + """Override get_serializer_info so that we can add 'default' values to any fields whose Meta.model specifies a default value.""" self.serializer = serializer serializer_info = super().get_serializer_info(serializer) @@ -208,10 +204,7 @@ class InvenTreeMetadata(SimpleMetadata): pass if instance is not None: - """ - If there is an instance associated with this API View, - introspect that instance to find any specific API info. - """ + """If there is an instance associated with this API View, introspect that instance to find any specific API info.""" if hasattr(instance, 'api_instance_filters'): @@ -233,13 +226,10 @@ class InvenTreeMetadata(SimpleMetadata): return serializer_info def get_field_info(self, field): - """ - Given an instance of a serializer field, return a dictionary - of metadata about it. + """Given an instance of a serializer field, return a dictionary of metadata about it. We take the regular DRF metadata and add our own unique flavor """ - # Run super method first field_info = super().get_field_info(field) diff --git a/InvenTree/InvenTree/middleware.py b/InvenTree/InvenTree/middleware.py index aaf13f7623..60df027bf5 100644 --- a/InvenTree/InvenTree/middleware.py +++ b/InvenTree/InvenTree/middleware.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +"""Middleware for InvenTree.""" import logging @@ -19,10 +19,17 @@ logger = logging.getLogger("inventree") class AuthRequiredMiddleware(object): + """Check for user to be authenticated.""" + def __init__(self, get_response): + """Save response object.""" self.get_response = get_response def __call__(self, request): + """Check if user needs to be authenticated and is. + + Redirects to login if not authenticated. + """ # Code to be executed for each request before # the view (and later middleware) are called. @@ -35,6 +42,7 @@ class AuthRequiredMiddleware(object): if not request.user.is_authenticated: """ Normally, a web-based session would use csrftoken based authentication. + However when running an external application (e.g. the InvenTree app or Python library), we must validate the user token manually. """ @@ -105,9 +113,9 @@ url_matcher = re_path('', include(frontendpatterns)) class Check2FAMiddleware(BaseRequire2FAMiddleware): - """check if user is required to have MFA enabled""" + """Check if user is required to have MFA enabled.""" def require_2fa(self, request): - # Superusers are require to have 2FA. + """Use setting to check if MFA should be enforced for frontend page.""" try: if url_matcher.resolve(request.path[1:]): return InvenTreeSetting.get_setting('LOGIN_ENFORCE_MFA') @@ -117,8 +125,9 @@ class Check2FAMiddleware(BaseRequire2FAMiddleware): class CustomAllauthTwoFactorMiddleware(AllauthTwoFactorMiddleware): - """This function ensures only frontend code triggers the MFA auth cycle""" + """This function ensures only frontend code triggers the MFA auth cycle.""" def process_request(self, request): + """Check if requested url is forntend and enforce MFA check.""" try: if not url_matcher.resolve(request.path[1:]): super().process_request(request) @@ -127,12 +136,11 @@ class CustomAllauthTwoFactorMiddleware(AllauthTwoFactorMiddleware): class InvenTreeRemoteUserMiddleware(PersistentRemoteUserMiddleware): - """ - Middleware to check if HTTP-header based auth is enabled and to set it up - """ + """Middleware to check if HTTP-header based auth is enabled and to set it up.""" header = settings.REMOTE_LOGIN_HEADER def process_request(self, request): + """Check if proxy login is enabled.""" if not settings.REMOTE_LOGIN: return diff --git a/InvenTree/InvenTree/models.py b/InvenTree/InvenTree/models.py index aba8b36763..a2b775f71b 100644 --- a/InvenTree/InvenTree/models.py +++ b/InvenTree/InvenTree/models.py @@ -1,6 +1,4 @@ -""" -Generic models which provide extra functionality over base Django model types. -""" +"""Generic models which provide extra functionality over base Django model types.""" import logging import os @@ -25,25 +23,21 @@ logger = logging.getLogger('inventree') def rename_attachment(instance, filename): - """ - Function for renaming an attachment file. - The subdirectory for the uploaded file is determined by the implementing class. + """Function for renaming an attachment file. The subdirectory for the uploaded file is determined by the implementing class. - Args: + Args: instance: Instance of a PartAttachment object filename: name of uploaded file Returns: path to store file, format: '//filename' """ - # Construct a path to store a file attachment for a given model type return os.path.join(instance.getSubdir(), filename) class DataImportMixin(object): - """ - Model mixin class which provides support for 'data import' functionality. + """Model mixin class which provides support for 'data import' functionality. Models which implement this mixin should provide information on the fields available for import """ @@ -53,12 +47,10 @@ class DataImportMixin(object): @classmethod def get_import_fields(cls): - """ - Return all available import fields + """Return all available import fields. Where information on a particular field is not explicitly provided, introspect the base model to (attempt to) find that information. - """ fields = cls.IMPORT_FIELDS @@ -85,7 +77,7 @@ class DataImportMixin(object): @classmethod def get_required_import_fields(cls): - """ Return all *required* import fields """ + """Return all *required* import fields.""" fields = {} for name, field in cls.get_import_fields().items(): @@ -98,8 +90,7 @@ class DataImportMixin(object): class ReferenceIndexingMixin(models.Model): - """ - A mixin for keeping track of numerical copies of the "reference" field. + """A mixin for keeping track of numerical copies of the "reference" field. !!DANGER!! always add `ReferenceIndexingSerializerMixin`to all your models serializers to ensure the reference field is not too big @@ -118,18 +109,20 @@ class ReferenceIndexingMixin(models.Model): """ class Meta: + """Metaclass options. Abstract ensures no database table is created.""" + abstract = True def rebuild_reference_field(self): - + """Extract integer out of reference for sorting.""" reference = getattr(self, 'reference', '') - self.reference_int = extract_int(reference) reference_int = models.BigIntegerField(default=0) def extract_int(reference, clip=0x7fffffff): + """Extract integer out of reference.""" # Default value if we cannot convert to an integer ref_int = 0 @@ -155,7 +148,7 @@ def extract_int(reference, clip=0x7fffffff): class InvenTreeAttachment(models.Model): - """ Provides an abstracted class for managing file attachments. + """Provides an abstracted class for managing file attachments. An attachment can be either an uploaded file, or an external URL @@ -167,14 +160,14 @@ class InvenTreeAttachment(models.Model): """ def getSubdir(self): - """ - Return the subdirectory under which attachments should be stored. + """Return the subdirectory under which attachments should be stored. + Note: Re-implement this for each subclass of InvenTreeAttachment """ - return "attachments" def save(self, *args, **kwargs): + """Provide better validation error.""" # Either 'attachment' or 'link' must be specified! if not self.attachment and not self.link: raise ValidationError({ @@ -185,6 +178,7 @@ class InvenTreeAttachment(models.Model): super().save(*args, **kwargs) def __str__(self): + """Human name for attachment.""" if self.attachment is not None: return os.path.basename(self.attachment.name) else: @@ -215,6 +209,7 @@ class InvenTreeAttachment(models.Model): @property def basename(self): + """Base name/path for attachment.""" if self.attachment: return os.path.basename(self.attachment.name) else: @@ -222,15 +217,13 @@ class InvenTreeAttachment(models.Model): @basename.setter def basename(self, fn): - """ - Function to rename the attachment file. + """Function to rename the attachment file. - Filename cannot be empty - Filename cannot contain illegal characters - Filename must specify an extension - Filename cannot match an existing file """ - fn = fn.strip() if len(fn) == 0: @@ -287,11 +280,13 @@ class InvenTreeAttachment(models.Model): raise ValidationError(_("Error renaming file")) class Meta: + """Metaclass options. Abstract ensures no database table is created.""" + abstract = True class InvenTreeTree(MPTTModel): - """ Provides an abstracted self-referencing tree model for data categories. + """Provides an abstracted self-referencing tree model for data categories. - Each Category has one parent Category, which can be blank (for a top-level Category). - Each Category can have zero-or-more child Categor(y/ies) @@ -303,10 +298,7 @@ class InvenTreeTree(MPTTModel): """ def api_instance_filters(self): - """ - Instance filters for InvenTreeTree models - """ - + """Instance filters for InvenTreeTree models.""" return { 'parent': { 'exclude_tree': self.pk, @@ -314,7 +306,7 @@ class InvenTreeTree(MPTTModel): } def save(self, *args, **kwargs): - + """Provide better error for invalid moves.""" try: super().save(*args, **kwargs) except InvalidMove: @@ -323,12 +315,15 @@ class InvenTreeTree(MPTTModel): }) class Meta: + """Metaclass defines extra model properties.""" + abstract = True # Names must be unique at any given level in the tree unique_together = ('name', 'parent') class MPTTMeta: + """Set insert order.""" order_insertion_by = ['name'] name = models.CharField( @@ -356,7 +351,7 @@ class InvenTreeTree(MPTTModel): @property def item_count(self): - """ Return the number of items which exist *under* this node in the tree. + """Return the number of items which exist *under* this node in the tree. Here an 'item' is considered to be the 'leaf' at the end of each branch, and the exact nature here will depend on the class implementation. @@ -366,30 +361,29 @@ class InvenTreeTree(MPTTModel): return 0 def getUniqueParents(self): - """ Return a flat set of all parent items that exist above this node. + """Return a flat set of all parent items that exist above this node. + If any parents are repeated (which would be very bad!), the process is halted """ - return self.get_ancestors() def getUniqueChildren(self, include_self=True): - """ Return a flat set of all child items that exist under this node. + """Return a flat set of all child items that exist under this node. + If any child items are repeated, the repetitions are omitted. """ - return self.get_descendants(include_self=include_self) @property def has_children(self): - """ True if there are any children under this item """ + """True if there are any children under this item.""" return self.getUniqueChildren(include_self=False).count() > 0 def getAcceptableParents(self): - """ Returns a list of acceptable parent items within this model - Acceptable parents are ones which are not underneath this item. + """Returns a list of acceptable parent items within this model Acceptable parents are ones which are not underneath this item. + Setting the parent of an item to its own child results in recursion. """ - contents = ContentType.objects.get_for_model(type(self)) available = contents.get_all_objects_for_this_type() @@ -407,17 +401,16 @@ class InvenTreeTree(MPTTModel): @property def parentpath(self): - """ Get the parent path of this category + """Get the parent path of this category. Returns: List of category names from the top level to the parent of this category """ - return [a for a in self.get_ancestors()] @property def path(self): - """ Get the complete part of this category. + """Get the complete part of this category. e.g. ["Top", "Second", "Third", "This"] @@ -428,25 +421,23 @@ class InvenTreeTree(MPTTModel): @property def pathstring(self): - """ Get a string representation for the path of this item. + """Get a string representation for the path of this item. e.g. "Top/Second/Third/This" """ return '/'.join([item.name for item in self.path]) def __str__(self): - """ String representation of a category is the full path to that category """ - + """String representation of a category is the full path to that category.""" return "{path} - {desc}".format(path=self.pathstring, desc=self.description) @receiver(pre_delete, sender=InvenTreeTree, dispatch_uid='tree_pre_delete_log') def before_delete_tree_item(sender, instance, using, **kwargs): - """ Receives pre_delete signal from InvenTreeTree object. + """Receives pre_delete signal from InvenTreeTree object. Before an item is deleted, update each child object to point to the parent of the object being deleted. """ - # Update each tree item below this one for child in instance.children.all(): child.parent = instance.parent diff --git a/InvenTree/InvenTree/permissions.py b/InvenTree/InvenTree/permissions.py index 920e111ce2..7f607f6dbf 100644 --- a/InvenTree/InvenTree/permissions.py +++ b/InvenTree/InvenTree/permissions.py @@ -1,12 +1,12 @@ +"""Permission set for InvenTree.""" + from rest_framework import permissions import users.models class RolePermission(permissions.BasePermission): - """ - Role mixin for API endpoints, allowing us to specify the user "role" - which is required for certain operations. + """Role mixin for API endpoints, allowing us to specify the user "role" which is required for certain operations. Each endpoint can have one or more of the following actions: - GET @@ -25,14 +25,10 @@ class RolePermission(permissions.BasePermission): to perform the specified action. For example, a DELETE action will be rejected unless the user has the "part.remove" permission - """ def has_permission(self, request, view): - """ - Determine if the current user has the specified permissions - """ - + """Determine if the current user has the specified permissions.""" user = request.user # Superuser can do it all diff --git a/InvenTree/InvenTree/ready.py b/InvenTree/InvenTree/ready.py index e93972cf2e..f7a319a92d 100644 --- a/InvenTree/InvenTree/ready.py +++ b/InvenTree/InvenTree/ready.py @@ -1,31 +1,24 @@ +"""Functions to check if certain parts of InvenTree are ready.""" + import sys def isInTestMode(): - """ - Returns True if the database is in testing mode - """ - + """Returns True if the database is in testing mode.""" return 'test' in sys.argv def isImportingData(): - """ - Returns True if the database is currently importing data, - e.g. 'loaddata' command is performed - """ - + """Returns True if the database is currently importing data, e.g. 'loaddata' command is performed.""" return 'loaddata' in sys.argv def canAppAccessDatabase(allow_test=False): - """ - Returns True if the apps.py file can access database records. + """Returns True if the apps.py file can access database records. There are some circumstances where we don't want the ready function in apps.py to touch the database """ - # If any of the following management commands are being executed, # prevent custom "on load" code from running! excluded_commands = [ diff --git a/InvenTree/InvenTree/serializers.py b/InvenTree/InvenTree/serializers.py index e91bcab484..22235bd3d8 100644 --- a/InvenTree/InvenTree/serializers.py +++ b/InvenTree/InvenTree/serializers.py @@ -1,6 +1,4 @@ -""" -Serializers used in various InvenTree apps -""" +"""Serializers used in various InvenTree apps.""" import os from collections import OrderedDict @@ -26,25 +24,20 @@ from .models import extract_int class InvenTreeMoneySerializer(MoneyField): - """ - Custom serializer for 'MoneyField', - which ensures that passed values are numerically valid + """Custom serializer for 'MoneyField', which ensures that passed values are numerically valid. Ref: https://github.com/django-money/django-money/blob/master/djmoney/contrib/django_rest_framework/fields.py """ def __init__(self, *args, **kwargs): - + """Overrite default values.""" kwargs["max_digits"] = kwargs.get("max_digits", 19) kwargs["decimal_places"] = kwargs.get("decimal_places", 4) super().__init__(*args, **kwargs) def get_value(self, data): - """ - Test that the returned amount is a valid Decimal - """ - + """Test that the returned amount is a valid Decimal.""" amount = super(DecimalField, self).get_value(data) # Convert an empty string to None @@ -68,17 +61,21 @@ class InvenTreeMoneySerializer(MoneyField): class UserSerializer(serializers.ModelSerializer): - """ Serializer for User - provides all fields """ + """Serializer for User - provides all fields.""" class Meta: + """Metaclass options.""" + model = User fields = 'all' class UserSerializerBrief(serializers.ModelSerializer): - """ Serializer for User - provides limited information """ + """Serializer for User - provides limited information.""" class Meta: + """Metaclass options.""" + model = User fields = [ 'pk', @@ -87,17 +84,10 @@ class UserSerializerBrief(serializers.ModelSerializer): class InvenTreeModelSerializer(serializers.ModelSerializer): - """ - Inherits the standard Django ModelSerializer class, - but also ensures that the underlying model class data are checked on validation. - """ + """Inherits the standard Django ModelSerializer class, but also ensures that the underlying model class data are checked on validation.""" def __init__(self, instance=None, data=empty, **kwargs): - """ - Custom __init__ routine to ensure that *default* values (as specified in the ORM) - are used by the DRF serializers, *if* the values are not provided by the user. - """ - + """Custom __init__ routine to ensure that *default* values (as specified in the ORM) are used by the DRF serializers, *if* the values are not provided by the user.""" # If instance is None, we are creating a new instance if instance is None and data is not empty: @@ -118,6 +108,7 @@ class InvenTreeModelSerializer(serializers.ModelSerializer): """ Update the field IF (and ONLY IF): + - The field has a specified default value - The field does not already have a value set """ @@ -137,11 +128,10 @@ class InvenTreeModelSerializer(serializers.ModelSerializer): super().__init__(instance, data, **kwargs) def get_initial(self): - """ - Construct initial data for the serializer. + """Construct initial data for the serializer. + Use the 'default' values specified by the django model definition """ - initials = super().get_initial().copy() # Are we creating a new instance? @@ -168,11 +158,7 @@ class InvenTreeModelSerializer(serializers.ModelSerializer): return initials def save(self, **kwargs): - """ - Catch any django ValidationError thrown at the moment save() is called, - and re-throw as a DRF ValidationError - """ - + """Catch any django ValidationError thrown at the moment `save` is called, and re-throw as a DRF ValidationError.""" try: super().save(**kwargs) except (ValidationError, DjangoValidationError) as exc: @@ -181,10 +167,7 @@ class InvenTreeModelSerializer(serializers.ModelSerializer): return self.instance def update(self, instance, validated_data): - """ - Catch any django ValidationError, and re-throw as a DRF ValidationError - """ - + """Catch any django ValidationError, and re-throw as a DRF ValidationError.""" try: instance = super().update(instance, validated_data) except (ValidationError, DjangoValidationError) as exc: @@ -193,12 +176,11 @@ class InvenTreeModelSerializer(serializers.ModelSerializer): return instance def run_validation(self, data=empty): - """ - Perform serializer validation. + """Perform serializer validation. + In addition to running validators on the serializer fields, this class ensures that the underlying model is also validated. """ - # Run any native validation checks first (may raise a ValidationError) data = super().run_validation(data) @@ -237,20 +219,17 @@ class InvenTreeModelSerializer(serializers.ModelSerializer): class ReferenceIndexingSerializerMixin(): - """ - This serializer mixin ensures the the reference is not to big / small - for the BigIntegerField - """ + """This serializer mixin ensures the the reference is not to big / small for the BigIntegerField.""" + def validate_reference(self, value): + """Ensures the reference is not to big / small for the BigIntegerField.""" if extract_int(value) > models.BigIntegerField.MAX_BIGINT: raise serializers.ValidationError('reference is to to big') return value class InvenTreeAttachmentSerializerField(serializers.FileField): - """ - Override the DRF native FileField serializer, - to remove the leading server path. + """Override the DRF native FileField serializer, to remove the leading server path. For example, the FileField might supply something like: @@ -269,7 +248,7 @@ class InvenTreeAttachmentSerializerField(serializers.FileField): """ def to_representation(self, value): - + """To json-serializable type.""" if not value: return None @@ -277,8 +256,7 @@ class InvenTreeAttachmentSerializerField(serializers.FileField): class InvenTreeAttachmentSerializer(InvenTreeModelSerializer): - """ - Special case of an InvenTreeModelSerializer, which handles an "attachment" model. + """Special case of an InvenTreeModelSerializer, which handles an "attachment" model. The only real addition here is that we support "renaming" of the attachment file. """ @@ -298,13 +276,13 @@ class InvenTreeAttachmentSerializer(InvenTreeModelSerializer): class InvenTreeImageSerializerField(serializers.ImageField): - """ - Custom image serializer. + """Custom image serializer. + On upload, validate that the file is a valid image file """ def to_representation(self, value): - + """To json-serializable type.""" if not value: return None @@ -312,15 +290,15 @@ class InvenTreeImageSerializerField(serializers.ImageField): class InvenTreeDecimalField(serializers.FloatField): - """ - Custom serializer for decimal fields. Solves the following issues: + """Custom serializer for decimal fields. + Solves the following issues: - The normal DRF DecimalField renders values with trailing zeros - Using a FloatField can result in rounding issues: https://code.djangoproject.com/ticket/30290 """ def to_internal_value(self, data): - + """Convert to python type.""" # Convert the value to a string, and then a decimal try: return Decimal(str(data)) @@ -329,8 +307,7 @@ class InvenTreeDecimalField(serializers.FloatField): class DataFileUploadSerializer(serializers.Serializer): - """ - Generic serializer for uploading a data file, and extracting a dataset. + """Generic serializer for uploading a data file, and extracting a dataset. - Validates uploaded file - Extracts column names @@ -341,6 +318,8 @@ class DataFileUploadSerializer(serializers.Serializer): TARGET_MODEL = None class Meta: + """Metaclass options.""" + fields = [ 'data_file', ] @@ -353,10 +332,7 @@ class DataFileUploadSerializer(serializers.Serializer): ) def validate_data_file(self, data_file): - """ - Perform validation checks on the uploaded data file. - """ - + """Perform validation checks on the uploaded data file.""" self.filename = data_file.name name, ext = os.path.splitext(data_file.name) @@ -406,15 +382,13 @@ class DataFileUploadSerializer(serializers.Serializer): return data_file def match_column(self, column_name, field_names, exact=False): - """ - Attempt to match a column name (from the file) to a field (defined in the model) + """Attempt to match a column name (from the file) to a field (defined in the model). Order of matching is: - Direct match - Case insensitive match - Fuzzy match """ - if not column_name: return None @@ -439,10 +413,7 @@ class DataFileUploadSerializer(serializers.Serializer): return None def extract_data(self): - """ - Returns dataset extracted from the file - """ - + """Returns dataset extracted from the file.""" # Provide a dict of available import fields for the model model_fields = {} @@ -483,12 +454,12 @@ class DataFileUploadSerializer(serializers.Serializer): } def save(self): + """Empty overwrite for save.""" ... class DataFileExtractSerializer(serializers.Serializer): - """ - Generic serializer for extracting data from an imported dataset. + """Generic serializer for extracting data from an imported dataset. - User provides an array of matched headers - User provides an array of raw data rows @@ -498,6 +469,8 @@ class DataFileExtractSerializer(serializers.Serializer): TARGET_MODEL = None class Meta: + """Metaclass options.""" + fields = [ 'columns', 'rows', @@ -520,7 +493,7 @@ class DataFileExtractSerializer(serializers.Serializer): ) def validate(self, data): - + """Clean data.""" data = super().validate(data) self.columns = data.get('columns', []) @@ -538,7 +511,7 @@ class DataFileExtractSerializer(serializers.Serializer): @property def data(self): - + """Returns current data.""" if self.TARGET_MODEL: try: model_fields = self.TARGET_MODEL.get_import_fields() @@ -548,9 +521,7 @@ class DataFileExtractSerializer(serializers.Serializer): rows = [] for row in self.rows: - """ - Optionally pre-process each row, before sending back to the client - """ + """Optionally pre-process each row, before sending back to the client.""" processed_row = self.process_row(self.row_to_dict(row)) @@ -567,22 +538,17 @@ class DataFileExtractSerializer(serializers.Serializer): } def process_row(self, row): - """ - Process a 'row' of data, which is a mapped column:value dict + """Process a 'row' of data, which is a mapped column:value dict. Returns either a mapped column:value dict, or None. If the function returns None, the column is ignored! """ - # Default implementation simply returns the original row data return row def row_to_dict(self, row): - """ - Convert a "row" to a named data dict - """ - + """Convert a "row" to a named data dict.""" row_dict = { 'errors': {}, } @@ -598,10 +564,7 @@ class DataFileExtractSerializer(serializers.Serializer): return row_dict def validate_extracted_columns(self): - """ - Perform custom validation of header mapping. - """ - + """Perform custom validation of header mapping.""" if self.TARGET_MODEL: try: model_fields = self.TARGET_MODEL.get_import_fields() @@ -631,7 +594,5 @@ class DataFileExtractSerializer(serializers.Serializer): cols_seen.add(col) def save(self): - """ - No "save" action for this serializer - """ - ... + """No "save" action for this serializer.""" + pass diff --git a/InvenTree/InvenTree/settings.py b/InvenTree/InvenTree/settings.py index 2f99274bd4..9a6b81d16f 100644 --- a/InvenTree/InvenTree/settings.py +++ b/InvenTree/InvenTree/settings.py @@ -1,5 +1,4 @@ -""" -Django settings for InvenTree project. +"""Django settings for InvenTree project. In practice the settings in this file should not be adjusted, instead settings can be configured in the config.yaml file @@ -8,7 +7,6 @@ located in the top level project directory. This allows implementation configuration to be hidden from source control, as well as separate configuration parameters from the more complex database setup in this file. - """ import logging diff --git a/InvenTree/InvenTree/status.py b/InvenTree/InvenTree/status.py index 9ea57024d8..cc74b850ea 100644 --- a/InvenTree/InvenTree/status.py +++ b/InvenTree/InvenTree/status.py @@ -1,6 +1,4 @@ -""" -Provides system status functionality checks. -""" +"""Provides system status functionality checks.""" # -*- coding: utf-8 -*- import logging @@ -19,10 +17,7 @@ logger = logging.getLogger("inventree") def is_worker_running(**kwargs): - """ - Return True if the background worker process is oprational - """ - + """Return True if the background worker process is oprational.""" clusters = Stat.get_all() if len(clusters) > 0: @@ -48,12 +43,10 @@ def is_worker_running(**kwargs): def is_email_configured(): - """ - Check if email backend is configured. + """Check if email backend is configured. NOTE: This does not check if the configuration is valid! """ - configured = True if InvenTree.ready.isInTestMode(): @@ -87,12 +80,10 @@ def is_email_configured(): def check_system_health(**kwargs): - """ - Check that the InvenTree system is running OK. + """Check that the InvenTree system is running OK. Returns True if all system checks pass. """ - result = True if InvenTree.ready.isInTestMode(): diff --git a/InvenTree/InvenTree/status_codes.py b/InvenTree/InvenTree/status_codes.py index 15f3d872bb..73684817e8 100644 --- a/InvenTree/InvenTree/status_codes.py +++ b/InvenTree/InvenTree/status_codes.py @@ -1,9 +1,11 @@ +"""Status codes for InvenTree.""" + from django.utils.translation import gettext_lazy as _ class StatusCode: - """ - Base class for representing a set of StatusCodes. + """Base class for representing a set of StatusCodes. + This is used to map a set of integer values to text. """ @@ -11,10 +13,7 @@ class StatusCode: @classmethod def render(cls, key, large=False): - """ - Render the value as a HTML label. - """ - + """Render the value as a HTML label.""" # If the key cannot be found, pass it back if key not in cls.options.keys(): return key @@ -31,10 +30,7 @@ class StatusCode: @classmethod def list(cls): - """ - Return the StatusCode options as a list of mapped key / value items - """ - + """Return the StatusCode options as a list of mapped key / value items.""" codes = [] for key in cls.options.keys(): @@ -55,28 +51,32 @@ class StatusCode: @classmethod def text(cls, key): + """Text for supplied status code.""" return cls.options.get(key, None) @classmethod def items(cls): + """All status code items.""" return cls.options.items() @classmethod def keys(cls): + """All status code keys.""" return cls.options.keys() @classmethod def labels(cls): + """All status code labels.""" return cls.options.values() @classmethod def label(cls, value): - """ Return the status code label associated with the provided value """ + """Return the status code label associated with the provided value.""" return cls.options.get(value, value) @classmethod def value(cls, label): - """ Return the value associated with the provided label """ + """Return the value associated with the provided label.""" for k in cls.options.keys(): if cls.options[k].lower() == label.lower(): return k @@ -85,9 +85,7 @@ class StatusCode: class PurchaseOrderStatus(StatusCode): - """ - Defines a set of status codes for a PurchaseOrder - """ + """Defines a set of status codes for a PurchaseOrder.""" # Order status codes PENDING = 10 # Order is pending (not yet placed) @@ -130,7 +128,7 @@ class PurchaseOrderStatus(StatusCode): class SalesOrderStatus(StatusCode): - """ Defines a set of status codes for a SalesOrder """ + """Defines a set of status codes for a SalesOrder.""" PENDING = 10 # Order is pending SHIPPED = 20 # Order has been shipped to customer @@ -166,6 +164,7 @@ class SalesOrderStatus(StatusCode): class StockStatus(StatusCode): + """Status codes for Stock.""" OK = 10 # Item is OK ATTENTION = 50 # Item requires attention @@ -207,6 +206,7 @@ class StockStatus(StatusCode): class StockHistoryCode(StatusCode): + """Status codes for StockHistory.""" LEGACY = 0 @@ -295,8 +295,8 @@ class StockHistoryCode(StatusCode): class BuildStatus(StatusCode): + """Build status codes.""" - # Build status codes PENDING = 10 # Build is pending / active PRODUCTION = 20 # BuildOrder is in production CANCELLED = 30 # Build was cancelled diff --git a/InvenTree/InvenTree/tasks.py b/InvenTree/InvenTree/tasks.py index e118c607fb..35893af2c9 100644 --- a/InvenTree/InvenTree/tasks.py +++ b/InvenTree/InvenTree/tasks.py @@ -1,3 +1,5 @@ +"""Functions for tasks and a few general async tasks.""" + import json import logging import re @@ -16,11 +18,10 @@ logger = logging.getLogger("inventree") def schedule_task(taskname, **kwargs): - """ - Create a scheduled task. + """Create a scheduled task. + If the task has already been scheduled, ignore! """ - # If unspecified, repeat indefinitely repeats = kwargs.pop('repeats', -1) kwargs['repeats'] = repeats @@ -52,7 +53,7 @@ def schedule_task(taskname, **kwargs): def raise_warning(msg): - """Log and raise a warning""" + """Log and raise a warning.""" logger.warning(msg) # If testing is running raise a warning that can be asserted @@ -61,15 +62,11 @@ def raise_warning(msg): def offload_task(taskname, *args, force_sync=False, **kwargs): - """ - Create an AsyncTask if workers are running. - This is different to a 'scheduled' task, - in that it only runs once! + """Create an AsyncTask if workers are running. This is different to a 'scheduled' task, in that it only runs once! - If workers are not running or force_sync flag - is set then the task is ran synchronously. + If workers are not running or force_sync flag + is set then the task is ran synchronously. """ - try: import importlib @@ -129,14 +126,10 @@ def offload_task(taskname, *args, force_sync=False, **kwargs): def heartbeat(): - """ - Simple task which runs at 5 minute intervals, - so we can determine that the background worker - is actually running. + """Simple task which runs at 5 minute intervals, so we can determine that the background worker is actually running. (There is probably a less "hacky" way of achieving this)? """ - try: from django_q.models import Success except AppRegistryNotReady: # pragma: no cover @@ -156,11 +149,7 @@ def heartbeat(): def delete_successful_tasks(): - """ - Delete successful task logs - which are more than a month old. - """ - + """Delete successful task logs which are more than a month old.""" try: from django_q.models import Success except AppRegistryNotReady: # pragma: no cover @@ -179,10 +168,7 @@ def delete_successful_tasks(): def delete_old_error_logs(): - """ - Delete old error logs from the server - """ - + """Delete old error logs from the server.""" try: from error_report.models import Error @@ -204,10 +190,7 @@ def delete_old_error_logs(): def check_for_updates(): - """ - Check if there is an update for InvenTree - """ - + """Check if there is an update for InvenTree.""" try: import common.models except AppRegistryNotReady: # pragma: no cover @@ -249,10 +232,7 @@ def check_for_updates(): def update_exchange_rates(): - """ - Update currency exchange rates - """ - + """Update currency exchange rates.""" try: from djmoney.contrib.exchange.models import ExchangeBackend, Rate @@ -293,11 +273,7 @@ def update_exchange_rates(): def send_email(subject, body, recipients, from_email=None, html_message=None): - """ - Send an email with the specified subject and body, - to the specified recipients list. - """ - + """Send an email with the specified subject and body, to the specified recipients list.""" if type(recipients) == str: recipients = [recipients] diff --git a/InvenTree/InvenTree/test_api.py b/InvenTree/InvenTree/test_api.py index 889ff674b3..46029d1540 100644 --- a/InvenTree/InvenTree/test_api.py +++ b/InvenTree/InvenTree/test_api.py @@ -1,4 +1,4 @@ -""" Low level tests for the InvenTree API """ +"""Low level tests for the InvenTree API.""" from base64 import b64encode @@ -12,8 +12,7 @@ from users.models import RuleSet class HTMLAPITests(InvenTreeTestCase): - """ - Test that we can access the REST API endpoints via the HTML interface. + """Test that we can access the REST API endpoints via the HTML interface. History: Discovered on 2021-06-28 a bug in InvenTreeModelSerializer, which raised an AssertionError when using the HTML API interface, @@ -22,6 +21,7 @@ class HTMLAPITests(InvenTreeTestCase): roles = 'all' def test_part_api(self): + """Test that part list is working.""" url = reverse('api-part-list') # Check JSON response @@ -33,6 +33,7 @@ class HTMLAPITests(InvenTreeTestCase): self.assertEqual(response.status_code, 200) def test_build_api(self): + """Test that build list is working.""" url = reverse('api-build-list') # Check JSON response @@ -44,6 +45,7 @@ class HTMLAPITests(InvenTreeTestCase): self.assertEqual(response.status_code, 200) def test_stock_api(self): + """Test that stock list is working.""" url = reverse('api-stock-list') # Check JSON response @@ -55,6 +57,7 @@ class HTMLAPITests(InvenTreeTestCase): self.assertEqual(response.status_code, 200) def test_company_list(self): + """Test that company list is working.""" url = reverse('api-company-list') # Check JSON response @@ -66,14 +69,13 @@ class HTMLAPITests(InvenTreeTestCase): self.assertEqual(response.status_code, 200) def test_not_found(self): - """Test that the NotFoundView is working""" - + """Test that the NotFoundView is working.""" response = self.client.get('/api/anc') self.assertEqual(response.status_code, 404) class APITests(InvenTreeAPITestCase): - """ Tests for the InvenTree API """ + """Tests for the InvenTree API.""" fixtures = [ 'location', @@ -81,16 +83,11 @@ class APITests(InvenTreeAPITestCase): 'part', 'stock' ] - token = None - auto_login = False - def setUp(self): - - super().setUp() - def basicAuth(self): + """Helper function to use basic auth.""" # Use basic authentication authstring = bytes("{u}:{p}".format(u=self.username, p=self.password), "ascii") @@ -100,7 +97,7 @@ class APITests(InvenTreeAPITestCase): self.client.credentials(HTTP_AUTHORIZATION="Basic {auth}".format(auth=auth)) def tokenAuth(self): - + """Helper function to use token auth.""" self.basicAuth() token_url = reverse('api-token') response = self.client.get(token_url, format='json', data={}) @@ -112,6 +109,7 @@ class APITests(InvenTreeAPITestCase): self.token = token def test_token_failure(self): + """Test token resolve endpoint does not work without basic auth.""" # Test token endpoint without basic auth url = reverse('api-token') response = self.client.get(url, format='json') @@ -120,15 +118,12 @@ class APITests(InvenTreeAPITestCase): self.assertIsNone(self.token) def test_token_success(self): - + """Test token auth works.""" self.tokenAuth() self.assertIsNotNone(self.token) def test_info_view(self): - """ - Test that we can read the 'info-view' endpoint. - """ - + """Test that we can read the 'info-view' endpoint.""" url = reverse('api-inventree-info') response = self.client.get(url, format='json') @@ -141,12 +136,10 @@ class APITests(InvenTreeAPITestCase): self.assertEqual('InvenTree', data['server']) def test_role_view(self): - """ - Test that we can access the 'roles' view for the logged in user. + """Test that we can access the 'roles' view for the logged in user. Also tests that it is *not* accessible if the client is not logged in. """ - url = reverse('api-user-roles') response = self.client.get(url, format='json') @@ -182,10 +175,7 @@ class APITests(InvenTreeAPITestCase): self.assertNotIn('delete', roles[rule]) def test_with_superuser(self): - """ - Superuser should have *all* roles assigned - """ - + """Superuser should have *all* roles assigned.""" self.user.is_superuser = True self.user.save() @@ -202,10 +192,7 @@ class APITests(InvenTreeAPITestCase): self.assertIn(perm, roles[rule]) def test_with_roles(self): - """ - Assign some roles to the user - """ - + """Assign some roles to the user.""" self.basicAuth() response = self.get(reverse('api-user-roles')) @@ -220,10 +207,7 @@ class APITests(InvenTreeAPITestCase): self.assertIn('change', roles['build']) def test_list_endpoint_actions(self): - """ - Tests for the OPTIONS method for API endpoints. - """ - + """Tests for the OPTIONS method for API endpoints.""" self.basicAuth() # Without any 'part' permissions, we should not see any available actions @@ -252,10 +236,7 @@ class APITests(InvenTreeAPITestCase): self.assertIn('GET', actions) def test_detail_endpoint_actions(self): - """ - Tests for detail API endpoint actions - """ - + """Tests for detail API endpoint actions.""" self.basicAuth() url = reverse('api-part-detail', kwargs={'pk': 1}) diff --git a/InvenTree/InvenTree/test_middleware.py b/InvenTree/InvenTree/test_middleware.py index e9e5e4846f..96748b00f1 100644 --- a/InvenTree/InvenTree/test_middleware.py +++ b/InvenTree/InvenTree/test_middleware.py @@ -1,4 +1,4 @@ -"""Tests for middleware functions""" +"""Tests for middleware functions.""" from django.urls import reverse @@ -6,16 +6,16 @@ from InvenTree.helpers import InvenTreeTestCase class MiddlewareTests(InvenTreeTestCase): - """Test for middleware functions""" + """Test for middleware functions.""" def check_path(self, url, code=200, **kwargs): + """Helper function to run a request.""" response = self.client.get(url, HTTP_ACCEPT='application/json', **kwargs) self.assertEqual(response.status_code, code) return response def test_AuthRequiredMiddleware(self): - """Test the auth middleware""" - + """Test the auth middleware.""" # test that /api/ routes go through self.check_path(reverse('api-inventree-info')) @@ -40,7 +40,7 @@ class MiddlewareTests(InvenTreeTestCase): self.check_path(reverse('settings.js'), 401) def test_token_auth(self): - """Test auth with token auth""" + """Test auth with token auth.""" # get token response = self.client.get(reverse('api-token'), format='json', data={}) token = response.data['token'] diff --git a/InvenTree/InvenTree/test_tasks.py b/InvenTree/InvenTree/test_tasks.py index 1b04e7da5b..500d1d7c7f 100644 --- a/InvenTree/InvenTree/test_tasks.py +++ b/InvenTree/InvenTree/test_tasks.py @@ -1,6 +1,4 @@ -""" -Unit tests for task management -""" +"""Unit tests for task management.""" from datetime import timedelta @@ -18,19 +16,14 @@ threshold_low = threshold - timedelta(days=1) class ScheduledTaskTests(TestCase): - """ - Unit tests for scheduled tasks - """ + """Unit tests for scheduled tasks.""" def get_tasks(self, name): - + """Helper function to get a Schedule object.""" return Schedule.objects.filter(func=name) def test_add_task(self): - """ - Ensure that duplicate tasks cannot be added. - """ - + """Ensure that duplicate tasks cannot be added.""" task = 'InvenTree.tasks.heartbeat' self.assertEqual(self.get_tasks(task).count(), 0) @@ -53,16 +46,15 @@ class ScheduledTaskTests(TestCase): def get_result(): - """Demo function for test_offloading""" + """Demo function for test_offloading.""" return 'abc' class InvenTreeTaskTests(TestCase): - """Unit tests for tasks""" + """Unit tests for tasks.""" def test_offloading(self): - """Test task offloading""" - + """Test task offloading.""" # Run with function ref InvenTree.tasks.offload_task(get_result) @@ -83,11 +75,11 @@ class InvenTreeTaskTests(TestCase): InvenTree.tasks.offload_task('InvenTree.test_tasks.doesnotexsist') def test_task_hearbeat(self): - """Test the task heartbeat""" + """Test the task heartbeat.""" InvenTree.tasks.offload_task(InvenTree.tasks.heartbeat) def test_task_delete_successful_tasks(self): - """Test the task delete_successful_tasks""" + """Test the task delete_successful_tasks.""" from django_q.models import Success Success.objects.create(name='abc', func='abc', stopped=threshold, started=threshold_low) @@ -96,8 +88,7 @@ class InvenTreeTaskTests(TestCase): self.assertEqual(len(results), 0) def test_task_delete_old_error_logs(self): - """Test the task delete_old_error_logs""" - + """Test the task delete_old_error_logs.""" # Create error error_obj = Error.objects.create() error_obj.when = threshold_low @@ -115,7 +106,7 @@ class InvenTreeTaskTests(TestCase): self.assertEqual(len(errors), 0) def test_task_check_for_updates(self): - """Test the task check_for_updates""" + """Test the task check_for_updates.""" # Check that setting should be empty self.assertEqual(InvenTreeSetting.get_setting('INVENTREE_LATEST_VERSION'), '') diff --git a/InvenTree/InvenTree/test_urls.py b/InvenTree/InvenTree/test_urls.py index 7f41f6e9fe..8586de0e69 100644 --- a/InvenTree/InvenTree/test_urls.py +++ b/InvenTree/InvenTree/test_urls.py @@ -1,6 +1,4 @@ -""" -Validate that all URLs specified in template files are correct. -""" +"""Validate that all URLs specified in template files are correct.""" import os import re @@ -11,6 +9,7 @@ from django.urls import reverse class URLTest(TestCase): + """Test all files for broken url tags.""" # Need fixture data in the database fixtures = [ @@ -35,11 +34,7 @@ class URLTest(TestCase): ] def find_files(self, suffix): - """ - Search for all files in the template directories, - which can have URLs rendered - """ - + """Search for all files in the template directories, which can have URLs rendered.""" template_dirs = [ ('build', 'templates'), ('common', 'templates'), @@ -71,10 +66,7 @@ class URLTest(TestCase): return template_files def find_urls(self, input_file): - """ - Search for all instances of {% url %} in supplied template file - """ - + """Search for all instances of {% url %} in supplied template file.""" urls = [] pattern = "{% url ['\"]([^'\"]+)['\"]([^%]*)%}" @@ -100,10 +92,7 @@ class URLTest(TestCase): return urls def reverse_url(self, url_pair): - """ - Perform lookup on the URL - """ - + """Perform lookup on the URL.""" url, pk = url_pair # Ignore "renaming" @@ -125,24 +114,21 @@ class URLTest(TestCase): reverse(url) def check_file(self, f): - """ - Run URL checks for the provided file - """ - + """Run URL checks for the provided file.""" urls = self.find_urls(f) for url in urls: self.reverse_url(url) def test_html_templates(self): - + """Test all HTML templates for broken url tags.""" template_files = self.find_files("*.html") for f in template_files: self.check_file(f) def test_js_templates(self): - + """Test all JS templates for broken url tags.""" template_files = self.find_files("*.js") for f in template_files: diff --git a/InvenTree/InvenTree/test_views.py b/InvenTree/InvenTree/test_views.py index 425de5e46c..c4d4cfd513 100644 --- a/InvenTree/InvenTree/test_views.py +++ b/InvenTree/InvenTree/test_views.py @@ -1,6 +1,4 @@ -""" -Unit tests for the main web views -""" +"""Unit tests for the main web views.""" import os import re @@ -11,33 +9,26 @@ from InvenTree.helpers import InvenTreeTestCase class ViewTests(InvenTreeTestCase): - """ Tests for various top-level views """ + """Tests for various top-level views.""" username = 'test_user' password = 'test_pass' def test_api_doc(self): - """ Test that the api-doc view works """ - + """Test that the api-doc view works.""" api_url = os.path.join(reverse('index'), 'api-doc') + '/' response = self.client.get(api_url) self.assertEqual(response.status_code, 200) def test_index_redirect(self): - """ - top-level URL should redirect to "index" page - """ - + """Top-level URL should redirect to "index" page.""" response = self.client.get("/") self.assertEqual(response.status_code, 302) def get_index_page(self): - """ - Retrieve the index page (used for subsequent unit tests) - """ - + """Retrieve the index page (used for subsequent unit tests)""" response = self.client.get("/index/") self.assertEqual(response.status_code, 200) @@ -45,10 +36,7 @@ class ViewTests(InvenTreeTestCase): return str(response.content.decode()) def test_panels(self): - """ - Test that the required 'panels' are present - """ - + """Test that the required 'panels' are present.""" content = self.get_index_page() self.assertIn("
", content) @@ -56,10 +44,7 @@ class ViewTests(InvenTreeTestCase): # TODO: In future, run the javascript and ensure that the panels get created! def test_js_load(self): - """ - Test that the required javascript files are loaded correctly - """ - + """Test that the required javascript files are loaded correctly.""" # Change this number as more javascript files are added to the index page N_SCRIPT_FILES = 40 diff --git a/InvenTree/InvenTree/tests.py b/InvenTree/InvenTree/tests.py index 3de293ca66..c31be5bfce 100644 --- a/InvenTree/InvenTree/tests.py +++ b/InvenTree/InvenTree/tests.py @@ -1,3 +1,5 @@ +"""Test general functions and helpers.""" + import json import os import time @@ -24,20 +26,17 @@ from .validators import validate_overage, validate_part_name class ValidatorTest(TestCase): - - """ Simple tests for custom field validators """ + """Simple tests for custom field validators.""" def test_part_name(self): - """ Test part name validator """ - + """Test part name validator.""" validate_part_name('hello world') with self.assertRaises(django_exceptions.ValidationError): validate_part_name('This | name is not } valid') def test_overage(self): - """ Test overage validator """ - + """Test overage validator.""" validate_overage("100%") validate_overage("10") validate_overage("45.2 %") @@ -59,11 +58,10 @@ class ValidatorTest(TestCase): class TestHelpers(TestCase): - """ Tests for InvenTree helper functions """ + """Tests for InvenTree helper functions.""" def test_image_url(self): - """ Test if a filename looks like an image """ - + """Test if a filename looks like an image.""" for name in ['ape.png', 'bat.GiF', 'apple.WeBP', 'BiTMap.Bmp']: self.assertTrue(helpers.TestIfImageURL(name)) @@ -71,8 +69,7 @@ class TestHelpers(TestCase): self.assertFalse(helpers.TestIfImageURL(name)) def test_str2bool(self): - """ Test string to boolean conversion """ - + """Test string to boolean conversion.""" for s in ['yes', 'Y', 'ok', '1', 'OK', 'Ok', 'tRuE', 'oN']: self.assertTrue(helpers.str2bool(s)) self.assertFalse(helpers.str2bool(s, test=False)) @@ -86,7 +83,7 @@ class TestHelpers(TestCase): self.assertFalse(helpers.str2bool(s, test=False)) def test_isnull(self): - + """Test isNull.""" for s in ['null', 'none', '', '-1', 'false']: self.assertTrue(helpers.isNull(s)) @@ -94,35 +91,35 @@ class TestHelpers(TestCase): self.assertFalse(helpers.isNull(s)) def testStaticUrl(self): - + """Test static url helpers.""" self.assertEqual(helpers.getStaticUrl('test.jpg'), '/static/test.jpg') self.assertEqual(helpers.getBlankImage(), '/static/img/blank_image.png') self.assertEqual(helpers.getBlankThumbnail(), '/static/img/blank_image.thumbnail.png') def testMediaUrl(self): - + """Test getMediaUrl.""" self.assertEqual(helpers.getMediaUrl('xx/yy.png'), '/media/xx/yy.png') def testDecimal2String(self): - + """Test decimal2string.""" self.assertEqual(helpers.decimal2string(Decimal('1.2345000')), '1.2345') self.assertEqual(helpers.decimal2string('test'), 'test') class TestQuoteWrap(TestCase): - """ Tests for string wrapping """ + """Tests for string wrapping.""" def test_single(self): - + """Test WrapWithQuotes.""" self.assertEqual(helpers.WrapWithQuotes('hello'), '"hello"') self.assertEqual(helpers.WrapWithQuotes('hello"'), '"hello"') class TestIncrement(TestCase): + """Tests for increment function.""" def tests(self): - """ Test 'intelligent' incrementing function """ - + """Test 'intelligent' incrementing function.""" tests = [ ("", ""), (1, "2"), @@ -142,10 +139,10 @@ class TestIncrement(TestCase): class TestMakeBarcode(TestCase): - """ Tests for barcode string creation """ + """Tests for barcode string creation.""" def test_barcode_extended(self): - + """Test creation of barcode with extended data.""" bc = helpers.MakeBarcode( "part", 3, @@ -166,7 +163,7 @@ class TestMakeBarcode(TestCase): self.assertEqual(data['part']['url'], 'www.google.com') def test_barcode_brief(self): - + """Test creation of simple barcode.""" bc = helpers.MakeBarcode( "stockitem", 7, @@ -178,27 +175,29 @@ class TestMakeBarcode(TestCase): class TestDownloadFile(TestCase): + """Tests for DownloadFile.""" def test_download(self): + """Tests for DownloadFile.""" helpers.DownloadFile("hello world", "out.txt") helpers.DownloadFile(bytes(b"hello world"), "out.bin") class TestMPTT(TestCase): - """ Tests for the MPTT tree models """ + """Tests for the MPTT tree models.""" fixtures = [ 'location', ] def setUp(self): + """Setup for all tests.""" super().setUp() StockLocation.objects.rebuild() def test_self_as_parent(self): - """ Test that we cannot set self as parent """ - + """Test that we cannot set self as parent.""" loc = StockLocation.objects.get(pk=4) loc.parent = loc @@ -206,8 +205,7 @@ class TestMPTT(TestCase): loc.save() def test_child_as_parent(self): - """ Test that we cannot set a child as parent """ - + """Test that we cannot set a child as parent.""" parent = StockLocation.objects.get(pk=4) child = StockLocation.objects.get(pk=5) @@ -217,8 +215,7 @@ class TestMPTT(TestCase): parent.save() def test_move(self): - """ Move an item to a different tree """ - + """Move an item to a different tree.""" drawer = StockLocation.objects.get(name='Drawer_1') # Record the tree ID @@ -233,10 +230,10 @@ class TestMPTT(TestCase): class TestSerialNumberExtraction(TestCase): - """ Tests for serial number extraction code """ + """Tests for serial number extraction code.""" def test_simple(self): - + """Test simple serial numbers.""" e = helpers.extract_serial_numbers sn = e("1-5", 5, 1) @@ -301,7 +298,7 @@ class TestSerialNumberExtraction(TestCase): self.assertEqual(sn, [5, 6, 7, 8]) def test_failures(self): - + """Test wron serial numbers.""" e = helpers.extract_serial_numbers # Test duplicates @@ -332,6 +329,7 @@ class TestSerialNumberExtraction(TestCase): e("1, 2, 3, E-5", 5, 1) def test_combinations(self): + """Test complex serial number combinations.""" e = helpers.extract_serial_numbers sn = e("1 3-5 9+2", 7, 1) @@ -352,12 +350,10 @@ class TestSerialNumberExtraction(TestCase): class TestVersionNumber(TestCase): - """ - Unit tests for version number functions - """ + """Unit tests for version number functions.""" def test_tuple(self): - + """Test inventreeVersionTuple.""" v = version.inventreeVersionTuple() self.assertEqual(len(v), 3) @@ -366,10 +362,7 @@ class TestVersionNumber(TestCase): self.assertTrue(s in version.inventreeVersion()) def test_comparison(self): - """ - Test direct comparison of version numbers - """ - + """Test direct comparison of version numbers.""" v_a = version.inventreeVersionTuple('1.2.0') v_b = version.inventreeVersionTuple('1.2.3') v_c = version.inventreeVersionTuple('1.2.4') @@ -381,8 +374,7 @@ class TestVersionNumber(TestCase): self.assertTrue(v_d > v_a) def test_commit_info(self): - """Test that the git commit information is extracted successfully""" - + """Test that the git commit information is extracted successfully.""" envs = { 'INVENTREE_COMMIT_HASH': 'abcdef', 'INVENTREE_COMMIT_DATE': '2022-12-31' @@ -406,12 +398,10 @@ class TestVersionNumber(TestCase): class CurrencyTests(TestCase): - """ - Unit tests for currency / exchange rate functionality - """ + """Unit tests for currency / exchange rate functionality.""" def test_rates(self): - + """Test exchange rate update.""" # Initially, there will not be any exchange rate information rates = Rate.objects.all() @@ -459,33 +449,32 @@ class CurrencyTests(TestCase): class TestStatus(TestCase): - """ - Unit tests for status functions - """ + """Unit tests for status functions.""" def test_check_system_healt(self): - """test that the system health check is false in testing -> background worker not running""" + """Test that the system health check is false in testing -> background worker not running.""" self.assertEqual(status.check_system_health(), False) def test_TestMode(self): + """Test isInTestMode check.""" self.assertTrue(ready.isInTestMode()) def test_Importing(self): + """Test isImportingData check.""" self.assertEqual(ready.isImportingData(), False) class TestSettings(helpers.InvenTreeTestCase): - """ - Unit tests for settings - """ + """Unit tests for settings.""" superuser = True def in_env_context(self, envs={}): - """Patch the env to include the given dict""" + """Patch the env to include the given dict.""" return mock.patch.dict(os.environ, envs) def run_reload(self, envs={}): + """Helper function to reload InvenTree.""" from plugin import registry with self.in_env_context(envs): @@ -494,6 +483,7 @@ class TestSettings(helpers.InvenTreeTestCase): @override_settings(TESTING_ENV=True) def test_set_user_to_few(self): + """Test adding an admin user via env variables.""" user_model = get_user_model() # add shortcut user_count = user_model.objects.count @@ -537,7 +527,7 @@ class TestSettings(helpers.InvenTreeTestCase): settings.TESTING_ENV = False def test_initial_install(self): - """Test if install of plugins on startup works""" + """Test if install of plugins on startup works.""" from plugin import registry # Check an install run @@ -553,6 +543,7 @@ class TestSettings(helpers.InvenTreeTestCase): self.assertEqual(response, True) def test_helpers_cfg_file(self): + """Test get_config_file.""" # normal run - not configured valid = [ @@ -567,6 +558,7 @@ class TestSettings(helpers.InvenTreeTestCase): self.assertIn('inventree/my_special_conf.yaml', config.get_config_file().lower()) def test_helpers_plugin_file(self): + """Test get_plugin_file.""" # normal run - not configured valid = [ @@ -581,6 +573,7 @@ class TestSettings(helpers.InvenTreeTestCase): self.assertIn('my_special_plugins.txt', config.get_plugin_file()) def test_helpers_setting(self): + """Test get_setting.""" TEST_ENV_NAME = '123TEST' # check that default gets returned if not present self.assertEqual(config.get_setting(TEST_ENV_NAME, None, '123!'), '123!') @@ -591,12 +584,10 @@ class TestSettings(helpers.InvenTreeTestCase): class TestInstanceName(helpers.InvenTreeTestCase): - """ - Unit tests for instance name - """ + """Unit tests for instance name.""" def test_instance_name(self): - + """Test instance name settings.""" # default setting self.assertEqual(version.inventreeInstanceTitle(), 'InvenTree') diff --git a/InvenTree/InvenTree/urls.py b/InvenTree/InvenTree/urls.py index c3f5b87169..e1ac6ed1f3 100644 --- a/InvenTree/InvenTree/urls.py +++ b/InvenTree/InvenTree/urls.py @@ -1,5 +1,4 @@ -""" -Top-level URL lookup for InvenTree application. +"""Top-level URL lookup for InvenTree application. Passes URL lookup downstream to each app as required. """ diff --git a/InvenTree/InvenTree/validators.py b/InvenTree/InvenTree/validators.py index 8a23eae39b..bd4c6d1485 100644 --- a/InvenTree/InvenTree/validators.py +++ b/InvenTree/InvenTree/validators.py @@ -1,6 +1,4 @@ -""" -Custom field validators for InvenTree -""" +"""Custom field validators for InvenTree.""" import re from decimal import Decimal, InvalidOperation @@ -15,20 +13,18 @@ import common.models def validate_currency_code(code): - """ - Check that a given code is a valid currency code. - """ - + """Check that a given code is a valid currency code.""" if code not in CURRENCIES: raise ValidationError(_('Not a valid currency code')) def allowable_url_schemes(): - """ Return the list of allowable URL schemes. + """Return the list of allowable URL schemes. + In addition to the default schemes allowed by Django, the install configuration file (config.yaml) can specify - extra schemas """ - + extra schemas + """ # Default schemes schemes = ['http', 'https', 'ftp', 'ftps'] @@ -42,9 +38,7 @@ def allowable_url_schemes(): def validate_part_name(value): - """ Prevent some illegal characters in part names. - """ - + """Prevent some illegal characters in part names.""" for c in ['|', '#', '$', '{', '}']: if c in str(value): raise ValidationError( @@ -53,8 +47,7 @@ def validate_part_name(value): def validate_part_ipn(value): - """ Validate the Part IPN against regex rule """ - + """Validate the Part IPN against regex rule.""" pattern = common.models.InvenTreeSetting.get_setting('PART_IPN_REGEX') if pattern: @@ -65,10 +58,7 @@ def validate_part_ipn(value): def validate_build_order_reference(value): - """ - Validate the 'reference' field of a BuildOrder - """ - + """Validate the 'reference' field of a BuildOrder.""" pattern = common.models.InvenTreeSetting.get_setting('BUILDORDER_REFERENCE_REGEX') if pattern: @@ -79,10 +69,7 @@ def validate_build_order_reference(value): def validate_purchase_order_reference(value): - """ - Validate the 'reference' field of a PurchaseOrder - """ - + """Validate the 'reference' field of a PurchaseOrder.""" pattern = common.models.InvenTreeSetting.get_setting('PURCHASEORDER_REFERENCE_REGEX') if pattern: @@ -93,10 +80,7 @@ def validate_purchase_order_reference(value): def validate_sales_order_reference(value): - """ - Validate the 'reference' field of a SalesOrder - """ - + """Validate the 'reference' field of a SalesOrder.""" pattern = common.models.InvenTreeSetting.get_setting('SALESORDER_REFERENCE_REGEX') if pattern: @@ -107,16 +91,14 @@ def validate_sales_order_reference(value): def validate_tree_name(value): - """ Prevent illegal characters in tree item names """ - + """Prevent illegal characters in tree item names.""" for c in "!@#$%^&*'\"\\/[]{}<>,|+=~`\"": if c in str(value): raise ValidationError(_('Illegal character in name ({x})'.format(x=c))) def validate_overage(value): - """ - Validate that a BOM overage string is properly formatted. + """Validate that a BOM overage string is properly formatted. An overage string can look like: @@ -124,7 +106,6 @@ def validate_overage(value): - A decimal number ('0.123') - A percentage ('5%' / '10 %') """ - value = str(value).lower().strip() # First look for a simple numerical value @@ -162,11 +143,10 @@ def validate_overage(value): def validate_part_name_format(self): - """ - Validate part name format. + """Validate part name format. + Make sure that each template container has a field of Part Model """ - jinja_template_regex = re.compile('{{.*?}}') field_name_regex = re.compile('(?<=part\\.)[A-z]+') for jinja_template in jinja_template_regex.findall(str(self)): diff --git a/InvenTree/InvenTree/version.py b/InvenTree/InvenTree/version.py index fe970ee5d4..3bcb3293bb 100644 --- a/InvenTree/InvenTree/version.py +++ b/InvenTree/InvenTree/version.py @@ -1,5 +1,5 @@ -""" -Version information for InvenTree. +"""Version information for InvenTree. + Provides information on the current InvenTree version """ @@ -17,12 +17,12 @@ INVENTREE_SW_VERSION = "0.8.0 dev" def inventreeInstanceName(): - """ Returns the InstanceName settings for the current database """ + """Returns the InstanceName settings for the current database.""" return common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE", "") def inventreeInstanceTitle(): - """ Returns the InstanceTitle for the current database """ + """Returns the InstanceTitle for the current database.""" if common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE_TITLE", False): return common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE", "") else: @@ -30,13 +30,12 @@ def inventreeInstanceTitle(): def inventreeVersion(): - """ Returns the InvenTree version string """ + """Returns the InvenTree version string.""" return INVENTREE_SW_VERSION.lower().strip() def inventreeVersionTuple(version=None): - """ Return the InvenTree version string as (maj, min, sub) tuple """ - + """Return the InvenTree version string as (maj, min, sub) tuple.""" if version is None: version = INVENTREE_SW_VERSION @@ -46,21 +45,16 @@ def inventreeVersionTuple(version=None): def isInvenTreeDevelopmentVersion(): - """ - Return True if current InvenTree version is a "development" version - """ + """Return True if current InvenTree version is a "development" version.""" return inventreeVersion().endswith('dev') def inventreeDocsVersion(): - """ - Return the version string matching the latest documentation. + """Return the version string matching the latest documentation. Development -> "latest" Release -> "major.minor.sub" e.g. "0.5.2" - """ - if isInvenTreeDevelopmentVersion(): return "latest" else: @@ -68,13 +62,10 @@ def inventreeDocsVersion(): def isInvenTreeUpToDate(): - """ - Test if the InvenTree instance is "up to date" with the latest version. + """Test if the InvenTree instance is "up to date" with the latest version. - A background task periodically queries GitHub for latest version, - and stores it to the database as INVENTREE_LATEST_VERSION + A background task periodically queries GitHub for latest version, and stores it to the database as INVENTREE_LATEST_VERSION """ - latest = common.models.InvenTreeSetting.get_setting('INVENTREE_LATEST_VERSION', backup_value=None, create=False) # No record for "latest" version - we must assume we are up to date! @@ -89,17 +80,17 @@ def isInvenTreeUpToDate(): def inventreeApiVersion(): + """Returns current API version of InvenTree.""" return INVENTREE_API_VERSION def inventreeDjangoVersion(): - """ Return the version of Django library """ + """Returns the version of Django library.""" return django.get_version() def inventreeCommitHash(): - """ Returns the git commit hash for the running codebase """ - + """Returns the git commit hash for the running codebase.""" # First look in the environment variables, i.e. if running in docker commit_hash = os.environ.get('INVENTREE_COMMIT_HASH', '') @@ -113,8 +104,7 @@ def inventreeCommitHash(): def inventreeCommitDate(): - """ Returns the git commit date for the running codebase """ - + """Returns the git commit date for the running codebase.""" # First look in the environment variables, e.g. if running in docker commit_date = os.environ.get('INVENTREE_COMMIT_DATE', '') diff --git a/InvenTree/InvenTree/views.py b/InvenTree/InvenTree/views.py index b29ea7cd44..c2463f0b5e 100644 --- a/InvenTree/InvenTree/views.py +++ b/InvenTree/InvenTree/views.py @@ -1,5 +1,4 @@ -""" -Various Views which provide extra functionality over base Django Views. +"""Various Views which provide extra functionality over base Django Views. In particular these views provide base functionality for rendering Django forms as JSON objects and passing them to modal forms (using jQuery / bootstrap). @@ -41,12 +40,10 @@ from .helpers import str2bool def auth_request(request): - """ - Simple 'auth' endpoint used to determine if the user is authenticated. - Useful for (for example) redirecting authentication requests through - django's permission framework. - """ + """Simple 'auth' endpoint used to determine if the user is authenticated. + Useful for (for example) redirecting authentication requests through django's permission framework. + """ if request.user.is_authenticated: return HttpResponse(status=200) else: @@ -54,8 +51,7 @@ def auth_request(request): class InvenTreeRoleMixin(PermissionRequiredMixin): - """ - Permission class based on user roles, not user 'permissions'. + """Permission class based on user roles, not user 'permissions'. There are a number of ways that the permissions can be specified for a view: @@ -97,10 +93,7 @@ class InvenTreeRoleMixin(PermissionRequiredMixin): role_required = None def has_permission(self): - """ - Determine if the current user has specified permissions - """ - + """Determine if the current user has specified permissions.""" roles_required = [] if type(self.role_required) is str: @@ -163,8 +156,7 @@ class InvenTreeRoleMixin(PermissionRequiredMixin): return True def get_permission_class(self): - """ - Return the 'permission_class' required for the current View. + """Return the 'permission_class' required for the current View. Must be one of: @@ -177,7 +169,6 @@ class InvenTreeRoleMixin(PermissionRequiredMixin): 'permission_class' attribute, or it can be "guessed" by looking at the type of class """ - perm = getattr(self, 'permission_class', None) # Permission is specified by the class itself @@ -204,13 +195,10 @@ class InvenTreeRoleMixin(PermissionRequiredMixin): class AjaxMixin(InvenTreeRoleMixin): - """ AjaxMixin provides basic functionality for rendering a Django form to JSON. - Handles jsonResponse rendering, and adds extra data for the modal forms to process - on the client side. + """AjaxMixin provides basic functionality for rendering a Django form to JSON. Handles jsonResponse rendering, and adds extra data for the modal forms to process on the client side. Any view which inherits the AjaxMixin will need correct permissions set using the 'role_required' attribute - """ # By default, allow *any* role @@ -223,11 +211,11 @@ class AjaxMixin(InvenTreeRoleMixin): ajax_form_title = '' def get_form_title(self): - """ Default implementation - return the ajax_form_title variable """ + """Default implementation - return the ajax_form_title variable.""" return self.ajax_form_title def get_param(self, name, method='GET'): - """ Get a request query parameter value from URL e.g. ?part=3 + """Get a request query parameter value from URL e.g. ?part=3. Args: name: Variable name e.g. 'part' @@ -236,14 +224,13 @@ class AjaxMixin(InvenTreeRoleMixin): Returns: Value of the supplier parameter or None if parameter is not available """ - if method == 'POST': return self.request.POST.get(name, None) else: return self.request.GET.get(name, None) def get_data(self): - """ Get extra context data (default implementation is empty dict) + """Get extra context data (default implementation is empty dict). Returns: dict object (empty) @@ -251,20 +238,18 @@ class AjaxMixin(InvenTreeRoleMixin): return {} def validate(self, obj, form, **kwargs): - """ - Hook for performing custom form validation steps. + """Hook for performing custom form validation steps. If a form error is detected, add it to the form, with 'form.add_error()' Ref: https://docs.djangoproject.com/en/dev/topics/forms/ """ - # Do nothing by default pass def renderJsonResponse(self, request, form=None, data=None, context=None): - """ Render a JSON response based on specific class context. + """Render a JSON response based on specific class context. Args: request: HTTP request object (e.g. GET / POST) @@ -318,19 +303,25 @@ class AjaxMixin(InvenTreeRoleMixin): class AjaxView(AjaxMixin, View): - """ An 'AJAXified' View for displaying an object - """ + """An 'AJAXified' View for displaying an object.""" def post(self, request, *args, **kwargs): + """Return a json formatted response. + + This renderJsonResponse function must be supplied by your function. + """ return self.renderJsonResponse(request) def get(self, request, *args, **kwargs): + """Return a json formatted response. + This renderJsonResponse function must be supplied by your function. + """ return self.renderJsonResponse(request) class QRCodeView(AjaxView): - """ An 'AJAXified' view for displaying a QR code. + """An 'AJAXified' view for displaying a QR code. Subclasses should implement the get_qr_data(self) function. """ @@ -338,22 +329,23 @@ class QRCodeView(AjaxView): ajax_template_name = "qr_code.html" def get(self, request, *args, **kwargs): + """Return json with qr-code data.""" self.request = request self.pk = self.kwargs['pk'] return self.renderJsonResponse(request, None, context=self.get_context_data()) def get_qr_data(self): - """ Returns the text object to render to a QR code. - The actual rendering will be handled by the template """ + """Returns the text object to render to a QR code. + The actual rendering will be handled by the template + """ return None def get_context_data(self): - """ Get context data for passing to the rendering template. + """Get context data for passing to the rendering template. Explicity passes the parameter 'qr_data' """ - context = {} qr = self.get_qr_data() @@ -367,15 +359,14 @@ class QRCodeView(AjaxView): class AjaxCreateView(AjaxMixin, CreateView): + """An 'AJAXified' CreateView for creating a new object in the db. - """ An 'AJAXified' CreateView for creating a new object in the db - Returns a form in JSON format (for delivery to a modal window) - Handles form validation via AJAX POST requests """ def get(self, request, *args, **kwargs): - """ Creates form with initial data, and renders JSON response """ - + """Creates form with initial data, and renders JSON response.""" super(CreateView, self).get(request, *args, **kwargs) self.request = request @@ -383,18 +374,16 @@ class AjaxCreateView(AjaxMixin, CreateView): return self.renderJsonResponse(request, form) def save(self, form): - """ - Method for actually saving the form to the database. - Default implementation is very simple, - but can be overridden if required. - """ + """Method for actually saving the form to the database. + Default implementation is very simple, but can be overridden if required. + """ self.object = form.save() return self.object def post(self, request, *args, **kwargs): - """ Responds to form POST. Validates POST data and returns status info. + """Responds to form POST. Validates POST data and returns status info. - Validate POST form data - If valid, save form @@ -441,45 +430,44 @@ class AjaxCreateView(AjaxMixin, CreateView): class AjaxUpdateView(AjaxMixin, UpdateView): - """ An 'AJAXified' UpdateView for updating an object in the db + """An 'AJAXified' UpdateView for updating an object in the db. + - Returns form in JSON format (for delivery to a modal window) - Handles repeated form validation (via AJAX) until the form is valid """ def get(self, request, *args, **kwargs): - """ Respond to GET request. + """Respond to GET request. - Populates form with object data - Renders form to JSON and returns to client """ - super(UpdateView, self).get(request, *args, **kwargs) return self.renderJsonResponse(request, self.get_form(), context=self.get_context_data()) def save(self, object, form, **kwargs): - """ - Method for updating the object in the database. - Default implementation is very simple, but can be overridden if required. + """Method for updating the object in the database. Default implementation is very simple, but can be overridden if required. Args: object - The current object, to be updated form - The validated form - """ + Returns: + object instance for supplied form + """ self.object = form.save() return self.object def post(self, request, *args, **kwargs): - """ Respond to POST request. + """Respond to POST request. - Updates model with POST field data - Performs form and object validation - If errors exist, re-render the form - Otherwise, return sucess status """ - self.request = request # Make sure we have an object to point to @@ -524,8 +512,8 @@ class AjaxUpdateView(AjaxMixin, UpdateView): class AjaxDeleteView(AjaxMixin, UpdateView): + """An 'AJAXified DeleteView for removing an object from the DB. - """ An 'AJAXified DeleteView for removing an object from the DB - Returns a HTML object (not a form!) in JSON format (for delivery to a modal window) - Handles deletion """ @@ -536,6 +524,7 @@ class AjaxDeleteView(AjaxMixin, UpdateView): context_object_name = 'item' def get_object(self): + """Return object matched to the model of the calling class.""" try: self.object = self.model.objects.get(pk=self.kwargs['pk']) except: @@ -543,15 +532,15 @@ class AjaxDeleteView(AjaxMixin, UpdateView): return self.object def get_form(self): + """Returns a form instance for the form_class of the calling class.""" return self.form_class(self.get_form_kwargs()) def get(self, request, *args, **kwargs): - """ Respond to GET request + """Respond to GET request. - Render a DELETE confirmation form to JSON - Return rendered form to client """ - super(UpdateView, self).get(request, *args, **kwargs) form = self.get_form() @@ -563,12 +552,11 @@ class AjaxDeleteView(AjaxMixin, UpdateView): return self.renderJsonResponse(request, form, context=context) def post(self, request, *args, **kwargs): - """ Respond to POST request + """Respond to POST request. - DELETE the object - Render success message to JSON and return to client """ - obj = self.get_object() pk = obj.id @@ -592,28 +580,30 @@ class AjaxDeleteView(AjaxMixin, UpdateView): class EditUserView(AjaxUpdateView): - """ View for editing user information """ + """View for editing user information.""" ajax_template_name = "modal_form.html" ajax_form_title = _("Edit User Information") form_class = EditUserForm def get_object(self): + """Set form to edit current user.""" return self.request.user class SetPasswordView(AjaxUpdateView): - """ View for setting user password """ + """View for setting user password.""" ajax_template_name = "InvenTree/password.html" ajax_form_title = _("Set Password") form_class = SetPasswordForm def get_object(self): + """Set form to edit current user.""" return self.request.user def post(self, request, *args, **kwargs): - + """Validate inputs and change password.""" form = self.get_form() valid = form.is_valid() @@ -645,19 +635,13 @@ class SetPasswordView(AjaxUpdateView): class IndexView(TemplateView): - """ View for InvenTree index page """ + """View for InvenTree index page.""" template_name = 'InvenTree/index.html' - def get_context_data(self, **kwargs): - - context = super().get_context_data(**kwargs) - - return context - class SearchView(TemplateView): - """ View for InvenTree search page. + """View for InvenTree search page. Displays results of search query """ @@ -665,11 +649,10 @@ class SearchView(TemplateView): template_name = 'InvenTree/search.html' def post(self, request, *args, **kwargs): - """ Handle POST request (which contains search query). + """Handle POST request (which contains search query). Pass the search query to the page template """ - context = self.get_context_data() query = request.POST.get('search', '') @@ -680,24 +663,19 @@ class SearchView(TemplateView): class DynamicJsView(TemplateView): - """ - View for returning javacsript files, - which instead of being served dynamically, - are passed through the django translation engine! - """ + """View for returning javacsript files, which instead of being served dynamically, are passed through the django translation engine!""" template_name = "" content_type = 'text/javascript' class SettingsView(TemplateView): - """ View for configuring User settings - """ + """View for configuring User settings.""" template_name = "InvenTree/settings/settings.html" def get_context_data(self, **kwargs): - + """Add data for template.""" ctx = super().get_context_data(**kwargs).copy() ctx['settings'] = InvenTreeSetting.objects.all().order_by('key') @@ -739,61 +717,53 @@ class SettingsView(TemplateView): class AllauthOverrides(LoginRequiredMixin): - """ - Override allauths views to always redirect to success_url - """ + """Override allauths views to always redirect to success_url.""" + def get(self, request, *args, **kwargs): - # always redirect to settings + """Always redirect to success_url (set to settings).""" return HttpResponseRedirect(self.success_url) class CustomEmailView(AllauthOverrides, EmailView): - """ - Override of allauths EmailView to always show the settings but leave the functions allow - """ + """Override of allauths EmailView to always show the settings but leave the functions allow.""" success_url = reverse_lazy("settings") class CustomConnectionsView(AllauthOverrides, ConnectionsView): - """ - Override of allauths ConnectionsView to always show the settings but leave the functions allow - """ + """Override of allauths ConnectionsView to always show the settings but leave the functions allow.""" success_url = reverse_lazy("settings") class CustomPasswordResetFromKeyView(PasswordResetFromKeyView): - """ - Override of allauths PasswordResetFromKeyView to always show the settings but leave the functions allow - """ + """Override of allauths PasswordResetFromKeyView to always show the settings but leave the functions allow.""" success_url = reverse_lazy("account_login") class UserSessionOverride(): - """overrides sucessurl to lead to settings""" + """Overrides sucessurl to lead to settings.""" + def get_success_url(self): + """Revert to settings page after success.""" return str(reverse_lazy('settings')) class CustomSessionDeleteView(UserSessionOverride, SessionDeleteView): + """Revert to settings after session delete.""" pass class CustomSessionDeleteOtherView(UserSessionOverride, SessionDeleteOtherView): + """Revert to settings after session delete.""" pass class CurrencyRefreshView(RedirectView): - """ - POST endpoint to refresh / update exchange rates - """ + """POST endpoint to refresh / update exchange rates.""" url = reverse_lazy("settings-currencies") def post(self, request, *args, **kwargs): - """ - On a POST request we will attempt to refresh the exchange rates - """ - + """On a POST request we will attempt to refresh the exchange rates.""" from InvenTree.tasks import offload_task, update_exchange_rates offload_task(update_exchange_rates, force_sync=True) @@ -802,10 +772,10 @@ class CurrencyRefreshView(RedirectView): class AppearanceSelectView(RedirectView): - """ View for selecting a color theme """ + """View for selecting a color theme.""" def get_user_theme(self): - """ Get current user color theme """ + """Get current user color theme.""" try: user_theme = ColorTheme.objects.filter(user=self.request.user).get() except ColorTheme.DoesNotExist: @@ -814,8 +784,7 @@ class AppearanceSelectView(RedirectView): return user_theme def post(self, request, *args, **kwargs): - """ Save user color theme selection """ - + """Save user color theme selection.""" theme = request.POST.get('theme', None) # Get current user theme @@ -833,15 +802,14 @@ class AppearanceSelectView(RedirectView): class SettingCategorySelectView(FormView): - """ View for selecting categories in settings """ + """View for selecting categories in settings.""" form_class = SettingCategorySelectForm success_url = reverse_lazy('settings-category') template_name = "InvenTree/settings/category.html" def get_initial(self): - """ Set category selection """ - + """Set category selection.""" initial = super().get_initial() category = self.request.GET.get('category', None) @@ -851,11 +819,10 @@ class SettingCategorySelectView(FormView): return initial def post(self, request, *args, **kwargs): - """ Handle POST request (which contains category selection). + """Handle POST request (which contains category selection). Pass the selected category to the page template """ - form = self.get_form() if form.is_valid(): @@ -869,14 +836,13 @@ class SettingCategorySelectView(FormView): class DatabaseStatsView(AjaxView): - """ View for displaying database statistics """ + """View for displaying database statistics.""" ajax_template_name = "stats.html" ajax_form_title = _("System Information") class NotificationsView(TemplateView): - """ View for showing notifications - """ + """View for showing notifications.""" template_name = "InvenTree/notifications/notifications.html" diff --git a/InvenTree/InvenTree/wsgi.py b/InvenTree/InvenTree/wsgi.py index 9c4d72edbc..dfced329a8 100644 --- a/InvenTree/InvenTree/wsgi.py +++ b/InvenTree/InvenTree/wsgi.py @@ -1,5 +1,4 @@ -""" -WSGI config for InvenTree project. +"""WSGI config for InvenTree project. It exposes the WSGI callable as a module-level variable named ``application``. diff --git a/InvenTree/build/__init__.py b/InvenTree/build/__init__.py index 747a20f837..924a19b982 100644 --- a/InvenTree/build/__init__.py +++ b/InvenTree/build/__init__.py @@ -1,5 +1,4 @@ -""" -The Build module is responsible for managing "Build" transactions. +"""The Build module is responsible for managing "Build" transactions. A Build consumes parts from stock to create new parts """ diff --git a/InvenTree/build/admin.py b/InvenTree/build/admin.py index 5988850fe4..eec7376ede 100644 --- a/InvenTree/build/admin.py +++ b/InvenTree/build/admin.py @@ -1,3 +1,5 @@ +"""Admin functionality for the BuildOrder app""" + from django.contrib import admin from import_export.admin import ImportExportModelAdmin @@ -11,7 +13,7 @@ import part.models class BuildResource(ModelResource): - """Class for managing import/export of Build data""" + """Class for managing import/export of Build data.""" # For some reason, we need to specify the fields individually for this ModelResource, # but we don't for other ones. # TODO: 2022-05-12 - Need to investigate why this is the case! @@ -39,6 +41,7 @@ class BuildResource(ModelResource): notes = Field(attribute='notes') class Meta: + """Metaclass options""" models = Build skip_unchanged = True report_skipped = False @@ -50,6 +53,7 @@ class BuildResource(ModelResource): class BuildAdmin(ImportExportModelAdmin): + """Class for managing the Build model via the admin interface""" exclude = [ 'reference_int', @@ -81,6 +85,7 @@ class BuildAdmin(ImportExportModelAdmin): class BuildItemAdmin(admin.ModelAdmin): + """Class for managing the BuildItem model via the admin interface""" list_display = ( 'build', diff --git a/InvenTree/build/api.py b/InvenTree/build/api.py index fffd77ed05..6703c9a511 100644 --- a/InvenTree/build/api.py +++ b/InvenTree/build/api.py @@ -1,6 +1,4 @@ -""" -JSON API for the Build app -""" +"""JSON API for the Build app.""" from django.urls import include, re_path @@ -22,16 +20,14 @@ from users.models import Owner class BuildFilter(rest_filters.FilterSet): - """ - Custom filterset for BuildList API endpoint - """ + """Custom filterset for BuildList API endpoint.""" status = rest_filters.NumberFilter(label='Status') active = rest_filters.BooleanFilter(label='Build is active', method='filter_active') def filter_active(self, queryset, name, value): - + """Filter the queryset to either include or exclude orders which are active.""" if str2bool(value): queryset = queryset.filter(status__in=BuildStatus.ACTIVE_CODES) else: @@ -42,7 +38,7 @@ class BuildFilter(rest_filters.FilterSet): overdue = rest_filters.BooleanFilter(label='Build is overdue', method='filter_overdue') def filter_overdue(self, queryset, name, value): - + """Filter the queryset to either include or exclude orders which are overdue.""" if str2bool(value): queryset = queryset.filter(Build.OVERDUE_FILTER) else: @@ -53,10 +49,7 @@ class BuildFilter(rest_filters.FilterSet): assigned_to_me = rest_filters.BooleanFilter(label='assigned_to_me', method='filter_assigned_to_me') def filter_assigned_to_me(self, queryset, name, value): - """ - Filter by orders which are assigned to the current user - """ - + """Filter by orders which are assigned to the current user.""" value = str2bool(value) # Work out who "me" is! @@ -71,7 +64,7 @@ class BuildFilter(rest_filters.FilterSet): class BuildList(APIDownloadMixin, generics.ListCreateAPIView): - """ API endpoint for accessing a list of Build objects. + """API endpoint for accessing a list of Build objects. - GET: Return list of objects (with filters) - POST: Create a new Build object @@ -113,11 +106,7 @@ class BuildList(APIDownloadMixin, generics.ListCreateAPIView): ] def get_queryset(self): - """ - Override the queryset filtering, - as some of the fields don't natively play nicely with DRF - """ - + """Override the queryset filtering, as some of the fields don't natively play nicely with DRF.""" queryset = super().get_queryset().select_related('part') queryset = build.serializers.BuildSerializer.annotate_queryset(queryset) @@ -125,6 +114,7 @@ class BuildList(APIDownloadMixin, generics.ListCreateAPIView): return queryset def download_queryset(self, queryset, export_format): + """Download the queryset data as a file.""" dataset = build.admin.BuildResource().export(queryset=queryset) filedata = dataset.export(export_format) @@ -133,7 +123,7 @@ class BuildList(APIDownloadMixin, generics.ListCreateAPIView): return DownloadFile(filedata, filename) def filter_queryset(self, queryset): - + """Custom query filtering for the BuildList endpoint.""" queryset = super().filter_queryset(queryset) params = self.request.query_params @@ -197,7 +187,7 @@ class BuildList(APIDownloadMixin, generics.ListCreateAPIView): return queryset def get_serializer(self, *args, **kwargs): - + """Add extra context information to the endpoint serializer.""" try: part_detail = str2bool(self.request.GET.get('part_detail', None)) except AttributeError: @@ -209,15 +199,14 @@ class BuildList(APIDownloadMixin, generics.ListCreateAPIView): class BuildDetail(generics.RetrieveUpdateAPIView): - """ API endpoint for detail view of a Build object """ + """API endpoint for detail view of a Build object.""" queryset = Build.objects.all() serializer_class = build.serializers.BuildSerializer class BuildUnallocate(generics.CreateAPIView): - """ - API endpoint for unallocating stock items from a build order + """API endpoint for unallocating stock items from a build order. - The BuildOrder object is specified by the URL - "output" (StockItem) can optionally be specified @@ -229,7 +218,7 @@ class BuildUnallocate(generics.CreateAPIView): serializer_class = build.serializers.BuildUnallocationSerializer def get_serializer_context(self): - + """Add extra context information to the endpoint serializer.""" ctx = super().get_serializer_context() try: @@ -243,9 +232,10 @@ class BuildUnallocate(generics.CreateAPIView): class BuildOrderContextMixin: - """ Mixin class which adds build order as serializer context variable """ + """Mixin class which adds build order as serializer context variable.""" def get_serializer_context(self): + """Add extra context information to the endpoint serializer.""" ctx = super().get_serializer_context() ctx['request'] = self.request @@ -260,9 +250,7 @@ class BuildOrderContextMixin: class BuildOutputCreate(BuildOrderContextMixin, generics.CreateAPIView): - """ - API endpoint for creating new build output(s) - """ + """API endpoint for creating new build output(s).""" queryset = Build.objects.none() @@ -270,9 +258,7 @@ class BuildOutputCreate(BuildOrderContextMixin, generics.CreateAPIView): class BuildOutputComplete(BuildOrderContextMixin, generics.CreateAPIView): - """ - API endpoint for completing build outputs - """ + """API endpoint for completing build outputs.""" queryset = Build.objects.none() @@ -280,11 +266,10 @@ class BuildOutputComplete(BuildOrderContextMixin, generics.CreateAPIView): class BuildOutputDelete(BuildOrderContextMixin, generics.CreateAPIView): - """ - API endpoint for deleting multiple build outputs - """ + """API endpoint for deleting multiple build outputs.""" def get_serializer_context(self): + """Add extra context information to the endpoint serializer.""" ctx = super().get_serializer_context() ctx['to_complete'] = False @@ -297,9 +282,7 @@ class BuildOutputDelete(BuildOrderContextMixin, generics.CreateAPIView): class BuildFinish(BuildOrderContextMixin, generics.CreateAPIView): - """ - API endpoint for marking a build as finished (completed) - """ + """API endpoint for marking a build as finished (completed).""" queryset = Build.objects.none() @@ -307,8 +290,7 @@ class BuildFinish(BuildOrderContextMixin, generics.CreateAPIView): class BuildAutoAllocate(BuildOrderContextMixin, generics.CreateAPIView): - """ - API endpoint for 'automatically' allocating stock against a build order. + """API endpoint for 'automatically' allocating stock against a build order. - Only looks at 'untracked' parts - If stock exists in a single location, easy! @@ -322,8 +304,7 @@ class BuildAutoAllocate(BuildOrderContextMixin, generics.CreateAPIView): class BuildAllocate(BuildOrderContextMixin, generics.CreateAPIView): - """ - API endpoint to allocate stock items to a build order + """API endpoint to allocate stock items to a build order. - The BuildOrder object is specified by the URL - Items to allocate are specified as a list called "items" with the following options: @@ -339,23 +320,21 @@ class BuildAllocate(BuildOrderContextMixin, generics.CreateAPIView): class BuildCancel(BuildOrderContextMixin, generics.CreateAPIView): - """ API endpoint for cancelling a BuildOrder """ + """API endpoint for cancelling a BuildOrder.""" queryset = Build.objects.all() serializer_class = build.serializers.BuildCancelSerializer class BuildItemDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for detail view of a BuildItem object - """ + """API endpoint for detail view of a BuildItem object.""" queryset = BuildItem.objects.all() serializer_class = build.serializers.BuildItemSerializer class BuildItemList(generics.ListCreateAPIView): - """ API endpoint for accessing a list of BuildItem objects + """API endpoint for accessing a list of BuildItem objects. - GET: Return list of objects - POST: Create a new BuildItem object @@ -364,7 +343,7 @@ class BuildItemList(generics.ListCreateAPIView): serializer_class = build.serializers.BuildItemSerializer def get_serializer(self, *args, **kwargs): - + """Returns a BuildItemSerializer instance based on the request.""" try: params = self.request.query_params @@ -377,10 +356,7 @@ class BuildItemList(generics.ListCreateAPIView): return self.serializer_class(*args, **kwargs) def get_queryset(self): - """ Override the queryset method, - to allow filtering by stock_item.part - """ - + """Override the queryset method, to allow filtering by stock_item.part.""" query = BuildItem.objects.all() query = query.select_related('stock_item__location') @@ -390,7 +366,7 @@ class BuildItemList(generics.ListCreateAPIView): return query def filter_queryset(self, queryset): - + """Customm query filtering for the BuildItem list.""" queryset = super().filter_queryset(queryset) params = self.request.query_params @@ -438,9 +414,7 @@ class BuildItemList(generics.ListCreateAPIView): class BuildAttachmentList(generics.ListCreateAPIView, AttachmentMixin): - """ - API endpoint for listing (and creating) BuildOrderAttachment objects - """ + """API endpoint for listing (and creating) BuildOrderAttachment objects.""" queryset = BuildOrderAttachment.objects.all() serializer_class = build.serializers.BuildAttachmentSerializer @@ -455,9 +429,7 @@ class BuildAttachmentList(generics.ListCreateAPIView, AttachmentMixin): class BuildAttachmentDetail(generics.RetrieveUpdateDestroyAPIView, AttachmentMixin): - """ - Detail endpoint for a BuildOrderAttachment object - """ + """Detail endpoint for a BuildOrderAttachment object.""" queryset = BuildOrderAttachment.objects.all() serializer_class = build.serializers.BuildAttachmentSerializer diff --git a/InvenTree/build/apps.py b/InvenTree/build/apps.py index 9025e77e0e..683e410b66 100644 --- a/InvenTree/build/apps.py +++ b/InvenTree/build/apps.py @@ -1,5 +1,8 @@ +"""Django app for the BuildOrder module""" + from django.apps import AppConfig class BuildConfig(AppConfig): + """BuildOrder app config class""" name = 'build' diff --git a/InvenTree/build/models.py b/InvenTree/build/models.py index 0f539dc158..cc81fd7f26 100644 --- a/InvenTree/build/models.py +++ b/InvenTree/build/models.py @@ -1,6 +1,4 @@ -""" -Build database model definitions -""" +"""Build database model definitions.""" import decimal @@ -42,10 +40,7 @@ from users import models as UserModels def get_next_build_number(): - """ - Returns the next available BuildOrder reference number - """ - + """Returns the next available BuildOrder reference number.""" if Build.objects.count() == 0: return '0001' @@ -71,7 +66,7 @@ def get_next_build_number(): class Build(MPTTModel, ReferenceIndexingMixin): - """ A Build object organises the creation of new StockItem objects from other existing StockItem objects. + """A Build object organises the creation of new StockItem objects from other existing StockItem objects. Attributes: part: The part to be built (from component BOM items) @@ -97,10 +92,11 @@ class Build(MPTTModel, ReferenceIndexingMixin): @staticmethod def get_api_url(): + """Return the API URL associated with the BuildOrder model""" return reverse('api-build-list') def api_instance_filters(self): - + """Returns custom API filters for the particular BuildOrder instance""" return { 'parent': { 'exclude_tree': self.pk, @@ -109,10 +105,7 @@ class Build(MPTTModel, ReferenceIndexingMixin): @classmethod def api_defaults(cls, request): - """ - Return default values for this model when issuing an API OPTIONS request - """ - + """Return default values for this model when issuing an API OPTIONS request.""" defaults = { 'reference': get_next_build_number(), } @@ -123,7 +116,7 @@ class Build(MPTTModel, ReferenceIndexingMixin): return defaults def save(self, *args, **kwargs): - + """Custom save method for the BuildOrder model""" self.rebuild_reference_field() try: @@ -134,14 +127,12 @@ class Build(MPTTModel, ReferenceIndexingMixin): }) class Meta: + """Metaclass options for the BuildOrder model""" verbose_name = _("Build Order") verbose_name_plural = _("Build Orders") def format_barcode(self, **kwargs): - """ - Return a JSON string to represent this build as a barcode - """ - + """Return a JSON string to represent this build as a barcode.""" return MakeBarcode( "buildorder", self.pk, @@ -153,13 +144,11 @@ class Build(MPTTModel, ReferenceIndexingMixin): @staticmethod def filterByDate(queryset, min_date, max_date): - """ - Filter by 'minimum and maximum date range' + """Filter by 'minimum and maximum date range'. - Specified as min_date, max_date - Both must be specified for filter to be applied """ - date_fmt = '%Y-%m-%d' # ISO format date string # Ensure that both dates are valid @@ -183,12 +172,13 @@ class Build(MPTTModel, ReferenceIndexingMixin): return queryset def __str__(self): - + """String representation of a BuildOrder""" prefix = getSetting("BUILDORDER_REFERENCE_PREFIX") return f"{prefix}{self.reference}" def get_absolute_url(self): + """Return the web URL associated with this BuildOrder""" return reverse('build-detail', kwargs={'pk': self.id}) reference = models.CharField( @@ -336,10 +326,7 @@ class Build(MPTTModel, ReferenceIndexingMixin): ) def sub_builds(self, cascade=True): - """ - Return all Build Order objects under this one. - """ - + """Return all Build Order objects under this one.""" if cascade: return Build.objects.filter(parent=self.pk) else: @@ -347,23 +334,22 @@ class Build(MPTTModel, ReferenceIndexingMixin): Build.objects.filter(parent__pk__in=[d.pk for d in descendants]) def sub_build_count(self, cascade=True): - """ - Return the number of sub builds under this one. + """Return the number of sub builds under this one. Args: cascade: If True (defualt), include cascading builds under sub builds """ - return self.sub_builds(cascade=cascade).count() @property def is_overdue(self): - """ - Returns true if this build is "overdue": + """Returns true if this build is "overdue". Makes use of the OVERDUE_FILTER to avoid code duplication - """ + Returns: + bool: Is the build overdue + """ query = Build.objects.filter(pk=self.pk) query = query.filter(Build.OVERDUE_FILTER) @@ -371,80 +357,59 @@ class Build(MPTTModel, ReferenceIndexingMixin): @property def active(self): - """ - Return True if this build is active - """ - + """Return True if this build is active.""" return self.status in BuildStatus.ACTIVE_CODES @property def bom_items(self): - """ - Returns the BOM items for the part referenced by this BuildOrder - """ - + """Returns the BOM items for the part referenced by this BuildOrder.""" return self.part.get_bom_items() @property def tracked_bom_items(self): - """ - Returns the "trackable" BOM items for this BuildOrder - """ - + """Returns the "trackable" BOM items for this BuildOrder.""" items = self.bom_items items = items.filter(sub_part__trackable=True) return items def has_tracked_bom_items(self): - """ - Returns True if this BuildOrder has trackable BomItems - """ - + """Returns True if this BuildOrder has trackable BomItems.""" return self.tracked_bom_items.count() > 0 @property def untracked_bom_items(self): - """ - Returns the "non trackable" BOM items for this BuildOrder - """ - + """Returns the "non trackable" BOM items for this BuildOrder.""" items = self.bom_items items = items.filter(sub_part__trackable=False) return items def has_untracked_bom_items(self): - """ - Returns True if this BuildOrder has non trackable BomItems - """ - + """Returns True if this BuildOrder has non trackable BomItems.""" return self.untracked_bom_items.count() > 0 @property def remaining(self): - """ - Return the number of outputs remaining to be completed. - """ - + """Return the number of outputs remaining to be completed.""" return max(0, self.quantity - self.completed) @property def output_count(self): + """Return the number of build outputs (StockItem) associated with this build order""" return self.build_outputs.count() def has_build_outputs(self): + """Returns True if this build has more than zero build outputs""" return self.output_count > 0 def get_build_outputs(self, **kwargs): - """ - Return a list of build outputs. + """Return a list of build outputs. kwargs: complete = (True / False) - If supplied, filter by completed status in_stock = (True / False) - If supplied, filter by 'in-stock' status """ - outputs = self.build_outputs.all() # Filter by 'in stock' status @@ -469,17 +434,14 @@ class Build(MPTTModel, ReferenceIndexingMixin): @property def complete_outputs(self): - """ - Return all the "completed" build outputs - """ - + """Return all the "completed" build outputs.""" outputs = self.get_build_outputs(complete=True) return outputs @property def complete_count(self): - + """Return the total quantity of completed outputs""" quantity = 0 for output in self.complete_outputs: @@ -489,20 +451,14 @@ class Build(MPTTModel, ReferenceIndexingMixin): @property def incomplete_outputs(self): - """ - Return all the "incomplete" build outputs - """ - + """Return all the "incomplete" build outputs.""" outputs = self.get_build_outputs(complete=False) return outputs @property def incomplete_count(self): - """ - Return the total number of "incomplete" outputs - """ - + """Return the total number of "incomplete" outputs.""" quantity = 0 for output in self.incomplete_outputs: @@ -512,10 +468,7 @@ class Build(MPTTModel, ReferenceIndexingMixin): @classmethod def getNextBuildNumber(cls): - """ - Try to predict the next Build Order reference: - """ - + """Try to predict the next Build Order reference.""" if cls.objects.count() == 0: return None @@ -552,13 +505,11 @@ class Build(MPTTModel, ReferenceIndexingMixin): @property def can_complete(self): - """ - Returns True if this build can be "completed" + """Returns True if this build can be "completed". - Must not have any outstanding build outputs - 'completed' value must meet (or exceed) the 'quantity' value """ - if self.incomplete_count > 0: return False @@ -573,10 +524,7 @@ class Build(MPTTModel, ReferenceIndexingMixin): @transaction.atomic def complete_build(self, user): - """ - Mark this build as complete - """ - + """Mark this build as complete.""" if self.incomplete_count > 0: return @@ -597,13 +545,12 @@ class Build(MPTTModel, ReferenceIndexingMixin): @transaction.atomic def cancel_build(self, user, **kwargs): - """ Mark the Build as CANCELLED + """Mark the Build as CANCELLED. - Delete any pending BuildItem objects (but do not remove items from stock) - Set build status to CANCELLED - Save the Build object """ - remove_allocated_stock = kwargs.get('remove_allocated_stock', False) remove_incomplete_outputs = kwargs.get('remove_incomplete_outputs', False) @@ -633,14 +580,12 @@ class Build(MPTTModel, ReferenceIndexingMixin): @transaction.atomic def unallocateStock(self, bom_item=None, output=None): - """ - Unallocate stock from this Build + """Unallocate stock from this Build. - arguments: - - bom_item: Specify a particular BomItem to unallocate stock against - - output: Specify a particular StockItem (output) to unallocate stock against + Args: + bom_item: Specify a particular BomItem to unallocate stock against + output: Specify a particular StockItem (output) to unallocate stock against """ - allocations = BuildItem.objects.filter( build=self, install_into=output @@ -653,19 +598,17 @@ class Build(MPTTModel, ReferenceIndexingMixin): @transaction.atomic def create_build_output(self, quantity, **kwargs): - """ - Create a new build output against this BuildOrder. + """Create a new build output against this BuildOrder. - args: + Args: quantity: The quantity of the item to produce - kwargs: + Kwargs: batch: Override batch code serials: Serial numbers location: Override location auto_allocate: Automatically allocate stock with matching serial numbers """ - batch = kwargs.get('batch', self.batch) location = kwargs.get('location', self.destination) serials = kwargs.get('serials', None) @@ -687,9 +630,7 @@ class Build(MPTTModel, ReferenceIndexingMixin): multiple = True if multiple: - """ - Create multiple build outputs with a single quantity of 1 - """ + """Create multiple build outputs with a single quantity of 1.""" # Quantity *must* be an integer at this point! quantity = int(quantity) @@ -743,9 +684,7 @@ class Build(MPTTModel, ReferenceIndexingMixin): ) else: - """ - Create a single build output of the given quantity - """ + """Create a single build output of the given quantity.""" StockModels.StockItem.objects.create( quantity=quantity, @@ -762,13 +701,12 @@ class Build(MPTTModel, ReferenceIndexingMixin): @transaction.atomic def delete_output(self, output): - """ - Remove a build output from the database: + """Remove a build output from the database. + Executes: - Unallocate any build items against the output - Delete the output StockItem """ - if not output: raise ValidationError(_("No build output specified")) @@ -786,11 +724,7 @@ class Build(MPTTModel, ReferenceIndexingMixin): @transaction.atomic def subtract_allocated_stock(self, user): - """ - Called when the Build is marked as "complete", - this function removes the allocated untracked items from stock. - """ - + """Called when the Build is marked as "complete", this function removes the allocated untracked items from stock.""" items = self.allocated_stock.filter( stock_item__part__trackable=False ) @@ -804,13 +738,11 @@ class Build(MPTTModel, ReferenceIndexingMixin): @transaction.atomic def complete_build_output(self, output, user, **kwargs): - """ - Complete a particular build output + """Complete a particular build output. - Remove allocated StockItems - Mark the output as complete """ - # Select the location for the build output location = kwargs.get('location', self.destination) status = kwargs.get('status', StockStatus.OK) @@ -850,10 +782,9 @@ class Build(MPTTModel, ReferenceIndexingMixin): @transaction.atomic def auto_allocate_stock(self, **kwargs): - """ - Automatically allocate stock items against this build order, - following a number of 'guidelines': + """Automatically allocate stock items against this build order. + Following a number of 'guidelines': - Only "untracked" BOM items are considered (tracked BOM items must be manually allocated) - If a particular BOM item is already fully allocated, it is skipped - Extract all available stock items for the BOM part @@ -863,7 +794,6 @@ class Build(MPTTModel, ReferenceIndexingMixin): - If multiple stock items are found, we *may* be able to allocate: - If the calling function has specified that items are interchangeable """ - location = kwargs.get('location', None) exclude_location = kwargs.get('exclude_location', None) interchangeable = kwargs.get('interchangeable', False) @@ -958,14 +888,12 @@ class Build(MPTTModel, ReferenceIndexingMixin): break def required_quantity(self, bom_item, output=None): - """ - Get the quantity of a part required to complete the particular build output. + """Get the quantity of a part required to complete the particular build output. Args: - part: The Part object - output - The particular build output (StockItem) + bom_item: The Part object + output: The particular build output (StockItem) """ - quantity = bom_item.quantity if output: @@ -976,17 +904,15 @@ class Build(MPTTModel, ReferenceIndexingMixin): return quantity def allocated_bom_items(self, bom_item, output=None): - """ - Return all BuildItem objects which allocate stock of to + """Return all BuildItem objects which allocate stock of to . Note that the bom_item may allow variants, or direct substitutes, making things difficult. Args: - bom_item - The BomItem object - output - Build output (StockItem). + bom_item: The BomItem object + output: Build output (StockItem). """ - allocations = BuildItem.objects.filter( build=self, bom_item=bom_item, @@ -996,10 +922,7 @@ class Build(MPTTModel, ReferenceIndexingMixin): return allocations def allocated_quantity(self, bom_item, output=None): - """ - Return the total quantity of given part allocated to a given build output. - """ - + """Return the total quantity of given part allocated to a given build output.""" allocations = self.allocated_bom_items(bom_item, output) allocated = allocations.aggregate( @@ -1013,27 +936,18 @@ class Build(MPTTModel, ReferenceIndexingMixin): return allocated['q'] def unallocated_quantity(self, bom_item, output=None): - """ - Return the total unallocated (remaining) quantity of a part against a particular output. - """ - + """Return the total unallocated (remaining) quantity of a part against a particular output.""" required = self.required_quantity(bom_item, output) allocated = self.allocated_quantity(bom_item, output) return max(required - allocated, 0) def is_bom_item_allocated(self, bom_item, output=None): - """ - Test if the supplied BomItem has been fully allocated! - """ - + """Test if the supplied BomItem has been fully allocated!""" return self.unallocated_quantity(bom_item, output) == 0 def is_fully_allocated(self, output): - """ - Returns True if the particular build output is fully allocated. - """ - + """Returns True if the particular build output is fully allocated.""" # If output is not specified, we are talking about "untracked" items if output is None: bom_items = self.untracked_bom_items @@ -1049,10 +963,7 @@ class Build(MPTTModel, ReferenceIndexingMixin): return True def is_partially_allocated(self, output): - """ - Returns True if the particular build output is (at least) partially allocated - """ - + """Returns True if the particular build output is (at least) partially allocated.""" # If output is not specified, we are talking about "untracked" items if output is None: bom_items = self.untracked_bom_items @@ -1067,17 +978,11 @@ class Build(MPTTModel, ReferenceIndexingMixin): return False def are_untracked_parts_allocated(self): - """ - Returns True if the un-tracked parts are fully allocated for this BuildOrder - """ - + """Returns True if the un-tracked parts are fully allocated for this BuildOrder.""" return self.is_fully_allocated(None) def unallocated_bom_items(self, output): - """ - Return a list of bom items which have *not* been fully allocated against a particular output - """ - + """Return a list of bom items which have *not* been fully allocated against a particular output.""" unallocated = [] # If output is not specified, we are talking about "untracked" items @@ -1095,7 +1000,7 @@ class Build(MPTTModel, ReferenceIndexingMixin): @property def required_parts(self): - """ Returns a list of parts required to build this part (BOM) """ + """Returns a list of parts required to build this part (BOM).""" parts = [] for item in self.bom_items: @@ -1105,7 +1010,7 @@ class Build(MPTTModel, ReferenceIndexingMixin): @property def required_parts_to_complete_build(self): - """ Returns a list of parts required to complete the full build """ + """Returns a list of parts required to complete the full build.""" parts = [] for bom_item in self.bom_items: @@ -1119,26 +1024,23 @@ class Build(MPTTModel, ReferenceIndexingMixin): @property def is_active(self): - """ Is this build active? An active build is either: + """Is this build active? + An active build is either: - PENDING - HOLDING """ - return self.status in BuildStatus.ACTIVE_CODES @property def is_complete(self): - """ Returns True if the build status is COMPLETE """ - + """Returns True if the build status is COMPLETE.""" return self.status == BuildStatus.COMPLETE @receiver(post_save, sender=Build, dispatch_uid='build_post_save_log') def after_save_build(sender, instance: Build, created: bool, **kwargs): - """ - Callback function to be executed after a Build instance is saved - """ + """Callback function to be executed after a Build instance is saved.""" from . import tasks as build_tasks if created: @@ -1149,21 +1051,19 @@ def after_save_build(sender, instance: Build, created: bool, **kwargs): class BuildOrderAttachment(InvenTreeAttachment): - """ - Model for storing file attachments against a BuildOrder object - """ + """Model for storing file attachments against a BuildOrder object.""" def getSubdir(self): + """Return the media file subdirectory for storing BuildOrder attachments""" return os.path.join('bo_files', str(self.build.id)) build = models.ForeignKey(Build, on_delete=models.CASCADE, related_name='attachments') class BuildItem(models.Model): - """ A BuildItem links multiple StockItem objects to a Build. - These are used to allocate part stock to a build. - Once the Build is completed, the parts are removed from stock and the - BuildItemAllocation objects are removed. + """A BuildItem links multiple StockItem objects to a Build. + + These are used to allocate part stock to a build. Once the Build is completed, the parts are removed from stock and the BuildItemAllocation objects are removed. Attributes: build: Link to a Build object @@ -1175,33 +1075,28 @@ class BuildItem(models.Model): @staticmethod def get_api_url(): + """Return the API URL used to access this model""" return reverse('api-build-item-list') - def get_absolute_url(self): - # TODO - Fix! - return '/build/item/{pk}/'.format(pk=self.id) - # return reverse('build-detail', kwargs={'pk': self.id}) - class Meta: + """Serializer metaclass""" unique_together = [ ('build', 'stock_item', 'install_into'), ] def save(self, *args, **kwargs): - + """Custom save method for the BuildItem model""" self.clean() super().save() def clean(self): - """ - Check validity of this BuildItem instance. - The following checks are performed: + """Check validity of this BuildItem instance. + The following checks are performed: - StockItem.part must be in the BOM of the Part object referenced by Build - Allocation quantity cannot exceed available quantity """ - self.validate_unique() super().clean() @@ -1303,13 +1198,11 @@ class BuildItem(models.Model): @transaction.atomic def complete_allocation(self, user, notes=''): - """ - Complete the allocation of this BuildItem into the output stock item. + """Complete the allocation of this BuildItem into the output stock item. - If the referenced part is trackable, the stock item will be *installed* into the build output - If the referenced part is *not* trackable, the stock item will be removed from stock """ - item = self.stock_item # For a trackable part, special consideration needed! @@ -1344,10 +1237,7 @@ class BuildItem(models.Model): ) def getStockItemThumbnail(self): - """ - Return qualified URL for part thumbnail image - """ - + """Return qualified URL for part thumbnail image.""" thumb_url = None if self.stock_item and self.stock_item.part: diff --git a/InvenTree/build/serializers.py b/InvenTree/build/serializers.py index e4919f1433..6659dd1e42 100644 --- a/InvenTree/build/serializers.py +++ b/InvenTree/build/serializers.py @@ -1,6 +1,4 @@ -""" -JSON serializers for Build API -""" +"""JSON serializers for Build API.""" from django.db import transaction from django.core.exceptions import ValidationError as DjangoValidationError @@ -31,9 +29,7 @@ from .models import Build, BuildItem, BuildOrderAttachment class BuildSerializer(ReferenceIndexingSerializerMixin, InvenTreeModelSerializer): - """ - Serializes a Build object - """ + """Serializes a Build object.""" url = serializers.CharField(source='get_absolute_url', read_only=True) status_text = serializers.CharField(source='get_status_display', read_only=True) @@ -50,16 +46,12 @@ class BuildSerializer(ReferenceIndexingSerializerMixin, InvenTreeModelSerializer @staticmethod def annotate_queryset(queryset): - """ - Add custom annotations to the BuildSerializer queryset, - performing database queries as efficiently as possible. + """Add custom annotations to the BuildSerializer queryset, performing database queries as efficiently as possible. The following annoted fields are added: - overdue: True if the build is outstanding *and* the completion date has past - """ - # Annotate a boolean 'overdue' flag queryset = queryset.annotate( @@ -74,6 +66,7 @@ class BuildSerializer(ReferenceIndexingSerializerMixin, InvenTreeModelSerializer return queryset def __init__(self, *args, **kwargs): + """Determine if extra serializer fields are required""" part_detail = kwargs.pop('part_detail', True) super().__init__(*args, **kwargs) @@ -82,6 +75,7 @@ class BuildSerializer(ReferenceIndexingSerializerMixin, InvenTreeModelSerializer self.fields.pop('part_detail') class Meta: + """Serializer metaclass""" model = Build fields = [ 'pk', @@ -121,8 +115,7 @@ class BuildSerializer(ReferenceIndexingSerializerMixin, InvenTreeModelSerializer class BuildOutputSerializer(serializers.Serializer): - """ - Serializer for a "BuildOutput" + """Serializer for a "BuildOutput". Note that a "BuildOutput" is really just a StockItem which is "in production"! """ @@ -136,7 +129,7 @@ class BuildOutputSerializer(serializers.Serializer): ) def validate_output(self, output): - + """Perform validation for the output (StockItem) provided to the serializer""" build = self.context['build'] # As this serializer can be used in multiple contexts, we need to work out why we are here @@ -168,14 +161,14 @@ class BuildOutputSerializer(serializers.Serializer): return output class Meta: + """Serializer metaclass""" fields = [ 'output', ] class BuildOutputCreateSerializer(serializers.Serializer): - """ - Serializer for creating a new BuildOutput against a BuildOrder. + """Serializer for creating a new BuildOutput against a BuildOrder. URL pattern is "/api/build//create-output/", where is the PK of a Build. @@ -192,13 +185,15 @@ class BuildOutputCreateSerializer(serializers.Serializer): ) def get_build(self): + """Return the Build instance associated with this serializer""" return self.context["build"] def get_part(self): + """Return the Part instance associated with the build""" return self.get_build().part def validate_quantity(self, quantity): - + """Validate the provided quantity field""" if quantity <= 0: raise ValidationError(_("Quantity must be greater than zero")) @@ -229,7 +224,7 @@ class BuildOutputCreateSerializer(serializers.Serializer): ) def validate_serial_numbers(self, serial_numbers): - + """Clean the provided serial number string""" serial_numbers = serial_numbers.strip() return serial_numbers @@ -243,10 +238,7 @@ class BuildOutputCreateSerializer(serializers.Serializer): ) def validate(self, data): - """ - Perform form validation - """ - + """Perform form validation.""" part = self.get_part() # Cache a list of serial numbers (to be used in the "save" method) @@ -284,10 +276,7 @@ class BuildOutputCreateSerializer(serializers.Serializer): return data def save(self): - """ - Generate the new build output(s) - """ - + """Generate the new build output(s)""" data = self.validated_data quantity = data['quantity'] @@ -305,11 +294,10 @@ class BuildOutputCreateSerializer(serializers.Serializer): class BuildOutputDeleteSerializer(serializers.Serializer): - """ - DRF serializer for deleting (cancelling) one or more build outputs - """ + """DRF serializer for deleting (cancelling) one or more build outputs.""" class Meta: + """Serializer metaclass""" fields = [ 'outputs', ] @@ -320,7 +308,7 @@ class BuildOutputDeleteSerializer(serializers.Serializer): ) def validate(self, data): - + """Perform data validation for this serializer""" data = super().validate(data) outputs = data.get('outputs', []) @@ -331,10 +319,7 @@ class BuildOutputDeleteSerializer(serializers.Serializer): return data def save(self): - """ - 'save' the serializer to delete the build outputs - """ - + """'save' the serializer to delete the build outputs.""" data = self.validated_data outputs = data.get('outputs', []) @@ -347,11 +332,10 @@ class BuildOutputDeleteSerializer(serializers.Serializer): class BuildOutputCompleteSerializer(serializers.Serializer): - """ - DRF serializer for completing one or more build outputs - """ + """DRF serializer for completing one or more build outputs.""" class Meta: + """Serializer metaclass""" fields = [ 'outputs', 'location', @@ -393,7 +377,7 @@ class BuildOutputCompleteSerializer(serializers.Serializer): ) def validate(self, data): - + """Perform data validation for this serializer""" super().validate(data) outputs = data.get('outputs', []) @@ -404,10 +388,7 @@ class BuildOutputCompleteSerializer(serializers.Serializer): return data def save(self): - """ - "save" the serializer to complete the build outputs - """ - + """Save the serializer to complete the build outputs.""" build = self.context['build'] request = self.context['request'] @@ -435,15 +416,17 @@ class BuildOutputCompleteSerializer(serializers.Serializer): class BuildCancelSerializer(serializers.Serializer): + """DRF serializer class for cancelling an active BuildOrder""" class Meta: + """Serializer metaclass""" fields = [ 'remove_allocated_stock', 'remove_incomplete_outputs', ] def get_context_data(self): - + """Retrieve extra context data from this serializer""" build = self.context['build'] return { @@ -467,7 +450,7 @@ class BuildCancelSerializer(serializers.Serializer): ) def save(self): - + """Cancel the specified build""" build = self.context['build'] request = self.context['request'] @@ -481,9 +464,7 @@ class BuildCancelSerializer(serializers.Serializer): class BuildCompleteSerializer(serializers.Serializer): - """ - DRF serializer for marking a BuildOrder as complete - """ + """DRF serializer for marking a BuildOrder as complete.""" accept_unallocated = serializers.BooleanField( label=_('Accept Unallocated'), @@ -493,7 +474,7 @@ class BuildCompleteSerializer(serializers.Serializer): ) def validate_accept_unallocated(self, value): - + """Check if the 'accept_unallocated' field is required""" build = self.context['build'] if not build.are_untracked_parts_allocated() and not value: @@ -509,7 +490,7 @@ class BuildCompleteSerializer(serializers.Serializer): ) def validate_accept_incomplete(self, value): - + """Check if the 'accept_incomplete' field is required""" build = self.context['build'] if build.remaining > 0 and not value: @@ -518,7 +499,7 @@ class BuildCompleteSerializer(serializers.Serializer): return value def validate(self, data): - + """Perform validation of this serializer prior to saving""" build = self.context['build'] if build.incomplete_count > 0: @@ -530,7 +511,7 @@ class BuildCompleteSerializer(serializers.Serializer): return data def save(self): - + """Complete the specified build output""" request = self.context['request'] build = self.context['build'] @@ -538,14 +519,12 @@ class BuildCompleteSerializer(serializers.Serializer): class BuildUnallocationSerializer(serializers.Serializer): - """ - DRF serializer for unallocating stock from a BuildOrder + """DRF serializer for unallocating stock from a BuildOrder. Allocated stock can be unallocated with a number of filters: - output: Filter against a particular build output (blank = untracked stock) - bom_item: Filter against a particular BOM line item - """ bom_item = serializers.PrimaryKeyRelatedField( @@ -567,8 +546,7 @@ class BuildUnallocationSerializer(serializers.Serializer): ) def validate_output(self, stock_item): - - # Stock item must point to the same build order! + """Validation for the output StockItem instance. Stock item must point to the same build order!""" build = self.context['build'] if stock_item and stock_item.build != build: @@ -577,11 +555,10 @@ class BuildUnallocationSerializer(serializers.Serializer): return stock_item def save(self): - """ - 'Save' the serializer data. + """Save the serializer data. + This performs the actual unallocation against the build order """ - build = self.context['build'] data = self.validated_data @@ -593,9 +570,7 @@ class BuildUnallocationSerializer(serializers.Serializer): class BuildAllocationItemSerializer(serializers.Serializer): - """ - A serializer for allocating a single stock item against a build order - """ + """A serializer for allocating a single stock item against a build order.""" bom_item = serializers.PrimaryKeyRelatedField( queryset=BomItem.objects.all(), @@ -606,10 +581,7 @@ class BuildAllocationItemSerializer(serializers.Serializer): ) def validate_bom_item(self, bom_item): - """ - Check if the parts match! - """ - + """Check if the parts match""" build = self.context['build'] # BomItem should point to the same 'part' as the parent build @@ -632,7 +604,7 @@ class BuildAllocationItemSerializer(serializers.Serializer): ) def validate_stock_item(self, stock_item): - + """Perform validation of the stock_item field""" if not stock_item.in_stock: raise ValidationError(_("Item must be in stock")) @@ -646,7 +618,7 @@ class BuildAllocationItemSerializer(serializers.Serializer): ) def validate_quantity(self, quantity): - + """Perform validation of the 'quantity' field""" if quantity <= 0: raise ValidationError(_("Quantity must be greater than zero")) @@ -661,6 +633,7 @@ class BuildAllocationItemSerializer(serializers.Serializer): ) class Meta: + """Serializer metaclass""" fields = [ 'bom_item', 'stock_item', @@ -669,7 +642,7 @@ class BuildAllocationItemSerializer(serializers.Serializer): ] def validate(self, data): - + """Perfofrm data validation for this item""" super().validate(data) build = self.context['build'] @@ -715,22 +688,18 @@ class BuildAllocationItemSerializer(serializers.Serializer): class BuildAllocationSerializer(serializers.Serializer): - """ - DRF serializer for allocation stock items against a build order - """ + """DRF serializer for allocation stock items against a build order.""" items = BuildAllocationItemSerializer(many=True) class Meta: + """Serializer metaclass""" fields = [ 'items', ] def validate(self, data): - """ - Validation - """ - + """Validation.""" data = super().validate(data) items = data.get('items', []) @@ -741,7 +710,7 @@ class BuildAllocationSerializer(serializers.Serializer): return data def save(self): - + """Perform the allocation""" data = self.validated_data items = data.get('items', []) @@ -770,11 +739,10 @@ class BuildAllocationSerializer(serializers.Serializer): class BuildAutoAllocationSerializer(serializers.Serializer): - """ - DRF serializer for auto allocating stock items against a build order - """ + """DRF serializer for auto allocating stock items against a build order.""" class Meta: + """Serializer metaclass""" fields = [ 'location', 'exclude_location', @@ -813,7 +781,7 @@ class BuildAutoAllocationSerializer(serializers.Serializer): ) def save(self): - + """Perform the auto-allocation step""" data = self.validated_data build = self.context['build'] @@ -827,7 +795,7 @@ class BuildAutoAllocationSerializer(serializers.Serializer): class BuildItemSerializer(InvenTreeModelSerializer): - """ Serializes a BuildItem object """ + """Serializes a BuildItem object.""" bom_part = serializers.IntegerField(source='bom_item.sub_part.pk', read_only=True) part = serializers.IntegerField(source='stock_item.part.pk', read_only=True) @@ -842,7 +810,7 @@ class BuildItemSerializer(InvenTreeModelSerializer): quantity = InvenTreeDecimalField() def __init__(self, *args, **kwargs): - + """Determine which extra details fields should be included""" build_detail = kwargs.pop('build_detail', False) part_detail = kwargs.pop('part_detail', False) location_detail = kwargs.pop('location_detail', False) @@ -859,6 +827,7 @@ class BuildItemSerializer(InvenTreeModelSerializer): self.fields.pop('location_detail') class Meta: + """Serializer metaclass""" model = BuildItem fields = [ 'pk', @@ -877,11 +846,10 @@ class BuildItemSerializer(InvenTreeModelSerializer): class BuildAttachmentSerializer(InvenTreeAttachmentSerializer): - """ - Serializer for a BuildAttachment - """ + """Serializer for a BuildAttachment.""" class Meta: + """Serializer metaclass""" model = BuildOrderAttachment fields = [ diff --git a/InvenTree/build/tasks.py b/InvenTree/build/tasks.py index daf231d951..ef32fc73ee 100644 --- a/InvenTree/build/tasks.py +++ b/InvenTree/build/tasks.py @@ -1,3 +1,5 @@ +"""Background task definitions for the BuildOrder app""" + from decimal import Decimal import logging @@ -18,11 +20,10 @@ logger = logging.getLogger('inventree') def check_build_stock(build: build.models.Build): - """ - Check the required stock for a newly created build order, - and send an email out to any subscribed users if stock is low. - """ + """Check the required stock for a newly created build order. + Send an email out to any subscribed users if stock is low. + """ # Do not notify if we are importing data if isImportingData(): return diff --git a/InvenTree/build/test_api.py b/InvenTree/build/test_api.py index 1adf38935e..bd969bcea9 100644 --- a/InvenTree/build/test_api.py +++ b/InvenTree/build/test_api.py @@ -1,3 +1,5 @@ +"""Unit tests for the BuildOrder API""" + from datetime import datetime, timedelta from django.urls import reverse @@ -13,8 +15,8 @@ from InvenTree.api_tester import InvenTreeAPITestCase class TestBuildAPI(InvenTreeAPITestCase): - """ - Series of tests for the Build DRF API + """Series of tests for the Build DRF API. + - Tests for Build API - Tests for BuildItem API """ @@ -33,10 +35,7 @@ class TestBuildAPI(InvenTreeAPITestCase): ] def test_get_build_list(self): - """ - Test that we can retrieve list of build objects - """ - + """Test that we can retrieve list of build objects.""" url = reverse('api-build-list') response = self.client.get(url, format='json') self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -65,7 +64,7 @@ class TestBuildAPI(InvenTreeAPITestCase): self.assertEqual(len(response.data), 0) def test_get_build_item_list(self): - """ Test that we can retrieve list of BuildItem objects """ + """Test that we can retrieve list of BuildItem objects.""" url = reverse('api-build-item-list') response = self.client.get(url, format='json') @@ -77,9 +76,7 @@ class TestBuildAPI(InvenTreeAPITestCase): class BuildAPITest(InvenTreeAPITestCase): - """ - Series of tests for the Build DRF API - """ + """Series of tests for the Build DRF API.""" fixtures = [ 'category', @@ -96,18 +93,12 @@ class BuildAPITest(InvenTreeAPITestCase): 'build.add' ] - def setUp(self): - - super().setUp() - class BuildTest(BuildAPITest): - """ - Unit testing for the build complete API endpoint - """ + """Unit testing for the build complete API endpoint.""" def setUp(self): - + """Basic setup for this test suite""" super().setUp() self.build = Build.objects.get(pk=1) @@ -115,10 +106,7 @@ class BuildTest(BuildAPITest): self.url = reverse('api-build-output-complete', kwargs={'pk': self.build.pk}) def test_invalid(self): - """ - Test with invalid data - """ - + """Test with invalid data.""" # Test with an invalid build ID self.post( reverse('api-build-output-complete', kwargs={'pk': 99999}), @@ -199,10 +187,7 @@ class BuildTest(BuildAPITest): ) def test_complete(self): - """ - Test build order completion - """ - + """Test build order completion.""" # Initially, build should not be able to be completed self.assertFalse(self.build.can_complete) @@ -270,8 +255,7 @@ class BuildTest(BuildAPITest): self.assertTrue(self.build.is_complete) def test_cancel(self): - """ Test that we can cancel a BuildOrder via the API """ - + """Test that we can cancel a BuildOrder via the API.""" bo = Build.objects.get(pk=1) url = reverse('api-build-cancel', kwargs={'pk': bo.pk}) @@ -285,10 +269,7 @@ class BuildTest(BuildAPITest): self.assertEqual(bo.status, BuildStatus.CANCELLED) def test_create_delete_output(self): - """ - Test that we can create and delete build outputs via the API - """ - + """Test that we can create and delete build outputs via the API.""" bo = Build.objects.get(pk=1) n_outputs = bo.output_count @@ -494,7 +475,7 @@ class BuildTest(BuildAPITest): self.assertIn('This build output has already been completed', str(response.data)) def test_download_build_orders(self): - + """Test that we can download a list of build orders via the API""" required_cols = [ 'reference', 'status', @@ -539,19 +520,17 @@ class BuildTest(BuildAPITest): class BuildAllocationTest(BuildAPITest): - """ - Unit tests for allocation of stock items against a build order. + """Unit tests for allocation of stock items against a build order. For this test, we will be using Build ID=1; - This points to Part 100 (see fixture data in part.yaml) - This Part already has a BOM with 4 items (see fixture data in bom.yaml) - There are no BomItem objects yet created for this build - """ def setUp(self): - + """Basic operation as part of test suite setup""" super().setUp() self.assignRole('build.add') @@ -565,10 +544,7 @@ class BuildAllocationTest(BuildAPITest): self.n = BuildItem.objects.count() def test_build_data(self): - """ - Check that our assumptions about the particular BuildOrder are correct - """ - + """Check that our assumptions about the particular BuildOrder are correct.""" self.assertEqual(self.build.part.pk, 100) # There should be 4x BOM items we can use @@ -578,26 +554,17 @@ class BuildAllocationTest(BuildAPITest): self.assertEqual(self.build.allocated_stock.count(), 0) def test_get(self): - """ - A GET request to the endpoint should return an error - """ - + """A GET request to the endpoint should return an error.""" self.get(self.url, expected_code=405) def test_options(self): - """ - An OPTIONS request to the endpoint should return information about the endpoint - """ - + """An OPTIONS request to the endpoint should return information about the endpoint.""" response = self.options(self.url, expected_code=200) self.assertIn("API endpoint to allocate stock items to a build order", str(response.data)) def test_empty(self): - """ - Test without any POST data - """ - + """Test without any POST data.""" # Initially test with an empty data set data = self.post(self.url, {}, expected_code=400).data @@ -618,10 +585,7 @@ class BuildAllocationTest(BuildAPITest): self.assertEqual(self.n, BuildItem.objects.count()) def test_missing(self): - """ - Test with missing data - """ - + """Test with missing data.""" # Missing quantity data = self.post( self.url, @@ -674,10 +638,7 @@ class BuildAllocationTest(BuildAPITest): self.assertEqual(self.n, BuildItem.objects.count()) def test_invalid_bom_item(self): - """ - Test by passing an invalid BOM item - """ - + """Test by passing an invalid BOM item.""" data = self.post( self.url, { @@ -695,11 +656,10 @@ class BuildAllocationTest(BuildAPITest): self.assertIn('must point to the same part', str(data)) def test_valid_data(self): - """ - Test with valid data. + """Test with valid data. + This should result in creation of a new BuildItem object """ - self.post( self.url, { @@ -725,17 +685,12 @@ class BuildAllocationTest(BuildAPITest): class BuildListTest(BuildAPITest): - """ - Tests for the BuildOrder LIST API - """ + """Tests for the BuildOrder LIST API.""" url = reverse('api-build-list') def test_get_all_builds(self): - """ - Retrieve *all* builds via the API - """ - + """Retrieve *all* builds via the API.""" builds = self.get(self.url) self.assertEqual(len(builds.data), 5) @@ -753,10 +708,7 @@ class BuildListTest(BuildAPITest): self.assertEqual(len(builds.data), 0) def test_overdue(self): - """ - Create a new build, in the past - """ - + """Create a new build, in the past.""" in_the_past = datetime.now().date() - timedelta(days=50) part = Part.objects.get(pk=50) @@ -776,10 +728,7 @@ class BuildListTest(BuildAPITest): self.assertEqual(len(builds), 1) def test_sub_builds(self): - """ - Test the build / sub-build relationship - """ - + """Test the build / sub-build relationship.""" parent = Build.objects.get(pk=5) part = Part.objects.get(pk=50) diff --git a/InvenTree/build/test_build.py b/InvenTree/build/test_build.py index d61b2adbaa..e17103ae36 100644 --- a/InvenTree/build/test_build.py +++ b/InvenTree/build/test_build.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +"""Unit tests for the 'build' models""" from django.test import TestCase @@ -12,13 +12,10 @@ from stock.models import StockItem class BuildTestBase(TestCase): - """ - Run some tests to ensure that the Build model is working properly. - """ + """Run some tests to ensure that the Build model is working properly.""" def setUp(self): - """ - Initialize data to use for these tests. + """Initialize data to use for these tests. The base Part 'assembly' has a BOM consisting of three parts: @@ -119,11 +116,10 @@ class BuildTestBase(TestCase): class BuildTest(BuildTestBase): + """Unit testing class for the Build model""" def test_ref_int(self): - """ - Test the "integer reference" field used for natural sorting - """ + """Test the "integer reference" field used for natural sorting""" for ii in range(10): build = Build( @@ -141,7 +137,7 @@ class BuildTest(BuildTestBase): self.assertEqual(build.reference_int, ii) def test_init(self): - # Perform some basic tests before we start the ball rolling + """Perform some basic tests before we start the ball rolling""" self.assertEqual(StockItem.objects.count(), 10) @@ -166,7 +162,7 @@ class BuildTest(BuildTestBase): self.assertFalse(self.build.is_complete) def test_build_item_clean(self): - # Ensure that dodgy BuildItem objects cannot be created + """Ensure that dodgy BuildItem objects cannot be created""" stock = StockItem.objects.create(part=self.assembly, quantity=99) @@ -193,7 +189,7 @@ class BuildTest(BuildTestBase): b.save() def test_duplicate_bom_line(self): - # Try to add a duplicate BOM item - it should be allowed + """Try to add a duplicate BOM item - it should be allowed""" BomItem.objects.create( part=self.assembly, @@ -202,12 +198,11 @@ class BuildTest(BuildTestBase): ) def allocate_stock(self, output, allocations): - """ - Allocate stock to this build, against a particular output + """Allocate stock to this build, against a particular output Args: - output - StockItem object (or None) - allocations - Map of {StockItem: quantity} + output: StockItem object (or None) + allocations: Map of {StockItem: quantity} """ for item, quantity in allocations.items(): @@ -219,9 +214,7 @@ class BuildTest(BuildTestBase): ) def test_partial_allocation(self): - """ - Test partial allocation of stock - """ + """Test partial allocation of stock""" # Fully allocate tracked stock against build output 1 self.allocate_stock( @@ -294,9 +287,7 @@ class BuildTest(BuildTestBase): self.assertTrue(self.build.are_untracked_parts_allocated()) def test_cancel(self): - """ - Test cancellation of the build - """ + """Test cancellation of the build""" # TODO @@ -309,9 +300,7 @@ class BuildTest(BuildTestBase): pass def test_complete(self): - """ - Test completion of a build output - """ + """Test completion of a build output""" self.stock_1_1.quantity = 1000 self.stock_1_1.save() @@ -385,12 +374,10 @@ class BuildTest(BuildTestBase): class AutoAllocationTests(BuildTestBase): - """ - Tests for auto allocating stock against a build order - """ + """Tests for auto allocating stock against a build order""" def setUp(self): - + """Init routines for this unit test class""" super().setUp() # Add a "substitute" part for bom_item_2 @@ -411,8 +398,7 @@ class AutoAllocationTests(BuildTestBase): ) def test_auto_allocate(self): - """ - Run the 'auto-allocate' function. What do we expect to happen? + """Run the 'auto-allocate' function. What do we expect to happen? There are two "untracked" parts: - sub_part_1 (quantity 5 per BOM = 50 required total) / 103 in stock (2 items) @@ -474,9 +460,7 @@ class AutoAllocationTests(BuildTestBase): self.assertTrue(self.build.is_bom_item_allocated(self.bom_item_2)) def test_fully_auto(self): - """ - We should be able to auto-allocate against a build in a single go - """ + """We should be able to auto-allocate against a build in a single go""" self.build.auto_allocate_stock( interchangeable=True, diff --git a/InvenTree/build/test_migrations.py b/InvenTree/build/test_migrations.py index 1e95cfb54e..0a6ea3dd8b 100644 --- a/InvenTree/build/test_migrations.py +++ b/InvenTree/build/test_migrations.py @@ -1,6 +1,4 @@ -""" -Tests for the build model database migrations -""" +"""Tests for the build model database migrations.""" from django_test_migrations.contrib.unittest_case import MigratorTestCase @@ -8,18 +6,13 @@ from InvenTree import helpers class TestForwardMigrations(MigratorTestCase): - """ - Test entire schema migration sequence for the build app - """ + """Test entire schema migration sequence for the build app.""" migrate_from = ('build', helpers.getOldestMigrationFile('build')) migrate_to = ('build', helpers.getNewestMigrationFile('build')) def prepare(self): - """ - Create initial data! - """ - + """Create initial data!""" Part = self.old_state.apps.get_model('part', 'part') buildable_part = Part.objects.create( @@ -45,7 +38,7 @@ class TestForwardMigrations(MigratorTestCase): ) def test_items_exist(self): - + """Test to ensure that the 'assembly' field is correctly configured""" Part = self.new_state.apps.get_model('part', 'part') self.assertEqual(Part.objects.count(), 1) @@ -63,18 +56,13 @@ class TestForwardMigrations(MigratorTestCase): class TestReferenceMigration(MigratorTestCase): - """ - Test custom migration which adds 'reference' field to Build model - """ + """Test custom migration which adds 'reference' field to Build model.""" migrate_from = ('build', helpers.getOldestMigrationFile('build')) migrate_to = ('build', '0018_build_reference') def prepare(self): - """ - Create some builds - """ - + """Create some builds.""" Part = self.old_state.apps.get_model('part', 'part') part = Part.objects.create( @@ -108,7 +96,7 @@ class TestReferenceMigration(MigratorTestCase): print(build.reference) def test_build_reference(self): - + """Test that the build reference is correctly assigned to the PK of the Build""" Build = self.new_state.apps.get_model('build', 'build') self.assertEqual(Build.objects.count(), 3) diff --git a/InvenTree/build/tests.py b/InvenTree/build/tests.py index 8bd7553a52..e09e59d958 100644 --- a/InvenTree/build/tests.py +++ b/InvenTree/build/tests.py @@ -1,3 +1,5 @@ +"""Basic unit tests for the BuildOrder app""" + from django.urls import reverse from datetime import datetime, timedelta @@ -11,6 +13,7 @@ from InvenTree.status_codes import BuildStatus class BuildTestSimple(InvenTreeTestCase): + """Basic set of tests for the BuildOrder model functionality""" fixtures = [ 'category', @@ -26,7 +29,7 @@ class BuildTestSimple(InvenTreeTestCase): ] def test_build_objects(self): - # Ensure the Build objects were correctly created + """Ensure the Build objects were correctly created""" self.assertEqual(Build.objects.count(), 5) b = Build.objects.get(pk=2) self.assertEqual(b.batch, 'B2') @@ -35,10 +38,12 @@ class BuildTestSimple(InvenTreeTestCase): self.assertEqual(str(b), 'BO0002') def test_url(self): + """Test URL lookup""" b1 = Build.objects.get(pk=1) self.assertEqual(b1.get_absolute_url(), '/build/1/') def test_is_complete(self): + """Test build completion status""" b1 = Build.objects.get(pk=1) b2 = Build.objects.get(pk=2) @@ -48,10 +53,7 @@ class BuildTestSimple(InvenTreeTestCase): self.assertEqual(b2.status, BuildStatus.COMPLETE) def test_overdue(self): - """ - Test overdue status functionality - """ - + """Test overdue status functionality.""" today = datetime.now().date() build = Build.objects.get(pk=1) @@ -66,6 +68,7 @@ class BuildTestSimple(InvenTreeTestCase): self.assertFalse(build.is_overdue) def test_is_active(self): + """Test active / inactive build status""" b1 = Build.objects.get(pk=1) b2 = Build.objects.get(pk=2) @@ -73,12 +76,12 @@ class BuildTestSimple(InvenTreeTestCase): self.assertEqual(b2.is_active, False) def test_required_parts(self): - # TODO - Generate BOM for test part - pass + """Test set of required BOM items for the build""" + # TODO: Generate BOM for test part + ... def test_cancel_build(self): - """ Test build cancellation function """ - + """Test build cancellation function.""" build = Build.objects.get(id=1) self.assertEqual(build.status, BuildStatus.PENDING) @@ -89,7 +92,7 @@ class BuildTestSimple(InvenTreeTestCase): class TestBuildViews(InvenTreeTestCase): - """ Tests for Build app views """ + """Tests for Build app views.""" fixtures = [ 'category', @@ -105,6 +108,7 @@ class TestBuildViews(InvenTreeTestCase): ] def setUp(self): + """Fixturing for this suite of unit tests""" super().setUp() # Create a build output for build # 1 @@ -118,14 +122,12 @@ class TestBuildViews(InvenTreeTestCase): ) def test_build_index(self): - """ test build index view """ - + """Test build index view.""" response = self.client.get(reverse('build-index')) self.assertEqual(response.status_code, 200) def test_build_detail(self): - """ Test the detail view for a Build object """ - + """Test the detail view for a Build object.""" pk = 1 response = self.client.get(reverse('build-detail', args=(pk,))) diff --git a/InvenTree/build/urls.py b/InvenTree/build/urls.py index 0788a1de37..b524df5627 100644 --- a/InvenTree/build/urls.py +++ b/InvenTree/build/urls.py @@ -1,6 +1,4 @@ -""" -URL lookup for Build app -""" +"""URL lookup for Build app.""" from django.urls import include, re_path diff --git a/InvenTree/build/views.py b/InvenTree/build/views.py index 57d8f8bd37..9d01ddc3d6 100644 --- a/InvenTree/build/views.py +++ b/InvenTree/build/views.py @@ -1,6 +1,4 @@ -""" -Django views for interacting with Build objects -""" +"""Django views for interacting with Build objects.""" from django.utils.translation import gettext_lazy as _ from django.views.generic import DetailView, ListView @@ -15,42 +13,25 @@ from plugin.views import InvenTreePluginViewMixin class BuildIndex(InvenTreeRoleMixin, ListView): - """ - View for displaying list of Builds - """ + """View for displaying list of Builds.""" model = Build template_name = 'build/index.html' context_object_name = 'builds' def get_queryset(self): - """ Return all Build objects (order by date, newest first) """ + """Return all Build objects (order by date, newest first)""" return Build.objects.order_by('status', '-completion_date') - def get_context_data(self, **kwargs): - - context = super().get_context_data(**kwargs) - - context['BuildStatus'] = BuildStatus - - context['active'] = self.get_queryset().filter(status__in=BuildStatus.ACTIVE_CODES) - - context['completed'] = self.get_queryset().filter(status=BuildStatus.COMPLETE) - context['cancelled'] = self.get_queryset().filter(status=BuildStatus.CANCELLED) - - return context - class BuildDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): - """ - Detail view of a single Build object. - """ + """Detail view of a single Build object.""" model = Build template_name = 'build/detail.html' context_object_name = 'build' def get_context_data(self, **kwargs): - + """Return extra context information for the BuildDetail view""" ctx = super().get_context_data(**kwargs) build = self.get_object() @@ -71,9 +52,7 @@ class BuildDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): class BuildDelete(AjaxDeleteView): - """ - View to delete a build - """ + """View to delete a build.""" model = Build ajax_template_name = 'build/delete_build.html' diff --git a/InvenTree/common/admin.py b/InvenTree/common/admin.py index f9ae7b557f..021704615b 100644 --- a/InvenTree/common/admin.py +++ b/InvenTree/common/admin.py @@ -1,3 +1,5 @@ +"""Admin for the common app.""" + from django.contrib import admin from import_export.admin import ImportExportModelAdmin @@ -6,14 +8,12 @@ import common.models class SettingsAdmin(ImportExportModelAdmin): + """Admin settings for InvenTreeSetting.""" list_display = ('key', 'value') def get_readonly_fields(self, request, obj=None): # pragma: no cover - """ - Prevent the 'key' field being edited once the setting is created - """ - + """Prevent the 'key' field being edited once the setting is created.""" if obj: return ['key'] else: @@ -21,14 +21,12 @@ class SettingsAdmin(ImportExportModelAdmin): class UserSettingsAdmin(ImportExportModelAdmin): + """Admin settings for InvenTreeUserSetting.""" list_display = ('key', 'value', 'user', ) def get_readonly_fields(self, request, obj=None): # pragma: no cover - """ - Prevent the 'key' field being edited once the setting is created - """ - + """Prevent the 'key' field being edited once the setting is created.""" if obj: return ['key'] else: @@ -36,16 +34,19 @@ class UserSettingsAdmin(ImportExportModelAdmin): class WebhookAdmin(ImportExportModelAdmin): + """Admin settings for Webhook.""" list_display = ('endpoint_id', 'name', 'active', 'user') class NotificationEntryAdmin(admin.ModelAdmin): + """Admin settings for NotificationEntry.""" list_display = ('key', 'uid', 'updated', ) class NotificationMessageAdmin(admin.ModelAdmin): + """Admin settings for NotificationMessage.""" list_display = ('age_human', 'user', 'category', 'name', 'read', 'target_object', 'source_object', ) diff --git a/InvenTree/common/api.py b/InvenTree/common/api.py index d805626da7..0323778129 100644 --- a/InvenTree/common/api.py +++ b/InvenTree/common/api.py @@ -1,6 +1,4 @@ -""" -Provides a JSON API for common components. -""" +"""Provides a JSON API for common components.""" import json @@ -24,25 +22,23 @@ from plugin.serializers import NotificationUserSettingSerializer class CsrfExemptMixin(object): - """ - Exempts the view from CSRF requirements. - """ + """Exempts the view from CSRF requirements.""" @method_decorator(csrf_exempt) def dispatch(self, *args, **kwargs): - return super(CsrfExemptMixin, self).dispatch(*args, **kwargs) + """Overwrites dispatch to be extempt from csrf checks.""" + return super().dispatch(*args, **kwargs) class WebhookView(CsrfExemptMixin, APIView): - """ - Endpoint for receiving webhooks. - """ + """Endpoint for receiving webhooks.""" authentication_classes = [] permission_classes = [] model_class = common.models.WebhookEndpoint run_async = False def post(self, request, endpoint, *args, **kwargs): + """Process incomming webhook.""" # get webhook definition self._get_webhook(endpoint, request, *args, **kwargs) @@ -101,6 +97,10 @@ class WebhookView(CsrfExemptMixin, APIView): class SettingsList(generics.ListAPIView): + """Generic ListView for settings. + + This is inheritted by all list views for settings. + """ filter_backends = [ DjangoFilterBackend, @@ -120,24 +120,17 @@ class SettingsList(generics.ListAPIView): class GlobalSettingsList(SettingsList): - """ - API endpoint for accessing a list of global settings objects - """ + """API endpoint for accessing a list of global settings objects.""" queryset = common.models.InvenTreeSetting.objects.all() serializer_class = common.serializers.GlobalSettingsSerializer class GlobalSettingsPermissions(permissions.BasePermission): - """ - Special permission class to determine if the user is "staff" - """ + """Special permission class to determine if the user is "staff".""" def has_permission(self, request, view): - """ - Check that the requesting user is 'admin' - """ - + """Check that the requesting user is 'admin'.""" try: user = request.user @@ -152,8 +145,7 @@ class GlobalSettingsPermissions(permissions.BasePermission): class GlobalSettingsDetail(generics.RetrieveUpdateAPIView): - """ - Detail view for an individual "global setting" object. + """Detail view for an individual "global setting" object. - User must have 'staff' status to view / edit """ @@ -163,10 +155,7 @@ class GlobalSettingsDetail(generics.RetrieveUpdateAPIView): serializer_class = common.serializers.GlobalSettingsSerializer def get_object(self): - """ - Attempt to find a global setting object with the provided key. - """ - + """Attempt to find a global setting object with the provided key.""" key = self.kwargs['key'] if key not in common.models.InvenTreeSetting.SETTINGS.keys(): @@ -181,18 +170,13 @@ class GlobalSettingsDetail(generics.RetrieveUpdateAPIView): class UserSettingsList(SettingsList): - """ - API endpoint for accessing a list of user settings objects - """ + """API endpoint for accessing a list of user settings objects.""" queryset = common.models.InvenTreeUserSetting.objects.all() serializer_class = common.serializers.UserSettingsSerializer def filter_queryset(self, queryset): - """ - Only list settings which apply to the current user - """ - + """Only list settings which apply to the current user.""" try: user = self.request.user except AttributeError: # pragma: no cover @@ -206,12 +190,10 @@ class UserSettingsList(SettingsList): class UserSettingsPermissions(permissions.BasePermission): - """ - Special permission class to determine if the user can view / edit a particular setting - """ + """Special permission class to determine if the user can view / edit a particular setting.""" def has_object_permission(self, request, view, obj): - + """Check if the user that requested is also the object owner.""" try: user = request.user except AttributeError: # pragma: no cover @@ -221,8 +203,7 @@ class UserSettingsPermissions(permissions.BasePermission): class UserSettingsDetail(generics.RetrieveUpdateAPIView): - """ - Detail view for an individual "user setting" object + """Detail view for an individual "user setting" object. - User can only view / edit settings their own settings objects """ @@ -232,10 +213,7 @@ class UserSettingsDetail(generics.RetrieveUpdateAPIView): serializer_class = common.serializers.UserSettingsSerializer def get_object(self): - """ - Attempt to find a user setting object with the provided key. - """ - + """Attempt to find a user setting object with the provided key.""" key = self.kwargs['key'] if key not in common.models.InvenTreeUserSetting.SETTINGS.keys(): @@ -249,18 +227,13 @@ class UserSettingsDetail(generics.RetrieveUpdateAPIView): class NotificationUserSettingsList(SettingsList): - """ - API endpoint for accessing a list of notification user settings objects - """ + """API endpoint for accessing a list of notification user settings objects.""" queryset = NotificationUserSetting.objects.all() serializer_class = NotificationUserSettingSerializer def filter_queryset(self, queryset): - """ - Only list settings which apply to the current user - """ - + """Only list settings which apply to the current user.""" try: user = self.request.user except AttributeError: @@ -272,8 +245,7 @@ class NotificationUserSettingsList(SettingsList): class NotificationUserSettingsDetail(generics.RetrieveUpdateAPIView): - """ - Detail view for an individual "notification user setting" object + """Detail view for an individual "notification user setting" object. - User can only view / edit settings their own settings objects """ @@ -287,6 +259,8 @@ class NotificationUserSettingsDetail(generics.RetrieveUpdateAPIView): class NotificationList(generics.ListAPIView): + """List view for all notifications of the current user.""" + queryset = common.models.NotificationMessage.objects.all() serializer_class = common.serializers.NotificationMessageSerializer @@ -313,10 +287,7 @@ class NotificationList(generics.ListAPIView): ] def filter_queryset(self, queryset): - """ - Only list notifications which apply to the current user - """ - + """Only list notifications which apply to the current user.""" try: user = self.request.user except AttributeError: @@ -328,8 +299,7 @@ class NotificationList(generics.ListAPIView): class NotificationDetail(generics.RetrieveUpdateDestroyAPIView): - """ - Detail view for an individual notification object + """Detail view for an individual notification object. - User can only view / delete their own notification objects """ @@ -342,9 +312,7 @@ class NotificationDetail(generics.RetrieveUpdateDestroyAPIView): class NotificationReadEdit(generics.CreateAPIView): - """ - general API endpoint to manipulate read state of a notification - """ + """General API endpoint to manipulate read state of a notification.""" queryset = common.models.NotificationMessage.objects.all() serializer_class = common.serializers.NotificationReadSerializer @@ -354,12 +322,14 @@ class NotificationReadEdit(generics.CreateAPIView): ] def get_serializer_context(self): + """Add instance to context so it can be accessed in the serializer.""" context = super().get_serializer_context() if self.request: context['instance'] = self.get_object() return context def perform_create(self, serializer): + """Set the `read` status to the target value.""" message = self.get_object() try: message.read = self.target @@ -369,23 +339,17 @@ class NotificationReadEdit(generics.CreateAPIView): class NotificationRead(NotificationReadEdit): - """ - API endpoint to mark a notification as read. - """ + """API endpoint to mark a notification as read.""" target = True class NotificationUnread(NotificationReadEdit): - """ - API endpoint to mark a notification as unread. - """ + """API endpoint to mark a notification as unread.""" target = False class NotificationReadAll(generics.RetrieveAPIView): - """ - API endpoint to mark all notifications as read. - """ + """API endpoint to mark all notifications as read.""" queryset = common.models.NotificationMessage.objects.all() @@ -394,6 +358,7 @@ class NotificationReadAll(generics.RetrieveAPIView): ] def get(self, request, *args, **kwargs): + """Set all messages for the current user as read.""" try: self.queryset.filter(user=request.user, read=False).update(read=True) return Response({'status': 'ok'}) diff --git a/InvenTree/common/apps.py b/InvenTree/common/apps.py index 629b5179e6..664fc430e2 100644 --- a/InvenTree/common/apps.py +++ b/InvenTree/common/apps.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +"""App config for common app.""" import logging @@ -8,17 +8,19 @@ logger = logging.getLogger('inventree') class CommonConfig(AppConfig): + """AppConfig for common app. + + Clears system wide flags on ready. + """ + name = 'common' def ready(self): - + """Initialize restart flag clearance on startup.""" self.clear_restart_flag() def clear_restart_flag(self): - """ - Clear the SERVER_RESTART_REQUIRED setting - """ - + """Clear the SERVER_RESTART_REQUIRED setting.""" try: import common.models diff --git a/InvenTree/common/files.py b/InvenTree/common/files.py index 704d02a681..5a4c5a975c 100644 --- a/InvenTree/common/files.py +++ b/InvenTree/common/files.py @@ -1,6 +1,4 @@ -""" -Files management tools. -""" +"""Files management tools.""" import os @@ -12,7 +10,7 @@ from rapidfuzz import fuzz class FileManager: - """ Class for managing an uploaded file """ + """Class for managing an uploaded file.""" name = '' @@ -32,8 +30,7 @@ class FileManager: HEADERS = [] def __init__(self, file, name=None): - """ Initialize the FileManager class with a user-uploaded file object """ - + """Initialize the FileManager class with a user-uploaded file object.""" # Set name if name: self.name = name @@ -46,8 +43,7 @@ class FileManager: @classmethod def validate(cls, file): - """ Validate file extension and data """ - + """Validate file extension and data.""" cleaned_data = None ext = os.path.splitext(file.name)[-1].lower().replace('.', '') @@ -79,21 +75,15 @@ class FileManager: return cleaned_data def process(self, file): - """ Process file """ - + """Process file.""" self.data = self.__class__.validate(file) def update_headers(self): - """ Update headers """ - + """Update headers.""" self.HEADERS = self.REQUIRED_HEADERS + self.ITEM_MATCH_HEADERS + self.OPTIONAL_MATCH_HEADERS + self.OPTIONAL_HEADERS def setup(self): - """ - Setup headers - should be overriden in usage to set the Different Headers - """ - + """Setup headers should be overriden in usage to set the Different Headers.""" if not self.name: return @@ -101,14 +91,15 @@ class FileManager: self.update_headers() def guess_header(self, header, threshold=80): - """ - Try to match a header (from the file) to a list of known headers + """Try to match a header (from the file) to a list of known headers. Args: - header - Header name to look for - threshold - Match threshold for fuzzy search - """ + header (Any): Header name to look for + threshold (int, optional): Match threshold for fuzzy search. Defaults to 80. + Returns: + Any: Matched headers + """ # Replace null values with empty string if header is None: header = '' @@ -143,7 +134,7 @@ class FileManager: return None def columns(self): - """ Return a list of headers for the thingy """ + """Return a list of headers for the thingy.""" headers = [] for header in self.data.headers: @@ -170,21 +161,21 @@ class FileManager: return headers def col_count(self): + """Return the number of columns in the file.""" if self.data is None: return 0 return len(self.data.headers) def row_count(self): - """ Return the number of rows in the file. """ - + """Return the number of rows in the file.""" if self.data is None: return 0 return len(self.data) def rows(self): - """ Return a list of all rows """ + """Return a list of all rows.""" rows = [] for i in range(self.row_count()): @@ -221,15 +212,14 @@ class FileManager: return rows def get_row_data(self, index): - """ Retrieve row data at a particular index """ + """Retrieve row data at a particular index.""" if self.data is None or index >= len(self.data): return None return self.data[index] def get_row_dict(self, index): - """ Retrieve a dict object representing the data row at a particular offset """ - + """Retrieve a dict object representing the data row at a particular offset.""" if self.data is None or index >= len(self.data): return None diff --git a/InvenTree/common/forms.py b/InvenTree/common/forms.py index 7fcdf8535c..224ba3a9c5 100644 --- a/InvenTree/common/forms.py +++ b/InvenTree/common/forms.py @@ -1,6 +1,4 @@ -""" -Django forms for interacting with common objects -""" +"""Django forms for interacting with common objects.""" from django import forms from django.utils.translation import gettext as _ @@ -12,11 +10,11 @@ from .models import InvenTreeSetting class SettingEditForm(HelperForm): - """ - Form for creating / editing a settings object - """ + """Form for creating / editing a settings object.""" class Meta: + """Metaclassoptions for SettingEditForm.""" + model = InvenTreeSetting fields = [ @@ -25,7 +23,7 @@ class SettingEditForm(HelperForm): class UploadFileForm(forms.Form): - """ Step 1 of FileManagementFormView """ + """Step 1 of FileManagementFormView.""" file = forms.FileField( label=_('File'), @@ -33,8 +31,7 @@ class UploadFileForm(forms.Form): ) def __init__(self, *args, **kwargs): - """ Update label and help_text """ - + """Update label and help_text.""" # Get file name name = None if 'name' in kwargs: @@ -48,11 +45,10 @@ class UploadFileForm(forms.Form): self.fields['file'].help_text = _(f'Select {name} file to upload') def clean_file(self): - """ - Run tabular file validation. - If anything is wrong with the file, it will raise ValidationError - """ + """Run tabular file validation. + If anything is wrong with the file, it will raise ValidationError + """ file = self.cleaned_data['file'] # Validate file using FileManager class - will perform initial data validation @@ -63,10 +59,10 @@ class UploadFileForm(forms.Form): class MatchFieldForm(forms.Form): - """ Step 2 of FileManagementFormView """ + """Step 2 of FileManagementFormView.""" def __init__(self, *args, **kwargs): - + """Setup filemanager and check columsn.""" # Get FileManager file_manager = None if 'file_manager' in kwargs: @@ -96,10 +92,10 @@ class MatchFieldForm(forms.Form): class MatchItemForm(forms.Form): - """ Step 3 of FileManagementFormView """ + """Step 3 of FileManagementFormView.""" def __init__(self, *args, **kwargs): - + """Setup filemanager and create fields.""" # Get FileManager file_manager = None if 'file_manager' in kwargs: @@ -194,6 +190,5 @@ class MatchItemForm(forms.Form): ) def get_special_field(self, col_guess, row, file_manager): - """ Function to be overriden in inherited forms to add specific form settings """ - + """Function to be overriden in inherited forms to add specific form settings.""" return None diff --git a/InvenTree/common/models.py b/InvenTree/common/models.py index a5bf5267eb..82f6603a66 100644 --- a/InvenTree/common/models.py +++ b/InvenTree/common/models.py @@ -1,5 +1,5 @@ -""" -Common database model definitions. +"""Common database model definitions. + These models are 'generic' and do not fit a particular business logic object. """ @@ -42,9 +42,10 @@ logger = logging.getLogger('inventree') class EmptyURLValidator(URLValidator): + """Validator for filed with url - that can be empty.""" def __call__(self, value): - + """Make sure empty values pass.""" value = str(value).strip() if len(value) == 0: @@ -55,21 +56,17 @@ class EmptyURLValidator(URLValidator): class BaseInvenTreeSetting(models.Model): - """ - An base InvenTreeSetting object is a key:value pair used for storing - single values (e.g. one-off settings values). - """ + """An base InvenTreeSetting object is a key:value pair used for storing single values (e.g. one-off settings values).""" SETTINGS = {} class Meta: + """Meta options for BaseInvenTreeSetting -> abstract stops creation of database entry.""" + abstract = True def save(self, *args, **kwargs): - """ - Enforce validation and clean before saving - """ - + """Enforce validation and clean before saving.""" self.key = str(self.key).upper() self.clean(**kwargs) @@ -79,14 +76,12 @@ class BaseInvenTreeSetting(models.Model): @classmethod def allValues(cls, user=None, exclude_hidden=False): - """ - Return a dict of "all" defined global settings. + """Return a dict of "all" defined global settings. This performs a single database lookup, and then any settings which are not *in* the database are assigned their default values """ - results = cls.objects.all() # Optionally filter by user @@ -131,28 +126,23 @@ class BaseInvenTreeSetting(models.Model): return settings def get_kwargs(self): - """ - Construct kwargs for doing class-based settings lookup, - depending on *which* class we are. + """Construct kwargs for doing class-based settings lookup, depending on *which* class we are. This is necessary to abtract the settings object from the implementing class (e.g plugins) Subclasses should override this function to ensure the kwargs are correctly set. """ - return {} @classmethod def get_setting_definition(cls, key, **kwargs): - """ - Return the 'definition' of a particular settings value, as a dict object. + """Return the 'definition' of a particular settings value, as a dict object. - The 'settings' dict can be passed as a kwarg - If not passed, look for cls.SETTINGS - Returns an empty dict if the key is not found """ - settings = kwargs.get('settings', cls.SETTINGS) key = str(key).strip().upper() @@ -164,69 +154,56 @@ class BaseInvenTreeSetting(models.Model): @classmethod def get_setting_name(cls, key, **kwargs): - """ - Return the name of a particular setting. + """Return the name of a particular setting. If it does not exist, return an empty string. """ - setting = cls.get_setting_definition(key, **kwargs) return setting.get('name', '') @classmethod def get_setting_description(cls, key, **kwargs): - """ - Return the description for a particular setting. + """Return the description for a particular setting. If it does not exist, return an empty string. """ - setting = cls.get_setting_definition(key, **kwargs) return setting.get('description', '') @classmethod def get_setting_units(cls, key, **kwargs): - """ - Return the units for a particular setting. + """Return the units for a particular setting. If it does not exist, return an empty string. """ - setting = cls.get_setting_definition(key, **kwargs) return setting.get('units', '') @classmethod def get_setting_validator(cls, key, **kwargs): - """ - Return the validator for a particular setting. + """Return the validator for a particular setting. If it does not exist, return None """ - setting = cls.get_setting_definition(key, **kwargs) return setting.get('validator', None) @classmethod def get_setting_default(cls, key, **kwargs): - """ - Return the default value for a particular setting. + """Return the default value for a particular setting. If it does not exist, return an empty string """ - setting = cls.get_setting_definition(key, **kwargs) return setting.get('default', '') @classmethod def get_setting_choices(cls, key, **kwargs): - """ - Return the validator choices available for a particular setting. - """ - + """Return the validator choices available for a particular setting.""" setting = cls.get_setting_definition(key, **kwargs) choices = setting.get('choices', None) @@ -239,13 +216,11 @@ class BaseInvenTreeSetting(models.Model): @classmethod def get_setting_object(cls, key, **kwargs): - """ - Return an InvenTreeSetting object matching the given key. + """Return an InvenTreeSetting object matching the given key. - Key is case-insensitive - Returns None if no match is made """ - key = str(key).strip().upper() settings = cls.objects.all() @@ -311,11 +286,10 @@ class BaseInvenTreeSetting(models.Model): @classmethod def get_setting(cls, key, backup_value=None, **kwargs): - """ - Get the value of a particular setting. + """Get the value of a particular setting. + If it does not exist, return the backup value (default = None) """ - # If no backup value is specified, atttempt to retrieve a "default" value if backup_value is None: backup_value = cls.get_setting_default(key, **kwargs) @@ -343,9 +317,7 @@ class BaseInvenTreeSetting(models.Model): @classmethod def set_setting(cls, key, value, change_user, create=True, **kwargs): - """ - Set the value of a particular setting. - If it does not exist, option to create it. + """Set the value of a particular setting. If it does not exist, option to create it. Args: key: settings key @@ -353,7 +325,6 @@ class BaseInvenTreeSetting(models.Model): change_user: User object (must be staff member to update a core setting) create: If True, create a new setting if the specified key does not exist. """ - if change_user is not None and not change_user.is_staff: return @@ -397,26 +368,26 @@ class BaseInvenTreeSetting(models.Model): @property def name(self): + """Return name for setting.""" return self.__class__.get_setting_name(self.key, **self.get_kwargs()) @property def default_value(self): + """Return default_value for setting.""" return self.__class__.get_setting_default(self.key, **self.get_kwargs()) @property def description(self): + """Return description for setting.""" return self.__class__.get_setting_description(self.key, **self.get_kwargs()) @property def units(self): + """Return units for setting.""" return self.__class__.get_setting_units(self.key, **self.get_kwargs()) def clean(self, **kwargs): - """ - If a validator (or multiple validators) are defined for a particular setting key, - run them against the 'value' field. - """ - + """If a validator (or multiple validators) are defined for a particular setting key, run them against the 'value' field.""" super().clean() # Encode as native values @@ -437,10 +408,7 @@ class BaseInvenTreeSetting(models.Model): raise ValidationError(_("Chosen value is not a valid option")) def run_validator(self, validator): - """ - Run a validator against the 'value' field for this InvenTreeSetting object. - """ - + """Run a validator against the 'value' field for this InvenTreeSetting object.""" if validator is None: return @@ -485,15 +453,11 @@ class BaseInvenTreeSetting(models.Model): validator(value) def validate_unique(self, exclude=None, **kwargs): - """ - Ensure that the key:value pair is unique. - In addition to the base validators, this ensures that the 'key' - is unique, using a case-insensitive comparison. + """Ensure that the key:value pair is unique. In addition to the base validators, this ensures that the 'key' is unique, using a case-insensitive comparison. Note that sub-classes (UserSetting, PluginSetting) use other filters to determine if the setting is 'unique' or not """ - super().validate_unique(exclude) filters = { @@ -520,17 +484,11 @@ class BaseInvenTreeSetting(models.Model): pass def choices(self): - """ - Return the available choices for this setting (or None if no choices are defined) - """ - + """Return the available choices for this setting (or None if no choices are defined).""" return self.__class__.get_setting_choices(self.key, **self.get_kwargs()) def valid_options(self): - """ - Return a list of valid options for this setting - """ - + """Return a list of valid options for this setting.""" choices = self.choices() if not choices: @@ -539,21 +497,17 @@ class BaseInvenTreeSetting(models.Model): return [opt[0] for opt in choices] def is_choice(self): - """ - Check if this setting is a "choice" field - """ - + """Check if this setting is a "choice" field.""" return self.__class__.get_setting_choices(self.key, **self.get_kwargs()) is not None def as_choice(self): - """ - Render this setting as the "display" value of a choice field, - e.g. if the choices are: + """Render this setting as the "display" value of a choice field. + + E.g. if the choices are: [('A4', 'A4 paper'), ('A3', 'A3 paper')], and the value is 'A4', then display 'A4 paper' """ - choices = self.get_setting_choices(self.key, **self.get_kwargs()) if not choices: @@ -566,30 +520,23 @@ class BaseInvenTreeSetting(models.Model): return self.value def is_model(self): - """ - Check if this setting references a model instance in the database - """ - + """Check if this setting references a model instance in the database.""" return self.model_name() is not None def model_name(self): - """ - Return the model name associated with this setting - """ - + """Return the model name associated with this setting.""" setting = self.get_setting_definition(self.key, **self.get_kwargs()) return setting.get('model', None) def model_class(self): - """ - Return the model class associated with this setting, if (and only if): + """Return the model class associated with this setting. + If (and only if): - It has a defined 'model' parameter - The 'model' parameter is of the form app.model - The 'model' parameter has matches a known app model """ - model_name = self.model_name() if not model_name: @@ -617,11 +564,7 @@ class BaseInvenTreeSetting(models.Model): return model def api_url(self): - """ - Return the API url associated with the linked model, - if provided, and valid! - """ - + """Return the API url associated with the linked model, if provided, and valid!""" model_class = self.model_class() if model_class: @@ -634,28 +577,20 @@ class BaseInvenTreeSetting(models.Model): return None def is_bool(self): - """ - Check if this setting is required to be a boolean value - """ - + """Check if this setting is required to be a boolean value.""" validator = self.__class__.get_setting_validator(self.key, **self.get_kwargs()) return self.__class__.validator_is_bool(validator) def as_bool(self): - """ - Return the value of this setting converted to a boolean value. + """Return the value of this setting converted to a boolean value. Warning: Only use on values where is_bool evaluates to true! """ - return InvenTree.helpers.str2bool(self.value) def setting_type(self): - """ - Return the field type identifier for this setting object - """ - + """Return the field type identifier for this setting object.""" if self.is_bool(): return 'boolean' @@ -670,7 +605,7 @@ class BaseInvenTreeSetting(models.Model): @classmethod def validator_is_bool(cls, validator): - + """Return if validator is for bool.""" if validator == bool: return True @@ -682,17 +617,14 @@ class BaseInvenTreeSetting(models.Model): return False def is_int(self,): - """ - Check if the setting is required to be an integer value: - """ - + """Check if the setting is required to be an integer value.""" validator = self.__class__.get_setting_validator(self.key, **self.get_kwargs()) return self.__class__.validator_is_int(validator) @classmethod def validator_is_int(cls, validator): - + """Return if validator is for int.""" if validator == int: return True @@ -704,12 +636,10 @@ class BaseInvenTreeSetting(models.Model): return False def as_int(self): - """ - Return the value of this setting converted to a boolean value. + """Return the value of this setting converted to a boolean value. If an error occurs, return the default value """ - try: value = int(self.value) except (ValueError, TypeError): @@ -719,41 +649,34 @@ class BaseInvenTreeSetting(models.Model): @classmethod def is_protected(cls, key, **kwargs): - """ - Check if the setting value is protected - """ - + """Check if the setting value is protected.""" setting = cls.get_setting_definition(key, **kwargs) return setting.get('protected', False) @property def protected(self): + """Returns if setting is protected from rendering.""" return self.__class__.is_protected(self.key, **self.get_kwargs()) def settings_group_options(): - """ - Build up group tuple for settings based on your choices - """ + """Build up group tuple for settings based on your choices.""" return [('', _('No group')), *[(str(a.id), str(a)) for a in Group.objects.all()]] class InvenTreeSetting(BaseInvenTreeSetting): - """ - An InvenTreeSetting object is a key:value pair used for storing - single values (e.g. one-off settings values). + """An InvenTreeSetting object is a key:value pair used for storing single values (e.g. one-off settings values). The class provides a way of retrieving the value for a particular key, even if that key does not exist. """ def save(self, *args, **kwargs): - """ - When saving a global setting, check to see if it requires a server restart. + """When saving a global setting, check to see if it requires a server restart. + If so, set the "SERVER_RESTART_REQUIRED" setting to True """ - super().save() if self.requires_restart(): @@ -1235,6 +1158,8 @@ class InvenTreeSetting(BaseInvenTreeSetting): } class Meta: + """Meta options for InvenTreeSetting.""" + verbose_name = "InvenTree Setting" verbose_name_plural = "InvenTree Settings" @@ -1246,18 +1171,11 @@ class InvenTreeSetting(BaseInvenTreeSetting): ) def to_native_value(self): - """ - Return the "pythonic" value, - e.g. convert "True" to True, and "1" to 1 - """ - + """Return the "pythonic" value, e.g. convert "True" to True, and "1" to 1.""" return self.__class__.get_setting(self.key) def requires_restart(self): - """ - Return True if this setting requires a server restart after changing - """ - + """Return True if this setting requires a server restart after changing.""" options = InvenTreeSetting.SETTINGS.get(self.key, None) if options: @@ -1267,9 +1185,7 @@ class InvenTreeSetting(BaseInvenTreeSetting): class InvenTreeUserSetting(BaseInvenTreeSetting): - """ - An InvenTreeSetting object with a usercontext - """ + """An InvenTreeSetting object with a usercontext.""" SETTINGS = { 'HOMEPAGE_PART_STARRED': { @@ -1561,6 +1477,8 @@ class InvenTreeUserSetting(BaseInvenTreeSetting): } class Meta: + """Meta options for InvenTreeUserSetting.""" + verbose_name = "InvenTree User Setting" verbose_name_plural = "InvenTree User Settings" constraints = [ @@ -1584,36 +1502,30 @@ class InvenTreeUserSetting(BaseInvenTreeSetting): @classmethod def get_setting_object(cls, key, user=None): + """Return setting object for provided user.""" return super().get_setting_object(key, user=user) def validate_unique(self, exclude=None, **kwargs): + """Return if the setting (including key) is unique.""" return super().validate_unique(exclude=exclude, user=self.user) def to_native_value(self): - """ - Return the "pythonic" value, - e.g. convert "True" to True, and "1" to 1 - """ - + """Return the "pythonic" value, e.g. convert "True" to True, and "1" to 1.""" return self.__class__.get_setting(self.key, user=self.user) def get_kwargs(self): - """ - Explicit kwargs required to uniquely identify a particular setting object, - in addition to the 'key' parameter - """ - + """Explicit kwargs required to uniquely identify a particular setting object, in addition to the 'key' parameter.""" return { 'user': self.user, } class PriceBreak(models.Model): - """ - Represents a PriceBreak model - """ + """Represents a PriceBreak model.""" class Meta: + """Define this as abstract -> no DB entry is created.""" + abstract = True quantity = InvenTree.fields.RoundingDecimalField( @@ -1634,13 +1546,11 @@ class PriceBreak(models.Model): ) def convert_to(self, currency_code): - """ - Convert the unit-price at this price break to the specified currency code. + """Convert the unit-price at this price break to the specified currency code. Args: - currency_code - The currency code to convert to (e.g "USD" or "AUD") + currency_code: The currency code to convert to (e.g "USD" or "AUD") """ - try: converted = convert_money(self.price, currency_code) except MissingRate: @@ -1651,7 +1561,7 @@ class PriceBreak(models.Model): def get_price(instance, quantity, moq=True, multiples=True, currency=None, break_name: str = 'price_breaks'): - """ Calculate the price based on quantity price breaks. + """Calculate the price based on quantity price breaks. - Don't forget to add in flat-fee cost (base_cost field) - If MOQ (minimum order quantity) is required, bump quantity @@ -1721,7 +1631,7 @@ def get_price(instance, quantity, moq=True, multiples=True, currency=None, break class ColorTheme(models.Model): - """ Color Theme Setting """ + """Color Theme Setting.""" name = models.CharField(max_length=20, default='', blank=True) @@ -1731,7 +1641,7 @@ class ColorTheme(models.Model): @classmethod def get_color_themes_choices(cls): - """ Get all color themes from static folder """ + """Get all color themes from static folder.""" if settings.TESTING and not os.path.exists(settings.STATIC_COLOR_THEMES_DIR): logger.error('Theme directory does not exsist') return [] @@ -1750,7 +1660,7 @@ class ColorTheme(models.Model): @classmethod def is_valid_choice(cls, user_color_theme): - """ Check if color theme is valid choice """ + """Check if color theme is valid choice.""" try: user_color_theme_name = user_color_theme.name except AttributeError: @@ -1764,13 +1674,15 @@ class ColorTheme(models.Model): class VerificationMethod: + """Class to hold method references.""" + NONE = 0 TOKEN = 1 HMAC = 2 class WebhookEndpoint(models.Model): - """ Defines a Webhook entdpoint + """Defines a Webhook entdpoint. Attributes: endpoint_id: Path to the webhook, @@ -1835,9 +1747,19 @@ class WebhookEndpoint(models.Model): # To be overridden def init(self, request, *args, **kwargs): + """Set verification method. + + Args: + request: Original request object. + """ self.verify = self.VERIFICATION_METHOD def process_webhook(self): + """Process the webhook incomming. + + This does not deal with the data itself - that happens in process_payload. + Do not touch or pickle data here - it was not verified to be safe. + """ if self.token: self.verify = VerificationMethod.TOKEN if self.secret: @@ -1845,6 +1767,10 @@ class WebhookEndpoint(models.Model): return True def validate_token(self, payload, headers, request): + """Make sure that the provided token (if any) confirms to the setting for this endpoint. + + This can be overridden to create your own token validation method. + """ token = headers.get(self.TOKEN_NAME, "") # no token @@ -1866,7 +1792,14 @@ class WebhookEndpoint(models.Model): return True - def save_data(self, payload, headers=None, request=None): + def save_data(self, payload=None, headers=None, request=None): + """Safes payload to database. + + Args: + payload (optional): Payload that was send along. Defaults to None. + headers (optional): Headers that were send along. Defaults to None. + request (optional): Original request object. Defaults to None. + """ return WebhookMessage.objects.create( host=request.get_host(), header=json.dumps({key: val for key, val in headers.items()}), @@ -1874,15 +1807,35 @@ class WebhookEndpoint(models.Model): endpoint=self, ) - def process_payload(self, message, payload=None, headers=None): + def process_payload(self, message, payload=None, headers=None) -> bool: + """Process a payload. + + Args: + message: DB entry for this message mm + payload (optional): Payload that was send along. Defaults to None. + headers (optional): Headers that were included. Defaults to None. + + Returns: + bool: Was the message processed + """ return True - def get_return(self, payload, headers=None, request=None): + def get_return(self, payload=None, headers=None, request=None) -> str: + """Returns the message that should be returned to the endpoint caller. + + Args: + payload (optional): Payload that was send along. Defaults to None. + headers (optional): Headers that were send along. Defaults to None. + request (optional): Original request object. Defaults to None. + + Returns: + str: Message for caller. + """ return self.MESSAGE_OK class WebhookMessage(models.Model): - """ Defines a webhook message + """Defines a webhook message. Attributes: message_id: Unique identifier for this message, @@ -1939,8 +1892,7 @@ class WebhookMessage(models.Model): class NotificationEntry(models.Model): - """ - A NotificationEntry records the last time a particular notifaction was sent out. + """A NotificationEntry records the last time a particular notifaction was sent out. It is recorded to ensure that notifications are not sent out "too often" to users. @@ -1951,6 +1903,8 @@ class NotificationEntry(models.Model): """ class Meta: + """Meta options for NotificationEntry.""" + unique_together = [ ('key', 'uid'), ] @@ -1970,10 +1924,7 @@ class NotificationEntry(models.Model): @classmethod def check_recent(cls, key: str, uid: int, delta: timedelta): - """ - Test if a particular notification has been sent in the specified time period - """ - + """Test if a particular notification has been sent in the specified time period.""" since = datetime.now().date() - delta entries = cls.objects.filter( @@ -1986,10 +1937,7 @@ class NotificationEntry(models.Model): @classmethod def notify(cls, key: str, uid: int): - """ - Notify the database that a particular notification has been sent out - """ - + """Notify the database that a particular notification has been sent out.""" entry, created = cls.objects.get_or_create( key=key, uid=uid @@ -1999,8 +1947,7 @@ class NotificationEntry(models.Model): class NotificationMessage(models.Model): - """ - A NotificationEntry records the last time a particular notifaction was sent out. + """A NotificationEntry records the last time a particular notifaction was sent out. It is recorded to ensure that notifications are not sent out "too often" to users. @@ -2073,13 +2020,14 @@ class NotificationMessage(models.Model): @staticmethod def get_api_url(): + """Return API endpoint.""" return reverse('api-notifications-list') def age(self): - """age of the message in seconds""" + """Age of the message in seconds.""" delta = now() - self.creation return delta.seconds def age_human(self): - """humanized age""" + """Humanized age.""" return naturaltime(self.creation) diff --git a/InvenTree/common/notifications.py b/InvenTree/common/notifications.py index aa39ad20ef..122d7deaeb 100644 --- a/InvenTree/common/notifications.py +++ b/InvenTree/common/notifications.py @@ -1,3 +1,5 @@ +"""Base classes and functions for notifications.""" + import logging from datetime import timedelta @@ -12,9 +14,7 @@ logger = logging.getLogger('inventree') # region methods class NotificationMethod: - """ - Base class for notification methods - """ + """Base class for notification methods.""" METHOD_NAME = '' METHOD_ICON = None @@ -24,6 +24,13 @@ class NotificationMethod: USER_SETTING = None def __init__(self, obj, category, targets, context) -> None: + """Check that the method is read. + + This checks that: + - All needed functions are implemented + - The method is not disabled via plugin + - All needed contaxt values were provided + """ # Check if a sending fnc is defined if (not hasattr(self, 'send')) and (not hasattr(self, 'send_bulk')): raise NotImplementedError('A NotificationMethod must either define a `send` or a `send_bulk` method') @@ -47,6 +54,7 @@ class NotificationMethod: self.targets = self.get_targets() def check_context(self, context): + """Check that all values defined in the methods CONTEXT were provided in the current context.""" def check(ref, obj): # the obj is not accesible so we are on the end if not isinstance(obj, (list, dict, tuple, )): @@ -82,21 +90,33 @@ class NotificationMethod: return context def get_targets(self): + """Returns targets for notifications. + + Processes `self.targets` to extract all users that should be notified. + """ raise NotImplementedError('The `get_targets` method must be implemented!') def setup(self): + """Set up context before notifications are send. + + This is intended to be overridden in method implementations. + """ return True def cleanup(self): + """Clean up context after all notifications were send. + + This is intended to be overridden in method implementations. + """ return True # region plugins def get_plugin(self): - """Returns plugin class""" + """Returns plugin class.""" return False def global_setting_disable(self): - """Check if the method is defined in a plugin and has a global setting""" + """Check if the method is defined in a plugin and has a global setting.""" # Check if plugin has a setting if not self.GLOBAL_SETTING: return False @@ -115,29 +135,45 @@ class NotificationMethod: return False def usersetting(self, target): - """ - Returns setting for this method for a given user - """ + """Returns setting for this method for a given user.""" return NotificationUserSetting.get_setting(f'NOTIFICATION_METHOD_{self.METHOD_NAME.upper()}', user=target, method=self.METHOD_NAME) # endregion class SingleNotificationMethod(NotificationMethod): + """NotificationMethod that sends notifications one by one.""" + def send(self, target): + """This function must be overriden.""" raise NotImplementedError('The `send` method must be overriden!') class BulkNotificationMethod(NotificationMethod): + """NotificationMethod that sends all notifications in bulk.""" + def send_bulk(self): + """This function must be overriden.""" raise NotImplementedError('The `send` method must be overriden!') # endregion class MethodStorageClass: + """Class that works as registry for all available notification methods in InvenTree. + + Is initialized on startup as one instance named `storage` in this file. + """ + liste = None user_settings = {} def collect(self, selected_classes=None): + """Collect all classes in the enviroment that are notification methods. + + Can be filtered to only include provided classes for testing. + + Args: + selected_classes (class, optional): References to the classes that should be registered. Defaults to None. + """ logger.info('collecting notification methods') current_method = inheritors(NotificationMethod) - IGNORED_NOTIFICATION_CLS @@ -155,7 +191,17 @@ class MethodStorageClass: storage.liste = list(filtered_list.values()) logger.info(f'found {len(storage.liste)} notification methods') - def get_usersettings(self, user): + def get_usersettings(self, user) -> list: + """Returns all user settings for a specific user. + + This is needed to show them in the settings UI. + + Args: + user (User): User that should be used as a filter. + + Returns: + list: All applicablae notification settings. + """ methods = [] for item in storage.liste: if item.USER_SETTING: @@ -186,12 +232,16 @@ storage = MethodStorageClass() class UIMessageNotification(SingleNotificationMethod): + """Delivery method for sending specific users notifications in the notification pain in the web UI.""" + METHOD_NAME = 'ui_message' def get_targets(self): + """Just return the targets - no tricks here.""" return self.targets def send(self, target): + """Send a UI notification to a user.""" NotificationMessage.objects.create( target_object=self.obj, source_object=target, @@ -204,10 +254,7 @@ class UIMessageNotification(SingleNotificationMethod): def trigger_notification(obj, category=None, obj_ref='pk', **kwargs): - """ - Send out a notification - """ - + """Send out a notification.""" targets = kwargs.get('targets', None) target_fnc = kwargs.get('target_fnc', None) target_args = kwargs.get('target_args', []) @@ -267,6 +314,15 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs): def deliver_notification(cls: NotificationMethod, obj, category: str, targets, context: dict): + """Send notification with the provided class. + + This: + - Intis the method + - Checks that there are valid targets + - Runs the delivery setup + - Sends notifications either via `send_bulk` or send` + - Runs the delivery cleanup + """ # Init delivery method method = cls(obj, category, targets, context) diff --git a/InvenTree/common/serializers.py b/InvenTree/common/serializers.py index c91ad0f45d..f703b82fbc 100644 --- a/InvenTree/common/serializers.py +++ b/InvenTree/common/serializers.py @@ -1,6 +1,4 @@ -""" -JSON serializers for common components -""" +"""JSON serializers for common components.""" from rest_framework import serializers @@ -11,9 +9,7 @@ from InvenTree.serializers import InvenTreeModelSerializer class SettingsSerializer(InvenTreeModelSerializer): - """ - Base serializer for a settings object - """ + """Base serializer for a settings object.""" key = serializers.CharField(read_only=True) @@ -30,10 +26,7 @@ class SettingsSerializer(InvenTreeModelSerializer): api_url = serializers.CharField(read_only=True) def get_choices(self, obj): - """ - Returns the choices available for a given item - """ - + """Returns the choices available for a given item.""" results = [] choices = obj.choices() @@ -48,10 +41,7 @@ class SettingsSerializer(InvenTreeModelSerializer): return results def get_value(self, obj): - """ - Make sure protected values are not returned - """ - + """Make sure protected values are not returned.""" # never return protected values if obj.protected: result = '***' @@ -62,11 +52,11 @@ class SettingsSerializer(InvenTreeModelSerializer): class GlobalSettingsSerializer(SettingsSerializer): - """ - Serializer for the InvenTreeSetting model - """ + """Serializer for the InvenTreeSetting model.""" class Meta: + """Meta options for GlobalSettingsSerializer.""" + model = InvenTreeSetting fields = [ 'pk', @@ -82,13 +72,13 @@ class GlobalSettingsSerializer(SettingsSerializer): class UserSettingsSerializer(SettingsSerializer): - """ - Serializer for the InvenTreeUserSetting model - """ + """Serializer for the InvenTreeUserSetting model.""" user = serializers.PrimaryKeyRelatedField(read_only=True) class Meta: + """Meta options for UserSettingsSerializer.""" + model = InvenTreeUserSetting fields = [ 'pk', @@ -105,8 +95,7 @@ class UserSettingsSerializer(SettingsSerializer): class GenericReferencedSettingSerializer(SettingsSerializer): - """ - Serializer for a GenericReferencedSetting model + """Serializer for a GenericReferencedSetting model. Args: MODEL: model class for the serializer @@ -118,9 +107,9 @@ class GenericReferencedSettingSerializer(SettingsSerializer): EXTRA_FIELDS = None def __init__(self, *args, **kwargs): - """Init overrides the Meta class to make it dynamic""" + """Init overrides the Meta class to make it dynamic.""" class CustomMeta: - """Scaffold for custom Meta class""" + """Scaffold for custom Meta class.""" fields = [ 'pk', 'key', @@ -144,9 +133,7 @@ class GenericReferencedSettingSerializer(SettingsSerializer): class NotificationMessageSerializer(InvenTreeModelSerializer): - """ - Serializer for the InvenTreeUserSetting model - """ + """Serializer for the InvenTreeUserSetting model.""" target = serializers.SerializerMethodField(read_only=True) @@ -169,12 +156,16 @@ class NotificationMessageSerializer(InvenTreeModelSerializer): read = serializers.BooleanField(read_only=True) def get_target(self, obj): + """Function to resolve generic object reference to target.""" return get_objectreference(obj, 'target_content_type', 'target_object_id') def get_source(self, obj): + """Function to resolve generic object reference to source.""" return get_objectreference(obj, 'source_content_type', 'source_object_id') class Meta: + """Meta options for NotificationMessageSerializer.""" + model = NotificationMessage fields = [ 'pk', @@ -192,8 +183,10 @@ class NotificationMessageSerializer(InvenTreeModelSerializer): class NotificationReadSerializer(NotificationMessageSerializer): + """Serializer for reading a notification.""" def is_valid(self, raise_exception=False): + """Ensure instance data is available for view and let validation pass.""" self.instance = self.context['instance'] # set instance that should be returned self._validated_data = True return True diff --git a/InvenTree/common/settings.py b/InvenTree/common/settings.py index f010e65567..3bf544e13e 100644 --- a/InvenTree/common/settings.py +++ b/InvenTree/common/settings.py @@ -1,6 +1,4 @@ -""" -User-configurable settings for the common app -""" +"""User-configurable settings for the common app.""" from django.conf import settings @@ -8,9 +6,7 @@ from moneyed import CURRENCIES def currency_code_default(): - """ - Returns the default currency code (or USD if not specified) - """ + """Returns the default currency code (or USD if not specified)""" from django.db.utils import ProgrammingError from common.models import InvenTreeSetting @@ -28,23 +24,17 @@ def currency_code_default(): def currency_code_mappings(): - """ - Returns the current currency choices - """ + """Returns the current currency choices.""" return [(a, CURRENCIES[a].name) for a in settings.CURRENCIES] def currency_codes(): - """ - Returns the current currency codes - """ + """Returns the current currency codes.""" return [a for a in settings.CURRENCIES] def stock_expiry_enabled(): - """ - Returns True if the stock expiry feature is enabled - """ + """Returns True if the stock expiry feature is enabled.""" from common.models import InvenTreeSetting return InvenTreeSetting.get_setting('STOCK_ENABLE_EXPIRY') diff --git a/InvenTree/common/tasks.py b/InvenTree/common/tasks.py index 7fdd57999a..56fd3fb04f 100644 --- a/InvenTree/common/tasks.py +++ b/InvenTree/common/tasks.py @@ -1,3 +1,5 @@ +"""Tasks (processes that get offloaded) for common app.""" + import logging from datetime import datetime, timedelta @@ -7,12 +9,10 @@ logger = logging.getLogger('inventree') def delete_old_notifications(): - """ - Remove old notifications from the database. + """Remove old notifications from the database. Anything older than ~3 months is removed """ - try: from common.models import NotificationEntry except AppRegistryNotReady: # pragma: no cover diff --git a/InvenTree/common/test_notifications.py b/InvenTree/common/test_notifications.py index 41e6303a81..ab4b0838fd 100644 --- a/InvenTree/common/test_notifications.py +++ b/InvenTree/common/test_notifications.py @@ -1,3 +1,5 @@ +"""Tests for basic notification methods and functions in InvenTree.""" + import plugin.templatetags.plugin_extras as plugin_tags from common.notifications import (BulkNotificationMethod, NotificationMethod, SingleNotificationMethod, storage) @@ -6,9 +8,10 @@ from plugin.models import NotificationUserSetting class BaseNotificationTests(BaseNotificationIntegrationTest): + """Tests for basic NotificationMethod.""" def test_NotificationMethod(self): - """ensure the implementation requirements are tested""" + """Ensure the implementation requirements are tested.""" class FalseNotificationMethod(NotificationMethod): METHOD_NAME = 'FalseNotification' @@ -17,12 +20,12 @@ class BaseNotificationTests(BaseNotificationIntegrationTest): METHOD_NAME = 'AnotherFalseNotification' def send(self): - """a comment so we do not need a pass""" + """A comment so we do not need a pass.""" class NoNameNotificationMethod(NotificationMethod): def send(self): - """a comment so we do not need a pass""" + """A comment so we do not need a pass.""" class WrongContextNotificationMethod(NotificationMethod): METHOD_NAME = 'WrongContextNotification' @@ -34,7 +37,7 @@ class BaseNotificationTests(BaseNotificationIntegrationTest): ] def send(self): - """a comment so we do not need a pass""" + """A comment so we do not need a pass.""" # no send / send bulk with self.assertRaises(NotImplementedError): @@ -53,11 +56,12 @@ class BaseNotificationTests(BaseNotificationIntegrationTest): AnotherFalseNotificationMethod('', '', '', {'name': 1, 'message': 2, }, ) def test_failing_passing(self): + """Ensure that an error in one deliverymethod is not blocking all mehthods.""" # cover failing delivery self._notification_run() def test_errors_passing(self): - """ensure that errors do not kill the whole delivery""" + """Ensure that errors do not kill the whole delivery.""" class ErrorImplementation(SingleNotificationMethod): METHOD_NAME = 'ErrorImplementation' @@ -72,10 +76,14 @@ class BaseNotificationTests(BaseNotificationIntegrationTest): class BulkNotificationMethodTests(BaseNotificationIntegrationTest): + """Tests for BulkNotificationMethod classes specifically. + + General tests for NotificationMethods are in BaseNotificationTests. + """ def test_BulkNotificationMethod(self): - """ - Ensure the implementation requirements are tested. + """Ensure the implementation requirements are tested. + MixinNotImplementedError needs to raise if the send_bulk() method is not set. """ @@ -90,10 +98,14 @@ class BulkNotificationMethodTests(BaseNotificationIntegrationTest): class SingleNotificationMethodTests(BaseNotificationIntegrationTest): + """Tests for SingleNotificationMethod classes specifically. + + General tests for NotificationMethods are in BaseNotificationTests. + """ def test_SingleNotificationMethod(self): - """ - Ensure the implementation requirements are tested. + """Ensure the implementation requirements are tested. + MixinNotImplementedError needs to raise if the send() method is not set. """ @@ -110,14 +122,15 @@ class SingleNotificationMethodTests(BaseNotificationIntegrationTest): class NotificationUserSettingTests(BaseNotificationIntegrationTest): - """ Tests for NotificationUserSetting """ + """Tests for NotificationUserSetting.""" def setUp(self): + """Setup for all tests.""" super().setUp() self.client.login(username=self.user.username, password='password') def test_setting_attributes(self): - """check notification method plugin methods: usersettings and tags """ + """Check notification method plugin methods: usersettings and tags.""" class SampleImplementation(BulkNotificationMethod): METHOD_NAME = 'test' diff --git a/InvenTree/common/test_tasks.py b/InvenTree/common/test_tasks.py index 0f19720b95..d551bc82ac 100644 --- a/InvenTree/common/test_tasks.py +++ b/InvenTree/common/test_tasks.py @@ -1,4 +1,5 @@ -# -*- coding: utf-8 -*- +"""Tests for tasks in app common.""" + from django.test import TestCase from common.models import NotificationEntry @@ -8,12 +9,10 @@ from . import tasks as common_tasks class TaskTest(TestCase): - """ - Tests for common tasks - """ + """Tests for common tasks.""" def test_delete(self): - + """Test that the task `delete_old_notifications` runs through without errors.""" # check empty run self.assertEqual(NotificationEntry.objects.all().count(), 0) offload_task(common_tasks.delete_old_notifications,) diff --git a/InvenTree/common/test_views.py b/InvenTree/common/test_views.py index 2394913c73..0e43770f02 100644 --- a/InvenTree/common/test_views.py +++ b/InvenTree/common/test_views.py @@ -1,3 +1 @@ -""" -Unit tests for the views associated with the 'common' app -""" +"""Unit tests for the views associated with the 'common' app.""" diff --git a/InvenTree/common/tests.py b/InvenTree/common/tests.py index b325d50170..b5e4ee330c 100644 --- a/InvenTree/common/tests.py +++ b/InvenTree/common/tests.py @@ -1,3 +1,4 @@ +"""Tests for mechanisms in common.""" import json from datetime import timedelta @@ -19,16 +20,14 @@ CONTENT_TYPE_JSON = 'application/json' class SettingsTest(InvenTreeTestCase): - """ - Tests for the 'settings' model - """ + """Tests for the 'settings' model.""" fixtures = [ 'settings', ] def test_settings_objects(self): - + """Test fixture loading and lookup for settings.""" # There should be two settings objects in the database settings = InvenTreeSetting.objects.all() @@ -42,9 +41,7 @@ class SettingsTest(InvenTreeTestCase): self.assertEqual(InvenTreeSetting.get_setting_object('iNvEnTrEE_inSTanCE').pk, 1) def test_settings_functions(self): - """ - Test settings functions and properties - """ + """Test settings functions and properties.""" # define settings to check instance_ref = 'INVENTREE_INSTANCE' instance_obj = InvenTreeSetting.get_setting_object(instance_ref) @@ -90,9 +87,7 @@ class SettingsTest(InvenTreeTestCase): self.assertEqual(stale_days.to_native_value(), 0) def test_allValues(self): - """ - Make sure that the allValues functions returns correctly - """ + """Make sure that the allValues functions returns correctly.""" # define testing settings # check a few keys @@ -103,7 +98,13 @@ class SettingsTest(InvenTreeTestCase): self.assertIn('SIGNUP_GROUP', result) def run_settings_check(self, key, setting): + """Test that all settings are valid. + - Ensure that a name is set and that it is translated + - Ensure that a description is set + - Ensure that every setting key is valid + - Ensure that a validator is supplied + """ self.assertTrue(type(setting) is dict) name = setting.get('name', None) @@ -147,11 +148,11 @@ class SettingsTest(InvenTreeTestCase): self.assertIn(default, [True, False]) def test_setting_data(self): - """ + """Test for settings data. + - Ensure that every setting has a name, which is translated - Ensure that every setting has a description, which is translated """ - for key, setting in InvenTreeSetting.SETTINGS.items(): try: @@ -168,10 +169,7 @@ class SettingsTest(InvenTreeTestCase): raise exc def test_defaults(self): - """ - Populate the settings with default values - """ - + """Populate the settings with default values.""" for key in InvenTreeSetting.SETTINGS.keys(): value = InvenTreeSetting.get_setting_default(key) @@ -192,14 +190,10 @@ class SettingsTest(InvenTreeTestCase): class GlobalSettingsApiTest(InvenTreeAPITestCase): - """ - Tests for the global settings API - """ + """Tests for the global settings API.""" def test_global_settings_api_list(self): - """ - Test list URL for global settings - """ + """Test list URL for global settings.""" url = reverse('api-global-setting-list') # Read out each of the global settings value, to ensure they are instantiated in the database @@ -212,7 +206,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase): self.assertEqual(len(response.data), len(InvenTreeSetting.SETTINGS.keys())) def test_company_name(self): - + """Test a settings object lifecyle e2e.""" setting = InvenTreeSetting.get_setting_object('INVENTREE_COMPANY_NAME') # Check default value @@ -245,8 +239,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase): self.assertEqual(setting.value, val) def test_api_detail(self): - """Test that we can access the detail view for a setting based on the """ - + """Test that we can access the detail view for a setting based on the .""" # These keys are invalid, and should return 404 for key in ["apple", "carrot", "dog"]: response = self.get( @@ -287,28 +280,22 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase): class UserSettingsApiTest(InvenTreeAPITestCase): - """ - Tests for the user settings API - """ + """Tests for the user settings API.""" def test_user_settings_api_list(self): - """ - Test list URL for user settings - """ + """Test list URL for user settings.""" url = reverse('api-user-setting-list') self.get(url, expected_code=200) def test_user_setting_invalid(self): - """Test a user setting with an invalid key""" - + """Test a user setting with an invalid key.""" url = reverse('api-user-setting-detail', kwargs={'key': 'DONKEY'}) self.get(url, expected_code=404) def test_user_setting_init(self): - """Test we can retrieve a setting which has not yet been initialized""" - + """Test we can retrieve a setting which has not yet been initialized.""" key = 'HOMEPAGE_PART_LATEST' # Ensure it does not actually exist in the database @@ -328,10 +315,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase): self.assertEqual(setting.to_native_value(), False) def test_user_setting_boolean(self): - """ - Test a boolean user setting value - """ - + """Test a boolean user setting value.""" # Ensure we have a boolean setting available setting = InvenTreeUserSetting.get_setting_object( 'SEARCH_PREVIEW_SHOW_PARTS', @@ -395,7 +379,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase): self.assertFalse(str2bool(response.data['value'])) def test_user_setting_choice(self): - + """Test a user setting with choices.""" setting = InvenTreeUserSetting.get_setting_object( 'DATE_DISPLAY_FORMAT', user=self.user @@ -434,7 +418,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase): self.assertIn('Chosen value is not a valid option', str(response.data)) def test_user_setting_integer(self): - + """Test a integer user setting value.""" setting = InvenTreeUserSetting.get_setting_object( 'SEARCH_PREVIEW_RESULTS', user=self.user @@ -480,25 +464,25 @@ class UserSettingsApiTest(InvenTreeAPITestCase): class NotificationUserSettingsApiTest(InvenTreeAPITestCase): - """Tests for the notification user settings API""" + """Tests for the notification user settings API.""" def test_api_list(self): - """Test list URL""" + """Test list URL.""" url = reverse('api-notifcation-setting-list') self.get(url, expected_code=200) def test_setting(self): - """Test the string name for NotificationUserSetting""" + """Test the string name for NotificationUserSetting.""" test_setting = NotificationUserSetting.get_setting_object('NOTIFICATION_METHOD_MAIL', user=self.user) self.assertEqual(str(test_setting), 'NOTIFICATION_METHOD_MAIL (for testuser): ') class PluginSettingsApiTest(InvenTreeAPITestCase): - """Tests for the plugin settings API""" + """Tests for the plugin settings API.""" def test_plugin_list(self): - """List installed plugins via API""" + """List installed plugins via API.""" url = reverse('api-plugin-list') # Simple request @@ -508,13 +492,13 @@ class PluginSettingsApiTest(InvenTreeAPITestCase): self.get(url, expected_code=200, data={'mixin': 'settings'}) def test_api_list(self): - """Test list URL""" + """Test list URL.""" url = reverse('api-plugin-setting-list') self.get(url, expected_code=200) def test_valid_plugin_slug(self): - """Test that an valid plugin slug runs through""" + """Test that an valid plugin slug runs through.""" # load plugin configs fixtures = PluginConfig.objects.all() if not fixtures: @@ -544,26 +528,30 @@ class PluginSettingsApiTest(InvenTreeAPITestCase): self.assertIn("Plugin 'sample' has no setting matching 'doesnotexsist'", str(response.data)) def test_invalid_setting_key(self): - """Test that an invalid setting key returns a 404""" + """Test that an invalid setting key returns a 404.""" ... def test_uninitialized_setting(self): - """Test that requesting an uninitialized setting creates the setting""" + """Test that requesting an uninitialized setting creates the setting.""" ... class WebhookMessageTests(TestCase): + """Tests for webhooks.""" + def setUp(self): + """Setup for all tests.""" self.endpoint_def = WebhookEndpoint.objects.create() self.url = f'/api/webhook/{self.endpoint_def.endpoint_id}/' self.client = Client(enforce_csrf_checks=True) def test_bad_method(self): + """Test that a wrong HTTP method does not work.""" response = self.client.get(self.url) - assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED def test_missing_token(self): + """Tests that token checks work.""" response = self.client.post( self.url, content_type=CONTENT_TYPE_JSON, @@ -575,6 +563,7 @@ class WebhookMessageTests(TestCase): ) def test_bad_token(self): + """Test that a wrong token is not working.""" response = self.client.post( self.url, content_type=CONTENT_TYPE_JSON, @@ -585,6 +574,7 @@ class WebhookMessageTests(TestCase): assert (json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR) def test_bad_url(self): + """Test that a wrongly formed url is not working.""" response = self.client.post( '/api/webhook/1234/', content_type=CONTENT_TYPE_JSON, @@ -593,6 +583,7 @@ class WebhookMessageTests(TestCase): assert response.status_code == HTTPStatus.NOT_FOUND def test_bad_json(self): + """Test that malformed JSON is not accepted.""" response = self.client.post( self.url, data="{'this': 123}", @@ -606,6 +597,7 @@ class WebhookMessageTests(TestCase): ) def test_success_no_token_check(self): + """Test that a endpoint without a token set does not require one.""" # delete token self.endpoint_def.token = '' self.endpoint_def.save() @@ -620,6 +612,7 @@ class WebhookMessageTests(TestCase): assert str(response.content, 'utf-8') == WebhookView.model_class.MESSAGE_OK def test_bad_hmac(self): + """Test that a malformed HMAC does not pass.""" # delete token self.endpoint_def.token = '' self.endpoint_def.secret = '123abc' @@ -635,6 +628,7 @@ class WebhookMessageTests(TestCase): assert (json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR) def test_success_hmac(self): + """Test with a valid HMAC provided.""" # delete token self.endpoint_def.token = '' self.endpoint_def.secret = '123abc' @@ -651,6 +645,10 @@ class WebhookMessageTests(TestCase): assert str(response.content, 'utf-8') == WebhookView.model_class.MESSAGE_OK def test_success(self): + """Test full e2e webhook call. + + The message should go through and save the json payload. + """ response = self.client.post( self.url, data={"this": "is a message"}, @@ -665,9 +663,10 @@ class WebhookMessageTests(TestCase): class NotificationTest(InvenTreeAPITestCase): + """Tests for NotificationEntriy.""" def test_check_notification_entries(self): - + """Test that notification entries can be created.""" # Create some notification entries self.assertEqual(NotificationEntry.objects.count(), 0) @@ -684,21 +683,16 @@ class NotificationTest(InvenTreeAPITestCase): self.assertTrue(NotificationEntry.check_recent('test.notification', 1, delta)) def test_api_list(self): - """Test list URL""" + """Test list URL.""" url = reverse('api-notifications-list') self.get(url, expected_code=200) class LoadingTest(TestCase): - """ - Tests for the common config - """ + """Tests for the common config.""" def test_restart_flag(self): - """ - Test that the restart flag is reset on start - """ - + """Test that the restart flag is reset on start.""" import common.models from plugin import registry @@ -713,10 +707,10 @@ class LoadingTest(TestCase): class ColorThemeTest(TestCase): - """Tests for ColorTheme""" + """Tests for ColorTheme.""" def test_choices(self): - """Test that default choices are returned""" + """Test that default choices are returned.""" result = ColorTheme.get_color_themes_choices() # skip @@ -725,7 +719,7 @@ class ColorThemeTest(TestCase): self.assertIn(('default', 'Default'), result) def test_valid_choice(self): - """Check that is_valid_choice works correctly""" + """Check that is_valid_choice works correctly.""" result = ColorTheme.get_color_themes_choices() # skip diff --git a/InvenTree/common/urls.py b/InvenTree/common/urls.py index 261ea1a691..4803d86236 100644 --- a/InvenTree/common/urls.py +++ b/InvenTree/common/urls.py @@ -1,6 +1,4 @@ -""" -URL lookup for common views -""" +"""URL lookup for common views.""" common_urls = [ ] diff --git a/InvenTree/common/views.py b/InvenTree/common/views.py index ee5a02a288..3fbbceee37 100644 --- a/InvenTree/common/views.py +++ b/InvenTree/common/views.py @@ -1,6 +1,4 @@ -""" -Django views for interacting with common models -""" +"""Django views for interacting with common models.""" import os @@ -18,10 +16,10 @@ from .files import FileManager class MultiStepFormView(SessionWizardView): - """ Setup basic methods of multi-step form + """Setup basic methods of multi-step form. - form_list: list of forms - form_steps_description: description for each form + form_list: list of forms + form_steps_description: description for each form """ form_steps_template = [] @@ -31,14 +29,13 @@ class MultiStepFormView(SessionWizardView): file_storage = FileSystemStorage(settings.MEDIA_ROOT) def __init__(self, *args, **kwargs): - """ Override init method to set media folder """ + """Override init method to set media folder.""" super().__init__(**kwargs) self.process_media_folder() def process_media_folder(self): - """ Process media folder """ - + """Process media folder.""" if self.media_folder: media_folder_abs = os.path.join(settings.MEDIA_ROOT, self.media_folder) if not os.path.exists(media_folder_abs): @@ -46,8 +43,7 @@ class MultiStepFormView(SessionWizardView): self.file_storage = FileSystemStorage(location=media_folder_abs) def get_template_names(self): - """ Select template """ - + """Select template.""" try: # Get template template = self.form_steps_template[self.steps.index] @@ -57,8 +53,7 @@ class MultiStepFormView(SessionWizardView): return template def get_context_data(self, **kwargs): - """ Update context data """ - + """Update context data.""" # Retrieve current context context = super().get_context_data(**kwargs) @@ -74,7 +69,9 @@ class MultiStepFormView(SessionWizardView): class FileManagementFormView(MultiStepFormView): - """ Setup form wizard to perform the following steps: + """File management form wizard. + + Perform the following steps: 1. Upload tabular data file 2. Match headers to InvenTree fields 3. Edit row data and match InvenTree items @@ -95,8 +92,7 @@ class FileManagementFormView(MultiStepFormView): extra_context_data = {} def __init__(self, *args, **kwargs): - """ Initialize the FormView """ - + """Initialize the FormView.""" # Perform all checks and inits for MultiStepFormView super().__init__(self, *args, **kwargs) @@ -105,8 +101,7 @@ class FileManagementFormView(MultiStepFormView): raise NotImplementedError('A subclass of a file manager class needs to be set!') def get_context_data(self, form=None, **kwargs): - """ Handle context data """ - + """Handle context data.""" if form is None: form = self.get_form() @@ -136,8 +131,7 @@ class FileManagementFormView(MultiStepFormView): return context def get_file_manager(self, step=None, form=None): - """ Get FileManager instance from uploaded file """ - + """Get FileManager instance from uploaded file.""" if self.file_manager: return @@ -151,8 +145,7 @@ class FileManagementFormView(MultiStepFormView): self.file_manager = self.file_manager_class(file=file, name=self.name) def get_form_kwargs(self, step=None): - """ Update kwargs to dynamically build forms """ - + """Update kwargs to dynamically build forms.""" # Always retrieve FileManager instance from uploaded file self.get_file_manager(step) @@ -191,7 +184,7 @@ class FileManagementFormView(MultiStepFormView): return super().get_form_kwargs() def get_form(self, step=None, data=None, files=None): - """ add crispy-form helper to form """ + """Add crispy-form helper to form.""" form = super().get_form(step=step, data=data, files=files) form.helper = FormHelper() @@ -200,17 +193,14 @@ class FileManagementFormView(MultiStepFormView): return form def get_form_table_data(self, form_data): - """ Extract table cell data from form data and fields. - These data are used to maintain state between sessions. + """Extract table cell data from form data and fields. These data are used to maintain state between sessions. Table data keys are as follows: col_name_ - Column name at idx as provided in the uploaded file col_guess_ - Column guess at idx as selected row__col - Cell data as provided in the uploaded file - """ - # Map the columns self.column_names = {} self.column_selections = {} @@ -264,8 +254,7 @@ class FileManagementFormView(MultiStepFormView): self.row_data[row_id][col_id] = value def set_form_table_data(self, form=None): - """ Set the form table data """ - + """Set the form table data.""" if self.column_names: # Re-construct the column data self.columns = [] @@ -324,10 +313,10 @@ class FileManagementFormView(MultiStepFormView): row[field_key] = field_key + '-' + str(row['index']) def get_column_index(self, name): - """ Return the index of the column with the given name. + """Return the index of the column with the given name. + It named column is not found, return -1 """ - try: idx = list(self.column_selections.values()).index(name) except ValueError: @@ -336,9 +325,7 @@ class FileManagementFormView(MultiStepFormView): return idx def get_field_selection(self): - """ Once data columns have been selected, attempt to pre-select the proper data from the database. - This function is called once the field selection has been validated. - The pre-fill data are then passed through to the part selection form. + """Once data columns have been selected, attempt to pre-select the proper data from the database. This function is called once the field selection has been validated. The pre-fill data are then passed through to the part selection form. This method is very specific to the type of data found in the file, therefore overwrite it in the subclass. @@ -346,7 +333,7 @@ class FileManagementFormView(MultiStepFormView): pass def get_clean_items(self): - """ returns dict with all cleaned values """ + """Returns dict with all cleaned values.""" items = {} for form_key, form_value in self.get_all_cleaned_data().items(): @@ -373,8 +360,7 @@ class FileManagementFormView(MultiStepFormView): return items def check_field_selection(self, form): - """ Check field matching """ - + """Check field matching.""" # Are there any missing columns? missing_columns = [] @@ -422,8 +408,7 @@ class FileManagementFormView(MultiStepFormView): return valid def validate(self, step, form): - """ Validate forms """ - + """Validate forms.""" valid = True # Get form table data @@ -442,8 +427,7 @@ class FileManagementFormView(MultiStepFormView): return valid def post(self, request, *args, **kwargs): - """ Perform validations before posting data """ - + """Perform validations before posting data.""" wizard_goto_step = self.request.POST.get('wizard_goto_step', None) form = self.get_form(data=self.request.POST, files=self.request.FILES) @@ -458,14 +442,21 @@ class FileManagementFormView(MultiStepFormView): class FileManagementAjaxView(AjaxView): - """ Use a FileManagementFormView as base for a AjaxView - Inherit this class before inheriting the base FileManagementFormView + """Use a FileManagementFormView as base for a AjaxView Inherit this class before inheriting the base FileManagementFormView. ajax_form_steps_template: templates for rendering ajax validate: function to validate the current form -> normally point to the same function in the base FileManagementFormView """ def post(self, request): + """Handle wizard step call. + + Possible actions: + - Step back -> render previous step + - Invalid form -> render error + - Valid form and not done -> render next step + - Valid form and done -> render final step + """ # check if back-step button was selected wizard_back = self.request.POST.get('act-btn_back', None) if wizard_back: @@ -497,6 +488,7 @@ class FileManagementAjaxView(AjaxView): return self.renderJsonResponse(request, data={'form_valid': None}) def get(self, request): + """Reset storage if flag is set, proceed to render JsonResponse.""" if 'reset' in request.GET: # reset form self.storage.reset() @@ -504,11 +496,12 @@ class FileManagementAjaxView(AjaxView): return self.renderJsonResponse(request) def renderJsonResponse(self, request, form=None, data={}, context=None): - """ always set the right templates before rendering """ + """Always set the right templates before rendering.""" self.setTemplate() return super().renderJsonResponse(request, form=form, data=data, context=context) - def get_data(self): + def get_data(self) -> dict: + """Get extra context data.""" data = super().get_data() data['hideErrorMessage'] = '1' # hide the error buttons = [{'name': 'back', 'title': _('Previous Step')}] if self.get_step_index() > 0 else [] @@ -516,9 +509,13 @@ class FileManagementAjaxView(AjaxView): return data def setTemplate(self): - """ set template name and title """ + """Set template name and title.""" self.ajax_template_name = self.ajax_form_steps_template[self.get_step_index()] self.ajax_form_title = self.form_steps_description[self.get_step_index()] def validate(self, obj, form, **kwargs): + """Generic validate action. + + This is the point to process provided userinput. + """ raise NotImplementedError('This function needs to be overridden!') diff --git a/InvenTree/company/__init__.py b/InvenTree/company/__init__.py index 6c89578ac3..38da11ee68 100644 --- a/InvenTree/company/__init__.py +++ b/InvenTree/company/__init__.py @@ -1,5 +1,4 @@ -""" -The Company module is responsible for managing Company interactions. +"""The Company module is responsible for managing Company interactions. A company can be either (or both): diff --git a/InvenTree/company/admin.py b/InvenTree/company/admin.py index d3bf75dab3..d279c9227d 100644 --- a/InvenTree/company/admin.py +++ b/InvenTree/company/admin.py @@ -1,3 +1,5 @@ +"""Admin class for the 'company' app""" + from django.contrib import admin import import_export.widgets as widgets @@ -13,9 +15,10 @@ from .models import (Company, ManufacturerPart, ManufacturerPartAttachment, class CompanyResource(ModelResource): - """ Class for managing Company data import/export """ + """Class for managing Company data import/export.""" class Meta: + """Metaclass defines extra options""" model = Company skip_unchanged = True report_skipped = False @@ -23,6 +26,7 @@ class CompanyResource(ModelResource): class CompanyAdmin(ImportExportModelAdmin): + """Admin class for the Company model""" resource_class = CompanyResource @@ -35,9 +39,7 @@ class CompanyAdmin(ImportExportModelAdmin): class SupplierPartResource(ModelResource): - """ - Class for managing SupplierPart data import/export - """ + """Class for managing SupplierPart data import/export.""" part = Field(attribute='part', widget=widgets.ForeignKeyWidget(Part)) @@ -48,6 +50,7 @@ class SupplierPartResource(ModelResource): supplier_name = Field(attribute='supplier__name', readonly=True) class Meta: + """Metaclass defines extra admin options""" model = SupplierPart skip_unchanged = True report_skipped = True @@ -55,6 +58,7 @@ class SupplierPartResource(ModelResource): class SupplierPartAdmin(ImportExportModelAdmin): + """Admin class for the SupplierPart model""" resource_class = SupplierPartResource @@ -71,9 +75,7 @@ class SupplierPartAdmin(ImportExportModelAdmin): class ManufacturerPartResource(ModelResource): - """ - Class for managing ManufacturerPart data import/export - """ + """Class for managing ManufacturerPart data import/export.""" part = Field(attribute='part', widget=widgets.ForeignKeyWidget(Part)) @@ -84,6 +86,7 @@ class ManufacturerPartResource(ModelResource): manufacturer_name = Field(attribute='manufacturer__name', readonly=True) class Meta: + """Metaclass defines extra admin options""" model = ManufacturerPart skip_unchanged = True report_skipped = True @@ -91,9 +94,7 @@ class ManufacturerPartResource(ModelResource): class ManufacturerPartAdmin(ImportExportModelAdmin): - """ - Admin class for ManufacturerPart model - """ + """Admin class for ManufacturerPart model.""" resource_class = ManufacturerPartResource @@ -109,9 +110,7 @@ class ManufacturerPartAdmin(ImportExportModelAdmin): class ManufacturerPartAttachmentAdmin(ImportExportModelAdmin): - """ - Admin class for ManufacturerPartAttachment model - """ + """Admin class for ManufacturerPartAttachment model.""" list_display = ('manufacturer_part', 'attachment', 'comment') @@ -119,11 +118,10 @@ class ManufacturerPartAttachmentAdmin(ImportExportModelAdmin): class ManufacturerPartParameterResource(ModelResource): - """ - Class for managing ManufacturerPartParameter data import/export - """ + """Class for managing ManufacturerPartParameter data import/export.""" class Meta: + """Metaclass defines extra admin options""" model = ManufacturerPartParameter skip_unchanged = True report_skipped = True @@ -131,9 +129,7 @@ class ManufacturerPartParameterResource(ModelResource): class ManufacturerPartParameterAdmin(ImportExportModelAdmin): - """ - Admin class for ManufacturerPartParameter model - """ + """Admin class for ManufacturerPartParameter model.""" resource_class = ManufacturerPartParameterResource @@ -149,7 +145,7 @@ class ManufacturerPartParameterAdmin(ImportExportModelAdmin): class SupplierPriceBreakResource(ModelResource): - """ Class for managing SupplierPriceBreak data import/export """ + """Class for managing SupplierPriceBreak data import/export.""" part = Field(attribute='part', widget=widgets.ForeignKeyWidget(SupplierPart)) @@ -164,6 +160,7 @@ class SupplierPriceBreakResource(ModelResource): MPN = Field(attribute='part__MPN', readonly=True) class Meta: + """Metaclass defines extra admin options""" model = SupplierPriceBreak skip_unchanged = True report_skipped = False @@ -171,6 +168,7 @@ class SupplierPriceBreakResource(ModelResource): class SupplierPriceBreakAdmin(ImportExportModelAdmin): + """Admin class for the SupplierPriceBreak model""" resource_class = SupplierPriceBreakResource diff --git a/InvenTree/company/api.py b/InvenTree/company/api.py index b63cd65c83..c0455abcca 100644 --- a/InvenTree/company/api.py +++ b/InvenTree/company/api.py @@ -1,6 +1,4 @@ -""" -Provides a JSON API for the Company app -""" +"""Provides a JSON API for the Company app.""" from django.db.models import Q from django.urls import include, re_path @@ -23,7 +21,7 @@ from .serializers import (CompanySerializer, class CompanyList(generics.ListCreateAPIView): - """ API endpoint for accessing a list of Company objects + """API endpoint for accessing a list of Company objects. Provides two methods: @@ -35,7 +33,7 @@ class CompanyList(generics.ListCreateAPIView): queryset = Company.objects.all() def get_queryset(self): - + """Return annotated queryset for the company list endpoint""" queryset = super().get_queryset() queryset = CompanySerializer.annotate_queryset(queryset) @@ -70,13 +68,13 @@ class CompanyList(generics.ListCreateAPIView): class CompanyDetail(generics.RetrieveUpdateDestroyAPIView): - """ API endpoint for detail of a single Company object """ + """API endpoint for detail of a single Company object.""" queryset = Company.objects.all() serializer_class = CompanySerializer def get_queryset(self): - + """Return annotated queryset for the company detail endpoint""" queryset = super().get_queryset() queryset = CompanySerializer.annotate_queryset(queryset) @@ -84,11 +82,11 @@ class CompanyDetail(generics.RetrieveUpdateDestroyAPIView): class ManufacturerPartFilter(rest_filters.FilterSet): - """ - Custom API filters for the ManufacturerPart list endpoint. - """ + """Custom API filters for the ManufacturerPart list endpoint.""" class Meta: + """Metaclass options.""" + model = ManufacturerPart fields = [ 'manufacturer', @@ -101,7 +99,7 @@ class ManufacturerPartFilter(rest_filters.FilterSet): class ManufacturerPartList(generics.ListCreateAPIView): - """ API endpoint for list view of ManufacturerPart object + """API endpoint for list view of ManufacturerPart object. - GET: Return list of ManufacturerPart objects - POST: Create a new ManufacturerPart object @@ -117,7 +115,7 @@ class ManufacturerPartList(generics.ListCreateAPIView): filterset_class = ManufacturerPartFilter def get_serializer(self, *args, **kwargs): - + """Return serializer instance for this endpoint""" # Do we wish to include extra detail? try: params = self.request.query_params @@ -149,7 +147,7 @@ class ManufacturerPartList(generics.ListCreateAPIView): class ManufacturerPartDetail(generics.RetrieveUpdateDestroyAPIView): - """ API endpoint for detail view of ManufacturerPart object + """API endpoint for detail view of ManufacturerPart object. - GET: Retrieve detail view - PATCH: Update object @@ -161,9 +159,7 @@ class ManufacturerPartDetail(generics.RetrieveUpdateDestroyAPIView): class ManufacturerPartAttachmentList(AttachmentMixin, generics.ListCreateAPIView): - """ - API endpoint for listing (and creating) a ManufacturerPartAttachment (file upload). - """ + """API endpoint for listing (and creating) a ManufacturerPartAttachment (file upload).""" queryset = ManufacturerPartAttachment.objects.all() serializer_class = ManufacturerPartAttachmentSerializer @@ -178,24 +174,20 @@ class ManufacturerPartAttachmentList(AttachmentMixin, generics.ListCreateAPIView class ManufacturerPartAttachmentDetail(AttachmentMixin, generics.RetrieveUpdateDestroyAPIView): - """ - Detail endpooint for ManufacturerPartAttachment model - """ + """Detail endpooint for ManufacturerPartAttachment model.""" queryset = ManufacturerPartAttachment.objects.all() serializer_class = ManufacturerPartAttachmentSerializer class ManufacturerPartParameterList(generics.ListCreateAPIView): - """ - API endpoint for list view of ManufacturerPartParamater model. - """ + """API endpoint for list view of ManufacturerPartParamater model.""" queryset = ManufacturerPartParameter.objects.all() serializer_class = ManufacturerPartParameterSerializer def get_serializer(self, *args, **kwargs): - + """Return serializer instance for this endpoint""" # Do we wish to include any extra detail? try: params = self.request.query_params @@ -215,10 +207,7 @@ class ManufacturerPartParameterList(generics.ListCreateAPIView): return self.serializer_class(*args, **kwargs) def filter_queryset(self, queryset): - """ - Custom filtering for the queryset - """ - + """Custom filtering for the queryset.""" queryset = super().filter_queryset(queryset) params = self.request.query_params @@ -258,16 +247,14 @@ class ManufacturerPartParameterList(generics.ListCreateAPIView): class ManufacturerPartParameterDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for detail view of ManufacturerPartParameter model - """ + """API endpoint for detail view of ManufacturerPartParameter model.""" queryset = ManufacturerPartParameter.objects.all() serializer_class = ManufacturerPartParameterSerializer class SupplierPartList(generics.ListCreateAPIView): - """ API endpoint for list view of SupplierPart object + """API endpoint for list view of SupplierPart object. - GET: Return list of SupplierPart objects - POST: Create a new SupplierPart object @@ -275,17 +262,8 @@ class SupplierPartList(generics.ListCreateAPIView): queryset = SupplierPart.objects.all() - def get_queryset(self): - - queryset = super().get_queryset() - - return queryset - def filter_queryset(self, queryset): - """ - Custom filtering for the queryset. - """ - + """Custom filtering for the queryset.""" queryset = super().filter_queryset(queryset) params = self.request.query_params @@ -330,6 +308,7 @@ class SupplierPartList(generics.ListCreateAPIView): return queryset def get_serializer(self, *args, **kwargs): + """Return serializer instance for this endpoint""" # Do we wish to include extra detail? try: @@ -369,7 +348,7 @@ class SupplierPartList(generics.ListCreateAPIView): class SupplierPartDetail(generics.RetrieveUpdateDestroyAPIView): - """ API endpoint for detail view of SupplierPart object + """API endpoint for detail view of SupplierPart object. - GET: Retrieve detail view - PATCH: Update object @@ -384,7 +363,7 @@ class SupplierPartDetail(generics.RetrieveUpdateDestroyAPIView): class SupplierPriceBreakList(generics.ListCreateAPIView): - """ API endpoint for list view of SupplierPriceBreak object + """API endpoint for list view of SupplierPriceBreak object. - GET: Retrieve list of SupplierPriceBreak objects - POST: Create a new SupplierPriceBreak object @@ -403,9 +382,7 @@ class SupplierPriceBreakList(generics.ListCreateAPIView): class SupplierPriceBreakDetail(generics.RetrieveUpdateDestroyAPIView): - """ - Detail endpoint for SupplierPriceBreak object - """ + """Detail endpoint for SupplierPriceBreak object.""" queryset = SupplierPriceBreak.objects.all() serializer_class = SupplierPriceBreakSerializer diff --git a/InvenTree/company/apps.py b/InvenTree/company/apps.py index a0cd4919cf..feff9a0e0f 100644 --- a/InvenTree/company/apps.py +++ b/InvenTree/company/apps.py @@ -1,12 +1,13 @@ +"""Config for the 'company' app""" + from django.apps import AppConfig class CompanyConfig(AppConfig): + """Config class for the 'company' app""" + name = 'company' def ready(self): - """ - This function is called whenever the Company app is loaded. - """ - + """This function is called whenever the Company app is loaded.""" pass diff --git a/InvenTree/company/forms.py b/InvenTree/company/forms.py index bf15038898..7e0b0c66fb 100644 --- a/InvenTree/company/forms.py +++ b/InvenTree/company/forms.py @@ -1,6 +1,4 @@ -""" -Django Forms for interacting with Company app -""" +"""Django Forms for interacting with Company app.""" import django.forms from django.utils.translation import gettext_lazy as _ @@ -12,9 +10,7 @@ from .models import Company, SupplierPriceBreak class CompanyImageDownloadForm(HelperForm): - """ - Form for downloading an image from a URL - """ + """Form for downloading an image from a URL.""" url = django.forms.URLField( label=_('URL'), @@ -23,6 +19,8 @@ class CompanyImageDownloadForm(HelperForm): ) class Meta: + """Metaclass options.""" + model = Company fields = [ 'url', @@ -30,7 +28,7 @@ class CompanyImageDownloadForm(HelperForm): class EditPriceBreakForm(HelperForm): - """ Form for creating / editing a supplier price break """ + """Form for creating / editing a supplier price break.""" quantity = RoundingDecimalFormField( max_digits=10, @@ -40,6 +38,8 @@ class EditPriceBreakForm(HelperForm): ) class Meta: + """Metaclass options.""" + model = SupplierPriceBreak fields = [ 'part', diff --git a/InvenTree/company/models.py b/InvenTree/company/models.py index 1feac56f67..0af1ec73bc 100644 --- a/InvenTree/company/models.py +++ b/InvenTree/company/models.py @@ -1,6 +1,4 @@ -""" -Company database model definitions -""" +"""Company database model definitions.""" import os @@ -27,7 +25,7 @@ from InvenTree.status_codes import PurchaseOrderStatus def rename_company_image(instance, filename): - """ Function to rename a company image after upload + """Function to rename a company image after upload. Args: instance: Company object @@ -36,7 +34,6 @@ def rename_company_image(instance, filename): Returns: New image filename """ - base = 'company_images' if filename.count('.') > 0: @@ -53,7 +50,8 @@ def rename_company_image(instance, filename): class Company(models.Model): - """ A Company object represents an external company. + """A Company object represents an external company. + It may be a supplier or a customer or a manufacturer (or a combination) - A supplier is a company from which parts can be purchased @@ -79,9 +77,11 @@ class Company(models.Model): @staticmethod def get_api_url(): + """Return the API URL associated with the Company model""" return reverse('api-company-list') class Meta: + """Metaclass defines extra model options""" ordering = ['name', ] constraints = [ UniqueConstraint(fields=['name', 'email'], name='unique_name_email_pair') @@ -150,13 +150,11 @@ class Company(models.Model): @property def currency_code(self): - """ - Return the currency code associated with this company. + """Return the currency code associated with this company. - If the currency code is invalid, use the default currency - If the currency code is not specified, use the default currency """ - code = self.currency if code not in CURRENCIES: @@ -165,103 +163,41 @@ class Company(models.Model): return code def __str__(self): - """ Get string representation of a Company """ + """Get string representation of a Company.""" return "{n} - {d}".format(n=self.name, d=self.description) def get_absolute_url(self): - """ Get the web URL for the detail view for this Company """ + """Get the web URL for the detail view for this Company.""" return reverse('company-detail', kwargs={'pk': self.id}) def get_image_url(self): - """ Return the URL of the image for this company """ - + """Return the URL of the image for this company.""" if self.image: return getMediaUrl(self.image.url) else: return getBlankImage() def get_thumbnail_url(self): - """ Return the URL for the thumbnail image for this Company """ - + """Return the URL for the thumbnail image for this Company.""" if self.image: return getMediaUrl(self.image.thumbnail.url) else: return getBlankThumbnail() - @property - def manufactured_part_count(self): - """ The number of parts manufactured by this company """ - return self.manufactured_parts.count() - - @property - def has_manufactured_parts(self): - return self.manufactured_part_count > 0 - - @property - def supplied_part_count(self): - """ The number of parts supplied by this company """ - return self.supplied_parts.count() - - @property - def has_supplied_parts(self): - """ Return True if this company supplies any parts """ - return self.supplied_part_count > 0 - @property def parts(self): - """ Return SupplierPart objects which are supplied or manufactured by this company """ + """Return SupplierPart objects which are supplied or manufactured by this company.""" return SupplierPart.objects.filter(Q(supplier=self.id) | Q(manufacturer_part__manufacturer=self.id)) - @property - def part_count(self): - """ The number of parts manufactured (or supplied) by this Company """ - return self.parts.count() - - @property - def has_parts(self): - return self.part_count > 0 - @property def stock_items(self): - """ Return a list of all stock items supplied or manufactured by this company """ + """Return a list of all stock items supplied or manufactured by this company.""" stock = apps.get_model('stock', 'StockItem') return stock.objects.filter(Q(supplier_part__supplier=self.id) | Q(supplier_part__manufacturer_part__manufacturer=self.id)).all() - @property - def stock_count(self): - """ Return the number of stock items supplied or manufactured by this company """ - return self.stock_items.count() - - def outstanding_purchase_orders(self): - """ Return purchase orders which are 'outstanding' """ - return self.purchase_orders.filter(status__in=PurchaseOrderStatus.OPEN) - - def pending_purchase_orders(self): - """ Return purchase orders which are PENDING (not yet issued) """ - return self.purchase_orders.filter(status=PurchaseOrderStatus.PENDING) - - def closed_purchase_orders(self): - """ Return purchase orders which are not 'outstanding' - - - Complete - - Failed / lost - - Returned - """ - - return self.purchase_orders.exclude(status__in=PurchaseOrderStatus.OPEN) - - def complete_purchase_orders(self): - return self.purchase_orders.filter(status=PurchaseOrderStatus.COMPLETE) - - def failed_purchase_orders(self): - """ Return any purchase orders which were not successful """ - - return self.purchase_orders.filter(status__in=PurchaseOrderStatus.FAILED) - class Contact(models.Model): - """ A Contact represents a person who works at a particular company. - A Company may have zero or more associated Contact objects. + """A Contact represents a person who works at a particular company. A Company may have zero or more associated Contact objects. Attributes: company: Company link for this contact @@ -284,10 +220,7 @@ class Contact(models.Model): class ManufacturerPart(models.Model): - """ Represents a unique part as provided by a Manufacturer - Each ManufacturerPart is identified by a MPN (Manufacturer Part Number) - Each ManufacturerPart is also linked to a Part object. - A Part may be available from multiple manufacturers + """Represents a unique part as provided by a Manufacturer Each ManufacturerPart is identified by a MPN (Manufacturer Part Number) Each ManufacturerPart is also linked to a Part object. A Part may be available from multiple manufacturers. Attributes: part: Link to the master Part @@ -299,9 +232,11 @@ class ManufacturerPart(models.Model): @staticmethod def get_api_url(): + """Return the API URL associated with the ManufacturerPart instance""" return reverse('api-manufacturer-part-list') class Meta: + """Metaclass defines extra model options""" unique_together = ('part', 'manufacturer', 'MPN') part = models.ForeignKey('part.Part', on_delete=models.CASCADE, @@ -346,10 +281,7 @@ class ManufacturerPart(models.Model): @classmethod def create(cls, part, manufacturer, mpn, description, link=None): - """ Check if ManufacturerPart instance does not already exist - then create it - """ - + """Check if ManufacturerPart instance does not already exist then create it.""" manufacturer_part = None try: @@ -364,6 +296,7 @@ class ManufacturerPart(models.Model): return manufacturer_part def __str__(self): + """Format a string representation of a ManufacturerPart""" s = '' if self.manufacturer: @@ -376,15 +309,15 @@ class ManufacturerPart(models.Model): class ManufacturerPartAttachment(InvenTreeAttachment): - """ - Model for storing file attachments against a ManufacturerPart object - """ + """Model for storing file attachments against a ManufacturerPart object.""" @staticmethod def get_api_url(): + """Return the API URL associated with the ManufacturerPartAttachment model""" return reverse('api-manufacturer-part-attachment-list') def getSubdir(self): + """Return the subdirectory where attachment files for the ManufacturerPart model are located""" return os.path.join("manufacturer_part_files", str(self.manufacturer_part.id)) manufacturer_part = models.ForeignKey(ManufacturerPart, on_delete=models.CASCADE, @@ -392,8 +325,7 @@ class ManufacturerPartAttachment(InvenTreeAttachment): class ManufacturerPartParameter(models.Model): - """ - A ManufacturerPartParameter represents a key:value parameter for a MnaufacturerPart. + """A ManufacturerPartParameter represents a key:value parameter for a MnaufacturerPart. This is used to represent parmeters / properties for a particular manufacturer part. @@ -402,9 +334,11 @@ class ManufacturerPartParameter(models.Model): @staticmethod def get_api_url(): + """Return the API URL associated with the ManufacturerPartParameter model""" return reverse('api-manufacturer-part-parameter-list') class Meta: + """Metaclass defines extra model options""" unique_together = ('manufacturer_part', 'name') manufacturer_part = models.ForeignKey( @@ -437,13 +371,14 @@ class ManufacturerPartParameter(models.Model): class SupplierPartManager(models.Manager): - """ Define custom SupplierPart objects manager + """Define custom SupplierPart objects manager. - The main purpose of this manager is to improve database hit as the - SupplierPart model involves A LOT of foreign keys lookups + The main purpose of this manager is to improve database hit as the + SupplierPart model involves A LOT of foreign keys lookups """ def get_queryset(self): + """Prefetch related fields when querying against the SupplierPart model""" # Always prefetch related models return super().get_queryset().prefetch_related( 'part', @@ -453,10 +388,7 @@ class SupplierPartManager(models.Manager): class SupplierPart(models.Model): - """ Represents a unique part as provided by a Supplier - Each SupplierPart is identified by a SKU (Supplier Part Number) - Each SupplierPart is also linked to a Part or ManufacturerPart object. - A Part may be available from multiple suppliers + """Represents a unique part as provided by a Supplier Each SupplierPart is identified by a SKU (Supplier Part Number) Each SupplierPart is also linked to a Part or ManufacturerPart object. A Part may be available from multiple suppliers. Attributes: part: Link to the master Part (Obsolete) @@ -476,13 +408,15 @@ class SupplierPart(models.Model): @staticmethod def get_api_url(): + """Return the API URL associated with the SupplierPart model""" return reverse('api-supplier-part-list') def get_absolute_url(self): + """Return the web URL of the detail view for this SupplierPart""" return reverse('supplier-part-detail', kwargs={'pk': self.id}) def api_instance_filters(self): - + """Return custom API filters for this particular instance""" return { 'manufacturer_part': { 'part': self.part.pk @@ -490,13 +424,17 @@ class SupplierPart(models.Model): } class Meta: + """Metaclass defines extra model options""" unique_together = ('part', 'supplier', 'SKU') # This model was moved from the 'Part' app db_table = 'part_supplierpart' def clean(self): + """Custom clean action for the SupplierPart model: + - Ensure that manufacturer_part.part and part are the same! + """ super().clean() # Ensure that the linked manufacturer_part points to the same part! @@ -508,8 +446,7 @@ class SupplierPart(models.Model): }) def save(self, *args, **kwargs): - """ Overriding save method to connect an existing ManufacturerPart """ - + """Overriding save method to connect an existing ManufacturerPart.""" manufacturer_part = None if all(key in kwargs for key in ('manufacturer', 'MPN')): @@ -593,10 +530,10 @@ class SupplierPart(models.Model): @property def manufacturer_string(self): - """ Format a MPN string for this SupplierPart. + """Format a MPN string for this SupplierPart. + Concatenates manufacture name and part number. """ - items = [] if self.manufacturer_part: @@ -609,26 +546,26 @@ class SupplierPart(models.Model): @property def has_price_breaks(self): + """Return True if this SupplierPart has associated price breaks""" return self.price_breaks.count() > 0 @property def price_breaks(self): - """ Return the associated price breaks in the correct order """ + """Return the associated price breaks in the correct order.""" return self.pricebreaks.order_by('quantity').all() @property def unit_pricing(self): + """Return the single-quantity pricing for this SupplierPart""" return self.get_price(1) - def add_price_break(self, quantity, price): - """ - Create a new price break for this part + def add_price_break(self, quantity, price) -> None: + """Create a new price break for this part. - args: - quantity - Numerical quantity - price - Must be a Money object + Args: + quantity: Numerical quantity + price: Must be a Money object """ - # Check if a price break at that quantity already exists... if self.price_breaks.filter(quantity=quantity, part=self.pk).exists(): return @@ -642,18 +579,14 @@ class SupplierPart(models.Model): get_price = common.models.get_price def open_orders(self): - """ Return a database query for PurchaseOrder line items for this SupplierPart, - limited to purchase orders that are open / outstanding. - """ - + """Return a database query for PurchaseOrder line items for this SupplierPart, limited to purchase orders that are open / outstanding.""" return self.purchase_order_line_items.prefetch_related('order').filter(order__status__in=PurchaseOrderStatus.OPEN) def on_order(self): - """ Return the total quantity of items currently on order. + """Return the total quantity of items currently on order. Subtract partially received stock as appropriate """ - totals = self.open_orders().aggregate(Sum('quantity'), Sum('received')) # Quantity on order @@ -668,15 +601,16 @@ class SupplierPart(models.Model): return max(q - r, 0) def purchase_orders(self): - """ Returns a list of purchase orders relating to this supplier part """ - + """Returns a list of purchase orders relating to this supplier part.""" return [line.order for line in self.purchase_order_line_items.all().prefetch_related('order')] @property def pretty_name(self): + """Format a 'pretty' name for this SupplierPart""" return str(self) def __str__(self): + """Format a string representation of a SupplierPart""" s = '' if self.part.IPN: @@ -692,7 +626,8 @@ class SupplierPart(models.Model): class SupplierPriceBreak(common.models.PriceBreak): - """ Represents a quantity price break for a SupplierPart. + """Represents a quantity price break for a SupplierPart. + - Suppliers can offer discounts at larger quantities - SupplierPart(s) may have zero-or-more associated SupplierPriceBreak(s) @@ -706,6 +641,7 @@ class SupplierPriceBreak(common.models.PriceBreak): @staticmethod def get_api_url(): + """Return the API URL associated with the SupplierPriceBreak model""" return reverse('api-part-supplier-price-list') part = models.ForeignKey(SupplierPart, on_delete=models.CASCADE, related_name='pricebreaks', verbose_name=_('Part'),) @@ -713,10 +649,12 @@ class SupplierPriceBreak(common.models.PriceBreak): updated = models.DateTimeField(auto_now=True, null=True, verbose_name=_('last updated')) class Meta: + """Metaclass defines extra model options""" unique_together = ("part", "quantity") # This model was moved from the 'Part' app db_table = 'part_supplierpricebreak' def __str__(self): + """Format a string representation of a SupplierPriceBreak instance""" return f'{self.part.SKU} - {self.price} @ {self.quantity}' diff --git a/InvenTree/company/serializers.py b/InvenTree/company/serializers.py index e0c34d077f..e1d95637b8 100644 --- a/InvenTree/company/serializers.py +++ b/InvenTree/company/serializers.py @@ -1,6 +1,4 @@ -""" -JSON serializers for Company app -""" +"""JSON serializers for Company app.""" from django.utils.translation import gettext_lazy as _ @@ -21,13 +19,15 @@ from .models import (Company, ManufacturerPart, ManufacturerPartAttachment, class CompanyBriefSerializer(InvenTreeModelSerializer): - """ Serializer for Company object (limited detail) """ + """Serializer for Company object (limited detail)""" url = serializers.CharField(source='get_absolute_url', read_only=True) image = serializers.CharField(source='get_thumbnail_url', read_only=True) class Meta: + """Metaclass options.""" + model = Company fields = [ 'pk', @@ -39,11 +39,11 @@ class CompanyBriefSerializer(InvenTreeModelSerializer): class CompanySerializer(InvenTreeModelSerializer): - """ Serializer for Company object (full detail) """ + """Serializer for Company object (full detail)""" @staticmethod def annotate_queryset(queryset): - + """Annoate the supplied queryset with aggregated information""" # Add count of parts manufactured queryset = queryset.annotate( parts_manufactured=SubqueryCount('manufactured_parts') @@ -71,6 +71,8 @@ class CompanySerializer(InvenTreeModelSerializer): ) class Meta: + """Metaclass options.""" + model = Company fields = [ 'pk', @@ -96,9 +98,7 @@ class CompanySerializer(InvenTreeModelSerializer): class ManufacturerPartSerializer(InvenTreeModelSerializer): - """ - Serializer for ManufacturerPart object - """ + """Serializer for ManufacturerPart object.""" part_detail = PartBriefSerializer(source='part', many=False, read_only=True) @@ -107,7 +107,7 @@ class ManufacturerPartSerializer(InvenTreeModelSerializer): pretty_name = serializers.CharField(read_only=True) def __init__(self, *args, **kwargs): - + """Initialize this serializer with extra detail fields as required""" part_detail = kwargs.pop('part_detail', True) manufacturer_detail = kwargs.pop('manufacturer_detail', True) prettify = kwargs.pop('pretty', False) @@ -126,6 +126,8 @@ class ManufacturerPartSerializer(InvenTreeModelSerializer): manufacturer = serializers.PrimaryKeyRelatedField(queryset=Company.objects.filter(is_manufacturer=True)) class Meta: + """Metaclass options.""" + model = ManufacturerPart fields = [ 'pk', @@ -141,11 +143,11 @@ class ManufacturerPartSerializer(InvenTreeModelSerializer): class ManufacturerPartAttachmentSerializer(InvenTreeAttachmentSerializer): - """ - Serializer for the ManufacturerPartAttachment class - """ + """Serializer for the ManufacturerPartAttachment class.""" class Meta: + """Metaclass options.""" + model = ManufacturerPartAttachment fields = [ @@ -164,14 +166,12 @@ class ManufacturerPartAttachmentSerializer(InvenTreeAttachmentSerializer): class ManufacturerPartParameterSerializer(InvenTreeModelSerializer): - """ - Serializer for the ManufacturerPartParameter model - """ + """Serializer for the ManufacturerPartParameter model.""" manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', many=False, read_only=True) def __init__(self, *args, **kwargs): - + """Initialize this serializer with extra detail fields as required""" man_detail = kwargs.pop('manufacturer_part_detail', False) super(ManufacturerPartParameterSerializer, self).__init__(*args, **kwargs) @@ -180,6 +180,8 @@ class ManufacturerPartParameterSerializer(InvenTreeModelSerializer): self.fields.pop('manufacturer_part_detail') class Meta: + """Metaclass options.""" + model = ManufacturerPartParameter fields = [ @@ -193,7 +195,7 @@ class ManufacturerPartParameterSerializer(InvenTreeModelSerializer): class SupplierPartSerializer(InvenTreeModelSerializer): - """ Serializer for SupplierPart object """ + """Serializer for SupplierPart object.""" part_detail = PartBriefSerializer(source='part', many=False, read_only=True) @@ -204,7 +206,7 @@ class SupplierPartSerializer(InvenTreeModelSerializer): pretty_name = serializers.CharField(read_only=True) def __init__(self, *args, **kwargs): - + """Initialize this serializer with extra detail fields as required""" part_detail = kwargs.pop('part_detail', True) supplier_detail = kwargs.pop('supplier_detail', True) manufacturer_detail = kwargs.pop('manufacturer_detail', True) @@ -234,6 +236,8 @@ class SupplierPartSerializer(InvenTreeModelSerializer): manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', read_only=True) class Meta: + """Metaclass options.""" + model = SupplierPart fields = [ 'description', @@ -255,8 +259,7 @@ class SupplierPartSerializer(InvenTreeModelSerializer): ] def create(self, validated_data): - """ Extract manufacturer data and process ManufacturerPart """ - + """Extract manufacturer data and process ManufacturerPart.""" # Create SupplierPart supplier_part = super().create(validated_data) @@ -275,7 +278,7 @@ class SupplierPartSerializer(InvenTreeModelSerializer): class SupplierPriceBreakSerializer(InvenTreeModelSerializer): - """ Serializer for SupplierPriceBreak object """ + """Serializer for SupplierPriceBreak object.""" quantity = InvenTreeDecimalField() @@ -292,6 +295,8 @@ class SupplierPriceBreakSerializer(InvenTreeModelSerializer): ) class Meta: + """Metaclass options.""" + model = SupplierPriceBreak fields = [ 'pk', diff --git a/InvenTree/company/test_api.py b/InvenTree/company/test_api.py index e0b54130ff..d3dd12eb81 100644 --- a/InvenTree/company/test_api.py +++ b/InvenTree/company/test_api.py @@ -1,3 +1,5 @@ +"""Unit testing for the company app API functions""" + from django.urls import reverse from rest_framework import status @@ -8,9 +10,7 @@ from .models import Company class CompanyTest(InvenTreeAPITestCase): - """ - Series of tests for the Company DRF API - """ + """Series of tests for the Company DRF API.""" roles = [ 'purchase_order.add', @@ -18,7 +18,7 @@ class CompanyTest(InvenTreeAPITestCase): ] def setUp(self): - + """Perform initialization for the unit test class""" super().setUp() self.acme = Company.objects.create(name='ACME', description='Supplier', is_customer=False, is_supplier=True) @@ -26,6 +26,7 @@ class CompanyTest(InvenTreeAPITestCase): Company.objects.create(name='Sippy Cup Emporium', description='Another supplier') def test_company_list(self): + """Test the list API endpoint for the Company model""" url = reverse('api-company-list') # There should be three companies @@ -45,10 +46,7 @@ class CompanyTest(InvenTreeAPITestCase): self.assertEqual(len(response.data), 2) def test_company_detail(self): - """ - Tests for the Company detail endpoint - """ - + """Tests for the Company detail endpoint.""" url = reverse('api-company-detail', kwargs={'pk': self.acme.pk}) response = self.get(url) @@ -71,20 +69,14 @@ class CompanyTest(InvenTreeAPITestCase): self.assertEqual(response.data['currency'], 'NZD') def test_company_search(self): - """ - Test search functionality in company list - """ - + """Test search functionality in company list.""" url = reverse('api-company-list') data = {'search': 'cup'} response = self.get(url, data) self.assertEqual(len(response.data), 2) def test_company_create(self): - """ - Test that we can create a company via the API! - """ - + """Test that we can create a company via the API!""" url = reverse('api-company-list') # Name is required @@ -146,9 +138,7 @@ class CompanyTest(InvenTreeAPITestCase): class ManufacturerTest(InvenTreeAPITestCase): - """ - Series of tests for the Manufacturer DRF API - """ + """Series of tests for the Manufacturer DRF API.""" fixtures = [ 'category', @@ -164,6 +154,7 @@ class ManufacturerTest(InvenTreeAPITestCase): ] def test_manufacturer_part_list(self): + """Test the ManufacturerPart API list functionality""" url = reverse('api-manufacturer-part-list') # There should be three manufacturer parts @@ -191,9 +182,7 @@ class ManufacturerTest(InvenTreeAPITestCase): self.assertEqual(len(response.data), 2) def test_manufacturer_part_detail(self): - """ - Tests for the ManufacturerPart detail endpoint - """ + """Tests for the ManufacturerPart detail endpoint.""" url = reverse('api-manufacturer-part-detail', kwargs={'pk': 1}) response = self.get(url) @@ -210,13 +199,14 @@ class ManufacturerTest(InvenTreeAPITestCase): self.assertEqual(response.data['MPN'], 'MPN-TEST-123') def test_manufacturer_part_search(self): - # Test search functionality in manufacturer list + """Test search functionality in manufacturer list""" url = reverse('api-manufacturer-part-list') data = {'search': 'MPN'} response = self.get(url, data) self.assertEqual(len(response.data), 3) def test_supplier_part_create(self): + """Test a SupplierPart can be created via the API""" url = reverse('api-supplier-part-list') # Create a manufacturer part diff --git a/InvenTree/company/test_migrations.py b/InvenTree/company/test_migrations.py index 882d54260f..4c8b0cb586 100644 --- a/InvenTree/company/test_migrations.py +++ b/InvenTree/company/test_migrations.py @@ -1,6 +1,4 @@ -""" -Tests for the company model database migrations -""" +"""Tests for the company model database migrations.""" from django_test_migrations.contrib.unittest_case import MigratorTestCase @@ -8,15 +6,13 @@ from InvenTree import helpers class TestForwardMigrations(MigratorTestCase): + """Unit testing class for testing 'company' app migrations""" migrate_from = ('company', helpers.getOldestMigrationFile('company')) migrate_to = ('company', helpers.getNewestMigrationFile('company')) def prepare(self): - """ - Create some simple Company data, and ensure that it migrates OK - """ - + """Create some simple Company data, and ensure that it migrates OK.""" Company = self.old_state.apps.get_model('company', 'company') Company.objects.create( @@ -26,29 +22,25 @@ class TestForwardMigrations(MigratorTestCase): ) def test_migrations(self): - + """Test the database state after applying all migrations""" Company = self.new_state.apps.get_model('company', 'company') self.assertEqual(Company.objects.count(), 1) class TestManufacturerField(MigratorTestCase): - """ - Tests for migration 0019 which migrates from old 'manufacturer_name' field to new 'manufacturer' field - """ + """Tests for migration 0019 which migrates from old 'manufacturer_name' field to new 'manufacturer' field.""" migrate_from = ('company', '0018_supplierpart_manufacturer') migrate_to = ('company', '0019_auto_20200413_0642') def prepare(self): - """ - Prepare the database by adding some test data 'before' the change: + """Prepare the database by adding some test data 'before' the change: - Part object - Company object (supplier) - SupplierPart object """ - Part = self.old_state.apps.get_model('part', 'part') Company = self.old_state.apps.get_model('company', 'company') SupplierPart = self.old_state.apps.get_model('company', 'supplierpart') @@ -85,10 +77,7 @@ class TestManufacturerField(MigratorTestCase): self.assertEqual(Company.objects.count(), 1) def test_company_objects(self): - """ - Test that the new companies have been created successfully - """ - + """Test that the new companies have been created successfully.""" # Two additional company objects should have been created Company = self.new_state.apps.get_model('company', 'company') self.assertEqual(Company.objects.count(), 3) @@ -108,22 +97,18 @@ class TestManufacturerField(MigratorTestCase): class TestManufacturerPart(MigratorTestCase): - """ - Tests for migration 0034-0037 which added and transitioned to the ManufacturerPart model - """ + """Tests for migration 0034-0037 which added and transitioned to the ManufacturerPart model.""" migrate_from = ('company', '0033_auto_20210410_1528') migrate_to = ('company', '0037_supplierpart_update_3') def prepare(self): - """ - Prepare the database by adding some test data 'before' the change: + """Prepare the database by adding some test data 'before' the change: - Part object - Company object (supplier) - SupplierPart object """ - Part = self.old_state.apps.get_model('part', 'part') Company = self.old_state.apps.get_model('company', 'company') SupplierPart = self.old_state.apps.get_model('company', 'supplierpart') @@ -214,10 +199,7 @@ class TestManufacturerPart(MigratorTestCase): ) def test_manufacturer_part_objects(self): - """ - Test that the new companies have been created successfully - """ - + """Test that the new companies have been created successfully.""" # Check on the SupplierPart objects SupplierPart = self.new_state.apps.get_model('company', 'supplierpart') @@ -238,16 +220,13 @@ class TestManufacturerPart(MigratorTestCase): class TestCurrencyMigration(MigratorTestCase): - """ - Tests for upgrade from basic currency support to django-money - """ + """Tests for upgrade from basic currency support to django-money.""" migrate_from = ('company', '0025_auto_20201110_1001') migrate_to = ('company', '0026_auto_20201110_1011') def prepare(self): - """ - Prepare some data: + """Prepare some data: - A part to buy - A supplier to buy from @@ -255,7 +234,6 @@ class TestCurrencyMigration(MigratorTestCase): - Multiple currency objects - Multiple supplier price breaks """ - Part = self.old_state.apps.get_model('part', 'part') part = Part.objects.create( @@ -293,7 +271,7 @@ class TestCurrencyMigration(MigratorTestCase): self.assertIsNone(pb.price) def test_currency_migration(self): - + """Test database state after applying migrations""" PB = self.new_state.apps.get_model('company', 'supplierpricebreak') for pb in PB.objects.all(): diff --git a/InvenTree/company/test_views.py b/InvenTree/company/test_views.py index a3ecd1651a..b9e95af7e2 100644 --- a/InvenTree/company/test_views.py +++ b/InvenTree/company/test_views.py @@ -1,11 +1,12 @@ -""" Unit tests for Company views (see views.py) """ +"""Unit tests for Company views (see views.py)""" from django.urls import reverse from InvenTree.helpers import InvenTreeTestCase -class CompanyViewTestBase(InvenTreeTestCase): +class CompanyViewTest(InvenTreeTestCase): + """Tests for various 'Company' views.""" fixtures = [ 'category', @@ -18,40 +19,29 @@ class CompanyViewTestBase(InvenTreeTestCase): roles = 'all' - -class CompanyViewTest(CompanyViewTestBase): - """ - Tests for various 'Company' views - """ - def test_company_index(self): - """ Test the company index """ - + """Test the company index.""" response = self.client.get(reverse('company-index')) self.assertEqual(response.status_code, 200) def test_manufacturer_index(self): - """ Test the manufacturer index """ - + """Test the manufacturer index.""" response = self.client.get(reverse('manufacturer-index')) self.assertEqual(response.status_code, 200) def test_customer_index(self): - """ Test the customer index """ - + """Test the customer index.""" response = self.client.get(reverse('customer-index')) self.assertEqual(response.status_code, 200) def test_manufacturer_part_detail_view(self): - """ Test the manufacturer part detail view """ - + """Test the manufacturer part detail view.""" response = self.client.get(reverse('manufacturer-part-detail', kwargs={'pk': 1})) self.assertEqual(response.status_code, 200) self.assertContains(response, 'MPN123') def test_supplier_part_detail_view(self): - """ Test the supplier part detail view """ - + """Test the supplier part detail view.""" response = self.client.get(reverse('supplier-part-detail', kwargs={'pk': 10})) self.assertEqual(response.status_code, 200) self.assertContains(response, 'MPN456-APPEL') diff --git a/InvenTree/company/tests.py b/InvenTree/company/tests.py index 008ac066b2..cb42653afc 100644 --- a/InvenTree/company/tests.py +++ b/InvenTree/company/tests.py @@ -1,3 +1,5 @@ +"""Unit tests for the models in the 'company' app""" + import os from decimal import Decimal @@ -11,6 +13,7 @@ from .models import (Company, Contact, ManufacturerPart, SupplierPart, class CompanySimpleTest(TestCase): + """Unit tests for the Company model""" fixtures = [ 'company', @@ -24,6 +27,7 @@ class CompanySimpleTest(TestCase): ] def setUp(self): + """Perform initialization for the tests in this class""" Company.objects.create(name='ABC Co.', description='Seller of ABC products', website='www.abc-sales.com', @@ -37,15 +41,18 @@ class CompanySimpleTest(TestCase): self.zergm312 = SupplierPart.objects.get(SKU='ZERGM312') def test_company_model(self): + """Tests for the company model data""" c = Company.objects.get(name='ABC Co.') self.assertEqual(c.name, 'ABC Co.') self.assertEqual(str(c), 'ABC Co. - Seller of ABC products') def test_company_url(self): + """Test the detail URL for a company""" c = Company.objects.get(pk=1) self.assertEqual(c.get_absolute_url(), '/company/1/') def test_image_renamer(self): + """Test the company image upload functionality""" c = Company.objects.get(pk=1) rn = rename_company_image(c, 'test.png') self.assertEqual(rn, 'company_images' + os.path.sep + 'company_1_img.png') @@ -53,23 +60,8 @@ class CompanySimpleTest(TestCase): rn = rename_company_image(c, 'test2') self.assertEqual(rn, 'company_images' + os.path.sep + 'company_1_img') - def test_part_count(self): - - acme = Company.objects.get(pk=1) - appel = Company.objects.get(pk=2) - zerg = Company.objects.get(pk=3) - - self.assertTrue(acme.has_parts) - self.assertEqual(acme.supplied_part_count, 4) - - self.assertTrue(appel.has_parts) - self.assertEqual(appel.supplied_part_count, 4) - - self.assertTrue(zerg.has_parts) - self.assertEqual(zerg.supplied_part_count, 2) - def test_price_breaks(self): - + """Unit tests for price breaks""" self.assertTrue(self.acme0001.has_price_breaks) self.assertTrue(self.acme0002.has_price_breaks) self.assertTrue(self.zergm312.has_price_breaks) @@ -81,8 +73,7 @@ class CompanySimpleTest(TestCase): self.assertEqual(self.zergm312.price_breaks.count(), 2) def test_quantity_pricing(self): - """ Simple test for quantity pricing """ - + """Simple test for quantity pricing.""" p = self.acme0001.get_price self.assertEqual(p(1), 10) self.assertEqual(p(4), 40) @@ -99,6 +90,7 @@ class CompanySimpleTest(TestCase): self.assertEqual(p(55), 68.75) def test_part_pricing(self): + """Unit tests for supplier part pricing""" m2x4 = Part.objects.get(name='M2x4 LPHS') self.assertEqual(m2x4.get_price_info(5.5), "38.5 - 41.25") @@ -116,10 +108,7 @@ class CompanySimpleTest(TestCase): self.assertIsNotNone(m3x12.get_price_info(50)) def test_currency_validation(self): - """ - Test validation for currency selection - """ - + """Test validation for currency selection.""" # Create a company with a valid currency code (should pass) company = Company.objects.create( name='Test', @@ -141,8 +130,10 @@ class CompanySimpleTest(TestCase): class ContactSimpleTest(TestCase): + """Unit tests for the Contact model""" def setUp(self): + """Initialization for the tests in this class""" # Create a simple company self.c = Company.objects.create(name='Test Corp.', description='We make stuff good') @@ -152,15 +143,18 @@ class ContactSimpleTest(TestCase): Contact.objects.create(name='Sally Smith', company=self.c) def test_exists(self): + """Test that contacts exist""" self.assertEqual(Contact.objects.count(), 3) def test_delete(self): + """Test deletion of a Contact instance""" # Remove the parent company Company.objects.get(pk=self.c.pk).delete() self.assertEqual(Contact.objects.count(), 0) class ManufacturerPartSimpleTest(TestCase): + """Unit tests for the ManufacturerPart model""" fixtures = [ 'category', @@ -171,6 +165,8 @@ class ManufacturerPartSimpleTest(TestCase): ] def setUp(self): + """Initialization for the unit tests in this class""" + # Create a manufacturer part self.part = Part.objects.get(pk=1) manufacturer = Company.objects.get(pk=1) @@ -193,6 +189,7 @@ class ManufacturerPartSimpleTest(TestCase): supplier_part.save() def test_exists(self): + """That that a ManufacturerPart has been created""" self.assertEqual(ManufacturerPart.objects.count(), 4) # Check that manufacturer part was created from supplier part creation @@ -200,7 +197,7 @@ class ManufacturerPartSimpleTest(TestCase): self.assertEqual(manufacturer_parts.count(), 1) def test_delete(self): - # Remove a part + """Test deletion of a ManufacturerPart""" Part.objects.get(pk=self.part.id).delete() # Check that ManufacturerPart was deleted self.assertEqual(ManufacturerPart.objects.count(), 3) diff --git a/InvenTree/company/urls.py b/InvenTree/company/urls.py index a7b9584c54..71adc1be32 100644 --- a/InvenTree/company/urls.py +++ b/InvenTree/company/urls.py @@ -1,6 +1,4 @@ -""" -URL lookup for Company app -""" +"""URL lookup for Company app.""" from django.urls import include, re_path diff --git a/InvenTree/company/views.py b/InvenTree/company/views.py index 6a21d1e57a..b967d8e0d2 100644 --- a/InvenTree/company/views.py +++ b/InvenTree/company/views.py @@ -1,6 +1,4 @@ -""" -Django views for interacting with Company app -""" +"""Django views for interacting with Company app.""" import io @@ -20,8 +18,7 @@ from .models import Company, ManufacturerPart, SupplierPart class CompanyIndex(InvenTreeRoleMixin, ListView): - """ View for displaying list of companies - """ + """View for displaying list of companies.""" model = Company template_name = 'company/index.html' @@ -30,6 +27,7 @@ class CompanyIndex(InvenTreeRoleMixin, ListView): permission_required = 'company.view_company' def get_context_data(self, **kwargs): + """Add extra context data to the company index page""" ctx = super().get_context_data(**kwargs) @@ -80,7 +78,7 @@ class CompanyIndex(InvenTreeRoleMixin, ListView): return ctx def get_queryset(self): - """ Retrieve the Company queryset based on HTTP request parameters. + """Retrieve the Company queryset based on HTTP request parameters. - supplier: Filter by supplier - customer: Filter by customer @@ -97,23 +95,16 @@ class CompanyIndex(InvenTreeRoleMixin, ListView): class CompanyDetail(InvenTreePluginViewMixin, DetailView): - """ Detail view for Company object """ + """Detail view for Company object.""" context_obect_name = 'company' template_name = 'company/detail.html' queryset = Company.objects.all() model = Company permission_required = 'company.view_company' - def get_context_data(self, **kwargs): - ctx = super().get_context_data(**kwargs) - - return ctx - class CompanyImageDownloadFromURL(AjaxUpdateView): - """ - View for downloading an image from a provided URL - """ + """View for downloading an image from a provided URL.""" model = Company ajax_template_name = 'image_download.html' @@ -121,9 +112,7 @@ class CompanyImageDownloadFromURL(AjaxUpdateView): ajax_form_title = _('Download Image') def validate(self, company, form): - """ - Validate that the image data are correct - """ + """Validate that the image data are correct.""" # First ensure that the normal validation routines pass if not form.is_valid(): return @@ -167,9 +156,7 @@ class CompanyImageDownloadFromURL(AjaxUpdateView): return def save(self, company, form, **kwargs): - """ - Save the downloaded image to the company - """ + """Save the downloaded image to the company.""" fmt = self.image.format if not fmt: @@ -189,28 +176,18 @@ class CompanyImageDownloadFromURL(AjaxUpdateView): class ManufacturerPartDetail(InvenTreePluginViewMixin, DetailView): - """ Detail view for ManufacturerPart """ + """Detail view for ManufacturerPart.""" model = ManufacturerPart template_name = 'company/manufacturer_part_detail.html' context_object_name = 'part' queryset = ManufacturerPart.objects.all() permission_required = 'purchase_order.view' - def get_context_data(self, **kwargs): - ctx = super().get_context_data(**kwargs) - - return ctx - class SupplierPartDetail(InvenTreePluginViewMixin, DetailView): - """ Detail view for SupplierPart """ + """Detail view for SupplierPart.""" model = SupplierPart template_name = 'company/supplier_part_detail.html' context_object_name = 'part' queryset = SupplierPart.objects.all() permission_required = 'purchase_order.view' - - def get_context_data(self, **kwargs): - ctx = super().get_context_data(**kwargs) - - return ctx diff --git a/InvenTree/gunicorn.conf.py b/InvenTree/gunicorn.conf.py index 0e5eee75f9..b11999aadd 100644 --- a/InvenTree/gunicorn.conf.py +++ b/InvenTree/gunicorn.conf.py @@ -1,3 +1,5 @@ +"""Gunicorn configuration script for InvenTree web server""" + import multiprocessing bind = "0.0.0.0:8000" diff --git a/InvenTree/label/admin.py b/InvenTree/label/admin.py index 112ac55a4b..647017eeb3 100644 --- a/InvenTree/label/admin.py +++ b/InvenTree/label/admin.py @@ -1,10 +1,12 @@ +"""Admin functionality for the 'label' app""" + from django.contrib import admin from .models import PartLabel, StockItemLabel, StockLocationLabel class LabelAdmin(admin.ModelAdmin): - + """Admin class for the various label models""" list_display = ('name', 'description', 'label', 'filters', 'enabled') diff --git a/InvenTree/label/api.py b/InvenTree/label/api.py index 7988bfb32f..2816572bd5 100644 --- a/InvenTree/label/api.py +++ b/InvenTree/label/api.py @@ -1,3 +1,5 @@ +"""API functionality for the 'label' app""" + from django.conf import settings from django.core.exceptions import FieldError, ValidationError from django.http import HttpResponse, JsonResponse @@ -21,9 +23,7 @@ from .serializers import (PartLabelSerializer, StockItemLabelSerializer, class LabelListView(generics.ListAPIView): - """ - Generic API class for label templates - """ + """Generic API class for label templates.""" filter_backends = [ DjangoFilterBackend, @@ -41,13 +41,11 @@ class LabelListView(generics.ListAPIView): class LabelPrintMixin: - """ - Mixin for printing labels - """ + """Mixin for printing labels.""" def get_plugin(self, request): - """ - Return the label printing plugin associated with this request. + """Return the label printing plugin associated with this request. + This is provided in the url, e.g. ?plugin=myprinter Requires: @@ -56,7 +54,6 @@ class LabelPrintMixin: - matching plugin implements the 'labels' mixin - matching plugin is enabled """ - if not settings.PLUGINS_ENABLED: return None # pragma: no cover @@ -80,10 +77,7 @@ class LabelPrintMixin: raise NotFound(f"Plugin '{plugin_key}' not found") def print(self, request, items_to_print): - """ - Print this label template against a number of pre-validated items - """ - + """Print this label template against a number of pre-validated items.""" # Check the request to determine if the user has selected a label printing plugin plugin = self.get_plugin(request) @@ -119,26 +113,20 @@ class LabelPrintMixin: label_name += ".pdf" if plugin is not None: - """ - Label printing is to be handled by a plugin, - rather than being exported to PDF. + """Label printing is to be handled by a plugin, rather than being exported to PDF. In this case, we do the following: - Individually generate each label, exporting as an image file - Pass all the images through to the label printing plugin - Return a JSON response indicating that the printing has been offloaded - """ # Label instance label_instance = self.get_object() for idx, output in enumerate(outputs): - """ - For each output, we generate a temporary image file, - which will then get sent to the printer - """ + """For each output, we generate a temporary image file, which will then get sent to the printer.""" # Generate PDF data for the label pdf = output.get_document().write_pdf() @@ -159,20 +147,14 @@ class LabelPrintMixin: }) elif debug_mode: - """ - Contatenate all rendered templates into a single HTML string, - and return the string as a HTML response. - """ + """Contatenate all rendered templates into a single HTML string, and return the string as a HTML response.""" html = "\n".join(outputs) return HttpResponse(html) else: - """ - Concatenate all rendered pages into a single PDF object, - and return the resulting document! - """ + """Concatenate all rendered pages into a single PDF object, and return the resulting document!""" pages = [] @@ -198,15 +180,10 @@ class LabelPrintMixin: class StockItemLabelMixin: - """ - Mixin for extracting stock items from query params - """ + """Mixin for extracting stock items from query params.""" def get_items(self): - """ - Return a list of requested stock items - """ - + """Return a list of requested stock items.""" items = [] params = self.request.query_params @@ -231,25 +208,20 @@ class StockItemLabelMixin: class StockItemLabelList(LabelListView, StockItemLabelMixin): - """ - API endpoint for viewing list of StockItemLabel objects. + """API endpoint for viewing list of StockItemLabel objects. Filterable by: - enabled: Filter by enabled / disabled status - item: Filter by single stock item - items: Filter by list of stock items - """ queryset = StockItemLabel.objects.all() serializer_class = StockItemLabelSerializer def filter_queryset(self, queryset): - """ - Filter the StockItem label queryset. - """ - + """Filter the StockItem label queryset.""" queryset = super().filter_queryset(queryset) # List of StockItem objects to match against @@ -304,42 +276,30 @@ class StockItemLabelList(LabelListView, StockItemLabelMixin): class StockItemLabelDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for a single StockItemLabel object - """ + """API endpoint for a single StockItemLabel object.""" queryset = StockItemLabel.objects.all() serializer_class = StockItemLabelSerializer class StockItemLabelPrint(generics.RetrieveAPIView, StockItemLabelMixin, LabelPrintMixin): - """ - API endpoint for printing a StockItemLabel object - """ + """API endpoint for printing a StockItemLabel object.""" queryset = StockItemLabel.objects.all() serializer_class = StockItemLabelSerializer def get(self, request, *args, **kwargs): - """ - Check if valid stock item(s) have been provided. - """ - + """Check if valid stock item(s) have been provided.""" items = self.get_items() return self.print(request, items) class StockLocationLabelMixin: - """ - Mixin for extracting stock locations from query params - """ + """Mixin for extracting stock locations from query params.""" def get_locations(self): - """ - Return a list of requested stock locations - """ - + """Return a list of requested stock locations.""" locations = [] params = self.request.query_params @@ -364,8 +324,7 @@ class StockLocationLabelMixin: class StockLocationLabelList(LabelListView, StockLocationLabelMixin): - """ - API endpoint for viewiing list of StockLocationLabel objects. + """API endpoint for viewiing list of StockLocationLabel objects. Filterable by: @@ -378,10 +337,7 @@ class StockLocationLabelList(LabelListView, StockLocationLabelMixin): serializer_class = StockLocationLabelSerializer def filter_queryset(self, queryset): - """ - Filter the StockLocationLabel queryset - """ - + """Filter the StockLocationLabel queryset.""" queryset = super().filter_queryset(queryset) # List of StockLocation objects to match against @@ -436,39 +392,30 @@ class StockLocationLabelList(LabelListView, StockLocationLabelMixin): class StockLocationLabelDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for a single StockLocationLabel object - """ + """API endpoint for a single StockLocationLabel object.""" queryset = StockLocationLabel.objects.all() serializer_class = StockLocationLabelSerializer class StockLocationLabelPrint(generics.RetrieveAPIView, StockLocationLabelMixin, LabelPrintMixin): - """ - API endpoint for printing a StockLocationLabel object - """ + """API endpoint for printing a StockLocationLabel object.""" queryset = StockLocationLabel.objects.all() seiralizer_class = StockLocationLabelSerializer def get(self, request, *args, **kwargs): - + """Print labels based on the request parameters""" locations = self.get_locations() return self.print(request, locations) class PartLabelMixin: - """ - Mixin for extracting Part objects from query parameters - """ + """Mixin for extracting Part objects from query parameters.""" def get_parts(self): - """ - Return a list of requested Part objects - """ - + """Return a list of requested Part objects.""" parts = [] params = self.request.query_params @@ -491,15 +438,13 @@ class PartLabelMixin: class PartLabelList(LabelListView, PartLabelMixin): - """ - API endpoint for viewing list of PartLabel objects - """ + """API endpoint for viewing list of PartLabel objects.""" queryset = PartLabel.objects.all() serializer_class = PartLabelSerializer def filter_queryset(self, queryset): - + """Custom queryset filtering for the PartLabel list""" queryset = super().filter_queryset(queryset) parts = self.get_parts() @@ -539,27 +484,20 @@ class PartLabelList(LabelListView, PartLabelMixin): class PartLabelDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for a single PartLabel object - """ + """API endpoint for a single PartLabel object.""" queryset = PartLabel.objects.all() serializer_class = PartLabelSerializer class PartLabelPrint(generics.RetrieveAPIView, PartLabelMixin, LabelPrintMixin): - """ - API endpoint for printing a PartLabel object - """ + """API endpoint for printing a PartLabel object.""" queryset = PartLabel.objects.all() serializer_class = PartLabelSerializer def get(self, request, *args, **kwargs): - """ - Check if valid part(s) have been provided - """ - + """Check if valid part(s) have been provided.""" parts = self.get_parts() return self.print(request, parts) diff --git a/InvenTree/label/apps.py b/InvenTree/label/apps.py index b26b7fb692..ab37e2f275 100644 --- a/InvenTree/label/apps.py +++ b/InvenTree/label/apps.py @@ -1,3 +1,5 @@ +"""label app specification""" + import hashlib import logging import os @@ -14,10 +16,7 @@ logger = logging.getLogger("inventree") def hashFile(filename): - """ - Calculate the MD5 hash of a file - """ - + """Calculate the MD5 hash of a file.""" md5 = hashlib.md5() with open(filename, 'rb') as f: @@ -28,20 +27,17 @@ def hashFile(filename): class LabelConfig(AppConfig): + """App configuration class for the 'label' app""" + name = 'label' def ready(self): - """ - This function is called whenever the label app is loaded - """ - + """This function is called whenever the label app is loaded.""" if canAppAccessDatabase(): self.create_labels() # pragma: no cover def create_labels(self): - """ - Create all default templates - """ + """Create all default templates.""" # Test if models are ready try: from .models import StockLocationLabel @@ -56,11 +52,7 @@ class LabelConfig(AppConfig): self.create_part_labels() def create_stock_item_labels(self): - """ - Create database entries for the default StockItemLabel templates, - if they do not already exist - """ - + """Create database entries for the default StockItemLabel templates, if they do not already exist.""" from .models import StockItemLabel src_dir = os.path.join( @@ -139,11 +131,7 @@ class LabelConfig(AppConfig): ) def create_stock_location_labels(self): - """ - Create database entries for the default StockItemLocation templates, - if they do not already exist - """ - + """Create database entries for the default StockItemLocation templates, if they do not already exist.""" from .models import StockLocationLabel src_dir = os.path.join( @@ -229,11 +217,7 @@ class LabelConfig(AppConfig): ) def create_part_labels(self): - """ - Create database entries for the default PartLabel templates, - if they do not already exist. - """ - + """Create database entries for the default PartLabel templates, if they do not already exist.""" from .models import PartLabel src_dir = os.path.join( diff --git a/InvenTree/label/models.py b/InvenTree/label/models.py index 4da42d73a9..03dcd7468b 100644 --- a/InvenTree/label/models.py +++ b/InvenTree/label/models.py @@ -1,6 +1,4 @@ -""" -Label printing models -""" +"""Label printing models.""" import datetime import logging @@ -32,55 +30,52 @@ logger = logging.getLogger("inventree") def rename_label(instance, filename): - """ Place the label file into the correct subdirectory """ - + """Place the label file into the correct subdirectory.""" filename = os.path.basename(filename) return os.path.join('label', 'template', instance.SUBDIR, filename) def validate_stock_item_filters(filters): - + """Validate query filters for the StockItemLabel model""" filters = validateFilterString(filters, model=stock.models.StockItem) return filters def validate_stock_location_filters(filters): - + """Validate query filters for the StockLocationLabel model""" filters = validateFilterString(filters, model=stock.models.StockLocation) return filters def validate_part_filters(filters): - + """Validate query filters for the PartLabel model""" filters = validateFilterString(filters, model=part.models.Part) return filters class WeasyprintLabelMixin(WeasyTemplateResponseMixin): - """ - Class for rendering a label to a PDF - """ + """Class for rendering a label to a PDF.""" pdf_filename = 'label.pdf' pdf_attachment = True def __init__(self, request, template, **kwargs): - + """Initialize a label mixin with certain properties""" self.request = request self.template_name = template self.pdf_filename = kwargs.get('filename', 'label.pdf') class LabelTemplate(models.Model): - """ - Base class for generic, filterable labels. - """ + """Base class for generic, filterable labels.""" class Meta: + """Metaclass options. Abstract ensures no database table is created.""" + abstract = True # Each class of label files will be stored in a separate subdirectory @@ -91,9 +86,11 @@ class LabelTemplate(models.Model): @property def template(self): + """Return the file path of the template associated with this label instance""" return self.label.path def __str__(self): + """Format a string representation of a label instance""" return "{n} - {d}".format( n=self.name, d=self.description @@ -150,11 +147,10 @@ class LabelTemplate(models.Model): @property def template_name(self): - """ - Returns the file system path to the template file. + """Returns the file system path to the template file. + Required for passing the file to an external process """ - template = self.label.name template = template.replace('/', os.path.sep) template = template.replace('\\', os.path.sep) @@ -164,19 +160,14 @@ class LabelTemplate(models.Model): return template def get_context_data(self, request): - """ - Supply custom context data to the template for rendering. + """Supply custom context data to the template for rendering. Note: Override this in any subclass """ - return {} # pragma: no cover def generate_filename(self, request, **kwargs): - """ - Generate a filename for this label - """ - + """Generate a filename for this label.""" template_string = Template(self.filename_pattern) ctx = self.context(request) @@ -186,10 +177,7 @@ class LabelTemplate(models.Model): return template_string.render(context) def context(self, request): - """ - Provides context data to the template. - """ - + """Provides context data to the template.""" context = self.get_context_data(request) # Add "basic" context data which gets passed to every label @@ -204,21 +192,17 @@ class LabelTemplate(models.Model): return context def render_as_string(self, request, **kwargs): - """ - Render the label to a HTML string + """Render the label to a HTML string. Useful for debug mode (viewing generated code) """ - return render_to_string(self.template_name, self.context(request), request) def render(self, request, **kwargs): - """ - Render the label template to a PDF file + """Render the label template to a PDF file. Uses django-weasyprint plugin to render HTML template """ - wp = WeasyprintLabelMixin( request, self.template_name, @@ -235,12 +219,11 @@ class LabelTemplate(models.Model): class StockItemLabel(LabelTemplate): - """ - Template for printing StockItem labels - """ + """Template for printing StockItem labels.""" @staticmethod def get_api_url(): + """Return the API URL associated with the StockItemLabel model""" return reverse('api-stockitem-label-list') # pragma: no cover SUBDIR = "stockitem" @@ -255,10 +238,7 @@ class StockItemLabel(LabelTemplate): ) def get_context_data(self, request): - """ - Generate context data for each provided StockItem - """ - + """Generate context data for each provided StockItem.""" stock_item = self.object_to_print return { @@ -279,12 +259,11 @@ class StockItemLabel(LabelTemplate): class StockLocationLabel(LabelTemplate): - """ - Template for printing StockLocation labels - """ + """Template for printing StockLocation labels.""" @staticmethod def get_api_url(): + """Return the API URL associated with the StockLocationLabel model""" return reverse('api-stocklocation-label-list') # pragma: no cover SUBDIR = "stocklocation" @@ -298,10 +277,7 @@ class StockLocationLabel(LabelTemplate): ) def get_context_data(self, request): - """ - Generate context data for each provided StockLocation - """ - + """Generate context data for each provided StockLocation.""" location = self.object_to_print return { @@ -311,12 +287,11 @@ class StockLocationLabel(LabelTemplate): class PartLabel(LabelTemplate): - """ - Template for printing Part labels - """ + """Template for printing Part labels.""" @staticmethod def get_api_url(): + """Return the API url associated with the PartLabel model""" return reverse('api-part-label-list') # pragma: no cover SUBDIR = 'part' @@ -331,10 +306,7 @@ class PartLabel(LabelTemplate): ) def get_context_data(self, request): - """ - Generate context data for each provided Part object - """ - + """Generate context data for each provided Part object.""" part = self.object_to_print return { diff --git a/InvenTree/label/serializers.py b/InvenTree/label/serializers.py index 5428424572..b250b6f441 100644 --- a/InvenTree/label/serializers.py +++ b/InvenTree/label/serializers.py @@ -1,3 +1,5 @@ +"""API serializers for the label app""" + from InvenTree.serializers import (InvenTreeAttachmentSerializerField, InvenTreeModelSerializer) @@ -5,13 +7,13 @@ from .models import PartLabel, StockItemLabel, StockLocationLabel class StockItemLabelSerializer(InvenTreeModelSerializer): - """ - Serializes a StockItemLabel object. - """ + """Serializes a StockItemLabel object.""" label = InvenTreeAttachmentSerializerField(required=True) class Meta: + """Metaclass options.""" + model = StockItemLabel fields = [ 'pk', @@ -24,13 +26,13 @@ class StockItemLabelSerializer(InvenTreeModelSerializer): class StockLocationLabelSerializer(InvenTreeModelSerializer): - """ - Serializes a StockLocationLabel object - """ + """Serializes a StockLocationLabel object.""" label = InvenTreeAttachmentSerializerField(required=True) class Meta: + """Metaclass options.""" + model = StockLocationLabel fields = [ 'pk', @@ -43,13 +45,13 @@ class StockLocationLabelSerializer(InvenTreeModelSerializer): class PartLabelSerializer(InvenTreeModelSerializer): - """ - Serializes a PartLabel object - """ + """Serializes a PartLabel object.""" label = InvenTreeAttachmentSerializerField(required=True) class Meta: + """Metaclass options.""" + model = PartLabel fields = [ 'pk', diff --git a/InvenTree/label/test_api.py b/InvenTree/label/test_api.py index d9b0bafc9a..25277c964b 100644 --- a/InvenTree/label/test_api.py +++ b/InvenTree/label/test_api.py @@ -1,4 +1,4 @@ -# Tests for labels +"""Unit tests for label API""" from django.urls import reverse @@ -6,9 +6,7 @@ from InvenTree.api_tester import InvenTreeAPITestCase class TestReportTests(InvenTreeAPITestCase): - """ - Tests for the StockItem TestReport templates - """ + """Tests for the StockItem TestReport templates.""" fixtures = [ 'category', @@ -24,12 +22,8 @@ class TestReportTests(InvenTreeAPITestCase): list_url = reverse('api-stockitem-testreport-list') - def setUp(self): - - super().setUp() - def do_list(self, filters={}): - + """Helper function to request list of labels with provided filters""" response = self.client.get(self.list_url, filters, format='json') self.assertEqual(response.status_code, 200) @@ -37,7 +31,7 @@ class TestReportTests(InvenTreeAPITestCase): return response.data def test_list(self): - + """Test the API list endpoint""" response = self.do_list() # TODO - Add some report templates to the fixtures diff --git a/InvenTree/label/tests.py b/InvenTree/label/tests.py index f94efafb84..7de7fd6c80 100644 --- a/InvenTree/label/tests.py +++ b/InvenTree/label/tests.py @@ -1,4 +1,4 @@ -# Tests for labels +"""Tests for labels""" import os @@ -16,6 +16,7 @@ from .models import PartLabel, StockItemLabel, StockLocationLabel class LabelTest(InvenTreeAPITestCase): + """Unit test class for label models""" fixtures = [ 'category', @@ -25,15 +26,12 @@ class LabelTest(InvenTreeAPITestCase): ] def setUp(self) -> None: + """Ensure that some label instances exist as part of init routine""" super().setUp() - # ensure the labels were created apps.get_app_config('label').create_labels() def test_default_labels(self): - """ - Test that the default label templates are copied across - """ - + """Test that the default label templates are copied across.""" labels = StockItemLabel.objects.all() self.assertTrue(labels.count() > 0) @@ -43,10 +41,7 @@ class LabelTest(InvenTreeAPITestCase): self.assertTrue(labels.count() > 0) def test_default_files(self): - """ - Test that label files exist in the MEDIA directory - """ - + """Test that label files exist in the MEDIA directory.""" item_dir = os.path.join( settings.MEDIA_ROOT, 'label', @@ -70,10 +65,7 @@ class LabelTest(InvenTreeAPITestCase): self.assertTrue(len(files) > 0) def test_filters(self): - """ - Test the label filters - """ - + """Test the label filters.""" filter_string = "part__pk=10" filters = validateFilterString(filter_string, model=StockItem) @@ -86,8 +78,7 @@ class LabelTest(InvenTreeAPITestCase): validateFilterString(bad_filter_string, model=StockItem) def test_label_rendering(self): - """Test label rendering""" - + """Test label rendering.""" labels = PartLabel.objects.all() part = Part.objects.first() diff --git a/InvenTree/label/views.py b/InvenTree/label/views.py deleted file mode 100644 index 60f00ef0ef..0000000000 --- a/InvenTree/label/views.py +++ /dev/null @@ -1 +0,0 @@ -# Create your views here. diff --git a/InvenTree/manage.py b/InvenTree/manage.py index 959fc2787e..9e50f6cd15 100755 --- a/InvenTree/manage.py +++ b/InvenTree/manage.py @@ -1,4 +1,5 @@ -#!/usr/bin/env python +"""InvenTree / django management commands""" + import os import sys diff --git a/InvenTree/order/__init__.py b/InvenTree/order/__init__.py index 896e9facd5..0ef70f99b7 100644 --- a/InvenTree/order/__init__.py +++ b/InvenTree/order/__init__.py @@ -1,3 +1 @@ -""" -The Order module is responsible for managing Orders -""" +"""The Order module is responsible for managing Orders.""" diff --git a/InvenTree/order/admin.py b/InvenTree/order/admin.py index a46fe62532..be953de701 100644 --- a/InvenTree/order/admin.py +++ b/InvenTree/order/admin.py @@ -1,3 +1,5 @@ +"""Admin functionality for the 'order' app""" + from django.contrib import admin import import_export.widgets as widgets @@ -13,6 +15,7 @@ from .models import (PurchaseOrder, PurchaseOrderExtraLine, # region general classes class GeneralExtraLineAdmin: + """Admin class template for the 'ExtraLineItem' models""" list_display = ( 'order', 'quantity', @@ -29,6 +32,7 @@ class GeneralExtraLineAdmin: class GeneralExtraLineMeta: + """Metaclass template for the 'ExtraLineItem' models""" skip_unchanged = True report_skipped = False clean_model_instances = True @@ -36,11 +40,13 @@ class GeneralExtraLineMeta: class PurchaseOrderLineItemInlineAdmin(admin.StackedInline): + """Inline admin class for the PurchaseOrderLineItem model""" model = PurchaseOrderLineItem extra = 0 class PurchaseOrderAdmin(ImportExportModelAdmin): + """Admin class for the PurchaseOrder model""" exclude = [ 'reference_int', @@ -68,6 +74,7 @@ class PurchaseOrderAdmin(ImportExportModelAdmin): class SalesOrderAdmin(ImportExportModelAdmin): + """Admin class for the SalesOrder model""" exclude = [ 'reference_int', @@ -91,9 +98,7 @@ class SalesOrderAdmin(ImportExportModelAdmin): class PurchaseOrderResource(ModelResource): - """ - Class for managing import / export of PurchaseOrder data - """ + """Class for managing import / export of PurchaseOrder data.""" # Add number of line items line_items = Field(attribute='line_count', widget=widgets.IntegerWidget(), readonly=True) @@ -102,6 +107,7 @@ class PurchaseOrderResource(ModelResource): overdue = Field(attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True) class Meta: + """Metaclass""" model = PurchaseOrder skip_unchanged = True clean_model_instances = True @@ -111,7 +117,7 @@ class PurchaseOrderResource(ModelResource): class PurchaseOrderLineItemResource(ModelResource): - """ Class for managing import / export of PurchaseOrderLineItem data """ + """Class for managing import / export of PurchaseOrderLineItem data.""" part_name = Field(attribute='part__part__name', readonly=True) @@ -122,6 +128,7 @@ class PurchaseOrderLineItemResource(ModelResource): SKU = Field(attribute='part__SKU', readonly=True) class Meta: + """Metaclass""" model = PurchaseOrderLineItem skip_unchanged = True report_skipped = False @@ -129,16 +136,16 @@ class PurchaseOrderLineItemResource(ModelResource): class PurchaseOrderExtraLineResource(ModelResource): - """ Class for managing import / export of PurchaseOrderExtraLine data """ + """Class for managing import / export of PurchaseOrderExtraLine data.""" class Meta(GeneralExtraLineMeta): + """Metaclass options.""" + model = PurchaseOrderExtraLine class SalesOrderResource(ModelResource): - """ - Class for managing import / export of SalesOrder data - """ + """Class for managing import / export of SalesOrder data.""" # Add number of line items line_items = Field(attribute='line_count', widget=widgets.IntegerWidget(), readonly=True) @@ -147,6 +154,7 @@ class SalesOrderResource(ModelResource): overdue = Field(attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True) class Meta: + """Metaclass options""" model = SalesOrder skip_unchanged = True clean_model_instances = True @@ -156,9 +164,7 @@ class SalesOrderResource(ModelResource): class SalesOrderLineItemResource(ModelResource): - """ - Class for managing import / export of SalesOrderLineItem data - """ + """Class for managing import / export of SalesOrderLineItem data.""" part_name = Field(attribute='part__name', readonly=True) @@ -169,17 +175,17 @@ class SalesOrderLineItemResource(ModelResource): fulfilled = Field(attribute='fulfilled_quantity', readonly=True) def dehydrate_sale_price(self, item): - """ - Return a string value of the 'sale_price' field, rather than the 'Money' object. + """Return a string value of the 'sale_price' field, rather than the 'Money' object. + Ref: https://github.com/inventree/InvenTree/issues/2207 """ - if item.sale_price: return str(item.sale_price) else: return '' class Meta: + """Metaclass options""" model = SalesOrderLineItem skip_unchanged = True report_skipped = False @@ -187,13 +193,16 @@ class SalesOrderLineItemResource(ModelResource): class SalesOrderExtraLineResource(ModelResource): - """ Class for managing import / export of SalesOrderExtraLine data """ + """Class for managing import / export of SalesOrderExtraLine data.""" class Meta(GeneralExtraLineMeta): + """Metaclass options.""" + model = SalesOrderExtraLine class PurchaseOrderLineItemAdmin(ImportExportModelAdmin): + """Admin class for the PurchaseOrderLine model""" resource_class = PurchaseOrderLineItemResource @@ -210,11 +219,12 @@ class PurchaseOrderLineItemAdmin(ImportExportModelAdmin): class PurchaseOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin): - + """Admin class for the PurchaseOrderExtraLine model""" resource_class = PurchaseOrderExtraLineResource class SalesOrderLineItemAdmin(ImportExportModelAdmin): + """Admin class for the SalesOrderLine model""" resource_class = SalesOrderLineItemResource @@ -236,11 +246,12 @@ class SalesOrderLineItemAdmin(ImportExportModelAdmin): class SalesOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin): - + """Admin class for the SalesOrderExtraLine model""" resource_class = SalesOrderExtraLineResource class SalesOrderShipmentAdmin(ImportExportModelAdmin): + """Admin class for the SalesOrderShipment model""" list_display = [ 'order', @@ -258,6 +269,7 @@ class SalesOrderShipmentAdmin(ImportExportModelAdmin): class SalesOrderAllocationAdmin(ImportExportModelAdmin): + """Admin class for the SalesOrderAllocation model""" list_display = ( 'line', diff --git a/InvenTree/order/api.py b/InvenTree/order/api.py index e6e6767edb..2cb3e394ad 100644 --- a/InvenTree/order/api.py +++ b/InvenTree/order/api.py @@ -1,6 +1,4 @@ -""" -JSON API for the Order app -""" +"""JSON API for the Order app.""" from django.db.models import F, Q from django.urls import include, path, re_path @@ -24,11 +22,10 @@ from users.models import Owner class GeneralExtraLineList: - """ - General template for ExtraLine API classes - """ + """General template for ExtraLine API classes.""" def get_serializer(self, *args, **kwargs): + """Return the serializer instance for this endpoint""" try: params = self.request.query_params @@ -41,7 +38,7 @@ class GeneralExtraLineList: return self.serializer_class(*args, **kwargs) def get_queryset(self, *args, **kwargs): - + """Return the annotated queryset for this endpoint""" queryset = super().get_queryset(*args, **kwargs) queryset = queryset.prefetch_related( @@ -76,17 +73,12 @@ class GeneralExtraLineList: class PurchaseOrderFilter(rest_filters.FilterSet): - """ - Custom API filters for the PurchaseOrderList endpoint - """ + """Custom API filters for the PurchaseOrderList endpoint.""" assigned_to_me = rest_filters.BooleanFilter(label='assigned_to_me', method='filter_assigned_to_me') def filter_assigned_to_me(self, queryset, name, value): - """ - Filter by orders which are assigned to the current user - """ - + """Filter by orders which are assigned to the current user.""" value = str2bool(value) # Work out who "me" is! @@ -100,6 +92,8 @@ class PurchaseOrderFilter(rest_filters.FilterSet): return queryset class Meta: + """Metaclass options.""" + model = models.PurchaseOrder fields = [ 'supplier', @@ -107,7 +101,7 @@ class PurchaseOrderFilter(rest_filters.FilterSet): class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView): - """ API endpoint for accessing a list of PurchaseOrder objects + """API endpoint for accessing a list of PurchaseOrder objects. - GET: Return list of PurchaseOrder objects (with filters) - POST: Create a new PurchaseOrder object @@ -118,9 +112,7 @@ class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView): filterset_class = PurchaseOrderFilter def create(self, request, *args, **kwargs): - """ - Save user information on create - """ + """Save user information on create.""" serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) @@ -132,7 +124,7 @@ class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView): return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) def get_serializer(self, *args, **kwargs): - + """Return the serializer instance for this endpoint""" try: kwargs['supplier_detail'] = str2bool(self.request.query_params.get('supplier_detail', False)) except AttributeError: @@ -144,7 +136,7 @@ class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView): return self.serializer_class(*args, **kwargs) def get_queryset(self, *args, **kwargs): - + """Return the annotated queryset for this endpoint""" queryset = super().get_queryset(*args, **kwargs) queryset = queryset.prefetch_related( @@ -157,6 +149,8 @@ class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView): return queryset def download_queryset(self, queryset, export_format): + """Download the filtered queryset as a file""" + dataset = PurchaseOrderResource().export(queryset=queryset) filedata = dataset.export(export_format) @@ -166,7 +160,7 @@ class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView): return DownloadFile(filedata, filename) def filter_queryset(self, queryset): - + """Custom queryset filtering""" # Perform basic filtering queryset = super().filter_queryset(queryset) @@ -260,13 +254,13 @@ class PurchaseOrderList(APIDownloadMixin, generics.ListCreateAPIView): class PurchaseOrderDetail(generics.RetrieveUpdateDestroyAPIView): - """ API endpoint for detail view of a PurchaseOrder object """ + """API endpoint for detail view of a PurchaseOrder object.""" queryset = models.PurchaseOrder.objects.all() serializer_class = serializers.PurchaseOrderSerializer def get_serializer(self, *args, **kwargs): - + """Return serializer instance for this endpoint""" try: kwargs['supplier_detail'] = str2bool(self.request.query_params.get('supplier_detail', False)) except AttributeError: @@ -278,7 +272,7 @@ class PurchaseOrderDetail(generics.RetrieveUpdateDestroyAPIView): return self.serializer_class(*args, **kwargs) def get_queryset(self, *args, **kwargs): - + """Return annotated queryset for this endpoint""" queryset = super().get_queryset(*args, **kwargs) queryset = queryset.prefetch_related( @@ -292,11 +286,10 @@ class PurchaseOrderDetail(generics.RetrieveUpdateDestroyAPIView): class PurchaseOrderContextMixin: - """ Mixin to add purchase order object as serializer context variable """ + """Mixin to add purchase order object as serializer context variable.""" def get_serializer_context(self): - """ Add the PurchaseOrder object to the serializer context """ - + """Add the PurchaseOrder object to the serializer context.""" context = super().get_serializer_context() # Pass the purchase order through to the serializer for validation @@ -311,8 +304,7 @@ class PurchaseOrderContextMixin: class PurchaseOrderCancel(PurchaseOrderContextMixin, generics.CreateAPIView): - """ - API endpoint to 'cancel' a purchase order. + """API endpoint to 'cancel' a purchase order. The purchase order must be in a state which can be cancelled """ @@ -323,9 +315,7 @@ class PurchaseOrderCancel(PurchaseOrderContextMixin, generics.CreateAPIView): class PurchaseOrderComplete(PurchaseOrderContextMixin, generics.CreateAPIView): - """ - API endpoint to 'complete' a purchase order - """ + """API endpoint to 'complete' a purchase order.""" queryset = models.PurchaseOrder.objects.all() @@ -333,9 +323,7 @@ class PurchaseOrderComplete(PurchaseOrderContextMixin, generics.CreateAPIView): class PurchaseOrderIssue(PurchaseOrderContextMixin, generics.CreateAPIView): - """ - API endpoint to 'complete' a purchase order - """ + """API endpoint to 'complete' a purchase order.""" queryset = models.PurchaseOrder.objects.all() @@ -343,17 +331,17 @@ class PurchaseOrderIssue(PurchaseOrderContextMixin, generics.CreateAPIView): class PurchaseOrderMetadata(generics.RetrieveUpdateAPIView): - """API endpoint for viewing / updating PurchaseOrder metadata""" + """API endpoint for viewing / updating PurchaseOrder metadata.""" def get_serializer(self, *args, **kwargs): + """Return MetadataSerializer instance for a PurchaseOrder""" return MetadataSerializer(models.PurchaseOrder, *args, **kwargs) queryset = models.PurchaseOrder.objects.all() class PurchaseOrderReceive(PurchaseOrderContextMixin, generics.CreateAPIView): - """ - API endpoint to receive stock items against a purchase order. + """API endpoint to receive stock items against a purchase order. - The purchase order is specified in the URL. - Items to receive are specified as a list called "items" with the following options: @@ -370,11 +358,11 @@ class PurchaseOrderReceive(PurchaseOrderContextMixin, generics.CreateAPIView): class PurchaseOrderLineItemFilter(rest_filters.FilterSet): - """ - Custom filters for the PurchaseOrderLineItemList endpoint - """ + """Custom filters for the PurchaseOrderLineItemList endpoint.""" class Meta: + """Metaclass options.""" + model = models.PurchaseOrderLineItem fields = [ 'order', @@ -384,10 +372,7 @@ class PurchaseOrderLineItemFilter(rest_filters.FilterSet): pending = rest_filters.BooleanFilter(label='pending', method='filter_pending') def filter_pending(self, queryset, name, value): - """ - Filter by "pending" status (order status = pending) - """ - + """Filter by "pending" status (order status = pending)""" value = str2bool(value) if value: @@ -402,12 +387,10 @@ class PurchaseOrderLineItemFilter(rest_filters.FilterSet): received = rest_filters.BooleanFilter(label='received', method='filter_received') def filter_received(self, queryset, name, value): - """ - Filter by lines which are "received" (or "not" received) + """Filter by lines which are "received" (or "not" received) A line is considered "received" when received >= quantity """ - value = str2bool(value) q = Q(received__gte=F('quantity')) @@ -422,7 +405,7 @@ class PurchaseOrderLineItemFilter(rest_filters.FilterSet): class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView): - """ API endpoint for accessing a list of PurchaseOrderLineItem objects + """API endpoint for accessing a list of PurchaseOrderLineItem objects. - GET: Return a list of PurchaseOrder Line Item objects - POST: Create a new PurchaseOrderLineItem object @@ -433,7 +416,7 @@ class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView): filterset_class = PurchaseOrderLineItemFilter def get_queryset(self, *args, **kwargs): - + """Return annotated queryset for this endpoint""" queryset = super().get_queryset(*args, **kwargs) queryset = serializers.PurchaseOrderLineItemSerializer.annotate_queryset(queryset) @@ -441,7 +424,7 @@ class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView): return queryset def get_serializer(self, *args, **kwargs): - + """Return serializer instance for this endpoint""" try: kwargs['part_detail'] = str2bool(self.request.query_params.get('part_detail', False)) kwargs['order_detail'] = str2bool(self.request.query_params.get('order_detail', False)) @@ -453,10 +436,7 @@ class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView): return self.serializer_class(*args, **kwargs) def filter_queryset(self, queryset): - """ - Additional filtering options - """ - + """Additional filtering options.""" params = self.request.query_params queryset = super().filter_queryset(queryset) @@ -475,6 +455,8 @@ class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView): return queryset def download_queryset(self, queryset, export_format): + """Download the requested queryset as a file""" + dataset = PurchaseOrderLineItemResource().export(queryset=queryset) filedata = dataset.export(export_format) @@ -483,19 +465,6 @@ class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView): return DownloadFile(filedata, filename) - def list(self, request, *args, **kwargs): - - queryset = self.filter_queryset(self.get_queryset()) - - page = self.paginate_queryset(queryset) - - if page is not None: - serializer = self.get_serializer(page, many=True) - return self.get_paginated_response(serializer.data) - - serializer = self.get_serializer(queryset, many=True) - return Response(serializer.data) - filter_backends = [ rest_filters.DjangoFilterBackend, filters.SearchFilter, @@ -530,15 +499,13 @@ class PurchaseOrderLineItemList(APIDownloadMixin, generics.ListCreateAPIView): class PurchaseOrderLineItemDetail(generics.RetrieveUpdateDestroyAPIView): - """ - Detail API endpoint for PurchaseOrderLineItem object - """ + """Detail API endpoint for PurchaseOrderLineItem object.""" queryset = models.PurchaseOrderLineItem.objects.all() serializer_class = serializers.PurchaseOrderLineItemSerializer def get_queryset(self): - + """Return annotated queryset for this endpoint""" queryset = super().get_queryset() queryset = serializers.PurchaseOrderLineItemSerializer.annotate_queryset(queryset) @@ -547,25 +514,21 @@ class PurchaseOrderLineItemDetail(generics.RetrieveUpdateDestroyAPIView): class PurchaseOrderExtraLineList(GeneralExtraLineList, generics.ListCreateAPIView): - """ - API endpoint for accessing a list of PurchaseOrderExtraLine objects. - """ + """API endpoint for accessing a list of PurchaseOrderExtraLine objects.""" queryset = models.PurchaseOrderExtraLine.objects.all() serializer_class = serializers.PurchaseOrderExtraLineSerializer class PurchaseOrderExtraLineDetail(generics.RetrieveUpdateDestroyAPIView): - """ API endpoint for detail view of a PurchaseOrderExtraLine object """ + """API endpoint for detail view of a PurchaseOrderExtraLine object.""" queryset = models.PurchaseOrderExtraLine.objects.all() serializer_class = serializers.PurchaseOrderExtraLineSerializer class SalesOrderAttachmentList(generics.ListCreateAPIView, AttachmentMixin): - """ - API endpoint for listing (and creating) a SalesOrderAttachment (file upload) - """ + """API endpoint for listing (and creating) a SalesOrderAttachment (file upload)""" queryset = models.SalesOrderAttachment.objects.all() serializer_class = serializers.SalesOrderAttachmentSerializer @@ -580,17 +543,14 @@ class SalesOrderAttachmentList(generics.ListCreateAPIView, AttachmentMixin): class SalesOrderAttachmentDetail(generics.RetrieveUpdateDestroyAPIView, AttachmentMixin): - """ - Detail endpoint for SalesOrderAttachment - """ + """Detail endpoint for SalesOrderAttachment.""" queryset = models.SalesOrderAttachment.objects.all() serializer_class = serializers.SalesOrderAttachmentSerializer class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView): - """ - API endpoint for accessing a list of SalesOrder objects. + """API endpoint for accessing a list of SalesOrder objects. - GET: Return list of SalesOrder objects (with filters) - POST: Create a new SalesOrder @@ -600,9 +560,7 @@ class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView): serializer_class = serializers.SalesOrderSerializer def create(self, request, *args, **kwargs): - """ - Save user information on create - """ + """Save user information on create.""" serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) @@ -614,7 +572,7 @@ class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView): return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) def get_serializer(self, *args, **kwargs): - + """Return serializer instance for this endpoint""" try: kwargs['customer_detail'] = str2bool(self.request.query_params.get('customer_detail', False)) except AttributeError: @@ -626,7 +584,7 @@ class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView): return self.serializer_class(*args, **kwargs) def get_queryset(self, *args, **kwargs): - + """Return annotated queryset for this endpoint""" queryset = super().get_queryset(*args, **kwargs) queryset = queryset.prefetch_related( @@ -639,6 +597,7 @@ class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView): return queryset def download_queryset(self, queryset, export_format): + """Download this queryset as a file""" dataset = SalesOrderResource().export(queryset=queryset) filedata = dataset.export(export_format) @@ -648,10 +607,7 @@ class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView): return DownloadFile(filedata, filename) def filter_queryset(self, queryset): - """ - Perform custom filtering operations on the SalesOrder queryset. - """ - + """Perform custom filtering operations on the SalesOrder queryset.""" queryset = super().filter_queryset(queryset) params = self.request.query_params @@ -739,15 +695,13 @@ class SalesOrderList(APIDownloadMixin, generics.ListCreateAPIView): class SalesOrderDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for detail view of a SalesOrder object. - """ + """API endpoint for detail view of a SalesOrder object.""" queryset = models.SalesOrder.objects.all() serializer_class = serializers.SalesOrderSerializer def get_serializer(self, *args, **kwargs): - + """Return the serializer instance for this endpoint""" try: kwargs['customer_detail'] = str2bool(self.request.query_params.get('customer_detail', False)) except AttributeError: @@ -758,7 +712,7 @@ class SalesOrderDetail(generics.RetrieveUpdateDestroyAPIView): return self.serializer_class(*args, **kwargs) def get_queryset(self, *args, **kwargs): - + """Return the annotated queryset for this serializer""" queryset = super().get_queryset(*args, **kwargs) queryset = queryset.prefetch_related('customer', 'lines') @@ -769,11 +723,11 @@ class SalesOrderDetail(generics.RetrieveUpdateDestroyAPIView): class SalesOrderLineItemFilter(rest_filters.FilterSet): - """ - Custom filters for SalesOrderLineItemList endpoint - """ + """Custom filters for SalesOrderLineItemList endpoint.""" class Meta: + """Metaclass options.""" + model = models.SalesOrderLineItem fields = [ 'order', @@ -783,12 +737,10 @@ class SalesOrderLineItemFilter(rest_filters.FilterSet): completed = rest_filters.BooleanFilter(label='completed', method='filter_completed') def filter_completed(self, queryset, name, value): - """ - Filter by lines which are "completed" + """Filter by lines which are "completed". A line is completed when shipped >= quantity """ - value = str2bool(value) q = Q(shipped__gte=F('quantity')) @@ -802,16 +754,14 @@ class SalesOrderLineItemFilter(rest_filters.FilterSet): class SalesOrderLineItemList(generics.ListCreateAPIView): - """ - API endpoint for accessing a list of SalesOrderLineItem objects. - """ + """API endpoint for accessing a list of SalesOrderLineItem objects.""" queryset = models.SalesOrderLineItem.objects.all() serializer_class = serializers.SalesOrderLineItemSerializer filterset_class = SalesOrderLineItemFilter def get_serializer(self, *args, **kwargs): - + """Return serializer for this endpoint with extra data as requested""" try: params = self.request.query_params @@ -826,7 +776,7 @@ class SalesOrderLineItemList(generics.ListCreateAPIView): return self.serializer_class(*args, **kwargs) def get_queryset(self, *args, **kwargs): - + """Return annotated queryset for this endpoint""" queryset = super().get_queryset(*args, **kwargs) queryset = queryset.prefetch_related( @@ -866,33 +816,31 @@ class SalesOrderLineItemList(generics.ListCreateAPIView): class SalesOrderExtraLineList(GeneralExtraLineList, generics.ListCreateAPIView): - """ - API endpoint for accessing a list of SalesOrderExtraLine objects. - """ + """API endpoint for accessing a list of SalesOrderExtraLine objects.""" queryset = models.SalesOrderExtraLine.objects.all() serializer_class = serializers.SalesOrderExtraLineSerializer class SalesOrderExtraLineDetail(generics.RetrieveUpdateDestroyAPIView): - """ API endpoint for detail view of a SalesOrderExtraLine object """ + """API endpoint for detail view of a SalesOrderExtraLine object.""" queryset = models.SalesOrderExtraLine.objects.all() serializer_class = serializers.SalesOrderExtraLineSerializer class SalesOrderLineItemDetail(generics.RetrieveUpdateDestroyAPIView): - """ API endpoint for detail view of a SalesOrderLineItem object """ + """API endpoint for detail view of a SalesOrderLineItem object.""" queryset = models.SalesOrderLineItem.objects.all() serializer_class = serializers.SalesOrderLineItemSerializer class SalesOrderContextMixin: - """ Mixin to add sales order object as serializer context variable """ + """Mixin to add sales order object as serializer context variable.""" def get_serializer_context(self): - + """Add the 'order' reference to the serializer context for any classes which inherit this mixin""" ctx = super().get_serializer_context() ctx['request'] = self.request @@ -906,42 +854,38 @@ class SalesOrderContextMixin: class SalesOrderCancel(SalesOrderContextMixin, generics.CreateAPIView): + """API endpoint to cancel a SalesOrder""" queryset = models.SalesOrder.objects.all() serializer_class = serializers.SalesOrderCancelSerializer class SalesOrderComplete(SalesOrderContextMixin, generics.CreateAPIView): - """ - API endpoint for manually marking a SalesOrder as "complete". - """ + """API endpoint for manually marking a SalesOrder as "complete".""" queryset = models.SalesOrder.objects.all() serializer_class = serializers.SalesOrderCompleteSerializer class SalesOrderMetadata(generics.RetrieveUpdateAPIView): - """API endpoint for viewing / updating SalesOrder metadata""" + """API endpoint for viewing / updating SalesOrder metadata.""" def get_serializer(self, *args, **kwargs): + """Return a metadata serializer for the SalesOrder model""" return MetadataSerializer(models.SalesOrder, *args, **kwargs) queryset = models.SalesOrder.objects.all() class SalesOrderAllocateSerials(SalesOrderContextMixin, generics.CreateAPIView): - """ - API endpoint to allocation stock items against a SalesOrder, - by specifying serial numbers. - """ + """API endpoint to allocation stock items against a SalesOrder, by specifying serial numbers.""" queryset = models.SalesOrder.objects.none() serializer_class = serializers.SalesOrderSerialAllocationSerializer class SalesOrderAllocate(SalesOrderContextMixin, generics.CreateAPIView): - """ - API endpoint to allocate stock items against a SalesOrder + """API endpoint to allocate stock items against a SalesOrder. - The SalesOrder is specified in the URL - See the SalesOrderShipmentAllocationSerializer class @@ -952,24 +896,23 @@ class SalesOrderAllocate(SalesOrderContextMixin, generics.CreateAPIView): class SalesOrderAllocationDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for detali view of a SalesOrderAllocation object - """ + """API endpoint for detali view of a SalesOrderAllocation object.""" queryset = models.SalesOrderAllocation.objects.all() serializer_class = serializers.SalesOrderAllocationSerializer class SalesOrderAllocationList(generics.ListAPIView): - """ - API endpoint for listing SalesOrderAllocation objects - """ + """API endpoint for listing SalesOrderAllocation objects.""" queryset = models.SalesOrderAllocation.objects.all() serializer_class = serializers.SalesOrderAllocationSerializer def get_serializer(self, *args, **kwargs): + """Return the serializer instance for this endpoint. + Adds extra detail serializers if requested + """ try: params = self.request.query_params @@ -984,7 +927,7 @@ class SalesOrderAllocationList(generics.ListAPIView): return self.serializer_class(*args, **kwargs) def filter_queryset(self, queryset): - + """Custom queryset filtering""" queryset = super().filter_queryset(queryset) # Filter by order @@ -1039,14 +982,12 @@ class SalesOrderAllocationList(generics.ListAPIView): class SalesOrderShipmentFilter(rest_filters.FilterSet): - """ - Custom filterset for the SalesOrderShipmentList endpoint - """ + """Custom filterset for the SalesOrderShipmentList endpoint.""" shipped = rest_filters.BooleanFilter(label='shipped', method='filter_shipped') def filter_shipped(self, queryset, name, value): - + """Filter SalesOrder list by 'shipped' status (boolean)""" value = str2bool(value) if value: @@ -1057,6 +998,8 @@ class SalesOrderShipmentFilter(rest_filters.FilterSet): return queryset class Meta: + """Metaclass options.""" + model = models.SalesOrderShipment fields = [ 'order', @@ -1064,9 +1007,7 @@ class SalesOrderShipmentFilter(rest_filters.FilterSet): class SalesOrderShipmentList(generics.ListCreateAPIView): - """ - API list endpoint for SalesOrderShipment model - """ + """API list endpoint for SalesOrderShipment model.""" queryset = models.SalesOrderShipment.objects.all() serializer_class = serializers.SalesOrderShipmentSerializer @@ -1078,27 +1019,20 @@ class SalesOrderShipmentList(generics.ListCreateAPIView): class SalesOrderShipmentDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API detail endpooint for SalesOrderShipment model - """ + """API detail endpooint for SalesOrderShipment model.""" queryset = models.SalesOrderShipment.objects.all() serializer_class = serializers.SalesOrderShipmentSerializer class SalesOrderShipmentComplete(generics.CreateAPIView): - """ - API endpoint for completing (shipping) a SalesOrderShipment - """ + """API endpoint for completing (shipping) a SalesOrderShipment.""" queryset = models.SalesOrderShipment.objects.all() serializer_class = serializers.SalesOrderShipmentCompleteSerializer def get_serializer_context(self): - """ - Pass the request object to the serializer - """ - + """Pass the request object to the serializer.""" ctx = super().get_serializer_context() ctx['request'] = self.request @@ -1113,9 +1047,7 @@ class SalesOrderShipmentComplete(generics.CreateAPIView): class PurchaseOrderAttachmentList(generics.ListCreateAPIView, AttachmentMixin): - """ - API endpoint for listing (and creating) a PurchaseOrderAttachment (file upload) - """ + """API endpoint for listing (and creating) a PurchaseOrderAttachment (file upload)""" queryset = models.PurchaseOrderAttachment.objects.all() serializer_class = serializers.PurchaseOrderAttachmentSerializer @@ -1130,9 +1062,7 @@ class PurchaseOrderAttachmentList(generics.ListCreateAPIView, AttachmentMixin): class PurchaseOrderAttachmentDetail(generics.RetrieveUpdateDestroyAPIView, AttachmentMixin): - """ - Detail endpoint for a PurchaseOrderAttachment - """ + """Detail endpoint for a PurchaseOrderAttachment.""" queryset = models.PurchaseOrderAttachment.objects.all() serializer_class = serializers.PurchaseOrderAttachmentSerializer diff --git a/InvenTree/order/apps.py b/InvenTree/order/apps.py index 821e6d872c..e3bc0fc6d2 100644 --- a/InvenTree/order/apps.py +++ b/InvenTree/order/apps.py @@ -1,5 +1,8 @@ +"""Config for the 'order' app""" + from django.apps import AppConfig class OrderConfig(AppConfig): + """Configuration class for the 'order' app""" name = 'order' diff --git a/InvenTree/order/forms.py b/InvenTree/order/forms.py index 2b06a92d00..c38f50ed56 100644 --- a/InvenTree/order/forms.py +++ b/InvenTree/order/forms.py @@ -1,6 +1,4 @@ -""" -Django Forms for interacting with Order objects -""" +"""Django Forms for interacting with Order objects.""" from django import forms from django.utils.translation import gettext_lazy as _ @@ -11,11 +9,10 @@ from InvenTree.helpers import clean_decimal class OrderMatchItemForm(MatchItemForm): - """ Override MatchItemForm fields """ + """Override MatchItemForm fields.""" def get_special_field(self, col_guess, row, file_manager): - """ Set special fields """ - + """Set special fields.""" # set quantity field if 'quantity' in col_guess.lower(): return forms.CharField( diff --git a/InvenTree/order/models.py b/InvenTree/order/models.py index 3871ba38a1..bc915fa564 100644 --- a/InvenTree/order/models.py +++ b/InvenTree/order/models.py @@ -1,8 +1,4 @@ -""" -Order model definitions -""" - -# -*- coding: utf-8 -*- +"""Order model definitions.""" import logging import os @@ -47,10 +43,7 @@ logger = logging.getLogger('inventree') def get_next_po_number(): - """ - Returns the next available PurchaseOrder reference number - """ - + """Returns the next available PurchaseOrder reference number.""" if PurchaseOrder.objects.count() == 0: return '0001' @@ -76,10 +69,7 @@ def get_next_po_number(): def get_next_so_number(): - """ - Returns the next available SalesOrder reference number - """ - + """Returns the next available SalesOrder reference number.""" if SalesOrder.objects.count() == 0: return '0001' @@ -105,7 +95,7 @@ def get_next_so_number(): class Order(MetadataMixin, ReferenceIndexingMixin): - """ Abstract model for an order. + """Abstract model for an order. Instances of this class: @@ -123,7 +113,10 @@ class Order(MetadataMixin, ReferenceIndexingMixin): """ def save(self, *args, **kwargs): + """Custom save method for the order models: + Ensures that the reference field is rebuilt whenever the instance is saved. + """ self.rebuild_reference_field() if not self.creation_date: @@ -132,6 +125,8 @@ class Order(MetadataMixin, ReferenceIndexingMixin): super().save(*args, **kwargs) class Meta: + """Metaclass options. Abstract ensures no database table is created.""" + abstract = True description = models.CharField(max_length=250, verbose_name=_('Description'), help_text=_('Order description')) @@ -159,15 +154,13 @@ class Order(MetadataMixin, ReferenceIndexingMixin): notes = MarkdownxField(blank=True, verbose_name=_('Notes'), help_text=_('Order notes')) def get_total_price(self, target_currency=currency_code_default()): - """ - Calculates the total price of all order lines, and converts to the specified target currency. + """Calculates the total price of all order lines, and converts to the specified target currency. If not specified, the default system currency is used. If currency conversion fails (e.g. there are no valid conversion rates), then we simply return zero, rather than attempting some other calculation. """ - total = Money(0, target_currency) # gather name reference @@ -230,7 +223,7 @@ class Order(MetadataMixin, ReferenceIndexingMixin): class PurchaseOrder(Order): - """ A PurchaseOrder represents goods shipped inwards from an external supplier. + """A PurchaseOrder represents goods shipped inwards from an external supplier. Attributes: supplier: Reference to the company supplying the goods in the order @@ -241,14 +234,14 @@ class PurchaseOrder(Order): @staticmethod def get_api_url(): + """Return the API URL associated with the PurchaseOrder model""" return reverse('api-po-list') OVERDUE_FILTER = Q(status__in=PurchaseOrderStatus.OPEN) & ~Q(target_date=None) & Q(target_date__lte=datetime.now().date()) @staticmethod def filterByDate(queryset, min_date, max_date): - """ - Filter by 'minimum and maximum date range' + """Filter by 'minimum and maximum date range'. - Specified as min_date, max_date - Both must be specified for filter to be applied @@ -259,7 +252,6 @@ class PurchaseOrder(Order): - A "pending" order where the target date lies within the date range - TODO: An "overdue" order where the target date is in the past """ - date_fmt = '%Y-%m-%d' # ISO format date string # Ensure that both dates are valid @@ -283,7 +275,7 @@ class PurchaseOrder(Order): return queryset def __str__(self): - + """Render a string representation of this PurchaseOrder""" prefix = getSetting('PURCHASEORDER_REFERENCE_PREFIX') return f"{prefix}{self.reference} - {self.supplier.name if self.supplier else _('deleted')}" @@ -340,22 +332,29 @@ class PurchaseOrder(Order): ) def get_absolute_url(self): + """Return the web URL of the detail view for this order""" return reverse('po-detail', kwargs={'pk': self.id}) @transaction.atomic - def add_line_item(self, supplier_part, quantity, group=True, reference='', purchase_price=None): - """ Add a new line item to this purchase order. - This function will check that: + def add_line_item(self, supplier_part, quantity, group: bool = True, reference: str = '', purchase_price=None): + """Add a new line item to this purchase order. + This function will check that: * The supplier part matches the supplier specified for this purchase order * The quantity is greater than zero Args: - supplier_part - The supplier_part to add - quantity - The number of items to add - group - If True, this new quantity will be added to an existing line item for the same supplier_part (if it exists) - """ + supplier_part: The supplier_part to add + quantity : The number of items to add + group (bool, optional): If True, this new quantity will be added to an existing line item for the same supplier_part (if it exists). Defaults to True. + reference (str, optional): Reference to item. Defaults to ''. + purchase_price (optional): Price of item. Defaults to None. + Raises: + ValidationError: quantity is smaller than 0 + ValidationError: quantity is not type int + ValidationError: supplier is not supplier of purchase order + """ try: quantity = int(quantity) if quantity <= 0: @@ -396,8 +395,10 @@ class PurchaseOrder(Order): @transaction.atomic def place_order(self): - """ Marks the PurchaseOrder as PLACED. Order must be currently PENDING. """ + """Marks the PurchaseOrder as PLACED. + Order must be currently PENDING. + """ if self.status == PurchaseOrderStatus.PENDING: self.status = PurchaseOrderStatus.PLACED self.issue_date = datetime.now().date() @@ -407,8 +408,10 @@ class PurchaseOrder(Order): @transaction.atomic def complete_order(self): - """ Marks the PurchaseOrder as COMPLETE. Order must be currently PLACED. """ + """Marks the PurchaseOrder as COMPLETE. + Order must be currently PLACED. + """ if self.status == PurchaseOrderStatus.PLACED: self.status = PurchaseOrderStatus.COMPLETE self.complete_date = datetime.now().date() @@ -418,22 +421,21 @@ class PurchaseOrder(Order): @property def is_overdue(self): - """ - Returns True if this PurchaseOrder is "overdue" + """Returns True if this PurchaseOrder is "overdue". Makes use of the OVERDUE_FILTER to avoid code duplication. """ - query = PurchaseOrder.objects.filter(pk=self.pk) query = query.filter(PurchaseOrder.OVERDUE_FILTER) return query.exists() def can_cancel(self): - """ - A PurchaseOrder can only be cancelled under the following circumstances: - """ + """A PurchaseOrder can only be cancelled under the following circumstances. + - Status is PLACED + - Status is PENDING + """ return self.status in [ PurchaseOrderStatus.PLACED, PurchaseOrderStatus.PENDING @@ -441,8 +443,7 @@ class PurchaseOrder(Order): @transaction.atomic def cancel_order(self): - """ Marks the PurchaseOrder as CANCELLED. """ - + """Marks the PurchaseOrder as CANCELLED.""" if self.can_cancel(): self.status = PurchaseOrderStatus.CANCELLED self.save() @@ -450,43 +451,39 @@ class PurchaseOrder(Order): trigger_event('purchaseorder.cancelled', id=self.pk) def pending_line_items(self): - """ Return a list of pending line items for this order. + """Return a list of pending line items for this order. + Any line item where 'received' < 'quantity' will be returned. """ - return self.lines.filter(quantity__gt=F('received')) def completed_line_items(self): - """ - Return a list of completed line items against this order - """ + """Return a list of completed line items against this order.""" return self.lines.filter(quantity__lte=F('received')) @property def line_count(self): + """Return the total number of line items associated with this order""" return self.lines.count() @property def completed_line_count(self): - + """Return the number of complete line items associated with this order""" return self.completed_line_items().count() @property def pending_line_count(self): + """Return the number of pending line items associated with this order""" return self.pending_line_items().count() @property def is_complete(self): - """ Return True if all line items have been received """ - + """Return True if all line items have been received.""" return self.lines.count() > 0 and self.pending_line_items().count() == 0 @transaction.atomic def receive_line_item(self, line, location, quantity, user, status=StockStatus.OK, **kwargs): - """ - Receive a line item (or partial line item) against this PurchaseOrder - """ - + """Receive a line item (or partial line item) against this PurchaseOrder.""" # Extract optional batch code for the new stock item batch_code = kwargs.get('batch_code', '') @@ -573,8 +570,7 @@ class PurchaseOrder(Order): class SalesOrder(Order): - """ - A SalesOrder represents a list of goods shipped outwards to a customer. + """A SalesOrder represents a list of goods shipped outwards to a customer. Attributes: customer: Reference to the company receiving the goods in the order @@ -584,14 +580,14 @@ class SalesOrder(Order): @staticmethod def get_api_url(): + """Return the API URL associated with the SalesOrder model""" return reverse('api-so-list') OVERDUE_FILTER = Q(status__in=SalesOrderStatus.OPEN) & ~Q(target_date=None) & Q(target_date__lte=datetime.now().date()) @staticmethod def filterByDate(queryset, min_date, max_date): - """ - Filter by "minimum and maximum date range" + """Filter by "minimum and maximum date range". - Specified as min_date, max_date - Both must be specified for filter to be applied @@ -602,7 +598,6 @@ class SalesOrder(Order): - A "pending" order where the target date lies within the date range - TODO: An "overdue" order where the target date is in the past """ - date_fmt = '%Y-%m-%d' # ISO format date string # Ensure that both dates are valid @@ -625,19 +620,14 @@ class SalesOrder(Order): return queryset - def save(self, *args, **kwargs): - - self.rebuild_reference_field() - - super().save(*args, **kwargs) - def __str__(self): - + """Render a string representation of this SalesOrder""" prefix = getSetting('SALESORDER_REFERENCE_PREFIX') return f"{prefix}{self.reference} - {self.customer.name if self.customer else _('deleted')}" def get_absolute_url(self): + """Return the web URL for the detail view of this order""" return reverse('so-detail', kwargs={'pk': self.id}) reference = models.CharField( @@ -682,12 +672,10 @@ class SalesOrder(Order): @property def is_overdue(self): - """ - Returns true if this SalesOrder is "overdue": + """Returns true if this SalesOrder is "overdue". Makes use of the OVERDUE_FILTER to avoid code duplication. """ - query = SalesOrder.objects.filter(pk=self.pk) query = query.filter(SalesOrder.OVERDUE_FILTER) @@ -695,21 +683,18 @@ class SalesOrder(Order): @property def is_pending(self): + """Return True if this order is 'pending'""" return self.status == SalesOrderStatus.PENDING @property def stock_allocations(self): - """ - Return a queryset containing all allocations for this order - """ - + """Return a queryset containing all allocations for this order.""" return SalesOrderAllocation.objects.filter( line__in=[line.pk for line in self.lines.all()] ) def is_fully_allocated(self): - """ Return True if all line items are fully allocated """ - + """Return True if all line items are fully allocated.""" for line in self.lines.all(): if not line.is_fully_allocated(): return False @@ -717,8 +702,7 @@ class SalesOrder(Order): return True def is_over_allocated(self): - """ Return true if any lines in the order are over-allocated """ - + """Return true if any lines in the order are over-allocated.""" for line in self.lines.all(): if line.is_over_allocated(): return True @@ -726,19 +710,14 @@ class SalesOrder(Order): return False def is_completed(self): - """ - Check if this order is "shipped" (all line items delivered), - """ - + """Check if this order is "shipped" (all line items delivered).""" return self.lines.count() > 0 and all([line.is_completed() for line in self.lines.all()]) def can_complete(self, raise_error=False): - """ - Test if this SalesOrder can be completed. + """Test if this SalesOrder can be completed. Throws a ValidationError if cannot be completed. """ - try: # Order without line items cannot be completed @@ -765,10 +744,7 @@ class SalesOrder(Order): return True def complete_order(self, user): - """ - Mark this order as "complete" - """ - + """Mark this order as "complete.""" if not self.can_complete(): return False @@ -783,10 +759,7 @@ class SalesOrder(Order): return True def can_cancel(self): - """ - Return True if this order can be cancelled - """ - + """Return True if this order can be cancelled.""" if self.status != SalesOrderStatus.PENDING: return False @@ -794,13 +767,12 @@ class SalesOrder(Order): @transaction.atomic def cancel_order(self): - """ - Cancel this order (only if it is "pending") + """Cancel this order (only if it is "pending"). + Executes: - Mark the order as 'cancelled' - Delete any StockItems which have been allocated """ - if not self.can_cancel(): return False @@ -817,59 +789,54 @@ class SalesOrder(Order): @property def line_count(self): + """Return the total number of lines associated with this order""" return self.lines.count() def completed_line_items(self): - """ - Return a queryset of the completed line items for this order - """ + """Return a queryset of the completed line items for this order.""" return self.lines.filter(shipped__gte=F('quantity')) def pending_line_items(self): - """ - Return a queryset of the pending line items for this order - """ + """Return a queryset of the pending line items for this order.""" return self.lines.filter(shipped__lt=F('quantity')) @property def completed_line_count(self): + """Return the number of completed lines for this order""" return self.completed_line_items().count() @property def pending_line_count(self): + """Return the number of pending (incomplete) lines associated with this order""" return self.pending_line_items().count() def completed_shipments(self): - """ - Return a queryset of the completed shipments for this order - """ + """Return a queryset of the completed shipments for this order.""" return self.shipments.exclude(shipment_date=None) def pending_shipments(self): - """ - Return a queryset of the pending shipments for this order - """ - + """Return a queryset of the pending shipments for this order.""" return self.shipments.filter(shipment_date=None) @property def shipment_count(self): + """Return the total number of shipments associated with this order""" return self.shipments.count() @property def completed_shipment_count(self): + """Return the number of completed shipments associated with this order""" return self.completed_shipments().count() @property def pending_shipment_count(self): + """Return the number of pending shipments associated with this order""" return self.pending_shipments().count() @receiver(post_save, sender=SalesOrder, dispatch_uid='build_post_save_log') def after_save_sales_order(sender, instance: SalesOrder, created: bool, **kwargs): - """ - Callback function to be executed after a SalesOrder instance is saved - """ + """Callback function to be executed after a SalesOrder instance is saved.""" if created and getSetting('SALESORDER_DEFAULT_SHIPMENT'): # A new SalesOrder has just been created @@ -881,37 +848,37 @@ def after_save_sales_order(sender, instance: SalesOrder, created: bool, **kwargs class PurchaseOrderAttachment(InvenTreeAttachment): - """ - Model for storing file attachments against a PurchaseOrder object - """ + """Model for storing file attachments against a PurchaseOrder object.""" @staticmethod def get_api_url(): + """Return the API URL associated with the PurchaseOrderAttachment model""" return reverse('api-po-attachment-list') def getSubdir(self): + """Return the directory path where PurchaseOrderAttachment files are located""" return os.path.join("po_files", str(self.order.id)) order = models.ForeignKey(PurchaseOrder, on_delete=models.CASCADE, related_name="attachments") class SalesOrderAttachment(InvenTreeAttachment): - """ - Model for storing file attachments against a SalesOrder object - """ + """Model for storing file attachments against a SalesOrder object.""" @staticmethod def get_api_url(): + """Return the API URL associated with the SalesOrderAttachment class""" return reverse('api-so-attachment-list') def getSubdir(self): + """Return the directory path where SalesOrderAttachment files are located""" return os.path.join("so_files", str(self.order.id)) order = models.ForeignKey(SalesOrder, on_delete=models.CASCADE, related_name='attachments') class OrderLineItem(models.Model): - """ Abstract model for an order line item + """Abstract model for an order line item. Attributes: quantity: Number of items @@ -929,6 +896,8 @@ class OrderLineItem(models.Model): OVERDUE_FILTER = Q(received__lt=F('quantity')) & ~Q(target_date=None) & Q(target_date__lt=datetime.now().date()) class Meta: + """Metaclass options. Abstract ensures no database table is created.""" + abstract = True quantity = RoundingDecimalField( @@ -951,16 +920,16 @@ class OrderLineItem(models.Model): class OrderExtraLine(OrderLineItem): - """ - Abstract Model for a single ExtraLine in a Order + """Abstract Model for a single ExtraLine in a Order. + Attributes: price: The unit sale price for this OrderLineItem """ class Meta: + """Metaclass options. Abstract ensures no database table is created.""" + abstract = True - unique_together = [ - ] context = models.JSONField( blank=True, null=True, @@ -976,30 +945,24 @@ class OrderExtraLine(OrderLineItem): help_text=_('Unit price'), ) - def price_converted(self): - return convert_money(self.price, currency_code_default()) - - def price_converted_currency(self): - return currency_code_default() - class PurchaseOrderLineItem(OrderLineItem): - """ Model for a purchase order line item. + """Model for a purchase order line item. Attributes: order: Reference to a PurchaseOrder object """ - class Meta: - unique_together = ( - ) - @staticmethod def get_api_url(): + """Return the API URL associated with the PurchaseOrderLineItem model""" return reverse('api-po-line-list') def clean(self): + """Custom clean method for the PurchaseOrderLineItem model: + - Ensure the supplier part matches the supplier + """ super().clean() if self.order.supplier and self.part: @@ -1010,6 +973,7 @@ class PurchaseOrderLineItem(OrderLineItem): }) def __str__(self): + """Render a string representation of a PurchaseOrderLineItem instance""" return "{n} x {part} from {supplier} (for {po})".format( n=decimal2string(self.quantity), part=self.part.SKU if self.part else 'unknown part', @@ -1024,8 +988,7 @@ class PurchaseOrderLineItem(OrderLineItem): ) def get_base_part(self): - """ - Return the base part.Part object for the line item + """Return the base part.Part object for the line item. Note: Returns None if the SupplierPart is not set! """ @@ -1067,14 +1030,12 @@ class PurchaseOrderLineItem(OrderLineItem): ) def get_destination(self): - """ - Show where the line item is or should be placed + """Show where the line item is or should be placed. NOTE: If a line item gets split when recieved, only an arbitrary stock items location will be reported as the location for the entire line. """ - for stock in stock_models.StockItem.objects.filter(supplier_part=self.part, purchase_order=self.order): if stock.location: return stock.location @@ -1084,14 +1045,14 @@ class PurchaseOrderLineItem(OrderLineItem): return self.part.part.default_location def remaining(self): - """ Calculate the number of items remaining to be received """ + """Calculate the number of items remaining to be received.""" r = self.quantity - self.received return max(r, 0) class PurchaseOrderExtraLine(OrderExtraLine): - """ - Model for a single ExtraLine in a PurchaseOrder + """Model for a single ExtraLine in a PurchaseOrder. + Attributes: order: Link to the PurchaseOrder that this line belongs to title: title of line @@ -1099,14 +1060,14 @@ class PurchaseOrderExtraLine(OrderExtraLine): """ @staticmethod def get_api_url(): + """Return the API URL associated with the PurchaseOrderExtraLine model""" return reverse('api-po-extra-line-list') order = models.ForeignKey(PurchaseOrder, on_delete=models.CASCADE, related_name='extra_lines', verbose_name=_('Order'), help_text=_('Purchase Order')) class SalesOrderLineItem(OrderLineItem): - """ - Model for a single LineItem in a SalesOrder + """Model for a single LineItem in a SalesOrder. Attributes: order: Link to the SalesOrder that this line item belongs to @@ -1117,6 +1078,7 @@ class SalesOrderLineItem(OrderLineItem): @staticmethod def get_api_url(): + """Return the API URL associated with the SalesOrderLineItem model""" return reverse('api-so-line-list') order = models.ForeignKey( @@ -1145,52 +1107,39 @@ class SalesOrderLineItem(OrderLineItem): validators=[MinValueValidator(0)] ) - class Meta: - unique_together = [ - ] - def fulfilled_quantity(self): - """ - Return the total stock quantity fulfilled against this line item. - """ - + """Return the total stock quantity fulfilled against this line item.""" query = self.order.stock_items.filter(part=self.part).aggregate(fulfilled=Coalesce(Sum('quantity'), Decimal(0))) return query['fulfilled'] def allocated_quantity(self): - """ Return the total stock quantity allocated to this LineItem. + """Return the total stock quantity allocated to this LineItem. This is a summation of the quantity of each attached StockItem """ - query = self.allocations.aggregate(allocated=Coalesce(Sum('quantity'), Decimal(0))) return query['allocated'] def is_fully_allocated(self): - """ Return True if this line item is fully allocated """ - + """Return True if this line item is fully allocated.""" if self.order.status == SalesOrderStatus.SHIPPED: return self.fulfilled_quantity() >= self.quantity return self.allocated_quantity() >= self.quantity def is_over_allocated(self): - """ Return True if this line item is over allocated """ + """Return True if this line item is over allocated.""" return self.allocated_quantity() > self.quantity def is_completed(self): - """ - Return True if this line item is completed (has been fully shipped) - """ - + """Return True if this line item is completed (has been fully shipped).""" return self.shipped >= self.quantity class SalesOrderShipment(models.Model): - """ - The SalesOrderShipment model represents a physical shipment made against a SalesOrder. + """The SalesOrderShipment model represents a physical shipment made against a SalesOrder. - Points to a single SalesOrder object - Multiple SalesOrderAllocation objects point to a particular SalesOrderShipment @@ -1205,6 +1154,7 @@ class SalesOrderShipment(models.Model): """ class Meta: + """Metaclass defines extra model options""" # Shipment reference must be unique for a given sales order unique_together = [ 'order', 'reference', @@ -1212,6 +1162,7 @@ class SalesOrderShipment(models.Model): @staticmethod def get_api_url(): + """Return the API URL associated with the SalesOrderShipment model""" return reverse('api-so-shipment-list') order = models.ForeignKey( @@ -1275,10 +1226,11 @@ class SalesOrderShipment(models.Model): ) def is_complete(self): + """Return True if this shipment has already been completed""" return self.shipment_date is not None def check_can_complete(self, raise_error=True): - + """Check if this shipment is able to be completed""" try: if self.shipment_date: # Shipment has already been sent! @@ -1297,14 +1249,13 @@ class SalesOrderShipment(models.Model): @transaction.atomic def complete_shipment(self, user, **kwargs): - """ - Complete this particular shipment: + """Complete this particular shipment. + Executes: 1. Update any stock items associated with this shipment 2. Update the "shipped" quantity of all associated line items 3. Set the "shipment_date" to now """ - # Check if the shipment can be completed (throw error if not) self.check_can_complete() @@ -1343,8 +1294,8 @@ class SalesOrderShipment(models.Model): class SalesOrderExtraLine(OrderExtraLine): - """ - Model for a single ExtraLine in a SalesOrder + """Model for a single ExtraLine in a SalesOrder. + Attributes: order: Link to the SalesOrder that this line belongs to title: title of line @@ -1352,40 +1303,37 @@ class SalesOrderExtraLine(OrderExtraLine): """ @staticmethod def get_api_url(): + """Return the API URL associated with the SalesOrderExtraLine model""" return reverse('api-so-extra-line-list') order = models.ForeignKey(SalesOrder, on_delete=models.CASCADE, related_name='extra_lines', verbose_name=_('Order'), help_text=_('Sales Order')) class SalesOrderAllocation(models.Model): - """ - This model is used to 'allocate' stock items to a SalesOrder. - Items that are "allocated" to a SalesOrder are not yet "attached" to the order, - but they will be once the order is fulfilled. + """This model is used to 'allocate' stock items to a SalesOrder. Items that are "allocated" to a SalesOrder are not yet "attached" to the order, but they will be once the order is fulfilled. Attributes: line: SalesOrderLineItem reference shipment: SalesOrderShipment reference item: StockItem reference quantity: Quantity to take from the StockItem - """ @staticmethod def get_api_url(): + """Return the API URL associated with the SalesOrderAllocation model""" return reverse('api-so-allocation-list') def clean(self): - """ - Validate the SalesOrderAllocation object: + """Validate the SalesOrderAllocation object. + Executes: - Cannot allocate stock to a line item without a part reference - The referenced part must match the part associated with the line item - Allocated quantity cannot exceed the quantity of the stock item - Allocation quantity must be "1" if the StockItem is serialized - Allocation quantity cannot be zero """ - super().clean() errors = {} @@ -1452,29 +1400,21 @@ class SalesOrderAllocation(models.Model): quantity = RoundingDecimalField(max_digits=15, decimal_places=5, validators=[MinValueValidator(0)], default=1, verbose_name=_('Quantity'), help_text=_('Enter stock allocation quantity')) - def get_serial(self): - return self.item.serial - def get_location(self): + """Return the value of the location associated with this allocation""" return self.item.location.id if self.item.location else None - def get_location_path(self): - if self.item.location: - return self.item.location.pathstring - else: - return "" - def get_po(self): + """Return the PurchaseOrder associated with this allocation""" return self.item.purchase_order def complete_allocation(self, user): - """ - Complete this allocation (called when the parent SalesOrder is marked as "shipped"): + """Complete this allocation (called when the parent SalesOrder is marked as "shipped"). + Executes: - Determine if the referenced StockItem needs to be "split" (if allocated quantity != stock quantity) - Mark the StockItem as belonging to the Customer (this will remove it from stock) """ - order = self.line.order item = self.item.allocateToCustomer( diff --git a/InvenTree/order/serializers.py b/InvenTree/order/serializers.py index 0743d95c33..0ebeb95c42 100644 --- a/InvenTree/order/serializers.py +++ b/InvenTree/order/serializers.py @@ -1,6 +1,4 @@ -""" -JSON serializers for the Order API -""" +"""JSON serializers for the Order API.""" from datetime import datetime from decimal import Decimal @@ -33,9 +31,8 @@ from users.serializers import OwnerSerializer class AbstractOrderSerializer(serializers.Serializer): - """ - Abstract field definitions for OrderSerializers - """ + """Abstract field definitions for OrderSerializers.""" + total_price = InvenTreeMoneySerializer( source='get_total_price', allow_null=True, @@ -46,9 +43,10 @@ class AbstractOrderSerializer(serializers.Serializer): class AbstractExtraLineSerializer(serializers.Serializer): - """ Abstract Serializer for a ExtraLine object """ - def __init__(self, *args, **kwargs): + """Abstract Serializer for a ExtraLine object.""" + def __init__(self, *args, **kwargs): + """Initialization routine for the serializer""" order_detail = kwargs.pop('order_detail', False) super().__init__(*args, **kwargs) @@ -71,9 +69,7 @@ class AbstractExtraLineSerializer(serializers.Serializer): class AbstractExtraLineMeta: - """ - Abstract Meta for ExtraLine - """ + """Abstract Meta for ExtraLine.""" fields = [ 'pk', @@ -90,10 +86,10 @@ class AbstractExtraLineMeta: class PurchaseOrderSerializer(AbstractOrderSerializer, ReferenceIndexingSerializerMixin, InvenTreeModelSerializer): - """ Serializer for a PurchaseOrder object """ + """Serializer for a PurchaseOrder object.""" def __init__(self, *args, **kwargs): - + """Initialization routine for the serializer""" supplier_detail = kwargs.pop('supplier_detail', False) super().__init__(*args, **kwargs) @@ -103,13 +99,11 @@ class PurchaseOrderSerializer(AbstractOrderSerializer, ReferenceIndexingSerializ @staticmethod def annotate_queryset(queryset): - """ - Add extra information to the queryset + """Add extra information to the queryset. - Number of lines in the PurchaseOrder - Overdue status of the PurchaseOrder """ - queryset = queryset.annotate( line_items=SubqueryCount('lines') ) @@ -138,6 +132,8 @@ class PurchaseOrderSerializer(AbstractOrderSerializer, ReferenceIndexingSerializ responsible_detail = OwnerSerializer(source='responsible', read_only=True, many=False) class Meta: + """Metaclass options.""" + model = order.models.PurchaseOrder fields = [ @@ -172,18 +168,15 @@ class PurchaseOrderSerializer(AbstractOrderSerializer, ReferenceIndexingSerializ class PurchaseOrderCancelSerializer(serializers.Serializer): - """ - Serializer for cancelling a PurchaseOrder - """ + """Serializer for cancelling a PurchaseOrder.""" class Meta: + """Metaclass options.""" + fields = [], def get_context_data(self): - """ - Return custom context information about the order - """ - + """Return custom context information about the order.""" self.order = self.context['order'] return { @@ -191,7 +184,7 @@ class PurchaseOrderCancelSerializer(serializers.Serializer): } def save(self): - + """Save the serializer to 'cancel' the order""" order = self.context['order'] if not order.can_cancel(): @@ -201,18 +194,15 @@ class PurchaseOrderCancelSerializer(serializers.Serializer): class PurchaseOrderCompleteSerializer(serializers.Serializer): - """ - Serializer for completing a purchase order - """ + """Serializer for completing a purchase order.""" class Meta: + """Metaclass options.""" + fields = [] def get_context_data(self): - """ - Custom context information for this serializer - """ - + """Custom context information for this serializer.""" order = self.context['order'] return { @@ -220,34 +210,34 @@ class PurchaseOrderCompleteSerializer(serializers.Serializer): } def save(self): - + """Save the serializer to 'complete' the order""" order = self.context['order'] order.complete_order() class PurchaseOrderIssueSerializer(serializers.Serializer): - """ Serializer for issuing (sending) a purchase order """ + """Serializer for issuing (sending) a purchase order.""" class Meta: + """Metaclass options.""" + fields = [] def save(self): - + """Save the serializer to 'place' the order""" order = self.context['order'] order.place_order() class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer): - + """Serializer class for the PurchaseOrderLineItem model""" @staticmethod def annotate_queryset(queryset): - """ - Add some extra annotations to this queryset: + """Add some extra annotations to this queryset: - Total price = purchase_price * quantity - "Overdue" status (boolean field) """ - queryset = queryset.annotate( total_price=ExpressionWrapper( F('purchase_price') * F('quantity'), @@ -267,7 +257,7 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer): return queryset def __init__(self, *args, **kwargs): - + """Initialization routine for the serializer""" part_detail = kwargs.pop('part_detail', False) order_detail = kwargs.pop('order_detail', False) @@ -284,14 +274,14 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer): quantity = serializers.FloatField(min_value=0, required=True) def validate_quantity(self, quantity): - + """Validation for the 'quantity' field""" if quantity <= 0: raise ValidationError(_("Quantity must be greater than zero")) return quantity def validate_purchase_order(self, purchase_order): - + """Validation for the 'purchase_order' field""" if purchase_order.status not in PurchaseOrderStatus.OPEN: raise ValidationError(_('Order is not open')) @@ -323,7 +313,12 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer): order_detail = PurchaseOrderSerializer(source='order', read_only=True, many=False) def validate(self, data): + """Custom validation for the serializer: + - Ensure the supplier_part field is supplied + - Ensure the purchase_order field is supplied + - Ensure that the supplier_part and supplier references match + """ data = super().validate(data) supplier_part = data.get('part', None) @@ -349,6 +344,8 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer): return data class Meta: + """Metaclass options.""" + model = order.models.PurchaseOrderLineItem fields = [ @@ -374,20 +371,22 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer): class PurchaseOrderExtraLineSerializer(AbstractExtraLineSerializer, InvenTreeModelSerializer): - """ Serializer for a PurchaseOrderExtraLine object """ + """Serializer for a PurchaseOrderExtraLine object.""" order_detail = PurchaseOrderSerializer(source='order', many=False, read_only=True) class Meta(AbstractExtraLineMeta): + """Metaclass options.""" + model = order.models.PurchaseOrderExtraLine class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer): - """ - A serializer for receiving a single purchase order line item against a purchase order - """ + """A serializer for receiving a single purchase order line item against a purchase order.""" class Meta: + """Metaclass options.""" + fields = [ 'barcode', 'line_item', @@ -407,7 +406,7 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer): ) def validate_line_item(self, item): - + """Validation for the 'line_item' field""" if item.order != self.context['order']: raise ValidationError(_('Line item does not match purchase order')) @@ -430,7 +429,7 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer): ) def validate_quantity(self, quantity): - + """Validation for the 'quantity' field""" if quantity <= 0: raise ValidationError(_("Quantity must be greater than zero")) @@ -468,10 +467,7 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer): ) def validate_barcode(self, barcode): - """ - Cannot check in a LineItem with a barcode that is already assigned - """ - + """Cannot check in a LineItem with a barcode that is already assigned.""" # Ignore empty barcode values if not barcode or barcode.strip() == '': return None @@ -482,7 +478,11 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer): return barcode def validate(self, data): + """Custom validation for the serializer: + - Integer quantity must be provided for serialized stock + - Validate serial numbers (if provided) + """ data = super().validate(data) line_item = data['line_item'] @@ -513,9 +513,7 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer): class PurchaseOrderReceiveSerializer(serializers.Serializer): - """ - Serializer for receiving items against a purchase order - """ + """Serializer for receiving items against a purchase order.""" items = PurchaseOrderLineItemReceiveSerializer(many=True) @@ -528,7 +526,11 @@ class PurchaseOrderReceiveSerializer(serializers.Serializer): ) def validate(self, data): + """Custom validation for the serializer: + - Ensure line items are provided + - Check that a location is specified + """ super().validate(data) items = data.get('items', []) @@ -571,10 +573,7 @@ class PurchaseOrderReceiveSerializer(serializers.Serializer): return data def save(self): - """ - Perform the actual database transaction to receive purchase order items - """ - + """Perform the actual database transaction to receive purchase order items.""" data = self.validated_data request = self.context['request'] @@ -606,6 +605,8 @@ class PurchaseOrderReceiveSerializer(serializers.Serializer): raise ValidationError(detail=serializers.as_serializer_error(exc)) class Meta: + """Metaclass options.""" + fields = [ 'items', 'location', @@ -613,11 +614,11 @@ class PurchaseOrderReceiveSerializer(serializers.Serializer): class PurchaseOrderAttachmentSerializer(InvenTreeAttachmentSerializer): - """ - Serializers for the PurchaseOrderAttachment model - """ + """Serializers for the PurchaseOrderAttachment model.""" class Meta: + """Metaclass options.""" + model = order.models.PurchaseOrderAttachment fields = [ @@ -636,12 +637,10 @@ class PurchaseOrderAttachmentSerializer(InvenTreeAttachmentSerializer): class SalesOrderSerializer(AbstractOrderSerializer, ReferenceIndexingSerializerMixin, InvenTreeModelSerializer): - """ - Serializers for the SalesOrder object - """ + """Serializers for the SalesOrder object.""" def __init__(self, *args, **kwargs): - + """Initialization routine for the serializer""" customer_detail = kwargs.pop('customer_detail', False) super().__init__(*args, **kwargs) @@ -651,13 +650,11 @@ class SalesOrderSerializer(AbstractOrderSerializer, ReferenceIndexingSerializerM @staticmethod def annotate_queryset(queryset): - """ - Add extra information to the queryset + """Add extra information to the queryset. - Number of line items in the SalesOrder - Overdue status of the SalesOrder """ - queryset = queryset.annotate( line_items=SubqueryCount('lines') ) @@ -684,6 +681,8 @@ class SalesOrderSerializer(AbstractOrderSerializer, ReferenceIndexingSerializerM reference = serializers.CharField(required=True) class Meta: + """Metaclass options.""" + model = order.models.SalesOrder fields = [ @@ -715,8 +714,8 @@ class SalesOrderSerializer(AbstractOrderSerializer, ReferenceIndexingSerializerM class SalesOrderAllocationSerializer(InvenTreeModelSerializer): - """ - Serializer for the SalesOrderAllocation model. + """Serializer for the SalesOrderAllocation model. + This includes some fields from the related model objects. """ @@ -736,7 +735,7 @@ class SalesOrderAllocationSerializer(InvenTreeModelSerializer): shipment_date = serializers.DateField(source='shipment.shipment_date', read_only=True) def __init__(self, *args, **kwargs): - + """Initialization routine for the serializer""" order_detail = kwargs.pop('order_detail', False) part_detail = kwargs.pop('part_detail', True) item_detail = kwargs.pop('item_detail', False) @@ -761,6 +760,8 @@ class SalesOrderAllocationSerializer(InvenTreeModelSerializer): self.fields.pop('customer_detail') class Meta: + """Metaclass options.""" + model = order.models.SalesOrderAllocation fields = [ @@ -783,16 +784,14 @@ class SalesOrderAllocationSerializer(InvenTreeModelSerializer): class SalesOrderLineItemSerializer(InvenTreeModelSerializer): - """ Serializer for a SalesOrderLineItem object """ + """Serializer for a SalesOrderLineItem object.""" @staticmethod def annotate_queryset(queryset): - """ - Add some extra annotations to this queryset: + """Add some extra annotations to this queryset: - "Overdue" status (boolean field) """ - queryset = queryset.annotate( overdue=Case( When( @@ -803,7 +802,10 @@ class SalesOrderLineItemSerializer(InvenTreeModelSerializer): ) def __init__(self, *args, **kwargs): + """Initializion routine for the serializer: + - Add extra related serializer information if required + """ part_detail = kwargs.pop('part_detail', False) order_detail = kwargs.pop('order_detail', False) allocations = kwargs.pop('allocations', False) @@ -843,6 +845,8 @@ class SalesOrderLineItemSerializer(InvenTreeModelSerializer): ) class Meta: + """Metaclass options.""" + model = order.models.SalesOrderLineItem fields = [ @@ -866,15 +870,15 @@ class SalesOrderLineItemSerializer(InvenTreeModelSerializer): class SalesOrderShipmentSerializer(InvenTreeModelSerializer): - """ - Serializer for the SalesOrderShipment class - """ + """Serializer for the SalesOrderShipment class.""" allocations = SalesOrderAllocationSerializer(many=True, read_only=True, location_detail=True) order_detail = SalesOrderSerializer(source='order', read_only=True, many=False) class Meta: + """Metaclass options.""" + model = order.models.SalesOrderShipment fields = [ @@ -893,11 +897,11 @@ class SalesOrderShipmentSerializer(InvenTreeModelSerializer): class SalesOrderShipmentCompleteSerializer(serializers.ModelSerializer): - """ - Serializer for completing (shipping) a SalesOrderShipment - """ + """Serializer for completing (shipping) a SalesOrderShipment.""" class Meta: + """Metaclass options.""" + model = order.models.SalesOrderShipment fields = [ @@ -908,7 +912,10 @@ class SalesOrderShipmentCompleteSerializer(serializers.ModelSerializer): ] def validate(self, data): + """Custom validation for the serializer: + - Ensure the shipment reference is provided + """ data = super().validate(data) shipment = self.context.get('shipment', None) @@ -921,7 +928,7 @@ class SalesOrderShipmentCompleteSerializer(serializers.ModelSerializer): return data def save(self): - + """Save the serializer to complete the SalesOrderShipment""" shipment = self.context.get('shipment', None) if not shipment: @@ -945,11 +952,11 @@ class SalesOrderShipmentCompleteSerializer(serializers.ModelSerializer): class SalesOrderShipmentAllocationItemSerializer(serializers.Serializer): - """ - A serializer for allocating a single stock-item against a SalesOrder shipment - """ + """A serializer for allocating a single stock-item against a SalesOrder shipment.""" class Meta: + """Metaclass options.""" + fields = [ 'line_item', 'stock_item', @@ -965,7 +972,10 @@ class SalesOrderShipmentAllocationItemSerializer(serializers.Serializer): ) def validate_line_item(self, line_item): + """Custom validation for the 'line_item' field: + - Ensure the line_item is associated with the particular SalesOrder + """ order = self.context['order'] # Ensure that the line item points to the correct order @@ -990,14 +1000,18 @@ class SalesOrderShipmentAllocationItemSerializer(serializers.Serializer): ) def validate_quantity(self, quantity): - + """Custom validation for the 'quantity' field""" if quantity <= 0: raise ValidationError(_("Quantity must be positive")) return quantity def validate(self, data): + """Custom validation for the serializer: + - Ensure that the quantity is 1 for serialized stock + - Quantity cannot exceed the available amount + """ data = super().validate(data) stock_item = data['stock_item'] @@ -1019,12 +1033,10 @@ class SalesOrderShipmentAllocationItemSerializer(serializers.Serializer): class SalesOrderCompleteSerializer(serializers.Serializer): - """ - DRF serializer for manually marking a sales order as complete - """ + """DRF serializer for manually marking a sales order as complete.""" def validate(self, data): - + """Custom validation for the serializer""" data = super().validate(data) order = self.context['order'] @@ -1034,7 +1046,7 @@ class SalesOrderCompleteSerializer(serializers.Serializer): return data def save(self): - + """Save the serializer to complete the SalesOrder""" request = self.context['request'] order = self.context['order'] @@ -1044,11 +1056,10 @@ class SalesOrderCompleteSerializer(serializers.Serializer): class SalesOrderCancelSerializer(serializers.Serializer): - """ Serializer for marking a SalesOrder as cancelled - """ + """Serializer for marking a SalesOrder as cancelled.""" def get_context_data(self): - + """Add extra context data to the serializer""" order = self.context['order'] return { @@ -1056,18 +1067,18 @@ class SalesOrderCancelSerializer(serializers.Serializer): } def save(self): - + """Save the serializer to cancel the order""" order = self.context['order'] order.cancel_order() class SalesOrderSerialAllocationSerializer(serializers.Serializer): - """ - DRF serializer for allocation of serial numbers against a sales order / shipment - """ + """DRF serializer for allocation of serial numbers against a sales order / shipment.""" class Meta: + """Metaclass options.""" + fields = [ 'line_item', 'quantity', @@ -1084,10 +1095,7 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer): ) def validate_line_item(self, line_item): - """ - Ensure that the line_item is valid - """ - + """Ensure that the line_item is valid.""" order = self.context['order'] # Ensure that the line item points to the correct order @@ -1119,13 +1127,11 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer): ) def validate_shipment(self, shipment): - """ - Validate the shipment: + """Validate the shipment: - Must point to the same order - Must not be shipped """ - order = self.context['order'] if shipment.shipment_date is not None: @@ -1137,14 +1143,12 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer): return shipment def validate(self, data): - """ - Validation for the serializer: + """Validation for the serializer: - Ensure the serial_numbers and quantity fields match - Check that all serial numbers exist - Check that the serial numbers are not yet allocated """ - data = super().validate(data) line_item = data['line_item'] @@ -1207,7 +1211,7 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer): return data def save(self): - + """Allocate stock items against the sales order""" data = self.validated_data line_item = data['line_item'] @@ -1226,11 +1230,11 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer): class SalesOrderShipmentAllocationSerializer(serializers.Serializer): - """ - DRF serializer for allocation of stock items against a sales order / shipment - """ + """DRF serializer for allocation of stock items against a sales order / shipment.""" class Meta: + """Metaclass options.""" + fields = [ 'items', 'shipment', @@ -1247,10 +1251,7 @@ class SalesOrderShipmentAllocationSerializer(serializers.Serializer): ) def validate_shipment(self, shipment): - """ - Run validation against the provided shipment instance - """ - + """Run validation against the provided shipment instance.""" order = self.context['order'] if shipment.shipment_date is not None: @@ -1262,10 +1263,7 @@ class SalesOrderShipmentAllocationSerializer(serializers.Serializer): return shipment def validate(self, data): - """ - Serializer validation - """ - + """Serializer validation.""" data = super().validate(data) # Extract SalesOrder from serializer context @@ -1279,10 +1277,7 @@ class SalesOrderShipmentAllocationSerializer(serializers.Serializer): return data def save(self): - """ - Perform the allocation of items against this order - """ - + """Perform the allocation of items against this order.""" data = self.validated_data items = data['items'] @@ -1304,20 +1299,22 @@ class SalesOrderShipmentAllocationSerializer(serializers.Serializer): class SalesOrderExtraLineSerializer(AbstractExtraLineSerializer, InvenTreeModelSerializer): - """ Serializer for a SalesOrderExtraLine object """ + """Serializer for a SalesOrderExtraLine object.""" order_detail = SalesOrderSerializer(source='order', many=False, read_only=True) class Meta(AbstractExtraLineMeta): + """Metaclass options.""" + model = order.models.SalesOrderExtraLine class SalesOrderAttachmentSerializer(InvenTreeAttachmentSerializer): - """ - Serializers for the SalesOrderAttachment model - """ + """Serializers for the SalesOrderAttachment model.""" class Meta: + """Metaclass options.""" + model = order.models.SalesOrderAttachment fields = [ diff --git a/InvenTree/order/test_api.py b/InvenTree/order/test_api.py index 6be694bb11..f3bc6e9a42 100644 --- a/InvenTree/order/test_api.py +++ b/InvenTree/order/test_api.py @@ -1,6 +1,4 @@ -""" -Tests for the Order API -""" +"""Tests for the Order API.""" import io from datetime import datetime, timedelta @@ -18,7 +16,7 @@ from stock.models import StockItem class OrderTest(InvenTreeAPITestCase): - + """Base class for order API unit testing""" fixtures = [ 'category', 'part', @@ -35,14 +33,8 @@ class OrderTest(InvenTreeAPITestCase): 'sales_order.change', ] - def setUp(self): - super().setUp() - def filter(self, filters, count): - """ - Test API filters - """ - + """Test API filters.""" response = self.get( self.LIST_URL, filters @@ -55,14 +47,12 @@ class OrderTest(InvenTreeAPITestCase): class PurchaseOrderTest(OrderTest): - """ - Tests for the PurchaseOrder API - """ + """Tests for the PurchaseOrder API.""" LIST_URL = reverse('api-po-list') def test_po_list(self): - + """Test the PurchaseOrder list API endpoint""" # List *ALL* PurchaseOrder items self.filter({}, 7) @@ -79,10 +69,7 @@ class PurchaseOrderTest(OrderTest): self.filter({'status': 40}, 1) def test_overdue(self): - """ - Test "overdue" status - """ - + """Test "overdue" status.""" self.filter({'overdue': True}, 0) self.filter({'overdue': False}, 7) @@ -94,7 +81,7 @@ class PurchaseOrderTest(OrderTest): self.filter({'overdue': False}, 6) def test_po_detail(self): - + """Test the PurchaseOrder detail API endpoint""" url = '/api/order/po/1/' response = self.get(url) @@ -107,7 +94,7 @@ class PurchaseOrderTest(OrderTest): self.assertEqual(data['description'], 'Ordering some screws') def test_po_reference(self): - """test that a reference with a too big / small reference is not possible""" + """Test that a reference with a too big / small reference is not possible.""" # get permissions self.assignRole('purchase_order.add') @@ -125,7 +112,7 @@ class PurchaseOrderTest(OrderTest): ) def test_po_attachments(self): - + """Test the list endpoint for the PurchaseOrderAttachment model""" url = reverse('api-po-attachment-list') response = self.get(url) @@ -133,10 +120,7 @@ class PurchaseOrderTest(OrderTest): self.assertEqual(response.status_code, status.HTTP_200_OK) def test_po_operations(self): - """ - Test that we can create / edit and delete a PurchaseOrder via the API - """ - + """Test that we can create / edit and delete a PurchaseOrder via the API.""" n = models.PurchaseOrder.objects.count() url = reverse('api-po-list') @@ -223,10 +207,7 @@ class PurchaseOrderTest(OrderTest): response = self.get(url, expected_code=404) def test_po_create(self): - """ - Test that we can create a new PurchaseOrder via the API - """ - + """Test that we can create a new PurchaseOrder via the API.""" self.assignRole('purchase_order.add') self.post( @@ -240,10 +221,7 @@ class PurchaseOrderTest(OrderTest): ) def test_po_cancel(self): - """ - Test the PurchaseOrderCancel API endpoint - """ - + """Test the PurchaseOrderCancel API endpoint.""" po = models.PurchaseOrder.objects.get(pk=1) self.assertEqual(po.status, PurchaseOrderStatus.PENDING) @@ -269,8 +247,7 @@ class PurchaseOrderTest(OrderTest): self.post(url, {}, expected_code=400) def test_po_complete(self): - """ Test the PurchaseOrderComplete API endpoint """ - + """Test the PurchaseOrderComplete API endpoint.""" po = models.PurchaseOrder.objects.get(pk=3) url = reverse('api-po-complete', kwargs={'pk': po.pk}) @@ -289,8 +266,7 @@ class PurchaseOrderTest(OrderTest): self.assertEqual(po.status, PurchaseOrderStatus.COMPLETE) def test_po_issue(self): - """ Test the PurchaseOrderIssue API endpoint """ - + """Test the PurchaseOrderIssue API endpoint.""" po = models.PurchaseOrder.objects.get(pk=2) url = reverse('api-po-issue', kwargs={'pk': po.pk}) @@ -307,6 +283,7 @@ class PurchaseOrderTest(OrderTest): self.assertEqual(po.status, PurchaseOrderStatus.PLACED) def test_po_metadata(self): + """Test the 'metadata' endpoint for the PurchaseOrder model""" url = reverse('api-po-metadata', kwargs={'pk': 1}) self.patch( @@ -324,7 +301,7 @@ class PurchaseOrderTest(OrderTest): class PurchaseOrderDownloadTest(OrderTest): - """Unit tests for downloading PurchaseOrder data via the API endpoint""" + """Unit tests for downloading PurchaseOrder data via the API endpoint.""" required_cols = [ 'id', @@ -342,8 +319,7 @@ class PurchaseOrderDownloadTest(OrderTest): ] def test_download_wrong_format(self): - """Incorrect format should default raise an error""" - + """Incorrect format should default raise an error.""" url = reverse('api-po-list') with self.assertRaises(ValueError): @@ -355,8 +331,7 @@ class PurchaseOrderDownloadTest(OrderTest): ) def test_download_csv(self): - """Download PurchaseOrder data as .csv""" - + """Download PurchaseOrder data as .csv.""" with self.download_file( reverse('api-po-list'), { @@ -380,7 +355,7 @@ class PurchaseOrderDownloadTest(OrderTest): self.assertEqual(order.reference, row['reference']) def test_download_line_items(self): - + """Test that the PurchaseOrderLineItems can be downloaded to a file""" with self.download_file( reverse('api-po-line-list'), { @@ -395,11 +370,10 @@ class PurchaseOrderDownloadTest(OrderTest): class PurchaseOrderReceiveTest(OrderTest): - """ - Unit tests for receiving items against a PurchaseOrder - """ + """Unit tests for receiving items against a PurchaseOrder.""" def setUp(self): + """Init routines for this unit test class""" super().setUp() self.assignRole('purchase_order.add') @@ -415,10 +389,7 @@ class PurchaseOrderReceiveTest(OrderTest): order.save() def test_empty(self): - """ - Test without any POST data - """ - + """Test without any POST data.""" data = self.post(self.url, {}, expected_code=400).data self.assertIn('This field is required', str(data['items'])) @@ -428,10 +399,7 @@ class PurchaseOrderReceiveTest(OrderTest): self.assertEqual(self.n, StockItem.objects.count()) def test_no_items(self): - """ - Test with an empty list of items - """ - + """Test with an empty list of items.""" data = self.post( self.url, { @@ -447,10 +415,7 @@ class PurchaseOrderReceiveTest(OrderTest): self.assertEqual(self.n, StockItem.objects.count()) def test_invalid_items(self): - """ - Test than errors are returned as expected for invalid data - """ - + """Test than errors are returned as expected for invalid data.""" data = self.post( self.url, { @@ -473,10 +438,7 @@ class PurchaseOrderReceiveTest(OrderTest): self.assertEqual(self.n, StockItem.objects.count()) def test_invalid_status(self): - """ - Test with an invalid StockStatus value - """ - + """Test with an invalid StockStatus value.""" data = self.post( self.url, { @@ -498,10 +460,7 @@ class PurchaseOrderReceiveTest(OrderTest): self.assertEqual(self.n, StockItem.objects.count()) def test_mismatched_items(self): - """ - Test for supplier parts which *do* exist but do not match the order supplier - """ - + """Test for supplier parts which *do* exist but do not match the order supplier.""" data = self.post( self.url, { @@ -523,10 +482,7 @@ class PurchaseOrderReceiveTest(OrderTest): self.assertEqual(self.n, StockItem.objects.count()) def test_null_barcode(self): - """ - Test than a "null" barcode field can be provided - """ - + """Test than a "null" barcode field can be provided.""" # Set stock item barcode item = StockItem.objects.get(pk=1) item.save() @@ -548,13 +504,11 @@ class PurchaseOrderReceiveTest(OrderTest): ) def test_invalid_barcodes(self): - """ - Tests for checking in items with invalid barcodes: + """Tests for checking in items with invalid barcodes: - Cannot check in "duplicate" barcodes - Barcodes cannot match UID field for existing StockItem """ - # Set stock item barcode item = StockItem.objects.get(pk=1) item.uid = 'MY-BARCODE-HASH' @@ -603,10 +557,7 @@ class PurchaseOrderReceiveTest(OrderTest): self.assertEqual(self.n, StockItem.objects.count()) def test_valid(self): - """ - Test receipt of valid data - """ - + """Test receipt of valid data.""" line_1 = models.PurchaseOrderLineItem.objects.get(pk=1) line_2 = models.PurchaseOrderLineItem.objects.get(pk=2) @@ -683,10 +634,7 @@ class PurchaseOrderReceiveTest(OrderTest): self.assertTrue(StockItem.objects.filter(uid='MY-UNIQUE-BARCODE-456').exists()) def test_batch_code(self): - """ - Test that we can supply a 'batch code' when receiving items - """ - + """Test that we can supply a 'batch code' when receiving items.""" line_1 = models.PurchaseOrderLineItem.objects.get(pk=1) line_2 = models.PurchaseOrderLineItem.objects.get(pk=2) @@ -727,10 +675,7 @@ class PurchaseOrderReceiveTest(OrderTest): self.assertEqual(item_2.batch, 'xyz-789') def test_serial_numbers(self): - """ - Test that we can supply a 'serial number' when receiving items - """ - + """Test that we can supply a 'serial number' when receiving items.""" line_1 = models.PurchaseOrderLineItem.objects.get(pk=1) line_2 = models.PurchaseOrderLineItem.objects.get(pk=2) @@ -786,14 +731,12 @@ class PurchaseOrderReceiveTest(OrderTest): class SalesOrderTest(OrderTest): - """ - Tests for the SalesOrder API - """ + """Tests for the SalesOrder API.""" LIST_URL = reverse('api-so-list') def test_so_list(self): - + """Test the SalesOrder list API endpoint""" # All orders self.filter({}, 5) @@ -811,10 +754,7 @@ class SalesOrderTest(OrderTest): self.filter({'status': 99}, 0) # Invalid def test_overdue(self): - """ - Test "overdue" status - """ - + """Test "overdue" status.""" self.filter({'overdue': True}, 0) self.filter({'overdue': False}, 5) @@ -827,7 +767,7 @@ class SalesOrderTest(OrderTest): self.filter({'overdue': False}, 3) def test_so_detail(self): - + """Test the SalesOrder detail endpoint""" url = '/api/order/so/1/' response = self.get(url) @@ -837,16 +777,13 @@ class SalesOrderTest(OrderTest): self.assertEqual(data['pk'], 1) def test_so_attachments(self): - + """Test the list endpoint for the SalesOrderAttachment model""" url = reverse('api-so-attachment-list') self.get(url) def test_so_operations(self): - """ - Test that we can create / edit and delete a SalesOrder via the API - """ - + """Test that we can create / edit and delete a SalesOrder via the API.""" n = models.SalesOrder.objects.count() url = reverse('api-so-list') @@ -926,10 +863,7 @@ class SalesOrderTest(OrderTest): response = self.get(url, expected_code=404) def test_so_create(self): - """ - Test that we can create a new SalesOrder via the API - """ - + """Test that we can create a new SalesOrder via the API.""" self.assignRole('sales_order.add') self.post( @@ -943,8 +877,7 @@ class SalesOrderTest(OrderTest): ) def test_so_cancel(self): - """ Test API endpoint for cancelling a SalesOrder """ - + """Test API endpoint for cancelling a SalesOrder.""" so = models.SalesOrder.objects.get(pk=1) self.assertEqual(so.status, SalesOrderStatus.PENDING) @@ -963,6 +896,7 @@ class SalesOrderTest(OrderTest): self.assertEqual(so.status, SalesOrderStatus.CANCELLED) def test_so_metadata(self): + """Test the 'metadata' API endpoint for the SalesOrder model""" url = reverse('api-so-metadata', kwargs={'pk': 1}) self.patch( @@ -980,12 +914,10 @@ class SalesOrderTest(OrderTest): class SalesOrderLineItemTest(OrderTest): - """ - Tests for the SalesOrderLineItem API - """ + """Tests for the SalesOrderLineItem API.""" def setUp(self): - + """Init routine for this unit test class""" super().setUp() # List of salable parts @@ -1005,9 +937,7 @@ class SalesOrderLineItemTest(OrderTest): self.url = reverse('api-so-line-list') def test_so_line_list(self): - - # List *all* lines - + """Test list endpoint""" response = self.get( self.url, {}, @@ -1060,17 +990,17 @@ class SalesOrderLineItemTest(OrderTest): class SalesOrderDownloadTest(OrderTest): - """Unit tests for downloading SalesOrder data via the API endpoint""" + """Unit tests for downloading SalesOrder data via the API endpoint.""" def test_download_fail(self): - """Test that downloading without the 'export' option fails""" - + """Test that downloading without the 'export' option fails.""" url = reverse('api-so-list') with self.assertRaises(ValueError): self.download_file(url, {}, expected_code=200) def test_download_xls(self): + """Test xls file download""" url = reverse('api-so-list') # Download .xls file @@ -1086,7 +1016,7 @@ class SalesOrderDownloadTest(OrderTest): self.assertTrue(isinstance(fo, io.BytesIO)) def test_download_csv(self): - + """Tesst that the list of sales orders can be downloaded as a .csv file""" url = reverse('api-so-list') required_cols = [ @@ -1151,11 +1081,10 @@ class SalesOrderDownloadTest(OrderTest): class SalesOrderAllocateTest(OrderTest): - """ - Unit tests for allocating stock items against a SalesOrder - """ + """Unit tests for allocating stock items against a SalesOrder.""" def setUp(self): + """Init routines for this unit testing class""" super().setUp() self.assignRole('sales_order.add') @@ -1188,10 +1117,7 @@ class SalesOrderAllocateTest(OrderTest): ) def test_invalid(self): - """ - Test POST with invalid data - """ - + """Test POST with invalid data.""" # No data response = self.post(self.url, {}, expected_code=400) @@ -1244,11 +1170,7 @@ class SalesOrderAllocateTest(OrderTest): self.assertIn('Shipment is not associated with this order', str(response.data['shipment'])) def test_allocate(self): - """ - Test the the allocation endpoint acts as expected, - when provided with valid data! - """ - + """Test the the allocation endpoint acts as expected, when provided with valid data!""" # First, check that there are no line items allocated against this SalesOrder self.assertEqual(self.order.stock_allocations.count(), 0) @@ -1278,8 +1200,7 @@ class SalesOrderAllocateTest(OrderTest): self.assertEqual(line.allocations.count(), 1) def test_shipment_complete(self): - """Test that we can complete a shipment via the API""" - + """Test that we can complete a shipment via the API.""" url = reverse('api-so-shipment-ship', kwargs={'pk': self.shipment.pk}) self.assertFalse(self.shipment.is_complete()) @@ -1340,7 +1261,7 @@ class SalesOrderAllocateTest(OrderTest): self.assertEqual(self.shipment.link, 'http://test.com/link.html') def test_sales_order_shipment_list(self): - + """Test the SalesOrderShipment list API endpoint""" url = reverse('api-so-shipment-list') # Create some new shipments via the API diff --git a/InvenTree/order/test_migrations.py b/InvenTree/order/test_migrations.py index 61299a8e2f..1734501a9c 100644 --- a/InvenTree/order/test_migrations.py +++ b/InvenTree/order/test_migrations.py @@ -1,6 +1,4 @@ -""" -Unit tests for the 'order' model data migrations -""" +"""Unit tests for the 'order' model data migrations.""" from django_test_migrations.contrib.unittest_case import MigratorTestCase @@ -8,18 +6,13 @@ from InvenTree.status_codes import SalesOrderStatus class TestRefIntMigrations(MigratorTestCase): - """ - Test entire schema migration - """ + """Test entire schema migration.""" migrate_from = ('order', '0040_salesorder_target_date') migrate_to = ('order', '0061_merge_0054_auto_20211201_2139_0060_auto_20211129_1339') def prepare(self): - """ - Create initial data set - """ - + """Create initial data set.""" # Create a purchase order from a supplier Company = self.old_state.apps.get_model('company', 'company') @@ -57,10 +50,7 @@ class TestRefIntMigrations(MigratorTestCase): print(sales_order.reference_int) def test_ref_field(self): - """ - Test that the 'reference_int' field has been created and is filled out correctly - """ - + """Test that the 'reference_int' field has been created and is filled out correctly.""" PurchaseOrder = self.new_state.apps.get_model('order', 'purchaseorder') SalesOrder = self.new_state.apps.get_model('order', 'salesorder') @@ -75,18 +65,13 @@ class TestRefIntMigrations(MigratorTestCase): class TestShipmentMigration(MigratorTestCase): - """ - Test data migration for the "SalesOrderShipment" model - """ + """Test data migration for the "SalesOrderShipment" model.""" migrate_from = ('order', '0051_auto_20211014_0623') migrate_to = ('order', '0055_auto_20211025_0645') def prepare(self): - """ - Create an initial SalesOrder - """ - + """Create an initial SalesOrder.""" Company = self.old_state.apps.get_model('company', 'company') customer = Company.objects.create( @@ -112,10 +97,7 @@ class TestShipmentMigration(MigratorTestCase): self.old_state.apps.get_model('order', 'salesordershipment') def test_shipment_creation(self): - """ - Check that a SalesOrderShipment has been created - """ - + """Check that a SalesOrderShipment has been created.""" SalesOrder = self.new_state.apps.get_model('order', 'salesorder') Shipment = self.new_state.apps.get_model('order', 'salesordershipment') @@ -125,18 +107,13 @@ class TestShipmentMigration(MigratorTestCase): class TestAdditionalLineMigration(MigratorTestCase): - """ - Test entire schema migration - """ + """Test entire schema migration.""" migrate_from = ('order', '0063_alter_purchaseorderlineitem_unique_together') migrate_to = ('order', '0064_purchaseorderextraline_salesorderextraline') def prepare(self): - """ - Create initial data set - """ - + """Create initial data set.""" # Create a purchase order from a supplier Company = self.old_state.apps.get_model('company', 'company') PurchaseOrder = self.old_state.apps.get_model('order', 'purchaseorder') @@ -199,10 +176,7 @@ class TestAdditionalLineMigration(MigratorTestCase): # ) def test_po_migration(self): - """ - Test that the the PO lines where converted correctly - """ - + """Test that the the PO lines where converted correctly.""" PurchaseOrder = self.new_state.apps.get_model('order', 'purchaseorder') for ii in range(10): diff --git a/InvenTree/order/test_sales_order.py b/InvenTree/order/test_sales_order.py index b357b14f55..aeb11f72bc 100644 --- a/InvenTree/order/test_sales_order.py +++ b/InvenTree/order/test_sales_order.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +"""Unit tests for the SalesOrder models""" from datetime import datetime, timedelta @@ -15,13 +15,10 @@ from stock.models import StockItem class SalesOrderTest(TestCase): - """ - Run tests to ensure that the SalesOrder model is working correctly. - - """ + """Run tests to ensure that the SalesOrder model is working correctly.""" def setUp(self): - + """Initial setup for this set of unit tests""" # Create a Company to ship the goods to self.customer = Company.objects.create(name="ABC Co", description="My customer", is_customer=True) @@ -48,11 +45,21 @@ class SalesOrderTest(TestCase): # Create a line item self.line = SalesOrderLineItem.objects.create(quantity=50, order=self.order, part=self.part) - def test_overdue(self): - """ - Tests for overdue functionality - """ + def test_rebuild_reference(self): + """Test that the 'reference_int' field gets rebuilt when the model is saved""" + self.assertEqual(self.order.reference_int, 1234) + + self.order.reference = '999' + self.order.save() + self.assertEqual(self.order.reference_int, 999) + + self.order.reference = '1000K' + self.order.save() + self.assertEqual(self.order.reference_int, 1000) + + def test_overdue(self): + """Tests for overdue functionality.""" today = datetime.now().date() # By default, order is *not* overdue as the target date is not set @@ -69,6 +76,7 @@ class SalesOrderTest(TestCase): self.assertFalse(self.order.is_overdue) def test_empty_order(self): + """Test for an empty order""" self.assertEqual(self.line.quantity, 50) self.assertEqual(self.line.allocated_quantity(), 0) self.assertEqual(self.line.fulfilled_quantity(), 0) @@ -79,14 +87,13 @@ class SalesOrderTest(TestCase): self.assertFalse(self.order.is_fully_allocated()) def test_add_duplicate_line_item(self): - # Adding a duplicate line item to a SalesOrder is accepted + """Adding a duplicate line item to a SalesOrder is accepted""" for ii in range(1, 5): SalesOrderLineItem.objects.create(order=self.order, part=self.part, quantity=ii) def allocate_stock(self, full=True): - - # Allocate stock to the order + """Allocate stock to the order""" SalesOrderAllocation.objects.create( line=self.line, shipment=self.shipment, @@ -101,7 +108,7 @@ class SalesOrderTest(TestCase): ) def test_allocate_partial(self): - # Partially allocate stock + """Partially allocate stock""" self.allocate_stock(False) self.assertFalse(self.order.is_fully_allocated()) @@ -110,7 +117,7 @@ class SalesOrderTest(TestCase): self.assertEqual(self.line.fulfilled_quantity(), 0) def test_allocate_full(self): - # Fully allocate stock + """Fully allocate stock""" self.allocate_stock(True) self.assertTrue(self.order.is_fully_allocated()) @@ -118,8 +125,7 @@ class SalesOrderTest(TestCase): self.assertEqual(self.line.allocated_quantity(), 50) def test_order_cancel(self): - # Allocate line items then cancel the order - + """Allocate line items then cancel the order""" self.allocate_stock(True) self.assertEqual(SalesOrderAllocation.objects.count(), 2) @@ -137,8 +143,7 @@ class SalesOrderTest(TestCase): self.assertFalse(result) def test_complete_order(self): - # Allocate line items, then ship the order - + """Allocate line items, then ship the order""" # Assert some stuff before we run the test # Initially there are two stock items self.assertEqual(StockItem.objects.count(), 2) @@ -199,8 +204,7 @@ class SalesOrderTest(TestCase): self.assertEqual(self.line.allocated_quantity(), 50) def test_default_shipment(self): - # Test sales order default shipment creation - + """Test sales order default shipment creation""" # Default setting value should be False self.assertEqual(False, InvenTreeSetting.get_setting('SALESORDER_DEFAULT_SHIPMENT')) diff --git a/InvenTree/order/test_views.py b/InvenTree/order/test_views.py index aad0fed25d..cadd2914ad 100644 --- a/InvenTree/order/test_views.py +++ b/InvenTree/order/test_views.py @@ -1,4 +1,4 @@ -""" Unit tests for Order views (see views.py) """ +"""Unit tests for Order views (see views.py)""" from django.urls import reverse @@ -6,7 +6,7 @@ from InvenTree.helpers import InvenTreeTestCase class OrderViewTestCase(InvenTreeTestCase): - + """Base unit test class for order views""" fixtures = [ 'category', 'part', @@ -29,26 +29,26 @@ class OrderViewTestCase(InvenTreeTestCase): class OrderListTest(OrderViewTestCase): - + """Unit tests for the PurchaseOrder index page""" def test_order_list(self): + """Tests for the PurchaseOrder index page""" response = self.client.get(reverse('po-index')) self.assertEqual(response.status_code, 200) class PurchaseOrderTests(OrderViewTestCase): - """ Tests for PurchaseOrder views """ + """Tests for PurchaseOrder views.""" def test_detail_view(self): - """ Retrieve PO detail view """ + """Retrieve PO detail view.""" response = self.client.get(reverse('po-detail', args=(1,))) self.assertEqual(response.status_code, 200) keys = response.context.keys() self.assertIn('PurchaseOrderStatus', keys) def test_po_export(self): - """ Export PurchaseOrder """ - + """Export PurchaseOrder.""" response = self.client.get(reverse('po-export', args=(1,)), HTTP_X_REQUESTED_WITH='XMLHttpRequest') # Response should be streaming-content (file download) diff --git a/InvenTree/order/tests.py b/InvenTree/order/tests.py index 1c23be2829..b8fd945518 100644 --- a/InvenTree/order/tests.py +++ b/InvenTree/order/tests.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +"""Various unit tests for order models""" from datetime import datetime, timedelta @@ -14,9 +14,7 @@ from .models import PurchaseOrder, PurchaseOrderLineItem class OrderTest(TestCase): - """ - Tests to ensure that the order models are functioning correctly. - """ + """Tests to ensure that the order models are functioning correctly.""" fixtures = [ 'company', @@ -30,8 +28,7 @@ class OrderTest(TestCase): ] def test_basics(self): - """ Basic tests e.g. repr functions etc """ - + """Basic tests e.g. repr functions etc.""" order = PurchaseOrder.objects.get(pk=1) self.assertEqual(order.get_absolute_url(), '/order/purchase-order/1/') @@ -42,11 +39,19 @@ class OrderTest(TestCase): self.assertEqual(str(line), "100 x ACME0001 from ACME (for PO0001 - ACME)") - def test_overdue(self): - """ - Test overdue status functionality - """ + def test_rebuild_reference(self): + """Test that the reference_int field is correctly updated when the model is saved""" + order = PurchaseOrder.objects.get(pk=1) + order.save() + self.assertEqual(order.reference_int, 1) + + order.reference = '12345XYZ' + order.save() + self.assertEqual(order.reference_int, 12345) + + def test_overdue(self): + """Test overdue status functionality.""" today = datetime.now().date() order = PurchaseOrder.objects.get(pk=1) @@ -61,8 +66,7 @@ class OrderTest(TestCase): self.assertFalse(order.is_overdue) def test_on_order(self): - """ There should be 3 separate items on order for the M2x4 LPHS part """ - + """There should be 3 separate items on order for the M2x4 LPHS part.""" part = Part.objects.get(name='M2x4 LPHS') open_orders = [] @@ -76,8 +80,7 @@ class OrderTest(TestCase): self.assertEqual(part.on_order, 1400) def test_add_items(self): - """ Test functions for adding line items to an order """ - + """Test functions for adding line items to an order.""" order = PurchaseOrder.objects.get(pk=1) self.assertEqual(order.status, PurchaseOrderStatus.PENDING) @@ -113,8 +116,7 @@ class OrderTest(TestCase): order.add_line_item(sku, 99) def test_pricing(self): - """ Test functions for adding line items to an order including price-breaks """ - + """Test functions for adding line items to an order including price-breaks.""" order = PurchaseOrder.objects.get(pk=7) self.assertEqual(order.status, PurchaseOrderStatus.PENDING) @@ -146,8 +148,7 @@ class OrderTest(TestCase): self.assertEqual(order.lines.first().purchase_price.amount, 1.25) def test_receive(self): - """ Test order receiving functions """ - + """Test order receiving functions.""" part = Part.objects.get(name='M2x4 LPHS') # Receive some items diff --git a/InvenTree/order/urls.py b/InvenTree/order/urls.py index 15e7f5b1bb..278914bd75 100644 --- a/InvenTree/order/urls.py +++ b/InvenTree/order/urls.py @@ -1,5 +1,4 @@ -""" -URL lookup for the Order app. Provides URL endpoints for: +"""URL lookup for the Order app. Provides URL endpoints for: - List view of Purchase Orders - Detail view of Purchase Orders diff --git a/InvenTree/order/views.py b/InvenTree/order/views.py index 45b7402dc0..b3909971e0 100644 --- a/InvenTree/order/views.py +++ b/InvenTree/order/views.py @@ -1,6 +1,4 @@ -""" -Django views for interacting with Order app -""" +"""Django views for interacting with Order app.""" import logging from decimal import Decimal, InvalidOperation @@ -33,48 +31,36 @@ logger = logging.getLogger("inventree") class PurchaseOrderIndex(InvenTreeRoleMixin, ListView): - """ List view for all purchase orders """ + """List view for all purchase orders.""" model = PurchaseOrder template_name = 'order/purchase_orders.html' context_object_name = 'orders' def get_queryset(self): - """ Retrieve the list of purchase orders, - ensure that the most recent ones are returned first. """ - + """Retrieve the list of purchase orders, ensure that the most recent ones are returned first.""" queryset = PurchaseOrder.objects.all().order_by('-creation_date') return queryset - def get_context_data(self, **kwargs): - ctx = super().get_context_data(**kwargs) - - return ctx - class SalesOrderIndex(InvenTreeRoleMixin, ListView): - + """SalesOrder index (list) view class""" model = SalesOrder template_name = 'order/sales_orders.html' context_object_name = 'orders' class PurchaseOrderDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): - """ Detail view for a PurchaseOrder object """ + """Detail view for a PurchaseOrder object.""" context_object_name = 'order' queryset = PurchaseOrder.objects.all().prefetch_related('lines') template_name = 'order/purchase_order_detail.html' - def get_context_data(self, **kwargs): - ctx = super().get_context_data(**kwargs) - - return ctx - class SalesOrderDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): - """ Detail view for a SalesOrder object """ + """Detail view for a SalesOrder object.""" context_object_name = 'order' queryset = SalesOrder.objects.all().prefetch_related('lines__allocations__item__purchase_order') @@ -82,9 +68,10 @@ class SalesOrderDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView) class PurchaseOrderUpload(FileManagementFormView): - ''' PurchaseOrder: Upload file, match to fields and parts (using multi-Step form) ''' + """PurchaseOrder: Upload file, match to fields and parts (using multi-Step form)""" class OrderFileManager(FileManager): + """Specify required fields""" REQUIRED_HEADERS = [ 'Quantity', ] @@ -126,13 +113,11 @@ class PurchaseOrderUpload(FileManagementFormView): file_manager_class = OrderFileManager def get_order(self): - """ Get order or return 404 """ - + """Get order or return 404.""" return get_object_or_404(PurchaseOrder, pk=self.kwargs['pk']) def get_context_data(self, form, **kwargs): - """ Handle context data for order """ - + """Handle context data for order.""" context = super().get_context_data(form=form, **kwargs) order = self.get_order() @@ -142,11 +127,11 @@ class PurchaseOrderUpload(FileManagementFormView): return context def get_field_selection(self): - """ Once data columns have been selected, attempt to pre-select the proper data from the database. + """Once data columns have been selected, attempt to pre-select the proper data from the database. + This function is called once the field selection has been validated. The pre-fill data are then passed through to the SupplierPart selection form. """ - order = self.get_order() self.allowed_items = SupplierPart.objects.filter(supplier=order.supplier).prefetch_related('manufacturer_part') @@ -231,8 +216,7 @@ class PurchaseOrderUpload(FileManagementFormView): row['notes'] = notes def done(self, form_list, **kwargs): - """ Once all the data is in, process it to add PurchaseOrderLineItem instances to the order """ - + """Once all the data is in, process it to add PurchaseOrderLineItem instances to the order.""" order = self.get_order() items = self.get_clean_items() @@ -263,8 +247,7 @@ class PurchaseOrderUpload(FileManagementFormView): class SalesOrderExport(AjaxView): - """ - Export a sales order + """Export a sales order. - File format can optionally be passed as a query parameter e.g. ?format=CSV - Default file format is CSV @@ -275,7 +258,7 @@ class SalesOrderExport(AjaxView): role_required = 'sales_order.view' def get(self, request, *args, **kwargs): - + """Perform GET request to export SalesOrder dataset""" order = get_object_or_404(SalesOrder, pk=self.kwargs.get('pk', None)) export_format = request.GET.get('format', 'csv') @@ -290,7 +273,7 @@ class SalesOrderExport(AjaxView): class PurchaseOrderExport(AjaxView): - """ File download for a purchase order + """File download for a purchase order. - File format can be optionally passed as a query param e.g. ?format=CSV - Default file format is CSV @@ -302,7 +285,7 @@ class PurchaseOrderExport(AjaxView): role_required = 'purchase_order.view' def get(self, request, *args, **kwargs): - + """Perform GET request to export PurchaseOrder dataset""" order = get_object_or_404(PurchaseOrder, pk=self.kwargs.get('pk', None)) export_format = request.GET.get('format', 'csv') @@ -321,15 +304,17 @@ class PurchaseOrderExport(AjaxView): class LineItemPricing(PartPricing): - """ View for inspecting part pricing information """ + """View for inspecting part pricing information.""" class EnhancedForm(PartPricing.form_class): + """Extra form options""" pk = IntegerField(widget=HiddenInput()) so_line = IntegerField(widget=HiddenInput()) form_class = EnhancedForm def get_part(self, id=False): + """Return the Part instance associated with this view""" if 'line_item' in self.request.GET: try: part_id = self.request.GET.get('line_item') @@ -350,6 +335,7 @@ class LineItemPricing(PartPricing): return part def get_so(self, pk=False): + """Return the SalesOrderLineItem associated with this view""" so_line = self.request.GET.get('line_item', None) if not so_line: so_line = self.request.POST.get('so_line', None) @@ -365,20 +351,21 @@ class LineItemPricing(PartPricing): return None def get_quantity(self): - """ Return set quantity in decimal format """ + """Return set quantity in decimal format.""" qty = Decimal(self.request.GET.get('quantity', 1)) if qty == 1: return Decimal(self.request.POST.get('quantity', 1)) return qty def get_initials(self): + """Return initial context values for this view""" initials = super().get_initials() initials['pk'] = self.get_part(id=True) initials['so_line'] = self.get_so(pk=True) return initials def post(self, request, *args, **kwargs): - # parse extra actions + """Respond to a POST request to get particular pricing information""" REF = 'act-btn_' act_btn = [a.replace(REF, '') for a in self.request.POST if REF in a] diff --git a/InvenTree/part/__init__.py b/InvenTree/part/__init__.py index e5912481ff..1412ff8f11 100644 --- a/InvenTree/part/__init__.py +++ b/InvenTree/part/__init__.py @@ -1,9 +1 @@ -""" -The Part module is responsible for Part management. - -It includes models for: - -- PartCategory -- Part -- BomItem -""" +"""The Part module is responsible for Part management.""" diff --git a/InvenTree/part/admin.py b/InvenTree/part/admin.py index 88064ff275..9c1648e616 100644 --- a/InvenTree/part/admin.py +++ b/InvenTree/part/admin.py @@ -1,3 +1,5 @@ +"""Admin class definitions for the 'part' app""" + from django.contrib import admin import import_export.widgets as widgets @@ -11,7 +13,7 @@ from stock.models import StockLocation class PartResource(ModelResource): - """ Class for managing Part data import/export """ + """Class for managing Part data import/export.""" # ForeignKey fields category = Field(attribute='category', widget=widgets.ForeignKeyWidget(models.PartCategory)) @@ -38,6 +40,7 @@ class PartResource(ModelResource): building = Field(attribute='quantity_being_built', readonly=True, widget=widgets.IntegerWidget()) class Meta: + """Metaclass definition""" model = models.Part skip_unchanged = True report_skipped = False @@ -49,8 +52,7 @@ class PartResource(ModelResource): ] def get_queryset(self): - """ Prefetch related data for quicker access """ - + """Prefetch related data for quicker access.""" query = super().get_queryset() query = query.prefetch_related( 'category', @@ -62,8 +64,17 @@ class PartResource(ModelResource): return query + def after_import(self, dataset, result, using_transactions, dry_run, **kwargs): + """Rebuild MPTT tree structure after importing Part data""" + + super().after_import(dataset, result, using_transactions, dry_run, **kwargs) + + # Rebuild the Part tree(s) + models.Part.objects.rebuild() + class PartAdmin(ImportExportModelAdmin): + """Admin class for the Part model""" resource_class = PartResource @@ -82,7 +93,7 @@ class PartAdmin(ImportExportModelAdmin): class PartCategoryResource(ModelResource): - """ Class for managing PartCategory data import/export """ + """Class for managing PartCategory data import/export.""" parent = Field(attribute='parent', widget=widgets.ForeignKeyWidget(models.PartCategory)) @@ -91,6 +102,7 @@ class PartCategoryResource(ModelResource): default_location = Field(attribute='default_location', widget=widgets.ForeignKeyWidget(StockLocation)) class Meta: + """Metaclass definition""" model = models.PartCategory skip_unchanged = True report_skipped = False @@ -103,6 +115,7 @@ class PartCategoryResource(ModelResource): ] def after_import(self, dataset, result, using_transactions, dry_run, **kwargs): + """Rebuild MPTT tree structure after importing PartCategory data""" super().after_import(dataset, result, using_transactions, dry_run, **kwargs) @@ -111,6 +124,7 @@ class PartCategoryResource(ModelResource): class PartCategoryAdmin(ImportExportModelAdmin): + """Admin class for the PartCategory model""" resource_class = PartCategoryResource @@ -122,35 +136,21 @@ class PartCategoryAdmin(ImportExportModelAdmin): class PartRelatedAdmin(admin.ModelAdmin): - """ - Class to manage PartRelated objects - """ + """Class to manage PartRelated objects.""" autocomplete_fields = ('part_1', 'part_2') class PartAttachmentAdmin(admin.ModelAdmin): + """Admin class for the PartAttachment model""" list_display = ('part', 'attachment', 'comment') autocomplete_fields = ('part',) -class PartStarAdmin(admin.ModelAdmin): - - list_display = ('part', 'user') - - autocomplete_fields = ('part',) - - -class PartCategoryStarAdmin(admin.ModelAdmin): - - list_display = ('category', 'user') - - autocomplete_fields = ('category',) - - class PartTestTemplateAdmin(admin.ModelAdmin): + """Admin class for the PartTestTemplate model""" list_display = ('part', 'test_name', 'required') @@ -158,7 +158,7 @@ class PartTestTemplateAdmin(admin.ModelAdmin): class BomItemResource(ModelResource): - """ Class for managing BomItem data import/export """ + """Class for managing BomItem data import/export.""" level = Field(attribute='level', readonly=True) @@ -189,25 +189,18 @@ class BomItemResource(ModelResource): sub_assembly = Field(attribute='sub_part__assembly', readonly=True) def dehydrate_quantity(self, item): - """ - Special consideration for the 'quantity' field on data export. - We do not want a spreadsheet full of "1.0000" (we'd rather "1") + """Special consideration for the 'quantity' field on data export. We do not want a spreadsheet full of "1.0000" (we'd rather "1") Ref: https://django-import-export.readthedocs.io/en/latest/getting_started.html#advanced-data-manipulation-on-export """ return float(item.quantity) def before_export(self, queryset, *args, **kwargs): - + """Perform before exporting data""" self.is_importing = kwargs.get('importing', False) def get_fields(self, **kwargs): - """ - If we are exporting for the purposes of generating - a 'bom-import' template, there are some fields which - we are not interested in. - """ - + """If we are exporting for the purposes of generating a 'bom-import' template, there are some fields which we are not interested in.""" fields = super().get_fields(**kwargs) # If we are not generating an "import" template, @@ -239,6 +232,7 @@ class BomItemResource(ModelResource): return fields class Meta: + """Metaclass definition""" model = models.BomItem skip_unchanged = True report_skipped = False @@ -253,6 +247,7 @@ class BomItemResource(ModelResource): class BomItemAdmin(ImportExportModelAdmin): + """Admin class for the BomItem model""" resource_class = BomItemResource @@ -264,13 +259,15 @@ class BomItemAdmin(ImportExportModelAdmin): class ParameterTemplateAdmin(ImportExportModelAdmin): + """Admin class for the PartParameterTemplate model""" + list_display = ('name', 'units') search_fields = ('name', 'units') class ParameterResource(ModelResource): - """ Class for managing PartParameter data import/export """ + """Class for managing PartParameter data import/export.""" part = Field(attribute='part', widget=widgets.ForeignKeyWidget(models.Part)) @@ -281,6 +278,7 @@ class ParameterResource(ModelResource): template_name = Field(attribute='template__name', readonly=True) class Meta: + """Metaclass definition""" model = models.PartParameter skip_unchanged = True report_skipped = False @@ -288,6 +286,7 @@ class ParameterResource(ModelResource): class ParameterAdmin(ImportExportModelAdmin): + """Admin class for the PartParameter model""" resource_class = ParameterResource @@ -297,21 +296,26 @@ class ParameterAdmin(ImportExportModelAdmin): class PartCategoryParameterAdmin(admin.ModelAdmin): + """Admin class for the PartCategoryParameterTemplate model""" autocomplete_fields = ('category', 'parameter_template',) class PartSellPriceBreakAdmin(admin.ModelAdmin): + """Admin class for the PartSellPriceBreak model""" class Meta: + """Metaclass definition""" model = models.PartSellPriceBreak list_display = ('part', 'quantity', 'price',) class PartInternalPriceBreakAdmin(admin.ModelAdmin): + """Admin class for the PartInternalPriceBreak model""" class Meta: + """Metaclass definition""" model = models.PartInternalPriceBreak list_display = ('part', 'quantity', 'price',) @@ -323,8 +327,6 @@ admin.site.register(models.Part, PartAdmin) admin.site.register(models.PartCategory, PartCategoryAdmin) admin.site.register(models.PartRelated, PartRelatedAdmin) admin.site.register(models.PartAttachment, PartAttachmentAdmin) -admin.site.register(models.PartStar, PartStarAdmin) -admin.site.register(models.PartCategoryStar, PartCategoryStarAdmin) admin.site.register(models.BomItem, BomItemAdmin) admin.site.register(models.PartParameterTemplate, ParameterTemplateAdmin) admin.site.register(models.PartParameter, ParameterAdmin) diff --git a/InvenTree/part/api.py b/InvenTree/part/api.py index c6c42a6aa9..348dec372d 100644 --- a/InvenTree/part/api.py +++ b/InvenTree/part/api.py @@ -1,6 +1,4 @@ -""" -Provides a JSON API for the Part app -""" +"""Provides a JSON API for the Part app.""" import datetime from decimal import Decimal, InvalidOperation @@ -41,7 +39,7 @@ from .models import (BomItem, BomItemSubstitute, Part, PartAttachment, class CategoryList(generics.ListCreateAPIView): - """ API endpoint for accessing a list of PartCategory objects. + """API endpoint for accessing a list of PartCategory objects. - GET: Return a list of PartCategory objects - POST: Create a new PartCategory object @@ -51,7 +49,7 @@ class CategoryList(generics.ListCreateAPIView): serializer_class = part_serializers.CategorySerializer def get_serializer_context(self): - + """Add extra context data to the serializer for the PartCategoryList endpoint""" ctx = super().get_serializer_context() try: @@ -63,11 +61,10 @@ class CategoryList(generics.ListCreateAPIView): return ctx def filter_queryset(self, queryset): - """ - Custom filtering: + """Custom filtering: + - Allow filtering by "null" parent to retrieve top-level part categories """ - queryset = super().filter_queryset(queryset) params = self.request.query_params @@ -158,15 +155,13 @@ class CategoryList(generics.ListCreateAPIView): class CategoryDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for detail view of a single PartCategory object - """ + """API endpoint for detail view of a single PartCategory object.""" serializer_class = part_serializers.CategorySerializer queryset = PartCategory.objects.all() def get_serializer_context(self): - + """Add extra context to the serializer for the CategoryDetail endpoint""" ctx = super().get_serializer_context() try: @@ -178,7 +173,7 @@ class CategoryDetail(generics.RetrieveUpdateDestroyAPIView): return ctx def update(self, request, *args, **kwargs): - + """Perform 'update' function and mark this part as 'starred' (or not)""" if 'starred' in request.data: starred = str2bool(request.data.get('starred', False)) @@ -190,16 +185,17 @@ class CategoryDetail(generics.RetrieveUpdateDestroyAPIView): class CategoryMetadata(generics.RetrieveUpdateAPIView): - """API endpoint for viewing / updating PartCategory metadata""" + """API endpoint for viewing / updating PartCategory metadata.""" def get_serializer(self, *args, **kwargs): + """Return a MetadataSerializer pointing to the referenced PartCategory instance""" return MetadataSerializer(PartCategory, *args, **kwargs) queryset = PartCategory.objects.all() class CategoryParameterList(generics.ListAPIView): - """ API endpoint for accessing a list of PartCategoryParameterTemplate objects. + """API endpoint for accessing a list of PartCategoryParameterTemplate objects. - GET: Return a list of PartCategoryParameterTemplate objects """ @@ -208,13 +204,12 @@ class CategoryParameterList(generics.ListAPIView): serializer_class = part_serializers.CategoryParameterTemplateSerializer def get_queryset(self): - """ - Custom filtering: + """Custom filtering: + - Allow filtering by "null" parent to retrieve all categories parameter templates - Allow filtering by category - Allow traversing all parent categories """ - queryset = super().get_queryset() params = self.request.query_params @@ -241,9 +236,7 @@ class CategoryParameterList(generics.ListAPIView): class CategoryTree(generics.ListAPIView): - """ - API endpoint for accessing a list of PartCategory objects ready for rendering a tree. - """ + """API endpoint for accessing a list of PartCategory objects ready for rendering a tree.""" queryset = PartCategory.objects.all() serializer_class = part_serializers.CategoryTree @@ -258,18 +251,14 @@ class CategoryTree(generics.ListAPIView): class PartSalePriceDetail(generics.RetrieveUpdateDestroyAPIView): - """ - Detail endpoint for PartSellPriceBreak model - """ + """Detail endpoint for PartSellPriceBreak model.""" queryset = PartSellPriceBreak.objects.all() serializer_class = part_serializers.PartSalePriceSerializer class PartSalePriceList(generics.ListCreateAPIView): - """ - API endpoint for list view of PartSalePriceBreak model - """ + """API endpoint for list view of PartSalePriceBreak model.""" queryset = PartSellPriceBreak.objects.all() serializer_class = part_serializers.PartSalePriceSerializer @@ -284,18 +273,14 @@ class PartSalePriceList(generics.ListCreateAPIView): class PartInternalPriceDetail(generics.RetrieveUpdateDestroyAPIView): - """ - Detail endpoint for PartInternalPriceBreak model - """ + """Detail endpoint for PartInternalPriceBreak model.""" queryset = PartInternalPriceBreak.objects.all() serializer_class = part_serializers.PartInternalPriceSerializer class PartInternalPriceList(generics.ListCreateAPIView): - """ - API endpoint for list view of PartInternalPriceBreak model - """ + """API endpoint for list view of PartInternalPriceBreak model.""" queryset = PartInternalPriceBreak.objects.all() serializer_class = part_serializers.PartInternalPriceSerializer @@ -311,9 +296,7 @@ class PartInternalPriceList(generics.ListCreateAPIView): class PartAttachmentList(generics.ListCreateAPIView, AttachmentMixin): - """ - API endpoint for listing (and creating) a PartAttachment (file upload). - """ + """API endpoint for listing (and creating) a PartAttachment (file upload).""" queryset = PartAttachment.objects.all() serializer_class = part_serializers.PartAttachmentSerializer @@ -328,38 +311,30 @@ class PartAttachmentList(generics.ListCreateAPIView, AttachmentMixin): class PartAttachmentDetail(generics.RetrieveUpdateDestroyAPIView, AttachmentMixin): - """ - Detail endpoint for PartAttachment model - """ + """Detail endpoint for PartAttachment model.""" queryset = PartAttachment.objects.all() serializer_class = part_serializers.PartAttachmentSerializer class PartTestTemplateDetail(generics.RetrieveUpdateDestroyAPIView): - """ - Detail endpoint for PartTestTemplate model - """ + """Detail endpoint for PartTestTemplate model.""" queryset = PartTestTemplate.objects.all() serializer_class = part_serializers.PartTestTemplateSerializer class PartTestTemplateList(generics.ListCreateAPIView): - """ - API endpoint for listing (and creating) a PartTestTemplate. - """ + """API endpoint for listing (and creating) a PartTestTemplate.""" queryset = PartTestTemplate.objects.all() serializer_class = part_serializers.PartTestTemplateSerializer def filter_queryset(self, queryset): - """ - Filter the test list queryset. + """Filter the test list queryset. If filtering by 'part', we include results for any parts "above" the specified part. """ - queryset = super().filter_queryset(queryset) params = self.request.query_params @@ -390,15 +365,13 @@ class PartTestTemplateList(generics.ListCreateAPIView): class PartThumbs(generics.ListAPIView): - """ - API endpoint for retrieving information on available Part thumbnails - """ + """API endpoint for retrieving information on available Part thumbnails.""" queryset = Part.objects.all() serializer_class = part_serializers.PartThumbSerializer def get_queryset(self): - + """Return a queryset which exlcudes any parts without images""" queryset = super().get_queryset() # Get all Parts which have an associated image @@ -407,11 +380,10 @@ class PartThumbs(generics.ListAPIView): return queryset def list(self, request, *args, **kwargs): - """ - Serialize the available Part images. + """Serialize the available Part images. + - Images may be used for multiple parts! """ - queryset = self.filter_queryset(self.get_queryset()) # Return the most popular parts first @@ -436,7 +408,7 @@ class PartThumbs(generics.ListAPIView): class PartThumbsUpdate(generics.RetrieveUpdateAPIView): - """ API endpoint for updating Part thumbnails""" + """API endpoint for updating Part thumbnails.""" queryset = Part.objects.all() serializer_class = part_serializers.PartThumbSerializerUpdate @@ -447,8 +419,7 @@ class PartThumbsUpdate(generics.RetrieveUpdateAPIView): class PartScheduling(generics.RetrieveAPIView): - """ - API endpoint for delivering "scheduling" information about a given part via the API. + """API endpoint for delivering "scheduling" information about a given part via the API. Returns a chronologically ordered list about future "scheduled" events, concerning stock levels for the part: @@ -462,7 +433,7 @@ class PartScheduling(generics.RetrieveAPIView): queryset = Part.objects.all() def retrieve(self, request, *args, **kwargs): - + """Return scheduling information for the referenced Part instance""" today = datetime.datetime.now().date() part = self.get_object() @@ -470,13 +441,12 @@ class PartScheduling(generics.RetrieveAPIView): schedule = [] def add_schedule_entry(date, quantity, title, label, url): - """ - Check if a scheduled entry should be added: + """Check if a scheduled entry should be added: + - date must be non-null - date cannot be in the "past" - quantity must not be zero """ - if date and date >= today and quantity != 0: schedule.append({ 'date': date, @@ -583,25 +553,22 @@ class PartScheduling(generics.RetrieveAPIView): class PartMetadata(generics.RetrieveUpdateAPIView): - """ - API endpoint for viewing / updating Part metadata - """ + """API endpoint for viewing / updating Part metadata.""" def get_serializer(self, *args, **kwargs): + """Returns a MetadataSerializer instance pointing to the referenced Part""" return MetadataSerializer(Part, *args, **kwargs) queryset = Part.objects.all() class PartSerialNumberDetail(generics.RetrieveAPIView): - """ - API endpoint for returning extra serial number information about a particular part - """ + """API endpoint for returning extra serial number information about a particular part.""" queryset = Part.objects.all() def retrieve(self, request, *args, **kwargs): - + """Return serial number information for the referenced Part instance""" part = self.get_object() # Calculate the "latest" serial number @@ -621,15 +588,13 @@ class PartSerialNumberDetail(generics.RetrieveAPIView): class PartCopyBOM(generics.CreateAPIView): - """ - API endpoint for duplicating a BOM - """ + """API endpoint for duplicating a BOM.""" queryset = Part.objects.all() serializer_class = part_serializers.PartCopyBOMSerializer def get_serializer_context(self): - + """Add custom information to the serializer context for this endpoint""" ctx = super().get_serializer_context() try: @@ -641,13 +606,13 @@ class PartCopyBOM(generics.CreateAPIView): class PartValidateBOM(generics.RetrieveUpdateAPIView): - """ - API endpoint for 'validating' the BOM for a given Part - """ + """API endpoint for 'validating' the BOM for a given Part.""" class BOMValidateSerializer(serializers.ModelSerializer): + """Simple serializer class for validating a single BomItem instance""" class Meta: + """Metaclass defines serializer fields""" model = Part fields = [ 'checksum', @@ -667,6 +632,7 @@ class PartValidateBOM(generics.RetrieveUpdateAPIView): ) def validate_valid(self, valid): + """Check that the 'valid' input was flagged""" if not valid: raise ValidationError(_('This option must be selected')) @@ -675,7 +641,7 @@ class PartValidateBOM(generics.RetrieveUpdateAPIView): serializer_class = BOMValidateSerializer def update(self, request, *args, **kwargs): - + """Validate the referenced BomItem instance""" part = self.get_object() partial = kwargs.pop('partial', False) @@ -691,7 +657,7 @@ class PartValidateBOM(generics.RetrieveUpdateAPIView): class PartDetail(generics.RetrieveUpdateDestroyAPIView): - """ API endpoint for detail view of a single Part object """ + """API endpoint for detail view of a single Part object.""" queryset = Part.objects.all() serializer_class = part_serializers.PartSerializer @@ -699,6 +665,7 @@ class PartDetail(generics.RetrieveUpdateDestroyAPIView): starred_parts = None def get_queryset(self, *args, **kwargs): + """Return an annotated queryset object for the PartDetail endpoint""" queryset = super().get_queryset(*args, **kwargs) queryset = part_serializers.PartSerializer.annotate_queryset(queryset) @@ -706,7 +673,7 @@ class PartDetail(generics.RetrieveUpdateDestroyAPIView): return queryset def get_serializer(self, *args, **kwargs): - + """Return a serializer instance for the PartDetail endpoint""" # By default, include 'category_detail' information in the detail view try: kwargs['category_detail'] = str2bool(self.request.query_params.get('category_detail', True)) @@ -726,7 +693,11 @@ class PartDetail(generics.RetrieveUpdateDestroyAPIView): return self.serializer_class(*args, **kwargs) def destroy(self, request, *args, **kwargs): - # Retrieve part + """Delete a Part instance via the API + + - If the part is 'active' it cannot be deleted + - It must first be marked as 'inactive' + """ part = Part.objects.get(pk=int(kwargs['pk'])) # Check if inactive if not part.active: @@ -734,16 +705,14 @@ class PartDetail(generics.RetrieveUpdateDestroyAPIView): return super(PartDetail, self).destroy(request, *args, **kwargs) else: # Return 405 error - message = f'Part \'{part.name}\' (pk = {part.pk}) is active: cannot delete' + message = 'Part is active: cannot delete' return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED, data=message) def update(self, request, *args, **kwargs): - """ - Custom update functionality for Part instance. + """Custom update functionality for Part instance. - If the 'starred' field is provided, update the 'starred' status against current user """ - if 'starred' in request.data: starred = str2bool(request.data.get('starred', False)) @@ -755,8 +724,8 @@ class PartDetail(generics.RetrieveUpdateDestroyAPIView): class PartFilter(rest_filters.FilterSet): - """ - Custom filters for the PartList endpoint. + """Custom filters for the PartList endpoint. + Uses the django_filters extension framework """ @@ -764,7 +733,7 @@ class PartFilter(rest_filters.FilterSet): has_ipn = rest_filters.BooleanFilter(label='Has IPN', method='filter_has_ipn') def filter_has_ipn(self, queryset, name, value): - + """Filter by whether the Part has an IPN (internal part number) or not""" value = str2bool(value) if value: @@ -791,10 +760,7 @@ class PartFilter(rest_filters.FilterSet): low_stock = rest_filters.BooleanFilter(label='Low stock', method='filter_low_stock') def filter_low_stock(self, queryset, name, value): - """ - Filter by "low stock" status - """ - + """Filter by "low stock" status.""" value = str2bool(value) if value: @@ -812,7 +778,7 @@ class PartFilter(rest_filters.FilterSet): has_stock = rest_filters.BooleanFilter(label='Has stock', method='filter_has_stock') def filter_has_stock(self, queryset, name, value): - + """Filter by whether the Part has any stock""" value = str2bool(value) if value: @@ -826,7 +792,7 @@ class PartFilter(rest_filters.FilterSet): unallocated_stock = rest_filters.BooleanFilter(label='Unallocated stock', method='filter_unallocated_stock') def filter_unallocated_stock(self, queryset, name, value): - + """Filter by whether the Part has unallocated stock""" value = str2bool(value) if value: @@ -854,8 +820,7 @@ class PartFilter(rest_filters.FilterSet): class PartList(APIDownloadMixin, generics.ListCreateAPIView): - """ - API endpoint for accessing a list of Part objects + """API endpoint for accessing a list of Part objects. - GET: Return list of objects - POST: Create a new Part object @@ -882,7 +847,7 @@ class PartList(APIDownloadMixin, generics.ListCreateAPIView): starred_parts = None def get_serializer(self, *args, **kwargs): - + """Return a serializer instance for this endpoint""" # Ensure the request context is passed through kwargs['context'] = self.get_serializer_context() @@ -904,6 +869,7 @@ class PartList(APIDownloadMixin, generics.ListCreateAPIView): return self.serializer_class(*args, **kwargs) def download_queryset(self, queryset, export_format): + """Download the filtered queryset as a data file""" dataset = PartResource().export(queryset=queryset) filedata = dataset.export(export_format) @@ -912,14 +878,10 @@ class PartList(APIDownloadMixin, generics.ListCreateAPIView): return DownloadFile(filedata, filename) def list(self, request, *args, **kwargs): - """ - Overide the 'list' method, as the PartCategory objects are - very expensive to serialize! + """Overide the 'list' method, as the PartCategory objects are very expensive to serialize! - So we will serialize them first, and keep them in memory, - so that they do not have to be serialized multiple times... + So we will serialize them first, and keep them in memory, so that they do not have to be serialized multiple times... """ - queryset = self.filter_queryset(self.get_queryset()) page = self.paginate_queryset(queryset) @@ -980,12 +942,10 @@ class PartList(APIDownloadMixin, generics.ListCreateAPIView): @transaction.atomic def create(self, request, *args, **kwargs): - """ - We wish to save the user who created this part! + """We wish to save the user who created this part! Note: Implementation copied from DRF class CreateModelMixin """ - # TODO: Unit tests for this function! serializer = self.get_serializer(data=request.data) @@ -1128,18 +1088,14 @@ class PartList(APIDownloadMixin, generics.ListCreateAPIView): return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) def get_queryset(self, *args, **kwargs): - + """Return an annotated queryset object""" queryset = super().get_queryset(*args, **kwargs) queryset = part_serializers.PartSerializer.annotate_queryset(queryset) return queryset def filter_queryset(self, queryset): - """ - Perform custom filtering of the queryset. - We overide the DRF filter_fields here because - """ - + """Perform custom filtering of the queryset""" params = self.request.query_params queryset = super().filter_queryset(queryset) @@ -1392,15 +1348,13 @@ class PartList(APIDownloadMixin, generics.ListCreateAPIView): class PartRelatedList(generics.ListCreateAPIView): - """ - API endpoint for accessing a list of PartRelated objects - """ + """API endpoint for accessing a list of PartRelated objects.""" queryset = PartRelated.objects.all() serializer_class = part_serializers.PartRelationSerializer def filter_queryset(self, queryset): - + """Custom queryset filtering""" queryset = super().filter_queryset(queryset) params = self.request.query_params @@ -1421,16 +1375,14 @@ class PartRelatedList(generics.ListCreateAPIView): class PartRelatedDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for accessing detail view of a PartRelated object - """ + """API endpoint for accessing detail view of a PartRelated object.""" queryset = PartRelated.objects.all() serializer_class = part_serializers.PartRelationSerializer class PartParameterTemplateList(generics.ListCreateAPIView): - """ API endpoint for accessing a list of PartParameterTemplate objects. + """API endpoint for accessing a list of PartParameterTemplate objects. - GET: Return list of PartParameterTemplate objects - POST: Create a new PartParameterTemplate object @@ -1454,10 +1406,7 @@ class PartParameterTemplateList(generics.ListCreateAPIView): ] def filter_queryset(self, queryset): - """ - Custom filtering for the PartParameterTemplate API - """ - + """Custom filtering for the PartParameterTemplate API.""" queryset = super().filter_queryset(queryset) params = self.request.query_params @@ -1493,7 +1442,7 @@ class PartParameterTemplateList(generics.ListCreateAPIView): class PartParameterList(generics.ListCreateAPIView): - """ API endpoint for accessing a list of PartParameter objects + """API endpoint for accessing a list of PartParameter objects. - GET: Return list of PartParameter objects - POST: Create a new PartParameter object @@ -1513,18 +1462,14 @@ class PartParameterList(generics.ListCreateAPIView): class PartParameterDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for detail view of a single PartParameter object - """ + """API endpoint for detail view of a single PartParameter object.""" queryset = PartParameter.objects.all() serializer_class = part_serializers.PartParameterSerializer class BomFilter(rest_filters.FilterSet): - """ - Custom filters for the BOM list - """ + """Custom filters for the BOM list.""" # Boolean filters for BOM item optional = rest_filters.BooleanFilter(label='BOM line is optional') @@ -1542,8 +1487,7 @@ class BomFilter(rest_filters.FilterSet): validated = rest_filters.BooleanFilter(label='BOM line has been validated', method='filter_validated') def filter_validated(self, queryset, name, value): - - # Work out which lines have actually been validated + """Filter by which lines have actually been validated""" pks = [] value = str2bool(value) @@ -1565,8 +1509,7 @@ class BomFilter(rest_filters.FilterSet): class BomList(generics.ListCreateAPIView): - """ - API endpoint for accessing a list of BomItem objects. + """API endpoint for accessing a list of BomItem objects. - GET: Return list of BomItem objects - POST: Create a new BomItem object @@ -1577,6 +1520,7 @@ class BomList(generics.ListCreateAPIView): filterset_class = BomFilter def list(self, request, *args, **kwargs): + """Return serialized list response for this endpoint""" queryset = self.filter_queryset(self.get_queryset()) @@ -1602,6 +1546,13 @@ class BomList(generics.ListCreateAPIView): return Response(data) def get_serializer(self, *args, **kwargs): + """Return the serializer instance for this API endpoint + + If requested, extra detail fields are annotated to the queryset: + - part_detail + - sub_part_detail + - include_pricing + """ # Do we wish to include extra detail? try: @@ -1626,7 +1577,7 @@ class BomList(generics.ListCreateAPIView): return self.serializer_class(*args, **kwargs) def get_queryset(self, *args, **kwargs): - + """Return the queryset object for this endpoint""" queryset = super().get_queryset(*args, **kwargs) queryset = self.get_serializer_class().setup_eager_loading(queryset) @@ -1635,7 +1586,7 @@ class BomList(generics.ListCreateAPIView): return queryset def filter_queryset(self, queryset): - + """Custom query filtering for the BomItem list API""" queryset = super().filter_queryset(queryset) params = self.request.query_params @@ -1716,18 +1667,13 @@ class BomList(generics.ListCreateAPIView): return queryset def include_pricing(self): - """ - Determine if pricing information should be included in the response - """ + """Determine if pricing information should be included in the response.""" pricing_default = InvenTreeSetting.get_setting('PART_SHOW_PRICE_IN_BOM') return str2bool(self.request.query_params.get('include_pricing', pricing_default)) def annotate_pricing(self, queryset): - """ - Add part pricing information to the queryset - """ - + """Add part pricing information to the queryset.""" # Annotate with purchase prices queryset = queryset.annotate( purchase_price_min=Min('sub_part__stock_items__purchase_price'), @@ -1742,8 +1688,7 @@ class BomList(generics.ListCreateAPIView): ).values('pk', 'sub_part', 'purchase_price', 'purchase_price_currency') def convert_price(price, currency, decimal_places=4): - """ Convert price field, returns Money field """ - + """Convert price field, returns Money field.""" price_adjusted = None # Get default currency from settings @@ -1796,8 +1741,7 @@ class BomList(generics.ListCreateAPIView): class BomImportUpload(generics.CreateAPIView): - """ - API endpoint for uploading a complete Bill of Materials. + """API endpoint for uploading a complete Bill of Materials. It is assumed that the BOM has been extracted from a file using the BomExtract endpoint. """ @@ -1806,10 +1750,7 @@ class BomImportUpload(generics.CreateAPIView): serializer_class = part_serializers.BomImportUploadSerializer def create(self, request, *args, **kwargs): - """ - Custom create function to return the extracted data - """ - + """Custom create function to return the extracted data.""" serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) self.perform_create(serializer) @@ -1821,31 +1762,27 @@ class BomImportUpload(generics.CreateAPIView): class BomImportExtract(generics.CreateAPIView): - """ - API endpoint for extracting BOM data from a BOM file. - """ + """API endpoint for extracting BOM data from a BOM file.""" queryset = Part.objects.none() serializer_class = part_serializers.BomImportExtractSerializer class BomImportSubmit(generics.CreateAPIView): - """ - API endpoint for submitting BOM data from a BOM file - """ + """API endpoint for submitting BOM data from a BOM file.""" queryset = BomItem.objects.none() serializer_class = part_serializers.BomImportSubmitSerializer class BomDetail(generics.RetrieveUpdateDestroyAPIView): - """ API endpoint for detail view of a single BomItem object """ + """API endpoint for detail view of a single BomItem object.""" queryset = BomItem.objects.all() serializer_class = part_serializers.BomItemSerializer def get_queryset(self, *args, **kwargs): - + """Prefetch related fields for this queryset""" queryset = super().get_queryset(*args, **kwargs) queryset = self.get_serializer_class().setup_eager_loading(queryset) @@ -1855,19 +1792,17 @@ class BomDetail(generics.RetrieveUpdateDestroyAPIView): class BomItemValidate(generics.UpdateAPIView): - """ API endpoint for validating a BomItem """ + """API endpoint for validating a BomItem.""" - # Very simple serializers class BomItemValidationSerializer(serializers.Serializer): - + """Simple serializer for passing a single boolean field""" valid = serializers.BooleanField(default=False) queryset = BomItem.objects.all() serializer_class = BomItemValidationSerializer def update(self, request, *args, **kwargs): - """ Perform update request """ - + """Perform update request.""" partial = kwargs.pop('partial', False) valid = request.data.get('valid', False) @@ -1884,9 +1819,7 @@ class BomItemValidate(generics.UpdateAPIView): class BomItemSubstituteList(generics.ListCreateAPIView): - """ - API endpoint for accessing a list of BomItemSubstitute objects - """ + """API endpoint for accessing a list of BomItemSubstitute objects.""" serializer_class = part_serializers.BomItemSubstituteSerializer queryset = BomItemSubstitute.objects.all() @@ -1904,9 +1837,7 @@ class BomItemSubstituteList(generics.ListCreateAPIView): class BomItemSubstituteDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for detail view of a single BomItemSubstitute object - """ + """API endpoint for detail view of a single BomItemSubstitute object.""" queryset = BomItemSubstitute.objects.all() serializer_class = part_serializers.BomItemSubstituteSerializer diff --git a/InvenTree/part/apps.py b/InvenTree/part/apps.py index 9f66be5009..9a8ce076b5 100644 --- a/InvenTree/part/apps.py +++ b/InvenTree/part/apps.py @@ -1,3 +1,5 @@ +"""part app specification""" + import logging from django.apps import AppConfig @@ -9,24 +11,19 @@ logger = logging.getLogger("inventree") class PartConfig(AppConfig): + """Config class for the 'part' app""" name = 'part' def ready(self): - """ - This function is called whenever the Part app is loaded. - """ - + """This function is called whenever the Part app is loaded.""" if canAppAccessDatabase(): self.update_trackable_status() def update_trackable_status(self): - """ - Check for any instances where a trackable part is used in the BOM - for a non-trackable part. + """Check for any instances where a trackable part is used in the BOM for a non-trackable part. In such a case, force the top-level part to be trackable too. """ - from .models import BomItem try: diff --git a/InvenTree/part/bom.py b/InvenTree/part/bom.py index fcb86f6204..2e1018060d 100644 --- a/InvenTree/part/bom.py +++ b/InvenTree/part/bom.py @@ -1,5 +1,5 @@ -""" -Functionality for Bill of Material (BOM) management. +"""Functionality for Bill of Material (BOM) management. + Primarily BOM upload tools. """ @@ -11,18 +11,16 @@ from company.models import ManufacturerPart, SupplierPart from InvenTree.helpers import DownloadFile, GetExportFormats, normalize from .admin import BomItemResource -from .models import BomItem +from .models import BomItem, Part def IsValidBOMFormat(fmt): - """ Test if a file format specifier is in the valid list of BOM file formats """ - + """Test if a file format specifier is in the valid list of BOM file formats.""" return fmt.strip().lower() in GetExportFormats() def MakeBomTemplate(fmt): - """ Generate a Bill of Materials upload template file (for user download) """ - + """Generate a Bill of Materials upload template file (for user download).""" fmt = fmt.strip().lower() if not IsValidBOMFormat(fmt): @@ -44,14 +42,22 @@ def MakeBomTemplate(fmt): return DownloadFile(data, filename) -def ExportBom(part, fmt='csv', cascade=False, max_levels=None, parameter_data=False, stock_data=False, supplier_data=False, manufacturer_data=False): - """ Export a BOM (Bill of Materials) for a given part. +def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = None, parameter_data=False, stock_data=False, supplier_data=False, manufacturer_data=False): + """Export a BOM (Bill of Materials) for a given part. Args: - fmt: File format (default = 'csv') - cascade: If True, multi-level BOM output is supported. Otherwise, a flat top-level-only BOM is exported. - """ + part (Part): Part for which the BOM should be exported + fmt (str, optional): file format. Defaults to 'csv'. + cascade (bool, optional): If True, multi-level BOM output is supported. Otherwise, a flat top-level-only BOM is exported.. Defaults to False. + max_levels (int, optional): Levels of items that should be included. None for np sublevels. Defaults to None. + parameter_data (bool, optional): Additonal data that should be added. Defaults to False. + stock_data (bool, optional): Additonal data that should be added. Defaults to False. + supplier_data (bool, optional): Additonal data that should be added. Defaults to False. + manufacturer_data (bool, optional): Additonal data that should be added. Defaults to False. + Returns: + StreamingHttpResponse: Response that can be passed to the endpoint + """ if not IsValidBOMFormat(fmt): fmt = 'csv' @@ -92,9 +98,7 @@ def ExportBom(part, fmt='csv', cascade=False, max_levels=None, parameter_data=Fa pass if parameter_data: - """ - If requested, add extra columns for each PartParameter associated with each line item - """ + """If requested, add extra columns for each PartParameter associated with each line item.""" parameter_cols = {} @@ -117,9 +121,7 @@ def ExportBom(part, fmt='csv', cascade=False, max_levels=None, parameter_data=Fa add_columns_to_dataset(parameter_cols_ordered, len(bom_items)) if stock_data: - """ - If requested, add extra columns for stock data associated with each line item - """ + """If requested, add extra columns for stock data associated with each line item.""" stock_headers = [ _('Default Location'), @@ -172,9 +174,7 @@ def ExportBom(part, fmt='csv', cascade=False, max_levels=None, parameter_data=Fa add_columns_to_dataset(stock_cols, len(bom_items)) if manufacturer_data or supplier_data: - """ - If requested, add extra columns for each SupplierPart and ManufacturerPart associated with each line item - """ + """If requested, add extra columns for each SupplierPart and ManufacturerPart associated with each line item.""" # Keep track of the supplier parts we have already exported supplier_parts_used = set() diff --git a/InvenTree/part/forms.py b/InvenTree/part/forms.py index e1fe938bd2..f4c4810d37 100644 --- a/InvenTree/part/forms.py +++ b/InvenTree/part/forms.py @@ -1,13 +1,10 @@ -""" -Django Forms for interacting with Part objects -""" +"""Django Forms for interacting with Part objects.""" from django import forms from django.utils.translation import gettext_lazy as _ from mptt.fields import TreeNodeChoiceField -import common.models from common.forms import MatchItemForm from InvenTree.fields import RoundingDecimalFormField from InvenTree.forms import HelperForm @@ -18,24 +15,8 @@ from .models import (Part, PartCategory, PartCategoryParameterTemplate, PartSellPriceBreak) -class PartModelChoiceField(forms.ModelChoiceField): - """ Extending string representation of Part instance with available stock """ - - def label_from_instance(self, part): - - label = str(part) - - # Optionally display available part quantity - if common.models.InvenTreeSetting.get_setting('PART_SHOW_QUANTITY_IN_FORMS'): - label += f" - {part.available_stock}" - - return label - - class PartImageDownloadForm(HelperForm): - """ - Form for downloading an image from a URL - """ + """Form for downloading an image from a URL.""" url = forms.URLField( label=_('URL'), @@ -44,6 +25,7 @@ class PartImageDownloadForm(HelperForm): ) class Meta: + """Metaclass defines fields for this form""" model = Part fields = [ 'url', @@ -51,11 +33,10 @@ class PartImageDownloadForm(HelperForm): class BomMatchItemForm(MatchItemForm): - """ Override MatchItemForm fields """ + """Override MatchItemForm fields.""" def get_special_field(self, col_guess, row, file_manager): - """ Set special fields """ - + """Set special fields.""" # set quantity field if 'quantity' in col_guess.lower(): return forms.CharField( @@ -74,15 +55,16 @@ class BomMatchItemForm(MatchItemForm): class SetPartCategoryForm(forms.Form): - """ Form for setting the category of multiple Part objects """ + """Form for setting the category of multiple Part objects.""" part_category = TreeNodeChoiceField(queryset=PartCategory.objects.all(), required=True, help_text=_('Select part category')) class EditPartParameterTemplateForm(HelperForm): - """ Form for editing a PartParameterTemplate object """ + """Form for editing a PartParameterTemplate object.""" class Meta: + """Metaclass defines fields for this form""" model = PartParameterTemplate fields = [ 'name', @@ -91,7 +73,7 @@ class EditPartParameterTemplateForm(HelperForm): class EditCategoryParameterTemplateForm(HelperForm): - """ Form for editing a PartCategoryParameterTemplate object """ + """Form for editing a PartCategoryParameterTemplate object.""" add_to_same_level_categories = forms.BooleanField(required=False, initial=False, @@ -102,6 +84,7 @@ class EditCategoryParameterTemplateForm(HelperForm): help_text=_('Add parameter template to all categories')) class Meta: + """Metaclass defines fields for this form""" model = PartCategoryParameterTemplate fields = [ 'category', @@ -113,7 +96,7 @@ class EditCategoryParameterTemplateForm(HelperForm): class PartPriceForm(forms.Form): - """ Simple form for viewing part pricing information """ + """Simple form for viewing part pricing information.""" quantity = forms.IntegerField( required=True, @@ -123,6 +106,7 @@ class PartPriceForm(forms.Form): ) class Meta: + """Metaclass defines fields for this form""" model = Part fields = [ 'quantity', @@ -130,13 +114,12 @@ class PartPriceForm(forms.Form): class EditPartSalePriceBreakForm(HelperForm): - """ - Form for creating / editing a sale price for a part - """ + """Form for creating / editing a sale price for a part.""" quantity = RoundingDecimalFormField(max_digits=10, decimal_places=5, label=_('Quantity')) class Meta: + """Metaclass defines fields for this form""" model = PartSellPriceBreak fields = [ 'part', @@ -146,13 +129,12 @@ class EditPartSalePriceBreakForm(HelperForm): class EditPartInternalPriceBreakForm(HelperForm): - """ - Form for creating / editing a internal price for a part - """ + """Form for creating / editing a internal price for a part.""" quantity = RoundingDecimalFormField(max_digits=10, decimal_places=5, label=_('Quantity')) class Meta: + """Metaclass defines fields for this form""" model = PartInternalPriceBreak fields = [ 'part', diff --git a/InvenTree/part/models.py b/InvenTree/part/models.py index 3437437cef..d5a6900c8c 100644 --- a/InvenTree/part/models.py +++ b/InvenTree/part/models.py @@ -1,6 +1,6 @@ -""" -Part database model definitions -""" +"""Part database model definitions.""" + +from __future__ import annotations import decimal import hashlib @@ -54,7 +54,7 @@ logger = logging.getLogger("inventree") class PartCategory(MetadataMixin, InvenTreeTree): - """ PartCategory provides hierarchical organization of Part objects. + """PartCategory provides hierarchical organization of Part objects. Attributes: name: Name of this category @@ -64,11 +64,10 @@ class PartCategory(MetadataMixin, InvenTreeTree): """ def delete(self, *args, **kwargs): - """ - Custom model deletion routine, which updates any child categories or parts. + """Custom model deletion routine, which updates any child categories or parts. + This must be handled within a transaction.atomic(), otherwise the tree structure is damaged """ - with transaction.atomic(): parent = self.parent @@ -104,24 +103,29 @@ class PartCategory(MetadataMixin, InvenTreeTree): @staticmethod def get_api_url(): + """Return the API url associated with the PartCategory model""" return reverse('api-part-category-list') def get_absolute_url(self): + """Return the web URL associated with the detail view for this PartCategory instance""" return reverse('category-detail', kwargs={'pk': self.id}) class Meta: + """Metaclass defines extra model properties""" verbose_name = _("Part Category") verbose_name_plural = _("Part Categories") - def get_parts(self, cascade=True): - """ Return a queryset for all parts under this category. + def get_parts(self, cascade=True) -> set[Part]: + """Return a queryset for all parts under this category. - args: - cascade - If True, also look under subcategories (default = True) + Args: + cascade (bool, optional): If True, also look under subcategories. Defaults to True. + + Returns: + set[Part]: All matching parts """ - if cascade: - """ Select any parts which exist in this category or any child categories """ + """Select any parts which exist in this category or any child categories.""" queryset = Part.objects.filter(category__in=self.getUniqueChildren(include_self=True)) else: queryset = Part.objects.filter(category=self.pk) @@ -130,13 +134,11 @@ class PartCategory(MetadataMixin, InvenTreeTree): @property def item_count(self): + """Return the number of parts contained in this PartCategory""" return self.partcount() def partcount(self, cascade=True, active=False): - """ Return the total part count under this category - (including children of child categories) - """ - + """Return the total part count under this category (including children of child categories).""" query = self.get_parts(cascade=cascade) if active: @@ -144,19 +146,12 @@ class PartCategory(MetadataMixin, InvenTreeTree): return query.count() - @property - def has_parts(self): - """ True if there are any parts in this category """ - return self.partcount() > 0 - def prefetch_parts_parameters(self, cascade=True): - """ Prefectch parts parameters """ - + """Prefectch parts parameters.""" return self.get_parts(cascade=cascade).prefetch_related('parameters', 'parameters__template').all() def get_unique_parameters(self, cascade=True, prefetch=None): - """ Get all unique parameter names for all parts from this category """ - + """Get all unique parameter names for all parts from this category.""" unique_parameters_names = [] if prefetch: @@ -173,8 +168,7 @@ class PartCategory(MetadataMixin, InvenTreeTree): return sorted(unique_parameters_names) def get_parts_parameters(self, cascade=True, prefetch=None): - """ Get all parameter names and values for all parts from this category """ - + """Get all parameter names and values for all parts from this category.""" category_parameters = [] if prefetch: @@ -203,8 +197,7 @@ class PartCategory(MetadataMixin, InvenTreeTree): @classmethod def get_parent_categories(cls): - """ Return tuple list of parent (root) categories """ - + """Return tuple list of parent (root) categories.""" # Get root nodes root_categories = cls.objects.filter(level=0) @@ -215,17 +208,13 @@ class PartCategory(MetadataMixin, InvenTreeTree): return parent_categories def get_parameter_templates(self): - """ Return parameter templates associated to category """ - + """Return parameter templates associated to category.""" prefetch = PartCategoryParameterTemplate.objects.prefetch_related('category', 'parameter_template') return prefetch.filter(category=self.id) def get_subscribers(self, include_parents=True): - """ - Return a list of users who subscribe to this PartCategory - """ - + """Return a list of users who subscribe to this PartCategory.""" cats = self.get_ancestors(include_self=True) subscribers = set() @@ -245,17 +234,11 @@ class PartCategory(MetadataMixin, InvenTreeTree): return [s for s in subscribers] def is_starred_by(self, user, **kwargs): - """ - Returns True if the specified user subscribes to this category - """ - + """Returns True if the specified user subscribes to this category.""" return user in self.get_subscribers(**kwargs) def set_starred(self, user, status): - """ - Set the "subscription" status of this PartCategory against the specified user - """ - + """Set the "subscription" status of this PartCategory against the specified user.""" if not user: return @@ -277,7 +260,7 @@ class PartCategory(MetadataMixin, InvenTreeTree): def rename_part_image(instance, filename): - """ Function for renaming a part image file + """Function for renaming a part image file. Args: instance: Instance of a Part object @@ -286,7 +269,6 @@ def rename_part_image(instance, filename): Returns: Cleaned filename in format part__img """ - base = 'part_images' fname = os.path.basename(filename) @@ -294,15 +276,14 @@ def rename_part_image(instance, filename): class PartManager(TreeManager): - """ - Defines a custom object manager for the Part model. + """Defines a custom object manager for the Part model. The main purpose of this manager is to reduce the number of database hits, as the Part model has a large number of ForeignKey fields! """ def get_queryset(self): - + """Perform default prefetch operations when accessing Part model from the database""" return super().get_queryset().prefetch_related( 'category', 'category__parent', @@ -313,7 +294,7 @@ class PartManager(TreeManager): @cleanup.ignore class Part(MetadataMixin, MPTTModel): - """ The Part object represents an abstract part, the 'concept' of an actual entity. + """The Part object represents an abstract part, the 'concept' of an actual entity. An actual physical instance of a Part is a StockItem which is treated separately. @@ -351,6 +332,7 @@ class Part(MetadataMixin, MPTTModel): objects = PartManager() class Meta: + """Metaclass defines extra model properties""" verbose_name = _("Part") verbose_name_plural = _("Parts") ordering = ['name', ] @@ -359,19 +341,17 @@ class Part(MetadataMixin, MPTTModel): ] class MPTTMeta: + """MPTT metaclass definitions""" # For legacy reasons the 'variant_of' field is used to indicate the MPTT parent parent_attr = 'variant_of' @staticmethod def get_api_url(): - + """Return the list API endpoint URL associated with the Part model""" return reverse('api-part-list') def api_instance_filters(self): - """ - Return API query filters for limiting field results against this instance - """ - + """Return API query filters for limiting field results against this instance.""" return { 'variant_of': { 'exclude_tree': self.pk, @@ -379,10 +359,7 @@ class Part(MetadataMixin, MPTTModel): } def get_context_data(self, request, **kwargs): - """ - Return some useful context data about this part for template rendering - """ - + """Return some useful context data about this part for template rendering.""" context = {} context['disabled'] = not self.active @@ -415,13 +392,11 @@ class Part(MetadataMixin, MPTTModel): return context def save(self, *args, **kwargs): - """ - Overrides the save() function for the Part model. - If the part image has been updated, - then check if the "old" (previous) image is still used by another part. + """Overrides the save function for the Part model. + + If the part image has been updated, then check if the "old" (previous) image is still used by another part. If not, it is considered "orphaned" and will be deleted. """ - # Get category templates settings add_category_templates = kwargs.pop('add_category_templates', False) @@ -476,14 +451,14 @@ class Part(MetadataMixin, MPTTModel): pass def __str__(self): + """Return a string representation of the Part (for use in the admin interface)""" return f"{self.full_name} - {self.description}" def get_parts_in_bom(self, **kwargs): - """ - Return a list of all parts in the BOM for this part. + """Return a list of all parts in the BOM for this part. + Takes into account substitutes, variant parts, and inherited BOM items """ - parts = set() for bom_item in self.get_bom_items(**kwargs): @@ -493,28 +468,23 @@ class Part(MetadataMixin, MPTTModel): return parts def check_if_part_in_bom(self, other_part, **kwargs): - """ - Check if the other_part is in the BOM for *this* part. + """Check if the other_part is in the BOM for *this* part. Note: - Accounts for substitute parts - Accounts for variant BOMs """ - return other_part in self.get_parts_in_bom(**kwargs) def check_add_to_bom(self, parent, raise_error=False, recursive=True): - """ - Check if this Part can be added to the BOM of another part. + """Check if this Part can be added to the BOM of another part. This will fail if: a) The parent part is the same as this one b) The parent part is used in the BOM for *this* part c) The parent part is used in the BOM for any child parts under this one - """ - result = True try: @@ -553,13 +523,10 @@ class Part(MetadataMixin, MPTTModel): return result def checkIfSerialNumberExists(self, sn, exclude_self=False): - """ - Check if a serial number exists for this Part. + """Check if a serial number exists for this Part. - Note: Serial numbers must be unique across an entire Part "tree", - so here we filter by the entire tree. + Note: Serial numbers must be unique across an entire Part "tree", so here we filter by the entire tree. """ - parts = Part.objects.filter(tree_id=self.tree_id) stock = StockModels.StockItem.objects.filter(part__in=parts, serial=sn) @@ -570,10 +537,7 @@ class Part(MetadataMixin, MPTTModel): return stock.exists() def find_conflicting_serial_numbers(self, serials): - """ - For a provided list of serials, return a list of those which are conflicting. - """ - + """For a provided list of serials, return a list of those which are conflicting.""" conflicts = [] for serial in serials: @@ -583,8 +547,7 @@ class Part(MetadataMixin, MPTTModel): return conflicts def getLatestSerialNumber(self): - """ - Return the "latest" serial number for this Part. + """Return the "latest" serial number for this Part. If *all* the serial numbers are integers, then this will return the highest one. Otherwise, it will simply return the serial number most recently added. @@ -592,7 +555,6 @@ class Part(MetadataMixin, MPTTModel): Note: Serial numbers must be unique across an entire Part "tree", so we filter by the entire tree. """ - parts = Part.objects.filter(tree_id=self.tree_id) stock = StockModels.StockItem.objects.filter(part__in=parts).exclude(serial=None) @@ -617,11 +579,10 @@ class Part(MetadataMixin, MPTTModel): return None def getLatestSerialNumberInt(self): - """ - Return the "latest" serial number for this Part as a integer. + """Return the "latest" serial number for this Part as a integer. + If it is not an integer the result is 0 """ - latest = self.getLatestSerialNumber() # No serial number = > 0 @@ -637,11 +598,7 @@ class Part(MetadataMixin, MPTTModel): return 0 def getSerialNumberString(self, quantity=1): - """ - Return a formatted string representing the next available serial numbers, - given a certain quantity of items. - """ - + """Return a formatted string representing the next available serial numbers, given a certain quantity of items.""" latest = self.getLatestSerialNumber() quantity = int(quantity) @@ -674,9 +631,9 @@ class Part(MetadataMixin, MPTTModel): @property def full_name(self): - """ Format a 'full name' for this Part based on the format PART_NAME_FORMAT defined in Inventree settings + """Format a 'full name' for this Part based on the format PART_NAME_FORMAT defined in Inventree settings. - As a failsafe option, the following is done + As a failsafe option, the following is done: - IPN (if not null) - Part name @@ -684,7 +641,6 @@ class Part(MetadataMixin, MPTTModel): Elements are joined by the | character """ - full_name_pattern = InvenTreeSetting.get_setting('PART_NAME_FORMAT') try: @@ -711,48 +667,34 @@ class Part(MetadataMixin, MPTTModel): return ' | '.join(elements) - def set_category(self, category): - - # Ignore if the category is already the same - if self.category == category: - return - - self.category = category - self.save() - def get_absolute_url(self): - """ Return the web URL for viewing this part """ + """Return the web URL for viewing this part.""" return reverse('part-detail', kwargs={'pk': self.id}) def get_image_url(self): - """ Return the URL of the image for this part """ - + """Return the URL of the image for this part.""" if self.image: return helpers.getMediaUrl(self.image.url) else: return helpers.getBlankImage() def get_thumbnail_url(self): - """ - Return the URL of the image thumbnail for this part - """ - + """Return the URL of the image thumbnail for this part.""" if self.image: return helpers.getMediaUrl(self.image.thumbnail.url) else: return helpers.getBlankThumbnail() def validate_unique(self, exclude=None): - """ Validate that a part is 'unique'. - Uniqueness is checked across the following (case insensitive) fields: + """Validate that a part is 'unique'. - * Name - * IPN - * Revision + Uniqueness is checked across the following (case insensitive) fields: + - Name + - IPN + - Revision e.g. there can exist multiple parts with the same name, but only if they have a different revision or internal part number. - """ super().validate_unique(exclude) @@ -770,15 +712,13 @@ class Part(MetadataMixin, MPTTModel): }) def clean(self): - """ - Perform cleaning operations for the Part model + """Perform cleaning operations for the Part model. Update trackable status: If this part is trackable, and it is used in the BOM for a parent part which is *not* trackable, then we will force the parent part to be trackable. """ - super().clean() # Strip IPN field @@ -875,13 +815,12 @@ class Part(MetadataMixin, MPTTModel): ) def get_default_location(self): - """ Get the default location for a Part (may be None). + """Get the default location for a Part (may be None). If the Part does not specify a default location, look at the Category this part is in. The PartCategory object may also specify a default stock location """ - if self.default_location: return self.default_location elif self.category: @@ -896,13 +835,12 @@ class Part(MetadataMixin, MPTTModel): return None def get_default_supplier(self): - """ Get the default supplier part for this part (may be None). + """Get the default supplier part for this part (may be None). - If the part specifies a default_supplier, return that - If there is only one supplier part available, return that - Else, return None """ - if self.default_supplier: return self.default_supplier @@ -998,8 +936,7 @@ class Part(MetadataMixin, MPTTModel): responsible = models.ForeignKey(User, on_delete=models.SET_NULL, blank=True, null=True, verbose_name=_('Responsible'), related_name='parts_responible') def format_barcode(self, **kwargs): - """ Return a JSON string for formatting a barcode for this Part object """ - + """Return a JSON string for formatting a barcode for this Part object.""" return helpers.MakeBarcode( "part", self.id, @@ -1012,28 +949,24 @@ class Part(MetadataMixin, MPTTModel): @property def category_path(self): + """Return the category path of this Part instance""" if self.category: return self.category.pathstring return '' @property def available_stock(self): - """ - Return the total available stock. + """Return the total available stock. - This subtracts stock which is already allocated to builds """ - total = self.total_stock total -= self.allocation_count() return max(total, 0) def requiring_build_orders(self): - """ - Return list of outstanding build orders which require this part - """ - + """Return list of outstanding build orders which require this part.""" # List parts that this part is required for parts = self.get_used_in().all() @@ -1048,10 +981,7 @@ class Part(MetadataMixin, MPTTModel): return builds def required_build_order_quantity(self): - """ - Return the quantity of this part required for active build orders - """ - + """Return the quantity of this part required for active build orders.""" # List active build orders which reference this part builds = self.requiring_build_orders() @@ -1074,10 +1004,7 @@ class Part(MetadataMixin, MPTTModel): return quantity def requiring_sales_orders(self): - """ - Return a list of sales orders which require this part - """ - + """Return a list of sales orders which require this part.""" orders = set() # Get a list of line items for open orders which match this part @@ -1092,10 +1019,7 @@ class Part(MetadataMixin, MPTTModel): return orders def required_sales_order_quantity(self): - """ - Return the quantity of this part required for active sales orders - """ - + """Return the quantity of this part required for active sales orders.""" # Get a list of line items for open orders which match this part open_lines = OrderModels.SalesOrderLineItem.objects.filter( order__status__in=SalesOrderStatus.OPEN, @@ -1112,16 +1036,12 @@ class Part(MetadataMixin, MPTTModel): return quantity def required_order_quantity(self): - """ - Return total required to fulfil orders - """ - + """Return total required to fulfil orders.""" return self.required_build_order_quantity() + self.required_sales_order_quantity() @property def quantity_to_order(self): - """ - Return the quantity needing to be ordered for this part. + """Return the quantity needing to be ordered for this part. Here, an "order" could be one of: - Build Order @@ -1133,9 +1053,7 @@ class Part(MetadataMixin, MPTTModel): Required for orders = self.required_order_quantity() Currently on order = self.on_order Currently building = self.quantity_being_built - """ - # Total requirement required = self.required_order_quantity() @@ -1152,20 +1070,19 @@ class Part(MetadataMixin, MPTTModel): @property def net_stock(self): - """ Return the 'net' stock. It takes into account: + """Return the 'net' stock. + It takes into account: - Stock on hand (total_stock) - Stock on order (on_order) - Stock allocated (allocation_count) This number (unlike 'available_stock') can be negative. """ - return self.total_stock - self.allocation_count() + self.on_order def get_subscribers(self, include_variants=True, include_categories=True): - """ - Return a list of users who are 'subscribed' to this part. + """Return a list of users who are 'subscribed' to this part. A user may 'subscribe' to this part in the following ways: @@ -1173,9 +1090,7 @@ class Part(MetadataMixin, MPTTModel): b) Subscribing to a template part "above" this part (if it is a variant) c) Subscribing to the part category that this part belongs to d) Subscribing to a parent category of the category in c) - """ - subscribers = set() # Start by looking at direct subscriptions to a Part model @@ -1199,17 +1114,11 @@ class Part(MetadataMixin, MPTTModel): return [s for s in subscribers] def is_starred_by(self, user, **kwargs): - """ - Return True if the specified user subscribes to this part - """ - + """Return True if the specified user subscribes to this part.""" return user in self.get_subscribers(**kwargs) def set_starred(self, user, status): - """ - Set the "subscription" status of this Part against the specified user - """ - + """Set the "subscription" status of this Part against the specified user.""" if not user: return @@ -1224,21 +1133,9 @@ class Part(MetadataMixin, MPTTModel): # if the user is subscribed to a parent part or category PartStar.objects.filter(part=self, user=user).delete() - def need_to_restock(self): - """ Return True if this part needs to be restocked - (either by purchasing or building). - - If the allocated_stock exceeds the total_stock, - then we need to restock. - """ - - return (self.total_stock + self.on_order - self.allocation_count) < self.minimum_stock - @property def can_build(self): - """ Return the number of units that can be build with available stock - """ - + """Return the number of units that can be build with available stock.""" # If this part does NOT have a BOM, result is simply the currently available stock if not self.has_bom: return 0 @@ -1268,28 +1165,19 @@ class Part(MetadataMixin, MPTTModel): @property def active_builds(self): - """ Return a list of outstanding builds. + """Return a list of outstanding builds. + Builds marked as 'complete' or 'cancelled' are ignored """ - return self.builds.filter(status__in=BuildStatus.ACTIVE_CODES) - @property - def inactive_builds(self): - """ Return a list of inactive builds - """ - - return self.builds.exclude(status__in=BuildStatus.ACTIVE_CODES) - @property def quantity_being_built(self): - """ - Return the current number of parts currently being built. + """Return the current number of parts currently being built. Note: This is the total quantity of Build orders, *not* the number of build outputs. In this fashion, it is the "projected" quantity of builds """ - builds = self.active_builds quantity = 0 @@ -1301,10 +1189,7 @@ class Part(MetadataMixin, MPTTModel): return quantity def build_order_allocations(self, **kwargs): - """ - Return all 'BuildItem' objects which allocate this part to Build objects - """ - + """Return all 'BuildItem' objects which allocate this part to Build objects.""" include_variants = kwargs.get('include_variants', True) queryset = BuildModels.BuildItem.objects.all() @@ -1320,10 +1205,7 @@ class Part(MetadataMixin, MPTTModel): return queryset def build_order_allocation_count(self, **kwargs): - """ - Return the total amount of this part allocated to build orders - """ - + """Return the total amount of this part allocated to build orders.""" query = self.build_order_allocations(**kwargs).aggregate( total=Coalesce( Sum( @@ -1338,10 +1220,7 @@ class Part(MetadataMixin, MPTTModel): return query['total'] def sales_order_allocations(self, **kwargs): - """ - Return all sales-order-allocation objects which allocate this part to a SalesOrder - """ - + """Return all sales-order-allocation objects which allocate this part to a SalesOrder.""" include_variants = kwargs.get('include_variants', True) queryset = OrderModels.SalesOrderAllocation.objects.all() @@ -1375,10 +1254,7 @@ class Part(MetadataMixin, MPTTModel): return queryset def sales_order_allocation_count(self, **kwargs): - """ - Return the total quantity of this part allocated to sales orders - """ - + """Return the total quantity of this part allocated to sales orders.""" query = self.sales_order_allocations(**kwargs).aggregate( total=Coalesce( Sum( @@ -1393,11 +1269,7 @@ class Part(MetadataMixin, MPTTModel): return query['total'] def allocation_count(self, **kwargs): - """ - Return the total quantity of stock allocated for this part, - against both build orders and sales orders. - """ - + """Return the total quantity of stock allocated for this part, against both build orders and sales orders.""" return sum( [ self.build_order_allocation_count(**kwargs), @@ -1406,14 +1278,13 @@ class Part(MetadataMixin, MPTTModel): ) def stock_entries(self, include_variants=True, in_stock=None): - """ Return all stock entries for this Part. + """Return all stock entries for this Part. - If this is a template part, include variants underneath this. Note: To return all stock-entries for all part variants under this one, we need to be creative with the filtering. """ - if include_variants: query = StockModels.StockItem.objects.filter(part__in=self.get_descendants(include_self=True)) else: @@ -1427,10 +1298,7 @@ class Part(MetadataMixin, MPTTModel): return query def get_stock_count(self, include_variants=True): - """ - Return the total "in stock" count for this part - """ - + """Return the total "in stock" count for this part.""" entries = self.stock_entries(in_stock=True, include_variants=include_variants) query = entries.aggregate(t=Coalesce(Sum('quantity'), Decimal(0))) @@ -1439,17 +1307,15 @@ class Part(MetadataMixin, MPTTModel): @property def total_stock(self): - """ Return the total stock quantity for this part. + """Return the total stock quantity for this part. - Part may be stored in multiple locations - If this part is a "template" (variants exist) then these are counted too """ - return self.get_stock_count(include_variants=True) def get_bom_item_filter(self, include_inherited=True): - """ - Returns a query filter for all BOM items associated with this Part. + """Returns a query filter for all BOM items associated with this Part. There are some considerations: @@ -1461,9 +1327,7 @@ class Part(MetadataMixin, MPTTModel): Note: This does *not* return a queryset, it returns a Q object, which can be used by some other query operation! Because we want to keep our code DRY! - """ - bom_filter = Q(part=self) if include_inherited: @@ -1486,25 +1350,21 @@ class Part(MetadataMixin, MPTTModel): return bom_filter def get_bom_items(self, include_inherited=True): - """ - Return a queryset containing all BOM items for this part + """Return a queryset containing all BOM items for this part. By default, will include inherited BOM items """ - queryset = BomItem.objects.filter(self.get_bom_item_filter(include_inherited=include_inherited)) return queryset.prefetch_related('sub_part') - def get_installed_part_options(self, include_inherited=True, include_variants=True): - """ - Return a set of all Parts which can be "installed" into this part, based on the BOM. + def get_installed_part_options(self, include_inherited: bool = True, include_variants: bool = True): + """Return a set of all Parts which can be "installed" into this part, based on the BOM. - arguments: - include_inherited - If set, include BomItem entries defined for parent parts - include_variants - If set, include variant parts for BomItems which allow variants + Arguments: + include_inherited (bool): If set, include BomItem entries defined for parent parts + include_variants (bool): If set, include variant parts for BomItems which allow variants """ - parts = set() for bom_item in self.get_bom_items(include_inherited=include_inherited): @@ -1518,8 +1378,7 @@ class Part(MetadataMixin, MPTTModel): return parts def get_used_in_filter(self, include_inherited=True): - """ - Return a query filter for all parts that this part is used in. + """Return a query filter for all parts that this part is used in. There are some considerations: @@ -1529,7 +1388,6 @@ class Part(MetadataMixin, MPTTModel): Note: This function returns a Q object, not an actual queryset. The Q object is used to filter against a list of Part objects """ - # This is pretty expensive - we need to traverse multiple variant lists! # TODO - In the future, could this be improved somehow? @@ -1558,8 +1416,7 @@ class Part(MetadataMixin, MPTTModel): return Q(id__in=part_ids) def get_used_in(self, include_inherited=True): - """ - Return a queryset containing all parts this part is used in. + """Return a queryset containing all parts this part is used in. Includes consideration of inherited BOMs """ @@ -1567,13 +1424,11 @@ class Part(MetadataMixin, MPTTModel): @property def has_bom(self): + """Return True if this Part instance has any BOM items""" return self.get_bom_items().count() > 0 def get_trackable_parts(self): - """ - Return a queryset of all trackable parts in the BOM for this part - """ - + """Return a queryset of all trackable parts in the BOM for this part.""" queryset = self.get_bom_items() queryset = queryset.filter(sub_part__trackable=True) @@ -1581,32 +1436,28 @@ class Part(MetadataMixin, MPTTModel): @property def has_trackable_parts(self): - """ - Return True if any parts linked in the Bill of Materials are trackable. + """Return True if any parts linked in the Bill of Materials are trackable. + This is important when building the part. """ - return self.get_trackable_parts().count() > 0 @property def bom_count(self): - """ Return the number of items contained in the BOM for this part """ + """Return the number of items contained in the BOM for this part.""" return self.get_bom_items().count() @property def used_in_count(self): - """ Return the number of part BOMs that this part appears in """ + """Return the number of part BOMs that this part appears in.""" return self.get_used_in().count() def get_bom_hash(self): - """ Return a checksum hash for the BOM for this part. + """Return a checksum hash for the BOM for this part. + Used to determine if the BOM has changed (and needs to be signed off!) - - The hash is calculated by hashing each line item in the BOM. - - returns a string representation of a hash object which can be compared with a stored value + The hash is calculated by hashing each line item in the BOM. Returns a string representation of a hash object which can be compared with a stored value """ - result_hash = hashlib.md5(str(self.id).encode()) # List *all* BOM items (including inherited ones!) @@ -1618,19 +1469,16 @@ class Part(MetadataMixin, MPTTModel): return str(result_hash.digest()) def is_bom_valid(self): - """ Check if the BOM is 'valid' - if the calculated checksum matches the stored value - """ - + """Check if the BOM is 'valid' - if the calculated checksum matches the stored value.""" return self.get_bom_hash() == self.bom_checksum or not self.has_bom @transaction.atomic def validate_bom(self, user): - """ Validate the BOM (mark the BOM as validated by the given User. + """Validate the BOM (mark the BOM as validated by the given User. - Calculates and stores the hash for the BOM - Saves the current date and the checking user """ - # Validate each line item, ignoring inherited ones bom_items = self.get_bom_items(include_inherited=False) @@ -1645,23 +1493,19 @@ class Part(MetadataMixin, MPTTModel): @transaction.atomic def clear_bom(self): - """ - Clear the BOM items for the part (delete all BOM lines). + """Clear the BOM items for the part (delete all BOM lines). Note: Does *NOT* delete inherited BOM items! """ - self.bom_items.all().delete() def getRequiredParts(self, recursive=False, parts=None): - """ - Return a list of parts required to make this part (i.e. BOM items). + """Return a list of parts required to make this part (i.e. BOM items). Args: recursive: If True iterate down through sub-assemblies parts: Set of parts already found (to prevent recursion issues) """ - if parts is None: parts = set() @@ -1680,40 +1524,19 @@ class Part(MetadataMixin, MPTTModel): return parts - def get_allowed_bom_items(self): - """ - Return a list of parts which can be added to a BOM for this part. - - - Exclude parts which are not 'component' parts - - Exclude parts which this part is in the BOM for - """ - - # Start with a list of all parts designated as 'sub components' - parts = Part.objects.filter(component=True) - - # Exclude this part - parts = parts.exclude(id=self.id) - - # Exclude any parts that this part is used *in* (to prevent recursive BOMs) - used_in = self.get_used_in().all() - - parts = parts.exclude(id__in=[part.id for part in used_in]) - - return parts - @property def supplier_count(self): - """ Return the number of supplier parts available for this part """ + """Return the number of supplier parts available for this part.""" return self.supplier_parts.count() @property def has_pricing_info(self, internal=False): - """ Return true if there is pricing information for this part """ + """Return true if there is pricing information for this part.""" return self.get_price_range(internal=internal) is not None @property def has_complete_bom_pricing(self): - """ Return true if there is pricing information for each item in the BOM. """ + """Return true if there is pricing information for each item in the BOM.""" use_internal = common.models.get_setting('PART_BOM_USE_INTERNAL_PRICE', False) for item in self.get_bom_items().all().select_related('sub_part'): if not item.sub_part.has_pricing_info(use_internal): @@ -1722,7 +1545,7 @@ class Part(MetadataMixin, MPTTModel): return True def get_price_info(self, quantity=1, buy=True, bom=True, internal=False): - """ Return a simplified pricing string for this part + """Return a simplified pricing string for this part. Args: quantity: Number of units to calculate price for @@ -1730,7 +1553,6 @@ class Part(MetadataMixin, MPTTModel): bom: Include BOM pricing (default = True) internal: Include internal pricing (default = False) """ - price_range = self.get_price_range(quantity, buy, bom, internal) if price_range is None: @@ -1747,7 +1569,17 @@ class Part(MetadataMixin, MPTTModel): return "{a} - {b}".format(a=min_price, b=max_price) def get_supplier_price_range(self, quantity=1): + """Return the supplier price range of this part: + - Checks if there is any supplier pricing information associated with this Part + - Iterate through available supplier pricing and select (min, max) + - Returns tuple of (min, max) + + Arguments: + quantity: Quantity at which to calculate price (default=1) + + Returns: (min, max) tuple or (None, None) if no supplier pricing available + """ min_price = None max_price = None @@ -1773,13 +1605,12 @@ class Part(MetadataMixin, MPTTModel): return (min_price, max_price) def get_bom_price_range(self, quantity=1, internal=False, purchase=False): - """ Return the price range of the BOM for this part. - Adds the minimum price for all components in the BOM. + """Return the price range of the BOM for this part. + Adds the minimum price for all components in the BOM. Note: If the BOM contains items without pricing information, these items cannot be included in the BOM! """ - min_price = None max_price = None @@ -1817,9 +1648,9 @@ class Part(MetadataMixin, MPTTModel): return (min_price, max_price) def get_price_range(self, quantity=1, buy=True, bom=True, internal=False, purchase=False): + """Return the price range for this part. - """ Return the price range for this part. This price can be either: - + This price can be either: - Supplier price (if purchased from suppliers) - BOM price (if built from other parts) - Internal price (if set for the part) @@ -1828,7 +1659,6 @@ class Part(MetadataMixin, MPTTModel): Returns: Minimum of the supplier, BOM, internal or purchase price. If no pricing available, returns None """ - # only get internal price if set and should be used if internal and self.has_internal_price_breaks: internal_price = self.get_internal_price(quantity) @@ -1863,26 +1693,26 @@ class Part(MetadataMixin, MPTTModel): @property def has_price_breaks(self): + """Return True if this part has sale price breaks""" return self.price_breaks.count() > 0 @property def price_breaks(self): - """ Return the associated price breaks in the correct order """ + """Return the associated price breaks in the correct order.""" return self.salepricebreaks.order_by('quantity').all() @property def unit_pricing(self): + """Returns the price of this Part at quantity=1""" return self.get_price(1) def add_price_break(self, quantity, price): - """ - Create a new price break for this part + """Create a new price break for this part. - args: - quantity - Numerical quantity - price - Must be a Money object + Args: + quantity: Numerical quantity + price: Must be a Money object """ - # Check if a price break at that quantity already exists... if self.price_breaks.filter(quantity=quantity, part=self.pk).exists(): return @@ -1894,22 +1724,25 @@ class Part(MetadataMixin, MPTTModel): ) def get_internal_price(self, quantity, moq=True, multiples=True, currency=None): + """Return the internal price of this Part at the specified quantity""" return common.models.get_price(self, quantity, moq, multiples, currency, break_name='internal_price_breaks') @property def has_internal_price_breaks(self): + """Return True if this Part has internal pricing information""" return self.internal_price_breaks.count() > 0 @property def internal_price_breaks(self): - """ Return the associated price breaks in the correct order """ + """Return the associated price breaks in the correct order.""" return self.internalpricebreaks.order_by('quantity').all() - @property - def internal_unit_pricing(self): - return self.get_internal_price(1) - def get_purchase_price(self, quantity): + """Calculate the purchase price for this part at the specified quantity + + - Looks at available supplier pricing data + - Calculates the price base on the closest price point + """ currency = currency_code_default() try: prices = [convert_money(item.purchase_price, currency).amount for item in self.stock_items.all() if item.purchase_price] @@ -1923,14 +1756,12 @@ class Part(MetadataMixin, MPTTModel): @transaction.atomic def copy_bom_from(self, other, clear=True, **kwargs): - """ - Copy the BOM from another part. + """Copy the BOM from another part. - args: - other - The part to copy the BOM from - clear - Remove existing BOM items first (default=True) + Args: + other: The part to copy the BOM from + clear (bool, optional): Remove existing BOM items first. Defaults to True. """ - # Ignore if the other part is actually this part? if other == self: return @@ -1991,7 +1822,7 @@ class Part(MetadataMixin, MPTTModel): @transaction.atomic def copy_parameters_from(self, other, **kwargs): - + """Copy all parameter values from another Part instance""" clear = kwargs.get('clear', True) if clear: @@ -2015,16 +1846,15 @@ class Part(MetadataMixin, MPTTModel): @transaction.atomic def deep_copy(self, other, **kwargs): - """ Duplicates non-field data from another part. - Does not alter the normal fields of this part, - but can be used to copy other data linked by ForeignKey refernce. + """Duplicates non-field data from another part. + + Does not alter the normal fields of this part, but can be used to copy other data linked by ForeignKey refernce. Keyword Args: image: If True, copies Part image (default = True) bom: If True, copies BOM data (default = False) parameters: If True, copies Parameters data (default = True) """ - # Copy the part image if kwargs.get('image', True): if other.image: @@ -2050,15 +1880,14 @@ class Part(MetadataMixin, MPTTModel): self.save() def getTestTemplates(self, required=None, include_parent=True): - """ - Return a list of all test templates associated with this Part. + """Return a list of all test templates associated with this Part. + These are used for validation of a StockItem. - args: + Args: required: Set to True or False to filter by "required" status include_parent: Set to True to traverse upwards """ - if include_parent: tests = PartTestTemplate.objects.filter(part__in=self.get_ancestors(include_self=True)) else: @@ -2070,30 +1899,21 @@ class Part(MetadataMixin, MPTTModel): return tests def getRequiredTests(self): - # Return the tests which are required by this part + """Return the tests which are required by this part""" return self.getTestTemplates(required=True) - def requiredTestCount(self): - return self.getRequiredTests().count() - @property def attachment_count(self): - """ - Count the number of attachments for this part. + """Count the number of attachments for this part. + If the part is a variant of a template part, include the number of attachments for the template part. """ - return self.part_attachments.count() @property def part_attachments(self): - """ - Return *all* attachments for this part, - potentially including attachments for template parts - above this one. - """ - + """Return *all* attachments for this part, potentially including attachments for template parts above this one.""" ancestors = self.get_ancestors(include_self=True) attachments = PartAttachment.objects.filter(part__in=ancestors) @@ -2101,8 +1921,7 @@ class Part(MetadataMixin, MPTTModel): return attachments def sales_orders(self): - """ Return a list of sales orders which reference this part """ - + """Return a list of sales orders which reference this part.""" orders = [] for line in self.sales_order_line_items.all().prefetch_related('order'): @@ -2112,8 +1931,7 @@ class Part(MetadataMixin, MPTTModel): return orders def purchase_orders(self): - """ Return a list of purchase orders which reference this part """ - + """Return a list of purchase orders which reference this part.""" orders = [] for part in self.supplier_parts.all().prefetch_related('purchase_order_line_items'): @@ -2123,20 +1941,9 @@ class Part(MetadataMixin, MPTTModel): return orders - def open_purchase_orders(self): - """ Return a list of open purchase orders against this part """ - - return [order for order in self.purchase_orders() if order.status in PurchaseOrderStatus.OPEN] - - def closed_purchase_orders(self): - """ Return a list of closed purchase orders against this part """ - - return [order for order in self.purchase_orders() if order.status not in PurchaseOrderStatus.OPEN] - @property def on_order(self): - """ Return the total number of items on order for this part. """ - + """Return the total number of items on order for this part.""" orders = self.supplier_parts.filter(purchase_order_line_items__order__status__in=PurchaseOrderStatus.OPEN).aggregate( quantity=Sum('purchase_order_line_items__quantity'), received=Sum('purchase_order_line_items__received') @@ -2154,20 +1961,18 @@ class Part(MetadataMixin, MPTTModel): return quantity - received def get_parameters(self): - """ Return all parameters for this part, ordered by name """ - + """Return all parameters for this part, ordered by name.""" return self.parameters.order_by('template__name') def parameters_map(self): - """ - Return a map (dict) of parameter values assocaited with this Part instance, - of the form: + """Return a map (dict) of parameter values assocaited with this Part instance, of the form. + + Example: { "name_1": "value_1", "name_2": "value_2", } """ - params = {} for parameter in self.parameters.all(): @@ -2177,39 +1982,31 @@ class Part(MetadataMixin, MPTTModel): @property def has_variants(self): - """ Check if this Part object has variants underneath it. """ - + """Check if this Part object has variants underneath it.""" return self.get_all_variants().count() > 0 def get_all_variants(self): - """ Return all Part object which exist as a variant under this part. """ - + """Return all Part object which exist as a variant under this part.""" return self.get_descendants(include_self=False) @property def can_convert(self): - """ - Check if this Part can be "converted" to a different variant: + """Check if this Part can be "converted" to a different variant. It can be converted if: - a) It has non-virtual variant parts underneath it b) It has non-virtual template parts above it c) It has non-virtual sibling variants - """ - return self.get_conversion_options().count() > 0 def get_conversion_options(self): - """ - Return options for converting this part to a "variant" within the same tree + """Return options for converting this part to a "variant" within the same tree. a) Variants underneath this one b) Immediate parent c) Siblings """ - parts = [] # Child parts @@ -2240,44 +2037,43 @@ class Part(MetadataMixin, MPTTModel): return filtered_parts def get_related_parts(self): - """ Return list of tuples for all related parts: - - first value is PartRelated object - - second value is matching Part object - """ + """Return list of tuples for all related parts. + Includes: + - first value is PartRelated object + - second value is matching Part object + """ related_parts = [] related_parts_1 = self.related_parts_1.filter(part_1__id=self.pk) related_parts_2 = self.related_parts_2.filter(part_2__id=self.pk) + related_parts.append() + for related_part in related_parts_1: # Add to related parts list - related_parts.append((related_part, related_part.part_2)) + related_parts.append(related_part.part_2) for related_part in related_parts_2: # Add to related parts list - related_parts.append((related_part, related_part.part_1)) + related_parts.append(related_part.part_1) return related_parts @property def related_count(self): + """Return the number of 'related parts' which point to this Part""" return len(self.get_related_parts()) def is_part_low_on_stock(self): - """ - Returns True if the total stock for this part is less than the minimum stock level - """ - + """Returns True if the total stock for this part is less than the minimum stock level.""" return self.get_stock_count() < self.minimum_stock @receiver(post_save, sender=Part, dispatch_uid='part_post_save_log') def after_save_part(sender, instance: Part, created, **kwargs): - """ - Function to be executed after a Part is saved - """ + """Function to be executed after a Part is saved.""" from part import tasks as part_tasks if not created and not InvenTree.ready.isImportingData(): @@ -2288,15 +2084,15 @@ def after_save_part(sender, instance: Part, created, **kwargs): class PartAttachment(InvenTreeAttachment): - """ - Model for storing file attachments against a Part object - """ + """Model for storing file attachments against a Part object.""" @staticmethod def get_api_url(): + """Return the list API endpoint URL associated with the PartAttachment model""" return reverse('api-part-attachment-list') def getSubdir(self): + """Returns the media subdirectory where part attachments are stored""" return os.path.join("part_files", str(self.part.id)) part = models.ForeignKey(Part, on_delete=models.CASCADE, @@ -2304,12 +2100,11 @@ class PartAttachment(InvenTreeAttachment): class PartSellPriceBreak(common.models.PriceBreak): - """ - Represents a price break for selling this part - """ + """Represents a price break for selling this part.""" @staticmethod def get_api_url(): + """Return the list API endpoint URL associated with the PartSellPriceBreak model""" return reverse('api-part-sale-price-list') part = models.ForeignKey( @@ -2320,16 +2115,16 @@ class PartSellPriceBreak(common.models.PriceBreak): ) class Meta: + """Metaclass providing extra model definition""" unique_together = ('part', 'quantity') class PartInternalPriceBreak(common.models.PriceBreak): - """ - Represents a price break for internally selling this part - """ + """Represents a price break for internally selling this part.""" @staticmethod def get_api_url(): + """Return the list API endpoint URL associated with the PartInternalPriceBreak model""" return reverse('api-part-internal-price-list') part = models.ForeignKey( @@ -2339,11 +2134,12 @@ class PartInternalPriceBreak(common.models.PriceBreak): ) class Meta: + """Metaclass providing extra model definition""" unique_together = ('part', 'quantity') class PartStar(models.Model): - """ A PartStar object creates a subscription relationship between a User and a Part. + """A PartStar object creates a subscription relationship between a User and a Part. It is used to designate a Part as 'subscribed' for a given User. @@ -2357,6 +2153,7 @@ class PartStar(models.Model): user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name=_('User'), related_name='starred_parts') class Meta: + """Metaclass providing extra model definition""" unique_together = [ 'part', 'user' @@ -2364,8 +2161,7 @@ class PartStar(models.Model): class PartCategoryStar(models.Model): - """ - A PartCategoryStar creates a subscription relationship between a User and a PartCategory. + """A PartCategoryStar creates a subscription relationship between a User and a PartCategory. Attributes: category: Link to a PartCategory object @@ -2377,6 +2173,7 @@ class PartCategoryStar(models.Model): user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name=_('User'), related_name='starred_categories') class Meta: + """Metaclass providing extra model definition""" unique_together = [ 'category', 'user', @@ -2384,9 +2181,7 @@ class PartCategoryStar(models.Model): class PartTestTemplate(models.Model): - """ - A PartTestTemplate defines a 'template' for a test which is required to be run - against a StockItem (an instance of the Part). + """A PartTestTemplate defines a 'template' for a test which is required to be run against a StockItem (an instance of the Part). The test template applies "recursively" to part variants, allowing tests to be defined in a heirarchy. @@ -2400,26 +2195,24 @@ class PartTestTemplate(models.Model): @staticmethod def get_api_url(): + """Return the list API endpoint URL associated with the PartTestTemplate model""" return reverse('api-part-test-template-list') def save(self, *args, **kwargs): - + """Enforce 'clean' operation when saving a PartTestTemplate instance""" self.clean() super().save(*args, **kwargs) def clean(self): - + """Clean fields for the PartTestTemplate model""" self.test_name = self.test_name.strip() self.validate_unique() super().clean() def validate_unique(self, exclude=None): - """ - Test that this test template is 'unique' within this part tree. - """ - + """Test that this test template is 'unique' within this part tree.""" if not self.part.trackable: raise ValidationError({ 'part': _('Test templates can only be created for trackable parts') @@ -2446,7 +2239,7 @@ class PartTestTemplate(models.Model): @property def key(self): - """ Generate a key for this test """ + """Generate a key for this test.""" return helpers.generateTestKey(self.test_name) part = models.ForeignKey( @@ -2489,21 +2282,16 @@ class PartTestTemplate(models.Model): def validate_template_name(name): - """ - Prevent illegal characters in "name" field for PartParameterTemplate - """ - + """Prevent illegal characters in "name" field for PartParameterTemplate.""" for c in "!@#$%^&*()<>{}[].,?/\\|~`_+-=\'\"": if c in str(name): raise ValidationError(_(f"Illegal character in template name ({c})")) class PartParameterTemplate(models.Model): - """ - A PartParameterTemplate provides a template for key:value pairs for extra - parameters fields/values to be added to a Part. - This allows users to arbitrarily assign data fields to a Part - beyond the built-in attributes. + """A PartParameterTemplate provides a template for key:value pairs for extra parameters fields/values to be added to a Part. + + This allows users to arbitrarily assign data fields to a Part beyond the built-in attributes. Attributes: name: The name (key) of the Parameter [string] @@ -2512,19 +2300,21 @@ class PartParameterTemplate(models.Model): @staticmethod def get_api_url(): + """Return the list API endpoint URL associated with the PartParameterTemplate model""" return reverse('api-part-parameter-template-list') def __str__(self): + """Return a string representation of a PartParameterTemplate instance""" s = str(self.name) if self.units: s += " ({units})".format(units=self.units) return s def validate_unique(self, exclude=None): - """ Ensure that PartParameterTemplates cannot be created with the same name. + """Ensure that PartParameterTemplates cannot be created with the same name. + This test should be case-insensitive (which the unique caveat does not cover). """ - super().validate_unique(exclude) try: @@ -2550,8 +2340,7 @@ class PartParameterTemplate(models.Model): class PartParameter(models.Model): - """ - A PartParameter is a specific instance of a PartParameterTemplate. It assigns a particular parameter pair to a part. + """A PartParameter is a specific instance of a PartParameterTemplate. It assigns a particular parameter pair to a part. Attributes: part: Reference to a single Part object @@ -2561,10 +2350,11 @@ class PartParameter(models.Model): @staticmethod def get_api_url(): + """Return the list API endpoint URL associated with the PartParameter model""" return reverse('api-part-parameter-list') def __str__(self): - # String representation of a PartParameter (used in the admin interface) + """String representation of a PartParameter (used in the admin interface)""" return "{part} : {param} = {data}{units}".format( part=str(self.part.full_name), param=str(self.template.name), @@ -2573,6 +2363,7 @@ class PartParameter(models.Model): ) class Meta: + """Metaclass providing extra model definition""" # Prevent multiple instances of a parameter for a single part unique_together = ('part', 'template') @@ -2584,6 +2375,7 @@ class PartParameter(models.Model): @classmethod def create(cls, part, template, data, save=False): + """Custom save method for the PartParameter class""" part_parameter = cls(part=part, template=template, data=data) if save: part_parameter.save() @@ -2591,11 +2383,7 @@ class PartParameter(models.Model): class PartCategoryParameterTemplate(models.Model): - """ - A PartCategoryParameterTemplate creates a unique relationship between a PartCategory - and a PartParameterTemplate. - Multiple PartParameterTemplate instances can be associated to a PartCategory to drive - a default list of parameter templates attached to a Part instance upon creation. + """A PartCategoryParameterTemplate creates a unique relationship between a PartCategory and a PartParameterTemplate. Multiple PartParameterTemplate instances can be associated to a PartCategory to drive a default list of parameter templates attached to a Part instance upon creation. Attributes: category: Reference to a single PartCategory object @@ -2605,14 +2393,14 @@ class PartCategoryParameterTemplate(models.Model): """ class Meta: + """Metaclass providing extra model definition""" constraints = [ UniqueConstraint(fields=['category', 'parameter_template'], name='unique_category_parameter_template_pair') ] def __str__(self): - """ String representation of a PartCategoryParameterTemplate (admin interface) """ - + """String representation of a PartCategoryParameterTemplate (admin interface).""" if self.default_value: return f'{self.category.name} | {self.parameter_template.name} | {self.default_value}' else: @@ -2636,8 +2424,9 @@ class PartCategoryParameterTemplate(models.Model): help_text=_('Default Parameter Value')) -class BomItem(models.Model, DataImportMixin): - """ A BomItem links a part to its component items. +class BomItem(DataImportMixin, models.Model): + """A BomItem links a part to its component items. + A part can have a BOM (bill of materials) which defines which parts are required (and in what quantity) to make it. @@ -2689,17 +2478,17 @@ class BomItem(models.Model, DataImportMixin): @staticmethod def get_api_url(): + """Return the list API endpoint URL associated with the BomItem model""" return reverse('api-bom-list') def get_valid_parts_for_allocation(self, allow_variants=True, allow_substitutes=True): - """ - Return a list of valid parts which can be allocated against this BomItem: + """Return a list of valid parts which can be allocated against this BomItem. - - Include the referenced sub_part - - Include any directly specvified substitute parts - - If allow_variants is True, allow all variants of sub_part + Includes: + - The referenced sub_part + - Any directly specvified substitute parts + - If allow_variants is True, all variants of sub_part """ - # Set of parts we will allow parts = set() @@ -2732,25 +2521,19 @@ class BomItem(models.Model, DataImportMixin): return valid_parts def is_stock_item_valid(self, stock_item): - """ - Check if the provided StockItem object is "valid" for assignment against this BomItem - """ - + """Check if the provided StockItem object is "valid" for assignment against this BomItem.""" return stock_item.part in self.get_valid_parts_for_allocation() def get_stock_filter(self): - """ - Return a queryset filter for selecting StockItems which match this BomItem + """Return a queryset filter for selecting StockItems which match this BomItem. - Allow stock from all directly specified substitute parts - If allow_variants is True, allow all part variants - """ - return Q(part__in=[part.pk for part in self.get_valid_parts_for_allocation()]) def save(self, *args, **kwargs): - + """Enforce 'clean' operation when saving a BomItem instance""" self.clean() super().save(*args, **kwargs) @@ -2802,19 +2585,16 @@ class BomItem(models.Model, DataImportMixin): ) def get_item_hash(self): - """ Calculate the checksum hash of this BOM line item: + """Calculate the checksum hash of this BOM line item. The hash is calculated from the following fields: - - Part.full_name (if the part name changes, the BOM checksum is invalidated) - Quantity - Reference field - Note field - Optional field - Inherited field - """ - # Seed the hash with the ID of this BOM item result_hash = hashlib.md5(str(self.id).encode()) @@ -2830,12 +2610,11 @@ class BomItem(models.Model, DataImportMixin): return str(result_hash.digest()) def validate_hash(self, valid=True): - """ Mark this item as 'valid' (store the checksum hash). + """Mark this item as 'valid' (store the checksum hash). Args: valid: If true, validate the hash, otherwise invalidate it (default = True) """ - if valid: self.checksum = str(self.get_item_hash()) else: @@ -2845,8 +2624,7 @@ class BomItem(models.Model, DataImportMixin): @property def is_line_valid(self): - """ Check if this line item has been validated by the user """ - + """Check if this line item has been validated by the user.""" # Ensure an empty checksum returns False if len(self.checksum) == 0: return False @@ -2854,8 +2632,7 @@ class BomItem(models.Model, DataImportMixin): return self.get_item_hash() == self.checksum def clean(self): - """ - Check validity of the BomItem model. + """Check validity of the BomItem model. Performs model checks beyond simple field validation. @@ -2864,7 +2641,6 @@ class BomItem(models.Model, DataImportMixin): - If the "sub_part" is trackable, then the "part" must be trackable too! """ - super().clean() try: @@ -2897,18 +2673,18 @@ class BomItem(models.Model, DataImportMixin): raise ValidationError({'sub_part': _('Sub part must be specified')}) class Meta: + """Metaclass providing extra model definition""" verbose_name = _("BOM Item") def __str__(self): + """Return a string representation of this BomItem instance""" return "{n} x {child} to make {parent}".format( parent=self.part.full_name, child=self.sub_part.full_name, n=decimal2string(self.quantity)) def get_overage_quantity(self, quantity): - """ Calculate overage quantity - """ - + """Calculate overage quantity.""" # Most of the time overage string will be empty if len(self.overage) == 0: return 0 @@ -2949,8 +2725,7 @@ class BomItem(models.Model, DataImportMixin): return 0 def get_required_quantity(self, build_quantity): - """ Calculate the required part quantity, based on the supplier build_quantity. - Includes overage estimate in the returned value. + """Calculate the required part quantity, based on the supplier build_quantity. Includes overage estimate in the returned value. Args: build_quantity: Number of parts to build @@ -2958,7 +2733,6 @@ class BomItem(models.Model, DataImportMixin): Returns: Quantity required for this build (including overage) """ - # Base quantity requirement base_quantity = self.quantity * build_quantity @@ -2971,8 +2745,7 @@ class BomItem(models.Model, DataImportMixin): @property def price_range(self, internal=False): - """ Return the price-range for this BOM item. """ - + """Return the price-range for this BOM item.""" # get internal price setting use_internal = common.models.InvenTreeSetting.get_setting('PART_BOM_USE_INTERNAL_PRICE', False) prange = self.sub_part.get_price_range(self.quantity, internal=use_internal and internal) @@ -2993,9 +2766,7 @@ class BomItem(models.Model, DataImportMixin): class BomItemSubstitute(models.Model): - """ - A BomItemSubstitute provides a specification for alternative parts, - which can be used in a bill of materials. + """A BomItemSubstitute provides a specification for alternative parts, which can be used in a bill of materials. Attributes: bom_item: Link to the parent BomItem instance @@ -3003,24 +2774,24 @@ class BomItemSubstitute(models.Model): """ class Meta: + """Metaclass providing extra model definition""" verbose_name = _("BOM Item Substitute") # Prevent duplication of substitute parts unique_together = ('part', 'bom_item') def save(self, *args, **kwargs): - + """Enforce a full_clean when saving the BomItemSubstitute model""" self.full_clean() super().save(*args, **kwargs) def validate_unique(self, exclude=None): - """ - Ensure that this BomItemSubstitute is "unique": + """Ensure that this BomItemSubstitute is "unique". + Ensure: - It cannot point to the same "part" as the "sub_part" of the parent "bom_item" """ - super().validate_unique(exclude=exclude) if self.part == self.bom_item.sub_part: @@ -3030,6 +2801,7 @@ class BomItemSubstitute(models.Model): @staticmethod def get_api_url(): + """Returns the list API endpoint URL associated with this model""" return reverse('api-bom-substitute-list') bom_item = models.ForeignKey( @@ -3053,7 +2825,7 @@ class BomItemSubstitute(models.Model): class PartRelated(models.Model): - """ Store and handle related parts (eg. mating connector, crimps, etc.) """ + """Store and handle related parts (eg. mating connector, crimps, etc.).""" part_1 = models.ForeignKey(Part, related_name='related_parts_1', verbose_name=_('Part 1'), on_delete=models.DO_NOTHING) @@ -3063,11 +2835,11 @@ class PartRelated(models.Model): verbose_name=_('Part 2'), help_text=_('Select Related Part')) def __str__(self): + """Return a string representation of this Part-Part relationship""" return f'{self.part_1} <--> {self.part_2}' def validate(self, part_1, part_2): - ''' Validate that the two parts relationship is unique ''' - + """Validate that the two parts relationship is unique.""" validate = True parts = Part.objects.all() @@ -3087,8 +2859,7 @@ class PartRelated(models.Model): return validate def clean(self): - ''' Overwrite clean method to check that relation is unique ''' - + """Overwrite clean method to check that relation is unique.""" validate = self.validate(self.part_1, self.part_2) if not validate: diff --git a/InvenTree/part/serializers.py b/InvenTree/part/serializers.py index 4675c529da..4094821670 100644 --- a/InvenTree/part/serializers.py +++ b/InvenTree/part/serializers.py @@ -1,6 +1,4 @@ -""" -JSON serializers for Part app -""" +"""DRF data serializers for Part app.""" import imghdr from decimal import Decimal @@ -37,17 +35,10 @@ from .models import (BomItem, BomItemSubstitute, Part, PartAttachment, class CategorySerializer(InvenTreeModelSerializer): - """ Serializer for PartCategory """ - - def __init__(self, *args, **kwargs): - - super().__init__(*args, **kwargs) + """Serializer for PartCategory.""" def get_starred(self, category): - """ - Return True if the category is directly "starred" by the current user - """ - + """Return True if the category is directly "starred" by the current user.""" return category in self.context.get('starred_categories', []) url = serializers.CharField(source='get_absolute_url', read_only=True) @@ -59,6 +50,7 @@ class CategorySerializer(InvenTreeModelSerializer): starred = serializers.SerializerMethodField() class Meta: + """Metaclass defining serializer fields""" model = PartCategory fields = [ 'pk', @@ -76,11 +68,10 @@ class CategorySerializer(InvenTreeModelSerializer): class CategoryTree(InvenTreeModelSerializer): - """ - Serializer for PartCategory tree - """ + """Serializer for PartCategory tree.""" class Meta: + """Metaclass defining serializer fields""" model = PartCategory fields = [ 'pk', @@ -90,11 +81,10 @@ class CategoryTree(InvenTreeModelSerializer): class PartAttachmentSerializer(InvenTreeAttachmentSerializer): - """ - Serializer for the PartAttachment class - """ + """Serializer for the PartAttachment class.""" class Meta: + """Metaclass defining serializer fields""" model = PartAttachment fields = [ @@ -113,13 +103,12 @@ class PartAttachmentSerializer(InvenTreeAttachmentSerializer): class PartTestTemplateSerializer(InvenTreeModelSerializer): - """ - Serializer for the PartTestTemplate class - """ + """Serializer for the PartTestTemplate class.""" key = serializers.CharField(read_only=True) class Meta: + """Metaclass defining serializer fields""" model = PartTestTemplate fields = [ @@ -135,9 +124,7 @@ class PartTestTemplateSerializer(InvenTreeModelSerializer): class PartSalePriceSerializer(InvenTreeModelSerializer): - """ - Serializer for sale prices for Part model. - """ + """Serializer for sale prices for Part model.""" quantity = InvenTreeDecimalField() @@ -155,6 +142,7 @@ class PartSalePriceSerializer(InvenTreeModelSerializer): price_string = serializers.CharField(source='price', read_only=True) class Meta: + """Metaclass defining serializer fields""" model = PartSellPriceBreak fields = [ 'pk', @@ -167,9 +155,7 @@ class PartSalePriceSerializer(InvenTreeModelSerializer): class PartInternalPriceSerializer(InvenTreeModelSerializer): - """ - Serializer for internal prices for Part model. - """ + """Serializer for internal prices for Part model.""" quantity = InvenTreeDecimalField() @@ -187,6 +173,7 @@ class PartInternalPriceSerializer(InvenTreeModelSerializer): price_string = serializers.CharField(source='price', read_only=True) class Meta: + """Metaclass defining serializer fields""" model = PartInternalPriceBreak fields = [ 'pk', @@ -199,8 +186,8 @@ class PartInternalPriceSerializer(InvenTreeModelSerializer): class PartThumbSerializer(serializers.Serializer): - """ - Serializer for the 'image' field of the Part model. + """Serializer for the 'image' field of the Part model. + Used to serve and display existing Part images. """ @@ -209,12 +196,10 @@ class PartThumbSerializer(serializers.Serializer): class PartThumbSerializerUpdate(InvenTreeModelSerializer): - """ Serializer for updating Part thumbnail """ + """Serializer for updating Part thumbnail.""" def validate_image(self, value): - """ - Check that file is an image. - """ + """Check that file is an image.""" validate = imghdr.what(value) if not validate: raise serializers.ValidationError("File is not an image") @@ -223,6 +208,7 @@ class PartThumbSerializerUpdate(InvenTreeModelSerializer): image = InvenTreeAttachmentSerializerField(required=True) class Meta: + """Metaclass defining serializer fields""" model = Part fields = [ 'image', @@ -230,9 +216,10 @@ class PartThumbSerializerUpdate(InvenTreeModelSerializer): class PartParameterTemplateSerializer(InvenTreeModelSerializer): - """ JSON serializer for the PartParameterTemplate model """ + """JSON serializer for the PartParameterTemplate model.""" class Meta: + """Metaclass defining serializer fields""" model = PartParameterTemplate fields = [ 'pk', @@ -242,11 +229,12 @@ class PartParameterTemplateSerializer(InvenTreeModelSerializer): class PartParameterSerializer(InvenTreeModelSerializer): - """ JSON serializers for the PartParameter model """ + """JSON serializers for the PartParameter model.""" template_detail = PartParameterTemplateSerializer(source='template', many=False, read_only=True) class Meta: + """Metaclass defining serializer fields""" model = PartParameter fields = [ 'pk', @@ -258,13 +246,14 @@ class PartParameterSerializer(InvenTreeModelSerializer): class PartBriefSerializer(InvenTreeModelSerializer): - """ Serializer for Part (brief detail) """ + """Serializer for Part (brief detail)""" thumbnail = serializers.CharField(source='get_thumbnail_url', read_only=True) stock = serializers.FloatField(source='total_stock') class Meta: + """Metaclass defining serializer fields""" model = Part fields = [ 'pk', @@ -288,19 +277,20 @@ class PartBriefSerializer(InvenTreeModelSerializer): class PartSerializer(InvenTreeModelSerializer): - """ Serializer for complete detail information of a part. + """Serializer for complete detail information of a part. + Used when displaying all details of a single component. """ def get_api_url(self): + """Return the API url associated with this serializer""" return reverse_lazy('api-part-list') def __init__(self, *args, **kwargs): - """ - Custom initialization method for PartSerializer, - so that we can optionally pass extra fields based on the query. - """ + """Custom initialization method for PartSerializer: + - Allows us to optionally pass extra fields based on the query. + """ self.starred_parts = kwargs.pop('starred_parts', []) category_detail = kwargs.pop('category_detail', False) @@ -317,12 +307,10 @@ class PartSerializer(InvenTreeModelSerializer): @staticmethod def annotate_queryset(queryset): - """ - Add some extra annotations to the queryset, - performing database queries as efficiently as possible, - to reduce database trips. - """ + """Add some extra annotations to the queryset. + Performing database queries as efficiently as possible, to reduce database trips. + """ # Annotate with the total 'in stock' quantity queryset = queryset.annotate( in_stock=Coalesce( @@ -444,10 +432,7 @@ class PartSerializer(InvenTreeModelSerializer): return queryset def get_starred(self, part): - """ - Return "true" if the part is starred by the current user. - """ - + """Return "true" if the part is starred by the current user.""" return part in self.starred_parts # Extra detail for the category @@ -477,6 +462,7 @@ class PartSerializer(InvenTreeModelSerializer): ) class Meta: + """Metaclass defining serializer fields""" model = Part partial = True fields = [ @@ -522,14 +508,13 @@ class PartSerializer(InvenTreeModelSerializer): class PartRelationSerializer(InvenTreeModelSerializer): - """ - Serializer for a PartRelated model - """ + """Serializer for a PartRelated model.""" part_1_detail = PartSerializer(source='part_1', read_only=True, many=False) part_2_detail = PartSerializer(source='part_2', read_only=True, many=False) class Meta: + """Metaclass defining serializer fields""" model = PartRelated fields = [ 'pk', @@ -541,12 +526,13 @@ class PartRelationSerializer(InvenTreeModelSerializer): class PartStarSerializer(InvenTreeModelSerializer): - """ Serializer for a PartStar object """ + """Serializer for a PartStar object.""" partname = serializers.CharField(source='part.full_name', read_only=True) username = serializers.CharField(source='user.username', read_only=True) class Meta: + """Metaclass defining serializer fields""" model = PartStar fields = [ 'pk', @@ -558,13 +544,12 @@ class PartStarSerializer(InvenTreeModelSerializer): class BomItemSubstituteSerializer(InvenTreeModelSerializer): - """ - Serializer for the BomItemSubstitute class - """ + """Serializer for the BomItemSubstitute class.""" part_detail = PartBriefSerializer(source='part', read_only=True, many=False) class Meta: + """Metaclass defining serializer fields""" model = BomItemSubstitute fields = [ 'pk', @@ -575,15 +560,14 @@ class BomItemSubstituteSerializer(InvenTreeModelSerializer): class BomItemSerializer(InvenTreeModelSerializer): - """ - Serializer for BomItem object - """ + """Serializer for BomItem object.""" price_range = serializers.CharField(read_only=True) quantity = InvenTreeDecimalField(required=True) def validate_quantity(self, quantity): + """Perform validation for the BomItem quantity field""" if quantity <= 0: raise serializers.ValidationError(_("Quantity must be greater than zero")) @@ -615,9 +599,11 @@ class BomItemSerializer(InvenTreeModelSerializer): available_variant_stock = serializers.FloatField(read_only=True) def __init__(self, *args, **kwargs): - # part_detail and sub_part_detail serializers are only included if requested. - # This saves a bunch of database requests + """Determine if extra detail fields are to be annotated on this serializer + - part_detail and sub_part_detail serializers are only included if requested. + - This saves a bunch of database requests + """ part_detail = kwargs.pop('part_detail', False) sub_part_detail = kwargs.pop('sub_part_detail', False) include_pricing = kwargs.pop('include_pricing', False) @@ -640,6 +626,7 @@ class BomItemSerializer(InvenTreeModelSerializer): @staticmethod def setup_eager_loading(queryset): + """Prefetch against the provided queryset to speed up database access""" queryset = queryset.prefetch_related('part') queryset = queryset.prefetch_related('part__category') queryset = queryset.prefetch_related('part__stock_items') @@ -663,18 +650,15 @@ class BomItemSerializer(InvenTreeModelSerializer): @staticmethod def annotate_queryset(queryset): - """ - Annotate the BomItem queryset with extra information: + """Annotate the BomItem queryset with extra information: Annotations: available_stock: The amount of stock available for the sub_part Part object """ - """ Construct an "available stock" quantity: available_stock = total_stock - build_order_allocations - sales_order_allocations """ - build_order_filter = Q(build__status__in=BuildStatus.ACTIVE_CODES) sales_order_filter = Q( line__order__status__in=SalesOrderStatus.OPEN, @@ -799,8 +783,7 @@ class BomItemSerializer(InvenTreeModelSerializer): return queryset def get_purchase_price_range(self, obj): - """ Return purchase price range """ - + """Return purchase price range.""" try: purchase_price_min = obj.purchase_price_min except AttributeError: @@ -830,8 +813,7 @@ class BomItemSerializer(InvenTreeModelSerializer): return purchase_price_range def get_purchase_price_avg(self, obj): - """ Return purchase price average """ - + """Return purchase price average.""" try: purchase_price_avg = obj.purchase_price_avg except AttributeError: @@ -846,6 +828,7 @@ class BomItemSerializer(InvenTreeModelSerializer): return purchase_price_avg class Meta: + """Metaclass defining serializer fields""" model = BomItem fields = [ 'allow_variants', @@ -877,7 +860,7 @@ class BomItemSerializer(InvenTreeModelSerializer): class CategoryParameterTemplateSerializer(InvenTreeModelSerializer): - """ Serializer for PartCategoryParameterTemplate """ + """Serializer for PartCategoryParameterTemplate.""" parameter_template = PartParameterTemplateSerializer(many=False, read_only=True) @@ -885,6 +868,7 @@ class CategoryParameterTemplateSerializer(InvenTreeModelSerializer): category_detail = CategorySerializer(source='category', many=False, read_only=True) class Meta: + """Metaclass defining serializer fields""" model = PartCategoryParameterTemplate fields = [ 'pk', @@ -896,11 +880,10 @@ class CategoryParameterTemplateSerializer(InvenTreeModelSerializer): class PartCopyBOMSerializer(serializers.Serializer): - """ - Serializer for copying a BOM from another part - """ + """Serializer for copying a BOM from another part.""" class Meta: + """Metaclass defining serializer fields""" fields = [ 'part', 'remove_existing', @@ -919,10 +902,7 @@ class PartCopyBOMSerializer(serializers.Serializer): ) def validate_part(self, part): - """ - Check that a 'valid' part was selected - """ - + """Check that a 'valid' part was selected.""" return part remove_existing = serializers.BooleanField( @@ -950,10 +930,7 @@ class PartCopyBOMSerializer(serializers.Serializer): ) def save(self): - """ - Actually duplicate the BOM - """ - + """Actually duplicate the BOM.""" base_part = self.context['part'] data = self.validated_data @@ -968,13 +945,12 @@ class PartCopyBOMSerializer(serializers.Serializer): class BomImportUploadSerializer(DataFileUploadSerializer): - """ - Serializer for uploading a file and extracting data from it. - """ + """Serializer for uploading a file and extracting data from it.""" TARGET_MODEL = BomItem class Meta: + """Metaclass defining serializer fields""" fields = [ 'data_file', 'part', @@ -994,7 +970,7 @@ class BomImportUploadSerializer(DataFileUploadSerializer): ) def save(self): - + """The uploaded data file has been validated, accept the submitted data""" data = self.validated_data if data.get('clear_existing_bom', False): @@ -1005,12 +981,15 @@ class BomImportUploadSerializer(DataFileUploadSerializer): class BomImportExtractSerializer(DataFileExtractSerializer): - """ + """Serializer class for exatracting BOM data from an uploaded file. + + The parent class DataFileExtractSerializer does most of the heavy lifting here. """ TARGET_MODEL = BomItem def validate_extracted_columns(self): + """Validate that the extracted columns are correct""" super().validate_extracted_columns() part_columns = ['part', 'part_name', 'part_ipn', 'part_id'] @@ -1020,7 +999,7 @@ class BomImportExtractSerializer(DataFileExtractSerializer): raise serializers.ValidationError(_("No part column specified")) def process_row(self, row): - + """Process a single row from the loaded BOM file""" # Skip any rows which are at a lower "level" level = row.get('level', None) @@ -1089,8 +1068,7 @@ class BomImportExtractSerializer(DataFileExtractSerializer): class BomImportSubmitSerializer(serializers.Serializer): - """ - Serializer for uploading a BOM against a specified part. + """Serializer for uploading a BOM against a specified part. A "BOM" is a set of BomItem objects which are to be validated together as a set """ @@ -1098,7 +1076,10 @@ class BomImportSubmitSerializer(serializers.Serializer): items = BomItemSerializer(many=True, required=True) def validate(self, data): + """Validate the submitted BomItem data: + - At least one line (BomItem) is required + """ items = data['items'] if len(items) == 0: @@ -1109,7 +1090,11 @@ class BomImportSubmitSerializer(serializers.Serializer): return data def save(self): + """POST: Perform final save of submitted BOM data: + - By this stage each line in the BOM has been validated + - Individually 'save' (create) each BomItem line + """ data = self.validated_data items = data['items'] diff --git a/InvenTree/part/settings.py b/InvenTree/part/settings.py index 41ecbb0d00..e5706a87ff 100644 --- a/InvenTree/part/settings.py +++ b/InvenTree/part/settings.py @@ -1,61 +1,38 @@ -""" -User-configurable settings for the Part app -""" +"""User-configurable settings for the Part app.""" from common.models import InvenTreeSetting def part_assembly_default(): - """ - Returns the default value for the 'assembly' field of a Part object - """ - + """Returns the default value for the 'assembly' field of a Part object.""" return InvenTreeSetting.get_setting('PART_ASSEMBLY') def part_template_default(): - """ - Returns the default value for the 'is_template' field of a Part object - """ - + """Returns the default value for the 'is_template' field of a Part object.""" return InvenTreeSetting.get_setting('PART_TEMPLATE') def part_virtual_default(): - """ - Returns the default value for the 'is_virtual' field of Part object - """ - + """Returns the default value for the 'is_virtual' field of Part object.""" return InvenTreeSetting.get_setting('PART_VIRTUAL') def part_component_default(): - """ - Returns the default value for the 'component' field of a Part object - """ - + """Returns the default value for the 'component' field of a Part object.""" return InvenTreeSetting.get_setting('PART_COMPONENT') def part_purchaseable_default(): - """ - Returns the default value for the 'purchasable' field for a Part object - """ - + """Returns the default value for the 'purchasable' field for a Part object.""" return InvenTreeSetting.get_setting('PART_PURCHASEABLE') def part_salable_default(): - """ - Returns the default value for the 'salable' field for a Part object - """ - + """Returns the default value for the 'salable' field for a Part object.""" return InvenTreeSetting.get_setting('PART_SALABLE') def part_trackable_default(): - """ - Returns the default value for the 'trackable' field for a Part object - """ - + """Returns the default value for the 'trackable' field for a Part object.""" return InvenTreeSetting.get_setting('PART_TRACKABLE') diff --git a/InvenTree/part/tasks.py b/InvenTree/part/tasks.py index 9037fb92be..db1443042a 100644 --- a/InvenTree/part/tasks.py +++ b/InvenTree/part/tasks.py @@ -1,3 +1,5 @@ +"""Background task definitions for the 'part' app""" + import logging from django.utils.translation import gettext_lazy as _ @@ -11,6 +13,11 @@ logger = logging.getLogger("inventree") def notify_low_stock(part: part.models.Part): + """Notify interested users that a part is 'low stock': + + - Triggered when the available stock for a given part falls be low the configured threhsold + - A notification is delivered to any users who are 'subscribed' to this part + """ name = _("Low stock notification") message = _(f'The available stock for {part.name} has fallen below the configured minimum level') context = { @@ -33,12 +40,10 @@ def notify_low_stock(part: part.models.Part): def notify_low_stock_if_required(part: part.models.Part): - """ - Check if the stock quantity has fallen below the minimum threshold of part. + """Check if the stock quantity has fallen below the minimum threshold of part. If true, notify the users who have subscribed to the part """ - # Run "up" the tree, to allow notification for "parent" parts parts = part.get_ancestors(include_self=True, ascending=True) diff --git a/InvenTree/part/templatetags/__init__.py b/InvenTree/part/templatetags/__init__.py index e69de29bb2..ce063ef241 100644 --- a/InvenTree/part/templatetags/__init__.py +++ b/InvenTree/part/templatetags/__init__.py @@ -0,0 +1 @@ +"""Custom InvenTree template tags for HTML template rendering""" diff --git a/InvenTree/part/templatetags/inventree_extras.py b/InvenTree/part/templatetags/inventree_extras.py index 7c7ee4de12..e48c6f89c6 100644 --- a/InvenTree/part/templatetags/inventree_extras.py +++ b/InvenTree/part/templatetags/inventree_extras.py @@ -1,9 +1,4 @@ -# -*- coding: utf-8 -*- - -""" -This module provides template tags for extra functionality, -over and above the built-in Django tags. -""" +"""This module provides template tags for extra functionality, over and above the built-in Django tags.""" import logging import os @@ -33,26 +28,22 @@ logger = logging.getLogger('inventree') @register.simple_tag() def define(value, *args, **kwargs): - """ - Shortcut function to overcome the shortcomings of the django templating language + """Shortcut function to overcome the shortcomings of the django templating language. Use as follows: {% define "hello_world" as hello %} Ref: https://stackoverflow.com/questions/1070398/how-to-set-a-value-of-a-variable-inside-a-template-code """ - return value @register.simple_tag(takes_context=True) def render_date(context, date_object): - """ - Renders a date according to the preference of the provided user + """Renders a date according to the preference of the provided user. Note that the user preference is stored using the formatting adopted by moment.js, which differs from the python formatting! """ - if date_object is None: return None @@ -105,59 +96,43 @@ def render_date(context, date_object): @register.simple_tag() def decimal(x, *args, **kwargs): - """ Simplified rendering of a decimal number """ - + """Simplified rendering of a decimal number.""" return InvenTree.helpers.decimal2string(x) @register.simple_tag() def str2bool(x, *args, **kwargs): - """ Convert a string to a boolean value """ - + """Convert a string to a boolean value.""" return InvenTree.helpers.str2bool(x) -@register.simple_tag() -def inrange(n, *args, **kwargs): - """ Return range(n) for iterating through a numeric quantity """ - return range(n) - - -@register.simple_tag() -def multiply(x, y, *args, **kwargs): - """ Multiply two numbers together """ - return InvenTree.helpers.decimal2string(x * y) - - @register.simple_tag() def add(x, y, *args, **kwargs): - """ Add two numbers together """ + """Add two numbers together.""" return x + y @register.simple_tag() def to_list(*args): - """ Return the input arguments as list """ + """Return the input arguments as list.""" return args @register.simple_tag() def part_allocation_count(build, part, *args, **kwargs): - """ Return the total number of allocated to """ - + """Return the total number of allocated to """ return InvenTree.helpers.decimal2string(build.getAllocatedQuantity(part)) @register.simple_tag() def inventree_in_debug_mode(*args, **kwargs): - """ Return True if the server is running in DEBUG mode """ - + """Return True if the server is running in DEBUG mode.""" return djangosettings.DEBUG @register.simple_tag() def inventree_show_about(user, *args, **kwargs): - """ Return True if the about modal should be shown """ + """Return True if the about modal should be shown.""" if InvenTreeSetting.get_setting('INVENTREE_RESTRICT_ABOUT') and not user.is_superuser: return False return True @@ -165,22 +140,19 @@ def inventree_show_about(user, *args, **kwargs): @register.simple_tag() def inventree_docker_mode(*args, **kwargs): - """ Return True if the server is running as a Docker image """ - + """Return True if the server is running as a Docker image.""" return djangosettings.DOCKER @register.simple_tag() def plugins_enabled(*args, **kwargs): - """ Return True if plugins are enabled for the server instance """ - + """Return True if plugins are enabled for the server instance.""" return djangosettings.PLUGINS_ENABLED @register.simple_tag() def inventree_db_engine(*args, **kwargs): - """ Return the InvenTree database backend e.g. 'postgresql' """ - + """Return the InvenTree database backend e.g. 'postgresql'.""" db = djangosettings.DATABASES['default'] engine = db.get('ENGINE', _('Unknown database')) @@ -192,33 +164,31 @@ def inventree_db_engine(*args, **kwargs): @register.simple_tag() def inventree_instance_name(*args, **kwargs): - """ Return the InstanceName associated with the current database """ + """Return the InstanceName associated with the current database.""" return version.inventreeInstanceName() @register.simple_tag() def inventree_title(*args, **kwargs): - """ Return the title for the current instance - respecting the settings """ + """Return the title for the current instance - respecting the settings""" return version.inventreeInstanceTitle() @register.simple_tag() def inventree_base_url(*args, **kwargs): - """ Return the INVENTREE_BASE_URL setting """ + """Return the INVENTREE_BASE_URL setting.""" return InvenTreeSetting.get_setting('INVENTREE_BASE_URL') @register.simple_tag() def python_version(*args, **kwargs): - """ - Return the current python version - """ + """Return the current python version.""" return sys.version.split(' ')[0] @register.simple_tag() def inventree_version(shortstring=False, *args, **kwargs): - """ Return InvenTree version string """ + """Return InvenTree version string.""" if shortstring: return _("{title} v{version}".format( title=version.inventreeInstanceTitle(), @@ -229,53 +199,55 @@ def inventree_version(shortstring=False, *args, **kwargs): @register.simple_tag() def inventree_is_development(*args, **kwargs): + """Returns True if this is a development version of InvenTree""" return version.isInvenTreeDevelopmentVersion() @register.simple_tag() def inventree_is_release(*args, **kwargs): + """Returns True if this is a release version of InvenTree""" return not version.isInvenTreeDevelopmentVersion() @register.simple_tag() def inventree_docs_version(*args, **kwargs): + """Returns the InvenTree documentation version""" return version.inventreeDocsVersion() @register.simple_tag() def inventree_api_version(*args, **kwargs): - """ Return InvenTree API version """ + """Return InvenTree API version.""" return version.inventreeApiVersion() @register.simple_tag() def django_version(*args, **kwargs): - """ Return Django version string """ + """Return Django version string.""" return version.inventreeDjangoVersion() @register.simple_tag() def inventree_commit_hash(*args, **kwargs): - """ Return InvenTree git commit hash string """ + """Return InvenTree git commit hash string.""" return version.inventreeCommitHash() @register.simple_tag() def inventree_commit_date(*args, **kwargs): - """ Return InvenTree git commit date string """ + """Return InvenTree git commit date string.""" return version.inventreeCommitDate() @register.simple_tag() def inventree_github_url(*args, **kwargs): - """ Return URL for InvenTree github site """ + """Return URL for InvenTree github site.""" return "https://github.com/InvenTree/InvenTree/" @register.simple_tag() def inventree_docs_url(*args, **kwargs): - """ Return URL for InvenTree documenation site """ - + """Return URL for InvenTree documenation site.""" tag = version.inventreeDocsVersion() return f"https://inventree.readthedocs.io/en/{tag}" @@ -283,24 +255,23 @@ def inventree_docs_url(*args, **kwargs): @register.simple_tag() def inventree_credits_url(*args, **kwargs): - """ Return URL for InvenTree credits site """ + """Return URL for InvenTree credits site.""" return "https://inventree.readthedocs.io/en/latest/credits/" @register.simple_tag() def default_currency(*args, **kwargs): - """ Returns the default currency code """ + """Returns the default currency code.""" return currency_code_default() @register.simple_tag() def setting_object(key, *args, **kwargs): - """ - Return a setting object speciifed by the given key + """Return a setting object speciifed by the given key. + (Or return None if the setting does not exist) if a user-setting was requested return that """ - if 'plugin' in kwargs: # Note, 'plugin' is an instance of an InvenTreePlugin class @@ -319,10 +290,7 @@ def setting_object(key, *args, **kwargs): @register.simple_tag() def settings_value(key, *args, **kwargs): - """ - Return a settings value specified by the given key - """ - + """Return a settings value specified by the given key.""" if 'user' in kwargs: if not kwargs['user'] or (kwargs['user'] and kwargs['user'].is_authenticated is False): return InvenTreeUserSetting.get_setting(key) @@ -333,37 +301,25 @@ def settings_value(key, *args, **kwargs): @register.simple_tag() def user_settings(user, *args, **kwargs): - """ - Return all USER settings as a key:value dict - """ - + """Return all USER settings as a key:value dict.""" return InvenTreeUserSetting.allValues(user=user) @register.simple_tag() def global_settings(*args, **kwargs): - """ - Return all GLOBAL InvenTree settings as a key:value dict - """ - + """Return all GLOBAL InvenTree settings as a key:value dict.""" return InvenTreeSetting.allValues() @register.simple_tag() def visible_global_settings(*args, **kwargs): - """ - Return any global settings which are not marked as 'hidden' - """ - + """Return any global settings which are not marked as 'hidden'.""" return InvenTreeSetting.allValues(exclude_hidden=True) @register.simple_tag() def progress_bar(val, max_val, *args, **kwargs): - """ - Render a progress bar element - """ - + """Render a progress bar element.""" item_id = kwargs.get('id', 'progress-bar') val = InvenTree.helpers.normalize(val) @@ -402,6 +358,7 @@ def progress_bar(val, max_val, *args, **kwargs): @register.simple_tag() def get_color_theme_css(username): + """Return the cutsom theme .css file for the selected user""" user_theme_name = get_user_color_theme(username) # Build path to CSS sheet inventree_css_sheet = os.path.join('css', 'color-themes', user_theme_name + '.css') @@ -414,7 +371,7 @@ def get_color_theme_css(username): @register.simple_tag() def get_user_color_theme(username): - """ Get current user color theme """ + """Get current user color theme.""" try: user_theme = ColorTheme.objects.filter(user=username).get() user_theme_name = user_theme.name @@ -428,10 +385,7 @@ def get_user_color_theme(username): @register.simple_tag() def get_available_themes(*args, **kwargs): - """ - Return the available theme choices - """ - + """Return the available theme choices.""" themes = [] for key, name in ColorTheme.get_color_themes_choices(): @@ -445,13 +399,11 @@ def get_available_themes(*args, **kwargs): @register.simple_tag() def primitive_to_javascript(primitive): - """ - Convert a python primitive to a javascript primitive. + """Convert a python primitive to a javascript primitive. e.g. True -> true 'hello' -> '"hello"' """ - if type(primitive) is bool: return str(primitive).lower() @@ -465,10 +417,9 @@ def primitive_to_javascript(primitive): @register.filter def keyvalue(dict, key): - """ - access to key of supplied dict + """Access to key of supplied dict. - usage: + Usage: {% mydict|keyvalue:mykey %} """ return dict.get(key) @@ -476,10 +427,9 @@ def keyvalue(dict, key): @register.simple_tag() def call_method(obj, method_name, *args): - """ - enables calling model methods / functions from templates with arguments + """Enables calling model methods / functions from templates with arguments. - usage: + Usage: {% call_method model_object 'fnc_name' argument1 %} """ method = getattr(obj, method_name) @@ -488,8 +438,7 @@ def call_method(obj, method_name, *args): @register.simple_tag() def authorized_owners(group): - """ Return authorized owners """ - + """Return authorized owners.""" owners = [] try: @@ -507,41 +456,39 @@ def authorized_owners(group): @register.simple_tag() def object_link(url_name, pk, ref): - """ Return highlighted link to object """ - + """Return highlighted link to object.""" ref_url = reverse(url_name, kwargs={'pk': pk}) return mark_safe('{}'.format(ref_url, ref)) @register.simple_tag() def mail_configured(): - """ Return if mail is configured """ + """Return if mail is configured.""" return bool(settings.EMAIL_HOST) @register.simple_tag() def inventree_customize(reference, *args, **kwargs): - """ Return customization values for the user interface """ - + """Return customization values for the user interface.""" return djangosettings.CUSTOMIZE.get(reference, '') @register.simple_tag() def inventree_logo(*args, **kwargs): - """ Return the path to the logo-file """ - + """Return the path to the logo-file.""" if settings.CUSTOM_LOGO: return default_storage.url(settings.CUSTOM_LOGO) return static('img/inventree.png') class I18nStaticNode(StaticNode): - """ - custom StaticNode - replaces a variable named *lng* in the path with the current language - """ - def render(self, context): # pragma: no cover + """Custom StaticNode. + Replaces a variable named *lng* in the path with the current language + """ + + def render(self, context): # pragma: no cover + """Render this node with the determined locale context.""" self.original = getattr(self, 'original', None) if not self.original: @@ -561,17 +508,16 @@ if settings.DEBUG: @register.simple_tag() def i18n_static(url_name): - """ simple tag to enable {% url %} functionality instead of {% static %} """ + """Simple tag to enable {% url %} functionality instead of {% static %}""" return reverse(url_name) else: # pragma: no cover @register.tag('i18n_static') def do_i18n_static(parser, token): - """ - Overrides normal static, adds language - lookup for prerenderd files #1485 + """Overrides normal static, adds language - lookup for prerenderd files #1485 - usage (like static): + Usage (like static): {% i18n_static path [as varname] %} """ bits = token.split_contents() diff --git a/InvenTree/part/templatetags/status_codes.py b/InvenTree/part/templatetags/status_codes.py index 12072d3d40..d3811c38f5 100644 --- a/InvenTree/part/templatetags/status_codes.py +++ b/InvenTree/part/templatetags/status_codes.py @@ -1,6 +1,4 @@ -""" -Provide templates for the various model status codes. -""" +"""Provide templates for the various model status codes.""" from django import template from django.utils.safestring import mark_safe @@ -13,28 +11,29 @@ register = template.Library() @register.simple_tag def purchase_order_status_label(key, *args, **kwargs): - """ Render a PurchaseOrder status label """ + """Render a PurchaseOrder status label.""" return mark_safe(PurchaseOrderStatus.render(key, large=kwargs.get('large', False))) @register.simple_tag def sales_order_status_label(key, *args, **kwargs): - """ Render a SalesOrder status label """ + """Render a SalesOrder status label.""" return mark_safe(SalesOrderStatus.render(key, large=kwargs.get('large', False))) @register.simple_tag def stock_status_label(key, *args, **kwargs): - """ Render a StockItem status label """ + """Render a StockItem status label.""" return mark_safe(StockStatus.render(key, large=kwargs.get('large', False))) @register.simple_tag def stock_status_text(key, *args, **kwargs): + """Render the text value of a StockItem status value""" return mark_safe(StockStatus.text(key)) @register.simple_tag def build_status_label(key, *args, **kwargs): - """ Render a Build status label """ + """Render a Build status label.""" return mark_safe(BuildStatus.render(key, large=kwargs.get('large', False))) diff --git a/InvenTree/part/test_api.py b/InvenTree/part/test_api.py index 02fbe5d29e..6af000f283 100644 --- a/InvenTree/part/test_api.py +++ b/InvenTree/part/test_api.py @@ -1,3 +1,5 @@ +"""Unit tests for the various part API endpoints""" + from django.urls import reverse import PIL @@ -11,12 +13,13 @@ from company.models import Company from InvenTree.api_tester import InvenTreeAPITestCase from InvenTree.status_codes import (BuildStatus, PurchaseOrderStatus, StockStatus) -from part.models import BomItem, BomItemSubstitute, Part, PartCategory +from part.models import (BomItem, BomItemSubstitute, Part, PartCategory, + PartRelated) from stock.models import StockItem, StockLocation class PartCategoryAPITest(InvenTreeAPITestCase): - """Unit tests for the PartCategory API""" + """Unit tests for the PartCategory API.""" fixtures = [ 'category', @@ -40,8 +43,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase): ] def test_category_list(self): - - # List all part categories + """Test the PartCategoryList API endpoint""" url = reverse('api-part-category-list') response = self.get(url, expected_code=200) @@ -73,8 +75,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase): self.assertEqual(len(response.data), 5) def test_category_metadata(self): - """Test metadata endpoint for the PartCategory""" - + """Test metadata endpoint for the PartCategory.""" cat = PartCategory.objects.get(pk=1) cat.metadata = { @@ -95,8 +96,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase): class PartOptionsAPITest(InvenTreeAPITestCase): - """ - Tests for the various OPTIONS endpoints in the /part/ API + """Tests for the various OPTIONS endpoints in the /part/ API. Ensure that the required field details are provided! """ @@ -105,15 +105,8 @@ class PartOptionsAPITest(InvenTreeAPITestCase): 'part.add', ] - def setUp(self): - - super().setUp() - def test_part(self): - """ - Test the Part API OPTIONS - """ - + """Test the Part API OPTIONS.""" actions = self.getActions(reverse('api-part-list'))['POST'] # Check that a bunch o' fields are contained @@ -147,10 +140,7 @@ class PartOptionsAPITest(InvenTreeAPITestCase): self.assertEqual(category['help_text'], 'Part category') def test_category(self): - """ - Test the PartCategory API OPTIONS endpoint - """ - + """Test the PartCategory API OPTIONS endpoint.""" actions = self.getActions(reverse('api-part-category-list')) # actions should *not* contain 'POST' as we do not have the correct role @@ -169,10 +159,7 @@ class PartOptionsAPITest(InvenTreeAPITestCase): self.assertEqual(loc['api_url'], reverse('api-location-list')) def test_bom_item(self): - """ - Test the BomItem API OPTIONS endpoint - """ - + """Test the BomItem API OPTIONS endpoint.""" actions = self.getActions(reverse('api-bom-list'))['POST'] inherited = actions['inherited'] @@ -195,8 +182,8 @@ class PartOptionsAPITest(InvenTreeAPITestCase): class PartAPITest(InvenTreeAPITestCase): - """ - Series of tests for the Part DRF API + """Series of tests for the Part DRF API. + - Tests for Part API - Tests for PartCategory API """ @@ -218,67 +205,57 @@ class PartAPITest(InvenTreeAPITestCase): 'part_category.add', ] - def setUp(self): - super().setUp() - def test_get_categories(self): - """ - Test that we can retrieve list of part categories, - with various filtering options. - """ - + """Test that we can retrieve list of part categories, with various filtering options.""" url = reverse('api-part-category-list') # Request *all* part categories - response = self.client.get(url, format='json') + response = self.get(url) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data), 8) # Request top-level part categories only - response = self.client.get( + response = self.get( url, { 'parent': 'null', - }, - format='json' + } ) self.assertEqual(len(response.data), 2) # Children of PartCategory<1>, cascade - response = self.client.get( + response = self.get( url, { 'parent': 1, 'cascade': 'true', - }, - format='json', + } ) self.assertEqual(len(response.data), 5) # Children of PartCategory<1>, do not cascade - response = self.client.get( + response = self.get( url, { 'parent': 1, 'cascade': 'false', - }, - format='json', + } ) self.assertEqual(len(response.data), 3) def test_add_categories(self): - """ Check that we can add categories """ + """Check that we can add categories.""" data = { 'name': 'Animals', 'description': 'All animals go here' } url = reverse('api-part-category-list') - response = self.client.post(url, data, format='json') + response = self.post(url, data) self.assertEqual(response.status_code, status.HTTP_201_CREATED) parent = response.data['pk'] @@ -290,19 +267,20 @@ class PartAPITest(InvenTreeAPITestCase): 'description': 'A sort of animal', 'parent': parent, } - response = self.client.post(url, data, format='json') + response = self.post(url, data) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.data['parent'], parent) self.assertEqual(response.data['name'], animal) self.assertEqual(response.data['pathstring'], 'Animals/' + animal) # There should be now 8 categories - response = self.client.get(url, format='json') + response = self.get(url) self.assertEqual(len(response.data), 12) def test_cat_detail(self): + """Test the PartCategoryDetail API endpoint""" url = reverse('api-part-category-detail', kwargs={'pk': 4}) - response = self.client.get(url, format='json') + response = self.get(url) # Test that we have retrieved the category self.assertEqual(response.data['description'], 'Integrated Circuits') @@ -313,22 +291,22 @@ class PartAPITest(InvenTreeAPITestCase): data['name'] = 'Changing category' data['parent'] = None data['description'] = 'Changing the description' - response = self.client.patch(url, data, format='json') + response = self.patch(url, data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['description'], 'Changing the description') self.assertIsNone(response.data['parent']) - def test_get_all_parts(self): + def test_filter_parts(self): + """Test part filtering using the API""" url = reverse('api-part-list') data = {'cascade': True} - response = self.client.get(url, data, format='json') + response = self.get(url, data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data), Part.objects.count()) - def test_get_parts_by_cat(self): - url = reverse('api-part-list') + # Test filtering parts by category data = {'category': 2} - response = self.client.get(url, data, format='json') + response = self.get(url, data) self.assertEqual(response.status_code, status.HTTP_200_OK) # There should only be 2 objects in category C @@ -337,39 +315,62 @@ class PartAPITest(InvenTreeAPITestCase): for part in response.data: self.assertEqual(part['category'], 2) + def test_filter_by_related(self): + """Test that we can filter by the 'related' status""" + url = reverse('api-part-list') + + # Initially there are no relations, so this should return zero results + response = self.get(url, {'related': 1}, expected_code=200) + self.assertEqual(len(response.data), 0) + + # Add some relationships + PartRelated.objects.create( + part_1=Part.objects.get(pk=1), + part_2=Part.objects.get(pk=2), + ) + + PartRelated.objects.create( + part_2=Part.objects.get(pk=1), + part_1=Part.objects.get(pk=3) + ) + + response = self.get(url, {'related': 1}, expected_code=200) + self.assertEqual(len(response.data), 2) + def test_include_children(self): - """ Test the special 'include_child_categories' flag + """Test the special 'include_child_categories' flag. + If provided, parts are provided for ANY child category (recursive) """ url = reverse('api-part-list') data = {'category': 1, 'cascade': True} # Now request to include child categories - response = self.client.get(url, data, format='json') + response = self.get(url, data) # Now there should be 5 total parts self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data), 3) def test_test_templates(self): - + """Test the PartTestTemplate API""" url = reverse('api-part-test-template-list') # List ALL items - response = self.client.get(url) + response = self.get(url) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data), 7) # Request for a particular part - response = self.client.get(url, data={'part': 10000}) + response = self.get(url, data={'part': 10000}) self.assertEqual(len(response.data), 5) - response = self.client.get(url, data={'part': 10004}) + response = self.get(url, data={'part': 10004}) self.assertEqual(len(response.data), 7) # Try to post a new object (missing description) - response = self.client.post( + response = self.post( url, data={ 'part': 10000, @@ -381,34 +382,32 @@ class PartAPITest(InvenTreeAPITestCase): self.assertEqual(response.status_code, 400) # Try to post a new object (should succeed) - response = self.client.post( + response = self.post( url, data={ 'part': 10000, 'test_name': 'New Test', 'required': True, 'description': 'a test description' - }, - format='json', + } ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Try to post a new test with the same name (should fail) - response = self.client.post( + response = self.post( url, data={ 'part': 10004, 'test_name': " newtest", 'description': 'dafsdf', - }, - format='json', + } ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) # Try to post a new test against a non-trackable part (should fail) - response = self.client.post( + response = self.post( url, data={ 'part': 1, @@ -419,21 +418,15 @@ class PartAPITest(InvenTreeAPITestCase): self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def test_get_thumbs(self): - """ - Return list of part thumbnails - """ - + """Return list of part thumbnails.""" url = reverse('api-part-thumbs') - response = self.client.get(url) + response = self.get(url) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_paginate(self): - """ - Test pagination of the Part list API - """ - + """Test pagination of the Part list API.""" for n in [1, 5, 10]: response = self.get(reverse('api-part-list'), {'limit': n}) @@ -445,16 +438,14 @@ class PartAPITest(InvenTreeAPITestCase): self.assertEqual(len(data['results']), n) def test_default_values(self): - """ - Tests for 'default' values: + """Tests for 'default' values: Ensure that unspecified fields revert to "default" values (as specified in the model field definition) """ - url = reverse('api-part-list') - response = self.client.post(url, { + response = self.post(url, { 'name': 'all defaults', 'description': 'my test part', 'category': 1, @@ -476,7 +467,7 @@ class PartAPITest(InvenTreeAPITestCase): self.user ) - response = self.client.post(url, { + response = self.post(url, { 'name': 'all defaults', 'description': 'my test part 2', 'category': 1, @@ -486,7 +477,7 @@ class PartAPITest(InvenTreeAPITestCase): self.assertTrue(response.data['purchaseable']) # "default" values should not be used if the value is specified - response = self.client.post(url, { + response = self.post(url, { 'name': 'all defaults', 'description': 'my test part 2', 'category': 1, @@ -498,10 +489,7 @@ class PartAPITest(InvenTreeAPITestCase): self.assertFalse(response.data['purchaseable']) def test_initial_stock(self): - """ - Tests for initial stock quantity creation - """ - + """Tests for initial stock quantity creation.""" url = reverse('api-part-list') # Track how many parts exist at the start of this test @@ -555,10 +543,7 @@ class PartAPITest(InvenTreeAPITestCase): self.assertEqual(new_part.total_stock, 12345) def test_initial_supplier_data(self): - """ - Tests for initial creation of supplier / manufacturer data - """ - + """Tests for initial creation of supplier / manufacturer data.""" url = reverse('api-part-list') n = Part.objects.count() @@ -620,10 +605,7 @@ class PartAPITest(InvenTreeAPITestCase): self.assertEqual(new_part.manufacturer_parts.count(), 1) def test_strange_chars(self): - """ - Test that non-standard ASCII chars are accepted - """ - + """Test that non-standard ASCII chars are accepted.""" url = reverse('api-part-list') name = "Kaltgerätestecker" @@ -641,15 +623,13 @@ class PartAPITest(InvenTreeAPITestCase): self.assertEqual(response.data['description'], description) def test_template_filters(self): - """ - Unit tests for API filters related to template parts: + """Unit tests for API filters related to template parts: - variant_of : Return children of specified part - ancestor : Return descendants of specified part Uses the 'chair template' part (pk=10000) """ - # Rebuild the MPTT structure before running these tests Part.objects.rebuild() @@ -732,11 +712,7 @@ class PartAPITest(InvenTreeAPITestCase): self.assertEqual(len(response.data), 101) def test_variant_stock(self): - """ - Unit tests for the 'variant_stock' annotation, - which provides a stock count for *variant* parts - """ - + """Unit tests for the 'variant_stock' annotation, which provides a stock count for *variant* parts.""" # Ensure the MPTT structure is in a known state before running tests Part.objects.rebuild() @@ -820,8 +796,7 @@ class PartAPITest(InvenTreeAPITestCase): self.assertEqual(response.data['variant_stock'], 500) def test_part_download(self): - """Test download of part data via the API""" - + """Test download of part data via the API.""" url = reverse('api-part-list') required_cols = [ @@ -873,9 +848,7 @@ class PartAPITest(InvenTreeAPITestCase): class PartDetailTests(InvenTreeAPITestCase): - """ - Test that we can create / edit / delete Part objects via the API - """ + """Test that we can create / edit / delete Part objects via the API.""" fixtures = [ 'category', @@ -898,14 +871,12 @@ class PartDetailTests(InvenTreeAPITestCase): 'part_category.add', ] - def setUp(self): - super().setUp() - def test_part_operations(self): + """Test that Part instances can be adjusted via the API""" n = Part.objects.count() # Create a part - response = self.client.post( + response = self.post( reverse('api-part-list'), { 'name': 'my test api part', @@ -930,7 +901,7 @@ class PartDetailTests(InvenTreeAPITestCase): # Let's change the name of the part - response = self.client.patch(url, { + response = self.patch(url, { 'name': 'a new better name', }) @@ -948,14 +919,14 @@ class PartDetailTests(InvenTreeAPITestCase): # Now, try to set the name to the *same* value # 2021-06-22 this test is to check that the "duplicate part" checks don't do strange things - response = self.client.patch(url, { + response = self.patch(url, { 'name': 'a new better name', }) self.assertEqual(response.status_code, 200) # Try to remove the part - response = self.client.delete(url) + response = self.delete(url) # As the part is 'active' we cannot delete it self.assertEqual(response.status_code, 405) @@ -963,19 +934,16 @@ class PartDetailTests(InvenTreeAPITestCase): # So, let's make it not active response = self.patch(url, {'active': False}, expected_code=200) - response = self.client.delete(url) + response = self.delete(url) self.assertEqual(response.status_code, 204) # Part count should have reduced self.assertEqual(Part.objects.count(), n) def test_duplicates(self): - """ - Check that trying to create 'duplicate' parts results in errors - """ - + """Check that trying to create 'duplicate' parts results in errors.""" # Create a part - response = self.client.post(reverse('api-part-list'), { + response = self.post(reverse('api-part-list'), { 'name': 'part', 'description': 'description', 'IPN': 'IPN-123', @@ -988,7 +956,7 @@ class PartDetailTests(InvenTreeAPITestCase): n = Part.objects.count() # Check that we cannot create a duplicate in a different category - response = self.client.post(reverse('api-part-list'), { + response = self.post(reverse('api-part-list'), { 'name': 'part', 'description': 'description', 'IPN': 'IPN-123', @@ -1011,7 +979,7 @@ class PartDetailTests(InvenTreeAPITestCase): self.assertEqual(Part.objects.count(), n) # But a different 'revision' *can* be created - response = self.client.post(reverse('api-part-list'), { + response = self.post(reverse('api-part-list'), { 'name': 'part', 'description': 'description', 'IPN': 'IPN-123', @@ -1028,18 +996,17 @@ class PartDetailTests(InvenTreeAPITestCase): url = reverse('api-part-detail', kwargs={'pk': pk}) # Attempt to alter the revision code - response = self.client.patch( + response = self.patch( url, { 'revision': 'A', - }, - format='json', + } ) self.assertEqual(response.status_code, 400) # But we *can* change it to a unique revision code - response = self.client.patch( + response = self.patch( url, { 'revision': 'C', @@ -1049,14 +1016,11 @@ class PartDetailTests(InvenTreeAPITestCase): self.assertEqual(response.status_code, 200) def test_image_upload(self): - """ - Test that we can upload an image to the part API - """ - + """Test that we can upload an image to the part API.""" self.assignRole('part.add') # Create a new part - response = self.client.post( + response = self.post( reverse('api-part-list'), { 'name': 'imagine', @@ -1120,10 +1084,7 @@ class PartDetailTests(InvenTreeAPITestCase): self.assertIsNotNone(p.image) def test_details(self): - """ - Test that the required details are available - """ - + """Test that the required details are available.""" p = Part.objects.get(pk=1) url = reverse('api-part-detail', kwargs={'pk': 1}) @@ -1152,10 +1113,7 @@ class PartDetailTests(InvenTreeAPITestCase): self.assertEqual(data['unallocated_stock'], 9000) def test_part_metadata(self): - """ - Tests for the part metadata endpoint - """ - + """Tests for the part metadata endpoint.""" url = reverse('api-part-metadata', kwargs={'pk': 1}) part = Part.objects.get(pk=1) @@ -1206,9 +1164,7 @@ class PartDetailTests(InvenTreeAPITestCase): class PartAPIAggregationTest(InvenTreeAPITestCase): - """ - Tests to ensure that the various aggregation annotations are working correctly... - """ + """Tests to ensure that the various aggregation annotations are working correctly...""" fixtures = [ 'category', @@ -1229,7 +1185,7 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): ] def setUp(self): - + """Create test data as part of setup routine""" super().setUp() # Ensure the part "variant" tree is correctly structured @@ -1253,9 +1209,10 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): StockItem.objects.create(part=self.part, quantity=400, status=StockStatus.LOST) def get_part_data(self): + """Helper function for retrieving part data""" url = reverse('api-part-list') - response = self.client.get(url, format='json') + response = self.get(url) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -1267,10 +1224,7 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): self.assertTrue(False) # pragma: no cover def test_stock_quantity(self): - """ - Simple test for the stock quantity - """ - + """Simple test for the stock quantity.""" data = self.get_part_data() self.assertEqual(data['in_stock'], 600) @@ -1290,11 +1244,10 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): self.assertEqual(data['stock_item_count'], 105) def test_allocation_annotations(self): - """ - Tests for query annotations which add allocation information. + """Tests for query annotations which add allocation information. + Ref: https://github.com/inventree/InvenTree/pull/2797 """ - # We are looking at Part ID 100 ("Bob") url = reverse('api-part-detail', kwargs={'pk': 100}) @@ -1438,9 +1391,7 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): class BomItemTest(InvenTreeAPITestCase): - """ - Unit tests for the BomItem API - """ + """Unit tests for the BomItem API.""" fixtures = [ 'category', @@ -1457,14 +1408,8 @@ class BomItemTest(InvenTreeAPITestCase): 'part.delete', ] - def setUp(self): - super().setUp() - def test_bom_list(self): - """ - Tests for the BomItem list endpoint - """ - + """Tests for the BomItem list endpoint.""" # How many BOM items currently exist in the database? n = BomItem.objects.count() @@ -1529,10 +1474,7 @@ class BomItemTest(InvenTreeAPITestCase): self.assertTrue(key in el) def test_get_bom_detail(self): - """ - Get the detail view for a single BomItem object - """ - + """Get the detail view for a single BomItem object.""" url = reverse('api-bom-item-detail', kwargs={'pk': 3}) response = self.get(url, expected_code=200) @@ -1570,10 +1512,7 @@ class BomItemTest(InvenTreeAPITestCase): self.assertEqual(response.data['note'], 'Added a note') def test_add_bom_item(self): - """ - Test that we can create a new BomItem via the API - """ - + """Test that we can create a new BomItem via the API.""" url = reverse('api-bom-list') data = { @@ -1587,13 +1526,10 @@ class BomItemTest(InvenTreeAPITestCase): # Now try to create a BomItem which references itself data['part'] = 100 data['sub_part'] = 100 - self.client.post(url, data, expected_code=400) + self.post(url, data, expected_code=400) def test_variants(self): - """ - Tests for BomItem use with variants - """ - + """Tests for BomItem use with variants.""" stock_url = reverse('api-stock-list') # BOM item we are interested in @@ -1675,10 +1611,7 @@ class BomItemTest(InvenTreeAPITestCase): self.assertEqual(len(response.data), 2) def test_substitutes(self): - """ - Tests for BomItem substitutes - """ - + """Tests for BomItem substitutes.""" url = reverse('api-bom-substitute-list') stock_url = reverse('api-stock-list') @@ -1760,10 +1693,7 @@ class BomItemTest(InvenTreeAPITestCase): self.assertEqual(data['available_stock'], 9000) def test_bom_item_uses(self): - """ - Tests for the 'uses' field - """ - + """Tests for the 'uses' field.""" url = reverse('api-bom-list') # Test that the direct 'sub_part' association works @@ -1813,10 +1743,7 @@ class BomItemTest(InvenTreeAPITestCase): self.assertEqual(len(response.data), i) def test_bom_variant_stock(self): - """ - Test for 'available_variant_stock' annotation - """ - + """Test for 'available_variant_stock' annotation.""" Part.objects.rebuild() # BOM item we are interested in @@ -1852,10 +1779,7 @@ class BomItemTest(InvenTreeAPITestCase): class PartParameterTest(InvenTreeAPITestCase): - """ - Tests for the ParParameter API - """ - + """Tests for the ParParameter API.""" superuser = True fixtures = [ @@ -1865,51 +1789,39 @@ class PartParameterTest(InvenTreeAPITestCase): 'params', ] - def setUp(self): - - super().setUp() - def test_list_params(self): - """ - Test for listing part parameters - """ - + """Test for listing part parameters.""" url = reverse('api-part-parameter-list') - response = self.client.get(url, format='json') + response = self.get(url) self.assertEqual(len(response.data), 5) # Filter by part - response = self.client.get( + response = self.get( url, { 'part': 3, - }, - format='json' + } ) self.assertEqual(len(response.data), 3) # Filter by template - response = self.client.get( + response = self.get( url, { 'template': 1, - }, - format='json', + } ) self.assertEqual(len(response.data), 3) def test_create_param(self): - """ - Test that we can create a param via the API - """ - + """Test that we can create a param via the API.""" url = reverse('api-part-parameter-list') - response = self.client.post( + response = self.post( url, { 'part': '2', @@ -1920,18 +1832,15 @@ class PartParameterTest(InvenTreeAPITestCase): self.assertEqual(response.status_code, 201) - response = self.client.get(url, format='json') + response = self.get(url) self.assertEqual(len(response.data), 6) def test_param_detail(self): - """ - Tests for the PartParameter detail endpoint - """ - + """Tests for the PartParameter detail endpoint.""" url = reverse('api-part-parameter-detail', kwargs={'pk': 5}) - response = self.client.get(url) + response = self.get(url) self.assertEqual(response.status_code, 200) @@ -1942,12 +1851,12 @@ class PartParameterTest(InvenTreeAPITestCase): self.assertEqual(data['data'], '12') # PATCH data back in - response = self.client.patch(url, {'data': '15'}, format='json') + response = self.patch(url, {'data': '15'}) self.assertEqual(response.status_code, 200) # Check that the data changed! - response = self.client.get(url, format='json') + response = self.get(url) data = response.data diff --git a/InvenTree/part/test_bom_export.py b/InvenTree/part/test_bom_export.py index d78ced4b17..d7e92d1da4 100644 --- a/InvenTree/part/test_bom_export.py +++ b/InvenTree/part/test_bom_export.py @@ -1,6 +1,4 @@ -""" -Unit testing for BOM export functionality -""" +"""Unit testing for BOM export functionality.""" import csv @@ -10,6 +8,7 @@ from InvenTree.helpers import InvenTreeTestCase class BomExportTest(InvenTreeTestCase): + """Class for performing unit testing of BOM export functionality""" fixtures = [ 'category', @@ -21,15 +20,13 @@ class BomExportTest(InvenTreeTestCase): roles = 'all' def setUp(self): + """Perform test setup functions""" super().setUp() self.url = reverse('bom-download', kwargs={'pk': 100}) def test_bom_template(self): - """ - Test that the BOM template can be downloaded from the server - """ - + """Test that the BOM template can be downloaded from the server.""" url = reverse('bom-upload-template') # Download an XLS template @@ -78,10 +75,7 @@ class BomExportTest(InvenTreeTestCase): self.assertTrue(header in headers) def test_export_csv(self): - """ - Test BOM download in CSV format - """ - + """Test BOM download in CSV format.""" params = { 'format': 'csv', 'cascade': True, @@ -142,10 +136,7 @@ class BomExportTest(InvenTreeTestCase): self.assertTrue(header in expected) def test_export_xls(self): - """ - Test BOM download in XLS format - """ - + """Test BOM download in XLS format.""" params = { 'format': 'xls', 'cascade': True, @@ -163,10 +154,7 @@ class BomExportTest(InvenTreeTestCase): self.assertEqual(content, 'attachment; filename="BOB | Bob | A2_BOM.xls"') def test_export_xlsx(self): - """ - Test BOM download in XLSX format - """ - + """Test BOM download in XLSX format.""" params = { 'format': 'xlsx', 'cascade': True, @@ -181,10 +169,7 @@ class BomExportTest(InvenTreeTestCase): self.assertEqual(response.status_code, 200) def test_export_json(self): - """ - Test BOM download in JSON format - """ - + """Test BOM download in JSON format.""" params = { 'format': 'json', 'cascade': True, diff --git a/InvenTree/part/test_bom_import.py b/InvenTree/part/test_bom_import.py index a9c853ddc4..71f47510a1 100644 --- a/InvenTree/part/test_bom_import.py +++ b/InvenTree/part/test_bom_import.py @@ -1,6 +1,4 @@ -""" -Unit testing for BOM upload / import functionality -""" +"""Unit testing for BOM upload / import functionality.""" from django.core.files.uploadedfile import SimpleUploadedFile from django.urls import reverse @@ -12,9 +10,7 @@ from part.models import Part class BomUploadTest(InvenTreeAPITestCase): - """ - Test BOM file upload API endpoint - """ + """Test BOM file upload API endpoint.""" roles = [ 'part.add', @@ -22,6 +18,7 @@ class BomUploadTest(InvenTreeAPITestCase): ] def setUp(self): + """Create BOM data as part of setup routine""" super().setUp() self.part = Part.objects.create( @@ -41,7 +38,7 @@ class BomUploadTest(InvenTreeAPITestCase): ) def post_bom(self, filename, file_data, clear_existing=None, expected_code=None, content_type='text/plain'): - + """Helper function for submitting a BOM file""" bom_file = SimpleUploadedFile( filename, file_data, @@ -63,10 +60,7 @@ class BomUploadTest(InvenTreeAPITestCase): return response def test_missing_file(self): - """ - POST without a file - """ - + """POST without a file.""" response = self.post( reverse('api-bom-import-upload'), data={}, @@ -76,10 +70,7 @@ class BomUploadTest(InvenTreeAPITestCase): self.assertIn('No file was submitted', str(response.data['data_file'])) def test_unsupported_file(self): - """ - POST with an unsupported file type - """ - + """POST with an unsupported file type.""" response = self.post_bom( 'sample.txt', b'hello world', @@ -89,10 +80,7 @@ class BomUploadTest(InvenTreeAPITestCase): self.assertIn('Unsupported file type', str(response.data['data_file'])) def test_broken_file(self): - """ - Test upload with broken (corrupted) files - """ - + """Test upload with broken (corrupted) files.""" response = self.post_bom( 'sample.csv', b'', @@ -111,10 +99,7 @@ class BomUploadTest(InvenTreeAPITestCase): self.assertIn('Unsupported format, or corrupt file', str(response.data['data_file'])) def test_missing_rows(self): - """ - Test upload of an invalid file (without data rows) - """ - + """Test upload of an invalid file (without data rows)""" dataset = tablib.Dataset() dataset.headers = [ @@ -142,10 +127,7 @@ class BomUploadTest(InvenTreeAPITestCase): self.assertIn('No data rows found in file', str(response.data)) def test_missing_columns(self): - """ - Upload extracted data, but with missing columns - """ - + """Upload extracted data, but with missing columns.""" url = reverse('api-bom-import-extract') rows = [ @@ -195,10 +177,7 @@ class BomUploadTest(InvenTreeAPITestCase): ) def test_invalid_data(self): - """ - Upload data which contains errors - """ - + """Upload data which contains errors.""" dataset = tablib.Dataset() # Only these headers are strictly necessary @@ -241,10 +220,7 @@ class BomUploadTest(InvenTreeAPITestCase): self.assertEqual(rows[5]['data']['errors']['part'], 'Part is not designated as a component') def test_part_guess(self): - """ - Test part 'guessing' when PK values are not supplied - """ - + """Test part 'guessing' when PK values are not supplied.""" dataset = tablib.Dataset() # Should be able to 'guess' the part from the name @@ -304,10 +280,7 @@ class BomUploadTest(InvenTreeAPITestCase): self.assertEqual(rows[idx]['data']['part'], components[idx].pk) def test_levels(self): - """ - Test that multi-level BOMs are correctly handled during upload - """ - + """Test that multi-level BOMs are correctly handled during upload.""" url = reverse('api-bom-import-extract') dataset = tablib.Dataset() diff --git a/InvenTree/part/test_bom_item.py b/InvenTree/part/test_bom_item.py index 0a48733c85..58c5dfb032 100644 --- a/InvenTree/part/test_bom_item.py +++ b/InvenTree/part/test_bom_item.py @@ -1,4 +1,6 @@ +"""Unit tests for the BomItem model""" + from decimal import Decimal import django.core.exceptions as django_exceptions @@ -9,6 +11,7 @@ from .models import BomItem, BomItemSubstitute, Part class BomItemTest(TestCase): + """Class for unit testing BomItem model""" fixtures = [ 'category', @@ -22,21 +25,25 @@ class BomItemTest(TestCase): ] def setUp(self): + """Create initial data""" self.bob = Part.objects.get(id=100) self.orphan = Part.objects.get(name='Orphan') self.r1 = Part.objects.get(name='R_2K2_0805') def test_str(self): + """Test the string representation of a BOMItem""" b = BomItem.objects.get(id=1) self.assertEqual(str(b), '10 x M2x4 LPHS to make BOB | Bob | A2') def test_has_bom(self): + """Test the has_bom attribute""" self.assertFalse(self.orphan.has_bom) self.assertTrue(self.bob.has_bom) self.assertEqual(self.bob.bom_count, 4) def test_in_bom(self): + """Test BOM aggregation""" parts = self.bob.getRequiredParts() self.assertIn(self.orphan, parts) @@ -44,22 +51,19 @@ class BomItemTest(TestCase): self.assertTrue(self.bob.check_if_part_in_bom(self.orphan)) def test_used_in(self): + """Test that the 'used_in_count' attribute is calculated correctly""" self.assertEqual(self.bob.used_in_count, 1) self.assertEqual(self.orphan.used_in_count, 1) def test_self_reference(self): - """ Test that we get an appropriate error when we create a BomItem which points to itself """ - + """Test that we get an appropriate error when we create a BomItem which points to itself.""" with self.assertRaises(django_exceptions.ValidationError): # A validation error should be raised here item = BomItem.objects.create(part=self.bob, sub_part=self.bob, quantity=7) item.clean() # pragma: no cover def test_integer_quantity(self): - """ - Test integer validation for BomItem - """ - + """Test integer validation for BomItem.""" p = Part.objects.create(name="test", description="d", component=True, trackable=True) # Creation of a BOMItem with a non-integer quantity of a trackable Part should fail @@ -70,8 +74,7 @@ class BomItemTest(TestCase): BomItem.objects.create(part=self.bob, sub_part=p, quantity=21) def test_overage(self): - """ Test that BOM line overages are calculated correctly """ - + """Test that BOM line overages are calculated correctly.""" item = BomItem.objects.get(part=100, sub_part=50) q = 300 @@ -106,8 +109,7 @@ class BomItemTest(TestCase): self.assertEqual(n, 3150) def test_item_hash(self): - """ Test BOM item hash encoding """ - + """Test BOM item hash encoding.""" item = BomItem.objects.get(part=100, sub_part=50) h1 = item.get_item_hash() @@ -122,6 +124,7 @@ class BomItemTest(TestCase): self.assertNotEqual(h1, h2) def test_pricing(self): + """Test BOM pricing""" self.bob.get_price(1) self.assertEqual( self.bob.get_bom_price_range(1, internal=True), @@ -135,10 +138,7 @@ class BomItemTest(TestCase): ) def test_substitutes(self): - """ - Tests for BOM item substitutes - """ - + """Tests for BOM item substitutes.""" # We will make some subtitute parts for the "orphan" part bom_item = BomItem.objects.get( part=self.bob, diff --git a/InvenTree/part/test_category.py b/InvenTree/part/test_category.py index 4da893f188..267e9a9733 100644 --- a/InvenTree/part/test_category.py +++ b/InvenTree/part/test_category.py @@ -1,3 +1,5 @@ +"""Unit tests for the PartCategory model""" + from django.core.exceptions import ValidationError from django.test import TestCase @@ -5,8 +7,7 @@ from .models import Part, PartCategory, PartParameter, PartParameterTemplate class CategoryTest(TestCase): - """ - Tests to ensure that the relational category tree functions correctly. + """Tests to ensure that the relational category tree functions correctly. Loads the following test fixtures: - category.yaml @@ -19,7 +20,7 @@ class CategoryTest(TestCase): ] def setUp(self): - # Extract some interesting categories for time-saving + """Extract some interesting categories for time-saving""" self.electronics = PartCategory.objects.get(name='Electronics') self.mechanical = PartCategory.objects.get(name='Mechanical') self.resistors = PartCategory.objects.get(name='Resistors') @@ -29,9 +30,7 @@ class CategoryTest(TestCase): self.transceivers = PartCategory.objects.get(name='Transceivers') def test_parents(self): - """ Test that the parent fields are properly set, - based on the test fixtures """ - + """Test that the parent fields are properly set, based on the test fixtures.""" self.assertEqual(self.resistors.parent, self.electronics) self.assertEqual(self.capacitors.parent, self.electronics) self.assertEqual(self.electronics.parent, None) @@ -39,8 +38,7 @@ class CategoryTest(TestCase): self.assertEqual(self.fasteners.parent, self.mechanical) def test_children_count(self): - """ Test that categories have the correct number of children """ - + """Test that categories have the correct number of children.""" self.assertTrue(self.electronics.has_children) self.assertTrue(self.mechanical.has_children) @@ -48,8 +46,7 @@ class CategoryTest(TestCase): self.assertEqual(len(self.mechanical.children.all()), 1) def test_unique_childs(self): - """ Test the 'unique_children' functionality """ - + """Test the 'unique_children' functionality.""" childs = [item.pk for item in self.electronics.getUniqueChildren()] self.assertIn(self.transceivers.id, childs) @@ -58,8 +55,7 @@ class CategoryTest(TestCase): self.assertNotIn(self.fasteners.id, childs) def test_unique_parents(self): - """ Test the 'unique_parents' functionality """ - + """Test the 'unique_parents' functionality.""" parents = [item.pk for item in self.transceivers.getUniqueParents()] self.assertIn(self.electronics.id, parents) @@ -67,22 +63,16 @@ class CategoryTest(TestCase): self.assertNotIn(self.fasteners.id, parents) def test_path_string(self): - """ Test that the category path string works correctly """ - + """Test that the category path string works correctly.""" self.assertEqual(str(self.resistors), 'Electronics/Resistors - Resistors') self.assertEqual(str(self.transceivers.pathstring), 'Electronics/IC/Transceivers') def test_url(self): - """ Test that the PartCategory URL works """ - + """Test that the PartCategory URL works.""" self.assertEqual(self.capacitors.get_absolute_url(), '/part/category/3/') def test_part_count(self): - """ Test that the Category part count works """ - - self.assertTrue(self.resistors.has_parts) - self.assertTrue(self.fasteners.has_parts) - self.assertFalse(self.transceivers.has_parts) + """Test that the Category part count works.""" self.assertEqual(self.fasteners.partcount(), 2) self.assertEqual(self.capacitors.partcount(), 1) @@ -96,8 +86,7 @@ class CategoryTest(TestCase): self.assertEqual(self.electronics.item_count, self.electronics.partcount()) def test_parameters(self): - """ Test that the Category parameters are correctly fetched """ - + """Test that the Category parameters are correctly fetched.""" # Check number of SQL queries to iterate other parameters with self.assertNumQueries(7): # Prefetch: 3 queries (parts, parameters and parameters_template) @@ -121,8 +110,7 @@ class CategoryTest(TestCase): self.assertEqual(len(part_parameter), 1) def test_invalid_name(self): - # Test that an illegal character is prohibited in a category name - + """Test that an illegal character is prohibited in a category name""" cat = PartCategory(name='test/with/illegal/chars', description='Test category', parent=None) with self.assertRaises(ValidationError) as err: @@ -135,8 +123,7 @@ class CategoryTest(TestCase): cat.save() def test_delete(self): - """ Test that category deletion moves the children properly """ - + """Test that category deletion moves the children properly.""" # Delete the 'IC' category and 'Transceiver' should move to be under 'Electronics' self.assertEqual(self.transceivers.parent, self.ic) self.assertEqual(self.ic.parent, self.electronics) @@ -156,8 +143,7 @@ class CategoryTest(TestCase): self.assertEqual(f.category, self.mechanical) def test_default_locations(self): - """ Test traversal for default locations """ - + """Test traversal for default locations.""" self.assertEqual(str(self.fasteners.default_location), 'Office/Drawer_1 - In my desk') # Any part under electronics should default to 'Home' @@ -174,12 +160,11 @@ class CategoryTest(TestCase): self.assertIsNone(w.get_default_location()) def test_category_tree(self): - """ - Unit tests for the part category tree structure (MPTT) + """Unit tests for the part category tree structure (MPTT) + Ensure that the MPTT structure is rebuilt correctly, and the correct ancestor tree is observed. """ - # Clear out any existing parts Part.objects.all().delete() diff --git a/InvenTree/part/test_migrations.py b/InvenTree/part/test_migrations.py index 9da287519e..f3562af872 100644 --- a/InvenTree/part/test_migrations.py +++ b/InvenTree/part/test_migrations.py @@ -1,6 +1,4 @@ -""" -Unit tests for the part model database migrations -""" +"""Unit tests for the part model database migrations.""" from django_test_migrations.contrib.unittest_case import MigratorTestCase @@ -8,18 +6,13 @@ from InvenTree import helpers class TestForwardMigrations(MigratorTestCase): - """ - Test entire schema migration sequence for the part app - """ + """Test entire schema migration sequence for the part app.""" migrate_from = ('part', helpers.getOldestMigrationFile('part')) migrate_to = ('part', helpers.getNewestMigrationFile('part')) def prepare(self): - """ - Create initial data - """ - + """Create initial data.""" Part = self.old_state.apps.get_model('part', 'part') Part.objects.create(name='A', description='My part A') @@ -39,7 +32,7 @@ class TestForwardMigrations(MigratorTestCase): print(p.is_template) def test_models_exist(self): - + """Test that the Part model can still be accessed at the end of schema migration""" Part = self.new_state.apps.get_model('part', 'part') self.assertEqual(Part.objects.count(), 5) @@ -49,3 +42,7 @@ class TestForwardMigrations(MigratorTestCase): part.save() part.is_template = False part.save() + + for name in ['A', 'C', 'E']: + part = Part.objects.get(name=name) + self.assertEqual(part.description, f"My part {name}") diff --git a/InvenTree/part/test_param.py b/InvenTree/part/test_param.py index 1db23f0be2..7a13bfa6ce 100644 --- a/InvenTree/part/test_param.py +++ b/InvenTree/part/test_param.py @@ -1,4 +1,4 @@ -# Tests for Part Parameters +"""Various unit tests for Part Parameters""" import django.core.exceptions as django_exceptions from django.test import TestCase, TransactionTestCase @@ -8,6 +8,7 @@ from .models import (Part, PartCategory, PartCategoryParameterTemplate, class TestParams(TestCase): + """Unit test class for testing the PartParameter model""" fixtures = [ 'location', @@ -17,7 +18,7 @@ class TestParams(TestCase): ] def test_str(self): - + """Test the str representation of the PartParameterTemplate model""" t1 = PartParameterTemplate.objects.get(pk=1) self.assertEqual(str(t1), 'Length (mm)') @@ -28,7 +29,7 @@ class TestParams(TestCase): self.assertEqual(str(c1), 'Mechanical | Length | 2.8') def test_validate(self): - + """Test validation for part templates""" n = PartParameterTemplate.objects.all().count() t1 = PartParameterTemplate(name='abcde', units='dd') @@ -44,6 +45,7 @@ class TestParams(TestCase): class TestCategoryTemplates(TransactionTestCase): + """Test class for PartCategoryParameterTemplate model""" fixtures = [ 'location', @@ -53,7 +55,7 @@ class TestCategoryTemplates(TransactionTestCase): ] def test_validate(self): - + """Test that category templates are correctly applied to Part instances""" # Category templates n = PartCategoryParameterTemplate.objects.all().count() self.assertEqual(n, 2) @@ -79,6 +81,7 @@ class TestCategoryTemplates(TransactionTestCase): 'main': True, 'parent': True, } + # Save it with category parameters part.save(**{'add_category_templates': add_category_templates}) diff --git a/InvenTree/part/test_part.py b/InvenTree/part/test_part.py index d3b8e85a02..46f1473983 100644 --- a/InvenTree/part/test_part.py +++ b/InvenTree/part/test_part.py @@ -1,4 +1,4 @@ -# Tests for the Part model +"""Tests for the Part model.""" import os @@ -21,42 +21,45 @@ from .templatetags import inventree_extras class TemplateTagTest(InvenTreeTestCase): - """ Tests for the custom template tag code """ + """Tests for the custom template tag code.""" def test_define(self): + """Test the 'define' template tag""" self.assertEqual(int(inventree_extras.define(3)), 3) def test_str2bool(self): + """Various test for the str2bool template tag""" self.assertEqual(int(inventree_extras.str2bool('true')), True) self.assertEqual(int(inventree_extras.str2bool('yes')), True) self.assertEqual(int(inventree_extras.str2bool('none')), False) self.assertEqual(int(inventree_extras.str2bool('off')), False) - def test_inrange(self): - self.assertEqual(inventree_extras.inrange(3), range(3)) - - def test_multiply(self): - self.assertEqual(int(inventree_extras.multiply(3, 5)), 15) - def test_add(self): + """Test that the 'add""" self.assertEqual(int(inventree_extras.add(3, 5)), 8) def test_plugins_enabled(self): + """Test the plugins_enabled tag""" self.assertEqual(inventree_extras.plugins_enabled(), True) def test_inventree_instance_name(self): + """Test the 'instance name' setting""" self.assertEqual(inventree_extras.inventree_instance_name(), 'InvenTree server') def test_inventree_base_url(self): + """Test that the base URL tag returns correctly""" self.assertEqual(inventree_extras.inventree_base_url(), '') def test_inventree_is_release(self): + """Test that the release version check functions as expected""" self.assertEqual(inventree_extras.inventree_is_release(), not version.isInvenTreeDevelopmentVersion()) def test_inventree_docs_version(self): + """Test that the documentation version template tag returns correctly""" self.assertEqual(inventree_extras.inventree_docs_version(), version.inventreeDocsVersion()) def test_hash(self): + """Test that the commit hash template tag returns correctly""" result_hash = inventree_extras.inventree_commit_hash() if settings.DOCKER: # pragma: no cover # Testing inside docker environment *may* return an empty git commit hash @@ -66,6 +69,7 @@ class TemplateTagTest(InvenTreeTestCase): self.assertGreater(len(result_hash), 5) def test_date(self): + """Test that the commit date template tag returns correctly""" d = inventree_extras.inventree_commit_date() if settings.DOCKER: # pragma: no cover # Testing inside docker environment *may* return an empty git commit hash @@ -75,26 +79,33 @@ class TemplateTagTest(InvenTreeTestCase): self.assertEqual(len(d.split('-')), 3) def test_github(self): + """Test that the github URL template tag returns correctly""" self.assertIn('github.com', inventree_extras.inventree_github_url()) def test_docs(self): + """Test that the documentation URL template tag returns correctly""" self.assertIn('inventree.readthedocs.io', inventree_extras.inventree_docs_url()) def test_keyvalue(self): + """Test keyvalue template tag""" self.assertEqual(inventree_extras.keyvalue({'a': 'a'}, 'a'), 'a') def test_mail_configured(self): + """Test that mail configuration returns False""" self.assertEqual(inventree_extras.mail_configured(), False) def test_user_settings(self): + """Test user settings""" result = inventree_extras.user_settings(self.user) self.assertEqual(len(result), len(InvenTreeUserSetting.SETTINGS)) def test_global_settings(self): + """Test global settings""" result = inventree_extras.global_settings() self.assertEqual(len(result), len(InvenTreeSetting.SETTINGS)) def test_visible_global_settings(self): + """Test that hidden global settings are actually hidden""" result = inventree_extras.visible_global_settings() n = len(result) @@ -112,7 +123,7 @@ class TemplateTagTest(InvenTreeTestCase): class PartTest(TestCase): - """ Tests for the Part model """ + """Tests for the Part model.""" fixtures = [ 'category', @@ -122,6 +133,9 @@ class PartTest(TestCase): ] def setUp(self): + """Create some Part instances as part of init routine""" + super().setUp() + self.r1 = Part.objects.get(name='R_2K2_0805') self.r2 = Part.objects.get(name='R_4K7_0603') @@ -130,7 +144,7 @@ class PartTest(TestCase): Part.objects.rebuild() def test_tree(self): - # Test that the part variant tree is working properly + """Test that the part variant tree is working properly""" chair = Part.objects.get(pk=10000) self.assertEqual(chair.get_children().count(), 3) self.assertEqual(chair.get_descendant_count(), 4) @@ -142,14 +156,12 @@ class PartTest(TestCase): self.assertEqual(Part.objects.filter(tree_id=chair.tree_id).count(), 5) def test_str(self): + """Test string representation of a Part""" p = Part.objects.get(pk=100) self.assertEqual(str(p), "BOB | Bob | A2 - Can we build it?") def test_duplicate(self): - """ - Test that we cannot create a "duplicate" Part - """ - + """Test that we cannot create a "duplicate" Part.""" n = Part.objects.count() cat = PartCategory.objects.get(pk=1) @@ -201,10 +213,12 @@ class PartTest(TestCase): part_2.validate_unique() def test_attributes(self): + """Test Part attributes""" self.assertEqual(self.r1.name, 'R_2K2_0805') self.assertEqual(self.r1.get_absolute_url(), '/part/3/') def test_category(self): + """Test PartCategory path""" self.assertEqual(str(self.c1.category), 'Electronics/Capacitors - Capacitors') orphan = Part.objects.get(name='Orphan') @@ -212,26 +226,29 @@ class PartTest(TestCase): self.assertEqual(orphan.category_path, '') def test_rename_img(self): + """Test that an image can be renamed""" img = rename_part_image(self.r1, 'hello.png') self.assertEqual(img, os.path.join('part_images', 'hello.png')) def test_stock(self): - # No stock of any resistors + """Test case where there is zero stock""" res = Part.objects.filter(description__contains='resistor') for r in res: self.assertEqual(r.total_stock, 0) self.assertEqual(r.available_stock, 0) def test_barcode(self): + """Test barcode format functionality""" barcode = self.r1.format_barcode(brief=False) self.assertIn('InvenTree', barcode) self.assertIn(self.r1.name, barcode) def test_copy(self): + """Test that we can 'deep copy' a Part instance""" self.r2.deep_copy(self.r1, image=True, bom=True) def test_sell_pricing(self): - # check that the sell pricebreaks were loaded + """Check that the sell pricebreaks were loaded""" self.assertTrue(self.r1.has_price_breaks) self.assertEqual(self.r1.price_breaks.count(), 2) # check that the sell pricebreaks work @@ -239,7 +256,7 @@ class PartTest(TestCase): self.assertEqual(float(self.r1.get_price(10)), 1.0) def test_internal_pricing(self): - # check that the sell pricebreaks were loaded + """Check that the sell pricebreaks were loaded""" self.assertTrue(self.r1.has_internal_price_breaks) self.assertEqual(self.r1.internal_price_breaks.count(), 2) # check that the sell pricebreaks work @@ -247,8 +264,7 @@ class PartTest(TestCase): self.assertEqual(float(self.r1.get_internal_price(10)), 0.5) def test_metadata(self): - """Unit tests for the Part metadata field""" - + """Unit tests for the Part metadata field.""" p = Part.objects.get(pk=1) self.assertIsNone(p.metadata) @@ -266,6 +282,7 @@ class PartTest(TestCase): class TestTemplateTest(TestCase): + """Unit test for the TestTemplate class""" fixtures = [ 'category', @@ -275,7 +292,7 @@ class TestTemplateTest(TestCase): ] def test_template_count(self): - + """Tests for the test template functions""" chair = Part.objects.get(pk=10000) # Tests for the top-level chair object (nothing above it!) @@ -292,8 +309,7 @@ class TestTemplateTest(TestCase): self.assertEqual(variant.getTestTemplates(required=True).count(), 5) def test_uniqueness(self): - # Test names must be unique for this part and also parts above - + """Test names must be unique for this part and also parts above""" variant = Part.objects.get(pk=10004) with self.assertRaises(ValidationError): @@ -324,17 +340,13 @@ class TestTemplateTest(TestCase): class PartSettingsTest(InvenTreeTestCase): - """ - Tests to ensure that the user-configurable default values work as expected. + """Tests to ensure that the user-configurable default values work as expected. Some fields for the Part model can have default values specified by the user. """ def make_part(self): - """ - Helper function to create a simple part - """ - + """Helper function to create a simple part.""" part = Part.objects.create( name='Test Part', description='I am but a humble test part', @@ -344,20 +356,14 @@ class PartSettingsTest(InvenTreeTestCase): return part def test_defaults(self): - """ - Test that the default values for the part settings are correct - """ - + """Test that the default values for the part settings are correct.""" self.assertTrue(part.settings.part_component_default()) self.assertTrue(part.settings.part_purchaseable_default()) self.assertFalse(part.settings.part_salable_default()) self.assertFalse(part.settings.part_trackable_default()) def test_initial(self): - """ - Test the 'initial' default values (no default values have been set) - """ - + """Test the 'initial' default values (no default values have been set)""" part = self.make_part() self.assertTrue(part.component) @@ -366,10 +372,7 @@ class PartSettingsTest(InvenTreeTestCase): self.assertFalse(part.trackable) def test_custom(self): - """ - Update some of the part values and re-test - """ - + """Update some of the part values and re-test.""" for val in [True, False]: InvenTreeSetting.set_setting('PART_COMPONENT', val, self.user) InvenTreeSetting.set_setting('PART_PURCHASEABLE', val, self.user) @@ -395,10 +398,7 @@ class PartSettingsTest(InvenTreeTestCase): Part.objects.filter(pk=part.pk).delete() def test_duplicate_ipn(self): - """ - Test the setting which controls duplicate IPN values - """ - + """Test the setting which controls duplicate IPN values.""" # Create a part Part.objects.create(name='Hello', description='A thing', IPN='IPN123', revision='A') @@ -444,6 +444,7 @@ class PartSettingsTest(InvenTreeTestCase): class PartSubscriptionTests(InvenTreeTestCase): + """Unit tests for part 'subscription'""" fixtures = [ 'location', @@ -452,6 +453,7 @@ class PartSubscriptionTests(InvenTreeTestCase): ] def setUp(self): + """Create category and part data as part of setup routine""" super().setUp() # electronics / IC / MCU @@ -465,10 +467,7 @@ class PartSubscriptionTests(InvenTreeTestCase): ) def test_part_subcription(self): - """ - Test basic subscription against a part - """ - + """Test basic subscription against a part.""" # First check that the user is *not* subscribed to the part self.assertFalse(self.part.is_starred_by(self.user)) @@ -485,10 +484,7 @@ class PartSubscriptionTests(InvenTreeTestCase): self.assertFalse(self.part.is_starred_by(self.user)) def test_variant_subscription(self): - """ - Test subscription against a parent part - """ - + """Test subscription against a parent part.""" # Construct a sub-part to star against sub_part = Part.objects.create( name='sub_part', @@ -505,10 +501,7 @@ class PartSubscriptionTests(InvenTreeTestCase): self.assertTrue(sub_part.is_starred_by(self.user)) def test_category_subscription(self): - """ - Test subscription against a PartCategory - """ - + """Test subscription against a PartCategory.""" self.assertEqual(PartCategoryStar.objects.count(), 0) self.assertFalse(self.part.is_starred_by(self.user)) @@ -533,10 +526,7 @@ class PartSubscriptionTests(InvenTreeTestCase): self.assertFalse(self.part.is_starred_by(self.user)) def test_parent_category_subscription(self): - """ - Check that a parent category can be subscribed to - """ - + """Check that a parent category can be subscribed to.""" # Top-level "electronics" category cat = PartCategory.objects.get(pk=1) @@ -553,7 +543,7 @@ class PartSubscriptionTests(InvenTreeTestCase): class BaseNotificationIntegrationTest(InvenTreeTestCase): - """ Integration test for notifications """ + """Integration test for notifications.""" fixtures = [ 'location', @@ -563,6 +553,7 @@ class BaseNotificationIntegrationTest(InvenTreeTestCase): ] def setUp(self): + """Add an email address as part of initialization""" super().setUp() # Add Mailadress EmailAddress.objects.create(user=self.user, email='test@testing.com') @@ -571,8 +562,8 @@ class BaseNotificationIntegrationTest(InvenTreeTestCase): self.part = Part.objects.get(name='R_2K2_0805') def _notification_run(self, run_class=None): - """ - Run a notification test suit through. + """Run a notification test suit through. + If you only want to test one class pass it to run_class """ # reload notification methods @@ -597,9 +588,10 @@ class BaseNotificationIntegrationTest(InvenTreeTestCase): class PartNotificationTest(BaseNotificationIntegrationTest): - """ Integration test for part notifications """ + """Integration test for part notifications.""" def test_notification(self): + """Test that a notification is generated""" self._notification_run(UIMessageNotification) # There should be 1 notification message right now diff --git a/InvenTree/part/test_views.py b/InvenTree/part/test_views.py index 21ec6f7909..761356a5de 100644 --- a/InvenTree/part/test_views.py +++ b/InvenTree/part/test_views.py @@ -1,4 +1,4 @@ -""" Unit tests for Part Views (see views.py) """ +"""Unit tests for Part Views (see views.py)""" from django.urls import reverse @@ -8,6 +8,7 @@ from .models import Part class PartViewTestCase(InvenTreeTestCase): + """Base class for unit testing the various Part views""" fixtures = [ 'category', @@ -21,13 +22,12 @@ class PartViewTestCase(InvenTreeTestCase): roles = 'all' superuser = True - def setUp(self): - super().setUp() - class PartListTest(PartViewTestCase): + """Unit tests for the PartList view""" def test_part_index(self): + """Test that the PartIndex page returns successfully""" response = self.client.get(reverse('part-index')) self.assertEqual(response.status_code, 200) @@ -38,10 +38,10 @@ class PartListTest(PartViewTestCase): class PartDetailTest(PartViewTestCase): + """Unit tests for the PartDetail view""" def test_part_detail(self): - """ Test that we can retrieve a part detail page """ - + """Test that we can retrieve a part detail page.""" pk = 1 response = self.client.get(reverse('part-detail', args=(pk,))) @@ -58,8 +58,8 @@ class PartDetailTest(PartViewTestCase): self.assertEqual(response.context['category'], part.category) def test_part_detail_from_ipn(self): - """ - Test that we can retrieve a part detail page from part IPN: + """Test that we can retrieve a part detail page from part IPN: + - if no part with matching IPN -> return part index - if unique IPN match -> return part detail page - if multiple IPN matches -> return part index @@ -68,6 +68,7 @@ class PartDetailTest(PartViewTestCase): pk = 1 def test_ipn_match(index_result=False, detail_result=False): + """Helper function for matching IPN detail view""" index_redirect = False detail_redirect = False @@ -108,22 +109,22 @@ class PartDetailTest(PartViewTestCase): test_ipn_match(index_result=True, detail_result=False) def test_bom_download(self): - """ Test downloading a BOM for a valid part """ - + """Test downloading a BOM for a valid part.""" response = self.client.get(reverse('bom-download', args=(1,)), HTTP_X_REQUESTED_WITH='XMLHttpRequest') self.assertEqual(response.status_code, 200) self.assertIn('streaming_content', dir(response)) class PartQRTest(PartViewTestCase): - """ Tests for the Part QR Code AJAX view """ + """Tests for the Part QR Code AJAX view.""" def test_html_redirect(self): - # A HTML request for a QR code should be redirected (use an AJAX request instead) + """A HTML request for a QR code should be redirected (use an AJAX request instead)""" response = self.client.get(reverse('part-qr', args=(1,))) self.assertEqual(response.status_code, 302) def test_valid_part(self): + """Test QR code response for a Part""" response = self.client.get(reverse('part-qr', args=(1,)), HTTP_X_REQUESTED_WITH='XMLHttpRequest') self.assertEqual(response.status_code, 200) @@ -133,17 +134,17 @@ class PartQRTest(PartViewTestCase): self.assertIn(' str: + """Returns an encoded path for the provided dict.""" groups = [] for key, val in arguments.items(): groups.append(f'{key}={",".join([str(a) for a in val])}') return f'?{"&".join(groups)}' - def api_call(self, endpoint, method: str = 'GET', url_args=None, data=None, headers=None, simple_response: bool = True, endpoint_is_url: bool = False): + def api_call(self, endpoint: str, method: str = 'GET', url_args: dict = None, data=None, headers: dict = None, simple_response: bool = True, endpoint_is_url: bool = False): + """Do an API call. + + Simplest call example: + ```python + self.api_call('hello') + ``` + Will call the `{base_url}/hello` with a GET request and - if set - the token for this plugin. + + Args: + endpoint (str): Path to current endpoint. Either the endpoint or the full or if the flag is set + method (str, optional): HTTP method that should be uses - capitalized. Defaults to 'GET'. + url_args (dict, optional): arguments that should be appended to the url. Defaults to None. + data (Any, optional): Data that should be transmitted in the body - must be JSON serializable. Defaults to None. + headers (dict, optional): Headers that should be used for the request. Defaults to self.api_headers. + simple_response (bool, optional): Return the response as JSON. Defaults to True. + endpoint_is_url (bool, optional): The provided endpoint is the full url - do not use self.api_url as base. Defaults to False. + + Returns: + Response + """ if url_args: endpoint += self.api_build_url_args(url_args) @@ -487,8 +463,7 @@ class APICallMixin: class PanelMixin: - """ - Mixin which allows integration of custom 'panels' into a particular page. + """Mixin which allows integration of custom 'panels' into a particular page. The mixin provides a number of key functionalities: @@ -529,28 +504,29 @@ class PanelMixin: 'javascript': 'alert("You just loaded this panel!")', 'content': 'Hello world', } - """ class MixinMeta: + """Meta for mixin.""" + MIXIN_NAME = 'Panel' def __init__(self): + """Register mixin.""" super().__init__() self.add_mixin('panel', True, __class__) def get_custom_panels(self, view, request): - """ This method *must* be implemented by the plugin class """ + """This method *must* be implemented by the plugin class.""" raise MixinNotImplementedError(f"{__class__} is missing the 'get_custom_panels' method") def get_panel_context(self, view, request, context): - """ - Build the context data to be used for template rendering. + """Build the context data to be used for template rendering. + Custom class can override this to provide any custom context data. (See the example in "custom_panel_sample.py") """ - # Provide some standard context items to the template for rendering context['plugin'] = self context['request'] = request @@ -565,7 +541,16 @@ class PanelMixin: return context def render_panels(self, view, request, context): + """Get panels for a view. + Args: + view: Current view context + request: Current request for passthrough + context: Rendering context + + Returns: + Array of panels + """ panels = [] # Construct an updated context object for template rendering diff --git a/InvenTree/plugin/base/integration/test_mixins.py b/InvenTree/plugin/base/integration/test_mixins.py index e91f4da365..40cb1591b8 100644 --- a/InvenTree/plugin/base/integration/test_mixins.py +++ b/InvenTree/plugin/base/integration/test_mixins.py @@ -1,4 +1,4 @@ -""" Unit tests for base mixins for plugins """ +"""Unit tests for base mixins for plugins.""" from django.conf import settings from django.test import TestCase @@ -17,7 +17,10 @@ from plugin.urls import PLUGIN_BASE class BaseMixinDefinition: + """Mixin to test the meta functions of all mixins.""" + def test_mixin_name(self): + """Test that the mixin registers itseld correctly.""" # mixin name self.assertIn(self.MIXIN_NAME, [item['key'] for item in self.mixin.registered_mixins]) # human name @@ -25,6 +28,8 @@ class BaseMixinDefinition: class SettingsMixinTest(BaseMixinDefinition, InvenTreeTestCase): + """Tests for SettingsMixin.""" + MIXIN_HUMAN_NAME = 'Settings' MIXIN_NAME = 'settings' MIXIN_ENABLE_CHECK = 'has_settings' @@ -32,6 +37,7 @@ class SettingsMixinTest(BaseMixinDefinition, InvenTreeTestCase): TEST_SETTINGS = {'SETTING1': {'default': '123', }} def setUp(self): + """Setup for all tests.""" class SettingsCls(SettingsMixin, InvenTreePlugin): SETTINGS = self.TEST_SETTINGS self.mixin = SettingsCls() @@ -43,6 +49,7 @@ class SettingsMixinTest(BaseMixinDefinition, InvenTreeTestCase): super().setUp() def test_function(self): + """Test that the mixin functions.""" # settings variable self.assertEqual(self.mixin.settings, self.TEST_SETTINGS) @@ -60,11 +67,14 @@ class SettingsMixinTest(BaseMixinDefinition, InvenTreeTestCase): class UrlsMixinTest(BaseMixinDefinition, TestCase): + """Tests for UrlsMixin.""" + MIXIN_HUMAN_NAME = 'URLs' MIXIN_NAME = 'urls' MIXIN_ENABLE_CHECK = 'has_urls' def setUp(self): + """Setup for all tests.""" class UrlsCls(UrlsMixin, InvenTreePlugin): def test(): return 'ccc' @@ -76,6 +86,7 @@ class UrlsMixinTest(BaseMixinDefinition, TestCase): self.mixin_nothing = NoUrlsCls() def test_function(self): + """Test that the mixin functions.""" plg_name = self.mixin.plugin_name() # base_url @@ -99,26 +110,32 @@ class UrlsMixinTest(BaseMixinDefinition, TestCase): class AppMixinTest(BaseMixinDefinition, TestCase): + """Tests for AppMixin.""" + MIXIN_HUMAN_NAME = 'App registration' MIXIN_NAME = 'app' MIXIN_ENABLE_CHECK = 'has_app' def setUp(self): + """Setup for all tests.""" class TestCls(AppMixin, InvenTreePlugin): pass self.mixin = TestCls() def test_function(self): - # test that this plugin is in settings + """Test that the sample plugin registers in settings.""" self.assertIn('plugin.samples.integration', settings.INSTALLED_APPS) class NavigationMixinTest(BaseMixinDefinition, TestCase): + """Tests for NavigationMixin.""" + MIXIN_HUMAN_NAME = 'Navigation Links' MIXIN_NAME = 'navigation' MIXIN_ENABLE_CHECK = 'has_naviation' def setUp(self): + """Setup for all tests.""" class NavigationCls(NavigationMixin, InvenTreePlugin): NAVIGATION = [ {'name': 'aa', 'link': 'plugin:test:test_view'}, @@ -131,6 +148,7 @@ class NavigationMixinTest(BaseMixinDefinition, TestCase): self.nothing_mixin = NothingNavigationCls() def test_function(self): + """Test that a correct configuration functions.""" # check right configuration self.assertEqual(self.mixin.navigation, [{'name': 'aa', 'link': 'plugin:test:test_view'}, ]) @@ -139,7 +157,7 @@ class NavigationMixinTest(BaseMixinDefinition, TestCase): self.assertEqual(self.nothing_mixin.navigation_name, '') def test_fail(self): - # check wrong links fails + """Test that wrong links fail.""" with self.assertRaises(NotImplementedError): class NavigationCls(NavigationMixin, InvenTreePlugin): NAVIGATION = ['aa', 'aa'] @@ -147,11 +165,14 @@ class NavigationMixinTest(BaseMixinDefinition, TestCase): class APICallMixinTest(BaseMixinDefinition, TestCase): + """Tests for APICallMixin.""" + MIXIN_HUMAN_NAME = 'API calls' MIXIN_NAME = 'api_call' MIXIN_ENABLE_CHECK = 'has_api_call' def setUp(self): + """Setup for all tests.""" class MixinCls(APICallMixin, SettingsMixin, InvenTreePlugin): NAME = "Sample API Caller" @@ -170,9 +191,7 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): API_TOKEN_SETTING = 'API_TOKEN' def get_external_url(self, simple: bool = True): - ''' - returns data from the sample endpoint - ''' + """Returns data from the sample endpoint.""" return self.api_call('api/users/2', simple_response=simple) self.mixin = MixinCls() @@ -185,7 +204,7 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): self.mixin_wrong2 = WrongCLS2() def test_base_setup(self): - """Test that the base settings work""" + """Test that the base settings work.""" # check init self.assertTrue(self.mixin.has_api_call) # api_url @@ -196,7 +215,7 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): self.assertEqual(headers, {'Bearer': '', 'Content-Type': 'application/json'}) def test_args(self): - """Test that building up args work""" + """Test that building up args work.""" # api_build_url_args # 1 arg result = self.mixin.api_build_url_args({'a': 'b'}) @@ -209,7 +228,7 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): self.assertEqual(result, '?a=b&c=d,e,f') def test_api_call(self): - """Test that api calls work""" + """Test that api calls work.""" # api_call result = self.mixin.get_external_url() self.assertTrue(result) @@ -239,7 +258,7 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): self.assertEqual(result['page'], 2) def test_function_errors(self): - """Test function errors""" + """Test function errors.""" # wrongly defined plugins should not load with self.assertRaises(MixinNotImplementedError): self.mixin_wrong.has_api_call() @@ -250,7 +269,7 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): class PanelMixinTests(InvenTreeTestCase): - """Test that the PanelMixin plugin operates correctly""" + """Test that the PanelMixin plugin operates correctly.""" fixtures = [ 'category', @@ -262,8 +281,7 @@ class PanelMixinTests(InvenTreeTestCase): roles = 'all' def test_installed(self): - """Test that the sample panel plugin is installed""" - + """Test that the sample panel plugin is installed.""" plugins = registry.with_mixin('panel') self.assertTrue(len(plugins) > 0) @@ -275,8 +293,7 @@ class PanelMixinTests(InvenTreeTestCase): self.assertEqual(len(plugins), 0) def test_disabled(self): - """Test that the panels *do not load* if the plugin is not enabled""" - + """Test that the panels *do not load* if the plugin is not enabled.""" plugin = registry.get_plugin('samplepanel') plugin.set_setting('ENABLE_HELLO_WORLD', True) @@ -305,10 +322,7 @@ class PanelMixinTests(InvenTreeTestCase): self.assertNotIn('Custom Part Panel', str(response.content)) def test_enabled(self): - """ - Test that the panels *do* load if the plugin is enabled - """ - + """Test that the panels *do* load if the plugin is enabled.""" plugin = registry.get_plugin('samplepanel') self.assertEqual(len(registry.with_mixin('panel', active=True)), 0) @@ -382,8 +396,7 @@ class PanelMixinTests(InvenTreeTestCase): self.assertEqual(Error.objects.count(), n_errors + len(urls)) def test_mixin(self): - """Test that ImplementationError is raised""" - + """Test that ImplementationError is raised.""" with self.assertRaises(MixinNotImplementedError): class Wrong(PanelMixin, InvenTreePlugin): pass diff --git a/InvenTree/plugin/base/label/label.py b/InvenTree/plugin/base/label/label.py index 56eaf1bc20..6880a8b3e6 100644 --- a/InvenTree/plugin/base/label/label.py +++ b/InvenTree/plugin/base/label/label.py @@ -1,4 +1,5 @@ -"""Functions to print a label to a mixin printer""" +"""Functions to print a label to a mixin printer.""" + import logging import sys import traceback @@ -16,20 +17,19 @@ from plugin.registry import registry logger = logging.getLogger('inventree') -def print_label(plugin_slug, pdf_data, filename=None, label_instance=None, user=None): - """ - Print label with the provided plugin. +def print_label(plugin_slug: str, pdf_data, filename=None, label_instance=None, user=None): + """Print label with the provided plugin. This task is nominally handled by the background worker. - If the printing fails (throws an exception) then the user is notified. - Arguments: - plugin_slug: The unique slug (key) of the plugin - pdf_data: Binary PDF data - filename: The intended name of the printed label + Args: + plugin_slug (str): The unique slug (key) of the plugin. + pdf_data: Binary PDF data. + filename: The intended name of the printed label. Defaults to None. + label_instance (Union[LabelTemplate, None], optional): The template instance that should be printed. Defaults to None. + user (Union[User, None], optional): User that should be informed of errors. Defaults to None. """ - logger.info(f"Plugin '{plugin_slug}' is printing a label '{filename}'") plugin = registry.plugins.get(plugin_slug, None) diff --git a/InvenTree/plugin/base/label/mixins.py b/InvenTree/plugin/base/label/mixins.py index aa17b1812a..b439d28c97 100644 --- a/InvenTree/plugin/base/label/mixins.py +++ b/InvenTree/plugin/base/label/mixins.py @@ -1,11 +1,10 @@ -"""Plugin mixin classes for label plugins""" +"""Plugin mixin classes for label plugins.""" from plugin.helpers import MixinNotImplementedError class LabelPrintingMixin: - """ - Mixin which enables direct printing of stock labels. + """Mixin which enables direct printing of stock labels. Each plugin must provide a NAME attribute, which is used to uniquely identify the printer. @@ -13,18 +12,16 @@ class LabelPrintingMixin: """ class MixinMeta: - """ - Meta options for this mixin - """ + """Meta options for this mixin.""" MIXIN_NAME = 'Label printing' def __init__(self): # pragma: no cover + """Register mixin.""" super().__init__() self.add_mixin('labels', True, __class__) def print_label(self, **kwargs): - """ - Callback to print a single label + """Callback to print a single label. kwargs: pdf_data: Raw PDF data of the rendered label @@ -35,6 +32,5 @@ class LabelPrintingMixin: filename: The filename of this PDF label user: The user who printed this label """ - # Unimplemented (to be implemented by the particular plugin class) raise MixinNotImplementedError('This Plugin must implement a `print_label` method') diff --git a/InvenTree/plugin/base/label/test_label_mixin.py b/InvenTree/plugin/base/label/test_label_mixin.py index 53294d2f24..1a16e2d8c8 100644 --- a/InvenTree/plugin/base/label/test_label_mixin.py +++ b/InvenTree/plugin/base/label/test_label_mixin.py @@ -1,4 +1,5 @@ -"""Unit tests for the label printing mixin""" +"""Unit tests for the label printing mixin.""" + import os from django.apps import apps @@ -18,7 +19,7 @@ from stock.models import StockItem, StockLocation class LabelMixinTests(InvenTreeAPITestCase): - """Test that the Label mixin operates correctly""" + """Test that the Label mixin operates correctly.""" fixtures = [ 'category', @@ -30,14 +31,13 @@ class LabelMixinTests(InvenTreeAPITestCase): roles = 'all' def do_activate_plugin(self): - """Activate the 'samplelabel' plugin""" - + """Activate the 'samplelabel' plugin.""" config = registry.get_plugin('samplelabel').plugin_config() config.active = True config.save() def do_url(self, parts, plugin_ref, label, url_name: str = 'api-part-label-print', url_single: str = 'part', invalid: bool = False): - """Generate an URL to print a label""" + """Generate an URL to print a label.""" # Construct URL kwargs = {} if label: @@ -64,7 +64,7 @@ class LabelMixinTests(InvenTreeAPITestCase): return url def test_wrong_implementation(self): - """Test that a wrong implementation raises an error""" + """Test that a wrong implementation raises an error.""" class WrongPlugin(LabelPrintingMixin, InvenTreePlugin): pass @@ -74,8 +74,7 @@ class LabelMixinTests(InvenTreeAPITestCase): plugin.print_label(filename='test') def test_installed(self): - """Test that the sample printing plugin is installed""" - + """Test that the sample printing plugin is installed.""" # Get all label plugins plugins = registry.with_mixin('labels') self.assertEqual(len(plugins), 1) @@ -85,8 +84,7 @@ class LabelMixinTests(InvenTreeAPITestCase): self.assertEqual(len(plugins), 0) def test_api(self): - """Test that we can filter the API endpoint by mixin""" - + """Test that we can filter the API endpoint by mixin.""" url = reverse('api-plugin-list') # Try POST (disallowed) @@ -130,8 +128,7 @@ class LabelMixinTests(InvenTreeAPITestCase): self.assertEqual(data['key'], 'samplelabel') def test_printing_process(self): - """Test that a label can be printed""" - + """Test that a label can be printed.""" # Ensure the labels were created apps.get_app_config('label').create_labels() @@ -186,7 +183,7 @@ class LabelMixinTests(InvenTreeAPITestCase): Image.open('label.png') def test_printing_endpoints(self): - """Cover the endpoints not covered by `test_printing_process`""" + """Cover the endpoints not covered by `test_printing_process`.""" plugin_ref = 'samplelabel' # Activate the label components @@ -194,13 +191,13 @@ class LabelMixinTests(InvenTreeAPITestCase): self.do_activate_plugin() def run_print_test(label, qs, url_name, url_single): - """Run tests on single and multiple page printing + """Run tests on single and multiple page printing. Args: - label (_type_): class of the label - qs (_type_): class of the base queryset - url_name (_type_): url for endpoints - url_single (_type_): item lookup reference + label: class of the label + qs: class of the base queryset + url_name: url for endpoints + url_single: item lookup reference """ label = label.objects.first() qs = qs.objects.all() diff --git a/InvenTree/plugin/base/locate/api.py b/InvenTree/plugin/base/locate/api.py index da0488591f..a7d6f5d86e 100644 --- a/InvenTree/plugin/base/locate/api.py +++ b/InvenTree/plugin/base/locate/api.py @@ -1,4 +1,4 @@ -"""API for location plugins""" +"""API for location plugins.""" from rest_framework import permissions from rest_framework.exceptions import NotFound, ParseError @@ -11,16 +11,14 @@ from stock.models import StockItem, StockLocation class LocatePluginView(APIView): - """ - Endpoint for using a custom plugin to identify or 'locate' a stock item or location - """ + """Endpoint for using a custom plugin to identify or 'locate' a stock item or location.""" permission_classes = [ permissions.IsAuthenticated, ] def post(self, request, *args, **kwargs): - + """Check inputs and offload the task to the plugin.""" # Which plugin to we wish to use? plugin = request.data.get('plugin', None) diff --git a/InvenTree/plugin/base/locate/mixins.py b/InvenTree/plugin/base/locate/mixins.py index 5d804b70f2..82780f7823 100644 --- a/InvenTree/plugin/base/locate/mixins.py +++ b/InvenTree/plugin/base/locate/mixins.py @@ -1,4 +1,4 @@ -"""Plugin mixin for locating stock items and locations""" +"""Plugin mixin for locating stock items and locations.""" import logging @@ -8,9 +8,7 @@ logger = logging.getLogger('inventree') class LocateMixin: - """ - Mixin class which provides support for 'locating' inventory items, - for example identifying the location of a particular StockLocation. + """Mixin class which provides support for 'locating' inventory items, for example identifying the location of a particular StockLocation. Plugins could implement audible or visual cues to direct attention to the location, with (for e.g.) LED strips or buzzers, or some other method. @@ -23,19 +21,19 @@ class LocateMixin: - locate_stock_location : Used to locate / identify a particular stock location Refer to the default method implementations below for more information! - """ class MixinMeta: + """Meta for mixin.""" MIXIN_NAME = "Locate" def __init__(self): + """Register the mixin.""" super().__init__() self.add_mixin('locate', True, __class__) def locate_stock_item(self, item_pk): - """ - Attempt to locate a particular StockItem + """Attempt to locate a particular StockItem. Arguments: item_pk: The PK (primary key) of the StockItem to be located @@ -47,7 +45,6 @@ class LocateMixin: Note: A custom implemenation could always change this behaviour """ - logger.info(f"LocateMixin: Attempting to locate StockItem pk={item_pk}") from stock.models import StockItem @@ -63,8 +60,7 @@ class LocateMixin: pass def locate_stock_location(self, location_pk): - """ - Attempt to location a particular StockLocation + """Attempt to location a particular StockLocation. Arguments: location_pk: The PK (primary key) of the StockLocation to be located diff --git a/InvenTree/plugin/base/locate/test_locate.py b/InvenTree/plugin/base/locate/test_locate.py index 96dfca7cb1..32e8540fb2 100644 --- a/InvenTree/plugin/base/locate/test_locate.py +++ b/InvenTree/plugin/base/locate/test_locate.py @@ -1,6 +1,4 @@ -""" -Unit tests for the 'locate' plugin mixin class -""" +"""Unit tests for the 'locate' plugin mixin class.""" from django.urls import reverse @@ -11,6 +9,7 @@ from stock.models import StockItem, StockLocation class LocatePluginTests(InvenTreeAPITestCase): + """Tests for LocateMixin.""" fixtures = [ 'category', @@ -20,8 +19,7 @@ class LocatePluginTests(InvenTreeAPITestCase): ] def test_installed(self): - """Test that a locate plugin is actually installed""" - + """Test that a locate plugin is actually installed.""" plugins = registry.with_mixin('locate') self.assertTrue(len(plugins) > 0) @@ -29,8 +27,7 @@ class LocatePluginTests(InvenTreeAPITestCase): self.assertTrue('samplelocate' in [p.slug for p in plugins]) def test_locate_fail(self): - """Test various API failure modes""" - + """Test various API failure modes.""" url = reverse('api-locate-plugin') # Post without a plugin @@ -90,13 +87,11 @@ class LocatePluginTests(InvenTreeAPITestCase): self.assertIn(f"StockLocation matching PK '{pk}' not found", str(response.data)) def test_locate_item(self): - """ - Test that the plugin correctly 'locates' a StockItem + """Test that the plugin correctly 'locates' a StockItem. As the background worker is not running during unit testing, the sample 'locate' function will be called 'inline' """ - url = reverse('api-locate-plugin') item = StockItem.objects.get(pk=1) @@ -121,10 +116,7 @@ class LocatePluginTests(InvenTreeAPITestCase): self.assertTrue(item.metadata['located']) def test_locate_location(self): - """ - Test that the plugin correctly 'locates' a StockLocation - """ - + """Test that the plugin correctly 'locates' a StockLocation.""" url = reverse('api-locate-plugin') for location in StockLocation.objects.all(): @@ -148,7 +140,7 @@ class LocatePluginTests(InvenTreeAPITestCase): self.assertTrue(location.metadata['located']) def test_mixin_locate(self): - """Test the sample mixin redirection""" + """Test the sample mixin redirection.""" class SamplePlugin(LocateMixin, InvenTreePlugin): pass diff --git a/InvenTree/plugin/builtin/action/simpleactionplugin.py b/InvenTree/plugin/builtin/action/simpleactionplugin.py index 54598e72cb..af4b404764 100644 --- a/InvenTree/plugin/builtin/action/simpleactionplugin.py +++ b/InvenTree/plugin/builtin/action/simpleactionplugin.py @@ -1,26 +1,26 @@ -# -*- coding: utf-8 -*- -"""sample implementation for ActionMixin""" +"""Sample implementation for ActionMixin.""" + from plugin import InvenTreePlugin from plugin.mixins import ActionMixin class SimpleActionPlugin(ActionMixin, InvenTreePlugin): - """ - An EXTREMELY simple action plugin which demonstrates - the capability of the ActionMixin class - """ + """An EXTREMELY simple action plugin which demonstrates the capability of the ActionMixin class.""" NAME = "SimpleActionPlugin" ACTION_NAME = "simple" def perform_action(self, user=None, data=None): + """Sample method.""" print("Action plugin in action!") def get_info(self, user, data=None): + """Sample method.""" return { "user": user.username, "hello": "world", } def get_result(self, user=None, data=None): + """Sample method.""" return True diff --git a/InvenTree/plugin/builtin/action/test_simpleactionplugin.py b/InvenTree/plugin/builtin/action/test_simpleactionplugin.py index e645dea85c..6bc2329496 100644 --- a/InvenTree/plugin/builtin/action/test_simpleactionplugin.py +++ b/InvenTree/plugin/builtin/action/test_simpleactionplugin.py @@ -1,24 +1,25 @@ -""" Unit tests for action plugins """ +"""Unit tests for action plugins.""" from InvenTree.helpers import InvenTreeTestCase from plugin.builtin.action.simpleactionplugin import SimpleActionPlugin class SimpleActionPluginTests(InvenTreeTestCase): - """ Tests for SampleIntegrationPlugin """ + """Tests for SampleIntegrationPlugin.""" def setUp(self): + """Setup for tests.""" super().setUp() self.plugin = SimpleActionPlugin() def test_name(self): - """check plugn names """ + """Check plugn names.""" self.assertEqual(self.plugin.plugin_name(), "SimpleActionPlugin") self.assertEqual(self.plugin.action_name(), "simple") def test_function(self): - """check if functions work """ + """Check if functions work.""" # test functions response = self.client.post('/api/action/', data={'action': "simple", 'data': {'foo': "bar", }}) self.assertEqual(response.status_code, 200) diff --git a/InvenTree/plugin/builtin/barcodes/inventree_barcode.py b/InvenTree/plugin/builtin/barcodes/inventree_barcode.py index bbc2d29c68..52e97ddbd6 100644 --- a/InvenTree/plugin/builtin/barcodes/inventree_barcode.py +++ b/InvenTree/plugin/builtin/barcodes/inventree_barcode.py @@ -1,6 +1,4 @@ -""" -The InvenTreeBarcodePlugin validates barcodes generated by InvenTree itself. -It can be used as a template for developing third-party barcode plugins. +"""The InvenTreeBarcodePlugin validates barcodes generated by InvenTree itself. It can be used as a template for developing third-party barcode plugins. The data format is very simple, and maps directly to database objects, via the "id" parameter. @@ -9,8 +7,6 @@ Parsing an InvenTree barcode simply involves validating that the references model objects actually exist in the database. """ -# -*- coding: utf-8 -*- - import json from rest_framework.exceptions import ValidationError @@ -22,20 +18,19 @@ from stock.models import StockItem, StockLocation class InvenTreeBarcodePlugin(BarcodeMixin, InvenTreePlugin): + """Builtin BarcodePlugin for matching and generating internal barcodes.""" NAME = "InvenTreeBarcode" def validate(self): - """ - An "InvenTree" barcode must be a jsonnable-dict with the following tags: + """Validate a barcode. + An "InvenTree" barcode must be a jsonnable-dict with the following tags: { 'tool': 'InvenTree', 'version': } - """ - # The data must either be dict or be able to dictified if type(self.data) is dict: pass @@ -59,7 +54,7 @@ class InvenTreeBarcodePlugin(BarcodeMixin, InvenTreePlugin): return True def getStockItem(self): - + """Lookup StockItem by 'stockitem' key in barcode data.""" for k in self.data.keys(): if k.lower() == 'stockitem': @@ -88,7 +83,7 @@ class InvenTreeBarcodePlugin(BarcodeMixin, InvenTreePlugin): return None def getStockLocation(self): - + """Lookup StockLocation by 'stocklocation' key in barcode data.""" for k in self.data.keys(): if k.lower() == 'stocklocation': @@ -116,7 +111,7 @@ class InvenTreeBarcodePlugin(BarcodeMixin, InvenTreePlugin): return None def getPart(self): - + """Lookup Part by 'part' key in barcode data.""" for k in self.data.keys(): if k.lower() == 'part': diff --git a/InvenTree/plugin/builtin/barcodes/test_inventree_barcode.py b/InvenTree/plugin/builtin/barcodes/test_inventree_barcode.py index 0bc855aa24..b3fd51c781 100644 --- a/InvenTree/plugin/builtin/barcodes/test_inventree_barcode.py +++ b/InvenTree/plugin/builtin/barcodes/test_inventree_barcode.py @@ -1,5 +1,4 @@ -# -*- coding: utf-8 -*- -"""Unit tests for InvenTreeBarcodePlugin""" +"""Unit tests for InvenTreeBarcodePlugin.""" from django.urls import reverse @@ -9,6 +8,7 @@ from InvenTree.api_tester import InvenTreeAPITestCase class TestInvenTreeBarcode(InvenTreeAPITestCase): + """Tests for the integrated InvenTreeBarcode barcode plugin.""" fixtures = [ 'category', @@ -18,9 +18,7 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): ] def test_errors(self): - """ - Test all possible error cases for assigment action - """ + """Test all possible error cases for assigment action.""" def test_assert_error(barcode_data): response = self.client.post( @@ -46,10 +44,7 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): test_assert_error('{"blbla": 10004}') def test_scan(self): - """ - Test that a barcode can be scanned - """ - + """Test that a barcode can be scanned.""" response = self.client.post(reverse('api-barcode-scan'), format='json', data={'barcode': 'blbla=10004'}) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn('success', response.data) diff --git a/InvenTree/plugin/builtin/integration/core_notifications.py b/InvenTree/plugin/builtin/integration/core_notifications.py index 179fb7994f..d5c3cffd51 100644 --- a/InvenTree/plugin/builtin/integration/core_notifications.py +++ b/InvenTree/plugin/builtin/integration/core_notifications.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- -"""Core set of Notifications as a Plugin""" +"""Core set of Notifications as a Plugin.""" + from django.template.loader import render_to_string from django.utils.translation import ugettext_lazy as _ @@ -11,14 +11,18 @@ from plugin.mixins import BulkNotificationMethod, SettingsMixin class PlgMixin: + """Mixin to access plugin easier. + + This needs to be spit out to reference the class. Perks of python. + """ + def get_plugin(self): + """Return plugin reference.""" return CoreNotificationsPlugin class CoreNotificationsPlugin(SettingsMixin, InvenTreePlugin): - """ - Core notification methods for InvenTree - """ + """Core notification methods for InvenTree.""" NAME = "CoreNotificationsPlugin" AUTHOR = _('InvenTree contributors') @@ -34,6 +38,8 @@ class CoreNotificationsPlugin(SettingsMixin, InvenTreePlugin): } class EmailNotification(PlgMixin, BulkNotificationMethod): + """Notificationmethod for delivery via Email.""" + METHOD_NAME = 'mail' METHOD_ICON = 'fa-envelope' CONTEXT_EXTRA = [ @@ -50,11 +56,7 @@ class CoreNotificationsPlugin(SettingsMixin, InvenTreePlugin): } def get_targets(self): - """ - Return a list of target email addresses, - only for users which allow email notifications - """ - + """Return a list of target email addresses, only for users which allow email notifications.""" allowed_users = [] for user in self.targets: @@ -68,6 +70,7 @@ class CoreNotificationsPlugin(SettingsMixin, InvenTreePlugin): ) def send_bulk(self): + """Send the notifications out via email.""" html_message = render_to_string(self.context['template']['html'], self.context) targets = self.targets.values_list('email', flat=True) diff --git a/InvenTree/plugin/builtin/integration/test_core_notifications.py b/InvenTree/plugin/builtin/integration/test_core_notifications.py index 3badac2562..1ff60c74b1 100644 --- a/InvenTree/plugin/builtin/integration/test_core_notifications.py +++ b/InvenTree/plugin/builtin/integration/test_core_notifications.py @@ -1,3 +1,5 @@ +"""Tests for core_notifications.""" + from part.test_part import BaseNotificationIntegrationTest from plugin import registry from plugin.builtin.integration.core_notifications import \ @@ -6,12 +8,10 @@ from plugin.models import NotificationUserSetting class CoreNotificationTestTests(BaseNotificationIntegrationTest): + """Tests for CoreNotificationsPlugin.""" def test_email(self): - """ - Ensure that the email notifications run - """ - + """Ensure that the email notifications run.""" # enable plugin and set mail setting to true plugin = registry.plugins.get('corenotificationsplugin') plugin.set_setting('ENABLE_NOTIFICATION_EMAILS', True) diff --git a/InvenTree/plugin/events.py b/InvenTree/plugin/events.py index 452e71d4ef..800233bd9a 100644 --- a/InvenTree/plugin/events.py +++ b/InvenTree/plugin/events.py @@ -1,6 +1,4 @@ -""" -Import helper for events -""" +"""Import helper for events.""" from plugin.base.event.events import (process_event, register_event, trigger_event) diff --git a/InvenTree/plugin/helpers.py b/InvenTree/plugin/helpers.py index 8a3a38a4ca..247ae72654 100644 --- a/InvenTree/plugin/helpers.py +++ b/InvenTree/plugin/helpers.py @@ -1,6 +1,5 @@ -""" -Helpers for plugin app -""" +"""Helpers for plugin app.""" + import inspect import logging import os @@ -20,36 +19,38 @@ logger = logging.getLogger('inventree') # region logging / errors class IntegrationPluginError(Exception): - """ - Error that encapsulates another error and adds the path / reference of the raising plugin - """ + """Error that encapsulates another error and adds the path / reference of the raising plugin.""" + def __init__(self, path, message): + """Init a plugin error. + + Args: + path: Path on which the error occured - used to find out which plugin it was + message: The original error message + """ self.path = path self.message = message def __str__(self): + """Returns the error message.""" return self.message # pragma: no cover class MixinImplementationError(ValueError): - """ - Error if mixin was implemented wrong in plugin + """Error if mixin was implemented wrong in plugin. + Mostly raised if constant is missing """ pass class MixinNotImplementedError(NotImplementedError): - """ - Error if necessary mixin function was not overwritten - """ + """Error if necessary mixin function was not overwritten.""" pass def log_error(error, reference: str = 'general'): - """ - Log an plugin error - """ + """Log an plugin error.""" from plugin import registry # make sure the registry is set up @@ -61,9 +62,7 @@ def log_error(error, reference: str = 'general'): def handle_error(error, do_raise: bool = True, do_log: bool = True, log_name: str = ''): - """ - Handles an error and casts it as an IntegrationPluginError - """ + """Handles an error and casts it as an IntegrationPluginError.""" package_path = traceback.extract_tb(error.__traceback__)[-1].filename install_path = sysconfig.get_paths()["purelib"] try: @@ -99,9 +98,7 @@ def handle_error(error, do_raise: bool = True, do_log: bool = True, log_name: st # region git-helpers def get_git_log(path): - """ - Get dict with info of the last commit to file named in path - """ + """Get dict with info of the last commit to file named in path.""" from plugin import registry output = None @@ -122,8 +119,7 @@ def get_git_log(path): def check_git_version(): - """returns if the current git version supports modern features""" - + """Returns if the current git version supports modern features.""" # get version string try: output = str(subprocess.check_output(['git', '--version'], cwd=os.path.dirname(settings.BASE_DIR)), 'utf-8') @@ -143,18 +139,17 @@ def check_git_version(): class GitStatus: - """ - Class for resolving git gpg singing state - """ + """Class for resolving git gpg singing state.""" + class Definition: - """ - Definition of a git gpg sing state - """ + """Definition of a git gpg sing state.""" + key: str = 'N' status: int = 2 msg: str = '' def __init__(self, key: str = 'N', status: int = 2, msg: str = '') -> None: + """Define a git Status -> needed for lookup.""" self.key = key self.status = status self.msg = msg @@ -172,8 +167,7 @@ class GitStatus: # region plugin finders def get_modules(pkg): - """get all modules in a package""" - + """Get all modules in a package.""" context = {} for loader, name, ispkg in pkgutil.walk_packages(pkg.__path__): try: @@ -195,18 +189,16 @@ def get_modules(pkg): def get_classes(module): - """get all classes in a given module""" + """Get all classes in a given module.""" return inspect.getmembers(module, inspect.isclass) def get_plugins(pkg, baseclass): - """ - Return a list of all modules under a given package. + """Return a list of all modules under a given package. - Modules must be a subclass of the provided 'baseclass' - Modules must have a non-empty NAME parameter """ - plugins = [] modules = get_modules(pkg) @@ -225,10 +217,7 @@ def get_plugins(pkg, baseclass): # region templates def render_template(plugin, template_file, context=None): - """ - Locate and render a template file, available in the global template context. - """ - + """Locate and render a template file, available in the global template context.""" try: tmp = template.loader.get_template(template_file) except template.TemplateDoesNotExist: @@ -247,10 +236,7 @@ def render_template(plugin, template_file, context=None): def render_text(text, context=None): - """ - Locate a raw string with provided context - """ - + """Locate a raw string with provided context.""" ctx = template.Context(context) return template.Template(text).render(ctx) diff --git a/InvenTree/plugin/mixins/__init__.py b/InvenTree/plugin/mixins/__init__.py index 80088de561..207a65ab42 100644 --- a/InvenTree/plugin/mixins/__init__.py +++ b/InvenTree/plugin/mixins/__init__.py @@ -1,6 +1,4 @@ -""" -Utility class to enable simpler imports -""" +"""Utility class to enable simpler imports.""" from common.notifications import (BulkNotificationMethod, SingleNotificationMethod) diff --git a/InvenTree/plugin/models.py b/InvenTree/plugin/models.py index 35becb9c47..9a58e96fdc 100644 --- a/InvenTree/plugin/models.py +++ b/InvenTree/plugin/models.py @@ -1,6 +1,4 @@ -""" -Plugin model definitions -""" +"""Plugin model definitions.""" import warnings @@ -14,9 +12,7 @@ from plugin import InvenTreePlugin, registry class MetadataMixin(models.Model): - """ - Model mixin class which adds a JSON metadata field to a model, - for use by any (and all) plugins. + """Model mixin class which adds a JSON metadata field to a model, for use by any (and all) plugins. The intent of this mixin is to provide a metadata field on a model instance, for plugins to read / modify as required, to store any extra information. @@ -28,6 +24,7 @@ class MetadataMixin(models.Model): """ class Meta: + """Meta for MetadataMixin.""" abstract = True metadata = models.JSONField( @@ -37,8 +34,7 @@ class MetadataMixin(models.Model): ) def get_metadata(self, key: str, backup_value=None): - """ - Finds metadata for this model instance, using the provided key for lookup + """Finds metadata for this model instance, using the provided key for lookup. Args: key: String key for requesting metadata. e.g. if a plugin is accessing the metadata, the plugin slug should be used @@ -46,22 +42,19 @@ class MetadataMixin(models.Model): Returns: Python dict object containing requested metadata. If no matching metadata is found, returns None """ - if self.metadata is None: return backup_value return self.metadata.get(key, backup_value) - def set_metadata(self, key: str, data, commit=True): - """ - Save the provided metadata under the provided key. + def set_metadata(self, key: str, data, commit: bool = True): + """Save the provided metadata under the provided key. Args: - key: String key for saving metadata - data: Data object to save - must be able to be rendered as a JSON string - overwrite: If true, existing metadata with the provided key will be overwritten. If false, a merge will be attempted + key (str): Key for saving metadata + data (Any): Data object to save - must be able to be rendered as a JSON string + commit (bool, optional): If true, existing metadata with the provided key will be overwritten. If false, a merge will be attempted. Defaults to True. """ - if self.metadata is None: # Handle a null field value self.metadata = {} @@ -73,15 +66,16 @@ class MetadataMixin(models.Model): class PluginConfig(models.Model): - """ - A PluginConfig object holds settings for plugins. + """A PluginConfig object holds settings for plugins. Attributes: key: slug of the plugin (this must be unique across all installed plugins!) name: PluginName of the plugin - serves for a manual double check if the right plugin is used active: Should the plugin be loaded? """ + class Meta: + """Meta for PluginConfig.""" verbose_name = _("Plugin Configuration") verbose_name_plural = _("Plugin Configurations") @@ -107,6 +101,7 @@ class PluginConfig(models.Model): ) def __str__(self) -> str: + """Nice name for printing.""" name = f'{self.name} - {self.key}' if not self.active: name += '(not active)' @@ -114,7 +109,7 @@ class PluginConfig(models.Model): # extra attributes from the registry def mixins(self): - + """Returns all registered mixins.""" try: return self.plugin._mixinreg except (AttributeError, ValueError): # pragma: no cover @@ -123,10 +118,7 @@ class PluginConfig(models.Model): # functions def __init__(self, *args, **kwargs): - """ - Override to set original state of the plugin-config instance - """ - + """Override to set original state of the plugin-config instance.""" super().__init__(*args, **kwargs) self.__org_active = self.active @@ -145,9 +137,7 @@ class PluginConfig(models.Model): } def save(self, force_insert=False, force_update=False, *args, **kwargs): - """ - Extend save method to reload plugins if the 'active' status changes - """ + """Extend save method to reload plugins if the 'active' status changes.""" reload = kwargs.pop('no_reload', False) # check if no_reload flag is set ret = super().save(force_insert, force_update, *args, **kwargs) @@ -163,11 +153,10 @@ class PluginConfig(models.Model): class PluginSetting(common.models.BaseInvenTreeSetting): - """ - This model represents settings for individual plugins - """ + """This model represents settings for individual plugins.""" class Meta: + """Meta for PluginSetting.""" unique_together = [ ('plugin', 'key'), ] @@ -182,9 +171,7 @@ class PluginSetting(common.models.BaseInvenTreeSetting): @classmethod def get_setting_definition(cls, key, **kwargs): - """ - In the BaseInvenTreeSetting class, we have a class attribute named 'SETTINGS', - which is a dict object that fully defines all the setting parameters. + """In the BaseInvenTreeSetting class, we have a class attribute named 'SETTINGS', which is a dict object that fully defines all the setting parameters. Here, unlike the BaseInvenTreeSetting, we do not know the definitions of all settings 'ahead of time' (as they are defined externally in the plugins). @@ -194,7 +181,6 @@ class PluginSetting(common.models.BaseInvenTreeSetting): If not provided, we'll look at the plugin registry to see what settings are available, (if the plugin is specified!) """ - if 'settings' not in kwargs: plugin = kwargs.pop('plugin', None) @@ -209,28 +195,24 @@ class PluginSetting(common.models.BaseInvenTreeSetting): return super().get_setting_definition(key, **kwargs) def get_kwargs(self): - """ - Explicit kwargs required to uniquely identify a particular setting object, - in addition to the 'key' parameter - """ - + """Explicit kwargs required to uniquely identify a particular setting object, in addition to the 'key' parameter.""" return { 'plugin': self.plugin, } class NotificationUserSetting(common.models.BaseInvenTreeSetting): - """ - This model represents notification settings for a user - """ + """This model represents notification settings for a user.""" class Meta: + """Meta for NotificationUserSetting.""" unique_together = [ ('method', 'user', 'key'), ] @classmethod def get_setting_definition(cls, key, **kwargs): + """Override setting_definition to use notification settings.""" from common.notifications import storage kwargs['settings'] = storage.user_settings @@ -238,11 +220,7 @@ class NotificationUserSetting(common.models.BaseInvenTreeSetting): return super().get_setting_definition(key, **kwargs) def get_kwargs(self): - """ - Explicit kwargs required to uniquely identify a particular setting object, - in addition to the 'key' parameter - """ - + """Explicit kwargs required to uniquely identify a particular setting object, in addition to the 'key' parameter.""" return { 'method': self.method, 'user': self.user, @@ -262,4 +240,5 @@ class NotificationUserSetting(common.models.BaseInvenTreeSetting): ) def __str__(self) -> str: + """Nice name of printing.""" return f'{self.key} (for {self.user}): {self.value}' diff --git a/InvenTree/plugin/plugin.py b/InvenTree/plugin/plugin.py index d5f6a7bccf..8b0c0bc847 100644 --- a/InvenTree/plugin/plugin.py +++ b/InvenTree/plugin/plugin.py @@ -1,7 +1,5 @@ -# -*- coding: utf-8 -*- -""" -Base Class for InvenTree plugins -""" +"""Base Class for InvenTree plugins.""" + import inspect import logging import os @@ -21,7 +19,7 @@ logger = logging.getLogger("inventree") class MetaBase: - """Base class for a plugins metadata""" + """Base class for a plugins metadata.""" # Override the plugin name for each concrete plugin instance NAME = '' @@ -29,7 +27,7 @@ class MetaBase: TITLE = None def get_meta_value(self, key: str, old_key: str = None, __default=None): - """Reference a meta item with a key + """Reference a meta item with a key. Args: key (str): key for the value @@ -55,24 +53,19 @@ class MetaBase: return value def plugin_name(self): - """ - Name of plugin - """ + """Name of plugin.""" return self.get_meta_value('NAME', 'PLUGIN_NAME') @property def name(self): - """ - Name of plugin - """ + """Name of plugin.""" return self.plugin_name() def plugin_slug(self): - """ - Slug of plugin + """Slug of plugin. + If not set plugin name slugified """ - slug = self.get_meta_value('SLUG', 'PLUGIN_SLUG', None) if not slug: slug = self.plugin_name() @@ -81,16 +74,11 @@ class MetaBase: @property def slug(self): - """ - Slug of plugin - """ + """Slug of plugin.""" return self.plugin_slug() def plugin_title(self): - """ - Title of plugin - """ - + """Title of plugin.""" title = self.get_meta_value('TITLE', 'PLUGIN_TITLE', None) if title: return title @@ -98,16 +86,11 @@ class MetaBase: @property def human_name(self): - """ - Human readable name of plugin - """ + """Human readable name of plugin.""" return self.plugin_title() def plugin_config(self): - """ - Return the PluginConfig object associated with this plugin - """ - + """Return the PluginConfig object associated with this plugin.""" try: import plugin.models @@ -121,10 +104,7 @@ class MetaBase: return cfg def is_active(self): - """ - Return True if this plugin is currently active - """ - + """Return True if this plugin is currently active.""" cfg = self.plugin_config() if cfg: @@ -134,25 +114,23 @@ class MetaBase: class MixinBase: - """ - Base set of mixin functions and mechanisms - """ + """Base set of mixin functions and mechanisms.""" def __init__(self, *args, **kwargs) -> None: + """Init sup-parts. + + Adds state dicts. + """ self._mixinreg = {} self._mixins = {} super().__init__(*args, **kwargs) def mixin(self, key): - """ - Check if mixin is registered - """ + """Check if mixin is registered.""" return key in self._mixins def mixin_enabled(self, key): - """ - Check if mixin is registered, enabled and ready - """ + """Check if mixin is registered, enabled and ready.""" if self.mixin(key): fnc_name = self._mixins.get(key) @@ -164,18 +142,12 @@ class MixinBase: return False def add_mixin(self, key: str, fnc_enabled=True, cls=None): - """ - Add a mixin to the plugins registry - """ - + """Add a mixin to the plugins registry.""" self._mixins[key] = fnc_enabled self.setup_mixin(key, cls=cls) def setup_mixin(self, key, cls=None): - """ - Define mixin details for the current mixin -> provides meta details for all active mixins - """ - + """Define mixin details for the current mixin -> provides meta details for all active mixins.""" # get human name human_name = getattr(cls.MixinMeta, 'MIXIN_NAME', key) if cls and hasattr(cls, 'MixinMeta') else key @@ -187,10 +159,7 @@ class MixinBase: @property def registered_mixins(self, with_base: bool = False): - """ - Get all registered mixins for the plugin - """ - + """Get all registered mixins for the plugin.""" mixins = getattr(self, '_mixinreg', None) if mixins: # filter out base @@ -202,8 +171,7 @@ class MixinBase: class InvenTreePlugin(MixinBase, MetaBase): - """ - The InvenTreePlugin class is used to integrate with 3rd party software + """The InvenTreePlugin class is used to integrate with 3rd party software. DO NOT USE THIS DIRECTLY, USE plugin.InvenTreePlugin """ @@ -216,6 +184,10 @@ class InvenTreePlugin(MixinBase, MetaBase): LICENSE = None def __init__(self): + """Init a plugin. + + Set paths and load metadata. + """ super().__init__() self.add_mixin('base') self.def_path = inspect.getfile(self.__class__) @@ -226,9 +198,7 @@ class InvenTreePlugin(MixinBase, MetaBase): # region properties @property def description(self): - """ - Description of plugin - """ + """Description of plugin.""" description = getattr(self, 'DESCRIPTION', None) if not description: description = self.plugin_name() @@ -236,9 +206,7 @@ class InvenTreePlugin(MixinBase, MetaBase): @property def author(self): - """ - Author of plugin - either from plugin settings or git - """ + """Author of plugin - either from plugin settings or git.""" author = getattr(self, 'AUTHOR', None) if not author: author = self.package.get('author') @@ -248,9 +216,7 @@ class InvenTreePlugin(MixinBase, MetaBase): @property def pub_date(self): - """ - Publishing date of plugin - either from plugin settings or git - """ + """Publishing date of plugin - either from plugin settings or git.""" pub_date = getattr(self, 'PUBLISH_DATE', None) if not pub_date: pub_date = self.package.get('date') @@ -262,77 +228,57 @@ class InvenTreePlugin(MixinBase, MetaBase): @property def version(self): - """ - Version of plugin - """ + """Version of plugin.""" version = getattr(self, 'VERSION', None) return version @property def website(self): - """ - Website of plugin - if set else None - """ + """Website of plugin - if set else None.""" website = getattr(self, 'WEBSITE', None) return website @property def license(self): - """ - License of plugin - """ + """License of plugin.""" lic = getattr(self, 'LICENSE', None) return lic # endregion @property def _is_package(self): - """ - Is the plugin delivered as a package - """ + """Is the plugin delivered as a package.""" return getattr(self, 'is_package', False) @property def is_sample(self): - """ - Is this plugin part of the samples? - """ + """Is this plugin part of the samples?""" path = str(self.package_path) return path.startswith('plugin/samples/') @property def package_path(self): - """ - Path to the plugin - """ + """Path to the plugin.""" if self._is_package: return self.__module__ # pragma: no cover return pathlib.Path(self.def_path).relative_to(settings.BASE_DIR) @property def settings_url(self): - """ - URL to the settings panel for this plugin - """ + """URL to the settings panel for this plugin.""" return f'{reverse("settings")}#select-plugin-{self.slug}' # region package info def _get_package_commit(self): - """ - Get last git commit for the plugin - """ + """Get last git commit for the plugin.""" return get_git_log(self.def_path) def _get_package_metadata(self): - """ - Get package metadata for plugin - """ + """Get package metadata for plugin.""" return {} # pragma: no cover # TODO add usage for package metadata def define_package(self): - """ - Add package info of the plugin into plugins context - """ + """Add package info of the plugin into plugins context.""" package = self._get_package_metadata() if self._is_package else self._get_package_commit() # process date @@ -355,8 +301,13 @@ class InvenTreePlugin(MixinBase, MetaBase): class IntegrationPluginBase(InvenTreePlugin): + """Legacy base class for plugins. + + Do not use! + """ + def __init__(self, *args, **kwargs): - """Send warning about using this reference""" + """Send warning about using this reference.""" # TODO remove in 0.8.0 warnings.warn("This import is deprecated - use InvenTreePlugin", DeprecationWarning) super().__init__(*args, **kwargs) diff --git a/InvenTree/plugin/registry.py b/InvenTree/plugin/registry.py index e8e61496b4..574cf654ff 100644 --- a/InvenTree/plugin/registry.py +++ b/InvenTree/plugin/registry.py @@ -1,5 +1,4 @@ -""" -Registry for loading and managing multiple plugins at run-time +"""Registry for loading and managing multiple plugins at run-time. - Holds the class and the object that contains all code to maintain plugin states - Manages setup and teardown of plugin class instances @@ -31,11 +30,13 @@ logger = logging.getLogger('inventree') class PluginsRegistry: - """ - The PluginsRegistry class - """ + """The PluginsRegistry class.""" def __init__(self) -> None: + """Initialize registry. + + Set up all needed references for internal and external states. + """ # plugin registry self.plugins = {} self.plugins_inactive = {} @@ -55,10 +56,7 @@ class PluginsRegistry: self.mixins_settings = {} def get_plugin(self, slug): - """ - Lookup plugin by slug (unique key). - """ - + """Lookup plugin by slug (unique key).""" if slug not in self.plugins: logger.warning(f"Plugin registry has no record of plugin '{slug}'") return None @@ -66,15 +64,13 @@ class PluginsRegistry: return self.plugins[slug] def call_plugin_function(self, slug, func, *args, **kwargs): - """ - Call a member function (named by 'func') of the plugin named by 'slug'. + """Call a member function (named by 'func') of the plugin named by 'slug'. As this is intended to be run by the background worker, we do not perform any try/except here. Instead, any error messages are returned to the worker. """ - plugin = self.get_plugin(slug) if not plugin: @@ -87,7 +83,7 @@ class PluginsRegistry: # region public functions # region loading / unloading def load_plugins(self, full_reload: bool = False): - """Load and activate all IntegrationPlugins + """Load and activate all IntegrationPlugins. Args: full_reload (bool, optional): Reload everything - including plugin mechanism. Defaults to False. @@ -150,10 +146,7 @@ class PluginsRegistry: logger.info('Finished loading plugins') def unload_plugins(self): - """ - Unload and deactivate all IntegrationPlugins - """ - + """Unload and deactivate all IntegrationPlugins.""" if not settings.PLUGINS_ENABLED: # Plugins not enabled, do nothing return # pragma: no cover @@ -177,12 +170,11 @@ class PluginsRegistry: logger.info('Finished unloading plugins') def reload_plugins(self, full_reload: bool = False): - """Safely reload IntegrationPlugins + """Safely reload. Args: full_reload (bool, optional): Reload everything - including plugin mechanism. Defaults to False. """ - # Do not reload whe currently loading if self.is_loading: return # pragma: no cover @@ -196,8 +188,7 @@ class PluginsRegistry: logger.info('Finished reloading plugins') def collect_plugins(self): - """Collect plugins from all possible ways of loading""" - + """Collect plugins from all possible ways of loading.""" if not settings.PLUGINS_ENABLED: # Plugins not enabled, do nothing return # pragma: no cover @@ -226,10 +217,7 @@ class PluginsRegistry: logger.info(", ".join([a.__module__ for a in self.plugin_modules])) def install_plugin_file(self): - """ - Make sure all plugins are installed in the current enviroment - """ - + """Make sure all plugins are installed in the current enviroment.""" if settings.PLUGIN_FILE_CHECKED: logger.info('Plugin file was already checked') return True @@ -250,9 +238,7 @@ class PluginsRegistry: # region registry functions def with_mixin(self, mixin: str, active=None): - """ - Returns reference to all plugins that have a specified mixin enabled - """ + """Returns reference to all plugins that have a specified mixin enabled.""" result = [] for plugin in self.plugins.values(): @@ -273,14 +259,12 @@ class PluginsRegistry: # region general internal loading /activating / deactivating / deloading def _init_plugins(self, disabled=None): - """ - Initialise all found plugins + """Initialise all found plugins. :param disabled: loading path of disabled app, defaults to None :type disabled: str, optional :raises error: IntegrationPluginError """ - from plugin.models import PluginConfig logger.info('Starting plugin initialisation') @@ -344,7 +328,7 @@ class PluginsRegistry: self.plugins_inactive[plug_key] = plugin_db_setting # pragma: no cover def _activate_plugins(self, force_reload=False, full_reload: bool = False): - """Run activation functions for all plugins + """Run activation functions for all plugins. Args: force_reload (bool, optional): Also reload base apps. Defaults to False. @@ -359,8 +343,7 @@ class PluginsRegistry: self.activate_plugin_app(plugins, force_reload=force_reload, full_reload=full_reload) def _deactivate_plugins(self): - """Run deactivation functions for all plugins""" - + """Run deactivation functions for all plugins.""" self.deactivate_plugin_app() self.deactivate_plugin_schedule() self.deactivate_plugin_settings() @@ -368,7 +351,11 @@ class PluginsRegistry: # region mixin specific loading ... def activate_plugin_settings(self, plugins): + """Activate plugin settings. + Add all defined settings form the plugins to a unified dict in the registry. + This dict is referenced by the PluginSettings for settings definitions. + """ logger.info('Activating plugin settings') self.mixins_settings = {} @@ -379,18 +366,13 @@ class PluginsRegistry: self.mixins_settings[slug] = plugin_setting def deactivate_plugin_settings(self): - - # collect all settings - plugin_settings = {} - - for _, plugin_setting in self.mixins_settings.items(): - plugin_settings.update(plugin_setting) - - # clear cache + """Deactivate all plugin settings.""" + logger.info('Deactivating plugin settings') + # clear settings cache self.mixins_settings = {} def activate_plugin_schedule(self, plugins): - + """Activate scheudles from plugins with the ScheduleMixin.""" logger.info('Activating plugin tasks') from common.models import InvenTreeSetting @@ -434,21 +416,20 @@ class PluginsRegistry: logger.warning("activate_integration_schedule failed, database not ready") def deactivate_plugin_schedule(self): - """ - Deactivate ScheduleMixin - currently nothing is done + """Deactivate ScheduleMixin. + + Currently nothing is done here. """ pass def activate_plugin_app(self, plugins, force_reload=False, full_reload: bool = False): - """Activate AppMixin plugins - add custom apps and reload + """Activate AppMixin plugins - add custom apps and reload. Args: plugins (dict): List of IntegrationPlugins that should be installed force_reload (bool, optional): Only reload base apps. Defaults to False. full_reload (bool, optional): Reload everything - including plugin mechanism. Defaults to False. """ - from common.models import InvenTreeSetting if settings.PLUGIN_TESTING or InvenTreeSetting.get_setting('ENABLE_PLUGINS_APP'): @@ -480,9 +461,10 @@ class PluginsRegistry: self._update_urls() def _reregister_contrib_apps(self): - """fix reloading of contrib apps - models and admin - this is needed if plugins were loaded earlier and then reloaded as models and admins rely on imports - those register models and admin in their respective objects (e.g. admin.site for admin) + """Fix reloading of contrib apps - models and admin. + + This is needed if plugins were loaded earlier and then reloaded as models and admins rely on imports. + Those register models and admin in their respective objects (e.g. admin.site for admin). """ for plugin_path in self.installed_apps: try: @@ -512,8 +494,9 @@ class PluginsRegistry: reload(app_config.module.admin) def _get_plugin_path(self, plugin): - """parse plugin path - the input can be eiter: + """Parse plugin path. + + The input can be eiter: - a local file / dir - a package """ @@ -526,8 +509,7 @@ class PluginsRegistry: return plugin_path def deactivate_plugin_app(self): - """Deactivate AppMixin plugins - some magic required""" - + """Deactivate AppMixin plugins - some magic required.""" # unregister models from admin for plugin_path in self.installed_apps: models = [] # the modelrefs need to be collected as poping an item in a iter is not welcomed @@ -598,13 +580,12 @@ class PluginsRegistry: clear_url_caches() def _reload_apps(self, force_reload: bool = False, full_reload: bool = False): - """Internal: reload apps using django internal functions + """Internal: reload apps using django internal functions. Args: force_reload (bool, optional): Also reload base apps. Defaults to False. full_reload (bool, optional): Reload everything - including plugin mechanism. Defaults to False. """ - # If full_reloading is set to true we do not want to set the flag if not full_reload: self.is_loading = True # set flag to disable loop reloading @@ -619,9 +600,9 @@ class PluginsRegistry: self.is_loading = False def _try_reload(self, cmd, *args, **kwargs): - """ - wrapper to try reloading the apps - throws an custom error that gets handled by the loading function + """Wrapper to try reloading the apps. + + Throws an custom error that gets handled by the loading function. """ try: cmd(*args, **kwargs) @@ -635,5 +616,5 @@ registry = PluginsRegistry() def call_function(plugin_name, function_name, *args, **kwargs): - """ Global helper function to call a specific member function of a plugin """ + """Global helper function to call a specific member function of a plugin.""" return registry.call_plugin_function(plugin_name, function_name, *args, **kwargs) diff --git a/InvenTree/plugin/samples/event/event_sample.py b/InvenTree/plugin/samples/event/event_sample.py index acddbf95c6..9912d5f80c 100644 --- a/InvenTree/plugin/samples/event/event_sample.py +++ b/InvenTree/plugin/samples/event/event_sample.py @@ -1,6 +1,4 @@ -""" -Sample plugin which responds to events -""" +"""Sample plugin which responds to events.""" import warnings @@ -11,17 +9,14 @@ from plugin.mixins import EventMixin class EventPluginSample(EventMixin, InvenTreePlugin): - """ - A sample plugin which provides supports for triggered events - """ + """A sample plugin which provides supports for triggered events.""" NAME = "EventPlugin" SLUG = "sampleevent" TITLE = "Triggered Events" def process_event(self, event, *args, **kwargs): - """ Custom event processing """ - + """Custom event processing.""" print(f"Processing triggered event: '{event}'") print("args:", str(args)) print("kwargs:", str(kwargs)) diff --git a/InvenTree/plugin/samples/event/test_event_sample.py b/InvenTree/plugin/samples/event/test_event_sample.py index 284b306a92..4f28e6595f 100644 --- a/InvenTree/plugin/samples/event/test_event_sample.py +++ b/InvenTree/plugin/samples/event/test_event_sample.py @@ -1,4 +1,4 @@ -"""Unit tests for event_sample sample plugins""" +"""Unit tests for event_sample sample plugins.""" from django.conf import settings from django.test import TestCase @@ -10,10 +10,10 @@ from plugin.mixins import EventMixin class EventPluginSampleTests(TestCase): - """Tests for EventPluginSample""" + """Tests for EventPluginSample.""" def test_run_event(self): - """Check if the event is issued""" + """Check if the event is issued.""" # Activate plugin config = registry.get_plugin('sampleevent').plugin_config() config.active = True @@ -30,8 +30,7 @@ class EventPluginSampleTests(TestCase): settings.PLUGIN_TESTING_EVENTS = False def test_mixin(self): - """Test that MixinNotImplementedError is raised""" - + """Test that MixinNotImplementedError is raised.""" with self.assertRaises(MixinNotImplementedError): class Wrong(EventMixin, InvenTreePlugin): pass diff --git a/InvenTree/plugin/samples/integration/another_sample.py b/InvenTree/plugin/samples/integration/another_sample.py index 0cc5ce21c3..580e69d9cf 100644 --- a/InvenTree/plugin/samples/integration/another_sample.py +++ b/InvenTree/plugin/samples/integration/another_sample.py @@ -1,19 +1,15 @@ -"""sample implementation for IntegrationPlugin""" +"""Sample implementation for IntegrationPlugin.""" from plugin import InvenTreePlugin from plugin.mixins import UrlsMixin class NoIntegrationPlugin(InvenTreePlugin): - """ - An basic plugin - """ + """A basic plugin.""" NAME = "NoIntegrationPlugin" class WrongIntegrationPlugin(UrlsMixin, InvenTreePlugin): - """ - An basic wron plugin with urls - """ + """A basic wron plugin with urls.""" NAME = "WrongIntegrationPlugin" diff --git a/InvenTree/plugin/samples/integration/api_caller.py b/InvenTree/plugin/samples/integration/api_caller.py index 98a145de34..0968780e8b 100644 --- a/InvenTree/plugin/samples/integration/api_caller.py +++ b/InvenTree/plugin/samples/integration/api_caller.py @@ -1,14 +1,11 @@ -""" -Sample plugin for calling an external API -""" +"""Sample plugin for calling an external API.""" from plugin import InvenTreePlugin from plugin.mixins import APICallMixin, SettingsMixin class SampleApiCallerPlugin(APICallMixin, SettingsMixin, InvenTreePlugin): - """ - A small api call sample - """ + """A small api call sample.""" + NAME = "Sample API Caller" SETTINGS = { @@ -26,7 +23,5 @@ class SampleApiCallerPlugin(APICallMixin, SettingsMixin, InvenTreePlugin): API_TOKEN_SETTING = 'API_TOKEN' def get_external_url(self): - """ - returns data from the sample endpoint - """ + """Returns data from the sample endpoint.""" return self.api_call('api/users/2') diff --git a/InvenTree/plugin/samples/integration/broken_file.py b/InvenTree/plugin/samples/integration/broken_file.py index 52c6005771..f56932e876 100644 --- a/InvenTree/plugin/samples/integration/broken_file.py +++ b/InvenTree/plugin/samples/integration/broken_file.py @@ -1,11 +1,10 @@ -"""sample of a broken python file that will be ignored on import""" +"""Sample of a broken python file that will be ignored on import.""" + from plugin import InvenTreePlugin class BrokenFileIntegrationPlugin(InvenTreePlugin): - """ - An very broken plugin - """ + """An very broken plugin.""" aaa = bb # noqa: F821 diff --git a/InvenTree/plugin/samples/integration/broken_sample.py b/InvenTree/plugin/samples/integration/broken_sample.py index ebd7821fe0..b873b93fa6 100644 --- a/InvenTree/plugin/samples/integration/broken_sample.py +++ b/InvenTree/plugin/samples/integration/broken_sample.py @@ -1,16 +1,16 @@ -"""sample of a broken plugin""" +"""Sample of a broken plugin.""" from plugin import InvenTreePlugin class BrokenIntegrationPlugin(InvenTreePlugin): - """ - An very broken plugin - """ + """A very broken plugin.""" + NAME = 'Test' TITLE = 'Broken Plugin' SLUG = 'broken' def __init__(self): + """Raise a KeyError to provoke a range of unit tests and safety mechanisms in the plugin loading mechanism.""" super().__init__() raise KeyError('This is a dummy error') diff --git a/InvenTree/plugin/samples/integration/custom_panel_sample.py b/InvenTree/plugin/samples/integration/custom_panel_sample.py index 3671dc532e..a3228582e6 100644 --- a/InvenTree/plugin/samples/integration/custom_panel_sample.py +++ b/InvenTree/plugin/samples/integration/custom_panel_sample.py @@ -1,6 +1,4 @@ -""" -Sample plugin which renders custom panels on certain pages -""" +"""Sample plugin which renders custom panels on certain pages.""" from part.views import PartDetail from plugin import InvenTreePlugin @@ -9,9 +7,7 @@ from stock.views import StockLocationDetail class CustomPanelSample(PanelMixin, SettingsMixin, InvenTreePlugin): - """ - A sample plugin which renders some custom panels. - """ + """A sample plugin which renders some custom panels.""" NAME = "CustomPanelExample" SLUG = "samplepanel" @@ -35,7 +31,7 @@ class CustomPanelSample(PanelMixin, SettingsMixin, InvenTreePlugin): } def get_panel_context(self, view, request, context): - + """Returns enriched context.""" ctx = super().get_panel_context(view, request, context) # If we are looking at a StockLocationDetail view, add location context object @@ -45,16 +41,13 @@ class CustomPanelSample(PanelMixin, SettingsMixin, InvenTreePlugin): return ctx def get_custom_panels(self, view, request): - - """ - You can decide at run-time which custom panels you want to display! + """You can decide at run-time which custom panels you want to display! - Display on every page - Only on a single page or set of pages - Only for a specific instance (e.g. part) - Based on the user viewing the page! """ - panels = [ { # Simple panel without any actual content diff --git a/InvenTree/plugin/samples/integration/label_sample.py b/InvenTree/plugin/samples/integration/label_sample.py index c01c575012..61ec11b829 100644 --- a/InvenTree/plugin/samples/integration/label_sample.py +++ b/InvenTree/plugin/samples/integration/label_sample.py @@ -1,12 +1,14 @@ +"""Simple sample for a plugin with the LabelPrintingMixin. + +This does not function in real usage and is more to show the required components and for unit tests. +""" from plugin import InvenTreePlugin from plugin.mixins import LabelPrintingMixin class SampleLabelPrinter(LabelPrintingMixin, InvenTreePlugin): - """ - Sample plugin which provides a 'fake' label printer endpoint - """ + """Sample plugin which provides a 'fake' label printer endpoint.""" NAME = "Label Printer" SLUG = "samplelabel" @@ -15,7 +17,10 @@ class SampleLabelPrinter(LabelPrintingMixin, InvenTreePlugin): VERSION = "0.2" def print_label(self, **kwargs): + """Sample printing step. + Normally here the connection to the printer and transfer of the label would take place. + """ # Test that the expected kwargs are present print(f"Printing Label: {kwargs['filename']} (User: {kwargs['user']})") print(f"Width: {kwargs['width']} x Height: {kwargs['height']}") diff --git a/InvenTree/plugin/samples/integration/sample.py b/InvenTree/plugin/samples/integration/sample.py index a768fb3de3..402159455a 100644 --- a/InvenTree/plugin/samples/integration/sample.py +++ b/InvenTree/plugin/samples/integration/sample.py @@ -1,6 +1,4 @@ -""" -Sample implementations for IntegrationPlugin -""" +"""Sample implementations for IntegrationPlugin.""" from django.http import HttpResponse from django.urls import include, re_path @@ -11,9 +9,7 @@ from plugin.mixins import AppMixin, NavigationMixin, SettingsMixin, UrlsMixin class SampleIntegrationPlugin(AppMixin, SettingsMixin, UrlsMixin, NavigationMixin, InvenTreePlugin): - """ - A full plugin example - """ + """A full plugin example.""" NAME = "SampleIntegrationPlugin" SLUG = "sample" @@ -23,10 +19,11 @@ class SampleIntegrationPlugin(AppMixin, SettingsMixin, UrlsMixin, NavigationMixi NAVIGATION_TAB_ICON = 'fas fa-plus' def view_test(self, request): - """very basic view""" + """Very basic view.""" return HttpResponse(f'Hi there {request.user.username} this works') def setup_urls(self): + """Urls that are exposed by this plugin.""" he_urls = [ re_path(r'^he/', self.view_test, name='he'), re_path(r'^ha/', self.view_test, name='ha'), diff --git a/InvenTree/plugin/samples/integration/scheduled_task.py b/InvenTree/plugin/samples/integration/scheduled_task.py index 2a59f820c6..101a2e3687 100644 --- a/InvenTree/plugin/samples/integration/scheduled_task.py +++ b/InvenTree/plugin/samples/integration/scheduled_task.py @@ -1,6 +1,4 @@ -""" -Sample plugin which supports task scheduling -""" +"""Sample plugin which supports task scheduling.""" from plugin import InvenTreePlugin from plugin.mixins import ScheduleMixin, SettingsMixin @@ -8,17 +6,23 @@ from plugin.mixins import ScheduleMixin, SettingsMixin # Define some simple tasks to perform def print_hello(): + """Sample function that can be called on schedule. + + Contents do not matter - therefore no coverage. + """ print("Hello") # pragma: no cover def print_world(): + """Sample function that can be called on schedule. + + Contents do not matter - therefore no coverage. + """ print("World") # pragma: no cover class ScheduledTaskPlugin(ScheduleMixin, SettingsMixin, InvenTreePlugin): - """ - A sample plugin which provides support for scheduled tasks - """ + """A sample plugin which provides support for scheduled tasks.""" NAME = "ScheduledTasksPlugin" SLUG = "schedule" @@ -51,10 +55,7 @@ class ScheduledTaskPlugin(ScheduleMixin, SettingsMixin, InvenTreePlugin): } def member_func(self, *args, **kwargs): - """ - A simple member function to demonstrate functionality - """ - + """A simple member function to demonstrate functionality.""" t_or_f = self.get_setting('T_OR_F') print(f"Called member_func - value is {t_or_f}") diff --git a/InvenTree/plugin/samples/integration/test_api_caller.py b/InvenTree/plugin/samples/integration/test_api_caller.py index 32ee07bbd3..28b51835c8 100644 --- a/InvenTree/plugin/samples/integration/test_api_caller.py +++ b/InvenTree/plugin/samples/integration/test_api_caller.py @@ -1,4 +1,4 @@ -""" Unit tests for action caller sample""" +"""Unit tests for action caller sample.""" from django.test import TestCase @@ -6,10 +6,10 @@ from plugin import registry class SampleApiCallerPluginTests(TestCase): - """ Tests for SampleApiCallerPluginTests """ + """Tests for SampleApiCallerPluginTests.""" def test_return(self): - """check if the external api call works""" + """Check if the external api call works.""" # The plugin should be defined self.assertIn('sample-api-caller', registry.plugins) plg = registry.plugins['sample-api-caller'] diff --git a/InvenTree/plugin/samples/integration/test_sample.py b/InvenTree/plugin/samples/integration/test_sample.py index 577aa5812e..bd04377f97 100644 --- a/InvenTree/plugin/samples/integration/test_sample.py +++ b/InvenTree/plugin/samples/integration/test_sample.py @@ -1,13 +1,13 @@ -""" Unit tests for action plugins """ +"""Unit tests for action plugins.""" from InvenTree.helpers import InvenTreeTestCase class SampleIntegrationPluginTests(InvenTreeTestCase): - """ Tests for SampleIntegrationPlugin """ + """Tests for SampleIntegrationPlugin.""" def test_view(self): - """check the function of the custom sample plugin """ + """Check the function of the custom sample plugin.""" response = self.client.get('/plugin/sample/ho/he/') self.assertEqual(response.status_code, 200) self.assertEqual(response.content, b'Hi there testuser this works') diff --git a/InvenTree/plugin/samples/integration/test_scheduled_task.py b/InvenTree/plugin/samples/integration/test_scheduled_task.py index c99a662ff9..32627c0aa9 100644 --- a/InvenTree/plugin/samples/integration/test_scheduled_task.py +++ b/InvenTree/plugin/samples/integration/test_scheduled_task.py @@ -1,4 +1,4 @@ -""" Unit tests for scheduled tasks""" +"""Unit tests for scheduled tasks.""" from django.test import TestCase @@ -9,10 +9,10 @@ from plugin.registry import call_function class ExampleScheduledTaskPluginTests(TestCase): - """ Tests for provided ScheduledTaskPlugin """ + """Tests for provided ScheduledTaskPlugin.""" def test_function(self): - """check if the scheduling works""" + """Check if the scheduling works.""" # The plugin should be defined self.assertIn('schedule', registry.plugins) plg = registry.plugins['schedule'] @@ -44,7 +44,7 @@ class ExampleScheduledTaskPluginTests(TestCase): self.assertEqual(len(scheduled_plugin_tasks), 0) def test_calling(self): - """check if a function can be called without errors""" + """Check if a function can be called without errors.""" # Check with right parameters self.assertEqual(call_function('schedule', 'member_func'), False) @@ -53,23 +53,22 @@ class ExampleScheduledTaskPluginTests(TestCase): class ScheduledTaskPluginTests(TestCase): - """ Tests for ScheduledTaskPluginTests mixin base """ + """Tests for ScheduledTaskPluginTests mixin base.""" def test_init(self): - """Check that all MixinImplementationErrors raise""" + """Check that all MixinImplementationErrors raise.""" class Base(ScheduleMixin, InvenTreePlugin): NAME = 'APlugin' class NoSchedules(Base): - """Plugin without schedules""" + """Plugin without schedules.""" pass with self.assertRaises(MixinImplementationError): NoSchedules() class WrongFuncSchedules(Base): - """ - Plugin with broken functions + """Plugin with broken functions. This plugin is missing a func """ @@ -88,8 +87,7 @@ class ScheduledTaskPluginTests(TestCase): WrongFuncSchedules() class WrongFuncSchedules1(WrongFuncSchedules): - """ - Plugin with broken functions + """Plugin with broken functions. This plugin is missing a schedule """ @@ -105,8 +103,7 @@ class ScheduledTaskPluginTests(TestCase): WrongFuncSchedules1() class WrongFuncSchedules2(WrongFuncSchedules): - """ - Plugin with broken functions + """Plugin with broken functions. This plugin is missing a schedule """ @@ -122,8 +119,7 @@ class ScheduledTaskPluginTests(TestCase): WrongFuncSchedules2() class WrongFuncSchedules3(WrongFuncSchedules): - """ - Plugin with broken functions + """Plugin with broken functions. This plugin has a broken schedule """ @@ -140,8 +136,7 @@ class ScheduledTaskPluginTests(TestCase): WrongFuncSchedules3() class WrongFuncSchedules4(WrongFuncSchedules): - """ - Plugin with broken functions + """Plugin with broken functions. This plugin is missing a minute marker for its schedule """ diff --git a/InvenTree/plugin/samples/locate/locate_sample.py b/InvenTree/plugin/samples/locate/locate_sample.py index d4ce411098..7cf63e2bf2 100644 --- a/InvenTree/plugin/samples/locate/locate_sample.py +++ b/InvenTree/plugin/samples/locate/locate_sample.py @@ -1,5 +1,4 @@ -""" -Sample plugin for locating stock items / locations. +"""Sample plugin for locating stock items / locations. Note: This plugin does not *actually* locate anything! """ @@ -13,8 +12,8 @@ logger = logging.getLogger('inventree') class SampleLocatePlugin(LocateMixin, InvenTreePlugin): - """ - A very simple example of the 'locate' plugin. + """A very simple example of the 'locate' plugin. + This plugin class simply prints location information to the logger. """ @@ -25,7 +24,11 @@ class SampleLocatePlugin(LocateMixin, InvenTreePlugin): VERSION = "0.2" def locate_stock_item(self, item_pk): + """Locate a StockItem. + Args: + item_pk: primary key for item + """ from stock.models import StockItem logger.info(f"SampleLocatePlugin attempting to locate item ID {item_pk}") @@ -41,7 +44,11 @@ class SampleLocatePlugin(LocateMixin, InvenTreePlugin): logger.error(f"StockItem ID {item_pk} does not exist!") def locate_stock_location(self, location_pk): + """Locate a StockLocation. + Args: + location_pk: primary key for location + """ from stock.models import StockLocation logger.info(f"SampleLocatePlugin attempting to locate location ID {location_pk}") diff --git a/InvenTree/plugin/samples/locate/test_locate_sample.py b/InvenTree/plugin/samples/locate/test_locate_sample.py index fe7ba28cca..1e85ceb566 100644 --- a/InvenTree/plugin/samples/locate/test_locate_sample.py +++ b/InvenTree/plugin/samples/locate/test_locate_sample.py @@ -1,4 +1,4 @@ -"""Unit tests for locate_sample sample plugins""" +"""Unit tests for locate_sample sample plugins.""" from django.urls import reverse @@ -9,7 +9,7 @@ from plugin.mixins import LocateMixin class SampleLocatePlugintests(InvenTreeAPITestCase): - """Tests for SampleLocatePlugin""" + """Tests for SampleLocatePlugin.""" fixtures = [ 'location', @@ -19,7 +19,7 @@ class SampleLocatePlugintests(InvenTreeAPITestCase): ] def test_run_locator(self): - """Check if the event is issued""" + """Check if the event is issued.""" # Activate plugin config = registry.get_plugin('samplelocate').plugin_config() config.active = True @@ -50,8 +50,7 @@ class SampleLocatePlugintests(InvenTreeAPITestCase): self.post(url, {'plugin': 'samplelocate', 'location': 1}, expected_code=200) def test_mixin(self): - """Test that MixinNotImplementedError is raised""" - + """Test that MixinNotImplementedError is raised.""" with self.assertRaises(MixinNotImplementedError): class Wrong(LocateMixin, InvenTreePlugin): pass diff --git a/InvenTree/plugin/serializers.py b/InvenTree/plugin/serializers.py index 4ab067eebb..28bc8f6fa7 100644 --- a/InvenTree/plugin/serializers.py +++ b/InvenTree/plugin/serializers.py @@ -1,6 +1,4 @@ -""" -JSON serializers for plugin app -""" +"""JSON serializers for plugin app.""" import os import subprocess @@ -17,24 +15,28 @@ from plugin.models import NotificationUserSetting, PluginConfig, PluginSetting class MetadataSerializer(serializers.ModelSerializer): - """ - Serializer class for model metadata API access. - """ + """Serializer class for model metadata API access.""" metadata = serializers.JSONField(required=True) def __init__(self, model_type, *args, **kwargs): - + """Initialize the metadata serializer with information on the model type""" self.Meta.model = model_type super().__init__(*args, **kwargs) class Meta: + """Metaclass options.""" + fields = [ 'metadata', ] def update(self, instance, data): + """Perform update on the metadata field: + - If this is a partial (PATCH) update, try to 'merge' data in + - Else, if it is a PUT update, overwrite any existing metadata + """ if self.partial: # Default behaviour is to "merge" new data in metadata = instance.metadata.copy() if instance.metadata else {} @@ -45,14 +47,13 @@ class MetadataSerializer(serializers.ModelSerializer): class PluginConfigSerializer(serializers.ModelSerializer): - """ - Serializer for a PluginConfig: - """ + """Serializer for a PluginConfig.""" meta = serializers.DictField(read_only=True) mixins = serializers.DictField(read_only=True) class Meta: + """Meta for serializer.""" model = PluginConfig fields = [ 'key', @@ -64,9 +65,7 @@ class PluginConfigSerializer(serializers.ModelSerializer): class PluginConfigInstallSerializer(serializers.Serializer): - """ - Serializer for installing a new plugin - """ + """Serializer for installing a new plugin.""" url = serializers.CharField( required=False, @@ -86,6 +85,7 @@ class PluginConfigInstallSerializer(serializers.Serializer): ) class Meta: + """Meta for serializer.""" fields = [ 'url', 'packagename', @@ -93,6 +93,10 @@ class PluginConfigInstallSerializer(serializers.Serializer): ] def validate(self, data): + """Validate inputs. + + Make sure both confirm and url are provided. + """ super().validate(data) # check the base requirements are met @@ -105,6 +109,7 @@ class PluginConfigInstallSerializer(serializers.Serializer): return data def save(self): + """Install a plugin from a package registry and set operational results as instance data.""" data = self.validated_data packagename = data.get('packagename', '') @@ -156,9 +161,7 @@ class PluginConfigInstallSerializer(serializers.Serializer): class PluginSettingSerializer(GenericReferencedSettingSerializer): - """ - Serializer for the PluginSetting model - """ + """Serializer for the PluginSetting model.""" MODEL = PluginSetting EXTRA_FIELDS = [ @@ -169,9 +172,7 @@ class PluginSettingSerializer(GenericReferencedSettingSerializer): class NotificationUserSettingSerializer(GenericReferencedSettingSerializer): - """ - Serializer for the PluginSetting model - """ + """Serializer for the PluginSetting model.""" MODEL = NotificationUserSetting EXTRA_FIELDS = ['method', ] diff --git a/InvenTree/plugin/template.py b/InvenTree/plugin/template.py index f33b5f13fb..c16135a7a3 100644 --- a/InvenTree/plugin/template.py +++ b/InvenTree/plugin/template.py @@ -1,4 +1,4 @@ -"""Load templates for loaded plugins""" +"""Load templates for loaded plugins.""" from pathlib import Path @@ -8,8 +8,7 @@ from plugin import registry class PluginTemplateLoader(FilesystemLoader): - """ - A custom template loader which allows loading of templates from installed plugins. + """A custom template loader which allows loading of templates from installed plugins. Each plugin can register templates simply by providing a 'templates' directory in its root path. @@ -22,6 +21,7 @@ class PluginTemplateLoader(FilesystemLoader): """ def get_dirs(self): + """Returns all template dir paths in plugins.""" dirname = 'templates' template_dirs = [] diff --git a/InvenTree/plugin/templatetags/plugin_extras.py b/InvenTree/plugin/templatetags/plugin_extras.py index b949ced8e3..35d4db5e17 100644 --- a/InvenTree/plugin/templatetags/plugin_extras.py +++ b/InvenTree/plugin/templatetags/plugin_extras.py @@ -1,7 +1,5 @@ -# -*- coding: utf-8 -*- +"""This module provides template tags for handeling plugins.""" -""" This module provides template tags for handeling plugins -""" from django import template from django.conf import settings as djangosettings from django.urls import reverse @@ -15,49 +13,37 @@ register = template.Library() @register.simple_tag() def plugin_list(*args, **kwargs): - """ - List of all installed plugins - """ + """List of all installed plugins.""" return registry.plugins @register.simple_tag() def inactive_plugin_list(*args, **kwargs): - """ - List of all inactive plugins - """ + """List of all inactive plugins.""" return registry.plugins_inactive @register.simple_tag() def plugin_settings(plugin, *args, **kwargs): - """ - List of all settings for the plugin - """ + """List of all settings for the plugin.""" return registry.mixins_settings.get(plugin) @register.simple_tag() def mixin_enabled(plugin, key, *args, **kwargs): - """ - Is the mixin registerd and configured in the plugin? - """ + """Is the mixin registerd and configured in the plugin?""" return plugin.mixin_enabled(key) @register.simple_tag() def mixin_available(mixin, *args, **kwargs): - """ - Returns True if there is at least one active plugin which supports the provided mixin - """ + """Returns True if there is at least one active plugin which supports the provided mixin.""" return len(registry.with_mixin(mixin)) > 0 @register.simple_tag() def navigation_enabled(*args, **kwargs): - """ - Is plugin navigation enabled? - """ + """Is plugin navigation enabled?""" if djangosettings.PLUGIN_TESTING: return True return InvenTreeSetting.get_setting('ENABLE_PLUGINS_NAVIGATION') # pragma: no cover @@ -65,8 +51,8 @@ def navigation_enabled(*args, **kwargs): @register.simple_tag() def safe_url(view_name, *args, **kwargs): - """ - Safe lookup fnc for URLs + """Safe lookup fnc for URLs. + Returns None if not found """ try: @@ -77,15 +63,11 @@ def safe_url(view_name, *args, **kwargs): @register.simple_tag() def plugin_errors(*args, **kwargs): - """ - All plugin errors in the current session - """ + """All plugin errors in the current session.""" return registry.errors @register.simple_tag(takes_context=True) def notification_settings_list(context, *args, **kwargs): - """ - List of all user notification settings - """ + """List of all user notification settings.""" return storage.get_usersettings(user=context.get('user', None)) diff --git a/InvenTree/plugin/test_api.py b/InvenTree/plugin/test_api.py index e04c0ee84b..53e5a67e2c 100644 --- a/InvenTree/plugin/test_api.py +++ b/InvenTree/plugin/test_api.py @@ -1,3 +1,4 @@ +"""Tests for general API tests for the plugin app.""" from django.urls import reverse @@ -5,9 +6,7 @@ from InvenTree.api_tester import InvenTreeAPITestCase class PluginDetailAPITest(InvenTreeAPITestCase): - """ - Tests the plugin API endpoints - """ + """Tests the plugin API endpoints.""" roles = [ 'admin.add', @@ -17,6 +16,7 @@ class PluginDetailAPITest(InvenTreeAPITestCase): ] def setUp(self): + """Setup for all tests.""" self.MSG_NO_PKG = 'Either packagename of URL must be provided' self.PKG_NAME = 'minimal' @@ -24,9 +24,7 @@ class PluginDetailAPITest(InvenTreeAPITestCase): super().setUp() def test_plugin_install(self): - """ - Test the plugin install command - """ + """Test the plugin install command.""" url = reverse('api-plugin-install') # valid - Pypi @@ -73,9 +71,7 @@ class PluginDetailAPITest(InvenTreeAPITestCase): self.assertEqual(data['confirm'][0].title().upper(), 'Installation not confirmed'.upper()) def test_admin_action(self): - """ - Test the PluginConfig action commands - """ + """Test the PluginConfig action commands.""" from plugin import registry from plugin.models import PluginConfig @@ -132,9 +128,7 @@ class PluginDetailAPITest(InvenTreeAPITestCase): self.assertEqual(response.status_code, 200) def test_model(self): - """ - Test the PluginConfig model - """ + """Test the PluginConfig model.""" from plugin import registry from plugin.models import PluginConfig diff --git a/InvenTree/plugin/test_helpers.py b/InvenTree/plugin/test_helpers.py index 1b9cd104cc..53b2622592 100644 --- a/InvenTree/plugin/test_helpers.py +++ b/InvenTree/plugin/test_helpers.py @@ -1,4 +1,4 @@ -"""Unit tests for helpers.py""" +"""Unit tests for helpers.py.""" from django.test import TestCase @@ -6,10 +6,10 @@ from .helpers import render_template class HelperTests(TestCase): - """Tests for helpers""" + """Tests for helpers.""" def test_render_template(self): - """Check if render_template helper works""" + """Check if render_template helper works.""" class ErrorSource: slug = 'sampleplg' diff --git a/InvenTree/plugin/test_plugin.py b/InvenTree/plugin/test_plugin.py index 1516c4eee8..330071bf3d 100644 --- a/InvenTree/plugin/test_plugin.py +++ b/InvenTree/plugin/test_plugin.py @@ -1,6 +1,4 @@ -""" -Unit tests for plugins -""" +"""Unit tests for plugins.""" from datetime import datetime @@ -14,30 +12,31 @@ from plugin.samples.integration.sample import SampleIntegrationPlugin class PluginTagTests(TestCase): - """ Tests for the plugin extras """ + """Tests for the plugin extras.""" def setUp(self): + """Setup for all tests.""" self.sample = SampleIntegrationPlugin() self.plugin_no = NoIntegrationPlugin() self.plugin_wrong = WrongIntegrationPlugin() def test_tag_plugin_list(self): - """test that all plugins are listed""" + """Test that all plugins are listed.""" self.assertEqual(plugin_tags.plugin_list(), registry.plugins) def test_tag_incative_plugin_list(self): - """test that all inactive plugins are listed""" + """Test that all inactive plugins are listed.""" self.assertEqual(plugin_tags.inactive_plugin_list(), registry.plugins_inactive) def test_tag_plugin_settings(self): - """check all plugins are listed""" + """Check all plugins are listed.""" self.assertEqual( plugin_tags.plugin_settings(self.sample), registry.mixins_settings.get(self.sample) ) def test_tag_mixin_enabled(self): - """check that mixin enabled functions work""" + """Check that mixin enabled functions work.""" key = 'urls' # mixin enabled self.assertEqual(plugin_tags.mixin_enabled(self.sample, key), True) @@ -47,25 +46,26 @@ class PluginTagTests(TestCase): self.assertEqual(plugin_tags.mixin_enabled(self.plugin_no, key), False) def test_tag_safe_url(self): - """test that the safe url tag works expected""" + """Test that the safe url tag works expected.""" # right url self.assertEqual(plugin_tags.safe_url('api-plugin-install'), '/api/plugin/install/') # wrong url self.assertEqual(plugin_tags.safe_url('indexas'), None) def test_tag_plugin_errors(self): - """test that all errors are listed""" + """Test that all errors are listed.""" self.assertEqual(plugin_tags.plugin_errors(), registry.errors) class InvenTreePluginTests(TestCase): - """ Tests for InvenTreePlugin """ + """Tests for InvenTreePlugin.""" def setUp(self): + """Setup for all tests.""" self.plugin = InvenTreePlugin() class NamedPlugin(InvenTreePlugin): - """a named plugin""" + """a named plugin.""" NAME = 'abc123' self.named_plugin = NamedPlugin() @@ -95,21 +95,21 @@ class InvenTreePluginTests(TestCase): self.plugin_sample = SampleIntegrationPlugin() def test_basic_plugin_init(self): - """check if a basic plugin intis""" + """Check if a basic plugin intis.""" self.assertEqual(self.plugin.NAME, '') self.assertEqual(self.plugin.plugin_name(), '') def test_basic_plugin_name(self): - """check if the name of a basic plugin can be set""" + """Check if the name of a basic plugin can be set.""" self.assertEqual(self.named_plugin.NAME, 'abc123') self.assertEqual(self.named_plugin.plugin_name(), 'abc123') def test_basic_is_active(self): - """check if a basic plugin is active""" + """Check if a basic plugin is active.""" self.assertEqual(self.plugin.is_active(), False) def test_action_name(self): - """check the name definition possibilities""" + """Check the name definition possibilities.""" # plugin_name self.assertEqual(self.plugin.plugin_name(), '') self.assertEqual(self.plugin_simple.plugin_name(), 'SimplePlugin') @@ -156,8 +156,7 @@ class InvenTreePluginTests(TestCase): self.assertEqual(self.plugin_name.license, 'MIT') def test_depreciation(self): - """Check if depreciations raise as expected""" - + """Check if depreciations raise as expected.""" # check deprecation warning is firing with self.assertWarns(DeprecationWarning): self.assertEqual(self.plugin_old.slug, 'old') diff --git a/InvenTree/plugin/urls.py b/InvenTree/plugin/urls.py index be90ff1618..6da54a83ee 100644 --- a/InvenTree/plugin/urls.py +++ b/InvenTree/plugin/urls.py @@ -1,6 +1,4 @@ -""" -URL lookup for plugin app -""" +"""URL lookup for plugin app.""" from django.urls import include, re_path @@ -10,10 +8,7 @@ PLUGIN_BASE = 'plugin' # Constant for links def get_plugin_urls(): - """ - Returns a urlpattern that can be integrated into the global urls - """ - + """Returns a urlpattern that can be integrated into the global urls.""" urls = [] for plugin in registry.plugins.values(): diff --git a/InvenTree/plugin/views.py b/InvenTree/plugin/views.py index 066047d33f..aa43b9bd09 100644 --- a/InvenTree/plugin/views.py +++ b/InvenTree/plugin/views.py @@ -1,3 +1,5 @@ +"""Views for plugin app.""" + import logging import sys import traceback @@ -13,19 +15,13 @@ logger = logging.getLogger('inventree') class InvenTreePluginViewMixin: - """ - Custom view mixin which adds context data to the view, - based on loaded plugins. + """Custom view mixin which adds context data to the view, based on loaded plugins. This allows rendered pages to be augmented by loaded plugins. - """ def get_plugin_panels(self, ctx): - """ - Return a list of extra 'plugin panels' associated with this view - """ - + """Return a list of extra 'plugin panels' associated with this view.""" panels = [] for plug in registry.with_mixin('panel', active=True): @@ -50,10 +46,7 @@ class InvenTreePluginViewMixin: return panels def get_context_data(self, **kwargs): - """ - Add plugin context data to the view - """ - + """Add plugin context data to the view.""" ctx = super().get_context_data(**kwargs) if settings.PLUGINS_ENABLED: diff --git a/InvenTree/plugins/__init__.py b/InvenTree/plugins/__init__.py index 926e30e23c..a77bf4fe18 100644 --- a/InvenTree/plugins/__init__.py +++ b/InvenTree/plugins/__init__.py @@ -1,5 +1,5 @@ """ -Directory for custom plugin development +Directory for custom plugin development. Please read the docs for more information https://inventree.readthedocs.io/en/latest/extend/plugins/#local-directory """ diff --git a/InvenTree/report/admin.py b/InvenTree/report/admin.py index 6e1b919633..adf813aede 100644 --- a/InvenTree/report/admin.py +++ b/InvenTree/report/admin.py @@ -1,3 +1,5 @@ +"""Admin functionality for the 'report' app""" + from django.contrib import admin from .models import (BillOfMaterialsReport, BuildReport, PurchaseOrderReport, @@ -5,17 +7,17 @@ from .models import (BillOfMaterialsReport, BuildReport, PurchaseOrderReport, class ReportTemplateAdmin(admin.ModelAdmin): - + """Admin class for the various reporting models""" list_display = ('name', 'description', 'template', 'filters', 'enabled', 'revision') class ReportSnippetAdmin(admin.ModelAdmin): - + """Admin class for the ReportSnippet model""" list_display = ('id', 'snippet', 'description') class ReportAssetAdmin(admin.ModelAdmin): - + """Admin class for the ReportAsset model""" list_display = ('id', 'asset', 'description') diff --git a/InvenTree/report/api.py b/InvenTree/report/api.py index 106da81f08..79c2d098e5 100644 --- a/InvenTree/report/api.py +++ b/InvenTree/report/api.py @@ -1,3 +1,4 @@ +"""API functionality for the 'report' app""" from django.core.exceptions import FieldError, ValidationError from django.http import HttpResponse @@ -24,9 +25,7 @@ from .serializers import (BOMReportSerializer, BuildReportSerializer, class ReportListView(generics.ListAPIView): - """ - Generic API class for report templates - """ + """Generic API class for report templates.""" filter_backends = [ DjangoFilterBackend, @@ -44,15 +43,10 @@ class ReportListView(generics.ListAPIView): class StockItemReportMixin: - """ - Mixin for extracting stock items from query params - """ + """Mixin for extracting stock items from query params.""" def get_items(self): - """ - Return a list of requested stock items - """ - + """Return a list of requested stock items.""" items = [] params = self.request.query_params @@ -77,15 +71,10 @@ class StockItemReportMixin: class BuildReportMixin: - """ - Mixin for extracting Build items from query params - """ + """Mixin for extracting Build items from query params.""" def get_builds(self): - """ - Return a list of requested Build objects - """ - + """Return a list of requested Build objects.""" builds = [] params = self.request.query_params @@ -109,17 +98,13 @@ class BuildReportMixin: class OrderReportMixin: - """ - Mixin for extracting order items from query params + """Mixin for extracting order items from query params. requires the OrderModel class attribute to be set! """ def get_orders(self): - """ - Return a list of order objects - """ - + """Return a list of order objects.""" orders = [] params = self.request.query_params @@ -143,15 +128,10 @@ class OrderReportMixin: class PartReportMixin: - """ - Mixin for extracting part items from query params - """ + """Mixin for extracting part items from query params.""" def get_parts(self): - """ - Return a list of requested part objects - """ - + """Return a list of requested part objects.""" parts = [] params = self.request.query_params @@ -176,15 +156,10 @@ class PartReportMixin: class ReportPrintMixin: - """ - Mixin for printing reports - """ + """Mixin for printing reports.""" def print(self, request, items_to_print): - """ - Print this report template against a number of pre-validated items. - """ - + """Print this report template against a number of pre-validated items.""" if len(items_to_print) == 0: # No valid items provided, return an error message data = { @@ -229,19 +204,13 @@ class ReportPrintMixin: report_name += '.pdf' if debug_mode: - """ - Contatenate all rendered templates into a single HTML string, - and return the string as a HTML response. - """ + """Contatenate all rendered templates into a single HTML string, and return the string as a HTML response.""" html = "\n".join(outputs) return HttpResponse(html) else: - """ - Concatenate all rendered pages into a single PDF object, - and return the resulting document! - """ + """Concatenate all rendered pages into a single PDF object, and return the resulting document!""" pages = [] @@ -283,21 +252,19 @@ class ReportPrintMixin: class StockItemTestReportList(ReportListView, StockItemReportMixin): - """ - API endpoint for viewing list of TestReport objects. + """API endpoint for viewing list of TestReport objects. Filterable by: - enabled: Filter by enabled / disabled status - item: Filter by stock item(s) - """ queryset = TestReport.objects.all() serializer_class = TestReportSerializer def filter_queryset(self, queryset): - + """Custom queryset filtering""" queryset = super().filter_queryset(queryset) # List of StockItem objects to match against @@ -347,35 +314,27 @@ class StockItemTestReportList(ReportListView, StockItemReportMixin): class StockItemTestReportDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for a single TestReport object - """ + """API endpoint for a single TestReport object.""" queryset = TestReport.objects.all() serializer_class = TestReportSerializer class StockItemTestReportPrint(generics.RetrieveAPIView, StockItemReportMixin, ReportPrintMixin): - """ - API endpoint for printing a TestReport object - """ + """API endpoint for printing a TestReport object.""" queryset = TestReport.objects.all() serializer_class = TestReportSerializer def get(self, request, *args, **kwargs): - """ - Check if valid stock item(s) have been provided. - """ - + """Check if valid stock item(s) have been provided.""" items = self.get_items() return self.print(request, items) class BOMReportList(ReportListView, PartReportMixin): - """ - API endpoint for viewing a list of BillOfMaterialReport objects. + """API endpoint for viewing a list of BillOfMaterialReport objects. Filterably by: @@ -387,7 +346,7 @@ class BOMReportList(ReportListView, PartReportMixin): serializer_class = BOMReportSerializer def filter_queryset(self, queryset): - + """Custom queryset filtering""" queryset = super().filter_queryset(queryset) # List of Part objects to match against @@ -436,35 +395,27 @@ class BOMReportList(ReportListView, PartReportMixin): class BOMReportDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for a single BillOfMaterialReport object - """ + """API endpoint for a single BillOfMaterialReport object.""" queryset = BillOfMaterialsReport.objects.all() serializer_class = BOMReportSerializer class BOMReportPrint(generics.RetrieveAPIView, PartReportMixin, ReportPrintMixin): - """ - API endpoint for printing a BillOfMaterialReport object - """ + """API endpoint for printing a BillOfMaterialReport object.""" queryset = BillOfMaterialsReport.objects.all() serializer_class = BOMReportSerializer def get(self, request, *args, **kwargs): - """ - Check if valid part item(s) have been provided - """ - + """Check if valid part item(s) have been provided.""" parts = self.get_parts() return self.print(request, parts) class BuildReportList(ReportListView, BuildReportMixin): - """ - API endpoint for viewing a list of BuildReport objects. + """API endpoint for viewing a list of BuildReport objects. Can be filtered by: @@ -476,7 +427,7 @@ class BuildReportList(ReportListView, BuildReportMixin): serializer_class = BuildReportSerializer def filter_queryset(self, queryset): - + """Custom queryset filtering""" queryset = super().filter_queryset(queryset) # List of Build objects to match against @@ -526,45 +477,41 @@ class BuildReportList(ReportListView, BuildReportMixin): class BuildReportDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for a single BuildReport object - """ + """API endpoint for a single BuildReport object.""" queryset = BuildReport.objects.all() serializer_class = BuildReportSerializer class BuildReportPrint(generics.RetrieveAPIView, BuildReportMixin, ReportPrintMixin): - """ - API endpoint for printing a BuildReport - """ + """API endpoint for printing a BuildReport.""" queryset = BuildReport.objects.all() serializer_class = BuildReportSerializer def get(self, request, *ars, **kwargs): - + """Perform a GET action to print the report""" builds = self.get_builds() return self.print(request, builds) class PurchaseOrderReportList(ReportListView, OrderReportMixin): - + """API list endpoint for the PurchaseOrderReport model""" OrderModel = order.models.PurchaseOrder queryset = PurchaseOrderReport.objects.all() serializer_class = PurchaseOrderReportSerializer def filter_queryset(self, queryset): - + """Custom queryset filter for the PurchaseOrderReport list""" queryset = super().filter_queryset(queryset) orders = self.get_orders() if len(orders) > 0: """ - We wish to filter by purchase orders + We wish to filter by purchase orders. We need to compare the 'filters' string of each report, and see if it matches against each of the specified orders. @@ -607,18 +554,14 @@ class PurchaseOrderReportList(ReportListView, OrderReportMixin): class PurchaseOrderReportDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for a single PurchaseOrderReport object - """ + """API endpoint for a single PurchaseOrderReport object.""" queryset = PurchaseOrderReport.objects.all() serializer_class = PurchaseOrderReportSerializer class PurchaseOrderReportPrint(generics.RetrieveAPIView, OrderReportMixin, ReportPrintMixin): - """ - API endpoint for printing a PurchaseOrderReport object - """ + """API endpoint for printing a PurchaseOrderReport object.""" OrderModel = order.models.PurchaseOrder @@ -626,28 +569,28 @@ class PurchaseOrderReportPrint(generics.RetrieveAPIView, OrderReportMixin, Repor serializer_class = PurchaseOrderReportSerializer def get(self, request, *args, **kwargs): - + """Perform GET request to print the report""" orders = self.get_orders() return self.print(request, orders) class SalesOrderReportList(ReportListView, OrderReportMixin): - + """API list endpoint for the SalesOrderReport model""" OrderModel = order.models.SalesOrder queryset = SalesOrderReport.objects.all() serializer_class = SalesOrderReportSerializer def filter_queryset(self, queryset): - + """Custom queryset filtering for the SalesOrderReport API list""" queryset = super().filter_queryset(queryset) orders = self.get_orders() if len(orders) > 0: """ - We wish to filter by purchase orders + We wish to filter by purchase orders. We need to compare the 'filters' string of each report, and see if it matches against each of the specified orders. @@ -690,18 +633,14 @@ class SalesOrderReportList(ReportListView, OrderReportMixin): class SalesOrderReportDetail(generics.RetrieveUpdateDestroyAPIView): - """ - API endpoint for a single SalesOrderReport object - """ + """API endpoint for a single SalesOrderReport object.""" queryset = SalesOrderReport.objects.all() serializer_class = SalesOrderReportSerializer class SalesOrderReportPrint(generics.RetrieveAPIView, OrderReportMixin, ReportPrintMixin): - """ - API endpoint for printing a PurchaseOrderReport object - """ + """API endpoint for printing a PurchaseOrderReport object.""" OrderModel = order.models.SalesOrder @@ -709,7 +648,7 @@ class SalesOrderReportPrint(generics.RetrieveAPIView, OrderReportMixin, ReportPr serializer_class = SalesOrderReportSerializer def get(self, request, *args, **kwargs): - + """Perform a GET request to print the report""" orders = self.get_orders() return self.print(request, orders) diff --git a/InvenTree/report/apps.py b/InvenTree/report/apps.py index cc7fd6eded..9e91627e62 100644 --- a/InvenTree/report/apps.py +++ b/InvenTree/report/apps.py @@ -1,3 +1,5 @@ +"""Config options for the 'report' app""" + import logging import os import shutil @@ -11,22 +13,17 @@ logger = logging.getLogger("inventree") class ReportConfig(AppConfig): + """Configuration class for the 'report' app""" name = 'report' def ready(self): - """ - This function is called whenever the report app is loaded - """ - + """This function is called whenever the report app is loaded.""" if canAppAccessDatabase(allow_test=True): self.create_default_test_reports() self.create_default_build_reports() def create_default_reports(self, model, reports): - """ - Copy defualt report files across to the media directory. - """ - + """Copy defualt report files across to the media directory.""" # Source directory for report templates src_dir = os.path.join( os.path.dirname(os.path.realpath(__file__)), @@ -82,11 +79,7 @@ class ReportConfig(AppConfig): pass def create_default_test_reports(self): - """ - Create database entries for the default TestReport templates, - if they do not already exist - """ - + """Create database entries for the default TestReport templates, if they do not already exist.""" try: from .models import TestReport except: # pragma: no cover @@ -105,11 +98,7 @@ class ReportConfig(AppConfig): self.create_default_reports(TestReport, reports) def create_default_build_reports(self): - """ - Create database entries for the default BuildReport templates - (if they do not already exist) - """ - + """Create database entries for the default BuildReport templates (if they do not already exist)""" try: from .models import BuildReport except: # pragma: no cover diff --git a/InvenTree/report/models.py b/InvenTree/report/models.py index bc852b4faa..453ade931f 100644 --- a/InvenTree/report/models.py +++ b/InvenTree/report/models.py @@ -1,6 +1,4 @@ -""" -Report template model definitions -""" +"""Report template model definitions.""" import datetime import logging @@ -36,94 +34,78 @@ logger = logging.getLogger("inventree") def rename_template(instance, filename): - """ - Helper function for 'renaming' uploaded report files. + """Helper function for 'renaming' uploaded report files. + Pass responsibility back to the calling class, to ensure that files are uploaded to the correct directory. """ - return instance.rename_file(filename) def validate_stock_item_report_filters(filters): - """ - Validate filter string against StockItem model - """ - + """Validate filter string against StockItem model.""" return validateFilterString(filters, model=stock.models.StockItem) def validate_part_report_filters(filters): - """ - Validate filter string against Part model - """ - + """Validate filter string against Part model.""" return validateFilterString(filters, model=part.models.Part) def validate_build_report_filters(filters): - """ - Validate filter string against Build model - """ - + """Validate filter string against Build model.""" return validateFilterString(filters, model=build.models.Build) def validate_purchase_order_filters(filters): - """ - Validate filter string against PurchaseOrder model - """ - + """Validate filter string against PurchaseOrder model.""" return validateFilterString(filters, model=order.models.PurchaseOrder) def validate_sales_order_filters(filters): - """ - Validate filter string against SalesOrder model - """ - + """Validate filter string against SalesOrder model.""" return validateFilterString(filters, model=order.models.SalesOrder) class WeasyprintReportMixin(WeasyTemplateResponseMixin): - """ - Class for rendering a HTML template to a PDF. - """ + """Class for rendering a HTML template to a PDF.""" pdf_filename = 'report.pdf' pdf_attachment = True def __init__(self, request, template, **kwargs): - + """Initialize the report mixin with some standard attributes""" self.request = request self.template_name = template self.pdf_filename = kwargs.get('filename', 'report.pdf') class ReportBase(models.Model): - """ - Base class for uploading html templates - """ + """Base class for uploading html templates.""" class Meta: + """Metaclass options. Abstract ensures no database table is created.""" + abstract = True def save(self, *args, **kwargs): - + """Perform additional actions when the report is saved""" # Increment revision number self.revision += 1 super().save() def __str__(self): + """Format a string representation of a report instance""" return "{n} - {d}".format(n=self.name, d=self.description) @classmethod def getSubdir(cls): + """Return the subdirectory where template files for this report model will be located.""" return '' def rename_file(self, filename): - # Function for renaming uploaded file + """Function for renaming uploaded file""" filename = os.path.basename(filename) @@ -147,15 +129,15 @@ class ReportBase(models.Model): @property def extension(self): + """Return the filename extension of the associated template file""" return os.path.splitext(self.template.name)[1].lower() @property def template_name(self): - """ - Returns the file system path to the template file. + """Returns the file system path to the template file. + Required for passing the file to an external process """ - template = self.template.name template = template.replace('/', os.path.sep) template = template.replace('\\', os.path.sep) @@ -192,28 +174,20 @@ class ReportBase(models.Model): class ReportTemplateBase(ReportBase): - """ - Reporting template model. + """Reporting template model. Able to be passed context data - """ # Pass a single top-level object to the report template object_to_print = None def get_context_data(self, request): - """ - Supply context data to the template for rendering - """ - + """Supply context data to the template for rendering.""" return {} def context(self, request): - """ - All context to be passed to the renderer. - """ - + """All context to be passed to the renderer.""" # Generate custom context data based on the particular report subclass context = self.get_context_data(request) @@ -230,10 +204,7 @@ class ReportTemplateBase(ReportBase): return context def generate_filename(self, request, **kwargs): - """ - Generate a filename for this report - """ - + """Generate a filename for this report.""" template_string = Template(self.filename_pattern) ctx = self.context(request) @@ -243,21 +214,17 @@ class ReportTemplateBase(ReportBase): return template_string.render(context) def render_as_string(self, request, **kwargs): - """ - Render the report to a HTML string. + """Render the report to a HTML string. Useful for debug mode (viewing generated code) """ - return render_to_string(self.template_name, self.context(request), request) def render(self, request, **kwargs): - """ - Render the template to a PDF file. + """Render the template to a PDF file. Uses django-weasyprint plugin to render HTML template against Weasyprint """ - # TODO: Support custom filename generation! # filename = kwargs.get('filename', 'report.pdf') @@ -288,20 +255,22 @@ class ReportTemplateBase(ReportBase): ) class Meta: + """Metaclass options. Abstract ensures no database table is created.""" + abstract = True class TestReport(ReportTemplateBase): - """ - Render a TestReport against a StockItem object. - """ + """Render a TestReport against a StockItem object.""" @staticmethod def get_api_url(): + """Return the API URL associated with the TestReport model""" return reverse('api-stockitem-testreport-list') @classmethod def getSubdir(cls): + """Return the subdirectory where TestReport templates are located""" return 'test' filters = models.CharField( @@ -321,10 +290,7 @@ class TestReport(ReportTemplateBase): ) def matches_stock_item(self, item): - """ - Test if this report template matches a given StockItem objects - """ - + """Test if this report template matches a given StockItem objects.""" try: filters = validateFilterString(self.filters) items = stock.models.StockItem.objects.filter(**filters) @@ -337,7 +303,7 @@ class TestReport(ReportTemplateBase): return items.exists() def get_context_data(self, request): - + """Return custom context data for the TestReport template""" stock_item = self.object_to_print return { @@ -352,16 +318,16 @@ class TestReport(ReportTemplateBase): class BuildReport(ReportTemplateBase): - """ - Build order / work order report - """ + """Build order / work order report.""" @staticmethod def get_api_url(): + """Return the API URL associated with the BuildReport model""" return reverse('api-build-report-list') @classmethod def getSubdir(cls): + """Return the subdirectory where BuildReport templates are located""" return 'build' filters = models.CharField( @@ -375,10 +341,7 @@ class BuildReport(ReportTemplateBase): ) def get_context_data(self, request): - """ - Custom context data for the build report - """ - + """Custom context data for the build report.""" my_build = self.object_to_print if type(my_build) != build.models.Build: @@ -395,16 +358,16 @@ class BuildReport(ReportTemplateBase): class BillOfMaterialsReport(ReportTemplateBase): - """ - Render a Bill of Materials against a Part object - """ + """Render a Bill of Materials against a Part object.""" @staticmethod def get_api_url(): + """Return the API URL associated with the BillOfMaterialsReport model""" return reverse('api-bom-report-list') @classmethod def getSubdir(cls): + """Retun the directory where BillOfMaterialsReport templates are located""" return 'bom' filters = models.CharField( @@ -418,7 +381,7 @@ class BillOfMaterialsReport(ReportTemplateBase): ) def get_context_data(self, request): - + """Return custom context data for the BillOfMaterialsReport template""" part = self.object_to_print return { @@ -429,16 +392,16 @@ class BillOfMaterialsReport(ReportTemplateBase): class PurchaseOrderReport(ReportTemplateBase): - """ - Render a report against a PurchaseOrder object - """ + """Render a report against a PurchaseOrder object.""" @staticmethod def get_api_url(): + """Return the API URL associated with the PurchaseOrderReport model""" return reverse('api-po-report-list') @classmethod def getSubdir(cls): + """Return the directory where PurchaseOrderReport templates are stored""" return 'purchaseorder' filters = models.CharField( @@ -452,7 +415,7 @@ class PurchaseOrderReport(ReportTemplateBase): ) def get_context_data(self, request): - + """Return custom context data for the PurchaseOrderReport template""" order = self.object_to_print return { @@ -468,16 +431,16 @@ class PurchaseOrderReport(ReportTemplateBase): class SalesOrderReport(ReportTemplateBase): - """ - Render a report against a SalesOrder object - """ + """Render a report against a SalesOrder object.""" @staticmethod def get_api_url(): + """Return the API URL associated with the SalesOrderReport model""" return reverse('api-so-report-list') @classmethod def getSubdir(cls): + """Retun the subdirectory where SalesOrderReport templates are located""" return 'salesorder' filters = models.CharField( @@ -491,7 +454,7 @@ class SalesOrderReport(ReportTemplateBase): ) def get_context_data(self, request): - + """Return custom context data for a SalesOrderReport template""" order = self.object_to_print return { @@ -507,6 +470,7 @@ class SalesOrderReport(ReportTemplateBase): def rename_snippet(instance, filename): + """Function to rename a report snippet once uploaded""" filename = os.path.basename(filename) @@ -530,9 +494,7 @@ def rename_snippet(instance, filename): class ReportSnippet(models.Model): - """ - Report template 'snippet' which can be used to make templates - that can then be included in other reports. + """Report template 'snippet' which can be used to make templates that can then be included in other reports. Useful for 'common' template actions, sub-templates, etc """ @@ -548,6 +510,7 @@ class ReportSnippet(models.Model): def rename_asset(instance, filename): + """Function to rename an asset file when uploaded""" filename = os.path.basename(filename) @@ -567,14 +530,15 @@ def rename_asset(instance, filename): class ReportAsset(models.Model): - """ - Asset file for use in report templates. + """Asset file for use in report templates. + For example, an image to use in a header file. Uploaded asset files appear in MEDIA_ROOT/report/assets, and can be loaded in a template using the {% report_asset %} tag. """ def __str__(self): + """String representation of a ReportAsset instance""" return os.path.basename(self.asset.name) asset = models.FileField( diff --git a/InvenTree/report/serializers.py b/InvenTree/report/serializers.py index daa1e254d5..50e9592373 100644 --- a/InvenTree/report/serializers.py +++ b/InvenTree/report/serializers.py @@ -1,3 +1,4 @@ +"""API serializers for the reporting models""" from InvenTree.serializers import (InvenTreeAttachmentSerializerField, InvenTreeModelSerializer) @@ -7,10 +8,13 @@ from .models import (BillOfMaterialsReport, BuildReport, PurchaseOrderReport, class TestReportSerializer(InvenTreeModelSerializer): + """Serializer class for the TestReport model""" template = InvenTreeAttachmentSerializerField(required=True) class Meta: + """Metaclass options.""" + model = TestReport fields = [ 'pk', @@ -23,10 +27,13 @@ class TestReportSerializer(InvenTreeModelSerializer): class BuildReportSerializer(InvenTreeModelSerializer): + """Serializer class for the BuildReport model""" template = InvenTreeAttachmentSerializerField(required=True) class Meta: + """Metaclass options.""" + model = BuildReport fields = [ 'pk', @@ -39,10 +46,12 @@ class BuildReportSerializer(InvenTreeModelSerializer): class BOMReportSerializer(InvenTreeModelSerializer): - + """Serializer class for the BillOfMaterialsReport model""" template = InvenTreeAttachmentSerializerField(required=True) class Meta: + """Metaclass options.""" + model = BillOfMaterialsReport fields = [ 'pk', @@ -55,10 +64,12 @@ class BOMReportSerializer(InvenTreeModelSerializer): class PurchaseOrderReportSerializer(InvenTreeModelSerializer): - + """Serializer class for the PurchaseOrdeReport model""" template = InvenTreeAttachmentSerializerField(required=True) class Meta: + """Metaclass options.""" + model = PurchaseOrderReport fields = [ 'pk', @@ -71,10 +82,12 @@ class PurchaseOrderReportSerializer(InvenTreeModelSerializer): class SalesOrderReportSerializer(InvenTreeModelSerializer): - + """Serializer class for the SalesOrderReport model""" template = InvenTreeAttachmentSerializerField(required=True) class Meta: + """Metaclass options.""" + model = SalesOrderReport fields = [ 'pk', diff --git a/InvenTree/report/templatetags/barcode.py b/InvenTree/report/templatetags/barcode.py index 8d8397d50a..0f1d885201 100644 --- a/InvenTree/report/templatetags/barcode.py +++ b/InvenTree/report/templatetags/barcode.py @@ -1,6 +1,4 @@ -""" -Template tags for rendering various barcodes -""" +"""Template tags for rendering various barcodes.""" import base64 from io import BytesIO @@ -14,12 +12,10 @@ register = template.Library() def image_data(img, fmt='PNG'): - """ - Convert an image into HTML renderable data + """Convert an image into HTML renderable data. Returns a string ``data:image/FMT;base64,xxxxxxxxx`` which can be rendered to an tag """ - buffered = BytesIO() img.save(buffered, format=fmt) @@ -30,8 +26,7 @@ def image_data(img, fmt='PNG'): @register.simple_tag() def qrcode(data, **kwargs): - """ - Return a byte-encoded QR code image + """Return a byte-encoded QR code image. Optional kwargs --------------- @@ -39,7 +34,6 @@ def qrcode(data, **kwargs): fill_color: Fill color (default = black) back_color: Background color (default = white) """ - # Construct "default" values params = dict( box_size=20, @@ -63,10 +57,7 @@ def qrcode(data, **kwargs): @register.simple_tag() def barcode(data, barcode_class='code128', **kwargs): - """ - Render a barcode - """ - + """Render a barcode.""" constructor = python_barcode.get_barcode_class(barcode_class) data = str(data).zfill(constructor.digits) diff --git a/InvenTree/report/templatetags/report.py b/InvenTree/report/templatetags/report.py index 8a14501933..9e82202196 100644 --- a/InvenTree/report/templatetags/report.py +++ b/InvenTree/report/templatetags/report.py @@ -1,6 +1,4 @@ -""" -Custom template tags for report generation -""" +"""Custom template tags for report generation.""" import os @@ -19,10 +17,7 @@ register = template.Library() @register.simple_tag() def asset(filename): - """ - Return fully-qualified path for an upload report asset file. - """ - + """Return fully-qualified path for an upload report asset file.""" # If in debug mode, return URL to the image, not a local file debug_mode = InvenTreeSetting.get_setting('REPORT_DEBUG_MODE') @@ -38,10 +33,7 @@ def asset(filename): @register.simple_tag() def part_image(part): - """ - Return a fully-qualified path for a part image - """ - + """Return a fully-qualified path for a part image.""" # If in debug mode, return URL to the image, not a local file debug_mode = InvenTreeSetting.get_setting('REPORT_DEBUG_MODE') @@ -75,10 +67,7 @@ def part_image(part): @register.simple_tag() def company_image(company): - """ - Return a fully-qualified path for a company image - """ - + """Return a fully-qualified path for a company image.""" # If in debug mode, return the URL to the image, not a local file debug_mode = InvenTreeSetting.get_setting('REPORT_DEBUG_MODE') @@ -108,15 +97,13 @@ def company_image(company): @register.simple_tag() def internal_link(link, text): - """ - Make a href which points to an InvenTree URL. + """Make a href which points to an InvenTree URL. Important Note: This only works if the INVENTREE_BASE_URL parameter is set! If the INVENTREE_BASE_URL parameter is not configured, the text will be returned (unlinked) """ - text = str(text) url = InvenTree.helpers.construct_absolute_url(link) diff --git a/InvenTree/report/tests.py b/InvenTree/report/tests.py index 38fbb491bf..e6a89a0304 100644 --- a/InvenTree/report/tests.py +++ b/InvenTree/report/tests.py @@ -1,3 +1,4 @@ +"""Unit testing for the various report models""" import os import shutil @@ -14,7 +15,7 @@ from stock.models import StockItem class ReportTest(InvenTreeAPITestCase): - + """Base class for unit testing reporting models""" fixtures = [ 'category', 'part', @@ -32,14 +33,8 @@ class ReportTest(InvenTreeAPITestCase): detail_url = None print_url = None - def setUp(self): - super().setUp() - def copyReportTemplate(self, filename, description): - """ - Copy the provided report template into the required media directory - """ - + """Copy the provided report template into the required media directory.""" src_dir = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'templates', @@ -81,10 +76,7 @@ class ReportTest(InvenTreeAPITestCase): ) def test_list_endpoint(self): - """ - Test that the LIST endpoint works for each report - """ - + """Test that the LIST endpoint works for each report.""" if not self.list_url: return @@ -121,7 +113,7 @@ class ReportTest(InvenTreeAPITestCase): class TestReportTest(ReportTest): - + """Unit testing class for the stock item TestReport model""" model = report_models.TestReport list_url = 'api-stockitem-testreport-list' @@ -129,16 +121,13 @@ class TestReportTest(ReportTest): print_url = 'api-stockitem-testreport-print' def setUp(self): - + """Setup function for the stock item TestReport""" self.copyReportTemplate('inventree_test_report.html', 'stock item test report') return super().setUp() def test_print(self): - """ - Printing tests for the TestReport - """ - + """Printing tests for the TestReport.""" report = self.model.objects.first() url = reverse(self.print_url, kwargs={'pk': report.pk}) @@ -163,7 +152,7 @@ class TestReportTest(ReportTest): class BuildReportTest(ReportTest): - + """Unit test class for the BuildReport model""" model = report_models.BuildReport list_url = 'api-build-report-list' @@ -171,16 +160,13 @@ class BuildReportTest(ReportTest): print_url = 'api-build-report-print' def setUp(self): - + """Setup unit testing functions""" self.copyReportTemplate('inventree_build_order.html', 'build order template') return super().setUp() def test_print(self): - """ - Printing tests for the BuildReport - """ - + """Printing tests for the BuildReport.""" report = self.model.objects.first() url = reverse(self.print_url, kwargs={'pk': report.pk}) @@ -216,7 +202,7 @@ class BuildReportTest(ReportTest): class BOMReportTest(ReportTest): - + """Unit test class fot the BillOfMaterialsReport model""" model = report_models.BillOfMaterialsReport list_url = 'api-bom-report-list' @@ -225,7 +211,7 @@ class BOMReportTest(ReportTest): class PurchaseOrderReportTest(ReportTest): - + """Unit test class fort he PurchaseOrderReport model""" model = report_models.PurchaseOrderReport list_url = 'api-po-report-list' @@ -234,7 +220,7 @@ class PurchaseOrderReportTest(ReportTest): class SalesOrderReportTest(ReportTest): - + """Unit test class for the SalesOrderReport model""" model = report_models.SalesOrderReport list_url = 'api-so-report-list' diff --git a/InvenTree/script/translation_stats.py b/InvenTree/script/translation_stats.py index e96708d1f2..e46d682fae 100644 --- a/InvenTree/script/translation_stats.py +++ b/InvenTree/script/translation_stats.py @@ -1,6 +1,4 @@ -""" -This script calculates translation coverage for various languages -""" +"""This script calculates translation coverage for various languages.""" import json import os @@ -8,10 +6,7 @@ import sys def calculate_coverage(filename): - """ - Calculate translation coverage for a .po file - """ - + """Calculate translation coverage for a .po file.""" with open(filename, 'r') as f: lines = f.readlines() diff --git a/InvenTree/stock/__init__.py b/InvenTree/stock/__init__.py index 6970329be1..524da0600c 100644 --- a/InvenTree/stock/__init__.py +++ b/InvenTree/stock/__init__.py @@ -1,5 +1,4 @@ -""" -The Stock module is responsible for Stock management. +"""The Stock module is responsible for Stock management. It includes models for: diff --git a/InvenTree/stock/admin.py b/InvenTree/stock/admin.py index 85d7b7afe0..f3c56553c5 100644 --- a/InvenTree/stock/admin.py +++ b/InvenTree/stock/admin.py @@ -1,3 +1,5 @@ +"""Admin for stock app.""" + from django.contrib import admin import import_export.widgets as widgets @@ -15,13 +17,15 @@ from .models import (StockItem, StockItemAttachment, StockItemTestResult, class LocationResource(ModelResource): - """ Class for managing StockLocation data import/export """ + """Class for managing StockLocation data import/export.""" parent = Field(attribute='parent', widget=widgets.ForeignKeyWidget(StockLocation)) parent_name = Field(attribute='parent__name', readonly=True) class Meta: + """Metaclass options.""" + model = StockLocation skip_unchanged = True report_skipped = False @@ -34,7 +38,7 @@ class LocationResource(ModelResource): ] def after_import(self, dataset, result, using_transactions, dry_run, **kwargs): - + """Rebuild after import to keep tree intact.""" super().after_import(dataset, result, using_transactions, dry_run, **kwargs) # Rebuild the StockLocation tree(s) @@ -42,13 +46,12 @@ class LocationResource(ModelResource): class LocationInline(admin.TabularInline): - """ - Inline for sub-locations - """ + """Inline for sub-locations.""" model = StockLocation class LocationAdmin(ImportExportModelAdmin): + """Admin class for Location.""" resource_class = LocationResource @@ -66,7 +69,7 @@ class LocationAdmin(ImportExportModelAdmin): class StockItemResource(ModelResource): - """ Class for managing StockItem data import/export """ + """Class for managing StockItem data import/export.""" # Custom managers for ForeignKey fields part = Field(attribute='part', widget=widgets.ForeignKeyWidget(Part)) @@ -103,13 +106,15 @@ class StockItemResource(ModelResource): stocktake_date = Field(attribute='stocktake_date', widget=widgets.DateWidget()) def after_import(self, dataset, result, using_transactions, dry_run, **kwargs): - + """Rebuild after import to keep tree intact.""" super().after_import(dataset, result, using_transactions, dry_run, **kwargs) # Rebuild the StockItem tree(s) StockItem.objects.rebuild() class Meta: + """Metaclass options.""" + model = StockItem skip_unchanged = True report_skipped = False @@ -124,6 +129,7 @@ class StockItemResource(ModelResource): class StockItemAdmin(ImportExportModelAdmin): + """Admin class for StockItem.""" resource_class = StockItemResource @@ -152,6 +158,7 @@ class StockItemAdmin(ImportExportModelAdmin): class StockAttachmentAdmin(admin.ModelAdmin): + """Admin class for StockAttachment.""" list_display = ('stock_item', 'attachment', 'comment') @@ -161,6 +168,8 @@ class StockAttachmentAdmin(admin.ModelAdmin): class StockTrackingAdmin(ImportExportModelAdmin): + """Admin class for StockTracking.""" + list_display = ('item', 'date', 'label') autocomplete_fields = [ @@ -169,6 +178,7 @@ class StockTrackingAdmin(ImportExportModelAdmin): class StockItemTestResultAdmin(admin.ModelAdmin): + """Admin class for StockItemTestResult.""" list_display = ('stock_item', 'test', 'result', 'value') diff --git a/InvenTree/stock/api.py b/InvenTree/stock/api.py index 4bc21df5a9..536ac66a4e 100644 --- a/InvenTree/stock/api.py +++ b/InvenTree/stock/api.py @@ -1,6 +1,4 @@ -""" -JSON API for the Stock app -""" +"""JSON API for the Stock app.""" from collections import OrderedDict from datetime import datetime, timedelta @@ -39,7 +37,7 @@ from stock.models import (StockItem, StockItemAttachment, StockItemTestResult, class StockDetail(generics.RetrieveUpdateDestroyAPIView): - """ API detail endpoint for Stock object + """API detail endpoint for Stock object. get: Return a single StockItem object @@ -55,21 +53,21 @@ class StockDetail(generics.RetrieveUpdateDestroyAPIView): serializer_class = StockSerializers.StockItemSerializer def get_queryset(self, *args, **kwargs): - + """Annotate queryset.""" queryset = super().get_queryset(*args, **kwargs) queryset = StockSerializers.StockItemSerializer.annotate_queryset(queryset) return queryset def get_serializer_context(self): - + """Extend serializer context.""" ctx = super().get_serializer_context() ctx['user'] = getattr(self.request, 'user', None) return ctx def get_serializer(self, *args, **kwargs): - + """Set context before returning serializer.""" kwargs['part_detail'] = True kwargs['location_detail'] = True kwargs['supplier_part_detail'] = True @@ -80,19 +78,20 @@ class StockDetail(generics.RetrieveUpdateDestroyAPIView): class StockMetadata(generics.RetrieveUpdateAPIView): - """API endpoint for viewing / updating StockItem metadata""" + """API endpoint for viewing / updating StockItem metadata.""" def get_serializer(self, *args, **kwargs): + """Return serializer.""" return MetadataSerializer(StockItem, *args, **kwargs) queryset = StockItem.objects.all() class StockItemContextMixin: - """ Mixin class for adding StockItem object to serializer context """ + """Mixin class for adding StockItem object to serializer context.""" def get_serializer_context(self): - + """Extend serializer context.""" context = super().get_serializer_context() context['request'] = self.request @@ -105,17 +104,14 @@ class StockItemContextMixin: class StockItemSerialize(StockItemContextMixin, generics.CreateAPIView): - """ - API endpoint for serializing a stock item - """ + """API endpoint for serializing a stock item.""" queryset = StockItem.objects.none() serializer_class = StockSerializers.SerializeStockItemSerializer class StockItemInstall(StockItemContextMixin, generics.CreateAPIView): - """ - API endpoint for installing a particular stock item into this stock item. + """API endpoint for installing a particular stock item into this stock item. - stock_item.part must be in the BOM for this part - stock_item must currently be "in stock" @@ -127,17 +123,14 @@ class StockItemInstall(StockItemContextMixin, generics.CreateAPIView): class StockItemUninstall(StockItemContextMixin, generics.CreateAPIView): - """ - API endpoint for removing (uninstalling) items from this item - """ + """API endpoint for removing (uninstalling) items from this item.""" queryset = StockItem.objects.none() serializer_class = StockSerializers.UninstallStockItemSerializer class StockAdjustView(generics.CreateAPIView): - """ - A generic class for handling stocktake actions. + """A generic class for handling stocktake actions. Subclasses exist for: @@ -150,80 +143,66 @@ class StockAdjustView(generics.CreateAPIView): queryset = StockItem.objects.none() def get_serializer_context(self): - + """Extend serializer context.""" context = super().get_serializer_context() - context['request'] = self.request return context class StockCount(StockAdjustView): - """ - Endpoint for counting stock (performing a stocktake). - """ + """Endpoint for counting stock (performing a stocktake).""" serializer_class = StockSerializers.StockCountSerializer class StockAdd(StockAdjustView): - """ - Endpoint for adding a quantity of stock to an existing StockItem - """ + """Endpoint for adding a quantity of stock to an existing StockItem.""" serializer_class = StockSerializers.StockAddSerializer class StockRemove(StockAdjustView): - """ - Endpoint for removing a quantity of stock from an existing StockItem. - """ + """Endpoint for removing a quantity of stock from an existing StockItem.""" serializer_class = StockSerializers.StockRemoveSerializer class StockTransfer(StockAdjustView): - """ - API endpoint for performing stock movements - """ + """API endpoint for performing stock movements.""" serializer_class = StockSerializers.StockTransferSerializer class StockAssign(generics.CreateAPIView): - """ - API endpoint for assigning stock to a particular customer - """ + """API endpoint for assigning stock to a particular customer.""" queryset = StockItem.objects.all() serializer_class = StockSerializers.StockAssignmentSerializer def get_serializer_context(self): - + """Extend serializer context.""" ctx = super().get_serializer_context() - ctx['request'] = self.request return ctx class StockMerge(generics.CreateAPIView): - """ - API endpoint for merging multiple stock items - """ + """API endpoint for merging multiple stock items.""" queryset = StockItem.objects.none() serializer_class = StockSerializers.StockMergeSerializer def get_serializer_context(self): + """Extend serializer context.""" ctx = super().get_serializer_context() ctx['request'] = self.request return ctx class StockLocationList(generics.ListCreateAPIView): - """ - API endpoint for list view of StockLocation objects: + """API endpoint for list view of StockLocation objects. - GET: Return list of StockLocation objects - POST: Create a new StockLocation @@ -233,11 +212,7 @@ class StockLocationList(generics.ListCreateAPIView): serializer_class = StockSerializers.LocationSerializer def filter_queryset(self, queryset): - """ - Custom filtering: - - Allow filtering by "null" parent to retrieve top-level stock locations - """ - + """Custom filtering: - Allow filtering by "null" parent to retrieve top-level stock locations.""" queryset = super().filter_queryset(queryset) params = self.request.query_params @@ -319,10 +294,7 @@ class StockLocationList(generics.ListCreateAPIView): class StockLocationTree(generics.ListAPIView): - """ - API endpoint for accessing a list of StockLocation objects, - ready for rendering as a tree - """ + """API endpoint for accessing a list of StockLocation objects, ready for rendering as a tree.""" queryset = StockLocation.objects.all() serializer_class = StockSerializers.LocationTreeSerializer @@ -337,9 +309,7 @@ class StockLocationTree(generics.ListAPIView): class StockFilter(rest_filters.FilterSet): - """ - FilterSet for StockItem LIST API - """ + """FilterSet for StockItem LIST API.""" # Part name filters name = rest_filters.CharFilter(label='Part name (case insensitive)', field_name='part__name', lookup_expr='iexact') @@ -361,7 +331,7 @@ class StockFilter(rest_filters.FilterSet): in_stock = rest_filters.BooleanFilter(label='In Stock', method='filter_in_stock') def filter_in_stock(self, queryset, name, value): - + """Filter by if item is in stock.""" if str2bool(value): queryset = queryset.filter(StockItem.IN_STOCK_FILTER) else: @@ -372,12 +342,10 @@ class StockFilter(rest_filters.FilterSet): available = rest_filters.BooleanFilter(label='Available', method='filter_available') def filter_available(self, queryset, name, value): - """ - Filter by whether the StockItem is "available" or not. + """Filter by whether the StockItem is "available" or not. Here, "available" means that the allocated quantity is less than the total quantity """ - if str2bool(value): # The 'quantity' field is greater than the calculated 'allocated' field queryset = queryset.filter(Q(quantity__gt=F('allocated'))) @@ -401,10 +369,7 @@ class StockFilter(rest_filters.FilterSet): serialized = rest_filters.BooleanFilter(label='Has serial number', method='filter_serialized') def filter_serialized(self, queryset, name, value): - """ - Filter by whether the StockItem has a serial number (or not) - """ - + """Filter by whether the StockItem has a serial number (or not).""" q = Q(serial=None) | Q(serial='') if str2bool(value): @@ -417,10 +382,7 @@ class StockFilter(rest_filters.FilterSet): has_batch = rest_filters.BooleanFilter(label='Has batch code', method='filter_has_batch') def filter_has_batch(self, queryset, name, value): - """ - Filter by whether the StockItem has a batch code (or not) - """ - + """Filter by whether the StockItem has a batch code (or not).""" q = Q(batch=None) | Q(batch='') if str2bool(value): @@ -433,12 +395,12 @@ class StockFilter(rest_filters.FilterSet): tracked = rest_filters.BooleanFilter(label='Tracked', method='filter_tracked') def filter_tracked(self, queryset, name, value): - """ - Filter by whether this stock item is *tracked*, meaning either: + """Filter by whether this stock item is *tracked*. + + Meaning either: - It has a serial number - It has a batch code """ - q_batch = Q(batch=None) | Q(batch='') q_serial = Q(serial=None) | Q(serial='') @@ -452,10 +414,7 @@ class StockFilter(rest_filters.FilterSet): installed = rest_filters.BooleanFilter(label='Installed in other stock item', method='filter_installed') def filter_installed(self, queryset, name, value): - """ - Filter stock items by "belongs_to" field being empty - """ - + """Filter stock items by "belongs_to" field being empty.""" if str2bool(value): queryset = queryset.exclude(belongs_to=None) else: @@ -466,7 +425,7 @@ class StockFilter(rest_filters.FilterSet): sent_to_customer = rest_filters.BooleanFilter(label='Sent to customer', method='filter_sent_to_customer') def filter_sent_to_customer(self, queryset, name, value): - + """Filter by sent to customer.""" if str2bool(value): queryset = queryset.exclude(customer=None) else: @@ -477,7 +436,7 @@ class StockFilter(rest_filters.FilterSet): depleted = rest_filters.BooleanFilter(label='Depleted', method='filter_depleted') def filter_depleted(self, queryset, name, value): - + """Filter by depleted items.""" if str2bool(value): queryset = queryset.filter(quantity__lte=0) else: @@ -488,9 +447,9 @@ class StockFilter(rest_filters.FilterSet): has_purchase_price = rest_filters.BooleanFilter(label='Has purchase price', method='filter_has_purchase_price') def filter_has_purchase_price(self, queryset, name, value): - + """Filter by having a purchase price.""" if str2bool(value): - queryset = queryset.exclude(purcahse_price=None) + queryset = queryset.exclude(purchase_price=None) else: queryset = queryset.filter(purchase_price=None) @@ -502,7 +461,7 @@ class StockFilter(rest_filters.FilterSet): class StockList(APIDownloadMixin, generics.ListCreateAPIView): - """ API endpoint for list view of Stock objects + """API endpoint for list view of Stock objects. - GET: Return a list of all StockItem objects (with optional query filters) - POST: Create a new StockItem @@ -513,22 +472,20 @@ class StockList(APIDownloadMixin, generics.ListCreateAPIView): filterset_class = StockFilter def get_serializer_context(self): - + """Extend serializer context.""" ctx = super().get_serializer_context() ctx['user'] = getattr(self.request, 'user', None) return ctx def create(self, request, *args, **kwargs): - """ - Create a new StockItem object via the API. + """Create a new StockItem object via the API. We override the default 'create' implementation. If a location is *not* specified, but the linked *part* has a default location, we can pre-fill the location automatically. """ - user = request.user # Copy the request data, to side-step "mutability" issues @@ -602,9 +559,7 @@ class StockList(APIDownloadMixin, generics.ListCreateAPIView): }) if serials is not None: - """ - If the stock item is going to be serialized, set the quantity to 1 - """ + """If the stock item is going to be serialized, set the quantity to 1.""" data['quantity'] = 1 # De-serialize the provided data @@ -643,8 +598,8 @@ class StockList(APIDownloadMixin, generics.ListCreateAPIView): return Response(response_data, status=status.HTTP_201_CREATED, headers=self.get_success_headers(serializer.data)) def download_queryset(self, queryset, export_format): - """ - Download this queryset as a file. + """Download this queryset as a file. + Uses the APIDownloadMixin mixin class """ dataset = StockItemResource().export(queryset=queryset) @@ -659,13 +614,10 @@ class StockList(APIDownloadMixin, generics.ListCreateAPIView): return DownloadFile(filedata, filename) def list(self, request, *args, **kwargs): - """ - Override the 'list' method, as the StockLocation objects - are very expensive to serialize. + """Override the 'list' method, as the StockLocation objects are very expensive to serialize. So, we fetch and serialize the required StockLocation objects only as required. """ - queryset = self.filter_queryset(self.get_queryset()) params = request.query_params @@ -767,7 +719,7 @@ class StockList(APIDownloadMixin, generics.ListCreateAPIView): return Response(data) def get_queryset(self, *args, **kwargs): - + """Annotate queryset before returning.""" queryset = super().get_queryset(*args, **kwargs) queryset = StockSerializers.StockItemSerializer.annotate_queryset(queryset) @@ -775,10 +727,7 @@ class StockList(APIDownloadMixin, generics.ListCreateAPIView): return queryset def filter_queryset(self, queryset): - """ - Custom filtering for the StockItem queryset - """ - + """Custom filtering for the StockItem queryset.""" params = self.request.query_params queryset = super().filter_queryset(queryset) @@ -1090,9 +1039,7 @@ class StockList(APIDownloadMixin, generics.ListCreateAPIView): class StockAttachmentList(generics.ListCreateAPIView, AttachmentMixin): - """ - API endpoint for listing (and creating) a StockItemAttachment (file upload) - """ + """API endpoint for listing (and creating) a StockItemAttachment (file upload).""" queryset = StockItemAttachment.objects.all() serializer_class = StockSerializers.StockItemAttachmentSerializer @@ -1109,27 +1056,21 @@ class StockAttachmentList(generics.ListCreateAPIView, AttachmentMixin): class StockAttachmentDetail(generics.RetrieveUpdateDestroyAPIView, AttachmentMixin): - """ - Detail endpoint for StockItemAttachment - """ + """Detail endpoint for StockItemAttachment.""" queryset = StockItemAttachment.objects.all() serializer_class = StockSerializers.StockItemAttachmentSerializer class StockItemTestResultDetail(generics.RetrieveUpdateDestroyAPIView): - """ - Detail endpoint for StockItemTestResult - """ + """Detail endpoint for StockItemTestResult.""" queryset = StockItemTestResult.objects.all() serializer_class = StockSerializers.StockItemTestResultSerializer class StockItemTestResultList(generics.ListCreateAPIView): - """ - API endpoint for listing (and creating) a StockItemTestResult object. - """ + """API endpoint for listing (and creating) a StockItemTestResult object.""" queryset = StockItemTestResult.objects.all() serializer_class = StockSerializers.StockItemTestResultSerializer @@ -1150,7 +1091,7 @@ class StockItemTestResultList(generics.ListCreateAPIView): ordering = 'date' def filter_queryset(self, queryset): - + """Filter by build or stock_item.""" params = self.request.query_params queryset = super().filter_queryset(queryset) @@ -1195,6 +1136,7 @@ class StockItemTestResultList(generics.ListCreateAPIView): return queryset def get_serializer(self, *args, **kwargs): + """Set context before returning serializer.""" try: kwargs['user_detail'] = str2bool(self.request.query_params.get('user_detail', False)) except: @@ -1205,13 +1147,11 @@ class StockItemTestResultList(generics.ListCreateAPIView): return self.serializer_class(*args, **kwargs) def perform_create(self, serializer): - """ - Create a new test result object. + """Create a new test result object. Also, check if an attachment was uploaded alongside the test result, and save it to the database if it were. """ - # Capture the user information test_result = serializer.save() test_result.user = self.request.user @@ -1219,16 +1159,14 @@ class StockItemTestResultList(generics.ListCreateAPIView): class StockTrackingDetail(generics.RetrieveAPIView): - """ - Detail API endpoint for StockItemTracking model - """ + """Detail API endpoint for StockItemTracking model.""" queryset = StockItemTracking.objects.all() serializer_class = StockSerializers.StockTrackingSerializer class StockTrackingList(generics.ListAPIView): - """ API endpoint for list view of StockItemTracking objects. + """API endpoint for list view of StockItemTracking objects. StockItemTracking objects are read-only (they are created by internal model functionality) @@ -1240,6 +1178,7 @@ class StockTrackingList(generics.ListAPIView): serializer_class = StockSerializers.StockTrackingSerializer def get_serializer(self, *args, **kwargs): + """Set context before returning serializer.""" try: kwargs['item_detail'] = str2bool(self.request.query_params.get('item_detail', False)) except: @@ -1255,7 +1194,7 @@ class StockTrackingList(generics.ListAPIView): return self.serializer_class(*args, **kwargs) def list(self, request, *args, **kwargs): - + """List all stock tracking entries.""" queryset = self.filter_queryset(self.get_queryset()) serializer = self.get_serializer(queryset, many=True) @@ -1320,12 +1259,11 @@ class StockTrackingList(generics.ListAPIView): return Response(data) def create(self, request, *args, **kwargs): - """ Create a new StockItemTracking object + """Create a new StockItemTracking object. Here we override the default 'create' implementation, to save the user information associated with the request object. """ - serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) @@ -1365,16 +1303,17 @@ class StockTrackingList(generics.ListAPIView): class LocationMetadata(generics.RetrieveUpdateAPIView): - """API endpoint for viewing / updating StockLocation metadata""" + """API endpoint for viewing / updating StockLocation metadata.""" def get_serializer(self, *args, **kwargs): + """Return serializer.""" return MetadataSerializer(StockLocation, *args, **kwargs) queryset = StockLocation.objects.all() class LocationDetail(generics.RetrieveUpdateDestroyAPIView): - """ API endpoint for detail view of StockLocation object + """API endpoint for detail view of StockLocation object. - GET: Return a single StockLocation object - PATCH: Update a StockLocation object diff --git a/InvenTree/stock/apps.py b/InvenTree/stock/apps.py index ca24ee7312..5c1d48545d 100644 --- a/InvenTree/stock/apps.py +++ b/InvenTree/stock/apps.py @@ -1,6 +1,8 @@ +"""AppConfig for stock app.""" from django.apps import AppConfig class StockConfig(AppConfig): + """AppConfig for stock app.""" name = 'stock' diff --git a/InvenTree/stock/forms.py b/InvenTree/stock/forms.py index 7e615b5eae..5eaf9d917a 100644 --- a/InvenTree/stock/forms.py +++ b/InvenTree/stock/forms.py @@ -1,6 +1,4 @@ -""" -Django Forms for interacting with Stock app -""" +"""Django Forms for interacting with Stock app.""" from InvenTree.forms import HelperForm @@ -8,13 +6,14 @@ from .models import StockItem, StockItemTracking class ReturnStockItemForm(HelperForm): - """ - Form for manually returning a StockItem into stock + """Form for manually returning a StockItem into stock. TODO: This could be a simple API driven form! """ class Meta: + """Metaclass options.""" + model = StockItem fields = [ 'location', @@ -22,13 +21,14 @@ class ReturnStockItemForm(HelperForm): class ConvertStockItemForm(HelperForm): - """ - Form for converting a StockItem to a variant of its current part. + """Form for converting a StockItem to a variant of its current part. TODO: Migrate this form to the modern API forms interface """ class Meta: + """Metaclass options.""" + model = StockItem fields = [ 'part' @@ -36,13 +36,14 @@ class ConvertStockItemForm(HelperForm): class TrackingEntryForm(HelperForm): - """ - Form for creating / editing a StockItemTracking object. + """Form for creating / editing a StockItemTracking object. Note: 2021-05-11 - This form is not currently used - should delete? """ class Meta: + """Metaclass options.""" + model = StockItemTracking fields = [ diff --git a/InvenTree/stock/models.py b/InvenTree/stock/models.py index 3faea832d9..292afae53d 100644 --- a/InvenTree/stock/models.py +++ b/InvenTree/stock/models.py @@ -1,6 +1,6 @@ -""" -Stock database model definitions -""" +"""Stock database model definitions.""" + +from __future__ import annotations import os from datetime import datetime, timedelta @@ -40,17 +40,17 @@ from users.models import Owner class StockLocation(MetadataMixin, InvenTreeTree): - """ Organization tree for StockItem objects + """Organization tree for StockItem objects. + A "StockLocation" can be considered a warehouse, or storage location Stock locations can be heirarchical as required """ def delete(self, *args, **kwargs): - """ - Custom model deletion routine, which updates any child locations or items. + """Custom model deletion routine, which updates any child locations or items. + This must be handled within a transaction.atomic(), otherwise the tree structure is damaged """ - with transaction.atomic(): parent = self.parent @@ -76,6 +76,7 @@ class StockLocation(MetadataMixin, InvenTreeTree): @staticmethod def get_api_url(): + """Return API url.""" return reverse('api-location-list') owner = models.ForeignKey(Owner, on_delete=models.SET_NULL, blank=True, null=True, @@ -84,12 +85,10 @@ class StockLocation(MetadataMixin, InvenTreeTree): related_name='stock_locations') def get_location_owner(self): - """ - Get the closest "owner" for this location. + """Get the closest "owner" for this location. Start at this location, and traverse "up" the location tree until we find an owner """ - for loc in self.get_ancestors(include_self=True, ascending=True): if loc.owner is not None: return loc.owner @@ -97,10 +96,7 @@ class StockLocation(MetadataMixin, InvenTreeTree): return None def check_ownership(self, user): - """ - Check if the user "owns" (is one of the owners of) the location. - """ - + """Check if the user "owns" (is one of the owners of) the location.""" # Superuser accounts automatically "own" everything if user.is_superuser: return True @@ -121,11 +117,11 @@ class StockLocation(MetadataMixin, InvenTreeTree): return user in owner.get_related_owners(include_group=True) def get_absolute_url(self): + """Return url for instance.""" return reverse('stock-location-detail', kwargs={'pk': self.id}) def format_barcode(self, **kwargs): - """ Return a JSON string for formatting a barcode for this StockLocation object """ - + """Return a JSON string for formatting a barcode for this StockLocation object.""" return InvenTree.helpers.MakeBarcode( 'stocklocation', self.pk, @@ -137,19 +133,20 @@ class StockLocation(MetadataMixin, InvenTreeTree): ) @property - def barcode(self): - """ - Brief payload data (e.g. for labels) + def barcode(self) -> str: + """Get Brief payload data (e.g. for labels). + + Returns: + str: Brief pyload data """ return self.format_barcode(brief=True) def get_stock_items(self, cascade=True): - """ Return a queryset for all stock items under this category. + """Return a queryset for all stock items under this category. Args: cascade: If True, also look under sublocations (default = True) """ - if cascade: query = StockItem.objects.filter(location__in=self.getUniqueChildren(include_self=True)) else: @@ -158,36 +155,26 @@ class StockLocation(MetadataMixin, InvenTreeTree): return query def stock_item_count(self, cascade=True): - """ Return the number of StockItem objects which live in or under this category - """ - + """Return the number of StockItem objects which live in or under this category.""" return self.get_stock_items(cascade).count() - def has_items(self, cascade=True): - """ Return True if there are StockItems existing in this category. - - Args: - cascade: If True, also search an sublocations (default = True) - """ - return self.stock_item_count(cascade) > 0 - @property def item_count(self): - """ Simply returns the number of stock items in this location. + """Simply returns the number of stock items in this location. + Required for tree view serializer. """ return self.stock_item_count() class StockItemManager(TreeManager): - """ - Custom database manager for the StockItem class. + """Custom database manager for the StockItem class. StockItem querysets will automatically prefetch related fields. """ def get_queryset(self): - + """Prefetch queryset to optimise db hits.""" return super().get_queryset().prefetch_related( 'belongs_to', 'build', @@ -205,13 +192,11 @@ class StockItemManager(TreeManager): def generate_batch_code(): - """ - Generate a default 'batch code' for a new StockItem. + """Generate a default 'batch code' for a new StockItem. This uses the value of the 'STOCK_BATCH_CODE_TEMPLATE' setting (if configured), which can be passed through a simple template. """ - batch_template = common.models.InvenTreeSetting.get_setting('STOCK_BATCH_CODE_TEMPLATE', '') now = datetime.now() @@ -231,8 +216,7 @@ def generate_batch_code(): class StockItem(MetadataMixin, MPTTModel): - """ - A StockItem object represents a quantity of physical instances of a part. + """A StockItem object represents a quantity of physical instances of a part. Attributes: parent: Link to another StockItem from which this StockItem was created @@ -263,13 +247,11 @@ class StockItem(MetadataMixin, MPTTModel): @staticmethod def get_api_url(): + """Return API url.""" return reverse('api-stock-list') def api_instance_filters(self): - """ - Custom API instance filters - """ - + """Custom API instance filters.""" return { 'parent': { 'exclude_tree': self.pk, @@ -290,11 +272,10 @@ class StockItem(MetadataMixin, MPTTModel): EXPIRED_FILTER = IN_STOCK_FILTER & ~Q(expiry_date=None) & Q(expiry_date__lt=datetime.now().date()) def update_serial_number(self): - """ - Update the 'serial_int' field, to be an integer representation of the serial number. + """Update the 'serial_int' field, to be an integer representation of the serial number. + This is used for efficient numerical sorting """ - serial = getattr(self, 'serial', '') # Default value if we cannot convert to an integer @@ -309,8 +290,7 @@ class StockItem(MetadataMixin, MPTTModel): self.serial_int = serial_int def get_next_serialized_item(self, include_variants=True, reverse=False): - """ - Get the "next" serial number for the part this stock item references. + """Get the "next" serial number for the part this stock item references. e.g. if this stock item has a serial number 100, we may return the stock item with serial number 101 @@ -322,9 +302,7 @@ class StockItem(MetadataMixin, MPTTModel): Returns: A StockItem object matching the requirements, or None - """ - if not self.serialized: return None @@ -358,13 +336,12 @@ class StockItem(MetadataMixin, MPTTModel): return None def save(self, *args, **kwargs): - """ - Save this StockItem to the database. Performs a number of checks: + """Save this StockItem to the database. + Performs a number of checks: - Unique serial number requirement - Adds a transaction note when the item is first created. """ - self.validate_unique() self.clean() @@ -428,21 +405,20 @@ class StockItem(MetadataMixin, MPTTModel): @property def status_label(self): - + """Return label.""" return StockStatus.label(self.status) @property def serialized(self): - """ Return True if this StockItem is serialized """ + """Return True if this StockItem is serialized.""" return self.serial is not None and len(str(self.serial).strip()) > 0 and self.quantity == 1 def validate_unique(self, exclude=None): - """ - Test that this StockItem is "unique". + """Test that this StockItem is "unique". + If the StockItem is serialized, the same serial number. cannot exist for the same part (or part tree). """ - super(StockItem, self).validate_unique(exclude) # If the serial number is set, make sure it is not a duplicate @@ -459,15 +435,13 @@ class StockItem(MetadataMixin, MPTTModel): raise ValidationError({"serial": _("StockItem with this serial number already exists")}) def clean(self): - """ Validate the StockItem object (separate to field validation) + """Validate the StockItem object (separate to field validation). The following validation checks are performed: - - The 'part' and 'supplier_part.part' fields cannot point to the same Part object - The 'part' does not belong to itself - Quantity must be 1 if the StockItem has a serial number """ - super().clean() # Strip serial number field @@ -557,22 +531,23 @@ class StockItem(MetadataMixin, MPTTModel): }) def get_absolute_url(self): + """Return url for instance.""" return reverse('stock-item-detail', kwargs={'pk': self.id}) def get_part_name(self): + """Returns part name.""" return self.part.full_name def format_barcode(self, **kwargs): - """ Return a JSON string for formatting a barcode for this StockItem. - Can be used to perform lookup of a stockitem using barcode + """Return a JSON string for formatting a barcode for this StockItem. + + Can be used to perform lookup of a stockitem using barcode. Contains the following data: - - { type: 'StockItem', stock_id: , part_id: } + `{ type: 'StockItem', stock_id: , part_id: }` Voltagile data (e.g. stock quantity) should be looked up using the InvenTree API (as it may change) """ - return InvenTree.helpers.MakeBarcode( "stockitem", self.id, @@ -586,8 +561,10 @@ class StockItem(MetadataMixin, MPTTModel): @property def barcode(self): - """ - Brief payload data (e.g. for labels) + """Get Brief payload data (e.g. for labels). + + Returns: + str: Brief pyload data """ return self.format_barcode(brief=True) @@ -753,11 +730,7 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def convert_to_variant(self, variant, user, notes=None): - """ - Convert this StockItem instance to a "variant", - i.e. change the "part" reference field - """ - + """Convert this StockItem instance to a "variant", i.e. change the "part" reference field.""" if not variant: # Ignore null values return @@ -779,14 +752,12 @@ class StockItem(MetadataMixin, MPTTModel): ) def get_item_owner(self): - """ - Return the closest "owner" for this StockItem. + """Return the closest "owner" for this StockItem. - If the item has an owner set, return that - If the item is "in stock", check the StockLocation - Otherwise, return None """ - if self.owner is not None: return self.owner @@ -799,10 +770,7 @@ class StockItem(MetadataMixin, MPTTModel): return None def check_ownership(self, user): - """ - Check if the user "owns" (or is one of the owners of) the item - """ - + """Check if the user "owns" (or is one of the owners of) the item.""" # Superuser accounts automatically "own" everything if user.is_superuser: return True @@ -821,16 +789,13 @@ class StockItem(MetadataMixin, MPTTModel): return user in owner.get_related_owners(include_group=True) def is_stale(self): - """ - Returns True if this Stock item is "stale". + """Returns True if this Stock item is "stale". To be "stale", the following conditions must be met: - - Expiry date is not None - Expiry date will "expire" within the configured stale date - The StockItem is otherwise "in stock" """ - if self.expiry_date is None: return False @@ -849,16 +814,13 @@ class StockItem(MetadataMixin, MPTTModel): return self.expiry_date < expiry_date def is_expired(self): - """ - Returns True if this StockItem is "expired". + """Returns True if this StockItem is "expired". To be "expired", the following conditions must be met: - - Expiry date is not None - Expiry date is "in the past" - The StockItem is otherwise "in stock" """ - if self.expiry_date is None: return False @@ -870,13 +832,12 @@ class StockItem(MetadataMixin, MPTTModel): return self.expiry_date < today def clearAllocations(self): - """ - Clear all order allocations for this StockItem: + """Clear all order allocations for this StockItem. + Clears: - SalesOrder allocations - Build allocations """ - # Delete outstanding SalesOrder allocations self.sales_order_allocations.all().delete() @@ -884,8 +845,7 @@ class StockItem(MetadataMixin, MPTTModel): self.allocations.all().delete() def allocateToCustomer(self, customer, quantity=None, order=None, user=None, notes=None): - """ - Allocate a StockItem to a customer. + """Allocate a StockItem to a customer. This action can be called by the following processes: - Completion of a SalesOrder @@ -898,7 +858,6 @@ class StockItem(MetadataMixin, MPTTModel): user: User that performed the action notes: Notes field """ - if quantity is None: quantity = self.quantity @@ -936,10 +895,7 @@ class StockItem(MetadataMixin, MPTTModel): return item def returnFromCustomer(self, location, user=None, **kwargs): - """ - Return stock item from customer, back into the specified location. - """ - + """Return stock item from customer, back into the specified location.""" notes = kwargs.get('notes', '') tracking_info = {} @@ -972,10 +928,7 @@ class StockItem(MetadataMixin, MPTTModel): infinite = models.BooleanField(default=False) def is_allocated(self): - """ - Return True if this StockItem is allocated to a SalesOrder or a Build - """ - + """Return True if this StockItem is allocated to a SalesOrder or a Build.""" # TODO - For now this only checks if the StockItem is allocated to a SalesOrder # TODO - In future, once the "build" is working better, check this too @@ -988,10 +941,7 @@ class StockItem(MetadataMixin, MPTTModel): return False def build_allocation_count(self): - """ - Return the total quantity allocated to builds - """ - + """Return the total quantity allocated to builds.""" query = self.allocations.aggregate(q=Coalesce(Sum('quantity'), Decimal(0))) total = query['q'] @@ -1002,10 +952,7 @@ class StockItem(MetadataMixin, MPTTModel): return total def sales_order_allocation_count(self): - """ - Return the total quantity allocated to SalesOrders - """ - + """Return the total quantity allocated to SalesOrders.""" query = self.sales_order_allocations.aggregate(q=Coalesce(Sum('quantity'), Decimal(0))) total = query['q'] @@ -1016,31 +963,25 @@ class StockItem(MetadataMixin, MPTTModel): return total def allocation_count(self): - """ - Return the total quantity allocated to builds or orders - """ - + """Return the total quantity allocated to builds or orders.""" bo = self.build_allocation_count() so = self.sales_order_allocation_count() return bo + so def unallocated_quantity(self): - """ - Return the quantity of this StockItem which is *not* allocated - """ - + """Return the quantity of this StockItem which is *not* allocated.""" return max(self.quantity - self.allocation_count(), 0) def can_delete(self): - """ Can this stock item be deleted? It can NOT be deleted under the following circumstances: + """Can this stock item be deleted? + It can NOT be deleted under the following circumstances: - Has installed stock items - Is installed inside another StockItem - It has been assigned to a SalesOrder - It has been assigned to a BuildOrder """ - if self.installed_item_count() > 0: return False @@ -1049,16 +990,17 @@ class StockItem(MetadataMixin, MPTTModel): return True - def get_installed_items(self, cascade=False): - """ - Return all stock items which are *installed* in this one! - - Args: - cascade - Include items which are installed in items which are installed in items + def get_installed_items(self, cascade: bool = False) -> set[StockItem]: + """Return all stock items which are *installed* in this one! Note: This function is recursive, and may result in a number of database hits! - """ + Args: + cascade (bool, optional): Include items which are installed in items which are installed in items. Defaults to False. + + Returns: + set[StockItem]: Sll stock items which are installed + """ installed = set() items = StockItem.objects.filter(belongs_to=self) @@ -1085,24 +1027,19 @@ class StockItem(MetadataMixin, MPTTModel): return installed def installed_item_count(self): - """ - Return the number of stock items installed inside this one. - """ - + """Return the number of stock items installed inside this one.""" return self.installed_parts.count() @transaction.atomic def installStockItem(self, other_item, quantity, user, notes): - """ - Install another stock item into this stock item. + """Install another stock item into this stock item. - Args + Args: other_item: The stock item to install into this stock item quantity: The quantity of stock to install user: The user performing the operation notes: Any notes associated with the operation """ - # Cannot be already installed in another stock item! if self.belongs_to is not None: return False @@ -1139,15 +1076,13 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def uninstall_into_location(self, location, user, notes): - """ - Uninstall this stock item from another item, into a location. + """Uninstall this stock item from another item, into a location. Args: location: The stock location where the item will be moved user: The user performing the operation notes: Any notes associated with the operation """ - # If the stock item is not installed in anything, ignore if self.belongs_to is None: return False @@ -1184,24 +1119,23 @@ class StockItem(MetadataMixin, MPTTModel): @property def children(self): - """ Return a list of the child items which have been split from this stock item """ + """Return a list of the child items which have been split from this stock item.""" return self.get_descendants(include_self=False) @property def child_count(self): - """ Return the number of 'child' items associated with this StockItem. + """Return the number of 'child' items associated with this StockItem. + A child item is one which has been split from this one. """ return self.children.count() @property def in_stock(self): - """ - Returns True if this item is in stock. + """Returns True if this item is in stock. See also: IN_STOCK_FILTER """ - query = StockItem.objects.filter(pk=self.pk) query = query.filter(StockItem.IN_STOCK_FILTER) @@ -1210,14 +1144,12 @@ class StockItem(MetadataMixin, MPTTModel): @property def can_adjust_location(self): - """ - Returns True if the stock location can be "adjusted" for this part + """Returns True if the stock location can be "adjusted" for this part. Cannot be adjusted if: - Has been delivered to a customer - Has been installed inside another StockItem """ - if self.customer is not None: return False @@ -1231,22 +1163,22 @@ class StockItem(MetadataMixin, MPTTModel): @property def tracking_info_count(self): + """How many tracking entries are available?""" return self.tracking_info.count() @property def has_tracking_info(self): + """Is tracking info available?""" return self.tracking_info_count > 0 - def add_tracking_entry(self, entry_type, user, deltas=None, notes='', **kwargs): - """ - Add a history tracking entry for this StockItem + def add_tracking_entry(self, entry_type: int, user: User, deltas: dict = None, notes: str = '', **kwargs): + """Add a history tracking entry for this StockItem. Args: - entry_type - Integer code describing the "type" of historical action (see StockHistoryCode) - user - The user performing this action - deltas - A map of the changes made to the model - notes - User notes associated with this tracking entry - url - Optional URL associated with this tracking entry + entry_type (int): Code describing the "type" of historical action (see StockHistoryCode) + user (User): The user performing this action + deltas (dict, optional): A map of the changes made to the model. Defaults to None. + notes (str, optional): URL associated with this tracking entry. Defaults to ''. """ if deltas is None: deltas = {} @@ -1276,7 +1208,7 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def serializeStock(self, quantity, serials, user, notes='', location=None): - """ Split this stock item into unique serial numbers. + """Split this stock item into unique serial numbers. - Quantity can be less than or equal to the quantity of the stock item - Number of serial numbers must match the quantity @@ -1289,7 +1221,6 @@ class StockItem(MetadataMixin, MPTTModel): notes: Optional notes for tracking location: If specified, serialized items will be placed in the given location """ - # Cannot serialize stock that is already serialized! if self.serialized: return @@ -1360,8 +1291,7 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def copyHistoryFrom(self, other): - """ Copy stock history from another StockItem """ - + """Copy stock history from another StockItem.""" for item in other.tracking_info.all(): item.item = self @@ -1370,8 +1300,7 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def copyTestResultsFrom(self, other, filters={}): - """ Copy all test results from another StockItem """ - + """Copy all test results from another StockItem.""" for result in other.test_results.all().filter(**filters): # Create a copy of the test result by nulling-out the pk @@ -1380,10 +1309,7 @@ class StockItem(MetadataMixin, MPTTModel): result.save() def can_merge(self, other=None, raise_error=False, **kwargs): - """ - Check if this stock item can be merged into another stock item - """ - + """Check if this stock item can be merged into another stock item.""" allow_mismatched_suppliers = kwargs.get('allow_mismatched_suppliers', False) allow_mismatched_status = kwargs.get('allow_mismatched_status', False) @@ -1437,8 +1363,7 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def merge_stock_items(self, other_items, raise_error=False, **kwargs): - """ - Merge another stock item into this one; the two become one! + """Merge another stock item into this one; the two become one! *This* stock item subsumes the other, which is essentially deleted: @@ -1446,7 +1371,6 @@ class StockItem(MetadataMixin, MPTTModel): - Tracking history for the *other* item is deleted - Any allocations (build order, sales order) are moved to this StockItem """ - if len(other_items) == 0: return @@ -1499,7 +1423,8 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def splitStock(self, quantity, location, user, **kwargs): - """ Split this stock item into two items, in the same location. + """Split this stock item into two items, in the same location. + Stock tracking notes for this StockItem will be duplicated, and added to the new StockItem. @@ -1511,7 +1436,6 @@ class StockItem(MetadataMixin, MPTTModel): The provided quantity will be subtracted from this item and given to the new one. The new item will have a different StockItem ID, while this will remain the same. """ - notes = kwargs.get('notes', '') code = kwargs.get('code', StockHistoryCode.SPLIT_FROM_PARENT) @@ -1576,7 +1500,7 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def move(self, location, notes, user, **kwargs): - """ Move part to a new location. + """Move part to a new location. If less than the available quantity is to be moved, a new StockItem is created, with the defined quantity, @@ -1590,7 +1514,6 @@ class StockItem(MetadataMixin, MPTTModel): kwargs: quantity: If provided, override the quantity (default = total stock quantity) """ - try: quantity = Decimal(kwargs.get('quantity', self.quantity)) except InvalidOperation: @@ -1636,7 +1559,7 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def updateQuantity(self, quantity): - """ Update stock quantity for this item. + """Update stock quantity for this item. If the quantity has reached zero, this StockItem will be deleted. @@ -1644,7 +1567,6 @@ class StockItem(MetadataMixin, MPTTModel): - True if the quantity was saved - False if the StockItem was deleted """ - # Do not adjust quantity of a serialized part if self.serialized: return @@ -1669,11 +1591,11 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def stocktake(self, count, user, notes=''): - """ Perform item stocktake. + """Perform item stocktake. + When the quantity of an item is counted, record the date of stocktake """ - try: count = Decimal(count) except InvalidOperation: @@ -1700,11 +1622,11 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def add_stock(self, quantity, user, notes=''): - """ Add items to stock + """Add items to stock. + This function can be called by initiating a ProjectRun, or by manually adding the items to the stock location """ - # Cannot add items to a serialized part if self.serialized: return False @@ -1734,10 +1656,7 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def take_stock(self, quantity, user, notes='', code=StockHistoryCode.STOCK_REMOVE): - """ - Remove items from stock - """ - + """Remove items from stock.""" # Cannot remove items from a serialized part if self.serialized: return False @@ -1765,6 +1684,7 @@ class StockItem(MetadataMixin, MPTTModel): return True def __str__(self): + """Human friendly name.""" if self.part.trackable and self.serial: s = '{part} #{sn}'.format( part=self.part.full_name, @@ -1787,13 +1707,7 @@ class StockItem(MetadataMixin, MPTTModel): @transaction.atomic def clear_test_results(self, **kwargs): - """ - Remove all test results - - kwargs: - TODO - """ - + """Remove all test results.""" # All test results results = self.test_results.all() @@ -1802,15 +1716,13 @@ class StockItem(MetadataMixin, MPTTModel): results.delete() def getTestResults(self, test=None, result=None, user=None): - """ - Return all test results associated with this StockItem. + """Return all test results associated with this StockItem. Optionally can filter results by: - Test name - Test result - User """ - results = self.test_results if test: @@ -1828,15 +1740,14 @@ class StockItem(MetadataMixin, MPTTModel): return results def testResultMap(self, **kwargs): - """ - Return a map of test-results using the test name as the key. + """Return a map of test-results using the test name as the key. + Where multiple test results exist for a given name, the *most recent* test is used. This map is useful for rendering to a template (e.g. a test report), as all named tests are accessible. """ - # Do we wish to include test results from installed items? include_installed = kwargs.pop('include_installed', False) @@ -1867,23 +1778,18 @@ class StockItem(MetadataMixin, MPTTModel): return result_map def testResultList(self, **kwargs): - """ - Return a list of test-result objects for this StockItem - """ - + """Return a list of test-result objects for this StockItem.""" return self.testResultMap(**kwargs).values() def requiredTestStatus(self): - """ - Return the status of the tests required for this StockItem. + """Return the status of the tests required for this StockItem. - return: + Return: A dict containing the following items: - total: Number of required tests - passed: Number of tests that have passed - failed: Number of tests that have failed """ - # All the tests required by the part object required = self.part.getRequiredTests() @@ -1912,31 +1818,21 @@ class StockItem(MetadataMixin, MPTTModel): @property def required_test_count(self): - """ - Return the number of 'required tests' for this StockItem - """ + """Return the number of 'required tests' for this StockItem.""" return self.part.getRequiredTests().count() def hasRequiredTests(self): - """ - Return True if there are any 'required tests' associated with this StockItem - """ + """Return True if there are any 'required tests' associated with this StockItem.""" return self.part.getRequiredTests().count() > 0 def passedAllRequiredTests(self): - """ - Returns True if this StockItem has passed all required tests - """ - + """Returns True if this StockItem has passed all required tests.""" status = self.requiredTestStatus() return status['passed'] >= status['total'] def available_test_reports(self): - """ - Return a list of TestReport objects which match this StockItem. - """ - + """Return a list of TestReport objects which match this StockItem.""" reports = [] item_query = StockItem.objects.filter(pk=self.pk) @@ -1955,17 +1851,11 @@ class StockItem(MetadataMixin, MPTTModel): @property def has_test_reports(self): - """ - Return True if there are test reports available for this stock item - """ - + """Return True if there are test reports available for this stock item.""" return len(self.available_test_reports()) > 0 def available_labels(self): - """ - Return a list of Label objects which match this StockItem - """ - + """Return a list of Label objects which match this StockItem.""" labels = [] item_query = StockItem.objects.filter(pk=self.pk) @@ -1984,22 +1874,17 @@ class StockItem(MetadataMixin, MPTTModel): @property def has_labels(self): - """ - Return True if there are any label templates available for this stock item - """ - + """Return True if there are any label templates available for this stock item.""" return len(self.available_labels()) > 0 @receiver(pre_delete, sender=StockItem, dispatch_uid='stock_item_pre_delete_log') def before_delete_stock_item(sender, instance, using, **kwargs): - """ - Receives pre_delete signal from StockItem object. + """Receives pre_delete signal from StockItem object. Before a StockItem is deleted, ensure that each child object is updated, to point to the new parent item. """ - # Update each StockItem parent field for child in instance.children.all(): child.parent = instance.parent @@ -2008,9 +1893,7 @@ def before_delete_stock_item(sender, instance, using, **kwargs): @receiver(post_delete, sender=StockItem, dispatch_uid='stock_item_post_delete_log') def after_delete_stock_item(sender, instance: StockItem, **kwargs): - """ - Function to be executed after a StockItem object is deleted - """ + """Function to be executed after a StockItem object is deleted.""" from part import tasks as part_tasks if not InvenTree.ready.isImportingData(): @@ -2020,9 +1903,7 @@ def after_delete_stock_item(sender, instance: StockItem, **kwargs): @receiver(post_save, sender=StockItem, dispatch_uid='stock_item_post_save_log') def after_save_stock_item(sender, instance: StockItem, created, **kwargs): - """ - Hook function to be executed after StockItem object is saved/updated - """ + """Hook function to be executed after StockItem object is saved/updated.""" from part import tasks as part_tasks if not InvenTree.ready.isImportingData(): @@ -2031,15 +1912,15 @@ def after_save_stock_item(sender, instance: StockItem, created, **kwargs): class StockItemAttachment(InvenTreeAttachment): - """ - Model for storing file attachments against a StockItem object. - """ + """Model for storing file attachments against a StockItem object.""" @staticmethod def get_api_url(): + """Return API url.""" return reverse('api-stock-attachment-list') def getSubdir(self): + """Override attachment location.""" return os.path.join("stock_files", str(self.stock_item.id)) stock_item = models.ForeignKey( @@ -2050,8 +1931,7 @@ class StockItemAttachment(InvenTreeAttachment): class StockItemTracking(models.Model): - """ - Stock tracking entry - used for tracking history of a particular StockItem + """Stock tracking entry - used for tracking history of a particular StockItem. Note: 2021-05-11 The legacy StockTrackingItem model contained very litle information about the "history" of the item. @@ -2072,13 +1952,15 @@ class StockItemTracking(models.Model): @staticmethod def get_api_url(): + """Return API url.""" return reverse('api-stock-tracking-list') def get_absolute_url(self): + """Return url for instance.""" return '/stock/track/{pk}'.format(pk=self.id) def label(self): - + """Return label.""" if self.tracking_type in StockHistoryCode.keys(): return StockHistoryCode.label(self.tracking_type) else: @@ -2109,13 +1991,13 @@ class StockItemTracking(models.Model): def rename_stock_item_test_result_attachment(instance, filename): - + """Rename test result.""" return os.path.join('stock_files', str(instance.stock_item.pk), os.path.basename(filename)) class StockItemTestResult(models.Model): - """ - A StockItemTestResult records results of custom tests against individual StockItem objects. + """A StockItemTestResult records results of custom tests against individual StockItem objects. + This is useful for tracking unit acceptance tests, and particularly useful when integrated with automated testing setups. @@ -2134,16 +2016,17 @@ class StockItemTestResult(models.Model): @staticmethod def get_api_url(): + """Return API url.""" return reverse('api-stock-test-result-list') def save(self, *args, **kwargs): - + """Validate result is unique before saving.""" super().clean() super().validate_unique() super().save(*args, **kwargs) def clean(self): - + """Make sure all values - including for templates - are provided.""" super().clean() # If this test result corresponds to a template, check the requirements of the template @@ -2170,6 +2053,7 @@ class StockItemTestResult(models.Model): @property def key(self): + """Return key for test.""" return InvenTree.helpers.generateTestKey(self.test) stock_item = models.ForeignKey( diff --git a/InvenTree/stock/serializers.py b/InvenTree/stock/serializers.py index 920052a0ec..fe4957d5b1 100644 --- a/InvenTree/stock/serializers.py +++ b/InvenTree/stock/serializers.py @@ -1,6 +1,4 @@ -""" -JSON serializers for Stock app -""" +"""JSON serializers for Stock app.""" from datetime import datetime, timedelta from decimal import Decimal @@ -29,11 +27,11 @@ from .models import (StockItem, StockItemAttachment, StockItemTestResult, class LocationBriefSerializer(InvenTree.serializers.InvenTreeModelSerializer): - """ - Provides a brief serializer for a StockLocation object - """ + """Provides a brief serializer for a StockLocation object.""" class Meta: + """Metaclass options.""" + model = StockLocation fields = [ 'pk', @@ -43,7 +41,7 @@ class LocationBriefSerializer(InvenTree.serializers.InvenTreeModelSerializer): class StockItemSerializerBrief(InvenTree.serializers.InvenTreeModelSerializer): - """ Brief serializers for a StockItem """ + """Brief serializers for a StockItem.""" location_name = serializers.CharField(source='location', read_only=True) part_name = serializers.CharField(source='part.full_name', read_only=True) @@ -51,6 +49,8 @@ class StockItemSerializerBrief(InvenTree.serializers.InvenTreeModelSerializer): quantity = InvenTreeDecimalField() class Meta: + """Metaclass options.""" + model = StockItem fields = [ 'part', @@ -65,34 +65,28 @@ class StockItemSerializerBrief(InvenTree.serializers.InvenTreeModelSerializer): ] def validate_serial(self, value): + """Make sure serial is not to big.""" if extract_int(value) > 2147483647: raise serializers.ValidationError('serial is to to big') return value class StockItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): - """ Serializer for a StockItem: + """Serializer for a StockItem. - Includes serialization for the linked part - Includes serialization for the item location """ def update(self, instance, validated_data): - """ - Custom update method to pass the user information through to the instance - """ - + """Custom update method to pass the user information through to the instance.""" instance._user = self.context['user'] return super().update(instance, validated_data) @staticmethod def annotate_queryset(queryset): - """ - Add some extra annotations to the queryset, - performing database queries as efficiently as possible. - """ - + """Add some extra annotations to the queryset, performing database queries as efficiently as possible.""" # Annotate the queryset with the total allocated to sales orders queryset = queryset.annotate( allocated=Coalesce( @@ -172,7 +166,7 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): purchase_price_string = serializers.SerializerMethodField() def get_purchase_price_string(self, obj): - + """Return purchase price as string.""" return str(obj.purchase_price) if obj.purchase_price else '-' purchase_order_reference = serializers.CharField(source='purchase_order.reference', read_only=True) @@ -180,7 +174,7 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): sales_order_reference = serializers.CharField(source='sales_order.reference', read_only=True) def __init__(self, *args, **kwargs): - + """Add detail fields.""" part_detail = kwargs.pop('part_detail', False) location_detail = kwargs.pop('location_detail', False) supplier_part_detail = kwargs.pop('supplier_part_detail', False) @@ -201,6 +195,8 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): self.fields.pop('required_tests') class Meta: + """Metaclass options.""" + model = StockItem fields = [ 'allocated', @@ -257,8 +253,7 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): class SerializeStockItemSerializer(serializers.Serializer): - """ - A DRF serializer for "serializing" a StockItem. + """A DRF serializer for "serializing" a StockItem. (Sorry for the confusing naming...) @@ -269,6 +264,8 @@ class SerializeStockItemSerializer(serializers.Serializer): """ class Meta: + """Metaclass options.""" + fields = [ 'quantity', 'serial_numbers', @@ -284,10 +281,7 @@ class SerializeStockItemSerializer(serializers.Serializer): ) def validate_quantity(self, quantity): - """ - Validate that the quantity value is correct - """ - + """Validate that the quantity value is correct.""" item = self.context['item'] if quantity < 0: @@ -323,10 +317,7 @@ class SerializeStockItemSerializer(serializers.Serializer): ) def validate(self, data): - """ - Check that the supplied serial numbers are valid - """ - + """Check that the supplied serial numbers are valid.""" data = super().validate(data) item = self.context['item'] @@ -358,7 +349,7 @@ class SerializeStockItemSerializer(serializers.Serializer): return data def save(self): - + """Serialize stock item.""" item = self.context['item'] request = self.context['request'] user = request.user @@ -381,9 +372,7 @@ class SerializeStockItemSerializer(serializers.Serializer): class InstallStockItemSerializer(serializers.Serializer): - """ - Serializer for installing a stock item into a given part - """ + """Serializer for installing a stock item into a given part.""" stock_item = serializers.PrimaryKeyRelatedField( queryset=StockItem.objects.all(), @@ -401,10 +390,7 @@ class InstallStockItemSerializer(serializers.Serializer): ) def validate_stock_item(self, stock_item): - """ - Validate the selected stock item - """ - + """Validate the selected stock item.""" if not stock_item.in_stock: # StockItem must be in stock to be "installed" raise ValidationError(_("Stock item is unavailable")) @@ -419,8 +405,7 @@ class InstallStockItemSerializer(serializers.Serializer): return stock_item def save(self): - """ Install the selected stock item into this one """ - + """Install the selected stock item into this one.""" data = self.validated_data stock_item = data['stock_item'] @@ -438,11 +423,11 @@ class InstallStockItemSerializer(serializers.Serializer): class UninstallStockItemSerializer(serializers.Serializer): - """ - API serializers for uninstalling an installed item from a stock item - """ + """API serializers for uninstalling an installed item from a stock item.""" class Meta: + """Metaclass options.""" + fields = [ 'location', 'note', @@ -462,7 +447,7 @@ class UninstallStockItemSerializer(serializers.Serializer): ) def save(self): - + """Uninstall stock item.""" item = self.context['item'] data = self.validated_data @@ -480,11 +465,11 @@ class UninstallStockItemSerializer(serializers.Serializer): class LocationTreeSerializer(InvenTree.serializers.InvenTreeModelSerializer): - """ - Serializer for a simple tree view - """ + """Serializer for a simple tree view.""" class Meta: + """Metaclass options.""" + model = StockLocation fields = [ 'pk', @@ -494,8 +479,7 @@ class LocationTreeSerializer(InvenTree.serializers.InvenTreeModelSerializer): class LocationSerializer(InvenTree.serializers.InvenTreeModelSerializer): - """ Detailed information about a stock location - """ + """Detailed information about a stock location.""" url = serializers.CharField(source='get_absolute_url', read_only=True) @@ -504,6 +488,8 @@ class LocationSerializer(InvenTree.serializers.InvenTreeModelSerializer): level = serializers.IntegerField(read_only=True) class Meta: + """Metaclass options.""" + model = StockLocation fields = [ 'pk', @@ -519,9 +505,10 @@ class LocationSerializer(InvenTree.serializers.InvenTreeModelSerializer): class StockItemAttachmentSerializer(InvenTree.serializers.InvenTreeAttachmentSerializer): - """ Serializer for StockItemAttachment model """ + """Serializer for StockItemAttachment model.""" def __init__(self, *args, **kwargs): + """Add detail fields.""" user_detail = kwargs.pop('user_detail', False) super().__init__(*args, **kwargs) @@ -534,6 +521,8 @@ class StockItemAttachmentSerializer(InvenTree.serializers.InvenTreeAttachmentSer # TODO: Record the uploading user when creating or updating an attachment! class Meta: + """Metaclass options.""" + model = StockItemAttachment fields = [ @@ -556,7 +545,7 @@ class StockItemAttachmentSerializer(InvenTree.serializers.InvenTreeAttachmentSer class StockItemTestResultSerializer(InvenTree.serializers.InvenTreeModelSerializer): - """ Serializer for the StockItemTestResult model """ + """Serializer for the StockItemTestResult model.""" user_detail = InvenTree.serializers.UserSerializerBrief(source='user', read_only=True) @@ -565,6 +554,7 @@ class StockItemTestResultSerializer(InvenTree.serializers.InvenTreeModelSerializ attachment = InvenTree.serializers.InvenTreeAttachmentSerializerField(required=False) def __init__(self, *args, **kwargs): + """Add detail fields.""" user_detail = kwargs.pop('user_detail', False) super().__init__(*args, **kwargs) @@ -573,6 +563,8 @@ class StockItemTestResultSerializer(InvenTree.serializers.InvenTreeModelSerializ self.fields.pop('user_detail') class Meta: + """Metaclass options.""" + model = StockItemTestResult fields = [ @@ -597,10 +589,10 @@ class StockItemTestResultSerializer(InvenTree.serializers.InvenTreeModelSerializ class StockTrackingSerializer(InvenTree.serializers.InvenTreeModelSerializer): - """ Serializer for StockItemTracking model """ + """Serializer for StockItemTracking model.""" def __init__(self, *args, **kwargs): - + """Add detail fields.""" item_detail = kwargs.pop('item_detail', False) user_detail = kwargs.pop('user_detail', False) @@ -621,6 +613,8 @@ class StockTrackingSerializer(InvenTree.serializers.InvenTreeModelSerializer): deltas = serializers.JSONField(read_only=True) class Meta: + """Metaclass options.""" + model = StockItemTracking fields = [ 'pk', @@ -644,8 +638,7 @@ class StockTrackingSerializer(InvenTree.serializers.InvenTreeModelSerializer): class StockAssignmentItemSerializer(serializers.Serializer): - """ - Serializer for a single StockItem with in StockAssignment request. + """Serializer for a single StockItem with in StockAssignment request. Here, the particular StockItem is being assigned (manually) to a customer @@ -654,6 +647,8 @@ class StockAssignmentItemSerializer(serializers.Serializer): """ class Meta: + """Metaclass options.""" + fields = [ 'item', ] @@ -667,7 +662,13 @@ class StockAssignmentItemSerializer(serializers.Serializer): ) def validate_item(self, item): + """Validate item. + Ensures: + - is in stock + - Is salable + - Is not allocated + """ # The item must currently be "in stock" if not item.in_stock: raise ValidationError(_("Item must be in stock")) @@ -688,13 +689,14 @@ class StockAssignmentItemSerializer(serializers.Serializer): class StockAssignmentSerializer(serializers.Serializer): - """ - Serializer for assigning one (or more) stock items to a customer. + """Serializer for assigning one (or more) stock items to a customer. This is a manual assignment process, separate for (for example) a Sales Order """ class Meta: + """Metaclass options.""" + fields = [ 'items', 'customer', @@ -716,7 +718,7 @@ class StockAssignmentSerializer(serializers.Serializer): ) def validate_customer(self, customer): - + """Make sure provided company is customer.""" if customer and not customer.is_customer: raise ValidationError(_('Selected company is not a customer')) @@ -730,7 +732,7 @@ class StockAssignmentSerializer(serializers.Serializer): ) def validate(self, data): - + """Make sure items were provided.""" data = super().validate(data) items = data.get('items', []) @@ -741,7 +743,7 @@ class StockAssignmentSerializer(serializers.Serializer): return data def save(self): - + """Assign stock.""" request = self.context['request'] user = getattr(request, 'user', None) @@ -765,13 +767,14 @@ class StockAssignmentSerializer(serializers.Serializer): class StockMergeItemSerializer(serializers.Serializer): - """ - Serializer for a single StockItem within the StockMergeSerializer class. + """Serializer for a single StockItem within the StockMergeSerializer class. Here, the individual StockItem is being checked for merge compatibility. """ class Meta: + """Metaclass options.""" + fields = [ 'item', ] @@ -785,7 +788,7 @@ class StockMergeItemSerializer(serializers.Serializer): ) def validate_item(self, item): - + """Make sure item can be merged.""" # Check that the stock item is able to be merged item.can_merge(raise_error=True) @@ -793,11 +796,11 @@ class StockMergeItemSerializer(serializers.Serializer): class StockMergeSerializer(serializers.Serializer): - """ - Serializer for merging two (or more) stock items together - """ + """Serializer for merging two (or more) stock items together.""" class Meta: + """Metaclass options.""" + fields = [ 'items', 'location', @@ -840,7 +843,7 @@ class StockMergeSerializer(serializers.Serializer): ) def validate(self, data): - + """Make sure all needed values are provided and that the items can be merged.""" data = super().validate(data) items = data['items'] @@ -879,11 +882,10 @@ class StockMergeSerializer(serializers.Serializer): return data def save(self): - """ - Actually perform the stock merging action. + """Actually perform the stock merging action. + At this point we are confident that the merge can take place """ - data = self.validated_data base_item = data['base_item'] @@ -908,8 +910,7 @@ class StockMergeSerializer(serializers.Serializer): class StockAdjustmentItemSerializer(serializers.Serializer): - """ - Serializer for a single StockItem within a stock adjument request. + """Serializer for a single StockItem within a stock adjument request. Fields: - item: StockItem object @@ -917,6 +918,8 @@ class StockAdjustmentItemSerializer(serializers.Serializer): """ class Meta: + """Metaclass options.""" + fields = [ 'item', 'quantity' @@ -940,11 +943,11 @@ class StockAdjustmentItemSerializer(serializers.Serializer): class StockAdjustmentSerializer(serializers.Serializer): - """ - Base class for managing stock adjustment actions via the API - """ + """Base class for managing stock adjustment actions via the API.""" class Meta: + """Metaclass options.""" + fields = [ 'items', 'notes', @@ -960,7 +963,7 @@ class StockAdjustmentSerializer(serializers.Serializer): ) def validate(self, data): - + """Make sure items are provided.""" super().validate(data) items = data.get('items', []) @@ -972,12 +975,10 @@ class StockAdjustmentSerializer(serializers.Serializer): class StockCountSerializer(StockAdjustmentSerializer): - """ - Serializer for counting stock items - """ + """Serializer for counting stock items.""" def save(self): - + """Count stock.""" request = self.context['request'] data = self.validated_data @@ -998,12 +999,10 @@ class StockCountSerializer(StockAdjustmentSerializer): class StockAddSerializer(StockAdjustmentSerializer): - """ - Serializer for adding stock to stock item(s) - """ + """Serializer for adding stock to stock item(s).""" def save(self): - + """Add stock.""" request = self.context['request'] data = self.validated_data @@ -1023,12 +1022,10 @@ class StockAddSerializer(StockAdjustmentSerializer): class StockRemoveSerializer(StockAdjustmentSerializer): - """ - Serializer for removing stock from stock item(s) - """ + """Serializer for removing stock from stock item(s).""" def save(self): - + """Remove stock.""" request = self.context['request'] data = self.validated_data @@ -1048,9 +1045,7 @@ class StockRemoveSerializer(StockAdjustmentSerializer): class StockTransferSerializer(StockAdjustmentSerializer): - """ - Serializer for transferring (moving) stock item(s) - """ + """Serializer for transferring (moving) stock item(s).""" location = serializers.PrimaryKeyRelatedField( queryset=StockLocation.objects.all(), @@ -1062,22 +1057,16 @@ class StockTransferSerializer(StockAdjustmentSerializer): ) class Meta: + """Metaclass options.""" + fields = [ 'items', 'notes', 'location', ] - def validate(self, data): - - data = super().validate(data) - - # TODO: Any specific validation of location field? - - return data - def save(self): - + """Transfer stock.""" request = self.context['request'] data = self.validated_data diff --git a/InvenTree/stock/test_api.py b/InvenTree/stock/test_api.py index 87d8bfaada..7a7c5aa67f 100644 --- a/InvenTree/stock/test_api.py +++ b/InvenTree/stock/test_api.py @@ -1,6 +1,4 @@ -""" -Unit testing for the Stock API -""" +"""Unit testing for the Stock API.""" import io import os @@ -21,6 +19,7 @@ from stock.models import StockItem, StockLocation class StockAPITestCase(InvenTreeAPITestCase): + """Mixin for stock api tests.""" fixtures = [ 'category', @@ -41,30 +40,28 @@ class StockAPITestCase(InvenTreeAPITestCase): 'stock.delete', ] - def setUp(self): - - super().setUp() - class StockLocationTest(StockAPITestCase): - """ - Series of API tests for the StockLocation API - """ + """Series of API tests for the StockLocation API.""" + list_url = reverse('api-location-list') def setUp(self): + """Setup for all tests.""" super().setUp() # Add some stock locations StockLocation.objects.create(name='top', description='top category') def test_list(self): + """Test StockLocation list.""" # Check that we can request the StockLocation list response = self.client.get(self.list_url, format='json') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertGreaterEqual(len(response.data), 1) def test_add(self): + """Test adding StockLocation.""" # Check that we can add a new StockLocation data = { 'parent': 1, @@ -76,17 +73,12 @@ class StockLocationTest(StockAPITestCase): class StockItemListTest(StockAPITestCase): - """ - Tests for the StockItem API LIST endpoint - """ + """Tests for the StockItem API LIST endpoint.""" list_url = reverse('api-stock-list') def get_stock(self, **kwargs): - """ - Filter stock and return JSON object - """ - + """Filter stock and return JSON object.""" response = self.client.get(self.list_url, format='json', data=kwargs) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -95,19 +87,13 @@ class StockItemListTest(StockAPITestCase): return response.data def test_get_stock_list(self): - """ - List *all* StockItem objects. - """ - + """List *all* StockItem objects.""" response = self.get_stock() self.assertEqual(len(response), 29) def test_filter_by_part(self): - """ - Filter StockItem by Part reference - """ - + """Filter StockItem by Part reference.""" response = self.get_stock(part=25) self.assertEqual(len(response), 17) @@ -116,19 +102,13 @@ class StockItemListTest(StockAPITestCase): self.assertEqual(len(response), 12) - def test_filter_by_IPN(self): - """ - Filter StockItem by IPN reference - """ - + def test_filter_by_ipn(self): + """Filter StockItem by IPN reference.""" response = self.get_stock(IPN="R.CH") self.assertEqual(len(response), 3) def test_filter_by_location(self): - """ - Filter StockItem by StockLocation reference - """ - + """Filter StockItem by StockLocation reference.""" response = self.get_stock(location=5) self.assertEqual(len(response), 1) @@ -142,10 +122,7 @@ class StockItemListTest(StockAPITestCase): self.assertEqual(len(response), 18) def test_filter_by_depleted(self): - """ - Filter StockItem by depleted status - """ - + """Filter StockItem by depleted status.""" response = self.get_stock(depleted=1) self.assertEqual(len(response), 1) @@ -153,10 +130,7 @@ class StockItemListTest(StockAPITestCase): self.assertEqual(len(response), 28) def test_filter_by_in_stock(self): - """ - Filter StockItem by 'in stock' status - """ - + """Filter StockItem by 'in stock' status.""" response = self.get_stock(in_stock=1) self.assertEqual(len(response), 26) @@ -164,10 +138,7 @@ class StockItemListTest(StockAPITestCase): self.assertEqual(len(response), 3) def test_filter_by_status(self): - """ - Filter StockItem by 'status' field - """ - + """Filter StockItem by 'status' field.""" codes = { StockStatus.OK: 27, StockStatus.DESTROYED: 1, @@ -183,18 +154,12 @@ class StockItemListTest(StockAPITestCase): self.assertEqual(len(response), num) def test_filter_by_batch(self): - """ - Filter StockItem by batch code - """ - + """Filter StockItem by batch code.""" response = self.get_stock(batch='B123') self.assertEqual(len(response), 1) def test_filter_by_serialized(self): - """ - Filter StockItem by serialized status - """ - + """Filter StockItem by serialized status.""" response = self.get_stock(serialized=1) self.assertEqual(len(response), 12) @@ -208,10 +173,7 @@ class StockItemListTest(StockAPITestCase): self.assertIsNone(item['serial']) def test_filter_by_has_batch(self): - """ - Test the 'has_batch' filter, which tests if the stock item has been assigned a batch code - """ - + """Test the 'has_batch' filter, which tests if the stock item has been assigned a batch code.""" with_batch = self.get_stock(has_batch=1) without_batch = self.get_stock(has_batch=0) @@ -227,11 +189,10 @@ class StockItemListTest(StockAPITestCase): self.assertTrue(item['batch'] in [None, '']) def test_filter_by_tracked(self): - """ - Test the 'tracked' filter. + """Test the 'tracked' filter. + This checks if the stock item has either a batch code *or* a serial number """ - tracked = self.get_stock(tracked=True) untracked = self.get_stock(tracked=False) @@ -248,10 +209,7 @@ class StockItemListTest(StockAPITestCase): self.assertTrue(item['batch'] in blank and item['serial'] in blank) def test_filter_by_expired(self): - """ - Filter StockItem by expiry status - """ - + """Filter StockItem by expiry status.""" # First, we can assume that the 'stock expiry' feature is disabled response = self.get_stock(expired=1) self.assertEqual(len(response), 29) @@ -289,10 +247,7 @@ class StockItemListTest(StockAPITestCase): self.assertEqual(len(response), 25) def test_paginate(self): - """ - Test that we can paginate results correctly - """ - + """Test that we can paginate results correctly.""" for n in [1, 5, 10]: response = self.get_stock(limit=n) @@ -302,7 +257,7 @@ class StockItemListTest(StockAPITestCase): self.assertEqual(len(response['results']), n) def export_data(self, filters=None): - + """Helper to test exports.""" if not filters: filters = {} @@ -321,10 +276,7 @@ class StockItemListTest(StockAPITestCase): return dataset def test_export(self): - """ - Test exporting of Stock data via the API - """ - + """Test exporting of Stock data via the API.""" dataset = self.export_data({}) # Check that *all* stock item objects have been exported @@ -361,13 +313,12 @@ class StockItemListTest(StockAPITestCase): class StockItemTest(StockAPITestCase): - """ - Series of API tests for the StockItem API - """ + """Series of API tests for the StockItem API.""" list_url = reverse('api-stock-list') def setUp(self): + """Setup for all tests.""" super().setUp() # Create some stock locations top = StockLocation.objects.create(name='A', description='top') @@ -376,11 +327,7 @@ class StockItemTest(StockAPITestCase): StockLocation.objects.create(name='C', description='location c', parent=top) def test_create_default_location(self): - """ - Test the default location functionality, - if a 'location' is not specified in the creation request. - """ - + """Test the default location functionality, if a 'location' is not specified in the creation request.""" # The part 'R_4K7_0603' (pk=4) has a default location specified response = self.client.post( @@ -423,10 +370,7 @@ class StockItemTest(StockAPITestCase): self.assertEqual(response.data['location'], None) def test_stock_item_create(self): - """ - Test creation of a StockItem via the API - """ - + """Test creation of a StockItem via the API.""" # POST with an empty part reference response = self.client.post( @@ -476,10 +420,7 @@ class StockItemTest(StockAPITestCase): ) def test_creation_with_serials(self): - """ - Test that serialized stock items can be created via the API, - """ - + """Test that serialized stock items can be created via the API.""" trackable_part = part.models.Part.objects.create( name='My part', description='A trackable part', @@ -537,8 +478,7 @@ class StockItemTest(StockAPITestCase): self.assertEqual(trackable_part.get_stock_count(), 10) def test_default_expiry(self): - """ - Test that the "default_expiry" functionality works via the API. + """Test that the "default_expiry" functionality works via the API. - If an expiry_date is specified, use that - Otherwise, check if the referenced part has a default_expiry defined @@ -547,9 +487,7 @@ class StockItemTest(StockAPITestCase): Notes: - Part <25> has a default_expiry of 10 days - """ - # First test - create a new StockItem without an expiry date data = { 'part': 4, @@ -587,10 +525,7 @@ class StockItemTest(StockAPITestCase): self.assertEqual(response.data['expiry_date'], expiry.isoformat()) def test_purchase_price(self): - """ - Test that we can correctly read and adjust purchase price information via the API - """ - + """Test that we can correctly read and adjust purchase price information via the API.""" url = reverse('api-stock-detail', kwargs={'pk': 1}) data = self.get(url, expected_code=200).data @@ -648,8 +583,7 @@ class StockItemTest(StockAPITestCase): self.assertEqual(data['purchase_price_currency'], 'NZD') def test_install(self): - """ Test that stock item can be installed into antoher item, via the API """ - + """Test that stock item can be installed into antoher item, via the API.""" # Select the "parent" stock item parent_part = part.models.Part.objects.get(pk=100) @@ -731,16 +665,10 @@ class StockItemTest(StockAPITestCase): class StocktakeTest(StockAPITestCase): - """ - Series of tests for the Stocktake API - """ + """Series of tests for the Stocktake API.""" def test_action(self): - """ - Test each stocktake action endpoint, - for validation - """ - + """Test each stocktake action endpoint, for validation.""" for endpoint in ['api-stock-count', 'api-stock-add', 'api-stock-remove']: url = reverse(endpoint) @@ -796,10 +724,7 @@ class StocktakeTest(StockAPITestCase): self.assertContains(response, 'Ensure this value is greater than or equal to 0', status_code=status.HTTP_400_BAD_REQUEST) def test_transfer(self): - """ - Test stock transfers - """ - + """Test stock transfers.""" data = { 'items': [ { @@ -825,12 +750,10 @@ class StocktakeTest(StockAPITestCase): class StockItemDeletionTest(StockAPITestCase): - """ - Tests for stock item deletion via the API - """ + """Tests for stock item deletion via the API.""" def test_delete(self): - + """Test stock item deletion.""" n = StockItem.objects.count() # Create and then delete a bunch of stock items @@ -861,12 +784,14 @@ class StockItemDeletionTest(StockAPITestCase): class StockTestResultTest(StockAPITestCase): + """Tests for StockTestResult APIs.""" def get_url(self): + """Helper funtion to get test-result api url.""" return reverse('api-stock-test-result-list') def test_list(self): - + """Test list endpoint.""" url = self.get_url() response = self.client.get(url) @@ -878,6 +803,7 @@ class StockTestResultTest(StockAPITestCase): self.assertGreaterEqual(len(response.data), 4) def test_post_fail(self): + """Test failing posts.""" # Attempt to post a new test result without specifying required data url = self.get_url() @@ -907,8 +833,7 @@ class StockTestResultTest(StockAPITestCase): self.assertEqual(response.status_code, status.HTTP_201_CREATED) def test_post(self): - # Test creation of a new test result - + """Test creation of a new test result.""" url = self.get_url() response = self.client.get(url) @@ -939,8 +864,7 @@ class StockTestResultTest(StockAPITestCase): self.assertEqual(test['user'], self.user.pk) def test_post_bitmap(self): - """ - 2021-08-25 + """2021-08-25. For some (unknown) reason, prior to fix https://github.com/inventree/InvenTree/pull/2018 uploading a bitmap image would result in a failure. @@ -949,7 +873,6 @@ class StockTestResultTest(StockAPITestCase): As a bonus this also tests the file-upload component """ - here = os.path.dirname(__file__) image_file = os.path.join(here, 'fixtures', 'test_image.bmp') @@ -974,15 +897,12 @@ class StockTestResultTest(StockAPITestCase): class StockAssignTest(StockAPITestCase): - """ - Unit tests for the stock assignment API endpoint, - where stock items are manually assigned to a customer - """ + """Unit tests for the stock assignment API endpoint, where stock items are manually assigned to a customer.""" URL = reverse('api-stock-assign') def test_invalid(self): - + """Test invalid assign.""" # Test with empty data response = self.post( self.URL, @@ -1049,7 +969,7 @@ class StockAssignTest(StockAPITestCase): self.assertIn('Item must be in stock', str(response.data['items'][0])) def test_valid(self): - + """Test valid assign.""" stock_items = [] for i in range(5): @@ -1083,14 +1003,12 @@ class StockAssignTest(StockAPITestCase): class StockMergeTest(StockAPITestCase): - """ - Unit tests for merging stock items via the API - """ + """Unit tests for merging stock items via the API.""" URL = reverse('api-stock-merge') def setUp(self): - + """Setup for all tests.""" super().setUp() self.part = part.models.Part.objects.get(pk=25) @@ -1117,10 +1035,7 @@ class StockMergeTest(StockAPITestCase): ) def test_missing_data(self): - """ - Test responses which are missing required data - """ - + """Test responses which are missing required data.""" # Post completely empty data = self.post( @@ -1145,10 +1060,7 @@ class StockMergeTest(StockAPITestCase): self.assertIn('At least two stock items', str(data)) def test_invalid_data(self): - """ - Test responses which have invalid data - """ - + """Test responses which have invalid data.""" # Serialized stock items should be rejected data = self.post( self.URL, @@ -1229,10 +1141,7 @@ class StockMergeTest(StockAPITestCase): self.assertIn('Stock items must refer to the same supplier part', str(data)) def test_valid_merge(self): - """ - Test valid merging of stock items - """ - + """Test valid merging of stock items.""" # Check initial conditions n = StockItem.objects.filter(part=self.part).count() self.assertEqual(self.item_1.quantity, 100) diff --git a/InvenTree/stock/test_views.py b/InvenTree/stock/test_views.py index dba39334de..f54b188dcd 100644 --- a/InvenTree/stock/test_views.py +++ b/InvenTree/stock/test_views.py @@ -1,4 +1,4 @@ -""" Unit tests for Stock views (see views.py) """ +"""Unit tests for Stock views (see views.py).""" from django.urls import reverse @@ -8,6 +8,7 @@ from InvenTree.helpers import InvenTreeTestCase class StockViewTestCase(InvenTreeTestCase): + """Mixin for Stockview tests.""" fixtures = [ 'category', @@ -22,18 +23,19 @@ class StockViewTestCase(InvenTreeTestCase): class StockListTest(StockViewTestCase): - """ Tests for Stock list views """ + """Tests for Stock list views.""" def test_stock_index(self): + """Test stock index page.""" response = self.client.get(reverse('stock-index')) self.assertEqual(response.status_code, 200) class StockOwnershipTest(StockViewTestCase): - """ Tests for stock ownership views """ + """Tests for stock ownership views.""" def setUp(self): - """ Add another user for ownership tests """ + """Add another user for ownership tests.""" """ TODO: Refactor this following test to use the new API form diff --git a/InvenTree/stock/tests.py b/InvenTree/stock/tests.py index fed4b9858e..edeea46486 100644 --- a/InvenTree/stock/tests.py +++ b/InvenTree/stock/tests.py @@ -1,3 +1,5 @@ +"""Tests for stock app.""" + import datetime from django.core.exceptions import ValidationError @@ -13,9 +15,7 @@ from .models import (StockItem, StockItemTestResult, StockItemTracking, class StockTest(InvenTreeTestCase): - """ - Tests to ensure that the stock location tree functions correcly - """ + """Tests to ensure that the stock location tree functions correcly.""" fixtures = [ 'category', @@ -27,6 +27,7 @@ class StockTest(InvenTreeTestCase): ] def setUp(self): + """Setup for all tests.""" super().setUp() # Extract some shortcuts from the fixtures @@ -44,10 +45,7 @@ class StockTest(InvenTreeTestCase): StockItem.objects.rebuild() def test_expiry(self): - """ - Test expiry date functionality for StockItem model. - """ - + """Test expiry date functionality for StockItem model.""" today = datetime.datetime.now().date() item = StockItem.objects.create( @@ -78,10 +76,7 @@ class StockTest(InvenTreeTestCase): self.assertTrue(item.is_expired()) def test_is_building(self): - """ - Test that the is_building flag does not count towards stock. - """ - + """Test that the is_building flag does not count towards stock.""" part = Part.objects.get(pk=1) # Record the total stock count @@ -107,25 +102,29 @@ class StockTest(InvenTreeTestCase): self.assertEqual(part.quantity_being_built, 1) def test_loc_count(self): + """Test count function.""" self.assertEqual(StockLocation.objects.count(), 7) def test_url(self): + """Test get_absolute_url function.""" it = StockItem.objects.get(pk=2) self.assertEqual(it.get_absolute_url(), '/stock/item/2/') self.assertEqual(self.home.get_absolute_url(), '/stock/location/1/') def test_barcode(self): + """Test format_barcode.""" barcode = self.office.format_barcode(brief=False) self.assertIn('"name": "Office"', barcode) def test_strings(self): + """Test str function.""" it = StockItem.objects.get(pk=1) self.assertEqual(str(it), '4000 x M2x4 LPHS @ Dining Room') def test_parent_locations(self): - + """Test parent.""" self.assertEqual(self.office.parent, None) self.assertEqual(self.drawer1.parent, self.office) self.assertEqual(self.drawer2.parent, self.office) @@ -142,6 +141,7 @@ class StockTest(InvenTreeTestCase): self.assertEqual(self.drawer3.pathstring, 'Home/Drawer_3') def test_children(self): + """Test has_children.""" self.assertTrue(self.office.has_children) self.assertFalse(self.drawer2.has_children) @@ -154,15 +154,14 @@ class StockTest(InvenTreeTestCase): self.assertNotIn(self.bathroom.id, childs) def test_items(self): - self.assertTrue(self.drawer1.has_items()) - self.assertTrue(self.drawer3.has_items()) - self.assertFalse(self.drawer2.has_items()) + """Test has_items.""" # Drawer 3 should have three stock items self.assertEqual(self.drawer3.stock_items.count(), 18) self.assertEqual(self.drawer3.item_count, 18) def test_stock_count(self): + """Test stock count.""" part = Part.objects.get(pk=1) entries = part.stock_entries() @@ -177,7 +176,7 @@ class StockTest(InvenTreeTestCase): ) def test_delete_location(self): - + """Test deleting stock.""" # How many stock items are there? n_stock = StockItem.objects.count() @@ -196,8 +195,7 @@ class StockTest(InvenTreeTestCase): self.assertEqual(s_item.location, self.office) def test_move(self): - """ Test stock movement functions """ - + """Test stock movement functions.""" # Move 4,000 screws to the bathroom it = StockItem.objects.get(pk=1) self.assertNotEqual(it.location, self.bathroom) @@ -215,6 +213,7 @@ class StockTest(InvenTreeTestCase): self.assertEqual(track.notes, 'Moved to the bathroom') def test_self_move(self): + """Test moving stock to itself does not work.""" # Try to move an item to its current location (should fail) it = StockItem.objects.get(pk=1) @@ -225,6 +224,7 @@ class StockTest(InvenTreeTestCase): self.assertEqual(it.tracking_info.count(), n) def test_partial_move(self): + """Test partial stock moving.""" w1 = StockItem.objects.get(pk=100) # A batch code is required to split partial stock! @@ -249,6 +249,7 @@ class StockTest(InvenTreeTestCase): self.assertFalse(widget.move(None, 'null', None)) def test_split_stock(self): + """Test stock splitting.""" # Split the 1234 x 2K2 resistors in Drawer_1 n = StockItem.objects.filter(part=3).count() @@ -268,6 +269,7 @@ class StockTest(InvenTreeTestCase): self.assertEqual(StockItem.objects.filter(part=3).count(), n + 1) def test_stocktake(self): + """Test stocktake function.""" # Perform stocktake it = StockItem.objects.get(pk=2) self.assertEqual(it.quantity, 5000) @@ -288,6 +290,7 @@ class StockTest(InvenTreeTestCase): self.assertEqual(it.tracking_info.count(), n) def test_add_stock(self): + """Test adding stock.""" it = StockItem.objects.get(pk=2) n = it.quantity it.add_stock(45, None, notes='Added some items') @@ -303,6 +306,7 @@ class StockTest(InvenTreeTestCase): self.assertFalse(it.add_stock(-10, None)) def test_take_stock(self): + """Test stock removal.""" it = StockItem.objects.get(pk=2) n = it.quantity it.take_stock(15, None, notes='Removed some items') @@ -320,7 +324,7 @@ class StockTest(InvenTreeTestCase): self.assertFalse(it.take_stock(-10, None)) def test_deplete_stock(self): - + """Test depleted stock deletion.""" w1 = StockItem.objects.get(pk=100) w2 = StockItem.objects.get(pk=101) @@ -339,10 +343,7 @@ class StockTest(InvenTreeTestCase): w2 = StockItem.objects.get(pk=101) def test_serials(self): - """ - Tests for stock serialization - """ - + """Tests for stock serialization.""" p = Part.objects.create( name='trackable part', description='trackable part', @@ -373,10 +374,7 @@ class StockTest(InvenTreeTestCase): self.assertTrue(item.serialized) def test_big_serials(self): - """ - Unit tests for "large" serial numbers which exceed integer encoding - """ - + """Unit tests for "large" serial numbers which exceed integer encoding.""" p = Part.objects.create( name='trackable part', description='trackable part', @@ -451,11 +449,10 @@ class StockTest(InvenTreeTestCase): self.assertEqual(item_prev.serial_int, 99) def test_serialize_stock_invalid(self): - """ - Test manual serialization of parts. + """Test manual serialization of parts. + Each of these tests should fail """ - # Test serialization of non-serializable part item = StockItem.objects.get(pk=1234) @@ -480,8 +477,7 @@ class StockTest(InvenTreeTestCase): item.serializeStock(3, "hello", self.user) def test_serialize_stock_valid(self): - """ Perform valid stock serializations """ - + """Perform valid stock serializations.""" # There are 10 of these in stock # Item will deplete when deleted item = StockItem.objects.get(pk=100) @@ -517,15 +513,14 @@ class StockTest(InvenTreeTestCase): item.serializeStock(2, [99, 100], self.user) def test_location_tree(self): - """ - Unit tests for stock location tree structure (MPTT). + """Unit tests for stock location tree structure (MPTT). + Ensure that the MPTT structure is rebuilt correctly, and the corrent ancestor tree is observed. Ref: https://github.com/inventree/InvenTree/issues/2636 Ref: https://github.com/inventree/InvenTree/issues/2733 """ - # First, we will create a stock location structure A = StockLocation.objects.create( @@ -686,11 +681,10 @@ class StockTest(InvenTreeTestCase): class VariantTest(StockTest): - """ - Tests for calculation stock counts against templates / variants - """ + """Tests for calculation stock counts against templates / variants.""" def test_variant_stock(self): + """Test variant functions.""" # Check the 'Chair' variant chair = Part.objects.get(pk=10000) @@ -704,8 +698,7 @@ class VariantTest(StockTest): self.assertEqual(green.stock_entries().count(), 3) def test_serial_numbers(self): - # Test serial number functionality for variant / template parts - + """Test serial number functionality for variant / template parts.""" chair = Part.objects.get(pk=10000) # Operations on the top-level object @@ -769,11 +762,10 @@ class VariantTest(StockTest): class TestResultTest(StockTest): - """ - Tests for the StockItemTestResult model. - """ + """Tests for the StockItemTestResult model.""" def test_test_count(self): + """Test test count.""" item = StockItem.objects.get(pk=105) tests = item.test_results self.assertEqual(tests.count(), 4) @@ -795,7 +787,7 @@ class TestResultTest(StockTest): self.assertIn(test, result_map.keys()) def test_test_results(self): - + """Test test results.""" item = StockItem.objects.get(pk=522) status = item.requiredTestStatus() @@ -832,7 +824,7 @@ class TestResultTest(StockTest): self.assertTrue(item.passedAllRequiredTests()) def test_duplicate_item_tests(self): - + """Test duplicate item behaviour.""" # Create an example stock item by copying one from the database (because we are lazy) item = StockItem.objects.get(pk=522) @@ -898,12 +890,10 @@ class TestResultTest(StockTest): self.assertEqual(item3.test_results.count(), 4) def test_installed_tests(self): - """ - Test test results for stock in stock. + """Test test results for stock in stock. Or, test "test results" for "stock items" installed "inside" a "stock item" """ - # Get a "master" stock item item = StockItem.objects.get(pk=105) diff --git a/InvenTree/stock/urls.py b/InvenTree/stock/urls.py index 3a9ad6c490..859a9114a4 100644 --- a/InvenTree/stock/urls.py +++ b/InvenTree/stock/urls.py @@ -1,6 +1,4 @@ -""" -URL lookup for Stock app -""" +"""URL lookup for Stock app.""" from django.urls import include, re_path diff --git a/InvenTree/stock/views.py b/InvenTree/stock/views.py index 7a0f1ab978..1d49b13f13 100644 --- a/InvenTree/stock/views.py +++ b/InvenTree/stock/views.py @@ -1,6 +1,4 @@ -""" -Django views for interacting with Stock app -""" +"""Django views for interacting with Stock app.""" from datetime import datetime @@ -21,13 +19,14 @@ from .models import StockItem, StockItemTracking, StockLocation class StockIndex(InvenTreeRoleMixin, InvenTreePluginViewMixin, ListView): - """ StockIndex view loads all StockLocation and StockItem object - """ + """StockIndex view loads all StockLocation and StockItem object.""" + model = StockItem template_name = 'stock/location.html' context_obect_name = 'locations' def get_context_data(self, **kwargs): + """Extend template context.""" context = super().get_context_data(**kwargs).copy() # Return all top-level locations @@ -48,9 +47,7 @@ class StockIndex(InvenTreeRoleMixin, InvenTreePluginViewMixin, ListView): class StockLocationDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): - """ - Detailed view of a single StockLocation object - """ + """Detailed view of a single StockLocation object.""" context_object_name = 'location' template_name = 'stock/location.html' @@ -58,7 +55,7 @@ class StockLocationDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailVi model = StockLocation def get_context_data(self, **kwargs): - + """Extend template context.""" context = super().get_context_data(**kwargs) context['ownership_enabled'] = common.models.InvenTreeSetting.get_setting('STOCK_OWNERSHIP_CONTROL') @@ -69,9 +66,7 @@ class StockLocationDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailVi class StockItemDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): - """ - Detailed view of a single StockItem object - """ + """Detailed view of a single StockItem object.""" context_object_name = 'item' template_name = 'stock/item.html' @@ -79,11 +74,7 @@ class StockItemDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): model = StockItem def get_context_data(self, **kwargs): - """ - Add information on the "next" and "previous" StockItem objects, - based on the serial numbers. - """ - + """Add information on the "next" and "previous" StockItem objects, based on the serial numbers.""" data = super().get_context_data(**kwargs) if self.object.serialized: @@ -103,8 +94,7 @@ class StockItemDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): return data def get(self, request, *args, **kwargs): - """ check if item exists else return to stock index """ - + """Check if item exists else return to stock index.""" stock_pk = kwargs.get('pk', None) if stock_pk: @@ -120,14 +110,14 @@ class StockItemDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): class StockLocationQRCode(QRCodeView): - """ View for displaying a QR code for a StockLocation object """ + """View for displaying a QR code for a StockLocation object.""" ajax_form_title = _("Stock Location QR code") role_required = ['stock_location.view', 'stock.view'] def get_qr_data(self): - """ Generate QR code data for the StockLocation """ + """Generate QR code data for the StockLocation.""" try: loc = StockLocation.objects.get(id=self.pk) return loc.format_barcode() @@ -136,9 +126,7 @@ class StockLocationQRCode(QRCodeView): class StockItemReturnToStock(AjaxUpdateView): - """ - View for returning a stock item (which is assigned to a customer) to stock. - """ + """View for returning a stock item (which is assigned to a customer) to stock.""" model = StockItem ajax_form_title = _("Return to Stock") @@ -146,29 +134,28 @@ class StockItemReturnToStock(AjaxUpdateView): form_class = StockForms.ReturnStockItemForm def validate(self, item, form, **kwargs): - + """Make sure required data is there.""" location = form.cleaned_data.get('location', None) if not location: form.add_error('location', _('Specify a valid location')) def save(self, item, form, **kwargs): - + """Return stock.""" location = form.cleaned_data.get('location', None) if location: item.returnFromCustomer(location, self.request.user) def get_data(self): + """Set success message.""" return { 'success': _('Stock item returned from customer') } class StockItemDeleteTestData(AjaxUpdateView): - """ - View for deleting all test data - """ + """View for deleting all test data.""" model = StockItem form_class = ConfirmForm @@ -177,10 +164,11 @@ class StockItemDeleteTestData(AjaxUpdateView): role_required = ['stock.change', 'stock.delete'] def get_form(self): + """Require confirm.""" return ConfirmForm() def post(self, request, *args, **kwargs): - + """Delete test data.""" valid = False stock_item = StockItem.objects.get(pk=self.kwargs['pk']) @@ -203,13 +191,13 @@ class StockItemDeleteTestData(AjaxUpdateView): class StockItemQRCode(QRCodeView): - """ View for displaying a QR code for a StockItem object """ + """View for displaying a QR code for a StockItem object.""" ajax_form_title = _("Stock Item QR Code") role_required = 'stock.view' def get_qr_data(self): - """ Generate QR code data for the StockItem """ + """Generate QR code data for the StockItem.""" try: item = StockItem.objects.get(id=self.pk) return item.format_barcode() @@ -218,9 +206,7 @@ class StockItemQRCode(QRCodeView): class StockItemConvert(AjaxUpdateView): - """ - View for 'converting' a StockItem to a variant of its current part. - """ + """View for 'converting' a StockItem to a variant of its current part.""" model = StockItem form_class = StockForms.ConvertStockItemForm @@ -229,10 +215,7 @@ class StockItemConvert(AjaxUpdateView): context_object_name = 'item' def get_form(self): - """ - Filter the available parts. - """ - + """Filter the available parts.""" form = super().get_form() item = self.get_object() @@ -241,7 +224,7 @@ class StockItemConvert(AjaxUpdateView): return form def save(self, obj, form): - + """Convert item to variant.""" stock_item = self.get_object() variant = form.cleaned_data.get('part', None) @@ -252,8 +235,8 @@ class StockItemConvert(AjaxUpdateView): class StockLocationDelete(AjaxDeleteView): - """ - View to delete a StockLocation + """View to delete a StockLocation. + Presents a deletion confirmation form to the user """ @@ -265,8 +248,8 @@ class StockLocationDelete(AjaxDeleteView): class StockItemDelete(AjaxDeleteView): - """ - View to delete a StockItem + """View to delete a StockItem. + Presents a deletion confirmation form to the user """ @@ -278,8 +261,8 @@ class StockItemDelete(AjaxDeleteView): class StockItemTrackingDelete(AjaxDeleteView): - """ - View to delete a StockItemTracking object + """View to delete a StockItemTracking object. + Presents a deletion confirmation form to the user """ @@ -289,7 +272,7 @@ class StockItemTrackingDelete(AjaxDeleteView): class StockItemTrackingEdit(AjaxUpdateView): - """ View for editing a StockItemTracking object """ + """View for editing a StockItemTracking object.""" model = StockItemTracking ajax_form_title = _('Edit Stock Tracking Entry') @@ -297,15 +280,14 @@ class StockItemTrackingEdit(AjaxUpdateView): class StockItemTrackingCreate(AjaxCreateView): - """ View for creating a new StockItemTracking object. - """ + """View for creating a new StockItemTracking object.""" model = StockItemTracking ajax_form_title = _("Add Stock Tracking Entry") form_class = StockForms.TrackingEntryForm def post(self, request, *args, **kwargs): - + """Create StockItemTracking object.""" self.request = request self.form = self.get_form() diff --git a/InvenTree/users/admin.py b/InvenTree/users/admin.py index 2be478fa0e..b608f91b46 100644 --- a/InvenTree/users/admin.py +++ b/InvenTree/users/admin.py @@ -1,3 +1,4 @@ +"""Admin classes for the 'users' app""" from django import forms from django.contrib import admin, messages @@ -14,9 +15,7 @@ User = get_user_model() class RuleSetInline(admin.TabularInline): - """ - Class for displaying inline RuleSet data in the Group admin page. - """ + """Class for displaying inline RuleSet data in the Group admin page.""" model = RuleSet can_delete = False @@ -32,13 +31,13 @@ class RuleSetInline(admin.TabularInline): class InvenTreeGroupAdminForm(forms.ModelForm): - """ - Custom admin form for the Group model. + """Custom admin form for the Group model. Adds the ability for editing user membership directly in the group admin page. """ class Meta: + """Metaclass defines extra fields""" model = Group exclude = [] fields = [ @@ -47,6 +46,7 @@ class InvenTreeGroupAdminForm(forms.ModelForm): ] def __init__(self, *args, **kwargs): # pragma: no cover + """Populate the 'users' field with the users in the current group""" super().__init__(*args, **kwargs) if self.instance.pk: @@ -63,11 +63,11 @@ class InvenTreeGroupAdminForm(forms.ModelForm): ) def save_m2m(self): # pragma: no cover - # Add the users to the Group. - + """Add the users to the Group""" self.instance.user_set.set(self.cleaned_data['users']) def save(self, *args, **kwargs): # pragma: no cover + """Custom save method for Group admin form""" # Default save instance = super().save() # Save many-to-many data @@ -76,9 +76,7 @@ class InvenTreeGroupAdminForm(forms.ModelForm): class RoleGroupAdmin(admin.ModelAdmin): # pragma: no cover - """ - Custom admin interface for the Group model - """ + """Custom admin interface for the Group model.""" form = InvenTreeGroupAdminForm @@ -90,8 +88,7 @@ class RoleGroupAdmin(admin.ModelAdmin): # pragma: no cover 'stock_item', 'build', 'purchase_order', 'sales_order') def get_rule_set(self, obj, rule_set_type): - ''' Return list of permissions for the given ruleset ''' - + """Return list of permissions for the given ruleset.""" # Get all rulesets associated to object rule_sets = RuleSet.objects.filter(group=obj.pk) @@ -101,6 +98,7 @@ class RoleGroupAdmin(admin.ModelAdmin): # pragma: no cover break def append_permission_level(permission_level, next_level): + """Appen permission level""" if not permission_level: return next_level @@ -128,30 +126,39 @@ class RoleGroupAdmin(admin.ModelAdmin): # pragma: no cover return permission_level def admin(self, obj): + """Return the ruleset for the admin role""" return self.get_rule_set(obj, 'admin') def part_category(self, obj): + """Return the ruleset for the PartCategory role""" return self.get_rule_set(obj, 'part_category') def part(self, obj): + """Return the ruleset for the Part role""" return self.get_rule_set(obj, 'part') def stock_location(self, obj): + """Return the ruleset for the StockLocation role""" return self.get_rule_set(obj, 'stock_location') def stock_item(self, obj): + """Return the ruleset for the StockItem role""" return self.get_rule_set(obj, 'stock') def build(self, obj): + """Return the ruleset for the BuildOrder role""" return self.get_rule_set(obj, 'build') def purchase_order(self, obj): + """Return the ruleset for the PurchaseOrder role""" return self.get_rule_set(obj, 'purchase_order') def sales_order(self, obj): + """Return the ruleset for the SalesOrder role""" return self.get_rule_set(obj, 'sales_order') def get_formsets_with_inlines(self, request, obj=None): + """Return all inline formsets""" for inline in self.get_inline_instances(request, obj): # Hide RuleSetInline in the 'Add role' view if not isinstance(inline, RuleSetInline) or obj is not None: @@ -160,12 +167,12 @@ class RoleGroupAdmin(admin.ModelAdmin): # pragma: no cover filter_horizontal = ['permissions'] def save_model(self, request, obj, form, change): - """ - This method serves two purposes: + """Save overwrite. + + This method serves two purposes: - show warning message whenever the group users belong to multiple groups - skip saving of the group instance model as inlines needs to be saved before. """ - # Get form cleaned data users = form.cleaned_data['users'] @@ -186,6 +193,7 @@ class RoleGroupAdmin(admin.ModelAdmin): # pragma: no cover messages.add_message(request, messages.WARNING, warning_message) def save_formset(self, request, form, formset, change): + """Save the inline formset""" # Save inline Rulesets formset.save() # Save Group instance and update permissions @@ -193,8 +201,7 @@ class RoleGroupAdmin(admin.ModelAdmin): # pragma: no cover class InvenTreeUserAdmin(UserAdmin): - """ - Custom admin page for the User model. + """Custom admin page for the User model. Hides the "permissions" view as this is now handled entirely by groups and RuleSets. @@ -213,9 +220,7 @@ class InvenTreeUserAdmin(UserAdmin): class OwnerAdmin(admin.ModelAdmin): - """ - Custom admin interface for the Owner model - """ + """Custom admin interface for the Owner model.""" pass diff --git a/InvenTree/users/api.py b/InvenTree/users/api.py index d2b25ceb5b..332c5ec67d 100644 --- a/InvenTree/users/api.py +++ b/InvenTree/users/api.py @@ -1,3 +1,4 @@ +"""DRF API definition for the 'users' app""" from django.contrib.auth.models import User from django.core.exceptions import ObjectDoesNotExist @@ -14,16 +15,16 @@ from users.serializers import OwnerSerializer, UserSerializer class OwnerList(generics.ListAPIView): - """ - List API endpoint for Owner model. Cannot create. + """List API endpoint for Owner model. + + Cannot create. """ queryset = Owner.objects.all() serializer_class = OwnerSerializer def filter_queryset(self, queryset): - """ - Implement text search for the "owner" model. + """Implement text search for the "owner" model. Note that an "owner" can be either a group, or a user, so we cannot do a direct text search. @@ -34,7 +35,6 @@ class OwnerList(generics.ListAPIView): It is not necessarily "efficient" to do it this way, but until we determine a better way, this is what we have... """ - search_term = str(self.request.query_params.get('search', '')).lower() queryset = super().filter_queryset(queryset) @@ -54,8 +54,9 @@ class OwnerList(generics.ListAPIView): class OwnerDetail(generics.RetrieveAPIView): - """ - Detail API endpoint for Owner model. Cannot edit or delete + """Detail API endpoint for Owner model. + + Cannot edit or delete """ queryset = Owner.objects.all() @@ -63,9 +64,7 @@ class OwnerDetail(generics.RetrieveAPIView): class RoleDetails(APIView): - """ - API endpoint which lists the available role permissions - for the current user + """API endpoint which lists the available role permissions for the current user. (Requires authentication) """ @@ -75,7 +74,7 @@ class RoleDetails(APIView): ] def get(self, request, *args, **kwargs): - + """Return the list of roles / permissions available to the current user""" user = request.user roles = {} @@ -108,7 +107,7 @@ class RoleDetails(APIView): class UserDetail(generics.RetrieveAPIView): - """ Detail endpoint for a single user """ + """Detail endpoint for a single user.""" queryset = User.objects.all() serializer_class = UserSerializer @@ -116,7 +115,7 @@ class UserDetail(generics.RetrieveAPIView): class UserList(generics.ListAPIView): - """ List endpoint for detail on all users """ + """List endpoint for detail on all users.""" queryset = User.objects.all() serializer_class = UserSerializer @@ -135,20 +134,18 @@ class UserList(generics.ListAPIView): class GetAuthToken(APIView): - """ Return authentication token for an authenticated user. """ + """Return authentication token for an authenticated user.""" permission_classes = [ permissions.IsAuthenticated, ] def get(self, request, *args, **kwargs): - return self.login(request) - - def delete(self, request): - return self.logout(request) - - def login(self, request): + """Return an API token if the user is authenticated + - If the user already has a token, return it + - Otherwise, create a new token + """ if request.user.is_authenticated: # Get the user token (or create one if it does not exist) token, created = Token.objects.get_or_create(user=request.user) @@ -156,7 +153,8 @@ class GetAuthToken(APIView): 'token': token.key, }) - def logout(self, request): + def delete(self, request): + """User has requested deletion of API token""" try: request.user.auth_token.delete() return Response({"success": "Successfully logged out."}, diff --git a/InvenTree/users/apps.py b/InvenTree/users/apps.py index b20927f0ed..71c2d079ce 100644 --- a/InvenTree/users/apps.py +++ b/InvenTree/users/apps.py @@ -1,14 +1,22 @@ +"""App configuration class for the 'users' app""" + +import logging + from django.apps import AppConfig from django.db.utils import OperationalError, ProgrammingError from InvenTree.ready import canAppAccessDatabase +logger = logging.getLogger('inventree') + class UsersConfig(AppConfig): + """Config class for the 'users' app""" + name = 'users' def ready(self): - + """Called when the 'users' app is loaded at runtime""" if canAppAccessDatabase(allow_test=True): try: @@ -22,15 +30,15 @@ class UsersConfig(AppConfig): pass def assign_permissions(self): - + """Update role permissions for existing groups""" from django.contrib.auth.models import Group from users.models import RuleSet, update_group_roles # First, delete any rule_set objects which have become outdated! for rule in RuleSet.objects.all(): - if rule.name not in RuleSet.RULESET_NAMES: # pragma: no cover # can not change ORM without the app beeing loaded - print("need to delete:", rule.name) + if rule.name not in RuleSet.RULESET_NAMES: # pragma: no cover # can not change ORM without the app being loaded + logger.info(f"Deleting outdated ruleset: {rule.name}") rule.delete() # Update group permission assignments for all groups @@ -39,7 +47,7 @@ class UsersConfig(AppConfig): update_group_roles(group) def update_owners(self): - + """Create an 'owner' object for each user and group instance""" from django.contrib.auth import get_user_model from django.contrib.auth.models import Group diff --git a/InvenTree/users/models.py b/InvenTree/users/models.py index be5fac1641..48b35e4562 100644 --- a/InvenTree/users/models.py +++ b/InvenTree/users/models.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +"""Database model definitions for the 'users' app""" import logging @@ -20,9 +20,7 @@ logger = logging.getLogger("inventree") class RuleSet(models.Model): - """ - A RuleSet is somewhat like a superset of the django permission class, - in that in encapsulates a bunch of permissions. + """A RuleSet is somewhat like a superset of the django permission class, in that in encapsulates a bunch of permissions. There are *many* apps models used within InvenTree, so it makes sense to group them into "roles". @@ -192,6 +190,7 @@ class RuleSet(models.Model): ] class Meta: + """Metaclass defines additional model properties""" unique_together = ( ('name', 'group'), ) @@ -221,10 +220,7 @@ class RuleSet(models.Model): @classmethod def check_table_permission(cls, user, table, permission): - """ - Check if the provided user has the specified permission against the table - """ - + """Check if the provided user has the specified permission against the table.""" # If the table does *not* require permissions if table in cls.RULESET_IGNORE: return True @@ -255,11 +251,7 @@ class RuleSet(models.Model): @staticmethod def get_model_permission_string(model, permission): - """ - Construct the correctly formatted permission string, - given the app_model name, and the permission type. - """ - + """Construct the correctly formatted permission string, given the app_model name, and the permission type.""" model, app = split_model(model) return "{app}.{perm}_{model}".format( @@ -269,7 +261,7 @@ class RuleSet(models.Model): ) def __str__(self, debug=False): # pragma: no cover - """ Ruleset string representation """ + """Ruleset string representation.""" if debug: # Makes debugging easier return f'{str(self.group).ljust(15)}: {self.name.title().ljust(15)} | ' \ @@ -279,10 +271,11 @@ class RuleSet(models.Model): return self.name def save(self, *args, **kwargs): + """Intercept the 'save' functionality to make addtional permission changes: - # It does not make sense to be able to change / create something, - # but not be able to view it! - + It does not make sense to be able to change / create something, + but not be able to view it! + """ if self.can_add or self.can_change or self.can_delete: self.can_view = True @@ -296,15 +289,12 @@ class RuleSet(models.Model): self.group.save() def get_models(self): - """ - Return the database tables / models that this ruleset covers. - """ - + """Return the database tables / models that this ruleset covers.""" return self.RULESET_MODELS.get(self.name, []) def split_model(model): - """get modelname and app from modelstring""" + """Get modelname and app from modelstring.""" *app, model = model.split('_') # handle models that have @@ -317,7 +307,7 @@ def split_model(model): def split_permission(app, perm): - """split permission string into permission and model""" + """Split permission string into permission and model.""" permission_name, *model = perm.split('_') # handle models that have underscores if len(model) > 1: # pragma: no cover @@ -328,10 +318,7 @@ def split_permission(app, perm): def update_group_roles(group, debug=False): - """ - - Iterates through all of the RuleSets associated with the group, - and ensures that the correct permissions are either applied or removed from the group. + """Iterates through all of the RuleSets associated with the group, and ensures that the correct permissions are either applied or removed from the group. This function is called under the following conditions: @@ -339,9 +326,7 @@ def update_group_roles(group, debug=False): b) Whenver the group object is updated The RuleSet model has complete control over the permissions applied to any group. - """ - if not canAppAccessDatabase(allow_test=True): return # pragma: no cover @@ -367,15 +352,13 @@ def update_group_roles(group, debug=False): permissions_to_delete = set() def add_model(name, action, allowed): - """ - Add a new model to the pile: + """Add a new model to the pile. - args: - name - The name of the model e.g. part_part - action - The permission action e.g. view - allowed - Whether or not the action is allowed + Args: + name: The name of the model e.g. part_part + action: The permission action e.g. view + allowed: Whether or not the action is allowed """ - if action not in ['view', 'add', 'change', 'delete']: # pragma: no cover raise ValueError("Action {a} is invalid".format(a=action)) @@ -418,16 +401,13 @@ def update_group_roles(group, debug=False): add_model(model, 'delete', ruleset.can_delete) def get_permission_object(permission_string): - """ - Find the permission object in the database, - from the simplified permission string + """Find the permission object in the database, from the simplified permission string. Args: - permission_string - a simplified permission_string e.g. 'part.view_partcategory' + permission_string: a simplified permission_string e.g. 'part.view_partcategory' Returns the permission object in the database associated with the permission string """ - (app, perm) = permission_string.split('.') perm, model = split_permission(app, perm) @@ -496,23 +476,18 @@ def update_group_roles(group, debug=False): @receiver(post_save, sender=Group, dispatch_uid='create_missing_rule_sets') def create_missing_rule_sets(sender, instance, **kwargs): - """ - Called *after* a Group object is saved. - As the linked RuleSet instances are saved *before* the Group, - then we can now use these RuleSet values to update the - group permissions. - """ + """Called *after* a Group object is saved. + As the linked RuleSet instances are saved *before* the Group, then we can now use these RuleSet values to update the group permissions. + """ update_group_roles(instance) def check_user_role(user, role, permission): - """ - Check if a user has a particular role:permission combination. + """Check if a user has a particular role:permission combination. If the user is a superuser, this will return True """ - if user.is_superuser: return True @@ -539,8 +514,8 @@ def check_user_role(user, role, permission): class Owner(models.Model): - """ - The Owner class is a proxy for a Group or User instance. + """The Owner class is a proxy for a Group or User instance. + Owner can be associated to any InvenTree model (part, stock, build, etc.) owner_type: Model type (Group or User) @@ -550,13 +525,12 @@ class Owner(models.Model): @classmethod def get_owners_matching_user(cls, user): - """ - Return all "owner" objects matching the provided user: + """Return all "owner" objects matching the provided user. - A) An exact match for the user - B) Any groups that the user is a part of + Includes: + - An exact match for the user + - Any groups that the user is a part of """ - user_type = ContentType.objects.get(app_label='auth', model='user') group_type = ContentType.objects.get(app_label='auth', model='group') @@ -578,9 +552,11 @@ class Owner(models.Model): @staticmethod def get_api_url(): # pragma: no cover + """Returns the API endpoint URL associated with the Owner model""" return reverse('api-owner-list') class Meta: + """Metaclass defines extra model properties""" # Ensure all owners are unique constraints = [ UniqueConstraint(fields=['owner_type', 'owner_id'], @@ -594,25 +570,20 @@ class Owner(models.Model): owner = GenericForeignKey('owner_type', 'owner_id') def __str__(self): - """ Defines the owner string representation """ + """Defines the owner string representation.""" return f'{self.owner} ({self.owner_type.name})' def name(self): - """ - Return the 'name' of this owner - """ + """Return the 'name' of this owner.""" return str(self.owner) def label(self): - """ - Return the 'type' label of this owner i.e. 'user' or 'group' - """ + """Return the 'type' label of this owner i.e. 'user' or 'group'.""" return str(self.owner_type.name) @classmethod def create(cls, obj): - """ Check if owner exist then create new owner entry """ - + """Check if owner exist then create new owner entry.""" # Check for existing owner existing_owner = cls.get_owner(obj) @@ -627,8 +598,7 @@ class Owner(models.Model): @classmethod def get_owner(cls, user_or_group): - """ Get owner instance for a group or user """ - + """Get owner instance for a group or user.""" user_model = get_user_model() owner = None content_type_id = 0 @@ -651,11 +621,10 @@ class Owner(models.Model): return owner def get_related_owners(self, include_group=False): - """ - Get all owners "related" to an owner. + """Get all owners "related" to an owner. + This method is useful to retrieve all "user-type" owners linked to a "group-type" owner """ - user_model = get_user_model() related_owners = None @@ -680,21 +649,13 @@ class Owner(models.Model): @receiver(post_save, sender=Group, dispatch_uid='create_owner') @receiver(post_save, sender=get_user_model(), dispatch_uid='create_owner') def create_owner(sender, instance, **kwargs): - """ - Callback function to create a new owner instance - after either a new group or user instance is saved. - """ - + """Callback function to create a new owner instance after either a new group or user instance is saved.""" Owner.create(obj=instance) @receiver(post_delete, sender=Group, dispatch_uid='delete_owner') @receiver(post_delete, sender=get_user_model(), dispatch_uid='delete_owner') def delete_owner(sender, instance, **kwargs): - """ - Callback function to delete an owner instance - after either a new group or user instance is deleted. - """ - + """Callback function to delete an owner instance after either a new group or user instance is deleted.""" owner = Owner.get_owner(instance) owner.delete() diff --git a/InvenTree/users/serializers.py b/InvenTree/users/serializers.py index fcf1ed5678..1ae2dc28e1 100644 --- a/InvenTree/users/serializers.py +++ b/InvenTree/users/serializers.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +"""DRF API serializers for the 'users' app""" from django.contrib.auth.models import User @@ -10,10 +10,10 @@ from .models import Owner class UserSerializer(InvenTreeModelSerializer): - """ Serializer for a User - """ + """Serializer for a User.""" class Meta: + """Metaclass defines serializer fields.""" model = User fields = ('pk', 'username', @@ -23,15 +23,14 @@ class UserSerializer(InvenTreeModelSerializer): class OwnerSerializer(InvenTreeModelSerializer): - """ - Serializer for an "Owner" (either a "user" or a "group") - """ + """Serializer for an "Owner" (either a "user" or a "group")""" name = serializers.CharField(read_only=True) label = serializers.CharField(read_only=True) class Meta: + """Metaclass defines serializer fields.""" model = Owner fields = [ 'pk', diff --git a/InvenTree/users/test_migrations.py b/InvenTree/users/test_migrations.py index 7bb17d0070..9cc8b404f8 100644 --- a/InvenTree/users/test_migrations.py +++ b/InvenTree/users/test_migrations.py @@ -1,6 +1,4 @@ -""" -Unit tests for the user model database migrations -""" +"""Unit tests for the user model database migrations.""" from django_test_migrations.contrib.unittest_case import MigratorTestCase @@ -8,15 +6,13 @@ from InvenTree import helpers class TestForwardMigrations(MigratorTestCase): - """ - Test entire schema migration sequence for the users app - """ + """Test entire schema migration sequence for the users app.""" migrate_from = ('users', helpers.getOldestMigrationFile('users')) migrate_to = ('users', helpers.getNewestMigrationFile('users')) def prepare(self): - + """Setup the initial state of the database before migrations""" User = self.old_state.apps.get_model('auth', 'user') User.objects.create( @@ -32,7 +28,7 @@ class TestForwardMigrations(MigratorTestCase): ) def test_users_exist(self): - + """Test that users exist in the database""" User = self.new_state.apps.get_model('auth', 'user') self.assertEqual(User.objects.count(), 2) diff --git a/InvenTree/users/tests.py b/InvenTree/users/tests.py index 393b640d2e..f8892f2155 100644 --- a/InvenTree/users/tests.py +++ b/InvenTree/users/tests.py @@ -1,3 +1,5 @@ +"""Unit tests for the 'users' app""" + from django.apps import apps from django.contrib.auth.models import Group from django.test import TestCase @@ -10,12 +12,10 @@ from users.models import Owner, RuleSet class RuleSetModelTest(TestCase): - """ - Some simplistic tests to ensure the RuleSet model is setup correctly. - """ + """Some simplistic tests to ensure the RuleSet model is setup correctly.""" def test_ruleset_models(self): - + """Test that the role rulesets work as intended""" keys = RuleSet.RULESET_MODELS.keys() # Check if there are any rulesets which do not have models defined @@ -48,11 +48,7 @@ class RuleSetModelTest(TestCase): self.assertEqual(len(empty), 0) def test_model_names(self): - """ - Test that each model defined in the rulesets is valid, - based on the database schema! - """ - + """Test that each model defined in the rulesets is valid, based on the database schema!""" available_models = apps.get_models() available_tables = set() @@ -108,10 +104,7 @@ class RuleSetModelTest(TestCase): self.assertEqual(len(extra_models), 0) def test_permission_assign(self): - """ - Test that the permission assigning works! - """ - + """Test that the permission assigning works!""" # Create a new group group = Group.objects.create(name="Test group") @@ -161,17 +154,16 @@ class RuleSetModelTest(TestCase): class OwnerModelTest(InvenTreeTestCase): - """ - Some simplistic tests to ensure the Owner model is setup correctly. - """ + """Some simplistic tests to ensure the Owner model is setup correctly.""" def do_request(self, endpoint, filters, status_code=200): + """Perform an API request""" response = self.client.get(endpoint, filters, format='json') self.assertEqual(response.status_code, status_code) return response.data def test_owner(self): - + """Tests for the 'owner' model""" # Check that owner was created for user user_as_owner = Owner.get_owner(self.user) self.assertEqual(type(user_as_owner), Owner) @@ -212,9 +204,7 @@ class OwnerModelTest(InvenTreeTestCase): self.assertEqual(group_as_owner, None) def test_api(self): - """ - Test user APIs - """ + """Test user APIs.""" self.client.logout() # not authed @@ -231,9 +221,7 @@ class OwnerModelTest(InvenTreeTestCase): # self.do_request(reverse('api-owner-detail', kwargs={'pk': self.user.id}), {}) def test_token(self): - """ - Test token mechanisms - """ + """Test token mechanisms.""" self.client.logout() token = Token.objects.filter(user=self.user) diff --git a/InvenTree/users/urls.py b/InvenTree/users/urls.py deleted file mode 100644 index 40a96afc6f..0000000000 --- a/InvenTree/users/urls.py +++ /dev/null @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/ci/check_api_endpoint.py b/ci/check_api_endpoint.py index dffacfcf3a..0bb5618696 100644 --- a/ci/check_api_endpoint.py +++ b/ci/check_api_endpoint.py @@ -1,6 +1,4 @@ -""" -Test that the root API endpoint is available. -""" +"""Test that the root API endpoint is available.""" import json diff --git a/ci/check_js_templates.py b/ci/check_js_templates.py index 84e5198982..5fd7521c55 100644 --- a/ci/check_js_templates.py +++ b/ci/check_js_templates.py @@ -1,9 +1,6 @@ -""" -Test that the "translated" javascript files to not contain template tags -which need to be determined at "run time". +"""Test that the "translated" javascript files to not contain template tags which need to be determined at "run time". This is because the "translated" javascript files are compiled into the "static" directory. - They should only contain template tags that render static information. """ diff --git a/ci/check_locale_files.py b/ci/check_locale_files.py index 808e3ab797..3bfbd66943 100644 --- a/ci/check_locale_files.py +++ b/ci/check_locale_files.py @@ -1,4 +1,4 @@ -""" Check that there are no database migration files which have not been committed. """ +"""Check that there are no database migration files which have not been committed.""" import subprocess import sys diff --git a/ci/check_migration_files.py b/ci/check_migration_files.py index 2f40d49a9d..24668350be 100644 --- a/ci/check_migration_files.py +++ b/ci/check_migration_files.py @@ -1,4 +1,4 @@ -""" Check that there are no database migration files which have not been committed. """ +"""Check that there are no database migration files which have not been committed.""" import subprocess import sys diff --git a/setup.cfg b/setup.cfg index 3aa96371ec..3386be2230 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,9 +15,14 @@ ignore = N806, # - N812 - lowercase imported as non-lowercase N812, + # - D202 - No blank lines allowed after function docstring + D202, # - D415 - First line should end with a period, question mark, or exclamation point D415, exclude = .git,__pycache__,*/migrations/*,*/lib/*,*/bin/*,*/media/*,*/static/*,InvenTree/plugins/* +per-file-ignores = + # Do not enforce docstring on __init__ + __init__.py: D104 max-complexity = 20 docstring-convention=google diff --git a/tasks.py b/tasks.py index 4ecebfc8e8..a9036e460b 100644 --- a/tasks.py +++ b/tasks.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - import json import os import pathlib @@ -10,10 +8,7 @@ from invoke import task def apps(): - """ - Returns a list of installed apps - """ - + """Returns a list of installed apps""" return [ 'build', 'common', @@ -30,8 +25,8 @@ def apps(): def localDir(): - """ - Returns the directory of *THIS* file. + """Returns the directory of *THIS* file. + Used to ensure that the various scripts always run in the correct directory. """ @@ -39,30 +34,22 @@ def localDir(): def managePyDir(): - """ - Returns the directory of the manage.py file - """ - + """Returns the directory of the manage.py file""" return os.path.join(localDir(), 'InvenTree') def managePyPath(): - """ - Return the path of the manage.py file - """ - + """Return the path of the manage.py file""" return os.path.join(managePyDir(), 'manage.py') def manage(c, cmd, pty=False): - """ - Runs a given command against django's "manage.py" script. + """Runs a given command against django's "manage.py" script. Args: c - Command line context cmd - django command to run """ - c.run('cd "{path}" && python3 manage.py {cmd}'.format( path=managePyDir(), cmd=cmd @@ -71,10 +58,7 @@ def manage(c, cmd, pty=False): @task def plugins(c): - """ - Installs all plugins as specified in 'plugins.txt' - """ - + """Installs all plugins as specified in 'plugins.txt'""" from InvenTree.InvenTree.config import get_plugin_file plugin_file = get_plugin_file() @@ -87,10 +71,7 @@ def plugins(c): @task(post=[plugins]) def install(c): - """ - Installs required python packages - """ - + """Installs required python packages""" print("Installing required python packages from 'requirements.txt'") # Install required Python packages with PIP @@ -99,10 +80,7 @@ def install(c): @task def setup_dev(c): - """ - Sets up everything needed for the dev enviroment - """ - + """Sets up everything needed for the dev enviroment""" print("Installing required python packages from 'requirements.txt'") # Install required Python packages with PIP @@ -117,82 +95,55 @@ def setup_dev(c): @task def shell(c): - """ - Open a python shell with access to the InvenTree database models. - """ - + """Open a python shell with access to the InvenTree database models.""" manage(c, 'shell', pty=True) @task def superuser(c): - """ - Create a superuser (admin) account for the database. - """ - + """Create a superuser/admin account for the database.""" manage(c, 'createsuperuser', pty=True) @task def check(c): - """ - Check validity of django codebase - """ - + """Check validity of django codebase""" manage(c, "check") @task def wait(c): - """ - Wait until the database connection is ready - """ - + """Wait until the database connection is ready""" return manage(c, "wait_for_db") @task(pre=[wait]) def worker(c): - """ - Run the InvenTree background worker process - """ - + """Run the InvenTree background worker process""" manage(c, 'qcluster', pty=True) @task def rebuild_models(c): - """ - Rebuild database models with MPTT structures - """ - + """Rebuild database models with MPTT structures""" manage(c, "rebuild_models", pty=True) @task def rebuild_thumbnails(c): - """ - Rebuild missing image thumbnails - """ - + """Rebuild missing image thumbnails""" manage(c, "rebuild_thumbnails", pty=True) @task def clean_settings(c): - """ - Clean the setting tables of old settings - """ - + """Clean the setting tables of old settings""" manage(c, "clean_settings") @task(help={'mail': 'mail of the user whos MFA should be disabled'}) def remove_mfa(c, mail=''): - """ - Remove MFA for a user - """ - + """Remove MFA for a user""" if not mail: print('You must provide a users mail') @@ -201,11 +152,10 @@ def remove_mfa(c, mail=''): @task(post=[rebuild_models, rebuild_thumbnails]) def migrate(c): - """ - Performs database migrations. + """Performs database migrations. + This is a critical step if the database schema have been altered! """ - print("Running InvenTree database migrations...") print("========================================") @@ -220,35 +170,28 @@ def migrate(c): @task def static(c): - """ - Copies required static files to the STATIC_ROOT directory, - as per Django requirements. - """ - + """Copies required static files to the STATIC_ROOT directory, as per Django requirements.""" manage(c, "prerender") manage(c, "collectstatic --no-input") @task def translate_stats(c): - """ - Collect translation stats. + """Collect translation stats. + The file generated from this is needed for the UI. """ - path = os.path.join('InvenTree', 'script', 'translation_stats.py') c.run(f'python3 {path}') @task(post=[translate_stats, static]) def translate(c): - """ - Rebuild translation source files. (Advanced use only!) + """Rebuild translation source files. (Advanced use only!) Note: This command should not be used on a local install, it is performed as part of the InvenTree translation toolchain. """ - # Translate applicable .py / .html / .js files manage(c, "makemessages --all -e py,html,js --no-wrap") manage(c, "compilemessages") @@ -256,8 +199,7 @@ def translate(c): @task(pre=[install, migrate, static, clean_settings]) def update(c): - """ - Update InvenTree installation. + """Update InvenTree installation. This command should be invoked after source code has been updated, e.g. downloading new code from GitHub. @@ -270,7 +212,6 @@ def update(c): - static - clean_settings """ - # Recompile the translation files (.mo) # We do not run 'invoke translate' here, as that will touch the source (.po) files too! manage(c, 'compilemessages', pty=True) @@ -278,19 +219,14 @@ def update(c): @task def style(c): - """ - Run PEP style checks against InvenTree sourcecode - """ - + """Run PEP style checks against InvenTree sourcecode""" print("Running PEP style checks...") c.run('flake8 InvenTree') @task def test(c, database=None): - """ - Run unit-tests for InvenTree codebase. - """ + """Run unit-tests for InvenTree codebase.""" # Run sanity check on the django install manage(c, 'check') @@ -300,13 +236,10 @@ def test(c, database=None): @task def coverage(c): - """ - Run code-coverage of the InvenTree codebase, - using the 'coverage' code-analysis tools. + """Run code-coverage of the InvenTree codebase, using the 'coverage' code-analysis tools. Generates a code coverage report (available in the htmlcov directory) """ - # Run sanity check on the django install manage(c, 'check') @@ -321,10 +254,7 @@ def coverage(c): def content_excludes(): - """ - Returns a list of content types to exclude from import/export - """ - + """Returns a list of content types to exclude from import/export""" excludes = [ "contenttypes", "auth.permission", @@ -351,10 +281,7 @@ def content_excludes(): @task(help={'filename': "Output filename (default = 'data.json')"}) def export_records(c, filename='data.json'): - """ - Export all database records to a file - """ - + """Export all database records to a file""" # Get an absolute path to the file if not os.path.isabs(filename): filename = os.path.join(localDir(), filename) @@ -403,10 +330,7 @@ def export_records(c, filename='data.json'): @task(help={'filename': 'Input filename', 'clear': 'Clear existing data before import'}, post=[rebuild_models, rebuild_thumbnails]) def import_records(c, filename='data.json', clear=False): - """ - Import database records from a file - """ - + """Import database records from a file""" # Get an absolute path to the supplied filename if not os.path.isabs(filename): filename = os.path.join(localDir(), filename) @@ -450,12 +374,10 @@ def import_records(c, filename='data.json', clear=False): @task def delete_data(c, force=False): - """ - Delete all database records! + """Delete all database records! Warning: This will REALLY delete all records in the database!! """ - print("Deleting all data from InvenTree database...") if force: @@ -466,8 +388,7 @@ def delete_data(c, force=False): @task(post=[rebuild_models, rebuild_thumbnails]) def import_fixtures(c): - """ - Import fixture data into the database. + """Import fixture data into the database. This command imports all existing test fixture data into the database. @@ -476,7 +397,6 @@ def import_fixtures(c): - Running this command may overwrite existing database data!! - Don't say you were not warned... """ - fixtures = [ # Build model 'build', @@ -515,20 +435,16 @@ def import_fixtures(c): @task(help={'address': 'Server address:port (default=127.0.0.1:8000)'}) def server(c, address="127.0.0.1:8000"): - """ - Launch a (deveopment) server using Django's in-built webserver. + """Launch a (deveopment) server using Django's in-built webserver. Note: This is *not* sufficient for a production installation. """ - manage(c, "runserver {address}".format(address=address), pty=True) @task(post=[translate_stats, static, server]) def test_translations(c): - """ - Add a fictional language to test if each component is ready for translations - """ + """Add a fictional language to test if each component is ready for translations""" import django from django.conf import settings @@ -595,8 +511,5 @@ def test_translations(c): @task def render_js_files(c): - """ - Render templated javascript files (used for static testing). - """ - + """Render templated javascript files (used for static testing).""" manage(c, "test InvenTree.ci_render_js")