Fix login, add logout, add refresh for tokens and tests for all
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict # type: ignore
|
from pydantic_settings import BaseSettings, SettingsConfigDict # type: ignore
|
||||||
from passlib.context import CryptContext # type: ignore
|
from passlib.context import CryptContext # type: ignore
|
||||||
import pytz
|
import pytz
|
||||||
@@ -6,6 +7,7 @@ class Settings(BaseSettings):
|
|||||||
PROJECT_NAME: str = "StoneEdge Asset Management System"
|
PROJECT_NAME: str = "StoneEdge Asset Management System"
|
||||||
PROJECT_VERSION: str = "0.0.1"
|
PROJECT_VERSION: str = "0.0.1"
|
||||||
PROJECT_SUMMARY: str = "Product API for StoneEdge."
|
PROJECT_SUMMARY: str = "Product API for StoneEdge."
|
||||||
|
PROJECT_PUBLIC_URL: str = ""
|
||||||
SECRET_KEY: str | None = None
|
SECRET_KEY: str | None = None
|
||||||
PSQL_USERNAME: str = "user"
|
PSQL_USERNAME: str = "user"
|
||||||
PSQL_PASSWORD: str = "password"
|
PSQL_PASSWORD: str = "password"
|
||||||
@@ -16,8 +18,8 @@ class Settings(BaseSettings):
|
|||||||
ACCESS_TOKEN_EXPIRE_MIN: int = 30
|
ACCESS_TOKEN_EXPIRE_MIN: int = 30
|
||||||
REFRESH_TOKEN_EXPIRE_MIN: int = 60
|
REFRESH_TOKEN_EXPIRE_MIN: int = 60
|
||||||
BACKEND_CORS_ORIGINS: list = ["*"]
|
BACKEND_CORS_ORIGINS: list = ["*"]
|
||||||
DEFAULT_TIMEZONE: str = pytz.UTC._tzname
|
|
||||||
CRYPT: CryptContext = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
CRYPT: CryptContext = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
OAUTH2_SCHEME: OAuth2PasswordBearer = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env")
|
model_config = SettingsConfigDict(env_file=".env")
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import pytz
|
||||||
from tortoise.models import Model
|
from tortoise.models import Model
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
import uuid
|
import uuid
|
||||||
@@ -21,8 +22,8 @@ class Token(Model, CMDMixin):
|
|||||||
refresh_token: str = fields.TextField(null=True)
|
refresh_token: str = fields.TextField(null=True)
|
||||||
disabled: bool = fields.BooleanField(default=False)
|
disabled: bool = fields.BooleanField(default=False)
|
||||||
|
|
||||||
def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
self.disabled_at = datetime.now(tz=settings.DEFAULT_TIMEZONE)
|
self.disabled_at = datetime.now(tz=pytz.UTC)
|
||||||
self.save()
|
await self.save()
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,22 @@
|
|||||||
from datetime import timedelta
|
from datetime import datetime
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi.responses import JSONResponse
|
import uuid
|
||||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from modules.auth.utils import create_token
|
import pytz
|
||||||
|
from modules.users.utils import get_current_active_user
|
||||||
|
from modules.auth.utils import create_jwt_tokens, get_tokens_from_logged_in_user
|
||||||
from modules.auth.models import Token
|
from modules.auth.models import Token
|
||||||
from modules.users.models import User
|
from modules.users.models import User
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from tortoise.expressions import Q
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
email_error: str = "E-Mail Address or password is incorrect"
|
||||||
|
token_error: str = "Refresh token not found or something went wrong."
|
||||||
error: str = "E-Mail Address or password is incorrect"
|
|
||||||
|
|
||||||
crypt = settings.CRYPT
|
crypt = settings.CRYPT
|
||||||
|
|
||||||
@@ -25,34 +26,65 @@ async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]):
|
|||||||
user: User | None = await User.filter(email=form.username).first()
|
user: User | None = await User.filter(email=form.username).first()
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(status_code=401, detail=error)
|
raise HTTPException(status_code=401, detail=email_error)
|
||||||
|
|
||||||
if user.check_against_password(form.password) is False:
|
if user.check_against_password(form.password) is False:
|
||||||
raise HTTPException(status_code=401, detail=error)
|
raise HTTPException(status_code=401, detail=email_error)
|
||||||
|
|
||||||
if user.disabled is True:
|
if user.disabled is True:
|
||||||
raise HTTPException(status_code=401, detail=error)
|
raise HTTPException(status_code=401, detail=email_error)
|
||||||
|
|
||||||
auth_token = create_token(
|
tokens = await create_jwt_tokens(user)
|
||||||
user_id=user.id, offset=timedelta(settings.ACCESS_TOKEN_EXPIRE_MIN)
|
|
||||||
)
|
|
||||||
|
|
||||||
refresh_token = create_token(
|
return {"jwt": tokens}
|
||||||
user_id=user.id, offset=timedelta(settings.REFRESH_TOKEN_EXPIRE_MIN)
|
|
||||||
)
|
|
||||||
|
|
||||||
token = await Token.create(
|
|
||||||
user=user,
|
|
||||||
access_token=auth_token,
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"jwt": token}
|
@router.get("/logout", status_code=204)
|
||||||
|
async def logout(user: Annotated[User, Depends(get_current_active_user)]):
|
||||||
|
get_all_tokens = await Token.filter(Q(user__id=user.id))
|
||||||
|
if get_all_tokens is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT, detail="An error occurred."
|
||||||
|
)
|
||||||
|
for token in get_all_tokens:
|
||||||
|
await token.delete()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh")
|
@router.post("/refresh")
|
||||||
async def refresh_login():
|
async def refresh_login(
|
||||||
pass
|
refresh_token: Annotated[Token | None, Depends(get_tokens_from_logged_in_user)]
|
||||||
|
):
|
||||||
|
if refresh_token is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=token_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable tokens if used after expiration.
|
||||||
|
if (
|
||||||
|
refresh_token.created_at >= datetime.now(tz=pytz.utc)
|
||||||
|
and refresh_token.disabled is False
|
||||||
|
):
|
||||||
|
refresh_token.delete()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=token_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
if refresh_token.disabled is True:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=token_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
await refresh_token.delete()
|
||||||
|
|
||||||
|
tokens = await create_jwt_tokens(
|
||||||
|
user=await User.filter(Q(id=refresh_token.user_id)).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"jwt": tokens}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register")
|
@router.post("/register")
|
||||||
|
|||||||
@@ -1,8 +1,16 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Annotated
|
||||||
import uuid, time
|
import uuid, time
|
||||||
|
|
||||||
|
from tortoise.expressions import Q
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
|
||||||
|
from modules.users.models import User
|
||||||
|
from modules.auth.models import Token
|
||||||
from config import settings
|
from config import settings
|
||||||
from joserfc import jwt # type: ignore
|
from joserfc import jwt # type: ignore
|
||||||
from joserfc.jwk import OctKey # type: ignore
|
from joserfc.jwk import OctKey # type: ignore
|
||||||
|
from config import settings
|
||||||
|
|
||||||
crypt = settings.CRYPT
|
crypt = settings.CRYPT
|
||||||
|
|
||||||
@@ -17,7 +25,7 @@ def create_token(user_id: uuid, offset: timedelta) -> str:
|
|||||||
return jwt.encode(
|
return jwt.encode(
|
||||||
{"alg": "HS256", "typ": "JWT"},
|
{"alg": "HS256", "typ": "JWT"},
|
||||||
{
|
{
|
||||||
"iss": "",
|
"iss": f"{settings.PROJECT_PUBLIC_URL}",
|
||||||
"sub": f"id:{user}",
|
"sub": f"id:{user}",
|
||||||
"nbf": curr_time,
|
"nbf": curr_time,
|
||||||
"iat": curr_time,
|
"iat": curr_time,
|
||||||
@@ -27,5 +35,48 @@ def create_token(user_id: uuid, offset: timedelta) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def decode_token(token: str):
|
async def create_jwt_tokens(user: User) -> Token:
|
||||||
pass
|
"""
|
||||||
|
Create a Token class with the following entities:
|
||||||
|
|
||||||
|
1) A user that is attached to the Token
|
||||||
|
2) A fresh Auth Token
|
||||||
|
3) A fresh Refresh Token.
|
||||||
|
|
||||||
|
This is then returned in the form of an Token class.
|
||||||
|
"""
|
||||||
|
auth_token = create_token(
|
||||||
|
user_id=user.id, offset=timedelta(settings.ACCESS_TOKEN_EXPIRE_MIN)
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_token = create_token(
|
||||||
|
user_id=user.id, offset=timedelta(settings.REFRESH_TOKEN_EXPIRE_MIN)
|
||||||
|
)
|
||||||
|
|
||||||
|
return await Token.create(
|
||||||
|
user=user,
|
||||||
|
access_token=auth_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_tokens_from_logged_in_user(
|
||||||
|
token: Annotated[str, Depends(settings.OAUTH2_SCHEME)]
|
||||||
|
) -> User | None:
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="An issue occurred with the token.",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload: jwt.Token = jwt.decode(
|
||||||
|
token, OctKey.import_key(settings.SECRET_KEY), algorithms=["HS256"]
|
||||||
|
)
|
||||||
|
id: str | None = payload.claims.get("sub", None)
|
||||||
|
if id is None:
|
||||||
|
raise credentials_exception
|
||||||
|
user_id = id.split(":")[1]
|
||||||
|
except:
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
return await Token.filter(Q(refresh_token=token) & Q(user__id=user_id)).first()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from datetime import datetime
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Type
|
from typing import Type
|
||||||
import uuid
|
import uuid
|
||||||
|
import pytz
|
||||||
from tortoise.exceptions import ConfigurationError
|
from tortoise.exceptions import ConfigurationError
|
||||||
from tortoise.models import Model
|
from tortoise.models import Model
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
@@ -78,9 +79,9 @@ class Organization(Model, CMDMixin):
|
|||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"{self.id} - {self.name}"
|
return f"{self.id} - {self.name}"
|
||||||
|
|
||||||
def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
self.disabled_at = datetime.now(tz=settings.DEFAULT_TIMEZONE)
|
self.disabled_at = datetime.now(tz=pytz.UTC)
|
||||||
self.save()
|
await self.save()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import uuid
|
import uuid
|
||||||
from pydantic import EmailStr
|
from pydantic import EmailStr
|
||||||
|
import pytz
|
||||||
from tortoise.models import Model
|
from tortoise.models import Model
|
||||||
from tortoise import fields
|
from tortoise import fields
|
||||||
|
|
||||||
@@ -57,10 +58,10 @@ class User(Model, CMDMixin):
|
|||||||
return False
|
return False
|
||||||
self.set_password(new_password)
|
self.set_password(new_password)
|
||||||
|
|
||||||
def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
self.disabled_at = datetime.now(tz=settings.DEFAULT_TIMEZONE)
|
self.disabled_at = datetime.now(tz=pytz.UTC)
|
||||||
self.save()
|
await self.save()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -102,7 +103,7 @@ class Membership(Model, CMDMixin):
|
|||||||
acl: ACL = fields.ForeignKeyField("models.ACL")
|
acl: ACL = fields.ForeignKeyField("models.ACL")
|
||||||
disabled: bool = fields.BooleanField(default=False)
|
disabled: bool = fields.BooleanField(default=False)
|
||||||
|
|
||||||
def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
self.disabled_at = datetime.now(tz=settings.DEFAULT_TIMEZONE)
|
self.disabled_at = datetime.now(tz=pytz.UTC)
|
||||||
self.save()
|
await self.save()
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
from joserfc import jwt # type: ignore
|
||||||
|
from joserfc.jwk import OctKey # type: ignore
|
||||||
|
|
||||||
|
from tortoise.expressions import Q
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
|
||||||
|
# from modules.users.schemas import UserModel
|
||||||
|
from modules.users.models import User
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_from_token(token: Annotated[str, Depends(settings.OAUTH2_SCHEME)]) -> User:
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="An issue occurred with the token.",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload: jwt.Token = jwt.decode(token, OctKey.import_key(settings.SECRET_KEY), algorithms=["HS256"])
|
||||||
|
id: str | None = payload.claims.get("sub", None)
|
||||||
|
if id is None:
|
||||||
|
raise credentials_exception
|
||||||
|
user_id = id.split(":")[1]
|
||||||
|
except:
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
return await User.filter(Q(id=user_id)).get_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_active_user(
|
||||||
|
user: Annotated[User, Depends(get_user_from_token)],
|
||||||
|
):
|
||||||
|
if user.disabled:
|
||||||
|
raise HTTPException(status_code=400, detail="User is not found or active")
|
||||||
|
return user
|
||||||
@@ -65,3 +65,89 @@ class TestAuthentication(object):
|
|||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_logging_out_destroys_tokens(
|
||||||
|
self, client: AsyncClient, use_admin_account
|
||||||
|
):
|
||||||
|
_, _, admin, _ = use_admin_account
|
||||||
|
response = await client.post(
|
||||||
|
"http://localhost/api/v1/auth/",
|
||||||
|
data={
|
||||||
|
"username": "admin@localhost.com",
|
||||||
|
"password": "adminpassword",
|
||||||
|
"grant_type": "password",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {
|
||||||
|
"jwt": {
|
||||||
|
"created_at": ANY,
|
||||||
|
"user_id": str(admin.id),
|
||||||
|
"id": ANY,
|
||||||
|
"modified_at": ANY,
|
||||||
|
"disabled_at": None,
|
||||||
|
"refresh_token": ANY,
|
||||||
|
"disabled": False,
|
||||||
|
"access_token": ANY,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
access_token = response.json()["jwt"]["access_token"]
|
||||||
|
|
||||||
|
logout = await client.get(
|
||||||
|
"http://localhost/api/v1/auth/logout",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
assert logout.status_code == 204
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_new_tokens_upon_refresh(
|
||||||
|
self, client: AsyncClient, use_admin_account
|
||||||
|
):
|
||||||
|
_, _, admin, _ = use_admin_account
|
||||||
|
token = await client.post(
|
||||||
|
"http://localhost/api/v1/auth/",
|
||||||
|
data={
|
||||||
|
"username": "admin@localhost.com",
|
||||||
|
"password": "adminpassword",
|
||||||
|
"grant_type": "password",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert token.status_code == 200
|
||||||
|
assert token.json() == {
|
||||||
|
"jwt": {
|
||||||
|
"created_at": ANY,
|
||||||
|
"user_id": str(admin.id),
|
||||||
|
"id": ANY,
|
||||||
|
"modified_at": ANY,
|
||||||
|
"disabled_at": None,
|
||||||
|
"refresh_token": ANY,
|
||||||
|
"disabled": False,
|
||||||
|
"access_token": ANY,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
refresh_token = token.json()["jwt"]["refresh_token"]
|
||||||
|
|
||||||
|
response2 = await client.post(
|
||||||
|
"http://localhost/api/v1/auth/refresh",
|
||||||
|
headers={"Authorization": f"Bearer {refresh_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response2.status_code == 200
|
||||||
|
assert response2.json() == {
|
||||||
|
"jwt": {
|
||||||
|
"created_at": ANY,
|
||||||
|
"user_id": str(admin.id),
|
||||||
|
"id": ANY,
|
||||||
|
"modified_at": ANY,
|
||||||
|
"disabled_at": None,
|
||||||
|
"refresh_token": ANY,
|
||||||
|
"disabled": False,
|
||||||
|
"access_token": ANY,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user