8 Commits

29 changed files with 503 additions and 303 deletions
+3
View File
@@ -91,3 +91,6 @@
/web/**/psd /web/**/psd
/web/**/thumb /web/**/thumb
/web/**/sketch /web/**/sketch
# Prevent uploading DB files
*.sqlite*
+2 -2
View File
@@ -1,7 +1,6 @@
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from pydantic_settings import BaseSettings, SettingsConfigDict # type: ignore from pydantic_settings import BaseSettings, SettingsConfigDict # type: ignore
from passlib.context import CryptContext # type: ignore from passlib.context import CryptContext # type: ignore
import pytz
class Settings(BaseSettings): class Settings(BaseSettings):
PROJECT_NAME: str = "StoneEdge Asset Management System" PROJECT_NAME: str = "StoneEdge Asset Management System"
@@ -9,12 +8,13 @@ class Settings(BaseSettings):
PROJECT_SUMMARY: str = "Product API for StoneEdge." PROJECT_SUMMARY: str = "Product API for StoneEdge."
PROJECT_PUBLIC_URL: str = "localhost" PROJECT_PUBLIC_URL: str = "localhost"
SECRET_KEY: str | None = None SECRET_KEY: str | None = None
USE_HTTPS_ONLY: bool = False
IS_TESTING: bool = False # Testing uses a SQLite DB!
PSQL_USERNAME: str = "user" PSQL_USERNAME: str = "user"
PSQL_PASSWORD: str = "password" PSQL_PASSWORD: str = "password"
PSQL_HOSTNAME: str = "localhost" PSQL_HOSTNAME: str = "localhost"
PSQL_PORT: int = 5432 PSQL_PORT: int = 5432
PSQL_DB_NAME: str = "stoneedge" PSQL_DB_NAME: str = "stoneedge"
PSQL_TEST_DB_NAME: str = "stoneedge_testing"
ACCESS_TOKEN_EXPIRE_MIN: int = 10 ACCESS_TOKEN_EXPIRE_MIN: int = 10
REFRESH_TOKEN_EXPIRE_MIN: int = 20 REFRESH_TOKEN_EXPIRE_MIN: int = 20
BACKEND_CORS_ORIGINS: list = ["*"] BACKEND_CORS_ORIGINS: list = ["*"]
+8 -3
View File
@@ -5,7 +5,6 @@ from aerich import Command
modules: dict[str, Any] = { modules: dict[str, Any] = {
"models": [ "models": [
"modules.assets.models",
"modules.auth.models", "modules.auth.models",
"modules.users.models", "modules.users.models",
"modules.organizations.models", "modules.organizations.models",
@@ -14,6 +13,12 @@ modules: dict[str, Any] = {
TORTOISE_ORM = { TORTOISE_ORM = {
"connections": { "connections": {
"testing": {
"engine": "tortoise.backends.sqlite",
"credentials": {
"file_path": "stoneedge.sqlite"
}
},
"default": { "default": {
"engine": "tortoise.backends.asyncpg", "engine": "tortoise.backends.asyncpg",
"credentials": { "credentials": {
@@ -22,13 +27,13 @@ TORTOISE_ORM = {
"user": settings.PSQL_USERNAME, "user": settings.PSQL_USERNAME,
"password": settings.PSQL_PASSWORD, "password": settings.PSQL_PASSWORD,
"port": settings.PSQL_PORT, "port": settings.PSQL_PORT,
}, }
} }
}, },
"apps": { "apps": {
"models": { "models": {
"models": modules.get("models", []) + ["aerich.models"], "models": modules.get("models", []) + ["aerich.models"],
"default_connection": "default", "default_connection": "testing" if settings.IS_TESTING else "default",
}, },
}, },
} }
+3 -3
View File
@@ -6,7 +6,6 @@ from contextlib import asynccontextmanager
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from router import router as root_router from router import router as root_router
from modules.assets.router import router as asset_router
from modules.auth.router import router as auth_router from modules.auth.router import router as auth_router
from modules.users.router import router as users_router from modules.users.router import router as users_router
from modules.organizations.router import router as organizations_router from modules.organizations.router import router as organizations_router
@@ -29,9 +28,11 @@ app = FastAPI(
default_response_class=msgspec_jsonresponse, default_response_class=msgspec_jsonresponse,
) )
app.add_middleware(HTTPSRedirectMiddleware)
app.add_middleware(TrustedHostMiddleware, allowed_hosts=[settings.PROJECT_PUBLIC_URL,]) app.add_middleware(TrustedHostMiddleware, allowed_hosts=[settings.PROJECT_PUBLIC_URL,])
if settings.USE_HTTPS_ONLY:
app.add_middleware(HTTPSRedirectMiddleware)
# Set all CORS enabled origins # Set all CORS enabled origins
if settings.BACKEND_CORS_ORIGINS: if settings.BACKEND_CORS_ORIGINS:
app.add_middleware( app.add_middleware(
@@ -46,4 +47,3 @@ app.include_router(root_router)
app.include_router(auth_router) app.include_router(auth_router)
app.include_router(users_router) app.include_router(users_router)
app.include_router(organizations_router) app.include_router(organizations_router)
app.include_router(asset_router)
@@ -1,75 +0,0 @@
from tortoise import BaseDBAsyncClient
async def upgrade(db: BaseDBAsyncClient) -> str:
return """
CREATE TABLE IF NOT EXISTS "asset" (
"id" UUID NOT NULL PRIMARY KEY,
"name" VARCHAR(128) NOT NULL
);
CREATE TABLE IF NOT EXISTS "acl" (
"id" UUID NOT NULL PRIMARY KEY,
"READ" BOOL NOT NULL DEFAULT False,
"WRITE" BOOL NOT NULL DEFAULT False,
"REPORT" BOOL NOT NULL DEFAULT False,
"MANAGE" BOOL NOT NULL DEFAULT False,
"ADMIN" BOOL NOT NULL DEFAULT False
);
COMMENT ON TABLE "acl" IS 'ACL';
CREATE TABLE IF NOT EXISTS "organization" (
"created_at" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
"modified_at" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
"disabled_at" TIMESTAMPTZ,
"id" UUID NOT NULL PRIMARY KEY,
"name" VARCHAR(128) NOT NULL,
"type" VARCHAR(128) NOT NULL,
"disabled" BOOL NOT NULL DEFAULT False
);
COMMENT ON TABLE "organization" IS 'Organization';
CREATE TABLE IF NOT EXISTS "user" (
"created_at" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
"modified_at" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
"disabled_at" TIMESTAMPTZ,
"id" UUID NOT NULL PRIMARY KEY,
"email" VARCHAR(128) NOT NULL,
"username" TEXT NOT NULL,
"name" TEXT NOT NULL,
"surname" TEXT NOT NULL,
"password" VARCHAR(128),
"disabled" BOOL NOT NULL DEFAULT False
);
COMMENT ON TABLE "user" IS 'User';
CREATE TABLE IF NOT EXISTS "token" (
"created_at" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
"modified_at" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
"disabled_at" TIMESTAMPTZ,
"id" UUID NOT NULL PRIMARY KEY,
"token_type" VARCHAR(128) NOT NULL DEFAULT 'Bearer',
"access_token" VARCHAR(128),
"refresh_token" VARCHAR(128),
"disabled" BOOL NOT NULL DEFAULT False,
"user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
);
COMMENT ON TABLE "token" IS 'Token';
CREATE TABLE IF NOT EXISTS "membership" (
"created_at" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
"modified_at" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
"disabled_at" TIMESTAMPTZ,
"id" UUID NOT NULL PRIMARY KEY,
"disabled" BOOL NOT NULL DEFAULT False,
"acl_id" UUID NOT NULL REFERENCES "acl" ("id") ON DELETE CASCADE,
"organization_id" UUID NOT NULL REFERENCES "organization" ("id") ON DELETE CASCADE,
"user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
);
COMMENT ON TABLE "membership" IS 'Membership';
CREATE TABLE IF NOT EXISTS "aerich" (
"id" SERIAL NOT NULL PRIMARY KEY,
"version" VARCHAR(255) NOT NULL,
"app" VARCHAR(100) NOT NULL,
"content" JSONB NOT NULL
);"""
async def downgrade(db: BaseDBAsyncClient) -> str:
return """
"""
@@ -0,0 +1,71 @@
from tortoise import BaseDBAsyncClient
async def upgrade(db: BaseDBAsyncClient) -> str:
return """
CREATE TABLE IF NOT EXISTS "acl" (
"id" CHAR(36) NOT NULL PRIMARY KEY,
"READ" INT NOT NULL DEFAULT 0,
"WRITE" INT NOT NULL DEFAULT 0,
"REPORT" INT NOT NULL DEFAULT 0,
"MANAGE" INT NOT NULL DEFAULT 0,
"ADMIN" INT NOT NULL DEFAULT 0
) /* ACL */;
CREATE TABLE IF NOT EXISTS "organization" (
"created_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
"modified_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
"disabled_at" TIMESTAMP,
"id" CHAR(36) NOT NULL PRIMARY KEY,
"name" VARCHAR(128) NOT NULL,
"type" VARCHAR(128) NOT NULL,
"street_name" TEXT,
"zip_code" VARCHAR(128),
"state" VARCHAR(128),
"city" VARCHAR(128),
"country" VARCHAR(128),
"disabled" INT NOT NULL DEFAULT 0
) /* Organization */;
CREATE TABLE IF NOT EXISTS "user" (
"created_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
"modified_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
"disabled_at" TIMESTAMP,
"id" CHAR(36) NOT NULL PRIMARY KEY,
"email" VARCHAR(128) NOT NULL,
"username" TEXT NOT NULL,
"name" TEXT NOT NULL,
"surname" TEXT NOT NULL,
"password" VARCHAR(128),
"disabled" INT NOT NULL DEFAULT 0
) /* User */;
CREATE TABLE IF NOT EXISTS "token" (
"created_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
"modified_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
"disabled_at" TIMESTAMP,
"id" CHAR(36) NOT NULL PRIMARY KEY,
"token_type" VARCHAR(128) NOT NULL DEFAULT 'Bearer',
"access_token" TEXT,
"refresh_token" TEXT,
"disabled" INT NOT NULL DEFAULT 0,
"user_id" CHAR(36) NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
) /* Token */;
CREATE TABLE IF NOT EXISTS "membership" (
"created_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
"modified_at" TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
"disabled_at" TIMESTAMP,
"id" CHAR(36) NOT NULL PRIMARY KEY,
"disabled" INT NOT NULL DEFAULT 0,
"acl_id" CHAR(36) NOT NULL REFERENCES "acl" ("id") ON DELETE CASCADE,
"organization_id" CHAR(36) NOT NULL REFERENCES "organization" ("id") ON DELETE CASCADE,
"user_id" CHAR(36) NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
) /* Membership */;
CREATE TABLE IF NOT EXISTS "aerich" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
"version" VARCHAR(255) NOT NULL,
"app" VARCHAR(100) NOT NULL,
"content" JSON NOT NULL
);"""
async def downgrade(db: BaseDBAsyncClient) -> str:
return """
"""
@@ -1,15 +0,0 @@
from tortoise import BaseDBAsyncClient
async def upgrade(db: BaseDBAsyncClient) -> str:
return """
ALTER TABLE "asset" ADD "disabled_at" TIMESTAMPTZ;
ALTER TABLE "asset" ADD "modified_at" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP;
ALTER TABLE "asset" ADD "created_at" TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP;"""
async def downgrade(db: BaseDBAsyncClient) -> str:
return """
ALTER TABLE "asset" DROP COLUMN "disabled_at";
ALTER TABLE "asset" DROP COLUMN "modified_at";
ALTER TABLE "asset" DROP COLUMN "created_at";"""
@@ -1,13 +0,0 @@
from tortoise import BaseDBAsyncClient
async def upgrade(db: BaseDBAsyncClient) -> str:
return """
ALTER TABLE "token" ALTER COLUMN "refresh_token" TYPE TEXT USING "refresh_token"::TEXT;
ALTER TABLE "token" ALTER COLUMN "access_token" TYPE TEXT USING "access_token"::TEXT;"""
async def downgrade(db: BaseDBAsyncClient) -> str:
return """
ALTER TABLE "token" ALTER COLUMN "refresh_token" TYPE VARCHAR(128) USING "refresh_token"::VARCHAR(128);
ALTER TABLE "token" ALTER COLUMN "access_token" TYPE VARCHAR(128) USING "access_token"::VARCHAR(128);"""
@@ -0,0 +1,5 @@
from fastapi import APIRouter
router = APIRouter(prefix="/api/v1/acls", tags=["acl"])
@@ -1,7 +0,0 @@
from tortoise.models import Model
from tortoise import fields
from mixins.CMDMixin import CMDMixin
class Asset(Model, CMDMixin):
id = fields.UUIDField(primary_key=True)
name = fields.CharField(max_length=128)
@@ -1,23 +0,0 @@
from uuid import UUID
from fastapi.routing import APIRouter
router = APIRouter(
prefix="/assets"
)
@router.get("/")
async def get_all_assets():
pass
@router.post("/")
async def create_asset(name: str):
pass
@router.delete("/", status_code=204)
async def delete_asset(remove_id: UUID):
pass
@router.get("/{asset_id}")
async def get_asset(asset_id: UUID):
pass
+12 -7
View File
@@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from typing import Annotated from typing import Annotated, List
import uuid from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
import pytz import pytz
@@ -25,13 +25,18 @@ crypt = settings.CRYPT
@router.post("/login") @router.post("/login")
async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]): async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]) -> JSONResponse:
""" """
Login Login
Logs the user into our API, creates tokens and passes them back to User. Logs the user into our API, creates tokens and passes them back to User.
""" """
user: User | None = await User.filter(email=form.username).first() user: User | None = await User.filter(
Q(email=form.username) & Q(password=form.password)
).first()
print(await User.all())
print(form.username, form.password, user.__dict__ if user else None)
if user is None: if user is None:
raise HTTPException(status_code=401, detail=account_error) raise HTTPException(status_code=401, detail=account_error)
@@ -48,7 +53,7 @@ async def login(form: Annotated[OAuth2PasswordRequestForm, Depends()]):
@router.get("/logout", status_code=204) @router.get("/logout", status_code=204)
async def logout(user: Annotated[User, Depends(get_current_active_user)]): async def logout(user: Annotated[User, Depends(get_current_active_user)]) -> None:
""" """
Logout Logout
@@ -67,7 +72,7 @@ async def logout(user: Annotated[User, Depends(get_current_active_user)]):
@router.post("/refresh") @router.post("/refresh")
async def refresh_login( async def refresh_login(
refresh_token: Annotated[Token | None, Depends(get_tokens_from_logged_in_user)], refresh_token: Annotated[Token | None, Depends(get_tokens_from_logged_in_user)],
): ) -> JSONResponse:
""" """
Refresh Refresh
@@ -111,7 +116,7 @@ async def refresh_login(
@router.post("/register", status_code=201, response_model=user_model) @router.post("/register", status_code=201, response_model=user_model)
async def register(user: register_model): async def register(user: register_model) -> User:
# Prevent existing users from reapplying for our system. # Prevent existing users from reapplying for our system.
existing_user: User | None = await User.filter( existing_user: User | None = await User.filter(
Q(email=user.email) Q(email=user.email)
@@ -9,6 +9,7 @@ from tortoise import fields
from mixins.CMDMixin import CMDMixin from mixins.CMDMixin import CMDMixin
class EnumField(fields.CharField): class EnumField(fields.CharField):
""" """
Serializes Enums to and from a str representation in the DB. Serializes Enums to and from a str representation in the DB.
@@ -52,7 +53,6 @@ class OrganizationType(Enum):
EXTRA_LARGE_ORGANIZATION: str = "xl_org" # 1000 - 5000+ EXTRA_LARGE_ORGANIZATION: str = "xl_org" # 1000 - 5000+
class Organization(Model, CMDMixin): class Organization(Model, CMDMixin):
""" """
Organization Organization
@@ -64,10 +64,15 @@ class Organization(Model, CMDMixin):
id: uuid.UUID = fields.UUIDField(primary_key=True) id: uuid.UUID = fields.UUIDField(primary_key=True)
name: str = fields.CharField(max_length=128) name: str = fields.CharField(max_length=128)
type: str = EnumField(OrganizationType) type: str = EnumField(OrganizationType)
street_name: str | None = fields.TextField(null=True)
zip_code: str | None = fields.CharField(max_length=128, null=True)
state: str | None = fields.CharField(max_length=128, null=True)
city: str | None = fields.CharField(max_length=128, null=True)
country: str | None = fields.CharField(max_length=128, null=True)
users: uuid.UUID = fields.ManyToManyField( users: uuid.UUID = fields.ManyToManyField(
"models.User", "models.User",
related_name="members", related_name="members",
through="Membership", through="membership",
forward_key="user_id", forward_key="user_id",
backward_key="organization_id", backward_key="organization_id",
null=True, null=True,
@@ -85,5 +90,3 @@ class Organization(Model, CMDMixin):
self.disabled = True self.disabled = True
self.disabled_at = datetime.now(tz=pytz.UTC) self.disabled_at = datetime.now(tz=pytz.UTC)
await self.save() await self.save()
@@ -1,17 +1,103 @@
from fastapi import APIRouter import uuid
from fastapi import APIRouter, Depends, HTTPException
from typing import Annotated, List
from modules.organizations.models import Organization
from modules.organizations.schemas import organization_model, register_organization
from modules.users.utils import get_current_active_user
from modules.users.models import ACL, Membership, User
from tortoise.expressions import Q
router = APIRouter(prefix="/api/v1/organizations", tags=["orgs"])
router = APIRouter(prefix="/api/v1/organizations") @router.get("/", response_model=List[organization_model])
async def all_active_organizations(
user: Annotated[User, Depends(get_current_active_user)],
) -> List[Organization]:
memberships: List[Membership] = list(
await Membership.filter(
Q(user_id=user.id) & Q(disabled=False)
).prefetch_related("organization")
)
organizations: List[Organization] = []
@router.get("/") if len(memberships) < 1:
def all_organizations(): raise HTTPException(status_code=404, detail="No active organizations found!")
pass
@router.delete("/") for member in memberships:
def delete_organization(): organizations.append(member.organization)
pass
@router.post("/create") return organizations
def create_organization():
pass
@router.delete("/{org_id}", status_code=204)
async def delete_organization(
user: Annotated[User, Depends(get_current_active_user)], org_id: uuid.UUID
) -> None:
membership: Membership | None = (
await Membership.filter(Q(user_id=user.id) & Q(organization_id=org_id))
.get_or_none()
.prefetch_related("acl", "user", "organization")
)
if not membership:
raise HTTPException(
status_code=403,
detail="You are not part of the organization you wish to leave or remove.",
)
if membership.acl.ADMIN:
# Prepare to remove ALL members in the organization.
# We've already checked whether user is ADMIN.
all_memberships: List[Membership] = list(
await Membership.filter(Q(organization_id=org_id)).prefetch_related(
"acl", "user", "organization"
)
)
for member in all_memberships:
await member.acl.delete()
await member.delete()
# Completely remove organization.
await membership.organization.delete()
else:
await membership.delete()
return
@router.post("/", response_model=organization_model)
async def create_organization(
user: Annotated[User, Depends(get_current_active_user)],
register_organization: register_organization,
) -> Organization:
acl: ACL = await ACL.create(
READ=True, WRITE=True, REPORT=True, MANAGE=True, ADMIN=True
)
org: Organization = await Organization.create(
name=register_organization.name,
type=register_organization.type,
street_name=register_organization.street_name,
zip_code=register_organization.zip_code,
state=register_organization.state,
city=register_organization.city,
country=register_organization.country,
)
await Membership.create(organization=org, user=user, acl=acl)
return org
@router.put("/{org_id}", response_model=organization_model)
async def update_organization(
user: Annotated[User, Depends(get_current_active_user)],
org_id: uuid.UUID,
alter_organization: register_organization,
) -> Organization:
org: Organization | None = Organization.filter(
Q(users__id=user.id) & Q(id=org_id)
).get_or_none()
if not org:
raise HTTPException(status_code=404, detail="Organization could not be found.")
return await org.update_from_dict(**alter_organization)
@@ -1,6 +1,15 @@
from pydantic import BaseModel
from tortoise.contrib.pydantic import pydantic_model_creator from tortoise.contrib.pydantic import pydantic_model_creator
from modules.organizations.models import Organization from modules.organizations.models import Organization, OrganizationType
OrganizationModel = pydantic_model_creator(Organization) organization_model = pydantic_model_creator(Organization)
class register_organization(BaseModel):
name: str
type: OrganizationType
street_name: str | None
zip_code: str | None
state: str | None
city: str | None
country: str | None
@@ -1,4 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import List
import uuid import uuid
from pydantic import EmailStr from pydantic import EmailStr
import pytz import pytz
@@ -25,17 +26,16 @@ class User(Model, CMDMixin):
name: str = fields.TextField(max_length=128) name: str = fields.TextField(max_length=128)
surname: str = fields.TextField(max_length=128) surname: str = fields.TextField(max_length=128)
password: str = fields.CharField(max_length=128, null=True) password: str = fields.CharField(max_length=128, null=True)
organizations: uuid = fields.ManyToManyField( organizations: List[Organization] = fields.ManyToManyField(
"models.Organization", "models.Organization",
related_name="members", related_name="members",
through="Membership", through="membership",
forward_key="organization_id", forward_key="organization_id",
backward_key="user_id", backward_key="user_id",
null=True, null=True,
on_delete=fields.NO_ACTION, on_delete=fields.NO_ACTION,
) )
disabled: bool = fields.BooleanField(default=False) disabled: bool = fields.BooleanField(default=False)
# tokens = fields.ForeignKeyField("models.Token")
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.id} - {self.name} {self.surname}" return f"{self.id} - {self.name} {self.surname}"
@@ -98,9 +98,9 @@ class Membership(Model, CMDMixin):
""" """
id: uuid.UUID = fields.UUIDField(primary_key=True) id: uuid.UUID = fields.UUIDField(primary_key=True)
organization: Organization = fields.ForeignKeyField("models.Organization") organization: Organization | None = fields.ForeignKeyField("models.Organization")
user: User = fields.ForeignKeyField("models.User") user: User | None = fields.ForeignKeyField("models.User")
acl: ACL = fields.ForeignKeyField("models.ACL") acl: ACL | None = fields.ForeignKeyField("models.ACL")
disabled: bool = fields.BooleanField(default=False) disabled: bool = fields.BooleanField(default=False)
async def delete(self, force: bool = False) -> None: async def delete(self, force: bool = False) -> None:
+1 -1
View File
@@ -15,4 +15,4 @@ asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session" asyncio_default_fixture_loop_scope = "session"
testpaths = [ testpaths = [
"tests/", "tests/",
] ]
@@ -1,21 +1,21 @@
aerich>=0.8.0 aerich>=0.9.0
fastapi[all]>=0.115.5 fastapi[all]>=0.115.12
python-dotenv>=0.21.0 tortoise-orm[asyncpg]>=0.25.1
tortoise-orm[asyncpg]>=0.22.1 uvicorn>=0.34.3
uvicorn>=0.31.1 black>=25.1.0
black>=24.10.0 joserfc>=1.1.0
joserfc>=1.0.1
passlib>=1.7.4 passlib>=1.7.4
pytz>=2024.2 pytz>=2025.2
ptpython>=0.25 ptpython>=3.0.30
msgspec>=0.19.0 msgspec>=0.19.0
bcrypt>=4.2.1 bcrypt>=4.3.0
tomlkit>=0.13.3
# Test Suite # Test Suite
httpx>=0.28.1 httpx>=0.28.1
pytest>=8.3.4 mock>=5.2.0
mock>=5.1.0 pytest>=8.4.0
asyncio>=3.4.3 asyncio>=3.4.3
pytest-mock>=3.14.0 pytest-mock>=3.14.1
pytest-asyncio>=0.25.3 pytest-asyncio>=1.0.0
asgi-lifespan>=2.1.0 asgi-lifespan>=2.1.0
+5
View File
@@ -0,0 +1,5 @@
import pytest
@pytest.mark.usefixtures("create_user_with_org")
class Test():
pass
+3 -12
View File
@@ -2,18 +2,9 @@ import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator from typing import AsyncGenerator
import httpx, pytest import httpx, pytest
from config import settings
from glob import glob
from asgi_lifespan import LifespanManager # type: ignore from asgi_lifespan import LifespanManager # type: ignore
from tests.fixtures.account import *
settings.PSQL_DB_NAME = settings.PSQL_TEST_DB_NAME
pytest_plugins = [
fixture.replace("/", ".").replace("\\", ".").replace(".py", "")
for fixture in glob("tests/fixtures/*.py")
if "__" not in fixture
]
try: try:
from main import app from main import app
@@ -27,12 +18,12 @@ except ImportError:
ClientManagerType = AsyncGenerator[httpx.AsyncClient, None] ClientManagerType = AsyncGenerator[httpx.AsyncClient, None]
@pytest.fixture(scope="session") @pytest.fixture
def anyio_backend(): def anyio_backend():
return "asyncio" return "asyncio"
@pytest.fixture(scope="session") @pytest.fixture
def event_loop(): def event_loop():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
yield loop yield loop
+59
View File
@@ -0,0 +1,59 @@
import pytest
from dataclasses import dataclass
from modules.auth.utils import create_jwt_tokens
from modules.organizations.models import Organization, OrganizationType
from modules.users.models import ACL, Membership, User
from modules.auth.models import Token
from config import settings
crypt = settings.CRYPT
@dataclass
class user_creation_return_type:
user: User
organization: Organization
acl: ACL
tokens: Token
@pytest.fixture()
async def create_user_with_org():
async def inner_function(email="user@localhost.com",
username="user",
name="awesome",
surname="user",
password="password-dont-use",
organization_name="simple organization",
organization_type=OrganizationType.HOME,
is_admin=False) -> user_creation_return_type:
org: Organization = await Organization.create(
name=organization_name,
type=organization_type
)
acl: ACL = await ACL.create(
READ=True,
WRITE=True,
REPORT=True,
MANAGE=True if is_admin else False,
ADMIN=True if is_admin else False,
)
user: User = await User.create(
email=email,
username=username,
name=name,
surname=surname,
password=crypt.hash(password),
)
await Membership.create(
organization=org,
user=user,
acl=acl
)
tokens: Token = await create_jwt_tokens(user=user)
return user, org, acl, tokens
return inner_function
@@ -1,86 +0,0 @@
from modules.organizations.models import Organization, OrganizationType
from modules.users.models import ACL, Membership, User
import pytest # type=ignore
from config import settings
crypt = settings.CRYPT
@pytest.fixture()
async def use_user_account():
org, _ = await Organization.get_or_create(
id="6ad4c94e-0522-4912-8d16-02d451f4c92d",
defaults={
"name": "User's Organization",
"type": OrganizationType.HOME,
},
)
acl, _ = await ACL.get_or_create(
id="a4e927a3-36e5-4761-badb-0a44ade6616f",
defaults={
"READ": True,
"WRITE": True,
"REPORT": True,
"MANAGE": False,
"ADMIN": False,
},
)
user, _ = await User.get_or_create(
id="24235427-9662-4ba3-a9c5-00000000000b",
defaults={
"email": "user@localhost.com",
"username": "user",
"name": "awesome",
"surname": "user",
"password": crypt.hash("userpassword"),
},
)
membership, _ = await Membership.get_or_create(
id="833b9511-b2da-4760-8fa4-1a5c7059911e",
defaults={
"organization": org,
"user": user,
"acl": acl,
},
)
return org, acl, user, membership
@pytest.fixture()
async def use_admin_account():
org, _ = await Organization.get_or_create(
id="de001f44-1bb8-4667-9f9d-2d62d6ad7270",
defaults={
"name": "Admin's Organization",
"type": OrganizationType.EXTRA_LARGE_ORGANIZATION,
},
)
acl, _ = await ACL.get_or_create(
id="83c1bfe6-c2ed-4ba1-be03-0e5c1960ec31",
defaults={
"READ": True,
"WRITE": True,
"REPORT": True,
"MANAGE": True,
"ADMIN": True,
},
)
user, _ = await User.get_or_create(
id="24235427-9662-4ba3-a9c5-00000000000a",
defaults={
"email": "admin@localhost.com",
"username": "admin",
"name": "awesome",
"surname": "admin",
"password": crypt.hash("adminpassword"),
},
)
membership, _ = await Membership.get_or_create(
id="393473ee-c218-4bcf-82cd-cb676c4d8a33",
defaults={
"organization": org,
"user": user,
"acl": acl,
},
)
return org, acl, user, membership
@@ -1,15 +1,14 @@
from modules.users.models import User from modules.users.models import User
import pytest # type: ignore
from httpx import AsyncClient from httpx import AsyncClient
from config import settings from config import settings
from unittest.mock import ANY from unittest.mock import ANY
from tortoise.expressions import Q from tortoise.expressions import Q
from tests.base_test import Test
crypt = settings.CRYPT crypt = settings.CRYPT
class TestAuthentication(object): class TestAuthentication(Test):
@pytest.mark.asyncio
async def test_authentication_with_non_existing_user_and_password( async def test_authentication_with_non_existing_user_and_password(
self, client: AsyncClient self, client: AsyncClient
): ):
@@ -24,11 +23,10 @@ class TestAuthentication(object):
assert response.status_code == 401 assert response.status_code == 401
assert response.json() == {"detail": "E-Mail Address or password is incorrect"} assert response.json() == {"detail": "E-Mail Address or password is incorrect"}
@pytest.mark.asyncio
async def test_authentication_with_existing_user_and_wrong_password( async def test_authentication_with_existing_user_and_wrong_password(
self, client: AsyncClient, use_admin_account self, client: AsyncClient, create_user_with_org
): ):
_, _, _, _ = use_admin_account _, _, _, _ = await create_user_with_org(email="admin@localhost.com")
response = await client.post( response = await client.post(
"https://localhost/api/v1/auth/login", "https://localhost/api/v1/auth/login",
data={ data={
@@ -40,11 +38,10 @@ class TestAuthentication(object):
assert response.status_code == 401 assert response.status_code == 401
assert response.json() == {"detail": "E-Mail Address or password is incorrect"} assert response.json() == {"detail": "E-Mail Address or password is incorrect"}
@pytest.mark.asyncio
async def test_authentication_with_existing_user_and_password( async def test_authentication_with_existing_user_and_password(
self, client: AsyncClient, use_admin_account self, client: AsyncClient, create_user_with_org
): ):
_, _, admin, _ = use_admin_account user, _, _, _ = await create_user_with_org(email="admin@localhost.com", password="adminpassword")
response = await client.post( response = await client.post(
"https://localhost/api/v1/auth/login", "https://localhost/api/v1/auth/login",
data={ data={
@@ -57,7 +54,7 @@ class TestAuthentication(object):
assert response.json() == { assert response.json() == {
"jwt": { "jwt": {
"created_at": ANY, "created_at": ANY,
"user_id": str(admin.id), "user_id": str(user.id),
"id": ANY, "id": ANY,
"modified_at": ANY, "modified_at": ANY,
"disabled_at": None, "disabled_at": None,
@@ -68,11 +65,10 @@ class TestAuthentication(object):
} }
} }
@pytest.mark.asyncio
async def test_logging_out_destroys_tokens( async def test_logging_out_destroys_tokens(
self, client: AsyncClient, use_user_account self, client: AsyncClient, create_user_with_org
): ):
_, _, user, _ = use_user_account user, _, _, _ = await create_user_with_org(email="user@localhost.com", password="userpassword")
response = await client.post( response = await client.post(
"https://localhost/api/v1/auth/login", "https://localhost/api/v1/auth/login",
data={ data={
@@ -115,11 +111,10 @@ class TestAuthentication(object):
"detail": "Refresh token not found or something went wrong." "detail": "Refresh token not found or something went wrong."
} }
@pytest.mark.asyncio
async def test_create_new_tokens_upon_refresh( async def test_create_new_tokens_upon_refresh(
self, client: AsyncClient, use_admin_account self, client: AsyncClient, create_user_with_org
): ):
_, _, admin, _ = use_admin_account user, _, _, _ = await create_user_with_org(email="admin@localhost.com", password="adminpassword")
token = await client.post( token = await client.post(
"https://localhost/api/v1/auth/login", "https://localhost/api/v1/auth/login",
data={ data={
@@ -128,11 +123,13 @@ class TestAuthentication(object):
"grant_type": "password", "grant_type": "password",
}, },
) )
assert token.__dict__ == True
assert token.status_code == 200 assert token.status_code == 200
assert token.json() == { assert token.json() == {
"jwt": { "jwt": {
"created_at": ANY, "created_at": ANY,
"user_id": str(admin.id), "user_id": str(user.id),
"id": ANY, "id": ANY,
"modified_at": ANY, "modified_at": ANY,
"disabled_at": None, "disabled_at": None,
@@ -154,7 +151,7 @@ class TestAuthentication(object):
assert response2.json() == { assert response2.json() == {
"jwt": { "jwt": {
"created_at": ANY, "created_at": ANY,
"user_id": str(admin.id), "user_id": str(user.id),
"id": ANY, "id": ANY,
"modified_at": ANY, "modified_at": ANY,
"disabled_at": None, "disabled_at": None,
@@ -165,7 +162,6 @@ class TestAuthentication(object):
} }
} }
@pytest.mark.asyncio
async def test_setup_new_account(self, client: AsyncClient): async def test_setup_new_account(self, client: AsyncClient):
# Ensure account is never available. Prevents account already being available. # Ensure account is never available. Prevents account already being available.
check_if_account_exists: User | None = await User.filter( check_if_account_exists: User | None = await User.filter(
@@ -0,0 +1,181 @@
import pytest
from httpx import AsyncClient
from config import settings
from unittest.mock import ANY
from tests.base_test import Test
crypt = settings.CRYPT
class TestOrganizationRoute(Test):
@pytest.mark.asyncio
async def test_get_organizations_from_api(
self, client: AsyncClient, create_user_with_org
):
_,_,_,tokens = await create_user_with_org()
organizations = await client.get(
"https://localhost/api/v1/organizations/",
headers={"Authorization": f"Bearer {tokens.access_token}"},
)
assert organizations.status_code == 200
assert organizations.json() == [
{
"created_at": ANY,
"disabled": False,
"disabled_at": None,
"id": ANY,
"modified_at": ANY,
"name": "simple organization",
"type": "home",
"street_name": None,
"zip_code": None,
"state": None,
"city": None,
"country": None,
},
]
@pytest.mark.asyncio
async def test_create_organization(
self, client: AsyncClient, create_user_with_org
):
_,_,_,tokens = await create_user_with_org()
organizations = await client.post(
"https://localhost/api/v1/organizations/",
json={
"name": "My new organization",
"type": "xl_org",
"street_name": "Alakaventie 5 A 188",
"zip_code": "00920",
"state": "uusimaa",
"city": "Helsinki",
"country": "Finland",
},
headers={"Authorization": f"Bearer {tokens.access_token}"},
)
assert organizations.status_code == 200
assert organizations.json() == {
"created_at": ANY,
"modified_at": ANY,
"disabled_at": None,
"id": ANY,
"name": "My new organization",
"type": "xl_org",
"street_name": "Alakaventie 5 A 188",
"zip_code": "00920",
"state": "uusimaa",
"city": "Helsinki",
"country": "Finland",
"disabled": False,
}
@pytest.mark.asyncio
async def test_delete_organization(
self, client: AsyncClient, create_user_with_org
):
_,_,_,tokens = await create_user_with_org()
organizations = await client.post(
"https://localhost/api/v1/organizations/",
json={
"name": "My new organization",
"type": "xl_org",
"street_name": "Alakaventie 5 A 188",
"zip_code": "00920",
"state": "uusimaa",
"city": "Helsinki",
"country": "Finland",
},
headers={"Authorization": f"Bearer {tokens.access_token}"},
)
assert organizations.status_code == 200
assert organizations.json() == {
"created_at": ANY,
"modified_at": ANY,
"disabled_at": None,
"id": ANY,
"name": "My new organization",
"type": "xl_org",
"street_name": "Alakaventie 5 A 188",
"zip_code": "00920",
"state": "uusimaa",
"city": "Helsinki",
"country": "Finland",
"disabled": False,
}
org_id = organizations.json()["id"]
deleted_org = await client.delete(
f"https://localhost/api/v1/organizations/{org_id}",
headers={"Authorization": f"Bearer {tokens.access_token}"},
)
assert deleted_org.status_code == 204
# @pytest.mark.asyncio
# async def test_update_organization(
# self, client: AsyncClient, get_admin_login_token
# ):
# access_token, _ = get_admin_login_token
# organizations = await client.post(
# "https://localhost/api/v1/organizations/",
# json={
# "name": "My new organization",
# "type": "xl_org",
# "street_name": "Alakaventie 5 A 188",
# "zip_code": "00920",
# "state": "uusimaa",
# "city": "Helsinki",
# "country": "Finland",
# },
# headers={"Authorization": f"Bearer {access_token}"},
# )
# assert organizations.status_code == 200
# assert organizations.json() == {
# "created_at": ANY,
# "modified_at": ANY,
# "disabled_at": None,
# "id": ANY,
# "name": "My new organization",
# "type": "xl_org",
# "street_name": "Alakaventie 5 A 188",
# "zip_code": "00920",
# "state": "uusimaa",
# "city": "Helsinki",
# "country": "Finland",
# "disabled": False,
# }
# org_id = organizations.json()["id"]
# update_org = await client.put(
# f"https://localhost/api/v1/organizations/{org_id}",
# json={
# "name": "My awesome organization",
# },
# headers={"Authorization": f"Bearer {access_token}"},
# )
# assert update_org.json() == {
# "created_at": ANY,
# "modified_at": ANY,
# "disabled_at": None,
# "id": ANY,
# "name": "My new organization",
# "type": "xl_org",
# "street_name": "Alakaventie 5 A 188",
# "zip_code": "00920",
# "state": "uusimaa",
# "city": "Helsinki",
# "country": "Finland",
# "disabled": False,
# }