Compare commits
6 Commits
main
...
14adcaa3e6
| Author | SHA1 | Date | |
|---|---|---|---|
| 14adcaa3e6 | |||
| 1dbedf0f52 | |||
| 184b95969c | |||
| ddab62a883 | |||
| 87b094220d | |||
| b8334efa5a |
33
.env.example
Normal file
33
.env.example
Normal file
@@ -0,0 +1,33 @@
|
||||
# Environment mode: dev or prod
|
||||
ENVIRONMENT=dev
|
||||
|
||||
# App settings
|
||||
APP_NAME=Blog API
|
||||
APP_DEBUG=false
|
||||
APP_HOST=0.0.0.0
|
||||
APP_PORT=8000
|
||||
|
||||
# Database settings
|
||||
# For dev (SQLite): DB_URL=sqlite+aiosqlite:///./blog.db
|
||||
# For prod (PostgreSQL): DB_URL=postgresql+asyncpg://user:pass@host:port/db
|
||||
# Or use individual DB_* vars for prod (see below)
|
||||
DB_URL=
|
||||
DB_ECHO=false
|
||||
|
||||
# PostgreSQL-specific settings (used in prod when DB_URL is not set)
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_USER=postgres
|
||||
DB_PASSWORD=postgres
|
||||
DB_NAME=blog
|
||||
|
||||
# Security settings (REQUIRED)
|
||||
SECURITY_SECRET_KEY=your-secret-key-here-change-in-production
|
||||
SECURITY_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
|
||||
# Keycloak settings (REQUIRED for authentication)
|
||||
KC_SERVER_URL=http://localhost:8080
|
||||
KC_REALM=blog
|
||||
KC_CLIENT_ID=blog-api
|
||||
KC_CLIENT_SECRET=your-keycloak-client-secret-here
|
||||
KC_TOKEN_CACHE_TTL=60
|
||||
30
.github/PULL_REQUEST_TEMPLATE.md
vendored
30
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -1,30 +0,0 @@
|
||||
## Description
|
||||
<!-- Brief description of changes -->
|
||||
|
||||
## Type of Change
|
||||
<!-- Mark with [x] -->
|
||||
- [ ] 🚀 Feature (`feat`)
|
||||
- [ ] 🐛 Bug Fix (`fix`)
|
||||
- [ ] 📝 Documentation (`docs`)
|
||||
- [ ] ♻️ Refactor (`refactor`)
|
||||
- [ ] 🎨 Code Style (`style`)
|
||||
- [ ] ✅ Tests (`test`)
|
||||
- [ ] 🔧 Chore (`chore`)
|
||||
|
||||
## Checklist
|
||||
- [ ] Code follows project style guidelines (ruff, isort)
|
||||
- [ ] Tests added/updated (if applicable)
|
||||
- [ ] Documentation updated (if applicable)
|
||||
- [ ] Commit message follows convention (`type: description`)
|
||||
- [ ] Branch rebased to single commit before merge
|
||||
- [ ] No cache files in commit (`__pycache__`, `*.pyc`)
|
||||
|
||||
## Testing
|
||||
<!-- Describe how changes were tested -->
|
||||
|
||||
## Related Issues
|
||||
<!-- Link to issues if applicable -->
|
||||
Fixes #
|
||||
|
||||
## Screenshots (if applicable)
|
||||
<!-- Add screenshots for UI changes -->
|
||||
14
.gitignore
vendored
14
.gitignore
vendored
@@ -8,14 +8,6 @@ site/
|
||||
*.pyc
|
||||
*.pyo
|
||||
|
||||
# opencode skills (agent-only)
|
||||
.opencode/
|
||||
AGENTS.md
|
||||
.github/
|
||||
|
||||
# Scripts (except hooks)
|
||||
scripts/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
@@ -36,13 +28,9 @@ htmlcov/
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.example
|
||||
.venv/
|
||||
venv/
|
||||
|
||||
# uv cache
|
||||
.uv/
|
||||
|
||||
# Scripts cache
|
||||
scripts/__pycache__/
|
||||
|
||||
blog.db
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
when:
|
||||
event: [push, pull_request]
|
||||
|
||||
steps:
|
||||
- name: comment
|
||||
image: mcs94/gitea-comment
|
||||
settings:
|
||||
gitea_address: https://git.pyaqa.ru
|
||||
gitea_token:
|
||||
from_secret: gitea_token
|
||||
comment: >
|
||||
✅ Build ${CI_BUILD_EVENT} of `${CI_REPO_NAME}` has status `${CI_BUILD_STATUS}`.
|
||||
|
||||
📝 Commit by ${CI_COMMIT_AUTHOR} on `${CI_COMMIT_BRANCH}`:
|
||||
|
||||
`${CI_COMMIT_MESSAGE}`
|
||||
|
||||
🌐 ${CI_BUILD_LINK}
|
||||
|
||||
depends_on:
|
||||
- lint
|
||||
- type
|
||||
- test
|
||||
|
||||
@@ -4,11 +4,11 @@ when:
|
||||
|
||||
steps:
|
||||
- name: lint
|
||||
image: python:3.11
|
||||
image: python:3.13
|
||||
commands:
|
||||
- pip install uv
|
||||
- uv sync --no-dev --only-group lints
|
||||
- uv run black --check .
|
||||
- uv run ruff check .
|
||||
- uv run ruff format --check .
|
||||
- uv run isort --check-only .
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ when:
|
||||
|
||||
steps:
|
||||
- name: test
|
||||
image: python:3.11
|
||||
image: python:3.13
|
||||
commands:
|
||||
- pip install uv
|
||||
- uv sync --no-dev --group tests
|
||||
|
||||
@@ -4,7 +4,7 @@ when:
|
||||
|
||||
steps:
|
||||
- name: type
|
||||
image: python:3.11
|
||||
image: python:3.13
|
||||
commands:
|
||||
- pip install uv
|
||||
- uv sync --no-dev --only-group types
|
||||
|
||||
144
AGENTS.md
Normal file
144
AGENTS.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# Blog AGENTS.md
|
||||
|
||||
## Stack
|
||||
- Python 3.13+, FastAPI, pydantic, uvicorn
|
||||
- SQLAlchemy 2.0 (async), aiosqlite
|
||||
- Package manager: `uv`
|
||||
- CI: Woodpecker (lint, test, type on push/PR to `dev`)
|
||||
|
||||
## Commands
|
||||
```bash
|
||||
uv sync --group dev # Install all dev dependencies
|
||||
uv run pytest # Run tests (coverage >= 70% required)
|
||||
uv run pytest tests/unit/ # Run single test directory
|
||||
uv run ruff check . --fix # Lint
|
||||
uv run ruff format # Format
|
||||
uv run isort . # Sort imports
|
||||
uv run mypy . # Type check (strict mode)
|
||||
uv run blog # Start dev server (port 8000)
|
||||
```
|
||||
|
||||
## Pre-commit order
|
||||
`ruff check --fix` → `ruff format` → `isort` → `mypy`
|
||||
|
||||
## DDD Architecture
|
||||
|
||||
### Layer Structure
|
||||
```
|
||||
app/
|
||||
├── domain/ # Domain Layer - business logic, no dependencies
|
||||
│ ├── entities/ # Domain entities (Post, User, etc.)
|
||||
│ │ ├── base.py # Base entity class
|
||||
│ │ └── post.py # Post entity with business logic
|
||||
│ ├── value_objects/ # Value objects (Title, Content, Slug)
|
||||
│ │ ├── base.py
|
||||
│ │ ├── title.py
|
||||
│ │ ├── content.py
|
||||
│ │ └── slug.py
|
||||
│ ├── repositories/ # Repository interfaces (abstract)
|
||||
│ │ ├── base.py
|
||||
│ │ └── post.py
|
||||
│ └── exceptions.py # Domain exceptions
|
||||
│
|
||||
├── application/ # Application Layer - use cases
|
||||
│ ├── dtos/ # Data Transfer Objects
|
||||
│ │ └── post.py
|
||||
│ ├── interfaces/ # Abstract interfaces (UoW)
|
||||
│ │ └── unit_of_work.py
|
||||
│ └── use_cases/ # Use cases (CQRS-like)
|
||||
│ ├── create_post.py
|
||||
│ ├── get_post.py
|
||||
│ ├── update_post.py
|
||||
│ ├── delete_post.py
|
||||
│ ├── list_posts.py
|
||||
│ └── publish_post.py
|
||||
│
|
||||
├── infrastructure/ # Infrastructure Layer - external concerns
|
||||
│ ├── config/ # Configuration
|
||||
│ │ └── settings.py
|
||||
│ ├── database/ # Database connection & ORM models
|
||||
│ │ ├── connection.py
|
||||
│ │ └── models.py
|
||||
│ ├── repositories/ # Repository implementations
|
||||
│ │ ├── post.py # SQLAlchemyPostRepository
|
||||
│ │ └── unit_of_work.py # SQLAlchemyUnitOfWork
|
||||
│ ├── di/ # Dependency Injection
|
||||
│ │ └── container.py
|
||||
│ └── middleware/ # Exception handlers
|
||||
│ └── error_handler.py
|
||||
│
|
||||
├── presentation/ # Presentation Layer - API
|
||||
│ ├── api/ # FastAPI routes
|
||||
│ │ ├── v1/ # API version 1
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ └── posts.py # Posts endpoints
|
||||
│ │ ├── deps.py # FastAPI dependencies
|
||||
│ │ └── __init__.py
|
||||
│ └── schemas/ # Pydantic schemas
|
||||
│ └── post.py
|
||||
│
|
||||
└── main.py # Application entry point
|
||||
|
||||
tests/
|
||||
├── unit/ # Unit tests (domain, use cases)
|
||||
│ ├── domain/ # Domain layer tests
|
||||
│ ├── application/ # Application layer tests
|
||||
│ └── infrastructure/ # Infrastructure tests
|
||||
├── integration/ # Integration tests (DB, repos)
|
||||
├── api/ # API endpoint tests
|
||||
└── e2e/ # End-to-end tests
|
||||
```
|
||||
|
||||
## Key Conventions
|
||||
|
||||
### Dependency Rule
|
||||
- Domain layer has **NO dependencies** on other layers
|
||||
- Application layer depends only on Domain
|
||||
- Infrastructure depends on Domain and Application
|
||||
- Presentation depends on all other layers
|
||||
|
||||
### Testing
|
||||
- **Unit tests**: Test domain logic without DB/external services
|
||||
- **Integration tests**: Test repository implementations with real DB
|
||||
- **API tests**: Test endpoints with mocked use cases
|
||||
- **E2E tests**: Full workflow testing
|
||||
|
||||
### Code Patterns
|
||||
- Use **dataclasses** for entities and value objects
|
||||
- Use **frozen dataclasses** for value objects (immutable)
|
||||
- Use **Unit of Work** pattern for transactions
|
||||
- Use **Repository** pattern for data access
|
||||
- Use **Dependency Injection** via FastAPI's Depends()
|
||||
|
||||
## DDD Concepts Used
|
||||
|
||||
### Entities
|
||||
- Have identity (UUID)
|
||||
- Mutable state
|
||||
- Business logic methods (publish, update_title, etc.)
|
||||
- Example: `Post` entity
|
||||
|
||||
### Value Objects
|
||||
- Immutable
|
||||
- Defined by attributes
|
||||
- Validated on creation
|
||||
- Examples: `Title`, `Content`, `Slug`
|
||||
|
||||
### Aggregates & Repositories
|
||||
- `Post` is an aggregate root
|
||||
- `PostRepository` interface in Domain
|
||||
- `SQLAlchemyPostRepository` implementation in Infrastructure
|
||||
|
||||
### Domain Events
|
||||
- Placeholder for future implementation
|
||||
- Can be added via event bus in application layer
|
||||
|
||||
## Configuration
|
||||
- `.env` file loaded by pydantic-settings
|
||||
- Settings available via `app.infrastructure.config.settings`
|
||||
|
||||
## Database
|
||||
- SQLAlchemy 2.0 with async support
|
||||
- SQLite by default (aiosqlite)
|
||||
- Tables auto-created on startup
|
||||
- Use `init_db()` and `close_db()` in lifespan
|
||||
@@ -1 +0,0 @@
|
||||
"""API module - HTTP routes and endpoints."""
|
||||
@@ -1 +0,0 @@
|
||||
"""API version 1 endpoints."""
|
||||
28
app/application/__init__.py
Normal file
28
app/application/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Application layer exports."""
|
||||
|
||||
from app.application.dtos import CreatePostDTO, PostResponseDTO, UpdatePostDTO
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.application.use_cases import (
|
||||
CreatePostUseCase,
|
||||
DeletePostUseCase,
|
||||
GetPostUseCase,
|
||||
ListPostsUseCase,
|
||||
PublishPostUseCase,
|
||||
UpdatePostUseCase,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# DTOs
|
||||
"CreatePostDTO",
|
||||
"UpdatePostDTO",
|
||||
"PostResponseDTO",
|
||||
# Interfaces
|
||||
"TransactionManager",
|
||||
# Use Cases
|
||||
"CreatePostUseCase",
|
||||
"GetPostUseCase",
|
||||
"UpdatePostUseCase",
|
||||
"DeletePostUseCase",
|
||||
"ListPostsUseCase",
|
||||
"PublishPostUseCase",
|
||||
]
|
||||
5
app/application/dtos/__init__.py
Normal file
5
app/application/dtos/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Application DTOs."""
|
||||
|
||||
from app.application.dtos.post import CreatePostDTO, PostResponseDTO, UpdatePostDTO
|
||||
|
||||
__all__ = ["CreatePostDTO", "UpdatePostDTO", "PostResponseDTO"]
|
||||
39
app/application/dtos/post.py
Normal file
39
app/application/dtos/post.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""DTOs for post use cases."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CreatePostDTO:
|
||||
"""DTO for creating a post."""
|
||||
|
||||
title: str
|
||||
content: str
|
||||
author_id: str
|
||||
tags: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UpdatePostDTO:
|
||||
"""DTO for updating a post."""
|
||||
|
||||
title: str | None = None
|
||||
content: str | None = None
|
||||
tags: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostResponseDTO:
|
||||
"""DTO for post response."""
|
||||
|
||||
id: UUID
|
||||
title: str
|
||||
content: str
|
||||
slug: str
|
||||
author_id: str
|
||||
published: bool
|
||||
tags: list[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
5
app/application/interfaces/__init__.py
Normal file
5
app/application/interfaces/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Application interfaces."""
|
||||
|
||||
from app.application.interfaces.transaction_manager import TransactionManager
|
||||
|
||||
__all__ = ["TransactionManager"]
|
||||
17
app/application/interfaces/transaction_manager.py
Normal file
17
app/application/interfaces/transaction_manager.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Transaction Manager interface for managing database transactions."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class TransactionManager(ABC):
|
||||
"""Abstract Transaction Manager for controlling transaction boundaries."""
|
||||
|
||||
@abstractmethod
|
||||
async def commit(self) -> None:
|
||||
"""Commit the current transaction."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def rollback(self) -> None:
|
||||
"""Rollback the current transaction."""
|
||||
...
|
||||
17
app/application/use_cases/__init__.py
Normal file
17
app/application/use_cases/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Use cases."""
|
||||
|
||||
from app.application.use_cases.create_post import CreatePostUseCase
|
||||
from app.application.use_cases.delete_post import DeletePostUseCase
|
||||
from app.application.use_cases.get_post import GetPostUseCase
|
||||
from app.application.use_cases.list_posts import ListPostsUseCase
|
||||
from app.application.use_cases.publish_post import PublishPostUseCase
|
||||
from app.application.use_cases.update_post import UpdatePostUseCase
|
||||
|
||||
__all__ = [
|
||||
"CreatePostUseCase",
|
||||
"GetPostUseCase",
|
||||
"UpdatePostUseCase",
|
||||
"DeletePostUseCase",
|
||||
"ListPostsUseCase",
|
||||
"PublishPostUseCase",
|
||||
]
|
||||
60
app/application/use_cases/create_post.py
Normal file
60
app/application/use_cases/create_post.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Create post use case."""
|
||||
|
||||
from app.application.dtos.post import CreatePostDTO, PostResponseDTO
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.entities import Post
|
||||
from app.domain.exceptions import AlreadyExistsException
|
||||
from app.domain.repositories import PostRepository
|
||||
|
||||
|
||||
class CreatePostUseCase:
|
||||
"""Use case for creating a new blog post."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> None:
|
||||
self._post_repo = post_repo
|
||||
self._tx_manager = tx_manager
|
||||
|
||||
async def execute(self, dto: CreatePostDTO) -> PostResponseDTO:
|
||||
"""Execute the use case."""
|
||||
# Generate slug from title
|
||||
from app.domain.value_objects import Slug
|
||||
|
||||
slug = Slug.from_title(dto.title)
|
||||
|
||||
# Check if slug already exists
|
||||
if await self._post_repo.slug_exists(slug.value):
|
||||
raise AlreadyExistsException(f"Post with slug '{slug.value}' already exists")
|
||||
|
||||
# Create domain entity
|
||||
post = Post.create(
|
||||
title_str=dto.title,
|
||||
content_str=dto.content,
|
||||
author_id=dto.author_id,
|
||||
tags=dto.tags or [],
|
||||
)
|
||||
|
||||
# Persist entity
|
||||
await self._post_repo.add(post)
|
||||
|
||||
# Commit transaction
|
||||
await self._tx_manager.commit()
|
||||
|
||||
return self._map_to_dto(post)
|
||||
|
||||
def _map_to_dto(self, post: Post) -> PostResponseDTO:
|
||||
"""Map domain entity to response DTO."""
|
||||
return PostResponseDTO(
|
||||
id=post.id,
|
||||
title=post.title.value,
|
||||
content=post.content.value,
|
||||
slug=post.slug.value,
|
||||
author_id=post.author_id,
|
||||
published=post.published,
|
||||
tags=post.tags.copy(),
|
||||
created_at=post.created_at,
|
||||
updated_at=post.updated_at,
|
||||
)
|
||||
35
app/application/use_cases/delete_post.py
Normal file
35
app/application/use_cases/delete_post.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Delete post use case."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.exceptions import ForbiddenException, NotFoundException
|
||||
from app.domain.repositories import PostRepository
|
||||
|
||||
|
||||
class DeletePostUseCase:
|
||||
"""Use case for deleting a blog post."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> None:
|
||||
self._post_repo = post_repo
|
||||
self._tx_manager = tx_manager
|
||||
|
||||
async def execute(self, post_id: UUID, current_user_id: str) -> None:
|
||||
"""Execute the use case."""
|
||||
post = await self._post_repo.get_by_id(post_id)
|
||||
if not post:
|
||||
raise NotFoundException(f"Post with id '{post_id}' not found")
|
||||
|
||||
# Check authorization
|
||||
if post.author_id != current_user_id:
|
||||
raise ForbiddenException("You can only delete your own posts")
|
||||
|
||||
# Delete the post
|
||||
await self._post_repo.delete(post_id)
|
||||
|
||||
# Commit transaction
|
||||
await self._tx_manager.commit()
|
||||
49
app/application/use_cases/get_post.py
Normal file
49
app/application/use_cases/get_post.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Get post use case."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from app.application.dtos.post import PostResponseDTO
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.entities import Post
|
||||
from app.domain.exceptions import NotFoundException
|
||||
from app.domain.repositories import PostRepository
|
||||
|
||||
|
||||
class GetPostUseCase:
|
||||
"""Use case for retrieving a post by ID or slug."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> None:
|
||||
self._post_repo = post_repo
|
||||
self._tx_manager = tx_manager
|
||||
|
||||
async def by_id(self, post_id: UUID) -> PostResponseDTO:
|
||||
"""Get post by ID."""
|
||||
post = await self._post_repo.get_by_id(post_id)
|
||||
if not post:
|
||||
raise NotFoundException(f"Post with id '{post_id}' not found")
|
||||
return self._map_to_dto(post)
|
||||
|
||||
async def by_slug(self, slug: str) -> PostResponseDTO:
|
||||
"""Get post by slug."""
|
||||
post = await self._post_repo.get_by_slug(slug)
|
||||
if not post:
|
||||
raise NotFoundException(f"Post with slug '{slug}' not found")
|
||||
return self._map_to_dto(post)
|
||||
|
||||
def _map_to_dto(self, post: Post) -> PostResponseDTO:
|
||||
"""Map domain entity to response DTO."""
|
||||
return PostResponseDTO(
|
||||
id=post.id,
|
||||
title=post.title.value,
|
||||
content=post.content.value,
|
||||
slug=post.slug.value,
|
||||
author_id=post.author_id,
|
||||
published=post.published,
|
||||
tags=post.tags.copy(),
|
||||
created_at=post.created_at,
|
||||
updated_at=post.updated_at,
|
||||
)
|
||||
76
app/application/use_cases/list_posts.py
Normal file
76
app/application/use_cases/list_posts.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""List posts use case."""
|
||||
|
||||
from app.application.dtos.post import PostResponseDTO
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.entities import Post
|
||||
from app.domain.repositories import PostRepository
|
||||
|
||||
|
||||
class ListPostsUseCase:
|
||||
"""Use case for listing blog posts with filtering."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> None:
|
||||
self._post_repo = post_repo
|
||||
self._tx_manager = tx_manager
|
||||
|
||||
async def all_posts(self) -> list[PostResponseDTO]:
|
||||
"""Get all posts."""
|
||||
posts = await self._post_repo.get_all()
|
||||
return [self._map_to_dto(post) for post in posts]
|
||||
|
||||
async def published_posts(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[PostResponseDTO]:
|
||||
"""Get all published posts."""
|
||||
posts = await self._post_repo.get_published(limit=limit, offset=offset)
|
||||
return [self._map_to_dto(post) for post in posts]
|
||||
|
||||
async def by_author(
|
||||
self,
|
||||
author_id: str,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[PostResponseDTO]:
|
||||
"""Get posts by author."""
|
||||
posts = await self._post_repo.get_by_author(author_id, limit=limit, offset=offset)
|
||||
return [self._map_to_dto(post) for post in posts]
|
||||
|
||||
async def by_tag(
|
||||
self,
|
||||
tag: str,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[PostResponseDTO]:
|
||||
"""Get posts by tag."""
|
||||
posts = await self._post_repo.get_by_tag(tag, limit=limit, offset=offset)
|
||||
return [self._map_to_dto(post) for post in posts]
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[PostResponseDTO]:
|
||||
"""Search posts."""
|
||||
posts = await self._post_repo.search(query, limit=limit, offset=offset)
|
||||
return [self._map_to_dto(post) for post in posts]
|
||||
|
||||
def _map_to_dto(self, post: Post) -> PostResponseDTO:
|
||||
"""Map domain entity to response DTO."""
|
||||
return PostResponseDTO(
|
||||
id=post.id,
|
||||
title=post.title.value,
|
||||
content=post.content.value,
|
||||
slug=post.slug.value,
|
||||
author_id=post.author_id,
|
||||
published=post.published,
|
||||
tags=post.tags.copy(),
|
||||
created_at=post.created_at,
|
||||
updated_at=post.updated_at,
|
||||
)
|
||||
65
app/application/use_cases/publish_post.py
Normal file
65
app/application/use_cases/publish_post.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Publish post use case."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from app.application.dtos.post import PostResponseDTO
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.entities import Post
|
||||
from app.domain.exceptions import ForbiddenException, NotFoundException
|
||||
from app.domain.repositories import PostRepository
|
||||
|
||||
|
||||
class PublishPostUseCase:
|
||||
"""Use case for publishing/unpublishing a blog post."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> None:
|
||||
self._post_repo = post_repo
|
||||
self._tx_manager = tx_manager
|
||||
|
||||
async def publish(self, post_id: UUID, current_user_id: str) -> PostResponseDTO:
|
||||
"""Publish a post."""
|
||||
post = await self._post_repo.get_by_id(post_id)
|
||||
if not post:
|
||||
raise NotFoundException(f"Post with id '{post_id}' not found")
|
||||
|
||||
if post.author_id != current_user_id:
|
||||
raise ForbiddenException("You can only publish your own posts")
|
||||
|
||||
post.publish()
|
||||
await self._post_repo.update(post)
|
||||
await self._tx_manager.commit()
|
||||
|
||||
return self._map_to_dto(post)
|
||||
|
||||
async def unpublish(self, post_id: UUID, current_user_id: str) -> PostResponseDTO:
|
||||
"""Unpublish a post."""
|
||||
post = await self._post_repo.get_by_id(post_id)
|
||||
if not post:
|
||||
raise NotFoundException(f"Post with id '{post_id}' not found")
|
||||
|
||||
if post.author_id != current_user_id:
|
||||
raise ForbiddenException("You can only unpublish your own posts")
|
||||
|
||||
post.unpublish()
|
||||
await self._post_repo.update(post)
|
||||
await self._tx_manager.commit()
|
||||
|
||||
return self._map_to_dto(post)
|
||||
|
||||
def _map_to_dto(self, post: Post) -> PostResponseDTO:
|
||||
"""Map domain entity to response DTO."""
|
||||
return PostResponseDTO(
|
||||
id=post.id,
|
||||
title=post.title.value,
|
||||
content=post.content.value,
|
||||
slug=post.slug.value,
|
||||
author_id=post.author_id,
|
||||
published=post.published,
|
||||
tags=post.tags.copy(),
|
||||
created_at=post.created_at,
|
||||
updated_at=post.updated_at,
|
||||
)
|
||||
73
app/application/use_cases/update_post.py
Normal file
73
app/application/use_cases/update_post.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Update post use case."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from app.application.dtos.post import PostResponseDTO, UpdatePostDTO
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.entities import Post
|
||||
from app.domain.exceptions import ForbiddenException, NotFoundException
|
||||
from app.domain.repositories import PostRepository
|
||||
from app.domain.value_objects import Content, Title
|
||||
|
||||
|
||||
class UpdatePostUseCase:
|
||||
"""Use case for updating a blog post."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> None:
|
||||
self._post_repo = post_repo
|
||||
self._tx_manager = tx_manager
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
post_id: UUID,
|
||||
dto: UpdatePostDTO,
|
||||
current_user_id: str,
|
||||
) -> PostResponseDTO:
|
||||
"""Execute the use case."""
|
||||
post = await self._post_repo.get_by_id(post_id)
|
||||
if not post:
|
||||
raise NotFoundException(f"Post with id '{post_id}' not found")
|
||||
|
||||
# Check authorization
|
||||
if post.author_id != current_user_id:
|
||||
raise ForbiddenException("You can only update your own posts")
|
||||
|
||||
# Update fields
|
||||
if dto.title is not None:
|
||||
post.update_title(Title(dto.title))
|
||||
|
||||
if dto.content is not None:
|
||||
post.update_content(Content(dto.content))
|
||||
|
||||
if dto.tags is not None:
|
||||
# Replace all tags
|
||||
for tag in post.tags[:]:
|
||||
post.remove_tag(tag)
|
||||
for tag in dto.tags:
|
||||
post.add_tag(tag)
|
||||
|
||||
# Persist changes
|
||||
await self._post_repo.update(post)
|
||||
|
||||
# Commit transaction
|
||||
await self._tx_manager.commit()
|
||||
|
||||
return self._map_to_dto(post)
|
||||
|
||||
def _map_to_dto(self, post: Post) -> PostResponseDTO:
|
||||
"""Map domain entity to response DTO."""
|
||||
return PostResponseDTO(
|
||||
id=post.id,
|
||||
title=post.title.value,
|
||||
content=post.content.value,
|
||||
slug=post.slug.value,
|
||||
author_id=post.author_id,
|
||||
published=post.published,
|
||||
tags=post.tags.copy(),
|
||||
created_at=post.created_at,
|
||||
updated_at=post.updated_at,
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
"""Common utilities and shared components."""
|
||||
@@ -1,48 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from app.core.exceptions import AppException
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
status_code: int
|
||||
message: str
|
||||
details: dict[str, str] | None = None
|
||||
timestamp: str
|
||||
|
||||
|
||||
async def app_exception_handler(request: Request, exc: AppException) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"status_code": exc.status_code,
|
||||
"message": exc.message,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"status_code": exc.status_code,
|
||||
"message": str(exc.detail),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
app.add_exception_handler(
|
||||
AppException,
|
||||
app_exception_handler, # type: ignore[arg-type]
|
||||
)
|
||||
app.add_exception_handler(
|
||||
HTTPException,
|
||||
http_exception_handler, # type: ignore[arg-type]
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
"""Core module - shared functionality and configuration."""
|
||||
@@ -1,15 +0,0 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
app_name: str = "Blog API"
|
||||
debug: bool = False
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
database_url: str | None = None
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -1,25 +0,0 @@
|
||||
class AppException(Exception):
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class NotFoundError(AppException):
|
||||
def __init__(self, message: str = "Resource not found"):
|
||||
super().__init__(message, status_code=404)
|
||||
|
||||
|
||||
class ValidationError(AppException):
|
||||
def __init__(self, message: str = "Validation failed"):
|
||||
super().__init__(message, status_code=400)
|
||||
|
||||
|
||||
class UnauthorizedError(AppException):
|
||||
def __init__(self, message: str = "Unauthorized"):
|
||||
super().__init__(message, status_code=401)
|
||||
|
||||
|
||||
class ForbiddenError(AppException):
|
||||
def __init__(self, message: str = "Forbidden"):
|
||||
super().__init__(message, status_code=403)
|
||||
34
app/domain/__init__.py
Normal file
34
app/domain/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Domain layer exports."""
|
||||
|
||||
from app.domain.entities import BaseEntity, Post
|
||||
from app.domain.exceptions import (
|
||||
AlreadyExistsException,
|
||||
DomainException,
|
||||
ForbiddenException,
|
||||
NotFoundException,
|
||||
UnauthorizedException,
|
||||
ValidationException,
|
||||
)
|
||||
from app.domain.repositories import PostRepository, Repository
|
||||
from app.domain.value_objects import Content, Slug, Title, ValueObject
|
||||
|
||||
__all__ = [
|
||||
# Entities
|
||||
"BaseEntity",
|
||||
"Post",
|
||||
# Value Objects
|
||||
"ValueObject",
|
||||
"Title",
|
||||
"Content",
|
||||
"Slug",
|
||||
# Repositories
|
||||
"Repository",
|
||||
"PostRepository",
|
||||
# Exceptions
|
||||
"DomainException",
|
||||
"ValidationException",
|
||||
"NotFoundException",
|
||||
"AlreadyExistsException",
|
||||
"UnauthorizedException",
|
||||
"ForbiddenException",
|
||||
]
|
||||
6
app/domain/entities/__init__.py
Normal file
6
app/domain/entities/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Domain entities."""
|
||||
|
||||
from app.domain.entities.base import BaseEntity
|
||||
from app.domain.entities.post import Post
|
||||
|
||||
__all__ = ["BaseEntity", "Post"]
|
||||
33
app/domain/entities/base.py
Normal file
33
app/domain/entities/base.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Base entity for DDD domain layer."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class BaseEntity(ABC):
|
||||
"""Base class for all domain entities."""
|
||||
|
||||
id: UUID = field(default_factory=uuid4)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, BaseEntity):
|
||||
return NotImplemented
|
||||
return self.id == other.id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Update the updated_at timestamp."""
|
||||
self.updated_at = datetime.now(UTC)
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert entity to dictionary."""
|
||||
...
|
||||
88
app/domain/entities/post.py
Normal file
88
app/domain/entities/post.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Domain entity for Blog Post."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.domain.entities.base import BaseEntity
|
||||
from app.domain.value_objects.content import Content
|
||||
from app.domain.value_objects.slug import Slug
|
||||
from app.domain.value_objects.title import Title
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Post(BaseEntity):
|
||||
"""Blog post domain entity."""
|
||||
|
||||
title: Title
|
||||
content: Content
|
||||
slug: Slug
|
||||
author_id: str
|
||||
published: bool = False
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
def publish(self) -> None:
|
||||
"""Publish the post."""
|
||||
self.published = True
|
||||
self.touch()
|
||||
|
||||
def unpublish(self) -> None:
|
||||
"""Unpublish the post."""
|
||||
self.published = False
|
||||
self.touch()
|
||||
|
||||
def update_content(self, content: Content) -> None:
|
||||
"""Update post content."""
|
||||
self.content = content
|
||||
self.touch()
|
||||
|
||||
def update_title(self, title: Title) -> None:
|
||||
"""Update post title and regenerate slug."""
|
||||
self.title = title
|
||||
self.slug = Slug.from_title(title.value)
|
||||
self.touch()
|
||||
|
||||
def add_tag(self, tag: str) -> None:
|
||||
"""Add a tag to the post."""
|
||||
if tag not in self.tags:
|
||||
self.tags.append(tag)
|
||||
self.touch()
|
||||
|
||||
def remove_tag(self, tag: str) -> None:
|
||||
"""Remove a tag from the post."""
|
||||
if tag in self.tags:
|
||||
self.tags.remove(tag)
|
||||
self.touch()
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert entity to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"title": self.title.value,
|
||||
"content": self.content.value,
|
||||
"slug": self.slug.value,
|
||||
"author_id": self.author_id,
|
||||
"published": self.published,
|
||||
"tags": self.tags.copy(),
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
title_str: str,
|
||||
content_str: str,
|
||||
author_id: str,
|
||||
tags: list[str] | None = None,
|
||||
) -> "Post":
|
||||
"""Factory method to create a new post."""
|
||||
title = Title(title_str)
|
||||
content = Content(content_str)
|
||||
slug = Slug.from_title(title_str)
|
||||
return cls(
|
||||
title=title,
|
||||
content=content,
|
||||
slug=slug,
|
||||
author_id=author_id,
|
||||
tags=tags or [],
|
||||
)
|
||||
39
app/domain/exceptions.py
Normal file
39
app/domain/exceptions.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Domain exceptions."""
|
||||
|
||||
|
||||
class DomainException(Exception):
|
||||
"""Base exception for domain layer."""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class ValidationException(DomainException):
|
||||
"""Raised when validation fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NotFoundException(DomainException):
|
||||
"""Raised when an entity is not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AlreadyExistsException(DomainException):
|
||||
"""Raised when trying to create an entity that already exists."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnauthorizedException(DomainException):
|
||||
"""Raised when user is not authorized."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ForbiddenException(DomainException):
|
||||
"""Raised when access is forbidden."""
|
||||
|
||||
pass
|
||||
6
app/domain/repositories/__init__.py
Normal file
6
app/domain/repositories/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Repository interfaces."""
|
||||
|
||||
from app.domain.repositories.base import Repository
|
||||
from app.domain.repositories.post import PostRepository
|
||||
|
||||
__all__ = ["Repository", "PostRepository"]
|
||||
43
app/domain/repositories/base.py
Normal file
43
app/domain/repositories/base.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Base repository interface for DDD."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.domain.entities.base import BaseEntity
|
||||
|
||||
T = TypeVar("T", bound=BaseEntity)
|
||||
|
||||
|
||||
class Repository(ABC, Generic[T]):
|
||||
"""Generic repository interface."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_id(self, entity_id: UUID) -> T | None:
|
||||
"""Get entity by ID."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_all(self) -> list[T]:
|
||||
"""Get all entities."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def add(self, entity: T) -> None:
|
||||
"""Add new entity."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update(self, entity: T) -> None:
|
||||
"""Update existing entity."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, entity_id: UUID) -> None:
|
||||
"""Delete entity by ID."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def exists(self, entity_id: UUID) -> bool:
|
||||
"""Check if entity exists."""
|
||||
...
|
||||
59
app/domain/repositories/post.py
Normal file
59
app/domain/repositories/post.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Post repository interface."""
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from app.domain.entities.post import Post
|
||||
from app.domain.repositories.base import Repository
|
||||
|
||||
|
||||
class PostRepository(Repository[Post]):
|
||||
"""Repository interface for Blog Posts."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_slug(self, slug: str) -> Post | None:
|
||||
"""Get post by slug."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_author(
|
||||
self,
|
||||
author_id: str,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[Post]:
|
||||
"""Get all posts by author."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_published(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[Post]:
|
||||
"""Get all published posts."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_tag(
|
||||
self,
|
||||
tag: str,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[Post]:
|
||||
"""Get posts by tag."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def slug_exists(self, slug: str) -> bool:
|
||||
"""Check if slug already exists."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[Post]:
|
||||
"""Search posts by query string."""
|
||||
...
|
||||
103
app/domain/roles.py
Normal file
103
app/domain/roles.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Role-based access control definitions."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from app.domain.exceptions import ForbiddenException
|
||||
|
||||
|
||||
class Role(str, Enum):
|
||||
"""User roles in the system."""
|
||||
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
GUEST = "guest"
|
||||
|
||||
|
||||
class Permission:
|
||||
"""Permission definitions."""
|
||||
|
||||
# Post permissions
|
||||
POST_CREATE = "post:create"
|
||||
POST_READ = "post:read"
|
||||
POST_READ_UNPUBLISHED = "post:read_unpublished"
|
||||
POST_UPDATE = "post:update"
|
||||
POST_DELETE = "post:delete"
|
||||
POST_PUBLISH = "post:publish"
|
||||
|
||||
|
||||
# Role-based permission mapping
|
||||
ROLE_PERMISSIONS: dict[Role, list[str]] = {
|
||||
Role.ADMIN: [
|
||||
Permission.POST_CREATE,
|
||||
Permission.POST_READ,
|
||||
Permission.POST_READ_UNPUBLISHED,
|
||||
Permission.POST_UPDATE,
|
||||
Permission.POST_DELETE,
|
||||
Permission.POST_PUBLISH,
|
||||
],
|
||||
Role.USER: [
|
||||
Permission.POST_CREATE,
|
||||
Permission.POST_READ,
|
||||
Permission.POST_UPDATE,
|
||||
Permission.POST_DELETE,
|
||||
Permission.POST_PUBLISH,
|
||||
],
|
||||
Role.GUEST: [
|
||||
Permission.POST_READ,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def has_permission(role: Role, permission: str) -> bool:
|
||||
"""Check if role has specific permission."""
|
||||
return permission in ROLE_PERMISSIONS.get(role, [])
|
||||
|
||||
|
||||
def require_permission(
|
||||
permission: str,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""Decorator to require specific permission."""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
# Get token_info from kwargs
|
||||
token_info = kwargs.get("token_info")
|
||||
if not token_info:
|
||||
raise ForbiddenException("Authentication required")
|
||||
|
||||
# Determine role from token or default to guest
|
||||
roles = getattr(token_info, "roles", [])
|
||||
if Role.ADMIN.value in roles:
|
||||
role = Role.ADMIN
|
||||
elif Role.USER.value in roles:
|
||||
role = Role.USER
|
||||
else:
|
||||
role = Role.GUEST
|
||||
|
||||
if not has_permission(role, permission):
|
||||
raise ForbiddenException(
|
||||
f"Permission '{permission}' required for role '{role.value}'"
|
||||
)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_effective_role(roles: list[str]) -> Role:
|
||||
"""Determine effective role from list of roles.
|
||||
|
||||
Priority: admin > user > guest
|
||||
"""
|
||||
if Role.ADMIN.value in roles:
|
||||
return Role.ADMIN
|
||||
elif Role.USER.value in roles:
|
||||
return Role.USER
|
||||
else:
|
||||
return Role.GUEST
|
||||
8
app/domain/value_objects/__init__.py
Normal file
8
app/domain/value_objects/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Value objects."""
|
||||
|
||||
from app.domain.value_objects.base import ValueObject
|
||||
from app.domain.value_objects.content import Content
|
||||
from app.domain.value_objects.slug import Slug
|
||||
from app.domain.value_objects.title import Title
|
||||
|
||||
__all__ = ["ValueObject", "Title", "Content", "Slug"]
|
||||
37
app/domain/value_objects/base.py
Normal file
37
app/domain/value_objects/base.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Base value object for DDD domain layer."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ValueObject(ABC, Generic[T]):
|
||||
"""Base class for all value objects."""
|
||||
|
||||
value: T
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._validate()
|
||||
|
||||
@abstractmethod
|
||||
def _validate(self) -> None:
|
||||
"""Validate the value object. Raise ValueError if invalid."""
|
||||
...
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, ValueObject):
|
||||
return False
|
||||
return bool(self.value == other.value)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
def to_primitive(self) -> Any:
|
||||
"""Convert value object to primitive type."""
|
||||
return self.value
|
||||
23
app/domain/value_objects/content.py
Normal file
23
app/domain/value_objects/content.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Content value object."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.domain.value_objects.base import ValueObject
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Content(ValueObject[str]):
|
||||
"""Blog post content value object."""
|
||||
|
||||
MIN_LENGTH: int = 10
|
||||
MAX_LENGTH: int = 50000
|
||||
|
||||
def _validate(self) -> None:
|
||||
if not isinstance(self.value, str):
|
||||
raise ValueError("Content must be a string")
|
||||
if not self.value.strip():
|
||||
raise ValueError("Content cannot be empty or whitespace")
|
||||
if len(self.value) < self.MIN_LENGTH:
|
||||
raise ValueError(f"Content must be at least {self.MIN_LENGTH} characters")
|
||||
if len(self.value) > self.MAX_LENGTH:
|
||||
raise ValueError(f"Content must be at most {self.MAX_LENGTH} characters")
|
||||
39
app/domain/value_objects/slug.py
Normal file
39
app/domain/value_objects/slug.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Slug value object for URL-friendly identifiers."""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.domain.value_objects.base import ValueObject
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Slug(ValueObject[str]):
|
||||
"""URL slug value object."""
|
||||
|
||||
MAX_LENGTH: int = 200
|
||||
SLUG_PATTERN: str = r"^[a-z0-9]+(?:-[a-z0-9]+)*$"
|
||||
|
||||
def _validate(self) -> None:
|
||||
if not isinstance(self.value, str):
|
||||
raise ValueError("Slug must be a string")
|
||||
if len(self.value) > self.MAX_LENGTH:
|
||||
raise ValueError(f"Slug must be at most {self.MAX_LENGTH} characters")
|
||||
if not re.match(self.SLUG_PATTERN, self.value):
|
||||
raise ValueError("Slug must contain only lowercase letters, numbers, and hyphens")
|
||||
|
||||
@classmethod
|
||||
def from_title(cls, title: str) -> "Slug":
|
||||
"""Generate slug from title."""
|
||||
# Convert to lowercase, replace spaces with hyphens
|
||||
slug = title.lower().strip()
|
||||
# Keep only alphanumeric, spaces, and hyphens
|
||||
slug = re.sub(r"[^a-z0-9\s-]", "", slug)
|
||||
# Replace spaces and multiple hyphens with single hyphen
|
||||
slug = re.sub(r"[-\s]+", "-", slug)
|
||||
# Limit length and strip hyphens
|
||||
max_len = 200 # Same as MAX_LENGTH
|
||||
slug = slug[:max_len].strip("-")
|
||||
# Ensure we have at least one character
|
||||
if not slug:
|
||||
slug = "post"
|
||||
return cls(value=slug)
|
||||
23
app/domain/value_objects/title.py
Normal file
23
app/domain/value_objects/title.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Title value object."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.domain.value_objects.base import ValueObject
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Title(ValueObject[str]):
|
||||
"""Blog post title value object."""
|
||||
|
||||
MIN_LENGTH: int = 3
|
||||
MAX_LENGTH: int = 200
|
||||
|
||||
def _validate(self) -> None:
|
||||
if not isinstance(self.value, str):
|
||||
raise ValueError("Title must be a string")
|
||||
if len(self.value) < self.MIN_LENGTH:
|
||||
raise ValueError(f"Title must be at least {self.MIN_LENGTH} characters")
|
||||
if len(self.value) > self.MAX_LENGTH:
|
||||
raise ValueError(f"Title must be at most {self.MAX_LENGTH} characters")
|
||||
if not self.value.strip():
|
||||
raise ValueError("Title cannot be empty or whitespace")
|
||||
35
app/infrastructure/__init__.py
Normal file
35
app/infrastructure/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Infrastructure layer exports."""
|
||||
|
||||
from app.infrastructure.config import Settings, settings
|
||||
from app.infrastructure.database import (
|
||||
AsyncSessionLocal,
|
||||
Base,
|
||||
PostORM,
|
||||
close_db,
|
||||
engine,
|
||||
get_session,
|
||||
init_db,
|
||||
)
|
||||
from app.infrastructure.di import create_container
|
||||
from app.infrastructure.middleware import register_exception_handlers
|
||||
from app.infrastructure.repositories import SQLAlchemyPostRepository
|
||||
|
||||
__all__ = [
|
||||
# Config
|
||||
"Settings",
|
||||
"settings",
|
||||
# Database
|
||||
"Base",
|
||||
"PostORM",
|
||||
"engine",
|
||||
"AsyncSessionLocal",
|
||||
"get_session",
|
||||
"init_db",
|
||||
"close_db",
|
||||
# Repositories
|
||||
"SQLAlchemyPostRepository",
|
||||
# DI
|
||||
"create_container",
|
||||
# Middleware
|
||||
"register_exception_handlers",
|
||||
]
|
||||
6
app/infrastructure/auth/__init__.py
Normal file
6
app/infrastructure/auth/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Authentication infrastructure package."""
|
||||
|
||||
from app.infrastructure.auth.client import KeycloakAuthClient
|
||||
from app.infrastructure.auth.models import KeycloakUser, TokenInfo
|
||||
|
||||
__all__ = ["KeycloakAuthClient", "KeycloakUser", "TokenInfo"]
|
||||
127
app/infrastructure/auth/client.py
Normal file
127
app/infrastructure/auth/client.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Keycloak authentication client."""
|
||||
|
||||
import time
|
||||
|
||||
import httpx
|
||||
|
||||
from app.infrastructure.auth.models import KeycloakUser, TokenInfo
|
||||
from app.infrastructure.config.settings import Settings
|
||||
|
||||
|
||||
class KeycloakAuthClient:
|
||||
"""Client for Keycloak authentication operations."""
|
||||
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
"""Initialize Keycloak client with settings."""
|
||||
self._settings = settings
|
||||
self._base_url = f"{settings.kc.server_url}/realms/{settings.kc.realm}"
|
||||
self._client_id = settings.kc.client_id
|
||||
self._client_secret = settings.kc.client_secret
|
||||
self._cache: dict[str, tuple[TokenInfo, float]] = {}
|
||||
self._cache_ttl = settings.kc.token_cache_ttl
|
||||
|
||||
def _get_introspection_url(self) -> str:
|
||||
"""Get token introspection endpoint URL."""
|
||||
return f"{self._base_url}/protocol/openid-connect/token/introspection"
|
||||
|
||||
def _get_userinfo_url(self) -> str:
|
||||
"""Get userinfo endpoint URL."""
|
||||
return f"{self._base_url}/protocol/openid-connect/userinfo"
|
||||
|
||||
def _get_cached_token(self, token: str) -> TokenInfo | None:
|
||||
"""Get cached token info if valid."""
|
||||
if token not in self._cache:
|
||||
return None
|
||||
|
||||
token_info, cached_at = self._cache[token]
|
||||
if time.time() - cached_at > self._cache_ttl:
|
||||
del self._cache[token]
|
||||
return None
|
||||
|
||||
return token_info
|
||||
|
||||
def _cache_token(self, token: str, token_info: TokenInfo) -> None:
|
||||
"""Cache token info."""
|
||||
self._cache[token] = (token_info, time.time())
|
||||
# Simple cleanup of old entries
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
k for k, (_, t) in self._cache.items() if current_time - t > self._cache_ttl
|
||||
]
|
||||
for k in expired_keys:
|
||||
del self._cache[k]
|
||||
|
||||
async def introspect_token(self, token: str) -> TokenInfo:
|
||||
"""Introspect access token using Keycloak."""
|
||||
# Check cache first
|
||||
cached = self._get_cached_token(token)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Prepare introspection request
|
||||
data = {
|
||||
"token": token,
|
||||
"client_id": self._client_id,
|
||||
"client_secret": self._client_secret,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self._get_introspection_url(),
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=10.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except httpx.HTTPError as e:
|
||||
return TokenInfo(active=False, raw_claims={"error": str(e)})
|
||||
|
||||
if not result.get("active", False):
|
||||
return TokenInfo(active=False, raw_claims=result)
|
||||
|
||||
# Extract roles from realm_access or resource_access
|
||||
roles: list[str] = []
|
||||
realm_access = result.get("realm_access", {})
|
||||
if isinstance(realm_access, dict):
|
||||
roles.extend(realm_access.get("roles", []))
|
||||
|
||||
token_info = TokenInfo(
|
||||
active=True,
|
||||
user_id=result.get("sub", ""),
|
||||
username=result.get("preferred_username", ""),
|
||||
email=result.get("email", ""),
|
||||
roles=roles,
|
||||
raw_claims=result,
|
||||
)
|
||||
|
||||
# Cache valid token
|
||||
self._cache_token(token, token_info)
|
||||
|
||||
return token_info
|
||||
|
||||
async def get_userinfo(self, token: str) -> KeycloakUser | None:
|
||||
"""Get user information from Keycloak using access token."""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self._get_userinfo_url(),
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=10.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.HTTPError:
|
||||
return None
|
||||
|
||||
return KeycloakUser(
|
||||
id=data.get("sub", ""),
|
||||
username=data.get("preferred_username", ""),
|
||||
email=data.get("email", ""),
|
||||
first_name=data.get("given_name", ""),
|
||||
last_name=data.get("family_name", ""),
|
||||
roles=data.get("realm_access", {}).get("roles", [])
|
||||
if isinstance(data.get("realm_access"), dict)
|
||||
else [],
|
||||
)
|
||||
34
app/infrastructure/auth/models.py
Normal file
34
app/infrastructure/auth/models.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Keycloak authentication models."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenInfo:
|
||||
"""Information about validated token from Keycloak."""
|
||||
|
||||
active: bool
|
||||
user_id: str = ""
|
||||
username: str = ""
|
||||
email: str = ""
|
||||
roles: list[str] = field(default_factory=list)
|
||||
raw_claims: dict[str, Any] = field(default_factory=dict, repr=False)
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if token is valid and active."""
|
||||
return self.active and bool(self.user_id)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KeycloakUser:
|
||||
"""User information from Keycloak."""
|
||||
|
||||
id: str
|
||||
username: str
|
||||
email: str
|
||||
first_name: str = ""
|
||||
last_name: str = ""
|
||||
roles: list[str] = field(default_factory=list)
|
||||
is_active: bool = True
|
||||
21
app/infrastructure/config/__init__.py
Normal file
21
app/infrastructure/config/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Infrastructure configuration."""
|
||||
|
||||
from app.infrastructure.config.settings import (
|
||||
AppConfig,
|
||||
DBConfig,
|
||||
Environment,
|
||||
KCConfig,
|
||||
SecurityConfig,
|
||||
Settings,
|
||||
settings,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AppConfig",
|
||||
"DBConfig",
|
||||
"KCConfig",
|
||||
"SecurityConfig",
|
||||
"Environment",
|
||||
"Settings",
|
||||
"settings",
|
||||
]
|
||||
173
app/infrastructure/config/settings.py
Normal file
173
app/infrastructure/config/settings.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Application settings with composition pattern."""
|
||||
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
|
||||
from pydantic import Field, PostgresDsn, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Environment(str, Enum):
|
||||
"""Application environment modes."""
|
||||
|
||||
DEV = "dev"
|
||||
PROD = "prod"
|
||||
|
||||
|
||||
class AppConfig(BaseSettings):
|
||||
"""Application configuration."""
|
||||
|
||||
name: str = "Blog API"
|
||||
debug: bool = False
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="APP_",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
|
||||
class DBConfig(BaseSettings):
|
||||
"""Database configuration."""
|
||||
|
||||
# For dev: sqlite+aiosqlite:///./blog.db
|
||||
# For prod: postgresql+asyncpg://user:pass@host:port/db
|
||||
url: str | None = None
|
||||
echo: bool = False
|
||||
|
||||
# PostgreSQL-specific settings (used in prod)
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
user: str = "postgres"
|
||||
password: str = "postgres"
|
||||
name: str = "blog"
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="DB_",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def validate_url(cls, v: str | None) -> str | None:
|
||||
"""Validate database URL if provided."""
|
||||
if v is None:
|
||||
return v
|
||||
if not any(v.startswith(prefix) for prefix in ("sqlite+", "postgresql+")):
|
||||
raise ValueError("Database URL must start with 'sqlite+' or 'postgresql+'")
|
||||
return v
|
||||
|
||||
|
||||
class KCConfig(BaseSettings):
|
||||
"""Keycloak configuration."""
|
||||
|
||||
server_url: str = "http://localhost:8080"
|
||||
realm: str = "blog"
|
||||
client_id: str = "blog-api"
|
||||
client_secret: str = Field(
|
||||
default="",
|
||||
description="Keycloak client secret - must be set via env in production",
|
||||
)
|
||||
token_cache_ttl: int = 60 # seconds
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="KC_",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if Keycloak is properly configured."""
|
||||
return bool(self.client_secret)
|
||||
|
||||
|
||||
class SecurityConfig(BaseSettings):
|
||||
"""Security configuration."""
|
||||
|
||||
secret_key: str = Field(
|
||||
default="", description="Secret key for JWT - must be set via env in production"
|
||||
)
|
||||
access_token_expire_minutes: int = 30
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="SECURITY_",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if security is properly configured."""
|
||||
return bool(self.secret_key)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application configuration settings with composition."""
|
||||
|
||||
# Environment mode
|
||||
environment: Environment = Environment.DEV
|
||||
|
||||
# Sub-configurations
|
||||
app: AppConfig = Field(default_factory=AppConfig)
|
||||
db: DBConfig = Field(default_factory=DBConfig)
|
||||
kc: KCConfig = Field(default_factory=KCConfig)
|
||||
security: SecurityConfig = Field(default_factory=SecurityConfig)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
env_nested_delimiter="__",
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: object) -> None:
|
||||
"""Validate settings after initialization."""
|
||||
if self.is_prod:
|
||||
if not self.security.is_configured:
|
||||
raise ValueError("SECURITY_SECRET_KEY must be set in production mode")
|
||||
if not self.kc.is_configured:
|
||||
raise ValueError("KC_CLIENT_SECRET must be set in production mode")
|
||||
|
||||
@cached_property
|
||||
def database_url(self) -> str:
|
||||
"""Get database URL based on environment.
|
||||
|
||||
- In dev: uses SQLite if no URL provided
|
||||
- In prod: uses PostgreSQL if no URL provided
|
||||
"""
|
||||
if self.db.url:
|
||||
return self.db.url
|
||||
|
||||
if self.environment == Environment.PROD:
|
||||
# Build PostgreSQL URL from components
|
||||
return str(
|
||||
PostgresDsn.build(
|
||||
scheme="postgresql+asyncpg",
|
||||
username=self.db.user,
|
||||
password=self.db.password,
|
||||
host=self.db.host,
|
||||
port=self.db.port,
|
||||
path=self.db.name,
|
||||
)
|
||||
)
|
||||
|
||||
# Default dev SQLite URL
|
||||
return "sqlite+aiosqlite:///./blog.db"
|
||||
|
||||
@property
|
||||
def is_dev(self) -> bool:
|
||||
"""Check if running in development mode."""
|
||||
return self.environment == Environment.DEV
|
||||
|
||||
@property
|
||||
def is_prod(self) -> bool:
|
||||
"""Check if running in production mode."""
|
||||
return self.environment == Environment.PROD
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
22
app/infrastructure/database/__init__.py
Normal file
22
app/infrastructure/database/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Database infrastructure."""
|
||||
|
||||
from app.infrastructure.database.connection import (
|
||||
AsyncSessionLocal,
|
||||
close_db,
|
||||
engine,
|
||||
get_session,
|
||||
get_session_context,
|
||||
init_db,
|
||||
)
|
||||
from app.infrastructure.database.models import Base, PostORM
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"PostORM",
|
||||
"engine",
|
||||
"AsyncSessionLocal",
|
||||
"get_session",
|
||||
"get_session_context",
|
||||
"init_db",
|
||||
"close_db",
|
||||
]
|
||||
70
app/infrastructure/database/connection.py
Normal file
70
app/infrastructure/database/connection.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Database connection and session management."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from app.infrastructure.config import settings
|
||||
|
||||
|
||||
# Convert SQLite URL to async format if needed
|
||||
def _get_database_url() -> str:
|
||||
url = settings.database_url
|
||||
if url.startswith("sqlite:///") and not url.startswith("sqlite+aiosqlite:///"):
|
||||
return url.replace("sqlite:///", "sqlite+aiosqlite:///")
|
||||
return url
|
||||
|
||||
|
||||
# Create async engine
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
_get_database_url(),
|
||||
echo=settings.db.echo,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession]:
|
||||
"""Get database session."""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session_context() -> AsyncGenerator[AsyncSession]:
|
||||
"""Get database session as context manager."""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Initialize database tables."""
|
||||
from app.infrastructure.database.models import Base
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
"""Close database connections."""
|
||||
await engine.dispose()
|
||||
34
app/infrastructure/database/models.py
Normal file
34
app/infrastructure/database/models.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""SQLAlchemy ORM models."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import JSON, Boolean, DateTime, String, Text
|
||||
from sqlalchemy.orm import Mapped, declarative_base, mapped_column
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class PostORM(Base): # type: ignore[valid-type,misc]
|
||||
"""SQLAlchemy model for Blog Post."""
|
||||
|
||||
__tablename__ = "posts"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid4()))
|
||||
title: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
slug: Mapped[str] = mapped_column(String(200), nullable=False, unique=True, index=True)
|
||||
author_id: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||
published: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False, index=True)
|
||||
tags: Mapped[list[str]] = mapped_column(JSON, default=list)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
7
app/infrastructure/di/__init__.py
Normal file
7
app/infrastructure/di/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Dependency Injection using Dishka."""
|
||||
|
||||
from app.infrastructure.di.container import create_container
|
||||
|
||||
__all__ = [
|
||||
"create_container",
|
||||
]
|
||||
20
app/infrastructure/di/container.py
Normal file
20
app/infrastructure/di/container.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Dishka container setup."""
|
||||
|
||||
from dishka import AsyncContainer, make_async_container
|
||||
|
||||
from app.infrastructure.di.providers import (
|
||||
DatabaseProvider,
|
||||
RepositoryProvider,
|
||||
TransactionManagerProvider,
|
||||
UseCaseProvider,
|
||||
)
|
||||
|
||||
|
||||
def create_container() -> AsyncContainer:
|
||||
"""Create and configure Dishka container."""
|
||||
return make_async_container(
|
||||
DatabaseProvider(),
|
||||
RepositoryProvider(),
|
||||
TransactionManagerProvider(),
|
||||
UseCaseProvider(),
|
||||
)
|
||||
144
app/infrastructure/di/providers.py
Normal file
144
app/infrastructure/di/providers.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Dishka providers for dependency injection."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from dishka import Provider, Scope, provide
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||
|
||||
from app.application import (
|
||||
CreatePostUseCase,
|
||||
DeletePostUseCase,
|
||||
GetPostUseCase,
|
||||
ListPostsUseCase,
|
||||
PublishPostUseCase,
|
||||
UpdatePostUseCase,
|
||||
)
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.repositories import PostRepository
|
||||
from app.infrastructure.auth import KeycloakAuthClient
|
||||
from app.infrastructure.config.settings import settings
|
||||
from app.infrastructure.database.connection import AsyncSessionLocal, engine
|
||||
from app.infrastructure.repositories.post import SQLAlchemyPostRepository
|
||||
|
||||
|
||||
class DatabaseProvider(Provider):
|
||||
"""Provider for database-related dependencies."""
|
||||
|
||||
@provide(scope=Scope.APP)
|
||||
def get_engine(self) -> AsyncEngine:
|
||||
"""Provide SQLAlchemy engine."""
|
||||
return engine
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_session(self) -> AsyncGenerator[AsyncSession]:
|
||||
"""Provide database session per request."""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
class RepositoryProvider(Provider):
|
||||
"""Provider for repository implementations."""
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_post_repository(self, session: AsyncSession) -> PostRepository:
|
||||
"""Provide PostRepository implementation."""
|
||||
return SQLAlchemyPostRepository(session)
|
||||
|
||||
|
||||
class TransactionManagerProvider(Provider):
|
||||
"""Provider for transaction manager."""
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_transaction_manager(self, session: AsyncSession) -> TransactionManager:
|
||||
"""Provide TransactionManager implementation."""
|
||||
from app.infrastructure.di.transaction_manager import SessionTransactionManager
|
||||
|
||||
return SessionTransactionManager(session)
|
||||
|
||||
|
||||
class UseCaseProvider(Provider):
|
||||
"""Provider for use cases."""
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_create_post_use_case(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> CreatePostUseCase:
|
||||
"""Provide CreatePostUseCase."""
|
||||
return CreatePostUseCase(
|
||||
post_repo=post_repo,
|
||||
tx_manager=tx_manager,
|
||||
)
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_get_post_use_case(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> GetPostUseCase:
|
||||
"""Provide GetPostUseCase."""
|
||||
return GetPostUseCase(
|
||||
post_repo=post_repo,
|
||||
tx_manager=tx_manager,
|
||||
)
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_update_post_use_case(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> UpdatePostUseCase:
|
||||
"""Provide UpdatePostUseCase."""
|
||||
return UpdatePostUseCase(
|
||||
post_repo=post_repo,
|
||||
tx_manager=tx_manager,
|
||||
)
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_delete_post_use_case(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> DeletePostUseCase:
|
||||
"""Provide DeletePostUseCase."""
|
||||
return DeletePostUseCase(
|
||||
post_repo=post_repo,
|
||||
tx_manager=tx_manager,
|
||||
)
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_list_posts_use_case(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> ListPostsUseCase:
|
||||
"""Provide ListPostsUseCase."""
|
||||
return ListPostsUseCase(
|
||||
post_repo=post_repo,
|
||||
tx_manager=tx_manager,
|
||||
)
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_publish_post_use_case(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> PublishPostUseCase:
|
||||
"""Provide PublishPostUseCase."""
|
||||
return PublishPostUseCase(
|
||||
post_repo=post_repo,
|
||||
tx_manager=tx_manager,
|
||||
)
|
||||
|
||||
|
||||
class KeycloakProvider(Provider):
|
||||
"""Provider for Keycloak authentication client."""
|
||||
|
||||
@provide(scope=Scope.APP)
|
||||
def get_keycloak_client(self) -> KeycloakAuthClient:
|
||||
"""Provide KeycloakAuthClient singleton."""
|
||||
return KeycloakAuthClient(settings)
|
||||
24
app/infrastructure/di/transaction_manager.py
Normal file
24
app/infrastructure/di/transaction_manager.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""SQLAlchemy implementation of Transaction Manager."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.application.interfaces import TransactionManager
|
||||
|
||||
|
||||
class SessionTransactionManager(TransactionManager):
|
||||
"""SQLAlchemy Session-based Transaction Manager."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
self._committed: bool = False
|
||||
|
||||
async def commit(self) -> None:
|
||||
"""Commit the current transaction."""
|
||||
if not self._committed:
|
||||
await self._session.commit()
|
||||
self._committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
"""Rollback the current transaction."""
|
||||
if not self._committed:
|
||||
await self._session.rollback()
|
||||
15
app/infrastructure/middleware/__init__.py
Normal file
15
app/infrastructure/middleware/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Infrastructure middleware."""
|
||||
|
||||
from app.infrastructure.middleware.error_handler import (
|
||||
domain_exception_handler,
|
||||
generic_exception_handler,
|
||||
http_exception_handler,
|
||||
register_exception_handlers,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"domain_exception_handler",
|
||||
"http_exception_handler",
|
||||
"generic_exception_handler",
|
||||
"register_exception_handlers",
|
||||
]
|
||||
89
app/infrastructure/middleware/error_handler.py
Normal file
89
app/infrastructure/middleware/error_handler.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Exception handling middleware."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from app.domain.exceptions import (
|
||||
AlreadyExistsException,
|
||||
DomainException,
|
||||
ForbiddenException,
|
||||
NotFoundException,
|
||||
UnauthorizedException,
|
||||
ValidationException,
|
||||
)
|
||||
|
||||
|
||||
def get_status_code(exc: DomainException) -> int:
|
||||
"""Map domain exceptions to HTTP status codes."""
|
||||
match exc:
|
||||
case ValidationException():
|
||||
return 400
|
||||
case UnauthorizedException():
|
||||
return 401
|
||||
case ForbiddenException():
|
||||
return 403
|
||||
case NotFoundException():
|
||||
return 404
|
||||
case AlreadyExistsException():
|
||||
return 409
|
||||
case _:
|
||||
return 500
|
||||
|
||||
|
||||
async def domain_exception_handler(request: Request, exc: DomainException) -> JSONResponse:
|
||||
"""Handle domain exceptions."""
|
||||
status_code = get_status_code(exc)
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": exc.__class__.__name__,
|
||||
"message": exc.message,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"path": str(request.url.path),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse:
|
||||
"""Handle HTTP exceptions."""
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"error": "HTTPException",
|
||||
"message": str(exc.detail),
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"path": str(request.url.path),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""Handle generic exceptions."""
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": "InternalServerError",
|
||||
"message": "An unexpected error occurred",
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"path": str(request.url.path),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register all exception handlers with FastAPI app."""
|
||||
if not isinstance(app, FastAPI):
|
||||
raise TypeError("app must be a FastAPI instance")
|
||||
|
||||
# Domain exceptions
|
||||
app.add_exception_handler(DomainException, domain_exception_handler) # type: ignore[arg-type]
|
||||
|
||||
# HTTP exceptions
|
||||
app.add_exception_handler(StarletteHTTPException, http_exception_handler) # type: ignore[arg-type]
|
||||
|
||||
# Generic exceptions (only in production)
|
||||
# In development, let FastAPI show detailed traceback
|
||||
# app.add_exception_handler(Exception, generic_exception_handler)
|
||||
5
app/infrastructure/repositories/__init__.py
Normal file
5
app/infrastructure/repositories/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Repository implementations."""
|
||||
|
||||
from app.infrastructure.repositories.post import SQLAlchemyPostRepository
|
||||
|
||||
__all__ = ["SQLAlchemyPostRepository"]
|
||||
170
app/infrastructure/repositories/post.py
Normal file
170
app/infrastructure/repositories/post.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""SQLAlchemy implementation of PostRepository."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.domain.entities import Post
|
||||
from app.domain.repositories import PostRepository
|
||||
from app.domain.value_objects import Content, Slug, Title
|
||||
from app.infrastructure.database.models import PostORM
|
||||
|
||||
|
||||
class SQLAlchemyPostRepository(PostRepository):
|
||||
"""SQLAlchemy implementation of Post repository."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
def _to_domain(self, orm: PostORM) -> Post:
|
||||
"""Convert ORM model to domain entity."""
|
||||
return Post(
|
||||
id=UUID(orm.id),
|
||||
title=Title(orm.title),
|
||||
content=Content(orm.content),
|
||||
slug=Slug(orm.slug),
|
||||
author_id=orm.author_id,
|
||||
published=orm.published,
|
||||
tags=orm.tags or [],
|
||||
created_at=orm.created_at,
|
||||
updated_at=orm.updated_at,
|
||||
)
|
||||
|
||||
def _to_orm(self, post: Post) -> PostORM:
|
||||
"""Convert domain entity to ORM model."""
|
||||
return PostORM(
|
||||
id=str(post.id),
|
||||
title=post.title.value,
|
||||
content=post.content.value,
|
||||
slug=post.slug.value,
|
||||
author_id=post.author_id,
|
||||
published=post.published,
|
||||
tags=post.tags,
|
||||
created_at=post.created_at,
|
||||
updated_at=post.updated_at,
|
||||
)
|
||||
|
||||
async def get_by_id(self, entity_id: UUID) -> Post | None:
|
||||
"""Get post by ID."""
|
||||
result = await self._session.execute(select(PostORM).where(PostORM.id == str(entity_id)))
|
||||
orm = result.scalar_one_or_none()
|
||||
return self._to_domain(orm) if orm else None
|
||||
|
||||
async def get_all(self) -> list[Post]:
|
||||
"""Get all posts."""
|
||||
result = await self._session.execute(select(PostORM))
|
||||
orms = result.scalars().all()
|
||||
return [self._to_domain(orm) for orm in orms]
|
||||
|
||||
async def add(self, entity: Post) -> None:
|
||||
"""Add new post."""
|
||||
orm = self._to_orm(entity)
|
||||
self._session.add(orm)
|
||||
# Commit делает TransactionManager
|
||||
|
||||
async def update(self, entity: Post) -> None:
|
||||
"""Update existing post."""
|
||||
result = await self._session.execute(select(PostORM).where(PostORM.id == str(entity.id)))
|
||||
orm = result.scalar_one()
|
||||
|
||||
orm.title = entity.title.value
|
||||
orm.content = entity.content.value
|
||||
orm.slug = entity.slug.value
|
||||
orm.published = entity.published
|
||||
orm.tags = entity.tags
|
||||
orm.updated_at = entity.updated_at
|
||||
|
||||
# Commit делает TransactionManager
|
||||
|
||||
async def delete(self, entity_id: UUID) -> None:
|
||||
"""Delete post by ID."""
|
||||
result = await self._session.execute(select(PostORM).where(PostORM.id == str(entity_id)))
|
||||
orm = result.scalar_one_or_none()
|
||||
if orm:
|
||||
await self._session.delete(orm)
|
||||
|
||||
async def exists(self, entity_id: UUID) -> bool:
|
||||
"""Check if post exists."""
|
||||
result = await self._session.execute(select(PostORM).where(PostORM.id == str(entity_id)))
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
async def get_by_slug(self, slug: str) -> Post | None:
|
||||
"""Get post by slug."""
|
||||
result = await self._session.execute(select(PostORM).where(PostORM.slug == slug))
|
||||
orm = result.scalar_one_or_none()
|
||||
return self._to_domain(orm) if orm else None
|
||||
|
||||
async def get_by_author(
|
||||
self,
|
||||
author_id: str,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[Post]:
|
||||
"""Get posts by author."""
|
||||
query = select(PostORM).where(PostORM.author_id == author_id)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
if offset is not None:
|
||||
query = query.offset(offset)
|
||||
result = await self._session.execute(query)
|
||||
orms = result.scalars().all()
|
||||
return [self._to_domain(orm) for orm in orms]
|
||||
|
||||
async def get_published(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[Post]:
|
||||
"""Get published posts."""
|
||||
query = select(PostORM).where(PostORM.published.is_(True))
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
if offset is not None:
|
||||
query = query.offset(offset)
|
||||
result = await self._session.execute(query)
|
||||
orms = result.scalars().all()
|
||||
return [self._to_domain(orm) for orm in orms]
|
||||
|
||||
async def get_by_tag(
|
||||
self,
|
||||
tag: str,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[Post]:
|
||||
"""Get posts by tag."""
|
||||
query = select(PostORM).where(PostORM.tags.contains([tag]))
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
if offset is not None:
|
||||
query = query.offset(offset)
|
||||
result = await self._session.execute(query)
|
||||
orms = result.scalars().all()
|
||||
return [self._to_domain(orm) for orm in orms]
|
||||
|
||||
async def slug_exists(self, slug: str) -> bool:
|
||||
"""Check if slug exists."""
|
||||
result = await self._session.execute(select(PostORM).where(PostORM.slug == slug))
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[Post]:
|
||||
"""Search posts."""
|
||||
search_pattern = f"%{query}%"
|
||||
stmt = select(PostORM).where(
|
||||
or_(
|
||||
PostORM.title.ilike(search_pattern),
|
||||
PostORM.content.ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
if offset is not None:
|
||||
stmt = stmt.offset(offset)
|
||||
result = await self._session.execute(stmt)
|
||||
orms = result.scalars().all()
|
||||
return [self._to_domain(orm) for orm in orms]
|
||||
76
app/main.py
76
app/main.py
@@ -1,22 +1,90 @@
|
||||
"""Application entry point with DDD architecture."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import uvicorn
|
||||
from dishka import make_async_container
|
||||
from dishka.integrations.fastapi import setup_dishka
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.infrastructure import close_db, init_db, register_exception_handlers, settings
|
||||
from app.infrastructure.di.providers import (
|
||||
DatabaseProvider,
|
||||
KeycloakProvider,
|
||||
RepositoryProvider,
|
||||
TransactionManagerProvider,
|
||||
UseCaseProvider,
|
||||
)
|
||||
from app.presentation import router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
|
||||
"""Application lifespan manager."""
|
||||
# Startup
|
||||
await init_db()
|
||||
yield
|
||||
# Shutdown
|
||||
await close_db()
|
||||
|
||||
|
||||
def app_factory() -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
"""Create and configure FastAPI application."""
|
||||
app = FastAPI(
|
||||
title=settings.app.name,
|
||||
debug=settings.app.debug,
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs" if settings.is_dev else None,
|
||||
redoc_url="/redoc" if settings.is_dev else None,
|
||||
)
|
||||
|
||||
# Setup Dishka DI container
|
||||
container = make_async_container(
|
||||
DatabaseProvider(),
|
||||
RepositoryProvider(),
|
||||
TransactionManagerProvider(),
|
||||
UseCaseProvider(),
|
||||
KeycloakProvider(),
|
||||
)
|
||||
setup_dishka(container, app)
|
||||
|
||||
# Register exception handlers
|
||||
register_exception_handlers(app)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include API routes
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
# Health check endpoint
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health_check() -> dict[str, str]:
|
||||
return {
|
||||
"status": "ok",
|
||||
"app": settings.app.name,
|
||||
"env": settings.environment.value,
|
||||
}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
uvicorn.run(app_factory, factory=True, host="0.0.0.0", port=8000)
|
||||
"""Run the application."""
|
||||
uvicorn.run(
|
||||
app_factory,
|
||||
factory=True,
|
||||
host=settings.app.host,
|
||||
port=settings.app.port,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Feature modules - business logic organized by domain."""
|
||||
17
app/presentation/__init__.py
Normal file
17
app/presentation/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Presentation layer exports."""
|
||||
|
||||
from app.presentation.api import router
|
||||
from app.presentation.schemas import (
|
||||
PostCreateSchema,
|
||||
PostListResponseSchema,
|
||||
PostResponseSchema,
|
||||
PostUpdateSchema,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
"PostCreateSchema",
|
||||
"PostUpdateSchema",
|
||||
"PostResponseSchema",
|
||||
"PostListResponseSchema",
|
||||
]
|
||||
8
app/presentation/api/__init__.py
Normal file
8
app/presentation/api/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""API router configuration."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.presentation.api.v1 import router as v1_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(v1_router)
|
||||
131
app/presentation/api/deps.py
Normal file
131
app/presentation/api/deps.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""API dependencies using Dishka."""
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
from dishka.integrations.fastapi import FromDishka
|
||||
from fastapi import Depends, Request
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from app.application import (
|
||||
CreatePostUseCase,
|
||||
DeletePostUseCase,
|
||||
GetPostUseCase,
|
||||
ListPostsUseCase,
|
||||
PublishPostUseCase,
|
||||
UpdatePostUseCase,
|
||||
)
|
||||
from app.domain.exceptions import ForbiddenException, UnauthorizedException
|
||||
from app.domain.roles import Role, get_effective_role
|
||||
from app.infrastructure.auth import KeycloakAuthClient, TokenInfo
|
||||
|
||||
# Use case dependencies - injected via Dishka
|
||||
CreatePostDep = FromDishka[CreatePostUseCase]
|
||||
GetPostDep = FromDishka[GetPostUseCase]
|
||||
UpdatePostDep = FromDishka[UpdatePostUseCase]
|
||||
DeletePostDep = FromDishka[DeletePostUseCase]
|
||||
ListPostsDep = FromDishka[ListPostsUseCase]
|
||||
PublishPostDep = FromDishka[PublishPostUseCase]
|
||||
|
||||
# Security scheme
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
def get_keycloak_client(request: Request) -> KeycloakAuthClient:
|
||||
"""Get Keycloak client from DI container via request state."""
|
||||
client: KeycloakAuthClient = request.state.dishka_container.get(KeycloakAuthClient)
|
||||
return client
|
||||
|
||||
|
||||
async def get_current_token_info(
|
||||
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
|
||||
request: Request,
|
||||
) -> TokenInfo:
|
||||
"""Validate token and return token info from Keycloak."""
|
||||
if not credentials:
|
||||
raise UnauthorizedException("Authentication required")
|
||||
|
||||
keycloak_client = get_keycloak_client(request)
|
||||
token = credentials.credentials
|
||||
token_info = await keycloak_client.introspect_token(token)
|
||||
|
||||
if not token_info.is_valid:
|
||||
raise UnauthorizedException("Invalid or expired token")
|
||||
|
||||
return token_info
|
||||
|
||||
|
||||
async def get_current_user_id(
|
||||
token_info: Annotated[TokenInfo, Depends(get_current_token_info)],
|
||||
) -> str:
|
||||
"""Get current user ID from validated token."""
|
||||
return token_info.user_id
|
||||
|
||||
|
||||
CurrentUserDep = Annotated[str, Depends(get_current_user_id)]
|
||||
TokenInfoDep = Annotated[TokenInfo, Depends(get_current_token_info)]
|
||||
|
||||
|
||||
# Optional auth - doesn't require authentication but provides user info if available
|
||||
async def get_optional_token_info(
|
||||
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
|
||||
request: Request,
|
||||
) -> TokenInfo | None:
|
||||
"""Get token info if valid token provided, otherwise None (guest)."""
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
keycloak_client = get_keycloak_client(request)
|
||||
token = credentials.credentials
|
||||
token_info = await keycloak_client.introspect_token(token)
|
||||
|
||||
if token_info.is_valid:
|
||||
return token_info
|
||||
|
||||
return None
|
||||
|
||||
|
||||
OptionalTokenInfoDep = Annotated[TokenInfo | None, Depends(get_optional_token_info)]
|
||||
|
||||
|
||||
async def get_optional_user_id(
|
||||
token_info: OptionalTokenInfoDep,
|
||||
) -> str | None:
|
||||
"""Get current user ID if token is valid, otherwise None."""
|
||||
if token_info:
|
||||
return token_info.user_id
|
||||
return None
|
||||
|
||||
|
||||
OptionalUserDep = Annotated[str | None, Depends(get_optional_user_id)]
|
||||
|
||||
|
||||
def get_current_role(token_info: OptionalTokenInfoDep) -> Role:
|
||||
"""Get effective role from token info.
|
||||
|
||||
Returns GUEST if no valid token provided.
|
||||
"""
|
||||
if token_info and token_info.roles:
|
||||
return get_effective_role(token_info.roles)
|
||||
return Role.GUEST
|
||||
|
||||
|
||||
CurrentRoleDep = Annotated[Role, Depends(get_current_role)]
|
||||
|
||||
|
||||
def require_roles(allowed_roles: list[Role]) -> Any:
|
||||
"""Create dependency that checks if user has one of the allowed roles."""
|
||||
|
||||
async def check_role(role: CurrentRoleDep) -> Role:
|
||||
if role not in allowed_roles:
|
||||
raise ForbiddenException(
|
||||
f"Access denied. Required roles: {[r.value for r in allowed_roles]}"
|
||||
)
|
||||
return role
|
||||
|
||||
return Depends(check_role)
|
||||
|
||||
|
||||
# Predefined role requirements
|
||||
RequireAdmin = require_roles([Role.ADMIN])
|
||||
RequireUser = require_roles([Role.USER, Role.ADMIN])
|
||||
RequireAny = require_roles([Role.GUEST, Role.USER, Role.ADMIN])
|
||||
8
app/presentation/api/v1/__init__.py
Normal file
8
app/presentation/api/v1/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""API v1 router."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.presentation.api.v1.posts import router as posts_router
|
||||
|
||||
router = APIRouter(prefix="/v1")
|
||||
router.include_router(posts_router)
|
||||
241
app/presentation/api/v1/posts.py
Normal file
241
app/presentation/api/v1/posts.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Posts API routes."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from dishka.integrations.fastapi import DishkaRoute
|
||||
from fastapi import APIRouter, status
|
||||
|
||||
from app.application.dtos import CreatePostDTO, UpdatePostDTO
|
||||
from app.domain.exceptions import ForbiddenException
|
||||
from app.domain.roles import Permission, has_permission
|
||||
from app.presentation.api.deps import (
|
||||
CreatePostDep,
|
||||
CurrentRoleDep,
|
||||
CurrentUserDep,
|
||||
DeletePostDep,
|
||||
GetPostDep,
|
||||
ListPostsDep,
|
||||
PublishPostDep,
|
||||
UpdatePostDep,
|
||||
)
|
||||
from app.presentation.schemas import (
|
||||
PostCreateSchema,
|
||||
PostListResponseSchema,
|
||||
PostResponseSchema,
|
||||
PostUpdateSchema,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/posts", tags=["posts"], route_class=DishkaRoute)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=PostResponseSchema,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new post",
|
||||
)
|
||||
async def create_post(
|
||||
schema: PostCreateSchema,
|
||||
use_case: CreatePostDep,
|
||||
current_user_id: CurrentUserDep,
|
||||
) -> PostResponseSchema:
|
||||
"""Create a new blog post."""
|
||||
dto = CreatePostDTO(
|
||||
title=schema.title,
|
||||
content=schema.content,
|
||||
author_id=current_user_id,
|
||||
tags=schema.tags,
|
||||
)
|
||||
result = await use_case.execute(dto)
|
||||
return PostResponseSchema(**result.__dict__)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=PostListResponseSchema,
|
||||
summary="List posts",
|
||||
)
|
||||
async def list_posts(
|
||||
use_case: ListPostsDep,
|
||||
role: CurrentRoleDep,
|
||||
include_unpublished: bool = False,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
) -> PostListResponseSchema:
|
||||
"""Get blog posts with optional filtering and pagination.
|
||||
|
||||
Args:
|
||||
include_unpublished: If True, returns all posts including drafts.
|
||||
Only admins can use this parameter.
|
||||
limit: Maximum number of posts to return (default: 10, max: 100).
|
||||
offset: Number of posts to skip (default: 0).
|
||||
|
||||
Raises:
|
||||
ForbiddenException: If non-admin tries to include unpublished posts.
|
||||
"""
|
||||
# Clamp limit to reasonable range
|
||||
limit = max(1, min(limit, 100))
|
||||
offset = max(0, offset)
|
||||
|
||||
# Check permissions for unpublished posts
|
||||
if include_unpublished:
|
||||
if not has_permission(role, Permission.POST_READ_UNPUBLISHED):
|
||||
raise ForbiddenException("Only admins can view unpublished posts")
|
||||
results = await use_case.all_posts()
|
||||
else:
|
||||
results = await use_case.published_posts(limit=limit, offset=offset)
|
||||
|
||||
items = [PostResponseSchema(**r.__dict__) for r in results]
|
||||
return PostListResponseSchema(items=items, total=len(items))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/published",
|
||||
response_model=PostListResponseSchema,
|
||||
summary="List published posts",
|
||||
)
|
||||
async def list_published_posts(
|
||||
use_case: ListPostsDep,
|
||||
) -> PostListResponseSchema:
|
||||
"""Get all published blog posts."""
|
||||
results = await use_case.published_posts()
|
||||
items = [PostResponseSchema(**r.__dict__) for r in results]
|
||||
return PostListResponseSchema(items=items, total=len(items))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search",
|
||||
response_model=PostListResponseSchema,
|
||||
summary="Search posts",
|
||||
)
|
||||
async def search_posts(
|
||||
query: str,
|
||||
use_case: ListPostsDep,
|
||||
) -> PostListResponseSchema:
|
||||
"""Search posts by query."""
|
||||
results = await use_case.search(query)
|
||||
items = [PostResponseSchema(**r.__dict__) for r in results]
|
||||
return PostListResponseSchema(items=items, total=len(items))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/by-tag/{tag}",
|
||||
response_model=PostListResponseSchema,
|
||||
summary="Get posts by tag",
|
||||
)
|
||||
async def get_posts_by_tag(
|
||||
tag: str,
|
||||
use_case: ListPostsDep,
|
||||
) -> PostListResponseSchema:
|
||||
"""Get posts by tag."""
|
||||
results = await use_case.by_tag(tag)
|
||||
items = [PostResponseSchema(**r.__dict__) for r in results]
|
||||
return PostListResponseSchema(items=items, total=len(items))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/by-author/{author_id}",
|
||||
response_model=PostListResponseSchema,
|
||||
summary="Get posts by author",
|
||||
)
|
||||
async def get_posts_by_author(
|
||||
author_id: str,
|
||||
use_case: ListPostsDep,
|
||||
) -> PostListResponseSchema:
|
||||
"""Get posts by author."""
|
||||
results = await use_case.by_author(author_id)
|
||||
items = [PostResponseSchema(**r.__dict__) for r in results]
|
||||
return PostListResponseSchema(items=items, total=len(items))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{post_id}",
|
||||
response_model=PostResponseSchema,
|
||||
summary="Get post by ID",
|
||||
)
|
||||
async def get_post(
|
||||
post_id: UUID,
|
||||
use_case: GetPostDep,
|
||||
) -> PostResponseSchema:
|
||||
"""Get a post by its ID."""
|
||||
result = await use_case.by_id(post_id)
|
||||
return PostResponseSchema(**result.__dict__)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/slug/{slug}",
|
||||
response_model=PostResponseSchema,
|
||||
summary="Get post by slug",
|
||||
)
|
||||
async def get_post_by_slug(
|
||||
slug: str,
|
||||
use_case: GetPostDep,
|
||||
) -> PostResponseSchema:
|
||||
"""Get a post by its slug."""
|
||||
result = await use_case.by_slug(slug)
|
||||
return PostResponseSchema(**result.__dict__)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{post_id}",
|
||||
response_model=PostResponseSchema,
|
||||
summary="Update post",
|
||||
)
|
||||
async def update_post(
|
||||
post_id: UUID,
|
||||
schema: PostUpdateSchema,
|
||||
use_case: UpdatePostDep,
|
||||
current_user_id: CurrentUserDep,
|
||||
) -> PostResponseSchema:
|
||||
"""Update a post."""
|
||||
dto = UpdatePostDTO(
|
||||
title=schema.title,
|
||||
content=schema.content,
|
||||
tags=schema.tags,
|
||||
)
|
||||
result = await use_case.execute(post_id, dto, current_user_id)
|
||||
return PostResponseSchema(**result.__dict__)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{post_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete post",
|
||||
)
|
||||
async def delete_post(
|
||||
post_id: UUID,
|
||||
use_case: DeletePostDep,
|
||||
current_user_id: CurrentUserDep,
|
||||
) -> None:
|
||||
"""Delete a post."""
|
||||
await use_case.execute(post_id, current_user_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{post_id}/publish",
|
||||
response_model=PostResponseSchema,
|
||||
summary="Publish post",
|
||||
)
|
||||
async def publish_post(
|
||||
post_id: UUID,
|
||||
use_case: PublishPostDep,
|
||||
current_user_id: CurrentUserDep,
|
||||
) -> PostResponseSchema:
|
||||
"""Publish a post."""
|
||||
result = await use_case.publish(post_id, current_user_id)
|
||||
return PostResponseSchema(**result.__dict__)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{post_id}/unpublish",
|
||||
response_model=PostResponseSchema,
|
||||
summary="Unpublish post",
|
||||
)
|
||||
async def unpublish_post(
|
||||
post_id: UUID,
|
||||
use_case: PublishPostDep,
|
||||
current_user_id: CurrentUserDep,
|
||||
) -> PostResponseSchema:
|
||||
"""Unpublish a post."""
|
||||
result = await use_case.unpublish(post_id, current_user_id)
|
||||
return PostResponseSchema(**result.__dict__)
|
||||
21
app/presentation/schemas/__init__.py
Normal file
21
app/presentation/schemas/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Presentation schemas."""
|
||||
|
||||
from app.presentation.schemas.post import (
|
||||
PostBaseSchema,
|
||||
PostCreateSchema,
|
||||
PostListResponseSchema,
|
||||
PostPublishSchema,
|
||||
PostResponseSchema,
|
||||
PostSearchSchema,
|
||||
PostUpdateSchema,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PostBaseSchema",
|
||||
"PostCreateSchema",
|
||||
"PostUpdateSchema",
|
||||
"PostResponseSchema",
|
||||
"PostListResponseSchema",
|
||||
"PostSearchSchema",
|
||||
"PostPublishSchema",
|
||||
]
|
||||
66
app/presentation/schemas/post.py
Normal file
66
app/presentation/schemas/post.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""API schemas for posts."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class PostBaseSchema(BaseModel):
|
||||
"""Base schema for posts."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
title: str = Field(..., min_length=3, max_length=200)
|
||||
content: str = Field(..., min_length=10, max_length=50000)
|
||||
|
||||
|
||||
class PostCreateSchema(PostBaseSchema):
|
||||
"""Schema for creating a post."""
|
||||
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PostUpdateSchema(BaseModel):
|
||||
"""Schema for updating a post."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
title: str | None = Field(None, min_length=3, max_length=200)
|
||||
content: str | None = Field(None, min_length=10, max_length=50000)
|
||||
tags: list[str] | None = None
|
||||
|
||||
|
||||
class PostResponseSchema(BaseModel):
|
||||
"""Schema for post response."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
title: str
|
||||
content: str
|
||||
slug: str
|
||||
author_id: str
|
||||
published: bool
|
||||
tags: list[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class PostListResponseSchema(BaseModel):
|
||||
"""Schema for list of posts response."""
|
||||
|
||||
items: list[PostResponseSchema]
|
||||
total: int
|
||||
|
||||
|
||||
class PostSearchSchema(BaseModel):
|
||||
"""Schema for searching posts."""
|
||||
|
||||
query: str = Field(..., min_length=1, max_length=100)
|
||||
|
||||
|
||||
class PostPublishSchema(BaseModel):
|
||||
"""Schema for publishing/unpublishing a post."""
|
||||
|
||||
published: bool
|
||||
@@ -1,17 +0,0 @@
|
||||
# API Endpoints
|
||||
|
||||
## Overview
|
||||
|
||||
| Method | Endpoint | Description |
|
||||
|--------|----------|-------------|
|
||||
| GET | `/` | Health check |
|
||||
|
||||
## Health Check
|
||||
|
||||
```http
|
||||
GET /
|
||||
```
|
||||
|
||||
**Response:** `200 OK`
|
||||
|
||||
Returns application status.
|
||||
@@ -1,13 +0,0 @@
|
||||
# API Reference
|
||||
|
||||
This section contains auto-generated API documentation from source code docstrings.
|
||||
|
||||
## Modules
|
||||
|
||||
::: app.main
|
||||
handler: python
|
||||
options:
|
||||
members:
|
||||
- lifespan
|
||||
- app_factory
|
||||
- main
|
||||
@@ -1,43 +0,0 @@
|
||||
# Code Style
|
||||
|
||||
## Linting & Formatting
|
||||
|
||||
```bash
|
||||
# Run all linters
|
||||
uv run ruff check . --fix
|
||||
uv run ruff format .
|
||||
uv run isort . --profile black --filter-files
|
||||
|
||||
# Type checking
|
||||
uv run mypy .
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
```bash
|
||||
# Check docstring style
|
||||
uv run pydocstyle app/
|
||||
|
||||
# Check documentation coverage
|
||||
uv run interrogate app/ -v
|
||||
|
||||
# Build documentation
|
||||
uv run mkdocs build
|
||||
|
||||
# Serve documentation locally
|
||||
uv run mkdocs serve
|
||||
```
|
||||
|
||||
## Pre-commit Hooks
|
||||
|
||||
This project uses pre-commit hooks to ensure code quality:
|
||||
|
||||
- ruff check
|
||||
- ruff format
|
||||
- isort
|
||||
- mypy
|
||||
|
||||
Install hooks:
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
@@ -1,31 +0,0 @@
|
||||
# Setup Guide
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.13+
|
||||
- uv package manager
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/pyaqa/blog.git
|
||||
cd blog
|
||||
|
||||
# Install dependencies
|
||||
uv sync
|
||||
|
||||
# Run tests
|
||||
uv run pytest
|
||||
|
||||
# Start development server
|
||||
uv run python -m app.main
|
||||
```
|
||||
|
||||
## Development Server
|
||||
|
||||
The server runs on `http://0.0.0.0:8000` by default.
|
||||
|
||||
Access interactive API docs at:
|
||||
- Swagger UI: `http://localhost:8000/docs`
|
||||
- ReDoc: `http://localhost:8000/redoc`
|
||||
@@ -1,28 +0,0 @@
|
||||
# Blog API
|
||||
|
||||
Welcome to the Blog API documentation.
|
||||
|
||||
## Features
|
||||
|
||||
- FastAPI-based REST API
|
||||
- Python 3.13+
|
||||
- Async support
|
||||
- Type hints throughout
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
uv sync
|
||||
|
||||
# Run development server
|
||||
uv run python -m app.main
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
See [API Reference](api/endpoints.md) for detailed endpoint documentation.
|
||||
|
||||
## Development
|
||||
|
||||
See [Development Guide](development/setup.md) for setup instructions.
|
||||
50
mkdocs.yml
50
mkdocs.yml
@@ -1,50 +0,0 @@
|
||||
site_name: Blog API Documentation
|
||||
site_description: FastAPI Blog Application Documentation
|
||||
site_author: Blog Team
|
||||
repo_url: https://github.com/pyaqa/blog
|
||||
|
||||
theme:
|
||||
name: mkdocs
|
||||
palette:
|
||||
- scheme: default
|
||||
primary: indigo
|
||||
accent: indigo
|
||||
toggle:
|
||||
icon: material/brightness-7
|
||||
name: Switch to dark mode
|
||||
- scheme: slate
|
||||
primary: indigo
|
||||
accent: indigo
|
||||
toggle:
|
||||
icon: material/brightness-4
|
||||
name: Switch to light mode
|
||||
|
||||
plugins:
|
||||
- search
|
||||
- mkdocstrings:
|
||||
handlers:
|
||||
python:
|
||||
options:
|
||||
docstring_style: google
|
||||
show_root_heading: true
|
||||
show_source: true
|
||||
show_bases: true
|
||||
|
||||
markdown_extensions:
|
||||
- pymdownx.highlight:
|
||||
anchor_linenums: true
|
||||
- pymdownx.inlinehilite
|
||||
- pymdownx.snippets
|
||||
- pymdownx.superfences
|
||||
- admonition
|
||||
- pymdownx.details
|
||||
- tables
|
||||
|
||||
nav:
|
||||
- Home: index.md
|
||||
- API Reference:
|
||||
- Overview: api/index.md
|
||||
- Endpoints: api/endpoints.md
|
||||
- Development:
|
||||
- Setup: development/setup.md
|
||||
- Code Style: development/codestyle.md
|
||||
@@ -9,8 +9,20 @@ dependencies = [
|
||||
"pydantic>=2.13.2",
|
||||
"pydantic-settings>=2.14.0",
|
||||
"uvicorn>=0.44.0",
|
||||
"sqlalchemy>=2.0.0",
|
||||
"aiosqlite>=0.21.0",
|
||||
"asyncpg>=0.30.0",
|
||||
"dishka>=1.5.0",
|
||||
"httpx>=0.28.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["app"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
{include-group = "lints"},
|
||||
@@ -35,10 +47,13 @@ types = [
|
||||
"mypy>=1.20.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
blog = "app.main:main"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
addopts = "--cov=src --cov-report=term"
|
||||
addopts = "--cov=app --cov-report=term-missing --cov-report=html"
|
||||
pythonpath = "."
|
||||
testpaths = "tests"
|
||||
xfail_strict = true
|
||||
@@ -47,6 +62,14 @@ xfail_strict = true
|
||||
strict = true
|
||||
plugins = ["pydantic.mypy"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py313"
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "W", "B", "C4", "SIM"]
|
||||
ignore = ["E501"]
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
filter_files = true
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
# Development Scripts
|
||||
|
||||
## clean_cache.sh
|
||||
|
||||
Clean all Python cache files:
|
||||
|
||||
```bash
|
||||
bash scripts/clean_cache.sh
|
||||
```
|
||||
|
||||
Removes:
|
||||
- `__pycache__/` directories
|
||||
- `*.pyc`, `*.pyo` files
|
||||
- `.pytest_cache/`
|
||||
- `.mypy_cache/`
|
||||
- `.ruff_cache/`
|
||||
- `.coverage`
|
||||
- `htmlcov/`
|
||||
|
||||
## update_readme.py
|
||||
|
||||
Update README.md with latest project information:
|
||||
|
||||
```bash
|
||||
uv run python scripts/update_readme.py
|
||||
```
|
||||
|
||||
Check if update needed (for CI):
|
||||
|
||||
```bash
|
||||
uv run python scripts/update_readme.py --check
|
||||
```
|
||||
|
||||
## post-commit
|
||||
|
||||
Git hook for auto-updating README after commits.
|
||||
|
||||
Install:
|
||||
|
||||
```bash
|
||||
cp scripts/post-commit .git/hooks/post-commit
|
||||
chmod +x .git/hooks/post-commit
|
||||
```
|
||||
|
||||
## Disable Python Cache During Development
|
||||
|
||||
Set environment variables before running Python:
|
||||
|
||||
```bash
|
||||
# Option 1: Export variables
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
export UV_NO_CACHE=1
|
||||
|
||||
# Option 2: Use with command
|
||||
PYTHONDONTWRITEBYTECODE=1 uv run python -m app.main
|
||||
|
||||
# Option 3: Add to .env (not committed)
|
||||
echo "PYTHONDONTWRITEBYTECODE=1" >> .env
|
||||
```
|
||||
|
||||
Or use the clean script periodically:
|
||||
|
||||
```bash
|
||||
bash scripts/clean_cache.sh
|
||||
```
|
||||
@@ -1,22 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
echo "Cleaning Python cache files..."
|
||||
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
|
||||
find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
|
||||
find . -type f -name "*.pyo" -delete 2>/dev/null || true
|
||||
|
||||
rm -rf .pytest_cache/ 2>/dev/null || true
|
||||
|
||||
rm -rf .mypy_cache/ 2>/dev/null || true
|
||||
|
||||
rm -rf .ruff_cache/ 2>/dev/null || true
|
||||
|
||||
rm -f .coverage 2>/dev/null || true
|
||||
rm -rf htmlcov/ 2>/dev/null || true
|
||||
|
||||
echo "✓ Cache cleaned"
|
||||
@@ -1,64 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
COMMIT_MSG_FILE="$1"
|
||||
if [ -z "$COMMIT_MSG_FILE" ]; then
|
||||
echo "Checking for cache files in staged changes..."
|
||||
|
||||
CACHE_FILES=$(git diff --cached --name-only | grep -E "__pycache__|\.pyc$|\.pyo$" || true)
|
||||
|
||||
if [ -n "$CACHE_FILES" ]; then
|
||||
echo "❌ Attempting to commit Python cache files!"
|
||||
echo ""
|
||||
echo "Files:"
|
||||
echo "$CACHE_FILES"
|
||||
echo ""
|
||||
echo "Run: bash scripts/clean_cache.sh"
|
||||
echo "Or: git reset HEAD <files>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✓ No cache files in staged changes"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
COMMIT_MSG=$(cat "$COMMIT_MSG_FILE")
|
||||
|
||||
if ! echo "$COMMIT_MSG" | grep -qE "^(feat|fix|docs|style|refactor|test|chore): [a-z].{0,49}$"; then
|
||||
echo "❌ Invalid commit message format!"
|
||||
echo ""
|
||||
echo "Current message: $COMMIT_MSG"
|
||||
echo ""
|
||||
echo "Expected format: <type>: <short description>"
|
||||
echo ""
|
||||
echo "Types:"
|
||||
echo " feat - New feature"
|
||||
echo " fix - Bug fix"
|
||||
echo " docs - Documentation"
|
||||
echo " style - Code style"
|
||||
echo " refactor - Refactoring"
|
||||
echo " test - Tests"
|
||||
echo " chore - Maintenance"
|
||||
echo ""
|
||||
echo "Rules:"
|
||||
echo " - Max 50 characters"
|
||||
echo " - Lowercase after type"
|
||||
echo " - Imperative mood (add, not added)"
|
||||
echo " - No period at end"
|
||||
echo ""
|
||||
echo "Good examples:"
|
||||
echo " feat: add user authentication"
|
||||
echo " fix: resolve database timeout"
|
||||
echo " docs: update API docs"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if echo "$COMMIT_MSG" | grep -qE "\.$"; then
|
||||
echo "❌ Commit message should not end with a period"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✓ Commit message valid: $COMMIT_MSG"
|
||||
exit 0
|
||||
@@ -1,18 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Post-commit hook: Update README.md automatically
|
||||
|
||||
set -e
|
||||
|
||||
echo "Updating README.md..."
|
||||
|
||||
# Run README update script
|
||||
uv run python scripts/update_readme.py
|
||||
|
||||
# Check if README changed
|
||||
if ! git diff --quiet README.md; then
|
||||
echo "✓ README.md was updated"
|
||||
echo " Review changes and commit if needed:"
|
||||
echo " git add README.md && git commit -m 'docs: update README [skip ci]'"
|
||||
else
|
||||
echo "✓ README.md is up to date"
|
||||
fi
|
||||
@@ -1,358 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
import tomllib
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
return Path(__file__).parent.parent
|
||||
|
||||
|
||||
def get_pyproject() -> dict[str, Any]:
|
||||
root = get_project_root()
|
||||
with open(root / "pyproject.toml", "rb") as f:
|
||||
return tomllib.load(f)
|
||||
|
||||
|
||||
def get_latest_commits(count: int = 10) -> list[dict[str, str]]:
|
||||
result = subprocess.run(
|
||||
["git", "log", "--format=%H|%s|%ad|%an", "--date=short", f"-n{count}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=get_project_root(),
|
||||
)
|
||||
|
||||
commits = []
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
if line:
|
||||
parts = line.split("|")
|
||||
if len(parts) >= 4:
|
||||
commits.append(
|
||||
{
|
||||
"hash": parts[0][:7],
|
||||
"message": parts[1],
|
||||
"date": parts[2],
|
||||
"author": parts[3],
|
||||
}
|
||||
)
|
||||
return commits
|
||||
|
||||
|
||||
def get_last_tag() -> str | None:
|
||||
result = subprocess.run(
|
||||
["git", "describe", "--tags", "--abbrev=0"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=get_project_root(),
|
||||
)
|
||||
return result.stdout.strip() if result.returncode == 0 else None
|
||||
|
||||
|
||||
def get_ignored_files() -> set[str]:
|
||||
gitignore_path = get_project_root() / ".gitignore"
|
||||
ignored = set()
|
||||
if gitignore_path.exists():
|
||||
for line in gitignore_path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
ignored.add(line.rstrip("/"))
|
||||
return ignored
|
||||
|
||||
|
||||
def commit_has_tracked_changes(commit_hash: str) -> bool:
|
||||
result = subprocess.run(
|
||||
["git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=get_project_root(),
|
||||
)
|
||||
if not result.stdout.strip():
|
||||
return False
|
||||
|
||||
ignored = get_ignored_files()
|
||||
for file_path in result.stdout.strip().split("\n"):
|
||||
if not file_path:
|
||||
continue
|
||||
parts = file_path.split("/")
|
||||
is_ignored = False
|
||||
for i in range(len(parts)):
|
||||
path_part = "/".join(parts[: i + 1])
|
||||
for pattern in ignored:
|
||||
if pattern.endswith("*"):
|
||||
if path_part.startswith(pattern[:-1]):
|
||||
is_ignored = True
|
||||
break
|
||||
elif path_part == pattern or parts[-1] == pattern:
|
||||
is_ignored = True
|
||||
break
|
||||
if is_ignored:
|
||||
break
|
||||
if not is_ignored:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def commit_has_skip_ci_message(commit_hash: str) -> bool:
|
||||
result = subprocess.run(
|
||||
["git", "log", "-1", "--format=%s", commit_hash],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=get_project_root(),
|
||||
)
|
||||
msg = result.stdout.strip().lower()
|
||||
return "[skip ci]" in msg or "[skip-ci]" in msg or "[ci skip]" in msg
|
||||
|
||||
|
||||
def commit_only_changes_readme(commit_hash: str) -> bool:
|
||||
result = subprocess.run(
|
||||
["git", "diff-tree", "--no-commit-id", "--name-only", "-r", commit_hash],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=get_project_root(),
|
||||
)
|
||||
files = [f.strip() for f in result.stdout.strip().split("\n") if f.strip()]
|
||||
return files == ["README.md"]
|
||||
|
||||
|
||||
def get_commits_since_tag(tag: str | None) -> list[dict[str, str]]:
|
||||
if tag:
|
||||
result = subprocess.run(
|
||||
["git", "log", "--format=%H|%s|%ad|%an", "--date=short", f"{tag}..HEAD"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=get_project_root(),
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["git", "log", "--format=%H|%s|%ad|%an", "--date=short", "-n10"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=get_project_root(),
|
||||
)
|
||||
|
||||
commits = []
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
if line:
|
||||
parts = line.split("|")
|
||||
if len(parts) >= 4:
|
||||
commit_hash = parts[0]
|
||||
if commit_has_skip_ci_message(commit_hash):
|
||||
continue
|
||||
if commit_only_changes_readme(commit_hash):
|
||||
continue
|
||||
if not commit_has_tracked_changes(commit_hash):
|
||||
continue
|
||||
commits.append(
|
||||
{
|
||||
"hash": commit_hash[:7],
|
||||
"message": parts[1],
|
||||
"date": parts[2],
|
||||
"author": parts[3],
|
||||
}
|
||||
)
|
||||
return commits
|
||||
|
||||
|
||||
def categorize_commits(commits: list[dict[str, str]]) -> dict[str, list[str]]:
|
||||
categories: dict[str, list[str]] = {
|
||||
"Added": [],
|
||||
"Changed": [],
|
||||
"Fixed": [],
|
||||
"Removed": [],
|
||||
"Other": [],
|
||||
}
|
||||
|
||||
for commit in commits:
|
||||
msg = commit["message"].lower()
|
||||
entry = f"- {commit['message']} ({commit['hash']})"
|
||||
|
||||
if msg.startswith("feat") or "add" in msg:
|
||||
categories["Added"].append(entry)
|
||||
elif msg.startswith("fix") or "fix" in msg:
|
||||
categories["Fixed"].append(entry)
|
||||
elif msg.startswith("change") or "update" in msg:
|
||||
categories["Changed"].append(entry)
|
||||
elif msg.startswith("remove") or "delete" in msg:
|
||||
categories["Removed"].append(entry)
|
||||
else:
|
||||
categories["Other"].append(entry)
|
||||
|
||||
return categories
|
||||
|
||||
|
||||
def format_changelog(commits: list[dict[str, str]], version: str = "v0.1.0") -> str:
|
||||
categorized = categorize_commits(commits)
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
lines = [f"### [{version}] - {today}"]
|
||||
|
||||
for section, entries in categorized.items():
|
||||
if entries:
|
||||
lines.append(f"\n#### {section}")
|
||||
lines.extend(entries)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_dependencies(pyproject: dict[str, Any]) -> dict[str, list[str]]:
|
||||
deps: dict[str, list[str]] = {
|
||||
"runtime": [],
|
||||
"tests": [],
|
||||
"lints": [],
|
||||
"types": [],
|
||||
"docs": [],
|
||||
}
|
||||
|
||||
for dep in pyproject.get("project", {}).get("dependencies", []):
|
||||
deps["runtime"].append(dep)
|
||||
|
||||
dep_groups = pyproject.get("dependency-groups", {})
|
||||
|
||||
if "tests" in dep_groups:
|
||||
for dep in dep_groups["tests"]:
|
||||
if isinstance(dep, str):
|
||||
deps["tests"].append(dep)
|
||||
|
||||
if "lints" in dep_groups:
|
||||
for dep in dep_groups["lints"]:
|
||||
if isinstance(dep, str):
|
||||
deps["lints"].append(dep)
|
||||
|
||||
if "types" in dep_groups:
|
||||
for dep in dep_groups["types"]:
|
||||
if isinstance(dep, str):
|
||||
deps["types"].append(dep)
|
||||
|
||||
if "docs" in dep_groups:
|
||||
for dep in dep_groups["docs"]:
|
||||
if isinstance(dep, str):
|
||||
deps["docs"].append(dep)
|
||||
|
||||
return deps
|
||||
|
||||
|
||||
def get_available_commands() -> list[dict[str, str]]:
|
||||
commands = [
|
||||
{"cmd": "uv sync", "desc": "Install dependencies"},
|
||||
{"cmd": "uv run python -m app.main", "desc": "Start development server"},
|
||||
{
|
||||
"cmd": "uv run pytest --cov=app --cov-fail-under=70",
|
||||
"desc": "Run tests with coverage",
|
||||
},
|
||||
{"cmd": "uv run ruff check . --fix", "desc": "Run linters"},
|
||||
{"cmd": "uv run ruff format .", "desc": "Format code"},
|
||||
{
|
||||
"cmd": "uv run isort . --profile black --filter-files",
|
||||
"desc": "Sort imports",
|
||||
},
|
||||
{"cmd": "uv run mypy .", "desc": "Type checking"},
|
||||
{"cmd": "uv run mkdocs build", "desc": "Build documentation"},
|
||||
{"cmd": "uv run mkdocs serve", "desc": "Serve documentation locally"},
|
||||
]
|
||||
return commands
|
||||
|
||||
|
||||
def update_dependencies_section(content: str, deps: dict[str, list[str]]) -> str:
|
||||
section_pattern = r"(## Dependencies\n.*?)(\n## |\Z)"
|
||||
|
||||
deps_text = "## Dependencies\n\n"
|
||||
|
||||
if deps["runtime"]:
|
||||
deps_text += "### Runtime\n"
|
||||
for dep in sorted(deps["runtime"]):
|
||||
deps_text += f"- {dep}\n"
|
||||
deps_text += "\n"
|
||||
|
||||
if deps["tests"]:
|
||||
deps_text += "### Development\n"
|
||||
deps_text += "- **Tests**: " + ", ".join(sorted(deps["tests"])) + "\n"
|
||||
if deps["lints"]:
|
||||
deps_text += "- **Lint**: " + ", ".join(sorted(deps["lints"])) + "\n"
|
||||
if deps["types"]:
|
||||
deps_text += "- **Types**: " + ", ".join(sorted(deps["types"])) + "\n"
|
||||
if deps["docs"]:
|
||||
deps_text += "- **Docs**: " + ", ".join(sorted(deps["docs"])) + "\n"
|
||||
|
||||
deps_text += "\n"
|
||||
|
||||
replacement = f"{deps_text}\\2"
|
||||
return re.sub(section_pattern, replacement, content, flags=re.DOTALL)
|
||||
|
||||
|
||||
def update_commands_section(content: str, commands: list[dict[str, str]]) -> str:
|
||||
section_pattern = r"(## Available Commands\n.*?\|.*?\n\|---\|.*?\n)(.*?)(\n## |\Z)"
|
||||
|
||||
commands_table = "| Command | Description |\n|---------|-------------|\n"
|
||||
for cmd in commands:
|
||||
commands_table += f"| `{cmd['cmd']}` | {cmd['desc']} |\n"
|
||||
|
||||
commands_table += "\n"
|
||||
|
||||
replacement = f"\\1{commands_table}\\3"
|
||||
return re.sub(section_pattern, replacement, content, flags=re.DOTALL)
|
||||
|
||||
|
||||
def update_changelog_section(content: str, changelog: str) -> str:
|
||||
section_pattern = r"(## Changelog\n)(.*?)(\Z)"
|
||||
|
||||
replacement = f"\\1\n{changelog}\n\\3"
|
||||
return re.sub(section_pattern, replacement, content, flags=re.DOTALL)
|
||||
|
||||
|
||||
def update_readme(check_only: bool = False) -> bool:
|
||||
readme_path = get_project_root() / "README.md"
|
||||
|
||||
if not readme_path.exists():
|
||||
print("README.md not found")
|
||||
return False
|
||||
|
||||
content = readme_path.read_text()
|
||||
original_content = content
|
||||
|
||||
pyproject = get_pyproject()
|
||||
commits = get_commits_since_tag(get_last_tag())
|
||||
deps = get_dependencies(pyproject)
|
||||
commands = get_available_commands()
|
||||
|
||||
version = get_last_tag() or "v0.1.0"
|
||||
changelog = format_changelog(commits, version)
|
||||
|
||||
content = update_changelog_section(content, changelog)
|
||||
content = update_dependencies_section(content, deps)
|
||||
content = update_commands_section(content, commands)
|
||||
|
||||
if check_only:
|
||||
needs_update = content != original_content
|
||||
if needs_update:
|
||||
print("README.md needs update")
|
||||
else:
|
||||
print("README.md is up to date")
|
||||
return needs_update
|
||||
|
||||
if content != original_content:
|
||||
readme_path.write_text(content)
|
||||
print("README.md updated successfully")
|
||||
return True
|
||||
else:
|
||||
print("No changes needed")
|
||||
return False
|
||||
|
||||
|
||||
def main() -> None:
|
||||
import sys
|
||||
|
||||
check_only = "--check" in sys.argv
|
||||
|
||||
updated = update_readme(check_only=check_only)
|
||||
|
||||
if check_only and updated:
|
||||
sys.exit(1)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,24 +1,57 @@
|
||||
# API test fixtures
|
||||
# Provides: httpx.AsyncClient, authentication helpers, test API data
|
||||
"""API test fixtures."""
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.infrastructure.auth.models import TokenInfo
|
||||
from app.main import app_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client() -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create async HTTP client for API testing."""
|
||||
from app.main import app_factory
|
||||
def mock_keycloak_client() -> MagicMock:
|
||||
"""Create mock Keycloak client for testing."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.introspect_token.return_value = TokenInfo(
|
||||
active=True,
|
||||
user_id="test-user-id",
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["user"],
|
||||
)
|
||||
return mock_client
|
||||
|
||||
app = app_factory()
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
@pytest.fixture
|
||||
async def client(mock_keycloak_client: MagicMock) -> AsyncGenerator[AsyncClient]:
|
||||
"""Create async HTTP client for API testing."""
|
||||
with patch(
|
||||
"app.presentation.api.deps.KeycloakAuthClient",
|
||||
return_value=mock_keycloak_client,
|
||||
):
|
||||
app = app_factory()
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers() -> dict[str, str]:
|
||||
"""Return mock authentication headers."""
|
||||
return {"Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthorized_keycloak_client() -> MagicMock:
|
||||
"""Create mock Keycloak client that returns invalid token."""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.introspect_token.return_value = TokenInfo(
|
||||
active=False,
|
||||
user_id="",
|
||||
username="",
|
||||
email="",
|
||||
roles=[],
|
||||
)
|
||||
return mock_client
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# E2E test fixtures
|
||||
# Provides: full application state, end-to-end workflows, cleanup
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def e2e_app() -> AsyncGenerator[FastAPI, None]:
|
||||
async def e2e_app() -> AsyncGenerator[FastAPI]:
|
||||
"""Create full application instance for E2E testing."""
|
||||
from app.main import app_factory
|
||||
|
||||
|
||||
@@ -1,20 +1,58 @@
|
||||
# Integration test fixtures
|
||||
# Provides: test database, external service connections
|
||||
"""Integration test fixtures."""
|
||||
|
||||
from typing import Generator
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from app.infrastructure.database.models import Base
|
||||
|
||||
# Use in-memory SQLite for tests
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db_connection() -> Generator[str, None, None]:
|
||||
"""Create test database connection."""
|
||||
# TODO: Implement when DB is added to project
|
||||
yield "test_db"
|
||||
@pytest.fixture(scope="session")
|
||||
def engine() -> AsyncEngine:
|
||||
"""Create test engine."""
|
||||
return create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db() -> Generator[None, None, None]:
|
||||
"""Cleanup database after test."""
|
||||
@pytest.fixture(scope="session")
|
||||
def session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
|
||||
"""Create test session factory."""
|
||||
return async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_db(engine: AsyncEngine) -> AsyncGenerator[None]:
|
||||
"""Setup database tables for each test."""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
# TODO: Implement cleanup logic
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> AsyncGenerator[AsyncSession]:
|
||||
"""Create database session for testing."""
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
# Предполагаем, что тестируемый модуль называется `myapp`
|
||||
# Импортируем из него нужные объекты
|
||||
from app.main import app_factory, lifespan, main
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan() -> None:
|
||||
"""Проверяет, что lifespan является корректным асинхронным контекстным менеджером."""
|
||||
app = FastAPI()
|
||||
# Проверяем, что lifespan - это asynccontextmanager
|
||||
assert isinstance(lifespan, asynccontextmanager(lifespan).__class__) # type: ignore[arg-type]
|
||||
|
||||
# Проверяем, что контекстный менеджер работает (ничего не ломается)
|
||||
async with lifespan(app):
|
||||
pass # Просто убеждаемся, что yield отрабатывает
|
||||
|
||||
|
||||
def test_app_factory() -> None:
|
||||
"""Проверяет, что app_factory создаёт приложение FastAPI с переданным lifespan."""
|
||||
app = app_factory()
|
||||
assert isinstance(app, FastAPI)
|
||||
# Проверяем, что lifespan приложения установлен на функцию lifespan
|
||||
assert app.router.lifespan_context == lifespan
|
||||
|
||||
|
||||
@patch("app.main.uvicorn.run")
|
||||
def test_main(mock_uvicorn_run: Mock) -> None:
|
||||
"""Проверяет, что main вызывает uvicorn.run с правильными параметрами."""
|
||||
main()
|
||||
mock_uvicorn_run.assert_called_once_with(
|
||||
app_factory,
|
||||
factory=True,
|
||||
host="0.0.0.0",
|
||||
port=8000, # Предполагаемый порт (в коде обрезано, но обычно 8000)
|
||||
)
|
||||
0
tests/unit/application/__init__.py
Normal file
0
tests/unit/application/__init__.py
Normal file
273
tests/unit/application/test_use_cases.py
Normal file
273
tests/unit/application/test_use_cases.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Tests for application use cases."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.application.dtos.post import CreatePostDTO, UpdatePostDTO
|
||||
from app.application.use_cases import (
|
||||
CreatePostUseCase,
|
||||
DeletePostUseCase,
|
||||
GetPostUseCase,
|
||||
ListPostsUseCase,
|
||||
PublishPostUseCase,
|
||||
UpdatePostUseCase,
|
||||
)
|
||||
from app.domain.entities import Post
|
||||
from app.domain.exceptions import (
|
||||
AlreadyExistsException,
|
||||
ForbiddenException,
|
||||
NotFoundException,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_post() -> Post:
|
||||
"""Create a test post."""
|
||||
return Post.create(
|
||||
title_str="Test Post",
|
||||
content_str="This is test content with enough characters",
|
||||
author_id="user-123",
|
||||
tags=["test"],
|
||||
)
|
||||
|
||||
|
||||
class TestCreatePostUseCase:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_post_success(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
) -> None:
|
||||
"""Test successful post creation."""
|
||||
# Setup
|
||||
mock_post_repository.slug_exists = AsyncMock(return_value=False)
|
||||
mock_post_repository.add = AsyncMock()
|
||||
|
||||
use_case = CreatePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
dto = CreatePostDTO(
|
||||
title="New Post",
|
||||
content="Content with enough characters",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await use_case.execute(dto)
|
||||
|
||||
# Assert
|
||||
assert result.title == "New Post"
|
||||
assert result.author_id == "user-123"
|
||||
mock_post_repository.add.assert_called_once()
|
||||
mock_transaction_manager.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_post_slug_exists(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
) -> None:
|
||||
"""Test post creation with existing slug."""
|
||||
# Setup
|
||||
mock_post_repository.slug_exists = AsyncMock(return_value=True)
|
||||
|
||||
use_case = CreatePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
dto = CreatePostDTO(
|
||||
title="Existing Post",
|
||||
content="Content with enough characters",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(AlreadyExistsException):
|
||||
await use_case.execute(dto)
|
||||
|
||||
mock_post_repository.add.assert_not_called()
|
||||
mock_transaction_manager.commit.assert_not_called()
|
||||
|
||||
|
||||
class TestGetPostUseCase:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_post_by_id_success(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test successful get post by ID."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=test_post)
|
||||
|
||||
use_case = GetPostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
|
||||
# Execute
|
||||
result = await use_case.by_id(test_post.id)
|
||||
|
||||
# Assert
|
||||
assert result.id == test_post.id
|
||||
assert result.title == test_post.title.value
|
||||
mock_post_repository.get_by_id.assert_called_once_with(test_post.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_post_by_id_not_found(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
) -> None:
|
||||
"""Test get post by ID when not found."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=None)
|
||||
|
||||
use_case = GetPostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
post_id = uuid4()
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(NotFoundException):
|
||||
await use_case.by_id(post_id)
|
||||
|
||||
|
||||
class TestUpdatePostUseCase:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_post_success(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test successful post update."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=test_post)
|
||||
mock_post_repository.update = AsyncMock()
|
||||
|
||||
use_case = UpdatePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
dto = UpdatePostDTO(title="Updated Title")
|
||||
|
||||
# Execute
|
||||
result = await use_case.execute(test_post.id, dto, "user-123")
|
||||
|
||||
# Assert
|
||||
assert result.title == "Updated Title"
|
||||
mock_post_repository.update.assert_called_once()
|
||||
mock_transaction_manager.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_post_not_found(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
) -> None:
|
||||
"""Test update post when not found."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=None)
|
||||
|
||||
use_case = UpdatePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
dto = UpdatePostDTO(title="Updated Title")
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(NotFoundException):
|
||||
await use_case.execute(uuid4(), dto, "user-123")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_post_forbidden(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test update post by different user."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=test_post)
|
||||
|
||||
use_case = UpdatePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
dto = UpdatePostDTO(title="Updated Title")
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(ForbiddenException):
|
||||
await use_case.execute(test_post.id, dto, "other-user")
|
||||
|
||||
|
||||
class TestDeletePostUseCase:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_post_success(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test successful post deletion."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=test_post)
|
||||
mock_post_repository.delete = AsyncMock()
|
||||
|
||||
use_case = DeletePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
|
||||
# Execute
|
||||
await use_case.execute(test_post.id, "user-123")
|
||||
|
||||
# Assert
|
||||
mock_post_repository.delete.assert_called_once_with(test_post.id)
|
||||
mock_transaction_manager.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_post_forbidden(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test delete post by different user."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=test_post)
|
||||
|
||||
use_case = DeletePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(ForbiddenException):
|
||||
await use_case.execute(test_post.id, "other-user")
|
||||
|
||||
|
||||
class TestPublishPostUseCase:
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_post_success(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test successful post publish."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=test_post)
|
||||
mock_post_repository.update = AsyncMock()
|
||||
|
||||
use_case = PublishPostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
|
||||
# Execute
|
||||
result = await use_case.publish(test_post.id, "user-123")
|
||||
|
||||
# Assert
|
||||
assert result.published is True
|
||||
mock_post_repository.update.assert_called_once()
|
||||
mock_transaction_manager.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestListPostsUseCase:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_posts(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test listing all posts."""
|
||||
# Setup
|
||||
mock_post_repository.get_all = AsyncMock(return_value=[test_post])
|
||||
|
||||
use_case = ListPostsUseCase(mock_post_repository, mock_transaction_manager)
|
||||
|
||||
# Execute
|
||||
results = await use_case.all_posts()
|
||||
|
||||
# Assert
|
||||
assert len(results) == 1
|
||||
assert results[0].id == test_post.id
|
||||
mock_post_repository.get_all.assert_called_once()
|
||||
@@ -1,18 +1,29 @@
|
||||
# Unit test fixtures
|
||||
# Provides: mocks, stubs, isolated test data
|
||||
"""Unit test fixtures."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.repositories import PostRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_service() -> Mock:
|
||||
"""Create a mock service for unit testing."""
|
||||
return Mock()
|
||||
def mock_post_repository() -> Mock:
|
||||
"""Create a mock post repository."""
|
||||
return Mock(spec=PostRepository)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_transaction_manager() -> Mock:
|
||||
"""Create a mock transaction manager."""
|
||||
tx_manager = Mock(spec=TransactionManager)
|
||||
tx_manager.commit = AsyncMock()
|
||||
tx_manager.rollback = AsyncMock()
|
||||
return tx_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_service() -> AsyncMock:
|
||||
"""Create an async mock service for unit testing."""
|
||||
"""Create an async mock service."""
|
||||
return AsyncMock()
|
||||
|
||||
0
tests/unit/domain/__init__.py
Normal file
0
tests/unit/domain/__init__.py
Normal file
128
tests/unit/domain/test_entities.py
Normal file
128
tests/unit/domain/test_entities.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Tests for domain entities."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from app.domain.entities import Post
|
||||
from app.domain.value_objects import Content, Title
|
||||
|
||||
|
||||
class TestPost:
|
||||
def test_post_creation(self) -> None:
|
||||
"""Test creating a post."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
tags=["test", "python"],
|
||||
)
|
||||
|
||||
assert isinstance(post.id, UUID)
|
||||
assert post.title.value == "Test Title"
|
||||
assert post.content.value == "This is test content that is long enough"
|
||||
assert post.slug.value == "test-title"
|
||||
assert post.author_id == "user-123"
|
||||
assert post.published is False
|
||||
assert post.tags == ["test", "python"]
|
||||
|
||||
def test_post_publish(self) -> None:
|
||||
"""Test publishing a post."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
assert post.published is False
|
||||
post.publish()
|
||||
assert post.published is True
|
||||
|
||||
def test_post_unpublish(self) -> None:
|
||||
"""Test unpublishing a post."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
post.publish()
|
||||
assert post.published is True
|
||||
post.unpublish()
|
||||
assert post.published is False
|
||||
|
||||
def test_post_update_title(self) -> None:
|
||||
"""Test updating post title."""
|
||||
post = Post.create(
|
||||
title_str="Original Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
old_updated_at = post.updated_at
|
||||
post.update_title(Title("New Title"))
|
||||
|
||||
assert post.title.value == "New Title"
|
||||
assert post.slug.value == "new-title"
|
||||
assert post.updated_at > old_updated_at
|
||||
|
||||
def test_post_update_content(self) -> None:
|
||||
"""Test updating post content."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
old_updated_at = post.updated_at
|
||||
post.update_content(Content("Updated content that is also long enough"))
|
||||
|
||||
assert post.content.value == "Updated content that is also long enough"
|
||||
assert post.updated_at > old_updated_at
|
||||
|
||||
def test_post_add_tag(self) -> None:
|
||||
"""Test adding a tag."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
post.add_tag("python")
|
||||
assert "python" in post.tags
|
||||
|
||||
# Adding same tag twice should not duplicate
|
||||
post.add_tag("python")
|
||||
assert post.tags.count("python") == 1
|
||||
|
||||
def test_post_remove_tag(self) -> None:
|
||||
"""Test removing a tag."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
tags=["python", "fastapi"],
|
||||
)
|
||||
|
||||
post.remove_tag("python")
|
||||
assert "python" not in post.tags
|
||||
assert "fastapi" in post.tags
|
||||
|
||||
def test_post_to_dict(self) -> None:
|
||||
"""Test converting post to dict."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
tags=["test"],
|
||||
)
|
||||
|
||||
data = post.to_dict()
|
||||
|
||||
assert data["title"] == "Test Title"
|
||||
assert data["content"] == "This is test content that is long enough"
|
||||
assert data["slug"] == "test-title"
|
||||
assert data["author_id"] == "user-123"
|
||||
assert data["published"] is False
|
||||
assert data["tags"] == ["test"]
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
assert "updated_at" in data
|
||||
48
tests/unit/domain/test_exceptions.py
Normal file
48
tests/unit/domain/test_exceptions.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Tests for domain exceptions."""
|
||||
|
||||
from app.domain.exceptions import (
|
||||
AlreadyExistsException,
|
||||
DomainException,
|
||||
ForbiddenException,
|
||||
NotFoundException,
|
||||
UnauthorizedException,
|
||||
ValidationException,
|
||||
)
|
||||
|
||||
|
||||
class TestDomainExceptions:
|
||||
def test_base_exception(self) -> None:
|
||||
"""Test base domain exception."""
|
||||
exc = DomainException("Something went wrong")
|
||||
assert exc.message == "Something went wrong"
|
||||
assert str(exc) == "Something went wrong"
|
||||
|
||||
def test_validation_exception(self) -> None:
|
||||
"""Test validation exception."""
|
||||
exc = ValidationException("Invalid input")
|
||||
assert isinstance(exc, DomainException)
|
||||
assert exc.message == "Invalid input"
|
||||
|
||||
def test_not_found_exception(self) -> None:
|
||||
"""Test not found exception."""
|
||||
exc = NotFoundException("Resource not found")
|
||||
assert isinstance(exc, DomainException)
|
||||
assert exc.message == "Resource not found"
|
||||
|
||||
def test_already_exists_exception(self) -> None:
|
||||
"""Test already exists exception."""
|
||||
exc = AlreadyExistsException("Already exists")
|
||||
assert isinstance(exc, DomainException)
|
||||
assert exc.message == "Already exists"
|
||||
|
||||
def test_unauthorized_exception(self) -> None:
|
||||
"""Test unauthorized exception."""
|
||||
exc = UnauthorizedException("Unauthorized")
|
||||
assert isinstance(exc, DomainException)
|
||||
assert exc.message == "Unauthorized"
|
||||
|
||||
def test_forbidden_exception(self) -> None:
|
||||
"""Test forbidden exception."""
|
||||
exc = ForbiddenException("Forbidden")
|
||||
assert isinstance(exc, DomainException)
|
||||
assert exc.message == "Forbidden"
|
||||
123
tests/unit/domain/test_roles.py
Normal file
123
tests/unit/domain/test_roles.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Tests for role-based access control."""
|
||||
|
||||
from app.domain.roles import (
|
||||
ROLE_PERMISSIONS,
|
||||
Permission,
|
||||
Role,
|
||||
get_effective_role,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
|
||||
class TestRole:
|
||||
"""Test Role enum."""
|
||||
|
||||
def test_role_values(self) -> None:
|
||||
"""Test role enum values."""
|
||||
assert Role.ADMIN.value == "admin"
|
||||
assert Role.USER.value == "user"
|
||||
assert Role.GUEST.value == "guest"
|
||||
|
||||
def test_role_comparison(self) -> None:
|
||||
"""Test role comparison."""
|
||||
assert Role.ADMIN == Role.ADMIN
|
||||
# USER and ADMIN are different enum values with different string values
|
||||
assert Role.USER.value != Role.ADMIN.value # type: ignore[comparison-overlap]
|
||||
|
||||
|
||||
class TestPermissions:
|
||||
"""Test permission definitions."""
|
||||
|
||||
def test_permission_values(self) -> None:
|
||||
"""Test permission constants."""
|
||||
assert Permission.POST_CREATE == "post:create"
|
||||
assert Permission.POST_READ == "post:read"
|
||||
assert Permission.POST_READ_UNPUBLISHED == "post:read_unpublished"
|
||||
assert Permission.POST_UPDATE == "post:update"
|
||||
assert Permission.POST_DELETE == "post:delete"
|
||||
assert Permission.POST_PUBLISH == "post:publish"
|
||||
|
||||
|
||||
class TestRolePermissions:
|
||||
"""Test role-based permission mapping."""
|
||||
|
||||
def test_admin_has_all_permissions(self) -> None:
|
||||
"""Test admin has all permissions."""
|
||||
admin_perms = ROLE_PERMISSIONS[Role.ADMIN]
|
||||
assert Permission.POST_CREATE in admin_perms
|
||||
assert Permission.POST_READ in admin_perms
|
||||
assert Permission.POST_READ_UNPUBLISHED in admin_perms
|
||||
assert Permission.POST_UPDATE in admin_perms
|
||||
assert Permission.POST_DELETE in admin_perms
|
||||
assert Permission.POST_PUBLISH in admin_perms
|
||||
|
||||
def test_user_permissions(self) -> None:
|
||||
"""Test user permissions."""
|
||||
user_perms = ROLE_PERMISSIONS[Role.USER]
|
||||
assert Permission.POST_CREATE in user_perms
|
||||
assert Permission.POST_READ in user_perms
|
||||
assert Permission.POST_UPDATE in user_perms
|
||||
assert Permission.POST_DELETE in user_perms
|
||||
assert Permission.POST_PUBLISH in user_perms
|
||||
# User cannot read unpublished
|
||||
assert Permission.POST_READ_UNPUBLISHED not in user_perms
|
||||
|
||||
def test_guest_permissions(self) -> None:
|
||||
"""Test guest permissions."""
|
||||
guest_perms = ROLE_PERMISSIONS[Role.GUEST]
|
||||
assert Permission.POST_READ in guest_perms
|
||||
# Guest has very limited permissions
|
||||
assert Permission.POST_CREATE not in guest_perms
|
||||
assert Permission.POST_UPDATE not in guest_perms
|
||||
assert Permission.POST_DELETE not in guest_perms
|
||||
assert Permission.POST_READ_UNPUBLISHED not in guest_perms
|
||||
|
||||
|
||||
class TestHasPermission:
|
||||
"""Test has_permission function."""
|
||||
|
||||
def test_admin_has_all_permissions_check(self) -> None:
|
||||
"""Test admin permission checks."""
|
||||
assert has_permission(Role.ADMIN, Permission.POST_CREATE) is True
|
||||
assert has_permission(Role.ADMIN, Permission.POST_READ_UNPUBLISHED) is True
|
||||
assert has_permission(Role.ADMIN, "unknown:permission") is False
|
||||
|
||||
def test_user_limited_permissions(self) -> None:
|
||||
"""Test user limited permissions."""
|
||||
assert has_permission(Role.USER, Permission.POST_CREATE) is True
|
||||
assert has_permission(Role.USER, Permission.POST_READ_UNPUBLISHED) is False
|
||||
assert has_permission(Role.USER, Permission.POST_READ) is True
|
||||
|
||||
def test_guest_read_only(self) -> None:
|
||||
"""Test guest read-only access."""
|
||||
assert has_permission(Role.GUEST, Permission.POST_READ) is True
|
||||
assert has_permission(Role.GUEST, Permission.POST_CREATE) is False
|
||||
assert has_permission(Role.GUEST, Permission.POST_UPDATE) is False
|
||||
|
||||
|
||||
class TestGetEffectiveRole:
|
||||
"""Test get_effective_role function."""
|
||||
|
||||
def test_admin_from_roles_list(self) -> None:
|
||||
"""Test admin role detection."""
|
||||
assert get_effective_role(["admin"]) == Role.ADMIN
|
||||
assert get_effective_role(["user", "admin"]) == Role.ADMIN
|
||||
assert get_effective_role(["admin", "user"]) == Role.ADMIN
|
||||
|
||||
def test_user_from_roles_list(self) -> None:
|
||||
"""Test user role detection."""
|
||||
assert get_effective_role(["user"]) == Role.USER
|
||||
assert get_effective_role(["user", "moderator"]) == Role.USER
|
||||
|
||||
def test_guest_from_roles_list(self) -> None:
|
||||
"""Test guest role detection."""
|
||||
assert get_effective_role([]) == Role.GUEST
|
||||
assert get_effective_role(["unknown"]) == Role.GUEST
|
||||
assert get_effective_role(["guest"]) == Role.GUEST
|
||||
|
||||
def test_role_priority(self) -> None:
|
||||
"""Test that admin > user > guest."""
|
||||
# Admin takes precedence
|
||||
assert get_effective_role(["user", "admin", "guest"]) == Role.ADMIN
|
||||
# User takes precedence over guest
|
||||
assert get_effective_role(["guest", "user"]) == Role.USER
|
||||
93
tests/unit/domain/test_value_objects.py
Normal file
93
tests/unit/domain/test_value_objects.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Tests for domain value objects."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.domain.value_objects import Content, Slug, Title
|
||||
|
||||
|
||||
class TestTitle:
|
||||
def test_valid_title(self) -> None:
|
||||
"""Test creating a valid title."""
|
||||
title = Title("Valid Title")
|
||||
assert title.value == "Valid Title"
|
||||
|
||||
def test_title_too_short(self) -> None:
|
||||
"""Test title that is too short."""
|
||||
with pytest.raises(ValueError, match="at least"):
|
||||
Title("ab")
|
||||
|
||||
def test_title_too_long(self) -> None:
|
||||
"""Test title that is too long."""
|
||||
with pytest.raises(ValueError, match="at most"):
|
||||
Title("a" * 201)
|
||||
|
||||
def test_title_empty(self) -> None:
|
||||
"""Test empty title."""
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
Title(" ")
|
||||
|
||||
def test_title_not_string(self) -> None:
|
||||
"""Test non-string title."""
|
||||
with pytest.raises(ValueError, match="string"):
|
||||
Title(123) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestContent:
|
||||
def test_valid_content(self) -> None:
|
||||
"""Test creating valid content."""
|
||||
content = Content("This is valid content with enough characters")
|
||||
assert content.value == "This is valid content with enough characters"
|
||||
|
||||
def test_content_too_short(self) -> None:
|
||||
"""Test content that is too short."""
|
||||
with pytest.raises(ValueError, match="at least"):
|
||||
Content("short")
|
||||
|
||||
def test_content_too_long(self) -> None:
|
||||
"""Test content that is too long."""
|
||||
with pytest.raises(ValueError, match="at most"):
|
||||
Content("a" * 50001)
|
||||
|
||||
def test_content_empty(self) -> None:
|
||||
"""Test empty content."""
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
Content(" ")
|
||||
|
||||
|
||||
class TestSlug:
|
||||
def test_valid_slug(self) -> None:
|
||||
"""Test creating a valid slug."""
|
||||
slug = Slug("valid-slug")
|
||||
assert slug.value == "valid-slug"
|
||||
|
||||
def test_slug_from_title(self) -> None:
|
||||
"""Test generating slug from title."""
|
||||
slug = Slug.from_title("Hello World Post")
|
||||
assert slug.value == "hello-world-post"
|
||||
|
||||
def test_slug_from_title_with_special_chars(self) -> None:
|
||||
"""Test generating slug from title with special characters."""
|
||||
slug = Slug.from_title("Hello, World! Post @#$%")
|
||||
assert slug.value == "hello-world-post"
|
||||
|
||||
def test_slug_from_title_only_special_chars(self) -> None:
|
||||
"""Test generating slug from title with only special characters."""
|
||||
slug = Slug.from_title("!@#$%")
|
||||
assert slug.value == "post"
|
||||
|
||||
def test_slug_invalid_chars(self) -> None:
|
||||
"""Test slug with invalid characters."""
|
||||
with pytest.raises(ValueError, match="lowercase"):
|
||||
Slug("Invalid_Slug")
|
||||
|
||||
def test_slug_uppercase(self) -> None:
|
||||
"""Test slug with uppercase letters."""
|
||||
with pytest.raises(ValueError, match="lowercase"):
|
||||
Slug("Uppercase-Slug")
|
||||
|
||||
def test_slug_equality(self) -> None:
|
||||
"""Test slug value equality."""
|
||||
slug1 = Slug("test-slug")
|
||||
slug2 = Slug("test-slug")
|
||||
assert slug1 == slug2
|
||||
assert hash(slug1) == hash(slug2)
|
||||
0
tests/unit/infrastructure/__init__.py
Normal file
0
tests/unit/infrastructure/__init__.py
Normal file
303
tests/unit/infrastructure/test_auth.py
Normal file
303
tests/unit/infrastructure/test_auth.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""Tests for Keycloak authentication client."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.infrastructure.auth import KeycloakAuthClient, KeycloakUser, TokenInfo
|
||||
from app.infrastructure.config.settings import Settings
|
||||
|
||||
|
||||
class TestTokenInfo:
|
||||
"""Test TokenInfo dataclass."""
|
||||
|
||||
def test_token_info_valid(self) -> None:
|
||||
"""Test valid token info."""
|
||||
token_info = TokenInfo(
|
||||
active=True,
|
||||
user_id="user-123",
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["user"],
|
||||
)
|
||||
assert token_info.is_valid is True
|
||||
assert token_info.user_id == "user-123"
|
||||
assert token_info.username == "testuser"
|
||||
assert token_info.email == "test@example.com"
|
||||
assert token_info.roles == ["user"]
|
||||
|
||||
def test_token_info_invalid_not_active(self) -> None:
|
||||
"""Test invalid token when not active."""
|
||||
token_info = TokenInfo(
|
||||
active=False,
|
||||
user_id="user-123",
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["user"],
|
||||
)
|
||||
assert token_info.is_valid is False
|
||||
|
||||
def test_token_info_invalid_no_user_id(self) -> None:
|
||||
"""Test invalid token when no user_id."""
|
||||
token_info = TokenInfo(
|
||||
active=True,
|
||||
user_id="",
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=["user"],
|
||||
)
|
||||
assert token_info.is_valid is False
|
||||
|
||||
def test_token_info_empty_roles(self) -> None:
|
||||
"""Test token info with empty roles."""
|
||||
token_info = TokenInfo(
|
||||
active=True,
|
||||
user_id="user-123",
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
roles=[],
|
||||
)
|
||||
assert token_info.is_valid is True
|
||||
assert token_info.roles == []
|
||||
|
||||
|
||||
class TestKeycloakUser:
|
||||
"""Test KeycloakUser dataclass."""
|
||||
|
||||
def test_keycloak_user_creation(self) -> None:
|
||||
"""Test KeycloakUser creation."""
|
||||
user = KeycloakUser(
|
||||
id="user-123",
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
roles=["user", "admin"],
|
||||
is_active=True,
|
||||
)
|
||||
assert user.id == "user-123"
|
||||
assert user.username == "testuser"
|
||||
assert user.email == "test@example.com"
|
||||
assert user.first_name == "Test"
|
||||
assert user.last_name == "User"
|
||||
assert user.roles == ["user", "admin"]
|
||||
assert user.is_active is True
|
||||
|
||||
def test_keycloak_user_defaults(self) -> None:
|
||||
"""Test KeycloakUser with default values."""
|
||||
user = KeycloakUser(
|
||||
id="user-123",
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
)
|
||||
assert user.first_name == ""
|
||||
assert user.last_name == ""
|
||||
assert user.roles == []
|
||||
assert user.is_active is True
|
||||
|
||||
|
||||
class TestKeycloakAuthClient:
|
||||
"""Test KeycloakAuthClient."""
|
||||
|
||||
@pytest.fixture
|
||||
def settings(self) -> Settings:
|
||||
"""Create test settings."""
|
||||
from app.infrastructure.config import KCConfig, SecurityConfig
|
||||
|
||||
return Settings(
|
||||
environment="dev",
|
||||
kc=KCConfig(
|
||||
server_url="http://localhost:8080",
|
||||
realm="test-realm",
|
||||
client_id="test-client",
|
||||
client_secret="test-secret",
|
||||
token_cache_ttl=60,
|
||||
),
|
||||
security=SecurityConfig(
|
||||
secret_key="test-secret-key-for-jwt-tokens",
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, settings: Settings) -> KeycloakAuthClient:
|
||||
"""Create Keycloak client."""
|
||||
return KeycloakAuthClient(settings)
|
||||
|
||||
def test_client_initialization(self, client: KeycloakAuthClient, settings: Settings) -> None:
|
||||
"""Test client initialization."""
|
||||
assert client._settings == settings
|
||||
assert client._base_url == "http://localhost:8080/realms/test-realm"
|
||||
assert client._client_id == "test-client"
|
||||
assert client._client_secret == "test-secret"
|
||||
assert client._cache_ttl == 60
|
||||
|
||||
def test_get_introspection_url(self, client: KeycloakAuthClient) -> None:
|
||||
"""Test introspection URL generation."""
|
||||
url = client._get_introspection_url()
|
||||
assert (
|
||||
url
|
||||
== "http://localhost:8080/realms/test-realm/protocol/openid-connect/token/introspection"
|
||||
)
|
||||
|
||||
def test_get_userinfo_url(self, client: KeycloakAuthClient) -> None:
|
||||
"""Test userinfo URL generation."""
|
||||
url = client._get_userinfo_url()
|
||||
assert url == "http://localhost:8080/realms/test-realm/protocol/openid-connect/userinfo"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_introspect_token_success(self, client: KeycloakAuthClient) -> None:
|
||||
"""Test successful token introspection."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"active": True,
|
||||
"sub": "user-123",
|
||||
"preferred_username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"realm_access": {"roles": ["user", "admin"]},
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_async_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||
result = await client.introspect_token("test-token")
|
||||
|
||||
assert result.active is True
|
||||
assert result.user_id == "user-123"
|
||||
assert result.username == "testuser"
|
||||
assert result.email == "test@example.com"
|
||||
assert result.roles == ["user", "admin"]
|
||||
assert result.is_valid is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_introspect_token_inactive(self, client: KeycloakAuthClient) -> None:
|
||||
"""Test introspection with inactive token."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"active": False}
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_async_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||
result = await client.introspect_token("test-token")
|
||||
|
||||
assert result.active is False
|
||||
assert result.is_valid is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_introspect_token_http_error(self, client: KeycloakAuthClient) -> None:
|
||||
"""Test introspection with HTTP error."""
|
||||
import httpx
|
||||
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_async_client.post = AsyncMock(side_effect=httpx.HTTPError("Connection error"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||
result = await client.introspect_token("test-token")
|
||||
|
||||
assert result.active is False
|
||||
assert result.is_valid is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_introspect_token_uses_cache(self, client: KeycloakAuthClient) -> None:
|
||||
"""Test that token introspection uses cache."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"active": True,
|
||||
"sub": "user-123",
|
||||
"preferred_username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"realm_access": {"roles": ["user"]},
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_async_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||
# First call
|
||||
result1 = await client.introspect_token("test-token")
|
||||
# Second call should use cache
|
||||
result2 = await client.introspect_token("test-token")
|
||||
|
||||
# HTTP client should only be called once
|
||||
assert mock_async_client.post.call_count == 1
|
||||
assert result1.user_id == result2.user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_userinfo_success(self, client: KeycloakAuthClient) -> None:
|
||||
"""Test successful userinfo retrieval."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"sub": "user-123",
|
||||
"preferred_username": "testuser",
|
||||
"email": "test@example.com",
|
||||
"given_name": "Test",
|
||||
"family_name": "User",
|
||||
"realm_access": {"roles": ["user"]},
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_async_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||
result = await client.get_userinfo("test-token")
|
||||
|
||||
assert result is not None
|
||||
assert result.id == "user-123"
|
||||
assert result.username == "testuser"
|
||||
assert result.email == "test@example.com"
|
||||
assert result.first_name == "Test"
|
||||
assert result.last_name == "User"
|
||||
assert result.roles == ["user"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_userinfo_error(self, client: KeycloakAuthClient) -> None:
|
||||
"""Test userinfo retrieval with error."""
|
||||
import httpx
|
||||
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_async_client.get = AsyncMock(side_effect=httpx.HTTPError("Connection error"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||
result = await client.get_userinfo("test-token")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_introspect_token_no_realm_roles(self, client: KeycloakAuthClient) -> None:
|
||||
"""Test introspection without realm_access roles."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"active": True,
|
||||
"sub": "user-123",
|
||||
"preferred_username": "testuser",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
mock_response.raise_for_status = Mock()
|
||||
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.__aenter__ = AsyncMock(return_value=mock_async_client)
|
||||
mock_async_client.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_async_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_async_client):
|
||||
result = await client.introspect_token("test-token")
|
||||
|
||||
assert result.active is True
|
||||
assert result.roles == []
|
||||
244
tests/unit/infrastructure/test_config.py
Normal file
244
tests/unit/infrastructure/test_config.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Tests for infrastructure config."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.infrastructure.config import (
|
||||
AppConfig,
|
||||
DBConfig,
|
||||
Environment,
|
||||
KCConfig,
|
||||
SecurityConfig,
|
||||
Settings,
|
||||
)
|
||||
|
||||
|
||||
class TestSettings:
|
||||
"""Test Settings with composition pattern."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default settings values by creating settings without env file."""
|
||||
# Create settings with required secrets and no env file
|
||||
s = Settings(
|
||||
_env_file=None,
|
||||
security=SecurityConfig(secret_key="test-secret-key"),
|
||||
kc=KCConfig(client_secret="test-client-secret"),
|
||||
)
|
||||
assert s.app.name == "Blog API"
|
||||
assert s.app.debug is False
|
||||
assert s.app.host == "0.0.0.0"
|
||||
assert s.app.port == 8000
|
||||
assert s.database_url == "sqlite+aiosqlite:///./blog.db"
|
||||
assert s.db.echo is False
|
||||
assert s.security.secret_key == "test-secret-key"
|
||||
assert s.kc.client_secret == "test-client-secret"
|
||||
assert s.environment == Environment.DEV
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Test custom settings values."""
|
||||
s = Settings(
|
||||
_env_file=None,
|
||||
environment=Environment.PROD,
|
||||
app=AppConfig(
|
||||
name="Test API",
|
||||
debug=True,
|
||||
host="localhost",
|
||||
port=9000,
|
||||
),
|
||||
db=DBConfig(url="postgresql+asyncpg://user:pass@host/db"),
|
||||
security=SecurityConfig(secret_key="test-secret"),
|
||||
kc=KCConfig(client_secret="test-client-secret"),
|
||||
)
|
||||
assert s.app.name == "Test API"
|
||||
assert s.app.debug is True
|
||||
assert s.app.host == "localhost"
|
||||
assert s.app.port == 9000
|
||||
assert s.database_url == "postgresql+asyncpg://user:pass@host/db"
|
||||
assert s.security.secret_key == "test-secret"
|
||||
assert s.kc.client_secret == "test-client-secret"
|
||||
assert s.environment == Environment.PROD
|
||||
|
||||
def test_model_config(self) -> None:
|
||||
"""Test settings model config."""
|
||||
assert "env_file" in Settings.model_config
|
||||
|
||||
def test_is_dev_property(self) -> None:
|
||||
"""Test is_dev property."""
|
||||
s = Settings(
|
||||
_env_file=None,
|
||||
environment=Environment.DEV,
|
||||
security=SecurityConfig(secret_key="test"),
|
||||
kc=KCConfig(client_secret="test"),
|
||||
)
|
||||
assert s.is_dev is True
|
||||
assert s.is_prod is False
|
||||
|
||||
def test_is_prod_property(self) -> None:
|
||||
"""Test is_prod property."""
|
||||
s = Settings(
|
||||
_env_file=None,
|
||||
environment=Environment.PROD,
|
||||
security=SecurityConfig(secret_key="test"),
|
||||
kc=KCConfig(client_secret="test"),
|
||||
)
|
||||
assert s.is_prod is True
|
||||
assert s.is_dev is False
|
||||
|
||||
def test_prod_requires_security_secret(self) -> None:
|
||||
"""Test that prod mode requires security secret_key."""
|
||||
with pytest.raises(ValueError, match="SECURITY_SECRET_KEY"):
|
||||
Settings(
|
||||
_env_file=None,
|
||||
environment=Environment.PROD,
|
||||
security=SecurityConfig(secret_key=""),
|
||||
kc=KCConfig(client_secret="test"),
|
||||
)
|
||||
|
||||
def test_prod_requires_kc_secret(self) -> None:
|
||||
"""Test that prod mode requires KC client_secret."""
|
||||
with pytest.raises(ValueError, match="KC_CLIENT_SECRET"):
|
||||
Settings(
|
||||
_env_file=None,
|
||||
environment=Environment.PROD,
|
||||
security=SecurityConfig(secret_key="test"),
|
||||
kc=KCConfig(client_secret=""),
|
||||
)
|
||||
|
||||
def test_database_url_dev_default(self) -> None:
|
||||
"""Test default database URL in dev mode."""
|
||||
s = Settings(
|
||||
_env_file=None,
|
||||
environment=Environment.DEV,
|
||||
security=SecurityConfig(secret_key="test"),
|
||||
kc=KCConfig(client_secret="test"),
|
||||
)
|
||||
assert s.database_url == "sqlite+aiosqlite:///./blog.db"
|
||||
|
||||
def test_database_url_prod_builds_postgres(self) -> None:
|
||||
"""Test that database URL builds from components in prod."""
|
||||
s = Settings(
|
||||
_env_file=None,
|
||||
environment=Environment.PROD,
|
||||
db=DBConfig(
|
||||
url=None, # Force building from components
|
||||
host="db.example.com",
|
||||
port=5433,
|
||||
user="admin",
|
||||
password="secret",
|
||||
name="mydb",
|
||||
),
|
||||
security=SecurityConfig(secret_key="test"),
|
||||
kc=KCConfig(client_secret="test"),
|
||||
)
|
||||
assert s.database_url == "postgresql+asyncpg://admin:secret@db.example.com:5433/mydb"
|
||||
|
||||
def test_database_url_override(self) -> None:
|
||||
"""Test that explicit database URL overrides auto-building."""
|
||||
s = Settings(
|
||||
_env_file=None,
|
||||
environment=Environment.PROD,
|
||||
db=DBConfig(
|
||||
url="postgresql+asyncpg://custom/url",
|
||||
host="ignored",
|
||||
user="ignored",
|
||||
),
|
||||
security=SecurityConfig(secret_key="test"),
|
||||
kc=KCConfig(client_secret="test"),
|
||||
)
|
||||
assert s.database_url == "postgresql+asyncpg://custom/url"
|
||||
|
||||
|
||||
class TestAppConfig:
|
||||
"""Test AppConfig."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test AppConfig default values."""
|
||||
cfg = AppConfig()
|
||||
assert cfg.name == "Blog API"
|
||||
assert cfg.debug is False
|
||||
assert cfg.host == "0.0.0.0"
|
||||
assert cfg.port == 8000
|
||||
|
||||
|
||||
class TestDBConfig:
|
||||
"""Test DBConfig."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test DBConfig default values."""
|
||||
cfg = DBConfig()
|
||||
assert cfg.url is None
|
||||
assert cfg.echo is False
|
||||
assert cfg.host == "localhost"
|
||||
assert cfg.port == 5432
|
||||
assert cfg.user == "postgres"
|
||||
assert cfg.password == "postgres"
|
||||
assert cfg.name == "blog"
|
||||
|
||||
def test_postgres_url_validation(self) -> None:
|
||||
"""Test URL validation for postgres."""
|
||||
cfg = DBConfig(url="postgresql+asyncpg://user:pass@host/db")
|
||||
assert cfg.url == "postgresql+asyncpg://user:pass@host/db"
|
||||
|
||||
def test_sqlite_url_validation(self) -> None:
|
||||
"""Test URL validation for sqlite."""
|
||||
cfg = DBConfig(url="sqlite+aiosqlite:///./test.db")
|
||||
assert cfg.url == "sqlite+aiosqlite:///./test.db"
|
||||
|
||||
def test_invalid_url_validation(self) -> None:
|
||||
"""Test URL validation rejects invalid URLs."""
|
||||
with pytest.raises(ValueError, match="sqlite+.*postgresql+"):
|
||||
DBConfig(url="mysql://invalid")
|
||||
|
||||
|
||||
class TestKCConfig:
|
||||
"""Test KCConfig."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test KCConfig default values."""
|
||||
cfg = KCConfig(client_secret="test-secret")
|
||||
assert cfg.server_url == "http://localhost:8080"
|
||||
assert cfg.realm == "blog"
|
||||
assert cfg.client_id == "blog-api"
|
||||
assert cfg.client_secret == "test-secret"
|
||||
assert cfg.token_cache_ttl == 60
|
||||
|
||||
def test_is_configured_with_secret(self) -> None:
|
||||
"""Test is_configured returns True when secret is set."""
|
||||
cfg = KCConfig(client_secret="test-secret")
|
||||
assert cfg.is_configured is True
|
||||
|
||||
def test_is_configured_without_secret(self) -> None:
|
||||
"""Test is_configured returns False when secret is empty."""
|
||||
cfg = KCConfig(client_secret="")
|
||||
assert cfg.is_configured is False
|
||||
|
||||
|
||||
class TestSecurityConfig:
|
||||
"""Test SecurityConfig."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test SecurityConfig default values."""
|
||||
cfg = SecurityConfig(secret_key="test-key")
|
||||
assert cfg.secret_key == "test-key"
|
||||
assert cfg.access_token_expire_minutes == 30
|
||||
|
||||
def test_is_configured_with_secret(self) -> None:
|
||||
"""Test is_configured returns True when secret is set."""
|
||||
cfg = SecurityConfig(secret_key="test-secret")
|
||||
assert cfg.is_configured is True
|
||||
|
||||
def test_is_configured_without_secret(self) -> None:
|
||||
"""Test is_configured returns False when secret is empty."""
|
||||
cfg = SecurityConfig(secret_key="")
|
||||
assert cfg.is_configured is False
|
||||
|
||||
|
||||
class TestEnvironment:
|
||||
"""Test Environment enum."""
|
||||
|
||||
def test_dev_value(self) -> None:
|
||||
"""Test DEV environment value."""
|
||||
assert Environment.DEV.value == "dev"
|
||||
|
||||
def test_prod_value(self) -> None:
|
||||
"""Test PROD environment value."""
|
||||
assert Environment.PROD.value == "prod"
|
||||
@@ -1,52 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.core.config import Settings
|
||||
|
||||
|
||||
class TestSettings:
|
||||
def test_default_values(self) -> None:
|
||||
settings = Settings()
|
||||
assert settings.app_name == "Blog API"
|
||||
assert settings.debug is False
|
||||
assert settings.host == "0.0.0.0"
|
||||
assert settings.port == 8000
|
||||
assert settings.database_url is None
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
settings = Settings(
|
||||
app_name="Test API",
|
||||
debug=True,
|
||||
host="localhost",
|
||||
port=9000,
|
||||
database_url="postgresql://test",
|
||||
)
|
||||
assert settings.app_name == "Test API"
|
||||
assert settings.debug is True
|
||||
assert settings.host == "localhost"
|
||||
assert settings.port == 9000
|
||||
assert settings.database_url == "postgresql://test"
|
||||
|
||||
def test_settings_from_env(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"APP_NAME": "Env API",
|
||||
"DEBUG": "true",
|
||||
"HOST": "127.0.0.1",
|
||||
"PORT": "8080",
|
||||
"DATABASE_URL": "sqlite:///test.db",
|
||||
},
|
||||
):
|
||||
settings = Settings()
|
||||
assert settings.app_name == "Env API"
|
||||
assert settings.debug is True
|
||||
assert settings.host == "127.0.0.1"
|
||||
assert settings.port == 8080
|
||||
assert settings.database_url == "sqlite:///test.db"
|
||||
|
||||
def test_global_settings_instance(self) -> None:
|
||||
from app.core.config import settings
|
||||
|
||||
assert isinstance(settings, Settings)
|
||||
assert settings.app_name == "Blog API"
|
||||
@@ -1,110 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from app.common.error_handler import (
|
||||
ErrorResponse,
|
||||
app_exception_handler,
|
||||
http_exception_handler,
|
||||
register_exception_handlers,
|
||||
)
|
||||
from app.core.exceptions import AppException
|
||||
|
||||
|
||||
class TestErrorResponse:
|
||||
def test_error_response_creation(self) -> None:
|
||||
response = ErrorResponse(
|
||||
status_code=400,
|
||||
message="Bad request",
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.message == "Bad request"
|
||||
assert response.details is None
|
||||
|
||||
def test_error_response_with_details(self) -> None:
|
||||
response = ErrorResponse(
|
||||
status_code=500,
|
||||
message="Internal error",
|
||||
details={"field": "value"},
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
assert response.status_code == 500
|
||||
assert response.message == "Internal error"
|
||||
assert response.details == {"field": "value"}
|
||||
|
||||
|
||||
class TestAppExceptionHandler:
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_exception_handler(self) -> None:
|
||||
request = Mock(spec=Request)
|
||||
exc = AppException(message="Test error", status_code=400)
|
||||
|
||||
response = await app_exception_handler(request, exc)
|
||||
|
||||
assert response.status_code == 400
|
||||
body = bytes(response.body).decode()
|
||||
assert "Test error" in body
|
||||
assert "400" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_exception_handler_content(self) -> None:
|
||||
request = Mock(spec=Request)
|
||||
exc = AppException(message="Validation error", status_code=422)
|
||||
|
||||
with patch("app.common.error_handler.datetime") as mock_datetime:
|
||||
mock_datetime.now.return_value.isoformat.return_value = (
|
||||
"2024-01-01T00:00:00"
|
||||
)
|
||||
|
||||
response = await app_exception_handler(request, exc)
|
||||
|
||||
content = bytes(response.body).decode()
|
||||
assert "Validation error" in content
|
||||
assert "422" in content
|
||||
assert "2024-01-01T00:00:00" in content
|
||||
|
||||
|
||||
class TestHttpExceptionHandler:
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_exception_handler(self) -> None:
|
||||
request = Mock(spec=Request)
|
||||
exc = HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
response = await http_exception_handler(request, exc)
|
||||
|
||||
assert response.status_code == 404
|
||||
body = bytes(response.body).decode()
|
||||
assert "Not found" in body
|
||||
assert "404" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_exception_handler_content(self) -> None:
|
||||
request = Mock(spec=Request)
|
||||
exc = HTTPException(status_code=503, detail="Service unavailable")
|
||||
|
||||
with patch("app.common.error_handler.datetime") as mock_datetime:
|
||||
mock_datetime.now.return_value.isoformat.return_value = (
|
||||
"2024-01-01T12:00:00"
|
||||
)
|
||||
|
||||
response = await http_exception_handler(request, exc)
|
||||
|
||||
content = bytes(response.body).decode()
|
||||
assert "Service unavailable" in content
|
||||
assert "503" in content
|
||||
assert "2024-01-01T12:00:00" in content
|
||||
|
||||
|
||||
class TestRegisterExceptionHandlers:
|
||||
def test_register_exception_handlers(self) -> None:
|
||||
app = Mock(spec=FastAPI)
|
||||
|
||||
register_exception_handlers(app)
|
||||
|
||||
assert app.add_exception_handler.call_count == 2
|
||||
app.add_exception_handler.assert_any_call(AppException, app_exception_handler)
|
||||
app.add_exception_handler.assert_any_call(HTTPException, http_exception_handler)
|
||||
@@ -1,87 +0,0 @@
|
||||
from app.core.exceptions import (
|
||||
AppException,
|
||||
ForbiddenError,
|
||||
NotFoundError,
|
||||
UnauthorizedError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
|
||||
class TestAppException:
|
||||
def test_default_status_code(self) -> None:
|
||||
exc = AppException(message="Test error")
|
||||
assert exc.message == "Test error"
|
||||
assert exc.status_code == 500
|
||||
|
||||
def test_custom_status_code(self) -> None:
|
||||
exc = AppException(message="Custom error", status_code=400)
|
||||
assert exc.message == "Custom error"
|
||||
assert exc.status_code == 400
|
||||
|
||||
def test_string_representation(self) -> None:
|
||||
exc = AppException(message="Error message")
|
||||
assert str(exc) == "Error message"
|
||||
|
||||
|
||||
class TestNotFoundError:
|
||||
def test_default_message(self) -> None:
|
||||
exc = NotFoundError()
|
||||
assert exc.message == "Resource not found"
|
||||
assert exc.status_code == 404
|
||||
|
||||
def test_custom_message(self) -> None:
|
||||
exc = NotFoundError(message="Item not found")
|
||||
assert exc.message == "Item not found"
|
||||
assert exc.status_code == 404
|
||||
|
||||
def test_is_subclass_of_app_exception(self) -> None:
|
||||
exc = NotFoundError()
|
||||
assert isinstance(exc, AppException)
|
||||
|
||||
|
||||
class TestValidationError:
|
||||
def test_default_message(self) -> None:
|
||||
exc = ValidationError()
|
||||
assert exc.message == "Validation failed"
|
||||
assert exc.status_code == 400
|
||||
|
||||
def test_custom_message(self) -> None:
|
||||
exc = ValidationError(message="Invalid email format")
|
||||
assert exc.message == "Invalid email format"
|
||||
assert exc.status_code == 400
|
||||
|
||||
def test_is_subclass_of_app_exception(self) -> None:
|
||||
exc = ValidationError()
|
||||
assert isinstance(exc, AppException)
|
||||
|
||||
|
||||
class TestUnauthorizedError:
|
||||
def test_default_message(self) -> None:
|
||||
exc = UnauthorizedError()
|
||||
assert exc.message == "Unauthorized"
|
||||
assert exc.status_code == 401
|
||||
|
||||
def test_custom_message(self) -> None:
|
||||
exc = UnauthorizedError(message="Invalid credentials")
|
||||
assert exc.message == "Invalid credentials"
|
||||
assert exc.status_code == 401
|
||||
|
||||
def test_is_subclass_of_app_exception(self) -> None:
|
||||
exc = UnauthorizedError()
|
||||
assert isinstance(exc, AppException)
|
||||
|
||||
|
||||
class TestForbiddenError:
|
||||
def test_default_message(self) -> None:
|
||||
exc = ForbiddenError()
|
||||
assert exc.message == "Forbidden"
|
||||
assert exc.status_code == 403
|
||||
|
||||
def test_custom_message(self) -> None:
|
||||
exc = ForbiddenError(message="Access denied")
|
||||
assert exc.message == "Access denied"
|
||||
assert exc.status_code == 403
|
||||
|
||||
def test_is_subclass_of_app_exception(self) -> None:
|
||||
exc = ForbiddenError()
|
||||
assert isinstance(exc, AppException)
|
||||
49
tests/unit/test_main.py
Normal file
49
tests/unit/test_main.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Tests for main application."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.main import app_factory, lifespan, main
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan() -> None:
|
||||
"""Test lifespan context manager."""
|
||||
app = FastAPI()
|
||||
|
||||
with (
|
||||
patch("app.main.init_db") as mock_init,
|
||||
patch("app.main.close_db") as mock_close,
|
||||
):
|
||||
async with lifespan(app):
|
||||
mock_init.assert_called_once()
|
||||
mock_close.assert_not_called()
|
||||
mock_close.assert_called_once()
|
||||
|
||||
|
||||
def test_app_factory() -> None:
|
||||
"""Test app factory creates FastAPI app."""
|
||||
app = app_factory()
|
||||
assert isinstance(app, FastAPI)
|
||||
|
||||
|
||||
def test_app_factory_has_routes() -> None:
|
||||
"""Test app has registered routes."""
|
||||
app = app_factory()
|
||||
routes = [str(route.path) for route in app.routes if hasattr(route, "path")]
|
||||
assert "/health" in routes
|
||||
# Check that API routes are included
|
||||
assert any("api" in path for path in routes)
|
||||
|
||||
|
||||
@patch("app.main.uvicorn.run")
|
||||
def test_main(mock_uvicorn_run: Mock) -> None:
|
||||
"""Test main function starts uvicorn."""
|
||||
main()
|
||||
mock_uvicorn_run.assert_called_once()
|
||||
call_kwargs = mock_uvicorn_run.call_args.kwargs
|
||||
assert call_kwargs.get("factory") is True
|
||||
assert call_kwargs.get("host") == "0.0.0.0"
|
||||
assert call_kwargs.get("port") == 8000
|
||||
@@ -1,33 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.main import app_factory, lifespan, main
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan() -> None:
|
||||
app = FastAPI()
|
||||
assert isinstance(lifespan, asynccontextmanager(lifespan).__class__) # type: ignore[arg-type]
|
||||
|
||||
async with lifespan(app):
|
||||
pass
|
||||
|
||||
|
||||
def test_app_factory() -> None:
|
||||
app = app_factory()
|
||||
assert isinstance(app, FastAPI)
|
||||
assert app.router.lifespan_context == lifespan
|
||||
|
||||
|
||||
@patch("app.main.uvicorn.run")
|
||||
def test_main(mock_uvicorn_run: Mock) -> None:
|
||||
main()
|
||||
mock_uvicorn_run.assert_called_once_with(
|
||||
app_factory,
|
||||
factory=True,
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
)
|
||||
Reference in New Issue
Block a user