Fix fixtures and fix flaky tests

This commit is contained in:
2025-03-12 18:04:37 +02:00
parent d6211d12b8
commit bf9ecaff73
13 changed files with 153 additions and 82 deletions
+17 -11
View File
@@ -14,7 +14,22 @@
abort @external 401 abort @external 401
} }
*.{$MAIN_DOMAIN} { stoneedge.{$MAIN_DOMAIN} {
handle / {
reverse_proxy stoneedge:8000
}
}
stoneedge-staging.{$MAIN_DOMAIN} {
handle / {
import protect
abort
reverse_proxy stoneedge-staging:8000
}
}
{$MAIN_DOMAIN} {
encode zstd gzip encode zstd gzip
handle /ping { handle /ping {
@@ -24,16 +39,7 @@
} }
handle / { handle / {
abort abort 404
}
handle @stoneedge {
reverse_proxy stoneedge:8000
}
handle @stoneedge-staging {
import protect
abort
} }
handle { handle {
+3 -3
View File
@@ -7,7 +7,7 @@ class Settings(BaseSettings):
PROJECT_NAME: str = "StoneEdge Asset Management System" PROJECT_NAME: str = "StoneEdge Asset Management System"
PROJECT_VERSION: str = "0.0.1" PROJECT_VERSION: str = "0.0.1"
PROJECT_SUMMARY: str = "Product API for StoneEdge." PROJECT_SUMMARY: str = "Product API for StoneEdge."
PROJECT_PUBLIC_URL: str = "" PROJECT_PUBLIC_URL: str = "localhost"
SECRET_KEY: str | None = None SECRET_KEY: str | None = None
PSQL_USERNAME: str = "user" PSQL_USERNAME: str = "user"
PSQL_PASSWORD: str = "password" PSQL_PASSWORD: str = "password"
@@ -15,8 +15,8 @@ class Settings(BaseSettings):
PSQL_PORT: int = 5432 PSQL_PORT: int = 5432
PSQL_DB_NAME: str = "stoneedge" PSQL_DB_NAME: str = "stoneedge"
PSQL_TEST_DB_NAME: str = "stoneedge_testing" PSQL_TEST_DB_NAME: str = "stoneedge_testing"
ACCESS_TOKEN_EXPIRE_MIN: int = 30 ACCESS_TOKEN_EXPIRE_MIN: int = 10
REFRESH_TOKEN_EXPIRE_MIN: int = 60 REFRESH_TOKEN_EXPIRE_MIN: int = 20
BACKEND_CORS_ORIGINS: list = ["*"] BACKEND_CORS_ORIGINS: list = ["*"]
CRYPT: CryptContext = CryptContext(schemes=["bcrypt"], deprecated="auto") CRYPT: CryptContext = CryptContext(schemes=["bcrypt"], deprecated="auto")
OAUTH2_SCHEME: OAuth2PasswordBearer = OAuth2PasswordBearer(tokenUrl="token") OAUTH2_SCHEME: OAuth2PasswordBearer = OAuth2PasswordBearer(tokenUrl="token")
+5
View File
@@ -11,6 +11,8 @@ from modules.auth.router import router as auth_router
from modules.users.router import router as users_router from modules.users.router import router as users_router
from modules.organizations.router import router as organizations_router from modules.organizations.router import router as organizations_router
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
@asynccontextmanager @asynccontextmanager
async def lifespan(_: FastAPI): async def lifespan(_: FastAPI):
@@ -27,6 +29,9 @@ app = FastAPI(
default_response_class=msgspec_jsonresponse, default_response_class=msgspec_jsonresponse,
) )
app.add_middleware(HTTPSRedirectMiddleware)
app.add_middleware(TrustedHostMiddleware, allowed_hosts=[settings.PROJECT_PUBLIC_URL,])
# Set all CORS enabled origins # Set all CORS enabled origins
if settings.BACKEND_CORS_ORIGINS: if settings.BACKEND_CORS_ORIGINS:
app.add_middleware( app.add_middleware(
+2 -2
View File
@@ -15,8 +15,8 @@ class Token(Model, CMDMixin):
Creates the access tokens for the User Creates the access tokens for the User
""" """
id: uuid = fields.UUIDField(primary_key=True) id: uuid.UUID = fields.UUIDField(primary_key=True)
user: uuid = fields.ForeignKeyField("models.User") user: uuid.UUID = fields.ForeignKeyField("models.User")
token_type: str = fields.CharField(max_length=128, default="Bearer") token_type: str = fields.CharField(max_length=128, default="Bearer")
access_token: str = fields.TextField(null=True) access_token: str = fields.TextField(null=True)
refresh_token: str = fields.TextField(null=True) refresh_token: str = fields.TextField(null=True)
+31 -6
View File
@@ -3,6 +3,7 @@ from typing import Annotated
import uuid import uuid
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pydantic import EmailStr
import pytz import pytz
from modules.users.utils import get_current_active_user from modules.users.utils import get_current_active_user
from modules.auth.utils import create_jwt_tokens, get_tokens_from_logged_in_user from modules.auth.utils import create_jwt_tokens, get_tokens_from_logged_in_user
@@ -15,7 +16,7 @@ from config import settings
router = APIRouter(prefix="/api/v1/auth", tags=["auth"]) router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
email_error: str = "E-Mail Address or password is incorrect" account_error: str = "E-Mail Address or password is incorrect"
token_error: str = "Refresh token not found or something went wrong." token_error: str = "Refresh token not found or something went wrong."
crypt = settings.CRYPT crypt = settings.CRYPT
@@ -23,16 +24,21 @@ crypt = settings.CRYPT
@router.post("/") @router.post("/")
async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]): async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]):
"""
Login
Logs the user into our API, creates tokens and passes them back to User.
"""
user: User | None = await User.filter(email=form.username).first() user: User | None = await User.filter(email=form.username).first()
if user is None: if user is None:
raise HTTPException(status_code=401, detail=email_error) raise HTTPException(status_code=401, detail=account_error)
if user.check_against_password(form.password) is False: if user.check_against_password(form.password) is False:
raise HTTPException(status_code=401, detail=email_error) raise HTTPException(status_code=401, detail=account_error)
if user.disabled is True: if user.disabled is True:
raise HTTPException(status_code=401, detail=email_error) raise HTTPException(status_code=401, detail=account_error)
tokens = await create_jwt_tokens(user) tokens = await create_jwt_tokens(user)
@@ -41,6 +47,11 @@ async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]):
@router.get("/logout", status_code=204) @router.get("/logout", status_code=204)
async def logout(user: Annotated[User, Depends(get_current_active_user)]): async def logout(user: Annotated[User, Depends(get_current_active_user)]):
"""
Logout
Logout destroys all tokens for User that are currently active.
"""
get_all_tokens = await Token.filter(Q(user__id=user.id)) get_all_tokens = await Token.filter(Q(user__id=user.id))
if get_all_tokens is None: if get_all_tokens is None:
raise HTTPException( raise HTTPException(
@@ -55,6 +66,12 @@ async def logout(user: Annotated[User, Depends(get_current_active_user)]):
async def refresh_login( async def refresh_login(
refresh_token: Annotated[Token | None, Depends(get_tokens_from_logged_in_user)] refresh_token: Annotated[Token | None, Depends(get_tokens_from_logged_in_user)]
): ):
"""
Refresh
After ging this route a token that is active and not disabled, we disable ALL other tokens and pass along new tokens.
Tokens are alive for about 10 minutes. Refresh tokens are alive for 20 minutes.
"""
if refresh_token is None: if refresh_token is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@@ -78,7 +95,11 @@ async def refresh_login(
detail=token_error, detail=token_error,
) )
await refresh_token.delete() get_all_tokens = await Token.filter(Q(user__id=refresh_token.user_id))
for token in get_all_tokens:
if token.id != refresh_token.id:
await token.delete()
tokens = await create_jwt_tokens( tokens = await create_jwt_tokens(
user=await User.filter(Q(id=refresh_token.user_id)).first() user=await User.filter(Q(id=refresh_token.user_id)).first()
@@ -88,5 +109,9 @@ async def refresh_login(
@router.post("/register") @router.post("/register")
async def register(): async def register(email: EmailStr, name: str, surname: str, password: str, validate_password: str):
pass
@router.post("/2fa")
async def twofa():
pass pass
@@ -2,4 +2,5 @@ from tortoise.contrib.pydantic import pydantic_model_creator
from modules.auth.models import Token from modules.auth.models import Token
TokenModel = pydantic_model_creator(Token) token_model = pydantic_model_creator(Token)
@@ -8,7 +8,6 @@ from tortoise.models import Model
from tortoise import fields from tortoise import fields
from mixins.CMDMixin import CMDMixin from mixins.CMDMixin import CMDMixin
from config import settings
class EnumField(fields.CharField): class EnumField(fields.CharField):
""" """
@@ -43,7 +42,6 @@ class OrganizationType(Enum):
2. What size is it? 2. What size is it?
All choices should be representative of the org. All choices should be representative of the org.
There are no seat costs.
""" """
HOME: str = "home" # Home use (Any size) HOME: str = "home" # Home use (Any size)
@@ -62,10 +60,10 @@ class Organization(Model, CMDMixin):
and makes sure that we can add users. and makes sure that we can add users.
""" """
id: uuid = fields.UUIDField(primary_key=True) id: uuid.UUID = fields.UUIDField(primary_key=True)
name: str = fields.CharField(max_length=128) name: str = fields.CharField(max_length=128)
type: str = EnumField(OrganizationType) type: str = EnumField(OrganizationType)
users: uuid = fields.ManyToManyField( users: uuid.UUID = fields.ManyToManyField(
"models.User", "models.User",
related_name="members", related_name="members",
through="Membership", through="Membership",
@@ -18,7 +18,7 @@ class User(Model, CMDMixin):
This holds all of our users This holds all of our users
""" """
id: uuid = fields.UUIDField(primary_key=True) id: uuid.UUID = fields.UUIDField(primary_key=True)
email: EmailStr = fields.CharField(max_length=128) email: EmailStr = fields.CharField(max_length=128)
username: str = fields.TextField(max_length=128) username: str = fields.TextField(max_length=128)
name: str = fields.TextField(max_length=128) name: str = fields.TextField(max_length=128)
@@ -72,7 +72,7 @@ class ACL(Model):
Access control lists, every invited user gets an ACL and this decides whether you grant / deny access to certain parts of our system. Access control lists, every invited user gets an ACL and this decides whether you grant / deny access to certain parts of our system.
""" """
id: uuid = fields.UUIDField(primary_key=True) id: uuid.UUID = fields.UUIDField(primary_key=True)
READ: bool = fields.BooleanField(default=False) READ: bool = fields.BooleanField(default=False)
WRITE: bool = fields.BooleanField(default=False) WRITE: bool = fields.BooleanField(default=False)
REPORT: bool = fields.BooleanField(default=False) REPORT: bool = fields.BooleanField(default=False)
@@ -97,7 +97,7 @@ class Membership(Model, CMDMixin):
Creates a connection between an user and a company together with an ACL. Creates a connection between an user and a company together with an ACL.
""" """
id: uuid = fields.UUIDField(primary_key=True) id: uuid.UUID = fields.UUIDField(primary_key=True)
organization: Organization = fields.ForeignKeyField("models.Organization") organization: Organization = fields.ForeignKeyField("models.Organization")
user: User = fields.ForeignKeyField("models.User") user: User = fields.ForeignKeyField("models.User")
acl: ACL = fields.ForeignKeyField("models.ACL") acl: ACL = fields.ForeignKeyField("models.ACL")
+18 -7
View File
@@ -5,19 +5,22 @@ from joserfc.jwk import OctKey # type: ignore
from tortoise.expressions import Q from tortoise.expressions import Q
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
# from modules.users.schemas import UserModel
from modules.users.models import User from modules.users.models import User
from config import settings from config import settings
async def get_user_from_token(token: Annotated[str, Depends(settings.OAUTH2_SCHEME)]) -> User: async def get_user_from_token(
token: Annotated[str, Depends(settings.OAUTH2_SCHEME)]
) -> User | None:
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="An issue occurred with the token.", detail="An issue occurred with the token.",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
try: try:
payload: jwt.Token = jwt.decode(token, OctKey.import_key(settings.SECRET_KEY), algorithms=["HS256"]) payload: jwt.Token = jwt.decode(
token, OctKey.import_key(settings.SECRET_KEY), algorithms=["HS256"]
)
id: str | None = payload.claims.get("sub", None) id: str | None = payload.claims.get("sub", None)
if id is None: if id is None:
raise credentials_exception raise credentials_exception
@@ -25,12 +28,20 @@ async def get_user_from_token(token: Annotated[str, Depends(settings.OAUTH2_SCHE
except: except:
raise credentials_exception raise credentials_exception
return await User.filter(Q(id=user_id)).get_or_none() return await User.filter(Q(id=user_id)).first()
async def get_current_active_user( async def get_current_active_user(
user: Annotated[User, Depends(get_user_from_token)], user: Annotated[User | None, Depends(get_user_from_token)],
): ) -> User:
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User is not found or active",
)
if user.disabled: if user.disabled:
raise HTTPException(status_code=400, detail="User is not found or active") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User is not found or active",
)
return user return user
+1 -1
View File
@@ -40,7 +40,7 @@ def event_loop():
@asynccontextmanager @asynccontextmanager
async def client_manager(app, base_url="http://localhost", **kw) -> ClientManagerType: async def client_manager(app, base_url="https://localhost", **kw) -> ClientManagerType:
app.state.testing = True app.state.testing = True
async with LifespanManager(app): async with LifespanManager(app):
transport = httpx.ASGITransport(app=app) transport = httpx.ASGITransport(app=app)
+42 -28
View File
@@ -1,6 +1,6 @@
from modules.organizations.models import Organization, OrganizationType from modules.organizations.models import Organization, OrganizationType
from modules.users.models import ACL, Membership, User from modules.users.models import ACL, Membership, User
import pytest # type: ignore import pytest # type=ignore
from config import settings from config import settings
crypt = settings.CRYPT crypt = settings.CRYPT
@@ -10,30 +10,38 @@ crypt = settings.CRYPT
async def use_user_account(): async def use_user_account():
org, _ = await Organization.get_or_create( org, _ = await Organization.get_or_create(
id="6ad4c94e-0522-4912-8d16-02d451f4c92d", id="6ad4c94e-0522-4912-8d16-02d451f4c92d",
name="User's Organization", defaults={
type=OrganizationType.HOME, "name": "User's Organization",
"type": OrganizationType.HOME,
},
) )
acl, _ = await ACL.get_or_create( acl, _ = await ACL.get_or_create(
id="a4e927a3-36e5-4761-badb-0a44ade6616f", id="a4e927a3-36e5-4761-badb-0a44ade6616f",
READ=True, defaults={
WRITE=True, "READ": True,
REPORT=True, "WRITE": True,
MANAGE=False, "REPORT": True,
ADMIN=False, "MANAGE": False,
"ADMIN": False,
},
) )
user, _ = await User.get_or_create( user, _ = await User.get_or_create(
id="24235427-9662-4ba3-a9c5-00000000000b", id="24235427-9662-4ba3-a9c5-00000000000b",
email="user@localhost.com", defaults={
username="user", "email": "user@localhost.com",
name="awesome", "username": "user",
surname="user", "name": "awesome",
password=crypt.hash("userpassword"), "surname": "user",
"password": crypt.hash("userpassword"),
},
) )
membership, _ = await Membership.get_or_create( membership, _ = await Membership.get_or_create(
id="833b9511-b2da-4760-8fa4-1a5c7059911e", id="833b9511-b2da-4760-8fa4-1a5c7059911e",
organization=org, defaults={
user=user, "organization": org,
acl=acl, "user": user,
"acl": acl,
},
) )
return org, acl, user, membership return org, acl, user, membership
@@ -42,31 +50,37 @@ async def use_user_account():
async def use_admin_account(): async def use_admin_account():
org, _ = await Organization.get_or_create( org, _ = await Organization.get_or_create(
id="de001f44-1bb8-4667-9f9d-2d62d6ad7270", id="de001f44-1bb8-4667-9f9d-2d62d6ad7270",
name="Admin's Organization", defaults={
type=OrganizationType.EXTRA_LARGE_ORGANIZATION, "name": "Admin's Organization",
"type": OrganizationType.EXTRA_LARGE_ORGANIZATION,
},
) )
acl, _ = await ACL.get_or_create( acl, _ = await ACL.get_or_create(
id="83c1bfe6-c2ed-4ba1-be03-0e5c1960ec31", id="83c1bfe6-c2ed-4ba1-be03-0e5c1960ec31",
READ=True, defaults={
WRITE=True, "READ": True,
REPORT=True, "WRITE": True,
MANAGE=True, "REPORT": True,
ADMIN=True, "MANAGE": True,
"ADMIN": True,
},
) )
user, _ = await User.get_or_create( user, _ = await User.get_or_create(
id="24235427-9662-4ba3-a9c5-00000000000a",
defaults={ defaults={
"id": "24235427-9662-4ba3-a9c5-00000000000a",
"email": "admin@localhost.com", "email": "admin@localhost.com",
"username": "admin", "username": "admin",
"name": "awesome", "name": "awesome",
"surname": "admin", "surname": "admin",
"password": crypt.hash("adminpassword"), "password": crypt.hash("adminpassword"),
} },
) )
membership, _ = await Membership.get_or_create( membership, _ = await Membership.get_or_create(
id="393473ee-c218-4bcf-82cd-cb676c4d8a33", id="393473ee-c218-4bcf-82cd-cb676c4d8a33",
organization=org, defaults={
user=user, "organization": org,
acl=acl, "user": user,
"acl": acl,
},
) )
return org, acl, user, membership return org, acl, user, membership
@@ -12,7 +12,7 @@ class TestAuthentication(object):
self, client: AsyncClient self, client: AsyncClient
): ):
response = await client.post( response = await client.post(
"http://localhost/api/v1/auth/", "https://localhost/api/v1/auth/",
data={ data={
"username": "non-existing@localhost.com", "username": "non-existing@localhost.com",
"password": "password", "password": "password",
@@ -28,7 +28,7 @@ class TestAuthentication(object):
): ):
_, _, _, _ = use_admin_account _, _, _, _ = use_admin_account
response = await client.post( response = await client.post(
"http://localhost/api/v1/auth/", "https://localhost/api/v1/auth/",
data={ data={
"username": "admin@localhost.com", "username": "admin@localhost.com",
"password": "password", "password": "password",
@@ -44,7 +44,7 @@ class TestAuthentication(object):
): ):
_, _, admin, _ = use_admin_account _, _, admin, _ = use_admin_account
response = await client.post( response = await client.post(
"http://localhost/api/v1/auth/", "https://localhost/api/v1/auth/",
data={ data={
"username": "admin@localhost.com", "username": "admin@localhost.com",
"password": "adminpassword", "password": "adminpassword",
@@ -68,14 +68,14 @@ class TestAuthentication(object):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_logging_out_destroys_tokens( async def test_logging_out_destroys_tokens(
self, client: AsyncClient, use_admin_account self, client: AsyncClient, use_user_account
): ):
_, _, admin, _ = use_admin_account _, _, user, _ = use_user_account
response = await client.post( response = await client.post(
"http://localhost/api/v1/auth/", "https://localhost/api/v1/auth/",
data={ data={
"username": "admin@localhost.com", "username": "user@localhost.com",
"password": "adminpassword", "password": "userpassword",
"grant_type": "password", "grant_type": "password",
}, },
) )
@@ -83,7 +83,7 @@ class TestAuthentication(object):
assert response.json() == { assert response.json() == {
"jwt": { "jwt": {
"created_at": ANY, "created_at": ANY,
"user_id": str(admin.id), "user_id": str(user.id),
"id": ANY, "id": ANY,
"modified_at": ANY, "modified_at": ANY,
"disabled_at": None, "disabled_at": None,
@@ -95,20 +95,31 @@ class TestAuthentication(object):
} }
access_token = response.json()["jwt"]["access_token"] access_token = response.json()["jwt"]["access_token"]
refresh_token = response.json()["jwt"]["refresh_token"]
logout = await client.get( logout = await client.get(
"http://localhost/api/v1/auth/logout", "https://localhost/api/v1/auth/logout",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
) )
assert logout.status_code == 204 assert logout.status_code == 204
refresh_request = await client.post(
"https://localhost/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {refresh_token}"},
)
assert refresh_request.status_code == 401
assert refresh_request.json() == {
"detail": "Refresh token not found or something went wrong."
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_new_tokens_upon_refresh( async def test_create_new_tokens_upon_refresh(
self, client: AsyncClient, use_admin_account self, client: AsyncClient, use_admin_account
): ):
_, _, admin, _ = use_admin_account _, _, admin, _ = use_admin_account
token = await client.post( token = await client.post(
"http://localhost/api/v1/auth/", "https://localhost/api/v1/auth/",
data={ data={
"username": "admin@localhost.com", "username": "admin@localhost.com",
"password": "adminpassword", "password": "adminpassword",
@@ -133,7 +144,7 @@ class TestAuthentication(object):
refresh_token = token.json()["jwt"]["refresh_token"] refresh_token = token.json()["jwt"]["refresh_token"]
response2 = await client.post( response2 = await client.post(
"http://localhost/api/v1/auth/refresh", "https://localhost/api/v1/auth/refresh",
headers={"Authorization": f"Bearer {refresh_token}"}, headers={"Authorization": f"Bearer {refresh_token}"},
) )
@@ -1,13 +1,13 @@
import pytest import pytest # type: ignore
from httpx import AsyncClient from httpx import AsyncClient
class TestRootRoute(object): class TestRootRoute(object):
async def test_read_docs_on_main_route(self, client: AsyncClient): async def test_read_docs_on_main_route(self, client: AsyncClient):
response = await client.get("http://localhost/api/v1/") response = await client.get("https://localhost/api/v1/")
assert response.status_code == 307 assert response.status_code == 307
async def test_get_pong(self, client: AsyncClient): async def test_get_pong(self, client: AsyncClient):
response = await client.get("http://localhost/api/v1/ping") response = await client.get("https://localhost/api/v1/ping")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"ping": "pong!"} assert response.json() == {"ping": "pong!"}