Enable more Ruff rules (#7930)

* bump version

* Add more checks

* add simplify rules

* Add RUF rules

* small perf imrpovements

* pylint checks

* more style fixes

* fix a number of A002 cases

* fix A001 cases

* disable unsafe fixes

* remove unneeded branches
fixes SIM102

* re-enable .keys for specific case

* Revert "remove unneeded branches"

This reverts commit f74d41bc07.

* fix reference
This commit is contained in:
Matthias Mair 2024-08-27 01:04:55 +02:00 committed by GitHub
parent bcbbae0a18
commit 1634258783
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
127 changed files with 526 additions and 740 deletions

View File

@ -71,7 +71,7 @@ def check_prohibited_tags(data):
for filename in pathlib.Path(js_i18n_dir).rglob('*.js'): for filename in pathlib.Path(js_i18n_dir).rglob('*.js'):
print(f"Checking file 'translated/{os.path.basename(filename)}':") print(f"Checking file 'translated/{os.path.basename(filename)}':")
with open(filename, 'r') as js_file: with open(filename, encoding='utf-8') as js_file:
data = js_file.readlines() data = js_file.readlines()
errors += check_invalid_tag(data) errors += check_invalid_tag(data)
@ -81,7 +81,7 @@ for filename in pathlib.Path(js_dynamic_dir).rglob('*.js'):
print(f"Checking file 'dynamic/{os.path.basename(filename)}':") print(f"Checking file 'dynamic/{os.path.basename(filename)}':")
# Check that the 'dynamic' files do not contains any translated strings # Check that the 'dynamic' files do not contains any translated strings
with open(filename, 'r') as js_file: with open(filename, encoding='utf-8') as js_file:
data = js_file.readlines() data = js_file.readlines()
invalid_tags = ['blocktrans', 'blocktranslate', 'trans', 'translate'] invalid_tags = ['blocktrans', 'blocktranslate', 'trans', 'translate']

View File

@ -20,9 +20,9 @@ for line in str(out.decode()).split('\n'):
if len(migrations) == 0: if len(migrations) == 0:
sys.exit(0) sys.exit(0)
print('There are {n} unstaged migration files:'.format(n=len(migrations))) print(f'There are {len(migrations)} unstaged migration files:')
for m in migrations: for m in migrations:
print(' - {m}'.format(m=m)) print(f' - {m}')
sys.exit(len(migrations)) sys.exit(len(migrations))

View File

@ -89,7 +89,7 @@ def check_version_number(version_string, allow_duplicate=False):
if release > version_tuple: if release > version_tuple:
highest_release = False highest_release = False
print(f'Found newer release: {str(release)}') print(f'Found newer release: {release!s}')
return highest_release return highest_release
@ -134,7 +134,7 @@ if __name__ == '__main__':
version = None version = None
with open(version_file, 'r') as f: with open(version_file, encoding='utf-8') as f:
text = f.read() text = f.read()
# Extract the InvenTree software version # Extract the InvenTree software version
@ -175,10 +175,7 @@ if __name__ == '__main__':
print(f"Version number '{version}' does not match tag '{version_tag}'") print(f"Version number '{version}' does not match tag '{version_tag}'")
sys.exit sys.exit
if highest_release: docker_tags = [version_tag, 'stable'] if highest_release else [version_tag]
docker_tags = [version_tag, 'stable']
else:
docker_tags = [version_tag]
elif GITHUB_REF_TYPE == 'branch': elif GITHUB_REF_TYPE == 'branch':
# Otherwise we know we are targeting the 'master' branch # Otherwise we know we are targeting the 'master' branch
@ -202,7 +199,7 @@ if __name__ == '__main__':
target_repos = [REPO.lower(), f'ghcr.io/{REPO.lower()}'] target_repos = [REPO.lower(), f'ghcr.io/{REPO.lower()}']
# Ref: https://getridbug.com/python/how-to-set-environment-variables-in-github-actions-using-python/ # Ref: https://getridbug.com/python/how-to-set-environment-variables-in-github-actions-using-python/
with open(os.getenv('GITHUB_ENV'), 'a') as env_file: with open(os.getenv('GITHUB_ENV'), 'a', encoding='utf-8') as env_file:
# Construct tag string # Construct tag string
tag_list = [[f'{r}:{t}' for t in docker_tags] for r in target_repos] tag_list = [[f'{r}:{t}' for t in docker_tags] for r in target_repos]
tags = ','.join(itertools.chain(*tag_list)) tags = ','.join(itertools.chain(*tag_list))

View File

@ -17,17 +17,18 @@ repos:
- id: check-yaml - id: check-yaml
- id: mixed-line-ending - id: mixed-line-ending
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.1 rev: v0.6.1
hooks: hooks:
- id: ruff-format - id: ruff-format
args: [--preview] args: [--preview]
- id: ruff - id: ruff
args: [ args: [
--fix, --fix,
# --unsafe-fixes,
--preview --preview
] ]
- repo: https://github.com/astral-sh/uv-pre-commit - repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.2.13 rev: 0.2.37
hooks: hooks:
- id: pip-compile - id: pip-compile
name: pip-compile requirements-dev.in name: pip-compile requirements-dev.in
@ -77,7 +78,7 @@ repos:
- "prettier@^2.4.1" - "prettier@^2.4.1"
- "@trivago/prettier-plugin-sort-imports" - "@trivago/prettier-plugin-sort-imports"
- repo: https://github.com/pre-commit/mirrors-eslint - repo: https://github.com/pre-commit/mirrors-eslint
rev: "v9.6.0" rev: "v9.9.0"
hooks: hooks:
- id: eslint - id: eslint
additional_dependencies: additional_dependencies:

View File

@ -10,7 +10,7 @@ tld = os.path.abspath(os.path.join(here, '..'))
config_file = os.path.join(tld, 'mkdocs.yml') config_file = os.path.join(tld, 'mkdocs.yml')
with open(config_file, 'r') as f: with open(config_file, encoding='utf-8') as f:
data = yaml.load(f, yaml.BaseLoader) data = yaml.load(f, yaml.BaseLoader)
assert data['strict'] == 'true' assert data['strict'] == 'true'

View File

@ -57,7 +57,7 @@ def fetch_rtd_versions():
versions = sorted(versions, key=lambda x: StrictVersion(x['version']), reverse=True) versions = sorted(versions, key=lambda x: StrictVersion(x['version']), reverse=True)
# Add "latest" version first # Add "latest" version first
if not any((x['title'] == 'latest' for x in versions)): if not any(x['title'] == 'latest' for x in versions):
versions.insert( versions.insert(
0, 0,
{ {
@ -70,7 +70,7 @@ def fetch_rtd_versions():
# Ensure we have the 'latest' version # Ensure we have the 'latest' version
current_version = os.environ.get('READTHEDOCS_VERSION', None) current_version = os.environ.get('READTHEDOCS_VERSION', None)
if current_version and not any((x['title'] == current_version for x in versions)): if current_version and not any(x['title'] == current_version for x in versions):
versions.append({ versions.append({
'version': current_version, 'version': current_version,
'title': current_version, 'title': current_version,
@ -82,7 +82,7 @@ def fetch_rtd_versions():
print('Discovered the following versions:') print('Discovered the following versions:')
print(versions) print(versions)
with open(output_filename, 'w') as file: with open(output_filename, 'w', encoding='utf-8') as file:
json.dump(versions, file, indent=2) json.dump(versions, file, indent=2)
@ -100,7 +100,7 @@ def get_release_data():
# Release information has been cached to file # Release information has been cached to file
print("Loading release information from 'releases.json'") print("Loading release information from 'releases.json'")
with open(json_file) as f: with open(json_file, encoding='utf-8') as f:
return json.loads(f.read()) return json.loads(f.read())
# Download release information via the GitHub API # Download release information via the GitHub API
@ -127,7 +127,7 @@ def get_release_data():
page += 1 page += 1
# Cache these results to file # Cache these results to file
with open(json_file, 'w') as f: with open(json_file, 'w', encoding='utf-8') as f:
print("Saving release information to 'releases.json'") print("Saving release information to 'releases.json'")
f.write(json.dumps(releases)) f.write(json.dumps(releases))
@ -173,7 +173,7 @@ def on_config(config, *args, **kwargs):
# Add *all* readthedocs related keys # Add *all* readthedocs related keys
readthedocs = {} readthedocs = {}
for key in os.environ.keys(): for key in os.environ:
if key.startswith('READTHEDOCS_'): if key.startswith('READTHEDOCS_'):
k = key.replace('READTHEDOCS_', '').lower() k = key.replace('READTHEDOCS_', '').lower()
readthedocs[k] = os.environ[key] readthedocs[k] = os.environ[key]

View File

@ -46,7 +46,7 @@ def top_level_path(path: str) -> str:
key = path.split('/')[1] key = path.split('/')[1]
if key in SPECIAL_PATHS.keys(): if key in SPECIAL_PATHS:
return key return key
return GENERAL_PATH return GENERAL_PATH
@ -54,9 +54,7 @@ def top_level_path(path: str) -> str:
def generate_schema_file(key: str) -> None: def generate_schema_file(key: str) -> None:
"""Generate a schema file for the provided key.""" """Generate a schema file for the provided key."""
description = ( description = SPECIAL_PATHS.get(key, 'General API Endpoints')
SPECIAL_PATHS[key] if key in SPECIAL_PATHS else 'General API Endpoints'
)
output = f""" output = f"""
--- ---
@ -75,7 +73,7 @@ def generate_schema_file(key: str) -> None:
print('Writing schema file to:', output_file) print('Writing schema file to:', output_file)
with open(output_file, 'w') as f: with open(output_file, 'w', encoding='utf-8') as f:
f.write(output) f.write(output)
@ -121,7 +119,7 @@ def generate_index_file(version: str):
print('Writing index file to:', output_file) print('Writing index file to:', output_file)
with open(output_file, 'w') as f: with open(output_file, 'w', encoding='utf-8') as f:
f.write(output) f.write(output)
@ -173,7 +171,7 @@ def parse_api_file(filename: str):
The intent is to make the API schema easier to peruse on the documentation. The intent is to make the API schema easier to peruse on the documentation.
""" """
with open(filename, 'r') as f: with open(filename, encoding='utf-8') as f:
data = yaml.safe_load(f) data = yaml.safe_load(f)
paths = data['paths'] paths = data['paths']
@ -213,7 +211,7 @@ def parse_api_file(filename: str):
output_file = os.path.abspath(output_file) output_file = os.path.abspath(output_file)
with open(output_file, 'w') as f: with open(output_file, 'w', encoding='utf-8') as f:
yaml.dump(output, f) yaml.dump(output, f)
# Generate a markdown file for the schema # Generate a markdown file for the schema

View File

@ -16,7 +16,7 @@ global USER_SETTINGS
here = os.path.dirname(__file__) here = os.path.dirname(__file__)
settings_file = os.path.join(here, 'inventree_settings.json') settings_file = os.path.join(here, 'inventree_settings.json')
with open(settings_file, 'r') as sf: with open(settings_file, encoding='utf-8') as sf:
settings = json.load(sf) settings = json.load(sf)
GLOBAL_SETTINGS = settings['global'] GLOBAL_SETTINGS = settings['global']
@ -27,7 +27,7 @@ def get_repo_url(raw=False):
"""Return the repository URL for the current project.""" """Return the repository URL for the current project."""
mkdocs_yml = os.path.join(os.path.dirname(__file__), 'mkdocs.yml') mkdocs_yml = os.path.join(os.path.dirname(__file__), 'mkdocs.yml')
with open(mkdocs_yml, 'r') as f: with open(mkdocs_yml, encoding='utf-8') as f:
mkdocs_config = yaml.safe_load(f) mkdocs_config = yaml.safe_load(f)
repo_name = mkdocs_config['repo_name'] repo_name = mkdocs_config['repo_name']
@ -47,7 +47,7 @@ def check_link(url) -> bool:
# Keep a local cache file of URLs we have already checked # Keep a local cache file of URLs we have already checked
if os.path.exists(CACHE_FILE): if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, 'r') as f: with open(CACHE_FILE, encoding='utf-8') as f:
cache = f.read().splitlines() cache = f.read().splitlines()
if url in cache: if url in cache:
@ -59,7 +59,7 @@ def check_link(url) -> bool:
response = requests.head(url, timeout=5000) response = requests.head(url, timeout=5000)
if response.status_code == 200: if response.status_code == 200:
# Update the cache file # Update the cache file
with open(CACHE_FILE, 'a') as f: with open(CACHE_FILE, 'a', encoding='utf-8') as f:
f.write(f'{url}\n') f.write(f'{url}\n')
return True return True
@ -177,7 +177,7 @@ def define_env(env):
assert subprocess.call(command, shell=True) == 0 assert subprocess.call(command, shell=True) == 0
with open(output, 'r') as f: with open(output, encoding='utf-8') as f:
content = f.read() content = f.read()
return content return content
@ -200,12 +200,13 @@ def define_env(env):
return assets return assets
@env.macro @env.macro
def includefile(filename: str, title: str, format: str = ''): def includefile(filename: str, title: str, fmt: str = ''):
"""Include a file in the documentation, in a 'collapse' block. """Include a file in the documentation, in a 'collapse' block.
Arguments: Arguments:
- filename: The name of the file to include (relative to the top-level directory) - filename: The name of the file to include (relative to the top-level directory)
- title: - title:
- fmt:
""" """
here = os.path.dirname(__file__) here = os.path.dirname(__file__)
path = os.path.join(here, '..', filename) path = os.path.join(here, '..', filename)
@ -214,11 +215,11 @@ def define_env(env):
if not os.path.exists(path): if not os.path.exists(path):
raise FileNotFoundError(f'Required file {path} does not exist.') raise FileNotFoundError(f'Required file {path} does not exist.')
with open(path, 'r') as f: with open(path, encoding='utf-8') as f:
content = f.read() content = f.read()
data = f'??? abstract "{title}"\n\n' data = f'??? abstract "{title}"\n\n'
data += f' ```{format}\n' data += f' ```{fmt}\n'
data += textwrap.indent(content, ' ') data += textwrap.indent(content, ' ')
data += '\n\n' data += '\n\n'
data += ' ```\n\n' data += ' ```\n\n'
@ -233,15 +234,15 @@ def define_env(env):
'src', 'backend', 'InvenTree', 'report', 'templates', filename 'src', 'backend', 'InvenTree', 'report', 'templates', filename
) )
return includefile(fn, f'Template: {base}', format='html') return includefile(fn, f'Template: {base}', fmt='html')
@env.macro @env.macro
def rendersetting(setting: dict): def rendersetting(setting: dict):
"""Render a provided setting object into a table row.""" """Render a provided setting object into a table row."""
name = setting['name'] name = setting['name']
description = setting['description'] description = setting['description']
default = setting.get('default', None) default = setting.get('default')
units = setting.get('units', None) units = setting.get('units')
return f'| {name} | {description} | {default if default is not None else ""} | {units if units is not None else ""} |' return f'| {name} | {description} | {default if default is not None else ""} | {units if units is not None else ""} |'

View File

@ -20,13 +20,30 @@ src = ["src/backend/InvenTree"]
"__init__.py" = ["D104"] "__init__.py" = ["D104"]
[tool.ruff.lint] [tool.ruff.lint]
select = ["A", "B", "C4", "D", "I", "N", "F"] select = ["A", "B", "C", "C4", "D", "F", "I", "N", "SIM", "PIE", "PLE", "PLW", "RUF", "UP", "W"]
# Things that should be enabled in the future: # Things that should be enabled in the future:
# - LOG # - LOG
# - DJ # for Django stuff # - DJ # for Django stuff
# - S # for security stuff (bandit) # - S # for security stuff (bandit)
ignore = [ ignore = [
"PLE1205",
# - PLE1205 - Too many arguments for logging format string
"PLW2901",
# - PLW2901 - Outer {outer_kind} variable {name} overwritten by inner {inner_kind} target
"PLW0602","PLW0603","PLW0604", # global variable things
"RUF015",
# - RUF015 - Prefer next({iterable}) over single element slice
"RUF012",
# - RUF012 - Mutable class attributes should be annotated with typing.ClassVar
"SIM117",
# - SIM117 - Use a single with statement with multiple contexts instead of nested with statements
"SIM102",
# - SIM102 - Use a single if statement instead of nested if statements
"SIM105",
# - SIM105 - Use contextlib.suppress({exception}) instead of try-except-pass
"C901",
# - C901 - function is too complex
"N999", "N999",
# - N802 - function name should be lowercase # - N802 - function name should be lowercase
"N802", "N802",
@ -42,7 +59,7 @@ ignore = [
"B904", "B904",
# Remove fast # Remove fast
"A001", "A002","A003","B018" "A002", "B018"
] ]
[tool.ruff.lint.pydocstyle] [tool.ruff.lint.pydocstyle]

View File

@ -104,14 +104,16 @@ class InvenTreeResource(ModelResource):
attribute = getattr(field, 'attribute', field_name) attribute = getattr(field, 'attribute', field_name)
# Check if the associated database field is a non-nullable string # Check if the associated database field is a non-nullable string
if db_field := db_fields.get(attribute): if (
if ( (db_field := db_fields.get(attribute))
and (
isinstance(db_field, CharField) isinstance(db_field, CharField)
and db_field.blank and db_field.blank
and not db_field.null and not db_field.null
): )
if column not in self.CONVERT_NULL_FIELDS: and column not in self.CONVERT_NULL_FIELDS
self.CONVERT_NULL_FIELDS.append(column) ):
self.CONVERT_NULL_FIELDS.append(column)
return super().before_import(dataset, using_transactions, dry_run, **kwargs) return super().before_import(dataset, using_transactions, dry_run, **kwargs)

View File

@ -26,7 +26,7 @@ from part.models import Part
from plugin.serializers import MetadataSerializer from plugin.serializers import MetadataSerializer
from users.models import ApiToken from users.models import ApiToken
from .email import is_email_configured from .helpers_email import is_email_configured
from .mixins import ListAPI, RetrieveUpdateAPI from .mixins import ListAPI, RetrieveUpdateAPI
from .status import check_system_health, is_worker_running from .status import check_system_health, is_worker_running
from .version import inventreeApiText from .version import inventreeApiText
@ -77,7 +77,7 @@ class LicenseView(APIView):
# Ensure we do not have any duplicate 'name' values in the list # Ensure we do not have any duplicate 'name' values in the list
for entry in data: for entry in data:
name = None name = None
for key in entry.keys(): for key in entry:
if key.lower() == 'name': if key.lower() == 'name':
name = entry[key] name = entry[key]
break break
@ -321,7 +321,6 @@ class BulkDeleteMixin:
Raises: Raises:
ValidationError: If the deletion should not proceed ValidationError: If the deletion should not proceed
""" """
pass
def filter_delete_queryset(self, queryset, request): def filter_delete_queryset(self, queryset, request):
"""Provide custom filtering for the queryset *before* it is deleted. """Provide custom filtering for the queryset *before* it is deleted.
@ -398,8 +397,6 @@ class BulkDeleteMixin:
class ListCreateDestroyAPIView(BulkDeleteMixin, ListCreateAPI): class ListCreateDestroyAPIView(BulkDeleteMixin, ListCreateAPI):
"""Custom API endpoint which provides BulkDelete functionality in addition to List and Create.""" """Custom API endpoint which provides BulkDelete functionality in addition to List and Create."""
...
class APISearchViewSerializer(serializers.Serializer): class APISearchViewSerializer(serializers.Serializer):
"""Serializer for the APISearchView.""" """Serializer for the APISearchView."""

View File

@ -125,7 +125,7 @@ class InvenTreeConfig(AppConfig):
for task in tasks: for task in tasks:
ref_name = f'{task.func.__module__}.{task.func.__name__}' ref_name = f'{task.func.__module__}.{task.func.__name__}'
if ref_name in existing_tasks.keys(): if ref_name in existing_tasks:
# This task already exists - update the details if required # This task already exists - update the details if required
existing_task = existing_tasks[ref_name] existing_task = existing_tasks[ref_name]

View File

@ -131,7 +131,7 @@ def load_config_data(set_cache: bool = False) -> map:
cfg_file = get_config_file() cfg_file = get_config_file()
with open(cfg_file, 'r') as cfg: with open(cfg_file, encoding='utf-8') as cfg:
data = yaml.safe_load(cfg) data = yaml.safe_load(cfg)
# Set the cache if requested # Set the cache if requested

View File

@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
"""Provides extra global data to all templates.""" """Provides extra global data to all templates."""
import InvenTree.email import InvenTree.helpers_email
import InvenTree.ready import InvenTree.ready
import InvenTree.status import InvenTree.status
from generic.states.custom import get_custom_classes from generic.states.custom import get_custom_classes
@ -27,7 +25,7 @@ def health_status(request):
status = { status = {
'django_q_running': InvenTree.status.is_worker_running(), 'django_q_running': InvenTree.status.is_worker_running(),
'email_configured': InvenTree.email.is_email_configured(), 'email_configured': InvenTree.helpers_email.is_email_configured(),
} }
# The following keys are required to denote system health # The following keys are required to denote system health
@ -75,7 +73,7 @@ def user_roles(request):
roles = {} roles = {}
for role in RuleSet.get_ruleset_models().keys(): for role in RuleSet.get_ruleset_models():
permissions = {} permissions = {}
for perm in ['view', 'add', 'change', 'delete']: for perm in ['view', 'add', 'change', 'delete']:

View File

@ -2,6 +2,7 @@
import logging import logging
import re import re
from typing import Optional
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -95,7 +96,7 @@ def from_engineering_notation(value):
""" """
value = str(value).strip() value = str(value).strip()
pattern = '(\d+)([a-zA-Z]+)(\d+)(.*)' pattern = r'(\d+)([a-zA-Z]+)(\d+)(.*)'
if match := re.match(pattern, value): if match := re.match(pattern, value):
left, prefix, right, suffix = match.groups() left, prefix, right, suffix = match.groups()
@ -133,7 +134,7 @@ def convert_value(value, unit):
return value return value
def convert_physical_value(value: str, unit: str = None, strip_units=True): def convert_physical_value(value: str, unit: Optional[str] = None, strip_units=True):
"""Validate that the provided value is a valid physical quantity. """Validate that the provided value is a valid physical quantity.
Arguments: Arguments:
@ -245,8 +246,4 @@ def is_dimensionless(value):
if value.units == ureg.dimensionless: if value.units == ureg.dimensionless:
return True return True
if value.to_base_units().units == ureg.dimensionless: return value.to_base_units().units == ureg.dimensionless
return True
# At this point, the value is not dimensionless
return False

View File

@ -1,7 +1,6 @@
"""Custom exception handling for the DRF API.""" """Custom exception handling for the DRF API."""
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import unicode_literals
import logging import logging
import sys import sys
@ -41,10 +40,7 @@ def log_error(path, error_name=None, error_info=None, error_data=None):
if kind in settings.IGNORED_ERRORS: if kind in settings.IGNORED_ERRORS:
return return
if error_name: kind = error_name or getattr(kind, '__name__', 'Unknown Error')
kind = error_name
else:
kind = getattr(kind, '__name__', 'Unknown Error')
if error_info: if error_info:
info = error_info info = error_info

View File

@ -31,10 +31,7 @@ class InvenTreeExchange(SimpleExchangeBackend):
# Find the selected exchange rate plugin # Find the selected exchange rate plugin
slug = get_global_setting('CURRENCY_UPDATE_PLUGIN', create=False) slug = get_global_setting('CURRENCY_UPDATE_PLUGIN', create=False)
if slug: plugin = registry.get_plugin(slug) if slug else None
plugin = registry.get_plugin(slug)
else:
plugin = None
if not plugin: if not plugin:
# Find the first active currency exchange plugin # Find the first active currency exchange plugin

View File

@ -93,9 +93,8 @@ class InvenTreeModelMoneyField(ModelMoneyField):
allow_negative = kwargs.pop('allow_negative', False) allow_negative = kwargs.pop('allow_negative', False)
# If no validators are provided, add some "standard" ones # If no validators are provided, add some "standard" ones
if len(validators) == 0: if len(validators) == 0 and not allow_negative:
if not allow_negative: validators.append(MinMoneyValidator(0))
validators.append(MinMoneyValidator(0))
kwargs['validators'] = validators kwargs['validators'] = validators

View File

@ -118,10 +118,7 @@ class InvenTreeOrderingFilter(filters.OrderingFilter):
field = field[1:] field = field[1:]
# Are aliases defined for this field? # Are aliases defined for this field?
if field in aliases: alias = aliases.get(field, field)
alias = aliases[field]
else:
alias = field
""" """
Potentially, a single field could be "aliased" to multiple field, Potentially, a single field could be "aliased" to multiple field,

View File

@ -2,6 +2,7 @@
import re import re
import string import string
from typing import Optional
from django.conf import settings from django.conf import settings
from django.utils import translation from django.utils import translation
@ -106,10 +107,7 @@ def construct_format_regex(fmt_string: str) -> str:
# Add a named capture group for the format entry # Add a named capture group for the format entry
if name: if name:
# Check if integer values are required # Check if integer values are required
if _fmt.endswith('d'): c = '\\d' if _fmt.endswith('d') else '.'
c = '\d'
else:
c = '.'
# Specify width # Specify width
# TODO: Introspect required width # TODO: Introspect required width
@ -160,7 +158,7 @@ def extract_named_group(name: str, value: str, fmt_string: str) -> str:
""" """
info = parse_format_string(fmt_string) info = parse_format_string(fmt_string)
if name not in info.keys(): if name not in info:
raise NameError(_(f"Value '{name}' does not appear in pattern format")) raise NameError(_(f"Value '{name}' does not appear in pattern format"))
# Construct a regular expression for matching against the provided format string # Construct a regular expression for matching against the provided format string
@ -182,8 +180,8 @@ def extract_named_group(name: str, value: str, fmt_string: str) -> str:
def format_money( def format_money(
money: Money, money: Money,
decimal_places: int = None, decimal_places: Optional[int] = None,
format: str = None, fmt: Optional[str] = None,
include_symbol: bool = True, include_symbol: bool = True,
) -> str: ) -> str:
"""Format money object according to the currently set local. """Format money object according to the currently set local.
@ -191,7 +189,7 @@ def format_money(
Args: Args:
money (Money): The money object to format money (Money): The money object to format
decimal_places (int): Number of decimal places to use decimal_places (int): Number of decimal places to use
format (str): Format pattern according LDML / the babel format pattern syntax (https://babel.pocoo.org/en/latest/numbers.html) fmt (str): Format pattern according LDML / the babel format pattern syntax (https://babel.pocoo.org/en/latest/numbers.html)
Returns: Returns:
str: The formatted string str: The formatted string
@ -199,10 +197,10 @@ def format_money(
Raises: Raises:
ValueError: format string is incorrectly specified ValueError: format string is incorrectly specified
""" """
language = None and translation.get_language() or settings.LANGUAGE_CODE language = (None) or settings.LANGUAGE_CODE
locale = Locale.parse(translation.to_locale(language)) locale = Locale.parse(translation.to_locale(language))
if format: if fmt:
pattern = parse_pattern(format) pattern = parse_pattern(fmt)
else: else:
pattern = locale.currency_formats['standard'] pattern = locale.currency_formats['standard']
if decimal_places is not None: if decimal_places is not None:

View File

@ -266,9 +266,8 @@ class RegistratonMixin:
raise forms.ValidationError( raise forms.ValidationError(
_('The provided primary email address is not valid.') _('The provided primary email address is not valid.')
) )
else: elif split_email[1] == option[1:]:
if split_email[1] == option[1:]: return super().clean_email(email)
return super().clean_email(email)
logger.info('The provided email domain for %s is not approved', email) logger.info('The provided email domain for %s is not approved', email)
raise forms.ValidationError(_('The provided email domain is not approved.')) raise forms.ValidationError(_('The provided email domain is not approved.'))

View File

@ -9,15 +9,15 @@ import os.path
import re import re
from decimal import Decimal, InvalidOperation from decimal import Decimal, InvalidOperation
from pathlib import Path from pathlib import Path
from typing import TypeVar, Union from typing import Optional, TypeVar, Union
from wsgiref.util import FileWrapper from wsgiref.util import FileWrapper
import django.utils.timezone as timezone
from django.conf import settings from django.conf import settings
from django.contrib.staticfiles.storage import StaticFilesStorage from django.contrib.staticfiles.storage import StaticFilesStorage
from django.core.exceptions import FieldError, ValidationError from django.core.exceptions import FieldError, ValidationError
from django.core.files.storage import Storage, default_storage from django.core.files.storage import Storage, default_storage
from django.http import StreamingHttpResponse from django.http import StreamingHttpResponse
from django.utils import timezone
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
import pytz import pytz
@ -97,10 +97,7 @@ def generateTestKey(test_name: str) -> str:
if char.isidentifier(): if char.isidentifier():
return True return True
if char.isalnum(): return bool(char.isalnum())
return True
return False
# Remove any characters that cannot be used to represent a variable # Remove any characters that cannot be used to represent a variable
key = ''.join([c for c in key if valid_char(c)]) key = ''.join([c for c in key if valid_char(c)])
@ -490,10 +487,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
except ValueError: except ValueError:
raise ValidationError([_('Invalid quantity provided')]) raise ValidationError([_('Invalid quantity provided')])
if input_string: input_string = str(input_string).strip() if input_string else ''
input_string = str(input_string).strip()
else:
input_string = ''
if len(input_string) == 0: if len(input_string) == 0:
raise ValidationError([_('Empty serial number string')]) raise ValidationError([_('Empty serial number string')])
@ -800,10 +794,10 @@ def remove_non_printable_characters(
if remove_unicode: if remove_unicode:
# Remove Unicode control characters # Remove Unicode control characters
if remove_newline: if remove_newline:
cleaned = regex.sub('[^\P{C}]+', '', cleaned) cleaned = regex.sub(r'[^\P{C}]+', '', cleaned)
else: else:
# Use 'negative-lookahead' to exclude newline character # Use 'negative-lookahead' to exclude newline character
cleaned = regex.sub('(?![\x0a])[^\P{C}]+', '', cleaned) cleaned = regex.sub('(?![\x0a])[^\\P{C}]+', '', cleaned)
return cleaned return cleaned
@ -827,7 +821,7 @@ def hash_barcode(barcode_data):
def hash_file(filename: Union[str, Path], storage: Union[Storage, None] = None): def hash_file(filename: Union[str, Path], storage: Union[Storage, None] = None):
"""Return the MD5 hash of a file.""" """Return the MD5 hash of a file."""
content = ( content = (
open(filename, 'rb').read() open(filename, 'rb').read() # noqa: SIM115
if storage is None if storage is None
else storage.open(str(filename), 'rb').read() else storage.open(str(filename), 'rb').read()
) )
@ -865,7 +859,7 @@ def server_timezone() -> str:
return settings.TIME_ZONE return settings.TIME_ZONE
def to_local_time(time, target_tz: str = None): def to_local_time(time, target_tz: Optional[str] = None):
"""Convert the provided time object to the local timezone. """Convert the provided time object to the local timezone.
Arguments: Arguments:

View File

@ -62,7 +62,7 @@ def send_email(subject, body, recipients, from_email=None, html_message=None):
# If we are importing data, don't send emails # If we are importing data, don't send emails
return return
if not InvenTree.email.is_email_configured() and not settings.TESTING: if not is_email_configured() and not settings.TESTING:
# Email is not configured / enabled # Email is not configured / enabled
return return

View File

@ -114,10 +114,7 @@ def download_image_from_url(remote_url, timeout=2.5):
# Add user specified user-agent to request (if specified) # Add user specified user-agent to request (if specified)
user_agent = get_global_setting('INVENTREE_DOWNLOAD_FROM_URL_USER_AGENT') user_agent = get_global_setting('INVENTREE_DOWNLOAD_FROM_URL_USER_AGENT')
if user_agent: headers = {'User-Agent': user_agent} if user_agent else None
headers = {'User-Agent': user_agent}
else:
headers = None
try: try:
response = requests.get( response = requests.get(
@ -130,7 +127,7 @@ def download_image_from_url(remote_url, timeout=2.5):
# Throw an error if anything goes wrong # Throw an error if anything goes wrong
response.raise_for_status() response.raise_for_status()
except requests.exceptions.ConnectionError as exc: except requests.exceptions.ConnectionError as exc:
raise Exception(_('Connection error') + f': {str(exc)}') raise Exception(_('Connection error') + f': {exc!s}')
except requests.exceptions.Timeout as exc: except requests.exceptions.Timeout as exc:
raise exc raise exc
except requests.exceptions.HTTPError: except requests.exceptions.HTTPError:
@ -138,7 +135,7 @@ def download_image_from_url(remote_url, timeout=2.5):
_('Server responded with invalid status code') + f': {response.status_code}' _('Server responded with invalid status code') + f': {response.status_code}'
) )
except Exception as exc: except Exception as exc:
raise Exception(_('Exception occurred') + f': {str(exc)}') raise Exception(_('Exception occurred') + f': {exc!s}')
if response.status_code != 200: if response.status_code != 200:
raise Exception( raise Exception(

View File

@ -47,7 +47,7 @@ class Command(BaseCommand):
filename = kwargs.get('filename', 'inventree_settings.json') filename = kwargs.get('filename', 'inventree_settings.json')
with open(filename, 'w') as f: with open(filename, 'w', encoding='utf-8') as f:
json.dump(settings, f, indent=4) json.dump(settings, f, indent=4)
print(f"Exported InvenTree settings definitions to '{filename}'") print(f"Exported InvenTree settings definitions to '{filename}'")

View File

@ -103,14 +103,14 @@ class Command(BaseCommand):
}) })
self.stdout.write(f'Writing icon map for {len(icons.keys())} icons') self.stdout.write(f'Writing icon map for {len(icons.keys())} icons')
with open(kwargs['output_file'], 'w') as f: with open(kwargs['output_file'], 'w', encoding='utf-8') as f:
json.dump(icons, f, indent=2) json.dump(icons, f, indent=2)
self.stdout.write(f'Icon map written to {kwargs["output_file"]}') self.stdout.write(f'Icon map written to {kwargs["output_file"]}')
# Import icon map file # Import icon map file
if kwargs['input_file']: if kwargs['input_file']:
with open(kwargs['input_file'], 'r') as f: with open(kwargs['input_file'], encoding='utf-8') as f:
icons = json.load(f) icons = json.load(f)
self.stdout.write(f'Loaded icon map for {len(icons.keys())} icons') self.stdout.write(f'Loaded icon map for {len(icons.keys())} icons')

View File

@ -19,10 +19,11 @@ def render_file(file_name, source, target, locales, ctx):
target_file = os.path.join(target, locale + '.' + file_name) target_file = os.path.join(target, locale + '.' + file_name)
with open(target_file, 'w') as localised_file: with open(target_file, 'w', encoding='utf-8') as localised_file, lang_over(
with lang_over(locale): locale
rendered = render_to_string(os.path.join(source, file_name), ctx) ):
localised_file.write(rendered) rendered = render_to_string(os.path.join(source, file_name), ctx)
localised_file.write(rendered)
class Command(BaseCommand): class Command(BaseCommand):

View File

@ -35,7 +35,7 @@ class Command(BaseCommand):
img_paths.append(x.path) img_paths.append(x.path)
if len(img_paths) > 0: if len(img_paths) > 0:
if all((os.path.exists(path) for path in img_paths)): if all(os.path.exists(path) for path in img_paths):
# All images exist - skip further work # All images exist - skip further work
return return

View File

@ -35,4 +35,4 @@ class Command(BaseCommand):
mfa_user[0].staticdevice_set.all().delete() mfa_user[0].staticdevice_set.all().delete()
# TOTP tokens # TOTP tokens
mfa_user[0].totpdevice_set.all().delete() mfa_user[0].totpdevice_set.all().delete()
print(f'Removed all MFA methods for user {str(mfa_user[0])}') print(f'Removed all MFA methods for user {mfa_user[0]!s}')

View File

@ -204,7 +204,7 @@ class InvenTreeMetadata(SimpleMetadata):
# Iterate through simple fields # Iterate through simple fields
for name, field in model_fields.fields.items(): for name, field in model_fields.fields.items():
if name in serializer_info.keys(): if name in serializer_info:
if name in read_only_fields: if name in read_only_fields:
serializer_info[name]['read_only'] = True serializer_info[name]['read_only'] = True
@ -236,7 +236,7 @@ class InvenTreeMetadata(SimpleMetadata):
# Iterate through relations # Iterate through relations
for name, relation in model_fields.relations.items(): for name, relation in model_fields.relations.items():
if name not in serializer_info.keys(): if name not in serializer_info:
# Skip relation not defined in serializer # Skip relation not defined in serializer
continue continue
@ -307,12 +307,12 @@ class InvenTreeMetadata(SimpleMetadata):
instance_filters = instance.api_instance_filters() instance_filters = instance.api_instance_filters()
for field_name, field_filters in instance_filters.items(): for field_name, field_filters in instance_filters.items():
if field_name not in serializer_info.keys(): if field_name not in serializer_info:
# The field might be missing, but is added later on # The field might be missing, but is added later on
# This function seems to get called multiple times? # This function seems to get called multiple times?
continue continue
if 'instance_filters' not in serializer_info[field_name].keys(): if 'instance_filters' not in serializer_info[field_name]:
serializer_info[field_name]['instance_filters'] = {} serializer_info[field_name]['instance_filters'] = {}
for key, value in field_filters.items(): for key, value in field_filters.items():

View File

@ -36,7 +36,7 @@ def get_token_from_request(request):
return None return None
class AuthRequiredMiddleware(object): class AuthRequiredMiddleware:
"""Check for user to be authenticated.""" """Check for user to be authenticated."""
def __init__(self, get_response): def __init__(self, get_response):
@ -92,23 +92,18 @@ class AuthRequiredMiddleware(object):
# Allow static files to be accessed without auth # Allow static files to be accessed without auth
# Important for e.g. login page # Important for e.g. login page
if request.path_info.startswith('/static/'): if (
authorized = True request.path_info.startswith('/static/')
or request.path_info.startswith('/accounts/')
# Unauthorized users can access the login page or (
elif request.path_info.startswith('/accounts/'): request.path_info.startswith(f'/{settings.FRONTEND_URL_BASE}/')
authorized = True or request.path_info.startswith('/assets/')
or request.path_info == f'/{settings.FRONTEND_URL_BASE}'
elif ( )
request.path_info.startswith(f'/{settings.FRONTEND_URL_BASE}/') or self.check_token(request)
or request.path_info.startswith('/assets/')
or request.path_info == f'/{settings.FRONTEND_URL_BASE}'
): ):
authorized = True authorized = True
elif self.check_token(request):
authorized = True
# No authorization was found for the request # No authorization was found for the request
if not authorized: if not authorized:
path = request.path_info path = request.path_info

View File

@ -128,14 +128,10 @@ class CreateAPI(CleanMixin, generics.CreateAPIView):
class RetrieveAPI(generics.RetrieveAPIView): class RetrieveAPI(generics.RetrieveAPIView):
"""View for retrieve API.""" """View for retrieve API."""
pass
class RetrieveUpdateAPI(CleanMixin, generics.RetrieveUpdateAPIView): class RetrieveUpdateAPI(CleanMixin, generics.RetrieveUpdateAPIView):
"""View for retrieve and update API.""" """View for retrieve and update API."""
pass
class CustomDestroyModelMixin: class CustomDestroyModelMixin:
"""This mixin was created pass the kwargs from the API to the models.""" """This mixin was created pass the kwargs from the API to the models."""

View File

@ -390,10 +390,7 @@ class ReferenceIndexingMixin(models.Model):
except Exception: except Exception:
# If anything goes wrong, return the most recent reference # If anything goes wrong, return the most recent reference
recent = cls.get_most_recent_item() recent = cls.get_most_recent_item()
if recent: reference = recent.reference if recent else ''
reference = recent.reference
else:
reference = ''
return reference return reference
@ -410,14 +407,14 @@ class ReferenceIndexingMixin(models.Model):
}) })
# Check that only 'allowed' keys are provided # Check that only 'allowed' keys are provided
for key in info.keys(): for key in info:
if key not in ctx.keys(): if key not in ctx:
raise ValidationError({ raise ValidationError({
'value': _('Unknown format key specified') + f": '{key}'" 'value': _('Unknown format key specified') + f": '{key}'"
}) })
# Check that the 'ref' variable is specified # Check that the 'ref' variable is specified
if 'ref' not in info.keys(): if 'ref' not in info:
raise ValidationError({ raise ValidationError({
'value': _('Missing required format key') + ": 'ref'" 'value': _('Missing required format key') + ": 'ref'"
}) })
@ -859,7 +856,7 @@ class InvenTreeTree(MetadataMixin, PluginValidationMixin, MPTTModel):
Returns: Returns:
List of category names from the top level to this category List of category names from the top level to this category
""" """
return self.parentpath + [self] return [*self.parentpath, self]
def get_path(self): def get_path(self):
"""Return a list of element in the item tree. """Return a list of element in the item tree.
@ -1099,4 +1096,4 @@ def after_error_logged(sender, instance: Error, created: bool, **kwargs):
except Exception as exc: except Exception as exc:
"""We do not want to throw an exception while reporting an exception""" """We do not want to throw an exception while reporting an exception"""
logger.error(exc) # noqa: LOG005 logger.error(exc)

View File

@ -105,8 +105,7 @@ class IsStaffOrReadOnly(permissions.IsAdminUser):
def has_permission(self, request, view): def has_permission(self, request, view):
"""Check if the user is a superuser.""" """Check if the user is a superuser."""
return bool( return bool(
request.user (request.user and request.user.is_staff)
and request.user.is_staff
or request.method in permissions.SAFE_METHODS or request.method in permissions.SAFE_METHODS
) )

View File

@ -11,43 +11,37 @@ def isInTestMode():
def isImportingData(): def isImportingData():
"""Returns True if the database is currently importing (or exporting) data, e.g. 'loaddata' command is performed.""" """Returns True if the database is currently importing (or exporting) data, e.g. 'loaddata' command is performed."""
return any((x in sys.argv for x in ['flush', 'loaddata', 'dumpdata'])) return any(x in sys.argv for x in ['flush', 'loaddata', 'dumpdata'])
def isRunningMigrations(): def isRunningMigrations():
"""Return True if the database is currently running migrations.""" """Return True if the database is currently running migrations."""
return any( return any(
( x in sys.argv
x in sys.argv for x in ['migrate', 'makemigrations', 'showmigrations', 'runmigrations']
for x in ['migrate', 'makemigrations', 'showmigrations', 'runmigrations']
)
) )
def isRebuildingData(): def isRebuildingData():
"""Return true if any of the rebuilding commands are being executed.""" """Return true if any of the rebuilding commands are being executed."""
return any( return any(
( x in sys.argv
x in sys.argv for x in ['prerender', 'rebuild_models', 'rebuild_thumbnails', 'rebuild']
for x in ['prerender', 'rebuild_models', 'rebuild_thumbnails', 'rebuild']
)
) )
def isRunningBackup(): def isRunningBackup():
"""Return true if any of the backup commands are being executed.""" """Return true if any of the backup commands are being executed."""
return any( return any(
( x in sys.argv
x in sys.argv for x in [
for x in [ 'backup',
'backup', 'restore',
'restore', 'dbbackup',
'dbbackup', 'dbresotore',
'dbresotore', 'mediabackup',
'mediabackup', 'mediarestore',
'mediarestore', ]
]
)
) )
@ -64,10 +58,7 @@ def isInServerThread():
if 'runserver' in sys.argv: if 'runserver' in sys.argv:
return True return True
if 'gunicorn' in sys.argv[0]: return 'gunicorn' in sys.argv[0]
return True
return False
def isInMainThread(): def isInMainThread():
@ -128,11 +119,7 @@ def canAppAccessDatabase(
if not allow_plugins: if not allow_plugins:
excluded_commands.extend(['collectplugins']) excluded_commands.extend(['collectplugins'])
for cmd in excluded_commands: return all(cmd not in sys.argv for cmd in excluded_commands)
if cmd in sys.argv:
return False
return True
def isPluginRegistryLoaded(): def isPluginRegistryLoaded():

View File

@ -89,7 +89,7 @@ class InvenTreeCurrencySerializer(serializers.ChoiceField):
) )
if allow_blank: if allow_blank:
choices = [('', '---------')] + choices choices = [('', '---------'), *choices]
kwargs['choices'] = choices kwargs['choices'] = choices
@ -379,7 +379,7 @@ class InvenTreeTaggitSerializer(TaggitSerializer):
tag_object = super().update(instance, validated_data) tag_object = super().update(instance, validated_data)
for key in to_be_tagged.keys(): for key in to_be_tagged:
# re-add the tagmanager # re-add the tagmanager
new_tagobject = tag_object.__class__.objects.get(id=tag_object.id) new_tagobject = tag_object.__class__.objects.get(id=tag_object.id)
setattr(tag_object, key, getattr(new_tagobject, key)) setattr(tag_object, key, getattr(new_tagobject, key))
@ -390,8 +390,6 @@ class InvenTreeTaggitSerializer(TaggitSerializer):
class InvenTreeTagModelSerializer(InvenTreeTaggitSerializer, InvenTreeModelSerializer): class InvenTreeTagModelSerializer(InvenTreeTaggitSerializer, InvenTreeModelSerializer):
"""Combination of InvenTreeTaggitSerializer and InvenTreeModelSerializer.""" """Combination of InvenTreeTaggitSerializer and InvenTreeModelSerializer."""
pass
class UserSerializer(InvenTreeModelSerializer): class UserSerializer(InvenTreeModelSerializer):
"""Serializer for a User.""" """Serializer for a User."""
@ -426,14 +424,15 @@ class ExendedUserSerializer(UserSerializer):
class Meta(UserSerializer.Meta): class Meta(UserSerializer.Meta):
"""Metaclass defines serializer fields.""" """Metaclass defines serializer fields."""
fields = UserSerializer.Meta.fields + [ fields = [
*UserSerializer.Meta.fields,
'groups', 'groups',
'is_staff', 'is_staff',
'is_superuser', 'is_superuser',
'is_active', 'is_active',
] ]
read_only_fields = UserSerializer.Meta.read_only_fields + ['groups'] read_only_fields = [*UserSerializer.Meta.read_only_fields, 'groups']
is_staff = serializers.BooleanField( is_staff = serializers.BooleanField(
label=_('Staff'), help_text=_('Does this user have staff permissions') label=_('Staff'), help_text=_('Does this user have staff permissions')
@ -704,7 +703,6 @@ class DataFileUploadSerializer(serializers.Serializer):
def save(self): def save(self):
"""Empty overwrite for save.""" """Empty overwrite for save."""
...
class DataFileExtractSerializer(serializers.Serializer): class DataFileExtractSerializer(serializers.Serializer):
@ -806,11 +804,10 @@ class DataFileExtractSerializer(serializers.Serializer):
required = field.get('required', False) required = field.get('required', False)
# Check for missing required columns # Check for missing required columns
if required: if required and name not in self.columns:
if name not in self.columns: raise serializers.ValidationError(
raise serializers.ValidationError( _(f"Missing required column: '{name}'")
_(f"Missing required column: '{name}'") )
)
for col in self.columns: for col in self.columns:
if not col: if not col:
@ -824,7 +821,6 @@ class DataFileExtractSerializer(serializers.Serializer):
def save(self): def save(self):
"""No "save" action for this serializer.""" """No "save" action for this serializer."""
pass
class NotesFieldMixin: class NotesFieldMixin:

View File

@ -281,7 +281,7 @@ QUERYCOUNT = {
'MIN_TIME_TO_LOG': 0.1, 'MIN_TIME_TO_LOG': 0.1,
'MIN_QUERY_COUNT_TO_LOG': 25, 'MIN_QUERY_COUNT_TO_LOG': 25,
}, },
'IGNORE_REQUEST_PATTERNS': ['^(?!\/(api)?(plugin)?\/).*'], 'IGNORE_REQUEST_PATTERNS': [r'^(?!\/(api)?(plugin)?\/).*'],
'IGNORE_SQL_PATTERNS': [], 'IGNORE_SQL_PATTERNS': [],
'DISPLAY_DUPLICATES': 1, 'DISPLAY_DUPLICATES': 1,
'RESPONSE_HEADER': 'X-Django-Query-Count', 'RESPONSE_HEADER': 'X-Django-Query-Count',
@ -298,7 +298,7 @@ if (
and INVENTREE_ADMIN_ENABLED and INVENTREE_ADMIN_ENABLED
and not TESTING and not TESTING
and get_boolean_setting('INVENTREE_DEBUG_SHELL', 'debug_shell', False) and get_boolean_setting('INVENTREE_DEBUG_SHELL', 'debug_shell', False)
): # noqa ):
try: try:
import django_admin_shell # noqa: F401 import django_admin_shell # noqa: F401
@ -949,10 +949,7 @@ USE_I18N = True
# Do not use native timezone support in "test" mode # Do not use native timezone support in "test" mode
# It generates a *lot* of cruft in the logs # It generates a *lot* of cruft in the logs
if not TESTING: USE_TZ = bool(not TESTING)
USE_TZ = True # pragma: no cover
else:
USE_TZ = False
DATE_INPUT_FORMATS = ['%Y-%m-%d'] DATE_INPUT_FORMATS = ['%Y-%m-%d']
@ -1065,7 +1062,7 @@ COOKIE_MODE = (
valid_cookie_modes = {'lax': 'Lax', 'strict': 'Strict', 'none': None, 'null': None} valid_cookie_modes = {'lax': 'Lax', 'strict': 'Strict', 'none': None, 'null': None}
if COOKIE_MODE not in valid_cookie_modes.keys(): if COOKIE_MODE not in valid_cookie_modes:
logger.error('Invalid cookie samesite mode: %s', COOKIE_MODE) logger.error('Invalid cookie samesite mode: %s', COOKIE_MODE)
sys.exit(-1) sys.exit(-1)

View File

@ -94,20 +94,19 @@ for name, provider in providers.registry.provider_map.items():
urls = [] urls = []
if len(adapters) == 1: if len(adapters) == 1:
urls = handle_oauth2(adapter=adapters[0]) urls = handle_oauth2(adapter=adapters[0])
elif provider.id in legacy:
logger.warning(
'`%s` is not supported on platform UI. Use `%s` instead.',
provider.id,
legacy[provider.id],
)
continue
else: else:
if provider.id in legacy: logger.error(
logger.warning( 'Found handler that is not yet ready for platform UI: `%s`. Open an feature request on GitHub if you need it implemented.',
'`%s` is not supported on platform UI. Use `%s` instead.', provider.id,
provider.id, )
legacy[provider.id], continue
)
continue
else:
logger.error(
'Found handler that is not yet ready for platform UI: `%s`. Open an feature request on GitHub if you need it implemented.',
provider.id,
)
continue
provider_urlpatterns += [path(f'{provider.id}/', include(urls))] provider_urlpatterns += [path(f'{provider.id}/', include(urls))]

View File

@ -1,5 +1,4 @@
"""Provides system status functionality checks.""" """Provides system status functionality checks."""
# -*- coding: utf-8 -*-
import logging import logging
from datetime import timedelta from datetime import timedelta
@ -10,7 +9,7 @@ from django.utils.translation import gettext_lazy as _
from django_q.models import Success from django_q.models import Success
from django_q.status import Stat from django_q.status import Stat
import InvenTree.email import InvenTree.helpers_email
import InvenTree.ready import InvenTree.ready
logger = logging.getLogger('inventree') logger = logging.getLogger('inventree')
@ -65,7 +64,7 @@ def check_system_health(**kwargs):
result = False result = False
logger.warning(_('Background worker check failed')) logger.warning(_('Background worker check failed'))
if not InvenTree.email.is_email_configured(): # pragma: no cover if not InvenTree.helpers_email.is_email_configured(): # pragma: no cover
result = False result = False
logger.warning(_('Email backend not configured')) logger.warning(_('Email backend not configured'))

View File

@ -9,7 +9,7 @@ import time
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Callable from typing import Callable, Optional
from django.conf import settings from django.conf import settings
from django.core.exceptions import AppRegistryNotReady from django.core.exceptions import AppRegistryNotReady
@ -206,7 +206,7 @@ def offload_task(
raise_warning(f"WARNING: '{taskname}' not offloaded - Function not found") raise_warning(f"WARNING: '{taskname}' not offloaded - Function not found")
return False return False
except Exception as exc: except Exception as exc:
raise_warning(f"WARNING: '{taskname}' not offloaded due to {str(exc)}") raise_warning(f"WARNING: '{taskname}' not offloaded due to {exc!s}")
log_error('InvenTree.offload_task') log_error('InvenTree.offload_task')
return False return False
else: else:
@ -256,7 +256,7 @@ def offload_task(
_func(*args, **kwargs) _func(*args, **kwargs)
except Exception as exc: except Exception as exc:
log_error('InvenTree.offload_task') log_error('InvenTree.offload_task')
raise_warning(f"WARNING: '{taskname}' failed due to {str(exc)}") raise_warning(f"WARNING: '{taskname}' failed due to {exc!s}")
raise exc raise exc
# Finally, task either completed successfully or was offloaded # Finally, task either completed successfully or was offloaded
@ -291,7 +291,7 @@ class TaskRegister:
task_list: list[ScheduledTask] = [] task_list: list[ScheduledTask] = []
def register(self, task, schedule, minutes: int = None): def register(self, task, schedule, minutes: Optional[int] = None):
"""Register a task with the que.""" """Register a task with the que."""
self.task_list.append(ScheduledTask(task, schedule, minutes)) self.task_list.append(ScheduledTask(task, schedule, minutes))
@ -299,7 +299,9 @@ class TaskRegister:
tasks = TaskRegister() tasks = TaskRegister()
def scheduled_task(interval: str, minutes: int = None, tasklist: TaskRegister = None): def scheduled_task(
interval: str, minutes: Optional[int] = None, tasklist: TaskRegister = None
):
"""Register the given task as a scheduled task. """Register the given task as a scheduled task.
Example: Example:

View File

@ -80,7 +80,7 @@ def do_translate(parser, token):
""" """
bits = token.split_contents() bits = token.split_contents()
if len(bits) < 2: if len(bits) < 2:
raise TemplateSyntaxError("'%s' takes at least one argument" % bits[0]) raise TemplateSyntaxError(f"'{bits[0]}' takes at least one argument")
message_string = parser.compile_filter(bits[1]) message_string = parser.compile_filter(bits[1])
remaining = bits[2:] remaining = bits[2:]
@ -95,7 +95,7 @@ def do_translate(parser, token):
option = remaining.pop(0) option = remaining.pop(0)
if option in seen: if option in seen:
raise TemplateSyntaxError( raise TemplateSyntaxError(
"The '%s' option was specified more than once." % option f"The '{option}' option was specified more than once."
) )
elif option == 'noop': elif option == 'noop':
noop = True noop = True
@ -104,13 +104,12 @@ def do_translate(parser, token):
value = remaining.pop(0) value = remaining.pop(0)
except IndexError: except IndexError:
raise TemplateSyntaxError( raise TemplateSyntaxError(
"No argument provided to the '%s' tag for the context option." f"No argument provided to the '{bits[0]}' tag for the context option."
% bits[0]
) )
if value in invalid_context: if value in invalid_context:
raise TemplateSyntaxError( raise TemplateSyntaxError(
"Invalid argument '%s' provided to the '%s' tag for the context " f"Invalid argument '{value}' provided to the '{bits[0]}' tag for the context "
'option' % (value, bits[0]) 'option'
) )
message_context = parser.compile_filter(value) message_context = parser.compile_filter(value)
elif option == 'as': elif option == 'as':
@ -118,16 +117,15 @@ def do_translate(parser, token):
value = remaining.pop(0) value = remaining.pop(0)
except IndexError: except IndexError:
raise TemplateSyntaxError( raise TemplateSyntaxError(
"No argument provided to the '%s' tag for the as option." % bits[0] f"No argument provided to the '{bits[0]}' tag for the as option."
) )
asvar = value asvar = value
elif option == 'escape': elif option == 'escape':
escape = True escape = True
else: else:
raise TemplateSyntaxError( raise TemplateSyntaxError(
"Unknown argument for '%s' tag: '%s'. The only options " f"Unknown argument for '{bits[0]}' tag: '{option}'. The only options "
"available are 'noop', 'context' \"xxx\", and 'as VAR'." "available are 'noop', 'context' \"xxx\", and 'as VAR'."
% (bits[0], option)
) )
seen.add(option) seen.add(option)

View File

@ -410,10 +410,7 @@ def progress_bar(val, max_val, *args, **kwargs):
else: else:
style = '' style = ''
if max_val != 0: percent = float(val / max_val) * 100 if max_val != 0 else 0
percent = float(val / max_val) * 100
else:
percent = 0
if percent > 100: if percent > 100:
percent = 100 percent = 100
@ -498,7 +495,7 @@ def primitive_to_javascript(primitive):
elif type(primitive) in [int, float]: elif type(primitive) in [int, float]:
return primitive return primitive
# Wrap with quotes # Wrap with quotes
return format_html("'{}'", primitive) # noqa: P103 return format_html("'{}'", primitive)
@register.simple_tag() @register.simple_tag()

View File

@ -70,11 +70,11 @@ class APITests(InvenTreeAPITestCase):
"""Helper function to use basic auth.""" """Helper function to use basic auth."""
# Use basic authentication # Use basic authentication
authstring = bytes('{u}:{p}'.format(u=self.username, p=self.password), 'ascii') authstring = bytes(f'{self.username}:{self.password}', 'ascii')
# Use "basic" auth by default # Use "basic" auth by default
auth = b64encode(authstring).decode('ascii') auth = b64encode(authstring).decode('ascii')
self.client.credentials(HTTP_AUTHORIZATION='Basic {auth}'.format(auth=auth)) self.client.credentials(HTTP_AUTHORIZATION=f'Basic {auth}')
def tokenAuth(self): def tokenAuth(self):
"""Helper function to use token auth.""" """Helper function to use token auth."""

View File

@ -70,7 +70,7 @@ class URLTest(TestCase):
pattern = '{% url [\'"]([^\'"]+)[\'"]([^%]*)%}' pattern = '{% url [\'"]([^\'"]+)[\'"]([^%]*)%}'
with open(input_file, 'r') as f: with open(input_file, encoding='utf-8') as f:
data = f.read() data = f.read()
results = re.findall(pattern, data) results = re.findall(pattern, data)

View File

@ -543,22 +543,22 @@ class FormatTest(TestCase):
def test_currency_formatting(self): def test_currency_formatting(self):
"""Test that currency formatting works correctly for multiple currencies.""" """Test that currency formatting works correctly for multiple currencies."""
test_data = ( test_data = (
(Money(3651.285718, 'USD'), 4, True, '$3,651.2857'), # noqa: E201,E202 (Money(3651.285718, 'USD'), 4, True, '$3,651.2857'),
(Money(487587.849178, 'CAD'), 5, True, 'CA$487,587.84918'), # noqa: E201,E202 (Money(487587.849178, 'CAD'), 5, True, 'CA$487,587.84918'),
(Money(0.348102, 'EUR'), 1, False, '0.3'), # noqa: E201,E202 (Money(0.348102, 'EUR'), 1, False, '0.3'),
(Money(0.916530, 'GBP'), 1, True, '£0.9'), # noqa: E201,E202 (Money(0.916530, 'GBP'), 1, True, '£0.9'),
(Money(61.031024, 'JPY'), 3, False, '61.031'), # noqa: E201,E202 (Money(61.031024, 'JPY'), 3, False, '61.031'),
(Money(49609.694602, 'JPY'), 1, True, '¥49,609.7'), # noqa: E201,E202 (Money(49609.694602, 'JPY'), 1, True, '¥49,609.7'),
(Money(155565.264777, 'AUD'), 2, False, '155,565.26'), # noqa: E201,E202 (Money(155565.264777, 'AUD'), 2, False, '155,565.26'),
(Money(0.820437, 'CNY'), 4, True, 'CN¥0.8204'), # noqa: E201,E202 (Money(0.820437, 'CNY'), 4, True, 'CN¥0.8204'),
(Money(7587.849178, 'EUR'), 0, True, '€7,588'), # noqa: E201,E202 (Money(7587.849178, 'EUR'), 0, True, '€7,588'),
(Money(0.348102, 'GBP'), 3, False, '0.348'), # noqa: E201,E202 (Money(0.348102, 'GBP'), 3, False, '0.348'),
(Money(0.652923, 'CHF'), 0, True, 'CHF1'), # noqa: E201,E202 (Money(0.652923, 'CHF'), 0, True, 'CHF1'),
(Money(0.820437, 'CNY'), 1, True, 'CN¥0.8'), # noqa: E201,E202 (Money(0.820437, 'CNY'), 1, True, 'CN¥0.8'),
(Money(98789.5295680, 'CHF'), 0, False, '98,790'), # noqa: E201,E202 (Money(98789.5295680, 'CHF'), 0, False, '98,790'),
(Money(0.585787, 'USD'), 1, True, '$0.6'), # noqa: E201,E202 (Money(0.585787, 'USD'), 1, True, '$0.6'),
(Money(0.690541, 'CAD'), 3, True, 'CA$0.691'), # noqa: E201,E202 (Money(0.690541, 'CAD'), 3, True, 'CA$0.691'),
(Money(427.814104, 'AUD'), 5, True, 'A$427.81410'), # noqa: E201,E202 (Money(427.814104, 'AUD'), 5, True, 'A$427.81410'),
) )
with self.settings(LANGUAGE_CODE='en-us'): with self.settings(LANGUAGE_CODE='en-us'):
@ -794,7 +794,7 @@ class TestDownloadFile(TestCase):
def test_download(self): def test_download(self):
"""Tests for DownloadFile.""" """Tests for DownloadFile."""
helpers.DownloadFile('hello world', 'out.txt') helpers.DownloadFile('hello world', 'out.txt')
helpers.DownloadFile(bytes(b'hello world'), 'out.bin') helpers.DownloadFile(b'hello world', 'out.bin')
class TestMPTT(TestCase): class TestMPTT(TestCase):
@ -1033,12 +1033,12 @@ class TestVersionNumber(TestCase):
# Check that the current .git values work too # Check that the current .git values work too
hash = str( git_hash = str(
subprocess.check_output('git rev-parse --short HEAD'.split()), 'utf-8' subprocess.check_output('git rev-parse --short HEAD'.split()), 'utf-8'
).strip() ).strip()
# On some systems the hash is a different length, so just check the first 6 characters # On some systems the hash is a different length, so just check the first 6 characters
self.assertEqual(hash[:6], version.inventreeCommitHash()[:6]) self.assertEqual(git_hash[:6], version.inventreeCommitHash()[:6])
d = ( d = (
str(subprocess.check_output('git show -s --format=%ci'.split()), 'utf-8') str(subprocess.check_output('git show -s --format=%ci'.split()), 'utf-8')
@ -1438,8 +1438,8 @@ class BarcodeMixinTest(InvenTreeTestCase):
'{"part": 17, "stockitem": 12}': 'c88c11ed0628eb7fef0d59b098b96975', '{"part": 17, "stockitem": 12}': 'c88c11ed0628eb7fef0d59b098b96975',
} }
for barcode, hash in hashing_tests.items(): for barcode, expected in hashing_tests.items():
self.assertEqual(InvenTree.helpers.hash_barcode(barcode), hash) self.assertEqual(InvenTree.helpers.hash_barcode(barcode), expected)
class SanitizerTest(TestCase): class SanitizerTest(TestCase):
@ -1557,15 +1557,12 @@ class ClassValidationMixinTest(TestCase):
def test(self): def test(self):
"""Test function.""" """Test function."""
...
def test1(self): def test1(self):
"""Test function.""" """Test function."""
...
def test2(self): def test2(self):
"""Test function.""" """Test function."""
...
required_attributes = ['NAME'] required_attributes = ['NAME']
required_overrides = [test, [test1, test2]] required_overrides = [test, [test1, test2]]
@ -1573,8 +1570,6 @@ class ClassValidationMixinTest(TestCase):
class InvalidClass: class InvalidClass:
"""An invalid class that does not inherit from ClassValidationMixin.""" """An invalid class that does not inherit from ClassValidationMixin."""
pass
def test_valid_class(self): def test_valid_class(self):
"""Test that a valid class passes the validation.""" """Test that a valid class passes the validation."""
@ -1585,11 +1580,9 @@ class ClassValidationMixinTest(TestCase):
def test(self): def test(self):
"""Test function.""" """Test function."""
...
def test2(self): def test2(self):
"""Test function.""" """Test function."""
...
TestClass.validate() TestClass.validate()
@ -1612,7 +1605,6 @@ class ClassValidationMixinTest(TestCase):
def test2(self): def test2(self):
"""Test function.""" """Test function."""
...
with self.assertRaisesRegex( with self.assertRaisesRegex(
NotImplementedError, NotImplementedError,
@ -1627,8 +1619,6 @@ class ClassProviderMixinTest(TestCase):
class TestClass(ClassProviderMixin): class TestClass(ClassProviderMixin):
"""This class is a dummy class to test the ClassProviderMixin.""" """This class is a dummy class to test the ClassProviderMixin."""
pass
def test_get_provider_file(self): def test_get_provider_file(self):
"""Test the get_provider_file function.""" """Test the get_provider_file function."""
self.assertEqual(self.TestClass.get_provider_file(), __file__) self.assertEqual(self.TestClass.get_provider_file(), __file__)

View File

@ -15,7 +15,7 @@ def reload_translation_stats():
STATS_FILE = settings.BASE_DIR.joinpath('InvenTree/locale_stats.json').absolute() STATS_FILE = settings.BASE_DIR.joinpath('InvenTree/locale_stats.json').absolute()
try: try:
with open(STATS_FILE, 'r') as f: with open(STATS_FILE, encoding='utf-8') as f:
_translation_stats = json.load(f) _translation_stats = json.load(f)
except Exception: except Exception:
_translation_stats = None _translation_stats = None

View File

@ -246,8 +246,6 @@ class ExchangeRateMixin:
class InvenTreeTestCase(ExchangeRateMixin, UserMixin, TestCase): class InvenTreeTestCase(ExchangeRateMixin, UserMixin, TestCase):
"""Testcase with user setup build in.""" """Testcase with user setup build in."""
pass
class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase): class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
"""Base class for running InvenTree API tests.""" """Base class for running InvenTree API tests."""
@ -283,9 +281,7 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
) # pragma: no cover ) # pragma: no cover
if verbose or n >= value: if verbose or n >= value:
msg = '\r\n%s' % json.dumps( msg = f'\r\n{json.dumps(context.captured_queries, indent=4)}' # pragma: no cover
context.captured_queries, indent=4
) # pragma: no cover
else: else:
msg = None msg = None

View File

@ -456,7 +456,7 @@ urlpatterns = []
if settings.INVENTREE_ADMIN_ENABLED: if settings.INVENTREE_ADMIN_ENABLED:
admin_url = settings.INVENTREE_ADMIN_URL admin_url = settings.INVENTREE_ADMIN_URL
if settings.ADMIN_SHELL_ENABLE: # noqa if settings.ADMIN_SHELL_ENABLE:
urlpatterns += [path(f'{admin_url}/shell/', include('django_admin_shell.urls'))] urlpatterns += [path(f'{admin_url}/shell/', include('django_admin_shell.urls'))]
urlpatterns += [ urlpatterns += [

View File

@ -96,7 +96,6 @@ def validate_sales_order_reference(value):
def validate_tree_name(value): def validate_tree_name(value):
"""Placeholder for legacy function used in migrations.""" """Placeholder for legacy function used in migrations."""
...
def validate_overage(value): def validate_overage(value):

View File

@ -180,7 +180,7 @@ class InvenTreeRoleMixin(PermissionRequiredMixin):
AjaxUpdateView: 'change', AjaxUpdateView: 'change',
} }
for view_class in permission_map.keys(): for view_class in permission_map:
if issubclass(type(self), view_class): if issubclass(type(self), view_class):
return permission_map[view_class] return permission_map[view_class]
@ -238,7 +238,6 @@ class AjaxMixin(InvenTreeRoleMixin):
Ref: https://docs.djangoproject.com/en/dev/topics/forms/ Ref: https://docs.djangoproject.com/en/dev/topics/forms/
""" """
# Do nothing by default # Do nothing by default
pass
def renderJsonResponse(self, request, form=None, data=None, context=None): def renderJsonResponse(self, request, form=None, data=None, context=None):
"""Render a JSON response based on specific class context. """Render a JSON response based on specific class context.
@ -286,7 +285,7 @@ class AjaxMixin(InvenTreeRoleMixin):
# Custom feedback`data # Custom feedback`data
fb = self.get_data() fb = self.get_data()
for key in fb.keys(): for key in fb:
data[key] = fb[key] data[key] = fb[key]
return JsonResponse(data, safe=False) return JsonResponse(data, safe=False)
@ -329,11 +328,11 @@ class AjaxUpdateView(AjaxMixin, UpdateView):
request, self.get_form(), context=self.get_context_data() request, self.get_form(), context=self.get_context_data()
) )
def save(self, object, form, **kwargs): def save(self, obj, form, **kwargs):
"""Method for updating the object in the database. Default implementation is very simple, but can be overridden if required. """Method for updating the object in the database. Default implementation is very simple, but can be overridden if required.
Args: Args:
object: The current object, to be updated obj: The current object, to be updated
form: The validated form form: The validated form
Returns: Returns:
@ -578,14 +577,10 @@ class UserSessionOverride:
class CustomSessionDeleteView(UserSessionOverride, SessionDeleteView): class CustomSessionDeleteView(UserSessionOverride, SessionDeleteView):
"""Revert to settings after session delete.""" """Revert to settings after session delete."""
pass
class CustomSessionDeleteOtherView(UserSessionOverride, SessionDeleteOtherView): class CustomSessionDeleteOtherView(UserSessionOverride, SessionDeleteOtherView):
"""Revert to settings after session delete.""" """Revert to settings after session delete."""
pass
class CustomLoginView(LoginView): class CustomLoginView(LoginView):
"""Custom login view that allows login with urlargs.""" """Custom login view that allows login with urlargs."""

View File

@ -1,34 +1,31 @@
"""Background task definitions for the BuildOrder app""" """Background task definitions for the BuildOrder app."""
import logging
from datetime import timedelta from datetime import timedelta
from decimal import Decimal from decimal import Decimal
import logging
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.utils.translation import gettext_lazy as _
from django.template.loader import render_to_string from django.template.loader import render_to_string
from django.utils.translation import gettext_lazy as _
from allauth.account.models import EmailAddress from allauth.account.models import EmailAddress
from plugin.events import trigger_event
import common.notifications
import build.models import build.models
import InvenTree.email import common.notifications
import InvenTree.helpers import InvenTree.helpers
import InvenTree.helpers_email
import InvenTree.helpers_model import InvenTree.helpers_model
import InvenTree.tasks import InvenTree.tasks
from InvenTree.ready import isImportingData
from build.status_codes import BuildStatusGroups
import part.models as part_models import part.models as part_models
from build.status_codes import BuildStatusGroups
from InvenTree.ready import isImportingData
from plugin.events import trigger_event
logger = logging.getLogger('inventree') logger = logging.getLogger('inventree')
def auto_allocate_build(build_id: int, **kwargs): def auto_allocate_build(build_id: int, **kwargs):
"""Run auto-allocation for a specified BuildOrder.""" """Run auto-allocation for a specified BuildOrder."""
build_order = build.models.Build.objects.filter(pk=build_id).first() build_order = build.models.Build.objects.filter(pk=build_id).first()
if not build_order: if not build_order:
@ -40,7 +37,6 @@ def auto_allocate_build(build_id: int, **kwargs):
def complete_build_allocations(build_id: int, user_id: int): def complete_build_allocations(build_id: int, user_id: int):
"""Complete build allocations for a specified BuildOrder.""" """Complete build allocations for a specified BuildOrder."""
build_order = build.models.Build.objects.filter(pk=build_id).first() build_order = build.models.Build.objects.filter(pk=build_id).first()
if user_id: if user_id:
@ -185,7 +181,7 @@ def check_build_stock(build: build.models.Build):
recipients = emails.values_list('email', flat=True) recipients = emails.values_list('email', flat=True)
InvenTree.email.send_email(subject, '', recipients, html_message=html_message) InvenTree.helpers_email.send_email(subject, '', recipients, html_message=html_message)
def create_child_builds(build_id: int) -> None: def create_child_builds(build_id: int) -> None:
@ -225,7 +221,7 @@ def create_child_builds(build_id: int) -> None:
def notify_overdue_build_order(bo: build.models.Build): def notify_overdue_build_order(bo: build.models.Build):
"""Notify appropriate users that a Build has just become 'overdue'""" """Notify appropriate users that a Build has just become 'overdue'."""
targets = [] targets = []
if bo.issued_by: if bo.issued_by:
@ -265,7 +261,7 @@ def notify_overdue_build_order(bo: build.models.Build):
@InvenTree.tasks.scheduled_task(InvenTree.tasks.ScheduledTask.DAILY) @InvenTree.tasks.scheduled_task(InvenTree.tasks.ScheduledTask.DAILY)
def check_overdue_build_orders(): def check_overdue_build_orders():
"""Check if any outstanding BuildOrders have just become overdue """Check if any outstanding BuildOrders have just become overdue.
- This check is performed daily - This check is performed daily
- Look at the 'target_date' of any outstanding BuildOrder objects - Look at the 'target_date' of any outstanding BuildOrder objects

View File

@ -47,7 +47,7 @@ from plugin.models import NotificationUserSetting
from plugin.serializers import NotificationUserSettingSerializer from plugin.serializers import NotificationUserSettingSerializer
class CsrfExemptMixin(object): class CsrfExemptMixin:
"""Exempts the view from CSRF requirements.""" """Exempts the view from CSRF requirements."""
@method_decorator(csrf_exempt) @method_decorator(csrf_exempt)
@ -136,7 +136,7 @@ class CurrencyExchangeView(APIView):
serializer_class = None serializer_class = None
@extend_schema(responses={200: common.serializers.CurrencyExchangeSerializer}) @extend_schema(responses={200: common.serializers.CurrencyExchangeSerializer})
def get(self, request, format=None): def get(self, request, fmt=None):
"""Return information on available currency conversions.""" """Return information on available currency conversions."""
# Extract a list of all available rates # Extract a list of all available rates
try: try:
@ -244,10 +244,7 @@ class GlobalSettingsDetail(RetrieveUpdateAPI):
"""Attempt to find a global setting object with the provided key.""" """Attempt to find a global setting object with the provided key."""
key = str(self.kwargs['key']).upper() key = str(self.kwargs['key']).upper()
if ( if key.startswith('_') or key not in common.models.InvenTreeSetting.SETTINGS:
key.startswith('_')
or key not in common.models.InvenTreeSetting.SETTINGS.keys()
):
raise NotFound() raise NotFound()
return common.models.InvenTreeSetting.get_setting_object( return common.models.InvenTreeSetting.get_setting_object(
@ -318,7 +315,7 @@ class UserSettingsDetail(RetrieveUpdateAPI):
if ( if (
key.startswith('_') key.startswith('_')
or key not in common.models.InvenTreeUserSetting.SETTINGS.keys() or key not in common.models.InvenTreeUserSetting.SETTINGS
): ):
raise NotFound() raise NotFound()
@ -566,7 +563,7 @@ class BackgroundTaskOverview(APIView):
permission_classes = [permissions.IsAuthenticated, IsAdminUser] permission_classes = [permissions.IsAuthenticated, IsAdminUser]
serializer_class = None serializer_class = None
def get(self, request, format=None): def get(self, request, fmt=None):
"""Return information about the current status of the background task queue.""" """Return information about the current status of the background task queue."""
import django_q.models as q_models import django_q.models as q_models

View File

@ -65,7 +65,7 @@ class MatchFieldForm(forms.Form):
for col in columns: for col in columns:
field_name = col['name'] field_name = col['name']
self.fields[field_name] = forms.ChoiceField( self.fields[field_name] = forms.ChoiceField(
choices=[('', '-' * 10)] + headers_choices, choices=[('', '-' * 10), *headers_choices],
required=False, required=False,
widget=forms.Select(attrs={'class': 'select fieldselect'}), widget=forms.Select(attrs={'class': 'select fieldselect'}),
) )
@ -83,10 +83,7 @@ class MatchItemForm(forms.Form):
if 'file_manager' in kwargs: if 'file_manager' in kwargs:
file_manager = kwargs.pop('file_manager') file_manager = kwargs.pop('file_manager')
if 'row_data' in kwargs: row_data = kwargs.pop('row_data') if 'row_data' in kwargs else None
row_data = kwargs.pop('row_data')
else:
row_data = None
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -134,7 +131,7 @@ class MatchItemForm(forms.Form):
item_match = row['match_' + col_guess] item_match = row['match_' + col_guess]
# Set field select box # Set field select box
self.fields[field_name] = forms.ChoiceField( self.fields[field_name] = forms.ChoiceField(
choices=[('', '-' * 10)] + item_options, choices=[('', '-' * 10), *item_options],
required=False, required=False,
widget=forms.Select(attrs={'class': 'select bomselect'}), widget=forms.Select(attrs={'class': 'select bomselect'}),
) )
@ -154,7 +151,7 @@ class MatchItemForm(forms.Form):
field_name = 'item_select-' + str(row['index']) field_name = 'item_select-' + str(row['index'])
# Set field select box # Set field select box
self.fields[field_name] = forms.ChoiceField( self.fields[field_name] = forms.ChoiceField(
choices=[('', '-' * 10)] + item_options, choices=[('', '-' * 10), *item_options],
required=False, required=False,
widget=forms.Select(attrs={'class': 'select bomselect'}), widget=forms.Select(attrs={'class': 'select bomselect'}),
) )

View File

@ -55,7 +55,7 @@ def get_icon_packs():
tabler_icons_path = Path(__file__).parent.parent.joinpath( tabler_icons_path = Path(__file__).parent.parent.joinpath(
'InvenTree/static/tabler-icons/icons.json' 'InvenTree/static/tabler-icons/icons.json'
) )
with open(tabler_icons_path, 'r') as tabler_icons_file: with open(tabler_icons_path, encoding='utf-8') as tabler_icons_file:
tabler_icons = json.load(tabler_icons_file) tabler_icons = json.load(tabler_icons_file)
icon_packs = [ icon_packs = [

View File

@ -248,7 +248,7 @@ class BaseInvenTreeSetting(models.Model):
If a particular setting is not present, create it with the default value If a particular setting is not present, create it with the default value
""" """
cache_key = f'BUILD_DEFAULT_VALUES:{str(cls.__name__)}' cache_key = f'BUILD_DEFAULT_VALUES:{cls.__name__!s}'
try: try:
if InvenTree.helpers.str2bool(cache.get(cache_key, False)): if InvenTree.helpers.str2bool(cache.get(cache_key, False)):
@ -331,7 +331,7 @@ class BaseInvenTreeSetting(models.Model):
- The unique KEY string - The unique KEY string
- Any key:value kwargs associated with the particular setting type (e.g. user-id) - Any key:value kwargs associated with the particular setting type (e.g. user-id)
""" """
key = f'{str(cls.__name__)}:{setting_key}' key = f'{cls.__name__!s}:{setting_key}'
for k, v in kwargs.items(): for k, v in kwargs.items():
key += f'_{k}:{v}' key += f'_{k}:{v}'
@ -2062,7 +2062,7 @@ class InvenTreeSetting(BaseInvenTreeSetting):
'description': _( 'description': _(
'Check that all plugins are installed on startup - enable in container environments' 'Check that all plugins are installed on startup - enable in container environments'
), ),
'default': str(os.getenv('INVENTREE_DOCKER', False)).lower() 'default': str(os.getenv('INVENTREE_DOCKER', 'False')).lower()
in ['1', 'true'], in ['1', 'true'],
'validator': bool, 'validator': bool,
'requires_restart': True, 'requires_restart': True,
@ -3092,13 +3092,10 @@ class CustomUnit(models.Model):
"""Ensure that the custom unit is unique.""" """Ensure that the custom unit is unique."""
super().validate_unique(exclude) super().validate_unique(exclude)
if self.symbol: if self.symbol and (
if ( CustomUnit.objects.filter(symbol=self.symbol).exclude(pk=self.pk).exists()
CustomUnit.objects.filter(symbol=self.symbol) ):
.exclude(pk=self.pk) raise ValidationError({'symbol': _('Unit symbol must be unique')})
.exists()
):
raise ValidationError({'symbol': _('Unit symbol must be unique')})
def clean(self): def clean(self):
"""Validate that the provided custom unit is indeed valid.""" """Validate that the provided custom unit is indeed valid."""
@ -3433,13 +3430,13 @@ class InvenTreeCustomUserStateModel(models.Model):
'reference_status': _('Reference status set not found') 'reference_status': _('Reference status set not found')
}) })
ref_set = ref_set[0] ref_set = ref_set[0]
if self.key in ref_set.keys(): if self.key in ref_set.keys(): # noqa: SIM118
raise ValidationError({ raise ValidationError({
'key': _( 'key': _(
'Key must be different from the logical keys of the reference status' 'Key must be different from the logical keys of the reference status'
) )
}) })
if self.logical_key not in ref_set.keys(): if self.logical_key not in ref_set.keys(): # noqa: SIM118
raise ValidationError({ raise ValidationError({
'logical_key': _( 'logical_key': _(
'Logical key must be in the logical keys of the reference status' 'Logical key must be in the logical keys of the reference status'

View File

@ -143,11 +143,7 @@ class NotificationMethod:
# Check if method globally enabled # Check if method globally enabled
plg_instance = registry.get_plugin(plg_cls.NAME.lower()) plg_instance = registry.get_plugin(plg_cls.NAME.lower())
if plg_instance and not plg_instance.get_setting(self.GLOBAL_SETTING): return plg_instance and not plg_instance.get_setting(self.GLOBAL_SETTING)
return True
# Lets go!
return False
def usersetting(self, target): def usersetting(self, target):
"""Returns setting for this method for a given user.""" """Returns setting for this method for a given user."""
@ -365,7 +361,7 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs):
obj_ref_value = getattr(obj, 'id', None) obj_ref_value = getattr(obj, 'id', None)
if not obj_ref_value: if not obj_ref_value:
raise KeyError( raise KeyError(
f"Could not resolve an object reference for '{str(obj)}' with {obj_ref}, pk, id" f"Could not resolve an object reference for '{obj!s}' with {obj_ref}, pk, id"
) )
# Check if we have notified recently... # Check if we have notified recently...
@ -432,9 +428,9 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs):
deliver_notification(method, obj, category, target_users, context) deliver_notification(method, obj, category, target_users, context)
except NotImplementedError as error: except NotImplementedError as error:
# Allow any single notification method to fail, without failing the others # Allow any single notification method to fail, without failing the others
logger.error(error) # noqa: LOG005 logger.error(error)
except Exception as error: except Exception as error:
logger.error(error) # noqa: LOG005 logger.error(error)
# Set delivery flag # Set delivery flag
common.models.NotificationEntry.notify(category, obj_ref_value) common.models.NotificationEntry.notify(category, obj_ref_value)

View File

@ -574,9 +574,8 @@ class AttachmentSerializer(InvenTreeModelSerializer):
model_type = self.validated_data.get('model_type', None) model_type = self.validated_data.get('model_type', None)
if model_type is None: if model_type is None and self.instance:
if self.instance: model_type = self.instance.model_type
model_type = self.instance.model_type
# Ensure that the user has permission to attach files to the specified model # Ensure that the user has permission to attach files to the specified model
user = self.context.get('request').user user = self.context.get('request').user

View File

@ -393,7 +393,7 @@ class SettingsTest(InvenTreeTestCase):
'before_save', 'before_save',
] ]
for k in setting.keys(): for k in setting:
self.assertIn(k, allowed_keys) self.assertIn(k, allowed_keys)
# Check default value for boolean settings # Check default value for boolean settings
@ -431,7 +431,7 @@ class SettingsTest(InvenTreeTestCase):
@override_settings(SITE_URL=None, PLUGIN_TESTING=True, PLUGIN_TESTING_SETUP=True) @override_settings(SITE_URL=None, PLUGIN_TESTING=True, PLUGIN_TESTING_SETUP=True)
def test_defaults(self): def test_defaults(self):
"""Populate the settings with default values.""" """Populate the settings with default values."""
for key in InvenTreeSetting.SETTINGS.keys(): for key in InvenTreeSetting.SETTINGS:
value = InvenTreeSetting.get_setting_default(key) value = InvenTreeSetting.get_setting_default(key)
try: try:
@ -520,7 +520,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
response = self.get(url, expected_code=200) response = self.get(url, expected_code=200)
n_public_settings = len([ n_public_settings = len([
k for k in InvenTreeSetting.SETTINGS.keys() if not k.startswith('_') k for k in InvenTreeSetting.SETTINGS if not k.startswith('_')
]) ])
# Number of results should match the number of settings # Number of results should match the number of settings
@ -832,11 +832,9 @@ class PluginSettingsApiTest(PluginMixin, InvenTreeAPITestCase):
def test_invalid_setting_key(self): def test_invalid_setting_key(self):
"""Test that an invalid setting key returns a 404.""" """Test that an invalid setting key returns a 404."""
...
def test_uninitialized_setting(self): def test_uninitialized_setting(self):
"""Test that requesting an uninitialized setting creates the setting.""" """Test that requesting an uninitialized setting creates the setting."""
...
class ErrorReportTest(InvenTreeAPITestCase): class ErrorReportTest(InvenTreeAPITestCase):
@ -934,7 +932,7 @@ class WebhookMessageTests(TestCase):
def test_bad_token(self): def test_bad_token(self):
"""Test that a wrong token is not working.""" """Test that a wrong token is not working."""
response = self.client.post( response = self.client.post(
self.url, content_type=CONTENT_TYPE_JSON, **{'HTTP_TOKEN': '1234567fghj'} self.url, content_type=CONTENT_TYPE_JSON, HTTP_TOKEN='1234567fghj'
) )
assert response.status_code == HTTPStatus.FORBIDDEN assert response.status_code == HTTPStatus.FORBIDDEN
@ -957,7 +955,7 @@ class WebhookMessageTests(TestCase):
self.url, self.url,
data="{'this': 123}", data="{'this': 123}",
content_type=CONTENT_TYPE_JSON, content_type=CONTENT_TYPE_JSON,
**{'HTTP_TOKEN': str(self.endpoint_def.token)}, HTTP_TOKEN=str(self.endpoint_def.token),
) )
assert response.status_code == HTTPStatus.NOT_ACCEPTABLE assert response.status_code == HTTPStatus.NOT_ACCEPTABLE
@ -1005,7 +1003,7 @@ class WebhookMessageTests(TestCase):
response = self.client.post( response = self.client.post(
self.url, self.url,
content_type=CONTENT_TYPE_JSON, content_type=CONTENT_TYPE_JSON,
**{'HTTP_TOKEN': str('68MXtc/OiXdA5e2Nq9hATEVrZFpLb3Zb0oau7n8s31I=')}, HTTP_TOKEN='68MXtc/OiXdA5e2Nq9hATEVrZFpLb3Zb0oau7n8s31I=',
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
@ -1020,7 +1018,7 @@ class WebhookMessageTests(TestCase):
self.url, self.url,
data={'this': 'is a message'}, data={'this': 'is a message'},
content_type=CONTENT_TYPE_JSON, content_type=CONTENT_TYPE_JSON,
**{'HTTP_TOKEN': str(self.endpoint_def.token)}, HTTP_TOKEN=str(self.endpoint_def.token),
) )
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK

View File

@ -312,7 +312,6 @@ class FileManagementFormView(MultiStepFormView):
This method is very specific to the type of data found in the file, This method is very specific to the type of data found in the file,
therefore overwrite it in the subclass. therefore overwrite it in the subclass.
""" """
pass
def get_clean_items(self): def get_clean_items(self):
"""Returns dict with all cleaned values.""" """Returns dict with all cleaned values."""
@ -363,10 +362,7 @@ class FileManagementFormView(MultiStepFormView):
duplicates = [] duplicates = []
for col in self.column_names: for col in self.column_names:
if col in self.column_selections: guess = self.column_selections.get(col, None)
guess = self.column_selections[col]
else:
guess = None
if guess: if guess:
n = list(self.column_selections.values()).count( n = list(self.column_selections.values()).count(

View File

@ -10,4 +10,3 @@ class CompanyConfig(AppConfig):
def ready(self): def ready(self):
"""This function is called whenever the Company app is loaded.""" """This function is called whenever the Company app is loaded."""
pass

View File

@ -46,10 +46,7 @@ def rename_company_image(instance, filename):
""" """
base = 'company_images' base = 'company_images'
if filename.count('.') > 0: ext = filename.split('.')[-1] if filename.count('.') > 0 else ''
ext = filename.split('.')[-1]
else:
ext = ''
fn = f'company_{instance.pk}_img' fn = f'company_{instance.pk}_img'

View File

@ -11,8 +11,8 @@ from .transition import StateTransitionMixin, TransitionMethod, storage
__all__ = [ __all__ = [
'ColorEnum', 'ColorEnum',
'StatusCode',
'storage',
'TransitionMethod',
'StateTransitionMixin', 'StateTransitionMixin',
'StatusCode',
'TransitionMethod',
'storage',
] ]

View File

@ -5,7 +5,7 @@ import re
from enum import Enum from enum import Enum
class BaseEnum(enum.IntEnum): class BaseEnum(enum.IntEnum): # noqa: PLW1641
"""An `Enum` capabile of having its members have docstrings. """An `Enum` capabile of having its members have docstrings.
Based on https://stackoverflow.com/questions/19330460/how-do-i-put-docstrings-on-enums Based on https://stackoverflow.com/questions/19330460/how-do-i-put-docstrings-on-enums
@ -100,9 +100,7 @@ class StatusCode(BaseEnum):
return False return False
if callable(value): if callable(value):
return False return False
if not isinstance(value.value, int): return isinstance(value.value, int)
return False
return True
@classmethod @classmethod
def values(cls, key=None): def values(cls, key=None):

View File

@ -27,7 +27,6 @@ class GeneralStatus(StatusCode):
def GHI(self): # This should be ignored def GHI(self): # This should be ignored
"""A invalid function.""" """A invalid function."""
...
class GeneralStateTest(InvenTreeTestCase): class GeneralStateTest(InvenTreeTestCase):

View File

@ -3,6 +3,6 @@
from django import template from django import template
register = template.Library() register = template.Library()
from generic.states.tags import status_label # noqa: E402 from generic.states.tags import status_label
__all__ = [status_label] __all__ = ['status_label']

View File

@ -213,8 +213,6 @@ class DataImportExportSerializerMixin(
): ):
"""Mixin class for adding data import/export functionality to a DRF serializer.""" """Mixin class for adding data import/export functionality to a DRF serializer."""
pass
class DataExportViewMixin: class DataExportViewMixin:
"""Mixin class for exporting a dataset via the API. """Mixin class for exporting a dataset via the API.

View File

@ -2,6 +2,7 @@
import json import json
import logging import logging
from typing import Optional
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.exceptions import ValidationError as DjangoValidationError from django.core.exceptions import ValidationError as DjangoValidationError
@ -117,8 +118,8 @@ class DataImportSession(models.Model):
""" """
mapping = {} mapping = {}
for map in self.column_mappings.all(): for i in self.column_mappings.all():
mapping[map.field] = map.column mapping[i.field] = i.column
return mapping return mapping
@ -209,13 +210,13 @@ class DataImportSession(models.Model):
missing_fields = [] missing_fields = []
for field in required_fields.keys(): for field in required_fields:
# An override value exists # An override value exists
if field in field_overrides: if field in field_overrides:
continue continue
# A default value exists # A default value exists
if field in field_defaults and field_defaults[field]: if field_defaults.get(field):
continue continue
# The field has been mapped to a data column # The field has been mapped to a data column
@ -537,7 +538,10 @@ class DataImportRow(models.Model):
return overrides return overrides
def extract_data( def extract_data(
self, available_fields: dict = None, field_mapping: dict = None, commit=True self,
available_fields: Optional[dict] = None,
field_mapping: Optional[dict] = None,
commit=True,
): ):
"""Extract row data from the provided data dictionary.""" """Extract row data from the provided data dictionary."""
if not field_mapping: if not field_mapping:

View File

@ -36,7 +36,7 @@ def load_data_file(data_file, file_format=None):
try: try:
data = file_object.read() data = file_object.read()
except (IOError, FileNotFoundError): except (OSError, FileNotFoundError):
raise ValidationError(_('Failed to open data file')) raise ValidationError(_('Failed to open data file'))
# Excel formats expect binary data # Excel formats expect binary data
@ -113,9 +113,8 @@ def get_field_label(field) -> str:
Returns: Returns:
str: Field label str: Field label
""" """
if field: if field and (label := getattr(field, 'label', None)):
if label := getattr(field, 'label', None): return label
return label
# TODO: Check if the field is a model field # TODO: Check if the field is a model field

View File

@ -19,7 +19,7 @@ class ImporterTest(InvenTreeTestCase):
fn = os.path.join(os.path.dirname(__file__), 'test_data', 'companies.csv') fn = os.path.join(os.path.dirname(__file__), 'test_data', 'companies.csv')
with open(fn, 'r') as input_file: with open(fn, encoding='utf-8') as input_file:
data = input_file.read() data = input_file.read()
session = DataImportSession.objects.create( session = DataImportSession.objects.create(
@ -61,4 +61,3 @@ class ImporterTest(InvenTreeTestCase):
def test_field_defaults(self): def test_field_defaults(self):
"""Test default field values.""" """Test default field values."""
...

View File

@ -36,7 +36,7 @@ def validate_importer_model_type(value):
"""Validate that the given model type is supported for importing.""" """Validate that the given model type is supported for importing."""
from importer.registry import supported_models from importer.registry import supported_models
if value not in supported_models().keys(): if value not in supported_models():
raise ValidationError(f"Unsupported model type '{value}'") raise ValidationError(f"Unsupported model type '{value}'")

View File

@ -1,4 +1,4 @@
from machine.machine_type import BaseDriver, BaseMachineType, MachineStatus from machine.machine_type import BaseDriver, BaseMachineType, MachineStatus
from machine.registry import registry from machine.registry import registry
__all__ = ['registry', 'BaseMachineType', 'BaseDriver', 'MachineStatus'] __all__ = ['BaseDriver', 'BaseMachineType', 'MachineStatus', 'registry']

View File

@ -139,7 +139,7 @@ class BaseDriver(
Arguments: Arguments:
error: Exception or string error: Exception or string
""" """
self.set_shared_state('errors', self.errors + [error]) self.set_shared_state('errors', [*self.errors, error])
# --- state getters/setters # --- state getters/setters
@property @property
@ -317,7 +317,7 @@ class BaseMachineType(
Arguments: Arguments:
error: Exception or string error: Exception or string
""" """
self.set_shared_state('errors', self.errors + [error]) self.set_shared_state('errors', [*self.errors, error])
def reset_errors(self): def reset_errors(self):
"""Helper function for resetting the error list for a machine.""" """Helper function for resetting the error list for a machine."""

View File

@ -4,8 +4,8 @@ from machine.machine_types.label_printer import (
) )
__all__ = [ __all__ = [
# machine types
'LabelPrinterMachine',
# base drivers # base drivers
'LabelPrinterBaseDriver', 'LabelPrinterBaseDriver',
# machine types
'LabelPrinterMachine',
] ]

View File

@ -38,7 +38,7 @@ class MachineRegistry(
def handle_error(self, error: Union[Exception, str]): def handle_error(self, error: Union[Exception, str]):
"""Helper function for capturing errors with the machine registry.""" """Helper function for capturing errors with the machine registry."""
self.set_shared_state('errors', self.errors + [error]) self.set_shared_state('errors', [*self.errors, error])
def initialize(self, main: bool = False): def initialize(self, main: bool = False):
"""Initialize the machine registry.""" """Initialize the machine registry."""

View File

@ -84,13 +84,11 @@ class MachineAPITest(TestMachineRegistryMixin, InvenTreeAPITestCase):
machine_type, machine_type,
{ {
**machine_type, **machine_type,
**{ 'slug': 'label-printer',
'slug': 'label-printer', 'name': 'Label Printer',
'name': 'Label Printer', 'description': 'Directly print labels for various items.',
'description': 'Directly print labels for various items.', 'provider_plugin': None,
'provider_plugin': None, 'is_builtin': True,
'is_builtin': True,
},
}, },
) )
self.assertTrue( self.assertTrue(
@ -109,15 +107,13 @@ class MachineAPITest(TestMachineRegistryMixin, InvenTreeAPITestCase):
driver, driver,
{ {
**driver, **driver,
**{ 'slug': 'test-label-printer-api',
'slug': 'test-label-printer-api', 'name': 'Test label printer',
'name': 'Test label printer', 'description': 'This is a test label printer driver for testing.',
'description': 'This is a test label printer driver for testing.', 'provider_plugin': None,
'provider_plugin': None, 'is_builtin': True,
'is_builtin': True, 'machine_type': 'label-printer',
'machine_type': 'label-printer', 'driver_errors': [],
'driver_errors': [],
},
}, },
) )
self.assertEqual(driver['provider_file'], __file__) self.assertEqual(driver['provider_file'], __file__)
@ -173,17 +169,15 @@ class MachineAPITest(TestMachineRegistryMixin, InvenTreeAPITestCase):
response.data[0], response.data[0],
{ {
**response.data[0], **response.data[0],
**{ 'name': 'Test Machine',
'name': 'Test Machine', 'machine_type': 'label-printer',
'machine_type': 'label-printer', 'driver': 'test-label-printer-api',
'driver': 'test-label-printer-api', 'initialized': True,
'initialized': True, 'active': True,
'active': True, 'status': 101,
'status': 101, 'status_model': 'LabelPrinterStatus',
'status_model': 'LabelPrinterStatus', 'status_text': '',
'status_text': '', 'is_driver_available': True,
'is_driver_available': True,
},
}, },
) )
@ -216,9 +210,7 @@ class MachineAPITest(TestMachineRegistryMixin, InvenTreeAPITestCase):
reverse('api-machine-detail', kwargs={'pk': pk}), reverse('api-machine-detail', kwargs={'pk': pk}),
{'name': 'Updated Machine'}, {'name': 'Updated Machine'},
) )
self.assertEqual( self.assertEqual(response.data, {**response.data, 'name': 'Updated Machine'})
response.data, {**response.data, **{'name': 'Updated Machine'}}
)
self.assertEqual(MachineConfig.objects.get(pk=pk).name, 'Updated Machine') self.assertEqual(MachineConfig.objects.get(pk=pk).name, 'Updated Machine')
# Delete the machine # Delete the machine

View File

@ -337,8 +337,6 @@ class PurchaseOrderList(PurchaseOrderMixin, DataExportViewMixin, ListCreateAPI):
class PurchaseOrderDetail(PurchaseOrderMixin, RetrieveUpdateDestroyAPI): class PurchaseOrderDetail(PurchaseOrderMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a PurchaseOrder object.""" """API endpoint for detail view of a PurchaseOrder object."""
pass
class PurchaseOrderContextMixin: class PurchaseOrderContextMixin:
"""Mixin to add purchase order object as serializer context variable.""" """Mixin to add purchase order object as serializer context variable."""
@ -603,8 +601,6 @@ class PurchaseOrderLineItemList(
class PurchaseOrderLineItemDetail(PurchaseOrderLineItemMixin, RetrieveUpdateDestroyAPI): class PurchaseOrderLineItemDetail(PurchaseOrderLineItemMixin, RetrieveUpdateDestroyAPI):
"""Detail API endpoint for PurchaseOrderLineItem object.""" """Detail API endpoint for PurchaseOrderLineItem object."""
pass
class PurchaseOrderExtraLineList(GeneralExtraLineList, ListCreateAPI): class PurchaseOrderExtraLineList(GeneralExtraLineList, ListCreateAPI):
"""API endpoint for accessing a list of PurchaseOrderExtraLine objects.""" """API endpoint for accessing a list of PurchaseOrderExtraLine objects."""
@ -746,8 +742,6 @@ class SalesOrderList(SalesOrderMixin, DataExportViewMixin, ListCreateAPI):
class SalesOrderDetail(SalesOrderMixin, RetrieveUpdateDestroyAPI): class SalesOrderDetail(SalesOrderMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a SalesOrder object.""" """API endpoint for detail view of a SalesOrder object."""
pass
class SalesOrderLineItemFilter(LineItemFilter): class SalesOrderLineItemFilter(LineItemFilter):
"""Custom filters for SalesOrderLineItemList endpoint.""" """Custom filters for SalesOrderLineItemList endpoint."""
@ -865,8 +859,6 @@ class SalesOrderLineItemList(
class SalesOrderLineItemDetail(SalesOrderLineItemMixin, RetrieveUpdateDestroyAPI): class SalesOrderLineItemDetail(SalesOrderLineItemMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a SalesOrderLineItem object.""" """API endpoint for detail view of a SalesOrderLineItem object."""
pass
class SalesOrderExtraLineList(GeneralExtraLineList, ListCreateAPI): class SalesOrderExtraLineList(GeneralExtraLineList, ListCreateAPI):
"""API endpoint for accessing a list of SalesOrderExtraLine objects.""" """API endpoint for accessing a list of SalesOrderExtraLine objects."""
@ -1181,8 +1173,6 @@ class ReturnOrderList(ReturnOrderMixin, DataExportViewMixin, ListCreateAPI):
class ReturnOrderDetail(ReturnOrderMixin, RetrieveUpdateDestroyAPI): class ReturnOrderDetail(ReturnOrderMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single ReturnOrder object.""" """API endpoint for detail view of a single ReturnOrder object."""
pass
class ReturnOrderContextMixin: class ReturnOrderContextMixin:
"""Simple mixin class to add a ReturnOrder to the serializer context.""" """Simple mixin class to add a ReturnOrder to the serializer context."""
@ -1310,8 +1300,6 @@ class ReturnOrderLineItemList(
class ReturnOrderLineItemDetail(ReturnOrderLineItemMixin, RetrieveUpdateDestroyAPI): class ReturnOrderLineItemDetail(ReturnOrderLineItemMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a ReturnOrderLineItem object.""" """API endpoint for detail view of a ReturnOrderLineItem object."""
pass
class ReturnOrderExtraLineList(GeneralExtraLineList, ListCreateAPI): class ReturnOrderExtraLineList(GeneralExtraLineList, ListCreateAPI):
"""API endpoint for accessing a list of ReturnOrderExtraLine objects.""" """API endpoint for accessing a list of ReturnOrderExtraLine objects."""
@ -1368,10 +1356,9 @@ class OrderCalendarExport(ICalFeed):
if auth[0].lower() == 'basic': if auth[0].lower() == 'basic':
uname, passwd = base64.b64decode(auth[1]).decode('ascii').split(':') uname, passwd = base64.b64decode(auth[1]).decode('ascii').split(':')
user = authenticate(username=uname, password=passwd) user = authenticate(username=uname, password=passwd)
if user is not None: if user is not None and user.is_active:
if user.is_active: login(request, user)
login(request, user) request.user = user
request.user = user
# Check again # Check again
if request.user.is_authenticated: if request.user.is_authenticated:

View File

@ -1048,19 +1048,11 @@ class SalesOrder(TotalPriceMixin, Order):
def is_fully_allocated(self): def is_fully_allocated(self):
"""Return True if all line items are fully allocated.""" """Return True if all line items are fully allocated."""
for line in self.lines.all(): return all(line.is_fully_allocated() for line in self.lines.all())
if not line.is_fully_allocated():
return False
return True
def is_overallocated(self): def is_overallocated(self):
"""Return true if any lines in the order are over-allocated.""" """Return true if any lines in the order are over-allocated."""
for line in self.lines.all(): return any(line.is_overallocated() for line in self.lines.all())
if line.is_overallocated():
return True
return False
def is_completed(self): def is_completed(self):
"""Check if this order is "shipped" (all line items delivered).""" """Check if this order is "shipped" (all line items delivered)."""

View File

@ -166,7 +166,8 @@ class AbstractOrderSerializer(DataImportExportSerializerMixin, serializers.Seria
'notes', 'notes',
'barcode_hash', 'barcode_hash',
'overdue', 'overdue',
] + extra_fields *extra_fields,
]
class AbstractLineItemSerializer: class AbstractLineItemSerializer:
@ -433,7 +434,7 @@ class PurchaseOrderLineItemSerializer(
def skip_create_fields(self): def skip_create_fields(self):
"""Return a list of fields to skip when creating a new object.""" """Return a list of fields to skip when creating a new object."""
return ['auto_pricing', 'merge_items'] + super().skip_create_fields() return ['auto_pricing', 'merge_items', *super().skip_create_fields()]
@staticmethod @staticmethod
def annotate_queryset(queryset): def annotate_queryset(queryset):
@ -740,13 +741,12 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer):
base_quantity = line_item.part.base_quantity(quantity) base_quantity = line_item.part.base_quantity(quantity)
# Does the quantity need to be "integer" (for trackable parts?) # Does the quantity need to be "integer" (for trackable parts?)
if base_part.trackable: if base_part.trackable and Decimal(base_quantity) != int(base_quantity):
if Decimal(base_quantity) != int(base_quantity): raise ValidationError({
raise ValidationError({ 'quantity': _(
'quantity': _( 'An integer quantity must be provided for trackable parts'
'An integer quantity must be provided for trackable parts' )
) })
})
# If serial numbers are provided # If serial numbers are provided
if serial_numbers: if serial_numbers:

View File

@ -1527,7 +1527,7 @@ class SalesOrderTest(OrderTest):
self.download_file( self.download_file(
reverse('api-so-list'), reverse('api-so-list'),
{'export': fmt}, {'export': fmt},
decode=True if fmt == 'csv' else False, decode=fmt == 'csv',
expected_code=200, expected_code=200,
expected_fn=r'InvenTree_SalesOrder_.+', expected_fn=r'InvenTree_SalesOrder_.+',
) )

View File

@ -294,7 +294,7 @@ class SalesOrderExport(AjaxView):
export_format = request.GET.get('format', 'csv') export_format = request.GET.get('format', 'csv')
filename = f'{str(order)} - {order.customer.name}.{export_format}' filename = f'{order!s} - {order.customer.name}.{export_format}'
dataset = SalesOrderLineItemResource().export(queryset=order.lines.all()) dataset = SalesOrderLineItemResource().export(queryset=order.lines.all())
@ -321,7 +321,7 @@ class PurchaseOrderExport(AjaxView):
export_format = request.GET.get('format', 'csv') export_format = request.GET.get('format', 'csv')
filename = f'{str(order)} - {order.supplier.name}.{export_format}' filename = f'{order!s} - {order.supplier.name}.{export_format}'
dataset = PurchaseOrderLineItemResource().export(queryset=order.lines.all()) dataset = PurchaseOrderLineItemResource().export(queryset=order.lines.all())

View File

@ -167,10 +167,9 @@ class CategoryFilter(rest_filters.FilterSet):
top_level = str2bool(self.data.get('top_level', None)) top_level = str2bool(self.data.get('top_level', None))
# If the parent is *not* provided, update the results based on the "cascade" value # If the parent is *not* provided, update the results based on the "cascade" value
if not parent or top_level: if (not parent or top_level) and not value:
if not value: # If "cascade" is False, only return top-level categories
# If "cascade" is False, only return top-level categories queryset = queryset.filter(parent=None)
queryset = queryset.filter(parent=None)
return queryset return queryset
@ -466,8 +465,6 @@ class PartTestTemplateMixin:
class PartTestTemplateDetail(PartTestTemplateMixin, RetrieveUpdateDestroyAPI): class PartTestTemplateDetail(PartTestTemplateMixin, RetrieveUpdateDestroyAPI):
"""Detail endpoint for PartTestTemplate model.""" """Detail endpoint for PartTestTemplate model."""
pass
class PartTestTemplateList(PartTestTemplateMixin, DataExportViewMixin, ListCreateAPI): class PartTestTemplateList(PartTestTemplateMixin, DataExportViewMixin, ListCreateAPI):
"""API endpoint for listing (and creating) a PartTestTemplate.""" """API endpoint for listing (and creating) a PartTestTemplate."""
@ -1570,8 +1567,6 @@ class PartParameterTemplateList(
class PartParameterTemplateDetail(PartParameterTemplateMixin, RetrieveUpdateDestroyAPI): class PartParameterTemplateDetail(PartParameterTemplateMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for accessing the detail view for a PartParameterTemplate object.""" """API endpoint for accessing the detail view for a PartParameterTemplate object."""
pass
class PartParameterAPIMixin: class PartParameterAPIMixin:
"""Mixin class for PartParameter API endpoints.""" """Mixin class for PartParameter API endpoints."""
@ -1663,8 +1658,6 @@ class PartParameterList(PartParameterAPIMixin, DataExportViewMixin, ListCreateAP
class PartParameterDetail(PartParameterAPIMixin, RetrieveUpdateDestroyAPI): class PartParameterDetail(PartParameterAPIMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single PartParameter object.""" """API endpoint for detail view of a single PartParameter object."""
pass
class PartStocktakeFilter(rest_filters.FilterSet): class PartStocktakeFilter(rest_filters.FilterSet):
"""Custom filter for the PartStocktakeList endpoint.""" """Custom filter for the PartStocktakeList endpoint."""
@ -1922,8 +1915,6 @@ class BomList(BomMixin, DataExportViewMixin, ListCreateDestroyAPIView):
class BomDetail(BomMixin, RetrieveUpdateDestroyAPI): class BomDetail(BomMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single BomItem object.""" """API endpoint for detail view of a single BomItem object."""
pass
class BomImportUpload(CreateAPI): class BomImportUpload(CreateAPI):
"""API endpoint for uploading a complete Bill of Materials. """API endpoint for uploading a complete Bill of Materials.

View File

@ -4,6 +4,7 @@ Primarily BOM upload tools.
""" """
from collections import OrderedDict from collections import OrderedDict
from typing import Optional
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
@ -40,7 +41,11 @@ def MakeBomTemplate(fmt):
def ExportBom( def ExportBom(
part: Part, fmt='csv', cascade: bool = False, max_levels: int = None, **kwargs part: Part,
fmt='csv',
cascade: bool = False,
max_levels: Optional[int] = None,
**kwargs,
): ):
"""Export a BOM (Bill of Materials) for a given part. """Export a BOM (Bill of Materials) for a given part.
@ -244,10 +249,7 @@ def ExportBom(
manufacturer_name = '' manufacturer_name = ''
# Extract the "MPN" field from the Manufacturer Part # Extract the "MPN" field from the Manufacturer Part
if mp_part: manufacturer_mpn = mp_part.MPN if mp_part else ''
manufacturer_mpn = mp_part.MPN
else:
manufacturer_mpn = ''
# Generate a column name for this manufacturer # Generate a column name for this manufacturer
k_man = f'{_("Manufacturer")}_{mp_idx}' k_man = f'{_("Manufacturer")}_{mp_idx}'
@ -270,10 +272,7 @@ def ExportBom(
else: else:
supplier_name = '' supplier_name = ''
if sp_part: supplier_sku = sp_part.SKU if sp_part else ''
supplier_sku = sp_part.SKU
else:
supplier_sku = ''
# Generate column names for this supplier # Generate column names for this supplier
k_sup = ( k_sup = (
@ -307,10 +306,7 @@ def ExportBom(
supplier_parts_used.add(sp_part) supplier_parts_used.add(sp_part)
if sp_part.supplier: supplier_name = sp_part.supplier.name if sp_part.supplier else ''
supplier_name = sp_part.supplier.name
else:
supplier_name = ''
supplier_sku = sp_part.SKU supplier_sku = sp_part.SKU

View File

@ -231,10 +231,7 @@ class PartCategory(InvenTree.models.InvenTreeTree):
"""Get all unique parameter names for all parts from this category.""" """Get all unique parameter names for all parts from this category."""
unique_parameters_names = [] unique_parameters_names = []
if prefetch: parts = prefetch or self.prefetch_parts_parameters(cascade=cascade)
parts = prefetch
else:
parts = self.prefetch_parts_parameters(cascade=cascade)
for part in parts: for part in parts:
for parameter in part.parameters.all(): for parameter in part.parameters.all():
@ -248,10 +245,7 @@ class PartCategory(InvenTree.models.InvenTreeTree):
"""Get all parameter names and values for all parts from this category.""" """Get all parameter names and values for all parts from this category."""
category_parameters = [] category_parameters = []
if prefetch: parts = prefetch or self.prefetch_parts_parameters(cascade=cascade)
parts = prefetch
else:
parts = self.prefetch_parts_parameters(cascade=cascade)
for part in parts: for part in parts:
part_parameters = { part_parameters = {
@ -933,24 +927,26 @@ class Part(
'IPN': _('Duplicate IPN not allowed in part settings') 'IPN': _('Duplicate IPN not allowed in part settings')
}) })
if self.revision_of and self.revision: if (
if ( self.revision_of
and self.revision
and (
Part.objects.exclude(pk=self.pk) Part.objects.exclude(pk=self.pk)
.filter(revision_of=self.revision_of, revision=self.revision) .filter(revision_of=self.revision_of, revision=self.revision)
.exists() .exists()
): )
raise ValidationError(_('Duplicate part revision already exists.')) ):
raise ValidationError(_('Duplicate part revision already exists.'))
# Ensure unique across (Name, revision, IPN) (as specified) # Ensure unique across (Name, revision, IPN) (as specified)
if self.revision or self.IPN: if (self.revision or self.IPN) and (
if ( Part.objects.exclude(pk=self.pk)
Part.objects.exclude(pk=self.pk) .filter(name=self.name, revision=self.revision, IPN=self.IPN)
.filter(name=self.name, revision=self.revision, IPN=self.IPN) .exists()
.exists() ):
): raise ValidationError(
raise ValidationError( _('Part with this Name, IPN and Revision already exists.')
_('Part with this Name, IPN and Revision already exists.') )
)
def clean(self): def clean(self):
"""Perform cleaning operations for the Part model. """Perform cleaning operations for the Part model.
@ -3888,16 +3884,18 @@ class PartParameter(InvenTree.models.InvenTreeMetadataModel):
super().clean() super().clean()
# Validate the parameter data against the template units # Validate the parameter data against the template units
if get_global_setting( if (
'PART_PARAMETER_ENFORCE_UNITS', True, cache=False, create=False get_global_setting(
'PART_PARAMETER_ENFORCE_UNITS', True, cache=False, create=False
)
and self.template.units
): ):
if self.template.units: try:
try: InvenTree.conversion.convert_physical_value(
InvenTree.conversion.convert_physical_value( self.data, self.template.units
self.data, self.template.units )
) except ValidationError as e:
except ValidationError as e: raise ValidationError({'data': e.message})
raise ValidationError({'data': e.message})
# Validate the parameter data against the template choices # Validate the parameter data against the template choices
if choices := self.template.get_choices(): if choices := self.template.get_choices():
@ -4205,9 +4203,8 @@ class BomItem(
# Check if the part was changed # Check if the part was changed
deltas = self.get_field_deltas() deltas = self.get_field_deltas()
if 'part' in deltas: if 'part' in deltas and (old_part := deltas['part'].get('old', None)):
if old_part := deltas['part'].get('old', None): self.check_part_lock(old_part)
self.check_part_lock(old_part)
# Update the 'validated' field based on checksum calculation # Update the 'validated' field based on checksum calculation
self.validated = self.is_line_valid self.validated = self.is_line_valid
@ -4344,7 +4341,7 @@ class BomItem(
- allow_variants - allow_variants
""" """
# Seed the hash with the ID of this BOM item # Seed the hash with the ID of this BOM item
result_hash = hashlib.md5(''.encode()) result_hash = hashlib.md5(b'')
# The following components are used to calculate the checksum # The following components are used to calculate the checksum
components = [ components = [
@ -4438,8 +4435,7 @@ class BomItem(
try: try:
ovg = float(overage) ovg = float(overage)
if ovg < 0: ovg = max(ovg, 0)
ovg = 0
return ovg return ovg
except ValueError: except ValueError:
@ -4451,10 +4447,8 @@ class BomItem(
try: try:
percent = float(overage) / 100.0 percent = float(overage) / 100.0
if percent > 1: percent = min(percent, 1)
percent = 1 percent = max(percent, 0)
if percent < 0:
percent = 0
# Must be represented as a decimal # Must be represented as a decimal
percent = Decimal(percent) percent = Decimal(percent)

View File

@ -2000,9 +2000,8 @@ class BomImportExtractSerializer(InvenTree.serializers.DataFileExtractSerializer
if part is None: if part is None:
row['errors']['part'] = _('No matching part found') row['errors']['part'] = _('No matching part found')
else: elif not part.component:
if not part.component: row['errors']['part'] = _('Part is not designated as a component')
row['errors']['part'] = _('Part is not designated as a component')
# Update the 'part' value in the row # Update the 'part' value in the row
row['part'] = part.pk if part is not None else None row['part'] = part.pk if part is not None else None

View File

@ -1618,7 +1618,7 @@ class PartDetailTests(PartAPITestBase):
# Try to upload a non-image file # Try to upload a non-image file
test_path = BASE_DIR / '_testfolder' / 'dummy_image' test_path = BASE_DIR / '_testfolder' / 'dummy_image'
with open(f'{test_path}.txt', 'w') as dummy_image: with open(f'{test_path}.txt', 'w', encoding='utf-8') as dummy_image:
dummy_image.write('hello world') dummy_image.write('hello world')
with open(f'{test_path}.txt', 'rb') as dummy_image: with open(f'{test_path}.txt', 'rb') as dummy_image:

View File

@ -49,7 +49,7 @@ class BomExportTest(InvenTreeTestCase):
with open(filename, 'wb') as f: with open(filename, 'wb') as f:
f.write(response.getvalue()) f.write(response.getvalue())
with open(filename, 'r') as f: with open(filename, encoding='utf-8') as f:
reader = csv.reader(f, delimiter=',') reader = csv.reader(f, delimiter=',')
for line in reader: for line in reader:
@ -96,7 +96,7 @@ class BomExportTest(InvenTreeTestCase):
f.write(response.getvalue()) f.write(response.getvalue())
# Read the file # Read the file
with open(filename, 'r') as f: with open(filename, encoding='utf-8') as f:
reader = csv.reader(f, delimiter=',') reader = csv.reader(f, delimiter=',')
for line in reader: for line in reader:

View File

@ -419,7 +419,7 @@ class PartParameterTest(InvenTreeAPITestCase):
response = self.get( response = self.get(
url, url,
{'ordering': 'parameter_{pk}'.format(pk=template.pk), 'parameters': 'true'}, {'ordering': f'parameter_{template.pk}', 'parameters': 'true'},
expected_code=200, expected_code=200,
) )
@ -436,10 +436,7 @@ class PartParameterTest(InvenTreeAPITestCase):
# Next, check reverse ordering # Next, check reverse ordering
response = self.get( response = self.get(
url, url,
{ {'ordering': f'-parameter_{template.pk}', 'parameters': 'true'},
'ordering': '-parameter_{pk}'.format(pk=template.pk),
'parameters': 'true',
},
expected_code=200, expected_code=200,
) )

View File

@ -111,7 +111,6 @@ class PartPricingTests(InvenTreeTestCase):
def test_invalid_rate(self): def test_invalid_rate(self):
"""Ensure that conversion behaves properly with missing rates.""" """Ensure that conversion behaves properly with missing rates."""
...
def test_simple(self): def test_simple(self):
"""Tests for hard-coded values.""" """Tests for hard-coded values."""

View File

@ -180,9 +180,9 @@ class PartImport(FileManagementFormView):
if idx in self.file_manager.OPTIONAL_MATCH_HEADERS: if idx in self.file_manager.OPTIONAL_MATCH_HEADERS:
try: try:
exact_match = self.allowed_items[idx].get(**{ exact_match = self.allowed_items[idx].get(
a: data for a in self.matches[idx] **dict.fromkeys(self.matches[idx], data)
}) )
except ( except (
ValueError, ValueError,
self.allowed_items[idx].model.DoesNotExist, self.allowed_items[idx].model.DoesNotExist,
@ -414,7 +414,7 @@ class PartDetailFromIPN(PartDetail):
if not self.object: if not self.object:
return HttpResponseRedirect(reverse('part-index')) return HttpResponseRedirect(reverse('part-index'))
return super(PartDetailFromIPN, self).get(request, *args, **kwargs) return super().get(request, *args, **kwargs)
class PartImageSelect(AjaxUpdateView): class PartImageSelect(AjaxUpdateView):

View File

@ -5,8 +5,8 @@ from .plugin import InvenTreePlugin
from .registry import registry from .registry import registry
__all__ = [ __all__ = [
'registry',
'InvenTreePlugin', 'InvenTreePlugin',
'MixinNotImplementedError',
'MixinImplementationError', 'MixinImplementationError',
'MixinNotImplementedError',
'registry',
] ]

View File

@ -63,7 +63,7 @@ class PluginFilter(rest_filters.FilterSet):
match = True match = True
for mixin in mixins: for mixin in mixins:
if mixin not in result.mixins().keys(): if mixin not in result.mixins():
match = False match = False
break break

View File

@ -335,9 +335,8 @@ class BarcodePOAllocate(BarcodeView):
if not part and not supplier_part and not manufacturer_part: if not part and not supplier_part and not manufacturer_part:
raise ValidationError({'error': _('No matching part data found')}) raise ValidationError({'error': _('No matching part data found')})
if part: if part and (part_id := part.get('pk', None)):
if part_id := part.get('pk', None): supplier_parts = supplier_parts.filter(part__pk=part_id)
supplier_parts = supplier_parts.filter(part__pk=part_id)
if supplier_part: if supplier_part:
if supplier_part_id := supplier_part.get('pk', None): if supplier_part_id := supplier_part.get('pk', None):
@ -594,7 +593,7 @@ class BarcodeSOAllocate(BarcodeView):
raise ValidationError(response) raise ValidationError(response)
# If we have sufficient information, we can allocate the stock item # If we have sufficient information, we can allocate the stock item
if all((x is not None for x in [line_item, sales_order, shipment, quantity])): if all(x is not None for x in [line_item, sales_order, shipment, quantity]):
order.models.SalesOrderAllocation.objects.create( order.models.SalesOrderAllocation.objects.create(
line=line_item, shipment=shipment, item=stock_item, quantity=quantity line=line_item, shipment=shipment, item=stock_item, quantity=quantity
) )

View File

@ -384,7 +384,9 @@ class SupplierBarcodeMixin(BarcodeMixin):
return orders_intersection if orders_intersection else orders_union return orders_intersection if orders_intersection else orders_union
@staticmethod @staticmethod
def get_supplier_parts(sku: str = None, supplier: Company = None, mpn: str = None): def get_supplier_parts(
sku: str | None = None, supplier: Company = None, mpn: str | None = None
):
"""Get a supplier part from SKU or by supplier and MPN.""" """Get a supplier part from SKU or by supplier and MPN."""
if not (sku or supplier or mpn): if not (sku or supplier or mpn):
return SupplierPart.objects.none() return SupplierPart.objects.none()
@ -420,10 +422,10 @@ class SupplierBarcodeMixin(BarcodeMixin):
def receive_purchase_order_item( def receive_purchase_order_item(
supplier_part: SupplierPart, supplier_part: SupplierPart,
user: User, user: User,
quantity: Decimal | str = None, quantity: Decimal | str | None = None,
purchase_order: PurchaseOrder = None, purchase_order: PurchaseOrder = None,
location: StockLocation = None, location: StockLocation = None,
barcode: str = None, barcode: str | None = None,
) -> dict: ) -> dict:
"""Try to receive a purchase order item. """Try to receive a purchase order item.
@ -471,9 +473,9 @@ class SupplierBarcodeMixin(BarcodeMixin):
# 2. check if it's defined on the part # 2. check if it's defined on the part
# 3. check if there's 1 or 0 stock locations defined in InvenTree # 3. check if there's 1 or 0 stock locations defined in InvenTree
# -> assume all stock is going into that location (or no location) # -> assume all stock is going into that location (or no location)
if location := line_item.destination: if (location := line_item.destination) or (
pass location := supplier_part.part.get_default_location()
elif location := supplier_part.part.get_default_location(): ):
pass pass
elif StockLocation.objects.count() <= 1: elif StockLocation.objects.count() <= 1:
if not (location := StockLocation.objects.first()): if not (location := StockLocation.objects.first()):

View File

@ -41,7 +41,7 @@ class BarcodeGenerateSerializer(serializers.Serializer):
plugin.base.barcodes.helper.get_supported_barcode_models_map() plugin.base.barcodes.helper.get_supported_barcode_models_map()
) )
if model not in supported_models.keys(): if model not in supported_models:
raise ValidationError(_('Model is not supported')) raise ValidationError(_('Model is not supported'))
return model return model

View File

@ -147,10 +147,7 @@ def allow_table_event(table_name):
'part_partstocktakereport', 'part_partstocktakereport',
] ]
if table_name in ignore_tables: return not table_name in ignore_tables
return False
return True
@receiver(post_save) @receiver(post_save)

View File

@ -3,6 +3,7 @@
import json as json_pkg import json as json_pkg
import logging import logging
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional
import requests import requests
@ -117,10 +118,10 @@ class APICallMixin:
self, self,
endpoint: str, endpoint: str,
method: str = 'GET', method: str = 'GET',
url_args: dict = None, url_args: Optional[dict] = None,
data=None, data=None,
json=None, json=None,
headers: dict = None, headers: Optional[dict] = None,
simple_response: bool = True, simple_response: bool = True,
endpoint_is_url: bool = False, endpoint_is_url: bool = False,
): ):

View File

@ -33,7 +33,6 @@ class ReportMixin:
request: The request object which initiated the report generation request: The request object which initiated the report generation
context: The context dictionary to add to context: The context dictionary to add to
""" """
pass
def add_label_context(self, label_instance, model_instance, request, context): def add_label_context(self, label_instance, model_instance, request, context):
"""Add extra context to the provided label instance. """Add extra context to the provided label instance.
@ -46,7 +45,6 @@ class ReportMixin:
request: The request object which initiated the label generation request: The request object which initiated the label generation
context: The context dictionary to add to context: The context dictionary to add to
""" """
pass
def report_callback(self, template, instance, report, request): def report_callback(self, template, instance, report, request):
"""Callback function called after a report is generated. """Callback function called after a report is generated.
@ -59,4 +57,3 @@ class ReportMixin:
The default implementation does nothing. The default implementation does nothing.
""" """
pass

View File

@ -148,7 +148,7 @@ class ScheduleMixin:
def get_task_names(self): def get_task_names(self):
"""All defined task names.""" """All defined task names."""
# Returns a list of all task names associated with this plugin instance # Returns a list of all task names associated with this plugin instance
return [self.get_task_name(key) for key in self.scheduled_tasks.keys()] return [self.get_task_name(key) for key in self.scheduled_tasks]
def register_tasks(self): def register_tasks(self):
"""Register the tasks with the database.""" """Register the tasks with the database."""
@ -200,7 +200,7 @@ class ScheduleMixin:
try: try:
from django_q.models import Schedule from django_q.models import Schedule
for key, _ in self.scheduled_tasks.items(): for key in self.scheduled_tasks:
task_name = self.get_task_name(key) task_name = self.get_task_name(key)
try: try:

View File

@ -15,8 +15,6 @@ else:
class SettingsKeyType: class SettingsKeyType:
"""Dummy class, so that python throws no error.""" """Dummy class, so that python throws no error."""
pass
class SettingsMixin: class SettingsMixin:
"""Mixin that enables global settings for the plugin.""" """Mixin that enables global settings for the plugin."""

View File

@ -1,5 +1,7 @@
"""Validation mixin class definition.""" """Validation mixin class definition."""
from typing import Optional
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db.models import Model from django.db.models import Model
@ -67,7 +69,9 @@ class ValidationMixin:
""" """
return None return None
def validate_model_instance(self, instance: Model, deltas: dict = None) -> None: def validate_model_instance(
self, instance: Model, deltas: Optional[dict] = None
) -> None:
"""Run custom validation on a database model instance. """Run custom validation on a database model instance.
This method is called when a model instance is being validated. This method is called when a model instance is being validated.
@ -218,4 +222,3 @@ class ValidationMixin:
Raises: Raises:
ValidationError: If the proposed parameter value is objectionable ValidationError: If the proposed parameter value is objectionable
""" """
pass

Some files were not shown because too many files have changed in this diff Show More