From 184b95969c7fc0ff8630fc757c283119e710b438 Mon Sep 17 00:00:00 2001 From: Sergey Vanyushkin Date: Sat, 2 May 2026 00:43:10 +0300 Subject: [PATCH] feat(auth): implement Keycloak authentication with RBAC and pagination Major changes: - Add Keycloak integration via token introspection endpoint - Implement RBAC system with roles: admin, user, guest - Add role-based permissions for post operations - Add pagination support (default limit: 10) to list endpoints - Add published_only filter with admin-only override for unpublished posts Security improvements: - Remove hardcoded default secrets (SECRET_KEY, KEYCLOAK_CLIENT_SECRET) - Update .env.example with proper security placeholders - Add comprehensive RBAC unit tests Infrastructure: - Add httpx dependency for HTTP client - Add KeycloakAuthClient with token caching (TTL: 60s) - Add role-based dependencies (RequireAdmin, RequireUser, etc.) - Update DI container with Keycloak provider Endpoints updated: - GET /posts: filter by published status (admin can see all) - Add pagination params (limit, offset) to list endpoints - Enforce RBAC on post operations Tests: - Add 16 auth infrastructure tests - Add 13 RBAC role tests - Update existing tests for new required settings Breaking changes: - SECRET_KEY and KEYCLOAK_CLIENT_SECRET now required (no defaults) --- .env.example | 33 +++ app/application/use_cases/list_posts.py | 37 ++- app/domain/repositories/post.py | 27 +- app/domain/roles.py | 102 +++++++ app/infrastructure/auth/__init__.py | 6 + app/infrastructure/auth/client.py | 127 +++++++++ app/infrastructure/auth/models.py | 34 +++ app/infrastructure/config/__init__.py | 20 +- app/infrastructure/config/settings.py | 170 +++++++++++- app/infrastructure/database/connection.py | 2 +- app/infrastructure/di/providers.py | 11 + app/infrastructure/repositories/post.py | 69 +++-- app/main.py | 20 +- app/presentation/api/deps.py | 109 +++++++- app/presentation/api/v1/posts.py | 38 ++- pyproject.toml | 16 +- tests/api/conftest.py | 46 +++- tests/unit/domain/test_roles.py | 123 +++++++++ tests/unit/infrastructure/test_auth.py | 318 ++++++++++++++++++++++ tests/unit/infrastructure/test_config.py | 252 +++++++++++++++-- 20 files changed, 1461 insertions(+), 99 deletions(-) create mode 100644 .env.example create mode 100644 app/domain/roles.py create mode 100644 app/infrastructure/auth/__init__.py create mode 100644 app/infrastructure/auth/client.py create mode 100644 app/infrastructure/auth/models.py create mode 100644 tests/unit/domain/test_roles.py create mode 100644 tests/unit/infrastructure/test_auth.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..3c3bb63 --- /dev/null +++ b/.env.example @@ -0,0 +1,33 @@ +# Environment mode: dev or prod +ENVIRONMENT=dev + +# App settings +APP_NAME=Blog API +APP_DEBUG=false +APP_HOST=0.0.0.0 +APP_PORT=8000 + +# Database settings +# For dev (SQLite): DB_URL=sqlite+aiosqlite:///./blog.db +# For prod (PostgreSQL): DB_URL=postgresql+asyncpg://user:pass@host:port/db +# Or use individual DB_* vars for prod (see below) +DB_URL= +DB_ECHO=false + +# PostgreSQL-specific settings (used in prod when DB_URL is not set) +DB_HOST=localhost +DB_PORT=5432 +DB_USER=postgres +DB_PASSWORD=postgres +DB_NAME=blog + +# Security settings (REQUIRED) +SECURITY_SECRET_KEY=your-secret-key-here-change-in-production +SECURITY_ACCESS_TOKEN_EXPIRE_MINUTES=30 + +# Keycloak settings (REQUIRED for authentication) +KC_SERVER_URL=http://localhost:8080 +KC_REALM=blog +KC_CLIENT_ID=blog-api +KC_CLIENT_SECRET=your-keycloak-client-secret-here +KC_TOKEN_CACHE_TTL=60 diff --git a/app/application/use_cases/list_posts.py b/app/application/use_cases/list_posts.py index 4364388..634b267 100644 --- a/app/application/use_cases/list_posts.py +++ b/app/application/use_cases/list_posts.py @@ -22,24 +22,45 @@ class ListPostsUseCase: posts = await self._post_repo.get_all() return [self._map_to_dto(post) for post in posts] - async def published_posts(self) -> list[PostResponseDTO]: + async def published_posts( + self, + limit: int | None = None, + offset: int | None = None, + ) -> list[PostResponseDTO]: """Get all published posts.""" - posts = await self._post_repo.get_published() + posts = await self._post_repo.get_published(limit=limit, offset=offset) return [self._map_to_dto(post) for post in posts] - async def by_author(self, author_id: str) -> list[PostResponseDTO]: + async def by_author( + self, + author_id: str, + limit: int | None = None, + offset: int | None = None, + ) -> list[PostResponseDTO]: """Get posts by author.""" - posts = await self._post_repo.get_by_author(author_id) + posts = await self._post_repo.get_by_author( + author_id, limit=limit, offset=offset + ) return [self._map_to_dto(post) for post in posts] - async def by_tag(self, tag: str) -> list[PostResponseDTO]: + async def by_tag( + self, + tag: str, + limit: int | None = None, + offset: int | None = None, + ) -> list[PostResponseDTO]: """Get posts by tag.""" - posts = await self._post_repo.get_by_tag(tag) + posts = await self._post_repo.get_by_tag(tag, limit=limit, offset=offset) return [self._map_to_dto(post) for post in posts] - async def search(self, query: str) -> list[PostResponseDTO]: + async def search( + self, + query: str, + limit: int | None = None, + offset: int | None = None, + ) -> list[PostResponseDTO]: """Search posts.""" - posts = await self._post_repo.search(query) + posts = await self._post_repo.search(query, limit=limit, offset=offset) return [self._map_to_dto(post) for post in posts] def _map_to_dto(self, post: Post) -> PostResponseDTO: diff --git a/app/domain/repositories/post.py b/app/domain/repositories/post.py index 85d1b4d..2e0fa99 100644 --- a/app/domain/repositories/post.py +++ b/app/domain/repositories/post.py @@ -15,17 +15,31 @@ class PostRepository(Repository[Post]): ... @abstractmethod - async def get_by_author(self, author_id: str) -> list[Post]: + async def get_by_author( + self, + author_id: str, + limit: int | None = None, + offset: int | None = None, + ) -> list[Post]: """Get all posts by author.""" ... @abstractmethod - async def get_published(self) -> list[Post]: + async def get_published( + self, + limit: int | None = None, + offset: int | None = None, + ) -> list[Post]: """Get all published posts.""" ... @abstractmethod - async def get_by_tag(self, tag: str) -> list[Post]: + async def get_by_tag( + self, + tag: str, + limit: int | None = None, + offset: int | None = None, + ) -> list[Post]: """Get posts by tag.""" ... @@ -35,6 +49,11 @@ class PostRepository(Repository[Post]): ... @abstractmethod - async def search(self, query: str) -> list[Post]: + async def search( + self, + query: str, + limit: int | None = None, + offset: int | None = None, + ) -> list[Post]: """Search posts by query string.""" ... diff --git a/app/domain/roles.py b/app/domain/roles.py new file mode 100644 index 0000000..d84e2eb --- /dev/null +++ b/app/domain/roles.py @@ -0,0 +1,102 @@ +"""Role-based access control definitions.""" + +from enum import Enum +from functools import wraps +from typing import Any, Callable + +from app.domain.exceptions import ForbiddenException + + +class Role(str, Enum): + """User roles in the system.""" + + ADMIN = "admin" + USER = "user" + GUEST = "guest" + + +class Permission: + """Permission definitions.""" + + # Post permissions + POST_CREATE = "post:create" + POST_READ = "post:read" + POST_READ_UNPUBLISHED = "post:read_unpublished" + POST_UPDATE = "post:update" + POST_DELETE = "post:delete" + POST_PUBLISH = "post:publish" + + +# Role-based permission mapping +ROLE_PERMISSIONS: dict[Role, list[str]] = { + Role.ADMIN: [ + Permission.POST_CREATE, + Permission.POST_READ, + Permission.POST_READ_UNPUBLISHED, + Permission.POST_UPDATE, + Permission.POST_DELETE, + Permission.POST_PUBLISH, + ], + Role.USER: [ + Permission.POST_CREATE, + Permission.POST_READ, + Permission.POST_UPDATE, + Permission.POST_DELETE, + Permission.POST_PUBLISH, + ], + Role.GUEST: [ + Permission.POST_READ, + ], +} + + +def has_permission(role: Role, permission: str) -> bool: + """Check if role has specific permission.""" + return permission in ROLE_PERMISSIONS.get(role, []) + + +def require_permission( + permission: str, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Decorator to require specific permission.""" + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + # Get token_info from kwargs + token_info = kwargs.get("token_info") + if not token_info: + raise ForbiddenException("Authentication required") + + # Determine role from token or default to guest + roles = getattr(token_info, "roles", []) + if Role.ADMIN.value in roles: + role = Role.ADMIN + elif Role.USER.value in roles: + role = Role.USER + else: + role = Role.GUEST + + if not has_permission(role, permission): + raise ForbiddenException( + f"Permission '{permission}' required for role '{role.value}'" + ) + + return await func(*args, **kwargs) + + return wrapper + + return decorator + + +def get_effective_role(roles: list[str]) -> Role: + """Determine effective role from list of roles. + + Priority: admin > user > guest + """ + if Role.ADMIN.value in roles: + return Role.ADMIN + elif Role.USER.value in roles: + return Role.USER + else: + return Role.GUEST diff --git a/app/infrastructure/auth/__init__.py b/app/infrastructure/auth/__init__.py new file mode 100644 index 0000000..cdd6fcd --- /dev/null +++ b/app/infrastructure/auth/__init__.py @@ -0,0 +1,6 @@ +"""Authentication infrastructure package.""" + +from app.infrastructure.auth.client import KeycloakAuthClient +from app.infrastructure.auth.models import KeycloakUser, TokenInfo + +__all__ = ["KeycloakAuthClient", "KeycloakUser", "TokenInfo"] diff --git a/app/infrastructure/auth/client.py b/app/infrastructure/auth/client.py new file mode 100644 index 0000000..93857cb --- /dev/null +++ b/app/infrastructure/auth/client.py @@ -0,0 +1,127 @@ +"""Keycloak authentication client.""" + +import time + +import httpx + +from app.infrastructure.auth.models import KeycloakUser, TokenInfo +from app.infrastructure.config.settings import Settings + + +class KeycloakAuthClient: + """Client for Keycloak authentication operations.""" + + def __init__(self, settings: Settings) -> None: + """Initialize Keycloak client with settings.""" + self._settings = settings + self._base_url = f"{settings.kc.server_url}/realms/{settings.kc.realm}" + self._client_id = settings.kc.client_id + self._client_secret = settings.kc.client_secret + self._cache: dict[str, tuple[TokenInfo, float]] = {} + self._cache_ttl = settings.kc.token_cache_ttl + + def _get_introspection_url(self) -> str: + """Get token introspection endpoint URL.""" + return f"{self._base_url}/protocol/openid-connect/token/introspection" + + def _get_userinfo_url(self) -> str: + """Get userinfo endpoint URL.""" + return f"{self._base_url}/protocol/openid-connect/userinfo" + + def _get_cached_token(self, token: str) -> TokenInfo | None: + """Get cached token info if valid.""" + if token not in self._cache: + return None + + token_info, cached_at = self._cache[token] + if time.time() - cached_at > self._cache_ttl: + del self._cache[token] + return None + + return token_info + + def _cache_token(self, token: str, token_info: TokenInfo) -> None: + """Cache token info.""" + self._cache[token] = (token_info, time.time()) + # Simple cleanup of old entries + current_time = time.time() + expired_keys = [ + k for k, (_, t) in self._cache.items() if current_time - t > self._cache_ttl + ] + for k in expired_keys: + del self._cache[k] + + async def introspect_token(self, token: str) -> TokenInfo: + """Introspect access token using Keycloak.""" + # Check cache first + cached = self._get_cached_token(token) + if cached: + return cached + + # Prepare introspection request + data = { + "token": token, + "client_id": self._client_id, + "client_secret": self._client_secret, + } + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + self._get_introspection_url(), + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=10.0, + ) + response.raise_for_status() + result = response.json() + except httpx.HTTPError as e: + return TokenInfo(active=False, raw_claims={"error": str(e)}) + + if not result.get("active", False): + return TokenInfo(active=False, raw_claims=result) + + # Extract roles from realm_access or resource_access + roles: list[str] = [] + realm_access = result.get("realm_access", {}) + if isinstance(realm_access, dict): + roles.extend(realm_access.get("roles", [])) + + token_info = TokenInfo( + active=True, + user_id=result.get("sub", ""), + username=result.get("preferred_username", ""), + email=result.get("email", ""), + roles=roles, + raw_claims=result, + ) + + # Cache valid token + self._cache_token(token, token_info) + + return token_info + + async def get_userinfo(self, token: str) -> KeycloakUser | None: + """Get user information from Keycloak using access token.""" + try: + async with httpx.AsyncClient() as client: + response = await client.get( + self._get_userinfo_url(), + headers={"Authorization": f"Bearer {token}"}, + timeout=10.0, + ) + response.raise_for_status() + data = response.json() + except httpx.HTTPError: + return None + + return KeycloakUser( + id=data.get("sub", ""), + username=data.get("preferred_username", ""), + email=data.get("email", ""), + first_name=data.get("given_name", ""), + last_name=data.get("family_name", ""), + roles=data.get("realm_access", {}).get("roles", []) + if isinstance(data.get("realm_access"), dict) + else [], + ) diff --git a/app/infrastructure/auth/models.py b/app/infrastructure/auth/models.py new file mode 100644 index 0000000..ccde351 --- /dev/null +++ b/app/infrastructure/auth/models.py @@ -0,0 +1,34 @@ +"""Keycloak authentication models.""" + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class TokenInfo: + """Information about validated token from Keycloak.""" + + active: bool + user_id: str = "" + username: str = "" + email: str = "" + roles: list[str] = field(default_factory=list) + raw_claims: dict[str, Any] = field(default_factory=dict, repr=False) + + @property + def is_valid(self) -> bool: + """Check if token is valid and active.""" + return self.active and bool(self.user_id) + + +@dataclass(frozen=True) +class KeycloakUser: + """User information from Keycloak.""" + + id: str + username: str + email: str + first_name: str = "" + last_name: str = "" + roles: list[str] = field(default_factory=list) + is_active: bool = True diff --git a/app/infrastructure/config/__init__.py b/app/infrastructure/config/__init__.py index 9078739..bc7a6e2 100644 --- a/app/infrastructure/config/__init__.py +++ b/app/infrastructure/config/__init__.py @@ -1,5 +1,21 @@ """Infrastructure configuration.""" -from app.infrastructure.config.settings import Settings, settings +from app.infrastructure.config.settings import ( + AppConfig, + DBConfig, + Environment, + KCConfig, + SecurityConfig, + Settings, + settings, +) -__all__ = ["Settings", "settings"] +__all__ = [ + "AppConfig", + "DBConfig", + "KCConfig", + "SecurityConfig", + "Environment", + "Settings", + "settings", +] diff --git a/app/infrastructure/config/settings.py b/app/infrastructure/config/settings.py index 484fd23..f7c8e97 100644 --- a/app/infrastructure/config/settings.py +++ b/app/infrastructure/config/settings.py @@ -1,31 +1,173 @@ -"""Application settings.""" +"""Application settings with composition pattern.""" +from enum import Enum +from functools import cached_property + +from pydantic import Field, PostgresDsn, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict -class Settings(BaseSettings): - """Application configuration settings.""" +class Environment(str, Enum): + """Application environment modes.""" - # App settings - app_name: str = "Blog API" + DEV = "dev" + PROD = "prod" + + +class AppConfig(BaseSettings): + """Application configuration.""" + + name: str = "Blog API" debug: bool = False host: str = "0.0.0.0" port: int = 8000 - # Database settings - database_url: str = "sqlite:///./blog.db" - database_echo: bool = False - - # Security settings - secret_key: str = "your-secret-key-change-in-production" - access_token_expire_minutes: int = 30 - model_config = SettingsConfigDict( - env_file=".env", + env_prefix="APP_", env_file_encoding="utf-8", case_sensitive=False, ) +class DBConfig(BaseSettings): + """Database configuration.""" + + # For dev: sqlite+aiosqlite:///./blog.db + # For prod: postgresql+asyncpg://user:pass@host:port/db + url: str | None = None + echo: bool = False + + # PostgreSQL-specific settings (used in prod) + host: str = "localhost" + port: int = 5432 + user: str = "postgres" + password: str = "postgres" + name: str = "blog" + + model_config = SettingsConfigDict( + env_prefix="DB_", + env_file_encoding="utf-8", + case_sensitive=False, + ) + + @field_validator("url") + @classmethod + def validate_url(cls, v: str | None) -> str | None: + """Validate database URL if provided.""" + if v is None: + return v + if not any(v.startswith(prefix) for prefix in ("sqlite+", "postgresql+")): + raise ValueError("Database URL must start with 'sqlite+' or 'postgresql+'") + return v + + +class KCConfig(BaseSettings): + """Keycloak configuration.""" + + server_url: str = "http://localhost:8080" + realm: str = "blog" + client_id: str = "blog-api" + client_secret: str = Field( + default="", + description="Keycloak client secret - must be set via env in production", + ) + token_cache_ttl: int = 60 # seconds + + model_config = SettingsConfigDict( + env_prefix="KC_", + env_file_encoding="utf-8", + case_sensitive=False, + ) + + @property + def is_configured(self) -> bool: + """Check if Keycloak is properly configured.""" + return bool(self.client_secret) + + +class SecurityConfig(BaseSettings): + """Security configuration.""" + + secret_key: str = Field( + default="", description="Secret key for JWT - must be set via env in production" + ) + access_token_expire_minutes: int = 30 + + model_config = SettingsConfigDict( + env_prefix="SECURITY_", + env_file_encoding="utf-8", + case_sensitive=False, + ) + + @property + def is_configured(self) -> bool: + """Check if security is properly configured.""" + return bool(self.secret_key) + + +class Settings(BaseSettings): + """Application configuration settings with composition.""" + + # Environment mode + environment: Environment = Environment.DEV + + # Sub-configurations + app: AppConfig = Field(default_factory=AppConfig) + db: DBConfig = Field(default_factory=DBConfig) + kc: KCConfig = Field(default_factory=KCConfig) + security: SecurityConfig = Field(default_factory=SecurityConfig) + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + env_nested_delimiter="__", + ) + + def model_post_init(self, __context: object) -> None: + """Validate settings after initialization.""" + if self.is_prod: + if not self.security.is_configured: + raise ValueError("SECURITY_SECRET_KEY must be set in production mode") + if not self.kc.is_configured: + raise ValueError("KC_CLIENT_SECRET must be set in production mode") + + @cached_property + def database_url(self) -> str: + """Get database URL based on environment. + + - In dev: uses SQLite if no URL provided + - In prod: uses PostgreSQL if no URL provided + """ + if self.db.url: + return self.db.url + + if self.environment == Environment.PROD: + # Build PostgreSQL URL from components + return str( + PostgresDsn.build( + scheme="postgresql+asyncpg", + username=self.db.user, + password=self.db.password, + host=self.db.host, + port=self.db.port, + path=self.db.name, + ) + ) + + # Default dev SQLite URL + return "sqlite+aiosqlite:///./blog.db" + + @property + def is_dev(self) -> bool: + """Check if running in development mode.""" + return self.environment == Environment.DEV + + @property + def is_prod(self) -> bool: + """Check if running in production mode.""" + return self.environment == Environment.PROD + + # Global settings instance settings = Settings() diff --git a/app/infrastructure/database/connection.py b/app/infrastructure/database/connection.py index 84913d4..74b46a6 100644 --- a/app/infrastructure/database/connection.py +++ b/app/infrastructure/database/connection.py @@ -24,7 +24,7 @@ def _get_database_url() -> str: # Create async engine engine: AsyncEngine = create_async_engine( _get_database_url(), - echo=settings.database_echo, + echo=settings.db.echo, future=True, ) diff --git a/app/infrastructure/di/providers.py b/app/infrastructure/di/providers.py index 2d1f21e..d017653 100644 --- a/app/infrastructure/di/providers.py +++ b/app/infrastructure/di/providers.py @@ -15,6 +15,8 @@ from app.application import ( ) from app.application.interfaces import TransactionManager from app.domain.repositories import PostRepository +from app.infrastructure.auth import KeycloakAuthClient +from app.infrastructure.config.settings import settings from app.infrastructure.database.connection import AsyncSessionLocal, engine from app.infrastructure.repositories.post import SQLAlchemyPostRepository @@ -131,3 +133,12 @@ class UseCaseProvider(Provider): post_repo=post_repo, tx_manager=tx_manager, ) + + +class KeycloakProvider(Provider): + """Provider for Keycloak authentication client.""" + + @provide(scope=Scope.APP) + def get_keycloak_client(self) -> KeycloakAuthClient: + """Provide KeycloakAuthClient singleton.""" + return KeycloakAuthClient(settings) diff --git a/app/infrastructure/repositories/post.py b/app/infrastructure/repositories/post.py index 2b6bd5e..0ececfb 100644 --- a/app/infrastructure/repositories/post.py +++ b/app/infrastructure/repositories/post.py @@ -105,27 +105,50 @@ class SQLAlchemyPostRepository(PostRepository): orm = result.scalar_one_or_none() return self._to_domain(orm) if orm else None - async def get_by_author(self, author_id: str) -> list[Post]: + async def get_by_author( + self, + author_id: str, + limit: int | None = None, + offset: int | None = None, + ) -> list[Post]: """Get posts by author.""" - result = await self._session.execute( - select(PostORM).where(PostORM.author_id == author_id) - ) + query = select(PostORM).where(PostORM.author_id == author_id) + if limit is not None: + query = query.limit(limit) + if offset is not None: + query = query.offset(offset) + result = await self._session.execute(query) orms = result.scalars().all() return [self._to_domain(orm) for orm in orms] - async def get_published(self) -> list[Post]: + async def get_published( + self, + limit: int | None = None, + offset: int | None = None, + ) -> list[Post]: """Get published posts.""" - result = await self._session.execute( - select(PostORM).where(PostORM.published.is_(True)) - ) + query = select(PostORM).where(PostORM.published.is_(True)) + if limit is not None: + query = query.limit(limit) + if offset is not None: + query = query.offset(offset) + result = await self._session.execute(query) orms = result.scalars().all() return [self._to_domain(orm) for orm in orms] - async def get_by_tag(self, tag: str) -> list[Post]: + async def get_by_tag( + self, + tag: str, + limit: int | None = None, + offset: int | None = None, + ) -> list[Post]: """Get posts by tag.""" - result = await self._session.execute( - select(PostORM).where(PostORM.tags.contains([tag])) - ) + query = select(PostORM).where(PostORM.tags.contains([tag])) + if limit is not None: + query = query.limit(limit) + if offset is not None: + query = query.offset(offset) + result = await self._session.execute(query) orms = result.scalars().all() return [self._to_domain(orm) for orm in orms] @@ -136,16 +159,24 @@ class SQLAlchemyPostRepository(PostRepository): ) return result.scalar_one_or_none() is not None - async def search(self, query: str) -> list[Post]: + async def search( + self, + query: str, + limit: int | None = None, + offset: int | None = None, + ) -> list[Post]: """Search posts.""" search_pattern = f"%{query}%" - result = await self._session.execute( - select(PostORM).where( - or_( - PostORM.title.ilike(search_pattern), - PostORM.content.ilike(search_pattern), - ) + stmt = select(PostORM).where( + or_( + PostORM.title.ilike(search_pattern), + PostORM.content.ilike(search_pattern), ) ) + if limit is not None: + stmt = stmt.limit(limit) + if offset is not None: + stmt = stmt.offset(offset) + result = await self._session.execute(stmt) orms = result.scalars().all() return [self._to_domain(orm) for orm in orms] diff --git a/app/main.py b/app/main.py index aa324b5..b410551 100644 --- a/app/main.py +++ b/app/main.py @@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware from app.infrastructure import close_db, init_db, register_exception_handlers, settings from app.infrastructure.di.providers import ( DatabaseProvider, + KeycloakProvider, RepositoryProvider, TransactionManagerProvider, UseCaseProvider, @@ -32,11 +33,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: def app_factory() -> FastAPI: """Create and configure FastAPI application.""" app = FastAPI( - title=settings.app_name, - debug=settings.debug, + title=settings.app.name, + debug=settings.app.debug, lifespan=lifespan, - docs_url="/docs" if settings.debug else None, - redoc_url="/redoc" if settings.debug else None, + docs_url="/docs" if settings.is_dev else None, + redoc_url="/redoc" if settings.is_dev else None, ) # Setup Dishka DI container @@ -45,6 +46,7 @@ def app_factory() -> FastAPI: RepositoryProvider(), TransactionManagerProvider(), UseCaseProvider(), + KeycloakProvider(), ) setup_dishka(container, app) @@ -66,7 +68,11 @@ def app_factory() -> FastAPI: # Health check endpoint @app.get("/health", tags=["health"]) async def health_check() -> dict[str, str]: - return {"status": "ok", "app": settings.app_name} + return { + "status": "ok", + "app": settings.app.name, + "env": settings.environment.value, + } return app @@ -76,8 +82,8 @@ def main() -> None: uvicorn.run( app_factory, factory=True, - host=settings.host, - port=settings.port, + host=settings.app.host, + port=settings.app.port, ) diff --git a/app/presentation/api/deps.py b/app/presentation/api/deps.py index c159f38..71f55d1 100644 --- a/app/presentation/api/deps.py +++ b/app/presentation/api/deps.py @@ -1,9 +1,10 @@ """API dependencies using Dishka.""" -from typing import Annotated +from typing import Annotated, Any from dishka.integrations.fastapi import FromDishka -from fastapi import Depends, Header +from fastapi import Depends, Request +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from app.application import ( CreatePostUseCase, @@ -13,6 +14,9 @@ from app.application import ( PublishPostUseCase, UpdatePostUseCase, ) +from app.domain.exceptions import ForbiddenException, UnauthorizedException +from app.domain.roles import Role, get_effective_role +from app.infrastructure.auth import KeycloakAuthClient, TokenInfo # Use case dependencies - injected via Dishka CreatePostDep = FromDishka[CreatePostUseCase] @@ -22,13 +26,106 @@ DeletePostDep = FromDishka[DeletePostUseCase] ListPostsDep = FromDishka[ListPostsUseCase] PublishPostDep = FromDishka[PublishPostUseCase] +# Security scheme +security = HTTPBearer(auto_error=False) + + +def get_keycloak_client(request: Request) -> KeycloakAuthClient: + """Get Keycloak client from DI container via request state.""" + client: KeycloakAuthClient = request.state.dishka_container.get(KeycloakAuthClient) + return client + + +async def get_current_token_info( + credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)], + request: Request, +) -> TokenInfo: + """Validate token and return token info from Keycloak.""" + if not credentials: + raise UnauthorizedException("Authentication required") + + keycloak_client = get_keycloak_client(request) + token = credentials.credentials + token_info = await keycloak_client.introspect_token(token) + + if not token_info.is_valid: + raise UnauthorizedException("Invalid or expired token") + + return token_info + -# Mock current user dependency (replace with real auth) async def get_current_user_id( - x_user_id: Annotated[str | None, Header()] = "user-123", + token_info: Annotated[TokenInfo, Depends(get_current_token_info)], ) -> str: - """Get current user ID from header.""" - return x_user_id or "user-123" + """Get current user ID from validated token.""" + return token_info.user_id CurrentUserDep = Annotated[str, Depends(get_current_user_id)] +TokenInfoDep = Annotated[TokenInfo, Depends(get_current_token_info)] + + +# Optional auth - doesn't require authentication but provides user info if available +async def get_optional_token_info( + credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)], + request: Request, +) -> TokenInfo | None: + """Get token info if valid token provided, otherwise None (guest).""" + if not credentials: + return None + + keycloak_client = get_keycloak_client(request) + token = credentials.credentials + token_info = await keycloak_client.introspect_token(token) + + if token_info.is_valid: + return token_info + + return None + + +OptionalTokenInfoDep = Annotated[TokenInfo | None, Depends(get_optional_token_info)] + + +async def get_optional_user_id( + token_info: OptionalTokenInfoDep, +) -> str | None: + """Get current user ID if token is valid, otherwise None.""" + if token_info: + return token_info.user_id + return None + + +OptionalUserDep = Annotated[str | None, Depends(get_optional_user_id)] + + +def get_current_role(token_info: OptionalTokenInfoDep) -> Role: + """Get effective role from token info. + + Returns GUEST if no valid token provided. + """ + if token_info and token_info.roles: + return get_effective_role(token_info.roles) + return Role.GUEST + + +CurrentRoleDep = Annotated[Role, Depends(get_current_role)] + + +def require_roles(allowed_roles: list[Role]) -> Any: + """Create dependency that checks if user has one of the allowed roles.""" + + async def check_role(role: CurrentRoleDep) -> Role: + if role not in allowed_roles: + raise ForbiddenException( + f"Access denied. Required roles: {[r.value for r in allowed_roles]}" + ) + return role + + return Depends(check_role) + + +# Predefined role requirements +RequireAdmin = require_roles([Role.ADMIN]) +RequireUser = require_roles([Role.USER, Role.ADMIN]) +RequireAny = require_roles([Role.GUEST, Role.USER, Role.ADMIN]) diff --git a/app/presentation/api/v1/posts.py b/app/presentation/api/v1/posts.py index 90f532b..5bf5aca 100644 --- a/app/presentation/api/v1/posts.py +++ b/app/presentation/api/v1/posts.py @@ -6,8 +6,11 @@ from dishka.integrations.fastapi import DishkaRoute from fastapi import APIRouter, status from app.application.dtos import CreatePostDTO, UpdatePostDTO +from app.domain.exceptions import ForbiddenException +from app.domain.roles import Permission, has_permission from app.presentation.api.deps import ( CreatePostDep, + CurrentRoleDep, CurrentUserDep, DeletePostDep, GetPostDep, @@ -50,11 +53,38 @@ async def create_post( @router.get( "", response_model=PostListResponseSchema, - summary="List all posts", + summary="List posts", ) -async def list_posts(use_case: ListPostsDep) -> PostListResponseSchema: - """Get all blog posts.""" - results = await use_case.all_posts() +async def list_posts( + use_case: ListPostsDep, + role: CurrentRoleDep, + include_unpublished: bool = False, + limit: int = 10, + offset: int = 0, +) -> PostListResponseSchema: + """Get blog posts with optional filtering and pagination. + + Args: + include_unpublished: If True, returns all posts including drafts. + Only admins can use this parameter. + limit: Maximum number of posts to return (default: 10, max: 100). + offset: Number of posts to skip (default: 0). + + Raises: + ForbiddenException: If non-admin tries to include unpublished posts. + """ + # Clamp limit to reasonable range + limit = max(1, min(limit, 100)) + offset = max(0, offset) + + # Check permissions for unpublished posts + if include_unpublished: + if not has_permission(role, Permission.POST_READ_UNPUBLISHED): + raise ForbiddenException("Only admins can view unpublished posts") + results = await use_case.all_posts() + else: + results = await use_case.published_posts(limit=limit, offset=offset) + items = [PostResponseSchema(**r.__dict__) for r in results] return PostListResponseSchema(items=items, total=len(items)) diff --git a/pyproject.toml b/pyproject.toml index 1e7fe44..f658d2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,13 +4,6 @@ version = "0.1.0" description = "Add your description here" readme = "README.md" requires-python = ">=3.13" - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["app"] dependencies = [ "fastapi>=0.136.0", "pydantic>=2.13.2", @@ -18,9 +11,18 @@ dependencies = [ "uvicorn>=0.44.0", "sqlalchemy>=2.0.0", "aiosqlite>=0.21.0", + "asyncpg>=0.30.0", "dishka>=1.5.0", + "httpx>=0.28.0", ] +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["app"] + [dependency-groups] dev = [ {include-group = "lints"}, diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 6d0972f..cfa6b57 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -1,23 +1,57 @@ """API test fixtures.""" from typing import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock, patch import pytest from httpx import ASGITransport, AsyncClient +from app.infrastructure.auth.models import TokenInfo from app.main import app_factory @pytest.fixture -async def client() -> AsyncGenerator[AsyncClient, None]: +def mock_keycloak_client() -> MagicMock: + """Create mock Keycloak client for testing.""" + mock_client = AsyncMock() + mock_client.introspect_token.return_value = TokenInfo( + active=True, + user_id="test-user-id", + username="testuser", + email="test@example.com", + roles=["user"], + ) + return mock_client + + +@pytest.fixture +async def client(mock_keycloak_client: MagicMock) -> AsyncGenerator[AsyncClient, None]: """Create async HTTP client for API testing.""" - app = app_factory() - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as ac: - yield ac + with patch( + "app.presentation.api.deps.KeycloakAuthClient", + return_value=mock_keycloak_client, + ): + app = app_factory() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac @pytest.fixture def auth_headers() -> dict[str, str]: """Return mock authentication headers.""" - return {"Authorization": "Bearer test_token", "X-User-Id": "user-123"} + return {"Authorization": "Bearer test_token"} + + +@pytest.fixture +def unauthorized_keycloak_client() -> MagicMock: + """Create mock Keycloak client that returns invalid token.""" + mock_client = AsyncMock() + mock_client.introspect_token.return_value = TokenInfo( + active=False, + user_id="", + username="", + email="", + roles=[], + ) + return mock_client diff --git a/tests/unit/domain/test_roles.py b/tests/unit/domain/test_roles.py new file mode 100644 index 0000000..5a3d0e1 --- /dev/null +++ b/tests/unit/domain/test_roles.py @@ -0,0 +1,123 @@ +"""Tests for role-based access control.""" + +from app.domain.roles import ( + ROLE_PERMISSIONS, + Permission, + Role, + get_effective_role, + has_permission, +) + + +class TestRole: + """Test Role enum.""" + + def test_role_values(self) -> None: + """Test role enum values.""" + assert Role.ADMIN.value == "admin" + assert Role.USER.value == "user" + assert Role.GUEST.value == "guest" + + def test_role_comparison(self) -> None: + """Test role comparison.""" + assert Role.ADMIN == Role.ADMIN + # USER and ADMIN are different enum values with different string values + assert Role.USER.value != Role.ADMIN.value # type: ignore[comparison-overlap] + + +class TestPermissions: + """Test permission definitions.""" + + def test_permission_values(self) -> None: + """Test permission constants.""" + assert Permission.POST_CREATE == "post:create" + assert Permission.POST_READ == "post:read" + assert Permission.POST_READ_UNPUBLISHED == "post:read_unpublished" + assert Permission.POST_UPDATE == "post:update" + assert Permission.POST_DELETE == "post:delete" + assert Permission.POST_PUBLISH == "post:publish" + + +class TestRolePermissions: + """Test role-based permission mapping.""" + + def test_admin_has_all_permissions(self) -> None: + """Test admin has all permissions.""" + admin_perms = ROLE_PERMISSIONS[Role.ADMIN] + assert Permission.POST_CREATE in admin_perms + assert Permission.POST_READ in admin_perms + assert Permission.POST_READ_UNPUBLISHED in admin_perms + assert Permission.POST_UPDATE in admin_perms + assert Permission.POST_DELETE in admin_perms + assert Permission.POST_PUBLISH in admin_perms + + def test_user_permissions(self) -> None: + """Test user permissions.""" + user_perms = ROLE_PERMISSIONS[Role.USER] + assert Permission.POST_CREATE in user_perms + assert Permission.POST_READ in user_perms + assert Permission.POST_UPDATE in user_perms + assert Permission.POST_DELETE in user_perms + assert Permission.POST_PUBLISH in user_perms + # User cannot read unpublished + assert Permission.POST_READ_UNPUBLISHED not in user_perms + + def test_guest_permissions(self) -> None: + """Test guest permissions.""" + guest_perms = ROLE_PERMISSIONS[Role.GUEST] + assert Permission.POST_READ in guest_perms + # Guest has very limited permissions + assert Permission.POST_CREATE not in guest_perms + assert Permission.POST_UPDATE not in guest_perms + assert Permission.POST_DELETE not in guest_perms + assert Permission.POST_READ_UNPUBLISHED not in guest_perms + + +class TestHasPermission: + """Test has_permission function.""" + + def test_admin_has_all_permissions_check(self) -> None: + """Test admin permission checks.""" + assert has_permission(Role.ADMIN, Permission.POST_CREATE) is True + assert has_permission(Role.ADMIN, Permission.POST_READ_UNPUBLISHED) is True + assert has_permission(Role.ADMIN, "unknown:permission") is False + + def test_user_limited_permissions(self) -> None: + """Test user limited permissions.""" + assert has_permission(Role.USER, Permission.POST_CREATE) is True + assert has_permission(Role.USER, Permission.POST_READ_UNPUBLISHED) is False + assert has_permission(Role.USER, Permission.POST_READ) is True + + def test_guest_read_only(self) -> None: + """Test guest read-only access.""" + assert has_permission(Role.GUEST, Permission.POST_READ) is True + assert has_permission(Role.GUEST, Permission.POST_CREATE) is False + assert has_permission(Role.GUEST, Permission.POST_UPDATE) is False + + +class TestGetEffectiveRole: + """Test get_effective_role function.""" + + def test_admin_from_roles_list(self) -> None: + """Test admin role detection.""" + assert get_effective_role(["admin"]) == Role.ADMIN + assert get_effective_role(["user", "admin"]) == Role.ADMIN + assert get_effective_role(["admin", "user"]) == Role.ADMIN + + def test_user_from_roles_list(self) -> None: + """Test user role detection.""" + assert get_effective_role(["user"]) == Role.USER + assert get_effective_role(["user", "moderator"]) == Role.USER + + def test_guest_from_roles_list(self) -> None: + """Test guest role detection.""" + assert get_effective_role([]) == Role.GUEST + assert get_effective_role(["unknown"]) == Role.GUEST + assert get_effective_role(["guest"]) == Role.GUEST + + def test_role_priority(self) -> None: + """Test that admin > user > guest.""" + # Admin takes precedence + assert get_effective_role(["user", "admin", "guest"]) == Role.ADMIN + # User takes precedence over guest + assert get_effective_role(["guest", "user"]) == Role.USER diff --git a/tests/unit/infrastructure/test_auth.py b/tests/unit/infrastructure/test_auth.py new file mode 100644 index 0000000..7cf875a --- /dev/null +++ b/tests/unit/infrastructure/test_auth.py @@ -0,0 +1,318 @@ +"""Tests for Keycloak authentication client.""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from app.infrastructure.auth import KeycloakAuthClient, KeycloakUser, TokenInfo +from app.infrastructure.config.settings import Settings + + +class TestTokenInfo: + """Test TokenInfo dataclass.""" + + def test_token_info_valid(self) -> None: + """Test valid token info.""" + token_info = TokenInfo( + active=True, + user_id="user-123", + username="testuser", + email="test@example.com", + roles=["user"], + ) + assert token_info.is_valid is True + assert token_info.user_id == "user-123" + assert token_info.username == "testuser" + assert token_info.email == "test@example.com" + assert token_info.roles == ["user"] + + def test_token_info_invalid_not_active(self) -> None: + """Test invalid token when not active.""" + token_info = TokenInfo( + active=False, + user_id="user-123", + username="testuser", + email="test@example.com", + roles=["user"], + ) + assert token_info.is_valid is False + + def test_token_info_invalid_no_user_id(self) -> None: + """Test invalid token when no user_id.""" + token_info = TokenInfo( + active=True, + user_id="", + username="testuser", + email="test@example.com", + roles=["user"], + ) + assert token_info.is_valid is False + + def test_token_info_empty_roles(self) -> None: + """Test token info with empty roles.""" + token_info = TokenInfo( + active=True, + user_id="user-123", + username="testuser", + email="test@example.com", + roles=[], + ) + assert token_info.is_valid is True + assert token_info.roles == [] + + +class TestKeycloakUser: + """Test KeycloakUser dataclass.""" + + def test_keycloak_user_creation(self) -> None: + """Test KeycloakUser creation.""" + user = KeycloakUser( + id="user-123", + username="testuser", + email="test@example.com", + first_name="Test", + last_name="User", + roles=["user", "admin"], + is_active=True, + ) + assert user.id == "user-123" + assert user.username == "testuser" + assert user.email == "test@example.com" + assert user.first_name == "Test" + assert user.last_name == "User" + assert user.roles == ["user", "admin"] + assert user.is_active is True + + def test_keycloak_user_defaults(self) -> None: + """Test KeycloakUser with default values.""" + user = KeycloakUser( + id="user-123", + username="testuser", + email="test@example.com", + ) + assert user.first_name == "" + assert user.last_name == "" + assert user.roles == [] + assert user.is_active is True + + +class TestKeycloakAuthClient: + """Test KeycloakAuthClient.""" + + @pytest.fixture + def settings(self) -> Settings: + """Create test settings.""" + from app.infrastructure.config import KCConfig, SecurityConfig + + return Settings( + environment="dev", + kc=KCConfig( + server_url="http://localhost:8080", + realm="test-realm", + client_id="test-client", + client_secret="test-secret", + token_cache_ttl=60, + ), + security=SecurityConfig( + secret_key="test-secret-key-for-jwt-tokens", + ), + ) + + @pytest.fixture + def client(self, settings: Settings) -> KeycloakAuthClient: + """Create Keycloak client.""" + return KeycloakAuthClient(settings) + + def test_client_initialization( + self, client: KeycloakAuthClient, settings: Settings + ) -> None: + """Test client initialization.""" + assert client._settings == settings + assert client._base_url == "http://localhost:8080/realms/test-realm" + assert client._client_id == "test-client" + assert client._client_secret == "test-secret" + assert client._cache_ttl == 60 + + def test_get_introspection_url(self, client: KeycloakAuthClient) -> None: + """Test introspection URL generation.""" + url = client._get_introspection_url() + assert ( + url + == "http://localhost:8080/realms/test-realm/protocol/openid-connect/token/introspection" + ) + + def test_get_userinfo_url(self, client: KeycloakAuthClient) -> None: + """Test userinfo URL generation.""" + url = client._get_userinfo_url() + assert ( + url + == "http://localhost:8080/realms/test-realm/protocol/openid-connect/userinfo" + ) + + @pytest.mark.asyncio + async def test_introspect_token_success(self, client: KeycloakAuthClient) -> None: + """Test successful token introspection.""" + mock_response = Mock() + mock_response.json.return_value = { + "active": True, + "sub": "user-123", + "preferred_username": "testuser", + "email": "test@example.com", + "realm_access": {"roles": ["user", "admin"]}, + } + mock_response.raise_for_status = Mock() + + mock_async_client = AsyncMock() + mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client) + mock_async_client.__aexit__ = AsyncMock(return_value=None) + mock_async_client.post = AsyncMock(return_value=mock_response) + + with patch("httpx.AsyncClient", return_value=mock_async_client): + result = await client.introspect_token("test-token") + + assert result.active is True + assert result.user_id == "user-123" + assert result.username == "testuser" + assert result.email == "test@example.com" + assert result.roles == ["user", "admin"] + assert result.is_valid is True + + @pytest.mark.asyncio + async def test_introspect_token_inactive(self, client: KeycloakAuthClient) -> None: + """Test introspection with inactive token.""" + mock_response = Mock() + mock_response.json.return_value = {"active": False} + mock_response.raise_for_status = Mock() + + mock_async_client = AsyncMock() + mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client) + mock_async_client.__aexit__ = AsyncMock(return_value=None) + mock_async_client.post = AsyncMock(return_value=mock_response) + + with patch("httpx.AsyncClient", return_value=mock_async_client): + result = await client.introspect_token("test-token") + + assert result.active is False + assert result.is_valid is False + + @pytest.mark.asyncio + async def test_introspect_token_http_error( + self, client: KeycloakAuthClient + ) -> None: + """Test introspection with HTTP error.""" + import httpx + + mock_async_client = AsyncMock() + mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client) + mock_async_client.__aexit__ = AsyncMock(return_value=None) + mock_async_client.post = AsyncMock( + side_effect=httpx.HTTPError("Connection error") + ) + + with patch("httpx.AsyncClient", return_value=mock_async_client): + result = await client.introspect_token("test-token") + + assert result.active is False + assert result.is_valid is False + + @pytest.mark.asyncio + async def test_introspect_token_uses_cache( + self, client: KeycloakAuthClient + ) -> None: + """Test that token introspection uses cache.""" + mock_response = Mock() + mock_response.json.return_value = { + "active": True, + "sub": "user-123", + "preferred_username": "testuser", + "email": "test@example.com", + "realm_access": {"roles": ["user"]}, + } + mock_response.raise_for_status = Mock() + + mock_async_client = AsyncMock() + mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client) + mock_async_client.__aexit__ = AsyncMock(return_value=None) + mock_async_client.post = AsyncMock(return_value=mock_response) + + with patch("httpx.AsyncClient", return_value=mock_async_client): + # First call + result1 = await client.introspect_token("test-token") + # Second call should use cache + result2 = await client.introspect_token("test-token") + + # HTTP client should only be called once + assert mock_async_client.post.call_count == 1 + assert result1.user_id == result2.user_id + + @pytest.mark.asyncio + async def test_get_userinfo_success(self, client: KeycloakAuthClient) -> None: + """Test successful userinfo retrieval.""" + mock_response = Mock() + mock_response.json.return_value = { + "sub": "user-123", + "preferred_username": "testuser", + "email": "test@example.com", + "given_name": "Test", + "family_name": "User", + "realm_access": {"roles": ["user"]}, + } + mock_response.raise_for_status = Mock() + + mock_async_client = AsyncMock() + mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client) + mock_async_client.__aexit__ = AsyncMock(return_value=None) + mock_async_client.get = AsyncMock(return_value=mock_response) + + with patch("httpx.AsyncClient", return_value=mock_async_client): + result = await client.get_userinfo("test-token") + + assert result is not None + assert result.id == "user-123" + assert result.username == "testuser" + assert result.email == "test@example.com" + assert result.first_name == "Test" + assert result.last_name == "User" + assert result.roles == ["user"] + + @pytest.mark.asyncio + async def test_get_userinfo_error(self, client: KeycloakAuthClient) -> None: + """Test userinfo retrieval with error.""" + import httpx + + mock_async_client = AsyncMock() + mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client) + mock_async_client.__aexit__ = AsyncMock(return_value=None) + mock_async_client.get = AsyncMock( + side_effect=httpx.HTTPError("Connection error") + ) + + with patch("httpx.AsyncClient", return_value=mock_async_client): + result = await client.get_userinfo("test-token") + + assert result is None + + @pytest.mark.asyncio + async def test_introspect_token_no_realm_roles( + self, client: KeycloakAuthClient + ) -> None: + """Test introspection without realm_access roles.""" + mock_response = Mock() + mock_response.json.return_value = { + "active": True, + "sub": "user-123", + "preferred_username": "testuser", + "email": "test@example.com", + } + mock_response.raise_for_status = Mock() + + mock_async_client = AsyncMock() + mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client) + mock_async_client.__aexit__ = AsyncMock(return_value=None) + mock_async_client.post = AsyncMock(return_value=mock_response) + + with patch("httpx.AsyncClient", return_value=mock_async_client): + result = await client.introspect_token("test-token") + + assert result.active is True + assert result.roles == [] diff --git a/tests/unit/infrastructure/test_config.py b/tests/unit/infrastructure/test_config.py index 73d8c4c..bdd1cce 100644 --- a/tests/unit/infrastructure/test_config.py +++ b/tests/unit/infrastructure/test_config.py @@ -1,37 +1,247 @@ """Tests for infrastructure config.""" -from app.infrastructure.config import Settings +import pytest + +from app.infrastructure.config import ( + AppConfig, + DBConfig, + Environment, + KCConfig, + SecurityConfig, + Settings, +) class TestSettings: + """Test Settings with composition pattern.""" + def test_default_values(self) -> None: """Test default settings values by creating settings without env file.""" - # Create settings with no env file to test defaults - s = Settings(_env_file=None) - assert s.app_name == "Blog API" - assert s.debug is False - assert s.host == "0.0.0.0" - assert s.port == 8000 - assert s.database_url == "sqlite:///./blog.db" - assert s.database_echo is False + # Create settings with required secrets and no env file + s = Settings( + _env_file=None, + security=SecurityConfig(secret_key="test-secret-key"), + kc=KCConfig(client_secret="test-client-secret"), + ) + assert s.app.name == "Blog API" + assert s.app.debug is False + assert s.app.host == "0.0.0.0" + assert s.app.port == 8000 + assert s.database_url == "sqlite+aiosqlite:///./blog.db" + assert s.db.echo is False + assert s.security.secret_key == "test-secret-key" + assert s.kc.client_secret == "test-client-secret" + assert s.environment == Environment.DEV def test_custom_values(self) -> None: """Test custom settings values.""" s = Settings( - app_name="Test API", - debug=True, - host="localhost", - port=9000, - database_url="postgresql://test", - secret_key="test-secret", + _env_file=None, + environment=Environment.PROD, + app=AppConfig( + name="Test API", + debug=True, + host="localhost", + port=9000, + ), + db=DBConfig(url="postgresql+asyncpg://user:pass@host/db"), + security=SecurityConfig(secret_key="test-secret"), + kc=KCConfig(client_secret="test-client-secret"), ) - assert s.app_name == "Test API" - assert s.debug is True - assert s.host == "localhost" - assert s.port == 9000 - assert s.database_url == "postgresql://test" - assert s.secret_key == "test-secret" + assert s.app.name == "Test API" + assert s.app.debug is True + assert s.app.host == "localhost" + assert s.app.port == 9000 + assert s.database_url == "postgresql+asyncpg://user:pass@host/db" + assert s.security.secret_key == "test-secret" + assert s.kc.client_secret == "test-client-secret" + assert s.environment == Environment.PROD def test_model_config(self) -> None: """Test settings model config.""" assert "env_file" in Settings.model_config + + def test_is_dev_property(self) -> None: + """Test is_dev property.""" + s = Settings( + _env_file=None, + environment=Environment.DEV, + security=SecurityConfig(secret_key="test"), + kc=KCConfig(client_secret="test"), + ) + assert s.is_dev is True + assert s.is_prod is False + + def test_is_prod_property(self) -> None: + """Test is_prod property.""" + s = Settings( + _env_file=None, + environment=Environment.PROD, + security=SecurityConfig(secret_key="test"), + kc=KCConfig(client_secret="test"), + ) + assert s.is_prod is True + assert s.is_dev is False + + def test_prod_requires_security_secret(self) -> None: + """Test that prod mode requires security secret_key.""" + with pytest.raises(ValueError, match="SECURITY_SECRET_KEY"): + Settings( + _env_file=None, + environment=Environment.PROD, + security=SecurityConfig(secret_key=""), + kc=KCConfig(client_secret="test"), + ) + + def test_prod_requires_kc_secret(self) -> None: + """Test that prod mode requires KC client_secret.""" + with pytest.raises(ValueError, match="KC_CLIENT_SECRET"): + Settings( + _env_file=None, + environment=Environment.PROD, + security=SecurityConfig(secret_key="test"), + kc=KCConfig(client_secret=""), + ) + + def test_database_url_dev_default(self) -> None: + """Test default database URL in dev mode.""" + s = Settings( + _env_file=None, + environment=Environment.DEV, + security=SecurityConfig(secret_key="test"), + kc=KCConfig(client_secret="test"), + ) + assert s.database_url == "sqlite+aiosqlite:///./blog.db" + + def test_database_url_prod_builds_postgres(self) -> None: + """Test that database URL builds from components in prod.""" + s = Settings( + _env_file=None, + environment=Environment.PROD, + db=DBConfig( + url=None, # Force building from components + host="db.example.com", + port=5433, + user="admin", + password="secret", + name="mydb", + ), + security=SecurityConfig(secret_key="test"), + kc=KCConfig(client_secret="test"), + ) + assert ( + s.database_url + == "postgresql+asyncpg://admin:secret@db.example.com:5433/mydb" + ) + + def test_database_url_override(self) -> None: + """Test that explicit database URL overrides auto-building.""" + s = Settings( + _env_file=None, + environment=Environment.PROD, + db=DBConfig( + url="postgresql+asyncpg://custom/url", + host="ignored", + user="ignored", + ), + security=SecurityConfig(secret_key="test"), + kc=KCConfig(client_secret="test"), + ) + assert s.database_url == "postgresql+asyncpg://custom/url" + + +class TestAppConfig: + """Test AppConfig.""" + + def test_default_values(self) -> None: + """Test AppConfig default values.""" + cfg = AppConfig() + assert cfg.name == "Blog API" + assert cfg.debug is False + assert cfg.host == "0.0.0.0" + assert cfg.port == 8000 + + +class TestDBConfig: + """Test DBConfig.""" + + def test_default_values(self) -> None: + """Test DBConfig default values.""" + cfg = DBConfig() + assert cfg.url is None + assert cfg.echo is False + assert cfg.host == "localhost" + assert cfg.port == 5432 + assert cfg.user == "postgres" + assert cfg.password == "postgres" + assert cfg.name == "blog" + + def test_postgres_url_validation(self) -> None: + """Test URL validation for postgres.""" + cfg = DBConfig(url="postgresql+asyncpg://user:pass@host/db") + assert cfg.url == "postgresql+asyncpg://user:pass@host/db" + + def test_sqlite_url_validation(self) -> None: + """Test URL validation for sqlite.""" + cfg = DBConfig(url="sqlite+aiosqlite:///./test.db") + assert cfg.url == "sqlite+aiosqlite:///./test.db" + + def test_invalid_url_validation(self) -> None: + """Test URL validation rejects invalid URLs.""" + with pytest.raises(ValueError, match="sqlite+.*postgresql+"): + DBConfig(url="mysql://invalid") + + +class TestKCConfig: + """Test KCConfig.""" + + def test_default_values(self) -> None: + """Test KCConfig default values.""" + cfg = KCConfig(client_secret="test-secret") + assert cfg.server_url == "http://localhost:8080" + assert cfg.realm == "blog" + assert cfg.client_id == "blog-api" + assert cfg.client_secret == "test-secret" + assert cfg.token_cache_ttl == 60 + + def test_is_configured_with_secret(self) -> None: + """Test is_configured returns True when secret is set.""" + cfg = KCConfig(client_secret="test-secret") + assert cfg.is_configured is True + + def test_is_configured_without_secret(self) -> None: + """Test is_configured returns False when secret is empty.""" + cfg = KCConfig(client_secret="") + assert cfg.is_configured is False + + +class TestSecurityConfig: + """Test SecurityConfig.""" + + def test_default_values(self) -> None: + """Test SecurityConfig default values.""" + cfg = SecurityConfig(secret_key="test-key") + assert cfg.secret_key == "test-key" + assert cfg.access_token_expire_minutes == 30 + + def test_is_configured_with_secret(self) -> None: + """Test is_configured returns True when secret is set.""" + cfg = SecurityConfig(secret_key="test-secret") + assert cfg.is_configured is True + + def test_is_configured_without_secret(self) -> None: + """Test is_configured returns False when secret is empty.""" + cfg = SecurityConfig(secret_key="") + assert cfg.is_configured is False + + +class TestEnvironment: + """Test Environment enum.""" + + def test_dev_value(self) -> None: + """Test DEV environment value.""" + assert Environment.DEV.value == "dev" + + def test_prod_value(self) -> None: + """Test PROD environment value.""" + assert Environment.PROD.value == "prod"