Enable more Ruff rules (#7930)

* bump version

* Add more checks

* add simplify rules

* Add RUF rules

* small perf imrpovements

* pylint checks

* more style fixes

* fix a number of A002 cases

* fix A001 cases

* disable unsafe fixes

* remove unneeded branches
fixes SIM102

* re-enable .keys for specific case

* Revert "remove unneeded branches"

This reverts commit f74d41bc07.

* fix reference
This commit is contained in:
Matthias Mair 2024-08-27 01:04:55 +02:00 committed by GitHub
parent bcbbae0a18
commit 1634258783
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
127 changed files with 526 additions and 740 deletions

View File

@ -71,7 +71,7 @@ def check_prohibited_tags(data):
for filename in pathlib.Path(js_i18n_dir).rglob('*.js'):
print(f"Checking file 'translated/{os.path.basename(filename)}':")
with open(filename, 'r') as js_file:
with open(filename, encoding='utf-8') as js_file:
data = js_file.readlines()
errors += check_invalid_tag(data)
@ -81,7 +81,7 @@ for filename in pathlib.Path(js_dynamic_dir).rglob('*.js'):
print(f"Checking file 'dynamic/{os.path.basename(filename)}':")
# Check that the 'dynamic' files do not contains any translated strings
with open(filename, 'r') as js_file:
with open(filename, encoding='utf-8') as js_file:
data = js_file.readlines()
invalid_tags = ['blocktrans', 'blocktranslate', 'trans', 'translate']

View File

@ -20,9 +20,9 @@ for line in str(out.decode()).split('\n'):
if len(migrations) == 0:
sys.exit(0)
print('There are {n} unstaged migration files:'.format(n=len(migrations)))
print(f'There are {len(migrations)} unstaged migration files:')
for m in migrations:
print(' - {m}'.format(m=m))
print(f' - {m}')
sys.exit(len(migrations))

View File

@ -89,7 +89,7 @@ def check_version_number(version_string, allow_duplicate=False):
if release > version_tuple:
highest_release = False
print(f'Found newer release: {str(release)}')
print(f'Found newer release: {release!s}')
return highest_release
@ -134,7 +134,7 @@ if __name__ == '__main__':
version = None
with open(version_file, 'r') as f:
with open(version_file, encoding='utf-8') as f:
text = f.read()
# Extract the InvenTree software version
@ -175,10 +175,7 @@ if __name__ == '__main__':
print(f"Version number '{version}' does not match tag '{version_tag}'")
sys.exit
if highest_release:
docker_tags = [version_tag, 'stable']
else:
docker_tags = [version_tag]
docker_tags = [version_tag, 'stable'] if highest_release else [version_tag]
elif GITHUB_REF_TYPE == 'branch':
# Otherwise we know we are targeting the 'master' branch
@ -202,7 +199,7 @@ if __name__ == '__main__':
target_repos = [REPO.lower(), f'ghcr.io/{REPO.lower()}']
# Ref: https://getridbug.com/python/how-to-set-environment-variables-in-github-actions-using-python/
with open(os.getenv('GITHUB_ENV'), 'a') as env_file:
with open(os.getenv('GITHUB_ENV'), 'a', encoding='utf-8') as env_file:
# Construct tag string
tag_list = [[f'{r}:{t}' for t in docker_tags] for r in target_repos]
tags = ','.join(itertools.chain(*tag_list))

View File

@ -17,17 +17,18 @@ repos:
- id: check-yaml
- id: mixed-line-ending
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.1
rev: v0.6.1
hooks:
- id: ruff-format
args: [--preview]
- id: ruff
args: [
--fix,
# --unsafe-fixes,
--preview
]
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.2.13
rev: 0.2.37
hooks:
- id: pip-compile
name: pip-compile requirements-dev.in
@ -77,7 +78,7 @@ repos:
- "prettier@^2.4.1"
- "@trivago/prettier-plugin-sort-imports"
- repo: https://github.com/pre-commit/mirrors-eslint
rev: "v9.6.0"
rev: "v9.9.0"
hooks:
- id: eslint
additional_dependencies:

View File

@ -10,7 +10,7 @@ tld = os.path.abspath(os.path.join(here, '..'))
config_file = os.path.join(tld, 'mkdocs.yml')
with open(config_file, 'r') as f:
with open(config_file, encoding='utf-8') as f:
data = yaml.load(f, yaml.BaseLoader)
assert data['strict'] == 'true'

View File

@ -57,7 +57,7 @@ def fetch_rtd_versions():
versions = sorted(versions, key=lambda x: StrictVersion(x['version']), reverse=True)
# Add "latest" version first
if not any((x['title'] == 'latest' for x in versions)):
if not any(x['title'] == 'latest' for x in versions):
versions.insert(
0,
{
@ -70,7 +70,7 @@ def fetch_rtd_versions():
# Ensure we have the 'latest' version
current_version = os.environ.get('READTHEDOCS_VERSION', None)
if current_version and not any((x['title'] == current_version for x in versions)):
if current_version and not any(x['title'] == current_version for x in versions):
versions.append({
'version': current_version,
'title': current_version,
@ -82,7 +82,7 @@ def fetch_rtd_versions():
print('Discovered the following versions:')
print(versions)
with open(output_filename, 'w') as file:
with open(output_filename, 'w', encoding='utf-8') as file:
json.dump(versions, file, indent=2)
@ -100,7 +100,7 @@ def get_release_data():
# Release information has been cached to file
print("Loading release information from 'releases.json'")
with open(json_file) as f:
with open(json_file, encoding='utf-8') as f:
return json.loads(f.read())
# Download release information via the GitHub API
@ -127,7 +127,7 @@ def get_release_data():
page += 1
# Cache these results to file
with open(json_file, 'w') as f:
with open(json_file, 'w', encoding='utf-8') as f:
print("Saving release information to 'releases.json'")
f.write(json.dumps(releases))
@ -173,7 +173,7 @@ def on_config(config, *args, **kwargs):
# Add *all* readthedocs related keys
readthedocs = {}
for key in os.environ.keys():
for key in os.environ:
if key.startswith('READTHEDOCS_'):
k = key.replace('READTHEDOCS_', '').lower()
readthedocs[k] = os.environ[key]

View File

@ -46,7 +46,7 @@ def top_level_path(path: str) -> str:
key = path.split('/')[1]
if key in SPECIAL_PATHS.keys():
if key in SPECIAL_PATHS:
return key
return GENERAL_PATH
@ -54,9 +54,7 @@ def top_level_path(path: str) -> str:
def generate_schema_file(key: str) -> None:
"""Generate a schema file for the provided key."""
description = (
SPECIAL_PATHS[key] if key in SPECIAL_PATHS else 'General API Endpoints'
)
description = SPECIAL_PATHS.get(key, 'General API Endpoints')
output = f"""
---
@ -75,7 +73,7 @@ def generate_schema_file(key: str) -> None:
print('Writing schema file to:', output_file)
with open(output_file, 'w') as f:
with open(output_file, 'w', encoding='utf-8') as f:
f.write(output)
@ -121,7 +119,7 @@ def generate_index_file(version: str):
print('Writing index file to:', output_file)
with open(output_file, 'w') as f:
with open(output_file, 'w', encoding='utf-8') as f:
f.write(output)
@ -173,7 +171,7 @@ def parse_api_file(filename: str):
The intent is to make the API schema easier to peruse on the documentation.
"""
with open(filename, 'r') as f:
with open(filename, encoding='utf-8') as f:
data = yaml.safe_load(f)
paths = data['paths']
@ -213,7 +211,7 @@ def parse_api_file(filename: str):
output_file = os.path.abspath(output_file)
with open(output_file, 'w') as f:
with open(output_file, 'w', encoding='utf-8') as f:
yaml.dump(output, f)
# Generate a markdown file for the schema

View File

@ -16,7 +16,7 @@ global USER_SETTINGS
here = os.path.dirname(__file__)
settings_file = os.path.join(here, 'inventree_settings.json')
with open(settings_file, 'r') as sf:
with open(settings_file, encoding='utf-8') as sf:
settings = json.load(sf)
GLOBAL_SETTINGS = settings['global']
@ -27,7 +27,7 @@ def get_repo_url(raw=False):
"""Return the repository URL for the current project."""
mkdocs_yml = os.path.join(os.path.dirname(__file__), 'mkdocs.yml')
with open(mkdocs_yml, 'r') as f:
with open(mkdocs_yml, encoding='utf-8') as f:
mkdocs_config = yaml.safe_load(f)
repo_name = mkdocs_config['repo_name']
@ -47,7 +47,7 @@ def check_link(url) -> bool:
# Keep a local cache file of URLs we have already checked
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, 'r') as f:
with open(CACHE_FILE, encoding='utf-8') as f:
cache = f.read().splitlines()
if url in cache:
@ -59,7 +59,7 @@ def check_link(url) -> bool:
response = requests.head(url, timeout=5000)
if response.status_code == 200:
# Update the cache file
with open(CACHE_FILE, 'a') as f:
with open(CACHE_FILE, 'a', encoding='utf-8') as f:
f.write(f'{url}\n')
return True
@ -177,7 +177,7 @@ def define_env(env):
assert subprocess.call(command, shell=True) == 0
with open(output, 'r') as f:
with open(output, encoding='utf-8') as f:
content = f.read()
return content
@ -200,12 +200,13 @@ def define_env(env):
return assets
@env.macro
def includefile(filename: str, title: str, format: str = ''):
def includefile(filename: str, title: str, fmt: str = ''):
"""Include a file in the documentation, in a 'collapse' block.
Arguments:
- filename: The name of the file to include (relative to the top-level directory)
- title:
- fmt:
"""
here = os.path.dirname(__file__)
path = os.path.join(here, '..', filename)
@ -214,11 +215,11 @@ def define_env(env):
if not os.path.exists(path):
raise FileNotFoundError(f'Required file {path} does not exist.')
with open(path, 'r') as f:
with open(path, encoding='utf-8') as f:
content = f.read()
data = f'??? abstract "{title}"\n\n'
data += f' ```{format}\n'
data += f' ```{fmt}\n'
data += textwrap.indent(content, ' ')
data += '\n\n'
data += ' ```\n\n'
@ -233,15 +234,15 @@ def define_env(env):
'src', 'backend', 'InvenTree', 'report', 'templates', filename
)
return includefile(fn, f'Template: {base}', format='html')
return includefile(fn, f'Template: {base}', fmt='html')
@env.macro
def rendersetting(setting: dict):
"""Render a provided setting object into a table row."""
name = setting['name']
description = setting['description']
default = setting.get('default', None)
units = setting.get('units', None)
default = setting.get('default')
units = setting.get('units')
return f'| {name} | {description} | {default if default is not None else ""} | {units if units is not None else ""} |'

View File

@ -20,13 +20,30 @@ src = ["src/backend/InvenTree"]
"__init__.py" = ["D104"]
[tool.ruff.lint]
select = ["A", "B", "C4", "D", "I", "N", "F"]
select = ["A", "B", "C", "C4", "D", "F", "I", "N", "SIM", "PIE", "PLE", "PLW", "RUF", "UP", "W"]
# Things that should be enabled in the future:
# - LOG
# - DJ # for Django stuff
# - S # for security stuff (bandit)
ignore = [
"PLE1205",
# - PLE1205 - Too many arguments for logging format string
"PLW2901",
# - PLW2901 - Outer {outer_kind} variable {name} overwritten by inner {inner_kind} target
"PLW0602","PLW0603","PLW0604", # global variable things
"RUF015",
# - RUF015 - Prefer next({iterable}) over single element slice
"RUF012",
# - RUF012 - Mutable class attributes should be annotated with typing.ClassVar
"SIM117",
# - SIM117 - Use a single with statement with multiple contexts instead of nested with statements
"SIM102",
# - SIM102 - Use a single if statement instead of nested if statements
"SIM105",
# - SIM105 - Use contextlib.suppress({exception}) instead of try-except-pass
"C901",
# - C901 - function is too complex
"N999",
# - N802 - function name should be lowercase
"N802",
@ -42,7 +59,7 @@ ignore = [
"B904",
# Remove fast
"A001", "A002","A003","B018"
"A002", "B018"
]
[tool.ruff.lint.pydocstyle]

View File

@ -104,13 +104,15 @@ class InvenTreeResource(ModelResource):
attribute = getattr(field, 'attribute', field_name)
# Check if the associated database field is a non-nullable string
if db_field := db_fields.get(attribute):
if (
(db_field := db_fields.get(attribute))
and (
isinstance(db_field, CharField)
and db_field.blank
and not db_field.null
)
and column not in self.CONVERT_NULL_FIELDS
):
if column not in self.CONVERT_NULL_FIELDS:
self.CONVERT_NULL_FIELDS.append(column)
return super().before_import(dataset, using_transactions, dry_run, **kwargs)

View File

@ -26,7 +26,7 @@ from part.models import Part
from plugin.serializers import MetadataSerializer
from users.models import ApiToken
from .email import is_email_configured
from .helpers_email import is_email_configured
from .mixins import ListAPI, RetrieveUpdateAPI
from .status import check_system_health, is_worker_running
from .version import inventreeApiText
@ -77,7 +77,7 @@ class LicenseView(APIView):
# Ensure we do not have any duplicate 'name' values in the list
for entry in data:
name = None
for key in entry.keys():
for key in entry:
if key.lower() == 'name':
name = entry[key]
break
@ -321,7 +321,6 @@ class BulkDeleteMixin:
Raises:
ValidationError: If the deletion should not proceed
"""
pass
def filter_delete_queryset(self, queryset, request):
"""Provide custom filtering for the queryset *before* it is deleted.
@ -398,8 +397,6 @@ class BulkDeleteMixin:
class ListCreateDestroyAPIView(BulkDeleteMixin, ListCreateAPI):
"""Custom API endpoint which provides BulkDelete functionality in addition to List and Create."""
...
class APISearchViewSerializer(serializers.Serializer):
"""Serializer for the APISearchView."""

View File

@ -125,7 +125,7 @@ class InvenTreeConfig(AppConfig):
for task in tasks:
ref_name = f'{task.func.__module__}.{task.func.__name__}'
if ref_name in existing_tasks.keys():
if ref_name in existing_tasks:
# This task already exists - update the details if required
existing_task = existing_tasks[ref_name]

View File

@ -131,7 +131,7 @@ def load_config_data(set_cache: bool = False) -> map:
cfg_file = get_config_file()
with open(cfg_file, 'r') as cfg:
with open(cfg_file, encoding='utf-8') as cfg:
data = yaml.safe_load(cfg)
# Set the cache if requested

View File

@ -1,8 +1,6 @@
# -*- coding: utf-8 -*-
"""Provides extra global data to all templates."""
import InvenTree.email
import InvenTree.helpers_email
import InvenTree.ready
import InvenTree.status
from generic.states.custom import get_custom_classes
@ -27,7 +25,7 @@ def health_status(request):
status = {
'django_q_running': InvenTree.status.is_worker_running(),
'email_configured': InvenTree.email.is_email_configured(),
'email_configured': InvenTree.helpers_email.is_email_configured(),
}
# The following keys are required to denote system health
@ -75,7 +73,7 @@ def user_roles(request):
roles = {}
for role in RuleSet.get_ruleset_models().keys():
for role in RuleSet.get_ruleset_models():
permissions = {}
for perm in ['view', 'add', 'change', 'delete']:

View File

@ -2,6 +2,7 @@
import logging
import re
from typing import Optional
from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _
@ -95,7 +96,7 @@ def from_engineering_notation(value):
"""
value = str(value).strip()
pattern = '(\d+)([a-zA-Z]+)(\d+)(.*)'
pattern = r'(\d+)([a-zA-Z]+)(\d+)(.*)'
if match := re.match(pattern, value):
left, prefix, right, suffix = match.groups()
@ -133,7 +134,7 @@ def convert_value(value, unit):
return value
def convert_physical_value(value: str, unit: str = None, strip_units=True):
def convert_physical_value(value: str, unit: Optional[str] = None, strip_units=True):
"""Validate that the provided value is a valid physical quantity.
Arguments:
@ -245,8 +246,4 @@ def is_dimensionless(value):
if value.units == ureg.dimensionless:
return True
if value.to_base_units().units == ureg.dimensionless:
return True
# At this point, the value is not dimensionless
return False
return value.to_base_units().units == ureg.dimensionless

View File

@ -1,7 +1,6 @@
"""Custom exception handling for the DRF API."""
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import logging
import sys
@ -41,10 +40,7 @@ def log_error(path, error_name=None, error_info=None, error_data=None):
if kind in settings.IGNORED_ERRORS:
return
if error_name:
kind = error_name
else:
kind = getattr(kind, '__name__', 'Unknown Error')
kind = error_name or getattr(kind, '__name__', 'Unknown Error')
if error_info:
info = error_info

View File

@ -31,10 +31,7 @@ class InvenTreeExchange(SimpleExchangeBackend):
# Find the selected exchange rate plugin
slug = get_global_setting('CURRENCY_UPDATE_PLUGIN', create=False)
if slug:
plugin = registry.get_plugin(slug)
else:
plugin = None
plugin = registry.get_plugin(slug) if slug else None
if not plugin:
# Find the first active currency exchange plugin

View File

@ -93,8 +93,7 @@ class InvenTreeModelMoneyField(ModelMoneyField):
allow_negative = kwargs.pop('allow_negative', False)
# If no validators are provided, add some "standard" ones
if len(validators) == 0:
if not allow_negative:
if len(validators) == 0 and not allow_negative:
validators.append(MinMoneyValidator(0))
kwargs['validators'] = validators

View File

@ -118,10 +118,7 @@ class InvenTreeOrderingFilter(filters.OrderingFilter):
field = field[1:]
# Are aliases defined for this field?
if field in aliases:
alias = aliases[field]
else:
alias = field
alias = aliases.get(field, field)
"""
Potentially, a single field could be "aliased" to multiple field,

View File

@ -2,6 +2,7 @@
import re
import string
from typing import Optional
from django.conf import settings
from django.utils import translation
@ -106,10 +107,7 @@ def construct_format_regex(fmt_string: str) -> str:
# Add a named capture group for the format entry
if name:
# Check if integer values are required
if _fmt.endswith('d'):
c = '\d'
else:
c = '.'
c = '\\d' if _fmt.endswith('d') else '.'
# Specify width
# TODO: Introspect required width
@ -160,7 +158,7 @@ def extract_named_group(name: str, value: str, fmt_string: str) -> str:
"""
info = parse_format_string(fmt_string)
if name not in info.keys():
if name not in info:
raise NameError(_(f"Value '{name}' does not appear in pattern format"))
# Construct a regular expression for matching against the provided format string
@ -182,8 +180,8 @@ def extract_named_group(name: str, value: str, fmt_string: str) -> str:
def format_money(
money: Money,
decimal_places: int = None,
format: str = None,
decimal_places: Optional[int] = None,
fmt: Optional[str] = None,
include_symbol: bool = True,
) -> str:
"""Format money object according to the currently set local.
@ -191,7 +189,7 @@ def format_money(
Args:
money (Money): The money object to format
decimal_places (int): Number of decimal places to use
format (str): Format pattern according LDML / the babel format pattern syntax (https://babel.pocoo.org/en/latest/numbers.html)
fmt (str): Format pattern according LDML / the babel format pattern syntax (https://babel.pocoo.org/en/latest/numbers.html)
Returns:
str: The formatted string
@ -199,10 +197,10 @@ def format_money(
Raises:
ValueError: format string is incorrectly specified
"""
language = None and translation.get_language() or settings.LANGUAGE_CODE
language = (None) or settings.LANGUAGE_CODE
locale = Locale.parse(translation.to_locale(language))
if format:
pattern = parse_pattern(format)
if fmt:
pattern = parse_pattern(fmt)
else:
pattern = locale.currency_formats['standard']
if decimal_places is not None:

View File

@ -266,8 +266,7 @@ class RegistratonMixin:
raise forms.ValidationError(
_('The provided primary email address is not valid.')
)
else:
if split_email[1] == option[1:]:
elif split_email[1] == option[1:]:
return super().clean_email(email)
logger.info('The provided email domain for %s is not approved', email)

View File

@ -9,15 +9,15 @@ import os.path
import re
from decimal import Decimal, InvalidOperation
from pathlib import Path
from typing import TypeVar, Union
from typing import Optional, TypeVar, Union
from wsgiref.util import FileWrapper
import django.utils.timezone as timezone
from django.conf import settings
from django.contrib.staticfiles.storage import StaticFilesStorage
from django.core.exceptions import FieldError, ValidationError
from django.core.files.storage import Storage, default_storage
from django.http import StreamingHttpResponse
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
import pytz
@ -97,10 +97,7 @@ def generateTestKey(test_name: str) -> str:
if char.isidentifier():
return True
if char.isalnum():
return True
return False
return bool(char.isalnum())
# Remove any characters that cannot be used to represent a variable
key = ''.join([c for c in key if valid_char(c)])
@ -490,10 +487,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
except ValueError:
raise ValidationError([_('Invalid quantity provided')])
if input_string:
input_string = str(input_string).strip()
else:
input_string = ''
input_string = str(input_string).strip() if input_string else ''
if len(input_string) == 0:
raise ValidationError([_('Empty serial number string')])
@ -800,10 +794,10 @@ def remove_non_printable_characters(
if remove_unicode:
# Remove Unicode control characters
if remove_newline:
cleaned = regex.sub('[^\P{C}]+', '', cleaned)
cleaned = regex.sub(r'[^\P{C}]+', '', cleaned)
else:
# Use 'negative-lookahead' to exclude newline character
cleaned = regex.sub('(?![\x0a])[^\P{C}]+', '', cleaned)
cleaned = regex.sub('(?![\x0a])[^\\P{C}]+', '', cleaned)
return cleaned
@ -827,7 +821,7 @@ def hash_barcode(barcode_data):
def hash_file(filename: Union[str, Path], storage: Union[Storage, None] = None):
"""Return the MD5 hash of a file."""
content = (
open(filename, 'rb').read()
open(filename, 'rb').read() # noqa: SIM115
if storage is None
else storage.open(str(filename), 'rb').read()
)
@ -865,7 +859,7 @@ def server_timezone() -> str:
return settings.TIME_ZONE
def to_local_time(time, target_tz: str = None):
def to_local_time(time, target_tz: Optional[str] = None):
"""Convert the provided time object to the local timezone.
Arguments:

View File

@ -62,7 +62,7 @@ def send_email(subject, body, recipients, from_email=None, html_message=None):
# If we are importing data, don't send emails
return
if not InvenTree.email.is_email_configured() and not settings.TESTING:
if not is_email_configured() and not settings.TESTING:
# Email is not configured / enabled
return

View File

@ -114,10 +114,7 @@ def download_image_from_url(remote_url, timeout=2.5):
# Add user specified user-agent to request (if specified)
user_agent = get_global_setting('INVENTREE_DOWNLOAD_FROM_URL_USER_AGENT')
if user_agent:
headers = {'User-Agent': user_agent}
else:
headers = None
headers = {'User-Agent': user_agent} if user_agent else None
try:
response = requests.get(
@ -130,7 +127,7 @@ def download_image_from_url(remote_url, timeout=2.5):
# Throw an error if anything goes wrong
response.raise_for_status()
except requests.exceptions.ConnectionError as exc:
raise Exception(_('Connection error') + f': {str(exc)}')
raise Exception(_('Connection error') + f': {exc!s}')
except requests.exceptions.Timeout as exc:
raise exc
except requests.exceptions.HTTPError:
@ -138,7 +135,7 @@ def download_image_from_url(remote_url, timeout=2.5):
_('Server responded with invalid status code') + f': {response.status_code}'
)
except Exception as exc:
raise Exception(_('Exception occurred') + f': {str(exc)}')
raise Exception(_('Exception occurred') + f': {exc!s}')
if response.status_code != 200:
raise Exception(

View File

@ -47,7 +47,7 @@ class Command(BaseCommand):
filename = kwargs.get('filename', 'inventree_settings.json')
with open(filename, 'w') as f:
with open(filename, 'w', encoding='utf-8') as f:
json.dump(settings, f, indent=4)
print(f"Exported InvenTree settings definitions to '{filename}'")

View File

@ -103,14 +103,14 @@ class Command(BaseCommand):
})
self.stdout.write(f'Writing icon map for {len(icons.keys())} icons')
with open(kwargs['output_file'], 'w') as f:
with open(kwargs['output_file'], 'w', encoding='utf-8') as f:
json.dump(icons, f, indent=2)
self.stdout.write(f'Icon map written to {kwargs["output_file"]}')
# Import icon map file
if kwargs['input_file']:
with open(kwargs['input_file'], 'r') as f:
with open(kwargs['input_file'], encoding='utf-8') as f:
icons = json.load(f)
self.stdout.write(f'Loaded icon map for {len(icons.keys())} icons')

View File

@ -19,8 +19,9 @@ def render_file(file_name, source, target, locales, ctx):
target_file = os.path.join(target, locale + '.' + file_name)
with open(target_file, 'w') as localised_file:
with lang_over(locale):
with open(target_file, 'w', encoding='utf-8') as localised_file, lang_over(
locale
):
rendered = render_to_string(os.path.join(source, file_name), ctx)
localised_file.write(rendered)

View File

@ -35,7 +35,7 @@ class Command(BaseCommand):
img_paths.append(x.path)
if len(img_paths) > 0:
if all((os.path.exists(path) for path in img_paths)):
if all(os.path.exists(path) for path in img_paths):
# All images exist - skip further work
return

View File

@ -35,4 +35,4 @@ class Command(BaseCommand):
mfa_user[0].staticdevice_set.all().delete()
# TOTP tokens
mfa_user[0].totpdevice_set.all().delete()
print(f'Removed all MFA methods for user {str(mfa_user[0])}')
print(f'Removed all MFA methods for user {mfa_user[0]!s}')

View File

@ -204,7 +204,7 @@ class InvenTreeMetadata(SimpleMetadata):
# Iterate through simple fields
for name, field in model_fields.fields.items():
if name in serializer_info.keys():
if name in serializer_info:
if name in read_only_fields:
serializer_info[name]['read_only'] = True
@ -236,7 +236,7 @@ class InvenTreeMetadata(SimpleMetadata):
# Iterate through relations
for name, relation in model_fields.relations.items():
if name not in serializer_info.keys():
if name not in serializer_info:
# Skip relation not defined in serializer
continue
@ -307,12 +307,12 @@ class InvenTreeMetadata(SimpleMetadata):
instance_filters = instance.api_instance_filters()
for field_name, field_filters in instance_filters.items():
if field_name not in serializer_info.keys():
if field_name not in serializer_info:
# The field might be missing, but is added later on
# This function seems to get called multiple times?
continue
if 'instance_filters' not in serializer_info[field_name].keys():
if 'instance_filters' not in serializer_info[field_name]:
serializer_info[field_name]['instance_filters'] = {}
for key, value in field_filters.items():

View File

@ -36,7 +36,7 @@ def get_token_from_request(request):
return None
class AuthRequiredMiddleware(object):
class AuthRequiredMiddleware:
"""Check for user to be authenticated."""
def __init__(self, get_response):
@ -92,23 +92,18 @@ class AuthRequiredMiddleware(object):
# Allow static files to be accessed without auth
# Important for e.g. login page
if request.path_info.startswith('/static/'):
authorized = True
# Unauthorized users can access the login page
elif request.path_info.startswith('/accounts/'):
authorized = True
elif (
if (
request.path_info.startswith('/static/')
or request.path_info.startswith('/accounts/')
or (
request.path_info.startswith(f'/{settings.FRONTEND_URL_BASE}/')
or request.path_info.startswith('/assets/')
or request.path_info == f'/{settings.FRONTEND_URL_BASE}'
)
or self.check_token(request)
):
authorized = True
elif self.check_token(request):
authorized = True
# No authorization was found for the request
if not authorized:
path = request.path_info

View File

@ -128,14 +128,10 @@ class CreateAPI(CleanMixin, generics.CreateAPIView):
class RetrieveAPI(generics.RetrieveAPIView):
"""View for retrieve API."""
pass
class RetrieveUpdateAPI(CleanMixin, generics.RetrieveUpdateAPIView):
"""View for retrieve and update API."""
pass
class CustomDestroyModelMixin:
"""This mixin was created pass the kwargs from the API to the models."""

View File

@ -390,10 +390,7 @@ class ReferenceIndexingMixin(models.Model):
except Exception:
# If anything goes wrong, return the most recent reference
recent = cls.get_most_recent_item()
if recent:
reference = recent.reference
else:
reference = ''
reference = recent.reference if recent else ''
return reference
@ -410,14 +407,14 @@ class ReferenceIndexingMixin(models.Model):
})
# Check that only 'allowed' keys are provided
for key in info.keys():
if key not in ctx.keys():
for key in info:
if key not in ctx:
raise ValidationError({
'value': _('Unknown format key specified') + f": '{key}'"
})
# Check that the 'ref' variable is specified
if 'ref' not in info.keys():
if 'ref' not in info:
raise ValidationError({
'value': _('Missing required format key') + ": 'ref'"
})
@ -859,7 +856,7 @@ class InvenTreeTree(MetadataMixin, PluginValidationMixin, MPTTModel):
Returns:
List of category names from the top level to this category
"""
return self.parentpath + [self]
return [*self.parentpath, self]
def get_path(self):
"""Return a list of element in the item tree.
@ -1099,4 +1096,4 @@ def after_error_logged(sender, instance: Error, created: bool, **kwargs):
except Exception as exc:
"""We do not want to throw an exception while reporting an exception"""
logger.error(exc) # noqa: LOG005
logger.error(exc)

View File

@ -105,8 +105,7 @@ class IsStaffOrReadOnly(permissions.IsAdminUser):
def has_permission(self, request, view):
"""Check if the user is a superuser."""
return bool(
request.user
and request.user.is_staff
(request.user and request.user.is_staff)
or request.method in permissions.SAFE_METHODS
)

View File

@ -11,33 +11,28 @@ def isInTestMode():
def isImportingData():
"""Returns True if the database is currently importing (or exporting) data, e.g. 'loaddata' command is performed."""
return any((x in sys.argv for x in ['flush', 'loaddata', 'dumpdata']))
return any(x in sys.argv for x in ['flush', 'loaddata', 'dumpdata'])
def isRunningMigrations():
"""Return True if the database is currently running migrations."""
return any(
(
x in sys.argv
for x in ['migrate', 'makemigrations', 'showmigrations', 'runmigrations']
)
)
def isRebuildingData():
"""Return true if any of the rebuilding commands are being executed."""
return any(
(
x in sys.argv
for x in ['prerender', 'rebuild_models', 'rebuild_thumbnails', 'rebuild']
)
)
def isRunningBackup():
"""Return true if any of the backup commands are being executed."""
return any(
(
x in sys.argv
for x in [
'backup',
@ -48,7 +43,6 @@ def isRunningBackup():
'mediarestore',
]
)
)
def isInWorkerThread():
@ -64,10 +58,7 @@ def isInServerThread():
if 'runserver' in sys.argv:
return True
if 'gunicorn' in sys.argv[0]:
return True
return False
return 'gunicorn' in sys.argv[0]
def isInMainThread():
@ -128,11 +119,7 @@ def canAppAccessDatabase(
if not allow_plugins:
excluded_commands.extend(['collectplugins'])
for cmd in excluded_commands:
if cmd in sys.argv:
return False
return True
return all(cmd not in sys.argv for cmd in excluded_commands)
def isPluginRegistryLoaded():

View File

@ -89,7 +89,7 @@ class InvenTreeCurrencySerializer(serializers.ChoiceField):
)
if allow_blank:
choices = [('', '---------')] + choices
choices = [('', '---------'), *choices]
kwargs['choices'] = choices
@ -379,7 +379,7 @@ class InvenTreeTaggitSerializer(TaggitSerializer):
tag_object = super().update(instance, validated_data)
for key in to_be_tagged.keys():
for key in to_be_tagged:
# re-add the tagmanager
new_tagobject = tag_object.__class__.objects.get(id=tag_object.id)
setattr(tag_object, key, getattr(new_tagobject, key))
@ -390,8 +390,6 @@ class InvenTreeTaggitSerializer(TaggitSerializer):
class InvenTreeTagModelSerializer(InvenTreeTaggitSerializer, InvenTreeModelSerializer):
"""Combination of InvenTreeTaggitSerializer and InvenTreeModelSerializer."""
pass
class UserSerializer(InvenTreeModelSerializer):
"""Serializer for a User."""
@ -426,14 +424,15 @@ class ExendedUserSerializer(UserSerializer):
class Meta(UserSerializer.Meta):
"""Metaclass defines serializer fields."""
fields = UserSerializer.Meta.fields + [
fields = [
*UserSerializer.Meta.fields,
'groups',
'is_staff',
'is_superuser',
'is_active',
]
read_only_fields = UserSerializer.Meta.read_only_fields + ['groups']
read_only_fields = [*UserSerializer.Meta.read_only_fields, 'groups']
is_staff = serializers.BooleanField(
label=_('Staff'), help_text=_('Does this user have staff permissions')
@ -704,7 +703,6 @@ class DataFileUploadSerializer(serializers.Serializer):
def save(self):
"""Empty overwrite for save."""
...
class DataFileExtractSerializer(serializers.Serializer):
@ -806,8 +804,7 @@ class DataFileExtractSerializer(serializers.Serializer):
required = field.get('required', False)
# Check for missing required columns
if required:
if name not in self.columns:
if required and name not in self.columns:
raise serializers.ValidationError(
_(f"Missing required column: '{name}'")
)
@ -824,7 +821,6 @@ class DataFileExtractSerializer(serializers.Serializer):
def save(self):
"""No "save" action for this serializer."""
pass
class NotesFieldMixin:

View File

@ -281,7 +281,7 @@ QUERYCOUNT = {
'MIN_TIME_TO_LOG': 0.1,
'MIN_QUERY_COUNT_TO_LOG': 25,
},
'IGNORE_REQUEST_PATTERNS': ['^(?!\/(api)?(plugin)?\/).*'],
'IGNORE_REQUEST_PATTERNS': [r'^(?!\/(api)?(plugin)?\/).*'],
'IGNORE_SQL_PATTERNS': [],
'DISPLAY_DUPLICATES': 1,
'RESPONSE_HEADER': 'X-Django-Query-Count',
@ -298,7 +298,7 @@ if (
and INVENTREE_ADMIN_ENABLED
and not TESTING
and get_boolean_setting('INVENTREE_DEBUG_SHELL', 'debug_shell', False)
): # noqa
):
try:
import django_admin_shell # noqa: F401
@ -949,10 +949,7 @@ USE_I18N = True
# Do not use native timezone support in "test" mode
# It generates a *lot* of cruft in the logs
if not TESTING:
USE_TZ = True # pragma: no cover
else:
USE_TZ = False
USE_TZ = bool(not TESTING)
DATE_INPUT_FORMATS = ['%Y-%m-%d']
@ -1065,7 +1062,7 @@ COOKIE_MODE = (
valid_cookie_modes = {'lax': 'Lax', 'strict': 'Strict', 'none': None, 'null': None}
if COOKIE_MODE not in valid_cookie_modes.keys():
if COOKIE_MODE not in valid_cookie_modes:
logger.error('Invalid cookie samesite mode: %s', COOKIE_MODE)
sys.exit(-1)

View File

@ -94,8 +94,7 @@ for name, provider in providers.registry.provider_map.items():
urls = []
if len(adapters) == 1:
urls = handle_oauth2(adapter=adapters[0])
else:
if provider.id in legacy:
elif provider.id in legacy:
logger.warning(
'`%s` is not supported on platform UI. Use `%s` instead.',
provider.id,

View File

@ -1,5 +1,4 @@
"""Provides system status functionality checks."""
# -*- coding: utf-8 -*-
import logging
from datetime import timedelta
@ -10,7 +9,7 @@ from django.utils.translation import gettext_lazy as _
from django_q.models import Success
from django_q.status import Stat
import InvenTree.email
import InvenTree.helpers_email
import InvenTree.ready
logger = logging.getLogger('inventree')
@ -65,7 +64,7 @@ def check_system_health(**kwargs):
result = False
logger.warning(_('Background worker check failed'))
if not InvenTree.email.is_email_configured(): # pragma: no cover
if not InvenTree.helpers_email.is_email_configured(): # pragma: no cover
result = False
logger.warning(_('Email backend not configured'))

View File

@ -9,7 +9,7 @@ import time
import warnings
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Callable
from typing import Callable, Optional
from django.conf import settings
from django.core.exceptions import AppRegistryNotReady
@ -206,7 +206,7 @@ def offload_task(
raise_warning(f"WARNING: '{taskname}' not offloaded - Function not found")
return False
except Exception as exc:
raise_warning(f"WARNING: '{taskname}' not offloaded due to {str(exc)}")
raise_warning(f"WARNING: '{taskname}' not offloaded due to {exc!s}")
log_error('InvenTree.offload_task')
return False
else:
@ -256,7 +256,7 @@ def offload_task(
_func(*args, **kwargs)
except Exception as exc:
log_error('InvenTree.offload_task')
raise_warning(f"WARNING: '{taskname}' failed due to {str(exc)}")
raise_warning(f"WARNING: '{taskname}' failed due to {exc!s}")
raise exc
# Finally, task either completed successfully or was offloaded
@ -291,7 +291,7 @@ class TaskRegister:
task_list: list[ScheduledTask] = []
def register(self, task, schedule, minutes: int = None):
def register(self, task, schedule, minutes: Optional[int] = None):
"""Register a task with the que."""
self.task_list.append(ScheduledTask(task, schedule, minutes))
@ -299,7 +299,9 @@ class TaskRegister:
tasks = TaskRegister()
def scheduled_task(interval: str, minutes: int = None, tasklist: TaskRegister = None):
def scheduled_task(
interval: str, minutes: Optional[int] = None, tasklist: TaskRegister = None
):
"""Register the given task as a scheduled task.
Example:

View File

@ -80,7 +80,7 @@ def do_translate(parser, token):
"""
bits = token.split_contents()
if len(bits) < 2:
raise TemplateSyntaxError("'%s' takes at least one argument" % bits[0])
raise TemplateSyntaxError(f"'{bits[0]}' takes at least one argument")
message_string = parser.compile_filter(bits[1])
remaining = bits[2:]
@ -95,7 +95,7 @@ def do_translate(parser, token):
option = remaining.pop(0)
if option in seen:
raise TemplateSyntaxError(
"The '%s' option was specified more than once." % option
f"The '{option}' option was specified more than once."
)
elif option == 'noop':
noop = True
@ -104,13 +104,12 @@ def do_translate(parser, token):
value = remaining.pop(0)
except IndexError:
raise TemplateSyntaxError(
"No argument provided to the '%s' tag for the context option."
% bits[0]
f"No argument provided to the '{bits[0]}' tag for the context option."
)
if value in invalid_context:
raise TemplateSyntaxError(
"Invalid argument '%s' provided to the '%s' tag for the context "
'option' % (value, bits[0])
f"Invalid argument '{value}' provided to the '{bits[0]}' tag for the context "
'option'
)
message_context = parser.compile_filter(value)
elif option == 'as':
@ -118,16 +117,15 @@ def do_translate(parser, token):
value = remaining.pop(0)
except IndexError:
raise TemplateSyntaxError(
"No argument provided to the '%s' tag for the as option." % bits[0]
f"No argument provided to the '{bits[0]}' tag for the as option."
)
asvar = value
elif option == 'escape':
escape = True
else:
raise TemplateSyntaxError(
"Unknown argument for '%s' tag: '%s'. The only options "
f"Unknown argument for '{bits[0]}' tag: '{option}'. The only options "
"available are 'noop', 'context' \"xxx\", and 'as VAR'."
% (bits[0], option)
)
seen.add(option)

View File

@ -410,10 +410,7 @@ def progress_bar(val, max_val, *args, **kwargs):
else:
style = ''
if max_val != 0:
percent = float(val / max_val) * 100
else:
percent = 0
percent = float(val / max_val) * 100 if max_val != 0 else 0
if percent > 100:
percent = 100
@ -498,7 +495,7 @@ def primitive_to_javascript(primitive):
elif type(primitive) in [int, float]:
return primitive
# Wrap with quotes
return format_html("'{}'", primitive) # noqa: P103
return format_html("'{}'", primitive)
@register.simple_tag()

View File

@ -70,11 +70,11 @@ class APITests(InvenTreeAPITestCase):
"""Helper function to use basic auth."""
# Use basic authentication
authstring = bytes('{u}:{p}'.format(u=self.username, p=self.password), 'ascii')
authstring = bytes(f'{self.username}:{self.password}', 'ascii')
# Use "basic" auth by default
auth = b64encode(authstring).decode('ascii')
self.client.credentials(HTTP_AUTHORIZATION='Basic {auth}'.format(auth=auth))
self.client.credentials(HTTP_AUTHORIZATION=f'Basic {auth}')
def tokenAuth(self):
"""Helper function to use token auth."""

View File

@ -70,7 +70,7 @@ class URLTest(TestCase):
pattern = '{% url [\'"]([^\'"]+)[\'"]([^%]*)%}'
with open(input_file, 'r') as f:
with open(input_file, encoding='utf-8') as f:
data = f.read()
results = re.findall(pattern, data)

View File

@ -543,22 +543,22 @@ class FormatTest(TestCase):
def test_currency_formatting(self):
"""Test that currency formatting works correctly for multiple currencies."""
test_data = (
(Money(3651.285718, 'USD'), 4, True, '$3,651.2857'), # noqa: E201,E202
(Money(487587.849178, 'CAD'), 5, True, 'CA$487,587.84918'), # noqa: E201,E202
(Money(0.348102, 'EUR'), 1, False, '0.3'), # noqa: E201,E202
(Money(0.916530, 'GBP'), 1, True, '£0.9'), # noqa: E201,E202
(Money(61.031024, 'JPY'), 3, False, '61.031'), # noqa: E201,E202
(Money(49609.694602, 'JPY'), 1, True, '¥49,609.7'), # noqa: E201,E202
(Money(155565.264777, 'AUD'), 2, False, '155,565.26'), # noqa: E201,E202
(Money(0.820437, 'CNY'), 4, True, 'CN¥0.8204'), # noqa: E201,E202
(Money(7587.849178, 'EUR'), 0, True, '€7,588'), # noqa: E201,E202
(Money(0.348102, 'GBP'), 3, False, '0.348'), # noqa: E201,E202
(Money(0.652923, 'CHF'), 0, True, 'CHF1'), # noqa: E201,E202
(Money(0.820437, 'CNY'), 1, True, 'CN¥0.8'), # noqa: E201,E202
(Money(98789.5295680, 'CHF'), 0, False, '98,790'), # noqa: E201,E202
(Money(0.585787, 'USD'), 1, True, '$0.6'), # noqa: E201,E202
(Money(0.690541, 'CAD'), 3, True, 'CA$0.691'), # noqa: E201,E202
(Money(427.814104, 'AUD'), 5, True, 'A$427.81410'), # noqa: E201,E202
(Money(3651.285718, 'USD'), 4, True, '$3,651.2857'),
(Money(487587.849178, 'CAD'), 5, True, 'CA$487,587.84918'),
(Money(0.348102, 'EUR'), 1, False, '0.3'),
(Money(0.916530, 'GBP'), 1, True, '£0.9'),
(Money(61.031024, 'JPY'), 3, False, '61.031'),
(Money(49609.694602, 'JPY'), 1, True, '¥49,609.7'),
(Money(155565.264777, 'AUD'), 2, False, '155,565.26'),
(Money(0.820437, 'CNY'), 4, True, 'CN¥0.8204'),
(Money(7587.849178, 'EUR'), 0, True, '€7,588'),
(Money(0.348102, 'GBP'), 3, False, '0.348'),
(Money(0.652923, 'CHF'), 0, True, 'CHF1'),
(Money(0.820437, 'CNY'), 1, True, 'CN¥0.8'),
(Money(98789.5295680, 'CHF'), 0, False, '98,790'),
(Money(0.585787, 'USD'), 1, True, '$0.6'),
(Money(0.690541, 'CAD'), 3, True, 'CA$0.691'),
(Money(427.814104, 'AUD'), 5, True, 'A$427.81410'),
)
with self.settings(LANGUAGE_CODE='en-us'):
@ -794,7 +794,7 @@ class TestDownloadFile(TestCase):
def test_download(self):
"""Tests for DownloadFile."""
helpers.DownloadFile('hello world', 'out.txt')
helpers.DownloadFile(bytes(b'hello world'), 'out.bin')
helpers.DownloadFile(b'hello world', 'out.bin')
class TestMPTT(TestCase):
@ -1033,12 +1033,12 @@ class TestVersionNumber(TestCase):
# Check that the current .git values work too
hash = str(
git_hash = str(
subprocess.check_output('git rev-parse --short HEAD'.split()), 'utf-8'
).strip()
# On some systems the hash is a different length, so just check the first 6 characters
self.assertEqual(hash[:6], version.inventreeCommitHash()[:6])
self.assertEqual(git_hash[:6], version.inventreeCommitHash()[:6])
d = (
str(subprocess.check_output('git show -s --format=%ci'.split()), 'utf-8')
@ -1438,8 +1438,8 @@ class BarcodeMixinTest(InvenTreeTestCase):
'{"part": 17, "stockitem": 12}': 'c88c11ed0628eb7fef0d59b098b96975',
}
for barcode, hash in hashing_tests.items():
self.assertEqual(InvenTree.helpers.hash_barcode(barcode), hash)
for barcode, expected in hashing_tests.items():
self.assertEqual(InvenTree.helpers.hash_barcode(barcode), expected)
class SanitizerTest(TestCase):
@ -1557,15 +1557,12 @@ class ClassValidationMixinTest(TestCase):
def test(self):
"""Test function."""
...
def test1(self):
"""Test function."""
...
def test2(self):
"""Test function."""
...
required_attributes = ['NAME']
required_overrides = [test, [test1, test2]]
@ -1573,8 +1570,6 @@ class ClassValidationMixinTest(TestCase):
class InvalidClass:
"""An invalid class that does not inherit from ClassValidationMixin."""
pass
def test_valid_class(self):
"""Test that a valid class passes the validation."""
@ -1585,11 +1580,9 @@ class ClassValidationMixinTest(TestCase):
def test(self):
"""Test function."""
...
def test2(self):
"""Test function."""
...
TestClass.validate()
@ -1612,7 +1605,6 @@ class ClassValidationMixinTest(TestCase):
def test2(self):
"""Test function."""
...
with self.assertRaisesRegex(
NotImplementedError,
@ -1627,8 +1619,6 @@ class ClassProviderMixinTest(TestCase):
class TestClass(ClassProviderMixin):
"""This class is a dummy class to test the ClassProviderMixin."""
pass
def test_get_provider_file(self):
"""Test the get_provider_file function."""
self.assertEqual(self.TestClass.get_provider_file(), __file__)

View File

@ -15,7 +15,7 @@ def reload_translation_stats():
STATS_FILE = settings.BASE_DIR.joinpath('InvenTree/locale_stats.json').absolute()
try:
with open(STATS_FILE, 'r') as f:
with open(STATS_FILE, encoding='utf-8') as f:
_translation_stats = json.load(f)
except Exception:
_translation_stats = None

View File

@ -246,8 +246,6 @@ class ExchangeRateMixin:
class InvenTreeTestCase(ExchangeRateMixin, UserMixin, TestCase):
"""Testcase with user setup build in."""
pass
class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
"""Base class for running InvenTree API tests."""
@ -283,9 +281,7 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
) # pragma: no cover
if verbose or n >= value:
msg = '\r\n%s' % json.dumps(
context.captured_queries, indent=4
) # pragma: no cover
msg = f'\r\n{json.dumps(context.captured_queries, indent=4)}' # pragma: no cover
else:
msg = None

View File

@ -456,7 +456,7 @@ urlpatterns = []
if settings.INVENTREE_ADMIN_ENABLED:
admin_url = settings.INVENTREE_ADMIN_URL
if settings.ADMIN_SHELL_ENABLE: # noqa
if settings.ADMIN_SHELL_ENABLE:
urlpatterns += [path(f'{admin_url}/shell/', include('django_admin_shell.urls'))]
urlpatterns += [

View File

@ -96,7 +96,6 @@ def validate_sales_order_reference(value):
def validate_tree_name(value):
"""Placeholder for legacy function used in migrations."""
...
def validate_overage(value):

View File

@ -180,7 +180,7 @@ class InvenTreeRoleMixin(PermissionRequiredMixin):
AjaxUpdateView: 'change',
}
for view_class in permission_map.keys():
for view_class in permission_map:
if issubclass(type(self), view_class):
return permission_map[view_class]
@ -238,7 +238,6 @@ class AjaxMixin(InvenTreeRoleMixin):
Ref: https://docs.djangoproject.com/en/dev/topics/forms/
"""
# Do nothing by default
pass
def renderJsonResponse(self, request, form=None, data=None, context=None):
"""Render a JSON response based on specific class context.
@ -286,7 +285,7 @@ class AjaxMixin(InvenTreeRoleMixin):
# Custom feedback`data
fb = self.get_data()
for key in fb.keys():
for key in fb:
data[key] = fb[key]
return JsonResponse(data, safe=False)
@ -329,11 +328,11 @@ class AjaxUpdateView(AjaxMixin, UpdateView):
request, self.get_form(), context=self.get_context_data()
)
def save(self, object, form, **kwargs):
def save(self, obj, form, **kwargs):
"""Method for updating the object in the database. Default implementation is very simple, but can be overridden if required.
Args:
object: The current object, to be updated
obj: The current object, to be updated
form: The validated form
Returns:
@ -578,14 +577,10 @@ class UserSessionOverride:
class CustomSessionDeleteView(UserSessionOverride, SessionDeleteView):
"""Revert to settings after session delete."""
pass
class CustomSessionDeleteOtherView(UserSessionOverride, SessionDeleteOtherView):
"""Revert to settings after session delete."""
pass
class CustomLoginView(LoginView):
"""Custom login view that allows login with urlargs."""

View File

@ -1,34 +1,31 @@
"""Background task definitions for the BuildOrder app"""
"""Background task definitions for the BuildOrder app."""
import logging
from datetime import timedelta
from decimal import Decimal
import logging
from django.contrib.auth.models import User
from django.utils.translation import gettext_lazy as _
from django.template.loader import render_to_string
from django.utils.translation import gettext_lazy as _
from allauth.account.models import EmailAddress
from plugin.events import trigger_event
import common.notifications
import build.models
import InvenTree.email
import common.notifications
import InvenTree.helpers
import InvenTree.helpers_email
import InvenTree.helpers_model
import InvenTree.tasks
from InvenTree.ready import isImportingData
from build.status_codes import BuildStatusGroups
import part.models as part_models
from build.status_codes import BuildStatusGroups
from InvenTree.ready import isImportingData
from plugin.events import trigger_event
logger = logging.getLogger('inventree')
def auto_allocate_build(build_id: int, **kwargs):
"""Run auto-allocation for a specified BuildOrder."""
build_order = build.models.Build.objects.filter(pk=build_id).first()
if not build_order:
@ -40,7 +37,6 @@ def auto_allocate_build(build_id: int, **kwargs):
def complete_build_allocations(build_id: int, user_id: int):
"""Complete build allocations for a specified BuildOrder."""
build_order = build.models.Build.objects.filter(pk=build_id).first()
if user_id:
@ -185,7 +181,7 @@ def check_build_stock(build: build.models.Build):
recipients = emails.values_list('email', flat=True)
InvenTree.email.send_email(subject, '', recipients, html_message=html_message)
InvenTree.helpers_email.send_email(subject, '', recipients, html_message=html_message)
def create_child_builds(build_id: int) -> None:
@ -225,7 +221,7 @@ def create_child_builds(build_id: int) -> None:
def notify_overdue_build_order(bo: build.models.Build):
"""Notify appropriate users that a Build has just become 'overdue'"""
"""Notify appropriate users that a Build has just become 'overdue'."""
targets = []
if bo.issued_by:
@ -265,7 +261,7 @@ def notify_overdue_build_order(bo: build.models.Build):
@InvenTree.tasks.scheduled_task(InvenTree.tasks.ScheduledTask.DAILY)
def check_overdue_build_orders():
"""Check if any outstanding BuildOrders have just become overdue
"""Check if any outstanding BuildOrders have just become overdue.
- This check is performed daily
- Look at the 'target_date' of any outstanding BuildOrder objects

View File

@ -47,7 +47,7 @@ from plugin.models import NotificationUserSetting
from plugin.serializers import NotificationUserSettingSerializer
class CsrfExemptMixin(object):
class CsrfExemptMixin:
"""Exempts the view from CSRF requirements."""
@method_decorator(csrf_exempt)
@ -136,7 +136,7 @@ class CurrencyExchangeView(APIView):
serializer_class = None
@extend_schema(responses={200: common.serializers.CurrencyExchangeSerializer})
def get(self, request, format=None):
def get(self, request, fmt=None):
"""Return information on available currency conversions."""
# Extract a list of all available rates
try:
@ -244,10 +244,7 @@ class GlobalSettingsDetail(RetrieveUpdateAPI):
"""Attempt to find a global setting object with the provided key."""
key = str(self.kwargs['key']).upper()
if (
key.startswith('_')
or key not in common.models.InvenTreeSetting.SETTINGS.keys()
):
if key.startswith('_') or key not in common.models.InvenTreeSetting.SETTINGS:
raise NotFound()
return common.models.InvenTreeSetting.get_setting_object(
@ -318,7 +315,7 @@ class UserSettingsDetail(RetrieveUpdateAPI):
if (
key.startswith('_')
or key not in common.models.InvenTreeUserSetting.SETTINGS.keys()
or key not in common.models.InvenTreeUserSetting.SETTINGS
):
raise NotFound()
@ -566,7 +563,7 @@ class BackgroundTaskOverview(APIView):
permission_classes = [permissions.IsAuthenticated, IsAdminUser]
serializer_class = None
def get(self, request, format=None):
def get(self, request, fmt=None):
"""Return information about the current status of the background task queue."""
import django_q.models as q_models

View File

@ -65,7 +65,7 @@ class MatchFieldForm(forms.Form):
for col in columns:
field_name = col['name']
self.fields[field_name] = forms.ChoiceField(
choices=[('', '-' * 10)] + headers_choices,
choices=[('', '-' * 10), *headers_choices],
required=False,
widget=forms.Select(attrs={'class': 'select fieldselect'}),
)
@ -83,10 +83,7 @@ class MatchItemForm(forms.Form):
if 'file_manager' in kwargs:
file_manager = kwargs.pop('file_manager')
if 'row_data' in kwargs:
row_data = kwargs.pop('row_data')
else:
row_data = None
row_data = kwargs.pop('row_data') if 'row_data' in kwargs else None
super().__init__(*args, **kwargs)
@ -134,7 +131,7 @@ class MatchItemForm(forms.Form):
item_match = row['match_' + col_guess]
# Set field select box
self.fields[field_name] = forms.ChoiceField(
choices=[('', '-' * 10)] + item_options,
choices=[('', '-' * 10), *item_options],
required=False,
widget=forms.Select(attrs={'class': 'select bomselect'}),
)
@ -154,7 +151,7 @@ class MatchItemForm(forms.Form):
field_name = 'item_select-' + str(row['index'])
# Set field select box
self.fields[field_name] = forms.ChoiceField(
choices=[('', '-' * 10)] + item_options,
choices=[('', '-' * 10), *item_options],
required=False,
widget=forms.Select(attrs={'class': 'select bomselect'}),
)

View File

@ -55,7 +55,7 @@ def get_icon_packs():
tabler_icons_path = Path(__file__).parent.parent.joinpath(
'InvenTree/static/tabler-icons/icons.json'
)
with open(tabler_icons_path, 'r') as tabler_icons_file:
with open(tabler_icons_path, encoding='utf-8') as tabler_icons_file:
tabler_icons = json.load(tabler_icons_file)
icon_packs = [

View File

@ -248,7 +248,7 @@ class BaseInvenTreeSetting(models.Model):
If a particular setting is not present, create it with the default value
"""
cache_key = f'BUILD_DEFAULT_VALUES:{str(cls.__name__)}'
cache_key = f'BUILD_DEFAULT_VALUES:{cls.__name__!s}'
try:
if InvenTree.helpers.str2bool(cache.get(cache_key, False)):
@ -331,7 +331,7 @@ class BaseInvenTreeSetting(models.Model):
- The unique KEY string
- Any key:value kwargs associated with the particular setting type (e.g. user-id)
"""
key = f'{str(cls.__name__)}:{setting_key}'
key = f'{cls.__name__!s}:{setting_key}'
for k, v in kwargs.items():
key += f'_{k}:{v}'
@ -2062,7 +2062,7 @@ class InvenTreeSetting(BaseInvenTreeSetting):
'description': _(
'Check that all plugins are installed on startup - enable in container environments'
),
'default': str(os.getenv('INVENTREE_DOCKER', False)).lower()
'default': str(os.getenv('INVENTREE_DOCKER', 'False')).lower()
in ['1', 'true'],
'validator': bool,
'requires_restart': True,
@ -3092,11 +3092,8 @@ class CustomUnit(models.Model):
"""Ensure that the custom unit is unique."""
super().validate_unique(exclude)
if self.symbol:
if (
CustomUnit.objects.filter(symbol=self.symbol)
.exclude(pk=self.pk)
.exists()
if self.symbol and (
CustomUnit.objects.filter(symbol=self.symbol).exclude(pk=self.pk).exists()
):
raise ValidationError({'symbol': _('Unit symbol must be unique')})
@ -3433,13 +3430,13 @@ class InvenTreeCustomUserStateModel(models.Model):
'reference_status': _('Reference status set not found')
})
ref_set = ref_set[0]
if self.key in ref_set.keys():
if self.key in ref_set.keys(): # noqa: SIM118
raise ValidationError({
'key': _(
'Key must be different from the logical keys of the reference status'
)
})
if self.logical_key not in ref_set.keys():
if self.logical_key not in ref_set.keys(): # noqa: SIM118
raise ValidationError({
'logical_key': _(
'Logical key must be in the logical keys of the reference status'

View File

@ -143,11 +143,7 @@ class NotificationMethod:
# Check if method globally enabled
plg_instance = registry.get_plugin(plg_cls.NAME.lower())
if plg_instance and not plg_instance.get_setting(self.GLOBAL_SETTING):
return True
# Lets go!
return False
return plg_instance and not plg_instance.get_setting(self.GLOBAL_SETTING)
def usersetting(self, target):
"""Returns setting for this method for a given user."""
@ -365,7 +361,7 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs):
obj_ref_value = getattr(obj, 'id', None)
if not obj_ref_value:
raise KeyError(
f"Could not resolve an object reference for '{str(obj)}' with {obj_ref}, pk, id"
f"Could not resolve an object reference for '{obj!s}' with {obj_ref}, pk, id"
)
# Check if we have notified recently...
@ -432,9 +428,9 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs):
deliver_notification(method, obj, category, target_users, context)
except NotImplementedError as error:
# Allow any single notification method to fail, without failing the others
logger.error(error) # noqa: LOG005
logger.error(error)
except Exception as error:
logger.error(error) # noqa: LOG005
logger.error(error)
# Set delivery flag
common.models.NotificationEntry.notify(category, obj_ref_value)

View File

@ -574,8 +574,7 @@ class AttachmentSerializer(InvenTreeModelSerializer):
model_type = self.validated_data.get('model_type', None)
if model_type is None:
if self.instance:
if model_type is None and self.instance:
model_type = self.instance.model_type
# Ensure that the user has permission to attach files to the specified model

View File

@ -393,7 +393,7 @@ class SettingsTest(InvenTreeTestCase):
'before_save',
]
for k in setting.keys():
for k in setting:
self.assertIn(k, allowed_keys)
# Check default value for boolean settings
@ -431,7 +431,7 @@ class SettingsTest(InvenTreeTestCase):
@override_settings(SITE_URL=None, PLUGIN_TESTING=True, PLUGIN_TESTING_SETUP=True)
def test_defaults(self):
"""Populate the settings with default values."""
for key in InvenTreeSetting.SETTINGS.keys():
for key in InvenTreeSetting.SETTINGS:
value = InvenTreeSetting.get_setting_default(key)
try:
@ -520,7 +520,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
response = self.get(url, expected_code=200)
n_public_settings = len([
k for k in InvenTreeSetting.SETTINGS.keys() if not k.startswith('_')
k for k in InvenTreeSetting.SETTINGS if not k.startswith('_')
])
# Number of results should match the number of settings
@ -832,11 +832,9 @@ class PluginSettingsApiTest(PluginMixin, InvenTreeAPITestCase):
def test_invalid_setting_key(self):
"""Test that an invalid setting key returns a 404."""
...
def test_uninitialized_setting(self):
"""Test that requesting an uninitialized setting creates the setting."""
...
class ErrorReportTest(InvenTreeAPITestCase):
@ -934,7 +932,7 @@ class WebhookMessageTests(TestCase):
def test_bad_token(self):
"""Test that a wrong token is not working."""
response = self.client.post(
self.url, content_type=CONTENT_TYPE_JSON, **{'HTTP_TOKEN': '1234567fghj'}
self.url, content_type=CONTENT_TYPE_JSON, HTTP_TOKEN='1234567fghj'
)
assert response.status_code == HTTPStatus.FORBIDDEN
@ -957,7 +955,7 @@ class WebhookMessageTests(TestCase):
self.url,
data="{'this': 123}",
content_type=CONTENT_TYPE_JSON,
**{'HTTP_TOKEN': str(self.endpoint_def.token)},
HTTP_TOKEN=str(self.endpoint_def.token),
)
assert response.status_code == HTTPStatus.NOT_ACCEPTABLE
@ -1005,7 +1003,7 @@ class WebhookMessageTests(TestCase):
response = self.client.post(
self.url,
content_type=CONTENT_TYPE_JSON,
**{'HTTP_TOKEN': str('68MXtc/OiXdA5e2Nq9hATEVrZFpLb3Zb0oau7n8s31I=')},
HTTP_TOKEN='68MXtc/OiXdA5e2Nq9hATEVrZFpLb3Zb0oau7n8s31I=',
)
assert response.status_code == HTTPStatus.OK
@ -1020,7 +1018,7 @@ class WebhookMessageTests(TestCase):
self.url,
data={'this': 'is a message'},
content_type=CONTENT_TYPE_JSON,
**{'HTTP_TOKEN': str(self.endpoint_def.token)},
HTTP_TOKEN=str(self.endpoint_def.token),
)
assert response.status_code == HTTPStatus.OK

View File

@ -312,7 +312,6 @@ class FileManagementFormView(MultiStepFormView):
This method is very specific to the type of data found in the file,
therefore overwrite it in the subclass.
"""
pass
def get_clean_items(self):
"""Returns dict with all cleaned values."""
@ -363,10 +362,7 @@ class FileManagementFormView(MultiStepFormView):
duplicates = []
for col in self.column_names:
if col in self.column_selections:
guess = self.column_selections[col]
else:
guess = None
guess = self.column_selections.get(col, None)
if guess:
n = list(self.column_selections.values()).count(

View File

@ -10,4 +10,3 @@ class CompanyConfig(AppConfig):
def ready(self):
"""This function is called whenever the Company app is loaded."""
pass

View File

@ -46,10 +46,7 @@ def rename_company_image(instance, filename):
"""
base = 'company_images'
if filename.count('.') > 0:
ext = filename.split('.')[-1]
else:
ext = ''
ext = filename.split('.')[-1] if filename.count('.') > 0 else ''
fn = f'company_{instance.pk}_img'

View File

@ -11,8 +11,8 @@ from .transition import StateTransitionMixin, TransitionMethod, storage
__all__ = [
'ColorEnum',
'StatusCode',
'storage',
'TransitionMethod',
'StateTransitionMixin',
'StatusCode',
'TransitionMethod',
'storage',
]

View File

@ -5,7 +5,7 @@ import re
from enum import Enum
class BaseEnum(enum.IntEnum):
class BaseEnum(enum.IntEnum): # noqa: PLW1641
"""An `Enum` capabile of having its members have docstrings.
Based on https://stackoverflow.com/questions/19330460/how-do-i-put-docstrings-on-enums
@ -100,9 +100,7 @@ class StatusCode(BaseEnum):
return False
if callable(value):
return False
if not isinstance(value.value, int):
return False
return True
return isinstance(value.value, int)
@classmethod
def values(cls, key=None):

View File

@ -27,7 +27,6 @@ class GeneralStatus(StatusCode):
def GHI(self): # This should be ignored
"""A invalid function."""
...
class GeneralStateTest(InvenTreeTestCase):

View File

@ -3,6 +3,6 @@
from django import template
register = template.Library()
from generic.states.tags import status_label # noqa: E402
from generic.states.tags import status_label
__all__ = [status_label]
__all__ = ['status_label']

View File

@ -213,8 +213,6 @@ class DataImportExportSerializerMixin(
):
"""Mixin class for adding data import/export functionality to a DRF serializer."""
pass
class DataExportViewMixin:
"""Mixin class for exporting a dataset via the API.

View File

@ -2,6 +2,7 @@
import json
import logging
from typing import Optional
from django.contrib.auth.models import User
from django.core.exceptions import ValidationError as DjangoValidationError
@ -117,8 +118,8 @@ class DataImportSession(models.Model):
"""
mapping = {}
for map in self.column_mappings.all():
mapping[map.field] = map.column
for i in self.column_mappings.all():
mapping[i.field] = i.column
return mapping
@ -209,13 +210,13 @@ class DataImportSession(models.Model):
missing_fields = []
for field in required_fields.keys():
for field in required_fields:
# An override value exists
if field in field_overrides:
continue
# A default value exists
if field in field_defaults and field_defaults[field]:
if field_defaults.get(field):
continue
# The field has been mapped to a data column
@ -537,7 +538,10 @@ class DataImportRow(models.Model):
return overrides
def extract_data(
self, available_fields: dict = None, field_mapping: dict = None, commit=True
self,
available_fields: Optional[dict] = None,
field_mapping: Optional[dict] = None,
commit=True,
):
"""Extract row data from the provided data dictionary."""
if not field_mapping:

View File

@ -36,7 +36,7 @@ def load_data_file(data_file, file_format=None):
try:
data = file_object.read()
except (IOError, FileNotFoundError):
except (OSError, FileNotFoundError):
raise ValidationError(_('Failed to open data file'))
# Excel formats expect binary data
@ -113,8 +113,7 @@ def get_field_label(field) -> str:
Returns:
str: Field label
"""
if field:
if label := getattr(field, 'label', None):
if field and (label := getattr(field, 'label', None)):
return label
# TODO: Check if the field is a model field

View File

@ -19,7 +19,7 @@ class ImporterTest(InvenTreeTestCase):
fn = os.path.join(os.path.dirname(__file__), 'test_data', 'companies.csv')
with open(fn, 'r') as input_file:
with open(fn, encoding='utf-8') as input_file:
data = input_file.read()
session = DataImportSession.objects.create(
@ -61,4 +61,3 @@ class ImporterTest(InvenTreeTestCase):
def test_field_defaults(self):
"""Test default field values."""
...

View File

@ -36,7 +36,7 @@ def validate_importer_model_type(value):
"""Validate that the given model type is supported for importing."""
from importer.registry import supported_models
if value not in supported_models().keys():
if value not in supported_models():
raise ValidationError(f"Unsupported model type '{value}'")

View File

@ -1,4 +1,4 @@
from machine.machine_type import BaseDriver, BaseMachineType, MachineStatus
from machine.registry import registry
__all__ = ['registry', 'BaseMachineType', 'BaseDriver', 'MachineStatus']
__all__ = ['BaseDriver', 'BaseMachineType', 'MachineStatus', 'registry']

View File

@ -139,7 +139,7 @@ class BaseDriver(
Arguments:
error: Exception or string
"""
self.set_shared_state('errors', self.errors + [error])
self.set_shared_state('errors', [*self.errors, error])
# --- state getters/setters
@property
@ -317,7 +317,7 @@ class BaseMachineType(
Arguments:
error: Exception or string
"""
self.set_shared_state('errors', self.errors + [error])
self.set_shared_state('errors', [*self.errors, error])
def reset_errors(self):
"""Helper function for resetting the error list for a machine."""

View File

@ -4,8 +4,8 @@ from machine.machine_types.label_printer import (
)
__all__ = [
# machine types
'LabelPrinterMachine',
# base drivers
'LabelPrinterBaseDriver',
# machine types
'LabelPrinterMachine',
]

View File

@ -38,7 +38,7 @@ class MachineRegistry(
def handle_error(self, error: Union[Exception, str]):
"""Helper function for capturing errors with the machine registry."""
self.set_shared_state('errors', self.errors + [error])
self.set_shared_state('errors', [*self.errors, error])
def initialize(self, main: bool = False):
"""Initialize the machine registry."""

View File

@ -84,14 +84,12 @@ class MachineAPITest(TestMachineRegistryMixin, InvenTreeAPITestCase):
machine_type,
{
**machine_type,
**{
'slug': 'label-printer',
'name': 'Label Printer',
'description': 'Directly print labels for various items.',
'provider_plugin': None,
'is_builtin': True,
},
},
)
self.assertTrue(
machine_type['provider_file'].endswith(
@ -109,7 +107,6 @@ class MachineAPITest(TestMachineRegistryMixin, InvenTreeAPITestCase):
driver,
{
**driver,
**{
'slug': 'test-label-printer-api',
'name': 'Test label printer',
'description': 'This is a test label printer driver for testing.',
@ -118,7 +115,6 @@ class MachineAPITest(TestMachineRegistryMixin, InvenTreeAPITestCase):
'machine_type': 'label-printer',
'driver_errors': [],
},
},
)
self.assertEqual(driver['provider_file'], __file__)
@ -173,7 +169,6 @@ class MachineAPITest(TestMachineRegistryMixin, InvenTreeAPITestCase):
response.data[0],
{
**response.data[0],
**{
'name': 'Test Machine',
'machine_type': 'label-printer',
'driver': 'test-label-printer-api',
@ -184,7 +179,6 @@ class MachineAPITest(TestMachineRegistryMixin, InvenTreeAPITestCase):
'status_text': '',
'is_driver_available': True,
},
},
)
def test_machine_detail(self):
@ -216,9 +210,7 @@ class MachineAPITest(TestMachineRegistryMixin, InvenTreeAPITestCase):
reverse('api-machine-detail', kwargs={'pk': pk}),
{'name': 'Updated Machine'},
)
self.assertEqual(
response.data, {**response.data, **{'name': 'Updated Machine'}}
)
self.assertEqual(response.data, {**response.data, 'name': 'Updated Machine'})
self.assertEqual(MachineConfig.objects.get(pk=pk).name, 'Updated Machine')
# Delete the machine

View File

@ -337,8 +337,6 @@ class PurchaseOrderList(PurchaseOrderMixin, DataExportViewMixin, ListCreateAPI):
class PurchaseOrderDetail(PurchaseOrderMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a PurchaseOrder object."""
pass
class PurchaseOrderContextMixin:
"""Mixin to add purchase order object as serializer context variable."""
@ -603,8 +601,6 @@ class PurchaseOrderLineItemList(
class PurchaseOrderLineItemDetail(PurchaseOrderLineItemMixin, RetrieveUpdateDestroyAPI):
"""Detail API endpoint for PurchaseOrderLineItem object."""
pass
class PurchaseOrderExtraLineList(GeneralExtraLineList, ListCreateAPI):
"""API endpoint for accessing a list of PurchaseOrderExtraLine objects."""
@ -746,8 +742,6 @@ class SalesOrderList(SalesOrderMixin, DataExportViewMixin, ListCreateAPI):
class SalesOrderDetail(SalesOrderMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a SalesOrder object."""
pass
class SalesOrderLineItemFilter(LineItemFilter):
"""Custom filters for SalesOrderLineItemList endpoint."""
@ -865,8 +859,6 @@ class SalesOrderLineItemList(
class SalesOrderLineItemDetail(SalesOrderLineItemMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a SalesOrderLineItem object."""
pass
class SalesOrderExtraLineList(GeneralExtraLineList, ListCreateAPI):
"""API endpoint for accessing a list of SalesOrderExtraLine objects."""
@ -1181,8 +1173,6 @@ class ReturnOrderList(ReturnOrderMixin, DataExportViewMixin, ListCreateAPI):
class ReturnOrderDetail(ReturnOrderMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single ReturnOrder object."""
pass
class ReturnOrderContextMixin:
"""Simple mixin class to add a ReturnOrder to the serializer context."""
@ -1310,8 +1300,6 @@ class ReturnOrderLineItemList(
class ReturnOrderLineItemDetail(ReturnOrderLineItemMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a ReturnOrderLineItem object."""
pass
class ReturnOrderExtraLineList(GeneralExtraLineList, ListCreateAPI):
"""API endpoint for accessing a list of ReturnOrderExtraLine objects."""
@ -1368,8 +1356,7 @@ class OrderCalendarExport(ICalFeed):
if auth[0].lower() == 'basic':
uname, passwd = base64.b64decode(auth[1]).decode('ascii').split(':')
user = authenticate(username=uname, password=passwd)
if user is not None:
if user.is_active:
if user is not None and user.is_active:
login(request, user)
request.user = user

View File

@ -1048,19 +1048,11 @@ class SalesOrder(TotalPriceMixin, Order):
def is_fully_allocated(self):
"""Return True if all line items are fully allocated."""
for line in self.lines.all():
if not line.is_fully_allocated():
return False
return True
return all(line.is_fully_allocated() for line in self.lines.all())
def is_overallocated(self):
"""Return true if any lines in the order are over-allocated."""
for line in self.lines.all():
if line.is_overallocated():
return True
return False
return any(line.is_overallocated() for line in self.lines.all())
def is_completed(self):
"""Check if this order is "shipped" (all line items delivered)."""

View File

@ -166,7 +166,8 @@ class AbstractOrderSerializer(DataImportExportSerializerMixin, serializers.Seria
'notes',
'barcode_hash',
'overdue',
] + extra_fields
*extra_fields,
]
class AbstractLineItemSerializer:
@ -433,7 +434,7 @@ class PurchaseOrderLineItemSerializer(
def skip_create_fields(self):
"""Return a list of fields to skip when creating a new object."""
return ['auto_pricing', 'merge_items'] + super().skip_create_fields()
return ['auto_pricing', 'merge_items', *super().skip_create_fields()]
@staticmethod
def annotate_queryset(queryset):
@ -740,8 +741,7 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer):
base_quantity = line_item.part.base_quantity(quantity)
# Does the quantity need to be "integer" (for trackable parts?)
if base_part.trackable:
if Decimal(base_quantity) != int(base_quantity):
if base_part.trackable and Decimal(base_quantity) != int(base_quantity):
raise ValidationError({
'quantity': _(
'An integer quantity must be provided for trackable parts'

View File

@ -1527,7 +1527,7 @@ class SalesOrderTest(OrderTest):
self.download_file(
reverse('api-so-list'),
{'export': fmt},
decode=True if fmt == 'csv' else False,
decode=fmt == 'csv',
expected_code=200,
expected_fn=r'InvenTree_SalesOrder_.+',
)

View File

@ -294,7 +294,7 @@ class SalesOrderExport(AjaxView):
export_format = request.GET.get('format', 'csv')
filename = f'{str(order)} - {order.customer.name}.{export_format}'
filename = f'{order!s} - {order.customer.name}.{export_format}'
dataset = SalesOrderLineItemResource().export(queryset=order.lines.all())
@ -321,7 +321,7 @@ class PurchaseOrderExport(AjaxView):
export_format = request.GET.get('format', 'csv')
filename = f'{str(order)} - {order.supplier.name}.{export_format}'
filename = f'{order!s} - {order.supplier.name}.{export_format}'
dataset = PurchaseOrderLineItemResource().export(queryset=order.lines.all())

View File

@ -167,8 +167,7 @@ class CategoryFilter(rest_filters.FilterSet):
top_level = str2bool(self.data.get('top_level', None))
# If the parent is *not* provided, update the results based on the "cascade" value
if not parent or top_level:
if not value:
if (not parent or top_level) and not value:
# If "cascade" is False, only return top-level categories
queryset = queryset.filter(parent=None)
@ -466,8 +465,6 @@ class PartTestTemplateMixin:
class PartTestTemplateDetail(PartTestTemplateMixin, RetrieveUpdateDestroyAPI):
"""Detail endpoint for PartTestTemplate model."""
pass
class PartTestTemplateList(PartTestTemplateMixin, DataExportViewMixin, ListCreateAPI):
"""API endpoint for listing (and creating) a PartTestTemplate."""
@ -1570,8 +1567,6 @@ class PartParameterTemplateList(
class PartParameterTemplateDetail(PartParameterTemplateMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for accessing the detail view for a PartParameterTemplate object."""
pass
class PartParameterAPIMixin:
"""Mixin class for PartParameter API endpoints."""
@ -1663,8 +1658,6 @@ class PartParameterList(PartParameterAPIMixin, DataExportViewMixin, ListCreateAP
class PartParameterDetail(PartParameterAPIMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single PartParameter object."""
pass
class PartStocktakeFilter(rest_filters.FilterSet):
"""Custom filter for the PartStocktakeList endpoint."""
@ -1922,8 +1915,6 @@ class BomList(BomMixin, DataExportViewMixin, ListCreateDestroyAPIView):
class BomDetail(BomMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single BomItem object."""
pass
class BomImportUpload(CreateAPI):
"""API endpoint for uploading a complete Bill of Materials.

View File

@ -4,6 +4,7 @@ Primarily BOM upload tools.
"""
from collections import OrderedDict
from typing import Optional
from django.utils.translation import gettext as _
@ -40,7 +41,11 @@ def MakeBomTemplate(fmt):
def ExportBom(
part: Part, fmt='csv', cascade: bool = False, max_levels: int = None, **kwargs
part: Part,
fmt='csv',
cascade: bool = False,
max_levels: Optional[int] = None,
**kwargs,
):
"""Export a BOM (Bill of Materials) for a given part.
@ -244,10 +249,7 @@ def ExportBom(
manufacturer_name = ''
# Extract the "MPN" field from the Manufacturer Part
if mp_part:
manufacturer_mpn = mp_part.MPN
else:
manufacturer_mpn = ''
manufacturer_mpn = mp_part.MPN if mp_part else ''
# Generate a column name for this manufacturer
k_man = f'{_("Manufacturer")}_{mp_idx}'
@ -270,10 +272,7 @@ def ExportBom(
else:
supplier_name = ''
if sp_part:
supplier_sku = sp_part.SKU
else:
supplier_sku = ''
supplier_sku = sp_part.SKU if sp_part else ''
# Generate column names for this supplier
k_sup = (
@ -307,10 +306,7 @@ def ExportBom(
supplier_parts_used.add(sp_part)
if sp_part.supplier:
supplier_name = sp_part.supplier.name
else:
supplier_name = ''
supplier_name = sp_part.supplier.name if sp_part.supplier else ''
supplier_sku = sp_part.SKU

View File

@ -231,10 +231,7 @@ class PartCategory(InvenTree.models.InvenTreeTree):
"""Get all unique parameter names for all parts from this category."""
unique_parameters_names = []
if prefetch:
parts = prefetch
else:
parts = self.prefetch_parts_parameters(cascade=cascade)
parts = prefetch or self.prefetch_parts_parameters(cascade=cascade)
for part in parts:
for parameter in part.parameters.all():
@ -248,10 +245,7 @@ class PartCategory(InvenTree.models.InvenTreeTree):
"""Get all parameter names and values for all parts from this category."""
category_parameters = []
if prefetch:
parts = prefetch
else:
parts = self.prefetch_parts_parameters(cascade=cascade)
parts = prefetch or self.prefetch_parts_parameters(cascade=cascade)
for part in parts:
part_parameters = {
@ -933,17 +927,19 @@ class Part(
'IPN': _('Duplicate IPN not allowed in part settings')
})
if self.revision_of and self.revision:
if (
self.revision_of
and self.revision
and (
Part.objects.exclude(pk=self.pk)
.filter(revision_of=self.revision_of, revision=self.revision)
.exists()
)
):
raise ValidationError(_('Duplicate part revision already exists.'))
# Ensure unique across (Name, revision, IPN) (as specified)
if self.revision or self.IPN:
if (
if (self.revision or self.IPN) and (
Part.objects.exclude(pk=self.pk)
.filter(name=self.name, revision=self.revision, IPN=self.IPN)
.exists()
@ -3888,10 +3884,12 @@ class PartParameter(InvenTree.models.InvenTreeMetadataModel):
super().clean()
# Validate the parameter data against the template units
if get_global_setting(
if (
get_global_setting(
'PART_PARAMETER_ENFORCE_UNITS', True, cache=False, create=False
)
and self.template.units
):
if self.template.units:
try:
InvenTree.conversion.convert_physical_value(
self.data, self.template.units
@ -4205,8 +4203,7 @@ class BomItem(
# Check if the part was changed
deltas = self.get_field_deltas()
if 'part' in deltas:
if old_part := deltas['part'].get('old', None):
if 'part' in deltas and (old_part := deltas['part'].get('old', None)):
self.check_part_lock(old_part)
# Update the 'validated' field based on checksum calculation
@ -4344,7 +4341,7 @@ class BomItem(
- allow_variants
"""
# Seed the hash with the ID of this BOM item
result_hash = hashlib.md5(''.encode())
result_hash = hashlib.md5(b'')
# The following components are used to calculate the checksum
components = [
@ -4438,8 +4435,7 @@ class BomItem(
try:
ovg = float(overage)
if ovg < 0:
ovg = 0
ovg = max(ovg, 0)
return ovg
except ValueError:
@ -4451,10 +4447,8 @@ class BomItem(
try:
percent = float(overage) / 100.0
if percent > 1:
percent = 1
if percent < 0:
percent = 0
percent = min(percent, 1)
percent = max(percent, 0)
# Must be represented as a decimal
percent = Decimal(percent)

View File

@ -2000,8 +2000,7 @@ class BomImportExtractSerializer(InvenTree.serializers.DataFileExtractSerializer
if part is None:
row['errors']['part'] = _('No matching part found')
else:
if not part.component:
elif not part.component:
row['errors']['part'] = _('Part is not designated as a component')
# Update the 'part' value in the row

View File

@ -1618,7 +1618,7 @@ class PartDetailTests(PartAPITestBase):
# Try to upload a non-image file
test_path = BASE_DIR / '_testfolder' / 'dummy_image'
with open(f'{test_path}.txt', 'w') as dummy_image:
with open(f'{test_path}.txt', 'w', encoding='utf-8') as dummy_image:
dummy_image.write('hello world')
with open(f'{test_path}.txt', 'rb') as dummy_image:

View File

@ -49,7 +49,7 @@ class BomExportTest(InvenTreeTestCase):
with open(filename, 'wb') as f:
f.write(response.getvalue())
with open(filename, 'r') as f:
with open(filename, encoding='utf-8') as f:
reader = csv.reader(f, delimiter=',')
for line in reader:
@ -96,7 +96,7 @@ class BomExportTest(InvenTreeTestCase):
f.write(response.getvalue())
# Read the file
with open(filename, 'r') as f:
with open(filename, encoding='utf-8') as f:
reader = csv.reader(f, delimiter=',')
for line in reader:

View File

@ -419,7 +419,7 @@ class PartParameterTest(InvenTreeAPITestCase):
response = self.get(
url,
{'ordering': 'parameter_{pk}'.format(pk=template.pk), 'parameters': 'true'},
{'ordering': f'parameter_{template.pk}', 'parameters': 'true'},
expected_code=200,
)
@ -436,10 +436,7 @@ class PartParameterTest(InvenTreeAPITestCase):
# Next, check reverse ordering
response = self.get(
url,
{
'ordering': '-parameter_{pk}'.format(pk=template.pk),
'parameters': 'true',
},
{'ordering': f'-parameter_{template.pk}', 'parameters': 'true'},
expected_code=200,
)

View File

@ -111,7 +111,6 @@ class PartPricingTests(InvenTreeTestCase):
def test_invalid_rate(self):
"""Ensure that conversion behaves properly with missing rates."""
...
def test_simple(self):
"""Tests for hard-coded values."""

View File

@ -180,9 +180,9 @@ class PartImport(FileManagementFormView):
if idx in self.file_manager.OPTIONAL_MATCH_HEADERS:
try:
exact_match = self.allowed_items[idx].get(**{
a: data for a in self.matches[idx]
})
exact_match = self.allowed_items[idx].get(
**dict.fromkeys(self.matches[idx], data)
)
except (
ValueError,
self.allowed_items[idx].model.DoesNotExist,
@ -414,7 +414,7 @@ class PartDetailFromIPN(PartDetail):
if not self.object:
return HttpResponseRedirect(reverse('part-index'))
return super(PartDetailFromIPN, self).get(request, *args, **kwargs)
return super().get(request, *args, **kwargs)
class PartImageSelect(AjaxUpdateView):

View File

@ -5,8 +5,8 @@ from .plugin import InvenTreePlugin
from .registry import registry
__all__ = [
'registry',
'InvenTreePlugin',
'MixinNotImplementedError',
'MixinImplementationError',
'MixinNotImplementedError',
'registry',
]

View File

@ -63,7 +63,7 @@ class PluginFilter(rest_filters.FilterSet):
match = True
for mixin in mixins:
if mixin not in result.mixins().keys():
if mixin not in result.mixins():
match = False
break

View File

@ -335,8 +335,7 @@ class BarcodePOAllocate(BarcodeView):
if not part and not supplier_part and not manufacturer_part:
raise ValidationError({'error': _('No matching part data found')})
if part:
if part_id := part.get('pk', None):
if part and (part_id := part.get('pk', None)):
supplier_parts = supplier_parts.filter(part__pk=part_id)
if supplier_part:
@ -594,7 +593,7 @@ class BarcodeSOAllocate(BarcodeView):
raise ValidationError(response)
# If we have sufficient information, we can allocate the stock item
if all((x is not None for x in [line_item, sales_order, shipment, quantity])):
if all(x is not None for x in [line_item, sales_order, shipment, quantity]):
order.models.SalesOrderAllocation.objects.create(
line=line_item, shipment=shipment, item=stock_item, quantity=quantity
)

View File

@ -384,7 +384,9 @@ class SupplierBarcodeMixin(BarcodeMixin):
return orders_intersection if orders_intersection else orders_union
@staticmethod
def get_supplier_parts(sku: str = None, supplier: Company = None, mpn: str = None):
def get_supplier_parts(
sku: str | None = None, supplier: Company = None, mpn: str | None = None
):
"""Get a supplier part from SKU or by supplier and MPN."""
if not (sku or supplier or mpn):
return SupplierPart.objects.none()
@ -420,10 +422,10 @@ class SupplierBarcodeMixin(BarcodeMixin):
def receive_purchase_order_item(
supplier_part: SupplierPart,
user: User,
quantity: Decimal | str = None,
quantity: Decimal | str | None = None,
purchase_order: PurchaseOrder = None,
location: StockLocation = None,
barcode: str = None,
barcode: str | None = None,
) -> dict:
"""Try to receive a purchase order item.
@ -471,9 +473,9 @@ class SupplierBarcodeMixin(BarcodeMixin):
# 2. check if it's defined on the part
# 3. check if there's 1 or 0 stock locations defined in InvenTree
# -> assume all stock is going into that location (or no location)
if location := line_item.destination:
pass
elif location := supplier_part.part.get_default_location():
if (location := line_item.destination) or (
location := supplier_part.part.get_default_location()
):
pass
elif StockLocation.objects.count() <= 1:
if not (location := StockLocation.objects.first()):

View File

@ -41,7 +41,7 @@ class BarcodeGenerateSerializer(serializers.Serializer):
plugin.base.barcodes.helper.get_supported_barcode_models_map()
)
if model not in supported_models.keys():
if model not in supported_models:
raise ValidationError(_('Model is not supported'))
return model

View File

@ -147,10 +147,7 @@ def allow_table_event(table_name):
'part_partstocktakereport',
]
if table_name in ignore_tables:
return False
return True
return not table_name in ignore_tables
@receiver(post_save)

View File

@ -3,6 +3,7 @@
import json as json_pkg
import logging
from collections.abc import Iterable
from typing import Optional
import requests
@ -117,10 +118,10 @@ class APICallMixin:
self,
endpoint: str,
method: str = 'GET',
url_args: dict = None,
url_args: Optional[dict] = None,
data=None,
json=None,
headers: dict = None,
headers: Optional[dict] = None,
simple_response: bool = True,
endpoint_is_url: bool = False,
):

View File

@ -33,7 +33,6 @@ class ReportMixin:
request: The request object which initiated the report generation
context: The context dictionary to add to
"""
pass
def add_label_context(self, label_instance, model_instance, request, context):
"""Add extra context to the provided label instance.
@ -46,7 +45,6 @@ class ReportMixin:
request: The request object which initiated the label generation
context: The context dictionary to add to
"""
pass
def report_callback(self, template, instance, report, request):
"""Callback function called after a report is generated.
@ -59,4 +57,3 @@ class ReportMixin:
The default implementation does nothing.
"""
pass

View File

@ -148,7 +148,7 @@ class ScheduleMixin:
def get_task_names(self):
"""All defined task names."""
# Returns a list of all task names associated with this plugin instance
return [self.get_task_name(key) for key in self.scheduled_tasks.keys()]
return [self.get_task_name(key) for key in self.scheduled_tasks]
def register_tasks(self):
"""Register the tasks with the database."""
@ -200,7 +200,7 @@ class ScheduleMixin:
try:
from django_q.models import Schedule
for key, _ in self.scheduled_tasks.items():
for key in self.scheduled_tasks:
task_name = self.get_task_name(key)
try:

View File

@ -15,8 +15,6 @@ else:
class SettingsKeyType:
"""Dummy class, so that python throws no error."""
pass
class SettingsMixin:
"""Mixin that enables global settings for the plugin."""

View File

@ -1,5 +1,7 @@
"""Validation mixin class definition."""
from typing import Optional
from django.core.exceptions import ValidationError
from django.db.models import Model
@ -67,7 +69,9 @@ class ValidationMixin:
"""
return None
def validate_model_instance(self, instance: Model, deltas: dict = None) -> None:
def validate_model_instance(
self, instance: Model, deltas: Optional[dict] = None
) -> None:
"""Run custom validation on a database model instance.
This method is called when a model instance is being validated.
@ -218,4 +222,3 @@ class ValidationMixin:
Raises:
ValidationError: If the proposed parameter value is objectionable
"""
pass

Some files were not shown because too many files have changed in this diff Show More