diff --git a/InvenTree/InvenTree/test_api.py b/InvenTree/InvenTree/test_api.py index e75a111e5d..7068593e6b 100644 --- a/InvenTree/InvenTree/test_api.py +++ b/InvenTree/InvenTree/test_api.py @@ -4,9 +4,10 @@ from rest_framework.test import APITestCase from rest_framework import status from django.urls import reverse - from django.contrib.auth import get_user_model +from base64 import b64encode + class APITests(APITestCase): """ Tests for the InvenTree API """ @@ -21,24 +22,48 @@ class APITests(APITestCase): username = 'test_user' password = 'test_pass' + token = None + def setUp(self): # Create a user (but do not log in!) User = get_user_model() User.objects.create_user(self.username, 'user@email.com', self.password) - def get_token(self): + def basicAuth(self): + # Use basic authentication + + authstring = bytes("{u}:{p}".format(u=self.username, p=self.password), "ascii") + + # Use "basic" auth by default + auth = b64encode(authstring).decode("ascii") + self.client.credentials(HTTP_AUTHORIZATION="Basic {auth}".format(auth=auth)) + + def tokenAuth(self): + + self.basicAuth() token_url = reverse('api-token') + response = self.client.get(token_url, format='json', data={}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('token', response.data) - # POST to retreive a token - response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password}) - token = response.data['token'] - - self.client.credentials(HTTP_AUTHORIZATION='Token ' + token) - self.token = token + def token_failure(self): + # Test token endpoint without basic auth + url = reverse('api-token') + response = self.client.get(url, format='json') + + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + self.assertIsNone(self.token) + + def token_success(self): + + self.tokenAuth() + self.assertIsNotNone(self.token) + def test_info_view(self): """ Test that we can read the 'info-view' endpoint. @@ -55,51 +80,18 @@ class APITests(APITestCase): self.assertEquals('InvenTree', data['server']) - def test_get_token_fail(self): - """ Ensure that an invalid user cannot get a token """ - - token_url = reverse('api-token') - - response = self.client.post(token_url, format='json', data={'username': 'bad', 'password': 'also_bad'}) - - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertFalse('token' in response.data) - - def test_get_token_pass(self): - """ Ensure that a valid user can request an API token """ - - token_url = reverse('api-token') - - # POST to retreive a token - response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password}) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertTrue('token' in response.data) - self.assertTrue('pk' in response.data) - self.assertTrue(len(response.data['token']) > 0) - - # Now, use the token to access other data - token = response.data['token'] - - part_url = reverse('api-part-list') - - # Try to access without a token - response = self.client.get(part_url, format='json') + def test_barcode_fail(self): + # Test barcode endpoint without auth + response = self.client.post(reverse('api-barcode-plugin'), format='json') self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - - # Now, with the token - self.client.credentials(HTTP_AUTHORIZATION='Token ' + token) - response = self.client.get(part_url, format='json') - - self.assertEqual(response.status_code, status.HTTP_200_OK) def test_barcode(self): """ Test the barcode endpoint """ - url = reverse('api-barcode-plugin') + self.tokenAuth() - self.get_token() + url = reverse('api-barcode-plugin') data = { 'barcode': { diff --git a/InvenTree/users/views.py b/InvenTree/users/views.py index 01d9834c69..97e5f48355 100644 --- a/InvenTree/users/views.py +++ b/InvenTree/users/views.py @@ -3,7 +3,7 @@ from django.contrib.auth.models import User from django.core.exceptions import ObjectDoesNotExist from .serializers import UserSerializer -from rest_framework.authtoken.views import ObtainAuthToken +from rest_framework.views import APIView from rest_framework.authtoken.models import Token from rest_framework.response import Response from rest_framework import status @@ -25,28 +25,32 @@ class UserList(generics.ListAPIView): permission_classes = (permissions.IsAuthenticated,) -class GetAuthToken(ObtainAuthToken): +class GetAuthToken(APIView): """ Return authentication token for an authenticated user. """ - def post(self, request, *args, **kwargs): + permission_classes = [ + permissions.IsAuthenticated, + ] + + def get(self, request, *args, **kwargs): return self.login(request) def delete(self, request): return self.logout(request) def login(self, request): - serializer = self.serializer_class(data=request.data, - context={'request': request}) - serializer.is_valid(raise_exception=True) - user = serializer.validated_data['user'] - token, created = Token.objects.get_or_create(user=user) - return Response({ - 'token': token.key, - 'pk': user.pk, - 'username': user.username, - 'email': user.email - }) + if request.user.is_authenticated: + # Get the user token (or create one if it does not exist) + token, created = Token.objects.get_or_create(user=request.user) + return Response({ + 'token': token.key, + }) + + else: + return Response({ + 'error': 'User not authenticated', + }) def logout(self, request): try: