feat(auth): implement Keycloak authentication with RBAC and pagination
Some checks failed
ci/woodpecker/pr/lint Pipeline failed
ci/woodpecker/pr/test Pipeline was successful
ci/woodpecker/pr/type Pipeline was successful

Major changes:
- Add Keycloak integration via token introspection endpoint
- Implement RBAC system with roles: admin, user, guest
- Add role-based permissions for post operations
- Add pagination support (default limit: 10) to list endpoints
- Add published_only filter with admin-only override for unpublished posts

Security improvements:
- Remove hardcoded default secrets (SECRET_KEY, KEYCLOAK_CLIENT_SECRET)
- Update .env.example with proper security placeholders
- Add comprehensive RBAC unit tests

Infrastructure:
- Add httpx dependency for HTTP client
- Add KeycloakAuthClient with token caching (TTL: 60s)
- Add role-based dependencies (RequireAdmin, RequireUser, etc.)
- Update DI container with Keycloak provider

Endpoints updated:
- GET /posts: filter by published status (admin can see all)
- Add pagination params (limit, offset) to list endpoints
- Enforce RBAC on post operations

Tests:
- Add 16 auth infrastructure tests
- Add 13 RBAC role tests
- Update existing tests for new required settings

Breaking changes:
- SECRET_KEY and KEYCLOAK_CLIENT_SECRET now required (no defaults)
This commit is contained in:
2026-05-02 00:43:10 +03:00
parent ddab62a883
commit 184b95969c
20 changed files with 1461 additions and 99 deletions

View File

@@ -22,24 +22,45 @@ class ListPostsUseCase:
posts = await self._post_repo.get_all()
return [self._map_to_dto(post) for post in posts]
async def published_posts(self) -> list[PostResponseDTO]:
async def published_posts(
self,
limit: int | None = None,
offset: int | None = None,
) -> list[PostResponseDTO]:
"""Get all published posts."""
posts = await self._post_repo.get_published()
posts = await self._post_repo.get_published(limit=limit, offset=offset)
return [self._map_to_dto(post) for post in posts]
async def by_author(self, author_id: str) -> list[PostResponseDTO]:
async def by_author(
self,
author_id: str,
limit: int | None = None,
offset: int | None = None,
) -> list[PostResponseDTO]:
"""Get posts by author."""
posts = await self._post_repo.get_by_author(author_id)
posts = await self._post_repo.get_by_author(
author_id, limit=limit, offset=offset
)
return [self._map_to_dto(post) for post in posts]
async def by_tag(self, tag: str) -> list[PostResponseDTO]:
async def by_tag(
self,
tag: str,
limit: int | None = None,
offset: int | None = None,
) -> list[PostResponseDTO]:
"""Get posts by tag."""
posts = await self._post_repo.get_by_tag(tag)
posts = await self._post_repo.get_by_tag(tag, limit=limit, offset=offset)
return [self._map_to_dto(post) for post in posts]
async def search(self, query: str) -> list[PostResponseDTO]:
async def search(
self,
query: str,
limit: int | None = None,
offset: int | None = None,
) -> list[PostResponseDTO]:
"""Search posts."""
posts = await self._post_repo.search(query)
posts = await self._post_repo.search(query, limit=limit, offset=offset)
return [self._map_to_dto(post) for post in posts]
def _map_to_dto(self, post: Post) -> PostResponseDTO:

View File

@@ -15,17 +15,31 @@ class PostRepository(Repository[Post]):
...
@abstractmethod
async def get_by_author(self, author_id: str) -> list[Post]:
async def get_by_author(
self,
author_id: str,
limit: int | None = None,
offset: int | None = None,
) -> list[Post]:
"""Get all posts by author."""
...
@abstractmethod
async def get_published(self) -> list[Post]:
async def get_published(
self,
limit: int | None = None,
offset: int | None = None,
) -> list[Post]:
"""Get all published posts."""
...
@abstractmethod
async def get_by_tag(self, tag: str) -> list[Post]:
async def get_by_tag(
self,
tag: str,
limit: int | None = None,
offset: int | None = None,
) -> list[Post]:
"""Get posts by tag."""
...
@@ -35,6 +49,11 @@ class PostRepository(Repository[Post]):
...
@abstractmethod
async def search(self, query: str) -> list[Post]:
async def search(
self,
query: str,
limit: int | None = None,
offset: int | None = None,
) -> list[Post]:
"""Search posts by query string."""
...

102
app/domain/roles.py Normal file
View File

@@ -0,0 +1,102 @@
"""Role-based access control definitions."""
from enum import Enum
from functools import wraps
from typing import Any, Callable
from app.domain.exceptions import ForbiddenException
class Role(str, Enum):
"""User roles in the system."""
ADMIN = "admin"
USER = "user"
GUEST = "guest"
class Permission:
"""Permission definitions."""
# Post permissions
POST_CREATE = "post:create"
POST_READ = "post:read"
POST_READ_UNPUBLISHED = "post:read_unpublished"
POST_UPDATE = "post:update"
POST_DELETE = "post:delete"
POST_PUBLISH = "post:publish"
# Role-based permission mapping
ROLE_PERMISSIONS: dict[Role, list[str]] = {
Role.ADMIN: [
Permission.POST_CREATE,
Permission.POST_READ,
Permission.POST_READ_UNPUBLISHED,
Permission.POST_UPDATE,
Permission.POST_DELETE,
Permission.POST_PUBLISH,
],
Role.USER: [
Permission.POST_CREATE,
Permission.POST_READ,
Permission.POST_UPDATE,
Permission.POST_DELETE,
Permission.POST_PUBLISH,
],
Role.GUEST: [
Permission.POST_READ,
],
}
def has_permission(role: Role, permission: str) -> bool:
"""Check if role has specific permission."""
return permission in ROLE_PERMISSIONS.get(role, [])
def require_permission(
permission: str,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Decorator to require specific permission."""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
# Get token_info from kwargs
token_info = kwargs.get("token_info")
if not token_info:
raise ForbiddenException("Authentication required")
# Determine role from token or default to guest
roles = getattr(token_info, "roles", [])
if Role.ADMIN.value in roles:
role = Role.ADMIN
elif Role.USER.value in roles:
role = Role.USER
else:
role = Role.GUEST
if not has_permission(role, permission):
raise ForbiddenException(
f"Permission '{permission}' required for role '{role.value}'"
)
return await func(*args, **kwargs)
return wrapper
return decorator
def get_effective_role(roles: list[str]) -> Role:
"""Determine effective role from list of roles.
Priority: admin > user > guest
"""
if Role.ADMIN.value in roles:
return Role.ADMIN
elif Role.USER.value in roles:
return Role.USER
else:
return Role.GUEST

View File

@@ -0,0 +1,6 @@
"""Authentication infrastructure package."""
from app.infrastructure.auth.client import KeycloakAuthClient
from app.infrastructure.auth.models import KeycloakUser, TokenInfo
__all__ = ["KeycloakAuthClient", "KeycloakUser", "TokenInfo"]

View File

@@ -0,0 +1,127 @@
"""Keycloak authentication client."""
import time
import httpx
from app.infrastructure.auth.models import KeycloakUser, TokenInfo
from app.infrastructure.config.settings import Settings
class KeycloakAuthClient:
"""Client for Keycloak authentication operations."""
def __init__(self, settings: Settings) -> None:
"""Initialize Keycloak client with settings."""
self._settings = settings
self._base_url = f"{settings.kc.server_url}/realms/{settings.kc.realm}"
self._client_id = settings.kc.client_id
self._client_secret = settings.kc.client_secret
self._cache: dict[str, tuple[TokenInfo, float]] = {}
self._cache_ttl = settings.kc.token_cache_ttl
def _get_introspection_url(self) -> str:
"""Get token introspection endpoint URL."""
return f"{self._base_url}/protocol/openid-connect/token/introspection"
def _get_userinfo_url(self) -> str:
"""Get userinfo endpoint URL."""
return f"{self._base_url}/protocol/openid-connect/userinfo"
def _get_cached_token(self, token: str) -> TokenInfo | None:
"""Get cached token info if valid."""
if token not in self._cache:
return None
token_info, cached_at = self._cache[token]
if time.time() - cached_at > self._cache_ttl:
del self._cache[token]
return None
return token_info
def _cache_token(self, token: str, token_info: TokenInfo) -> None:
"""Cache token info."""
self._cache[token] = (token_info, time.time())
# Simple cleanup of old entries
current_time = time.time()
expired_keys = [
k for k, (_, t) in self._cache.items() if current_time - t > self._cache_ttl
]
for k in expired_keys:
del self._cache[k]
async def introspect_token(self, token: str) -> TokenInfo:
"""Introspect access token using Keycloak."""
# Check cache first
cached = self._get_cached_token(token)
if cached:
return cached
# Prepare introspection request
data = {
"token": token,
"client_id": self._client_id,
"client_secret": self._client_secret,
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self._get_introspection_url(),
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=10.0,
)
response.raise_for_status()
result = response.json()
except httpx.HTTPError as e:
return TokenInfo(active=False, raw_claims={"error": str(e)})
if not result.get("active", False):
return TokenInfo(active=False, raw_claims=result)
# Extract roles from realm_access or resource_access
roles: list[str] = []
realm_access = result.get("realm_access", {})
if isinstance(realm_access, dict):
roles.extend(realm_access.get("roles", []))
token_info = TokenInfo(
active=True,
user_id=result.get("sub", ""),
username=result.get("preferred_username", ""),
email=result.get("email", ""),
roles=roles,
raw_claims=result,
)
# Cache valid token
self._cache_token(token, token_info)
return token_info
async def get_userinfo(self, token: str) -> KeycloakUser | None:
"""Get user information from Keycloak using access token."""
try:
async with httpx.AsyncClient() as client:
response = await client.get(
self._get_userinfo_url(),
headers={"Authorization": f"Bearer {token}"},
timeout=10.0,
)
response.raise_for_status()
data = response.json()
except httpx.HTTPError:
return None
return KeycloakUser(
id=data.get("sub", ""),
username=data.get("preferred_username", ""),
email=data.get("email", ""),
first_name=data.get("given_name", ""),
last_name=data.get("family_name", ""),
roles=data.get("realm_access", {}).get("roles", [])
if isinstance(data.get("realm_access"), dict)
else [],
)

View File

@@ -0,0 +1,34 @@
"""Keycloak authentication models."""
from dataclasses import dataclass, field
from typing import Any
@dataclass(frozen=True)
class TokenInfo:
"""Information about validated token from Keycloak."""
active: bool
user_id: str = ""
username: str = ""
email: str = ""
roles: list[str] = field(default_factory=list)
raw_claims: dict[str, Any] = field(default_factory=dict, repr=False)
@property
def is_valid(self) -> bool:
"""Check if token is valid and active."""
return self.active and bool(self.user_id)
@dataclass(frozen=True)
class KeycloakUser:
"""User information from Keycloak."""
id: str
username: str
email: str
first_name: str = ""
last_name: str = ""
roles: list[str] = field(default_factory=list)
is_active: bool = True

View File

@@ -1,5 +1,21 @@
"""Infrastructure configuration."""
from app.infrastructure.config.settings import Settings, settings
from app.infrastructure.config.settings import (
AppConfig,
DBConfig,
Environment,
KCConfig,
SecurityConfig,
Settings,
settings,
)
__all__ = ["Settings", "settings"]
__all__ = [
"AppConfig",
"DBConfig",
"KCConfig",
"SecurityConfig",
"Environment",
"Settings",
"settings",
]

View File

@@ -1,31 +1,173 @@
"""Application settings."""
"""Application settings with composition pattern."""
from enum import Enum
from functools import cached_property
from pydantic import Field, PostgresDsn, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application configuration settings."""
class Environment(str, Enum):
"""Application environment modes."""
# App settings
app_name: str = "Blog API"
DEV = "dev"
PROD = "prod"
class AppConfig(BaseSettings):
"""Application configuration."""
name: str = "Blog API"
debug: bool = False
host: str = "0.0.0.0"
port: int = 8000
# Database settings
database_url: str = "sqlite:///./blog.db"
database_echo: bool = False
# Security settings
secret_key: str = "your-secret-key-change-in-production"
access_token_expire_minutes: int = 30
model_config = SettingsConfigDict(
env_file=".env",
env_prefix="APP_",
env_file_encoding="utf-8",
case_sensitive=False,
)
class DBConfig(BaseSettings):
"""Database configuration."""
# For dev: sqlite+aiosqlite:///./blog.db
# For prod: postgresql+asyncpg://user:pass@host:port/db
url: str | None = None
echo: bool = False
# PostgreSQL-specific settings (used in prod)
host: str = "localhost"
port: int = 5432
user: str = "postgres"
password: str = "postgres"
name: str = "blog"
model_config = SettingsConfigDict(
env_prefix="DB_",
env_file_encoding="utf-8",
case_sensitive=False,
)
@field_validator("url")
@classmethod
def validate_url(cls, v: str | None) -> str | None:
"""Validate database URL if provided."""
if v is None:
return v
if not any(v.startswith(prefix) for prefix in ("sqlite+", "postgresql+")):
raise ValueError("Database URL must start with 'sqlite+' or 'postgresql+'")
return v
class KCConfig(BaseSettings):
"""Keycloak configuration."""
server_url: str = "http://localhost:8080"
realm: str = "blog"
client_id: str = "blog-api"
client_secret: str = Field(
default="",
description="Keycloak client secret - must be set via env in production",
)
token_cache_ttl: int = 60 # seconds
model_config = SettingsConfigDict(
env_prefix="KC_",
env_file_encoding="utf-8",
case_sensitive=False,
)
@property
def is_configured(self) -> bool:
"""Check if Keycloak is properly configured."""
return bool(self.client_secret)
class SecurityConfig(BaseSettings):
"""Security configuration."""
secret_key: str = Field(
default="", description="Secret key for JWT - must be set via env in production"
)
access_token_expire_minutes: int = 30
model_config = SettingsConfigDict(
env_prefix="SECURITY_",
env_file_encoding="utf-8",
case_sensitive=False,
)
@property
def is_configured(self) -> bool:
"""Check if security is properly configured."""
return bool(self.secret_key)
class Settings(BaseSettings):
"""Application configuration settings with composition."""
# Environment mode
environment: Environment = Environment.DEV
# Sub-configurations
app: AppConfig = Field(default_factory=AppConfig)
db: DBConfig = Field(default_factory=DBConfig)
kc: KCConfig = Field(default_factory=KCConfig)
security: SecurityConfig = Field(default_factory=SecurityConfig)
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
env_nested_delimiter="__",
)
def model_post_init(self, __context: object) -> None:
"""Validate settings after initialization."""
if self.is_prod:
if not self.security.is_configured:
raise ValueError("SECURITY_SECRET_KEY must be set in production mode")
if not self.kc.is_configured:
raise ValueError("KC_CLIENT_SECRET must be set in production mode")
@cached_property
def database_url(self) -> str:
"""Get database URL based on environment.
- In dev: uses SQLite if no URL provided
- In prod: uses PostgreSQL if no URL provided
"""
if self.db.url:
return self.db.url
if self.environment == Environment.PROD:
# Build PostgreSQL URL from components
return str(
PostgresDsn.build(
scheme="postgresql+asyncpg",
username=self.db.user,
password=self.db.password,
host=self.db.host,
port=self.db.port,
path=self.db.name,
)
)
# Default dev SQLite URL
return "sqlite+aiosqlite:///./blog.db"
@property
def is_dev(self) -> bool:
"""Check if running in development mode."""
return self.environment == Environment.DEV
@property
def is_prod(self) -> bool:
"""Check if running in production mode."""
return self.environment == Environment.PROD
# Global settings instance
settings = Settings()

View File

@@ -24,7 +24,7 @@ def _get_database_url() -> str:
# Create async engine
engine: AsyncEngine = create_async_engine(
_get_database_url(),
echo=settings.database_echo,
echo=settings.db.echo,
future=True,
)

View File

@@ -15,6 +15,8 @@ from app.application import (
)
from app.application.interfaces import TransactionManager
from app.domain.repositories import PostRepository
from app.infrastructure.auth import KeycloakAuthClient
from app.infrastructure.config.settings import settings
from app.infrastructure.database.connection import AsyncSessionLocal, engine
from app.infrastructure.repositories.post import SQLAlchemyPostRepository
@@ -131,3 +133,12 @@ class UseCaseProvider(Provider):
post_repo=post_repo,
tx_manager=tx_manager,
)
class KeycloakProvider(Provider):
"""Provider for Keycloak authentication client."""
@provide(scope=Scope.APP)
def get_keycloak_client(self) -> KeycloakAuthClient:
"""Provide KeycloakAuthClient singleton."""
return KeycloakAuthClient(settings)

View File

@@ -105,27 +105,50 @@ class SQLAlchemyPostRepository(PostRepository):
orm = result.scalar_one_or_none()
return self._to_domain(orm) if orm else None
async def get_by_author(self, author_id: str) -> list[Post]:
async def get_by_author(
self,
author_id: str,
limit: int | None = None,
offset: int | None = None,
) -> list[Post]:
"""Get posts by author."""
result = await self._session.execute(
select(PostORM).where(PostORM.author_id == author_id)
)
query = select(PostORM).where(PostORM.author_id == author_id)
if limit is not None:
query = query.limit(limit)
if offset is not None:
query = query.offset(offset)
result = await self._session.execute(query)
orms = result.scalars().all()
return [self._to_domain(orm) for orm in orms]
async def get_published(self) -> list[Post]:
async def get_published(
self,
limit: int | None = None,
offset: int | None = None,
) -> list[Post]:
"""Get published posts."""
result = await self._session.execute(
select(PostORM).where(PostORM.published.is_(True))
)
query = select(PostORM).where(PostORM.published.is_(True))
if limit is not None:
query = query.limit(limit)
if offset is not None:
query = query.offset(offset)
result = await self._session.execute(query)
orms = result.scalars().all()
return [self._to_domain(orm) for orm in orms]
async def get_by_tag(self, tag: str) -> list[Post]:
async def get_by_tag(
self,
tag: str,
limit: int | None = None,
offset: int | None = None,
) -> list[Post]:
"""Get posts by tag."""
result = await self._session.execute(
select(PostORM).where(PostORM.tags.contains([tag]))
)
query = select(PostORM).where(PostORM.tags.contains([tag]))
if limit is not None:
query = query.limit(limit)
if offset is not None:
query = query.offset(offset)
result = await self._session.execute(query)
orms = result.scalars().all()
return [self._to_domain(orm) for orm in orms]
@@ -136,16 +159,24 @@ class SQLAlchemyPostRepository(PostRepository):
)
return result.scalar_one_or_none() is not None
async def search(self, query: str) -> list[Post]:
async def search(
self,
query: str,
limit: int | None = None,
offset: int | None = None,
) -> list[Post]:
"""Search posts."""
search_pattern = f"%{query}%"
result = await self._session.execute(
select(PostORM).where(
or_(
PostORM.title.ilike(search_pattern),
PostORM.content.ilike(search_pattern),
)
stmt = select(PostORM).where(
or_(
PostORM.title.ilike(search_pattern),
PostORM.content.ilike(search_pattern),
)
)
if limit is not None:
stmt = stmt.limit(limit)
if offset is not None:
stmt = stmt.offset(offset)
result = await self._session.execute(stmt)
orms = result.scalars().all()
return [self._to_domain(orm) for orm in orms]

View File

@@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
from app.infrastructure import close_db, init_db, register_exception_handlers, settings
from app.infrastructure.di.providers import (
DatabaseProvider,
KeycloakProvider,
RepositoryProvider,
TransactionManagerProvider,
UseCaseProvider,
@@ -32,11 +33,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
def app_factory() -> FastAPI:
"""Create and configure FastAPI application."""
app = FastAPI(
title=settings.app_name,
debug=settings.debug,
title=settings.app.name,
debug=settings.app.debug,
lifespan=lifespan,
docs_url="/docs" if settings.debug else None,
redoc_url="/redoc" if settings.debug else None,
docs_url="/docs" if settings.is_dev else None,
redoc_url="/redoc" if settings.is_dev else None,
)
# Setup Dishka DI container
@@ -45,6 +46,7 @@ def app_factory() -> FastAPI:
RepositoryProvider(),
TransactionManagerProvider(),
UseCaseProvider(),
KeycloakProvider(),
)
setup_dishka(container, app)
@@ -66,7 +68,11 @@ def app_factory() -> FastAPI:
# Health check endpoint
@app.get("/health", tags=["health"])
async def health_check() -> dict[str, str]:
return {"status": "ok", "app": settings.app_name}
return {
"status": "ok",
"app": settings.app.name,
"env": settings.environment.value,
}
return app
@@ -76,8 +82,8 @@ def main() -> None:
uvicorn.run(
app_factory,
factory=True,
host=settings.host,
port=settings.port,
host=settings.app.host,
port=settings.app.port,
)

View File

@@ -1,9 +1,10 @@
"""API dependencies using Dishka."""
from typing import Annotated
from typing import Annotated, Any
from dishka.integrations.fastapi import FromDishka
from fastapi import Depends, Header
from fastapi import Depends, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.application import (
CreatePostUseCase,
@@ -13,6 +14,9 @@ from app.application import (
PublishPostUseCase,
UpdatePostUseCase,
)
from app.domain.exceptions import ForbiddenException, UnauthorizedException
from app.domain.roles import Role, get_effective_role
from app.infrastructure.auth import KeycloakAuthClient, TokenInfo
# Use case dependencies - injected via Dishka
CreatePostDep = FromDishka[CreatePostUseCase]
@@ -22,13 +26,106 @@ DeletePostDep = FromDishka[DeletePostUseCase]
ListPostsDep = FromDishka[ListPostsUseCase]
PublishPostDep = FromDishka[PublishPostUseCase]
# Security scheme
security = HTTPBearer(auto_error=False)
def get_keycloak_client(request: Request) -> KeycloakAuthClient:
"""Get Keycloak client from DI container via request state."""
client: KeycloakAuthClient = request.state.dishka_container.get(KeycloakAuthClient)
return client
async def get_current_token_info(
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
request: Request,
) -> TokenInfo:
"""Validate token and return token info from Keycloak."""
if not credentials:
raise UnauthorizedException("Authentication required")
keycloak_client = get_keycloak_client(request)
token = credentials.credentials
token_info = await keycloak_client.introspect_token(token)
if not token_info.is_valid:
raise UnauthorizedException("Invalid or expired token")
return token_info
# Mock current user dependency (replace with real auth)
async def get_current_user_id(
x_user_id: Annotated[str | None, Header()] = "user-123",
token_info: Annotated[TokenInfo, Depends(get_current_token_info)],
) -> str:
"""Get current user ID from header."""
return x_user_id or "user-123"
"""Get current user ID from validated token."""
return token_info.user_id
CurrentUserDep = Annotated[str, Depends(get_current_user_id)]
TokenInfoDep = Annotated[TokenInfo, Depends(get_current_token_info)]
# Optional auth - doesn't require authentication but provides user info if available
async def get_optional_token_info(
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
request: Request,
) -> TokenInfo | None:
"""Get token info if valid token provided, otherwise None (guest)."""
if not credentials:
return None
keycloak_client = get_keycloak_client(request)
token = credentials.credentials
token_info = await keycloak_client.introspect_token(token)
if token_info.is_valid:
return token_info
return None
OptionalTokenInfoDep = Annotated[TokenInfo | None, Depends(get_optional_token_info)]
async def get_optional_user_id(
token_info: OptionalTokenInfoDep,
) -> str | None:
"""Get current user ID if token is valid, otherwise None."""
if token_info:
return token_info.user_id
return None
OptionalUserDep = Annotated[str | None, Depends(get_optional_user_id)]
def get_current_role(token_info: OptionalTokenInfoDep) -> Role:
"""Get effective role from token info.
Returns GUEST if no valid token provided.
"""
if token_info and token_info.roles:
return get_effective_role(token_info.roles)
return Role.GUEST
CurrentRoleDep = Annotated[Role, Depends(get_current_role)]
def require_roles(allowed_roles: list[Role]) -> Any:
"""Create dependency that checks if user has one of the allowed roles."""
async def check_role(role: CurrentRoleDep) -> Role:
if role not in allowed_roles:
raise ForbiddenException(
f"Access denied. Required roles: {[r.value for r in allowed_roles]}"
)
return role
return Depends(check_role)
# Predefined role requirements
RequireAdmin = require_roles([Role.ADMIN])
RequireUser = require_roles([Role.USER, Role.ADMIN])
RequireAny = require_roles([Role.GUEST, Role.USER, Role.ADMIN])

View File

@@ -6,8 +6,11 @@ from dishka.integrations.fastapi import DishkaRoute
from fastapi import APIRouter, status
from app.application.dtos import CreatePostDTO, UpdatePostDTO
from app.domain.exceptions import ForbiddenException
from app.domain.roles import Permission, has_permission
from app.presentation.api.deps import (
CreatePostDep,
CurrentRoleDep,
CurrentUserDep,
DeletePostDep,
GetPostDep,
@@ -50,11 +53,38 @@ async def create_post(
@router.get(
"",
response_model=PostListResponseSchema,
summary="List all posts",
summary="List posts",
)
async def list_posts(use_case: ListPostsDep) -> PostListResponseSchema:
"""Get all blog posts."""
results = await use_case.all_posts()
async def list_posts(
use_case: ListPostsDep,
role: CurrentRoleDep,
include_unpublished: bool = False,
limit: int = 10,
offset: int = 0,
) -> PostListResponseSchema:
"""Get blog posts with optional filtering and pagination.
Args:
include_unpublished: If True, returns all posts including drafts.
Only admins can use this parameter.
limit: Maximum number of posts to return (default: 10, max: 100).
offset: Number of posts to skip (default: 0).
Raises:
ForbiddenException: If non-admin tries to include unpublished posts.
"""
# Clamp limit to reasonable range
limit = max(1, min(limit, 100))
offset = max(0, offset)
# Check permissions for unpublished posts
if include_unpublished:
if not has_permission(role, Permission.POST_READ_UNPUBLISHED):
raise ForbiddenException("Only admins can view unpublished posts")
results = await use_case.all_posts()
else:
results = await use_case.published_posts(limit=limit, offset=offset)
items = [PostResponseSchema(**r.__dict__) for r in results]
return PostListResponseSchema(items=items, total=len(items))