add simplify rules

This commit is contained in:
Matthias Mair 2024-08-19 23:09:20 +02:00
parent 3b9d485403
commit a2dfac593b
No known key found for this signature in database
GPG Key ID: A593429DDA23B66A
57 changed files with 151 additions and 271 deletions

View File

@ -175,10 +175,7 @@ if __name__ == '__main__':
print(f"Version number '{version}' does not match tag '{version_tag}'") print(f"Version number '{version}' does not match tag '{version_tag}'")
sys.exit sys.exit
if highest_release: docker_tags = [version_tag, 'stable'] if highest_release else [version_tag]
docker_tags = [version_tag, 'stable']
else:
docker_tags = [version_tag]
elif GITHUB_REF_TYPE == 'branch': elif GITHUB_REF_TYPE == 'branch':
# Otherwise we know we are targeting the 'master' branch # Otherwise we know we are targeting the 'master' branch

View File

@ -173,7 +173,7 @@ def on_config(config, *args, **kwargs):
# Add *all* readthedocs related keys # Add *all* readthedocs related keys
readthedocs = {} readthedocs = {}
for key in os.environ.keys(): for key in os.environ:
if key.startswith('READTHEDOCS_'): if key.startswith('READTHEDOCS_'):
k = key.replace('READTHEDOCS_', '').lower() k = key.replace('READTHEDOCS_', '').lower()
readthedocs[k] = os.environ[key] readthedocs[k] = os.environ[key]

View File

@ -54,9 +54,7 @@ def top_level_path(path: str) -> str:
def generate_schema_file(key: str) -> None: def generate_schema_file(key: str) -> None:
"""Generate a schema file for the provided key.""" """Generate a schema file for the provided key."""
description = ( description = SPECIAL_PATHS.get(key, 'General API Endpoints')
SPECIAL_PATHS[key] if key in SPECIAL_PATHS else 'General API Endpoints'
)
output = f""" output = f"""
--- ---

View File

@ -20,13 +20,19 @@ src = ["src/backend/InvenTree"]
"__init__.py" = ["D104"] "__init__.py" = ["D104"]
[tool.ruff.lint] [tool.ruff.lint]
select = ["A", "B", "C", "C4", "D", "F", "I", "N", "PIE", "UP", "W"] select = ["A", "B", "C", "C4", "D", "F", "I", "N", "SIM", "PIE", "UP", "W"]
# Things that should be enabled in the future: # Things that should be enabled in the future:
# - LOG # - LOG
# - DJ # for Django stuff # - DJ # for Django stuff
# - S # for security stuff (bandit) # - S # for security stuff (bandit)
ignore = [ ignore = [
"SIM117",
# - SIM117 - Use a single with statement with multiple contexts instead of nested with statements
"SIM102",
# - SIM102 - Use a single if statement instead of nested if statements
"SIM105",
# - SIM105 - Use contextlib.suppress({exception}) instead of try-except-pass
"C901", "C901",
# - C901 - function is too complex # - C901 - function is too complex
"N999", "N999",

View File

@ -104,14 +104,16 @@ class InvenTreeResource(ModelResource):
attribute = getattr(field, 'attribute', field_name) attribute = getattr(field, 'attribute', field_name)
# Check if the associated database field is a non-nullable string # Check if the associated database field is a non-nullable string
if db_field := db_fields.get(attribute): if (
if ( (db_field := db_fields.get(attribute))
and (
isinstance(db_field, CharField) isinstance(db_field, CharField)
and db_field.blank and db_field.blank
and not db_field.null and not db_field.null
): )
if column not in self.CONVERT_NULL_FIELDS: and column not in self.CONVERT_NULL_FIELDS
self.CONVERT_NULL_FIELDS.append(column) ):
self.CONVERT_NULL_FIELDS.append(column)
return super().before_import(dataset, using_transactions, dry_run, **kwargs) return super().before_import(dataset, using_transactions, dry_run, **kwargs)

View File

@ -77,7 +77,7 @@ class LicenseView(APIView):
# Ensure we do not have any duplicate 'name' values in the list # Ensure we do not have any duplicate 'name' values in the list
for entry in data: for entry in data:
name = None name = None
for key in entry.keys(): for key in entry:
if key.lower() == 'name': if key.lower() == 'name':
name = entry[key] name = entry[key]
break break

View File

@ -70,7 +70,7 @@ def user_roles(request):
roles = {} roles = {}
for role in RuleSet.get_ruleset_models().keys(): for role in RuleSet.get_ruleset_models():
permissions = {} permissions = {}
for perm in ['view', 'add', 'change', 'delete']: for perm in ['view', 'add', 'change', 'delete']:

View File

@ -245,8 +245,4 @@ def is_dimensionless(value):
if value.units == ureg.dimensionless: if value.units == ureg.dimensionless:
return True return True
if value.to_base_units().units == ureg.dimensionless: return value.to_base_units().units == ureg.dimensionless
return True
# At this point, the value is not dimensionless
return False

View File

@ -40,10 +40,7 @@ def log_error(path, error_name=None, error_info=None, error_data=None):
if kind in settings.IGNORED_ERRORS: if kind in settings.IGNORED_ERRORS:
return return
if error_name: kind = error_name or getattr(kind, '__name__', 'Unknown Error')
kind = error_name
else:
kind = getattr(kind, '__name__', 'Unknown Error')
if error_info: if error_info:
info = error_info info = error_info

View File

@ -31,10 +31,7 @@ class InvenTreeExchange(SimpleExchangeBackend):
# Find the selected exchange rate plugin # Find the selected exchange rate plugin
slug = get_global_setting('CURRENCY_UPDATE_PLUGIN', create=False) slug = get_global_setting('CURRENCY_UPDATE_PLUGIN', create=False)
if slug: plugin = registry.get_plugin(slug) if slug else None
plugin = registry.get_plugin(slug)
else:
plugin = None
if not plugin: if not plugin:
# Find the first active currency exchange plugin # Find the first active currency exchange plugin

View File

@ -93,9 +93,8 @@ class InvenTreeModelMoneyField(ModelMoneyField):
allow_negative = kwargs.pop('allow_negative', False) allow_negative = kwargs.pop('allow_negative', False)
# If no validators are provided, add some "standard" ones # If no validators are provided, add some "standard" ones
if len(validators) == 0: if len(validators) == 0 and not allow_negative:
if not allow_negative: validators.append(MinMoneyValidator(0))
validators.append(MinMoneyValidator(0))
kwargs['validators'] = validators kwargs['validators'] = validators

View File

@ -118,10 +118,7 @@ class InvenTreeOrderingFilter(filters.OrderingFilter):
field = field[1:] field = field[1:]
# Are aliases defined for this field? # Are aliases defined for this field?
if field in aliases: alias = aliases.get(field, field)
alias = aliases[field]
else:
alias = field
""" """
Potentially, a single field could be "aliased" to multiple field, Potentially, a single field could be "aliased" to multiple field,

View File

@ -106,10 +106,7 @@ def construct_format_regex(fmt_string: str) -> str:
# Add a named capture group for the format entry # Add a named capture group for the format entry
if name: if name:
# Check if integer values are required # Check if integer values are required
if _fmt.endswith('d'): c = '\\d' if _fmt.endswith('d') else '.'
c = r'\d'
else:
c = '.'
# Specify width # Specify width
# TODO: Introspect required width # TODO: Introspect required width
@ -160,7 +157,7 @@ def extract_named_group(name: str, value: str, fmt_string: str) -> str:
""" """
info = parse_format_string(fmt_string) info = parse_format_string(fmt_string)
if name not in info.keys(): if name not in info:
raise NameError(_(f"Value '{name}' does not appear in pattern format")) raise NameError(_(f"Value '{name}' does not appear in pattern format"))
# Construct a regular expression for matching against the provided format string # Construct a regular expression for matching against the provided format string
@ -199,7 +196,7 @@ def format_money(
Raises: Raises:
ValueError: format string is incorrectly specified ValueError: format string is incorrectly specified
""" """
language = (None and translation.get_language()) or settings.LANGUAGE_CODE language = (None) or settings.LANGUAGE_CODE
locale = Locale.parse(translation.to_locale(language)) locale = Locale.parse(translation.to_locale(language))
if format: if format:
pattern = parse_pattern(format) pattern = parse_pattern(format)

View File

@ -97,10 +97,7 @@ def generateTestKey(test_name: str) -> str:
if char.isidentifier(): if char.isidentifier():
return True return True
if char.isalnum(): return bool(char.isalnum())
return True
return False
# Remove any characters that cannot be used to represent a variable # Remove any characters that cannot be used to represent a variable
key = ''.join([c for c in key if valid_char(c)]) key = ''.join([c for c in key if valid_char(c)])
@ -490,10 +487,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
except ValueError: except ValueError:
raise ValidationError([_('Invalid quantity provided')]) raise ValidationError([_('Invalid quantity provided')])
if input_string: input_string = str(input_string).strip() if input_string else ''
input_string = str(input_string).strip()
else:
input_string = ''
if len(input_string) == 0: if len(input_string) == 0:
raise ValidationError([_('Empty serial number string')]) raise ValidationError([_('Empty serial number string')])
@ -827,7 +821,7 @@ def hash_barcode(barcode_data):
def hash_file(filename: Union[str, Path], storage: Union[Storage, None] = None): def hash_file(filename: Union[str, Path], storage: Union[Storage, None] = None):
"""Return the MD5 hash of a file.""" """Return the MD5 hash of a file."""
content = ( content = (
open(filename, 'rb').read() open(filename, 'rb').read() # noqa: SIM115
if storage is None if storage is None
else storage.open(str(filename), 'rb').read() else storage.open(str(filename), 'rb').read()
) )

View File

@ -114,10 +114,7 @@ def download_image_from_url(remote_url, timeout=2.5):
# Add user specified user-agent to request (if specified) # Add user specified user-agent to request (if specified)
user_agent = get_global_setting('INVENTREE_DOWNLOAD_FROM_URL_USER_AGENT') user_agent = get_global_setting('INVENTREE_DOWNLOAD_FROM_URL_USER_AGENT')
if user_agent: headers = {'User-Agent': user_agent} if user_agent else None
headers = {'User-Agent': user_agent}
else:
headers = None
try: try:
response = requests.get( response = requests.get(

View File

@ -19,10 +19,9 @@ def render_file(file_name, source, target, locales, ctx):
target_file = os.path.join(target, locale + '.' + file_name) target_file = os.path.join(target, locale + '.' + file_name)
with open(target_file, 'w') as localised_file: with open(target_file, 'w') as localised_file, lang_over(locale):
with lang_over(locale): rendered = render_to_string(os.path.join(source, file_name), ctx)
rendered = render_to_string(os.path.join(source, file_name), ctx) localised_file.write(rendered)
localised_file.write(rendered)
class Command(BaseCommand): class Command(BaseCommand):

View File

@ -204,7 +204,7 @@ class InvenTreeMetadata(SimpleMetadata):
# Iterate through simple fields # Iterate through simple fields
for name, field in model_fields.fields.items(): for name, field in model_fields.fields.items():
if name in serializer_info.keys(): if name in serializer_info:
if name in read_only_fields: if name in read_only_fields:
serializer_info[name]['read_only'] = True serializer_info[name]['read_only'] = True
@ -236,7 +236,7 @@ class InvenTreeMetadata(SimpleMetadata):
# Iterate through relations # Iterate through relations
for name, relation in model_fields.relations.items(): for name, relation in model_fields.relations.items():
if name not in serializer_info.keys(): if name not in serializer_info:
# Skip relation not defined in serializer # Skip relation not defined in serializer
continue continue
@ -307,12 +307,12 @@ class InvenTreeMetadata(SimpleMetadata):
instance_filters = instance.api_instance_filters() instance_filters = instance.api_instance_filters()
for field_name, field_filters in instance_filters.items(): for field_name, field_filters in instance_filters.items():
if field_name not in serializer_info.keys(): if field_name not in serializer_info:
# The field might be missing, but is added later on # The field might be missing, but is added later on
# This function seems to get called multiple times? # This function seems to get called multiple times?
continue continue
if 'instance_filters' not in serializer_info[field_name].keys(): if 'instance_filters' not in serializer_info[field_name]:
serializer_info[field_name]['instance_filters'] = {} serializer_info[field_name]['instance_filters'] = {}
for key, value in field_filters.items(): for key, value in field_filters.items():

View File

@ -390,10 +390,7 @@ class ReferenceIndexingMixin(models.Model):
except Exception: except Exception:
# If anything goes wrong, return the most recent reference # If anything goes wrong, return the most recent reference
recent = cls.get_most_recent_item() recent = cls.get_most_recent_item()
if recent: reference = recent.reference if recent else ''
reference = recent.reference
else:
reference = ''
return reference return reference
@ -410,14 +407,14 @@ class ReferenceIndexingMixin(models.Model):
}) })
# Check that only 'allowed' keys are provided # Check that only 'allowed' keys are provided
for key in info.keys(): for key in info:
if key not in ctx.keys(): if key not in ctx:
raise ValidationError({ raise ValidationError({
'value': _('Unknown format key specified') + f": '{key}'" 'value': _('Unknown format key specified') + f": '{key}'"
}) })
# Check that the 'ref' variable is specified # Check that the 'ref' variable is specified
if 'ref' not in info.keys(): if 'ref' not in info:
raise ValidationError({ raise ValidationError({
'value': _('Missing required format key') + ": 'ref'" 'value': _('Missing required format key') + ": 'ref'"
}) })

View File

@ -58,10 +58,7 @@ def isInServerThread():
if 'runserver' in sys.argv: if 'runserver' in sys.argv:
return True return True
if 'gunicorn' in sys.argv[0]: return 'gunicorn' in sys.argv[0]
return True
return False
def isInMainThread(): def isInMainThread():
@ -122,11 +119,7 @@ def canAppAccessDatabase(
if not allow_plugins: if not allow_plugins:
excluded_commands.extend(['collectplugins']) excluded_commands.extend(['collectplugins'])
for cmd in excluded_commands: return all(cmd not in sys.argv for cmd in excluded_commands)
if cmd in sys.argv:
return False
return True
def isPluginRegistryLoaded(): def isPluginRegistryLoaded():

View File

@ -379,7 +379,7 @@ class InvenTreeTaggitSerializer(TaggitSerializer):
tag_object = super().update(instance, validated_data) tag_object = super().update(instance, validated_data)
for key in to_be_tagged.keys(): for key in to_be_tagged:
# re-add the tagmanager # re-add the tagmanager
new_tagobject = tag_object.__class__.objects.get(id=tag_object.id) new_tagobject = tag_object.__class__.objects.get(id=tag_object.id)
setattr(tag_object, key, getattr(new_tagobject, key)) setattr(tag_object, key, getattr(new_tagobject, key))
@ -803,11 +803,10 @@ class DataFileExtractSerializer(serializers.Serializer):
required = field.get('required', False) required = field.get('required', False)
# Check for missing required columns # Check for missing required columns
if required: if required and name not in self.columns:
if name not in self.columns: raise serializers.ValidationError(
raise serializers.ValidationError( _(f"Missing required column: '{name}'")
_(f"Missing required column: '{name}'") )
)
for col in self.columns: for col in self.columns:
if not col: if not col:

View File

@ -949,10 +949,7 @@ USE_I18N = True
# Do not use native timezone support in "test" mode # Do not use native timezone support in "test" mode
# It generates a *lot* of cruft in the logs # It generates a *lot* of cruft in the logs
if not TESTING: USE_TZ = bool(not TESTING)
USE_TZ = True # pragma: no cover
else:
USE_TZ = False
DATE_INPUT_FORMATS = ['%Y-%m-%d'] DATE_INPUT_FORMATS = ['%Y-%m-%d']

View File

@ -410,10 +410,7 @@ def progress_bar(val, max_val, *args, **kwargs):
else: else:
style = '' style = ''
if max_val != 0: percent = float(val / max_val) * 100 if max_val != 0 else 0
percent = float(val / max_val) * 100
else:
percent = 0
if percent > 100: if percent > 100:
percent = 100 percent = 100

View File

@ -285,7 +285,7 @@ class AjaxMixin(InvenTreeRoleMixin):
# Custom feedback`data # Custom feedback`data
fb = self.get_data() fb = self.get_data()
for key in fb.keys(): for key in fb:
data[key] = fb[key] data[key] = fb[key]
return JsonResponse(data, safe=False) return JsonResponse(data, safe=False)

View File

@ -244,10 +244,7 @@ class GlobalSettingsDetail(RetrieveUpdateAPI):
"""Attempt to find a global setting object with the provided key.""" """Attempt to find a global setting object with the provided key."""
key = str(self.kwargs['key']).upper() key = str(self.kwargs['key']).upper()
if ( if key.startswith('_') or key not in common.models.InvenTreeSetting.SETTINGS:
key.startswith('_')
or key not in common.models.InvenTreeSetting.SETTINGS.keys()
):
raise NotFound() raise NotFound()
return common.models.InvenTreeSetting.get_setting_object( return common.models.InvenTreeSetting.get_setting_object(
@ -318,7 +315,7 @@ class UserSettingsDetail(RetrieveUpdateAPI):
if ( if (
key.startswith('_') key.startswith('_')
or key not in common.models.InvenTreeUserSetting.SETTINGS.keys() or key not in common.models.InvenTreeUserSetting.SETTINGS
): ):
raise NotFound() raise NotFound()

View File

@ -83,10 +83,7 @@ class MatchItemForm(forms.Form):
if 'file_manager' in kwargs: if 'file_manager' in kwargs:
file_manager = kwargs.pop('file_manager') file_manager = kwargs.pop('file_manager')
if 'row_data' in kwargs: row_data = kwargs.pop('row_data') if 'row_data' in kwargs else None
row_data = kwargs.pop('row_data')
else:
row_data = None
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@ -3090,13 +3090,10 @@ class CustomUnit(models.Model):
"""Ensure that the custom unit is unique.""" """Ensure that the custom unit is unique."""
super().validate_unique(exclude) super().validate_unique(exclude)
if self.symbol: if self.symbol and (
if ( CustomUnit.objects.filter(symbol=self.symbol).exclude(pk=self.pk).exists()
CustomUnit.objects.filter(symbol=self.symbol) ):
.exclude(pk=self.pk) raise ValidationError({'symbol': _('Unit symbol must be unique')})
.exists()
):
raise ValidationError({'symbol': _('Unit symbol must be unique')})
def clean(self): def clean(self):
"""Validate that the provided custom unit is indeed valid.""" """Validate that the provided custom unit is indeed valid."""

View File

@ -143,11 +143,7 @@ class NotificationMethod:
# Check if method globally enabled # Check if method globally enabled
plg_instance = registry.get_plugin(plg_cls.NAME.lower()) plg_instance = registry.get_plugin(plg_cls.NAME.lower())
if plg_instance and not plg_instance.get_setting(self.GLOBAL_SETTING): return plg_instance and not plg_instance.get_setting(self.GLOBAL_SETTING)
return True
# Lets go!
return False
def usersetting(self, target): def usersetting(self, target):
"""Returns setting for this method for a given user.""" """Returns setting for this method for a given user."""

View File

@ -547,9 +547,8 @@ class AttachmentSerializer(InvenTreeModelSerializer):
model_type = self.validated_data.get('model_type', None) model_type = self.validated_data.get('model_type', None)
if model_type is None: if model_type is None and self.instance:
if self.instance: model_type = self.instance.model_type
model_type = self.instance.model_type
# Ensure that the user has permission to attach files to the specified model # Ensure that the user has permission to attach files to the specified model
user = self.context.get('request').user user = self.context.get('request').user

View File

@ -392,7 +392,7 @@ class SettingsTest(InvenTreeTestCase):
'before_save', 'before_save',
] ]
for k in setting.keys(): for k in setting:
self.assertIn(k, allowed_keys) self.assertIn(k, allowed_keys)
# Check default value for boolean settings # Check default value for boolean settings
@ -430,7 +430,7 @@ class SettingsTest(InvenTreeTestCase):
@override_settings(SITE_URL=None, PLUGIN_TESTING=True, PLUGIN_TESTING_SETUP=True) @override_settings(SITE_URL=None, PLUGIN_TESTING=True, PLUGIN_TESTING_SETUP=True)
def test_defaults(self): def test_defaults(self):
"""Populate the settings with default values.""" """Populate the settings with default values."""
for key in InvenTreeSetting.SETTINGS.keys(): for key in InvenTreeSetting.SETTINGS:
value = InvenTreeSetting.get_setting_default(key) value = InvenTreeSetting.get_setting_default(key)
try: try:
@ -519,7 +519,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
response = self.get(url, expected_code=200) response = self.get(url, expected_code=200)
n_public_settings = len([ n_public_settings = len([
k for k in InvenTreeSetting.SETTINGS.keys() if not k.startswith('_') k for k in InvenTreeSetting.SETTINGS if not k.startswith('_')
]) ])
# Number of results should match the number of settings # Number of results should match the number of settings

View File

@ -362,10 +362,7 @@ class FileManagementFormView(MultiStepFormView):
duplicates = [] duplicates = []
for col in self.column_names: for col in self.column_names:
if col in self.column_selections: guess = self.column_selections.get(col, None)
guess = self.column_selections[col]
else:
guess = None
if guess: if guess:
n = list(self.column_selections.values()).count( n = list(self.column_selections.values()).count(

View File

@ -46,10 +46,7 @@ def rename_company_image(instance, filename):
""" """
base = 'company_images' base = 'company_images'
if filename.count('.') > 0: ext = filename.split('.')[-1] if filename.count('.') > 0 else ''
ext = filename.split('.')[-1]
else:
ext = ''
fn = f'company_{instance.pk}_img' fn = f'company_{instance.pk}_img'

View File

@ -86,9 +86,7 @@ class StatusCode(BaseEnum):
return False return False
if callable(value): if callable(value):
return False return False
if not isinstance(value.value, int): return isinstance(value.value, int)
return False
return True
@classmethod @classmethod
def values(cls, key=None): def values(cls, key=None):

View File

@ -209,7 +209,7 @@ class DataImportSession(models.Model):
missing_fields = [] missing_fields = []
for field in required_fields.keys(): for field in required_fields:
# An override value exists # An override value exists
if field in field_overrides: if field in field_overrides:
continue continue

View File

@ -113,9 +113,8 @@ def get_field_label(field) -> str:
Returns: Returns:
str: Field label str: Field label
""" """
if field: if field and (label := getattr(field, 'label', None)):
if label := getattr(field, 'label', None): return label
return label
# TODO: Check if the field is a model field # TODO: Check if the field is a model field

View File

@ -36,7 +36,7 @@ def validate_importer_model_type(value):
"""Validate that the given model type is supported for importing.""" """Validate that the given model type is supported for importing."""
from importer.registry import supported_models from importer.registry import supported_models
if value not in supported_models().keys(): if value not in supported_models():
raise ValidationError(f"Unsupported model type '{value}'") raise ValidationError(f"Unsupported model type '{value}'")

View File

@ -1356,10 +1356,9 @@ class OrderCalendarExport(ICalFeed):
if auth[0].lower() == 'basic': if auth[0].lower() == 'basic':
uname, passwd = base64.b64decode(auth[1]).decode('ascii').split(':') uname, passwd = base64.b64decode(auth[1]).decode('ascii').split(':')
user = authenticate(username=uname, password=passwd) user = authenticate(username=uname, password=passwd)
if user is not None: if user is not None and user.is_active:
if user.is_active: login(request, user)
login(request, user) request.user = user
request.user = user
# Check again # Check again
if request.user.is_authenticated: if request.user.is_authenticated:

View File

@ -1047,19 +1047,11 @@ class SalesOrder(TotalPriceMixin, Order):
def is_fully_allocated(self): def is_fully_allocated(self):
"""Return True if all line items are fully allocated.""" """Return True if all line items are fully allocated."""
for line in self.lines.all(): return all(line.is_fully_allocated() for line in self.lines.all())
if not line.is_fully_allocated():
return False
return True
def is_overallocated(self): def is_overallocated(self):
"""Return true if any lines in the order are over-allocated.""" """Return true if any lines in the order are over-allocated."""
for line in self.lines.all(): return any(line.is_overallocated() for line in self.lines.all())
if line.is_overallocated():
return True
return False
def is_completed(self): def is_completed(self):
"""Check if this order is "shipped" (all line items delivered).""" """Check if this order is "shipped" (all line items delivered)."""

View File

@ -740,13 +740,12 @@ class PurchaseOrderLineItemReceiveSerializer(serializers.Serializer):
base_quantity = line_item.part.base_quantity(quantity) base_quantity = line_item.part.base_quantity(quantity)
# Does the quantity need to be "integer" (for trackable parts?) # Does the quantity need to be "integer" (for trackable parts?)
if base_part.trackable: if base_part.trackable and Decimal(base_quantity) != int(base_quantity):
if Decimal(base_quantity) != int(base_quantity): raise ValidationError({
raise ValidationError({ 'quantity': _(
'quantity': _( 'An integer quantity must be provided for trackable parts'
'An integer quantity must be provided for trackable parts' )
) })
})
# If serial numbers are provided # If serial numbers are provided
if serial_numbers: if serial_numbers:

View File

@ -1527,7 +1527,7 @@ class SalesOrderTest(OrderTest):
self.download_file( self.download_file(
reverse('api-so-list'), reverse('api-so-list'),
{'export': fmt}, {'export': fmt},
decode=True if fmt == 'csv' else False, decode=fmt == 'csv',
expected_code=200, expected_code=200,
expected_fn=r'InvenTree_SalesOrder_.+', expected_fn=r'InvenTree_SalesOrder_.+',
) )

View File

@ -167,10 +167,9 @@ class CategoryFilter(rest_filters.FilterSet):
top_level = str2bool(self.data.get('top_level', None)) top_level = str2bool(self.data.get('top_level', None))
# If the parent is *not* provided, update the results based on the "cascade" value # If the parent is *not* provided, update the results based on the "cascade" value
if not parent or top_level: if (not parent or top_level) and not value:
if not value: # If "cascade" is False, only return top-level categories
# If "cascade" is False, only return top-level categories queryset = queryset.filter(parent=None)
queryset = queryset.filter(parent=None)
return queryset return queryset

View File

@ -244,10 +244,7 @@ def ExportBom(
manufacturer_name = '' manufacturer_name = ''
# Extract the "MPN" field from the Manufacturer Part # Extract the "MPN" field from the Manufacturer Part
if mp_part: manufacturer_mpn = mp_part.MPN if mp_part else ''
manufacturer_mpn = mp_part.MPN
else:
manufacturer_mpn = ''
# Generate a column name for this manufacturer # Generate a column name for this manufacturer
k_man = f'{_("Manufacturer")}_{mp_idx}' k_man = f'{_("Manufacturer")}_{mp_idx}'
@ -270,10 +267,7 @@ def ExportBom(
else: else:
supplier_name = '' supplier_name = ''
if sp_part: supplier_sku = sp_part.SKU if sp_part else ''
supplier_sku = sp_part.SKU
else:
supplier_sku = ''
# Generate column names for this supplier # Generate column names for this supplier
k_sup = ( k_sup = (
@ -307,10 +301,7 @@ def ExportBom(
supplier_parts_used.add(sp_part) supplier_parts_used.add(sp_part)
if sp_part.supplier: supplier_name = sp_part.supplier.name if sp_part.supplier else ''
supplier_name = sp_part.supplier.name
else:
supplier_name = ''
supplier_sku = sp_part.SKU supplier_sku = sp_part.SKU

View File

@ -230,10 +230,7 @@ class PartCategory(InvenTree.models.InvenTreeTree):
"""Get all unique parameter names for all parts from this category.""" """Get all unique parameter names for all parts from this category."""
unique_parameters_names = [] unique_parameters_names = []
if prefetch: parts = prefetch or self.prefetch_parts_parameters(cascade=cascade)
parts = prefetch
else:
parts = self.prefetch_parts_parameters(cascade=cascade)
for part in parts: for part in parts:
for parameter in part.parameters.all(): for parameter in part.parameters.all():
@ -247,10 +244,7 @@ class PartCategory(InvenTree.models.InvenTreeTree):
"""Get all parameter names and values for all parts from this category.""" """Get all parameter names and values for all parts from this category."""
category_parameters = [] category_parameters = []
if prefetch: parts = prefetch or self.prefetch_parts_parameters(cascade=cascade)
parts = prefetch
else:
parts = self.prefetch_parts_parameters(cascade=cascade)
for part in parts: for part in parts:
part_parameters = { part_parameters = {
@ -917,24 +911,26 @@ class Part(
'IPN': _('Duplicate IPN not allowed in part settings') 'IPN': _('Duplicate IPN not allowed in part settings')
}) })
if self.revision_of and self.revision: if (
if ( self.revision_of
and self.revision
and (
Part.objects.exclude(pk=self.pk) Part.objects.exclude(pk=self.pk)
.filter(revision_of=self.revision_of, revision=self.revision) .filter(revision_of=self.revision_of, revision=self.revision)
.exists() .exists()
): )
raise ValidationError(_('Duplicate part revision already exists.')) ):
raise ValidationError(_('Duplicate part revision already exists.'))
# Ensure unique across (Name, revision, IPN) (as specified) # Ensure unique across (Name, revision, IPN) (as specified)
if self.revision or self.IPN: if (self.revision or self.IPN) and (
if ( Part.objects.exclude(pk=self.pk)
Part.objects.exclude(pk=self.pk) .filter(name=self.name, revision=self.revision, IPN=self.IPN)
.filter(name=self.name, revision=self.revision, IPN=self.IPN) .exists()
.exists() ):
): raise ValidationError(
raise ValidationError( _('Part with this Name, IPN and Revision already exists.')
_('Part with this Name, IPN and Revision already exists.') )
)
def clean(self): def clean(self):
"""Perform cleaning operations for the Part model. """Perform cleaning operations for the Part model.
@ -3872,16 +3868,18 @@ class PartParameter(InvenTree.models.InvenTreeMetadataModel):
super().clean() super().clean()
# Validate the parameter data against the template units # Validate the parameter data against the template units
if get_global_setting( if (
'PART_PARAMETER_ENFORCE_UNITS', True, cache=False, create=False get_global_setting(
'PART_PARAMETER_ENFORCE_UNITS', True, cache=False, create=False
)
and self.template.units
): ):
if self.template.units: try:
try: InvenTree.conversion.convert_physical_value(
InvenTree.conversion.convert_physical_value( self.data, self.template.units
self.data, self.template.units )
) except ValidationError as e:
except ValidationError as e: raise ValidationError({'data': e.message})
raise ValidationError({'data': e.message})
# Validate the parameter data against the template choices # Validate the parameter data against the template choices
if choices := self.template.get_choices(): if choices := self.template.get_choices():
@ -4189,9 +4187,8 @@ class BomItem(
# Check if the part was changed # Check if the part was changed
deltas = self.get_field_deltas() deltas = self.get_field_deltas()
if 'part' in deltas: if 'part' in deltas and (old_part := deltas['part'].get('old', None)):
if old_part := deltas['part'].get('old', None): self.check_part_lock(old_part)
self.check_part_lock(old_part)
# Update the 'validated' field based on checksum calculation # Update the 'validated' field based on checksum calculation
self.validated = self.is_line_valid self.validated = self.is_line_valid

View File

@ -63,7 +63,7 @@ class PluginFilter(rest_filters.FilterSet):
match = True match = True
for mixin in mixins: for mixin in mixins:
if mixin not in result.mixins().keys(): if mixin not in result.mixins():
match = False match = False
break break

View File

@ -335,9 +335,8 @@ class BarcodePOAllocate(BarcodeView):
if not part and not supplier_part and not manufacturer_part: if not part and not supplier_part and not manufacturer_part:
raise ValidationError({'error': _('No matching part data found')}) raise ValidationError({'error': _('No matching part data found')})
if part: if part and (part_id := part.get('pk', None)):
if part_id := part.get('pk', None): supplier_parts = supplier_parts.filter(part__pk=part_id)
supplier_parts = supplier_parts.filter(part__pk=part_id)
if supplier_part: if supplier_part:
if supplier_part_id := supplier_part.get('pk', None): if supplier_part_id := supplier_part.get('pk', None):

View File

@ -41,7 +41,7 @@ class BarcodeGenerateSerializer(serializers.Serializer):
plugin.base.barcodes.helper.get_supported_barcode_models_map() plugin.base.barcodes.helper.get_supported_barcode_models_map()
) )
if model not in supported_models.keys(): if model not in supported_models:
raise ValidationError(_('Model is not supported')) raise ValidationError(_('Model is not supported'))
return model return model

View File

@ -147,10 +147,7 @@ def allow_table_event(table_name):
'part_partstocktakereport', 'part_partstocktakereport',
] ]
if table_name in ignore_tables: return not table_name in ignore_tables
return False
return True
@receiver(post_save) @receiver(post_save)

View File

@ -148,7 +148,7 @@ class ScheduleMixin:
def get_task_names(self): def get_task_names(self):
"""All defined task names.""" """All defined task names."""
# Returns a list of all task names associated with this plugin instance # Returns a list of all task names associated with this plugin instance
return [self.get_task_name(key) for key in self.scheduled_tasks.keys()] return [self.get_task_name(key) for key in self.scheduled_tasks]
def register_tasks(self): def register_tasks(self):
"""Register the tasks with the database.""" """Register the tasks with the database."""

View File

@ -216,10 +216,7 @@ def install_plugin(url=None, packagename=None, user=None, version=None):
identifier in url for identifier in ['git+https', 'hg+https', 'svn+svn'] identifier in url for identifier in ['git+https', 'hg+https', 'svn+svn']
]: ]:
# using a VCS provider # using a VCS provider
if packagename: full_pkg = f'{packagename}@{url}' if packagename else url
full_pkg = f'{packagename}@{url}'
else:
full_pkg = url
elif url: elif url:
install_name.append('-i') install_name.append('-i')
full_pkg = url full_pkg = url
@ -244,10 +241,9 @@ def install_plugin(url=None, packagename=None, user=None, version=None):
ret['result'] = ret['success'] = _('Installed plugin successfully') ret['result'] = ret['success'] = _('Installed plugin successfully')
ret['output'] = str(result, 'utf-8') ret['output'] = str(result, 'utf-8')
if packagename: if packagename and (path := check_package_path(packagename)):
if path := check_package_path(packagename): # Override result information
# Override result information ret['result'] = _(f'Installed plugin into {path}')
ret['result'] = _(f'Installed plugin into {path}')
except subprocess.CalledProcessError as error: except subprocess.CalledProcessError as error:
handle_pip_error(error, 'plugin_install') handle_pip_error(error, 'plugin_install')

View File

@ -39,7 +39,7 @@ def collect_plugins_static_files():
logger.info('Collecting static files for all installed plugins.') logger.info('Collecting static files for all installed plugins.')
for slug in registry.plugins.keys(): for slug in registry.plugins:
copy_plugin_static_files(slug, check_reload=False) copy_plugin_static_files(slug, check_reload=False)

View File

@ -112,10 +112,7 @@ def plugin_static(context, file: str, **kwargs):
""" """
plugin = context.get('plugin', None) plugin = context.get('plugin', None)
if plugin: plugin = plugin.slug if plugin else kwargs.get('plugin', None)
plugin = plugin.slug
else:
plugin = kwargs.get('plugin', None)
if not plugin: if not plugin:
return file return file

View File

@ -59,10 +59,7 @@ if __name__ == '__main__':
(total, covered, uncovered) = stats (total, covered, uncovered) = stats
if total > 0: percentage = int(covered / total * 100) if total > 0 else 0
percentage = int(covered / total * 100)
else:
percentage = 0
if verbose: if verbose:
print(f"| {locale.ljust(4, ' ')} : {str(percentage).rjust(4, ' ')}% |") print(f"| {locale.ljust(4, ' ')} : {str(percentage).rjust(4, ' ')}% |")
@ -78,9 +75,6 @@ if __name__ == '__main__':
with open(STAT_FILE, 'w') as target: with open(STAT_FILE, 'w') as target:
json.dump(locales_perc, target) json.dump(locales_perc, target)
if len(percentages) > 0: avg = int(sum(percentages) / len(percentages)) if len(percentages) > 0 else 0
avg = int(sum(percentages) / len(percentages))
else:
avg = 0
print(f'InvenTree translation coverage: {avg}%') print(f'InvenTree translation coverage: {avg}%')

View File

@ -354,10 +354,9 @@ class StockLocationFilter(rest_filters.FilterSet):
top_level = str2bool(self.data.get('top_level', None)) top_level = str2bool(self.data.get('top_level', None))
# If the parent is *not* provided, update the results based on the "cascade" value # If the parent is *not* provided, update the results based on the "cascade" value
if not parent or top_level: if (not parent or top_level) and not value:
if not value: # If "cascade" is False, only return top-level location
# If "cascade" is False, only return top-level location queryset = queryset.filter(parent=None)
queryset = queryset.filter(parent=None)
return queryset return queryset
@ -1452,17 +1451,17 @@ class StockTrackingList(DataExportViewMixin, ListAPI):
delta_models = self.get_delta_model_map() delta_models = self.get_delta_model_map()
# Construct a set of related models we need to lookup for later # Construct a set of related models we need to lookup for later
related_model_lookups = {key: set() for key in delta_models.keys()} related_model_lookups = {key: set() for key in delta_models}
# Run a first pass through the data to determine which related models we need to lookup # Run a first pass through the data to determine which related models we need to lookup
for item in data: for item in data:
deltas = item['deltas'] or {} deltas = item['deltas'] or {}
for key in delta_models.keys(): for key in delta_models:
if key in deltas: if key in deltas:
related_model_lookups[key].add(deltas[key]) related_model_lookups[key].add(deltas[key])
for key in delta_models.keys(): for key in delta_models:
model, serializer = delta_models[key] model, serializer = delta_models[key]
# Fetch all related models in one go # Fetch all related models in one go

View File

@ -1183,10 +1183,7 @@ class StockItem(
if self.allocations.count() > 0: if self.allocations.count() > 0:
return True return True
if self.sales_order_allocations.count() > 0: return self.sales_order_allocations.count() > 0
return True
return False
def build_allocation_count(self): def build_allocation_count(self):
"""Return the total quantity allocated to builds.""" """Return the total quantity allocated to builds."""
@ -1260,10 +1257,7 @@ class StockItem(
if self.installed_item_count() > 0: if self.installed_item_count() > 0:
return False return False
if self.sales_order is not None: return not self.sales_order is not None
return False
return True
def get_installed_items(self, cascade: bool = False) -> set[StockItem]: def get_installed_items(self, cascade: bool = False) -> set[StockItem]:
"""Return all stock items which are *installed* in this one! """Return all stock items which are *installed* in this one!
@ -1422,10 +1416,7 @@ class StockItem(
if self.belongs_to is not None: if self.belongs_to is not None:
return False return False
if self.sales_order is not None: return not self.sales_order is not None
return False
return True
@property @property
def tracking_info_count(self): def tracking_info_count(self):
@ -2204,7 +2195,7 @@ class StockItem(
for item in installed_items: for item in installed_items:
item_results = item.testResultMap() item_results = item.testResultMap()
for key in item_results.keys(): for key in item_results:
# Results from sub items should not override master ones # Results from sub items should not override master ones
if key not in result_map: if key not in result_map:
result_map[key] = item_results[key] result_map[key] = item_results[key]
@ -2353,7 +2344,7 @@ class StockItemTracking(InvenTree.models.InvenTreeModel):
def label(self): def label(self):
"""Return label.""" """Return label."""
if self.tracking_type in StockHistoryCode.keys(): if self.tracking_type in StockHistoryCode:
return StockHistoryCode.label(self.tracking_type) return StockHistoryCode.label(self.tracking_type)
return getattr(self, 'title', '') return getattr(self, 'title', '')

View File

@ -226,7 +226,7 @@ class TestTestResultMigration(MigratorTestCase):
# Two more test templates should have been created # Two more test templates should have been created
self.assertEqual(PartTestTemplate.objects.count(), 3) self.assertEqual(PartTestTemplate.objects.count(), 3)
for k in self.test_keys.keys(): for k in self.test_keys:
self.assertTrue(PartTestTemplate.objects.filter(key=k).exists()) self.assertTrue(PartTestTemplate.objects.filter(key=k).exists())
for result in StockItemTestResult.objects.all(): for result in StockItemTestResult.objects.all():

View File

@ -32,10 +32,7 @@ class ApiTokenAdmin(admin.ModelAdmin):
def get_fields(self, request, obj=None): def get_fields(self, request, obj=None):
"""Return list of fields to display.""" """Return list of fields to display."""
if obj: fields = ['token'] if obj else ['key']
fields = ['token']
else:
fields = ['key']
fields += [ fields += [
'user', 'user',

View File

@ -489,10 +489,7 @@ def split_model(model):
*app, model = model.split('_') *app, model = model.split('_')
# handle models that have # handle models that have
if len(app) > 1: app = '_'.join(app) if len(app) > 1 else app[0]
app = '_'.join(app)
else:
app = app[0]
return model, app return model, app
@ -674,7 +671,7 @@ def clear_user_role_cache(user: User):
Args: Args:
user: The User object to be expunged from the cache user: The User object to be expunged from the cache
""" """
for role in RuleSet.get_ruleset_models().keys(): for role in RuleSet.get_ruleset_models():
for perm in ['add', 'change', 'view', 'delete']: for perm in ['add', 'change', 'view', 'delete']:
key = f'role_{user.pk}_{role}_{perm}' key = f'role_{user.pk}_{role}_{perm}'
cache.delete(key) cache.delete(key)

View File

@ -62,7 +62,7 @@ class RuleSetModelTest(TestCase):
assigned_models = set() assigned_models = set()
# Now check that each defined model is a valid table name # Now check that each defined model is a valid table name
for key in RuleSet.get_ruleset_models().keys(): for key in RuleSet.get_ruleset_models():
models = RuleSet.get_ruleset_models()[key] models = RuleSet.get_ruleset_models()[key]
for m in models: for m in models: