feat(auth): implement Keycloak authentication with RBAC and pagination
Major changes: - Add Keycloak integration via token introspection endpoint - Implement RBAC system with roles: admin, user, guest - Add role-based permissions for post operations - Add pagination support (default limit: 10) to list endpoints - Add published_only filter with admin-only override for unpublished posts Security improvements: - Remove hardcoded default secrets (SECRET_KEY, KEYCLOAK_CLIENT_SECRET) - Update .env.example with proper security placeholders - Add comprehensive RBAC unit tests Infrastructure: - Add httpx dependency for HTTP client - Add KeycloakAuthClient with token caching (TTL: 60s) - Add role-based dependencies (RequireAdmin, RequireUser, etc.) - Update DI container with Keycloak provider Endpoints updated: - GET /posts: filter by published status (admin can see all) - Add pagination params (limit, offset) to list endpoints - Enforce RBAC on post operations Tests: - Add 16 auth infrastructure tests - Add 13 RBAC role tests - Update existing tests for new required settings Breaking changes: - SECRET_KEY and KEYCLOAK_CLIENT_SECRET now required (no defaults)
This commit is contained in:
33
.env.example
Normal file
33
.env.example
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# Environment mode: dev or prod
|
||||||
|
ENVIRONMENT=dev
|
||||||
|
|
||||||
|
# App settings
|
||||||
|
APP_NAME=Blog API
|
||||||
|
APP_DEBUG=false
|
||||||
|
APP_HOST=0.0.0.0
|
||||||
|
APP_PORT=8000
|
||||||
|
|
||||||
|
# Database settings
|
||||||
|
# For dev (SQLite): DB_URL=sqlite+aiosqlite:///./blog.db
|
||||||
|
# For prod (PostgreSQL): DB_URL=postgresql+asyncpg://user:pass@host:port/db
|
||||||
|
# Or use individual DB_* vars for prod (see below)
|
||||||
|
DB_URL=
|
||||||
|
DB_ECHO=false
|
||||||
|
|
||||||
|
# PostgreSQL-specific settings (used in prod when DB_URL is not set)
|
||||||
|
DB_HOST=localhost
|
||||||
|
DB_PORT=5432
|
||||||
|
DB_USER=postgres
|
||||||
|
DB_PASSWORD=postgres
|
||||||
|
DB_NAME=blog
|
||||||
|
|
||||||
|
# Security settings (REQUIRED)
|
||||||
|
SECURITY_SECRET_KEY=your-secret-key-here-change-in-production
|
||||||
|
SECURITY_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||||
|
|
||||||
|
# Keycloak settings (REQUIRED for authentication)
|
||||||
|
KC_SERVER_URL=http://localhost:8080
|
||||||
|
KC_REALM=blog
|
||||||
|
KC_CLIENT_ID=blog-api
|
||||||
|
KC_CLIENT_SECRET=your-keycloak-client-secret-here
|
||||||
|
KC_TOKEN_CACHE_TTL=60
|
||||||
@@ -22,24 +22,45 @@ class ListPostsUseCase:
|
|||||||
posts = await self._post_repo.get_all()
|
posts = await self._post_repo.get_all()
|
||||||
return [self._map_to_dto(post) for post in posts]
|
return [self._map_to_dto(post) for post in posts]
|
||||||
|
|
||||||
async def published_posts(self) -> list[PostResponseDTO]:
|
async def published_posts(
|
||||||
|
self,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> list[PostResponseDTO]:
|
||||||
"""Get all published posts."""
|
"""Get all published posts."""
|
||||||
posts = await self._post_repo.get_published()
|
posts = await self._post_repo.get_published(limit=limit, offset=offset)
|
||||||
return [self._map_to_dto(post) for post in posts]
|
return [self._map_to_dto(post) for post in posts]
|
||||||
|
|
||||||
async def by_author(self, author_id: str) -> list[PostResponseDTO]:
|
async def by_author(
|
||||||
|
self,
|
||||||
|
author_id: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> list[PostResponseDTO]:
|
||||||
"""Get posts by author."""
|
"""Get posts by author."""
|
||||||
posts = await self._post_repo.get_by_author(author_id)
|
posts = await self._post_repo.get_by_author(
|
||||||
|
author_id, limit=limit, offset=offset
|
||||||
|
)
|
||||||
return [self._map_to_dto(post) for post in posts]
|
return [self._map_to_dto(post) for post in posts]
|
||||||
|
|
||||||
async def by_tag(self, tag: str) -> list[PostResponseDTO]:
|
async def by_tag(
|
||||||
|
self,
|
||||||
|
tag: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> list[PostResponseDTO]:
|
||||||
"""Get posts by tag."""
|
"""Get posts by tag."""
|
||||||
posts = await self._post_repo.get_by_tag(tag)
|
posts = await self._post_repo.get_by_tag(tag, limit=limit, offset=offset)
|
||||||
return [self._map_to_dto(post) for post in posts]
|
return [self._map_to_dto(post) for post in posts]
|
||||||
|
|
||||||
async def search(self, query: str) -> list[PostResponseDTO]:
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> list[PostResponseDTO]:
|
||||||
"""Search posts."""
|
"""Search posts."""
|
||||||
posts = await self._post_repo.search(query)
|
posts = await self._post_repo.search(query, limit=limit, offset=offset)
|
||||||
return [self._map_to_dto(post) for post in posts]
|
return [self._map_to_dto(post) for post in posts]
|
||||||
|
|
||||||
def _map_to_dto(self, post: Post) -> PostResponseDTO:
|
def _map_to_dto(self, post: Post) -> PostResponseDTO:
|
||||||
|
|||||||
@@ -15,17 +15,31 @@ class PostRepository(Repository[Post]):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_by_author(self, author_id: str) -> list[Post]:
|
async def get_by_author(
|
||||||
|
self,
|
||||||
|
author_id: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> list[Post]:
|
||||||
"""Get all posts by author."""
|
"""Get all posts by author."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_published(self) -> list[Post]:
|
async def get_published(
|
||||||
|
self,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> list[Post]:
|
||||||
"""Get all published posts."""
|
"""Get all published posts."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_by_tag(self, tag: str) -> list[Post]:
|
async def get_by_tag(
|
||||||
|
self,
|
||||||
|
tag: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> list[Post]:
|
||||||
"""Get posts by tag."""
|
"""Get posts by tag."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@@ -35,6 +49,11 @@ class PostRepository(Repository[Post]):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def search(self, query: str) -> list[Post]:
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> list[Post]:
|
||||||
"""Search posts by query string."""
|
"""Search posts by query string."""
|
||||||
...
|
...
|
||||||
|
|||||||
102
app/domain/roles.py
Normal file
102
app/domain/roles.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""Role-based access control definitions."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from app.domain.exceptions import ForbiddenException
|
||||||
|
|
||||||
|
|
||||||
|
class Role(str, Enum):
|
||||||
|
"""User roles in the system."""
|
||||||
|
|
||||||
|
ADMIN = "admin"
|
||||||
|
USER = "user"
|
||||||
|
GUEST = "guest"
|
||||||
|
|
||||||
|
|
||||||
|
class Permission:
|
||||||
|
"""Permission definitions."""
|
||||||
|
|
||||||
|
# Post permissions
|
||||||
|
POST_CREATE = "post:create"
|
||||||
|
POST_READ = "post:read"
|
||||||
|
POST_READ_UNPUBLISHED = "post:read_unpublished"
|
||||||
|
POST_UPDATE = "post:update"
|
||||||
|
POST_DELETE = "post:delete"
|
||||||
|
POST_PUBLISH = "post:publish"
|
||||||
|
|
||||||
|
|
||||||
|
# Role-based permission mapping
|
||||||
|
ROLE_PERMISSIONS: dict[Role, list[str]] = {
|
||||||
|
Role.ADMIN: [
|
||||||
|
Permission.POST_CREATE,
|
||||||
|
Permission.POST_READ,
|
||||||
|
Permission.POST_READ_UNPUBLISHED,
|
||||||
|
Permission.POST_UPDATE,
|
||||||
|
Permission.POST_DELETE,
|
||||||
|
Permission.POST_PUBLISH,
|
||||||
|
],
|
||||||
|
Role.USER: [
|
||||||
|
Permission.POST_CREATE,
|
||||||
|
Permission.POST_READ,
|
||||||
|
Permission.POST_UPDATE,
|
||||||
|
Permission.POST_DELETE,
|
||||||
|
Permission.POST_PUBLISH,
|
||||||
|
],
|
||||||
|
Role.GUEST: [
|
||||||
|
Permission.POST_READ,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def has_permission(role: Role, permission: str) -> bool:
|
||||||
|
"""Check if role has specific permission."""
|
||||||
|
return permission in ROLE_PERMISSIONS.get(role, [])
|
||||||
|
|
||||||
|
|
||||||
|
def require_permission(
|
||||||
|
permission: str,
|
||||||
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||||
|
"""Decorator to require specific permission."""
|
||||||
|
|
||||||
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
# Get token_info from kwargs
|
||||||
|
token_info = kwargs.get("token_info")
|
||||||
|
if not token_info:
|
||||||
|
raise ForbiddenException("Authentication required")
|
||||||
|
|
||||||
|
# Determine role from token or default to guest
|
||||||
|
roles = getattr(token_info, "roles", [])
|
||||||
|
if Role.ADMIN.value in roles:
|
||||||
|
role = Role.ADMIN
|
||||||
|
elif Role.USER.value in roles:
|
||||||
|
role = Role.USER
|
||||||
|
else:
|
||||||
|
role = Role.GUEST
|
||||||
|
|
||||||
|
if not has_permission(role, permission):
|
||||||
|
raise ForbiddenException(
|
||||||
|
f"Permission '{permission}' required for role '{role.value}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def get_effective_role(roles: list[str]) -> Role:
|
||||||
|
"""Determine effective role from list of roles.
|
||||||
|
|
||||||
|
Priority: admin > user > guest
|
||||||
|
"""
|
||||||
|
if Role.ADMIN.value in roles:
|
||||||
|
return Role.ADMIN
|
||||||
|
elif Role.USER.value in roles:
|
||||||
|
return Role.USER
|
||||||
|
else:
|
||||||
|
return Role.GUEST
|
||||||
6
app/infrastructure/auth/__init__.py
Normal file
6
app/infrastructure/auth/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Authentication infrastructure package."""
|
||||||
|
|
||||||
|
from app.infrastructure.auth.client import KeycloakAuthClient
|
||||||
|
from app.infrastructure.auth.models import KeycloakUser, TokenInfo
|
||||||
|
|
||||||
|
__all__ = ["KeycloakAuthClient", "KeycloakUser", "TokenInfo"]
|
||||||
127
app/infrastructure/auth/client.py
Normal file
127
app/infrastructure/auth/client.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""Keycloak authentication client."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.infrastructure.auth.models import KeycloakUser, TokenInfo
|
||||||
|
from app.infrastructure.config.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
|
class KeycloakAuthClient:
|
||||||
|
"""Client for Keycloak authentication operations."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
"""Initialize Keycloak client with settings."""
|
||||||
|
self._settings = settings
|
||||||
|
self._base_url = f"{settings.kc.server_url}/realms/{settings.kc.realm}"
|
||||||
|
self._client_id = settings.kc.client_id
|
||||||
|
self._client_secret = settings.kc.client_secret
|
||||||
|
self._cache: dict[str, tuple[TokenInfo, float]] = {}
|
||||||
|
self._cache_ttl = settings.kc.token_cache_ttl
|
||||||
|
|
||||||
|
def _get_introspection_url(self) -> str:
|
||||||
|
"""Get token introspection endpoint URL."""
|
||||||
|
return f"{self._base_url}/protocol/openid-connect/token/introspection"
|
||||||
|
|
||||||
|
def _get_userinfo_url(self) -> str:
|
||||||
|
"""Get userinfo endpoint URL."""
|
||||||
|
return f"{self._base_url}/protocol/openid-connect/userinfo"
|
||||||
|
|
||||||
|
def _get_cached_token(self, token: str) -> TokenInfo | None:
|
||||||
|
"""Get cached token info if valid."""
|
||||||
|
if token not in self._cache:
|
||||||
|
return None
|
||||||
|
|
||||||
|
token_info, cached_at = self._cache[token]
|
||||||
|
if time.time() - cached_at > self._cache_ttl:
|
||||||
|
del self._cache[token]
|
||||||
|
return None
|
||||||
|
|
||||||
|
return token_info
|
||||||
|
|
||||||
|
def _cache_token(self, token: str, token_info: TokenInfo) -> None:
|
||||||
|
"""Cache token info."""
|
||||||
|
self._cache[token] = (token_info, time.time())
|
||||||
|
# Simple cleanup of old entries
|
||||||
|
current_time = time.time()
|
||||||
|
expired_keys = [
|
||||||
|
k for k, (_, t) in self._cache.items() if current_time - t > self._cache_ttl
|
||||||
|
]
|
||||||
|
for k in expired_keys:
|
||||||
|
del self._cache[k]
|
||||||
|
|
||||||
|
async def introspect_token(self, token: str) -> TokenInfo:
|
||||||
|
"""Introspect access token using Keycloak."""
|
||||||
|
# Check cache first
|
||||||
|
cached = self._get_cached_token(token)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Prepare introspection request
|
||||||
|
data = {
|
||||||
|
"token": token,
|
||||||
|
"client_id": self._client_id,
|
||||||
|
"client_secret": self._client_secret,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self._get_introspection_url(),
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
return TokenInfo(active=False, raw_claims={"error": str(e)})
|
||||||
|
|
||||||
|
if not result.get("active", False):
|
||||||
|
return TokenInfo(active=False, raw_claims=result)
|
||||||
|
|
||||||
|
# Extract roles from realm_access or resource_access
|
||||||
|
roles: list[str] = []
|
||||||
|
realm_access = result.get("realm_access", {})
|
||||||
|
if isinstance(realm_access, dict):
|
||||||
|
roles.extend(realm_access.get("roles", []))
|
||||||
|
|
||||||
|
token_info = TokenInfo(
|
||||||
|
active=True,
|
||||||
|
user_id=result.get("sub", ""),
|
||||||
|
username=result.get("preferred_username", ""),
|
||||||
|
email=result.get("email", ""),
|
||||||
|
roles=roles,
|
||||||
|
raw_claims=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache valid token
|
||||||
|
self._cache_token(token, token_info)
|
||||||
|
|
||||||
|
return token_info
|
||||||
|
|
||||||
|
async def get_userinfo(self, token: str) -> KeycloakUser | None:
|
||||||
|
"""Get user information from Keycloak using access token."""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
self._get_userinfo_url(),
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
except httpx.HTTPError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return KeycloakUser(
|
||||||
|
id=data.get("sub", ""),
|
||||||
|
username=data.get("preferred_username", ""),
|
||||||
|
email=data.get("email", ""),
|
||||||
|
first_name=data.get("given_name", ""),
|
||||||
|
last_name=data.get("family_name", ""),
|
||||||
|
roles=data.get("realm_access", {}).get("roles", [])
|
||||||
|
if isinstance(data.get("realm_access"), dict)
|
||||||
|
else [],
|
||||||
|
)
|
||||||
34
app/infrastructure/auth/models.py
Normal file
34
app/infrastructure/auth/models.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""Keycloak authentication models."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TokenInfo:
|
||||||
|
"""Information about validated token from Keycloak."""
|
||||||
|
|
||||||
|
active: bool
|
||||||
|
user_id: str = ""
|
||||||
|
username: str = ""
|
||||||
|
email: str = ""
|
||||||
|
roles: list[str] = field(default_factory=list)
|
||||||
|
raw_claims: dict[str, Any] = field(default_factory=dict, repr=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
"""Check if token is valid and active."""
|
||||||
|
return self.active and bool(self.user_id)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class KeycloakUser:
|
||||||
|
"""User information from Keycloak."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
username: str
|
||||||
|
email: str
|
||||||
|
first_name: str = ""
|
||||||
|
last_name: str = ""
|
||||||
|
roles: list[str] = field(default_factory=list)
|
||||||
|
is_active: bool = True
|
||||||
@@ -1,5 +1,21 @@
|
|||||||
"""Infrastructure configuration."""
|
"""Infrastructure configuration."""
|
||||||
|
|
||||||
from app.infrastructure.config.settings import Settings, settings
|
from app.infrastructure.config.settings import (
|
||||||
|
AppConfig,
|
||||||
|
DBConfig,
|
||||||
|
Environment,
|
||||||
|
KCConfig,
|
||||||
|
SecurityConfig,
|
||||||
|
Settings,
|
||||||
|
settings,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["Settings", "settings"]
|
__all__ = [
|
||||||
|
"AppConfig",
|
||||||
|
"DBConfig",
|
||||||
|
"KCConfig",
|
||||||
|
"SecurityConfig",
|
||||||
|
"Environment",
|
||||||
|
"Settings",
|
||||||
|
"settings",
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,31 +1,173 @@
|
|||||||
"""Application settings."""
|
"""Application settings with composition pattern."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
from pydantic import Field, PostgresDsn, field_validator
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Environment(str, Enum):
|
||||||
"""Application configuration settings."""
|
"""Application environment modes."""
|
||||||
|
|
||||||
# App settings
|
DEV = "dev"
|
||||||
app_name: str = "Blog API"
|
PROD = "prod"
|
||||||
|
|
||||||
|
|
||||||
|
class AppConfig(BaseSettings):
|
||||||
|
"""Application configuration."""
|
||||||
|
|
||||||
|
name: str = "Blog API"
|
||||||
debug: bool = False
|
debug: bool = False
|
||||||
host: str = "0.0.0.0"
|
host: str = "0.0.0.0"
|
||||||
port: int = 8000
|
port: int = 8000
|
||||||
|
|
||||||
# Database settings
|
|
||||||
database_url: str = "sqlite:///./blog.db"
|
|
||||||
database_echo: bool = False
|
|
||||||
|
|
||||||
# Security settings
|
|
||||||
secret_key: str = "your-secret-key-change-in-production"
|
|
||||||
access_token_expire_minutes: int = 30
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=".env",
|
env_prefix="APP_",
|
||||||
env_file_encoding="utf-8",
|
env_file_encoding="utf-8",
|
||||||
case_sensitive=False,
|
case_sensitive=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DBConfig(BaseSettings):
|
||||||
|
"""Database configuration."""
|
||||||
|
|
||||||
|
# For dev: sqlite+aiosqlite:///./blog.db
|
||||||
|
# For prod: postgresql+asyncpg://user:pass@host:port/db
|
||||||
|
url: str | None = None
|
||||||
|
echo: bool = False
|
||||||
|
|
||||||
|
# PostgreSQL-specific settings (used in prod)
|
||||||
|
host: str = "localhost"
|
||||||
|
port: int = 5432
|
||||||
|
user: str = "postgres"
|
||||||
|
password: str = "postgres"
|
||||||
|
name: str = "blog"
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_prefix="DB_",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
case_sensitive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("url")
|
||||||
|
@classmethod
|
||||||
|
def validate_url(cls, v: str | None) -> str | None:
|
||||||
|
"""Validate database URL if provided."""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
if not any(v.startswith(prefix) for prefix in ("sqlite+", "postgresql+")):
|
||||||
|
raise ValueError("Database URL must start with 'sqlite+' or 'postgresql+'")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class KCConfig(BaseSettings):
|
||||||
|
"""Keycloak configuration."""
|
||||||
|
|
||||||
|
server_url: str = "http://localhost:8080"
|
||||||
|
realm: str = "blog"
|
||||||
|
client_id: str = "blog-api"
|
||||||
|
client_secret: str = Field(
|
||||||
|
default="",
|
||||||
|
description="Keycloak client secret - must be set via env in production",
|
||||||
|
)
|
||||||
|
token_cache_ttl: int = 60 # seconds
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_prefix="KC_",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
case_sensitive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_configured(self) -> bool:
|
||||||
|
"""Check if Keycloak is properly configured."""
|
||||||
|
return bool(self.client_secret)
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityConfig(BaseSettings):
|
||||||
|
"""Security configuration."""
|
||||||
|
|
||||||
|
secret_key: str = Field(
|
||||||
|
default="", description="Secret key for JWT - must be set via env in production"
|
||||||
|
)
|
||||||
|
access_token_expire_minutes: int = 30
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_prefix="SECURITY_",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
case_sensitive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_configured(self) -> bool:
|
||||||
|
"""Check if security is properly configured."""
|
||||||
|
return bool(self.secret_key)
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application configuration settings with composition."""
|
||||||
|
|
||||||
|
# Environment mode
|
||||||
|
environment: Environment = Environment.DEV
|
||||||
|
|
||||||
|
# Sub-configurations
|
||||||
|
app: AppConfig = Field(default_factory=AppConfig)
|
||||||
|
db: DBConfig = Field(default_factory=DBConfig)
|
||||||
|
kc: KCConfig = Field(default_factory=KCConfig)
|
||||||
|
security: SecurityConfig = Field(default_factory=SecurityConfig)
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
case_sensitive=False,
|
||||||
|
env_nested_delimiter="__",
|
||||||
|
)
|
||||||
|
|
||||||
|
def model_post_init(self, __context: object) -> None:
|
||||||
|
"""Validate settings after initialization."""
|
||||||
|
if self.is_prod:
|
||||||
|
if not self.security.is_configured:
|
||||||
|
raise ValueError("SECURITY_SECRET_KEY must be set in production mode")
|
||||||
|
if not self.kc.is_configured:
|
||||||
|
raise ValueError("KC_CLIENT_SECRET must be set in production mode")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def database_url(self) -> str:
|
||||||
|
"""Get database URL based on environment.
|
||||||
|
|
||||||
|
- In dev: uses SQLite if no URL provided
|
||||||
|
- In prod: uses PostgreSQL if no URL provided
|
||||||
|
"""
|
||||||
|
if self.db.url:
|
||||||
|
return self.db.url
|
||||||
|
|
||||||
|
if self.environment == Environment.PROD:
|
||||||
|
# Build PostgreSQL URL from components
|
||||||
|
return str(
|
||||||
|
PostgresDsn.build(
|
||||||
|
scheme="postgresql+asyncpg",
|
||||||
|
username=self.db.user,
|
||||||
|
password=self.db.password,
|
||||||
|
host=self.db.host,
|
||||||
|
port=self.db.port,
|
||||||
|
path=self.db.name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default dev SQLite URL
|
||||||
|
return "sqlite+aiosqlite:///./blog.db"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_dev(self) -> bool:
|
||||||
|
"""Check if running in development mode."""
|
||||||
|
return self.environment == Environment.DEV
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_prod(self) -> bool:
|
||||||
|
"""Check if running in production mode."""
|
||||||
|
return self.environment == Environment.PROD
|
||||||
|
|
||||||
|
|
||||||
# Global settings instance
|
# Global settings instance
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def _get_database_url() -> str:
|
|||||||
# Create async engine
|
# Create async engine
|
||||||
engine: AsyncEngine = create_async_engine(
|
engine: AsyncEngine = create_async_engine(
|
||||||
_get_database_url(),
|
_get_database_url(),
|
||||||
echo=settings.database_echo,
|
echo=settings.db.echo,
|
||||||
future=True,
|
future=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ from app.application import (
|
|||||||
)
|
)
|
||||||
from app.application.interfaces import TransactionManager
|
from app.application.interfaces import TransactionManager
|
||||||
from app.domain.repositories import PostRepository
|
from app.domain.repositories import PostRepository
|
||||||
|
from app.infrastructure.auth import KeycloakAuthClient
|
||||||
|
from app.infrastructure.config.settings import settings
|
||||||
from app.infrastructure.database.connection import AsyncSessionLocal, engine
|
from app.infrastructure.database.connection import AsyncSessionLocal, engine
|
||||||
from app.infrastructure.repositories.post import SQLAlchemyPostRepository
|
from app.infrastructure.repositories.post import SQLAlchemyPostRepository
|
||||||
|
|
||||||
@@ -131,3 +133,12 @@ class UseCaseProvider(Provider):
|
|||||||
post_repo=post_repo,
|
post_repo=post_repo,
|
||||||
tx_manager=tx_manager,
|
tx_manager=tx_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KeycloakProvider(Provider):
|
||||||
|
"""Provider for Keycloak authentication client."""
|
||||||
|
|
||||||
|
@provide(scope=Scope.APP)
|
||||||
|
def get_keycloak_client(self) -> KeycloakAuthClient:
|
||||||
|
"""Provide KeycloakAuthClient singleton."""
|
||||||
|
return KeycloakAuthClient(settings)
|
||||||
|
|||||||
@@ -105,27 +105,50 @@ class SQLAlchemyPostRepository(PostRepository):
|
|||||||
orm = result.scalar_one_or_none()
|
orm = result.scalar_one_or_none()
|
||||||
return self._to_domain(orm) if orm else None
|
return self._to_domain(orm) if orm else None
|
||||||
|
|
||||||
async def get_by_author(self, author_id: str) -> list[Post]:
|
async def get_by_author(
|
||||||
|
self,
|
||||||
|
author_id: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> list[Post]:
|
||||||
"""Get posts by author."""
|
"""Get posts by author."""
|
||||||
result = await self._session.execute(
|
query = select(PostORM).where(PostORM.author_id == author_id)
|
||||||
select(PostORM).where(PostORM.author_id == author_id)
|
if limit is not None:
|
||||||
)
|
query = query.limit(limit)
|
||||||
|
if offset is not None:
|
||||||
|
query = query.offset(offset)
|
||||||
|
result = await self._session.execute(query)
|
||||||
orms = result.scalars().all()
|
orms = result.scalars().all()
|
||||||
return [self._to_domain(orm) for orm in orms]
|
return [self._to_domain(orm) for orm in orms]
|
||||||
|
|
||||||
async def get_published(self) -> list[Post]:
|
async def get_published(
|
||||||
|
self,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> list[Post]:
|
||||||
"""Get published posts."""
|
"""Get published posts."""
|
||||||
result = await self._session.execute(
|
query = select(PostORM).where(PostORM.published.is_(True))
|
||||||
select(PostORM).where(PostORM.published.is_(True))
|
if limit is not None:
|
||||||
)
|
query = query.limit(limit)
|
||||||
|
if offset is not None:
|
||||||
|
query = query.offset(offset)
|
||||||
|
result = await self._session.execute(query)
|
||||||
orms = result.scalars().all()
|
orms = result.scalars().all()
|
||||||
return [self._to_domain(orm) for orm in orms]
|
return [self._to_domain(orm) for orm in orms]
|
||||||
|
|
||||||
async def get_by_tag(self, tag: str) -> list[Post]:
|
async def get_by_tag(
|
||||||
|
self,
|
||||||
|
tag: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> list[Post]:
|
||||||
"""Get posts by tag."""
|
"""Get posts by tag."""
|
||||||
result = await self._session.execute(
|
query = select(PostORM).where(PostORM.tags.contains([tag]))
|
||||||
select(PostORM).where(PostORM.tags.contains([tag]))
|
if limit is not None:
|
||||||
)
|
query = query.limit(limit)
|
||||||
|
if offset is not None:
|
||||||
|
query = query.offset(offset)
|
||||||
|
result = await self._session.execute(query)
|
||||||
orms = result.scalars().all()
|
orms = result.scalars().all()
|
||||||
return [self._to_domain(orm) for orm in orms]
|
return [self._to_domain(orm) for orm in orms]
|
||||||
|
|
||||||
@@ -136,16 +159,24 @@ class SQLAlchemyPostRepository(PostRepository):
|
|||||||
)
|
)
|
||||||
return result.scalar_one_or_none() is not None
|
return result.scalar_one_or_none() is not None
|
||||||
|
|
||||||
async def search(self, query: str) -> list[Post]:
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> list[Post]:
|
||||||
"""Search posts."""
|
"""Search posts."""
|
||||||
search_pattern = f"%{query}%"
|
search_pattern = f"%{query}%"
|
||||||
result = await self._session.execute(
|
stmt = select(PostORM).where(
|
||||||
select(PostORM).where(
|
|
||||||
or_(
|
or_(
|
||||||
PostORM.title.ilike(search_pattern),
|
PostORM.title.ilike(search_pattern),
|
||||||
PostORM.content.ilike(search_pattern),
|
PostORM.content.ilike(search_pattern),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
if offset is not None:
|
||||||
|
stmt = stmt.offset(offset)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
orms = result.scalars().all()
|
orms = result.scalars().all()
|
||||||
return [self._to_domain(orm) for orm in orms]
|
return [self._to_domain(orm) for orm in orms]
|
||||||
|
|||||||
20
app/main.py
20
app/main.py
@@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from app.infrastructure import close_db, init_db, register_exception_handlers, settings
|
from app.infrastructure import close_db, init_db, register_exception_handlers, settings
|
||||||
from app.infrastructure.di.providers import (
|
from app.infrastructure.di.providers import (
|
||||||
DatabaseProvider,
|
DatabaseProvider,
|
||||||
|
KeycloakProvider,
|
||||||
RepositoryProvider,
|
RepositoryProvider,
|
||||||
TransactionManagerProvider,
|
TransactionManagerProvider,
|
||||||
UseCaseProvider,
|
UseCaseProvider,
|
||||||
@@ -32,11 +33,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
def app_factory() -> FastAPI:
|
def app_factory() -> FastAPI:
|
||||||
"""Create and configure FastAPI application."""
|
"""Create and configure FastAPI application."""
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title=settings.app_name,
|
title=settings.app.name,
|
||||||
debug=settings.debug,
|
debug=settings.app.debug,
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
docs_url="/docs" if settings.debug else None,
|
docs_url="/docs" if settings.is_dev else None,
|
||||||
redoc_url="/redoc" if settings.debug else None,
|
redoc_url="/redoc" if settings.is_dev else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup Dishka DI container
|
# Setup Dishka DI container
|
||||||
@@ -45,6 +46,7 @@ def app_factory() -> FastAPI:
|
|||||||
RepositoryProvider(),
|
RepositoryProvider(),
|
||||||
TransactionManagerProvider(),
|
TransactionManagerProvider(),
|
||||||
UseCaseProvider(),
|
UseCaseProvider(),
|
||||||
|
KeycloakProvider(),
|
||||||
)
|
)
|
||||||
setup_dishka(container, app)
|
setup_dishka(container, app)
|
||||||
|
|
||||||
@@ -66,7 +68,11 @@ def app_factory() -> FastAPI:
|
|||||||
# Health check endpoint
|
# Health check endpoint
|
||||||
@app.get("/health", tags=["health"])
|
@app.get("/health", tags=["health"])
|
||||||
async def health_check() -> dict[str, str]:
|
async def health_check() -> dict[str, str]:
|
||||||
return {"status": "ok", "app": settings.app_name}
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"app": settings.app.name,
|
||||||
|
"env": settings.environment.value,
|
||||||
|
}
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
@@ -76,8 +82,8 @@ def main() -> None:
|
|||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app_factory,
|
app_factory,
|
||||||
factory=True,
|
factory=True,
|
||||||
host=settings.host,
|
host=settings.app.host,
|
||||||
port=settings.port,
|
port=settings.app.port,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""API dependencies using Dishka."""
|
"""API dependencies using Dishka."""
|
||||||
|
|
||||||
from typing import Annotated
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from dishka.integrations.fastapi import FromDishka
|
from dishka.integrations.fastapi import FromDishka
|
||||||
from fastapi import Depends, Header
|
from fastapi import Depends, Request
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
|
||||||
from app.application import (
|
from app.application import (
|
||||||
CreatePostUseCase,
|
CreatePostUseCase,
|
||||||
@@ -13,6 +14,9 @@ from app.application import (
|
|||||||
PublishPostUseCase,
|
PublishPostUseCase,
|
||||||
UpdatePostUseCase,
|
UpdatePostUseCase,
|
||||||
)
|
)
|
||||||
|
from app.domain.exceptions import ForbiddenException, UnauthorizedException
|
||||||
|
from app.domain.roles import Role, get_effective_role
|
||||||
|
from app.infrastructure.auth import KeycloakAuthClient, TokenInfo
|
||||||
|
|
||||||
# Use case dependencies - injected via Dishka
|
# Use case dependencies - injected via Dishka
|
||||||
CreatePostDep = FromDishka[CreatePostUseCase]
|
CreatePostDep = FromDishka[CreatePostUseCase]
|
||||||
@@ -22,13 +26,106 @@ DeletePostDep = FromDishka[DeletePostUseCase]
|
|||||||
ListPostsDep = FromDishka[ListPostsUseCase]
|
ListPostsDep = FromDishka[ListPostsUseCase]
|
||||||
PublishPostDep = FromDishka[PublishPostUseCase]
|
PublishPostDep = FromDishka[PublishPostUseCase]
|
||||||
|
|
||||||
|
# Security scheme
|
||||||
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_keycloak_client(request: Request) -> KeycloakAuthClient:
|
||||||
|
"""Get Keycloak client from DI container via request state."""
|
||||||
|
client: KeycloakAuthClient = request.state.dishka_container.get(KeycloakAuthClient)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_token_info(
|
||||||
|
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
|
||||||
|
request: Request,
|
||||||
|
) -> TokenInfo:
|
||||||
|
"""Validate token and return token info from Keycloak."""
|
||||||
|
if not credentials:
|
||||||
|
raise UnauthorizedException("Authentication required")
|
||||||
|
|
||||||
|
keycloak_client = get_keycloak_client(request)
|
||||||
|
token = credentials.credentials
|
||||||
|
token_info = await keycloak_client.introspect_token(token)
|
||||||
|
|
||||||
|
if not token_info.is_valid:
|
||||||
|
raise UnauthorizedException("Invalid or expired token")
|
||||||
|
|
||||||
|
return token_info
|
||||||
|
|
||||||
|
|
||||||
# Mock current user dependency (replace with real auth)
|
|
||||||
async def get_current_user_id(
|
async def get_current_user_id(
|
||||||
x_user_id: Annotated[str | None, Header()] = "user-123",
|
token_info: Annotated[TokenInfo, Depends(get_current_token_info)],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get current user ID from header."""
|
"""Get current user ID from validated token."""
|
||||||
return x_user_id or "user-123"
|
return token_info.user_id
|
||||||
|
|
||||||
|
|
||||||
CurrentUserDep = Annotated[str, Depends(get_current_user_id)]
|
CurrentUserDep = Annotated[str, Depends(get_current_user_id)]
|
||||||
|
TokenInfoDep = Annotated[TokenInfo, Depends(get_current_token_info)]
|
||||||
|
|
||||||
|
|
||||||
|
# Optional auth - doesn't require authentication but provides user info if available
|
||||||
|
async def get_optional_token_info(
|
||||||
|
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
|
||||||
|
request: Request,
|
||||||
|
) -> TokenInfo | None:
|
||||||
|
"""Get token info if valid token provided, otherwise None (guest)."""
|
||||||
|
if not credentials:
|
||||||
|
return None
|
||||||
|
|
||||||
|
keycloak_client = get_keycloak_client(request)
|
||||||
|
token = credentials.credentials
|
||||||
|
token_info = await keycloak_client.introspect_token(token)
|
||||||
|
|
||||||
|
if token_info.is_valid:
|
||||||
|
return token_info
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
OptionalTokenInfoDep = Annotated[TokenInfo | None, Depends(get_optional_token_info)]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_optional_user_id(
|
||||||
|
token_info: OptionalTokenInfoDep,
|
||||||
|
) -> str | None:
|
||||||
|
"""Get current user ID if token is valid, otherwise None."""
|
||||||
|
if token_info:
|
||||||
|
return token_info.user_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
OptionalUserDep = Annotated[str | None, Depends(get_optional_user_id)]
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_role(token_info: OptionalTokenInfoDep) -> Role:
|
||||||
|
"""Get effective role from token info.
|
||||||
|
|
||||||
|
Returns GUEST if no valid token provided.
|
||||||
|
"""
|
||||||
|
if token_info and token_info.roles:
|
||||||
|
return get_effective_role(token_info.roles)
|
||||||
|
return Role.GUEST
|
||||||
|
|
||||||
|
|
||||||
|
CurrentRoleDep = Annotated[Role, Depends(get_current_role)]
|
||||||
|
|
||||||
|
|
||||||
|
def require_roles(allowed_roles: list[Role]) -> Any:
|
||||||
|
"""Create dependency that checks if user has one of the allowed roles."""
|
||||||
|
|
||||||
|
async def check_role(role: CurrentRoleDep) -> Role:
|
||||||
|
if role not in allowed_roles:
|
||||||
|
raise ForbiddenException(
|
||||||
|
f"Access denied. Required roles: {[r.value for r in allowed_roles]}"
|
||||||
|
)
|
||||||
|
return role
|
||||||
|
|
||||||
|
return Depends(check_role)
|
||||||
|
|
||||||
|
|
||||||
|
# Predefined role requirements
|
||||||
|
RequireAdmin = require_roles([Role.ADMIN])
|
||||||
|
RequireUser = require_roles([Role.USER, Role.ADMIN])
|
||||||
|
RequireAny = require_roles([Role.GUEST, Role.USER, Role.ADMIN])
|
||||||
|
|||||||
@@ -6,8 +6,11 @@ from dishka.integrations.fastapi import DishkaRoute
|
|||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, status
|
||||||
|
|
||||||
from app.application.dtos import CreatePostDTO, UpdatePostDTO
|
from app.application.dtos import CreatePostDTO, UpdatePostDTO
|
||||||
|
from app.domain.exceptions import ForbiddenException
|
||||||
|
from app.domain.roles import Permission, has_permission
|
||||||
from app.presentation.api.deps import (
|
from app.presentation.api.deps import (
|
||||||
CreatePostDep,
|
CreatePostDep,
|
||||||
|
CurrentRoleDep,
|
||||||
CurrentUserDep,
|
CurrentUserDep,
|
||||||
DeletePostDep,
|
DeletePostDep,
|
||||||
GetPostDep,
|
GetPostDep,
|
||||||
@@ -50,11 +53,38 @@ async def create_post(
|
|||||||
@router.get(
|
@router.get(
|
||||||
"",
|
"",
|
||||||
response_model=PostListResponseSchema,
|
response_model=PostListResponseSchema,
|
||||||
summary="List all posts",
|
summary="List posts",
|
||||||
)
|
)
|
||||||
async def list_posts(use_case: ListPostsDep) -> PostListResponseSchema:
|
async def list_posts(
|
||||||
"""Get all blog posts."""
|
use_case: ListPostsDep,
|
||||||
|
role: CurrentRoleDep,
|
||||||
|
include_unpublished: bool = False,
|
||||||
|
limit: int = 10,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> PostListResponseSchema:
|
||||||
|
"""Get blog posts with optional filtering and pagination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_unpublished: If True, returns all posts including drafts.
|
||||||
|
Only admins can use this parameter.
|
||||||
|
limit: Maximum number of posts to return (default: 10, max: 100).
|
||||||
|
offset: Number of posts to skip (default: 0).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ForbiddenException: If non-admin tries to include unpublished posts.
|
||||||
|
"""
|
||||||
|
# Clamp limit to reasonable range
|
||||||
|
limit = max(1, min(limit, 100))
|
||||||
|
offset = max(0, offset)
|
||||||
|
|
||||||
|
# Check permissions for unpublished posts
|
||||||
|
if include_unpublished:
|
||||||
|
if not has_permission(role, Permission.POST_READ_UNPUBLISHED):
|
||||||
|
raise ForbiddenException("Only admins can view unpublished posts")
|
||||||
results = await use_case.all_posts()
|
results = await use_case.all_posts()
|
||||||
|
else:
|
||||||
|
results = await use_case.published_posts(limit=limit, offset=offset)
|
||||||
|
|
||||||
items = [PostResponseSchema(**r.__dict__) for r in results]
|
items = [PostResponseSchema(**r.__dict__) for r in results]
|
||||||
return PostListResponseSchema(items=items, total=len(items))
|
return PostListResponseSchema(items=items, total=len(items))
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,6 @@ version = "0.1.0"
|
|||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["hatchling"]
|
|
||||||
build-backend = "hatchling.build"
|
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
|
||||||
packages = ["app"]
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi>=0.136.0",
|
"fastapi>=0.136.0",
|
||||||
"pydantic>=2.13.2",
|
"pydantic>=2.13.2",
|
||||||
@@ -18,9 +11,18 @@ dependencies = [
|
|||||||
"uvicorn>=0.44.0",
|
"uvicorn>=0.44.0",
|
||||||
"sqlalchemy>=2.0.0",
|
"sqlalchemy>=2.0.0",
|
||||||
"aiosqlite>=0.21.0",
|
"aiosqlite>=0.21.0",
|
||||||
|
"asyncpg>=0.30.0",
|
||||||
"dishka>=1.5.0",
|
"dishka>=1.5.0",
|
||||||
|
"httpx>=0.28.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["app"]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
{include-group = "lints"},
|
{include-group = "lints"},
|
||||||
|
|||||||
@@ -1,16 +1,36 @@
|
|||||||
"""API test fixtures."""
|
"""API test fixtures."""
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import ASGITransport, AsyncClient
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from app.infrastructure.auth.models import TokenInfo
|
||||||
from app.main import app_factory
|
from app.main import app_factory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def client() -> AsyncGenerator[AsyncClient, None]:
|
def mock_keycloak_client() -> MagicMock:
|
||||||
|
"""Create mock Keycloak client for testing."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.introspect_token.return_value = TokenInfo(
|
||||||
|
active=True,
|
||||||
|
user_id="test-user-id",
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
roles=["user"],
|
||||||
|
)
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(mock_keycloak_client: MagicMock) -> AsyncGenerator[AsyncClient, None]:
|
||||||
"""Create async HTTP client for API testing."""
|
"""Create async HTTP client for API testing."""
|
||||||
|
with patch(
|
||||||
|
"app.presentation.api.deps.KeycloakAuthClient",
|
||||||
|
return_value=mock_keycloak_client,
|
||||||
|
):
|
||||||
app = app_factory()
|
app = app_factory()
|
||||||
transport = ASGITransport(app=app)
|
transport = ASGITransport(app=app)
|
||||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||||
@@ -20,4 +40,18 @@ async def client() -> AsyncGenerator[AsyncClient, None]:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def auth_headers() -> dict[str, str]:
|
def auth_headers() -> dict[str, str]:
|
||||||
"""Return mock authentication headers."""
|
"""Return mock authentication headers."""
|
||||||
return {"Authorization": "Bearer test_token", "X-User-Id": "user-123"}
|
return {"Authorization": "Bearer test_token"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unauthorized_keycloak_client() -> MagicMock:
|
||||||
|
"""Create mock Keycloak client that returns invalid token."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.introspect_token.return_value = TokenInfo(
|
||||||
|
active=False,
|
||||||
|
user_id="",
|
||||||
|
username="",
|
||||||
|
email="",
|
||||||
|
roles=[],
|
||||||
|
)
|
||||||
|
return mock_client
|
||||||
|
|||||||
123
tests/unit/domain/test_roles.py
Normal file
123
tests/unit/domain/test_roles.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""Tests for role-based access control."""
|
||||||
|
|
||||||
|
from app.domain.roles import (
|
||||||
|
ROLE_PERMISSIONS,
|
||||||
|
Permission,
|
||||||
|
Role,
|
||||||
|
get_effective_role,
|
||||||
|
has_permission,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRole:
|
||||||
|
"""Test Role enum."""
|
||||||
|
|
||||||
|
def test_role_values(self) -> None:
|
||||||
|
"""Test role enum values."""
|
||||||
|
assert Role.ADMIN.value == "admin"
|
||||||
|
assert Role.USER.value == "user"
|
||||||
|
assert Role.GUEST.value == "guest"
|
||||||
|
|
||||||
|
def test_role_comparison(self) -> None:
|
||||||
|
"""Test role comparison."""
|
||||||
|
assert Role.ADMIN == Role.ADMIN
|
||||||
|
# USER and ADMIN are different enum values with different string values
|
||||||
|
assert Role.USER.value != Role.ADMIN.value # type: ignore[comparison-overlap]
|
||||||
|
|
||||||
|
|
||||||
|
class TestPermissions:
|
||||||
|
"""Test permission definitions."""
|
||||||
|
|
||||||
|
def test_permission_values(self) -> None:
|
||||||
|
"""Test permission constants."""
|
||||||
|
assert Permission.POST_CREATE == "post:create"
|
||||||
|
assert Permission.POST_READ == "post:read"
|
||||||
|
assert Permission.POST_READ_UNPUBLISHED == "post:read_unpublished"
|
||||||
|
assert Permission.POST_UPDATE == "post:update"
|
||||||
|
assert Permission.POST_DELETE == "post:delete"
|
||||||
|
assert Permission.POST_PUBLISH == "post:publish"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRolePermissions:
|
||||||
|
"""Test role-based permission mapping."""
|
||||||
|
|
||||||
|
def test_admin_has_all_permissions(self) -> None:
|
||||||
|
"""Test admin has all permissions."""
|
||||||
|
admin_perms = ROLE_PERMISSIONS[Role.ADMIN]
|
||||||
|
assert Permission.POST_CREATE in admin_perms
|
||||||
|
assert Permission.POST_READ in admin_perms
|
||||||
|
assert Permission.POST_READ_UNPUBLISHED in admin_perms
|
||||||
|
assert Permission.POST_UPDATE in admin_perms
|
||||||
|
assert Permission.POST_DELETE in admin_perms
|
||||||
|
assert Permission.POST_PUBLISH in admin_perms
|
||||||
|
|
||||||
|
def test_user_permissions(self) -> None:
|
||||||
|
"""Test user permissions."""
|
||||||
|
user_perms = ROLE_PERMISSIONS[Role.USER]
|
||||||
|
assert Permission.POST_CREATE in user_perms
|
||||||
|
assert Permission.POST_READ in user_perms
|
||||||
|
assert Permission.POST_UPDATE in user_perms
|
||||||
|
assert Permission.POST_DELETE in user_perms
|
||||||
|
assert Permission.POST_PUBLISH in user_perms
|
||||||
|
# User cannot read unpublished
|
||||||
|
assert Permission.POST_READ_UNPUBLISHED not in user_perms
|
||||||
|
|
||||||
|
def test_guest_permissions(self) -> None:
|
||||||
|
"""Test guest permissions."""
|
||||||
|
guest_perms = ROLE_PERMISSIONS[Role.GUEST]
|
||||||
|
assert Permission.POST_READ in guest_perms
|
||||||
|
# Guest has very limited permissions
|
||||||
|
assert Permission.POST_CREATE not in guest_perms
|
||||||
|
assert Permission.POST_UPDATE not in guest_perms
|
||||||
|
assert Permission.POST_DELETE not in guest_perms
|
||||||
|
assert Permission.POST_READ_UNPUBLISHED not in guest_perms
|
||||||
|
|
||||||
|
|
||||||
|
class TestHasPermission:
|
||||||
|
"""Test has_permission function."""
|
||||||
|
|
||||||
|
def test_admin_has_all_permissions_check(self) -> None:
|
||||||
|
"""Test admin permission checks."""
|
||||||
|
assert has_permission(Role.ADMIN, Permission.POST_CREATE) is True
|
||||||
|
assert has_permission(Role.ADMIN, Permission.POST_READ_UNPUBLISHED) is True
|
||||||
|
assert has_permission(Role.ADMIN, "unknown:permission") is False
|
||||||
|
|
||||||
|
def test_user_limited_permissions(self) -> None:
|
||||||
|
"""Test user limited permissions."""
|
||||||
|
assert has_permission(Role.USER, Permission.POST_CREATE) is True
|
||||||
|
assert has_permission(Role.USER, Permission.POST_READ_UNPUBLISHED) is False
|
||||||
|
assert has_permission(Role.USER, Permission.POST_READ) is True
|
||||||
|
|
||||||
|
def test_guest_read_only(self) -> None:
|
||||||
|
"""Test guest read-only access."""
|
||||||
|
assert has_permission(Role.GUEST, Permission.POST_READ) is True
|
||||||
|
assert has_permission(Role.GUEST, Permission.POST_CREATE) is False
|
||||||
|
assert has_permission(Role.GUEST, Permission.POST_UPDATE) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetEffectiveRole:
|
||||||
|
"""Test get_effective_role function."""
|
||||||
|
|
||||||
|
def test_admin_from_roles_list(self) -> None:
|
||||||
|
"""Test admin role detection."""
|
||||||
|
assert get_effective_role(["admin"]) == Role.ADMIN
|
||||||
|
assert get_effective_role(["user", "admin"]) == Role.ADMIN
|
||||||
|
assert get_effective_role(["admin", "user"]) == Role.ADMIN
|
||||||
|
|
||||||
|
def test_user_from_roles_list(self) -> None:
|
||||||
|
"""Test user role detection."""
|
||||||
|
assert get_effective_role(["user"]) == Role.USER
|
||||||
|
assert get_effective_role(["user", "moderator"]) == Role.USER
|
||||||
|
|
||||||
|
def test_guest_from_roles_list(self) -> None:
|
||||||
|
"""Test guest role detection."""
|
||||||
|
assert get_effective_role([]) == Role.GUEST
|
||||||
|
assert get_effective_role(["unknown"]) == Role.GUEST
|
||||||
|
assert get_effective_role(["guest"]) == Role.GUEST
|
||||||
|
|
||||||
|
def test_role_priority(self) -> None:
|
||||||
|
"""Test that admin > user > guest."""
|
||||||
|
# Admin takes precedence
|
||||||
|
assert get_effective_role(["user", "admin", "guest"]) == Role.ADMIN
|
||||||
|
# User takes precedence over guest
|
||||||
|
assert get_effective_role(["guest", "user"]) == Role.USER
|
||||||
318
tests/unit/infrastructure/test_auth.py
Normal file
318
tests/unit/infrastructure/test_auth.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
"""Tests for Keycloak authentication client."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.infrastructure.auth import KeycloakAuthClient, KeycloakUser, TokenInfo
|
||||||
|
from app.infrastructure.config.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenInfo:
|
||||||
|
"""Test TokenInfo dataclass."""
|
||||||
|
|
||||||
|
def test_token_info_valid(self) -> None:
|
||||||
|
"""Test valid token info."""
|
||||||
|
token_info = TokenInfo(
|
||||||
|
active=True,
|
||||||
|
user_id="user-123",
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
roles=["user"],
|
||||||
|
)
|
||||||
|
assert token_info.is_valid is True
|
||||||
|
assert token_info.user_id == "user-123"
|
||||||
|
assert token_info.username == "testuser"
|
||||||
|
assert token_info.email == "test@example.com"
|
||||||
|
assert token_info.roles == ["user"]
|
||||||
|
|
||||||
|
def test_token_info_invalid_not_active(self) -> None:
|
||||||
|
"""Test invalid token when not active."""
|
||||||
|
token_info = TokenInfo(
|
||||||
|
active=False,
|
||||||
|
user_id="user-123",
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
roles=["user"],
|
||||||
|
)
|
||||||
|
assert token_info.is_valid is False
|
||||||
|
|
||||||
|
def test_token_info_invalid_no_user_id(self) -> None:
|
||||||
|
"""Test invalid token when no user_id."""
|
||||||
|
token_info = TokenInfo(
|
||||||
|
active=True,
|
||||||
|
user_id="",
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
roles=["user"],
|
||||||
|
)
|
||||||
|
assert token_info.is_valid is False
|
||||||
|
|
||||||
|
def test_token_info_empty_roles(self) -> None:
|
||||||
|
"""Test token info with empty roles."""
|
||||||
|
token_info = TokenInfo(
|
||||||
|
active=True,
|
||||||
|
user_id="user-123",
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
roles=[],
|
||||||
|
)
|
||||||
|
assert token_info.is_valid is True
|
||||||
|
assert token_info.roles == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestKeycloakUser:
|
||||||
|
"""Test KeycloakUser dataclass."""
|
||||||
|
|
||||||
|
def test_keycloak_user_creation(self) -> None:
|
||||||
|
"""Test KeycloakUser creation."""
|
||||||
|
user = KeycloakUser(
|
||||||
|
id="user-123",
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
first_name="Test",
|
||||||
|
last_name="User",
|
||||||
|
roles=["user", "admin"],
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
assert user.id == "user-123"
|
||||||
|
assert user.username == "testuser"
|
||||||
|
assert user.email == "test@example.com"
|
||||||
|
assert user.first_name == "Test"
|
||||||
|
assert user.last_name == "User"
|
||||||
|
assert user.roles == ["user", "admin"]
|
||||||
|
assert user.is_active is True
|
||||||
|
|
||||||
|
def test_keycloak_user_defaults(self) -> None:
|
||||||
|
"""Test KeycloakUser with default values."""
|
||||||
|
user = KeycloakUser(
|
||||||
|
id="user-123",
|
||||||
|
username="testuser",
|
||||||
|
email="test@example.com",
|
||||||
|
)
|
||||||
|
assert user.first_name == ""
|
||||||
|
assert user.last_name == ""
|
||||||
|
assert user.roles == []
|
||||||
|
assert user.is_active is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestKeycloakAuthClient:
|
||||||
|
"""Test KeycloakAuthClient."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def settings(self) -> Settings:
|
||||||
|
"""Create test settings."""
|
||||||
|
from app.infrastructure.config import KCConfig, SecurityConfig
|
||||||
|
|
||||||
|
return Settings(
|
||||||
|
environment="dev",
|
||||||
|
kc=KCConfig(
|
||||||
|
server_url="http://localhost:8080",
|
||||||
|
realm="test-realm",
|
||||||
|
client_id="test-client",
|
||||||
|
client_secret="test-secret",
|
||||||
|
token_cache_ttl=60,
|
||||||
|
),
|
||||||
|
security=SecurityConfig(
|
||||||
|
secret_key="test-secret-key-for-jwt-tokens",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(self, settings: Settings) -> KeycloakAuthClient:
|
||||||
|
"""Create Keycloak client."""
|
||||||
|
return KeycloakAuthClient(settings)
|
||||||
|
|
||||||
|
def test_client_initialization(
|
||||||
|
self, client: KeycloakAuthClient, settings: Settings
|
||||||
|
) -> None:
|
||||||
|
"""Test client initialization."""
|
||||||
|
assert client._settings == settings
|
||||||
|
assert client._base_url == "http://localhost:8080/realms/test-realm"
|
||||||
|
assert client._client_id == "test-client"
|
||||||
|
assert client._client_secret == "test-secret"
|
||||||
|
assert client._cache_ttl == 60
|
||||||
|
|
||||||
|
def test_get_introspection_url(self, client: KeycloakAuthClient) -> None:
|
||||||
|
"""Test introspection URL generation."""
|
||||||
|
url = client._get_introspection_url()
|
||||||
|
assert (
|
||||||
|
url
|
||||||
|
== "http://localhost:8080/realms/test-realm/protocol/openid-connect/token/introspection"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_userinfo_url(self, client: KeycloakAuthClient) -> None:
|
||||||
|
"""Test userinfo URL generation."""
|
||||||
|
url = client._get_userinfo_url()
|
||||||
|
assert (
|
||||||
|
url
|
||||||
|
== "http://localhost:8080/realms/test-realm/protocol/openid-connect/userinfo"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_introspect_token_success(self, client: KeycloakAuthClient) -> None:
|
||||||
|
"""Test successful token introspection."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"active": True,
|
||||||
|
"sub": "user-123",
|
||||||
|
"preferred_username": "testuser",
|
||||||
|
"email": "test@example.com",
|
||||||
|
"realm_access": {"roles": ["user", "admin"]},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_async_client = AsyncMock()
|
||||||
|
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||||
|
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_async_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||||
|
result = await client.introspect_token("test-token")
|
||||||
|
|
||||||
|
assert result.active is True
|
||||||
|
assert result.user_id == "user-123"
|
||||||
|
assert result.username == "testuser"
|
||||||
|
assert result.email == "test@example.com"
|
||||||
|
assert result.roles == ["user", "admin"]
|
||||||
|
assert result.is_valid is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_introspect_token_inactive(self, client: KeycloakAuthClient) -> None:
|
||||||
|
"""Test introspection with inactive token."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"active": False}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_async_client = AsyncMock()
|
||||||
|
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||||
|
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_async_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||||
|
result = await client.introspect_token("test-token")
|
||||||
|
|
||||||
|
assert result.active is False
|
||||||
|
assert result.is_valid is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_introspect_token_http_error(
|
||||||
|
self, client: KeycloakAuthClient
|
||||||
|
) -> None:
|
||||||
|
"""Test introspection with HTTP error."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
mock_async_client = AsyncMock()
|
||||||
|
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||||
|
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_async_client.post = AsyncMock(
|
||||||
|
side_effect=httpx.HTTPError("Connection error")
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||||
|
result = await client.introspect_token("test-token")
|
||||||
|
|
||||||
|
assert result.active is False
|
||||||
|
assert result.is_valid is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_introspect_token_uses_cache(
|
||||||
|
self, client: KeycloakAuthClient
|
||||||
|
) -> None:
|
||||||
|
"""Test that token introspection uses cache."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"active": True,
|
||||||
|
"sub": "user-123",
|
||||||
|
"preferred_username": "testuser",
|
||||||
|
"email": "test@example.com",
|
||||||
|
"realm_access": {"roles": ["user"]},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_async_client = AsyncMock()
|
||||||
|
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||||
|
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_async_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||||
|
# First call
|
||||||
|
result1 = await client.introspect_token("test-token")
|
||||||
|
# Second call should use cache
|
||||||
|
result2 = await client.introspect_token("test-token")
|
||||||
|
|
||||||
|
# HTTP client should only be called once
|
||||||
|
assert mock_async_client.post.call_count == 1
|
||||||
|
assert result1.user_id == result2.user_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_userinfo_success(self, client: KeycloakAuthClient) -> None:
|
||||||
|
"""Test successful userinfo retrieval."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"sub": "user-123",
|
||||||
|
"preferred_username": "testuser",
|
||||||
|
"email": "test@example.com",
|
||||||
|
"given_name": "Test",
|
||||||
|
"family_name": "User",
|
||||||
|
"realm_access": {"roles": ["user"]},
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_async_client = AsyncMock()
|
||||||
|
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||||
|
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_async_client.get = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||||
|
result = await client.get_userinfo("test-token")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.id == "user-123"
|
||||||
|
assert result.username == "testuser"
|
||||||
|
assert result.email == "test@example.com"
|
||||||
|
assert result.first_name == "Test"
|
||||||
|
assert result.last_name == "User"
|
||||||
|
assert result.roles == ["user"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_userinfo_error(self, client: KeycloakAuthClient) -> None:
|
||||||
|
"""Test userinfo retrieval with error."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
mock_async_client = AsyncMock()
|
||||||
|
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||||
|
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_async_client.get = AsyncMock(
|
||||||
|
side_effect=httpx.HTTPError("Connection error")
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||||
|
result = await client.get_userinfo("test-token")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_introspect_token_no_realm_roles(
|
||||||
|
self, client: KeycloakAuthClient
|
||||||
|
) -> None:
|
||||||
|
"""Test introspection without realm_access roles."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"active": True,
|
||||||
|
"sub": "user-123",
|
||||||
|
"preferred_username": "testuser",
|
||||||
|
"email": "test@example.com",
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = Mock()
|
||||||
|
|
||||||
|
mock_async_client = AsyncMock()
|
||||||
|
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||||
|
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
mock_async_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||||
|
result = await client.introspect_token("test-token")
|
||||||
|
|
||||||
|
assert result.active is True
|
||||||
|
assert result.roles == []
|
||||||
@@ -1,37 +1,247 @@
|
|||||||
"""Tests for infrastructure config."""
|
"""Tests for infrastructure config."""
|
||||||
|
|
||||||
from app.infrastructure.config import Settings
|
import pytest
|
||||||
|
|
||||||
|
from app.infrastructure.config import (
|
||||||
|
AppConfig,
|
||||||
|
DBConfig,
|
||||||
|
Environment,
|
||||||
|
KCConfig,
|
||||||
|
SecurityConfig,
|
||||||
|
Settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSettings:
|
class TestSettings:
|
||||||
|
"""Test Settings with composition pattern."""
|
||||||
|
|
||||||
def test_default_values(self) -> None:
|
def test_default_values(self) -> None:
|
||||||
"""Test default settings values by creating settings without env file."""
|
"""Test default settings values by creating settings without env file."""
|
||||||
# Create settings with no env file to test defaults
|
# Create settings with required secrets and no env file
|
||||||
s = Settings(_env_file=None)
|
s = Settings(
|
||||||
assert s.app_name == "Blog API"
|
_env_file=None,
|
||||||
assert s.debug is False
|
security=SecurityConfig(secret_key="test-secret-key"),
|
||||||
assert s.host == "0.0.0.0"
|
kc=KCConfig(client_secret="test-client-secret"),
|
||||||
assert s.port == 8000
|
)
|
||||||
assert s.database_url == "sqlite:///./blog.db"
|
assert s.app.name == "Blog API"
|
||||||
assert s.database_echo is False
|
assert s.app.debug is False
|
||||||
|
assert s.app.host == "0.0.0.0"
|
||||||
|
assert s.app.port == 8000
|
||||||
|
assert s.database_url == "sqlite+aiosqlite:///./blog.db"
|
||||||
|
assert s.db.echo is False
|
||||||
|
assert s.security.secret_key == "test-secret-key"
|
||||||
|
assert s.kc.client_secret == "test-client-secret"
|
||||||
|
assert s.environment == Environment.DEV
|
||||||
|
|
||||||
def test_custom_values(self) -> None:
|
def test_custom_values(self) -> None:
|
||||||
"""Test custom settings values."""
|
"""Test custom settings values."""
|
||||||
s = Settings(
|
s = Settings(
|
||||||
app_name="Test API",
|
_env_file=None,
|
||||||
|
environment=Environment.PROD,
|
||||||
|
app=AppConfig(
|
||||||
|
name="Test API",
|
||||||
debug=True,
|
debug=True,
|
||||||
host="localhost",
|
host="localhost",
|
||||||
port=9000,
|
port=9000,
|
||||||
database_url="postgresql://test",
|
),
|
||||||
secret_key="test-secret",
|
db=DBConfig(url="postgresql+asyncpg://user:pass@host/db"),
|
||||||
|
security=SecurityConfig(secret_key="test-secret"),
|
||||||
|
kc=KCConfig(client_secret="test-client-secret"),
|
||||||
)
|
)
|
||||||
assert s.app_name == "Test API"
|
assert s.app.name == "Test API"
|
||||||
assert s.debug is True
|
assert s.app.debug is True
|
||||||
assert s.host == "localhost"
|
assert s.app.host == "localhost"
|
||||||
assert s.port == 9000
|
assert s.app.port == 9000
|
||||||
assert s.database_url == "postgresql://test"
|
assert s.database_url == "postgresql+asyncpg://user:pass@host/db"
|
||||||
assert s.secret_key == "test-secret"
|
assert s.security.secret_key == "test-secret"
|
||||||
|
assert s.kc.client_secret == "test-client-secret"
|
||||||
|
assert s.environment == Environment.PROD
|
||||||
|
|
||||||
def test_model_config(self) -> None:
|
def test_model_config(self) -> None:
|
||||||
"""Test settings model config."""
|
"""Test settings model config."""
|
||||||
assert "env_file" in Settings.model_config
|
assert "env_file" in Settings.model_config
|
||||||
|
|
||||||
|
def test_is_dev_property(self) -> None:
|
||||||
|
"""Test is_dev property."""
|
||||||
|
s = Settings(
|
||||||
|
_env_file=None,
|
||||||
|
environment=Environment.DEV,
|
||||||
|
security=SecurityConfig(secret_key="test"),
|
||||||
|
kc=KCConfig(client_secret="test"),
|
||||||
|
)
|
||||||
|
assert s.is_dev is True
|
||||||
|
assert s.is_prod is False
|
||||||
|
|
||||||
|
def test_is_prod_property(self) -> None:
|
||||||
|
"""Test is_prod property."""
|
||||||
|
s = Settings(
|
||||||
|
_env_file=None,
|
||||||
|
environment=Environment.PROD,
|
||||||
|
security=SecurityConfig(secret_key="test"),
|
||||||
|
kc=KCConfig(client_secret="test"),
|
||||||
|
)
|
||||||
|
assert s.is_prod is True
|
||||||
|
assert s.is_dev is False
|
||||||
|
|
||||||
|
def test_prod_requires_security_secret(self) -> None:
|
||||||
|
"""Test that prod mode requires security secret_key."""
|
||||||
|
with pytest.raises(ValueError, match="SECURITY_SECRET_KEY"):
|
||||||
|
Settings(
|
||||||
|
_env_file=None,
|
||||||
|
environment=Environment.PROD,
|
||||||
|
security=SecurityConfig(secret_key=""),
|
||||||
|
kc=KCConfig(client_secret="test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prod_requires_kc_secret(self) -> None:
|
||||||
|
"""Test that prod mode requires KC client_secret."""
|
||||||
|
with pytest.raises(ValueError, match="KC_CLIENT_SECRET"):
|
||||||
|
Settings(
|
||||||
|
_env_file=None,
|
||||||
|
environment=Environment.PROD,
|
||||||
|
security=SecurityConfig(secret_key="test"),
|
||||||
|
kc=KCConfig(client_secret=""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_database_url_dev_default(self) -> None:
|
||||||
|
"""Test default database URL in dev mode."""
|
||||||
|
s = Settings(
|
||||||
|
_env_file=None,
|
||||||
|
environment=Environment.DEV,
|
||||||
|
security=SecurityConfig(secret_key="test"),
|
||||||
|
kc=KCConfig(client_secret="test"),
|
||||||
|
)
|
||||||
|
assert s.database_url == "sqlite+aiosqlite:///./blog.db"
|
||||||
|
|
||||||
|
def test_database_url_prod_builds_postgres(self) -> None:
|
||||||
|
"""Test that database URL builds from components in prod."""
|
||||||
|
s = Settings(
|
||||||
|
_env_file=None,
|
||||||
|
environment=Environment.PROD,
|
||||||
|
db=DBConfig(
|
||||||
|
url=None, # Force building from components
|
||||||
|
host="db.example.com",
|
||||||
|
port=5433,
|
||||||
|
user="admin",
|
||||||
|
password="secret",
|
||||||
|
name="mydb",
|
||||||
|
),
|
||||||
|
security=SecurityConfig(secret_key="test"),
|
||||||
|
kc=KCConfig(client_secret="test"),
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
s.database_url
|
||||||
|
== "postgresql+asyncpg://admin:secret@db.example.com:5433/mydb"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_database_url_override(self) -> None:
|
||||||
|
"""Test that explicit database URL overrides auto-building."""
|
||||||
|
s = Settings(
|
||||||
|
_env_file=None,
|
||||||
|
environment=Environment.PROD,
|
||||||
|
db=DBConfig(
|
||||||
|
url="postgresql+asyncpg://custom/url",
|
||||||
|
host="ignored",
|
||||||
|
user="ignored",
|
||||||
|
),
|
||||||
|
security=SecurityConfig(secret_key="test"),
|
||||||
|
kc=KCConfig(client_secret="test"),
|
||||||
|
)
|
||||||
|
assert s.database_url == "postgresql+asyncpg://custom/url"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppConfig:
|
||||||
|
"""Test AppConfig."""
|
||||||
|
|
||||||
|
def test_default_values(self) -> None:
|
||||||
|
"""Test AppConfig default values."""
|
||||||
|
cfg = AppConfig()
|
||||||
|
assert cfg.name == "Blog API"
|
||||||
|
assert cfg.debug is False
|
||||||
|
assert cfg.host == "0.0.0.0"
|
||||||
|
assert cfg.port == 8000
|
||||||
|
|
||||||
|
|
||||||
|
class TestDBConfig:
|
||||||
|
"""Test DBConfig."""
|
||||||
|
|
||||||
|
def test_default_values(self) -> None:
|
||||||
|
"""Test DBConfig default values."""
|
||||||
|
cfg = DBConfig()
|
||||||
|
assert cfg.url is None
|
||||||
|
assert cfg.echo is False
|
||||||
|
assert cfg.host == "localhost"
|
||||||
|
assert cfg.port == 5432
|
||||||
|
assert cfg.user == "postgres"
|
||||||
|
assert cfg.password == "postgres"
|
||||||
|
assert cfg.name == "blog"
|
||||||
|
|
||||||
|
def test_postgres_url_validation(self) -> None:
|
||||||
|
"""Test URL validation for postgres."""
|
||||||
|
cfg = DBConfig(url="postgresql+asyncpg://user:pass@host/db")
|
||||||
|
assert cfg.url == "postgresql+asyncpg://user:pass@host/db"
|
||||||
|
|
||||||
|
def test_sqlite_url_validation(self) -> None:
|
||||||
|
"""Test URL validation for sqlite."""
|
||||||
|
cfg = DBConfig(url="sqlite+aiosqlite:///./test.db")
|
||||||
|
assert cfg.url == "sqlite+aiosqlite:///./test.db"
|
||||||
|
|
||||||
|
def test_invalid_url_validation(self) -> None:
|
||||||
|
"""Test URL validation rejects invalid URLs."""
|
||||||
|
with pytest.raises(ValueError, match="sqlite+.*postgresql+"):
|
||||||
|
DBConfig(url="mysql://invalid")
|
||||||
|
|
||||||
|
|
||||||
|
class TestKCConfig:
|
||||||
|
"""Test KCConfig."""
|
||||||
|
|
||||||
|
def test_default_values(self) -> None:
|
||||||
|
"""Test KCConfig default values."""
|
||||||
|
cfg = KCConfig(client_secret="test-secret")
|
||||||
|
assert cfg.server_url == "http://localhost:8080"
|
||||||
|
assert cfg.realm == "blog"
|
||||||
|
assert cfg.client_id == "blog-api"
|
||||||
|
assert cfg.client_secret == "test-secret"
|
||||||
|
assert cfg.token_cache_ttl == 60
|
||||||
|
|
||||||
|
def test_is_configured_with_secret(self) -> None:
|
||||||
|
"""Test is_configured returns True when secret is set."""
|
||||||
|
cfg = KCConfig(client_secret="test-secret")
|
||||||
|
assert cfg.is_configured is True
|
||||||
|
|
||||||
|
def test_is_configured_without_secret(self) -> None:
|
||||||
|
"""Test is_configured returns False when secret is empty."""
|
||||||
|
cfg = KCConfig(client_secret="")
|
||||||
|
assert cfg.is_configured is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestSecurityConfig:
|
||||||
|
"""Test SecurityConfig."""
|
||||||
|
|
||||||
|
def test_default_values(self) -> None:
|
||||||
|
"""Test SecurityConfig default values."""
|
||||||
|
cfg = SecurityConfig(secret_key="test-key")
|
||||||
|
assert cfg.secret_key == "test-key"
|
||||||
|
assert cfg.access_token_expire_minutes == 30
|
||||||
|
|
||||||
|
def test_is_configured_with_secret(self) -> None:
|
||||||
|
"""Test is_configured returns True when secret is set."""
|
||||||
|
cfg = SecurityConfig(secret_key="test-secret")
|
||||||
|
assert cfg.is_configured is True
|
||||||
|
|
||||||
|
def test_is_configured_without_secret(self) -> None:
|
||||||
|
"""Test is_configured returns False when secret is empty."""
|
||||||
|
cfg = SecurityConfig(secret_key="")
|
||||||
|
assert cfg.is_configured is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvironment:
|
||||||
|
"""Test Environment enum."""
|
||||||
|
|
||||||
|
def test_dev_value(self) -> None:
|
||||||
|
"""Test DEV environment value."""
|
||||||
|
assert Environment.DEV.value == "dev"
|
||||||
|
|
||||||
|
def test_prod_value(self) -> None:
|
||||||
|
"""Test PROD environment value."""
|
||||||
|
assert Environment.PROD.value == "prod"
|
||||||
|
|||||||
Reference in New Issue
Block a user