Fix ASYNC tests and managing the individual tests
This commit is contained in:
@@ -8,7 +8,6 @@ class Settings(BaseSettings):
|
||||
PROJECT_VERSION: str = "0.0.1"
|
||||
PROJECT_SUMMARY: str = "Product API for StoneEdge."
|
||||
SECRET_KEY: str | None = None
|
||||
HASHING_SCHEME: str = "HS512"
|
||||
PSQL_USERNAME: str = "user"
|
||||
PSQL_PASSWORD: str = "password"
|
||||
PSQL_HOSTNAME: str = "localhost"
|
||||
@@ -17,6 +16,7 @@ class Settings(BaseSettings):
|
||||
PSQL_TEST_DB_NAME: str = "stoneedge_testing"
|
||||
ACCESS_TOKEN_EXPIRE_MIN: int = 30
|
||||
REFRESH_TOKEN_EXPIRE_MIN: int = 60
|
||||
BACKEND_CORS_ORIGINS: list = ["*"]
|
||||
DEFAULT_TIMEZONE: str = pytz.UTC._tzname
|
||||
CRYPT: CryptContext = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing_extensions import Any
|
||||
from tortoise import Tortoise
|
||||
from config import settings
|
||||
from aerich import Command
|
||||
|
||||
modules: dict[str, Any] = {
|
||||
"models": [
|
||||
@@ -33,10 +34,11 @@ TORTOISE_ORM = {
|
||||
}
|
||||
|
||||
|
||||
async def init_db():
|
||||
async def migrate_db():
|
||||
aerich = Command(tortoise_config=TORTOISE_ORM)
|
||||
await aerich.init()
|
||||
await aerich.upgrade(run_in_transaction=True)
|
||||
await Tortoise.init(config=TORTOISE_ORM)
|
||||
|
||||
|
||||
async def migrate_db():
|
||||
await init_db()
|
||||
await Tortoise.generate_schemas(safe=True)
|
||||
async def end_connections_to_db():
|
||||
await Tortoise.close_connections()
|
||||
@@ -1,8 +1,10 @@
|
||||
from fastapi import FastAPI
|
||||
from tortoise import run_async
|
||||
from tortoise import Tortoise
|
||||
from config import settings
|
||||
from database import migrate_db
|
||||
from database import end_connections_to_db, migrate_db
|
||||
from responses import msgspec_jsonresponse
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from router import router as root_router
|
||||
from modules.assets.router import router as asset_router
|
||||
@@ -10,14 +12,32 @@ from modules.auth.router import router as auth_router
|
||||
from modules.users.router import router as users_router
|
||||
from modules.organizations.router import router as organizations_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
await migrate_db()
|
||||
yield
|
||||
print(_.state.testing)
|
||||
await end_connections_to_db()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title=settings.PROJECT_NAME,
|
||||
version=settings.PROJECT_VERSION,
|
||||
summary=settings.PROJECT_SUMMARY,
|
||||
default_response_class=msgspec_jsonresponse
|
||||
default_response_class=msgspec_jsonresponse,
|
||||
)
|
||||
|
||||
run_async(migrate_db())
|
||||
# Set all CORS enabled origins
|
||||
if settings.BACKEND_CORS_ORIGINS:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(root_router)
|
||||
app.include_router(auth_router)
|
||||
|
||||
@@ -67,17 +67,7 @@ CREATE TABLE IF NOT EXISTS "aerich" (
|
||||
"version" VARCHAR(255) NOT NULL,
|
||||
"app" VARCHAR(100) NOT NULL,
|
||||
"content" JSONB NOT NULL
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "Membership" (
|
||||
"organization_id" UUID NOT NULL REFERENCES "organization" ("id") ON DELETE NO ACTION,
|
||||
"user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE NO ACTION
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "uidx_Membership_organiz_b0a446" ON "Membership" ("organization_id", "user_id");
|
||||
CREATE TABLE IF NOT EXISTS "Membership" (
|
||||
"user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE NO ACTION,
|
||||
"organization_id" UUID NOT NULL REFERENCES "organization" ("id") ON DELETE NO ACTION
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "uidx_Membership_user_id_cc48d3" ON "Membership" ("user_id", "organization_id");"""
|
||||
);"""
|
||||
|
||||
|
||||
async def downgrade(db: BaseDBAsyncClient) -> str:
|
||||
|
||||
@@ -9,7 +9,6 @@ from modules.auth.models import Token
|
||||
from modules.users.models import User
|
||||
from fastapi import Depends, HTTPException
|
||||
from config import settings
|
||||
from tortoise.expressions import Q
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
@@ -23,28 +22,29 @@ crypt = settings.CRYPT
|
||||
|
||||
@router.post("/")
|
||||
async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]):
|
||||
user: User | None = await User.filter(
|
||||
Q(email=form.username)
|
||||
).get_or_none()
|
||||
|
||||
user: User | None = await User.filter(email=form.username).first()
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail=error)
|
||||
|
||||
if user.check_against_password(form.password) is False:
|
||||
raise HTTPException(status_code=401, detail=error)
|
||||
|
||||
return JSONResponse(
|
||||
await Token.create(
|
||||
user=user.id,
|
||||
access_token=create_token(
|
||||
user_id=user.id, offset=timedelta(settings.ACCESS_TOKEN_EXPIRE_MIN)
|
||||
),
|
||||
refresh_token=create_token(
|
||||
user_id=user.id, offset=timedelta(settings.REFRESH_TOKEN_EXPIRE_MIN)
|
||||
),
|
||||
)
|
||||
auth_token = create_token(
|
||||
user_id=user.id, offset=timedelta(settings.ACCESS_TOKEN_EXPIRE_MIN)
|
||||
)
|
||||
|
||||
refresh_token = create_token(
|
||||
user_id=user.id, offset=timedelta(settings.REFRESH_TOKEN_EXPIRE_MIN)
|
||||
)
|
||||
|
||||
token = await Token.create(
|
||||
user=user.id,
|
||||
access_token=auth_token,
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
|
||||
return {"jwt": token}
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_login():
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import timedelta
|
||||
import uuid, time
|
||||
from config import settings
|
||||
from joserfc import jwt # type: ignore
|
||||
@@ -5,20 +6,21 @@ from joserfc import jwt # type: ignore
|
||||
crypt = settings.CRYPT
|
||||
|
||||
|
||||
def create_token(user_id: uuid, offset: float) -> str:
|
||||
def create_token(user_id: uuid, offset: timedelta) -> str:
|
||||
"""
|
||||
Creates a JWT token
|
||||
"""
|
||||
user = str(user_id)
|
||||
curr_time = int(time.time())
|
||||
|
||||
return jwt.encode(
|
||||
{"alg": settings.HASHING_SCHEME, "typ": "JWT"},
|
||||
{"alg": "HS256", "typ": "JWT"},
|
||||
{
|
||||
"iss": "",
|
||||
"sub": f"id:{user_id}",
|
||||
"sub": f"id:{user}",
|
||||
"nbf": curr_time,
|
||||
"iat": curr_time,
|
||||
"exp": int(curr_time + offset),
|
||||
"exp": int(time.time() + offset.total_seconds()),
|
||||
},
|
||||
settings.SECRET_KEY,
|
||||
)
|
||||
|
||||
@@ -40,16 +40,14 @@ class User(Model, CMDMixin):
|
||||
|
||||
def set_password(self, password: str) -> None:
|
||||
self.password = crypt.hash(
|
||||
password,
|
||||
settings.HASHING_SCHEME
|
||||
password
|
||||
)
|
||||
self.save() # Make sure to save the model in DB
|
||||
|
||||
def check_against_password(self, password: str) -> bool:
|
||||
return crypt.verify(
|
||||
password,
|
||||
self.password,
|
||||
settings.HASHING_SCHEME
|
||||
self.password
|
||||
)
|
||||
|
||||
def update_password(self, old_password, new_password: str, verify_new_password: str) -> bool:
|
||||
|
||||
@@ -9,12 +9,13 @@ passlib>=1.7.4
|
||||
pytz>=2024.2
|
||||
ptpython>=0.25
|
||||
msgspec>=0.19.0
|
||||
bcrypt>=4.2.1
|
||||
|
||||
# Test Suite
|
||||
httpx>=0.28.1
|
||||
pytest>=8.3.4
|
||||
mock>=5.1.0
|
||||
asyncio>=3.4.3
|
||||
pytest-mock>=3.14.0
|
||||
pytest-asyncio>=0.25.3
|
||||
asyncio>=3.4.3
|
||||
asgi-lifespan>=2.1.0
|
||||
@@ -4,6 +4,7 @@ from fastapi.responses import JSONResponse, RedirectResponse
|
||||
|
||||
router = APIRouter(prefix="/api/v1")
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def main() -> RedirectResponse:
|
||||
return RedirectResponse(url="/docs")
|
||||
@@ -11,4 +12,4 @@ async def main() -> RedirectResponse:
|
||||
|
||||
@router.get("/ping")
|
||||
async def ping() -> JSONResponse:
|
||||
return JSONResponse("PONG")
|
||||
return {"ping": "pong!"}
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
from typing import AsyncGenerator, Optional, Self
|
||||
import httpx
|
||||
from tortoise import Tortoise
|
||||
import pytest # type: ignore
|
||||
from database import modules
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
import httpx, pytest
|
||||
from config import settings
|
||||
from glob import glob
|
||||
|
||||
from asgi_lifespan import LifespanManager # type: ignore
|
||||
from asgi_lifespan import LifespanManager # type: ignore
|
||||
|
||||
settings.PSQL_DB_NAME = settings.PSQL_TEST_DB_NAME
|
||||
|
||||
pytest_plugins = [
|
||||
fixture.replace("/", ".").replace("\\", ".").replace(".py", "")
|
||||
for fixture in glob("tests/fixtures/*.py")
|
||||
if "__" not in fixture
|
||||
]
|
||||
|
||||
try:
|
||||
from main import app
|
||||
@@ -16,26 +24,7 @@ except ImportError:
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
from main import app
|
||||
|
||||
TORTOISE_ORM = {
|
||||
"connections": {
|
||||
"default": {
|
||||
"engine": "tortoise.backends.asyncpg",
|
||||
"credentials": {
|
||||
"host": settings.PSQL_HOSTNAME,
|
||||
"database": settings.PSQL_TEST_DB_NAME,
|
||||
"user": settings.PSQL_USERNAME,
|
||||
"password": settings.PSQL_PASSWORD,
|
||||
"port": settings.PSQL_PORT,
|
||||
},
|
||||
}
|
||||
},
|
||||
"apps": {
|
||||
"models": {
|
||||
"models": modules.get("models", []) + ["aerich.models"],
|
||||
"default_connection": "default",
|
||||
},
|
||||
},
|
||||
}
|
||||
ClientManagerType = AsyncGenerator[httpx.AsyncClient, None]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -43,43 +32,23 @@ def anyio_backend():
|
||||
return "asyncio"
|
||||
|
||||
|
||||
class TestClient(httpx.AsyncClient):
|
||||
def __init__(self, app, base_url="http://localhost", mount_lifespan=True, **kw) -> None:
|
||||
self.mount_lifespan = mount_lifespan
|
||||
self._manager: Optional[LifespanManager] = None
|
||||
super().__init__(transport=httpx.ASGITransport(app), base_url=base_url, **kw)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
if self.mount_lifespan:
|
||||
app = self._transport.app # type:ignore
|
||||
self._manager = await LifespanManager(app).__aenter__()
|
||||
self._transport = httpx.ASGITransport(app=self._manager.app)
|
||||
return await super().__aenter__()
|
||||
|
||||
async def __aexit__(self, *args, **kw):
|
||||
await super().__aexit__(*args, **kw)
|
||||
if self._manager is not None:
|
||||
await self._manager.__aexit__(*args, **kw)
|
||||
|
||||
async def init_db(create_db: bool = True, schemas: bool = True) -> None:
|
||||
"""Initial database connection"""
|
||||
await Tortoise.init(
|
||||
config=TORTOISE_ORM, timezone="Europe/Helsinki"
|
||||
)
|
||||
if create_db:
|
||||
print(f"Database created!")
|
||||
if schemas:
|
||||
await Tortoise.generate_schemas()
|
||||
print("Success to generate schemas")
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def initialize_tests():
|
||||
await init_db()
|
||||
yield
|
||||
await Tortoise._drop_databases()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def client() -> AsyncGenerator[TestClient, None]:
|
||||
async with TestClient(app) as c:
|
||||
yield c
|
||||
def event_loop():
|
||||
loop = asyncio.get_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def client_manager(app, base_url="http://localhost", **kw) -> ClientManagerType:
|
||||
app.state.testing = True
|
||||
async with LifespanManager(app):
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url=base_url, **kw) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def client() -> ClientManagerType:
|
||||
async with client_manager(app) as c:
|
||||
yield c
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
import uuid
|
||||
from modules.organizations.models import Organization
|
||||
from modules.users.models import ACL, Membership, User
|
||||
import pytest # type: ignore
|
||||
from config import settings
|
||||
|
||||
crypt = settings.CRYPT
|
||||
|
||||
@pytest.fixture()
|
||||
async def use_user_account():
|
||||
org = await Organization.create(name="User's Organization", type="home")
|
||||
acl = await ACL.create(
|
||||
READ=True, WRITE=True, REPORT=True, MANAGE=True, ADMIN=True
|
||||
)
|
||||
user = await User.create(
|
||||
email="user@localhost.com",
|
||||
username="user",
|
||||
name="awesome",
|
||||
surname="user",
|
||||
password=crypt.hash("userpassword"),
|
||||
)
|
||||
membership = await Membership.create(
|
||||
organization=org,
|
||||
user=user,
|
||||
acl=acl,
|
||||
)
|
||||
return org, acl, user, membership
|
||||
|
||||
@pytest.fixture()
|
||||
async def use_admin_account():
|
||||
org = await Organization.create(name="Admin's Organization", type="home")
|
||||
acl = await ACL.create(
|
||||
READ=True, WRITE=True, REPORT=True, MANAGE=True, ADMIN=True
|
||||
)
|
||||
user = await User.create(
|
||||
email="admin@localhost.com",
|
||||
username="admin",
|
||||
name="awesome",
|
||||
surname="admin",
|
||||
password=crypt.hash("adminpassword"),
|
||||
)
|
||||
membership = await Membership.create(
|
||||
organization=org,
|
||||
user=user,
|
||||
acl=acl,
|
||||
)
|
||||
return org, acl, user, membership
|
||||
|
||||
@@ -1,39 +1,52 @@
|
||||
import pytest
|
||||
import pytest # type: ignore
|
||||
from httpx import AsyncClient
|
||||
from modules.organizations.models import Organization
|
||||
from modules.users.models import ACL, Membership, User
|
||||
from config import settings
|
||||
|
||||
crypt = settings.CRYPT
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def setup_function():
|
||||
org = await Organization.create(name="Admin's Organization", type="home")
|
||||
user = await User.create(
|
||||
email="admin@localhost.com",
|
||||
username="admin",
|
||||
name="admin",
|
||||
surname="admin",
|
||||
password=crypt.hash("password")
|
||||
)
|
||||
acl = await ACL.create(READ=True, WRITE=True, REPORT=True, MANAGE=True, ADMIN=True)
|
||||
await Membership.create(organization=org, user=user, acl=acl)
|
||||
|
||||
class TestAuthentication(object):
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_with_non_existing_user_and_password(
|
||||
self, client: AsyncClient
|
||||
):
|
||||
response = await client.post(
|
||||
"http://localhost/api/v1/auth/",
|
||||
data={
|
||||
"username": "non-existing@localhost.com",
|
||||
"password": "password",
|
||||
"grant_type": "password",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {"detail": "E-Mail Address or password is incorrect"}
|
||||
|
||||
# def teardown_function():
|
||||
# Organization.all().delete()
|
||||
# User.all().delete()
|
||||
# ACL.all().delete()
|
||||
# Membership.all().delete()
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_with_existing_user_and_wrong_password(
|
||||
self, client: AsyncClient, use_admin_account
|
||||
):
|
||||
response = await client.post(
|
||||
"http://localhost/api/v1/auth/",
|
||||
data={
|
||||
"username": "admin@localhost.com",
|
||||
"password": "password",
|
||||
"grant_type": "password",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json() == {"detail": "E-Mail Address or password is incorrect"}
|
||||
|
||||
async def test_read_main(client: AsyncClient):
|
||||
print("start")
|
||||
response = await client.post(
|
||||
"http://localhost/api/v1/auth",
|
||||
data={
|
||||
"username": "admin@localhost.com",
|
||||
"password": "password",
|
||||
"grant_type": "password",
|
||||
},
|
||||
)
|
||||
assert response.json() == {}
|
||||
assert response.status_code == 200
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_with_existing_user_and_password(
|
||||
self, client: AsyncClient, use_admin_account
|
||||
):
|
||||
response = await client.post(
|
||||
"http://localhost/api/v1/auth/",
|
||||
data={
|
||||
"username": "admin@localhost.com",
|
||||
"password": "adminpassword",
|
||||
"grant_type": "password",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.text == ""
|
||||
|
||||
@@ -2,14 +2,12 @@ import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_read_main(client: AsyncClient):
|
||||
response = await client.get("http://localhost:8000/api/v1/")
|
||||
assert response.status_code == 307
|
||||
class TestRootRoute(object):
|
||||
async def test_read_docs_on_main_route(self, client: AsyncClient):
|
||||
response = await client.get("http://localhost/api/v1/")
|
||||
assert response.status_code == 307
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_pong(client: AsyncClient):
|
||||
response = await client.get("http://localhost:8000/api/v1/ping")
|
||||
assert response.status_code == 200
|
||||
assert response.text == '"PONG"'
|
||||
async def test_get_pong(self, client: AsyncClient):
|
||||
response = await client.get("http://localhost/api/v1/ping")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"ping": "pong!"}
|
||||
|
||||
Reference in New Issue
Block a user