mirror of
https://github.com/inventree/InvenTree
synced 2024-08-30 18:33:04 +00:00
Merge pull request #735 from SchrodingersGat/token-auth-fix
Improvements for token authentication
This commit is contained in:
commit
4f3f78f55a
@ -4,9 +4,10 @@ from rest_framework.test import APITestCase
|
||||
from rest_framework import status
|
||||
|
||||
from django.urls import reverse
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
from base64 import b64encode
|
||||
|
||||
|
||||
class APITests(APITestCase):
|
||||
""" Tests for the InvenTree API """
|
||||
@ -21,24 +22,48 @@ class APITests(APITestCase):
|
||||
username = 'test_user'
|
||||
password = 'test_pass'
|
||||
|
||||
token = None
|
||||
|
||||
def setUp(self):
|
||||
|
||||
# Create a user (but do not log in!)
|
||||
User = get_user_model()
|
||||
User.objects.create_user(self.username, 'user@email.com', self.password)
|
||||
|
||||
def get_token(self):
|
||||
def basicAuth(self):
|
||||
# Use basic authentication
|
||||
|
||||
authstring = bytes("{u}:{p}".format(u=self.username, p=self.password), "ascii")
|
||||
|
||||
# Use "basic" auth by default
|
||||
auth = b64encode(authstring).decode("ascii")
|
||||
self.client.credentials(HTTP_AUTHORIZATION="Basic {auth}".format(auth=auth))
|
||||
|
||||
def tokenAuth(self):
|
||||
|
||||
self.basicAuth()
|
||||
token_url = reverse('api-token')
|
||||
response = self.client.get(token_url, format='json', data={})
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertIn('token', response.data)
|
||||
|
||||
# POST to retreive a token
|
||||
response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password})
|
||||
|
||||
token = response.data['token']
|
||||
|
||||
self.client.credentials(HTTP_AUTHORIZATION='Token ' + token)
|
||||
|
||||
self.token = token
|
||||
|
||||
def token_failure(self):
|
||||
# Test token endpoint without basic auth
|
||||
url = reverse('api-token')
|
||||
response = self.client.get(url, format='json')
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
self.assertIsNone(self.token)
|
||||
|
||||
def token_success(self):
|
||||
|
||||
self.tokenAuth()
|
||||
self.assertIsNotNone(self.token)
|
||||
|
||||
def test_info_view(self):
|
||||
"""
|
||||
Test that we can read the 'info-view' endpoint.
|
||||
@ -55,51 +80,18 @@ class APITests(APITestCase):
|
||||
|
||||
self.assertEquals('InvenTree', data['server'])
|
||||
|
||||
def test_get_token_fail(self):
|
||||
""" Ensure that an invalid user cannot get a token """
|
||||
|
||||
token_url = reverse('api-token')
|
||||
|
||||
response = self.client.post(token_url, format='json', data={'username': 'bad', 'password': 'also_bad'})
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
||||
self.assertFalse('token' in response.data)
|
||||
|
||||
def test_get_token_pass(self):
|
||||
""" Ensure that a valid user can request an API token """
|
||||
|
||||
token_url = reverse('api-token')
|
||||
|
||||
# POST to retreive a token
|
||||
response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password})
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertTrue('token' in response.data)
|
||||
self.assertTrue('pk' in response.data)
|
||||
self.assertTrue(len(response.data['token']) > 0)
|
||||
|
||||
# Now, use the token to access other data
|
||||
token = response.data['token']
|
||||
|
||||
part_url = reverse('api-part-list')
|
||||
|
||||
# Try to access without a token
|
||||
response = self.client.get(part_url, format='json')
|
||||
def test_barcode_fail(self):
|
||||
# Test barcode endpoint without auth
|
||||
response = self.client.post(reverse('api-barcode-plugin'), format='json')
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
# Now, with the token
|
||||
self.client.credentials(HTTP_AUTHORIZATION='Token ' + token)
|
||||
response = self.client.get(part_url, format='json')
|
||||
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
|
||||
def test_barcode(self):
|
||||
""" Test the barcode endpoint """
|
||||
|
||||
url = reverse('api-barcode-plugin')
|
||||
self.tokenAuth()
|
||||
|
||||
self.get_token()
|
||||
url = reverse('api-barcode-plugin')
|
||||
|
||||
data = {
|
||||
'barcode': {
|
||||
|
@ -3,7 +3,7 @@ from django.contrib.auth.models import User
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from .serializers import UserSerializer
|
||||
|
||||
from rest_framework.authtoken.views import ObtainAuthToken
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.authtoken.models import Token
|
||||
from rest_framework.response import Response
|
||||
from rest_framework import status
|
||||
@ -25,28 +25,32 @@ class UserList(generics.ListAPIView):
|
||||
permission_classes = (permissions.IsAuthenticated,)
|
||||
|
||||
|
||||
class GetAuthToken(ObtainAuthToken):
|
||||
class GetAuthToken(APIView):
|
||||
""" Return authentication token for an authenticated user. """
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
permission_classes = [
|
||||
permissions.IsAuthenticated,
|
||||
]
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
return self.login(request)
|
||||
|
||||
def delete(self, request):
|
||||
return self.logout(request)
|
||||
|
||||
def login(self, request):
|
||||
serializer = self.serializer_class(data=request.data,
|
||||
context={'request': request})
|
||||
serializer.is_valid(raise_exception=True)
|
||||
user = serializer.validated_data['user']
|
||||
token, created = Token.objects.get_or_create(user=user)
|
||||
|
||||
return Response({
|
||||
'token': token.key,
|
||||
'pk': user.pk,
|
||||
'username': user.username,
|
||||
'email': user.email
|
||||
})
|
||||
if request.user.is_authenticated:
|
||||
# Get the user token (or create one if it does not exist)
|
||||
token, created = Token.objects.get_or_create(user=request.user)
|
||||
return Response({
|
||||
'token': token.key,
|
||||
})
|
||||
|
||||
else:
|
||||
return Response({
|
||||
'error': 'User not authenticated',
|
||||
})
|
||||
|
||||
def logout(self, request):
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user