mirror of
https://github.com/inventree/InvenTree
synced 2024-08-30 18:33:04 +00:00
Merge pull request #735 from SchrodingersGat/token-auth-fix
Improvements for token authentication
This commit is contained in:
commit
4f3f78f55a
@ -4,9 +4,10 @@ from rest_framework.test import APITestCase
|
|||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
|
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
|
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
|
|
||||||
|
from base64 import b64encode
|
||||||
|
|
||||||
|
|
||||||
class APITests(APITestCase):
|
class APITests(APITestCase):
|
||||||
""" Tests for the InvenTree API """
|
""" Tests for the InvenTree API """
|
||||||
@ -21,24 +22,48 @@ class APITests(APITestCase):
|
|||||||
username = 'test_user'
|
username = 'test_user'
|
||||||
password = 'test_pass'
|
password = 'test_pass'
|
||||||
|
|
||||||
|
token = None
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
||||||
# Create a user (but do not log in!)
|
# Create a user (but do not log in!)
|
||||||
User = get_user_model()
|
User = get_user_model()
|
||||||
User.objects.create_user(self.username, 'user@email.com', self.password)
|
User.objects.create_user(self.username, 'user@email.com', self.password)
|
||||||
|
|
||||||
def get_token(self):
|
def basicAuth(self):
|
||||||
|
# Use basic authentication
|
||||||
|
|
||||||
|
authstring = bytes("{u}:{p}".format(u=self.username, p=self.password), "ascii")
|
||||||
|
|
||||||
|
# Use "basic" auth by default
|
||||||
|
auth = b64encode(authstring).decode("ascii")
|
||||||
|
self.client.credentials(HTTP_AUTHORIZATION="Basic {auth}".format(auth=auth))
|
||||||
|
|
||||||
|
def tokenAuth(self):
|
||||||
|
|
||||||
|
self.basicAuth()
|
||||||
token_url = reverse('api-token')
|
token_url = reverse('api-token')
|
||||||
|
response = self.client.get(token_url, format='json', data={})
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertIn('token', response.data)
|
||||||
|
|
||||||
# POST to retreive a token
|
|
||||||
response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password})
|
|
||||||
|
|
||||||
token = response.data['token']
|
token = response.data['token']
|
||||||
|
|
||||||
self.client.credentials(HTTP_AUTHORIZATION='Token ' + token)
|
|
||||||
|
|
||||||
self.token = token
|
self.token = token
|
||||||
|
|
||||||
|
def token_failure(self):
|
||||||
|
# Test token endpoint without basic auth
|
||||||
|
url = reverse('api-token')
|
||||||
|
response = self.client.get(url, format='json')
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||||
|
self.assertIsNone(self.token)
|
||||||
|
|
||||||
|
def token_success(self):
|
||||||
|
|
||||||
|
self.tokenAuth()
|
||||||
|
self.assertIsNotNone(self.token)
|
||||||
|
|
||||||
def test_info_view(self):
|
def test_info_view(self):
|
||||||
"""
|
"""
|
||||||
Test that we can read the 'info-view' endpoint.
|
Test that we can read the 'info-view' endpoint.
|
||||||
@ -55,51 +80,18 @@ class APITests(APITestCase):
|
|||||||
|
|
||||||
self.assertEquals('InvenTree', data['server'])
|
self.assertEquals('InvenTree', data['server'])
|
||||||
|
|
||||||
def test_get_token_fail(self):
|
def test_barcode_fail(self):
|
||||||
""" Ensure that an invalid user cannot get a token """
|
# Test barcode endpoint without auth
|
||||||
|
response = self.client.post(reverse('api-barcode-plugin'), format='json')
|
||||||
token_url = reverse('api-token')
|
|
||||||
|
|
||||||
response = self.client.post(token_url, format='json', data={'username': 'bad', 'password': 'also_bad'})
|
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
|
||||||
self.assertFalse('token' in response.data)
|
|
||||||
|
|
||||||
def test_get_token_pass(self):
|
|
||||||
""" Ensure that a valid user can request an API token """
|
|
||||||
|
|
||||||
token_url = reverse('api-token')
|
|
||||||
|
|
||||||
# POST to retreive a token
|
|
||||||
response = self.client.post(token_url, format='json', data={'username': self.username, 'password': self.password})
|
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
||||||
self.assertTrue('token' in response.data)
|
|
||||||
self.assertTrue('pk' in response.data)
|
|
||||||
self.assertTrue(len(response.data['token']) > 0)
|
|
||||||
|
|
||||||
# Now, use the token to access other data
|
|
||||||
token = response.data['token']
|
|
||||||
|
|
||||||
part_url = reverse('api-part-list')
|
|
||||||
|
|
||||||
# Try to access without a token
|
|
||||||
response = self.client.get(part_url, format='json')
|
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
# Now, with the token
|
|
||||||
self.client.credentials(HTTP_AUTHORIZATION='Token ' + token)
|
|
||||||
response = self.client.get(part_url, format='json')
|
|
||||||
|
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
||||||
|
|
||||||
def test_barcode(self):
|
def test_barcode(self):
|
||||||
""" Test the barcode endpoint """
|
""" Test the barcode endpoint """
|
||||||
|
|
||||||
url = reverse('api-barcode-plugin')
|
self.tokenAuth()
|
||||||
|
|
||||||
self.get_token()
|
url = reverse('api-barcode-plugin')
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
'barcode': {
|
'barcode': {
|
||||||
|
@ -3,7 +3,7 @@ from django.contrib.auth.models import User
|
|||||||
from django.core.exceptions import ObjectDoesNotExist
|
from django.core.exceptions import ObjectDoesNotExist
|
||||||
from .serializers import UserSerializer
|
from .serializers import UserSerializer
|
||||||
|
|
||||||
from rest_framework.authtoken.views import ObtainAuthToken
|
from rest_framework.views import APIView
|
||||||
from rest_framework.authtoken.models import Token
|
from rest_framework.authtoken.models import Token
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
@ -25,28 +25,32 @@ class UserList(generics.ListAPIView):
|
|||||||
permission_classes = (permissions.IsAuthenticated,)
|
permission_classes = (permissions.IsAuthenticated,)
|
||||||
|
|
||||||
|
|
||||||
class GetAuthToken(ObtainAuthToken):
|
class GetAuthToken(APIView):
|
||||||
""" Return authentication token for an authenticated user. """
|
""" Return authentication token for an authenticated user. """
|
||||||
|
|
||||||
def post(self, request, *args, **kwargs):
|
permission_classes = [
|
||||||
|
permissions.IsAuthenticated,
|
||||||
|
]
|
||||||
|
|
||||||
|
def get(self, request, *args, **kwargs):
|
||||||
return self.login(request)
|
return self.login(request)
|
||||||
|
|
||||||
def delete(self, request):
|
def delete(self, request):
|
||||||
return self.logout(request)
|
return self.logout(request)
|
||||||
|
|
||||||
def login(self, request):
|
def login(self, request):
|
||||||
serializer = self.serializer_class(data=request.data,
|
|
||||||
context={'request': request})
|
|
||||||
serializer.is_valid(raise_exception=True)
|
|
||||||
user = serializer.validated_data['user']
|
|
||||||
token, created = Token.objects.get_or_create(user=user)
|
|
||||||
|
|
||||||
return Response({
|
if request.user.is_authenticated:
|
||||||
'token': token.key,
|
# Get the user token (or create one if it does not exist)
|
||||||
'pk': user.pk,
|
token, created = Token.objects.get_or_create(user=request.user)
|
||||||
'username': user.username,
|
return Response({
|
||||||
'email': user.email
|
'token': token.key,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
else:
|
||||||
|
return Response({
|
||||||
|
'error': 'User not authenticated',
|
||||||
|
})
|
||||||
|
|
||||||
def logout(self, request):
|
def logout(self, request):
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user