diff --git a/InvenTree/key.py b/InvenTree/key.py index c28b9c6a30..edcdac0854 100644 --- a/InvenTree/key.py +++ b/InvenTree/key.py @@ -3,8 +3,11 @@ import random import string import os +import sys +import argparse -fn = 'secret_key.txt' +KEY_FN = 'secret_key.txt' +KEY_DIR = os.path.dirname(os.path.realpath(__file__)) def generate_key(): @@ -14,11 +17,29 @@ def generate_key(): if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='Generate Django SECRET_KEY file') + parser.add_argument('--output', help='Specify key file path', default=None) + parser.add_argument('--force', '-f', help='Override key file (if it exists)', action='store_true') + parser.add_argument('--dummy', '-d', help='Dummy run (display key only', action='store_true') + + args = parser.parse_args() - # Ensure key file is placed in same directory as this script - path = os.path.dirname(os.path.realpath(__file__)) - key_file = os.path.join(path, fn) + if args.output: + key_filename = args.output + else: + key_filename = os.path.join(KEY_DIR, KEY_FN) - with open(key_file, 'w') as kf: - kf.write(generate_key()) - print('Generated SECRET_KEY to {f}'.format(f=key_file)) + key_data = generate_key() + + if args.dummy: + print('SECRET_KEY: {k}'.format(k=key_data)) + sys.exit(0) + + if not args.force and os.path.exists(key_filename): + print("Key file already exists - '{f}'".format(f=key_filename)) + sys.exit(0) + + with open(key_filename, 'w') as key_file: + print("Generating SECRET_KEY file - '{f}'".format(f=key_filename)) + key_file.write(key_data) diff --git a/InvenTree/part/api.py b/InvenTree/part/api.py index 96606c643a..e600a303bb 100644 --- a/InvenTree/part/api.py +++ b/InvenTree/part/api.py @@ -55,6 +55,11 @@ class CategoryList(generics.ListCreateAPIView): ] +class CategoryDetail(generics.RetrieveUpdateDestroyAPIView): + serializer_class = CategorySerializer + queryset = PartCategory.objects.all() + + class PartDetail(generics.RetrieveUpdateDestroyAPIView): queryset = Part.objects.all() serializer_class = PartSerializer @@ -205,6 +210,9 @@ class SupplierPriceBreakList(generics.ListCreateAPIView): cat_api_urls = [ + + url(r'^(?P\d+)/?', CategoryDetail.as_view(), name='api-part-category-detail'), + url(r'^$', CategoryList.as_view(), name='api-part-category-list'), ] diff --git a/InvenTree/part/test_api.py b/InvenTree/part/test_api.py new file mode 100644 index 0000000000..6546ea7109 --- /dev/null +++ b/InvenTree/part/test_api.py @@ -0,0 +1,115 @@ +from rest_framework.test import APITestCase +from rest_framework import status +from django.urls import reverse +from django.contrib.auth import get_user_model + +from .models import Part, PartCategory + + +class BomAPITest(APITestCase): + + def setUp(self): + # Create a user for auth + User = get_user_model() + User.objects.create_user('testuser', 'test@testing.com', 'password') + + self.client.login(username='testuser', password='password') + + def test_category_list_empty(self): + # Check that we can retrieve an (empty) category list + url = reverse('api-part-category-list') + response = self.client.get(url, format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 0) + + def test_add_categories(self): + # Check that we can add categories + data = { + 'name': 'Animals', + 'description': 'All animals go here' + } + + url = reverse('api-part-category-list') + response = self.client.post(url, data=data, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data['pk'], 1) + + # Add some sub-categories to the top-level 'Animals' category + for animal in ['cat', 'dog', 'zebra']: + data = { + 'name': animal, + 'description': 'A sort of animal', + 'parent': 1, + } + response = self.client.post(url, data=data, format='json') + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.data['parent'], 1) + self.assertEqual(response.data['name'], animal) + self.assertEqual(response.data['pathstring'], 'Animals/' + animal) + + # There should be now 4 categories + response = self.client.get(url, format='json') + self.assertEqual(len(response.data), 4) + + +class PartAPITest(APITestCase): + + def setUp(self): + # Create a user for auth + User = get_user_model() + User.objects.create_user('testuser', 'test@testing.com', 'password') + + self.client.login(username='testuser', password='password') + + # Create some test data + TOP = PartCategory.objects.create(name='Top', description='Top level category') + + A = PartCategory.objects.create(name='A', description='Cat A', parent=TOP) + B = PartCategory.objects.create(name='B', description='Cat B', parent=TOP) + C = PartCategory.objects.create(name='C', description='Cat C', parent=TOP) + + Part.objects.create(name='Top.t', description='t in TOP', category=TOP) + + Part.objects.create(name='A.a', description='a in A', category=A) + Part.objects.create(name='B.b', description='b in B', category=B) + Part.objects.create(name='C.c1', description='c1 in C', category=C) + Part.objects.create(name='C.c2', description='c2 in C', category=C) + + def test_get_all_parts(self): + url = reverse('api-part-list') + response = self.client.get(url, format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 5) + + def test_get_parts_by_cat(self): + url = reverse('api-part-list') + data = {'category': 4} + response = self.client.get(url, data=data, format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # There should only be 2 objects in category C + self.assertEqual(len(response.data), 2) + + for part in response.data: + self.assertEqual(part['category'], 4) + + def test_include_children(self): + """ Test the special 'include_child_categories' flag + If provided, parts are provided for ANY child category (recursive) + """ + url = reverse('api-part-list') + data = {'category': 1} + + response = self.client.get(url, data=data, format='json') + + # There should be 1 part in this category + self.assertEqual(len(response.data), 1) + + data['include_child_categories'] = 1 + + # Now request to include child categories + response = self.client.get(url, data=data, format='json') + + # Now there should be 5 total parts + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 5)