diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2458c3f7f7..764691ede0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,18 +15,14 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: mixed-line-ending -- repo: https://github.com/pycqa/isort - rev: '5.12.0' - hooks: - - id: isort - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.1.11 hooks: -# - id: ruff-format -# args: [--preview] + - id: ruff-format + args: [--preview] - id: ruff args: [ - #--fix, + --fix, --preview ] - repo: https://github.com/jazzband/pip-tools diff --git a/InvenTree/InvenTree/admin.py b/InvenTree/InvenTree/admin.py index 12d6820db1..7c5c9f6cc2 100644 --- a/InvenTree/InvenTree/admin.py +++ b/InvenTree/InvenTree/admin.py @@ -30,14 +30,18 @@ class InvenTreeResource(ModelResource): using_transactions, collect_failed_rows, rollback_on_validation_errors=None, - **kwargs + **kwargs, ): """Override the default import_data_inner function to provide better error handling""" if len(dataset) > self.MAX_IMPORT_ROWS: - raise ImportExportError(f"Dataset contains too many rows (max {self.MAX_IMPORT_ROWS})") + raise ImportExportError( + f'Dataset contains too many rows (max {self.MAX_IMPORT_ROWS})' + ) if len(dataset.headers) > self.MAX_IMPORT_COLS: - raise ImportExportError(f"Dataset contains too many columns (max {self.MAX_IMPORT_COLS})") + raise ImportExportError( + f'Dataset contains too many columns (max {self.MAX_IMPORT_COLS})' + ) return super().import_data_inner( dataset, @@ -46,7 +50,7 @@ class InvenTreeResource(ModelResource): using_transactions, collect_failed_rows, rollback_on_validation_errors=rollback_on_validation_errors, - **kwargs + **kwargs, ) def export_resource(self, obj): @@ -75,10 +79,7 @@ class InvenTreeResource(ModelResource): """Return fields, with some common exclusions""" fields = super().get_fields(**kwargs) - fields_to_exclude = [ - 'metadata', - 'lft', 'rght', 'tree_id', 'level', - ] + fields_to_exclude = ['metadata', 'lft', 'rght', 'tree_id', 'level'] return [f for f in fields if f.column_name not in fields_to_exclude] diff --git a/InvenTree/InvenTree/api.py b/InvenTree/InvenTree/api.py index 2c47b338ad..3e449398b5 100644 --- a/InvenTree/InvenTree/api.py +++ b/InvenTree/InvenTree/api.py @@ -31,9 +31,7 @@ from .views import AjaxView class VersionView(APIView): """Simple JSON endpoint for InvenTree version information.""" - permission_classes = [ - permissions.IsAdminUser, - ] + permission_classes = [permissions.IsAdminUser] def get(self, request, *args, **kwargs): """Return information about the InvenTree server.""" @@ -47,20 +45,21 @@ class VersionView(APIView): 'commit_date': InvenTree.version.inventreeCommitDate(), 'commit_branch': InvenTree.version.inventreeBranch(), 'python': InvenTree.version.inventreePythonVersion(), - 'django': InvenTree.version.inventreeDjangoVersion() + 'django': InvenTree.version.inventreeDjangoVersion(), }, 'links': { 'doc': InvenTree.version.inventreeDocUrl(), 'code': InvenTree.version.inventreeGithubUrl(), 'credit': InvenTree.version.inventreeCreditsUrl(), 'app': InvenTree.version.inventreeAppUrl(), - 'bug': f'{InvenTree.version.inventreeGithubUrl()}/issues' - } + 'bug': f'{InvenTree.version.inventreeGithubUrl()}/issues', + }, }) class VersionSerializer(serializers.Serializer): """Serializer for a single version.""" + version = serializers.CharField() date = serializers.CharField() gh = serializers.CharField() @@ -69,16 +68,19 @@ class VersionSerializer(serializers.Serializer): class Meta: """Meta class for VersionSerializer.""" + fields = ['version', 'date', 'gh', 'text', 'latest'] class VersionApiSerializer(serializers.Serializer): """Serializer for the version api endpoint.""" + VersionSerializer(many=True) class VersionTextView(ListAPI): """Simple JSON endpoint for InvenTree version text.""" + permission_classes = [permissions.IsAdminUser] @extend_schema(responses={200: OpenApiResponse(response=VersionApiSerializer)}) @@ -119,10 +121,10 @@ class InfoView(AjaxView): 'debug_mode': settings.DEBUG, 'docker_mode': settings.DOCKER, 'system_health': check_system_health() if is_staff else None, - 'database': InvenTree.version.inventreeDatabase()if is_staff else None, + 'database': InvenTree.version.inventreeDatabase() if is_staff else None, 'platform': InvenTree.version.inventreePlatform() if is_staff else None, 'installer': InvenTree.version.inventreeInstaller() if is_staff else None, - 'target': InvenTree.version.inventreeTarget()if is_staff else None, + 'target': InvenTree.version.inventreeTarget() if is_staff else None, } return JsonResponse(data) @@ -130,7 +132,9 @@ class InfoView(AjaxView): def check_auth_header(self, request): """Check if user is authenticated via a token in the header.""" # TODO @matmair: remove after refacgtor of Token check is done - headers = request.headers.get('Authorization', request.headers.get('authorization')) + headers = request.headers.get( + 'Authorization', request.headers.get('authorization') + ) if not headers: return False @@ -160,7 +164,7 @@ class NotFoundView(AjaxView): 'detail': _('API endpoint not found'), 'url': request.build_absolute_uri(), }, - status=404 + status=404, ) def options(self, request, *args, **kwargs): @@ -228,24 +232,25 @@ class BulkDeleteMixin: if not items and not filters: raise ValidationError({ - "non_field_errors": ["List of items or filters must be provided for bulk deletion"], + 'non_field_errors': [ + 'List of items or filters must be provided for bulk deletion' + ] }) if items and type(items) is not list: raise ValidationError({ - "items": ["'items' must be supplied as a list object"] + 'items': ["'items' must be supplied as a list object"] }) if filters and type(filters) is not dict: raise ValidationError({ - "filters": ["'filters' must be supplied as a dict object"] + 'filters': ["'filters' must be supplied as a dict object"] }) # Keep track of how many items we deleted n_deleted = 0 with transaction.atomic(): - # Start with *all* models and perform basic filtering queryset = model.objects.all() queryset = self.filter_delete_queryset(queryset, request) @@ -261,16 +266,12 @@ class BulkDeleteMixin: n_deleted = queryset.count() queryset.delete() - return Response( - { - 'success': f"Deleted {n_deleted} items", - }, - status=204 - ) + return Response({'success': f'Deleted {n_deleted} items'}, status=204) class ListCreateDestroyAPIView(BulkDeleteMixin, ListCreateAPI): """Custom API endpoint which provides BulkDelete functionality in addition to List and Create""" + ... @@ -307,24 +308,17 @@ class APIDownloadMixin: def download_queryset(self, queryset, export_format): """This function must be implemented to provide a downloadFile request.""" - raise NotImplementedError("download_queryset method not implemented!") + raise NotImplementedError('download_queryset method not implemented!') class AttachmentMixin: """Mixin for creating attachment objects, and ensuring the user information is saved correctly.""" - permission_classes = [ - permissions.IsAuthenticated, - RolePermission, - ] + permission_classes = [permissions.IsAuthenticated, RolePermission] filter_backends = SEARCH_ORDER_FILTER - search_fields = [ - 'attachment', - 'comment', - 'link', - ] + search_fields = ['attachment', 'comment', 'link'] def perform_create(self, serializer): """Save the user information when a file is uploaded.""" @@ -342,9 +336,7 @@ class APISearchView(APIView): Is much more efficient and simplifies code! """ - permission_classes = [ - permissions.IsAuthenticated, - ] + permission_classes = [permissions.IsAuthenticated] def get_result_types(self): """Construct a list of search types we can return""" @@ -384,14 +376,11 @@ class APISearchView(APIView): } if 'search' not in data: - raise ValidationError({ - 'search': 'Search term must be provided', - }) + raise ValidationError({'search': 'Search term must be provided'}) for key, cls in self.get_result_types().items(): # Only return results which are specifically requested if key in data: - params = data[key] for k, v in pass_through_params.items(): @@ -418,16 +407,18 @@ class APISearchView(APIView): table = f'{app_label}_{model_name}' try: - if users.models.RuleSet.check_table_permission(request.user, table, 'view'): + if users.models.RuleSet.check_table_permission( + request.user, table, 'view' + ): results[key] = view.list(request, *args, **kwargs).data else: results[key] = { - 'error': _('User does not have permission to view this model') + 'error': _( + 'User does not have permission to view this model' + ) } except Exception as exc: - results[key] = { - 'error': str(exc) - } + results[key] = {'error': str(exc)} return Response(results) @@ -442,7 +433,9 @@ class MetadataView(RetrieveUpdateAPI): model = self.kwargs.get(self.MODEL_REF, None) if model is None: - raise ValidationError(f"MetadataView called without '{self.MODEL_REF}' parameter") + raise ValidationError( + f"MetadataView called without '{self.MODEL_REF}' parameter" + ) return model diff --git a/InvenTree/InvenTree/api_version.py b/InvenTree/InvenTree/api_version.py index 1a869bdd51..696a125e67 100644 --- a/InvenTree/InvenTree/api_version.py +++ b/InvenTree/InvenTree/api_version.py @@ -1,6 +1,5 @@ """InvenTree API version information.""" - # InvenTree API version INVENTREE_API_VERSION = 160 """Increment this API version number whenever there is a significant change to the API that any clients need to know about.""" diff --git a/InvenTree/InvenTree/apps.py b/InvenTree/InvenTree/apps.py index 19b8aad4b4..56eb653f27 100644 --- a/InvenTree/InvenTree/apps.py +++ b/InvenTree/InvenTree/apps.py @@ -16,11 +16,12 @@ import InvenTree.ready import InvenTree.tasks from InvenTree.config import get_setting -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') class InvenTreeConfig(AppConfig): """AppConfig for inventree app.""" + name = 'InvenTree' def ready(self): @@ -36,7 +37,10 @@ class InvenTreeConfig(AppConfig): - Adding users set in the current environment """ # skip loading if plugin registry is not loaded or we run in a background thread - if not InvenTree.ready.isPluginRegistryLoaded() or not InvenTree.ready.isInMainThread(): + if ( + not InvenTree.ready.isPluginRegistryLoaded() + or not InvenTree.ready.isInMainThread() + ): return # Skip if running migrations @@ -44,7 +48,6 @@ class InvenTreeConfig(AppConfig): return if InvenTree.ready.canAppAccessDatabase() or settings.TESTING_ENV: - self.remove_obsolete_tasks() self.collect_tasks() @@ -81,11 +84,11 @@ class InvenTreeConfig(AppConfig): try: Schedule.objects.filter(func__in=obsolete).delete() except Exception: - logger.exception("Failed to remove obsolete tasks - database not ready") + logger.exception('Failed to remove obsolete tasks - database not ready') def start_background_tasks(self): """Start all background tests for InvenTree.""" - logger.info("Starting background tasks...") + logger.info('Starting background tasks...') from django_q.models import Schedule @@ -102,15 +105,16 @@ class InvenTreeConfig(AppConfig): tasks = InvenTree.tasks.tasks.task_list for task in tasks: - ref_name = f'{task.func.__module__}.{task.func.__name__}' if ref_name in existing_tasks.keys(): # This task already exists - update the details if required existing_task = existing_tasks[ref_name] - if existing_task.schedule_type != task.interval or existing_task.minutes != task.minutes: - + if ( + existing_task.schedule_type != task.interval + or existing_task.minutes != task.minutes + ): existing_task.schedule_type = task.interval existing_task.minutes = task.minutes tasks_to_update.append(existing_task) @@ -128,20 +132,17 @@ class InvenTreeConfig(AppConfig): if len(tasks_to_create) > 0: Schedule.objects.bulk_create(tasks_to_create) - logger.info("Created %s new scheduled tasks", len(tasks_to_create)) + logger.info('Created %s new scheduled tasks', len(tasks_to_create)) if len(tasks_to_update) > 0: Schedule.objects.bulk_update(tasks_to_update, ['schedule_type', 'minutes']) - logger.info("Updated %s existing scheduled tasks", len(tasks_to_update)) + logger.info('Updated %s existing scheduled tasks', len(tasks_to_update)) # Put at least one task onto the background worker stack, # which will be processed as soon as the worker comes online - InvenTree.tasks.offload_task( - InvenTree.tasks.heartbeat, - force_async=True, - ) + InvenTree.tasks.offload_task(InvenTree.tasks.heartbeat, force_async=True) - logger.info("Started %s scheduled background tasks...", len(tasks)) + logger.info('Started %s scheduled background tasks...', len(tasks)) def collect_tasks(self): """Collect all background tasks.""" @@ -153,7 +154,7 @@ class InvenTreeConfig(AppConfig): try: import_module(f'{app.module.__package__}.tasks') except Exception as e: # pragma: no cover - logger.exception("Error loading tasks for %s: %s", app_name, e) + logger.exception('Error loading tasks for %s: %s', app_name, e) def update_exchange_rates(self): # pragma: no cover """Update exchange rates each time the server is started. @@ -184,16 +185,20 @@ class InvenTreeConfig(AppConfig): if last_update is None: # Never been updated - logger.info("Exchange backend has never been updated") + logger.info('Exchange backend has never been updated') update = True # Backend currency has changed? if base_currency != backend.base_currency: - logger.info("Base currency changed from %s to %s", backend.base_currency, base_currency) + logger.info( + 'Base currency changed from %s to %s', + backend.base_currency, + base_currency, + ) update = True - except (ExchangeBackend.DoesNotExist): - logger.info("Exchange backend not found - updating") + except ExchangeBackend.DoesNotExist: + logger.info('Exchange backend not found - updating') update = True except Exception: @@ -204,9 +209,9 @@ class InvenTreeConfig(AppConfig): try: update_exchange_rates() except OperationalError: - logger.warning("Could not update exchange rates - database not ready") + logger.warning('Could not update exchange rates - database not ready') except Exception as e: - logger.exception("Error updating exchange rates: %s (%s)", e, type(e)) + logger.exception('Error updating exchange rates: %s (%s)', e, type(e)) def add_user_on_startup(self): """Add a user on startup.""" @@ -218,7 +223,9 @@ class InvenTreeConfig(AppConfig): add_user = get_setting('INVENTREE_ADMIN_USER', 'admin_user') add_email = get_setting('INVENTREE_ADMIN_EMAIL', 'admin_email') add_password = get_setting('INVENTREE_ADMIN_PASSWORD', 'admin_password') - add_password_file = get_setting("INVENTREE_ADMIN_PASSWORD_FILE", "admin_password_file", None) + add_password_file = get_setting( + 'INVENTREE_ADMIN_PASSWORD_FILE', 'admin_password_file', None + ) # check if all values are present set_variables = 0 @@ -239,7 +246,9 @@ class InvenTreeConfig(AppConfig): # if a password file is present, do not warn - will be handled later if add_password_file: return - logger.warning('Not all required settings for adding a user on startup are present:\nINVENTREE_ADMIN_USER, INVENTREE_ADMIN_EMAIL, INVENTREE_ADMIN_PASSWORD') + logger.warning( + 'Not all required settings for adding a user on startup are present:\nINVENTREE_ADMIN_USER, INVENTREE_ADMIN_EMAIL, INVENTREE_ADMIN_PASSWORD' + ) return # good to go -> create user @@ -253,9 +262,11 @@ class InvenTreeConfig(AppConfig): try: with transaction.atomic(): if user.objects.filter(username=add_user).exists(): - logger.info("User %s already exists - skipping creation", add_user) + logger.info('User %s already exists - skipping creation', add_user) else: - new_user = user.objects.create_superuser(add_user, add_email, add_password) + new_user = user.objects.create_superuser( + add_user, add_email, add_password + ) logger.info('User %s was created!', str(new_user)) except IntegrityError: logger.warning('The user "%s" could not be created', add_user) @@ -263,12 +274,12 @@ class InvenTreeConfig(AppConfig): def add_user_from_file(self): """Add the superuser from a file.""" # stop if checks were already created - if hasattr(settings, "USER_ADDED_FILE") and settings.USER_ADDED_FILE: + if hasattr(settings, 'USER_ADDED_FILE') and settings.USER_ADDED_FILE: return # get values add_password_file = get_setting( - "INVENTREE_ADMIN_PASSWORD_FILE", "admin_password_file", None + 'INVENTREE_ADMIN_PASSWORD_FILE', 'admin_password_file', None ) # no variable set -> do not try anything @@ -284,7 +295,11 @@ class InvenTreeConfig(AppConfig): return # good to go -> create user - self._create_admin_user(get_setting('INVENTREE_ADMIN_USER', 'admin_user', 'admin'), get_setting('INVENTREE_ADMIN_EMAIL', 'admin_email', ''), add_password_file.read_text(encoding="utf-8")) + self._create_admin_user( + get_setting('INVENTREE_ADMIN_USER', 'admin_user', 'admin'), + get_setting('INVENTREE_ADMIN_EMAIL', 'admin_email', ''), + add_password_file.read_text(encoding='utf-8'), + ) # do not try again settings.USER_ADDED_FILE = True diff --git a/InvenTree/InvenTree/ci_render_js.py b/InvenTree/InvenTree/ci_render_js.py index f221349e08..6aa3847895 100644 --- a/InvenTree/InvenTree/ci_render_js.py +++ b/InvenTree/InvenTree/ci_render_js.py @@ -24,22 +24,14 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover here = os.path.abspath(os.path.dirname(__file__)) - output_dir = os.path.join( - here, - '..', - '..', - 'js_tmp', - ) + output_dir = os.path.join(here, '..', '..', 'js_tmp') output_dir = os.path.abspath(output_dir) if not os.path.exists(output_dir): os.mkdir(output_dir) - output_file = os.path.join( - output_dir, - filename, - ) + output_file = os.path.join(output_dir, filename) with open(output_file, 'wb') as output: output.write(response.content) @@ -48,12 +40,7 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover """Download files in directory.""" here = os.path.abspath(os.path.dirname(__file__)) - js_template_dir = os.path.join( - here, - '..', - 'templates', - 'js', - ) + js_template_dir = os.path.join(here, '..', 'templates', 'js') directory = os.path.join(js_template_dir, subdir) @@ -76,9 +63,9 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover """Look for all javascript files.""" n = 0 - print("Rendering javascript files...") + print('Rendering javascript files...') n += self.download_files('translated', '/js/i18n') n += self.download_files('dynamic', '/js/dynamic') - print(f"Rendered {n} javascript files.") + print(f'Rendered {n} javascript files.') diff --git a/InvenTree/InvenTree/config.py b/InvenTree/InvenTree/config.py index 204b4376a0..40c849bb36 100644 --- a/InvenTree/InvenTree/config.py +++ b/InvenTree/InvenTree/config.py @@ -51,7 +51,11 @@ def to_dict(value): try: return json.loads(value) except Exception as error: - logger.exception("Failed to parse value '%s' as JSON with error %s. Ensure value is a valid JSON string.", value, error) + logger.exception( + "Failed to parse value '%s' as JSON with error %s. Ensure value is a valid JSON string.", + value, + error, + ) return {} @@ -90,12 +94,14 @@ def get_config_file(create=True) -> Path: cfg_filename = base_dir.joinpath('config.yaml').resolve() if not cfg_filename.exists() and create: - print("InvenTree configuration file 'config.yaml' not found - creating default file") + print( + "InvenTree configuration file 'config.yaml' not found - creating default file" + ) ensure_dir(cfg_filename.parent) - cfg_template = base_dir.joinpath("config_template.yaml") + cfg_template = base_dir.joinpath('config_template.yaml') shutil.copyfile(cfg_template, cfg_filename) - print(f"Created config file {cfg_filename}") + print(f'Created config file {cfg_filename}') return cfg_filename @@ -153,7 +159,13 @@ def do_typecast(value, type, var_name=None): return val except Exception as error: if var_name: - logger.exception("Failed to typecast '%s' with value '%s' to type '%s' with error %s", var_name, value, type, error) + logger.exception( + "Failed to typecast '%s' with value '%s' to type '%s' with error %s", + var_name, + value, + type, + error, + ) return value @@ -174,7 +186,12 @@ def get_setting(env_var=None, config_key=None, default_value=None, typecast=None def set_metadata(source: str): """Set lookup metadata for the setting.""" key = env_var or config_key - CONFIG_LOOKUPS[key] = {'env_var': env_var, 'config_key': config_key, 'source': source, 'accessed': datetime.datetime.now()} + CONFIG_LOOKUPS[key] = { + 'env_var': env_var, + 'config_key': config_key, + 'source': source, + 'accessed': datetime.datetime.now(), + } # First, try to load from the environment variables if env_var is not None: @@ -192,7 +209,6 @@ def get_setting(env_var=None, config_key=None, default_value=None, typecast=None # Hack to allow 'path traversal' in configuration file for key in config_key.strip().split('.'): - if type(cfg_data) is not dict or key not in cfg_data: result = None break @@ -276,12 +292,16 @@ def get_plugin_file(): plugin_file = Path(plugin_file) if not plugin_file.exists(): - logger.warning("Plugin configuration file does not exist - creating default file") + logger.warning( + 'Plugin configuration file does not exist - creating default file' + ) logger.info("Creating plugin file at '%s'", plugin_file) ensure_dir(plugin_file.parent) # If opening the file fails (no write permission, for example), then this will throw an error - plugin_file.write_text("# InvenTree Plugins (uses PIP framework to install)\n\n") + plugin_file.write_text( + '# InvenTree Plugins (uses PIP framework to install)\n\n' + ) return plugin_file @@ -303,7 +323,7 @@ def get_secret_key(): """ # Look for environment variable if secret_key := get_setting('INVENTREE_SECRET_KEY', 'secret_key'): - logger.info("SECRET_KEY loaded by INVENTREE_SECRET_KEY") # pragma: no cover + logger.info('SECRET_KEY loaded by INVENTREE_SECRET_KEY') # pragma: no cover return secret_key # Look for secret key file @@ -311,7 +331,7 @@ def get_secret_key(): secret_key_file = Path(secret_key_file).resolve() else: # Default location for secret key file - secret_key_file = get_base_dir().joinpath("secret_key.txt").resolve() + secret_key_file = get_base_dir().joinpath('secret_key.txt').resolve() if not secret_key_file.exists(): logger.info("Generating random key file at '%s'", secret_key_file) @@ -329,7 +349,9 @@ def get_secret_key(): return key_data -def get_custom_file(env_ref: str, conf_ref: str, log_ref: str, lookup_media: bool = False): +def get_custom_file( + env_ref: str, conf_ref: str, log_ref: str, lookup_media: bool = False +): """Returns the checked path to a custom file. Set lookup_media to True to also search in the media folder. @@ -345,12 +367,17 @@ def get_custom_file(env_ref: str, conf_ref: str, log_ref: str, lookup_media: boo static_storage = StaticFilesStorage() if static_storage.exists(value): - logger.info("Loading %s from %s directory: %s", log_ref, 'static', value) + logger.info('Loading %s from %s directory: %s', log_ref, 'static', value) elif lookup_media and default_storage.exists(value): - logger.info("Loading %s from %s directory: %s", log_ref, 'media', value) + logger.info('Loading %s from %s directory: %s', log_ref, 'media', value) else: add_dir_str = ' or media' if lookup_media else '' - logger.warning("The %s file '%s' could not be found in the static %s directories", log_ref, value, add_dir_str) + logger.warning( + "The %s file '%s' could not be found in the static %s directories", + log_ref, + value, + add_dir_str, + ) value = False return value @@ -364,16 +391,21 @@ def get_frontend_settings(debug=True): """ # Legacy settings - pui_settings = get_setting('INVENTREE_PUI_SETTINGS', 'pui_settings', {}, typecast=dict) + pui_settings = get_setting( + 'INVENTREE_PUI_SETTINGS', 'pui_settings', {}, typecast=dict + ) if len(pui_settings) > 0: warnings.warn( "The 'INVENTREE_PUI_SETTINGS' key is deprecated. Please use 'INVENTREE_FRONTEND_SETTINGS' instead", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) # New settings - frontend_settings = get_setting('INVENTREE_FRONTEND_SETTINGS', 'frontend_settings', {}, typecast=dict) + frontend_settings = get_setting( + 'INVENTREE_FRONTEND_SETTINGS', 'frontend_settings', {}, typecast=dict + ) # Merge settings settings = {**pui_settings, **frontend_settings} @@ -385,10 +417,13 @@ def get_frontend_settings(debug=True): if base_url: warnings.warn( "The 'INVENTREE_PUI_URL_BASE' key is deprecated. Please use 'INVENTREE_FRONTEND_URL_BASE' instead", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) else: - base_url = get_setting('INVENTREE_FRONTEND_URL_BASE', 'frontend_url_base', 'platform') + base_url = get_setting( + 'INVENTREE_FRONTEND_URL_BASE', 'frontend_url_base', 'platform' + ) settings['base_url'] = base_url diff --git a/InvenTree/InvenTree/context.py b/InvenTree/InvenTree/context.py index 4518d2b951..e25a27d668 100644 --- a/InvenTree/InvenTree/context.py +++ b/InvenTree/InvenTree/context.py @@ -31,9 +31,7 @@ def health_status(request): } # The following keys are required to denote system health - health_keys = [ - 'django_q_running', - ] + health_keys = ['django_q_running'] all_healthy = True @@ -72,11 +70,9 @@ def user_roles(request): """ user = request.user - roles = { - } + roles = {} for role in RuleSet.RULESET_MODELS.keys(): - permissions = {} for perm in ['view', 'add', 'change', 'delete']: diff --git a/InvenTree/InvenTree/conversion.py b/InvenTree/InvenTree/conversion.py index 6324848693..b87d98c8a5 100644 --- a/InvenTree/InvenTree/conversion.py +++ b/InvenTree/InvenTree/conversion.py @@ -29,6 +29,7 @@ def reload_unit_registry(): This function is called at startup, and whenever the database is updated. """ import time + t_start = time.time() global _unit_registry @@ -52,7 +53,9 @@ def reload_unit_registry(): try: reg.define(cu.fmt_string()) except Exception as e: - logger.exception('Failed to load custom unit: %s - %s', cu.fmt_string(), e) + logger.exception( + 'Failed to load custom unit: %s - %s', cu.fmt_string(), e + ) # Once custom units are loaded, save registry _unit_registry = reg @@ -124,7 +127,7 @@ def convert_physical_value(value: str, unit: str = None, strip_units=True): if unit: raise ValidationError(_(f'Could not convert {original} to {unit}')) else: - raise ValidationError(_("Invalid quantity supplied")) + raise ValidationError(_('Invalid quantity supplied')) # Calculate the "magnitude" of the value, as a float # If the value is specified strangely (e.g. as a fraction or a dozen), this can cause issues diff --git a/InvenTree/InvenTree/email.py b/InvenTree/InvenTree/email.py index 9552262dd1..a5f7b283df 100644 --- a/InvenTree/InvenTree/email.py +++ b/InvenTree/InvenTree/email.py @@ -30,22 +30,22 @@ def is_email_configured(): # Display warning unless in test mode if not testing: # pragma: no cover - logger.debug("EMAIL_HOST is not configured") + logger.debug('EMAIL_HOST is not configured') # Display warning unless in test mode if not settings.EMAIL_HOST_USER and not testing: # pragma: no cover - logger.debug("EMAIL_HOST_USER is not configured") + logger.debug('EMAIL_HOST_USER is not configured') # Display warning unless in test mode if not settings.EMAIL_HOST_PASSWORD and testing: # pragma: no cover - logger.debug("EMAIL_HOST_PASSWORD is not configured") + logger.debug('EMAIL_HOST_PASSWORD is not configured') # Email sender must be configured if not settings.DEFAULT_FROM_EMAIL: configured = False if not testing: # pragma: no cover - logger.debug("DEFAULT_FROM_EMAIL is not configured") + logger.debug('DEFAULT_FROM_EMAIL is not configured') return configured @@ -75,7 +75,7 @@ def send_email(subject, body, recipients, from_email=None, html_message=None): if settings.TESTING: from_email = 'from@test.com' else: - logger.error("send_email failed: DEFAULT_FROM_EMAIL not specified") + logger.error('send_email failed: DEFAULT_FROM_EMAIL not specified') return InvenTree.tasks.offload_task( @@ -85,5 +85,5 @@ def send_email(subject, body, recipients, from_email=None, html_message=None): from_email, recipients, fail_silently=False, - html_message=html_message + html_message=html_message, ) diff --git a/InvenTree/InvenTree/exceptions.py b/InvenTree/InvenTree/exceptions.py index 75cbaff125..b2f014390d 100644 --- a/InvenTree/InvenTree/exceptions.py +++ b/InvenTree/InvenTree/exceptions.py @@ -86,7 +86,7 @@ def exception_handler(exc, context): # If in DEBUG mode, provide error information in the response error_detail = str(exc) else: - error_detail = _("Error details can be found in the admin panel") + error_detail = _('Error details can be found in the admin panel') response_data = { 'error': type(exc).__name__, diff --git a/InvenTree/InvenTree/exchange.py b/InvenTree/InvenTree/exchange.py index b3c6262d84..9a0d7ecf12 100644 --- a/InvenTree/InvenTree/exchange.py +++ b/InvenTree/InvenTree/exchange.py @@ -18,7 +18,7 @@ class InvenTreeExchange(SimpleExchangeBackend): Uses the plugin system to actually fetch the rates from an external API. """ - name = "InvenTreeExchange" + name = 'InvenTreeExchange' def get_rates(self, **kwargs) -> None: """Set the requested currency codes and get rates.""" @@ -44,7 +44,9 @@ class InvenTreeExchange(SimpleExchangeBackend): plugin = plugins[0] if not plugin: - logger.warning('No active currency exchange plugins found - skipping update') + logger.warning( + 'No active currency exchange plugins found - skipping update' + ) return {} logger.info("Running exchange rate update using plugin '%s'", plugin.name) @@ -53,16 +55,22 @@ class InvenTreeExchange(SimpleExchangeBackend): try: rates = plugin.update_exchange_rates(base_currency, symbols) except Exception as exc: - logger.exception("Exchange rate update failed: %s", exc) + logger.exception('Exchange rate update failed: %s', exc) return {} if not rates: - logger.warning("Exchange rate update failed - no data returned from plugin %s", slug) + logger.warning( + 'Exchange rate update failed - no data returned from plugin %s', slug + ) return {} # Update exchange rates based on returned data if type(rates) is not dict: - logger.warning("Invalid exchange rate data returned from plugin %s (type %s)", slug, type(rates)) + logger.warning( + 'Invalid exchange rate data returned from plugin %s (type %s)', + slug, + type(rates), + ) return {} # Ensure base currency is provided @@ -73,14 +81,20 @@ class InvenTreeExchange(SimpleExchangeBackend): @atomic def update_rates(self, base_currency=None, **kwargs): """Call to update all exchange rates""" - backend, _ = ExchangeBackend.objects.update_or_create(name=self.name, defaults={"base_currency": base_currency}) + backend, _ = ExchangeBackend.objects.update_or_create( + name=self.name, defaults={'base_currency': base_currency} + ) if base_currency is None: base_currency = currency_code_default() symbols = currency_codes() - logger.info("Updating exchange rates for %s (%s currencies)", base_currency, len(symbols)) + logger.info( + 'Updating exchange rates for %s (%s currencies)', + base_currency, + len(symbols), + ) # Fetch new rates from the backend # If the backend fails, the existing rates will not be updated @@ -95,6 +109,8 @@ class InvenTreeExchange(SimpleExchangeBackend): for currency, amount in rates.items() ]) else: - logger.info("No exchange rates returned from backend - currencies not updated") + logger.info( + 'No exchange rates returned from backend - currencies not updated' + ) - logger.info("Updated exchange rates for %s", base_currency) + logger.info('Updated exchange rates for %s', base_currency) diff --git a/InvenTree/InvenTree/fields.py b/InvenTree/InvenTree/fields.py index 70ca133a80..b91bdc39a4 100644 --- a/InvenTree/InvenTree/fields.py +++ b/InvenTree/InvenTree/fields.py @@ -35,7 +35,9 @@ class InvenTreeRestURLField(RestURLField): import common.models - strict_urls = common.models.InvenTreeSetting.get_setting('INVENTREE_STRICT_URLS', True, cache=False) + strict_urls = common.models.InvenTreeSetting.get_setting( + 'INVENTREE_STRICT_URLS', True, cache=False + ) if not strict_urls and data is not empty: if '://' not in data: @@ -97,11 +99,8 @@ class InvenTreeModelMoneyField(ModelMoneyField): # If no validators are provided, add some "standard" ones if len(validators) == 0: - if not allow_negative: - validators.append( - MinMoneyValidator(0), - ) + validators.append(MinMoneyValidator(0)) kwargs['validators'] = validators @@ -144,11 +143,7 @@ class DatePickerFormField(forms.DateField): required = kwargs.get('required', False) initial = kwargs.get('initial', None) - widget = forms.DateInput( - attrs={ - 'type': 'date', - } - ) + widget = forms.DateInput(attrs={'type': 'date'}) forms.DateField.__init__( self, @@ -156,7 +151,7 @@ class DatePickerFormField(forms.DateField): initial=initial, help_text=help_text, widget=widget, - label=label + label=label, ) diff --git a/InvenTree/InvenTree/filters.py b/InvenTree/InvenTree/filters.py index cea9b01b03..a3866cd0a0 100644 --- a/InvenTree/InvenTree/filters.py +++ b/InvenTree/InvenTree/filters.py @@ -36,7 +36,9 @@ class InvenTreeSearchFilter(filters.SearchFilter): The following query params are available to 'augment' the search (in decreasing order of priority) - search_regex: If True, search is performed on 'regex' comparison """ - regex = InvenTree.helpers.str2bool(request.query_params.get('search_regex', False)) + regex = InvenTree.helpers.str2bool( + request.query_params.get('search_regex', False) + ) search_fields = super().get_search_fields(view, request) @@ -56,7 +58,9 @@ class InvenTreeSearchFilter(filters.SearchFilter): Depending on the request parameters, we may "augment" these somewhat """ - whole = InvenTree.helpers.str2bool(request.query_params.get('search_whole', False)) + whole = InvenTree.helpers.str2bool( + request.query_params.get('search_whole', False) + ) terms = [] @@ -72,7 +76,7 @@ class InvenTreeSearchFilter(filters.SearchFilter): if whole: # Wrap the search term to enable word-boundary matching - term = r"\y" + term + r"\y" + term = r'\y' + term + r'\y' terms.append(term) @@ -110,7 +114,6 @@ class InvenTreeOrderingFilter(filters.OrderingFilter): ordering = [] for field in ordering_initial: - reverse = field.startswith('-') if reverse: @@ -164,7 +167,4 @@ SEARCH_ORDER_FILTER_ALIAS = [ InvenTreeOrderingFilter, ] -ORDER_FILTER = [ - rest_filters.DjangoFilterBackend, - filters.OrderingFilter, -] +ORDER_FILTER = [rest_filters.DjangoFilterBackend, filters.OrderingFilter] diff --git a/InvenTree/InvenTree/format.py b/InvenTree/InvenTree/format.py index ab4267959f..aa733cd094 100644 --- a/InvenTree/InvenTree/format.py +++ b/InvenTree/InvenTree/format.py @@ -36,10 +36,7 @@ def parse_format_string(fmt_string: str) -> dict: else: seen_groups.add(name) - info[group[1]] = { - 'format': group[1], - 'prefix': group[0], - } + info[group[1]] = {'format': group[1], 'prefix': group[0]} return info @@ -67,7 +64,7 @@ def construct_format_regex(fmt_string: str) -> str: Raises: ValueError: Format string is invalid """ - pattern = "^" + pattern = '^' for group in string.Formatter().parse(fmt_string): prefix = group[0] # Prefix (literal text appearing before this group) @@ -75,9 +72,23 @@ def construct_format_regex(fmt_string: str) -> str: format = group[2] # Format specifier e.g :04d rep = [ - '+', '-', '.', - '{', '}', '(', ')', - '^', '$', '~', '!', '@', ':', ';', '|', '\'', '"', + '+', + '-', + '.', + '{', + '}', + '(', + ')', + '^', + '$', + '~', + '!', + '@', + ':', + ';', + '|', + "'", + '"', ] # Escape any special regex characters @@ -94,7 +105,6 @@ def construct_format_regex(fmt_string: str) -> str: # Add a named capture group for the format entry if name: - # Check if integer values are required if format.endswith('d'): chr = '\d' @@ -105,9 +115,9 @@ def construct_format_regex(fmt_string: str) -> str: # TODO: Introspect required width w = '+' - pattern += f"(?P<{name}>{chr}{w})" + pattern += f'(?P<{name}>{chr}{w})' - pattern += "$" + pattern += '$' return pattern @@ -161,7 +171,9 @@ def extract_named_group(name: str, value: str, fmt_string: str) -> str: result = re.match(pattern, value) if not result: - raise ValueError(_("Provided value does not match required pattern: ") + fmt_string) + raise ValueError( + _('Provided value does not match required pattern: ') + fmt_string + ) # And return the value we are interested in # Note: This will raise an IndexError if the named group was not matched @@ -186,7 +198,7 @@ def format_money(money: Money, decimal_places: int = None, format: str = None) - if format: pattern = parse_pattern(format) else: - pattern = locale.currency_formats["standard"] + pattern = locale.currency_formats['standard'] if decimal_places is not None: pattern.frac_prec = (decimal_places, decimal_places) diff --git a/InvenTree/InvenTree/forms.py b/InvenTree/InvenTree/forms.py index 3532218d02..c33eaa8fae 100644 --- a/InvenTree/InvenTree/forms.py +++ b/InvenTree/InvenTree/forms.py @@ -17,8 +17,7 @@ from allauth.core.exceptions import ImmediateHttpResponse from allauth.socialaccount.adapter import DefaultSocialAccountAdapter from allauth_2fa.adapter import OTPAdapter from allauth_2fa.utils import user_has_valid_totp_device -from crispy_forms.bootstrap import (AppendedText, PrependedAppendedText, - PrependedText) +from crispy_forms.bootstrap import AppendedText, PrependedAppendedText, PrependedText from crispy_forms.helper import FormHelper from crispy_forms.layout import Field, Layout from dj_rest_auth.registration.serializers import RegisterSerializer @@ -80,31 +79,19 @@ class HelperForm(forms.ModelForm): field, prepended_text=prefix, appended_text=suffix, - placeholder=placeholder + placeholder=placeholder, ) ) ) elif prefix: layouts.append( - Field( - PrependedText( - field, - prefix, - placeholder=placeholder - ) - ) + Field(PrependedText(field, prefix, placeholder=placeholder)) ) elif suffix: layouts.append( - Field( - AppendedText( - field, - suffix, - placeholder=placeholder - ) - ) + Field(AppendedText(field, suffix, placeholder=placeholder)) ) else: @@ -120,10 +107,7 @@ class EditUserForm(HelperForm): """Metaclass options.""" model = User - fields = [ - 'first_name', - 'last_name', - ] + fields = ['first_name', 'last_name'] class SetPasswordForm(HelperForm): @@ -133,11 +117,7 @@ class SetPasswordForm(HelperForm): """Metaclass options.""" model = User - fields = [ - 'enter_password', - 'confirm_password', - 'old_password', - ] + fields = ['enter_password', 'confirm_password', 'old_password'] enter_password = forms.CharField( max_length=100, @@ -146,7 +126,7 @@ class SetPasswordForm(HelperForm): initial='', widget=forms.PasswordInput(attrs={'autocomplete': 'off'}), label=_('Enter password'), - help_text=_('Enter new password') + help_text=_('Enter new password'), ) confirm_password = forms.CharField( @@ -156,14 +136,16 @@ class SetPasswordForm(HelperForm): initial='', widget=forms.PasswordInput(attrs={'autocomplete': 'off'}), label=_('Confirm password'), - help_text=_('Confirm new password') + help_text=_('Confirm new password'), ) old_password = forms.CharField( - label=_("Old password"), + label=_('Old password'), strip=False, required=False, - widget=forms.PasswordInput(attrs={'autocomplete': 'current-password', 'autofocus': True}), + widget=forms.PasswordInput( + attrs={'autocomplete': 'current-password', 'autofocus': True} + ), ) @@ -196,22 +178,24 @@ class CustomSignupForm(SignupForm): # check for two mail fields if InvenTreeSetting.get_setting('LOGIN_SIGNUP_MAIL_TWICE'): - self.fields["email2"] = forms.EmailField( - label=_("Email (again)"), + self.fields['email2'] = forms.EmailField( + label=_('Email (again)'), widget=forms.TextInput( attrs={ - "type": "email", - "placeholder": _("Email address confirmation"), + 'type': 'email', + 'placeholder': _('Email address confirmation'), } ), ) # check for two password fields if not InvenTreeSetting.get_setting('LOGIN_SIGNUP_PWD_TWICE'): - self.fields.pop("password2") + self.fields.pop('password2') # reorder fields - set_form_field_order(self, ["username", "email", "email2", "password1", "password2", ]) + set_form_field_order( + self, ['username', 'email', 'email2', 'password1', 'password2'] + ) def clean(self): """Make sure the supplied emails match if enabled in settings.""" @@ -219,21 +203,26 @@ class CustomSignupForm(SignupForm): # check for two mail fields if InvenTreeSetting.get_setting('LOGIN_SIGNUP_MAIL_TWICE'): - email = cleaned_data.get("email") - email2 = cleaned_data.get("email2") + email = cleaned_data.get('email') + email2 = cleaned_data.get('email2') if (email and email2) and email != email2: - self.add_error("email2", _("You must type the same email each time.")) + self.add_error('email2', _('You must type the same email each time.')) return cleaned_data def registration_enabled(): """Determine whether user registration is enabled.""" - if InvenTreeSetting.get_setting('LOGIN_ENABLE_REG') or InvenTree.sso.registration_enabled(): + if ( + InvenTreeSetting.get_setting('LOGIN_ENABLE_REG') + or InvenTree.sso.registration_enabled() + ): if settings.EMAIL_HOST: return True else: - logger.error("Registration cannot be enabled, because EMAIL_HOST is not configured.") + logger.error( + 'Registration cannot be enabled, because EMAIL_HOST is not configured.' + ) return False @@ -251,20 +240,26 @@ class RegistratonMixin: def clean_email(self, email): """Check if the mail is valid to the pattern in LOGIN_SIGNUP_MAIL_RESTRICTION (if enabled in settings).""" - mail_restriction = InvenTreeSetting.get_setting('LOGIN_SIGNUP_MAIL_RESTRICTION', None) + mail_restriction = InvenTreeSetting.get_setting( + 'LOGIN_SIGNUP_MAIL_RESTRICTION', None + ) if not mail_restriction: return super().clean_email(email) split_email = email.split('@') if len(split_email) != 2: logger.error('The user %s has an invalid email address', email) - raise forms.ValidationError(_('The provided primary email address is not valid.')) + raise forms.ValidationError( + _('The provided primary email address is not valid.') + ) mailoptions = mail_restriction.split(',') for option in mailoptions: if not option.startswith('@'): log_error('LOGIN_SIGNUP_MAIL_RESTRICTION is not configured correctly') - raise forms.ValidationError(_('The provided primary email address is not valid.')) + raise forms.ValidationError( + _('The provided primary email address is not valid.') + ) else: if split_email[1] == option[1:]: return super().clean_email(email) @@ -284,7 +279,10 @@ class RegistratonMixin: group = Group.objects.get(id=start_group) user.groups.add(group) except Group.DoesNotExist: - logger.exception('The setting `SIGNUP_GROUP` contains an non existent group', start_group) + logger.exception( + 'The setting `SIGNUP_GROUP` contains an non existent group', + start_group, + ) user.save() return user @@ -294,11 +292,13 @@ class CustomUrlMixin: def get_email_confirmation_url(self, request, emailconfirmation): """Custom email confirmation (activation) url.""" - url = reverse("account_confirm_email", args=[emailconfirmation.key]) + url = reverse('account_confirm_email', args=[emailconfirmation.key]) return Site.objects.get_current().domain + url -class CustomAccountAdapter(CustomUrlMixin, RegistratonMixin, OTPAdapter, DefaultAccountAdapter): +class CustomAccountAdapter( + CustomUrlMixin, RegistratonMixin, OTPAdapter, DefaultAccountAdapter +): """Override of adapter to use dynamic settings.""" def send_mail(self, template_prefix, email, context): @@ -325,7 +325,9 @@ class CustomAccountAdapter(CustomUrlMixin, RegistratonMixin, OTPAdapter, Default return url -class CustomSocialAccountAdapter(CustomUrlMixin, RegistratonMixin, DefaultSocialAccountAdapter): +class CustomSocialAccountAdapter( + CustomUrlMixin, RegistratonMixin, DefaultSocialAccountAdapter +): """Override of adapter to use dynamic settings.""" def is_auto_signup_allowed(self, request, sociallogin): @@ -352,14 +354,14 @@ class CustomSocialAccountAdapter(CustomUrlMixin, RegistratonMixin, DefaultSocial if request.GET: redirect_url += '?' + urlencode(request.GET) - raise ImmediateHttpResponse( - response=HttpResponseRedirect(redirect_url) - ) + raise ImmediateHttpResponse(response=HttpResponseRedirect(redirect_url)) # Otherwise defer to the original allauth adapter. return super().login(request, user) - def authentication_error(self, request, provider_id, error=None, exception=None, extra_context=None): + def authentication_error( + self, request, provider_id, error=None, exception=None, extra_context=None + ): """Callback method for authentication errors.""" # Log the error to the database @@ -370,6 +372,7 @@ class CustomSocialAccountAdapter(CustomUrlMixin, RegistratonMixin, DefaultSocial # override dj-rest-auth class CustomRegisterSerializer(RegisterSerializer): """Override of serializer to use dynamic settings.""" + email = serializers.EmailField() def __init__(self, instance=None, data=..., **kwargs): diff --git a/InvenTree/InvenTree/helpers.py b/InvenTree/InvenTree/helpers.py index 2f81b68a75..3e8b575f41 100644 --- a/InvenTree/InvenTree/helpers.py +++ b/InvenTree/InvenTree/helpers.py @@ -36,7 +36,7 @@ def generateTestKey(test_name): Tests must be named such that they will have unique keys. """ key = test_name.strip().lower() - key = key.replace(" ", "") + key = key.replace(' ', '') # Remove any characters that cannot be used to represent a variable key = re.sub(r'[^a-zA-Z0-9]', '', key) @@ -56,7 +56,7 @@ def constructPathString(path, max_chars=250): # Replace middle elements to limit the pathstring if len(pathstring) > max_chars: n = int(max_chars / 2 - 2) - pathstring = pathstring[:n] + "..." + pathstring[-n:] + pathstring = pathstring[:n] + '...' + pathstring[-n:] return pathstring @@ -82,19 +82,18 @@ def TestIfImage(img): def getBlankImage(): """Return the qualified path for the 'blank image' placeholder.""" - return getStaticUrl("img/blank_image.png") + return getStaticUrl('img/blank_image.png') def getBlankThumbnail(): """Return the qualified path for the 'blank image' thumbnail placeholder.""" - return getStaticUrl("img/blank_image.thumbnail.png") + return getStaticUrl('img/blank_image.thumbnail.png') def getLogoImage(as_file=False, custom=True): """Return the InvenTree logo image, or a custom logo if available.""" """Return the path to the logo-file.""" if custom and settings.CUSTOM_LOGO: - static_storage = StaticFilesStorage() if static_storage.exists(settings.CUSTOM_LOGO): @@ -106,13 +105,13 @@ def getLogoImage(as_file=False, custom=True): if storage is not None: if as_file: - return f"file://{storage.path(settings.CUSTOM_LOGO)}" + return f'file://{storage.path(settings.CUSTOM_LOGO)}' return storage.url(settings.CUSTOM_LOGO) # If we have got to this point, return the default logo if as_file: path = settings.STATIC_ROOT.joinpath('img/inventree.png') - return f"file://{path}" + return f'file://{path}' return getStaticUrl('img/inventree.png') @@ -121,12 +120,11 @@ def getSplashScreen(custom=True): static_storage = StaticFilesStorage() if custom and settings.CUSTOM_SPLASH: - if static_storage.exists(settings.CUSTOM_SPLASH): return static_storage.url(settings.CUSTOM_SPLASH) # No custom splash screen - return static_storage.url("img/inventree_splash.jpg") + return static_storage.url('img/inventree_splash.jpg') def TestIfImageURL(url): @@ -135,10 +133,15 @@ def TestIfImageURL(url): Simply tests the extension against a set of allowed values """ return os.path.splitext(os.path.basename(url))[-1].lower() in [ - '.jpg', '.jpeg', '.j2k', - '.png', '.bmp', - '.tif', '.tiff', - '.webp', '.gif', + '.jpg', + '.jpeg', + '.j2k', + '.png', + '.bmp', + '.tif', + '.tiff', + '.webp', + '.gif', ] @@ -153,8 +156,8 @@ def str2bool(text, test=True): True if the text looks like the selected boolean value """ if test: - return str(text).lower() in ['1', 'y', 'yes', 't', 'true', 'ok', 'on', ] - return str(text).lower() in ['0', 'n', 'no', 'none', 'f', 'false', 'off', ] + return str(text).lower() in ['1', 'y', 'yes', 't', 'true', 'ok', 'on'] + return str(text).lower() in ['0', 'n', 'no', 'none', 'f', 'false', 'off'] def str2int(text, default=None): @@ -191,7 +194,15 @@ def isNull(text): Returns: True if the text looks like a null value """ - return str(text).strip().lower() in ['top', 'null', 'none', 'empty', 'false', '-1', ''] + return str(text).strip().lower() in [ + 'top', + 'null', + 'none', + 'empty', + 'false', + '-1', + '', + ] def normalize(d): @@ -223,7 +234,7 @@ def increment(value): # Provide a default value if provided with a null input return '1' - pattern = r"(.*?)(\d+)?$" + pattern = r'(.*?)(\d+)?$' result = re.search(pattern, value) @@ -282,7 +293,7 @@ def decimal2string(d): if '.' not in s: return s - return s.rstrip("0").rstrip(".") + return s.rstrip('0').rstrip('.') def decimal2money(d, currency=None): @@ -354,17 +365,12 @@ def MakeBarcode(cls_name, object_pk: int, object_data=None, **kwargs): def GetExportFormats(): """Return a list of allowable file formats for exporting data.""" - return [ - 'csv', - 'tsv', - 'xls', - 'xlsx', - 'json', - 'yaml', - ] + return ['csv', 'tsv', 'xls', 'xlsx', 'json', 'yaml'] -def DownloadFile(data, filename, content_type='application/text', inline=False) -> StreamingHttpResponse: +def DownloadFile( + data, filename, content_type='application/text', inline=False +) -> StreamingHttpResponse: """Create a dynamic file for the user to download. Args: @@ -389,7 +395,7 @@ def DownloadFile(data, filename, content_type='application/text', inline=False) length = len(bytes(data, response.charset)) response['Content-Length'] = length - disposition = "inline" if inline else "attachment" + disposition = 'inline' if inline else 'attachment' response['Content-Disposition'] = f'{disposition}; filename={filename}' @@ -449,7 +455,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value= try: expected_quantity = int(expected_quantity) except ValueError: - raise ValidationError([_("Invalid quantity provided")]) + raise ValidationError([_('Invalid quantity provided')]) if input_string: input_string = str(input_string).strip() @@ -457,7 +463,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value= input_string = '' if len(input_string) == 0: - raise ValidationError([_("Empty serial number string")]) + raise ValidationError([_('Empty serial number string')]) next_value = increment_serial_number(starting_value) @@ -467,7 +473,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value= next_value = increment_serial_number(next_value) # Split input string by whitespace or comma (,) characters - groups = re.split(r"[\s,]+", input_string) + groups = re.split(r'[\s,]+', input_string) serials = [] errors = [] @@ -487,7 +493,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value= return if serial in serials: - add_error(_("Duplicate serial") + f": {serial}") + add_error(_('Duplicate serial') + f': {serial}') else: serials.append(serial) @@ -502,7 +508,6 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value= return serials for group in groups: - # Calculate the "remaining" quantity of serial numbers remaining = expected_quantity - len(serials) @@ -520,7 +525,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value= if a == b: # Invalid group - add_error(_(f"Invalid group range: {group}")) + add_error(_(f'Invalid group range: {group}')) continue group_items = [] @@ -549,13 +554,21 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value= break if len(group_items) > remaining: - add_error(_(f"Group range {group} exceeds allowed quantity ({expected_quantity})")) - elif len(group_items) > 0 and group_items[0] == a and group_items[-1] == b: + add_error( + _( + f'Group range {group} exceeds allowed quantity ({expected_quantity})' + ) + ) + elif ( + len(group_items) > 0 + and group_items[0] == a + and group_items[-1] == b + ): # In this case, the range extraction looks like it has worked for item in group_items: add_serial(item) else: - add_error(_(f"Invalid group range: {group}")) + add_error(_(f'Invalid group range: {group}')) else: # In the case of a different number of hyphens, simply add the entire group @@ -573,20 +586,24 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value= sequence_count = max(0, expected_quantity - len(serials)) if len(items) > 2 or len(items) == 0: - add_error(_(f"Invalid group sequence: {group}")) + add_error(_(f'Invalid group sequence: {group}')) continue elif len(items) == 2: try: if items[1]: sequence_count = int(items[1]) + 1 except ValueError: - add_error(_(f"Invalid group sequence: {group}")) + add_error(_(f'Invalid group sequence: {group}')) continue value = items[0] # Keep incrementing up to the specified quantity - while value is not None and value not in sequence_items and counter < sequence_count: + while ( + value is not None + and value not in sequence_items + and counter < sequence_count + ): sequence_items.append(value) value = increment_serial_number(value) counter += 1 @@ -595,7 +612,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value= for item in sequence_items: add_serial(item) else: - add_error(_(f"Invalid group sequence: {group}")) + add_error(_(f'Invalid group sequence: {group}')) else: # At this point, we assume that the 'group' is just a single serial value @@ -605,10 +622,14 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value= raise ValidationError(errors) if len(serials) == 0: - raise ValidationError([_("No serial numbers found")]) + raise ValidationError([_('No serial numbers found')]) if len(errors) == 0 and len(serials) != expected_quantity: - raise ValidationError([_(f"Number of unique serial numbers ({len(serials)}) must match quantity ({expected_quantity})")]) + raise ValidationError([ + _( + f'Number of unique serial numbers ({len(serials)}) must match quantity ({expected_quantity})' + ) + ]) return serials @@ -645,9 +666,7 @@ def validateFilterString(value, model=None): pair = group.split('=') if len(pair) != 2: - raise ValidationError( - f"Invalid group: {group}" - ) + raise ValidationError(f'Invalid group: {group}') k, v = pair @@ -655,9 +674,7 @@ def validateFilterString(value, model=None): v = v.strip() if not k or not v: - raise ValidationError( - f"Invalid group: {group}" - ) + raise ValidationError(f'Invalid group: {group}') results[k] = v @@ -666,9 +683,7 @@ def validateFilterString(value, model=None): try: model.objects.filter(**results) except FieldError as e: - raise ValidationError( - str(e), - ) + raise ValidationError(str(e)) return results @@ -706,7 +721,11 @@ def clean_decimal(number): # Number cannot be converted to Decimal (eg. a string containing letters) return Decimal(0) - return clean_number.quantize(Decimal(1)) if clean_number == clean_number.to_integral() else clean_number.normalize() + return ( + clean_number.quantize(Decimal(1)) + if clean_number == clean_number.to_integral() + else clean_number.normalize() + ) def strip_html_tags(value: str, raise_error=True, field_name=None): @@ -714,55 +733,45 @@ def strip_html_tags(value: str, raise_error=True, field_name=None): If raise_error is True, a ValidationError will be thrown if HTML tags are detected """ - cleaned = clean( - value, - strip=True, - tags=[], - attributes=[], - ) + cleaned = clean(value, strip=True, tags=[], attributes=[]) # Add escaped characters back in - replacements = { - '>': '>', - '<': '<', - '&': '&', - } + replacements = {'>': '>', '<': '<', '&': '&'} for o, r in replacements.items(): cleaned = cleaned.replace(o, r) # If the length changed, it means that HTML tags were removed! if len(cleaned) != len(value) and raise_error: - field = field_name or 'non_field_errors' - raise ValidationError({ - field: [_("Remove HTML tags from this value")] - }) + raise ValidationError({field: [_('Remove HTML tags from this value')]}) return cleaned -def remove_non_printable_characters(value: str, remove_newline=True, remove_ascii=True, remove_unicode=True): +def remove_non_printable_characters( + value: str, remove_newline=True, remove_ascii=True, remove_unicode=True +): """Remove non-printable / control characters from the provided string""" cleaned = value if remove_ascii: # Remove ASCII control characters # Note that we do not sub out 0x0A (\n) here, it is done separately below - cleaned = regex.sub(u'[\x00-\x09]+', '', cleaned) - cleaned = regex.sub(u'[\x0b-\x1F\x7F]+', '', cleaned) + cleaned = regex.sub('[\x00-\x09]+', '', cleaned) + cleaned = regex.sub('[\x0b-\x1f\x7f]+', '', cleaned) if remove_newline: - cleaned = regex.sub(u'[\x0a]+', '', cleaned) + cleaned = regex.sub('[\x0a]+', '', cleaned) if remove_unicode: # Remove Unicode control characters if remove_newline: - cleaned = regex.sub(u'[^\P{C}]+', '', cleaned) + cleaned = regex.sub('[^\P{C}]+', '', cleaned) else: # Use 'negative-lookahead' to exclude newline character - cleaned = regex.sub(u'(?![\x0A])[^\P{C}]+', '', cleaned) + cleaned = regex.sub('(?![\x0a])[^\P{C}]+', '', cleaned) return cleaned @@ -783,7 +792,9 @@ def hash_barcode(barcode_data): return str(hash.hexdigest()) -def get_objectreference(obj, type_ref: str = 'content_type', object_ref: str = 'object_id'): +def get_objectreference( + obj, type_ref: str = 'content_type', object_ref: str = 'object_id' +): """Lookup method for the GenericForeignKey fields. Attributes: @@ -821,11 +832,7 @@ def get_objectreference(obj, type_ref: str = 'content_type', object_ref: str = ' ret = {} if url_fnc: ret['link'] = url_fnc() - return { - 'name': str(item), - 'model': str(model_cls._meta.verbose_name), - **ret - } + return {'name': str(item), 'model': str(model_cls._meta.verbose_name), **ret} def inheritors(cls): diff --git a/InvenTree/InvenTree/helpers_model.py b/InvenTree/InvenTree/helpers_model.py index 5dc37cf723..ab76bf00db 100644 --- a/InvenTree/InvenTree/helpers_model.py +++ b/InvenTree/InvenTree/helpers_model.py @@ -19,8 +19,11 @@ import common.models import InvenTree import InvenTree.helpers_model import InvenTree.version -from common.notifications import (InvenTreeNotificationBodies, - NotificationBody, trigger_notification) +from common.notifications import ( + InvenTreeNotificationBodies, + NotificationBody, + trigger_notification, +) from InvenTree.format import format_money logger = logging.getLogger('inventree') @@ -49,7 +52,9 @@ def construct_absolute_url(*arg, **kwargs): if not site_url: # Otherwise, try to use the InvenTree setting try: - site_url = common.models.InvenTreeSetting.get_setting('INVENTREE_BASE_URL', create=False, cache=False) + site_url = common.models.InvenTreeSetting.get_setting( + 'INVENTREE_BASE_URL', create=False, cache=False + ) except (ProgrammingError, OperationalError): pass @@ -100,12 +105,22 @@ def download_image_from_url(remote_url, timeout=2.5): validator(remote_url) # Calculate maximum allowable image size (in bytes) - max_size = int(common.models.InvenTreeSetting.get_setting('INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE')) * 1024 * 1024 + max_size = ( + int( + common.models.InvenTreeSetting.get_setting( + 'INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE' + ) + ) + * 1024 + * 1024 + ) # Add user specified user-agent to request (if specified) - user_agent = common.models.InvenTreeSetting.get_setting('INVENTREE_DOWNLOAD_FROM_URL_USER_AGENT') + user_agent = common.models.InvenTreeSetting.get_setting( + 'INVENTREE_DOWNLOAD_FROM_URL_USER_AGENT' + ) if user_agent: - headers = {"User-Agent": user_agent} + headers = {'User-Agent': user_agent} else: headers = None @@ -120,24 +135,28 @@ def download_image_from_url(remote_url, timeout=2.5): # Throw an error if anything goes wrong response.raise_for_status() except requests.exceptions.ConnectionError as exc: - raise Exception(_("Connection error") + f": {str(exc)}") + raise Exception(_('Connection error') + f': {str(exc)}') except requests.exceptions.Timeout as exc: raise exc except requests.exceptions.HTTPError: - raise requests.exceptions.HTTPError(_("Server responded with invalid status code") + f": {response.status_code}") + raise requests.exceptions.HTTPError( + _('Server responded with invalid status code') + f': {response.status_code}' + ) except Exception as exc: - raise Exception(_("Exception occurred") + f": {str(exc)}") + raise Exception(_('Exception occurred') + f': {str(exc)}') if response.status_code != 200: - raise Exception(_("Server responded with invalid status code") + f": {response.status_code}") + raise Exception( + _('Server responded with invalid status code') + f': {response.status_code}' + ) try: content_length = int(response.headers.get('Content-Length', 0)) except ValueError: - raise ValueError(_("Server responded with invalid Content-Length value")) + raise ValueError(_('Server responded with invalid Content-Length value')) if content_length > max_size: - raise ValueError(_("Image size is too large")) + raise ValueError(_('Image size is too large')) # Download the file, ensuring we do not exceed the reported size file = io.BytesIO() @@ -149,12 +168,12 @@ def download_image_from_url(remote_url, timeout=2.5): dl_size += len(chunk) if dl_size > max_size: - raise ValueError(_("Image download exceeded maximum size")) + raise ValueError(_('Image download exceeded maximum size')) file.write(chunk) if dl_size == 0: - raise ValueError(_("Remote server returned empty response")) + raise ValueError(_('Remote server returned empty response')) # Now, attempt to convert the downloaded data to a valid image file # img.verify() will throw an exception if the image is not valid @@ -162,12 +181,18 @@ def download_image_from_url(remote_url, timeout=2.5): img = Image.open(file).convert() img.verify() except Exception: - raise TypeError(_("Supplied URL is not a valid image file")) + raise TypeError(_('Supplied URL is not a valid image file')) return img -def render_currency(money, decimal_places=None, currency=None, min_decimal_places=None, max_decimal_places=None): +def render_currency( + money, + decimal_places=None, + currency=None, + min_decimal_places=None, + max_decimal_places=None, +): """Render a currency / Money object to a formatted string (e.g. for reports) Arguments: @@ -192,13 +217,19 @@ def render_currency(money, decimal_places=None, currency=None, min_decimal_place pass if decimal_places is None: - decimal_places = common.models.InvenTreeSetting.get_setting('PRICING_DECIMAL_PLACES', 6) + decimal_places = common.models.InvenTreeSetting.get_setting( + 'PRICING_DECIMAL_PLACES', 6 + ) if min_decimal_places is None: - min_decimal_places = common.models.InvenTreeSetting.get_setting('PRICING_DECIMAL_PLACES_MIN', 0) + min_decimal_places = common.models.InvenTreeSetting.get_setting( + 'PRICING_DECIMAL_PLACES_MIN', 0 + ) if max_decimal_places is None: - max_decimal_places = common.models.InvenTreeSetting.get_setting('PRICING_DECIMAL_PLACES', 6) + max_decimal_places = common.models.InvenTreeSetting.get_setting( + 'PRICING_DECIMAL_PLACES', 6 + ) value = Decimal(str(money.amount)).normalize() value = str(value) @@ -229,7 +260,9 @@ def getModelsWithMixin(mixin_class) -> list: from django.contrib.contenttypes.models import ContentType try: - db_models = [x.model_class() for x in ContentType.objects.all() if x is not None] + db_models = [ + x.model_class() for x in ContentType.objects.all() if x is not None + ] except (OperationalError, ProgrammingError): # Database is likely not yet ready db_models = [] @@ -237,7 +270,12 @@ def getModelsWithMixin(mixin_class) -> list: return [x for x in db_models if x is not None and issubclass(x, mixin_class)] -def notify_responsible(instance, sender, content: NotificationBody = InvenTreeNotificationBodies.NewOrder, exclude=None): +def notify_responsible( + instance, + sender, + content: NotificationBody = InvenTreeNotificationBodies.NewOrder, + exclude=None, +): """Notify all responsible parties of a change in an instance. Parses the supplied content with the provided instance and sender and sends a notification to all responsible users, @@ -249,10 +287,18 @@ def notify_responsible(instance, sender, content: NotificationBody = InvenTreeNo content (NotificationBody, optional): _description_. Defaults to InvenTreeNotificationBodies.NewOrder. exclude (User, optional): User instance that should be excluded. Defaults to None. """ - notify_users([instance.responsible], instance, sender, content=content, exclude=exclude) + notify_users( + [instance.responsible], instance, sender, content=content, exclude=exclude + ) -def notify_users(users, instance, sender, content: NotificationBody = InvenTreeNotificationBodies.NewOrder, exclude=None): +def notify_users( + users, + instance, + sender, + content: NotificationBody = InvenTreeNotificationBodies.NewOrder, + exclude=None, +): """Notify all passed users or groups. Parses the supplied content with the provided instance and sender and sends a notification to all users, @@ -278,10 +324,10 @@ def notify_users(users, instance, sender, content: NotificationBody = InvenTreeN 'instance': instance, 'name': content.name.format(**content_context), 'message': content.message.format(**content_context), - 'link': InvenTree.helpers_model.construct_absolute_url(instance.get_absolute_url()), - 'template': { - 'subject': content.name.format(**content_context), - } + 'link': InvenTree.helpers_model.construct_absolute_url( + instance.get_absolute_url() + ), + 'template': {'subject': content.name.format(**content_context)}, } if content.template: diff --git a/InvenTree/InvenTree/magic_login.py b/InvenTree/InvenTree/magic_login.py index b48c3f3547..725b913bbf 100644 --- a/InvenTree/InvenTree/magic_login.py +++ b/InvenTree/InvenTree/magic_login.py @@ -1,4 +1,5 @@ """Functions for magic login.""" + from django.conf import settings from django.contrib.auth.models import User from django.contrib.sites.models import Site @@ -17,15 +18,13 @@ def send_simple_login_email(user, link): """Send an email with the login link to this user.""" site = Site.objects.get_current() - context = { - "username": user.username, - "site_name": site.name, - "link": link, - } - email_plaintext_message = render_to_string("InvenTree/user_simple_login.txt", context) + context = {'username': user.username, 'site_name': site.name, 'link': link} + email_plaintext_message = render_to_string( + 'InvenTree/user_simple_login.txt', context + ) send_mail( - _(f"[{site.name}] Log in to the app"), + _(f'[{site.name}] Log in to the app'), email_plaintext_message, settings.DEFAULT_FROM_EMAIL, [user.email], @@ -35,7 +34,7 @@ def send_simple_login_email(user, link): class GetSimpleLoginSerializer(serializers.Serializer): """Serializer for the simple login view.""" - email = serializers.CharField(label=_("Email")) + email = serializers.CharField(label=_('Email')) class GetSimpleLoginView(APIView): @@ -48,14 +47,14 @@ class GetSimpleLoginView(APIView): """Get the token for the current user or fail.""" serializer = self.serializer_class(data=request.data) serializer.is_valid(raise_exception=True) - self.email_submitted(email=serializer.data["email"]) - return Response({"status": "ok"}) + self.email_submitted(email=serializer.data['email']) + return Response({'status': 'ok'}) def email_submitted(self, email): """Notify user about link.""" user = self.get_user(email) if user is None: - print("user not found:", email) + print('user not found:', email) return link = self.create_link(user) send_simple_login_email(user, link) @@ -69,7 +68,7 @@ class GetSimpleLoginView(APIView): def create_link(self, user): """Create a login link for this user.""" - link = reverse("sesame-login") + link = reverse('sesame-login') link = self.request.build_absolute_uri(link) link += sesame.utils.get_query_string(user) return link diff --git a/InvenTree/InvenTree/management/commands/clean_settings.py b/InvenTree/InvenTree/management/commands/clean_settings.py index 4ef8269f24..45f99e98ab 100644 --- a/InvenTree/InvenTree/management/commands/clean_settings.py +++ b/InvenTree/InvenTree/management/commands/clean_settings.py @@ -12,7 +12,7 @@ class Command(BaseCommand): def handle(self, *args, **kwargs): """Cleanup old (undefined) settings in the database.""" - logger.info("Collecting settings") + logger.info('Collecting settings') from common.models import InvenTreeSetting, InvenTreeUserSetting # general settings @@ -35,4 +35,4 @@ class Command(BaseCommand): setting.delete() logger.info("deleted user setting '%s'", setting.key) - logger.info("checked all settings") + logger.info('checked all settings') diff --git a/InvenTree/InvenTree/management/commands/prerender.py b/InvenTree/InvenTree/management/commands/prerender.py index 5bba15109b..466b6666c7 100644 --- a/InvenTree/InvenTree/management/commands/prerender.py +++ b/InvenTree/InvenTree/management/commands/prerender.py @@ -13,7 +13,6 @@ from django.utils.translation import override as lang_over def render_file(file_name, source, target, locales, ctx): """Renders a file into all provided locales.""" for locale in locales: - # Enforce lower-case for locale names locale = locale.lower() locale = locale.replace('_', '-') @@ -50,15 +49,19 @@ class Command(BaseCommand): # render! request = HttpRequest() ctx = {} - processors = tuple(import_string(path) for path in settings.STATFILES_I18_PROCESSORS) + processors = tuple( + import_string(path) for path in settings.STATFILES_I18_PROCESSORS + ) for processor in processors: ctx.update(processor(request)) - for file in os.listdir(SOURCE_DIR, ): + for file in os.listdir(SOURCE_DIR): path = os.path.join(SOURCE_DIR, file) if os.path.exists(path) and os.path.isfile(path): - print(f"render {file}") + print(f'render {file}') render_file(file, SOURCE_DIR, TARGET_DIR, locales, ctx) else: - raise NotImplementedError('Using multi-level directories is not implemented at this point') # TODO multilevel dir if needed - print(f"rendered all files in {SOURCE_DIR}") + raise NotImplementedError( + 'Using multi-level directories is not implemented at this point' + ) # TODO multilevel dir if needed + print(f'rendered all files in {SOURCE_DIR}') diff --git a/InvenTree/InvenTree/management/commands/rebuild_models.py b/InvenTree/InvenTree/management/commands/rebuild_models.py index 87aa378f0b..02af71f3a5 100644 --- a/InvenTree/InvenTree/management/commands/rebuild_models.py +++ b/InvenTree/InvenTree/management/commands/rebuild_models.py @@ -13,45 +13,50 @@ class Command(BaseCommand): """Rebuild all database models which leverage the MPTT structure.""" # Part model try: - print("Rebuilding Part objects") + print('Rebuilding Part objects') from part.models import Part + Part.objects.rebuild() except Exception: - print("Error rebuilding Part objects") + print('Error rebuilding Part objects') # Part category try: - print("Rebuilding PartCategory objects") + print('Rebuilding PartCategory objects') from part.models import PartCategory + PartCategory.objects.rebuild() except Exception: - print("Error rebuilding PartCategory objects") + print('Error rebuilding PartCategory objects') # StockItem model try: - print("Rebuilding StockItem objects") + print('Rebuilding StockItem objects') from stock.models import StockItem + StockItem.objects.rebuild() except Exception: - print("Error rebuilding StockItem objects") + print('Error rebuilding StockItem objects') # StockLocation model try: - print("Rebuilding StockLocation objects") + print('Rebuilding StockLocation objects') from stock.models import StockLocation + StockLocation.objects.rebuild() except Exception: - print("Error rebuilding StockLocation objects") + print('Error rebuilding StockLocation objects') # Build model try: - print("Rebuilding Build objects") + print('Rebuilding Build objects') from build.models import Build + Build.objects.rebuild() except Exception: - print("Error rebuilding Build objects") + print('Error rebuilding Build objects') diff --git a/InvenTree/InvenTree/management/commands/rebuild_thumbnails.py b/InvenTree/InvenTree/management/commands/rebuild_thumbnails.py index 102a9e1bf1..8b54c98cc7 100644 --- a/InvenTree/InvenTree/management/commands/rebuild_thumbnails.py +++ b/InvenTree/InvenTree/management/commands/rebuild_thumbnails.py @@ -37,20 +37,20 @@ class Command(BaseCommand): def handle(self, *args, **kwargs): """Rebuild all thumbnail images.""" - logger.info("Rebuilding Part thumbnails") + logger.info('Rebuilding Part thumbnails') for part in Part.objects.exclude(image=None): try: self.rebuild_thumbnail(part) except (OperationalError, ProgrammingError): - logger.exception("ERROR: Database read error.") + logger.exception('ERROR: Database read error.') break - logger.info("Rebuilding Company thumbnails") + logger.info('Rebuilding Company thumbnails') for company in Company.objects.exclude(image=None): try: self.rebuild_thumbnail(company) except (OperationalError, ProgrammingError): - logger.exception("ERROR: abase read error.") + logger.exception('ERROR: abase read error.') break diff --git a/InvenTree/InvenTree/management/commands/remove_mfa.py b/InvenTree/InvenTree/management/commands/remove_mfa.py index cd4ef8c88b..4cb4f21659 100644 --- a/InvenTree/InvenTree/management/commands/remove_mfa.py +++ b/InvenTree/InvenTree/management/commands/remove_mfa.py @@ -18,7 +18,12 @@ class Command(BaseCommand): if not mail: raise KeyError('A mail is required') user = get_user_model() - mfa_user = [*set(user.objects.filter(email=mail) | user.objects.filter(emailaddress__email=mail))] + mfa_user = [ + *set( + user.objects.filter(email=mail) + | user.objects.filter(emailaddress__email=mail) + ) + ] if len(mfa_user) == 0: print('No user with this mail associated') diff --git a/InvenTree/InvenTree/management/commands/wait_for_db.py b/InvenTree/InvenTree/management/commands/wait_for_db.py index 3bcd2f5ffd..6bfdc98b58 100644 --- a/InvenTree/InvenTree/management/commands/wait_for_db.py +++ b/InvenTree/InvenTree/management/commands/wait_for_db.py @@ -12,12 +12,11 @@ class Command(BaseCommand): def handle(self, *args, **kwargs): """Wait till the database is ready.""" - self.stdout.write("Waiting for database...") + self.stdout.write('Waiting for database...') connected = False while not connected: - time.sleep(2) try: @@ -26,12 +25,12 @@ class Command(BaseCommand): connected = True except OperationalError as e: - self.stdout.write(f"Could not connect to database: {e}") + self.stdout.write(f'Could not connect to database: {e}') except ImproperlyConfigured as e: - self.stdout.write(f"Improperly configured: {e}") + self.stdout.write(f'Improperly configured: {e}') else: if not connection.is_usable(): - self.stdout.write("Database configuration is not usable") + self.stdout.write('Database configuration is not usable') if connected: - self.stdout.write("Database connection successful!") + self.stdout.write('Database connection successful!') diff --git a/InvenTree/InvenTree/metadata.py b/InvenTree/InvenTree/metadata.py index a526cf152e..11a61b3a69 100644 --- a/InvenTree/InvenTree/metadata.py +++ b/InvenTree/InvenTree/metadata.py @@ -45,8 +45,9 @@ class InvenTreeMetadata(SimpleMetadata): context = {} if str2bool(request.query_params.get('context', False)): - - if hasattr(self, 'serializer') and hasattr(self.serializer, 'get_context_data'): + if hasattr(self, 'serializer') and hasattr( + self.serializer, 'get_context_data' + ): context = self.serializer.get_context_data() metadata['context'] = context @@ -68,7 +69,7 @@ class InvenTreeMetadata(SimpleMetadata): metadata['model'] = tbl_label - table = f"{app_label}_{tbl_label}" + table = f'{app_label}_{tbl_label}' actions = metadata.get('actions', None) @@ -86,12 +87,11 @@ class InvenTreeMetadata(SimpleMetadata): } # let the view define a custom rolemap - if hasattr(view, "rolemap"): + if hasattr(view, 'rolemap'): rolemap.update(view.rolemap) # Remove any HTTP methods that the user does not have permission for for method, permission in rolemap.items(): - result = check(user, table, permission) if method in actions and not result: @@ -123,10 +123,7 @@ class InvenTreeMetadata(SimpleMetadata): model_class = None # Attributes to copy extra attributes from the model to the field (if they don't exist) - extra_attributes = [ - 'help_text', - 'max_length', - ] + extra_attributes = ['help_text', 'max_length'] try: model_class = serializer.Meta.model @@ -142,11 +139,8 @@ class InvenTreeMetadata(SimpleMetadata): # Iterate through simple fields for name, field in model_fields.fields.items(): - if name in serializer_info.keys(): - if field.has_default(): - default = field.default if callable(default): @@ -162,13 +156,11 @@ class InvenTreeMetadata(SimpleMetadata): for attr in extra_attributes: if attr not in serializer_info[name]: - if hasattr(field, attr): serializer_info[name][attr] = getattr(field, attr) # Iterate through relations for name, relation in model_fields.relations.items(): - if name not in serializer_info.keys(): # Skip relation not defined in serializer continue @@ -179,11 +171,17 @@ class InvenTreeMetadata(SimpleMetadata): # Extract and provide the "limit_choices_to" filters # This is used to automatically filter AJAX requests - serializer_info[name]['filters'] = relation.model_field.get_limit_choices_to() + serializer_info[name]['filters'] = ( + relation.model_field.get_limit_choices_to() + ) for attr in extra_attributes: - if attr not in serializer_info[name] and hasattr(relation.model_field, attr): - serializer_info[name][attr] = getattr(relation.model_field, attr) + if attr not in serializer_info[name] and hasattr( + relation.model_field, attr + ): + serializer_info[name][attr] = getattr( + relation.model_field, attr + ) if name in model_default_values: serializer_info[name]['default'] = model_default_values[name] @@ -220,11 +218,9 @@ class InvenTreeMetadata(SimpleMetadata): """If there is an instance associated with this API View, introspect that instance to find any specific API info.""" if hasattr(instance, 'api_instance_filters'): - instance_filters = instance.api_instance_filters() for field_name, field_filters in instance_filters.items(): - if field_name not in serializer_info.keys(): # The field might be missing, but is added later on # This function seems to get called multiple times? @@ -256,17 +252,20 @@ class InvenTreeMetadata(SimpleMetadata): # Force non-nullable fields to read as "required" # (even if there is a default value!) - if not field.allow_null and not (hasattr(field, 'allow_blank') and field.allow_blank): + if not field.allow_null and not ( + hasattr(field, 'allow_blank') and field.allow_blank + ): field_info['required'] = True # Introspect writable related fields if field_info['type'] == 'field' and not field_info['read_only']: - # If the field is a PrimaryKeyRelatedField, we can extract the model from the queryset if isinstance(field, serializers.PrimaryKeyRelatedField): model = field.queryset.model else: - logger.debug("Could not extract model for:", field_info.get('label'), '->', field) + logger.debug( + 'Could not extract model for:', field_info.get('label'), '->', field + ) model = None if model: @@ -287,4 +286,4 @@ class InvenTreeMetadata(SimpleMetadata): return field_info -InvenTreeMetadata.label_lookup[DependentField] = "dependent field" +InvenTreeMetadata.label_lookup[DependentField] = 'dependent field' diff --git a/InvenTree/InvenTree/middleware.py b/InvenTree/InvenTree/middleware.py index 7fac82e13f..52428cf3ee 100644 --- a/InvenTree/InvenTree/middleware.py +++ b/InvenTree/InvenTree/middleware.py @@ -9,14 +9,13 @@ from django.http import HttpResponse from django.shortcuts import redirect from django.urls import Resolver404, include, re_path, resolve, reverse_lazy -from allauth_2fa.middleware import (AllauthTwoFactorMiddleware, - BaseRequire2FAMiddleware) +from allauth_2fa.middleware import AllauthTwoFactorMiddleware, BaseRequire2FAMiddleware from error_report.middleware import ExceptionProcessor from InvenTree.urls import frontendpatterns from users.models import ApiToken -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') class AuthRequiredMiddleware(object): @@ -64,11 +63,20 @@ class AuthRequiredMiddleware(object): elif request.path_info.startswith('/accounts/'): authorized = True - elif request.path_info.startswith(f'/{settings.FRONTEND_URL_BASE}/') or request.path_info.startswith('/assets/') or request.path_info == f'/{settings.FRONTEND_URL_BASE}': + elif ( + request.path_info.startswith(f'/{settings.FRONTEND_URL_BASE}/') + or request.path_info.startswith('/assets/') + or request.path_info == f'/{settings.FRONTEND_URL_BASE}' + ): authorized = True - elif 'Authorization' in request.headers.keys() or 'authorization' in request.headers.keys(): - auth = request.headers.get('Authorization', request.headers.get('authorization')).strip() + elif ( + 'Authorization' in request.headers.keys() + or 'authorization' in request.headers.keys() + ): + auth = request.headers.get( + 'Authorization', request.headers.get('authorization') + ).strip() if auth.lower().startswith('token') and len(auth.split()) == 2: token_key = auth.split()[1] @@ -78,13 +86,12 @@ class AuthRequiredMiddleware(object): token = ApiToken.objects.get(key=token_key) if token.active and token.user: - # Provide the user information to the request request.user = token.user authorized = True except ApiToken.DoesNotExist: - logger.warning("Access denied for unknown token %s", token_key) + logger.warning('Access denied for unknown token %s', token_key) # No authorization was found for the request if not authorized: @@ -99,17 +106,16 @@ class AuthRequiredMiddleware(object): ] # Do not redirect requests to any of these paths - paths_ignore = [ - '/api/', - '/js/', - '/media/', - '/static/', - ] + paths_ignore = ['/api/', '/js/', '/media/', '/static/'] - if path not in urls and not any(path.startswith(p) for p in paths_ignore): + if path not in urls and not any( + path.startswith(p) for p in paths_ignore + ): # Save the 'next' parameter to pass through to the login view - return redirect(f'{reverse_lazy("account_login")}?next={request.path}') + return redirect( + f'{reverse_lazy("account_login")}?next={request.path}' + ) # Return a 401 (Unauthorized) response code for this request return HttpResponse('Unauthorized', status=401) @@ -123,6 +129,7 @@ url_matcher = re_path('', include(frontendpatterns)) class Check2FAMiddleware(BaseRequire2FAMiddleware): """Check if user is required to have MFA enabled.""" + def require_2fa(self, request): """Use setting to check if MFA should be enforced for frontend page.""" from common.models import InvenTreeSetting @@ -137,6 +144,7 @@ class Check2FAMiddleware(BaseRequire2FAMiddleware): class CustomAllauthTwoFactorMiddleware(AllauthTwoFactorMiddleware): """This function ensures only frontend code triggers the MFA auth cycle.""" + def process_request(self, request): """Check if requested url is forntend and enforce MFA check.""" try: @@ -148,6 +156,7 @@ class CustomAllauthTwoFactorMiddleware(AllauthTwoFactorMiddleware): class InvenTreeRemoteUserMiddleware(PersistentRemoteUserMiddleware): """Middleware to check if HTTP-header based auth is enabled and to set it up.""" + header = settings.REMOTE_LOGIN_HEADER def process_request(self, request): diff --git a/InvenTree/InvenTree/mixins.py b/InvenTree/InvenTree/mixins.py index 41eb8a5ec6..6981a6bb6f 100644 --- a/InvenTree/InvenTree/mixins.py +++ b/InvenTree/InvenTree/mixins.py @@ -62,7 +62,7 @@ class DiffMixin: return field_name in self.get_field_deltas() -class CleanMixin(): +class CleanMixin: """Model mixin class which cleans inputs using the Mozilla bleach tools.""" # Define a list of field names which will *not* be cleaned @@ -74,13 +74,17 @@ class CleanMixin(): serializer.is_valid(raise_exception=True) self.perform_create(serializer) headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + ) def update(self, request, *args, **kwargs): """Override to clean data before processing it.""" partial = kwargs.pop('partial', False) instance = self.get_object() - serializer = self.get_serializer(instance, data=self.clean_data(request.data), partial=partial) + serializer = self.get_serializer( + instance, data=self.clean_data(request.data), partial=partial + ) serializer.is_valid(raise_exception=True) self.perform_update(serializer) @@ -113,9 +117,7 @@ class CleanMixin(): field = model._meta.get_field(field) # The following field types allow newline characters - allow_newline = [ - InvenTreeNotesField, - ] + allow_newline = [InvenTreeNotesField] for field_type in allow_newline: if issubclass(type(field), field_type): @@ -127,7 +129,9 @@ class CleanMixin(): except FieldDoesNotExist: pass - cleaned = remove_non_printable_characters(cleaned, remove_newline=remove_newline) + cleaned = remove_non_printable_characters( + cleaned, remove_newline=remove_newline + ) return cleaned @@ -148,7 +152,6 @@ class CleanMixin(): clean_data = {} for k, v in data.items(): - if k in self.SAFE_FIELDS: ret = v elif isinstance(v, str): @@ -177,16 +180,19 @@ class CreateAPI(CleanMixin, generics.CreateAPIView): class RetrieveAPI(generics.RetrieveAPIView): """View for retrieve API.""" + pass class RetrieveUpdateAPI(CleanMixin, generics.RetrieveUpdateAPIView): """View for retrieve and update API.""" + pass class CustomDestroyModelMixin: """This mixin was created pass the kwargs from the API to the models.""" + def destroy(self, request, *args, **kwargs): """Custom destroy method to pass kwargs.""" instance = self.get_object() @@ -198,11 +204,14 @@ class CustomDestroyModelMixin: instance.delete(**kwargs) -class CustomRetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, - mixins.UpdateModelMixin, - CustomDestroyModelMixin, - generics.GenericAPIView): +class CustomRetrieveUpdateDestroyAPIView( + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + CustomDestroyModelMixin, + generics.GenericAPIView, +): """This APIView was created pass the kwargs from the API to the models.""" + def get(self, request, *args, **kwargs): """Custom get method to pass kwargs.""" return self.retrieve(request, *args, **kwargs) diff --git a/InvenTree/InvenTree/models.py b/InvenTree/InvenTree/models.py index a4b147ff08..65ca835ce6 100644 --- a/InvenTree/InvenTree/models.py +++ b/InvenTree/InvenTree/models.py @@ -58,6 +58,7 @@ class MetadataMixin(models.Model): class Meta: """Meta for MetadataMixin.""" + abstract = True def save(self, *args, **kwargs): @@ -78,10 +79,13 @@ class MetadataMixin(models.Model): self.metadata = {} if type(self.metadata) is not dict: - raise ValidationError({'metadata': _('Metadata must be a python dict object')}) + raise ValidationError({ + 'metadata': _('Metadata must be a python dict object') + }) metadata = models.JSONField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Plugin Metadata'), help_text=_('JSON metadata field, for use by external plugins'), ) @@ -100,7 +104,9 @@ class MetadataMixin(models.Model): return self.metadata.get(key, backup_value) - def set_metadata(self, key: str, data, commit: bool = True, overwrite: bool = False): + def set_metadata( + self, key: str, data, commit: bool = True, overwrite: bool = False + ): """Save the provided metadata under the provided key. Args: @@ -137,7 +143,6 @@ class DataImportMixin(object): fields = cls.IMPORT_FIELDS for name, field in fields.items(): - # Attempt to extract base field information from the model base_field = None @@ -207,7 +212,10 @@ class ReferenceIndexingMixin(models.Model): # import at function level to prevent cyclic imports from common.models import InvenTreeSetting - return InvenTreeSetting.get_setting(cls.REFERENCE_PATTERN_SETTING, create=False).strip() + + return InvenTreeSetting.get_setting( + cls.REFERENCE_PATTERN_SETTING, create=False + ).strip() @classmethod def get_reference_context(cls): @@ -216,10 +224,7 @@ class ReferenceIndexingMixin(models.Model): - Returns a python dict object which contains the context data for formatting the reference string. - The default implementation provides some default context information """ - return { - 'ref': cls.get_next_reference(), - 'date': datetime.now(), - } + return {'ref': cls.get_next_reference(), 'date': datetime.now()} @classmethod def get_most_recent_item(cls): @@ -246,7 +251,9 @@ class ReferenceIndexingMixin(models.Model): reference = latest.reference.strip try: - reference = InvenTree.format.extract_named_group('ref', reference, cls.get_reference_pattern()) + reference = InvenTree.format.extract_named_group( + 'ref', reference, cls.get_reference_pattern() + ) except Exception: # If reference cannot be extracted using the pattern, try just the integer value reference = str(latest.reference_int) @@ -296,7 +303,7 @@ class ReferenceIndexingMixin(models.Model): if recent: reference = recent.reference else: - reference = "" + reference = '' return reference @@ -309,20 +316,20 @@ class ReferenceIndexingMixin(models.Model): info = InvenTree.format.parse_format_string(pattern) except Exception as exc: raise ValidationError({ - "value": _("Improperly formatted pattern") + ": " + str(exc) + 'value': _('Improperly formatted pattern') + ': ' + str(exc) }) # Check that only 'allowed' keys are provided for key in info.keys(): if key not in ctx.keys(): raise ValidationError({ - "value": _("Unknown format key specified") + f": '{key}'" + 'value': _('Unknown format key specified') + f": '{key}'" }) # Check that the 'ref' variable is specified if 'ref' not in info.keys(): raise ValidationError({ - 'value': _("Missing required format key") + ": 'ref'" + 'value': _('Missing required format key') + ": 'ref'" }) @classmethod @@ -333,14 +340,16 @@ class ReferenceIndexingMixin(models.Model): value = str(value).strip() if len(value) == 0: - raise ValidationError(_("Reference field cannot be empty")) + raise ValidationError(_('Reference field cannot be empty')) # An 'empty' pattern means no further validation is required if not pattern: return if not InvenTree.format.validate_string(value, pattern): - raise ValidationError(_("Reference must match required pattern") + ": " + pattern) + raise ValidationError( + _('Reference must match required pattern') + ': ' + pattern + ) # Check that the reference field can be rebuild cls.rebuild_reference_field(value, validate=True) @@ -361,7 +370,9 @@ class ReferenceIndexingMixin(models.Model): """ try: # Extract named group based on provided pattern - reference = InvenTree.format.extract_named_group('ref', reference, cls.get_reference_pattern()) + reference = InvenTree.format.extract_named_group( + 'ref', reference, cls.get_reference_pattern() + ) except Exception: pass @@ -369,16 +380,14 @@ class ReferenceIndexingMixin(models.Model): if validate: if reference_int > models.BigIntegerField.MAX_BIGINT: - raise ValidationError({ - "reference": _("Reference number is too large") - }) + raise ValidationError({'reference': _('Reference number is too large')}) return reference_int reference_int = models.BigIntegerField(default=0) -def extract_int(reference, clip=0x7fffffff, allow_negative=False): +def extract_int(reference, clip=0x7FFFFFFF, allow_negative=False): """Extract an integer out of reference.""" # Default value if we cannot convert to an integer ref_int = 0 @@ -390,7 +399,7 @@ def extract_int(reference, clip=0x7fffffff, allow_negative=False): return 0 # Look at the start of the string - can it be "integerized"? - result = re.match(r"^(\d+)", reference) + result = re.match(r'^(\d+)', reference) if result and len(result.groups()) == 1: ref = result.groups()[0] @@ -438,6 +447,7 @@ class InvenTreeAttachment(models.Model): class Meta: """Metaclass options. Abstract ensures no database table is created.""" + abstract = True def getSubdir(self): @@ -445,7 +455,7 @@ class InvenTreeAttachment(models.Model): Note: Re-implement this for each subclass of InvenTreeAttachment """ - return "attachments" + return 'attachments' def save(self, *args, **kwargs): """Provide better validation error.""" @@ -472,28 +482,40 @@ class InvenTreeAttachment(models.Model): return os.path.basename(self.attachment.name) return str(self.link) - attachment = models.FileField(upload_to=rename_attachment, verbose_name=_('Attachment'), - help_text=_('Select file to attach'), - blank=True, null=True - ) - - link = InvenTree.fields.InvenTreeURLField( - blank=True, null=True, - verbose_name=_('Link'), - help_text=_('Link to external URL') + attachment = models.FileField( + upload_to=rename_attachment, + verbose_name=_('Attachment'), + help_text=_('Select file to attach'), + blank=True, + null=True, ) - comment = models.CharField(blank=True, max_length=100, verbose_name=_('Comment'), help_text=_('File comment')) + link = InvenTree.fields.InvenTreeURLField( + blank=True, + null=True, + verbose_name=_('Link'), + help_text=_('Link to external URL'), + ) + + comment = models.CharField( + blank=True, + max_length=100, + verbose_name=_('Comment'), + help_text=_('File comment'), + ) user = models.ForeignKey( User, on_delete=models.SET_NULL, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('User'), help_text=_('User'), ) - upload_date = models.DateField(auto_now_add=True, null=True, blank=True, verbose_name=_('upload date')) + upload_date = models.DateField( + auto_now_add=True, null=True, blank=True, verbose_name=_('upload date') + ) @property def basename(self): @@ -522,35 +544,59 @@ class InvenTreeAttachment(models.Model): # Check that there are no directory tricks going on... if new_file.parent != attachment_dir: - logger.error("Attempted to rename attachment outside valid directory: '%s'", new_file) - raise ValidationError(_("Invalid attachment directory")) + logger.error( + "Attempted to rename attachment outside valid directory: '%s'", new_file + ) + raise ValidationError(_('Invalid attachment directory')) # Ignore further checks if the filename is not actually being renamed if new_file == old_file: return - forbidden = ["'", '"', "#", "@", "!", "&", "^", "<", ">", ":", ";", "/", "\\", "|", "?", "*", "%", "~", "`"] + forbidden = [ + "'", + '"', + '#', + '@', + '!', + '&', + '^', + '<', + '>', + ':', + ';', + '/', + '\\', + '|', + '?', + '*', + '%', + '~', + '`', + ] for c in forbidden: if c in fn: raise ValidationError(_(f"Filename contains illegal character '{c}'")) if len(fn.split('.')) < 2: - raise ValidationError(_("Filename missing extension")) + raise ValidationError(_('Filename missing extension')) if not old_file.exists(): - logger.error("Trying to rename attachment '%s' which does not exist", old_file) + logger.error( + "Trying to rename attachment '%s' which does not exist", old_file + ) return if new_file.exists(): - raise ValidationError(_("Attachment with this filename already exists")) + raise ValidationError(_('Attachment with this filename already exists')) try: os.rename(old_file, new_file) self.attachment.name = os.path.join(self.getSubdir(), fn) self.save() except Exception: - raise ValidationError(_("Error renaming file")) + raise ValidationError(_('Error renaming file')) def fully_qualified_url(self): """Return a 'fully qualified' URL for this attachment. @@ -586,10 +632,12 @@ class InvenTreeTree(MPTTModel): class Meta: """Metaclass defines extra model properties.""" + abstract = True class MPTTMeta: """Set insert order.""" + order_insertion_by = ['name'] def delete(self, delete_children=False, delete_items=False): @@ -607,13 +655,19 @@ class InvenTreeTree(MPTTModel): self.refresh_from_db() except self.__class__.DoesNotExist: # If the object no longer exists, raise a ValidationError - raise ValidationError("Object %s of type %s no longer exists", str(self), str(self.__class__)) + raise ValidationError( + 'Object %s of type %s no longer exists', str(self), str(self.__class__) + ) # Cache node ID values for lower nodes, before we delete this one - lower_nodes = list(self.get_descendants(include_self=False).values_list('pk', flat=True)) + lower_nodes = list( + self.get_descendants(include_self=False).values_list('pk', flat=True) + ) # 1. Update nodes and items under the current node - self.handle_tree_delete(delete_children=delete_children, delete_items=delete_items) + self.handle_tree_delete( + delete_children=delete_children, delete_items=delete_items + ) # 2. Delete *this* node super().delete() @@ -673,9 +727,7 @@ class InvenTreeTree(MPTTModel): # - Move all items at any lower level to the parent of this item # - Delete all descendant nodes elif delete_children and not delete_items: - self.get_items(cascade=True).update(**{ - self.ITEM_PARENT_KEY: self.parent - }) + self.get_items(cascade=True).update(**{self.ITEM_PARENT_KEY: self.parent}) self.delete_nodes(child_nodes) @@ -690,9 +742,7 @@ class InvenTreeTree(MPTTModel): # - Move all items directly associated with this node up one level # - Move any direct child nodes up one level elif not delete_children and not delete_items: - self.get_items(cascade=False).update(**{ - self.ITEM_PARENT_KEY: self.parent - }) + self.get_items(cascade=False).update(**{self.ITEM_PARENT_KEY: self.parent}) self.get_children().update(parent=self.parent) def delete_nodes(self, nodes): @@ -719,8 +769,7 @@ class InvenTreeTree(MPTTModel): super().validate_unique(exclude) results = self.__class__.objects.filter( - name=self.name, - parent=self.parent + name=self.name, parent=self.parent ).exclude(pk=self.pk) if results.exists(): @@ -730,17 +779,11 @@ class InvenTreeTree(MPTTModel): def api_instance_filters(self): """Instance filters for InvenTreeTree models.""" - return { - 'parent': { - 'exclude_tree': self.pk, - } - } + return {'parent': {'exclude_tree': self.pk}} def construct_pathstring(self): """Construct the pathstring for this tree node""" - return InvenTree.helpers.constructPathString( - [item.name for item in self.path] - ) + return InvenTree.helpers.constructPathString([item.name for item in self.path]) def save(self, *args, **kwargs): """Custom save method for InvenTreeTree abstract model""" @@ -748,15 +791,12 @@ class InvenTreeTree(MPTTModel): super().save(*args, **kwargs) except InvalidMove: # Provide better error for parent selection - raise ValidationError({ - 'parent': _("Invalid choice"), - }) + raise ValidationError({'parent': _('Invalid choice')}) # Re-calculate the 'pathstring' field pathstring = self.construct_pathstring() if pathstring != self.pathstring: - if 'force_insert' in kwargs: del kwargs['force_insert'] @@ -781,33 +821,29 @@ class InvenTreeTree(MPTTModel): self.__class__.objects.bulk_update(nodes_to_update, ['pathstring']) name = models.CharField( - blank=False, - max_length=100, - verbose_name=_("Name"), - help_text=_("Name"), + blank=False, max_length=100, verbose_name=_('Name'), help_text=_('Name') ) description = models.CharField( blank=True, max_length=250, - verbose_name=_("Description"), - help_text=_("Description (optional)") + verbose_name=_('Description'), + help_text=_('Description (optional)'), ) # When a category is deleted, graft the children onto its parent - parent = TreeForeignKey('self', - on_delete=models.DO_NOTHING, - blank=True, - null=True, - verbose_name=_("parent"), - related_name='children') + parent = TreeForeignKey( + 'self', + on_delete=models.DO_NOTHING, + blank=True, + null=True, + verbose_name=_('parent'), + related_name='children', + ) # The 'pathstring' field is calculated each time the model is saved pathstring = models.CharField( - blank=True, - max_length=250, - verbose_name=_('Path'), - help_text=_('Path') + blank=True, max_length=250, verbose_name=_('Path'), help_text=_('Path') ) def get_items(self, cascade=False): @@ -818,7 +854,7 @@ class InvenTreeTree(MPTTModel): The default implementation returns an empty list """ - raise NotImplementedError(f"items() method not implemented for {type(self)}") + raise NotImplementedError(f'items() method not implemented for {type(self)}') def getUniqueParents(self): """Return a flat set of all parent items that exist above this node. @@ -889,16 +925,11 @@ class InvenTreeTree(MPTTModel): name: , } """ - return [ - { - 'pk': item.pk, - 'name': item.name - } for item in self.path - ] + return [{'pk': item.pk, 'name': item.name} for item in self.path] def __str__(self): """String representation of a category is the full path to that category.""" - return f"{self.pathstring} - {self.description}" + return f'{self.pathstring} - {self.description}' class InvenTreeNotesMixin(models.Model): @@ -914,11 +945,11 @@ class InvenTreeNotesMixin(models.Model): Note: abstract must be true, as this is only a mixin, not a separate table """ + abstract = True notes = InvenTree.fields.InvenTreeNotesField( - verbose_name=_('Notes'), - help_text=_('Markdown notes (optional)'), + verbose_name=_('Notes'), help_text=_('Markdown notes (optional)') ) @@ -941,18 +972,21 @@ class InvenTreeBarcodeMixin(models.Model): Note: abstract must be true, as this is only a mixin, not a separate table """ + abstract = True barcode_data = models.CharField( - blank=True, max_length=500, + blank=True, + max_length=500, verbose_name=_('Barcode Data'), help_text=_('Third party barcode data'), ) barcode_hash = models.CharField( - blank=True, max_length=128, + blank=True, + max_length=128, verbose_name=_('Barcode Hash'), - help_text=_('Unique hash of barcode data') + help_text=_('Unique hash of barcode data'), ) @classmethod @@ -964,21 +998,17 @@ class InvenTreeBarcodeMixin(models.Model): def format_barcode(self, **kwargs): """Return a JSON string for formatting a QR code for this model instance.""" return InvenTree.helpers.MakeBarcode( - self.__class__.barcode_model_type(), - self.pk, - **kwargs + self.__class__.barcode_model_type(), self.pk, **kwargs ) def format_matched_response(self): """Format a standard response for a matched barcode.""" - data = { - 'pk': self.pk, - } + data = {'pk': self.pk} if hasattr(self, 'get_api_url'): api_url = self.get_api_url() - data['api_url'] = f"{api_url}{self.pk}/" + data['api_url'] = f'{api_url}{self.pk}/' if hasattr(self, 'get_absolute_url'): data['web_url'] = self.get_absolute_url() @@ -995,7 +1025,9 @@ class InvenTreeBarcodeMixin(models.Model): """Check if a model instance exists with the specified third-party barcode hash.""" return cls.objects.filter(barcode_hash=barcode_hash).first() - def assign_barcode(self, barcode_hash=None, barcode_data=None, raise_error=True, save=True): + def assign_barcode( + self, barcode_hash=None, barcode_data=None, raise_error=True, save=True + ): """Assign an external (third-party) barcode to this object.""" # Must provide either barcode_hash or barcode_data if barcode_hash is None and barcode_data is None: @@ -1008,7 +1040,7 @@ class InvenTreeBarcodeMixin(models.Model): # Check for existing item if self.__class__.lookup_barcode(barcode_hash) is not None: if raise_error: - raise ValidationError(_("Existing barcode found")) + raise ValidationError(_('Existing barcode found')) else: return False @@ -1044,20 +1076,24 @@ def after_error_logged(sender, instance: Error, created: bool, **kwargs): users = get_user_model().objects.filter(is_staff=True) link = InvenTree.helpers_model.construct_absolute_url( - reverse('admin:error_report_error_change', kwargs={'object_id': instance.pk}) + reverse( + 'admin:error_report_error_change', kwargs={'object_id': instance.pk} + ) ) context = { 'error': instance, 'name': _('Server Error'), 'message': _('An error has been logged by the server.'), - 'link': link + 'link': link, } target_users = [] for user in users: - if common.models.InvenTreeUserSetting.get_setting('NOTIFICATION_ERROR_REPORT', True, user=user): + if common.models.InvenTreeUserSetting.get_setting( + 'NOTIFICATION_ERROR_REPORT', True, user=user + ): target_users.append(user) if len(target_users) > 0: @@ -1066,7 +1102,7 @@ def after_error_logged(sender, instance: Error, created: bool, **kwargs): 'inventree.error_log', context=context, targets=target_users, - delivery_methods={common.notifications.UIMessageNotification, }, + delivery_methods={common.notifications.UIMessageNotification}, ) except Exception as exc: diff --git a/InvenTree/InvenTree/permissions.py b/InvenTree/InvenTree/permissions.py index befbcab4c3..74b844f008 100644 --- a/InvenTree/InvenTree/permissions.py +++ b/InvenTree/InvenTree/permissions.py @@ -18,7 +18,7 @@ def get_model_for_view(view, raise_error=True): if hasattr(view, 'get_serializer_class'): return view.get_serializr_class().Meta.model - raise AttributeError(f"Serializer class not specified for {view.__class__}") + raise AttributeError(f'Serializer class not specified for {view.__class__}') class RolePermission(permissions.BasePermission): @@ -62,7 +62,7 @@ class RolePermission(permissions.BasePermission): } # let the view define a custom rolemap - if hasattr(view, "rolemap"): + if hasattr(view, 'rolemap'): rolemap.update(view.rolemap) permission = rolemap[request.method] @@ -78,7 +78,7 @@ class RolePermission(permissions.BasePermission): app_label = model._meta.app_label model_name = model._meta.model_name - table = f"{app_label}_{model_name}" + table = f'{app_label}_{model_name}' except AttributeError: # We will assume that if the serializer class does *not* have a Meta, # then we don't need a permission @@ -100,12 +100,18 @@ class IsStaffOrReadOnly(permissions.IsAdminUser): def has_permission(self, request, view): """Check if the user is a superuser.""" - return bool(request.user and request.user.is_staff or request.method in permissions.SAFE_METHODS) + return bool( + request.user + and request.user.is_staff + or request.method in permissions.SAFE_METHODS + ) def auth_exempt(view_func): """Mark a view function as being exempt from auth requirements.""" + def wrapped_view(*args, **kwargs): return view_func(*args, **kwargs) + wrapped_view.auth_exempt = True return wraps(view_func)(wrapped_view) diff --git a/InvenTree/InvenTree/ready.py b/InvenTree/InvenTree/ready.py index 4f1d46df44..64161f1a7a 100644 --- a/InvenTree/InvenTree/ready.py +++ b/InvenTree/InvenTree/ready.py @@ -16,11 +16,7 @@ def isImportingData(): def isRunningMigrations(): """Return True if the database is currently running migrations.""" - return any((x in sys.argv for x in [ - 'migrate', - 'makemigrations', - 'showmigrations' - ])) + return any((x in sys.argv for x in ['migrate', 'makemigrations', 'showmigrations'])) def isInMainThread(): @@ -29,13 +25,15 @@ def isInMainThread(): - The RUN_MAIN env is set in that case. However if --noreload is applied, this variable is not set because there are no different threads. """ - if "runserver" in sys.argv and "--noreload" not in sys.argv: - return os.environ.get('RUN_MAIN', None) == "true" + if 'runserver' in sys.argv and '--noreload' not in sys.argv: + return os.environ.get('RUN_MAIN', None) == 'true' return True -def canAppAccessDatabase(allow_test: bool = False, allow_plugins: bool = False, allow_shell: bool = False): +def canAppAccessDatabase( + allow_test: bool = False, allow_plugins: bool = False, allow_shell: bool = False +): """Returns True if the apps.py file can access database records. There are some circumstances where we don't want the ready function in apps.py diff --git a/InvenTree/InvenTree/sanitizer.py b/InvenTree/InvenTree/sanitizer.py index 0565c1b7f2..0272b50f94 100644 --- a/InvenTree/InvenTree/sanitizer.py +++ b/InvenTree/InvenTree/sanitizer.py @@ -1,49 +1,196 @@ """Functions to sanitize user input files.""" + from bleach import clean from bleach.css_sanitizer import CSSSanitizer ALLOWED_ELEMENTS_SVG = [ - 'a', 'animate', 'animateColor', 'animateMotion', - 'animateTransform', 'circle', 'defs', 'desc', 'ellipse', 'font-face', - 'font-face-name', 'font-face-src', 'g', 'glyph', 'hkern', - 'linearGradient', 'line', 'marker', 'metadata', 'missing-glyph', - 'mpath', 'path', 'polygon', 'polyline', 'radialGradient', 'rect', - 'set', 'stop', 'svg', 'switch', 'text', 'title', 'tspan', 'use' + 'a', + 'animate', + 'animateColor', + 'animateMotion', + 'animateTransform', + 'circle', + 'defs', + 'desc', + 'ellipse', + 'font-face', + 'font-face-name', + 'font-face-src', + 'g', + 'glyph', + 'hkern', + 'linearGradient', + 'line', + 'marker', + 'metadata', + 'missing-glyph', + 'mpath', + 'path', + 'polygon', + 'polyline', + 'radialGradient', + 'rect', + 'set', + 'stop', + 'svg', + 'switch', + 'text', + 'title', + 'tspan', + 'use', ] ALLOWED_ATTRIBUTES_SVG = [ - 'accent-height', 'accumulate', 'additive', 'alphabetic', - 'arabic-form', 'ascent', 'attributeName', 'attributeType', - 'baseProfile', 'bbox', 'begin', 'by', 'calcMode', 'cap-height', - 'class', 'color', 'color-rendering', 'content', 'cx', 'cy', 'd', 'dx', - 'dy', 'descent', 'display', 'dur', 'end', 'fill', 'fill-opacity', - 'fill-rule', 'font-family', 'font-size', 'font-stretch', 'font-style', - 'font-variant', 'font-weight', 'from', 'fx', 'fy', 'g1', 'g2', - 'glyph-name', 'gradientUnits', 'hanging', 'height', 'horiz-adv-x', - 'horiz-origin-x', 'id', 'ideographic', 'k', 'keyPoints', - 'keySplines', 'keyTimes', 'lang', 'marker-end', 'marker-mid', - 'marker-start', 'markerHeight', 'markerUnits', 'markerWidth', - 'mathematical', 'max', 'min', 'name', 'offset', 'opacity', 'orient', - 'origin', 'overline-position', 'overline-thickness', 'panose-1', - 'path', 'pathLength', 'points', 'preserveAspectRatio', 'r', 'refX', - 'refY', 'repeatCount', 'repeatDur', 'requiredExtensions', - 'requiredFeatures', 'restart', 'rotate', 'rx', 'ry', 'slope', - 'stemh', 'stemv', 'stop-color', 'stop-opacity', - 'strikethrough-position', 'strikethrough-thickness', 'stroke', - 'stroke-dasharray', 'stroke-dashoffset', 'stroke-linecap', - 'stroke-linejoin', 'stroke-miterlimit', 'stroke-opacity', - 'stroke-width', 'systemLanguage', 'target', 'text-anchor', 'to', - 'transform', 'type', 'u1', 'u2', 'underline-position', - 'underline-thickness', 'unicode', 'unicode-range', 'units-per-em', - 'values', 'version', 'viewBox', 'visibility', 'width', 'widths', 'x', - 'x-height', 'x1', 'x2', 'xlink:actuate', 'xlink:arcrole', - 'xlink:href', 'xlink:role', 'xlink:show', 'xlink:title', - 'xlink:type', 'xml:base', 'xml:lang', 'xml:space', 'xmlns', - 'xmlns:xlink', 'y', 'y1', 'y2', 'zoomAndPan', 'style' + 'accent-height', + 'accumulate', + 'additive', + 'alphabetic', + 'arabic-form', + 'ascent', + 'attributeName', + 'attributeType', + 'baseProfile', + 'bbox', + 'begin', + 'by', + 'calcMode', + 'cap-height', + 'class', + 'color', + 'color-rendering', + 'content', + 'cx', + 'cy', + 'd', + 'dx', + 'dy', + 'descent', + 'display', + 'dur', + 'end', + 'fill', + 'fill-opacity', + 'fill-rule', + 'font-family', + 'font-size', + 'font-stretch', + 'font-style', + 'font-variant', + 'font-weight', + 'from', + 'fx', + 'fy', + 'g1', + 'g2', + 'glyph-name', + 'gradientUnits', + 'hanging', + 'height', + 'horiz-adv-x', + 'horiz-origin-x', + 'id', + 'ideographic', + 'k', + 'keyPoints', + 'keySplines', + 'keyTimes', + 'lang', + 'marker-end', + 'marker-mid', + 'marker-start', + 'markerHeight', + 'markerUnits', + 'markerWidth', + 'mathematical', + 'max', + 'min', + 'name', + 'offset', + 'opacity', + 'orient', + 'origin', + 'overline-position', + 'overline-thickness', + 'panose-1', + 'path', + 'pathLength', + 'points', + 'preserveAspectRatio', + 'r', + 'refX', + 'refY', + 'repeatCount', + 'repeatDur', + 'requiredExtensions', + 'requiredFeatures', + 'restart', + 'rotate', + 'rx', + 'ry', + 'slope', + 'stemh', + 'stemv', + 'stop-color', + 'stop-opacity', + 'strikethrough-position', + 'strikethrough-thickness', + 'stroke', + 'stroke-dasharray', + 'stroke-dashoffset', + 'stroke-linecap', + 'stroke-linejoin', + 'stroke-miterlimit', + 'stroke-opacity', + 'stroke-width', + 'systemLanguage', + 'target', + 'text-anchor', + 'to', + 'transform', + 'type', + 'u1', + 'u2', + 'underline-position', + 'underline-thickness', + 'unicode', + 'unicode-range', + 'units-per-em', + 'values', + 'version', + 'viewBox', + 'visibility', + 'width', + 'widths', + 'x', + 'x-height', + 'x1', + 'x2', + 'xlink:actuate', + 'xlink:arcrole', + 'xlink:href', + 'xlink:role', + 'xlink:show', + 'xlink:title', + 'xlink:type', + 'xml:base', + 'xml:lang', + 'xml:space', + 'xmlns', + 'xmlns:xlink', + 'y', + 'y1', + 'y2', + 'zoomAndPan', + 'style', ] -def sanitize_svg(file_data, strip: bool = True, elements: str = ALLOWED_ELEMENTS_SVG, attributes: str = ALLOWED_ATTRIBUTES_SVG) -> str: +def sanitize_svg( + file_data, + strip: bool = True, + elements: str = ALLOWED_ELEMENTS_SVG, + attributes: str = ALLOWED_ATTRIBUTES_SVG, +) -> str: """Sanitize a SVG file. Args: @@ -65,7 +212,7 @@ def sanitize_svg(file_data, strip: bool = True, elements: str = ALLOWED_ELEMENTS attributes=attributes, strip=strip, strip_comments=strip, - css_sanitizer=CSSSanitizer() + css_sanitizer=CSSSanitizer(), ) return cleaned diff --git a/InvenTree/InvenTree/sentry.py b/InvenTree/InvenTree/sentry.py index b1ead7a5cb..cd265f199d 100644 --- a/InvenTree/InvenTree/sentry.py +++ b/InvenTree/InvenTree/sentry.py @@ -37,7 +37,7 @@ def sentry_ignore_errors(): def init_sentry(dsn, sample_rate, tags): """Initialize sentry.io error reporting""" - logger.info("Initializing sentry.io integration") + logger.info('Initializing sentry.io integration') sentry_sdk.init( dsn=dsn, @@ -46,7 +46,9 @@ def init_sentry(dsn, sample_rate, tags): send_default_pii=True, ignore_errors=sentry_ignore_errors(), release=InvenTree.version.INVENTREE_SW_VERSION, - environment='development' if InvenTree.version.isInvenTreeDevelopmentVersion() else 'production' + environment='development' + if InvenTree.version.isInvenTreeDevelopmentVersion() + else 'production', ) for key, val in tags.items(): @@ -62,11 +64,10 @@ def init_sentry(dsn, sample_rate, tags): def report_exception(exc): """Report an exception to sentry.io""" if settings.SENTRY_ENABLED and settings.SENTRY_DSN: - if not any(isinstance(exc, e) for e in sentry_ignore_errors()): - logger.info("Reporting exception to sentry.io: %s", exc) + logger.info('Reporting exception to sentry.io: %s', exc) try: sentry_sdk.capture_exception(exc) except Exception: - logger.warning("Failed to report exception to sentry.io") + logger.warning('Failed to report exception to sentry.io') diff --git a/InvenTree/InvenTree/serializers.py b/InvenTree/InvenTree/serializers.py index d0304bb313..824f11f25b 100644 --- a/InvenTree/InvenTree/serializers.py +++ b/InvenTree/InvenTree/serializers.py @@ -37,9 +37,9 @@ class InvenTreeMoneySerializer(MoneyField): def __init__(self, *args, **kwargs): """Override default values.""" - kwargs["max_digits"] = kwargs.get("max_digits", 19) - self.decimal_places = kwargs["decimal_places"] = kwargs.get("decimal_places", 6) - kwargs["required"] = kwargs.get("required", False) + kwargs['max_digits'] = kwargs.get('max_digits', 19) + self.decimal_places = kwargs['decimal_places'] = kwargs.get('decimal_places', 6) + kwargs['required'] = kwargs.get('required', False) super().__init__(*args, **kwargs) @@ -57,13 +57,18 @@ class InvenTreeMoneySerializer(MoneyField): amount = Decimal(amount) amount = round(amount, self.decimal_places) except Exception: - raise ValidationError({ - self.field_name: [_("Must be a valid number")], - }) + raise ValidationError({self.field_name: [_('Must be a valid number')]}) - currency = data.get(get_currency_field_name(self.field_name), self.default_currency) + currency = data.get( + get_currency_field_name(self.field_name), self.default_currency + ) - if currency and amount is not None and not isinstance(amount, MONEY_CLASSES) and amount is not empty: + if ( + currency + and amount is not None + and not isinstance(amount, MONEY_CLASSES) + and amount is not empty + ): return Money(amount, currency) return amount @@ -76,7 +81,9 @@ class InvenTreeCurrencySerializer(serializers.ChoiceField): """Initialize the currency serializer""" choices = currency_code_mappings() - allow_blank = kwargs.get('allow_blank', False) or kwargs.get('allow_null', False) + allow_blank = kwargs.get('allow_blank', False) or kwargs.get( + 'allow_null', False + ) if allow_blank: choices = [('', '---------')] + choices @@ -97,6 +104,7 @@ class InvenTreeCurrencySerializer(serializers.ChoiceField): class DependentField(serializers.Field): """A dependent field can be used to dynamically return child fields based on the value of other fields.""" + child = None def __init__(self, *args, depends_on, field_serializer, **kwargs): @@ -126,7 +134,7 @@ class DependentField(serializers.Field): def get_child(self, raise_exception=False): """This method tries to extract the child based on the provided data in the request by the client.""" - data = deepcopy(self.context["request"].data) + data = deepcopy(self.context['request'].data) def visit_parent(node): """Recursively extract the data for the parent field/serializer in reverse.""" @@ -136,8 +144,9 @@ class DependentField(serializers.Field): visit_parent(node.parent) # only do for composite fields and stop right before the current field - if hasattr(node, "child") and node is not self and isinstance(data, dict): + if hasattr(node, 'child') and node is not self and isinstance(data, dict): data = data.get(node.field_name, None) + visit_parent(self) # ensure that data is a dictionary and that a parent exists @@ -152,7 +161,9 @@ class DependentField(serializers.Field): # partially validate the data for options requests that set raise_exception while calling .get_child(...) if raise_exception: validation_data = {k: v for k, v in data.items() if k in self.depends_on} - serializer = self.parent.__class__(context=self.context, data=validation_data, partial=True) + serializer = self.parent.__class__( + context=self.context, data=validation_data, partial=True + ) serializer.is_valid(raise_exception=raise_exception) # try to get the field serializer @@ -196,7 +207,6 @@ class InvenTreeModelSerializer(serializers.ModelSerializer): """Custom __init__ routine to ensure that *default* values (as specified in the ORM) are used by the DRF serializers, *if* the values are not provided by the user.""" # If instance is None, we are creating a new instance if instance is None and data is not empty: - if data is None: data = OrderedDict() else: @@ -211,7 +221,6 @@ class InvenTreeModelSerializer(serializers.ModelSerializer): fields = model_meta.get_field_info(ModelClass) for field_name, field in fields.fields.items(): - """ Update the field IF (and ONLY IF): @@ -219,7 +228,6 @@ class InvenTreeModelSerializer(serializers.ModelSerializer): - The field does not already have a value set """ if field.has_default() and field_name not in data: - value = field.default # Account for callable functions @@ -247,9 +255,7 @@ class InvenTreeModelSerializer(serializers.ModelSerializer): fields = model_meta.get_field_info(ModelClass) for field_name, field in fields.fields.items(): - if field.has_default() and field_name not in initials: - value = field.default # Account for callable functions @@ -337,7 +343,6 @@ class InvenTreeModelSerializer(serializers.ModelSerializer): try: instance.full_clean() except (ValidationError, DjangoValidationError) as exc: - data = exc.message_dict # Change '__all__' key (django style) to 'non_field_errors' (DRF style) @@ -369,6 +374,7 @@ class InvenTreeTaggitSerializer(TaggitSerializer): class InvenTreeTagModelSerializer(InvenTreeTaggitSerializer, InvenTreeModelSerializer): """Combination of InvenTreeTaggitSerializer and InvenTreeModelSerializer.""" + pass @@ -377,38 +383,31 @@ class UserSerializer(InvenTreeModelSerializer): class Meta: """Metaclass defines serializer fields.""" - model = User - fields = [ - 'pk', - 'username', - 'first_name', - 'last_name', - 'email', - ] - read_only_fields = [ - 'username', - ] + model = User + fields = ['pk', 'username', 'first_name', 'last_name', 'email'] + + read_only_fields = ['username'] class ExendedUserSerializer(UserSerializer): """Serializer for a User with a bit more info.""" + from users.serializers import GroupSerializer groups = GroupSerializer(read_only=True, many=True) class Meta(UserSerializer.Meta): """Metaclass defines serializer fields.""" + fields = UserSerializer.Meta.fields + [ 'groups', 'is_staff', 'is_superuser', - 'is_active' + 'is_active', ] - read_only_fields = UserSerializer.Meta.read_only_fields + [ - 'groups', - ] + read_only_fields = UserSerializer.Meta.read_only_fields + ['groups'] def validate(self, attrs): """Expanded validation for changing user role.""" @@ -424,17 +423,20 @@ class ExendedUserSerializer(UserSerializer): # Staff can change any role except is_superuser pass else: - raise PermissionDenied(_("You do not have permission to change this user role.")) + raise PermissionDenied( + _('You do not have permission to change this user role.') + ) return super().validate(attrs) class UserCreateSerializer(ExendedUserSerializer): """Serializer for creating a new User.""" + def validate(self, attrs): """Expanded valiadation for auth.""" # Check that the user trying to create a new user is a superuser if not self.context['request'].user.is_superuser: - raise serializers.ValidationError(_("Only superusers can create new users")) + raise serializers.ValidationError(_('Only superusers can create new users')) # Generate a random password password = User.objects.make_random_password(length=14) @@ -451,8 +453,10 @@ class UserCreateSerializer(ExendedUserSerializer): current_site = Site.objects.get_current() domain = current_site.domain instance.email_user( - subject=_(f"Welcome to {current_site.name}"), - message=_(f"Your account has been created.\n\nPlease use the password reset function to get access (at https://{domain})."), + subject=_(f'Welcome to {current_site.name}'), + message=_( + f'Your account has been created.\n\nPlease use the password reset function to get access (at https://{domain}).' + ), ) return instance @@ -509,17 +513,11 @@ class InvenTreeAttachmentSerializer(InvenTreeModelSerializer): user_detail = UserSerializer(source='user', read_only=True, many=False) - attachment = InvenTreeAttachmentSerializerField( - required=False, - allow_null=False, - ) + attachment = InvenTreeAttachmentSerializerField(required=False, allow_null=False) # The 'filename' field must be present in the serializer filename = serializers.CharField( - label=_('Filename'), - required=False, - source='basename', - allow_blank=False, + label=_('Filename'), required=False, source='basename', allow_blank=False ) upload_date = serializers.DateField(read_only=True) @@ -553,7 +551,7 @@ class InvenTreeDecimalField(serializers.FloatField): try: return Decimal(str(data)) except Exception: - raise serializers.ValidationError(_("Invalid value")) + raise serializers.ValidationError(_('Invalid value')) class DataFileUploadSerializer(serializers.Serializer): @@ -570,13 +568,11 @@ class DataFileUploadSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'data_file', - ] + fields = ['data_file'] data_file = serializers.FileField( - label=_("Data File"), - help_text=_("Select data file for upload"), + label=_('Data File'), + help_text=_('Select data file for upload'), required=True, allow_empty_file=False, ) @@ -590,20 +586,16 @@ class DataFileUploadSerializer(serializers.Serializer): # Remove the leading . from the extension ext = ext[1:] - accepted_file_types = [ - 'xls', 'xlsx', - 'csv', 'tsv', - 'xml', - ] + accepted_file_types = ['xls', 'xlsx', 'csv', 'tsv', 'xml'] if ext not in accepted_file_types: - raise serializers.ValidationError(_("Unsupported file type")) + raise serializers.ValidationError(_('Unsupported file type')) # Impose a 50MB limit on uploaded BOM files max_upload_file_size = 50 * 1024 * 1024 if data_file.size > max_upload_file_size: - raise serializers.ValidationError(_("File is too large")) + raise serializers.ValidationError(_('File is too large')) # Read file data into memory (bytes object) try: @@ -624,10 +616,10 @@ class DataFileUploadSerializer(serializers.Serializer): raise serializers.ValidationError(str(e)) if len(self.dataset.headers) == 0: - raise serializers.ValidationError(_("No columns found in file")) + raise serializers.ValidationError(_('No columns found in file')) if len(self.dataset) == 0: - raise serializers.ValidationError(_("No data rows found in file")) + raise serializers.ValidationError(_('No data rows found in file')) return data_file @@ -721,24 +713,14 @@ class DataFileExtractSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'columns', - 'rows', - ] + fields = ['columns', 'rows'] # Mapping of columns - columns = serializers.ListField( - child=serializers.CharField( - allow_blank=True, - ), - ) + columns = serializers.ListField(child=serializers.CharField(allow_blank=True)) rows = serializers.ListField( child=serializers.ListField( - child=serializers.CharField( - allow_blank=True, - allow_null=True, - ), + child=serializers.CharField(allow_blank=True, allow_null=True) ) ) @@ -750,10 +732,10 @@ class DataFileExtractSerializer(serializers.Serializer): self.rows = data.get('rows', []) if len(self.rows) == 0: - raise serializers.ValidationError(_("No data rows provided")) + raise serializers.ValidationError(_('No data rows provided')) if len(self.columns) == 0: - raise serializers.ValidationError(_("No data columns supplied")) + raise serializers.ValidationError(_('No data columns supplied')) self.validate_extracted_columns() @@ -776,16 +758,9 @@ class DataFileExtractSerializer(serializers.Serializer): processed_row = self.process_row(self.row_to_dict(row)) if processed_row: - rows.append({ - "original": row, - "data": processed_row, - }) + rows.append({'original': row, 'data': processed_row}) - return { - 'fields': model_fields, - 'columns': self.columns, - 'rows': rows, - } + return {'fields': model_fields, 'columns': self.columns, 'rows': rows} def process_row(self, row): """Process a 'row' of data, which is a mapped column:value dict. @@ -799,12 +774,9 @@ class DataFileExtractSerializer(serializers.Serializer): def row_to_dict(self, row): """Convert a "row" to a named data dict.""" - row_dict = { - 'errors': {}, - } + row_dict = {'errors': {}} for idx, value in enumerate(row): - if idx < len(self.columns): col = self.columns[idx] @@ -824,16 +796,16 @@ class DataFileExtractSerializer(serializers.Serializer): cols_seen = set() for name, field in model_fields.items(): - required = field.get('required', False) # Check for missing required columns if required: if name not in self.columns: - raise serializers.ValidationError(_(f"Missing required column: '{name}'")) + raise serializers.ValidationError( + _(f"Missing required column: '{name}'") + ) for col in self.columns: - if not col: continue @@ -856,16 +828,14 @@ class RemoteImageMixin(metaclass=serializers.SerializerMetaclass): def skip_create_fields(self): """Ensure the 'remote_image' field is skipped when creating a new instance""" - return [ - 'remote_image', - ] + return ['remote_image'] remote_image = serializers.URLField( required=False, allow_blank=False, write_only=True, - label=_("Remote Image"), - help_text=_("URL of remote image file"), + label=_('Remote Image'), + help_text=_('URL of remote image file'), ) def validate_remote_image(self, url): @@ -877,8 +847,12 @@ class RemoteImageMixin(metaclass=serializers.SerializerMetaclass): if not url: return - if not common_models.InvenTreeSetting.get_setting('INVENTREE_DOWNLOAD_FROM_URL'): - raise ValidationError(_("Downloading images from remote URL is not enabled")) + if not common_models.InvenTreeSetting.get_setting( + 'INVENTREE_DOWNLOAD_FROM_URL' + ): + raise ValidationError( + _('Downloading images from remote URL is not enabled') + ) try: self.remote_image_file = download_image_from_url(url) diff --git a/InvenTree/InvenTree/settings.py b/InvenTree/InvenTree/settings.py index fb2bd0668a..8f46e5295e 100644 --- a/InvenTree/InvenTree/settings.py +++ b/InvenTree/InvenTree/settings.py @@ -38,9 +38,8 @@ INVENTREE_NEWS_URL = 'https://inventree.org/news/feed.atom' TESTING = 'test' in sys.argv or 'TESTING' in os.environ if TESTING: - # Use a weaker password hasher for testing (improves testing speed) - PASSWORD_HASHERS = ['django.contrib.auth.hashers.MD5PasswordHasher',] + PASSWORD_HASHERS = ['django.contrib.auth.hashers.MD5PasswordHasher'] # Enable slow-test-runner TEST_RUNNER = 'django_slowtests.testrunner.DiscoverSlowestTestsRunner' @@ -53,7 +52,7 @@ if TESTING: site_packages = '/usr/local/lib/python3.9/site-packages' if site_packages not in sys.path: - print("Adding missing site-packages path:", site_packages) + print('Adding missing site-packages path:', site_packages) sys.path.append(site_packages) # Are environment variables manipulated by tests? Needs to be set by testing code @@ -78,16 +77,17 @@ if version_file.exists(): # SECURITY WARNING: don't run with debug turned on in production! DEBUG = get_boolean_setting('INVENTREE_DEBUG', 'debug', True) -ENABLE_CLASSIC_FRONTEND = get_boolean_setting('INVENTREE_CLASSIC_FRONTEND', 'classic_frontend', True) -ENABLE_PLATFORM_FRONTEND = get_boolean_setting('INVENTREE_PLATFORM_FRONTEND', 'platform_frontend', True) +ENABLE_CLASSIC_FRONTEND = get_boolean_setting( + 'INVENTREE_CLASSIC_FRONTEND', 'classic_frontend', True +) +ENABLE_PLATFORM_FRONTEND = get_boolean_setting( + 'INVENTREE_PLATFORM_FRONTEND', 'platform_frontend', True +) # Configure logging settings log_level = get_setting('INVENTREE_LOG_LEVEL', 'log_level', 'WARNING') -logging.basicConfig( - level=log_level, - format="%(asctime)s %(levelname)s %(message)s", -) +logging.basicConfig(level=log_level, format='%(asctime)s %(levelname)s %(message)s') if log_level not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']: log_level = 'WARNING' # pragma: no cover @@ -95,32 +95,21 @@ if log_level not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']: LOGGING = { 'version': 1, 'disable_existing_loggers': False, - 'handlers': { - 'console': { - 'class': 'logging.StreamHandler', - }, - }, - 'root': { - 'handlers': ['console'], - 'level': log_level, - }, + 'handlers': {'console': {'class': 'logging.StreamHandler'}}, + 'root': {'handlers': ['console'], 'level': log_level}, 'filters': { 'require_not_maintenance_mode_503': { - '()': 'maintenance_mode.logging.RequireNotMaintenanceMode503', - }, + '()': 'maintenance_mode.logging.RequireNotMaintenanceMode503' + } }, } # Optionally add database-level logging if get_setting('INVENTREE_DB_LOGGING', 'db_logging', False): - LOGGING['loggers'] = { - 'django.db.backends': { - 'level': log_level or 'DEBUG', - }, - } + LOGGING['loggers'] = {'django.db.backends': {'level': log_level or 'DEBUG'}} # Get a logger instance for this setup file -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') # Load SECRET_KEY SECRET_KEY = config.get_secret_key() @@ -133,7 +122,7 @@ MEDIA_ROOT = config.get_media_dir() # List of allowed hosts (default = allow all) ALLOWED_HOSTS = get_setting( - "INVENTREE_ALLOWED_HOSTS", + 'INVENTREE_ALLOWED_HOSTS', config_key='allowed_hosts', default_value=['*'], typecast=list, @@ -146,13 +135,11 @@ CORS_URLS_REGEX = r'^/(api|media|static)/.*$' # Extract CORS options from configuration file CORS_ORIGIN_ALLOW_ALL = get_boolean_setting( - "INVENTREE_CORS_ORIGIN_ALLOW_ALL", - config_key='cors.allow_all', - default_value=False, + 'INVENTREE_CORS_ORIGIN_ALLOW_ALL', config_key='cors.allow_all', default_value=False ) CORS_ORIGIN_WHITELIST = get_setting( - "INVENTREE_CORS_ORIGIN_WHITELIST", + 'INVENTREE_CORS_ORIGIN_WHITELIST', config_key='cors.whitelist', default_value=[], typecast=list, @@ -173,9 +160,7 @@ STATICFILES_I18_TRG = BASE_DIR.joinpath('InvenTree', 'static_i18n') STATICFILES_DIRS.append(STATICFILES_I18_TRG) STATICFILES_I18_TRG = STATICFILES_I18_TRG.joinpath(STATICFILES_I18_PREFIX) -STATFILES_I18_PROCESSORS = [ - 'InvenTree.context.status_codes', -] +STATFILES_I18_PROCESSORS = ['InvenTree.context.status_codes'] # Color Themes Directory STATIC_COLOR_THEMES_DIR = STATIC_ROOT.joinpath('css', 'color-themes').resolve() @@ -189,38 +174,33 @@ DBBACKUP_SEND_EMAIL = False DBBACKUP_STORAGE = get_setting( 'INVENTREE_BACKUP_STORAGE', 'backup_storage', - 'django.core.files.storage.FileSystemStorage' + 'django.core.files.storage.FileSystemStorage', ) # Default backup configuration -DBBACKUP_STORAGE_OPTIONS = get_setting('INVENTREE_BACKUP_OPTIONS', 'backup_options', None) +DBBACKUP_STORAGE_OPTIONS = get_setting( + 'INVENTREE_BACKUP_OPTIONS', 'backup_options', None +) if DBBACKUP_STORAGE_OPTIONS is None: - DBBACKUP_STORAGE_OPTIONS = { - 'location': config.get_backup_dir(), - } + DBBACKUP_STORAGE_OPTIONS = {'location': config.get_backup_dir()} INVENTREE_ADMIN_ENABLED = get_boolean_setting( - 'INVENTREE_ADMIN_ENABLED', - config_key='admin_enabled', - default_value=True + 'INVENTREE_ADMIN_ENABLED', config_key='admin_enabled', default_value=True ) # Base URL for admin pages (default="admin") INVENTREE_ADMIN_URL = get_setting( - 'INVENTREE_ADMIN_URL', - config_key='admin_url', - default_value='admin' + 'INVENTREE_ADMIN_URL', config_key='admin_url', default_value='admin' ) INSTALLED_APPS = [ # Admin site integration 'django.contrib.admin', - # InvenTree apps 'build.apps.BuildConfig', 'common.apps.CommonConfig', 'company.apps.CompanyConfig', - 'plugin.apps.PluginAppConfig', # Plugin app runs before all apps that depend on the isPluginRegistryLoaded function + 'plugin.apps.PluginAppConfig', # Plugin app runs before all apps that depend on the isPluginRegistryLoaded function 'label.apps.LabelConfig', 'order.apps.OrderConfig', 'part.apps.PartConfig', @@ -229,109 +209,113 @@ INSTALLED_APPS = [ 'users.apps.UsersConfig', 'web', 'generic', - 'InvenTree.apps.InvenTreeConfig', # InvenTree app runs last - + 'InvenTree.apps.InvenTreeConfig', # InvenTree app runs last # Core django modules 'django.contrib.auth', 'django.contrib.contenttypes', - 'user_sessions', # db user sessions + 'user_sessions', # db user sessions 'django.contrib.messages', 'django.contrib.staticfiles', 'django.contrib.sites', - # Maintenance 'maintenance_mode', - # Third part add-ons - 'django_filters', # Extended filter functionality - 'rest_framework', # DRF (Django Rest Framework) - 'corsheaders', # Cross-origin Resource Sharing for DRF - 'crispy_forms', # Improved form rendering - 'import_export', # Import / export tables to file - 'django_cleanup.apps.CleanupConfig', # Automatically delete orphaned MEDIA files - 'mptt', # Modified Preorder Tree Traversal - 'markdownify', # Markdown template rendering - 'djmoney', # django-money integration - 'djmoney.contrib.exchange', # django-money exchange rates - 'error_report', # Error reporting in the admin interface + 'django_filters', # Extended filter functionality + 'rest_framework', # DRF (Django Rest Framework) + 'corsheaders', # Cross-origin Resource Sharing for DRF + 'crispy_forms', # Improved form rendering + 'import_export', # Import / export tables to file + 'django_cleanup.apps.CleanupConfig', # Automatically delete orphaned MEDIA files + 'mptt', # Modified Preorder Tree Traversal + 'markdownify', # Markdown template rendering + 'djmoney', # django-money integration + 'djmoney.contrib.exchange', # django-money exchange rates + 'error_report', # Error reporting in the admin interface 'django_q', - 'formtools', # Form wizard tools - 'dbbackup', # Backups - django-dbbackup - 'taggit', # Tagging - 'flags', # Flagging - django-flags - - 'allauth', # Base app for SSO - 'allauth.account', # Extend user with accounts - 'allauth.socialaccount', # Use 'social' providers - - 'django_otp', # OTP is needed for MFA - base package - 'django_otp.plugins.otp_totp', # Time based OTP - 'django_otp.plugins.otp_static', # Backup codes - - 'allauth_2fa', # MFA flow for allauth - 'dj_rest_auth', # Authentication APIs - dj-rest-auth - 'dj_rest_auth.registration', # Registration APIs - dj-rest-auth' - 'drf_spectacular', # API documentation - - 'django_ical', # For exporting calendars + 'formtools', # Form wizard tools + 'dbbackup', # Backups - django-dbbackup + 'taggit', # Tagging + 'flags', # Flagging - django-flags + 'allauth', # Base app for SSO + 'allauth.account', # Extend user with accounts + 'allauth.socialaccount', # Use 'social' providers + 'django_otp', # OTP is needed for MFA - base package + 'django_otp.plugins.otp_totp', # Time based OTP + 'django_otp.plugins.otp_static', # Backup codes + 'allauth_2fa', # MFA flow for allauth + 'dj_rest_auth', # Authentication APIs - dj-rest-auth + 'dj_rest_auth.registration', # Registration APIs - dj-rest-auth' + 'drf_spectacular', # API documentation + 'django_ical', # For exporting calendars ] -MIDDLEWARE = CONFIG.get('middleware', [ - 'django.middleware.security.SecurityMiddleware', - 'x_forwarded_for.middleware.XForwardedForMiddleware', - 'user_sessions.middleware.SessionMiddleware', # db user sessions - 'django.middleware.locale.LocaleMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'corsheaders.middleware.CorsMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'InvenTree.middleware.InvenTreeRemoteUserMiddleware', # Remote / proxy auth - 'django_otp.middleware.OTPMiddleware', # MFA support - 'InvenTree.middleware.CustomAllauthTwoFactorMiddleware', # Flow control for allauth - 'allauth.account.middleware.AccountMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', - 'InvenTree.middleware.AuthRequiredMiddleware', - 'InvenTree.middleware.Check2FAMiddleware', # Check if the user should be forced to use MFA - 'maintenance_mode.middleware.MaintenanceModeMiddleware', - 'InvenTree.middleware.InvenTreeExceptionProcessor', # Error reporting -]) +MIDDLEWARE = CONFIG.get( + 'middleware', + [ + 'django.middleware.security.SecurityMiddleware', + 'x_forwarded_for.middleware.XForwardedForMiddleware', + 'user_sessions.middleware.SessionMiddleware', # db user sessions + 'django.middleware.locale.LocaleMiddleware', + 'django.middleware.common.CommonMiddleware', + 'django.middleware.csrf.CsrfViewMiddleware', + 'corsheaders.middleware.CorsMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'InvenTree.middleware.InvenTreeRemoteUserMiddleware', # Remote / proxy auth + 'django_otp.middleware.OTPMiddleware', # MFA support + 'InvenTree.middleware.CustomAllauthTwoFactorMiddleware', # Flow control for allauth + 'allauth.account.middleware.AccountMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + 'django.middleware.clickjacking.XFrameOptionsMiddleware', + 'InvenTree.middleware.AuthRequiredMiddleware', + 'InvenTree.middleware.Check2FAMiddleware', # Check if the user should be forced to use MFA + 'maintenance_mode.middleware.MaintenanceModeMiddleware', + 'InvenTree.middleware.InvenTreeExceptionProcessor', # Error reporting + ], +) -AUTHENTICATION_BACKENDS = CONFIG.get('authentication_backends', [ - 'django.contrib.auth.backends.RemoteUserBackend', # proxy login - 'django.contrib.auth.backends.ModelBackend', - 'allauth.account.auth_backends.AuthenticationBackend', # SSO login via external providers - "sesame.backends.ModelBackend", # Magic link login django-sesame -]) +AUTHENTICATION_BACKENDS = CONFIG.get( + 'authentication_backends', + [ + 'django.contrib.auth.backends.RemoteUserBackend', # proxy login + 'django.contrib.auth.backends.ModelBackend', + 'allauth.account.auth_backends.AuthenticationBackend', # SSO login via external providers + 'sesame.backends.ModelBackend', # Magic link login django-sesame + ], +) # LDAP support -LDAP_AUTH = get_boolean_setting("INVENTREE_LDAP_ENABLED", "ldap.enabled", False) +LDAP_AUTH = get_boolean_setting('INVENTREE_LDAP_ENABLED', 'ldap.enabled', False) if LDAP_AUTH: import ldap from django_auth_ldap.config import GroupOfUniqueNamesType, LDAPSearch - AUTHENTICATION_BACKENDS.append("django_auth_ldap.backend.LDAPBackend") + AUTHENTICATION_BACKENDS.append('django_auth_ldap.backend.LDAPBackend') # debug mode to troubleshoot configuration - LDAP_DEBUG = get_boolean_setting("INVENTREE_LDAP_DEBUG", "ldap.debug", False) + LDAP_DEBUG = get_boolean_setting('INVENTREE_LDAP_DEBUG', 'ldap.debug', False) if LDAP_DEBUG: - if "loggers" not in LOGGING: - LOGGING["loggers"] = {} - LOGGING["loggers"]["django_auth_ldap"] = {"level": "DEBUG", "handlers": ["console"]} + if 'loggers' not in LOGGING: + LOGGING['loggers'] = {} + LOGGING['loggers']['django_auth_ldap'] = { + 'level': 'DEBUG', + 'handlers': ['console'], + } # get global options from dict and use ldap.OPT_* as keys and values - global_options_dict = get_setting("INVENTREE_LDAP_GLOBAL_OPTIONS", "ldap.global_options", {}, dict) + global_options_dict = get_setting( + 'INVENTREE_LDAP_GLOBAL_OPTIONS', 'ldap.global_options', {}, dict + ) global_options = {} for k, v in global_options_dict.items(): # keys are always ldap.OPT_* constants k_attr = getattr(ldap, k, None) - if not k.startswith("OPT_") or k_attr is None: + if not k.startswith('OPT_') or k_attr is None: print(f"[LDAP] ldap.global_options, key '{k}' not found, skipping...") continue # values can also be other strings, e.g. paths v_attr = v - if v.startswith("OPT_"): + if v.startswith('OPT_'): v_attr = getattr(ldap, v, None) if v_attr is None: @@ -341,41 +325,60 @@ if LDAP_AUTH: global_options[k_attr] = v_attr AUTH_LDAP_GLOBAL_OPTIONS = global_options if LDAP_DEBUG: - print("[LDAP] ldap.global_options =", global_options) + print('[LDAP] ldap.global_options =', global_options) - AUTH_LDAP_SERVER_URI = get_setting("INVENTREE_LDAP_SERVER_URI", "ldap.server_uri") - AUTH_LDAP_START_TLS = get_boolean_setting("INVENTREE_LDAP_START_TLS", "ldap.start_tls", False) - AUTH_LDAP_BIND_DN = get_setting("INVENTREE_LDAP_BIND_DN", "ldap.bind_dn") - AUTH_LDAP_BIND_PASSWORD = get_setting("INVENTREE_LDAP_BIND_PASSWORD", "ldap.bind_password") - AUTH_LDAP_USER_SEARCH = LDAPSearch( - get_setting("INVENTREE_LDAP_SEARCH_BASE_DN", "ldap.search_base_dn"), - ldap.SCOPE_SUBTREE, - str(get_setting("INVENTREE_LDAP_SEARCH_FILTER_STR", "ldap.search_filter_str", "(uid= %(user)s)")) + AUTH_LDAP_SERVER_URI = get_setting('INVENTREE_LDAP_SERVER_URI', 'ldap.server_uri') + AUTH_LDAP_START_TLS = get_boolean_setting( + 'INVENTREE_LDAP_START_TLS', 'ldap.start_tls', False + ) + AUTH_LDAP_BIND_DN = get_setting('INVENTREE_LDAP_BIND_DN', 'ldap.bind_dn') + AUTH_LDAP_BIND_PASSWORD = get_setting( + 'INVENTREE_LDAP_BIND_PASSWORD', 'ldap.bind_password' + ) + AUTH_LDAP_USER_SEARCH = LDAPSearch( + get_setting('INVENTREE_LDAP_SEARCH_BASE_DN', 'ldap.search_base_dn'), + ldap.SCOPE_SUBTREE, + str( + get_setting( + 'INVENTREE_LDAP_SEARCH_FILTER_STR', + 'ldap.search_filter_str', + '(uid= %(user)s)', + ) + ), + ) + AUTH_LDAP_USER_DN_TEMPLATE = get_setting( + 'INVENTREE_LDAP_USER_DN_TEMPLATE', 'ldap.user_dn_template' + ) + AUTH_LDAP_USER_ATTR_MAP = get_setting( + 'INVENTREE_LDAP_USER_ATTR_MAP', + 'ldap.user_attr_map', + {'first_name': 'givenName', 'last_name': 'sn', 'email': 'mail'}, + dict, + ) + AUTH_LDAP_ALWAYS_UPDATE_USER = get_boolean_setting( + 'INVENTREE_LDAP_ALWAYS_UPDATE_USER', 'ldap.always_update_user', True + ) + AUTH_LDAP_CACHE_TIMEOUT = get_setting( + 'INVENTREE_LDAP_CACHE_TIMEOUT', 'ldap.cache_timeout', 3600, int ) - AUTH_LDAP_USER_DN_TEMPLATE = get_setting("INVENTREE_LDAP_USER_DN_TEMPLATE", "ldap.user_dn_template") - AUTH_LDAP_USER_ATTR_MAP = get_setting("INVENTREE_LDAP_USER_ATTR_MAP", "ldap.user_attr_map", { - 'first_name': 'givenName', - 'last_name': 'sn', - 'email': 'mail', - }, dict) - AUTH_LDAP_ALWAYS_UPDATE_USER = get_boolean_setting("INVENTREE_LDAP_ALWAYS_UPDATE_USER", "ldap.always_update_user", True) - AUTH_LDAP_CACHE_TIMEOUT = get_setting("INVENTREE_LDAP_CACHE_TIMEOUT", "ldap.cache_timeout", 3600, int) AUTH_LDAP_GROUP_SEARCH = LDAPSearch( - get_setting("INVENTREE_LDAP_GROUP_SEARCH", "ldap.group_search"), + get_setting('INVENTREE_LDAP_GROUP_SEARCH', 'ldap.group_search'), ldap.SCOPE_SUBTREE, - "(objectClass=groupOfUniqueNames)", + '(objectClass=groupOfUniqueNames)', + ) + AUTH_LDAP_GROUP_TYPE = GroupOfUniqueNamesType(name_attr='cn') + AUTH_LDAP_REQUIRE_GROUP = get_setting( + 'INVENTREE_LDAP_REQUIRE_GROUP', 'ldap.require_group' + ) + AUTH_LDAP_DENY_GROUP = get_setting('INVENTREE_LDAP_DENY_GROUP', 'ldap.deny_group') + AUTH_LDAP_USER_FLAGS_BY_GROUP = get_setting( + 'INVENTREE_LDAP_USER_FLAGS_BY_GROUP', 'ldap.user_flags_by_group', {}, dict ) - AUTH_LDAP_GROUP_TYPE = GroupOfUniqueNamesType(name_attr="cn") - AUTH_LDAP_REQUIRE_GROUP = get_setting("INVENTREE_LDAP_REQUIRE_GROUP", "ldap.require_group") - AUTH_LDAP_DENY_GROUP = get_setting("INVENTREE_LDAP_DENY_GROUP", "ldap.deny_group") - AUTH_LDAP_USER_FLAGS_BY_GROUP = get_setting("INVENTREE_LDAP_USER_FLAGS_BY_GROUP", "ldap.user_flags_by_group", {}, dict) AUTH_LDAP_FIND_GROUP_PERMS = True # Internal IP addresses allowed to see the debug toolbar -INTERNAL_IPS = [ - '127.0.0.1', -] +INTERNAL_IPS = ['127.0.0.1'] # Internal flag to determine if we are running in docker mode DOCKER = get_boolean_setting('INVENTREE_DOCKER', default_value=False) @@ -383,7 +386,10 @@ DOCKER = get_boolean_setting('INVENTREE_DOCKER', default_value=False) if DOCKER: # pragma: no cover # Internal IP addresses are different when running under docker hostname, ___, ips = socket.gethostbyname_ex(socket.gethostname()) - INTERNAL_IPS = [ip[: ip.rfind(".")] + ".1" for ip in ips] + ["127.0.0.1", "10.0.2.2"] + INTERNAL_IPS = [ip[: ip.rfind('.')] + '.1' for ip in ips] + [ + '127.0.0.1', + '10.0.2.2', + ] # Allow secure http developer server in debug mode if DEBUG: @@ -413,15 +419,18 @@ TEMPLATES = [ 'InvenTree.context.status_codes', 'InvenTree.context.user_roles', ], - 'loaders': [( - 'InvenTree.template.InvenTreeTemplateLoader', [ - 'plugin.template.PluginTemplateLoader', - 'django.template.loaders.filesystem.Loader', - 'django.template.loaders.app_directories.Loader', - ]) + 'loaders': [ + ( + 'InvenTree.template.InvenTreeTemplateLoader', + [ + 'plugin.template.PluginTemplateLoader', + 'django.template.loaders.filesystem.Loader', + 'django.template.loaders.app_directories.Loader', + ], + ) ], }, - }, + } ] REST_FRAMEWORK = { @@ -440,22 +449,24 @@ REST_FRAMEWORK = { ), 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', 'DEFAULT_METADATA_CLASS': 'InvenTree.metadata.InvenTreeMetadata', - 'DEFAULT_RENDERER_CLASSES': [ - 'rest_framework.renderers.JSONRenderer', - ], + 'DEFAULT_RENDERER_CLASSES': ['rest_framework.renderers.JSONRenderer'], 'TOKEN_MODEL': 'users.models.ApiToken', } if DEBUG: # Enable browsable API if in DEBUG mode - REST_FRAMEWORK['DEFAULT_RENDERER_CLASSES'].append('rest_framework.renderers.BrowsableAPIRenderer') + REST_FRAMEWORK['DEFAULT_RENDERER_CLASSES'].append( + 'rest_framework.renderers.BrowsableAPIRenderer' + ) # dj-rest-auth # JWT switch USE_JWT = get_boolean_setting('INVENTREE_USE_JWT', 'use_jwt', False) REST_USE_JWT = USE_JWT OLD_PASSWORD_FIELD_ENABLED = True -REST_AUTH_REGISTER_SERIALIZERS = {'REGISTER_SERIALIZER': 'InvenTree.forms.CustomRegisterSerializer'} +REST_AUTH_REGISTER_SERIALIZERS = { + 'REGISTER_SERIALIZER': 'InvenTree.forms.CustomRegisterSerializer' +} # JWT settings - rest_framework_simplejwt if USE_JWT: @@ -471,7 +482,10 @@ SPECTACULAR_SETTINGS = { 'TITLE': 'InvenTree API', 'DESCRIPTION': 'API for InvenTree - the intuitive open source inventory management system', 'LICENSE': {'MIT': 'https://github.com/inventree/InvenTree/blob/master/LICENSE'}, - 'EXTERNAL_DOCS': {'docs': 'https://docs.inventree.org', 'web': 'https://inventree.org'}, + 'EXTERNAL_DOCS': { + 'docs': 'https://docs.inventree.org', + 'web': 'https://inventree.org', + }, 'VERSION': inventreeApiVersion(), 'SERVE_INCLUDE_SCHEMA': False, } @@ -486,7 +500,7 @@ Configure the database backend based on the user-specified values. - The following code lets the user "mix and match" database configuration """ -logger.debug("Configuring database backend:") +logger.debug('Configuring database backend:') # Extract database configuration from the config.yaml file db_config = CONFIG.get('database', {}) @@ -500,7 +514,7 @@ db_keys = ['ENGINE', 'NAME', 'USER', 'PASSWORD', 'HOST', 'PORT'] for key in db_keys: # First, check the environment variables - env_key = f"INVENTREE_DB_{key}" + env_key = f'INVENTREE_DB_{key}' env_var = os.environ.get(env_key, None) if env_var: @@ -509,7 +523,7 @@ for key in db_keys: try: env_var = int(env_var) except ValueError: - logger.exception("Invalid number for %s: %s", env_key, env_var) + logger.exception('Invalid number for %s: %s', env_key, env_var) # Override configuration value db_config[key] = env_var @@ -550,9 +564,9 @@ if 'sqlite' in db_engine: db_name = str(Path(db_name).resolve()) db_config['NAME'] = db_name -logger.info("DB_ENGINE: %s", db_engine) -logger.info("DB_NAME: %s", db_name) -logger.info("DB_HOST: %s", db_host) +logger.info('DB_ENGINE: %s', db_engine) +logger.info('DB_NAME: %s', db_name) +logger.info('DB_HOST: %s', db_host) """ In addition to base-level database configuration, we may wish to specify specific options to the database backend @@ -565,19 +579,21 @@ Ref: https://docs.djangoproject.com/en/3.2/ref/settings/#std:setting-OPTIONS # connecting to the database server (such as a replica failover) don't sit and # wait for possibly an hour or more, just tell the client something went wrong # and let the client retry when they want to. -db_options = db_config.get("OPTIONS", db_config.get("options", {})) +db_options = db_config.get('OPTIONS', db_config.get('options', {})) # Specific options for postgres backend -if "postgres" in db_engine: # pragma: no cover - from psycopg2.extensions import (ISOLATION_LEVEL_READ_COMMITTED, - ISOLATION_LEVEL_SERIALIZABLE) +if 'postgres' in db_engine: # pragma: no cover + from psycopg2.extensions import ( + ISOLATION_LEVEL_READ_COMMITTED, + ISOLATION_LEVEL_SERIALIZABLE, + ) # Connection timeout - if "connect_timeout" not in db_options: + if 'connect_timeout' not in db_options: # The DB server is in the same data center, it should not take very # long to connect to the database server # # seconds, 2 is minimum allowed by libpq - db_options["connect_timeout"] = int( + db_options['connect_timeout'] = int( get_setting('INVENTREE_DB_TIMEOUT', 'database.timeout', 2) ) @@ -587,27 +603,37 @@ if "postgres" in db_engine: # pragma: no cover # issue to resolve itself. It it that doesn't happen whatever happened # is probably fatal and no amount of waiting is going to fix it. # # 0 - TCP Keepalives disabled; 1 - enabled - if "keepalives" not in db_options: - db_options["keepalives"] = int( + if 'keepalives' not in db_options: + db_options['keepalives'] = int( get_setting('INVENTREE_DB_TCP_KEEPALIVES', 'database.tcp_keepalives', 1) ) # Seconds after connection is idle to send keep alive - if "keepalives_idle" not in db_options: - db_options["keepalives_idle"] = int( - get_setting('INVENTREE_DB_TCP_KEEPALIVES_IDLE', 'database.tcp_keepalives_idle', 1) + if 'keepalives_idle' not in db_options: + db_options['keepalives_idle'] = int( + get_setting( + 'INVENTREE_DB_TCP_KEEPALIVES_IDLE', 'database.tcp_keepalives_idle', 1 + ) ) # Seconds after missing ACK to send another keep alive - if "keepalives_interval" not in db_options: - db_options["keepalives_interval"] = int( - get_setting("INVENTREE_DB_TCP_KEEPALIVES_INTERVAL", "database.tcp_keepalives_internal", "1") + if 'keepalives_interval' not in db_options: + db_options['keepalives_interval'] = int( + get_setting( + 'INVENTREE_DB_TCP_KEEPALIVES_INTERVAL', + 'database.tcp_keepalives_internal', + '1', + ) ) # Number of missing ACKs before we close the connection - if "keepalives_count" not in db_options: - db_options["keepalives_count"] = int( - get_setting("INVENTREE_DB_TCP_KEEPALIVES_COUNT", "database.tcp_keepalives_count", "5") + if 'keepalives_count' not in db_options: + db_options['keepalives_count'] = int( + get_setting( + 'INVENTREE_DB_TCP_KEEPALIVES_COUNT', + 'database.tcp_keepalives_count', + '5', + ) ) # # Milliseconds for how long pending data should remain unacked @@ -621,12 +647,18 @@ if "postgres" in db_engine: # pragma: no cover # protect against simultaneous changes. # https://www.postgresql.org/docs/devel/transaction-iso.html # https://docs.djangoproject.com/en/3.2/ref/databases/#isolation-level - if "isolation_level" not in db_options: - serializable = get_boolean_setting('INVENTREE_DB_ISOLATION_SERIALIZABLE', 'database.serializable', False) - db_options["isolation_level"] = ISOLATION_LEVEL_SERIALIZABLE if serializable else ISOLATION_LEVEL_READ_COMMITTED + if 'isolation_level' not in db_options: + serializable = get_boolean_setting( + 'INVENTREE_DB_ISOLATION_SERIALIZABLE', 'database.serializable', False + ) + db_options['isolation_level'] = ( + ISOLATION_LEVEL_SERIALIZABLE + if serializable + else ISOLATION_LEVEL_READ_COMMITTED + ) # Specific options for MySql / MariaDB backend -elif "mysql" in db_engine: # pragma: no cover +elif 'mysql' in db_engine: # pragma: no cover # TODO TCP time outs and keepalives # MariaDB's default isolation level is Repeatable Read which is @@ -635,12 +667,16 @@ elif "mysql" in db_engine: # pragma: no cover # protect against siumltaneous changes. # https://mariadb.com/kb/en/mariadb-transactions-and-isolation-levels-for-sql-server-users/#changing-the-isolation-level # https://docs.djangoproject.com/en/3.2/ref/databases/#mysql-isolation-level - if "isolation_level" not in db_options: - serializable = get_boolean_setting('INVENTREE_DB_ISOLATION_SERIALIZABLE', 'database.serializable', False) - db_options["isolation_level"] = "serializable" if serializable else "read committed" + if 'isolation_level' not in db_options: + serializable = get_boolean_setting( + 'INVENTREE_DB_ISOLATION_SERIALIZABLE', 'database.serializable', False + ) + db_options['isolation_level'] = ( + 'serializable' if serializable else 'read committed' + ) # Specific options for sqlite backend -elif "sqlite" in db_engine: +elif 'sqlite' in db_engine: # TODO: Verify timeouts are not an issue because no network is involved for SQLite # SQLite's default isolation level is Serializable due to SQLite's @@ -653,31 +689,34 @@ elif "sqlite" in db_engine: db_config['OPTIONS'] = db_options # Set testing options for the database -db_config['TEST'] = { - 'CHARSET': 'utf8', -} +db_config['TEST'] = {'CHARSET': 'utf8'} # Set collation option for mysql test database if 'mysql' in db_engine: db_config['TEST']['COLLATION'] = 'utf8_general_ci' # pragma: no cover -DATABASES = { - 'default': db_config -} +DATABASES = {'default': db_config} # login settings -REMOTE_LOGIN = get_boolean_setting('INVENTREE_REMOTE_LOGIN', 'remote_login_enabled', False) -REMOTE_LOGIN_HEADER = get_setting('INVENTREE_REMOTE_LOGIN_HEADER', 'remote_login_header', 'REMOTE_USER') +REMOTE_LOGIN = get_boolean_setting( + 'INVENTREE_REMOTE_LOGIN', 'remote_login_enabled', False +) +REMOTE_LOGIN_HEADER = get_setting( + 'INVENTREE_REMOTE_LOGIN_HEADER', 'remote_login_header', 'REMOTE_USER' +) # sentry.io integration for error reporting -SENTRY_ENABLED = get_boolean_setting('INVENTREE_SENTRY_ENABLED', 'sentry_enabled', False) +SENTRY_ENABLED = get_boolean_setting( + 'INVENTREE_SENTRY_ENABLED', 'sentry_enabled', False +) # Default Sentry DSN (can be overridden if user wants custom sentry integration) SENTRY_DSN = get_setting('INVENTREE_SENTRY_DSN', 'sentry_dsn', default_sentry_dsn()) -SENTRY_SAMPLE_RATE = float(get_setting('INVENTREE_SENTRY_SAMPLE_RATE', 'sentry_sample_rate', 0.1)) +SENTRY_SAMPLE_RATE = float( + get_setting('INVENTREE_SENTRY_SAMPLE_RATE', 'sentry_sample_rate', 0.1) +) if SENTRY_ENABLED and SENTRY_DSN: # pragma: no cover - inventree_tags = { 'testing': TESTING, 'docker': DOCKER, @@ -696,53 +735,47 @@ if cache_host: # pragma: no cover # so don't wait too long for the cache as nothing in the cache should be # irreplaceable. _cache_options = { - "CLIENT_CLASS": "django_redis.client.DefaultClient", - "SOCKET_CONNECT_TIMEOUT": int(os.getenv("CACHE_CONNECT_TIMEOUT", "2")), - "SOCKET_TIMEOUT": int(os.getenv("CACHE_SOCKET_TIMEOUT", "2")), - "CONNECTION_POOL_KWARGS": { - "socket_keepalive": config.is_true( - os.getenv("CACHE_TCP_KEEPALIVE", "1") - ), - "socket_keepalive_options": { - socket.TCP_KEEPCNT: int( - os.getenv("CACHE_KEEPALIVES_COUNT", "5") - ), - socket.TCP_KEEPIDLE: int( - os.getenv("CACHE_KEEPALIVES_IDLE", "1") - ), - socket.TCP_KEEPINTVL: int( - os.getenv("CACHE_KEEPALIVES_INTERVAL", "1") - ), + 'CLIENT_CLASS': 'django_redis.client.DefaultClient', + 'SOCKET_CONNECT_TIMEOUT': int(os.getenv('CACHE_CONNECT_TIMEOUT', '2')), + 'SOCKET_TIMEOUT': int(os.getenv('CACHE_SOCKET_TIMEOUT', '2')), + 'CONNECTION_POOL_KWARGS': { + 'socket_keepalive': config.is_true(os.getenv('CACHE_TCP_KEEPALIVE', '1')), + 'socket_keepalive_options': { + socket.TCP_KEEPCNT: int(os.getenv('CACHE_KEEPALIVES_COUNT', '5')), + socket.TCP_KEEPIDLE: int(os.getenv('CACHE_KEEPALIVES_IDLE', '1')), + socket.TCP_KEEPINTVL: int(os.getenv('CACHE_KEEPALIVES_INTERVAL', '1')), socket.TCP_USER_TIMEOUT: int( - os.getenv("CACHE_TCP_USER_TIMEOUT", "1000") + os.getenv('CACHE_TCP_USER_TIMEOUT', '1000') ), }, }, } CACHES = { - "default": { - "BACKEND": "django_redis.cache.RedisCache", - "LOCATION": f"redis://{cache_host}:{cache_port}/0", - "OPTIONS": _cache_options, - }, + 'default': { + 'BACKEND': 'django_redis.cache.RedisCache', + 'LOCATION': f'redis://{cache_host}:{cache_port}/0', + 'OPTIONS': _cache_options, + } } else: - CACHES = { - "default": { - "BACKEND": "django.core.cache.backends.locmem.LocMemCache", - }, - } + CACHES = {'default': {'BACKEND': 'django.core.cache.backends.locmem.LocMemCache'}} -_q_worker_timeout = int(get_setting('INVENTREE_BACKGROUND_TIMEOUT', 'background.timeout', 90)) +_q_worker_timeout = int( + get_setting('INVENTREE_BACKGROUND_TIMEOUT', 'background.timeout', 90) +) # django-q background worker configuration Q_CLUSTER = { 'name': 'InvenTree', 'label': 'Background Tasks', - 'workers': int(get_setting('INVENTREE_BACKGROUND_WORKERS', 'background.workers', 4)), + 'workers': int( + get_setting('INVENTREE_BACKGROUND_WORKERS', 'background.workers', 4) + ), 'timeout': _q_worker_timeout, 'retry': min(120, _q_worker_timeout + 30), - 'max_attempts': int(get_setting('INVENTREE_BACKGROUND_MAX_ATTEMPTS', 'background.max_attempts', 5)), + 'max_attempts': int( + get_setting('INVENTREE_BACKGROUND_MAX_ATTEMPTS', 'background.max_attempts', 5) + ), 'queue_limit': 50, 'catch_up': False, 'bulk': 10, @@ -754,40 +787,30 @@ Q_CLUSTER = { # Configure django-q sentry integration if SENTRY_ENABLED and SENTRY_DSN: - Q_CLUSTER['error_reporter'] = { - 'sentry': { - 'dsn': SENTRY_DSN - } - } + Q_CLUSTER['error_reporter'] = {'sentry': {'dsn': SENTRY_DSN}} if cache_host: # pragma: no cover # If using external redis cache, make the cache the broker for Django Q # as well - Q_CLUSTER["django_redis"] = "worker" + Q_CLUSTER['django_redis'] = 'worker' # database user sessions SESSION_ENGINE = 'user_sessions.backends.db' -LOGOUT_REDIRECT_URL = get_setting('INVENTREE_LOGOUT_REDIRECT_URL', 'logout_redirect_url', 'index') -SILENCED_SYSTEM_CHECKS = [ - 'admin.E410', -] +LOGOUT_REDIRECT_URL = get_setting( + 'INVENTREE_LOGOUT_REDIRECT_URL', 'logout_redirect_url', 'index' +) +SILENCED_SYSTEM_CHECKS = ['admin.E410'] # Password validation # https://docs.djangoproject.com/en/1.10/ref/settings/#auth-password-validators AUTH_PASSWORD_VALIDATORS = [ { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator' }, + {'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator'}, + {'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator'}, + {'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator'}, ] # Extra (optional) URL validators @@ -796,7 +819,7 @@ AUTH_PASSWORD_VALIDATORS = [ EXTRA_URL_SCHEMES = get_setting('INVENTREE_EXTRA_URL_SCHEMES', 'extra_url_schemes', []) if type(EXTRA_URL_SCHEMES) not in [list]: # pragma: no cover - logger.warning("extra_url_schemes not correctly formatted") + logger.warning('extra_url_schemes not correctly formatted') EXTRA_URL_SCHEMES = [] # Internationalization @@ -854,26 +877,21 @@ if get_boolean_setting('TEST_TRANSLATIONS', default_value=False): # pragma: no LANGUAGES.append(('xx', 'Test')) # Add custom languages not provided by Django - EXTRA_LANG_INFO = { - 'xx': { - 'code': 'xx', - 'name': 'Test', - 'name_local': 'Test' - }, - } + EXTRA_LANG_INFO = {'xx': {'code': 'xx', 'name': 'Test', 'name_local': 'Test'}} LANG_INFO = dict(django.conf.locale.LANG_INFO, **EXTRA_LANG_INFO) django.conf.locale.LANG_INFO = LANG_INFO # Currencies available for use CURRENCIES = get_setting( - 'INVENTREE_CURRENCIES', 'currencies', + 'INVENTREE_CURRENCIES', + 'currencies', ['AUD', 'CAD', 'CNY', 'EUR', 'GBP', 'JPY', 'NZD', 'USD'], typecast=list, ) # Ensure that at least one currency value is available if len(CURRENCIES) == 0: # pragma: no cover - logger.warning("No currencies selected: Defaulting to USD") + logger.warning('No currencies selected: Defaulting to USD') CURRENCIES = ['USD'] # Maximum number of decimal places for currency rendering @@ -889,12 +907,18 @@ for currency in CURRENCIES: EXCHANGE_BACKEND = 'InvenTree.exchange.InvenTreeExchange' # Email configuration options -EMAIL_BACKEND = get_setting('INVENTREE_EMAIL_BACKEND', 'email.backend', 'django.core.mail.backends.smtp.EmailBackend') +EMAIL_BACKEND = get_setting( + 'INVENTREE_EMAIL_BACKEND', + 'email.backend', + 'django.core.mail.backends.smtp.EmailBackend', +) EMAIL_HOST = get_setting('INVENTREE_EMAIL_HOST', 'email.host', '') EMAIL_PORT = get_setting('INVENTREE_EMAIL_PORT', 'email.port', 25, typecast=int) EMAIL_HOST_USER = get_setting('INVENTREE_EMAIL_USERNAME', 'email.username', '') EMAIL_HOST_PASSWORD = get_setting('INVENTREE_EMAIL_PASSWORD', 'email.password', '') -EMAIL_SUBJECT_PREFIX = get_setting('INVENTREE_EMAIL_PREFIX', 'email.prefix', '[InvenTree] ') +EMAIL_SUBJECT_PREFIX = get_setting( + 'INVENTREE_EMAIL_PREFIX', 'email.prefix', '[InvenTree] ' +) EMAIL_USE_TLS = get_boolean_setting('INVENTREE_EMAIL_TLS', 'email.tls', False) EMAIL_USE_SSL = get_boolean_setting('INVENTREE_EMAIL_SSL', 'email.ssl', False) @@ -907,9 +931,7 @@ if not DEFAULT_FROM_EMAIL: EMAIL_USE_LOCALTIME = False EMAIL_TIMEOUT = 60 -LOCALE_PATHS = ( - BASE_DIR.joinpath('locale/'), -) +LOCALE_PATHS = (BASE_DIR.joinpath('locale/'),) TIME_ZONE = get_setting('INVENTREE_TIMEZONE', 'timezone', 'UTC') @@ -922,9 +944,7 @@ USE_L10N = True if not TESTING: USE_TZ = True # pragma: no cover -DATE_INPUT_FORMATS = [ - "%Y-%m-%d", -] +DATE_INPUT_FORMATS = ['%Y-%m-%d'] # crispy forms use the bootstrap templates CRISPY_TEMPLATE_PACK = 'bootstrap4' @@ -935,10 +955,11 @@ IMPORT_EXPORT_USE_TRANSACTIONS = True SITE_ID = 1 # Load the allauth social backends -SOCIAL_BACKENDS = get_setting('INVENTREE_SOCIAL_BACKENDS', 'social_backends', [], typecast=list) +SOCIAL_BACKENDS = get_setting( + 'INVENTREE_SOCIAL_BACKENDS', 'social_backends', [], typecast=list +) for app in SOCIAL_BACKENDS: - # Ensure that the app starts with 'allauth.socialaccount.providers' social_prefix = 'allauth.socialaccount.providers.' @@ -947,14 +968,22 @@ for app in SOCIAL_BACKENDS: INSTALLED_APPS.append(app) # pragma: no cover -SOCIALACCOUNT_PROVIDERS = get_setting('INVENTREE_SOCIAL_PROVIDERS', 'social_providers', None, typecast=dict) +SOCIALACCOUNT_PROVIDERS = get_setting( + 'INVENTREE_SOCIAL_PROVIDERS', 'social_providers', None, typecast=dict +) SOCIALACCOUNT_STORE_TOKENS = True # settings for allauth -ACCOUNT_EMAIL_CONFIRMATION_EXPIRE_DAYS = get_setting('INVENTREE_LOGIN_CONFIRM_DAYS', 'login_confirm_days', 3, typecast=int) -ACCOUNT_LOGIN_ATTEMPTS_LIMIT = get_setting('INVENTREE_LOGIN_ATTEMPTS', 'login_attempts', 5, typecast=int) -ACCOUNT_DEFAULT_HTTP_PROTOCOL = get_setting('INVENTREE_LOGIN_DEFAULT_HTTP_PROTOCOL', 'login_default_protocol', 'http') +ACCOUNT_EMAIL_CONFIRMATION_EXPIRE_DAYS = get_setting( + 'INVENTREE_LOGIN_CONFIRM_DAYS', 'login_confirm_days', 3, typecast=int +) +ACCOUNT_LOGIN_ATTEMPTS_LIMIT = get_setting( + 'INVENTREE_LOGIN_ATTEMPTS', 'login_attempts', 5, typecast=int +) +ACCOUNT_DEFAULT_HTTP_PROTOCOL = get_setting( + 'INVENTREE_LOGIN_DEFAULT_HTTP_PROTOCOL', 'login_default_protocol', 'http' +) ACCOUNT_LOGOUT_ON_PASSWORD_CHANGE = True ACCOUNT_PREVENT_ENUMERATION = True # 2FA @@ -981,21 +1010,17 @@ ACCOUNT_ADAPTER = 'InvenTree.forms.CustomAccountAdapter' MARKDOWNIFY = { 'default': { 'BLEACH': True, - 'WHITELIST_ATTRS': [ - 'href', - 'src', - 'alt', - ], - 'MARKDOWN_EXTENSIONS': [ - 'markdown.extensions.extra' - ], + 'WHITELIST_ATTRS': ['href', 'src', 'alt'], + 'MARKDOWN_EXTENSIONS': ['markdown.extensions.extra'], 'WHITELIST_TAGS': [ 'a', 'abbr', 'b', 'blockquote', 'em', - 'h1', 'h2', 'h3', + 'h1', + 'h2', + 'h3', 'i', 'img', 'li', @@ -1008,46 +1033,55 @@ MARKDOWNIFY = { 'tbody', 'th', 'tr', - 'td' + 'td', ], } } # Ignore these error typeps for in-database error logging -IGNORED_ERRORS = [ - Http404, - django.core.exceptions.PermissionDenied, -] +IGNORED_ERRORS = [Http404, django.core.exceptions.PermissionDenied] # Maintenance mode MAINTENANCE_MODE_RETRY_AFTER = 60 MAINTENANCE_MODE_STATE_BACKEND = 'maintenance_mode.backends.StaticStorageBackend' # Are plugins enabled? -PLUGINS_ENABLED = get_boolean_setting('INVENTREE_PLUGINS_ENABLED', 'plugins_enabled', False) +PLUGINS_ENABLED = get_boolean_setting( + 'INVENTREE_PLUGINS_ENABLED', 'plugins_enabled', False +) PLUGIN_FILE = config.get_plugin_file() # Plugin test settings -PLUGIN_TESTING = get_setting('INVENTREE_PLUGIN_TESTING', 'PLUGIN_TESTING', TESTING) # Are plugins being tested? -PLUGIN_TESTING_SETUP = get_setting('INVENTREE_PLUGIN_TESTING_SETUP', 'PLUGIN_TESTING_SETUP', False) # Load plugins from setup hooks in testing? -PLUGIN_TESTING_EVENTS = False # Flag if events are tested right now -PLUGIN_RETRY = get_setting('INVENTREE_PLUGIN_RETRY', 'PLUGIN_RETRY', 5) # How often should plugin loading be tried? -PLUGIN_FILE_CHECKED = False # Was the plugin file checked? +PLUGIN_TESTING = get_setting( + 'INVENTREE_PLUGIN_TESTING', 'PLUGIN_TESTING', TESTING +) # Are plugins being tested? +PLUGIN_TESTING_SETUP = get_setting( + 'INVENTREE_PLUGIN_TESTING_SETUP', 'PLUGIN_TESTING_SETUP', False +) # Load plugins from setup hooks in testing? +PLUGIN_TESTING_EVENTS = False # Flag if events are tested right now +PLUGIN_RETRY = get_setting( + 'INVENTREE_PLUGIN_RETRY', 'PLUGIN_RETRY', 5 +) # How often should plugin loading be tried? +PLUGIN_FILE_CHECKED = False # Was the plugin file checked? # Site URL can be specified statically, or via a run-time setting SITE_URL = get_setting('INVENTREE_SITE_URL', 'site_url', None) if SITE_URL: - logger.info("Site URL: %s", SITE_URL) + logger.info('Site URL: %s', SITE_URL) # Check that the site URL is valid validator = URLValidator() validator(SITE_URL) # User interface customization values -CUSTOM_LOGO = get_custom_file('INVENTREE_CUSTOM_LOGO', 'customize.logo', 'custom logo', lookup_media=True) -CUSTOM_SPLASH = get_custom_file('INVENTREE_CUSTOM_SPLASH', 'customize.splash', 'custom splash') +CUSTOM_LOGO = get_custom_file( + 'INVENTREE_CUSTOM_LOGO', 'customize.logo', 'custom logo', lookup_media=True +) +CUSTOM_SPLASH = get_custom_file( + 'INVENTREE_CUSTOM_SPLASH', 'customize.splash', 'custom splash' +) CUSTOMIZE = get_setting('INVENTREE_CUSTOMIZE', 'customize', {}) @@ -1056,7 +1090,7 @@ FRONTEND_SETTINGS = config.get_frontend_settings(debug=DEBUG) FRONTEND_URL_BASE = FRONTEND_SETTINGS.get('base_url', 'platform') if DEBUG: - logger.info("InvenTree running with DEBUG enabled") + logger.info('InvenTree running with DEBUG enabled') logger.info("MEDIA_ROOT: '%s'", MEDIA_ROOT) logger.info("STATIC_ROOT: '%s'", STATIC_ROOT) @@ -1068,7 +1102,7 @@ FLAGS = { {'condition': 'parameter', 'value': 'experimental='}, ], # Should experimental features be turned on? 'NEXT_GEN': [ - {'condition': 'parameter', 'value': 'ngen='}, + {'condition': 'parameter', 'value': 'ngen='} ], # Should next-gen features be turned on? } @@ -1076,12 +1110,12 @@ FLAGS = { CUSTOM_FLAGS = get_setting('INVENTREE_FLAGS', 'flags', None, typecast=dict) if CUSTOM_FLAGS: if not isinstance(CUSTOM_FLAGS, dict): - logger.error("Invalid custom flags, must be valid dict: %s", str(CUSTOM_FLAGS)) + logger.error('Invalid custom flags, must be valid dict: %s', str(CUSTOM_FLAGS)) else: - logger.info("Custom flags: %s", str(CUSTOM_FLAGS)) + logger.info('Custom flags: %s', str(CUSTOM_FLAGS)) FLAGS.update(CUSTOM_FLAGS) # Magic login django-sesame SESAME_MAX_AGE = 300 # LOGIN_REDIRECT_URL = f"/{FRONTEND_URL_BASE}/logged-in/" -LOGIN_REDIRECT_URL = "/index/" +LOGIN_REDIRECT_URL = '/index/' diff --git a/InvenTree/InvenTree/social_auth_urls.py b/InvenTree/InvenTree/social_auth_urls.py index 5eca51210e..e323b1892d 100644 --- a/InvenTree/InvenTree/social_auth_urls.py +++ b/InvenTree/InvenTree/social_auth_urls.py @@ -1,4 +1,5 @@ """API endpoints for social authentication with allauth.""" + import logging from importlib import import_module @@ -6,8 +7,7 @@ from django.urls import NoReverseMatch, include, path, reverse from allauth.account.models import EmailAddress from allauth.socialaccount import providers -from allauth.socialaccount.providers.oauth2.views import (OAuth2Adapter, - OAuth2LoginView) +from allauth.socialaccount.providers.oauth2.views import OAuth2Adapter, OAuth2LoginView from drf_spectacular.utils import OpenApiResponse, extend_schema from rest_framework.exceptions import NotFound from rest_framework.permissions import AllowAny, IsAuthenticated @@ -23,6 +23,7 @@ logger = logging.getLogger('inventree') class GenericOAuth2ApiLoginView(OAuth2LoginView): """Api view to login a user with a social account""" + def dispatch(self, request, *args, **kwargs): """Dispatch the regular login view directly.""" return self.login(request, *args, **kwargs) @@ -44,8 +45,16 @@ class GenericOAuth2ApiConnectView(GenericOAuth2ApiLoginView): def handle_oauth2(adapter: OAuth2Adapter): """Define urls for oauth2 endpoints.""" return [ - path('login/', GenericOAuth2ApiLoginView.adapter_view(adapter), name=f'{provider.id}_api_login'), - path('connect/', GenericOAuth2ApiConnectView.adapter_view(adapter), name=f'{provider.id}_api_connect'), + path( + 'login/', + GenericOAuth2ApiLoginView.adapter_view(adapter), + name=f'{provider.id}_api_login', + ), + path( + 'connect/', + GenericOAuth2ApiConnectView.adapter_view(adapter), + name=f'{provider.id}_api_connect', + ), ] @@ -64,15 +73,20 @@ social_auth_urlpatterns = [] provider_urlpatterns = [] for name, provider in providers.registry.provider_map.items(): - try: - prov_mod = import_module(provider.get_package() + ".views") + prov_mod = import_module(provider.get_package() + '.views') except ImportError: - logger.exception("Could not import authentication provider %s", name) + logger.exception('Could not import authentication provider %s', name) continue # Try to extract the adapter class - adapters = [cls for cls in prov_mod.__dict__.values() if isinstance(cls, type) and not cls == OAuth2Adapter and issubclass(cls, OAuth2Adapter)] + adapters = [ + cls + for cls in prov_mod.__dict__.values() + if isinstance(cls, type) + and not cls == OAuth2Adapter + and issubclass(cls, OAuth2Adapter) + ] # Get urls urls = [] @@ -80,10 +94,17 @@ for name, provider in providers.registry.provider_map.items(): urls = handle_oauth2(adapter=adapters[0]) else: if provider.id in legacy: - logger.warning('`%s` is not supported on platform UI. Use `%s` instead.', provider.id, legacy[provider.id]) + logger.warning( + '`%s` is not supported on platform UI. Use `%s` instead.', + provider.id, + legacy[provider.id], + ) continue else: - logger.error('Found handler that is not yet ready for platform UI: `%s`. Open an feature request on GitHub if you need it implemented.', provider.id) + logger.error( + 'Found handler that is not yet ready for platform UI: `%s`. Open an feature request on GitHub if you need it implemented.', + provider.id, + ) continue provider_urlpatterns += [path(f'{provider.id}/', include(urls))] @@ -93,6 +114,7 @@ social_auth_urlpatterns += provider_urlpatterns class SocialProviderListView(ListAPI): """List of available social providers.""" + permission_classes = (AllowAny,) def get(self, request, *args, **kwargs): @@ -102,21 +124,27 @@ class SocialProviderListView(ListAPI): provider_data = { 'id': provider.id, 'name': provider.name, - 'configured': False + 'configured': False, } try: - provider_data['login'] = request.build_absolute_uri(reverse(f'{provider.id}_api_login')) + provider_data['login'] = request.build_absolute_uri( + reverse(f'{provider.id}_api_login') + ) except NoReverseMatch: provider_data['login'] = None try: - provider_data['connect'] = request.build_absolute_uri(reverse(f'{provider.id}_api_connect')) + provider_data['connect'] = request.build_absolute_uri( + reverse(f'{provider.id}_api_connect') + ) except NoReverseMatch: provider_data['connect'] = None provider_data['configured'] = InvenTree.sso.check_provider(provider) - provider_data['display_name'] = InvenTree.sso.provider_display_name(provider) + provider_data['display_name'] = InvenTree.sso.provider_display_name( + provider + ) provider_list.append(provider_data) @@ -124,7 +152,7 @@ class SocialProviderListView(ListAPI): 'sso_enabled': InvenTree.sso.login_enabled(), 'sso_registration': InvenTree.sso.registration_enabled(), 'mfa_required': InvenTreeSetting.get_setting('LOGIN_ENFORCE_MFA'), - 'providers': provider_list + 'providers': provider_list, } return Response(data) @@ -151,6 +179,7 @@ class EmptyEmailAddressSerializer(InvenTreeModelSerializer): class EmailListView(ListCreateAPI): """List of registered email addresses for current users.""" + permission_classes = (IsAuthenticated,) serializer_class = EmailAddressSerializer @@ -161,12 +190,15 @@ class EmailListView(ListCreateAPI): class EmailActionMixin(CreateAPI): """Mixin to modify email addresses for current users.""" + serializer_class = EmptyEmailAddressSerializer permission_classes = (IsAuthenticated,) def get_queryset(self): """Filter queryset for current user.""" - return EmailAddress.objects.filter(user=self.request.user, pk=self.kwargs['pk']).first() + return EmailAddress.objects.filter( + user=self.request.user, pk=self.kwargs['pk'] + ).first() @extend_schema(responses={200: OpenApiResponse(response=EmailAddressSerializer)}) def post(self, request, *args, **kwargs): diff --git a/InvenTree/InvenTree/sso.py b/InvenTree/InvenTree/sso.py index 3fc126e3e4..d77a7dfc1b 100644 --- a/InvenTree/InvenTree/sso.py +++ b/InvenTree/InvenTree/sso.py @@ -1,6 +1,5 @@ """Helper functions for Single Sign On functionality""" - import logging from common.models import InvenTreeSetting @@ -49,7 +48,7 @@ def check_provider(provider, raise_error=False): if allauth.app_settings.SITES_ENABLED: # At least one matching site must be specified if not app.sites.exists(): - logger.error("SocialApp %s has no sites configured", app) + logger.error('SocialApp %s has no sites configured', app) return False # At this point, we assume that the provider is correctly configured diff --git a/InvenTree/InvenTree/status.py b/InvenTree/InvenTree/status.py index 767f091735..cd6d357928 100644 --- a/InvenTree/InvenTree/status.py +++ b/InvenTree/InvenTree/status.py @@ -13,7 +13,7 @@ from django_q.status import Stat import InvenTree.email import InvenTree.ready -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') def is_worker_running(**kwargs): @@ -33,9 +33,7 @@ def is_worker_running(**kwargs): now = timezone.now() past = now - timedelta(minutes=10) - results = Success.objects.filter( - started__gte=past - ) + results = Success.objects.filter(started__gte=past) # If any results are returned, then the background worker is running! try: @@ -65,13 +63,13 @@ def check_system_health(**kwargs): if not is_worker_running(**kwargs): # pragma: no cover result = False - logger.warning(_("Background worker check failed")) + logger.warning(_('Background worker check failed')) if not InvenTree.email.is_email_configured(): # pragma: no cover result = False - logger.warning(_("Email backend not configured")) + logger.warning(_('Email backend not configured')) if not result: # pragma: no cover - logger.warning(_("InvenTree system health checks failed")) + logger.warning(_('InvenTree system health checks failed')) return result diff --git a/InvenTree/InvenTree/status_codes.py b/InvenTree/InvenTree/status_codes.py index 956773a9e6..65a02fb4b2 100644 --- a/InvenTree/InvenTree/status_codes.py +++ b/InvenTree/InvenTree/status_codes.py @@ -9,68 +9,68 @@ class PurchaseOrderStatus(StatusCode): """Defines a set of status codes for a PurchaseOrder.""" # Order status codes - PENDING = 10, _("Pending"), 'secondary' # Order is pending (not yet placed) - PLACED = 20, _("Placed"), 'primary' # Order has been placed with supplier - COMPLETE = 30, _("Complete"), 'success' # Order has been completed - CANCELLED = 40, _("Cancelled"), 'danger' # Order was cancelled - LOST = 50, _("Lost"), 'warning' # Order was lost - RETURNED = 60, _("Returned"), 'warning' # Order was returned + PENDING = 10, _('Pending'), 'secondary' # Order is pending (not yet placed) + PLACED = 20, _('Placed'), 'primary' # Order has been placed with supplier + COMPLETE = 30, _('Complete'), 'success' # Order has been completed + CANCELLED = 40, _('Cancelled'), 'danger' # Order was cancelled + LOST = 50, _('Lost'), 'warning' # Order was lost + RETURNED = 60, _('Returned'), 'warning' # Order was returned class PurchaseOrderStatusGroups: """Groups for PurchaseOrderStatus codes.""" # Open orders - OPEN = [ - PurchaseOrderStatus.PENDING.value, - PurchaseOrderStatus.PLACED.value, - ] + OPEN = [PurchaseOrderStatus.PENDING.value, PurchaseOrderStatus.PLACED.value] # Failed orders FAILED = [ PurchaseOrderStatus.CANCELLED.value, PurchaseOrderStatus.LOST.value, - PurchaseOrderStatus.RETURNED.value + PurchaseOrderStatus.RETURNED.value, ] class SalesOrderStatus(StatusCode): """Defines a set of status codes for a SalesOrder.""" - PENDING = 10, _("Pending"), 'secondary' # Order is pending - IN_PROGRESS = 15, _("In Progress"), 'primary' # Order has been issued, and is in progress - SHIPPED = 20, _("Shipped"), 'success' # Order has been shipped to customer - CANCELLED = 40, _("Cancelled"), 'danger' # Order has been cancelled - LOST = 50, _("Lost"), 'warning' # Order was lost - RETURNED = 60, _("Returned"), 'warning' # Order was returned + PENDING = 10, _('Pending'), 'secondary' # Order is pending + IN_PROGRESS = ( + 15, + _('In Progress'), + 'primary', + ) # Order has been issued, and is in progress + SHIPPED = 20, _('Shipped'), 'success' # Order has been shipped to customer + CANCELLED = 40, _('Cancelled'), 'danger' # Order has been cancelled + LOST = 50, _('Lost'), 'warning' # Order was lost + RETURNED = 60, _('Returned'), 'warning' # Order was returned class SalesOrderStatusGroups: """Groups for SalesOrderStatus codes.""" # Open orders - OPEN = [ - SalesOrderStatus.PENDING.value, - SalesOrderStatus.IN_PROGRESS.value, - ] + OPEN = [SalesOrderStatus.PENDING.value, SalesOrderStatus.IN_PROGRESS.value] # Completed orders - COMPLETE = [ - SalesOrderStatus.SHIPPED.value, - ] + COMPLETE = [SalesOrderStatus.SHIPPED.value] class StockStatus(StatusCode): """Status codes for Stock.""" - OK = 10, _("OK"), 'success' # Item is OK - ATTENTION = 50, _("Attention needed"), 'warning' # Item requires attention - DAMAGED = 55, _("Damaged"), 'warning' # Item is damaged - DESTROYED = 60, _("Destroyed"), 'danger' # Item is destroyed - REJECTED = 65, _("Rejected"), 'danger' # Item is rejected - LOST = 70, _("Lost"), 'dark' # Item has been lost - QUARANTINED = 75, _("Quarantined"), 'info' # Item has been quarantined and is unavailable - RETURNED = 85, _("Returned"), 'warning' # Item has been returned from a customer + OK = 10, _('OK'), 'success' # Item is OK + ATTENTION = 50, _('Attention needed'), 'warning' # Item requires attention + DAMAGED = 55, _('Damaged'), 'warning' # Item is damaged + DESTROYED = 60, _('Destroyed'), 'danger' # Item is destroyed + REJECTED = 65, _('Rejected'), 'danger' # Item is rejected + LOST = 70, _('Lost'), 'dark' # Item has been lost + QUARANTINED = ( + 75, + _('Quarantined'), + 'info', + ) # Item has been quarantined and is unavailable + RETURNED = 85, _('Returned'), 'warning' # Item has been returned from a customer class StockStatusGroups: @@ -129,7 +129,7 @@ class StockHistoryCode(StatusCode): BUILD_CONSUMED = 57, _('Consumed by build order') # Sales order codes - SHIPPED_AGAINST_SALES_ORDER = 60, _("Shipped against Sales Order") + SHIPPED_AGAINST_SALES_ORDER = 60, _('Shipped against Sales Order') # Purchase order codes RECEIVED_AGAINST_PURCHASE_ORDER = 70, _('Received against Purchase Order') @@ -145,59 +145,53 @@ class StockHistoryCode(StatusCode): class BuildStatus(StatusCode): """Build status codes.""" - PENDING = 10, _("Pending"), 'secondary' # Build is pending / active - PRODUCTION = 20, _("Production"), 'primary' # BuildOrder is in production - CANCELLED = 30, _("Cancelled"), 'danger' # Build was cancelled - COMPLETE = 40, _("Complete"), 'success' # Build is complete + PENDING = 10, _('Pending'), 'secondary' # Build is pending / active + PRODUCTION = 20, _('Production'), 'primary' # BuildOrder is in production + CANCELLED = 30, _('Cancelled'), 'danger' # Build was cancelled + COMPLETE = 40, _('Complete'), 'success' # Build is complete class BuildStatusGroups: """Groups for BuildStatus codes.""" - ACTIVE_CODES = [ - BuildStatus.PENDING.value, - BuildStatus.PRODUCTION.value, - ] + ACTIVE_CODES = [BuildStatus.PENDING.value, BuildStatus.PRODUCTION.value] class ReturnOrderStatus(StatusCode): """Defines a set of status codes for a ReturnOrder""" # Order is pending, waiting for receipt of items - PENDING = 10, _("Pending"), 'secondary' + PENDING = 10, _('Pending'), 'secondary' # Items have been received, and are being inspected - IN_PROGRESS = 20, _("In Progress"), 'primary' + IN_PROGRESS = 20, _('In Progress'), 'primary' - COMPLETE = 30, _("Complete"), 'success' - CANCELLED = 40, _("Cancelled"), 'danger' + COMPLETE = 30, _('Complete'), 'success' + CANCELLED = 40, _('Cancelled'), 'danger' class ReturnOrderStatusGroups: """Groups for ReturnOrderStatus codes.""" - OPEN = [ - ReturnOrderStatus.PENDING.value, - ReturnOrderStatus.IN_PROGRESS.value, - ] + OPEN = [ReturnOrderStatus.PENDING.value, ReturnOrderStatus.IN_PROGRESS.value] class ReturnOrderLineStatus(StatusCode): """Defines a set of status codes for a ReturnOrderLineItem""" - PENDING = 10, _("Pending"), 'secondary' + PENDING = 10, _('Pending'), 'secondary' # Item is to be returned to customer, no other action - RETURN = 20, _("Return"), 'success' + RETURN = 20, _('Return'), 'success' # Item is to be repaired, and returned to customer - REPAIR = 30, _("Repair"), 'primary' + REPAIR = 30, _('Repair'), 'primary' # Item is to be replaced (new item shipped) - REPLACE = 40, _("Replace"), 'warning' + REPLACE = 40, _('Replace'), 'warning' # Item is to be refunded (cannot be repaired) - REFUND = 50, _("Refund"), 'info' + REFUND = 50, _('Refund'), 'info' # Item is rejected - REJECT = 60, _("Reject"), 'danger' + REJECT = 60, _('Reject'), 'danger' diff --git a/InvenTree/InvenTree/tasks.py b/InvenTree/InvenTree/tasks.py index 3bbc36be12..7f67ab9f63 100644 --- a/InvenTree/InvenTree/tasks.py +++ b/InvenTree/InvenTree/tasks.py @@ -16,20 +16,22 @@ from django.core.exceptions import AppRegistryNotReady from django.core.management import call_command from django.db import DEFAULT_DB_ALIAS, connections from django.db.migrations.executor import MigrationExecutor -from django.db.utils import (NotSupportedError, OperationalError, - ProgrammingError) +from django.db.utils import NotSupportedError, OperationalError, ProgrammingError from django.utils import timezone import requests -from maintenance_mode.core import (get_maintenance_mode, maintenance_mode_on, - set_maintenance_mode) +from maintenance_mode.core import ( + get_maintenance_mode, + maintenance_mode_on, + set_maintenance_mode, +) from InvenTree.config import get_setting from plugin import registry from .version import isInvenTreeUpToDate -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') def schedule_task(taskname, **kwargs): @@ -44,7 +46,7 @@ def schedule_task(taskname, **kwargs): try: from django_q.models import Schedule except AppRegistryNotReady: # pragma: no cover - logger.info("Could not start background tasks - App registry not ready") + logger.info('Could not start background tasks - App registry not ready') return try: @@ -57,11 +59,7 @@ def schedule_task(taskname, **kwargs): else: logger.info("Creating scheduled task '%s'", taskname) - Schedule.objects.create( - name=taskname, - func=taskname, - **kwargs - ) + Schedule.objects.create(name=taskname, func=taskname, **kwargs) except (OperationalError, ProgrammingError): # pragma: no cover # Required if the DB is not ready yet pass @@ -96,7 +94,9 @@ def check_daily_holdoff(task_name: str, n_days: int = 1) -> bool: from InvenTree.ready import isInTestMode if n_days <= 0: - logger.info("Specified interval for task '%s' < 1 - task will not run", task_name) + logger.info( + "Specified interval for task '%s' < 1 - task will not run", task_name + ) return False # Sleep a random number of seconds to prevent worker conflict @@ -119,7 +119,9 @@ def check_daily_holdoff(task_name: str, n_days: int = 1) -> bool: threshold = datetime.now() - timedelta(days=n_days) if last_success > threshold: - logger.info("Last successful run for '%s' was too recent - skipping task", task_name) + logger.info( + "Last successful run for '%s' was too recent - skipping task", task_name + ) return False # Check for any information we have about this task @@ -136,7 +138,9 @@ def check_daily_holdoff(task_name: str, n_days: int = 1) -> bool: threshold = datetime.now() - timedelta(hours=12) if last_attempt > threshold: - logger.info("Last attempt for '%s' was too recent - skipping task", task_name) + logger.info( + "Last attempt for '%s' was too recent - skipping task", task_name + ) return False # Record this attempt @@ -152,17 +156,23 @@ def record_task_attempt(task_name: str): logger.info("Logging task attempt for '%s'", task_name) - InvenTreeSetting.set_setting(f'_{task_name}_ATTEMPT', datetime.now().isoformat(), None) + InvenTreeSetting.set_setting( + f'_{task_name}_ATTEMPT', datetime.now().isoformat(), None + ) def record_task_success(task_name: str): """Record that a multi-day task was successful *now*""" from common.models import InvenTreeSetting - InvenTreeSetting.set_setting(f'_{task_name}_SUCCESS', datetime.now().isoformat(), None) + InvenTreeSetting.set_setting( + f'_{task_name}_SUCCESS', datetime.now().isoformat(), None + ) -def offload_task(taskname, *args, force_async=False, force_sync=False, **kwargs) -> bool: +def offload_task( + taskname, *args, force_async=False, force_sync=False, **kwargs +) -> bool: """Create an AsyncTask if workers are running. This is different to a 'scheduled' task, in that it only runs once! If workers are not running or force_sync flag, is set then the task is ran synchronously. @@ -205,7 +215,6 @@ def offload_task(taskname, *args, force_async=False, force_sync=False, **kwargs) raise_warning(f"WARNING: '{taskname}' not offloaded due to {str(exc)}") return False else: - if callable(taskname): # function was passed - use that _func = taskname @@ -215,14 +224,18 @@ def offload_task(taskname, *args, force_async=False, force_sync=False, **kwargs) app, mod, func = taskname.split('.') app_mod = app + '.' + mod except ValueError: - raise_warning(f"WARNING: '{taskname}' not started - Malformed function path") + raise_warning( + f"WARNING: '{taskname}' not started - Malformed function path" + ) return False # Import module from app try: _mod = importlib.import_module(app_mod) except ModuleNotFoundError: - raise_warning(f"WARNING: '{taskname}' not started - No module named '{app_mod}'") + raise_warning( + f"WARNING: '{taskname}' not started - No module named '{app_mod}'" + ) return False # Retrieve function @@ -236,7 +249,9 @@ def offload_task(taskname, *args, force_async=False, force_sync=False, **kwargs) if not _func: _func = eval(func) # pragma: no cover except NameError: - raise_warning(f"WARNING: '{taskname}' not started - No function named '{func}'") + raise_warning( + f"WARNING: '{taskname}' not started - No function named '{func}'" + ) return False # Workers are not running: run it as synchronous task @@ -263,18 +278,19 @@ class ScheduledTask: interval: str minutes: int = None - MINUTES = "I" - HOURLY = "H" - DAILY = "D" - WEEKLY = "W" - MONTHLY = "M" - QUARTERLY = "Q" - YEARLY = "Y" + MINUTES = 'I' + HOURLY = 'H' + DAILY = 'D' + WEEKLY = 'W' + MONTHLY = 'M' + QUARTERLY = 'Q' + YEARLY = 'Y' TYPE = [MINUTES, HOURLY, DAILY, WEEKLY, MONTHLY, QUARTERLY, YEARLY] class TaskRegister: """Registry for periodic tasks.""" + task_list: List[ScheduledTask] = [] def register(self, task, schedule, minutes: int = None): @@ -320,6 +336,7 @@ def scheduled_task(interval: str, minutes: int = None, tasklist: TaskRegister = _tasks.register(admin_class, interval, minutes=minutes) return admin_class + return _task_wrapper @@ -332,7 +349,7 @@ def heartbeat(): try: from django_q.models import Success except AppRegistryNotReady: # pragma: no cover - logger.info("Could not perform heartbeat task - App registry not ready") + logger.info('Could not perform heartbeat task - App registry not ready') return threshold = timezone.now() - timedelta(minutes=30) @@ -340,8 +357,7 @@ def heartbeat(): # Delete heartbeat results more than half an hour old, # otherwise they just create extra noise heartbeats = Success.objects.filter( - func='InvenTree.tasks.heartbeat', - started__lte=threshold + func='InvenTree.tasks.heartbeat', started__lte=threshold ) heartbeats.delete() @@ -359,16 +375,16 @@ def delete_successful_tasks(): threshold = timezone.now() - timedelta(days=days) # Delete successful tasks - results = Success.objects.filter( - started__lte=threshold - ) + results = Success.objects.filter(started__lte=threshold) if results.count() > 0: - logger.info("Deleting %s successful task records", results.count()) + logger.info('Deleting %s successful task records', results.count()) results.delete() except AppRegistryNotReady: # pragma: no cover - logger.info("Could not perform 'delete_successful_tasks' - App registry not ready") + logger.info( + "Could not perform 'delete_successful_tasks' - App registry not ready" + ) @scheduled_task(ScheduledTask.DAILY) @@ -383,12 +399,10 @@ def delete_failed_tasks(): threshold = timezone.now() - timedelta(days=days) # Delete failed tasks - results = Failure.objects.filter( - started__lte=threshold - ) + results = Failure.objects.filter(started__lte=threshold) if results.count() > 0: - logger.info("Deleting %s failed task records", results.count()) + logger.info('Deleting %s failed task records', results.count()) results.delete() except AppRegistryNotReady: # pragma: no cover @@ -406,47 +420,48 @@ def delete_old_error_logs(): days = InvenTreeSetting.get_setting('INVENTREE_DELETE_ERRORS_DAYS', 30) threshold = timezone.now() - timedelta(days=days) - errors = Error.objects.filter( - when__lte=threshold, - ) + errors = Error.objects.filter(when__lte=threshold) if errors.count() > 0: - logger.info("Deleting %s old error logs", errors.count()) + logger.info('Deleting %s old error logs', errors.count()) errors.delete() except AppRegistryNotReady: # pragma: no cover # Apps not yet loaded - logger.info("Could not perform 'delete_old_error_logs' - App registry not ready") + logger.info( + "Could not perform 'delete_old_error_logs' - App registry not ready" + ) @scheduled_task(ScheduledTask.DAILY) def delete_old_notifications(): """Delete old notification logs""" try: - from common.models import (InvenTreeSetting, NotificationEntry, - NotificationMessage) + from common.models import ( + InvenTreeSetting, + NotificationEntry, + NotificationMessage, + ) days = InvenTreeSetting.get_setting('INVENTREE_DELETE_NOTIFICATIONS_DAYS', 30) threshold = timezone.now() - timedelta(days=days) - items = NotificationEntry.objects.filter( - updated__lte=threshold - ) + items = NotificationEntry.objects.filter(updated__lte=threshold) if items.count() > 0: - logger.info("Deleted %s old notification entries", items.count()) + logger.info('Deleted %s old notification entries', items.count()) items.delete() - items = NotificationMessage.objects.filter( - creation__lte=threshold - ) + items = NotificationMessage.objects.filter(creation__lte=threshold) if items.count() > 0: - logger.info("Deleted %s old notification messages", items.count()) + logger.info('Deleted %s old notification messages', items.count()) items.delete() except AppRegistryNotReady: - logger.info("Could not perform 'delete_old_notifications' - App registry not ready") + logger.info( + "Could not perform 'delete_old_notifications' - App registry not ready" + ) @scheduled_task(ScheduledTask.DAILY) @@ -460,13 +475,17 @@ def check_for_updates(): logger.info("Could not perform 'check_for_updates' - App registry not ready") return - interval = int(common.models.InvenTreeSetting.get_setting('INVENTREE_UPDATE_CHECK_INTERVAL', 7, cache=False)) + interval = int( + common.models.InvenTreeSetting.get_setting( + 'INVENTREE_UPDATE_CHECK_INTERVAL', 7, cache=False + ) + ) # Check if we should check for updates *today* if not check_daily_holdoff('check_for_updates', interval): return - logger.info("Checking for InvenTree software updates") + logger.info('Checking for InvenTree software updates') headers = {} @@ -475,15 +494,17 @@ def check_for_updates(): token = os.getenv('GITHUB_TOKEN', None) if token: - headers['Authorization'] = f"Bearer {token}" + headers['Authorization'] = f'Bearer {token}' response = requests.get( 'https://api.github.com/repos/inventree/inventree/releases/latest', - headers=headers + headers=headers, ) if response.status_code != 200: - raise ValueError(f'Unexpected status code from GitHub API: {response.status_code}') # pragma: no cover + raise ValueError( + f'Unexpected status code from GitHub API: {response.status_code}' + ) # pragma: no cover data = json.loads(response.text) @@ -492,7 +513,7 @@ def check_for_updates(): if not tag: raise ValueError("'tag_name' missing from GitHub response") # pragma: no cover - match = re.match(r"^.*(\d+)\.(\d+)\.(\d+).*$", tag) + match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', tag) if len(match.groups()) != 3: # pragma: no cover logger.warning("Version '%s' did not match expected pattern", tag) @@ -506,31 +527,26 @@ def check_for_updates(): logger.info("Latest InvenTree version: '%s'", tag) # Save the version to the database - common.models.InvenTreeSetting.set_setting( - '_INVENTREE_LATEST_VERSION', - tag, - None - ) + common.models.InvenTreeSetting.set_setting('_INVENTREE_LATEST_VERSION', tag, None) # Record that this task was successful record_task_success('check_for_updates') # Send notification if there is a new version if not isInvenTreeUpToDate(): - logger.warning("InvenTree is not up-to-date, sending notification") + logger.warning('InvenTree is not up-to-date, sending notification') plg = registry.get_plugin('InvenTreeCoreNotificationsPlugin') if not plg: - logger.warning("Cannot send notification - plugin not found") + logger.warning('Cannot send notification - plugin not found') return plg = plg.plugin_config() if not plg: - logger.warning("Cannot send notification - plugin config not found") + logger.warning('Cannot send notification - plugin config not found') return # Send notification trigger_superuser_notification( - plg, - f'An update for InvenTree to version {tag} is available', + plg, f'An update for InvenTree to version {tag} is available' ) @@ -549,17 +565,21 @@ def update_exchange_rates(force: bool = False): from InvenTree.exchange import InvenTreeExchange except AppRegistryNotReady: # pragma: no cover # Apps not yet loaded! - logger.info("Could not perform 'update_exchange_rates' - App registry not ready") + logger.info( + "Could not perform 'update_exchange_rates' - App registry not ready" + ) return except Exception as exc: # pragma: no cover logger.info("Could not perform 'update_exchange_rates' - %s", exc) return if not force: - interval = int(InvenTreeSetting.get_setting('CURRENCY_UPDATE_INTERVAL', 1, cache=False)) + interval = int( + InvenTreeSetting.get_setting('CURRENCY_UPDATE_INTERVAL', 1, cache=False) + ) if not check_daily_holdoff('update_exchange_rates', interval): - logger.info("Skipping exchange rate update (interval not reached)") + logger.info('Skipping exchange rate update (interval not reached)') return backend = InvenTreeExchange() @@ -570,15 +590,17 @@ def update_exchange_rates(force: bool = False): backend.update_rates(base_currency=base) # Remove any exchange rates which are not in the provided currencies - Rate.objects.filter(backend="InvenTreeExchange").exclude(currency__in=currency_codes()).delete() + Rate.objects.filter(backend='InvenTreeExchange').exclude( + currency__in=currency_codes() + ).delete() # Record successful task execution record_task_success('update_exchange_rates') except (AppRegistryNotReady, OperationalError, ProgrammingError): - logger.warning("Could not update exchange rates - database not ready") + logger.warning('Could not update exchange rates - database not ready') except Exception as e: # pragma: no cover - logger.exception("Error updating exchange rates: %s", str(type(e))) + logger.exception('Error updating exchange rates: %s', str(type(e))) @scheduled_task(ScheduledTask.DAILY) @@ -590,16 +612,20 @@ def run_backup(): # Backups are not enabled - exit early return - interval = int(InvenTreeSetting.get_setting('INVENTREE_BACKUP_DAYS', 1, cache=False)) + interval = int( + InvenTreeSetting.get_setting('INVENTREE_BACKUP_DAYS', 1, cache=False) + ) # Check if should run this task *today* if not check_daily_holdoff('run_backup', interval): return - logger.info("Performing automated database backup task") + logger.info('Performing automated database backup task') - call_command("dbbackup", noinput=True, clean=True, compress=True, interactive=False) - call_command("mediabackup", noinput=True, clean=True, compress=True, interactive=False) + call_command('dbbackup', noinput=True, clean=True, compress=True, interactive=False) + call_command( + 'mediabackup', noinput=True, clean=True, compress=True, interactive=False + ) # Record that this task was successful record_task_success('run_backup') @@ -627,7 +653,7 @@ def check_for_migrations(): logger.info('There are %s pending migrations', n) InvenTreeSetting.set_setting('_PENDING_MIGRATIONS', n, None) - logger.info("Checking for pending database migrations") + logger.info('Checking for pending database migrations') # Force plugin registry reload registry.check_reload() @@ -645,12 +671,12 @@ def check_for_migrations(): # Test if auto-updates are enabled if not get_setting('INVENTREE_AUTO_UPDATE', 'auto_update'): - logger.info("Auto-update is disabled - skipping migrations") + logger.info('Auto-update is disabled - skipping migrations') return # Log open migrations for migration in plan: - logger.info("- %s", str(migration[0])) + logger.info('- %s', str(migration[0])) # Set the application to maintenance mode - no access from now on. set_maintenance_mode(True) @@ -668,13 +694,13 @@ def check_for_migrations(): else: set_pending_migrations(0) - logger.info("Completed %s migrations", n) + logger.info('Completed %s migrations', n) # Make sure we are out of maintenance mode if get_maintenance_mode(): - logger.warning("Maintenance mode was not disabled - forcing it now") + logger.warning('Maintenance mode was not disabled - forcing it now') set_maintenance_mode(False) - logger.info("Manually released maintenance mode") + logger.info('Manually released maintenance mode') # We should be current now - triggering full reload to make sure all models # are loaded fully in their new state. diff --git a/InvenTree/InvenTree/test_api.py b/InvenTree/InvenTree/test_api.py index aa7907d926..72ea009853 100644 --- a/InvenTree/InvenTree/test_api.py +++ b/InvenTree/InvenTree/test_api.py @@ -17,6 +17,7 @@ class HTMLAPITests(InvenTreeTestCase): which raised an AssertionError when using the HTML API interface, while the regular JSON interface continued to work as expected. """ + roles = 'all' def test_part_api(self): @@ -60,12 +61,7 @@ class HTMLAPITests(InvenTreeTestCase): class APITests(InvenTreeAPITestCase): """Tests for the InvenTree API.""" - fixtures = [ - 'location', - 'category', - 'part', - 'stock' - ] + fixtures = ['location', 'category', 'part', 'stock'] token = None auto_login = False @@ -73,11 +69,11 @@ class APITests(InvenTreeAPITestCase): """Helper function to use basic auth.""" # Use basic authentication - authstring = bytes("{u}:{p}".format(u=self.username, p=self.password), "ascii") + authstring = bytes('{u}:{p}'.format(u=self.username, p=self.password), 'ascii') # Use "basic" auth by default - auth = b64encode(authstring).decode("ascii") - self.client.credentials(HTTP_AUTHORIZATION="Basic {auth}".format(auth=auth)) + auth = b64encode(authstring).decode('ascii') + self.client.credentials(HTTP_AUTHORIZATION='Basic {auth}'.format(auth=auth)) def tokenAuth(self): """Helper function to use token auth.""" @@ -270,33 +266,20 @@ class BulkDeleteTests(InvenTreeAPITestCase): url = reverse('api-stock-test-result-list') # DELETE without any of the required fields - response = self.delete( - url, - {}, - expected_code=400 - ) + response = self.delete(url, {}, expected_code=400) - self.assertIn('List of items or filters must be provided for bulk deletion', str(response.data)) + self.assertIn( + 'List of items or filters must be provided for bulk deletion', + str(response.data), + ) # DELETE with invalid 'items' - response = self.delete( - url, - { - 'items': {"hello": "world"}, - }, - expected_code=400, - ) + response = self.delete(url, {'items': {'hello': 'world'}}, expected_code=400) self.assertIn("'items' must be supplied as a list object", str(response.data)) # DELETE with invalid 'filters' - response = self.delete( - url, - { - 'filters': [1, 2, 3], - }, - expected_code=400, - ) + response = self.delete(url, {'filters': [1, 2, 3]}, expected_code=400) self.assertIn("'filters' must be supplied as a dict object", str(response.data)) @@ -317,11 +300,7 @@ class SearchTests(InvenTreeAPITestCase): def test_empty(self): """Test empty request""" - data = [ - '', - None, - {}, - ] + data = ['', None, {}] for d in data: response = self.post(reverse('api-search'), d, expected_code=400) @@ -331,13 +310,8 @@ class SearchTests(InvenTreeAPITestCase): """Test individual result types""" response = self.post( reverse('api-search'), - { - 'search': 'chair', - 'limit': 3, - 'part': {}, - 'build': {}, - }, - expected_code=200 + {'search': 'chair', 'limit': 3, 'part': {}, 'build': {}}, + expected_code=200, ) # No build results @@ -354,12 +328,7 @@ class SearchTests(InvenTreeAPITestCase): # Search for orders response = self.post( reverse('api-search'), - { - 'search': '01', - 'limit': 2, - 'purchaseorder': {}, - 'salesorder': {}, - }, + {'search': '01', 'limit': 2, 'purchaseorder': {}, 'salesorder': {}}, expected_code=200, ) @@ -392,33 +361,25 @@ class SearchTests(InvenTreeAPITestCase): 'salesorder', ] - query = { - 'search': 'c', - 'limit': 3, - } + query = {'search': 'c', 'limit': 3} for mdl in models: query[mdl] = {} - response = self.post( - reverse('api-search'), - query, - expected_code=200 - ) + response = self.post(reverse('api-search'), query, expected_code=200) # Check for 'permission denied' error for mdl in models: - self.assertEqual(response.data[mdl]['error'], 'User does not have permission to view this model') + self.assertEqual( + response.data[mdl]['error'], + 'User does not have permission to view this model', + ) # Assign view roles for some parts self.assignRole('build.view') self.assignRole('part.view') - response = self.post( - reverse('api-search'), - query, - expected_code=200 - ) + response = self.post(reverse('api-search'), query, expected_code=200) # Check for expected results, based on permissions # We expect results to be returned for the following model types @@ -438,4 +399,6 @@ class SearchTests(InvenTreeAPITestCase): self.assertIn('count', result) else: self.assertIn('error', result) - self.assertEqual(result['error'], 'User does not have permission to view this model') + self.assertEqual( + result['error'], 'User does not have permission to view this model' + ) diff --git a/InvenTree/InvenTree/test_api_version.py b/InvenTree/InvenTree/test_api_version.py index d6a2c38fb6..4b2a3b49b5 100644 --- a/InvenTree/InvenTree/test_api_version.py +++ b/InvenTree/InvenTree/test_api_version.py @@ -1,6 +1,5 @@ """Tests for api_version.""" - from django.urls import reverse from InvenTree.api_version import INVENTREE_API_VERSION diff --git a/InvenTree/InvenTree/test_middleware.py b/InvenTree/InvenTree/test_middleware.py index f6928a798b..f750a339b2 100644 --- a/InvenTree/InvenTree/test_middleware.py +++ b/InvenTree/InvenTree/test_middleware.py @@ -66,6 +66,7 @@ class MiddlewareTests(InvenTreeTestCase): def test_error_exceptions(self): """Test that ignored errors are not logged.""" + def check(excpected_nbr=0): # Check that errors are empty errors = Error.objects.all() diff --git a/InvenTree/InvenTree/test_tasks.py b/InvenTree/InvenTree/test_tasks.py index cb14dfc962..eec5b39a69 100644 --- a/InvenTree/InvenTree/test_tasks.py +++ b/InvenTree/InvenTree/test_tasks.py @@ -67,15 +67,23 @@ class InvenTreeTaskTests(TestCase): # Error runs # Malformed taskname - with self.assertWarnsMessage(UserWarning, "WARNING: 'InvenTree' not started - Malformed function path"): + with self.assertWarnsMessage( + UserWarning, "WARNING: 'InvenTree' not started - Malformed function path" + ): InvenTree.tasks.offload_task('InvenTree') # Non existent app - with self.assertWarnsMessage(UserWarning, "WARNING: 'InvenTreeABC.test_tasks.doesnotmatter' not started - No module named 'InvenTreeABC.test_tasks'"): + with self.assertWarnsMessage( + UserWarning, + "WARNING: 'InvenTreeABC.test_tasks.doesnotmatter' not started - No module named 'InvenTreeABC.test_tasks'", + ): InvenTree.tasks.offload_task('InvenTreeABC.test_tasks.doesnotmatter') # Non existent function - with self.assertWarnsMessage(UserWarning, "WARNING: 'InvenTree.test_tasks.doesnotexist' not started - No function named 'doesnotexist'"): + with self.assertWarnsMessage( + UserWarning, + "WARNING: 'InvenTree.test_tasks.doesnotexist' not started - No function named 'doesnotexist'", + ): InvenTree.tasks.offload_task('InvenTree.test_tasks.doesnotexist') def test_task_hearbeat(self): @@ -86,7 +94,9 @@ class InvenTreeTaskTests(TestCase): """Test the task delete_successful_tasks.""" from django_q.models import Success - Success.objects.create(name='abc', func='abc', stopped=threshold, started=threshold_low) + Success.objects.create( + name='abc', func='abc', stopped=threshold, started=threshold_low + ) InvenTree.tasks.offload_task(InvenTree.tasks.delete_successful_tasks) results = Success.objects.filter(started__lte=threshold) self.assertEqual(len(results), 0) @@ -99,14 +109,14 @@ class InvenTreeTaskTests(TestCase): error_obj.save() # Check that it is not empty - errors = Error.objects.filter(when__lte=threshold,) + errors = Error.objects.filter(when__lte=threshold) self.assertNotEqual(len(errors), 0) # Run action InvenTree.tasks.offload_task(InvenTree.tasks.delete_old_error_logs) # Check that it is empty again - errors = Error.objects.filter(when__lte=threshold,) + errors = Error.objects.filter(when__lte=threshold) self.assertEqual(len(errors), 0) def test_task_check_for_updates(self): @@ -146,7 +156,9 @@ class InvenTreeTaskTests(TestCase): # Cleanup try: migration_name = InvenTree.tasks.get_migration_plan()[0][0].name + '.py' - migration_path = settings.BASE_DIR / 'InvenTree' / 'migrations' / migration_name + migration_path = ( + settings.BASE_DIR / 'InvenTree' / 'migrations' / migration_name + ) migration_path.unlink() except IndexError: # pragma: no cover pass diff --git a/InvenTree/InvenTree/test_urls.py b/InvenTree/InvenTree/test_urls.py index 2f2db7e2a9..d15493a750 100644 --- a/InvenTree/InvenTree/test_urls.py +++ b/InvenTree/InvenTree/test_urls.py @@ -44,7 +44,7 @@ class URLTest(TestCase): ('part', 'templates'), ('report', 'templates'), ('stock', 'templates'), - ('templates', ), + ('templates',), ] template_files = [] @@ -53,11 +53,9 @@ class URLTest(TestCase): tld = os.path.join(here, '..') for directory in template_dirs: - template_dir = os.path.join(tld, *directory) for path in Path(template_dir).rglob(suffix): - f = os.path.abspath(path) if f not in template_files: @@ -69,25 +67,18 @@ class URLTest(TestCase): """Search for all instances of {% url %} in supplied template file.""" urls = [] - pattern = "{% url ['\"]([^'\"]+)['\"]([^%]*)%}" + pattern = '{% url [\'"]([^\'"]+)[\'"]([^%]*)%}' with open(input_file, 'r') as f: - data = f.read() results = re.findall(pattern, data) for result in results: if len(result) == 2: - urls.append([ - result[0].strip(), - result[1].strip() - ]) + urls.append([result[0].strip(), result[1].strip()]) elif len(result) == 1: # pragma: no cover - urls.append([ - result[0].strip(), - '' - ]) + urls.append([result[0].strip(), '']) return urls @@ -100,16 +91,16 @@ class URLTest(TestCase): pk = None # TODO: Handle reverse lookup of admin URLs! - if url.startswith("admin:"): + if url.startswith('admin:'): return # TODO can this be more elegant? - if url.startswith("account_"): + if url.startswith('account_'): return if pk: # We will assume that there is at least one item in the database - reverse(url, kwargs={"pk": 1}) + reverse(url, kwargs={'pk': 1}) else: reverse(url) @@ -122,14 +113,14 @@ class URLTest(TestCase): def test_html_templates(self): """Test all HTML templates for broken url tags.""" - template_files = self.find_files("*.html") + template_files = self.find_files('*.html') for f in template_files: self.check_file(f) def test_js_templates(self): """Test all JS templates for broken url tags.""" - template_files = self.find_files("*.js") + template_files = self.find_files('*.js') for f in template_files: self.check_file(f) diff --git a/InvenTree/InvenTree/test_views.py b/InvenTree/InvenTree/test_views.py index 708304c68d..48c8db6924 100644 --- a/InvenTree/InvenTree/test_views.py +++ b/InvenTree/InvenTree/test_views.py @@ -23,13 +23,13 @@ class ViewTests(InvenTreeTestCase): def test_index_redirect(self): """Top-level URL should redirect to "index" page.""" - response = self.client.get("/") + response = self.client.get('/') self.assertEqual(response.status_code, 302) def get_index_page(self): """Retrieve the index page (used for subsequent unit tests)""" - response = self.client.get("/index/") + response = self.client.get('/index/') self.assertEqual(response.status_code, 200) @@ -60,35 +60,20 @@ class ViewTests(InvenTreeTestCase): self.assertEqual(response.status_code, 200) content = response.content.decode() - user_panels = [ - 'account', - 'user-display', - 'user-home', - 'user-reports', - ] + user_panels = ['account', 'user-display', 'user-home', 'user-reports'] - staff_panels = [ - 'server', - 'login', - 'barcodes', - 'pricing', - 'parts', - 'stock', - ] + staff_panels = ['server', 'login', 'barcodes', 'pricing', 'parts', 'stock'] - plugin_panels = [ - 'plugin', - ] + plugin_panels = ['plugin'] # Default user has staff access, so all panels will be present for panel in user_panels + staff_panels + plugin_panels: - self.assertIn(f"select-{panel}", content) - self.assertIn(f"panel-{panel}", content) + self.assertIn(f'select-{panel}', content) + self.assertIn(f'panel-{panel}', content) # Now create a user who does not have staff access pleb_user = get_user_model().objects.create_user( - username='pleb', - password='notstaff', + username='pleb', password='notstaff' ) pleb_user.groups.add(self.group) @@ -98,10 +83,7 @@ class ViewTests(InvenTreeTestCase): self.client.logout() - result = self.client.login( - username='pleb', - password='notstaff', - ) + result = self.client.login(username='pleb', password='notstaff') self.assertTrue(result) @@ -111,22 +93,24 @@ class ViewTests(InvenTreeTestCase): # Normal user still has access to user-specific panels for panel in user_panels: - self.assertIn(f"select-{panel}", content) - self.assertIn(f"panel-{panel}", content) + self.assertIn(f'select-{panel}', content) + self.assertIn(f'panel-{panel}', content) # Normal user does NOT have access to global or plugin settings for panel in staff_panels + plugin_panels: - self.assertNotIn(f"select-{panel}", content) - self.assertNotIn(f"panel-{panel}", content) + self.assertNotIn(f'select-{panel}', content) + self.assertNotIn(f'panel-{panel}', content) def test_url_login(self): """Test logging in via arguments""" # Log out self.client.logout() - response = self.client.get("/index/") + response = self.client.get('/index/') self.assertEqual(response.status_code, 302) # Try login with url - response = self.client.get(f"/accounts/login/?next=/&login={self.username}&password={self.password}") + response = self.client.get( + f'/accounts/login/?next=/&login={self.username}&password={self.password}' + ) self.assertEqual(response.status_code, 302) self.assertEqual(response.url, '/') diff --git a/InvenTree/InvenTree/tests.py b/InvenTree/InvenTree/tests.py index 5c7fe65c02..2488c07685 100644 --- a/InvenTree/InvenTree/tests.py +++ b/InvenTree/InvenTree/tests.py @@ -45,12 +45,12 @@ class ConversionTest(TestCase): def test_prefixes(self): """Test inputs where prefixes are used""" tests = { - "3": 3, - "3m": 3, - "3mm": 0.003, - "3k": 3000, - "3u": 0.000003, - "3 inch": 0.0762, + '3': 3, + '3m': 3, + '3mm': 0.003, + '3k': 3000, + '3u': 0.000003, + '3 inch': 0.0762, } for val, expected in tests.items(): @@ -60,13 +60,13 @@ class ConversionTest(TestCase): def test_base_units(self): """Test conversion to specified base units""" tests = { - "3": 3, - "3 dozen": 36, - "50 dozen kW": 600000, - "1 / 10": 0.1, - "1/2 kW": 500, - "1/2 dozen kW": 6000, - "0.005 MW": 5000, + '3': 3, + '3 dozen': 36, + '50 dozen kW': 600000, + '1 / 10': 0.1, + '1/2 kW': 500, + '1/2 dozen kW': 6000, + '0.005 MW': 5000, } for val, expected in tests.items(): @@ -104,13 +104,7 @@ class ConversionTest(TestCase): def test_invalid_units(self): """Test conversion with bad units""" - tests = { - '3': '10', - '13': '-?-', - '-3': 'xyz', - '-12': '-12', - '1/0': '1/0', - } + tests = {'3': '10', '13': '-?-', '-3': 'xyz', '-12': '-12', '1/0': '1/0'} for val, unit in tests.items(): with self.assertRaises(ValidationError): @@ -118,12 +112,7 @@ class ConversionTest(TestCase): def test_invalid_values(self): """Test conversion of invalid inputs""" - inputs = [ - '-x', - '1/0', - 'xyz', - '12B45C' - ] + inputs = ['-x', '1/0', 'xyz', '12B45C'] for val in inputs: # Test with a provided unit @@ -148,9 +137,7 @@ class ConversionTest(TestCase): # Create a new custom unit CustomUnit.objects.create( - name='fanciful_unit', - definition='henry / mm', - symbol='hpmm', + name='fanciful_unit', definition='henry / mm', symbol='hpmm' ) # Reload registry @@ -171,7 +158,9 @@ class ConversionTest(TestCase): for val, expected in tests.items(): # Convert, and leave units - q = InvenTree.conversion.convert_physical_value(val, 'henry / km', strip_units=False) + q = InvenTree.conversion.convert_physical_value( + val, 'henry / km', strip_units=False + ) self.assertAlmostEqual(float(q.magnitude), expected, 2) # Convert and strip units @@ -184,24 +173,24 @@ class ValidatorTest(TestCase): def test_overage(self): """Test overage validator.""" - validate_overage("100%") - validate_overage("10") - validate_overage("45.2 %") + validate_overage('100%') + validate_overage('10') + validate_overage('45.2 %') with self.assertRaises(django_exceptions.ValidationError): - validate_overage("-1") + validate_overage('-1') with self.assertRaises(django_exceptions.ValidationError): - validate_overage("-2.04 %") + validate_overage('-2.04 %') with self.assertRaises(django_exceptions.ValidationError): - validate_overage("105%") + validate_overage('105%') with self.assertRaises(django_exceptions.ValidationError): - validate_overage("xxx %") + validate_overage('xxx %') with self.assertRaises(django_exceptions.ValidationError): - validate_overage("aaaa") + validate_overage('aaaa') def test_url_validation(self): """Test for AllowedURLValidator""" @@ -241,7 +230,7 @@ class FormatTest(TestCase): def test_parse(self): """Tests for the 'parse_format_string' function""" # Extract data from a valid format string - fmt = "PO-{abc:02f}-{ref:04d}-{date}-???" + fmt = 'PO-{abc:02f}-{ref:04d}-{date}-???' info = InvenTree.format.parse_format_string(fmt) @@ -250,22 +239,17 @@ class FormatTest(TestCase): self.assertIn('date', info) # Try with invalid strings - for fmt in [ - 'PO-{{xyz}', - 'PO-{xyz}}', - 'PO-{xyz}-{', - ]: - + for fmt in ['PO-{{xyz}', 'PO-{xyz}}', 'PO-{xyz}-{']: with self.assertRaises(ValueError): InvenTree.format.parse_format_string(fmt) def test_create_regex(self): """Test function for creating a regex from a format string""" tests = { - "PO-123-{ref:04f}": r"^PO\-123\-(?P.+)$", - "{PO}-???-{ref}-{date}-22": r"^(?P.+)\-...\-(?P.+)\-(?P.+)\-22$", - "ABC-123-###-{ref}": r"^ABC\-123\-\d\d\d\-(?P.+)$", - "ABC-123": r"^ABC\-123$", + 'PO-123-{ref:04f}': r'^PO\-123\-(?P.+)$', + '{PO}-???-{ref}-{date}-22': r'^(?P.+)\-...\-(?P.+)\-(?P.+)\-22$', + 'ABC-123-###-{ref}': r'^ABC\-123\-\d\d\d\-(?P.+)$', + 'ABC-123': r'^ABC\-123$', } for fmt, reg in tests.items(): @@ -275,55 +259,44 @@ class FormatTest(TestCase): """Test that string validation works as expected""" # These tests should pass for value, pattern in { - "ABC-hello-123": "???-{q}-###", - "BO-1234": "BO-{ref}", - "111.222.fred.china": "???.###.{name}.{place}", - "PO-1234": "PO-{ref:04d}" + 'ABC-hello-123': '???-{q}-###', + 'BO-1234': 'BO-{ref}', + '111.222.fred.china': '???.###.{name}.{place}', + 'PO-1234': 'PO-{ref:04d}', }.items(): self.assertTrue(InvenTree.format.validate_string(value, pattern)) # These tests should fail for value, pattern in { - "ABC-hello-123": "###-{q}-???", - "BO-1234": "BO.{ref}", - "BO-####": "BO-{pattern}-{next}", - "BO-123d": "BO-{ref:04d}" + 'ABC-hello-123': '###-{q}-???', + 'BO-1234': 'BO.{ref}', + 'BO-####': 'BO-{pattern}-{next}', + 'BO-123d': 'BO-{ref:04d}', }.items(): self.assertFalse(InvenTree.format.validate_string(value, pattern)) def test_extract_value(self): """Test that we can extract named values based on a format string""" # Simple tests based on a straight-forward format string - fmt = "PO-###-{ref:04d}" + fmt = 'PO-###-{ref:04d}' - tests = { - "123": "PO-123-123", - "456": "PO-123-456", - "789": "PO-123-789", - } + tests = {'123': 'PO-123-123', '456': 'PO-123-456', '789': 'PO-123-789'} for k, v in tests.items(): self.assertEqual(InvenTree.format.extract_named_group('ref', v, fmt), k) # However these ones should fail - tests = { - 'abc': 'PO-123-abc', - 'xyz': 'PO-123-xyz', - } + tests = {'abc': 'PO-123-abc', 'xyz': 'PO-123-xyz'} for v in tests.values(): with self.assertRaises(ValueError): InvenTree.format.extract_named_group('ref', v, fmt) # More complex tests - fmt = "PO-{date}-{test}-???-{ref}-###" - val = "PO-2022-02-01-hello-ABC-12345-222" + fmt = 'PO-{date}-{test}-???-{ref}-###' + val = 'PO-2022-02-01-hello-ABC-12345-222' - data = { - 'date': '2022-02-01', - 'test': 'hello', - 'ref': '12345', - } + data = {'date': '2022-02-01', 'test': 'hello', 'ref': '12345'} for k, v in data.items(): self.assertEqual(InvenTree.format.extract_named_group(k, val, fmt), v) @@ -332,60 +305,46 @@ class FormatTest(TestCase): # Raises a ValueError as the format string is bad with self.assertRaises(ValueError): - InvenTree.format.extract_named_group( - "test", - "PO-1234-5", - "PO-{test}-{" - ) + InvenTree.format.extract_named_group('test', 'PO-1234-5', 'PO-{test}-{') # Raises a NameError as the named group does not exist in the format string with self.assertRaises(NameError): - InvenTree.format.extract_named_group( - "missing", - "PO-12345", - "PO-{test}", - ) + InvenTree.format.extract_named_group('missing', 'PO-12345', 'PO-{test}') # Raises a ValueError as the value does not match the format string with self.assertRaises(ValueError): - InvenTree.format.extract_named_group( - "test", - "PO-1234", - "PO-{test}-1234", - ) + InvenTree.format.extract_named_group('test', 'PO-1234', 'PO-{test}-1234') with self.assertRaises(ValueError): - InvenTree.format.extract_named_group( - "test", - "PO-ABC-xyz", - "PO-###-{test}", - ) + InvenTree.format.extract_named_group('test', 'PO-ABC-xyz', 'PO-###-{test}') def test_currency_formatting(self): """Test that currency formatting works correctly for multiple currencies""" test_data = ( - (Money( 3651.285718, "USD"), 4, "$3,651.2857" ), # noqa: E201,E202 - (Money(487587.849178, "CAD"), 5, "CA$487,587.84918"), # noqa: E201,E202 - (Money( 0.348102, "EUR"), 1, "€0.3" ), # noqa: E201,E202 - (Money( 0.916530, "GBP"), 1, "£0.9" ), # noqa: E201,E202 - (Money( 61.031024, "JPY"), 3, "¥61.031" ), # noqa: E201,E202 - (Money( 49609.694602, "JPY"), 1, "¥49,609.7" ), # noqa: E201,E202 - (Money(155565.264777, "AUD"), 2, "A$155,565.26" ), # noqa: E201,E202 - (Money( 0.820437, "CNY"), 4, "CN¥0.8204" ), # noqa: E201,E202 - (Money( 7587.849178, "EUR"), 0, "€7,588" ), # noqa: E201,E202 - (Money( 0.348102, "GBP"), 3, "£0.348" ), # noqa: E201,E202 - (Money( 0.652923, "CHF"), 0, "CHF1" ), # noqa: E201,E202 - (Money( 0.820437, "CNY"), 1, "CN¥0.8" ), # noqa: E201,E202 - (Money(98789.5295680, "CHF"), 0, "CHF98,790" ), # noqa: E201,E202 - (Money( 0.585787, "USD"), 1, "$0.6" ), # noqa: E201,E202 - (Money( 0.690541, "CAD"), 3, "CA$0.691" ), # noqa: E201,E202 - (Money( 427.814104, "AUD"), 5, "A$427.81410" ), # noqa: E201,E202 + (Money(3651.285718, 'USD'), 4, '$3,651.2857'), # noqa: E201,E202 + (Money(487587.849178, 'CAD'), 5, 'CA$487,587.84918'), # noqa: E201,E202 + (Money(0.348102, 'EUR'), 1, '€0.3'), # noqa: E201,E202 + (Money(0.916530, 'GBP'), 1, '£0.9'), # noqa: E201,E202 + (Money(61.031024, 'JPY'), 3, '¥61.031'), # noqa: E201,E202 + (Money(49609.694602, 'JPY'), 1, '¥49,609.7'), # noqa: E201,E202 + (Money(155565.264777, 'AUD'), 2, 'A$155,565.26'), # noqa: E201,E202 + (Money(0.820437, 'CNY'), 4, 'CN¥0.8204'), # noqa: E201,E202 + (Money(7587.849178, 'EUR'), 0, '€7,588'), # noqa: E201,E202 + (Money(0.348102, 'GBP'), 3, '£0.348'), # noqa: E201,E202 + (Money(0.652923, 'CHF'), 0, 'CHF1'), # noqa: E201,E202 + (Money(0.820437, 'CNY'), 1, 'CN¥0.8'), # noqa: E201,E202 + (Money(98789.5295680, 'CHF'), 0, 'CHF98,790'), # noqa: E201,E202 + (Money(0.585787, 'USD'), 1, '$0.6'), # noqa: E201,E202 + (Money(0.690541, 'CAD'), 3, 'CA$0.691'), # noqa: E201,E202 + (Money(427.814104, 'AUD'), 5, 'A$427.81410'), # noqa: E201,E202 ) - with self.settings(LANGUAGE_CODE="en-us"): + with self.settings(LANGUAGE_CODE='en-us'): for value, decimal_places, expected_result in test_data: - result = InvenTree.format.format_money(value, decimal_places=decimal_places) + result = InvenTree.format.format_money( + value, decimal_places=decimal_places + ) assert result == expected_result @@ -394,30 +353,35 @@ class TestHelpers(TestCase): def test_absolute_url(self): """Test helper function for generating an absolute URL""" - base = "https://demo.inventree.org:12345" + base = 'https://demo.inventree.org:12345' InvenTreeSetting.set_setting('INVENTREE_BASE_URL', base, change_user=None) tests = { - "": base, - "api/": base + "/api/", - "/api/": base + "/api/", - "api": base + "/api", - "media/label/output/": base + "/media/label/output/", - "static/logo.png": base + "/static/logo.png", - "https://www.google.com": "https://www.google.com", - "https://demo.inventree.org:12345/out.html": "https://demo.inventree.org:12345/out.html", - "https://demo.inventree.org/test.html": "https://demo.inventree.org/test.html", - "http://www.cwi.nl:80/%7Eguido/Python.html": "http://www.cwi.nl:80/%7Eguido/Python.html", - "test.org": base + "/test.org", + '': base, + 'api/': base + '/api/', + '/api/': base + '/api/', + 'api': base + '/api', + 'media/label/output/': base + '/media/label/output/', + 'static/logo.png': base + '/static/logo.png', + 'https://www.google.com': 'https://www.google.com', + 'https://demo.inventree.org:12345/out.html': 'https://demo.inventree.org:12345/out.html', + 'https://demo.inventree.org/test.html': 'https://demo.inventree.org/test.html', + 'http://www.cwi.nl:80/%7Eguido/Python.html': 'http://www.cwi.nl:80/%7Eguido/Python.html', + 'test.org': base + '/test.org', } for url, expected in tests.items(): # Test with supplied base URL - self.assertEqual(InvenTree.helpers_model.construct_absolute_url(url, site_url=base), expected) + self.assertEqual( + InvenTree.helpers_model.construct_absolute_url(url, site_url=base), + expected, + ) # Test without supplied base URL - self.assertEqual(InvenTree.helpers_model.construct_absolute_url(url), expected) + self.assertEqual( + InvenTree.helpers_model.construct_absolute_url(url), expected + ) def test_image_url(self): """Test if a filename looks like an image.""" @@ -453,7 +417,9 @@ class TestHelpers(TestCase): """Test static url helpers.""" self.assertEqual(helpers.getStaticUrl('test.jpg'), '/static/test.jpg') self.assertEqual(helpers.getBlankImage(), '/static/img/blank_image.png') - self.assertEqual(helpers.getBlankThumbnail(), '/static/img/blank_image.thumbnail.png') + self.assertEqual( + helpers.getBlankThumbnail(), '/static/img/blank_image.thumbnail.png' + ) def testMediaUrl(self): """Test getMediaUrl.""" @@ -476,12 +442,7 @@ class TestHelpers(TestCase): def test_download_image(self): """Test function for downloading image from remote URL""" # Run check with a sequence of bad URLs - for url in [ - "blog", - "htp://test.com/?", - "google", - "\\invalid-url" - ]: + for url in ['blog', 'htp://test.com/?', 'google', '\\invalid-url']: with self.assertRaises(django_exceptions.ValidationError): InvenTree.helpers_model.download_image_from_url(url) @@ -496,16 +457,17 @@ class TestHelpers(TestCase): with self.assertRaises(expected_error): while tries < retries: - try: - InvenTree.helpers_model.download_image_from_url(url, timeout=timeout) + InvenTree.helpers_model.download_image_from_url( + url, timeout=timeout + ) break except Exception as exc: if type(exc) is expected_error: # Re-throw this error raise exc else: - print("Unexpected error:", type(exc), exc) + print('Unexpected error:', type(exc), exc) tries += 1 time.sleep(10 * tries) @@ -518,16 +480,20 @@ class TestHelpers(TestCase): # TODO: Re-implement this test when we are happier with the external service # dl_helper("https://httpstat.us/200?sleep=5000", requests.exceptions.ReadTimeout, timeout=1) - large_img = "https://github.com/inventree/InvenTree/raw/master/InvenTree/InvenTree/static/img/paper_splash_large.jpg" + large_img = 'https://github.com/inventree/InvenTree/raw/master/InvenTree/InvenTree/static/img/paper_splash_large.jpg' - InvenTreeSetting.set_setting('INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE', 1, change_user=None) + InvenTreeSetting.set_setting( + 'INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE', 1, change_user=None + ) # Attempt to download an image which is too large with self.assertRaises(ValueError): InvenTree.helpers_model.download_image_from_url(large_img, timeout=10) # Increase allowable download size - InvenTreeSetting.set_setting('INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE', 5, change_user=None) + InvenTreeSetting.set_setting( + 'INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE', 5, change_user=None + ) # Download a valid image (should not throw an error) InvenTree.helpers_model.download_image_from_url(large_img, timeout=10) @@ -561,14 +527,14 @@ class TestIncrement(TestCase): def tests(self): """Test 'intelligent' incrementing function.""" tests = [ - ("", '1'), - (1, "2"), - ("001", "002"), - ("1001", "1002"), - ("ABC123", "ABC124"), - ("XYZ0", "XYZ1"), - ("123Q", "123Q"), - ("QQQ", "QQQ"), + ('', '1'), + (1, '2'), + ('001', '002'), + ('1001', '1002'), + ('ABC123', 'ABC124'), + ('XYZ0', 'XYZ1'), + ('123Q', '123Q'), + ('QQQ', 'QQQ'), ] for test in tests: @@ -584,13 +550,7 @@ class TestMakeBarcode(TestCase): def test_barcode_extended(self): """Test creation of barcode with extended data.""" bc = helpers.MakeBarcode( - "part", - 3, - { - "id": 3, - "url": "www.google.com", - }, - brief=False + 'part', 3, {'id': 3, 'url': 'www.google.com'}, brief=False ) self.assertIn('part', bc) @@ -604,10 +564,7 @@ class TestMakeBarcode(TestCase): def test_barcode_brief(self): """Test creation of simple barcode.""" - bc = helpers.MakeBarcode( - "stockitem", - 7, - ) + bc = helpers.MakeBarcode('stockitem', 7) data = json.loads(bc) self.assertEqual(len(data), 1) @@ -619,16 +576,14 @@ class TestDownloadFile(TestCase): def test_download(self): """Tests for DownloadFile.""" - helpers.DownloadFile("hello world", "out.txt") - helpers.DownloadFile(bytes(b"hello world"), "out.bin") + helpers.DownloadFile('hello world', 'out.txt') + helpers.DownloadFile(bytes(b'hello world'), 'out.bin') class TestMPTT(TestCase): """Tests for the MPTT tree models.""" - fixtures = [ - 'location', - ] + fixtures = ['location'] @classmethod def setUpTestData(cls): @@ -681,62 +636,62 @@ class TestSerialNumberExtraction(TestCase): e = helpers.extract_serial_numbers # Test a range of numbers - sn = e("1-5", 5, 1) + sn = e('1-5', 5, 1) self.assertEqual(len(sn), 5) for i in range(1, 6): self.assertIn(str(i), sn) - sn = e("11-30", 20, 1) + sn = e('11-30', 20, 1) self.assertEqual(len(sn), 20) - sn = e("1, 2, 3, 4, 5", 5, 1) + sn = e('1, 2, 3, 4, 5', 5, 1) self.assertEqual(len(sn), 5) # Test partially specifying serials - sn = e("1, 2, 4+", 5, 1) + sn = e('1, 2, 4+', 5, 1) self.assertEqual(len(sn), 5) self.assertEqual(sn, ['1', '2', '4', '5', '6']) # Test groups are not interpolated if enough serials are supplied - sn = e("1, 2, 3, AF5-69H, 5", 5, 1) + sn = e('1, 2, 3, AF5-69H, 5', 5, 1) self.assertEqual(len(sn), 5) self.assertEqual(sn, ['1', '2', '3', 'AF5-69H', '5']) # Test groups are not interpolated with more than one hyphen in a word - sn = e("1, 2, TG-4SR-92, 4+", 5, 1) + sn = e('1, 2, TG-4SR-92, 4+', 5, 1) self.assertEqual(len(sn), 5) - self.assertEqual(sn, ['1', '2', "TG-4SR-92", '4', '5']) + self.assertEqual(sn, ['1', '2', 'TG-4SR-92', '4', '5']) # Test multiple placeholders - sn = e("1 2 ~ ~ ~", 5, 2) + sn = e('1 2 ~ ~ ~', 5, 2) self.assertEqual(len(sn), 5) self.assertEqual(sn, ['1', '2', '3', '4', '5']) - sn = e("1-5, 10-15", 11, 1) + sn = e('1-5, 10-15', 11, 1) self.assertIn('3', sn) self.assertIn('13', sn) - sn = e("1+", 10, 1) + sn = e('1+', 10, 1) self.assertEqual(len(sn), 10) self.assertEqual(sn, [str(_) for _ in range(1, 11)]) - sn = e("4, 1+2", 4, 1) + sn = e('4, 1+2', 4, 1) self.assertEqual(len(sn), 4) self.assertEqual(sn, ['4', '1', '2', '3']) - sn = e("~", 1, 1) + sn = e('~', 1, 1) self.assertEqual(len(sn), 1) self.assertEqual(sn, ['2']) - sn = e("~", 1, 3) + sn = e('~', 1, 3) self.assertEqual(len(sn), 1) self.assertEqual(sn, ['4']) - sn = e("~+", 2, 4) + sn = e('~+', 2, 4) self.assertEqual(len(sn), 2) self.assertEqual(sn, ['5', '6']) - sn = e("~+3", 4, 4) + sn = e('~+3', 4, 4) self.assertEqual(len(sn), 4) self.assertEqual(sn, ['5', '6', '7', '8']) @@ -746,70 +701,70 @@ class TestSerialNumberExtraction(TestCase): # Test duplicates with self.assertRaises(ValidationError): - e("1,2,3,3,3", 5, 1) + e('1,2,3,3,3', 5, 1) # Test invalid length with self.assertRaises(ValidationError): - e("1,2,3", 5, 1) + e('1,2,3', 5, 1) # Test empty string with self.assertRaises(ValidationError): - e(", , ,", 0, 1) + e(', , ,', 0, 1) # Test incorrect sign in group with self.assertRaises(ValidationError): - e("10-2", 8, 1) + e('10-2', 8, 1) # Test invalid group with self.assertRaises(ValidationError): - e("1-5-10", 10, 1) + e('1-5-10', 10, 1) with self.assertRaises(ValidationError): - e("10, a, 7-70j", 4, 1) + e('10, a, 7-70j', 4, 1) # Test groups are not interpolated with word characters with self.assertRaises(ValidationError): - e("1, 2, 3, E-5", 5, 1) + e('1, 2, 3, E-5', 5, 1) # Extract a range of values with a smaller range with self.assertRaises(ValidationError) as exc: - e("11-50", 10, 1) + e('11-50', 10, 1) self.assertIn('Range quantity exceeds 10', str(exc)) # Test groups are not interpolated with alpha characters with self.assertRaises(ValidationError) as exc: - e("1, A-2, 3+", 5, 1) + e('1, A-2, 3+', 5, 1) self.assertIn('Invalid group range: A-2', str(exc)) def test_combinations(self): """Test complex serial number combinations.""" e = helpers.extract_serial_numbers - sn = e("1 3-5 9+2", 7, 1) + sn = e('1 3-5 9+2', 7, 1) self.assertEqual(len(sn), 7) self.assertEqual(sn, ['1', '3', '4', '5', '9', '10', '11']) - sn = e("1,3-5,9+2", 7, 1) + sn = e('1,3-5,9+2', 7, 1) self.assertEqual(len(sn), 7) self.assertEqual(sn, ['1', '3', '4', '5', '9', '10', '11']) - sn = e("~+2", 3, 13) + sn = e('~+2', 3, 13) self.assertEqual(len(sn), 3) self.assertEqual(sn, ['14', '15', '16']) - sn = e("~+", 2, 13) + sn = e('~+', 2, 13) self.assertEqual(len(sn), 2) self.assertEqual(sn, ['14', '15']) # Test multiple increment groups - sn = e("~+4, 20+4, 30+4", 15, 10) + sn = e('~+4, 20+4, 30+4', 15, 10) self.assertEqual(len(sn), 15) for v in [14, 24, 34]: self.assertIn(str(v), sn) # Test multiple range groups - sn = e("11-20, 41-50, 91-100", 30, 1) + sn = e('11-20, 41-50, 91-100', 30, 1) self.assertEqual(len(sn), 30) for v in range(11, 21): @@ -848,7 +803,7 @@ class TestVersionNumber(TestCase): """Test that the git commit information is extracted successfully.""" envs = { 'INVENTREE_COMMIT_HASH': 'abcdef', - 'INVENTREE_COMMIT_DATE': '2022-12-31' + 'INVENTREE_COMMIT_DATE': '2022-12-31', } # Check that the environment variables take priority @@ -861,10 +816,16 @@ class TestVersionNumber(TestCase): # Check that the current .git values work too - hash = str(subprocess.check_output('git rev-parse --short HEAD'.split()), 'utf-8').strip() + hash = str( + subprocess.check_output('git rev-parse --short HEAD'.split()), 'utf-8' + ).strip() self.assertEqual(hash, version.inventreeCommitHash()) - d = str(subprocess.check_output('git show -s --format=%ci'.split()), 'utf-8').strip().split(' ')[0] + d = ( + str(subprocess.check_output('git show -s --format=%ci'.split()), 'utf-8') + .strip() + .split(' ')[0] + ) self.assertEqual(d, version.inventreeCommitDate()) @@ -898,7 +859,7 @@ class CurrencyTests(TestCase): break else: # pragma: no cover - print("Exchange rate update failed - retrying") + print('Exchange rate update failed - retrying') print(f'Expected {currency_codes()}, got {[a.currency for a in rates]}') time.sleep(1) @@ -975,16 +936,14 @@ class TestSettings(InvenTreeTestCase): self.assertEqual(user_count(), 1) # not enough set - self.run_reload({ - 'INVENTREE_ADMIN_USER': 'admin' - }) + self.run_reload({'INVENTREE_ADMIN_USER': 'admin'}) self.assertEqual(user_count(), 1) # enough set self.run_reload({ 'INVENTREE_ADMIN_USER': 'admin', # set username 'INVENTREE_ADMIN_EMAIL': 'info@example.com', # set email - 'INVENTREE_ADMIN_PASSWORD': 'password123' # set password + 'INVENTREE_ADMIN_PASSWORD': 'password123', # set password }) self.assertEqual(user_count(), 2) @@ -1027,27 +986,27 @@ class TestSettings(InvenTreeTestCase): """Test get_config_file.""" # normal run - not configured - valid = [ - 'inventree/config.yaml', - 'inventree/data/config.yaml', - ] + valid = ['inventree/config.yaml', 'inventree/data/config.yaml'] - self.assertTrue(any(opt in str(config.get_config_file()).lower() for opt in valid)) + self.assertTrue( + any(opt in str(config.get_config_file()).lower() for opt in valid) + ) # with env set with self.in_env_context({'INVENTREE_CONFIG_FILE': 'my_special_conf.yaml'}): - self.assertIn('inventree/my_special_conf.yaml', str(config.get_config_file()).lower()) + self.assertIn( + 'inventree/my_special_conf.yaml', str(config.get_config_file()).lower() + ) def test_helpers_plugin_file(self): """Test get_plugin_file.""" # normal run - not configured - valid = [ - 'inventree/plugins.txt', - 'inventree/data/plugins.txt', - ] + valid = ['inventree/plugins.txt', 'inventree/data/plugins.txt'] - self.assertTrue(any(opt in str(config.get_plugin_file()).lower() for opt in valid)) + self.assertTrue( + any(opt in str(config.get_plugin_file()).lower() for opt in valid) + ) # with env set with self.in_env_context({'INVENTREE_PLUGIN_FILE': 'my_special_plugins.txt'}): @@ -1064,11 +1023,15 @@ class TestSettings(InvenTreeTestCase): self.assertEqual(config.get_setting(TEST_ENV_NAME, None), '321') # test typecasting to dict - None should be mapped to empty dict - self.assertEqual(config.get_setting(TEST_ENV_NAME, None, None, typecast=dict), {}) + self.assertEqual( + config.get_setting(TEST_ENV_NAME, None, None, typecast=dict), {} + ) # test typecasting to dict - valid JSON string should be mapped to corresponding dict with self.in_env_context({TEST_ENV_NAME: '{"a": 1}'}): - self.assertEqual(config.get_setting(TEST_ENV_NAME, None, typecast=dict), {"a": 1}) + self.assertEqual( + config.get_setting(TEST_ENV_NAME, None, typecast=dict), {'a': 1} + ) # test typecasting to dict - invalid JSON string should be mapped to empty dict with self.in_env_context({TEST_ENV_NAME: "{'a': 1}"}): @@ -1084,8 +1047,8 @@ class TestInstanceName(InvenTreeTestCase): self.assertEqual(version.inventreeInstanceTitle(), 'InvenTree') # set up required setting - InvenTreeSetting.set_setting("INVENTREE_INSTANCE_TITLE", True, self.user) - InvenTreeSetting.set_setting("INVENTREE_INSTANCE", "Testing title", self.user) + InvenTreeSetting.set_setting('INVENTREE_INSTANCE_TITLE', True, self.user) + InvenTreeSetting.set_setting('INVENTREE_INSTANCE', 'Testing title', self.user) self.assertEqual(version.inventreeInstanceTitle(), 'Testing title') @@ -1096,7 +1059,9 @@ class TestInstanceName(InvenTreeTestCase): def test_instance_url(self): """Test instance url settings.""" # Set up required setting - InvenTreeSetting.set_setting("INVENTREE_BASE_URL", "http://127.1.2.3", self.user) + InvenTreeSetting.set_setting( + 'INVENTREE_BASE_URL', 'http://127.1.2.3', self.user + ) # The site should also be changed site_obj = Site.objects.all().order_by('id').first() @@ -1106,12 +1071,7 @@ class TestInstanceName(InvenTreeTestCase): class TestOffloadTask(InvenTreeTestCase): """Tests for offloading tasks to the background worker""" - fixtures = [ - 'category', - 'part', - 'location', - 'stock', - ] + fixtures = ['category', 'part', 'location', 'stock'] def test_offload_tasks(self): """Test that we can offload various tasks to the background worker thread. @@ -1127,38 +1087,38 @@ class TestOffloadTask(InvenTreeTestCase): Ref: https://github.com/inventree/InvenTree/pull/3273 """ - self.assertTrue(offload_task( - 'dummy_tasks.stock', - item=StockItem.objects.get(pk=1), - loc=StockLocation.objects.get(pk=1), - force_async=True - )) + self.assertTrue( + offload_task( + 'dummy_tasks.stock', + item=StockItem.objects.get(pk=1), + loc=StockLocation.objects.get(pk=1), + force_async=True, + ) + ) - self.assertTrue(offload_task( - 'dummy_task.numbers', - 1, 2, 3, 4, 5, - force_async=True - )) + self.assertTrue( + offload_task('dummy_task.numbers', 1, 2, 3, 4, 5, force_async=True) + ) # Offload a dummy task, but force sync # This should fail, because the function does not exist with self.assertLogs(logger='inventree', level='WARNING') as log: - self.assertFalse(offload_task( - 'dummy_task.numbers', - 1, 1, 1, - force_sync=True - )) + self.assertFalse( + offload_task('dummy_task.numbers', 1, 1, 1, force_sync=True) + ) - self.assertIn("Malformed function path", str(log.output)) + self.assertIn('Malformed function path', str(log.output)) # Offload dummy task with a Part instance # This should succeed, ensuring that the Part instance is correctly pickled - self.assertTrue(offload_task( - 'dummy_tasks.parts', - part=Part.objects.get(pk=1), - cat=PartCategory.objects.get(pk=1), - force_async=True - )) + self.assertTrue( + offload_task( + 'dummy_tasks.parts', + part=Part.objects.get(pk=1), + cat=PartCategory.objects.get(pk=1), + force_async=True, + ) + ) def test_daily_holdoff(self): """Tests for daily task holdoff helper functions""" @@ -1180,7 +1140,9 @@ class TestOffloadTask(InvenTreeTestCase): # An attempt has been logged, but it is too recent result = InvenTree.tasks.check_daily_holdoff('dummy_task') self.assertFalse(result) - self.assertIn("Last attempt for 'dummy_task' was too recent", str(cm.output)) + self.assertIn( + "Last attempt for 'dummy_task' was too recent", str(cm.output) + ) # Mark last attempt a few days ago - should now return True t_old = datetime.now() - timedelta(days=3) @@ -1191,13 +1153,17 @@ class TestOffloadTask(InvenTreeTestCase): self.assertTrue(result) # Last attempt should have been updated - self.assertNotEqual(t_old, InvenTreeSetting.get_setting('_dummy_task_ATTEMPT', '', cache=False)) + self.assertNotEqual( + t_old, InvenTreeSetting.get_setting('_dummy_task_ATTEMPT', '', cache=False) + ) # Last attempt should prevent us now with self.assertLogs(logger='inventree', level='INFO') as cm: result = InvenTree.tasks.check_daily_holdoff('dummy_task') self.assertFalse(result) - self.assertIn("Last attempt for 'dummy_task' was too recent", str(cm.output)) + self.assertIn( + "Last attempt for 'dummy_task' was too recent", str(cm.output) + ) # Configure so a task was successful too recently InvenTreeSetting.set_setting('_dummy_task_ATTEMPT', t_old, None) diff --git a/InvenTree/InvenTree/translation.py b/InvenTree/InvenTree/translation.py index d07f5c8ed7..3615a6d05e 100644 --- a/InvenTree/InvenTree/translation.py +++ b/InvenTree/InvenTree/translation.py @@ -24,7 +24,7 @@ def reload_translation_stats(): keys = _translation_stats.keys() # Note that the names used in the stats file may not align 100% - for (code, _lang) in settings.LANGUAGES: + for code, _lang in settings.LANGUAGES: if code in keys: # Direct match, move on continue diff --git a/InvenTree/InvenTree/unit_test.py b/InvenTree/InvenTree/unit_test.py index 43a2d66cbf..85af6d9dd2 100644 --- a/InvenTree/InvenTree/unit_test.py +++ b/InvenTree/InvenTree/unit_test.py @@ -39,7 +39,7 @@ def getMigrationFileNames(app): files = local_dir.joinpath('..', app, 'migrations').iterdir() # Regex pattern for migration files - regex = re.compile(r"^[\d]+_.*\.py$") + regex = re.compile(r'^[\d]+_.*\.py$') migration_files = [] @@ -56,7 +56,6 @@ def getOldestMigrationFile(app, exclude_extension=True, ignore_initial=True): oldest_file = None for f in getMigrationFileNames(app): - if ignore_initial and f.startswith('0001_initial'): continue @@ -115,9 +114,7 @@ class UserMixin: # Create a user to log in with cls.user = get_user_model().objects.create_user( - username=cls.username, - password=cls.password, - email=cls.email + username=cls.username, password=cls.password, email=cls.email ) # Create a group for the user @@ -163,15 +160,15 @@ class UserMixin: raise TypeError('assignRole: assign_all must be a boolean value') if not role and not assign_all: - raise ValueError('assignRole: either role must be provided, or assign_all must be set') + raise ValueError( + 'assignRole: either role must be provided, or assign_all must be set' + ) if not assign_all and role: rule, perm = role.split('.') for ruleset in group.rule_sets.all(): - if assign_all or ruleset.name == rule: - if assign_all or perm == 'view': ruleset.can_view = True elif assign_all or perm == 'change': @@ -205,37 +202,24 @@ class ExchangeRateMixin: def generate_exchange_rates(self): """Helper function which generates some exchange rates to work with""" - rates = { - 'AUD': 1.5, - 'CAD': 1.7, - 'GBP': 0.9, - 'USD': 1.0, - } + rates = {'AUD': 1.5, 'CAD': 1.7, 'GBP': 0.9, 'USD': 1.0} # Create a dummy backend - ExchangeBackend.objects.create( - name='InvenTreeExchange', - base_currency='USD', - ) + ExchangeBackend.objects.create(name='InvenTreeExchange', base_currency='USD') backend = ExchangeBackend.objects.get(name='InvenTreeExchange') items = [] for currency, rate in rates.items(): - items.append( - Rate( - currency=currency, - value=rate, - backend=backend, - ) - ) + items.append(Rate(currency=currency, value=rate, backend=backend)) Rate.objects.bulk_create(items) class InvenTreeTestCase(ExchangeRateMixin, UserMixin, TestCase): """Testcase with user setup buildin.""" + pass @@ -243,7 +227,9 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase): """Base class for running InvenTree API tests.""" @contextmanager - def assertNumQueriesLessThan(self, value, using='default', verbose=False, debug=False): + def assertNumQueriesLessThan( + self, value, using='default', verbose=False, debug=False + ): """Context manager to check that the number of queries is less than a certain value. Example: @@ -252,17 +238,17 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase): Ref: https://stackoverflow.com/questions/1254170/django-is-there-a-way-to-count-sql-queries-from-an-unit-test/59089020#59089020 """ with CaptureQueriesContext(connections[using]) as context: - yield # your test will be run here + yield # your test will be run here if verbose: - msg = "\r\n%s" % json.dumps(context.captured_queries, indent=4) + msg = '\r\n%s' % json.dumps(context.captured_queries, indent=4) else: msg = None n = len(context.captured_queries) if debug: - print(f"Expected less than {value} queries, got {n} queries") + print(f'Expected less than {value} queries, got {n} queries') self.assertLess(n, value, msg=msg) @@ -273,8 +259,9 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase): return if expected_code != response.status_code: - - print(f"Unexpected {method} response at '{url}': status_code = {response.status_code}") + print( + f"Unexpected {method} response at '{url}': status_code = {response.status_code}" + ) if hasattr(response, 'data'): print('data:', response.data) @@ -359,7 +346,9 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase): return response - def download_file(self, url, data, expected_code=None, expected_fn=None, decode=True): + def download_file( + self, url, data, expected_code=None, expected_fn=None, decode=True + ): """Download a file from the server, and return an in-memory file.""" response = self.client.get(url, data=data, format='json') @@ -367,7 +356,9 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase): # Check that the response is of the correct type if not isinstance(response, StreamingHttpResponse): - raise ValueError("Response is not a StreamingHttpResponse object as expected") + raise ValueError( + 'Response is not a StreamingHttpResponse object as expected' + ) # Extract filename disposition = response.headers['Content-Disposition'] @@ -394,7 +385,14 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase): return file - def process_csv(self, file_object, delimiter=',', required_cols=None, excluded_cols=None, required_rows=None): + def process_csv( + self, + file_object, + delimiter=',', + required_cols=None, + excluded_cols=None, + required_rows=None, + ): """Helper function to process and validate a downloaded csv file.""" # Check that the correct object type has been passed self.assertTrue(isinstance(file_object, io.StringIO)) diff --git a/InvenTree/InvenTree/urls.py b/InvenTree/InvenTree/urls.py index 2baade9a2a..0b40ede9cc 100644 --- a/InvenTree/InvenTree/urls.py +++ b/InvenTree/InvenTree/urls.py @@ -10,9 +10,11 @@ from django.urls import include, path, re_path from django.views.decorators.csrf import csrf_exempt from django.views.generic.base import RedirectView -from dj_rest_auth.registration.views import (ConfirmEmailView, - SocialAccountDisconnectView, - SocialAccountListView) +from dj_rest_auth.registration.views import ( + ConfirmEmailView, + SocialAccountDisconnectView, + SocialAccountListView, +) from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView from sesame.views import LoginView @@ -28,36 +30,49 @@ import stock.api import users.api from build.urls import build_urls from common.urls import common_urls -from company.urls import (company_urls, manufacturer_part_urls, - supplier_part_urls) +from company.urls import company_urls, manufacturer_part_urls, supplier_part_urls from order.urls import order_urls from part.urls import part_urls from plugin.urls import get_plugin_urls from stock.urls import stock_urls from web.urls import urlpatterns as platform_urls -from .api import (APISearchView, InfoView, NotFoundView, VersionTextView, - VersionView) +from .api import APISearchView, InfoView, NotFoundView, VersionTextView, VersionView from .magic_login import GetSimpleLoginView -from .social_auth_urls import (EmailListView, EmailPrimaryView, - EmailRemoveView, EmailVerifyView, - SocialProviderListView, social_auth_urlpatterns) -from .views import (AboutView, AppearanceSelectView, CustomConnectionsView, - CustomEmailView, CustomLoginView, - CustomPasswordResetFromKeyView, - CustomSessionDeleteOtherView, CustomSessionDeleteView, - DatabaseStatsView, DynamicJsView, EditUserView, IndexView, - NotificationsView, SearchView, SetPasswordView, - SettingsView, auth_request) +from .social_auth_urls import ( + EmailListView, + EmailPrimaryView, + EmailRemoveView, + EmailVerifyView, + SocialProviderListView, + social_auth_urlpatterns, +) +from .views import ( + AboutView, + AppearanceSelectView, + CustomConnectionsView, + CustomEmailView, + CustomLoginView, + CustomPasswordResetFromKeyView, + CustomSessionDeleteOtherView, + CustomSessionDeleteView, + DatabaseStatsView, + DynamicJsView, + EditUserView, + IndexView, + NotificationsView, + SearchView, + SetPasswordView, + SettingsView, + auth_request, +) -admin.site.site_header = "InvenTree Admin" +admin.site.site_header = 'InvenTree Admin' apipatterns = [ - # Global search path('search/', APISearchView.as_view(), name='api-search'), - re_path(r'^settings/', include(common.api.settings_api_urls)), re_path(r'^part/', include(part.api.part_api_urls)), re_path(r'^bom/', include(part.api.bom_api_urls)), @@ -69,117 +84,288 @@ apipatterns = [ re_path(r'^report/', include(report.api.report_api_urls)), re_path(r'^user/', include(users.api.user_urls)), re_path(r'^admin/', include(common.api.admin_api_urls)), - # Plugin endpoints path('', include(plugin.api.plugin_api_urls)), - # Common endpoints endpoint path('', include(common.api.common_api_urls)), - # OpenAPI Schema - re_path('schema/', SpectacularAPIView.as_view(custom_settings={'SCHEMA_PATH_PREFIX': '/api/'}), name='schema'), - + re_path( + 'schema/', + SpectacularAPIView.as_view(custom_settings={'SCHEMA_PATH_PREFIX': '/api/'}), + name='schema', + ), # InvenTree information endpoints - path("version-text", VersionTextView.as_view(), name="api-version-text"), # version text + path( + 'version-text', VersionTextView.as_view(), name='api-version-text' + ), # version text path('version/', VersionView.as_view(), name='api-version'), # version info path('', InfoView.as_view(), name='api-inventree-info'), # server info - # Auth API endpoints - path('auth/', include([ - re_path(r'^registration/account-confirm-email/(?P[-:\w]+)/$', ConfirmEmailView.as_view(), name='account_confirm_email'), - path('registration/', include('dj_rest_auth.registration.urls')), - path('providers/', SocialProviderListView.as_view(), name='social_providers'), - path('emails/', include([path('/', include([ - path('primary/', EmailPrimaryView.as_view(), name='email-primary'), - path('verify/', EmailVerifyView.as_view(), name='email-verify'), - path('remove/', EmailRemoveView().as_view(), name='email-remove'),])), - path('', EmailListView.as_view(), name='email-list') - ])), - path('social/', include(social_auth_urlpatterns)), - path('social/', SocialAccountListView.as_view(), name='social_account_list'), - path('social//disconnect/', SocialAccountDisconnectView.as_view(), name='social_account_disconnect'), - path('', include('dj_rest_auth.urls')), - ])), - + path( + 'auth/', + include([ + re_path( + r'^registration/account-confirm-email/(?P[-:\w]+)/$', + ConfirmEmailView.as_view(), + name='account_confirm_email', + ), + path('registration/', include('dj_rest_auth.registration.urls')), + path( + 'providers/', SocialProviderListView.as_view(), name='social_providers' + ), + path( + 'emails/', + include([ + path( + '/', + include([ + path( + 'primary/', + EmailPrimaryView.as_view(), + name='email-primary', + ), + path( + 'verify/', + EmailVerifyView.as_view(), + name='email-verify', + ), + path( + 'remove/', + EmailRemoveView().as_view(), + name='email-remove', + ), + ]), + ), + path('', EmailListView.as_view(), name='email-list'), + ]), + ), + path('social/', include(social_auth_urlpatterns)), + path( + 'social/', SocialAccountListView.as_view(), name='social_account_list' + ), + path( + 'social//disconnect/', + SocialAccountDisconnectView.as_view(), + name='social_account_disconnect', + ), + path('', include('dj_rest_auth.urls')), + ]), + ), # Magic login URLs - path("email/generate/", csrf_exempt(GetSimpleLoginView().as_view()), name="sesame-generate",), - path("email/login/", LoginView.as_view(), name="sesame-login"), - + path( + 'email/generate/', + csrf_exempt(GetSimpleLoginView().as_view()), + name='sesame-generate', + ), + path('email/login/', LoginView.as_view(), name='sesame-login'), # Unknown endpoint re_path(r'^.*$', NotFoundView.as_view(), name='api-404'), ] settings_urls = [ - re_path(r'^i18n/?', include('django.conf.urls.i18n')), - - re_path(r'^appearance/?', AppearanceSelectView.as_view(), name='settings-appearance'), - + re_path( + r'^appearance/?', AppearanceSelectView.as_view(), name='settings-appearance' + ), # Catch any other urls - re_path(r'^.*$', SettingsView.as_view(template_name='InvenTree/settings/settings.html'), name='settings'), + re_path( + r'^.*$', + SettingsView.as_view(template_name='InvenTree/settings/settings.html'), + name='settings', + ), ] notifications_urls = [ - # Catch any other urls - re_path(r'^.*$', NotificationsView.as_view(), name='notifications'), + re_path(r'^.*$', NotificationsView.as_view(), name='notifications') ] # These javascript files are served "dynamically" - i.e. rendered on demand dynamic_javascript_urls = [ - re_path(r'^calendar.js', DynamicJsView.as_view(template_name='js/dynamic/calendar.js'), name='calendar.js'), - re_path(r'^nav.js', DynamicJsView.as_view(template_name='js/dynamic/nav.js'), name='nav.js'), - re_path(r'^permissions.js', DynamicJsView.as_view(template_name='js/dynamic/permissions.js'), name='permissions.js'), - re_path(r'^settings.js', DynamicJsView.as_view(template_name='js/dynamic/settings.js'), name='settings.js'), + re_path( + r'^calendar.js', + DynamicJsView.as_view(template_name='js/dynamic/calendar.js'), + name='calendar.js', + ), + re_path( + r'^nav.js', + DynamicJsView.as_view(template_name='js/dynamic/nav.js'), + name='nav.js', + ), + re_path( + r'^permissions.js', + DynamicJsView.as_view(template_name='js/dynamic/permissions.js'), + name='permissions.js', + ), + re_path( + r'^settings.js', + DynamicJsView.as_view(template_name='js/dynamic/settings.js'), + name='settings.js', + ), ] # These javascript files are passed through the Django translation layer translated_javascript_urls = [ - re_path(r'^api.js', DynamicJsView.as_view(template_name='js/translated/api.js'), name='api.js'), - re_path(r'^attachment.js', DynamicJsView.as_view(template_name='js/translated/attachment.js'), name='attachment.js'), - re_path(r'^barcode.js', DynamicJsView.as_view(template_name='js/translated/barcode.js'), name='barcode.js'), - re_path(r'^bom.js', DynamicJsView.as_view(template_name='js/translated/bom.js'), name='bom.js'), - re_path(r'^build.js', DynamicJsView.as_view(template_name='js/translated/build.js'), name='build.js'), - re_path(r'^charts.js', DynamicJsView.as_view(template_name='js/translated/charts.js'), name='charts.js'), - re_path(r'^company.js', DynamicJsView.as_view(template_name='js/translated/company.js'), name='company.js'), - re_path(r'^filters.js', DynamicJsView.as_view(template_name='js/translated/filters.js'), name='filters.js'), - re_path(r'^forms.js', DynamicJsView.as_view(template_name='js/translated/forms.js'), name='forms.js'), - re_path(r'^helpers.js', DynamicJsView.as_view(template_name='js/translated/helpers.js'), name='helpers.js'), - re_path(r'^index.js', DynamicJsView.as_view(template_name='js/translated/index.js'), name='index.js'), - re_path(r'^label.js', DynamicJsView.as_view(template_name='js/translated/label.js'), name='label.js'), - re_path(r'^model_renderers.js', DynamicJsView.as_view(template_name='js/translated/model_renderers.js'), name='model_renderers.js'), - re_path(r'^modals.js', DynamicJsView.as_view(template_name='js/translated/modals.js'), name='modals.js'), - re_path(r'^order.js', DynamicJsView.as_view(template_name='js/translated/order.js'), name='order.js'), - re_path(r'^part.js', DynamicJsView.as_view(template_name='js/translated/part.js'), name='part.js'), - re_path(r'^purchase_order.js', DynamicJsView.as_view(template_name='js/translated/purchase_order.js'), name='purchase_order.js'), - re_path(r'^return_order.js', DynamicJsView.as_view(template_name='js/translated/return_order.js'), name='return_order.js'), - re_path(r'^report.js', DynamicJsView.as_view(template_name='js/translated/report.js'), name='report.js'), - re_path(r'^sales_order.js', DynamicJsView.as_view(template_name='js/translated/sales_order.js'), name='sales_order.js'), - re_path(r'^search.js', DynamicJsView.as_view(template_name='js/translated/search.js'), name='search.js'), - re_path(r'^stock.js', DynamicJsView.as_view(template_name='js/translated/stock.js'), name='stock.js'), - re_path(r'^status_codes.js', DynamicJsView.as_view(template_name='js/translated/status_codes.js'), name='status_codes.js'), - re_path(r'^plugin.js', DynamicJsView.as_view(template_name='js/translated/plugin.js'), name='plugin.js'), - re_path(r'^pricing.js', DynamicJsView.as_view(template_name='js/translated/pricing.js'), name='pricing.js'), - re_path(r'^news.js', DynamicJsView.as_view(template_name='js/translated/news.js'), name='news.js'), - re_path(r'^tables.js', DynamicJsView.as_view(template_name='js/translated/tables.js'), name='tables.js'), - re_path(r'^table_filters.js', DynamicJsView.as_view(template_name='js/translated/table_filters.js'), name='table_filters.js'), - re_path(r'^notification.js', DynamicJsView.as_view(template_name='js/translated/notification.js'), name='notification.js'), + re_path( + r'^api.js', + DynamicJsView.as_view(template_name='js/translated/api.js'), + name='api.js', + ), + re_path( + r'^attachment.js', + DynamicJsView.as_view(template_name='js/translated/attachment.js'), + name='attachment.js', + ), + re_path( + r'^barcode.js', + DynamicJsView.as_view(template_name='js/translated/barcode.js'), + name='barcode.js', + ), + re_path( + r'^bom.js', + DynamicJsView.as_view(template_name='js/translated/bom.js'), + name='bom.js', + ), + re_path( + r'^build.js', + DynamicJsView.as_view(template_name='js/translated/build.js'), + name='build.js', + ), + re_path( + r'^charts.js', + DynamicJsView.as_view(template_name='js/translated/charts.js'), + name='charts.js', + ), + re_path( + r'^company.js', + DynamicJsView.as_view(template_name='js/translated/company.js'), + name='company.js', + ), + re_path( + r'^filters.js', + DynamicJsView.as_view(template_name='js/translated/filters.js'), + name='filters.js', + ), + re_path( + r'^forms.js', + DynamicJsView.as_view(template_name='js/translated/forms.js'), + name='forms.js', + ), + re_path( + r'^helpers.js', + DynamicJsView.as_view(template_name='js/translated/helpers.js'), + name='helpers.js', + ), + re_path( + r'^index.js', + DynamicJsView.as_view(template_name='js/translated/index.js'), + name='index.js', + ), + re_path( + r'^label.js', + DynamicJsView.as_view(template_name='js/translated/label.js'), + name='label.js', + ), + re_path( + r'^model_renderers.js', + DynamicJsView.as_view(template_name='js/translated/model_renderers.js'), + name='model_renderers.js', + ), + re_path( + r'^modals.js', + DynamicJsView.as_view(template_name='js/translated/modals.js'), + name='modals.js', + ), + re_path( + r'^order.js', + DynamicJsView.as_view(template_name='js/translated/order.js'), + name='order.js', + ), + re_path( + r'^part.js', + DynamicJsView.as_view(template_name='js/translated/part.js'), + name='part.js', + ), + re_path( + r'^purchase_order.js', + DynamicJsView.as_view(template_name='js/translated/purchase_order.js'), + name='purchase_order.js', + ), + re_path( + r'^return_order.js', + DynamicJsView.as_view(template_name='js/translated/return_order.js'), + name='return_order.js', + ), + re_path( + r'^report.js', + DynamicJsView.as_view(template_name='js/translated/report.js'), + name='report.js', + ), + re_path( + r'^sales_order.js', + DynamicJsView.as_view(template_name='js/translated/sales_order.js'), + name='sales_order.js', + ), + re_path( + r'^search.js', + DynamicJsView.as_view(template_name='js/translated/search.js'), + name='search.js', + ), + re_path( + r'^stock.js', + DynamicJsView.as_view(template_name='js/translated/stock.js'), + name='stock.js', + ), + re_path( + r'^status_codes.js', + DynamicJsView.as_view(template_name='js/translated/status_codes.js'), + name='status_codes.js', + ), + re_path( + r'^plugin.js', + DynamicJsView.as_view(template_name='js/translated/plugin.js'), + name='plugin.js', + ), + re_path( + r'^pricing.js', + DynamicJsView.as_view(template_name='js/translated/pricing.js'), + name='pricing.js', + ), + re_path( + r'^news.js', + DynamicJsView.as_view(template_name='js/translated/news.js'), + name='news.js', + ), + re_path( + r'^tables.js', + DynamicJsView.as_view(template_name='js/translated/tables.js'), + name='tables.js', + ), + re_path( + r'^table_filters.js', + DynamicJsView.as_view(template_name='js/translated/table_filters.js'), + name='table_filters.js', + ), + re_path( + r'^notification.js', + DynamicJsView.as_view(template_name='js/translated/notification.js'), + name='notification.js', + ), ] backendpatterns = [ # "Dynamic" javascript files which are rendered using InvenTree templating. re_path(r'^js/dynamic/', include(dynamic_javascript_urls)), re_path(r'^js/i18n/', include(translated_javascript_urls)), - re_path(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), re_path(r'^auth/?', auth_request), - re_path(r'^api/', include(apipatterns)), - re_path(r'^api-doc/', SpectacularRedocView.as_view(url_name='schema'), name='api-doc'), + re_path( + r'^api-doc/', SpectacularRedocView.as_view(url_name='schema'), name='api-doc' + ), ] classic_frontendpatterns = [ - # Apps re_path(r'^build/', include(build_urls)), re_path(r'^common/', include(common_urls)), @@ -189,38 +375,48 @@ classic_frontendpatterns = [ re_path(r'^part/', include(part_urls)), re_path(r'^stock/', include(stock_urls)), re_path(r'^supplier-part/', include(supplier_part_urls)), - re_path(r'^edit-user/', EditUserView.as_view(), name='edit-user'), re_path(r'^set-password/', SetPasswordView.as_view(), name='set-password'), - re_path(r'^index/', IndexView.as_view(), name='index'), re_path(r'^notifications/', include(notifications_urls)), re_path(r'^search/', SearchView.as_view(), name='search'), re_path(r'^settings/', include(settings_urls)), re_path(r'^about/', AboutView.as_view(), name='about'), re_path(r'^stats/', DatabaseStatsView.as_view(), name='stats'), - # DB user sessions - path('accounts/sessions/other/delete/', view=CustomSessionDeleteOtherView.as_view(), name='session_delete_other', ), - re_path(r'^accounts/sessions/(?P\w+)/delete/$', view=CustomSessionDeleteView.as_view(), name='session_delete', ), - + path( + 'accounts/sessions/other/delete/', + view=CustomSessionDeleteOtherView.as_view(), + name='session_delete_other', + ), + re_path( + r'^accounts/sessions/(?P\w+)/delete/$', + view=CustomSessionDeleteView.as_view(), + name='session_delete', + ), # Single Sign On / allauth # overrides of urlpatterns re_path(r'^accounts/email/', CustomEmailView.as_view(), name='account_email'), - re_path(r'^accounts/social/connections/', CustomConnectionsView.as_view(), name='socialaccount_connections'), - re_path(r"^accounts/password/reset/key/(?P[0-9A-Za-z]+)-(?P.+)/$", CustomPasswordResetFromKeyView.as_view(), name="account_reset_password_from_key"), - + re_path( + r'^accounts/social/connections/', + CustomConnectionsView.as_view(), + name='socialaccount_connections', + ), + re_path( + r'^accounts/password/reset/key/(?P[0-9A-Za-z]+)-(?P.+)/$', + CustomPasswordResetFromKeyView.as_view(), + name='account_reset_password_from_key', + ), # Override login page - re_path("accounts/login/", CustomLoginView.as_view(), name="account_login"), - - re_path(r'^accounts/', include('allauth_2fa.urls')), # MFA support - re_path(r'^accounts/', include('allauth.urls')), # included urlpatterns + re_path('accounts/login/', CustomLoginView.as_view(), name='account_login'), + re_path(r'^accounts/', include('allauth_2fa.urls')), # MFA support + re_path(r'^accounts/', include('allauth.urls')), # included urlpatterns ] urlpatterns = [] if settings.INVENTREE_ADMIN_ENABLED: - admin_url = settings.INVENTREE_ADMIN_URL, + admin_url = (settings.INVENTREE_ADMIN_URL,) urlpatterns += [ path(f'{admin_url}/error_log/', include('error_report.urls')), path(f'{admin_url}/', admin.site.urls, name='inventree-admin'), @@ -251,8 +447,13 @@ if settings.DEBUG: # Redirect for favicon.ico urlpatterns.append( - path('favicon.ico', RedirectView.as_view(url=f'{settings.STATIC_URL}img/favicon/favicon.ico')) + path( + 'favicon.ico', + RedirectView.as_view(url=f'{settings.STATIC_URL}img/favicon/favicon.ico'), + ) ) # Send any unknown URLs to the parts page -urlpatterns += [re_path(r'^.*$', RedirectView.as_view(url='/index/', permanent=False), name='index')] +urlpatterns += [ + re_path(r'^.*$', RedirectView.as_view(url='/index/', permanent=False), name='index') +] diff --git a/InvenTree/InvenTree/validators.py b/InvenTree/InvenTree/validators.py index ea9111585d..74f44dc9b2 100644 --- a/InvenTree/InvenTree/validators.py +++ b/InvenTree/InvenTree/validators.py @@ -69,7 +69,9 @@ class AllowedURLValidator(validators.URLValidator): self.schemes = allowable_url_schemes() # Determine if 'strict' URL validation is required (i.e. if the URL must have a schema prefix) - strict_urls = common.models.InvenTreeSetting.get_setting('INVENTREE_STRICT_URLS', True, cache=False) + strict_urls = common.models.InvenTreeSetting.get_setting( + 'INVENTREE_STRICT_URLS', True, cache=False + ) if not strict_urls: # Allow URLs which do not have a provided schema @@ -117,7 +119,7 @@ def validate_overage(value): i = Decimal(value) if i < 0: - raise ValidationError(_("Overage value must not be negative")) + raise ValidationError(_('Overage value must not be negative')) # Looks like a number return True @@ -133,17 +135,15 @@ def validate_overage(value): f = float(v) if f < 0: - raise ValidationError(_("Overage value must not be negative")) + raise ValidationError(_('Overage value must not be negative')) elif f > 100: - raise ValidationError(_("Overage must not exceed 100%")) + raise ValidationError(_('Overage must not exceed 100%')) return True except ValueError: pass - raise ValidationError( - _("Invalid value for overage") - ) + raise ValidationError(_('Invalid value for overage')) def validate_part_name_format(value): @@ -179,8 +179,6 @@ def validate_part_name_format(value): try: Template(value).render({'part': p}) except Exception as exc: - raise ValidationError({ - 'value': str(exc) - }) + raise ValidationError({'value': str(exc)}) return True diff --git a/InvenTree/InvenTree/version.py b/InvenTree/InvenTree/version.py index 52c133dd2a..f7dfb6cb49 100644 --- a/InvenTree/InvenTree/version.py +++ b/InvenTree/InvenTree/version.py @@ -19,7 +19,7 @@ from dulwich.repo import NotGitRepository, Repo from .api_version import INVENTREE_API_TEXT, INVENTREE_API_VERSION # InvenTree software version -INVENTREE_SW_VERSION = "0.14.0 dev" +INVENTREE_SW_VERSION = '0.14.0 dev' # Discover git try: @@ -32,8 +32,8 @@ except (NotGitRepository, FileNotFoundError): def checkMinPythonVersion(): """Check that the Python version is at least 3.9""" - version = sys.version.split(" ")[0] - docs = "https://docs.inventree.org/en/stable/start/intro/#python-requirements" + version = sys.version.split(' ')[0] + docs = 'https://docs.inventree.org/en/stable/start/intro/#python-requirements' msg = f""" InvenTree requires Python 3.9 or above - you are running version {version}. @@ -47,22 +47,22 @@ def checkMinPythonVersion(): if sys.version_info.major == 3 and sys.version_info.minor < 9: raise RuntimeError(msg) - print(f"Python version {version} - {sys.executable}") + print(f'Python version {version} - {sys.executable}') def inventreeInstanceName(): """Returns the InstanceName settings for the current database.""" import common.models - return common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE", "") + return common.models.InvenTreeSetting.get_setting('INVENTREE_INSTANCE', '') def inventreeInstanceTitle(): """Returns the InstanceTitle for the current database.""" import common.models - if common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE_TITLE", False): - return common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE", "") + if common.models.InvenTreeSetting.get_setting('INVENTREE_INSTANCE_TITLE', False): + return common.models.InvenTreeSetting.get_setting('INVENTREE_INSTANCE', '') return 'InvenTree' @@ -76,7 +76,7 @@ def inventreeVersionTuple(version=None): if version is None: version = INVENTREE_SW_VERSION - match = re.match(r"^.*(\d+)\.(\d+)\.(\d+).*$", str(version)) + match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', str(version)) return [int(g) for g in match.groups()] @@ -93,14 +93,14 @@ def inventreeDocsVersion(): Release -> "major.minor.sub" e.g. "0.5.2" """ if isInvenTreeDevelopmentVersion(): - return "latest" + return 'latest' return INVENTREE_SW_VERSION # pragma: no cover def inventreeDocUrl(): """Return URL for InvenTree documentation site.""" tag = inventreeDocsVersion() - return f"https://docs.inventree.org/en/{tag}" + return f'https://docs.inventree.org/en/{tag}' def inventreeAppUrl(): @@ -110,12 +110,12 @@ def inventreeAppUrl(): def inventreeCreditsUrl(): """Return URL for InvenTree credits site.""" - return "https://docs.inventree.org/en/latest/credits/" + return 'https://docs.inventree.org/en/latest/credits/' def inventreeGithubUrl(): """Return URL for InvenTree github site.""" - return "https://github.com/InvenTree/InvenTree/" + return 'https://github.com/InvenTree/InvenTree/' def isInvenTreeUpToDate(): @@ -124,7 +124,10 @@ def isInvenTreeUpToDate(): A background task periodically queries GitHub for latest version, and stores it to the database as "_INVENTREE_LATEST_VERSION" """ import common.models - latest = common.models.InvenTreeSetting.get_setting('_INVENTREE_LATEST_VERSION', backup_value=None, create=False) + + latest = common.models.InvenTreeSetting.get_setting( + '_INVENTREE_LATEST_VERSION', backup_value=None, create=False + ) # No record for "latest" version - we must assume we are up to date! if not latest: @@ -144,24 +147,26 @@ def inventreeApiVersion(): def parse_version_text(): """Parse the version text to structured data.""" - patched_data = INVENTREE_API_TEXT.split("\n\n") + patched_data = INVENTREE_API_TEXT.split('\n\n') # Remove first newline on latest version - patched_data[0] = patched_data[0].replace("\n", "", 1) + patched_data[0] = patched_data[0].replace('\n', '', 1) version_data = {} for version in patched_data: - data = version.split("\n") + data = version.split('\n') version_split = data[0].split(' -> ') - version_detail = version_split[1].split(':', 1) if len(version_split) > 1 else ['', ] + version_detail = ( + version_split[1].split(':', 1) if len(version_split) > 1 else [''] + ) new_data = { - "version": version_split[0].strip(), - "date": version_detail[0].strip(), - "gh": version_detail[1].strip() if len(version_detail) > 1 else None, - "text": data[1:], - "latest": False, + 'version': version_split[0].strip(), + 'date': version_detail[0].strip(), + 'gh': version_detail[1].strip() if len(version_detail) > 1 else None, + 'text': data[1:], + 'latest': False, } - version_data[new_data["version"]] = new_data + version_data[new_data['version']] = new_data return version_data @@ -183,7 +188,7 @@ def inventreeApiText(versions: int = 10, start_version: int = 0): start_version = INVENTREE_API_VERSION - versions return { - f"v{a}": version_data.get(f"v{a}", None) + f'v{a}': version_data.get(f'v{a}', None) for a in range(start_version, start_version + versions) } @@ -222,7 +227,9 @@ def inventreeCommitDate(): if main_commit is None: return None - commit_dt = dt.fromtimestamp(main_commit.commit_time) + td(seconds=main_commit.commit_timezone) + commit_dt = dt.fromtimestamp(main_commit.commit_time) + td( + seconds=main_commit.commit_timezone + ) return str(commit_dt.date()) diff --git a/InvenTree/InvenTree/views.py b/InvenTree/InvenTree/views.py index 9a6f4e95f2..a4e318591e 100644 --- a/InvenTree/InvenTree/views.py +++ b/InvenTree/InvenTree/views.py @@ -5,8 +5,7 @@ as JSON objects and passing them to modal forms (using jQuery / bootstrap). """ from django.contrib.auth import password_validation -from django.contrib.auth.mixins import (LoginRequiredMixin, - PermissionRequiredMixin) +from django.contrib.auth.mixins import LoginRequiredMixin, PermissionRequiredMixin from django.core.exceptions import ValidationError from django.http import HttpResponse, HttpResponseRedirect, JsonResponse from django.shortcuts import redirect @@ -20,8 +19,7 @@ from django.views.generic.base import RedirectView, TemplateView from allauth.account.forms import AddEmailForm from allauth.account.models import EmailAddress -from allauth.account.views import (EmailView, LoginView, - PasswordResetFromKeyView) +from allauth.account.views import EmailView, LoginView, PasswordResetFromKeyView from allauth.socialaccount.forms import DisconnectForm from allauth.socialaccount.views import ConnectionsView from djmoney.contrib.exchange.models import ExchangeBackend, Rate @@ -104,7 +102,6 @@ class InvenTreeRoleMixin(PermissionRequiredMixin): return True for required in roles_required: - (role, permission) = required.split('.') if role not in RuleSet.RULESET_NAMES: @@ -138,12 +135,14 @@ class InvenTreeRoleMixin(PermissionRequiredMixin): app_label = model._meta.app_label model_name = model._meta.model_name - table = f"{app_label}_{model_name}" + table = f'{app_label}_{model_name}' permission = self.get_permission_class() if not permission: - raise AttributeError(f"permission_class not defined for {type(self).__name__}") + raise AttributeError( + f'permission_class not defined for {type(self).__name__}' + ) # Check if the user has the required permission return RuleSet.check_table_permission(user, table, permission) @@ -182,7 +181,6 @@ class InvenTreeRoleMixin(PermissionRequiredMixin): } for view_class in permission_map.keys(): - if issubclass(type(self), view_class): return permission_map[view_class] @@ -282,9 +280,7 @@ class AjaxMixin(InvenTreeRoleMixin): data['title'] = self.get_form_title() data['html_form'] = render_to_string( - self.ajax_template_name, - context, - request=request + self.ajax_template_name, context, request=request ) # Custom feedback`data @@ -329,7 +325,9 @@ class AjaxUpdateView(AjaxMixin, UpdateView): """ super(UpdateView, self).get(request, *args, **kwargs) - return self.renderJsonResponse(request, self.get_form(), context=self.get_context_data()) + return self.renderJsonResponse( + request, self.get_form(), context=self.get_context_data() + ) def save(self, object, form, **kwargs): """Method for updating the object in the database. Default implementation is very simple, but can be overridden if required. @@ -379,7 +377,6 @@ class AjaxUpdateView(AjaxMixin, UpdateView): data[key] = value if valid: - # Save the updated object to the database self.save(self.object, form) @@ -399,8 +396,8 @@ class AjaxUpdateView(AjaxMixin, UpdateView): class EditUserView(AjaxUpdateView): """View for editing user information.""" - ajax_template_name = "modal_form.html" - ajax_form_title = _("Edit User Information") + ajax_template_name = 'modal_form.html' + ajax_form_title = _('Edit User Information') form_class = EditUserForm def get_object(self): @@ -411,8 +408,8 @@ class EditUserView(AjaxUpdateView): class SetPasswordView(AjaxUpdateView): """View for setting user password.""" - ajax_template_name = "InvenTree/password.html" - ajax_form_title = _("Set Password") + ajax_template_name = 'InvenTree/password.html' + ajax_form_title = _('Set Password') form_class = SetPasswordForm def get_object(self): @@ -494,14 +491,14 @@ class SearchView(TemplateView): class DynamicJsView(TemplateView): """View for returning javacsript files, which instead of being served dynamically, are passed through the django translation engine!""" - template_name = "" + template_name = '' content_type = 'text/javascript' class SettingsView(TemplateView): """View for configuring User settings.""" - template_name = "InvenTree/settings/settings.html" + template_name = 'InvenTree/settings/settings.html' def get_context_data(self, **kwargs): """Add data for template.""" @@ -509,33 +506,37 @@ class SettingsView(TemplateView): ctx['settings'] = common_models.InvenTreeSetting.objects.all().order_by('key') - ctx["base_currency"] = common_settings.currency_code_default() - ctx["currencies"] = common_settings.currency_codes + ctx['base_currency'] = common_settings.currency_code_default() + ctx['currencies'] = common_settings.currency_codes - ctx["rates"] = Rate.objects.filter(backend="InvenTreeExchange") + ctx['rates'] = Rate.objects.filter(backend='InvenTreeExchange') - ctx["categories"] = PartCategory.objects.all().order_by('tree_id', 'lft', 'name') + ctx['categories'] = PartCategory.objects.all().order_by( + 'tree_id', 'lft', 'name' + ) # When were the rates last updated? try: backend = ExchangeBackend.objects.filter(name='InvenTreeExchange') if backend.exists(): backend = backend.first() - ctx["rates_updated"] = backend.last_update + ctx['rates_updated'] = backend.last_update except Exception: - ctx["rates_updated"] = None + ctx['rates_updated'] = None # Forms and context for allauth ctx['add_email_form'] = AddEmailForm - ctx["can_add_email"] = EmailAddress.objects.can_add_email(self.request.user) + ctx['can_add_email'] = EmailAddress.objects.can_add_email(self.request.user) # Form and context for allauth social-accounts - ctx["request"] = self.request + ctx['request'] = self.request ctx['social_form'] = DisconnectForm(request=self.request) # user db sessions ctx['session_key'] = self.request.session.session_key - ctx['session_list'] = self.request.user.session_set.filter(expire_date__gt=now()).order_by('-last_activity') + ctx['session_list'] = self.request.user.session_set.filter( + expire_date__gt=now() + ).order_by('-last_activity') return ctx @@ -550,20 +551,23 @@ class AllauthOverrides(LoginRequiredMixin): class CustomEmailView(AllauthOverrides, EmailView): """Override of allauths EmailView to always show the settings but leave the functions allow.""" - success_url = reverse_lazy("settings") + + success_url = reverse_lazy('settings') class CustomConnectionsView(AllauthOverrides, ConnectionsView): """Override of allauths ConnectionsView to always show the settings but leave the functions allow.""" - success_url = reverse_lazy("settings") + + success_url = reverse_lazy('settings') class CustomPasswordResetFromKeyView(PasswordResetFromKeyView): """Override of allauths PasswordResetFromKeyView to always show the settings but leave the functions allow.""" - success_url = reverse_lazy("account_login") + + success_url = reverse_lazy('account_login') -class UserSessionOverride(): +class UserSessionOverride: """Overrides sucessurl to lead to settings.""" def get_success_url(self): @@ -573,11 +577,13 @@ class UserSessionOverride(): class CustomSessionDeleteView(UserSessionOverride, SessionDeleteView): """Revert to settings after session delete.""" + pass class CustomSessionDeleteOtherView(UserSessionOverride, SessionDeleteOtherView): """Revert to settings after session delete.""" + pass @@ -607,7 +613,9 @@ class AppearanceSelectView(RedirectView): def get_user_theme(self): """Get current user color theme.""" try: - user_theme = common_models.ColorTheme.objects.filter(user=self.request.user).get() + user_theme = common_models.ColorTheme.objects.filter( + user=self.request.user + ).get() except common_models.ColorTheme.DoesNotExist: user_theme = None @@ -638,18 +646,18 @@ class AppearanceSelectView(RedirectView): class DatabaseStatsView(AjaxView): """View for displaying database statistics.""" - ajax_template_name = "stats.html" - ajax_form_title = _("System Information") + ajax_template_name = 'stats.html' + ajax_form_title = _('System Information') class AboutView(AjaxView): """A view for displaying InvenTree version information""" - ajax_template_name = "about.html" - ajax_form_title = _("About InvenTree") + ajax_template_name = 'about.html' + ajax_form_title = _('About InvenTree') class NotificationsView(TemplateView): """View for showing notifications.""" - template_name = "InvenTree/notifications/notifications.html" + template_name = 'InvenTree/notifications/notifications.html' diff --git a/InvenTree/InvenTree/wsgi.py b/InvenTree/InvenTree/wsgi.py index dfced329a8..4630c31182 100644 --- a/InvenTree/InvenTree/wsgi.py +++ b/InvenTree/InvenTree/wsgi.py @@ -10,6 +10,8 @@ import os # pragma: no cover from django.core.wsgi import get_wsgi_application # pragma: no cover -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "InvenTree.settings") # pragma: no cover +os.environ.setdefault( + 'DJANGO_SETTINGS_MODULE', 'InvenTree.settings' +) # pragma: no cover application = get_wsgi_application() # pragma: no cover diff --git a/InvenTree/common/admin.py b/InvenTree/common/admin.py index cbc5b6621b..528f705c3a 100644 --- a/InvenTree/common/admin.py +++ b/InvenTree/common/admin.py @@ -22,7 +22,7 @@ class SettingsAdmin(ImportExportModelAdmin): class UserSettingsAdmin(ImportExportModelAdmin): """Admin settings for InvenTreeUserSetting.""" - list_display = ('key', 'value', 'user', ) + list_display = ('key', 'value', 'user') def get_readonly_fields(self, request, obj=None): # pragma: no cover """Prevent the 'key' field being edited once the setting is created.""" @@ -40,23 +40,31 @@ class WebhookAdmin(ImportExportModelAdmin): class NotificationEntryAdmin(admin.ModelAdmin): """Admin settings for NotificationEntry.""" - list_display = ('key', 'uid', 'updated', ) + list_display = ('key', 'uid', 'updated') class NotificationMessageAdmin(admin.ModelAdmin): """Admin settings for NotificationMessage.""" - list_display = ('age_human', 'user', 'category', 'name', 'read', 'target_object', 'source_object', ) + list_display = ( + 'age_human', + 'user', + 'category', + 'name', + 'read', + 'target_object', + 'source_object', + ) - list_filter = ('category', 'read', 'user', ) + list_filter = ('category', 'read', 'user') - search_fields = ('name', 'category', 'message', ) + search_fields = ('name', 'category', 'message') class NewsFeedEntryAdmin(admin.ModelAdmin): """Admin settings for NewsFeedEntry.""" - list_display = ('title', 'author', 'published', 'summary', ) + list_display = ('title', 'author', 'published', 'summary') admin.site.register(common.models.InvenTreeSetting, SettingsAdmin) diff --git a/InvenTree/common/api.py b/InvenTree/common/api.py index 3b9e15bbdc..9ccfc4b9d2 100644 --- a/InvenTree/common/api.py +++ b/InvenTree/common/api.py @@ -23,8 +23,13 @@ from InvenTree.api import BulkDeleteMixin, MetadataView from InvenTree.config import CONFIG_LOOKUPS from InvenTree.filters import ORDER_FILTER, SEARCH_ORDER_FILTER from InvenTree.helpers import inheritors -from InvenTree.mixins import (ListAPI, ListCreateAPI, RetrieveAPI, - RetrieveUpdateAPI, RetrieveUpdateDestroyAPI) +from InvenTree.mixins import ( + ListAPI, + ListCreateAPI, + RetrieveAPI, + RetrieveUpdateAPI, + RetrieveUpdateDestroyAPI, +) from InvenTree.permissions import IsStaffOrReadOnly, IsSuperuser from plugin.models import NotificationUserSetting from plugin.serializers import NotificationUserSettingSerializer @@ -41,6 +46,7 @@ class CsrfExemptMixin(object): class WebhookView(CsrfExemptMixin, APIView): """Endpoint for receiving webhooks.""" + authentication_classes = [] permission_classes = [] model_class = common.models.WebhookEndpoint @@ -66,8 +72,7 @@ class WebhookView(CsrfExemptMixin, APIView): async_task(self._process_payload, message.id) else: self._process_result( - self.webhook.process_payload(message, payload, headers), - message, + self.webhook.process_payload(message, payload, headers), message ) data = self.webhook.get_return(payload, headers, request) @@ -76,8 +81,7 @@ class WebhookView(CsrfExemptMixin, APIView): def _process_payload(self, message_id): message = common.models.WebhookMessage.objects.get(message_id=message_id) self._process_result( - self.webhook.process_payload(message, message.body, message.header), - message, + self.webhook.process_payload(message, message.body, message.header), message ) def _process_result(self, result, message): @@ -108,9 +112,7 @@ class WebhookView(CsrfExemptMixin, APIView): class CurrencyExchangeView(APIView): """API endpoint for displaying currency information""" - permission_classes = [ - permissions.IsAuthenticated, - ] + permission_classes = [permissions.IsAuthenticated] def get(self, request, format=None): """Return information on available currency conversions""" @@ -133,7 +135,9 @@ class CurrencyExchangeView(APIView): updated = None response = { - 'base_currency': common.models.InvenTreeSetting.get_setting('INVENTREE_DEFAULT_CURRENCY', 'USD'), + 'base_currency': common.models.InvenTreeSetting.get_setting( + 'INVENTREE_DEFAULT_CURRENCY', 'USD' + ), 'exchange_rates': {}, 'updated': updated, } @@ -150,10 +154,7 @@ class CurrencyRefreshView(APIView): User must be a 'staff' user to access this endpoint """ - permission_classes = [ - permissions.IsAuthenticated, - permissions.IsAdminUser, - ] + permission_classes = [permissions.IsAuthenticated, permissions.IsAdminUser] def post(self, request, *args, **kwargs): """Performing a POST request will update currency exchange rates""" @@ -161,9 +162,7 @@ class CurrencyRefreshView(APIView): update_exchange_rates(force=True) - return Response({ - 'success': 'Exchange rates updated', - }) + return Response({'success': 'Exchange rates updated'}) class SettingsList(ListAPI): @@ -174,21 +173,15 @@ class SettingsList(ListAPI): filter_backends = SEARCH_ORDER_FILTER - ordering_fields = [ - 'pk', - 'key', - 'name', - ] + ordering_fields = ['pk', 'key', 'name'] - search_fields = [ - 'key', - ] + search_fields = ['key'] class GlobalSettingsList(SettingsList): """API endpoint for accessing a list of global settings objects.""" - queryset = common.models.InvenTreeSetting.objects.exclude(key__startswith="_") + queryset = common.models.InvenTreeSetting.objects.exclude(key__startswith='_') serializer_class = common.serializers.GlobalSettingsSerializer def list(self, request, *args, **kwargs): @@ -221,25 +214,24 @@ class GlobalSettingsDetail(RetrieveUpdateAPI): """ lookup_field = 'key' - queryset = common.models.InvenTreeSetting.objects.exclude(key__startswith="_") + queryset = common.models.InvenTreeSetting.objects.exclude(key__startswith='_') serializer_class = common.serializers.GlobalSettingsSerializer def get_object(self): """Attempt to find a global setting object with the provided key.""" key = str(self.kwargs['key']).upper() - if key.startswith('_') or key not in common.models.InvenTreeSetting.SETTINGS.keys(): + if ( + key.startswith('_') + or key not in common.models.InvenTreeSetting.SETTINGS.keys() + ): raise NotFound() return common.models.InvenTreeSetting.get_setting_object( - key, - cache=False, create=True + key, cache=False, create=True ) - permission_classes = [ - permissions.IsAuthenticated, - GlobalSettingsPermissions, - ] + permission_classes = [permissions.IsAuthenticated, GlobalSettingsPermissions] class UserSettingsList(SettingsList): @@ -294,18 +286,17 @@ class UserSettingsDetail(RetrieveUpdateAPI): """Attempt to find a user setting object with the provided key.""" key = str(self.kwargs['key']).upper() - if key.startswith('_') or key not in common.models.InvenTreeUserSetting.SETTINGS.keys(): + if ( + key.startswith('_') + or key not in common.models.InvenTreeUserSetting.SETTINGS.keys() + ): raise NotFound() return common.models.InvenTreeUserSetting.get_setting_object( - key, - user=self.request.user, - cache=False, create=True + key, user=self.request.user, cache=False, create=True ) - permission_classes = [ - UserSettingsPermissions, - ] + permission_classes = [UserSettingsPermissions] class NotificationUserSettingsList(SettingsList): @@ -334,39 +325,29 @@ class NotificationUserSettingsDetail(RetrieveUpdateAPI): queryset = NotificationUserSetting.objects.all() serializer_class = NotificationUserSettingSerializer - permission_classes = [UserSettingsPermissions, ] + permission_classes = [UserSettingsPermissions] class NotificationMessageMixin: """Generic mixin for NotificationMessage.""" + queryset = common.models.NotificationMessage.objects.all() serializer_class = common.serializers.NotificationMessageSerializer - permission_classes = [UserSettingsPermissions, ] + permission_classes = [UserSettingsPermissions] class NotificationList(NotificationMessageMixin, BulkDeleteMixin, ListAPI): """List view for all notifications of the current user.""" - permission_classes = [permissions.IsAuthenticated, ] + permission_classes = [permissions.IsAuthenticated] filter_backends = SEARCH_ORDER_FILTER - ordering_fields = [ - 'category', - 'name', - 'read', - 'creation', - ] + ordering_fields = ['category', 'name', 'read', 'creation'] - search_fields = [ - 'name', - 'message', - ] + search_fields = ['name', 'message'] - filterset_fields = [ - 'category', - 'read', - ] + filterset_fields = ['category', 'read'] def filter_queryset(self, queryset): """Only list notifications which apply to the current user.""" @@ -401,29 +382,27 @@ class NotificationReadAll(NotificationMessageMixin, RetrieveAPI): self.queryset.filter(user=request.user, read=False).update(read=True) return Response({'status': 'ok'}) except Exception as exc: - raise serializers.ValidationError(detail=serializers.as_serializer_error(exc)) + raise serializers.ValidationError( + detail=serializers.as_serializer_error(exc) + ) class NewsFeedMixin: """Generic mixin for NewsFeedEntry.""" + queryset = common.models.NewsFeedEntry.objects.all() serializer_class = common.serializers.NewsFeedEntrySerializer - permission_classes = [IsAdminUser, ] + permission_classes = [IsAdminUser] class NewsFeedEntryList(NewsFeedMixin, BulkDeleteMixin, ListAPI): """List view for all news items.""" + filter_backends = ORDER_FILTER - ordering_fields = [ - 'published', - 'author', - 'read', - ] + ordering_fields = ['published', 'author', 'read'] - filterset_fields = [ - 'read', - ] + filterset_fields = ['read'] class NewsFeedEntryDetail(NewsFeedMixin, RetrieveUpdateDestroyAPI): @@ -435,14 +414,14 @@ class ConfigList(ListAPI): queryset = CONFIG_LOOKUPS serializer_class = common.serializers.ConfigSerializer - permission_classes = [IsSuperuser, ] + permission_classes = [IsSuperuser] class ConfigDetail(RetrieveAPI): """Detail view for an individual configuration.""" serializer_class = common.serializers.ConfigSerializer - permission_classes = [IsSuperuser, ] + permission_classes = [IsSuperuser] def get_object(self): """Attempt to find a config object with the provided key.""" @@ -458,7 +437,7 @@ class NotesImageList(ListCreateAPI): queryset = common.models.NotesImage.objects.all() serializer_class = common.serializers.NotesImageSerializer - permission_classes = [permissions.IsAuthenticated, ] + permission_classes = [permissions.IsAuthenticated] def perform_create(self, serializer): """Create (upload) a new notes image""" @@ -475,14 +454,9 @@ class ProjectCodeList(ListCreateAPI): permission_classes = [permissions.IsAuthenticated, IsStaffOrReadOnly] filter_backends = SEARCH_ORDER_FILTER - ordering_fields = [ - 'code', - ] + ordering_fields = ['code'] - search_fields = [ - 'code', - 'description', - ] + search_fields = ['code', 'description'] class ProjectCodeDetail(RetrieveUpdateDestroyAPI): @@ -515,14 +489,14 @@ class FlagList(ListAPI): queryset = settings.FLAGS serializer_class = common.serializers.FlagSerializer - permission_classes = [permissions.AllowAny, ] + permission_classes = [permissions.AllowAny] class FlagDetail(RetrieveAPI): """Detail view for an individual feature flag.""" serializer_class = common.serializers.FlagSerializer - permission_classes = [permissions.AllowAny, ] + permission_classes = [permissions.AllowAny] def get_object(self): """Attempt to find a config object with the provided key.""" @@ -535,97 +509,175 @@ class FlagDetail(RetrieveAPI): settings_api_urls = [ # User settings - re_path(r'^user/', include([ - # User Settings Detail - re_path(r'^(?P\w+)/', UserSettingsDetail.as_view(), name='api-user-setting-detail'), - - # User Settings List - re_path(r'^.*$', UserSettingsList.as_view(), name='api-user-setting-list'), - ])), - + re_path( + r'^user/', + include([ + # User Settings Detail + re_path( + r'^(?P\w+)/', + UserSettingsDetail.as_view(), + name='api-user-setting-detail', + ), + # User Settings List + re_path(r'^.*$', UserSettingsList.as_view(), name='api-user-setting-list'), + ]), + ), # Notification settings - re_path(r'^notification/', include([ - # Notification Settings Detail - path(r'/', NotificationUserSettingsDetail.as_view(), name='api-notification-setting-detail'), - - # Notification Settings List - re_path(r'^.*$', NotificationUserSettingsList.as_view(), name='api-notification-setting-list'), - ])), - + re_path( + r'^notification/', + include([ + # Notification Settings Detail + path( + r'/', + NotificationUserSettingsDetail.as_view(), + name='api-notification-setting-detail', + ), + # Notification Settings List + re_path( + r'^.*$', + NotificationUserSettingsList.as_view(), + name='api-notification-setting-list', + ), + ]), + ), # Global settings - re_path(r'^global/', include([ - # Global Settings Detail - re_path(r'^(?P\w+)/', GlobalSettingsDetail.as_view(), name='api-global-setting-detail'), - - # Global Settings List - re_path(r'^.*$', GlobalSettingsList.as_view(), name='api-global-setting-list'), - ])), + re_path( + r'^global/', + include([ + # Global Settings Detail + re_path( + r'^(?P\w+)/', + GlobalSettingsDetail.as_view(), + name='api-global-setting-detail', + ), + # Global Settings List + re_path( + r'^.*$', GlobalSettingsList.as_view(), name='api-global-setting-list' + ), + ]), + ), ] common_api_urls = [ # Webhooks path('webhook//', WebhookView.as_view(), name='api-webhook'), - # Uploaded images for notes - re_path(r'^notes-image-upload/', NotesImageList.as_view(), name='api-notes-image-list'), - + re_path( + r'^notes-image-upload/', NotesImageList.as_view(), name='api-notes-image-list' + ), # Project codes - re_path(r'^project-code/', include([ - path(r'/', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': common.models.ProjectCode}, name='api-project-code-metadata'), - re_path(r'^.*$', ProjectCodeDetail.as_view(), name='api-project-code-detail'), - ])), - re_path(r'^.*$', ProjectCodeList.as_view(), name='api-project-code-list'), - ])), - + re_path( + r'^project-code/', + include([ + path( + r'/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': common.models.ProjectCode}, + name='api-project-code-metadata', + ), + re_path( + r'^.*$', + ProjectCodeDetail.as_view(), + name='api-project-code-detail', + ), + ]), + ), + re_path(r'^.*$', ProjectCodeList.as_view(), name='api-project-code-list'), + ]), + ), # Custom physical units - re_path(r'^units/', include([ - path(r'/', include([ - re_path(r'^.*$', CustomUnitDetail.as_view(), name='api-custom-unit-detail'), - ])), - re_path(r'^.*$', CustomUnitList.as_view(), name='api-custom-unit-list'), - ])), - + re_path( + r'^units/', + include([ + path( + r'/', + include([ + re_path( + r'^.*$', + CustomUnitDetail.as_view(), + name='api-custom-unit-detail', + ) + ]), + ), + re_path(r'^.*$', CustomUnitList.as_view(), name='api-custom-unit-list'), + ]), + ), # Currencies - re_path(r'^currency/', include([ - re_path(r'^exchange/', CurrencyExchangeView.as_view(), name='api-currency-exchange'), - re_path(r'^refresh/', CurrencyRefreshView.as_view(), name='api-currency-refresh'), - ])), - + re_path( + r'^currency/', + include([ + re_path( + r'^exchange/', + CurrencyExchangeView.as_view(), + name='api-currency-exchange', + ), + re_path( + r'^refresh/', CurrencyRefreshView.as_view(), name='api-currency-refresh' + ), + ]), + ), # Notifications - re_path(r'^notifications/', include([ - # Individual purchase order detail URLs - path(r'/', include([ - re_path(r'.*$', NotificationDetail.as_view(), name='api-notifications-detail'), - ])), - # Read all - re_path(r'^readall/', NotificationReadAll.as_view(), name='api-notifications-readall'), - - # Notification messages list - re_path(r'^.*$', NotificationList.as_view(), name='api-notifications-list'), - ])), - + re_path( + r'^notifications/', + include([ + # Individual purchase order detail URLs + path( + r'/', + include([ + re_path( + r'.*$', + NotificationDetail.as_view(), + name='api-notifications-detail', + ) + ]), + ), + # Read all + re_path( + r'^readall/', + NotificationReadAll.as_view(), + name='api-notifications-readall', + ), + # Notification messages list + re_path(r'^.*$', NotificationList.as_view(), name='api-notifications-list'), + ]), + ), # News - re_path(r'^news/', include([ - path(r'/', include([ - re_path(r'.*$', NewsFeedEntryDetail.as_view(), name='api-news-detail'), - ])), - re_path(r'^.*$', NewsFeedEntryList.as_view(), name='api-news-list'), - ])), - + re_path( + r'^news/', + include([ + path( + r'/', + include([ + re_path( + r'.*$', NewsFeedEntryDetail.as_view(), name='api-news-detail' + ) + ]), + ), + re_path(r'^.*$', NewsFeedEntryList.as_view(), name='api-news-list'), + ]), + ), # Flags - path('flags/', include([ - path('/', FlagDetail.as_view(), name='api-flag-detail'), - re_path(r'^.*$', FlagList.as_view(), name='api-flag-list'), - ])), - + path( + 'flags/', + include([ + path('/', FlagDetail.as_view(), name='api-flag-detail'), + re_path(r'^.*$', FlagList.as_view(), name='api-flag-list'), + ]), + ), # Status - path('generic/status/', include([ - path(f'/', include([ - path('', StatusView.as_view(), name='api-status'), - ])), - path('', AllStatusViews.as_view(), name='api-status-all'), - ])), + path( + 'generic/status/', + include([ + path( + f'/', + include([path('', StatusView.as_view(), name='api-status')]), + ), + path('', AllStatusViews.as_view(), name='api-status-all'), + ]), + ), ] admin_api_urls = [ diff --git a/InvenTree/common/apps.py b/InvenTree/common/apps.py index e43789e99e..558d817793 100644 --- a/InvenTree/common/apps.py +++ b/InvenTree/common/apps.py @@ -30,10 +30,14 @@ class CommonConfig(AppConfig): try: import common.models - if common.models.InvenTreeSetting.get_setting('SERVER_RESTART_REQUIRED', backup_value=False, create=False, cache=False): - logger.info("Clearing SERVER_RESTART_REQUIRED flag") + if common.models.InvenTreeSetting.get_setting( + 'SERVER_RESTART_REQUIRED', backup_value=False, create=False, cache=False + ): + logger.info('Clearing SERVER_RESTART_REQUIRED flag') if not InvenTree.ready.isImportingData(): - common.models.InvenTreeSetting.set_setting('SERVER_RESTART_REQUIRED', False, None) + common.models.InvenTreeSetting.set_setting( + 'SERVER_RESTART_REQUIRED', False, None + ) except Exception: pass diff --git a/InvenTree/common/files.py b/InvenTree/common/files.py index d18ca34c98..37d85f46aa 100644 --- a/InvenTree/common/files.py +++ b/InvenTree/common/files.py @@ -49,12 +49,12 @@ class FileManager: ext = os.path.splitext(file.name)[-1].lower().replace('.', '') try: - if ext in ['csv', 'tsv', ]: + if ext in ['csv', 'tsv']: # These file formats need string decoding raw_data = file.read().decode('utf-8') # Reset stream position to beginning of file file.seek(0) - elif ext in ['xls', 'xlsx', 'json', 'yaml', ]: + elif ext in ['xls', 'xlsx', 'json', 'yaml']: raw_data = file.read() # Reset stream position to beginning of file file.seek(0) @@ -81,7 +81,12 @@ class FileManager: def update_headers(self): """Update headers.""" - self.HEADERS = self.REQUIRED_HEADERS + self.ITEM_MATCH_HEADERS + self.OPTIONAL_MATCH_HEADERS + self.OPTIONAL_HEADERS + self.HEADERS = ( + self.REQUIRED_HEADERS + + self.ITEM_MATCH_HEADERS + + self.OPTIONAL_MATCH_HEADERS + + self.OPTIONAL_HEADERS + ) def setup(self): """Setup headers should be overridden in usage to set the Different Headers.""" @@ -149,15 +154,9 @@ class FileManager: break if not guess_exists: - headers.append({ - 'name': header, - 'guess': guess - }) + headers.append({'name': header, 'guess': guess}) else: - headers.append({ - 'name': header, - 'guess': None - }) + headers.append({'name': header, 'guess': None}) return headers @@ -180,7 +179,6 @@ class FileManager: rows = [] for i in range(self.row_count()): - data = list(self.get_row_data(i)) # Is the row completely empty? Skip! @@ -203,10 +201,7 @@ class FileManager: if empty: continue - row = { - 'data': data, - 'index': i - } + row = {'data': data, 'index': i} rows.append(row) diff --git a/InvenTree/common/forms.py b/InvenTree/common/forms.py index b58f1ff352..b991eec06b 100644 --- a/InvenTree/common/forms.py +++ b/InvenTree/common/forms.py @@ -9,10 +9,7 @@ from .files import FileManager class UploadFileForm(forms.Form): """Step 1 of FileManagementFormView.""" - file = forms.FileField( - label=_('File'), - help_text=_('Select file to upload'), - ) + file = forms.FileField(label=_('File'), help_text=_('Select file to upload')) def __init__(self, *args, **kwargs): """Update label and help_text.""" @@ -67,9 +64,7 @@ class MatchFieldForm(forms.Form): self.fields[field_name] = forms.ChoiceField( choices=[('', '-' * 10)] + headers_choices, required=False, - widget=forms.Select(attrs={ - 'class': 'select fieldselect', - }) + widget=forms.Select(attrs={'class': 'select fieldselect'}), ) if col['guess']: self.fields[field_name].initial = col['guess'] @@ -107,7 +102,9 @@ class MatchItemForm(forms.Form): field_name = col_guess.lower() + '-' + str(row['index']) # check if field def was overridden - overriden_field = self.get_special_field(col_guess, row, file_manager) + overriden_field = self.get_special_field( + col_guess, row, file_manager + ) if overriden_field: self.fields[field_name] = overriden_field @@ -117,23 +114,23 @@ class MatchItemForm(forms.Form): value = row.get(col_guess.lower(), '') # Set field input box self.fields[field_name] = forms.CharField( - required=True, - initial=value, + required=True, initial=value ) # Create item selection box elif col_guess in file_manager.OPTIONAL_MATCH_HEADERS: # Get item options - item_options = [(option.id, option) for option in row['match_options_' + col_guess]] + item_options = [ + (option.id, option) + for option in row['match_options_' + col_guess] + ] # Get item match item_match = row['match_' + col_guess] # Set field select box self.fields[field_name] = forms.ChoiceField( choices=[('', '-' * 10)] + item_options, required=False, - widget=forms.Select(attrs={ - 'class': 'select bomselect', - }) + widget=forms.Select(attrs={'class': 'select bomselect'}), ) # Update select box when match was found if item_match: @@ -142,7 +139,9 @@ class MatchItemForm(forms.Form): # Create item selection box elif col_guess in file_manager.ITEM_MATCH_HEADERS: # Get item options - item_options = [(option.id, option) for option in row['item_options']] + item_options = [ + (option.id, option) for option in row['item_options'] + ] # Get item match item_match = row['item_match'] # Set field name @@ -151,9 +150,7 @@ class MatchItemForm(forms.Form): self.fields[field_name] = forms.ChoiceField( choices=[('', '-' * 10)] + item_options, required=False, - widget=forms.Select(attrs={ - 'class': 'select bomselect', - }) + widget=forms.Select(attrs={'class': 'select bomselect'}), ) # Update select box when match was found if item_match: @@ -169,8 +166,7 @@ class MatchItemForm(forms.Form): value = row.get(col_guess.lower(), '') # Set field input box self.fields[field_name] = forms.CharField( - required=False, - initial=value, + required=False, initial=value ) def get_special_field(self, col_guess, row, file_manager): diff --git a/InvenTree/common/models.py b/InvenTree/common/models.py index 620f598970..71cb025302 100644 --- a/InvenTree/common/models.py +++ b/InvenTree/common/models.py @@ -27,8 +27,7 @@ from django.contrib.humanize.templatetags.humanize import naturaltime from django.contrib.sites.models import Site from django.core.cache import cache from django.core.exceptions import AppRegistryNotReady, ValidationError -from django.core.validators import (MaxValueValidator, MinValueValidator, - URLValidator) +from django.core.validators import MaxValueValidator, MinValueValidator, URLValidator from django.db import models, transaction from django.db.models.signals import post_delete, post_save from django.db.utils import IntegrityError, OperationalError, ProgrammingError @@ -66,6 +65,7 @@ class MetaMixin(models.Model): class Meta: """Meta options for MetaMixin.""" + abstract = True updated = models.DateTimeField( @@ -88,7 +88,13 @@ class BaseURLValidator(URLValidator): super().__init__(schemes, **kwargs) # Override default host_re value - allow optional tld regex - self.host_re = '(' + self.hostname_re + self.domain_re + f'({self.tld_re})?' + '|localhost)' + self.host_re = ( + '(' + + self.hostname_re + + self.domain_re + + f'({self.tld_re})?' + + '|localhost)' + ) def __call__(self, value): """Make sure empty values pass.""" @@ -130,7 +136,8 @@ class ProjectCode(InvenTree.models.MetadataMixin, models.Model): responsible = models.ForeignKey( users.models.Owner, on_delete=models.SET_NULL, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Responsible'), help_text=_('User or group responsible for this project'), related_name='project_codes', @@ -213,7 +220,7 @@ class BaseInvenTreeSetting(models.Model): If a particular setting is not present, create it with the default value """ - cache_key = f"BUILD_DEFAULT_VALUES:{str(cls.__name__)}" + cache_key = f'BUILD_DEFAULT_VALUES:{str(cls.__name__)}' if InvenTree.helpers.str2bool(cache.get(cache_key, False)): # Already built default values @@ -226,16 +233,18 @@ class BaseInvenTreeSetting(models.Model): missing_keys = set(settings_keys) - set(existing_keys) if len(missing_keys) > 0: - logger.info("Building %s default values for %s", len(missing_keys), str(cls)) + logger.info( + 'Building %s default values for %s', len(missing_keys), str(cls) + ) cls.objects.bulk_create([ - cls( - key=key, - value=cls.get_setting_default(key), - **kwargs - ) for key in missing_keys if not key.startswith('_') + cls(key=key, value=cls.get_setting_default(key), **kwargs) + for key in missing_keys + if not key.startswith('_') ]) except Exception as exc: - logger.exception("Failed to build default values for %s (%s)", str(cls), str(type(exc))) + logger.exception( + 'Failed to build default values for %s (%s)', str(cls), str(type(exc)) + ) pass cache.set(cache_key, True, timeout=3600) @@ -249,7 +258,9 @@ class BaseInvenTreeSetting(models.Model): kwargs: Keyword arguments to pass to the function """ # Get action - setting = self.get_setting_definition(self.key, *args, **{**self.get_filters_for_instance(), **kwargs}) + setting = self.get_setting_definition( + self.key, *args, **{**self.get_filters_for_instance(), **kwargs} + ) settings_fnc = setting.get(reference, None) # Execute if callable @@ -259,7 +270,9 @@ class BaseInvenTreeSetting(models.Model): @property def cache_key(self): """Generate a unique cache key for this settings object""" - return self.__class__.create_cache_key(self.key, **self.get_filters_for_instance()) + return self.__class__.create_cache_key( + self.key, **self.get_filters_for_instance() + ) def save_to_cache(self): """Save this setting object to cache""" @@ -272,11 +285,7 @@ class BaseInvenTreeSetting(models.Model): logger.debug("Saving setting '%s' to cache", ckey) try: - cache.set( - ckey, - self, - timeout=3600 - ) + cache.set(ckey, self, timeout=3600) except TypeError: # Some characters cause issues with caching; ignore and move on pass @@ -290,24 +299,38 @@ class BaseInvenTreeSetting(models.Model): - The unique KEY string - Any key:value kwargs associated with the particular setting type (e.g. user-id) """ - key = f"{str(cls.__name__)}:{setting_key}" + key = f'{str(cls.__name__)}:{setting_key}' for k, v in kwargs.items(): - key += f"_{k}:{v}" + key += f'_{k}:{v}' - return key.replace(" ", "") + return key.replace(' ', '') @classmethod def get_filters(cls, **kwargs): """Enable to filter by other kwargs defined in cls.extra_unique_fields""" - return {key: value for key, value in kwargs.items() if key in cls.extra_unique_fields} + return { + key: value + for key, value in kwargs.items() + if key in cls.extra_unique_fields + } def get_filters_for_instance(self): """Enable to filter by other fields defined in self.extra_unique_fields""" - return {key: getattr(self, key, None) for key in self.extra_unique_fields if hasattr(self, key)} + return { + key: getattr(self, key, None) + for key in self.extra_unique_fields + if hasattr(self, key) + } @classmethod - def all_settings(cls, *, exclude_hidden=False, settings_definition: Union[Dict[str, SettingsKeyType], None] = None, **kwargs): + def all_settings( + cls, + *, + exclude_hidden=False, + settings_definition: Union[Dict[str, SettingsKeyType], None] = None, + **kwargs, + ): """Return a list of "all" defined settings. This performs a single database lookup, @@ -339,18 +362,18 @@ class BaseInvenTreeSetting(models.Model): settings[key.upper()] = cls( key=key.upper(), value=cls.get_setting_default(key, **filters), - **filters + **filters, ) # remove any hidden settings - if exclude_hidden and setting.get("hidden", False): + if exclude_hidden and setting.get('hidden', False): del settings[key.upper()] # format settings values and remove protected for key, setting in settings.items(): validator = cls.get_setting_validator(key, **filters) - if cls.is_protected(key, **filters) and setting.value != "": + if cls.is_protected(key, **filters) and setting.value != '': setting.value = '***' elif cls.validator_is_bool(validator): setting.value = InvenTree.helpers.str2bool(setting.value) @@ -363,14 +386,24 @@ class BaseInvenTreeSetting(models.Model): return settings @classmethod - def allValues(cls, *, exclude_hidden=False, settings_definition: Union[Dict[str, SettingsKeyType], None] = None, **kwargs): + def allValues( + cls, + *, + exclude_hidden=False, + settings_definition: Union[Dict[str, SettingsKeyType], None] = None, + **kwargs, + ): """Return a dict of "all" defined global settings. This performs a single database lookup, and then any settings which are not *in* the database are assigned their default values """ - all_settings = cls.all_settings(exclude_hidden=exclude_hidden, settings_definition=settings_definition, **kwargs) + all_settings = cls.all_settings( + exclude_hidden=exclude_hidden, + settings_definition=settings_definition, + **kwargs, + ) settings: Dict[str, Any] = {} @@ -380,14 +413,24 @@ class BaseInvenTreeSetting(models.Model): return settings @classmethod - def check_all_settings(cls, *, exclude_hidden=False, settings_definition: Union[Dict[str, SettingsKeyType], None] = None, **kwargs): + def check_all_settings( + cls, + *, + exclude_hidden=False, + settings_definition: Union[Dict[str, SettingsKeyType], None] = None, + **kwargs, + ): """Check if all required settings are set by definition. Returns: is_valid: Are all required settings defined missing_settings: List of all settings that are missing (empty if is_valid is 'True') """ - all_settings = cls.all_settings(exclude_hidden=exclude_hidden, settings_definition=settings_definition, **kwargs) + all_settings = cls.all_settings( + exclude_hidden=exclude_hidden, + settings_definition=settings_definition, + **kwargs, + ) missing_settings: List[str] = [] @@ -395,7 +438,7 @@ class BaseInvenTreeSetting(models.Model): if setting.required: value = setting.value or cls.get_setting_default(setting.key, **kwargs) - if value == "": + if value == '': missing_settings.append(setting.key.upper()) return len(missing_settings) == 0, missing_settings @@ -496,7 +539,6 @@ class BaseInvenTreeSetting(models.Model): filters = { 'key__iexact': key, - # Optionally filter by other keys **cls.get_filters(**kwargs), } @@ -539,9 +581,13 @@ class BaseInvenTreeSetting(models.Model): # Setting does not exist! (Try to create it) if not setting: - # Prevent creation of new settings objects when importing data - if InvenTree.ready.isImportingData() or not InvenTree.ready.canAppAccessDatabase(allow_test=True, allow_shell=True): + if ( + InvenTree.ready.isImportingData() + or not InvenTree.ready.canAppAccessDatabase( + allow_test=True, allow_shell=True + ) + ): create = False if create: @@ -549,11 +595,7 @@ class BaseInvenTreeSetting(models.Model): default_value = cls.get_setting_default(key, **kwargs) - setting = cls( - key=key, - value=default_value, - **kwargs - ) + setting = cls(key=key, value=default_value, **kwargs) try: # Wrap this statement in "atomic", so it can be rolled back if it fails @@ -620,7 +662,6 @@ class BaseInvenTreeSetting(models.Model): filters = { 'key__iexact': key, - # Optionally filter by other keys **cls.get_filters(**kwargs), } @@ -628,7 +669,6 @@ class BaseInvenTreeSetting(models.Model): try: setting = cls.objects.get(**filters) except cls.DoesNotExist: - if create: setting = cls(key=key, **kwargs) else: @@ -648,36 +688,62 @@ class BaseInvenTreeSetting(models.Model): # Likely a race condition has caused a duplicate entry to be created if attempts > 0: # Try again - logger.info("Duplicate setting key '%s' for %s - trying again", key, str(cls)) - cls.set_setting(key, value, change_user, create=create, attempts=attempts - 1, **kwargs) + logger.info( + "Duplicate setting key '%s' for %s - trying again", key, str(cls) + ) + cls.set_setting( + key, + value, + change_user, + create=create, + attempts=attempts - 1, + **kwargs, + ) except Exception as exc: # Some other error - logger.exception("Error setting setting '%s' for %s: %s", key, str(cls), str(type(exc))) + logger.exception( + "Error setting setting '%s' for %s: %s", key, str(cls), str(type(exc)) + ) pass - key = models.CharField(max_length=50, blank=False, unique=False, help_text=_('Settings key (must be unique - case insensitive)')) + key = models.CharField( + max_length=50, + blank=False, + unique=False, + help_text=_('Settings key (must be unique - case insensitive)'), + ) - value = models.CharField(max_length=2000, blank=True, unique=False, help_text=_('Settings value')) + value = models.CharField( + max_length=2000, blank=True, unique=False, help_text=_('Settings value') + ) @property def name(self): """Return name for setting.""" - return self.__class__.get_setting_name(self.key, **self.get_filters_for_instance()) + return self.__class__.get_setting_name( + self.key, **self.get_filters_for_instance() + ) @property def default_value(self): """Return default_value for setting.""" - return self.__class__.get_setting_default(self.key, **self.get_filters_for_instance()) + return self.__class__.get_setting_default( + self.key, **self.get_filters_for_instance() + ) @property def description(self): """Return description for setting.""" - return self.__class__.get_setting_description(self.key, **self.get_filters_for_instance()) + return self.__class__.get_setting_description( + self.key, **self.get_filters_for_instance() + ) @property def units(self): """Return units for setting.""" - return self.__class__.get_setting_units(self.key, **self.get_filters_for_instance()) + return self.__class__.get_setting_units( + self.key, **self.get_filters_for_instance() + ) def clean(self): """If a validator (or multiple validators) are defined for a particular setting key, run them against the 'value' field.""" @@ -690,7 +756,9 @@ class BaseInvenTreeSetting(models.Model): elif self.is_bool(): self.value = self.as_bool() - validator = self.__class__.get_setting_validator(self.key, **self.get_filters_for_instance()) + validator = self.__class__.get_setting_validator( + self.key, **self.get_filters_for_instance() + ) if validator is not None: self.run_validator(validator) @@ -698,7 +766,7 @@ class BaseInvenTreeSetting(models.Model): options = self.valid_options() if options and self.value not in options: - raise ValidationError(_("Chosen value is not a valid option")) + raise ValidationError(_('Chosen value is not a valid option')) def run_validator(self, validator): """Run a validator against the 'value' field for this InvenTreeSetting object.""" @@ -714,20 +782,15 @@ class BaseInvenTreeSetting(models.Model): # Coerce into either "True" or "False" value = InvenTree.helpers.str2bool(value) else: - raise ValidationError({ - 'value': _('Value must be a boolean value') - }) + raise ValidationError({'value': _('Value must be a boolean value')}) # Integer validator if validator is int: - try: # Coerce into an integer value value = int(value) except (ValueError, TypeError): - raise ValidationError({ - 'value': _('Value must be an integer value'), - }) + raise ValidationError({'value': _('Value must be an integer value')}) # If a list of validators is supplied, iterate through each one if type(validator) in [list, tuple]: @@ -755,7 +818,6 @@ class BaseInvenTreeSetting(models.Model): filters = { 'key__iexact': self.key, - # Optionally filter by other keys **self.get_filters_for_instance(), } @@ -772,7 +834,9 @@ class BaseInvenTreeSetting(models.Model): def choices(self): """Return the available choices for this setting (or None if no choices are defined).""" - return self.__class__.get_setting_choices(self.key, **self.get_filters_for_instance()) + return self.__class__.get_setting_choices( + self.key, **self.get_filters_for_instance() + ) def valid_options(self): """Return a list of valid options for this setting.""" @@ -785,7 +849,12 @@ class BaseInvenTreeSetting(models.Model): def is_choice(self): """Check if this setting is a "choice" field.""" - return self.__class__.get_setting_choices(self.key, **self.get_filters_for_instance()) is not None + return ( + self.__class__.get_setting_choices( + self.key, **self.get_filters_for_instance() + ) + is not None + ) def as_choice(self): """Render this setting as the "display" value of a choice field. @@ -812,7 +881,9 @@ class BaseInvenTreeSetting(models.Model): def model_name(self): """Return the model name associated with this setting.""" - setting = self.get_setting_definition(self.key, **self.get_filters_for_instance()) + setting = self.get_setting_definition( + self.key, **self.get_filters_for_instance() + ) return setting.get('model', None) @@ -832,19 +903,31 @@ class BaseInvenTreeSetting(models.Model): try: (app, mdl) = model_name.strip().split('.') except ValueError: - logger.exception("Invalid 'model' parameter for setting '%s': '%s'", self.key, model_name) + logger.exception( + "Invalid 'model' parameter for setting '%s': '%s'", self.key, model_name + ) return None app_models = apps.all_models.get(app, None) if app_models is None: - logger.error("Error retrieving model class '%s' for setting '%s' - no app named '%s'", model_name, self.key, app) + logger.error( + "Error retrieving model class '%s' for setting '%s' - no app named '%s'", + model_name, + self.key, + app, + ) return None model = app_models.get(mdl, None) if model is None: - logger.error("Error retrieving model class '%s' for setting '%s' - no model named '%s'", model_name, self.key, mdl) + logger.error( + "Error retrieving model class '%s' for setting '%s' - no model named '%s'", + model_name, + self.key, + mdl, + ) return None # Looks like we have found a model! @@ -867,7 +950,9 @@ class BaseInvenTreeSetting(models.Model): 'auth.group': 'api-group-list', } - model_table = f'{model_class._meta.app_label}.{model_class._meta.model_name}' + model_table = ( + f'{model_class._meta.app_label}.{model_class._meta.model_name}' + ) if url := hardcoded_models[model_table]: return reverse(url) @@ -876,7 +961,9 @@ class BaseInvenTreeSetting(models.Model): def is_bool(self): """Check if this setting is required to be a boolean value.""" - validator = self.__class__.get_setting_validator(self.key, **self.get_filters_for_instance()) + validator = self.__class__.get_setting_validator( + self.key, **self.get_filters_for_instance() + ) return self.__class__.validator_is_bool(validator) @@ -912,9 +999,11 @@ class BaseInvenTreeSetting(models.Model): return False - def is_int(self,): + def is_int(self): """Check if the setting is required to be an integer value.""" - validator = self.__class__.get_setting_validator(self.key, **self.get_filters_for_instance()) + validator = self.__class__.get_setting_validator( + self.key, **self.get_filters_for_instance() + ) return self.__class__.validator_is_int(validator) @@ -960,7 +1049,7 @@ class BaseInvenTreeSetting(models.Model): """Check if this setting value is required.""" setting = cls.get_setting_definition(key, **cls.get_filters(**kwargs)) - return setting.get("required", False) + return setting.get('required', False) @property def required(self): @@ -1004,13 +1093,12 @@ def currency_exchange_plugins(): """Return a set of plugin choices which can be used for currency exchange""" try: from plugin import registry + plugs = registry.with_mixin('currencyexchange', active=True) except Exception: plugs = [] - return [ - ('', _('No plugin')), - ] + [(plug.slug, plug.human_name) for plug in plugs] + return [('', _('No plugin'))] + [(plug.slug, plug.human_name) for plug in plugs] def update_exchange_rates(setting): @@ -1043,8 +1131,8 @@ class InvenTreeSetting(BaseInvenTreeSetting): class Meta: """Meta options for InvenTreeSetting.""" - verbose_name = "InvenTree Setting" - verbose_name_plural = "InvenTree Settings" + verbose_name = 'InvenTree Setting' + verbose_name_plural = 'InvenTree Settings' def save(self, *args, **kwargs): """When saving a global setting, check to see if it requires a server restart. @@ -1073,49 +1161,44 @@ class InvenTreeSetting(BaseInvenTreeSetting): """ SETTINGS = { - 'SERVER_RESTART_REQUIRED': { 'name': _('Restart required'), - 'description': _('A setting has been changed which requires a server restart'), + 'description': _( + 'A setting has been changed which requires a server restart' + ), 'default': False, 'validator': bool, 'hidden': True, }, - '_PENDING_MIGRATIONS': { 'name': _('Pending migrations'), 'description': _('Number of pending database migrations'), 'default': 0, 'validator': int, }, - 'INVENTREE_INSTANCE': { 'name': _('Server Instance Name'), 'default': 'InvenTree', 'description': _('String descriptor for the server instance'), 'after_save': update_instance_name, }, - 'INVENTREE_INSTANCE_TITLE': { 'name': _('Use instance name'), 'description': _('Use the instance name in the title-bar'), 'validator': bool, 'default': False, }, - 'INVENTREE_RESTRICT_ABOUT': { 'name': _('Restrict showing `about`'), 'description': _('Show the `about` modal only to superusers'), 'validator': bool, 'default': False, }, - 'INVENTREE_COMPANY_NAME': { 'name': _('Company name'), 'description': _('Internal company name'), 'default': 'My company name', }, - 'INVENTREE_BASE_URL': { 'name': _('Base URL'), 'description': _('Base URL for server instance'), @@ -1123,7 +1206,6 @@ class InvenTreeSetting(BaseInvenTreeSetting): 'default': '', 'after_save': update_instance_url, }, - 'INVENTREE_DEFAULT_CURRENCY': { 'name': _('Default Currency'), 'description': _('Select base currency for pricing calculations'), @@ -1131,597 +1213,534 @@ class InvenTreeSetting(BaseInvenTreeSetting): 'choices': CURRENCY_CHOICES, 'after_save': update_exchange_rates, }, - 'CURRENCY_UPDATE_INTERVAL': { 'name': _('Currency Update Interval'), - 'description': _('How often to update exchange rates (set to zero to disable)'), + 'description': _( + 'How often to update exchange rates (set to zero to disable)' + ), 'default': 1, 'units': _('days'), - 'validator': [ - int, - MinValueValidator(0), - ], + 'validator': [int, MinValueValidator(0)], }, - 'CURRENCY_UPDATE_PLUGIN': { 'name': _('Currency Update Plugin'), 'description': _('Currency update plugin to use'), 'choices': currency_exchange_plugins, - 'default': 'inventreecurrencyexchange' + 'default': 'inventreecurrencyexchange', }, - 'INVENTREE_DOWNLOAD_FROM_URL': { 'name': _('Download from URL'), - 'description': _('Allow download of remote images and files from external URL'), + 'description': _( + 'Allow download of remote images and files from external URL' + ), 'validator': bool, 'default': False, }, - 'INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE': { 'name': _('Download Size Limit'), 'description': _('Maximum allowable download size for remote image'), 'units': 'MB', 'default': 1, - 'validator': [ - int, - MinValueValidator(1), - MaxValueValidator(25), - ] + 'validator': [int, MinValueValidator(1), MaxValueValidator(25)], }, - 'INVENTREE_DOWNLOAD_FROM_URL_USER_AGENT': { 'name': _('User-agent used to download from URL'), - 'description': _('Allow to override the user-agent used to download images and files from external URL (leave blank for the default)'), + 'description': _( + 'Allow to override the user-agent used to download images and files from external URL (leave blank for the default)' + ), 'default': '', }, - 'INVENTREE_STRICT_URLS': { 'name': _('Strict URL Validation'), 'description': _('Require schema specification when validating URLs'), 'validator': bool, 'default': True, }, - 'INVENTREE_REQUIRE_CONFIRM': { 'name': _('Require confirm'), 'description': _('Require explicit user confirmation for certain action.'), 'validator': bool, 'default': True, }, - 'INVENTREE_TREE_DEPTH': { 'name': _('Tree Depth'), - 'description': _('Default tree depth for treeview. Deeper levels can be lazy loaded as they are needed.'), + 'description': _( + 'Default tree depth for treeview. Deeper levels can be lazy loaded as they are needed.' + ), 'default': 1, - 'validator': [ - int, - MinValueValidator(0), - ] + 'validator': [int, MinValueValidator(0)], }, - 'INVENTREE_UPDATE_CHECK_INTERVAL': { 'name': _('Update Check Interval'), 'description': _('How often to check for updates (set to zero to disable)'), - 'validator': [ - int, - MinValueValidator(0), - ], + 'validator': [int, MinValueValidator(0)], 'default': 7, 'units': _('days'), }, - 'INVENTREE_BACKUP_ENABLE': { 'name': _('Automatic Backup'), 'description': _('Enable automatic backup of database and media files'), 'validator': bool, 'default': False, }, - 'INVENTREE_BACKUP_DAYS': { 'name': _('Auto Backup Interval'), 'description': _('Specify number of days between automated backup events'), - 'validator': [ - int, - MinValueValidator(1), - ], + 'validator': [int, MinValueValidator(1)], 'default': 1, 'units': _('days'), }, - 'INVENTREE_DELETE_TASKS_DAYS': { 'name': _('Task Deletion Interval'), - 'description': _('Background task results will be deleted after specified number of days'), + 'description': _( + 'Background task results will be deleted after specified number of days' + ), 'default': 30, 'units': _('days'), - 'validator': [ - int, - MinValueValidator(7), - ] + 'validator': [int, MinValueValidator(7)], }, - 'INVENTREE_DELETE_ERRORS_DAYS': { 'name': _('Error Log Deletion Interval'), - 'description': _('Error logs will be deleted after specified number of days'), + 'description': _( + 'Error logs will be deleted after specified number of days' + ), 'default': 30, 'units': _('days'), - 'validator': [ - int, - MinValueValidator(7) - ] + 'validator': [int, MinValueValidator(7)], }, - 'INVENTREE_DELETE_NOTIFICATIONS_DAYS': { 'name': _('Notification Deletion Interval'), - 'description': _('User notifications will be deleted after specified number of days'), + 'description': _( + 'User notifications will be deleted after specified number of days' + ), 'default': 30, 'units': _('days'), - 'validator': [ - int, - MinValueValidator(7), - ] + 'validator': [int, MinValueValidator(7)], }, - 'BARCODE_ENABLE': { 'name': _('Barcode Support'), 'description': _('Enable barcode scanner support in the web interface'), 'default': True, 'validator': bool, }, - 'BARCODE_INPUT_DELAY': { 'name': _('Barcode Input Delay'), 'description': _('Barcode input processing delay time'), 'default': 50, - 'validator': [ - int, - MinValueValidator(1), - ], + 'validator': [int, MinValueValidator(1)], 'units': 'ms', }, - 'BARCODE_WEBCAM_SUPPORT': { 'name': _('Barcode Webcam Support'), 'description': _('Allow barcode scanning via webcam in browser'), 'default': True, 'validator': bool, }, - 'PART_ENABLE_REVISION': { 'name': _('Part Revisions'), 'description': _('Enable revision field for Part'), 'validator': bool, 'default': True, }, - 'PART_IPN_REGEX': { 'name': _('IPN Regex'), - 'description': _('Regular expression pattern for matching Part IPN') + 'description': _('Regular expression pattern for matching Part IPN'), }, - 'PART_ALLOW_DUPLICATE_IPN': { 'name': _('Allow Duplicate IPN'), 'description': _('Allow multiple parts to share the same IPN'), 'default': True, 'validator': bool, }, - 'PART_ALLOW_EDIT_IPN': { 'name': _('Allow Editing IPN'), 'description': _('Allow changing the IPN value while editing a part'), 'default': True, 'validator': bool, }, - 'PART_COPY_BOM': { 'name': _('Copy Part BOM Data'), 'description': _('Copy BOM data by default when duplicating a part'), 'default': True, 'validator': bool, }, - 'PART_COPY_PARAMETERS': { 'name': _('Copy Part Parameter Data'), 'description': _('Copy parameter data by default when duplicating a part'), 'default': True, 'validator': bool, }, - 'PART_COPY_TESTS': { 'name': _('Copy Part Test Data'), 'description': _('Copy test data by default when duplicating a part'), 'default': True, - 'validator': bool + 'validator': bool, }, - 'PART_CATEGORY_PARAMETERS': { 'name': _('Copy Category Parameter Templates'), 'description': _('Copy category parameter templates when creating a part'), 'default': True, - 'validator': bool + 'validator': bool, }, - 'PART_TEMPLATE': { 'name': _('Template'), 'description': _('Parts are templates by default'), 'default': False, 'validator': bool, }, - 'PART_ASSEMBLY': { 'name': _('Assembly'), 'description': _('Parts can be assembled from other components by default'), 'default': False, 'validator': bool, }, - 'PART_COMPONENT': { 'name': _('Component'), 'description': _('Parts can be used as sub-components by default'), 'default': True, 'validator': bool, }, - 'PART_PURCHASEABLE': { 'name': _('Purchaseable'), 'description': _('Parts are purchaseable by default'), 'default': True, 'validator': bool, }, - 'PART_SALABLE': { 'name': _('Salable'), 'description': _('Parts are salable by default'), 'default': False, 'validator': bool, }, - 'PART_TRACKABLE': { 'name': _('Trackable'), 'description': _('Parts are trackable by default'), 'default': False, 'validator': bool, }, - 'PART_VIRTUAL': { 'name': _('Virtual'), 'description': _('Parts are virtual by default'), 'default': False, 'validator': bool, }, - 'PART_SHOW_IMPORT': { 'name': _('Show Import in Views'), 'description': _('Display the import wizard in some part views'), 'default': False, 'validator': bool, }, - 'PART_SHOW_RELATED': { 'name': _('Show related parts'), 'description': _('Display related parts for a part'), 'default': True, 'validator': bool, }, - 'PART_CREATE_INITIAL': { 'name': _('Initial Stock Data'), 'description': _('Allow creation of initial stock when adding a new part'), 'default': False, 'validator': bool, }, - 'PART_CREATE_SUPPLIER': { 'name': _('Initial Supplier Data'), - 'description': _('Allow creation of initial supplier data when adding a new part'), + 'description': _( + 'Allow creation of initial supplier data when adding a new part' + ), 'default': True, 'validator': bool, }, - 'PART_NAME_FORMAT': { 'name': _('Part Name Display Format'), 'description': _('Format to display the part name'), 'default': "{{ part.IPN if part.IPN }}{{ ' | ' if part.IPN }}{{ part.name }}{{ ' | ' if part.revision }}" - "{{ part.revision if part.revision }}", - 'validator': InvenTree.validators.validate_part_name_format + '{{ part.revision if part.revision }}', + 'validator': InvenTree.validators.validate_part_name_format, }, - 'PART_CATEGORY_DEFAULT_ICON': { 'name': _('Part Category Default Icon'), 'description': _('Part category default icon (empty means no icon)'), 'default': '', }, - 'PART_PARAMETER_ENFORCE_UNITS': { 'name': _('Enforce Parameter Units'), - 'description': _('If units are provided, parameter values must match the specified units'), + 'description': _( + 'If units are provided, parameter values must match the specified units' + ), 'default': True, 'validator': bool, }, - 'PRICING_DECIMAL_PLACES_MIN': { 'name': _('Minimum Pricing Decimal Places'), - 'description': _('Minimum number of decimal places to display when rendering pricing data'), + 'description': _( + 'Minimum number of decimal places to display when rendering pricing data' + ), 'default': 0, - 'validator': [ - int, - MinValueValidator(0), - MaxValueValidator(4), - ] + 'validator': [int, MinValueValidator(0), MaxValueValidator(4)], }, - 'PRICING_DECIMAL_PLACES': { 'name': _('Maximum Pricing Decimal Places'), - 'description': _('Maximum number of decimal places to display when rendering pricing data'), + 'description': _( + 'Maximum number of decimal places to display when rendering pricing data' + ), 'default': 6, - 'validator': [ - int, - MinValueValidator(2), - MaxValueValidator(6) - ] + 'validator': [int, MinValueValidator(2), MaxValueValidator(6)], }, - 'PRICING_USE_SUPPLIER_PRICING': { 'name': _('Use Supplier Pricing'), - 'description': _('Include supplier price breaks in overall pricing calculations'), + 'description': _( + 'Include supplier price breaks in overall pricing calculations' + ), 'default': True, 'validator': bool, }, - 'PRICING_PURCHASE_HISTORY_OVERRIDES_SUPPLIER': { 'name': _('Purchase History Override'), - 'description': _('Historical purchase order pricing overrides supplier price breaks'), + 'description': _( + 'Historical purchase order pricing overrides supplier price breaks' + ), 'default': False, 'validator': bool, }, - 'PRICING_USE_STOCK_PRICING': { 'name': _('Use Stock Item Pricing'), - 'description': _('Use pricing from manually entered stock data for pricing calculations'), + 'description': _( + 'Use pricing from manually entered stock data for pricing calculations' + ), 'default': True, 'validator': bool, }, - 'PRICING_STOCK_ITEM_AGE_DAYS': { 'name': _('Stock Item Pricing Age'), - 'description': _('Exclude stock items older than this number of days from pricing calculations'), + 'description': _( + 'Exclude stock items older than this number of days from pricing calculations' + ), 'default': 0, 'units': _('days'), - 'validator': [ - int, - MinValueValidator(0), - ] + 'validator': [int, MinValueValidator(0)], }, - 'PRICING_USE_VARIANT_PRICING': { 'name': _('Use Variant Pricing'), 'description': _('Include variant pricing in overall pricing calculations'), 'default': True, 'validator': bool, }, - 'PRICING_ACTIVE_VARIANTS': { 'name': _('Active Variants Only'), - 'description': _('Only use active variant parts for calculating variant pricing'), + 'description': _( + 'Only use active variant parts for calculating variant pricing' + ), 'default': False, 'validator': bool, }, - 'PRICING_UPDATE_DAYS': { 'name': _('Pricing Rebuild Interval'), - 'description': _('Number of days before part pricing is automatically updated'), + 'description': _( + 'Number of days before part pricing is automatically updated' + ), 'units': _('days'), 'default': 30, - 'validator': [ - int, - MinValueValidator(10), - ] + 'validator': [int, MinValueValidator(10)], }, - 'PART_INTERNAL_PRICE': { 'name': _('Internal Prices'), 'description': _('Enable internal prices for parts'), 'default': False, - 'validator': bool + 'validator': bool, }, - 'PART_BOM_USE_INTERNAL_PRICE': { 'name': _('Internal Price Override'), - 'description': _('If available, internal prices override price range calculations'), + 'description': _( + 'If available, internal prices override price range calculations' + ), 'default': False, - 'validator': bool + 'validator': bool, }, - 'LABEL_ENABLE': { 'name': _('Enable label printing'), 'description': _('Enable label printing from the web interface'), 'default': True, 'validator': bool, }, - 'LABEL_DPI': { 'name': _('Label Image DPI'), - 'description': _('DPI resolution when generating image files to supply to label printing plugins'), + 'description': _( + 'DPI resolution when generating image files to supply to label printing plugins' + ), 'default': 300, - 'validator': [ - int, - MinValueValidator(100), - ] + 'validator': [int, MinValueValidator(100)], }, - 'REPORT_ENABLE': { 'name': _('Enable Reports'), 'description': _('Enable generation of reports'), 'default': False, 'validator': bool, }, - 'REPORT_DEBUG_MODE': { 'name': _('Debug Mode'), 'description': _('Generate reports in debug mode (HTML output)'), 'default': False, 'validator': bool, }, - 'REPORT_DEFAULT_PAGE_SIZE': { 'name': _('Page Size'), 'description': _('Default page size for PDF reports'), 'default': 'A4', 'choices': report.helpers.report_page_size_options, }, - 'REPORT_ENABLE_TEST_REPORT': { 'name': _('Enable Test Reports'), 'description': _('Enable generation of test reports'), 'default': True, 'validator': bool, }, - 'REPORT_ATTACH_TEST_REPORT': { 'name': _('Attach Test Reports'), - 'description': _('When printing a Test Report, attach a copy of the Test Report to the associated Stock Item'), + 'description': _( + 'When printing a Test Report, attach a copy of the Test Report to the associated Stock Item' + ), 'default': False, 'validator': bool, }, - 'SERIAL_NUMBER_GLOBALLY_UNIQUE': { 'name': _('Globally Unique Serials'), 'description': _('Serial numbers for stock items must be globally unique'), 'default': False, 'validator': bool, }, - 'SERIAL_NUMBER_AUTOFILL': { 'name': _('Autofill Serial Numbers'), 'description': _('Autofill serial numbers in forms'), 'default': False, 'validator': bool, }, - 'STOCK_DELETE_DEPLETED_DEFAULT': { 'name': _('Delete Depleted Stock'), - 'description': _('Determines default behaviour when a stock item is depleted'), + 'description': _( + 'Determines default behaviour when a stock item is depleted' + ), 'default': True, 'validator': bool, }, - 'STOCK_BATCH_CODE_TEMPLATE': { 'name': _('Batch Code Template'), - 'description': _('Template for generating default batch codes for stock items'), + 'description': _( + 'Template for generating default batch codes for stock items' + ), 'default': '', }, - 'STOCK_ENABLE_EXPIRY': { 'name': _('Stock Expiry'), 'description': _('Enable stock expiry functionality'), 'default': False, 'validator': bool, }, - 'STOCK_ALLOW_EXPIRED_SALE': { 'name': _('Sell Expired Stock'), 'description': _('Allow sale of expired stock'), 'default': False, 'validator': bool, }, - 'STOCK_STALE_DAYS': { 'name': _('Stock Stale Time'), - 'description': _('Number of days stock items are considered stale before expiring'), + 'description': _( + 'Number of days stock items are considered stale before expiring' + ), 'default': 0, 'units': _('days'), 'validator': [int], }, - 'STOCK_ALLOW_EXPIRED_BUILD': { 'name': _('Build Expired Stock'), 'description': _('Allow building with expired stock'), 'default': False, 'validator': bool, }, - 'STOCK_OWNERSHIP_CONTROL': { 'name': _('Stock Ownership Control'), 'description': _('Enable ownership control over stock locations and items'), 'default': False, 'validator': bool, }, - 'STOCK_LOCATION_DEFAULT_ICON': { 'name': _('Stock Location Default Icon'), 'description': _('Stock location default icon (empty means no icon)'), 'default': '', }, - 'STOCK_SHOW_INSTALLED_ITEMS': { 'name': _('Show Installed Stock Items'), 'description': _('Display installed stock items in stock tables'), 'default': False, 'validator': bool, }, - 'BUILDORDER_REFERENCE_PATTERN': { 'name': _('Build Order Reference Pattern'), - 'description': _('Required pattern for generating Build Order reference field'), + 'description': _( + 'Required pattern for generating Build Order reference field' + ), 'default': 'BO-{ref:04d}', 'validator': build.validators.validate_build_order_reference_pattern, }, - 'RETURNORDER_ENABLED': { 'name': _('Enable Return Orders'), 'description': _('Enable return order functionality in the user interface'), 'validator': bool, 'default': False, }, - 'RETURNORDER_REFERENCE_PATTERN': { 'name': _('Return Order Reference Pattern'), - 'description': _('Required pattern for generating Return Order reference field'), + 'description': _( + 'Required pattern for generating Return Order reference field' + ), 'default': 'RMA-{ref:04d}', 'validator': order.validators.validate_return_order_reference_pattern, }, - 'RETURNORDER_EDIT_COMPLETED_ORDERS': { 'name': _('Edit Completed Return Orders'), - 'description': _('Allow editing of return orders after they have been completed'), + 'description': _( + 'Allow editing of return orders after they have been completed' + ), 'default': False, 'validator': bool, }, - 'SALESORDER_REFERENCE_PATTERN': { 'name': _('Sales Order Reference Pattern'), - 'description': _('Required pattern for generating Sales Order reference field'), + 'description': _( + 'Required pattern for generating Sales Order reference field' + ), 'default': 'SO-{ref:04d}', 'validator': order.validators.validate_sales_order_reference_pattern, }, - 'SALESORDER_DEFAULT_SHIPMENT': { 'name': _('Sales Order Default Shipment'), 'description': _('Enable creation of default shipment with sales orders'), 'default': False, 'validator': bool, }, - 'SALESORDER_EDIT_COMPLETED_ORDERS': { 'name': _('Edit Completed Sales Orders'), - 'description': _('Allow editing of sales orders after they have been shipped or completed'), + 'description': _( + 'Allow editing of sales orders after they have been shipped or completed' + ), 'default': False, 'validator': bool, }, - 'PURCHASEORDER_REFERENCE_PATTERN': { 'name': _('Purchase Order Reference Pattern'), - 'description': _('Required pattern for generating Purchase Order reference field'), + 'description': _( + 'Required pattern for generating Purchase Order reference field' + ), 'default': 'PO-{ref:04d}', 'validator': order.validators.validate_purchase_order_reference_pattern, }, - 'PURCHASEORDER_EDIT_COMPLETED_ORDERS': { 'name': _('Edit Completed Purchase Orders'), - 'description': _('Allow editing of purchase orders after they have been shipped or completed'), + 'description': _( + 'Allow editing of purchase orders after they have been shipped or completed' + ), 'default': False, 'validator': bool, }, - # login / SSO 'LOGIN_ENABLE_PWD_FORGOT': { 'name': _('Enable password forgot'), @@ -1729,85 +1748,82 @@ class InvenTreeSetting(BaseInvenTreeSetting): 'default': True, 'validator': bool, }, - 'LOGIN_ENABLE_REG': { 'name': _('Enable registration'), 'description': _('Enable self-registration for users on the login pages'), 'default': False, 'validator': bool, }, - 'LOGIN_ENABLE_SSO': { 'name': _('Enable SSO'), 'description': _('Enable SSO on the login pages'), 'default': False, 'validator': bool, }, - 'LOGIN_ENABLE_SSO_REG': { 'name': _('Enable SSO registration'), - 'description': _('Enable self-registration via SSO for users on the login pages'), + 'description': _( + 'Enable self-registration via SSO for users on the login pages' + ), 'default': False, 'validator': bool, }, - 'LOGIN_MAIL_REQUIRED': { 'name': _('Email required'), 'description': _('Require user to supply mail on signup'), 'default': False, 'validator': bool, }, - 'LOGIN_SIGNUP_SSO_AUTO': { 'name': _('Auto-fill SSO users'), - 'description': _('Automatically fill out user-details from SSO account-data'), + 'description': _( + 'Automatically fill out user-details from SSO account-data' + ), 'default': True, 'validator': bool, }, - 'LOGIN_SIGNUP_MAIL_TWICE': { 'name': _('Mail twice'), 'description': _('On signup ask users twice for their mail'), 'default': False, 'validator': bool, }, - 'LOGIN_SIGNUP_PWD_TWICE': { 'name': _('Password twice'), 'description': _('On signup ask users twice for their password'), 'default': True, 'validator': bool, }, - 'LOGIN_SIGNUP_MAIL_RESTRICTION': { 'name': _('Allowed domains'), - 'description': _('Restrict signup to certain domains (comma-separated, starting with @)'), + 'description': _( + 'Restrict signup to certain domains (comma-separated, starting with @)' + ), 'default': '', 'before_save': validate_email_domains, }, - 'SIGNUP_GROUP': { 'name': _('Group on signup'), 'description': _('Group to which new users are assigned on registration'), 'default': '', - 'choices': settings_group_options + 'choices': settings_group_options, }, - 'LOGIN_ENFORCE_MFA': { 'name': _('Enforce MFA'), 'description': _('Users must use multifactor security.'), 'default': False, 'validator': bool, }, - 'PLUGIN_ON_STARTUP': { 'name': _('Check plugins on startup'), - 'description': _('Check that all plugins are installed on startup - enable in container environments'), - 'default': str(os.getenv('INVENTREE_DOCKER', False)).lower() in ['1', 'true'], + 'description': _( + 'Check that all plugins are installed on startup - enable in container environments' + ), + 'default': str(os.getenv('INVENTREE_DOCKER', False)).lower() + in ['1', 'true'], 'validator': bool, 'requires_restart': True, }, - # Settings for plugin mixin features 'ENABLE_PLUGINS_URL': { 'name': _('Enable URL integration'), @@ -1816,7 +1832,6 @@ class InvenTreeSetting(BaseInvenTreeSetting): 'validator': bool, 'after_save': reload_plugin_registry, }, - 'ENABLE_PLUGINS_NAVIGATION': { 'name': _('Enable navigation integration'), 'description': _('Enable plugins to integrate into navigation'), @@ -1824,7 +1839,6 @@ class InvenTreeSetting(BaseInvenTreeSetting): 'validator': bool, 'after_save': reload_plugin_registry, }, - 'ENABLE_PLUGINS_APP': { 'name': _('Enable app integration'), 'description': _('Enable plugins to add apps'), @@ -1832,7 +1846,6 @@ class InvenTreeSetting(BaseInvenTreeSetting): 'validator': bool, 'after_save': reload_plugin_registry, }, - 'ENABLE_PLUGINS_SCHEDULE': { 'name': _('Enable schedule integration'), 'description': _('Enable plugins to run scheduled tasks'), @@ -1840,7 +1853,6 @@ class InvenTreeSetting(BaseInvenTreeSetting): 'validator': bool, 'after_save': reload_plugin_registry, }, - 'ENABLE_PLUGINS_EVENTS': { 'name': _('Enable event integration'), 'description': _('Enable plugins to respond to internal events'), @@ -1848,55 +1860,51 @@ class InvenTreeSetting(BaseInvenTreeSetting): 'validator': bool, 'after_save': reload_plugin_registry, }, - - "PROJECT_CODES_ENABLED": { + 'PROJECT_CODES_ENABLED': { 'name': _('Enable project codes'), 'description': _('Enable project codes for tracking projects'), 'default': False, 'validator': bool, }, - 'STOCKTAKE_ENABLE': { 'name': _('Stocktake Functionality'), - 'description': _('Enable stocktake functionality for recording stock levels and calculating stock value'), + 'description': _( + 'Enable stocktake functionality for recording stock levels and calculating stock value' + ), 'validator': bool, 'default': False, }, - 'STOCKTAKE_EXCLUDE_EXTERNAL': { 'name': _('Exclude External Locations'), - 'description': _('Exclude stock items in external locations from stocktake calculations'), + 'description': _( + 'Exclude stock items in external locations from stocktake calculations' + ), 'validator': bool, 'default': False, }, - 'STOCKTAKE_AUTO_DAYS': { 'name': _('Automatic Stocktake Period'), - 'description': _('Number of days between automatic stocktake recording (set to zero to disable)'), - 'validator': [ - int, - MinValueValidator(0), - ], + 'description': _( + 'Number of days between automatic stocktake recording (set to zero to disable)' + ), + 'validator': [int, MinValueValidator(0)], 'default': 0, }, - 'STOCKTAKE_DELETE_REPORT_DAYS': { 'name': _('Report Deletion Interval'), - 'description': _('Stocktake reports will be deleted after specified number of days'), + 'description': _( + 'Stocktake reports will be deleted after specified number of days' + ), 'default': 30, 'units': _('days'), - 'validator': [ - int, - MinValueValidator(7), - ] + 'validator': [int, MinValueValidator(7)], }, - 'DISPLAY_FULL_NAMES': { 'name': _('Display Users full names'), 'description': _('Display Users full names instead of usernames'), 'default': False, - 'validator': bool - } + 'validator': bool, + }, } typ = 'inventree' @@ -1926,7 +1934,9 @@ def label_printer_options(): printers = [] label_printer_plugins = registry.with_mixin('labels') if label_printer_plugins: - printers.extend([(p.slug, p.name + ' - ' + p.human_name) for p in label_printer_plugins]) + printers.extend([ + (p.slug, p.name + ' - ' + p.human_name) for p in label_printer_plugins + ]) return printers @@ -1936,322 +1946,295 @@ class InvenTreeUserSetting(BaseInvenTreeSetting): class Meta: """Meta options for InvenTreeUserSetting.""" - verbose_name = "InvenTree User Setting" - verbose_name_plural = "InvenTree User Settings" + verbose_name = 'InvenTree User Setting' + verbose_name_plural = 'InvenTree User Settings' constraints = [ models.UniqueConstraint(fields=['key', 'user'], name='unique key and user') ] SETTINGS = { - 'HOMEPAGE_HIDE_INACTIVE': { 'name': _('Hide inactive parts'), - 'description': _('Hide inactive parts in results displayed on the homepage'), + 'description': _( + 'Hide inactive parts in results displayed on the homepage' + ), 'default': True, 'validator': bool, }, - 'HOMEPAGE_PART_STARRED': { 'name': _('Show subscribed parts'), 'description': _('Show subscribed parts on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_CATEGORY_STARRED': { 'name': _('Show subscribed categories'), 'description': _('Show subscribed part categories on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_PART_LATEST': { 'name': _('Show latest parts'), 'description': _('Show latest parts on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_BOM_REQUIRES_VALIDATION': { 'name': _('Show unvalidated BOMs'), 'description': _('Show BOMs that await validation on the homepage'), 'default': False, 'validator': bool, }, - 'HOMEPAGE_STOCK_RECENT': { 'name': _('Show recent stock changes'), 'description': _('Show recently changed stock items on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_STOCK_LOW': { 'name': _('Show low stock'), 'description': _('Show low stock items on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_SHOW_STOCK_DEPLETED': { 'name': _('Show depleted stock'), 'description': _('Show depleted stock items on the homepage'), 'default': False, 'validator': bool, }, - 'HOMEPAGE_BUILD_STOCK_NEEDED': { 'name': _('Show needed stock'), 'description': _('Show stock items needed for builds on the homepage'), 'default': False, 'validator': bool, }, - 'HOMEPAGE_STOCK_EXPIRED': { 'name': _('Show expired stock'), 'description': _('Show expired stock items on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_STOCK_STALE': { 'name': _('Show stale stock'), 'description': _('Show stale stock items on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_BUILD_PENDING': { 'name': _('Show pending builds'), 'description': _('Show pending builds on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_BUILD_OVERDUE': { 'name': _('Show overdue builds'), 'description': _('Show overdue builds on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_PO_OUTSTANDING': { 'name': _('Show outstanding POs'), 'description': _('Show outstanding POs on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_PO_OVERDUE': { 'name': _('Show overdue POs'), 'description': _('Show overdue POs on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_SO_OUTSTANDING': { 'name': _('Show outstanding SOs'), 'description': _('Show outstanding SOs on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_SO_OVERDUE': { 'name': _('Show overdue SOs'), 'description': _('Show overdue SOs on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_SO_SHIPMENTS_PENDING': { 'name': _('Show pending SO shipments'), 'description': _('Show pending SO shipments on the homepage'), 'default': True, 'validator': bool, }, - 'HOMEPAGE_NEWS': { 'name': _('Show News'), 'description': _('Show news on the homepage'), 'default': False, 'validator': bool, }, - - "LABEL_INLINE": { + 'LABEL_INLINE': { 'name': _('Inline label display'), - 'description': _('Display PDF labels in the browser, instead of downloading as a file'), + 'description': _( + 'Display PDF labels in the browser, instead of downloading as a file' + ), 'default': True, 'validator': bool, }, - - "LABEL_DEFAULT_PRINTER": { + 'LABEL_DEFAULT_PRINTER': { 'name': _('Default label printer'), - 'description': _('Configure which label printer should be selected by default'), + 'description': _( + 'Configure which label printer should be selected by default' + ), 'default': '', - 'choices': label_printer_options + 'choices': label_printer_options, }, - - "REPORT_INLINE": { + 'REPORT_INLINE': { 'name': _('Inline report display'), - 'description': _('Display PDF reports in the browser, instead of downloading as a file'), + 'description': _( + 'Display PDF reports in the browser, instead of downloading as a file' + ), 'default': False, 'validator': bool, }, - 'SEARCH_PREVIEW_SHOW_PARTS': { 'name': _('Search Parts'), 'description': _('Display parts in search preview window'), 'default': True, 'validator': bool, }, - 'SEARCH_PREVIEW_SHOW_SUPPLIER_PARTS': { 'name': _('Search Supplier Parts'), 'description': _('Display supplier parts in search preview window'), 'default': True, 'validator': bool, }, - 'SEARCH_PREVIEW_SHOW_MANUFACTURER_PARTS': { 'name': _('Search Manufacturer Parts'), 'description': _('Display manufacturer parts in search preview window'), 'default': True, 'validator': bool, }, - 'SEARCH_HIDE_INACTIVE_PARTS': { - 'name': _("Hide Inactive Parts"), + 'name': _('Hide Inactive Parts'), 'description': _('Excluded inactive parts from search preview window'), 'default': False, 'validator': bool, }, - 'SEARCH_PREVIEW_SHOW_CATEGORIES': { 'name': _('Search Categories'), 'description': _('Display part categories in search preview window'), 'default': False, 'validator': bool, }, - 'SEARCH_PREVIEW_SHOW_STOCK': { 'name': _('Search Stock'), 'description': _('Display stock items in search preview window'), 'default': True, 'validator': bool, }, - 'SEARCH_PREVIEW_HIDE_UNAVAILABLE_STOCK': { 'name': _('Hide Unavailable Stock Items'), - 'description': _('Exclude stock items which are not available from the search preview window'), + 'description': _( + 'Exclude stock items which are not available from the search preview window' + ), 'validator': bool, 'default': False, }, - 'SEARCH_PREVIEW_SHOW_LOCATIONS': { 'name': _('Search Locations'), 'description': _('Display stock locations in search preview window'), 'default': False, 'validator': bool, }, - 'SEARCH_PREVIEW_SHOW_COMPANIES': { 'name': _('Search Companies'), 'description': _('Display companies in search preview window'), 'default': True, 'validator': bool, }, - 'SEARCH_PREVIEW_SHOW_BUILD_ORDERS': { 'name': _('Search Build Orders'), 'description': _('Display build orders in search preview window'), 'default': True, 'validator': bool, }, - 'SEARCH_PREVIEW_SHOW_PURCHASE_ORDERS': { 'name': _('Search Purchase Orders'), 'description': _('Display purchase orders in search preview window'), 'default': True, 'validator': bool, }, - 'SEARCH_PREVIEW_EXCLUDE_INACTIVE_PURCHASE_ORDERS': { 'name': _('Exclude Inactive Purchase Orders'), - 'description': _('Exclude inactive purchase orders from search preview window'), + 'description': _( + 'Exclude inactive purchase orders from search preview window' + ), 'default': True, 'validator': bool, }, - 'SEARCH_PREVIEW_SHOW_SALES_ORDERS': { 'name': _('Search Sales Orders'), 'description': _('Display sales orders in search preview window'), 'default': True, 'validator': bool, }, - 'SEARCH_PREVIEW_EXCLUDE_INACTIVE_SALES_ORDERS': { 'name': _('Exclude Inactive Sales Orders'), - 'description': _('Exclude inactive sales orders from search preview window'), + 'description': _( + 'Exclude inactive sales orders from search preview window' + ), 'validator': bool, 'default': True, }, - 'SEARCH_PREVIEW_SHOW_RETURN_ORDERS': { 'name': _('Search Return Orders'), 'description': _('Display return orders in search preview window'), 'default': True, 'validator': bool, }, - 'SEARCH_PREVIEW_EXCLUDE_INACTIVE_RETURN_ORDERS': { 'name': _('Exclude Inactive Return Orders'), - 'description': _('Exclude inactive return orders from search preview window'), + 'description': _( + 'Exclude inactive return orders from search preview window' + ), 'validator': bool, 'default': True, }, - 'SEARCH_PREVIEW_RESULTS': { 'name': _('Search Preview Results'), - 'description': _('Number of results to show in each section of the search preview window'), + 'description': _( + 'Number of results to show in each section of the search preview window' + ), 'default': 10, - 'validator': [int, MinValueValidator(1)] + 'validator': [int, MinValueValidator(1)], }, - 'SEARCH_REGEX': { 'name': _('Regex Search'), 'description': _('Enable regular expressions in search queries'), 'default': False, 'validator': bool, }, - 'SEARCH_WHOLE': { 'name': _('Whole Word Search'), 'description': _('Search queries return results for whole word matches'), 'default': False, 'validator': bool, }, - 'PART_SHOW_QUANTITY_IN_FORMS': { 'name': _('Show Quantity in Forms'), 'description': _('Display available part quantity in some forms'), 'default': True, 'validator': bool, }, - 'FORMS_CLOSE_USING_ESCAPE': { 'name': _('Escape Key Closes Forms'), 'description': _('Use the escape key to close modal forms'), 'default': False, 'validator': bool, }, - 'STICKY_HEADER': { 'name': _('Fixed Navbar'), 'description': _('The navbar position is fixed to the top of the screen'), 'default': False, 'validator': bool, }, - 'DATE_DISPLAY_FORMAT': { 'name': _('Date Format'), 'description': _('Preferred format for displaying dates'), @@ -2264,67 +2247,58 @@ class InvenTreeUserSetting(BaseInvenTreeSetting): ('MM-DD-YYYY', '02-22-2022'), ('MM/DD/YYYY', '02/22/2022'), ('MMM DD YYYY', 'Feb 22 2022'), - ] + ], }, - 'DISPLAY_SCHEDULE_TAB': { 'name': _('Part Scheduling'), 'description': _('Display part scheduling information'), 'default': True, 'validator': bool, }, - 'DISPLAY_STOCKTAKE_TAB': { 'name': _('Part Stocktake'), - 'description': _('Display part stocktake information (if stocktake functionality is enabled)'), + 'description': _( + 'Display part stocktake information (if stocktake functionality is enabled)' + ), 'default': True, 'validator': bool, }, - 'TABLE_STRING_MAX_LENGTH': { 'name': _('Table String Length'), - 'description': _('Maximum length limit for strings displayed in table views'), - 'validator': [ - int, - MinValueValidator(0), - ], + 'description': _( + 'Maximum length limit for strings displayed in table views' + ), + 'validator': [int, MinValueValidator(0)], 'default': 100, }, - 'DEFAULT_PART_LABEL_TEMPLATE': { 'name': _('Default part label template'), 'description': _('The part label template to be automatically selected'), - 'validator': [ - int, - ], + 'validator': [int], 'default': '', }, - 'DEFAULT_ITEM_LABEL_TEMPLATE': { 'name': _('Default stock item template'), - 'description': _('The stock item label template to be automatically selected'), - 'validator': [ - int, - ], + 'description': _( + 'The stock item label template to be automatically selected' + ), + 'validator': [int], 'default': '', }, - 'DEFAULT_LOCATION_LABEL_TEMPLATE': { 'name': _('Default stock location label template'), - 'description': _('The stock location label template to be automatically selected'), - 'validator': [ - int, - ], + 'description': _( + 'The stock location label template to be automatically selected' + ), + 'validator': [int], 'default': '', }, - 'NOTIFICATION_ERROR_REPORT': { 'name': _('Receive error reports'), 'description': _('Receive notifications for system errors'), 'default': True, 'validator': bool, - } - + }, } typ = 'user' @@ -2340,7 +2314,8 @@ class InvenTreeUserSetting(BaseInvenTreeSetting): user = models.ForeignKey( User, on_delete=models.CASCADE, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('User'), help_text=_('User'), ) @@ -2384,13 +2359,24 @@ class PriceBreak(MetaMixin): try: converted = convert_money(self.price, currency_code) except MissingRate: - logger.warning("No currency conversion rate available for %s -> %s", self.price_currency, currency_code) + logger.warning( + 'No currency conversion rate available for %s -> %s', + self.price_currency, + currency_code, + ) return self.price.amount return converted.amount -def get_price(instance, quantity, moq=True, multiples=True, currency=None, break_name: str = 'price_breaks'): +def get_price( + instance, + quantity, + moq=True, + multiples=True, + currency=None, + break_name: str = 'price_breaks', +): """Calculate the price based on quantity price breaks. - Don't forget to add in flat-fee cost (base_cost field) @@ -2409,7 +2395,7 @@ def get_price(instance, quantity, moq=True, multiples=True, currency=None, break return None # Check if quantity is fraction and disable multiples - multiples = (quantity % 1 == 0) + multiples = quantity % 1 == 0 # Order multiples if multiples: @@ -2461,12 +2447,10 @@ def get_price(instance, quantity, moq=True, multiples=True, currency=None, break class ColorTheme(models.Model): """Color Theme Setting.""" - name = models.CharField(max_length=20, - default='', - blank=True) - user = models.CharField(max_length=150, - unique=True) + name = models.CharField(max_length=20, default='', blank=True) + + user = models.CharField(max_length=150, unique=True) @classmethod def get_color_themes_choices(cls): @@ -2482,9 +2466,11 @@ class ColorTheme(models.Model): files_list.append([file.stem, file.suffix]) # Get color themes choices (CSS sheets) - choices = [(file_name.lower(), _(file_name.replace('-', ' ').title())) - for file_name, file_ext in files_list - if file_ext == '.css'] + choices = [ + (file_name.lower(), _(file_name.replace('-', ' ').title())) + for file_name, file_ext in files_list + if file_ext == '.css' + ] return choices @@ -2505,6 +2491,7 @@ class ColorTheme(models.Model): class VerificationMethod(Enum): """Class to hold method references.""" + NONE = 0 TOKEN = 1 HMAC = 2 @@ -2523,11 +2510,11 @@ class WebhookEndpoint(models.Model): """ # Token - TOKEN_NAME = "Token" + TOKEN_NAME = 'Token' VERIFICATION_METHOD = VerificationMethod.NONE - MESSAGE_OK = "Message was received." - MESSAGE_TOKEN_ERROR = "Incorrect token in header." + MESSAGE_OK = 'Message was received.' + MESSAGE_TOKEN_ERROR = 'Incorrect token in header.' endpoint_id = models.CharField( max_length=255, @@ -2539,28 +2526,29 @@ class WebhookEndpoint(models.Model): name = models.CharField( max_length=255, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Name'), - help_text=_('Name for this webhook') + help_text=_('Name for this webhook'), ) active = models.BooleanField( - default=True, - verbose_name=_('Active'), - help_text=_('Is this webhook active') + default=True, verbose_name=_('Active'), help_text=_('Is this webhook active') ) user = models.ForeignKey( User, on_delete=models.SET_NULL, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('User'), help_text=_('User'), ) token = models.CharField( max_length=255, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Token'), help_text=_('Token for access'), default=uuid.uuid4, @@ -2568,7 +2556,8 @@ class WebhookEndpoint(models.Model): secret = models.CharField( max_length=255, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Secret'), help_text=_('Shared secret for HMAC'), ) @@ -2600,7 +2589,7 @@ class WebhookEndpoint(models.Model): This can be overridden to create your own token validation method. """ - token = headers.get(self.TOKEN_NAME, "") + token = headers.get(self.TOKEN_NAME, '') # no token if self.verify == VerificationMethod.NONE: @@ -2614,7 +2603,9 @@ class WebhookEndpoint(models.Model): # hmac token elif self.verify == VerificationMethod.HMAC: - digest = hmac.new(self.secret.encode('utf-8'), request.body, hashlib.sha256).digest() + digest = hmac.new( + self.secret.encode('utf-8'), request.body, hashlib.sha256 + ).digest() computed_hmac = base64.b64encode(digest) if not hmac.compare_digest(computed_hmac, token.encode('utf-8')): raise PermissionDenied(self.MESSAGE_TOKEN_ERROR) @@ -2692,14 +2683,16 @@ class WebhookMessage(models.Model): header = models.CharField( max_length=255, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Header'), help_text=_('Header of this message'), editable=False, ) body = models.JSONField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Body'), help_text=_('Body of this message'), editable=False, @@ -2708,7 +2701,8 @@ class WebhookMessage(models.Model): endpoint = models.ForeignKey( WebhookEndpoint, on_delete=models.SET_NULL, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Endpoint'), help_text=_('Endpoint on which this message was received'), ) @@ -2734,38 +2728,25 @@ class NotificationEntry(MetaMixin): class Meta: """Meta options for NotificationEntry.""" - unique_together = [ - ('key', 'uid'), - ] + unique_together = [('key', 'uid')] - key = models.CharField( - max_length=250, - blank=False, - ) + key = models.CharField(max_length=250, blank=False) - uid = models.IntegerField( - ) + uid = models.IntegerField() @classmethod def check_recent(cls, key: str, uid: int, delta: timedelta): """Test if a particular notification has been sent in the specified time period.""" since = datetime.now().date() - delta - entries = cls.objects.filter( - key=key, - uid=uid, - updated__gte=since - ) + entries = cls.objects.filter(key=key, uid=uid, updated__gte=since) return entries.exists() @classmethod def notify(cls, key: str, uid: int): """Notify the database that a particular notification has been sent out.""" - entry, created = cls.objects.get_or_create( - key=key, - uid=uid - ) + entry, created = cls.objects.get_or_create(key=key, uid=uid) entry.save() @@ -2782,9 +2763,7 @@ class NotificationMessage(models.Model): # generic link to target target_content_type = models.ForeignKey( - ContentType, - on_delete=models.CASCADE, - related_name='notification_target', + ContentType, on_delete=models.CASCADE, related_name='notification_target' ) target_object_id = models.PositiveIntegerField() @@ -2800,10 +2779,7 @@ class NotificationMessage(models.Model): blank=True, ) - source_object_id = models.PositiveIntegerField( - null=True, - blank=True, - ) + source_object_id = models.PositiveIntegerField(null=True, blank=True) source_object = GenericForeignKey('source_content_type', 'source_object_id') @@ -2817,29 +2793,15 @@ class NotificationMessage(models.Model): blank=True, ) - category = models.CharField( - max_length=250, - blank=False, - ) + category = models.CharField(max_length=250, blank=False) - name = models.CharField( - max_length=250, - blank=False, - ) + name = models.CharField(max_length=250, blank=False) - message = models.CharField( - max_length=250, - blank=True, - null=True, - ) + message = models.CharField(max_length=250, blank=True, null=True) - creation = models.DateTimeField( - auto_now_add=True, - ) + creation = models.DateTimeField(auto_now_add=True) - read = models.BooleanField( - default=False, - ) + read = models.BooleanField(default=False) @staticmethod def get_api_url(): @@ -2869,41 +2831,20 @@ class NewsFeedEntry(models.Model): - read: Was this iteam already by a superuser? """ - feed_id = models.CharField( - verbose_name=_('Id'), - unique=True, - max_length=250, - ) + feed_id = models.CharField(verbose_name=_('Id'), unique=True, max_length=250) - title = models.CharField( - verbose_name=_('Title'), - max_length=250, - ) + title = models.CharField(verbose_name=_('Title'), max_length=250) - link = models.URLField( - verbose_name=_('Link'), - max_length=250, - ) + link = models.URLField(verbose_name=_('Link'), max_length=250) - published = models.DateTimeField( - verbose_name=_('Published'), - max_length=250, - ) + published = models.DateTimeField(verbose_name=_('Published'), max_length=250) - author = models.CharField( - verbose_name=_('Author'), - max_length=250, - ) + author = models.CharField(verbose_name=_('Author'), max_length=250) - summary = models.CharField( - verbose_name=_('Summary'), - max_length=250, - ) + summary = models.CharField(verbose_name=_('Summary'), max_length=250) read = models.BooleanField( - verbose_name=_('Read'), - help_text=_('Was this news item read?'), - default=False + verbose_name=_('Read'), help_text=_('Was this news item read?'), default=False ) @@ -2920,9 +2861,7 @@ class NotesImage(models.Model): """ image = models.ImageField( - upload_to=rename_notes_image, - verbose_name=_('Image'), - help_text=_('Image file'), + upload_to=rename_notes_image, verbose_name=_('Image'), help_text=_('Image file') ) user = models.ForeignKey(User, on_delete=models.SET_NULL, null=True, blank=True) @@ -2964,9 +2903,7 @@ class CustomUnit(models.Model): # Cannot be zero length if not self.name.isidentifier(): - raise ValidationError({ - 'name': _('Unit name must be a valid identifier') - }) + raise ValidationError({'name': _('Unit name must be a valid identifier')}) self.definition = self.definition.strip() @@ -2974,9 +2911,7 @@ class CustomUnit(models.Model): try: registry.Quantity(self.definition) except Exception as exc: - raise ValidationError({ - 'definition': str(exc) - }) + raise ValidationError({'definition': str(exc)}) # Finally, test that the entire custom unit definition is valid try: @@ -2988,14 +2923,16 @@ class CustomUnit(models.Model): max_length=50, verbose_name=_('Name'), help_text=_('Unit name'), - unique=True, blank=False, + unique=True, + blank=False, ) symbol = models.CharField( max_length=10, verbose_name=_('Symbol'), help_text=_('Optional unit symbol'), - unique=True, blank=True, + unique=True, + blank=True, ) definition = models.CharField( @@ -3012,4 +2949,5 @@ def after_custom_unit_updated(sender, instance, **kwargs): """Callback when a custom unit is updated or deleted""" # Force reload of the unit registry from InvenTree.conversion import reload_unit_registry + reload_unit_registry() diff --git a/InvenTree/common/notifications.py b/InvenTree/common/notifications.py index c32ee3ec82..fadcdd7f05 100644 --- a/InvenTree/common/notifications.py +++ b/InvenTree/common/notifications.py @@ -24,7 +24,7 @@ class NotificationMethod: METHOD_NAME = '' METHOD_ICON = None - CONTEXT_BUILTIN = ['name', 'message', ] + CONTEXT_BUILTIN = ['name', 'message'] CONTEXT_EXTRA = [] GLOBAL_SETTING = None USER_SETTING = None @@ -39,11 +39,15 @@ class NotificationMethod: """ # Check if a sending fnc is defined if (not hasattr(self, 'send')) and (not hasattr(self, 'send_bulk')): - raise NotImplementedError('A NotificationMethod must either define a `send` or a `send_bulk` method') + raise NotImplementedError( + 'A NotificationMethod must either define a `send` or a `send_bulk` method' + ) # No method name is no good if self.METHOD_NAME in ('', None): - raise NotImplementedError(f'The NotificationMethod {self.__class__} did not provide a METHOD_NAME') + raise NotImplementedError( + f'The NotificationMethod {self.__class__} did not provide a METHOD_NAME' + ) # Check if plugin is disabled - if so do not gather targets etc. if self.global_setting_disable(): @@ -61,9 +65,10 @@ class NotificationMethod: def check_context(self, context): """Check that all values defined in the methods CONTEXT were provided in the current context.""" + def check(ref, obj): # the obj is not accessible so we are on the end - if not isinstance(obj, (list, dict, tuple, )): + if not isinstance(obj, (list, dict, tuple)): return ref # check if the ref exists @@ -82,7 +87,9 @@ class NotificationMethod: return check(ref[1:], obj[ref[0]]) # other cases -> raise - raise NotImplementedError('This type can not be used as a context reference') + raise NotImplementedError( + 'This type can not be used as a context reference' + ) missing = [] for item in (*self.CONTEXT_BUILTIN, *self.CONTEXT_EXTRA): @@ -91,7 +98,9 @@ class NotificationMethod: missing.append(ret) if missing: - raise NotImplementedError(f'The `context` is missing the following items:\n{missing}') + raise NotImplementedError( + f'The `context` is missing the following items:\n{missing}' + ) return context @@ -142,7 +151,12 @@ class NotificationMethod: def usersetting(self, target): """Returns setting for this method for a given user.""" - return NotificationUserSetting.get_setting(f'NOTIFICATION_METHOD_{self.METHOD_NAME.upper()}', user=target, method=self.METHOD_NAME) + return NotificationUserSetting.get_setting( + f'NOTIFICATION_METHOD_{self.METHOD_NAME.upper()}', + user=target, + method=self.METHOD_NAME, + ) + # endregion @@ -160,6 +174,8 @@ class BulkNotificationMethod(NotificationMethod): def send_bulk(self): """This function must be overridden.""" raise NotImplementedError('The `send` method must be overridden!') + + # endregion @@ -181,17 +197,25 @@ class MethodStorageClass: selected_classes (class, optional): References to the classes that should be registered. Defaults to None. """ logger.debug('Collecting notification methods') - current_method = InvenTree.helpers.inheritors(NotificationMethod) - IGNORED_NOTIFICATION_CLS + current_method = ( + InvenTree.helpers.inheritors(NotificationMethod) - IGNORED_NOTIFICATION_CLS + ) # for testing selective loading is made available if selected_classes: - current_method = [item for item in current_method if item is selected_classes] + current_method = [ + item for item in current_method if item is selected_classes + ] # make sure only one of each method is added filtered_list = {} for item in current_method: plugin = item.get_plugin(item) - ref = f'{plugin.package_path}_{item.METHOD_NAME}' if plugin else item.METHOD_NAME + ref = ( + f'{plugin.package_path}_{item.METHOD_NAME}' + if plugin + else item.METHOD_NAME + ) item.plugin = plugin() if plugin else None filtered_list[ref] = item @@ -217,9 +241,7 @@ class MethodStorageClass: # make sure the setting exists self.user_settings[new_key] = item.USER_SETTING NotificationUserSetting.get_setting( - key=new_key, - user=user, - method=item.METHOD_NAME, + key=new_key, user=user, method=item.METHOD_NAME ) # save definition @@ -231,7 +253,7 @@ class MethodStorageClass: return methods -IGNORED_NOTIFICATION_CLS = {SingleNotificationMethod, BulkNotificationMethod, } +IGNORED_NOTIFICATION_CLS = {SingleNotificationMethod, BulkNotificationMethod} storage = MethodStorageClass() @@ -275,6 +297,7 @@ class NotificationBody: app_label: App label (slugified) of the model model_name': Name (slugified) of the model """ + name: str slug: str message: str @@ -286,24 +309,25 @@ class InvenTreeNotificationBodies: Contains regularly used notification bodies. """ + NewOrder = NotificationBody( - name=_("New {verbose_name}"), + name=_('New {verbose_name}'), slug='{app_label}.new_{model_name}', - message=_("A new order has been created and assigned to you"), + message=_('A new order has been created and assigned to you'), template='email/new_order_assigned.html', ) """Send when a new order (build, sale or purchase) was created.""" OrderCanceled = NotificationBody( - name=_("{verbose_name} canceled"), + name=_('{verbose_name} canceled'), slug='{app_label}.canceled_{model_name}', - message=_("A order that is assigned to you was canceled"), + message=_('A order that is assigned to you was canceled'), template='email/canceled_order_assigned.html', ) """Send when a order (sale, return or purchase) was canceled.""" ItemsReceived = NotificationBody( - name=_("Items Received"), + name=_('Items Received'), slug='purchase_order.items_received', message=_('Items have been received against a purchase order'), template='email/purchase_order_received.html', @@ -340,13 +364,19 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs): if not obj_ref_value: obj_ref_value = getattr(obj, 'id', None) if not obj_ref_value: - raise KeyError(f"Could not resolve an object reference for '{str(obj)}' with {obj_ref}, pk, id") + raise KeyError( + f"Could not resolve an object reference for '{str(obj)}' with {obj_ref}, pk, id" + ) # Check if we have notified recently... delta = timedelta(days=1) if common.models.NotificationEntry.check_recent(category, obj_ref_value, delta): - logger.info("Notification '%s' has recently been sent for '%s' - SKIPPING", category, str(obj)) + logger.info( + "Notification '%s' has recently been sent for '%s' - SKIPPING", + category, + str(obj), + ) return logger.info("Gathering users for notification '%s'", category) @@ -383,7 +413,9 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs): target_users.add(user) # Unhandled type else: - logger.error("Unknown target passed to trigger_notification method: %s", target) + logger.error( + 'Unknown target passed to trigger_notification method: %s', target + ) if target_users: logger.info("Sending notification '%s' for '%s'", category, str(obj)) @@ -392,7 +424,7 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs): if delivery_methods is None: delivery_methods = storage.liste else: - delivery_methods = (delivery_methods - IGNORED_NOTIFICATION_CLS) + delivery_methods = delivery_methods - IGNORED_NOTIFICATION_CLS for method in delivery_methods: logger.info("Triggering notification method '%s'", method.METHOD_NAME) @@ -422,17 +454,15 @@ def trigger_superuser_notification(plugin: PluginConfig, msg: str): trigger_notification( plugin, 'inventree.plugin', - context={ - 'error': plugin, - 'name': _('Error raised by plugin'), - 'message': msg, - }, + context={'error': plugin, 'name': _('Error raised by plugin'), 'message': msg}, targets=users, - delivery_methods={UIMessageNotification, }, + delivery_methods={UIMessageNotification}, ) -def deliver_notification(cls: NotificationMethod, obj, category: str, targets, context: dict): +def deliver_notification( + cls: NotificationMethod, obj, category: str, targets, context: dict +): """Send notification with the provided class. This: @@ -447,7 +477,12 @@ def deliver_notification(cls: NotificationMethod, obj, category: str, targets, c if method.targets and len(method.targets) > 0: # Log start - logger.info("Notify users via '%s' for notification '%s' for '%s'", method.METHOD_NAME, category, str(obj)) + logger.info( + "Notify users via '%s' for notification '%s' for '%s'", + method.METHOD_NAME, + category, + str(obj), + ) # Run setup for delivery method method.setup() @@ -472,6 +507,12 @@ def deliver_notification(cls: NotificationMethod, obj, category: str, targets, c method.cleanup() # Log results - logger.info("Notified %s users via '%s' for notification '%s' for '%s' successfully", success_count, method.METHOD_NAME, category, str(obj)) + logger.info( + "Notified %s users via '%s' for notification '%s' for '%s' successfully", + success_count, + method.METHOD_NAME, + category, + str(obj), + ) if not success: - logger.info("There were some problems") + logger.info('There were some problems') diff --git a/InvenTree/common/serializers.py b/InvenTree/common/serializers.py index 06a6d226e9..587595ecc8 100644 --- a/InvenTree/common/serializers.py +++ b/InvenTree/common/serializers.py @@ -1,6 +1,5 @@ """JSON serializers for common components.""" - from django.urls import reverse from flags.state import flag_state @@ -9,8 +8,10 @@ from rest_framework import serializers import common.models as common_models from InvenTree.helpers import get_objectreference from InvenTree.helpers_model import construct_absolute_url -from InvenTree.serializers import (InvenTreeImageSerializerField, - InvenTreeModelSerializer) +from InvenTree.serializers import ( + InvenTreeImageSerializerField, + InvenTreeModelSerializer, +) from users.serializers import OwnerSerializer @@ -62,10 +63,7 @@ class SettingsSerializer(InvenTreeModelSerializer): if choices: for choice in choices: - results.append({ - 'value': choice[0], - 'display_name': choice[1], - }) + results.append({'value': choice[0], 'display_name': choice[1]}) return results @@ -131,8 +129,10 @@ class GenericReferencedSettingSerializer(SettingsSerializer): def __init__(self, *args, **kwargs): """Init overrides the Meta class to make it dynamic.""" + class CustomMeta: """Scaffold for custom Meta class.""" + fields = [ 'pk', 'key', @@ -204,10 +204,12 @@ class NotificationMessageSerializer(InvenTreeModelSerializer): request = self.context['request'] if request.user and request.user.is_staff: meta = obj.target_object._meta - target['link'] = construct_absolute_url(reverse( - f'admin:{meta.db_table}_change', - kwargs={'object_id': obj.target_object_id} - )) + target['link'] = construct_absolute_url( + reverse( + f'admin:{meta.db_table}_change', + kwargs={'object_id': obj.target_object_id}, + ) + ) return target @@ -257,17 +259,9 @@ class NotesImageSerializer(InvenTreeModelSerializer): """Meta options for NotesImageSerializer.""" model = common_models.NotesImage - fields = [ - 'pk', - 'image', - 'user', - 'date', - ] + fields = ['pk', 'image', 'user', 'date'] - read_only_fields = [ - 'date', - 'user', - ] + read_only_fields = ['date', 'user'] image = InvenTreeImageSerializerField(required=True) @@ -279,13 +273,7 @@ class ProjectCodeSerializer(InvenTreeModelSerializer): """Meta options for ProjectCodeSerializer.""" model = common_models.ProjectCode - fields = [ - 'pk', - 'code', - 'description', - 'responsible', - 'responsible_detail', - ] + fields = ['pk', 'code', 'description', 'responsible', 'responsible_detail'] responsible_detail = OwnerSerializer(source='responsible', read_only=True) @@ -313,9 +301,4 @@ class CustomUnitSerializer(InvenTreeModelSerializer): """Meta options for CustomUnitSerializer.""" model = common_models.CustomUnit - fields = [ - 'pk', - 'name', - 'symbol', - 'definition', - ] + fields = ['pk', 'name', 'symbol', 'definition'] diff --git a/InvenTree/common/settings.py b/InvenTree/common/settings.py index b62f081a28..eca01311bd 100644 --- a/InvenTree/common/settings.py +++ b/InvenTree/common/settings.py @@ -20,7 +20,9 @@ def currency_code_default(): return cached_value try: - code = InvenTreeSetting.get_setting('INVENTREE_DEFAULT_CURRENCY', backup_value='', create=True, cache=True) + code = InvenTreeSetting.get_setting( + 'INVENTREE_DEFAULT_CURRENCY', backup_value='', create=True, cache=True + ) except Exception: # pragma: no cover # Database may not yet be ready, no need to throw an error here code = '' diff --git a/InvenTree/common/tasks.py b/InvenTree/common/tasks.py index 1725dce0ce..fc20b9dbfe 100644 --- a/InvenTree/common/tasks.py +++ b/InvenTree/common/tasks.py @@ -27,7 +27,9 @@ def delete_old_notifications(): try: from common.models import NotificationEntry except AppRegistryNotReady: # pragma: no cover - logger.info("Could not perform 'delete_old_notifications' - App registry not ready") + logger.info( + "Could not perform 'delete_old_notifications' - App registry not ready" + ) return before = timezone.now() - timedelta(days=90) @@ -49,7 +51,7 @@ def update_news_feed(): try: d = feedparser.parse(settings.INVENTREE_NEWS_URL) except Exception as entry: # pragma: no cover - logger.warning("update_news_feed: Error parsing the newsfeed", entry) + logger.warning('update_news_feed: Error parsing the newsfeed', entry) return # Get a reference list @@ -87,13 +89,15 @@ def delete_old_notes_images(): try: from common.models import NotesImage except AppRegistryNotReady: - logger.info("Could not perform 'delete_old_notes_images' - App registry not ready") + logger.info( + "Could not perform 'delete_old_notes_images' - App registry not ready" + ) return # Remove any notes which point to non-existent image files for note in NotesImage.objects.all(): if not os.path.exists(note.image.path): - logger.info("Deleting note %s - image file does not exist", note.image.path) + logger.info('Deleting note %s - image file does not exist', note.image.path) note.delete() note_classes = getModelsWithMixin(InvenTreeNotesMixin) @@ -112,7 +116,7 @@ def delete_old_notes_images(): break if not found: - logger.info("Deleting note %s - image file not linked to a note", img) + logger.info('Deleting note %s - image file not linked to a note', img) note.delete() # Finally, remove any images in the notes dir which are not linked to a note @@ -127,7 +131,6 @@ def delete_old_notes_images(): all_notes = NotesImage.objects.all() for image in images: - found = False for note in all_notes: img_path = os.path.basename(note.image.path) @@ -136,5 +139,5 @@ def delete_old_notes_images(): break if not found: - logger.info("Deleting note %s - image file not linked to a note", image) + logger.info('Deleting note %s - image file not linked to a note', image) os.remove(os.path.join(notes_dir, image)) diff --git a/InvenTree/common/test_notifications.py b/InvenTree/common/test_notifications.py index e61e60dfc0..a517d51e33 100644 --- a/InvenTree/common/test_notifications.py +++ b/InvenTree/common/test_notifications.py @@ -1,8 +1,12 @@ """Tests for basic notification methods and functions in InvenTree.""" import plugin.templatetags.plugin_extras as plugin_tags -from common.notifications import (BulkNotificationMethod, NotificationMethod, - SingleNotificationMethod, storage) +from common.notifications import ( + BulkNotificationMethod, + NotificationMethod, + SingleNotificationMethod, + storage, +) from part.test_part import BaseNotificationIntegrationTest from plugin.models import NotificationUserSetting @@ -23,37 +27,31 @@ class BaseNotificationTests(BaseNotificationIntegrationTest): """A comment so we do not need a pass.""" class NoNameNotificationMethod(NotificationMethod): - def send(self): """A comment so we do not need a pass.""" class WrongContextNotificationMethod(NotificationMethod): METHOD_NAME = 'WrongContextNotification' - CONTEXT_EXTRA = [ - 'aa', - ('aa', 'bb', ), - ('templates', 'ccc', ), - (123, ) - ] + CONTEXT_EXTRA = ['aa', ('aa', 'bb'), ('templates', 'ccc'), (123,)] def send(self): """A comment so we do not need a pass.""" # no send / send bulk with self.assertRaises(NotImplementedError): - FalseNotificationMethod('', '', '', '', ) + FalseNotificationMethod('', '', '', '') # no METHOD_NAME with self.assertRaises(NotImplementedError): - NoNameNotificationMethod('', '', '', '', ) + NoNameNotificationMethod('', '', '', '') # a not existent context check with self.assertRaises(NotImplementedError): - WrongContextNotificationMethod('', '', '', '', ) + WrongContextNotificationMethod('', '', '', '') # no get_targets with self.assertRaises(NotImplementedError): - AnotherFalseNotificationMethod('', '', '', {'name': 1, 'message': 2, }, ) + AnotherFalseNotificationMethod('', '', '', {'name': 1, 'message': 2}) def test_failing_passing(self): """Ensure that an error in one deliverymethod is not blocking all mehthods.""" @@ -67,7 +65,7 @@ class BaseNotificationTests(BaseNotificationIntegrationTest): METHOD_NAME = 'ErrorImplementation' def get_targets(self): - return [1, ] + return [1] def send(self, target): raise KeyError('This could be any error') @@ -91,7 +89,7 @@ class BulkNotificationMethodTests(BaseNotificationIntegrationTest): METHOD_NAME = 'WrongImplementationBulk' def get_targets(self): - return [1, ] + return [1] with self.assertLogs(logger='inventree', level='ERROR'): self._notification_run(WrongImplementation) @@ -113,11 +111,12 @@ class SingleNotificationMethodTests(BaseNotificationIntegrationTest): METHOD_NAME = 'WrongImplementationSingle' def get_targets(self): - return [1, ] + return [1] with self.assertLogs(logger='inventree', level='ERROR'): self._notification_run(WrongImplementation) + # A integration test for notifications is provided in test_part.PartNotificationTest @@ -144,7 +143,7 @@ class NotificationUserSettingTests(BaseNotificationIntegrationTest): } def get_targets(self): - return [1, ] + return [1] def send_bulk(self): return True @@ -158,10 +157,14 @@ class NotificationUserSettingTests(BaseNotificationIntegrationTest): # assertions for settings self.assertEqual(setting.name, 'Enable test notifications') self.assertEqual(setting.default_value, True) - self.assertEqual(setting.description, 'Allow sending of test for event notifications') + self.assertEqual( + setting.description, 'Allow sending of test for event notifications' + ) self.assertEqual(setting.units, 'alpha') # test tag and array - self.assertEqual(plugin_tags.notification_settings_list({'user': self.user}), array) + self.assertEqual( + plugin_tags.notification_settings_list({'user': self.user}), array + ) self.assertEqual(array[0]['key'], 'NOTIFICATION_METHOD_TEST') self.assertEqual(array[0]['method'], 'test') diff --git a/InvenTree/common/test_tasks.py b/InvenTree/common/test_tasks.py index d551bc82ac..55821dbbf2 100644 --- a/InvenTree/common/test_tasks.py +++ b/InvenTree/common/test_tasks.py @@ -15,4 +15,4 @@ class TaskTest(TestCase): """Test that the task `delete_old_notifications` runs through without errors.""" # check empty run self.assertEqual(NotificationEntry.objects.all().count(), 0) - offload_task(common_tasks.delete_old_notifications,) + offload_task(common_tasks.delete_old_notifications) diff --git a/InvenTree/common/tests.py b/InvenTree/common/tests.py index c8167aec3f..245074841b 100644 --- a/InvenTree/common/tests.py +++ b/InvenTree/common/tests.py @@ -17,16 +17,23 @@ from django.urls import reverse import PIL from InvenTree.helpers import str2bool -from InvenTree.unit_test import (InvenTreeAPITestCase, InvenTreeTestCase, - PluginMixin) +from InvenTree.unit_test import InvenTreeAPITestCase, InvenTreeTestCase, PluginMixin from plugin import registry from plugin.models import NotificationUserSetting from .api import WebhookView -from .models import (ColorTheme, CustomUnit, InvenTreeSetting, - InvenTreeUserSetting, NotesImage, NotificationEntry, - NotificationMessage, ProjectCode, WebhookEndpoint, - WebhookMessage) +from .models import ( + ColorTheme, + CustomUnit, + InvenTreeSetting, + InvenTreeUserSetting, + NotesImage, + NotificationEntry, + NotificationMessage, + ProjectCode, + WebhookEndpoint, + WebhookMessage, +) CONTENT_TYPE_JSON = 'application/json' @@ -34,9 +41,7 @@ CONTENT_TYPE_JSON = 'application/json' class SettingsTest(InvenTreeTestCase): """Tests for the 'settings' model.""" - fixtures = [ - 'settings', - ] + fixtures = ['settings'] def test_settings_objects(self): """Test fixture loading and lookup for settings.""" @@ -50,7 +55,9 @@ class SettingsTest(InvenTreeTestCase): self.assertEqual(instance_name.value, 'My very first InvenTree Instance') # Check object lookup (case insensitive) - self.assertEqual(InvenTreeSetting.get_setting_object('iNvEnTrEE_inSTanCE').pk, 1) + self.assertEqual( + InvenTreeSetting.get_setting_object('iNvEnTrEE_inSTanCE').pk, 1 + ) def test_settings_functions(self): """Test settings functions and properties.""" @@ -61,14 +68,25 @@ class SettingsTest(InvenTreeTestCase): stale_ref = 'STOCK_STALE_DAYS' stale_days = InvenTreeSetting.get_setting_object(stale_ref, cache=False) - report_size_obj = InvenTreeSetting.get_setting_object('REPORT_DEFAULT_PAGE_SIZE') - report_test_obj = InvenTreeSetting.get_setting_object('REPORT_ENABLE_TEST_REPORT') + report_size_obj = InvenTreeSetting.get_setting_object( + 'REPORT_DEFAULT_PAGE_SIZE' + ) + report_test_obj = InvenTreeSetting.get_setting_object( + 'REPORT_ENABLE_TEST_REPORT' + ) # check settings base fields self.assertEqual(instance_obj.name, 'Server Instance Name') - self.assertEqual(instance_obj.get_setting_name(instance_ref), 'Server Instance Name') - self.assertEqual(instance_obj.description, 'String descriptor for the server instance') - self.assertEqual(instance_obj.get_setting_description(instance_ref), 'String descriptor for the server instance') + self.assertEqual( + instance_obj.get_setting_name(instance_ref), 'Server Instance Name' + ) + self.assertEqual( + instance_obj.description, 'String descriptor for the server instance' + ) + self.assertEqual( + instance_obj.get_setting_description(instance_ref), + 'String descriptor for the server instance', + ) # check units self.assertEqual(instance_obj.units, '') @@ -90,7 +108,9 @@ class SettingsTest(InvenTreeTestCase): # check as_int self.assertEqual(stale_days.as_int(), 0) - self.assertEqual(instance_obj.as_int(), 'InvenTree') # not an int -> return default + self.assertEqual( + instance_obj.as_int(), 'InvenTree' + ) # not an int -> return default # check as_bool self.assertEqual(report_test_obj.as_bool(), True) @@ -116,62 +136,66 @@ class SettingsTest(InvenTreeTestCase): def test_all_settings(self): """Make sure that the all_settings function returns correctly""" result = InvenTreeSetting.all_settings() - self.assertIn("INVENTREE_INSTANCE", result) + self.assertIn('INVENTREE_INSTANCE', result) self.assertIsInstance(result['INVENTREE_INSTANCE'], InvenTreeSetting) - @mock.patch("common.models.InvenTreeSetting.get_setting_definition") + @mock.patch('common.models.InvenTreeSetting.get_setting_definition') def test_check_all_settings(self, get_setting_definition): """Make sure that the check_all_settings function returns correctly""" # define partial schema settings_definition = { - "AB": { # key that's has not already been accessed - "required": True, + 'AB': { # key that's has not already been accessed + 'required': True }, - "CD": { - "required": True, - "protected": True, - }, - "EF": {} + 'CD': {'required': True, 'protected': True}, + 'EF': {}, } def mocked(key, **kwargs): return settings_definition.get(key, {}) + get_setting_definition.side_effect = mocked - self.assertEqual(InvenTreeSetting.check_all_settings(settings_definition=settings_definition), (False, ["AB", "CD"])) - InvenTreeSetting.set_setting('AB', "hello", self.user) - InvenTreeSetting.set_setting('CD', "world", self.user) + self.assertEqual( + InvenTreeSetting.check_all_settings( + settings_definition=settings_definition + ), + (False, ['AB', 'CD']), + ) + InvenTreeSetting.set_setting('AB', 'hello', self.user) + InvenTreeSetting.set_setting('CD', 'world', self.user) self.assertEqual(InvenTreeSetting.check_all_settings(), (True, [])) - @mock.patch("common.models.InvenTreeSetting.get_setting_definition") + @mock.patch('common.models.InvenTreeSetting.get_setting_definition') def test_settings_validator(self, get_setting_definition): """Make sure that the validator function gets called on set setting.""" def validator(x): - if x == "hello": + if x == 'hello': return x - raise ValidationError(f"{x} is not valid") + raise ValidationError(f'{x} is not valid') mock_validator = mock.Mock(side_effect=validator) # define partial schema settings_definition = { - "AB": { # key that's has not already been accessed - "validator": mock_validator, - }, + 'AB': { # key that's has not already been accessed + 'validator': mock_validator + } } def mocked(key, **kwargs): return settings_definition.get(key, {}) + get_setting_definition.side_effect = mocked - InvenTreeSetting.set_setting("AB", "hello", self.user) - mock_validator.assert_called_with("hello") + InvenTreeSetting.set_setting('AB', 'hello', self.user) + mock_validator.assert_called_with('hello') with self.assertRaises(ValidationError): - InvenTreeSetting.set_setting("AB", "world", self.user) - mock_validator.assert_called_with("world") + InvenTreeSetting.set_setting('AB', 'world', self.user) + mock_validator.assert_called_with('world') def run_settings_check(self, key, setting): """Test that all settings are valid. @@ -194,7 +218,9 @@ class SettingsTest(InvenTreeTestCase): self.assertIn('django.utils.functional.lazy', str(type(description))) if key != key.upper(): - raise ValueError(f"Setting key '{key}' is not uppercase") # pragma: no cover + raise ValueError( + f"Setting key '{key}' is not uppercase" + ) # pragma: no cover # Check that only allowed keys are provided allowed_keys = [ @@ -232,7 +258,6 @@ class SettingsTest(InvenTreeTestCase): - Ensure that every setting has a description, which is translated """ for key, setting in InvenTreeSetting.SETTINGS.items(): - try: self.run_settings_check(key, setting) except Exception as exc: # pragma: no cover @@ -249,7 +274,6 @@ class SettingsTest(InvenTreeTestCase): def test_defaults(self): """Populate the settings with default values.""" for key in InvenTreeSetting.SETTINGS.keys(): - value = InvenTreeSetting.get_setting_default(key) InvenTreeSetting.set_setting(key, value, self.user) @@ -261,10 +285,14 @@ class SettingsTest(InvenTreeTestCase): if setting.is_bool(): if setting.default_value in ['', None]: - raise ValueError(f'Default value for boolean setting {key} not provided') # pragma: no cover + raise ValueError( + f'Default value for boolean setting {key} not provided' + ) # pragma: no cover if setting.default_value not in [True, False]: - raise ValueError(f'Non-boolean default value specified for {key}') # pragma: no cover + raise ValueError( + f'Non-boolean default value specified for {key}' + ) # pragma: no cover def test_global_setting_caching(self): """Test caching operations for the global settings class""" @@ -294,9 +322,7 @@ class SettingsTest(InvenTreeTestCase): # Generate a number of new users for idx in range(5): get_user_model().objects.create( - username=f"User_{idx}", - password="hunter42", - email="email@dot.com", + username=f'User_{idx}', password='hunter42', email='email@dot.com' ) key = 'SEARCH_PREVIEW_RESULTS' @@ -305,7 +331,10 @@ class SettingsTest(InvenTreeTestCase): for user in get_user_model().objects.all(): setting = InvenTreeUserSetting.get_setting_object(key, user=user) cache_key = setting.cache_key - self.assertEqual(cache_key, f"InvenTreeUserSetting:SEARCH_PREVIEW_RESULTS_user:{user.username}") + self.assertEqual( + cache_key, + f'InvenTreeUserSetting:SEARCH_PREVIEW_RESULTS_user:{user.username}', + ) InvenTreeUserSetting.set_setting(key, user.pk, None, user=user) self.assertIsNotNone(cache.get(cache_key)) @@ -333,7 +362,9 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase): response = self.get(url, expected_code=200) - n_public_settings = len([k for k in InvenTreeSetting.SETTINGS.keys() if not k.startswith('_')]) + n_public_settings = len([ + k for k in InvenTreeSetting.SETTINGS.keys() if not k.startswith('_') + ]) # Number of results should match the number of settings self.assertEqual(len(response.data), n_public_settings) @@ -358,13 +389,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase): # Test setting via the API for val in ['cat', 'hat', 'bat', 'mat']: - response = self.patch( - url, - { - 'value': val, - }, - expected_code=200 - ) + response = self.patch(url, {'value': val}, expected_code=200) self.assertEqual(response.data['value'], val) @@ -374,7 +399,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase): def test_api_detail(self): """Test that we can access the detail view for a setting based on the .""" # These keys are invalid, and should return 404 - for key in ["apple", "carrot", "dog"]: + for key in ['apple', 'carrot', 'dog']: response = self.get( reverse('api-global-setting-detail', kwargs={'key': key}), expected_code=404, @@ -394,13 +419,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase): self.assertEqual(response.data['value'], 'InvenTree') # Now, the object should have been created in the DB - self.patch( - url, - { - 'value': 'My new title', - }, - expected_code=200, - ) + self.patch(url, {'value': 'My new title'}, expected_code=200) setting = InvenTreeSetting.objects.get(key=key) @@ -451,8 +470,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase): """Test a boolean user setting value.""" # Ensure we have a boolean setting available setting = InvenTreeUserSetting.get_setting_object( - 'SEARCH_PREVIEW_SHOW_PARTS', - user=self.user + 'SEARCH_PREVIEW_SHOW_PARTS', user=self.user ) # Check default values @@ -465,20 +483,16 @@ class UserSettingsApiTest(InvenTreeAPITestCase): self.assertEqual(response.data['pk'], setting.pk) self.assertEqual(response.data['key'], 'SEARCH_PREVIEW_SHOW_PARTS') - self.assertEqual(response.data['description'], 'Display parts in search preview window') + self.assertEqual( + response.data['description'], 'Display parts in search preview window' + ) self.assertEqual(response.data['type'], 'boolean') self.assertEqual(len(response.data['choices']), 0) self.assertTrue(str2bool(response.data['value'])) # Assign some truthy values for v in ['true', True, 1, 'y', 'TRUE']: - self.patch( - url, - { - 'value': str(v), - }, - expected_code=200, - ) + self.patch(url, {'value': str(v)}, expected_code=200) response = self.get(url, expected_code=200) @@ -486,13 +500,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase): # Assign some false(ish) values for v in ['false', False, '0', 'n', 'FalSe']: - self.patch( - url, - { - 'value': str(v), - }, - expected_code=200, - ) + self.patch(url, {'value': str(v)}, expected_code=200) response = self.get(url, expected_code=200) @@ -500,13 +508,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase): # Assign some invalid values for v in ['x', '', 'invalid', None, '-1', 'abcde']: - response = self.patch( - url, - { - 'value': str(v), - }, - expected_code=200 - ) + response = self.patch(url, {'value': str(v)}, expected_code=200) # Invalid values evaluate to False self.assertFalse(str2bool(response.data['value'])) @@ -514,8 +516,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase): def test_user_setting_choice(self): """Test a user setting with choices.""" setting = InvenTreeUserSetting.get_setting_object( - 'DATE_DISPLAY_FORMAT', - user=self.user + 'DATE_DISPLAY_FORMAT', user=self.user ) url = reverse('api-user-setting-detail', kwargs={'key': setting.key}) @@ -525,37 +526,21 @@ class UserSettingsApiTest(InvenTreeAPITestCase): # Check that a valid option can be assigned via the API for opt in ['YYYY-MM-DD', 'DD-MM-YYYY', 'MM/DD/YYYY']: - - self.patch( - url, - { - 'value': opt, - }, - expected_code=200, - ) + self.patch(url, {'value': opt}, expected_code=200) setting.refresh_from_db() self.assertEqual(setting.value, opt) # Send an invalid option for opt in ['cat', 'dog', 12345]: - - response = self.patch( - url, - { - 'value': opt, - }, - expected_code=400, - ) + response = self.patch(url, {'value': opt}, expected_code=400) self.assertIn('Chosen value is not a valid option', str(response.data)) def test_user_setting_integer(self): """Test a integer user setting value.""" setting = InvenTreeUserSetting.get_setting_object( - 'SEARCH_PREVIEW_RESULTS', - user=self.user, - cache=False, + 'SEARCH_PREVIEW_RESULTS', user=self.user, cache=False ) url = reverse('api-user-setting-detail', kwargs={'key': setting.key}) @@ -573,13 +558,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase): # Set valid options via the api for v in [5, 15, 25]: - self.patch( - url, - { - 'value': v, - }, - expected_code=200, - ) + self.patch(url, {'value': v}, expected_code=200) setting.refresh_from_db() self.assertEqual(setting.to_native_value(), v) @@ -587,14 +566,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase): # Set invalid options via the API # Note that this particular setting has a MinValueValidator(1) associated with it for v in [0, -1, -5]: - - response = self.patch( - url, - { - 'value': v, - }, - expected_code=400, - ) + response = self.patch(url, {'value': v}, expected_code=400) class NotificationUserSettingsApiTest(InvenTreeAPITestCase): @@ -608,9 +580,15 @@ class NotificationUserSettingsApiTest(InvenTreeAPITestCase): def test_setting(self): """Test the string name for NotificationUserSetting.""" - NotificationUserSetting.set_setting('NOTIFICATION_METHOD_MAIL', True, change_user=self.user, user=self.user) - test_setting = NotificationUserSetting.get_setting_object('NOTIFICATION_METHOD_MAIL', user=self.user) - self.assertEqual(str(test_setting), 'NOTIFICATION_METHOD_MAIL (for testuser): True') + NotificationUserSetting.set_setting( + 'NOTIFICATION_METHOD_MAIL', True, change_user=self.user, user=self.user + ) + test_setting = NotificationUserSetting.get_setting_object( + 'NOTIFICATION_METHOD_MAIL', user=self.user + ) + self.assertEqual( + str(test_setting), 'NOTIFICATION_METHOD_MAIL (for testuser): True' + ) class PluginSettingsApiTest(PluginMixin, InvenTreeAPITestCase): @@ -638,26 +616,38 @@ class PluginSettingsApiTest(PluginMixin, InvenTreeAPITestCase): registry.set_plugin_state('sample', True) # get data - url = reverse('api-plugin-setting-detail', kwargs={'plugin': 'sample', 'key': 'API_KEY'}) + url = reverse( + 'api-plugin-setting-detail', kwargs={'plugin': 'sample', 'key': 'API_KEY'} + ) response = self.get(url, expected_code=200) # check the right setting came through self.assertTrue(response.data['key'], 'API_KEY') self.assertTrue(response.data['plugin'], 'sample') self.assertTrue(response.data['type'], 'string') - self.assertTrue(response.data['description'], 'Key required for accessing external API') + self.assertTrue( + response.data['description'], 'Key required for accessing external API' + ) # Failure mode tests # Non-existent plugin - url = reverse('api-plugin-setting-detail', kwargs={'plugin': 'doesnotexist', 'key': 'doesnotmatter'}) + url = reverse( + 'api-plugin-setting-detail', + kwargs={'plugin': 'doesnotexist', 'key': 'doesnotmatter'}, + ) response = self.get(url, expected_code=404) self.assertIn("Plugin 'doesnotexist' not installed", str(response.data)) # Wrong key - url = reverse('api-plugin-setting-detail', kwargs={'plugin': 'sample', 'key': 'doesnotexist'}) + url = reverse( + 'api-plugin-setting-detail', + kwargs={'plugin': 'sample', 'key': 'doesnotexist'}, + ) response = self.get(url, expected_code=404) - self.assertIn("Plugin 'sample' has no setting matching 'doesnotexist'", str(response.data)) + self.assertIn( + "Plugin 'sample' has no setting matching 'doesnotexist'", str(response.data) + ) def test_invalid_setting_key(self): """Test that an invalid setting key returns a 404.""" @@ -684,32 +674,30 @@ class WebhookMessageTests(TestCase): def test_missing_token(self): """Tests that token checks work.""" - response = self.client.post( - self.url, - content_type=CONTENT_TYPE_JSON, - ) + response = self.client.post(self.url, content_type=CONTENT_TYPE_JSON) assert response.status_code == HTTPStatus.FORBIDDEN assert ( - json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR + json.loads(response.content)['detail'] + == WebhookView.model_class.MESSAGE_TOKEN_ERROR ) def test_bad_token(self): """Test that a wrong token is not working.""" response = self.client.post( - self.url, - content_type=CONTENT_TYPE_JSON, - **{'HTTP_TOKEN': '1234567fghj'}, + self.url, content_type=CONTENT_TYPE_JSON, **{'HTTP_TOKEN': '1234567fghj'} ) assert response.status_code == HTTPStatus.FORBIDDEN - assert (json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR) + assert ( + json.loads(response.content)['detail'] + == WebhookView.model_class.MESSAGE_TOKEN_ERROR + ) def test_bad_url(self): """Test that a wrongly formed url is not working.""" response = self.client.post( - '/api/webhook/1234/', - content_type=CONTENT_TYPE_JSON, + '/api/webhook/1234/', content_type=CONTENT_TYPE_JSON ) assert response.status_code == HTTPStatus.NOT_FOUND @@ -725,7 +713,8 @@ class WebhookMessageTests(TestCase): assert response.status_code == HTTPStatus.NOT_ACCEPTABLE assert ( - json.loads(response.content)['detail'] == 'Expecting property name enclosed in double quotes' + json.loads(response.content)['detail'] + == 'Expecting property name enclosed in double quotes' ) def test_success_no_token_check(self): @@ -735,10 +724,7 @@ class WebhookMessageTests(TestCase): self.endpoint_def.save() # check - response = self.client.post( - self.url, - content_type=CONTENT_TYPE_JSON, - ) + response = self.client.post(self.url, content_type=CONTENT_TYPE_JSON) assert response.status_code == HTTPStatus.OK assert str(response.content, 'utf-8') == WebhookView.model_class.MESSAGE_OK @@ -751,13 +737,13 @@ class WebhookMessageTests(TestCase): self.endpoint_def.save() # check - response = self.client.post( - self.url, - content_type=CONTENT_TYPE_JSON, - ) + response = self.client.post(self.url, content_type=CONTENT_TYPE_JSON) assert response.status_code == HTTPStatus.FORBIDDEN - assert (json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR) + assert ( + json.loads(response.content)['detail'] + == WebhookView.model_class.MESSAGE_TOKEN_ERROR + ) def test_success_hmac(self): """Test with a valid HMAC provided.""" @@ -783,7 +769,7 @@ class WebhookMessageTests(TestCase): """ response = self.client.post( self.url, - data={"this": "is a message"}, + data={'this': 'is a message'}, content_type=CONTENT_TYPE_JSON, **{'HTTP_TOKEN': str(self.endpoint_def.token)}, ) @@ -791,15 +777,13 @@ class WebhookMessageTests(TestCase): assert response.status_code == HTTPStatus.OK assert str(response.content, 'utf-8') == WebhookView.model_class.MESSAGE_OK message = WebhookMessage.objects.get() - assert message.body == {"this": "is a message"} + assert message.body == {'this': 'is a message'} class NotificationTest(InvenTreeAPITestCase): """Tests for NotificationEntry.""" - fixtures = [ - 'users', - ] + fixtures = ['users'] def test_check_notification_entries(self): """Test that notification entries can be created.""" @@ -832,7 +816,10 @@ class NotificationTest(InvenTreeAPITestCase): self.assertIn('GET', response.data['actions']) self.assertNotIn('POST', response.data['actions']) - self.assertEqual(response.data['description'], 'List view for all notifications of the current user.') + self.assertEqual( + response.data['description'], + 'List view for all notifications of the current user.', + ) # POST action should fail (not allowed) response = self.post(url, {}, expected_code=405) @@ -867,13 +854,7 @@ class NotificationTest(InvenTreeAPITestCase): ntf.save() # Read out via API again - response = self.get( - url, - { - 'read': True, - }, - expected_code=200 - ) + response = self.get(url, {'read': True}, expected_code=200) # Check validity of returned data self.assertEqual(len(response.data), 3) @@ -882,15 +863,7 @@ class NotificationTest(InvenTreeAPITestCase): # Now, let's bulk delete all 'unread' notifications via the API, # but only associated with the logged in user - response = self.delete( - url, - { - 'filters': { - 'read': False, - } - }, - expected_code=204, - ) + response = self.delete(url, {'filters': {'read': False}}, expected_code=204) # Only 7 notifications should have been deleted, # as the notifications associated with other users must remain untouched @@ -907,13 +880,17 @@ class CommonTest(InvenTreeAPITestCase): from plugin import registry # set flag true - common.models.InvenTreeSetting.set_setting('SERVER_RESTART_REQUIRED', True, None) + common.models.InvenTreeSetting.set_setting( + 'SERVER_RESTART_REQUIRED', True, None + ) # reload the app registry.reload_plugins() # now it should be false again - self.assertFalse(common.models.InvenTreeSetting.get_setting('SERVER_RESTART_REQUIRED')) + self.assertFalse( + common.models.InvenTreeSetting.get_setting('SERVER_RESTART_REQUIRED') + ) def test_config_api(self): """Test config URLs.""" @@ -926,8 +903,13 @@ class CommonTest(InvenTreeAPITestCase): # Successful checks data = [ - self.get(reverse('api-config-list'), expected_code=200).data[0], # list endpoint - self.get(reverse('api-config-detail', kwargs={'key': 'INVENTREE_DEBUG'}), expected_code=200).data, # detail endpoint + self.get(reverse('api-config-list'), expected_code=200).data[ + 0 + ], # list endpoint + self.get( + reverse('api-config-detail', kwargs={'key': 'INVENTREE_DEBUG'}), + expected_code=200, + ).data, # detail endpoint ] for item in data: @@ -956,21 +938,33 @@ class CommonTest(InvenTreeAPITestCase): self.assertEqual(response.data[0]['key'], 'EXPERIMENTAL') self.assertTrue(response.data[0]['conditions']) - response = self.get(reverse('api-flag-detail', kwargs={'key': 'EXPERIMENTAL'}), expected_code=200) + response = self.get( + reverse('api-flag-detail', kwargs={'key': 'EXPERIMENTAL'}), + expected_code=200, + ) self.assertEqual(len(response.data), 3) self.assertEqual(response.data['key'], 'EXPERIMENTAL') self.assertTrue(response.data['conditions']) # Try without param -> false - response = self.get(reverse('api-flag-detail', kwargs={'key': 'NEXT_GEN'}), expected_code=200) + response = self.get( + reverse('api-flag-detail', kwargs={'key': 'NEXT_GEN'}), expected_code=200 + ) self.assertFalse(response.data['state']) # Try with param -> true - response = self.get(reverse('api-flag-detail', kwargs={'key': 'NEXT_GEN'}), {'ngen': ''}, expected_code=200) + response = self.get( + reverse('api-flag-detail', kwargs={'key': 'NEXT_GEN'}), + {'ngen': ''}, + expected_code=200, + ) self.assertTrue(response.data['state']) # Try non existent flag - response = self.get(reverse('api-flag-detail', kwargs={'key': 'NON_EXISTENT'}), expected_code=404) + response = self.get( + reverse('api-flag-detail', kwargs={'key': 'NON_EXISTENT'}), + expected_code=404, + ) # Turn into normal user again self.user.is_superuser = False @@ -1038,7 +1032,7 @@ class CurrencyAPITests(InvenTreeAPITestCase): # Delay and try again time.sleep(10) - raise TimeoutError("Could not refresh currency exchange data after 5 attempts") + raise TimeoutError('Could not refresh currency exchange data after 5 attempts') class NotesImageTest(InvenTreeAPITestCase): @@ -1052,25 +1046,29 @@ class NotesImageTest(InvenTreeAPITestCase): response = self.post( reverse('api-notes-image-list'), data={ - 'image': SimpleUploadedFile('test.txt', b"this is not an image file", content_type='text/plain'), - }, - format='multipart', - expected_code=400 - ) - - self.assertIn("Upload a valid image", str(response.data['image'])) - - # Test upload of an invalid image file - response = self.post( - reverse('api-notes-image-list'), - data={ - 'image': SimpleUploadedFile('test.png', b"this is not an image file", content_type='image/png'), + 'image': SimpleUploadedFile( + 'test.txt', b'this is not an image file', content_type='text/plain' + ) }, format='multipart', expected_code=400, ) - self.assertIn("Upload a valid image", str(response.data['image'])) + self.assertIn('Upload a valid image', str(response.data['image'])) + + # Test upload of an invalid image file + response = self.post( + reverse('api-notes-image-list'), + data={ + 'image': SimpleUploadedFile( + 'test.png', b'this is not an image file', content_type='image/png' + ) + }, + format='multipart', + expected_code=400, + ) + + self.assertIn('Upload a valid image', str(response.data['image'])) # Check that no extra database entries have been created self.assertEqual(NotesImage.objects.count(), n) @@ -1089,10 +1087,12 @@ class NotesImageTest(InvenTreeAPITestCase): self.post( reverse('api-notes-image-list'), data={ - 'image': SimpleUploadedFile('test.png', contents, content_type='image/png'), + 'image': SimpleUploadedFile( + 'test.png', contents, content_type='image/png' + ) }, format='multipart', - expected_code=201 + expected_code=201, ) # Check that a new file has been created @@ -1136,7 +1136,7 @@ class ProjectCodesTest(InvenTreeAPITestCase): # Delete it self.delete( reverse('api-project-code-detail', kwargs={'pk': code.pk}), - expected_code=204 + expected_code=204, ) # Check it is gone @@ -1147,25 +1147,22 @@ class ProjectCodesTest(InvenTreeAPITestCase): # Create a new project code response = self.post( self.url, - data={ - 'code': 'PRJ-001', - 'description': 'Test project code', - }, - expected_code=400 + data={'code': 'PRJ-001', 'description': 'Test project code'}, + expected_code=400, ) - self.assertIn('project code with this Project Code already exists', str(response.data['code'])) + self.assertIn( + 'project code with this Project Code already exists', + str(response.data['code']), + ) def test_write_access(self): """Test that non-staff users have read-only access""" # By default user has staff access, can create a new project code response = self.post( self.url, - data={ - 'code': 'PRJ-xxx', - 'description': 'Test project code', - }, - expected_code=201 + data={'code': 'PRJ-xxx', 'description': 'Test project code'}, + expected_code=201, ) pk = response.data['pk'] @@ -1173,10 +1170,8 @@ class ProjectCodesTest(InvenTreeAPITestCase): # Test we can edit, also response = self.patch( reverse('api-project-code-detail', kwargs={'pk': pk}), - data={ - 'code': 'PRJ-999', - }, - expected_code=200 + data={'code': 'PRJ-999'}, + expected_code=200, ) self.assertEqual(response.data['code'], 'PRJ-999') @@ -1188,20 +1183,15 @@ class ProjectCodesTest(InvenTreeAPITestCase): # As user does not have staff access, should return 403 for list endpoint response = self.post( self.url, - data={ - 'code': 'PRJ-123', - 'description': 'Test project code' - }, - expected_code=403 + data={'code': 'PRJ-123', 'description': 'Test project code'}, + expected_code=403, ) # Should also return 403 for detail endpoint response = self.patch( reverse('api-project-code-detail', kwargs={'pk': pk}), - data={ - 'code': 'PRJ-999', - }, - expected_code=403 + data={'code': 'PRJ-999'}, + expected_code=403, ) @@ -1219,8 +1209,14 @@ class CustomUnitAPITest(InvenTreeAPITestCase): super().setUpTestData() units = [ - CustomUnit(name='metres_per_amp', definition='meter / ampere', symbol='m/A'), - CustomUnit(name='hectares_per_second', definition='hectares per second', symbol='ha/s'), + CustomUnit( + name='metres_per_amp', definition='meter / ampere', symbol='m/A' + ), + CustomUnit( + name='hectares_per_second', + definition='hectares per second', + symbol='ha/s', + ), ] CustomUnit.objects.bulk_create(units) @@ -1240,10 +1236,8 @@ class CustomUnitAPITest(InvenTreeAPITestCase): self.patch( reverse('api-custom-unit-detail', kwargs={'pk': unit.pk}), - { - 'name': 'new_unit_name', - }, - expected_code=403 + {'name': 'new_unit_name'}, + expected_code=403, ) # Ok, what if we have permission? @@ -1252,9 +1246,7 @@ class CustomUnitAPITest(InvenTreeAPITestCase): self.patch( reverse('api-custom-unit-detail', kwargs={'pk': unit.pk}), - { - 'name': 'new_unit_name', - }, + {'name': 'new_unit_name'}, # expected_code=200 ) @@ -1269,21 +1261,9 @@ class CustomUnitAPITest(InvenTreeAPITestCase): self.user.save() # Test invalid 'name' values (must be valid identifier) - invalid_name_values = [ - '1', - '1abc', - 'abc def', - 'abc-def', - 'abc.def', - ] + invalid_name_values = ['1', '1abc', 'abc def', 'abc-def', 'abc.def'] url = reverse('api-custom-unit-detail', kwargs={'pk': unit.pk}) for name in invalid_name_values: - self.patch( - url, - { - 'name': name, - }, - expected_code=400 - ) + self.patch(url, {'name': name}, expected_code=400) diff --git a/InvenTree/common/urls.py b/InvenTree/common/urls.py index 4803d86236..95f89631df 100644 --- a/InvenTree/common/urls.py +++ b/InvenTree/common/urls.py @@ -1,4 +1,3 @@ """URL lookup for common views.""" -common_urls = [ -] +common_urls = [] diff --git a/InvenTree/common/views.py b/InvenTree/common/views.py index 2eeb48b5df..3d9cd53c22 100644 --- a/InvenTree/common/views.py +++ b/InvenTree/common/views.py @@ -81,11 +81,7 @@ class FileManagementFormView(MultiStepFormView): ('fields', forms.MatchFieldForm), ('items', forms.MatchItemForm), ] - form_steps_description = [ - _("Upload File"), - _("Match Fields"), - _("Match Items"), - ] + form_steps_description = [_('Upload File'), _('Match Fields'), _('Match Items')] media_folder = 'file_upload/' extra_context_data = {} @@ -95,8 +91,12 @@ class FileManagementFormView(MultiStepFormView): super().__init__(self, *args, **kwargs) # Check for file manager class - if not hasattr(self, 'file_manager_class') and not issubclass(self.file_manager_class, FileManager): - raise NotImplementedError('A subclass of a file manager class needs to be set!') + if not hasattr(self, 'file_manager_class') and not issubclass( + self.file_manager_class, FileManager + ): + raise NotImplementedError( + 'A subclass of a file manager class needs to be set!' + ) def get_context_data(self, form=None, **kwargs): """Handle context data.""" @@ -106,7 +106,6 @@ class FileManagementFormView(MultiStepFormView): context = super().get_context_data(form=form, **kwargs) if self.steps.current in ('fields', 'items'): - # Get columns and row data self.columns = self.file_manager.columns() self.rows = self.file_manager.rows() @@ -140,7 +139,9 @@ class FileManagementFormView(MultiStepFormView): # Get file file = upload_files.get('upload-file', None) if file: - self.file_manager = self.file_manager_class(file=file, name=self.name) + self.file_manager = self.file_manager_class( + file=file, name=self.name + ) def get_form_kwargs(self, step=None): """Update kwargs to dynamically build forms.""" @@ -150,15 +151,11 @@ class FileManagementFormView(MultiStepFormView): if step == 'upload': # Dynamically build upload form if self.name: - kwargs = { - 'name': self.name - } + kwargs = {'name': self.name} return kwargs elif step == 'fields': # Dynamically build match field form - kwargs = { - 'file_manager': self.file_manager - } + kwargs = {'file_manager': self.file_manager} return kwargs elif step == 'items': # Dynamically build match item form @@ -206,7 +203,6 @@ class FileManagementFormView(MultiStepFormView): self.row_data = {} for item, value in form_data.items(): - # Column names as passed as col_name_ where idx is an integer # Extract the column names @@ -220,7 +216,6 @@ class FileManagementFormView(MultiStepFormView): # Extract the column selections (in the 'select fields' view) if item.startswith('fields-'): - try: col_name = item.replace('fields-', '') except ValueError: @@ -258,10 +253,7 @@ class FileManagementFormView(MultiStepFormView): self.columns = [] for idx, value in self.column_names.items(): - header = ({ - 'name': value, - 'guess': self.column_selections.get(idx, ''), - }) + header = {'name': value, 'guess': self.column_selections.get(idx, '')} self.columns.append(header) if self.row_data: @@ -280,18 +272,10 @@ class FileManagementFormView(MultiStepFormView): 'guess': self.column_selections[idx], } - cell_data = { - 'cell': item, - 'idx': idx, - 'column': column_data, - } + cell_data = {'cell': item, 'idx': idx, 'column': column_data} data.append(cell_data) - row = { - 'index': row_idx, - 'data': data, - 'errors': {}, - } + row = {'index': row_idx, 'data': data, 'errors': {}} self.rows.append(row) @@ -344,11 +328,7 @@ class FileManagementFormView(MultiStepFormView): try: if idx not in items: # Insert into items - items.update({ - idx: { - self.form_field_map[field]: form_value, - } - }) + items.update({idx: {self.form_field_map[field]: form_value}}) else: # Update items items[idx][self.form_field_map[field]] = form_value @@ -383,14 +363,15 @@ class FileManagementFormView(MultiStepFormView): duplicates = [] for col in self.column_names: - if col in self.column_selections: guess = self.column_selections[col] else: guess = None if guess: - n = list(self.column_selections.values()).count(self.column_selections[col]) + n = list(self.column_selections.values()).count( + self.column_selections[col] + ) if n > 1 and self.column_selections[col] not in duplicates: duplicates.append(self.column_selections[col]) @@ -459,7 +440,9 @@ class FileManagementAjaxView(AjaxView): wizard_back = self.request.POST.get('act-btn_back', None) if wizard_back: back_step_index = self.get_step_index() - 1 - self.storage.current_step = list(self.get_form_list().keys())[back_step_index] + self.storage.current_step = list(self.get_form_list().keys())[ + back_step_index + ] return self.renderJsonResponse(request, data={'form_valid': None}) # validate form @@ -499,13 +482,19 @@ class FileManagementAjaxView(AjaxView): data = {} self.setTemplate() - return super().renderJsonResponse(request, form=form, data=data, context=context) + return super().renderJsonResponse( + request, form=form, data=data, context=context + ) def get_data(self) -> dict: """Get extra context data.""" data = super().get_data() data['hideErrorMessage'] = '1' # hide the error - buttons = [{'name': 'back', 'title': _('Previous Step')}] if self.get_step_index() > 0 else [] + buttons = ( + [{'name': 'back', 'title': _('Previous Step')}] + if self.get_step_index() > 0 + else [] + ) data['buttons'] = buttons # set buttons return data diff --git a/InvenTree/company/admin.py b/InvenTree/company/admin.py index 5cf7bbda87..ff25a2e086 100644 --- a/InvenTree/company/admin.py +++ b/InvenTree/company/admin.py @@ -9,9 +9,16 @@ from import_export.fields import Field from InvenTree.admin import InvenTreeResource from part.models import Part -from .models import (Address, Company, Contact, ManufacturerPart, - ManufacturerPartAttachment, ManufacturerPartParameter, - SupplierPart, SupplierPriceBreak) +from .models import ( + Address, + Company, + Contact, + ManufacturerPart, + ManufacturerPartAttachment, + ManufacturerPartParameter, + SupplierPart, + SupplierPriceBreak, +) class CompanyResource(InvenTreeResource): @@ -19,6 +26,7 @@ class CompanyResource(InvenTreeResource): class Meta: """Metaclass defines extra options""" + model = Company skip_unchanged = True report_skipped = False @@ -32,10 +40,7 @@ class CompanyAdmin(ImportExportModelAdmin): list_display = ('name', 'website', 'contact') - search_fields = [ - 'name', - 'description', - ] + search_fields = ['name', 'description'] class SupplierPartResource(InvenTreeResource): @@ -43,6 +48,7 @@ class SupplierPartResource(InvenTreeResource): class Meta: """Metaclass defines extra admin options""" + model = SupplierPart skip_unchanged = True report_skipped = True @@ -70,18 +76,11 @@ class SupplierPartAdmin(ImportExportModelAdmin): list_display = ('part', 'supplier', 'SKU') - search_fields = [ - 'supplier__name', - 'part__name', - 'manufacturer_part__MPN', - 'SKU', - ] + search_fields = ['supplier__name', 'part__name', 'manufacturer_part__MPN', 'SKU'] - inlines = [ - SupplierPriceBreakInline, - ] + inlines = [SupplierPriceBreakInline] - autocomplete_fields = ('part', 'supplier', 'manufacturer_part',) + autocomplete_fields = ('part', 'supplier', 'manufacturer_part') class ManufacturerPartResource(InvenTreeResource): @@ -89,6 +88,7 @@ class ManufacturerPartResource(InvenTreeResource): class Meta: """Metaclass defines extra admin options""" + model = ManufacturerPart skip_unchanged = True report_skipped = True @@ -98,7 +98,9 @@ class ManufacturerPartResource(InvenTreeResource): part_name = Field(attribute='part__full_name', readonly=True) - manufacturer = Field(attribute='manufacturer', widget=widgets.ForeignKeyWidget(Company)) + manufacturer = Field( + attribute='manufacturer', widget=widgets.ForeignKeyWidget(Company) + ) manufacturer_name = Field(attribute='manufacturer__name', readonly=True) @@ -110,13 +112,9 @@ class ManufacturerPartAdmin(ImportExportModelAdmin): list_display = ('part', 'manufacturer', 'MPN') - search_fields = [ - 'manufacturer__name', - 'part__name', - 'MPN', - ] + search_fields = ['manufacturer__name', 'part__name', 'MPN'] - autocomplete_fields = ('part', 'manufacturer',) + autocomplete_fields = ('part', 'manufacturer') class ManufacturerPartAttachmentAdmin(ImportExportModelAdmin): @@ -132,6 +130,7 @@ class ManufacturerPartParameterResource(InvenTreeResource): class Meta: """Metaclass defines extra admin options""" + model = ManufacturerPartParameter skip_unchanged = True report_skipped = True @@ -145,11 +144,7 @@ class ManufacturerPartParameterAdmin(ImportExportModelAdmin): list_display = ('manufacturer_part', 'name', 'value') - search_fields = [ - 'manufacturer_part__manufacturer__name', - 'name', - 'value' - ] + search_fields = ['manufacturer_part__manufacturer__name', 'name', 'value'] autocomplete_fields = ('manufacturer_part',) @@ -159,6 +154,7 @@ class SupplierPriceBreakResource(InvenTreeResource): class Meta: """Metaclass defines extra admin options""" + model = SupplierPriceBreak skip_unchanged = True report_skipped = False @@ -192,6 +188,7 @@ class AddressResource(InvenTreeResource): class Meta: """Metaclass defining extra options""" + model = Address skip_unchanged = True report_skipped = False @@ -207,11 +204,7 @@ class AddressAdmin(ImportExportModelAdmin): list_display = ('company', 'line1', 'postal_code', 'country') - search_fields = [ - 'company', - 'country', - 'postal_code', - ] + search_fields = ['company', 'country', 'postal_code'] class ContactResource(InvenTreeResource): @@ -219,6 +212,7 @@ class ContactResource(InvenTreeResource): class Meta: """Metaclass defining extra options""" + model = Contact skip_unchanged = True report_skipped = False @@ -234,11 +228,7 @@ class ContactAdmin(ImportExportModelAdmin): list_display = ('company', 'name', 'role', 'email', 'phone') - search_fields = [ - 'company', - 'name', - 'email', - ] + search_fields = ['company', 'name', 'email'] admin.site.register(Company, CompanyAdmin) diff --git a/InvenTree/company/api.py b/InvenTree/company/api.py index 51c27f2279..b35ce16d07 100644 --- a/InvenTree/company/api.py +++ b/InvenTree/company/api.py @@ -6,23 +6,37 @@ from django.urls import include, path, re_path from django_filters import rest_framework as rest_filters import part.models -from InvenTree.api import (AttachmentMixin, ListCreateDestroyAPIView, - MetadataView) -from InvenTree.filters import (ORDER_FILTER, SEARCH_ORDER_FILTER, - SEARCH_ORDER_FILTER_ALIAS) +from InvenTree.api import AttachmentMixin, ListCreateDestroyAPIView, MetadataView +from InvenTree.filters import ( + ORDER_FILTER, + SEARCH_ORDER_FILTER, + SEARCH_ORDER_FILTER_ALIAS, +) from InvenTree.helpers import str2bool from InvenTree.mixins import ListCreateAPI, RetrieveUpdateDestroyAPI -from .models import (Address, Company, CompanyAttachment, Contact, - ManufacturerPart, ManufacturerPartAttachment, - ManufacturerPartParameter, SupplierPart, - SupplierPriceBreak) -from .serializers import (AddressSerializer, CompanyAttachmentSerializer, - CompanySerializer, ContactSerializer, - ManufacturerPartAttachmentSerializer, - ManufacturerPartParameterSerializer, - ManufacturerPartSerializer, SupplierPartSerializer, - SupplierPriceBreakSerializer) +from .models import ( + Address, + Company, + CompanyAttachment, + Contact, + ManufacturerPart, + ManufacturerPartAttachment, + ManufacturerPartParameter, + SupplierPart, + SupplierPriceBreak, +) +from .serializers import ( + AddressSerializer, + CompanyAttachmentSerializer, + CompanySerializer, + ContactSerializer, + ManufacturerPartAttachmentSerializer, + ManufacturerPartParameterSerializer, + ManufacturerPartSerializer, + SupplierPartSerializer, + SupplierPriceBreakSerializer, +) class CompanyList(ListCreateAPI): @@ -46,24 +60,11 @@ class CompanyList(ListCreateAPI): filter_backends = SEARCH_ORDER_FILTER - filterset_fields = [ - 'is_customer', - 'is_manufacturer', - 'is_supplier', - 'name', - ] + filterset_fields = ['is_customer', 'is_manufacturer', 'is_supplier', 'name'] - search_fields = [ - 'name', - 'description', - 'website', - ] + search_fields = ['name', 'description', 'website'] - ordering_fields = [ - 'name', - 'parts_supplied', - 'parts_manufactured', - ] + ordering_fields = ['name', 'parts_supplied', 'parts_manufactured'] ordering = 'name' @@ -88,9 +89,7 @@ class CompanyAttachmentList(AttachmentMixin, ListCreateDestroyAPIView): queryset = CompanyAttachment.objects.all() serializer_class = CompanyAttachmentSerializer - filterset_fields = [ - 'company', - ] + filterset_fields = ['company'] class CompanyAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI): @@ -108,18 +107,11 @@ class ContactList(ListCreateDestroyAPIView): filter_backends = SEARCH_ORDER_FILTER - filterset_fields = [ - 'company', - ] + filterset_fields = ['company'] - search_fields = [ - 'company__name', - 'name', - ] + search_fields = ['company__name', 'name'] - ordering_fields = [ - 'name', - ] + ordering_fields = ['name'] ordering = 'name' @@ -139,13 +131,9 @@ class AddressList(ListCreateDestroyAPIView): filter_backends = SEARCH_ORDER_FILTER - filterset_fields = [ - 'company', - ] + filterset_fields = ['company'] - ordering_fields = [ - 'title', - ] + ordering_fields = ['title'] ordering = 'title' @@ -164,13 +152,7 @@ class ManufacturerPartFilter(rest_filters.FilterSet): """Metaclass options.""" model = ManufacturerPart - fields = [ - 'manufacturer', - 'MPN', - 'part', - 'tags__name', - 'tags__slug', - ] + fields = ['manufacturer', 'MPN', 'part', 'tags__name', 'tags__slug'] # Filter by 'active' status of linked part active = rest_filters.BooleanFilter(field_name='part__active') @@ -184,10 +166,7 @@ class ManufacturerPartList(ListCreateDestroyAPIView): """ queryset = ManufacturerPart.objects.all().prefetch_related( - 'part', - 'manufacturer', - 'supplier_parts', - 'tags', + 'part', 'manufacturer', 'supplier_parts', 'tags' ) serializer_class = ManufacturerPartSerializer @@ -200,7 +179,9 @@ class ManufacturerPartList(ListCreateDestroyAPIView): params = self.request.query_params kwargs['part_detail'] = str2bool(params.get('part_detail', None)) - kwargs['manufacturer_detail'] = str2bool(params.get('manufacturer_detail', None)) + kwargs['manufacturer_detail'] = str2bool( + params.get('manufacturer_detail', None) + ) kwargs['pretty'] = str2bool(params.get('pretty', None)) except AttributeError: pass @@ -241,9 +222,7 @@ class ManufacturerPartAttachmentList(AttachmentMixin, ListCreateDestroyAPIView): queryset = ManufacturerPartAttachment.objects.all() serializer_class = ManufacturerPartAttachmentSerializer - filterset_fields = [ - 'manufacturer_part', - ] + filterset_fields = ['manufacturer_part'] class ManufacturerPartAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI): @@ -258,17 +237,17 @@ class ManufacturerPartParameterFilter(rest_filters.FilterSet): class Meta: """Metaclass options""" + model = ManufacturerPartParameter - fields = [ - 'name', - 'value', - 'units', - 'manufacturer_part', - ] + fields = ['name', 'value', 'units', 'manufacturer_part'] - manufacturer = rest_filters.ModelChoiceFilter(queryset=Company.objects.all(), field_name='manufacturer_part__manufacturer') + manufacturer = rest_filters.ModelChoiceFilter( + queryset=Company.objects.all(), field_name='manufacturer_part__manufacturer' + ) - part = rest_filters.ModelChoiceFilter(queryset=part.models.Part.objects.all(), field_name='manufacturer_part__part') + part = rest_filters.ModelChoiceFilter( + queryset=part.models.Part.objects.all(), field_name='manufacturer_part__part' + ) class ManufacturerPartParameterList(ListCreateDestroyAPIView): @@ -284,9 +263,7 @@ class ManufacturerPartParameterList(ListCreateDestroyAPIView): try: params = self.request.query_params - optional_fields = [ - 'manufacturer_part_detail', - ] + optional_fields = ['manufacturer_part_detail'] for key in optional_fields: kwargs[key] = str2bool(params.get(key, None)) @@ -300,11 +277,7 @@ class ManufacturerPartParameterList(ListCreateDestroyAPIView): filter_backends = SEARCH_ORDER_FILTER - search_fields = [ - 'name', - 'value', - 'units', - ] + search_fields = ['name', 'value', 'units'] class ManufacturerPartParameterDetail(RetrieveUpdateDestroyAPI): @@ -337,7 +310,7 @@ class SupplierPartFilter(rest_filters.FilterSet): MPN = rest_filters.CharFilter( label='Manufacturer Part Number', field_name='manufacturer_part__MPN', - lookup_expr='iexact' + lookup_expr='iexact', ) @@ -348,9 +321,7 @@ class SupplierPartList(ListCreateDestroyAPIView): - POST: Create a new SupplierPart object """ - queryset = SupplierPart.objects.all().prefetch_related( - 'tags', - ) + queryset = SupplierPart.objects.all().prefetch_related('tags') filterset_class = SupplierPartFilter def get_queryset(self, *args, **kwargs): @@ -376,7 +347,9 @@ class SupplierPartList(ListCreateDestroyAPIView): company = params.get('company', None) if company is not None: - queryset = queryset.filter(Q(manufacturer_part__manufacturer=company) | Q(supplier=company)).distinct() + queryset = queryset.filter( + Q(manufacturer_part__manufacturer=company) | Q(supplier=company) + ).distinct() return queryset @@ -387,7 +360,9 @@ class SupplierPartList(ListCreateDestroyAPIView): params = self.request.query_params kwargs['part_detail'] = str2bool(params.get('part_detail', None)) kwargs['supplier_detail'] = str2bool(params.get('supplier_detail', True)) - kwargs['manufacturer_detail'] = str2bool(params.get('manufacturer_detail', None)) + kwargs['manufacturer_detail'] = str2bool( + params.get('manufacturer_detail', None) + ) kwargs['pretty'] = str2bool(params.get('pretty', None)) except AttributeError: pass @@ -446,8 +421,7 @@ class SupplierPartDetail(RetrieveUpdateDestroyAPI): queryset = SupplierPart.objects.all() serializer_class = SupplierPartSerializer - read_only_fields = [ - ] + read_only_fields = [] class SupplierPriceBreakFilter(rest_filters.FilterSet): @@ -457,10 +431,7 @@ class SupplierPriceBreakFilter(rest_filters.FilterSet): """Metaclass options""" model = SupplierPriceBreak - fields = [ - 'part', - 'quantity', - ] + fields = ['part', 'quantity'] base_part = rest_filters.ModelChoiceFilter( label='Base Part', @@ -469,9 +440,7 @@ class SupplierPriceBreakFilter(rest_filters.FilterSet): ) supplier = rest_filters.ModelChoiceFilter( - label='Supplier', - queryset=Company.objects.all(), - field_name='part__supplier', + label='Supplier', queryset=Company.objects.all(), field_name='part__supplier' ) @@ -503,9 +472,7 @@ class SupplierPriceBreakList(ListCreateAPI): filter_backends = ORDER_FILTER - ordering_fields = [ - 'quantity', - ] + ordering_fields = ['quantity'] ordering = 'quantity' @@ -518,37 +485,74 @@ class SupplierPriceBreakDetail(RetrieveUpdateDestroyAPI): manufacturer_part_api_urls = [ - # Base URL for ManufacturerPartAttachment API endpoints - re_path(r'^attachment/', include([ - path(r'/', ManufacturerPartAttachmentDetail.as_view(), name='api-manufacturer-part-attachment-detail'), - re_path(r'^$', ManufacturerPartAttachmentList.as_view(), name='api-manufacturer-part-attachment-list'), - ])), - - re_path(r'^parameter/', include([ - path(r'/', ManufacturerPartParameterDetail.as_view(), name='api-manufacturer-part-parameter-detail'), - - # Catch anything else - re_path(r'^.*$', ManufacturerPartParameterList.as_view(), name='api-manufacturer-part-parameter-list'), - ])), - - re_path(r'^(?P\d+)/?', include([ - re_path('^metadata/', MetadataView.as_view(), {'model': ManufacturerPart}, name='api-manufacturer-part-metadata'), - re_path('^.*$', ManufacturerPartDetail.as_view(), name='api-manufacturer-part-detail'), - ])), - + re_path( + r'^attachment/', + include([ + path( + r'/', + ManufacturerPartAttachmentDetail.as_view(), + name='api-manufacturer-part-attachment-detail', + ), + re_path( + r'^$', + ManufacturerPartAttachmentList.as_view(), + name='api-manufacturer-part-attachment-list', + ), + ]), + ), + re_path( + r'^parameter/', + include([ + path( + r'/', + ManufacturerPartParameterDetail.as_view(), + name='api-manufacturer-part-parameter-detail', + ), + # Catch anything else + re_path( + r'^.*$', + ManufacturerPartParameterList.as_view(), + name='api-manufacturer-part-parameter-list', + ), + ]), + ), + re_path( + r'^(?P\d+)/?', + include([ + re_path( + '^metadata/', + MetadataView.as_view(), + {'model': ManufacturerPart}, + name='api-manufacturer-part-metadata', + ), + re_path( + '^.*$', + ManufacturerPartDetail.as_view(), + name='api-manufacturer-part-detail', + ), + ]), + ), # Catch anything else re_path(r'^.*$', ManufacturerPartList.as_view(), name='api-manufacturer-part-list'), ] supplier_part_api_urls = [ - - re_path(r'^(?P\d+)/?', include([ - re_path('^metadata/', MetadataView.as_view(), {'model': SupplierPart}, name='api-supplier-part-metadata'), - re_path('^.*$', SupplierPartDetail.as_view(), name='api-supplier-part-detail'), - ])), - + re_path( + r'^(?P\d+)/?', + include([ + re_path( + '^metadata/', + MetadataView.as_view(), + {'model': SupplierPart}, + name='api-supplier-part-metadata', + ), + re_path( + '^.*$', SupplierPartDetail.as_view(), name='api-supplier-part-detail' + ), + ]), + ), # Catch anything else re_path(r'^.*$', SupplierPartList.as_view(), name='api-supplier-part-list'), ] @@ -556,39 +560,74 @@ supplier_part_api_urls = [ company_api_urls = [ re_path(r'^part/manufacturer/', include(manufacturer_part_api_urls)), - re_path(r'^part/', include(supplier_part_api_urls)), - # Supplier price breaks - re_path(r'^price-break/', include([ - - re_path(r'^(?P\d+)/?', SupplierPriceBreakDetail.as_view(), name='api-part-supplier-price-detail'), - re_path(r'^.*$', SupplierPriceBreakList.as_view(), name='api-part-supplier-price-list'), - ])), - - re_path(r'^(?P\d+)/?', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': Company}, name='api-company-metadata'), - re_path(r'^.*$', CompanyDetail.as_view(), name='api-company-detail'), - ])), - - re_path(r'^attachment/', include([ - path(r'/', CompanyAttachmentDetail.as_view(), name='api-company-attachment-detail'), - re_path(r'^$', CompanyAttachmentList.as_view(), name='api-company-attachment-list'), - ])), - - re_path(r'^contact/', include([ - re_path(r'^(?P\d+)/?', include([ - re_path('^metadata/', MetadataView.as_view(), {'model': Contact}, name='api-contact-metadata'), - re_path('^.*$', ContactDetail.as_view(), name='api-contact-detail'), - ])), - re_path(r'^.*$', ContactList.as_view(), name='api-contact-list'), - ])), - - re_path(r'^address/', include([ - path('/', AddressDetail.as_view(), name='api-address-detail'), - re_path(r'^.*$', AddressList.as_view(), name='api-address-list'), - ])), - + re_path( + r'^price-break/', + include([ + re_path( + r'^(?P\d+)/?', + SupplierPriceBreakDetail.as_view(), + name='api-part-supplier-price-detail', + ), + re_path( + r'^.*$', + SupplierPriceBreakList.as_view(), + name='api-part-supplier-price-list', + ), + ]), + ), + re_path( + r'^(?P\d+)/?', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': Company}, + name='api-company-metadata', + ), + re_path(r'^.*$', CompanyDetail.as_view(), name='api-company-detail'), + ]), + ), + re_path( + r'^attachment/', + include([ + path( + r'/', + CompanyAttachmentDetail.as_view(), + name='api-company-attachment-detail', + ), + re_path( + r'^$', + CompanyAttachmentList.as_view(), + name='api-company-attachment-list', + ), + ]), + ), + re_path( + r'^contact/', + include([ + re_path( + r'^(?P\d+)/?', + include([ + re_path( + '^metadata/', + MetadataView.as_view(), + {'model': Contact}, + name='api-contact-metadata', + ), + re_path('^.*$', ContactDetail.as_view(), name='api-contact-detail'), + ]), + ), + re_path(r'^.*$', ContactList.as_view(), name='api-contact-list'), + ]), + ), + re_path( + r'^address/', + include([ + path('/', AddressDetail.as_view(), name='api-address-detail'), + re_path(r'^.*$', AddressList.as_view(), name='api-address-list'), + ]), + ), re_path(r'^.*$', CompanyList.as_view(), name='api-company-list'), - ] diff --git a/InvenTree/company/models.py b/InvenTree/company/models.py index 154de622ca..9097708697 100644 --- a/InvenTree/company/models.py +++ b/InvenTree/company/models.py @@ -29,8 +29,12 @@ import InvenTree.tasks import InvenTree.validators from common.settings import currency_code_default from InvenTree.fields import InvenTreeURLField, RoundingDecimalField -from InvenTree.models import (InvenTreeAttachment, InvenTreeBarcodeMixin, - InvenTreeNotesMixin, MetadataMixin) +from InvenTree.models import ( + InvenTreeAttachment, + InvenTreeBarcodeMixin, + InvenTreeNotesMixin, + MetadataMixin, +) from InvenTree.status_codes import PurchaseOrderStatusGroups @@ -87,20 +91,24 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model): class Meta: """Metaclass defines extra model options""" - ordering = ['name', ] + + ordering = ['name'] constraints = [ UniqueConstraint(fields=['name', 'email'], name='unique_name_email_pair') ] - verbose_name_plural = "Companies" + verbose_name_plural = 'Companies' @staticmethod def get_api_url(): """Return the API URL associated with the Company model""" return reverse('api-company-list') - name = models.CharField(max_length=100, blank=False, - help_text=_('Company name'), - verbose_name=_('Company name')) + name = models.CharField( + max_length=100, + blank=False, + help_text=_('Company name'), + verbose_name=_('Company name'), + ) description = models.CharField( max_length=500, @@ -110,41 +118,62 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model): ) website = InvenTreeURLField( - blank=True, - verbose_name=_('Website'), - help_text=_('Company website URL') + blank=True, verbose_name=_('Website'), help_text=_('Company website URL') ) - phone = models.CharField(max_length=50, - verbose_name=_('Phone number'), - blank=True, help_text=_('Contact phone number')) + phone = models.CharField( + max_length=50, + verbose_name=_('Phone number'), + blank=True, + help_text=_('Contact phone number'), + ) - email = models.EmailField(blank=True, null=True, - verbose_name=_('Email'), help_text=_('Contact email address')) + email = models.EmailField( + blank=True, + null=True, + verbose_name=_('Email'), + help_text=_('Contact email address'), + ) - contact = models.CharField(max_length=100, - verbose_name=_('Contact'), - blank=True, help_text=_('Point of contact')) + contact = models.CharField( + max_length=100, + verbose_name=_('Contact'), + blank=True, + help_text=_('Point of contact'), + ) - link = InvenTreeURLField(blank=True, verbose_name=_('Link'), help_text=_('Link to external company information')) + link = InvenTreeURLField( + blank=True, + verbose_name=_('Link'), + help_text=_('Link to external company information'), + ) image = StdImageField( upload_to=rename_company_image, null=True, blank=True, - variations={ - 'thumbnail': (128, 128), - 'preview': (256, 256), - }, + variations={'thumbnail': (128, 128), 'preview': (256, 256)}, delete_orphans=True, verbose_name=_('Image'), ) - is_customer = models.BooleanField(default=False, verbose_name=_('is customer'), help_text=_('Do you sell items to this company?')) + is_customer = models.BooleanField( + default=False, + verbose_name=_('is customer'), + help_text=_('Do you sell items to this company?'), + ) - is_supplier = models.BooleanField(default=True, verbose_name=_('is supplier'), help_text=_('Do you purchase items from this company?')) + is_supplier = models.BooleanField( + default=True, + verbose_name=_('is supplier'), + help_text=_('Do you purchase items from this company?'), + ) - is_manufacturer = models.BooleanField(default=False, verbose_name=_('is manufacturer'), help_text=_('Does this company manufacture parts?')) + is_manufacturer = models.BooleanField( + default=False, + verbose_name=_('is manufacturer'), + help_text=_('Does this company manufacture parts?'), + ) currency = models.CharField( max_length=3, @@ -186,7 +215,7 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model): def __str__(self): """Get string representation of a Company.""" - return f"{self.name} - {self.description}" + return f'{self.name} - {self.description}' def get_absolute_url(self): """Get the web URL for the detail view for this Company.""" @@ -207,13 +236,18 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model): @property def parts(self): """Return SupplierPart objects which are supplied or manufactured by this company.""" - return SupplierPart.objects.filter(Q(supplier=self.id) | Q(manufacturer_part__manufacturer=self.id)).distinct() + return SupplierPart.objects.filter( + Q(supplier=self.id) | Q(manufacturer_part__manufacturer=self.id) + ).distinct() @property def stock_items(self): """Return a list of all stock items supplied or manufactured by this company.""" stock = apps.get_model('stock', 'StockItem') - return stock.objects.filter(Q(supplier_part__supplier=self.id) | Q(supplier_part__manufacturer_part__manufacturer=self.id)).distinct() + return stock.objects.filter( + Q(supplier_part__supplier=self.id) + | Q(supplier_part__manufacturer_part__manufacturer=self.id) + ).distinct() class CompanyAttachment(InvenTreeAttachment): @@ -229,7 +263,8 @@ class CompanyAttachment(InvenTreeAttachment): return os.path.join('company_files', str(self.company.pk)) company = models.ForeignKey( - Company, on_delete=models.CASCADE, + Company, + on_delete=models.CASCADE, verbose_name=_('Company'), related_name='attachments', ) @@ -251,8 +286,9 @@ class Contact(MetadataMixin, models.Model): """Return the API URL associated with the Contcat model""" return reverse('api-contact-list') - company = models.ForeignKey(Company, related_name='contacts', - on_delete=models.CASCADE) + company = models.ForeignKey( + Company, related_name='contacts', on_delete=models.CASCADE + ) name = models.CharField(max_length=100) @@ -281,7 +317,8 @@ class Address(models.Model): class Meta: """Metaclass defines extra model options""" - verbose_name_plural = "Addresses" + + verbose_name_plural = 'Addresses' def __init__(self, *args, **kwargs): """Custom init function""" @@ -289,27 +326,30 @@ class Address(models.Model): def __str__(self): """Defines string representation of address to supple a one-line to API calls""" - available_lines = [self.line1, - self.line2, - self.postal_code, - self.postal_city, - self.province, - self.country - ] + available_lines = [ + self.line1, + self.line2, + self.postal_code, + self.postal_city, + self.province, + self.country, + ] populated_lines = [] for line in available_lines: if len(line) > 0: populated_lines.append(line) - return ", ".join(populated_lines) + return ', '.join(populated_lines) def save(self, *args, **kwargs): """Run checks when saving an address: - If this address is marked as "primary", ensure that all other addresses for this company are marked as non-primary """ - others = list(Address.objects.filter(company=self.company).exclude(pk=self.pk).all()) + others = list( + Address.objects.filter(company=self.company).exclude(pk=self.pk).all() + ) # If this is the *only* address for this company, make it the primary one if len(others) == 0: @@ -329,63 +369,88 @@ class Address(models.Model): """Return the API URL associated with the Contcat model""" return reverse('api-address-list') - company = models.ForeignKey(Company, related_name='addresses', - on_delete=models.CASCADE, - verbose_name=_('Company'), - help_text=_('Select company')) + company = models.ForeignKey( + Company, + related_name='addresses', + on_delete=models.CASCADE, + verbose_name=_('Company'), + help_text=_('Select company'), + ) - title = models.CharField(max_length=100, - verbose_name=_('Address title'), - help_text=_('Title describing the address entry'), - blank=False) + title = models.CharField( + max_length=100, + verbose_name=_('Address title'), + help_text=_('Title describing the address entry'), + blank=False, + ) - primary = models.BooleanField(default=False, - verbose_name=_('Primary address'), - help_text=_('Set as primary address')) + primary = models.BooleanField( + default=False, + verbose_name=_('Primary address'), + help_text=_('Set as primary address'), + ) - line1 = models.CharField(max_length=50, - verbose_name=_('Line 1'), - help_text=_('Address line 1'), - blank=True) + line1 = models.CharField( + max_length=50, + verbose_name=_('Line 1'), + help_text=_('Address line 1'), + blank=True, + ) - line2 = models.CharField(max_length=50, - verbose_name=_('Line 2'), - help_text=_('Address line 2'), - blank=True) + line2 = models.CharField( + max_length=50, + verbose_name=_('Line 2'), + help_text=_('Address line 2'), + blank=True, + ) - postal_code = models.CharField(max_length=10, - verbose_name=_('Postal code'), - help_text=_('Postal code'), - blank=True) + postal_code = models.CharField( + max_length=10, + verbose_name=_('Postal code'), + help_text=_('Postal code'), + blank=True, + ) - postal_city = models.CharField(max_length=50, - verbose_name=_('City/Region'), - help_text=_('Postal code city/region'), - blank=True) + postal_city = models.CharField( + max_length=50, + verbose_name=_('City/Region'), + help_text=_('Postal code city/region'), + blank=True, + ) - province = models.CharField(max_length=50, - verbose_name=_('State/Province'), - help_text=_('State or province'), - blank=True) + province = models.CharField( + max_length=50, + verbose_name=_('State/Province'), + help_text=_('State or province'), + blank=True, + ) - country = models.CharField(max_length=50, - verbose_name=_('Country'), - help_text=_('Address country'), - blank=True) + country = models.CharField( + max_length=50, + verbose_name=_('Country'), + help_text=_('Address country'), + blank=True, + ) - shipping_notes = models.CharField(max_length=100, - verbose_name=_('Courier shipping notes'), - help_text=_('Notes for shipping courier'), - blank=True) + shipping_notes = models.CharField( + max_length=100, + verbose_name=_('Courier shipping notes'), + help_text=_('Notes for shipping courier'), + blank=True, + ) - internal_shipping_notes = models.CharField(max_length=100, - verbose_name=_('Internal shipping notes'), - help_text=_('Shipping notes for internal use'), - blank=True) + internal_shipping_notes = models.CharField( + max_length=100, + verbose_name=_('Internal shipping notes'), + help_text=_('Shipping notes for internal use'), + blank=True, + ) - link = InvenTreeURLField(blank=True, - verbose_name=_('Link'), - help_text=_('Link to address information (external)')) + link = InvenTreeURLField( + blank=True, + verbose_name=_('Link'), + help_text=_('Link to address information (external)'), + ) class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model): @@ -401,6 +466,7 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model): class Meta: """Metaclass defines extra model options""" + unique_together = ('part', 'manufacturer', 'MPN') @staticmethod @@ -408,23 +474,21 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model): """Return the API URL associated with the ManufacturerPart instance""" return reverse('api-manufacturer-part-list') - part = models.ForeignKey('part.Part', on_delete=models.CASCADE, - related_name='manufacturer_parts', - verbose_name=_('Base Part'), - limit_choices_to={ - 'purchaseable': True, - }, - help_text=_('Select part'), - ) + part = models.ForeignKey( + 'part.Part', + on_delete=models.CASCADE, + related_name='manufacturer_parts', + verbose_name=_('Base Part'), + limit_choices_to={'purchaseable': True}, + help_text=_('Select part'), + ) manufacturer = models.ForeignKey( Company, on_delete=models.CASCADE, null=True, related_name='manufactured_parts', - limit_choices_to={ - 'is_manufacturer': True - }, + limit_choices_to={'is_manufacturer': True}, verbose_name=_('Manufacturer'), help_text=_('Select manufacturer'), ) @@ -433,19 +497,22 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model): null=True, max_length=100, verbose_name=_('MPN'), - help_text=_('Manufacturer Part Number') + help_text=_('Manufacturer Part Number'), ) link = InvenTreeURLField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Link'), - help_text=_('URL for external manufacturer part link') + help_text=_('URL for external manufacturer part link'), ) description = models.CharField( - max_length=250, blank=True, null=True, + max_length=250, + blank=True, + null=True, verbose_name=_('Description'), - help_text=_('Manufacturer part description') + help_text=_('Manufacturer part description'), ) tags = TaggableManager(blank=True) @@ -456,12 +523,20 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model): manufacturer_part = None try: - manufacturer_part = ManufacturerPart.objects.get(part=part, manufacturer=manufacturer, MPN=mpn) + manufacturer_part = ManufacturerPart.objects.get( + part=part, manufacturer=manufacturer, MPN=mpn + ) except ManufacturerPart.DoesNotExist: pass if not manufacturer_part: - manufacturer_part = ManufacturerPart(part=part, manufacturer=manufacturer, MPN=mpn, description=description, link=link) + manufacturer_part = ManufacturerPart( + part=part, + manufacturer=manufacturer, + MPN=mpn, + description=description, + link=link, + ) manufacturer_part.save() return manufacturer_part @@ -489,10 +564,14 @@ class ManufacturerPartAttachment(InvenTreeAttachment): def getSubdir(self): """Return the subdirectory where attachment files for the ManufacturerPart model are located""" - return os.path.join("manufacturer_part_files", str(self.manufacturer_part.id)) + return os.path.join('manufacturer_part_files', str(self.manufacturer_part.id)) - manufacturer_part = models.ForeignKey(ManufacturerPart, on_delete=models.CASCADE, - verbose_name=_('Manufacturer Part'), related_name='attachments') + manufacturer_part = models.ForeignKey( + ManufacturerPart, + on_delete=models.CASCADE, + verbose_name=_('Manufacturer Part'), + related_name='attachments', + ) class ManufacturerPartParameter(models.Model): @@ -505,6 +584,7 @@ class ManufacturerPartParameter(models.Model): class Meta: """Metaclass defines extra model options""" + unique_together = ('manufacturer_part', 'name') @staticmethod @@ -523,21 +603,22 @@ class ManufacturerPartParameter(models.Model): max_length=500, blank=False, verbose_name=_('Name'), - help_text=_('Parameter name') + help_text=_('Parameter name'), ) value = models.CharField( max_length=500, blank=False, verbose_name=_('Value'), - help_text=_('Parameter value') + help_text=_('Parameter value'), ) units = models.CharField( max_length=64, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Units'), - help_text=_('Parameter units') + help_text=_('Parameter units'), ) @@ -551,10 +632,10 @@ class SupplierPartManager(models.Manager): def get_queryset(self): """Prefetch related fields when querying against the SupplierPart model""" # Always prefetch related models - return super().get_queryset().prefetch_related( - 'part', - 'supplier', - 'manufacturer_part__manufacturer', + return ( + super() + .get_queryset() + .prefetch_related('part', 'supplier', 'manufacturer_part__manufacturer') ) @@ -580,6 +661,7 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin class Meta: """Metaclass defines extra model options""" + unique_together = ('part', 'supplier', 'SKU') # This model was moved from the 'Part' app @@ -600,11 +682,7 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin def api_instance_filters(self): """Return custom API filters for this particular instance""" - return { - 'manufacturer_part': { - 'part': self.part.pk - } - } + return {'manufacturer_part': {'part': self.part.pk}} def clean(self): """Custom clean action for the SupplierPart model: @@ -624,36 +702,38 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin try: # Attempt conversion to specified unit native_value = InvenTree.conversion.convert_physical_value( - self.pack_quantity, self.part.units, - strip_units=False + self.pack_quantity, self.part.units, strip_units=False ) # If part units are not provided, value must be dimensionless - if not self.part.units and not InvenTree.conversion.is_dimensionless(native_value): + if not self.part.units and not InvenTree.conversion.is_dimensionless( + native_value + ): raise ValidationError({ - 'pack_quantity': _("Pack units must be compatible with the base part units") + 'pack_quantity': _( + 'Pack units must be compatible with the base part units' + ) }) # Native value must be greater than zero if float(native_value.magnitude) <= 0: raise ValidationError({ - 'pack_quantity': _("Pack units must be greater than zero") + 'pack_quantity': _('Pack units must be greater than zero') }) # Update native pack units value self.pack_quantity_native = Decimal(native_value.magnitude) except ValidationError as e: - raise ValidationError({ - 'pack_quantity': e.messages - }) + raise ValidationError({'pack_quantity': e.messages}) # Ensure that the linked manufacturer_part points to the same part! if self.manufacturer_part and self.part: - if self.manufacturer_part.part != self.part: raise ValidationError({ - 'manufacturer_part': _("Linked manufacturer part must reference the same base part"), + 'manufacturer_part': _( + 'Linked manufacturer part must reference the same base part' + ) }) def save(self, *args, **kwargs): @@ -666,7 +746,9 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin # Retrieve manufacturer part try: - manufacturer_part = ManufacturerPart.objects.get(manufacturer__name=manufacturer_name, MPN=MPN) + manufacturer_part = ManufacturerPart.objects.get( + manufacturer__name=manufacturer_name, MPN=MPN + ) except (ValueError, Company.DoesNotExist): # ManufacturerPart does not exist pass @@ -676,7 +758,9 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin # Connect ManufacturerPart to SupplierPart self.manufacturer_part = manufacturer_part else: - raise ValidationError(f'SupplierPart {self.__str__} is already linked to {self.manufacturer_part}') + raise ValidationError( + f'SupplierPart {self.__str__} is already linked to {self.manufacturer_part}' + ) self.clean() self.validate_unique() @@ -684,17 +768,17 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin super().save(*args, **kwargs) part = models.ForeignKey( - 'part.Part', on_delete=models.CASCADE, + 'part.Part', + on_delete=models.CASCADE, related_name='supplier_parts', verbose_name=_('Base Part'), - limit_choices_to={ - 'purchaseable': True, - }, + limit_choices_to={'purchaseable': True}, help_text=_('Select part'), ) supplier = models.ForeignKey( - Company, on_delete=models.CASCADE, + Company, + on_delete=models.CASCADE, related_name='supplied_parts', limit_choices_to={'is_supplier': True}, verbose_name=_('Supplier'), @@ -703,50 +787,71 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin SKU = models.CharField( max_length=100, - verbose_name=__("SKU = Stock Keeping Unit (supplier part number)", 'SKU'), - help_text=_('Supplier stock keeping unit') + verbose_name=__('SKU = Stock Keeping Unit (supplier part number)', 'SKU'), + help_text=_('Supplier stock keeping unit'), ) manufacturer_part = models.ForeignKey( - ManufacturerPart, on_delete=models.CASCADE, - blank=True, null=True, + ManufacturerPart, + on_delete=models.CASCADE, + blank=True, + null=True, related_name='supplier_parts', verbose_name=_('Manufacturer Part'), help_text=_('Select manufacturer part'), ) link = InvenTreeURLField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Link'), - help_text=_('URL for external supplier part link') + help_text=_('URL for external supplier part link'), ) description = models.CharField( - max_length=250, blank=True, null=True, + max_length=250, + blank=True, + null=True, verbose_name=_('Description'), - help_text=_('Supplier part description') + help_text=_('Supplier part description'), ) note = models.CharField( - max_length=100, blank=True, null=True, + max_length=100, + blank=True, + null=True, verbose_name=_('Note'), - help_text=_('Notes') + help_text=_('Notes'), ) - base_cost = models.DecimalField(max_digits=10, decimal_places=3, default=0, validators=[MinValueValidator(0)], verbose_name=_('base cost'), help_text=_('Minimum charge (e.g. stocking fee)')) + base_cost = models.DecimalField( + max_digits=10, + decimal_places=3, + default=0, + validators=[MinValueValidator(0)], + verbose_name=_('base cost'), + help_text=_('Minimum charge (e.g. stocking fee)'), + ) - packaging = models.CharField(max_length=50, blank=True, null=True, verbose_name=_('Packaging'), help_text=_('Part packaging')) + packaging = models.CharField( + max_length=50, + blank=True, + null=True, + verbose_name=_('Packaging'), + help_text=_('Part packaging'), + ) pack_quantity = models.CharField( max_length=25, verbose_name=_('Pack Quantity'), - help_text=_('Total quantity supplied in a single pack. Leave empty for single items.'), + help_text=_( + 'Total quantity supplied in a single pack. Leave empty for single items.' + ), blank=True, ) pack_quantity_native = RoundingDecimalField( - max_digits=20, decimal_places=10, default=1, - null=True, + max_digits=20, decimal_places=10, default=1, null=True ) def base_quantity(self, quantity=1) -> Decimal: @@ -756,20 +861,29 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin return q - multiple = models.PositiveIntegerField(default=1, validators=[MinValueValidator(1)], verbose_name=_('multiple'), help_text=_('Order multiple')) + multiple = models.PositiveIntegerField( + default=1, + validators=[MinValueValidator(1)], + verbose_name=_('multiple'), + help_text=_('Order multiple'), + ) # TODO - Reimplement lead-time as a charfield with special validation (pattern matching). # lead_time = models.DurationField(blank=True, null=True) available = models.DecimalField( - max_digits=10, decimal_places=3, default=0, + max_digits=10, + decimal_places=3, + default=0, validators=[MinValueValidator(0)], verbose_name=_('Available'), help_text=_('Quantity available from supplier'), ) availability_updated = models.DateTimeField( - null=True, blank=True, verbose_name=_('Availability Updated'), + null=True, + blank=True, + verbose_name=_('Availability Updated'), help_text=_('Date of last update of availability data'), ) @@ -821,17 +935,15 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin if self.price_breaks.filter(quantity=quantity, part=self.pk).exists(): return - SupplierPriceBreak.objects.create( - part=self, - quantity=quantity, - price=price - ) + SupplierPriceBreak.objects.create(part=self, quantity=quantity, price=price) get_price = common.models.get_price def open_orders(self): """Return a database query for PurchaseOrder line items for this SupplierPart, limited to purchase orders that are open / outstanding.""" - return self.purchase_order_line_items.prefetch_related('order').filter(order__status__in=PurchaseOrderStatusGroups.OPEN) + return self.purchase_order_line_items.prefetch_related('order').filter( + order__status__in=PurchaseOrderStatusGroups.OPEN + ) def on_order(self): """Return the total quantity of items currently on order. @@ -852,7 +964,10 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin def purchase_orders(self): """Returns a list of purchase orders relating to this supplier part.""" - return [line.order for line in self.purchase_order_line_items.all().prefetch_related('order')] + return [ + line.order + for line in self.purchase_order_line_items.all().prefetch_related('order') + ] @property def pretty_name(self): @@ -891,7 +1006,8 @@ class SupplierPriceBreak(common.models.PriceBreak): class Meta: """Metaclass defines extra model options""" - unique_together = ("part", "quantity") + + unique_together = ('part', 'quantity') # This model was moved from the 'Part' app db_table = 'part_supplierpricebreak' @@ -905,22 +1021,31 @@ class SupplierPriceBreak(common.models.PriceBreak): """Return the API URL associated with the SupplierPriceBreak model""" return reverse('api-part-supplier-price-list') - part = models.ForeignKey(SupplierPart, on_delete=models.CASCADE, related_name='pricebreaks', verbose_name=_('Part'),) + part = models.ForeignKey( + SupplierPart, + on_delete=models.CASCADE, + related_name='pricebreaks', + verbose_name=_('Part'), + ) -@receiver(post_save, sender=SupplierPriceBreak, dispatch_uid='post_save_supplier_price_break') +@receiver( + post_save, sender=SupplierPriceBreak, dispatch_uid='post_save_supplier_price_break' +) def after_save_supplier_price(sender, instance, created, **kwargs): """Callback function when a SupplierPriceBreak is created or updated""" if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData(): - if instance.part and instance.part.part: instance.part.part.schedule_pricing_update(create=True) -@receiver(post_delete, sender=SupplierPriceBreak, dispatch_uid='post_delete_supplier_price_break') +@receiver( + post_delete, + sender=SupplierPriceBreak, + dispatch_uid='post_delete_supplier_price_break', +) def after_delete_supplier_price(sender, instance, **kwargs): """Callback function when a SupplierPriceBreak is deleted""" if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData(): - if instance.part and instance.part.part: instance.part.part.schedule_pricing_update(create=False) diff --git a/InvenTree/company/serializers.py b/InvenTree/company/serializers.py index 6d5911749b..192187a3b5 100644 --- a/InvenTree/company/serializers.py +++ b/InvenTree/company/serializers.py @@ -10,20 +10,29 @@ from sql_util.utils import SubqueryCount from taggit.serializers import TagListSerializerField import part.filters -from InvenTree.serializers import (InvenTreeAttachmentSerializer, - InvenTreeCurrencySerializer, - InvenTreeDecimalField, - InvenTreeImageSerializerField, - InvenTreeModelSerializer, - InvenTreeMoneySerializer, - InvenTreeTagModelSerializer, - RemoteImageMixin) +from InvenTree.serializers import ( + InvenTreeAttachmentSerializer, + InvenTreeCurrencySerializer, + InvenTreeDecimalField, + InvenTreeImageSerializerField, + InvenTreeModelSerializer, + InvenTreeMoneySerializer, + InvenTreeTagModelSerializer, + RemoteImageMixin, +) from part.serializers import PartBriefSerializer -from .models import (Address, Company, CompanyAttachment, Contact, - ManufacturerPart, ManufacturerPartAttachment, - ManufacturerPartParameter, SupplierPart, - SupplierPriceBreak) +from .models import ( + Address, + Company, + CompanyAttachment, + Contact, + ManufacturerPart, + ManufacturerPartAttachment, + ManufacturerPartParameter, + SupplierPart, + SupplierPriceBreak, +) class CompanyBriefSerializer(InvenTreeModelSerializer): @@ -33,13 +42,7 @@ class CompanyBriefSerializer(InvenTreeModelSerializer): """Metaclass options.""" model = Company - fields = [ - 'pk', - 'url', - 'name', - 'description', - 'image', - ] + fields = ['pk', 'url', 'name', 'description', 'image'] url = serializers.CharField(source='get_absolute_url', read_only=True) @@ -86,7 +89,7 @@ class AddressBriefSerializer(InvenTreeModelSerializer): 'province', 'country', 'shipping_notes', - 'internal_shipping_notes' + 'internal_shipping_notes', ] @@ -119,7 +122,7 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer): 'parts_manufactured', 'remote_image', 'address_count', - 'primary_address' + 'primary_address', ] @staticmethod @@ -130,13 +133,9 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer): parts_manufactured=SubqueryCount('manufactured_parts') ) - queryset = queryset.annotate( - parts_supplied=SubqueryCount('supplied_parts') - ) + queryset = queryset.annotate(parts_supplied=SubqueryCount('supplied_parts')) - queryset = queryset.annotate( - address_count=SubqueryCount('addresses') - ) + queryset = queryset.annotate(address_count=SubqueryCount('addresses')) return queryset @@ -150,7 +149,9 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer): parts_manufactured = serializers.IntegerField(read_only=True) address_count = serializers.IntegerField(read_only=True) - currency = InvenTreeCurrencySerializer(help_text=_('Default currency used for this supplier'), required=True) + currency = InvenTreeCurrencySerializer( + help_text=_('Default currency used for this supplier'), required=True + ) def save(self): """Save the Company instance""" @@ -167,12 +168,9 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer): remote_img.save(buffer, format=fmt) # Construct a simplified name for the image - filename = f"company_{company.pk}_image.{fmt.lower()}" + filename = f'company_{company.pk}_image.{fmt.lower()}' - company.image.save( - filename, - ContentFile(buffer.getvalue()), - ) + company.image.save(filename, ContentFile(buffer.getvalue())) return self.instance @@ -182,11 +180,10 @@ class CompanyAttachmentSerializer(InvenTreeAttachmentSerializer): class Meta: """Metaclass defines serializer options""" + model = CompanyAttachment - fields = InvenTreeAttachmentSerializer.attachment_fields([ - 'company', - ]) + fields = InvenTreeAttachmentSerializer.attachment_fields(['company']) class ContactSerializer(InvenTreeModelSerializer): @@ -196,14 +193,7 @@ class ContactSerializer(InvenTreeModelSerializer): """Metaclass options""" model = Contact - fields = [ - 'pk', - 'company', - 'name', - 'phone', - 'email', - 'role', - ] + fields = ['pk', 'company', 'name', 'phone', 'email', 'role'] class ManufacturerPartSerializer(InvenTreeTagModelSerializer): @@ -224,7 +214,6 @@ class ManufacturerPartSerializer(InvenTreeTagModelSerializer): 'MPN', 'link', 'barcode_hash', - 'tags', ] @@ -249,11 +238,15 @@ class ManufacturerPartSerializer(InvenTreeTagModelSerializer): part_detail = PartBriefSerializer(source='part', many=False, read_only=True) - manufacturer_detail = CompanyBriefSerializer(source='manufacturer', many=False, read_only=True) + manufacturer_detail = CompanyBriefSerializer( + source='manufacturer', many=False, read_only=True + ) pretty_name = serializers.CharField(read_only=True) - manufacturer = serializers.PrimaryKeyRelatedField(queryset=Company.objects.filter(is_manufacturer=True)) + manufacturer = serializers.PrimaryKeyRelatedField( + queryset=Company.objects.filter(is_manufacturer=True) + ) class ManufacturerPartAttachmentSerializer(InvenTreeAttachmentSerializer): @@ -264,9 +257,7 @@ class ManufacturerPartAttachmentSerializer(InvenTreeAttachmentSerializer): model = ManufacturerPartAttachment - fields = InvenTreeAttachmentSerializer.attachment_fields([ - 'manufacturer_part', - ]) + fields = InvenTreeAttachmentSerializer.attachment_fields(['manufacturer_part']) class ManufacturerPartParameterSerializer(InvenTreeModelSerializer): @@ -295,7 +286,9 @@ class ManufacturerPartParameterSerializer(InvenTreeModelSerializer): if not man_detail: self.fields.pop('manufacturer_part_detail') - manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', many=False, read_only=True) + manufacturer_part_detail = ManufacturerPartSerializer( + source='manufacturer_part', many=False, read_only=True + ) class SupplierPartSerializer(InvenTreeTagModelSerializer): @@ -330,14 +323,10 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer): 'supplier_detail', 'url', 'updated', - 'tags', ] - read_only_fields = [ - 'availability_updated', - 'barcode_hash', - ] + read_only_fields = ['availability_updated', 'barcode_hash'] tags = TagListSerializerField(required=False) @@ -377,19 +366,27 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer): part_detail = PartBriefSerializer(source='part', many=False, read_only=True) - supplier_detail = CompanyBriefSerializer(source='supplier', many=False, read_only=True) + supplier_detail = CompanyBriefSerializer( + source='supplier', many=False, read_only=True + ) - manufacturer_detail = CompanyBriefSerializer(source='manufacturer_part.manufacturer', many=False, read_only=True) + manufacturer_detail = CompanyBriefSerializer( + source='manufacturer_part.manufacturer', many=False, read_only=True + ) pretty_name = serializers.CharField(read_only=True) - supplier = serializers.PrimaryKeyRelatedField(queryset=Company.objects.filter(is_supplier=True)) + supplier = serializers.PrimaryKeyRelatedField( + queryset=Company.objects.filter(is_supplier=True) + ) manufacturer = serializers.CharField(read_only=True) MPN = serializers.CharField(read_only=True) - manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', part_detail=False, read_only=True) + manufacturer_part_detail = ManufacturerPartSerializer( + source='manufacturer_part', part_detail=False, read_only=True + ) url = serializers.CharField(source='get_absolute_url', read_only=True) @@ -403,9 +400,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer): Fields: in_stock: Current stock quantity for each SupplierPart """ - queryset = queryset.annotate( - in_stock=part.filters.annotate_total_stock() - ) + queryset = queryset.annotate(in_stock=part.filters.annotate_total_stock()) return queryset @@ -436,10 +431,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer): MPN = self.initial_data.get('MPN', None) if manufacturer and MPN: - kwargs = { - 'manufacturer': manufacturer, - 'MPN': MPN, - } + kwargs = {'manufacturer': manufacturer, 'MPN': MPN} supplier_part.save(**kwargs) return supplier_part @@ -479,17 +471,19 @@ class SupplierPriceBreakSerializer(InvenTreeModelSerializer): quantity = InvenTreeDecimalField() - price = InvenTreeMoneySerializer( - allow_null=True, - required=True, - label=_('Price'), - ) + price = InvenTreeMoneySerializer(allow_null=True, required=True, label=_('Price')) price_currency = InvenTreeCurrencySerializer() - supplier = serializers.PrimaryKeyRelatedField(source='part.supplier', many=False, read_only=True) + supplier = serializers.PrimaryKeyRelatedField( + source='part.supplier', many=False, read_only=True + ) - supplier_detail = CompanyBriefSerializer(source='part.supplier', many=False, read_only=True) + supplier_detail = CompanyBriefSerializer( + source='part.supplier', many=False, read_only=True + ) # Detail serializer for SupplierPart - part_detail = SupplierPartSerializer(source='part', brief=True, many=False, read_only=True) + part_detail = SupplierPartSerializer( + source='part', brief=True, many=False, read_only=True + ) diff --git a/InvenTree/company/test_api.py b/InvenTree/company/test_api.py index 33720e30c2..4bcae45bf6 100644 --- a/InvenTree/company/test_api.py +++ b/InvenTree/company/test_api.py @@ -12,10 +12,7 @@ from .models import Address, Company, Contact, ManufacturerPart, SupplierPart class CompanyTest(InvenTreeAPITestCase): """Series of tests for the Company DRF API.""" - roles = [ - 'purchase_order.add', - 'purchase_order.change', - ] + roles = ['purchase_order.add', 'purchase_order.change'] @classmethod def setUpTestData(cls): @@ -23,9 +20,18 @@ class CompanyTest(InvenTreeAPITestCase): super().setUpTestData() # Create some company objects to work with - cls.acme = Company.objects.create(name='ACME', description='Supplier', is_customer=False, is_supplier=True) - Company.objects.create(name='Drippy Cup Co.', description='Customer', is_customer=True, is_supplier=False) - Company.objects.create(name='Sippy Cup Emporium', description='Another supplier') + cls.acme = Company.objects.create( + name='ACME', description='Supplier', is_customer=False, is_supplier=True + ) + Company.objects.create( + name='Drippy Cup Co.', + description='Customer', + is_customer=True, + is_supplier=False, + ) + Company.objects.create( + name='Sippy Cup Emporium', description='Another supplier' + ) def test_company_list(self): """Test the list API endpoint for the Company model""" @@ -82,22 +88,13 @@ class CompanyTest(InvenTreeAPITestCase): url = reverse('api-company-list') # Name is required - response = self.post( - url, - { - 'description': 'A description!', - }, - expected_code=400 - ) + response = self.post(url, {'description': 'A description!'}, expected_code=400) # Minimal example, checking default values response = self.post( url, - { - 'name': 'My API Company', - 'description': 'A company created via the API', - }, - expected_code=201 + {'name': 'My API Company', 'description': 'A company created via the API'}, + expected_code=201, ) self.assertTrue(response.data['is_supplier']) @@ -110,14 +107,14 @@ class CompanyTest(InvenTreeAPITestCase): response = self.post( url, { - 'name': "Another Company", - 'description': "Also created via the API!", + 'name': 'Another Company', + 'description': 'Also created via the API!', 'currency': 'AUD', 'is_supplier': False, 'is_manufacturer': True, 'is_customer': True, }, - expected_code=201 + expected_code=201, ) self.assertEqual(response.data['currency'], 'AUD') @@ -128,12 +125,8 @@ class CompanyTest(InvenTreeAPITestCase): # Attempt to create with invalid currency response = self.post( url, - { - 'name': "A name", - 'description': 'A description', - 'currency': 'POQD', - }, - expected_code=400 + {'name': 'A name', 'description': 'A description', 'currency': 'POQD'}, + expected_code=400, ) self.assertTrue('currency' in response.data) @@ -151,10 +144,8 @@ class ContactTest(InvenTreeAPITestCase): # Create some companies companies = [ - Company( - name=f"Company {idx}", - description="Some company" - ) for idx in range(3) + Company(name=f'Company {idx}', description='Some company') + for idx in range(3) ] Company.objects.bulk_create(companies) @@ -164,10 +155,7 @@ class ContactTest(InvenTreeAPITestCase): # Create some contacts for cmp in Company.objects.all(): contacts += [ - Contact( - company=cmp, - name=f"My name {idx}", - ) for idx in range(3) + Contact(company=cmp, name=f'My name {idx}') for idx in range(3) ] Contact.objects.bulk_create(contacts) @@ -187,13 +175,7 @@ class ContactTest(InvenTreeAPITestCase): # Filter by particular company for cmp in Company.objects.all(): - response = self.get( - self.url, - { - 'company': cmp.pk, - }, - expected_code=200 - ) + response = self.get(self.url, {'company': cmp.pk}, expected_code=200) self.assertEqual(len(response.data), 3) @@ -205,23 +187,13 @@ class ContactTest(InvenTreeAPITestCase): # Without required permissions, creation should fail self.post( - self.url, - { - 'company': company.pk, - 'name': 'Joe Bloggs', - }, - expected_code=403 + self.url, {'company': company.pk, 'name': 'Joe Bloggs'}, expected_code=403 ) self.assignRole('return_order.add') self.post( - self.url, - { - 'company': company.pk, - 'name': 'Joe Bloggs', - }, - expected_code=201 + self.url, {'company': company.pk, 'name': 'Joe Bloggs'}, expected_code=201 ) self.assertEqual(Contact.objects.count(), n + 1) @@ -239,23 +211,11 @@ class ContactTest(InvenTreeAPITestCase): for key in ['pk', 'name', 'role']: self.assertIn(key, data) - self.patch( - url, - { - 'role': 'model', - }, - expected_code=403 - ) + self.patch(url, {'role': 'model'}, expected_code=403) self.assignRole('purchase_order.change') - self.patch( - url, - { - 'role': 'x', - }, - expected_code=200 - ) + self.patch(url, {'role': 'x'}, expected_code=200) # Get the contact again contact = Contact.objects.first() @@ -291,10 +251,8 @@ class AddressTest(InvenTreeAPITestCase): cls.num_addr = 3 # Create some companies companies = [ - Company( - name=f"Company {idx}", - description="Some company" - ) for idx in range(cls.num_companies) + Company(name=f'Company {idx}', description='Some company') + for idx in range(cls.num_companies) ] Company.objects.bulk_create(companies) @@ -304,10 +262,8 @@ class AddressTest(InvenTreeAPITestCase): # Create some contacts for cmp in Company.objects.all(): addresses += [ - Address( - company=cmp, - title=f"Address no. {idx}", - ) for idx in range(cls.num_addr) + Address(company=cmp, title=f'Address no. {idx}') + for idx in range(cls.num_addr) ] cls.url = reverse('api-address-list') @@ -332,21 +288,11 @@ class AddressTest(InvenTreeAPITestCase): """Test creating a new address""" company = Company.objects.first() - self.post(self.url, - { - 'company': company.pk, - 'title': 'HQ' - }, - expected_code=403) + self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=403) self.assignRole('purchase_order.add') - self.post(self.url, - { - 'company': company.pk, - 'title': 'HQ' - }, - expected_code=201) + self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=201) def test_get(self): """Test that objects are properly returned from a get""" @@ -357,7 +303,15 @@ class AddressTest(InvenTreeAPITestCase): self.assertEqual(response.data['pk'], addr.pk) - for key in ['title', 'line1', 'line2', 'postal_code', 'postal_city', 'province', 'country']: + for key in [ + 'title', + 'line1', + 'line2', + 'postal_code', + 'postal_city', + 'province', + 'country', + ]: self.assertIn(key, response.data) def test_edit(self): @@ -366,23 +320,11 @@ class AddressTest(InvenTreeAPITestCase): url = reverse('api-address-detail', kwargs={'pk': addr.pk}) - self.patch( - url, - { - 'title': 'Hello' - }, - expected_code=403 - ) + self.patch(url, {'title': 'Hello'}, expected_code=403) self.assignRole('purchase_order.change') - self.patch( - url, - { - 'title': 'World' - }, - expected_code=200 - ) + self.patch(url, {'title': 'World'}, expected_code=200) data = self.get(url, expected_code=200).data @@ -415,10 +357,7 @@ class ManufacturerTest(InvenTreeAPITestCase): 'supplier_part', ] - roles = [ - 'part.add', - 'part.change', - ] + roles = ['part.add', 'part.change'] def test_manufacturer_part_list(self): """Test the ManufacturerPart API list functionality""" @@ -429,11 +368,7 @@ class ManufacturerTest(InvenTreeAPITestCase): self.assertEqual(len(response.data), 3) # Create manufacturer part - data = { - 'part': 1, - 'manufacturer': 7, - 'MPN': 'MPN_TEST', - } + data = {'part': 1, 'manufacturer': 7, 'MPN': 'MPN_TEST'} response = self.client.post(url, data, format='json') self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.data['MPN'], 'MPN_TEST') @@ -456,9 +391,7 @@ class ManufacturerTest(InvenTreeAPITestCase): self.assertEqual(response.data['MPN'], 'MPN123') # Change the MPN - data = { - 'MPN': 'MPN-TEST-123', - } + data = {'MPN': 'MPN-TEST-123'} response = self.client.patch(url, data, format='json') @@ -485,7 +418,7 @@ class ManufacturerTest(InvenTreeAPITestCase): 'MPN': 'PART_NUMBER', 'link': 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E', }, - expected_code=201 + expected_code=201, ) pk = response.data['pk'] @@ -503,10 +436,16 @@ class ManufacturerTest(InvenTreeAPITestCase): self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Check link is not modified - self.assertEqual(response.data['link'], 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E') + self.assertEqual( + response.data['link'], + 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E', + ) # Check link is not modified - self.assertEqual(response.data['link'], 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E') + self.assertEqual( + response.data['link'], + 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E', + ) class SupplierPartTest(InvenTreeAPITestCase): @@ -521,12 +460,7 @@ class SupplierPartTest(InvenTreeAPITestCase): 'supplier_part', ] - roles = [ - 'part.add', - 'part.change', - 'part.add', - 'purchase_order.change', - ] + roles = ['part.add', 'part.change', 'part.add', 'purchase_order.change'] def test_supplier_part_list(self): """Test the SupplierPart API list functionality""" @@ -543,10 +477,7 @@ class SupplierPartTest(InvenTreeAPITestCase): self.assertEqual(len(response.data), supplier.supplied_parts.count()) # Filter by Part reference - expected = { - 1: 4, - 25: 2, - } + expected = {1: 4, 25: 2} for pk, n in expected.items(): response = self.get(url, {'part': pk}, expected_code=200) @@ -559,12 +490,7 @@ class SupplierPartTest(InvenTreeAPITestCase): # Should fail when sending an invalid 'available' field response = self.post( url, - { - 'part': 1, - 'supplier': 2, - 'SKU': 'QQ', - 'available': 'not a number', - }, + {'part': 1, 'supplier': 2, 'SKU': 'QQ', 'available': 'not a number'}, expected_code=400, ) @@ -572,13 +498,7 @@ class SupplierPartTest(InvenTreeAPITestCase): # Create a SupplierPart without specifying available quantity response = self.post( - url, - { - 'part': 1, - 'supplier': 2, - 'SKU': 'QQ', - }, - expected_code=201 + url, {'part': 1, 'supplier': 2, 'SKU': 'QQ'}, expected_code=201 ) sp = SupplierPart.objects.get(pk=response.data['pk']) @@ -589,9 +509,7 @@ class SupplierPartTest(InvenTreeAPITestCase): # Now, *update* the available quantity via the API self.patch( reverse('api-supplier-part-detail', kwargs={'pk': sp.pk}), - { - 'available': 1234, - }, + {'available': 1234}, expected_code=200, ) @@ -602,12 +520,7 @@ class SupplierPartTest(InvenTreeAPITestCase): # We should also be able to create a SupplierPart with initial 'available' quantity response = self.post( url, - { - 'part': 1, - 'supplier': 2, - 'SKU': 'QQQ', - 'available': 999, - }, + {'part': 1, 'supplier': 2, 'SKU': 'QQQ', 'available': 999}, expected_code=201, ) @@ -629,11 +542,7 @@ class CompanyMetadataAPITest(InvenTreeAPITestCase): 'supplier_part', ] - roles = [ - 'company.change', - 'purchase_order.change', - 'part.change', - ] + roles = ['company.change', 'purchase_order.change', 'part.change'] def metatester(self, apikey, model): """Generic tester""" @@ -651,17 +560,15 @@ class CompanyMetadataAPITest(InvenTreeAPITestCase): self.patch( url, - { - 'metadata': { - f'abc-{numstr}': f'xyz-{apikey}-{numstr}', - } - }, - expected_code=200 + {'metadata': {f'abc-{numstr}': f'xyz-{apikey}-{numstr}'}}, + expected_code=200, ) # Refresh modeldata.refresh_from_db() - self.assertEqual(modeldata.get_metadata(f'abc-{numstr}'), f'xyz-{apikey}-{numstr}') + self.assertEqual( + modeldata.get_metadata(f'abc-{numstr}'), f'xyz-{apikey}-{numstr}' + ) def test_metadata(self): """Test all endpoints""" diff --git a/InvenTree/company/test_migrations.py b/InvenTree/company/test_migrations.py index d8285ee5a1..b4b291644a 100644 --- a/InvenTree/company/test_migrations.py +++ b/InvenTree/company/test_migrations.py @@ -16,9 +16,7 @@ class TestForwardMigrations(MigratorTestCase): Company = self.old_state.apps.get_model('company', 'company') Company.objects.create( - name='MSPC', - description='Michael Scotts Paper Company', - is_supplier=True + name='MSPC', description='Michael Scotts Paper Company', is_supplier=True ) def test_migrations(self): @@ -49,7 +47,10 @@ class TestManufacturerField(MigratorTestCase): part = Part.objects.create( name='Screw', description='A single screw', - level=0, tree_id=0, lft=0, rght=0 + level=0, + tree_id=0, + lft=0, + rght=0, ) # Create a company to act as the supplier @@ -62,17 +63,11 @@ class TestManufacturerField(MigratorTestCase): # Add some SupplierPart objects SupplierPart.objects.create( - part=part, - supplier=supplier, - SKU='SCREW.001', - manufacturer_name='ACME', + part=part, supplier=supplier, SKU='SCREW.001', manufacturer_name='ACME' ) SupplierPart.objects.create( - part=part, - supplier=supplier, - SKU='SCREW.002', - manufacturer_name='Zero Corp', + part=part, supplier=supplier, SKU='SCREW.002', manufacturer_name='Zero Corp' ) self.assertEqual(Company.objects.count(), 1) @@ -186,17 +181,12 @@ class TestManufacturerPart(MigratorTestCase): # No Manufacturer SupplierPart.objects.create( - part=part, - supplier=supplier_2, - SKU='CAP-CER-01UF-2', - MPN='MUR-CAP-123456', + part=part, supplier=supplier_2, SKU='CAP-CER-01UF-2', MPN='MUR-CAP-123456' ) # No Manufacturer data SupplierPart.objects.create( - part=part, - supplier=supplier_2, - SKU='CAP-CER-01UF-3', + part=part, supplier=supplier_2, SKU='CAP-CER-01UF-3' ) def test_manufacturer_part_objects(self): @@ -238,17 +228,20 @@ class TestCurrencyMigration(MigratorTestCase): Part = self.old_state.apps.get_model('part', 'part') part = Part.objects.create( - name="PART", description="A purchaseable part", + name='PART', + description='A purchaseable part', purchaseable=True, level=0, tree_id=0, lft=0, - rght=0 + rght=0, ) Company = self.old_state.apps.get_model('company', 'company') - supplier = Company.objects.create(name='Supplier', description='A supplier', is_supplier=True) + supplier = Company.objects.create( + name='Supplier', description='A supplier', is_supplier=True + ) SupplierPart = self.old_state.apps.get_model('company', 'supplierpart') @@ -256,8 +249,12 @@ class TestCurrencyMigration(MigratorTestCase): Currency = self.old_state.apps.get_model('common', 'currency') - aud = Currency.objects.create(symbol='$', suffix='AUD', description='Australian Dollars', value=1.0) - usd = Currency.objects.create(symbol='$', suffix='USD', description='US Dollars', value=1.0) + aud = Currency.objects.create( + symbol='$', suffix='AUD', description='Australian Dollars', value=1.0 + ) + usd = Currency.objects.create( + symbol='$', suffix='USD', description='US Dollars', value=1.0 + ) PB = self.old_state.apps.get_model('company', 'supplierpricebreak') @@ -312,7 +309,7 @@ class TestAddressMigration(MigratorTestCase): a2 = Address.objects.filter(company=c2.pk).first() self.assertEqual(a1.line1, self.short_l1) - self.assertEqual(a1.line2, "") + self.assertEqual(a1.line2, '') self.assertEqual(a2.line1, self.long_l1) self.assertEqual(a2.line2, self.l2) self.assertEqual(c1.address, '') @@ -332,12 +329,18 @@ class TestSupplierPartQuantity(MigratorTestCase): SupplierPart = self.old_state.apps.get_model('company', 'supplierpart') self.part = Part.objects.create( - name="PART", description="A purchaseable part", + name='PART', + description='A purchaseable part', purchaseable=True, - level=0, tree_id=0, lft=0, rght=0 + level=0, + tree_id=0, + lft=0, + rght=0, ) - self.supplier = Company.objects.create(name='Supplier', description='A supplier', is_supplier=True) + self.supplier = Company.objects.create( + name='Supplier', description='A supplier', is_supplier=True + ) self.supplier_parts = [] @@ -356,7 +359,6 @@ class TestSupplierPartQuantity(MigratorTestCase): SupplierPart = self.new_state.apps.get_model('company', 'supplierpart') for i, sp in enumerate(SupplierPart.objects.all()): - self.assertEqual(sp.pack_quantity, str(i + 1)) self.assertEqual(sp.pack_quantity_native, i + 1) diff --git a/InvenTree/company/test_supplier_parts.py b/InvenTree/company/test_supplier_parts.py index 21c34d9bbf..e00c78d68a 100644 --- a/InvenTree/company/test_supplier_parts.py +++ b/InvenTree/company/test_supplier_parts.py @@ -15,17 +15,15 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase): def test_pack_quantity_dimensionless(self): """Test valid values for the 'pack_quantity' field""" # Create a part without units (dimensionless) - part = Part.objects.create(name='Test Part', description='Test part description', component=True) + part = Part.objects.create( + name='Test Part', description='Test part description', component=True + ) # Create a supplier (company) company = Company.objects.create(name='Test Company', is_supplier=True) # Create a supplier part for this part - sp = SupplierPart.objects.create( - part=part, - supplier=company, - SKU='TEST-SKU' - ) + sp = SupplierPart.objects.create(part=part, supplier=company, SKU='TEST-SKU') # All these values are valid for a dimensionless part pass_tests = { @@ -37,14 +35,7 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase): } # All these values are invalid for a dimensionless part - fail_tests = [ - '1.2m', - '-1', - '0', - '0.0', - '100 feet', - '0 amps' - ] + fail_tests = ['1.2m', '-1', '0', '0.0', '100 feet', '0 amps'] for test, expected in pass_tests.items(): sp.pack_quantity = test @@ -59,17 +50,18 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase): def test_pack_quantity(self): """Test pack_quantity for a part with a specified dimension""" # Create a part with units 'm' - part = Part.objects.create(name='Test Part', description='Test part description', component=True, units='m') + part = Part.objects.create( + name='Test Part', + description='Test part description', + component=True, + units='m', + ) # Create a supplier (company) company = Company.objects.create(name='Test Company', is_supplier=True) # Create a supplier part for this part - sp = SupplierPart.objects.create( - part=part, - supplier=company, - SKU='TEST-SKU' - ) + sp = SupplierPart.objects.create(part=part, supplier=company, SKU='TEST-SKU') # All these values are valid for a part with dimension 'm' pass_tests = { @@ -87,23 +79,14 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase): # All these values are invalid for a part with dimension 'm' # Either the values are invalid, or the units are incompatible - fail_tests = [ - '-1', - '-1m', - '0', - '0m', - '12 deg', - '57 amps', - '-12 oz', - '17 yaks', - ] + fail_tests = ['-1', '-1m', '0', '0m', '12 deg', '57 amps', '-12 oz', '17 yaks'] for test, expected in pass_tests.items(): sp.pack_quantity = test sp.full_clean() self.assertEqual( round(Decimal(sp.pack_quantity_native), 10), - round(Decimal(str(expected)), 10) + round(Decimal(str(expected)), 10), ) for test in fail_tests: diff --git a/InvenTree/company/test_views.py b/InvenTree/company/test_views.py index 274162ccea..307cb8d6c8 100644 --- a/InvenTree/company/test_views.py +++ b/InvenTree/company/test_views.py @@ -36,7 +36,9 @@ class CompanyViewTest(InvenTreeTestCase): def test_manufacturer_part_detail_view(self): """Test the manufacturer part detail view.""" - response = self.client.get(reverse('manufacturer-part-detail', kwargs={'pk': 1})) + response = self.client.get( + reverse('manufacturer-part-detail', kwargs={'pk': 1}) + ) self.assertEqual(response.status_code, 200) self.assertContains(response, 'MPN123') diff --git a/InvenTree/company/tests.py b/InvenTree/company/tests.py index 1f6e7ab56d..9eb2190a98 100644 --- a/InvenTree/company/tests.py +++ b/InvenTree/company/tests.py @@ -8,8 +8,14 @@ from django.test import TestCase from part.models import Part -from .models import (Address, Company, Contact, ManufacturerPart, SupplierPart, - rename_company_image) +from .models import ( + Address, + Company, + Contact, + ManufacturerPart, + SupplierPart, + rename_company_image, +) class CompanySimpleTest(TestCase): @@ -31,11 +37,13 @@ class CompanySimpleTest(TestCase): """Perform initialization for the tests in this class""" super().setUpTestData() - Company.objects.create(name='ABC Co.', - description='Seller of ABC products', - website='www.abc-sales.com', - is_customer=False, - is_supplier=True) + Company.objects.create( + name='ABC Co.', + description='Seller of ABC products', + website='www.abc-sales.com', + is_customer=False, + is_supplier=True, + ) cls.acme0001 = SupplierPart.objects.get(SKU='ACME0001') cls.acme0002 = SupplierPart.objects.get(SKU='ACME0002') @@ -95,9 +103,9 @@ class CompanySimpleTest(TestCase): """Unit tests for supplier part pricing""" m2x4 = Part.objects.get(name='M2x4 LPHS') - self.assertEqual(m2x4.get_price_info(5.5), "38.5 - 41.25") - self.assertEqual(m2x4.get_price_info(10), "70 - 75") - self.assertEqual(m2x4.get_price_info(100), "125 - 350") + self.assertEqual(m2x4.get_price_info(5.5), '38.5 - 41.25') + self.assertEqual(m2x4.get_price_info(10), '70 - 75') + self.assertEqual(m2x4.get_price_info(100), '125 - 350') pmin, pmax = m2x4.get_price_range(5) self.assertEqual(pmin, 35) @@ -113,18 +121,14 @@ class CompanySimpleTest(TestCase): """Test validation for currency selection.""" # Create a company with a valid currency code (should pass) company = Company.objects.create( - name='Test', - description='Toast', - currency='AUD', + name='Test', description='Toast', currency='AUD' ) company.full_clean() # Create a company with an invalid currency code (should fail) company = Company.objects.create( - name='test', - description='Toasty', - currency='XZY', + name='test', description='Toasty', currency='XZY' ) with self.assertRaises(ValidationError): @@ -154,7 +158,9 @@ class ContactSimpleTest(TestCase): def setUp(self): """Initialization for the tests in this class""" # Create a simple company - self.c = Company.objects.create(name='Test Corp.', description='We make stuff good') + self.c = Company.objects.create( + name='Test Corp.', description='We make stuff good' + ) # Add some contacts Contact.objects.create(name='Joe Smith', company=self.c) @@ -178,7 +184,9 @@ class AddressTest(TestCase): def setUp(self): """Initialization for the tests in this class""" # Create a simple company - self.c = Company.objects.create(name='Test Corp.', description='We make stuff good') + self.c = Company.objects.create( + name='Test Corp.', description='We make stuff good' + ) def test_create(self): """Test that object creation with only company supplied is successful""" @@ -214,27 +222,28 @@ class AddressTest(TestCase): def test_model_str(self): """Test value of __str__""" - t = "Test address" - l1 = "Busy street 56" - l2 = "Red building" - pcd = "12345" - pct = "City" - pv = "Province" - cn = "COUNTRY" - addr = Address.objects.create(company=self.c, - title=t, - line1=l1, - line2=l2, - postal_code=pcd, - postal_city=pct, - province=pv, - country=cn) + t = 'Test address' + l1 = 'Busy street 56' + l2 = 'Red building' + pcd = '12345' + pct = 'City' + pv = 'Province' + cn = 'COUNTRY' + addr = Address.objects.create( + company=self.c, + title=t, + line1=l1, + line2=l2, + postal_code=pcd, + postal_city=pct, + province=pv, + country=cn, + ) self.assertEqual(str(addr), f'{l1}, {l2}, {pcd}, {pct}, {pv}, {cn}') - addr2 = Address.objects.create(company=self.c, - title=t, - line1=l1, - postal_code=pcd) + addr2 = Address.objects.create( + company=self.c, title=t, line1=l1, postal_code=pcd + ) self.assertEqual(str(addr2), f'{l1}, {pcd}') @@ -242,13 +251,7 @@ class AddressTest(TestCase): class ManufacturerPartSimpleTest(TestCase): """Unit tests for the ManufacturerPart model""" - fixtures = [ - 'category', - 'company', - 'location', - 'part', - 'manufacturer_part', - ] + fixtures = ['category', 'company', 'location', 'part', 'manufacturer_part'] def setUp(self): """Initialization for the unit tests in this class""" @@ -266,9 +269,7 @@ class ManufacturerPartSimpleTest(TestCase): # Create a supplier part supplier = Company.objects.get(pk=5) supplier_part = SupplierPart.objects.create( - part=self.part, - supplier=supplier, - SKU='SKU_TEST', + part=self.part, supplier=supplier, SKU='SKU_TEST' ) supplier_part.save() diff --git a/InvenTree/company/urls.py b/InvenTree/company/urls.py index 71985964fb..bddc745a72 100644 --- a/InvenTree/company/urls.py +++ b/InvenTree/company/urls.py @@ -5,29 +5,41 @@ from django.urls import include, path, re_path from . import views company_urls = [ - # Detail URLs for a specific Company instance - path(r'/', include([ - re_path(r'^.*$', views.CompanyDetail.as_view(), name='company-detail'), - ])), - + path( + r'/', + include([ + re_path(r'^.*$', views.CompanyDetail.as_view(), name='company-detail') + ]), + ), re_path(r'suppliers/', views.CompanyIndex.as_view(), name='supplier-index'), re_path(r'manufacturers/', views.CompanyIndex.as_view(), name='manufacturer-index'), re_path(r'customers/', views.CompanyIndex.as_view(), name='customer-index'), - # Redirect any other patterns to the 'company' index which displays all companies re_path(r'^.*$', views.CompanyIndex.as_view(), name='company-index'), ] manufacturer_part_urls = [ - - path(r'/', views.ManufacturerPartDetail.as_view(template_name='company/manufacturer_part.html'), name='manufacturer-part-detail'), + path( + r'/', + views.ManufacturerPartDetail.as_view( + template_name='company/manufacturer_part.html' + ), + name='manufacturer-part-detail', + ) ] supplier_part_urls = [ - path(r'/', include([ - re_path('^.*$', views.SupplierPartDetail.as_view(template_name='company/supplier_part.html'), name='supplier-part-detail'), - ])) - - + path( + r'/', + include([ + re_path( + '^.*$', + views.SupplierPartDetail.as_view( + template_name='company/supplier_part.html' + ), + name='supplier-part-detail', + ) + ]), + ) ] diff --git a/InvenTree/company/views.py b/InvenTree/company/views.py index c8b2b31dc0..a8c85ef7f8 100644 --- a/InvenTree/company/views.py +++ b/InvenTree/company/views.py @@ -44,14 +44,14 @@ class CompanyIndex(InvenTreeRoleMixin, ListView): 'button_text': _('New Customer'), 'filters': {'is_customer': 'true'}, 'pagetype': 'customers', - } + }, } default = { 'title': _('Companies'), 'button_text': _('New Company'), 'filters': {}, - 'pagetype': 'companies' + 'pagetype': 'companies', } context = None @@ -88,6 +88,7 @@ class CompanyIndex(InvenTreeRoleMixin, ListView): class CompanyDetail(InvenTreePluginViewMixin, DetailView): """Detail view for Company object.""" + context_obect_name = 'company' template_name = 'company/detail.html' queryset = Company.objects.all() @@ -97,6 +98,7 @@ class CompanyDetail(InvenTreePluginViewMixin, DetailView): class ManufacturerPartDetail(InvenTreePluginViewMixin, DetailView): """Detail view for ManufacturerPart.""" + model = ManufacturerPart template_name = 'company/manufacturer_part_detail.html' context_object_name = 'part' @@ -106,6 +108,7 @@ class ManufacturerPartDetail(InvenTreePluginViewMixin, DetailView): class SupplierPartDetail(InvenTreePluginViewMixin, DetailView): """Detail view for SupplierPart.""" + model = SupplierPart template_name = 'company/supplier_part_detail.html' context_object_name = 'part' diff --git a/InvenTree/generic/states/__init__.py b/InvenTree/generic/states/__init__.py index 772cec0ec0..8d6f1e870e 100644 --- a/InvenTree/generic/states/__init__.py +++ b/InvenTree/generic/states/__init__.py @@ -9,9 +9,4 @@ States can be extended with custom options for each InvenTree instance - those o from .states import StatusCode from .transition import StateTransitionMixin, TransitionMethod, storage -__all__ = [ - StatusCode, - storage, - TransitionMethod, - StateTransitionMixin, -] +__all__ = [StatusCode, storage, TransitionMethod, StateTransitionMixin] diff --git a/InvenTree/generic/states/api.py b/InvenTree/generic/states/api.py index e5c07d5128..1d7d3db59e 100644 --- a/InvenTree/generic/states/api.py +++ b/InvenTree/generic/states/api.py @@ -18,9 +18,7 @@ class StatusView(APIView): all available 'StockStatus' codes """ - permission_classes = [ - permissions.IsAuthenticated, - ] + permission_classes = [permissions.IsAuthenticated] # Override status_class for implementing subclass MODEL_REF = 'statusmodel' @@ -30,7 +28,9 @@ class StatusView(APIView): status_model = self.kwargs.get(self.MODEL_REF, None) if status_model is None: - raise ValidationError(f"StatusView view called without '{self.MODEL_REF}' parameter") + raise ValidationError( + f"StatusView view called without '{self.MODEL_REF}' parameter" + ) return status_model @@ -39,15 +39,12 @@ class StatusView(APIView): status_class = self.get_status_model() if not inspect.isclass(status_class): - raise NotImplementedError("`status_class` not a class") + raise NotImplementedError('`status_class` not a class') if not issubclass(status_class, StatusCode): - raise NotImplementedError("`status_class` not a valid StatusCode class") + raise NotImplementedError('`status_class` not a valid StatusCode class') - data = { - 'class': status_class.__name__, - 'values': status_class.dict(), - } + data = {'class': status_class.__name__, 'values': status_class.dict()} return Response(data) @@ -55,9 +52,7 @@ class StatusView(APIView): class AllStatusViews(StatusView): """Endpoint for listing all defined status models.""" - permission_classes = [ - permissions.IsAuthenticated, - ] + permission_classes = [permissions.IsAuthenticated] def get(self, request, *args, **kwargs): """Perform a GET request to learn information about status codes""" diff --git a/InvenTree/generic/states/states.py b/InvenTree/generic/states/states.py index 7d23ca359d..6638b3cc05 100644 --- a/InvenTree/generic/states/states.py +++ b/InvenTree/generic/states/states.py @@ -1,4 +1,5 @@ """Generic implementation of status for InvenTree models.""" + import enum import re @@ -148,12 +149,10 @@ class StatusCode(BaseEnum): @classmethod def dict(cls, key=None): """Return a dict representation containing all required information""" - return {x.name: { - 'color': x.color, - 'key': x.value, - 'label': x.label, - 'name': x.name, - } for x in cls.values(key)} + return { + x.name: {'color': x.color, 'key': x.value, 'label': x.label, 'name': x.name} + for x in cls.values(key) + } @classmethod def list(cls): diff --git a/InvenTree/generic/states/test_transition.py b/InvenTree/generic/states/test_transition.py index 7376b69ae9..a18314519d 100644 --- a/InvenTree/generic/states/test_transition.py +++ b/InvenTree/generic/states/test_transition.py @@ -35,8 +35,7 @@ class TransitionTests(InvenTreeTestCase): def test_class(self): """Ensure that the class itself works.""" - class ErrorImplementation(TransitionMethod): - ... + class ErrorImplementation(TransitionMethod): ... with self.assertRaises(NotImplementedError): ErrorImplementation() @@ -90,7 +89,6 @@ class TransitionTests(InvenTreeTestCase): class ValidImplementation(TransitionMethod): def transition(self, *args, **kwargs): - global raise_function if raise_function: @@ -103,7 +101,9 @@ class TransitionTests(InvenTreeTestCase): self.assertIn(ValidImplementation, storage.list) # Ensure that the function is called - self.assertEqual(StateTransitionMixin.handle_transition(0, 1, self, self, dflt), 1234) + self.assertEqual( + StateTransitionMixin.handle_transition(0, 1, self, self, dflt), 1234 + ) _clean_storage([ValidImplementationNoEffect, ValidImplementation]) diff --git a/InvenTree/generic/states/tests.py b/InvenTree/generic/states/tests.py index 81bc84732d..0d5f926f40 100644 --- a/InvenTree/generic/states/tests.py +++ b/InvenTree/generic/states/tests.py @@ -1,4 +1,5 @@ """Tests for the generic states module.""" + from django.test.client import RequestFactory from django.utils.translation import gettext_lazy as _ @@ -13,9 +14,9 @@ from .states import StatusCode class GeneralStatus(StatusCode): """Defines a set of status codes for tests.""" - PENDING = 10, _("Pending"), 'secondary' - PLACED = 20, _("Placed"), 'primary' - COMPLETE = 30, _("Complete"), 'success' + PENDING = 10, _('Pending'), 'secondary' + PLACED = 20, _('Placed'), 'primary' + COMPLETE = 30, _('Complete'), 'success' ABC = None # This should be ignored _DEF = None # This should be ignored jkl = None # This should be ignored @@ -27,6 +28,7 @@ class GeneralStatus(StatusCode): class GeneralStateTest(InvenTreeTestCase): """Test that the StatusCode class works.""" + def test_code_definition(self): """Test that the status code class has been defined correctly.""" self.assertEqual(GeneralStatus.PENDING, 10) @@ -36,36 +38,84 @@ class GeneralStateTest(InvenTreeTestCase): def test_code_functions(self): """Test that the status code class functions work correctly""" # render - self.assertEqual(GeneralStatus.render(10), "Pending") - self.assertEqual(GeneralStatus.render(20), "Placed") + self.assertEqual( + GeneralStatus.render(10), + "Pending", + ) + self.assertEqual( + GeneralStatus.render(20), + "Placed", + ) # render with invalid key self.assertEqual(GeneralStatus.render(100), 100) # list - self.assertEqual(GeneralStatus.list(), [{'color': 'secondary', 'key': 10, 'label': 'Pending', 'name': 'PENDING'}, {'color': 'primary', 'key': 20, 'label': 'Placed', 'name': 'PLACED'}, {'color': 'success', 'key': 30, 'label': 'Complete', 'name': 'COMPLETE'}]) + self.assertEqual( + GeneralStatus.list(), + [ + { + 'color': 'secondary', + 'key': 10, + 'label': 'Pending', + 'name': 'PENDING', + }, + {'color': 'primary', 'key': 20, 'label': 'Placed', 'name': 'PLACED'}, + { + 'color': 'success', + 'key': 30, + 'label': 'Complete', + 'name': 'COMPLETE', + }, + ], + ) # text self.assertEqual(GeneralStatus.text(10), 'Pending') self.assertEqual(GeneralStatus.text(20), 'Placed') # items - self.assertEqual(list(GeneralStatus.items()), [(10, 'Pending'), (20, 'Placed'), (30, 'Complete')]) + self.assertEqual( + list(GeneralStatus.items()), + [(10, 'Pending'), (20, 'Placed'), (30, 'Complete')], + ) # keys self.assertEqual(list(GeneralStatus.keys()), ([10, 20, 30])) # labels - self.assertEqual(list(GeneralStatus.labels()), ['Pending', 'Placed', 'Complete']) + self.assertEqual( + list(GeneralStatus.labels()), ['Pending', 'Placed', 'Complete'] + ) # names - self.assertEqual(GeneralStatus.names(), {'PENDING': 10, 'PLACED': 20, 'COMPLETE': 30}) + self.assertEqual( + GeneralStatus.names(), {'PENDING': 10, 'PLACED': 20, 'COMPLETE': 30} + ) # dict - self.assertEqual(GeneralStatus.dict(), { - 'PENDING': {'key': 10, 'name': 'PENDING', 'label': 'Pending', 'color': 'secondary'}, - 'PLACED': {'key': 20, 'name': 'PLACED', 'label': 'Placed', 'color': 'primary'}, - 'COMPLETE': {'key': 30, 'name': 'COMPLETE', 'label': 'Complete', 'color': 'success'}, - }) + self.assertEqual( + GeneralStatus.dict(), + { + 'PENDING': { + 'key': 10, + 'name': 'PENDING', + 'label': 'Pending', + 'color': 'secondary', + }, + 'PLACED': { + 'key': 20, + 'name': 'PLACED', + 'label': 'Placed', + 'color': 'primary', + }, + 'COMPLETE': { + 'key': 30, + 'name': 'COMPLETE', + 'label': 'Complete', + 'color': 'success', + }, + }, + ) # label self.assertEqual(GeneralStatus.label(10), 'Pending') @@ -74,7 +124,10 @@ class GeneralStateTest(InvenTreeTestCase): """Test that the status code tag functions.""" from .tags import status_label - self.assertEqual(status_label('general', 10), "Pending") + self.assertEqual( + status_label('general', 10), + "Pending", + ) # invalid type with self.assertRaises(ValueError) as e: @@ -87,24 +140,54 @@ class GeneralStateTest(InvenTreeTestCase): def test_api(self): """Test StatusView API view.""" view = StatusView.as_view() - rqst = RequestFactory().get('status/',) + rqst = RequestFactory().get('status/') force_authenticate(rqst, user=self.user) # Correct call resp = view(rqst, **{StatusView.MODEL_REF: GeneralStatus}) - self.assertEqual(resp.data, {'class': 'GeneralStatus', 'values': {'COMPLETE': {'key': 30, 'name': 'COMPLETE', 'label': 'Complete', 'color': 'success'}, 'PENDING': {'key': 10, 'name': 'PENDING', 'label': 'Pending', 'color': 'secondary'}, 'PLACED': {'key': 20, 'name': 'PLACED', 'label': 'Placed', 'color': 'primary'}}}) + self.assertEqual( + resp.data, + { + 'class': 'GeneralStatus', + 'values': { + 'COMPLETE': { + 'key': 30, + 'name': 'COMPLETE', + 'label': 'Complete', + 'color': 'success', + }, + 'PENDING': { + 'key': 10, + 'name': 'PENDING', + 'label': 'Pending', + 'color': 'secondary', + }, + 'PLACED': { + 'key': 20, + 'name': 'PLACED', + 'label': 'Placed', + 'color': 'primary', + }, + }, + }, + ) # No status defined resp = view(rqst, **{StatusView.MODEL_REF: None}) self.assertEqual(resp.status_code, 400) - self.assertEqual(str(resp.rendered_content, 'utf-8'), '["StatusView view called without \'statusmodel\' parameter"]') + self.assertEqual( + str(resp.rendered_content, 'utf-8'), + '["StatusView view called without \'statusmodel\' parameter"]', + ) # Invalid call - not a class with self.assertRaises(NotImplementedError) as e: resp = view(rqst, **{StatusView.MODEL_REF: 'invalid'}) - self.assertEqual(str(e.exception), "`status_class` not a class") + self.assertEqual(str(e.exception), '`status_class` not a class') # Invalid call - not the right class with self.assertRaises(NotImplementedError) as e: resp = view(rqst, **{StatusView.MODEL_REF: object}) - self.assertEqual(str(e.exception), "`status_class` not a valid StatusCode class") + self.assertEqual( + str(e.exception), '`status_class` not a valid StatusCode class' + ) diff --git a/InvenTree/generic/states/transition.py b/InvenTree/generic/states/transition.py index 882ace3ebb..a764db5a72 100644 --- a/InvenTree/generic/states/transition.py +++ b/InvenTree/generic/states/transition.py @@ -1,4 +1,5 @@ """Classes and functions for plugin controlled object state transitions.""" + import InvenTree.helpers @@ -15,8 +16,10 @@ class TransitionMethod: - The needed functions are implemented """ # Check if a sending fnc is defined - if (not hasattr(self, 'transition')): - raise NotImplementedError('A TransitionMethod must define a `transition` method') + if not hasattr(self, 'transition'): + raise NotImplementedError( + 'A TransitionMethod must define a `transition` method' + ) class TransitionMethodStorageClass: @@ -62,7 +65,9 @@ class StateTransitionMixin: ``` """ - def handle_transition(self, current_state, target_state, instance, default_action, **kwargs): + def handle_transition( + self, current_state, target_state, instance, default_action, **kwargs + ): """Handle a state transition for an object. Args: @@ -74,7 +79,9 @@ class StateTransitionMixin: # Check if there is a custom override function for this transition for override in storage.list: - rslt = override.transition(current_state, target_state, instance, default_action, **kwargs) + rslt = override.transition( + current_state, target_state, instance, default_action, **kwargs + ) if rslt: return rslt diff --git a/InvenTree/generic/templatetags/generic.py b/InvenTree/generic/templatetags/generic.py index af33e9d36b..e57bdec201 100644 --- a/InvenTree/generic/templatetags/generic.py +++ b/InvenTree/generic/templatetags/generic.py @@ -5,6 +5,4 @@ from django import template register = template.Library() from generic.states.tags import status_label # noqa: E402 -__all__ = [ - status_label, -] +__all__ = [status_label] diff --git a/InvenTree/gunicorn.conf.py b/InvenTree/gunicorn.conf.py index c1b403de12..5c26c15631 100644 --- a/InvenTree/gunicorn.conf.py +++ b/InvenTree/gunicorn.conf.py @@ -2,7 +2,7 @@ import multiprocessing -bind = "0.0.0.0:8000" +bind = '0.0.0.0:8000' workers = multiprocessing.cpu_count() * 2 + 1 diff --git a/InvenTree/label/admin.py b/InvenTree/label/admin.py index ec4cb3f7dd..4acad6054f 100644 --- a/InvenTree/label/admin.py +++ b/InvenTree/label/admin.py @@ -7,6 +7,7 @@ import label.models class LabelAdmin(admin.ModelAdmin): """Admin class for the various label models""" + list_display = ('name', 'description', 'label', 'filters', 'enabled') diff --git a/InvenTree/label/api.py b/InvenTree/label/api.py index 9421616d73..92a9034433 100644 --- a/InvenTree/label/api.py +++ b/InvenTree/label/api.py @@ -18,8 +18,7 @@ import label.models import label.serializers from InvenTree.api import MetadataView from InvenTree.filters import InvenTreeSearchFilter -from InvenTree.mixins import (ListCreateAPI, RetrieveAPI, - RetrieveUpdateDestroyAPI) +from InvenTree.mixins import ListCreateAPI, RetrieveAPI, RetrieveUpdateDestroyAPI from part.models import Part from plugin.builtin.labels.inventree_label import InvenTreeLabelPlugin from plugin.registry import registry @@ -59,7 +58,7 @@ class LabelFilterMixin: for id in ids: try: valid_ids.append(int(id)) - except (ValueError): + except ValueError: pass # Filter queryset by matching ID values @@ -120,34 +119,23 @@ class LabelListView(LabelFilterMixin, ListCreateAPI): return queryset - filter_backends = [ - DjangoFilterBackend, - InvenTreeSearchFilter - ] + filter_backends = [DjangoFilterBackend, InvenTreeSearchFilter] - filterset_fields = [ - 'enabled', - ] + filterset_fields = ['enabled'] - search_fields = [ - 'name', - 'description', - ] + search_fields = ['name', 'description'] @method_decorator(cache_page(5), name='dispatch') class LabelPrintMixin(LabelFilterMixin): """Mixin for printing labels.""" - rolemap = { - "GET": "view", - "POST": "view", - } + rolemap = {'GET': 'view', 'POST': 'view'} def check_permissions(self, request): """Override request method to GET so that also non superusers can print using a post request.""" - if request.method == "POST": - request = clone_request(request, "GET") + if request.method == 'POST': + request = clone_request(request, 'GET') return super().check_permissions(request) @method_decorator(never_cache) @@ -161,7 +149,9 @@ class LabelPrintMixin(LabelFilterMixin): plugin = self.get_plugin(self.request) kwargs.setdefault('context', self.get_serializer_context()) - serializer = plugin.get_printing_options_serializer(self.request, *args, **kwargs) + serializer = plugin.get_printing_options_serializer( + self.request, *args, **kwargs + ) # if no serializer is defined, return an empty serializer if not serializer: @@ -171,8 +161,12 @@ class LabelPrintMixin(LabelFilterMixin): def get(self, request, *args, **kwargs): """Perform a GET request against this endpoint to print labels""" - common.models.InvenTreeUserSetting.set_setting('DEFAULT_' + self.ITEM_KEY.upper() + '_LABEL_TEMPLATE', - self.get_object().pk, None, user=request.user) + common.models.InvenTreeUserSetting.set_setting( + 'DEFAULT_' + self.ITEM_KEY.upper() + '_LABEL_TEMPLATE', + self.get_object().pk, + None, + user=request.user, + ) return self.print(request, self.get_items()) def post(self, request, *args, **kwargs): @@ -205,8 +199,10 @@ class LabelPrintMixin(LabelFilterMixin): if not plugin.is_active(): raise ValidationError(f"Plugin '{plugin_key}' is not enabled") - if not plugin.mixin_enabled("labels"): - raise ValidationError(f"Plugin '{plugin_key}' is not a label printing plugin") + if not plugin.mixin_enabled('labels'): + raise ValidationError( + f"Plugin '{plugin_key}' is not a label printing plugin" + ) # Only return the plugin if it is enabled and has the label printing mixin return plugin @@ -228,18 +224,24 @@ class LabelPrintMixin(LabelFilterMixin): raise ValidationError('Label has invalid dimensions') # if the plugin returns a serializer, validate the data - if serializer := plugin.get_printing_options_serializer(request, data=request.data, context=self.get_serializer_context()): + if serializer := plugin.get_printing_options_serializer( + request, data=request.data, context=self.get_serializer_context() + ): serializer.is_valid(raise_exception=True) # At this point, we offload the label(s) to the selected plugin. # The plugin is responsible for handling the request and returning a response. - result = plugin.print_labels(label, items_to_print, request, printing_options=request.data) + result = plugin.print_labels( + label, items_to_print, request, printing_options=request.data + ) if isinstance(result, JsonResponse): result['plugin'] = plugin.plugin_slug() return result - raise ValidationError(f"Plugin '{plugin.plugin_slug()}' returned invalid response type '{type(result)}'") + raise ValidationError( + f"Plugin '{plugin.plugin_slug()}' returned invalid response type '{type(result)}'" + ) class StockItemLabelMixin: @@ -261,16 +263,19 @@ class StockItemLabelList(StockItemLabelMixin, LabelListView): - item: Filter by single stock item - items: Filter by list of stock items """ + pass class StockItemLabelDetail(StockItemLabelMixin, RetrieveUpdateDestroyAPI): """API endpoint for a single StockItemLabel object.""" + pass class StockItemLabelPrint(StockItemLabelMixin, LabelPrintMixin, RetrieveAPI): """API endpoint for printing a StockItemLabel object.""" + pass @@ -293,21 +298,25 @@ class StockLocationLabelList(StockLocationLabelMixin, LabelListView): - location: Filter by a single stock location - locations: Filter by list of stock locations """ + pass class StockLocationLabelDetail(StockLocationLabelMixin, RetrieveUpdateDestroyAPI): """API endpoint for a single StockLocationLabel object.""" + pass class StockLocationLabelPrint(StockLocationLabelMixin, LabelPrintMixin, RetrieveAPI): """API endpoint for printing a StockLocationLabel object.""" + pass class PartLabelMixin: """Mixin for PartLabel endpoints""" + queryset = label.models.PartLabel.objects.all() serializer_class = label.serializers.PartLabelSerializer @@ -317,16 +326,19 @@ class PartLabelMixin: class PartLabelList(PartLabelMixin, LabelListView): """API endpoint for viewing list of PartLabel objects.""" + pass class PartLabelDetail(PartLabelMixin, RetrieveUpdateDestroyAPI): """API endpoint for a single PartLabel object.""" + pass class PartLabelPrint(PartLabelMixin, LabelPrintMixin, RetrieveAPI): """API endpoint for printing a PartLabel object.""" + pass @@ -342,70 +354,147 @@ class BuildLineLabelMixin: class BuildLineLabelList(BuildLineLabelMixin, LabelListView): """API endpoint for viewing a list of BuildLineLabel objects""" + pass class BuildLineLabelDetail(BuildLineLabelMixin, RetrieveUpdateDestroyAPI): """API endpoint for a single BuildLineLabel object""" + pass class BuildLineLabelPrint(BuildLineLabelMixin, LabelPrintMixin, RetrieveAPI): """API endpoint for printing a BuildLineLabel object""" + pass label_api_urls = [ - # Stock item labels - re_path(r'stock/', include([ - # Detail views - path(r'/', include([ - re_path(r'print/?', StockItemLabelPrint.as_view(), name='api-stockitem-label-print'), - re_path(r'metadata/', MetadataView.as_view(), {'model': label.models.StockItemLabel}, name='api-stockitem-label-metadata'), - re_path(r'^.*$', StockItemLabelDetail.as_view(), name='api-stockitem-label-detail'), - ])), - - # List view - re_path(r'^.*$', StockItemLabelList.as_view(), name='api-stockitem-label-list'), - ])), - + re_path( + r'stock/', + include([ + # Detail views + path( + r'/', + include([ + re_path( + r'print/?', + StockItemLabelPrint.as_view(), + name='api-stockitem-label-print', + ), + re_path( + r'metadata/', + MetadataView.as_view(), + {'model': label.models.StockItemLabel}, + name='api-stockitem-label-metadata', + ), + re_path( + r'^.*$', + StockItemLabelDetail.as_view(), + name='api-stockitem-label-detail', + ), + ]), + ), + # List view + re_path( + r'^.*$', StockItemLabelList.as_view(), name='api-stockitem-label-list' + ), + ]), + ), # Stock location labels - re_path(r'location/', include([ - # Detail views - path(r'/', include([ - re_path(r'print/?', StockLocationLabelPrint.as_view(), name='api-stocklocation-label-print'), - re_path(r'metadata/', MetadataView.as_view(), {'model': label.models.StockLocationLabel}, name='api-stocklocation-label-metadata'), - re_path(r'^.*$', StockLocationLabelDetail.as_view(), name='api-stocklocation-label-detail'), - ])), - - # List view - re_path(r'^.*$', StockLocationLabelList.as_view(), name='api-stocklocation-label-list'), - ])), - + re_path( + r'location/', + include([ + # Detail views + path( + r'/', + include([ + re_path( + r'print/?', + StockLocationLabelPrint.as_view(), + name='api-stocklocation-label-print', + ), + re_path( + r'metadata/', + MetadataView.as_view(), + {'model': label.models.StockLocationLabel}, + name='api-stocklocation-label-metadata', + ), + re_path( + r'^.*$', + StockLocationLabelDetail.as_view(), + name='api-stocklocation-label-detail', + ), + ]), + ), + # List view + re_path( + r'^.*$', + StockLocationLabelList.as_view(), + name='api-stocklocation-label-list', + ), + ]), + ), # Part labels - re_path(r'^part/', include([ - # Detail views - path(r'/', include([ - re_path(r'^print/', PartLabelPrint.as_view(), name='api-part-label-print'), - re_path(r'^metadata/', MetadataView.as_view(), {'model': label.models.PartLabel}, name='api-part-label-metadata'), - re_path(r'^.*$', PartLabelDetail.as_view(), name='api-part-label-detail'), - ])), - - # List view - re_path(r'^.*$', PartLabelList.as_view(), name='api-part-label-list'), - ])), - + re_path( + r'^part/', + include([ + # Detail views + path( + r'/', + include([ + re_path( + r'^print/', + PartLabelPrint.as_view(), + name='api-part-label-print', + ), + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': label.models.PartLabel}, + name='api-part-label-metadata', + ), + re_path( + r'^.*$', PartLabelDetail.as_view(), name='api-part-label-detail' + ), + ]), + ), + # List view + re_path(r'^.*$', PartLabelList.as_view(), name='api-part-label-list'), + ]), + ), # BuildLine labels - re_path(r'^buildline/', include([ - # Detail views - path(r'/', include([ - re_path(r'^print/', BuildLineLabelPrint.as_view(), name='api-buildline-label-print'), - re_path(r'^metadata/', MetadataView.as_view(), {'model': label.models.BuildLineLabel}, name='api-buildline-label-metadata'), - re_path(r'^.*$', BuildLineLabelDetail.as_view(), name='api-buildline-label-detail'), - ])), - - # List view - re_path(r'^.*$', BuildLineLabelList.as_view(), name='api-buildline-label-list'), - ])), + re_path( + r'^buildline/', + include([ + # Detail views + path( + r'/', + include([ + re_path( + r'^print/', + BuildLineLabelPrint.as_view(), + name='api-buildline-label-print', + ), + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': label.models.BuildLineLabel}, + name='api-buildline-label-metadata', + ), + re_path( + r'^.*$', + BuildLineLabelDetail.as_view(), + name='api-buildline-label-detail', + ), + ]), + ), + # List view + re_path( + r'^.*$', BuildLineLabelList.as_view(), name='api-buildline-label-list' + ), + ]), + ), ] diff --git a/InvenTree/label/apps.py b/InvenTree/label/apps.py index b5d16eada2..862182a47b 100644 --- a/InvenTree/label/apps.py +++ b/InvenTree/label/apps.py @@ -14,7 +14,7 @@ from django.db.utils import IntegrityError, OperationalError, ProgrammingError import InvenTree.ready -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') def hashFile(filename): @@ -36,23 +36,37 @@ class LabelConfig(AppConfig): def ready(self): """This function is called whenever the label app is loaded.""" # skip loading if plugin registry is not loaded or we run in a background thread - if not InvenTree.ready.isPluginRegistryLoaded() or not InvenTree.ready.isInMainThread(): + if ( + not InvenTree.ready.isPluginRegistryLoaded() + or not InvenTree.ready.isInMainThread() + ): return if InvenTree.ready.isRunningMigrations(): return - if InvenTree.ready.canAppAccessDatabase(allow_test=False) and not InvenTree.ready.isImportingData(): + if ( + InvenTree.ready.canAppAccessDatabase(allow_test=False) + and not InvenTree.ready.isImportingData() + ): try: self.create_labels() # pragma: no cover - except (AppRegistryNotReady, IntegrityError, OperationalError, ProgrammingError): + except ( + AppRegistryNotReady, + IntegrityError, + OperationalError, + ProgrammingError, + ): # Database might not yet be ready - warnings.warn('Database was not ready for creating labels', stacklevel=2) + warnings.warn( + 'Database was not ready for creating labels', stacklevel=2 + ) def create_labels(self): """Create all default templates.""" # Test if models are ready import label.models + assert bool(label.models.StockLocationLabel is not None) # Create the categories @@ -66,7 +80,7 @@ class LabelConfig(AppConfig): 'description': 'Simple QR code label', 'width': 24, 'height': 24, - }, + } ], ) @@ -87,8 +101,8 @@ class LabelConfig(AppConfig): 'description': 'Label with QR code and name of location', 'width': 50, 'height': 24, - } - ] + }, + ], ) self.create_labels_category( @@ -109,7 +123,7 @@ class LabelConfig(AppConfig): 'width': 70, 'height': 24, }, - ] + ], ) self.create_labels_category( @@ -122,24 +136,16 @@ class LabelConfig(AppConfig): 'description': 'Example build line label', 'width': 125, 'height': 48, - }, - ] + } + ], ) def create_labels_category(self, model, ref_name, labels): """Create folder and database entries for the default templates, if they do not already exist.""" # Create root dir for templates - src_dir = Path(__file__).parent.joinpath( - 'templates', - 'label', - ref_name, - ) + src_dir = Path(__file__).parent.joinpath('templates', 'label', ref_name) - dst_dir = settings.MEDIA_ROOT.joinpath( - 'label', - 'inventree', - ref_name, - ) + dst_dir = settings.MEDIA_ROOT.joinpath('label', 'inventree', ref_name) if not dst_dir.exists(): logger.info("Creating required directory: '%s'", dst_dir) @@ -151,12 +157,7 @@ class LabelConfig(AppConfig): def create_template_label(self, model, src_dir, ref_name, label): """Ensure a label template is in place.""" - filename = os.path.join( - 'label', - 'inventree', - ref_name, - label['file'] - ) + filename = os.path.join('label', 'inventree', ref_name, label['file']) src_file = src_dir.joinpath(label['file']) dst_file = settings.MEDIA_ROOT.joinpath(filename) @@ -187,7 +188,10 @@ class LabelConfig(AppConfig): if model.objects.filter(label=filename).exists(): return # pragma: no cover except Exception: - logger.exception("Failed to query label for '%s' - you should run 'invoke update' first!", filename) + logger.exception( + "Failed to query label for '%s' - you should run 'invoke update' first!", + filename, + ) logger.info("Creating entry for %s '%s'", model, label['name']) diff --git a/InvenTree/label/models.py b/InvenTree/label/models.py index 65a0306ce1..ccc5b03843 100644 --- a/InvenTree/label/models.py +++ b/InvenTree/label/models.py @@ -25,12 +25,12 @@ from plugin.registry import registry try: from django_weasyprint import WeasyTemplateResponseMixin except OSError as err: # pragma: no cover - print(f"OSError: {err}") - print("You may require some further system packages to be installed.") + print(f'OSError: {err}') + print('You may require some further system packages to be installed.') sys.exit(1) -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') def rename_label(instance, filename): @@ -97,7 +97,7 @@ class LabelTemplate(MetadataMixin, models.Model): abstract = True # Each class of label files will be stored in a separate subdirectory - SUBDIR = "label" + SUBDIR = 'label' # Object we will be printing against (will be filled out later) object_to_print = None @@ -109,17 +109,16 @@ class LabelTemplate(MetadataMixin, models.Model): def __str__(self): """Format a string representation of a label instance""" - return f"{self.name} - {self.description}" + return f'{self.name} - {self.description}' name = models.CharField( - blank=False, max_length=100, - verbose_name=_('Name'), - help_text=_('Label name'), + blank=False, max_length=100, verbose_name=_('Name'), help_text=_('Label name') ) description = models.CharField( max_length=250, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Description'), help_text=_('Label description'), ) @@ -127,7 +126,8 @@ class LabelTemplate(MetadataMixin, models.Model): label = models.FileField( upload_to=rename_label, unique=True, - blank=False, null=False, + blank=False, + null=False, verbose_name=_('Label'), help_text=_('Label template file'), validators=[FileExtensionValidator(allowed_extensions=['html'])], @@ -143,18 +143,18 @@ class LabelTemplate(MetadataMixin, models.Model): default=50, verbose_name=_('Width [mm]'), help_text=_('Label width, specified in mm'), - validators=[MinValueValidator(2)] + validators=[MinValueValidator(2)], ) height = models.FloatField( default=20, verbose_name=_('Height [mm]'), help_text=_('Label height, specified in mm'), - validators=[MinValueValidator(2)] + validators=[MinValueValidator(2)], ) filename_pattern = models.CharField( - default="label.pdf", + default='label.pdf', verbose_name=_('Filename Pattern'), help_text=_('Pattern for generating label filenames'), max_length=100, @@ -249,11 +249,7 @@ class LabelTemplate(MetadataMixin, models.Model): context = self.context(request, **kwargs) - return render_to_string( - self.template_name, - context, - request - ) + return render_to_string(self.template_name, context, request) def render(self, request, target_object=None, **kwargs): """Render the label template to a PDF file. @@ -269,16 +265,13 @@ class LabelTemplate(MetadataMixin, models.Model): wp = WeasyprintLabelMixin( request, self.template_name, - base_url=request.build_absolute_uri("/"), + base_url=request.build_absolute_uri('/'), presentational_hints=True, filename=self.generate_filename(request), - **kwargs + **kwargs, ) - return wp.render_to_response( - context, - **kwargs - ) + return wp.render_to_response(context, **kwargs) class LabelOutput(models.Model): @@ -293,22 +286,14 @@ class LabelOutput(models.Model): # File will be stored in a subdirectory label = models.FileField( - upload_to=rename_label_output, - unique=True, blank=False, null=False, + upload_to=rename_label_output, unique=True, blank=False, null=False ) # Creation date of label output - created = models.DateField( - auto_now_add=True, - editable=False, - ) + created = models.DateField(auto_now_add=True, editable=False) # User who generated the label - user = models.ForeignKey( - User, - on_delete=models.SET_NULL, - blank=True, null=True, - ) + user = models.ForeignKey(User, on_delete=models.SET_NULL, blank=True, null=True) class StockItemLabel(LabelTemplate): @@ -319,15 +304,14 @@ class StockItemLabel(LabelTemplate): """Return the API URL associated with the StockItemLabel model""" return reverse('api-stockitem-label-list') # pragma: no cover - SUBDIR = "stockitem" + SUBDIR = 'stockitem' filters = models.CharField( - blank=True, max_length=250, + blank=True, + max_length=250, help_text=_('Query filters (comma-separated list of key=value pairs)'), verbose_name=_('Filters'), - validators=[ - validate_stock_item_filters - ] + validators=[validate_stock_item_filters], ) def get_context_data(self, request): @@ -348,7 +332,6 @@ class StockItemLabel(LabelTemplate): 'qr_url': request.build_absolute_uri(stock_item.get_absolute_url()), 'tests': stock_item.testResultMap(), 'parameters': stock_item.part.parameters_map(), - } @@ -360,24 +343,21 @@ class StockLocationLabel(LabelTemplate): """Return the API URL associated with the StockLocationLabel model""" return reverse('api-stocklocation-label-list') # pragma: no cover - SUBDIR = "stocklocation" + SUBDIR = 'stocklocation' filters = models.CharField( - blank=True, max_length=250, + blank=True, + max_length=250, help_text=_('Query filters (comma-separated list of key=value pairs)'), verbose_name=_('Filters'), - validators=[ - validate_stock_location_filters] + validators=[validate_stock_location_filters], ) def get_context_data(self, request): """Generate context data for each provided StockLocation.""" location = self.object_to_print - return { - 'location': location, - 'qr_data': location.format_barcode(brief=True), - } + return {'location': location, 'qr_data': location.format_barcode(brief=True)} class PartLabel(LabelTemplate): @@ -391,12 +371,11 @@ class PartLabel(LabelTemplate): SUBDIR = 'part' filters = models.CharField( - blank=True, max_length=250, + blank=True, + max_length=250, help_text=_('Query filters (comma-separated list of key=value pairs)'), verbose_name=_('Filters'), - validators=[ - validate_part_filters - ] + validators=[validate_part_filters], ) def get_context_data(self, request): @@ -427,12 +406,11 @@ class BuildLineLabel(LabelTemplate): SUBDIR = 'buildline' filters = models.CharField( - blank=True, max_length=250, + blank=True, + max_length=250, help_text=_('Query filters (comma-separated list of key=value pairs)'), verbose_name=_('Filters'), - validators=[ - validate_build_line_filters - ] + validators=[validate_build_line_filters], ) def get_context_data(self, request): diff --git a/InvenTree/label/serializers.py b/InvenTree/label/serializers.py index ecceb8ed69..c2c0d3fda0 100644 --- a/InvenTree/label/serializers.py +++ b/InvenTree/label/serializers.py @@ -1,8 +1,10 @@ """API serializers for the label app""" import label.models -from InvenTree.serializers import (InvenTreeAttachmentSerializerField, - InvenTreeModelSerializer) +from InvenTree.serializers import ( + InvenTreeAttachmentSerializerField, + InvenTreeModelSerializer, +) class LabelSerializerBase(InvenTreeModelSerializer): @@ -13,14 +15,7 @@ class LabelSerializerBase(InvenTreeModelSerializer): @staticmethod def label_fields(): """Generic serializer fields for a label template""" - return [ - 'pk', - 'name', - 'description', - 'label', - 'filters', - 'enabled', - ] + return ['pk', 'name', 'description', 'label', 'filters', 'enabled'] class StockItemLabelSerializer(LabelSerializerBase): diff --git a/InvenTree/label/test_api.py b/InvenTree/label/test_api.py index 2bba71f723..fe44e84e0e 100644 --- a/InvenTree/label/test_api.py +++ b/InvenTree/label/test_api.py @@ -8,17 +8,9 @@ from InvenTree.unit_test import InvenTreeAPITestCase class TestReportTests(InvenTreeAPITestCase): """Tests for the StockItem TestReport templates.""" - fixtures = [ - 'category', - 'part', - 'location', - 'stock', - ] + fixtures = ['category', 'part', 'location', 'stock'] - roles = [ - 'stock.view', - 'stock_location.view', - ] + roles = ['stock.view', 'stock_location.view'] list_url = reverse('api-stockitem-testreport-list') @@ -42,22 +34,10 @@ class TestReportTests(InvenTreeAPITestCase): self.assertEqual(len(response), 0) # TODO - Add some tests to this response - response = self.do_list( - { - 'item': 10, - } - ) + response = self.do_list({'item': 10}) # TODO - Add some tests to this response - response = self.do_list( - { - 'item': 100000, - } - ) + response = self.do_list({'item': 100000}) # TODO - Add some tests to this response - response = self.do_list( - { - 'items': [10, 11, 12], - } - ) + response = self.do_list({'items': [10, 11, 12]}) diff --git a/InvenTree/label/tests.py b/InvenTree/label/tests.py index ac82639388..3b64a7e68d 100644 --- a/InvenTree/label/tests.py +++ b/InvenTree/label/tests.py @@ -24,12 +24,7 @@ from .models import PartLabel, StockItemLabel, StockLocationLabel class LabelTest(InvenTreeAPITestCase): """Unit test class for label models""" - fixtures = [ - 'category', - 'part', - 'location', - 'stock' - ] + fixtures = ['category', 'part', 'location', 'stock'] @classmethod def setUpTestData(cls): @@ -49,12 +44,9 @@ class LabelTest(InvenTreeAPITestCase): def test_default_files(self): """Test that label files exist in the MEDIA directory.""" + def test_subdir(ref_name): - item_dir = settings.MEDIA_ROOT.joinpath( - 'label', - 'inventree', - ref_name, - ) + item_dir = settings.MEDIA_ROOT.joinpath('label', 'inventree', ref_name) self.assertTrue(len([item_dir.iterdir()]) > 0) test_subdir('stockitem') @@ -63,13 +55,13 @@ class LabelTest(InvenTreeAPITestCase): def test_filters(self): """Test the label filters.""" - filter_string = "part__pk=10" + filter_string = 'part__pk=10' filters = validateFilterString(filter_string, model=StockItem) self.assertEqual(type(filters), dict) - bad_filter_string = "part_pk=10" + bad_filter_string = 'part_pk=10' with self.assertRaises(ValidationError): validateFilterString(bad_filter_string, model=StockItem) @@ -115,14 +107,11 @@ class LabelTest(InvenTreeAPITestCase): buffer = io.StringIO() buffer.write(label_data) - template = ContentFile(buffer.getvalue(), "label.html") + template = ContentFile(buffer.getvalue(), 'label.html') # Construct a label template label = PartLabel.objects.create( - name='test', - description='Test label', - enabled=True, - label=template, + name='test', description='Test label', enabled=True, label=template ) # Ensure we are in "debug" mode (so the report is generated as HTML) @@ -151,7 +140,7 @@ class LabelTest(InvenTreeAPITestCase): content = f.read() # Test that each element has been rendered correctly - self.assertIn(f"part: {part_pk} - {part_name}", content) + self.assertIn(f'part: {part_pk} - {part_name}', content) self.assertIn(f'data: {{"part": {part_pk}}}', content) self.assertIn(f'http://testserver/part/{part_pk}/', content) diff --git a/InvenTree/manage.py b/InvenTree/manage.py index d2d21b0b23..9770d6ea35 100755 --- a/InvenTree/manage.py +++ b/InvenTree/manage.py @@ -7,17 +7,17 @@ import sys def main(): """Run administrative tasks.""" - os.environ.setdefault("DJANGO_SETTINGS_MODULE", "InvenTree.settings") + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'InvenTree.settings') try: from django.core.management import execute_from_command_line except ImportError as exc: # pragma: no cover raise ImportError( "Couldn't import Django. Are you sure it's installed and " - "available on your PYTHONPATH environment variable? Did you " - "forget to activate a virtual environment?" + 'available on your PYTHONPATH environment variable? Did you ' + 'forget to activate a virtual environment?' ) from exc execute_from_command_line(sys.argv) -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/InvenTree/order/admin.py b/InvenTree/order/admin.py index 842ea2045e..30ec3d2f06 100644 --- a/InvenTree/order/admin.py +++ b/InvenTree/order/admin.py @@ -50,52 +50,54 @@ class PriceResourceMixin: # region general classes class GeneralExtraLineAdmin: """Admin class template for the 'ExtraLineItem' models""" - list_display = ( - 'order', - 'quantity', - 'reference' - ) - search_fields = [ - 'order__reference', - 'order__customer__name', - 'reference', - ] + list_display = ('order', 'quantity', 'reference') - autocomplete_fields = ('order', ) + search_fields = ['order__reference', 'order__customer__name', 'reference'] + + autocomplete_fields = ('order',) class GeneralExtraLineMeta: """Metaclass template for the 'ExtraLineItem' models""" + skip_unchanged = True report_skipped = False clean_model_instances = True + + # endregion class PurchaseOrderLineItemInlineAdmin(admin.StackedInline): """Inline admin class for the PurchaseOrderLineItem model""" + model = models.PurchaseOrderLineItem extra = 0 -class PurchaseOrderResource(ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource): +class PurchaseOrderResource( + ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource +): """Class for managing import / export of PurchaseOrder data.""" class Meta: """Metaclass""" + model = models.PurchaseOrder skip_unchanged = True clean_model_instances = True - exclude = [ - 'metadata', - ] + exclude = ['metadata'] # Add number of line items - line_items = Field(attribute='line_count', widget=widgets.IntegerWidget(), readonly=True) + line_items = Field( + attribute='line_count', widget=widgets.IntegerWidget(), readonly=True + ) # Is this order overdue? - overdue = Field(attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True) + overdue = Field( + attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True + ) class PurchaseOrderAdmin(ImportExportModelAdmin): @@ -103,48 +105,39 @@ class PurchaseOrderAdmin(ImportExportModelAdmin): resource_class = PurchaseOrderResource - exclude = [ - 'reference_int', - ] + exclude = ['reference_int'] - list_display = ( - 'reference', - 'supplier', - 'status', - 'description', - 'creation_date' - ) + list_display = ('reference', 'supplier', 'status', 'description', 'creation_date') - search_fields = [ - 'reference', - 'supplier__name', - 'description', - ] + search_fields = ['reference', 'supplier__name', 'description'] - inlines = [ - PurchaseOrderLineItemInlineAdmin - ] + inlines = [PurchaseOrderLineItemInlineAdmin] autocomplete_fields = ('supplier',) -class SalesOrderResource(ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource): +class SalesOrderResource( + ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource +): """Class for managing import / export of SalesOrder data.""" class Meta: """Metaclass options""" + model = models.SalesOrder skip_unchanged = True clean_model_instances = True - exclude = [ - 'metadata', - ] + exclude = ['metadata'] # Add number of line items - line_items = Field(attribute='line_count', widget=widgets.IntegerWidget(), readonly=True) + line_items = Field( + attribute='line_count', widget=widgets.IntegerWidget(), readonly=True + ) # Is this order overdue? - overdue = Field(attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True) + overdue = Field( + attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True + ) class SalesOrderAdmin(ImportExportModelAdmin): @@ -152,23 +145,11 @@ class SalesOrderAdmin(ImportExportModelAdmin): resource_class = SalesOrderResource - exclude = [ - 'reference_int', - ] + exclude = ['reference_int'] - list_display = ( - 'reference', - 'customer', - 'status', - 'description', - 'creation_date', - ) + list_display = ('reference', 'customer', 'status', 'description', 'creation_date') - search_fields = [ - 'reference', - 'customer__name', - 'description', - ] + search_fields = ['reference', 'customer__name', 'description'] autocomplete_fields = ('customer',) @@ -178,6 +159,7 @@ class PurchaseOrderLineItemResource(PriceResourceMixin, InvenTreeResource): class Meta: """Metaclass""" + model = models.PurchaseOrderLineItem skip_unchanged = True report_skipped = False @@ -212,6 +194,7 @@ class SalesOrderLineItemResource(PriceResourceMixin, InvenTreeResource): class Meta: """Metaclass options""" + model = models.SalesOrderLineItem skip_unchanged = True report_skipped = False @@ -240,6 +223,7 @@ class SalesOrderExtraLineResource(PriceResourceMixin, InvenTreeResource): class Meta(GeneralExtraLineMeta): """Metaclass options.""" + model = models.SalesOrderExtraLine @@ -248,20 +232,16 @@ class PurchaseOrderLineItemAdmin(ImportExportModelAdmin): resource_class = PurchaseOrderLineItemResource - list_display = ( - 'order', - 'part', - 'quantity', - 'reference' - ) + list_display = ('order', 'part', 'quantity', 'reference') search_fields = ('reference',) - autocomplete_fields = ('order', 'part', 'destination',) + autocomplete_fields = ('order', 'part', 'destination') class PurchaseOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin): """Admin class for the PurchaseOrderExtraLine model""" + resource_class = PurchaseOrderExtraLineResource @@ -270,12 +250,7 @@ class SalesOrderLineItemAdmin(ImportExportModelAdmin): resource_class = SalesOrderLineItemResource - list_display = ( - 'order', - 'part', - 'quantity', - 'reference' - ) + list_display = ('order', 'part', 'quantity', 'reference') search_fields = [ 'part__name', @@ -284,28 +259,21 @@ class SalesOrderLineItemAdmin(ImportExportModelAdmin): 'reference', ] - autocomplete_fields = ('order', 'part',) + autocomplete_fields = ('order', 'part') class SalesOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin): """Admin class for the SalesOrderExtraLine model""" + resource_class = SalesOrderExtraLineResource class SalesOrderShipmentAdmin(ImportExportModelAdmin): """Admin class for the SalesOrderShipment model""" - list_display = [ - 'order', - 'shipment_date', - 'reference', - ] + list_display = ['order', 'shipment_date', 'reference'] - search_fields = [ - 'reference', - 'order__reference', - 'order__customer__name', - ] + search_fields = ['reference', 'order__reference', 'order__customer__name'] autocomplete_fields = ('order',) @@ -313,26 +281,23 @@ class SalesOrderShipmentAdmin(ImportExportModelAdmin): class SalesOrderAllocationAdmin(ImportExportModelAdmin): """Admin class for the SalesOrderAllocation model""" - list_display = ( - 'line', - 'item', - 'quantity' - ) + list_display = ('line', 'item', 'quantity') - autocomplete_fields = ('line', 'shipment', 'item',) + autocomplete_fields = ('line', 'shipment', 'item') -class ReturnOrderResource(ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource): +class ReturnOrderResource( + ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource +): """Class for managing import / export of ReturnOrder data""" class Meta: """Metaclass options""" + model = models.ReturnOrder skip_unchanged = True clean_model_instances = True - exclude = [ - 'metadata', - ] + exclude = ['metadata'] class ReturnOrderAdmin(ImportExportModelAdmin): @@ -340,25 +305,13 @@ class ReturnOrderAdmin(ImportExportModelAdmin): resource_class = ReturnOrderResource - exclude = [ - 'reference_int', - ] + exclude = ['reference_int'] - list_display = [ - 'reference', - 'customer', - 'status', - ] + list_display = ['reference', 'customer', 'status'] - search_fields = [ - 'reference', - 'customer__name', - 'description', - ] + search_fields = ['reference', 'customer__name', 'description'] - autocomplete_fields = [ - 'customer', - ] + autocomplete_fields = ['customer'] class ReturnOrderLineItemResource(PriceResourceMixin, InvenTreeResource): @@ -366,6 +319,7 @@ class ReturnOrderLineItemResource(PriceResourceMixin, InvenTreeResource): class Meta: """Metaclass options""" + model = models.ReturnOrderLineItem skip_unchanged = True report_skipped = False @@ -377,11 +331,7 @@ class ReturnOrderLineItemAdmin(ImportExportModelAdmin): resource_class = ReturnOrderLineItemResource - list_display = [ - 'order', - 'item', - 'reference', - ] + list_display = ['order', 'item', 'reference'] class ReturnOrderExtraLineClass(PriceResourceMixin, InvenTreeResource): @@ -389,11 +339,13 @@ class ReturnOrderExtraLineClass(PriceResourceMixin, InvenTreeResource): class Meta(GeneralExtraLineMeta): """Metaclass options""" + model = models.ReturnOrderExtraLine class ReturnOrdeerExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin): """Admin class for the ReturnOrderExtraLine model""" + resource_class = ReturnOrderExtraLineClass diff --git a/InvenTree/order/api.py b/InvenTree/order/api.py index fd7bade3d7..622cb32cae 100644 --- a/InvenTree/order/api.py +++ b/InvenTree/order/api.py @@ -17,22 +17,34 @@ import common.models as common_models from common.settings import settings from company.models import SupplierPart from generic.states.api import StatusView -from InvenTree.api import (APIDownloadMixin, AttachmentMixin, - ListCreateDestroyAPIView, MetadataView) +from InvenTree.api import ( + APIDownloadMixin, + AttachmentMixin, + ListCreateDestroyAPIView, + MetadataView, +) from InvenTree.filters import SEARCH_ORDER_FILTER, SEARCH_ORDER_FILTER_ALIAS from InvenTree.helpers import DownloadFile, str2bool from InvenTree.helpers_model import construct_absolute_url, get_base_url -from InvenTree.mixins import (CreateAPI, ListAPI, ListCreateAPI, - RetrieveUpdateDestroyAPI) -from InvenTree.status_codes import (PurchaseOrderStatus, - PurchaseOrderStatusGroups, - ReturnOrderLineStatus, ReturnOrderStatus, - SalesOrderStatus, SalesOrderStatusGroups) +from InvenTree.mixins import CreateAPI, ListAPI, ListCreateAPI, RetrieveUpdateDestroyAPI +from InvenTree.status_codes import ( + PurchaseOrderStatus, + PurchaseOrderStatusGroups, + ReturnOrderLineStatus, + ReturnOrderStatus, + SalesOrderStatus, + SalesOrderStatusGroups, +) from order import models, serializers -from order.admin import (PurchaseOrderExtraLineResource, - PurchaseOrderLineItemResource, PurchaseOrderResource, - ReturnOrderResource, SalesOrderExtraLineResource, - SalesOrderLineItemResource, SalesOrderResource) +from order.admin import ( + PurchaseOrderExtraLineResource, + PurchaseOrderLineItemResource, + PurchaseOrderResource, + ReturnOrderResource, + SalesOrderExtraLineResource, + SalesOrderLineItemResource, + SalesOrderResource, +) from part.models import Part from users.models import Owner @@ -57,37 +69,24 @@ class GeneralExtraLineList(APIDownloadMixin): """Return the annotated queryset for this endpoint""" queryset = super().get_queryset(*args, **kwargs) - queryset = queryset.prefetch_related( - 'order', - ) + queryset = queryset.prefetch_related('order') return queryset filter_backends = SEARCH_ORDER_FILTER - ordering_fields = [ - 'quantity', - 'note', - 'reference', - ] + ordering_fields = ['quantity', 'note', 'reference'] - search_fields = [ - 'quantity', - 'note', - 'reference', - 'description', - ] + search_fields = ['quantity', 'note', 'reference', 'description'] - filterset_fields = [ - 'order', - ] + filterset_fields = ['order'] class OrderFilter(rest_filters.FilterSet): """Base class for custom API filters for the OrderList endpoint.""" # Filter against order status - status = rest_filters.NumberFilter(label="Order Status", method='filter_status') + status = rest_filters.NumberFilter(label='Order Status', method='filter_status') def filter_status(self, queryset, name, value): """Filter by integer status code""" @@ -95,12 +94,12 @@ class OrderFilter(rest_filters.FilterSet): # Exact match for reference reference = rest_filters.CharFilter( - label='Filter by exact reference', - field_name='reference', - lookup_expr="iexact" + label='Filter by exact reference', field_name='reference', lookup_expr='iexact' ) - assigned_to_me = rest_filters.BooleanFilter(label='assigned_to_me', method='filter_assigned_to_me') + assigned_to_me = rest_filters.BooleanFilter( + label='assigned_to_me', method='filter_assigned_to_me' + ) def filter_assigned_to_me(self, queryset, name, value): """Filter by orders which are assigned to the current user.""" @@ -122,7 +121,9 @@ class OrderFilter(rest_filters.FilterSet): return queryset.filter(self.Meta.model.overdue_filter()) return queryset.exclude(self.Meta.model.overdue_filter()) - outstanding = rest_filters.BooleanFilter(label='outstanding', method='filter_outstanding') + outstanding = rest_filters.BooleanFilter( + label='outstanding', method='filter_outstanding' + ) def filter_outstanding(self, queryset, name, value): """Generic filter for determining if an order is 'outstanding'""" @@ -131,11 +132,12 @@ class OrderFilter(rest_filters.FilterSet): return queryset.exclude(status__in=self.Meta.model.get_status_class().OPEN) project_code = rest_filters.ModelChoiceFilter( - queryset=common_models.ProjectCode.objects.all(), - field_name='project_code' + queryset=common_models.ProjectCode.objects.all(), field_name='project_code' ) - has_project_code = rest_filters.BooleanFilter(label='has_project_code', method='filter_has_project_code') + has_project_code = rest_filters.BooleanFilter( + label='has_project_code', method='filter_has_project_code' + ) def filter_has_project_code(self, queryset, name, value): """Filter by whether or not the order has a project code""" @@ -148,9 +150,13 @@ class LineItemFilter(rest_filters.FilterSet): """Base class for custom API filters for order line item list(s)""" # Filter by order status - order_status = rest_filters.NumberFilter(label='order_status', field_name='order__status') + order_status = rest_filters.NumberFilter( + label='order_status', field_name='order__status' + ) - has_pricing = rest_filters.BooleanFilter(label="Has Pricing", method='filter_has_pricing') + has_pricing = rest_filters.BooleanFilter( + label='Has Pricing', method='filter_has_pricing' + ) def filter_has_pricing(self, queryset, name, value): """Filter by whether or not the line item has pricing information""" @@ -168,9 +174,7 @@ class PurchaseOrderFilter(OrderFilter): """Metaclass options.""" model = models.PurchaseOrder - fields = [ - 'supplier', - ] + fields = ['supplier'] class PurchaseOrderMixin: @@ -182,7 +186,9 @@ class PurchaseOrderMixin: def get_serializer(self, *args, **kwargs): """Return the serializer instance for this endpoint""" try: - kwargs['supplier_detail'] = str2bool(self.request.query_params.get('supplier_detail', False)) + kwargs['supplier_detail'] = str2bool( + self.request.query_params.get('supplier_detail', False) + ) except AttributeError: pass @@ -195,10 +201,7 @@ class PurchaseOrderMixin: """Return the annotated queryset for this endpoint""" queryset = super().get_queryset(*args, **kwargs) - queryset = queryset.prefetch_related( - 'supplier', - 'lines', - ) + queryset = queryset.prefetch_related('supplier', 'lines') queryset = serializers.PurchaseOrderSerializer.annotate_queryset(queryset) @@ -227,7 +230,7 @@ class PurchaseOrderList(PurchaseOrderMixin, APIDownloadMixin, ListCreateAPI): duplicate_order = models.PurchaseOrder.objects.get(pk=duplicate_order) except (ValueError, models.PurchaseOrder.DoesNotExist): raise ValidationError({ - 'duplicate_order': [_('No matching purchase order found')], + 'duplicate_order': [_('No matching purchase order found')] }) serializer = self.get_serializer(data=data) @@ -240,7 +243,6 @@ class PurchaseOrderList(PurchaseOrderMixin, APIDownloadMixin, ListCreateAPI): # Duplicate line items from other order if required if duplicate_order is not None: - if duplicate_line_items: for line in duplicate_order.lines.all(): # Copy the line across to the new order @@ -259,7 +261,9 @@ class PurchaseOrderList(PurchaseOrderMixin, APIDownloadMixin, ListCreateAPI): line.save() headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + ) def download_queryset(self, queryset, export_format): """Download the filtered queryset as a file""" @@ -267,7 +271,7 @@ class PurchaseOrderList(PurchaseOrderMixin, APIDownloadMixin, ListCreateAPI): filedata = dataset.export(export_format) - filename = f"InvenTree_PurchaseOrders.{export_format}" + filename = f'InvenTree_PurchaseOrders.{export_format}' return DownloadFile(filedata, filename) @@ -284,7 +288,9 @@ class PurchaseOrderList(PurchaseOrderMixin, APIDownloadMixin, ListCreateAPI): if part is not None: try: part = Part.objects.get(pk=part) - queryset = queryset.filter(id__in=[p.id for p in part.purchase_orders()]) + queryset = queryset.filter( + id__in=[p.id for p in part.purchase_orders()] + ) except (Part.DoesNotExist, ValueError): pass @@ -294,7 +300,9 @@ class PurchaseOrderList(PurchaseOrderMixin, APIDownloadMixin, ListCreateAPI): if supplier_part is not None: try: supplier_part = SupplierPart.objects.get(pk=supplier_part) - queryset = queryset.filter(id__in=[p.id for p in supplier_part.purchase_orders()]) + queryset = queryset.filter( + id__in=[p.id for p in supplier_part.purchase_orders()] + ) except (ValueError, SupplierPart.DoesNotExist): pass @@ -339,6 +347,7 @@ class PurchaseOrderList(PurchaseOrderMixin, APIDownloadMixin, ListCreateAPI): class PurchaseOrderDetail(PurchaseOrderMixin, RetrieveUpdateDestroyAPI): """API endpoint for detail view of a PurchaseOrder object.""" + pass @@ -353,7 +362,9 @@ class PurchaseOrderContextMixin: # Pass the purchase order through to the serializer for validation try: - context['order'] = models.PurchaseOrder.objects.get(pk=self.kwargs.get('pk', None)) + context['order'] = models.PurchaseOrder.objects.get( + pk=self.kwargs.get('pk', None) + ) except Exception: pass @@ -408,12 +419,10 @@ class PurchaseOrderLineItemFilter(LineItemFilter): class Meta: """Metaclass options.""" + price_field = 'purchase_price' model = models.PurchaseOrderLineItem - fields = [ - 'order', - 'part', - ] + fields = ['order', 'part'] pending = rest_filters.BooleanFilter(label='pending', method='filter_pending') @@ -435,7 +444,9 @@ class PurchaseOrderLineItemFilter(LineItemFilter): if str2bool(value): return queryset.filter(q) # Only count "pending" orders - return queryset.exclude(q).filter(order__status__in=PurchaseOrderStatusGroups.OPEN) + return queryset.exclude(q).filter( + order__status__in=PurchaseOrderStatusGroups.OPEN + ) class PurchaseOrderLineItemMixin: @@ -448,15 +459,21 @@ class PurchaseOrderLineItemMixin: """Return annotated queryset for this endpoint""" queryset = super().get_queryset(*args, **kwargs) - queryset = serializers.PurchaseOrderLineItemSerializer.annotate_queryset(queryset) + queryset = serializers.PurchaseOrderLineItemSerializer.annotate_queryset( + queryset + ) return queryset def get_serializer(self, *args, **kwargs): """Return serializer instance for this endpoint""" try: - kwargs['part_detail'] = str2bool(self.request.query_params.get('part_detail', False)) - kwargs['order_detail'] = str2bool(self.request.query_params.get('order_detail', False)) + kwargs['part_detail'] = str2bool( + self.request.query_params.get('part_detail', False) + ) + kwargs['order_detail'] = str2bool( + self.request.query_params.get('order_detail', False) + ) except AttributeError: pass @@ -465,7 +482,9 @@ class PurchaseOrderLineItemMixin: return self.serializer_class(*args, **kwargs) -class PurchaseOrderLineItemList(PurchaseOrderLineItemMixin, APIDownloadMixin, ListCreateDestroyAPIView): +class PurchaseOrderLineItemList( + PurchaseOrderLineItemMixin, APIDownloadMixin, ListCreateDestroyAPIView +): """API endpoint for accessing a list of PurchaseOrderLineItem objects. - GET: Return a list of PurchaseOrder Line Item objects @@ -499,7 +518,7 @@ class PurchaseOrderLineItemList(PurchaseOrderLineItemMixin, APIDownloadMixin, Li filedata = dataset.export(export_format) - filename = f"InvenTree_PurchaseOrderItems.{export_format}" + filename = f'InvenTree_PurchaseOrderItems.{export_format}' return DownloadFile(filedata, filename) @@ -534,6 +553,7 @@ class PurchaseOrderLineItemList(PurchaseOrderLineItemMixin, APIDownloadMixin, Li class PurchaseOrderLineItemDetail(PurchaseOrderLineItemMixin, RetrieveUpdateDestroyAPI): """Detail API endpoint for PurchaseOrderLineItem object.""" + pass @@ -547,7 +567,7 @@ class PurchaseOrderExtraLineList(GeneralExtraLineList, ListCreateAPI): """Download this queryset as a file""" dataset = PurchaseOrderExtraLineResource().export(queryset=queryset) filedata = dataset.export(export_format) - filename = f"InvenTree_ExtraPurchaseOrderLines.{export_format}" + filename = f'InvenTree_ExtraPurchaseOrderLines.{export_format}' return DownloadFile(filedata, filename) @@ -565,9 +585,7 @@ class SalesOrderAttachmentList(AttachmentMixin, ListCreateDestroyAPIView): queryset = models.SalesOrderAttachment.objects.all() serializer_class = serializers.SalesOrderAttachmentSerializer - filterset_fields = [ - 'order', - ] + filterset_fields = ['order'] class SalesOrderAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI): @@ -584,9 +602,7 @@ class SalesOrderFilter(OrderFilter): """Metaclass options.""" model = models.SalesOrder - fields = [ - 'customer', - ] + fields = ['customer'] class SalesOrderMixin: @@ -598,7 +614,9 @@ class SalesOrderMixin: def get_serializer(self, *args, **kwargs): """Return serializer instance for this endpoint""" try: - kwargs['customer_detail'] = str2bool(self.request.query_params.get('customer_detail', False)) + kwargs['customer_detail'] = str2bool( + self.request.query_params.get('customer_detail', False) + ) except AttributeError: pass @@ -611,10 +629,7 @@ class SalesOrderMixin: """Return annotated queryset for this endpoint""" queryset = super().get_queryset(*args, **kwargs) - queryset = queryset.prefetch_related( - 'customer', - 'lines' - ) + queryset = queryset.prefetch_related('customer', 'lines') queryset = serializers.SalesOrderSerializer.annotate_queryset(queryset) @@ -640,7 +655,9 @@ class SalesOrderList(SalesOrderMixin, APIDownloadMixin, ListCreateAPI): item.save() headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + ) def download_queryset(self, queryset, export_format): """Download this queryset as a file""" @@ -648,7 +665,7 @@ class SalesOrderList(SalesOrderMixin, APIDownloadMixin, ListCreateAPI): filedata = dataset.export(export_format) - filename = f"InvenTree_SalesOrders.{export_format}" + filename = f'InvenTree_SalesOrders.{export_format}' return DownloadFile(filedata, filename) @@ -685,9 +702,7 @@ class SalesOrderList(SalesOrderMixin, APIDownloadMixin, ListCreateAPI): 'project_code': ['project_code__code'], } - filterset_fields = [ - 'customer', - ] + filterset_fields = ['customer'] ordering_fields = [ 'creation_date', @@ -715,6 +730,7 @@ class SalesOrderList(SalesOrderMixin, APIDownloadMixin, ListCreateAPI): class SalesOrderDetail(SalesOrderMixin, RetrieveUpdateDestroyAPI): """API endpoint for detail view of a SalesOrder object.""" + pass @@ -723,12 +739,10 @@ class SalesOrderLineItemFilter(LineItemFilter): class Meta: """Metaclass options.""" + price_field = 'sale_price' model = models.SalesOrderLineItem - fields = [ - 'order', - 'part', - ] + fields = ['order', 'part'] completed = rest_filters.BooleanFilter(label='completed', method='filter_completed') @@ -795,28 +809,20 @@ class SalesOrderLineItemList(SalesOrderLineItemMixin, APIDownloadMixin, ListCrea dataset = SalesOrderLineItemResource().export(queryset=queryset) filedata = dataset.export(export_format) - filename = f"InvenTree_SalesOrderItems.{export_format}" + filename = f'InvenTree_SalesOrderItems.{export_format}' return DownloadFile(filedata, filename) filter_backends = SEARCH_ORDER_FILTER - ordering_fields = [ - 'part__name', - 'quantity', - 'reference', - 'target_date', - ] + ordering_fields = ['part__name', 'quantity', 'reference', 'target_date'] - search_fields = [ - 'part__name', - 'quantity', - 'reference', - ] + search_fields = ['part__name', 'quantity', 'reference'] class SalesOrderLineItemDetail(SalesOrderLineItemMixin, RetrieveUpdateDestroyAPI): """API endpoint for detail view of a SalesOrderLineItem object.""" + pass @@ -830,7 +836,7 @@ class SalesOrderExtraLineList(GeneralExtraLineList, ListCreateAPI): """Download this queryset as a file""" dataset = SalesOrderExtraLineResource().export(queryset=queryset) filedata = dataset.export(export_format) - filename = f"InvenTree_ExtraSalesOrderLines.{export_format}" + filename = f'InvenTree_ExtraSalesOrderLines.{export_format}' return DownloadFile(filedata, filename) @@ -863,11 +869,13 @@ class SalesOrderContextMixin: class SalesOrderCancel(SalesOrderContextMixin, CreateAPI): """API endpoint to cancel a SalesOrder""" + serializer_class = serializers.SalesOrderCancelSerializer class SalesOrderIssue(SalesOrderContextMixin, CreateAPI): """API endpoint to issue a SalesOrder""" + serializer_class = serializers.SalesOrderIssueSerializer @@ -967,14 +975,12 @@ class SalesOrderAllocationList(ListAPI): else: queryset = queryset.exclude( line__order__status__in=SalesOrderStatusGroups.OPEN, - shipment__shipment_date=None + shipment__shipment_date=None, ) return queryset - filter_backends = [ - rest_filters.DjangoFilterBackend, - ] + filter_backends = [rest_filters.DjangoFilterBackend] class SalesOrderShipmentFilter(rest_filters.FilterSet): @@ -984,9 +990,7 @@ class SalesOrderShipmentFilter(rest_filters.FilterSet): """Metaclass options.""" model = models.SalesOrderShipment - fields = [ - 'order', - ] + fields = ['order'] shipped = rest_filters.BooleanFilter(label='shipped', method='filter_shipped') @@ -1012,9 +1016,7 @@ class SalesOrderShipmentList(ListCreateAPI): serializer_class = serializers.SalesOrderShipmentSerializer filterset_class = SalesOrderShipmentFilter - filter_backends = [ - rest_filters.DjangoFilterBackend, - ] + filter_backends = [rest_filters.DjangoFilterBackend] class SalesOrderShipmentDetail(RetrieveUpdateDestroyAPI): @@ -1051,9 +1053,7 @@ class PurchaseOrderAttachmentList(AttachmentMixin, ListCreateDestroyAPIView): queryset = models.PurchaseOrderAttachment.objects.all() serializer_class = serializers.PurchaseOrderAttachmentSerializer - filterset_fields = [ - 'order', - ] + filterset_fields = ['order'] class PurchaseOrderAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI): @@ -1070,9 +1070,7 @@ class ReturnOrderFilter(OrderFilter): """Metaclass options""" model = models.ReturnOrder - fields = [ - 'customer', - ] + fields = ['customer'] class ReturnOrderMixin: @@ -1099,9 +1097,7 @@ class ReturnOrderMixin: """Return annotated queryset for this endpoint""" queryset = super().get_queryset(*args, **kwargs) - queryset = queryset.prefetch_related( - 'customer', - ) + queryset = queryset.prefetch_related('customer') queryset = serializers.ReturnOrderSerializer.annotate_queryset(queryset) @@ -1123,13 +1119,15 @@ class ReturnOrderList(ReturnOrderMixin, APIDownloadMixin, ListCreateAPI): item.save() headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + ) def download_queryset(self, queryset, export_format): """Download this queryset as a file""" dataset = ReturnOrderResource().export(queryset=queryset) filedata = dataset.export(export_format) - filename = f"InvenTree_ReturnOrders.{export_format}" + filename = f'InvenTree_ReturnOrders.{export_format}' return DownloadFile(filedata, filename) @@ -1164,6 +1162,7 @@ class ReturnOrderList(ReturnOrderMixin, APIDownloadMixin, ListCreateAPI): class ReturnOrderDetail(ReturnOrderMixin, RetrieveUpdateDestroyAPI): """API endpoint for detail view of a single ReturnOrder object""" + pass @@ -1178,7 +1177,9 @@ class ReturnOrderContextMixin: # Pass the ReturnOrder instance through to the serializer for validation try: - context['order'] = models.ReturnOrder.objects.get(pk=self.kwargs.get('pk', None)) + context['order'] = models.ReturnOrder.objects.get( + pk=self.kwargs.get('pk', None) + ) except Exception: pass @@ -1189,16 +1190,19 @@ class ReturnOrderContextMixin: class ReturnOrderCancel(ReturnOrderContextMixin, CreateAPI): """API endpoint to cancel a ReturnOrder""" + serializer_class = serializers.ReturnOrderCancelSerializer class ReturnOrderComplete(ReturnOrderContextMixin, CreateAPI): """API endpoint to complete a ReturnOrder""" + serializer_class = serializers.ReturnOrderCompleteSerializer class ReturnOrderIssue(ReturnOrderContextMixin, CreateAPI): """API endpoint to issue (place) a ReturnOrder""" + serializer_class = serializers.ReturnOrderIssueSerializer @@ -1214,12 +1218,10 @@ class ReturnOrderLineItemFilter(LineItemFilter): class Meta: """Metaclass options""" + price_field = 'price' model = models.ReturnOrderLineItem - fields = [ - 'order', - 'item', - ] + fields = ['order', 'item'] outcome = rest_filters.NumberFilter(label='outcome') @@ -1257,31 +1259,27 @@ class ReturnOrderLineItemMixin: """Return annotated queryset for this endpoint""" queryset = super().get_queryset(*args, **kwargs) - queryset = queryset.prefetch_related( - 'order', - 'item', - 'item__part', - ) + queryset = queryset.prefetch_related('order', 'item', 'item__part') return queryset -class ReturnOrderLineItemList(ReturnOrderLineItemMixin, APIDownloadMixin, ListCreateAPI): +class ReturnOrderLineItemList( + ReturnOrderLineItemMixin, APIDownloadMixin, ListCreateAPI +): """API endpoint for accessing a list of ReturnOrderLineItemList objects""" filterset_class = ReturnOrderLineItemFilter def download_queryset(self, queryset, export_format): """Download the requested queryset as a file""" - raise NotImplementedError("download_queryset not yet implemented for this endpoint") + raise NotImplementedError( + 'download_queryset not yet implemented for this endpoint' + ) filter_backends = SEARCH_ORDER_FILTER - ordering_fields = [ - 'reference', - 'target_date', - 'received_date', - ] + ordering_fields = ['reference', 'target_date', 'received_date'] search_fields = [ 'item_serial', @@ -1293,6 +1291,7 @@ class ReturnOrderLineItemList(ReturnOrderLineItemMixin, APIDownloadMixin, ListCr class ReturnOrderLineItemDetail(ReturnOrderLineItemMixin, RetrieveUpdateDestroyAPI): """API endpoint for detail view of a ReturnOrderLineItem object""" + pass @@ -1304,7 +1303,7 @@ class ReturnOrderExtraLineList(GeneralExtraLineList, ListCreateAPI): def download_queryset(self, queryset, export_format): """Download this queryset as a file""" - raise NotImplementedError("download_queryset not yet implemented") + raise NotImplementedError('download_queryset not yet implemented') class ReturnOrderExtraLineDetail(RetrieveUpdateDestroyAPI): @@ -1320,9 +1319,7 @@ class ReturnOrderAttachmentList(AttachmentMixin, ListCreateDestroyAPIView): queryset = models.ReturnOrderAttachment.objects.all() serializer_class = serializers.ReturnOrderAttachmentSerializer - filterset_fields = [ - 'order', - ] + filterset_fields = ['order'] class ReturnOrderAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI): @@ -1342,9 +1339,9 @@ class OrderCalendarExport(ICalFeed): instance_url = get_base_url() - instance_url = instance_url.replace("http://", "").replace("https://", "") + instance_url = instance_url.replace('http://', '').replace('https://', '') timezone = settings.TIME_ZONE - file_name = "calendar.ics" + file_name = 'calendar.ics' def __call__(self, request, *args, **kwargs): """Overload call in order to check for authentication. @@ -1370,8 +1367,8 @@ class OrderCalendarExport(ICalFeed): if len(auth) == 2: # NOTE: We are only support basic authentication for now. # - if auth[0].lower() == "basic": - uname, passwd = base64.b64decode(auth[1]).decode("ascii").split(':') + if auth[0].lower() == 'basic': + uname, passwd = base64.b64decode(auth[1]).decode('ascii').split(':') user = authenticate(username=uname, password=passwd) if user is not None: if user.is_active: @@ -1385,7 +1382,9 @@ class OrderCalendarExport(ICalFeed): # Still nothing - return Unauth. header with info on how to authenticate # Information is needed by client, eg Thunderbird - response = JsonResponse({"detail": "Authentication credentials were not provided."}) + response = JsonResponse({ + 'detail': 'Authentication credentials were not provided.' + }) response['WWW-Authenticate'] = 'Basic realm="api"' response.status_code = 401 return response @@ -1403,11 +1402,11 @@ class OrderCalendarExport(ICalFeed): def title(self, obj): """Return calendar title.""" - if obj["ordertype"] == 'purchase-order': + if obj['ordertype'] == 'purchase-order': ordertype_title = _('Purchase Order') - elif obj["ordertype"] == 'sales-order': + elif obj['ordertype'] == 'sales-order': ordertype_title = _('Sales Order') - elif obj["ordertype"] == 'return-order': + elif obj['ordertype'] == 'return-order': ordertype_title = _('Return Order') else: ordertype_title = _('Unknown') @@ -1429,21 +1428,27 @@ class OrderCalendarExport(ICalFeed): if obj['include_completed'] is False: # Do not include completed orders from list in this case # Completed status = 30 - outlist = models.PurchaseOrder.objects.filter(target_date__isnull=False).filter(status__lt=PurchaseOrderStatus.COMPLETE.value) + outlist = models.PurchaseOrder.objects.filter( + target_date__isnull=False + ).filter(status__lt=PurchaseOrderStatus.COMPLETE.value) else: outlist = models.PurchaseOrder.objects.filter(target_date__isnull=False) - elif obj["ordertype"] == 'sales-order': + elif obj['ordertype'] == 'sales-order': if obj['include_completed'] is False: # Do not include completed (=shipped) orders from list in this case # Shipped status = 20 - outlist = models.SalesOrder.objects.filter(target_date__isnull=False).filter(status__lt=SalesOrderStatus.SHIPPED.value) + outlist = models.SalesOrder.objects.filter( + target_date__isnull=False + ).filter(status__lt=SalesOrderStatus.SHIPPED.value) else: outlist = models.SalesOrder.objects.filter(target_date__isnull=False) - elif obj["ordertype"] == 'return-order': + elif obj['ordertype'] == 'return-order': if obj['include_completed'] is False: # Do not include completed orders from list in this case # Complete status = 30 - outlist = models.ReturnOrder.objects.filter(target_date__isnull=False).filter(status__lt=ReturnOrderStatus.COMPLETE.value) + outlist = models.ReturnOrder.objects.filter( + target_date__isnull=False + ).filter(status__lt=ReturnOrderStatus.COMPLETE.value) else: outlist = models.ReturnOrder.objects.filter(target_date__isnull=False) else: @@ -1453,11 +1458,11 @@ class OrderCalendarExport(ICalFeed): def item_title(self, item): """Set the event title to the order reference""" - return f"{item.reference}" + return f'{item.reference}' def item_description(self, item): """Set the event description""" - return f"Company: {item.company.name}\nStatus: {item.get_status_display()}\nDescription: {item.description}" + return f'Company: {item.company.name}\nStatus: {item.get_status_display()}\nDescription: {item.description}' def item_start_datetime(self, item): """Set event start to target date. Goal is all-day event.""" @@ -1485,160 +1490,405 @@ class OrderCalendarExport(ICalFeed): order_api_urls = [ - # API endpoints for purchase orders - re_path(r'^po/', include([ - - # Purchase order attachments - re_path(r'attachment/', include([ - path('/', PurchaseOrderAttachmentDetail.as_view(), name='api-po-attachment-detail'), - re_path(r'^.*$', PurchaseOrderAttachmentList.as_view(), name='api-po-attachment-list'), - ])), - - # Individual purchase order detail URLs - path(r'/', include([ - re_path(r'^cancel/', PurchaseOrderCancel.as_view(), name='api-po-cancel'), - re_path(r'^complete/', PurchaseOrderComplete.as_view(), name='api-po-complete'), - re_path(r'^issue/', PurchaseOrderIssue.as_view(), name='api-po-issue'), - re_path(r'^metadata/', MetadataView.as_view(), {'model': models.PurchaseOrder}, name='api-po-metadata'), - re_path(r'^receive/', PurchaseOrderReceive.as_view(), name='api-po-receive'), - - # PurchaseOrder detail API endpoint - re_path(r'.*$', PurchaseOrderDetail.as_view(), name='api-po-detail'), - ])), - - # Purchase order status code information - re_path(r'status/', StatusView.as_view(), {StatusView.MODEL_REF: PurchaseOrderStatus}, name='api-po-status-codes'), - - # Purchase order list - re_path(r'^.*$', PurchaseOrderList.as_view(), name='api-po-list'), - ])), - + re_path( + r'^po/', + include([ + # Purchase order attachments + re_path( + r'attachment/', + include([ + path( + '/', + PurchaseOrderAttachmentDetail.as_view(), + name='api-po-attachment-detail', + ), + re_path( + r'^.*$', + PurchaseOrderAttachmentList.as_view(), + name='api-po-attachment-list', + ), + ]), + ), + # Individual purchase order detail URLs + path( + r'/', + include([ + re_path( + r'^cancel/', PurchaseOrderCancel.as_view(), name='api-po-cancel' + ), + re_path( + r'^complete/', + PurchaseOrderComplete.as_view(), + name='api-po-complete', + ), + re_path( + r'^issue/', PurchaseOrderIssue.as_view(), name='api-po-issue' + ), + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': models.PurchaseOrder}, + name='api-po-metadata', + ), + re_path( + r'^receive/', + PurchaseOrderReceive.as_view(), + name='api-po-receive', + ), + # PurchaseOrder detail API endpoint + re_path( + r'.*$', PurchaseOrderDetail.as_view(), name='api-po-detail' + ), + ]), + ), + # Purchase order status code information + re_path( + r'status/', + StatusView.as_view(), + {StatusView.MODEL_REF: PurchaseOrderStatus}, + name='api-po-status-codes', + ), + # Purchase order list + re_path(r'^.*$', PurchaseOrderList.as_view(), name='api-po-list'), + ]), + ), # API endpoints for purchase order line items - re_path(r'^po-line/', include([ - path('/', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': models.PurchaseOrderLineItem}, name='api-po-line-metadata'), - re_path(r'^.*$', PurchaseOrderLineItemDetail.as_view(), name='api-po-line-detail'), - ])), - re_path(r'^.*$', PurchaseOrderLineItemList.as_view(), name='api-po-line-list'), - ])), - + re_path( + r'^po-line/', + include([ + path( + '/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': models.PurchaseOrderLineItem}, + name='api-po-line-metadata', + ), + re_path( + r'^.*$', + PurchaseOrderLineItemDetail.as_view(), + name='api-po-line-detail', + ), + ]), + ), + re_path( + r'^.*$', PurchaseOrderLineItemList.as_view(), name='api-po-line-list' + ), + ]), + ), # API endpoints for purchase order extra line - re_path(r'^po-extra-line/', include([ - path('/', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': models.PurchaseOrderExtraLine}, name='api-po-extra-line-metadata'), - re_path(r'^.*$', PurchaseOrderExtraLineDetail.as_view(), name='api-po-extra-line-detail'), - ])), - path('', PurchaseOrderExtraLineList.as_view(), name='api-po-extra-line-list'), - ])), - + re_path( + r'^po-extra-line/', + include([ + path( + '/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': models.PurchaseOrderExtraLine}, + name='api-po-extra-line-metadata', + ), + re_path( + r'^.*$', + PurchaseOrderExtraLineDetail.as_view(), + name='api-po-extra-line-detail', + ), + ]), + ), + path( + '', PurchaseOrderExtraLineList.as_view(), name='api-po-extra-line-list' + ), + ]), + ), # API endpoints for sales ordesr - re_path(r'^so/', include([ - re_path(r'attachment/', include([ - path('/', SalesOrderAttachmentDetail.as_view(), name='api-so-attachment-detail'), - re_path(r'^.*$', SalesOrderAttachmentList.as_view(), name='api-so-attachment-list'), - ])), - - re_path(r'^shipment/', include([ - path(r'/', include([ - path('ship/', SalesOrderShipmentComplete.as_view(), name='api-so-shipment-ship'), - re_path(r'^metadata/', MetadataView.as_view(), {'model': models.SalesOrderShipment}, name='api-so-shipment-metadata'), - re_path(r'^.*$', SalesOrderShipmentDetail.as_view(), name='api-so-shipment-detail'), - ])), - re_path(r'^.*$', SalesOrderShipmentList.as_view(), name='api-so-shipment-list'), - ])), - - # Sales order detail view - path(r'/', include([ - re_path(r'^allocate/', SalesOrderAllocate.as_view(), name='api-so-allocate'), - re_path(r'^allocate-serials/', SalesOrderAllocateSerials.as_view(), name='api-so-allocate-serials'), - re_path(r'^cancel/', SalesOrderCancel.as_view(), name='api-so-cancel'), - re_path(r'^issue/', SalesOrderIssue.as_view(), name='api-so-issue'), - re_path(r'^complete/', SalesOrderComplete.as_view(), name='api-so-complete'), - re_path(r'^metadata/', MetadataView.as_view(), {'model': models.SalesOrder}, name='api-so-metadata'), - - # SalesOrder detail endpoint - re_path(r'^.*$', SalesOrderDetail.as_view(), name='api-so-detail'), - ])), - - # Sales order status code information - re_path(r'status/', StatusView.as_view(), {StatusView.MODEL_REF: SalesOrderStatus}, name='api-so-status-codes'), - - # Sales order list view - re_path(r'^.*$', SalesOrderList.as_view(), name='api-so-list'), - ])), - + re_path( + r'^so/', + include([ + re_path( + r'attachment/', + include([ + path( + '/', + SalesOrderAttachmentDetail.as_view(), + name='api-so-attachment-detail', + ), + re_path( + r'^.*$', + SalesOrderAttachmentList.as_view(), + name='api-so-attachment-list', + ), + ]), + ), + re_path( + r'^shipment/', + include([ + path( + r'/', + include([ + path( + 'ship/', + SalesOrderShipmentComplete.as_view(), + name='api-so-shipment-ship', + ), + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': models.SalesOrderShipment}, + name='api-so-shipment-metadata', + ), + re_path( + r'^.*$', + SalesOrderShipmentDetail.as_view(), + name='api-so-shipment-detail', + ), + ]), + ), + re_path( + r'^.*$', + SalesOrderShipmentList.as_view(), + name='api-so-shipment-list', + ), + ]), + ), + # Sales order detail view + path( + r'/', + include([ + re_path( + r'^allocate/', + SalesOrderAllocate.as_view(), + name='api-so-allocate', + ), + re_path( + r'^allocate-serials/', + SalesOrderAllocateSerials.as_view(), + name='api-so-allocate-serials', + ), + re_path( + r'^cancel/', SalesOrderCancel.as_view(), name='api-so-cancel' + ), + re_path(r'^issue/', SalesOrderIssue.as_view(), name='api-so-issue'), + re_path( + r'^complete/', + SalesOrderComplete.as_view(), + name='api-so-complete', + ), + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': models.SalesOrder}, + name='api-so-metadata', + ), + # SalesOrder detail endpoint + re_path(r'^.*$', SalesOrderDetail.as_view(), name='api-so-detail'), + ]), + ), + # Sales order status code information + re_path( + r'status/', + StatusView.as_view(), + {StatusView.MODEL_REF: SalesOrderStatus}, + name='api-so-status-codes', + ), + # Sales order list view + re_path(r'^.*$', SalesOrderList.as_view(), name='api-so-list'), + ]), + ), # API endpoints for sales order line items - re_path(r'^so-line/', include([ - path('/', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': models.SalesOrderLineItem}, name='api-so-line-metadata'), - re_path(r'^.*$', SalesOrderLineItemDetail.as_view(), name='api-so-line-detail'), - ])), - path('', SalesOrderLineItemList.as_view(), name='api-so-line-list'), - ])), - + re_path( + r'^so-line/', + include([ + path( + '/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': models.SalesOrderLineItem}, + name='api-so-line-metadata', + ), + re_path( + r'^.*$', + SalesOrderLineItemDetail.as_view(), + name='api-so-line-detail', + ), + ]), + ), + path('', SalesOrderLineItemList.as_view(), name='api-so-line-list'), + ]), + ), # API endpoints for sales order extra line - re_path(r'^so-extra-line/', include([ - path('/', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': models.SalesOrderExtraLine}, name='api-so-extra-line-metadata'), - re_path(r'^.*$', SalesOrderExtraLineDetail.as_view(), name='api-so-extra-line-detail'), - ])), - path('', SalesOrderExtraLineList.as_view(), name='api-so-extra-line-list'), - ])), - + re_path( + r'^so-extra-line/', + include([ + path( + '/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': models.SalesOrderExtraLine}, + name='api-so-extra-line-metadata', + ), + re_path( + r'^.*$', + SalesOrderExtraLineDetail.as_view(), + name='api-so-extra-line-detail', + ), + ]), + ), + path('', SalesOrderExtraLineList.as_view(), name='api-so-extra-line-list'), + ]), + ), # API endpoints for sales order allocations - re_path(r'^so-allocation/', include([ - path('/', SalesOrderAllocationDetail.as_view(), name='api-so-allocation-detail'), - re_path(r'^.*$', SalesOrderAllocationList.as_view(), name='api-so-allocation-list'), - ])), - + re_path( + r'^so-allocation/', + include([ + path( + '/', + SalesOrderAllocationDetail.as_view(), + name='api-so-allocation-detail', + ), + re_path( + r'^.*$', + SalesOrderAllocationList.as_view(), + name='api-so-allocation-list', + ), + ]), + ), # API endpoints for return orders - re_path(r'^ro/', include([ - - re_path(r'^attachment/', include([ - path('/', ReturnOrderAttachmentDetail.as_view(), name='api-return-order-attachment-detail'), - re_path(r'^.*$', ReturnOrderAttachmentList.as_view(), name='api-return-order-attachment-list'), - ])), - - # Return Order detail endpoints - path('/', include([ - re_path(r'cancel/', ReturnOrderCancel.as_view(), name='api-return-order-cancel'), - re_path(r'complete/', ReturnOrderComplete.as_view(), name='api-return-order-complete'), - re_path(r'issue/', ReturnOrderIssue.as_view(), name='api-return-order-issue'), - re_path(r'receive/', ReturnOrderReceive.as_view(), name='api-return-order-receive'), - re_path(r'metadata/', MetadataView.as_view(), {'model': models.ReturnOrder}, name='api-return-order-metadata'), - re_path(r'.*$', ReturnOrderDetail.as_view(), name='api-return-order-detail'), - ])), - - # Return order status code information - re_path(r'status/', StatusView.as_view(), {StatusView.MODEL_REF: ReturnOrderStatus}, name='api-return-order-status-codes'), - - # Return Order list - re_path(r'^.*$', ReturnOrderList.as_view(), name='api-return-order-list'), - ])), - + re_path( + r'^ro/', + include([ + re_path( + r'^attachment/', + include([ + path( + '/', + ReturnOrderAttachmentDetail.as_view(), + name='api-return-order-attachment-detail', + ), + re_path( + r'^.*$', + ReturnOrderAttachmentList.as_view(), + name='api-return-order-attachment-list', + ), + ]), + ), + # Return Order detail endpoints + path( + '/', + include([ + re_path( + r'cancel/', + ReturnOrderCancel.as_view(), + name='api-return-order-cancel', + ), + re_path( + r'complete/', + ReturnOrderComplete.as_view(), + name='api-return-order-complete', + ), + re_path( + r'issue/', + ReturnOrderIssue.as_view(), + name='api-return-order-issue', + ), + re_path( + r'receive/', + ReturnOrderReceive.as_view(), + name='api-return-order-receive', + ), + re_path( + r'metadata/', + MetadataView.as_view(), + {'model': models.ReturnOrder}, + name='api-return-order-metadata', + ), + re_path( + r'.*$', + ReturnOrderDetail.as_view(), + name='api-return-order-detail', + ), + ]), + ), + # Return order status code information + re_path( + r'status/', + StatusView.as_view(), + {StatusView.MODEL_REF: ReturnOrderStatus}, + name='api-return-order-status-codes', + ), + # Return Order list + re_path(r'^.*$', ReturnOrderList.as_view(), name='api-return-order-list'), + ]), + ), # API endpoints for return order lines - re_path(r'^ro-line/', include([ - path('/', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': models.ReturnOrderLineItem}, name='api-return-order-line-metadata'), - re_path(r'^.*$', ReturnOrderLineItemDetail.as_view(), name='api-return-order-line-detail'), - ])), - - # Return order line item status code information - re_path(r'status/', StatusView.as_view(), {StatusView.MODEL_REF: ReturnOrderLineStatus}, name='api-return-order-line-status-codes'), - - path('', ReturnOrderLineItemList.as_view(), name='api-return-order-line-list'), - ])), - + re_path( + r'^ro-line/', + include([ + path( + '/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': models.ReturnOrderLineItem}, + name='api-return-order-line-metadata', + ), + re_path( + r'^.*$', + ReturnOrderLineItemDetail.as_view(), + name='api-return-order-line-detail', + ), + ]), + ), + # Return order line item status code information + re_path( + r'status/', + StatusView.as_view(), + {StatusView.MODEL_REF: ReturnOrderLineStatus}, + name='api-return-order-line-status-codes', + ), + path( + '', ReturnOrderLineItemList.as_view(), name='api-return-order-line-list' + ), + ]), + ), # API endpoints for return order extra line - re_path(r'^ro-extra-line/', include([ - path('/', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': models.ReturnOrderExtraLine}, name='api-return-order-extra-line-metadata'), - re_path(r'^.*$', ReturnOrderExtraLineDetail.as_view(), name='api-return-order-extra-line-detail'), - ])), - path('', ReturnOrderExtraLineList.as_view(), name='api-return-order-extra-line-list'), - ])), - + re_path( + r'^ro-extra-line/', + include([ + path( + '/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': models.ReturnOrderExtraLine}, + name='api-return-order-extra-line-metadata', + ), + re_path( + r'^.*$', + ReturnOrderExtraLineDetail.as_view(), + name='api-return-order-extra-line-detail', + ), + ]), + ), + path( + '', + ReturnOrderExtraLineList.as_view(), + name='api-return-order-extra-line-list', + ), + ]), + ), # API endpoint for subscribing to ICS calendar of purchase/sales/return orders - re_path(r'^calendar/(?Ppurchase-order|sales-order|return-order)/calendar.ics', OrderCalendarExport(), name='api-po-so-calendar'), + re_path( + r'^calendar/(?Ppurchase-order|sales-order|return-order)/calendar.ics', + OrderCalendarExport(), + name='api-po-so-calendar', + ), ] diff --git a/InvenTree/order/apps.py b/InvenTree/order/apps.py index e3bc0fc6d2..65a86fdb04 100644 --- a/InvenTree/order/apps.py +++ b/InvenTree/order/apps.py @@ -5,4 +5,5 @@ from django.apps import AppConfig class OrderConfig(AppConfig): """Configuration class for the 'order' app""" + name = 'order' diff --git a/InvenTree/order/forms.py b/InvenTree/order/forms.py index c38f50ed56..144057818a 100644 --- a/InvenTree/order/forms.py +++ b/InvenTree/order/forms.py @@ -17,14 +17,16 @@ class OrderMatchItemForm(MatchItemForm): if 'quantity' in col_guess.lower(): return forms.CharField( required=False, - widget=forms.NumberInput(attrs={ - 'name': 'quantity' + str(row['index']), - 'class': 'numberinput', - 'type': 'number', - 'min': '0', - 'step': 'any', - 'value': clean_decimal(row.get('quantity', '')), - }) + widget=forms.NumberInput( + attrs={ + 'name': 'quantity' + str(row['index']), + 'class': 'numberinput', + 'type': 'number', + 'min': '0', + 'step': 'any', + 'value': clean_decimal(row.get('quantity', '')), + } + ), ) # set price field elif 'price' in col_guess.lower(): diff --git a/InvenTree/order/models.py b/InvenTree/order/models.py index 20f7432fab..e139975f73 100644 --- a/InvenTree/order/models.py +++ b/InvenTree/order/models.py @@ -35,19 +35,31 @@ from common.settings import currency_code_default from company.models import Address, Company, Contact, SupplierPart from generic.states import StateTransitionMixin from InvenTree.exceptions import log_error -from InvenTree.fields import (InvenTreeModelMoneyField, InvenTreeURLField, - RoundingDecimalField) +from InvenTree.fields import ( + InvenTreeModelMoneyField, + InvenTreeURLField, + RoundingDecimalField, +) from InvenTree.helpers import decimal2string from InvenTree.helpers_model import getSetting, notify_responsible -from InvenTree.models import (InvenTreeAttachment, InvenTreeBarcodeMixin, - InvenTreeNotesMixin, MetadataMixin, - ReferenceIndexingMixin) -from InvenTree.status_codes import (PurchaseOrderStatus, - PurchaseOrderStatusGroups, - ReturnOrderLineStatus, ReturnOrderStatus, - ReturnOrderStatusGroups, SalesOrderStatus, - SalesOrderStatusGroups, StockHistoryCode, - StockStatus) +from InvenTree.models import ( + InvenTreeAttachment, + InvenTreeBarcodeMixin, + InvenTreeNotesMixin, + MetadataMixin, + ReferenceIndexingMixin, +) +from InvenTree.status_codes import ( + PurchaseOrderStatus, + PurchaseOrderStatusGroups, + ReturnOrderLineStatus, + ReturnOrderStatus, + ReturnOrderStatusGroups, + SalesOrderStatus, + SalesOrderStatusGroups, + StockHistoryCode, + StockStatus, +) from part import models as PartModels from plugin.events import trigger_event @@ -59,6 +71,7 @@ class TotalPriceMixin(models.Model): class Meta: """Meta for MetadataMixin.""" + abstract = True def save(self, *args, **kwargs): @@ -68,18 +81,20 @@ class TotalPriceMixin(models.Model): super().save(*args, **kwargs) total_price = InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, allow_negative=False, verbose_name=_('Total Price'), - help_text=_('Total price for this order') + help_text=_('Total price for this order'), ) order_currency = models.CharField( max_length=3, verbose_name=_('Order Currency'), - blank=True, null=True, + blank=True, + null=True, help_text=_('Currency for this order (leave blank to use company default)'), - validators=[InvenTree.validators.validate_currency_code] + validators=[InvenTree.validators.validate_currency_code], ) @property @@ -122,7 +137,6 @@ class TotalPriceMixin(models.Model): # order items for line in self.lines.all(): - if not line.price: continue @@ -140,7 +154,6 @@ class TotalPriceMixin(models.Model): # extra items for line in self.extra_lines.all(): - if not line.price: continue @@ -161,7 +174,13 @@ class TotalPriceMixin(models.Model): return total -class Order(StateTransitionMixin, InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, ReferenceIndexingMixin): +class Order( + StateTransitionMixin, + InvenTreeBarcodeMixin, + InvenTreeNotesMixin, + MetadataMixin, + ReferenceIndexingMixin, +): """Abstract model for an order. Instances of this class: @@ -182,6 +201,7 @@ class Order(StateTransitionMixin, InvenTreeBarcodeMixin, InvenTreeNotesMixin, Me class Meta: """Metaclass options. Abstract ensures no database table is created.""" + abstract = True def save(self, *args, **kwargs): @@ -204,7 +224,7 @@ class Order(StateTransitionMixin, InvenTreeBarcodeMixin, InvenTreeNotesMixin, Me if self.company and self.contact: if self.contact.company != self.company: raise ValidationError({ - "contact": _("Contact does not match selected company") + 'contact': _('Contact does not match selected company') }) @classmethod @@ -214,7 +234,11 @@ class Order(StateTransitionMixin, InvenTreeBarcodeMixin, InvenTreeNotesMixin, Me It requires any subclasses to implement the get_status_class() class method """ today = datetime.now().date() - return Q(status__in=cls.get_status_class().OPEN) & ~Q(target_date=None) & Q(target_date__lt=today) + return ( + Q(status__in=cls.get_status_class().OPEN) + & ~Q(target_date=None) + & Q(target_date__lt=today) + ) @property def is_overdue(self): @@ -222,37 +246,59 @@ class Order(StateTransitionMixin, InvenTreeBarcodeMixin, InvenTreeNotesMixin, Me Makes use of the overdue_filter() method to avoid code duplication """ - return self.__class__.objects.filter(pk=self.pk).filter(self.__class__.overdue_filter()).exists() + return ( + self.__class__.objects.filter(pk=self.pk) + .filter(self.__class__.overdue_filter()) + .exists() + ) - description = models.CharField(max_length=250, blank=True, verbose_name=_('Description'), help_text=_('Order description (optional)')) + description = models.CharField( + max_length=250, + blank=True, + verbose_name=_('Description'), + help_text=_('Order description (optional)'), + ) project_code = models.ForeignKey( - common_models.ProjectCode, on_delete=models.SET_NULL, - blank=True, null=True, - verbose_name=_('Project Code'), help_text=_('Select project code for this order') + common_models.ProjectCode, + on_delete=models.SET_NULL, + blank=True, + null=True, + verbose_name=_('Project Code'), + help_text=_('Select project code for this order'), ) - link = InvenTreeURLField(blank=True, verbose_name=_('Link'), help_text=_('Link to external page')) + link = InvenTreeURLField( + blank=True, verbose_name=_('Link'), help_text=_('Link to external page') + ) target_date = models.DateField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Target Date'), - help_text=_('Expected date for order delivery. Order will be overdue after this date.'), + help_text=_( + 'Expected date for order delivery. Order will be overdue after this date.' + ), ) - creation_date = models.DateField(blank=True, null=True, verbose_name=_('Creation Date')) + creation_date = models.DateField( + blank=True, null=True, verbose_name=_('Creation Date') + ) - created_by = models.ForeignKey(User, - on_delete=models.SET_NULL, - blank=True, null=True, - related_name='+', - verbose_name=_('Created By') - ) + created_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + blank=True, + null=True, + related_name='+', + verbose_name=_('Created By'), + ) responsible = models.ForeignKey( UserModels.Owner, on_delete=models.SET_NULL, - blank=True, null=True, + blank=True, + null=True, help_text=_('User or group responsible for this order'), verbose_name=_('Responsible'), related_name='+', @@ -261,7 +307,8 @@ class Order(StateTransitionMixin, InvenTreeBarcodeMixin, InvenTreeNotesMixin, Me contact = models.ForeignKey( Contact, on_delete=models.SET_NULL, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Contact'), help_text=_('Point of contact for this order'), related_name='+', @@ -270,7 +317,8 @@ class Order(StateTransitionMixin, InvenTreeBarcodeMixin, InvenTreeNotesMixin, Me address = models.ForeignKey( Address, on_delete=models.SET_NULL, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Address'), help_text=_('Company address for this order'), related_name='+', @@ -279,7 +327,7 @@ class Order(StateTransitionMixin, InvenTreeBarcodeMixin, InvenTreeNotesMixin, Me @classmethod def get_status_class(cls): """Return the enumeration class which represents the 'status' field for this model""" - raise NotImplementedError(f"get_status_class() not implemented for {__class__}") + raise NotImplementedError(f'get_status_class() not implemented for {__class__}') class PurchaseOrder(TotalPriceMixin, Order): @@ -310,7 +358,7 @@ class PurchaseOrder(TotalPriceMixin, Order): def api_defaults(cls, request): """Return default values for this model when issuing an API OPTIONS request""" defaults = { - 'reference': order.validators.generate_next_purchase_order_reference(), + 'reference': order.validators.generate_next_purchase_order_reference() } return defaults @@ -342,10 +390,19 @@ class PurchaseOrder(TotalPriceMixin, Order): return queryset # Construct a queryset for "received" orders within the range - received = Q(status=PurchaseOrderStatus.COMPLETE.value) & Q(complete_date__gte=min_date) & Q(complete_date__lte=max_date) + received = ( + Q(status=PurchaseOrderStatus.COMPLETE.value) + & Q(complete_date__gte=min_date) + & Q(complete_date__lte=max_date) + ) # Construct a queryset for "pending" orders within the range - pending = Q(status__in=PurchaseOrderStatusGroups.OPEN) & ~Q(target_date=None) & Q(target_date__gte=min_date) & Q(target_date__lte=max_date) + pending = ( + Q(status__in=PurchaseOrderStatusGroups.OPEN) + & ~Q(target_date=None) + & Q(target_date__gte=min_date) + & Q(target_date__lte=max_date) + ) # TODO - Construct a queryset for "overdue" orders within the range @@ -364,13 +421,14 @@ class PurchaseOrder(TotalPriceMixin, Order): verbose_name=_('Reference'), help_text=_('Order reference'), default=order.validators.generate_next_purchase_order_reference, - validators=[ - order.validators.validate_purchase_order_reference, - ] + validators=[order.validators.validate_purchase_order_reference], ) - status = models.PositiveIntegerField(default=PurchaseOrderStatus.PENDING.value, choices=PurchaseOrderStatus.items(), - help_text=_('Purchase order status')) + status = models.PositiveIntegerField( + default=PurchaseOrderStatus.PENDING.value, + choices=PurchaseOrderStatus.items(), + help_text=_('Purchase order status'), + ) @property def status_text(self): @@ -378,14 +436,13 @@ class PurchaseOrder(TotalPriceMixin, Order): return PurchaseOrderStatus.text(self.status) supplier = models.ForeignKey( - Company, on_delete=models.SET_NULL, + Company, + on_delete=models.SET_NULL, null=True, - limit_choices_to={ - 'is_supplier': True, - }, + limit_choices_to={'is_supplier': True}, related_name='purchase_orders', verbose_name=_('Supplier'), - help_text=_('Company from which the items are being ordered') + help_text=_('Company from which the items are being ordered'), ) @property @@ -393,30 +450,45 @@ class PurchaseOrder(TotalPriceMixin, Order): """Accessor helper for Order base class""" return self.supplier - supplier_reference = models.CharField(max_length=64, blank=True, verbose_name=_('Supplier Reference'), help_text=_("Supplier order reference code")) + supplier_reference = models.CharField( + max_length=64, + blank=True, + verbose_name=_('Supplier Reference'), + help_text=_('Supplier order reference code'), + ) received_by = models.ForeignKey( User, on_delete=models.SET_NULL, - blank=True, null=True, + blank=True, + null=True, related_name='+', - verbose_name=_('received by') + verbose_name=_('received by'), ) issue_date = models.DateField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Issue Date'), - help_text=_('Date order was issued') + help_text=_('Date order was issued'), ) complete_date = models.DateField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Completion Date'), - help_text=_('Date order was completed') + help_text=_('Date order was completed'), ) @transaction.atomic - def add_line_item(self, supplier_part, quantity, group: bool = True, reference: str = '', purchase_price=None): + def add_line_item( + self, + supplier_part, + quantity, + group: bool = True, + reference: str = '', + purchase_price=None, + ): """Add a new line item to this purchase order. This function will check that: @@ -442,12 +514,15 @@ class PurchaseOrder(TotalPriceMixin, Order): quantity = int(quantity) if quantity <= 0: raise ValidationError({ - 'quantity': _("Quantity must be greater than zero")}) + 'quantity': _('Quantity must be greater than zero') + }) except ValueError: - raise ValidationError({'quantity': _("Invalid quantity provided")}) + raise ValidationError({'quantity': _('Invalid quantity provided')}) if supplier_part.supplier != self.supplier: - raise ValidationError({'supplier': _("Part supplier must match PO supplier")}) + raise ValidationError({ + 'supplier': _('Part supplier must match PO supplier') + }) if group: # Check if there is already a matching line item (for this PurchaseOrder) @@ -498,7 +573,7 @@ class PurchaseOrder(TotalPriceMixin, Order): self, PurchaseOrder, exclude=self.created_by, - content=InvenTreeNotificationBodies.NewOrder + content=InvenTreeNotificationBodies.NewOrder, ) def _action_complete(self, *args, **kwargs): @@ -522,17 +597,23 @@ class PurchaseOrder(TotalPriceMixin, Order): @transaction.atomic def place_order(self): """Attempt to transition to PLACED status.""" - return self.handle_transition(self.status, PurchaseOrderStatus.PLACED.value, self, self._action_place) + return self.handle_transition( + self.status, PurchaseOrderStatus.PLACED.value, self, self._action_place + ) @transaction.atomic def complete_order(self): """Attempt to transition to COMPLETE status.""" - return self.handle_transition(self.status, PurchaseOrderStatus.COMPLETE.value, self, self._action_complete) + return self.handle_transition( + self.status, PurchaseOrderStatus.COMPLETE.value, self, self._action_complete + ) @transaction.atomic def cancel_order(self): """Attempt to transition to CANCELLED status.""" - return self.handle_transition(self.status, PurchaseOrderStatus.CANCELLED.value, self, self._action_cancel) + return self.handle_transition( + self.status, PurchaseOrderStatus.CANCELLED.value, self, self._action_cancel + ) @property def is_pending(self): @@ -553,7 +634,7 @@ class PurchaseOrder(TotalPriceMixin, Order): """ return self.status in [ PurchaseOrderStatus.PLACED.value, - PurchaseOrderStatus.PENDING.value + PurchaseOrderStatus.PENDING.value, ] def _action_cancel(self, *args, **kwargs): @@ -569,8 +650,9 @@ class PurchaseOrder(TotalPriceMixin, Order): self, PurchaseOrder, exclude=self.created_by, - content=InvenTreeNotificationBodies.OrderCanceled + content=InvenTreeNotificationBodies.OrderCanceled, ) + # endregion def pending_line_items(self): @@ -605,7 +687,9 @@ class PurchaseOrder(TotalPriceMixin, Order): return self.lines.count() > 0 and self.pending_line_items().count() == 0 @transaction.atomic - def receive_line_item(self, line, location, quantity, user, status=StockStatus.OK.value, **kwargs): + def receive_line_item( + self, line, location, quantity, user, status=StockStatus.OK.value, **kwargs + ): """Receive a line item (or partial line item) against this PurchaseOrder.""" # Extract optional batch code for the new stock item batch_code = kwargs.get('batch_code', '') @@ -631,17 +715,14 @@ class PurchaseOrder(TotalPriceMixin, Order): try: if quantity < 0: raise ValidationError({ - "quantity": _("Quantity must be a positive number") + 'quantity': _('Quantity must be a positive number') }) quantity = InvenTree.helpers.clean_decimal(quantity) except TypeError: - raise ValidationError({ - "quantity": _("Invalid quantity provided") - }) + raise ValidationError({'quantity': _('Invalid quantity provided')}) # Create a new stock item if line.part and quantity > 0: - # Calculate received quantity in base units stock_quantity = line.part.base_quantity(quantity) @@ -660,7 +741,6 @@ class PurchaseOrder(TotalPriceMixin, Order): serials = [None] for sn in serials: - item = stock.models.StockItem( part=line.part.part, supplier_part=line.part, @@ -670,22 +750,16 @@ class PurchaseOrder(TotalPriceMixin, Order): status=status, batch=batch_code, serial=sn, - purchase_price=unit_purchase_price + purchase_price=unit_purchase_price, ) # Assign the provided barcode if barcode: - item.assign_barcode( - barcode_data=barcode, - save=False - ) + item.assign_barcode(barcode_data=barcode, save=False) item.save(add_note=False) - tracking_info = { - 'status': status, - 'purchaseorder': self.pk, - } + tracking_info = {'status': status, 'purchaseorder': self.pk} item.add_tracking_entry( StockHistoryCode.RECEIVED_AGAINST_PURCHASE_ORDER, @@ -694,7 +768,7 @@ class PurchaseOrder(TotalPriceMixin, Order): deltas=tracking_info, location=location, purchaseorder=self, - quantity=quantity + quantity=quantity, ) # Update the number of parts received against the particular line item @@ -704,7 +778,6 @@ class PurchaseOrder(TotalPriceMixin, Order): # Has this order been completed? if len(self.pending_line_items()) == 0: - self.received_by = user self.complete_order() # This will save the model @@ -737,9 +810,7 @@ class SalesOrder(TotalPriceMixin, Order): @classmethod def api_defaults(cls, request): """Return default values for this model when issuing an API OPTIONS request""" - defaults = { - 'reference': order.validators.generate_next_sales_order_reference(), - } + defaults = {'reference': order.validators.generate_next_sales_order_reference()} return defaults @@ -770,10 +841,19 @@ class SalesOrder(TotalPriceMixin, Order): return queryset # Construct a queryset for "completed" orders within the range - completed = Q(status__in=SalesOrderStatusGroups.COMPLETE) & Q(shipment_date__gte=min_date) & Q(shipment_date__lte=max_date) + completed = ( + Q(status__in=SalesOrderStatusGroups.COMPLETE) + & Q(shipment_date__gte=min_date) + & Q(shipment_date__lte=max_date) + ) # Construct a queryset for "pending" orders within the range - pending = Q(status__in=SalesOrderStatusGroups.OPEN) & ~Q(target_date=None) & Q(target_date__gte=min_date) & Q(target_date__lte=max_date) + pending = ( + Q(status__in=SalesOrderStatusGroups.OPEN) + & ~Q(target_date=None) + & Q(target_date__gte=min_date) + & Q(target_date__lte=max_date) + ) # TODO: Construct a queryset for "overdue" orders within the range @@ -792,9 +872,7 @@ class SalesOrder(TotalPriceMixin, Order): verbose_name=_('Reference'), help_text=_('Order reference'), default=order.validators.generate_next_sales_order_reference, - validators=[ - order.validators.validate_sales_order_reference, - ] + validators=[order.validators.validate_sales_order_reference], ) customer = models.ForeignKey( @@ -804,7 +882,7 @@ class SalesOrder(TotalPriceMixin, Order): limit_choices_to={'is_customer': True}, related_name='return_orders', verbose_name=_('Customer'), - help_text=_("Company to which the items are being sold"), + help_text=_('Company to which the items are being sold'), ) @property @@ -815,7 +893,8 @@ class SalesOrder(TotalPriceMixin, Order): status = models.PositiveIntegerField( default=SalesOrderStatus.PENDING.value, choices=SalesOrderStatus.items(), - verbose_name=_('Status'), help_text=_('Purchase order status') + verbose_name=_('Status'), + help_text=_('Purchase order status'), ) @property @@ -823,16 +902,24 @@ class SalesOrder(TotalPriceMixin, Order): """Return the text representation of the status field""" return SalesOrderStatus.text(self.status) - customer_reference = models.CharField(max_length=64, blank=True, verbose_name=_('Customer Reference '), help_text=_("Customer order reference code")) + customer_reference = models.CharField( + max_length=64, + blank=True, + verbose_name=_('Customer Reference '), + help_text=_('Customer order reference code'), + ) - shipment_date = models.DateField(blank=True, null=True, verbose_name=_('Shipment Date')) + shipment_date = models.DateField( + blank=True, null=True, verbose_name=_('Shipment Date') + ) shipped_by = models.ForeignKey( User, on_delete=models.SET_NULL, - blank=True, null=True, + blank=True, + null=True, related_name='+', - verbose_name=_('shipped by') + verbose_name=_('shipped by'), ) @property @@ -870,7 +957,9 @@ class SalesOrder(TotalPriceMixin, Order): def is_completed(self): """Check if this order is "shipped" (all line items delivered).""" - return self.lines.count() > 0 and all(line.is_completed() for line in self.lines.all()) + return self.lines.count() > 0 and all( + line.is_completed() for line in self.lines.all() + ) def can_complete(self, raise_error=False, allow_incomplete_lines=False): """Test if this SalesOrder can be completed. @@ -878,23 +967,27 @@ class SalesOrder(TotalPriceMixin, Order): Throws a ValidationError if cannot be completed. """ try: - # Order without line items cannot be completed if self.lines.count() == 0: - raise ValidationError(_('Order cannot be completed as no parts have been assigned')) + raise ValidationError( + _('Order cannot be completed as no parts have been assigned') + ) # Only an open order can be marked as shipped elif not self.is_open: raise ValidationError(_('Only an open order can be marked as complete')) elif self.pending_shipment_count > 0: - raise ValidationError(_("Order cannot be completed as there are incomplete shipments")) + raise ValidationError( + _('Order cannot be completed as there are incomplete shipments') + ) elif not allow_incomplete_lines and self.pending_line_count > 0: - raise ValidationError(_("Order cannot be completed as there are incomplete line items")) + raise ValidationError( + _('Order cannot be completed as there are incomplete line items') + ) except ValidationError as e: - if raise_error: raise e else: @@ -966,7 +1059,7 @@ class SalesOrder(TotalPriceMixin, Order): self, SalesOrder, exclude=self.created_by, - content=InvenTreeNotificationBodies.OrderCanceled + content=InvenTreeNotificationBodies.OrderCanceled, ) return True @@ -974,17 +1067,29 @@ class SalesOrder(TotalPriceMixin, Order): @transaction.atomic def issue_order(self): """Attempt to transition to IN_PROGRESS status.""" - return self.handle_transition(self.status, SalesOrderStatus.IN_PROGRESS.value, self, self._action_place) + return self.handle_transition( + self.status, SalesOrderStatus.IN_PROGRESS.value, self, self._action_place + ) @transaction.atomic def complete_order(self, user, **kwargs): """Attempt to transition to SHIPPED status.""" - return self.handle_transition(self.status, SalesOrderStatus.SHIPPED.value, self, self._action_complete, user=user, **kwargs) + return self.handle_transition( + self.status, + SalesOrderStatus.SHIPPED.value, + self, + self._action_complete, + user=user, + **kwargs, + ) @transaction.atomic def cancel_order(self): """Attempt to transition to CANCELLED status.""" - return self.handle_transition(self.status, SalesOrderStatus.CANCELLED.value, self, self._action_cancel) + return self.handle_transition( + self.status, SalesOrderStatus.CANCELLED.value, self, self._action_cancel + ) + # endregion @property @@ -1042,7 +1147,10 @@ def after_save_sales_order(sender, instance: SalesOrder, created: bool, **kwargs - Ignore if the database is not ready for access - Ignore if data import is active """ - if not InvenTree.ready.canAppAccessDatabase(allow_test=True) or InvenTree.ready.isImportingData(): + if ( + not InvenTree.ready.canAppAccessDatabase(allow_test=True) + or InvenTree.ready.isImportingData() + ): return if created: @@ -1050,10 +1158,7 @@ def after_save_sales_order(sender, instance: SalesOrder, created: bool, **kwargs if getSetting('SALESORDER_DEFAULT_SHIPMENT'): # Create default shipment - SalesOrderShipment.objects.create( - order=instance, - reference='1', - ) + SalesOrderShipment.objects.create(order=instance, reference='1') # Notify the responsible users that the sales order has been created notify_responsible(instance, sender, exclude=instance.created_by) @@ -1069,9 +1174,11 @@ class PurchaseOrderAttachment(InvenTreeAttachment): def getSubdir(self): """Return the directory path where PurchaseOrderAttachment files are located""" - return os.path.join("po_files", str(self.order.id)) + return os.path.join('po_files', str(self.order.id)) - order = models.ForeignKey(PurchaseOrder, on_delete=models.CASCADE, related_name="attachments") + order = models.ForeignKey( + PurchaseOrder, on_delete=models.CASCADE, related_name='attachments' + ) class SalesOrderAttachment(InvenTreeAttachment): @@ -1084,9 +1191,11 @@ class SalesOrderAttachment(InvenTreeAttachment): def getSubdir(self): """Return the directory path where SalesOrderAttachment files are located""" - return os.path.join("so_files", str(self.order.id)) + return os.path.join('so_files', str(self.order.id)) - order = models.ForeignKey(SalesOrder, on_delete=models.CASCADE, related_name='attachments') + order = models.ForeignKey( + SalesOrder, on_delete=models.CASCADE, related_name='attachments' + ) class OrderLineItem(MetadataMixin, models.Model): @@ -1101,6 +1210,7 @@ class OrderLineItem(MetadataMixin, models.Model): class Meta: """Metaclass options. Abstract ensures no database table is created.""" + abstract = True def save(self, *args, **kwargs): @@ -1123,7 +1233,8 @@ class OrderLineItem(MetadataMixin, models.Model): verbose_name=_('Quantity'), help_text=_('Item quantity'), default=1, - max_digits=15, decimal_places=5, + max_digits=15, + decimal_places=5, validators=[MinValueValidator(0)], ) @@ -1133,20 +1244,31 @@ class OrderLineItem(MetadataMixin, models.Model): if self.price: return self.quantity * self.price - reference = models.CharField(max_length=100, blank=True, verbose_name=_('Reference'), help_text=_('Line item reference')) + reference = models.CharField( + max_length=100, + blank=True, + verbose_name=_('Reference'), + help_text=_('Line item reference'), + ) - notes = models.CharField(max_length=500, blank=True, verbose_name=_('Notes'), help_text=_('Line item notes')) + notes = models.CharField( + max_length=500, + blank=True, + verbose_name=_('Notes'), + help_text=_('Line item notes'), + ) link = InvenTreeURLField( - blank=True, - verbose_name=_('Link'), - help_text=_('Link to external page') + blank=True, verbose_name=_('Link'), help_text=_('Link to external page') ) target_date = models.DateField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Target Date'), - help_text=_('Target date for this line item (leave blank to use the target date from the order)'), + help_text=_( + 'Target date for this line item (leave blank to use the target date from the order)' + ), ) @@ -1159,16 +1281,19 @@ class OrderExtraLine(OrderLineItem): class Meta: """Metaclass options. Abstract ensures no database table is created.""" + abstract = True description = models.CharField( - max_length=250, blank=True, + max_length=250, + blank=True, verbose_name=_('Description'), - help_text=_('Line item description (optional)') + help_text=_('Line item description (optional)'), ) context = models.JSONField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Context'), help_text=_('Additional context for this line'), ) @@ -1176,7 +1301,8 @@ class OrderExtraLine(OrderLineItem): price = InvenTreeModelMoneyField( max_digits=19, decimal_places=6, - null=True, blank=True, + null=True, + blank=True, allow_negative=True, verbose_name=_('Price'), help_text=_('Unit price'), @@ -1191,7 +1317,11 @@ class PurchaseOrderLineItem(OrderLineItem): """ # Filter for determining if a particular PurchaseOrderLineItem is overdue - OVERDUE_FILTER = Q(received__lt=F('quantity')) & ~Q(target_date=None) & Q(target_date__lt=datetime.now().date()) + OVERDUE_FILTER = ( + Q(received__lt=F('quantity')) + & ~Q(target_date=None) + & Q(target_date__lt=datetime.now().date()) + ) @staticmethod def get_api_url(): @@ -1208,23 +1338,23 @@ class PurchaseOrderLineItem(OrderLineItem): if self.order.supplier and self.part: # Supplier part *must* point to the same supplier! if self.part.supplier != self.order.supplier: - raise ValidationError({ - 'part': _('Supplier part must match supplier') - }) + raise ValidationError({'part': _('Supplier part must match supplier')}) def __str__(self): """Render a string representation of a PurchaseOrderLineItem instance""" - return "{n} x {part} from {supplier} (for {po})".format( + return '{n} x {part} from {supplier} (for {po})'.format( n=decimal2string(self.quantity), part=self.part.SKU if self.part else 'unknown part', supplier=self.order.supplier.name if self.order.supplier else _('deleted'), - po=self.order) + po=self.order, + ) order = models.ForeignKey( - PurchaseOrder, on_delete=models.CASCADE, + PurchaseOrder, + on_delete=models.CASCADE, related_name='lines', verbose_name=_('Order'), - help_text=_('Purchase Order') + help_text=_('Purchase Order'), ) def get_base_part(self): @@ -1237,11 +1367,13 @@ class PurchaseOrderLineItem(OrderLineItem): return self.part.part part = models.ForeignKey( - SupplierPart, on_delete=models.SET_NULL, - blank=False, null=True, + SupplierPart, + on_delete=models.SET_NULL, + blank=False, + null=True, related_name='purchase_order_line_items', verbose_name=_('Part'), - help_text=_("Supplier part"), + help_text=_('Supplier part'), ) received = models.DecimalField( @@ -1249,13 +1381,14 @@ class PurchaseOrderLineItem(OrderLineItem): max_digits=15, default=0, verbose_name=_('Received'), - help_text=_('Number of items received') + help_text=_('Number of items received'), ) purchase_price = InvenTreeModelMoneyField( max_digits=19, decimal_places=6, - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Purchase Price'), help_text=_('Unit purchase price'), ) @@ -1266,11 +1399,13 @@ class PurchaseOrderLineItem(OrderLineItem): return self.purchase_price destination = TreeForeignKey( - 'stock.StockLocation', on_delete=models.SET_NULL, + 'stock.StockLocation', + on_delete=models.SET_NULL, verbose_name=_('Destination'), related_name='po_lines', - blank=True, null=True, - help_text=_('Where does the Purchaser want this item to be stored?') + blank=True, + null=True, + help_text=_('Where does the Purchaser want this item to be stored?'), ) def get_destination(self): @@ -1280,7 +1415,9 @@ class PurchaseOrderLineItem(OrderLineItem): stock items location will be reported as the location for the entire line. """ - for item in stock.models.StockItem.objects.filter(supplier_part=self.part, purchase_order=self.order): + for item in stock.models.StockItem.objects.filter( + supplier_part=self.part, purchase_order=self.order + ): if item.location: return item.location if self.destination: @@ -1302,12 +1439,19 @@ class PurchaseOrderExtraLine(OrderExtraLine): title: title of line price: The unit price for this OrderLine """ + @staticmethod def get_api_url(): """Return the API URL associated with the PurchaseOrderExtraLine model""" return reverse('api-po-extra-line-list') - order = models.ForeignKey(PurchaseOrder, on_delete=models.CASCADE, related_name='extra_lines', verbose_name=_('Order'), help_text=_('Purchase Order')) + order = models.ForeignKey( + PurchaseOrder, + on_delete=models.CASCADE, + related_name='extra_lines', + verbose_name=_('Order'), + help_text=_('Purchase Order'), + ) class SalesOrderLineItem(OrderLineItem): @@ -1321,7 +1465,11 @@ class SalesOrderLineItem(OrderLineItem): """ # Filter for determining if a particular SalesOrderLineItem is overdue - OVERDUE_FILTER = Q(shipped__lt=F('quantity')) & ~Q(target_date=None) & Q(target_date__lt=datetime.now().date()) + OVERDUE_FILTER = ( + Q(shipped__lt=F('quantity')) + & ~Q(target_date=None) + & Q(target_date__lt=datetime.now().date()) + ) @staticmethod def get_api_url(): @@ -1335,12 +1483,12 @@ class SalesOrderLineItem(OrderLineItem): if self.part: if self.part.virtual: raise ValidationError({ - 'part': _("Virtual part cannot be assigned to a sales order") + 'part': _('Virtual part cannot be assigned to a sales order') }) if not self.part.salable: raise ValidationError({ - 'part': _("Only salable parts can be assigned to a sales order") + 'part': _('Only salable parts can be assigned to a sales order') }) order = models.ForeignKey( @@ -1348,24 +1496,24 @@ class SalesOrderLineItem(OrderLineItem): on_delete=models.CASCADE, related_name='lines', verbose_name=_('Order'), - help_text=_('Sales Order') + help_text=_('Sales Order'), ) part = models.ForeignKey( - 'part.Part', on_delete=models.SET_NULL, + 'part.Part', + on_delete=models.SET_NULL, related_name='sales_order_line_items', null=True, verbose_name=_('Part'), help_text=_('Part'), - limit_choices_to={ - 'salable': True, - 'virtual': False, - }) + limit_choices_to={'salable': True, 'virtual': False}, + ) sale_price = InvenTreeModelMoneyField( max_digits=19, decimal_places=6, - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Sale Price'), help_text=_('Unit sale price'), ) @@ -1379,13 +1527,16 @@ class SalesOrderLineItem(OrderLineItem): verbose_name=_('Shipped'), help_text=_('Shipped quantity'), default=0, - max_digits=15, decimal_places=5, - validators=[MinValueValidator(0)] + max_digits=15, + decimal_places=5, + validators=[MinValueValidator(0)], ) def fulfilled_quantity(self): """Return the total stock quantity fulfilled against this line item.""" - query = self.order.stock_items.filter(part=self.part).aggregate(fulfilled=Coalesce(Sum('quantity'), Decimal(0))) + query = self.order.stock_items.filter(part=self.part).aggregate( + fulfilled=Coalesce(Sum('quantity'), Decimal(0)) + ) return query['fulfilled'] @@ -1394,7 +1545,9 @@ class SalesOrderLineItem(OrderLineItem): This is a summation of the quantity of each attached StockItem """ - query = self.allocations.aggregate(allocated=Coalesce(Sum('quantity'), Decimal(0))) + query = self.allocations.aggregate( + allocated=Coalesce(Sum('quantity'), Decimal(0)) + ) return query['allocated'] @@ -1431,10 +1584,9 @@ class SalesOrderShipment(InvenTreeNotesMixin, MetadataMixin, models.Model): class Meta: """Metaclass defines extra model options""" + # Shipment reference must be unique for a given sales order - unique_together = [ - 'order', 'reference', - ] + unique_together = ['order', 'reference'] @staticmethod def get_api_url(): @@ -1444,20 +1596,23 @@ class SalesOrderShipment(InvenTreeNotesMixin, MetadataMixin, models.Model): order = models.ForeignKey( SalesOrder, on_delete=models.CASCADE, - blank=False, null=False, + blank=False, + null=False, related_name='shipments', verbose_name=_('Order'), help_text=_('Sales Order'), ) shipment_date = models.DateField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Shipment Date'), help_text=_('Date of shipment'), ) delivery_date = models.DateField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Delivery Date'), help_text=_('Date of delivery of shipment'), ) @@ -1465,7 +1620,8 @@ class SalesOrderShipment(InvenTreeNotesMixin, MetadataMixin, models.Model): checked_by = models.ForeignKey( User, on_delete=models.SET_NULL, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Checked By'), help_text=_('User who checked this shipment'), related_name='+', @@ -1496,9 +1652,7 @@ class SalesOrderShipment(InvenTreeNotesMixin, MetadataMixin, models.Model): ) link = InvenTreeURLField( - blank=True, - verbose_name=_('Link'), - help_text=_('Link to external page') + blank=True, verbose_name=_('Link'), help_text=_('Link to external page') ) def is_complete(self): @@ -1514,10 +1668,10 @@ class SalesOrderShipment(InvenTreeNotesMixin, MetadataMixin, models.Model): try: if self.shipment_date: # Shipment has already been sent! - raise ValidationError(_("Shipment has already been sent")) + raise ValidationError(_('Shipment has already been sent')) if self.allocations.count() == 0: - raise ValidationError(_("Shipment has no allocated stock items")) + raise ValidationError(_('Shipment has no allocated stock items')) except ValidationError as e: if raise_error: @@ -1587,15 +1741,18 @@ class SalesOrderExtraLine(OrderExtraLine): title: title of line price: The unit price for this OrderLine """ + @staticmethod def get_api_url(): """Return the API URL associated with the SalesOrderExtraLine model""" return reverse('api-so-extra-line-list') order = models.ForeignKey( - SalesOrder, on_delete=models.CASCADE, + SalesOrder, + on_delete=models.CASCADE, related_name='extra_lines', - verbose_name=_('Order'), help_text=_('Sales Order') + verbose_name=_('Order'), + help_text=_('Sales Order'), ) @@ -1638,7 +1795,9 @@ class SalesOrderAllocation(models.Model): if self.line.part != self.item.part: variants = self.line.part.get_descendants(include_self=True) if self.line.part not in variants: - errors['item'] = _('Cannot allocate stock item to a line with a different part') + errors['item'] = _( + 'Cannot allocate stock item to a line with a different part' + ) except PartModels.Part.DoesNotExist: errors['line'] = _('Cannot allocate stock to a line without a part') @@ -1648,12 +1807,12 @@ class SalesOrderAllocation(models.Model): # Ensure that we do not 'over allocate' a stock item build_allocation_count = self.item.build_allocation_count() sales_allocation_count = self.item.sales_order_allocation_count( - exclude_allocations={ - "pk": self.pk, - } + exclude_allocations={'pk': self.pk} ) - total_allocation = build_allocation_count + sales_allocation_count + self.quantity + total_allocation = ( + build_allocation_count + sales_allocation_count + self.quantity + ) if total_allocation > self.item.quantity: errors['quantity'] = _('Stock item is over-allocated') @@ -1675,7 +1834,7 @@ class SalesOrderAllocation(models.Model): SalesOrderLineItem, on_delete=models.CASCADE, verbose_name=_('Line'), - related_name='allocations' + related_name='allocations', ) shipment = models.ForeignKey( @@ -1697,10 +1856,17 @@ class SalesOrderAllocation(models.Model): 'sales_order': None, }, verbose_name=_('Item'), - help_text=_('Select stock item to allocate') + help_text=_('Select stock item to allocate'), ) - quantity = RoundingDecimalField(max_digits=15, decimal_places=5, validators=[MinValueValidator(0)], default=1, verbose_name=_('Quantity'), help_text=_('Enter stock allocation quantity')) + quantity = RoundingDecimalField( + max_digits=15, + decimal_places=5, + validators=[MinValueValidator(0)], + default=1, + verbose_name=_('Quantity'), + help_text=_('Enter stock allocation quantity'), + ) def get_location(self): """Return the value of the location associated with this allocation""" @@ -1720,10 +1886,7 @@ class SalesOrderAllocation(models.Model): order = self.line.order item = self.item.allocateToCustomer( - order.customer, - quantity=self.quantity, - order=order, - user=user + order.customer, quantity=self.quantity, order=order, user=user ) # Update the 'shipped' quantity @@ -1763,7 +1926,7 @@ class ReturnOrder(TotalPriceMixin, Order): def api_defaults(cls, request): """Return default values for this model when issuing an API OPTIONS request""" defaults = { - 'reference': order.validators.generate_next_return_order_reference(), + 'reference': order.validators.generate_next_return_order_reference() } return defaults @@ -1781,9 +1944,7 @@ class ReturnOrder(TotalPriceMixin, Order): verbose_name=_('Reference'), help_text=_('Return Order reference'), default=order.validators.generate_next_return_order_reference, - validators=[ - order.validators.validate_return_order_reference, - ] + validators=[order.validators.validate_return_order_reference], ) customer = models.ForeignKey( @@ -1793,7 +1954,7 @@ class ReturnOrder(TotalPriceMixin, Order): limit_choices_to={'is_customer': True}, related_name='sales_orders', verbose_name=_('Customer'), - help_text=_("Company from which items are being returned"), + help_text=_('Company from which items are being returned'), ) @property @@ -1804,25 +1965,29 @@ class ReturnOrder(TotalPriceMixin, Order): status = models.PositiveIntegerField( default=ReturnOrderStatus.PENDING.value, choices=ReturnOrderStatus.items(), - verbose_name=_('Status'), help_text=_('Return order status') + verbose_name=_('Status'), + help_text=_('Return order status'), ) customer_reference = models.CharField( - max_length=64, blank=True, + max_length=64, + blank=True, verbose_name=_('Customer Reference '), - help_text=_("Customer order reference code") + help_text=_('Customer order reference code'), ) issue_date = models.DateField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Issue Date'), - help_text=_('Date order was issued') + help_text=_('Date order was issued'), ) complete_date = models.DateField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Completion Date'), - help_text=_('Date order was completed') + help_text=_('Date order was completed'), ) # region state changes @@ -1854,7 +2019,7 @@ class ReturnOrder(TotalPriceMixin, Order): self, ReturnOrder, exclude=self.created_by, - content=InvenTreeNotificationBodies.OrderCanceled + content=InvenTreeNotificationBodies.OrderCanceled, ) def _action_complete(self, *args, **kwargs): @@ -1882,17 +2047,24 @@ class ReturnOrder(TotalPriceMixin, Order): @transaction.atomic def issue_order(self): """Attempt to transition to IN_PROGRESS status.""" - return self.handle_transition(self.status, ReturnOrderStatus.IN_PROGRESS.value, self, self._action_place) + return self.handle_transition( + self.status, ReturnOrderStatus.IN_PROGRESS.value, self, self._action_place + ) @transaction.atomic def complete_order(self): """Attempt to transition to COMPLETE status.""" - return self.handle_transition(self.status, ReturnOrderStatus.COMPLETE.value, self, self._action_complete) + return self.handle_transition( + self.status, ReturnOrderStatus.COMPLETE.value, self, self._action_complete + ) @transaction.atomic def cancel_order(self): """Attempt to transition to CANCELLED status.""" - return self.handle_transition(self.status, ReturnOrderStatus.CANCELLED.value, self, self._action_cancel) + return self.handle_transition( + self.status, ReturnOrderStatus.CANCELLED.value, self, self._action_cancel + ) + # endregion @transaction.atomic @@ -1906,7 +2078,7 @@ class ReturnOrder(TotalPriceMixin, Order): """ # Prevent an item from being "received" multiple times if line.received_date is not None: - logger.warning("receive_line_item called with item already returned") + logger.warning('receive_line_item called with item already returned') return stock_item = line.item @@ -1959,9 +2131,7 @@ class ReturnOrderLineItem(OrderLineItem): class Meta: """Metaclass options for this model""" - unique_together = [ - ('order', 'item'), - ] + unique_together = [('order', 'item')] @staticmethod def get_api_url(): @@ -1974,7 +2144,7 @@ class ReturnOrderLineItem(OrderLineItem): if self.item and not self.item.serialized: raise ValidationError({ - 'item': _("Only serialized items can be assigned to a Return Order"), + 'item': _('Only serialized items can be assigned to a Return Order') }) order = models.ForeignKey( @@ -1990,11 +2160,12 @@ class ReturnOrderLineItem(OrderLineItem): on_delete=models.CASCADE, related_name='return_order_lines', verbose_name=_('Item'), - help_text=_('Select item to return from customer') + help_text=_('Select item to return from customer'), ) received_date = models.DateField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Received Date'), help_text=_('The date this this return item was received'), ) @@ -2007,11 +2178,13 @@ class ReturnOrderLineItem(OrderLineItem): outcome = models.PositiveIntegerField( default=ReturnOrderLineStatus.PENDING.value, choices=ReturnOrderLineStatus.items(), - verbose_name=_('Outcome'), help_text=_('Outcome for this line item') + verbose_name=_('Outcome'), + help_text=_('Outcome for this line item'), ) price = InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Price'), help_text=_('Cost associated with return or repair for this line item'), ) @@ -2026,9 +2199,11 @@ class ReturnOrderExtraLine(OrderExtraLine): return reverse('api-return-order-extra-line-list') order = models.ForeignKey( - ReturnOrder, on_delete=models.CASCADE, + ReturnOrder, + on_delete=models.CASCADE, related_name='extra_lines', - verbose_name=_('Order'), help_text=_('Return Order') + verbose_name=_('Order'), + help_text=_('Return Order'), ) @@ -2045,7 +2220,5 @@ class ReturnOrderAttachment(InvenTreeAttachment): return os.path.join('return_files', str(self.order.id)) order = models.ForeignKey( - ReturnOrder, - on_delete=models.CASCADE, - related_name='attachments', + ReturnOrder, on_delete=models.CASCADE, related_name='attachments' ) diff --git a/InvenTree/order/serializers.py b/InvenTree/order/serializers.py index 2d2b3a53d0..81630af7e6 100644 --- a/InvenTree/order/serializers.py +++ b/InvenTree/order/serializers.py @@ -5,8 +5,7 @@ from decimal import Decimal from django.core.exceptions import ValidationError as DjangoValidationError from django.db import models, transaction -from django.db.models import (BooleanField, Case, ExpressionWrapper, F, Q, - Value, When) +from django.db.models import BooleanField, Case, ExpressionWrapper, F, Q, Value, When from django.utils.translation import gettext_lazy as _ from rest_framework import serializers @@ -18,19 +17,27 @@ import part.filters import stock.models import stock.serializers from common.serializers import ProjectCodeSerializer -from company.serializers import (AddressBriefSerializer, - CompanyBriefSerializer, ContactSerializer, - SupplierPartSerializer) -from InvenTree.helpers import (extract_serial_numbers, hash_barcode, normalize, - str2bool) -from InvenTree.serializers import (InvenTreeAttachmentSerializer, - InvenTreeCurrencySerializer, - InvenTreeDecimalField, - InvenTreeModelSerializer, - InvenTreeMoneySerializer) -from InvenTree.status_codes import (PurchaseOrderStatusGroups, - ReturnOrderLineStatus, ReturnOrderStatus, - SalesOrderStatusGroups, StockStatus) +from company.serializers import ( + AddressBriefSerializer, + CompanyBriefSerializer, + ContactSerializer, + SupplierPartSerializer, +) +from InvenTree.helpers import extract_serial_numbers, hash_barcode, normalize, str2bool +from InvenTree.serializers import ( + InvenTreeAttachmentSerializer, + InvenTreeCurrencySerializer, + InvenTreeDecimalField, + InvenTreeModelSerializer, + InvenTreeMoneySerializer, +) +from InvenTree.status_codes import ( + PurchaseOrderStatusGroups, + ReturnOrderLineStatus, + ReturnOrderStatus, + SalesOrderStatusGroups, + StockStatus, +) from part.serializers import PartBriefSerializer from users.serializers import OwnerSerializer @@ -38,10 +45,7 @@ from users.serializers import OwnerSerializer class TotalPriceMixin(serializers.Serializer): """Serializer mixin which provides total price fields""" - total_price = InvenTreeMoneySerializer( - allow_null=True, - read_only=True, - ) + total_price = InvenTreeMoneySerializer(allow_null=True, read_only=True) order_currency = InvenTreeCurrencySerializer( allow_blank=True, @@ -74,13 +78,19 @@ class AbstractOrderSerializer(serializers.Serializer): contact_detail = ContactSerializer(source='contact', many=False, read_only=True) # Detail for responsible field - responsible_detail = OwnerSerializer(source='responsible', read_only=True, many=False) + responsible_detail = OwnerSerializer( + source='responsible', read_only=True, many=False + ) # Detail for project code field - project_code_detail = ProjectCodeSerializer(source='project_code', read_only=True, many=False) + project_code_detail = ProjectCodeSerializer( + source='project_code', read_only=True, many=False + ) # Detail for address field - address_detail = AddressBriefSerializer(source='address', many=False, read_only=True) + address_detail = AddressBriefSerializer( + source='address', many=False, read_only=True + ) # Boolean field indicating if this order is overdue (Note: must be annotated) overdue = serializers.BooleanField(required=False, read_only=True) @@ -95,9 +105,7 @@ class AbstractOrderSerializer(serializers.Serializer): @staticmethod def annotate_queryset(queryset): """Add extra information to the queryset""" - queryset = queryset.annotate( - line_items=SubqueryCount('lines') - ) + queryset = queryset.annotate(line_items=SubqueryCount('lines')) return queryset @@ -143,9 +151,7 @@ class AbstractExtraLineSerializer(serializers.Serializer): quantity = serializers.FloatField() - price = InvenTreeMoneySerializer( - allow_null=True - ) + price = InvenTreeMoneySerializer(allow_null=True) price_currency = InvenTreeCurrencySerializer() @@ -168,7 +174,9 @@ class AbstractExtraLineMeta: ] -class PurchaseOrderSerializer(TotalPriceMixin, AbstractOrderSerializer, InvenTreeModelSerializer): +class PurchaseOrderSerializer( + TotalPriceMixin, AbstractOrderSerializer, InvenTreeModelSerializer +): """Serializer for a PurchaseOrder object.""" class Meta: @@ -186,11 +194,7 @@ class PurchaseOrderSerializer(TotalPriceMixin, AbstractOrderSerializer, InvenTre 'order_currency', ]) - read_only_fields = [ - 'issue_date', - 'complete_date', - 'creation_date', - ] + read_only_fields = ['issue_date', 'complete_date', 'creation_date'] extra_kwargs = { 'supplier': {'required': True}, @@ -216,7 +220,9 @@ class PurchaseOrderSerializer(TotalPriceMixin, AbstractOrderSerializer, InvenTre queryset = AbstractOrderSerializer.annotate_queryset(queryset) queryset = queryset.annotate( - completed_lines=SubqueryCount('lines', filter=Q(quantity__lte=F('received'))) + completed_lines=SubqueryCount( + 'lines', filter=Q(quantity__lte=F('received')) + ) ) queryset = queryset.annotate( @@ -225,13 +231,15 @@ class PurchaseOrderSerializer(TotalPriceMixin, AbstractOrderSerializer, InvenTre order.models.PurchaseOrder.overdue_filter(), then=Value(True, output_field=BooleanField()), ), - default=Value(False, output_field=BooleanField()) + default=Value(False, output_field=BooleanField()), ) ) return queryset - supplier_detail = CompanyBriefSerializer(source='supplier', many=False, read_only=True) + supplier_detail = CompanyBriefSerializer( + source='supplier', many=False, read_only=True + ) class PurchaseOrderCancelSerializer(serializers.Serializer): @@ -240,22 +248,20 @@ class PurchaseOrderCancelSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [], + fields = ([],) def get_context_data(self): """Return custom context information about the order.""" self.order = self.context['order'] - return { - 'can_cancel': self.order.can_cancel, - } + return {'can_cancel': self.order.can_cancel} def save(self): """Save the serializer to 'cancel' the order""" order = self.context['order'] if not order.can_cancel: - raise ValidationError(_("Order cannot be cancelled")) + raise ValidationError(_('Order cannot be cancelled')) order.cancel_order() @@ -280,7 +286,7 @@ class PurchaseOrderCompleteSerializer(serializers.Serializer): order = self.context['order'] if not value and not order.is_complete: - raise ValidationError(_("Order has incomplete line items")) + raise ValidationError(_('Order has incomplete line items')) return value @@ -288,9 +294,7 @@ class PurchaseOrderCompleteSerializer(serializers.Serializer): """Custom context information for this serializer.""" order = self.context['order'] - return { - 'is_complete': order.is_complete, - } + return {'is_complete': order.is_complete} def save(self): """Save the serializer to 'complete' the order""" @@ -365,15 +369,15 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer): """ queryset = queryset.annotate( total_price=ExpressionWrapper( - F('purchase_price') * F('quantity'), - output_field=models.DecimalField() + F('purchase_price') * F('quantity'), output_field=models.DecimalField() ) ) queryset = queryset.annotate( overdue=Case( When( - order.models.PurchaseOrderLineItem.OVERDUE_FILTER, then=Value(True, output_field=BooleanField()) + order.models.PurchaseOrderLineItem.OVERDUE_FILTER, + then=Value(True, output_field=BooleanField()), ), default=Value(False, output_field=BooleanField()), ) @@ -386,7 +390,7 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer): def validate_quantity(self, quantity): """Validation for the 'quantity' field""" if quantity <= 0: - raise ValidationError(_("Quantity must be greater than zero")) + raise ValidationError(_('Quantity must be greater than zero')) return quantity @@ -403,15 +407,23 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer): total_price = serializers.FloatField(read_only=True) - part_detail = PartBriefSerializer(source='get_base_part', many=False, read_only=True) + part_detail = PartBriefSerializer( + source='get_base_part', many=False, read_only=True + ) - supplier_part_detail = SupplierPartSerializer(source='part', many=False, read_only=True) + supplier_part_detail = SupplierPartSerializer( + source='part', many=False, read_only=True + ) purchase_price = InvenTreeMoneySerializer(allow_null=True) - destination_detail = stock.serializers.LocationBriefSerializer(source='get_destination', read_only=True) + destination_detail = stock.serializers.LocationBriefSerializer( + source='get_destination', read_only=True + ) - purchase_price_currency = InvenTreeCurrencySerializer(help_text=_('Purchase price currency')) + purchase_price_currency = InvenTreeCurrencySerializer( + help_text=_('Purchase price currency') + ) order_detail = PurchaseOrderSerializer(source='order', read_only=True, many=False) @@ -428,17 +440,16 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer): purchase_order = data.get('order', None) if not supplier_part: - raise ValidationError({ - 'part': _('Supplier part must be specified'), - }) + raise ValidationError({'part': _('Supplier part must be specified')}) if not purchase_order: - raise ValidationError({ - 'order': _('Purchase order must be specified'), - }) + raise ValidationError({'order': _('Purchase order must be specified')}) # Check that the supplier part and purchase order match - if supplier_part is not None and supplier_part.supplier != purchase_order.supplier: + if ( + supplier_part is not None + and supplier_part.supplier != purchase_order.supplier + ): raise ValidationError({ 'part': _('Supplier must match purchase order'), 'order': _('Purchase order must match supplier'), @@ -447,7 +458,9 @@ class PurchaseOrderLineItemSerializer(InvenTreeModelSerializer): return data -class PurchaseOrderExtraLineSerializer(AbstractExtraLineSerializer, InvenTreeModelSerializer): +class PurchaseOrderExtraLineSerializer( + AbstractExtraLineSerializer, InvenTreeModelSerializer +): """Serializer for a PurchaseOrderExtraLine object.""" order_detail = PurchaseOrderSerializer(source='order', many=False, read_only=True) @@ -470,8 +483,7 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer): 'location', 'quantity', 'status', - 'batch_code' - 'serial_numbers', + 'batch_code' 'serial_numbers', ] line_item = serializers.PrimaryKeyRelatedField( @@ -499,16 +511,13 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer): ) quantity = serializers.DecimalField( - max_digits=15, - decimal_places=5, - min_value=0, - required=True, + max_digits=15, decimal_places=5, min_value=0, required=True ) def validate_quantity(self, quantity): """Validation for the 'quantity' field""" if quantity <= 0: - raise ValidationError(_("Quantity must be greater than zero")) + raise ValidationError(_('Quantity must be greater than zero')) return quantity @@ -529,9 +538,7 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer): ) status = serializers.ChoiceField( - choices=StockStatus.items(), - default=StockStatus.OK.value, - label=_('Status'), + choices=StockStatus.items(), default=StockStatus.OK.value, label=_('Status') ) barcode = serializers.CharField( @@ -573,10 +580,11 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer): # Does the quantity need to be "integer" (for trackable parts?) if base_part.trackable: - if Decimal(base_quantity) != int(base_quantity): raise ValidationError({ - 'quantity': _('An integer quantity must be provided for trackable parts'), + 'quantity': _( + 'An integer quantity must be provided for trackable parts' + ) }) # If serial numbers are provided @@ -584,14 +592,10 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer): try: # Pass the serial numbers through to the parent serializer once validated data['serials'] = extract_serial_numbers( - serial_numbers, - base_quantity, - base_part.get_latest_serial_number() + serial_numbers, base_quantity, base_part.get_latest_serial_number() ) except DjangoValidationError as e: - raise ValidationError({ - 'serial_numbers': e.messages, - }) + raise ValidationError({'serial_numbers': e.messages}) return data @@ -602,10 +606,7 @@ class PurchaseOrderReceiveSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'items', - 'location', - ] + fields = ['items', 'location'] items = PurchaseOrderLineItemReceiveSerializer(many=True) @@ -634,7 +635,6 @@ class PurchaseOrderReceiveSerializer(serializers.Serializer): # Check if the location is not specified for any particular item for item in items: - line = item['line_item'] if not item.get('location', None): @@ -647,7 +647,7 @@ class PurchaseOrderReceiveSerializer(serializers.Serializer): if not item['location']: raise ValidationError({ - 'location': _("Destination location must be specified"), + 'location': _('Destination location must be specified') }) # Ensure barcodes are unique @@ -677,9 +677,12 @@ class PurchaseOrderReceiveSerializer(serializers.Serializer): # Now we can actually receive the items into stock with transaction.atomic(): for item in items: - # Select location (in descending order of priority) - loc = location or item.get('location', None) or item['line_item'].get_destination() + loc = ( + location + or item.get('location', None) + or item['line_item'].get_destination() + ) try: order.receive_line_item( @@ -705,12 +708,12 @@ class PurchaseOrderAttachmentSerializer(InvenTreeAttachmentSerializer): model = order.models.PurchaseOrderAttachment - fields = InvenTreeAttachmentSerializer.attachment_fields([ - 'order', - ]) + fields = InvenTreeAttachmentSerializer.attachment_fields(['order']) -class SalesOrderSerializer(TotalPriceMixin, AbstractOrderSerializer, InvenTreeModelSerializer): +class SalesOrderSerializer( + TotalPriceMixin, AbstractOrderSerializer, InvenTreeModelSerializer +): """Serializer for the SalesOrder model class""" class Meta: @@ -727,15 +730,9 @@ class SalesOrderSerializer(TotalPriceMixin, AbstractOrderSerializer, InvenTreeMo 'order_currency', ]) - read_only_fields = [ - 'status', - 'creation_date', - 'shipment_date', - ] + read_only_fields = ['status', 'creation_date', 'shipment_date'] - extra_kwargs = { - 'order_currency': {'required': False}, - } + extra_kwargs = {'order_currency': {'required': False}} def __init__(self, *args, **kwargs): """Initialization routine for the serializer""" @@ -766,13 +763,15 @@ class SalesOrderSerializer(TotalPriceMixin, AbstractOrderSerializer, InvenTreeMo order.models.SalesOrder.overdue_filter(), then=Value(True, output_field=BooleanField()), ), - default=Value(False, output_field=BooleanField()) + default=Value(False, output_field=BooleanField()), ) ) return queryset - customer_detail = CompanyBriefSerializer(source='customer', many=False, read_only=True) + customer_detail = CompanyBriefSerializer( + source='customer', many=False, read_only=True + ) class SalesOrderIssueSerializer(serializers.Serializer): @@ -780,6 +779,7 @@ class SalesOrderIssueSerializer(serializers.Serializer): class Meta: """Metaclass options""" + fields = [] def save(self): @@ -843,19 +843,31 @@ class SalesOrderAllocationSerializer(InvenTreeModelSerializer): self.fields.pop('customer_detail') part = serializers.PrimaryKeyRelatedField(source='item.part', read_only=True) - order = serializers.PrimaryKeyRelatedField(source='line.order', many=False, read_only=True) + order = serializers.PrimaryKeyRelatedField( + source='line.order', many=False, read_only=True + ) serial = serializers.CharField(source='get_serial', read_only=True) quantity = serializers.FloatField(read_only=False) - location = serializers.PrimaryKeyRelatedField(source='item.location', many=False, read_only=True) + location = serializers.PrimaryKeyRelatedField( + source='item.location', many=False, read_only=True + ) # Extra detail fields order_detail = SalesOrderSerializer(source='line.order', many=False, read_only=True) part_detail = PartBriefSerializer(source='item.part', many=False, read_only=True) - item_detail = stock.serializers.StockItemSerializer(source='item', many=False, read_only=True) - location_detail = stock.serializers.LocationSerializer(source='item.location', many=False, read_only=True) - customer_detail = CompanyBriefSerializer(source='line.order.customer', many=False, read_only=True) + item_detail = stock.serializers.StockItemSerializer( + source='item', many=False, read_only=True + ) + location_detail = stock.serializers.LocationSerializer( + source='item.location', many=False, read_only=True + ) + customer_detail = CompanyBriefSerializer( + source='line.order.customer', many=False, read_only=True + ) - shipment_date = serializers.DateField(source='shipment.shipment_date', read_only=True) + shipment_date = serializers.DateField( + source='shipment.shipment_date', read_only=True + ) class SalesOrderLineItemSerializer(InvenTreeModelSerializer): @@ -922,7 +934,9 @@ class SalesOrderLineItemSerializer(InvenTreeModelSerializer): queryset = queryset.annotate( overdue=Case( When( - Q(order__status__in=SalesOrderStatusGroups.OPEN) & order.models.SalesOrderLineItem.OVERDUE_FILTER, then=Value(True, output_field=BooleanField()), + Q(order__status__in=SalesOrderStatusGroups.OPEN) + & order.models.SalesOrderLineItem.OVERDUE_FILTER, + then=Value(True, output_field=BooleanField()), ), default=Value(False, output_field=BooleanField()), ) @@ -932,43 +946,60 @@ class SalesOrderLineItemSerializer(InvenTreeModelSerializer): # To do this, we need to look at the total stock and any allocations queryset = queryset.alias( total_stock=part.filters.annotate_total_stock(reference='part__'), - allocated_to_sales_orders=part.filters.annotate_sales_order_allocations(reference='part__'), - allocated_to_build_orders=part.filters.annotate_build_order_allocations(reference='part__'), + allocated_to_sales_orders=part.filters.annotate_sales_order_allocations( + reference='part__' + ), + allocated_to_build_orders=part.filters.annotate_build_order_allocations( + reference='part__' + ), ) queryset = queryset.annotate( available_stock=ExpressionWrapper( - F('total_stock') - F('allocated_to_sales_orders') - F('allocated_to_build_orders'), - output_field=models.DecimalField() + F('total_stock') + - F('allocated_to_sales_orders') + - F('allocated_to_build_orders'), + output_field=models.DecimalField(), ) ) # Filter for "variant" stock: Variant stock items must be salable and active - variant_stock_query = part.filters.variant_stock_query(reference='part__').filter( - part__salable=True, - part__active=True - ) + variant_stock_query = part.filters.variant_stock_query( + reference='part__' + ).filter(part__salable=True, part__active=True) # Also add in available "variant" stock queryset = queryset.alias( - variant_stock_total=part.filters.annotate_variant_quantity(variant_stock_query, reference='quantity'), - variant_bo_allocations=part.filters.annotate_variant_quantity(variant_stock_query, reference='sales_order_allocations__quantity'), - variant_so_allocations=part.filters.annotate_variant_quantity(variant_stock_query, reference='allocations__quantity'), + variant_stock_total=part.filters.annotate_variant_quantity( + variant_stock_query, reference='quantity' + ), + variant_bo_allocations=part.filters.annotate_variant_quantity( + variant_stock_query, reference='sales_order_allocations__quantity' + ), + variant_so_allocations=part.filters.annotate_variant_quantity( + variant_stock_query, reference='allocations__quantity' + ), ) queryset = queryset.annotate( available_variant_stock=ExpressionWrapper( - F('variant_stock_total') - F('variant_bo_allocations') - F('variant_so_allocations'), + F('variant_stock_total') + - F('variant_bo_allocations') + - F('variant_so_allocations'), output_field=models.DecimalField(), ) ) return queryset - customer_detail = CompanyBriefSerializer(source='order.customer', many=False, read_only=True) + customer_detail = CompanyBriefSerializer( + source='order.customer', many=False, read_only=True + ) order_detail = SalesOrderSerializer(source='order', many=False, read_only=True) part_detail = PartBriefSerializer(source='part', many=False, read_only=True) - allocations = SalesOrderAllocationSerializer(many=True, read_only=True, location_detail=True) + allocations = SalesOrderAllocationSerializer( + many=True, read_only=True, location_detail=True + ) # Annotated fields overdue = serializers.BooleanField(required=False, read_only=True) @@ -983,7 +1014,9 @@ class SalesOrderLineItemSerializer(InvenTreeModelSerializer): sale_price = InvenTreeMoneySerializer(allow_null=True) - sale_price_currency = InvenTreeCurrencySerializer(help_text=_('Sale price currency')) + sale_price_currency = InvenTreeCurrencySerializer( + help_text=_('Sale price currency') + ) class SalesOrderShipmentSerializer(InvenTreeModelSerializer): @@ -1009,7 +1042,9 @@ class SalesOrderShipmentSerializer(InvenTreeModelSerializer): 'notes', ] - allocations = SalesOrderAllocationSerializer(many=True, read_only=True, location_detail=True) + allocations = SalesOrderAllocationSerializer( + many=True, read_only=True, location_detail=True + ) order_detail = SalesOrderSerializer(source='order', read_only=True, many=False) @@ -1040,7 +1075,7 @@ class SalesOrderShipmentCompleteSerializer(serializers.ModelSerializer): shipment = self.context.get('shipment', None) if not shipment: - raise ValidationError(_("No shipment details provided")) + raise ValidationError(_('No shipment details provided')) shipment.check_can_complete(raise_error=True) @@ -1081,11 +1116,7 @@ class SalesOrderShipmentAllocationItemSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'line_item', - 'stock_item', - 'quantity', - ] + fields = ['line_item', 'stock_item', 'quantity'] line_item = serializers.PrimaryKeyRelatedField( queryset=order.models.SalesOrderLineItem.objects.all(), @@ -1104,7 +1135,7 @@ class SalesOrderShipmentAllocationItemSerializer(serializers.Serializer): # Ensure that the line item points to the correct order if line_item.order != order: - raise ValidationError(_("Line item is not associated with this order")) + raise ValidationError(_('Line item is not associated with this order')) return line_item @@ -1117,16 +1148,13 @@ class SalesOrderShipmentAllocationItemSerializer(serializers.Serializer): ) quantity = serializers.DecimalField( - max_digits=15, - decimal_places=5, - min_value=0, - required=True + max_digits=15, decimal_places=5, min_value=0, required=True ) def validate_quantity(self, quantity): """Custom validation for the 'quantity' field""" if quantity <= 0: - raise ValidationError(_("Quantity must be positive")) + raise ValidationError(_('Quantity must be positive')) return quantity @@ -1143,15 +1171,13 @@ class SalesOrderShipmentAllocationItemSerializer(serializers.Serializer): if stock_item.serialized and quantity != 1: raise ValidationError({ - 'quantity': _("Quantity must be 1 for serialized stock item"), + 'quantity': _('Quantity must be 1 for serialized stock item') }) q = normalize(stock_item.unallocated_quantity()) if quantity > q: - raise ValidationError({ - 'quantity': _(f"Available quantity ({q}) exceeded") - }) + raise ValidationError({'quantity': _(f'Available quantity ({q}) exceeded')}) return data @@ -1171,7 +1197,7 @@ class SalesOrderCompleteSerializer(serializers.Serializer): order = self.context['order'] if not value and not order.is_completed(): - raise ValidationError(_("Order has incomplete line items")) + raise ValidationError(_('Order has incomplete line items')) return value @@ -1206,8 +1232,7 @@ class SalesOrderCompleteSerializer(serializers.Serializer): user = getattr(request, 'user', None) order.complete_order( - user, - allow_incomplete_lines=str2bool(data.get('accept_incomplete', False)), + user, allow_incomplete_lines=str2bool(data.get('accept_incomplete', False)) ) @@ -1218,9 +1243,7 @@ class SalesOrderCancelSerializer(serializers.Serializer): """Add extra context data to the serializer""" order = self.context['order'] - return { - 'can_cancel': order.can_cancel, - } + return {'can_cancel': order.can_cancel} def save(self): """Save the serializer to cancel the order""" @@ -1235,12 +1258,7 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'line_item', - 'quantity', - 'serial_numbers', - 'shipment', - ] + fields = ['line_item', 'quantity', 'serial_numbers', 'shipment'] line_item = serializers.PrimaryKeyRelatedField( queryset=order.models.SalesOrderLineItem.objects.all(), @@ -1256,20 +1274,17 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer): # Ensure that the line item points to the correct order if line_item.order != order: - raise ValidationError(_("Line item is not associated with this order")) + raise ValidationError(_('Line item is not associated with this order')) return line_item quantity = serializers.IntegerField( - min_value=1, - required=True, - allow_null=False, - label=_('Quantity'), + min_value=1, required=True, allow_null=False, label=_('Quantity') ) serial_numbers = serializers.CharField( - label=_("Serial Numbers"), - help_text=_("Enter serial numbers to allocate"), + label=_('Serial Numbers'), + help_text=_('Enter serial numbers to allocate'), required=True, allow_blank=False, ) @@ -1291,10 +1306,10 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer): order = self.context['order'] if shipment.shipment_date is not None: - raise ValidationError(_("Shipment has already been shipped")) + raise ValidationError(_('Shipment has already been shipped')) if shipment.order != order: - raise ValidationError(_("Shipment is not associated with this order")) + raise ValidationError(_('Shipment is not associated with this order')) return shipment @@ -1315,14 +1330,10 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer): try: data['serials'] = extract_serial_numbers( - serial_numbers, - quantity, - part.get_latest_serial_number() + serial_numbers, quantity, part.get_latest_serial_number() ) except DjangoValidationError as e: - raise ValidationError({ - 'serial_numbers': e.messages, - }) + raise ValidationError({'serial_numbers': e.messages}) serials_not_exist = [] serials_allocated = [] @@ -1330,9 +1341,7 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer): for serial in data['serials']: items = stock.models.StockItem.objects.filter( - part=part, - serial=serial, - quantity=1, + part=part, serial=serial, quantity=1 ) if not items.exists(): @@ -1347,24 +1356,18 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer): serials_allocated.append(str(serial)) if len(serials_not_exist) > 0: + error_msg = _('No match found for the following serial numbers') + error_msg += ': ' + error_msg += ','.join(serials_not_exist) - error_msg = _("No match found for the following serial numbers") - error_msg += ": " - error_msg += ",".join(serials_not_exist) - - raise ValidationError({ - 'serial_numbers': error_msg - }) + raise ValidationError({'serial_numbers': error_msg}) if len(serials_allocated) > 0: + error_msg = _('The following serial numbers are already allocated') + error_msg += ': ' + error_msg += ','.join(serials_allocated) - error_msg = _("The following serial numbers are already allocated") - error_msg += ": " - error_msg += ",".join(serials_allocated) - - raise ValidationError({ - 'serial_numbers': error_msg, - }) + raise ValidationError({'serial_numbers': error_msg}) data['stock_items'] = stock_items_to_allocate @@ -1382,10 +1385,7 @@ class SalesOrderSerialAllocationSerializer(serializers.Serializer): for stock_item in stock_items: # Create a new SalesOrderAllocation order.models.SalesOrderAllocation.objects.create( - line=line_item, - item=stock_item, - quantity=1, - shipment=shipment + line=line_item, item=stock_item, quantity=1, shipment=shipment ) @@ -1395,10 +1395,7 @@ class SalesOrderShipmentAllocationSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'items', - 'shipment', - ] + fields = ['items', 'shipment'] items = SalesOrderShipmentAllocationItemSerializer(many=True) @@ -1415,10 +1412,10 @@ class SalesOrderShipmentAllocationSerializer(serializers.Serializer): order = self.context['order'] if shipment.shipment_date is not None: - raise ValidationError(_("Shipment has already been shipped")) + raise ValidationError(_('Shipment has already been shipped')) if shipment.order != order: - raise ValidationError(_("Shipment is not associated with this order")) + raise ValidationError(_('Shipment is not associated with this order')) return shipment @@ -1445,7 +1442,6 @@ class SalesOrderShipmentAllocationSerializer(serializers.Serializer): with transaction.atomic(): for entry in items: - # Create a new SalesOrderAllocation allocation = order.models.SalesOrderAllocation( line=entry.get('line_item'), @@ -1458,7 +1454,9 @@ class SalesOrderShipmentAllocationSerializer(serializers.Serializer): allocation.save() -class SalesOrderExtraLineSerializer(AbstractExtraLineSerializer, InvenTreeModelSerializer): +class SalesOrderExtraLineSerializer( + AbstractExtraLineSerializer, InvenTreeModelSerializer +): """Serializer for a SalesOrderExtraLine object.""" class Meta(AbstractExtraLineMeta): @@ -1477,12 +1475,12 @@ class SalesOrderAttachmentSerializer(InvenTreeAttachmentSerializer): model = order.models.SalesOrderAttachment - fields = InvenTreeAttachmentSerializer.attachment_fields([ - 'order', - ]) + fields = InvenTreeAttachmentSerializer.attachment_fields(['order']) -class ReturnOrderSerializer(AbstractOrderSerializer, TotalPriceMixin, InvenTreeModelSerializer): +class ReturnOrderSerializer( + AbstractOrderSerializer, TotalPriceMixin, InvenTreeModelSerializer +): """Serializer for the ReturnOrder model class""" class Meta: @@ -1498,9 +1496,7 @@ class ReturnOrderSerializer(AbstractOrderSerializer, TotalPriceMixin, InvenTreeM 'total_price', ]) - read_only_fields = [ - 'creation_date', - ] + read_only_fields = ['creation_date'] def __init__(self, *args, **kwargs): """Initialization routine for the serializer""" @@ -1517,7 +1513,9 @@ class ReturnOrderSerializer(AbstractOrderSerializer, TotalPriceMixin, InvenTreeM queryset = AbstractOrderSerializer.annotate_queryset(queryset) queryset = queryset.annotate( - completed_lines=SubqueryCount('lines', filter=~Q(outcome=ReturnOrderLineStatus.PENDING.value)) + completed_lines=SubqueryCount( + 'lines', filter=~Q(outcome=ReturnOrderLineStatus.PENDING.value) + ) ) queryset = queryset.annotate( @@ -1526,13 +1524,15 @@ class ReturnOrderSerializer(AbstractOrderSerializer, TotalPriceMixin, InvenTreeM order.models.ReturnOrder.overdue_filter(), then=Value(True, output_field=BooleanField()), ), - default=Value(False, output_field=BooleanField()) + default=Value(False, output_field=BooleanField()), ) ) return queryset - customer_detail = CompanyBriefSerializer(source='customer', many=False, read_only=True) + customer_detail = CompanyBriefSerializer( + source='customer', many=False, read_only=True + ) class ReturnOrderIssueSerializer(serializers.Serializer): @@ -1540,6 +1540,7 @@ class ReturnOrderIssueSerializer(serializers.Serializer): class Meta: """Metaclass options""" + fields = [] def save(self): @@ -1553,6 +1554,7 @@ class ReturnOrderCancelSerializer(serializers.Serializer): class Meta: """Metaclass options""" + fields = [] def save(self): @@ -1566,6 +1568,7 @@ class ReturnOrderCompleteSerializer(serializers.Serializer): class Meta: """Metaclass options""" + fields = [] def save(self): @@ -1579,9 +1582,8 @@ class ReturnOrderLineItemReceiveSerializer(serializers.Serializer): class Meta: """Metaclass options""" - fields = [ - 'item', - ] + + fields = ['item'] item = serializers.PrimaryKeyRelatedField( queryset=order.models.ReturnOrderLineItem.objects.all(), @@ -1594,10 +1596,10 @@ class ReturnOrderLineItemReceiveSerializer(serializers.Serializer): def validate_line_item(self, item): """Validation for a single line item""" if item.order != self.context['order']: - raise ValidationError(_("Line item does not match return order")) + raise ValidationError(_('Line item does not match return order')) if item.received: - raise ValidationError(_("Line item has already been received")) + raise ValidationError(_('Line item has already been received')) return item @@ -1608,10 +1610,7 @@ class ReturnOrderReceiveSerializer(serializers.Serializer): class Meta: """Metaclass options""" - fields = [ - 'items', - 'location', - ] + fields = ['items', 'location'] items = ReturnOrderLineItemReceiveSerializer(many=True) @@ -1628,14 +1627,16 @@ class ReturnOrderReceiveSerializer(serializers.Serializer): """Perform data validation for this serializer""" order = self.context['order'] if order.status != ReturnOrderStatus.IN_PROGRESS: - raise ValidationError(_("Items can only be received against orders which are in progress")) + raise ValidationError( + _('Items can only be received against orders which are in progress') + ) data = super().validate(data) items = data.get('items', []) if len(items) == 0: - raise ValidationError(_("Line items must be provided")) + raise ValidationError(_('Line items must be provided')) return data @@ -1652,11 +1653,7 @@ class ReturnOrderReceiveSerializer(serializers.Serializer): with transaction.atomic(): for item in items: line_item = item['item'] - order.receive_line_item( - line_item, - location, - request.user - ) + order.receive_line_item(line_item, location, request.user) class ReturnOrderLineItemSerializer(InvenTreeModelSerializer): @@ -1703,18 +1700,23 @@ class ReturnOrderLineItemSerializer(InvenTreeModelSerializer): self.fields.pop('part_detail') order_detail = ReturnOrderSerializer(source='order', many=False, read_only=True) - item_detail = stock.serializers.StockItemSerializer(source='item', many=False, read_only=True) + item_detail = stock.serializers.StockItemSerializer( + source='item', many=False, read_only=True + ) part_detail = PartBriefSerializer(source='item.part', many=False, read_only=True) price = InvenTreeMoneySerializer(allow_null=True) price_currency = InvenTreeCurrencySerializer(help_text=_('Line price currency')) -class ReturnOrderExtraLineSerializer(AbstractExtraLineSerializer, InvenTreeModelSerializer): +class ReturnOrderExtraLineSerializer( + AbstractExtraLineSerializer, InvenTreeModelSerializer +): """Serializer for a ReturnOrderExtraLine object""" class Meta(AbstractExtraLineMeta): """Metaclass options""" + model = order.models.ReturnOrderExtraLine order_detail = ReturnOrderSerializer(source='order', many=False, read_only=True) @@ -1728,6 +1730,4 @@ class ReturnOrderAttachmentSerializer(InvenTreeAttachmentSerializer): model = order.models.ReturnOrderAttachment - fields = InvenTreeAttachmentSerializer.attachment_fields([ - 'order', - ]) + fields = InvenTreeAttachmentSerializer.attachment_fields(['order']) diff --git a/InvenTree/order/tasks.py b/InvenTree/order/tasks.py index 6d6a2eb7d6..da621626c7 100644 --- a/InvenTree/order/tasks.py +++ b/InvenTree/order/tasks.py @@ -7,8 +7,7 @@ from django.utils.translation import gettext_lazy as _ import common.notifications import InvenTree.helpers_model import order.models -from InvenTree.status_codes import (PurchaseOrderStatusGroups, - SalesOrderStatusGroups) +from InvenTree.status_codes import PurchaseOrderStatusGroups, SalesOrderStatusGroups from InvenTree.tasks import ScheduledTask, scheduled_task from plugin.events import trigger_event @@ -29,30 +28,19 @@ def notify_overdue_purchase_order(po: order.models.PurchaseOrder): 'order': po, 'name': name, 'message': _(f'Purchase order {po} is now overdue'), - 'link': InvenTree.helpers_model.construct_absolute_url( - po.get_absolute_url(), - ), - 'template': { - 'html': 'email/overdue_purchase_order.html', - 'subject': name, - } + 'link': InvenTree.helpers_model.construct_absolute_url(po.get_absolute_url()), + 'template': {'html': 'email/overdue_purchase_order.html', 'subject': name}, } event_name = 'order.overdue_purchase_order' # Send a notification to the appropriate users common.notifications.trigger_notification( - po, - event_name, - targets=targets, - context=context, + po, event_name, targets=targets, context=context ) # Register a matching event to the plugin system - trigger_event( - event_name, - purchase_order=po.pk, - ) + trigger_event(event_name, purchase_order=po.pk) @scheduled_task(ScheduledTask.DAILY) @@ -66,8 +54,7 @@ def check_overdue_purchase_orders(): yesterday = datetime.now().date() - timedelta(days=1) overdue_orders = order.models.PurchaseOrder.objects.filter( - target_date=yesterday, - status__in=PurchaseOrderStatusGroups.OPEN, + target_date=yesterday, status__in=PurchaseOrderStatusGroups.OPEN ) for po in overdue_orders: @@ -89,31 +76,20 @@ def notify_overdue_sales_order(so: order.models.SalesOrder): context = { 'order': so, 'name': name, - 'message': _(f"Sales order {so} is now overdue"), - 'link': InvenTree.helpers_model.construct_absolute_url( - so.get_absolute_url(), - ), - 'template': { - 'html': 'email/overdue_sales_order.html', - 'subject': name, - } + 'message': _(f'Sales order {so} is now overdue'), + 'link': InvenTree.helpers_model.construct_absolute_url(so.get_absolute_url()), + 'template': {'html': 'email/overdue_sales_order.html', 'subject': name}, } event_name = 'order.overdue_sales_order' # Send a notification to the appropriate users common.notifications.trigger_notification( - so, - event_name, - targets=targets, - context=context, + so, event_name, targets=targets, context=context ) # Register a matching event to the plugin system - trigger_event( - event_name, - sales_order=so.pk, - ) + trigger_event(event_name, sales_order=so.pk) @scheduled_task(ScheduledTask.DAILY) @@ -127,8 +103,7 @@ def check_overdue_sales_orders(): yesterday = datetime.now().date() - timedelta(days=1) overdue_orders = order.models.SalesOrder.objects.filter( - target_date=yesterday, - status__in=SalesOrderStatusGroups.OPEN, + target_date=yesterday, status__in=SalesOrderStatusGroups.OPEN ) for po in overdue_orders: diff --git a/InvenTree/order/test_api.py b/InvenTree/order/test_api.py index 2b7aa76782..906a5c229a 100644 --- a/InvenTree/order/test_api.py +++ b/InvenTree/order/test_api.py @@ -15,9 +15,14 @@ from rest_framework import status from common.settings import currency_codes from company.models import Company -from InvenTree.status_codes import (PurchaseOrderStatus, ReturnOrderLineStatus, - ReturnOrderStatus, SalesOrderStatus, - SalesOrderStatusGroups, StockStatus) +from InvenTree.status_codes import ( + PurchaseOrderStatus, + ReturnOrderLineStatus, + ReturnOrderStatus, + SalesOrderStatus, + SalesOrderStatusGroups, + StockStatus, +) from InvenTree.unit_test import InvenTreeAPITestCase from order import models from part.models import Part @@ -26,6 +31,7 @@ from stock.models import StockItem class OrderTest(InvenTreeAPITestCase): """Base class for order API unit testing""" + fixtures = [ 'category', 'part', @@ -37,17 +43,11 @@ class OrderTest(InvenTreeAPITestCase): 'sales_order', ] - roles = [ - 'purchase_order.change', - 'sales_order.change', - ] + roles = ['purchase_order.change', 'sales_order.change'] def filter(self, filters, count): """Test API filters.""" - response = self.get( - self.LIST_URL, - filters - ) + response = self.get(self.LIST_URL, filters) self.assertEqual(response.status_code, 200) self.assertEqual(len(response.data), count) @@ -80,28 +80,36 @@ class PurchaseOrderTest(OrderTest): self.assertEqual(field_data[k], v) # Checks for the 'order_currency' field - check_options(post, 'order_currency', { - 'type': 'choice', - 'required': False, - 'read_only': False, - 'label': 'Order Currency', - 'help_text': 'Currency for this order (leave blank to use company default)', - }) + check_options( + post, + 'order_currency', + { + 'type': 'choice', + 'required': False, + 'read_only': False, + 'label': 'Order Currency', + 'help_text': 'Currency for this order (leave blank to use company default)', + }, + ) # Checks for the 'reference' field - check_options(post, 'reference', { - 'type': 'string', - 'required': True, - 'read_only': False, - 'label': 'Reference', - }) + check_options( + post, + 'reference', + { + 'type': 'string', + 'required': True, + 'read_only': False, + 'label': 'Reference', + }, + ) # Checks for the 'supplier' field - check_options(post, 'supplier', { - 'type': 'related field', - 'required': True, - 'api_url': '/api/company/', - }) + check_options( + post, + 'supplier', + {'type': 'related field', 'required': True, 'api_url': '/api/company/'}, + ) def test_po_list(self): """Test the PurchaseOrder list API endpoint""" @@ -157,10 +165,7 @@ class PurchaseOrderTest(OrderTest): for supplier in Company.objects.filter(is_supplier=True): for _idx in range(10): new_orders.append( - models.PurchaseOrder( - supplier=supplier, - reference=f'PO-{idx + 100}' - ) + models.PurchaseOrder(supplier=supplier, reference=f'PO-{idx + 100}') ) idx += 1 @@ -190,7 +195,9 @@ class PurchaseOrderTest(OrderTest): # List all purchase orders for limit in [1, 5, 10, 100]: with CaptureQueriesContext(connection) as ctx: - response = self.get(self.LIST_URL, data={'limit': limit}, expected_code=200) + response = self.get( + self.LIST_URL, data={'limit': limit}, expected_code=200 + ) # Total database queries must be below 15, independent of the number of results self.assertLess(len(ctx), 15) @@ -230,7 +237,7 @@ class PurchaseOrderTest(OrderTest): self.assignRole('purchase_order.add') url = reverse('api-po-list') - huge_number = "PO-92233720368547758089999999999999999" + huge_number = 'PO-92233720368547758089999999999999999' response = self.post( url, @@ -245,7 +252,7 @@ class PurchaseOrderTest(OrderTest): order = models.PurchaseOrder.objects.get(pk=response.data['pk']) self.assertEqual(order.reference, 'PO-92233720368547758089999999999999999') - self.assertEqual(order.reference_int, 0x7fffffff) + self.assertEqual(order.reference_int, 0x7FFFFFFF) def test_po_attachments(self): """Test the list endpoint for the PurchaseOrderAttachment model""" @@ -270,7 +277,7 @@ class PurchaseOrderTest(OrderTest): 'reference': '123456789-xyz', 'description': 'PO created via the API', }, - expected_code=403 + expected_code=403, ) # And no new PurchaseOrder objects should have been created @@ -288,7 +295,7 @@ class PurchaseOrderTest(OrderTest): 'reference': 'PO-123456789', 'description': 'PO created via the API', }, - expected_code=201 + expected_code=201, ) self.assertEqual(models.PurchaseOrder.objects.count(), n + 1) @@ -303,7 +310,7 @@ class PurchaseOrderTest(OrderTest): 'reference': '123456789-xyz', 'description': 'A different description', }, - expected_code=400 + expected_code=400, ) self.assertEqual(models.PurchaseOrder.objects.count(), n + 1) @@ -316,13 +323,7 @@ class PurchaseOrderTest(OrderTest): self.assertEqual(response.data['reference'], 'PO-123456789') # Try to alter (edit) the PurchaseOrder - response = self.patch( - url, - { - 'reference': 'PO-12345', - }, - expected_code=200 - ) + response = self.patch(url, {'reference': 'PO-12345'}, expected_code=200) # Reference should have changed self.assertEqual(response.data['reference'], 'PO-12345') @@ -332,7 +333,7 @@ class PurchaseOrderTest(OrderTest): response = self.delete(url, expected_code=403) # Now, add the "delete" permission! - self.assignRole("purchase_order.delete") + self.assignRole('purchase_order.delete') response = self.delete(url, expected_code=204) @@ -353,7 +354,7 @@ class PurchaseOrderTest(OrderTest): 'supplier': 1, 'description': 'A test purchase order', }, - expected_code=201 + expected_code=201, ) def test_po_duplicate(self): @@ -368,11 +369,11 @@ class PurchaseOrderTest(OrderTest): # Add some extra line items to this order for idx in range(5): - lines.append(models.PurchaseOrderExtraLine( - order=po, - quantity=idx + 10, - reference='some reference', - )) + lines.append( + models.PurchaseOrderExtraLine( + order=po, quantity=idx + 10, reference='some reference' + ) + ) # bulk create orders models.PurchaseOrderExtraLine.objects.bulk_create(lines) @@ -390,11 +391,7 @@ class PurchaseOrderTest(OrderTest): data['reference'] = 'PO-9999' # Duplicate via the API - response = self.post( - reverse('api-po-list'), - data, - expected_code=400 - ) + response = self.post(reverse('api-po-list'), data, expected_code=400) data['duplicate_order'] = 1 data['duplicate_line_items'] = True @@ -403,11 +400,7 @@ class PurchaseOrderTest(OrderTest): data['reference'] = 'PO-9999' # Duplicate via the API - response = self.post( - reverse('api-po-list'), - data, - expected_code=201 - ) + response = self.post(reverse('api-po-list'), data, expected_code=201) # Order is for the same supplier self.assertEqual(response.data['supplier'], po.supplier.pk) @@ -421,11 +414,7 @@ class PurchaseOrderTest(OrderTest): data['duplicate_line_items'] = False data['duplicate_extra_lines'] = True - response = self.post( - reverse('api-po-list'), - data, - expected_code=201, - ) + response = self.post(reverse('api-po-list'), data, expected_code=201) po_dup = models.PurchaseOrder.objects.get(pk=response.data['pk']) @@ -445,11 +434,7 @@ class PurchaseOrderTest(OrderTest): self.assignRole('purchase_order.add') - self.post( - url, - {}, - expected_code=201, - ) + self.post(url, {}, expected_code=201) po.refresh_from_db() @@ -474,16 +459,12 @@ class PurchaseOrderTest(OrderTest): # Should fail due to incomplete lines response = self.post(url, {}, expected_code=400) - self.assertIn('Order has incomplete line items', str(response.data['accept_incomplete'])) + self.assertIn( + 'Order has incomplete line items', str(response.data['accept_incomplete']) + ) # Post again, accepting incomplete line items - self.post( - url, - { - 'accept_incomplete': True, - }, - expected_code=201 - ) + self.post(url, {'accept_incomplete': True}, expected_code=201) po.refresh_from_db() @@ -520,37 +501,36 @@ class PurchaseOrderTest(OrderTest): 'description': f'Calendar PO {i}', 'target_date': f'2024-12-{i:02d}', }, - expected_code=201 + expected_code=201, ) # Get some of these orders with target date, complete or cancel them for po in models.PurchaseOrder.objects.filter(target_date__isnull=False): - if po.reference in ['PO-11000001', 'PO-11000002', 'PO-11000003', 'PO-11000004']: + if po.reference in [ + 'PO-11000001', + 'PO-11000002', + 'PO-11000003', + 'PO-11000004', + ]: # Set issued status for these POs self.post( - reverse('api-po-issue', kwargs={'pk': po.pk}), - {}, - expected_code=201 + reverse('api-po-issue', kwargs={'pk': po.pk}), {}, expected_code=201 ) if po.reference in ['PO-11000001', 'PO-11000002']: # Set complete status for these POs self.post( reverse('api-po-complete', kwargs={'pk': po.pk}), - { - 'accept_incomplete': True, - }, - expected_code=201 + {'accept_incomplete': True}, + expected_code=201, ) elif po.reference in ['PO-11000005', 'PO-11000006']: # Set cancel status for these POs self.post( reverse('api-po-cancel', kwargs={'pk': po.pk}), - { - 'accept_incomplete': True, - }, - expected_code=201 + {'accept_incomplete': True}, + expected_code=201, ) url = reverse('api-po-so-calendar', kwargs={'ordertype': 'purchase-order'}) @@ -558,7 +538,11 @@ class PurchaseOrderTest(OrderTest): # Test without completed orders response = self.get(url, expected_code=200, format=None) - number_orders = len(models.PurchaseOrder.objects.filter(target_date__isnull=False).filter(status__lt=PurchaseOrderStatus.COMPLETE.value)) + number_orders = len( + models.PurchaseOrder.objects.filter(target_date__isnull=False).filter( + status__lt=PurchaseOrderStatus.COMPLETE.value + ) + ) # Transform content to a Calendar object calendar = Calendar.from_ical(response.content) @@ -572,9 +556,13 @@ class PurchaseOrderTest(OrderTest): self.assertEqual(number_orders, n_events) # Test with completed orders - response = self.get(url, data={'include_completed': 'True'}, expected_code=200, format=None) + response = self.get( + url, data={'include_completed': 'True'}, expected_code=200, format=None + ) - number_orders_incl_completed = len(models.PurchaseOrder.objects.filter(target_date__isnull=False)) + number_orders_incl_completed = len( + models.PurchaseOrder.objects.filter(target_date__isnull=False) + ) self.assertGreater(number_orders_incl_completed, number_orders) @@ -592,21 +580,28 @@ class PurchaseOrderTest(OrderTest): def test_po_calendar_noauth(self): """Test accessing calendar without authorization""" self.client.logout() - response = self.client.get(reverse('api-po-so-calendar', kwargs={'ordertype': 'purchase-order'}), format='json') + response = self.client.get( + reverse('api-po-so-calendar', kwargs={'ordertype': 'purchase-order'}), + format='json', + ) self.assertEqual(response.status_code, 401) resp_dict = response.json() - self.assertEqual(resp_dict['detail'], "Authentication credentials were not provided.") + self.assertEqual( + resp_dict['detail'], 'Authentication credentials were not provided.' + ) def test_po_calendar_auth(self): """Test accessing calendar with header authorization""" self.client.logout() - base64_token = base64.b64encode(f'{self.username}:{self.password}'.encode('ascii')).decode('ascii') + base64_token = base64.b64encode( + f'{self.username}:{self.password}'.encode('ascii') + ).decode('ascii') response = self.client.get( reverse('api-po-so-calendar', kwargs={'ordertype': 'purchase-order'}), format='json', - HTTP_AUTHORIZATION=f'basic {base64_token}' + HTTP_AUTHORIZATION=f'basic {base64_token}', ) self.assertEqual(response.status_code, 200) @@ -642,13 +637,7 @@ class PurchaseOrderLineItemTest(OrderTest): url = reverse('api-po-line-list') # Try to delete a set of line items via their IDs - self.delete( - url, - { - 'items': [1, 2], - }, - expected_code=204, - ) + self.delete(url, {'items': [1, 2]}, expected_code=204) # We should have 2 less PurchaseOrderLineItems after deletign them self.assertEqual(models.PurchaseOrderLineItem.objects.count(), n - 2) @@ -668,38 +657,28 @@ class PurchaseOrderDownloadTest(OrderTest): 'supplier_reference', ] - excluded_cols = [ - 'metadata', - ] + excluded_cols = ['metadata'] def test_download_wrong_format(self): """Incorrect format should default raise an error.""" url = reverse('api-po-list') with self.assertRaises(ValueError): - self.download_file( - url, - { - 'export': 'xyz', - } - ) + self.download_file(url, {'export': 'xyz'}) def test_download_csv(self): """Download PurchaseOrder data as .csv.""" with self.download_file( reverse('api-po-list'), - { - 'export': 'csv', - }, + {'export': 'csv'}, expected_code=200, expected_fn='InvenTree_PurchaseOrders.csv', ) as file: - data = self.process_csv( file, required_cols=self.required_cols, excluded_cols=self.excluded_cols, - required_rows=models.PurchaseOrder.objects.count() + required_rows=models.PurchaseOrder.objects.count(), ) for row in data: @@ -712,14 +691,11 @@ class PurchaseOrderDownloadTest(OrderTest): """Test that the PurchaseOrderLineItems can be downloaded to a file""" with self.download_file( reverse('api-po-line-list'), - { - 'export': 'xlsx', - }, + {'export': 'xlsx'}, decode=False, expected_code=200, expected_fn='InvenTree_PurchaseOrderItems.xlsx', ) as file: - self.assertTrue(isinstance(file, io.BytesIO)) @@ -755,12 +731,7 @@ class PurchaseOrderReceiveTest(OrderTest): def test_no_items(self): """Test with an empty list of items.""" data = self.post( - self.url, - { - "items": [], - "location": None, - }, - expected_code=400 + self.url, {'items': [], 'location': None}, expected_code=400 ).data self.assertIn('Line items must be provided', str(data)) @@ -772,21 +743,14 @@ class PurchaseOrderReceiveTest(OrderTest): """Test than errors are returned as expected for invalid data.""" data = self.post( self.url, - { - "items": [ - { - "line_item": 12345, - "location": 12345 - } - ] - }, - expected_code=400 + {'items': [{'line_item': 12345, 'location': 12345}]}, + expected_code=400, ).data items = data['items'][0] self.assertIn('Invalid pk "12345"', str(items['line_item'])) - self.assertIn("object does not exist", str(items['location'])) + self.assertIn('object does not exist', str(items['location'])) # No new stock items have been created self.assertEqual(self.n, StockItem.objects.count()) @@ -796,16 +760,11 @@ class PurchaseOrderReceiveTest(OrderTest): data = self.post( self.url, { - "items": [ - { - "line_item": 22, - "location": 1, - "status": 99999, - "quantity": 5, - } + 'items': [ + {'line_item': 22, 'location': 1, 'status': 99999, 'quantity': 5} ] }, - expected_code=400 + expected_code=400, ).data self.assertIn('"99999" is not a valid choice.', str(data)) @@ -818,16 +777,10 @@ class PurchaseOrderReceiveTest(OrderTest): data = self.post( self.url, { - 'items': [ - { - 'line_item': 22, - 'quantity': 123, - 'location': 1, - } - ], + 'items': [{'line_item': 22, 'quantity': 123, 'location': 1}], 'location': None, }, - expected_code=400 + expected_code=400, ).data self.assertIn('Line item does not match purchase order', str(data)) @@ -845,16 +798,10 @@ class PurchaseOrderReceiveTest(OrderTest): self.post( self.url, { - 'items': [ - { - 'line_item': 1, - 'quantity': 50, - 'barcode': None, - } - ], + 'items': [{'line_item': 1, 'quantity': 50, 'barcode': None}], 'location': 1, }, - expected_code=201 + expected_code=201, ) def test_invalid_barcodes(self): @@ -871,15 +818,11 @@ class PurchaseOrderReceiveTest(OrderTest): self.url, { 'items': [ - { - 'line_item': 1, - 'quantity': 50, - 'barcode': 'MY-BARCODE-HASH', - } + {'line_item': 1, 'quantity': 50, 'barcode': 'MY-BARCODE-HASH'} ], 'location': 1, }, - expected_code=400 + expected_code=400, ) self.assertIn('Barcode is already in use', str(response.data)) @@ -888,20 +831,12 @@ class PurchaseOrderReceiveTest(OrderTest): self.url, { 'items': [ - { - 'line_item': 1, - 'quantity': 5, - 'barcode': 'MY-BARCODE-HASH-1', - }, - { - 'line_item': 1, - 'quantity': 5, - 'barcode': 'MY-BARCODE-HASH-1' - }, + {'line_item': 1, 'quantity': 5, 'barcode': 'MY-BARCODE-HASH-1'}, + {'line_item': 1, 'quantity': 5, 'barcode': 'MY-BARCODE-HASH-1'}, ], 'location': 1, }, - expected_code=400 + expected_code=400, ) self.assertIn('barcode values must be unique', str(response.data)) @@ -922,17 +857,13 @@ class PurchaseOrderReceiveTest(OrderTest): valid_data = { 'items': [ - { - 'line_item': 1, - 'quantity': 50, - 'barcode': 'MY-UNIQUE-BARCODE-123', - }, + {'line_item': 1, 'quantity': 50, 'barcode': 'MY-UNIQUE-BARCODE-123'}, { 'line_item': 2, 'quantity': 200, 'location': 2, # Explicit location 'barcode': 'MY-UNIQUE-BARCODE-456', - } + }, ], 'location': 1, # Default location } @@ -943,11 +874,7 @@ class PurchaseOrderReceiveTest(OrderTest): order.status = PurchaseOrderStatus.PENDING.value order.save() - response = self.post( - self.url, - valid_data, - expected_code=400 - ) + response = self.post(self.url, valid_data, expected_code=400) self.assertIn('can only be received against', str(response.data)) @@ -956,11 +883,7 @@ class PurchaseOrderReceiveTest(OrderTest): order.save() # Receive two separate line items against this order - self.post( - self.url, - valid_data, - expected_code=201, - ) + self.post(self.url, valid_data, expected_code=201) # There should be two newly created stock items self.assertEqual(self.n + 2, StockItem.objects.count()) @@ -983,8 +906,12 @@ class PurchaseOrderReceiveTest(OrderTest): self.assertEqual(stock_2.last().location.pk, 1) # Barcodes should have been assigned to the stock items - self.assertTrue(StockItem.objects.filter(barcode_data='MY-UNIQUE-BARCODE-123').exists()) - self.assertTrue(StockItem.objects.filter(barcode_data='MY-UNIQUE-BARCODE-456').exists()) + self.assertTrue( + StockItem.objects.filter(barcode_data='MY-UNIQUE-BARCODE-123').exists() + ) + self.assertTrue( + StockItem.objects.filter(barcode_data='MY-UNIQUE-BARCODE-456').exists() + ) def test_batch_code(self): """Test that we can supply a 'batch code' when receiving items.""" @@ -996,27 +923,15 @@ class PurchaseOrderReceiveTest(OrderTest): data = { 'items': [ - { - 'line_item': 1, - 'quantity': 10, - 'batch_code': 'B-abc-123', - }, - { - 'line_item': 2, - 'quantity': 10, - 'batch_code': 'B-xyz-789', - } + {'line_item': 1, 'quantity': 10, 'batch_code': 'B-abc-123'}, + {'line_item': 2, 'quantity': 10, 'batch_code': 'B-xyz-789'}, ], 'location': 1, } n = StockItem.objects.count() - self.post( - self.url, - data, - expected_code=201, - ) + self.post(self.url, data, expected_code=201) # Check that two new stock items have been created! self.assertEqual(n + 2, StockItem.objects.count()) @@ -1043,28 +958,22 @@ class PurchaseOrderReceiveTest(OrderTest): 'batch_code': 'B-abc-123', 'serial_numbers': '100+', }, - { - 'line_item': 2, - 'quantity': 10, - 'batch_code': 'B-xyz-789', - } + {'line_item': 2, 'quantity': 10, 'batch_code': 'B-xyz-789'}, ], 'location': 1, } n = StockItem.objects.count() - self.post( - self.url, - data, - expected_code=201, - ) + self.post(self.url, data, expected_code=201) # Check that the expected number of stock items has been created self.assertEqual(n + 11, StockItem.objects.count()) # 10 serialized stock items created for the first line item - self.assertEqual(StockItem.objects.filter(supplier_part=line_1.part).count(), 10) + self.assertEqual( + StockItem.objects.filter(supplier_part=line_1.part).count(), 10 + ) # Check that the correct serial numbers have been allocated for i in range(100, 110): @@ -1129,10 +1038,7 @@ class SalesOrderTest(OrderTest): for customer in Company.objects.filter(is_customer=True): for _idx in range(10): new_orders.append( - models.SalesOrder( - customer=customer, - reference=f'SO-{idx + 100}', - ) + models.SalesOrder(customer=customer, reference=f'SO-{idx + 100}') ) idx += 1 @@ -1153,7 +1059,7 @@ class SalesOrderTest(OrderTest): order=so, part=p, quantity=idx + 1, - sale_price=Money((idx + 1) / 5, currencies[idx % n]) + sale_price=Money((idx + 1) / 5, currencies[idx % n]), ) ) @@ -1163,9 +1069,7 @@ class SalesOrderTest(OrderTest): for _ in range(3): extra_lines.append( models.SalesOrderExtraLine( - order=so, - quantity=(idx + 2) % 10, - price=Money(10, 'CAD'), + order=so, quantity=(idx + 2) % 10, price=Money(10, 'CAD') ) ) @@ -1175,7 +1079,9 @@ class SalesOrderTest(OrderTest): # List all SalesOrder objects and count queries for limit in [1, 5, 10, 100]: with CaptureQueriesContext(connection) as ctx: - response = self.get(self.LIST_URL, data={'limit': limit}, expected_code=200) + response = self.get( + self.LIST_URL, data={'limit': limit}, expected_code=200 + ) # Total database queries must be less than 15 self.assertLess(len(ctx), 15) @@ -1225,11 +1131,7 @@ class SalesOrderTest(OrderTest): # so this POST request should return 403 (denied) response = self.post( url, - { - 'customer': 4, - 'reference': '12345', - 'description': 'Sales order', - }, + {'customer': 4, 'reference': '12345', 'description': 'Sales order'}, expected_code=403, ) @@ -1238,12 +1140,8 @@ class SalesOrderTest(OrderTest): # Now we should be able to create a SalesOrder via the API response = self.post( url, - { - 'customer': 4, - 'reference': 'SO-12345', - 'description': 'Sales order', - }, - expected_code=201 + {'customer': 4, 'reference': 'SO-12345', 'description': 'Sales order'}, + expected_code=201, ) # Check that the new order has been created @@ -1260,7 +1158,7 @@ class SalesOrderTest(OrderTest): 'reference': 'SO-12345', 'description': 'Another sales order', }, - expected_code=400 + expected_code=400, ) url = reverse('api-so-detail', kwargs={'pk': pk}) @@ -1271,21 +1169,9 @@ class SalesOrderTest(OrderTest): # Try to alter (edit) the SalesOrder # Initially try with an invalid reference field value - response = self.patch( - url, - { - 'reference': 'SO-12345-a', - }, - expected_code=400 - ) + response = self.patch(url, {'reference': 'SO-12345-a'}, expected_code=400) - response = self.patch( - url, - { - 'reference': 'SO-12346', - }, - expected_code=200 - ) + response = self.patch(url, {'reference': 'SO-12346'}, expected_code=200) # Reference should have changed self.assertEqual(response.data['reference'], 'SO-12346') @@ -1321,7 +1207,9 @@ class SalesOrderTest(OrderTest): expected_code=400, ) - self.assertIn('Reference must match required pattern', str(response.data['reference'])) + self.assertIn( + 'Reference must match required pattern', str(response.data['reference']) + ) self.post( url, @@ -1366,15 +1254,19 @@ class SalesOrderTest(OrderTest): 'description': f'Calendar SO {i}', 'target_date': f'2024-12-{i:02d}', }, - expected_code=201 + expected_code=201, ) # Cancel a few orders - these will not show in incomplete view below for so in models.SalesOrder.objects.filter(target_date__isnull=False): - if so.reference in ['SO-11000006', 'SO-11000007', 'SO-11000008', 'SO-11000009']: + if so.reference in [ + 'SO-11000006', + 'SO-11000007', + 'SO-11000008', + 'SO-11000009', + ]: self.post( - reverse('api-so-cancel', kwargs={'pk': so.pk}), - expected_code=201 + reverse('api-so-cancel', kwargs={'pk': so.pk}), expected_code=201 ) url = reverse('api-po-so-calendar', kwargs={'ordertype': 'sales-order'}) @@ -1382,7 +1274,11 @@ class SalesOrderTest(OrderTest): # Test without completed orders response = self.get(url, expected_code=200, format=None) - number_orders = len(models.SalesOrder.objects.filter(target_date__isnull=False).filter(status__lt=SalesOrderStatus.SHIPPED.value)) + number_orders = len( + models.SalesOrder.objects.filter(target_date__isnull=False).filter( + status__lt=SalesOrderStatus.SHIPPED.value + ) + ) # Transform content to a Calendar object calendar = Calendar.from_ical(response.content) @@ -1396,9 +1292,13 @@ class SalesOrderTest(OrderTest): self.assertEqual(number_orders, n_events) # Test with completed orders - response = self.get(url, data={'include_completed': 'True'}, expected_code=200, format=None) + response = self.get( + url, data={'include_completed': 'True'}, expected_code=200, format=None + ) - number_orders_incl_complete = len(models.SalesOrder.objects.filter(target_date__isnull=False)) + number_orders_incl_complete = len( + models.SalesOrder.objects.filter(target_date__isnull=False) + ) self.assertGreater(number_orders_incl_complete, number_orders) # Transform content to a Calendar object @@ -1430,7 +1330,7 @@ class SalesOrderTest(OrderTest): {'export': fmt}, decode=True if fmt == 'csv' else False, expected_code=200, - expected_fn=f"InvenTree_SalesOrders.{fmt}" + expected_fn=f'InvenTree_SalesOrders.{fmt}', ) @@ -1451,14 +1351,13 @@ class SalesOrderLineItemTest(OrderTest): # Create a bunch of SalesOrderLineItems for each order for idx, so in enumerate(models.SalesOrder.objects.all()): - for part in parts: lines.append( models.SalesOrderLineItem( order=so, part=part, quantity=(idx + 1) * 5, - reference=f"Order {so.reference} - line {idx}", + reference=f'Order {so.reference} - line {idx}', ) ) @@ -1469,11 +1368,7 @@ class SalesOrderLineItemTest(OrderTest): def test_so_line_list(self): """Test list endpoint""" - response = self.get( - self.url, - {}, - expected_code=200, - ) + response = self.get(self.url, {}, expected_code=200) n = models.SalesOrderLineItem.objects.count() @@ -1481,13 +1376,7 @@ class SalesOrderLineItemTest(OrderTest): self.assertEqual(len(response.data), n) # List *all* lines, but paginate - response = self.get( - self.url, - { - "limit": 5, - }, - expected_code=200, - ) + response = self.get(self.url, {'limit': 5}, expected_code=200) self.assertEqual(response.data['count'], n) self.assertEqual(len(response.data['results']), 5) @@ -1497,25 +1386,13 @@ class SalesOrderLineItemTest(OrderTest): # List by part for part in Part.objects.filter(salable=True)[:3]: - response = self.get( - self.url, - { - 'part': part.pk, - 'limit': 10, - } - ) + response = self.get(self.url, {'part': part.pk, 'limit': 10}) self.assertEqual(response.data['count'], n_orders) # List by order for order in models.SalesOrder.objects.all()[:3]: - response = self.get( - self.url, - { - 'order': order.pk, - 'limit': 10, - } - ) + response = self.get(self.url, {'order': order.pk, 'limit': 10}) self.assertEqual(response.data['count'], n_parts) @@ -1545,9 +1422,7 @@ class SalesOrderDownloadTest(OrderTest): # Download .xls file with self.download_file( url, - { - 'export': 'xls', - }, + {'export': 'xls'}, expected_code=200, expected_fn='InvenTree_SalesOrders.xls', decode=False, @@ -1569,30 +1444,24 @@ class SalesOrderDownloadTest(OrderTest): 'description', ] - excluded_cols = [ - 'metadata' - ] + excluded_cols = ['metadata'] # Download .xls file with self.download_file( url, - { - 'export': 'csv', - }, + {'export': 'csv'}, expected_code=200, expected_fn='InvenTree_SalesOrders.csv', - decode=True + decode=True, ) as file: - data = self.process_csv( file, required_cols=required_cols, excluded_cols=excluded_cols, - required_rows=models.SalesOrder.objects.count() + required_rows=models.SalesOrder.objects.count(), ) for line in data: - order = models.SalesOrder.objects.get(pk=line['id']) self.assertEqual(line['description'], order.description) @@ -1601,20 +1470,18 @@ class SalesOrderDownloadTest(OrderTest): # Download only outstanding sales orders with self.download_file( url, - { - 'export': 'tsv', - 'outstanding': True, - }, + {'export': 'tsv', 'outstanding': True}, expected_code=200, expected_fn='InvenTree_SalesOrders.tsv', decode=True, ) as file: - self.process_csv( file, required_cols=required_cols, excluded_cols=excluded_cols, - required_rows=models.SalesOrder.objects.filter(status__in=SalesOrderStatusGroups.OPEN).count(), + required_rows=models.SalesOrder.objects.filter( + status__in=SalesOrderStatusGroups.OPEN + ).count(), delimiter='\t', ) @@ -1636,24 +1503,16 @@ class SalesOrderAllocateTest(OrderTest): parts = Part.objects.filter(salable=True) for part in parts: - # Create a new line item models.SalesOrderLineItem.objects.create( - order=self.order, - part=part, - quantity=5, + order=self.order, part=part, quantity=5 ) # Ensure we have stock! - StockItem.objects.create( - part=part, - quantity=100, - ) + StockItem.objects.create(part=part, quantity=100) # Create a new shipment against this SalesOrder - self.shipment = models.SalesOrderShipment.objects.create( - order=self.order, - ) + self.shipment = models.SalesOrderShipment.objects.create(order=self.order) def test_invalid(self): """Test POST with invalid data.""" @@ -1669,11 +1528,13 @@ class SalesOrderAllocateTest(OrderTest): # Valid stock_item, but quantity is invalid data = { - 'items': [{ - "line_item": line.pk, - "stock_item": part.stock_items.last().pk, - "quantity": 0, - }], + 'items': [ + { + 'line_item': line.pk, + 'stock_item': part.stock_items.last().pk, + 'quantity': 0, + } + ] } response = self.post(self.url, data, expected_code=400) @@ -1699,33 +1560,32 @@ class SalesOrderAllocateTest(OrderTest): # Valid shipment, but points to the wrong order shipment = models.SalesOrderShipment.objects.create( - order=models.SalesOrder.objects.get(pk=2), + order=models.SalesOrder.objects.get(pk=2) ) data['shipment'] = shipment.pk response = self.post(self.url, data, expected_code=400) - self.assertIn('Shipment is not associated with this order', str(response.data['shipment'])) + self.assertIn( + 'Shipment is not associated with this order', str(response.data['shipment']) + ) def test_allocate(self): """Test that the allocation endpoint acts as expected, when provided with valid data!""" # First, check that there are no line items allocated against this SalesOrder self.assertEqual(self.order.stock_allocations.count(), 0) - data = { - "items": [], - "shipment": self.shipment.pk, - } + data = {'items': [], 'shipment': self.shipment.pk} for line in self.order.lines.all(): stock_item = line.part.stock_items.last() # Fully-allocate each line data['items'].append({ - "line_item": line.pk, - "stock_item": stock_item.pk, - "quantity": 5 + 'line_item': line.pk, + 'stock_item': stock_item.pk, + 'quantity': 5, }) self.post(self.url, data, expected_code=201) @@ -1743,16 +1603,12 @@ class SalesOrderAllocateTest(OrderTest): # First, check that there are no line items allocated against this SalesOrder self.assertEqual(self.order.stock_allocations.count(), 0) - data = { - "items": [], - "shipment": self.shipment.pk, - } + data = {'items': [], 'shipment': self.shipment.pk} def check_template(line_item): return line_item.part.is_template for line in filter(check_template, self.order.lines.all()): - stock_item = None # Allocate a matching variant @@ -1763,9 +1619,9 @@ class SalesOrderAllocateTest(OrderTest): # Fully-allocate each line data['items'].append({ - "line_item": line.pk, - "stock_item": stock_item.pk, - "quantity": 5 + 'line_item': line.pk, + 'stock_item': stock_item.pk, + 'quantity': 5, }) self.post(self.url, data, expected_code=201) @@ -1786,10 +1642,7 @@ class SalesOrderAllocateTest(OrderTest): self.shipment.check_can_complete() # Attempting to complete this shipment via the API should fail - response = self.post( - url, {}, - expected_code=400 - ) + response = self.post(url, {}, expected_code=400) self.assertIn('Shipment has no allocated stock items', str(response.data)) @@ -1798,23 +1651,14 @@ class SalesOrderAllocateTest(OrderTest): part = line.part models.SalesOrderAllocation.objects.create( - shipment=self.shipment, - line=line, - item=part.stock_items.last(), - quantity=5 + shipment=self.shipment, line=line, item=part.stock_items.last(), quantity=5 ) # Shipment should now be able to be completed self.assertTrue(self.shipment.check_can_complete()) # Attempt with an invalid date - response = self.post( - url, - { - 'shipment_date': 'asfasd', - }, - expected_code=400, - ) + response = self.post(url, {'shipment_date': 'asfasd'}, expected_code=400) self.assertIn('Date has wrong format', str(response.data)) @@ -1844,35 +1688,17 @@ class SalesOrderAllocateTest(OrderTest): url = reverse('api-so-shipment-detail', kwargs={'pk': self.shipment.pk}) # Attempt remove delivery_date from shipment - response = self.patch( - url, - { - 'delivery_date': None, - }, - expected_code=200, - ) + response = self.patch(url, {'delivery_date': None}, expected_code=200) # Shipment should not be marked as delivered self.assertFalse(self.shipment.is_delivered()) # Attempt to set delivery date - response = self.patch( - url, - { - 'delivery_date': 'asfasd', - }, - expected_code=400, - ) + response = self.patch(url, {'delivery_date': 'asfasd'}, expected_code=400) self.assertIn('Date has wrong format', str(response.data)) - response = self.patch( - url, - { - 'delivery_date': '2023-05-15', - }, - expected_code=200, - ) + response = self.patch(url, {'delivery_date': '2023-05-15'}, expected_code=200) self.shipment.refresh_from_db() # Shipment should now be marked as delivered @@ -1888,26 +1714,19 @@ class SalesOrderAllocateTest(OrderTest): # Create some new shipments via the API for order in models.SalesOrder.objects.all(): - for idx in range(3): self.post( url, { 'order': order.pk, - 'reference': f"SH{idx + 1}", - 'tracking_number': f"TRK_{order.pk}_{idx}" + 'reference': f'SH{idx + 1}', + 'tracking_number': f'TRK_{order.pk}_{idx}', }, - expected_code=201 + expected_code=201, ) # Filter API by order - response = self.get( - url, - { - 'order': order.pk, - }, - expected_code=200, - ) + response = self.get(url, {'order': order.pk}, expected_code=200) # 3 shipments returned for each SalesOrder instance self.assertGreaterEqual(len(response.data), 3) @@ -1915,7 +1734,9 @@ class SalesOrderAllocateTest(OrderTest): # List *all* shipments response = self.get(url, expected_code=200) - self.assertEqual(len(response.data), count_before + 3 * models.SalesOrder.objects.count()) + self.assertEqual( + len(response.data), count_before + 3 * models.SalesOrder.objects.count() + ) class ReturnOrderTests(InvenTreeAPITestCase): @@ -1959,12 +1780,8 @@ class ReturnOrderTests(InvenTreeAPITestCase): # Paginated query data = self.get( url, - { - 'limit': 1, - 'ordering': 'reference', - 'customer_detail': True, - }, - expected_code=200 + {'limit': 1, 'ordering': 'reference', 'customer_detail': True}, + expected_code=200, ).data self.assertEqual(data['count'], 6) @@ -1974,25 +1791,13 @@ class ReturnOrderTests(InvenTreeAPITestCase): self.assertEqual(result['customer_detail']['name'], 'A customer') # Reverse ordering - data = self.get( - url, - { - 'ordering': '-reference', - }, - expected_code=200 - ).data + data = self.get(url, {'ordering': '-reference'}, expected_code=200).data self.assertEqual(data[0]['reference'], 'RMA-006') # Filter by customer for cmp_id in [4, 5]: - data = self.get( - url, - { - 'customer': cmp_id, - }, - expected_code=200 - ).data + data = self.get(url, {'customer': cmp_id}, expected_code=200).data self.assertEqual(len(data), 3) @@ -2000,13 +1805,7 @@ class ReturnOrderTests(InvenTreeAPITestCase): self.assertEqual(result['customer'], cmp_id) # Filter by status - data = self.get( - url, - { - 'status': 20, - }, - expected_code=200 - ).data + data = self.get(url, {'status': 20}, expected_code=200).data self.assertEqual(len(data), 2) @@ -2019,12 +1818,7 @@ class ReturnOrderTests(InvenTreeAPITestCase): # Do not have required permissions yet self.post( - url, - { - 'customer': 1, - 'description': 'a return order', - }, - expected_code=403 + url, {'customer': 1, 'description': 'a return order'}, expected_code=403 ) self.assignRole('return_order.add') @@ -2036,7 +1830,7 @@ class ReturnOrderTests(InvenTreeAPITestCase): 'customer_reference': 'cr', 'description': 'a return order', }, - expected_code=201 + expected_code=201, ).data # Reference automatically generated @@ -2054,22 +1848,12 @@ class ReturnOrderTests(InvenTreeAPITestCase): # Attempt to update, incorrect permissions self.patch( - url, - { - 'customer_reference': 'My customer reference', - }, - expected_code=403 + url, {'customer_reference': 'My customer reference'}, expected_code=403 ) self.assignRole('return_order.change') - self.patch( - url, - { - 'customer_reference': 'customer ref', - }, - expected_code=200 - ) + self.patch(url, {'customer_reference': 'customer ref'}, expected_code=200) rma = models.ReturnOrder.objects.get(pk=1) self.assertEqual(rma.customer_reference, 'customer ref') @@ -2098,8 +1882,7 @@ class ReturnOrderTests(InvenTreeAPITestCase): # Create an order rma = models.ReturnOrder.objects.create( - customer=customer, - description='A return order', + customer=customer, description='A return order' ) self.assertEqual(rma.reference, 'RMA-0007') @@ -2108,13 +1891,11 @@ class ReturnOrderTests(InvenTreeAPITestCase): part = Part.objects.get(pk=25) for idx in range(3): stock_item = StockItem.objects.create( - part=part, customer=customer, - quantity=1, serial=idx + part=part, customer=customer, quantity=1, serial=idx ) line_item = models.ReturnOrderLineItem.objects.create( - order=rma, - item=stock_item, + order=rma, item=stock_item ) self.assertEqual(line_item.outcome, ReturnOrderLineStatus.PENDING) @@ -2128,12 +1909,7 @@ class ReturnOrderTests(InvenTreeAPITestCase): url = reverse('api-return-order-receive', kwargs={'pk': rma.pk}) response = self.post( - url, - { - 'items': items, - 'location': location, - }, - expected_code=expected_code + url, {'items': items, 'location': location}, expected_code=expected_code ) return response.data @@ -2149,13 +1925,14 @@ class ReturnOrderTests(InvenTreeAPITestCase): # Receive, with incorrect order code data = receive([], 1, expected_code=400) - self.assertIn('Items can only be received against orders which are in progress', str(data)) + self.assertIn( + 'Items can only be received against orders which are in progress', str(data) + ) # Issue the order (via the API) self.assertIsNone(rma.issue_date) self.post( - reverse("api-return-order-issue", kwargs={"pk": rma.pk}), - expected_code=201, + reverse('api-return-order-issue', kwargs={'pk': rma.pk}), expected_code=201 ) rma.refresh_from_db() @@ -2173,9 +1950,7 @@ class ReturnOrderTests(InvenTreeAPITestCase): # Receive items successfully data = receive( - [{'item': line.pk} for line in rma.lines.all()], - 1, - expected_code=201 + [{'item': line.pk} for line in rma.lines.all()], 1, expected_code=201 ) # Check that all line items have been received @@ -2221,11 +1996,7 @@ class OrderMetadataAPITest(InvenTreeAPITestCase): 'return_order', ] - roles = [ - 'purchase_order.change', - 'sales_order.change', - 'return_order.change', - ] + roles = ['purchase_order.change', 'sales_order.change', 'return_order.change'] def metatester(self, apikey, model): """Generic tester""" @@ -2243,17 +2014,15 @@ class OrderMetadataAPITest(InvenTreeAPITestCase): self.patch( url, - { - 'metadata': { - f'abc-{numstr}': f'xyz-{apikey}-{numstr}', - } - }, - expected_code=200 + {'metadata': {f'abc-{numstr}': f'xyz-{apikey}-{numstr}'}}, + expected_code=200, ) # Refresh modeldata.refresh_from_db() - self.assertEqual(modeldata.get_metadata(f'abc-{numstr}'), f'xyz-{apikey}-{numstr}') + self.assertEqual( + modeldata.get_metadata(f'abc-{numstr}'), f'xyz-{apikey}-{numstr}' + ) def test_metadata(self): """Test all endpoints""" diff --git a/InvenTree/order/test_migrations.py b/InvenTree/order/test_migrations.py index 8bcdbeae8e..407468cd51 100644 --- a/InvenTree/order/test_migrations.py +++ b/InvenTree/order/test_migrations.py @@ -28,11 +28,10 @@ class TestRefIntMigrations(MigratorTestCase): # Create some orders for ii in range(10): - order = PurchaseOrder.objects.create( supplier=supplier, - reference=f"{ii}-abcde", - description="Just a test order" + reference=f'{ii}-abcde', + description='Just a test order', ) # Initially, the 'reference_int' field is unavailable @@ -41,8 +40,8 @@ class TestRefIntMigrations(MigratorTestCase): sales_order = SalesOrder.objects.create( customer=supplier, - reference=f"{ii}-xyz", - description="A test sales order", + reference=f'{ii}-xyz', + description='A test sales order', ) # Initially, the 'reference_int' field is unavailable @@ -68,9 +67,8 @@ class TestRefIntMigrations(MigratorTestCase): SalesOrder = self.new_state.apps.get_model('order', 'salesorder') for ii in range(10): - - po = PurchaseOrder.objects.get(reference=f"{ii}-abcde") - so = SalesOrder.objects.get(reference=f"{ii}-xyz") + po = PurchaseOrder.objects.get(reference=f'{ii}-abcde') + so = SalesOrder.objects.get(reference=f'{ii}-xyz') # The integer reference field must have been correctly updated self.assertEqual(po.reference_int, ii) @@ -79,11 +77,11 @@ class TestRefIntMigrations(MigratorTestCase): # Tests for orders with overly large reference values po = PurchaseOrder.objects.get(pk=self.po_pk) self.assertEqual(po.reference, '999999999999999999999999999999999') - self.assertEqual(po.reference_int, 0x7fffffff) + self.assertEqual(po.reference_int, 0x7FFFFFFF) so = SalesOrder.objects.get(pk=self.so_pk) self.assertEqual(so.reference, '999999999999999999999999999999999') - self.assertEqual(so.reference_int, 0x7fffffff) + self.assertEqual(so.reference_int, 0x7FFFFFFF) class TestShipmentMigration(MigratorTestCase): @@ -99,7 +97,7 @@ class TestShipmentMigration(MigratorTestCase): customer = Company.objects.create( name='My customer', description='A customer we sell stuff too', - is_customer=True + is_customer=True, ) SalesOrder = self.old_state.apps.get_model('order', 'salesorder') @@ -162,28 +160,17 @@ class TestAdditionalLineMigration(MigratorTestCase): lft=0, rght=0, ) - supplierpart = Supplierpart.objects.create( - part=part, - supplier=supplier - ) + supplierpart = Supplierpart.objects.create(part=part, supplier=supplier) # Create some orders for ii in range(10): - order = PurchaseOrder.objects.create( supplier=supplier, - reference=f"{ii}-abcde", - description="Just a test order" - ) - order.lines.create( - part=supplierpart, - quantity=12, - received=1 - ) - order.lines.create( - quantity=12, - received=1 + reference=f'{ii}-abcde', + description='Just a test order', ) + order.lines.create(part=supplierpart, quantity=12, received=1) + order.lines.create(quantity=12, received=1) # TODO @matmair fix this test!!! # sales_order = SalesOrder.objects.create( @@ -201,14 +188,13 @@ class TestAdditionalLineMigration(MigratorTestCase): """Test that the the PO lines where converted correctly.""" PurchaseOrder = self.new_state.apps.get_model('order', 'purchaseorder') for ii in range(10): - - po = PurchaseOrder.objects.get(reference=f"{ii}-abcde") + po = PurchaseOrder.objects.get(reference=f'{ii}-abcde') self.assertEqual(po.extra_lines.count(), 1) self.assertEqual(po.lines.count(), 1) # TODO @matmair fix this test!!! # SalesOrder = self.new_state.apps.get_model('order', 'salesorder') # for ii in range(10): - # so = SalesOrder.objects.get(reference=f"{ii}-xyz") - # self.assertEqual(so.extra_lines, 1) - # self.assertEqual(so.lines.count(), 1) + # so = SalesOrder.objects.get(reference=f"{ii}-xyz") + # self.assertEqual(so.extra_lines, 1) + # self.assertEqual(so.lines.count(), 1) diff --git a/InvenTree/order/test_sales_order.py b/InvenTree/order/test_sales_order.py index 3bbf33c46c..9fda6261de 100644 --- a/InvenTree/order/test_sales_order.py +++ b/InvenTree/order/test_sales_order.py @@ -11,9 +11,13 @@ import order.tasks from common.models import InvenTreeSetting, NotificationMessage from company.models import Company from InvenTree import status_codes as status -from order.models import (SalesOrder, SalesOrderAllocation, - SalesOrderExtraLine, SalesOrderLineItem, - SalesOrderShipment) +from order.models import ( + SalesOrder, + SalesOrderAllocation, + SalesOrderExtraLine, + SalesOrderLineItem, + SalesOrderShipment, +) from part.models import Part from stock.models import StockItem from users.models import Owner @@ -22,15 +26,15 @@ from users.models import Owner class SalesOrderTest(TestCase): """Run tests to ensure that the SalesOrder model is working correctly.""" - fixtures = [ - 'users', - ] + fixtures = ['users'] @classmethod def setUpTestData(cls): """Initial setup for this set of unit tests""" # Create a Company to ship the goods to - cls.customer = Company.objects.create(name="ABC Co", description="My customer", is_customer=True) + cls.customer = Company.objects.create( + name='ABC Co', description='My customer', is_customer=True + ) # Create a Part to ship cls.part = Part.objects.create( @@ -53,22 +57,23 @@ class SalesOrderTest(TestCase): # Create a SalesOrder to ship against cls.order = SalesOrder.objects.create( - customer=cls.customer, - reference='SO-1234', - customer_reference='ABC 55555' + customer=cls.customer, reference='SO-1234', customer_reference='ABC 55555' ) # Create a Shipment against this SalesOrder cls.shipment = SalesOrderShipment.objects.create( - order=cls.order, - reference='SO-001', + order=cls.order, reference='SO-001' ) # Create a line item - cls.line = SalesOrderLineItem.objects.create(quantity=50, order=cls.order, part=cls.part) + cls.line = SalesOrderLineItem.objects.create( + quantity=50, order=cls.order, part=cls.part + ) # Create an extra line - cls.extraline = SalesOrderExtraLine.objects.create(quantity=1, order=cls.order, reference="Extra line") + cls.extraline = SalesOrderExtraLine.objects.create( + quantity=1, order=cls.order, reference='Extra line' + ) def test_so_reference(self): """Unit tests for sales order generation""" @@ -120,7 +125,9 @@ class SalesOrderTest(TestCase): def test_add_duplicate_line_item(self): """Adding a duplicate line item to a SalesOrder is accepted""" for ii in range(1, 5): - SalesOrderLineItem.objects.create(order=self.order, part=self.part, quantity=ii) + SalesOrderLineItem.objects.create( + order=self.order, part=self.part, quantity=ii + ) def allocate_stock(self, full=True): """Allocate stock to the order""" @@ -128,13 +135,14 @@ class SalesOrderTest(TestCase): line=self.line, shipment=self.shipment, item=StockItem.objects.get(pk=self.Sa.pk), - quantity=25) + quantity=25, + ) SalesOrderAllocation.objects.create( line=self.line, shipment=self.shipment, item=StockItem.objects.get(pk=self.Sb.pk), - quantity=25 if full else 20 + quantity=25 if full else 20, ) def test_over_allocate(self): @@ -145,10 +153,7 @@ class SalesOrderTest(TestCase): # First three allocations should succeed for _i in range(3): allocation = SalesOrderAllocation.objects.create( - line=self.line, - item=SA, - quantity=3, - shipment=self.shipment + line=self.line, item=SA, quantity=3, shipment=self.shipment ) # Editing an existing allocation with a larger quantity should fail @@ -160,10 +165,7 @@ class SalesOrderTest(TestCase): # Next allocation should fail with self.assertRaises(ValidationError): allocation = SalesOrderAllocation.objects.create( - line=self.line, - item=SA, - quantity=3, - shipment=self.shipment + line=self.line, item=SA, quantity=3, shipment=self.shipment ) allocation.clean() @@ -191,7 +193,7 @@ class SalesOrderTest(TestCase): line=self.line, shipment=self.shipment, item=StockItem.objects.get(pk=self.Sc.pk), - quantity=50 + quantity=50, ) self.assertEqual(self.line.allocated_quantity(), 50) @@ -279,13 +281,13 @@ class SalesOrderTest(TestCase): def test_default_shipment(self): """Test sales order default shipment creation""" # Default setting value should be False - self.assertEqual(False, InvenTreeSetting.get_setting('SALESORDER_DEFAULT_SHIPMENT')) + self.assertEqual( + False, InvenTreeSetting.get_setting('SALESORDER_DEFAULT_SHIPMENT') + ) # Create an order order_1 = SalesOrder.objects.create( - customer=self.customer, - reference='1235', - customer_reference='ABC 55556' + customer=self.customer, reference='1235', customer_reference='ABC 55556' ) # Order should have no shipments when setting is False @@ -293,13 +295,13 @@ class SalesOrderTest(TestCase): # Update setting to True InvenTreeSetting.set_setting('SALESORDER_DEFAULT_SHIPMENT', True, None) - self.assertEqual(True, InvenTreeSetting.get_setting('SALESORDER_DEFAULT_SHIPMENT')) + self.assertEqual( + True, InvenTreeSetting.get_setting('SALESORDER_DEFAULT_SHIPMENT') + ) # Create a second order order_2 = SalesOrder.objects.create( - customer=self.customer, - reference='1236', - customer_reference='ABC 55557' + customer=self.customer, reference='1236', customer_reference='ABC 55557' ) # Order should have one shipment @@ -326,7 +328,7 @@ class SalesOrderTest(TestCase): order.tasks.check_overdue_sales_orders() messages = NotificationMessage.objects.filter( - category='order.overdue_sales_order', + category='order.overdue_sales_order' ) self.assertEqual(len(messages), 1) @@ -341,12 +343,10 @@ class SalesOrderTest(TestCase): customer=self.customer, reference='1234567', created_by=get_user_model().objects.get(pk=3), - responsible=Owner.create(obj=Group.objects.get(pk=3)) + responsible=Owner.create(obj=Group.objects.get(pk=3)), ) - messages = NotificationMessage.objects.filter( - category='order.new_salesorder', - ) + messages = NotificationMessage.objects.filter(category='order.new_salesorder') # A notification should have been generated for user 4 (who is a member of group 3) self.assertTrue(messages.filter(user__pk=4).exists()) @@ -356,7 +356,12 @@ class SalesOrderTest(TestCase): def test_metadata(self): """Unit tests for the metadata field.""" - for model in [SalesOrder, SalesOrderLineItem, SalesOrderExtraLine, SalesOrderShipment]: + for model in [ + SalesOrder, + SalesOrderLineItem, + SalesOrderExtraLine, + SalesOrderShipment, + ]: p = model.objects.first() self.assertIsNone(p.get_metadata('test')) diff --git a/InvenTree/order/test_views.py b/InvenTree/order/test_views.py index bb1e5bddf8..f6ebdf6272 100644 --- a/InvenTree/order/test_views.py +++ b/InvenTree/order/test_views.py @@ -7,6 +7,7 @@ from InvenTree.unit_test import InvenTreeTestCase class OrderViewTestCase(InvenTreeTestCase): """Base unit test class for order views""" + fixtures = [ 'category', 'part', @@ -35,6 +36,7 @@ class OrderViewTestCase(InvenTreeTestCase): class PurchaseOrderListTest(OrderViewTestCase): """Unit tests for the PurchaseOrder index page""" + def test_order_list(self): """Tests for the PurchaseOrder index page""" response = self.client.get(reverse('purchase-order-index')) @@ -54,7 +56,9 @@ class PurchaseOrderTests(OrderViewTestCase): def test_po_export(self): """Export PurchaseOrder.""" - response = self.client.get(reverse('po-export', args=(1,)), HTTP_X_REQUESTED_WITH='XMLHttpRequest') + response = self.client.get( + reverse('po-export', args=(1,)), HTTP_X_REQUESTED_WITH='XMLHttpRequest' + ) # Response should be streaming-content (file download) self.assertIn('streaming_content', dir(response)) diff --git a/InvenTree/order/tests.py b/InvenTree/order/tests.py index 5b9ee89b0d..bdfc327f86 100644 --- a/InvenTree/order/tests.py +++ b/InvenTree/order/tests.py @@ -18,8 +18,7 @@ from part.models import Part from stock.models import StockItem, StockLocation from users.models import Owner -from .models import (PurchaseOrder, PurchaseOrderExtraLine, - PurchaseOrderLineItem) +from .models import PurchaseOrder, PurchaseOrderExtraLine, PurchaseOrderLineItem class OrderTest(TestCase): @@ -40,7 +39,6 @@ class OrderTest(TestCase): def test_basics(self): """Basic tests e.g. repr functions etc.""" for pk in range(1, 8): - order = PurchaseOrder.objects.get(pk=pk) self.assertEqual(order.get_absolute_url(), f'/order/purchase-order/{pk}/') @@ -48,7 +46,7 @@ class OrderTest(TestCase): self.assertEqual(order.reference, f'PO-{pk:04d}') line = PurchaseOrderLineItem.objects.get(pk=1) - self.assertEqual(str(line), "100 x ACME0001 from ACME (for PO-0001 - ACME)") + self.assertEqual(str(line), '100 x ACME0001 from ACME (for PO-0001 - ACME)') def test_rebuild_reference(self): """Test that the reference_int field is correctly updated when the model is saved""" @@ -222,18 +220,12 @@ class OrderTest(TestCase): # Create a new supplier part with larger pack size sp_1 = SupplierPart.objects.create( - part=prt, - supplier=sup, - SKU='SKUx10', - pack_quantity='10', + part=prt, supplier=sup, SKU='SKUx10', pack_quantity='10' ) # Create a new supplier part with smaller pack size sp_2 = SupplierPart.objects.create( - part=prt, - supplier=sup, - SKU='SKUx0.1', - pack_quantity='0.1', + part=prt, supplier=sup, SKU='SKUx0.1', pack_quantity='0.1' ) # Record values before we start @@ -244,9 +236,7 @@ class OrderTest(TestCase): # Create a new PurchaseOrder po = PurchaseOrder.objects.create( - supplier=sup, - reference=f"PO-{n + 1}", - description='Some PO', + supplier=sup, reference=f'PO-{n + 1}', description='Some PO' ) # Add line items @@ -291,15 +281,11 @@ class OrderTest(TestCase): # The 'on_order' quantity should have decreased by 10.5 self.assertEqual( - prt.on_order, - round(on_order + Decimal(31.3) - Decimal(10.5), 1) + prt.on_order, round(on_order + Decimal(31.3) - Decimal(10.5), 1) ) # The 'in_stock' quantity should have increased by 10.5 - self.assertEqual( - prt.total_stock, - round(in_stock + Decimal(10.5), 1) - ) + self.assertEqual(prt.total_stock, round(in_stock + Decimal(10.5), 1)) # Check that the unit purchase price value has been updated correctly si = StockItem.objects.filter(supplier_part=sp_1) @@ -341,8 +327,7 @@ class OrderTest(TestCase): for user_id in [2, 3, 4]: messages = common.models.NotificationMessage.objects.filter( - category='order.overdue_purchase_order', - user__id=user_id, + category='order.overdue_purchase_order', user__id=user_id ) # User ID 3 is inactive, and thus should not receive notifications @@ -373,7 +358,7 @@ class OrderTest(TestCase): # Initially, no notifications messages = common.models.NotificationMessage.objects.filter( - category='order.new_purchaseorder', + category='order.new_purchaseorder' ) self.assertEqual(messages.count(), 0) diff --git a/InvenTree/order/urls.py b/InvenTree/order/urls.py index 7305bf8543..f1ae8b8446 100644 --- a/InvenTree/order/urls.py +++ b/InvenTree/order/urls.py @@ -9,34 +9,27 @@ from django.urls import include, path, re_path from . import views purchase_order_detail_urls = [ - re_path(r'^upload/', views.PurchaseOrderUpload.as_view(), name='po-upload'), re_path(r'^export/', views.PurchaseOrderExport.as_view(), name='po-export'), - re_path(r'^.*$', views.PurchaseOrderDetail.as_view(), name='po-detail'), ] purchase_order_urls = [ - re_path(r'^pricing/', views.LineItemPricing.as_view(), name='line-pricing'), - # Display detail view for a single purchase order path(r'/', include(purchase_order_detail_urls)), - # Display complete list of purchase orders re_path(r'^.*$', views.PurchaseOrderIndex.as_view(), name='purchase-order-index'), ] sales_order_detail_urls = [ re_path(r'^export/', views.SalesOrderExport.as_view(), name='so-export'), - re_path(r'^.*$', views.SalesOrderDetail.as_view(), name='so-detail'), ] sales_order_urls = [ # Display detail view for a single SalesOrder path(r'/', include(sales_order_detail_urls)), - # Display list of all sales orders re_path(r'^.*$', views.SalesOrderIndex.as_view(), name='sales-order-index'), ] @@ -44,7 +37,6 @@ sales_order_urls = [ return_order_urls = [ path(r'/', views.ReturnOrderDetail.as_view(), name='return-order-detail'), - # Display list of all return orders re_path(r'^.*$', views.ReturnOrderIndex.as_view(), name='return-order-index'), ] diff --git a/InvenTree/order/views.py b/InvenTree/order/views.py index 4130a06eaa..4c77374535 100644 --- a/InvenTree/order/views.py +++ b/InvenTree/order/views.py @@ -24,10 +24,15 @@ from plugin.views import InvenTreePluginViewMixin from . import forms as order_forms from .admin import PurchaseOrderLineItemResource, SalesOrderLineItemResource -from .models import (PurchaseOrder, PurchaseOrderLineItem, ReturnOrder, - SalesOrder, SalesOrderLineItem) +from .models import ( + PurchaseOrder, + PurchaseOrderLineItem, + ReturnOrder, + SalesOrder, + SalesOrderLineItem, +) -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') class PurchaseOrderIndex(InvenTreeRoleMixin, ListView): @@ -46,6 +51,7 @@ class PurchaseOrderIndex(InvenTreeRoleMixin, ListView): class SalesOrderIndex(InvenTreeRoleMixin, ListView): """SalesOrder index (list) view class""" + model = SalesOrder template_name = 'order/sales_orders.html' context_object_name = 'orders' @@ -71,7 +77,9 @@ class SalesOrderDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView) """Detail view for a SalesOrder object.""" context_object_name = 'order' - queryset = SalesOrder.objects.all().prefetch_related('lines__allocations__item__purchase_order') + queryset = SalesOrder.objects.all().prefetch_related( + 'lines__allocations__item__purchase_order' + ) template_name = 'order/sales_order_detail.html' @@ -88,20 +96,12 @@ class PurchaseOrderUpload(FileManagementFormView): class OrderFileManager(FileManager): """Specify required fields""" - REQUIRED_HEADERS = [ - 'Quantity', - ] - ITEM_MATCH_HEADERS = [ - 'Manufacturer_MPN', - 'Supplier_SKU', - ] + REQUIRED_HEADERS = ['Quantity'] - OPTIONAL_HEADERS = [ - 'Purchase_Price', - 'Reference', - 'Notes', - ] + ITEM_MATCH_HEADERS = ['Manufacturer_MPN', 'Supplier_SKU'] + + OPTIONAL_HEADERS = ['Purchase_Price', 'Reference', 'Notes'] name = 'order' form_list = [ @@ -115,9 +115,9 @@ class PurchaseOrderUpload(FileManagementFormView): 'order/order_wizard/match_parts.html', ] form_steps_description = [ - _("Upload File"), - _("Match Fields"), - _("Match Supplier Parts"), + _('Upload File'), + _('Match Fields'), + _('Match Supplier Parts'), ] form_field_map = { 'item_select': 'part', @@ -150,7 +150,9 @@ class PurchaseOrderUpload(FileManagementFormView): """ order = self.get_order() - self.allowed_items = SupplierPart.objects.filter(supplier=order.supplier).prefetch_related('manufacturer_part') + self.allowed_items = SupplierPart.objects.filter( + supplier=order.supplier + ).prefetch_related('manufacturer_part') # Fields prefixed with "Part_" can be used to do "smart matching" against Part objects in the database q_idx = self.get_column_index('Quantity') @@ -161,7 +163,6 @@ class PurchaseOrderUpload(FileManagementFormView): n_idx = self.get_column_index('Notes') for row in self.rows: - # Initially use a quantity of zero quantity = Decimal(0) @@ -191,7 +192,11 @@ class PurchaseOrderUpload(FileManagementFormView): try: # Attempt SupplierPart lookup based on SKU value exact_match_part = self.allowed_items.get(SKU__contains=sku) - except (ValueError, SupplierPart.DoesNotExist, SupplierPart.MultipleObjectsReturned): + except ( + ValueError, + SupplierPart.DoesNotExist, + SupplierPart.MultipleObjectsReturned, + ): exact_match_part = None # Check if there is a column corresponding to "Manufacturer MPN" and no exact match found yet @@ -200,8 +205,14 @@ class PurchaseOrderUpload(FileManagementFormView): try: # Attempt SupplierPart lookup based on MPN value - exact_match_part = self.allowed_items.get(manufacturer_part__MPN__contains=mpn) - except (ValueError, SupplierPart.DoesNotExist, SupplierPart.MultipleObjectsReturned): + exact_match_part = self.allowed_items.get( + manufacturer_part__MPN__contains=mpn + ) + except ( + ValueError, + SupplierPart.DoesNotExist, + SupplierPart.MultipleObjectsReturned, + ): exact_match_part = None # Supply list of part options for each row, sorted by how closely they match the part name @@ -239,7 +250,9 @@ class PurchaseOrderUpload(FileManagementFormView): # Create PurchaseOrderLineItem instances for purchase_order_item in items.values(): try: - supplier_part = SupplierPart.objects.get(pk=int(purchase_order_item['part'])) + supplier_part = SupplierPart.objects.get( + pk=int(purchase_order_item['part']) + ) except (ValueError, SupplierPart.DoesNotExist): continue @@ -259,7 +272,9 @@ class PurchaseOrderUpload(FileManagementFormView): # PurchaseOrderLineItem already exists pass - return HttpResponseRedirect(reverse('po-detail', kwargs={'pk': self.kwargs['pk']})) + return HttpResponseRedirect( + reverse('po-detail', kwargs={'pk': self.kwargs['pk']}) + ) class SalesOrderExport(AjaxView): @@ -279,7 +294,7 @@ class SalesOrderExport(AjaxView): export_format = request.GET.get('format', 'csv') - filename = f"{str(order)} - {order.customer.name}.{export_format}" + filename = f'{str(order)} - {order.customer.name}.{export_format}' dataset = SalesOrderLineItemResource().export(queryset=order.lines.all()) @@ -320,6 +335,7 @@ class LineItemPricing(PartPricing): class EnhancedForm(PartPricing.form_class): """Extra form options""" + pk = IntegerField(widget=HiddenInput()) so_line = IntegerField(widget=HiddenInput()) @@ -401,7 +417,9 @@ class LineItemPricing(PartPricing): # check qunatity and update if different if so_line.quantity != quantity: so_line.quantity = quantity - note = _('Updated {part} unit-price to {price} and quantity to {qty}') + note = _( + 'Updated {part} unit-price to {price} and quantity to {qty}' + ) # update sale_price so_line.sale_price = price @@ -410,7 +428,11 @@ class LineItemPricing(PartPricing): # parse response data = { 'form_valid': True, - 'success': note.format(part=str(so_line.part), price=str(so_line.sale_price), qty=quantity) + 'success': note.format( + part=str(so_line.part), + price=str(so_line.sale_price), + qty=quantity, + ), } return JsonResponse(data=data) diff --git a/InvenTree/part/admin.py b/InvenTree/part/admin.py index c41622c5c6..9b4e81f75a 100644 --- a/InvenTree/part/admin.py +++ b/InvenTree/part/admin.py @@ -18,53 +18,145 @@ class PartResource(InvenTreeResource): class Meta: """Metaclass definition""" + model = models.Part skip_unchanged = True report_skipped = False clean_model_instances = True exclude = [ - 'bom_checksum', 'bom_checked_by', 'bom_checked_date', - 'lft', 'rght', 'tree_id', 'level', + 'bom_checksum', + 'bom_checked_by', + 'bom_checked_date', + 'lft', + 'rght', + 'tree_id', + 'level', 'metadata', - 'barcode_data', 'barcode_hash', + 'barcode_data', + 'barcode_hash', ] id = Field(attribute='pk', column_name=_('Part ID'), widget=widgets.IntegerWidget()) - name = Field(attribute='name', column_name=_('Part Name'), widget=widgets.CharWidget()) - description = Field(attribute='description', column_name=_('Part Description'), widget=widgets.CharWidget()) + name = Field( + attribute='name', column_name=_('Part Name'), widget=widgets.CharWidget() + ) + description = Field( + attribute='description', + column_name=_('Part Description'), + widget=widgets.CharWidget(), + ) IPN = Field(attribute='IPN', column_name=_('IPN'), widget=widgets.CharWidget()) - revision = Field(attribute='revision', column_name=_('Revision'), widget=widgets.CharWidget()) - keywords = Field(attribute='keywords', column_name=_('Keywords'), widget=widgets.CharWidget()) + revision = Field( + attribute='revision', column_name=_('Revision'), widget=widgets.CharWidget() + ) + keywords = Field( + attribute='keywords', column_name=_('Keywords'), widget=widgets.CharWidget() + ) link = Field(attribute='link', column_name=_('Link'), widget=widgets.CharWidget()) - units = Field(attribute='units', column_name=_('Units'), widget=widgets.CharWidget()) + units = Field( + attribute='units', column_name=_('Units'), widget=widgets.CharWidget() + ) notes = Field(attribute='notes', column_name=_('Notes')) image = Field(attribute='image', column_name=_('Part Image'), readonly=True) - category = Field(attribute='category', column_name=_('Category ID'), widget=widgets.ForeignKeyWidget(models.PartCategory)) - category_name = Field(attribute='category__name', column_name=_('Category Name'), readonly=True) - default_location = Field(attribute='default_location', column_name=_('Default Location ID'), widget=widgets.ForeignKeyWidget(StockLocation)) - default_supplier = Field(attribute='default_supplier', column_name=_('Default Supplier ID'), widget=widgets.ForeignKeyWidget(SupplierPart)) - variant_of = Field(attribute='variant_of', column_name=_('Variant Of'), widget=widgets.ForeignKeyWidget(models.Part)) + category = Field( + attribute='category', + column_name=_('Category ID'), + widget=widgets.ForeignKeyWidget(models.PartCategory), + ) + category_name = Field( + attribute='category__name', column_name=_('Category Name'), readonly=True + ) + default_location = Field( + attribute='default_location', + column_name=_('Default Location ID'), + widget=widgets.ForeignKeyWidget(StockLocation), + ) + default_supplier = Field( + attribute='default_supplier', + column_name=_('Default Supplier ID'), + widget=widgets.ForeignKeyWidget(SupplierPart), + ) + variant_of = Field( + attribute='variant_of', + column_name=_('Variant Of'), + widget=widgets.ForeignKeyWidget(models.Part), + ) minimum_stock = Field(attribute='minimum_stock', column_name=_('Minimum Stock')) # Part Attributes - active = Field(attribute='active', column_name=_('Active'), widget=widgets.BooleanWidget()) - assembly = Field(attribute='assembly', column_name=_('Assembly'), widget=widgets.BooleanWidget()) - component = Field(attribute='component', column_name=_('Component'), widget=widgets.BooleanWidget()) - purchaseable = Field(attribute='purchaseable', column_name=_('Purchaseable'), widget=widgets.BooleanWidget()) - salable = Field(attribute='salable', column_name=_('Salable'), widget=widgets.BooleanWidget()) - is_template = Field(attribute='is_template', column_name=_('Template'), widget=widgets.BooleanWidget()) - trackable = Field(attribute='trackable', column_name=_('Trackable'), widget=widgets.BooleanWidget()) - virtual = Field(attribute='virtual', column_name=_('Virtual'), widget=widgets.BooleanWidget()) + active = Field( + attribute='active', column_name=_('Active'), widget=widgets.BooleanWidget() + ) + assembly = Field( + attribute='assembly', column_name=_('Assembly'), widget=widgets.BooleanWidget() + ) + component = Field( + attribute='component', + column_name=_('Component'), + widget=widgets.BooleanWidget(), + ) + purchaseable = Field( + attribute='purchaseable', + column_name=_('Purchaseable'), + widget=widgets.BooleanWidget(), + ) + salable = Field( + attribute='salable', column_name=_('Salable'), widget=widgets.BooleanWidget() + ) + is_template = Field( + attribute='is_template', + column_name=_('Template'), + widget=widgets.BooleanWidget(), + ) + trackable = Field( + attribute='trackable', + column_name=_('Trackable'), + widget=widgets.BooleanWidget(), + ) + virtual = Field( + attribute='virtual', column_name=_('Virtual'), widget=widgets.BooleanWidget() + ) # Extra calculated meta-data (readonly) - suppliers = Field(attribute='supplier_count', column_name=_('Suppliers'), readonly=True) - in_stock = Field(attribute='total_stock', column_name=_('In Stock'), readonly=True, widget=widgets.IntegerWidget()) - on_order = Field(attribute='on_order', column_name=_('On Order'), readonly=True, widget=widgets.IntegerWidget()) - used_in = Field(attribute='used_in_count', column_name=_('Used In'), readonly=True, widget=widgets.IntegerWidget()) - allocated = Field(attribute='allocation_count', column_name=_('Allocated'), readonly=True, widget=widgets.IntegerWidget()) - building = Field(attribute='quantity_being_built', column_name=_('Building'), readonly=True, widget=widgets.IntegerWidget()) - min_cost = Field(attribute='pricing__overall_min', column_name=_('Minimum Cost'), readonly=True) - max_cost = Field(attribute='pricing__overall_max', column_name=_('Maximum Cost'), readonly=True) + suppliers = Field( + attribute='supplier_count', column_name=_('Suppliers'), readonly=True + ) + in_stock = Field( + attribute='total_stock', + column_name=_('In Stock'), + readonly=True, + widget=widgets.IntegerWidget(), + ) + on_order = Field( + attribute='on_order', + column_name=_('On Order'), + readonly=True, + widget=widgets.IntegerWidget(), + ) + used_in = Field( + attribute='used_in_count', + column_name=_('Used In'), + readonly=True, + widget=widgets.IntegerWidget(), + ) + allocated = Field( + attribute='allocation_count', + column_name=_('Allocated'), + readonly=True, + widget=widgets.IntegerWidget(), + ) + building = Field( + attribute='quantity_being_built', + column_name=_('Building'), + readonly=True, + widget=widgets.IntegerWidget(), + ) + min_cost = Field( + attribute='pricing__overall_min', column_name=_('Minimum Cost'), readonly=True + ) + max_cost = Field( + attribute='pricing__overall_max', column_name=_('Maximum Cost'), readonly=True + ) def dehydrate_min_cost(self, part): """Render minimum cost value for this Part""" @@ -88,7 +180,7 @@ class PartResource(InvenTreeResource): 'used_in', 'builds', 'supplier_parts__purchase_order_line_items', - 'stock_items__allocations' + 'stock_items__allocations', ) return query @@ -106,21 +198,33 @@ class PartImportResource(InvenTreeResource): class Meta(PartResource.Meta): """Metaclass definition""" + skip_unchanged = True report_skipped = False clean_model_instances = True exclude = [ - 'id', 'category__name', 'creation_date', 'creation_user', - 'pricing__overall_min', 'pricing__overall_max', - 'bom_checksum', 'bom_checked_by', 'bom_checked_date', - 'lft', 'rght', 'tree_id', 'level', + 'id', + 'category__name', + 'creation_date', + 'creation_user', + 'pricing__overall_min', + 'pricing__overall_max', + 'bom_checksum', + 'bom_checked_by', + 'bom_checked_date', + 'lft', + 'rght', + 'tree_id', + 'level', 'metadata', - 'barcode_data', 'barcode_hash', + 'barcode_data', + 'barcode_hash', ] class PartParameterInline(admin.TabularInline): """Inline for part parameter data""" + model = models.PartParameter @@ -133,7 +237,13 @@ class PartAdmin(ImportExportModelAdmin): list_filter = ('active', 'assembly', 'is_template', 'virtual') - search_fields = ('name', 'description', 'category__name', 'category__description', 'IPN') + search_fields = ( + 'name', + 'description', + 'category__name', + 'category__description', + 'IPN', + ) autocomplete_fields = [ 'variant_of', @@ -142,9 +252,7 @@ class PartAdmin(ImportExportModelAdmin): 'default_supplier', ] - inlines = [ - PartParameterInline, - ] + inlines = [PartParameterInline] class PartPricingAdmin(admin.ModelAdmin): @@ -152,9 +260,7 @@ class PartPricingAdmin(admin.ModelAdmin): list_display = ('part', 'overall_min', 'overall_max') - autcomplete_fields = [ - 'part', - ] + autcomplete_fields = ['part'] class PartStocktakeAdmin(admin.ModelAdmin): @@ -174,6 +280,7 @@ class PartCategoryResource(InvenTreeResource): class Meta: """Metaclass definition""" + model = models.PartCategory skip_unchanged = True report_skipped = False @@ -181,22 +288,42 @@ class PartCategoryResource(InvenTreeResource): exclude = [ # Exclude MPTT internal model fields - 'lft', 'rght', 'tree_id', 'level', + 'lft', + 'rght', + 'tree_id', + 'level', 'metadata', 'icon', ] - id = Field(attribute='pk', column_name=_('Category ID'), widget=widgets.IntegerWidget()) + id = Field( + attribute='pk', column_name=_('Category ID'), widget=widgets.IntegerWidget() + ) name = Field(attribute='name', column_name=_('Category Name')) description = Field(attribute='description', column_name=_('Description')) - parent = Field(attribute='parent', column_name=_('Parent ID'), widget=widgets.ForeignKeyWidget(models.PartCategory)) - parent_name = Field(attribute='parent__name', column_name=_('Parent Name'), readonly=True) - default_location = Field(attribute='default_location', column_name=_('Default Location ID'), widget=widgets.ForeignKeyWidget(StockLocation)) + parent = Field( + attribute='parent', + column_name=_('Parent ID'), + widget=widgets.ForeignKeyWidget(models.PartCategory), + ) + parent_name = Field( + attribute='parent__name', column_name=_('Parent Name'), readonly=True + ) + default_location = Field( + attribute='default_location', + column_name=_('Default Location ID'), + widget=widgets.ForeignKeyWidget(StockLocation), + ) default_keywords = Field(attribute='default_keywords', column_name=_('Keywords')) pathstring = Field(attribute='pathstring', column_name=_('Category Path')) # Calculated fields - parts = Field(attribute='item_count', column_name=_('Parts'), widget=widgets.IntegerWidget(), readonly=True) + parts = Field( + attribute='item_count', + column_name=_('Parts'), + widget=widgets.IntegerWidget(), + readonly=True, + ) def after_import(self, dataset, result, using_transactions, dry_run, **kwargs): """Rebuild MPTT tree structure after importing PartCategory data""" @@ -215,7 +342,7 @@ class PartCategoryAdmin(ImportExportModelAdmin): search_fields = ('name', 'description') - autocomplete_fields = ('parent', 'default_location',) + autocomplete_fields = ('parent', 'default_location') class PartRelatedAdmin(admin.ModelAdmin): @@ -245,38 +372,63 @@ class BomItemResource(InvenTreeResource): class Meta: """Metaclass definition""" + model = models.BomItem skip_unchanged = True report_skipped = False clean_model_instances = True - exclude = [ - 'checksum', - 'id', - 'part', - 'sub_part', - 'validated', - ] + exclude = ['checksum', 'id', 'part', 'sub_part', 'validated'] level = Field(attribute='level', column_name=_('BOM Level'), readonly=True) - bom_id = Field(attribute='pk', column_name=_('BOM Item ID'), widget=widgets.IntegerWidget()) + bom_id = Field( + attribute='pk', column_name=_('BOM Item ID'), widget=widgets.IntegerWidget() + ) # ID of the parent part - parent_part_id = Field(attribute='part', column_name=_('Parent ID'), widget=widgets.ForeignKeyWidget(models.Part)) - parent_part_ipn = Field(attribute='part__IPN', column_name=_('Parent IPN'), readonly=True) - parent_part_name = Field(attribute='part__name', column_name=_('Parent Name'), readonly=True) - part_id = Field(attribute='sub_part', column_name=_('Part ID'), widget=widgets.ForeignKeyWidget(models.Part)) - part_ipn = Field(attribute='sub_part__IPN', column_name=_('Part IPN'), readonly=True) - part_name = Field(attribute='sub_part__name', column_name=_('Part Name'), readonly=True) - part_description = Field(attribute='sub_part__description', column_name=_('Description'), readonly=True) + parent_part_id = Field( + attribute='part', + column_name=_('Parent ID'), + widget=widgets.ForeignKeyWidget(models.Part), + ) + parent_part_ipn = Field( + attribute='part__IPN', column_name=_('Parent IPN'), readonly=True + ) + parent_part_name = Field( + attribute='part__name', column_name=_('Parent Name'), readonly=True + ) + part_id = Field( + attribute='sub_part', + column_name=_('Part ID'), + widget=widgets.ForeignKeyWidget(models.Part), + ) + part_ipn = Field( + attribute='sub_part__IPN', column_name=_('Part IPN'), readonly=True + ) + part_name = Field( + attribute='sub_part__name', column_name=_('Part Name'), readonly=True + ) + part_description = Field( + attribute='sub_part__description', column_name=_('Description'), readonly=True + ) quantity = Field(attribute='quantity', column_name=_('Quantity')) reference = Field(attribute='reference', column_name=_('Reference')) note = Field(attribute='note', column_name=_('Note')) - min_cost = Field(attribute='sub_part__pricing__overall_min', column_name=_('Minimum Price'), readonly=True) - max_cost = Field(attribute='sub_part__pricing__overall_max', column_name=_('Maximum Price'), readonly=True) + min_cost = Field( + attribute='sub_part__pricing__overall_min', + column_name=_('Minimum Price'), + readonly=True, + ) + max_cost = Field( + attribute='sub_part__pricing__overall_max', + column_name=_('Maximum Price'), + readonly=True, + ) - sub_assembly = Field(attribute='sub_part__assembly', column_name=_('Assembly'), readonly=True) + sub_assembly = Field( + attribute='sub_part__assembly', column_name=_('Assembly'), readonly=True + ) def dehydrate_min_cost(self, item): """Render minimum cost value for the BOM line item""" @@ -329,13 +481,12 @@ class BomItemResource(InvenTreeResource): 'part__name', 'sub_part__name', 'sub_part__description', - 'sub_part__assembly' + 'sub_part__assembly', ] idx = 0 while idx < len(fields): - if fields[idx].attribute in to_remove: del fields[idx] else: @@ -351,30 +502,31 @@ class BomItemAdmin(ImportExportModelAdmin): list_display = ('part', 'sub_part', 'quantity') - search_fields = ('part__name', 'part__description', 'sub_part__name', 'sub_part__description') + search_fields = ( + 'part__name', + 'part__description', + 'sub_part__name', + 'sub_part__description', + ) - autocomplete_fields = ('part', 'sub_part',) + autocomplete_fields = ('part', 'sub_part') class ParameterTemplateResource(InvenTreeResource): """Class for managing ParameterTemplate import/export""" # The following fields will be converted from None to '' - CONVERT_NULL_FIELDS = [ - 'choices', - 'units' - ] + CONVERT_NULL_FIELDS = ['choices', 'units'] class Meta: """Metaclass definition""" + model = models.PartParameterTemplate skip_unchanged = True report_skipped = False clean_model_instances = True - exclude = [ - 'metadata', - ] + exclude = ['metadata'] class ParameterTemplateAdmin(ImportExportModelAdmin): @@ -392,6 +544,7 @@ class ParameterResource(InvenTreeResource): class Meta: """Metaclass definition""" + model = models.PartParameter skip_unchanged = True report_skipped = False @@ -401,7 +554,10 @@ class ParameterResource(InvenTreeResource): part_name = Field(attribute='part__name', readonly=True) - template = Field(attribute='template', widget=widgets.ForeignKeyWidget(models.PartParameterTemplate)) + template = Field( + attribute='template', + widget=widgets.ForeignKeyWidget(models.PartParameterTemplate), + ) template_name = Field(attribute='template__name', readonly=True) @@ -419,7 +575,7 @@ class ParameterAdmin(ImportExportModelAdmin): class PartCategoryParameterAdmin(admin.ModelAdmin): """Admin class for the PartCategoryParameterTemplate model""" - autocomplete_fields = ('category', 'parameter_template',) + autocomplete_fields = ('category', 'parameter_template') class PartSellPriceBreakAdmin(admin.ModelAdmin): @@ -427,9 +583,10 @@ class PartSellPriceBreakAdmin(admin.ModelAdmin): class Meta: """Metaclass definition""" + model = models.PartSellPriceBreak - list_display = ('part', 'quantity', 'price',) + list_display = ('part', 'quantity', 'price') class PartInternalPriceBreakAdmin(admin.ModelAdmin): @@ -437,9 +594,10 @@ class PartInternalPriceBreakAdmin(admin.ModelAdmin): class Meta: """Metaclass definition""" + model = models.PartInternalPriceBreak - list_display = ('part', 'quantity', 'price',) + list_display = ('part', 'quantity', 'price') autocomplete_fields = ('part',) diff --git a/InvenTree/part/api.py b/InvenTree/part/api.py index f5664f4533..62e856af6b 100644 --- a/InvenTree/part/api.py +++ b/InvenTree/part/api.py @@ -17,35 +17,68 @@ from rest_framework.response import Response import order.models import part.filters from build.models import Build, BuildItem -from InvenTree.api import (APIDownloadMixin, AttachmentMixin, - ListCreateDestroyAPIView, MetadataView) -from InvenTree.filters import (ORDER_FILTER, SEARCH_ORDER_FILTER, - SEARCH_ORDER_FILTER_ALIAS, InvenTreeDateFilter, - InvenTreeSearchFilter) -from InvenTree.helpers import (DownloadFile, increment_serial_number, isNull, - str2bool, str2int) -from InvenTree.mixins import (CreateAPI, CustomRetrieveUpdateDestroyAPI, - ListAPI, ListCreateAPI, RetrieveAPI, - RetrieveUpdateAPI, RetrieveUpdateDestroyAPI, - UpdateAPI) +from InvenTree.api import ( + APIDownloadMixin, + AttachmentMixin, + ListCreateDestroyAPIView, + MetadataView, +) +from InvenTree.filters import ( + ORDER_FILTER, + SEARCH_ORDER_FILTER, + SEARCH_ORDER_FILTER_ALIAS, + InvenTreeDateFilter, + InvenTreeSearchFilter, +) +from InvenTree.helpers import ( + DownloadFile, + increment_serial_number, + isNull, + str2bool, + str2int, +) +from InvenTree.mixins import ( + CreateAPI, + CustomRetrieveUpdateDestroyAPI, + ListAPI, + ListCreateAPI, + RetrieveAPI, + RetrieveUpdateAPI, + RetrieveUpdateDestroyAPI, + UpdateAPI, +) from InvenTree.permissions import RolePermission -from InvenTree.status_codes import (BuildStatusGroups, - PurchaseOrderStatusGroups, - SalesOrderStatusGroups) +from InvenTree.status_codes import ( + BuildStatusGroups, + PurchaseOrderStatusGroups, + SalesOrderStatusGroups, +) from part.admin import PartCategoryResource, PartResource from stock.models import StockLocation from . import serializers as part_serializers from . import views -from .models import (BomItem, BomItemSubstitute, Part, PartAttachment, - PartCategory, PartCategoryParameterTemplate, - PartInternalPriceBreak, PartParameter, - PartParameterTemplate, PartRelated, PartSellPriceBreak, - PartStocktake, PartStocktakeReport, PartTestTemplate) +from .models import ( + BomItem, + BomItemSubstitute, + Part, + PartAttachment, + PartCategory, + PartCategoryParameterTemplate, + PartInternalPriceBreak, + PartParameter, + PartParameterTemplate, + PartRelated, + PartSellPriceBreak, + PartStocktake, + PartStocktakeReport, + PartTestTemplate, +) class CategoryMixin: """Mixin class for PartCategory endpoints""" + serializer_class = part_serializers.CategorySerializer queryset = PartCategory.objects.all() @@ -60,7 +93,9 @@ class CategoryMixin: ctx = super().get_serializer_context() try: - ctx['starred_categories'] = [star.category for star in self.request.user.starred_categories.all()] + ctx['starred_categories'] = [ + star.category for star in self.request.user.starred_categories.all() + ] except AttributeError: # Error is thrown if the view does not have an associated request ctx['starred_categories'] = [] @@ -79,7 +114,7 @@ class CategoryList(CategoryMixin, APIDownloadMixin, ListCreateAPI): """Download the filtered queryset as a data file""" dataset = PartCategoryResource().export(queryset=queryset) filedata = dataset.export(export_format) - filename = f"InvenTree_Categories.{export_format}" + filename = f'InvenTree_Categories.{export_format}' return DownloadFile(filedata, filename) @@ -103,7 +138,6 @@ class CategoryList(CategoryMixin, APIDownloadMixin, ListCreateAPI): pass # Look for top-level categories elif isNull(cat_id): - if not cascade: queryset = queryset.filter(parent=None) @@ -147,7 +181,9 @@ class CategoryList(CategoryMixin, APIDownloadMixin, ListCreateAPI): if starred is not None: starred = str2bool(starred) - starred_categories = [star.category.pk for star in self.request.user.starred_categories.all()] + starred_categories = [ + star.category.pk for star in self.request.user.starred_categories.all() + ] if starred: queryset = queryset.filter(pk__in=starred_categories) @@ -158,32 +194,14 @@ class CategoryList(CategoryMixin, APIDownloadMixin, ListCreateAPI): filter_backends = SEARCH_ORDER_FILTER - filterset_fields = [ - 'name', - 'description', - 'structural' - ] + filterset_fields = ['name', 'description', 'structural'] - ordering_fields = [ - 'name', - 'pathstring', - 'level', - 'tree_id', - 'lft', - 'part_count', - ] + ordering_fields = ['name', 'pathstring', 'level', 'tree_id', 'lft', 'part_count'] # Use hierarchical ordering by default - ordering = [ - 'tree_id', - 'lft', - 'name' - ] + ordering = ['tree_id', 'lft', 'name'] - search_fields = [ - 'name', - 'description', - ] + search_fields = ['name', 'description'] class CategoryDetail(CategoryMixin, CustomRetrieveUpdateDestroyAPI): @@ -216,13 +234,22 @@ class CategoryDetail(CategoryMixin, CustomRetrieveUpdateDestroyAPI): def destroy(self, request, *args, **kwargs): """Delete a Part category instance via the API""" - delete_parts = 'delete_parts' in request.data and request.data['delete_parts'] == '1' - delete_child_categories = 'delete_child_categories' in request.data and request.data['delete_child_categories'] == '1' - return super().destroy(request, - *args, - **dict(kwargs, - delete_parts=delete_parts, - delete_child_categories=delete_child_categories)) + delete_parts = ( + 'delete_parts' in request.data and request.data['delete_parts'] == '1' + ) + delete_child_categories = ( + 'delete_child_categories' in request.data + and request.data['delete_child_categories'] == '1' + ) + return super().destroy( + request, + *args, + **dict( + kwargs, + delete_parts=delete_parts, + delete_child_categories=delete_child_categories, + ), + ) class CategoryTree(ListAPI): @@ -261,7 +288,6 @@ class CategoryParameterList(ListCreateAPI): if category is not None: try: - category = PartCategory.objects.get(pk=category) fetch_parent = str2bool(params.get('fetch_parent', True)) @@ -298,13 +324,9 @@ class PartSalePriceList(ListCreateAPI): queryset = PartSellPriceBreak.objects.all() serializer_class = part_serializers.PartSalePriceSerializer - filter_backends = [ - DjangoFilterBackend - ] + filter_backends = [DjangoFilterBackend] - filterset_fields = [ - 'part', - ] + filterset_fields = ['part'] class PartInternalPriceDetail(RetrieveUpdateDestroyAPI): @@ -321,13 +343,9 @@ class PartInternalPriceList(ListCreateAPI): serializer_class = part_serializers.PartInternalPriceSerializer permission_required = 'roles.sales_order.show' - filter_backends = [ - DjangoFilterBackend - ] + filter_backends = [DjangoFilterBackend] - filterset_fields = [ - 'part', - ] + filterset_fields = ['part'] class PartAttachmentList(AttachmentMixin, ListCreateDestroyAPIView): @@ -336,9 +354,7 @@ class PartAttachmentList(AttachmentMixin, ListCreateDestroyAPIView): queryset = PartAttachment.objects.all() serializer_class = part_serializers.PartAttachmentSerializer - filterset_fields = [ - 'part', - ] + filterset_fields = ['part'] class PartAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI): @@ -379,7 +395,9 @@ class PartTestTemplateList(ListCreateAPI): if part: try: part = Part.objects.get(pk=part) - queryset = queryset.filter(part__in=part.get_ancestors(include_self=True)) + queryset = queryset.filter( + part__in=part.get_ancestors(include_self=True) + ) except (ValueError, Part.DoesNotExist): pass @@ -417,15 +435,13 @@ class PartThumbs(ListAPI): queryset = self.filter_queryset(self.get_queryset()) # Return the most popular parts first - data = queryset.values( - 'image', - ).annotate(count=Count('image')).order_by('-count') + data = ( + queryset.values('image').annotate(count=Count('image')).order_by('-count') + ) return Response(data) - filter_backends = [ - InvenTreeSearchFilter, - ] + filter_backends = [InvenTreeSearchFilter] search_fields = [ 'name', @@ -443,9 +459,7 @@ class PartThumbsUpdate(RetrieveUpdateAPI): queryset = Part.objects.all() serializer_class = part_serializers.PartThumbSerializerUpdate - filter_backends = [ - DjangoFilterBackend - ] + filter_backends = [DjangoFilterBackend] class PartScheduling(RetrieveAPI): @@ -468,7 +482,9 @@ class PartScheduling(RetrieveAPI): schedule = [] - def add_schedule_entry(date, quantity, title, label, url, speculative_quantity=0): + def add_schedule_entry( + date, quantity, title, label, url, speculative_quantity=0 + ): """Check if a scheduled entry should be added: - date must be non-null @@ -487,12 +503,10 @@ class PartScheduling(RetrieveAPI): # Add purchase order (incoming stock) information po_lines = order.models.PurchaseOrderLineItem.objects.filter( - part__part=part, - order__status__in=PurchaseOrderStatusGroups.OPEN, + part__part=part, order__status__in=PurchaseOrderStatusGroups.OPEN ) for line in po_lines: - target_date = line.target_date or line.order.target_date line_quantity = max(line.quantity - line.received, 0) @@ -505,17 +519,15 @@ class PartScheduling(RetrieveAPI): quantity, _('Incoming Purchase Order'), str(line.order), - line.order.get_absolute_url() + line.order.get_absolute_url(), ) # Add sales order (outgoing stock) information so_lines = order.models.SalesOrderLineItem.objects.filter( - part=part, - order__status__in=SalesOrderStatusGroups.OPEN, + part=part, order__status__in=SalesOrderStatusGroups.OPEN ) for line in so_lines: - target_date = line.target_date or line.order.target_date quantity = max(line.quantity - line.shipped, 0) @@ -530,12 +542,10 @@ class PartScheduling(RetrieveAPI): # Add build orders (incoming stock) information build_orders = Build.objects.filter( - part=part, - status__in=BuildStatusGroups.ACTIVE_CODES + part=part, status__in=BuildStatusGroups.ACTIVE_CODES ) for build in build_orders: - quantity = max(build.quantity - build.completed, 0) add_schedule_entry( @@ -576,17 +586,14 @@ class PartScheduling(RetrieveAPI): # An "inherited" BOM item filters down to variant parts also children = bom_item.part.get_descendants(include_self=True) builds = Build.objects.filter( - status__in=BuildStatusGroups.ACTIVE_CODES, - part__in=children, + status__in=BuildStatusGroups.ACTIVE_CODES, part__in=children ) else: builds = Build.objects.filter( - status__in=BuildStatusGroups.ACTIVE_CODES, - part=bom_item.part, + status__in=BuildStatusGroups.ACTIVE_CODES, part=bom_item.part ) for build in builds: - # Ensure we don't double-count any builds if build in seen_builds: continue @@ -602,8 +609,7 @@ class PartScheduling(RetrieveAPI): # Grab all allocations against the specified BomItem allocations = BuildItem.objects.filter( - build_line__bom_item=bom_item, - build_line__build=build, + build_line__bom_item=bom_item, build_line__build=build ) # Total allocated for *this* part @@ -622,7 +628,9 @@ class PartScheduling(RetrieveAPI): # Consider the case where the build order is *not* fully allocated if required_quantity > total_allocated_quantity: - speculative_quantity = -1 * (required_quantity - total_allocated_quantity) + speculative_quantity = -1 * ( + required_quantity - total_allocated_quantity + ) add_schedule_entry( build.target_date, @@ -630,7 +638,7 @@ class PartScheduling(RetrieveAPI): _('Stock required for Build Order'), str(build), build.get_absolute_url(), - speculative_quantity=speculative_quantity + speculative_quantity=speculative_quantity, ) def compare(entry_1, entry_2): @@ -674,16 +682,24 @@ class PartRequirements(RetrieveAPI): part = self.get_object() data = { - "available_stock": part.available_stock, - "on_order": part.on_order, - "required_build_order_quantity": part.required_build_order_quantity(), - "allocated_build_order_quantity": part.build_order_allocation_count(), - "required_sales_order_quantity": part.required_sales_order_quantity(), - "allocated_sales_order_quantity": part.sales_order_allocation_count(pending=True), + 'available_stock': part.available_stock, + 'on_order': part.on_order, + 'required_build_order_quantity': part.required_build_order_quantity(), + 'allocated_build_order_quantity': part.build_order_allocation_count(), + 'required_sales_order_quantity': part.required_sales_order_quantity(), + 'allocated_sales_order_quantity': part.sales_order_allocation_count( + pending=True + ), } - data["allocated"] = data["allocated_build_order_quantity"] + data["allocated_sales_order_quantity"] - data["required"] = data["required_build_order_quantity"] + data["required_sales_order_quantity"] + data['allocated'] = ( + data['allocated_build_order_quantity'] + + data['allocated_sales_order_quantity'] + ) + data['required'] = ( + data['required_build_order_quantity'] + + data['required_sales_order_quantity'] + ) return Response(data) @@ -719,9 +735,7 @@ class PartSerialNumberDetail(RetrieveAPI): # Calculate the "latest" serial number latest = part.get_latest_serial_number() - data = { - 'latest': latest, - } + data = {'latest': latest} if latest is not None: next_serial = increment_serial_number(latest) @@ -758,16 +772,11 @@ class PartValidateBOM(RetrieveUpdateAPI): class Meta: """Metaclass defines serializer fields""" - model = Part - fields = [ - 'checksum', - 'valid', - ] - checksum = serializers.CharField( - read_only=True, - source='bom_checksum', - ) + model = Part + fields = ['checksum', 'valid'] + + checksum = serializers.CharField(read_only=True, source='bom_checksum') valid = serializers.BooleanField( write_only=True, @@ -799,9 +808,7 @@ class PartValidateBOM(RetrieveUpdateAPI): part.validate_bom(request.user) - return Response({ - 'checksum': part.bom_checksum, - }) + return Response({'checksum': part.bom_checksum}) class PartFilter(rest_filters.FilterSet): @@ -812,6 +819,7 @@ class PartFilter(rest_filters.FilterSet): class Meta: """Metaclass options for this filter set""" + model = Part fields = [] @@ -834,17 +842,23 @@ class PartFilter(rest_filters.FilterSet): return queryset.filter(IPN='') # Regex filter for name - name_regex = rest_filters.CharFilter(label='Filter by name (regex)', field_name='name', lookup_expr='iregex') + name_regex = rest_filters.CharFilter( + label='Filter by name (regex)', field_name='name', lookup_expr='iregex' + ) # Exact match for IPN IPN = rest_filters.CharFilter( label='Filter by exact IPN (internal part number)', field_name='IPN', - lookup_expr="iexact" + lookup_expr='iexact', ) # Regex match for IPN - IPN_regex = rest_filters.CharFilter(label='Filter by regex on IPN (internal part number)', field_name='IPN', lookup_expr='iregex') + IPN_regex = rest_filters.CharFilter( + label='Filter by regex on IPN (internal part number)', + field_name='IPN', + lookup_expr='iregex', + ) # low_stock filter low_stock = rest_filters.BooleanFilter(label='Low stock', method='filter_low_stock') @@ -854,7 +868,9 @@ class PartFilter(rest_filters.FilterSet): if str2bool(value): # Ignore any parts which do not have a specified 'minimum_stock' level # Filter items which have an 'in_stock' level lower than 'minimum_stock' - return queryset.exclude(minimum_stock=0).filter(Q(total_in_stock__lt=F('minimum_stock'))) + return queryset.exclude(minimum_stock=0).filter( + Q(total_in_stock__lt=F('minimum_stock')) + ) # Filter items which have an 'in_stock' level higher than 'minimum_stock' return queryset.filter(Q(total_in_stock__gte=F('minimum_stock'))) @@ -868,7 +884,9 @@ class PartFilter(rest_filters.FilterSet): return queryset.filter(Q(in_stock__lte=0)) # unallocated_stock filter - unallocated_stock = rest_filters.BooleanFilter(label='Unallocated stock', method='filter_unallocated_stock') + unallocated_stock = rest_filters.BooleanFilter( + label='Unallocated stock', method='filter_unallocated_stock' + ) def filter_unallocated_stock(self, queryset, name, value): """Filter by whether the Part has unallocated stock""" @@ -876,7 +894,11 @@ class PartFilter(rest_filters.FilterSet): return queryset.filter(Q(unallocated_stock__gt=0)) return queryset.filter(Q(unallocated_stock__lte=0)) - convert_from = rest_filters.ModelChoiceFilter(label="Can convert from", queryset=Part.objects.all(), method='filter_convert_from') + convert_from = rest_filters.ModelChoiceFilter( + label='Can convert from', + queryset=Part.objects.all(), + method='filter_convert_from', + ) def filter_convert_from(self, queryset, name, part): """Limit the queryset to valid conversion options for the specified part""" @@ -886,7 +908,11 @@ class PartFilter(rest_filters.FilterSet): return queryset - exclude_tree = rest_filters.ModelChoiceFilter(label="Exclude Part tree", queryset=Part.objects.all(), method='filter_exclude_tree') + exclude_tree = rest_filters.ModelChoiceFilter( + label='Exclude Part tree', + queryset=Part.objects.all(), + method='filter_exclude_tree', + ) def filter_exclude_tree(self, queryset, name, part): """Exclude all parts and variants 'down' from the specified part from the queryset""" @@ -894,27 +920,35 @@ class PartFilter(rest_filters.FilterSet): return queryset.exclude(id__in=children) - ancestor = rest_filters.ModelChoiceFilter(label='Ancestor', queryset=Part.objects.all(), method='filter_ancestor') + ancestor = rest_filters.ModelChoiceFilter( + label='Ancestor', queryset=Part.objects.all(), method='filter_ancestor' + ) def filter_ancestor(self, queryset, name, part): """Limit queryset to descendants of the specified ancestor part""" descendants = part.get_descendants(include_self=False) return queryset.filter(id__in=descendants) - variant_of = rest_filters.ModelChoiceFilter(label='Variant Of', queryset=Part.objects.all(), method='filter_variant_of') + variant_of = rest_filters.ModelChoiceFilter( + label='Variant Of', queryset=Part.objects.all(), method='filter_variant_of' + ) def filter_variant_of(self, queryset, name, part): """Limit queryset to direct children (variants) of the specified part""" return queryset.filter(id__in=part.get_children()) - in_bom_for = rest_filters.ModelChoiceFilter(label='In BOM Of', queryset=Part.objects.all(), method='filter_in_bom') + in_bom_for = rest_filters.ModelChoiceFilter( + label='In BOM Of', queryset=Part.objects.all(), method='filter_in_bom' + ) def filter_in_bom(self, queryset, name, part): """Limit queryset to parts in the BOM for the specified part""" bom_parts = part.get_parts_in_bom() return queryset.filter(id__in=[p.pk for p in bom_parts]) - has_pricing = rest_filters.BooleanFilter(label="Has Pricing", method="filter_has_pricing") + has_pricing = rest_filters.BooleanFilter( + label='Has Pricing', method='filter_has_pricing' + ) def filter_has_pricing(self, queryset, name, value): """Filter the queryset based on whether pricing information is available for the sub_part""" @@ -926,7 +960,9 @@ class PartFilter(rest_filters.FilterSet): return queryset.filter(q_a | q_b).distinct() - stocktake = rest_filters.BooleanFilter(label="Has stocktake", method='filter_has_stocktake') + stocktake = rest_filters.BooleanFilter( + label='Has stocktake', method='filter_has_stocktake' + ) def filter_has_stocktake(self, queryset, name, value): """Filter the queryset based on whether stocktake data is available""" @@ -934,17 +970,25 @@ class PartFilter(rest_filters.FilterSet): return queryset.exclude(last_stocktake=None) return queryset.filter(last_stocktake=None) - stock_to_build = rest_filters.BooleanFilter(label='Required for Build Order', method='filter_stock_to_build') + stock_to_build = rest_filters.BooleanFilter( + label='Required for Build Order', method='filter_stock_to_build' + ) def filter_stock_to_build(self, queryset, name, value): """Filter the queryset based on whether part stock is required for a pending BuildOrder""" if str2bool(value): # Return parts which are required for a build order, but have not yet been allocated - return queryset.filter(required_for_build_orders__gt=F('allocated_to_build_orders')) + return queryset.filter( + required_for_build_orders__gt=F('allocated_to_build_orders') + ) # Return parts which are not required for a build order, or have already been allocated - return queryset.filter(required_for_build_orders__lte=F('allocated_to_build_orders')) + return queryset.filter( + required_for_build_orders__lte=F('allocated_to_build_orders') + ) - depleted_stock = rest_filters.BooleanFilter(label='Depleted Stock', method='filter_depleted_stock') + depleted_stock = rest_filters.BooleanFilter( + label='Depleted Stock', method='filter_depleted_stock' + ) def filter_depleted_stock(self, queryset, name, value): """Filter the queryset based on whether the part is fully depleted of stock""" @@ -952,7 +996,9 @@ class PartFilter(rest_filters.FilterSet): return queryset.filter(Q(in_stock=0) & ~Q(stock_item_count=0)) return queryset.exclude(Q(in_stock=0) & ~Q(stock_item_count=0)) - default_location = rest_filters.ModelChoiceFilter(label="Default Location", queryset=StockLocation.objects.all()) + default_location = rest_filters.ModelChoiceFilter( + label='Default Location', queryset=StockLocation.objects.all() + ) is_template = rest_filters.BooleanFilter() @@ -975,12 +1021,17 @@ class PartFilter(rest_filters.FilterSet): tags_slug = rest_filters.CharFilter(field_name='tags__slug', lookup_expr='iexact') # Created date filters - created_before = InvenTreeDateFilter(label='Updated before', field_name='creation_date', lookup_expr='lte') - created_after = InvenTreeDateFilter(label='Updated after', field_name='creation_date', lookup_expr='gte') + created_before = InvenTreeDateFilter( + label='Updated before', field_name='creation_date', lookup_expr='lte' + ) + created_after = InvenTreeDateFilter( + label='Updated after', field_name='creation_date', lookup_expr='gte' + ) class PartMixin: """Mixin class for Part API endpoints""" + serializer_class = part_serializers.PartSerializer queryset = Part.objects.all() @@ -1007,7 +1058,9 @@ class PartMixin: # Pass a list of "starred" parts to the current user to the serializer # We do this to reduce the number of database queries required! if self.starred_parts is None and self.request is not None: - self.starred_parts = [star.part for star in self.request.user.starred_parts.all()] + self.starred_parts = [ + star.part for star in self.request.user.starred_parts.all() + ] kwargs['starred_parts'] = self.starred_parts @@ -1042,7 +1095,7 @@ class PartList(PartMixin, APIDownloadMixin, ListCreateAPI): dataset = PartResource().export(queryset=queryset) filedata = dataset.export(export_format) - filename = f"InvenTree_Parts.{export_format}" + filename = f'InvenTree_Parts.{export_format}' return DownloadFile(filedata, filename) @@ -1087,7 +1140,6 @@ class PartList(PartMixin, APIDownloadMixin, ListCreateAPI): exclude_id += params.getlist(key, []) if exclude_id: - id_values = [] for val in exclude_id: @@ -1107,7 +1159,6 @@ class PartList(PartMixin, APIDownloadMixin, ListCreateAPI): # TODO: (It needs to be profiled!) # TODO: It might be worth caching the bom_valid status to a database column if bom_valid is not None: - bom_valid = str2bool(bom_valid) # Limit queryset to active assemblies @@ -1138,7 +1189,6 @@ class PartList(PartMixin, APIDownloadMixin, ListCreateAPI): relation_filter = Q(part_1=related_part) | Q(part_2=related_part) for relation in PartRelated.objects.filter(relation_filter).distinct(): - if relation.part_1.pk != pk: part_ids.add(relation.part_1.pk) @@ -1160,7 +1210,9 @@ class PartList(PartMixin, APIDownloadMixin, ListCreateAPI): if starred is not None: starred = str2bool(starred) - starred_parts = [star.part.pk for star in self.request.user.starred_parts.all()] + starred_parts = [ + star.part.pk for star in self.request.user.starred_parts.all() + ] if starred: queryset = queryset.filter(pk__in=starred_parts) @@ -1187,7 +1239,9 @@ class PartList(PartMixin, APIDownloadMixin, ListCreateAPI): # If '?cascade=true' then include parts which exist in sub-categories if cascade: - queryset = queryset.filter(category__in=category.getUniqueChildren()) + queryset = queryset.filter( + category__in=category.getUniqueChildren() + ) # Just return parts directly in the requested category else: queryset = queryset.filter(category=cat_id) @@ -1219,7 +1273,9 @@ class PartList(PartMixin, APIDownloadMixin, ListCreateAPI): if result: template_id = result.group(1) ascending = not ordering.startswith('-') - queryset = part.filters.order_by_parameter(queryset, template_id, ascending) + queryset = part.filters.order_by_parameter( + queryset, template_id, ascending + ) return queryset @@ -1339,14 +1395,10 @@ class PartParameterTemplateFilter(rest_filters.FilterSet): model = PartParameterTemplate # Simple filter fields - fields = [ - 'units', - 'checkbox', - ] + fields = ['units', 'checkbox'] has_choices = rest_filters.BooleanFilter( - method='filter_has_choices', - label='Has Choice', + method='filter_has_choices', label='Has Choice' ) def filter_has_choices(self, queryset, name, value): @@ -1356,10 +1408,7 @@ class PartParameterTemplateFilter(rest_filters.FilterSet): return queryset.filter(Q(choices=None) | Q(choices='')).distinct() - has_units = rest_filters.BooleanFilter( - method='filter_has_units', - label='Has Units', - ) + has_units = rest_filters.BooleanFilter(method='filter_has_units', label='Has Units') def filter_has_units(self, queryset, name, value): """Filter queryset to include only PartParameterTemplates with units.""" @@ -1382,20 +1431,11 @@ class PartParameterTemplateList(ListCreateAPI): filter_backends = SEARCH_ORDER_FILTER - filterset_fields = [ - 'name', - ] + filterset_fields = ['name'] - search_fields = [ - 'name', - 'description', - ] + search_fields = ['name', 'description'] - ordering_fields = [ - 'name', - 'units', - 'checkbox', - ] + ordering_fields = ['name', 'units', 'checkbox'] def filter_queryset(self, queryset): """Custom filtering for the PartParameterTemplate API.""" @@ -1407,7 +1447,6 @@ class PartParameterTemplateList(ListCreateAPI): part = params.get('part', None) if part is not None: - try: part = Part.objects.get(pk=part) parameters = PartParameter.objects.filter(part=part) @@ -1420,7 +1459,6 @@ class PartParameterTemplateList(ListCreateAPI): category = params.get('category', None) if category is not None: - try: category = PartCategory.objects.get(pk=category) cats = category.get_descendants(include_self=True) @@ -1442,6 +1480,7 @@ class PartParameterTemplateDetail(RetrieveUpdateDestroyAPI): class PartParameterAPIMixin: """Mixin class for PartParameter API endpoints.""" + queryset = PartParameter.objects.all() serializer_class = part_serializers.PartParameterSerializer @@ -1460,7 +1499,9 @@ class PartParameterAPIMixin: """ try: kwargs['part_detail'] = str2bool(self.request.GET.get('part_detail', False)) - kwargs['template_detail'] = str2bool(self.request.GET.get('template_detail', True)) + kwargs['template_detail'] = str2bool( + self.request.GET.get('template_detail', True) + ) except AttributeError: pass @@ -1472,12 +1513,13 @@ class PartParameterFilter(rest_filters.FilterSet): class Meta: """Metaclass options for the filterset""" - model = PartParameter - fields = [ - 'template' - ] - part = rest_filters.ModelChoiceFilter(queryset=Part.objects.all(), method='filter_part') + model = PartParameter + fields = ['template'] + + part = rest_filters.ModelChoiceFilter( + queryset=Part.objects.all(), method='filter_part' + ) def filter_part(self, queryset, name, part): """Filter against the provided part. @@ -1506,12 +1548,7 @@ class PartParameterList(PartParameterAPIMixin, ListCreateAPI): filter_backends = SEARCH_ORDER_FILTER_ALIAS - ordering_fields = [ - 'name', - 'data', - 'part', - 'template', - ] + ordering_fields = ['name', 'data', 'part', 'template'] ordering_field_aliases = { 'name': 'template__name', @@ -1528,6 +1565,7 @@ class PartParameterList(PartParameterAPIMixin, ListCreateAPI): class PartParameterDetail(PartParameterAPIMixin, RetrieveUpdateDestroyAPI): """API endpoint for detail view of a single PartParameter object.""" + pass @@ -1538,10 +1576,7 @@ class PartStocktakeFilter(rest_filters.FilterSet): """Metaclass options""" model = PartStocktake - fields = [ - 'part', - 'user', - ] + fields = ['part', 'user'] class PartStocktakeList(ListCreateAPI): @@ -1560,14 +1595,7 @@ class PartStocktakeList(ListCreateAPI): filter_backends = ORDER_FILTER - ordering_fields = [ - 'part', - 'item_count', - 'quantity', - 'date', - 'user', - 'pk', - ] + ordering_fields = ['part', 'item_count', 'quantity', 'date', 'user', 'pk'] # Reverse date ordering by default ordering = '-pk' @@ -1591,10 +1619,7 @@ class PartStocktakeReportList(ListAPI): filter_backends = ORDER_FILTER - ordering_fields = [ - 'date', - 'pk', - ] + ordering_fields = ['date', 'pk'] # Newest first, by default ordering = '-pk' @@ -1605,10 +1630,7 @@ class PartStocktakeReportGenerate(CreateAPI): serializer_class = part_serializers.PartStocktakeReportGenerateSerializer - permission_classes = [ - permissions.IsAuthenticated, - RolePermission, - ] + permission_classes = [permissions.IsAuthenticated, RolePermission] role_required = 'stocktake' @@ -1627,23 +1649,27 @@ class BomFilter(rest_filters.FilterSet): """Metaclass options""" model = BomItem - fields = [ - 'optional', - 'consumable', - 'inherited', - 'allow_variants', - 'validated', - ] + fields = ['optional', 'consumable', 'inherited', 'allow_variants', 'validated'] # Filters for linked 'part' - part_active = rest_filters.BooleanFilter(label='Master part is active', field_name='part__active') - part_trackable = rest_filters.BooleanFilter(label='Master part is trackable', field_name='part__trackable') + part_active = rest_filters.BooleanFilter( + label='Master part is active', field_name='part__active' + ) + part_trackable = rest_filters.BooleanFilter( + label='Master part is trackable', field_name='part__trackable' + ) # Filters for linked 'sub_part' - sub_part_trackable = rest_filters.BooleanFilter(label='Sub part is trackable', field_name='sub_part__trackable') - sub_part_assembly = rest_filters.BooleanFilter(label='Sub part is an assembly', field_name='sub_part__assembly') + sub_part_trackable = rest_filters.BooleanFilter( + label='Sub part is trackable', field_name='sub_part__trackable' + ) + sub_part_assembly = rest_filters.BooleanFilter( + label='Sub part is an assembly', field_name='sub_part__assembly' + ) - available_stock = rest_filters.BooleanFilter(label="Has available stock", method="filter_available_stock") + available_stock = rest_filters.BooleanFilter( + label='Has available stock', method='filter_available_stock' + ) def filter_available_stock(self, queryset, name, value): """Filter the queryset based on whether each line item has any available stock""" @@ -1651,7 +1677,7 @@ class BomFilter(rest_filters.FilterSet): return queryset.filter(available_stock__gt=0) return queryset.filter(available_stock=0) - on_order = rest_filters.BooleanFilter(label="On order", method="filter_on_order") + on_order = rest_filters.BooleanFilter(label='On order', method='filter_on_order') def filter_on_order(self, queryset, name, value): """Filter the queryset based on whether each line item has any stock on order""" @@ -1659,12 +1685,17 @@ class BomFilter(rest_filters.FilterSet): return queryset.filter(on_order__gt=0) return queryset.filter(on_order=0) - has_pricing = rest_filters.BooleanFilter(label="Has Pricing", method="filter_has_pricing") + has_pricing = rest_filters.BooleanFilter( + label='Has Pricing', method='filter_has_pricing' + ) def filter_has_pricing(self, queryset, name, value): """Filter the queryset based on whether pricing information is available for the sub_part""" q_a = Q(sub_part__pricing_data=None) - q_b = Q(sub_part__pricing_data__overall_min=None, sub_part__pricing_data__overall_max=None) + q_b = Q( + sub_part__pricing_data__overall_min=None, + sub_part__pricing_data__overall_max=None, + ) if str2bool(value): return queryset.exclude(q_a | q_b) @@ -1692,7 +1723,9 @@ class BomMixin: pass try: - kwargs['sub_part_detail'] = str2bool(self.request.GET.get('sub_part_detail', None)) + kwargs['sub_part_detail'] = str2bool( + self.request.GET.get('sub_part_detail', None) + ) except AttributeError: pass @@ -1789,7 +1822,6 @@ class BomList(BomMixin, ListCreateDestroyAPIView): uses = params.get('uses', None) if uses is not None: - try: # Extract the part we are interested in uses_part = Part.objects.get(pk=uses) @@ -1823,13 +1855,12 @@ class BomList(BomMixin, ListCreateDestroyAPIView): 'consumable', ] - ordering_field_aliases = { - 'sub_part': 'sub_part__name', - } + ordering_field_aliases = {'sub_part': 'sub_part__name'} class BomDetail(BomMixin, RetrieveUpdateDestroyAPI): """API endpoint for detail view of a single BomItem object.""" + pass @@ -1876,6 +1907,7 @@ class BomItemValidate(UpdateAPI): class BomItemValidationSerializer(serializers.Serializer): """Simple serializer for passing a single boolean field""" + valid = serializers.BooleanField(default=False) queryset = BomItem.objects.all() @@ -1908,10 +1940,7 @@ class BomItemSubstituteList(ListCreateAPI): filter_backends = SEARCH_ORDER_FILTER - filterset_fields = [ - 'part', - 'bom_item', - ] + filterset_fields = ['part', 'bom_item'] class BomItemSubstituteDetail(RetrieveUpdateDestroyAPI): @@ -1922,167 +1951,351 @@ class BomItemSubstituteDetail(RetrieveUpdateDestroyAPI): part_api_urls = [ - # Base URL for PartCategory API endpoints - re_path(r'^category/', include([ - re_path(r'^tree/', CategoryTree.as_view(), name='api-part-category-tree'), - - re_path(r'^parameters/', include([ - re_path(r'^(?P\d+)/', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': PartCategoryParameterTemplate}, name='api-part-category-parameter-metadata'), - re_path(r'^.*$', CategoryParameterDetail.as_view(), name='api-part-category-parameter-detail'), - ])), - re_path(r'^.*$', CategoryParameterList.as_view(), name='api-part-category-parameter-list'), - ])), - - # Category detail endpoints - path(r'/', include([ - - re_path(r'^metadata/', MetadataView.as_view(), {'model': PartCategory}, name='api-part-category-metadata'), - - # PartCategory detail endpoint - re_path(r'^.*$', CategoryDetail.as_view(), name='api-part-category-detail'), - ])), - - path('', CategoryList.as_view(), name='api-part-category-list'), - ])), - + re_path( + r'^category/', + include([ + re_path(r'^tree/', CategoryTree.as_view(), name='api-part-category-tree'), + re_path( + r'^parameters/', + include([ + re_path( + r'^(?P\d+)/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': PartCategoryParameterTemplate}, + name='api-part-category-parameter-metadata', + ), + re_path( + r'^.*$', + CategoryParameterDetail.as_view(), + name='api-part-category-parameter-detail', + ), + ]), + ), + re_path( + r'^.*$', + CategoryParameterList.as_view(), + name='api-part-category-parameter-list', + ), + ]), + ), + # Category detail endpoints + path( + r'/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': PartCategory}, + name='api-part-category-metadata', + ), + # PartCategory detail endpoint + re_path( + r'^.*$', + CategoryDetail.as_view(), + name='api-part-category-detail', + ), + ]), + ), + path('', CategoryList.as_view(), name='api-part-category-list'), + ]), + ), # Base URL for PartTestTemplate API endpoints - re_path(r'^test-template/', include([ - path(r'/', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': PartTestTemplate}, name='api-part-test-template-metadata'), - re_path(r'^.*$', PartTestTemplateDetail.as_view(), name='api-part-test-template-detail'), - ])), - path('', PartTestTemplateList.as_view(), name='api-part-test-template-list'), - ])), - + re_path( + r'^test-template/', + include([ + path( + r'/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': PartTestTemplate}, + name='api-part-test-template-metadata', + ), + re_path( + r'^.*$', + PartTestTemplateDetail.as_view(), + name='api-part-test-template-detail', + ), + ]), + ), + path( + '', PartTestTemplateList.as_view(), name='api-part-test-template-list' + ), + ]), + ), # Base URL for PartAttachment API endpoints - re_path(r'^attachment/', include([ - path(r'/', PartAttachmentDetail.as_view(), name='api-part-attachment-detail'), - path('', PartAttachmentList.as_view(), name='api-part-attachment-list'), - ])), - + re_path( + r'^attachment/', + include([ + path( + r'/', + PartAttachmentDetail.as_view(), + name='api-part-attachment-detail', + ), + path('', PartAttachmentList.as_view(), name='api-part-attachment-list'), + ]), + ), # Base URL for part sale pricing - re_path(r'^sale-price/', include([ - path(r'/', PartSalePriceDetail.as_view(), name='api-part-sale-price-detail'), - re_path(r'^.*$', PartSalePriceList.as_view(), name='api-part-sale-price-list'), - ])), - + re_path( + r'^sale-price/', + include([ + path( + r'/', + PartSalePriceDetail.as_view(), + name='api-part-sale-price-detail', + ), + re_path( + r'^.*$', PartSalePriceList.as_view(), name='api-part-sale-price-list' + ), + ]), + ), # Base URL for part internal pricing - re_path(r'^internal-price/', include([ - path(r'/', PartInternalPriceDetail.as_view(), name='api-part-internal-price-detail'), - re_path(r'^.*$', PartInternalPriceList.as_view(), name='api-part-internal-price-list'), - ])), - + re_path( + r'^internal-price/', + include([ + path( + r'/', + PartInternalPriceDetail.as_view(), + name='api-part-internal-price-detail', + ), + re_path( + r'^.*$', + PartInternalPriceList.as_view(), + name='api-part-internal-price-list', + ), + ]), + ), # Base URL for PartRelated API endpoints - re_path(r'^related/', include([ - path(r'/', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': PartRelated}, name='api-part-related-metadata'), - re_path(r'^.*$', PartRelatedDetail.as_view(), name='api-part-related-detail'), - ])), - re_path(r'^.*$', PartRelatedList.as_view(), name='api-part-related-list'), - ])), - + re_path( + r'^related/', + include([ + path( + r'/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': PartRelated}, + name='api-part-related-metadata', + ), + re_path( + r'^.*$', + PartRelatedDetail.as_view(), + name='api-part-related-detail', + ), + ]), + ), + re_path(r'^.*$', PartRelatedList.as_view(), name='api-part-related-list'), + ]), + ), # Base URL for PartParameter API endpoints - re_path(r'^parameter/', include([ - path('template/', include([ - re_path(r'^(?P\d+)/', include([ - re_path(r'^metadata/?', MetadataView.as_view(), {'model': PartParameterTemplate}, name='api-part-parameter-template-metadata'), - re_path(r'^.*$', PartParameterTemplateDetail.as_view(), name='api-part-parameter-template-detail'), - ])), - re_path(r'^.*$', PartParameterTemplateList.as_view(), name='api-part-parameter-template-list'), - ])), - - path(r'/', include([ - re_path(r'^metadata/?', MetadataView.as_view(), {'model': PartParameter}, name='api-part-parameter-metadata'), - re_path(r'^.*$', PartParameterDetail.as_view(), name='api-part-parameter-detail'), - ])), - re_path(r'^.*$', PartParameterList.as_view(), name='api-part-parameter-list'), - ])), - + re_path( + r'^parameter/', + include([ + path( + 'template/', + include([ + re_path( + r'^(?P\d+)/', + include([ + re_path( + r'^metadata/?', + MetadataView.as_view(), + {'model': PartParameterTemplate}, + name='api-part-parameter-template-metadata', + ), + re_path( + r'^.*$', + PartParameterTemplateDetail.as_view(), + name='api-part-parameter-template-detail', + ), + ]), + ), + re_path( + r'^.*$', + PartParameterTemplateList.as_view(), + name='api-part-parameter-template-list', + ), + ]), + ), + path( + r'/', + include([ + re_path( + r'^metadata/?', + MetadataView.as_view(), + {'model': PartParameter}, + name='api-part-parameter-metadata', + ), + re_path( + r'^.*$', + PartParameterDetail.as_view(), + name='api-part-parameter-detail', + ), + ]), + ), + re_path( + r'^.*$', PartParameterList.as_view(), name='api-part-parameter-list' + ), + ]), + ), # Part stocktake data - re_path(r'^stocktake/', include([ - - path(r'report/', include([ - path('generate/', PartStocktakeReportGenerate.as_view(), name='api-part-stocktake-report-generate'), - re_path(r'^.*$', PartStocktakeReportList.as_view(), name='api-part-stocktake-report-list'), - ])), - - path(r'/', PartStocktakeDetail.as_view(), name='api-part-stocktake-detail'), - re_path(r'^.*$', PartStocktakeList.as_view(), name='api-part-stocktake-list'), - ])), - - re_path(r'^thumbs/', include([ - path('', PartThumbs.as_view(), name='api-part-thumbs'), - re_path(r'^(?P\d+)/?', PartThumbsUpdate.as_view(), name='api-part-thumbs-update'), - ])), - + re_path( + r'^stocktake/', + include([ + path( + r'report/', + include([ + path( + 'generate/', + PartStocktakeReportGenerate.as_view(), + name='api-part-stocktake-report-generate', + ), + re_path( + r'^.*$', + PartStocktakeReportList.as_view(), + name='api-part-stocktake-report-list', + ), + ]), + ), + path( + r'/', + PartStocktakeDetail.as_view(), + name='api-part-stocktake-detail', + ), + re_path( + r'^.*$', PartStocktakeList.as_view(), name='api-part-stocktake-list' + ), + ]), + ), + re_path( + r'^thumbs/', + include([ + path('', PartThumbs.as_view(), name='api-part-thumbs'), + re_path( + r'^(?P\d+)/?', + PartThumbsUpdate.as_view(), + name='api-part-thumbs-update', + ), + ]), + ), # BOM template - re_path(r'^bom_template/?', views.BomUploadTemplate.as_view(), name='api-bom-upload-template'), - - path(r'/', include([ - - # Endpoint for extra serial number information - re_path(r'^serial-numbers/', PartSerialNumberDetail.as_view(), name='api-part-serial-number-detail'), - - # Endpoint for future scheduling information - re_path(r'^scheduling/', PartScheduling.as_view(), name='api-part-scheduling'), - - re_path(r'^requirements/', PartRequirements.as_view(), name='api-part-requirements'), - - # Endpoint for duplicating a BOM for the specific Part - re_path(r'^bom-copy/', PartCopyBOM.as_view(), name='api-part-bom-copy'), - - # Endpoint for validating a BOM for the specific Part - re_path(r'^bom-validate/', PartValidateBOM.as_view(), name='api-part-bom-validate'), - - # Part metadata - re_path(r'^metadata/', MetadataView.as_view(), {'model': Part}, name='api-part-metadata'), - - # Part pricing - re_path(r'^pricing/', PartPricingDetail.as_view(), name='api-part-pricing'), - - # BOM download - re_path(r'^bom-download/?', views.BomDownload.as_view(), name='api-bom-download'), - - # Old pricing endpoint - re_path(r'^pricing2/', views.PartPricing.as_view(), name='part-pricing'), - - # Part detail endpoint - re_path(r'^.*$', PartDetail.as_view(), name='api-part-detail'), - ])), - - re_path(r'^change_category/', PartChangeCategory.as_view(), name='api-part-change-category'), - + re_path( + r'^bom_template/?', + views.BomUploadTemplate.as_view(), + name='api-bom-upload-template', + ), + path( + r'/', + include([ + # Endpoint for extra serial number information + re_path( + r'^serial-numbers/', + PartSerialNumberDetail.as_view(), + name='api-part-serial-number-detail', + ), + # Endpoint for future scheduling information + re_path( + r'^scheduling/', PartScheduling.as_view(), name='api-part-scheduling' + ), + re_path( + r'^requirements/', + PartRequirements.as_view(), + name='api-part-requirements', + ), + # Endpoint for duplicating a BOM for the specific Part + re_path(r'^bom-copy/', PartCopyBOM.as_view(), name='api-part-bom-copy'), + # Endpoint for validating a BOM for the specific Part + re_path( + r'^bom-validate/', + PartValidateBOM.as_view(), + name='api-part-bom-validate', + ), + # Part metadata + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': Part}, + name='api-part-metadata', + ), + # Part pricing + re_path(r'^pricing/', PartPricingDetail.as_view(), name='api-part-pricing'), + # BOM download + re_path( + r'^bom-download/?', views.BomDownload.as_view(), name='api-bom-download' + ), + # Old pricing endpoint + re_path(r'^pricing2/', views.PartPricing.as_view(), name='part-pricing'), + # Part detail endpoint + re_path(r'^.*$', PartDetail.as_view(), name='api-part-detail'), + ]), + ), + re_path( + r'^change_category/', + PartChangeCategory.as_view(), + name='api-part-change-category', + ), re_path(r'^.*$', PartList.as_view(), name='api-part-list'), ] bom_api_urls = [ - - re_path(r'^substitute/', include([ - - # Detail view - path(r'/', include([ - re_path(r'^metadata/?', MetadataView.as_view(), {'model': BomItemSubstitute}, name='api-bom-substitute-metadata'), - re_path(r'^.*$', BomItemSubstituteDetail.as_view(), name='api-bom-substitute-detail'), - ])), - - # Catch all - re_path(r'^.*$', BomItemSubstituteList.as_view(), name='api-bom-substitute-list'), - ])), - + re_path( + r'^substitute/', + include([ + # Detail view + path( + r'/', + include([ + re_path( + r'^metadata/?', + MetadataView.as_view(), + {'model': BomItemSubstitute}, + name='api-bom-substitute-metadata', + ), + re_path( + r'^.*$', + BomItemSubstituteDetail.as_view(), + name='api-bom-substitute-detail', + ), + ]), + ), + # Catch all + re_path( + r'^.*$', BomItemSubstituteList.as_view(), name='api-bom-substitute-list' + ), + ]), + ), # BOM Item Detail - path(r'/', include([ - re_path(r'^validate/?', BomItemValidate.as_view(), name='api-bom-item-validate'), - re_path(r'^metadata/?', MetadataView.as_view(), {'model': BomItem}, name='api-bom-item-metadata'), - re_path(r'^.*$', BomDetail.as_view(), name='api-bom-item-detail'), - ])), - + path( + r'/', + include([ + re_path( + r'^validate/?', BomItemValidate.as_view(), name='api-bom-item-validate' + ), + re_path( + r'^metadata/?', + MetadataView.as_view(), + {'model': BomItem}, + name='api-bom-item-metadata', + ), + re_path(r'^.*$', BomDetail.as_view(), name='api-bom-item-detail'), + ]), + ), # API endpoint URLs for importing BOM data - re_path(r'^import/upload/', BomImportUpload.as_view(), name='api-bom-import-upload'), - re_path(r'^import/extract/', BomImportExtract.as_view(), name='api-bom-import-extract'), - re_path(r'^import/submit/', BomImportSubmit.as_view(), name='api-bom-import-submit'), - + re_path( + r'^import/upload/', BomImportUpload.as_view(), name='api-bom-import-upload' + ), + re_path( + r'^import/extract/', BomImportExtract.as_view(), name='api-bom-import-extract' + ), + re_path( + r'^import/submit/', BomImportSubmit.as_view(), name='api-bom-import-submit' + ), # Catch-all re_path(r'^.*$', BomList.as_view(), name='api-bom-list'), ] diff --git a/InvenTree/part/apps.py b/InvenTree/part/apps.py index 209cf60f9b..0f6c108427 100644 --- a/InvenTree/part/apps.py +++ b/InvenTree/part/apps.py @@ -7,17 +7,21 @@ from django.db.utils import OperationalError, ProgrammingError import InvenTree.ready -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') class PartConfig(AppConfig): """Config class for the 'part' app""" + name = 'part' def ready(self): """This function is called whenever the Part app is loaded.""" # skip loading if plugin registry is not loaded or we run in a background thread - if not InvenTree.ready.isPluginRegistryLoaded() or not InvenTree.ready.isInMainThread(): + if ( + not InvenTree.ready.isPluginRegistryLoaded() + or not InvenTree.ready.isInMainThread() + ): return if InvenTree.ready.isRunningMigrations(): @@ -35,7 +39,9 @@ class PartConfig(AppConfig): from .models import BomItem try: - items = BomItem.objects.filter(part__trackable=False, sub_part__trackable=True) + items = BomItem.objects.filter( + part__trackable=False, sub_part__trackable=True + ) for item in items: logger.info("Marking part '%s' as trackable", item.part.name) @@ -61,10 +67,12 @@ class PartConfig(AppConfig): if items.count() > 0: # Find any pricing objects which have the 'scheduled_for_update' flag set - logger.info("Resetting update flags for %s pricing objects...", items.count()) + logger.info( + 'Resetting update flags for %s pricing objects...', items.count() + ) for pricing in items: pricing.scheduled_for_update = False pricing.save() except Exception: - logger.exception("Failed to reset pricing flags - database not ready") + logger.exception('Failed to reset pricing flags - database not ready') diff --git a/InvenTree/part/bom.py b/InvenTree/part/bom.py index 525e60987b..88b67d17e0 100644 --- a/InvenTree/part/bom.py +++ b/InvenTree/part/bom.py @@ -8,8 +8,7 @@ from collections import OrderedDict from django.utils.translation import gettext as _ from company.models import ManufacturerPart, SupplierPart -from InvenTree.helpers import (DownloadFile, GetExportFormats, normalize, - str2bool) +from InvenTree.helpers import DownloadFile, GetExportFormats, normalize, str2bool from .admin import BomItemResource from .models import BomItem, BomItemSubstitute, Part @@ -31,10 +30,7 @@ def MakeBomTemplate(fmt): # This will then export just the row headers! query = BomItem.objects.filter(pk=None) - dataset = BomItemResource().export( - queryset=query, - importing=True - ) + dataset = BomItemResource().export(queryset=query, importing=True) data = dataset.export(fmt) @@ -43,7 +39,9 @@ def MakeBomTemplate(fmt): return DownloadFile(data, filename) -def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = None, **kwargs): +def ExportBom( + part: Part, fmt='csv', cascade: bool = False, max_levels: int = None, **kwargs +): """Export a BOM (Bill of Materials) for a given part. Args: @@ -80,7 +78,6 @@ def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = No def add_items(items, level, cascade=True): # Add items at a given layer for item in items: - item.level = str(int(level)) # Avoid circular BOM references @@ -98,9 +95,7 @@ def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = No add_items(top_level_items, 1, cascade) dataset = BomItemResource().export( - queryset=bom_items, - cascade=cascade, - include_pricing=pricing_data, + queryset=bom_items, cascade=cascade, include_pricing=pricing_data ) def add_columns_to_dataset(columns, column_size): @@ -122,7 +117,6 @@ def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = No for bom_item in bom_items: substitutes = BomItemSubstitute.objects.filter(bom_item=bom_item) for s_idx, substitute in enumerate(substitutes): - """Create substitute part name column""" name = f'{_("Substitute Part")}{s_idx + 1}' value = substitute.part.name @@ -164,7 +158,9 @@ def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = No parameter_cols[name] = {b_idx: value} # Add parameter columns to dataset - parameter_cols_ordered = OrderedDict(sorted(parameter_cols.items(), key=lambda x: x[0])) + parameter_cols_ordered = OrderedDict( + sorted(parameter_cols.items(), key=lambda x: x[0]) + ) add_columns_to_dataset(parameter_cols_ordered, len(bom_items)) if stock_data: @@ -180,7 +176,6 @@ def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = No stock_cols = {} for b_idx, bom_item in enumerate(bom_items): - stock_data = [] sub_part = bom_item.sub_part @@ -197,19 +192,13 @@ def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = No stock_data.append('') # Total "in stock" quantity for this part - stock_data.append( - str(normalize(sub_part.total_stock)) - ) + stock_data.append(str(normalize(sub_part.total_stock))) # Total "available stock" quantity for this part - stock_data.append( - str(normalize(sub_part.available_stock)) - ) + stock_data.append(str(normalize(sub_part.available_stock))) # Total "on order" quantity for this part - stock_data.append( - str(normalize(sub_part.on_order)) - ) + stock_data.append(str(normalize(sub_part.on_order))) for s_idx, header in enumerate(stock_headers): try: @@ -234,12 +223,12 @@ def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = No # Include manufacturer data for each BOM item if manufacturer_data: - # Filter manufacturer parts - manufacturer_parts = ManufacturerPart.objects.filter(part__pk=b_part.pk).prefetch_related('supplier_parts') + manufacturer_parts = ManufacturerPart.objects.filter( + part__pk=b_part.pk + ).prefetch_related('supplier_parts') for mp_idx, mp_part in enumerate(manufacturer_parts): - # Extract the "name" field of the Manufacturer (Company) if mp_part and mp_part.manufacturer: manufacturer_name = mp_part.manufacturer.name @@ -265,9 +254,7 @@ def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = No # We wish to include supplier data for this manufacturer part if supplier_data: - for sp_idx, sp_part in enumerate(mp_part.supplier_parts.all()): - supplier_parts_used.add(sp_part) if sp_part.supplier: @@ -281,11 +268,21 @@ def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = No supplier_sku = '' # Generate column names for this supplier - k_sup = str(_("Supplier")) + "_" + str(mp_idx) + "_" + str(sp_idx) - k_sku = str(_("SKU")) + "_" + str(mp_idx) + "_" + str(sp_idx) + k_sup = ( + str(_('Supplier')) + + '_' + + str(mp_idx) + + '_' + + str(sp_idx) + ) + k_sku = ( + str(_('SKU')) + '_' + str(mp_idx) + '_' + str(sp_idx) + ) try: - manufacturer_cols[k_sup].update({bom_idx: supplier_name}) + manufacturer_cols[k_sup].update({ + bom_idx: supplier_name + }) manufacturer_cols[k_sku].update({bom_idx: supplier_sku}) except KeyError: manufacturer_cols[k_sup] = {bom_idx: supplier_name} @@ -294,8 +291,9 @@ def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = No if supplier_data: # Add in any extra supplier parts, which are not associated with a manufacturer part - for sp_idx, sp_part in enumerate(SupplierPart.objects.filter(part__pk=b_part.pk)): - + for sp_idx, sp_part in enumerate( + SupplierPart.objects.filter(part__pk=b_part.pk) + ): if sp_part in supplier_parts_used: continue @@ -309,8 +307,8 @@ def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = No supplier_sku = sp_part.SKU # Generate column names for this supplier - k_sup = str(_("Supplier")) + "_" + str(sp_idx) - k_sku = str(_("SKU")) + "_" + str(sp_idx) + k_sup = str(_('Supplier')) + '_' + str(sp_idx) + k_sku = str(_('SKU')) + '_' + str(sp_idx) try: manufacturer_cols[k_sup].update({bom_idx: supplier_name}) @@ -324,6 +322,6 @@ def ExportBom(part: Part, fmt='csv', cascade: bool = False, max_levels: int = No data = dataset.export(fmt) - filename = f"{part.full_name}_BOM.{fmt}" + filename = f'{part.full_name}_BOM.{fmt}' return DownloadFile(data, filename) diff --git a/InvenTree/part/filters.py b/InvenTree/part/filters.py index e48909dd87..0877882142 100644 --- a/InvenTree/part/filters.py +++ b/InvenTree/part/filters.py @@ -19,18 +19,32 @@ Relevant PRs: from decimal import Decimal from django.db import models -from django.db.models import (Case, DecimalField, Exists, ExpressionWrapper, F, - FloatField, Func, IntegerField, OuterRef, Q, - Subquery, Value, When) +from django.db.models import ( + Case, + DecimalField, + Exists, + ExpressionWrapper, + F, + FloatField, + Func, + IntegerField, + OuterRef, + Q, + Subquery, + Value, + When, +) from django.db.models.functions import Coalesce from sql_util.utils import SubquerySum import part.models import stock.models -from InvenTree.status_codes import (BuildStatusGroups, - PurchaseOrderStatusGroups, - SalesOrderStatusGroups) +from InvenTree.status_codes import ( + BuildStatusGroups, + PurchaseOrderStatusGroups, + SalesOrderStatusGroups, +) def annotate_on_order_quantity(reference: str = ''): @@ -46,27 +60,28 @@ def annotate_on_order_quantity(reference: str = ''): # Filter only 'active' purhase orders # Filter only line with outstanding quantity order_filter = Q( - order__status__in=PurchaseOrderStatusGroups.OPEN, - quantity__gt=F('received'), + order__status__in=PurchaseOrderStatusGroups.OPEN, quantity__gt=F('received') ) return Coalesce( SubquerySum( ExpressionWrapper( - F(f'{reference}supplier_parts__purchase_order_line_items__quantity') * F(f'{reference}supplier_parts__pack_quantity_native'), + F(f'{reference}supplier_parts__purchase_order_line_items__quantity') + * F(f'{reference}supplier_parts__pack_quantity_native'), output_field=DecimalField(), ), - filter=order_filter + filter=order_filter, ), Decimal(0), - output_field=DecimalField() + output_field=DecimalField(), ) - Coalesce( SubquerySum( ExpressionWrapper( - F(f'{reference}supplier_parts__purchase_order_line_items__received') * F(f'{reference}supplier_parts__pack_quantity_native'), + F(f'{reference}supplier_parts__purchase_order_line_items__received') + * F(f'{reference}supplier_parts__pack_quantity_native'), output_field=DecimalField(), ), - filter=order_filter + filter=order_filter, ), Decimal(0), output_field=DecimalField(), @@ -88,10 +103,7 @@ def annotate_total_stock(reference: str = ''): stock_filter = stock.models.StockItem.IN_STOCK_FILTER return Coalesce( - SubquerySum( - f'{reference}stock_items__quantity', - filter=stock_filter, - ), + SubquerySum(f'{reference}stock_items__quantity', filter=stock_filter), Decimal(0), output_field=models.DecimalField(), ) @@ -109,10 +121,7 @@ def annotate_build_order_requirements(reference: str = ''): build_filter = Q(build__status__in=BuildStatusGroups.ACTIVE_CODES) return Coalesce( - SubquerySum( - f'{reference}used_in__build_lines__quantity', - filter=build_filter, - ), + SubquerySum(f'{reference}used_in__build_lines__quantity', filter=build_filter), Decimal(0), output_field=models.DecimalField(), ) @@ -134,8 +143,7 @@ def annotate_build_order_allocations(reference: str = ''): return Coalesce( SubquerySum( - f'{reference}stock_items__allocations__quantity', - filter=build_filter, + f'{reference}stock_items__allocations__quantity', filter=build_filter ), Decimal(0), output_field=models.DecimalField(), @@ -169,7 +177,9 @@ def annotate_sales_order_allocations(reference: str = ''): ) -def variant_stock_query(reference: str = '', filter: Q = stock.models.StockItem.IN_STOCK_FILTER): +def variant_stock_query( + reference: str = '', filter: Q = stock.models.StockItem.IN_STOCK_FILTER +): """Create a queryset to retrieve all stock items for variant parts under the specified part - Useful for annotating a queryset with aggregated information about variant parts @@ -221,10 +231,10 @@ def annotate_category_parts(): Subquery( subquery.annotate( total=Func(F('pk'), function='COUNT', output_field=IntegerField()) - ).values('total'), + ).values('total') ), 0, - output_field=IntegerField() + output_field=IntegerField(), ) @@ -259,31 +269,33 @@ def order_by_parameter(queryset, template_id: int, ascending=True): A queryset of Part objects ordered by the given parameter """ template_filter = part.models.PartParameter.objects.filter( - template__id=template_id, - part_id=OuterRef('id'), + template__id=template_id, part_id=OuterRef('id') ) # Annotate the queryset with the parameter value, and whether it exists - queryset = queryset.annotate( - parameter_exists=Exists(template_filter) - ) + queryset = queryset.annotate(parameter_exists=Exists(template_filter)) # Annotate the text data value queryset = queryset.annotate( parameter_value=Case( When( parameter_exists=True, - then=Subquery(template_filter.values('data')[:1], output_field=models.CharField()), + then=Subquery( + template_filter.values('data')[:1], output_field=models.CharField() + ), ), default=Value('', output_field=models.CharField()), ), parameter_value_numeric=Case( When( parameter_exists=True, - then=Subquery(template_filter.values('data_numeric')[:1], output_field=models.FloatField()), + then=Subquery( + template_filter.values('data_numeric')[:1], + output_field=models.FloatField(), + ), ), default=Value(0, output_field=models.FloatField()), - ) + ), ) prefix = '' if ascending else '-' diff --git a/InvenTree/part/forms.py b/InvenTree/part/forms.py index f1934bc105..b52f335d24 100644 --- a/InvenTree/part/forms.py +++ b/InvenTree/part/forms.py @@ -18,14 +18,16 @@ class BomMatchItemForm(MatchItemForm): if 'quantity' in col_guess.lower(): return forms.CharField( required=False, - widget=forms.NumberInput(attrs={ - 'name': 'quantity' + str(row['index']), - 'class': 'numberinput', - 'type': 'number', - 'min': '0', - 'step': 'any', - 'value': clean_decimal(row.get('quantity', '')), - }) + widget=forms.NumberInput( + attrs={ + 'name': 'quantity' + str(row['index']), + 'class': 'numberinput', + 'type': 'number', + 'min': '0', + 'step': 'any', + 'value': clean_decimal(row.get('quantity', '')), + } + ), ) return super().get_special_field(col_guess, row, file_manager) @@ -36,14 +38,13 @@ class PartPriceForm(forms.Form): class Meta: """Metaclass defines fields for this form""" + model = Part - fields = [ - 'quantity', - ] + fields = ['quantity'] quantity = forms.IntegerField( required=True, initial=1, label=_('Quantity'), - help_text=_('Input quantity for price calculation') + help_text=_('Input quantity for price calculation'), ) diff --git a/InvenTree/part/helpers.py b/InvenTree/part/helpers.py index ca1150b2ee..e8960b29a7 100644 --- a/InvenTree/part/helpers.py +++ b/InvenTree/part/helpers.py @@ -29,7 +29,10 @@ def compile_full_name_template(*args, **kwargs): template_string = InvenTreeSetting.get_setting('PART_NAME_FORMAT', '') # Skip if the template string has not changed - if template_string == _part_full_name_template_string and _part_full_name_template is not None: + if ( + template_string == _part_full_name_template_string + and _part_full_name_template is not None + ): return _part_full_name_template # Cache the template string @@ -38,7 +41,7 @@ def compile_full_name_template(*args, **kwargs): env = Environment( autoescape=select_autoescape(default_for_string=False, default=False), variable_start_string='{{', - variable_end_string='}}' + variable_end_string='}}', ) # Compile the template @@ -65,7 +68,11 @@ def render_part_full_name(part) -> str: try: return template.render(part=part) except Exception as e: - logger.warning("exception while trying to create full name for part %s: %s", part.name, e) + logger.warning( + 'exception while trying to create full name for part %s: %s', + part.name, + e, + ) # Fallback to the default format elements = [el for el in [part.IPN, part.name, part.revision] if el] @@ -73,7 +80,7 @@ def render_part_full_name(part) -> str: # Subdirectory for storing part images -PART_IMAGE_DIR = "part_images" +PART_IMAGE_DIR = 'part_images' def get_part_image_directory() -> str: @@ -85,10 +92,9 @@ def get_part_image_directory() -> str: TODO: Future work may be needed here to support other storage backends, such as S3 """ - part_image_directory = os.path.abspath(os.path.join( - settings.MEDIA_ROOT, - PART_IMAGE_DIR, - )) + part_image_directory = os.path.abspath( + os.path.join(settings.MEDIA_ROOT, PART_IMAGE_DIR) + ) # Create the directory if it does not exist if not os.path.exists(part_image_directory): diff --git a/InvenTree/part/models.py b/InvenTree/part/models.py index be9fa33170..df6830211f 100644 --- a/InvenTree/part/models.py +++ b/InvenTree/part/models.py @@ -48,18 +48,26 @@ from common.settings import currency_code_default from company.models import SupplierPart from InvenTree import helpers, validators from InvenTree.fields import InvenTreeURLField -from InvenTree.helpers import (decimal2money, decimal2string, normalize, - str2bool) -from InvenTree.models import (DataImportMixin, InvenTreeAttachment, - InvenTreeBarcodeMixin, InvenTreeNotesMixin, - InvenTreeTree, MetadataMixin) -from InvenTree.status_codes import (BuildStatusGroups, PurchaseOrderStatus, - PurchaseOrderStatusGroups, - SalesOrderStatus, SalesOrderStatusGroups) +from InvenTree.helpers import decimal2money, decimal2string, normalize, str2bool +from InvenTree.models import ( + DataImportMixin, + InvenTreeAttachment, + InvenTreeBarcodeMixin, + InvenTreeNotesMixin, + InvenTreeTree, + MetadataMixin, +) +from InvenTree.status_codes import ( + BuildStatusGroups, + PurchaseOrderStatus, + PurchaseOrderStatusGroups, + SalesOrderStatus, + SalesOrderStatusGroups, +) from order import models as OrderModels from stock import models as StockModels -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') class PartCategory(MetadataMixin, InvenTreeTree): @@ -76,8 +84,9 @@ class PartCategory(MetadataMixin, InvenTreeTree): class Meta: """Metaclass defines extra model properties""" - verbose_name = _("Part Category") - verbose_name_plural = _("Part Categories") + + verbose_name = _('Part Category') + verbose_name_plural = _('Part Categories') def delete(self, *args, **kwargs): """Custom model deletion routine, which updates any child categories or parts. @@ -91,11 +100,13 @@ class PartCategory(MetadataMixin, InvenTreeTree): ) default_location = TreeForeignKey( - 'stock.StockLocation', related_name="default_categories", - null=True, blank=True, + 'stock.StockLocation', + related_name='default_categories', + null=True, + blank=True, on_delete=models.SET_NULL, verbose_name=_('Default Location'), - help_text=_('Default location for parts in this category') + help_text=_('Default location for parts in this category'), ) structural = models.BooleanField( @@ -103,16 +114,23 @@ class PartCategory(MetadataMixin, InvenTreeTree): verbose_name=_('Structural'), help_text=_( 'Parts may not be directly assigned to a structural category, ' - 'but may be assigned to child categories.'), + 'but may be assigned to child categories.' + ), ) - default_keywords = models.CharField(null=True, blank=True, max_length=250, verbose_name=_('Default keywords'), help_text=_('Default keywords for parts in this category')) + default_keywords = models.CharField( + null=True, + blank=True, + max_length=250, + verbose_name=_('Default keywords'), + help_text=_('Default keywords for parts in this category'), + ) icon = models.CharField( blank=True, max_length=100, - verbose_name=_("Icon"), - help_text=_("Icon (optional)") + verbose_name=_('Icon'), + help_text=_('Icon (optional)'), ) @staticmethod @@ -131,8 +149,11 @@ class PartCategory(MetadataMixin, InvenTreeTree): """ if self.pk and self.structural and self.partcount(False, False) > 0: raise ValidationError( - _("You cannot make this part category structural because some parts " - "are already assigned to it!")) + _( + 'You cannot make this part category structural because some parts ' + 'are already assigned to it!' + ) + ) super().clean() def get_parts(self, cascade=True) -> set[Part]: @@ -146,7 +167,9 @@ class PartCategory(MetadataMixin, InvenTreeTree): """ if cascade: """Select any parts which exist in this category or any child categories.""" - queryset = Part.objects.filter(category__in=self.getUniqueChildren(include_self=True)) + queryset = Part.objects.filter( + category__in=self.getUniqueChildren(include_self=True) + ) else: queryset = Part.objects.filter(category=self.pk) @@ -172,7 +195,11 @@ class PartCategory(MetadataMixin, InvenTreeTree): def prefetch_parts_parameters(self, cascade=True): """Prefectch parts parameters.""" - return self.get_parts(cascade=cascade).prefetch_related('parameters', 'parameters__template').all() + return ( + self.get_parts(cascade=cascade) + .prefetch_related('parameters', 'parameters__template') + .all() + ) def get_unique_parameters(self, cascade=True, prefetch=None): """Get all unique parameter names for all parts from this category.""" @@ -233,7 +260,9 @@ class PartCategory(MetadataMixin, InvenTreeTree): def get_parameter_templates(self): """Return parameter templates associated to category.""" - prefetch = PartCategoryParameterTemplate.objects.prefetch_related('category', 'parameter_template') + prefetch = PartCategoryParameterTemplate.objects.prefetch_related( + 'category', 'parameter_template' + ) return prefetch.filter(category=self.id) @@ -244,13 +273,9 @@ class PartCategory(MetadataMixin, InvenTreeTree): subscribers = set() if include_parents: - queryset = PartCategoryStar.objects.filter( - category__in=cats, - ) + queryset = PartCategoryStar.objects.filter(category__in=cats) else: - queryset = PartCategoryStar.objects.filter( - category=self, - ) + queryset = PartCategoryStar.objects.filter(category=self) for result in queryset: subscribers.add(result.user) @@ -270,17 +295,11 @@ class PartCategory(MetadataMixin, InvenTreeTree): return if status: - PartCategoryStar.objects.create( - category=self, - user=user - ) + PartCategoryStar.objects.create(category=self, user=user) else: # Note that this won't actually stop the user being subscribed, # if the user is subscribed to a parent category - PartCategoryStar.objects.filter( - category=self, - user=user, - ).delete() + PartCategoryStar.objects.filter(category=self, user=user).delete() def rename_part_image(instance, filename): @@ -309,13 +328,17 @@ class PartManager(TreeManager): def get_queryset(self): """Perform default prefetch operations when accessing Part model from the database""" - return super().get_queryset().prefetch_related( - 'category', - 'pricing_data', - 'category__parent', - 'stock_items', - 'builds', - 'tags', + return ( + super() + .get_queryset() + .prefetch_related( + 'category', + 'pricing_data', + 'category__parent', + 'stock_items', + 'builds', + 'tags', + ) ) @@ -363,15 +386,17 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) class Meta: """Metaclass defines extra model properties""" - verbose_name = _("Part") - verbose_name_plural = _("Parts") - ordering = ['name', ] + + verbose_name = _('Part') + verbose_name_plural = _('Parts') + ordering = ['name'] constraints = [ UniqueConstraint(fields=['name', 'IPN', 'revision'], name='unique_part') ] class MPTTMeta: """MPTT metaclass definitions""" + # For legacy reasons the 'variant_of' field is used to indicate the MPTT parent parent_attr = 'variant_of' @@ -382,11 +407,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) def api_instance_filters(self): """Return API query filters for limiting field results against this instance.""" - return { - 'variant_of': { - 'exclude_tree': self.pk, - } - } + return {'variant_of': {'exclude_tree': self.pk}} def get_context_data(self, request, **kwargs): """Return some useful context data about this part for template rendering.""" @@ -397,9 +418,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) # Subscription status context['starred'] = self.is_starred_by(request.user) context['starred_directly'] = context['starred'] and self.is_starred_by( - request.user, - include_variants=False, - include_categories=False + request.user, include_variants=False, include_categories=False ) # Pre-calculate complex queries so they only need to be performed once @@ -411,13 +430,21 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) context['allocated_build_order_quantity'] = self.build_order_allocation_count() context['required_sales_order_quantity'] = self.required_sales_order_quantity() - context['allocated_sales_order_quantity'] = self.sales_order_allocation_count(pending=True) + context['allocated_sales_order_quantity'] = self.sales_order_allocation_count( + pending=True + ) context['available'] = self.available_stock context['on_order'] = self.on_order - context['required'] = context['required_build_order_quantity'] + context['required_sales_order_quantity'] - context['allocated'] = context['allocated_build_order_quantity'] + context['allocated_sales_order_quantity'] + context['required'] = ( + context['required_build_order_quantity'] + + context['required_sales_order_quantity'] + ) + context['allocated'] = ( + context['allocated_build_order_quantity'] + + context['allocated_sales_order_quantity'] + ) return context @@ -433,9 +460,12 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) # Image has been changed if previous.image is not None and self.image != previous.image: - # Are there any (other) parts which reference the image? - n_refs = Part.objects.filter(image=previous.image).exclude(pk=self.pk).count() + n_refs = ( + Part.objects.filter(image=previous.image) + .exclude(pk=self.pk) + .count() + ) if n_refs == 0: logger.info("Deleting unused image file '%s'", previous.image) @@ -448,13 +478,11 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) try: super().save(*args, **kwargs) except InvalidMove: - raise ValidationError({ - 'variant_of': _('Invalid choice for parent part'), - }) + raise ValidationError({'variant_of': _('Invalid choice for parent part')}) def __str__(self): """Return a string representation of the Part (for use in the admin interface)""" - return f"{self.full_name} - {self.description}" + return f'{self.full_name} - {self.description}' def get_parts_in_bom(self, **kwargs): """Return a list of all parts in the BOM for this part. @@ -492,26 +520,35 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) try: if self.pk == parent.pk: - raise ValidationError({'sub_part': _(f"Part '{self}' cannot be used in BOM for '{parent}' (recursive)")}) + raise ValidationError({ + 'sub_part': _( + f"Part '{self}' cannot be used in BOM for '{parent}' (recursive)" + ) + }) if self.tree_id == parent.tree_id: - raise ValidationError({'sub_part': _(f"Part '{self}' cannot be used in BOM for '{parent}' (recursive)")}) + raise ValidationError({ + 'sub_part': _( + f"Part '{self}' cannot be used in BOM for '{parent}' (recursive)" + ) + }) bom_items = self.get_bom_items() # Ensure that the parent part does not appear under any child BOM item! for item in bom_items.all(): - # Check for simple match if item.sub_part == parent: - raise ValidationError({'sub_part': _(f"Part '{parent}' is used in BOM for '{self}' (recursive)")}) + raise ValidationError({ + 'sub_part': _( + f"Part '{parent}' is used in BOM for '{self}' (recursive)" + ) + }) # And recursively check too if recursive: result = result and item.sub_part.check_add_to_bom( - parent, - recursive=True, - raise_error=raise_error + parent, recursive=True, raise_error=raise_error ) except ValidationError as e: @@ -539,9 +576,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) return except ValidationError as exc: if raise_error: - raise ValidationError({ - 'name': exc.message, - }) + raise ValidationError({'name': exc.message}) def validate_ipn(self, raise_error=True): """Ensure that the IPN (internal part number) is valid for this Part" @@ -560,12 +595,12 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) break except ValidationError as exc: if raise_error: - raise ValidationError({ - 'IPN': exc.message - }) + raise ValidationError({'IPN': exc.message}) # If we get to here, none of the plugins have raised an error - pattern = common.models.InvenTreeSetting.get_setting('PART_IPN_REGEX', '', create=False).strip() + pattern = common.models.InvenTreeSetting.get_setting( + 'PART_IPN_REGEX', '', create=False + ).strip() if pattern: match = re.search(pattern, self.IPN) @@ -573,7 +608,14 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) if match is None: raise ValidationError(_(f'IPN must match regex pattern {pattern}')) - def validate_serial_number(self, serial: str, stock_item=None, check_duplicates=True, raise_error=False, **kwargs): + def validate_serial_number( + self, + serial: str, + stock_item=None, + check_duplicates=True, + raise_error=False, + **kwargs, + ): """Validate a serial number against this Part instance. Note: This function is exposed to any Validation plugins, and thus can be customized. @@ -626,7 +668,9 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) from part.models import Part from stock.models import StockItem - if common.models.InvenTreeSetting.get_setting('SERIAL_NUMBER_GLOBALLY_UNIQUE', False): + if common.models.InvenTreeSetting.get_setting( + 'SERIAL_NUMBER_GLOBALLY_UNIQUE', False + ): # Serial number must be unique across *all* parts parts = Part.objects.all() else: @@ -641,7 +685,11 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) if stock.exists(): if raise_error: - raise ValidationError(_("Stock item with this serial number already exists") + ": " + serial) + raise ValidationError( + _('Stock item with this serial number already exists') + + ': ' + + serial + ) else: return False else: @@ -671,10 +719,14 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) Returns: The latest serial number specified for this part, or None """ - stock = StockModels.StockItem.objects.all().exclude(serial=None).exclude(serial='') + stock = ( + StockModels.StockItem.objects.all().exclude(serial=None).exclude(serial='') + ) # Generate a query for any stock items for this part variant tree with non-empty serial numbers - if common.models.InvenTreeSetting.get_setting('SERIAL_NUMBER_GLOBALLY_UNIQUE', False): + if common.models.InvenTreeSetting.get_setting( + 'SERIAL_NUMBER_GLOBALLY_UNIQUE', False + ): # Serial numbers are unique across all parts pass else: @@ -727,7 +779,9 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) super().validate_unique(exclude) # User can decide whether duplicate IPN (Internal Part Number) values are allowed - allow_duplicate_ipn = common.models.InvenTreeSetting.get_setting('PART_ALLOW_DUPLICATE_IPN') + allow_duplicate_ipn = common.models.InvenTreeSetting.get_setting( + 'PART_ALLOW_DUPLICATE_IPN' + ) # Raise an error if an IPN is set, and it is a duplicate if self.IPN and not allow_duplicate_ipn: @@ -736,12 +790,18 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) if parts.exists(): raise ValidationError({ - 'IPN': _('Duplicate IPN not allowed in part settings'), + 'IPN': _('Duplicate IPN not allowed in part settings') }) # Ensure unique across (Name, revision, IPN) (as specified) - if Part.objects.exclude(pk=self.pk).filter(name=self.name, revision=self.revision, IPN=self.IPN).exists(): - raise ValidationError(_("Part with this Name, IPN and Revision already exists.")) + if ( + Part.objects.exclude(pk=self.pk) + .filter(name=self.name, revision=self.revision, IPN=self.IPN) + .exists() + ): + raise ValidationError( + _('Part with this Name, IPN and Revision already exists.') + ) def clean(self): """Perform cleaning operations for the Part model. @@ -754,8 +814,9 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) then we will force the parent part to be trackable. """ if self.category is not None and self.category.structural: - raise ValidationError( - {'category': _("Parts cannot be assigned to structural part categories!")}) + raise ValidationError({ + 'category': _('Parts cannot be assigned to structural part categories!') + }) super().clean() @@ -771,81 +832,85 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) if self.trackable: for part in self.get_used_in(): - if not part.trackable: part.trackable = True part.clean() part.save() name = models.CharField( - max_length=100, blank=False, - help_text=_('Part name'), - verbose_name=_('Name'), + max_length=100, blank=False, help_text=_('Part name'), verbose_name=_('Name') ) is_template = models.BooleanField( default=part_settings.part_template_default, verbose_name=_('Is Template'), - help_text=_('Is this part a template part?') + help_text=_('Is this part a template part?'), ) variant_of = models.ForeignKey( - 'part.Part', related_name='variants', - null=True, blank=True, - limit_choices_to={ - 'is_template': True, - }, + 'part.Part', + related_name='variants', + null=True, + blank=True, + limit_choices_to={'is_template': True}, on_delete=models.SET_NULL, help_text=_('Is this part a variant of another part?'), verbose_name=_('Variant Of'), ) description = models.CharField( - max_length=250, blank=True, + max_length=250, + blank=True, verbose_name=_('Description'), - help_text=_('Part description (optional)') + help_text=_('Part description (optional)'), ) keywords = models.CharField( - max_length=250, blank=True, null=True, + max_length=250, + blank=True, + null=True, verbose_name=_('Keywords'), - help_text=_('Part keywords to improve visibility in search results') + help_text=_('Part keywords to improve visibility in search results'), ) category = TreeForeignKey( - PartCategory, related_name='parts', - null=True, blank=True, + PartCategory, + related_name='parts', + null=True, + blank=True, on_delete=models.DO_NOTHING, verbose_name=_('Category'), - help_text=_('Part category') + help_text=_('Part category'), ) IPN = models.CharField( - max_length=100, blank=True, null=True, + max_length=100, + blank=True, + null=True, verbose_name=_('IPN'), help_text=_('Internal Part Number'), ) revision = models.CharField( - max_length=100, blank=True, null=True, + max_length=100, + blank=True, + null=True, help_text=_('Part revision or version number'), verbose_name=_('Revision'), ) link = InvenTreeURLField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Link'), - help_text=_('Link to external URL') + help_text=_('Link to external URL'), ) image = StdImageField( upload_to=rename_part_image, null=True, blank=True, - variations={ - 'thumbnail': (128, 128), - 'preview': (256, 256), - }, + variations={'thumbnail': (128, 128), 'preview': (256, 256)}, delete_orphans=False, verbose_name=_('Image'), ) @@ -853,7 +918,8 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) default_location = TreeForeignKey( 'stock.StockLocation', on_delete=models.SET_NULL, - blank=True, null=True, + blank=True, + null=True, help_text=_('Where is this item normally stored?'), related_name='default_parts', verbose_name=_('Default Location'), @@ -898,10 +964,11 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) default_supplier = models.ForeignKey( SupplierPart, on_delete=models.SET_NULL, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Default Supplier'), help_text=_('Default supplier part'), - related_name='default_parts' + related_name='default_parts', ) default_expiry = models.PositiveIntegerField( @@ -912,81 +979,113 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) ) minimum_stock = models.DecimalField( - max_digits=19, decimal_places=6, - default=0, validators=[MinValueValidator(0)], + max_digits=19, + decimal_places=6, + default=0, + validators=[MinValueValidator(0)], verbose_name=_('Minimum Stock'), - help_text=_('Minimum allowed stock level') + help_text=_('Minimum allowed stock level'), ) units = models.CharField( - max_length=20, default="", - blank=True, null=True, + max_length=20, + default='', + blank=True, + null=True, verbose_name=_('Units'), help_text=_('Units of measure for this part'), - validators=[ - validators.validate_physical_units, - ] + validators=[validators.validate_physical_units], ) assembly = models.BooleanField( default=part_settings.part_assembly_default, verbose_name=_('Assembly'), - help_text=_('Can this part be built from other parts?') + help_text=_('Can this part be built from other parts?'), ) component = models.BooleanField( default=part_settings.part_component_default, verbose_name=_('Component'), - help_text=_('Can this part be used to build other parts?') + help_text=_('Can this part be used to build other parts?'), ) trackable = models.BooleanField( default=part_settings.part_trackable_default, verbose_name=_('Trackable'), - help_text=_('Does this part have tracking for unique items?')) + help_text=_('Does this part have tracking for unique items?'), + ) purchaseable = models.BooleanField( default=part_settings.part_purchaseable_default, verbose_name=_('Purchaseable'), - help_text=_('Can this part be purchased from external suppliers?')) + help_text=_('Can this part be purchased from external suppliers?'), + ) salable = models.BooleanField( default=part_settings.part_salable_default, verbose_name=_('Salable'), - help_text=_("Can this part be sold to customers?")) + help_text=_('Can this part be sold to customers?'), + ) active = models.BooleanField( - default=True, - verbose_name=_('Active'), - help_text=_('Is this part active?')) + default=True, verbose_name=_('Active'), help_text=_('Is this part active?') + ) virtual = models.BooleanField( default=part_settings.part_virtual_default, verbose_name=_('Virtual'), - help_text=_('Is this a virtual part, such as a software product or license?')) + help_text=_('Is this a virtual part, such as a software product or license?'), + ) - bom_checksum = models.CharField(max_length=128, blank=True, verbose_name=_('BOM checksum'), help_text=_('Stored BOM checksum')) + bom_checksum = models.CharField( + max_length=128, + blank=True, + verbose_name=_('BOM checksum'), + help_text=_('Stored BOM checksum'), + ) - bom_checked_by = models.ForeignKey(User, on_delete=models.SET_NULL, blank=True, null=True, - verbose_name=_('BOM checked by'), related_name='boms_checked') + bom_checked_by = models.ForeignKey( + User, + on_delete=models.SET_NULL, + blank=True, + null=True, + verbose_name=_('BOM checked by'), + related_name='boms_checked', + ) - bom_checked_date = models.DateField(blank=True, null=True, verbose_name=_('BOM checked date')) + bom_checked_date = models.DateField( + blank=True, null=True, verbose_name=_('BOM checked date') + ) - creation_date = models.DateField(auto_now_add=True, editable=False, blank=True, null=True, verbose_name=_('Creation Date')) + creation_date = models.DateField( + auto_now_add=True, + editable=False, + blank=True, + null=True, + verbose_name=_('Creation Date'), + ) - creation_user = models.ForeignKey(User, on_delete=models.SET_NULL, blank=True, null=True, verbose_name=_('Creation User'), related_name='parts_created') + creation_user = models.ForeignKey( + User, + on_delete=models.SET_NULL, + blank=True, + null=True, + verbose_name=_('Creation User'), + related_name='parts_created', + ) responsible_owner = models.ForeignKey( - users.models.Owner, on_delete=models.SET_NULL, - blank=True, null=True, + users.models.Owner, + on_delete=models.SET_NULL, + blank=True, + null=True, verbose_name=_('Responsible'), help_text=_('Owner responsible for this part'), - related_name='parts_responsible' + related_name='parts_responsible', ) last_stocktake = models.DateField( - blank=True, null=True, - verbose_name=_('Last Stocktake'), + blank=True, null=True, verbose_name=_('Last Stocktake') ) @property @@ -1013,8 +1112,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) # Now, get a list of outstanding build orders which require this part builds = BuildModels.Build.objects.filter( - part__in=self.get_used_in(), - status__in=BuildStatusGroups.ACTIVE_CODES + part__in=self.get_used_in(), status__in=BuildStatusGroups.ACTIVE_CODES ) return builds @@ -1027,7 +1125,6 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) quantity = 0 for build in builds: - bom_item = None # List the bom lines required to make the build (including inherited ones!) @@ -1035,7 +1132,6 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) # Match BOM item to build for bom_item in bom_items: - build_quantity = build.quantity * bom_item.quantity quantity += build_quantity @@ -1048,8 +1144,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) # Get a list of line items for open orders which match this part open_lines = OrderModels.SalesOrderLineItem.objects.filter( - order__status__in=SalesOrderStatusGroups.OPEN, - part=self + order__status__in=SalesOrderStatusGroups.OPEN, part=self ) for line in open_lines: @@ -1061,8 +1156,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) """Return the quantity of this part required for active sales orders.""" # Get a list of line items for open orders which match this part open_lines = OrderModels.SalesOrderLineItem.objects.filter( - order__status__in=SalesOrderStatusGroups.OPEN, - part=self + order__status__in=SalesOrderStatusGroups.OPEN, part=self ) quantity = 0 @@ -1076,7 +1170,9 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) def required_order_quantity(self): """Return total required to fulfil orders.""" - return self.required_build_order_quantity() + self.required_sales_order_quantity() + return ( + self.required_build_order_quantity() + self.required_sales_order_quantity() + ) @property def quantity_to_order(self): @@ -1136,9 +1232,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) queryset = PartStar.objects.all() if include_variants: - queryset = queryset.filter( - part__in=self.get_ancestors(include_self=True), - ) + queryset = queryset.filter(part__in=self.get_ancestors(include_self=True)) else: queryset = queryset.filter(part=self) @@ -1146,7 +1240,6 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) subscribers.add(star.user) if include_categories and self.category: - for sub in self.category.get_subscribers(): subscribers.add(sub) @@ -1217,13 +1310,19 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) ref = 'substitutes__part__' queryset = queryset.alias( sub_total_stock=part.filters.annotate_total_stock(reference=ref), - sub_so_allocations=part.filters.annotate_sales_order_allocations(reference=ref), - sub_bo_allocations=part.filters.annotate_build_order_allocations(reference=ref), + sub_so_allocations=part.filters.annotate_sales_order_allocations( + reference=ref + ), + sub_bo_allocations=part.filters.annotate_build_order_allocations( + reference=ref + ), ) queryset = queryset.annotate( substitute_stock=ExpressionWrapper( - F('sub_total_stock') - F('sub_so_allocations') - F('sub_bo_allocations'), + F('sub_total_stock') + - F('sub_so_allocations') + - F('sub_bo_allocations'), output_field=models.DecimalField(), ) ) @@ -1232,20 +1331,27 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) variant_stock_query = part.filters.variant_stock_query(reference='sub_part__') queryset = queryset.alias( - var_total_stock=part.filters.annotate_variant_quantity(variant_stock_query, reference='quantity'), - var_bo_allocations=part.filters.annotate_variant_quantity(variant_stock_query, reference='allocations__quantity'), - var_so_allocations=part.filters.annotate_variant_quantity(variant_stock_query, reference='sales_order_allocations__quantity'), + var_total_stock=part.filters.annotate_variant_quantity( + variant_stock_query, reference='quantity' + ), + var_bo_allocations=part.filters.annotate_variant_quantity( + variant_stock_query, reference='allocations__quantity' + ), + var_so_allocations=part.filters.annotate_variant_quantity( + variant_stock_query, reference='sales_order_allocations__quantity' + ), ) queryset = queryset.annotate( variant_stock=ExpressionWrapper( - F('var_total_stock') - F('var_bo_allocations') - F('var_so_allocations'), + F('var_total_stock') + - F('var_bo_allocations') + - F('var_so_allocations'), output_field=models.DecimalField(), ) ) for item in queryset.all(): - if item.quantity <= 0: # Ignore zero-quantity items continue @@ -1299,9 +1405,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) if include_variants: variants = self.get_descendants(include_self=True) - queryset = queryset.filter( - stock_item__part__in=variants, - ) + queryset = queryset.filter(stock_item__part__in=variants) else: queryset = queryset.filter(stock_item__part=self) @@ -1311,10 +1415,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) """Return the total amount of this part allocated to build orders.""" query = self.build_order_allocations(**kwargs).aggregate( total=Coalesce( - Sum( - 'quantity', - output_field=models.DecimalField() - ), + Sum('quantity', output_field=models.DecimalField()), 0, output_field=models.DecimalField(), ) @@ -1331,9 +1432,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) if include_variants: # Include allocations for all variants variants = self.get_descendants(include_self=True) - queryset = queryset.filter( - item__part__in=variants, - ) + queryset = queryset.filter(item__part__in=variants) else: # Only look at this part queryset = queryset.filter(item__part=self) @@ -1360,10 +1459,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) """Return the total quantity of this part allocated to sales orders.""" query = self.sales_order_allocations(**kwargs).aggregate( total=Coalesce( - Sum( - 'quantity', - output_field=models.DecimalField(), - ), + Sum('quantity', output_field=models.DecimalField()), 0, output_field=models.DecimalField(), ) @@ -1377,12 +1473,10 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) # If this instance has not been saved, foreign-key lookups will fail return 0 - return sum( - [ - self.build_order_allocation_count(**kwargs), - self.sales_order_allocation_count(**kwargs), - ], - ) + return sum([ + self.build_order_allocation_count(**kwargs), + self.sales_order_allocation_count(**kwargs), + ]) def stock_entries(self, include_variants=True, in_stock=None, location=None): """Return all stock entries for this Part. @@ -1393,7 +1487,9 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) location: If set, filter by stock entries in the specified location """ if include_variants: - query = StockModels.StockItem.objects.filter(part__in=self.get_descendants(include_self=True)) + query = StockModels.StockItem.objects.filter( + part__in=self.get_descendants(include_self=True) + ) else: query = self.stock_items @@ -1448,11 +1544,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) # There are parents available if parents.exists(): - - parent_filter = Q( - part__in=parents, - inherited=True - ) + parent_filter = Q(part__in=parents, inherited=True) # OR the filters together bom_filter |= parent_filter @@ -1464,11 +1556,15 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) By default, will include inherited BOM items """ - queryset = BomItem.objects.filter(self.get_bom_item_filter(include_inherited=include_inherited)) + queryset = BomItem.objects.filter( + self.get_bom_item_filter(include_inherited=include_inherited) + ) return queryset.prefetch_related('sub_part') - def get_installed_part_options(self, include_inherited: bool = True, include_variants: bool = True): + def get_installed_part_options( + self, include_inherited: bool = True, include_variants: bool = True + ): """Return a set of all Parts which can be "installed" into this part, based on the BOM. Arguments: @@ -1478,7 +1574,6 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) parts = set() for bom_item in self.get_bom_items(include_inherited=include_inherited): - if include_variants and bom_item.allow_variants: for part in bom_item.sub_part.get_descendants(include_self=True): parts.add(part) @@ -1487,7 +1582,9 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) return parts - def get_used_in_bom_item_filter(self, include_variants=True, include_substitutes=True): + def get_used_in_bom_item_filter( + self, include_variants=True, include_substitutes=True + ): """Return a BomItem queryset which returns all BomItem instances which refer to *this* part. As the BOM allocation logic is somewhat complicted, there are some considerations: @@ -1506,26 +1603,18 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) parents = [] # Case A: This part is directly specified in a BomItem (we always use this case) - query = Q( - sub_part=self, - ) + query = Q(sub_part=self) if include_variants: # Case B: This part is a *variant* of a part which is specified in a BomItem which allows variants - query |= Q( - allow_variants=True, - sub_part__in=parents, - ) + query |= Q(allow_variants=True, sub_part__in=parents) # Case C: This part is a *substitute* of a part which is directly specified in a BomItem if include_substitutes: - # Grab a list of BomItem substitutes which reference this part substitutes = self.substitute_items.all() - query |= Q( - pk__in=[substitute.bom_item.pk for substitute in substitutes], - ) + query |= Q(pk__in=[substitute.bom_item.pk for substitute in substitutes]) return query @@ -1536,9 +1625,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) """ # Grab a queryset of all BomItem objects which "require" this part bom_items = BomItem.objects.filter( - self.get_used_in_bom_item_filter( - include_substitutes=include_substitutes - ) + self.get_used_in_bom_item_filter(include_substitutes=include_substitutes) ) # Iterate through the returned items and construct a set of @@ -1653,11 +1740,9 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) bom_items = self.get_bom_items() for bom_item in bom_items: - sub_part = bom_item.sub_part if sub_part not in parts: - parts.add(sub_part) if recursive: @@ -1738,7 +1823,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) min_price = normalize(min_price) max_price = normalize(max_price) - return f"{min_price} - {max_price}" + return f'{min_price} - {max_price}' def get_supplier_price_range(self, quantity=1): """Return the supplier price range of this part: @@ -1756,7 +1841,6 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) max_price = None for supplier in self.supplier_parts.all(): - price = supplier.get_price(quantity) if price is None: @@ -1787,15 +1871,16 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) max_price = None for item in self.get_bom_items().select_related('sub_part'): - if item.sub_part.pk == self.pk: - logger.warning("WARNING: BomItem ID %s contains itself in BOM", item.pk) + logger.warning('WARNING: BomItem ID %s contains itself in BOM', item.pk) continue q = decimal.Decimal(quantity) i = decimal.Decimal(item.quantity) - prices = item.sub_part.get_price_range(q * i, internal=internal, purchase=purchase) + prices = item.sub_part.get_price_range( + q * i, internal=internal, purchase=purchase + ) if prices is None: continue @@ -1819,7 +1904,9 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) return (min_price, max_price) - def get_price_range(self, quantity=1, buy=True, bom=True, internal=False, purchase=False): + def get_price_range( + self, quantity=1, buy=True, bom=True, internal=False, purchase=False + ): """Return the price range for this part. This price can be either: @@ -1843,7 +1930,9 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) return purchase_price buy_price_range = self.get_supplier_price_range(quantity) if buy else None - bom_price_range = self.get_bom_price_range(quantity, internal=internal) if bom else None + bom_price_range = ( + self.get_bom_price_range(quantity, internal=internal) if bom else None + ) if buy_price_range is None: return bom_price_range @@ -1852,12 +1941,24 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) return buy_price_range return ( min(buy_price_range[0], bom_price_range[0]), - max(buy_price_range[1], bom_price_range[1]) + max(buy_price_range[1], bom_price_range[1]), ) - base_cost = models.DecimalField(max_digits=19, decimal_places=6, default=0, validators=[MinValueValidator(0)], verbose_name=_('base cost'), help_text=_('Minimum charge (e.g. stocking fee)')) + base_cost = models.DecimalField( + max_digits=19, + decimal_places=6, + default=0, + validators=[MinValueValidator(0)], + verbose_name=_('base cost'), + help_text=_('Minimum charge (e.g. stocking fee)'), + ) - multiple = models.PositiveIntegerField(default=1, validators=[MinValueValidator(1)], verbose_name=_('multiple'), help_text=_('Sell multiple')) + multiple = models.PositiveIntegerField( + default=1, + validators=[MinValueValidator(1)], + verbose_name=_('multiple'), + help_text=_('Sell multiple'), + ) get_price = common.models.get_price @@ -1887,15 +1988,13 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) if self.price_breaks.filter(quantity=quantity, part=self.pk).exists(): return - PartSellPriceBreak.objects.create( - part=self, - quantity=quantity, - price=price - ) + PartSellPriceBreak.objects.create(part=self, quantity=quantity, price=price) def get_internal_price(self, quantity, moq=True, multiples=True, currency=None): """Return the internal price of this Part at the specified quantity""" - return common.models.get_price(self, quantity, moq, multiples, currency, break_name='internal_price_breaks') + return common.models.get_price( + self, quantity, moq, multiples, currency, break_name='internal_price_breaks' + ) @property def has_internal_price_breaks(self): @@ -1915,7 +2014,11 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) """ currency = currency_code_default() try: - prices = [convert_money(item.purchase_price, currency).amount for item in self.stock_items.all() if item.purchase_price] + prices = [ + convert_money(item.purchase_price, currency).amount + for item in self.stock_items.all() + if item.purchase_price + ] except MissingRate: prices = None @@ -1995,14 +2098,15 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) self.get_parameters().delete() for parameter in other.get_parameters(): - # If this part already has a parameter pointing to the same template, # delete that parameter from this part first! try: - existing = PartParameter.objects.get(part=self, template=parameter.template) + existing = PartParameter.objects.get( + part=self, template=parameter.template + ) existing.delete() - except (PartParameter.DoesNotExist): + except PartParameter.DoesNotExist: pass parameter.part = self @@ -2020,7 +2124,9 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) include_parent: Set to True to traverse upwards """ if include_parent: - tests = PartTestTemplate.objects.filter(part__in=self.get_ancestors(include_self=True)) + tests = PartTestTemplate.objects.filter( + part__in=self.get_ancestors(include_self=True) + ) else: tests = self.test_templates @@ -2074,7 +2180,9 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) """Return a list of purchase orders which reference this part.""" orders = [] - for part in self.supplier_parts.all().prefetch_related('purchase_order_line_items'): + for part in self.supplier_parts.all().prefetch_related( + 'purchase_order_line_items' + ): for order in part.purchase_orders(): if order not in orders: orders.append(order) @@ -2092,7 +2200,6 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) # Iterate through all supplier parts for sp in self.supplier_parts.all(): - # Look at any incomplete line item for open orders lines = sp.purchase_order_line_items.filter( order__status__in=PurchaseOrderStatusGroups.OPEN, @@ -2191,10 +2298,7 @@ class Part(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, MPTTModel) # Ensure this part is not in the queryset, somehow filtered_parts = filtered_parts.exclude(pk=self.pk) - filtered_parts = filtered_parts.filter( - active=True, - virtual=False, - ) + filtered_parts = filtered_parts.filter(active=True, virtual=False) return filtered_parts @@ -2238,16 +2342,16 @@ def after_save_part(sender, instance: Part, created, **kwargs): # Run this check in the background try: - InvenTree.tasks.offload_task(part_tasks.notify_low_stock_if_required, instance) + InvenTree.tasks.offload_task( + part_tasks.notify_low_stock_if_required, instance + ) except PicklingError: # Can sometimes occur if the referenced Part has issues pass # Schedule a background task to rebuild any supplier parts InvenTree.tasks.offload_task( - part_tasks.rebuild_supplier_parts, - instance.pk, - force_async=True + part_tasks.rebuild_supplier_parts, instance.pk, force_async=True ) @@ -2297,7 +2401,11 @@ class PartPricing(common.models.MetaMixin): try: result = convert_money(money, target_currency) except MissingRate: - logger.warning("No currency conversion rate available for %s -> %s", money.currency, target_currency) + logger.warning( + 'No currency conversion rate available for %s -> %s', + money.currency, + target_currency, + ) result = None return result @@ -2318,8 +2426,14 @@ class PartPricing(common.models.MetaMixin): if InvenTree.ready.isRunningMigrations(): return - if not self.part or not self.part.pk or not Part.objects.filter(pk=self.part.pk).exists(): - logger.warning("Referenced part instance does not exist - skipping pricing update.") + if ( + not self.part + or not self.part.pk + or not Part.objects.filter(pk=self.part.pk).exists() + ): + logger.warning( + 'Referenced part instance does not exist - skipping pricing update.' + ) return try: @@ -2327,7 +2441,9 @@ class PartPricing(common.models.MetaMixin): self.refresh_from_db() except (PartPricing.DoesNotExist, IntegrityError): # Error thrown if this PartPricing instance has already been removed - logger.warning("Error refreshing PartPricing instance for part '%s'", self.part) + logger.warning( + "Error refreshing PartPricing instance for part '%s'", self.part + ) return # Ensure that the referenced part still exists in the database @@ -2335,17 +2451,21 @@ class PartPricing(common.models.MetaMixin): p = self.part p.refresh_from_db() except IntegrityError: - logger.exception("Could not update PartPricing as Part '%s' does not exist", self.part) + logger.exception( + "Could not update PartPricing as Part '%s' does not exist", self.part + ) return if self.scheduled_for_update: # Ignore if the pricing is already scheduled to be updated - logger.debug("Pricing for %s already scheduled for update - skipping", p) + logger.debug('Pricing for %s already scheduled for update - skipping', p) return if counter > 25: # Prevent infinite recursion / stack depth issues - logger.debug(counter, f"Skipping pricing update for {p} - maximum depth exceeded") + logger.debug( + counter, f'Skipping pricing update for {p} - maximum depth exceeded' + ) return try: @@ -2353,7 +2473,9 @@ class PartPricing(common.models.MetaMixin): self.save() except IntegrityError: # An IntegrityError here likely indicates that the referenced part has already been deleted - logger.exception("Could not save PartPricing for part '%s' to the database", self.part) + logger.exception( + "Could not save PartPricing for part '%s' to the database", self.part + ) return import part.tasks as part_tasks @@ -2361,10 +2483,7 @@ class PartPricing(common.models.MetaMixin): # Offload task to update the pricing # Force async, to prevent running in the foreground InvenTree.tasks.offload_task( - part_tasks.update_part_pricing, - self, - counter=counter, - force_async=True + part_tasks.update_part_pricing, self, counter=counter, force_async=True ) def update_pricing(self, counter: int = 0, cascade: bool = True): @@ -2515,7 +2634,10 @@ class PartPricing(common.models.MetaMixin): else: self.bom_cost_max = None - if old_bom_cost_min != self.bom_cost_min or old_bom_cost_max != self.bom_cost_max: + if ( + old_bom_cost_min != self.bom_cost_min + or old_bom_cost_max != self.bom_cost_max + ): self.price_modified = True if save: @@ -2540,12 +2662,13 @@ class PartPricing(common.models.MetaMixin): purchase_max = None for line in line_items: - if line.purchase_price is None: continue # Take supplier part pack size into account - purchase_cost = self.convert(line.purchase_price / line.part.pack_quantity_native) + purchase_cost = self.convert( + line.purchase_price / line.part.pack_quantity_native + ) if purchase_cost is None: continue @@ -2558,11 +2681,14 @@ class PartPricing(common.models.MetaMixin): # Also check if manual stock item pricing is included if InvenTreeSetting.get_setting('PRICING_USE_STOCK_PRICING', True, cache=False): - items = self.part.stock_items.all() # Limit to stock items updated within a certain window - days = int(InvenTreeSetting.get_setting('PRICING_STOCK_ITEM_AGE_DAYS', 0, cache=False)) + days = int( + InvenTreeSetting.get_setting( + 'PRICING_STOCK_ITEM_AGE_DAYS', 0, cache=False + ) + ) if days > 0: date_threshold = datetime.now().date() - timedelta(days=days) @@ -2581,7 +2707,10 @@ class PartPricing(common.models.MetaMixin): if purchase_max is None or cost > purchase_max: purchase_max = cost - if self.purchase_cost_min != purchase_min or self.purchase_cost_max != purchase_max: + if ( + self.purchase_cost_min != purchase_min + or self.purchase_cost_max != purchase_max + ): self.price_modified = True self.purchase_cost_min = purchase_min @@ -2610,7 +2739,10 @@ class PartPricing(common.models.MetaMixin): if max_int_cost is None or cost > max_int_cost: max_int_cost = cost - if self.internal_cost_min != min_int_cost or self.internal_cost_max != max_int_cost: + if ( + self.internal_cost_min != min_int_cost + or self.internal_cost_max != max_int_cost + ): self.price_modified = True self.internal_cost_min = min_int_cost @@ -2629,13 +2761,10 @@ class PartPricing(common.models.MetaMixin): max_sup_cost = None if self.part.purchaseable: - # Iterate through each available SupplierPart instance for sp in self.part.supplier_parts.all(): - # Iterate through each available SupplierPriceBreak instance for pb in sp.pricebreaks.all(): - if pb.price is None: continue @@ -2651,7 +2780,10 @@ class PartPricing(common.models.MetaMixin): if max_sup_cost is None or cost > max_sup_cost: max_sup_cost = cost - if self.supplier_price_min != min_sup_cost or self.supplier_price_max != max_sup_cost: + if ( + self.supplier_price_min != min_sup_cost + or self.supplier_price_max != max_sup_cost + ): self.price_modified = True self.supplier_price_min = min_sup_cost @@ -2674,7 +2806,6 @@ class PartPricing(common.models.MetaMixin): variants = self.part.get_descendants(include_self=False) for v in variants: - if active_only and not v.active: # Ignore inactive variant parts continue @@ -2708,21 +2839,17 @@ class PartPricing(common.models.MetaMixin): overall_min = None overall_max = None - min_costs = [ - self.bom_cost_min, - self.purchase_cost_min, - self.internal_cost_min, - ] + min_costs = [self.bom_cost_min, self.purchase_cost_min, self.internal_cost_min] - max_costs = [ - self.bom_cost_max, - self.purchase_cost_max, - self.internal_cost_max, - ] + max_costs = [self.bom_cost_max, self.purchase_cost_max, self.internal_cost_max] - purchase_history_override = InvenTreeSetting.get_setting('PRICING_PURCHASE_HISTORY_OVERRIDES_SUPPLIER', False, cache=False) + purchase_history_override = InvenTreeSetting.get_setting( + 'PRICING_PURCHASE_HISTORY_OVERRIDES_SUPPLIER', False, cache=False + ) - if InvenTreeSetting.get_setting('PRICING_USE_SUPPLIER_PRICING', True, cache=False): + if InvenTreeSetting.get_setting( + 'PRICING_USE_SUPPLIER_PRICING', True, cache=False + ): # Add supplier pricing data, *unless* historical pricing information should override if self.purchase_cost_min is None or not purchase_history_override: min_costs.append(self.supplier_price_min) @@ -2730,7 +2857,9 @@ class PartPricing(common.models.MetaMixin): if self.purchase_cost_max is None or not purchase_history_override: max_costs.append(self.supplier_price_max) - if InvenTreeSetting.get_setting('PRICING_USE_VARIANT_PRICING', True, cache=False): + if InvenTreeSetting.get_setting( + 'PRICING_USE_VARIANT_PRICING', True, cache=False + ): # Include variant pricing in overall calculations min_costs.append(self.variant_cost_min) max_costs.append(self.variant_cost_max) @@ -2757,7 +2886,9 @@ class PartPricing(common.models.MetaMixin): if overall_max is None or cost > overall_max: overall_max = cost - if InvenTreeSetting.get_setting('PART_BOM_USE_INTERNAL_PRICE', False, cache=False): + if InvenTreeSetting.get_setting( + 'PART_BOM_USE_INTERNAL_PRICE', False, cache=False + ): # Check if internal pricing should override other pricing if self.internal_cost_min is not None: overall_min = self.internal_cost_min @@ -2782,7 +2913,6 @@ class PartPricing(common.models.MetaMixin): max_sell_price = None for pb in self.part.salepricebreaks.all(): - cost = self.convert(pb.price) if cost is None: @@ -2803,15 +2933,13 @@ class PartPricing(common.models.MetaMixin): # Find all line items for shipped sales orders which reference this part line_items = OrderModels.SalesOrderLineItem.objects.filter( - order__status=SalesOrderStatus.SHIPPED, - part=self.part + order__status=SalesOrderStatus.SHIPPED, part=self.part ) # Exclude line items which do not have associated pricing data line_items = line_items.exclude(sale_price=None) for line in line_items: - cost = self.convert(line.sale_price) if cost is None: @@ -2823,7 +2951,10 @@ class PartPricing(common.models.MetaMixin): if max_sell_history is None or cost > max_sell_history: max_sell_history = cost - if self.sale_history_min != min_sell_history or self.sale_history_max != max_sell_history: + if ( + self.sale_history_min != min_sell_history + or self.sale_history_max != max_sell_history + ): self.price_modified = True self.sale_history_min = min_sell_history @@ -2840,9 +2971,7 @@ class PartPricing(common.models.MetaMixin): choices=common.settings.currency_code_mappings(), ) - scheduled_for_update = models.BooleanField( - default=False, - ) + scheduled_for_update = models.BooleanField(default=False) part = models.OneToOneField( Part, @@ -2852,109 +2981,127 @@ class PartPricing(common.models.MetaMixin): ) bom_cost_min = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Minimum BOM Cost'), - help_text=_('Minimum cost of component parts') + help_text=_('Minimum cost of component parts'), ) bom_cost_max = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Maximum BOM Cost'), help_text=_('Maximum cost of component parts'), ) purchase_cost_min = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Minimum Purchase Cost'), help_text=_('Minimum historical purchase cost'), ) purchase_cost_max = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Maximum Purchase Cost'), help_text=_('Maximum historical purchase cost'), ) internal_cost_min = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Minimum Internal Price'), help_text=_('Minimum cost based on internal price breaks'), ) internal_cost_max = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Maximum Internal Price'), help_text=_('Maximum cost based on internal price breaks'), ) supplier_price_min = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Minimum Supplier Price'), help_text=_('Minimum price of part from external suppliers'), ) supplier_price_max = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Maximum Supplier Price'), help_text=_('Maximum price of part from external suppliers'), ) variant_cost_min = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Minimum Variant Cost'), help_text=_('Calculated minimum cost of variant parts'), ) variant_cost_max = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Maximum Variant Cost'), help_text=_('Calculated maximum cost of variant parts'), ) override_min = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Minimum Cost'), help_text=_('Override minimum cost'), ) override_max = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Maximum Cost'), help_text=_('Override maximum cost'), ) overall_min = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Minimum Cost'), help_text=_('Calculated overall minimum cost'), ) overall_max = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Maximum Cost'), help_text=_('Calculated overall maximum cost'), ) sale_price_min = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Minimum Sale Price'), help_text=_('Minimum sale price based on price breaks'), ) sale_price_max = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Maximum Sale Price'), help_text=_('Maximum sale price based on price breaks'), ) sale_history_min = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Minimum Sale Cost'), help_text=_('Minimum historical sale price'), ) sale_history_max = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Maximum Sale Cost'), help_text=_('Maximum historical sale price'), ) @@ -2985,7 +3132,8 @@ class PartStocktake(models.Model): ) quantity = models.DecimalField( - max_digits=19, decimal_places=5, + max_digits=19, + decimal_places=5, validators=[MinValueValidator(0)], verbose_name=_('Quantity'), help_text=_('Total available stock at time of stocktake'), @@ -2994,7 +3142,7 @@ class PartStocktake(models.Model): date = models.DateField( verbose_name=_('Date'), help_text=_('Date stocktake was performed'), - auto_now_add=True + auto_now_add=True, ) note = models.CharField( @@ -3005,7 +3153,9 @@ class PartStocktake(models.Model): ) user = models.ForeignKey( - User, blank=True, null=True, + User, + blank=True, + null=True, on_delete=models.SET_NULL, related_name='part_stocktakes', verbose_name=_('User'), @@ -3013,13 +3163,15 @@ class PartStocktake(models.Model): ) cost_min = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Minimum Stock Cost'), help_text=_('Estimated minimum cost of stock on hand'), ) cost_max = InvenTree.fields.InvenTreeModelMoneyField( - null=True, blank=True, + null=True, + blank=True, verbose_name=_('Maximum Stock Cost'), help_text=_('Estimated maximum cost of stock on hand'), ) @@ -3071,14 +3223,12 @@ class PartStocktakeReport(models.Model): return self.report.url return None - date = models.DateField( - verbose_name=_('Date'), - auto_now_add=True - ) + date = models.DateField(verbose_name=_('Date'), auto_now_add=True) report = models.FileField( upload_to=save_stocktake_report, - unique=False, blank=False, + unique=False, + blank=False, verbose_name=_('Report'), help_text=_('Stocktake report file (generated internally)'), ) @@ -3090,7 +3240,9 @@ class PartStocktakeReport(models.Model): ) user = models.ForeignKey( - User, blank=True, null=True, + User, + blank=True, + null=True, on_delete=models.SET_NULL, related_name='stocktake_reports', verbose_name=_('User'), @@ -3108,10 +3260,14 @@ class PartAttachment(InvenTreeAttachment): def getSubdir(self): """Returns the media subdirectory where part attachments are stored""" - return os.path.join("part_files", str(self.part.id)) + return os.path.join('part_files', str(self.part.id)) - part = models.ForeignKey(Part, on_delete=models.CASCADE, - verbose_name=_('Part'), related_name='attachments') + part = models.ForeignKey( + Part, + on_delete=models.CASCADE, + verbose_name=_('Part'), + related_name='attachments', + ) class PartSellPriceBreak(common.models.PriceBreak): @@ -3119,6 +3275,7 @@ class PartSellPriceBreak(common.models.PriceBreak): class Meta: """Metaclass providing extra model definition""" + unique_together = ('part', 'quantity') @staticmethod @@ -3127,10 +3284,11 @@ class PartSellPriceBreak(common.models.PriceBreak): return reverse('api-part-sale-price-list') part = models.ForeignKey( - Part, on_delete=models.CASCADE, + Part, + on_delete=models.CASCADE, related_name='salepricebreaks', limit_choices_to={'salable': True}, - verbose_name=_('Part') + verbose_name=_('Part'), ) @@ -3139,6 +3297,7 @@ class PartInternalPriceBreak(common.models.PriceBreak): class Meta: """Metaclass providing extra model definition""" + unique_together = ('part', 'quantity') @staticmethod @@ -3147,9 +3306,10 @@ class PartInternalPriceBreak(common.models.PriceBreak): return reverse('api-part-internal-price-list') part = models.ForeignKey( - Part, on_delete=models.CASCADE, + Part, + on_delete=models.CASCADE, related_name='internalpricebreaks', - verbose_name=_('Part') + verbose_name=_('Part'), ) @@ -3165,14 +3325,22 @@ class PartStar(models.Model): class Meta: """Metaclass providing extra model definition""" - unique_together = [ - 'part', - 'user' - ] - part = models.ForeignKey(Part, on_delete=models.CASCADE, verbose_name=_('Part'), related_name='starred_users') + unique_together = ['part', 'user'] - user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name=_('User'), related_name='starred_parts') + part = models.ForeignKey( + Part, + on_delete=models.CASCADE, + verbose_name=_('Part'), + related_name='starred_users', + ) + + user = models.ForeignKey( + User, + on_delete=models.CASCADE, + verbose_name=_('User'), + related_name='starred_parts', + ) class PartCategoryStar(models.Model): @@ -3185,14 +3353,22 @@ class PartCategoryStar(models.Model): class Meta: """Metaclass providing extra model definition""" - unique_together = [ - 'category', - 'user', - ] - category = models.ForeignKey(PartCategory, on_delete=models.CASCADE, verbose_name=_('Category'), related_name='starred_users') + unique_together = ['category', 'user'] - user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name=_('User'), related_name='starred_categories') + category = models.ForeignKey( + PartCategory, + on_delete=models.CASCADE, + verbose_name=_('Category'), + related_name='starred_users', + ) + + user = models.ForeignKey( + User, + on_delete=models.CASCADE, + verbose_name=_('User'), + related_name='starred_categories', + ) class PartTestTemplate(MetadataMixin, models.Model): @@ -3247,7 +3423,7 @@ class PartTestTemplate(MetadataMixin, models.Model): for test in tests: if test.key == key: raise ValidationError({ - 'test_name': _("Test with this name already exists for this part") + 'test_name': _('Test with this name already exists for this part') }) super().validate_unique(exclude) @@ -3266,33 +3442,38 @@ class PartTestTemplate(MetadataMixin, models.Model): ) test_name = models.CharField( - blank=False, max_length=100, - verbose_name=_("Test Name"), - help_text=_("Enter a name for the test") + blank=False, + max_length=100, + verbose_name=_('Test Name'), + help_text=_('Enter a name for the test'), ) description = models.CharField( - blank=False, null=True, max_length=100, - verbose_name=_("Test Description"), - help_text=_("Enter description for this test") + blank=False, + null=True, + max_length=100, + verbose_name=_('Test Description'), + help_text=_('Enter description for this test'), ) required = models.BooleanField( default=True, - verbose_name=_("Required"), - help_text=_("Is this test required to pass?") + verbose_name=_('Required'), + help_text=_('Is this test required to pass?'), ) requires_value = models.BooleanField( default=False, - verbose_name=_("Requires Value"), - help_text=_("Does this test require a value when adding a test result?") + verbose_name=_('Requires Value'), + help_text=_('Does this test require a value when adding a test result?'), ) requires_attachment = models.BooleanField( default=False, - verbose_name=_("Requires Attachment"), - help_text=_("Does this test require a file attachment when adding a test result?") + verbose_name=_('Requires Attachment'), + help_text=_( + 'Does this test require a file attachment when adding a test result?' + ), ) @@ -3322,7 +3503,7 @@ class PartParameterTemplate(MetadataMixin, models.Model): """Return a string representation of a PartParameterTemplate instance""" s = str(self.name) if self.units: - s += f" ({self.units})" + s += f' ({self.units})' return s def clean(self): @@ -3362,9 +3543,7 @@ class PartParameterTemplate(MetadataMixin, models.Model): continue if choice in choice_set: - raise ValidationError({ - 'choices': _('Choices must be unique') - }) + raise ValidationError({'choices': _('Choices must be unique')}) choice_set.add(choice) @@ -3376,11 +3555,13 @@ class PartParameterTemplate(MetadataMixin, models.Model): super().validate_unique(exclude) try: - others = PartParameterTemplate.objects.filter(name__iexact=self.name).exclude(pk=self.pk) + others = PartParameterTemplate.objects.filter( + name__iexact=self.name + ).exclude(pk=self.pk) if others.exists(): - msg = _("Parameter template name must be unique") - raise ValidationError({"name": msg}) + msg = _('Parameter template name must be unique') + raise ValidationError({'name': msg}) except PartParameterTemplate.DoesNotExist: pass @@ -3395,16 +3576,15 @@ class PartParameterTemplate(MetadataMixin, models.Model): max_length=100, verbose_name=_('Name'), help_text=_('Parameter Name'), - unique=True + unique=True, ) units = models.CharField( max_length=25, - verbose_name=_('Units'), help_text=_('Physical units for this parameter'), + verbose_name=_('Units'), + help_text=_('Physical units for this parameter'), blank=True, - validators=[ - validators.validate_physical_units, - ] + validators=[validators.validate_physical_units], ) description = models.CharField( @@ -3417,7 +3597,7 @@ class PartParameterTemplate(MetadataMixin, models.Model): checkbox = models.BooleanField( default=False, verbose_name=_('Checkbox'), - help_text=_('Is this parameter a checkbox?') + help_text=_('Is this parameter a checkbox?'), ) choices = models.CharField( @@ -3428,19 +3608,20 @@ class PartParameterTemplate(MetadataMixin, models.Model): ) -@receiver(post_save, sender=PartParameterTemplate, dispatch_uid='post_save_part_parameter_template') +@receiver( + post_save, + sender=PartParameterTemplate, + dispatch_uid='post_save_part_parameter_template', +) def post_save_part_parameter_template(sender, instance, created, **kwargs): """Callback function when a PartParameterTemplate is created or saved""" import part.tasks as part_tasks if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData(): - if not created: # Schedule a background task to rebuild the parameters against this template InvenTree.tasks.offload_task( - part_tasks.rebuild_parameters, - instance.pk, - force_async=True + part_tasks.rebuild_parameters, instance.pk, force_async=True ) @@ -3455,6 +3636,7 @@ class PartParameter(MetadataMixin, models.Model): class Meta: """Metaclass providing extra model definition""" + # Prevent multiple instances of a parameter for a single part unique_together = ('part', 'template') @@ -3465,7 +3647,7 @@ class PartParameter(MetadataMixin, models.Model): def __str__(self): """String representation of a PartParameter (used in the admin interface)""" - return f"{self.part.full_name} : {self.template.name} = {self.data} ({self.template.units})" + return f'{self.part.full_name} : {self.template.name} = {self.data} ({self.template.units})' def save(self, *args, **kwargs): """Custom save method for the PartParameter model.""" @@ -3484,21 +3666,21 @@ class PartParameter(MetadataMixin, models.Model): super().clean() # Validate the parameter data against the template units - if InvenTreeSetting.get_setting('PART_PARAMETER_ENFORCE_UNITS', True, cache=False, create=False): + if InvenTreeSetting.get_setting( + 'PART_PARAMETER_ENFORCE_UNITS', True, cache=False, create=False + ): if self.template.units: try: - InvenTree.conversion.convert_physical_value(self.data, self.template.units) + InvenTree.conversion.convert_physical_value( + self.data, self.template.units + ) except ValidationError as e: - raise ValidationError({ - 'data': e.message - }) + raise ValidationError({'data': e.message}) # Validate the parameter data against the template choices if choices := self.template.get_choices(): if self.data not in choices: - raise ValidationError({ - 'data': _('Invalid choice for parameter value') - }) + raise ValidationError({'data': _('Invalid choice for parameter value')}) self.calculate_numeric_value() @@ -3506,7 +3688,6 @@ class PartParameter(MetadataMixin, models.Model): from plugin.registry import registry for plugin in registry.with_mixin('validation'): - # Note: The validate_part_parameter function may raise a ValidationError try: result = plugin.validate_part_parameter(self, self.data) @@ -3514,9 +3695,7 @@ class PartParameter(MetadataMixin, models.Model): break except ValidationError as exc: # Re-throw the ValidationError against the 'data' field - raise ValidationError({ - 'data': exc.message - }) + raise ValidationError({'data': exc.message}) def calculate_numeric_value(self): """Calculate a numeric value for the parameter data. @@ -3526,7 +3705,9 @@ class PartParameter(MetadataMixin, models.Model): """ if self.template.units: try: - self.data_numeric = InvenTree.conversion.convert_physical_value(self.data, self.template.units) + self.data_numeric = InvenTree.conversion.convert_physical_value( + self.data, self.template.units + ) except (ValidationError, ValueError): self.data_numeric = None @@ -3538,28 +3719,29 @@ class PartParameter(MetadataMixin, models.Model): self.data_numeric = None part = models.ForeignKey( - Part, on_delete=models.CASCADE, related_name='parameters', - verbose_name=_('Part'), help_text=_('Parent Part') + Part, + on_delete=models.CASCADE, + related_name='parameters', + verbose_name=_('Part'), + help_text=_('Parent Part'), ) template = models.ForeignKey( - PartParameterTemplate, on_delete=models.CASCADE, related_name='instances', - verbose_name=_('Template'), help_text=_('Parameter Template') + PartParameterTemplate, + on_delete=models.CASCADE, + related_name='instances', + verbose_name=_('Template'), + help_text=_('Parameter Template'), ) data = models.CharField( max_length=500, - verbose_name=_('Data'), help_text=_('Parameter Value'), - validators=[ - MinLengthValidator(1), - ] + verbose_name=_('Data'), + help_text=_('Parameter Value'), + validators=[MinLengthValidator(1)], ) - data_numeric = models.FloatField( - default=None, - null=True, - blank=True, - ) + data_numeric = models.FloatField(default=None, null=True, blank=True) @property def units(self): @@ -3599,9 +3781,12 @@ class PartCategoryParameterTemplate(MetadataMixin, models.Model): class Meta: """Metaclass providing extra model definition""" + constraints = [ - UniqueConstraint(fields=['category', 'parameter_template'], - name='unique_category_parameter_template_pair') + UniqueConstraint( + fields=['category', 'parameter_template'], + name='unique_category_parameter_template_pair', + ) ] def __str__(self): @@ -3610,22 +3795,28 @@ class PartCategoryParameterTemplate(MetadataMixin, models.Model): return f'{self.category.name} | {self.parameter_template.name} | {self.default_value}' return f'{self.category.name} | {self.parameter_template.name}' - category = models.ForeignKey(PartCategory, - on_delete=models.CASCADE, - related_name='parameter_templates', - verbose_name=_('Category'), - help_text=_('Part Category')) + category = models.ForeignKey( + PartCategory, + on_delete=models.CASCADE, + related_name='parameter_templates', + verbose_name=_('Category'), + help_text=_('Part Category'), + ) - parameter_template = models.ForeignKey(PartParameterTemplate, - on_delete=models.CASCADE, - related_name='part_categories', - verbose_name=_('Parameter Template'), - help_text=_('Parameter Template')) + parameter_template = models.ForeignKey( + PartParameterTemplate, + on_delete=models.CASCADE, + related_name='part_categories', + verbose_name=_('Parameter Template'), + help_text=_('Parameter Template'), + ) - default_value = models.CharField(max_length=500, - blank=True, - verbose_name=_('Default Value'), - help_text=_('Default Parameter Value')) + default_value = models.CharField( + max_length=500, + blank=True, + verbose_name=_('Default Value'), + help_text=_('Default Parameter Value'), + ) class BomItem(DataImportMixin, MetadataMixin, models.Model): @@ -3650,9 +3841,7 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model): # Fields available for bulk import IMPORT_FIELDS = { - 'quantity': { - 'required': True - }, + 'quantity': {'required': True}, 'reference': {}, 'overage': {}, 'allow_variants': {}, @@ -3660,35 +3849,21 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model): 'optional': {}, 'consumable': {}, 'note': {}, - 'part': { - 'label': _('Part'), - 'help_text': _('Part ID or part name'), - }, - 'part_id': { - 'label': _('Part ID'), - 'help_text': _('Unique part ID value') - }, - 'part_name': { - 'label': _('Part Name'), - 'help_text': _('Part name'), - }, - 'part_ipn': { - 'label': _('Part IPN'), - 'help_text': _('Part IPN value'), - }, - 'level': { - 'label': _('Level'), - 'help_text': _('BOM level'), - } + 'part': {'label': _('Part'), 'help_text': _('Part ID or part name')}, + 'part_id': {'label': _('Part ID'), 'help_text': _('Unique part ID value')}, + 'part_name': {'label': _('Part Name'), 'help_text': _('Part name')}, + 'part_ipn': {'label': _('Part IPN'), 'help_text': _('Part IPN value')}, + 'level': {'label': _('Level'), 'help_text': _('BOM level')}, } class Meta: """Metaclass providing extra model definition""" - verbose_name = _("BOM Item") + + verbose_name = _('BOM Item') def __str__(self): """Return a string representation of this BomItem instance""" - return f"{decimal2string(self.quantity)} x {self.sub_part.full_name} to make {self.part.full_name}" + return f'{decimal2string(self.quantity)} x {self.sub_part.full_name} to make {self.part.full_name}' @staticmethod def get_api_url(): @@ -3700,13 +3875,13 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model): assemblies = [self.part] if self.inherited: - assemblies += list( - self.part.get_descendants(include_self=False) - ) + assemblies += list(self.part.get_descendants(include_self=False)) return assemblies - def get_valid_parts_for_allocation(self, allow_variants=True, allow_substitutes=True): + def get_valid_parts_for_allocation( + self, allow_variants=True, allow_substitutes=True + ): """Return a list of valid parts which can be allocated against this BomItem. Includes: @@ -3763,56 +3938,82 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model): # A link to the parent part # Each part will get a reverse lookup field 'bom_items' - part = models.ForeignKey(Part, on_delete=models.CASCADE, related_name='bom_items', - verbose_name=_('Part'), - help_text=_('Select parent part'), - limit_choices_to={ - 'assembly': True, - }) + part = models.ForeignKey( + Part, + on_delete=models.CASCADE, + related_name='bom_items', + verbose_name=_('Part'), + help_text=_('Select parent part'), + limit_choices_to={'assembly': True}, + ) # A link to the child item (sub-part) # Each part will get a reverse lookup field 'used_in' - sub_part = models.ForeignKey(Part, on_delete=models.CASCADE, related_name='used_in', - verbose_name=_('Sub part'), - help_text=_('Select part to be used in BOM'), - limit_choices_to={ - 'component': True, - }) + sub_part = models.ForeignKey( + Part, + on_delete=models.CASCADE, + related_name='used_in', + verbose_name=_('Sub part'), + help_text=_('Select part to be used in BOM'), + limit_choices_to={'component': True}, + ) # Quantity required - quantity = models.DecimalField(default=1.0, max_digits=15, decimal_places=5, validators=[MinValueValidator(0)], verbose_name=_('Quantity'), help_text=_('BOM quantity for this BOM item')) + quantity = models.DecimalField( + default=1.0, + max_digits=15, + decimal_places=5, + validators=[MinValueValidator(0)], + verbose_name=_('Quantity'), + help_text=_('BOM quantity for this BOM item'), + ) optional = models.BooleanField( default=False, verbose_name=_('Optional'), - help_text=_("This BOM item is optional") + help_text=_('This BOM item is optional'), ) consumable = models.BooleanField( default=False, verbose_name=_('Consumable'), - help_text=_("This BOM item is consumable (it is not tracked in build orders)") + help_text=_('This BOM item is consumable (it is not tracked in build orders)'), ) - overage = models.CharField(max_length=24, blank=True, validators=[validators.validate_overage], - verbose_name=_('Overage'), - help_text=_('Estimated build wastage quantity (absolute or percentage)') - ) + overage = models.CharField( + max_length=24, + blank=True, + validators=[validators.validate_overage], + verbose_name=_('Overage'), + help_text=_('Estimated build wastage quantity (absolute or percentage)'), + ) - reference = models.CharField(max_length=5000, blank=True, verbose_name=_('Reference'), help_text=_('BOM item reference')) + reference = models.CharField( + max_length=5000, + blank=True, + verbose_name=_('Reference'), + help_text=_('BOM item reference'), + ) # Note attached to this BOM line item - note = models.CharField(max_length=500, blank=True, verbose_name=_('Note'), help_text=_('BOM item notes')) + note = models.CharField( + max_length=500, + blank=True, + verbose_name=_('Note'), + help_text=_('BOM item notes'), + ) checksum = models.CharField( - max_length=128, blank=True, - verbose_name=_('Checksum'), help_text=_('BOM line checksum') + max_length=128, + blank=True, + verbose_name=_('Checksum'), + help_text=_('BOM line checksum'), ) validated = models.BooleanField( default=False, verbose_name=_('Validated'), - help_text=_('This BOM item has been validated') + help_text=_('This BOM item has been validated'), ) inherited = models.BooleanField( @@ -3824,7 +4025,7 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model): allow_variants = models.BooleanField( default=False, verbose_name=_('Allow Variants'), - help_text=_('Stock items for variant parts can be used for this BOM item') + help_text=_('Stock items for variant parts can be used for this BOM item'), ) def get_item_hash(self): @@ -3852,7 +4053,7 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model): self.optional, self.inherited, self.consumable, - self.allow_variants + self.allow_variants, ] for component in components: @@ -3897,9 +4098,7 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model): try: self.quantity = Decimal(self.quantity) except InvalidOperation: - raise ValidationError({ - 'quantity': _('Must be a valid number') - }) + raise ValidationError({'quantity': _('Must be a valid number')}) try: # Check for circular BOM references @@ -3910,7 +4109,9 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model): if self.sub_part.trackable: if self.quantity != int(self.quantity): raise ValidationError({ - "quantity": _("Quantity must be integer value for trackable parts") + 'quantity': _( + 'Quantity must be integer value for trackable parts' + ) }) # Force the upstream part to be trackable if the sub_part is trackable @@ -3987,8 +4188,12 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model): def price_range(self, internal=False): """Return the price-range for this BOM item.""" # get internal price setting - use_internal = common.models.InvenTreeSetting.get_setting('PART_BOM_USE_INTERNAL_PRICE', False, cache=False) - prange = self.sub_part.get_price_range(self.quantity, internal=use_internal and internal) + use_internal = common.models.InvenTreeSetting.get_setting( + 'PART_BOM_USE_INTERNAL_PRICE', False, cache=False + ) + prange = self.sub_part.get_price_range( + self.quantity, internal=use_internal and internal + ) if prange is None: return prange @@ -4002,7 +4207,7 @@ class BomItem(DataImportMixin, MetadataMixin, models.Model): pmin = decimal2money(pmin) pmax = decimal2money(pmax) - return f"{pmin} to {pmax}" + return f'{pmin} to {pmax}' @receiver(post_save, sender=BomItem, dispatch_uid='update_bom_build_lines') @@ -4010,15 +4215,19 @@ def update_bom_build_lines(sender, instance, created, **kwargs): """Update existing build orders when a BomItem is created or edited""" if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData(): import build.tasks - InvenTree.tasks.offload_task( - build.tasks.update_build_order_lines, - instance.pk - ) + + InvenTree.tasks.offload_task(build.tasks.update_build_order_lines, instance.pk) @receiver(post_save, sender=BomItem, dispatch_uid='post_save_bom_item') -@receiver(post_save, sender=PartSellPriceBreak, dispatch_uid='post_save_sale_price_break') -@receiver(post_save, sender=PartInternalPriceBreak, dispatch_uid='post_save_internal_price_break') +@receiver( + post_save, sender=PartSellPriceBreak, dispatch_uid='post_save_sale_price_break' +) +@receiver( + post_save, + sender=PartInternalPriceBreak, + dispatch_uid='post_save_internal_price_break', +) def update_pricing_after_edit(sender, instance, created, **kwargs): """Callback function when a part price break is created or updated""" # Update part pricing *unless* we are importing data @@ -4027,8 +4236,14 @@ def update_pricing_after_edit(sender, instance, created, **kwargs): @receiver(post_delete, sender=BomItem, dispatch_uid='post_delete_bom_item') -@receiver(post_delete, sender=PartSellPriceBreak, dispatch_uid='post_delete_sale_price_break') -@receiver(post_delete, sender=PartInternalPriceBreak, dispatch_uid='post_delete_internal_price_break') +@receiver( + post_delete, sender=PartSellPriceBreak, dispatch_uid='post_delete_sale_price_break' +) +@receiver( + post_delete, + sender=PartInternalPriceBreak, + dispatch_uid='post_delete_internal_price_break', +) def update_pricing_after_delete(sender, instance, **kwargs): """Callback function when a part price break is deleted""" # Update part pricing *unless* we are importing data @@ -4046,7 +4261,8 @@ class BomItemSubstitute(MetadataMixin, models.Model): class Meta: """Metaclass providing extra model definition""" - verbose_name = _("BOM Item Substitute") + + verbose_name = _('BOM Item Substitute') # Prevent duplication of substitute parts unique_together = ('part', 'bom_item') @@ -4067,7 +4283,7 @@ class BomItemSubstitute(MetadataMixin, models.Model): if self.part == self.bom_item.sub_part: raise ValidationError({ - "part": _("Substitute part cannot be the same as the master part"), + 'part': _('Substitute part cannot be the same as the master part') }) @staticmethod @@ -4089,9 +4305,7 @@ class BomItemSubstitute(MetadataMixin, models.Model): related_name='substitute_items', verbose_name=_('Part'), help_text=_('Substitute part'), - limit_choices_to={ - 'component': True, - } + limit_choices_to={'component': True}, ) @@ -4100,14 +4314,23 @@ class PartRelated(MetadataMixin, models.Model): class Meta: """Metaclass defines extra model properties""" + unique_together = ('part_1', 'part_2') - part_1 = models.ForeignKey(Part, related_name='related_parts_1', - verbose_name=_('Part 1'), on_delete=models.CASCADE) + part_1 = models.ForeignKey( + Part, + related_name='related_parts_1', + verbose_name=_('Part 1'), + on_delete=models.CASCADE, + ) - part_2 = models.ForeignKey(Part, related_name='related_parts_2', - on_delete=models.CASCADE, - verbose_name=_('Part 2'), help_text=_('Select Related Part')) + part_2 = models.ForeignKey( + Part, + related_name='related_parts_2', + on_delete=models.CASCADE, + verbose_name=_('Part 2'), + help_text=_('Select Related Part'), + ) def __str__(self): """Return a string representation of this Part-Part relationship""" @@ -4124,8 +4347,10 @@ class PartRelated(MetadataMixin, models.Model): super().clean() if self.part_1 == self.part_2: - raise ValidationError(_("Part relationship cannot be created between a part and itself")) + raise ValidationError( + _('Part relationship cannot be created between a part and itself') + ) # Check for inverse relationship if PartRelated.objects.filter(part_1=self.part_2, part_2=self.part_1).exists(): - raise ValidationError(_("Duplicate relationship already exists")) + raise ValidationError(_('Duplicate relationship already exists')) diff --git a/InvenTree/part/part.py b/InvenTree/part/part.py index fc85218092..bcee84167c 100644 --- a/InvenTree/part/part.py +++ b/InvenTree/part/part.py @@ -25,10 +25,7 @@ def MakePartTemplate(fmt): # This will then export just the row headers! query = Part.objects.filter(pk=None) - dataset = PartImportResource().export( - queryset=query, - importing=True - ) + dataset = PartImportResource().export(queryset=query, importing=True) data = dataset.export(fmt) diff --git a/InvenTree/part/serializers.py b/InvenTree/part/serializers.py index db347f347a..4e978e086f 100644 --- a/InvenTree/part/serializers.py +++ b/InvenTree/part/serializers.py @@ -36,14 +36,26 @@ import users.models from InvenTree.status_codes import BuildStatusGroups from InvenTree.tasks import offload_task -from .models import (BomItem, BomItemSubstitute, Part, PartAttachment, - PartCategory, PartCategoryParameterTemplate, - PartInternalPriceBreak, PartParameter, - PartParameterTemplate, PartPricing, PartRelated, - PartSellPriceBreak, PartStar, PartStocktake, - PartStocktakeReport, PartTestTemplate) +from .models import ( + BomItem, + BomItemSubstitute, + Part, + PartAttachment, + PartCategory, + PartCategoryParameterTemplate, + PartInternalPriceBreak, + PartParameter, + PartParameterTemplate, + PartPricing, + PartRelated, + PartSellPriceBreak, + PartStar, + PartStocktake, + PartStocktakeReport, + PartTestTemplate, +) -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') class CategorySerializer(InvenTree.serializers.InvenTreeModelSerializer): @@ -51,6 +63,7 @@ class CategorySerializer(InvenTree.serializers.InvenTreeModelSerializer): class Meta: """Metaclass defining serializer fields""" + model = PartCategory fields = [ 'pk', @@ -86,9 +99,7 @@ class CategorySerializer(InvenTree.serializers.InvenTreeModelSerializer): def annotate_queryset(queryset): """Annotate extra information to the queryset""" # Annotate the number of 'parts' which exist in each category (including subcategories!) - queryset = queryset.annotate( - part_count=part.filters.annotate_category_parts() - ) + queryset = queryset.annotate(part_count=part.filters.annotate_category_parts()) return queryset @@ -101,9 +112,7 @@ class CategorySerializer(InvenTree.serializers.InvenTreeModelSerializer): starred = serializers.SerializerMethodField() path = serializers.ListField( - child=serializers.DictField(), - source='get_path', - read_only=True, + child=serializers.DictField(), source='get_path', read_only=True ) @@ -112,14 +121,9 @@ class CategoryTree(InvenTree.serializers.InvenTreeModelSerializer): class Meta: """Metaclass defining serializer fields""" + model = PartCategory - fields = [ - 'pk', - 'name', - 'parent', - 'icon', - 'structural', - ] + fields = ['pk', 'name', 'parent', 'icon', 'structural'] class PartAttachmentSerializer(InvenTree.serializers.InvenTreeAttachmentSerializer): @@ -127,10 +131,11 @@ class PartAttachmentSerializer(InvenTree.serializers.InvenTreeAttachmentSerializ class Meta: """Metaclass defining serializer fields""" + model = PartAttachment fields = InvenTree.serializers.InvenTreeAttachmentSerializer.attachment_fields([ - 'part', + 'part' ]) @@ -139,6 +144,7 @@ class PartTestTemplateSerializer(InvenTree.serializers.InvenTreeModelSerializer) class Meta: """Metaclass defining serializer fields""" + model = PartTestTemplate fields = [ @@ -160,20 +166,17 @@ class PartSalePriceSerializer(InvenTree.serializers.InvenTreeModelSerializer): class Meta: """Metaclass defining serializer fields""" + model = PartSellPriceBreak - fields = [ - 'pk', - 'part', - 'quantity', - 'price', - 'price_currency', - ] + fields = ['pk', 'part', 'quantity', 'price', 'price_currency'] quantity = InvenTree.serializers.InvenTreeDecimalField() price = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True) - price_currency = InvenTree.serializers.InvenTreeCurrencySerializer(help_text=_('Purchase currency of this stock item')) + price_currency = InvenTree.serializers.InvenTreeCurrencySerializer( + help_text=_('Purchase currency of this stock item') + ) class PartInternalPriceSerializer(InvenTree.serializers.InvenTreeModelSerializer): @@ -181,22 +184,17 @@ class PartInternalPriceSerializer(InvenTree.serializers.InvenTreeModelSerializer class Meta: """Metaclass defining serializer fields""" + model = PartInternalPriceBreak - fields = [ - 'pk', - 'part', - 'quantity', - 'price', - 'price_currency', - ] + fields = ['pk', 'part', 'quantity', 'price', 'price_currency'] quantity = InvenTree.serializers.InvenTreeDecimalField() - price = InvenTree.serializers.InvenTreeMoneySerializer( - allow_null=True - ) + price = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True) - price_currency = InvenTree.serializers.InvenTreeCurrencySerializer(help_text=_('Purchase currency of this stock item')) + price_currency = InvenTree.serializers.InvenTreeCurrencySerializer( + help_text=_('Purchase currency of this stock item') + ) class PartThumbSerializer(serializers.Serializer): @@ -214,16 +212,15 @@ class PartThumbSerializerUpdate(InvenTree.serializers.InvenTreeModelSerializer): class Meta: """Metaclass defining serializer fields""" + model = Part - fields = [ - 'image', - ] + fields = ['image'] def validate_image(self, value): """Check that file is an image.""" validate = imghdr.what(value) if not validate: - raise serializers.ValidationError("File is not an image") + raise serializers.ValidationError('File is not an image') return value image = InvenTree.serializers.InvenTreeAttachmentSerializerField(required=True) @@ -234,15 +231,9 @@ class PartParameterTemplateSerializer(InvenTree.serializers.InvenTreeModelSerial class Meta: """Metaclass defining serializer fields""" + model = PartParameterTemplate - fields = [ - 'pk', - 'name', - 'units', - 'description', - 'checkbox', - 'choices', - ] + fields = ['pk', 'name', 'units', 'description', 'checkbox', 'choices'] class PartBriefSerializer(InvenTree.serializers.InvenTreeModelSerializer): @@ -250,6 +241,7 @@ class PartBriefSerializer(InvenTree.serializers.InvenTreeModelSerializer): class Meta: """Metaclass defining serializer fields""" + model = Part fields = [ 'pk', @@ -273,9 +265,7 @@ class PartBriefSerializer(InvenTree.serializers.InvenTreeModelSerializer): 'pricing_max', ] - read_only_fields = [ - 'barcode_hash', - ] + read_only_fields = ['barcode_hash'] def __init__(self, *args, **kwargs): """Custom initialization routine for the PartBrief serializer""" @@ -290,8 +280,12 @@ class PartBriefSerializer(InvenTree.serializers.InvenTreeModelSerializer): thumbnail = serializers.CharField(source='get_thumbnail_url', read_only=True) # Pricing fields - pricing_min = InvenTree.serializers.InvenTreeMoneySerializer(source='pricing_data.overall_min', allow_null=True, read_only=True) - pricing_max = InvenTree.serializers.InvenTreeMoneySerializer(source='pricing_data.overall_max', allow_null=True, read_only=True) + pricing_min = InvenTree.serializers.InvenTreeMoneySerializer( + source='pricing_data.overall_min', allow_null=True, read_only=True + ) + pricing_max = InvenTree.serializers.InvenTreeMoneySerializer( + source='pricing_data.overall_max', allow_null=True, read_only=True + ) class PartParameterSerializer(InvenTree.serializers.InvenTreeModelSerializer): @@ -299,6 +293,7 @@ class PartParameterSerializer(InvenTree.serializers.InvenTreeModelSerializer): class Meta: """Metaclass defining serializer fields""" + model = PartParameter fields = [ 'pk', @@ -327,7 +322,9 @@ class PartParameterSerializer(InvenTree.serializers.InvenTreeModelSerializer): self.fields.pop('template_detail') part_detail = PartBriefSerializer(source='part', many=False, read_only=True) - template_detail = PartParameterTemplateSerializer(source='template', many=False, read_only=True) + template_detail = PartParameterTemplateSerializer( + source='template', many=False, read_only=True + ) class PartSetCategorySerializer(serializers.Serializer): @@ -335,29 +332,31 @@ class PartSetCategorySerializer(serializers.Serializer): class Meta: """Metaclass options""" - fields = [ - 'parts', - 'category', - ] + + fields = ['parts', 'category'] parts = serializers.PrimaryKeyRelatedField( queryset=Part.objects.all(), - many=True, required=True, allow_null=False, + many=True, + required=True, + allow_null=False, label=_('Parts'), ) def validate_parts(self, parts): """Validate the selected parts""" if len(parts) == 0: - raise serializers.ValidationError(_("No parts selected")) + raise serializers.ValidationError(_('No parts selected')) return parts category = serializers.PrimaryKeyRelatedField( queryset=PartCategory.objects.filter(structural=False), - many=False, required=True, allow_null=False, + many=False, + required=True, + allow_null=False, label=_('Category'), - help_text=_('Select category',) + help_text=_('Select category'), ) @transaction.atomic @@ -387,28 +386,37 @@ class DuplicatePartSerializer(serializers.Serializer): part = serializers.PrimaryKeyRelatedField( queryset=Part.objects.all(), - label=_('Original Part'), help_text=_('Select original part to duplicate'), + label=_('Original Part'), + help_text=_('Select original part to duplicate'), required=True, ) copy_image = serializers.BooleanField( - label=_('Copy Image'), help_text=_('Copy image from original part'), - required=False, default=False, + label=_('Copy Image'), + help_text=_('Copy image from original part'), + required=False, + default=False, ) copy_bom = serializers.BooleanField( - label=_('Copy BOM'), help_text=_('Copy bill of materials from original part'), - required=False, default=False, + label=_('Copy BOM'), + help_text=_('Copy bill of materials from original part'), + required=False, + default=False, ) copy_parameters = serializers.BooleanField( - label=_('Copy Parameters'), help_text=_('Copy parameter data from original part'), - required=False, default=False, + label=_('Copy Parameters'), + help_text=_('Copy parameter data from original part'), + required=False, + default=False, ) copy_notes = serializers.BooleanField( - label=_('Copy Notes'), help_text=_('Copy notes from original part'), - required=False, default=True, + label=_('Copy Notes'), + help_text=_('Copy notes from original part'), + required=False, + default=True, ) @@ -416,15 +424,22 @@ class InitialStockSerializer(serializers.Serializer): """Serializer for creating initial stock quantity.""" quantity = serializers.DecimalField( - max_digits=15, decimal_places=5, validators=[MinValueValidator(0)], - label=_('Initial Stock Quantity'), help_text=_('Specify initial stock quantity for this Part. If quantity is zero, no stock is added.'), + max_digits=15, + decimal_places=5, + validators=[MinValueValidator(0)], + label=_('Initial Stock Quantity'), + help_text=_( + 'Specify initial stock quantity for this Part. If quantity is zero, no stock is added.' + ), required=True, ) location = serializers.PrimaryKeyRelatedField( queryset=stock.models.StockLocation.objects.all(), - label=_('Initial Stock Location'), help_text=_('Specify initial stock location for this Part'), - allow_null=True, required=False, + label=_('Initial Stock Location'), + help_text=_('Specify initial stock location for this Part'), + allow_null=True, + required=False, ) @@ -433,53 +448,65 @@ class InitialSupplierSerializer(serializers.Serializer): supplier = serializers.PrimaryKeyRelatedField( queryset=company.models.Company.objects.all(), - label=_('Supplier'), help_text=_('Select supplier (or leave blank to skip)'), - allow_null=True, required=False, + label=_('Supplier'), + help_text=_('Select supplier (or leave blank to skip)'), + allow_null=True, + required=False, ) sku = serializers.CharField( - max_length=100, required=False, allow_blank=True, - label=_('SKU'), help_text=_('Supplier stock keeping unit'), + max_length=100, + required=False, + allow_blank=True, + label=_('SKU'), + help_text=_('Supplier stock keeping unit'), ) manufacturer = serializers.PrimaryKeyRelatedField( queryset=company.models.Company.objects.all(), - label=_('Manufacturer'), help_text=_('Select manufacturer (or leave blank to skip)'), - allow_null=True, required=False, + label=_('Manufacturer'), + help_text=_('Select manufacturer (or leave blank to skip)'), + allow_null=True, + required=False, ) mpn = serializers.CharField( - max_length=100, required=False, allow_blank=True, - label=_('MPN'), help_text=_('Manufacturer part number'), + max_length=100, + required=False, + allow_blank=True, + label=_('MPN'), + help_text=_('Manufacturer part number'), ) def validate_supplier(self, company): """Validation for the provided Supplier""" if company and not company.is_supplier: - raise serializers.ValidationError(_('Selected company is not a valid supplier')) + raise serializers.ValidationError( + _('Selected company is not a valid supplier') + ) return company def validate_manufacturer(self, company): """Validation for the provided Manufacturer""" if company and not company.is_manufacturer: - raise serializers.ValidationError(_('Selected company is not a valid manufacturer')) + raise serializers.ValidationError( + _('Selected company is not a valid manufacturer') + ) return company def validate(self, data): """Extra validation for this serializer""" if company.models.ManufacturerPart.objects.filter( - manufacturer=data.get('manufacturer', None), - MPN=data.get('mpn', '') + manufacturer=data.get('manufacturer', None), MPN=data.get('mpn', '') ).exists(): raise serializers.ValidationError({ 'mpn': _('Manufacturer part matching this MPN already exists') }) if company.models.SupplierPart.objects.filter( - supplier=data.get('supplier', None), - SKU=data.get('sku', '') + supplier=data.get('supplier', None), SKU=data.get('sku', '') ).exists(): raise serializers.ValidationError({ 'sku': _('Supplier part matching this SKU already exists') @@ -488,7 +515,10 @@ class InitialSupplierSerializer(serializers.Serializer): return data -class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serializers.InvenTreeTagModelSerializer): +class PartSerializer( + InvenTree.serializers.RemoteImageMixin, + InvenTree.serializers.InvenTreeTagModelSerializer, +): """Serializer for complete detail information of a part. Used when displaying all details of a single component. @@ -496,6 +526,7 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize class Meta: """Metaclass defining serializer fields""" + model = Part partial = True fields = [ @@ -537,7 +568,6 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize 'pricing_min', 'pricing_max', 'responsible', - # Annotated fields 'allocated_to_build_orders', 'allocated_to_sales_orders', @@ -550,20 +580,15 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize 'total_in_stock', 'unallocated_stock', 'variant_stock', - # Fields only used for Part creation 'duplicate', 'initial_stock', 'initial_supplier', 'copy_category_parameters', - 'tags', ] - read_only_fields = [ - 'barcode_hash', - 'creation_date', - ] + read_only_fields = ['barcode_hash', 'creation_date'] tags = TagListSerializerField(required=False) @@ -624,21 +649,19 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize Performing database queries as efficiently as possible, to reduce database trips. """ # Annotate with the total number of stock items - queryset = queryset.annotate( - stock_item_count=SubqueryCount('stock_items') - ) + queryset = queryset.annotate(stock_item_count=SubqueryCount('stock_items')) # Annotate with the total variant stock quantity variant_query = part.filters.variant_stock_query() queryset = queryset.annotate( - variant_stock=part.filters.annotate_variant_quantity(variant_query, reference='quantity'), + variant_stock=part.filters.annotate_variant_quantity( + variant_query, reference='quantity' + ) ) # Filter to limit builds to "active" - build_filter = Q( - status__in=BuildStatusGroups.ACTIVE_CODES - ) + build_filter = Q(status__in=BuildStatusGroups.ACTIVE_CODES) # Annotate with the total 'building' quantity queryset = queryset.annotate( @@ -655,7 +678,7 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize SubqueryCount('supplier_parts'), Decimal(0), output_field=models.DecimalField(), - ), + ) ) # TODO: This could do with some refactoring @@ -672,8 +695,7 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize # This is the 'in_stock' quantity summed with the 'variant_stock' quantity queryset = queryset.annotate( total_in_stock=ExpressionWrapper( - F('in_stock') + F('variant_stock'), - output_field=models.DecimalField(), + F('in_stock') + F('variant_stock'), output_field=models.DecimalField() ) ) @@ -681,14 +703,16 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize # This is the current stock, minus any allocations queryset = queryset.annotate( unallocated_stock=ExpressionWrapper( - F('total_in_stock') - F('allocated_to_sales_orders') - F('allocated_to_build_orders'), + F('total_in_stock') + - F('allocated_to_sales_orders') + - F('allocated_to_build_orders'), output_field=models.DecimalField(), ) ) # Annotate with the total 'required for builds' quantity queryset = queryset.annotate( - required_for_build_orders=part.filters.annotate_build_order_requirements(), + required_for_build_orders=part.filters.annotate_build_order_requirements() ) return queryset @@ -701,14 +725,13 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize category_detail = CategorySerializer(source='category', many=False, read_only=True) category_path = serializers.ListField( - child=serializers.DictField(), - source='category.get_path', - read_only=True, + child=serializers.DictField(), source='category.get_path', read_only=True ) responsible = serializers.PrimaryKeyRelatedField( queryset=users.models.Owner.objects.all(), - required=False, allow_null=True, + required=False, + allow_null=True, source='responsible_owner', ) @@ -727,7 +750,9 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize minimum_stock = serializers.FloatField() - image = InvenTree.serializers.InvenTreeImageSerializerField(required=False, allow_null=True) + image = InvenTree.serializers.InvenTreeImageSerializerField( + required=False, allow_null=True + ) thumbnail = serializers.CharField(source='get_thumbnail_url', read_only=True) starred = serializers.SerializerMethodField() @@ -735,32 +760,40 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize category = serializers.PrimaryKeyRelatedField(queryset=PartCategory.objects.all()) # Pricing fields - pricing_min = InvenTree.serializers.InvenTreeMoneySerializer(source='pricing_data.overall_min', allow_null=True, read_only=True) - pricing_max = InvenTree.serializers.InvenTreeMoneySerializer(source='pricing_data.overall_max', allow_null=True, read_only=True) - - parameters = PartParameterSerializer( - many=True, - read_only=True, + pricing_min = InvenTree.serializers.InvenTreeMoneySerializer( + source='pricing_data.overall_min', allow_null=True, read_only=True ) + pricing_max = InvenTree.serializers.InvenTreeMoneySerializer( + source='pricing_data.overall_max', allow_null=True, read_only=True + ) + + parameters = PartParameterSerializer(many=True, read_only=True) # Extra fields used only for creation of a new Part instance duplicate = DuplicatePartSerializer( - label=_('Duplicate Part'), help_text=_('Copy initial data from another Part'), - write_only=True, required=False + label=_('Duplicate Part'), + help_text=_('Copy initial data from another Part'), + write_only=True, + required=False, ) initial_stock = InitialStockSerializer( - label=_('Initial Stock'), help_text=_('Create Part with initial stock quantity'), - write_only=True, required=False, + label=_('Initial Stock'), + help_text=_('Create Part with initial stock quantity'), + write_only=True, + required=False, ) initial_supplier = InitialSupplierSerializer( - label=_('Supplier Information'), help_text=_('Add initial supplier information for this part'), - write_only=True, required=False, + label=_('Supplier Information'), + help_text=_('Add initial supplier information for this part'), + write_only=True, + required=False, ) copy_category_parameters = serializers.BooleanField( - default=True, required=False, + default=True, + required=False, label=_('Copy Category Parameters'), help_text=_('Copy parameter templates from selected part category'), ) @@ -782,10 +815,7 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize img = img.split(os.path.sep)[-1] # Ensure that the file actually exists - img_path = os.path.join( - part_helpers.get_part_image_directory(), - img - ) + img_path = os.path.join(part_helpers.get_part_image_directory(), img) if not os.path.exists(img_path) or not os.path.isfile(img_path): raise ValidationError(_('Image file does not exist')) @@ -838,8 +868,7 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize for template in templates: # First ensure that the part doesn't have that parameter if PartParameter.objects.filter( - part=instance, - template=template.parameter_template + part=instance, template=template.parameter_template ).exists(): continue @@ -848,10 +877,12 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize part=instance, template=template.parameter_template, data=template.default_value, - save=True + save=True, ) except IntegrityError: - logger.exception("Could not create new PartParameter for part %s", instance) + logger.exception( + 'Could not create new PartParameter for part %s', instance + ) # Create initial stock entry if initial_stock: @@ -860,24 +891,19 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize if quantity > 0: stockitem = stock.models.StockItem( - part=instance, - quantity=quantity, - location=location, + part=instance, quantity=quantity, location=location ) stockitem.save(user=self.context['request'].user) # Create initial supplier information if initial_supplier: - manufacturer = initial_supplier.get('manufacturer', None) mpn = initial_supplier.get('mpn', '') if manufacturer and mpn: manu_part = company.models.ManufacturerPart.objects.create( - part=instance, - manufacturer=manufacturer, - MPN=mpn + part=instance, manufacturer=manufacturer, MPN=mpn ) else: manu_part = None @@ -905,10 +931,7 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize existing_image = data.pop('existing_image', None) if existing_image: - img_path = os.path.join( - part_helpers.PART_IMAGE_DIR, - existing_image - ) + img_path = os.path.join(part_helpers.PART_IMAGE_DIR, existing_image) part.image = img_path part.save() @@ -922,12 +945,9 @@ class PartSerializer(InvenTree.serializers.RemoteImageMixin, InvenTree.serialize remote_img.save(buffer, format=fmt) # Construct a simplified name for the image - filename = f"part_{part.pk}_image.{fmt.lower()}" + filename = f'part_{part.pk}_image.{fmt.lower()}' - part.image.save( - filename, - ContentFile(buffer.getvalue()), - ) + part.image.save(filename, ContentFile(buffer.getvalue())) return self.instance @@ -954,14 +974,13 @@ class PartStocktakeSerializer(InvenTree.serializers.InvenTreeModelSerializer): 'user_detail', ] - read_only_fields = [ - 'date', - 'user', - ] + read_only_fields = ['date', 'user'] quantity = serializers.FloatField() - user_detail = InvenTree.serializers.UserSerializer(source='user', read_only=True, many=False) + user_detail = InvenTree.serializers.UserSerializer( + source='user', read_only=True, many=False + ) cost_min = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True) cost_min_currency = InvenTree.serializers.InvenTreeCurrencySerializer() @@ -987,16 +1006,11 @@ class PartStocktakeReportSerializer(InvenTree.serializers.InvenTreeModelSerializ """Metaclass defines serializer fields""" model = PartStocktakeReport - fields = [ - 'pk', - 'date', - 'report', - 'part_count', - 'user', - 'user_detail', - ] + fields = ['pk', 'date', 'report', 'part_count', 'user', 'user_detail'] - user_detail = InvenTree.serializers.UserSerializer(source='user', read_only=True, many=False) + user_detail = InvenTree.serializers.UserSerializer( + source='user', read_only=True, many=False + ) report = InvenTree.serializers.InvenTreeAttachmentSerializerField(read_only=True) @@ -1006,26 +1020,38 @@ class PartStocktakeReportGenerateSerializer(serializers.Serializer): part = serializers.PrimaryKeyRelatedField( queryset=Part.objects.all(), - required=False, allow_null=True, - label=_('Part'), help_text=_('Limit stocktake report to a particular part, and any variant parts') + required=False, + allow_null=True, + label=_('Part'), + help_text=_( + 'Limit stocktake report to a particular part, and any variant parts' + ), ) category = serializers.PrimaryKeyRelatedField( queryset=PartCategory.objects.all(), - required=False, allow_null=True, - label=_('Category'), help_text=_('Limit stocktake report to a particular part category, and any child categories'), + required=False, + allow_null=True, + label=_('Category'), + help_text=_( + 'Limit stocktake report to a particular part category, and any child categories' + ), ) location = serializers.PrimaryKeyRelatedField( queryset=stock.models.StockLocation.objects.all(), - required=False, allow_null=True, - label=_('Location'), help_text=_('Limit stocktake report to a particular stock location, and any child locations') + required=False, + allow_null=True, + label=_('Location'), + help_text=_( + 'Limit stocktake report to a particular stock location, and any child locations' + ), ) exclude_external = serializers.BooleanField( default=True, label=_('Exclude External Stock'), - help_text=_('Exclude stock items in external locations') + help_text=_('Exclude stock items in external locations'), ) generate_report = serializers.BooleanField( @@ -1037,18 +1063,20 @@ class PartStocktakeReportGenerateSerializer(serializers.Serializer): update_parts = serializers.BooleanField( default=True, label=_('Update Parts'), - help_text=_('Update specified parts with calculated stocktake data') + help_text=_('Update specified parts with calculated stocktake data'), ) def validate(self, data): """Custom validation for this serializer""" # Stocktake functionality must be enabled if not common.models.InvenTreeSetting.get_setting('STOCKTAKE_ENABLE', False): - raise serializers.ValidationError(_("Stocktake functionality is not enabled")) + raise serializers.ValidationError( + _('Stocktake functionality is not enabled') + ) # Check that background worker is running if not InvenTree.status.is_worker_running(): - raise serializers.ValidationError(_("Background worker check failed")) + raise serializers.ValidationError(_('Background worker check failed')) return data @@ -1076,6 +1104,7 @@ class PartPricingSerializer(InvenTree.serializers.InvenTreeModelSerializer): class Meta: """Metaclass defining serializer fields""" + model = PartPricing fields = [ 'currency', @@ -1111,59 +1140,99 @@ class PartPricingSerializer(InvenTree.serializers.InvenTreeModelSerializer): scheduled_for_update = serializers.BooleanField(read_only=True) # Custom serializers - bom_cost_min = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) - bom_cost_max = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) + bom_cost_min = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) + bom_cost_max = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) - purchase_cost_min = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) - purchase_cost_max = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) + purchase_cost_min = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) + purchase_cost_max = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) - internal_cost_min = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) - internal_cost_max = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) + internal_cost_min = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) + internal_cost_max = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) - supplier_price_min = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) - supplier_price_max = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) + supplier_price_min = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) + supplier_price_max = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) - variant_cost_min = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) - variant_cost_max = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) + variant_cost_min = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) + variant_cost_max = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) override_min = InvenTree.serializers.InvenTreeMoneySerializer( label=_('Minimum Price'), help_text=_('Override calculated value for minimum price'), - allow_null=True, read_only=False, required=False, + allow_null=True, + read_only=False, + required=False, ) override_min_currency = serializers.ChoiceField( label=_('Minimum price currency'), - read_only=False, required=False, + read_only=False, + required=False, choices=common.settings.currency_code_mappings(), ) override_max = InvenTree.serializers.InvenTreeMoneySerializer( label=_('Maximum Price'), help_text=_('Override calculated value for maximum price'), - allow_null=True, read_only=False, required=False, + allow_null=True, + read_only=False, + required=False, ) override_max_currency = serializers.ChoiceField( label=_('Maximum price currency'), - read_only=False, required=False, + read_only=False, + required=False, choices=common.settings.currency_code_mappings(), ) - overall_min = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) - overall_max = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) + overall_min = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) + overall_max = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) - sale_price_min = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) - sale_price_max = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) + sale_price_min = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) + sale_price_max = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) - sale_history_min = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) - sale_history_max = InvenTree.serializers.InvenTreeMoneySerializer(allow_null=True, read_only=True) + sale_history_min = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) + sale_history_max = InvenTree.serializers.InvenTreeMoneySerializer( + allow_null=True, read_only=True + ) update = serializers.BooleanField( write_only=True, label=_('Update'), help_text=_('Update pricing for this part'), - default=False, required=False, allow_null=True, + default=False, + required=False, + allow_null=True, ) def validate(self, data): @@ -1178,17 +1247,24 @@ class PartPricingSerializer(InvenTree.serializers.InvenTreeModelSerializer): default_currency = common.settings.currency_code_default() if override_min is not None and override_max is not None: - try: override_min = convert_money(override_min, default_currency) override_max = convert_money(override_max, default_currency) except MissingRate: - raise ValidationError(_(f'Could not convert from provided currencies to {default_currency}')) + raise ValidationError( + _( + f'Could not convert from provided currencies to {default_currency}' + ) + ) if override_min > override_max: raise ValidationError({ - 'override_min': _('Minimum price must not be greater than maximum price'), - 'override_max': _('Maximum price must not be less than minimum price') + 'override_min': _( + 'Minimum price must not be greater than maximum price' + ), + 'override_max': _( + 'Maximum price must not be less than minimum price' + ), }) return data @@ -1208,14 +1284,9 @@ class PartRelationSerializer(InvenTree.serializers.InvenTreeModelSerializer): class Meta: """Metaclass defining serializer fields""" + model = PartRelated - fields = [ - 'pk', - 'part_1', - 'part_1_detail', - 'part_2', - 'part_2_detail', - ] + fields = ['pk', 'part_1', 'part_1_detail', 'part_2', 'part_2_detail'] part_1_detail = PartSerializer(source='part_1', read_only=True, many=False) part_2_detail = PartSerializer(source='part_2', read_only=True, many=False) @@ -1226,14 +1297,9 @@ class PartStarSerializer(InvenTree.serializers.InvenTreeModelSerializer): class Meta: """Metaclass defining serializer fields""" + model = PartStar - fields = [ - 'pk', - 'part', - 'partname', - 'user', - 'username', - ] + fields = ['pk', 'part', 'partname', 'user', 'username'] partname = serializers.CharField(source='part.full_name', read_only=True) username = serializers.CharField(source='user.username', read_only=True) @@ -1244,15 +1310,13 @@ class BomItemSubstituteSerializer(InvenTree.serializers.InvenTreeModelSerializer class Meta: """Metaclass defining serializer fields""" - model = BomItemSubstitute - fields = [ - 'pk', - 'bom_item', - 'part', - 'part_detail', - ] - part_detail = PartBriefSerializer(source='part', read_only=True, many=False, pricing=False) + model = BomItemSubstitute + fields = ['pk', 'bom_item', 'part', 'part_detail'] + + part_detail = PartBriefSerializer( + source='part', read_only=True, many=False, pricing=False + ) class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): @@ -1260,6 +1324,7 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): class Meta: """Metaclass defining serializer fields""" + model = BomItem fields = [ 'allow_variants', @@ -1279,15 +1344,12 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): 'sub_part_detail', 'substitutes', 'validated', - # Annotated fields describing available quantity 'available_stock', 'available_substitute_stock', 'available_variant_stock', - # Annotated field describing quantity on order 'on_order', - # Annotated field describing quantity being built 'building', ] @@ -1319,17 +1381,21 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): def validate_quantity(self, quantity): """Perform validation for the BomItem quantity field""" if quantity <= 0: - raise serializers.ValidationError(_("Quantity must be greater than zero")) + raise serializers.ValidationError(_('Quantity must be greater than zero')) return quantity - part = serializers.PrimaryKeyRelatedField(queryset=Part.objects.filter(assembly=True)) + part = serializers.PrimaryKeyRelatedField( + queryset=Part.objects.filter(assembly=True) + ) substitutes = BomItemSubstituteSerializer(many=True, read_only=True) part_detail = PartBriefSerializer(source='part', many=False, read_only=True) - sub_part = serializers.PrimaryKeyRelatedField(queryset=Part.objects.filter(component=True)) + sub_part = serializers.PrimaryKeyRelatedField( + queryset=Part.objects.filter(component=True) + ) sub_part_detail = PartBriefSerializer(source='sub_part', many=False, read_only=True) @@ -1337,8 +1403,12 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): building = serializers.FloatField(read_only=True) # Cached pricing fields - pricing_min = InvenTree.serializers.InvenTreeMoneySerializer(source='sub_part.pricing.overall_min', allow_null=True, read_only=True) - pricing_max = InvenTree.serializers.InvenTreeMoneySerializer(source='sub_part.pricing.overall_max', allow_null=True, read_only=True) + pricing_min = InvenTree.serializers.InvenTreeMoneySerializer( + source='sub_part.pricing.overall_min', allow_null=True, read_only=True + ) + pricing_max = InvenTree.serializers.InvenTreeMoneySerializer( + source='sub_part.pricing.overall_max', allow_null=True, read_only=True + ) # Annotated fields for available stock available_stock = serializers.FloatField(read_only=True) @@ -1362,13 +1432,10 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): ) queryset = queryset.prefetch_related( - 'substitutes', - 'substitutes__part__stock_items', + 'substitutes', 'substitutes__part__stock_items' ) - queryset = queryset.prefetch_related( - 'sub_part__builds', - ) + queryset = queryset.prefetch_related('sub_part__builds') return queryset @@ -1388,7 +1455,7 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): # Annotate with the total "on order" amount for the sub-part queryset = queryset.annotate( - on_order=part.filters.annotate_on_order_quantity(ref), + on_order=part.filters.annotate_on_order_quantity(ref) ) # Annotate with the total "building" amount for the sub-part @@ -1408,14 +1475,20 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): # Note that these fields are only aliased, not annotated queryset = queryset.alias( total_stock=part.filters.annotate_total_stock(reference=ref), - allocated_to_sales_orders=part.filters.annotate_sales_order_allocations(reference=ref), - allocated_to_build_orders=part.filters.annotate_build_order_allocations(reference=ref), + allocated_to_sales_orders=part.filters.annotate_sales_order_allocations( + reference=ref + ), + allocated_to_build_orders=part.filters.annotate_build_order_allocations( + reference=ref + ), ) # Calculate 'available_stock' based on previously annotated fields queryset = queryset.annotate( available_stock=ExpressionWrapper( - F('total_stock') - F('allocated_to_sales_orders') - F('allocated_to_build_orders'), + F('total_stock') + - F('allocated_to_sales_orders') + - F('allocated_to_build_orders'), output_field=models.DecimalField(), ) ) @@ -1425,14 +1498,20 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): # Extract similar information for any 'substitute' parts queryset = queryset.alias( substitute_stock=part.filters.annotate_total_stock(reference=ref), - substitute_build_allocations=part.filters.annotate_build_order_allocations(reference=ref), - substitute_sales_allocations=part.filters.annotate_sales_order_allocations(reference=ref) + substitute_build_allocations=part.filters.annotate_build_order_allocations( + reference=ref + ), + substitute_sales_allocations=part.filters.annotate_sales_order_allocations( + reference=ref + ), ) # Calculate 'available_substitute_stock' field queryset = queryset.annotate( available_substitute_stock=ExpressionWrapper( - F('substitute_stock') - F('substitute_build_allocations') - F('substitute_sales_allocations'), + F('substitute_stock') + - F('substitute_build_allocations') + - F('substitute_sales_allocations'), output_field=models.DecimalField(), ) ) @@ -1441,14 +1520,22 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): variant_stock_query = part.filters.variant_stock_query(reference='sub_part__') queryset = queryset.alias( - variant_stock_total=part.filters.annotate_variant_quantity(variant_stock_query, reference='quantity'), - variant_bo_allocations=part.filters.annotate_variant_quantity(variant_stock_query, reference='sales_order_allocations__quantity'), - variant_so_allocations=part.filters.annotate_variant_quantity(variant_stock_query, reference='allocations__quantity'), + variant_stock_total=part.filters.annotate_variant_quantity( + variant_stock_query, reference='quantity' + ), + variant_bo_allocations=part.filters.annotate_variant_quantity( + variant_stock_query, reference='sales_order_allocations__quantity' + ), + variant_so_allocations=part.filters.annotate_variant_quantity( + variant_stock_query, reference='allocations__quantity' + ), ) queryset = queryset.annotate( available_variant_stock=ExpressionWrapper( - F('variant_stock_total') - F('variant_bo_allocations') - F('variant_so_allocations'), + F('variant_stock_total') + - F('variant_bo_allocations') + - F('variant_so_allocations'), output_field=FloatField(), ) ) @@ -1456,11 +1543,14 @@ class BomItemSerializer(InvenTree.serializers.InvenTreeModelSerializer): return queryset -class CategoryParameterTemplateSerializer(InvenTree.serializers.InvenTreeModelSerializer): +class CategoryParameterTemplateSerializer( + InvenTree.serializers.InvenTreeModelSerializer +): """Serializer for the PartCategoryParameterTemplate model.""" class Meta: """Metaclass defining serializer fields""" + model = PartCategoryParameterTemplate fields = [ 'pk', @@ -1471,7 +1561,9 @@ class CategoryParameterTemplateSerializer(InvenTree.serializers.InvenTreeModelSe 'default_value', ] - parameter_template_detail = PartParameterTemplateSerializer(source='parameter_template', many=False, read_only=True) + parameter_template_detail = PartParameterTemplateSerializer( + source='parameter_template', many=False, read_only=True + ) category_detail = CategorySerializer(source='category', many=False, read_only=True) @@ -1481,6 +1573,7 @@ class PartCopyBOMSerializer(serializers.Serializer): class Meta: """Metaclass defining serializer fields""" + fields = [ 'part', 'remove_existing', @@ -1548,22 +1641,16 @@ class BomImportUploadSerializer(InvenTree.serializers.DataFileUploadSerializer): class Meta: """Metaclass defining serializer fields""" - fields = [ - 'data_file', - 'part', - 'clear_existing_bom', - ] + + fields = ['data_file', 'part', 'clear_existing_bom'] part = serializers.PrimaryKeyRelatedField( - queryset=Part.objects.all(), - required=True, - allow_null=False, - many=False, + queryset=Part.objects.all(), required=True, allow_null=False, many=False ) clear_existing_bom = serializers.BooleanField( label=_('Clear Existing BOM'), - help_text=_('Delete existing BOM items before uploading') + help_text=_('Delete existing BOM items before uploading'), ) def save(self): @@ -1593,7 +1680,7 @@ class BomImportExtractSerializer(InvenTree.serializers.DataFileExtractSerializer if not any(col in self.columns for col in part_columns): # At least one part column is required! - raise serializers.ValidationError(_("No part column specified")) + raise serializers.ValidationError(_('No part column specified')) @staticmethod def process_row(row): @@ -1681,7 +1768,7 @@ class BomImportSubmitSerializer(serializers.Serializer): items = data['items'] if len(items) == 0: - raise serializers.ValidationError(_("At least one BOM item is required")) + raise serializers.ValidationError(_('At least one BOM item is required')) data = super().validate(data) @@ -1700,9 +1787,7 @@ class BomImportSubmitSerializer(serializers.Serializer): bom_items = [] try: - for item in items: - part = item['part'] sub_part = item['sub_part'] @@ -1710,12 +1795,10 @@ class BomImportSubmitSerializer(serializers.Serializer): if BomItem.objects.filter(part=part, sub_part=sub_part).exists(): continue - bom_items.append( - BomItem(**item) - ) + bom_items.append(BomItem(**item)) if len(bom_items) > 0: - logger.info("Importing %s BOM items", len(bom_items)) + logger.info('Importing %s BOM items', len(bom_items)) BomItem.objects.bulk_create(bom_items) except Exception as e: diff --git a/InvenTree/part/stocktake.py b/InvenTree/part/stocktake.py index 7ccd6cd185..4d89b263ec 100644 --- a/InvenTree/part/stocktake.py +++ b/InvenTree/part/stocktake.py @@ -21,7 +21,9 @@ import stock.models logger = logging.getLogger('inventree') -def perform_stocktake(target: part.models.Part, user: User, note: str = '', commit=True, **kwargs): +def perform_stocktake( + target: part.models.Part, user: User, note: str = '', commit=True, **kwargs +): """Perform stocktake action on a single part. Arguments: @@ -60,7 +62,7 @@ def perform_stocktake(target: part.models.Part, user: User, note: str = '', comm if not pricing.is_valid: # If pricing is not valid, let's update - logger.info("Pricing not valid for %s - updating", target) + logger.info('Pricing not valid for %s - updating', target) pricing.update_pricing(cascade=False) pricing.refresh_from_db() @@ -78,7 +80,6 @@ def perform_stocktake(target: part.models.Part, user: User, note: str = '', comm location_cost_max = Money(0, base_currency) for entry in stock_entries: - entry_cost_min = None entry_cost_max = None @@ -94,10 +95,13 @@ def perform_stocktake(target: part.models.Part, user: User, note: str = '', comm # Convert to base currency try: - entry_cost_min = convert_money(entry_cost_min, base_currency) * entry.quantity - entry_cost_max = convert_money(entry_cost_max, base_currency) * entry.quantity + entry_cost_min = ( + convert_money(entry_cost_min, base_currency) * entry.quantity + ) + entry_cost_max = ( + convert_money(entry_cost_max, base_currency) * entry.quantity + ) except Exception: - entry_cost_min = Money(0, base_currency) entry_cost_max = Money(0, base_currency) @@ -160,7 +164,7 @@ def generate_stocktake_report(**kwargs): # Determine if external locations should be excluded exclude_external = kwargs.get( 'exclude_exernal', - common.models.InvenTreeSetting.get_setting('STOCKTAKE_EXCLUDE_EXTERNAL', False) + common.models.InvenTreeSetting.get_setting('STOCKTAKE_EXCLUDE_EXTERNAL', False), ) parts = part.models.Part.objects.all() @@ -172,9 +176,7 @@ def generate_stocktake_report(**kwargs): # Filter by 'Part' instance if p := kwargs.get('part', None): variants = p.get_descendants(include_self=True) - parts = parts.filter( - pk__in=[v.pk for v in variants] - ) + parts = parts.filter(pk__in=[v.pk for v in variants]) # Filter by 'Category' instance (cascading) if category := kwargs.get('category', None): @@ -196,18 +198,16 @@ def generate_stocktake_report(**kwargs): # List of parts which exist within these locations unique_parts = items.order_by().values('part').distinct() - parts = parts.filter( - pk__in=[result['part'] for result in unique_parts] - ) + parts = parts.filter(pk__in=[result['part'] for result in unique_parts]) # Exit if filters removed all parts n_parts = parts.count() if n_parts == 0: - logger.info("No parts selected for stocktake report - exiting") + logger.info('No parts selected for stocktake report - exiting') return - logger.info("Generating new stocktake report for %s parts", n_parts) + logger.info('Generating new stocktake report for %s parts', n_parts) base_currency = common.settings.currency_code_default() @@ -239,12 +239,9 @@ def generate_stocktake_report(**kwargs): # Iterate through each Part which matches the filters above for p in parts: - # Create a new stocktake for this part (do not commit, this will take place later on) stocktake = perform_stocktake( - p, user, commit=False, - exclude_external=exclude_external, - location=location, + p, user, commit=False, exclude_external=exclude_external, location=location ) total_parts += 1 @@ -269,19 +266,16 @@ def generate_stocktake_report(**kwargs): buffer.write(dataset.export('csv')) today = datetime.now().date().isoformat() - filename = f"InvenTree_Stocktake_{today}.csv" + filename = f'InvenTree_Stocktake_{today}.csv' report_file = ContentFile(buffer.getvalue(), name=filename) if generate_report: report_instance = part.models.PartStocktakeReport.objects.create( - report=report_file, - part_count=total_parts, - user=user + report=report_file, part_count=total_parts, user=user ) # Notify the requesting user if user: - common.notifications.trigger_notification( report_instance, category='generate_stocktake_report', @@ -289,18 +283,19 @@ def generate_stocktake_report(**kwargs): 'name': _('Stocktake Report Available'), 'message': _('A new stocktake report is available for download'), }, - targets=[ - user, - ] + targets=[user], ) # If 'update_parts' is set, we save stocktake entries for each individual part if update_parts: # Use bulk_create for efficient insertion of stocktake part.models.PartStocktake.objects.bulk_create( - stocktake_instances, - batch_size=500, + stocktake_instances, batch_size=500 ) t_stocktake = time.time() - t_start - logger.info("Generated stocktake report for %s parts in %ss", total_parts, round(t_stocktake, 2)) + logger.info( + 'Generated stocktake report for %s parts in %ss', + total_parts, + round(t_stocktake, 2), + ) diff --git a/InvenTree/part/tasks.py b/InvenTree/part/tasks.py index a305adeba8..604baaa907 100644 --- a/InvenTree/part/tasks.py +++ b/InvenTree/part/tasks.py @@ -1,6 +1,5 @@ """Background task definitions for the 'part' app""" - import logging import random import time @@ -18,10 +17,14 @@ import InvenTree.helpers_model import InvenTree.tasks import part.models import part.stocktake -from InvenTree.tasks import (ScheduledTask, check_daily_holdoff, - record_task_success, scheduled_task) +from InvenTree.tasks import ( + ScheduledTask, + check_daily_holdoff, + record_task_success, + scheduled_task, +) -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') def notify_low_stock(part: part.models.Part): @@ -30,24 +33,20 @@ def notify_low_stock(part: part.models.Part): - Triggered when the available stock for a given part falls be low the configured threhsold - A notification is delivered to any users who are 'subscribed' to this part """ - name = _("Low stock notification") - message = _(f'The available stock for {part.name} has fallen below the configured minimum level') + name = _('Low stock notification') + message = _( + f'The available stock for {part.name} has fallen below the configured minimum level' + ) context = { 'part': part, 'name': name, 'message': message, 'link': InvenTree.helpers_model.construct_absolute_url(part.get_absolute_url()), - 'template': { - 'html': 'email/low_stock_notification.html', - 'subject': name, - }, + 'template': {'html': 'email/low_stock_notification.html', 'subject': name}, } common.notifications.trigger_notification( - part, - 'part.notify_low_stock', - target_fnc=part.get_subscribers, - context=context, + part, 'part.notify_low_stock', target_fnc=part.get_subscribers, context=context ) @@ -61,10 +60,7 @@ def notify_low_stock_if_required(part: part.models.Part): for p in parts: if p.is_part_low_on_stock(): - InvenTree.tasks.offload_task( - notify_low_stock, - p - ) + InvenTree.tasks.offload_task(notify_low_stock, p) def update_part_pricing(pricing: part.models.PartPricing, counter: int = 0): @@ -74,7 +70,7 @@ def update_part_pricing(pricing: part.models.PartPricing, counter: int = 0): pricing: The target PartPricing instance to be updated counter: How many times this function has been called in sequence """ - logger.info("Updating part pricing for %s", pricing.part) + logger.info('Updating part pricing for %s', pricing.part) pricing.update_pricing(counter=counter) @@ -94,7 +90,7 @@ def check_missing_pricing(limit=250): results = part.models.PartPricing.objects.filter(updated=None)[:limit] if results.count() > 0: - logger.info("Found %s parts with empty pricing", results.count()) + logger.info('Found %s parts with empty pricing', results.count()) for pp in results: pp.schedule_for_update() @@ -106,7 +102,7 @@ def check_missing_pricing(limit=250): results = part.models.PartPricing.objects.filter(updated__lte=stale_date)[:limit] if results.count() > 0: - logger.info("Found %s stale pricing entries", results.count()) + logger.info('Found %s stale pricing entries', results.count()) for pp in results: pp.schedule_for_update() @@ -116,7 +112,7 @@ def check_missing_pricing(limit=250): results = part.models.PartPricing.objects.exclude(currency=currency) if results.count() > 0: - logger.info("Found %s pricing entries in the wrong currency", results.count()) + logger.info('Found %s pricing entries in the wrong currency', results.count()) for pp in results: pp.schedule_for_update() @@ -125,7 +121,7 @@ def check_missing_pricing(limit=250): results = part.models.Part.objects.filter(pricing_data=None)[:limit] if results.count() > 0: - logger.info("Found %s parts without pricing", results.count()) + logger.info('Found %s parts without pricing', results.count()) for p in results: pricing = p.pricing @@ -146,27 +142,37 @@ def scheduled_stocktake_reports(): time.sleep(random.randint(1, 5)) # First let's delete any old stocktake reports - delete_n_days = int(common.models.InvenTreeSetting.get_setting('STOCKTAKE_DELETE_REPORT_DAYS', 30, cache=False)) + delete_n_days = int( + common.models.InvenTreeSetting.get_setting( + 'STOCKTAKE_DELETE_REPORT_DAYS', 30, cache=False + ) + ) threshold = datetime.now() - timedelta(days=delete_n_days) old_reports = part.models.PartStocktakeReport.objects.filter(date__lt=threshold) if old_reports.count() > 0: - logger.info("Deleting %s stale stocktake reports", old_reports.count()) + logger.info('Deleting %s stale stocktake reports', old_reports.count()) old_reports.delete() # Next, check if stocktake functionality is enabled - if not common.models.InvenTreeSetting.get_setting('STOCKTAKE_ENABLE', False, cache=False): - logger.info("Stocktake functionality is not enabled - exiting") + if not common.models.InvenTreeSetting.get_setting( + 'STOCKTAKE_ENABLE', False, cache=False + ): + logger.info('Stocktake functionality is not enabled - exiting') return - report_n_days = int(common.models.InvenTreeSetting.get_setting('STOCKTAKE_AUTO_DAYS', 0, cache=False)) + report_n_days = int( + common.models.InvenTreeSetting.get_setting( + 'STOCKTAKE_AUTO_DAYS', 0, cache=False + ) + ) if report_n_days < 1: - logger.info("Stocktake auto reports are disabled, exiting") + logger.info('Stocktake auto reports are disabled, exiting') return if not check_daily_holdoff('STOCKTAKE_RECENT_REPORT', report_n_days): - logger.info("Stocktake report was recently generated - exiting") + logger.info('Stocktake report was recently generated - exiting') return # Let's start a new stocktake report for all parts diff --git a/InvenTree/part/templatetags/i18n.py b/InvenTree/part/templatetags/i18n.py index 54af6040f6..ab2d440a32 100644 --- a/InvenTree/part/templatetags/i18n.py +++ b/InvenTree/part/templatetags/i18n.py @@ -42,15 +42,15 @@ class CustomTranslateNode(TranslateNode): result = result.replace(c, '') # Escape any quotes contained in the string - result = result.replace("'", r"\'") + result = result.replace("'", r'\'') result = result.replace('"', r'\"') # Return the 'clean' resulting string return result -@register.tag("translate") -@register.tag("trans") +@register.tag('translate') +@register.tag('trans') def do_translate(parser, token): """Custom translation function, lifted from https://github.com/django/django/blob/main/django/templatetags/i18n.py @@ -66,17 +66,17 @@ def do_translate(parser, token): asvar = None message_context = None seen = set() - invalid_context = {"as", "noop"} + invalid_context = {'as', 'noop'} while remaining: option = remaining.pop(0) if option in seen: raise TemplateSyntaxError( - "The '%s' option was specified more than once." % option, + "The '%s' option was specified more than once." % option ) - elif option == "noop": + elif option == 'noop': noop = True - elif option == "context": + elif option == 'context': try: value = remaining.pop(0) except IndexError: @@ -87,10 +87,10 @@ def do_translate(parser, token): if value in invalid_context: raise TemplateSyntaxError( "Invalid argument '%s' provided to the '%s' tag for the context " - "option" % (value, bits[0]), + 'option' % (value, bits[0]) ) message_context = parser.compile_filter(value) - elif option == "as": + elif option == 'as': try: value = remaining.pop(0) except IndexError: @@ -102,10 +102,7 @@ def do_translate(parser, token): raise TemplateSyntaxError( "Unknown argument for '%s' tag: '%s'. The only options " "available are 'noop', 'context' \"xxx\", and 'as VAR'." - % ( - bits[0], - option, - ) + % (bits[0], option) ) seen.add(option) @@ -113,18 +110,26 @@ def do_translate(parser, token): # Re-register tags which we have not explicitly overridden -register.tag("blocktrans", django.templatetags.i18n.do_block_translate) -register.tag("blocktranslate", django.templatetags.i18n.do_block_translate) +register.tag('blocktrans', django.templatetags.i18n.do_block_translate) +register.tag('blocktranslate', django.templatetags.i18n.do_block_translate) -register.tag("language", django.templatetags.i18n.language) +register.tag('language', django.templatetags.i18n.language) -register.tag("get_available_languages", django.templatetags.i18n.do_get_available_languages) -register.tag("get_language_info", django.templatetags.i18n.do_get_language_info) -register.tag("get_language_info_list", django.templatetags.i18n.do_get_language_info_list) -register.tag("get_current_language", django.templatetags.i18n.do_get_current_language) -register.tag("get_current_language_bidi", django.templatetags.i18n.do_get_current_language_bidi) +register.tag( + 'get_available_languages', django.templatetags.i18n.do_get_available_languages +) +register.tag('get_language_info', django.templatetags.i18n.do_get_language_info) +register.tag( + 'get_language_info_list', django.templatetags.i18n.do_get_language_info_list +) +register.tag('get_current_language', django.templatetags.i18n.do_get_current_language) +register.tag( + 'get_current_language_bidi', django.templatetags.i18n.do_get_current_language_bidi +) -register.filter("language_name", django.templatetags.i18n.language_name) -register.filter("language_name_translated", django.templatetags.i18n.language_name_translated) -register.filter("language_name_local", django.templatetags.i18n.language_name_local) -register.filter("language_bidi", django.templatetags.i18n.language_bidi) +register.filter('language_name', django.templatetags.i18n.language_name) +register.filter( + 'language_name_translated', django.templatetags.i18n.language_name_translated +) +register.filter('language_name_local', django.templatetags.i18n.language_name_local) +register.filter('language_bidi', django.templatetags.i18n.language_bidi) diff --git a/InvenTree/part/templatetags/inventree_extras.py b/InvenTree/part/templatetags/inventree_extras.py index cdb63aacd1..2b1b2633a7 100644 --- a/InvenTree/part/templatetags/inventree_extras.py +++ b/InvenTree/part/templatetags/inventree_extras.py @@ -55,7 +55,6 @@ def render_date(context, date_object): return None if isinstance(date_object, str): - date_object = date_object.strip() # Check for empty string @@ -66,29 +65,25 @@ def render_date(context, date_object): try: date_object = date.fromisoformat(date_object) except ValueError: - logger.warning("Tried to convert invalid date string: %s", date_object) + logger.warning('Tried to convert invalid date string: %s', date_object) return None # We may have already pre-cached the date format by calling this already! user_date_format = context.get('user_date_format', None) if user_date_format is None: - user = context.get('user', None) if user and user.is_authenticated: # User is specified - look for their date display preference - user_date_format = common.models.InvenTreeUserSetting.get_setting('DATE_DISPLAY_FORMAT', user=user) + user_date_format = common.models.InvenTreeUserSetting.get_setting( + 'DATE_DISPLAY_FORMAT', user=user + ) else: user_date_format = 'YYYY-MM-DD' # Convert the format string to Pythonic equivalent - replacements = [ - ('YYYY', '%Y'), - ('MMM', '%b'), - ('MM', '%m'), - ('DD', '%d'), - ] + replacements = [('YYYY', '%Y'), ('MMM', '%b'), ('MM', '%m'), ('DD', '%d')] for o, n in replacements: user_date_format = user_date_format.replace(o, n) @@ -171,11 +166,8 @@ def plugins_info(*args, **kwargs): plug_list = [plg for plg in registry.plugins.values() if plg.plugin_config().active] # Format list return [ - { - 'name': plg.name, - 'slug': plg.slug, - 'version': plg.version - } for plg in plug_list + {'name': plg.name, 'slug': plg.slug, 'version': plg.version} + for plg in plug_list ] @@ -228,7 +220,7 @@ def python_version(*args, **kwargs): def inventree_version(shortstring=False, *args, **kwargs): """Return InvenTree version string.""" if shortstring: - return _(f"{version.inventreeInstanceTitle()} v{version.inventreeVersion()}") + return _(f'{version.inventreeInstanceTitle()} v{version.inventreeVersion()}') return version.inventreeVersion() @@ -344,13 +336,19 @@ def setting_object(key, *args, **kwargs): if issubclass(plg.__class__, InvenTreePlugin): plg = plg.plugin_config() - return plugin.models.PluginSetting.get_setting_object(key, plugin=plg, cache=cache) + return plugin.models.PluginSetting.get_setting_object( + key, plugin=plg, cache=cache + ) elif 'method' in kwargs: - return plugin.models.NotificationUserSetting.get_setting_object(key, user=kwargs['user'], method=kwargs['method'], cache=cache) + return plugin.models.NotificationUserSetting.get_setting_object( + key, user=kwargs['user'], method=kwargs['method'], cache=cache + ) elif 'user' in kwargs: - return common.models.InvenTreeUserSetting.get_setting_object(key, user=kwargs['user'], cache=cache) + return common.models.InvenTreeUserSetting.get_setting_object( + key, user=kwargs['user'], cache=cache + ) else: return common.models.InvenTreeSetting.get_setting_object(key, cache=cache) @@ -360,7 +358,9 @@ def setting_object(key, *args, **kwargs): def settings_value(key, *args, **kwargs): """Return a settings value specified by the given key.""" if 'user' in kwargs: - if not kwargs['user'] or (kwargs['user'] and kwargs['user'].is_authenticated is False): + if not kwargs['user'] or ( + kwargs['user'] and kwargs['user'].is_authenticated is False + ): return common.models.InvenTreeUserSetting.get_setting(key) return common.models.InvenTreeUserSetting.get_setting(key, user=kwargs['user']) @@ -465,10 +465,7 @@ def get_available_themes(*args, **kwargs): from common.models import ColorTheme for key, name in ColorTheme.get_color_themes_choices(): - themes.append({ - 'key': key, - 'name': name - }) + themes.append({'key': key, 'name': name}) return themes @@ -573,7 +570,6 @@ class I18nStaticNode(StaticNode): self.original = self.path.var if hasattr(context, 'request'): - # Convert the "requested" language code to a standard format language_code = context.request.LANGUAGE_CODE.lower().strip() language_code = language_code.replace('_', '-') @@ -582,16 +578,11 @@ class I18nStaticNode(StaticNode): # - First, try the original requested code, e.g. 'pt-br' # - Next, try a simpler version of the code e.g. 'pt' # - Finally, fall back to english - options = [ - language_code, - language_code.split('-')[0], - 'en', - ] + options = [language_code, language_code.split('-')[0], 'en'] for lng in options: lng_file = os.path.join( - djangosettings.STATIC_ROOT, - self.original.format(lng=lng) + djangosettings.STATIC_ROOT, self.original.format(lng=lng) ) if os.path.exists(lng_file): @@ -657,18 +648,18 @@ def admin_url(user, table, pk): from django.urls import reverse if not djangosettings.INVENTREE_ADMIN_ENABLED: - return "" + return '' if not user.is_staff: - return "" + return '' # Check the user has the correct permission - perm_string = f"{app}.change_{model}" + perm_string = f'{app}.change_{model}' if not user.has_perm(perm_string): return '' # Fallback URL - url = reverse(f"admin:{app}_{model}_changelist") + url = reverse(f'admin:{app}_{model}_changelist') if pk: try: diff --git a/InvenTree/part/test_api.py b/InvenTree/part/test_api.py index 78ab67a9ef..1b32f5f914 100644 --- a/InvenTree/part/test_api.py +++ b/InvenTree/part/test_api.py @@ -20,13 +20,20 @@ import company.models import order.models from common.models import InvenTreeSetting from company.models import Company, SupplierPart -from InvenTree.status_codes import (BuildStatus, PurchaseOrderStatusGroups, - StockStatus) +from InvenTree.status_codes import BuildStatus, PurchaseOrderStatusGroups, StockStatus from InvenTree.unit_test import InvenTreeAPITestCase -from part.models import (BomItem, BomItemSubstitute, Part, PartCategory, - PartCategoryParameterTemplate, PartParameter, - PartParameterTemplate, PartRelated, PartStocktake, - PartTestTemplate) +from part.models import ( + BomItem, + BomItemSubstitute, + Part, + PartCategory, + PartCategoryParameterTemplate, + PartParameter, + PartParameterTemplate, + PartRelated, + PartStocktake, + PartTestTemplate, +) from stock.models import StockItem, StockLocation @@ -72,19 +79,67 @@ class PartCategoryAPITest(InvenTreeAPITestCase): ({'parent': 1, 'cascade': True}, 5, 'Filter by parent, cascading'), ({'cascade': True, 'depth': 0}, 8, 'Cascade with no parent, depth=0'), ({'cascade': False, 'depth': 10}, 8, 'Cascade with no parent, depth=0'), - ({'parent': 'null', 'cascade': True, 'depth': 0}, 2, 'Cascade with null parent, depth=0'), - ({'parent': 'null', 'cascade': True, 'depth': 10}, 8, 'Cascade with null parent and bigger depth'), - ({'parent': 'null', 'cascade': False, 'depth': 10}, 2, 'No cascade even with depth specified with null parent'), - ({'parent': 1, 'cascade': False, 'depth': 0}, 3, 'Dont cascade with depth=0 and parent'), - ({'parent': 1, 'cascade': True, 'depth': 0}, 3, 'Cascade with depth=0 and parent'), - ({'parent': 1, 'cascade': False, 'depth': 1}, 3, 'Dont cascade even with depth=1 specified with parent'), - ({'parent': 1, 'cascade': True, 'depth': 1}, 5, 'Cascade with depth=1 with parent'), - ({'parent': 1, 'cascade': True, 'depth': 'abcdefg'}, 5, 'Cascade with invalid depth and parent'), + ( + {'parent': 'null', 'cascade': True, 'depth': 0}, + 2, + 'Cascade with null parent, depth=0', + ), + ( + {'parent': 'null', 'cascade': True, 'depth': 10}, + 8, + 'Cascade with null parent and bigger depth', + ), + ( + {'parent': 'null', 'cascade': False, 'depth': 10}, + 2, + 'No cascade even with depth specified with null parent', + ), + ( + {'parent': 1, 'cascade': False, 'depth': 0}, + 3, + 'Dont cascade with depth=0 and parent', + ), + ( + {'parent': 1, 'cascade': True, 'depth': 0}, + 3, + 'Cascade with depth=0 and parent', + ), + ( + {'parent': 1, 'cascade': False, 'depth': 1}, + 3, + 'Dont cascade even with depth=1 specified with parent', + ), + ( + {'parent': 1, 'cascade': True, 'depth': 1}, + 5, + 'Cascade with depth=1 with parent', + ), + ( + {'parent': 1, 'cascade': True, 'depth': 'abcdefg'}, + 5, + 'Cascade with invalid depth and parent', + ), ({'parent': 42}, 8, 'Should return everything if parent_pk is not valid'), - ({'parent': 'null', 'exclude_tree': 1, 'cascade': True}, 2, 'Should return everything from except tree with pk=1'), - ({'parent': 'null', 'exclude_tree': 42, 'cascade': True}, 8, 'Should return everything because exclude_tree=42 is no valid pk'), - ({'parent': 1, 'starred': True, 'cascade': True}, 2, 'Should return the starred categories for the current user within the pk=1 tree'), - ({'parent': 1, 'starred': False, 'cascade': True}, 3, 'Should return the not starred categories for the current user within the pk=1 tree'), + ( + {'parent': 'null', 'exclude_tree': 1, 'cascade': True}, + 2, + 'Should return everything from except tree with pk=1', + ), + ( + {'parent': 'null', 'exclude_tree': 42, 'cascade': True}, + 8, + 'Should return everything because exclude_tree=42 is no valid pk', + ), + ( + {'parent': 1, 'starred': True, 'cascade': True}, + 2, + 'Should return the starred categories for the current user within the pk=1 tree', + ), + ( + {'parent': 1, 'starred': False, 'cascade': True}, + 3, + 'Should return the not starred categories for the current user within the pk=1 tree', + ), ] for params, res_len, description in test_cases: @@ -101,13 +156,15 @@ class PartCategoryAPITest(InvenTreeAPITestCase): 'parent', 'part_count', 'pathstring', - 'url' + 'url', ] response = self.get(url, expected_code=200) for result in response.data: for f in fields: - self.assertIn(f, result, f'"{f}" is missing in result of PartCategory list') + self.assertIn( + f, result, f'"{f}" is missing in result of PartCategory list' + ) def test_part_count(self): """Test that the 'part_count' field is annotated correctly""" @@ -115,35 +172,25 @@ class PartCategoryAPITest(InvenTreeAPITestCase): # Create a parent category cat = PartCategory.objects.create( - name='Parent Cat', - description='Some name', - parent=None + name='Parent Cat', description='Some name', parent=None ) # Create child categories for ii in range(10): child = PartCategory.objects.create( - name=f"Child cat {ii}", - description="A child category", - parent=cat + name=f'Child cat {ii}', description='A child category', parent=cat ) # Create parts in this category for jj in range(10): Part.objects.create( - name=f"Part xyz {jj}_{ii}", - description="A test part with a description", - category=child + name=f'Part xyz {jj}_{ii}', + description='A test part with a description', + category=child, ) # Filter by parent category - response = self.get( - url, - { - 'parent': cat.pk, - }, - expected_code=200 - ) + response = self.get(url, {'parent': cat.pk}, expected_code=200) # 10 child categories self.assertEqual(len(response.data), 10) @@ -153,10 +200,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase): self.assertEqual(result['part_count'], 10) # Detail view for parent category - response = self.get( - f'/api/part/category/{cat.pk}/', - expected_code=200 - ) + response = self.get(f'/api/part/category/{cat.pk}/', expected_code=200) # Annotation should include parts from all sub-categories self.assertEqual(response.data['part_count'], 100) @@ -179,7 +223,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase): 'category': 2, 'parameter_template': template.pk, 'default_value': 'xyz', - } + }, ) # Total number of category templates should have increased @@ -187,22 +231,26 @@ class PartCategoryAPITest(InvenTreeAPITestCase): self.assertEqual(len(response.data), 2 + n) # Filter by category - response = self.get( - url, - { - 'category': 2, - } - ) + response = self.get(url, {'category': 2}) self.assertEqual(len(response.data), n) # Test that we can retrieve individual templates via the API for template in PartCategoryParameterTemplate.objects.all(): - url = reverse('api-part-category-parameter-detail', kwargs={'pk': template.pk}) + url = reverse( + 'api-part-category-parameter-detail', kwargs={'pk': template.pk} + ) data = self.get(url, {}, expected_code=200).data - for key in ['pk', 'category', 'category_detail', 'parameter_template', 'parameter_template_detail', 'default_value']: + for key in [ + 'pk', + 'category', + 'category_detail', + 'parameter_template', + 'parameter_template_detail', + 'default_value', + ]: self.assertIn(key, data.keys()) # Test that we can delete via the API also @@ -227,31 +275,15 @@ class PartCategoryAPITest(InvenTreeAPITestCase): ] for v in invalid_values: - response = self.patch( - url, - { - 'description': v - }, - expected_code=400 - ) + response = self.patch(url, {'description': v}, expected_code=400) self.assertIn('Remove HTML tags', str(response.data)) # Raw characters should be allowed - allowed = [ - '<< hello', - 'Alpha & Omega', - 'A > B > C', - ] + allowed = ['<< hello', 'Alpha & Omega', 'A > B > C'] for val in allowed: - response = self.patch( - url, - { - 'description': val, - }, - expected_code=200, - ) + response = self.patch(url, {'description': val}, expected_code=200) self.assertEqual(response.data['description'], val) @@ -263,18 +295,11 @@ class PartCategoryAPITest(InvenTreeAPITestCase): 'A part\n category\n\t', 'A\t part\t category\t', 'A pa\rrt cat\r\r\regory', - 'A part\u200e catego\u200fry\u202e' + 'A part\u200e catego\u200fry\u202e', ] for val in values: - - response = self.patch( - url, - { - 'description': val, - }, - expected_code=200, - ) + response = self.patch(url, {'description': val}, expected_code=200) self.assertEqual(response.data['description'], 'A part category') @@ -282,25 +307,31 @@ class PartCategoryAPITest(InvenTreeAPITestCase): """Test category deletion with different parameters""" class Target(IntEnum): - move_subcategories_to_parent_move_parts_to_parent = 0, - move_subcategories_to_parent_delete_parts = 1, - delete_subcategories_move_parts_to_parent = 2, - delete_subcategories_delete_parts = 3, + move_subcategories_to_parent_move_parts_to_parent = (0,) + move_subcategories_to_parent_delete_parts = (1,) + delete_subcategories_move_parts_to_parent = (2,) + delete_subcategories_delete_parts = (3,) for i in range(4): delete_child_categories: bool = False delete_parts: bool = False - if i in (Target.move_subcategories_to_parent_delete_parts, Target.delete_subcategories_delete_parts): + if i in ( + Target.move_subcategories_to_parent_delete_parts, + Target.delete_subcategories_delete_parts, + ): delete_parts = True - if i in (Target.delete_subcategories_move_parts_to_parent, Target.delete_subcategories_delete_parts): + if i in ( + Target.delete_subcategories_move_parts_to_parent, + Target.delete_subcategories_delete_parts, + ): delete_child_categories = True # Create a parent category parent_category = PartCategory.objects.create( name='Parent category', description='This is the parent category where the child categories and parts are moved to', - parent=None + parent=None, ) category_count_before = PartCategory.objects.count() @@ -310,7 +341,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase): cat_to_delete = PartCategory.objects.create( name='Category to delete', description='This is the category to be deleted', - parent=parent_category + parent=parent_category, ) url = reverse('api-part-category-detail', kwargs={'pk': cat_to_delete.id}) @@ -318,30 +349,34 @@ class PartCategoryAPITest(InvenTreeAPITestCase): parts = [] # Create parts in the category to be deleted for jj in range(3): - parts.append(Part.objects.create( - name=f"Part xyz {i}_{jj}", - description="Child part of the deleted category", - category=cat_to_delete - )) + parts.append( + Part.objects.create( + name=f'Part xyz {i}_{jj}', + description='Child part of the deleted category', + category=cat_to_delete, + ) + ) child_categories = [] child_categories_parts = [] # Create child categories under the category to be deleted for ii in range(3): child = PartCategory.objects.create( - name=f"Child parent_cat {i}_{ii}", - description="A child category of the deleted category", - parent=cat_to_delete + name=f'Child parent_cat {i}_{ii}', + description='A child category of the deleted category', + parent=cat_to_delete, ) child_categories.append(child) # Create parts in the child categories for jj in range(3): - child_categories_parts.append(Part.objects.create( - name=f"Part xyz {i}_{jj}_{ii}", - description="Child part in the child category of the deleted category", - category=child - )) + child_categories_parts.append( + Part.objects.create( + name=f'Part xyz {i}_{jj}_{ii}', + description='Child part in the child category of the deleted category', + category=child, + ) + ) # Delete the created category (sub categories and their parts will be moved under the parent) params = {} @@ -349,11 +384,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase): params['delete_parts'] = '1' if delete_child_categories: params['delete_child_categories'] = '1' - response = self.delete( - url, - params, - expected_code=204, - ) + response = self.delete(url, params, expected_code=204) self.assertEqual(response.status_code, 204) @@ -363,7 +394,10 @@ class PartCategoryAPITest(InvenTreeAPITestCase): self.assertEqual(Part.objects.count(), part_count_before) elif i == Target.move_subcategories_to_parent_delete_parts: # Check if all parts deleted - self.assertEqual(Part.objects.count(), part_count_before + len(child_categories_parts)) + self.assertEqual( + Part.objects.count(), + part_count_before + len(child_categories_parts), + ) else: # parts moved to the parent category for part in parts: @@ -396,7 +430,7 @@ class PartCategoryAPITest(InvenTreeAPITestCase): name='Structural category', description='This is the structural category', parent=None, - structural=True + structural=True, ) part_count_before = Part.objects.count() @@ -404,9 +438,9 @@ class PartCategoryAPITest(InvenTreeAPITestCase): # Make sure that we get an error if we try to create part in the structural category with self.assertRaises(ValidationError): part = Part.objects.create( - name="-", - description="Part which shall not be created", - category=structural_category + name='-', + description='Part which shall not be created', + category=structural_category, ) # Ensure that the part really did not get created in the structural category @@ -417,14 +451,14 @@ class PartCategoryAPITest(InvenTreeAPITestCase): name='Non-structural category', description='This is a non-structural category', parent=None, - structural=False + structural=False, ) # Create the test part assigned to a non-structural category part = Part.objects.create( - name="-", - description="Part which category will be changed to structural", - category=non_structural_category + name='-', + description='Part which category will be changed to structural', + category=non_structural_category, ) # Assign the test part to a structural category and make sure it gives an error @@ -441,25 +475,13 @@ class PartCategoryAPITest(InvenTreeAPITestCase): url = reverse('api-part-category-detail', kwargs={'pk': 5}) # First, request without path detail - response = self.get( - url, - { - 'path_detail': False, - }, - expected_code=200 - ) + response = self.get(url, {'path_detail': False}, expected_code=200) # Check that the path detail information is not included self.assertFalse('path' in response.data.keys()) # Now, request *with* path detail - response = self.get( - url, - { - 'path_detail': True, - }, - expected_code=200 - ) + response = self.get(url, {'path_detail': True}, expected_code=200) self.assertTrue('path' in response.data.keys()) @@ -477,9 +499,7 @@ class PartOptionsAPITest(InvenTreeAPITestCase): Ensure that the required field details are provided! """ - roles = [ - 'part.add', - ] + roles = ['part.add'] def test_part(self): """Test the Part API OPTIONS.""" @@ -611,43 +631,23 @@ class PartAPITest(PartAPITestBase): self.assertEqual(len(response.data), 8) # Request top-level part categories only - response = self.get( - url, - { - 'parent': 'null', - } - ) + response = self.get(url, {'parent': 'null'}) self.assertEqual(len(response.data), 2) # Children of PartCategory<1>, cascade - response = self.get( - url, - { - 'parent': 1, - 'cascade': 'true', - } - ) + response = self.get(url, {'parent': 1, 'cascade': 'true'}) self.assertEqual(len(response.data), 5) # Children of PartCategory<1>, do not cascade - response = self.get( - url, - { - 'parent': 1, - 'cascade': 'false', - } - ) + response = self.get(url, {'parent': 1, 'cascade': 'false'}) self.assertEqual(len(response.data), 3) def test_add_categories(self): """Check that we can add categories.""" - data = { - 'name': 'Animals', - 'description': 'All animals go here' - } + data = {'name': 'Animals', 'description': 'All animals go here'} url = reverse('api-part-category-list') response = self.post(url, data) @@ -657,11 +657,7 @@ class PartAPITest(PartAPITestBase): # Add some sub-categories to the top-level 'Animals' category for animal in ['cat', 'dog', 'zebra']: - data = { - 'name': animal, - 'description': 'A sort of animal', - 'parent': parent, - } + data = {'name': animal, 'description': 'A sort of animal', 'parent': parent} response = self.post(url, data) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.data['parent'], parent) @@ -714,13 +710,7 @@ class PartAPITest(PartAPITestBase): """Test that we can filter part list by the 'in_bom_for' parameter""" url = reverse('api-part-list') - response = self.get( - url, - { - 'in_bom_for': 100, - }, - expected_code=200, - ) + response = self.get(url, {'in_bom_for': 100}, expected_code=200) self.assertEqual(len(response.data), 4) @@ -734,13 +724,11 @@ class PartAPITest(PartAPITestBase): # Add some relationships PartRelated.objects.create( - part_1=Part.objects.get(pk=1), - part_2=Part.objects.get(pk=2), + part_1=Part.objects.get(pk=1), part_2=Part.objects.get(pk=2) ) PartRelated.objects.create( - part_2=Part.objects.get(pk=1), - part_1=Part.objects.get(pk=3) + part_2=Part.objects.get(pk=1), part_1=Part.objects.get(pk=3) ) response = self.get(url, {'related': 1}, expected_code=200) @@ -752,7 +740,8 @@ class PartAPITest(PartAPITestBase): # First, construct a set of template / variant parts master_part = Part.objects.create( - name='Master', description='Master part which has some variants', + name='Master', + description='Master part which has some variants', category=category, is_template=True, ) @@ -761,33 +750,24 @@ class PartAPITest(PartAPITestBase): variants = [] for color in ['Red', 'Green', 'Blue', 'Yellow', 'Pink', 'Black']: - variants.append(Part.objects.create( - name=f"{color} Variant", description="Variant part with a specific color", - variant_of=master_part, - category=category, - )) + variants.append( + Part.objects.create( + name=f'{color} Variant', + description='Variant part with a specific color', + variant_of=master_part, + category=category, + ) + ) url = reverse('api-part-list') # An invalid part ID will return an error - response = self.get( - url, - { - 'convert_from': 999999, - }, - expected_code=400 - ) + response = self.get(url, {'convert_from': 999999}, expected_code=400) self.assertIn('Select a valid choice', str(response.data['convert_from'])) for variant in variants: - response = self.get( - url, - { - 'convert_from': variant.pk, - }, - expected_code=200, - ) + response = self.get(url, {'convert_from': variant.pk}, expected_code=200) # There should be the same number of results for each request self.assertEqual(len(response.data), 6) @@ -838,11 +818,7 @@ class PartAPITest(PartAPITestBase): # Try to post a new object (missing description) response = self.post( url, - data={ - 'part': 10000, - 'test_name': 'My very first test', - 'required': False, - } + data={'part': 10000, 'test_name': 'My very first test', 'required': False}, ) self.assertEqual(response.status_code, 400) @@ -854,8 +830,8 @@ class PartAPITest(PartAPITestBase): 'part': 10000, 'test_name': 'New Test', 'required': True, - 'description': 'a test description' - } + 'description': 'a test description', + }, ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -863,23 +839,13 @@ class PartAPITest(PartAPITestBase): # Try to post a new test with the same name (should fail) response = self.post( url, - data={ - 'part': 10004, - 'test_name': " newtest", - 'description': 'dafsdf', - } + data={'part': 10004, 'test_name': ' newtest', 'description': 'dafsdf'}, ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) # Try to post a new test against a non-trackable part (should fail) - response = self.post( - url, - data={ - 'part': 1, - 'test_name': 'A simple test', - } - ) + response = self.post(url, data={'part': 1, 'test_name': 'A simple test'}) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @@ -916,46 +882,22 @@ class PartAPITest(PartAPITestBase): url = reverse('api-part-list') - response = self.get( - url, - { - 'variant_of': 10000, - }, - expected_code=200 - ) + response = self.get(url, {'variant_of': 10000}, expected_code=200) # 3 direct children of template part self.assertEqual(len(response.data), 3) - response = self.get( - url, - { - 'ancestor': 10000, - }, - expected_code=200, - ) + response = self.get(url, {'ancestor': 10000}, expected_code=200) # 4 total descendants self.assertEqual(len(response.data), 4) # Use the 'green chair' as our reference - response = self.get( - url, - { - 'variant_of': 10003, - }, - expected_code=200, - ) + response = self.get(url, {'variant_of': 10003}, expected_code=200) self.assertEqual(len(response.data), 1) - response = self.get( - url, - { - 'ancestor': 10003, - }, - expected_code=200, - ) + response = self.get(url, {'ancestor': 10003}, expected_code=200) self.assertEqual(len(response.data), 1) @@ -971,24 +913,12 @@ class PartAPITest(PartAPITestBase): ) # There should still be only one direct variant - response = self.get( - url, - { - 'variant_of': 10003, - }, - expected_code=200, - ) + response = self.get(url, {'variant_of': 10003}, expected_code=200) self.assertEqual(len(response.data), 1) # However, now should be 101 descendants - response = self.get( - url, - { - 'ancestor': 10003, - }, - expected_code=200, - ) + response = self.get(url, {'ancestor': 10003}, expected_code=200) self.assertEqual(len(response.data), 101) @@ -1008,13 +938,7 @@ class PartAPITest(PartAPITestBase): self.assertEqual(response.data['variant_stock'], 0) # Grab a list of all variant chairs *under* the master template - response = self.get( - url, - { - 'ancestor': 10000, - }, - expected_code=200, - ) + response = self.get(url, {'ancestor': 10000}, expected_code=200) # 4 total descendants self.assertEqual(len(response.data), 4) @@ -1024,10 +948,7 @@ class PartAPITest(PartAPITestBase): # Now, let's make some variant stock for variant in Part.objects.get(pk=10000).get_descendants(include_self=False): - StockItem.objects.create( - part=variant, - quantity=100, - ) + StockItem.objects.create(part=variant, quantity=100) response = self.get('/api/part/10000/', {}, expected_code=200) @@ -1035,39 +956,27 @@ class PartAPITest(PartAPITestBase): self.assertEqual(response.data['variant_stock'], 400) # Check that each variant reports the correct stock quantities - response = self.get( - url, - { - 'ancestor': 10000, - }, - expected_code=200, - ) + response = self.get(url, {'ancestor': 10000}, expected_code=200) - expected_variant_stock = { - 10001: 0, - 10002: 0, - 10003: 100, - 10004: 0, - } + expected_variant_stock = {10001: 0, 10002: 0, 10003: 100, 10004: 0} for variant in response.data: self.assertEqual(variant['in_stock'], 100) - self.assertEqual(variant['variant_stock'], expected_variant_stock[variant['pk']]) + self.assertEqual( + variant['variant_stock'], expected_variant_stock[variant['pk']] + ) # Add some 'sub variants' for the green chair variant green_chair = Part.objects.get(pk=10004) for i in range(10): gcv = Part.objects.create( - name=f"GC Var {i}", - description="Green chair variant", + name=f'GC Var {i}', + description='Green chair variant', variant_of=green_chair, ) - StockItem.objects.create( - part=gcv, - quantity=50, - ) + StockItem.objects.create(part=gcv, quantity=50) # Spot check of some values response = self.get('/api/part/10000/', {}) @@ -1095,19 +1004,11 @@ class PartAPITest(PartAPITestBase): 'creation_date', ] - excluded_cols = [ - 'lft', 'rght', 'level', 'tree_id', - 'metadata', - ] + excluded_cols = ['lft', 'rght', 'level', 'tree_id', 'metadata'] with self.download_file( - url, - { - 'export': 'csv', - }, - expected_fn='InvenTree_Parts.csv', + url, {'export': 'csv'}, expected_fn='InvenTree_Parts.csv' ) as file: - data = self.process_csv( file, excluded_cols=excluded_cols, @@ -1139,13 +1040,7 @@ class PartAPITest(PartAPITestBase): date_compare = datetime.fromisoformat('2019-01-01') # Filter by creation date - response = self.get( - url, - { - 'created_before': '2019-01-01', - }, - expected_code=200 - ) + response = self.get(url, {'created_before': '2019-01-01'}, expected_code=200) self.assertTrue(len(response.data) < n) self.assertTrue(len(response.data) > 0) @@ -1156,13 +1051,7 @@ class PartAPITest(PartAPITestBase): date = datetime.fromisoformat(item['creation_date']) self.assertLessEqual(date, date_compare) - response = self.get( - url, - { - 'created_after': '2019-01-01', - }, - expected_code=200 - ) + response = self.get(url, {'created_after': '2019-01-01'}, expected_code=200) self.assertTrue(len(response.data) < n) self.assertTrue(len(response.data) > 0) @@ -1187,11 +1076,7 @@ class PartCreationTests(PartAPITestBase): response = self.post( url, - { - 'name': 'all defaults', - 'description': 'my test part', - 'category': 1, - }, + {'name': 'all defaults', 'description': 'my test part', 'category': 1}, expected_code=201, ) @@ -1205,19 +1090,11 @@ class PartCreationTests(PartAPITestBase): self.assertTrue(data['purchaseable']) # Set the default 'purchaseable' status to True - InvenTreeSetting.set_setting( - 'PART_PURCHASEABLE', - True, - self.user - ) + InvenTreeSetting.set_setting('PART_PURCHASEABLE', True, self.user) response = self.post( url, - { - 'name': 'all defaults 2', - 'description': 'my test part 2', - 'category': 1, - }, + {'name': 'all defaults 2', 'description': 'my test part 2', 'category': 1}, expected_code=201, ) @@ -1234,7 +1111,7 @@ class PartCreationTests(PartAPITestBase): 'active': False, 'purchaseable': False, }, - expected_code=201 + expected_code=201, ) self.assertFalse(response.data['active']) @@ -1254,9 +1131,7 @@ class PartCreationTests(PartAPITestBase): data['initial_stock'] = stock_data response = self.post( - reverse('api-part-list'), - data, - expected_code=expected_code + reverse('api-part-list'), data, expected_code=expected_code ) return response.data @@ -1266,19 +1141,18 @@ class PartCreationTests(PartAPITestBase): # Submit with empty data response = submit({}, expected_code=400) - self.assertIn('This field is required', str(response['initial_stock']['quantity'])) + self.assertIn( + 'This field is required', str(response['initial_stock']['quantity']) + ) # Submit with invalid quantity - response = submit({ - 'quantity': 'ax', - }, expected_code=400) - self.assertIn('A valid number is required', str(response['initial_stock']['quantity'])) + response = submit({'quantity': 'ax'}, expected_code=400) + self.assertIn( + 'A valid number is required', str(response['initial_stock']['quantity']) + ) # Submit with valid data - response = submit({ - 'quantity': 50, - 'location': 1, - }, expected_code=201) + response = submit({'quantity': 50, 'location': 1}, expected_code=201) part = Part.objects.get(pk=response['pk']) self.assertEqual(part.total_stock, 50) @@ -1298,9 +1172,7 @@ class PartCreationTests(PartAPITestBase): data['initial_supplier'] = supplier_data response = self.post( - reverse('api-part-list'), - data, - expected_code=expected_code + reverse('api-part-list'), data, expected_code=expected_code ) return response.data @@ -1310,46 +1182,41 @@ class PartCreationTests(PartAPITestBase): n_sp = company.models.SupplierPart.objects.count() # Submit with an invalid manufacturer - response = submit({ - 'manufacturer': 99999, - }) + response = submit({'manufacturer': 99999}) - self.assertIn('object does not exist', str(response['initial_supplier']['manufacturer'])) + self.assertIn( + 'object does not exist', str(response['initial_supplier']['manufacturer']) + ) - response = submit({ - 'manufacturer': 8 - }) + response = submit({'manufacturer': 8}) - self.assertIn('Selected company is not a valid manufacturer', str(response['initial_supplier']['manufacturer'])) + self.assertIn( + 'Selected company is not a valid manufacturer', + str(response['initial_supplier']['manufacturer']), + ) # Submit with an invalid supplier - response = submit({ - 'supplier': 8, - }) + response = submit({'supplier': 8}) - self.assertIn('Selected company is not a valid supplier', str(response['initial_supplier']['supplier'])) + self.assertIn( + 'Selected company is not a valid supplier', + str(response['initial_supplier']['supplier']), + ) # Test for duplicate MPN - response = submit({ - 'manufacturer': 6, - 'mpn': 'MPN123', - }) + response = submit({'manufacturer': 6, 'mpn': 'MPN123'}) - self.assertIn('Manufacturer part matching this MPN already exists', str(response)) + self.assertIn( + 'Manufacturer part matching this MPN already exists', str(response) + ) # Test for duplicate SKU - response = submit({ - 'supplier': 2, - 'sku': 'MPN456-APPEL', - }) + response = submit({'supplier': 2, 'sku': 'MPN456-APPEL'}) self.assertIn('Supplier part matching this SKU already exists', str(response)) # Test fields which are too long - response = submit({ - 'sku': 'abc' * 100, - 'mpn': 'xyz' * 100, - }) + response = submit({'sku': 'abc' * 100, 'mpn': 'xyz' * 100}) too_long = 'Ensure this field has no more than 100 characters' @@ -1358,13 +1225,8 @@ class PartCreationTests(PartAPITestBase): # Finally, submit a valid set of information response = submit( - { - 'supplier': 2, - 'sku': 'ABCDEFG', - 'manufacturer': 6, - 'mpn': 'QWERTY' - }, - expected_code=201 + {'supplier': 2, 'sku': 'ABCDEFG', 'manufacturer': 6, 'mpn': 'QWERTY'}, + expected_code=201, ) self.assertEqual(n_part + 1, Part.objects.count()) @@ -1375,14 +1237,10 @@ class PartCreationTests(PartAPITestBase): """Test that non-standard ASCII chars are accepted.""" url = reverse('api-part-list') - name = "Kaltgerätestecker" - description = "Gerät Kaltgerätestecker strange chars should get through" + name = 'Kaltgerätestecker' + description = 'Gerät Kaltgerätestecker strange chars should get through' - data = { - "name": name, - "description": description, - "category": 2 - } + data = {'name': name, 'description': description, 'category': 2} response = self.post(url, data, expected_code=201) @@ -1406,7 +1264,7 @@ class PartCreationTests(PartAPITestBase): 'copy_bom': bom, 'copy_image': img, 'copy_parameters': params, - } + }, }, expected_code=201, ) @@ -1426,7 +1284,7 @@ class PartCreationTests(PartAPITestBase): PartCategoryParameterTemplate.objects.create( parameter_template=PartParameterTemplate.objects.get(pk=pk), category=cat, - default_value=f"Value {pk}" + default_value=f'Value {pk}', ) self.assertEqual(cat.parameter_templates.count(), 3) @@ -1487,7 +1345,7 @@ class PartDetailTests(PartAPITestBase): 'description': 'a part created with the API', 'category': 1, 'tags': '["tag1", "tag2"]', - } + }, ) self.assertEqual(response.status_code, 201) @@ -1508,9 +1366,7 @@ class PartDetailTests(PartAPITestBase): # Let's change the name of the part - response = self.patch(url, { - 'name': 'a new better name', - }) + response = self.patch(url, {'name': 'a new better name'}) self.assertEqual(response.status_code, 200) self.assertEqual(response.data['pk'], pk) @@ -1526,16 +1382,12 @@ class PartDetailTests(PartAPITestBase): # Now, try to set the name to the *same* value # 2021-06-22 this test is to check that the "duplicate part" checks don't do strange things - response = self.patch(url, { - 'name': 'a new better name', - }) + response = self.patch(url, {'name': 'a new better name'}) self.assertEqual(response.status_code, 200) # Try to remove a tag - response = self.patch(url, { - 'tags': ['tag1',], - }) + response = self.patch(url, {'tags': ['tag1']}) self.assertEqual(response.status_code, 200) self.assertEqual(response.data['tags'], ['tag1']) @@ -1557,34 +1409,38 @@ class PartDetailTests(PartAPITestBase): def test_duplicates(self): """Check that trying to create 'duplicate' parts results in errors.""" # Create a part - response = self.post(reverse('api-part-list'), { - 'name': 'part', - 'description': 'description', - 'IPN': 'IPN-123', - 'category': 1, - 'revision': 'A', - }) + response = self.post( + reverse('api-part-list'), + { + 'name': 'part', + 'description': 'description', + 'IPN': 'IPN-123', + 'category': 1, + 'revision': 'A', + }, + ) self.assertEqual(response.status_code, 201) n = Part.objects.count() # Check that we cannot create a duplicate in a different category - response = self.post(reverse('api-part-list'), { - 'name': 'part', - 'description': 'description', - 'IPN': 'IPN-123', - 'category': 2, - 'revision': 'A', - }) + response = self.post( + reverse('api-part-list'), + { + 'name': 'part', + 'description': 'description', + 'IPN': 'IPN-123', + 'category': 2, + 'revision': 'A', + }, + ) self.assertEqual(response.status_code, 400) # Check that only 1 matching part exists parts = Part.objects.filter( - name='part', - description='description', - IPN='IPN-123' + name='part', description='description', IPN='IPN-123' ) self.assertEqual(parts.count(), 1) @@ -1593,13 +1449,16 @@ class PartDetailTests(PartAPITestBase): self.assertEqual(Part.objects.count(), n) # But a different 'revision' *can* be created - response = self.post(reverse('api-part-list'), { - 'name': 'part', - 'description': 'description', - 'IPN': 'IPN-123', - 'category': 2, - 'revision': 'B', - }) + response = self.post( + reverse('api-part-list'), + { + 'name': 'part', + 'description': 'description', + 'IPN': 'IPN-123', + 'category': 2, + 'revision': 'B', + }, + ) self.assertEqual(response.status_code, 201) self.assertEqual(Part.objects.count(), n + 1) @@ -1610,22 +1469,12 @@ class PartDetailTests(PartAPITestBase): url = reverse('api-part-detail', kwargs={'pk': pk}) # Attempt to alter the revision code - response = self.patch( - url, - { - 'revision': 'A', - } - ) + response = self.patch(url, {'revision': 'A'}) self.assertEqual(response.status_code, 400) # But we *can* change it to a unique revision code - response = self.patch( - url, - { - 'revision': 'C', - } - ) + response = self.patch(url, {'revision': 'C'}) self.assertEqual(response.status_code, 200) @@ -1636,12 +1485,8 @@ class PartDetailTests(PartAPITestBase): # Create a new part response = self.post( reverse('api-part-list'), - { - 'name': 'imagine', - 'description': 'All the people', - 'category': 1, - }, - expected_code=201 + {'name': 'imagine', 'description': 'All the people', 'category': 1}, + expected_code=201, ) pk = response.data['pk'] @@ -1660,11 +1505,7 @@ class PartDetailTests(PartAPITestBase): with open('dummy_image.txt', 'rb') as dummy_image: response = self.upload_client.patch( - url, - { - 'image': dummy_image, - }, - format='multipart', + url, {'image': dummy_image}, format='multipart' ) self.assertEqual(response.status_code, 400) @@ -1679,11 +1520,7 @@ class PartDetailTests(PartAPITestBase): with open(fn, 'rb') as dummy_image: response = self.upload_client.patch( - url, - { - 'image': dummy_image, - }, - format='multipart', + url, {'image': dummy_image}, format='multipart' ) self.assertEqual(response.status_code, 200) @@ -1705,10 +1542,7 @@ class PartDetailTests(PartAPITestBase): with open(fn, 'rb') as img_file: response = self.upload_client.patch( - reverse('api-part-detail', kwargs={'pk': p.pk}), - { - 'image': img_file, - }, + reverse('api-part-detail', kwargs={'pk': p.pk}), {'image': img_file} ) self.assertEqual(response.status_code, 200) @@ -1724,7 +1558,7 @@ class PartDetailTests(PartAPITestBase): 'category': 1, 'existing_image': 'does_not_exist.png', }, - expected_code=400 + expected_code=400, ) # Now, create a new part and assign the same image @@ -1734,7 +1568,7 @@ class PartDetailTests(PartAPITestBase): 'name': 'New part', 'description': 'New part description', 'category': 1, - 'existing_image': image_name.split(os.path.sep)[-1] + 'existing_image': image_name.split(os.path.sep)[-1], }, expected_code=201, ) @@ -1751,8 +1585,7 @@ class PartDetailTests(PartAPITestBase): # How many parts are 'on order' for this part? lines = order.models.PurchaseOrderLineItem.objects.filter( - part__part__pk=1, - order__status__in=PurchaseOrderStatusGroups.OPEN, + part__part__pk=1, order__status__in=PurchaseOrderStatusGroups.OPEN ) on_order = 0 @@ -1774,9 +1607,7 @@ class PartDetailTests(PartAPITestBase): """Check that path_detail can be requested against the serializer""" response = self.get( reverse('api-part-detail', kwargs={'pk': 1}), - { - 'path_detail': True, - }, + {'path_detail': True}, expected_code=200, ) @@ -1789,13 +1620,7 @@ class PartListTests(PartAPITestBase): def test_query_count(self): """Test that the query count is unchanged, independent of query results""" - queries = [ - {'limit': 1}, - {'limit': 10}, - {'limit': 50}, - {'category': 1}, - {}, - ] + queries = [{'limit': 1}, {'limit': 10}, {'limit': 50}, {'category': 1}, {}] url = reverse('api-part-list') @@ -1803,17 +1628,20 @@ class PartListTests(PartAPITestBase): parts = [] for ii in range(100): - parts.append(Part( - name=f"Extra part {ii}", - description="A new part which will appear via the API", - level=0, tree_id=0, - lft=0, rght=0, - )) + parts.append( + Part( + name=f'Extra part {ii}', + description='A new part which will appear via the API', + level=0, + tree_id=0, + lft=0, + rght=0, + ) + ) Part.objects.bulk_create(parts) for query in queries: - with CaptureQueriesContext(connection) as ctx: self.get(url, query, expected_code=200) @@ -1824,9 +1652,7 @@ class PartListTests(PartAPITestBase): for b in [False, True]: with CaptureQueriesContext(connection) as ctx: results = self.get( - reverse('api-part-list'), - {'category_detail': b}, - expected_code=200 + reverse('api-part-list'), {'category_detail': b}, expected_code=200 ) for result in results.data: @@ -1840,32 +1666,21 @@ class PartListTests(PartAPITestBase): class PartNotesTests(InvenTreeAPITestCase): """Tests for the 'notes' field (markdown field)""" - fixtures = [ - 'category', - 'part', - 'location', - 'company', - ] + fixtures = ['category', 'part', 'location', 'company'] - roles = [ - 'part.change', - 'part.add', - ] + roles = ['part.change', 'part.add'] def test_long_notes(self): """Test that very long notes field is rejected""" # Ensure that we cannot upload a very long piece of text url = reverse('api-part-detail', kwargs={'pk': 1}) - response = self.patch( - url, - { - 'notes': 'abcde' * 10001 - }, - expected_code=400 - ) + response = self.patch(url, {'notes': 'abcde' * 10001}, expected_code=400) - self.assertIn('Ensure this field has no more than 50000 characters', str(response.data['notes'])) + self.assertIn( + 'Ensure this field has no more than 50000 characters', + str(response.data['notes']), + ) def test_multiline_formatting(self): """Ensure that markdown formatting is retained""" @@ -1882,13 +1697,7 @@ class PartNotesTests(InvenTreeAPITestCase): """ - response = self.patch( - url, - { - 'notes': notes, - }, - expected_code=200 - ) + response = self.patch(url, {'notes': notes}, expected_code=200) # Ensure that newline chars have not been removed self.assertIn('\n', response.data['notes']) @@ -1900,15 +1709,9 @@ class PartNotesTests(InvenTreeAPITestCase): class PartPricingDetailTests(InvenTreeAPITestCase): """Tests for the part pricing API endpoint""" - fixtures = [ - 'category', - 'part', - 'location', - ] + fixtures = ['category', 'part', 'location'] - roles = [ - 'part.change', - ] + roles = ['part.change'] def url(self, pk): """Construct a pricing URL""" @@ -1916,10 +1719,7 @@ class PartPricingDetailTests(InvenTreeAPITestCase): def test_pricing_detail(self): """Test an empty pricing detail""" - response = self.get( - self.url(1), - expected_code=200 - ) + response = self.get(self.url(1), expected_code=200) # Check for expected fields expected_fields = [ @@ -1961,10 +1761,7 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): 'sales_order', ] - roles = [ - 'part.view', - 'part.change', - ] + roles = ['part.view', 'part.change'] @classmethod def setUpTestData(cls): @@ -1989,7 +1786,9 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): StockItem.objects.create(part=cls.part, quantity=300) # Now create another 400 units which are LOST - StockItem.objects.create(part=cls.part, quantity=400, status=StockStatus.LOST.value) + StockItem.objects.create( + part=cls.part, quantity=400, status=StockStatus.LOST.value + ) def get_part_data(self): """Helper function for retrieving part data""" @@ -2056,26 +1855,21 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): # Now, let's create a sales order, and allocate some stock so = order.models.SalesOrder.objects.create( - reference='001', - customer=Company.objects.get(pk=1), + reference='001', customer=Company.objects.get(pk=1) ) # We wish to send 50 units of "Bob" against this sales order line = order.models.SalesOrderLineItem.objects.create( - quantity=50, - order=so, - part=part, + quantity=50, order=so, part=part ) # Create a shipment against the order shipment_1 = order.models.SalesOrderShipment.objects.create( - order=so, - reference='001', + order=so, reference='001' ) shipment_2 = order.models.SalesOrderShipment.objects.create( - order=so, - reference='002', + order=so, reference='002' ) # Allocate stock items to this order, against multiple shipments @@ -2083,14 +1877,14 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): line=line, shipment=shipment_1, item=StockItem.objects.get(pk=1007), - quantity=17 + quantity=17, ) order.models.SalesOrderAllocation.objects.create( line=line, shipment=shipment_1, item=StockItem.objects.get(pk=1008), - quantity=18 + quantity=18, ) order.models.SalesOrderAllocation.objects.create( @@ -2129,16 +1923,11 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): bom_item = BomItem.objects.get(pk=6) - line = build.models.BuildLine.objects.get( - bom_item=bom_item, - build=bo, - ) + line = build.models.BuildLine.objects.get(bom_item=bom_item, build=bo) # Allocate multiple stock items against this build order build.models.BuildItem.objects.create( - build_line=line, - stock_item=StockItem.objects.get(pk=1000), - quantity=10, + build_line=line, stock_item=StockItem.objects.get(pk=1000), quantity=10 ) # Request data once more @@ -2157,9 +1946,7 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): # Allocate further stock against the build build.models.BuildItem.objects.create( - build_line=line, - stock_item=StockItem.objects.get(pk=1001), - quantity=10, + build_line=line, stock_item=StockItem.objects.get(pk=1001), quantity=10 ) # Request data once more @@ -2183,24 +1970,20 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): and should also use the 'pack_size' of the supplier part objects. """ supplier = Company.objects.create( - name='Paint Supplies', - description='A supplier of paints', - is_supplier=True + name='Paint Supplies', description='A supplier of paints', is_supplier=True ) # First, create some parts paint = PartCategory.objects.create( - parent=None, - name="Paint", - description="Paints and such", + parent=None, name='Paint', description='Paints and such' ) for color in ['Red', 'Green', 'Blue', 'Orange', 'Yellow']: p = Part.objects.create( category=paint, units='litres', - name=f"{color} Paint", - description=f"Paint which is {color} in color" + name=f'{color} Paint', + description=f'Paint which is {color} in color', ) # Create multiple supplier parts in different sizes @@ -2208,7 +1991,7 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): sp = SupplierPart.objects.create( part=p, supplier=supplier, - SKU=f"PNT-{color}-{pk_sz}L", + SKU=f'PNT-{color}-{pk_sz}L', pack_quantity=str(pk_sz), ) @@ -2222,15 +2005,12 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): # Create multiple orders for _ii in range(5): - po = order.models.PurchaseOrder.objects.create( - supplier=supplier, - description='ordering some paint', + supplier=supplier, description='ordering some paint' ) # Order an assortment of items for sp in supplier_parts: - # Generate random quantity to order quantity = randint(10, 20) @@ -2239,19 +2019,12 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): # Add a line item item = order.models.PurchaseOrderLineItem.objects.create( - part=sp, - order=po, - quantity=quantity, - received=received, + part=sp, order=po, quantity=quantity, received=received ) # Now grab a list of parts from the API response = self.get( - reverse('api-part-list'), - { - 'category': paint.pk, - }, - expected_code=200, + reverse('api-part-list'), {'category': paint.pk}, expected_code=200 ) # Check that the correct number of items have been returned @@ -2283,20 +2056,9 @@ class PartAPIAggregationTest(InvenTreeAPITestCase): class BomItemTest(InvenTreeAPITestCase): """Unit tests for the BomItem API.""" - fixtures = [ - 'category', - 'part', - 'location', - 'stock', - 'bom', - 'company', - ] + fixtures = ['category', 'part', 'location', 'stock', 'bom', 'company'] - roles = [ - 'part.add', - 'part.change', - 'part.delete', - ] + roles = ['part.add', 'part.change', 'part.delete'] def setUp(self): """Set up the test case""" @@ -2315,34 +2077,16 @@ class BomItemTest(InvenTreeAPITestCase): self.assertEqual(len(response.data), n) # Now, filter by part - response = self.get( - url, - data={ - 'part': 100, - }, - expected_code=200 - ) + response = self.get(url, data={'part': 100}, expected_code=200) # Filter by "validated" - response = self.get( - url, - data={ - 'validated': True, - }, - expected_code=200, - ) + response = self.get(url, data={'validated': True}, expected_code=200) # Should be zero validated results self.assertEqual(len(response.data), 0) # Now filter by "not validated" - response = self.get( - url, - data={ - 'validated': False, - }, - expected_code=200 - ) + response = self.get(url, data={'validated': False}, expected_code=200) # There should be at least one non-validated item self.assertTrue(len(response.data) > 0) @@ -2352,13 +2096,7 @@ class BomItemTest(InvenTreeAPITestCase): bom_item.validate_hash() - response = self.get( - url, - data={ - 'validated': True, - }, - expected_code=200, - ) + response = self.get(url, data={'validated': True}, expected_code=200) # Check that the expected response is returned self.assertEqual(len(response.data), 1) @@ -2366,7 +2104,6 @@ class BomItemTest(InvenTreeAPITestCase): # Each item in response should contain expected keys for el in response.data: - for key in ['available_stock', 'available_substitute_stock']: self.assertTrue(key in el) @@ -2379,37 +2116,19 @@ class BomItemTest(InvenTreeAPITestCase): self.assertEqual(len(response.data), 6) # Limit the results with a search term - response = self.get( - url, - { - 'search': '0805', - }, - expected_code=200, - ) + response = self.get(url, {'search': '0805'}, expected_code=200) self.assertEqual(len(response.data), 3) # Search by 'reference' field for q in ['ABCDE', 'LMNOP', 'VWXYZ']: - response = self.get( - url, - { - 'search': q, - }, - expected_code=200 - ) + response = self.get(url, {'search': q}, expected_code=200) self.assertEqual(len(response.data), 1) self.assertEqual(response.data[0]['reference'], q) # Search by nonsense data - response = self.get( - url, - { - 'search': 'xxxxxxxxxxxxxxxxx', - }, - expected_code=200 - ) + response = self.get(url, {'search': 'xxxxxxxxxxxxxxxxx'}, expected_code=200) self.assertEqual(len(response.data), 0) @@ -2418,10 +2137,7 @@ class BomItemTest(InvenTreeAPITestCase): url = reverse('api-bom-list') # Order by increasing quantity - response = self.get( - f"{url}?ordering=+quantity", - expected_code=200 - ) + response = self.get(f'{url}?ordering=+quantity', expected_code=200) self.assertEqual(len(response.data), 6) @@ -2431,34 +2147,21 @@ class BomItemTest(InvenTreeAPITestCase): self.assertTrue(q1 < q2) # Order by decreasing quantity - response = self.get( - f"{url}?ordering=-quantity", - expected_code=200, - ) + response = self.get(f'{url}?ordering=-quantity', expected_code=200) self.assertEqual(q1, response.data[-1]['quantity']) self.assertEqual(q2, response.data[0]['quantity']) # Now test ordering by 'sub_part' (which is actually 'sub_part__name') response = self.get( - url, - { - 'ordering': 'sub_part', - 'sub_part_detail': True, - }, - expected_code=200, + url, {'ordering': 'sub_part', 'sub_part_detail': True}, expected_code=200 ) n1 = response.data[0]['sub_part_detail']['name'] n2 = response.data[-1]['sub_part_detail']['name'] response = self.get( - url, - { - 'ordering': '-sub_part', - 'sub_part_detail': True, - }, - expected_code=200, + url, {'ordering': '-sub_part', 'sub_part_detail': True}, expected_code=200 ) self.assertEqual(n1, response.data[-1]['sub_part_detail']['name']) @@ -2507,11 +2210,7 @@ class BomItemTest(InvenTreeAPITestCase): """Test that we can create a new BomItem via the API.""" url = reverse('api-bom-list') - data = { - 'part': 100, - 'sub_part': 4, - 'quantity': 777, - } + data = {'part': 100, 'sub_part': 4, 'quantity': 777} self.post(url, data, expected_code=201) @@ -2537,13 +2236,7 @@ class BomItemTest(InvenTreeAPITestCase): sub_part.save() # How many stock items are initially available for this part? - response = self.get( - stock_url, - { - 'bom_item': bom_item.pk, - }, - expected_code=200 - ) + response = self.get(stock_url, {'bom_item': bom_item.pk}, expected_code=200) n_items = len(response.data) self.assertEqual(n_items, 2) @@ -2552,13 +2245,12 @@ class BomItemTest(InvenTreeAPITestCase): # Now we will create some variant parts and stock for ii in range(5): - # Create a variant part! variant = Part.objects.create( - name=f"Variant_{ii}", - description="A variant part, with a description", + name=f'Variant_{ii}', + description='A variant part, with a description', component=True, - variant_of=sub_part + variant_of=sub_part, ) variant.save() @@ -2567,23 +2259,13 @@ class BomItemTest(InvenTreeAPITestCase): # Create some stock items for this new part for _ in range(ii): - StockItem.objects.create( - part=variant, - location=loc, - quantity=100 - ) + StockItem.objects.create(part=variant, location=loc, quantity=100) # Keep track of running total n_items += ii # Now, there should be more stock items available! - response = self.get( - stock_url, - { - 'bom_item': bom_item.pk, - }, - expected_code=200 - ) + response = self.get(stock_url, {'bom_item': bom_item.pk}, expected_code=200) self.assertEqual(len(response.data), n_items) @@ -2592,13 +2274,7 @@ class BomItemTest(InvenTreeAPITestCase): bom_item.save() # There should now only be 2 stock items available again - response = self.get( - stock_url, - { - 'bom_item': bom_item.pk, - }, - expected_code=200 - ) + response = self.get(stock_url, {'bom_item': bom_item.pk}, expected_code=200) self.assertEqual(len(response.data), 2) @@ -2619,13 +2295,7 @@ class BomItemTest(InvenTreeAPITestCase): bom_item = BomItem.objects.get(pk=1) # Filter stock items which can be assigned against this stock item - response = self.get( - stock_url, - { - "bom_item": bom_item.pk, - }, - expected_code=200 - ) + response = self.get(stock_url, {'bom_item': bom_item.pk}, expected_code=200) n_items = len(response.data) @@ -2634,38 +2304,25 @@ class BomItemTest(InvenTreeAPITestCase): # Let's make some! for ii in range(5): sub_part = Part.objects.create( - name=f"Substitute {ii}", - description="A substitute part", + name=f'Substitute {ii}', + description='A substitute part', component=True, is_template=False, - assembly=False + assembly=False, ) # Create a new StockItem for this Part - StockItem.objects.create( - part=sub_part, - quantity=1000, - location=loc, - ) + StockItem.objects.create(part=sub_part, quantity=1000, location=loc) # Now, create an "alternative" for the BOM Item - BomItemSubstitute.objects.create( - bom_item=bom_item, - part=sub_part - ) + BomItemSubstitute.objects.create(bom_item=bom_item, part=sub_part) # We should be able to filter the API list to just return this new part response = self.get(url, data={'part': sub_part.pk}, expected_code=200) self.assertEqual(len(response.data), 1) # We should also have more stock available to allocate against this BOM item! - response = self.get( - stock_url, - { - "bom_item": bom_item.pk, - }, - expected_code=200 - ) + response = self.get(stock_url, {'bom_item': bom_item.pk}, expected_code=200) self.assertEqual(len(response.data), n_items + ii + 1) @@ -2676,7 +2333,7 @@ class BomItemTest(InvenTreeAPITestCase): # The BomItem detail endpoint should now also reflect the substitute data data = self.get( reverse('api-bom-item-detail', kwargs={'pk': bom_item.pk}), - expected_code=200 + expected_code=200, ).data # 5 substitute parts @@ -2698,10 +2355,10 @@ class BomItemTest(InvenTreeAPITestCase): for i in range(5): assy = Part.objects.create( - name=f"Assy_{i}", - description="An assembly made of other parts", + name=f'Assy_{i}', + description='An assembly made of other parts', active=True, - assembly=True + assembly=True, ) assemblies.append(assy) @@ -2710,31 +2367,20 @@ class BomItemTest(InvenTreeAPITestCase): # Create some sub-components for i in range(5): - cmp = Part.objects.create( - name=f"Component_{i}", - description="A sub component", + name=f'Component_{i}', + description='A sub component', active=True, - component=True + component=True, ) for j in range(i): # Create a BOM item - BomItem.objects.create( - quantity=10, - part=assemblies[j], - sub_part=cmp, - ) + BomItem.objects.create(quantity=10, part=assemblies[j], sub_part=cmp) components.append(cmp) - response = self.get( - url, - { - 'uses': cmp.pk, - }, - expected_code=200, - ) + response = self.get(url, {'uses': cmp.pk}, expected_code=200) self.assertEqual(len(response.data), i) @@ -2757,16 +2403,13 @@ class BomItemTest(InvenTreeAPITestCase): for i in range(10): # Create a variant part vp = Part.objects.create( - name=f"Var {i}", - description="Variant part description field", + name=f'Var {i}', + description='Variant part description field', variant_of=bom_item.sub_part, ) # Create a stock item - StockItem.objects.create( - part=vp, - quantity=100, - ) + StockItem.objects.create(part=vp, quantity=100) # There should now be variant stock available response = self.get('/api/bom/1/', {}, expected_code=200) @@ -2777,45 +2420,26 @@ class BomItemTest(InvenTreeAPITestCase): class PartAttachmentTest(InvenTreeAPITestCase): """Unit tests for the PartAttachment API endpoint""" - fixtures = [ - 'category', - 'part', - 'location', - ] + fixtures = ['category', 'part', 'location'] def test_add_attachment(self): """Test that we can create a new PartAttachment via the API""" url = reverse('api-part-attachment-list') # Upload without permission - response = self.post( - url, - {}, - expected_code=403, - ) + response = self.post(url, {}, expected_code=403) # Add required permission self.assignRole('part.add') # Upload without specifying part (will fail) - response = self.post( - url, - { - 'comment': 'Hello world', - }, - expected_code=400 - ) + response = self.post(url, {'comment': 'Hello world'}, expected_code=400) self.assertIn('This field is required', str(response.data['part'])) # Upload without file OR link (will fail) response = self.post( - url, - { - 'part': 1, - 'comment': 'Hello world', - }, - expected_code=400 + url, {'part': 1, 'comment': 'Hello world'}, expected_code=400 ) self.assertIn('Missing file', str(response.data['attachment'])) @@ -2823,12 +2447,7 @@ class PartAttachmentTest(InvenTreeAPITestCase): # Upload an invalid link (will fail) response = self.post( - url, - { - 'part': 1, - 'link': 'not-a-link.py', - }, - expected_code=400 + url, {'part': 1, 'link': 'not-a-link.py'}, expected_code=400 ) self.assertIn('Enter a valid URL', str(response.data['link'])) @@ -2837,13 +2456,7 @@ class PartAttachmentTest(InvenTreeAPITestCase): # Upload a valid link (will pass) response = self.post( - url, - { - 'part': 1, - 'link': link, - 'comment': 'Hello world', - }, - expected_code=201 + url, {'part': 1, 'link': link, 'comment': 'Hello world'}, expected_code=201 ) data = response.data @@ -2889,29 +2502,17 @@ class PartInternalPriceBreakTest(InvenTreeAPITestCase): (1.5, 90.999999999), (1.756, 89), (2, 86), - (25, 80) + (25, 80), ] for q, p in breaks: data = self.post( - url, - { - 'part': 1, - 'quantity': q, - 'price': p, - }, - expected_code=201 + url, {'part': 1, 'quantity': q, 'price': p}, expected_code=201 ).data self.assertEqual(data['part'], 1) - self.assertEqual( - round(Decimal(data['quantity']), 4), - round(Decimal(q), 4) - ) - self.assertEqual( - round(Decimal(data['price']), 4), - round(Decimal(p), 4) - ) + self.assertEqual(round(Decimal(data['quantity']), 4), round(Decimal(q), 4)) + self.assertEqual(round(Decimal(data['price']), 4), round(Decimal(p), 4)) # Now, ensure that we can delete the Part via the API # In particular this test checks that there are no circular post_delete relationships @@ -2922,7 +2523,7 @@ class PartInternalPriceBreakTest(InvenTreeAPITestCase): p.active = False p.save() - response = self.delete(reverse("api-part-detail", kwargs={"pk": 1})) + response = self.delete(reverse('api-part-detail', kwargs={'pk': 1})) self.assertEqual(response.status_code, 204) with self.assertRaises(Part.DoesNotExist): @@ -2935,12 +2536,7 @@ class PartStocktakeTest(InvenTreeAPITestCase): superuser = False is_staff = False - fixtures = [ - 'category', - 'part', - 'location', - 'stock', - ] + fixtures = ['category', 'part', 'location', 'stock'] def test_list_endpoint(self): """Test the list endpoint for the stocktake data""" @@ -2957,19 +2553,13 @@ class PartStocktakeTest(InvenTreeAPITestCase): # Iterate over (up to) 5 parts in the database for p in Part.objects.all()[:5]: - # Create some entries to_create = [] n = p.pk % 10 for idx in range(n): - to_create.append( - PartStocktake( - part=p, - quantity=(idx + 1) * 100, - ) - ) + to_create.append(PartStocktake(part=p, quantity=(idx + 1) * 100)) total += 1 @@ -2977,13 +2567,7 @@ class PartStocktakeTest(InvenTreeAPITestCase): PartStocktake.objects.bulk_create(to_create) # Query list endpoint - response = self.get( - url, - { - 'part': p.pk, - }, - expected_code=200, - ) + response = self.get(url, {'part': p.pk}, expected_code=200) # Check that the expected number of PartStocktake instances has been created self.assertEqual(len(response.data), n) @@ -3001,20 +2585,15 @@ class PartStocktakeTest(InvenTreeAPITestCase): self.assignRole('stocktake.view') for p in Part.objects.all(): - # Initially no stocktake information available self.assertIsNone(p.latest_stocktake) - note = f"Note {p.pk}" + note = f'Note {p.pk}' quantity = p.pk + 5 self.post( url, - { - 'part': p.pk, - 'quantity': quantity, - 'note': note, - }, + {'part': p.pk, 'quantity': quantity, 'note': note}, expected_code=201, ) @@ -3042,25 +2621,13 @@ class PartStocktakeTest(InvenTreeAPITestCase): self.get(url, expected_code=200) # Try to edit data - self.patch( - url, - { - 'note': 'Another edit', - }, - expected_code=403 - ) + self.patch(url, {'note': 'Another edit'}, expected_code=403) # Assign 'edit' role permission self.assignRole('stocktake.change') # Try again - self.patch( - url, - { - 'note': 'Editing note field again', - }, - expected_code=200, - ) + self.patch(url, {'note': 'Editing note field again'}, expected_code=200) # Try to delete self.delete(url, expected_code=403) @@ -3090,7 +2657,9 @@ class PartStocktakeTest(InvenTreeAPITestCase): self.assertEqual(len(response.data), n) # Stocktake report should be available via the API, also - response = self.get(reverse('api-part-stocktake-report-list'), expected_code=200) + response = self.get( + reverse('api-part-stocktake-report-list'), expected_code=200 + ) self.assertEqual(len(response.data), 1) @@ -3137,10 +2706,7 @@ class PartMetadataAPITest(InvenTreeAPITestCase): 'stock', ] - roles = [ - 'part.change', - 'part_category.change', - ] + roles = ['part.change', 'part_category.change'] def setUp(self): """Setup unit tets""" @@ -3163,17 +2729,15 @@ class PartMetadataAPITest(InvenTreeAPITestCase): self.patch( url, - { - 'metadata': { - f'abc-{numstr}': f'xyz-{apikey}-{numstr}', - } - }, - expected_code=200 + {'metadata': {f'abc-{numstr}': f'xyz-{apikey}-{numstr}'}}, + expected_code=200, ) # Refresh modeldata.refresh_from_db() - self.assertEqual(modeldata.get_metadata(f'abc-{numstr}'), f'xyz-{apikey}-{numstr}') + self.assertEqual( + modeldata.get_metadata(f'abc-{numstr}'), f'xyz-{apikey}-{numstr}' + ) def test_metadata(self): """Test all endpoints""" @@ -3196,9 +2760,7 @@ class PartSchedulingTest(PartAPITestBase): def test_get_schedule(self): """Test that the scheduling endpoint returns OK""" - part_ids = [ - 1, 3, 100, 101, - ] + part_ids = [1, 3, 100, 101] for pk in part_ids: url = reverse('api-part-scheduling', kwargs={'pk': pk}) diff --git a/InvenTree/part/test_bom_export.py b/InvenTree/part/test_bom_export.py index 1027e0c89e..46fcd199fe 100644 --- a/InvenTree/part/test_bom_export.py +++ b/InvenTree/part/test_bom_export.py @@ -11,12 +11,7 @@ from InvenTree.unit_test import InvenTreeTestCase class BomExportTest(InvenTreeTestCase): """Class for performing unit testing of BOM export functionality""" - fixtures = [ - 'category', - 'part', - 'location', - 'bom', - ] + fixtures = ['category', 'part', 'location', 'bom'] roles = 'all' @@ -37,7 +32,7 @@ class BomExportTest(InvenTreeTestCase): self.assertEqual(response.status_code, 200) self.assertEqual( response.headers['Content-Disposition'], - 'attachment; filename="InvenTree_BOM_Template.xls"' + 'attachment; filename="InvenTree_BOM_Template.xls"', ) # Return a simple CSV template @@ -45,7 +40,7 @@ class BomExportTest(InvenTreeTestCase): self.assertEqual(response.status_code, 200) self.assertEqual( response.headers['Content-Disposition'], - 'attachment; filename="InvenTree_BOM_Template.csv"' + 'attachment; filename="InvenTree_BOM_Template.csv"', ) filename = '_tmp.csv' diff --git a/InvenTree/part/test_bom_import.py b/InvenTree/part/test_bom_import.py index ae24418499..2a1c03d78c 100644 --- a/InvenTree/part/test_bom_import.py +++ b/InvenTree/part/test_bom_import.py @@ -12,10 +12,7 @@ from part.models import Part class BomUploadTest(InvenTreeAPITestCase): """Test BOM file upload API endpoint.""" - roles = [ - 'part.add', - 'part.change', - ] + roles = ['part.add', 'part.change'] @classmethod def setUpTestData(cls): @@ -36,34 +33,37 @@ class BomUploadTest(InvenTreeAPITestCase): for i in range(10): parts.append( Part( - name=f"Component {i}", - IPN=f"CMP_{i}", - description="A subcomponent that can be used in a BOM", + name=f'Component {i}', + IPN=f'CMP_{i}', + description='A subcomponent that can be used in a BOM', component=True, assembly=False, - lft=0, rght=0, - level=0, tree_id=0, + lft=0, + rght=0, + level=0, + tree_id=0, ) ) Part.objects.bulk_create(parts) - def post_bom(self, filename, file_data, clear_existing=None, expected_code=None, content_type='text/plain'): + def post_bom( + self, + filename, + file_data, + clear_existing=None, + expected_code=None, + content_type='text/plain', + ): """Helper function for submitting a BOM file""" - bom_file = SimpleUploadedFile( - filename, - file_data, - content_type=content_type, - ) + bom_file = SimpleUploadedFile(filename, file_data, content_type=content_type) if clear_existing is None: clear_existing = False response = self.post( reverse('api-bom-import-upload'), - data={ - 'data_file': bom_file, - }, + data={'data_file': bom_file}, expected_code=expected_code, format='multipart', ) @@ -73,30 +73,20 @@ class BomUploadTest(InvenTreeAPITestCase): def test_missing_file(self): """POST without a file.""" response = self.post( - reverse('api-bom-import-upload'), - data={}, - expected_code=400 + reverse('api-bom-import-upload'), data={}, expected_code=400 ) self.assertIn('No file was submitted', str(response.data['data_file'])) def test_unsupported_file(self): """POST with an unsupported file type.""" - response = self.post_bom( - 'sample.txt', - b'hello world', - expected_code=400, - ) + response = self.post_bom('sample.txt', b'hello world', expected_code=400) self.assertIn('Unsupported file type', str(response.data['data_file'])) def test_broken_file(self): """Test upload with broken (corrupted) files.""" - response = self.post_bom( - 'sample.csv', - b'', - expected_code=400, - ) + response = self.post_bom('sample.csv', b'', expected_code=400) self.assertIn('The submitted file is empty', str(response.data['data_file'])) @@ -107,16 +97,15 @@ class BomUploadTest(InvenTreeAPITestCase): content_type='application/xls', ) - self.assertIn('Unsupported format, or corrupt file', str(response.data['data_file'])) + self.assertIn( + 'Unsupported format, or corrupt file', str(response.data['data_file']) + ) def test_missing_rows(self): """Test upload of an invalid file (without data rows)""" dataset = tablib.Dataset() - dataset.headers = [ - 'apple', - 'banana', - ] + dataset.headers = ['apple', 'banana'] response = self.post_bom( 'test.csv', @@ -129,10 +118,7 @@ class BomUploadTest(InvenTreeAPITestCase): # Try again, with an .xlsx file response = self.post_bom( - 'bom.xlsx', - dataset.xlsx, - content_type='application/xlsx', - expected_code=400, + 'bom.xlsx', dataset.xlsx, content_type='application/xlsx', expected_code=400 ) self.assertIn('No data rows found in file', str(response.data)) @@ -141,50 +127,28 @@ class BomUploadTest(InvenTreeAPITestCase): """Upload extracted data, but with missing columns.""" url = reverse('api-bom-import-extract') - rows = [ - ['1', 'test'], - ['2', 'test'], - ] + rows = [['1', 'test'], ['2', 'test']] # Post without columns - response = self.post( - url, - {}, - expected_code=400, - ) + response = self.post(url, {}, expected_code=400) self.assertIn('This field is required', str(response.data['rows'])) self.assertIn('This field is required', str(response.data['columns'])) response = self.post( - url, - { - 'rows': rows, - 'columns': ['part', 'reference'], - }, - expected_code=400 + url, {'rows': rows, 'columns': ['part', 'reference']}, expected_code=400 ) self.assertIn("Missing required column: 'quantity'", str(response.data)) response = self.post( - url, - { - 'rows': rows, - 'columns': ['quantity', 'reference'], - }, - expected_code=400, + url, {'rows': rows, 'columns': ['quantity', 'reference']}, expected_code=400 ) self.assertIn('No part column specified', str(response.data)) self.post( - url, - { - 'rows': rows, - 'columns': ['quantity', 'part'], - }, - expected_code=201, + url, {'rows': rows, 'columns': ['quantity', 'part']}, expected_code=201 ) def test_invalid_data(self): @@ -197,7 +161,6 @@ class BomUploadTest(InvenTreeAPITestCase): components = Part.objects.filter(component=True) for idx, cmp in enumerate(components): - if idx == 5: cmp.component = False cmp.save() @@ -206,13 +169,7 @@ class BomUploadTest(InvenTreeAPITestCase): url = reverse('api-bom-import-extract') - response = self.post( - url, - { - 'columns': dataset.headers, - 'rows': list(dataset), - }, - ) + response = self.post(url, {'columns': dataset.headers, 'rows': list(dataset)}) rows = response.data['rows'] @@ -227,8 +184,12 @@ class BomUploadTest(InvenTreeAPITestCase): data['part'] = cmp.pk # Check some specific error messages - self.assertEqual(rows[0]['data']['errors']['quantity'], 'Quantity must be greater than zero') - self.assertEqual(rows[5]['data']['errors']['part'], 'Part is not designated as a component') + self.assertEqual( + rows[0]['data']['errors']['quantity'], 'Quantity must be greater than zero' + ) + self.assertEqual( + rows[5]['data']['errors']['part'], 'Part is not designated as a component' + ) def test_part_guess(self): """Test part 'guessing' when PK values are not supplied.""" @@ -240,20 +201,12 @@ class BomUploadTest(InvenTreeAPITestCase): components = Part.objects.filter(component=True) for component in components: - dataset.append([ - component.name, - 10, - ]) + dataset.append([component.name, 10]) url = reverse('api-bom-import-extract') response = self.post( - url, - { - 'columns': dataset.headers, - 'rows': list(dataset), - }, - expected_code=201, + url, {'columns': dataset.headers, 'rows': list(dataset)}, expected_code=201 ) rows = response.data['rows'] @@ -269,18 +222,10 @@ class BomUploadTest(InvenTreeAPITestCase): dataset.headers = ['part_ipn', 'quantity'] for component in components: - dataset.append([ - component.IPN, - 10, - ]) + dataset.append([component.IPN, 10]) response = self.post( - url, - { - 'columns': dataset.headers, - 'rows': list(dataset), - }, - expected_code=201, + url, {'columns': dataset.headers, 'rows': list(dataset)}, expected_code=201 ) rows = response.data['rows'] @@ -301,19 +246,10 @@ class BomUploadTest(InvenTreeAPITestCase): components = Part.objects.filter(component=True) for idx, cmp in enumerate(components): - dataset.append([ - idx % 3, - cmp.pk, - 2, - ]) + dataset.append([idx % 3, cmp.pk, 2]) response = self.post( - url, - { - 'rows': list(dataset), - 'columns': dataset.headers, - }, - expected_code=201, + url, {'rows': list(dataset), 'columns': dataset.headers}, expected_code=201 ) rows = response.data['rows'] diff --git a/InvenTree/part/test_bom_item.py b/InvenTree/part/test_bom_item.py index fecd0abbb1..4c2a7b7bb5 100644 --- a/InvenTree/part/test_bom_item.py +++ b/InvenTree/part/test_bom_item.py @@ -1,4 +1,3 @@ - """Unit tests for the BomItem model""" from decimal import Decimal @@ -70,7 +69,9 @@ class BomItemTest(TestCase): def test_integer_quantity(self): """Test integer validation for BomItem.""" - p = Part.objects.create(name="test", description="part description", component=True, trackable=True) + p = Part.objects.create( + name='test', description='part description', component=True, trackable=True + ) # Creation of a BOMItem with a non-integer quantity of a trackable Part should fail with self.assertRaises(django_exceptions.ValidationError): @@ -134,22 +135,19 @@ class BomItemTest(TestCase): self.bob.get_price(1) self.assertEqual( self.bob.get_bom_price_range(1, internal=True), - (Decimal(29.5), Decimal(89.5)) + (Decimal(29.5), Decimal(89.5)), ) # remove internal price for R_2K2_0805 self.r1.internal_price_breaks.delete() self.assertEqual( self.bob.get_bom_price_range(1, internal=True), - (Decimal(27.5), Decimal(87.5)) + (Decimal(27.5), Decimal(87.5)), ) def test_substitutes(self): """Tests for BOM item substitutes.""" # We will make some substitute parts for the "orphan" part - bom_item = BomItem.objects.get( - part=self.bob, - sub_part=self.orphan - ) + bom_item = BomItem.objects.get(part=self.bob, sub_part=self.orphan) # No substitute parts available self.assertEqual(bom_item.substitutes.count(), 0) @@ -157,11 +155,10 @@ class BomItemTest(TestCase): subs = [] for ii in range(5): - # Create a new part sub_part = Part.objects.create( - name=f"Orphan {ii}", - description="A substitute part for the orphan part", + name=f'Orphan {ii}', + description='A substitute part for the orphan part', component=True, is_template=False, assembly=False, @@ -170,28 +167,19 @@ class BomItemTest(TestCase): subs.append(sub_part) # Link it as a substitute part - BomItemSubstitute.objects.create( - bom_item=bom_item, - part=sub_part - ) + BomItemSubstitute.objects.create(bom_item=bom_item, part=sub_part) # Try to link it again (this should fail as it is a duplicate substitute) with self.assertRaises(django_exceptions.ValidationError): with transaction.atomic(): - BomItemSubstitute.objects.create( - bom_item=bom_item, - part=sub_part - ) + BomItemSubstitute.objects.create(bom_item=bom_item, part=sub_part) # There should be now 5 substitute parts available self.assertEqual(bom_item.substitutes.count(), 5) # Try to create a substitute which points to the same sub-part (should fail) with self.assertRaises(django_exceptions.ValidationError): - BomItemSubstitute.objects.create( - bom_item=bom_item, - part=self.orphan, - ) + BomItemSubstitute.objects.create(bom_item=bom_item, part=self.orphan) # Remove one substitute part bom_item.substitutes.last().delete() @@ -207,45 +195,42 @@ class BomItemTest(TestCase): def test_consumable(self): """Tests for the 'consumable' BomItem field""" # Create an assembly part - assembly = Part.objects.create(name="An assembly", description="Made with parts", assembly=True) + assembly = Part.objects.create( + name='An assembly', description='Made with parts', assembly=True + ) # No BOM information initially self.assertEqual(assembly.can_build, 0) # Create some component items - c1 = Part.objects.create(name="C1", description="Part C1 - this is just the part description") - c2 = Part.objects.create(name="C2", description="Part C2 - this is just the part description") - c3 = Part.objects.create(name="C3", description="Part C3 - this is just the part description") - c4 = Part.objects.create(name="C4", description="Part C4 - this is just the part description") + c1 = Part.objects.create( + name='C1', description='Part C1 - this is just the part description' + ) + c2 = Part.objects.create( + name='C2', description='Part C2 - this is just the part description' + ) + c3 = Part.objects.create( + name='C3', description='Part C3 - this is just the part description' + ) + c4 = Part.objects.create( + name='C4', description='Part C4 - this is just the part description' + ) for p in [c1, c2, c3, c4]: # Ensure we have stock stock.models.StockItem.objects.create(part=p, quantity=1000) # Create some BOM items - BomItem.objects.create( - part=assembly, - sub_part=c1, - quantity=10 - ) + BomItem.objects.create(part=assembly, sub_part=c1, quantity=10) self.assertEqual(assembly.can_build, 100) - BomItem.objects.create( - part=assembly, - sub_part=c2, - quantity=50, - consumable=True - ) + BomItem.objects.create(part=assembly, sub_part=c2, quantity=50, consumable=True) # A 'consumable' BomItem does not alter the can_build calculation self.assertEqual(assembly.can_build, 100) - BomItem.objects.create( - part=assembly, - sub_part=c3, - quantity=50, - ) + BomItem.objects.create(part=assembly, sub_part=c3, quantity=50) self.assertEqual(assembly.can_build, 20) @@ -271,16 +256,28 @@ class BomItemTest(TestCase): # First test: A BOM item which points to itself with self.assertRaises(django_exceptions.ValidationError): - BomItem.objects.create( - part=self.bob, - sub_part=self.bob, - quantity=1 - ) + BomItem.objects.create(part=self.bob, sub_part=self.bob, quantity=1) # Second test: A recursive BOM - part_a = Part.objects.create(name='Part A', description="A part which is called A", assembly=True, is_template=True, component=True) - part_b = Part.objects.create(name='Part B', description="A part which is called B", assembly=True, component=True) - part_c = Part.objects.create(name='Part C', description="A part which is called C", assembly=True, component=True) + part_a = Part.objects.create( + name='Part A', + description='A part which is called A', + assembly=True, + is_template=True, + component=True, + ) + part_b = Part.objects.create( + name='Part B', + description='A part which is called B', + assembly=True, + component=True, + ) + part_c = Part.objects.create( + name='Part C', + description='A part which is called C', + assembly=True, + component=True, + ) BomItem.objects.create(part=part_a, sub_part=part_b, quantity=10) BomItem.objects.create(part=part_b, sub_part=part_c, quantity=10) @@ -292,7 +289,13 @@ class BomItemTest(TestCase): BomItem.objects.create(part=part_c, sub_part=part_b, quantity=10) # Third test: A recursive BOM with a variant part - part_v = Part.objects.create(name='Part V', description='A part which is called V', variant_of=part_a, assembly=True, component=True) + part_v = Part.objects.create( + name='Part V', + description='A part which is called V', + variant_of=part_a, + assembly=True, + component=True, + ) with self.assertRaises(django_exceptions.ValidationError): BomItem.objects.create(part=part_a, sub_part=part_v, quantity=10) diff --git a/InvenTree/part/test_category.py b/InvenTree/part/test_category.py index 714c192449..bef67c71e2 100644 --- a/InvenTree/part/test_category.py +++ b/InvenTree/part/test_category.py @@ -12,12 +12,8 @@ class CategoryTest(TestCase): Loads the following test fixtures: - category.yaml """ - fixtures = [ - 'category', - 'part', - 'location', - 'params', - ] + + fixtures = ['category', 'part', 'location', 'params'] @classmethod def setUpTestData(cls): @@ -72,13 +68,15 @@ class CategoryTest(TestCase): self.transceivers.save() self.assertEqual(str(self.resistors), 'Electronics/Resistors - Resistors') - self.assertEqual(str(self.transceivers.pathstring), 'Electronics/IC/Transceivers') + self.assertEqual( + str(self.transceivers.pathstring), 'Electronics/IC/Transceivers' + ) # Create a new subcategory subcat = PartCategory.objects.create( name='Subcategory', description='My little sub category', - parent=self.transceivers + parent=self.transceivers, ) # Pathstring should have been updated correctly @@ -99,9 +97,7 @@ class CategoryTest(TestCase): # Construct a very long pathstring and ensure it gets updated correctly cat = PartCategory.objects.create( - name='Cat', - description='A long running category', - parent=None + name='Cat', description='A long running category', parent=None ) parent = cat @@ -110,9 +106,7 @@ class CategoryTest(TestCase): letter = chr(ord('A') + idx) child = PartCategory.objects.create( - name=letter * 10, - description=f"Subcategory {letter}", - parent=parent + name=letter * 10, description=f'Subcategory {letter}', parent=parent ) parent = child @@ -120,7 +114,7 @@ class CategoryTest(TestCase): self.assertTrue(len(child.path), 26) self.assertEqual( child.pathstring, - "Cat/AAAAAAAAAA/BBBBBBBBBB/CCCCCCCCCC/DDDDDDDDDD/EEEEEEEEEE/FFFFFFFFFF/GGGGGGGGGG/HHHHHHHHHH/IIIIIIIIII/JJJJJJJJJJ/KKKKKKKKK...OO/PPPPPPPPPP/QQQQQQQQQQ/RRRRRRRRRR/SSSSSSSSSS/TTTTTTTTTT/UUUUUUUUUU/VVVVVVVVVV/WWWWWWWWWW/XXXXXXXXXX/YYYYYYYYYY/ZZZZZZZZZZ" + 'Cat/AAAAAAAAAA/BBBBBBBBBB/CCCCCCCCCC/DDDDDDDDDD/EEEEEEEEEE/FFFFFFFFFF/GGGGGGGGGG/HHHHHHHHHH/IIIIIIIIII/JJJJJJJJJJ/KKKKKKKKK...OO/PPPPPPPPPP/QQQQQQQQQQ/RRRRRRRRRR/SSSSSSSSSS/TTTTTTTTTT/UUUUUUUUUU/VVVVVVVVVV/WWWWWWWWWW/XXXXXXXXXX/YYYYYYYYYY/ZZZZZZZZZZ', ) self.assertTrue(len(child.pathstring) <= 250) @@ -168,7 +162,9 @@ class CategoryTest(TestCase): self.assertIsInstance(parameter.template, PartParameterTemplate) # Test number of unique parameters - self.assertEqual(len(self.fasteners.get_unique_parameters(prefetch=fasteners)), 1) + self.assertEqual( + len(self.fasteners.get_unique_parameters(prefetch=fasteners)), 1 + ) # Test number of parameters found for each part parts_parameters = self.fasteners.get_parts_parameters(prefetch=fasteners) part_infos = ['pk', 'name', 'description'] @@ -202,7 +198,9 @@ class CategoryTest(TestCase): """Test traversal for default locations.""" self.assertIsNotNone(self.fasteners.default_location) self.fasteners.default_location.save() - self.assertEqual(str(self.fasteners.default_location), 'Office/Drawer_1 - In my desk') + self.assertEqual( + str(self.fasteners.default_location), 'Office/Drawer_1 - In my desk' + ) # Any part under electronics should default to 'Home' r1 = Part.objects.get(name='R_2K2_0805') @@ -229,10 +227,7 @@ class CategoryTest(TestCase): PartCategory.objects.rebuild() # First, create a structured tree of part categories - A = PartCategory.objects.create( - name='A', - description='Top level category', - ) + A = PartCategory.objects.create(name='A', description='Top level category') B1 = PartCategory.objects.create(name='B1', parent=A) B2 = PartCategory.objects.create(name='B2', parent=A) @@ -280,9 +275,7 @@ class CategoryTest(TestCase): for i in range(10): Part.objects.create( - name=f'Part {i}', - description='A test part', - category=B3, + name=f'Part {i}', description='A test part', category=B3 ) self.assertEqual(Part.objects.filter(category=B3).count(), 10) @@ -395,9 +388,7 @@ class CategoryTest(TestCase): # Create some sample parts under D32 for ii in range(10): Part.objects.create( - name=f'Part D32 {ii}', - description='A test part', - category=D32, + name=f'Part D32 {ii}', description='A test part', category=D32 ) self.assertEqual(Part.objects.filter(category=D32).count(), 10) diff --git a/InvenTree/part/test_migrations.py b/InvenTree/part/test_migrations.py index 59c1a650c4..cbff8e28f9 100644 --- a/InvenTree/part/test_migrations.py +++ b/InvenTree/part/test_migrations.py @@ -45,7 +45,7 @@ class TestForwardMigrations(MigratorTestCase): for name in ['A', 'C', 'E']: part = Part.objects.get(name=name) - self.assertEqual(part.description, f"My part {name}") + self.assertEqual(part.description, f'My part {name}') class TestBomItemMigrations(MigratorTestCase): @@ -92,17 +92,17 @@ class TestParameterMigrations(MigratorTestCase): """Create some parts, and templates with parameters""" Part = self.old_state.apps.get_model('part', 'part') PartParameter = self.old_state.apps.get_model('part', 'partparameter') - PartParameterTemlate = self.old_state.apps.get_model('part', 'partparametertemplate') + PartParameterTemlate = self.old_state.apps.get_model( + 'part', 'partparametertemplate' + ) # Create some parts a = Part.objects.create( - name='Part A', description='My part A', - level=0, lft=0, rght=0, tree_id=0, + name='Part A', description='My part A', level=0, lft=0, rght=0, tree_id=0 ) b = Part.objects.create( - name='Part B', description='My part B', - level=0, lft=0, rght=0, tree_id=0, + name='Part B', description='My part B', level=0, lft=0, rght=0, tree_id=0 ) # Create some templates @@ -111,7 +111,7 @@ class TestParameterMigrations(MigratorTestCase): # Create some parameter values PartParameter.objects.create(part=a, template=t1, data='1.0') - PartParameter.objects.create(part=a, template=t2, data='-2mA',) + PartParameter.objects.create(part=a, template=t2, data='-2mA') PartParameter.objects.create(part=b, template=t1, data='1/10 inch') PartParameter.objects.create(part=b, template=t2, data='abc') @@ -120,7 +120,9 @@ class TestParameterMigrations(MigratorTestCase): """Test that the template units and values have been updated correctly""" Part = self.new_state.apps.get_model('part', 'part') PartParameter = self.new_state.apps.get_model('part', 'partparameter') - PartParameterTemlate = self.new_state.apps.get_model('part', 'partparametertemplate') + PartParameterTemlate = self.new_state.apps.get_model( + 'part', 'partparametertemplate' + ) # Extract the parts a = Part.objects.get(name='Part A') @@ -166,8 +168,13 @@ class PartUnitsMigrationTest(MigratorTestCase): for idx, unit in enumerate(units): Part.objects.create( - name=f'Part {idx + 1}', description=f'My part at index {idx}', units=unit, - level=0, lft=0, rght=0, tree_id=0, + name=f'Part {idx + 1}', + description=f'My part at index {idx}', + units=unit, + level=0, + lft=0, + rght=0, + tree_id=0, ) def test_units_migration(self): @@ -196,10 +203,14 @@ class TestPartParameterTemplateMigration(MigratorTestCase): def prepare(self): """Prepare some parts with units""" - PartParameterTemplate = self.old_state.apps.get_model('part', 'partparametertemplate') + PartParameterTemplate = self.old_state.apps.get_model( + 'part', 'partparametertemplate' + ) # Create a test template - template = PartParameterTemplate.objects.create(name='Template 1', description='a part parameter template') + template = PartParameterTemplate.objects.create( + name='Template 1', description='a part parameter template' + ) # Ensure that the 'choices' and 'checkbox' fields do not exist with self.assertRaises(AttributeError): @@ -210,7 +221,9 @@ class TestPartParameterTemplateMigration(MigratorTestCase): def test_units_migration(self): """Test that the new fields have been added correctly""" - PartParameterTemplate = self.new_state.apps.get_model('part', 'partparametertemplate') + PartParameterTemplate = self.new_state.apps.get_model( + 'part', 'partparametertemplate' + ) template = PartParameterTemplate.objects.get(name='Template 1') diff --git a/InvenTree/part/test_param.py b/InvenTree/part/test_param.py index 827b8a369b..c6f6639353 100644 --- a/InvenTree/part/test_param.py +++ b/InvenTree/part/test_param.py @@ -7,19 +7,19 @@ from django.urls import reverse from common.models import InvenTreeSetting from InvenTree.unit_test import InvenTreeAPITestCase -from .models import (Part, PartCategory, PartCategoryParameterTemplate, - PartParameter, PartParameterTemplate) +from .models import ( + Part, + PartCategory, + PartCategoryParameterTemplate, + PartParameter, + PartParameterTemplate, +) class TestParams(TestCase): """Unit test class for testing the PartParameter model""" - fixtures = [ - 'location', - 'category', - 'part', - 'params' - ] + fixtures = ['location', 'category', 'part', 'params'] def test_str(self): """Test the str representation of the PartParameterTemplate model""" @@ -81,12 +81,7 @@ class TestParams(TestCase): class TestCategoryTemplates(TransactionTestCase): """Test class for PartCategoryParameterTemplate model""" - fixtures = [ - 'location', - 'category', - 'part', - 'params' - ] + fixtures = ['location', 'category', 'part', 'params'] def test_validate(self): """Test that category templates are correctly applied to Part instances""" @@ -97,9 +92,9 @@ class TestCategoryTemplates(TransactionTestCase): category = PartCategory.objects.get(pk=8) t1 = PartParameterTemplate.objects.get(pk=2) - c1 = PartCategoryParameterTemplate(category=category, - parameter_template=t1, - default_value='xyz') + c1 = PartCategoryParameterTemplate( + category=category, parameter_template=t1, default_value='xyz' + ) c1.save() n = PartCategoryParameterTemplate.objects.all().count() @@ -109,19 +104,14 @@ class TestCategoryTemplates(TransactionTestCase): class ParameterTests(TestCase): """Unit tests for parameter validation""" - fixtures = [ - 'location', - 'category', - 'part', - 'params' - ] + fixtures = ['location', 'category', 'part', 'params'] def test_choice_validation(self): """Test that parameter choices are correctly validated""" template = PartParameterTemplate.objects.create( name='My Template', description='A template with choices', - choices='red, blue, green' + choices='red, blue, green', ) pass_values = ['red', 'blue', 'green'] @@ -141,7 +131,21 @@ class ParameterTests(TestCase): def test_unit_validation(self): """Test validation of 'units' field for PartParameterTemplate""" # Test that valid units pass - for unit in [None, '', '%', 'mm', 'A', 'm^2', 'Pa', 'V', 'C', 'F', 'uF', 'mF', 'millifarad']: + for unit in [ + None, + '', + '%', + 'mm', + 'A', + 'm^2', + 'Pa', + 'V', + 'C', + 'F', + 'uF', + 'mF', + 'millifarad', + ]: tmp = PartParameterTemplate(name='test', units=unit) tmp.full_clean() @@ -153,31 +157,40 @@ class ParameterTests(TestCase): def test_param_unit_validation(self): """Test that parameters are correctly validated against template units""" - template = PartParameterTemplate.objects.create( - name='My Template', - units='m', - ) + template = PartParameterTemplate.objects.create(name='My Template', units='m') prt = Part.objects.get(pk=1) # Test that valid parameters pass - for value in ['1', '1m', 'm', '-4m', -2, '2.032mm', '99km', '-12 mile', 'foot', '3 yards']: + for value in [ + '1', + '1m', + 'm', + '-4m', + -2, + '2.032mm', + '99km', + '-12 mile', + 'foot', + '3 yards', + ]: param = PartParameter(part=prt, template=template, data=value) param.full_clean() # Test that percent unit is working template2 = PartParameterTemplate.objects.create( - name='My Template 2', - units='%', + name='My Template 2', units='%' ) - for value in ["1", "1%", "1 percent"]: + for value in ['1', '1%', '1 percent']: param = PartParameter(part=prt, template=template2, data=value) param.full_clean() bad_values = ['3 Amps', '-3 zogs', '3.14F'] # Disable enforcing of part parameter units - InvenTreeSetting.set_setting('PART_PARAMETER_ENFORCE_UNITS', False, change_user=None) + InvenTreeSetting.set_setting( + 'PART_PARAMETER_ENFORCE_UNITS', False, change_user=None + ) # Invalid units also pass, but will be converted to the template units for value in bad_values: @@ -185,7 +198,9 @@ class ParameterTests(TestCase): param.full_clean() # Enable enforcing of part parameter units - InvenTreeSetting.set_setting('PART_PARAMETER_ENFORCE_UNITS', True, change_user=None) + InvenTreeSetting.set_setting( + 'PART_PARAMETER_ENFORCE_UNITS', True, change_user=None + ) for value in bad_values: param = PartParameter(part=prt, template=template, data=value) @@ -194,10 +209,7 @@ class ParameterTests(TestCase): def test_param_unit_conversion(self): """Test that parameters are correctly converted to template units""" - template = PartParameterTemplate.objects.create( - name='My Template', - units='m', - ) + template = PartParameterTemplate.objects.create(name='My Template', units='m') tests = { '1': 1.0, @@ -219,14 +231,10 @@ class ParameterTests(TestCase): class PartParameterTest(InvenTreeAPITestCase): """Tests for the ParParameter API.""" + superuser = True - fixtures = [ - 'category', - 'part', - 'location', - 'params', - ] + fixtures = ['category', 'part', 'location', 'params'] def test_list_params(self): """Test for listing part parameters.""" @@ -237,22 +245,12 @@ class PartParameterTest(InvenTreeAPITestCase): self.assertEqual(len(response.data), 7) # Filter by part - response = self.get( - url, - { - 'part': 3, - } - ) + response = self.get(url, {'part': 3}) self.assertEqual(len(response.data), 3) # Filter by template - response = self.get( - url, - { - 'template': 1, - } - ) + response = self.get(url, {'template': 1}) self.assertEqual(len(response.data), 4) @@ -261,10 +259,7 @@ class PartParameterTest(InvenTreeAPITestCase): # Checkbox parameter cannot have "units" specified with self.assertRaises(django_exceptions.ValidationError): template = PartParameterTemplate( - name='test', - description='My description', - units='mm', - checkbox=True + name='test', description='My description', units='mm', checkbox=True ) template.clean() @@ -275,7 +270,7 @@ class PartParameterTest(InvenTreeAPITestCase): name='test', description='My description', choices='a,b,c', - checkbox=True + checkbox=True, ) template.clean() @@ -283,9 +278,7 @@ class PartParameterTest(InvenTreeAPITestCase): # Choices must be 'unique' with self.assertRaises(django_exceptions.ValidationError): template = PartParameterTemplate( - name='test', - description='My description', - choices='a,a,b', + name='test', description='My description', choices='a,a,b' ) template.clean() @@ -294,14 +287,7 @@ class PartParameterTest(InvenTreeAPITestCase): """Test that we can create a param via the API.""" url = reverse('api-part-parameter-list') - response = self.post( - url, - { - 'part': '2', - 'template': '3', - 'data': 70 - } - ) + response = self.post(url, {'part': '2', 'template': '3', 'data': 70}) self.assertEqual(response.status_code, 201) @@ -351,9 +337,7 @@ class PartParameterTest(InvenTreeAPITestCase): # Create a new parameter template template = PartParameterTemplate.objects.create( - name='Test Template', - description='My test template', - units='m' + name='Test Template', description='My test template', units='m' ) # Create parameters for each existing part @@ -362,7 +346,6 @@ class PartParameterTest(InvenTreeAPITestCase): parts = Part.objects.all().order_by('pk') for idx, part in enumerate(parts): - # Skip parts every now and then if idx % 10 == 7: continue @@ -371,9 +354,7 @@ class PartParameterTest(InvenTreeAPITestCase): params.append( PartParameter.objects.create( - part=part, - template=template, - data=f'{idx}{suffix}' + part=part, template=template, data=f'{idx}{suffix}' ) ) @@ -382,25 +363,15 @@ class PartParameterTest(InvenTreeAPITestCase): response = self.get( url, - { - 'ordering': 'parameter_{pk}'.format(pk=template.pk), - 'parameters': 'true', - }, - expected_code=200 + {'ordering': 'parameter_{pk}'.format(pk=template.pk), 'parameters': 'true'}, + expected_code=200, ) # All parts should be returned self.assertEqual(len(response.data), len(parts)) # Check that the parts are ordered correctly (in increasing order) - expectation = { - 0: '0mm', - 1: '3mm', - 7: '4m', - 9: '8m', - -2: '13m', - -1: None, - } + expectation = {0: '0mm', 1: '3mm', 7: '4m', 9: '8m', -2: '13m', -1: None} for idx, expected in expectation.items(): actual = get_param_value(response, template.pk, idx) @@ -413,16 +384,10 @@ class PartParameterTest(InvenTreeAPITestCase): 'ordering': '-parameter_{pk}'.format(pk=template.pk), 'parameters': 'true', }, - expected_code=200 + expected_code=200, ) - expectation = { - 0: '13m', - 1: '11m', - -3: '3mm', - -2: '0mm', - -1: None, - } + expectation = {0: '13m', 1: '11m', -3: '3mm', -2: '0mm', -1: None} for idx, expected in expectation.items(): actual = get_param_value(response, template.pk, idx) diff --git a/InvenTree/part/test_part.py b/InvenTree/part/test_part.py index bc9857c614..30da66c7f5 100644 --- a/InvenTree/part/test_part.py +++ b/InvenTree/part/test_part.py @@ -10,15 +10,26 @@ from django.test import TestCase from allauth.account.models import EmailAddress import part.settings -from common.models import (InvenTreeSetting, InvenTreeUserSetting, - NotificationEntry, NotificationMessage) +from common.models import ( + InvenTreeSetting, + InvenTreeUserSetting, + NotificationEntry, + NotificationMessage, +) from common.notifications import UIMessageNotification, storage from InvenTree import version from InvenTree.unit_test import InvenTreeTestCase -from .models import (Part, PartCategory, PartCategoryStar, PartRelated, - PartStar, PartStocktake, PartTestTemplate, - rename_part_image) +from .models import ( + Part, + PartCategory, + PartCategoryStar, + PartRelated, + PartStar, + PartStocktake, + PartTestTemplate, + rename_part_image, +) from .templatetags import inventree_extras @@ -54,11 +65,16 @@ class TemplateTagTest(InvenTreeTestCase): def test_inventree_is_release(self): """Test that the release version check functions as expected""" - self.assertEqual(inventree_extras.inventree_is_release(), not version.isInvenTreeDevelopmentVersion()) + self.assertEqual( + inventree_extras.inventree_is_release(), + not version.isInvenTreeDevelopmentVersion(), + ) def test_inventree_docs_version(self): """Test that the documentation version template tag returns correctly""" - self.assertEqual(inventree_extras.inventree_docs_version(), version.inventreeDocsVersion()) + self.assertEqual( + inventree_extras.inventree_docs_version(), version.inventreeDocsVersion() + ) def test_hash(self): """Test that the commit hash template tag returns correctly""" @@ -127,12 +143,7 @@ class TemplateTagTest(InvenTreeTestCase): class PartTest(TestCase): """Tests for the Part model.""" - fixtures = [ - 'category', - 'part', - 'location', - 'part_pricebreaks' - ] + fixtures = ['category', 'part', 'location', 'part_pricebreaks'] @classmethod def setUpTestData(cls): @@ -169,7 +180,7 @@ class PartTest(TestCase): def test_str(self): """Test string representation of a Part""" p = Part.objects.get(pk=100) - self.assertEqual(str(p), "BOB | Bob | A2 - Can we build it? Yes we can!") + self.assertEqual(str(p), 'BOB | Bob | A2 - Can we build it? Yes we can!') def test_duplicate(self): """Test that we cannot create a "duplicate" Part.""" @@ -352,12 +363,7 @@ class PartTest(TestCase): class TestTemplateTest(TestCase): """Unit test for the TestTemplate class""" - fixtures = [ - 'category', - 'part', - 'location', - 'test_templates', - ] + fixtures = ['category', 'part', 'location', 'test_templates'] def test_template_count(self): """Tests for the test template functions""" @@ -381,22 +387,17 @@ class TestTemplateTest(TestCase): variant = Part.objects.get(pk=10004) with self.assertRaises(ValidationError): - PartTestTemplate.objects.create( - part=variant, - test_name='Record weight' - ) + PartTestTemplate.objects.create(part=variant, test_name='Record weight') with self.assertRaises(ValidationError): PartTestTemplate.objects.create( - part=variant, - test_name='Check that chair is especially green' + part=variant, test_name='Check that chair is especially green' ) # Also should fail if we attempt to create a test that would generate the same key with self.assertRaises(ValidationError): PartTestTemplate.objects.create( - part=variant, - test_name='ReCoRD weiGHT ' + part=variant, test_name='ReCoRD weiGHT ' ) # But we should be able to create a new one! @@ -418,9 +419,7 @@ class PartSettingsTest(InvenTreeTestCase): cache.clear() part = Part.objects.create( - name='Test Part', - description='I am but a humble test part', - IPN='IPN-123', + name='Test Part', description='I am but a humble test part', IPN='IPN-123' ) return part @@ -474,7 +473,9 @@ class PartSettingsTest(InvenTreeTestCase): def test_duplicate_ipn(self): """Test the setting which controls duplicate IPN values.""" # Create a part - Part.objects.create(name='Hello', description='A thing', IPN='IPN123', revision='A') + Part.objects.create( + name='Hello', description='A thing', IPN='IPN123', revision='A' + ) # Attempt to create a duplicate item (should fail) with self.assertRaises(ValidationError): @@ -482,7 +483,9 @@ class PartSettingsTest(InvenTreeTestCase): part.validate_unique() # Attempt to create item with duplicate IPN (should be allowed by default) - Part.objects.create(name='Hello', description='A thing', IPN='IPN123', revision='B') + Part.objects.create( + name='Hello', description='A thing', IPN='IPN123', revision='B' + ) # And attempt again with the same values (should fail) with self.assertRaises(ValidationError): @@ -497,12 +500,16 @@ class PartSettingsTest(InvenTreeTestCase): part.full_clean() # Any duplicate IPN should raise an error - Part.objects.create(name='xyz', revision='1', description='A part', IPN='UNIQUE') + Part.objects.create( + name='xyz', revision='1', description='A part', IPN='UNIQUE' + ) # Case insensitive, so variations on spelling should throw an error for ipn in ['UNiquE', 'uniQuE', 'unique']: with self.assertRaises(ValidationError): - Part.objects.create(name='xyz', revision='2', description='A part', IPN=ipn) + Part.objects.create( + name='xyz', revision='2', description='A part', IPN=ipn + ) with self.assertRaises(ValidationError): Part.objects.create(name='zyx', description='A part', IPN='UNIQUE') @@ -520,11 +527,7 @@ class PartSettingsTest(InvenTreeTestCase): class PartSubscriptionTests(InvenTreeTestCase): """Unit tests for part 'subscription'""" - fixtures = [ - 'location', - 'category', - 'part', - ] + fixtures = ['location', 'category', 'part'] @classmethod def setUpTestData(cls): @@ -562,9 +565,7 @@ class PartSubscriptionTests(InvenTreeTestCase): """Test subscription against a parent part.""" # Construct a sub-part to star against sub_part = Part.objects.create( - name='sub_part', - description='a sub part', - variant_of=self.part, + name='sub_part', description='a sub part', variant_of=self.part ) self.assertFalse(sub_part.is_starred_by(self.user)) @@ -620,12 +621,7 @@ class PartSubscriptionTests(InvenTreeTestCase): class BaseNotificationIntegrationTest(InvenTreeTestCase): """Integration test for notifications.""" - fixtures = [ - 'location', - 'category', - 'part', - 'stock' - ] + fixtures = ['location', 'category', 'part', 'stock'] @classmethod def setUpTestData(cls): @@ -652,7 +648,9 @@ class BaseNotificationIntegrationTest(InvenTreeTestCase): self.assertEqual(NotificationEntry.objects.all().count(), 0) # Test that notifications run through without errors - self.part.minimum_stock = self.part.get_stock_count() + 1 # make sure minimum is one higher than current count + self.part.minimum_stock = ( + self.part.get_stock_count() + 1 + ) # make sure minimum is one higher than current count self.part.save() # There should be no notification as no-one is subscribed diff --git a/InvenTree/part/test_pricing.py b/InvenTree/part/test_pricing.py index e6a348f413..a434c3c233 100644 --- a/InvenTree/part/test_pricing.py +++ b/InvenTree/part/test_pricing.py @@ -31,15 +31,14 @@ class PartPricingTests(InvenTreeTestCase): name='PP', description='A part with pricing, measured in metres', assembly=True, - units='m' + units='m', ) def create_price_breaks(self): """Create some price breaks for the part, in various currencies""" # First supplier part (CAD) self.supplier_1 = company.models.Company.objects.create( - name='Supplier 1', - is_supplier=True + name='Supplier 1', is_supplier=True ) self.sp_1 = company.models.SupplierPart.objects.create( @@ -53,23 +52,16 @@ class PartPricingTests(InvenTreeTestCase): self.assertEqual(self.sp_1.pack_quantity_native, 2) company.models.SupplierPriceBreak.objects.create( - part=self.sp_1, - quantity=1, - price=10.4, - price_currency='CAD', + part=self.sp_1, quantity=1, price=10.4, price_currency='CAD' ) # Second supplier part (AUD) self.supplier_2 = company.models.Company.objects.create( - name='Supplier 2', - is_supplier=True + name='Supplier 2', is_supplier=True ) self.sp_2 = company.models.SupplierPart.objects.create( - supplier=self.supplier_2, - part=self.part, - SKU='SUP_2', - pack_quantity='2.5', + supplier=self.supplier_2, part=self.part, SKU='SUP_2', pack_quantity='2.5' ) # Native pack quantity should be 2.5m @@ -86,18 +78,12 @@ class PartPricingTests(InvenTreeTestCase): self.assertEqual(self.sp_3.pack_quantity_native, 0.254) company.models.SupplierPriceBreak.objects.create( - part=self.sp_2, - quantity=5, - price=7.555, - price_currency='AUD', + part=self.sp_2, quantity=5, price=7.555, price_currency='AUD' ) # Third supplier part (GBP) company.models.SupplierPriceBreak.objects.create( - part=self.sp_2, - quantity=10, - price=4.55, - price_currency='GBP', + part=self.sp_2, quantity=10, price=4.55, price_currency='GBP' ) def test_pricing_data(self): @@ -198,10 +184,7 @@ class PartPricingTests(InvenTreeTestCase): for ii in range(5): # Let's add some internal price breaks part.models.PartInternalPriceBreak.objects.create( - part=self.part, - quantity=ii + 1, - price=10 - ii, - price_currency=currency + part=self.part, quantity=ii + 1, price=10 - ii, price_currency=currency ) pricing.update_internal_cost() @@ -226,23 +209,20 @@ class PartPricingTests(InvenTreeTestCase): ) # Create some stock items - prices = [ - (10, 'AUD'), - (5, 'USD'), - (2, 'CAD'), - ] + prices = [(10, 'AUD'), (5, 'USD'), (2, 'CAD')] for price, currency in prices: - stock.models.StockItem.objects.create( part=p, quantity=10, purchase_price=price, - purchase_price_currency=currency + purchase_price_currency=currency, ) # Ensure that initially, stock item pricing is disabled - common.models.InvenTreeSetting.set_setting('PRICING_USE_STOCK_PRICING', False, None) + common.models.InvenTreeSetting.set_setting( + 'PRICING_USE_STOCK_PRICING', False, None + ) pricing = p.pricing pricing.update_pricing() @@ -254,7 +234,9 @@ class PartPricingTests(InvenTreeTestCase): self.assertIsNone(pricing.overall_max) # Turn on stock pricing - common.models.InvenTreeSetting.set_setting('PRICING_USE_STOCK_PRICING', True, None) + common.models.InvenTreeSetting.set_setting( + 'PRICING_USE_STOCK_PRICING', True, None + ) pricing.update_pricing() @@ -276,8 +258,8 @@ class PartPricingTests(InvenTreeTestCase): for ii in range(10): # Create a new part for the BOM sub_part = part.models.Part.objects.create( - name=f"Sub Part {ii}", - description="A sub part for use in a BOM", + name=f'Sub Part {ii}', + description='A sub part for use in a BOM', component=True, assembly=False, ) @@ -291,9 +273,7 @@ class PartPricingTests(InvenTreeTestCase): sub_part_pricing.save() part.models.BomItem.objects.create( - part=self.part, - sub_part=sub_part, - quantity=5, + part=self.part, sub_part=sub_part, quantity=5 ) pricing.update_bom_cost() @@ -318,14 +298,15 @@ class PartPricingTests(InvenTreeTestCase): # Generate some purchase orders po = order.models.PurchaseOrder.objects.create( - supplier=self.supplier_2, - reference='PO-009', + supplier=self.supplier_2, reference='PO-009' ) # Add some line items to the order # $5 AUD each @ 2.5m per unit = $2 AUD per metre - line_1 = po.add_line_item(self.sp_2, quantity=10, purchase_price=Money(5, 'AUD')) + line_1 = po.add_line_item( + self.sp_2, quantity=10, purchase_price=Money(5, 'AUD') + ) # $3 CAD each @ 10 inches per unit = $0.3 CAD per inch = $11.81 CAD per metre line_2 = po.add_line_item(self.sp_3, quantity=5, purchase_price=Money(3, 'CAD')) @@ -361,7 +342,9 @@ class PartPricingTests(InvenTreeTestCase): self.assertAlmostEqual(float(min_cost_aud.amount), 2, places=2) # Min cost in USD - self.assertAlmostEqual(float(pricing.purchase_cost_min.amount), 1.3333, places=2) + self.assertAlmostEqual( + float(pricing.purchase_cost_min.amount), 1.3333, places=2 + ) # Max cost in CAD = $11.81 CAD per metre self.assertAlmostEqual(float(max_cost_cad.amount), 11.81, places=2) @@ -420,8 +403,7 @@ class PartPricingTests(InvenTreeTestCase): # Create some parts for ii in range(100): part.models.Part.objects.create( - name=f"Part_{ii}", - description="A test part", + name=f'Part_{ii}', description='A test part' ) # Ensure there is no pricing data @@ -442,17 +424,13 @@ class PartPricingTests(InvenTreeTestCase): but it pointed to a Part instance which was slated to be deleted inside an atomic transaction. """ p = part.models.Part.objects.create( - name="my part", - description="my part description", - active=False, + name='my part', description='my part description', active=False ) # Create some stock items for _idx in range(3): stock.models.StockItem.objects.create( - part=p, - quantity=10, - purchase_price=Money(10, 'USD') + part=p, quantity=10, purchase_price=Money(10, 'USD') ) # Manually schedule a pricing update (does not happen automatically in testing) diff --git a/InvenTree/part/test_views.py b/InvenTree/part/test_views.py index a43bb91134..6674f6d0f5 100644 --- a/InvenTree/part/test_views.py +++ b/InvenTree/part/test_views.py @@ -10,14 +10,7 @@ from .models import Part class PartViewTestCase(InvenTreeTestCase): """Base class for unit testing the various Part views""" - fixtures = [ - 'category', - 'part', - 'bom', - 'location', - 'company', - 'supplier_part', - ] + fixtures = ['category', 'part', 'bom', 'location', 'company', 'supplier_part'] roles = 'all' superuser = True @@ -72,7 +65,9 @@ class PartDetailTest(PartViewTestCase): index_redirect = False detail_redirect = False - response = self.client.get(reverse('part-detail-from-ipn', args=(ipn_test,))) + response = self.client.get( + reverse('part-detail-from-ipn', args=(ipn_test,)) + ) # Check for PartIndex redirect try: @@ -110,6 +105,9 @@ class PartDetailTest(PartViewTestCase): def test_bom_download(self): """Test downloading a BOM for a valid part.""" - response = self.client.get(reverse('api-bom-download', args=(1,)), HTTP_X_REQUESTED_WITH='XMLHttpRequest') + response = self.client.get( + reverse('api-bom-download', args=(1,)), + HTTP_X_REQUESTED_WITH='XMLHttpRequest', + ) self.assertEqual(response.status_code, 200) self.assertIn('streaming_content', dir(response)) diff --git a/InvenTree/part/urls.py b/InvenTree/part/urls.py index 6654c3a9ea..7e6a4e1f8f 100644 --- a/InvenTree/part/urls.py +++ b/InvenTree/part/urls.py @@ -12,37 +12,37 @@ from . import views part_detail_urls = [ re_path(r'^bom-upload/?', views.BomUpload.as_view(), name='upload-bom'), - # Normal thumbnail with form - re_path(r'^thumb-select/?', views.PartImageSelect.as_view(), name='part-image-select'), - + re_path( + r'^thumb-select/?', views.PartImageSelect.as_view(), name='part-image-select' + ), # Any other URLs go to the part detail page re_path(r'^.*$', views.PartDetail.as_view(), name='part-detail'), ] category_urls = [ - # Category detail views - re_path(r'(?P\d+)/', views.CategoryDetail.as_view(), name='category-detail'), + re_path(r'(?P\d+)/', views.CategoryDetail.as_view(), name='category-detail') ] # URL list for part web interface part_urls = [ - # Upload a part re_path(r'^import/$', views.PartImport.as_view(), name='part-import'), - re_path(r'^import/?', views.PartImportTemplate.as_view(), name='part-template-download'), + re_path( + r'^import/?', views.PartImportTemplate.as_view(), name='part-template-download' + ), re_path(r'^import-api/', views.PartImportAjax.as_view(), name='api-part-import'), - # Individual part using pk path(r'/', include(part_detail_urls)), - # Part category re_path(r'^category/', include(category_urls)), - # Individual part using IPN as slug - re_path(r'^(?P[-\w]+)/', views.PartDetailFromIPN.as_view(), name='part-detail-from-ipn'), - + re_path( + r'^(?P[-\w]+)/', + views.PartDetailFromIPN.as_view(), + name='part-detail-from-ipn', + ), # Top level part list (display top level parts and categories) re_path(r'^.*$', views.PartIndex.as_view(), name='part-index'), ] diff --git a/InvenTree/part/views.py b/InvenTree/part/views.py index b8518ea64f..44993db71b 100644 --- a/InvenTree/part/views.py +++ b/InvenTree/part/views.py @@ -60,14 +60,13 @@ class PartIndex(InvenTreeRoleMixin, InvenTreePluginViewMixin, ListView): class PartImport(FileManagementFormView): """Part: Upload file, match to fields and import parts(using multi-Step form)""" + permission_required = 'part.add' class PartFileManager(FileManager): """Import field definitions""" - REQUIRED_HEADERS = [ - 'Name', - 'Description', - ] + + REQUIRED_HEADERS = ['Name', 'Description'] OPTIONAL_MATCH_HEADERS = [ 'Category', @@ -106,9 +105,9 @@ class PartImport(FileManagementFormView): 'part/import_wizard/match_references.html', ] form_steps_description = [ - _("Upload File"), - _("Match Fields"), - _("Match References"), + _('Upload File'), + _('Match Fields'), + _('Match References'), ] form_field_map = { @@ -147,9 +146,13 @@ class PartImport(FileManagementFormView): self.allowed_items = {} self.matches = {} - self.allowed_items['Category'] = PartCategory.objects.all().exclude(structural=True) + self.allowed_items['Category'] = PartCategory.objects.all().exclude( + structural=True + ) self.matches['Category'] = ['name__icontains'] - self.allowed_items['default_location'] = StockLocation.objects.all().exclude(structural=True) + self.allowed_items['default_location'] = StockLocation.objects.all().exclude( + structural=True + ) self.matches['default_location'] = ['name__icontains'] self.allowed_items['default_supplier'] = SupplierPart.objects.all() self.matches['default_supplier'] = ['SKU__icontains'] @@ -169,7 +172,6 @@ class PartImport(FileManagementFormView): for row in self.rows: # check each submitted column for idx in col_ids: - try: data = row['data'][col_ids[idx]]['cell'] except (IndexError, TypeError): @@ -177,8 +179,14 @@ class PartImport(FileManagementFormView): if idx in self.file_manager.OPTIONAL_MATCH_HEADERS: try: - exact_match = self.allowed_items[idx].get(**{a: data for a in self.matches[idx]}) - except (ValueError, self.allowed_items[idx].model.DoesNotExist, self.allowed_items[idx].model.MultipleObjectsReturned): + exact_match = self.allowed_items[idx].get(**{ + a: data for a in self.matches[idx] + }) + except ( + ValueError, + self.allowed_items[idx].model.DoesNotExist, + self.allowed_items[idx].model.MultipleObjectsReturned, + ): exact_match = None row['match_options_' + idx] = self.allowed_items[idx] @@ -197,14 +205,19 @@ class PartImport(FileManagementFormView): # Create Part instances for part_data in items.values(): - # set related parts optional_matches = {} for idx in self.file_manager.OPTIONAL_MATCH_HEADERS: if idx.lower() in part_data: try: - optional_matches[idx] = self.allowed_items[idx].get(pk=int(part_data[idx.lower()])) - except (ValueError, self.allowed_items[idx].model.DoesNotExist, self.allowed_items[idx].model.MultipleObjectsReturned): + optional_matches[idx] = self.allowed_items[idx].get( + pk=int(part_data[idx.lower()]) + ) + except ( + ValueError, + self.allowed_items[idx].model.DoesNotExist, + self.allowed_items[idx].model.MultipleObjectsReturned, + ): optional_matches[idx] = None else: optional_matches[idx] = None @@ -228,19 +241,39 @@ class PartImport(FileManagementFormView): active=str2bool(part_data.get('active', True)), base_cost=str2int(part_data.get('base_cost'), 0), multiple=str2int(part_data.get('multiple'), 1), - assembly=str2bool(part_data.get('assembly', part_settings.part_assembly_default())), - component=str2bool(part_data.get('component', part_settings.part_component_default())), - is_template=str2bool(part_data.get('is_template', part_settings.part_template_default())), - purchaseable=str2bool(part_data.get('purchaseable', part_settings.part_purchaseable_default())), - salable=str2bool(part_data.get('salable', part_settings.part_salable_default())), - trackable=str2bool(part_data.get('trackable', part_settings.part_trackable_default())), - virtual=str2bool(part_data.get('virtual', part_settings.part_virtual_default())), + assembly=str2bool( + part_data.get('assembly', part_settings.part_assembly_default()) + ), + component=str2bool( + part_data.get('component', part_settings.part_component_default()) + ), + is_template=str2bool( + part_data.get('is_template', part_settings.part_template_default()) + ), + purchaseable=str2bool( + part_data.get( + 'purchaseable', part_settings.part_purchaseable_default() + ) + ), + salable=str2bool( + part_data.get('salable', part_settings.part_salable_default()) + ), + trackable=str2bool( + part_data.get('trackable', part_settings.part_trackable_default()) + ), + virtual=str2bool( + part_data.get('virtual', part_settings.part_virtual_default()) + ), image=part_data.get('image', None), ) # check if there's a category assigned, if not skip this part or else bad things happen if not optional_matches['Category']: - import_error.append(_(f"Can't import part {new_part.name} because there is no category assigned")) + import_error.append( + _( + f"Can't import part {new_part.name} because there is no category assigned" + ) + ) continue try: @@ -264,8 +297,14 @@ class PartImport(FileManagementFormView): alert = f"{_('Part-Import')}
{_(f'Imported {import_done} parts')}" messages.success(self.request, alert) if import_error: - error_text = '\n'.join([f'
  • {import_error.count(a)}: {a}
  • ' for a in set(import_error)]) - messages.error(self.request, f"{_('Some errors occurred:')}
      {error_text}
    ") + error_text = '\n'.join([ + f'
  • {import_error.count(a)}: {a}
  • ' + for a in set(import_error) + ]) + messages.error( + self.request, + f"{_('Some errors occurred:')}
      {error_text}
    ", + ) return HttpResponseRedirect(reverse('part-index')) @@ -285,6 +324,7 @@ class PartImportTemplate(AjaxView): class PartImportAjax(FileManagementAjaxView, PartImport): """Multi-step form wizard for importing Part data""" + ajax_form_steps_template = [ 'part/import_wizard/ajax_part_upload.html', 'part/import_wizard/ajax_match_fields.html', @@ -383,9 +423,7 @@ class PartImageSelect(AjaxUpdateView): ajax_template_name = 'part/select_image.html' ajax_form_title = _('Select Part Image') - fields = [ - 'image', - ] + fields = ['image'] def post(self, request, *args, **kwargs): """Perform POST action to assign selected image to the Part instance""" @@ -403,7 +441,6 @@ class PartImageSelect(AjaxUpdateView): # Ensure that the image already exists if os.path.exists(img_path): - part.image = os.path.join(PART_IMAGE_DIR, img) part.save() @@ -481,31 +518,30 @@ class BomDownload(AjaxView): if not IsValidBOMFormat(export_format): export_format = 'csv' - return ExportBom(part, - fmt=export_format, - cascade=cascade, - max_levels=levels, - parameter_data=parameter_data, - stock_data=stock_data, - supplier_data=supplier_data, - manufacturer_data=manufacturer_data, - pricing_data=pricing_data, - substitute_part_data=substitute_part_data, - ) + return ExportBom( + part, + fmt=export_format, + cascade=cascade, + max_levels=levels, + parameter_data=parameter_data, + stock_data=stock_data, + supplier_data=supplier_data, + manufacturer_data=manufacturer_data, + pricing_data=pricing_data, + substitute_part_data=substitute_part_data, + ) def get_data(self): """Return a custom message""" - return { - 'info': 'Exported BOM' - } + return {'info': 'Exported BOM'} class PartPricing(AjaxView): """View for inspecting part pricing information.""" model = Part - ajax_template_name = "part/part_pricing.html" - ajax_form_title = _("Part Pricing") + ajax_template_name = 'part/part_pricing.html' + ajax_form_title = _('Part Pricing') form_class = part_forms.PartPriceForm role_required = ['sales_order.view', 'part.view'] @@ -534,11 +570,7 @@ class PartPricing(AjaxView): part = self.get_part() - ctx = { - 'part': part, - 'quantity': quantity, - 'currency': currency, - } + ctx = {'part': part, 'quantity': quantity, 'currency': currency} if part is None: return ctx @@ -569,8 +601,9 @@ class PartPricing(AjaxView): # BOM pricing information if part.bom_count > 0: - - use_internal = InvenTreeSetting.get_setting('PART_BOM_USE_INTERNAL_PRICE', False) + use_internal = InvenTreeSetting.get_setting( + 'PART_BOM_USE_INTERNAL_PRICE', False + ) bom_price = part.get_bom_price_range(quantity, internal=use_internal) purchase_price = part.get_bom_price_range(quantity, purchase=True) @@ -594,12 +627,20 @@ class PartPricing(AjaxView): min_bom_purchase_price /= scaler max_bom_purchase_price /= scaler if min_bom_purchase_price: - ctx['min_total_bom_purchase_price'] = round(min_bom_purchase_price, 3) - ctx['min_unit_bom_purchase_price'] = round(min_bom_purchase_price / quantity, 3) + ctx['min_total_bom_purchase_price'] = round( + min_bom_purchase_price, 3 + ) + ctx['min_unit_bom_purchase_price'] = round( + min_bom_purchase_price / quantity, 3 + ) if max_bom_purchase_price: - ctx['max_total_bom_purchase_price'] = round(max_bom_purchase_price, 3) - ctx['max_unit_bom_purchase_price'] = round(max_bom_purchase_price / quantity, 3) + ctx['max_total_bom_purchase_price'] = round( + max_bom_purchase_price, 3 + ) + ctx['max_unit_bom_purchase_price'] = round( + max_bom_purchase_price / quantity, 3 + ) # internal part pricing information internal_part_price = part.get_internal_price(quantity) @@ -624,7 +665,9 @@ class PartPricing(AjaxView): init = self.get_initials() qty = self.get_quantity() - return self.renderJsonResponse(request, self.form_class(initial=init), context=self.get_pricing(qty)) + return self.renderJsonResponse( + request, self.form_class(initial=init), context=self.get_pricing(qty) + ) def post(self, request, *args, **kwargs): """Perform custom POST action for this view""" @@ -647,7 +690,9 @@ class PartPricing(AjaxView): # Always mark the form as 'invalid' (the user may wish to keep getting pricing data) data['form_valid'] = False - return self.renderJsonResponse(request, form, data=data, context=self.get_pricing(quantity, currency)) + return self.renderJsonResponse( + request, form, data=data, context=self.get_pricing(quantity, currency) + ) class CategoryDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): @@ -676,11 +721,9 @@ class CategoryDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): category = kwargs.get('object', None) if category: - # Insert "starred" information context['starred_directly'] = category.is_starred_by( - self.request.user, - include_parents=False, + self.request.user, include_parents=False ) if context['starred_directly']: diff --git a/InvenTree/plugin/__init__.py b/InvenTree/plugin/__init__.py index db8a8ea10a..63a634fe20 100644 --- a/InvenTree/plugin/__init__.py +++ b/InvenTree/plugin/__init__.py @@ -6,7 +6,6 @@ from .registry import registry __all__ = [ 'registry', - 'InvenTreePlugin', 'MixinNotImplementedError', 'MixinImplementationError', diff --git a/InvenTree/plugin/admin.py b/InvenTree/plugin/admin.py index 42f02c0968..102a599da9 100644 --- a/InvenTree/plugin/admin.py +++ b/InvenTree/plugin/admin.py @@ -39,9 +39,7 @@ class PluginSettingInline(admin.TabularInline): model = models.PluginSetting - read_only_fields = [ - 'key', - ] + read_only_fields = ['key'] def has_add_permission(self, request, obj): """The plugin settings should not be meddled with manually.""" @@ -51,12 +49,20 @@ class PluginSettingInline(admin.TabularInline): class PluginConfigAdmin(admin.ModelAdmin): """Custom admin with restricted id fields.""" - readonly_fields = ["key", "name", ] - list_display = ['name', 'key', '__str__', 'active', 'is_builtin', 'is_sample', 'is_installed'] + readonly_fields = ['key', 'name'] + list_display = [ + 'name', + 'key', + '__str__', + 'active', + 'is_builtin', + 'is_sample', + 'is_installed', + ] list_filter = ['active'] - actions = [plugin_activate, plugin_deactivate, ] - inlines = [PluginSettingInline, ] - exclude = ['metadata', ] + actions = [plugin_activate, plugin_deactivate] + inlines = [PluginSettingInline] + exclude = ['metadata'] class NotificationUserSettingAdmin(admin.ModelAdmin): @@ -64,9 +70,7 @@ class NotificationUserSettingAdmin(admin.ModelAdmin): model = models.NotificationUserSetting - read_only_fields = [ - 'key', - ] + read_only_fields = ['key'] def has_add_permission(self, request): """Notifications should not be changed.""" diff --git a/InvenTree/plugin/api.py b/InvenTree/plugin/api.py index 4e04b0b5b3..8f5d7de655 100644 --- a/InvenTree/plugin/api.py +++ b/InvenTree/plugin/api.py @@ -14,8 +14,13 @@ from common.api import GlobalSettingsPermissions from InvenTree.api import MetadataView from InvenTree.filters import SEARCH_ORDER_FILTER from InvenTree.helpers import str2bool -from InvenTree.mixins import (CreateAPI, ListAPI, RetrieveUpdateAPI, - RetrieveUpdateDestroyAPI, UpdateAPI) +from InvenTree.mixins import ( + CreateAPI, + ListAPI, + RetrieveUpdateAPI, + RetrieveUpdateDestroyAPI, + UpdateAPI, +) from InvenTree.permissions import IsSuperuser from plugin import registry from plugin.base.action.api import ActionPluginView @@ -102,26 +107,13 @@ class PluginList(ListAPI): filter_backends = SEARCH_ORDER_FILTER - filterset_fields = [ - 'active', - ] + filterset_fields = ['active'] - ordering_fields = [ - 'key', - 'name', - 'active', - ] + ordering_fields = ['key', 'name', 'active'] - ordering = [ - '-active', - 'name', - 'key', - ] + ordering = ['-active', 'name', 'key'] - search_fields = [ - 'key', - 'name', - ] + search_fields = ['key', 'name'] class PluginDetail(RetrieveUpdateDestroyAPI): @@ -176,7 +168,7 @@ class PluginActivate(UpdateAPI): queryset = PluginConfig.objects.all() serializer_class = PluginSerializers.PluginActivateSerializer - permission_classes = [IsSuperuser, ] + permission_classes = [IsSuperuser] def get_object(self): """Returns the object for the view.""" @@ -194,7 +186,7 @@ class PluginReload(CreateAPI): queryset = PluginConfig.objects.none() serializer_class = PluginSerializers.PluginReloadSerializer - permission_classes = [IsSuperuser,] + permission_classes = [IsSuperuser] def perform_create(self, serializer): """Saving the serializer instance performs plugin installation""" @@ -211,18 +203,11 @@ class PluginSettingList(ListAPI): queryset = PluginSetting.objects.all() serializer_class = PluginSerializers.PluginSettingSerializer - permission_classes = [ - GlobalSettingsPermissions, - ] + permission_classes = [GlobalSettingsPermissions] - filter_backends = [ - DjangoFilterBackend, - ] + filter_backends = [DjangoFilterBackend] - filterset_fields = [ - 'plugin__active', - 'plugin__key', - ] + filterset_fields = ['plugin__active', 'plugin__key'] def check_plugin(plugin_slug: str, plugin_pk: int) -> InvenTreePlugin: @@ -242,7 +227,7 @@ def check_plugin(plugin_slug: str, plugin_pk: int) -> InvenTreePlugin: """ # Make sure that a plugin reference is specified if plugin_slug is None and plugin_pk is None: - raise NotFound(detail="Plugin not specified") + raise NotFound(detail='Plugin not specified') # Define filter filter = {} @@ -282,7 +267,9 @@ class PluginAllSettingList(APIView): permission_classes = [GlobalSettingsPermissions] - @extend_schema(responses={200: PluginSerializers.PluginSettingSerializer(many=True)}) + @extend_schema( + responses={200: PluginSerializers.PluginSettingSerializer(many=True)} + ) def get(self, request, pk): """Get all settings for a plugin config.""" @@ -291,9 +278,13 @@ class PluginAllSettingList(APIView): settings = getattr(plugin, 'settings', {}) - settings_dict = PluginSetting.all_settings(settings_definition=settings, plugin=plugin.plugin_config()) + settings_dict = PluginSetting.all_settings( + settings_definition=settings, plugin=plugin.plugin_config() + ) - results = PluginSerializers.PluginSettingSerializer(list(settings_dict.values()), many=True).data + results = PluginSerializers.PluginSettingSerializer( + list(settings_dict.values()), many=True + ).data return Response(results) @@ -315,19 +306,21 @@ class PluginSettingDetail(RetrieveUpdateAPI): key = self.kwargs['key'] # Look up plugin - plugin = check_plugin(plugin_slug=self.kwargs.get('plugin'), plugin_pk=self.kwargs.get('pk')) + plugin = check_plugin( + plugin_slug=self.kwargs.get('plugin'), plugin_pk=self.kwargs.get('pk') + ) settings = getattr(plugin, 'settings', {}) if key not in settings: - raise NotFound(detail=f"Plugin '{plugin.slug}' has no setting matching '{key}'") + raise NotFound( + detail=f"Plugin '{plugin.slug}' has no setting matching '{key}'" + ) return PluginSetting.get_setting_object(key, plugin=plugin.plugin_config()) # Staff permission required - permission_classes = [ - GlobalSettingsPermissions, - ] + permission_classes = [GlobalSettingsPermissions] class RegistryStatusView(APIView): @@ -336,7 +329,7 @@ class RegistryStatusView(APIView): - GET: Provide status data for the plugin registry """ - permission_classes = [IsSuperuser, ] + permission_classes = [IsSuperuser] serializer_class = PluginSerializers.PluginRegistryStatusSerializer @@ -349,13 +342,13 @@ class RegistryStatusView(APIView): for error_detail in errors: for name, message in error_detail.items(): error_list.append({ - "stage": stage, - "name": name, - "message": message, + 'stage': stage, + 'name': name, + 'message': message, }) result = PluginSerializers.PluginRegistryStatusSerializer({ - "registry_errors": error_list, + 'registry_errors': error_list }).data return Response(result) @@ -365,35 +358,73 @@ plugin_api_urls = [ re_path(r'^action/', ActionPluginView.as_view(), name='api-action-plugin'), re_path(r'^barcode/', include(barcode_api_urls)), re_path(r'^locate/', LocatePluginView.as_view(), name='api-locate-plugin'), - re_path(r'^plugins/', include([ - # Plugin settings URLs - re_path(r'^settings/', include([ - re_path(r'^(?P[-\w]+)/(?P\w+)/', PluginSettingDetail.as_view(), name='api-plugin-setting-detail'), # Used for admin interface - re_path(r'^.*$', PluginSettingList.as_view(), name='api-plugin-setting-list'), - ])), - - # Detail views for a single PluginConfig item - path(r'/', include([ - re_path(r"^settings/", include([ - re_path(r'^(?P\w+)/', PluginSettingDetail.as_view(), name='api-plugin-setting-detail-pk'), - re_path(r"^.*$", PluginAllSettingList.as_view(), name="api-plugin-settings"), - ])), - re_path(r'^activate/', PluginActivate.as_view(), name='api-plugin-detail-activate'), - re_path(r'^.*$', PluginDetail.as_view(), name='api-plugin-detail'), - ])), - - # Metadata - re_path('^metadata/', MetadataView.as_view(), {'model': PluginConfig}, name='api-plugin-metadata'), - - # Plugin management - re_path(r'^reload/', PluginReload.as_view(), name='api-plugin-reload'), - re_path(r'^install/', PluginInstall.as_view(), name='api-plugin-install'), - re_path(r'^activate/', PluginActivate.as_view(), name='api-plugin-activate'), - - # Registry status - re_path(r"^status/", RegistryStatusView.as_view(), name="api-plugin-registry-status"), - - # Anything else - re_path(r'^.*$', PluginList.as_view(), name='api-plugin-list'), - ])) + re_path( + r'^plugins/', + include([ + # Plugin settings URLs + re_path( + r'^settings/', + include([ + re_path( + r'^(?P[-\w]+)/(?P\w+)/', + PluginSettingDetail.as_view(), + name='api-plugin-setting-detail', + ), # Used for admin interface + re_path( + r'^.*$', + PluginSettingList.as_view(), + name='api-plugin-setting-list', + ), + ]), + ), + # Detail views for a single PluginConfig item + path( + r'/', + include([ + re_path( + r'^settings/', + include([ + re_path( + r'^(?P\w+)/', + PluginSettingDetail.as_view(), + name='api-plugin-setting-detail-pk', + ), + re_path( + r'^.*$', + PluginAllSettingList.as_view(), + name='api-plugin-settings', + ), + ]), + ), + re_path( + r'^activate/', + PluginActivate.as_view(), + name='api-plugin-detail-activate', + ), + re_path(r'^.*$', PluginDetail.as_view(), name='api-plugin-detail'), + ]), + ), + # Metadata + re_path( + '^metadata/', + MetadataView.as_view(), + {'model': PluginConfig}, + name='api-plugin-metadata', + ), + # Plugin management + re_path(r'^reload/', PluginReload.as_view(), name='api-plugin-reload'), + re_path(r'^install/', PluginInstall.as_view(), name='api-plugin-install'), + re_path( + r'^activate/', PluginActivate.as_view(), name='api-plugin-activate' + ), + # Registry status + re_path( + r'^status/', + RegistryStatusView.as_view(), + name='api-plugin-registry-status', + ), + # Anything else + re_path(r'^.*$', PluginList.as_view(), name='api-plugin-list'), + ]), + ), ] diff --git a/InvenTree/plugin/apps.py b/InvenTree/plugin/apps.py index 37d7dafab2..affc62587d 100644 --- a/InvenTree/plugin/apps.py +++ b/InvenTree/plugin/apps.py @@ -27,8 +27,10 @@ class PluginAppConfig(AppConfig): if not isInMainThread(): return - if not canAppAccessDatabase(allow_test=True, allow_plugins=True, allow_shell=True): - logger.info("Skipping plugin loading sequence") # pragma: no cover + if not canAppAccessDatabase( + allow_test=True, allow_plugins=True, allow_shell=True + ): + logger.info('Skipping plugin loading sequence') # pragma: no cover else: logger.info('Loading InvenTree plugins') @@ -36,14 +38,19 @@ class PluginAppConfig(AppConfig): # this is the first startup try: from common.models import InvenTreeSetting - if InvenTreeSetting.get_setting('PLUGIN_ON_STARTUP', create=False, cache=False): + + if InvenTreeSetting.get_setting( + 'PLUGIN_ON_STARTUP', create=False, cache=False + ): # make sure all plugins are installed registry.install_plugin_file() except Exception: # pragma: no cover pass # Perform a full reload of the plugin registry - registry.reload_plugins(full_reload=True, force_reload=True, collect=True) + registry.reload_plugins( + full_reload=True, force_reload=True, collect=True + ) # drop out of maintenance # makes sure we did not have an error in reloading and maintenance is still active diff --git a/InvenTree/plugin/base/action/api.py b/InvenTree/plugin/base/action/api.py index 735b46f65b..37dabd9eda 100644 --- a/InvenTree/plugin/base/action/api.py +++ b/InvenTree/plugin/base/action/api.py @@ -12,9 +12,7 @@ from plugin import registry class ActionPluginView(APIView): """Endpoint for running custom action plugins.""" - permission_classes = [ - permissions.IsAuthenticated, - ] + permission_classes = [permissions.IsAuthenticated] def post(self, request, *args, **kwargs): """This function checks if all required info was submitted and then performs a plugin_action or returns an error.""" @@ -23,9 +21,7 @@ class ActionPluginView(APIView): data = request.data.get('data', None) if action is None: - return Response({ - 'error': _("No action specified") - }) + return Response({'error': _('No action specified')}) action_plugins = registry.with_mixin('action') for plugin in action_plugins: @@ -34,7 +30,4 @@ class ActionPluginView(APIView): return Response(plugin.get_response(request.user, data=data)) # If we got to here, no matching action was found - return Response({ - 'error': _("No matching action found"), - "action": action, - }) + return Response({'error': _('No matching action found'), 'action': action}) diff --git a/InvenTree/plugin/base/action/mixins.py b/InvenTree/plugin/base/action/mixins.py index b978a69189..e4931c1012 100644 --- a/InvenTree/plugin/base/action/mixins.py +++ b/InvenTree/plugin/base/action/mixins.py @@ -4,7 +4,7 @@ class ActionMixin: """Mixin that enables custom actions.""" - ACTION_NAME = "" + ACTION_NAME = '' class MixinMeta: """Meta options for this mixin.""" @@ -47,7 +47,7 @@ class ActionMixin: Default implementation is a simple response which can be overridden. """ return { - "action": self.action_name(), - "result": self.get_result(user, data), - "info": self.get_info(user, data), + 'action': self.action_name(), + 'result': self.get_result(user, data), + 'info': self.get_info(user, data), } diff --git a/InvenTree/plugin/base/action/test_action.py b/InvenTree/plugin/base/action/test_action.py index 2a834ae63f..b59b82548e 100644 --- a/InvenTree/plugin/base/action/test_action.py +++ b/InvenTree/plugin/base/action/test_action.py @@ -17,12 +17,15 @@ class ActionMixinTests(TestCase): Contains multiple sample plugins that are used in the tests """ + class SimplePlugin(ActionMixin, InvenTreePlugin): pass + self.plugin = SimplePlugin() class TestActionPlugin(ActionMixin, InvenTreePlugin): """An action plugin.""" + ACTION_NAME = 'abc123' def perform_action(self, user=None, data=None): @@ -53,21 +56,24 @@ class ActionMixinTests(TestCase): self.assertIsNone(self.plugin.perform_action()) self.assertEqual(self.plugin.get_result(), False) self.assertIsNone(self.plugin.get_info()) - self.assertEqual(self.plugin.get_response(), { - "action": '', - "result": False, - "info": None, - }) + self.assertEqual( + self.plugin.get_response(), {'action': '', 'result': False, 'info': None} + ) # overridden functions - self.assertEqual(self.action_plugin.perform_action(), self.ACTION_RETURN + 'action') + self.assertEqual( + self.action_plugin.perform_action(), self.ACTION_RETURN + 'action' + ) self.assertEqual(self.action_plugin.get_result(), self.ACTION_RETURN + 'result') self.assertEqual(self.action_plugin.get_info(), self.ACTION_RETURN + 'info') - self.assertEqual(self.action_plugin.get_response(), { - "action": 'abc123', - "result": self.ACTION_RETURN + 'result', - "info": self.ACTION_RETURN + 'info', - }) + self.assertEqual( + self.action_plugin.get_response(), + { + 'action': 'abc123', + 'result': self.ACTION_RETURN + 'result', + 'info': self.ACTION_RETURN + 'info', + }, + ) class APITests(InvenTreeTestCase): @@ -78,15 +84,12 @@ class APITests(InvenTreeTestCase): # Test empty request response = self.client.post('/api/action/') self.assertEqual(response.status_code, 200) - self.assertEqual( - response.data, - {'error': 'No action specified'} - ) + self.assertEqual(response.data, {'error': 'No action specified'}) # Test non-existing action - response = self.client.post('/api/action/', data={'action': "nonexisting"}) + response = self.client.post('/api/action/', data={'action': 'nonexisting'}) self.assertEqual(response.status_code, 200) self.assertEqual( response.data, - {'error': 'No matching action found', 'action': 'nonexisting'} + {'error': 'No matching action found', 'action': 'nonexisting'}, ) diff --git a/InvenTree/plugin/base/barcodes/api.py b/InvenTree/plugin/base/barcodes/api.py index e46253bd62..3a829b9752 100644 --- a/InvenTree/plugin/base/barcodes/api.py +++ b/InvenTree/plugin/base/barcodes/api.py @@ -15,8 +15,7 @@ import order.models import stock.models from InvenTree.helpers import hash_barcode from plugin import registry -from plugin.builtin.barcodes.inventree_barcode import \ - InvenTreeInternalBarcodePlugin +from plugin.builtin.barcodes.inventree_barcode import InvenTreeInternalBarcodePlugin from users.models import RuleSet from . import serializers as barcode_serializers @@ -35,9 +34,7 @@ class BarcodeView(CreateAPIView): return None # Default permission classes (can be overridden) - permission_classes = [ - permissions.IsAuthenticated, - ] + permission_classes = [permissions.IsAuthenticated] def create(self, request, *args, **kwargs): """Handle create method - override default create""" @@ -60,7 +57,9 @@ class BarcodeView(CreateAPIView): kwargs: Any custom fields passed by the specific serializer """ - raise NotImplementedError(f"handle_barcode not implemented for {self.__class__}") + raise NotImplementedError( + f'handle_barcode not implemented for {self.__class__}' + ) def scan_barcode(self, barcode: str, request, **kwargs): """Perform a generic 'scan' of the provided barcode data. @@ -75,15 +74,17 @@ class BarcodeView(CreateAPIView): response = {} for current_plugin in plugins: - result = current_plugin.scan(barcode) if result is None: continue - if "error" in result: - logger.info("%s.scan(...) returned an error: %s", - current_plugin.__class__.__name__, result["error"]) + if 'error' in result: + logger.info( + '%s.scan(...) returned an error: %s', + current_plugin.__class__.__name__, + result['error'], + ) if not response: plugin = current_plugin response = result @@ -154,9 +155,9 @@ class BarcodeAssign(BarcodeView): result = plugin.scan(barcode) if result is not None: - result["error"] = _("Barcode matches existing item") - result["plugin"] = plugin.name - result["barcode_data"] = barcode + result['error'] = _('Barcode matches existing item') + result['plugin'] = plugin.name + result['barcode_data'] = barcode raise ValidationError(result) @@ -169,35 +170,29 @@ class BarcodeAssign(BarcodeView): valid_labels.append(label) if instance := kwargs.get(label, None): - # Check that the user has the required permission app_label = model._meta.app_label model_name = model._meta.model_name - table = f"{app_label}_{model_name}" + table = f'{app_label}_{model_name}' - if not RuleSet.check_table_permission(request.user, table, "change"): + if not RuleSet.check_table_permission(request.user, table, 'change'): raise PermissionDenied({ - "error": f"You do not have the required permissions for {table}" + 'error': f'You do not have the required permissions for {table}' }) - instance.assign_barcode( - barcode_data=barcode, - barcode_hash=barcode_hash, - ) + instance.assign_barcode(barcode_data=barcode, barcode_hash=barcode_hash) return Response({ - 'success': f"Assigned barcode to {label} instance", - label: { - 'pk': instance.pk, - }, - "barcode_data": barcode, - "barcode_hash": barcode_hash, + 'success': f'Assigned barcode to {label} instance', + label: {'pk': instance.pk}, + 'barcode_data': barcode, + 'barcode_hash': barcode_hash, }) # If we got here, it means that no valid model types were provided raise ValidationError({ - 'error': f"Missing data: provide one of '{valid_labels}'", + 'error': f"Missing data: provide one of '{valid_labels}'" }) @@ -231,38 +226,34 @@ class BarcodeUnassign(BarcodeView): if len(matched_labels) > 1: raise ValidationError({ - 'error': f"Multiple conflicting fields: '{model_names}'", + 'error': f"Multiple conflicting fields: '{model_names}'" }) # At this stage, we know that we have received a single valid field for model in supported_models: - label = model.barcode_model_type() if instance := data.get(label, None): - # Check that the user has the required permission app_label = model._meta.app_label model_name = model._meta.model_name - table = f"{app_label}_{model_name}" + table = f'{app_label}_{model_name}' - if not RuleSet.check_table_permission(request.user, table, "change"): + if not RuleSet.check_table_permission(request.user, table, 'change'): raise PermissionDenied({ - "error": f"You do not have the required permissions for {table}" + 'error': f'You do not have the required permissions for {table}' }) # Unassign the barcode data from the model instance instance.unassign_barcode() return Response({ - 'success': f'Barcode unassigned from {label} instance', + 'success': f'Barcode unassigned from {label} instance' }) # If we get to this point, something has gone wrong! - raise ValidationError({ - 'error': 'Could not unassign barcode', - }) + raise ValidationError({'error': 'Could not unassign barcode'}) class BarcodePOAllocate(BarcodeView): @@ -275,13 +266,13 @@ class BarcodePOAllocate(BarcodeView): - A SupplierPart object """ - role_required = [ - 'purchase_order.add' - ] + role_required = ['purchase_order.add'] serializer_class = barcode_serializers.BarcodePOAllocateSerializer - def get_supplier_part(self, purchase_order, part=None, supplier_part=None, manufacturer_part=None): + def get_supplier_part( + self, purchase_order, part=None, supplier_part=None, manufacturer_part=None + ): """Return a single matching SupplierPart (or else raise an exception) Arguments: @@ -305,9 +296,7 @@ class BarcodePOAllocate(BarcodeView): supplier_parts = company.models.SupplierPart.objects.filter(supplier=supplier) if not part and not supplier_part and not manufacturer_part: - raise ValidationError({ - 'error': _('No matching part data found'), - }) + raise ValidationError({'error': _('No matching part data found')}) if part: if part_id := part.get('pk', None): @@ -319,16 +308,16 @@ class BarcodePOAllocate(BarcodeView): if manufacturer_part: if manufacturer_part_id := manufacturer_part.get('pk', None): - supplier_parts = supplier_parts.filter(manufacturer_part__pk=manufacturer_part_id) + supplier_parts = supplier_parts.filter( + manufacturer_part__pk=manufacturer_part_id + ) if supplier_parts.count() == 0: - raise ValidationError({ - "error": _("No matching supplier parts found") - }) + raise ValidationError({'error': _('No matching supplier parts found')}) if supplier_parts.count() > 1: raise ValidationError({ - "error": _("Multiple matching supplier parts found") + 'error': _('Multiple matching supplier parts found') }) # At this stage, we have a single matching supplier part @@ -353,7 +342,7 @@ class BarcodePOAllocate(BarcodeView): manufacturer_part=result.get('manufacturerpart', None), ) - result['success'] = _("Matched supplier part") + result['success'] = _('Matched supplier part') result['supplierpart'] = supplier_part.format_matched_response() # TODO: Determine the 'quantity to order' for the supplier part @@ -377,9 +366,7 @@ class BarcodePOReceive(BarcodeView): - location: The destination location for the received item (optional) """ - role_required = [ - 'purchase_order.add' - ] + role_required = ['purchase_order.add'] serializer_class = barcode_serializers.BarcodePOReceiveSerializer @@ -392,45 +379,41 @@ class BarcodePOReceive(BarcodeView): purchase_order = kwargs.get('purchase_order', None) location = kwargs.get('location', None) - plugins = registry.with_mixin("barcode") + plugins = registry.with_mixin('barcode') # Look for a barcode plugin which knows how to deal with this barcode plugin = None - response = { - "barcode_data": barcode, - "barcode_hash": hash_barcode(barcode) - } + response = {'barcode_data': barcode, 'barcode_hash': hash_barcode(barcode)} - internal_barcode_plugin = next(filter( - lambda plugin: plugin.name == "InvenTreeBarcode", plugins - )) + internal_barcode_plugin = next( + filter(lambda plugin: plugin.name == 'InvenTreeBarcode', plugins) + ) if result := internal_barcode_plugin.scan(barcode): if 'stockitem' in result: - response["error"] = _("Item has already been received") + response['error'] = _('Item has already been received') raise ValidationError(response) # Now, look just for "supplier-barcode" plugins - plugins = registry.with_mixin("supplier-barcode") + plugins = registry.with_mixin('supplier-barcode') plugin_response = None for current_plugin in plugins: - result = current_plugin.scan_receive_item( - barcode, - request.user, - purchase_order=purchase_order, - location=location, + barcode, request.user, purchase_order=purchase_order, location=location ) if result is None: continue - if "error" in result: - logger.info("%s.scan_receive_item(...) returned an error: %s", - current_plugin.__class__.__name__, result["error"]) + if 'error' in result: + logger.info( + '%s.scan_receive_item(...) returned an error: %s', + current_plugin.__class__.__name__, + result['error'], + ) if not plugin_response: plugin = current_plugin plugin_response = result @@ -446,9 +429,9 @@ class BarcodePOReceive(BarcodeView): # A plugin has not been found! if plugin is None: - response["error"] = _("No match for supplier barcode") + response['error'] = _('No match for supplier barcode') raise ValidationError(response) - elif "error" in response: + elif 'error' in response: raise ValidationError(response) else: return Response(response) @@ -467,9 +450,7 @@ class BarcodeSOAllocate(BarcodeView): - Quantity """ - role_required = [ - 'sales_order.add', - ] + role_required = ['sales_order.add'] serializer_class = barcode_serializers.BarcodeSOAllocateSerializer @@ -488,20 +469,14 @@ class BarcodeSOAllocate(BarcodeView): # Find any matching line items for the stock item lines = order.models.SalesOrderLineItem.objects.filter( - order=sales_order, - part__in=parts, - shipped__lte=F('quantity'), + order=sales_order, part__in=parts, shipped__lte=F('quantity') ) if lines.count() > 1: - raise ValidationError({ - 'error': _('Multiple matching line items found'), - }) + raise ValidationError({'error': _('Multiple matching line items found')}) if lines.count() == 0: - raise ValidationError({ - 'error': _('No matching line item found'), - }) + raise ValidationError({'error': _('No matching line item found')}) return lines.first() @@ -513,14 +488,13 @@ class BarcodeSOAllocate(BarcodeView): if shipment := kwargs.get('shipment', None): if shipment.order != sales_order: raise ValidationError({ - 'error': _('Shipment does not match sales order'), + 'error': _('Shipment does not match sales order') }) return shipment shipments = order.models.SalesOrderShipment.objects.filter( - order=sales_order, - delivery_date=None + order=sales_order, delivery_date=None ) if shipments.count() == 1: @@ -579,7 +553,7 @@ class BarcodeSOAllocate(BarcodeView): 'sales_order': sales_order.pk if sales_order else None, 'line_item': line_item.pk if line_item else None, 'shipment': shipment.pk if shipment else None, - 'quantity': quantity + 'quantity': quantity, } if stock_item is not None and quantity is not None: @@ -590,10 +564,7 @@ class BarcodeSOAllocate(BarcodeView): # If we have sufficient information, we can allocate the stock item if all((x is not None for x in [line_item, sales_order, shipment, quantity])): order.models.SalesOrderAllocation.objects.create( - line=line_item, - shipment=shipment, - item=stock_item, - quantity=quantity, + line=line_item, shipment=shipment, item=stock_item, quantity=quantity ) response['success'] = _('Stock item allocated to sales order') @@ -609,19 +580,14 @@ class BarcodeSOAllocate(BarcodeView): barcode_api_urls = [ # Link a third-party barcode to an item (e.g. Part / StockItem / etc) path('link/', BarcodeAssign.as_view(), name='api-barcode-link'), - # Unlink a third-party barcode from an item path('unlink/', BarcodeUnassign.as_view(), name='api-barcode-unlink'), - # Receive a purchase order item by scanning its barcode - path("po-receive/", BarcodePOReceive.as_view(), name="api-barcode-po-receive"), - + path('po-receive/', BarcodePOReceive.as_view(), name='api-barcode-po-receive'), # Allocate parts to a purchase order by scanning their barcode - path("po-allocate/", BarcodePOAllocate.as_view(), name="api-barcode-po-allocate"), - + path('po-allocate/', BarcodePOAllocate.as_view(), name='api-barcode-po-allocate'), # Allocate stock to a sales order by scanning barcode - path("so-allocate/", BarcodeSOAllocate.as_view(), name="api-barcode-so-allocate"), - + path('so-allocate/', BarcodeSOAllocate.as_view(), name='api-barcode-so-allocate'), # Catch-all performs barcode 'scan' re_path(r'^.*$', BarcodeScan.as_view(), name='api-barcode-scan'), ] diff --git a/InvenTree/plugin/base/barcodes/mixins.py b/InvenTree/plugin/base/barcodes/mixins.py index 9576aef3a3..d6d60a30f2 100644 --- a/InvenTree/plugin/base/barcodes/mixins.py +++ b/InvenTree/plugin/base/barcodes/mixins.py @@ -23,7 +23,7 @@ class BarcodeMixin: Custom barcode plugins should use and extend this mixin as necessary. """ - ACTION_NAME = "" + ACTION_NAME = '' class MixinMeta: """Meta options for this mixin.""" @@ -62,19 +62,19 @@ class SupplierBarcodeMixin(BarcodeMixin): """ # Set of standard field names which can be extracted from the barcode - CUSTOMER_ORDER_NUMBER = "customer_order_number" - SUPPLIER_ORDER_NUMBER = "supplier_order_number" - PACKING_LIST_NUMBER = "packing_list_number" - SHIP_DATE = "ship_date" - CUSTOMER_PART_NUMBER = "customer_part_number" - SUPPLIER_PART_NUMBER = "supplier_part_number" - PURCHASE_ORDER_LINE = "purchase_order_line" - QUANTITY = "quantity" - DATE_CODE = "date_code" - LOT_CODE = "lot_code" - COUNTRY_OF_ORIGIN = "country_of_origin" - MANUFACTURER = "manufacturer" - MANUFACTURER_PART_NUMBER = "manufacturer_part_number" + CUSTOMER_ORDER_NUMBER = 'customer_order_number' + SUPPLIER_ORDER_NUMBER = 'supplier_order_number' + PACKING_LIST_NUMBER = 'packing_list_number' + SHIP_DATE = 'ship_date' + CUSTOMER_PART_NUMBER = 'customer_part_number' + SUPPLIER_PART_NUMBER = 'supplier_part_number' + PURCHASE_ORDER_LINE = 'purchase_order_line' + QUANTITY = 'quantity' + DATE_CODE = 'date_code' + LOT_CODE = 'lot_code' + COUNTRY_OF_ORIGIN = 'country_of_origin' + MANUFACTURER = 'manufacturer' + MANUFACTURER_PART_NUMBER = 'manufacturer_part_number' def __init__(self): """Register mixin.""" @@ -83,7 +83,7 @@ class SupplierBarcodeMixin(BarcodeMixin): def get_field_value(self, key, backup_value=None): """Return the value of a barcode field.""" - fields = getattr(self, "barcode_fields", None) or {} + fields = getattr(self, 'barcode_fields', None) or {} return fields.get(key, backup_value) @@ -124,7 +124,9 @@ class SupplierBarcodeMixin(BarcodeMixin): A dict object containing the barcode fields. """ - raise NotImplementedError("extract_barcode_fields must be implemented by each plugin") + raise NotImplementedError( + 'extract_barcode_fields must be implemented by each plugin' + ) def scan(self, barcode_data): """Try to match a supplier barcode to a supplier part.""" @@ -143,21 +145,19 @@ class SupplierBarcodeMixin(BarcodeMixin): ) if len(supplier_parts) > 1: - return {"error": _("Found multiple matching supplier parts for barcode")} + return {'error': _('Found multiple matching supplier parts for barcode')} elif not supplier_parts: return None supplier_part = supplier_parts[0] data = { - "pk": supplier_part.pk, - "api_url": f"{SupplierPart.get_api_url()}{supplier_part.pk}/", - "web_url": supplier_part.get_absolute_url(), + 'pk': supplier_part.pk, + 'api_url': f'{SupplierPart.get_api_url()}{supplier_part.pk}/', + 'web_url': supplier_part.get_absolute_url(), } - return { - SupplierPart.barcode_model_type(): data - } + return {SupplierPart.barcode_model_type(): data} def scan_receive_item(self, barcode_data, user, purchase_order=None, location=None): """Try to scan a supplier barcode to receive a purchase order item.""" @@ -178,7 +178,7 @@ class SupplierBarcodeMixin(BarcodeMixin): ) if len(supplier_parts) > 1: - return {"error": _("Found multiple matching supplier parts for barcode")} + return {'error': _('Found multiple matching supplier parts for barcode')} elif not supplier_parts: return None @@ -195,16 +195,18 @@ class SupplierBarcodeMixin(BarcodeMixin): order = self.customer_order_number or self.supplier_order_number if len(matching_orders) > 1: - return {"error": _(f"Found multiple purchase orders matching '{order}'")} + return { + 'error': _(f"Found multiple purchase orders matching '{order}'") + } if len(matching_orders) == 0: - return {"error": _(f"No matching purchase order for '{order}'")} + return {'error': _(f"No matching purchase order for '{order}'")} purchase_order = matching_orders.first() if supplier and purchase_order: if purchase_order.supplier != supplier: - return {"error": _("Purchase order does not match supplier")} + return {'error': _('Purchase order does not match supplier')} return self.receive_purchase_order_item( supplier_part, @@ -224,25 +226,27 @@ class SupplierBarcodeMixin(BarcodeMixin): if not isinstance(self, SettingsMixin): return None - if supplier_pk := self.get_setting("SUPPLIER_ID"): - if (supplier := Company.objects.get(pk=supplier_pk)): + if supplier_pk := self.get_setting('SUPPLIER_ID'): + if supplier := Company.objects.get(pk=supplier_pk): return supplier else: logger.error( - "No company with pk %d (set \"SUPPLIER_ID\" setting to a valid value)", - supplier_pk + 'No company with pk %d (set "SUPPLIER_ID" setting to a valid value)', + supplier_pk, ) return None - if not (supplier_name := getattr(self, "DEFAULT_SUPPLIER_NAME", None)): + if not (supplier_name := getattr(self, 'DEFAULT_SUPPLIER_NAME', None)): return None - suppliers = Company.objects.filter(name__icontains=supplier_name, is_supplier=True) + suppliers = Company.objects.filter( + name__icontains=supplier_name, is_supplier=True + ) if len(suppliers) != 1: return None - self.set_setting("SUPPLIER_ID", suppliers.first().pk) + self.set_setting('SUPPLIER_ID', suppliers.first().pk) return suppliers.first() @@ -256,21 +260,21 @@ class SupplierBarcodeMixin(BarcodeMixin): if it does not use the standard field names. """ return { - "K": cls.CUSTOMER_ORDER_NUMBER, - "1K": cls.SUPPLIER_ORDER_NUMBER, - "11K": cls.PACKING_LIST_NUMBER, - "6D": cls.SHIP_DATE, - "9D": cls.DATE_CODE, - "10D": cls.DATE_CODE, - "4K": cls.PURCHASE_ORDER_LINE, - "14K": cls.PURCHASE_ORDER_LINE, - "P": cls.SUPPLIER_PART_NUMBER, - "1P": cls.MANUFACTURER_PART_NUMBER, - "30P": cls.SUPPLIER_PART_NUMBER, - "1T": cls.LOT_CODE, - "4L": cls.COUNTRY_OF_ORIGIN, - "1V": cls.MANUFACTURER, - "Q": cls.QUANTITY, + 'K': cls.CUSTOMER_ORDER_NUMBER, + '1K': cls.SUPPLIER_ORDER_NUMBER, + '11K': cls.PACKING_LIST_NUMBER, + '6D': cls.SHIP_DATE, + '9D': cls.DATE_CODE, + '10D': cls.DATE_CODE, + '4K': cls.PURCHASE_ORDER_LINE, + '14K': cls.PURCHASE_ORDER_LINE, + 'P': cls.SUPPLIER_PART_NUMBER, + '1P': cls.MANUFACTURER_PART_NUMBER, + '30P': cls.SUPPLIER_PART_NUMBER, + '1T': cls.LOT_CODE, + '4L': cls.COUNTRY_OF_ORIGIN, + '1V': cls.MANUFACTURER, + 'Q': cls.QUANTITY, } @classmethod @@ -297,20 +301,22 @@ class SupplierBarcodeMixin(BarcodeMixin): for field in fields: for identifier, field_name in cls.ecia_field_map().items(): if field.startswith(identifier): - barcode_fields[field_name] = field[len(identifier):] + barcode_fields[field_name] = field[len(identifier) :] break return barcode_fields @staticmethod - def split_fields(barcode_data: str, delimiter: str = ',', header: str = '', trailer: str = '') -> list[str]: + def split_fields( + barcode_data: str, delimiter: str = ',', header: str = '', trailer: str = '' + ) -> list[str]: """Generic method for splitting barcode data into separate fields""" if header and barcode_data.startswith(header): - barcode_data = barcode_data[len(header):] + barcode_data = barcode_data[len(header) :] if trailer and barcode_data.endswith(trailer): - barcode_data = barcode_data[:-len(trailer)] + barcode_data = barcode_data[: -len(trailer)] return barcode_data.split(delimiter) @@ -318,10 +324,10 @@ class SupplierBarcodeMixin(BarcodeMixin): def parse_isoiec_15434_barcode2d(barcode_data: str) -> list[str]: """Parse a ISO/IEC 15434 barcode, returning the split data section.""" - OLD_MOUSER_HEADER = ">[)>06\x1D" - HEADER = "[)>\x1E06\x1D" - TRAILER = "\x1E\x04" - DELIMITER = "\x1D" + OLD_MOUSER_HEADER = '>[)>06\x1d' + HEADER = '[)>\x1e06\x1d' + TRAILER = '\x1e\x04' + DELIMITER = '\x1d' # Some old mouser barcodes start with this messed up header if barcode_data.startswith(OLD_MOUSER_HEADER): @@ -332,14 +338,13 @@ class SupplierBarcodeMixin(BarcodeMixin): return return SupplierBarcodeMixin.split_fields( - barcode_data, - delimiter=DELIMITER, - header=HEADER, - trailer=TRAILER, + barcode_data, delimiter=DELIMITER, header=HEADER, trailer=TRAILER ) @staticmethod - def get_purchase_orders(customer_order_number, supplier_order_number, supplier: Company = None): + def get_purchase_orders( + customer_order_number, supplier_order_number, supplier: Company = None + ): """Attempt to find a purchase order from the extracted customer and supplier order numbers""" orders = PurchaseOrder.objects.filter(status=PurchaseOrderStatus.PLACED.value) @@ -356,7 +361,9 @@ class SupplierBarcodeMixin(BarcodeMixin): if orders_union.count() == 1: return orders_union else: - orders_intersection = orders.filter(reference_filter & supplier_reference_filter) + orders_intersection = orders.filter( + reference_filter & supplier_reference_filter + ) return orders_intersection if orders_intersection else orders_union @staticmethod @@ -394,12 +401,12 @@ class SupplierBarcodeMixin(BarcodeMixin): @staticmethod def receive_purchase_order_item( - supplier_part: SupplierPart, - user: User, - quantity: Decimal | str = None, - purchase_order: PurchaseOrder = None, - location: StockLocation = None, - barcode: str = None, + supplier_part: SupplierPart, + user: User, + quantity: Decimal | str = None, + purchase_order: PurchaseOrder = None, + location: StockLocation = None, + barcode: str = None, ) -> dict: """Try to receive a purchase order item. @@ -419,7 +426,8 @@ class SupplierBarcodeMixin(BarcodeMixin): # find incomplete line_items that match the supplier_part line_items = purchase_order.lines.filter( - part=supplier_part.pk, quantity__gt=F("received")) + part=supplier_part.pk, quantity__gt=F('received') + ) if len(line_items) == 1 or not quantity: line_item = line_items[0] else: @@ -438,7 +446,7 @@ class SupplierBarcodeMixin(BarcodeMixin): line_item = line_items.first() if not line_item: - return {"error": _("Failed to find pending line item for supplier part")} + return {'error': _('Failed to find pending line item for supplier part')} no_stock_locations = False if not location: @@ -456,31 +464,26 @@ class SupplierBarcodeMixin(BarcodeMixin): no_stock_locations = True response = { - "lineitem": { - "pk": line_item.pk, - "purchase_order": purchase_order.pk, - } + 'lineitem': {'pk': line_item.pk, 'purchase_order': purchase_order.pk} } if quantity: - response["lineitem"]["quantity"] = quantity + response['lineitem']['quantity'] = quantity if location: - response["lineitem"]["location"] = location.pk + response['lineitem']['location'] = location.pk # if either the quantity is missing or no location is defined/found # -> return the line_item found, so the client can gather the missing # information and complete the action with an 'api-po-receive' call if not quantity or (not location and not no_stock_locations): - response["action_required"] = _("Further information required to receive line item") + response['action_required'] = _( + 'Further information required to receive line item' + ) return response purchase_order.receive_line_item( - line_item, - location, - quantity, - user, - barcode=barcode, + line_item, location, quantity, user, barcode=barcode ) - response["success"] = _("Received purchase order line item") + response['success'] = _('Received purchase order line item') return response diff --git a/InvenTree/plugin/base/barcodes/serializers.py b/InvenTree/plugin/base/barcodes/serializers.py index ef0b17d3d0..9482969545 100644 --- a/InvenTree/plugin/base/barcodes/serializers.py +++ b/InvenTree/plugin/base/barcodes/serializers.py @@ -8,8 +8,7 @@ from rest_framework import serializers import order.models import stock.models from InvenTree.status_codes import PurchaseOrderStatus, SalesOrderStatus -from plugin.builtin.barcodes.inventree_barcode import \ - InvenTreeInternalBarcodePlugin +from plugin.builtin.barcodes.inventree_barcode import InvenTreeInternalBarcodePlugin class BarcodeSerializer(serializers.Serializer): @@ -18,7 +17,8 @@ class BarcodeSerializer(serializers.Serializer): MAX_BARCODE_LENGTH = 4095 barcode = serializers.CharField( - required=True, help_text=_('Scanned barcode data'), + required=True, + help_text=_('Scanned barcode data'), max_length=MAX_BARCODE_LENGTH, ) @@ -32,17 +32,21 @@ class BarcodeAssignMixin(serializers.Serializer): super().__init__(*args, **kwargs) for model in InvenTreeInternalBarcodePlugin.get_supported_barcode_models(): - self.fields[model.barcode_model_type()] = serializers.PrimaryKeyRelatedField( - queryset=model.objects.all(), - required=False, allow_null=True, - label=model._meta.verbose_name, + self.fields[model.barcode_model_type()] = ( + serializers.PrimaryKeyRelatedField( + queryset=model.objects.all(), + required=False, + allow_null=True, + label=model._meta.verbose_name, + ) ) @staticmethod def get_model_fields(): """Return a list of model fields""" fields = [ - model.barcode_model_type() for model in InvenTreeInternalBarcodePlugin.get_supported_barcode_models() + model.barcode_model_type() + for model in InvenTreeInternalBarcodePlugin.get_supported_barcode_models() ] return fields @@ -54,10 +58,7 @@ class BarcodeAssignSerializer(BarcodeAssignMixin, BarcodeSerializer): class Meta: """Meta class for BarcodeAssignSerializer""" - fields = [ - 'barcode', - *BarcodeAssignMixin.get_model_fields() - ] + fields = ['barcode', *BarcodeAssignMixin.get_model_fields()] class BarcodeUnassignSerializer(BarcodeAssignMixin): @@ -85,7 +86,7 @@ class BarcodePOAllocateSerializer(BarcodeSerializer): """Validate the provided order""" if order.status != PurchaseOrderStatus.PENDING.value: - raise ValidationError(_("Purchase order is not pending")) + raise ValidationError(_('Purchase order is not pending')) return order @@ -101,7 +102,8 @@ class BarcodePOReceiveSerializer(BarcodeSerializer): purchase_order = serializers.PrimaryKeyRelatedField( queryset=order.models.PurchaseOrder.objects.all(), - required=False, allow_null=True, + required=False, + allow_null=True, help_text=_('PurchaseOrder to receive items against'), ) @@ -109,13 +111,14 @@ class BarcodePOReceiveSerializer(BarcodeSerializer): """Validate the provided order""" if order and order.status != PurchaseOrderStatus.PLACED.value: - raise ValidationError(_("Purchase order has not been placed")) + raise ValidationError(_('Purchase order has not been placed')) return order location = serializers.PrimaryKeyRelatedField( queryset=stock.models.StockLocation.objects.all(), - required=False, allow_null=True, + required=False, + allow_null=True, help_text=_('Location to receive items into'), ) @@ -123,7 +126,7 @@ class BarcodePOReceiveSerializer(BarcodeSerializer): """Validate the provided location""" if location and location.structural: - raise ValidationError(_("Cannot select a structural location")) + raise ValidationError(_('Cannot select a structural location')) return location @@ -144,19 +147,21 @@ class BarcodeSOAllocateSerializer(BarcodeSerializer): """Validate the provided order""" if order and order.status != SalesOrderStatus.PENDING.value: - raise ValidationError(_("Sales order is not pending")) + raise ValidationError(_('Sales order is not pending')) return order line = serializers.PrimaryKeyRelatedField( queryset=order.models.SalesOrderLineItem.objects.all(), - required=False, allow_null=True, + required=False, + allow_null=True, help_text=_('Sales order line item to allocate items against'), ) shipment = serializers.PrimaryKeyRelatedField( queryset=order.models.SalesOrderShipment.objects.all(), - required=False, allow_null=True, + required=False, + allow_null=True, help_text=_('Sales order shipment to allocate items against'), ) @@ -164,11 +169,10 @@ class BarcodeSOAllocateSerializer(BarcodeSerializer): """Validate the provided shipment""" if shipment and shipment.is_delivered(): - raise ValidationError(_("Shipment has already been delivered")) + raise ValidationError(_('Shipment has already been delivered')) return shipment quantity = serializers.IntegerField( - required=False, - help_text=_('Quantity to allocate'), + required=False, help_text=_('Quantity to allocate') ) diff --git a/InvenTree/plugin/base/barcodes/test_barcode.py b/InvenTree/plugin/base/barcodes/test_barcode.py index b828e12848..3b5cedba6f 100644 --- a/InvenTree/plugin/base/barcodes/test_barcode.py +++ b/InvenTree/plugin/base/barcodes/test_barcode.py @@ -12,12 +12,7 @@ from stock.models import StockItem class BarcodeAPITest(InvenTreeAPITestCase): """Tests for barcode api.""" - fixtures = [ - 'category', - 'part', - 'location', - 'stock' - ] + fixtures = ['category', 'part', 'location', 'stock'] def setUp(self): """Setup for all tests.""" @@ -29,7 +24,12 @@ class BarcodeAPITest(InvenTreeAPITestCase): def postBarcode(self, url, barcode, expected_code=None): """Post barcode and return results.""" - return self.post(url, format='json', data={'barcode': str(barcode)}, expected_code=expected_code) + return self.post( + url, + format='json', + data={'barcode': str(barcode)}, + expected_code=expected_code, + ) def test_invalid(self): """Test that invalid requests fail.""" @@ -38,8 +38,15 @@ class BarcodeAPITest(InvenTreeAPITestCase): # test wrong assign urls self.post(self.assign_url, format='json', data={}, expected_code=400) - self.post(self.assign_url, format='json', data={'barcode': '123'}, expected_code=400) - self.post(self.assign_url, format='json', data={'barcode': '123', 'stockitem': '123'}, expected_code=400) + self.post( + self.assign_url, format='json', data={'barcode': '123'}, expected_code=400 + ) + self.post( + self.assign_url, + format='json', + data={'barcode': '123', 'stockitem': '123'}, + expected_code=400, + ) def test_empty(self): """Test an empty barcode scan. @@ -59,11 +66,7 @@ class BarcodeAPITest(InvenTreeAPITestCase): part = Part.objects.first() response = self.post( - self.scan_url, - { - 'barcode': f'{{"part": {part.pk}}}', - }, - expected_code=200 + self.scan_url, {'barcode': f'{{"part": {part.pk}}}'}, expected_code=200 ) self.assertIn('part', response.data) @@ -73,11 +76,7 @@ class BarcodeAPITest(InvenTreeAPITestCase): def test_invalid_part(self): """Test response for invalid part.""" response = self.post( - self.scan_url, - { - 'barcode': '{"part": 999999999}' - }, - expected_code=400 + self.scan_url, {'barcode': '{"part": 999999999}'}, expected_code=400 ) self.assertIn('error', response.data) @@ -88,11 +87,7 @@ class BarcodeAPITest(InvenTreeAPITestCase): item = StockItem.objects.first() response = self.post( - self.scan_url, - { - 'barcode': item.format_barcode(), - }, - expected_code=200 + self.scan_url, {'barcode': item.format_barcode()}, expected_code=200 ) self.assertIn('stockitem', response.data) @@ -102,11 +97,7 @@ class BarcodeAPITest(InvenTreeAPITestCase): def test_invalid_item(self): """Test response for invalid stock item.""" response = self.post( - self.scan_url, - { - 'barcode': '{"stockitem": 999999999}' - }, - expected_code=400 + self.scan_url, {'barcode': '{"stockitem": 999999999}'}, expected_code=400 ) self.assertIn('error', response.data) @@ -114,11 +105,7 @@ class BarcodeAPITest(InvenTreeAPITestCase): def test_find_location(self): """Test that we can lookup a stock location based on ID.""" response = self.post( - self.scan_url, - { - 'barcode': '{"stocklocation": 1}', - }, - expected_code=200 + self.scan_url, {'barcode': '{"stocklocation": 1}'}, expected_code=200 ) self.assertIn('stocklocation', response.data) @@ -129,10 +116,8 @@ class BarcodeAPITest(InvenTreeAPITestCase): """Test response for an invalid location.""" response = self.post( self.scan_url, - { - 'barcode': '{"stocklocation": 999999999}' - }, - expected_code=400 + {'barcode': '{"stocklocation": 999999999}'}, + expected_code=400, ) self.assertIn('error', response.data) @@ -155,7 +140,9 @@ class BarcodeAPITest(InvenTreeAPITestCase): """Test that a barcode is generated with a scan.""" item = StockItem.objects.get(pk=522) - response = self.postBarcode(self.scan_url, item.format_barcode(), expected_code=200) + response = self.postBarcode( + self.scan_url, item.format_barcode(), expected_code=200 + ) data = response.data self.assertIn('stockitem', data) @@ -175,12 +162,10 @@ class BarcodeAPITest(InvenTreeAPITestCase): barcode_data = 'A-TEST-BARCODE-STRING' response = self.post( - self.assign_url, format='json', - data={ - 'barcode': barcode_data, - 'stockitem': item.pk - }, - expected_code=200 + self.assign_url, + format='json', + data={'barcode': barcode_data, 'stockitem': item.pk}, + expected_code=200, ) data = response.data @@ -197,12 +182,10 @@ class BarcodeAPITest(InvenTreeAPITestCase): # Ensure that the same barcode hash cannot be assigned to a different stock item! response = self.post( - self.assign_url, format='json', - data={ - 'barcode': barcode_data, - 'stockitem': 521 - }, - expected_code=400 + self.assign_url, + format='json', + data={'barcode': barcode_data, 'stockitem': 521}, + expected_code=400, ) self.assertIn('error', response.data) @@ -210,11 +193,7 @@ class BarcodeAPITest(InvenTreeAPITestCase): # Check that we can now unassign a barcode response = self.post( - self.unassign_url, - { - 'stockitem': item.pk, - }, - expected_code=200 + self.unassign_url, {'stockitem': item.pk}, expected_code=200 ) item.refresh_from_db() @@ -222,11 +201,7 @@ class BarcodeAPITest(InvenTreeAPITestCase): # Check that the 'unassign' endpoint fails if the stockitem is invalid response = self.post( - self.unassign_url, - { - 'stockitem': 999999999, - }, - expected_code=400 + self.unassign_url, {'stockitem': 999999999}, expected_code=400 ) def test_unassign_endpoint(self): @@ -236,41 +211,23 @@ class BarcodeAPITest(InvenTreeAPITestCase): # Invalid key should fail for k in invalid_keys: - response = self.post( - self.unassign_url, - { - k: 123 - }, - expected_code=400 - ) + response = self.post(self.unassign_url, {k: 123}, expected_code=400) - self.assertIn("Missing data: Provide one of", str(response.data['error'])) + self.assertIn('Missing data: Provide one of', str(response.data['error'])) valid_keys = ['build', 'salesorder', 'part'] # Valid key but invalid pk should fail for k in valid_keys: - response = self.post( - self.unassign_url, - { - k: 999999999 - }, - expected_code=400 - ) + response = self.post(self.unassign_url, {k: 999999999}, expected_code=400) - self.assertIn("object does not exist", str(response.data[k])) + self.assertIn('object does not exist', str(response.data[k])) class SOAllocateTest(InvenTreeAPITestCase): """Unit tests for the barcode endpoint for allocating items to a sales order""" - fixtures = [ - 'category', - 'company', - 'part', - 'location', - 'stock', - ] + fixtures = ['category', 'company', 'part', 'location', 'stock'] @classmethod def setUpTestData(cls): @@ -285,22 +242,15 @@ class SOAllocateTest(InvenTreeAPITestCase): cls.part = Part.objects.filter(salable=True).first() # Make a stock item - cls.stock_item = StockItem.objects.create( - part=cls.part, - quantity=100 - ) + cls.stock_item = StockItem.objects.create(part=cls.part, quantity=100) cls.stock_item.assign_barcode(barcode_data='barcode') # Find a customer - cls.customer = company.models.Company.objects.filter( - is_customer=True - ).first() + cls.customer = company.models.Company.objects.filter(is_customer=True).first() # Create a sales order - cls.sales_order = order.models.SalesOrder.objects.create( - customer=cls.customer - ) + cls.sales_order = order.models.SalesOrder.objects.create(customer=cls.customer) # Create a shipment cls.shipment = order.models.SalesOrderShipment.objects.create( @@ -309,9 +259,7 @@ class SOAllocateTest(InvenTreeAPITestCase): # Create a line item cls.line_item = order.models.SalesOrderLineItem.objects.create( - order=cls.sales_order, - part=cls.part, - quantity=10, + order=cls.sales_order, part=cls.part, quantity=10 ) def setUp(self): @@ -321,15 +269,10 @@ class SOAllocateTest(InvenTreeAPITestCase): def postBarcode(self, barcode, expected_code=None, **kwargs): """Post barcode and return results.""" - data = { - 'barcode': barcode, - **kwargs - } + data = {'barcode': barcode, **kwargs} response = self.post( - reverse('api-barcode-so-allocate'), - data=data, - expected_code=expected_code, + reverse('api-barcode-so-allocate'), data=data, expected_code=expected_code ) return response.data @@ -346,11 +289,7 @@ class SOAllocateTest(InvenTreeAPITestCase): """Test when an invalid sales order is provided""" # Test with an invalid sales order ID - result = self.postBarcode( - '', - sales_order=999999999, - expected_code=400 - ) + result = self.postBarcode('', sales_order=999999999, expected_code=400) self.assertIn('object does not exist', str(result['sales_order'])) @@ -359,9 +298,7 @@ class SOAllocateTest(InvenTreeAPITestCase): # Test with an invalid barcode result = self.postBarcode( - '123456789', - sales_order=self.sales_order.pk, - expected_code=400 + '123456789', sales_order=self.sales_order.pk, expected_code=400 ) self.assertIn('No match found for barcode', str(result['error'])) @@ -371,9 +308,7 @@ class SOAllocateTest(InvenTreeAPITestCase): item.assign_barcode(barcode_data='123456789') result = self.postBarcode( - '123456789', - sales_order=self.sales_order.pk, - expected_code=400 + '123456789', sales_order=self.sales_order.pk, expected_code=400 ) self.assertIn('No matching line item found', str(result['error'])) @@ -382,9 +317,7 @@ class SOAllocateTest(InvenTreeAPITestCase): item.part.assign_barcode(barcode_data='abcde') result = self.postBarcode( - 'abcde', - sales_order=self.sales_order.pk, - expected_code=400 + 'abcde', sales_order=self.sales_order.pk, expected_code=400 ) self.assertIn('does not match an existing stock item', str(result['error'])) @@ -393,9 +326,7 @@ class SOAllocateTest(InvenTreeAPITestCase): """Test data submission""" # Create a shipment for a different order - other_order = order.models.SalesOrder.objects.create( - customer=self.customer - ) + other_order = order.models.SalesOrder.objects.create(customer=self.customer) other_shipment = order.models.SalesOrderShipment.objects.create( order=other_order @@ -406,7 +337,7 @@ class SOAllocateTest(InvenTreeAPITestCase): self.stock_item.format_barcode(), sales_order=self.sales_order.pk, shipment=other_shipment.pk, - expected_code=400 + expected_code=400, ) self.assertIn('Shipment does not match sales order', str(response['error'])) @@ -418,7 +349,7 @@ class SOAllocateTest(InvenTreeAPITestCase): response = self.postBarcode( self.stock_item.format_barcode(), sales_order=self.sales_order.pk, - expected_code=200 + expected_code=200, ) # Check that the right data has been extracted diff --git a/InvenTree/plugin/base/event/events.py b/InvenTree/plugin/base/event/events.py index 01502b528c..c1699fbeb0 100644 --- a/InvenTree/plugin/base/event/events.py +++ b/InvenTree/plugin/base/event/events.py @@ -31,7 +31,10 @@ def trigger_event(event, *args, **kwargs): return # Make sure the database can be accessed and is not being tested rn - if not canAppAccessDatabase(allow_shell=True) and not settings.PLUGIN_TESTING_EVENTS: + if ( + not canAppAccessDatabase(allow_shell=True) + and not settings.PLUGIN_TESTING_EVENTS + ): logger.debug("Ignoring triggered event '%s' - database not ready", event) return @@ -41,12 +44,7 @@ def trigger_event(event, *args, **kwargs): if 'force_async' not in kwargs and not settings.PLUGIN_TESTING_EVENTS: kwargs['force_async'] = True - offload_task( - register_event, - event, - *args, - **kwargs - ) + offload_task(register_event, event, *args, **kwargs) def register_event(event, *args, **kwargs): @@ -61,9 +59,7 @@ def register_event(event, *args, **kwargs): # Determine if there are any plugins which are interested in responding if settings.PLUGIN_TESTING or InvenTreeSetting.get_setting('ENABLE_PLUGINS_EVENTS'): - with transaction.atomic(): - for slug, plugin in registry.plugins.items(): if not plugin.mixin_enabled('events'): continue @@ -84,13 +80,7 @@ def register_event(event, *args, **kwargs): kwargs['force_async'] = True # Offload a separate task for each plugin - offload_task( - process_event, - slug, - event, - *args, - **kwargs - ) + offload_task(process_event, slug, event, *args, **kwargs) def process_event(plugin_slug, event, *args, **kwargs): @@ -172,17 +162,9 @@ def after_save(sender, instance, created, **kwargs): return if created: - trigger_event( - f'{table}.created', - id=instance.id, - model=sender.__name__, - ) + trigger_event(f'{table}.created', id=instance.id, model=sender.__name__) else: - trigger_event( - f'{table}.saved', - id=instance.id, - model=sender.__name__, - ) + trigger_event(f'{table}.saved', id=instance.id, model=sender.__name__) @receiver(post_delete) @@ -193,7 +175,4 @@ def after_delete(sender, instance, **kwargs): if not allow_table_event(table): return - trigger_event( - f'{table}.deleted', - model=sender.__name__, - ) + trigger_event(f'{table}.deleted', model=sender.__name__) diff --git a/InvenTree/plugin/base/integration/APICallMixin.py b/InvenTree/plugin/base/integration/APICallMixin.py index 1dcef4dc61..8be6218a32 100644 --- a/InvenTree/plugin/base/integration/APICallMixin.py +++ b/InvenTree/plugin/base/integration/APICallMixin.py @@ -1,6 +1,5 @@ """Mixin class for making calls to an external API""" - import json as json_pkg import logging @@ -56,6 +55,7 @@ class APICallMixin: return self.api_call('api/users/2') ``` """ + API_METHOD = 'https' API_URL_SETTING = None API_TOKEN_SETTING = None @@ -64,6 +64,7 @@ class APICallMixin: class MixinMeta: """Meta options for this mixin.""" + MIXIN_NAME = 'API calls' def __init__(self): @@ -75,9 +76,9 @@ class APICallMixin: def has_api_call(self): """Is the mixin ready to call external APIs?""" if not bool(self.API_URL_SETTING): - raise MixinNotImplementedError("API_URL_SETTING must be defined") + raise MixinNotImplementedError('API_URL_SETTING must be defined') if not bool(self.API_TOKEN_SETTING): - raise MixinNotImplementedError("API_TOKEN_SETTING must be defined") + raise MixinNotImplementedError('API_TOKEN_SETTING must be defined') return True @property @@ -98,7 +99,7 @@ class APICallMixin: if token: headers[self.API_TOKEN] = token - headers['Authorization'] = f"{self.API_TOKEN} {token}" + headers['Authorization'] = f'{self.API_TOKEN} {token}' return headers @@ -109,7 +110,17 @@ class APICallMixin: groups.append(f'{key}={",".join([str(a) for a in val])}') return f'?{"&".join(groups)}' - def api_call(self, endpoint: str, method: str = 'GET', url_args: dict = None, data=None, json=None, headers: dict = None, simple_response: bool = True, endpoint_is_url: bool = False): + def api_call( + self, + endpoint: str, + method: str = 'GET', + url_args: dict = None, + data=None, + json=None, + headers: dict = None, + simple_response: bool = True, + endpoint_is_url: bool = False, + ): """Do an API call. Simplest call example: @@ -140,17 +151,13 @@ class APICallMixin: if endpoint_is_url: url = endpoint else: - if endpoint.startswith('/'): endpoint = endpoint[1:] url = f'{self.api_url}/{endpoint}' # build kwargs for call - kwargs = { - 'url': url, - 'headers': headers, - } + kwargs = {'url': url, 'headers': headers} if data and json: raise ValueError('You can either pass `data` or `json` to this function.') diff --git a/InvenTree/plugin/base/integration/AppMixin.py b/InvenTree/plugin/base/integration/AppMixin.py index 700f94a98e..196b941253 100644 --- a/InvenTree/plugin/base/integration/AppMixin.py +++ b/InvenTree/plugin/base/integration/AppMixin.py @@ -1,4 +1,5 @@ """Plugin mixin class for AppMixin.""" + import logging from importlib import reload from pathlib import Path @@ -26,7 +27,9 @@ class AppMixin: self.add_mixin('app', 'has_app', __class__) @classmethod - def _activate_mixin(cls, registry, plugins, force_reload=False, full_reload: bool = False): + def _activate_mixin( + cls, registry, plugins, force_reload=False, full_reload: bool = False + ): """Activate AppMixin plugins - add custom apps and reload. Args: @@ -37,7 +40,9 @@ class AppMixin: """ from common.models import InvenTreeSetting - if settings.PLUGIN_TESTING or InvenTreeSetting.get_setting('ENABLE_PLUGINS_APP'): + if settings.PLUGIN_TESTING or InvenTreeSetting.get_setting( + 'ENABLE_PLUGINS_APP' + ): logger.info('Registering IntegrationPlugin apps') apps_changed = False @@ -177,7 +182,7 @@ class AppMixin: return plugin_path -# endregion + # endregion @property def has_app(self): diff --git a/InvenTree/plugin/base/integration/CurrencyExchangeMixin.py b/InvenTree/plugin/base/integration/CurrencyExchangeMixin.py index 3ed67f5707..670e2efeb1 100644 --- a/InvenTree/plugin/base/integration/CurrencyExchangeMixin.py +++ b/InvenTree/plugin/base/integration/CurrencyExchangeMixin.py @@ -1,6 +1,5 @@ """Plugin mixin class for supporting currency exchange data""" - from plugin.helpers import MixinNotImplementedError @@ -17,7 +16,7 @@ class CurrencyExchangeMixin: class MixinMeta: """Meta options for this mixin class""" - MIXIN_NAME = "CurrentExchange" + MIXIN_NAME = 'CurrentExchange' def __init__(self): """Register the mixin""" @@ -39,4 +38,6 @@ class CurrencyExchangeMixin: Raises: Can raise any exception if the update fails """ - raise MixinNotImplementedError("Plugin must implement update_exchange_rates method") + raise MixinNotImplementedError( + 'Plugin must implement update_exchange_rates method' + ) diff --git a/InvenTree/plugin/base/integration/ScheduleMixin.py b/InvenTree/plugin/base/integration/ScheduleMixin.py index 1db3a339de..3b06c88315 100644 --- a/InvenTree/plugin/base/integration/ScheduleMixin.py +++ b/InvenTree/plugin/base/integration/ScheduleMixin.py @@ -1,4 +1,5 @@ """Plugin mixin class for ScheduleMixin.""" + import logging from django.conf import settings @@ -64,26 +65,27 @@ class ScheduleMixin: # List of tasks we have activated task_keys = [] - if settings.PLUGIN_TESTING or InvenTreeSetting.get_setting('ENABLE_PLUGINS_SCHEDULE'): - + if settings.PLUGIN_TESTING or InvenTreeSetting.get_setting( + 'ENABLE_PLUGINS_SCHEDULE' + ): for _key, plugin in plugins: - if plugin.mixin_enabled('schedule'): - if plugin.is_active(): # Only active tasks for plugins which are enabled plugin.register_tasks() task_keys += plugin.get_task_names() if len(task_keys) > 0: - logger.info("Activated %s scheduled tasks", len(task_keys)) + logger.info('Activated %s scheduled tasks', len(task_keys)) # Remove any scheduled tasks which do not match # This stops 'old' plugin tasks from accumulating try: from django_q.models import Schedule - scheduled_plugin_tasks = Schedule.objects.filter(name__istartswith="plugin.") + scheduled_plugin_tasks = Schedule.objects.filter( + name__istartswith='plugin.' + ) deleted_count = 0 @@ -93,10 +95,12 @@ class ScheduleMixin: deleted_count += 1 if deleted_count > 0: - logger.info("Removed %s old scheduled tasks", deleted_count) # pragma: no cover + logger.info( + 'Removed %s old scheduled tasks', deleted_count + ) # pragma: no cover except (ProgrammingError, OperationalError): # Database might not yet be ready - logger.warning("activate_integration_schedule failed, database not ready") + logger.warning('activate_integration_schedule failed, database not ready') def get_scheduled_tasks(self): """Returns `SCHEDULED_TASKS` context. @@ -113,30 +117,37 @@ class ScheduleMixin: def validate_scheduled_tasks(self): """Check that the provided scheduled tasks are valid.""" if not self.has_scheduled_tasks: - raise MixinImplementationError("SCHEDULED_TASKS not defined") + raise MixinImplementationError('SCHEDULED_TASKS not defined') for key, task in self.scheduled_tasks.items(): - if 'func' not in task: - raise MixinImplementationError(f"Task '{key}' is missing 'func' parameter") + raise MixinImplementationError( + f"Task '{key}' is missing 'func' parameter" + ) if 'schedule' not in task: - raise MixinImplementationError(f"Task '{key}' is missing 'schedule' parameter") + raise MixinImplementationError( + f"Task '{key}' is missing 'schedule' parameter" + ) schedule = task['schedule'].upper().strip() if schedule not in self.ALLOWABLE_SCHEDULE_TYPES: - raise MixinImplementationError(f"Task '{key}': Schedule '{schedule}' is not a valid option") + raise MixinImplementationError( + f"Task '{key}': Schedule '{schedule}' is not a valid option" + ) # If 'minutes' is selected, it must be provided! if schedule == 'I' and 'minutes' not in task: - raise MixinImplementationError(f"Task '{key}' is missing 'minutes' parameter") + raise MixinImplementationError( + f"Task '{key}' is missing 'minutes' parameter" + ) def get_task_name(self, key): """Task name for key.""" # Generate a 'unique' task name slug = self.plugin_slug() - return f"plugin.{slug}.{key}" + return f'plugin.{slug}.{key}' def get_task_names(self): """All defined task names.""" @@ -149,7 +160,6 @@ class ScheduleMixin: from django_q.models import Schedule for key, task in self.scheduled_tasks.items(): - task_name = self.get_task_name(key) obj = { @@ -184,7 +194,7 @@ class ScheduleMixin: except (ProgrammingError, OperationalError): # pragma: no cover # Database might not yet be ready - logger.warning("register_tasks failed, database not ready") + logger.warning('register_tasks failed, database not ready') def unregister_tasks(self): """Deregister the tasks with the database.""" @@ -192,7 +202,6 @@ class ScheduleMixin: from django_q.models import Schedule for key, _ in self.scheduled_tasks.items(): - task_name = self.get_task_name(key) try: @@ -202,4 +211,4 @@ class ScheduleMixin: pass except (ProgrammingError, OperationalError): # pragma: no cover # Database might not yet be ready - logger.warning("unregister_tasks failed, database not ready") + logger.warning('unregister_tasks failed, database not ready') diff --git a/InvenTree/plugin/base/integration/SettingsMixin.py b/InvenTree/plugin/base/integration/SettingsMixin.py index 21847c2009..c258f80103 100644 --- a/InvenTree/plugin/base/integration/SettingsMixin.py +++ b/InvenTree/plugin/base/integration/SettingsMixin.py @@ -1,4 +1,5 @@ """Plugin mixin class for SettingsMixin.""" + import logging from typing import TYPE_CHECKING, Dict @@ -10,8 +11,10 @@ logger = logging.getLogger('inventree') if TYPE_CHECKING: from common.models import SettingsKeyType else: + class SettingsKeyType: """Dummy class, so that python throws no error""" + pass @@ -22,6 +25,7 @@ class SettingsMixin: class MixinMeta: """Meta for mixin.""" + MIXIN_NAME = 'Settings' def __init__(self): @@ -74,7 +78,9 @@ class SettingsMixin: from plugin.models import PluginConfig, PluginSetting try: - plugin, _ = PluginConfig.objects.get_or_create(key=self.plugin_slug(), name=self.plugin_name()) + plugin, _ = PluginConfig.objects.get_or_create( + key=self.plugin_slug(), name=self.plugin_name() + ) except (OperationalError, ProgrammingError): # pragma: no cover plugin = None @@ -96,4 +102,6 @@ class SettingsMixin: """ from plugin.models import PluginSetting - return PluginSetting.check_all_settings(settings_definition=self.settings, plugin=self.plugin_config()) + return PluginSetting.check_all_settings( + settings_definition=self.settings, plugin=self.plugin_config() + ) diff --git a/InvenTree/plugin/base/integration/UrlsMixin.py b/InvenTree/plugin/base/integration/UrlsMixin.py index 67f87455d1..7a5972c0d8 100644 --- a/InvenTree/plugin/base/integration/UrlsMixin.py +++ b/InvenTree/plugin/base/integration/UrlsMixin.py @@ -1,4 +1,5 @@ """Plugin mixin class for UrlsMixin.""" + import logging from django.conf import settings @@ -24,7 +25,9 @@ class UrlsMixin: self.urls = self.setup_urls() @classmethod - def _activate_mixin(cls, registry, plugins, force_reload=False, full_reload: bool = False): + def _activate_mixin( + cls, registry, plugins, force_reload=False, full_reload: bool = False + ): """Activate UrlsMixin plugins - add custom urls . Args: @@ -34,7 +37,10 @@ class UrlsMixin: full_reload (bool, optional): Reload everything - including plugin mechanism. Defaults to False. """ from common.models import InvenTreeSetting - if settings.PLUGIN_TESTING or InvenTreeSetting.get_setting('ENABLE_PLUGINS_URL'): + + if settings.PLUGIN_TESTING or InvenTreeSetting.get_setting( + 'ENABLE_PLUGINS_URL' + ): logger.info('Registering UrlsMixin Plugin') urls_changed = False # check whether an activated plugin extends UrlsMixin @@ -64,7 +70,9 @@ class UrlsMixin: def urlpatterns(self): """Urlpatterns for this plugin.""" if self.has_urls: - return re_path(f'^{self.slug}/', include((self.urls, self.slug)), name=self.slug) + return re_path( + f'^{self.slug}/', include((self.urls, self.slug)), name=self.slug + ) return None @property diff --git a/InvenTree/plugin/base/integration/ValidationMixin.py b/InvenTree/plugin/base/integration/ValidationMixin.py index caf015aa4c..cee5030b70 100644 --- a/InvenTree/plugin/base/integration/ValidationMixin.py +++ b/InvenTree/plugin/base/integration/ValidationMixin.py @@ -1,6 +1,5 @@ """Validation mixin class definition""" - import part.models import stock.models @@ -33,7 +32,8 @@ class ValidationMixin: class MixinMeta: """Metaclass for this mixin""" - MIXIN_NAME = "Validation" + + MIXIN_NAME = 'Validation' def __init__(self): """Register the mixin""" diff --git a/InvenTree/plugin/base/integration/mixins.py b/InvenTree/plugin/base/integration/mixins.py index cc6276a5a0..0421472483 100644 --- a/InvenTree/plugin/base/integration/mixins.py +++ b/InvenTree/plugin/base/integration/mixins.py @@ -3,8 +3,7 @@ import logging from InvenTree.helpers import generateTestKey -from plugin.helpers import (MixinNotImplementedError, render_template, - render_text) +from plugin.helpers import MixinNotImplementedError, render_template, render_text logger = logging.getLogger('inventree') @@ -13,7 +12,7 @@ class NavigationMixin: """Mixin that enables custom navigation links with the plugin.""" NAVIGATION_TAB_NAME = None - NAVIGATION_TAB_ICON = "fas fa-question" + NAVIGATION_TAB_ICON = 'fas fa-question' class MixinMeta: """Meta options for this mixin.""" @@ -32,7 +31,7 @@ class NavigationMixin: if nav_links: # check if needed values are configured for link in nav_links: - if False in [a in link for a in ('link', 'name', )]: + if False in [a in link for a in ('link', 'name')]: raise MixinNotImplementedError('Wrong Link definition', link) return nav_links @@ -52,7 +51,7 @@ class NavigationMixin: @property def navigation_icon(self): """Icon-name for navigation tab.""" - return getattr(self, 'NAVIGATION_TAB_ICON', "fas fa-question") + return getattr(self, 'NAVIGATION_TAB_ICON', 'fas fa-question') class PanelMixin: @@ -111,7 +110,9 @@ class PanelMixin: def get_custom_panels(self, view, request): """This method *must* be implemented by the plugin class.""" - raise MixinNotImplementedError(f"{__class__} is missing the 'get_custom_panels' method") + raise MixinNotImplementedError( + f"{__class__} is missing the 'get_custom_panels' method" + ) def get_panel_context(self, view, request, context): """Build the context data to be used for template rendering. @@ -152,7 +153,6 @@ class PanelMixin: custom_panels = self.get_custom_panels(view, request) or [] for panel in custom_panels: - content_template = panel.get('content_template', None) javascript_template = panel.get('javascript_template', None) @@ -174,7 +174,10 @@ class PanelMixin: required_keys = ['title', 'content'] if any(key not in panel for key in required_keys): - logger.warning("Custom panel for plugin %s is missing a required parameter", __class__) + logger.warning( + 'Custom panel for plugin %s is missing a required parameter', + __class__, + ) continue # Add some information on this plugin @@ -207,4 +210,6 @@ class SettingsContentMixin: def get_settings_content(self, view, request): """This method *must* be implemented by the plugin class.""" - raise MixinNotImplementedError(f"{__class__} is missing the 'get_settings_content' method") + raise MixinNotImplementedError( + f"{__class__} is missing the 'get_settings_content' method" + ) diff --git a/InvenTree/plugin/base/integration/test_mixins.py b/InvenTree/plugin/base/integration/test_mixins.py index fde9df5352..e512fb9cea 100644 --- a/InvenTree/plugin/base/integration/test_mixins.py +++ b/InvenTree/plugin/base/integration/test_mixins.py @@ -12,8 +12,13 @@ from InvenTree.unit_test import InvenTreeTestCase from plugin import InvenTreePlugin from plugin.base.integration.mixins import PanelMixin from plugin.helpers import MixinNotImplementedError -from plugin.mixins import (APICallMixin, AppMixin, NavigationMixin, - SettingsMixin, UrlsMixin) +from plugin.mixins import ( + APICallMixin, + AppMixin, + NavigationMixin, + SettingsMixin, + UrlsMixin, +) from plugin.registry import registry from plugin.urls import PLUGIN_BASE @@ -24,9 +29,15 @@ class BaseMixinDefinition: def test_mixin_name(self): """Test that the mixin registers itseld correctly.""" # mixin name - self.assertIn(self.MIXIN_NAME, {item['key'] for item in self.mixin.registered_mixins.values()}) + self.assertIn( + self.MIXIN_NAME, + {item['key'] for item in self.mixin.registered_mixins.values()}, + ) # human name - self.assertIn(self.MIXIN_HUMAN_NAME, {item['human_name'] for item in self.mixin.registered_mixins.values()}) + self.assertIn( + self.MIXIN_HUMAN_NAME, + {item['human_name'] for item in self.mixin.registered_mixins.values()}, + ) class SettingsMixinTest(BaseMixinDefinition, InvenTreeTestCase): @@ -36,16 +47,19 @@ class SettingsMixinTest(BaseMixinDefinition, InvenTreeTestCase): MIXIN_NAME = 'settings' MIXIN_ENABLE_CHECK = 'has_settings' - TEST_SETTINGS = {'SETTING1': {'default': '123', }} + TEST_SETTINGS = {'SETTING1': {'default': '123'}} def setUp(self): """Setup for all tests.""" + class SettingsCls(SettingsMixin, InvenTreePlugin): SETTINGS = self.TEST_SETTINGS + self.mixin = SettingsCls() class NoSettingsCls(SettingsMixin, InvenTreePlugin): pass + self.mixin_nothing = NoSettingsCls() super().setUp() @@ -77,14 +91,18 @@ class UrlsMixinTest(BaseMixinDefinition, TestCase): def setUp(self): """Setup for all tests.""" + class UrlsCls(UrlsMixin, InvenTreePlugin): def test(): return 'ccc' - URLS = [re_path('testpath', test, name='test'), ] + + URLS = [re_path('testpath', test, name='test')] + self.mixin = UrlsCls() class NoUrlsCls(UrlsMixin, InvenTreePlugin): pass + self.mixin_nothing = NoUrlsCls() def test_function(self): @@ -96,8 +114,12 @@ class UrlsMixinTest(BaseMixinDefinition, TestCase): self.assertEqual(self.mixin.base_url, target_url) # urlpattern - target_pattern = re_path(f'^{plg_name}/', include((self.mixin.urls, plg_name)), name=plg_name) - self.assertEqual(self.mixin.urlpatterns.reverse_dict, target_pattern.reverse_dict) + target_pattern = re_path( + f'^{plg_name}/', include((self.mixin.urls, plg_name)), name=plg_name + ) + self.assertEqual( + self.mixin.urlpatterns.reverse_dict, target_pattern.reverse_dict + ) # resolve the view self.assertEqual(self.mixin.urlpatterns.resolve('/testpath').func(), 'ccc') @@ -120,8 +142,10 @@ class AppMixinTest(BaseMixinDefinition, TestCase): def setUp(self): """Setup for all tests.""" + class TestCls(AppMixin, InvenTreePlugin): pass + self.mixin = TestCls() def test_function(self): @@ -138,21 +162,24 @@ class NavigationMixinTest(BaseMixinDefinition, TestCase): def setUp(self): """Setup for all tests.""" + class NavigationCls(NavigationMixin, InvenTreePlugin): - NAVIGATION = [ - {'name': 'aa', 'link': 'plugin:test:test_view'}, - ] + NAVIGATION = [{'name': 'aa', 'link': 'plugin:test:test_view'}] NAVIGATION_TAB_NAME = 'abcd1' + self.mixin = NavigationCls() class NothingNavigationCls(NavigationMixin, InvenTreePlugin): pass + self.nothing_mixin = NothingNavigationCls() def test_function(self): """Test that a correct configuration functions.""" # check right configuration - self.assertEqual(self.mixin.navigation, [{'name': 'aa', 'link': 'plugin:test:test_view'}, ]) + self.assertEqual( + self.mixin.navigation, [{'name': 'aa', 'link': 'plugin:test:test_view'}] + ) # navigation name self.assertEqual(self.mixin.navigation_name, 'abcd1') @@ -161,8 +188,10 @@ class NavigationMixinTest(BaseMixinDefinition, TestCase): def test_fail(self): """Test that wrong links fail.""" with self.assertRaises(NotImplementedError): + class NavigationCls(NavigationMixin, InvenTreePlugin): NAVIGATION = ['aa', 'aa'] + NavigationCls() @@ -177,13 +206,10 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): """Setup for all tests.""" class MixinCls(APICallMixin, SettingsMixin, InvenTreePlugin): - NAME = "Sample API Caller" + NAME = 'Sample API Caller' SETTINGS = { - 'API_TOKEN': { - 'name': 'API Token', - 'protected': True, - }, + 'API_TOKEN': {'name': 'API Token', 'protected': True}, 'API_URL': { 'name': 'External URL', 'description': 'Where is your API located?', @@ -197,7 +223,7 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): @property def api_url(self): """Override API URL for this test""" - return "https://api.github.com" + return 'https://api.github.com' def get_external_url(self, simple: bool = True): """Returns data from the sample endpoint.""" @@ -243,7 +269,7 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): result = self.mixin.api_build_url_args({'a': 'b', 'c': 'd'}) self.assertEqual(result, '?a=b&c=d') # list args - result = self.mixin.api_build_url_args({'a': 'b', 'c': ['d', 'e', 'f', ]}) + result = self.mixin.api_build_url_args({'a': 'b', 'c': ['d', 'e', 'f']}) self.assertEqual(result, '?a=b&c=d,e,f') def test_api_call(self): @@ -263,7 +289,7 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): # api_call with post and data result = self.mixin.api_call( 'https://reqres.in/api/users/', - json={"name": "morpheus", "job": "leader"}, + json={'name': 'morpheus', 'job': 'leader'}, method='POST', endpoint_is_url=True, ) @@ -277,7 +303,9 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): self.assertEqual(result.reason, 'OK') # api_call with filter - result = self.mixin.api_call('repos/inventree/InvenTree/stargazers', url_args={'page': '2'}) + result = self.mixin.api_call( + 'repos/inventree/InvenTree/stargazers', url_args={'page': '2'} + ) self.assertTrue(result) def test_function_errors(self): @@ -293,17 +321,16 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): # Too many data arguments with self.assertRaises(ValueError): self.mixin.api_call( - 'https://reqres.in/api/users/', - json={"a": 1, }, data={"a": 1}, + 'https://reqres.in/api/users/', json={'a': 1}, data={'a': 1} ) # Sending a request with a wrong data format should result in 40 result = self.mixin.api_call( 'https://reqres.in/api/users/', - data={"name": "morpheus", "job": "leader"}, + data={'name': 'morpheus', 'job': 'leader'}, method='POST', endpoint_is_url=True, - simple_response=False + simple_response=False, ) self.assertEqual(result.status_code, 400) @@ -313,12 +340,7 @@ class APICallMixinTest(BaseMixinDefinition, TestCase): class PanelMixinTests(InvenTreeTestCase): """Test that the PanelMixin plugin operates correctly.""" - fixtures = [ - 'category', - 'part', - 'location', - 'stock', - ] + fixtures = ['category', 'part', 'location', 'stock'] roles = 'all' @@ -352,9 +374,7 @@ class PanelMixinTests(InvenTreeTestCase): reverse('stock-item-detail', kwargs={'pk': 2}), reverse('stock-location-detail', kwargs={'pk': 1}), ]: - response = self.client.get( - url - ) + response = self.client.get(url) self.assertEqual(response.status_code, 200) @@ -440,6 +460,7 @@ class PanelMixinTests(InvenTreeTestCase): def test_mixin(self): """Test that ImplementationError is raised.""" with self.assertRaises(MixinNotImplementedError): + class Wrong(PanelMixin, InvenTreePlugin): pass diff --git a/InvenTree/plugin/base/label/label.py b/InvenTree/plugin/base/label/label.py index ea4d831cc4..aadd4be209 100644 --- a/InvenTree/plugin/base/label/label.py +++ b/InvenTree/plugin/base/label/label.py @@ -36,17 +36,16 @@ def print_label(plugin_slug: str, **kwargs): plugin.print_label(**kwargs) except Exception as e: # pragma: no cover # Plugin threw an error - notify the user who attempted to print - ctx = { - 'name': _('Label printing failed'), - 'message': str(e), - } + ctx = {'name': _('Label printing failed'), 'message': str(e)} user = kwargs.get('user', None) if user: # Log an error message to the database log_error('plugin.print_label') - logger.exception("Label printing failed: Sending notification to user '%s'", user) # pragma: no cover + logger.exception( + "Label printing failed: Sending notification to user '%s'", user + ) # pragma: no cover # Throw an error against the plugin instance common.notifications.trigger_notification( @@ -54,7 +53,7 @@ def print_label(plugin_slug: str, **kwargs): 'label.printing_failed', targets=[user], context=ctx, - delivery_methods={common.notifications.UIMessageNotification, }, + delivery_methods={common.notifications.UIMessageNotification}, ) if settings.TESTING: diff --git a/InvenTree/plugin/base/label/mixins.py b/InvenTree/plugin/base/label/mixins.py index 56194c29a4..8c62b0b09f 100644 --- a/InvenTree/plugin/base/label/mixins.py +++ b/InvenTree/plugin/base/label/mixins.py @@ -32,6 +32,7 @@ class LabelPrintingMixin: class MixinMeta: """Meta options for this mixin.""" + MIXIN_NAME = 'Label printing' def __init__(self): # pragma: no cover @@ -63,18 +64,24 @@ class LabelPrintingMixin: pdf_data = kwargs.get('pdf_data', None) if not pdf_data: - pdf_data = self.render_to_pdf(label, request, **kwargs).get_document().write_pdf() + pdf_data = ( + self.render_to_pdf(label, request, **kwargs).get_document().write_pdf() + ) - dpi = kwargs.get( - 'dpi', - InvenTreeSetting.get_setting('LABEL_DPI', 300) - ) + dpi = kwargs.get('dpi', InvenTreeSetting.get_setting('LABEL_DPI', 300)) # Convert to png data png = pdf2image.convert_from_bytes(pdf_data, dpi=dpi)[0] return png - def print_labels(self, label: LabelTemplate, items: list, request: Request, printing_options: dict, **kwargs): + def print_labels( + self, + label: LabelTemplate, + items: list, + request: Request, + printing_options: dict, + **kwargs, + ): """Print one or more labels with the provided template and items. Arguments: @@ -149,7 +156,9 @@ class LabelPrintingMixin: Note that the supplied kwargs may be different if the plugin overrides the print_labels() method. """ # Unimplemented (to be implemented by the particular plugin class) - raise MixinNotImplementedError('This Plugin must implement a `print_label` method') + raise MixinNotImplementedError( + 'This Plugin must implement a `print_label` method' + ) def offload_label(self, **kwargs): """Offload a single label (non-blocking) @@ -162,13 +171,11 @@ class LabelPrintingMixin: # Exclude the 'pdf_file' object - cannot be pickled kwargs.pop('pdf_file', None) - offload_task( - plugin_label.print_label, - self.plugin_slug(), - **kwargs - ) + offload_task(plugin_label.print_label, self.plugin_slug(), **kwargs) - def get_printing_options_serializer(self, request: Request, *args, **kwargs) -> Union[serializers.Serializer, None]: + def get_printing_options_serializer( + self, request: Request, *args, **kwargs + ) -> Union[serializers.Serializer, None]: """Return a serializer class instance with dynamic printing options. Arguments: @@ -179,7 +186,7 @@ class LabelPrintingMixin: A class instance of a DRF serializer class, by default this an instance of self.PrintingOptionsSerializer using the *args, **kwargs if existing for this plugin """ - serializer = getattr(self, "PrintingOptionsSerializer", None) + serializer = getattr(self, 'PrintingOptionsSerializer', None) if not serializer: return None diff --git a/InvenTree/plugin/base/label/test_label_mixin.py b/InvenTree/plugin/base/label/test_label_mixin.py index bb6389b80a..dbb73713f5 100644 --- a/InvenTree/plugin/base/label/test_label_mixin.py +++ b/InvenTree/plugin/base/label/test_label_mixin.py @@ -23,12 +23,7 @@ from stock.models import StockItem, StockLocation class LabelMixinTests(InvenTreeAPITestCase): """Test that the Label mixin operates correctly.""" - fixtures = [ - 'category', - 'part', - 'location', - 'stock', - ] + fixtures = ['category', 'part', 'location', 'stock'] roles = 'all' @@ -38,12 +33,20 @@ class LabelMixinTests(InvenTreeAPITestCase): config.active = True config.save() - def do_url(self, parts, plugin_ref, label, url_name: str = 'api-part-label-print', url_single: str = 'part', invalid: bool = False): + def do_url( + self, + parts, + plugin_ref, + label, + url_name: str = 'api-part-label-print', + url_single: str = 'part', + invalid: bool = False, + ): """Generate an URL to print a label.""" # Construct URL kwargs = {} if label: - kwargs["pk"] = label.pk + kwargs['pk'] = label.pk url = reverse(url_name, kwargs=kwargs) @@ -93,37 +96,19 @@ class LabelMixinTests(InvenTreeAPITestCase): response = self.client.post(url, {}) self.assertEqual(response.status_code, 405) - response = self.client.get( - url, - { - 'mixin': 'labels', - 'active': True, - } - ) + response = self.client.get(url, {'mixin': 'labels', 'active': True}) # No results matching this query! self.assertEqual(len(response.data), 0) # What about inactive? - response = self.client.get( - url, - { - 'mixin': 'labels', - 'active': False, - } - ) + response = self.client.get(url, {'mixin': 'labels', 'active': False}) self.assertEqual(len(response.data), 0) self.do_activate_plugin() # Should be available via the API now - response = self.client.get( - url, - { - 'mixin': 'labels', - 'active': True, - } - ) + response = self.client.get(url, {'mixin': 'labels', 'active': True}) self.assertEqual(len(response.data), 3) @@ -147,11 +132,15 @@ class LabelMixinTests(InvenTreeAPITestCase): # Non-existing plugin response = self.get(f'{url}123', expected_code=404) - self.assertIn(f'Plugin \'{plugin_ref}123\' not found', str(response.content, 'utf8')) + self.assertIn( + f"Plugin '{plugin_ref}123' not found", str(response.content, 'utf8') + ) # Inactive plugin response = self.get(url, expected_code=400) - self.assertIn(f'Plugin \'{plugin_ref}\' is not enabled', str(response.content, 'utf8')) + self.assertIn( + f"Plugin '{plugin_ref}' is not enabled", str(response.content, 'utf8') + ) # Active plugin self.do_activate_plugin() @@ -202,19 +191,31 @@ class LabelMixinTests(InvenTreeAPITestCase): self.do_activate_plugin() # test options response - options = self.options(self.do_url(parts, plugin_ref, label), expected_code=200).json() - self.assertTrue("amount" in options["actions"]["POST"]) + options = self.options( + self.do_url(parts, plugin_ref, label), expected_code=200 + ).json() + self.assertTrue('amount' in options['actions']['POST']) plg = registry.get_plugin(plugin_ref) - with mock.patch.object(plg, "print_label") as print_label: + with mock.patch.object(plg, 'print_label') as print_label: # wrong value type - res = self.post(self.do_url(parts, plugin_ref, label), data={"amount": "-no-valid-int-"}, expected_code=400).json() - self.assertTrue("amount" in res) + res = self.post( + self.do_url(parts, plugin_ref, label), + data={'amount': '-no-valid-int-'}, + expected_code=400, + ).json() + self.assertTrue('amount' in res) print_label.assert_not_called() # correct value type - self.post(self.do_url(parts, plugin_ref, label), data={"amount": 13}, expected_code=200).json() - self.assertEqual(print_label.call_args.kwargs["printing_options"], {"amount": 13}) + self.post( + self.do_url(parts, plugin_ref, label), + data={'amount': 13}, + expected_code=200, + ).json() + self.assertEqual( + print_label.call_args.kwargs['printing_options'], {'amount': 13} + ) def test_printing_endpoints(self): """Cover the endpoints not covered by `test_printing_process`.""" @@ -237,22 +238,38 @@ class LabelMixinTests(InvenTreeAPITestCase): qs = qs.objects.all() # List endpoint - self.get(self.do_url(None, None, None, f'{url_name}-list', url_single), expected_code=200) + self.get( + self.do_url(None, None, None, f'{url_name}-list', url_single), + expected_code=200, + ) # List endpoint with filter - self.get(self.do_url(qs[:2], None, None, f'{url_name}-list', url_single, invalid=True), expected_code=200) + self.get( + self.do_url( + qs[:2], None, None, f'{url_name}-list', url_single, invalid=True + ), + expected_code=200, + ) # Single page printing - self.get(self.do_url(qs[:1], plugin_ref, label, f'{url_name}-print', url_single), expected_code=200) + self.get( + self.do_url(qs[:1], plugin_ref, label, f'{url_name}-print', url_single), + expected_code=200, + ) # Multi page printing - self.get(self.do_url(qs[:2], plugin_ref, label, f'{url_name}-print', url_single), expected_code=200) + self.get( + self.do_url(qs[:2], plugin_ref, label, f'{url_name}-print', url_single), + expected_code=200, + ) # Test StockItemLabels run_print_test(StockItemLabel, StockItem, 'api-stockitem-label', 'item') # Test StockLocationLabels - run_print_test(StockLocationLabel, StockLocation, 'api-stocklocation-label', 'location') + run_print_test( + StockLocationLabel, StockLocation, 'api-stocklocation-label', 'location' + ) # Test PartLabels run_print_test(PartLabel, Part, 'api-part-label', 'part') diff --git a/InvenTree/plugin/base/locate/api.py b/InvenTree/plugin/base/locate/api.py index a7d6f5d86e..c3f921d010 100644 --- a/InvenTree/plugin/base/locate/api.py +++ b/InvenTree/plugin/base/locate/api.py @@ -13,9 +13,7 @@ from stock.models import StockItem, StockLocation class LocatePluginView(APIView): """Endpoint for using a custom plugin to identify or 'locate' a stock item or location.""" - permission_classes = [ - permissions.IsAuthenticated, - ] + permission_classes = [permissions.IsAuthenticated] def post(self, request, *args, **kwargs): """Check inputs and offload the task to the plugin.""" @@ -29,7 +27,9 @@ class LocatePluginView(APIView): plugins = registry.with_mixin('locate') if plugin not in [p.slug for p in plugins]: - raise ParseError(f"Plugin '{plugin}' is not installed, or does not support the location mixin") + raise ParseError( + f"Plugin '{plugin}' is not installed, or does not support the location mixin" + ) # StockItem to identify item_pk = request.data.get('item', None) @@ -37,17 +37,16 @@ class LocatePluginView(APIView): # StockLocation to identify location_pk = request.data.get('location', None) - data = { - "success": "Identification plugin activated", - "plugin": plugin, - } + data = {'success': 'Identification plugin activated', 'plugin': plugin} # StockItem takes priority if item_pk: try: StockItem.objects.get(pk=item_pk) - offload_task(registry.call_plugin_function, plugin, 'locate_stock_item', item_pk) + offload_task( + registry.call_plugin_function, plugin, 'locate_stock_item', item_pk + ) data['item'] = item_pk @@ -60,7 +59,12 @@ class LocatePluginView(APIView): try: StockLocation.objects.get(pk=location_pk) - offload_task(registry.call_plugin_function, plugin, 'locate_stock_location', location_pk) + offload_task( + registry.call_plugin_function, + plugin, + 'locate_stock_location', + location_pk, + ) data['location'] = location_pk diff --git a/InvenTree/plugin/base/locate/mixins.py b/InvenTree/plugin/base/locate/mixins.py index e6404af9c3..005e9a8fcd 100644 --- a/InvenTree/plugin/base/locate/mixins.py +++ b/InvenTree/plugin/base/locate/mixins.py @@ -25,7 +25,8 @@ class LocateMixin: class MixinMeta: """Meta for mixin.""" - MIXIN_NAME = "Locate" + + MIXIN_NAME = 'Locate' def __init__(self): """Register the mixin.""" @@ -45,7 +46,7 @@ class LocateMixin: Note: A custom implementation could always change this behaviour """ - logger.info("LocateMixin: Attempting to locate StockItem pk=%s", item_pk) + logger.info('LocateMixin: Attempting to locate StockItem pk=%s', item_pk) from stock.models import StockItem @@ -56,7 +57,7 @@ class LocateMixin: self.locate_stock_location(item.location.pk) except StockItem.DoesNotExist: # pragma: no cover - logger.warning("LocateMixin: StockItem pk={item_pk} not found") + logger.warning('LocateMixin: StockItem pk={item_pk} not found') pass def locate_stock_location(self, location_pk): diff --git a/InvenTree/plugin/base/locate/test_locate.py b/InvenTree/plugin/base/locate/test_locate.py index 227f7553ec..e47d4a6cb5 100644 --- a/InvenTree/plugin/base/locate/test_locate.py +++ b/InvenTree/plugin/base/locate/test_locate.py @@ -11,12 +11,7 @@ from stock.models import StockItem, StockLocation class LocatePluginTests(InvenTreeAPITestCase): """Tests for LocateMixin.""" - fixtures = [ - 'category', - 'part', - 'location', - 'stock', - ] + fixtures = ['category', 'part', 'location', 'stock'] def test_installed(self): """Test that a locate plugin is actually installed.""" @@ -31,60 +26,41 @@ class LocatePluginTests(InvenTreeAPITestCase): url = reverse('api-locate-plugin') # Post without a plugin - response = self.post( - url, - {}, - expected_code=400 - ) + response = self.post(url, {}, expected_code=400) self.assertIn("'plugin' field must be supplied", str(response.data)) # Post with a plugin that does not exist, or is invalid for slug in ['xyz', 'event', 'plugin']: - response = self.post( - url, - { - 'plugin': slug, - }, - expected_code=400, + response = self.post(url, {'plugin': slug}, expected_code=400) + + self.assertIn( + f"Plugin '{slug}' is not installed, or does not support the location mixin", + str(response.data), ) - self.assertIn(f"Plugin '{slug}' is not installed, or does not support the location mixin", str(response.data)) - # Post with a valid plugin, but no other data - response = self.post( - url, - { - 'plugin': 'samplelocate', - }, - expected_code=400 - ) + response = self.post(url, {'plugin': 'samplelocate'}, expected_code=400) - self.assertIn("Must supply either 'item' or 'location' parameter", str(response.data)) + self.assertIn( + "Must supply either 'item' or 'location' parameter", str(response.data) + ) # Post with valid plugin, invalid item or location for pk in ['qq', 99999, -42]: response = self.post( - url, - { - 'plugin': 'samplelocate', - 'item': pk, - }, - expected_code=404 + url, {'plugin': 'samplelocate', 'item': pk}, expected_code=404 ) self.assertIn(f"StockItem matching PK '{pk}' not found", str(response.data)) response = self.post( - url, - { - 'plugin': 'samplelocate', - 'location': pk, - }, - expected_code=404, + url, {'plugin': 'samplelocate', 'location': pk}, expected_code=404 ) - self.assertIn(f"StockLocation matching PK '{pk}' not found", str(response.data)) + self.assertIn( + f"StockLocation matching PK '{pk}' not found", str(response.data) + ) def test_locate_item(self): """Test that the plugin correctly 'locates' a StockItem. @@ -100,12 +76,7 @@ class LocatePluginTests(InvenTreeAPITestCase): item.set_metadata('located', False) response = self.post( - url, - { - 'plugin': 'samplelocate', - 'item': 1, - }, - expected_code=200 + url, {'plugin': 'samplelocate', 'item': 1}, expected_code=200 ) self.assertEqual(response.data['item'], 1) @@ -120,16 +91,12 @@ class LocatePluginTests(InvenTreeAPITestCase): url = reverse('api-locate-plugin') for location in StockLocation.objects.all(): - location.set_metadata('located', False) response = self.post( url, - { - 'plugin': 'samplelocate', - 'location': location.pk, - }, - expected_code=200 + {'plugin': 'samplelocate', 'location': location.pk}, + expected_code=200, ) self.assertEqual(response.data['location'], location.pk) @@ -141,6 +108,7 @@ class LocatePluginTests(InvenTreeAPITestCase): def test_mixin_locate(self): """Test the sample mixin redirection.""" + class SamplePlugin(LocateMixin, InvenTreePlugin): pass diff --git a/InvenTree/plugin/broken/broken_sample.py b/InvenTree/plugin/broken/broken_sample.py index b873b93fa6..30f9b1a1ea 100644 --- a/InvenTree/plugin/broken/broken_sample.py +++ b/InvenTree/plugin/broken/broken_sample.py @@ -1,4 +1,5 @@ """Sample of a broken plugin.""" + from plugin import InvenTreePlugin diff --git a/InvenTree/plugin/builtin/barcodes/inventree_barcode.py b/InvenTree/plugin/builtin/barcodes/inventree_barcode.py index 7b0e7a4f7f..83c7b48f17 100644 --- a/InvenTree/plugin/builtin/barcodes/inventree_barcode.py +++ b/InvenTree/plugin/builtin/barcodes/inventree_barcode.py @@ -21,11 +21,11 @@ from plugin.mixins import BarcodeMixin class InvenTreeInternalBarcodePlugin(BarcodeMixin, InvenTreePlugin): """Builtin BarcodePlugin for matching and generating internal barcodes.""" - NAME = "InvenTreeBarcode" - TITLE = _("InvenTree Barcodes") - DESCRIPTION = _("Provides native support for barcodes") - VERSION = "2.0.0" - AUTHOR = _("InvenTree contributors") + NAME = 'InvenTreeBarcode' + TITLE = _('InvenTree Barcodes') + DESCRIPTION = _('Provides native support for barcodes') + VERSION = '2.0.0' + AUTHOR = _('InvenTree contributors') @staticmethod def get_supported_barcode_models(): diff --git a/InvenTree/plugin/builtin/barcodes/test_inventree_barcode.py b/InvenTree/plugin/builtin/barcodes/test_inventree_barcode.py index 2635471a3f..388376f002 100644 --- a/InvenTree/plugin/builtin/barcodes/test_inventree_barcode.py +++ b/InvenTree/plugin/builtin/barcodes/test_inventree_barcode.py @@ -10,25 +10,16 @@ from InvenTree.unit_test import InvenTreeAPITestCase class TestInvenTreeBarcode(InvenTreeAPITestCase): """Tests for the integrated InvenTreeBarcode barcode plugin.""" - fixtures = [ - 'category', - 'part', - 'location', - 'stock', - 'company', - 'supplier_part', - ] + fixtures = ['category', 'part', 'location', 'stock', 'company', 'supplier_part'] def test_assign_errors(self): """Test error cases for assignment action.""" def test_assert_error(barcode_data): response = self.post( - reverse('api-barcode-link'), format='json', - data={ - 'barcode': barcode_data, - 'stockitem': 521 - }, + reverse('api-barcode-link'), + format='json', + data={'barcode': barcode_data, 'stockitem': 521}, expected_code=400, ) @@ -46,34 +37,25 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): def assign(self, data, expected_code=None): """Perform a 'barcode assign' request""" return self.post( - reverse('api-barcode-link'), - data=data, - expected_code=expected_code + reverse('api-barcode-link'), data=data, expected_code=expected_code ) def unassign(self, data, expected_code=None): """Perform a 'barcode unassign' request""" return self.post( - reverse('api-barcode-unlink'), - data=data, - expected_code=expected_code, + reverse('api-barcode-unlink'), data=data, expected_code=expected_code ) def scan(self, data, expected_code=None): """Perform a 'scan' operation""" return self.post( - reverse('api-barcode-scan'), - data=data, - expected_code=expected_code + reverse('api-barcode-scan'), data=data, expected_code=expected_code ) def test_unassign_errors(self): """Test various error conditions for the barcode unassign endpoint""" # Fail without any fields provided - response = self.unassign( - {}, - expected_code=400, - ) + response = self.unassign({}, expected_code=400) self.assertIn('Missing data: Provide one of', str(response.data['error'])) @@ -89,45 +71,25 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): self.assertIn('Multiple conflicting fields:', str(response.data['error'])) # Fail with an invalid StockItem instance - response = self.unassign( - { - 'stockitem': 'invalid', - }, - expected_code=400, - ) + response = self.unassign({'stockitem': 'invalid'}, expected_code=400) self.assertIn('Incorrect type', str(response.data['stockitem'])) # Fail with an invalid Part instance - response = self.unassign( - { - 'part': 99999999999, - }, - expected_code=400, - ) + response = self.unassign({'part': 99999999999}, expected_code=400) self.assertIn('object does not exist', str(response.data['part'])) def test_assign_to_stock_item(self): """Test that we can assign a unique barcode to a StockItem object""" # Test without providing any fields - response = self.assign( - { - 'barcode': 'abcde', - }, - expected_code=400 - ) + response = self.assign({'barcode': 'abcde'}, expected_code=400) self.assertIn('Missing data:', str(response.data)) # Permission error check response = self.assign( - { - 'barcode': 'abcdefg', - 'part': 1, - 'stockitem': 1, - }, - expected_code=403 + {'barcode': 'abcdefg', 'part': 1, 'stockitem': 1}, expected_code=403 ) self.assignRole('part.change') @@ -135,12 +97,7 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): # Provide too many fields response = self.assign( - { - 'barcode': 'abcdefg', - 'part': 1, - 'stockitem': 1, - }, - expected_code=200 + {'barcode': 'abcdefg', 'part': 1, 'stockitem': 1}, expected_code=200 ) self.assertIn('Assigned barcode to part instance', str(response.data)) @@ -150,11 +107,7 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): # Assign a barcode to a StockItem instance response = self.assign( - data={ - 'barcode': bc_data, - 'stockitem': 521, - }, - expected_code=200, + data={'barcode': bc_data, 'stockitem': 521}, expected_code=200 ) data = response.data @@ -165,26 +118,17 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): si = stock.models.StockItem.objects.get(pk=521) self.assertEqual(si.barcode_data, bc_data) - self.assertEqual(si.barcode_hash, "2f5dba5c83a360599ba7665b2a4131c6") + self.assertEqual(si.barcode_hash, '2f5dba5c83a360599ba7665b2a4131c6') # Now test that we cannot assign this barcode to something else response = self.assign( - data={ - 'barcode': bc_data, - 'stockitem': 1, - }, - expected_code=400 + data={'barcode': bc_data, 'stockitem': 1}, expected_code=400 ) self.assertIn('Barcode matches existing item', str(response.data)) # Next, test that we can 'unassign' the barcode via the API - response = self.unassign( - { - 'stockitem': 521, - }, - expected_code=200, - ) + response = self.unassign({'stockitem': 521}, expected_code=200) si.refresh_from_db() @@ -198,34 +142,19 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): self.assignRole('part.change') # Test that an initial scan yields no results - response = self.scan( - { - 'barcode': barcode, - }, - expected_code=400 - ) + response = self.scan({'barcode': barcode}, expected_code=400) self.assignRole('part.change') # Attempt to assign to an invalid part ID response = self.assign( - { - 'barcode': barcode, - 'part': 99999999, - }, - expected_code=400, + {'barcode': barcode, 'part': 99999999}, expected_code=400 ) self.assertIn('object does not exist', str(response.data['part'])) # Test assigning to a valid part (should pass) - response = self.assign( - { - 'barcode': barcode, - 'part': 1, - }, - expected_code=200, - ) + response = self.assign({'barcode': barcode, 'part': 1}, expected_code=200) self.assertEqual(response.data['part']['pk'], 1) self.assertEqual(response.data['success'], 'Assigned barcode to part instance') @@ -236,37 +165,21 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): self.assertEqual(p.barcode_hash, 'bc39d07e9a395c7b5658c231bf910fae') # Scanning the barcode should now reveal the 'Part' instance - response = self.scan( - { - 'barcode': barcode, - }, - expected_code=200, - ) + response = self.scan({'barcode': barcode}, expected_code=200) self.assertIn('success', response.data) self.assertEqual(response.data['plugin'], 'InvenTreeBarcode') self.assertEqual(response.data['part']['pk'], 1) # Attempting to assign the same barcode to a different part should result in an error - response = self.assign( - { - 'barcode': barcode, - 'part': 2, - }, - expected_code=400, - ) + response = self.assign({'barcode': barcode, 'part': 2}, expected_code=400) self.assertIn('Barcode matches existing item', str(response.data['error'])) self.assignRole('part.change') # Now test that we can unassign the barcode data also - response = self.unassign( - { - 'part': 1, - }, - expected_code=200, - ) + response = self.unassign({'part': 1}, expected_code=200) p.refresh_from_db() @@ -279,22 +192,14 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): # Assign random barcode data to a StockLocation instance response = self.assign( - data={ - 'barcode': barcode, - 'stocklocation': 1, - }, - expected_code=403, + data={'barcode': barcode, 'stocklocation': 1}, expected_code=403 ) self.assignRole('stock_location.change') # Assign random barcode data to a StockLocation instance response = self.assign( - data={ - 'barcode': barcode, - 'stocklocation': 1, - }, - expected_code=200, + data={'barcode': barcode, 'stocklocation': 1}, expected_code=200 ) self.assertIn('success', response.data) @@ -308,22 +213,13 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): # Check that an error is thrown if we try to assign the same value again response = self.assign( - data={ - 'barcode': barcode, - 'stocklocation': 2, - }, - expected_code=400 + data={'barcode': barcode, 'stocklocation': 2}, expected_code=400 ) self.assertIn('Barcode matches existing item', str(response.data['error'])) # Now, unassign the barcode - response = self.unassign( - { - 'stocklocation': 1, - }, - expected_code=200, - ) + response = self.unassign({'stocklocation': 1}, expected_code=200) loc.refresh_from_db() self.assertEqual(loc.barcode_data, '') @@ -347,12 +243,7 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): for barcode in ['abcde', 'ABCDE', '12345']: si.assign_barcode(barcode_data=barcode) - response = self.scan( - { - 'barcode': barcode, - }, - expected_code=200, - ) + response = self.scan({'barcode': barcode}, expected_code=200) self.assertIn('success', response.data) self.assertEqual(response.data['stockitem']['pk'], 1) @@ -360,58 +251,37 @@ class TestInvenTreeBarcode(InvenTreeAPITestCase): def test_scan_inventree(self): """Test scanning of first-party barcodes""" # Scan a StockItem object (which does not exist) - response = self.scan( - { - 'barcode': '{"stockitem": 5}', - }, - expected_code=400, - ) + response = self.scan({'barcode': '{"stockitem": 5}'}, expected_code=400) self.assertIn('No match found for barcode data', str(response.data)) # Scan a StockItem object (which does exist) - response = self.scan( - { - 'barcode': '{"stockitem": 1}', - }, - expected_code=200 - ) + response = self.scan({'barcode': '{"stockitem": 1}'}, expected_code=200) self.assertIn('success', response.data) self.assertIn('stockitem', response.data) self.assertEqual(response.data['stockitem']['pk'], 1) # Scan a StockLocation object - response = self.scan( - { - 'barcode': '{"stocklocation": 5}', - }, - expected_code=200, - ) + response = self.scan({'barcode': '{"stocklocation": 5}'}, expected_code=200) self.assertIn('success', response.data) self.assertEqual(response.data['stocklocation']['pk'], 5) - self.assertEqual(response.data['stocklocation']['api_url'], '/api/stock/location/5/') - self.assertEqual(response.data['stocklocation']['web_url'], '/stock/location/5/') + self.assertEqual( + response.data['stocklocation']['api_url'], '/api/stock/location/5/' + ) + self.assertEqual( + response.data['stocklocation']['web_url'], '/stock/location/5/' + ) self.assertEqual(response.data['plugin'], 'InvenTreeBarcode') # Scan a Part object - response = self.scan( - { - 'barcode': '{"part": 5}' - }, - expected_code=200, - ) + response = self.scan({'barcode': '{"part": 5}'}, expected_code=200) self.assertEqual(response.data['part']['pk'], 5) # Scan a SupplierPart instance - response = self.scan( - { - 'barcode': '{"supplierpart": 1}', - }, - expected_code=200 - ) + response = self.scan({'barcode': '{"supplierpart": 1}'}, expected_code=200) self.assertEqual(response.data['supplierpart']['pk'], 1) self.assertEqual(response.data['plugin'], 'InvenTreeBarcode') diff --git a/InvenTree/plugin/builtin/integration/core_notifications.py b/InvenTree/plugin/builtin/integration/core_notifications.py index a988cf8287..e882562ba2 100644 --- a/InvenTree/plugin/builtin/integration/core_notifications.py +++ b/InvenTree/plugin/builtin/integration/core_notifications.py @@ -11,8 +11,7 @@ import InvenTree.email import InvenTree.helpers import InvenTree.tasks from plugin import InvenTreePlugin, registry -from plugin.mixins import (BulkNotificationMethod, SettingsContentMixin, - SettingsMixin) +from plugin.mixins import BulkNotificationMethod, SettingsContentMixin, SettingsMixin class PlgMixin: @@ -26,14 +25,16 @@ class PlgMixin: return InvenTreeCoreNotificationsPlugin -class InvenTreeCoreNotificationsPlugin(SettingsContentMixin, SettingsMixin, InvenTreePlugin): +class InvenTreeCoreNotificationsPlugin( + SettingsContentMixin, SettingsMixin, InvenTreePlugin +): """Core notification methods for InvenTree.""" - NAME = "InvenTreeCoreNotificationsPlugin" - TITLE = _("InvenTree Notifications") + NAME = 'InvenTreeCoreNotificationsPlugin' + TITLE = _('InvenTree Notifications') AUTHOR = _('InvenTree contributors') DESCRIPTION = _('Integrated outgoing notification methods') - VERSION = "1.0.0" + VERSION = '1.0.0' SETTINGS = { 'ENABLE_NOTIFICATION_EMAILS': { @@ -44,7 +45,9 @@ class InvenTreeCoreNotificationsPlugin(SettingsContentMixin, SettingsMixin, Inve }, 'ENABLE_NOTIFICATION_SLACK': { 'name': _('Enable slack notifications'), - 'description': _('Allow sending of slack channel messages for event notifications'), + 'description': _( + 'Allow sending of slack channel messages for event notifications' + ), 'default': False, 'validator': bool, }, @@ -71,11 +74,7 @@ class InvenTreeCoreNotificationsPlugin(SettingsContentMixin, SettingsMixin, Inve METHOD_NAME = 'mail' METHOD_ICON = 'fa-envelope' - CONTEXT_EXTRA = [ - ('template', ), - ('template', 'html', ), - ('template', 'subject', ), - ] + CONTEXT_EXTRA = [('template',), ('template', 'html'), ('template', 'subject')] GLOBAL_SETTING = 'ENABLE_NOTIFICATION_EMAILS' USER_SETTING = { 'name': _('Enable email notifications'), @@ -89,7 +88,6 @@ class InvenTreeCoreNotificationsPlugin(SettingsContentMixin, SettingsMixin, Inve allowed_users = [] for user in self.targets: - if not user.is_active: # Ignore any users who have been deactivated continue @@ -99,17 +97,19 @@ class InvenTreeCoreNotificationsPlugin(SettingsContentMixin, SettingsMixin, Inve if allows_emails: allowed_users.append(user) - return EmailAddress.objects.filter( - user__in=allowed_users, - ) + return EmailAddress.objects.filter(user__in=allowed_users) def send_bulk(self): """Send the notifications out via email.""" - html_message = render_to_string(self.context['template']['html'], self.context) + html_message = render_to_string( + self.context['template']['html'], self.context + ) targets = self.targets.values_list('email', flat=True) # Prefix the 'instance title' to the email subject - instance_title = common.models.InvenTreeSetting.get_setting('INVENTREE_INSTANCE') + instance_title = common.models.InvenTreeSetting.get_setting( + 'INVENTREE_INSTANCE' + ) subject = self.context['template'].get('subject', '') @@ -139,32 +139,37 @@ class InvenTreeCoreNotificationsPlugin(SettingsContentMixin, SettingsMixin, Inve if not url: return False - ret = requests.post(url, json={ - 'text': str(self.context['message']), - 'blocks': [ - { - "type": "section", - "text": { - "type": "plain_text", - "text": str(self.context['name']) - } - }, - { - "type": "section", - "text": { - "type": "mrkdwn", - "text": str(self.context['message']) - }, - "accessory": { - "type": "button", - "text": { - "type": "plain_text", - "text": str(_("Open link")), "emoji": True + ret = requests.post( + url, + json={ + 'text': str(self.context['message']), + 'blocks': [ + { + 'type': 'section', + 'text': { + 'type': 'plain_text', + 'text': str(self.context['name']), }, - "value": f'{self.category}_{self.obj.pk}', - "url": self.context['link'], - "action_id": "button-action" - } - }] - }) + }, + { + 'type': 'section', + 'text': { + 'type': 'mrkdwn', + 'text': str(self.context['message']), + }, + 'accessory': { + 'type': 'button', + 'text': { + 'type': 'plain_text', + 'text': str(_('Open link')), + 'emoji': True, + }, + 'value': f'{self.category}_{self.obj.pk}', + 'url': self.context['link'], + 'action_id': 'button-action', + }, + }, + ], + }, + ) return ret.ok diff --git a/InvenTree/plugin/builtin/integration/currency_exchange.py b/InvenTree/plugin/builtin/integration/currency_exchange.py index 33788c0ba2..7a57371ce4 100644 --- a/InvenTree/plugin/builtin/integration/currency_exchange.py +++ b/InvenTree/plugin/builtin/integration/currency_exchange.py @@ -1,6 +1,5 @@ """Builtin plugin for requesting exchange rates from an external API.""" - import logging from django.utils.translation import gettext_lazy as _ @@ -17,31 +16,31 @@ class InvenTreeCurrencyExchange(APICallMixin, CurrencyExchangeMixin, InvenTreePl Fetches exchange rate information from frankfurter.app """ - NAME = "InvenTreeCurrencyExchange" - SLUG = "inventreecurrencyexchange" + NAME = 'InvenTreeCurrencyExchange' + SLUG = 'inventreecurrencyexchange' AUTHOR = _('InvenTree contributors') - TITLE = _("InvenTree Currency Exchange") - DESCRIPTION = _("Default currency exchange integration") - VERSION = "1.0.0" + TITLE = _('InvenTree Currency Exchange') + DESCRIPTION = _('Default currency exchange integration') + VERSION = '1.0.0' def update_exchange_rates(self, base_currency: str, symbols: list[str]) -> dict: """Request exchange rate data from external API""" response = self.api_call( 'latest', - url_args={ - 'from': [base_currency], - 'to': symbols, - }, - simple_response=False + url_args={'from': [base_currency], 'to': symbols}, + simple_response=False, ) if response.status_code == 200: - rates = response.json().get('rates', {}) rates[base_currency] = 1.00 return rates - logger.warning("Failed to update exchange rates from %s: Server returned status %s", self.api_url, response.status_code) + logger.warning( + 'Failed to update exchange rates from %s: Server returned status %s', + self.api_url, + response.status_code, + ) return None @property diff --git a/InvenTree/plugin/builtin/integration/test_core_notifications.py b/InvenTree/plugin/builtin/integration/test_core_notifications.py index 60ed1d6d35..2c6e9f9455 100644 --- a/InvenTree/plugin/builtin/integration/test_core_notifications.py +++ b/InvenTree/plugin/builtin/integration/test_core_notifications.py @@ -4,8 +4,9 @@ from django.core import mail from part.test_part import BaseNotificationIntegrationTest from plugin import registry -from plugin.builtin.integration.core_notifications import \ - InvenTreeCoreNotificationsPlugin +from plugin.builtin.integration.core_notifications import ( + InvenTreeCoreNotificationsPlugin, +) from plugin.models import NotificationUserSetting @@ -25,7 +26,7 @@ class CoreNotificationTestTests(BaseNotificationIntegrationTest): value=True, change_user=self.user, user=self.user, - method=InvenTreeCoreNotificationsPlugin.EmailNotification.METHOD_NAME + method=InvenTreeCoreNotificationsPlugin.EmailNotification.METHOD_NAME, ) # run through diff --git a/InvenTree/plugin/builtin/labels/inventree_label.py b/InvenTree/plugin/builtin/labels/inventree_label.py index 8fa02c9042..f4d18c0a73 100644 --- a/InvenTree/plugin/builtin/labels/inventree_label.py +++ b/InvenTree/plugin/builtin/labels/inventree_label.py @@ -16,11 +16,11 @@ class InvenTreeLabelPlugin(LabelPrintingMixin, SettingsMixin, InvenTreePlugin): which is made available for download. """ - NAME = "InvenTreeLabel" - TITLE = _("InvenTree PDF label printer") - DESCRIPTION = _("Provides native support for printing PDF labels") - VERSION = "1.0.0" - AUTHOR = _("InvenTree contributors") + NAME = 'InvenTreeLabel' + TITLE = _('InvenTree PDF label printer') + DESCRIPTION = _('Provides native support for printing PDF labels') + VERSION = '1.0.0' + AUTHOR = _('InvenTree contributors') BLOCKING_PRINT = True @@ -30,7 +30,7 @@ class InvenTreeLabelPlugin(LabelPrintingMixin, SettingsMixin, InvenTreePlugin): 'description': _('Enable debug mode - returns raw HTML instead of PDF'), 'validator': bool, 'default': False, - }, + } } def print_labels(self, label: LabelTemplate, items: list, request, **kwargs): @@ -45,7 +45,6 @@ class InvenTreeLabelPlugin(LabelPrintingMixin, SettingsMixin, InvenTreePlugin): output_file = None for item in items: - label.object_to_print = item outputs.append(self.print_label(label, request, debug=debug, **kwargs)) @@ -70,15 +69,12 @@ class InvenTreeLabelPlugin(LabelPrintingMixin, SettingsMixin, InvenTreePlugin): output_file = ContentFile(pdf, 'labels.pdf') # Save the generated file to the database - output = LabelOutput.objects.create( - label=output_file, - user=request.user - ) + output = LabelOutput.objects.create(label=output_file, user=request.user) return JsonResponse({ 'file': output.label.url, 'success': True, - 'message': f'{len(items)} labels generated' + 'message': f'{len(items)} labels generated', }) def print_label(self, label: LabelTemplate, request, **kwargs): diff --git a/InvenTree/plugin/builtin/labels/label_sheet.py b/InvenTree/plugin/builtin/labels/label_sheet.py index 4a48edbe25..1d16aca1b9 100644 --- a/InvenTree/plugin/builtin/labels/label_sheet.py +++ b/InvenTree/plugin/builtin/labels/label_sheet.py @@ -26,26 +26,26 @@ class LabelPrintingOptionsSerializer(serializers.Serializer): choices=report.helpers.report_page_size_options(), default='A4', label=_('Page Size'), - help_text=_('Page size for the label sheet') + help_text=_('Page size for the label sheet'), ) skip = serializers.IntegerField( default=0, label=_('Skip Labels'), help_text=_('Skip this number of labels when printing label sheets'), - min_value=0 + min_value=0, ) border = serializers.BooleanField( default=False, label=_('Border'), - help_text=_('Print a border around each label') + help_text=_('Print a border around each label'), ) landscape = serializers.BooleanField( default=False, label=_('Landscape'), - help_text=_('Print the label sheet in landscape mode') + help_text=_('Print the label sheet in landscape mode'), ) @@ -56,11 +56,11 @@ class InvenTreeLabelSheetPlugin(LabelPrintingMixin, SettingsMixin, InvenTreePlug and returns the resulting PDF file. """ - NAME = "InvenTreeLabelSheet" - TITLE = _("InvenTree Label Sheet Printer") - DESCRIPTION = _("Arrays multiple labels onto a single sheet") - VERSION = "1.0.0" - AUTHOR = _("InvenTree contributors") + NAME = 'InvenTreeLabelSheet' + TITLE = _('InvenTree Label Sheet Printer') + DESCRIPTION = _('Arrays multiple labels onto a single sheet') + VERSION = '1.0.0' + AUTHOR = _('InvenTree contributors') BLOCKING_PRINT = True @@ -92,7 +92,7 @@ class InvenTreeLabelSheetPlugin(LabelPrintingMixin, SettingsMixin, InvenTreePlug n_cells = n_cols * n_rows if n_cells == 0: - raise ValidationError(_("Label is too large for page size")) + raise ValidationError(_('Label is too large for page size')) # Prepend the required number of skipped null labels items = [None] * skip + list(items) @@ -101,16 +101,16 @@ class InvenTreeLabelSheetPlugin(LabelPrintingMixin, SettingsMixin, InvenTreePlug # Data to pass through to each page document_data = { - "border": border, - "landscape": landscape, - "page_width": page_width, - "page_height": page_height, - "label_width": label.width, - "label_height": label.height, - "n_labels": n_labels, - "n_pages": math.ceil(n_labels / n_cells), - "n_cols": n_cols, - "n_rows": n_rows, + 'border': border, + 'landscape': landscape, + 'page_width': page_width, + 'page_height': page_height, + 'label_width': label.width, + 'label_height': label.height, + 'n_labels': n_labels, + 'n_pages': math.ceil(n_labels / n_cells), + 'n_cols': n_cols, + 'n_rows': n_rows, } pages = [] @@ -118,14 +118,15 @@ class InvenTreeLabelSheetPlugin(LabelPrintingMixin, SettingsMixin, InvenTreePlug idx = 0 while idx < n_labels: - - if page := self.print_page(label, items[idx:idx + n_cells], request, **document_data): + if page := self.print_page( + label, items[idx : idx + n_cells], request, **document_data + ): pages.append(page) idx += n_cells if len(pages) == 0: - raise ValidationError(_("No labels were generated")) + raise ValidationError(_('No labels were generated')) # Render to a single HTML document html_data = self.wrap_pages(pages, **document_data) @@ -136,15 +137,12 @@ class InvenTreeLabelSheetPlugin(LabelPrintingMixin, SettingsMixin, InvenTreePlug output_file = ContentFile(document, 'labels.pdf') - output = LabelOutput.objects.create( - label=output_file, - user=request.user - ) + output = LabelOutput.objects.create(label=output_file, user=request.user) return JsonResponse({ 'file': output.label.url, 'success': True, - 'message': f'{len(items)} labels generated' + 'message': f'{len(items)} labels generated', }) def print_page(self, label: LabelTemplate, items: list, request, **kwargs): @@ -173,7 +171,6 @@ class InvenTreeLabelSheetPlugin(LabelPrintingMixin, SettingsMixin, InvenTreePlug html += "" for col in range(n_cols): - # Cell index idx = row * n_cols + col @@ -190,22 +187,20 @@ class InvenTreeLabelSheetPlugin(LabelPrintingMixin, SettingsMixin, InvenTreePlug # Render the individual label template # Note that we disable @page styling for this cell = label.render_as_string( - request, - target_object=items[idx], - insert_page_style=False + request, target_object=items[idx], insert_page_style=False ) html += cell except Exception as exc: - logger.exception("Error rendering label: %s", str(exc)) + logger.exception('Error rendering label: %s', str(exc)) html += """
    """ - html += "" + html += '' - html += "" + html += '' - html += "" + html += '' return html @@ -229,20 +224,24 @@ class InvenTreeLabelSheetPlugin(LabelPrintingMixin, SettingsMixin, InvenTreePlug cell_styles = [] for row in range(n_rows): - cell_styles.append(f""" + cell_styles.append( + f""" .label-sheet-row-{row} {{ top: {row * label_height}mm; }} - """) + """ + ) for col in range(n_cols): - cell_styles.append(f""" + cell_styles.append( + f""" .label-sheet-col-{col} {{ left: {col * label_width}mm; }} - """) + """ + ) - cell_styles = "\n".join(cell_styles) + cell_styles = '\n'.join(cell_styles) return f""" diff --git a/InvenTree/plugin/builtin/suppliers/digikey.py b/InvenTree/plugin/builtin/suppliers/digikey.py index bf8da46dd1..c441f9a9d9 100644 --- a/InvenTree/plugin/builtin/suppliers/digikey.py +++ b/InvenTree/plugin/builtin/suppliers/digikey.py @@ -12,19 +12,19 @@ from plugin.mixins import SettingsMixin, SupplierBarcodeMixin class DigiKeyPlugin(SupplierBarcodeMixin, SettingsMixin, InvenTreePlugin): """Plugin to integrate the DigiKey API into Inventree.""" - NAME = "DigiKeyPlugin" - TITLE = _("Supplier Integration - DigiKey") - DESCRIPTION = _("Provides support for scanning DigiKey barcodes") - VERSION = "1.0.0" - AUTHOR = _("InvenTree contributors") + NAME = 'DigiKeyPlugin' + TITLE = _('Supplier Integration - DigiKey') + DESCRIPTION = _('Provides support for scanning DigiKey barcodes') + VERSION = '1.0.0' + AUTHOR = _('InvenTree contributors') - DEFAULT_SUPPLIER_NAME = "DigiKey" + DEFAULT_SUPPLIER_NAME = 'DigiKey' SETTINGS = { - "SUPPLIER_ID": { - "name": _("Supplier"), - "description": _("The Supplier which acts as 'DigiKey'"), - "model": "company.company", + 'SUPPLIER_ID': { + 'name': _('Supplier'), + 'description': _("The Supplier which acts as 'DigiKey'"), + 'model': 'company.company', } } diff --git a/InvenTree/plugin/builtin/suppliers/lcsc.py b/InvenTree/plugin/builtin/suppliers/lcsc.py index d8a820ac15..5228b80fed 100644 --- a/InvenTree/plugin/builtin/suppliers/lcsc.py +++ b/InvenTree/plugin/builtin/suppliers/lcsc.py @@ -14,29 +14,29 @@ from plugin.mixins import SettingsMixin, SupplierBarcodeMixin class LCSCPlugin(SupplierBarcodeMixin, SettingsMixin, InvenTreePlugin): """Plugin to integrate the LCSC API into Inventree.""" - NAME = "LCSCPlugin" - TITLE = _("Supplier Integration - LCSC") - DESCRIPTION = _("Provides support for scanning LCSC barcodes") - VERSION = "1.0.0" - AUTHOR = _("InvenTree contributors") + NAME = 'LCSCPlugin' + TITLE = _('Supplier Integration - LCSC') + DESCRIPTION = _('Provides support for scanning LCSC barcodes') + VERSION = '1.0.0' + AUTHOR = _('InvenTree contributors') - DEFAULT_SUPPLIER_NAME = "LCSC" + DEFAULT_SUPPLIER_NAME = 'LCSC' SETTINGS = { - "SUPPLIER_ID": { - "name": _("Supplier"), - "description": _("The Supplier which acts as 'LCSC'"), - "model": "company.company", + 'SUPPLIER_ID': { + 'name': _('Supplier'), + 'description': _("The Supplier which acts as 'LCSC'"), + 'model': 'company.company', } } - LCSC_BARCODE_REGEX = re.compile(r"^{((?:[^:,]+:[^:,]*,)*(?:[^:,]+:[^:,]*))}$") + LCSC_BARCODE_REGEX = re.compile(r'^{((?:[^:,]+:[^:,]*,)*(?:[^:,]+:[^:,]*))}$') # Custom field mapping for LCSC barcodes LCSC_FIELDS = { - "pm": SupplierBarcodeMixin.MANUFACTURER_PART_NUMBER, - "pc": SupplierBarcodeMixin.SUPPLIER_PART_NUMBER, - "qty": SupplierBarcodeMixin.QUANTITY, - "on": SupplierBarcodeMixin.SUPPLIER_ORDER_NUMBER, + 'pm': SupplierBarcodeMixin.MANUFACTURER_PART_NUMBER, + 'pc': SupplierBarcodeMixin.SUPPLIER_PART_NUMBER, + 'qty': SupplierBarcodeMixin.QUANTITY, + 'on': SupplierBarcodeMixin.SUPPLIER_ORDER_NUMBER, } def extract_barcode_fields(self, barcode_data: str) -> dict[str, str]: @@ -50,13 +50,10 @@ class LCSCPlugin(SupplierBarcodeMixin, SettingsMixin, InvenTreePlugin): # Extract fields fields = SupplierBarcodeMixin.split_fields( - barcode_data, - delimiter=',', - header='{', - trailer='}', + barcode_data, delimiter=',', header='{', trailer='}' ) - fields = dict(pair.split(":") for pair in fields) + fields = dict(pair.split(':') for pair in fields) barcode_fields = {} diff --git a/InvenTree/plugin/builtin/suppliers/mouser.py b/InvenTree/plugin/builtin/suppliers/mouser.py index 0c2dd80022..e9fd53a052 100644 --- a/InvenTree/plugin/builtin/suppliers/mouser.py +++ b/InvenTree/plugin/builtin/suppliers/mouser.py @@ -12,18 +12,18 @@ from plugin.mixins import SettingsMixin, SupplierBarcodeMixin class MouserPlugin(SupplierBarcodeMixin, SettingsMixin, InvenTreePlugin): """Plugin to integrate the Mouser API into Inventree.""" - NAME = "MouserPlugin" - TITLE = _("Supplier Integration - Mouser") - DESCRIPTION = _("Provides support for scanning Mouser barcodes") - VERSION = "1.0.0" - AUTHOR = _("InvenTree contributors") + NAME = 'MouserPlugin' + TITLE = _('Supplier Integration - Mouser') + DESCRIPTION = _('Provides support for scanning Mouser barcodes') + VERSION = '1.0.0' + AUTHOR = _('InvenTree contributors') - DEFAULT_SUPPLIER_NAME = "Mouser" + DEFAULT_SUPPLIER_NAME = 'Mouser' SETTINGS = { - "SUPPLIER_ID": { - "name": _("Supplier"), - "description": _("The Supplier which acts as 'Mouser'"), - "model": "company.company", + 'SUPPLIER_ID': { + 'name': _('Supplier'), + 'description': _("The Supplier which acts as 'Mouser'"), + 'model': 'company.company', } } diff --git a/InvenTree/plugin/builtin/suppliers/test_supplier_barcodes.py b/InvenTree/plugin/builtin/suppliers/test_supplier_barcodes.py index 73f19d3fec..afd1dea435 100644 --- a/InvenTree/plugin/builtin/suppliers/test_supplier_barcodes.py +++ b/InvenTree/plugin/builtin/suppliers/test_supplier_barcodes.py @@ -12,32 +12,35 @@ from stock.models import StockItem, StockLocation class SupplierBarcodeTests(InvenTreeAPITestCase): """Tests barcode parsing for all suppliers.""" - SCAN_URL = reverse("api-barcode-scan") + SCAN_URL = reverse('api-barcode-scan') @classmethod def setUpTestData(cls): """Create supplier parts for barcodes.""" super().setUpTestData() - part = Part.objects.create(name="Test Part", description="Test Part") + part = Part.objects.create(name='Test Part', description='Test Part') manufacturer = Company.objects.create( - name="Test Manufacturer", is_manufacturer=True) + name='Test Manufacturer', is_manufacturer=True + ) mpart1 = ManufacturerPart.objects.create( - part=part, manufacturer=manufacturer, MPN="MC34063ADR") + part=part, manufacturer=manufacturer, MPN='MC34063ADR' + ) mpart2 = ManufacturerPart.objects.create( - part=part, manufacturer=manufacturer, MPN="LDK320ADU33R") + part=part, manufacturer=manufacturer, MPN='LDK320ADU33R' + ) - supplier = Company.objects.create(name="Supplier", is_supplier=True) - mouser = Company.objects.create(name="Mouser Test", is_supplier=True) + supplier = Company.objects.create(name='Supplier', is_supplier=True) + mouser = Company.objects.create(name='Mouser Test', is_supplier=True) supplier_parts = [ - SupplierPart(SKU="296-LM358BIDDFRCT-ND", part=part, supplier=supplier), - SupplierPart(SKU="1", part=part, manufacturer_part=mpart1, supplier=mouser), - SupplierPart(SKU="2", part=part, manufacturer_part=mpart2, supplier=mouser), - SupplierPart(SKU="C312270", part=part, supplier=supplier), - SupplierPart(SKU="WBP-302", part=part, supplier=supplier), + SupplierPart(SKU='296-LM358BIDDFRCT-ND', part=part, supplier=supplier), + SupplierPart(SKU='1', part=part, manufacturer_part=mpart1, supplier=mouser), + SupplierPart(SKU='2', part=part, manufacturer_part=mpart2, supplier=mouser), + SupplierPart(SKU='C312270', part=part, supplier=supplier), + SupplierPart(SKU='WBP-302', part=part, supplier=supplier), ] SupplierPart.objects.bulk_create(supplier_parts) @@ -45,87 +48,101 @@ class SupplierBarcodeTests(InvenTreeAPITestCase): def test_digikey_barcode(self): """Test digikey barcode""" - result = self.post(self.SCAN_URL, data={"barcode": DIGIKEY_BARCODE}, expected_code=200) + result = self.post( + self.SCAN_URL, data={'barcode': DIGIKEY_BARCODE}, expected_code=200 + ) self.assertEqual(result.data['plugin'], 'DigiKeyPlugin') - supplier_part_data = result.data.get("supplierpart") + supplier_part_data = result.data.get('supplierpart') self.assertIn('pk', supplier_part_data) - supplier_part = SupplierPart.objects.get(pk=supplier_part_data["pk"]) - self.assertEqual(supplier_part.SKU, "296-LM358BIDDFRCT-ND") + supplier_part = SupplierPart.objects.get(pk=supplier_part_data['pk']) + self.assertEqual(supplier_part.SKU, '296-LM358BIDDFRCT-ND') def test_digikey_2_barcode(self): """Test digikey barcode which uses 30P instead of P""" - result = self.post(self.SCAN_URL, data={"barcode": DIGIKEY_BARCODE_2}, expected_code=200) + result = self.post( + self.SCAN_URL, data={'barcode': DIGIKEY_BARCODE_2}, expected_code=200 + ) self.assertEqual(result.data['plugin'], 'DigiKeyPlugin') - supplier_part_data = result.data.get("supplierpart") + supplier_part_data = result.data.get('supplierpart') self.assertIn('pk', supplier_part_data) - supplier_part = SupplierPart.objects.get(pk=supplier_part_data["pk"]) - self.assertEqual(supplier_part.SKU, "296-LM358BIDDFRCT-ND") + supplier_part = SupplierPart.objects.get(pk=supplier_part_data['pk']) + self.assertEqual(supplier_part.SKU, '296-LM358BIDDFRCT-ND') def test_digikey_3_barcode(self): """Test digikey barcode which is invalid""" - self.post(self.SCAN_URL, data={"barcode": DIGIKEY_BARCODE_3}, expected_code=400) + self.post(self.SCAN_URL, data={'barcode': DIGIKEY_BARCODE_3}, expected_code=400) def test_mouser_barcode(self): """Test mouser barcode with custom order number.""" - result = self.post(self.SCAN_URL, data={"barcode": MOUSER_BARCODE}, expected_code=200) + result = self.post( + self.SCAN_URL, data={'barcode': MOUSER_BARCODE}, expected_code=200 + ) - supplier_part_data = result.data.get("supplierpart") + supplier_part_data = result.data.get('supplierpart') self.assertIn('pk', supplier_part_data) - supplier_part = SupplierPart.objects.get(pk=supplier_part_data["pk"]) + supplier_part = SupplierPart.objects.get(pk=supplier_part_data['pk']) self.assertEqual(supplier_part.SKU, '1') def test_old_mouser_barcode(self): """Test old mouser barcode with messed up header.""" - result = self.post(self.SCAN_URL, data={"barcode": MOUSER_BARCODE_OLD}, expected_code=200) + result = self.post( + self.SCAN_URL, data={'barcode': MOUSER_BARCODE_OLD}, expected_code=200 + ) - supplier_part_data = result.data.get("supplierpart") + supplier_part_data = result.data.get('supplierpart') self.assertIn('pk', supplier_part_data) - supplier_part = SupplierPart.objects.get(pk=supplier_part_data["pk"]) + supplier_part = SupplierPart.objects.get(pk=supplier_part_data['pk']) self.assertEqual(supplier_part.SKU, '2') def test_lcsc_barcode(self): """Test LCSC barcode.""" - result = self.post(self.SCAN_URL, data={"barcode": LCSC_BARCODE}, expected_code=200) + result = self.post( + self.SCAN_URL, data={'barcode': LCSC_BARCODE}, expected_code=200 + ) self.assertEqual(result.data['plugin'], 'LCSCPlugin') - supplier_part_data = result.data.get("supplierpart") + supplier_part_data = result.data.get('supplierpart') self.assertIn('pk', supplier_part_data) - supplier_part = SupplierPart.objects.get(pk=supplier_part_data["pk"]) + supplier_part = SupplierPart.objects.get(pk=supplier_part_data['pk']) self.assertEqual(supplier_part.SKU, 'C312270') def test_tme_qrcode(self): """Test TME QR-Code.""" - result = self.post(self.SCAN_URL, data={"barcode": TME_QRCODE}, expected_code=200) + result = self.post( + self.SCAN_URL, data={'barcode': TME_QRCODE}, expected_code=200 + ) self.assertEqual(result.data['plugin'], 'TMEPlugin') - supplier_part_data = result.data.get("supplierpart") + supplier_part_data = result.data.get('supplierpart') self.assertIn('pk', supplier_part_data) - supplier_part = SupplierPart.objects.get(pk=supplier_part_data["pk"]) + supplier_part = SupplierPart.objects.get(pk=supplier_part_data['pk']) self.assertEqual(supplier_part.SKU, 'WBP-302') def test_tme_barcode2d(self): """Test TME DataMatrix-Code.""" - result = self.post(self.SCAN_URL, data={"barcode": TME_DATAMATRIX_CODE}, expected_code=200) + result = self.post( + self.SCAN_URL, data={'barcode': TME_DATAMATRIX_CODE}, expected_code=200 + ) self.assertEqual(result.data['plugin'], 'TMEPlugin') - supplier_part_data = result.data.get("supplierpart") + supplier_part_data = result.data.get('supplierpart') self.assertIn('pk', supplier_part_data) - supplier_part = SupplierPart.objects.get(pk=supplier_part_data["pk"]) + supplier_part = SupplierPart.objects.get(pk=supplier_part_data['pk']) self.assertEqual(supplier_part.SKU, 'WBP-302') @@ -136,28 +153,27 @@ class SupplierBarcodePOReceiveTests(InvenTreeAPITestCase): """Create supplier part and purchase_order.""" super().setUp() - part = Part.objects.create(name="Test Part", description="Test Part") - supplier = Company.objects.create(name="Supplier", is_supplier=True) + part = Part.objects.create(name='Test Part', description='Test Part') + supplier = Company.objects.create(name='Supplier', is_supplier=True) manufacturer = Company.objects.create( - name="Test Manufacturer", is_manufacturer=True + name='Test Manufacturer', is_manufacturer=True ) - mouser = Company.objects.create(name="Mouser Test", is_supplier=True) + mouser = Company.objects.create(name='Mouser Test', is_supplier=True) mpart = ManufacturerPart.objects.create( - part=part, manufacturer=manufacturer, MPN="MC34063ADR" + part=part, manufacturer=manufacturer, MPN='MC34063ADR' ) self.purchase_order1 = PurchaseOrder.objects.create( - supplier_reference="72991337", supplier=supplier + supplier_reference='72991337', supplier=supplier ) supplier_parts1 = [ - SupplierPart(SKU=f"1_{i}", part=part, supplier=supplier) - for i in range(6) + SupplierPart(SKU=f'1_{i}', part=part, supplier=supplier) for i in range(6) ] supplier_parts1.insert( - 2, SupplierPart(SKU="296-LM358BIDDFRCT-ND", part=part, supplier=supplier) + 2, SupplierPart(SKU='296-LM358BIDDFRCT-ND', part=part, supplier=supplier) ) for supplier_part in supplier_parts1: @@ -165,18 +181,18 @@ class SupplierBarcodePOReceiveTests(InvenTreeAPITestCase): self.purchase_order1.add_line_item(supplier_part, 8) self.purchase_order2 = PurchaseOrder.objects.create( - reference="P0-1337", supplier=mouser + reference='P0-1337', supplier=mouser ) self.purchase_order2.place_order() supplier_parts2 = [ - SupplierPart(SKU=f"2_{i}", part=part, supplier=mouser) - for i in range(6) + SupplierPart(SKU=f'2_{i}', part=part, supplier=mouser) for i in range(6) ] - supplier_parts2.insert(3, SupplierPart( - SKU="42", part=part, manufacturer_part=mpart, supplier=mouser - )) + supplier_parts2.insert( + 3, + SupplierPart(SKU='42', part=part, manufacturer_part=mpart, supplier=mouser), + ) for supplier_part in supplier_parts2: supplier_part.save() @@ -185,168 +201,185 @@ class SupplierBarcodePOReceiveTests(InvenTreeAPITestCase): def test_receive(self): """Test receiving an item from a barcode.""" - url = reverse("api-barcode-po-receive") + url = reverse('api-barcode-po-receive') - result1 = self.post(url, data={"barcode": DIGIKEY_BARCODE}, expected_code=400) + result1 = self.post(url, data={'barcode': DIGIKEY_BARCODE}, expected_code=400) - assert result1.data["error"].startswith("No matching purchase order") + assert result1.data['error'].startswith('No matching purchase order') self.purchase_order1.place_order() - result2 = self.post(url, data={"barcode": DIGIKEY_BARCODE}, expected_code=200) - self.assertIn("success", result2.data) + result2 = self.post(url, data={'barcode': DIGIKEY_BARCODE}, expected_code=200) + self.assertIn('success', result2.data) - result3 = self.post(url, data={"barcode": DIGIKEY_BARCODE}, expected_code=400) - self.assertEqual(result3.data['error'], "Item has already been received") + result3 = self.post(url, data={'barcode': DIGIKEY_BARCODE}, expected_code=400) + self.assertEqual(result3.data['error'], 'Item has already been received') - result4 = self.post(url, data={"barcode": DIGIKEY_BARCODE[:-1]}, expected_code=400) - assert result4.data["error"].startswith( - "Failed to find pending line item for supplier part") + result4 = self.post( + url, data={'barcode': DIGIKEY_BARCODE[:-1]}, expected_code=400 + ) + assert result4.data['error'].startswith( + 'Failed to find pending line item for supplier part' + ) - result5 = self.post(reverse("api-barcode-scan"), data={"barcode": DIGIKEY_BARCODE}, expected_code=200) - stock_item = StockItem.objects.get(pk=result5.data["stockitem"]["pk"]) - assert stock_item.supplier_part.SKU == "296-LM358BIDDFRCT-ND" + result5 = self.post( + reverse('api-barcode-scan'), + data={'barcode': DIGIKEY_BARCODE}, + expected_code=200, + ) + stock_item = StockItem.objects.get(pk=result5.data['stockitem']['pk']) + assert stock_item.supplier_part.SKU == '296-LM358BIDDFRCT-ND' assert stock_item.quantity == 10 assert stock_item.location is None def test_receive_custom_order_number(self): """Test receiving an item from a barcode with a custom order number.""" - url = reverse("api-barcode-po-receive") - result1 = self.post(url, data={"barcode": MOUSER_BARCODE}) - assert "success" in result1.data + url = reverse('api-barcode-po-receive') + result1 = self.post(url, data={'barcode': MOUSER_BARCODE}) + assert 'success' in result1.data - result2 = self.post(reverse("api-barcode-scan"), data={"barcode": MOUSER_BARCODE}) - stock_item = StockItem.objects.get(pk=result2.data["stockitem"]["pk"]) - assert stock_item.supplier_part.SKU == "42" - assert stock_item.supplier_part.manufacturer_part.MPN == "MC34063ADR" + result2 = self.post( + reverse('api-barcode-scan'), data={'barcode': MOUSER_BARCODE} + ) + stock_item = StockItem.objects.get(pk=result2.data['stockitem']['pk']) + assert stock_item.supplier_part.SKU == '42' + assert stock_item.supplier_part.manufacturer_part.MPN == 'MC34063ADR' assert stock_item.quantity == 3 assert stock_item.location is None def test_receive_one_stock_location(self): """Test receiving an item when only one stock location exists""" - stock_location = StockLocation.objects.create(name="Test Location") + stock_location = StockLocation.objects.create(name='Test Location') - url = reverse("api-barcode-po-receive") - result1 = self.post(url, data={"barcode": MOUSER_BARCODE}) - assert "success" in result1.data + url = reverse('api-barcode-po-receive') + result1 = self.post(url, data={'barcode': MOUSER_BARCODE}) + assert 'success' in result1.data - result2 = self.post(reverse("api-barcode-scan"), data={"barcode": MOUSER_BARCODE}) - stock_item = StockItem.objects.get(pk=result2.data["stockitem"]["pk"]) + result2 = self.post( + reverse('api-barcode-scan'), data={'barcode': MOUSER_BARCODE} + ) + stock_item = StockItem.objects.get(pk=result2.data['stockitem']['pk']) assert stock_item.location == stock_location def test_receive_default_line_item_location(self): """Test receiving an item into the default line_item location""" - StockLocation.objects.create(name="Test Location 1") - stock_location2 = StockLocation.objects.create(name="Test Location 2") + StockLocation.objects.create(name='Test Location 1') + stock_location2 = StockLocation.objects.create(name='Test Location 2') - line_item = PurchaseOrderLineItem.objects.filter(part__SKU="42")[0] + line_item = PurchaseOrderLineItem.objects.filter(part__SKU='42')[0] line_item.destination = stock_location2 line_item.save() - url = reverse("api-barcode-po-receive") - result1 = self.post(url, data={"barcode": MOUSER_BARCODE}) - assert "success" in result1.data + url = reverse('api-barcode-po-receive') + result1 = self.post(url, data={'barcode': MOUSER_BARCODE}) + assert 'success' in result1.data - result2 = self.post(reverse("api-barcode-scan"), data={"barcode": MOUSER_BARCODE}) - stock_item = StockItem.objects.get(pk=result2.data["stockitem"]["pk"]) + result2 = self.post( + reverse('api-barcode-scan'), data={'barcode': MOUSER_BARCODE} + ) + stock_item = StockItem.objects.get(pk=result2.data['stockitem']['pk']) assert stock_item.location == stock_location2 def test_receive_default_part_location(self): """Test receiving an item into the default part location""" - StockLocation.objects.create(name="Test Location 1") - stock_location2 = StockLocation.objects.create(name="Test Location 2") + StockLocation.objects.create(name='Test Location 1') + stock_location2 = StockLocation.objects.create(name='Test Location 2') part = Part.objects.all()[0] part.default_location = stock_location2 part.save() - url = reverse("api-barcode-po-receive") - result1 = self.post(url, data={"barcode": MOUSER_BARCODE}) - assert "success" in result1.data + url = reverse('api-barcode-po-receive') + result1 = self.post(url, data={'barcode': MOUSER_BARCODE}) + assert 'success' in result1.data - result2 = self.post(reverse("api-barcode-scan"), data={"barcode": MOUSER_BARCODE}) - stock_item = StockItem.objects.get(pk=result2.data["stockitem"]["pk"]) + result2 = self.post( + reverse('api-barcode-scan'), data={'barcode': MOUSER_BARCODE} + ) + stock_item = StockItem.objects.get(pk=result2.data['stockitem']['pk']) assert stock_item.location == stock_location2 def test_receive_specific_order_and_location(self): """Test receiving an item from a specific order into a specific location""" - StockLocation.objects.create(name="Test Location 1") - stock_location2 = StockLocation.objects.create(name="Test Location 2") + StockLocation.objects.create(name='Test Location 1') + stock_location2 = StockLocation.objects.create(name='Test Location 2') - url = reverse("api-barcode-po-receive") - barcode = MOUSER_BARCODE.replace("\x1dKP0-1337", "") - result1 = self.post(url, data={ - "barcode": barcode, - "purchase_order": self.purchase_order2.pk, - "location": stock_location2.pk, - }) - assert "success" in result1.data + url = reverse('api-barcode-po-receive') + barcode = MOUSER_BARCODE.replace('\x1dKP0-1337', '') + result1 = self.post( + url, + data={ + 'barcode': barcode, + 'purchase_order': self.purchase_order2.pk, + 'location': stock_location2.pk, + }, + ) + assert 'success' in result1.data - result2 = self.post(reverse("api-barcode-scan"), data={"barcode": barcode}) - stock_item = StockItem.objects.get(pk=result2.data["stockitem"]["pk"]) + result2 = self.post(reverse('api-barcode-scan'), data={'barcode': barcode}) + stock_item = StockItem.objects.get(pk=result2.data['stockitem']['pk']) assert stock_item.location == stock_location2 def test_receive_missing_quantity(self): """Test receiving an with missing quantity information""" - url = reverse("api-barcode-po-receive") - barcode = MOUSER_BARCODE.replace("\x1dQ3", "") - response = self.post(url, data={"barcode": barcode}, expected_code=200) + url = reverse('api-barcode-po-receive') + barcode = MOUSER_BARCODE.replace('\x1dQ3', '') + response = self.post(url, data={'barcode': barcode}, expected_code=200) - assert "lineitem" in response.data - assert "quantity" not in response.data["lineitem"] + assert 'lineitem' in response.data + assert 'quantity' not in response.data['lineitem'] DIGIKEY_BARCODE = ( - "[)>\x1e06\x1dP296-LM358BIDDFRCT-ND\x1d1PLM358BIDDFR\x1dK\x1d1K72991337\x1d" - "10K85781337\x1d11K1\x1d4LPH\x1dQ10\x1d11ZPICK\x1d12Z15221337\x1d13Z361337" - "\x1d20Z0000000000000000000000000000000000000000000000000000000000000000000" - "00000000000000000000000000000000000000000000000000000000000000000000000000" - "0000000000000000000000000000000000" + '[)>\x1e06\x1dP296-LM358BIDDFRCT-ND\x1d1PLM358BIDDFR\x1dK\x1d1K72991337\x1d' + '10K85781337\x1d11K1\x1d4LPH\x1dQ10\x1d11ZPICK\x1d12Z15221337\x1d13Z361337' + '\x1d20Z0000000000000000000000000000000000000000000000000000000000000000000' + '00000000000000000000000000000000000000000000000000000000000000000000000000' + '0000000000000000000000000000000000' ) # Uses 30P instead of P DIGIKEY_BARCODE_2 = ( - "[)>\x1e06\x1d30P296-LM358BIDDFRCT-ND\x1dK\x1d1K72991337\x1d" - "10K85781337\x1d11K1\x1d4LPH\x1dQ10\x1d11ZPICK\x1d12Z15221337\x1d13Z361337" - "\x1d20Z0000000000000000000000000000000000000000000000000000000000000000000" - "00000000000000000000000000000000000000000000000000000000000000000000000000" - "0000000000000000000000000000000000" + '[)>\x1e06\x1d30P296-LM358BIDDFRCT-ND\x1dK\x1d1K72991337\x1d' + '10K85781337\x1d11K1\x1d4LPH\x1dQ10\x1d11ZPICK\x1d12Z15221337\x1d13Z361337' + '\x1d20Z0000000000000000000000000000000000000000000000000000000000000000000' + '00000000000000000000000000000000000000000000000000000000000000000000000000' + '0000000000000000000000000000000000' ) # Invalid code DIGIKEY_BARCODE_3 = ( - "[)>\x1e06\x1dPnonsense\x1d30Pnonsense\x1d1Pnonsense\x1dK\x1d1K72991337\x1d" - "10K85781337\x1d11K1\x1d4LPH\x1dQ10\x1d11ZPICK\x1d12Z15221337\x1d13Z361337" - "\x1d20Z0000000000000000000000000000000000000000000000000000000000000000000" - "00000000000000000000000000000000000000000000000000000000000000000000000000" - "0000000000000000000000000000000000" - + '[)>\x1e06\x1dPnonsense\x1d30Pnonsense\x1d1Pnonsense\x1dK\x1d1K72991337\x1d' + '10K85781337\x1d11K1\x1d4LPH\x1dQ10\x1d11ZPICK\x1d12Z15221337\x1d13Z361337' + '\x1d20Z0000000000000000000000000000000000000000000000000000000000000000000' + '00000000000000000000000000000000000000000000000000000000000000000000000000' + '0000000000000000000000000000000000' ) MOUSER_BARCODE = ( - "[)>\x1e06\x1dKP0-1337\x1d14K011\x1d1PMC34063ADR\x1dQ3\x1d11K073121337\x1d4" - "LMX\x1d1VTI\x1e\x04" + '[)>\x1e06\x1dKP0-1337\x1d14K011\x1d1PMC34063ADR\x1dQ3\x1d11K073121337\x1d4' + 'LMX\x1d1VTI\x1e\x04' ) MOUSER_BARCODE_OLD = ( - ">[)>06\x1dK21421337\x1d14K033\x1d1PLDK320ADU33R\x1dQ32\x1d11K060931337\x1d" - "4LCN\x1d1VSTMicro" + '>[)>06\x1dK21421337\x1d14K033\x1d1PLDK320ADU33R\x1dQ32\x1d11K060931337\x1d' + '4LCN\x1d1VSTMicro' ) LCSC_BARCODE = ( - "{pbn:PICK2009291337,on:SO2009291337,pc:C312270,pm:ST-1-102-A01-T000-RS,qty" - ":2,mc:,cc:1,pdi:34421807}" + '{pbn:PICK2009291337,on:SO2009291337,pc:C312270,pm:ST-1-102-A01-T000-RS,qty' + ':2,mc:,cc:1,pdi:34421807}' ) TME_QRCODE = ( - "QTY:1 PN:WBP-302 PO:19361337/1 CPO:PO-2023-06-08-001337 MFR:WISHERENTERPRI" - "SE MPN:WBP-302 RoHS https://www.tme.eu/details/WBP-302" + 'QTY:1 PN:WBP-302 PO:19361337/1 CPO:PO-2023-06-08-001337 MFR:WISHERENTERPRI' + 'SE MPN:WBP-302 RoHS https://www.tme.eu/details/WBP-302' ) -TME_DATAMATRIX_CODE = "PWBP-302 1PMPNWBP-302 Q1 K19361337/1" +TME_DATAMATRIX_CODE = 'PWBP-302 1PMPNWBP-302 Q1 K19361337/1' diff --git a/InvenTree/plugin/builtin/suppliers/tme.py b/InvenTree/plugin/builtin/suppliers/tme.py index 5bd3e266ed..cbb9d1bf63 100644 --- a/InvenTree/plugin/builtin/suppliers/tme.py +++ b/InvenTree/plugin/builtin/suppliers/tme.py @@ -14,31 +14,31 @@ from plugin.mixins import SettingsMixin, SupplierBarcodeMixin class TMEPlugin(SupplierBarcodeMixin, SettingsMixin, InvenTreePlugin): """Plugin to integrate the TME API into Inventree.""" - NAME = "TMEPlugin" - TITLE = _("Supplier Integration - TME") - DESCRIPTION = _("Provides support for scanning TME barcodes") - VERSION = "1.0.0" - AUTHOR = _("InvenTree contributors") + NAME = 'TMEPlugin' + TITLE = _('Supplier Integration - TME') + DESCRIPTION = _('Provides support for scanning TME barcodes') + VERSION = '1.0.0' + AUTHOR = _('InvenTree contributors') - DEFAULT_SUPPLIER_NAME = "TME" + DEFAULT_SUPPLIER_NAME = 'TME' SETTINGS = { - "SUPPLIER_ID": { - "name": _("Supplier"), - "description": _("The Supplier which acts as 'TME'"), - "model": "company.company", + 'SUPPLIER_ID': { + 'name': _('Supplier'), + 'description': _("The Supplier which acts as 'TME'"), + 'model': 'company.company', } } - TME_IS_QRCODE_REGEX = re.compile(r"([^\s:]+:[^\s:]+\s+)+(\S+(\s|$)+)+") - TME_IS_BARCODE2D_REGEX = re.compile(r"(([^\s]+)(\s+|$))+") + TME_IS_QRCODE_REGEX = re.compile(r'([^\s:]+:[^\s:]+\s+)+(\S+(\s|$)+)+') + TME_IS_BARCODE2D_REGEX = re.compile(r'(([^\s]+)(\s+|$))+') # Custom field mapping TME_QRCODE_FIELDS = { - "PN": SupplierBarcodeMixin.SUPPLIER_PART_NUMBER, - "CPO": SupplierBarcodeMixin.CUSTOMER_ORDER_NUMBER, - "PO": SupplierBarcodeMixin.SUPPLIER_ORDER_NUMBER, - "MPN": SupplierBarcodeMixin.MANUFACTURER_PART_NUMBER, - "QTY": SupplierBarcodeMixin.QUANTITY, + 'PN': SupplierBarcodeMixin.SUPPLIER_PART_NUMBER, + 'CPO': SupplierBarcodeMixin.CUSTOMER_ORDER_NUMBER, + 'PO': SupplierBarcodeMixin.SUPPLIER_ORDER_NUMBER, + 'MPN': SupplierBarcodeMixin.MANUFACTURER_PART_NUMBER, + 'QTY': SupplierBarcodeMixin.QUANTITY, } def extract_barcode_fields(self, barcode_data: str) -> dict[str, str]: @@ -48,9 +48,9 @@ class TMEPlugin(SupplierBarcodeMixin, SettingsMixin, InvenTreePlugin): if self.TME_IS_QRCODE_REGEX.fullmatch(barcode_data): # Custom QR Code format e.g. "QTY: 1 PN:12345" - for item in barcode_data.split(" "): - if ":" in item: - key, value = item.split(":") + for item in barcode_data.split(' '): + if ':' in item: + key, value = item.split(':') if key in self.TME_QRCODE_FIELDS: barcode_fields[self.TME_QRCODE_FIELDS[key]] = value @@ -58,17 +58,17 @@ class TMEPlugin(SupplierBarcodeMixin, SettingsMixin, InvenTreePlugin): elif self.TME_IS_BARCODE2D_REGEX.fullmatch(barcode_data): # 2D Barcode format e.g. "PWBP-302 1PMPNWBP-302 Q1 K19361337/1" - for item in barcode_data.split(" "): + for item in barcode_data.split(' '): for k, v in self.ecia_field_map().items(): if item.startswith(k): - barcode_fields[v] = item[len(k):] + barcode_fields[v] = item[len(k) :] else: return {} # Custom handling for order number if SupplierBarcodeMixin.CUSTOMER_ORDER_NUMBER in barcode_fields: order_number = barcode_fields[SupplierBarcodeMixin.CUSTOMER_ORDER_NUMBER] - order_number = order_number.split("/")[0] + order_number = order_number.split('/')[0] barcode_fields[SupplierBarcodeMixin.CUSTOMER_ORDER_NUMBER] = order_number return barcode_fields diff --git a/InvenTree/plugin/events.py b/InvenTree/plugin/events.py index 800233bd9a..4ac5ff0544 100644 --- a/InvenTree/plugin/events.py +++ b/InvenTree/plugin/events.py @@ -1,10 +1,5 @@ """Import helper for events.""" -from plugin.base.event.events import (process_event, register_event, - trigger_event) +from plugin.base.event.events import process_event, register_event, trigger_event -__all__ = [ - 'process_event', - 'register_event', - 'trigger_event', -] +__all__ = ['process_event', 'register_event', 'trigger_event'] diff --git a/InvenTree/plugin/helpers.py b/InvenTree/plugin/helpers.py index d896ce10c2..211641514a 100644 --- a/InvenTree/plugin/helpers.py +++ b/InvenTree/plugin/helpers.py @@ -41,11 +41,13 @@ class MixinImplementationError(ValueError): Mostly raised if constant is missing """ + pass class MixinNotImplementedError(NotImplementedError): """Error if necessary mixin function was not overwritten.""" + pass @@ -64,7 +66,7 @@ def log_error(error, reference: str = 'general'): def handle_error(error, do_raise: bool = True, do_log: bool = True, log_name: str = ''): """Handles an error and casts it as an IntegrationPluginError.""" package_path = traceback.extract_tb(error.__traceback__)[-1].filename - install_path = sysconfig.get_paths()["purelib"] + install_path = sysconfig.get_paths()['purelib'] try: package_name = pathlib.Path(package_path).relative_to(install_path).parts[0] @@ -73,7 +75,9 @@ def handle_error(error, do_raise: bool = True, do_log: bool = True, log_name: st try: path_obj = pathlib.Path(package_path).relative_to(settings.BASE_DIR) path_parts = [*path_obj.parts] - path_parts[-1] = path_parts[-1].replace(path_obj.suffix, '') # remove suffix + path_parts[-1] = path_parts[-1].replace( + path_obj.suffix, '' + ) # remove suffix # remove path prefixes if path_parts[0] == 'plugin': @@ -94,7 +98,11 @@ def handle_error(error, do_raise: bool = True, do_log: bool = True, log_name: st if do_raise: # do a straight raise if we are playing with environment variables at execution time, ignore the broken sample - if settings.TESTING_ENV and package_name != 'integration.broken_sample' and isinstance(error, IntegrityError): + if ( + settings.TESTING_ENV + and package_name != 'integration.broken_sample' + and isinstance(error, IntegrityError) + ): raise error # pragma: no cover raise IntegrationPluginError(package_name, str(error)) @@ -103,6 +111,8 @@ def handle_error(error, do_raise: bool = True, do_log: bool = True, log_name: st def get_entrypoints(): """Returns list for entrypoints for InvenTree plugins.""" return entry_points().get('inventree_plugins', []) + + # endregion @@ -123,7 +133,6 @@ def get_git_log(path): # only do this if we are not in test mode if not isInTestMode(): # pragma: no cover - try: repo = Repo(path) head = repo.head() @@ -133,7 +142,7 @@ def get_git_log(path): head.decode(), commit.author.decode().split('<')[0][:-1], commit.author.decode().split('<')[1][:-1], - datetime.datetime.fromtimestamp(commit.author_time, ).isoformat(), + datetime.datetime.fromtimestamp(commit.author_time).isoformat(), commit.message.decode().split('\n')[0], ] except NotGitRepository: @@ -142,7 +151,15 @@ def get_git_log(path): if not output: output = 5 * [''] # pragma: no cover - return {'hash': output[0], 'author': output[1], 'mail': output[2], 'date': output[3], 'message': output[4]} + return { + 'hash': output[0], + 'author': output[1], + 'mail': output[2], + 'date': output[3], + 'message': output[4], + } + + # endregion @@ -199,6 +216,8 @@ def get_plugins(pkg, baseclass, path=None): plugins.append(plugin) return plugins + + # endregion @@ -208,7 +227,9 @@ def render_template(plugin, template_file, context=None): try: tmp = template.loader.get_template(template_file) except template.TemplateDoesNotExist: - logger.exception("Plugin %s could not locate template '%s'", plugin.slug, template_file) + logger.exception( + "Plugin %s could not locate template '%s'", plugin.slug, template_file + ) return f"""
    @@ -228,4 +249,5 @@ def render_text(text, context=None): return template.Template(text).render(ctx) + # endregion diff --git a/InvenTree/plugin/installer.py b/InvenTree/plugin/installer.py index 885114ab83..72b7ad4aa1 100644 --- a/InvenTree/plugin/installer.py +++ b/InvenTree/plugin/installer.py @@ -25,12 +25,10 @@ def pip_command(*args): command = [str(x) for x in command] - logger.info("Running pip command: %s", ' '.join(command)) + logger.info('Running pip command: %s', ' '.join(command)) return subprocess.check_output( - command, - cwd=settings.BASE_DIR.parent, - stderr=subprocess.STDOUT, + command, cwd=settings.BASE_DIR.parent, stderr=subprocess.STDOUT ) @@ -40,7 +38,7 @@ def check_package_path(packagename: str): - If installed, return the installation path - If not installed, return False """ - logger.debug("check_package_path: %s", packagename) + logger.debug('check_package_path: %s', packagename) # Remove version information for c in '<>=! ': @@ -49,7 +47,7 @@ def check_package_path(packagename: str): try: result = pip_command('show', packagename) - output = result.decode('utf-8').split("\n") + output = result.decode('utf-8').split('\n') for line in output: # Check if line matches pattern "Location: ..." @@ -60,7 +58,7 @@ def check_package_path(packagename: str): except subprocess.CalledProcessError as error: output = error.output.decode('utf-8') - logger.exception("Plugin lookup failed: %s", str(output)) + logger.exception('Plugin lookup failed: %s', str(output)) return False # If we get here, the package is not installed @@ -69,22 +67,22 @@ def check_package_path(packagename: str): def install_plugins_file(): """Install plugins from the plugins file""" - logger.info("Installing plugins from plugins file") + logger.info('Installing plugins from plugins file') pf = settings.PLUGIN_FILE if not pf or not pf.exists(): - logger.warning("Plugin file %s does not exist", str(pf)) + logger.warning('Plugin file %s does not exist', str(pf)) return try: pip_command('install', '-r', str(pf)) except subprocess.CalledProcessError as error: output = error.output.decode('utf-8') - logger.exception("Plugin file installation failed: %s", str(output)) + logger.exception('Plugin file installation failed: %s', str(output)) return False except Exception as exc: - logger.exception("Plugin file installation failed: %s", exc) + logger.exception('Plugin file installation failed: %s', exc) return False # At this point, the plugins file has been installed @@ -93,12 +91,12 @@ def install_plugins_file(): def add_plugin_to_file(install_name): """Add a plugin to the plugins file""" - logger.info("Adding plugin to plugins file: %s", install_name) + logger.info('Adding plugin to plugins file: %s', install_name) pf = settings.PLUGIN_FILE if not pf or not pf.exists(): - logger.warning("Plugin file %s does not exist", str(pf)) + logger.warning('Plugin file %s does not exist', str(pf)) return # First, read in existing plugin file @@ -106,13 +104,13 @@ def add_plugin_to_file(install_name): with pf.open(mode='r') as f: lines = f.readlines() except Exception as exc: - logger.exception("Failed to read plugins file: %s", str(exc)) + logger.exception('Failed to read plugins file: %s', str(exc)) return # Check if plugin is already in file for line in lines: if line.strip() == install_name: - logger.debug("Plugin already exists in file") + logger.debug('Plugin already exists in file') return # Append plugin to file @@ -127,7 +125,7 @@ def add_plugin_to_file(install_name): if not line.endswith('\n'): f.write('\n') except Exception as exc: - logger.exception("Failed to add plugin to plugins file: %s", str(exc)) + logger.exception('Failed to add plugin to plugins file: %s', str(exc)) def install_plugin(url=None, packagename=None, user=None): @@ -137,16 +135,18 @@ def install_plugin(url=None, packagename=None, user=None): - We must detect that we are running within a virtual environment """ if user and not user.is_staff: - raise ValidationError(_("Permission denied: only staff users can install plugins")) + raise ValidationError( + _('Permission denied: only staff users can install plugins') + ) - logger.debug("install_plugin: %s, %s", url, packagename) + logger.debug('install_plugin: %s, %s', url, packagename) # Check if we are running in a virtual environment # For now, just log a warning in_venv = sys.prefix != sys.base_prefix if not in_venv: - logger.warning("InvenTree is not running in a virtual environment") + logger.warning('InvenTree is not running in a virtual environment') # build up the command install_name = ['install', '-U'] @@ -155,7 +155,9 @@ def install_plugin(url=None, packagename=None, user=None): if url: # use custom registration / VCS - if True in [identifier in url for identifier in ['git+https', 'hg+https', 'svn+svn', ]]: + if True in [ + identifier in url for identifier in ['git+https', 'hg+https', 'svn+svn'] + ]: # using a VCS provider if packagename: full_pkg = f'{packagename}@{url}' @@ -183,25 +185,23 @@ def install_plugin(url=None, packagename=None, user=None): try: result = pip_command(*install_name) - ret['result'] = ret['success'] = _("Installed plugin successfully") + ret['result'] = ret['success'] = _('Installed plugin successfully') ret['output'] = str(result, 'utf-8') if packagename: if path := check_package_path(packagename): # Override result information - ret['result'] = _(f"Installed plugin into {path}") + ret['result'] = _(f'Installed plugin into {path}') except subprocess.CalledProcessError as error: # If an error was thrown, we need to parse the output output = error.output.decode('utf-8') - logger.exception("Plugin installation failed: %s", str(output)) + logger.exception('Plugin installation failed: %s', str(output)) - errors = [ - _("Plugin installation failed"), - ] + errors = [_('Plugin installation failed')] - for msg in output.split("\n"): + for msg in output.split('\n'): msg = msg.strip() if msg: @@ -217,6 +217,7 @@ def install_plugin(url=None, packagename=None, user=None): # Reload the plugin registry, to discover the new plugin from plugin.registry import registry + registry.reload_plugins(full_reload=True, force_reload=True, collect=True) return ret diff --git a/InvenTree/plugin/mixins/__init__.py b/InvenTree/plugin/mixins/__init__.py index 07f6eed485..0e39462ad6 100644 --- a/InvenTree/plugin/mixins/__init__.py +++ b/InvenTree/plugin/mixins/__init__.py @@ -1,15 +1,17 @@ """Utility class to enable simpler imports.""" -from common.notifications import (BulkNotificationMethod, - SingleNotificationMethod) +from common.notifications import BulkNotificationMethod, SingleNotificationMethod from plugin.base.action.mixins import ActionMixin from plugin.base.barcodes.mixins import BarcodeMixin, SupplierBarcodeMixin from plugin.base.event.mixins import EventMixin from plugin.base.integration.APICallMixin import APICallMixin from plugin.base.integration.AppMixin import AppMixin from plugin.base.integration.CurrencyExchangeMixin import CurrencyExchangeMixin -from plugin.base.integration.mixins import (NavigationMixin, PanelMixin, - SettingsContentMixin) +from plugin.base.integration.mixins import ( + NavigationMixin, + PanelMixin, + SettingsContentMixin, +) from plugin.base.integration.ReportMixin import ReportMixin from plugin.base.integration.ScheduleMixin import ScheduleMixin from plugin.base.integration.SettingsMixin import SettingsMixin diff --git a/InvenTree/plugin/mock/simple.py b/InvenTree/plugin/mock/simple.py index 8130aa7eed..3160fed4b8 100644 --- a/InvenTree/plugin/mock/simple.py +++ b/InvenTree/plugin/mock/simple.py @@ -7,4 +7,4 @@ class SimplePlugin(InvenTreePlugin): """A very simple plugin.""" NAME = 'SimplePlugin' - SLUG = "simple" + SLUG = 'simple' diff --git a/InvenTree/plugin/models.py b/InvenTree/plugin/models.py index fef716b86a..c8df8987c9 100644 --- a/InvenTree/plugin/models.py +++ b/InvenTree/plugin/models.py @@ -25,14 +25,12 @@ class PluginConfig(InvenTree.models.MetadataMixin, models.Model): class Meta: """Meta for PluginConfig.""" - verbose_name = _("Plugin Configuration") - verbose_name_plural = _("Plugin Configurations") + + verbose_name = _('Plugin Configuration') + verbose_name_plural = _('Plugin Configurations') key = models.CharField( - unique=True, - max_length=255, - verbose_name=_('Key'), - help_text=_('Key of plugin'), + unique=True, max_length=255, verbose_name=_('Key'), help_text=_('Key of plugin') ) name = models.CharField( @@ -44,9 +42,7 @@ class PluginConfig(InvenTree.models.MetadataMixin, models.Model): ) active = models.BooleanField( - default=False, - verbose_name=_('Active'), - help_text=_('Is the plugin active'), + default=False, verbose_name=_('Active'), help_text=_('Is the plugin active') ) def __str__(self) -> str: @@ -61,7 +57,9 @@ class PluginConfig(InvenTree.models.MetadataMixin, models.Model): """Returns all registered mixins.""" try: if inspect.isclass(self.plugin): - return self.plugin.get_registered_mixins(self, with_base=True, with_cls=False) + return self.plugin.get_registered_mixins( + self, with_base=True, with_cls=False + ) return self.plugin.get_registered_mixins(with_base=True, with_cls=False) except (AttributeError, ValueError): # pragma: no cover return {} @@ -94,9 +92,19 @@ class PluginConfig(InvenTree.models.MetadataMixin, models.Model): return result self.meta = { - key: get_plugin_meta(key) for key in ['slug', 'human_name', 'description', 'author', - 'pub_date', 'version', 'website', 'license', - 'package_path', 'settings_url', ] + key: get_plugin_meta(key) + for key in [ + 'slug', + 'human_name', + 'description', + 'author', + 'pub_date', + 'version', + 'website', + 'license', + 'package_path', + 'settings_url', + ] } # Save plugin @@ -105,7 +113,9 @@ class PluginConfig(InvenTree.models.MetadataMixin, models.Model): def __getstate__(self): """Customize pickling behavior.""" state = super().__getstate__() - state.pop("plugin", None) # plugin cannot be pickled in some circumstances when used with drf views, remove it (#5408) + state.pop( + 'plugin', None + ) # plugin cannot be pickled in some circumstances when used with drf views, remove it (#5408) return state def save(self, force_insert=False, force_update=False, *args, **kwargs): @@ -160,9 +170,8 @@ class PluginSetting(common.models.BaseInvenTreeSetting): class Meta: """Meta for PluginSetting.""" - unique_together = [ - ('plugin', 'key'), - ] + + unique_together = [('plugin', 'key')] plugin = models.ForeignKey( PluginConfig, @@ -185,7 +194,6 @@ class PluginSetting(common.models.BaseInvenTreeSetting): (if the plugin is specified!) """ if 'settings' not in kwargs: - plugin = kwargs.pop('plugin', None) if plugin: @@ -204,9 +212,8 @@ class NotificationUserSetting(common.models.BaseInvenTreeSetting): class Meta: """Meta for NotificationUserSetting.""" - unique_together = [ - ('method', 'user', 'key'), - ] + + unique_together = [('method', 'user', 'key')] @classmethod def get_setting_definition(cls, key, **kwargs): @@ -217,15 +224,13 @@ class NotificationUserSetting(common.models.BaseInvenTreeSetting): return super().get_setting_definition(key, **kwargs) - method = models.CharField( - max_length=255, - verbose_name=_('Method'), - ) + method = models.CharField(max_length=255, verbose_name=_('Method')) user = models.ForeignKey( User, on_delete=models.CASCADE, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('User'), help_text=_('User'), ) diff --git a/InvenTree/plugin/plugin.py b/InvenTree/plugin/plugin.py index 1938f59a2a..54fc1ff8d0 100644 --- a/InvenTree/plugin/plugin.py +++ b/InvenTree/plugin/plugin.py @@ -16,7 +16,7 @@ from django.utils.translation import gettext_lazy as _ from plugin.helpers import get_git_log -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') class MetaBase: @@ -46,7 +46,11 @@ class MetaBase: # Sound of a warning if old_key worked if value: - warnings.warn(f'Usage of {old_key} was depreciated in 0.7.0 in favour of {key}', DeprecationWarning, stacklevel=2) + warnings.warn( + f'Usage of {old_key} was depreciated in 0.7.0 in favour of {key}', + DeprecationWarning, + stacklevel=2, + ) # Use __default if still nothing set if (value is None) and __default: @@ -96,8 +100,7 @@ class MetaBase: import plugin.models cfg, _ = plugin.models.PluginConfig.objects.get_or_create( - key=self.plugin_slug(), - name=self.plugin_name(), + key=self.plugin_slug(), name=self.plugin_name() ) except (OperationalError, ProgrammingError): cfg = None @@ -153,14 +156,14 @@ class MixinBase: def setup_mixin(self, key, cls=None): """Define mixin details for the current mixin -> provides meta details for all active mixins.""" # get human name - human_name = getattr(cls.MixinMeta, 'MIXIN_NAME', key) if cls and hasattr(cls, 'MixinMeta') else key + human_name = ( + getattr(cls.MixinMeta, 'MIXIN_NAME', key) + if cls and hasattr(cls, 'MixinMeta') + else key + ) # register - self._mixinreg[key] = { - 'key': key, - 'human_name': human_name, - 'cls': cls, - } + self._mixinreg[key] = {'key': key, 'human_name': human_name, 'cls': cls} def get_registered_mixins(self, with_base: bool = False, with_cls: bool = True): """Get all registered mixins for the plugin.""" @@ -175,7 +178,10 @@ class MixinBase: # Do not return the mixin class if flas is set if not with_cls: - return {key: {k: v for k, v in mixin.items() if k != 'cls'} for key, mixin in mixins.items()} + return { + key: {k: v for k, v in mixin.items() if k != 'cls'} + for key, mixin in mixins.items() + } return mixins @property @@ -290,6 +296,7 @@ class InvenTreePlugin(VersionMixin, MixinBase, MetaBase): def license(self): """License of plugin.""" return self._get_value('LICENSE', 'license') + # endregion @classmethod @@ -363,7 +370,6 @@ class InvenTreePlugin(VersionMixin, MixinBase, MetaBase): meta = metadata(cls.__name__) # Simple lookup did not work - get data from module except PackageNotFoundError: - try: meta = metadata(cls.__module__.split('.')[0]) except PackageNotFoundError: @@ -372,7 +378,7 @@ class InvenTreePlugin(VersionMixin, MixinBase, MetaBase): try: website = meta['Project-URL'].split(', ')[1] - except (ValueError, IndexError, AttributeError, ): + except (ValueError, IndexError, AttributeError): website = meta['Project-URL'] return { @@ -380,13 +386,17 @@ class InvenTreePlugin(VersionMixin, MixinBase, MetaBase): 'description': meta['Summary'], 'version': meta['Version'], 'website': website, - 'license': meta['License'] + 'license': meta['License'], } def define_package(self): """Add package info of the plugin into plugins context.""" try: - package = self._get_package_metadata() if self._is_package else self._get_package_commit() + package = ( + self._get_package_metadata() + if self._is_package + else self._get_package_commit() + ) except TypeError: package = {} @@ -396,4 +406,5 @@ class InvenTreePlugin(VersionMixin, MixinBase, MetaBase): # set variables self.package = package + # endregion diff --git a/InvenTree/plugin/registry.py b/InvenTree/plugin/registry.py index b6b749f168..2b3dc354b8 100644 --- a/InvenTree/plugin/registry.py +++ b/InvenTree/plugin/registry.py @@ -21,14 +21,22 @@ from django.urls import clear_url_caches, path from django.utils.text import slugify from django.utils.translation import gettext_lazy as _ -from maintenance_mode.core import (get_maintenance_mode, maintenance_mode_on, - set_maintenance_mode) +from maintenance_mode.core import ( + get_maintenance_mode, + maintenance_mode_on, + set_maintenance_mode, +) from InvenTree.config import get_plugin_dir from InvenTree.ready import canAppAccessDatabase -from .helpers import (IntegrationPluginError, get_entrypoints, get_plugins, - handle_error, log_error) +from .helpers import ( + IntegrationPluginError, + get_entrypoints, + get_plugins, + handle_error, + log_error, +) from .plugin import InvenTreePlugin logger = logging.getLogger('inventree') @@ -50,25 +58,31 @@ class PluginsRegistry: Set up all needed references for internal and external states. """ # plugin registry - self.plugins: Dict[str, InvenTreePlugin] = {} # List of active instances - self.plugins_inactive: Dict[str, InvenTreePlugin] = {} # List of inactive instances - self.plugins_full: Dict[str, InvenTreePlugin] = {} # List of all plugin instances + self.plugins: Dict[str, InvenTreePlugin] = {} # List of active instances + self.plugins_inactive: Dict[ + str, InvenTreePlugin + ] = {} # List of inactive instances + self.plugins_full: Dict[ + str, InvenTreePlugin + ] = {} # List of all plugin instances # Keep an internal hash of the plugin registry state self.registry_hash = None - self.plugin_modules: List[InvenTreePlugin] = [] # Holds all discovered plugins - self.mixin_modules: Dict[str, Any] = {} # Holds all discovered mixins + self.plugin_modules: List[InvenTreePlugin] = [] # Holds all discovered plugins + self.mixin_modules: Dict[str, Any] = {} # Holds all discovered mixins - self.errors = {} # Holds discovering errors + self.errors = {} # Holds discovering errors - self.loading_lock = Lock() # Lock to prevent multiple loading at the same time + self.loading_lock = Lock() # Lock to prevent multiple loading at the same time # flags - self.plugins_loaded = False # Marks if the registry fully loaded and all django apps are reloaded - self.apps_loading = True # Marks if apps were reloaded yet + self.plugins_loaded = ( + False # Marks if the registry fully loaded and all django apps are reloaded + ) + self.apps_loading = True # Marks if apps were reloaded yet - self.installed_apps = [] # Holds all added plugin_paths + self.installed_apps = [] # Holds all added plugin_paths @property def is_loading(self): @@ -134,7 +148,6 @@ class PluginsRegistry: for plugin in self.plugins.values(): if plugin.mixin_enabled(mixin): - if active is not None: # Filter by 'active' status of plugin if active != plugin.is_active(): @@ -148,6 +161,7 @@ class PluginsRegistry: result.append(plugin) return result + # endregion # region loading / unloading @@ -179,7 +193,11 @@ class PluginsRegistry: logger.info('Database not accessible while loading plugins') break except IntegrationPluginError as error: - logger.exception('[PLUGIN] Encountered an error with %s:\n%s', error.path, error.message) + logger.exception( + '[PLUGIN] Encountered an error with %s:\n%s', + error.path, + error.message, + ) log_error({error.path: error.message}, 'load') blocked_plugin = error.path # we will not try to load this app again @@ -196,7 +214,9 @@ class PluginsRegistry: print('[PLUGIN] Max retries, breaking loading') break if settings.PLUGIN_TESTING: - print(f'[PLUGIN] Above error occurred during testing - {retry_counter}/{settings.PLUGIN_RETRY} retries left') + print( + f'[PLUGIN] Above error occurred during testing - {retry_counter}/{settings.PLUGIN_RETRY} retries left' + ) # now the loading will re-start up with init @@ -216,6 +236,7 @@ class PluginsRegistry: # Trigger plugins_loaded event if canAppAccessDatabase(): from plugin.events import trigger_event + trigger_event('plugins_loaded') def _unload_plugins(self, force_reload: bool = False): @@ -243,7 +264,12 @@ class PluginsRegistry: logger.info('Finished unloading plugins') - def reload_plugins(self, full_reload: bool = False, force_reload: bool = False, collect: bool = False): + def reload_plugins( + self, + full_reload: bool = False, + force_reload: bool = False, + collect: bool = False, + ): """Reload the plugin registry. This should be considered the single point of entry for loading plugins! @@ -255,12 +281,16 @@ class PluginsRegistry: """ # Do not reload when currently loading if self.is_loading: - logger.debug("Skipping reload - plugin registry is currently loading") + logger.debug('Skipping reload - plugin registry is currently loading') return if self.loading_lock.acquire(blocking=False): - - logger.info('Plugin Registry: Reloading plugins - Force: %s, Full: %s, Collect: %s', force_reload, full_reload, collect) + logger.info( + 'Plugin Registry: Reloading plugins - Force: %s, Full: %s, Collect: %s', + force_reload, + full_reload, + collect, + ) with maintenance_mode_on(): if collect: @@ -280,7 +310,7 @@ class PluginsRegistry: def plugin_dirs(self): """Construct a list of directories from where plugins can be loaded""" # Builtin plugins are *always* loaded - dirs = ['plugin.builtin', ] + dirs = ['plugin.builtin'] if settings.PLUGINS_ENABLED: # Any 'external' plugins are only loaded if PLUGINS_ENABLED is set to True @@ -307,7 +337,9 @@ class PluginsRegistry: try: pd.mkdir(exist_ok=True) except Exception: # pragma: no cover - logger.exception("Could not create plugin directory '%s'", pd) + logger.exception( + "Could not create plugin directory '%s'", pd + ) continue # Ensure the directory has an __init__.py file @@ -315,9 +347,11 @@ class PluginsRegistry: if not init_filename.exists(): try: - init_filename.write_text("# InvenTree plugin directory\n") + init_filename.write_text('# InvenTree plugin directory\n') except Exception: # pragma: no cover - logger.exception("Could not create file '%s'", init_filename) + logger.exception( + "Could not create file '%s'", init_filename + ) continue # By this point, we have confirmed that the directory at least exists @@ -340,7 +374,6 @@ class PluginsRegistry: # Collect plugins from paths for plugin in self.plugin_dirs(): - logger.debug("Loading plugins from directory '%s'", plugin) parent_path = None @@ -354,7 +387,9 @@ class PluginsRegistry: # Gather Modules if parent_path: - raw_module = imp.load_source(plugin, str(parent_obj.joinpath('__init__.py'))) + raw_module = imp.load_source( + plugin, str(parent_obj.joinpath('__init__.py')) + ) else: raw_module = importlib.import_module(plugin) modules = get_plugins(raw_module, InvenTreePlugin, path=parent_path) @@ -364,9 +399,10 @@ class PluginsRegistry: # From this point any plugins are considered "external" and only loaded if plugins are explicitly enabled if settings.PLUGINS_ENABLED: - # Check if not running in testing mode and apps should be loaded from hooks - if (not settings.PLUGIN_TESTING) or (settings.PLUGIN_TESTING and settings.PLUGIN_TESTING_SETUP): + if (not settings.PLUGIN_TESTING) or ( + settings.PLUGIN_TESTING and settings.PLUGIN_TESTING_SETUP + ): # Collect plugins from setup entry points for entry in get_entrypoints(): try: @@ -379,7 +415,7 @@ class PluginsRegistry: # Log collected plugins logger.info('Collected %s plugins', len(collected_plugins)) - logger.debug(", ".join([a.__module__ for a in collected_plugins])) + logger.debug(', '.join([a.__module__ for a in collected_plugins])) return collected_plugins @@ -436,15 +472,15 @@ class PluginsRegistry: logger.debug('Starting plugin initialization') # Fetch and cache list of existing plugin configuration instances - plugin_configs = { - cfg.key: cfg for cfg in PluginConfig.objects.all() - } + plugin_configs = {cfg.key: cfg for cfg in PluginConfig.objects.all()} # Initialize plugins for plg in self.plugin_modules: # These checks only use attributes - never use plugin supplied functions -> that would lead to arbitrary code execution!! plg_name = plg.NAME - plg_key = slugify(plg.SLUG if getattr(plg, 'SLUG', None) else plg_name) # keys are slugs! + plg_key = slugify( + plg.SLUG if getattr(plg, 'SLUG', None) else plg_name + ) # keys are slugs! try: if plg_key in plugin_configs: @@ -452,14 +488,16 @@ class PluginsRegistry: plg_db = plugin_configs[plg_key] else: # Configuration needs to be created - plg_db, _created = PluginConfig.objects.get_or_create(key=plg_key, name=plg_name) + plg_db, _created = PluginConfig.objects.get_or_create( + key=plg_key, name=plg_name + ) except (OperationalError, ProgrammingError) as error: # Exception if the database has not been migrated yet - check if test are running - raise if not if not settings.PLUGIN_TESTING: raise error # pragma: no cover plg_db = None - except (IntegrityError) as error: # pragma: no cover - logger.exception("Error initializing plugin `%s`: %s", plg_name, error) + except IntegrityError as error: # pragma: no cover + logger.exception('Error initializing plugin `%s`: %s', plg_name, error) handle_error(error, log_name='init') # Append reference to plugin @@ -492,8 +530,10 @@ class PluginsRegistry: dt = time.time() - t_start logger.debug('Loaded plugin `%s` in %.3fs', plg_name, dt) except Exception as error: - handle_error(error, log_name='init') # log error and raise it -> disable plugin - logger.warning("Plugin `%s` could not be loaded", plg_name) + handle_error( + error, log_name='init' + ) # log error and raise it -> disable plugin + logger.warning('Plugin `%s` could not be loaded', plg_name) # Safe extra attributes plg_i.is_package = getattr(plg_i, 'is_package', False) @@ -501,13 +541,17 @@ class PluginsRegistry: plg_i.db = plg_db # Run version check for plugin - if (plg_i.MIN_VERSION or plg_i.MAX_VERSION) and not plg_i.check_version(): + if ( + plg_i.MIN_VERSION or plg_i.MAX_VERSION + ) and not plg_i.check_version(): # Disable plugin safe_reference(plugin=plg_i, key=plg_key, active=False) p = plg_name v = version.inventreeVersion() - _msg = _(f"Plugin '{p}' is not compatible with the current InvenTree version {v}") + _msg = _( + f"Plugin '{p}' is not compatible with the current InvenTree version {v}" + ) if v := plg_i.MIN_VERSION: _msg += _(f'Plugin requires at least version {v}') if v := plg_i.MAX_VERSION: @@ -525,7 +569,11 @@ class PluginsRegistry: order = self.DEFAULT_MIXIN_ORDER # Append mixins that are not defined in the default list - order += [m.get('cls') for m in self.mixin_modules.values() if m.get('cls') not in order] + order += [ + m.get('cls') + for m in self.mixin_modules.values() + if m.get('cls') not in order + ] # Final list of mixins return order @@ -546,7 +594,9 @@ class PluginsRegistry: for mixin in self.__get_mixin_order(): if hasattr(mixin, '_activate_mixin'): - mixin._activate_mixin(self, plugins, force_reload=force_reload, full_reload=full_reload) + mixin._activate_mixin( + self, plugins, force_reload=force_reload, full_reload=full_reload + ) logger.debug('Done activating') @@ -561,6 +611,7 @@ class PluginsRegistry: mixin._deactivate_mixin(self, force_reload=force_reload) logger.debug('Finished deactivating plugins') + # endregion # region mixin specific loading ... @@ -617,13 +668,14 @@ class PluginsRegistry: from plugin.urls import get_plugin_urls for index, url in enumerate(urlpatterns): - app_name = getattr(url, 'app_name', None) admin_url = settings.INVENTREE_ADMIN_URL if app_name == 'admin': - urlpatterns[index] = path(f'{admin_url}/', admin.site.urls, name='inventree-admin') + urlpatterns[index] = path( + f'{admin_url}/', admin.site.urls, name='inventree-admin' + ) if app_name == 'plugin': urlpatterns[index] = get_plugin_urls() @@ -641,20 +693,26 @@ class PluginsRegistry: self.registry_hash = self.calculate_plugin_hash() try: - old_hash = InvenTreeSetting.get_setting("_PLUGIN_REGISTRY_HASH", "", create=False, cache=False) + old_hash = InvenTreeSetting.get_setting( + '_PLUGIN_REGISTRY_HASH', '', create=False, cache=False + ) except Exception: - old_hash = "" + old_hash = '' if old_hash != self.registry_hash: try: - logger.debug("Updating plugin registry hash: %s", str(self.registry_hash)) - InvenTreeSetting.set_setting("_PLUGIN_REGISTRY_HASH", self.registry_hash, change_user=None) + logger.debug( + 'Updating plugin registry hash: %s', str(self.registry_hash) + ) + InvenTreeSetting.set_setting( + '_PLUGIN_REGISTRY_HASH', self.registry_hash, change_user=None + ) except (OperationalError, ProgrammingError): # Exception if the database has not been migrated yet, or is not ready pass except Exception as exc: # Some other exception, we want to know about it - logger.exception("Failed to update plugin registry hash: %s", str(exc)) + logger.exception('Failed to update plugin registry hash: %s', str(exc)) def calculate_plugin_hash(self): """Calculate a 'hash' value for the current registry @@ -680,12 +738,18 @@ class PluginsRegistry: 'ENABLE_PLUGINS_NAVIGATION', 'ENABLE_PLUGINS_APP', 'ENABLE_PLUGINS_SCHEDULE', - 'ENABLE_PLUGINS_EVENTS' + 'ENABLE_PLUGINS_EVENTS', ] for k in keys: try: - data.update(str(InvenTreeSetting.get_setting(k, False, cache=False, create=False)).encode()) + data.update( + str( + InvenTreeSetting.get_setting( + k, False, cache=False, create=False + ) + ).encode() + ) except Exception: pass @@ -703,20 +767,22 @@ class PluginsRegistry: # Skip check if database cannot be accessed return - logger.debug("Checking plugin registry hash") + logger.debug('Checking plugin registry hash') # If not already cached, calculate the hash if not self.registry_hash: self.registry_hash = self.calculate_plugin_hash() try: - reg_hash = InvenTreeSetting.get_setting("_PLUGIN_REGISTRY_HASH", "", create=False, cache=False) + reg_hash = InvenTreeSetting.get_setting( + '_PLUGIN_REGISTRY_HASH', '', create=False, cache=False + ) except Exception as exc: - logger.exception("Failed to retrieve plugin registry hash: %s", str(exc)) + logger.exception('Failed to retrieve plugin registry hash: %s', str(exc)) return if reg_hash and reg_hash != self.registry_hash: - logger.info("Plugin registry hash has changed - reloading") + logger.info('Plugin registry hash has changed - reloading') self.reload_plugins(full_reload=True, force_reload=True, collect=True) # endregion diff --git a/InvenTree/plugin/samples/event/event_sample.py b/InvenTree/plugin/samples/event/event_sample.py index 2ae21dcbb5..1b2bf59615 100644 --- a/InvenTree/plugin/samples/event/event_sample.py +++ b/InvenTree/plugin/samples/event/event_sample.py @@ -13,15 +13,15 @@ logger = logging.getLogger('inventree') class EventPluginSample(EventMixin, InvenTreePlugin): """A sample plugin which provides supports for triggered events.""" - NAME = "EventPlugin" - SLUG = "sampleevent" - TITLE = "Triggered Events" + NAME = 'EventPlugin' + SLUG = 'sampleevent' + TITLE = 'Triggered Events' def process_event(self, event, *args, **kwargs): """Custom event processing.""" print(f"Processing triggered event: '{event}'") - print("args:", str(args)) - print("kwargs:", str(kwargs)) + print('args:', str(args)) + print('kwargs:', str(kwargs)) # Issue warning that we can test for if settings.PLUGIN_TESTING: diff --git a/InvenTree/plugin/samples/event/filtered_event_sample.py b/InvenTree/plugin/samples/event/filtered_event_sample.py index 4267f3be1b..7380d0622a 100644 --- a/InvenTree/plugin/samples/event/filtered_event_sample.py +++ b/InvenTree/plugin/samples/event/filtered_event_sample.py @@ -13,19 +13,19 @@ logger = logging.getLogger('inventree') class FilteredEventPluginSample(EventMixin, InvenTreePlugin): """A sample plugin which provides supports for triggered events.""" - NAME = "FilteredEventPlugin" - SLUG = "filteredsampleevent" - TITLE = "Triggered by test.event only" + NAME = 'FilteredEventPlugin' + SLUG = 'filteredsampleevent' + TITLE = 'Triggered by test.event only' def wants_process_event(self, event): """Return whether given event should be processed or not.""" - return event == "test.event" + return event == 'test.event' def process_event(self, event, *args, **kwargs): """Custom event processing.""" print(f"Processing triggered event: '{event}'") - print("args:", str(args)) - print("kwargs:", str(kwargs)) + print('args:', str(args)) + print('kwargs:', str(kwargs)) # Issue warning that we can test for if settings.PLUGIN_TESTING: diff --git a/InvenTree/plugin/samples/event/test_event_sample.py b/InvenTree/plugin/samples/event/test_event_sample.py index e466aaa0eb..0e5eb9b86d 100644 --- a/InvenTree/plugin/samples/event/test_event_sample.py +++ b/InvenTree/plugin/samples/event/test_event_sample.py @@ -27,9 +27,11 @@ class EventPluginSampleTests(TestCase): # Enable event testing settings.PLUGIN_TESTING_EVENTS = True # Check that an event is issued - with self.assertLogs(logger=logger, level="DEBUG") as cm: + with self.assertLogs(logger=logger, level='DEBUG') as cm: trigger_event('test.event') - self.assertIn('DEBUG:inventree:Event `test.event` triggered in sample plugin', cm[1]) + self.assertIn( + 'DEBUG:inventree:Event `test.event` triggered in sample plugin', cm[1] + ) # Disable again settings.PLUGIN_TESTING_EVENTS = False @@ -37,6 +39,7 @@ class EventPluginSampleTests(TestCase): def test_mixin(self): """Test that MixinNotImplementedError is raised.""" with self.assertRaises(MixinNotImplementedError): + class Wrong(EventMixin, InvenTreePlugin): pass diff --git a/InvenTree/plugin/samples/event/test_filtered_event_sample.py b/InvenTree/plugin/samples/event/test_filtered_event_sample.py index efcb5c0290..1da5a3b499 100644 --- a/InvenTree/plugin/samples/event/test_filtered_event_sample.py +++ b/InvenTree/plugin/samples/event/test_filtered_event_sample.py @@ -25,9 +25,11 @@ class FilteredEventPluginSampleTests(TestCase): # Enable event testing settings.PLUGIN_TESTING_EVENTS = True # Check that an event is issued - with self.assertLogs(logger=logger, level="DEBUG") as cm: + with self.assertLogs(logger=logger, level='DEBUG') as cm: trigger_event('test.event') - self.assertIn('DEBUG:inventree:Event `test.event` triggered in sample plugin', cm[1]) + self.assertIn( + 'DEBUG:inventree:Event `test.event` triggered in sample plugin', cm[1] + ) # Disable again settings.PLUGIN_TESTING_EVENTS = False @@ -44,9 +46,12 @@ class FilteredEventPluginSampleTests(TestCase): # Enable event testing settings.PLUGIN_TESTING_EVENTS = True # Check that an event is issued - with self.assertLogs(logger=logger, level="DEBUG") as cm: + with self.assertLogs(logger=logger, level='DEBUG') as cm: trigger_event('test.some.other.event') - self.assertNotIn('DEBUG:inventree:Event `test.some.other.event` triggered in sample plugin', cm[1]) + self.assertNotIn( + 'DEBUG:inventree:Event `test.some.other.event` triggered in sample plugin', + cm[1], + ) # Disable again settings.PLUGIN_TESTING_EVENTS = False diff --git a/InvenTree/plugin/samples/integration/another_sample.py b/InvenTree/plugin/samples/integration/another_sample.py index 14f92bf583..95cff15d2f 100644 --- a/InvenTree/plugin/samples/integration/another_sample.py +++ b/InvenTree/plugin/samples/integration/another_sample.py @@ -1,4 +1,5 @@ """Sample implementation for IntegrationPlugin.""" + from plugin import InvenTreePlugin from plugin.mixins import UrlsMixin @@ -6,10 +7,10 @@ from plugin.mixins import UrlsMixin class NoIntegrationPlugin(InvenTreePlugin): """A basic plugin.""" - NAME = "NoIntegrationPlugin" + NAME = 'NoIntegrationPlugin' class WrongIntegrationPlugin(UrlsMixin, InvenTreePlugin): """A basic wrong plugin with urls.""" - NAME = "WrongIntegrationPlugin" + NAME = 'WrongIntegrationPlugin' diff --git a/InvenTree/plugin/samples/integration/api_caller.py b/InvenTree/plugin/samples/integration/api_caller.py index 0968780e8b..ea78f175c5 100644 --- a/InvenTree/plugin/samples/integration/api_caller.py +++ b/InvenTree/plugin/samples/integration/api_caller.py @@ -1,4 +1,5 @@ """Sample plugin for calling an external API.""" + from plugin import InvenTreePlugin from plugin.mixins import APICallMixin, SettingsMixin @@ -6,13 +7,10 @@ from plugin.mixins import APICallMixin, SettingsMixin class SampleApiCallerPlugin(APICallMixin, SettingsMixin, InvenTreePlugin): """A small api call sample.""" - NAME = "Sample API Caller" + NAME = 'Sample API Caller' SETTINGS = { - 'API_TOKEN': { - 'name': 'API Token', - 'protected': True, - }, + 'API_TOKEN': {'name': 'API Token', 'protected': True}, 'API_URL': { 'name': 'External URL', 'description': 'Where is your API located?', diff --git a/InvenTree/plugin/samples/integration/custom_panel_sample.py b/InvenTree/plugin/samples/integration/custom_panel_sample.py index 0ef2086029..9f6ccabb8c 100644 --- a/InvenTree/plugin/samples/integration/custom_panel_sample.py +++ b/InvenTree/plugin/samples/integration/custom_panel_sample.py @@ -9,11 +9,11 @@ from stock.views import StockLocationDetail class CustomPanelSample(PanelMixin, SettingsMixin, InvenTreePlugin): """A sample plugin which renders some custom panels.""" - NAME = "CustomPanelExample" - SLUG = "samplepanel" - TITLE = "Custom Panel Example" - DESCRIPTION = "An example plugin demonstrating how custom panels can be added to the user interface" - VERSION = "0.1" + NAME = 'CustomPanelExample' + SLUG = 'samplepanel' + TITLE = 'Custom Panel Example' + DESCRIPTION = 'An example plugin demonstrating how custom panels can be added to the user interface' + VERSION = '0.1' SETTINGS = { 'ENABLE_HELLO_WORLD': { @@ -27,7 +27,7 @@ class CustomPanelSample(PanelMixin, SettingsMixin, InvenTreePlugin): 'description': 'Enable a panel with rendering issues', 'default': False, 'validator': bool, - } + }, } def get_panel_context(self, view, request, context): @@ -51,12 +51,11 @@ class CustomPanelSample(PanelMixin, SettingsMixin, InvenTreePlugin): panels = [ { # Simple panel without any actual content - 'title': 'No Content', + 'title': 'No Content' } ] if self.get_setting('ENABLE_HELLO_WORLD'): - # We can use template rendering in the raw content content = """ Hello world! @@ -81,7 +80,6 @@ class CustomPanelSample(PanelMixin, SettingsMixin, InvenTreePlugin): }) if self.get_setting('ENABLE_BROKEN_PANEL'): - # Enabling this panel will cause panel rendering to break, # due to the invalid tags panels.append({ @@ -103,7 +101,6 @@ class CustomPanelSample(PanelMixin, SettingsMixin, InvenTreePlugin): # This panel will *only* display on the StockLocation view, # and *only* if the StockLocation has *no* child locations if isinstance(view, StockLocationDetail): - try: loc = view.get_object() diff --git a/InvenTree/plugin/samples/integration/label_sample.py b/InvenTree/plugin/samples/integration/label_sample.py index e35f6ee6f1..5ac88f532b 100644 --- a/InvenTree/plugin/samples/integration/label_sample.py +++ b/InvenTree/plugin/samples/integration/label_sample.py @@ -12,15 +12,16 @@ from plugin.mixins import LabelPrintingMixin class SampleLabelPrinter(LabelPrintingMixin, InvenTreePlugin): """Sample plugin which provides a 'fake' label printer endpoint.""" - NAME = "Sample Label Printer" - SLUG = "samplelabelprinter" - TITLE = "Sample Label Printer" - DESCRIPTION = "A sample plugin which provides a (fake) label printer interface" - AUTHOR = "InvenTree contributors" - VERSION = "0.3.0" + NAME = 'Sample Label Printer' + SLUG = 'samplelabelprinter' + TITLE = 'Sample Label Printer' + DESCRIPTION = 'A sample plugin which provides a (fake) label printer interface' + AUTHOR = 'InvenTree contributors' + VERSION = '0.3.0' class PrintingOptionsSerializer(serializers.Serializer): """Serializer to return printing options.""" + amount = serializers.IntegerField(required=False, default=1) def print_label(self, **kwargs): diff --git a/InvenTree/plugin/samples/integration/report_plugin_sample.py b/InvenTree/plugin/samples/integration/report_plugin_sample.py index 5b143ab307..e14cadbfda 100644 --- a/InvenTree/plugin/samples/integration/report_plugin_sample.py +++ b/InvenTree/plugin/samples/integration/report_plugin_sample.py @@ -10,11 +10,11 @@ from report.models import PurchaseOrderReport class SampleReportPlugin(ReportMixin, InvenTreePlugin): """Sample plugin which provides extra context data to a report""" - NAME = "Sample Report Plugin" - SLUG = "samplereport" - TITLE = "Sample Report Plugin" - DESCRIPTION = "A sample plugin which provides extra context data to a report" - VERSION = "1.0" + NAME = 'Sample Report Plugin' + SLUG = 'samplereport' + TITLE = 'Sample Report Plugin' + DESCRIPTION = 'A sample plugin which provides extra context data to a report' + VERSION = '1.0' def some_custom_function(self): """Some custom function which is not required for the plugin to function""" @@ -24,7 +24,9 @@ class SampleReportPlugin(ReportMixin, InvenTreePlugin): """Add example content to the report instance""" # We can add any extra context data we want to the report # Generate a random string of data - context['random_text'] = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz', k=20)) + context['random_text'] = ''.join( + random.choices('abcdefghijklmnopqrstuvwxyz', k=20) + ) # Call a custom method context['random_int'] = self.some_custom_function() diff --git a/InvenTree/plugin/samples/integration/sample.py b/InvenTree/plugin/samples/integration/sample.py index 3f58261ec3..894d5f51e3 100644 --- a/InvenTree/plugin/samples/integration/sample.py +++ b/InvenTree/plugin/samples/integration/sample.py @@ -19,14 +19,16 @@ def validate_json(value): raise ValidationError(str(e)) -class SampleIntegrationPlugin(AppMixin, SettingsMixin, UrlsMixin, NavigationMixin, InvenTreePlugin): +class SampleIntegrationPlugin( + AppMixin, SettingsMixin, UrlsMixin, NavigationMixin, InvenTreePlugin +): """A full plugin example.""" - NAME = "SampleIntegrationPlugin" - SLUG = "sample" - TITLE = "Sample Plugin" + NAME = 'SampleIntegrationPlugin' + SLUG = 'sample' + TITLE = 'Sample Plugin' - NAVIGATION_TAB_NAME = "Sample Nav" + NAVIGATION_TAB_NAME = 'Sample Nav' NAVIGATION_TAB_ICON = 'fas fa-plus' def view_test(self, request): @@ -64,14 +66,9 @@ class SampleIntegrationPlugin(AppMixin, SettingsMixin, UrlsMixin, NavigationMixi 'default': 123, }, 'CHOICE_SETTING': { - 'name': _("Choice Setting"), + 'name': _('Choice Setting'), 'description': _('A setting with multiple choices'), - 'choices': [ - ('A', 'Anaconda'), - ('B', 'Bat'), - ('C', 'Cat'), - ('D', 'Dog'), - ], + 'choices': [('A', 'Anaconda'), ('B', 'Bat'), ('C', 'Cat'), ('D', 'Dog')], 'default': 'A', }, 'SELECT_COMPANY': { @@ -94,9 +91,7 @@ class SampleIntegrationPlugin(AppMixin, SettingsMixin, UrlsMixin, NavigationMixi 'name': 'JSON validator Setting', 'description': 'A setting using a JSON validator', 'validator': validate_json, - } + }, } - NAVIGATION = [ - {'name': 'SampleIntegration', 'link': 'plugin:sample:hi'}, - ] + NAVIGATION = [{'name': 'SampleIntegration', 'link': 'plugin:sample:hi'}] diff --git a/InvenTree/plugin/samples/integration/sample_currency_exchange.py b/InvenTree/plugin/samples/integration/sample_currency_exchange.py index 7a490543da..a60a2eb11c 100644 --- a/InvenTree/plugin/samples/integration/sample_currency_exchange.py +++ b/InvenTree/plugin/samples/integration/sample_currency_exchange.py @@ -11,17 +11,15 @@ from plugin.mixins import CurrencyExchangeMixin class SampleCurrencyExchangePlugin(CurrencyExchangeMixin, InvenTreePlugin): """Dummy currency exchange plugin which provides fake exchange rates""" - NAME = "Sample Exchange" - DESCRIPTION = _("Sample currency exchange plugin") - SLUG = "samplecurrencyexchange" - VERSION = "0.1.0" - AUTHOR = _("InvenTree Contributors") + NAME = 'Sample Exchange' + DESCRIPTION = _('Sample currency exchange plugin') + SLUG = 'samplecurrencyexchange' + VERSION = '0.1.0' + AUTHOR = _('InvenTree Contributors') def update_exchange_rates(self, base_currency: str, symbols: list[str]) -> dict: """Return dummy data for some currencies""" - rates = { - base_currency: 1.00, - } + rates = {base_currency: 1.00} for symbol in symbols: rates[symbol] = random.randrange(5, 15) * 0.1 diff --git a/InvenTree/plugin/samples/integration/scheduled_task.py b/InvenTree/plugin/samples/integration/scheduled_task.py index 50c7fd8d97..3e8a8f574a 100644 --- a/InvenTree/plugin/samples/integration/scheduled_task.py +++ b/InvenTree/plugin/samples/integration/scheduled_task.py @@ -10,7 +10,7 @@ def print_hello(): Contents do not matter - therefore no coverage. """ - print("Hello") # pragma: no cover + print('Hello') # pragma: no cover def print_world(): @@ -18,23 +18,19 @@ def print_world(): Contents do not matter - therefore no coverage. """ - print("World") # pragma: no cover + print('World') # pragma: no cover class ScheduledTaskPlugin(ScheduleMixin, SettingsMixin, InvenTreePlugin): """A sample plugin which provides support for scheduled tasks.""" - NAME = "ScheduledTasksPlugin" - SLUG = "schedule" - TITLE = "Scheduled Tasks" - VERSION = "0.2.0" + NAME = 'ScheduledTasksPlugin' + SLUG = 'schedule' + TITLE = 'Scheduled Tasks' + VERSION = '0.2.0' SCHEDULED_TASKS = { - 'member': { - 'func': 'member_func', - 'schedule': 'I', - 'minutes': 30, - }, + 'member': {'func': 'member_func', 'schedule': 'I', 'minutes': 30}, 'hello': { 'func': 'plugin.samples.integration.scheduled_task.print_hello', 'schedule': 'I', @@ -52,12 +48,12 @@ class ScheduledTaskPlugin(ScheduleMixin, SettingsMixin, InvenTreePlugin): 'description': 'Print true or false when running the periodic task', 'validator': bool, 'default': False, - }, + } } def member_func(self, *args, **kwargs): """A simple member function to demonstrate functionality.""" t_or_f = self.get_setting('T_OR_F') - print(f"Called member_func - value is {t_or_f}") + print(f'Called member_func - value is {t_or_f}') return t_or_f diff --git a/InvenTree/plugin/samples/integration/simpleactionplugin.py b/InvenTree/plugin/samples/integration/simpleactionplugin.py index af4b404764..bf892ba577 100644 --- a/InvenTree/plugin/samples/integration/simpleactionplugin.py +++ b/InvenTree/plugin/samples/integration/simpleactionplugin.py @@ -7,19 +7,16 @@ from plugin.mixins import ActionMixin class SimpleActionPlugin(ActionMixin, InvenTreePlugin): """An EXTREMELY simple action plugin which demonstrates the capability of the ActionMixin class.""" - NAME = "SimpleActionPlugin" - ACTION_NAME = "simple" + NAME = 'SimpleActionPlugin' + ACTION_NAME = 'simple' def perform_action(self, user=None, data=None): """Sample method.""" - print("Action plugin in action!") + print('Action plugin in action!') def get_info(self, user, data=None): """Sample method.""" - return { - "user": user.username, - "hello": "world", - } + return {'user': user.username, 'hello': 'world'} def get_result(self, user=None, data=None): """Sample method.""" diff --git a/InvenTree/plugin/samples/integration/test_api_caller.py b/InvenTree/plugin/samples/integration/test_api_caller.py index 28b51835c8..ff07af0bca 100644 --- a/InvenTree/plugin/samples/integration/test_api_caller.py +++ b/InvenTree/plugin/samples/integration/test_api_caller.py @@ -18,4 +18,4 @@ class SampleApiCallerPluginTests(TestCase): # do an api call result = plg.get_external_url() self.assertTrue(result) - self.assertIn('data', result,) + self.assertIn('data', result) diff --git a/InvenTree/plugin/samples/integration/test_sample.py b/InvenTree/plugin/samples/integration/test_sample.py index 66a6106f76..0d7ac58893 100644 --- a/InvenTree/plugin/samples/integration/test_sample.py +++ b/InvenTree/plugin/samples/integration/test_sample.py @@ -46,10 +46,11 @@ class SampleIntegrationPluginTests(InvenTreeTestCase): # check settings self.assertEqual(plugin.check_settings(), (False, ['API_KEY'])) - plugin.set_setting('API_KEY', "dsfiodsfjsfdjsf") + plugin.set_setting('API_KEY', 'dsfiodsfjsfdjsf') self.assertEqual(plugin.check_settings(), (True, [])) # validator + def test_settings_validator(self): """Test settings validator for plugins.""" plugin = registry.get_plugin('sample') diff --git a/InvenTree/plugin/samples/integration/test_scheduled_task.py b/InvenTree/plugin/samples/integration/test_scheduled_task.py index 78a05e52e9..11521a8252 100644 --- a/InvenTree/plugin/samples/integration/test_scheduled_task.py +++ b/InvenTree/plugin/samples/integration/test_scheduled_task.py @@ -22,13 +22,21 @@ class ExampleScheduledTaskPluginTests(TestCase): self.assertEqual(plg.member_func(), False) # check that the tasks are defined - self.assertEqual(plg.get_task_names(), ['plugin.schedule.member', 'plugin.schedule.hello', 'plugin.schedule.world']) + self.assertEqual( + plg.get_task_names(), + [ + 'plugin.schedule.member', + 'plugin.schedule.hello', + 'plugin.schedule.world', + ], + ) # register plg.register_tasks() # check that schedule was registers from django_q.models import Schedule - scheduled_plugin_tasks = Schedule.objects.filter(name__istartswith="plugin.") + + scheduled_plugin_tasks = Schedule.objects.filter(name__istartswith='plugin.') self.assertEqual(len(scheduled_plugin_tasks), 3) # test updating the schedule @@ -40,7 +48,7 @@ class ExampleScheduledTaskPluginTests(TestCase): # Check that the schedule was updated hello_schedule = Schedule.objects.get(name='plugin.schedule.hello') - scheduled_plugin_tasks = Schedule.objects.filter(name__istartswith="plugin.") + scheduled_plugin_tasks = Schedule.objects.filter(name__istartswith='plugin.') self.assertEqual(hello_schedule.minutes, 15) self.assertEqual(len(scheduled_plugin_tasks), 3) @@ -48,12 +56,12 @@ class ExampleScheduledTaskPluginTests(TestCase): # this is to check the system also deals with disappearing tasks scheduled_plugin_tasks[1].delete() # there should be one less now - scheduled_plugin_tasks = Schedule.objects.filter(name__istartswith="plugin.") + scheduled_plugin_tasks = Schedule.objects.filter(name__istartswith='plugin.') self.assertEqual(len(scheduled_plugin_tasks), 2) # test unregistering plg.unregister_tasks() - scheduled_plugin_tasks = Schedule.objects.filter(name__istartswith="plugin.") + scheduled_plugin_tasks = Schedule.objects.filter(name__istartswith='plugin.') self.assertEqual(len(scheduled_plugin_tasks), 0) def test_calling(self): @@ -70,11 +78,13 @@ class ScheduledTaskPluginTests(TestCase): def test_init(self): """Check that all MixinImplementationErrors raise.""" + class Base(ScheduleMixin, InvenTreePlugin): NAME = 'APlugin' class NoSchedules(Base): """Plugin without schedules.""" + pass with self.assertRaises(MixinImplementationError): @@ -86,12 +96,7 @@ class ScheduledTaskPluginTests(TestCase): This plugin is missing a func """ - SCHEDULED_TASKS = { - 'test': { - 'schedule': 'I', - 'minutes': 30, - }, - } + SCHEDULED_TASKS = {'test': {'schedule': 'I', 'minutes': 30}} def test(self): pass # pragma: no cover @@ -105,12 +110,7 @@ class ScheduledTaskPluginTests(TestCase): This plugin is missing a schedule """ - SCHEDULED_TASKS = { - 'test': { - 'func': 'test', - 'minutes': 30, - }, - } + SCHEDULED_TASKS = {'test': {'func': 'test', 'minutes': 30}} with self.assertRaises(MixinImplementationError): WrongFuncSchedules1() @@ -121,12 +121,7 @@ class ScheduledTaskPluginTests(TestCase): This plugin is missing a schedule """ - SCHEDULED_TASKS = { - 'test': { - 'func': 'test', - 'minutes': 30, - }, - } + SCHEDULED_TASKS = {'test': {'func': 'test', 'minutes': 30}} with self.assertRaises(MixinImplementationError): WrongFuncSchedules2() @@ -138,11 +133,7 @@ class ScheduledTaskPluginTests(TestCase): """ SCHEDULED_TASKS = { - 'test': { - 'func': 'test', - 'schedule': 'XX', - 'minutes': 30, - }, + 'test': {'func': 'test', 'schedule': 'XX', 'minutes': 30} } with self.assertRaises(MixinImplementationError): @@ -154,12 +145,7 @@ class ScheduledTaskPluginTests(TestCase): This plugin is missing a minute marker for its schedule """ - SCHEDULED_TASKS = { - 'test': { - 'func': 'test', - 'schedule': 'I', - }, - } + SCHEDULED_TASKS = {'test': {'func': 'test', 'schedule': 'I'}} with self.assertRaises(MixinImplementationError): WrongFuncSchedules4() diff --git a/InvenTree/plugin/samples/integration/test_simpleactionplugin.py b/InvenTree/plugin/samples/integration/test_simpleactionplugin.py index 2599984a66..d1542c4c2d 100644 --- a/InvenTree/plugin/samples/integration/test_simpleactionplugin.py +++ b/InvenTree/plugin/samples/integration/test_simpleactionplugin.py @@ -15,22 +15,21 @@ class SimpleActionPluginTests(InvenTreeTestCase): def test_name(self): """Check plugn names.""" - self.assertEqual(self.plugin.plugin_name(), "SimpleActionPlugin") - self.assertEqual(self.plugin.action_name(), "simple") + self.assertEqual(self.plugin.plugin_name(), 'SimpleActionPlugin') + self.assertEqual(self.plugin.action_name(), 'simple') def test_function(self): """Check if functions work.""" # test functions - response = self.client.post('/api/action/', data={'action': "simple", 'data': {'foo': "bar", }}) + response = self.client.post( + '/api/action/', data={'action': 'simple', 'data': {'foo': 'bar'}} + ) self.assertEqual(response.status_code, 200) self.assertJSONEqual( str(response.content, encoding='utf8'), { - "action": 'simple', - "result": True, - "info": { - "user": self.username, - "hello": "world", - }, - } + 'action': 'simple', + 'result': True, + 'info': {'user': self.username, 'hello': 'world'}, + }, ) diff --git a/InvenTree/plugin/samples/integration/transition.py b/InvenTree/plugin/samples/integration/transition.py index 4de3b140a9..97166506bd 100644 --- a/InvenTree/plugin/samples/integration/transition.py +++ b/InvenTree/plugin/samples/integration/transition.py @@ -10,7 +10,7 @@ from plugin import InvenTreePlugin class SampleTransitionPlugin(InvenTreePlugin): """A sample plugin which shows how state transitions might be implemented.""" - NAME = "SampleTransitionPlugin" + NAME = 'SampleTransitionPlugin' class ReturnChangeHandler(TransitionMethod): """Transition method for PurchaseOrder objects.""" @@ -18,7 +18,9 @@ class SampleTransitionPlugin(InvenTreePlugin): def transition(current_state, target_state, instance, default_action, **kwargs): # noqa: N805 """Example override function for state transition.""" # Only act on ReturnOrders that should be completed - if not isinstance(instance, ReturnOrder) or not (target_state == ReturnOrderStatus.COMPLETE.value): + if not isinstance(instance, ReturnOrder) or not ( + target_state == ReturnOrderStatus.COMPLETE.value + ): return False # Only allow proceeding if the return order has a responsible user assigned @@ -28,8 +30,10 @@ class SampleTransitionPlugin(InvenTreePlugin): trigger_notification( instance, 'sampel_123_456', - targets=[instance.created_by, ], - context={'message': "Return order without responsible owner can not be completed!"}, + targets=[instance.created_by], + context={ + 'message': 'Return order without responsible owner can not be completed!' + }, ) return True # True means nothing will happen return False # Do not act diff --git a/InvenTree/plugin/samples/integration/validation_sample.py b/InvenTree/plugin/samples/integration/validation_sample.py index 63e8d0eb9b..585ab59ed6 100644 --- a/InvenTree/plugin/samples/integration/validation_sample.py +++ b/InvenTree/plugin/samples/integration/validation_sample.py @@ -14,17 +14,17 @@ class CustomValidationMixin(SettingsMixin, ValidationMixin, InvenTreePlugin): Simple of examples of custom validator code. """ - NAME = "CustomValidator" - SLUG = "validator" - TITLE = "Custom Validator Plugin" - DESCRIPTION = "A sample plugin for demonstrating custom validation functionality" - VERSION = "0.3.0" + NAME = 'CustomValidator' + SLUG = 'validator' + TITLE = 'Custom Validator Plugin' + DESCRIPTION = 'A sample plugin for demonstrating custom validation functionality' + VERSION = '0.3.0' SETTINGS = { 'ILLEGAL_PART_CHARS': { 'name': 'Illegal Part Characters', 'description': 'Characters which are not allowed to appear in Part names', - 'default': '!@#$%^&*()~`' + 'default': '!@#$%^&*()~`', }, 'IPN_MUST_CONTAIN_Q': { 'name': 'IPN Q Requirement', @@ -60,7 +60,7 @@ class CustomValidationMixin(SettingsMixin, ValidationMixin, InvenTreePlugin): These examples are silly, but serve to demonstrate how the feature could be used """ if len(part.description) < len(name): - raise ValidationError("Part description cannot be shorter than the name") + raise ValidationError('Part description cannot be shorter than the name') illegal_chars = self.get_setting('ILLEGAL_PART_CHARS') @@ -84,7 +84,7 @@ class CustomValidationMixin(SettingsMixin, ValidationMixin, InvenTreePlugin): if parameter.template.name.lower() in ['length', 'width']: d = int(data) if d >= 100: - raise ValidationError("Value must be less than 100") + raise ValidationError('Value must be less than 100') def validate_serial_number(self, serial: str, part): """Validate serial number for a given StockItem @@ -93,12 +93,14 @@ class CustomValidationMixin(SettingsMixin, ValidationMixin, InvenTreePlugin): """ if self.get_setting('SERIAL_MUST_BE_PALINDROME'): if serial != serial[::-1]: - raise ValidationError("Serial must be a palindrome") + raise ValidationError('Serial must be a palindrome') if self.get_setting('SERIAL_MUST_MATCH_PART'): # Serial must start with the same letter as the linked part, for some reason if serial[0] != part.name[0]: - raise ValidationError("Serial number must start with same letter as part") + raise ValidationError( + 'Serial number must start with same letter as part' + ) def validate_batch_code(self, batch_code: str, item): """Ensure that a particular batch code meets specification. @@ -114,4 +116,4 @@ class CustomValidationMixin(SettingsMixin, ValidationMixin, InvenTreePlugin): def generate_batch_code(self): """Generate a new batch code.""" now = datetime.now() - return f"BATCH-{now.year}:{now.month}:{now.day}" + return f'BATCH-{now.year}:{now.month}:{now.day}' diff --git a/InvenTree/plugin/samples/integration/version.py b/InvenTree/plugin/samples/integration/version.py index d0f2183310..eda28a5431 100644 --- a/InvenTree/plugin/samples/integration/version.py +++ b/InvenTree/plugin/samples/integration/version.py @@ -1,12 +1,13 @@ """Sample plugin for versioning.""" + from plugin import InvenTreePlugin class VersionPlugin(InvenTreePlugin): """A small version sample.""" - SLUG = "sampleversion" - NAME = "Sample Version Plugin" - DESCRIPTION = "A simple plugin which shows how to use the version limits" + SLUG = 'sampleversion' + NAME = 'Sample Version Plugin' + DESCRIPTION = 'A simple plugin which shows how to use the version limits' MIN_VERSION = '0.1.0' MAX_VERSION = '1.0.0' diff --git a/InvenTree/plugin/samples/locate/locate_sample.py b/InvenTree/plugin/samples/locate/locate_sample.py index 13c7f4e304..50f3263f98 100644 --- a/InvenTree/plugin/samples/locate/locate_sample.py +++ b/InvenTree/plugin/samples/locate/locate_sample.py @@ -17,11 +17,11 @@ class SampleLocatePlugin(LocateMixin, InvenTreePlugin): This plugin class simply prints location information to the logger. """ - NAME = "SampleLocatePlugin" - SLUG = "samplelocate" - TITLE = "Sample plugin for locating items" + NAME = 'SampleLocatePlugin' + SLUG = 'samplelocate' + TITLE = 'Sample plugin for locating items' - VERSION = "0.2" + VERSION = '0.2' def locate_stock_item(self, item_pk): """Locate a StockItem. @@ -31,17 +31,17 @@ class SampleLocatePlugin(LocateMixin, InvenTreePlugin): """ from stock.models import StockItem - logger.info("SampleLocatePlugin attempting to locate item ID %s", item_pk) + logger.info('SampleLocatePlugin attempting to locate item ID %s', item_pk) try: item = StockItem.objects.get(pk=item_pk) - logger.info("StockItem %s located!", item_pk) + logger.info('StockItem %s located!', item_pk) # Tag metadata item.set_metadata('located', True) except (ValueError, StockItem.DoesNotExist): # pragma: no cover - logger.exception("StockItem ID %s does not exist!", item_pk) + logger.exception('StockItem ID %s does not exist!', item_pk) def locate_stock_location(self, location_pk): """Locate a StockLocation. @@ -51,7 +51,9 @@ class SampleLocatePlugin(LocateMixin, InvenTreePlugin): """ from stock.models import StockLocation - logger.info("SampleLocatePlugin attempting to locate location ID %s", location_pk) + logger.info( + 'SampleLocatePlugin attempting to locate location ID %s', location_pk + ) try: location = StockLocation.objects.get(pk=location_pk) @@ -61,4 +63,4 @@ class SampleLocatePlugin(LocateMixin, InvenTreePlugin): location.set_metadata('located', True) except (ValueError, StockLocation.DoesNotExist): # pragma: no cover - logger.exception("Location ID %s does not exist!", location_pk) + logger.exception('Location ID %s does not exist!', location_pk) diff --git a/InvenTree/plugin/samples/locate/test_locate_sample.py b/InvenTree/plugin/samples/locate/test_locate_sample.py index 04f1d674bc..e34d0de0ae 100644 --- a/InvenTree/plugin/samples/locate/test_locate_sample.py +++ b/InvenTree/plugin/samples/locate/test_locate_sample.py @@ -11,12 +11,7 @@ from plugin.mixins import LocateMixin class SampleLocatePlugintests(InvenTreeAPITestCase): """Tests for SampleLocatePlugin.""" - fixtures = [ - 'location', - 'category', - 'part', - 'stock' - ] + fixtures = ['location', 'category', 'part', 'stock'] def test_run_locator(self): """Check if the event is issued.""" @@ -53,6 +48,7 @@ class SampleLocatePlugintests(InvenTreeAPITestCase): """Test that MixinNotImplementedError is raised.""" # Test location locator with self.assertRaises(MixinNotImplementedError): + class Wrong(LocateMixin, InvenTreePlugin): pass diff --git a/InvenTree/plugin/serializers.py b/InvenTree/plugin/serializers.py index e86dc2b3fe..9d123ee952 100644 --- a/InvenTree/plugin/serializers.py +++ b/InvenTree/plugin/serializers.py @@ -17,9 +17,7 @@ class MetadataSerializer(serializers.ModelSerializer): class Meta: """Metaclass options.""" - fields = [ - 'metadata', - ] + fields = ['metadata'] def __init__(self, model_type, *args, **kwargs): """Initialize the metadata serializer with information on the model type""" @@ -46,6 +44,7 @@ class PluginConfigSerializer(serializers.ModelSerializer): class Meta: """Meta for serializer.""" + model = PluginConfig fields = [ 'pk', @@ -59,12 +58,7 @@ class PluginConfigSerializer(serializers.ModelSerializer): 'is_installed', ] - read_only_fields = [ - 'key', - 'is_builtin', - 'is_sample', - 'is_installed', - ] + read_only_fields = ['key', 'is_builtin', 'is_sample', 'is_installed'] meta = serializers.DictField(read_only=True) mixins = serializers.DictField(read_only=True) @@ -75,27 +69,30 @@ class PluginConfigInstallSerializer(serializers.Serializer): class Meta: """Meta for serializer.""" - fields = [ - 'url', - 'packagename', - 'confirm', - ] + + fields = ['url', 'packagename', 'confirm'] url = serializers.CharField( required=False, allow_blank=True, label=_('Source URL'), - help_text=_('Source for the package - this can be a custom registry or a VCS path') + help_text=_( + 'Source for the package - this can be a custom registry or a VCS path' + ), ) packagename = serializers.CharField( required=False, allow_blank=True, label=_('Package Name'), - help_text=_('Name for the Plugin Package - can also contain a version indicator'), + help_text=_( + 'Name for the Plugin Package - can also contain a version indicator' + ), ) confirm = serializers.BooleanField( label=_('Confirm plugin installation'), - help_text=_('This will install this plugin now into the current instance. The instance will go into maintenance.') + help_text=_( + 'This will install this plugin now into the current instance. The instance will go into maintenance.' + ), ) def validate(self, data): @@ -128,6 +125,7 @@ class PluginConfigInstallSerializer(serializers.Serializer): class PluginConfigEmptySerializer(serializers.Serializer): """Serializer for a PluginConfig.""" + ... @@ -135,26 +133,32 @@ class PluginReloadSerializer(serializers.Serializer): """Serializer for remotely forcing plugin registry reload""" full_reload = serializers.BooleanField( - required=False, default=False, - label=_("Full reload"), - help_text=_("Perform a full reload of the plugin registry") + required=False, + default=False, + label=_('Full reload'), + help_text=_('Perform a full reload of the plugin registry'), ) force_reload = serializers.BooleanField( - required=False, default=False, - label=_("Force reload"), - help_text=_("Force a reload of the plugin registry, even if it is already loaded") + required=False, + default=False, + label=_('Force reload'), + help_text=_( + 'Force a reload of the plugin registry, even if it is already loaded' + ), ) collect_plugins = serializers.BooleanField( - required=False, default=False, - label=_("Collect plugins"), - help_text=_("Collect plugins and add them to the registry") + required=False, + default=False, + label=_('Collect plugins'), + help_text=_('Collect plugins and add them to the registry'), ) def save(self): """Reload the plugin registry.""" from plugin.registry import registry + registry.reload_plugins( full_reload=self.validated_data.get('full_reload', False), force_reload=self.validated_data.get('force_reload', False), @@ -168,9 +172,10 @@ class PluginActivateSerializer(serializers.Serializer): model = PluginConfig active = serializers.BooleanField( - required=False, default=True, + required=False, + default=True, label=_('Activate Plugin'), - help_text=_('Activate this plugin') + help_text=_('Activate this plugin'), ) def update(self, instance, validated_data): @@ -191,9 +196,7 @@ class PluginSettingSerializer(GenericReferencedSettingSerializer): """Serializer for the PluginSetting model.""" MODEL = PluginSetting - EXTRA_FIELDS = [ - 'plugin', - ] + EXTRA_FIELDS = ['plugin'] plugin = serializers.CharField(source='plugin.key', read_only=True) @@ -202,7 +205,7 @@ class NotificationUserSettingSerializer(GenericReferencedSettingSerializer): """Serializer for the PluginSetting model.""" MODEL = NotificationUserSetting - EXTRA_FIELDS = ['method', ] + EXTRA_FIELDS = ['method'] method = serializers.CharField(read_only=True) diff --git a/InvenTree/plugin/templatetags/plugin_extras.py b/InvenTree/plugin/templatetags/plugin_extras.py index 02fc42f8b8..0483b1588e 100644 --- a/InvenTree/plugin/templatetags/plugin_extras.py +++ b/InvenTree/plugin/templatetags/plugin_extras.py @@ -85,11 +85,14 @@ def notification_settings_list(context, *args, **kwargs): @register.simple_tag(takes_context=True) def notification_list(context, *args, **kwargs): """List of all notification methods.""" - return [{ - 'slug': a.METHOD_NAME, - 'icon': a.METHOD_ICON, - 'setting': a.GLOBAL_SETTING, - 'plugin': a.plugin, - 'description': a.__doc__, - 'name': a.__name__ - } for a in storage.liste] + return [ + { + 'slug': a.METHOD_NAME, + 'icon': a.METHOD_ICON, + 'setting': a.GLOBAL_SETTING, + 'plugin': a.plugin, + 'description': a.__doc__, + 'name': a.__name__, + } + for a in storage.liste + ] diff --git a/InvenTree/plugin/test_api.py b/InvenTree/plugin/test_api.py index 442a493d48..bf094bde49 100644 --- a/InvenTree/plugin/test_api.py +++ b/InvenTree/plugin/test_api.py @@ -12,12 +12,7 @@ from plugin.models import PluginConfig class PluginDetailAPITest(PluginMixin, InvenTreeAPITestCase): """Tests the plugin API endpoints""" - roles = [ - 'admin.add', - 'admin.view', - 'admin.change', - 'admin.delete', - ] + roles = ['admin.add', 'admin.view', 'admin.change', 'admin.delete'] def setUp(self): """Setup for all tests.""" @@ -36,31 +31,21 @@ class PluginDetailAPITest(PluginMixin, InvenTreeAPITestCase): url, { 'confirm': True, - 'packagename': 'invalid_package_name-asdads-asfd-asdf-asdf-asdf' + 'packagename': 'invalid_package_name-asdads-asfd-asdf-asdf-asdf', }, - expected_code=400 + expected_code=400, ) # valid - Pypi data = self.post( - url, - { - 'confirm': True, - 'packagename': self.PKG_NAME - }, - expected_code=201, + url, {'confirm': True, 'packagename': self.PKG_NAME}, expected_code=201 ).data self.assertEqual(data['success'], 'Installed plugin successfully') # valid - github url data = self.post( - url, - { - 'confirm': True, - 'url': self.PKG_URL - }, - expected_code=201, + url, {'confirm': True, 'url': self.PKG_URL}, expected_code=201 ).data self.assertEqual(data['success'], 'Installed plugin successfully') @@ -68,11 +53,7 @@ class PluginDetailAPITest(PluginMixin, InvenTreeAPITestCase): # valid - github url and package name data = self.post( url, - { - 'confirm': True, - 'url': self.PKG_URL, - 'packagename': self.PKG_NAME, - }, + {'confirm': True, 'url': self.PKG_URL, 'packagename': self.PKG_NAME}, expected_code=201, ).data self.assertEqual(data['success'], 'Installed plugin successfully') @@ -82,24 +63,23 @@ class PluginDetailAPITest(PluginMixin, InvenTreeAPITestCase): self.post(url, {}, expected_code=400) # no package info - data = self.post(url, { - 'confirm': True, - }, expected_code=400).data + data = self.post(url, {'confirm': True}, expected_code=400).data self.assertEqual(data['url'][0].title().upper(), self.MSG_NO_PKG.upper()) - self.assertEqual(data['packagename'][0].title().upper(), self.MSG_NO_PKG.upper()) + self.assertEqual( + data['packagename'][0].title().upper(), self.MSG_NO_PKG.upper() + ) # not confirmed - self.post(url, { - 'packagename': self.PKG_NAME - }, expected_code=400) + self.post(url, {'packagename': self.PKG_NAME}, expected_code=400) - data = self.post(url, { - 'packagename': self.PKG_NAME, - 'confirm': False, - }, expected_code=400).data + data = self.post( + url, {'packagename': self.PKG_NAME, 'confirm': False}, expected_code=400 + ).data - self.assertEqual(data['confirm'][0].title().upper(), 'Installation not confirmed'.upper()) + self.assertEqual( + data['confirm'][0].title().upper(), 'Installation not confirmed'.upper() + ) def test_plugin_activate(self): """Test the plugin activate.""" @@ -122,7 +102,11 @@ class PluginDetailAPITest(PluginMixin, InvenTreeAPITestCase): # Activate plugin with detail url assert_plugin_active(self, False) - response = self.client.patch(reverse('api-plugin-detail-activate', kwargs={'pk': test_plg.id}), {}, follow=True) + response = self.client.patch( + reverse('api-plugin-detail-activate', kwargs={'pk': test_plg.id}), + {}, + follow=True, + ) self.assertEqual(response.status_code, 200) assert_plugin_active(self, True) @@ -132,7 +116,9 @@ class PluginDetailAPITest(PluginMixin, InvenTreeAPITestCase): # Activate plugin assert_plugin_active(self, False) - response = self.client.patch(reverse('api-plugin-activate'), {'pk': test_plg.pk}, follow=True) + response = self.client.patch( + reverse('api-plugin-activate'), {'pk': test_plg.pk}, follow=True + ) self.assertEqual(response.status_code, 200) assert_plugin_active(self, True) @@ -142,33 +128,47 @@ class PluginDetailAPITest(PluginMixin, InvenTreeAPITestCase): test_plg = self.plugin_confs.first() # deactivate plugin - response = self.client.post(url, { - 'action': 'plugin_deactivate', - 'index': 0, - '_selected_action': [test_plg.pk], - }, follow=True) + response = self.client.post( + url, + { + 'action': 'plugin_deactivate', + 'index': 0, + '_selected_action': [test_plg.pk], + }, + follow=True, + ) self.assertEqual(response.status_code, 200) # deactivate plugin - deactivate again -> nothing will happen but the nothing 'changed' function is triggered - response = self.client.post(url, { - 'action': 'plugin_deactivate', - 'index': 0, - '_selected_action': [test_plg.pk], - }, follow=True) + response = self.client.post( + url, + { + 'action': 'plugin_deactivate', + 'index': 0, + '_selected_action': [test_plg.pk], + }, + follow=True, + ) self.assertEqual(response.status_code, 200) # activate plugin - response = self.client.post(url, { - 'action': 'plugin_activate', - 'index': 0, - '_selected_action': [test_plg.pk], - }, follow=True) + response = self.client.post( + url, + { + 'action': 'plugin_activate', + 'index': 0, + '_selected_action': [test_plg.pk], + }, + follow=True, + ) self.assertEqual(response.status_code, 200) # save to deactivate a plugin - response = self.client.post(reverse('admin:plugin_pluginconfig_change', args=(test_plg.pk, )), { - '_save': 'Save', - }, follow=True) + response = self.client.post( + reverse('admin:plugin_pluginconfig_change', args=(test_plg.pk,)), + {'_save': 'Save'}, + follow=True, + ) self.assertEqual(response.status_code, 200) def test_model(self): @@ -177,7 +177,9 @@ class PluginDetailAPITest(PluginMixin, InvenTreeAPITestCase): plg = self.plugin_confs.first() mixin_dict = plg.mixins() self.assertIn('base', mixin_dict) - self.assertDictContainsSubset({'base': {'key': 'base', 'human_name': 'base'}}, mixin_dict) + self.assertDictContainsSubset( + {'base': {'key': 'base', 'human_name': 'base'}}, mixin_dict + ) # check reload on save with self.assertWarns(Warning) as cm: @@ -215,62 +217,56 @@ class PluginDetailAPITest(PluginMixin, InvenTreeAPITestCase): self.client.patch(url, {}, expected_code=200) # Valid plugin settings endpoints - valid_settings = [ - 'SELECT_PART', - 'API_KEY', - 'NUMERICAL_SETTING', - ] + valid_settings = ['SELECT_PART', 'API_KEY', 'NUMERICAL_SETTING'] for key in valid_settings: response = self.get( - reverse('api-plugin-setting-detail', kwargs={ - 'plugin': 'sample', - 'key': key - })) + reverse( + 'api-plugin-setting-detail', kwargs={'plugin': 'sample', 'key': key} + ) + ) self.assertEqual(response.data['key'], key) # Test that an invalid setting key raises a 404 error response = self.get( - reverse('api-plugin-setting-detail', kwargs={ - 'plugin': 'sample', - 'key': 'INVALID_SETTING' - }), - expected_code=404 + reverse( + 'api-plugin-setting-detail', + kwargs={'plugin': 'sample', 'key': 'INVALID_SETTING'}, + ), + expected_code=404, ) # Test that a protected setting returns hidden value response = self.get( - reverse('api-plugin-setting-detail', kwargs={ - 'plugin': 'sample', - 'key': 'PROTECTED_SETTING' - }), - expected_code=200 + reverse( + 'api-plugin-setting-detail', + kwargs={'plugin': 'sample', 'key': 'PROTECTED_SETTING'}, + ), + expected_code=200, ) self.assertEqual(response.data['value'], '***') # Test that we can update a setting value response = self.patch( - reverse('api-plugin-setting-detail', kwargs={ - 'plugin': 'sample', - 'key': 'NUMERICAL_SETTING' - }), - { - 'value': 456 - }, - expected_code=200 + reverse( + 'api-plugin-setting-detail', + kwargs={'plugin': 'sample', 'key': 'NUMERICAL_SETTING'}, + ), + {'value': 456}, + expected_code=200, ) self.assertEqual(response.data['value'], '456') # Retrieve the value again response = self.get( - reverse('api-plugin-setting-detail', kwargs={ - 'plugin': 'sample', - 'key': 'NUMERICAL_SETTING' - }), - expected_code=200 + reverse( + 'api-plugin-setting-detail', + kwargs={'plugin': 'sample', 'key': 'NUMERICAL_SETTING'}, + ), + expected_code=200, ) self.assertEqual(response.data['value'], '456') diff --git a/InvenTree/plugin/test_helpers.py b/InvenTree/plugin/test_helpers.py index 53b2622592..4a32c02fb9 100644 --- a/InvenTree/plugin/test_helpers.py +++ b/InvenTree/plugin/test_helpers.py @@ -10,6 +10,7 @@ class HelperTests(TestCase): def test_render_template(self): """Check if render_template helper works.""" + class ErrorSource: slug = 'sampleplg' @@ -18,6 +19,8 @@ class HelperTests(TestCase): self.assertEqual(response, '

    123

    \n') # Wrong sample - response = render_template(ErrorSource(), 'sample/wrongsample.html', {'abc': 123}) + response = render_template( + ErrorSource(), 'sample/wrongsample.html', {'abc': 123} + ) self.assertTrue('lert alert-block alert-danger' in response) self.assertTrue('Template file sample/wrongsample.html' in response) diff --git a/InvenTree/plugin/test_plugin.py b/InvenTree/plugin/test_plugin.py index b0b91ed2bf..0612cb70bb 100644 --- a/InvenTree/plugin/test_plugin.py +++ b/InvenTree/plugin/test_plugin.py @@ -12,8 +12,10 @@ from django.test import TestCase, override_settings import plugin.templatetags.plugin_extras as plugin_tags from plugin import InvenTreePlugin, registry -from plugin.samples.integration.another_sample import (NoIntegrationPlugin, - WrongIntegrationPlugin) +from plugin.samples.integration.another_sample import ( + NoIntegrationPlugin, + WrongIntegrationPlugin, +) from plugin.samples.integration.sample import SampleIntegrationPlugin @@ -38,7 +40,7 @@ class PluginTagTests(TestCase): """Check all plugins are listed.""" self.assertEqual( plugin_tags.plugin_settings(self.sample), - registry.mixins_settings.get(self.sample) + registry.mixins_settings.get(self.sample), ) def test_tag_mixin_enabled(self): @@ -54,7 +56,9 @@ class PluginTagTests(TestCase): def test_tag_safe_url(self): """Test that the safe url tag works expected.""" # right url - self.assertEqual(plugin_tags.safe_url('api-plugin-install'), '/api/plugins/install/') + self.assertEqual( + plugin_tags.safe_url('api-plugin-install'), '/api/plugins/install/' + ) # wrong url self.assertEqual(plugin_tags.safe_url('indexas'), None) @@ -75,6 +79,7 @@ class InvenTreePluginTests(TestCase): class NamedPlugin(InvenTreePlugin): """a named plugin.""" + NAME = 'abc123' cls.named_plugin = NamedPlugin() @@ -93,7 +98,7 @@ class InvenTreePluginTests(TestCase): NAME = 'Aplugin' SLUG = 'a' TITLE = 'a title' - PUBLISH_DATE = "1111-11-11" + PUBLISH_DATE = '1111-11-11' AUTHOR = 'AA BB' DESCRIPTION = 'A description' VERSION = '1.2.3a' @@ -179,7 +184,9 @@ class InvenTreePluginTests(TestCase): with self.assertWarns(DeprecationWarning): self.assertEqual(self.plugin_old.slug, 'old') # check default value is used - self.assertEqual(self.plugin_old.get_meta_value('ABC', 'ABCD', '123'), '123') + self.assertEqual( + self.plugin_old.get_meta_value('ABC', 'ABCD', '123'), '123' + ) def test_version(self): """Test Version checks""" @@ -270,7 +277,12 @@ class RegistryTests(TestCase): self.assertEqual(len(registry.errors), 3) # There should be at least one discovery error in the module `broken_file` self.assertTrue(len(registry.errors.get('discovery')) > 0) - self.assertEqual(registry.errors.get('discovery')[0]['broken_file'], "name 'bb' is not defined") + self.assertEqual( + registry.errors.get('discovery')[0]['broken_file'], + "name 'bb' is not defined", + ) # There should be at least one load error with an intentional KeyError self.assertTrue(len(registry.errors.get('load')) > 0) - self.assertEqual(registry.errors.get('load')[0]['broken_sample'], "'This is a dummy error'") + self.assertEqual( + registry.errors.get('load')[0]['broken_sample'], "'This is a dummy error'" + ) diff --git a/InvenTree/plugin/urls.py b/InvenTree/plugin/urls.py index d337dee9ee..dc25226171 100644 --- a/InvenTree/plugin/urls.py +++ b/InvenTree/plugin/urls.py @@ -13,7 +13,9 @@ def get_plugin_urls(): urls = [] # Only allow custom routing if the setting is enabled - if InvenTreeSetting.get_setting('ENABLE_PLUGINS_URL', False, create=False, cache=False): + if InvenTreeSetting.get_setting( + 'ENABLE_PLUGINS_URL', False, create=False, cache=False + ): for plugin in registry.plugins.values(): if plugin.mixin_enabled('urls'): urls.append(plugin.urlpatterns) diff --git a/InvenTree/plugin/views.py b/InvenTree/plugin/views.py index 49a772f401..a332046f3e 100644 --- a/InvenTree/plugin/views.py +++ b/InvenTree/plugin/views.py @@ -21,13 +21,16 @@ class InvenTreePluginViewMixin: panels = [] for plug in registry.with_mixin('panel', active=True): - try: panels += plug.render_panels(self, self.request, ctx) except Exception: # Log the error to the database log_error(self.request.path) - logger.exception("Plugin '%s' could not render custom panels at '%s'", plug.slug, self.request.path) + logger.exception( + "Plugin '%s' could not render custom panels at '%s'", + plug.slug, + self.request.path, + ) return panels diff --git a/InvenTree/report/admin.py b/InvenTree/report/admin.py index 18b60392c5..fa359eee70 100644 --- a/InvenTree/report/admin.py +++ b/InvenTree/report/admin.py @@ -2,23 +2,34 @@ from django.contrib import admin -from .models import (BillOfMaterialsReport, BuildReport, PurchaseOrderReport, - ReportAsset, ReportSnippet, ReturnOrderReport, - SalesOrderReport, StockLocationReport, TestReport) +from .models import ( + BillOfMaterialsReport, + BuildReport, + PurchaseOrderReport, + ReportAsset, + ReportSnippet, + ReturnOrderReport, + SalesOrderReport, + StockLocationReport, + TestReport, +) class ReportTemplateAdmin(admin.ModelAdmin): """Admin class for the various reporting models""" + list_display = ('name', 'description', 'template', 'filters', 'enabled', 'revision') class ReportSnippetAdmin(admin.ModelAdmin): """Admin class for the ReportSnippet model""" + list_display = ('id', 'snippet', 'description') class ReportAssetAdmin(admin.ModelAdmin): """Admin class for the ReportAsset model""" + list_display = ('id', 'asset', 'description') diff --git a/InvenTree/report/api.py b/InvenTree/report/api.py index 876f0fff1b..8a5a6cbd10 100644 --- a/InvenTree/report/api.py +++ b/InvenTree/report/api.py @@ -20,36 +20,37 @@ import part.models from InvenTree.api import MetadataView from InvenTree.exceptions import log_error from InvenTree.filters import InvenTreeSearchFilter -from InvenTree.mixins import (ListCreateAPI, RetrieveAPI, - RetrieveUpdateDestroyAPI) +from InvenTree.mixins import ListCreateAPI, RetrieveAPI, RetrieveUpdateDestroyAPI from stock.models import StockItem, StockItemAttachment, StockLocation -from .models import (BillOfMaterialsReport, BuildReport, PurchaseOrderReport, - ReturnOrderReport, SalesOrderReport, StockLocationReport, - TestReport) -from .serializers import (BOMReportSerializer, BuildReportSerializer, - PurchaseOrderReportSerializer, - ReturnOrderReportSerializer, - SalesOrderReportSerializer, - StockLocationReportSerializer, TestReportSerializer) +from .models import ( + BillOfMaterialsReport, + BuildReport, + PurchaseOrderReport, + ReturnOrderReport, + SalesOrderReport, + StockLocationReport, + TestReport, +) +from .serializers import ( + BOMReportSerializer, + BuildReportSerializer, + PurchaseOrderReportSerializer, + ReturnOrderReportSerializer, + SalesOrderReportSerializer, + StockLocationReportSerializer, + TestReportSerializer, +) class ReportListView(ListCreateAPI): """Generic API class for report templates.""" - filter_backends = [ - DjangoFilterBackend, - InvenTreeSearchFilter, - ] + filter_backends = [DjangoFilterBackend, InvenTreeSearchFilter] - filterset_fields = [ - 'enabled', - ] + filterset_fields = ['enabled'] - search_fields = [ - 'name', - 'description', - ] + search_fields = ['name', 'description'] class ReportFilterMixin: @@ -71,7 +72,9 @@ class ReportFilterMixin: def get_items(self): """Return a list of database objects from query parameters""" if not self.ITEM_MODEL: - raise NotImplementedError(f"ITEM_MODEL attribute not defined for {__class__}") + raise NotImplementedError( + f'ITEM_MODEL attribute not defined for {__class__}' + ) ids = [] @@ -169,19 +172,19 @@ class ReportPrintMixin: """Print this report template against a number of pre-validated items.""" if len(items_to_print) == 0: # No valid items provided, return an error message - data = { - 'error': _('No valid objects provided to template'), - } + data = {'error': _('No valid objects provided to template')} return Response(data, status=400) outputs = [] # In debug mode, generate single HTML output, rather than PDF - debug_mode = common.models.InvenTreeSetting.get_setting('REPORT_DEBUG_MODE', cache=False) + debug_mode = common.models.InvenTreeSetting.get_setting( + 'REPORT_DEBUG_MODE', cache=False + ) # Start with a default report name - report_name = "report.pdf" + report_name = 'report.pdf' try: # Merge one or more PDF files into a single download @@ -207,7 +210,9 @@ class ReportPrintMixin: return Response( { - 'error': _(f"Template file '{template}' is missing or does not exist"), + 'error': _( + f"Template file '{template}' is missing or does not exist" + ) }, status=400, ) @@ -218,7 +223,7 @@ class ReportPrintMixin: if debug_mode: """Concatenate all rendered templates into a single HTML string, and return the string as a HTML response.""" - html = "\n".join(outputs) + html = '\n'.join(outputs) return HttpResponse(html) else: @@ -235,7 +240,6 @@ class ReportPrintMixin: pdf = outputs[0].get_document().copy(pages).write_pdf() except TemplateDoesNotExist as e: - template = str(e) if not template: @@ -243,18 +247,19 @@ class ReportPrintMixin: return Response( { - 'error': _(f"Template file '{template}' is missing or does not exist"), + 'error': _( + f"Template file '{template}' is missing or does not exist" + ) }, status=400, ) - inline = common.models.InvenTreeUserSetting.get_setting('REPORT_INLINE', user=request.user, cache=False) + inline = common.models.InvenTreeUserSetting.get_setting( + 'REPORT_INLINE', user=request.user, cache=False + ) return InvenTree.helpers.DownloadFile( - pdf, - report_name, - content_type='application/pdf', - inline=inline, + pdf, report_name, content_type='application/pdf', inline=inline ) except Exception as exc: @@ -294,11 +299,13 @@ class StockItemTestReportList(StockItemTestReportMixin, ReportListView): - enabled: Filter by enabled / disabled status - item: Filter by stock item(s) """ + pass class StockItemTestReportDetail(StockItemTestReportMixin, RetrieveUpdateDestroyAPI): """API endpoint for a single TestReport object.""" + pass @@ -307,12 +314,13 @@ class StockItemTestReportPrint(StockItemTestReportMixin, ReportPrintMixin, Retri def report_callback(self, item, report, request): """Callback to (optionally) save a copy of the generated report""" - if common.models.InvenTreeSetting.get_setting('REPORT_ATTACH_TEST_REPORT', cache=False): - + if common.models.InvenTreeSetting.get_setting( + 'REPORT_ATTACH_TEST_REPORT', cache=False + ): # Construct a PDF file object try: pdf = report.get_document().write_pdf() - pdf_content = ContentFile(pdf, "test_report.pdf") + pdf_content = ContentFile(pdf, 'test_report.pdf') except TemplateDoesNotExist: return @@ -320,7 +328,7 @@ class StockItemTestReportPrint(StockItemTestReportMixin, ReportPrintMixin, Retri attachment=pdf_content, stock_item=item, user=request.user, - comment=_("Test report") + comment=_('Test report'), ) @@ -342,16 +350,19 @@ class BOMReportList(BOMReportMixin, ReportListView): - enabled: Filter by enabled / disabled status - part: Filter by part(s) """ + pass class BOMReportDetail(BOMReportMixin, RetrieveUpdateDestroyAPI): """API endpoint for a single BillOfMaterialReport object.""" + pass class BOMReportPrint(BOMReportMixin, ReportPrintMixin, RetrieveAPI): """API endpoint for printing a BillOfMaterialReport object.""" + pass @@ -373,16 +384,19 @@ class BuildReportList(BuildReportMixin, ReportListView): - enabled: Filter by enabled / disabled status - build: Filter by Build object """ + pass class BuildReportDetail(BuildReportMixin, RetrieveUpdateDestroyAPI): """API endpoint for a single BuildReport object.""" + pass class BuildReportPrint(BuildReportMixin, ReportPrintMixin, RetrieveAPI): """API endpoint for printing a BuildReport.""" + pass @@ -398,16 +412,19 @@ class PurchaseOrderReportMixin(ReportFilterMixin): class PurchaseOrderReportList(PurchaseOrderReportMixin, ReportListView): """API list endpoint for the PurchaseOrderReport model""" + pass class PurchaseOrderReportDetail(PurchaseOrderReportMixin, RetrieveUpdateDestroyAPI): """API endpoint for a single PurchaseOrderReport object.""" + pass class PurchaseOrderReportPrint(PurchaseOrderReportMixin, ReportPrintMixin, RetrieveAPI): """API endpoint for printing a PurchaseOrderReport object.""" + pass @@ -423,16 +440,19 @@ class SalesOrderReportMixin(ReportFilterMixin): class SalesOrderReportList(SalesOrderReportMixin, ReportListView): """API list endpoint for the SalesOrderReport model""" + pass class SalesOrderReportDetail(SalesOrderReportMixin, RetrieveUpdateDestroyAPI): """API endpoint for a single SalesOrderReport object.""" + pass class SalesOrderReportPrint(SalesOrderReportMixin, ReportPrintMixin, RetrieveAPI): """API endpoint for printing a PurchaseOrderReport object.""" + pass @@ -448,16 +468,19 @@ class ReturnOrderReportMixin(ReportFilterMixin): class ReturnOrderReportList(ReturnOrderReportMixin, ReportListView): """API list endpoint for the ReturnOrderReport model""" + pass class ReturnOrderReportDetail(ReturnOrderReportMixin, RetrieveUpdateDestroyAPI): """API endpoint for a single ReturnOrderReport object""" + pass class ReturnOrderReportPrint(ReturnOrderReportMixin, ReportPrintMixin, RetrieveAPI): """API endpoint for printing a ReturnOrderReport object""" + pass @@ -472,107 +495,236 @@ class StockLocationReportMixin(ReportFilterMixin): class StockLocationReportList(StockLocationReportMixin, ReportListView): """API list endpoint for the StockLocationReportList model""" + pass class StockLocationReportDetail(StockLocationReportMixin, RetrieveUpdateDestroyAPI): """API endpoint for a single StockLocationReportDetail object.""" + pass class StockLocationReportPrint(StockLocationReportMixin, ReportPrintMixin, RetrieveAPI): """API endpoint for printing a StockLocationReportPrint object""" + pass report_api_urls = [ - # Purchase order reports - re_path(r'po/', include([ - # Detail views - path(r'/', include([ - re_path(r'print/', PurchaseOrderReportPrint.as_view(), name='api-po-report-print'), - re_path(r'metadata/', MetadataView.as_view(), {'model': PurchaseOrderReport}, name='api-po-report-metadata'), - path('', PurchaseOrderReportDetail.as_view(), name='api-po-report-detail'), - ])), - - # List view - path('', PurchaseOrderReportList.as_view(), name='api-po-report-list'), - ])), - + re_path( + r'po/', + include([ + # Detail views + path( + r'/', + include([ + re_path( + r'print/', + PurchaseOrderReportPrint.as_view(), + name='api-po-report-print', + ), + re_path( + r'metadata/', + MetadataView.as_view(), + {'model': PurchaseOrderReport}, + name='api-po-report-metadata', + ), + path( + '', + PurchaseOrderReportDetail.as_view(), + name='api-po-report-detail', + ), + ]), + ), + # List view + path('', PurchaseOrderReportList.as_view(), name='api-po-report-list'), + ]), + ), # Sales order reports - re_path(r'so/', include([ - # Detail views - path(r'/', include([ - re_path(r'print/', SalesOrderReportPrint.as_view(), name='api-so-report-print'), - re_path(r'metadata/', MetadataView.as_view(), {'model': SalesOrderReport}, name='api-so-report-metadata'), - path('', SalesOrderReportDetail.as_view(), name='api-so-report-detail'), - ])), - - path('', SalesOrderReportList.as_view(), name='api-so-report-list'), - ])), - + re_path( + r'so/', + include([ + # Detail views + path( + r'/', + include([ + re_path( + r'print/', + SalesOrderReportPrint.as_view(), + name='api-so-report-print', + ), + re_path( + r'metadata/', + MetadataView.as_view(), + {'model': SalesOrderReport}, + name='api-so-report-metadata', + ), + path( + '', + SalesOrderReportDetail.as_view(), + name='api-so-report-detail', + ), + ]), + ), + path('', SalesOrderReportList.as_view(), name='api-so-report-list'), + ]), + ), # Return order reports - re_path(r'ro/', include([ - path(r'/', include([ - path(r'print/', ReturnOrderReportPrint.as_view(), name='api-return-order-report-print'), - re_path(r'metadata/', MetadataView.as_view(), {'model': ReturnOrderReport}, name='api-so-report-metadata'), - path('', ReturnOrderReportDetail.as_view(), name='api-return-order-report-detail'), - ])), - path('', ReturnOrderReportList.as_view(), name='api-return-order-report-list'), - ])), - + re_path( + r'ro/', + include([ + path( + r'/', + include([ + path( + r'print/', + ReturnOrderReportPrint.as_view(), + name='api-return-order-report-print', + ), + re_path( + r'metadata/', + MetadataView.as_view(), + {'model': ReturnOrderReport}, + name='api-so-report-metadata', + ), + path( + '', + ReturnOrderReportDetail.as_view(), + name='api-return-order-report-detail', + ), + ]), + ), + path( + '', ReturnOrderReportList.as_view(), name='api-return-order-report-list' + ), + ]), + ), # Build reports - re_path(r'build/', include([ - # Detail views - path(r'/', include([ - re_path(r'print/?', BuildReportPrint.as_view(), name='api-build-report-print'), - re_path(r'metadata/', MetadataView.as_view(), {'model': BuildReport}, name='api-build-report-metadata'), - re_path(r'^.$', BuildReportDetail.as_view(), name='api-build-report-detail'), - ])), - - # List view - re_path(r'^.*$', BuildReportList.as_view(), name='api-build-report-list'), - ])), - + re_path( + r'build/', + include([ + # Detail views + path( + r'/', + include([ + re_path( + r'print/?', + BuildReportPrint.as_view(), + name='api-build-report-print', + ), + re_path( + r'metadata/', + MetadataView.as_view(), + {'model': BuildReport}, + name='api-build-report-metadata', + ), + re_path( + r'^.$', + BuildReportDetail.as_view(), + name='api-build-report-detail', + ), + ]), + ), + # List view + re_path(r'^.*$', BuildReportList.as_view(), name='api-build-report-list'), + ]), + ), # Bill of Material reports - re_path(r'bom/', include([ - - # Detail views - path(r'/', include([ - re_path(r'print/?', BOMReportPrint.as_view(), name='api-bom-report-print'), - re_path(r'metadata/', MetadataView.as_view(), {'model': BillOfMaterialsReport}, name='api-bom-report-metadata'), - re_path(r'^.*$', BOMReportDetail.as_view(), name='api-bom-report-detail'), - ])), - - # List view - re_path(r'^.*$', BOMReportList.as_view(), name='api-bom-report-list'), - ])), - + re_path( + r'bom/', + include([ + # Detail views + path( + r'/', + include([ + re_path( + r'print/?', + BOMReportPrint.as_view(), + name='api-bom-report-print', + ), + re_path( + r'metadata/', + MetadataView.as_view(), + {'model': BillOfMaterialsReport}, + name='api-bom-report-metadata', + ), + re_path( + r'^.*$', BOMReportDetail.as_view(), name='api-bom-report-detail' + ), + ]), + ), + # List view + re_path(r'^.*$', BOMReportList.as_view(), name='api-bom-report-list'), + ]), + ), # Stock item test reports - re_path(r'test/', include([ - # Detail views - path(r'/', include([ - re_path(r'print/?', StockItemTestReportPrint.as_view(), name='api-stockitem-testreport-print'), - re_path(r'metadata/', MetadataView.as_view(), {'report': TestReport}, name='api-stockitem-testreport-metadata'), - re_path(r'^.*$', StockItemTestReportDetail.as_view(), name='api-stockitem-testreport-detail'), - ])), - - # List view - re_path(r'^.*$', StockItemTestReportList.as_view(), name='api-stockitem-testreport-list'), - ])), - + re_path( + r'test/', + include([ + # Detail views + path( + r'/', + include([ + re_path( + r'print/?', + StockItemTestReportPrint.as_view(), + name='api-stockitem-testreport-print', + ), + re_path( + r'metadata/', + MetadataView.as_view(), + {'report': TestReport}, + name='api-stockitem-testreport-metadata', + ), + re_path( + r'^.*$', + StockItemTestReportDetail.as_view(), + name='api-stockitem-testreport-detail', + ), + ]), + ), + # List view + re_path( + r'^.*$', + StockItemTestReportList.as_view(), + name='api-stockitem-testreport-list', + ), + ]), + ), # Stock Location reports (Stock Location Reports -> sir) - re_path(r'slr/', include([ - # Detail views - path(r'/', include([ - re_path(r'print/?', StockLocationReportPrint.as_view(), name='api-stocklocation-report-print'), - re_path(r'metadata/', MetadataView.as_view(), {'report': StockLocationReport}, name='api-stocklocation-report-metadata'), - re_path(r'^.*$', StockLocationReportDetail.as_view(), name='api-stocklocation-report-detail'), - ])), - - # List view - re_path(r'^.*$', StockLocationReportList.as_view(), name='api-stocklocation-report-list'), - ])), - + re_path( + r'slr/', + include([ + # Detail views + path( + r'/', + include([ + re_path( + r'print/?', + StockLocationReportPrint.as_view(), + name='api-stocklocation-report-print', + ), + re_path( + r'metadata/', + MetadataView.as_view(), + {'report': StockLocationReport}, + name='api-stocklocation-report-metadata', + ), + re_path( + r'^.*$', + StockLocationReportDetail.as_view(), + name='api-stocklocation-report-detail', + ), + ]), + ), + # List view + re_path( + r'^.*$', + StockLocationReportList.as_view(), + name='api-stocklocation-report-list', + ), + ]), + ), ] diff --git a/InvenTree/report/apps.py b/InvenTree/report/apps.py index 6523453df7..207412e33c 100644 --- a/InvenTree/report/apps.py +++ b/InvenTree/report/apps.py @@ -11,11 +11,12 @@ from django.conf import settings from django.core.exceptions import AppRegistryNotReady from django.db.utils import IntegrityError, OperationalError, ProgrammingError -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') class ReportConfig(AppConfig): """Configuration class for the 'report' app""" + name = 'report' def ready(self): @@ -23,7 +24,10 @@ class ReportConfig(AppConfig): import InvenTree.ready # skip loading if plugin registry is not loaded or we run in a background thread - if not InvenTree.ready.isPluginRegistryLoaded() or not InvenTree.ready.isInMainThread(): + if ( + not InvenTree.ready.isPluginRegistryLoaded() + or not InvenTree.ready.isInMainThread() + ): return if InvenTree.ready.isRunningMigrations(): @@ -34,8 +38,10 @@ class ReportConfig(AppConfig): logging.getLogger('weasyprint').setLevel(logging.WARNING) # Create entries for default report templates - if InvenTree.ready.canAppAccessDatabase(allow_test=False) and not InvenTree.ready.isImportingData(): - + if ( + InvenTree.ready.canAppAccessDatabase(allow_test=False) + and not InvenTree.ready.isImportingData() + ): try: self.create_default_test_reports() self.create_default_build_reports() @@ -44,24 +50,24 @@ class ReportConfig(AppConfig): self.create_default_sales_order_reports() self.create_default_return_order_reports() self.create_default_stock_location_reports() - except (AppRegistryNotReady, IntegrityError, OperationalError, ProgrammingError): + except ( + AppRegistryNotReady, + IntegrityError, + OperationalError, + ProgrammingError, + ): # Database might not yet be ready - warnings.warn('Database was not ready for creating reports', stacklevel=2) + warnings.warn( + 'Database was not ready for creating reports', stacklevel=2 + ) def create_default_reports(self, model, reports): """Copy default report files across to the media directory.""" # Source directory for report templates - src_dir = Path(__file__).parent.joinpath( - 'templates', - 'report', - ) + src_dir = Path(__file__).parent.joinpath('templates', 'report') # Destination directory - dst_dir = settings.MEDIA_ROOT.joinpath( - 'report', - 'inventree', - model.getSubdir(), - ) + dst_dir = settings.MEDIA_ROOT.joinpath('report', 'inventree', model.getSubdir()) if not dst_dir.exists(): logger.info("Creating missing directory: '%s'", dst_dir) @@ -69,13 +75,9 @@ class ReportConfig(AppConfig): # Copy each report template across (if required) for report in reports: - # Destination filename filename = os.path.join( - 'report', - 'inventree', - model.getSubdir(), - report['file'], + 'report', 'inventree', model.getSubdir(), report['file'] ) src_file = src_dir.joinpath(report['file']) @@ -96,7 +98,7 @@ class ReportConfig(AppConfig): name=report['name'], description=report['description'], template=filename, - enabled=True + enabled=True, ) except Exception: @@ -116,7 +118,7 @@ class ReportConfig(AppConfig): 'file': 'inventree_test_report.html', 'name': 'InvenTree Test Report', 'description': 'Stock item test report', - }, + } ] self.create_default_reports(TestReport, reports) diff --git a/InvenTree/report/helpers.py b/InvenTree/report/helpers.py index 2088913db5..d3594300e3 100644 --- a/InvenTree/report/helpers.py +++ b/InvenTree/report/helpers.py @@ -46,7 +46,7 @@ def report_page_size_default(): try: page_size = InvenTreeSetting.get_setting('REPORT_DEFAULT_PAGE_SIZE', 'A4') except Exception as exc: - logger.exception("Error getting default page size: %s", str(exc)) + logger.exception('Error getting default page size: %s', str(exc)) page_size = 'A4' return page_size @@ -70,4 +70,4 @@ def encode_image_base64(image, format: str = 'PNG'): img_str = base64.b64encode(buffered.getvalue()) - return f"data:image/{fmt};charset=utf-8;base64," + img_str.decode() + return f'data:image/{fmt};charset=utf-8;base64,' + img_str.decode() diff --git a/InvenTree/report/models.py b/InvenTree/report/models.py index 42085825eb..e64ae4abda 100644 --- a/InvenTree/report/models.py +++ b/InvenTree/report/models.py @@ -28,12 +28,12 @@ from plugin.registry import registry try: from django_weasyprint import WeasyTemplateResponseMixin except OSError as err: # pragma: no cover - print(f"OSError: {err}") - print("You may require some further system packages to be installed.") + print(f'OSError: {err}') + print('You may require some further system packages to be installed.') sys.exit(1) -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') def rename_template(instance, filename): @@ -105,7 +105,9 @@ class ReportBase(models.Model): """Initialize the particular report instance""" super().__init__(*args, **kwargs) - self._meta.get_field('page_size').choices = report.helpers.report_page_size_options() + self._meta.get_field( + 'page_size' + ).choices = report.helpers.report_page_size_options() def save(self, *args, **kwargs): """Perform additional actions when the report is saved""" @@ -116,7 +118,7 @@ class ReportBase(models.Model): def __str__(self): """Format a string representation of a report instance""" - return f"{self.name} - {self.description}" + return f'{self.name} - {self.description}' @classmethod def getSubdir(cls): @@ -134,7 +136,6 @@ class ReportBase(models.Model): # If the report file is the *same* filename as the one being uploaded, # remove the original one from the media directory if str(filename) == str(self.template): - if fullpath.exists(): logger.info("Deleting existing report template: '%s'", filename) os.remove(fullpath) @@ -166,7 +167,8 @@ class ReportBase(models.Model): return template name = models.CharField( - blank=False, max_length=100, + blank=False, + max_length=100, verbose_name=_('Name'), help_text=_('Template name'), ) @@ -174,20 +176,20 @@ class ReportBase(models.Model): template = models.FileField( upload_to=rename_template, verbose_name=_('Template'), - help_text=_("Report template file"), + help_text=_('Report template file'), validators=[FileExtensionValidator(allowed_extensions=['html', 'htm'])], ) description = models.CharField( max_length=250, verbose_name=_('Description'), - help_text=_("Report template description") + help_text=_('Report template description'), ) revision = models.PositiveIntegerField( default=1, - verbose_name=_("Revision"), - help_text=_("Report revision number (auto-increments)"), + verbose_name=_('Revision'), + help_text=_('Report revision number (auto-increments)'), editable=False, ) @@ -213,6 +215,7 @@ class ReportTemplateBase(MetadataMixin, ReportBase): class Meta: """Metaclass options. Abstract ensures no database table is created.""" + abstract = True # Pass a single top-level object to the report template @@ -225,7 +228,9 @@ class ReportTemplateBase(MetadataMixin, ReportBase): def get_report_size(self): """Return the printable page size for this report""" try: - page_size_default = common.models.InvenTreeSetting.get_setting('REPORT_DEFAULT_PAGE_SIZE', 'A4') + page_size_default = common.models.InvenTreeSetting.get_setting( + 'REPORT_DEFAULT_PAGE_SIZE', 'A4' + ) except Exception: page_size_default = 'A4' @@ -290,17 +295,16 @@ class ReportTemplateBase(MetadataMixin, ReportBase): wp = WeasyprintReportMixin( request, self.template_name, - base_url=request.build_absolute_uri("/"), + base_url=request.build_absolute_uri('/'), presentational_hints=True, filename=self.generate_filename(request), - **kwargs) + **kwargs, + ) - return wp.render_to_response( - self.context(request), - **kwargs) + return wp.render_to_response(self.context(request), **kwargs) filename_pattern = models.CharField( - default="report.pdf", + default='report.pdf', verbose_name=_('Filename Pattern'), help_text=_('Pattern for generating report filenames'), max_length=100, @@ -330,16 +334,18 @@ class TestReport(ReportTemplateBase): blank=True, max_length=250, verbose_name=_('Filters'), - help_text=_("StockItem query filters (comma-separated list of key=value pairs)"), - validators=[ - validate_stock_item_report_filters - ] + help_text=_( + 'StockItem query filters (comma-separated list of key=value pairs)' + ), + validators=[validate_stock_item_report_filters], ) include_installed = models.BooleanField( default=False, verbose_name=_('Include Installed Tests'), - help_text=_('Include test results for stock items installed inside assembled item') + help_text=_( + 'Include test results for stock items installed inside assembled item' + ), ) def get_test_keys(self, stock_item): @@ -359,7 +365,9 @@ class TestReport(ReportTemplateBase): if test.key not in keys: keys.append(test.key) - for result in stock_item.testResultList(include_installed=self.include_installed): + for result in stock_item.testResultList( + include_installed=self.include_installed + ): if result.key not in keys: keys.append(result.key) @@ -377,8 +385,12 @@ class TestReport(ReportTemplateBase): 'test_keys': self.get_test_keys(stock_item), 'test_template_list': stock_item.part.getTestTemplates(), 'test_template_map': stock_item.part.getTestTemplateMap(), - 'results': stock_item.testResultMap(include_installed=self.include_installed), - 'result_list': stock_item.testResultList(include_installed=self.include_installed), + 'results': stock_item.testResultMap( + include_installed=self.include_installed + ), + 'result_list': stock_item.testResultList( + include_installed=self.include_installed + ), 'installed_items': stock_item.get_installed_items(cascade=True), } @@ -401,9 +413,7 @@ class BuildReport(ReportTemplateBase): max_length=250, verbose_name=_('Build Filters'), help_text=_('Build query filters (comma-separated list of key=value pairs'), - validators=[ - validate_build_report_filters, - ] + validators=[validate_build_report_filters], ) def get_context_data(self, request): @@ -443,9 +453,7 @@ class BillOfMaterialsReport(ReportTemplateBase): max_length=250, verbose_name=_('Part Filters'), help_text=_('Part query filters (comma-separated list of key=value pairs'), - validators=[ - validate_part_report_filters - ] + validators=[validate_part_report_filters], ) def get_context_data(self, request): @@ -477,9 +485,7 @@ class PurchaseOrderReport(ReportTemplateBase): max_length=250, verbose_name=_('Filters'), help_text=_('Purchase order query filters'), - validators=[ - validate_purchase_order_filters, - ] + validators=[validate_purchase_order_filters], ) def get_context_data(self, request): @@ -515,9 +521,7 @@ class SalesOrderReport(ReportTemplateBase): max_length=250, verbose_name=_('Filters'), help_text=_('Sales order query filters'), - validators=[ - validate_sales_order_filters - ] + validators=[validate_sales_order_filters], ) def get_context_data(self, request): @@ -553,9 +557,7 @@ class ReturnOrderReport(ReportTemplateBase): max_length=250, verbose_name=_('Filters'), help_text=_('Return order query filters'), - validators=[ - validate_return_order_filters, - ] + validators=[validate_return_order_filters], ) def get_context_data(self, request): @@ -584,7 +586,6 @@ def rename_snippet(instance, filename): # If the snippet file is the *same* filename as the one being uploaded, # delete the original one from the media directory if str(filename) == str(instance.snippet): - if fullpath.exists(): logger.info("Deleting existing snippet file: '%s'", filename) os.remove(fullpath) @@ -608,7 +609,11 @@ class ReportSnippet(models.Model): validators=[FileExtensionValidator(allowed_extensions=['html', 'htm'])], ) - description = models.CharField(max_length=250, verbose_name=_('Description'), help_text=_("Snippet file description")) + description = models.CharField( + max_length=250, + verbose_name=_('Description'), + help_text=_('Snippet file description'), + ) def rename_asset(instance, filename): @@ -645,14 +650,14 @@ class ReportAsset(models.Model): asset = models.FileField( upload_to=rename_asset, verbose_name=_('Asset'), - help_text=_("Report asset file"), + help_text=_('Report asset file'), ) # Asset description (user facing string, not used internally) description = models.CharField( max_length=250, verbose_name=_('Description'), - help_text=_("Asset file description") + help_text=_('Asset file description'), ) @@ -673,10 +678,10 @@ class StockLocationReport(ReportTemplateBase): blank=True, max_length=250, verbose_name=_('Filters'), - help_text=_("stock location query filters (comma-separated list of key=value pairs)"), - validators=[ - validate_stock_location_report_filters - ] + help_text=_( + 'stock location query filters (comma-separated list of key=value pairs)' + ), + validators=[validate_stock_location_report_filters], ) def get_context_data(self, request): @@ -684,7 +689,10 @@ class StockLocationReport(ReportTemplateBase): stock_location = self.object_to_print if not isinstance(stock_location, stock.models.StockLocation): - raise TypeError('Provided model is not a StockLocation object -> ' + str(type(stock_location))) + raise TypeError( + 'Provided model is not a StockLocation object -> ' + + str(type(stock_location)) + ) return { 'stock_location': stock_location, diff --git a/InvenTree/report/serializers.py b/InvenTree/report/serializers.py index 2e690b6791..5dcd9b49a7 100644 --- a/InvenTree/report/serializers.py +++ b/InvenTree/report/serializers.py @@ -1,11 +1,19 @@ """API serializers for the reporting models""" -from InvenTree.serializers import (InvenTreeAttachmentSerializerField, - InvenTreeModelSerializer) +from InvenTree.serializers import ( + InvenTreeAttachmentSerializerField, + InvenTreeModelSerializer, +) -from .models import (BillOfMaterialsReport, BuildReport, PurchaseOrderReport, - ReturnOrderReport, SalesOrderReport, StockLocationReport, - TestReport) +from .models import ( + BillOfMaterialsReport, + BuildReport, + PurchaseOrderReport, + ReturnOrderReport, + SalesOrderReport, + StockLocationReport, + TestReport, +) class ReportSerializerBase(InvenTreeModelSerializer): @@ -16,14 +24,7 @@ class ReportSerializerBase(InvenTreeModelSerializer): @staticmethod def report_fields(): """Generic serializer fields for a report template""" - return [ - 'pk', - 'name', - 'description', - 'template', - 'filters', - 'enabled', - ] + return ['pk', 'name', 'description', 'template', 'filters', 'enabled'] class TestReportSerializer(ReportSerializerBase): diff --git a/InvenTree/report/templatetags/barcode.py b/InvenTree/report/templatetags/barcode.py index 723064f984..451136eb0f 100644 --- a/InvenTree/report/templatetags/barcode.py +++ b/InvenTree/report/templatetags/barcode.py @@ -35,11 +35,7 @@ def qrcode(data, **kwargs): """ # Construct "default" values - params = { - "box_size": 20, - "border": 1, - "version": 1, - } + params = {'box_size': 20, 'border': 1, 'version': 1} fill_color = kwargs.pop('fill_color', 'black') back_color = kwargs.pop('back_color', 'white') @@ -53,10 +49,7 @@ def qrcode(data, **kwargs): qr.add_data(data, optimize=20) qr.make(fit=True) - qri = qr.make_image( - fill_color=fill_color, - back_color=back_color - ) + qri = qr.make_image(fill_color=fill_color, back_color=back_color) # Render to byte-encoded image return image_data(qri, fmt=format) diff --git a/InvenTree/report/templatetags/report.py b/InvenTree/report/templatetags/report.py index 726703312d..c2a096136f 100644 --- a/InvenTree/report/templatetags/report.py +++ b/InvenTree/report/templatetags/report.py @@ -63,7 +63,7 @@ def getkey(container: dict, key): key: The 'key' to be found within the dict """ if type(container) is not dict: - logger.warning("getkey() called with non-dict object") + logger.warning('getkey() called with non-dict object') return None if key in container: @@ -92,15 +92,21 @@ def asset(filename): full_path = settings.MEDIA_ROOT.joinpath('report', 'assets', filename).resolve() if not full_path.exists() or not full_path.is_file(): - raise FileNotFoundError(_("Asset file does not exist") + f": '{filename}'") + raise FileNotFoundError(_('Asset file does not exist') + f": '{filename}'") if debug_mode: return os.path.join(settings.MEDIA_URL, 'report', 'assets', filename) - return f"file://{full_path}" + return f'file://{full_path}' @register.simple_tag() -def uploaded_image(filename, replace_missing=True, replacement_file='blank_image.png', validate=True, **kwargs): +def uploaded_image( + filename, + replace_missing=True, + replacement_file='blank_image.png', + validate=True, + **kwargs, +): """Return a fully-qualified path for an 'uploaded' image. Arguments: @@ -141,7 +147,7 @@ def uploaded_image(filename, replace_missing=True, replacement_file='blank_image exists = False if not exists and not replace_missing: - raise FileNotFoundError(_("Image file not found") + f": '{filename}'") + raise FileNotFoundError(_('Image file not found') + f": '{filename}'") if debug_mode: # In debug mode, return a web path (rather than an encoded image blob) @@ -167,12 +173,12 @@ def uploaded_image(filename, replace_missing=True, replacement_file='blank_image img = img.resize((width, height)) elif width is not None: # Resize the image, width only - wpercent = (width / float(img.size[0])) + wpercent = width / float(img.size[0]) hsize = int((float(img.size[1]) * float(wpercent))) img = img.resize((width, hsize)) elif height is not None: # Resize the image, height only - hpercent = (height / float(img.size[1])) + hpercent = height / float(img.size[1]) wsize = int((float(img.size[0]) * float(hpercent))) img = img.resize((wsize, height)) @@ -206,14 +212,16 @@ def encode_svg_image(filename): exists = False if not exists: - raise FileNotFoundError(_("Image file not found") + f": '{filename}'") + raise FileNotFoundError(_('Image file not found') + f": '{filename}'") # Read the file data with open(full_path, 'rb') as f: data = f.read() # Return the base64-encoded data - return "data:image/svg+xml;charset=utf-8;base64," + base64.b64encode(data).decode('utf-8') + return 'data:image/svg+xml;charset=utf-8;base64,' + base64.b64encode(data).decode( + 'utf-8' + ) @register.simple_tag() @@ -227,7 +235,7 @@ def part_image(part: Part, preview=False, thumbnail=False, **kwargs): TypeError if provided part is not a Part instance """ if type(part) is not Part: - raise TypeError(_("part_image tag requires a Part instance")) + raise TypeError(_('part_image tag requires a Part instance')) if preview: img = part.image.preview.name @@ -266,7 +274,7 @@ def company_image(company, preview=False, thumbnail=False, **kwargs): TypeError if provided company is not a Company instance """ if type(company) is not Company: - raise TypeError(_("company_image tag requires a Company instance")) + raise TypeError(_('company_image tag requires a Company instance')) if preview: img = company.image.preview.name diff --git a/InvenTree/report/tests.py b/InvenTree/report/tests.py index 4f81623fc4..96f64a22a2 100644 --- a/InvenTree/report/tests.py +++ b/InvenTree/report/tests.py @@ -44,12 +44,7 @@ class ReportTagTest(TestCase): def test_getkey(self): """Tests for the 'getkey' template tag""" - data = { - 'hello': 'world', - 'foo': 'bar', - 'with spaces': 'withoutspaces', - 1: 2, - } + data = {'hello': 'world', 'foo': 'bar', 'with spaces': 'withoutspaces', 1: 2} for k, v in data.items(): self.assertEqual(report_tags.getkey(data, k), v) @@ -61,14 +56,14 @@ class ReportTagTest(TestCase): self.debug_mode(b) with self.assertRaises(FileNotFoundError): - report_tags.asset("bad_file.txt") + report_tags.asset('bad_file.txt') # Create an asset file asset_dir = settings.MEDIA_ROOT.joinpath('report', 'assets') asset_dir.mkdir(parents=True, exist_ok=True) asset_path = asset_dir.joinpath('test.txt') - asset_path.write_text("dummy data") + asset_path.write_text('dummy data') self.debug_mode(True) asset = report_tags.asset('test.txt') @@ -89,7 +84,9 @@ class ReportTagTest(TestCase): self.debug_mode(b) with self.assertRaises(FileNotFoundError): - report_tags.uploaded_image('/part/something/test.png', replace_missing=False) + report_tags.uploaded_image( + '/part/something/test.png', replace_missing=False + ) img = str(report_tags.uploaded_image('/part/something/other.png')) @@ -104,7 +101,7 @@ class ReportTagTest(TestCase): img_file = img_path.joinpath('test.jpg') img_path.mkdir(parents=True, exist_ok=True) - img_file.write_text("dummy data") + img_file.write_text('dummy data') # Test in debug mode. Returns blank image as dummy file is not a valid image self.debug_mode(True) @@ -161,7 +158,7 @@ class BarcodeTagTest(TestCase): def test_barcode(self): """Test the barcode generation tag""" - barcode = barcode_tags.barcode("12345") + barcode = barcode_tags.barcode('12345') self.assertTrue(isinstance(barcode, str)) self.assertTrue(barcode.startswith('data:image/png;')) @@ -174,17 +171,14 @@ class BarcodeTagTest(TestCase): def test_qrcode(self): """Test the qrcode generation tag""" # Test with default settings - qrcode = barcode_tags.qrcode("hello world") + qrcode = barcode_tags.qrcode('hello world') self.assertTrue(isinstance(qrcode, str)) self.assertTrue(qrcode.startswith('data:image/png;')) self.assertEqual(len(qrcode), 700) # Generate a much larger qrcode qrcode = barcode_tags.qrcode( - "hello_world", - version=2, - box_size=50, - format='BMP', + 'hello_world', version=2, box_size=50, format='BMP' ) self.assertTrue(isinstance(qrcode, str)) self.assertTrue(qrcode.startswith('data:image/bmp;')) @@ -193,6 +187,7 @@ class BarcodeTagTest(TestCase): class ReportTest(InvenTreeAPITestCase): """Base class for unit testing reporting models""" + fixtures = [ 'category', 'part', @@ -217,16 +212,9 @@ class ReportTest(InvenTreeAPITestCase): def copyReportTemplate(self, filename, description): """Copy the provided report template into the required media directory.""" - src_dir = Path(__file__).parent.joinpath( - 'templates', - 'report' - ) + src_dir = Path(__file__).parent.joinpath('templates', 'report') - template_dir = os.path.join( - 'report', - 'inventree', - self.model.getSubdir(), - ) + template_dir = os.path.join('report', 'inventree', self.model.getSubdir()) dst_dir = settings.MEDIA_ROOT.joinpath(template_dir) @@ -240,17 +228,14 @@ class ReportTest(InvenTreeAPITestCase): shutil.copyfile(src_file, dst_file) # Convert to an "internal" filename - db_filename = os.path.join( - template_dir, - filename - ) + db_filename = os.path.join(template_dir, filename) # Create a database entry for this report template! self.model.objects.create( name=os.path.splitext(filename)[0], description=description, template=db_filename, - enabled=True + enabled=True, ) def test_list_endpoint(self): @@ -310,6 +295,7 @@ class ReportTest(InvenTreeAPITestCase): class TestReportTest(ReportTest): """Unit testing class for the stock item TestReport model""" + model = report_models.TestReport list_url = 'api-stockitem-testreport-list' @@ -363,6 +349,7 @@ class TestReportTest(ReportTest): class BuildReportTest(ReportTest): """Unit test class for the BuildReport model""" + model = report_models.BuildReport list_url = 'api-build-report-list' @@ -398,21 +385,28 @@ class BuildReportTest(ReportTest): headers = response.headers self.assertEqual(headers['Content-Type'], 'application/pdf') - self.assertEqual(headers['Content-Disposition'], 'attachment; filename="report.pdf"') + self.assertEqual( + headers['Content-Disposition'], 'attachment; filename="report.pdf"' + ) # Now, set the download type to be "inline" - inline = InvenTreeUserSetting.get_setting_object('REPORT_INLINE', cache=False, user=self.user) + inline = InvenTreeUserSetting.get_setting_object( + 'REPORT_INLINE', cache=False, user=self.user + ) inline.value = True inline.save() response = self.get(url, {'build': 1}) headers = response.headers self.assertEqual(headers['Content-Type'], 'application/pdf') - self.assertEqual(headers['Content-Disposition'], 'inline; filename="report.pdf"') + self.assertEqual( + headers['Content-Disposition'], 'inline; filename="report.pdf"' + ) class BOMReportTest(ReportTest): """Unit test class for the BillOfMaterialsReport model""" + model = report_models.BillOfMaterialsReport list_url = 'api-bom-report-list' @@ -421,13 +415,16 @@ class BOMReportTest(ReportTest): def setUp(self): """Setup function for the bill of materials Report""" - self.copyReportTemplate('inventree_bill_of_materials_report.html', 'bill of materials report') + self.copyReportTemplate( + 'inventree_bill_of_materials_report.html', 'bill of materials report' + ) return super().setUp() class PurchaseOrderReportTest(ReportTest): """Unit test class for the PurchaseOrderReport model""" + model = report_models.PurchaseOrderReport list_url = 'api-po-report-list' @@ -443,6 +440,7 @@ class PurchaseOrderReportTest(ReportTest): class SalesOrderReportTest(ReportTest): """Unit test class for the SalesOrderReport model""" + model = report_models.SalesOrderReport list_url = 'api-so-report-list' @@ -466,7 +464,9 @@ class ReturnOrderReportTest(ReportTest): def setUp(self): """Setup function for the ReturnOrderReport tests""" - self.copyReportTemplate('inventree_return_order_report.html', 'return order report') + self.copyReportTemplate( + 'inventree_return_order_report.html', 'return order report' + ) return super().setUp() diff --git a/InvenTree/script/translation_stats.py b/InvenTree/script/translation_stats.py index e46d682fae..f7a848f042 100644 --- a/InvenTree/script/translation_stats.py +++ b/InvenTree/script/translation_stats.py @@ -15,11 +15,10 @@ def calculate_coverage(filename): lines_uncovered = 0 for line in lines: - - if line.startswith("msgid "): + if line.startswith('msgid '): lines_count += 1 - elif line.startswith("msgstr"): + elif line.startswith('msgstr'): if line.startswith('msgstr ""') or line.startswith("msgstr ''"): lines_uncovered += 1 else: @@ -30,10 +29,11 @@ def calculate_coverage(filename): if __name__ == '__main__': - MY_DIR = os.path.dirname(os.path.realpath(__file__)) LC_DIR = os.path.abspath(os.path.join(MY_DIR, '..', 'locale')) - STAT_FILE = os.path.abspath(os.path.join(MY_DIR, '..', 'InvenTree/locale_stats.json')) + STAT_FILE = os.path.abspath( + os.path.join(MY_DIR, '..', 'InvenTree/locale_stats.json') + ) locales = {} locales_perc = {} @@ -43,14 +43,13 @@ if __name__ == '__main__': for locale in os.listdir(LC_DIR): path = os.path.join(LC_DIR, locale) if os.path.exists(path) and os.path.isdir(path): - locale_file = os.path.join(path, 'LC_MESSAGES', 'django.po') if os.path.exists(locale_file) and os.path.isfile(locale_file): locales[locale] = locale_file if verbose: - print("-" * 16) + print('-' * 16) percentages = [] @@ -73,7 +72,7 @@ if __name__ == '__main__': percentages.append(percentage) if verbose: - print("-" * 16) + print('-' * 16) # write locale stats with open(STAT_FILE, 'w') as target: @@ -84,4 +83,4 @@ if __name__ == '__main__': else: avg = 0 - print(f"InvenTree translation coverage: {avg}%") + print(f'InvenTree translation coverage: {avg}%') diff --git a/InvenTree/stock/admin.py b/InvenTree/stock/admin.py index 939e074369..d31bf9e1a4 100644 --- a/InvenTree/stock/admin.py +++ b/InvenTree/stock/admin.py @@ -14,8 +14,14 @@ from InvenTree.admin import InvenTreeResource from order.models import PurchaseOrder, SalesOrder from part.models import Part -from .models import (StockItem, StockItemAttachment, StockItemTestResult, - StockItemTracking, StockLocation, StockLocationType) +from .models import ( + StockItem, + StockItemAttachment, + StockItemTestResult, + StockItemTracking, + StockLocation, + StockLocationType, +) class LocationResource(InvenTreeResource): @@ -31,21 +37,39 @@ class LocationResource(InvenTreeResource): exclude = [ # Exclude MPTT internal model fields - 'lft', 'rght', 'tree_id', 'level', + 'lft', + 'rght', + 'tree_id', + 'level', 'metadata', - 'barcode_data', 'barcode_hash', - 'owner', 'icon', + 'barcode_data', + 'barcode_hash', + 'owner', + 'icon', ] - id = Field(attribute='id', column_name=_('Location ID'), widget=widgets.IntegerWidget()) + id = Field( + attribute='id', column_name=_('Location ID'), widget=widgets.IntegerWidget() + ) name = Field(attribute='name', column_name=_('Location Name')) description = Field(attribute='description', column_name=_('Description')) - parent = Field(attribute='parent', column_name=_('Parent ID'), widget=widgets.ForeignKeyWidget(StockLocation)) - parent_name = Field(attribute='parent__name', column_name=_('Parent Name'), readonly=True) + parent = Field( + attribute='parent', + column_name=_('Parent ID'), + widget=widgets.ForeignKeyWidget(StockLocation), + ) + parent_name = Field( + attribute='parent__name', column_name=_('Parent Name'), readonly=True + ) pathstring = Field(attribute='pathstring', column_name=_('Location Path')) # Calculated fields - items = Field(attribute='item_count', column_name=_('Stock Items'), widget=widgets.IntegerWidget(), readonly=True) + items = Field( + attribute='item_count', + column_name=_('Stock Items'), + widget=widgets.IntegerWidget(), + readonly=True, + ) def after_import(self, dataset, result, using_transactions, dry_run, **kwargs): """Rebuild after import to keep tree intact.""" @@ -57,6 +81,7 @@ class LocationResource(InvenTreeResource): class LocationInline(admin.TabularInline): """Inline for sub-locations.""" + model = StockLocation @@ -69,25 +94,23 @@ class LocationAdmin(ImportExportModelAdmin): search_fields = ('name', 'description') - inlines = [ - LocationInline, - ] + inlines = [LocationInline] - autocomplete_fields = [ - 'parent', - ] + autocomplete_fields = ['parent'] class LocationTypeAdmin(admin.ModelAdmin): """Admin class for StockLocationType.""" list_display = ('name', 'description', 'icon', 'location_count') - readonly_fields = ('location_count', ) + readonly_fields = ('location_count',) def get_queryset(self, request): """Annotate queryset to fetch location count.""" - return super().get_queryset(request).annotate( - location_count=Count("stock_locations"), + return ( + super() + .get_queryset(request) + .annotate(location_count=Count('stock_locations')) ) def location_count(self, obj): @@ -108,45 +131,129 @@ class StockItemResource(InvenTreeResource): exclude = [ # Exclude MPTT internal model fields - 'lft', 'rght', 'tree_id', 'level', + 'lft', + 'rght', + 'tree_id', + 'level', # Exclude internal fields - 'serial_int', 'metadata', - 'barcode_hash', 'barcode_data', + 'serial_int', + 'metadata', + 'barcode_hash', + 'barcode_data', 'owner', ] - id = Field(attribute='pk', column_name=_('Stock Item ID'), widget=widgets.IntegerWidget()) - part = Field(attribute='part', column_name=_('Part ID'), widget=widgets.ForeignKeyWidget(Part)) - part_name = Field(attribute='part__full_name', column_name=_('Part Name'), readonly=True) - quantity = Field(attribute='quantity', column_name=_('Quantity'), widget=widgets.DecimalWidget()) + id = Field( + attribute='pk', column_name=_('Stock Item ID'), widget=widgets.IntegerWidget() + ) + part = Field( + attribute='part', + column_name=_('Part ID'), + widget=widgets.ForeignKeyWidget(Part), + ) + part_name = Field( + attribute='part__full_name', column_name=_('Part Name'), readonly=True + ) + quantity = Field( + attribute='quantity', column_name=_('Quantity'), widget=widgets.DecimalWidget() + ) serial = Field(attribute='serial', column_name=_('Serial')) batch = Field(attribute='batch', column_name=_('Batch')) - status_label = Field(attribute='status_label', column_name=_('Status'), readonly=True) - status = Field(attribute='status', column_name=_('Status Code'), widget=widgets.IntegerWidget()) - location = Field(attribute='location', column_name=_('Location ID'), widget=widgets.ForeignKeyWidget(StockLocation)) - location_name = Field(attribute='location__name', column_name=_('Location Name'), readonly=True) - supplier_part = Field(attribute='supplier_part', column_name=_('Supplier Part ID'), widget=widgets.ForeignKeyWidget(SupplierPart)) - supplier = Field(attribute='supplier_part__supplier__id', column_name=_('Supplier ID'), readonly=True, widget=widgets.IntegerWidget()) - supplier_name = Field(attribute='supplier_part__supplier__name', column_name=_('Supplier Name'), readonly=True) - customer = Field(attribute='customer', column_name=_('Customer ID'), widget=widgets.ForeignKeyWidget(Company)) - belongs_to = Field(attribute='belongs_to', column_name=_('Installed In'), widget=widgets.ForeignKeyWidget(StockItem)) - build = Field(attribute='build', column_name=_('Build ID'), widget=widgets.ForeignKeyWidget(Build)) - parent = Field(attribute='parent', column_name=_('Parent ID'), widget=widgets.ForeignKeyWidget(StockItem)) - sales_order = Field(attribute='sales_order', column_name=_('Sales Order ID'), widget=widgets.ForeignKeyWidget(SalesOrder)) - purchase_order = Field(attribute='purchase_order', column_name=_('Purchase Order ID'), widget=widgets.ForeignKeyWidget(PurchaseOrder)) + status_label = Field( + attribute='status_label', column_name=_('Status'), readonly=True + ) + status = Field( + attribute='status', column_name=_('Status Code'), widget=widgets.IntegerWidget() + ) + location = Field( + attribute='location', + column_name=_('Location ID'), + widget=widgets.ForeignKeyWidget(StockLocation), + ) + location_name = Field( + attribute='location__name', column_name=_('Location Name'), readonly=True + ) + supplier_part = Field( + attribute='supplier_part', + column_name=_('Supplier Part ID'), + widget=widgets.ForeignKeyWidget(SupplierPart), + ) + supplier = Field( + attribute='supplier_part__supplier__id', + column_name=_('Supplier ID'), + readonly=True, + widget=widgets.IntegerWidget(), + ) + supplier_name = Field( + attribute='supplier_part__supplier__name', + column_name=_('Supplier Name'), + readonly=True, + ) + customer = Field( + attribute='customer', + column_name=_('Customer ID'), + widget=widgets.ForeignKeyWidget(Company), + ) + belongs_to = Field( + attribute='belongs_to', + column_name=_('Installed In'), + widget=widgets.ForeignKeyWidget(StockItem), + ) + build = Field( + attribute='build', + column_name=_('Build ID'), + widget=widgets.ForeignKeyWidget(Build), + ) + parent = Field( + attribute='parent', + column_name=_('Parent ID'), + widget=widgets.ForeignKeyWidget(StockItem), + ) + sales_order = Field( + attribute='sales_order', + column_name=_('Sales Order ID'), + widget=widgets.ForeignKeyWidget(SalesOrder), + ) + purchase_order = Field( + attribute='purchase_order', + column_name=_('Purchase Order ID'), + widget=widgets.ForeignKeyWidget(PurchaseOrder), + ) packaging = Field(attribute='packaging', column_name=_('Packaging')) link = Field(attribute='link', column_name=_('Link')) notes = Field(attribute='notes', column_name=_('Notes')) # Status fields (note that IntegerWidget exports better to excel than BooleanWidget) - is_building = Field(attribute='is_building', column_name=_('Building'), widget=widgets.IntegerWidget()) - review_needed = Field(attribute='review_needed', column_name=_('Review Needed'), widget=widgets.IntegerWidget()) - delete_on_deplete = Field(attribute='delete_on_deplete', column_name=_('Delete on Deplete'), widget=widgets.IntegerWidget()) + is_building = Field( + attribute='is_building', + column_name=_('Building'), + widget=widgets.IntegerWidget(), + ) + review_needed = Field( + attribute='review_needed', + column_name=_('Review Needed'), + widget=widgets.IntegerWidget(), + ) + delete_on_deplete = Field( + attribute='delete_on_deplete', + column_name=_('Delete on Deplete'), + widget=widgets.IntegerWidget(), + ) # Date management - updated = Field(attribute='updated', column_name=_('Last Updated'), widget=widgets.DateWidget()) - stocktake_date = Field(attribute='stocktake_date', column_name=_('Stocktake'), widget=widgets.DateWidget()) - expiry_date = Field(attribute='expiry_date', column_name=_('Expiry Date'), widget=widgets.DateWidget()) + updated = Field( + attribute='updated', column_name=_('Last Updated'), widget=widgets.DateWidget() + ) + stocktake_date = Field( + attribute='stocktake_date', + column_name=_('Stocktake'), + widget=widgets.DateWidget(), + ) + expiry_date = Field( + attribute='expiry_date', + column_name=_('Expiry Date'), + widget=widgets.DateWidget(), + ) def dehydrate_purchase_price(self, item): """Render purchase pric as float""" @@ -169,12 +276,7 @@ class StockItemAdmin(ImportExportModelAdmin): list_display = ('part', 'quantity', 'location', 'status', 'updated') # A list of search fields which can be used for lookup on matching 'autocomplete' fields - search_fields = [ - 'part__name', - 'part__description', - 'serial', - 'batch', - ] + search_fields = ['part__name', 'part__description', 'serial', 'batch'] autocomplete_fields = [ 'belongs_to', @@ -195,9 +297,7 @@ class StockAttachmentAdmin(admin.ModelAdmin): list_display = ('stock_item', 'attachment', 'comment') - autocomplete_fields = [ - 'stock_item', - ] + autocomplete_fields = ['stock_item'] class StockTrackingAdmin(ImportExportModelAdmin): @@ -205,9 +305,7 @@ class StockTrackingAdmin(ImportExportModelAdmin): list_display = ('item', 'date', 'label') - autocomplete_fields = [ - 'item', - ] + autocomplete_fields = ['item'] class StockItemTestResultAdmin(admin.ModelAdmin): @@ -215,9 +313,7 @@ class StockItemTestResultAdmin(admin.ModelAdmin): list_display = ('stock_item', 'test', 'result', 'value') - autocomplete_fields = [ - 'stock_item', - ] + autocomplete_fields = ['stock_item'] admin.site.register(StockLocation, LocationAdmin) diff --git a/InvenTree/stock/api.py b/InvenTree/stock/api.py index a6b5bf6f38..a01e4e9fbf 100644 --- a/InvenTree/stock/api.py +++ b/InvenTree/stock/api.py @@ -23,25 +23,51 @@ from build.serializers import BuildSerializer from company.models import Company, SupplierPart from company.serializers import CompanySerializer from generic.states.api import StatusView -from InvenTree.api import (APIDownloadMixin, AttachmentMixin, - ListCreateDestroyAPIView, MetadataView) -from InvenTree.filters import (ORDER_FILTER, SEARCH_ORDER_FILTER, - SEARCH_ORDER_FILTER_ALIAS, InvenTreeDateFilter) -from InvenTree.helpers import (DownloadFile, extract_serial_numbers, isNull, - str2bool, str2int) -from InvenTree.mixins import (CreateAPI, CustomRetrieveUpdateDestroyAPI, - ListAPI, ListCreateAPI, RetrieveAPI, - RetrieveUpdateDestroyAPI) +from InvenTree.api import ( + APIDownloadMixin, + AttachmentMixin, + ListCreateDestroyAPIView, + MetadataView, +) +from InvenTree.filters import ( + ORDER_FILTER, + SEARCH_ORDER_FILTER, + SEARCH_ORDER_FILTER_ALIAS, + InvenTreeDateFilter, +) +from InvenTree.helpers import ( + DownloadFile, + extract_serial_numbers, + isNull, + str2bool, + str2int, +) +from InvenTree.mixins import ( + CreateAPI, + CustomRetrieveUpdateDestroyAPI, + ListAPI, + ListCreateAPI, + RetrieveAPI, + RetrieveUpdateDestroyAPI, +) from InvenTree.status_codes import StockHistoryCode, StockStatus -from order.models import (PurchaseOrder, ReturnOrder, SalesOrder, - SalesOrderAllocation) -from order.serializers import (PurchaseOrderSerializer, ReturnOrderSerializer, - SalesOrderSerializer) +from order.models import PurchaseOrder, ReturnOrder, SalesOrder, SalesOrderAllocation +from order.serializers import ( + PurchaseOrderSerializer, + ReturnOrderSerializer, + SalesOrderSerializer, +) from part.models import BomItem, Part, PartCategory from part.serializers import PartBriefSerializer from stock.admin import LocationResource, StockItemResource -from stock.models import (StockItem, StockItemAttachment, StockItemTestResult, - StockItemTracking, StockLocation, StockLocationType) +from stock.models import ( + StockItem, + StockItemAttachment, + StockItemTestResult, + StockItemTracking, + StockLocation, + StockLocationType, +) class StockDetail(RetrieveUpdateDestroyAPI): @@ -83,7 +109,9 @@ class StockDetail(RetrieveUpdateDestroyAPI): kwargs['part_detail'] = str2bool(params.get('part_detail', True)) kwargs['location_detail'] = str2bool(params.get('location_detail', True)) - kwargs['supplier_part_detail'] = str2bool(params.get('supplier_part_detail', True)) + kwargs['supplier_part_detail'] = str2bool( + params.get('supplier_part_detail', True) + ) kwargs['path_detail'] = str2bool(params.get('path_detail', False)) except AttributeError: pass @@ -226,11 +254,12 @@ class StockLocationFilter(rest_filters.FilterSet): """Base class for custom API filters for the StockLocation endpoint.""" location_type = rest_filters.ModelChoiceFilter( - queryset=StockLocationType.objects.all(), - field_name='location_type' + queryset=StockLocationType.objects.all(), field_name='location_type' ) - has_location_type = rest_filters.BooleanFilter(label='has_location_type', method='filter_has_location_type') + has_location_type = rest_filters.BooleanFilter( + label='has_location_type', method='filter_has_location_type' + ) def filter_has_location_type(self, queryset, name, value): """Filter by whether or not the location has a location type""" @@ -246,9 +275,7 @@ class StockLocationList(APIDownloadMixin, ListCreateAPI): - POST: Create a new StockLocation """ - queryset = StockLocation.objects.all().prefetch_related( - 'tags', - ) + queryset = StockLocation.objects.all().prefetch_related('tags') serializer_class = StockSerializers.LocationSerializer filterset_class = StockLocationFilter @@ -256,7 +283,7 @@ class StockLocationList(APIDownloadMixin, ListCreateAPI): """Download the filtered queryset as a data file""" dataset = LocationResource().export(queryset=queryset) filedata = dataset.export(export_format) - filename = f"InvenTree_Locations.{export_format}" + filename = f'InvenTree_Locations.{export_format}' return DownloadFile(filedata, filename) @@ -283,7 +310,6 @@ class StockLocationList(APIDownloadMixin, ListCreateAPI): pass # Look for top-level locations elif isNull(loc_id): - # If we allow "cascade" at the top-level, this essentially means *all* locations if not cascade: queryset = queryset.filter(parent=None) @@ -292,7 +318,6 @@ class StockLocationList(APIDownloadMixin, ListCreateAPI): queryset = queryset.filter(level__lte=depth) else: - try: location = StockLocation.objects.get(pk=loc_id) @@ -319,7 +344,9 @@ class StockLocationList(APIDownloadMixin, ListCreateAPI): loc = StockLocation.objects.get(pk=exclude_tree) queryset = queryset.exclude( - pk__in=[subloc.pk for subloc in loc.get_descendants(include_self=True)] + pk__in=[ + subloc.pk for subloc in loc.get_descendants(include_self=True) + ] ) except (ValueError, StockLocation.DoesNotExist): @@ -329,35 +356,13 @@ class StockLocationList(APIDownloadMixin, ListCreateAPI): filter_backends = SEARCH_ORDER_FILTER - filterset_fields = [ - 'name', - 'structural', - 'external', - 'tags__name', - 'tags__slug', - ] + filterset_fields = ['name', 'structural', 'external', 'tags__name', 'tags__slug'] - search_fields = [ - 'name', - 'description', - 'tags__name', - 'tags__slug', - ] + search_fields = ['name', 'description', 'tags__name', 'tags__slug'] - ordering_fields = [ - 'name', - 'pathstring', - 'items', - 'level', - 'tree_id', - 'lft', - ] + ordering_fields = ['name', 'pathstring', 'items', 'level', 'tree_id', 'lft'] - ordering = [ - 'tree_id', - 'lft', - 'name', - ] + ordering = ['tree_id', 'lft', 'name'] class StockLocationTree(ListAPI): @@ -384,24 +389,18 @@ class StockLocationTypeList(ListCreateAPI): filter_backends = SEARCH_ORDER_FILTER - ordering_fields = [ - "name", - "location_count", - "icon", - ] + ordering_fields = ['name', 'location_count', 'icon'] - ordering = [ - "-location_count", - ] + ordering = ['-location_count'] - search_fields = [ - "name", - ] + search_fields = ['name'] def get_queryset(self): """Override the queryset method to include location count.""" queryset = super().get_queryset() - queryset = StockSerializers.StockLocationTypeSerializer.annotate_queryset(queryset) + queryset = StockSerializers.StockLocationTypeSerializer.annotate_queryset( + queryset + ) return queryset @@ -421,7 +420,9 @@ class StockLocationTypeDetail(RetrieveUpdateDestroyAPI): def get_queryset(self): """Override the queryset method to include location count.""" queryset = super().get_queryset() - queryset = StockSerializers.StockLocationTypeSerializer.annotate_queryset(queryset) + queryset = StockSerializers.StockLocationTypeSerializer.annotate_queryset( + queryset + ) return queryset @@ -448,26 +449,58 @@ class StockFilter(rest_filters.FilterSet): ] # Relationship filters - manufacturer = rest_filters.ModelChoiceFilter(label='Manufacturer', queryset=Company.objects.filter(is_manufacturer=True), field_name='manufacturer_part__manufacturer') - supplier = rest_filters.ModelChoiceFilter(label='Supplier', queryset=Company.objects.filter(is_supplier=True), field_name='supplier_part__supplier') + manufacturer = rest_filters.ModelChoiceFilter( + label='Manufacturer', + queryset=Company.objects.filter(is_manufacturer=True), + field_name='manufacturer_part__manufacturer', + ) + supplier = rest_filters.ModelChoiceFilter( + label='Supplier', + queryset=Company.objects.filter(is_supplier=True), + field_name='supplier_part__supplier', + ) # Part name filters - name = rest_filters.CharFilter(label='Part name (case insensitive)', field_name='part__name', lookup_expr='iexact') - name_contains = rest_filters.CharFilter(label='Part name contains (case insensitive)', field_name='part__name', lookup_expr='icontains') - name_regex = rest_filters.CharFilter(label='Part name (regex)', field_name='part__name', lookup_expr='iregex') + name = rest_filters.CharFilter( + label='Part name (case insensitive)', + field_name='part__name', + lookup_expr='iexact', + ) + name_contains = rest_filters.CharFilter( + label='Part name contains (case insensitive)', + field_name='part__name', + lookup_expr='icontains', + ) + name_regex = rest_filters.CharFilter( + label='Part name (regex)', field_name='part__name', lookup_expr='iregex' + ) # Part IPN filters - IPN = rest_filters.CharFilter(label='Part IPN (case insensitive)', field_name='part__IPN', lookup_expr='iexact') - IPN_contains = rest_filters.CharFilter(label='Part IPN contains (case insensitive)', field_name='part__IPN', lookup_expr='icontains') - IPN_regex = rest_filters.CharFilter(label='Part IPN (regex)', field_name='part__IPN', lookup_expr='iregex') + IPN = rest_filters.CharFilter( + label='Part IPN (case insensitive)', + field_name='part__IPN', + lookup_expr='iexact', + ) + IPN_contains = rest_filters.CharFilter( + label='Part IPN contains (case insensitive)', + field_name='part__IPN', + lookup_expr='icontains', + ) + IPN_regex = rest_filters.CharFilter( + label='Part IPN (regex)', field_name='part__IPN', lookup_expr='iregex' + ) # Part attribute filters - assembly = rest_filters.BooleanFilter(label="Assembly", field_name='part__assembly') - active = rest_filters.BooleanFilter(label="Active", field_name='part__active') - salable = rest_filters.BooleanFilter(label="Salable", field_name='part__salable') + assembly = rest_filters.BooleanFilter(label='Assembly', field_name='part__assembly') + active = rest_filters.BooleanFilter(label='Active', field_name='part__active') + salable = rest_filters.BooleanFilter(label='Salable', field_name='part__salable') - min_stock = rest_filters.NumberFilter(label='Minimum stock', field_name='quantity', lookup_expr='gte') - max_stock = rest_filters.NumberFilter(label='Maximum stock', field_name='quantity', lookup_expr='lte') + min_stock = rest_filters.NumberFilter( + label='Minimum stock', field_name='quantity', lookup_expr='gte' + ) + max_stock = rest_filters.NumberFilter( + label='Maximum stock', field_name='quantity', lookup_expr='lte' + ) status = rest_filters.NumberFilter(label='Status Code', method='filter_status') @@ -475,15 +508,21 @@ class StockFilter(rest_filters.FilterSet): """Filter by integer status code""" return queryset.filter(status=value) - allocated = rest_filters.BooleanFilter(label='Is Allocated', method='filter_allocated') + allocated = rest_filters.BooleanFilter( + label='Is Allocated', method='filter_allocated' + ) def filter_allocated(self, queryset, name, value): """Filter by whether or not the stock item is 'allocated'""" if str2bool(value): # Filter StockItem with either build allocations or sales order allocations - return queryset.filter(Q(sales_order_allocations__isnull=False) | Q(allocations__isnull=False)).distinct() + return queryset.filter( + Q(sales_order_allocations__isnull=False) | Q(allocations__isnull=False) + ).distinct() # Filter StockItem without build allocations or sales order allocations - return queryset.filter(Q(sales_order_allocations__isnull=True) & Q(allocations__isnull=True)) + return queryset.filter( + Q(sales_order_allocations__isnull=True) & Q(allocations__isnull=True) + ) expired = rest_filters.BooleanFilter(label='Expired', method='filter_expired') @@ -496,7 +535,9 @@ class StockFilter(rest_filters.FilterSet): return queryset.filter(StockItem.EXPIRED_FILTER) return queryset.exclude(StockItem.EXPIRED_FILTER) - external = rest_filters.BooleanFilter(label=_('External Location'), method='filter_external') + external = rest_filters.BooleanFilter( + label=_('External Location'), method='filter_external' + ) def filter_external(self, queryset, name, value): """Filter by whether or not the stock item is located in an external location""" @@ -522,23 +563,37 @@ class StockFilter(rest_filters.FilterSet): if str2bool(value): # The 'quantity' field is greater than the calculated 'allocated' field # Note that the item must also be "in stock" - return queryset.filter(StockItem.IN_STOCK_FILTER).filter(Q(quantity__gt=F('allocated'))) + return queryset.filter(StockItem.IN_STOCK_FILTER).filter( + Q(quantity__gt=F('allocated')) + ) # The 'quantity' field is less than (or equal to) the calculated 'allocated' field return queryset.filter(Q(quantity__lte=F('allocated'))) - batch = rest_filters.CharFilter(label="Batch code filter (case insensitive)", lookup_expr='iexact') + batch = rest_filters.CharFilter( + label='Batch code filter (case insensitive)', lookup_expr='iexact' + ) - batch_regex = rest_filters.CharFilter(label="Batch code filter (regex)", field_name='batch', lookup_expr='iregex') + batch_regex = rest_filters.CharFilter( + label='Batch code filter (regex)', field_name='batch', lookup_expr='iregex' + ) - is_building = rest_filters.BooleanFilter(label="In production") + is_building = rest_filters.BooleanFilter(label='In production') # Serial number filtering - serial_gte = rest_filters.NumberFilter(label='Serial number GTE', field_name='serial_int', lookup_expr='gte') - serial_lte = rest_filters.NumberFilter(label='Serial number LTE', field_name='serial_int', lookup_expr='lte') + serial_gte = rest_filters.NumberFilter( + label='Serial number GTE', field_name='serial_int', lookup_expr='gte' + ) + serial_lte = rest_filters.NumberFilter( + label='Serial number LTE', field_name='serial_int', lookup_expr='lte' + ) - serial = rest_filters.CharFilter(label='Serial number', field_name='serial', lookup_expr='exact') + serial = rest_filters.CharFilter( + label='Serial number', field_name='serial', lookup_expr='exact' + ) - serialized = rest_filters.BooleanFilter(label='Has serial number', method='filter_serialized') + serialized = rest_filters.BooleanFilter( + label='Has serial number', method='filter_serialized' + ) def filter_serialized(self, queryset, name, value): """Filter by whether the StockItem has a serial number (or not).""" @@ -549,7 +604,9 @@ class StockFilter(rest_filters.FilterSet): return queryset.filter(q).distinct() - has_batch = rest_filters.BooleanFilter(label='Has batch code', method='filter_has_batch') + has_batch = rest_filters.BooleanFilter( + label='Has batch code', method='filter_has_batch' + ) def filter_has_batch(self, queryset, name, value): """Filter by whether the StockItem has a batch code (or not).""" @@ -577,7 +634,9 @@ class StockFilter(rest_filters.FilterSet): return queryset.filter(q_batch).filter(q_serial).distinct() - installed = rest_filters.BooleanFilter(label='Installed in other stock item', method='filter_installed') + installed = rest_filters.BooleanFilter( + label='Installed in other stock item', method='filter_installed' + ) def filter_installed(self, queryset, name, value): """Filter stock items by "belongs_to" field being empty.""" @@ -585,7 +644,9 @@ class StockFilter(rest_filters.FilterSet): return queryset.exclude(belongs_to=None) return queryset.filter(belongs_to=None) - has_installed_items = rest_filters.BooleanFilter(label='Has installed items', method='filter_has_installed') + has_installed_items = rest_filters.BooleanFilter( + label='Has installed items', method='filter_has_installed' + ) def filter_has_installed(self, queryset, name, value): """Filter stock items by "belongs_to" field being empty.""" @@ -593,7 +654,9 @@ class StockFilter(rest_filters.FilterSet): return queryset.filter(installed_items__gt=0) return queryset.filter(installed_items=0) - sent_to_customer = rest_filters.BooleanFilter(label='Sent to customer', method='filter_sent_to_customer') + sent_to_customer = rest_filters.BooleanFilter( + label='Sent to customer', method='filter_sent_to_customer' + ) def filter_sent_to_customer(self, queryset, name, value): """Filter by sent to customer.""" @@ -609,7 +672,9 @@ class StockFilter(rest_filters.FilterSet): return queryset.filter(quantity__lte=0) return queryset.exclude(quantity__lte=0) - has_purchase_price = rest_filters.BooleanFilter(label='Has purchase price', method='filter_has_purchase_price') + has_purchase_price = rest_filters.BooleanFilter( + label='Has purchase price', method='filter_has_purchase_price' + ) def filter_has_purchase_price(self, queryset, name, value): """Filter by having a purchase price.""" @@ -618,21 +683,17 @@ class StockFilter(rest_filters.FilterSet): return queryset.filter(purchase_price=None) ancestor = rest_filters.ModelChoiceFilter( - label='Ancestor', - queryset=StockItem.objects.all(), - method='filter_ancestor' + label='Ancestor', queryset=StockItem.objects.all(), method='filter_ancestor' ) def filter_ancestor(self, queryset, name, ancestor): """Filter based on ancestor stock item""" - return queryset.filter( - parent__in=ancestor.get_descendants(include_self=True) - ) + return queryset.filter(parent__in=ancestor.get_descendants(include_self=True)) category = rest_filters.ModelChoiceFilter( label=_('Category'), queryset=PartCategory.objects.all(), - method='filter_category' + method='filter_category', ) def filter_category(self, queryset, name, category): @@ -640,14 +701,10 @@ class StockFilter(rest_filters.FilterSet): child_categories = category.get_descendants(include_self=True) - return queryset.filter( - part__category__in=child_categories, - ) + return queryset.filter(part__category__in=child_categories) bom_item = rest_filters.ModelChoiceFilter( - label=_('BOM Item'), - queryset=BomItem.objects.all(), - method='filter_bom_item' + label=_('BOM Item'), queryset=BomItem.objects.all(), method='filter_bom_item' ) def filter_bom_item(self, queryset, name, bom_item): @@ -656,44 +713,39 @@ class StockFilter(rest_filters.FilterSet): return queryset.filter(bom_item.get_stock_filter()) part_tree = rest_filters.ModelChoiceFilter( - label=_('Part Tree'), - queryset=Part.objects.all(), - method='filter_part_tree' + label=_('Part Tree'), queryset=Part.objects.all(), method='filter_part_tree' ) def filter_part_tree(self, queryset, name, part_tree): """Filter based on part tree""" - return queryset.filter( - part__tree_id=part_tree.tree_id - ) + return queryset.filter(part__tree_id=part_tree.tree_id) company = rest_filters.ModelChoiceFilter( - label=_('Company'), - queryset=Company.objects.all(), - method='filter_company' + label=_('Company'), queryset=Company.objects.all(), method='filter_company' ) def filter_company(self, queryset, name, company): """Filter by company (either manufacturer or supplier)""" return queryset.filter( - Q(supplier_part__supplier=company) | Q(supplier_part__manufacturer_part__manufacturer=company) + Q(supplier_part__supplier=company) + | Q(supplier_part__manufacturer_part__manufacturer=company) ).distinct() # Update date filters - updated_before = InvenTreeDateFilter(label='Updated before', field_name='updated', lookup_expr='lte') - updated_after = InvenTreeDateFilter(label='Updated after', field_name='updated', lookup_expr='gte') + updated_before = InvenTreeDateFilter( + label='Updated before', field_name='updated', lookup_expr='lte' + ) + updated_after = InvenTreeDateFilter( + label='Updated after', field_name='updated', lookup_expr='gte' + ) # Stock "expiry" filters expiry_date_lte = InvenTreeDateFilter( - label=_("Expiry date before"), - field_name='expiry_date', - lookup_expr='lte', + label=_('Expiry date before'), field_name='expiry_date', lookup_expr='lte' ) expiry_date_gte = InvenTreeDateFilter( - label=_('Expiry date after'), - field_name='expiry_date', - lookup_expr='gte', + label=_('Expiry date after'), field_name='expiry_date', lookup_expr='gte' ) stale = rest_filters.BooleanFilter(label=_('Stale'), method='filter_stale') @@ -708,7 +760,11 @@ class StockFilter(rest_filters.FilterSet): return queryset stale_date = datetime.now().date() + timedelta(days=stale_days) - stale_filter = StockItem.IN_STOCK_FILTER & ~Q(expiry_date=None) & Q(expiry_date__lt=stale_date) + stale_filter = ( + StockItem.IN_STOCK_FILTER + & ~Q(expiry_date=None) + & Q(expiry_date__lt=stale_date) + ) if str2bool(value): return queryset.filter(stale_filter) @@ -740,7 +796,12 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): try: params = self.request.query_params - for key in ['part_detail', 'location_detail', 'supplier_part_detail', 'tests']: + for key in [ + 'part_detail', + 'location_detail', + 'supplier_part_detail', + 'tests', + ]: kwargs[key] = str2bool(params.get(key, False)) except AttributeError: pass @@ -774,17 +835,13 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): quantity = data.get('quantity', None) if quantity is None: - raise ValidationError({ - 'quantity': _('Quantity is required'), - }) + raise ValidationError({'quantity': _('Quantity is required')}) try: Part.objects.prefetch_related(None) part = Part.objects.get(pk=data.get('part', None)) except (ValueError, Part.DoesNotExist): - raise ValidationError({ - 'part': _('Valid part must be supplied'), - }) + raise ValidationError({'part': _('Valid part must be supplied')}) # Set default location (if not provided) if 'location' not in data: @@ -797,7 +854,9 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): # An expiry date was *not* specified - try to infer it! if expiry_date is None and part.default_expiry > 0: - data['expiry_date'] = datetime.now().date() + timedelta(days=part.default_expiry) + data['expiry_date'] = datetime.now().date() + timedelta( + days=part.default_expiry + ) # Attempt to extract serial numbers from submitted data serials = None @@ -808,10 +867,12 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): # Check if the supplier_part has a package size defined, which is not 1 if 'supplier_part' in data and data['supplier_part'] is not None: try: - supplier_part = SupplierPart.objects.get(pk=data.get('supplier_part', None)) + supplier_part = SupplierPart.objects.get( + pk=data.get('supplier_part', None) + ) except (ValueError, SupplierPart.DoesNotExist): raise ValidationError({ - 'supplier_part': _('The given supplier part does not exist'), + 'supplier_part': _('The given supplier part does not exist') }) if supplier_part.base_quantity() != 1: @@ -820,16 +881,25 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): # use_pack_size = False -> Use quantity as is if 'use_pack_size' not in data: raise ValidationError({ - 'use_pack_size': _('The supplier part has a pack size defined, but flag use_pack_size not set'), + 'use_pack_size': _( + 'The supplier part has a pack size defined, but flag use_pack_size not set' + ) }) else: if bool(data.get('use_pack_size')): - quantity = data['quantity'] = supplier_part.base_quantity(quantity) + quantity = data['quantity'] = supplier_part.base_quantity( + quantity + ) # Divide purchase price by pack size, to save correct price per stock item - if data['purchase_price'] and supplier_part.pack_quantity_native: + if ( + data['purchase_price'] + and supplier_part.pack_quantity_native + ): try: - data['purchase_price'] = float(data['purchase_price']) / float(supplier_part.pack_quantity_native) + data['purchase_price'] = float( + data['purchase_price'] + ) / float(supplier_part.pack_quantity_native) except ValueError: pass @@ -840,18 +910,17 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): # Assign serial numbers for a trackable part if serial_numbers: - if not part.trackable: raise ValidationError({ - 'serial_numbers': [_("Serial numbers cannot be supplied for a non-trackable part")] + 'serial_numbers': [ + _('Serial numbers cannot be supplied for a non-trackable part') + ] }) # If serial numbers are specified, check that they match! try: serials = extract_serial_numbers( - serial_numbers, - quantity, - part.get_latest_serial_number() + serial_numbers, quantity, part.get_latest_serial_number() ) # Determine if any of the specified serial numbers are invalid @@ -870,14 +939,11 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): errors.append(exc.message) if len(errors) > 0: + msg = _('The following serial numbers already exist or are invalid') + msg += ' : ' + msg += ','.join([str(e) for e in invalid]) - msg = _("The following serial numbers already exist or are invalid") - msg += " : " - msg += ",".join([str(e) for e in invalid]) - - raise ValidationError({ - 'serial_numbers': errors + [msg] - }) + raise ValidationError({'serial_numbers': errors + [msg]}) except DjangoValidationError as e: raise ValidationError({ @@ -894,7 +960,6 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): serializer.is_valid(raise_exception=True) with transaction.atomic(): - # Create an initial StockItem object item = serializer.save() @@ -907,22 +972,22 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): if serials: for serial in serials[1:]: - # Create a duplicate stock item with the next serial number item.pk = None item.serial = serial item.save(user=user) - response_data = { - 'quantity': quantity, - 'serial_numbers': serials, - } + response_data = {'quantity': quantity, 'serial_numbers': serials} else: response_data = serializer.data - return Response(response_data, status=status.HTTP_201_CREATED, headers=self.get_success_headers(serializer.data)) + return Response( + response_data, + status=status.HTTP_201_CREATED, + headers=self.get_success_headers(serializer.data), + ) def download_queryset(self, queryset, export_format): """Download this queryset as a file. @@ -999,23 +1064,16 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): exclude_so_allocation = params.get('exclude_so_allocation', None) if exclude_so_allocation is not None: - try: order = SalesOrder.objects.get(pk=exclude_so_allocation) # Grab all the active SalesOrderAllocations for this order allocations = SalesOrderAllocation.objects.filter( - line__pk__in=[ - line.pk for line in order.lines.all() - ] + line__pk__in=[line.pk for line in order.lines.all()] ) # Exclude any stock item which is already allocated to the sales order - queryset = queryset.exclude( - pk__in=[ - a.item.pk for a in allocations - ] - ) + queryset = queryset.exclude(pk__in=[a.item.pk for a in allocations]) except (ValueError, SalesOrder.DoesNotExist): pass @@ -1040,7 +1098,7 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): queryset = queryset.filter(part=part) except (ValueError, Part.DoesNotExist): - raise ValidationError({"part": "Invalid Part ID specified"}) + raise ValidationError({'part': 'Invalid Part ID specified'}) # Does the client wish to filter by stock location? loc_id = params.get('location', None) @@ -1048,7 +1106,6 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): cascade = str2bool(params.get('cascade', True)) if loc_id is not None: - # Filter by 'null' location (i.e. top-level items) if isNull(loc_id): if not cascade: @@ -1058,7 +1115,9 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): # If '?cascade=true' then include items which exist in sub-locations if cascade: location = StockLocation.objects.get(pk=loc_id) - queryset = queryset.filter(location__in=location.getUniqueChildren()) + queryset = queryset.filter( + location__in=location.getUniqueChildren() + ) else: queryset = queryset.filter(location=loc_id) @@ -1089,11 +1148,7 @@ class StockList(APIDownloadMixin, ListCreateDestroyAPIView): 'SKU', ] - ordering = [ - 'part__name', - 'quantity', - 'location', - ] + ordering = ['part__name', 'quantity', 'location'] search_fields = [ 'serial', @@ -1113,9 +1168,7 @@ class StockAttachmentList(AttachmentMixin, ListCreateDestroyAPIView): queryset = StockItemAttachment.objects.all() serializer_class = StockSerializers.StockItemAttachmentSerializer - filterset_fields = [ - 'stock_item', - ] + filterset_fields = ['stock_item'] class StockAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI): @@ -1140,12 +1193,7 @@ class StockItemTestResultList(ListCreateDestroyAPIView): filter_backends = SEARCH_ORDER_FILTER - filterset_fields = [ - 'test', - 'user', - 'result', - 'value', - ] + filterset_fields = ['test', 'user', 'result', 'value'] ordering = 'date' @@ -1159,7 +1207,6 @@ class StockItemTestResultList(ListCreateDestroyAPIView): build = params.get('build', None) if build is not None: - try: build = Build.objects.get(pk=build) @@ -1197,7 +1244,9 @@ class StockItemTestResultList(ListCreateDestroyAPIView): def get_serializer(self, *args, **kwargs): """Set context before returning serializer.""" try: - kwargs['user_detail'] = str2bool(self.request.query_params.get('user_detail', False)) + kwargs['user_detail'] = str2bool( + self.request.query_params.get('user_detail', False) + ) except Exception: pass @@ -1239,12 +1288,16 @@ class StockTrackingList(ListAPI): def get_serializer(self, *args, **kwargs): """Set context before returning serializer.""" try: - kwargs['item_detail'] = str2bool(self.request.query_params.get('item_detail', False)) + kwargs['item_detail'] = str2bool( + self.request.query_params.get('item_detail', False) + ) except Exception: pass try: - kwargs['user_detail'] = str2bool(self.request.query_params.get('user_detail', False)) + kwargs['user_detail'] = str2bool( + self.request.query_params.get('user_detail', False) + ) except Exception: pass @@ -1372,25 +1425,19 @@ class StockTrackingList(ListAPI): item.save() headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + ) filter_backends = SEARCH_ORDER_FILTER - filterset_fields = [ - 'item', - 'user', - ] + filterset_fields = ['item', 'user'] ordering = '-date' - ordering_fields = [ - 'date', - ] + ordering_fields = ['date'] - search_fields = [ - 'title', - 'notes', - ] + search_fields = ['title', 'notes'] class LocationDetail(CustomRetrieveUpdateDestroyAPI): @@ -1435,36 +1482,59 @@ class LocationDetail(CustomRetrieveUpdateDestroyAPI): **dict( kwargs, delete_sub_locations=delete_sub_locations, - delete_stock_items=delete_stock_items - ) + delete_stock_items=delete_stock_items, + ), ) stock_api_urls = [ - re_path(r'^location/', include([ - - re_path(r'^tree/', StockLocationTree.as_view(), name='api-location-tree'), - - # Stock location detail endpoints - path(r'/', include([ - - re_path(r'^metadata/', MetadataView.as_view(), {'model': StockLocation}, name='api-location-metadata'), - - re_path(r'^.*$', LocationDetail.as_view(), name='api-location-detail'), - ])), - - re_path(r'^.*$', StockLocationList.as_view(), name='api-location-list'), - ])), - + re_path( + r'^location/', + include([ + re_path(r'^tree/', StockLocationTree.as_view(), name='api-location-tree'), + # Stock location detail endpoints + path( + r'/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': StockLocation}, + name='api-location-metadata', + ), + re_path( + r'^.*$', LocationDetail.as_view(), name='api-location-detail' + ), + ]), + ), + re_path(r'^.*$', StockLocationList.as_view(), name='api-location-list'), + ]), + ), # Stock location type endpoints - re_path(r'^location-type/', include([ - path(r'/', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': StockLocationType}, name='api-location-type-metadata'), - re_path(r'^.*$', StockLocationTypeDetail.as_view(), name='api-location-type-detail'), - ])), - re_path(r'^.*$', StockLocationTypeList.as_view(), name="api-location-type-list"), - ])), - + re_path( + r'^location-type/', + include([ + path( + r'/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': StockLocationType}, + name='api-location-type-metadata', + ), + re_path( + r'^.*$', + StockLocationTypeDetail.as_view(), + name='api-location-type-detail', + ), + ]), + ), + re_path( + r'^.*$', StockLocationTypeList.as_view(), name='api-location-type-list' + ), + ]), + ), # Endpoints for bulk stock adjustment actions re_path(r'^count/', StockCount.as_view(), name='api-stock-count'), re_path(r'^add/', StockAdd.as_view(), name='api-stock-add'), @@ -1472,47 +1542,108 @@ stock_api_urls = [ re_path(r'^transfer/', StockTransfer.as_view(), name='api-stock-transfer'), re_path(r'^assign/', StockAssign.as_view(), name='api-stock-assign'), re_path(r'^merge/', StockMerge.as_view(), name='api-stock-merge'), - re_path(r'^change_status/', StockChangeStatus.as_view(), name='api-stock-change-status'), - + re_path( + r'^change_status/', StockChangeStatus.as_view(), name='api-stock-change-status' + ), # StockItemAttachment API endpoints - re_path(r'^attachment/', include([ - path(r'/', StockAttachmentDetail.as_view(), name='api-stock-attachment-detail'), - path('', StockAttachmentList.as_view(), name='api-stock-attachment-list'), - ])), - + re_path( + r'^attachment/', + include([ + path( + r'/', + StockAttachmentDetail.as_view(), + name='api-stock-attachment-detail', + ), + path('', StockAttachmentList.as_view(), name='api-stock-attachment-list'), + ]), + ), # StockItemTestResult API endpoints - re_path(r'^test/', include([ - path(r'/', include([ - re_path(r'^metadata/', MetadataView.as_view(), {'model': StockItemTestResult}, name='api-stock-test-result-metadata'), - re_path(r'^.*$', StockItemTestResultDetail.as_view(), name='api-stock-test-result-detail'), - ])), - re_path(r'^.*$', StockItemTestResultList.as_view(), name='api-stock-test-result-list'), - ])), - + re_path( + r'^test/', + include([ + path( + r'/', + include([ + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': StockItemTestResult}, + name='api-stock-test-result-metadata', + ), + re_path( + r'^.*$', + StockItemTestResultDetail.as_view(), + name='api-stock-test-result-detail', + ), + ]), + ), + re_path( + r'^.*$', + StockItemTestResultList.as_view(), + name='api-stock-test-result-list', + ), + ]), + ), # StockItemTracking API endpoints - re_path(r'^track/', include([ - path(r'/', StockTrackingDetail.as_view(), name='api-stock-tracking-detail'), - - # Stock tracking status code information - re_path(r'status/', StatusView.as_view(), {StatusView.MODEL_REF: StockHistoryCode}, name='api-stock-tracking-status-codes'), - - re_path(r'^.*$', StockTrackingList.as_view(), name='api-stock-tracking-list'), - ])), - + re_path( + r'^track/', + include([ + path( + r'/', + StockTrackingDetail.as_view(), + name='api-stock-tracking-detail', + ), + # Stock tracking status code information + re_path( + r'status/', + StatusView.as_view(), + {StatusView.MODEL_REF: StockHistoryCode}, + name='api-stock-tracking-status-codes', + ), + re_path( + r'^.*$', StockTrackingList.as_view(), name='api-stock-tracking-list' + ), + ]), + ), # Detail views for a single stock item - path(r'/', include([ - re_path(r'^convert/', StockItemConvert.as_view(), name='api-stock-item-convert'), - re_path(r'^install/', StockItemInstall.as_view(), name='api-stock-item-install'), - re_path(r'^metadata/', MetadataView.as_view(), {'model': StockItem}, name='api-stock-item-metadata'), - re_path(r'^return/', StockItemReturn.as_view(), name='api-stock-item-return'), - re_path(r'^serialize/', StockItemSerialize.as_view(), name='api-stock-item-serialize'), - re_path(r'^uninstall/', StockItemUninstall.as_view(), name='api-stock-item-uninstall'), - re_path(r'^.*$', StockDetail.as_view(), name='api-stock-detail'), - ])), - + path( + r'/', + include([ + re_path( + r'^convert/', StockItemConvert.as_view(), name='api-stock-item-convert' + ), + re_path( + r'^install/', StockItemInstall.as_view(), name='api-stock-item-install' + ), + re_path( + r'^metadata/', + MetadataView.as_view(), + {'model': StockItem}, + name='api-stock-item-metadata', + ), + re_path( + r'^return/', StockItemReturn.as_view(), name='api-stock-item-return' + ), + re_path( + r'^serialize/', + StockItemSerialize.as_view(), + name='api-stock-item-serialize', + ), + re_path( + r'^uninstall/', + StockItemUninstall.as_view(), + name='api-stock-item-uninstall', + ), + re_path(r'^.*$', StockDetail.as_view(), name='api-stock-detail'), + ]), + ), # Stock item status code information - re_path(r'status/', StatusView.as_view(), {StatusView.MODEL_REF: StockStatus}, name='api-stock-status-codes'), - + re_path( + r'status/', + StatusView.as_view(), + {StatusView.MODEL_REF: StockStatus}, + name='api-stock-status-codes', + ), # Anything else re_path(r'^.*$', StockList.as_view(), name='api-stock-list'), ] diff --git a/InvenTree/stock/apps.py b/InvenTree/stock/apps.py index 5c1d48545d..14ebab890f 100644 --- a/InvenTree/stock/apps.py +++ b/InvenTree/stock/apps.py @@ -5,4 +5,5 @@ from django.apps import AppConfig class StockConfig(AppConfig): """AppConfig for stock app.""" + name = 'stock' diff --git a/InvenTree/stock/filters.py b/InvenTree/stock/filters.py index 0d5207c801..57d0f0461b 100644 --- a/InvenTree/stock/filters.py +++ b/InvenTree/stock/filters.py @@ -31,5 +31,5 @@ def annotate_location_items(filter: Q = None): ).values('total') ), 0, - output_field=IntegerField() + output_field=IntegerField(), ) diff --git a/InvenTree/stock/models.py b/InvenTree/stock/models.py index 8fbecd495c..af29e74fac 100644 --- a/InvenTree/stock/models.py +++ b/InvenTree/stock/models.py @@ -31,11 +31,20 @@ import label.models import report.models from company import models as CompanyModels from InvenTree.fields import InvenTreeModelMoneyField, InvenTreeURLField -from InvenTree.models import (InvenTreeAttachment, InvenTreeBarcodeMixin, - InvenTreeNotesMixin, InvenTreeTree, - MetadataMixin, extract_int) -from InvenTree.status_codes import (SalesOrderStatusGroups, StockHistoryCode, - StockStatus, StockStatusGroups) +from InvenTree.models import ( + InvenTreeAttachment, + InvenTreeBarcodeMixin, + InvenTreeNotesMixin, + InvenTreeTree, + MetadataMixin, + extract_int, +) +from InvenTree.status_codes import ( + SalesOrderStatusGroups, + StockHistoryCode, + StockStatus, + StockStatusGroups, +) from part import models as PartModels from plugin.events import trigger_event from users.models import Owner @@ -53,8 +62,8 @@ class StockLocationType(MetadataMixin, models.Model): class Meta: """Metaclass defines extra model properties.""" - verbose_name = _("Stock Location type") - verbose_name_plural = _("Stock Location types") + verbose_name = _('Stock Location type') + verbose_name_plural = _('Stock Location types') @staticmethod def get_api_url(): @@ -66,24 +75,21 @@ class StockLocationType(MetadataMixin, models.Model): return self.name name = models.CharField( - blank=False, - max_length=100, - verbose_name=_("Name"), - help_text=_("Name"), + blank=False, max_length=100, verbose_name=_('Name'), help_text=_('Name') ) description = models.CharField( blank=True, max_length=250, - verbose_name=_("Description"), - help_text=_("Description (optional)") + verbose_name=_('Description'), + help_text=_('Description (optional)'), ) icon = models.CharField( blank=True, max_length=100, - verbose_name=_("Icon"), - help_text=_("Default icon for all locations that have no icon set (optional)") + verbose_name=_('Icon'), + help_text=_('Default icon for all locations that have no icon set (optional)'), ) @@ -98,7 +104,7 @@ class StockLocationManager(TreeManager): - Joins the StockLocationType by default for speedier icon access """ - return super().get_queryset().select_related("location_type") + return super().get_queryset().select_related('location_type') class StockLocation(InvenTreeBarcodeMixin, MetadataMixin, InvenTreeTree): @@ -139,37 +145,44 @@ class StockLocation(InvenTreeBarcodeMixin, MetadataMixin, InvenTreeTree): custom_icon = models.CharField( blank=True, max_length=100, - verbose_name=_("Icon"), - help_text=_("Icon (optional)"), - db_column="icon", + verbose_name=_('Icon'), + help_text=_('Icon (optional)'), + db_column='icon', ) - owner = models.ForeignKey(Owner, on_delete=models.SET_NULL, blank=True, null=True, - verbose_name=_('Owner'), - help_text=_('Select Owner'), - related_name='stock_locations') + owner = models.ForeignKey( + Owner, + on_delete=models.SET_NULL, + blank=True, + null=True, + verbose_name=_('Owner'), + help_text=_('Select Owner'), + related_name='stock_locations', + ) structural = models.BooleanField( default=False, verbose_name=_('Structural'), help_text=_( 'Stock items may not be directly located into a structural stock locations, ' - 'but may be located to child locations.'), + 'but may be located to child locations.' + ), ) external = models.BooleanField( default=False, verbose_name=_('External'), - help_text=_('This is an external stock location') + help_text=_('This is an external stock location'), ) location_type = models.ForeignKey( StockLocationType, on_delete=models.SET_NULL, - verbose_name=_("Location type"), - related_name="stock_locations", - null=True, blank=True, - help_text=_("Stock location type of this location"), + verbose_name=_('Location type'), + related_name='stock_locations', + null=True, + blank=True, + help_text=_('Stock location type of this location'), ) @property @@ -184,7 +197,7 @@ class StockLocation(InvenTreeBarcodeMixin, MetadataMixin, InvenTreeTree): if self.location_type: return self.location_type.icon - return "" + return '' @icon.setter def icon(self, value): @@ -212,7 +225,9 @@ class StockLocation(InvenTreeBarcodeMixin, MetadataMixin, InvenTreeTree): if user.is_superuser: return True - ownership_enabled = common.models.InvenTreeSetting.get_setting('STOCK_OWNERSHIP_CONTROL') + ownership_enabled = common.models.InvenTreeSetting.get_setting( + 'STOCK_OWNERSHIP_CONTROL' + ) if not ownership_enabled: # Location ownership function is not enabled, so return True @@ -234,8 +249,11 @@ class StockLocation(InvenTreeBarcodeMixin, MetadataMixin, InvenTreeTree): """ if self.pk and self.structural and self.stock_item_count(False) > 0: raise ValidationError( - _("You cannot make this stock location structural because some stock items " - "are already located into it!")) + _( + 'You cannot make this stock location structural because some stock items ' + 'are already located into it!' + ) + ) super().clean() def get_absolute_url(self): @@ -249,7 +267,9 @@ class StockLocation(InvenTreeBarcodeMixin, MetadataMixin, InvenTreeTree): cascade: If True, also look under sublocations (default = True) """ if cascade: - query = StockItem.objects.filter(location__in=self.getUniqueChildren(include_self=True)) + query = StockItem.objects.filter( + location__in=self.getUniqueChildren(include_self=True) + ) else: query = StockItem.objects.filter(location=self.pk) @@ -292,7 +312,9 @@ def generate_batch_code(): return batch # If we get to this point, no plugin was able to generate a new batch code - batch_template = common.models.InvenTreeSetting.get_setting('STOCK_BATCH_CODE_TEMPLATE', '') + batch_template = common.models.InvenTreeSetting.get_setting( + 'STOCK_BATCH_CODE_TEMPLATE', '' + ) now = datetime.now() @@ -317,13 +339,21 @@ def default_delete_on_deplete(): Now, there is a user-configurable setting to govern default behaviour. """ try: - return common.models.InvenTreeSetting.get_setting('STOCK_DELETE_DEPLETED_DEFAULT', True) + return common.models.InvenTreeSetting.get_setting( + 'STOCK_DELETE_DEPLETED_DEFAULT', True + ) except (IntegrityError, OperationalError): # Revert to original default behaviour return True -class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, common.models.MetaMixin, MPTTModel): +class StockItem( + InvenTreeBarcodeMixin, + InvenTreeNotesMixin, + MetadataMixin, + common.models.MetaMixin, + MPTTModel, +): """A StockItem object represents a quantity of physical instances of a part. Attributes: @@ -358,11 +388,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo def api_instance_filters(self): """Custom API instance filters.""" - return { - 'parent': { - 'exclude_tree': self.pk, - } - } + return {'parent': {'exclude_tree': self.pk}} tags = TaggableManager(blank=True) @@ -374,11 +400,15 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo customer=None, consumed_by=None, is_building=False, - status__in=StockStatusGroups.AVAILABLE_CODES + status__in=StockStatusGroups.AVAILABLE_CODES, ) # A query filter which can be used to filter StockItem objects which have expired - EXPIRED_FILTER = IN_STOCK_FILTER & ~Q(expiry_date=None) & Q(expiry_date__lt=datetime.now().date()) + EXPIRED_FILTER = ( + IN_STOCK_FILTER + & ~Q(expiry_date=None) + & Q(expiry_date__lt=datetime.now().date()) + ) def update_serial_number(self): """Update the 'serial_int' field, to be an integer representation of the serial number. @@ -400,7 +430,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo if serial_int is not None: # Save the first returned result # Ensure that it is clipped within a range allowed in the database schema - clip = 0x7fffffff + clip = 0x7FFFFFFF serial_int = abs(serial_int) @@ -505,10 +535,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo if add_note and len(deltas) > 0: self.add_tracking_entry( - StockHistoryCode.EDITED, - user, - deltas=deltas, - notes=notes, + StockHistoryCode.EDITED, user, deltas=deltas, notes=notes ) except (ValueError, StockItem.DoesNotExist): @@ -518,10 +545,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo # If user information is provided, and no existing note exists, create one! if user and self.tracking_info.count() == 0: - - tracking_info = { - 'status': self.status, - } + tracking_info = {'status': self.status} self.add_tracking_entry( StockHistoryCode.CREATED, @@ -540,7 +564,11 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo @property def serialized(self): """Return True if this StockItem is serialized.""" - return self.serial is not None and len(str(self.serial).strip()) > 0 and self.quantity == 1 + return ( + self.serial is not None + and len(str(self.serial).strip()) > 0 + and self.quantity == 1 + ) def validate_unique(self, exclude=None): """Test that this StockItem is "unique". @@ -552,15 +580,12 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo # If the serial number is set, make sure it is not a duplicate if self.serial: - self.serial = str(self.serial).strip() try: self.part.validate_serial_number(self.serial, self, raise_error=True) except ValidationError as exc: - raise ValidationError({ - 'serial': exc.message, - }) + raise ValidationError({'serial': exc.message}) def validate_batch_code(self): """Ensure that the batch code is valid for this StockItem. @@ -574,9 +599,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo try: plugin.validate_batch_code(self.batch, self) except ValidationError as exc: - raise ValidationError({ - 'batch': exc.message - }) + raise ValidationError({'batch': exc.message}) def clean(self): """Validate the StockItem object (separate to field validation). @@ -589,8 +612,11 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo - Quantity must be 1 if the StockItem has a serial number """ if self.location is not None and self.location.structural: - raise ValidationError( - {'location': _("Stock items cannot be located into structural stock locations!")}) + raise ValidationError({ + 'location': _( + 'Stock items cannot be located into structural stock locations!' + ) + }) super().clean() @@ -610,13 +636,15 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo if self.part.trackable: if self.quantity != int(self.quantity): raise ValidationError({ - 'quantity': _('Quantity must be integer value for trackable parts') + 'quantity': _( + 'Quantity must be integer value for trackable parts' + ) }) # Virtual parts cannot have stock items created against them if self.part.virtual: raise ValidationError({ - 'part': _("Stock item cannot be created for virtual parts"), + 'part': _('Stock item cannot be created for virtual parts') }) except PartModels.Part.DoesNotExist: # For some reason the 'clean' process sometimes throws errors because self.part does not exist @@ -625,24 +653,29 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo pass if self.quantity < 0: - raise ValidationError({ - 'quantity': _('Quantity must be greater than zero') - }) + raise ValidationError({'quantity': _('Quantity must be greater than zero')}) # The 'supplier_part' field must point to the same part! try: if self.supplier_part is not None: if self.supplier_part.part != self.part: - raise ValidationError({'supplier_part': _(f"Part type ('{self.supplier_part.part}') must be {self.part}") - }) + raise ValidationError({ + 'supplier_part': _( + f"Part type ('{self.supplier_part.part}') must be {self.part}" + ) + }) if self.part is not None: # A part with a serial number MUST have the quantity set to 1 if self.serial: if self.quantity > 1: raise ValidationError({ - 'quantity': _('Quantity must be 1 for item with a serial number'), - 'serial': _('Serial number cannot be set if quantity greater than 1') + 'quantity': _( + 'Quantity must be 1 for item with a serial number' + ), + 'serial': _( + 'Serial number cannot be set if quantity greater than 1' + ), }) if self.quantity == 0: @@ -650,7 +683,9 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo elif self.quantity > 1: raise ValidationError({ - 'quantity': _('Quantity must be 1 for item with a serial number') + 'quantity': _( + 'Quantity must be 1 for item with a serial number' + ) }) # Serial numbered items cannot be deleted on depletion @@ -663,19 +698,16 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo # Ensure that the item cannot be assigned to itself if self.belongs_to and self.belongs_to.pk == self.pk: - raise ValidationError({ - 'belongs_to': _('Item cannot belong to itself') - }) + raise ValidationError({'belongs_to': _('Item cannot belong to itself')}) # If the item is marked as "is_building", it must point to a build! if self.is_building and not self.build: raise ValidationError({ - 'build': _("Item must have a build reference if is_building=True") + 'build': _('Item must have a build reference if is_building=True') }) # If the item points to a build, check that the Part references match if self.build: - if self.part == self.build.part: # Part references match exactly pass @@ -684,7 +716,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo pass else: raise ValidationError({ - 'build': _("Build reference does not point to the same part object") + 'build': _('Build reference does not point to the same part object') }) def get_absolute_url(self): @@ -700,21 +732,25 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo 'self', verbose_name=_('Parent Stock Item'), on_delete=models.DO_NOTHING, - blank=True, null=True, - related_name='children' + blank=True, + null=True, + related_name='children', ) part = models.ForeignKey( - 'part.Part', on_delete=models.CASCADE, + 'part.Part', + on_delete=models.CASCADE, verbose_name=_('Base Part'), related_name='stock_items', help_text=_('Base part'), - limit_choices_to={ - 'virtual': False - }) + limit_choices_to={'virtual': False}, + ) supplier_part = models.ForeignKey( - 'company.SupplierPart', blank=True, null=True, on_delete=models.SET_NULL, + 'company.SupplierPart', + blank=True, + null=True, + on_delete=models.SET_NULL, verbose_name=_('Supplier Part'), help_text=_('Select a matching supplier part for this stock item'), related_name='stock_items', @@ -722,18 +758,21 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo # Note: When a StockLocation is deleted, stock items are updated via a signal location = TreeForeignKey( - StockLocation, on_delete=models.DO_NOTHING, + StockLocation, + on_delete=models.DO_NOTHING, verbose_name=_('Stock Location'), related_name='stock_items', - blank=True, null=True, - help_text=_('Where is this stock item located?') + blank=True, + null=True, + help_text=_('Where is this stock item located?'), ) packaging = models.CharField( max_length=50, - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Packaging'), - help_text=_('Packaging this stock item is stored in') + help_text=_('Packaging this stock item is stored in'), ) # When deleting a stock item with installed items, those installed items are also installed @@ -741,108 +780,127 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo 'self', verbose_name=_('Installed In'), on_delete=models.CASCADE, - related_name='installed_parts', blank=True, null=True, - help_text=_('Is this item installed in another item?') + related_name='installed_parts', + blank=True, + null=True, + help_text=_('Is this item installed in another item?'), ) customer = models.ForeignKey( CompanyModels.Company, on_delete=models.SET_NULL, - null=True, blank=True, + null=True, + blank=True, limit_choices_to={'is_customer': True}, related_name='assigned_stock', - help_text=_("Customer"), - verbose_name=_("Customer"), + help_text=_('Customer'), + verbose_name=_('Customer'), ) serial = models.CharField( verbose_name=_('Serial Number'), - max_length=100, blank=True, null=True, - help_text=_('Serial number for this item') + max_length=100, + blank=True, + null=True, + help_text=_('Serial number for this item'), ) serial_int = models.IntegerField(default=0) link = InvenTreeURLField( - verbose_name=_('External Link'), - blank=True, - help_text=_("Link to external URL") + verbose_name=_('External Link'), blank=True, help_text=_('Link to external URL') ) batch = models.CharField( verbose_name=_('Batch Code'), - max_length=100, blank=True, null=True, + max_length=100, + blank=True, + null=True, help_text=_('Batch code for this stock item'), default=generate_batch_code, ) quantity = models.DecimalField( - verbose_name=_("Stock Quantity"), - max_digits=15, decimal_places=5, validators=[MinValueValidator(0)], - default=1 + verbose_name=_('Stock Quantity'), + max_digits=15, + decimal_places=5, + validators=[MinValueValidator(0)], + default=1, ) build = models.ForeignKey( - 'build.Build', on_delete=models.SET_NULL, + 'build.Build', + on_delete=models.SET_NULL, verbose_name=_('Source Build'), - blank=True, null=True, + blank=True, + null=True, help_text=_('Build for this stock item'), related_name='build_outputs', ) consumed_by = models.ForeignKey( - 'build.Build', on_delete=models.CASCADE, + 'build.Build', + on_delete=models.CASCADE, verbose_name=_('Consumed By'), - blank=True, null=True, + blank=True, + null=True, help_text=_('Build order which consumed this stock item'), related_name='consumed_stock', ) - is_building = models.BooleanField( - default=False, - ) + is_building = models.BooleanField(default=False) purchase_order = models.ForeignKey( 'order.PurchaseOrder', on_delete=models.SET_NULL, verbose_name=_('Source Purchase Order'), related_name='stock_items', - blank=True, null=True, - help_text=_('Purchase order for this stock item') + blank=True, + null=True, + help_text=_('Purchase order for this stock item'), ) sales_order = models.ForeignKey( 'order.SalesOrder', on_delete=models.SET_NULL, - verbose_name=_("Destination Sales Order"), + verbose_name=_('Destination Sales Order'), related_name='stock_items', - null=True, blank=True) + null=True, + blank=True, + ) expiry_date = models.DateField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Expiry Date'), - help_text=_('Expiry date for stock item. Stock will be considered expired after this date'), + help_text=_( + 'Expiry date for stock item. Stock will be considered expired after this date' + ), ) stocktake_date = models.DateField(blank=True, null=True) stocktake_user = models.ForeignKey( - User, on_delete=models.SET_NULL, - blank=True, null=True, - related_name='stocktake_stock' + User, + on_delete=models.SET_NULL, + blank=True, + null=True, + related_name='stocktake_stock', ) review_needed = models.BooleanField(default=False) delete_on_deplete = models.BooleanField( default=default_delete_on_deplete, - verbose_name=_('Delete on deplete'), help_text=_('Delete this Stock Item when stock is depleted') + verbose_name=_('Delete on deplete'), + help_text=_('Delete this Stock Item when stock is depleted'), ) status = models.PositiveIntegerField( default=StockStatus.OK.value, choices=StockStatus.items(), - validators=[MinValueValidator(0)]) + validators=[MinValueValidator(0)], + ) @property def status_text(self): @@ -858,10 +916,15 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo help_text=_('Single unit purchase price at time of purchase'), ) - owner = models.ForeignKey(Owner, on_delete=models.SET_NULL, blank=True, null=True, - verbose_name=_('Owner'), - help_text=_('Select Owner'), - related_name='stock_items') + owner = models.ForeignKey( + Owner, + on_delete=models.SET_NULL, + blank=True, + null=True, + verbose_name=_('Owner'), + help_text=_('Select Owner'), + related_name='stock_items', + ) @transaction.atomic def convert_to_variant(self, variant, user, notes=None): @@ -880,9 +943,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo self.add_tracking_entry( StockHistoryCode.CONVERTED_TO_VARIANT, user, - deltas={ - 'part': variant.pk, - }, + deltas={'part': variant.pk}, notes=_('Converted to part') + ': ' + variant.full_name, ) @@ -910,7 +971,9 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo if user.is_superuser: return True - ownership_enabled = common.models.InvenTreeSetting.get_setting('STOCK_OWNERSHIP_CONTROL') + ownership_enabled = common.models.InvenTreeSetting.get_setting( + 'STOCK_OWNERSHIP_CONTROL' + ) if not ownership_enabled: # Location ownership function is not enabled, so return True @@ -979,7 +1042,9 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo # Delete outstanding BuildOrder allocations self.allocations.all().delete() - def allocateToCustomer(self, customer, quantity=None, order=None, user=None, notes=None): + def allocateToCustomer( + self, customer, quantity=None, order=None, user=None, notes=None + ): """Allocate a StockItem to a customer. This action can be called by the following processes: @@ -1020,18 +1085,9 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo code = StockHistoryCode.SHIPPED_AGAINST_SALES_ORDER deltas['salesorder'] = order.pk - item.add_tracking_entry( - code, - user, - deltas, - notes=notes, - ) + item.add_tracking_entry(code, user, deltas, notes=notes) - trigger_event( - 'stockitem.assignedtocustomer', - id=self.id, - customer=customer.id, - ) + trigger_event('stockitem.assignedtocustomer', id=self.id, customer=customer.id) # Return the reference to the stock item return item @@ -1045,9 +1101,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo """ notes = kwargs.get('notes', '') - tracking_info = { - 'location': location.pk, - } + tracking_info = {'location': location.pk} if self.customer: tracking_info['customer'] = self.customer.id @@ -1058,7 +1112,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo user, notes=notes, deltas=tracking_info, - location=location + location=location, ) self.customer = None @@ -1066,18 +1120,12 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo self.sales_order = None self.location = location - trigger_event( - 'stockitem.returnedfromcustomer', - id=self.id, - ) + trigger_event('stockitem.returnedfromcustomer', id=self.id) """If new location is the same as the parent location, merge this stock back in the parent""" if self.parent and self.location == self.parent.location: self.parent.merge_stock_items( - {self}, - user=user, - location=location, - notes=notes + {self}, user=user, location=location, notes=notes ) else: self.save(add_note=False) @@ -1123,14 +1171,12 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo if active is True: query = query.filter( line__order__status__in=SalesOrderStatusGroups.OPEN, - shipment__shipment_date=None + shipment__shipment_date=None, ) elif active is False: query = query.exclude( - line__order__status__in=SalesOrderStatusGroups.OPEN, - ).exclude( - shipment__shipment_date=None - ) + line__order__status__in=SalesOrderStatusGroups.OPEN + ).exclude(shipment__shipment_date=None) return query @@ -1190,7 +1236,6 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo items = StockItem.objects.filter(belongs_to=self) for item in items: - # Prevent duplication or recursion if item == self or item in installed: continue @@ -1201,7 +1246,6 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo sub_items = item.get_installed_items(cascade=True) for sub_item in sub_items: - # Prevent recursion if sub_item == self or sub_item in installed: continue @@ -1236,19 +1280,14 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo stock_item.consumed_by = build stock_item.save(add_note=False) - deltas = { - 'stockitem': self.pk, - } + deltas = {'stockitem': self.pk} if build is not None: deltas['buildorder'] = build.pk # Add a transaction note to the other item stock_item.add_tracking_entry( - StockHistoryCode.INSTALLED_INTO_ASSEMBLY, - user, - notes=notes, - deltas=deltas, + StockHistoryCode.INSTALLED_INTO_ASSEMBLY, user, notes=notes, deltas=deltas ) # Add a transaction note to this item (the assembly) @@ -1256,9 +1295,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo StockHistoryCode.INSTALLED_CHILD_ITEM, user, notes=notes, - deltas={ - 'stockitem': stock_item.pk, - } + deltas={'stockitem': stock_item.pk}, ) @transaction.atomic @@ -1280,15 +1317,11 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo self.belongs_to.add_tracking_entry( StockHistoryCode.REMOVED_CHILD_ITEM, user, - deltas={ - 'stockitem': self.pk, - }, + deltas={'stockitem': self.pk}, notes=notes, ) - tracking_info = { - 'stockitem': self.belongs_to.pk - } + tracking_info = {'stockitem': self.belongs_to.pk} self.add_tracking_entry( StockHistoryCode.REMOVED_FROM_ASSEMBLY, @@ -1359,7 +1392,14 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo """Is tracking info available?""" return self.tracking_info_count > 0 - def add_tracking_entry(self, entry_type: int, user: User, deltas: dict = None, notes: str = '', **kwargs): + def add_tracking_entry( + self, + entry_type: int, + user: User, + deltas: dict = None, + notes: str = '', + **kwargs, + ): """Add a history tracking entry for this StockItem. Args: @@ -1414,37 +1454,44 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo return if not self.part.trackable: - raise ValidationError({"part": _("Part is not set as trackable")}) + raise ValidationError({'part': _('Part is not set as trackable')}) # Quantity must be a valid integer value try: quantity = int(quantity) except ValueError: - raise ValidationError({"quantity": _("Quantity must be integer")}) + raise ValidationError({'quantity': _('Quantity must be integer')}) if quantity <= 0: - raise ValidationError({"quantity": _("Quantity must be greater than zero")}) + raise ValidationError({'quantity': _('Quantity must be greater than zero')}) if quantity > self.quantity: - raise ValidationError({"quantity": _(f"Quantity must not exceed available stock quantity ({self.quantity})")}) + raise ValidationError({ + 'quantity': _( + f'Quantity must not exceed available stock quantity ({self.quantity})' + ) + }) if type(serials) not in [list, tuple]: - raise ValidationError({"serial_numbers": _("Serial numbers must be a list of integers")}) + raise ValidationError({ + 'serial_numbers': _('Serial numbers must be a list of integers') + }) if quantity != len(serials): - raise ValidationError({"quantity": _("Quantity does not match serial numbers")}) + raise ValidationError({ + 'quantity': _('Quantity does not match serial numbers') + }) # Test if each of the serial numbers are valid existing = self.part.find_conflicting_serial_numbers(serials) if len(existing) > 0: exists = ','.join([str(x) for x in existing]) - msg = _("Serial numbers already exist") + f": {exists}" - raise ValidationError({"serial_numbers": msg}) + msg = _('Serial numbers already exist') + f': {exists}' + raise ValidationError({'serial_numbers': msg}) # Create a new stock item for each unique serial number for serial in serials: - # Create a copy of this StockItem new_item = StockItem.objects.get(pk=self.pk) new_item.quantity = 1 @@ -1469,10 +1516,8 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo StockHistoryCode.ASSIGNED_SERIAL, user, notes=notes, - deltas={ - 'serial': serial, - }, - location=location + deltas={'serial': serial}, + location=location, ) # Remove the equivalent number of items @@ -1482,7 +1527,6 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo def copyHistoryFrom(self, other): """Copy stock history from another StockItem.""" for item in other.tracking_info.all(): - item.item = self item.pk = None item.save() @@ -1495,7 +1539,6 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo filters = {} for result in other.test_results.all().filter(**filters): - # Create a copy of the test result by nulling-out the pk result.pk = None result.stock_item = self @@ -1510,7 +1553,9 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo try: # Generic checks (do not rely on the 'other' part) if self.sales_order: - raise ValidationError(_('Stock item has been assigned to a sales order')) + raise ValidationError( + _('Stock item has been assigned to a sales order') + ) if self.belongs_to: raise ValidationError(_('Stock item is installed in another item')) @@ -1525,7 +1570,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo raise ValidationError(_('Stock item is currently in production')) if self.serialized: - raise ValidationError(_("Serialized stock cannot be merged")) + raise ValidationError(_('Serialized stock cannot be merged')) if other: # Specific checks (rely on the 'other' part) @@ -1536,15 +1581,20 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo # Base part must match if self.part != other.part: - raise ValidationError(_("Stock items must refer to the same part")) + raise ValidationError(_('Stock items must refer to the same part')) # Check if supplier part references match - if self.supplier_part != other.supplier_part and not allow_mismatched_suppliers: - raise ValidationError(_("Stock items must refer to the same supplier part")) + if ( + self.supplier_part != other.supplier_part + and not allow_mismatched_suppliers + ): + raise ValidationError( + _('Stock items must refer to the same supplier part') + ) # Check if stock status codes match if self.status != other.status and not allow_mismatched_status: - raise ValidationError(_("Stock status codes must match")) + raise ValidationError(_('Stock status codes must match')) except ValidationError as e: if raise_error: @@ -1579,18 +1629,15 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo return for other in other_items: - self.quantity += other.quantity # Any "build order allocations" for the other item must be assigned to this one for allocation in other.allocations.all(): - allocation.stock_item = self allocation.save() # Any "sales order allocations" for the other item must be assigned to this one for allocation in other.sales_order_allocations.all(): - allocation.stock_item = self() allocation.save() @@ -1606,9 +1653,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo user, quantity=self.quantity, notes=notes, - deltas={ - 'location': location.pk, - } + deltas={'location': location.pk}, ) self.location = location @@ -1669,9 +1714,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo else: new_stock.location = self.location - deltas = { - 'stockitem': self.pk, - } + deltas = {'stockitem': self.pk} # Optional fields which can be supplied in a 'move' call for field in StockItem.optional_transfer_fields(): @@ -1738,7 +1781,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo return False if not self.in_stock: - raise ValidationError(_("StockItem cannot be moved as it is not in stock")) + raise ValidationError(_('StockItem cannot be moved as it is not in stock')) if quantity <= 0: return False @@ -1778,12 +1821,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo setattr(self, field, kwargs[field]) tracking_info[field] = kwargs[field] - self.add_tracking_entry( - tracking_code, - user, - notes=notes, - deltas=tracking_info, - ) + self.add_tracking_entry(tracking_code, user, notes=notes, deltas=tracking_info) self.save() @@ -1839,14 +1877,11 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo self.stocktake_user = user if self.updateQuantity(count): - self.add_tracking_entry( StockHistoryCode.STOCK_COUNT, user, notes=notes, - deltas={ - 'quantity': float(self.quantity), - } + deltas={'quantity': float(self.quantity)}, ) return True @@ -1872,21 +1907,19 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo return False if self.updateQuantity(self.quantity + quantity): - self.add_tracking_entry( StockHistoryCode.STOCK_ADD, user, notes=notes, - deltas={ - 'added': float(quantity), - 'quantity': float(self.quantity), - } + deltas={'added': float(quantity), 'quantity': float(self.quantity)}, ) return True @transaction.atomic - def take_stock(self, quantity, user, notes='', code=StockHistoryCode.STOCK_REMOVE, **kwargs): + def take_stock( + self, quantity, user, notes='', code=StockHistoryCode.STOCK_REMOVE, **kwargs + ): """Remove items from stock.""" # Cannot remove items from a serialized part if self.serialized: @@ -1901,11 +1934,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo return False if self.updateQuantity(self.quantity - quantity): - - deltas = { - 'removed': float(quantity), - 'quantity': float(self.quantity), - } + deltas = {'removed': float(quantity), 'quantity': float(self.quantity)} if location := kwargs.get('location', None): deltas['location'] = location.pk @@ -1913,12 +1942,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo if stockitem := kwargs.get('stockitem', None): deltas['stockitem'] = stockitem.pk - self.add_tracking_entry( - code, - user, - notes=notes, - deltas=deltas, - ) + self.add_tracking_entry(code, user, notes=notes, deltas=deltas) return True @@ -1933,7 +1957,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo s += f' @ {self.location.name}' if self.purchase_order: - s += f" ({self.purchase_order})" + s += f' ({self.purchase_order})' return s @@ -2042,11 +2066,7 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo else: failed += 1 - return { - 'total': total, - 'passed': passed, - 'failed': failed, - } + return {'total': total, 'passed': passed, 'failed': failed} @property def required_test_count(self): @@ -2070,7 +2090,6 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo item_query = StockItem.objects.filter(pk=self.pk) for test_report in report.models.TestReport.objects.filter(enabled=True): - # Attempt to validate report filter (skip if invalid) try: filters = InvenTree.helpers.validateFilterString(test_report.filters) @@ -2093,7 +2112,6 @@ class StockItem(InvenTreeBarcodeMixin, InvenTreeNotesMixin, MetadataMixin, commo item_query = StockItem.objects.filter(pk=self.pk) for lbl in label.models.StockItemLabel.objects.filter(enabled=True): - try: filters = InvenTree.helpers.validateFilterString(lbl.filters) @@ -2130,7 +2148,9 @@ def after_delete_stock_item(sender, instance: StockItem, **kwargs): if not InvenTree.ready.isImportingData(): # Run this check in the background - InvenTree.tasks.offload_task(part_tasks.notify_low_stock_if_required, instance.part) + InvenTree.tasks.offload_task( + part_tasks.notify_low_stock_if_required, instance.part + ) # Schedule an update on parent part pricing if InvenTree.ready.canAppAccessDatabase(allow_test=True): @@ -2144,7 +2164,9 @@ def after_save_stock_item(sender, instance: StockItem, created, **kwargs): if created and not InvenTree.ready.isImportingData(): # Run this check in the background - InvenTree.tasks.offload_task(part_tasks.notify_low_stock_if_required, instance.part) + InvenTree.tasks.offload_task( + part_tasks.notify_low_stock_if_required, instance.part + ) if InvenTree.ready.canAppAccessDatabase(allow_test=True): instance.part.schedule_pricing_update(create=True) @@ -2160,12 +2182,10 @@ class StockItemAttachment(InvenTreeAttachment): def getSubdir(self): """Override attachment location.""" - return os.path.join("stock_files", str(self.stock_item.id)) + return os.path.join('stock_files', str(self.stock_item.id)) stock_item = models.ForeignKey( - StockItem, - on_delete=models.CASCADE, - related_name='attachments' + StockItem, on_delete=models.CASCADE, related_name='attachments' ) @@ -2205,23 +2225,20 @@ class StockItemTracking(models.Model): return getattr(self, 'title', '') - tracking_type = models.IntegerField( - default=StockHistoryCode.LEGACY, - ) + tracking_type = models.IntegerField(default=StockHistoryCode.LEGACY) item = models.ForeignKey( - StockItem, - on_delete=models.CASCADE, - related_name='tracking_info' + StockItem, on_delete=models.CASCADE, related_name='tracking_info' ) date = models.DateTimeField(auto_now_add=True, editable=False) notes = models.CharField( - blank=True, null=True, + blank=True, + null=True, max_length=512, verbose_name=_('Notes'), - help_text=_('Entry notes') + help_text=_('Entry notes'), ) user = models.ForeignKey(User, on_delete=models.SET_NULL, blank=True, null=True) @@ -2231,7 +2248,9 @@ class StockItemTracking(models.Model): def rename_stock_item_test_result_attachment(instance, filename): """Rename test result.""" - return os.path.join('stock_files', str(instance.stock_item.pk), os.path.basename(filename)) + return os.path.join( + 'stock_files', str(instance.stock_item.pk), os.path.basename(filename) + ) class StockItemTestResult(MetadataMixin, models.Model): @@ -2275,17 +2294,16 @@ class StockItemTestResult(MetadataMixin, models.Model): for template in templates: if key == template.key: - if template.requires_value: if not self.value: raise ValidationError({ - "value": _("Value must be provided for this test"), + 'value': _('Value must be provided for this test') }) if template.requires_attachment: if not self.attachment: raise ValidationError({ - "attachment": _("Attachment must be uploaded for this test"), + 'attachment': _('Attachment must be uploaded for this test') }) break @@ -2296,49 +2314,36 @@ class StockItemTestResult(MetadataMixin, models.Model): return InvenTree.helpers.generateTestKey(self.test) stock_item = models.ForeignKey( - StockItem, - on_delete=models.CASCADE, - related_name='test_results' + StockItem, on_delete=models.CASCADE, related_name='test_results' ) test = models.CharField( - blank=False, max_length=100, - verbose_name=_('Test'), - help_text=_('Test name') + blank=False, max_length=100, verbose_name=_('Test'), help_text=_('Test name') ) result = models.BooleanField( - default=False, - verbose_name=_('Result'), - help_text=_('Test result') + default=False, verbose_name=_('Result'), help_text=_('Test result') ) value = models.CharField( - blank=True, max_length=500, + blank=True, + max_length=500, verbose_name=_('Value'), - help_text=_('Test output value') + help_text=_('Test output value'), ) attachment = models.FileField( - null=True, blank=True, + null=True, + blank=True, upload_to=rename_stock_item_test_result_attachment, verbose_name=_('Attachment'), help_text=_('Test result attachment'), ) notes = models.CharField( - blank=True, max_length=500, - verbose_name=_('Notes'), - help_text=_("Test notes"), + blank=True, max_length=500, verbose_name=_('Notes'), help_text=_('Test notes') ) - user = models.ForeignKey( - User, - on_delete=models.SET_NULL, - blank=True, null=True - ) + user = models.ForeignKey(User, on_delete=models.SET_NULL, blank=True, null=True) - date = models.DateTimeField( - auto_now_add=True, - editable=False - ) + date = models.DateTimeField(auto_now_add=True, editable=False) diff --git a/InvenTree/stock/serializers.py b/InvenTree/stock/serializers.py index 593065f9ef..f01b71f954 100644 --- a/InvenTree/stock/serializers.py +++ b/InvenTree/stock/serializers.py @@ -23,12 +23,17 @@ import part.models as part_models import stock.filters from company.serializers import SupplierPartSerializer from InvenTree.models import extract_int -from InvenTree.serializers import (InvenTreeCurrencySerializer, - InvenTreeDecimalField) +from InvenTree.serializers import InvenTreeCurrencySerializer, InvenTreeDecimalField from part.serializers import PartBriefSerializer -from .models import (StockItem, StockItemAttachment, StockItemTestResult, - StockItemTracking, StockLocation, StockLocationType) +from .models import ( + StockItem, + StockItemAttachment, + StockItemTestResult, + StockItemTracking, + StockLocation, + StockLocationType, +) class LocationBriefSerializer(InvenTree.serializers.InvenTreeModelSerializer): @@ -38,11 +43,7 @@ class LocationBriefSerializer(InvenTree.serializers.InvenTreeModelSerializer): """Metaclass options.""" model = StockLocation - fields = [ - 'pk', - 'name', - 'pathstring', - ] + fields = ['pk', 'name', 'pathstring'] class StockItemTestResultSerializer(InvenTree.serializers.InvenTreeModelSerializer): @@ -64,15 +65,11 @@ class StockItemTestResultSerializer(InvenTree.serializers.InvenTreeModelSerializ 'notes', 'user', 'user_detail', - 'date' - ] - - read_only_fields = [ - 'pk', - 'user', 'date', ] + read_only_fields = ['pk', 'user', 'date'] + def __init__(self, *args, **kwargs): """Add detail fields.""" user_detail = kwargs.pop('user_detail', False) @@ -86,7 +83,9 @@ class StockItemTestResultSerializer(InvenTree.serializers.InvenTreeModelSerializ key = serializers.CharField(read_only=True) - attachment = InvenTree.serializers.InvenTreeAttachmentSerializerField(required=False) + attachment = InvenTree.serializers.InvenTreeAttachmentSerializerField( + required=False + ) class StockItemSerializerBrief(InvenTree.serializers.InvenTreeModelSerializer): @@ -107,9 +106,7 @@ class StockItemSerializerBrief(InvenTree.serializers.InvenTreeModelSerializer): 'barcode_hash', ] - read_only_fields = [ - 'barcode_hash', - ] + read_only_fields = ['barcode_hash'] part_name = serializers.CharField(source='part.full_name', read_only=True) @@ -117,8 +114,8 @@ class StockItemSerializerBrief(InvenTree.serializers.InvenTreeModelSerializer): def validate_serial(self, value): """Make sure serial is not to big.""" - if abs(extract_int(value)) > 0x7fffffff: - raise serializers.ValidationError(_("Serial number is too large")) + if abs(extract_int(value)) > 0x7FFFFFFF: + raise serializers.ValidationError(_('Serial number is too large')) return value @@ -169,14 +166,12 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer): 'purchase_price_currency', 'use_pack_size', 'tests', - # Annotated fields 'allocated', 'expired', 'installed_items', 'stale', 'tracking_items', - 'tags', ] @@ -195,21 +190,18 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer): """ Fields used when creating a stock item """ - extra_kwargs = { - 'use_pack_size': {'write_only': True}, - } + extra_kwargs = {'use_pack_size': {'write_only': True}} part = serializers.PrimaryKeyRelatedField( queryset=part_models.Part.objects.all(), - many=False, allow_null=False, - help_text=_("Base Part"), - label=_("Part"), + many=False, + allow_null=False, + help_text=_('Base Part'), + label=_('Part'), ) location_path = serializers.ListField( - child=serializers.DictField(), - source='location.get_path', - read_only=True, + child=serializers.DictField(), source='location.get_path', read_only=True ) """ @@ -219,14 +211,16 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer): write_only=True, required=False, allow_null=True, - help_text=_("Use pack size when adding: the quantity defined is the number of packs"), - label=("Use pack size"), + help_text=_( + 'Use pack size when adding: the quantity defined is the number of packs' + ), + label=('Use pack size'), ) def validate_part(self, part): """Ensure the provided Part instance is valid""" if part.virtual: - raise ValidationError(_("Stock item cannot be created for virtual parts")) + raise ValidationError(_('Stock item cannot be created for virtual parts')) return part @@ -257,54 +251,63 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer): queryset = queryset.annotate( allocated=Coalesce( SubquerySum('sales_order_allocations__quantity'), Decimal(0) - ) + Coalesce( - SubquerySum('allocations__quantity'), Decimal(0) ) + + Coalesce(SubquerySum('allocations__quantity'), Decimal(0)) ) # Annotate the queryset with the number of tracking items - queryset = queryset.annotate( - tracking_items=SubqueryCount('tracking_info') - ) + queryset = queryset.annotate(tracking_items=SubqueryCount('tracking_info')) # Add flag to indicate if the StockItem has expired queryset = queryset.annotate( expired=Case( When( - StockItem.EXPIRED_FILTER, then=Value(True, output_field=BooleanField()), + StockItem.EXPIRED_FILTER, + then=Value(True, output_field=BooleanField()), ), - default=Value(False, output_field=BooleanField()) + default=Value(False, output_field=BooleanField()), ) ) # Add flag to indicate if the StockItem is stale stale_days = common.models.InvenTreeSetting.get_setting('STOCK_STALE_DAYS') stale_date = datetime.now().date() + timedelta(days=stale_days) - stale_filter = StockItem.IN_STOCK_FILTER & ~Q(expiry_date=None) & Q(expiry_date__lt=stale_date) + stale_filter = ( + StockItem.IN_STOCK_FILTER + & ~Q(expiry_date=None) + & Q(expiry_date__lt=stale_date) + ) queryset = queryset.annotate( stale=Case( - When( - stale_filter, then=Value(True, output_field=BooleanField()), - ), + When(stale_filter, then=Value(True, output_field=BooleanField())), default=Value(False, output_field=BooleanField()), ) ) # Annotate with the total number of "installed items" - queryset = queryset.annotate( - installed_items=SubqueryCount('installed_parts') - ) + queryset = queryset.annotate(installed_items=SubqueryCount('installed_parts')) return queryset status_text = serializers.CharField(source='get_status_display', read_only=True) # Optional detail fields, which can be appended via query parameters - supplier_part_detail = SupplierPartSerializer(source='supplier_part', supplier_detail=False, manufacturer_detail=False, part_detail=False, many=False, read_only=True) + supplier_part_detail = SupplierPartSerializer( + source='supplier_part', + supplier_detail=False, + manufacturer_detail=False, + part_detail=False, + many=False, + read_only=True, + ) part_detail = PartBriefSerializer(source='part', many=False, read_only=True) - location_detail = LocationBriefSerializer(source='location', many=False, read_only=True) - tests = StockItemTestResultSerializer(source='test_results', many=True, read_only=True) + location_detail = LocationBriefSerializer( + source='location', many=False, read_only=True + ) + tests = StockItemTestResultSerializer( + source='test_results', many=True, read_only=True + ) quantity = InvenTreeDecimalField() @@ -321,10 +324,16 @@ class StockItemSerializer(InvenTree.serializers.InvenTreeTagModelSerializer): help_text=_('Purchase price of this stock item, per unit or pack'), ) - purchase_price_currency = InvenTreeCurrencySerializer(help_text=_('Purchase currency of this stock item')) + purchase_price_currency = InvenTreeCurrencySerializer( + help_text=_('Purchase currency of this stock item') + ) - purchase_order_reference = serializers.CharField(source='purchase_order.reference', read_only=True) - sales_order_reference = serializers.CharField(source='sales_order.reference', read_only=True) + purchase_order_reference = serializers.CharField( + source='purchase_order.reference', read_only=True + ) + sales_order_reference = serializers.CharField( + source='sales_order.reference', read_only=True + ) tags = TagListSerializerField(required=False) @@ -368,12 +377,7 @@ class SerializeStockItemSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'quantity', - 'serial_numbers', - 'destination', - 'notes', - ] + fields = ['quantity', 'serial_numbers', 'destination', 'notes'] quantity = serializers.IntegerField( min_value=0, @@ -387,11 +391,13 @@ class SerializeStockItemSerializer(serializers.Serializer): item = self.context['item'] if quantity < 0: - raise ValidationError(_("Quantity must be greater than zero")) + raise ValidationError(_('Quantity must be greater than zero')) if quantity > item.quantity: q = item.quantity - raise ValidationError(_(f"Quantity must not exceed available stock quantity ({q})")) + raise ValidationError( + _(f'Quantity must not exceed available stock quantity ({q})') + ) return quantity @@ -414,8 +420,8 @@ class SerializeStockItemSerializer(serializers.Serializer): notes = serializers.CharField( required=False, allow_blank=True, - label=_("Notes"), - help_text=_("Optional note field") + label=_('Notes'), + help_text=_('Optional note field'), ) def validate(self, data): @@ -425,7 +431,7 @@ class SerializeStockItemSerializer(serializers.Serializer): item = self.context['item'] if not item.part.trackable: - raise ValidationError(_("Serial numbers cannot be assigned to this part")) + raise ValidationError(_('Serial numbers cannot be assigned to this part')) # Ensure the serial numbers are valid! quantity = data['quantity'] @@ -433,24 +439,18 @@ class SerializeStockItemSerializer(serializers.Serializer): try: serials = InvenTree.helpers.extract_serial_numbers( - serial_numbers, - quantity, - item.part.get_latest_serial_number() + serial_numbers, quantity, item.part.get_latest_serial_number() ) except DjangoValidationError as e: - raise ValidationError({ - 'serial_numbers': e.messages, - }) + raise ValidationError({'serial_numbers': e.messages}) existing = item.part.find_conflicting_serial_numbers(serials) if len(existing) > 0: exists = ','.join([str(x) for x in existing]) - error = _('Serial numbers already exist') + ": " + exists + error = _('Serial numbers already exist') + ': ' + exists - raise ValidationError({ - 'serial_numbers': error, - }) + raise ValidationError({'serial_numbers': error}) return data @@ -465,7 +465,7 @@ class SerializeStockItemSerializer(serializers.Serializer): serials = InvenTree.helpers.extract_serial_numbers( data['serial_numbers'], data['quantity'], - item.part.get_latest_serial_number() + item.part.get_latest_serial_number(), ) item.serializeStock( @@ -508,7 +508,7 @@ class InstallStockItemSerializer(serializers.Serializer): """Validate the quantity value.""" if quantity < 1: - raise ValidationError(_("Quantity to install must be at least 1")) + raise ValidationError(_('Quantity to install must be at least 1')) return quantity @@ -516,14 +516,14 @@ class InstallStockItemSerializer(serializers.Serializer): """Validate the selected stock item.""" if not stock_item.in_stock: # StockItem must be in stock to be "installed" - raise ValidationError(_("Stock item is unavailable")) + raise ValidationError(_('Stock item is unavailable')) parent_item = self.context['item'] parent_part = parent_item.part # Check if the selected part is in the Bill of Materials of the parent item if not parent_part.check_if_part_in_bom(stock_item.part): - raise ValidationError(_("Selected part is not in the Bill of Materials")) + raise ValidationError(_('Selected part is not in the Bill of Materials')) return stock_item @@ -535,7 +535,9 @@ class InstallStockItemSerializer(serializers.Serializer): quantity = data.get('quantity', stock_item.quantity) if quantity > stock_item.quantity: - raise ValidationError(_("Quantity to install must not exceed available quantity")) + raise ValidationError( + _('Quantity to install must not exceed available quantity') + ) return data @@ -551,10 +553,7 @@ class InstallStockItemSerializer(serializers.Serializer): request = self.context['request'] parent_item.installStockItem( - stock_item, - quantity_to_install, - request.user, - note, + stock_item, quantity_to_install, request.user, note ) @@ -564,22 +563,22 @@ class UninstallStockItemSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'location', - 'note', - ] + fields = ['location', 'note'] location = serializers.PrimaryKeyRelatedField( queryset=StockLocation.objects.all(), - many=False, required=True, allow_null=False, + many=False, + required=True, + allow_null=False, label=_('Location'), - help_text=_('Destination location for uninstalled item') + help_text=_('Destination location for uninstalled item'), ) note = serializers.CharField( label=_('Notes'), help_text=_('Add transaction note (optional)'), - required=False, allow_blank=True, + required=False, + allow_blank=True, ) def save(self): @@ -593,11 +592,7 @@ class UninstallStockItemSerializer(serializers.Serializer): note = data.get('note', '') - item.uninstall_into_location( - location, - request.user, - note - ) + item.uninstall_into_location(location, request.user, note) class ConvertStockItemSerializer(serializers.Serializer): @@ -605,15 +600,16 @@ class ConvertStockItemSerializer(serializers.Serializer): class Meta: """Metaclass options""" - fields = [ - 'part', - ] + + fields = ['part'] part = serializers.PrimaryKeyRelatedField( queryset=part_models.Part.objects.all(), label=_('Part'), help_text=_('Select part to convert stock item into'), - many=False, required=True, allow_null=False + many=False, + required=True, + allow_null=False, ) def validate_part(self, part): @@ -622,7 +618,9 @@ class ConvertStockItemSerializer(serializers.Serializer): valid_options = stock_item.part.get_conversion_options() if part not in valid_options: - raise ValidationError(_("Selected part is not a valid option for conversion")) + raise ValidationError( + _('Selected part is not a valid option for conversion') + ) return part @@ -636,7 +634,9 @@ class ConvertStockItemSerializer(serializers.Serializer): stock_item = self.context['item'] if stock_item.supplier_part is not None: - raise ValidationError(_("Cannot convert stock item with assigned SupplierPart")) + raise ValidationError( + _('Cannot convert stock item with assigned SupplierPart') + ) return data @@ -658,14 +658,13 @@ class ReturnStockItemSerializer(serializers.Serializer): class Meta: """Metaclass options""" - fields = [ - 'location', - 'note', - ] + fields = ['location', 'note'] location = serializers.PrimaryKeyRelatedField( queryset=StockLocation.objects.all(), - many=False, required=True, allow_null=False, + many=False, + required=True, + allow_null=False, label=_('Location'), help_text=_('Destination location for returned item'), ) @@ -673,7 +672,8 @@ class ReturnStockItemSerializer(serializers.Serializer): notes = serializers.CharField( label=_('Notes'), help_text=_('Add transaction note (optional)'), - required=False, allow_blank=True, + required=False, + allow_blank=True, ) def save(self): @@ -686,11 +686,7 @@ class ReturnStockItemSerializer(serializers.Serializer): location = data['location'] notes = data.get('notes', '') - item.return_from_customer( - location, - user=request.user, - notes=notes - ) + item.return_from_customer(location, user=request.user, notes=notes) class StockChangeStatusSerializer(serializers.Serializer): @@ -698,11 +694,8 @@ class StockChangeStatusSerializer(serializers.Serializer): class Meta: """Metaclass options""" - fields = [ - 'items', - 'status', - 'note', - ] + + fields = ['items', 'status', 'note'] items = serializers.PrimaryKeyRelatedField( queryset=StockItem.objects.all(), @@ -716,7 +709,7 @@ class StockChangeStatusSerializer(serializers.Serializer): def validate_items(self, items): """Validate the selected stock items""" if len(items) == 0: - raise ValidationError(_("No stock items selected")) + raise ValidationError(_('No stock items selected')) return items @@ -729,7 +722,8 @@ class StockChangeStatusSerializer(serializers.Serializer): note = serializers.CharField( label=_('Notes'), help_text=_('Add transaction note (optional)'), - required=False, allow_blank=True, + required=False, + allow_blank=True, ) @transaction.atomic @@ -748,9 +742,7 @@ class StockChangeStatusSerializer(serializers.Serializer): items_to_update = [] transaction_notes = [] - deltas = { - 'status': status, - } + deltas = {'status': status} now = datetime.now() @@ -792,26 +784,16 @@ class StockLocationTypeSerializer(InvenTree.serializers.InvenTreeModelSerializer """Serializer metaclass.""" model = StockLocationType - fields = [ - "pk", - "name", - "description", - "icon", - "location_count", - ] + fields = ['pk', 'name', 'description', 'icon', 'location_count'] - read_only_fields = [ - "location_count", - ] + read_only_fields = ['location_count'] location_count = serializers.IntegerField(read_only=True) @staticmethod def annotate_queryset(queryset): """Add location count to each location type.""" - return queryset.annotate( - location_count=Count("stock_locations") - ) + return queryset.annotate(location_count=Count('stock_locations')) class LocationTreeSerializer(InvenTree.serializers.InvenTreeModelSerializer): @@ -821,13 +803,7 @@ class LocationTreeSerializer(InvenTree.serializers.InvenTreeModelSerializer): """Metaclass options.""" model = StockLocation - fields = [ - 'pk', - 'name', - 'parent', - 'icon', - 'structural', - ] + fields = ['pk', 'name', 'parent', 'icon', 'structural'] class LocationSerializer(InvenTree.serializers.InvenTreeTagModelSerializer): @@ -858,10 +834,7 @@ class LocationSerializer(InvenTree.serializers.InvenTreeTagModelSerializer): 'tags', ] - read_only_fields = [ - 'barcode_hash', - 'icon', - ] + read_only_fields = ['barcode_hash', 'icon'] def __init__(self, *args, **kwargs): """Optionally add or remove extra fields""" @@ -876,9 +849,7 @@ class LocationSerializer(InvenTree.serializers.InvenTreeTagModelSerializer): def annotate_queryset(queryset): """Annotate extra information to the queryset""" # Annotate the number of stock items which exist in this category (including subcategories) - queryset = queryset.annotate( - items=stock.filters.annotate_location_items() - ) + queryset = queryset.annotate(items=stock.filters.annotate_location_items()) return queryset @@ -891,19 +862,21 @@ class LocationSerializer(InvenTree.serializers.InvenTreeTagModelSerializer): tags = TagListSerializerField(required=False) path = serializers.ListField( - child=serializers.DictField(), - source='get_path', - read_only=True, + child=serializers.DictField(), source='get_path', read_only=True ) # explicitly set this field, so it gets included for AutoSchema icon = serializers.CharField(read_only=True) # Detail for location type - location_type_detail = StockLocationTypeSerializer(source="location_type", read_only=True, many=False) + location_type_detail = StockLocationTypeSerializer( + source='location_type', read_only=True, many=False + ) -class StockItemAttachmentSerializer(InvenTree.serializers.InvenTreeAttachmentSerializer): +class StockItemAttachmentSerializer( + InvenTree.serializers.InvenTreeAttachmentSerializer +): """Serializer for StockItemAttachment model.""" class Meta: @@ -912,7 +885,7 @@ class StockItemAttachmentSerializer(InvenTree.serializers.InvenTreeAttachmentSer model = StockItemAttachment fields = InvenTree.serializers.InvenTreeAttachmentSerializer.attachment_fields([ - 'stock_item', + 'stock_item' ]) @@ -936,12 +909,7 @@ class StockTrackingSerializer(InvenTree.serializers.InvenTreeModelSerializer): 'user_detail', ] - read_only_fields = [ - 'date', - 'user', - 'label', - 'tracking_type', - ] + read_only_fields = ['date', 'user', 'label', 'tracking_type'] def __init__(self, *args, **kwargs): """Add detail fields.""" @@ -960,7 +928,9 @@ class StockTrackingSerializer(InvenTree.serializers.InvenTreeModelSerializer): item_detail = StockItemSerializerBrief(source='item', many=False, read_only=True) - user_detail = InvenTree.serializers.UserSerializer(source='user', many=False, read_only=True) + user_detail = InvenTree.serializers.UserSerializer( + source='user', many=False, read_only=True + ) deltas = serializers.JSONField(read_only=True) @@ -977,9 +947,7 @@ class StockAssignmentItemSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'item', - ] + fields = ['item'] item = serializers.PrimaryKeyRelatedField( queryset=StockItem.objects.all(), @@ -999,19 +967,19 @@ class StockAssignmentItemSerializer(serializers.Serializer): """ # The item must currently be "in stock" if not item.in_stock: - raise ValidationError(_("Item must be in stock")) + raise ValidationError(_('Item must be in stock')) # The base part must be "salable" if not item.part.salable: - raise ValidationError(_("Part must be salable")) + raise ValidationError(_('Part must be salable')) # The item must not be allocated to a sales order if item.sales_order_allocations.count() > 0: - raise ValidationError(_("Item is allocated to a sales order")) + raise ValidationError(_('Item is allocated to a sales order')) # The item must not be allocated to a build order if item.allocations.count() > 0: - raise ValidationError(_("Item is allocated to a build order")) + raise ValidationError(_('Item is allocated to a build order')) return item @@ -1025,16 +993,9 @@ class StockAssignmentSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'items', - 'customer', - 'notes', - ] + fields = ['items', 'customer', 'notes'] - items = StockAssignmentItemSerializer( - many=True, - required=True, - ) + items = StockAssignmentItemSerializer(many=True, required=True) customer = serializers.PrimaryKeyRelatedField( queryset=company.models.Company.objects.all(), @@ -1066,7 +1027,7 @@ class StockAssignmentSerializer(serializers.Serializer): items = data.get('items', []) if len(items) == 0: - raise ValidationError(_("A list of stock items must be provided")) + raise ValidationError(_('A list of stock items must be provided')) return data @@ -1084,14 +1045,9 @@ class StockAssignmentSerializer(serializers.Serializer): with transaction.atomic(): for item in items: - stock_item = item['item'] - stock_item.allocateToCustomer( - customer, - user=user, - notes=notes, - ) + stock_item.allocateToCustomer(customer, user=user, notes=notes) class StockMergeItemSerializer(serializers.Serializer): @@ -1103,9 +1059,7 @@ class StockMergeItemSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'item', - ] + fields = ['item'] item = serializers.PrimaryKeyRelatedField( queryset=StockItem.objects.all(), @@ -1137,10 +1091,7 @@ class StockMergeSerializer(serializers.Serializer): 'allow_mismatched_status', ] - items = StockMergeItemSerializer( - many=True, - required=True, - ) + items = StockMergeItemSerializer(many=True, required=True) location = serializers.PrimaryKeyRelatedField( queryset=StockLocation.objects.all(), @@ -1203,7 +1154,9 @@ class StockMergeSerializer(serializers.Serializer): item.can_merge( raise_error=True, other=base_item, - allow_mismatched_suppliers=data.get('allow_mismatched_suppliers', False), + allow_mismatched_suppliers=data.get( + 'allow_mismatched_suppliers', False + ), allow_mismatched_status=data.get('allow_mismatched_status', False), ) @@ -1233,7 +1186,7 @@ class StockMergeSerializer(serializers.Serializer): allow_mismatched_status=data.get('allow_mismatched_status', False), user=user, location=data['location'], - notes=data.get('notes', None) + notes=data.get('notes', None), ) @@ -1255,10 +1208,7 @@ class StockAdjustmentItemSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'item', - 'quantity' - ] + fields = ['item', 'quantity'] pk = serializers.PrimaryKeyRelatedField( queryset=StockItem.objects.all(), @@ -1266,19 +1216,17 @@ class StockAdjustmentItemSerializer(serializers.Serializer): allow_null=False, required=True, label='stock_item', - help_text=_('StockItem primary key value') + help_text=_('StockItem primary key value'), ) quantity = serializers.DecimalField( - max_digits=15, - decimal_places=5, - min_value=0, - required=True + max_digits=15, decimal_places=5, min_value=0, required=True ) batch = serializers.CharField( max_length=100, - required=False, allow_blank=True, + required=False, + allow_blank=True, label=_('Batch Code'), help_text=_('Batch code for this stock item'), ) @@ -1288,12 +1236,14 @@ class StockAdjustmentItemSerializer(serializers.Serializer): default=InvenTree.status_codes.StockStatus.OK.value, label=_('Status'), help_text=_('Stock item status code'), - required=False, allow_blank=True, + required=False, + allow_blank=True, ) packaging = serializers.CharField( max_length=50, - required=False, allow_blank=True, + required=False, + allow_blank=True, label=_('Packaging'), help_text=_('Packaging this stock item is stored in'), ) @@ -1305,18 +1255,15 @@ class StockAdjustmentSerializer(serializers.Serializer): class Meta: """Metaclass options.""" - fields = [ - 'items', - 'notes', - ] + fields = ['items', 'notes'] items = StockAdjustmentItemSerializer(many=True) notes = serializers.CharField( required=False, allow_blank=True, - label=_("Notes"), - help_text=_("Stock transaction notes"), + label=_('Notes'), + help_text=_('Stock transaction notes'), ) def validate(self, data): @@ -1326,7 +1273,7 @@ class StockAdjustmentSerializer(serializers.Serializer): items = data.get('items', []) if len(items) == 0: - raise ValidationError(_("A list of stock items must be provided")) + raise ValidationError(_('A list of stock items must be provided')) return data @@ -1344,15 +1291,10 @@ class StockCountSerializer(StockAdjustmentSerializer): with transaction.atomic(): for item in items: - stock_item = item['pk'] quantity = item['quantity'] - stock_item.stocktake( - quantity, - request.user, - notes=notes - ) + stock_item.stocktake(quantity, request.user, notes=notes) class StockAddSerializer(StockAdjustmentSerializer): @@ -1367,15 +1309,10 @@ class StockAddSerializer(StockAdjustmentSerializer): with transaction.atomic(): for item in data['items']: - stock_item = item['pk'] quantity = item['quantity'] - stock_item.add_stock( - quantity, - request.user, - notes=notes - ) + stock_item.add_stock(quantity, request.user, notes=notes) class StockRemoveSerializer(StockAdjustmentSerializer): @@ -1390,15 +1327,10 @@ class StockRemoveSerializer(StockAdjustmentSerializer): with transaction.atomic(): for item in data['items']: - stock_item = item['pk'] quantity = item['quantity'] - stock_item.take_stock( - quantity, - request.user, - notes=notes - ) + stock_item.take_stock(quantity, request.user, notes=notes) class StockTransferSerializer(StockAdjustmentSerializer): @@ -1407,11 +1339,7 @@ class StockTransferSerializer(StockAdjustmentSerializer): class Meta: """Metaclass options.""" - fields = [ - 'items', - 'notes', - 'location', - ] + fields = ['items', 'notes', 'location'] location = serializers.PrimaryKeyRelatedField( queryset=StockLocation.objects.all(), @@ -1434,7 +1362,6 @@ class StockTransferSerializer(StockAdjustmentSerializer): with transaction.atomic(): for item in items: - # Required fields stock_item = item['pk'] quantity = item['quantity'] @@ -1447,9 +1374,5 @@ class StockTransferSerializer(StockAdjustmentSerializer): kwargs[field_name] = item[field_name] stock_item.move( - location, - notes, - request.user, - quantity=quantity, - **kwargs + location, notes, request.user, quantity=quantity, **kwargs ) diff --git a/InvenTree/stock/test_api.py b/InvenTree/stock/test_api.py index 0243aaa405..349e4574e9 100644 --- a/InvenTree/stock/test_api.py +++ b/InvenTree/stock/test_api.py @@ -20,8 +20,12 @@ from common.models import InvenTreeSetting from InvenTree.status_codes import StockHistoryCode, StockStatus from InvenTree.unit_test import InvenTreeAPITestCase from part.models import Part -from stock.models import (StockItem, StockItemTestResult, StockLocation, - StockLocationType) +from stock.models import ( + StockItem, + StockItemTestResult, + StockLocation, + StockLocationType, +) class StockAPITestCase(InvenTreeAPITestCase): @@ -69,17 +73,57 @@ class StockLocationTest(StockAPITestCase): ({'parent': 1, 'cascade': True}, 2, 'Filter by parent, cascading'), ({'cascade': True, 'depth': 0}, 8, 'Cascade with no parent, depth=0'), ({'cascade': False, 'depth': 10}, 8, 'Cascade with no parent, depth=0'), - ({'parent': 'null', 'cascade': True, 'depth': 0}, 7, 'Cascade with null parent, depth=0'), - ({'parent': 'null', 'cascade': True, 'depth': 10}, 8, 'Cascade with null parent and bigger depth'), - ({'parent': 'null', 'cascade': False, 'depth': 10}, 3, 'No cascade even with depth specified with null parent'), - ({'parent': 1, 'cascade': False, 'depth': 0}, 2, 'Dont cascade with depth=0 and parent'), - ({'parent': 1, 'cascade': True, 'depth': 0}, 2, 'Cascade with depth=0 and parent'), - ({'parent': 1, 'cascade': False, 'depth': 1}, 2, 'Dont cascade even with depth=1 specified with parent'), - ({'parent': 1, 'cascade': True, 'depth': 1}, 2, 'Cascade with depth=1 with parent'), - ({'parent': 1, 'cascade': True, 'depth': 'abcdefg'}, 2, 'Cascade with invalid depth and parent'), + ( + {'parent': 'null', 'cascade': True, 'depth': 0}, + 7, + 'Cascade with null parent, depth=0', + ), + ( + {'parent': 'null', 'cascade': True, 'depth': 10}, + 8, + 'Cascade with null parent and bigger depth', + ), + ( + {'parent': 'null', 'cascade': False, 'depth': 10}, + 3, + 'No cascade even with depth specified with null parent', + ), + ( + {'parent': 1, 'cascade': False, 'depth': 0}, + 2, + 'Dont cascade with depth=0 and parent', + ), + ( + {'parent': 1, 'cascade': True, 'depth': 0}, + 2, + 'Cascade with depth=0 and parent', + ), + ( + {'parent': 1, 'cascade': False, 'depth': 1}, + 2, + 'Dont cascade even with depth=1 specified with parent', + ), + ( + {'parent': 1, 'cascade': True, 'depth': 1}, + 2, + 'Cascade with depth=1 with parent', + ), + ( + {'parent': 1, 'cascade': True, 'depth': 'abcdefg'}, + 2, + 'Cascade with invalid depth and parent', + ), ({'parent': 42}, 8, 'Should return everything if parent_pk is not valid'), - ({'parent': 'null', 'exclude_tree': 1, 'cascade': True}, 5, 'Should return everything except tree with pk=1'), - ({'parent': 'null', 'exclude_tree': 42, 'cascade': True}, 8, 'Should return everything because exclude_tree=42 is no valid pk'), + ( + {'parent': 'null', 'exclude_tree': 1, 'cascade': True}, + 5, + 'Should return everything except tree with pk=1', + ), + ( + {'parent': 'null', 'exclude_tree': 42, 'cascade': True}, + 8, + 'Should return everything because exclude_tree=42 is no valid pk', + ), ] for params, res_len, description in test_cases: @@ -105,7 +149,9 @@ class StockLocationTest(StockAPITestCase): response = self.get(self.list_url, expected_code=200) for result in response.data: for f in fields: - self.assertIn(f, result, f'"{f}" is missing in result of StockLocation list') + self.assertIn( + f, result, f'"{f}" is missing in result of StockLocation list' + ) def test_add(self): """Test adding StockLocation.""" @@ -113,7 +159,7 @@ class StockLocationTest(StockAPITestCase): data = { 'parent': 1, 'name': 'Location', - 'description': 'Another location for stock' + 'description': 'Another location for stock', } self.post(self.list_url, data, expected_code=201) @@ -122,14 +168,15 @@ class StockLocationTest(StockAPITestCase): """Test stock location deletion with different parameters""" class Target(IntEnum): - move_sub_locations_to_parent_move_stockitems_to_parent = 0, - move_sub_locations_to_parent_delete_stockitems = 1, - delete_sub_locations_move_stockitems_to_parent = 2, - delete_sub_locations_delete_stockitems = 3, + move_sub_locations_to_parent_move_stockitems_to_parent = (0,) + move_sub_locations_to_parent_delete_stockitems = (1,) + delete_sub_locations_move_stockitems_to_parent = (2,) + delete_sub_locations_delete_stockitems = (3,) # First, construct a set of template / variant parts part = Part.objects.create( - name='Part for stock item creation', description='Part for stock item creation', + name='Part for stock item creation', + description='Part for stock item creation', category=None, is_template=False, ) @@ -138,16 +185,22 @@ class StockLocationTest(StockAPITestCase): delete_sub_locations: bool = False delete_stock_items: bool = False - if i in (Target.move_sub_locations_to_parent_delete_stockitems, Target.delete_sub_locations_delete_stockitems): + if i in ( + Target.move_sub_locations_to_parent_delete_stockitems, + Target.delete_sub_locations_delete_stockitems, + ): delete_stock_items = True - if i in (Target.delete_sub_locations_move_stockitems_to_parent, Target.delete_sub_locations_delete_stockitems): + if i in ( + Target.delete_sub_locations_move_stockitems_to_parent, + Target.delete_sub_locations_delete_stockitems, + ): delete_sub_locations = True # Create a parent stock location parent_stock_location = StockLocation.objects.create( name='Parent stock location', description='This is the parent stock location where the sub categories and stock items are moved to', - parent=None + parent=None, ) stocklocation_count_before = StockLocation.objects.count() @@ -157,38 +210,42 @@ class StockLocationTest(StockAPITestCase): stock_location_to_delete = StockLocation.objects.create( name='Stock location to delete', description='This is the stock location to be deleted', - parent=parent_stock_location + parent=parent_stock_location, ) - url = reverse('api-location-detail', kwargs={'pk': stock_location_to_delete.id}) + url = reverse( + 'api-location-detail', kwargs={'pk': stock_location_to_delete.id} + ) stock_items = [] # Create stock items in the location to be deleted for jj in range(3): - stock_items.append(StockItem.objects.create( - batch=f"Batch xyz {jj}", - location=stock_location_to_delete, - part=part - )) + stock_items.append( + StockItem.objects.create( + batch=f'Batch xyz {jj}', + location=stock_location_to_delete, + part=part, + ) + ) child_stock_locations = [] child_stock_locations_items = [] # Create sub location under the stock location to be deleted for ii in range(3): child = StockLocation.objects.create( - name=f"Sub-location {ii}", - description="A sub-location of the deleted stock location", - parent=stock_location_to_delete + name=f'Sub-location {ii}', + description='A sub-location of the deleted stock location', + parent=stock_location_to_delete, ) child_stock_locations.append(child) # Create stock items in the sub locations for jj in range(3): - child_stock_locations_items.append(StockItem.objects.create( - batch=f"B xyz {jj}", - part=part, - location=child - )) + child_stock_locations_items.append( + StockItem.objects.create( + batch=f'B xyz {jj}', part=part, location=child + ) + ) # Delete the created stock location params = {} @@ -196,21 +253,22 @@ class StockLocationTest(StockAPITestCase): params['delete_stock_items'] = '1' if delete_sub_locations: params['delete_sub_locations'] = '1' - response = self.delete( - url, - params, - expected_code=204, - ) + response = self.delete(url, params, expected_code=204) self.assertEqual(response.status_code, 204) if delete_stock_items: if i == Target.delete_sub_locations_delete_stockitems: # Check if all sub-categories deleted - self.assertEqual(StockItem.objects.count(), stock_location_count_before) + self.assertEqual( + StockItem.objects.count(), stock_location_count_before + ) elif i == Target.move_sub_locations_to_parent_delete_stockitems: # Check if all stock locations deleted - self.assertEqual(StockItem.objects.count(), stock_location_count_before + len(child_stock_locations_items)) + self.assertEqual( + StockItem.objects.count(), + stock_location_count_before + len(child_stock_locations_items), + ) else: # Stock locations moved to the parent location for stock_item in stock_items: @@ -220,11 +278,15 @@ class StockLocationTest(StockAPITestCase): if delete_sub_locations: for child_stock_location_item in child_stock_locations_items: child_stock_location_item.refresh_from_db() - self.assertEqual(child_stock_location_item.location, parent_stock_location) + self.assertEqual( + child_stock_location_item.location, parent_stock_location + ) if delete_sub_locations: # Check if all sub-locations are deleted - self.assertEqual(StockLocation.objects.count(), stocklocation_count_before) + self.assertEqual( + StockLocation.objects.count(), stocklocation_count_before + ) else: # Check if all sub-locations moved to the parent for child in child_stock_locations: @@ -244,7 +306,7 @@ class StockLocationTest(StockAPITestCase): name='Structural stock location', description='This is the structural stock location', parent=None, - structural=True + structural=True, ) stock_item_count_before = StockItem.objects.count() @@ -252,8 +314,8 @@ class StockLocationTest(StockAPITestCase): # Make sure that we get an error if we try to create a stock item in the structural location with self.assertRaises(ValidationError): item = StockItem.objects.create( - batch="Stock item which shall not be created", - location=structural_location + batch='Stock item which shall not be created', + location=structural_location, ) # Ensure that the stock item really did not get created in the structural location @@ -264,21 +326,20 @@ class StockLocationTest(StockAPITestCase): name='Non-structural category', description='This is a non-structural category', parent=None, - structural=False + structural=False, ) # Construct a part for stock item creation part = Part.objects.create( - name='Part for stock item creation', description='Part for stock item creation', + name='Part for stock item creation', + description='Part for stock item creation', category=None, is_template=False, ) # Create the test stock item located to a non-structural category item = StockItem.objects.create( - batch="BBB", - location=non_structural_location, - part=part + batch='BBB', location=non_structural_location, part=part ) # Try to relocate it to a structural location @@ -297,46 +358,99 @@ class StockLocationTest(StockAPITestCase): def test_stock_location_icon(self): """Test stock location icon inheritance from StockLocationType.""" - parent_location = StockLocation.objects.create(name="Parent location") + parent_location = StockLocation.objects.create(name='Parent location') - location_type = StockLocationType.objects.create(name="Box", description="This is a very cool type of box", icon="fas fa-box") - location = StockLocation.objects.create(name="Test location", custom_icon="fas fa-microscope", location_type=location_type, parent=parent_location) + location_type = StockLocationType.objects.create( + name='Box', description='This is a very cool type of box', icon='fas fa-box' + ) + location = StockLocation.objects.create( + name='Test location', + custom_icon='fas fa-microscope', + location_type=location_type, + parent=parent_location, + ) - res = self.get(self.list_url, {"parent": str(parent_location.pk)}, expected_code=200).json() - self.assertEqual(res[0]["icon"], "fas fa-microscope", "Custom icon from location should be returned") + res = self.get( + self.list_url, {'parent': str(parent_location.pk)}, expected_code=200 + ).json() + self.assertEqual( + res[0]['icon'], + 'fas fa-microscope', + 'Custom icon from location should be returned', + ) - location.custom_icon = "" + location.custom_icon = '' location.save() - res = self.get(self.list_url, {"parent": str(parent_location.pk)}, expected_code=200).json() - self.assertEqual(res[0]["icon"], "fas fa-box", "Custom icon is None, therefore it should inherit the location type icon") + res = self.get( + self.list_url, {'parent': str(parent_location.pk)}, expected_code=200 + ).json() + self.assertEqual( + res[0]['icon'], + 'fas fa-box', + 'Custom icon is None, therefore it should inherit the location type icon', + ) - location_type.icon = "" + location_type.icon = '' location_type.save() - res = self.get(self.list_url, {"parent": str(parent_location.pk)}, expected_code=200).json() - self.assertEqual(res[0]["icon"], "", "Custom icon and location type icon is None, None should be returned") + res = self.get( + self.list_url, {'parent': str(parent_location.pk)}, expected_code=200 + ).json() + self.assertEqual( + res[0]['icon'], + '', + 'Custom icon and location type icon is None, None should be returned', + ) def test_stock_location_list_filter(self): """Test stock location list filters.""" - parent_location = StockLocation.objects.create(name="Parent location") + parent_location = StockLocation.objects.create(name='Parent location') - location_type = StockLocationType.objects.create(name="Box", description="This is a very cool type of box", icon="fas fa-box") - location_type2 = StockLocationType.objects.create(name="Shelf", description="This is a very cool type of shelf", icon="fas fa-shapes") - StockLocation.objects.create(name="Test location w. type", location_type=location_type, parent=parent_location) - StockLocation.objects.create(name="Test location w. type 2", parent=parent_location, location_type=location_type2) - StockLocation.objects.create(name="Test location wo type", parent=parent_location) + location_type = StockLocationType.objects.create( + name='Box', description='This is a very cool type of box', icon='fas fa-box' + ) + location_type2 = StockLocationType.objects.create( + name='Shelf', + description='This is a very cool type of shelf', + icon='fas fa-shapes', + ) + StockLocation.objects.create( + name='Test location w. type', + location_type=location_type, + parent=parent_location, + ) + StockLocation.objects.create( + name='Test location w. type 2', + parent=parent_location, + location_type=location_type2, + ) + StockLocation.objects.create( + name='Test location wo type', parent=parent_location + ) - res = self.get(self.list_url, {"parent": str(parent_location.pk), "has_location_type": "1"}, expected_code=200).json() + res = self.get( + self.list_url, + {'parent': str(parent_location.pk), 'has_location_type': '1'}, + expected_code=200, + ).json() self.assertEqual(len(res), 2) - self.assertEqual(res[0]["name"], "Test location w. type") - self.assertEqual(res[1]["name"], "Test location w. type 2") + self.assertEqual(res[0]['name'], 'Test location w. type') + self.assertEqual(res[1]['name'], 'Test location w. type 2') - res = self.get(self.list_url, {"parent": str(parent_location.pk), "location_type": str(location_type.pk)}, expected_code=200).json() + res = self.get( + self.list_url, + {'parent': str(parent_location.pk), 'location_type': str(location_type.pk)}, + expected_code=200, + ).json() self.assertEqual(len(res), 1) - self.assertEqual(res[0]["name"], "Test location w. type") + self.assertEqual(res[0]['name'], 'Test location w. type') - res = self.get(self.list_url, {"parent": str(parent_location.pk), "has_location_type": "0"}, expected_code=200).json() + res = self.get( + self.list_url, + {'parent': str(parent_location.pk), 'has_location_type': '0'}, + expected_code=200, + ).json() self.assertEqual(len(res), 1) - self.assertEqual(res[0]["name"], "Test location wo type") + self.assertEqual(res[0]['name'], 'Test location wo type') class StockLocationTypeTest(StockAPITestCase): @@ -347,35 +461,58 @@ class StockLocationTypeTest(StockAPITestCase): def test_list(self): """Test that the list endpoint works as expected.""" location_types = [ - StockLocationType.objects.create(name="Type 1", description="Type 1 desc", icon="fas fa-box"), - StockLocationType.objects.create(name="Type 2", description="Type 2 desc", icon="fas fa-box"), - StockLocationType.objects.create(name="Type 3", description="Type 3 desc", icon="fas fa-box"), + StockLocationType.objects.create( + name='Type 1', description='Type 1 desc', icon='fas fa-box' + ), + StockLocationType.objects.create( + name='Type 2', description='Type 2 desc', icon='fas fa-box' + ), + StockLocationType.objects.create( + name='Type 3', description='Type 3 desc', icon='fas fa-box' + ), ] - StockLocation.objects.create(name="Loc 1", location_type=location_types[0]) - StockLocation.objects.create(name="Loc 2", location_type=location_types[0]) - StockLocation.objects.create(name="Loc 3", location_type=location_types[1]) + StockLocation.objects.create(name='Loc 1', location_type=location_types[0]) + StockLocation.objects.create(name='Loc 2', location_type=location_types[0]) + StockLocation.objects.create(name='Loc 3', location_type=location_types[1]) res = self.get(self.list_url, expected_code=200).json() self.assertEqual(len(res), 3) - self.assertCountEqual([r["location_count"] for r in res], [2, 1, 0]) + self.assertCountEqual([r['location_count'] for r in res], [2, 1, 0]) def test_delete(self): """Test that we can delete a location type via API.""" - location_type = StockLocationType.objects.create(name="Type 1", description="Type 1 desc", icon="fas fa-box") - self.delete(reverse('api-location-type-detail', kwargs={"pk": location_type.pk}), expected_code=204) + location_type = StockLocationType.objects.create( + name='Type 1', description='Type 1 desc', icon='fas fa-box' + ) + self.delete( + reverse('api-location-type-detail', kwargs={'pk': location_type.pk}), + expected_code=204, + ) self.assertEqual(StockLocationType.objects.count(), 0) def test_create(self): """Test that we can create a location type via API.""" - self.post(self.list_url, {"name": "Test Type 1", "description": "Test desc 1", "icon": "fas fa-box"}, expected_code=201) - self.assertIsNotNone(StockLocationType.objects.filter(name="Test Type 1").first()) + self.post( + self.list_url, + {'name': 'Test Type 1', 'description': 'Test desc 1', 'icon': 'fas fa-box'}, + expected_code=201, + ) + self.assertIsNotNone( + StockLocationType.objects.filter(name='Test Type 1').first() + ) def test_update(self): """Test that we can update a location type via API.""" - location_type = StockLocationType.objects.create(name="Type 1", description="Type 1 desc", icon="fas fa-box") - res = self.patch(reverse('api-location-type-detail', kwargs={"pk": location_type.pk}), {"icon": "fas fa-shapes"}, expected_code=200).json() - self.assertEqual(res["icon"], "fas fa-shapes") + location_type = StockLocationType.objects.create( + name='Type 1', description='Type 1 desc', icon='fas fa-box' + ) + res = self.patch( + reverse('api-location-type-detail', kwargs={'pk': location_type.pk}), + {'icon': 'fas fa-shapes'}, + expected_code=200, + ).json() + self.assertEqual(res['icon'], 'fas fa-shapes') class StockItemListTest(StockAPITestCase): @@ -404,14 +541,18 @@ class StockItemListTest(StockAPITestCase): self.assertEqual(len(response.data), StockItem.objects.count()) # Filter with "cascade=False" for the top-level location - response = self.get(self.list_url, {'location': 'null', 'cascade': False}, expected_code=200) + response = self.get( + self.list_url, {'location': 'null', 'cascade': False}, expected_code=200 + ) self.assertTrue(len(response.data) < StockItem.objects.count()) for result in response.data: self.assertIsNone(result['location']) # Filter with "cascade=True" - response = self.get(self.list_url, {'location': 'null', 'cascade': True}, expected_code=200) + response = self.get( + self.list_url, {'location': 'null', 'cascade': True}, expected_code=200 + ) self.assertEqual(len(response.data), StockItem.objects.count()) def test_get_stock_list(self): @@ -432,7 +573,7 @@ class StockItemListTest(StockAPITestCase): def test_filter_by_ipn(self): """Filter StockItem by IPN reference.""" - response = self.get_stock(IPN="R.CH") + response = self.get_stock(IPN='R.CH') self.assertEqual(len(response), 3) def test_filter_by_location(self): @@ -600,7 +741,9 @@ class StockItemListTest(StockAPITestCase): self.assertEqual(response.status_code, 200) - self.assertTrue(isinstance(response, django.http.response.StreamingHttpResponse)) + self.assertTrue( + isinstance(response, django.http.response.StreamingHttpResponse) + ) file_object = io.StringIO(response.getvalue().decode('utf-8')) @@ -629,9 +772,7 @@ class StockItemListTest(StockAPITestCase): for h in headers: self.assertIn(h, dataset.headers) - excluded_headers = [ - 'metadata', - ] + excluded_headers = ['metadata'] for h in excluded_headers: self.assertNotIn(h, dataset.headers) @@ -655,9 +796,15 @@ class StockItemListTest(StockAPITestCase): """ # Create a build order to allocate to - assembly = part.models.Part.objects.create(name='F Assembly', description='Assembly for filter test', assembly=True) - component = part.models.Part.objects.create(name='F Component', description='Component for filter test', component=True) - bom_item = part.models.BomItem.objects.create(part=assembly, sub_part=component, quantity=10) + assembly = part.models.Part.objects.create( + name='F Assembly', description='Assembly for filter test', assembly=True + ) + component = part.models.Part.objects.create( + name='F Component', description='Component for filter test', component=True + ) + bom_item = part.models.BomItem.objects.create( + part=assembly, sub_part=component, quantity=10 + ) # Create two build orders bo_1 = build.models.Build.objects.create(part=assembly, quantity=10) @@ -666,78 +813,67 @@ class StockItemListTest(StockAPITestCase): # Test that two distinct build line items are created automatically self.assertEqual(bo_1.build_lines.count(), 1) self.assertEqual(bo_2.build_lines.count(), 1) - self.assertEqual(build.models.BuildLine.objects.filter(bom_item=bom_item).count(), 2) + self.assertEqual( + build.models.BuildLine.objects.filter(bom_item=bom_item).count(), 2 + ) build_line_1 = bo_1.build_lines.first() build_line_2 = bo_2.build_lines.first() # Allocate stock location = StockLocation.objects.first() - stock_1 = StockItem.objects.create(part=component, quantity=100, location=location) - stock_2 = StockItem.objects.create(part=component, quantity=100, location=location) - stock_3 = StockItem.objects.create(part=component, quantity=100, location=location) + stock_1 = StockItem.objects.create( + part=component, quantity=100, location=location + ) + stock_2 = StockItem.objects.create( + part=component, quantity=100, location=location + ) + stock_3 = StockItem.objects.create( + part=component, quantity=100, location=location + ) # Allocate stock_1 to two build orders build.models.BuildItem.objects.create( - stock_item=stock_1, - build_line=build_line_1, - quantity=5 + stock_item=stock_1, build_line=build_line_1, quantity=5 ) build.models.BuildItem.objects.create( - stock_item=stock_1, - build_line=build_line_2, - quantity=5 + stock_item=stock_1, build_line=build_line_2, quantity=5 ) # Allocate stock_2 to 1 build orders build.models.BuildItem.objects.create( - stock_item=stock_2, - build_line=build_line_1, - quantity=5 + stock_item=stock_2, build_line=build_line_1, quantity=5 ) url = reverse('api-stock-list') # 3 items when just filtering by part response = self.get( - url, - { - "part": component.pk, - "in_stock": True - }, - expected_code=200 + url, {'part': component.pk, 'in_stock': True}, expected_code=200 ) self.assertEqual(len(response.data), 3) # 1 item when filtering by "not allocated" response = self.get( url, - { - "part": component.pk, - "in_stock": True, - "allocated": False, - }, - expected_code=200 + {'part': component.pk, 'in_stock': True, 'allocated': False}, + expected_code=200, ) self.assertEqual(len(response.data), 1) - self.assertEqual(response.data[0]["pk"], stock_3.pk) + self.assertEqual(response.data[0]['pk'], stock_3.pk) # 2 items when filtering by "allocated" response = self.get( url, - { - "part": component.pk, - "in_stock": True, - "allocated": True, - }, - expected_code=200 + {'part': component.pk, 'in_stock': True, 'allocated': True}, + expected_code=200, ) self.assertEqual(len(response.data), 2) - ids = [item["pk"] for item in response.data] + ids = [item['pk'] for item in response.data] self.assertIn(stock_1.pk, ids) self.assertIn(stock_2.pk, ids) @@ -755,11 +891,8 @@ class StockItemListTest(StockAPITestCase): prt = Part.objects.first() StockItem.objects.bulk_create([ - StockItem( - part=prt, - quantity=1, - level=0, tree_id=0, lft=0, rght=0, - ) for _ in range(100) + StockItem(part=prt, quantity=1, level=0, tree_id=0, lft=0, rght=0) + for _ in range(100) ]) # List *all* stock items @@ -798,12 +931,7 @@ class StockItemTest(StockAPITestCase): # The part 'R_4K7_0603' (pk=4) has a default location specified response = self.post( - self.list_url, - data={ - 'part': 4, - 'quantity': 10 - }, - expected_code=201 + self.list_url, data={'part': 4, 'quantity': 10}, expected_code=201 ) self.assertEqual(response.data['location'], 2) @@ -812,12 +940,8 @@ class StockItemTest(StockAPITestCase): response = self.post( self.list_url, - data={ - 'part': 4, - 'quantity': 20, - 'location': 1, - }, - expected_code=201 + data={'part': 4, 'quantity': 20, 'location': 1}, + expected_code=201, ) self.assertEqual(response.data['location'], 1) @@ -825,12 +949,8 @@ class StockItemTest(StockAPITestCase): response = self.post( self.list_url, - data={ - 'part': 4, - 'quantity': 20, - 'location': '', - }, - expected_code=201 + data={'part': 4, 'quantity': 20, 'location': ''}, + expected_code=201, ) self.assertEqual(response.data['location'], None) @@ -839,37 +959,29 @@ class StockItemTest(StockAPITestCase): """Test creation of a StockItem via the API.""" # POST with an empty part reference - response = self.client.post( - self.list_url, - data={ - 'quantity': 10, - 'location': 1 - } - ) + response = self.client.post(self.list_url, data={'quantity': 10, 'location': 1}) - self.assertContains(response, 'Valid part must be supplied', status_code=status.HTTP_400_BAD_REQUEST) + self.assertContains( + response, + 'Valid part must be supplied', + status_code=status.HTTP_400_BAD_REQUEST, + ) # POST with an invalid part reference response = self.client.post( - self.list_url, - data={ - 'quantity': 10, - 'location': 1, - 'part': 10000000, - } + self.list_url, data={'quantity': 10, 'location': 1, 'part': 10000000} ) - self.assertContains(response, 'Valid part must be supplied', status_code=status.HTTP_400_BAD_REQUEST) + self.assertContains( + response, + 'Valid part must be supplied', + status_code=status.HTTP_400_BAD_REQUEST, + ) # POST without quantity response = self.post( - self.list_url, - { - 'part': 1, - 'location': 1, - }, - expected_code=400 + self.list_url, {'part': 1, 'location': 1}, expected_code=400 ) self.assertIn('Quantity is required', str(response.data)) @@ -877,12 +989,8 @@ class StockItemTest(StockAPITestCase): # POST with quantity and part and location response = self.post( self.list_url, - data={ - 'part': 1, - 'location': 1, - 'quantity': 10, - }, - expected_code=201 + data={'part': 1, 'location': 1, 'quantity': 10}, + expected_code=201, ) def test_stock_item_create_withsupplierpart(self): @@ -890,13 +998,8 @@ class StockItemTest(StockAPITestCase): # POST with non-existent supplier part response = self.post( self.list_url, - data={ - 'part': 1, - 'location': 1, - 'quantity': 4, - 'supplier_part': 1000991 - }, - expected_code=400 + data={'part': 1, 'location': 1, 'quantity': 4, 'supplier_part': 1000991}, + expected_code=400, ) self.assertIn('The given supplier part does not exist', str(response.data)) @@ -915,7 +1018,7 @@ class StockItemTest(StockAPITestCase): 'purchase_price': 123.45, 'purchase_price_currency': 'USD', }, - expected_code=201 + expected_code=201, ) # Reload part, count stock again @@ -940,7 +1043,7 @@ class StockItemTest(StockAPITestCase): 'purchase_price': 123.45, 'purchase_price_currency': 'USD', }, - expected_code=201 + expected_code=201, ) # Reload part, count stock again part_4 = part.models.Part.objects.get(pk=4) @@ -962,7 +1065,7 @@ class StockItemTest(StockAPITestCase): 'purchase_price': 123.45, 'purchase_price_currency': 'USD', }, - expected_code=201 + expected_code=201, ) # Reload part, count stock again part_4 = part.models.Part.objects.get(pk=4) @@ -984,7 +1087,7 @@ class StockItemTest(StockAPITestCase): 'purchase_price': 123.45, 'purchase_price_currency': 'USD', }, - expected_code=201 + expected_code=201, ) # Reload part, count stock again part_4 = part.models.Part.objects.get(pk=4) @@ -1007,21 +1110,13 @@ class StockItemTest(StockAPITestCase): # This should fail, incorrect serial number count self.post( self.list_url, - data={ - 'part': trackable_part.pk, - 'quantity': 10, - 'serial_numbers': '1-20', - }, + data={'part': trackable_part.pk, 'quantity': 10, 'serial_numbers': '1-20'}, expected_code=400, ) response = self.post( self.list_url, - data={ - 'part': trackable_part.pk, - 'quantity': 10, - 'serial_numbers': '1-10', - }, + data={'part': trackable_part.pk, 'quantity': 10, 'serial_numbers': '1-10'}, expected_code=201, ) @@ -1036,10 +1131,7 @@ class StockItemTest(StockAPITestCase): # Check the unique stock item has been created - item = StockItem.objects.get( - part=trackable_part, - serial=str(i), - ) + item = StockItem.objects.get(part=trackable_part, serial=str(i)) # Item location should have been set automatically self.assertIsNotNone(item.location) @@ -1062,10 +1154,7 @@ class StockItemTest(StockAPITestCase): - Part <25> has a default_expiry of 10 days """ # First test - create a new StockItem without an expiry date - data = { - 'part': 4, - 'quantity': 10, - } + data = {'part': 4, 'quantity': 10} response = self.post(self.list_url, data, expected_code=201) @@ -1080,10 +1169,7 @@ class StockItemTest(StockAPITestCase): self.assertEqual(response.data['expiry_date'], '2022-12-12') # Third test - create a new StockItem for a Part which has a default expiry time - data = { - 'part': 25, - 'quantity': 10 - } + data = {'part': 25, 'quantity': 10} response = self.post(self.list_url, data, expected_code=201) @@ -1109,47 +1195,25 @@ class StockItemTest(StockAPITestCase): self.assertEqual(data['purchase_price_currency'], 'AUD') # Update just the amount - data = self.patch( - url, - { - 'purchase_price': 456 - }, - expected_code=200 - ).data + data = self.patch(url, {'purchase_price': 456}, expected_code=200).data self.assertEqual(data['purchase_price'], '456.000000') self.assertEqual(data['purchase_price_currency'], 'AUD') # Update the currency data = self.patch( - url, - { - 'purchase_price_currency': 'NZD', - }, - expected_code=200 + url, {'purchase_price_currency': 'NZD'}, expected_code=200 ).data self.assertEqual(data['purchase_price_currency'], 'NZD') # Clear the price field - data = self.patch( - url, - { - 'purchase_price': None, - }, - expected_code=200 - ).data + data = self.patch(url, {'purchase_price': None}, expected_code=200).data self.assertEqual(data['purchase_price'], None) # Invalid currency code - data = self.patch( - url, - { - 'purchase_price_currency': 'xyz', - }, - expected_code=400 - ) + data = self.patch(url, {'purchase_price_currency': 'xyz'}, expected_code=400) data = self.get(url).data self.assertEqual(data['purchase_price_currency'], 'NZD') @@ -1160,17 +1224,11 @@ class StockItemTest(StockAPITestCase): parent_part = part.models.Part.objects.get(pk=100) item = StockItem.objects.create( - part=parent_part, - serial='12345688-1230', - quantity=1, + part=parent_part, serial='12345688-1230', quantity=1 ) sub_part = part.models.Part.objects.get(pk=50) - sub_item = StockItem.objects.create( - part=sub_part, - serial='xyz-123', - quantity=1, - ) + sub_item = StockItem.objects.create(part=sub_part, serial='xyz-123', quantity=1) n_entries = sub_item.tracking_info.count() @@ -1185,18 +1243,17 @@ class StockItemTest(StockAPITestCase): 'stock_item': 520, 'note': 'This should fail, as Item #522 is not in the BOM', }, - expected_code=400 + expected_code=400, ) - self.assertIn('Selected part is not in the Bill of Materials', str(response.data)) + self.assertIn( + 'Selected part is not in the Bill of Materials', str(response.data) + ) # Now, try to install an item which *is* in the BOM for the parent part response = self.post( url, - { - 'stock_item': sub_item.pk, - 'note': "This time, it should be good!", - }, + {'stock_item': sub_item.pk, 'note': 'This time, it should be good!'}, expected_code=201, ) @@ -1209,10 +1266,7 @@ class StockItemTest(StockAPITestCase): # Try to install again - this time, should fail because the StockItem is not available! response = self.post( url, - { - 'stock_item': sub_item.pk, - 'note': 'Expectation: failure!', - }, + {'stock_item': sub_item.pk, 'note': 'Expectation: failure!'}, expected_code=400, ) @@ -1222,13 +1276,7 @@ class StockItemTest(StockAPITestCase): url = reverse('api-stock-item-uninstall', kwargs={'pk': sub_item.pk}) - self.post( - url, - { - 'location': 1, - }, - expected_code=201, - ) + self.post(url, {'location': 1}, expected_code=201) sub_item.refresh_from_db() @@ -1249,19 +1297,13 @@ class StockItemTest(StockAPITestCase): url = reverse('api-stock-item-return', kwargs={'pk': item.pk}) # Empty POST will fail - response = self.post( - url, {}, - expected_code=400 - ) + response = self.post(url, {}, expected_code=400) self.assertIn('This field is required', str(response.data['location'])) response = self.post( url, - { - 'location': '1', - 'notes': 'Returned from this customer for testing', - }, + {'location': '1', 'notes': 'Returned from this customer for testing'}, expected_code=201, ) @@ -1279,7 +1321,8 @@ class StockItemTest(StockAPITestCase): # First, construct a set of template / variant parts master_part = part.models.Part.objects.create( - name='Master', description='Master part which has variants', + name='Master', + description='Master part which has variants', category=category, is_template=True, ) @@ -1288,49 +1331,31 @@ class StockItemTest(StockAPITestCase): # Construct a set of variant parts for color in ['Red', 'Green', 'Blue', 'Yellow', 'Pink', 'Black']: - variants.append(part.models.Part.objects.create( - name=f"{color} Variant", description="Variant part with a specific color", - variant_of=master_part, - category=category, - )) + variants.append( + part.models.Part.objects.create( + name=f'{color} Variant', + description='Variant part with a specific color', + variant_of=master_part, + category=category, + ) + ) - stock_item = StockItem.objects.create( - part=master_part, - quantity=1000, - ) + stock_item = StockItem.objects.create(part=master_part, quantity=1000) url = reverse('api-stock-item-convert', kwargs={'pk': stock_item.pk}) # Attempt to convert to a part which does not exist - response = self.post( - url, - { - 'part': 999999, - }, - expected_code=400, - ) + response = self.post(url, {'part': 999999}, expected_code=400) self.assertIn('object does not exist', str(response.data['part'])) # Attempt to convert to a part which is not a valid option - response = self.post( - url, - { - 'part': 1, - }, - expected_code=400 - ) + response = self.post(url, {'part': 1}, expected_code=400) self.assertIn('Selected part is not a valid option', str(response.data['part'])) for variant in variants: - response = self.post( - url, - { - 'part': variant.pk, - }, - expected_code=201, - ) + response = self.post(url, {'part': variant.pk}, expected_code=201) stock_item.refresh_from_db() self.assertEqual(stock_item.part, variant) @@ -1342,9 +1367,7 @@ class StockItemTest(StockAPITestCase): prt = Part.objects.first() # Create a bunch of items - items = [ - StockItem.objects.create(part=prt, quantity=10) for _ in range(10) - ] + items = [StockItem.objects.create(part=prt, quantity=10) for _ in range(10)] for item in items: item.refresh_from_db() @@ -1386,7 +1409,6 @@ class StocktakeTest(StockAPITestCase): def test_action(self): """Test each stocktake action endpoint, for validation.""" for endpoint in ['api-stock-count', 'api-stock-add', 'api-stock-remove']: - url = reverse(endpoint) data = {} @@ -1394,11 +1416,9 @@ class StocktakeTest(StockAPITestCase): # POST with a valid action response = self.post(url, data) - self.assertIn("This field is required", str(response.data["items"])) + self.assertIn('This field is required', str(response.data['items'])) - data['items'] = [{ - 'no': 'aa' - }] + data['items'] = [{'no': 'aa'}] # POST without a PK response = self.post(url, data, expected_code=400) @@ -1406,50 +1426,51 @@ class StocktakeTest(StockAPITestCase): self.assertIn('This field is required', str(response.data)) # POST with an invalid PK - data['items'] = [{ - 'pk': 10 - }] + data['items'] = [{'pk': 10}] response = self.post(url, data, expected_code=400) - self.assertContains(response, 'object does not exist', status_code=status.HTTP_400_BAD_REQUEST) + self.assertContains( + response, + 'object does not exist', + status_code=status.HTTP_400_BAD_REQUEST, + ) # POST with missing quantity value - data['items'] = [{ - 'pk': 1234 - }] + data['items'] = [{'pk': 1234}] response = self.post(url, data, expected_code=400) - self.assertContains(response, 'This field is required', status_code=status.HTTP_400_BAD_REQUEST) + self.assertContains( + response, + 'This field is required', + status_code=status.HTTP_400_BAD_REQUEST, + ) # POST with an invalid quantity value - data['items'] = [{ - 'pk': 1234, - 'quantity': '10x0d' - }] + data['items'] = [{'pk': 1234, 'quantity': '10x0d'}] response = self.post(url, data) - self.assertContains(response, 'A valid number is required', status_code=status.HTTP_400_BAD_REQUEST) + self.assertContains( + response, + 'A valid number is required', + status_code=status.HTTP_400_BAD_REQUEST, + ) - data['items'] = [{ - 'pk': 1234, - 'quantity': "-1.234" - }] + data['items'] = [{'pk': 1234, 'quantity': '-1.234'}] response = self.post(url, data) - self.assertContains(response, 'Ensure this value is greater than or equal to 0', status_code=status.HTTP_400_BAD_REQUEST) + self.assertContains( + response, + 'Ensure this value is greater than or equal to 0', + status_code=status.HTTP_400_BAD_REQUEST, + ) def test_transfer(self): """Test stock transfers.""" data = { - 'items': [ - { - 'pk': 1234, - 'quantity': 10, - } - ], + 'items': [{'pk': 1234, 'quantity': 10}], 'location': 1, - 'notes': "Moving to a new location" + 'notes': 'Moving to a new location', } url = reverse('api-stock-transfer') @@ -1462,7 +1483,11 @@ class StocktakeTest(StockAPITestCase): response = self.post(url, data, expected_code=400) - self.assertContains(response, 'Incorrect type. Expected pk value', status_code=status.HTTP_400_BAD_REQUEST) + self.assertContains( + response, + 'Incorrect type. Expected pk value', + status_code=status.HTTP_400_BAD_REQUEST, + ) class StockItemDeletionTest(StockAPITestCase): @@ -1474,16 +1499,11 @@ class StockItemDeletionTest(StockAPITestCase): # Create and then delete a bunch of stock items for idx in range(10): - # Create new StockItem via the API response = self.post( reverse('api-stock-list'), - { - 'part': 1, - 'location': 1, - 'quantity': idx, - }, - expected_code=201 + {'part': 1, 'location': 1, 'quantity': idx}, + expected_code=201, ) pk = response.data['pk'] @@ -1492,8 +1512,7 @@ class StockItemDeletionTest(StockAPITestCase): # Request deletion via the API self.delete( - reverse('api-stock-detail', kwargs={'pk': pk}), - expected_code=204 + reverse('api-stock-detail', kwargs={'pk': pk}), expected_code=204 ) self.assertEqual(StockItem.objects.count(), n) @@ -1524,24 +1543,13 @@ class StockTestResultTest(StockAPITestCase): url = self.get_url() - self.post( - url, - data={ - 'test': 'A test', - 'result': True, - }, - expected_code=400 - ) + self.post(url, data={'test': 'A test', 'result': True}, expected_code=400) # This one should pass! self.post( url, - data={ - 'test': 'A test', - 'stock_item': 105, - 'result': True, - }, - expected_code=201 + data={'test': 'A test', 'stock_item': 105, 'result': True}, + expected_code=201, ) def test_post(self): @@ -1588,14 +1596,13 @@ class StockTestResultTest(StockAPITestCase): image_file = os.path.join(here, 'fixtures', 'test_image.bmp') with open(image_file, 'rb') as bitmap: - data = { 'stock_item': 105, 'test': 'Checked Steam Valve', 'result': False, 'value': '150kPa', 'notes': 'I guess there was just too much pressure?', - "attachment": bitmap, + 'attachment': bitmap, } response = self.client.post(self.get_url(), data) @@ -1618,11 +1625,11 @@ class StockTestResultTest(StockAPITestCase): url, { 'stock_item': 1, - 'test': f"Some test {_ii}", + 'test': f'Some test {_ii}', 'result': True, - 'value': 'Test result value' + 'value': 'Test result value', }, - expected_code=201 + expected_code=201, ) tests.append(response.data['pk']) @@ -1630,37 +1637,19 @@ class StockTestResultTest(StockAPITestCase): self.assertEqual(StockItemTestResult.objects.count(), n + 50) # Attempt a delete without providing items - self.delete( - url, - {}, - expected_code=400, - ) + self.delete(url, {}, expected_code=400) # Now, let's delete all the newly created items with a single API request # However, we will provide incorrect filters response = self.delete( - url, - { - 'items': tests, - 'filters': { - 'stock_item': 10, - } - }, - expected_code=204 + url, {'items': tests, 'filters': {'stock_item': 10}}, expected_code=204 ) self.assertEqual(StockItemTestResult.objects.count(), n + 50) # Try again, but with the correct filters this time response = self.delete( - url, - { - 'items': tests, - 'filters': { - 'stock_item': 1, - } - }, - expected_code=204 + url, {'items': tests, 'filters': {'stock_item': 1}}, expected_code=204 ) self.assertEqual(StockItemTestResult.objects.count(), n) @@ -1674,45 +1663,24 @@ class StockAssignTest(StockAPITestCase): def test_invalid(self): """Test invalid assign.""" # Test with empty data - response = self.post( - self.URL, - data={}, - expected_code=400, - ) + response = self.post(self.URL, data={}, expected_code=400) self.assertIn('This field is required', str(response.data['items'])) self.assertIn('This field is required', str(response.data['customer'])) # Test with an invalid customer - response = self.post( - self.URL, - data={ - 'customer': 999, - }, - expected_code=400, - ) + response = self.post(self.URL, data={'customer': 999}, expected_code=400) self.assertIn('object does not exist', str(response.data['customer'])) # Test with a company which is *not* a customer - response = self.post( - self.URL, - data={ - 'customer': 3, - }, - expected_code=400, - ) + response = self.post(self.URL, data={'customer': 3}, expected_code=400) self.assertIn('company is not a customer', str(response.data['customer'])) # Test with an empty items list response = self.post( - self.URL, - data={ - 'items': [], - 'customer': 4, - }, - expected_code=400, + self.URL, data={'items': [], 'customer': 4}, expected_code=400 ) self.assertIn('A list of stock items must be provided', str(response.data)) @@ -1725,14 +1693,7 @@ class StockAssignTest(StockAPITestCase): response = self.post( self.URL, - data={ - 'items': [ - { - 'item': stock_item.pk, - }, - ], - 'customer': 4, - }, + data={'items': [{'item': stock_item.pk}], 'customer': 4}, expected_code=400, ) @@ -1743,27 +1704,18 @@ class StockAssignTest(StockAPITestCase): stock_items = [] for i in range(5): - stock_item = StockItem.objects.create( - part=part.models.Part.objects.get(pk=25), - quantity=i + 5, + part=part.models.Part.objects.get(pk=25), quantity=i + 5 ) - stock_items.append({ - 'item': stock_item.pk - }) + stock_items.append({'item': stock_item.pk}) customer = company.models.Company.objects.get(pk=4) self.assertEqual(customer.assigned_stock.count(), 0) response = self.post( - self.URL, - data={ - 'items': stock_items, - 'customer': 4, - }, - expected_code=201, + self.URL, data={'items': stock_items, 'customer': 4}, expected_code=201 ) self.assertEqual(response.data['customer'], 4) @@ -1788,45 +1740,28 @@ class StockMergeTest(StockAPITestCase): cls.sp_2 = company.models.SupplierPart.objects.get(pk=101) cls.item_1 = StockItem.objects.create( - part=cls.part, - supplier_part=cls.sp_1, - quantity=100, + part=cls.part, supplier_part=cls.sp_1, quantity=100 ) cls.item_2 = StockItem.objects.create( - part=cls.part, - supplier_part=cls.sp_2, - quantity=100, + part=cls.part, supplier_part=cls.sp_2, quantity=100 ) cls.item_3 = StockItem.objects.create( - part=cls.part, - supplier_part=cls.sp_2, - quantity=50, + part=cls.part, supplier_part=cls.sp_2, quantity=50 ) def test_missing_data(self): """Test responses which are missing required data.""" # Post completely empty - data = self.post( - self.URL, - {}, - expected_code=400 - ).data + data = self.post(self.URL, {}, expected_code=400).data self.assertIn('This field is required', str(data['items'])) self.assertIn('This field is required', str(data['location'])) # Post with a location and empty items list - data = self.post( - self.URL, - { - 'items': [], - 'location': 1, - }, - expected_code=400 - ).data + data = self.post(self.URL, {'items': [], 'location': 1}, expected_code=400).data self.assertIn('At least two stock items', str(data)) @@ -1835,17 +1770,7 @@ class StockMergeTest(StockAPITestCase): # Serialized stock items should be rejected data = self.post( self.URL, - { - 'items': [ - { - 'item': 501, - }, - { - 'item': 502, - } - ], - 'location': 1, - }, + {'items': [{'item': 501}, {'item': 502}], 'location': 1}, expected_code=400, ).data @@ -1855,17 +1780,7 @@ class StockMergeTest(StockAPITestCase): data = self.post( self.URL, - { - 'items': [ - { - 'item': 11, - }, - { - 'item': 11, - } - ], - 'location': 1, - }, + {'items': [{'item': 11}, {'item': 11}], 'location': 1}, expected_code=400, ).data @@ -1874,17 +1789,7 @@ class StockMergeTest(StockAPITestCase): # Check for mismatching stock items data = self.post( self.URL, - { - 'items': [ - { - 'item': 1234, - }, - { - 'item': 11, - } - ], - 'location': 1, - }, + {'items': [{'item': 1234}, {'item': 11}], 'location': 1}, expected_code=400, ).data @@ -1892,22 +1797,11 @@ class StockMergeTest(StockAPITestCase): # Check for mismatching supplier parts payload = { - 'items': [ - { - 'item': self.item_1.pk, - }, - { - 'item': self.item_2.pk, - }, - ], + 'items': [{'item': self.item_1.pk}, {'item': self.item_2.pk}], 'location': 1, } - data = self.post( - self.URL, - payload, - expected_code=400, - ).data + data = self.post(self.URL, payload, expected_code=400).data self.assertIn('Stock items must refer to the same supplier part', str(data)) @@ -1919,25 +1813,15 @@ class StockMergeTest(StockAPITestCase): payload = { 'items': [ - { - 'item': self.item_1.pk, - }, - { - 'item': self.item_2.pk, - }, - { - 'item': self.item_3.pk, - }, + {'item': self.item_1.pk}, + {'item': self.item_2.pk}, + {'item': self.item_3.pk}, ], 'location': 1, 'allow_mismatched_suppliers': True, } - self.post( - self.URL, - payload, - expected_code=201, - ) + self.post(self.URL, payload, expected_code=201) self.item_1.refresh_from_db() @@ -1962,10 +1846,7 @@ class StockMetadataAPITest(InvenTreeAPITestCase): 'stock_tests', ] - roles = [ - 'stock.change', - 'stock_location.change', - ] + roles = ['stock.change', 'stock_location.change'] def metatester(self, apikey, model): """Generic tester""" @@ -1983,17 +1864,15 @@ class StockMetadataAPITest(InvenTreeAPITestCase): self.patch( url, - { - 'metadata': { - f'abc-{numstr}': f'xyz-{apikey}-{numstr}', - } - }, - expected_code=200 + {'metadata': {f'abc-{numstr}': f'xyz-{apikey}-{numstr}'}}, + expected_code=200, ) # Refresh modeldata.refresh_from_db() - self.assertEqual(modeldata.get_metadata(f'abc-{numstr}'), f'xyz-{apikey}-{numstr}') + self.assertEqual( + modeldata.get_metadata(f'abc-{numstr}'), f'xyz-{apikey}-{numstr}' + ) def test_metadata(self): """Test all endpoints""" diff --git a/InvenTree/stock/test_migrations.py b/InvenTree/stock/test_migrations.py index e153344e8f..db7cf26670 100644 --- a/InvenTree/stock/test_migrations.py +++ b/InvenTree/stock/test_migrations.py @@ -24,18 +24,14 @@ class TestSerialNumberMigration(MigratorTestCase): trackable=True, level=0, tree_id=0, - lft=0, rght=0 + lft=0, + rght=0, ) # Create some serialized stock items for sn in range(10, 20): StockItem.objects.create( - part=my_part, - quantity=1, - serial=sn, - level=0, - tree_id=0, - lft=0, rght=0 + part=my_part, quantity=1, serial=sn, level=0, tree_id=0, lft=0, rght=0 ) # Create a stock item with a very large serial number @@ -45,7 +41,8 @@ class TestSerialNumberMigration(MigratorTestCase): serial='9999999999999999999999999999999999999999999999999999999999999', level=0, tree_id=0, - lft=0, rght=0 + lft=0, + rght=0, ) self.big_ref_pk = item.pk @@ -63,8 +60,11 @@ class TestSerialNumberMigration(MigratorTestCase): big_ref_item = StockItem.objects.get(pk=self.big_ref_pk) # Check that the StockItem maximum serial number - self.assertEqual(big_ref_item.serial, '9999999999999999999999999999999999999999999999999999999999999') - self.assertEqual(big_ref_item.serial_int, 0x7fffffff) + self.assertEqual( + big_ref_item.serial, + '9999999999999999999999999999999999999999999999999999999999999', + ) + self.assertEqual(big_ref_item.serial_int, 0x7FFFFFFF) class TestScheduledForDeletionMigration(MigratorTestCase): @@ -83,17 +83,21 @@ class TestScheduledForDeletionMigration(MigratorTestCase): name=f'Part_{idx}', description='Just a part, nothing to see here', active=True, - level=0, tree_id=0, - lft=0, rght=0, + level=0, + tree_id=0, + lft=0, + rght=0, ) for jj in range(5): StockItem.objects.create( part=part, quantity=jj + 5, - level=0, tree_id=0, - lft=0, rght=0, - scheduled_for_deletion=True + level=0, + tree_id=0, + lft=0, + rght=0, + scheduled_for_deletion=True, ) # For extra points, create some parent-child relationships between stock items @@ -102,8 +106,10 @@ class TestScheduledForDeletionMigration(MigratorTestCase): item_1 = StockItem.objects.create( part=part, quantity=100, - level=0, tree_id=0, - lft=0, rght=0, + level=0, + tree_id=0, + lft=0, + rght=0, scheduled_for_deletion=True, ) @@ -111,8 +117,10 @@ class TestScheduledForDeletionMigration(MigratorTestCase): StockItem.objects.create( part=part, quantity=200, - level=0, tree_id=0, - lft=0, rght=0, + level=0, + tree_id=0, + lft=0, + rght=0, scheduled_for_deletion=False, parent=item_1, ) diff --git a/InvenTree/stock/test_views.py b/InvenTree/stock/test_views.py index 38a7afee01..c507b0558f 100644 --- a/InvenTree/stock/test_views.py +++ b/InvenTree/stock/test_views.py @@ -13,14 +13,7 @@ from users.models import Owner class StockViewTestCase(InvenTreeTestCase): """Mixin for Stockview tests.""" - fixtures = [ - 'category', - 'part', - 'company', - 'location', - 'supplier_part', - 'stock', - ] + fixtures = ['category', 'part', 'company', 'location', 'supplier_part', 'stock'] roles = 'all' @@ -59,13 +52,13 @@ class StockDetailTest(StockViewTestCase): # Actions to check actions = [ - "id=\\\'stock-count\\\' title=\\\'Count stock\\\'", - "id=\\\'stock-add\\\' title=\\\'Add stock\\\'", - "id=\\\'stock-remove\\\' title=\\\'Remove stock\\\'", - "id=\\\'stock-move\\\' title=\\\'Transfer stock\\\'", - "id=\\\'stock-duplicate\\\'", - "id=\\\'stock-edit\\\'", - "id=\\\'stock-delete\\\'", + "id=\\'stock-count\\' title=\\'Count stock\\'", + "id=\\'stock-add\\' title=\\'Add stock\\'", + "id=\\'stock-remove\\' title=\\'Remove stock\\'", + "id=\\'stock-move\\' title=\\'Transfer stock\\'", + "id=\\'stock-duplicate\\'", + "id=\\'stock-edit\\'", + "id=\\'stock-delete\\'", ] # Initially we should not have any of the required permissions @@ -86,6 +79,7 @@ class StockDetailTest(StockViewTestCase): class StockOwnershipTest(StockViewTestCase): """Tests for stock ownership views.""" + test_item_id = 11 test_location_id = 1 @@ -135,9 +129,13 @@ class StockOwnershipTest(StockViewTestCase): location = StockLocation.objects.get(pk=self.test_location_id) # Check that user is not allowed to change item - self.assertTrue(item.check_ownership(self.user)) # No owner -> True - self.assertTrue(location.check_ownership(self.user)) # No owner -> True - self.assertContains(self.assert_api_change(), 'You do not have permission to perform this action.', status_code=403) + self.assertTrue(item.check_ownership(self.user)) # No owner -> True + self.assertTrue(location.check_ownership(self.user)) # No owner -> True + self.assertContains( + self.assert_api_change(), + 'You do not have permission to perform this action.', + status_code=403, + ) # Adjust group rules group = Group.objects.get(name='my_test_group') @@ -153,9 +151,13 @@ class StockOwnershipTest(StockViewTestCase): location.save() # Check that user is allowed to change item - self.assertTrue(item.check_ownership(self.user)) # Owner is group -> True - self.assertTrue(location.check_ownership(self.user)) # Owner is group -> True - self.assertContains(self.assert_api_change(), f'"status":{StockStatus.DAMAGED.value}', status_code=200) + self.assertTrue(item.check_ownership(self.user)) # Owner is group -> True + self.assertTrue(location.check_ownership(self.user)) # Owner is group -> True + self.assertContains( + self.assert_api_change(), + f'"status":{StockStatus.DAMAGED.value}', + status_code=200, + ) # Change group new_group = Group.objects.create(name='new_group') @@ -166,5 +168,9 @@ class StockOwnershipTest(StockViewTestCase): location.save() # Check that user is not allowed to change item - self.assertFalse(item.check_ownership(self.user)) # Owner is not in group -> False - self.assertFalse(location.check_ownership(self.user)) # Owner is not in group -> False + self.assertFalse( + item.check_ownership(self.user) + ) # Owner is not in group -> False + self.assertFalse( + location.check_ownership(self.user) + ) # Owner is not in group -> False diff --git a/InvenTree/stock/tests.py b/InvenTree/stock/tests.py index 6843ee8bce..8b44c9774d 100644 --- a/InvenTree/stock/tests.py +++ b/InvenTree/stock/tests.py @@ -14,8 +14,7 @@ from InvenTree.unit_test import InvenTreeTestCase from order.models import SalesOrder from part.models import Part -from .models import (StockItem, StockItemTestResult, StockItemTracking, - StockLocation) +from .models import StockItem, StockItemTestResult, StockItemTracking, StockLocation class StockTestBase(InvenTreeTestCase): @@ -55,10 +54,10 @@ class StockTest(StockTestBase): def test_pathstring(self): """Check that pathstring updates occur as expected""" - a = StockLocation.objects.create(name="A") - b = StockLocation.objects.create(name="B", parent=a) - c = StockLocation.objects.create(name="C", parent=b) - d = StockLocation.objects.create(name="D", parent=c) + a = StockLocation.objects.create(name='A') + b = StockLocation.objects.create(name='B', parent=a) + c = StockLocation.objects.create(name='C', parent=b) + d = StockLocation.objects.create(name='D', parent=c) def refresh(): a.refresh_from_db() @@ -67,56 +66,56 @@ class StockTest(StockTestBase): d.refresh_from_db() # Initial checks - self.assertEqual(a.pathstring, "A") - self.assertEqual(b.pathstring, "A/B") - self.assertEqual(c.pathstring, "A/B/C") - self.assertEqual(d.pathstring, "A/B/C/D") + self.assertEqual(a.pathstring, 'A') + self.assertEqual(b.pathstring, 'A/B') + self.assertEqual(c.pathstring, 'A/B/C') + self.assertEqual(d.pathstring, 'A/B/C/D') - c.name = "Cc" + c.name = 'Cc' c.save() refresh() - self.assertEqual(a.pathstring, "A") - self.assertEqual(b.pathstring, "A/B") - self.assertEqual(c.pathstring, "A/B/Cc") - self.assertEqual(d.pathstring, "A/B/Cc/D") + self.assertEqual(a.pathstring, 'A') + self.assertEqual(b.pathstring, 'A/B') + self.assertEqual(c.pathstring, 'A/B/Cc') + self.assertEqual(d.pathstring, 'A/B/Cc/D') - b.name = "Bb" + b.name = 'Bb' b.save() refresh() - self.assertEqual(a.pathstring, "A") - self.assertEqual(b.pathstring, "A/Bb") - self.assertEqual(c.pathstring, "A/Bb/Cc") - self.assertEqual(d.pathstring, "A/Bb/Cc/D") + self.assertEqual(a.pathstring, 'A') + self.assertEqual(b.pathstring, 'A/Bb') + self.assertEqual(c.pathstring, 'A/Bb/Cc') + self.assertEqual(d.pathstring, 'A/Bb/Cc/D') - a.name = "Aa" + a.name = 'Aa' a.save() refresh() - self.assertEqual(a.pathstring, "Aa") - self.assertEqual(b.pathstring, "Aa/Bb") - self.assertEqual(c.pathstring, "Aa/Bb/Cc") - self.assertEqual(d.pathstring, "Aa/Bb/Cc/D") + self.assertEqual(a.pathstring, 'Aa') + self.assertEqual(b.pathstring, 'Aa/Bb') + self.assertEqual(c.pathstring, 'Aa/Bb/Cc') + self.assertEqual(d.pathstring, 'Aa/Bb/Cc/D') - d.name = "Dd" + d.name = 'Dd' d.save() refresh() - self.assertEqual(a.pathstring, "Aa") - self.assertEqual(b.pathstring, "Aa/Bb") - self.assertEqual(c.pathstring, "Aa/Bb/Cc") - self.assertEqual(d.pathstring, "Aa/Bb/Cc/Dd") + self.assertEqual(a.pathstring, 'Aa') + self.assertEqual(b.pathstring, 'Aa/Bb') + self.assertEqual(c.pathstring, 'Aa/Bb/Cc') + self.assertEqual(d.pathstring, 'Aa/Bb/Cc/Dd') # Test a really long name # (it will be clipped to < 250 characters) - a.name = "A" * 100 + a.name = 'A' * 100 a.save() - b.name = "B" * 100 + b.name = 'B' * 100 b.save() - c.name = "C" * 100 + c.name = 'C' * 100 c.save() - d.name = "D" * 100 + d.name = 'D' * 100 d.save() refresh() @@ -125,19 +124,15 @@ class StockTest(StockTestBase): self.assertEqual(len(c.pathstring), 249) self.assertEqual(len(d.pathstring), 249) - self.assertTrue(d.pathstring.startswith("AAAAAAAA")) - self.assertTrue(d.pathstring.endswith("DDDDDDDD")) + self.assertTrue(d.pathstring.startswith('AAAAAAAA')) + self.assertTrue(d.pathstring.endswith('DDDDDDDD')) def test_link(self): """Test the link URL field validation""" item = StockItem.objects.get(pk=1) # Check that invalid URLs fail - for bad_url in [ - 'test.com', - 'httpx://abc.xyz', - 'https:google.com', - ]: + for bad_url in ['test.com', 'httpx://abc.xyz', 'https:google.com']: with self.assertRaises(ValidationError): item.link = bad_url item.save() @@ -179,41 +174,31 @@ class StockTest(StockTestBase): # Ensure that 'global uniqueness' setting is enabled InvenTreeSetting.set_setting('SERIAL_NUMBER_GLOBALLY_UNIQUE', True, self.user) - part_a = Part.objects.create(name='A', description='A part with a description', trackable=True) - part_b = Part.objects.create(name='B', description='B part with a description', trackable=True) + part_a = Part.objects.create( + name='A', description='A part with a description', trackable=True + ) + part_b = Part.objects.create( + name='B', description='B part with a description', trackable=True + ) # Create a StockItem for part_a - StockItem.objects.create( - part=part_a, - quantity=1, - serial='ABCDE', - ) + StockItem.objects.create(part=part_a, quantity=1, serial='ABCDE') # Create a StockItem for part_a (but, will error due to identical serial) with self.assertRaises(ValidationError): - StockItem.objects.create( - part=part_b, - quantity=1, - serial='ABCDE', - ) + StockItem.objects.create(part=part_b, quantity=1, serial='ABCDE') # Now, allow serial numbers to be duplicated between different parts InvenTreeSetting.set_setting('SERIAL_NUMBER_GLOBALLY_UNIQUE', False, self.user) - StockItem.objects.create( - part=part_b, - quantity=1, - serial='ABCDE', - ) + StockItem.objects.create(part=part_b, quantity=1, serial='ABCDE') def test_expiry(self): """Test expiry date functionality for StockItem model.""" today = datetime.datetime.now().date() item = StockItem.objects.create( - location=self.office, - part=Part.objects.get(pk=1), - quantity=10, + location=self.office, part=Part.objects.get(pk=1), quantity=10 ) # Without an expiry_date set, item should not be "expired" @@ -249,13 +234,14 @@ class StockTest(StockTestBase): # And there should be *no* items being build self.assertEqual(part.quantity_being_built, 0) - build = Build.objects.create(reference='BO-4444', part=part, title='A test build', quantity=1) + build = Build.objects.create( + reference='BO-4444', part=part, title='A test build', quantity=1 + ) # Add some stock items which are "building" for _ in range(10): StockItem.objects.create( - part=part, build=build, - quantity=10, is_building=True + part=part, build=build, quantity=10, is_building=True ) # The "is_building" quantity should not be counted here @@ -330,7 +316,10 @@ class StockTest(StockTestBase): # There should be 16 widgets "in stock" self.assertEqual( - StockItem.objects.filter(part=25).aggregate(Sum('quantity'))['quantity__sum'], 16 + StockItem.objects.filter(part=25).aggregate(Sum('quantity'))[ + 'quantity__sum' + ], + 16, ) def test_delete_location(self): @@ -339,7 +328,9 @@ class StockTest(StockTestBase): n_stock = StockItem.objects.count() # What parts are in drawer 3? - stock_ids = [part.id for part in StockItem.objects.filter(location=self.drawer3.id)] + stock_ids = [ + part.id for part in StockItem.objects.filter(location=self.drawer3.id) + ] # Delete location - parts should move to parent location self.drawer3.delete() @@ -361,7 +352,9 @@ class StockTest(StockTestBase): self.assertEqual(it.location, self.bathroom) # There now should be 2 lots of screws in the bathroom - self.assertEqual(StockItem.objects.filter(part=1, location=self.bathroom).count(), 2) + self.assertEqual( + StockItem.objects.filter(part=1, location=self.bathroom).count(), 2 + ) # Check that a tracking item was added track = StockItemTracking.objects.filter(item=it).latest('id') @@ -467,9 +460,11 @@ class StockTest(StockTestBase): it = StockItem.objects.get(pk=2) n = it.quantity an = n - 10 - customer = Company.objects.create(name="MyTestCompany") - order = SalesOrder.objects.create(description="Test order") - ait = it.allocateToCustomer(customer, quantity=an, order=order, user=None, notes='Allocated some stock') + customer = Company.objects.create(name='MyTestCompany') + order = SalesOrder.objects.create(description='Test order') + ait = it.allocateToCustomer( + customer, quantity=an, order=order, user=None, notes='Allocated some stock' + ) # Check if new stockitem is created self.assertTrue(ait) @@ -485,7 +480,9 @@ class StockTest(StockTestBase): # Check that a tracking item was added track = StockItemTracking.objects.filter(item=ait).latest('id') - self.assertEqual(track.tracking_type, StockHistoryCode.SHIPPED_AGAINST_SALES_ORDER) + self.assertEqual( + track.tracking_type, StockHistoryCode.SHIPPED_AGAINST_SALES_ORDER + ) self.assertIn('Allocated some stock', track.notes) def test_return_from_customer(self): @@ -493,21 +490,29 @@ class StockTest(StockTestBase): it = StockItem.objects.get(pk=2) # First establish total stock for this part - allstock_before = StockItem.objects.filter(part=it.part).aggregate(Sum("quantity"))["quantity__sum"] + allstock_before = StockItem.objects.filter(part=it.part).aggregate( + Sum('quantity') + )['quantity__sum'] n = it.quantity an = n - 10 - customer = Company.objects.create(name="MyTestCompany") - order = SalesOrder.objects.create(description="Test order") + customer = Company.objects.create(name='MyTestCompany') + order = SalesOrder.objects.create(description='Test order') - ait = it.allocateToCustomer(customer, quantity=an, order=order, user=None, notes='Allocated some stock') - ait.return_from_customer(it.location, None, notes="Stock removed from customer") + ait = it.allocateToCustomer( + customer, quantity=an, order=order, user=None, notes='Allocated some stock' + ) + ait.return_from_customer(it.location, None, notes='Stock removed from customer') # When returned stock is returned to its original (parent) location, check that the parent has correct quantity self.assertEqual(it.quantity, n) - ait = it.allocateToCustomer(customer, quantity=an, order=order, user=None, notes='Allocated some stock') - ait.return_from_customer(self.drawer3, None, notes="Stock removed from customer") + ait = it.allocateToCustomer( + customer, quantity=an, order=order, user=None, notes='Allocated some stock' + ) + ait.return_from_customer( + self.drawer3, None, notes='Stock removed from customer' + ) # Check correct assignment of the new location self.assertEqual(ait.location, self.drawer3) @@ -527,7 +532,9 @@ class StockTest(StockTestBase): self.assertIn('Stock removed from customer', track.notes) # Establish total stock for the part after remove from customer to check that we still have the correct quantity in stock - allstock_after = StockItem.objects.filter(part=it.part).aggregate(Sum("quantity"))["quantity__sum"] + allstock_after = StockItem.objects.filter(part=it.part).aggregate( + Sum('quantity') + )['quantity__sum'] self.assertEqual(allstock_before, allstock_after) def test_take_stock(self): @@ -578,10 +585,7 @@ class StockTest(StockTestBase): # Ensure we do not have unique serials enabled InvenTreeSetting.set_setting('SERIAL_NUMBER_GLOBALLY_UNIQUE', False, None) - item = StockItem.objects.create( - part=p, - quantity=1, - ) + item = StockItem.objects.create(part=p, quantity=1) self.assertFalse(item.serialized) @@ -609,10 +613,7 @@ class StockTest(StockTestBase): trackable=True, ) - item = StockItem.objects.create( - part=p, - quantity=1, - ) + item = StockItem.objects.create(part=p, quantity=1) for sn in [12345, '12345', ' 12345 ']: item.serial = sn @@ -620,7 +621,7 @@ class StockTest(StockTestBase): self.assertEqual(item.serial_int, 12345) - item.serial = "-123" + item.serial = '-123' item.save() # Negative number should map to positive value @@ -631,7 +632,7 @@ class StockTest(StockTestBase): item.save() # The 'integer' portion has been clipped to a maximum value - self.assertEqual(item.serial_int, 0x7fffffff) + self.assertEqual(item.serial_int, 0x7FFFFFFF) # Non-numeric values should encode to zero for sn in ['apple', 'banana', 'carrot']: @@ -644,30 +645,18 @@ class StockTest(StockTestBase): item.serial = 100 item.save() - item_next = StockItem.objects.create( - part=p, - serial=150, - quantity=1 - ) + item_next = StockItem.objects.create(part=p, serial=150, quantity=1) self.assertEqual(item.get_next_serialized_item(), item_next) - item_prev = StockItem.objects.create( - part=p, - serial=' 57', - quantity=1, - ) + item_prev = StockItem.objects.create(part=p, serial=' 57', quantity=1) self.assertEqual(item.get_next_serialized_item(reverse=True), item_prev) # Create a number of serialized stock items around the current item for i in range(75, 125): try: - StockItem.objects.create( - part=p, - serial=i, - quantity=1, - ) + StockItem.objects.create(part=p, serial=i, quantity=1) except Exception: pass @@ -696,14 +685,14 @@ class StockTest(StockTestBase): # Try an invalid quantity with self.assertRaises(ValidationError): - item.serializeStock("k", [], self.user) + item.serializeStock('k', [], self.user) with self.assertRaises(ValidationError): item.serializeStock(-1, [], self.user) # Not enough serial numbers for all stock items. with self.assertRaises(ValidationError): - item.serializeStock(3, "hello", self.user) + item.serializeStock(3, 'hello', self.user) def test_serialize_stock_valid(self): """Perform valid stock serializations.""" @@ -755,55 +744,25 @@ class StockTest(StockTestBase): """ # First, we will create a stock location structure - A = StockLocation.objects.create( - name='A', - description='Top level location' - ) + A = StockLocation.objects.create(name='A', description='Top level location') - B1 = StockLocation.objects.create( - name='B1', - parent=A - ) + B1 = StockLocation.objects.create(name='B1', parent=A) - B2 = StockLocation.objects.create( - name='B2', - parent=A - ) + B2 = StockLocation.objects.create(name='B2', parent=A) - B3 = StockLocation.objects.create( - name='B3', - parent=A - ) + B3 = StockLocation.objects.create(name='B3', parent=A) - C11 = StockLocation.objects.create( - name='C11', - parent=B1, - ) + C11 = StockLocation.objects.create(name='C11', parent=B1) - C12 = StockLocation.objects.create( - name='C12', - parent=B1, - ) + C12 = StockLocation.objects.create(name='C12', parent=B1) - C21 = StockLocation.objects.create( - name='C21', - parent=B2, - ) + C21 = StockLocation.objects.create(name='C21', parent=B2) - C22 = StockLocation.objects.create( - name='C22', - parent=B2, - ) + C22 = StockLocation.objects.create(name='C22', parent=B2) - C31 = StockLocation.objects.create( - name='C31', - parent=B3, - ) + C31 = StockLocation.objects.create(name='C31', parent=B3) - C32 = StockLocation.objects.create( - name='C32', - parent=B3 - ) + C32 = StockLocation.objects.create(name='C32', parent=B3) # Check that the tree_id is correct for each sublocation for loc in [B1, B2, B3, C11, C12, C21, C22, C31, C32]: @@ -850,9 +809,7 @@ class StockTest(StockTestBase): # Add some stock items to B3 for _ in range(10): StockItem.objects.create( - part=Part.objects.get(pk=1), - quantity=10, - location=B3 + part=Part.objects.get(pk=1), quantity=10, location=B3 ) self.assertEqual(StockItem.objects.filter(location=B3).count(), 10) @@ -982,7 +939,10 @@ class VariantTest(StockTestBase): chair = Part.objects.get(pk=10000) # Operations on the top-level object - [self.assertFalse(chair.validate_serial_number(i)) for i in [1, 2, 3, 4, 5, 20, 21, 22]] + [ + self.assertFalse(chair.validate_serial_number(i)) + for i in [1, 2, 3, 4, 5, 20, 21, 22] + ] self.assertFalse(chair.validate_serial_number(20)) self.assertFalse(chair.validate_serial_number(21)) @@ -1006,11 +966,7 @@ class VariantTest(StockTestBase): # Create a new serial number n = variant.get_latest_serial_number() - item = StockItem( - part=variant, - quantity=1, - serial=n - ) + item = StockItem(part=variant, quantity=1, serial=n) # This should fail with self.assertRaises(ValidationError): @@ -1040,7 +996,7 @@ class TestResultTest(StockTestBase): tests = item.test_results self.assertEqual(tests.count(), 4) - results = item.getTestResults(test="Temperature Test") + results = item.getTestResults(test='Temperature Test') self.assertEqual(results.count(), 2) # Passing tests @@ -1074,9 +1030,7 @@ class TestResultTest(StockTestBase): test.save() StockItemTestResult.objects.create( - stock_item=item, - test='sew cushion', - result=True + stock_item=item, test='sew cushion', result=True ) # Still should be failing at this point, @@ -1088,7 +1042,7 @@ class TestResultTest(StockTestBase): stock_item=item, test='apply paint', date=datetime.datetime(2022, 12, 12), - result=True + result=True, ) self.assertTrue(item.passedAllRequiredTests()) @@ -1103,32 +1057,25 @@ class TestResultTest(StockTestBase): item.quantity = 50 # Try with an invalid batch code (according to sample validatoin plugin) - item.batch = "X234" + item.batch = 'X234' with self.assertRaises(ValidationError): item.save() - item.batch = "B123" + item.batch = 'B123' item.save() # Do some tests! StockItemTestResult.objects.create( - stock_item=item, - test="Firmware", - result=True + stock_item=item, test='Firmware', result=True ) StockItemTestResult.objects.create( - stock_item=item, - test="Paint Color", - result=True, - value="Red" + stock_item=item, test='Paint Color', result=True, value='Red' ) StockItemTestResult.objects.create( - stock_item=item, - test="Applied Sticker", - result=False + stock_item=item, test='Applied Sticker', result=False ) self.assertEqual(item.test_results.count(), 3) @@ -1142,10 +1089,7 @@ class TestResultTest(StockTestBase): self.assertEqual(item.test_results.count(), 3) self.assertEqual(item2.test_results.count(), 3) - StockItemTestResult.objects.create( - stock_item=item2, - test='A new test' - ) + StockItemTestResult.objects.create(stock_item=item2, test='A new test') self.assertEqual(item.test_results.count(), 3) self.assertEqual(item2.test_results.count(), 4) @@ -1154,10 +1098,7 @@ class TestResultTest(StockTestBase): item2.serializeStock(1, [100], self.user) # Add a test result to the parent *after* serialization - StockItemTestResult.objects.create( - stock_item=item2, - test='abcde' - ) + StockItemTestResult.objects.create(stock_item=item2, test='abcde') self.assertEqual(item2.test_results.count(), 5) @@ -1182,10 +1123,7 @@ class TestResultTest(StockTestBase): # Create a stock item which is installed *inside* the master item sub_item = StockItem.objects.create( - part=item.part, - quantity=1, - belongs_to=item, - location=None + part=item.part, quantity=1, belongs_to=item, location=None ) # Now, create some test results against the sub item @@ -1195,7 +1133,7 @@ class TestResultTest(StockTestBase): stock_item=sub_item, test='firmware version', date=datetime.datetime.now().date(), - result=True + result=True, ) # Should return the same number of tests as before diff --git a/InvenTree/stock/urls.py b/InvenTree/stock/urls.py index 37ea1ce6d9..337bbc4c4b 100644 --- a/InvenTree/stock/urls.py +++ b/InvenTree/stock/urls.py @@ -5,26 +5,29 @@ from django.urls import include, path, re_path from stock import views location_urls = [ - - path(r'/', include([ - # Anything else - direct to the location detail view - re_path('^.*$', views.StockLocationDetail.as_view(), name='stock-location-detail'), - ])), - + path( + r'/', + include([ + # Anything else - direct to the location detail view + re_path( + '^.*$', + views.StockLocationDetail.as_view(), + name='stock-location-detail', + ) + ]), + ) ] stock_item_detail_urls = [ # Anything else - direct to the item detail view - re_path('^.*$', views.StockItemDetail.as_view(), name='stock-item-detail'), + re_path('^.*$', views.StockItemDetail.as_view(), name='stock-item-detail') ] stock_urls = [ # Stock location re_path(r'^location/', include(location_urls)), - # Individual stock items re_path(r'^item/(?P\d+)/', include(stock_item_detail_urls)), - # Default to the stock index page re_path(r'^.*$', views.StockIndex.as_view(), name='stock-index'), ] diff --git a/InvenTree/stock/views.py b/InvenTree/stock/views.py index 22d3964e30..3fb10d968f 100644 --- a/InvenTree/stock/views.py +++ b/InvenTree/stock/views.py @@ -34,7 +34,9 @@ class StockIndex(InvenTreeRoleMixin, InvenTreePluginViewMixin, ListView): # No 'ownership' checks are necessary for the top-level StockLocation view context['user_owns_location'] = True context['location_owner'] = None - context['ownership_enabled'] = common.models.InvenTreeSetting.get_setting('STOCK_OWNERSHIP_CONTROL') + context['ownership_enabled'] = common.models.InvenTreeSetting.get_setting( + 'STOCK_OWNERSHIP_CONTROL' + ) return context @@ -51,9 +53,13 @@ class StockLocationDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailVi """Extend template context.""" context = super().get_context_data(**kwargs) - context['ownership_enabled'] = common.models.InvenTreeSetting.get_setting('STOCK_OWNERSHIP_CONTROL') + context['ownership_enabled'] = common.models.InvenTreeSetting.get_setting( + 'STOCK_OWNERSHIP_CONTROL' + ) context['location_owner'] = context['location'].get_location_owner() - context['user_owns_location'] = context['location'].check_ownership(self.request.user) + context['user_owns_location'] = context['location'].check_ownership( + self.request.user + ) return context @@ -74,14 +80,18 @@ class StockItemDetail(InvenTreeRoleMixin, InvenTreePluginViewMixin, DetailView): data['previous'] = self.object.get_next_serialized_item(reverse=True) data['next'] = self.object.get_next_serialized_item() - data['ownership_enabled'] = common.models.InvenTreeSetting.get_setting('STOCK_OWNERSHIP_CONTROL') + data['ownership_enabled'] = common.models.InvenTreeSetting.get_setting( + 'STOCK_OWNERSHIP_CONTROL' + ) data['item_owner'] = self.object.get_item_owner() data['user_owns_item'] = self.object.check_ownership(self.request.user) # Allocation information data['allocated_to_sales_orders'] = self.object.sales_order_allocation_count() data['allocated_to_build_orders'] = self.object.build_allocation_count() - data['allocated_to_orders'] = data['allocated_to_sales_orders'] + data['allocated_to_build_orders'] + data['allocated_to_orders'] = ( + data['allocated_to_sales_orders'] + data['allocated_to_build_orders'] + ) data['available'] = max(0, self.object.quantity - data['allocated_to_orders']) return data diff --git a/InvenTree/users/admin.py b/InvenTree/users/admin.py index 2b77421ce8..f7758206a9 100644 --- a/InvenTree/users/admin.py +++ b/InvenTree/users/admin.py @@ -18,17 +18,34 @@ class ApiTokenAdmin(admin.ModelAdmin): list_display = ('token', 'user', 'name', 'expiry', 'active') list_filter = ('user', 'revoked') - fields = ('token', 'user', 'name', 'created', 'last_seen', 'revoked', 'expiry', 'metadata') + fields = ( + 'token', + 'user', + 'name', + 'created', + 'last_seen', + 'revoked', + 'expiry', + 'metadata', + ) def get_fields(self, request, obj=None): """Return list of fields to display.""" if obj: - fields = ['token',] + fields = ['token'] else: - fields = ['key',] + fields = ['key'] - fields += ['user', 'name', 'created', 'last_seen', 'revoked', 'expiry', 'metadata'] + fields += [ + 'user', + 'name', + 'created', + 'last_seen', + 'revoked', + 'expiry', + 'metadata', + ] return fields @@ -67,12 +84,10 @@ class InvenTreeGroupAdminForm(forms.ModelForm): class Meta: """Metaclass defines extra fields""" + model = Group exclude = [] - fields = [ - 'name', - 'users', - ] + fields = ['name', 'users'] def __init__(self, *args, **kwargs): # pragma: no cover """Populate the 'users' field with the users in the current group""" @@ -88,7 +103,7 @@ class InvenTreeGroupAdminForm(forms.ModelForm): required=False, widget=FilteredSelectMultiple('users', False), label=_('Users'), - help_text=_('Select which users are assigned to this group') + help_text=_('Select which users are assigned to this group'), ) def save_m2m(self): # pragma: no cover @@ -109,12 +124,21 @@ class RoleGroupAdmin(admin.ModelAdmin): # pragma: no cover form = InvenTreeGroupAdminForm - inlines = [ - RuleSetInline, - ] + inlines = [RuleSetInline] - list_display = ('name', 'admin', 'part_category', 'part', 'stocktake', 'stock_location', - 'stock_item', 'build', 'purchase_order', 'sales_order', 'return_order') + list_display = ( + 'name', + 'admin', + 'part_category', + 'part', + 'stocktake', + 'stock_location', + 'stock_item', + 'build', + 'purchase_order', + 'sales_order', + 'return_order', + ) def get_rule_set(self, obj, rule_set_type): """Return list of permissions for the given ruleset.""" @@ -222,15 +246,14 @@ class RoleGroupAdmin(admin.ModelAdmin): # pragma: no cover # If any, display warning message when group is saved if len(multiple_group_users) > 0: - - msg = _("The following users are members of multiple groups") + ": " + ", ".join(multiple_group_users) - - messages.add_message( - request, - messages.WARNING, - msg + msg = ( + _('The following users are members of multiple groups') + + ': ' + + ', '.join(multiple_group_users) ) + messages.add_message(request, messages.WARNING, msg) + def save_formset(self, request, form, formset, change): """Save the inline formset""" # Save inline Rulesets @@ -247,19 +270,29 @@ class InvenTreeUserAdmin(UserAdmin): (And it's confusing!) """ - list_display = ('username', 'email', 'first_name', 'last_name', 'is_staff', 'last_login') # display last connection for each user in user admin panel. + + list_display = ( + 'username', + 'email', + 'first_name', + 'last_name', + 'is_staff', + 'last_login', + ) # display last connection for each user in user admin panel. fieldsets = ( (None, {'fields': ('username', 'password')}), (_('Personal info'), {'fields': ('first_name', 'last_name', 'email')}), - (_('Permissions'), { - 'fields': ('is_active', 'is_staff', 'is_superuser', 'groups'), - }), + ( + _('Permissions'), + {'fields': ('is_active', 'is_staff', 'is_superuser', 'groups')}, + ), (_('Important dates'), {'fields': ('last_login', 'date_joined')}), ) class OwnerAdmin(admin.ModelAdmin): """Custom admin interface for the Owner model.""" + pass diff --git a/InvenTree/users/api.py b/InvenTree/users/api.py index 3486bf6ebb..1c4f6611d5 100644 --- a/InvenTree/users/api.py +++ b/InvenTree/users/api.py @@ -12,8 +12,13 @@ from rest_framework.views import APIView import InvenTree.helpers from InvenTree.filters import SEARCH_ORDER_FILTER -from InvenTree.mixins import (ListAPI, ListCreateAPI, RetrieveAPI, - RetrieveUpdateAPI, RetrieveUpdateDestroyAPI) +from InvenTree.mixins import ( + ListAPI, + ListCreateAPI, + RetrieveAPI, + RetrieveUpdateAPI, + RetrieveUpdateDestroyAPI, +) from InvenTree.serializers import ExendedUserSerializer, UserCreateSerializer from users.models import ApiToken, Owner, RuleSet, check_user_role from users.serializers import GroupSerializer, OwnerSerializer @@ -52,10 +57,11 @@ class OwnerList(ListAPI): # Get a list of all matching users, depending on the *is_active* flag if is_active is not None: is_active = InvenTree.helpers.str2bool(is_active) - matching_user_ids = User.objects.filter(is_active=is_active).values_list('pk', flat=True) + matching_user_ids = User.objects.filter(is_active=is_active).values_list( + 'pk', flat=True + ) for result in queryset.all(): - name = str(result.name()).lower().strip() search_match = True @@ -71,7 +77,10 @@ class OwnerList(ListAPI): if is_active is not None: # Skip any users which do not match the required *is_active* value - if result.owner_type.name == 'user' and result.owner_id not in matching_user_ids: + if ( + result.owner_type.name == 'user' + and result.owner_id not in matching_user_ids + ): continue # If we get here, there is no reason *not* to include this result @@ -96,9 +105,7 @@ class RoleDetails(APIView): (Requires authentication) """ - permission_classes = [ - permissions.IsAuthenticated - ] + permission_classes = [permissions.IsAuthenticated] def get(self, request, *args, **kwargs): """Return the list of roles / permissions available to the current user""" @@ -107,14 +114,12 @@ class RoleDetails(APIView): roles = {} for ruleset in RuleSet.RULESET_CHOICES: - role, _text = ruleset permissions = [] for permission in RuleSet.RULESET_PERMISSIONS: if check_user_role(user, role, permission): - permissions.append(permission) if len(permissions) > 0: @@ -138,9 +143,7 @@ class UserDetail(RetrieveUpdateDestroyAPI): queryset = User.objects.all() serializer_class = ExendedUserSerializer - permission_classes = [ - permissions.IsAuthenticated - ] + permission_classes = [permissions.IsAuthenticated] class MeUserDetail(RetrieveUpdateAPI, UserDetail): @@ -156,16 +159,10 @@ class UserList(ListCreateAPI): queryset = User.objects.all() serializer_class = UserCreateSerializer - permission_classes = [ - permissions.IsAuthenticated, - ] + permission_classes = [permissions.IsAuthenticated] filter_backends = SEARCH_ORDER_FILTER - search_fields = [ - 'first_name', - 'last_name', - 'username', - ] + search_fields = ['first_name', 'last_name', 'username'] ordering_fields = [ 'email', @@ -177,11 +174,7 @@ class UserList(ListCreateAPI): 'is_active', ] - filterset_fields = [ - 'is_staff', - 'is_active', - 'is_superuser', - ] + filterset_fields = ['is_staff', 'is_active', 'is_superuser'] class GroupDetail(RetrieveUpdateDestroyAPI): @@ -189,9 +182,7 @@ class GroupDetail(RetrieveUpdateDestroyAPI): queryset = Group.objects.all() serializer_class = GroupSerializer - permission_classes = [ - permissions.IsAuthenticated, - ] + permission_classes = [permissions.IsAuthenticated] class GroupList(ListCreateAPI): @@ -199,27 +190,19 @@ class GroupList(ListCreateAPI): queryset = Group.objects.all() serializer_class = GroupSerializer - permission_classes = [ - permissions.IsAuthenticated, - ] + permission_classes = [permissions.IsAuthenticated] filter_backends = SEARCH_ORDER_FILTER - search_fields = [ - 'name', - ] + search_fields = ['name'] - ordering_fields = [ - 'name', - ] + ordering_fields = ['name'] class GetAuthToken(APIView): """Return authentication token for an authenticated user.""" - permission_classes = [ - permissions.IsAuthenticated, - ] + permission_classes = [permissions.IsAuthenticated] def get(self, request, *args, **kwargs): """Return an API token if the user is authenticated @@ -230,7 +213,6 @@ class GetAuthToken(APIView): """ if request.user.is_authenticated: - user = request.user name = request.query_params.get('name', '') @@ -239,7 +221,9 @@ class GetAuthToken(APIView): today = datetime.date.today() # Find existing token, which has not expired - token = ApiToken.objects.filter(user=user, name=name, revoked=False, expiry__gte=today).first() + token = ApiToken.objects.filter( + user=user, name=name, revoked=False, expiry__gte=today + ).first() if not token: # User is authenticated, and requesting a token against the provided name. @@ -253,13 +237,11 @@ class GetAuthToken(APIView): token.set_metadata('server_name', request.META.get('SERVER_NAME', '')) token.set_metadata('server_port', request.META.get('SERVER_PORT', '')) - data = { - 'token': token.key, - 'name': token.name, - 'expiry': token.expiry, - } + data = {'token': token.key, 'name': token.name, 'expiry': token.expiry} - logger.info("Created new API token for user '%s' (name='%s')", user.username, name) + logger.info( + "Created new API token for user '%s' (name='%s')", user.username, name + ) return Response(data) @@ -268,22 +250,25 @@ class GetAuthToken(APIView): user_urls = [ - re_path(r'roles/?$', RoleDetails.as_view(), name='api-user-roles'), re_path(r'token/?$', GetAuthToken.as_view(), name='api-token'), re_path(r'^me/', MeUserDetail.as_view(), name='api-user-me'), - - re_path(r'^owner/', include([ - path('/', OwnerDetail.as_view(), name='api-owner-detail'), - re_path(r'^.*$', OwnerList.as_view(), name='api-owner-list'), - ])), - - re_path(r'^group/', include([ - re_path(r'^(?P[0-9]+)/?$', GroupDetail.as_view(), name='api-group-detail'), - re_path(r'^.*$', GroupList.as_view(), name='api-group-list'), - ])), - + re_path( + r'^owner/', + include([ + path('/', OwnerDetail.as_view(), name='api-owner-detail'), + re_path(r'^.*$', OwnerList.as_view(), name='api-owner-list'), + ]), + ), + re_path( + r'^group/', + include([ + re_path( + r'^(?P[0-9]+)/?$', GroupDetail.as_view(), name='api-group-detail' + ), + re_path(r'^.*$', GroupList.as_view(), name='api-group-list'), + ]), + ), re_path(r'^(?P[0-9]+)/?$', UserDetail.as_view(), name='api-user-detail'), - path('', UserList.as_view(), name='api-user-list'), ] diff --git a/InvenTree/users/apps.py b/InvenTree/users/apps.py index f607cf12de..8c16165056 100644 --- a/InvenTree/users/apps.py +++ b/InvenTree/users/apps.py @@ -18,7 +18,10 @@ class UsersConfig(AppConfig): def ready(self): """Called when the 'users' app is loaded at runtime""" # skip loading if plugin registry is not loaded or we run in a background thread - if not InvenTree.ready.isPluginRegistryLoaded() or not InvenTree.ready.isInMainThread(): + if ( + not InvenTree.ready.isPluginRegistryLoaded() + or not InvenTree.ready.isInMainThread() + ): return # Skip if running migrations @@ -26,7 +29,6 @@ class UsersConfig(AppConfig): return if InvenTree.ready.canAppAccessDatabase(allow_test=True): - try: self.assign_permissions() except (OperationalError, ProgrammingError): @@ -45,13 +47,14 @@ class UsersConfig(AppConfig): # First, delete any rule_set objects which have become outdated! for rule in RuleSet.objects.all(): - if rule.name not in RuleSet.RULESET_NAMES: # pragma: no cover # can not change ORM without the app being loaded - logger.info("Deleting outdated ruleset: %s", rule.name) + if ( + rule.name not in RuleSet.RULESET_NAMES + ): # pragma: no cover # can not change ORM without the app being loaded + logger.info('Deleting outdated ruleset: %s', rule.name) rule.delete() # Update group permission assignments for all groups for group in Group.objects.all(): - update_group_roles(group) def update_owners(self): diff --git a/InvenTree/users/authentication.py b/InvenTree/users/authentication.py index 659fafcc48..aad1cfdd5b 100644 --- a/InvenTree/users/authentication.py +++ b/InvenTree/users/authentication.py @@ -26,10 +26,10 @@ class ApiTokenAuthentication(TokenAuthentication): (user, token) = super().authenticate_credentials(key) if token.revoked: - raise exceptions.AuthenticationFailed(_("Token has been revoked")) + raise exceptions.AuthenticationFailed(_('Token has been revoked')) if token.expired: - raise exceptions.AuthenticationFailed(_("Token has expired")) + raise exceptions.AuthenticationFailed(_('Token has expired')) if token.last_seen != datetime.date.today(): # Update the last-seen date diff --git a/InvenTree/users/models.py b/InvenTree/users/models.py index 4cd01439ed..8b40192915 100644 --- a/InvenTree/users/models.py +++ b/InvenTree/users/models.py @@ -26,7 +26,7 @@ import InvenTree.helpers import InvenTree.models from InvenTree.ready import canAppAccessDatabase -logger = logging.getLogger("inventree") +logger = logging.getLogger('inventree') # OVERRIDE START @@ -41,7 +41,7 @@ def user_model_str(self): return self.username -User.add_to_class("__str__", user_model_str) # Overriding User.__str__ +User.add_to_class('__str__', user_model_str) # Overriding User.__str__ # OVERRIDE END @@ -67,6 +67,7 @@ class ApiToken(AuthToken, InvenTree.models.MetadataMixin): class Meta: """Metaclass defines model properties""" + verbose_name = _('API Token') verbose_name_plural = _('API Tokens') abstract = False @@ -88,11 +89,10 @@ class ApiToken(AuthToken, InvenTree.models.MetadataMixin): key = models.CharField( default=default_token, verbose_name=_('Key'), - db_index=True, unique=True, + db_index=True, + unique=True, max_length=100, - validators=[ - MinLengthValidator(50), - ] + validators=[MinLengthValidator(50)], ) # Override the 'user' field, to allow multiple tokens per user @@ -114,19 +114,19 @@ class ApiToken(AuthToken, InvenTree.models.MetadataMixin): default=default_token_expiry, verbose_name=_('Expiry Date'), help_text=_('Token expiry date'), - auto_now=False, auto_now_add=False, + auto_now=False, + auto_now_add=False, ) last_seen = models.DateField( - blank=True, null=True, + blank=True, + null=True, verbose_name=_('Last Seen'), help_text=_('Last time the token was used'), ) revoked = models.BooleanField( - default=False, - verbose_name=_('Revoked'), - help_text=_('Token has been revoked'), + default=False, verbose_name=_('Revoked'), help_text=_('Token has been revoked') ) @staticmethod @@ -136,7 +136,9 @@ class ApiToken(AuthToken, InvenTree.models.MetadataMixin): name = str(name).strip() # Remove any non-printable chars - name = InvenTree.helpers.remove_non_printable_characters(name, remove_newline=True) + name = InvenTree.helpers.remove_non_printable_characters( + name, remove_newline=True + ) name = InvenTree.helpers.strip_html_tags(name) name = name.replace(' ', '-') @@ -199,13 +201,9 @@ class RuleSet(models.Model): ('return_order', _('Return Orders')), ] - RULESET_NAMES = [ - choice[0] for choice in RULESET_CHOICES - ] + RULESET_NAMES = [choice[0] for choice in RULESET_CHOICES] - RULESET_PERMISSIONS = [ - 'view', 'add', 'change', 'delete', - ] + RULESET_PERMISSIONS = ['view', 'add', 'change', 'delete'] RULESET_MODELS = { 'admin': [ @@ -261,15 +259,12 @@ class RuleSet(models.Model): 'company_manufacturerpartattachment', 'label_partlabel', ], - 'stocktake': [ - 'part_partstocktake', - 'part_partstocktakereport', - ], + 'stocktake': ['part_partstocktake', 'part_partstocktakereport'], 'stock_location': [ 'stock_stocklocation', 'stock_stocklocationtype', 'label_stocklocationlabel', - 'report_stocklocationreport' + 'report_stocklocationreport', ], 'stock': [ 'stock_stockitem', @@ -331,7 +326,7 @@ class RuleSet(models.Model): 'order_returnorderextraline', 'order_returnorderattachment', 'report_returnorderreport', - ] + ], } # Database models we ignore permission sets for @@ -339,7 +334,6 @@ class RuleSet(models.Model): # Core django models (not user configurable) 'admin_logentry', 'contenttypes_contenttype', - # Models which currently do not require permissions 'common_colortheme', 'common_customunit', @@ -353,13 +347,11 @@ class RuleSet(models.Model): 'common_webhookmessage', 'label_labeloutput', 'users_owner', - # Third-party tables 'error_report_error', 'exchange_rate', 'exchange_exchangebackend', 'user_sessions_session', - # Django-q 'django_q_ormq', 'django_q_failure', @@ -368,46 +360,50 @@ class RuleSet(models.Model): 'django_q_success', ] - RULESET_CHANGE_INHERIT = [ - ('part', 'partparameter'), - ('part', 'bomitem'), - ] + RULESET_CHANGE_INHERIT = [('part', 'partparameter'), ('part', 'bomitem')] - RULE_OPTIONS = [ - 'can_view', - 'can_add', - 'can_change', - 'can_delete', - ] + RULE_OPTIONS = ['can_view', 'can_add', 'can_change', 'can_delete'] class Meta: """Metaclass defines additional model properties""" - unique_together = ( - ('name', 'group'), - ) + + unique_together = (('name', 'group'),) name = models.CharField( max_length=50, choices=RULESET_CHOICES, blank=False, - help_text=_('Permission set') + help_text=_('Permission set'), ) group = models.ForeignKey( Group, related_name='rule_sets', - blank=False, null=False, + blank=False, + null=False, on_delete=models.CASCADE, help_text=_('Group'), ) - can_view = models.BooleanField(verbose_name=_('View'), default=True, help_text=_('Permission to view items')) + can_view = models.BooleanField( + verbose_name=_('View'), default=True, help_text=_('Permission to view items') + ) - can_add = models.BooleanField(verbose_name=_('Add'), default=False, help_text=_('Permission to add items')) + can_add = models.BooleanField( + verbose_name=_('Add'), default=False, help_text=_('Permission to add items') + ) - can_change = models.BooleanField(verbose_name=_('Change'), default=False, help_text=_('Permissions to edit items')) + can_change = models.BooleanField( + verbose_name=_('Change'), + default=False, + help_text=_('Permissions to edit items'), + ) - can_delete = models.BooleanField(verbose_name=_('Delete'), default=False, help_text=_('Permission to delete items')) + can_delete = models.BooleanField( + verbose_name=_('Delete'), + default=False, + help_text=_('Permission to delete items'), + ) @classmethod def check_table_permission(cls, user, table, permission): @@ -423,12 +419,11 @@ class RuleSet(models.Model): # Work out which roles touch the given table for role in cls.RULESET_NAMES: if table in cls.RULESET_MODELS[role]: - if check_user_role(user, role, permission): return True # Check for children models which inherits from parent role - for (parent, child) in cls.RULESET_CHANGE_INHERIT: + for parent, child in cls.RULESET_CHANGE_INHERIT: # Get child model name parent_child_string = f'{parent}_{child}' @@ -439,7 +434,9 @@ class RuleSet(models.Model): # Print message instead of throwing an error name = getattr(user, 'name', user.pk) - logger.debug("User '%s' failed permission check for %s.%s", name, table, permission) + logger.debug( + "User '%s' failed permission check for %s.%s", name, table, permission + ) return False @@ -448,15 +445,17 @@ class RuleSet(models.Model): """Construct the correctly formatted permission string, given the app_model name, and the permission type.""" model, app = split_model(model) - return f"{app}.{permission}_{model}" + return f'{app}.{permission}_{model}' def __str__(self, debug=False): # pragma: no cover """Ruleset string representation.""" if debug: # Makes debugging easier - return f'{str(self.group).ljust(15)}: {self.name.title().ljust(15)} | ' \ - f'v: {str(self.can_view).ljust(5)} | a: {str(self.can_add).ljust(5)} | ' \ - f'c: {str(self.can_change).ljust(5)} | d: {str(self.can_delete).ljust(5)}' + return ( + f'{str(self.group).ljust(15)}: {self.name.title().ljust(15)} | ' + f'v: {str(self.can_view).ljust(5)} | a: {str(self.can_add).ljust(5)} | ' + f'c: {str(self.can_change).ljust(5)} | d: {str(self.can_delete).ljust(5)}' + ) return self.name def save(self, *args, **kwargs): @@ -526,7 +525,7 @@ def update_group_roles(group, debug=False): # and create a simplified permission key string for p in group.permissions.all().prefetch_related('content_type'): (permission, app, model) = p.natural_key() - permission_string = f"{app}.{permission}" + permission_string = f'{app}.{permission}' group_permissions.add(permission_string) # List of permissions which must be added to the group @@ -544,12 +543,11 @@ def update_group_roles(group, debug=False): allowed: Whether or not the action is allowed """ if action not in ['view', 'add', 'change', 'delete']: # pragma: no cover - raise ValueError(f"Action {action} is invalid") + raise ValueError(f'Action {action} is invalid') permission_string = RuleSet.get_model_permission_string(model, action) if allowed: - # An 'allowed' action is always preferenced over a 'forbidden' action if permission_string in permissions_to_delete: permissions_to_delete.remove(permission_string) @@ -557,7 +555,6 @@ def update_group_roles(group, debug=False): permissions_to_add.add(permission_string) else: - # A forbidden action will be ignored if we have already allowed it if permission_string not in permissions_to_add: permissions_to_delete.add(permission_string) @@ -569,7 +566,6 @@ def update_group_roles(group, debug=False): # Get all the rulesets associated with this group for r in RuleSet.RULESET_CHOICES: - rulename = r[0] if rulename in rulesets: @@ -605,16 +601,19 @@ def update_group_roles(group, debug=False): try: content_type = ContentType.objects.get(app_label=app, model=model) - permission = Permission.objects.get(content_type=content_type, codename=perm) + permission = Permission.objects.get( + content_type=content_type, codename=perm + ) except ContentType.DoesNotExist: # pragma: no cover - logger.warning("Error: Could not find permission matching '%s'", permission_string) + logger.warning( + "Error: Could not find permission matching '%s'", permission_string + ) permission = None return permission # Add any required permissions to the group for perm in permissions_to_add: - # Ignore if permission is already in the group if perm in group_permissions: continue @@ -625,11 +624,10 @@ def update_group_roles(group, debug=False): group.permissions.add(permission) if debug: # pragma: no cover - logger.debug("Adding permission %s to group %s", perm, group.name) + logger.debug('Adding permission %s to group %s', perm, group.name) # Remove any extra permissions from the group for perm in permissions_to_delete: - # Ignore if the permission is not already assigned if perm not in group_permissions: continue @@ -640,11 +638,11 @@ def update_group_roles(group, debug=False): group.permissions.remove(permission) if debug: # pragma: no cover - logger.debug("Removing permission %s from group %s", perm, group.name) + logger.debug('Removing permission %s from group %s', perm, group.name) # Enable all action permissions for certain children models # if parent model has 'change' permission - for (parent, child) in RuleSet.RULESET_CHANGE_INHERIT: + for parent, child in RuleSet.RULESET_CHANGE_INHERIT: parent_child_string = f'{parent}_{child}' # Check each type of permission @@ -662,7 +660,9 @@ def update_group_roles(group, debug=False): permission = get_permission_object(child_perm) if permission: group.permissions.add(permission) - logger.debug("Adding permission %s to group %s", child_perm, group.name) + logger.debug( + 'Adding permission %s to group %s', child_perm, group.name + ) def clear_user_role_cache(user): @@ -675,7 +675,7 @@ def clear_user_role_cache(user): """ for role in RuleSet.RULESET_MODELS.keys(): for perm in ['add', 'change', 'view', 'delete']: - key = f"role_{user}_{role}_{perm}" + key = f'role_{user}_{role}_{perm}' cache.delete(key) @@ -707,7 +707,7 @@ def check_user_role(user, role, permission): return True # First, check the cache - key = f"role_{user}_{role}_{permission}" + key = f'role_{user}_{role}_{permission}' result = cache.get(key) @@ -718,11 +718,8 @@ def check_user_role(user, role, permission): result = False for group in user.groups.all(): - for rule in group.rule_sets.all(): - if rule.name == role: - if permission == 'add' and rule.can_add: result = True break @@ -756,10 +753,10 @@ class Owner(models.Model): class Meta: """Metaclass defines extra model properties""" + # Ensure all owners are unique constraints = [ - UniqueConstraint(fields=['owner_type', 'owner_id'], - name='unique_owner') + UniqueConstraint(fields=['owner_type', 'owner_id'], name='unique_owner') ] @classmethod @@ -794,7 +791,9 @@ class Owner(models.Model): """Returns the API endpoint URL associated with the Owner model""" return reverse('api-owner-list') - owner_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, null=True, blank=True) + owner_type = models.ForeignKey( + ContentType, on_delete=models.CASCADE, null=True, blank=True + ) owner_id = models.PositiveIntegerField(null=True, blank=True) @@ -802,7 +801,10 @@ class Owner(models.Model): def __str__(self): """Defines the owner string representation.""" - if self.owner_type.name == 'user' and common_models.InvenTreeSetting.get_setting('DISPLAY_FULL_NAMES'): + if ( + self.owner_type.name == 'user' + and common_models.InvenTreeSetting.get_setting('DISPLAY_FULL_NAMES') + ): display_name = self.owner.get_full_name() else: display_name = str(self.owner) @@ -810,7 +812,10 @@ class Owner(models.Model): def name(self): """Return the 'name' of this owner.""" - if self.owner_type.name == 'user' and common_models.InvenTreeSetting.get_setting('DISPLAY_FULL_NAMES'): + if ( + self.owner_type.name == 'user' + and common_models.InvenTreeSetting.get_setting('DISPLAY_FULL_NAMES') + ): return self.owner.get_full_name() return str(self.owner) @@ -839,8 +844,10 @@ class Owner(models.Model): user_model = get_user_model() owner = None content_type_id = 0 - content_type_id_list = [ContentType.objects.get_for_model(Group).id, - ContentType.objects.get_for_model(user_model).id] + content_type_id_list = [ + ContentType.objects.get_for_model(Group).id, + ContentType.objects.get_for_model(user_model).id, + ] # If instance type is obvious: set content type if isinstance(user_or_group, Group): @@ -850,8 +857,9 @@ class Owner(models.Model): if content_type_id: try: - owner = Owner.objects.get(owner_id=user_or_group.id, - owner_type=content_type_id) + owner = Owner.objects.get( + owner_id=user_or_group.id, owner_type=content_type_id + ) except Owner.DoesNotExist: pass @@ -870,10 +878,18 @@ class Owner(models.Model): if include_group: # Include "group-type" owner in the query - query = Q(owner_id__in=users, owner_type=ContentType.objects.get_for_model(user_model).id) | \ - Q(owner_id=self.owner.id, owner_type=ContentType.objects.get_for_model(Group).id) + query = Q( + owner_id__in=users, + owner_type=ContentType.objects.get_for_model(user_model).id, + ) | Q( + owner_id=self.owner.id, + owner_type=ContentType.objects.get_for_model(Group).id, + ) else: - query = Q(owner_id__in=users, owner_type=ContentType.objects.get_for_model(user_model).id) + query = Q( + owner_id__in=users, + owner_type=ContentType.objects.get_for_model(user_model).id, + ) related_owners = Owner.objects.filter(query) diff --git a/InvenTree/users/serializers.py b/InvenTree/users/serializers.py index 9b4fd48353..0aa5502e14 100644 --- a/InvenTree/users/serializers.py +++ b/InvenTree/users/serializers.py @@ -14,13 +14,9 @@ class OwnerSerializer(InvenTreeModelSerializer): class Meta: """Metaclass defines serializer fields.""" + model = Owner - fields = [ - 'pk', - 'owner_id', - 'name', - 'label', - ] + fields = ['pk', 'owner_id', 'name', 'label'] name = serializers.CharField(read_only=True) @@ -34,7 +30,4 @@ class GroupSerializer(InvenTreeModelSerializer): """Metaclass defines serializer fields""" model = Group - fields = [ - 'pk', - 'name', - ] + fields = ['pk', 'name'] diff --git a/InvenTree/users/test_api.py b/InvenTree/users/test_api.py index 96039153db..5417a3920a 100644 --- a/InvenTree/users/test_api.py +++ b/InvenTree/users/test_api.py @@ -14,10 +14,7 @@ class UserAPITests(InvenTreeAPITestCase): def test_user_api(self): """Tests for User API endpoints""" - response = self.get( - reverse('api-user-list'), - expected_code=200 - ) + response = self.get(reverse('api-user-list'), expected_code=200) # Check the correct number of results was returned self.assertEqual(len(response.data), User.objects.count()) @@ -29,8 +26,7 @@ class UserAPITests(InvenTreeAPITestCase): pk = response.data[0]['pk'] response = self.get( - reverse('api-user-detail', kwargs={'pk': pk}), - expected_code=200 + reverse('api-user-detail', kwargs={'pk': pk}), expected_code=200 ) self.assertIn('pk', response.data) @@ -38,10 +34,7 @@ class UserAPITests(InvenTreeAPITestCase): def test_group_api(self): """Tests for the Group API endpoints""" - response = self.get( - reverse('api-group-list'), - expected_code=200, - ) + response = self.get(reverse('api-group-list'), expected_code=200) self.assertIn('name', response.data[0]) @@ -106,7 +99,14 @@ class UserTokenTests(InvenTreeAPITestCase): token.refresh_from_db() # Check that the metadata has been updated - keys = ['user_agent', 'remote_addr', 'remote_host', 'remote_user', 'server_name', 'server_port'] + keys = [ + 'user_agent', + 'remote_addr', + 'remote_host', + 'remote_user', + 'server_name', + 'server_port', + ] for k in keys: self.assertIn(k, token.metadata) @@ -115,7 +115,9 @@ class UserTokenTests(InvenTreeAPITestCase): """Test user token authentication""" # Create a new token - token_key = self.get(url=reverse('api-token'), data={'name': 'test'}, expected_code=200).data['token'] + token_key = self.get( + url=reverse('api-token'), data={'name': 'test'}, expected_code=200 + ).data['token'] # Check that we can use the token to authenticate self.client.logout() diff --git a/InvenTree/users/test_migrations.py b/InvenTree/users/test_migrations.py index 9cc3b2af24..707546ffb5 100644 --- a/InvenTree/users/test_migrations.py +++ b/InvenTree/users/test_migrations.py @@ -15,17 +15,9 @@ class TestForwardMigrations(MigratorTestCase): """Setup the initial state of the database before migrations""" User = self.old_state.apps.get_model('auth', 'user') - User.objects.create( - username='fred', - email='fred@fred.com', - password='password' - ) + User.objects.create(username='fred', email='fred@fred.com', password='password') - User.objects.create( - username='brad', - email='brad@fred.com', - password='password' - ) + User.objects.create(username='brad', email='brad@fred.com', password='password') def test_users_exist(self): """Test that users exist in the database""" diff --git a/InvenTree/users/tests.py b/InvenTree/users/tests.py index de1c443ef5..76eee1c0f6 100644 --- a/InvenTree/users/tests.py +++ b/InvenTree/users/tests.py @@ -21,25 +21,27 @@ class RuleSetModelTest(TestCase): missing = [name for name in RuleSet.RULESET_NAMES if name not in keys] if len(missing) > 0: # pragma: no cover - print("The following rulesets do not have models assigned:") + print('The following rulesets do not have models assigned:') for m in missing: - print("-", m) + print('-', m) # Check if models have been defined for a ruleset which is incorrect extra = [name for name in keys if name not in RuleSet.RULESET_NAMES] if len(extra) > 0: # pragma: no cover - print("The following rulesets have been improperly added to RULESET_MODELS:") + print( + 'The following rulesets have been improperly added to RULESET_MODELS:' + ) for e in extra: - print("-", e) + print('-', e) # Check that each ruleset has models assigned empty = [key for key in keys if len(RuleSet.RULESET_MODELS[key]) == 0] if len(empty) > 0: # pragma: no cover - print("The following rulesets have empty entries in RULESET_MODELS:") + print('The following rulesets have empty entries in RULESET_MODELS:') for e in empty: - print("-", e) + print('-', e) self.assertEqual(len(missing), 0) self.assertEqual(len(extra), 0) @@ -61,23 +63,25 @@ class RuleSetModelTest(TestCase): # Now check that each defined model is a valid table name for key in RuleSet.RULESET_MODELS.keys(): - models = RuleSet.RULESET_MODELS[key] for m in models: - assigned_models.add(m) missing_models = set() for model in available_tables: - if model not in assigned_models and model not in RuleSet.RULESET_IGNORE: # pragma: no cover + if ( + model not in assigned_models and model not in RuleSet.RULESET_IGNORE + ): # pragma: no cover missing_models.add(model) if len(missing_models) > 0: # pragma: no cover - print("The following database models are not covered by the defined RuleSet permissions:") + print( + 'The following database models are not covered by the defined RuleSet permissions:' + ) for m in missing_models: - print("-", m) + print('-', m) extra_models = set() @@ -94,9 +98,9 @@ class RuleSetModelTest(TestCase): extra_models.add(model) if len(extra_models) > 0: # pragma: no cover - print("The following RuleSet permissions do not match a database model:") + print('The following RuleSet permissions do not match a database model:') for m in extra_models: - print("-", m) + print('-', m) self.assertEqual(len(missing_models), 0) self.assertEqual(len(extra_models), 0) @@ -104,7 +108,7 @@ class RuleSetModelTest(TestCase): def test_permission_assign(self): """Test that the permission assigning works!""" # Create a new group - group = Group.objects.create(name="Test group") + group = Group.objects.create(name='Test group') rulesets = group.rule_sets.all() @@ -115,7 +119,6 @@ class RuleSetModelTest(TestCase): permission_set = set() for models in RuleSet.RULESET_MODELS.values(): - for model in models: permission_set.add(model) @@ -208,7 +211,9 @@ class OwnerModelTest(InvenTreeTestCase): # not authed self.do_request(reverse('api-owner-list'), {}, 401) - self.do_request(reverse('api-owner-detail', kwargs={'pk': self.user.id}), {}, 401) + self.do_request( + reverse('api-owner-detail', kwargs={'pk': self.user.id}), {}, 401 + ) self.client.login(username=self.username, password=self.password) # user list @@ -230,7 +235,9 @@ class OwnerModelTest(InvenTreeTestCase): # self.assertEqual(response['owner_id'], group.pk) # own user detail - response_detail = self.do_request(reverse('api-user-detail', kwargs={'pk': self.user.id}), {}, 200) + response_detail = self.do_request( + reverse('api-user-detail', kwargs={'pk': self.user.id}), {}, 200 + ) self.assertEqual(response_detail['username'], self.username) response_me = self.do_request(reverse('api-user-me'), {}, 200) @@ -251,5 +258,7 @@ class OwnerModelTest(InvenTreeTestCase): self.assertEqual(response['token'], token.first().key) # test user is associated with token - response = self.do_request(reverse('api-user-me'), {'name': 'another-token'}, 200) + response = self.do_request( + reverse('api-user-me'), {'name': 'another-token'}, 200 + ) self.assertEqual(response['username'], self.username) diff --git a/InvenTree/web/templatetags/spa_helper.py b/InvenTree/web/templatetags/spa_helper.py index 5da2a0add9..003682402a 100644 --- a/InvenTree/web/templatetags/spa_helper.py +++ b/InvenTree/web/templatetags/spa_helper.py @@ -9,7 +9,7 @@ from django import template from django.conf import settings from django.utils.safestring import mark_safe -logger = getLogger("InvenTree") +logger = getLogger('InvenTree') register = template.Library() FRONTEND_SETTINGS = json.dumps(settings.FRONTEND_SETTINGS) @@ -18,40 +18,45 @@ FRONTEND_SETTINGS = json.dumps(settings.FRONTEND_SETTINGS) @register.simple_tag def spa_bundle(manifest_path: Union[str, Path] = '', app: str = 'web'): """Render SPA bundle.""" + def get_url(file: str) -> str: """Get static url for file.""" - return f"{settings.STATIC_URL}{app}/{file}" + return f'{settings.STATIC_URL}{app}/{file}' if manifest_path == '': - manifest_path = Path(__file__).parent.parent.joinpath("static/web/manifest.json") + manifest_path = Path(__file__).parent.parent.joinpath( + 'static/web/manifest.json' + ) manifest = Path(manifest_path) if not manifest.exists(): - logger.error("Manifest file not found") + logger.error('Manifest file not found') return try: manifest_data = json.load(manifest.open()) except (TypeError, json.decoder.JSONDecodeError): - logger.exception("Failed to parse manifest file") + logger.exception('Failed to parse manifest file') return - return_string = "" + return_string = '' # CSS (based on index.css file as entrypoint) - css_index = manifest_data.get("index.css") + css_index = manifest_data.get('index.css') if css_index: - return_string += f'' + return_string += ( + f'' + ) # JS (based on index.html file as entrypoint) - index = manifest_data.get("index.html") - dynamic_files = index.get("dynamicImports", []) - imports_files = "".join( - [ - f'' - for file in dynamic_files - ] + index = manifest_data.get('index.html') + dynamic_files = index.get('dynamicImports', []) + imports_files = ''.join([ + f'' + for file in dynamic_files + ]) + return_string += ( + f'{imports_files}' ) - return_string += f'{imports_files}' return mark_safe(return_string) @@ -59,4 +64,6 @@ def spa_bundle(manifest_path: Union[str, Path] = '', app: str = 'web'): @register.simple_tag def spa_settings(): """Render settings for spa.""" - return mark_safe(f"""""") + return mark_safe( + f"""""" + ) diff --git a/InvenTree/web/tests.py b/InvenTree/web/tests.py index 7822781f80..1975bfd3f2 100644 --- a/InvenTree/web/tests.py +++ b/InvenTree/web/tests.py @@ -1,4 +1,5 @@ """Tests for PUI backend stuff.""" + import json import os from pathlib import Path @@ -23,23 +24,29 @@ class TemplateTagTest(InvenTreeTestCase): def test_spa_bundle(self): """Test the 'spa_bundle' template tag""" resp = spa_helper.spa_bundle() - self.assertTrue(resp.startswith('window.INVENTREE_SETTINGS=')) - settings_data_string = resp.replace('', '') + settings_data_string = resp.replace( + '', '') settings_data = json.loads(settings_data_string) self.assertSettings(settings_data) @@ -61,8 +68,8 @@ class TemplateTagTest(InvenTreeTestCase): self.assertTrue(rsp['show_server_selector']) # No debug, serverlist -> no selector - envs = {'INVENTREE_PUI_SETTINGS': json.dumps({'server_list': ['aa', 'bb',]})} + envs = {'INVENTREE_PUI_SETTINGS': json.dumps({'server_list': ['aa', 'bb']})} with mock.patch.dict(os.environ, envs): rsp = get_frontend_settings(False) self.assertFalse('show_server_selector' in rsp) - self.assertEqual(rsp['server_list'], ['aa', 'bb',]) + self.assertEqual(rsp['server_list'], ['aa', 'bb']) diff --git a/InvenTree/web/urls.py b/InvenTree/web/urls.py index ea3f893729..16cdb57f40 100644 --- a/InvenTree/web/urls.py +++ b/InvenTree/web/urls.py @@ -1,4 +1,5 @@ """URLs for web app.""" + from django.conf import settings from django.shortcuts import redirect from django.urls import include, path, re_path @@ -16,16 +17,23 @@ class RedirectAssetView(TemplateView): ) -spa_view = ensure_csrf_cookie(TemplateView.as_view(template_name="web/index.html")) -assets_path = path("assets/", RedirectAssetView.as_view()) +spa_view = ensure_csrf_cookie(TemplateView.as_view(template_name='web/index.html')) +assets_path = path('assets/', RedirectAssetView.as_view()) urlpatterns = [ - path(f"{settings.FRONTEND_URL_BASE}/", include([ - assets_path, - path("set-password?uid=&token=", spa_view, name="password_reset_confirm",), - re_path(".*", spa_view), - ])), + path( + f'{settings.FRONTEND_URL_BASE}/', + include([ + assets_path, + path( + 'set-password?uid=&token=', + spa_view, + name='password_reset_confirm', + ), + re_path('.*', spa_view), + ]), + ), assets_path, - path(settings.FRONTEND_URL_BASE, spa_view, name="platform"), + path(settings.FRONTEND_URL_BASE, spa_view, name='platform'), ] diff --git a/ci/check_api_endpoint.py b/ci/check_api_endpoint.py index ad57ec84fb..b216fce32f 100644 --- a/ci/check_api_endpoint.py +++ b/ci/check_api_endpoint.py @@ -5,24 +5,19 @@ import json import requests # We expect the server to be running on the local host -url = "http://localhost:8000/api/" +url = 'http://localhost:8000/api/' -print("Testing InvenTree API endpoint") +print('Testing InvenTree API endpoint') response = requests.get(url) assert response.status_code == 200 -print("- Response 200 OK") +print('- Response 200 OK') data = json.loads(response.text) -required_keys = [ - 'server', - 'version', - 'apiVersion', - 'worker_running', -] +required_keys = ['server', 'version', 'apiVersion', 'worker_running'] for key in required_keys: assert key in data @@ -31,6 +26,6 @@ for key in required_keys: # Check that the worker is running assert data['worker_running'] -print("- Background worker is operational") +print('- Background worker is operational') -print("API Endpoint Tests Passed OK") +print('API Endpoint Tests Passed OK') diff --git a/ci/check_js_templates.py b/ci/check_js_templates.py index 0b80901ffa..3c3cee8e48 100644 --- a/ci/check_js_templates.py +++ b/ci/check_js_templates.py @@ -18,25 +18,24 @@ js_dynamic_dir = os.path.join(template_dir, 'js', 'dynamic') errors = 0 -print("=================================") -print("Checking static javascript files:") -print("=================================") +print('=================================') +print('Checking static javascript files:') +print('=================================') def check_invalid_tag(data): """Check for invalid tags.""" - pattern = r"{%(\w+)" + pattern = r'{%(\w+)' err_count = 0 for idx, line in enumerate(data): - results = re.findall(pattern, line) for result in results: err_count += 1 - print(f" - Error on line {idx+1}: %{{{result[0]}") + print(f' - Error on line {idx+1}: %{{{result[0]}') return err_count @@ -56,14 +55,12 @@ def check_prohibited_tags(data): 'url', ] - pattern = r"{% (\w+)\s" + pattern = r'{% (\w+)\s' err_count = 0 for idx, line in enumerate(data): - for tag in re.findall(pattern, line): - if tag not in allowed_tags: print(f" > Line {idx+1} contains prohibited template tag '{tag}'") err_count += 1 @@ -72,7 +69,6 @@ def check_prohibited_tags(data): for filename in pathlib.Path(js_i18n_dir).rglob('*.js'): - print(f"Checking file 'translated/{os.path.basename(filename)}':") with open(filename, 'r') as js_file: @@ -82,7 +78,6 @@ for filename in pathlib.Path(js_i18n_dir).rglob('*.js'): errors += check_prohibited_tags(data) for filename in pathlib.Path(js_dynamic_dir).rglob('*.js'): - print(f"Checking file 'dynamic/{os.path.basename(filename)}':") # Check that the 'dynamic' files do not contains any translated strings @@ -94,15 +89,14 @@ for filename in pathlib.Path(js_dynamic_dir).rglob('*.js'): err_count = 0 for idx, line in enumerate(data): - results = re.findall(pattern, line) if len(results) > 0: errors += 1 - print(f" > prohibited {{% trans %}} tag found at line {idx + 1}") + print(f' > prohibited {{% trans %}} tag found at line {idx + 1}') if errors > 0: - print(f"Found {errors} incorrect template tags") + print(f'Found {errors} incorrect template tags') sys.exit(errors) diff --git a/ci/check_locale_files.py b/ci/check_locale_files.py index 3bfbd66943..d5e2b89fbe 100644 --- a/ci/check_locale_files.py +++ b/ci/check_locale_files.py @@ -3,7 +3,7 @@ import subprocess import sys -print("Checking for uncommitted locale files...") +print('Checking for uncommitted locale files...') cmd = ['git', 'status'] @@ -19,9 +19,9 @@ for line in str(out.decode()).split('\n'): locales.append(line) if len(locales) > 0: - print("There are {n} unstaged locale files:".format(n=len(locales))) + print('There are {n} unstaged locale files:'.format(n=len(locales))) for lang in locales: - print(" - {l}".format(l=lang)) + print(' - {l}'.format(l=lang)) sys.exit(len(locales)) diff --git a/ci/check_migration_files.py b/ci/check_migration_files.py index 24668350be..d224848f02 100644 --- a/ci/check_migration_files.py +++ b/ci/check_migration_files.py @@ -3,7 +3,7 @@ import subprocess import sys -print("Checking for unstaged migration files...") +print('Checking for unstaged migration files...') cmd = ['git', 'ls-files', '--exclude-standard', '--others'] @@ -20,9 +20,9 @@ for line in str(out.decode()).split('\n'): if len(migrations) == 0: sys.exit(0) -print("There are {n} unstaged migration files:".format(n=len(migrations))) +print('There are {n} unstaged migration files:'.format(n=len(migrations))) for m in migrations: - print(" - {m}".format(m=m)) + print(' - {m}'.format(m=m)) sys.exit(len(migrations)) diff --git a/ci/version_check.py b/ci/version_check.py index 0ea59b73f4..ee90e78949 100644 --- a/ci/version_check.py +++ b/ci/version_check.py @@ -14,6 +14,7 @@ import json import os import re import sys +from pathlib import Path import requests @@ -26,14 +27,16 @@ def get_existing_release_tags(): headers = None if token: - headers = { - "Authorization": f"Bearer {token}" - } + headers = {'Authorization': f'Bearer {token}'} - response = requests.get('https://api.github.com/repos/inventree/inventree/releases', headers=headers) + response = requests.get( + 'https://api.github.com/repos/inventree/inventree/releases', headers=headers + ) if response.status_code != 200: - raise ValueError(f'Unexpected status code from GitHub API: {response.status_code}') + raise ValueError( + f'Unexpected status code from GitHub API: {response.status_code}' + ) data = json.loads(response.text) @@ -42,7 +45,7 @@ def get_existing_release_tags(): for release in data: tag = release['tag_name'].strip() - match = re.match(r"^.*(\d+)\.(\d+)\.(\d+).*$", tag) + match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', tag) if len(match.groups()) != 3: print(f"Version '{tag}' did not match expected pattern") @@ -62,10 +65,12 @@ def check_version_number(version_string, allow_duplicate=False): print(f"Checking version '{version_string}'") # Check that the version string matches the required format - match = re.match(r"^(\d+)\.(\d+)\.(\d+)(?: dev)?$", version_string) + match = re.match(r'^(\d+)\.(\d+)\.(\d+)(?: dev)?$', version_string) if not match or len(match.groups()) != 3: - raise ValueError(f"Version string '{version_string}' did not match required pattern") + raise ValueError( + f"Version string '{version_string}' did not match required pattern" + ) version_tuple = [int(x) for x in match.groups()] @@ -81,15 +86,12 @@ def check_version_number(version_string, allow_duplicate=False): if release > version_tuple: highest_release = False - print(f"Found newer release: {str(release)}") + print(f'Found newer release: {str(release)}') return highest_release if __name__ == '__main__': - - here = os.path.abspath(os.path.dirname(__file__)) - # GITHUB_REF_TYPE may be either 'branch' or 'tag' GITHUB_REF_TYPE = os.environ['GITHUB_REF_TYPE'] @@ -101,24 +103,24 @@ if __name__ == '__main__': GITHUB_BASE_REF = os.environ['GITHUB_BASE_REF'] # Print out version information, makes debugging actions *much* easier! - print(f"GITHUB_REF: {GITHUB_REF}") - print(f"GITHUB_REF_NAME: {GITHUB_REF_NAME}") - print(f"GITHUB_REF_TYPE: {GITHUB_REF_TYPE}") - print(f"GITHUB_BASE_REF: {GITHUB_BASE_REF}") + print(f'GITHUB_REF: {GITHUB_REF}') + print(f'GITHUB_REF_NAME: {GITHUB_REF_NAME}') + print(f'GITHUB_REF_TYPE: {GITHUB_REF_TYPE}') + print(f'GITHUB_BASE_REF: {GITHUB_BASE_REF}') - version_file = os.path.join(here, '..', 'InvenTree', 'InvenTree', 'version.py') + here = Path(__file__).parent.absolute() + version_file = here.joinpath('..', 'InvenTree', 'InvenTree', 'version.py') version = None with open(version_file, 'r') as f: - text = f.read() # Extract the InvenTree software version - results = re.findall(r'INVENTREE_SW_VERSION = "(.*)"', text) + results = re.findall(r"""INVENTREE_SW_VERSION = '(.*)'""", text) if len(results) != 1: - print(f"Could not find INVENTREE_SW_VERSION in {version_file}") + print(f'Could not find INVENTREE_SW_VERSION in {version_file}') sys.exit(1) version = results[0] @@ -162,15 +164,15 @@ if __name__ == '__main__': docker_tags = ['latest'] else: - print("Unsupported branch / version combination:") - print(f"InvenTree Version: {version}") - print("GITHUB_REF_TYPE:", GITHUB_REF_TYPE) - print("GITHUB_BASE_REF:", GITHUB_BASE_REF) - print("GITHUB_REF:", GITHUB_REF) + print('Unsupported branch / version combination:') + print(f'InvenTree Version: {version}') + print('GITHUB_REF_TYPE:', GITHUB_REF_TYPE) + print('GITHUB_BASE_REF:', GITHUB_BASE_REF) + print('GITHUB_REF:', GITHUB_REF) sys.exit(1) if docker_tags is None: - print("Docker tags could not be determined") + print('Docker tags could not be determined') sys.exit(1) print(f"Version check passed for '{version}'!") @@ -178,11 +180,10 @@ if __name__ == '__main__': # Ref: https://getridbug.com/python/how-to-set-environment-variables-in-github-actions-using-python/ with open(os.getenv('GITHUB_ENV'), 'a') as env_file: - # Construct tag string - tags = ",".join([f"inventree/inventree:{tag}" for tag in docker_tags]) + tags = ','.join([f'inventree/inventree:{tag}' for tag in docker_tags]) - env_file.write(f"docker_tags={tags}\n") + env_file.write(f'docker_tags={tags}\n') if GITHUB_REF_TYPE == 'tag' and highest_release: - env_file.write("stable_release=true\n") + env_file.write('stable_release=true\n') diff --git a/docker/gunicorn.conf.py b/docker/gunicorn.conf.py index 256bd5c0b7..6d36a74269 100644 --- a/docker/gunicorn.conf.py +++ b/docker/gunicorn.conf.py @@ -33,7 +33,7 @@ if workers is not None: if workers is None: workers = multiprocessing.cpu_count() * 2 + 1 -logger.info("Starting gunicorn server with %s workers", workers) +logger.info('Starting gunicorn server with %s workers', workers) max_requests = 1000 max_requests_jitter = 50 diff --git a/docs/docs/hooks.py b/docs/docs/hooks.py index fdd6f97464..6b33829c20 100644 --- a/docs/docs/hooks.py +++ b/docs/docs/hooks.py @@ -11,7 +11,7 @@ import requests def fetch_rtd_versions(): """Get a list of RTD docs versions to build the version selector""" - print("Fetching documentation versions from ReadTheDocs") + print('Fetching documentation versions from ReadTheDocs') versions = [] @@ -20,7 +20,7 @@ def fetch_rtd_versions(): response = requests.get(url, headers=headers) if response.status_code != 200: - print(f"Error fetching RTD versions: {response.status_code}") + print(f'Error fetching RTD versions: {response.status_code}') return data = json.loads(response.text) @@ -39,11 +39,7 @@ def fetch_rtd_versions(): if version == 'latest': continue - versions.append({ - 'version': version, - 'title': version, - 'aliases': aliases, - }) + versions.append({'version': version, 'title': version, 'aliases': aliases}) if data['next']: make_request(data['next'], headers) @@ -52,21 +48,24 @@ def fetch_rtd_versions(): token = os.environ.get('RTD_TOKEN', None) if token: headers = {'Authorization': f'Token {token}'} - url = "https://readthedocs.org/api/v3/projects/inventree/versions/?active=true&limit=50" + url = 'https://readthedocs.org/api/v3/projects/inventree/versions/?active=true&limit=50' make_request(url, headers) else: - print("No RTD token found - skipping RTD version fetch") + print('No RTD token found - skipping RTD version fetch') # Sort versions by version number versions = sorted(versions, key=lambda x: StrictVersion(x['version']), reverse=True) # Add "latest" version first if not any((x['title'] == 'latest' for x in versions)): - versions.insert(0, { - 'title': 'Development', - 'version': 'latest', - 'aliases': ['main', 'latest', 'development',], - }) + versions.insert( + 0, + { + 'title': 'Development', + 'version': 'latest', + 'aliases': ['main', 'latest', 'development'], + }, + ) # Ensure we have the 'latest' version current_version = os.environ.get('READTHEDOCS_VERSION', None) @@ -80,7 +79,7 @@ def fetch_rtd_versions(): output_filename = os.path.join(os.path.dirname(__file__), 'versions.json') - print("Discovered the following versions:") + print('Discovered the following versions:') print(versions) with open(output_filename, 'w') as file: @@ -105,14 +104,14 @@ def get_release_data(): return json.loads(f.read()) # Download release information via the GitHub API - print("Fetching InvenTree release information from api.github.com:") + print('Fetching InvenTree release information from api.github.com:') releases = [] # Keep making API requests until we run out of results page = 1 while 1: - url = f"https://api.github.com/repos/inventree/inventree/releases?page={page}&per_page=150" + url = f'https://api.github.com/repos/inventree/inventree/releases?page={page}&per_page=150' response = requests.get(url, timeout=30) assert response.status_code == 200 @@ -164,12 +163,12 @@ def on_config(config, *args, **kwargs): rtd_version = os.environ['READTHEDOCS_VERSION'] rtd_language = os.environ['READTHEDOCS_LANGUAGE'] - site_url = f"https://docs.inventree.org/{rtd_language}/{rtd_version}" - assets_dir = f"/{rtd_language}/{rtd_version}/assets" + site_url = f'https://docs.inventree.org/{rtd_language}/{rtd_version}' + assets_dir = f'/{rtd_language}/{rtd_version}/assets' - print("Building within READTHEDOCS environment!") - print(f" - Version: {rtd_version}") - print(f" - Language: {rtd_language}") + print('Building within READTHEDOCS environment!') + print(f' - Version: {rtd_version}') + print(f' - Language: {rtd_language}') # Add *all* readthedocs related keys readthedocs = {} @@ -188,7 +187,7 @@ def on_config(config, *args, **kwargs): else: print("'READTHEDOCS' environment variable not found") - print("Building for localhost configuration!") + print('Building for localhost configuration!') assets_dir = '/assets' site_url = config['site_url'] @@ -206,7 +205,6 @@ def on_config(config, *args, **kwargs): releases = [] for item in release_data: - # Ignore draft releases if item['draft']: continue @@ -217,15 +215,11 @@ def on_config(config, *args, **kwargs): re.match(r'^\d+\.\d+\.\d+$', tag) if not re.match: - print(f"Found badly formatted release: {tag}") + print(f'Found badly formatted release: {tag}') continue # Check if there is a local file with release information - local_path = os.path.join( - os.path.dirname(__file__), - 'releases', - f'{tag}.md', - ) + local_path = os.path.join(os.path.dirname(__file__), 'releases', f'{tag}.md') if os.path.exists(local_path): item['local_path'] = local_path @@ -244,7 +238,7 @@ def on_config(config, *args, **kwargs): releases.append(item) - print(f"- found {len(releases)} releases.") + print(f'- found {len(releases)} releases.') # Sort releases by descending date config['releases'] = sorted(releases, key=lambda it: it['date'], reverse=True) diff --git a/docs/main.py b/docs/main.py index 8f7f59c36d..b99988f061 100644 --- a/docs/main.py +++ b/docs/main.py @@ -15,13 +15,9 @@ def define_env(env): assets = [] - allowed = [ - '.png', - '.jpg', - ] + allowed = ['.png', '.jpg'] for asset in os.listdir(directory): - if any(asset.endswith(x) for x in allowed): assets.append(os.path.join(subdir, asset)) diff --git a/pyproject.toml b/pyproject.toml index 60e4ece806..38b45743db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,17 +12,16 @@ exclude = [ ".venv", ".env", ] -src = ["../InvenTree"] +src = ["InvenTree"] # line-length = 120 [tool.ruff.extend-per-file-ignores] "__init__.py" = ["D104"] [tool.ruff.lint] -select = ["A", "B", "C4", "D", "N"] +select = ["A", "B", "C4", "D", "I", "N"] # Things that should be enabled in the future: # - LOG -# - I # - DJ # for Django stuff # - S # for security stuff (bandit) @@ -63,13 +62,12 @@ section-order = [ "first-party", "local-folder", ] -known-first-party = ["src", "plugin", "InvenTree", "common"] [tool.ruff.lint.isort.sections] "django" = ["django"] [tool.ruff.format] -quote-style = "preserve" +quote-style = "single" indent-style = "space" skip-magic-trailing-comma = true line-ending = "auto" diff --git a/tasks.py b/tasks.py index 2243690bc3..b05418fb3b 100644 --- a/tasks.py +++ b/tasks.py @@ -22,7 +22,7 @@ def checkPythonVersion(): REQ_MAJOR = 3 REQ_MINOR = 9 - version = sys.version.split(" ")[0] + version = sys.version.split(' ')[0] valid = True @@ -33,8 +33,8 @@ def checkPythonVersion(): valid = False if not valid: - print(f"The installed python version ({version}) is not supported!") - print(f"InvenTree requires Python {REQ_MAJOR}.{REQ_MINOR} or above") + print(f'The installed python version ({version}) is not supported!') + print(f'InvenTree requires Python {REQ_MAJOR}.{REQ_MINOR} or above') sys.exit(1) @@ -62,26 +62,26 @@ def apps(): def content_excludes(): """Returns a list of content types to exclude from import/export.""" excludes = [ - "contenttypes", - "auth.permission", - "users.apitoken", - "error_report.error", - "admin.logentry", - "django_q.schedule", - "django_q.task", - "django_q.ormq", - "users.owner", - "exchange.rate", - "exchange.exchangebackend", - "common.notificationentry", - "common.notificationmessage", - "user_sessions.session", + 'contenttypes', + 'auth.permission', + 'users.apitoken', + 'error_report.error', + 'admin.logentry', + 'django_q.schedule', + 'django_q.task', + 'django_q.ormq', + 'users.owner', + 'exchange.rate', + 'exchange.exchangebackend', + 'common.notificationentry', + 'common.notificationmessage', + 'user_sessions.session', ] - output = "" + output = '' for e in excludes: - output += f"--exclude {e} " + output += f'--exclude {e} ' return output @@ -113,10 +113,10 @@ def manage(c, cmd, pty: bool = False): cmd: Django command to run. pty (bool, optional): Run an interactive session. Defaults to False. """ - c.run('cd "{path}" && python3 manage.py {cmd}'.format( - path=managePyDir(), - cmd=cmd - ), pty=pty) + c.run( + 'cd "{path}" && python3 manage.py {cmd}'.format(path=managePyDir(), cmd=cmd), + pty=pty, + ) def yarn(c, cmd, pty: bool = False): @@ -133,6 +133,7 @@ def yarn(c, cmd, pty: bool = False): def node_available(versions: bool = False, bypass_yarn: bool = False): """Checks if the frontend environment (ie node and yarn in bash) is available.""" + def ret(val, val0=None, val1=None): if versions: return val, val0, val1 @@ -140,7 +141,10 @@ def node_available(versions: bool = False, bypass_yarn: bool = False): def check(cmd): try: - return str(subprocess.check_output([cmd], stderr=subprocess.STDOUT, shell=True), encoding='utf-8').strip() + return str( + subprocess.check_output([cmd], stderr=subprocess.STDOUT, shell=True), + encoding='utf-8', + ).strip() except subprocess.CalledProcessError: return None except FileNotFoundError: @@ -154,7 +158,9 @@ def node_available(versions: bool = False, bypass_yarn: bool = False): # Print a warning if node is available but yarn is not if node_version and not yarn_passes: - print('Node is available but yarn is not. Install yarn if you wish to build the frontend.') + print( + 'Node is available but yarn is not. Install yarn if you wish to build the frontend.' + ) # Return the result return ret(yarn_passes and node_version, node_version, yarn_version) @@ -168,11 +174,13 @@ def check_file_existance(filename: str, overwrite: bool = False): overwrite (bool, optional): Overwrite the file without asking. Defaults to False. """ if Path(filename).is_file() and overwrite is False: - response = input("Warning: file already exists. Do you want to overwrite? [y/N]: ") + response = input( + 'Warning: file already exists. Do you want to overwrite? [y/N]: ' + ) response = str(response).strip().lower() if response not in ['y', 'yes']: - print("Cancelled export operation") + print('Cancelled export operation') sys.exit(1) @@ -198,7 +206,9 @@ def install(c): # Install required Python packages with PIP c.run('pip3 install --upgrade pip') c.run('pip3 install --upgrade setuptools') - c.run('pip3 install --no-cache-dir --disable-pip-version-check -U -r requirements.txt') + c.run( + 'pip3 install --no-cache-dir --disable-pip-version-check -U -r requirements.txt' + ) @task(help={'tests': 'Set up test dataset at the end'}) @@ -210,12 +220,12 @@ def setup_dev(c, tests=False): c.run('pip3 install -U -r requirements-dev.txt') # Install pre-commit hook - print("Installing pre-commit for checks before git commits...") + print('Installing pre-commit for checks before git commits...') c.run('pre-commit install') # Update all the hooks c.run('pre-commit autoupdate') - print("pre-commit set up is done...") + print('pre-commit set up is done...') # Set up test-data if flag is set if tests: @@ -232,19 +242,19 @@ def superuser(c): @task def rebuild_models(c): """Rebuild database models with MPTT structures.""" - manage(c, "rebuild_models", pty=True) + manage(c, 'rebuild_models', pty=True) @task def rebuild_thumbnails(c): """Rebuild missing image thumbnails.""" - manage(c, "rebuild_thumbnails", pty=True) + manage(c, 'rebuild_thumbnails', pty=True) @task def clean_settings(c): """Clean the setting tables of old settings.""" - manage(c, "clean_settings") + manage(c, 'clean_settings') @task(help={'mail': "mail of the user who's MFA should be disabled"}) @@ -253,20 +263,16 @@ def remove_mfa(c, mail=''): if not mail: print('You must provide a users mail') - manage(c, f"remove_mfa {mail}") + manage(c, f'remove_mfa {mail}') -@task( - help={ - 'frontend': 'Build the frontend', - } -) +@task(help={'frontend': 'Build the frontend'}) def static(c, frontend=False): """Copies required static files to the STATIC_ROOT directory, as per Django requirements.""" - manage(c, "prerender") + manage(c, 'prerender') if frontend and node_available(): frontend_build(c) - manage(c, "collectstatic --no-input") + manage(c, 'collectstatic --no-input') @task @@ -280,7 +286,7 @@ def translate_stats(c): try: manage(c, 'compilemessages', pty=True) except Exception: - print("WARNING: Translation files could not be compiled:") + print('WARNING: Translation files could not be compiled:') path = Path('InvenTree', 'script', 'translation_stats.py') c.run(f'python3 {path}') @@ -294,8 +300,8 @@ def translate(c): it is performed as part of the InvenTree translation toolchain. """ # Translate applicable .py / .html / .js / .tsx files - manage(c, "makemessages --all -e py,html,js --no-wrap") - manage(c, "compilemessages") + manage(c, 'makemessages --all -e py,html,js --no-wrap') + manage(c, 'compilemessages') if node_available(): frontend_install(c) @@ -309,19 +315,19 @@ def translate(c): @task def backup(c): """Backup the database and media files.""" - print("Backing up InvenTree database...") - manage(c, "dbbackup --noinput --clean --compress") - print("Backing up InvenTree media files...") - manage(c, "mediabackup --noinput --clean --compress") + print('Backing up InvenTree database...') + manage(c, 'dbbackup --noinput --clean --compress') + print('Backing up InvenTree media files...') + manage(c, 'mediabackup --noinput --clean --compress') @task def restore(c): """Restore the database and media files.""" - print("Restoring InvenTree database...") - manage(c, "dbrestore --noinput --uncompress") - print("Restoring InvenTree media files...") - manage(c, "mediarestore --noinput --uncompress") + print('Restoring InvenTree database...') + manage(c, 'dbrestore --noinput --uncompress') + print('Restoring InvenTree media files...') + manage(c, 'mediarestore --noinput --uncompress') @task(post=[rebuild_models, rebuild_thumbnails]) @@ -330,16 +336,16 @@ def migrate(c): This is a critical step if the database schema have been altered! """ - print("Running InvenTree database migrations...") - print("========================================") + print('Running InvenTree database migrations...') + print('========================================') - manage(c, "makemigrations") - manage(c, "migrate --noinput") - manage(c, "migrate --run-syncdb") - manage(c, "check") + manage(c, 'makemigrations') + manage(c, 'migrate --noinput') + manage(c, 'migrate --run-syncdb') + manage(c, 'check') - print("========================================") - print("InvenTree database migrations completed!") + print('========================================') + print('InvenTree database migrations completed!') @task( @@ -347,8 +353,8 @@ def migrate(c): help={ 'skip_backup': 'Skip database backup step (advanced users)', 'frontend': 'Force frontend compilation/download step (ignores INVENTREE_DOCKER)', - 'no_frontend': 'Skip frontend compilation/download step' - } + 'no_frontend': 'Skip frontend compilation/download step', + }, ) def update(c, skip_backup=False, frontend: bool = False, no_frontend: bool = False): """Update InvenTree installation. @@ -390,13 +396,21 @@ def update(c, skip_backup=False, frontend: bool = False, no_frontend: bool = Fal # Data tasks -@task(help={ - 'filename': "Output filename (default = 'data.json')", - 'overwrite': "Overwrite existing files without asking first (default = off/False)", - 'include_permissions': "Include user and group permissions in the output file (filename) (default = off/False)", - 'delete_temp': "Delete temporary files (containing permissions) at end of run. Note that this will delete temporary files from previous runs as well. (default = off/False)" -}) -def export_records(c, filename='data.json', overwrite=False, include_permissions=False, delete_temp=False): +@task( + help={ + 'filename': "Output filename (default = 'data.json')", + 'overwrite': 'Overwrite existing files without asking first (default = off/False)', + 'include_permissions': 'Include user and group permissions in the output file (filename) (default = off/False)', + 'delete_temp': 'Delete temporary files (containing permissions) at end of run. Note that this will delete temporary files from previous runs as well. (default = off/False)', + } +) +def export_records( + c, + filename='data.json', + overwrite=False, + include_permissions=False, + delete_temp=False, +): """Export all database records to a file. Write data to the file defined by filename. @@ -422,43 +436,45 @@ def export_records(c, filename='data.json', overwrite=False, include_permissions check_file_existance(filename, overwrite) - tmpfile = f"{filename}.tmp" + tmpfile = f'{filename}.tmp' cmd = f"dumpdata --indent 2 --output '{tmpfile}' {content_excludes()}" # Dump data to temporary file manage(c, cmd, pty=True) - print("Running data post-processing step...") + print('Running data post-processing step...') # Post-process the file, to remove any "permissions" specified for a user or group - with open(tmpfile, "r") as f_in: + with open(tmpfile, 'r') as f_in: data = json.loads(f_in.read()) if include_permissions is False: for entry in data: - if "model" in entry: - + if 'model' in entry: # Clear out any permissions specified for a group - if entry["model"] == "auth.group": - entry["fields"]["permissions"] = [] + if entry['model'] == 'auth.group': + entry['fields']['permissions'] = [] # Clear out any permissions specified for a user - if entry["model"] == "auth.user": - entry["fields"]["user_permissions"] = [] + if entry['model'] == 'auth.user': + entry['fields']['user_permissions'] = [] # Write the processed data to file - with open(filename, "w") as f_out: + with open(filename, 'w') as f_out: f_out.write(json.dumps(data, indent=2)) - print("Data export completed") + print('Data export completed') if delete_temp is True: - print("Removing temporary file") + print('Removing temporary file') os.remove(tmpfile) -@task(help={'filename': 'Input filename', 'clear': 'Clear existing data before import'}, post=[rebuild_models, rebuild_thumbnails]) +@task( + help={'filename': 'Input filename', 'clear': 'Clear existing data before import'}, + post=[rebuild_models, rebuild_thumbnails], +) def import_records(c, filename='data.json', clear=False): """Import database records from a file.""" # Get an absolute path to the supplied filename @@ -475,31 +491,30 @@ def import_records(c, filename='data.json', clear=False): print(f"Importing database records from '{filename}'") # Pre-process the data, to remove any "permissions" specified for a user or group - tmpfile = f"{filename}.tmp.json" + tmpfile = f'{filename}.tmp.json' - with open(filename, "r") as f_in: + with open(filename, 'r') as f_in: data = json.loads(f_in.read()) for entry in data: - if "model" in entry: - + if 'model' in entry: # Clear out any permissions specified for a group - if entry["model"] == "auth.group": - entry["fields"]["permissions"] = [] + if entry['model'] == 'auth.group': + entry['fields']['permissions'] = [] # Clear out any permissions specified for a user - if entry["model"] == "auth.user": - entry["fields"]["user_permissions"] = [] + if entry['model'] == 'auth.user': + entry['fields']['user_permissions'] = [] # Write the processed data to the tmp file - with open(tmpfile, "w") as f_out: + with open(tmpfile, 'w') as f_out: f_out.write(json.dumps(data, indent=2)) cmd = f"loaddata '{tmpfile}' -i {content_excludes()}" manage(c, cmd, pty=True) - print("Data import completed") + print('Data import completed') @task @@ -508,7 +523,7 @@ def delete_data(c, force=False): Warning: This will REALLY delete all records in the database!! """ - print("Deleting all data from InvenTree database...") + print('Deleting all data from InvenTree database...') if force: manage(c, 'flush --noinput') @@ -530,32 +545,26 @@ def import_fixtures(c): fixtures = [ # Build model 'build', - # Common models 'settings', - # Company model 'company', 'price_breaks', 'supplier_part', - # Order model 'order', - # Part model 'bom', 'category', 'params', 'part', 'test_templates', - # Stock model 'location', 'stock_tests', 'stock', - # Users - 'users' + 'users', ] command = 'loaddata ' + ' '.join(fixtures) @@ -567,16 +576,16 @@ def import_fixtures(c): @task def wait(c): """Wait until the database connection is ready.""" - return manage(c, "wait_for_db") + return manage(c, 'wait_for_db') @task(pre=[wait], help={'address': 'Server address:port (default=127.0.0.1:8000)'}) -def server(c, address="127.0.0.1:8000"): +def server(c, address='127.0.0.1:8000'): """Launch a (development) server using Django's in-built webserver. Note: This is *not* sufficient for a production installation. """ - manage(c, "runserver {address}".format(address=address), pty=True) + manage(c, 'runserver {address}'.format(address=address), pty=True) @task(pre=[wait]) @@ -589,7 +598,7 @@ def worker(c): @task def render_js_files(c): """Render templated javascript files (used for static testing).""" - manage(c, "test InvenTree.ci_render_js") + manage(c, 'test InvenTree.ci_render_js') @task(post=[translate_stats, static, server]) @@ -607,40 +616,44 @@ def test_translations(c): django.setup() # Add language - print("Add dummy language...") - print("========================================") - manage(c, "makemessages -e py,html,js --no-wrap -l xx") + print('Add dummy language...') + print('========================================') + manage(c, 'makemessages -e py,html,js --no-wrap -l xx') # change translation - print("Fill in dummy translations...") - print("========================================") + print('Fill in dummy translations...') + print('========================================') file_path = pathlib.Path(settings.LOCALE_PATHS[0], 'xx', 'LC_MESSAGES', 'django.po') new_file_path = str(file_path) + '_new' # compile regex reg = re.compile( - r"[a-zA-Z0-9]{1}" + # match any single letter and number # noqa: W504 - r"(?![^{\(\<]*[}\)\>])" + # that is not inside curly brackets, brackets or a tag # noqa: W504 - r"(?])' # that is not inside curly brackets, brackets or a tag # noqa: W504 + + r'(? replace regex matches with x in the read in (multi)string file_new.write(f'msgstr "{reg.sub("x", last_string[7:-2])}"\n') - last_string = "" # reset (multi)string + last_string = '' # reset (multi)string elif line.startswith('msgid "'): - last_string = last_string + line # a new translatable string starts -> start append + last_string = ( + last_string + line + ) # a new translatable string starts -> start append file_new.write(line) else: if last_string: - last_string = last_string + line # a string is being read in -> continue appending + last_string = ( + last_string + line + ) # a string is being read in -> continue appending file_new.write(line) # change out translation files @@ -648,9 +661,9 @@ def test_translations(c): new_file_path.rename(file_path) # compile languages - print("Compile languages ...") - print("========================================") - manage(c, "compilemessages") + print('Compile languages ...') + print('========================================') + manage(c, 'compilemessages') # reset cwd os.chdir(base_path) @@ -668,7 +681,9 @@ def test_translations(c): 'coverage': 'Run code coverage analysis (requires coverage package)', } ) -def test(c, disable_pty=False, runtest='', migrations=False, report=False, coverage=False): +def test( + c, disable_pty=False, runtest='', migrations=False, report=False, coverage=False +): """Run unit-tests for InvenTree codebase. To run only certain test, use the argument --runtest. @@ -713,7 +728,7 @@ def test(c, disable_pty=False, runtest='', migrations=False, report=False, cover @task(help={'dev': 'Set up development environment at the end'}) -def setup_test(c, ignore_update=False, dev=False, path="inventree-demo-dataset"): +def setup_test(c, ignore_update=False, dev=False, path='inventree-demo-dataset'): """Setup a testing environment.""" from InvenTree.InvenTree.config import get_media_dir @@ -722,41 +737,43 @@ def setup_test(c, ignore_update=False, dev=False, path="inventree-demo-dataset") # Remove old data directory if os.path.exists(path): - print("Removing old data ...") + print('Removing old data ...') c.run(f'rm {path} -r') # Get test data - print("Cloning demo dataset ...") + print('Cloning demo dataset ...') c.run(f'git clone https://github.com/inventree/demo-dataset {path} -v --depth=1') - print("========================================") + print('========================================') # Make sure migrations are done - might have just deleted sqlite database if not ignore_update: migrate(c) # Load data - print("Loading database records ...") + print('Loading database records ...') import_records(c, filename=f'{path}/inventree_data.json', clear=True) # Copy media files - print("Copying media files ...") + print('Copying media files ...') src = Path(path).joinpath('media').resolve() dst = get_media_dir() shutil.copytree(src, dst, dirs_exist_ok=True) - print("Done setting up test environment...") - print("========================================") + print('Done setting up test environment...') + print('========================================') # Set up development setup if flag is set if dev: setup_dev(c) -@task(help={ - 'filename': "Output filename (default = 'schema.yml')", - 'overwrite': "Overwrite existing files without asking first (default = off/False)", -}) +@task( + help={ + 'filename': "Output filename (default = 'schema.yml')", + 'overwrite': 'Overwrite existing files without asking first (default = off/False)', + } +) def schema(c, filename='schema.yml', overwrite=False): """Export current API schema.""" check_file_existance(filename, overwrite) @@ -767,13 +784,17 @@ def schema(c, filename='schema.yml', overwrite=False): def version(c): """Show the current version of InvenTree.""" import InvenTree.InvenTree.version as InvenTreeVersion - from InvenTree.InvenTree.config import (get_config_file, get_media_dir, - get_static_dir) + from InvenTree.InvenTree.config import ( + get_config_file, + get_media_dir, + get_static_dir, + ) # Gather frontend version information _, node, yarn = node_available(versions=True) - print(f""" + print( + f""" InvenTree - inventree.org The Open-Source Inventory Management System\n @@ -792,13 +813,16 @@ Node {node if node else 'N/A'} Yarn {yarn if yarn else 'N/A'} Commit hash:{InvenTreeVersion.inventreeCommitHash()} -Commit date:{InvenTreeVersion.inventreeCommitDate()}""") +Commit date:{InvenTreeVersion.inventreeCommitDate()}""" + ) if len(sys.argv) == 1 and sys.argv[0].startswith('/opt/inventree/env/lib/python'): - print(""" + print( + """ You are probably running the package installer / single-line installer. Please mentioned that in any bug reports! Use '--list' for a list of available commands -Use '--help' for help on a specific command""") +Use '--help' for help on a specific command""" + ) @task() @@ -826,8 +850,8 @@ def frontend_install(c): Args: c: Context variable """ - print("Installing frontend dependencies") - yarn(c, "yarn install") + print('Installing frontend dependencies') + yarn(c, 'yarn install') @task @@ -837,9 +861,9 @@ def frontend_trans(c): Args: c: Context variable """ - print("Compiling frontend translations") - yarn(c, "yarn run extract") - yarn(c, "yarn run compile") + print('Compiling frontend translations') + yarn(c, 'yarn run extract') + yarn(c, 'yarn run compile') @task @@ -849,8 +873,8 @@ def frontend_build(c): Args: c: Context variable """ - print("Building frontend") - yarn(c, "yarn run build --emptyOutDir") + print('Building frontend') + yarn(c, 'yarn run build --emptyOutDir') @task @@ -860,19 +884,29 @@ def frontend_dev(c): Args: c: Context variable """ - print("Starting frontend development server") - yarn(c, "yarn run dev") + print('Starting frontend development server') + yarn(c, 'yarn run dev') -@task(help={ - 'ref': "git ref, default: current git ref", - 'tag': "git tag to look for release", - 'file': "destination to frontend-build.zip file", - 'repo': "GitHub repository, default: InvenTree/inventree", - 'extract': "Also extract and place at the correct destination, default: True", - 'clean': "Delete old files from InvenTree/web/static/web first, default: True", -}) -def frontend_download(c, ref=None, tag=None, file=None, repo="InvenTree/inventree", extract=True, clean=True): +@task( + help={ + 'ref': 'git ref, default: current git ref', + 'tag': 'git tag to look for release', + 'file': 'destination to frontend-build.zip file', + 'repo': 'GitHub repository, default: InvenTree/inventree', + 'extract': 'Also extract and place at the correct destination, default: True', + 'clean': 'Delete old files from InvenTree/web/static/web first, default: True', + } +) +def frontend_download( + c, + ref=None, + tag=None, + file=None, + repo='InvenTree/inventree', + extract=True, + clean=True, +): """Download a pre-build frontend from GitHub if you dont want to install nodejs on your machine. There are 3 possibilities to install the frontend: @@ -894,7 +928,7 @@ def frontend_download(c, ref=None, tag=None, file=None, repo="InvenTree/inventre import requests # globals - default_headers = {"Accept": "application/vnd.github.v3+json"} + default_headers = {'Accept': 'application/vnd.github.v3+json'} # helper functions def find_resource(resource, key, value): @@ -908,30 +942,34 @@ def frontend_download(c, ref=None, tag=None, file=None, repo="InvenTree/inventre if not extract: return - dest_path = Path(__file__).parent / "InvenTree/web/static/web" + dest_path = Path(__file__).parent / 'InvenTree/web/static/web' # if clean, delete static/web directory if clean: shutil.rmtree(dest_path, ignore_errors=True) os.makedirs(dest_path) - print(f"Cleaned directory: {dest_path}") + print(f'Cleaned directory: {dest_path}') # unzip build to static folder - with ZipFile(file, "r") as zip_ref: + with ZipFile(file, 'r') as zip_ref: zip_ref.extractall(dest_path) - print(f"Unzipped downloaded frontend build to: {dest_path}") + print(f'Unzipped downloaded frontend build to: {dest_path}') def handle_download(url): # download frontend-build.zip to temporary file - with requests.get(url, headers=default_headers, stream=True, allow_redirects=True) as response, NamedTemporaryFile(suffix=".zip") as dst: + with requests.get( + url, headers=default_headers, stream=True, allow_redirects=True + ) as response, NamedTemporaryFile(suffix='.zip') as dst: response.raise_for_status() # auto decode the gzipped raw data - response.raw.read = functools.partial(response.raw.read, decode_content=True) - with open(dst.name, "wb") as f: + response.raw.read = functools.partial( + response.raw.read, decode_content=True + ) + with open(dst.name, 'wb') as f: shutil.copyfileobj(response.raw, f) - print(f"Downloaded frontend build to temporary file: {dst.name}") + print(f'Downloaded frontend build to temporary file: {dst.name}') handle_extract(dst.name) @@ -942,51 +980,72 @@ def frontend_download(c, ref=None, tag=None, file=None, repo="InvenTree/inventre # check arguments if ref is not None and tag is not None: - print("[ERROR] Do not set ref and tag.") + print('[ERROR] Do not set ref and tag.') return if ref is None and tag is None: try: - ref = subprocess.check_output(["git", "rev-parse", "HEAD"], encoding="utf-8").strip() + ref = subprocess.check_output( + ['git', 'rev-parse', 'HEAD'], encoding='utf-8' + ).strip() except Exception: print("[ERROR] Cannot get current ref via 'git rev-parse HEAD'") return if ref is None and tag is None: - print("[ERROR] Either ref or tag needs to be set.") + print('[ERROR] Either ref or tag needs to be set.') if tag: - tag = tag.lstrip("v") + tag = tag.lstrip('v') try: - handle_download(f"https://github.com/{repo}/releases/download/{tag}/frontend-build.zip") + handle_download( + f'https://github.com/{repo}/releases/download/{tag}/frontend-build.zip' + ) except Exception as e: if not isinstance(e, requests.HTTPError): raise e - print(f"""[ERROR] An Error occurred. Unable to download frontend build, release or build does not exist, + print( + f"""[ERROR] An Error occurred. Unable to download frontend build, release or build does not exist, try downloading the frontend-build.zip yourself via: https://github.com/{repo}/releases -Then try continuing by running: invoke frontend-download --file """) +Then try continuing by running: invoke frontend-download --file """ + ) return if ref: # get workflow run from all workflow runs on that particular ref - workflow_runs = requests.get(f"https://api.github.com/repos/{repo}/actions/runs?head_sha={ref}", headers=default_headers).json() + workflow_runs = requests.get( + f'https://api.github.com/repos/{repo}/actions/runs?head_sha={ref}', + headers=default_headers, + ).json() - if not (qc_run := find_resource(workflow_runs["workflow_runs"], "name", "QC")): - print("[ERROR] Cannot find any workflow runs for current sha") + if not (qc_run := find_resource(workflow_runs['workflow_runs'], 'name', 'QC')): + print('[ERROR] Cannot find any workflow runs for current sha') return - print(f"Found workflow {qc_run['name']} (run {qc_run['run_number']}-{qc_run['run_attempt']})") + print( + f"Found workflow {qc_run['name']} (run {qc_run['run_number']}-{qc_run['run_attempt']})" + ) # get frontend-build artifact from all artifacts available for this workflow run - artifacts = requests.get(qc_run["artifacts_url"], headers=default_headers).json() - if not (frontend_artifact := find_resource(artifacts["artifacts"], "name", "frontend-build")): - print("[ERROR] Cannot find frontend-build.zip attachment for current sha") + artifacts = requests.get( + qc_run['artifacts_url'], headers=default_headers + ).json() + if not ( + frontend_artifact := find_resource( + artifacts['artifacts'], 'name', 'frontend-build' + ) + ): + print('[ERROR] Cannot find frontend-build.zip attachment for current sha') return - print(f"Found artifact {frontend_artifact['name']} with id {frontend_artifact['id']} ({frontend_artifact['size_in_bytes']/1e6:.2f}MB).") + print( + f"Found artifact {frontend_artifact['name']} with id {frontend_artifact['id']} ({frontend_artifact['size_in_bytes']/1e6:.2f}MB)." + ) - print(f""" + print( + f""" GitHub doesn't allow artifact downloads from anonymous users. Either download the following file via your signed in browser, or consider using a point release download via invoke frontend-download --tag Download: https://github.com/{repo}/suites/{qc_run['check_suite_id']}/artifacts/{frontend_artifact['id']} manually and - continue by running: invoke frontend-download --file """) + continue by running: invoke frontend-download --file """ + )