Auth impl #10
134
AGENTS.md
134
AGENTS.md
@@ -2,6 +2,7 @@
|
||||
|
||||
## Stack
|
||||
- Python 3.13+, FastAPI, pydantic, uvicorn
|
||||
- SQLAlchemy 2.0 (async), aiosqlite
|
||||
- Package manager: `uv`
|
||||
- CI: Woodpecker (lint, test, type on push/PR to `dev`)
|
||||
|
||||
@@ -20,25 +21,124 @@ uv run blog # Start dev server (port 8000)
|
||||
## Pre-commit order
|
||||
`ruff check --fix` → `ruff format` → `isort` → `mypy`
|
||||
|
||||
## Architecture
|
||||
## DDD Architecture
|
||||
|
||||
### Layer Structure
|
||||
```
|
||||
app/
|
||||
main.py # Entry point, uvicorn.run(app_factory)
|
||||
core/config.py # Settings from .env via pydantic-settings
|
||||
core/exceptions.py
|
||||
common/error_handler.py
|
||||
api/v1/
|
||||
modules/
|
||||
├── domain/ # Domain Layer - business logic, no dependencies
|
||||
│ ├── entities/ # Domain entities (Post, User, etc.)
|
||||
│ │ ├── base.py # Base entity class
|
||||
│ │ └── post.py # Post entity with business logic
|
||||
│ ├── value_objects/ # Value objects (Title, Content, Slug)
|
||||
│ │ ├── base.py
|
||||
│ │ ├── title.py
|
||||
│ │ ├── content.py
|
||||
│ │ └── slug.py
|
||||
│ ├── repositories/ # Repository interfaces (abstract)
|
||||
│ │ ├── base.py
|
||||
│ │ └── post.py
|
||||
│ └── exceptions.py # Domain exceptions
|
||||
│
|
||||
├── application/ # Application Layer - use cases
|
||||
│ ├── dtos/ # Data Transfer Objects
|
||||
│ │ └── post.py
|
||||
│ ├── interfaces/ # Abstract interfaces (UoW)
|
||||
│ │ └── unit_of_work.py
|
||||
│ └── use_cases/ # Use cases (CQRS-like)
|
||||
│ ├── create_post.py
|
||||
│ ├── get_post.py
|
||||
│ ├── update_post.py
|
||||
│ ├── delete_post.py
|
||||
│ ├── list_posts.py
|
||||
│ └── publish_post.py
|
||||
│
|
||||
├── infrastructure/ # Infrastructure Layer - external concerns
|
||||
│ ├── config/ # Configuration
|
||||
│ │ └── settings.py
|
||||
│ ├── database/ # Database connection & ORM models
|
||||
│ │ ├── connection.py
|
||||
│ │ └── models.py
|
||||
│ ├── repositories/ # Repository implementations
|
||||
│ │ ├── post.py # SQLAlchemyPostRepository
|
||||
│ │ └── unit_of_work.py # SQLAlchemyUnitOfWork
|
||||
│ ├── di/ # Dependency Injection
|
||||
│ │ └── container.py
|
||||
│ └── middleware/ # Exception handlers
|
||||
│ └── error_handler.py
|
||||
│
|
||||
├── presentation/ # Presentation Layer - API
|
||||
│ ├── api/ # FastAPI routes
|
||||
│ │ ├── v1/ # API version 1
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ └── posts.py # Posts endpoints
|
||||
│ │ ├── deps.py # FastAPI dependencies
|
||||
│ │ └── __init__.py
|
||||
│ └── schemas/ # Pydantic schemas
|
||||
│ └── post.py
|
||||
│
|
||||
└── main.py # Application entry point
|
||||
|
||||
tests/
|
||||
unit/
|
||||
integration/
|
||||
e2e/
|
||||
api/
|
||||
├── unit/ # Unit tests (domain, use cases)
|
||||
│ ├── domain/ # Domain layer tests
|
||||
│ ├── application/ # Application layer tests
|
||||
│ └── infrastructure/ # Infrastructure tests
|
||||
├── integration/ # Integration tests (DB, repos)
|
||||
├── api/ # API endpoint tests
|
||||
└── e2e/ # End-to-end tests
|
||||
```
|
||||
|
||||
## Key conventions
|
||||
- All commands use `uv run` prefix
|
||||
- pytest: asyncio_mode=auto, coverage on `app/`
|
||||
- mypy: strict=true with pydantic plugin
|
||||
- isort: black profile, filter_files=true
|
||||
- `.env` loaded by pydantic-settings (not in repo)
|
||||
## Key Conventions
|
||||
|
||||
### Dependency Rule
|
||||
- Domain layer has **NO dependencies** on other layers
|
||||
- Application layer depends only on Domain
|
||||
- Infrastructure depends on Domain and Application
|
||||
- Presentation depends on all other layers
|
||||
|
||||
### Testing
|
||||
- **Unit tests**: Test domain logic without DB/external services
|
||||
- **Integration tests**: Test repository implementations with real DB
|
||||
- **API tests**: Test endpoints with mocked use cases
|
||||
- **E2E tests**: Full workflow testing
|
||||
|
||||
### Code Patterns
|
||||
- Use **dataclasses** for entities and value objects
|
||||
- Use **frozen dataclasses** for value objects (immutable)
|
||||
- Use **Unit of Work** pattern for transactions
|
||||
- Use **Repository** pattern for data access
|
||||
- Use **Dependency Injection** via FastAPI's Depends()
|
||||
|
||||
## DDD Concepts Used
|
||||
|
||||
### Entities
|
||||
- Have identity (UUID)
|
||||
- Mutable state
|
||||
- Business logic methods (publish, update_title, etc.)
|
||||
- Example: `Post` entity
|
||||
|
||||
### Value Objects
|
||||
- Immutable
|
||||
- Defined by attributes
|
||||
- Validated on creation
|
||||
- Examples: `Title`, `Content`, `Slug`
|
||||
|
||||
### Aggregates & Repositories
|
||||
- `Post` is an aggregate root
|
||||
- `PostRepository` interface in Domain
|
||||
- `SQLAlchemyPostRepository` implementation in Infrastructure
|
||||
|
||||
### Domain Events
|
||||
- Placeholder for future implementation
|
||||
- Can be added via event bus in application layer
|
||||
|
||||
## Configuration
|
||||
- `.env` file loaded by pydantic-settings
|
||||
- Settings available via `app.infrastructure.config.settings`
|
||||
|
||||
## Database
|
||||
- SQLAlchemy 2.0 with async support
|
||||
- SQLite by default (aiosqlite)
|
||||
- Tables auto-created on startup
|
||||
- Use `init_db()` and `close_db()` in lifespan
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""API module - HTTP routes and endpoints."""
|
||||
@@ -1 +0,0 @@
|
||||
"""API version 1 endpoints."""
|
||||
28
app/application/__init__.py
Normal file
28
app/application/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Application layer exports."""
|
||||
|
||||
from app.application.dtos import CreatePostDTO, PostResponseDTO, UpdatePostDTO
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.application.use_cases import (
|
||||
CreatePostUseCase,
|
||||
DeletePostUseCase,
|
||||
GetPostUseCase,
|
||||
ListPostsUseCase,
|
||||
PublishPostUseCase,
|
||||
UpdatePostUseCase,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# DTOs
|
||||
"CreatePostDTO",
|
||||
"UpdatePostDTO",
|
||||
"PostResponseDTO",
|
||||
# Interfaces
|
||||
"TransactionManager",
|
||||
# Use Cases
|
||||
"CreatePostUseCase",
|
||||
"GetPostUseCase",
|
||||
"UpdatePostUseCase",
|
||||
"DeletePostUseCase",
|
||||
"ListPostsUseCase",
|
||||
"PublishPostUseCase",
|
||||
]
|
||||
5
app/application/dtos/__init__.py
Normal file
5
app/application/dtos/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Application DTOs."""
|
||||
|
||||
from app.application.dtos.post import CreatePostDTO, PostResponseDTO, UpdatePostDTO
|
||||
|
||||
__all__ = ["CreatePostDTO", "UpdatePostDTO", "PostResponseDTO"]
|
||||
39
app/application/dtos/post.py
Normal file
39
app/application/dtos/post.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""DTOs for post use cases."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CreatePostDTO:
|
||||
"""DTO for creating a post."""
|
||||
|
||||
title: str
|
||||
content: str
|
||||
author_id: str
|
||||
tags: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UpdatePostDTO:
|
||||
"""DTO for updating a post."""
|
||||
|
||||
title: str | None = None
|
||||
content: str | None = None
|
||||
tags: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostResponseDTO:
|
||||
"""DTO for post response."""
|
||||
|
||||
id: UUID
|
||||
title: str
|
||||
content: str
|
||||
slug: str
|
||||
author_id: str
|
||||
published: bool
|
||||
tags: list[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
5
app/application/interfaces/__init__.py
Normal file
5
app/application/interfaces/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Application interfaces."""
|
||||
|
||||
from app.application.interfaces.transaction_manager import TransactionManager
|
||||
|
||||
__all__ = ["TransactionManager"]
|
||||
17
app/application/interfaces/transaction_manager.py
Normal file
17
app/application/interfaces/transaction_manager.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Transaction Manager interface for managing database transactions."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class TransactionManager(ABC):
|
||||
"""Abstract Transaction Manager for controlling transaction boundaries."""
|
||||
|
||||
@abstractmethod
|
||||
async def commit(self) -> None:
|
||||
"""Commit the current transaction."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def rollback(self) -> None:
|
||||
"""Rollback the current transaction."""
|
||||
...
|
||||
17
app/application/use_cases/__init__.py
Normal file
17
app/application/use_cases/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Use cases."""
|
||||
|
||||
from app.application.use_cases.create_post import CreatePostUseCase
|
||||
from app.application.use_cases.delete_post import DeletePostUseCase
|
||||
from app.application.use_cases.get_post import GetPostUseCase
|
||||
from app.application.use_cases.list_posts import ListPostsUseCase
|
||||
from app.application.use_cases.publish_post import PublishPostUseCase
|
||||
from app.application.use_cases.update_post import UpdatePostUseCase
|
||||
|
||||
__all__ = [
|
||||
"CreatePostUseCase",
|
||||
"GetPostUseCase",
|
||||
"UpdatePostUseCase",
|
||||
"DeletePostUseCase",
|
||||
"ListPostsUseCase",
|
||||
"PublishPostUseCase",
|
||||
]
|
||||
62
app/application/use_cases/create_post.py
Normal file
62
app/application/use_cases/create_post.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Create post use case."""
|
||||
|
||||
from app.application.dtos.post import CreatePostDTO, PostResponseDTO
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.entities import Post
|
||||
from app.domain.exceptions import AlreadyExistsException
|
||||
from app.domain.repositories import PostRepository
|
||||
|
||||
|
||||
class CreatePostUseCase:
|
||||
"""Use case for creating a new blog post."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> None:
|
||||
self._post_repo = post_repo
|
||||
self._tx_manager = tx_manager
|
||||
|
||||
async def execute(self, dto: CreatePostDTO) -> PostResponseDTO:
|
||||
"""Execute the use case."""
|
||||
# Generate slug from title
|
||||
from app.domain.value_objects import Slug
|
||||
|
||||
slug = Slug.from_title(dto.title)
|
||||
|
||||
# Check if slug already exists
|
||||
if await self._post_repo.slug_exists(slug.value):
|
||||
raise AlreadyExistsException(
|
||||
f"Post with slug '{slug.value}' already exists"
|
||||
)
|
||||
|
||||
# Create domain entity
|
||||
post = Post.create(
|
||||
title_str=dto.title,
|
||||
content_str=dto.content,
|
||||
author_id=dto.author_id,
|
||||
tags=dto.tags or [],
|
||||
)
|
||||
|
||||
# Persist entity
|
||||
await self._post_repo.add(post)
|
||||
|
||||
# Commit transaction
|
||||
await self._tx_manager.commit()
|
||||
|
||||
return self._map_to_dto(post)
|
||||
|
||||
def _map_to_dto(self, post: Post) -> PostResponseDTO:
|
||||
"""Map domain entity to response DTO."""
|
||||
return PostResponseDTO(
|
||||
id=post.id,
|
||||
title=post.title.value,
|
||||
content=post.content.value,
|
||||
slug=post.slug.value,
|
||||
author_id=post.author_id,
|
||||
published=post.published,
|
||||
tags=post.tags.copy(),
|
||||
created_at=post.created_at,
|
||||
updated_at=post.updated_at,
|
||||
)
|
||||
35
app/application/use_cases/delete_post.py
Normal file
35
app/application/use_cases/delete_post.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Delete post use case."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.exceptions import ForbiddenException, NotFoundException
|
||||
from app.domain.repositories import PostRepository
|
||||
|
||||
|
||||
class DeletePostUseCase:
|
||||
"""Use case for deleting a blog post."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> None:
|
||||
self._post_repo = post_repo
|
||||
self._tx_manager = tx_manager
|
||||
|
||||
async def execute(self, post_id: UUID, current_user_id: str) -> None:
|
||||
"""Execute the use case."""
|
||||
post = await self._post_repo.get_by_id(post_id)
|
||||
if not post:
|
||||
raise NotFoundException(f"Post with id '{post_id}' not found")
|
||||
|
||||
# Check authorization
|
||||
if post.author_id != current_user_id:
|
||||
raise ForbiddenException("You can only delete your own posts")
|
||||
|
||||
# Delete the post
|
||||
await self._post_repo.delete(post_id)
|
||||
|
||||
# Commit transaction
|
||||
await self._tx_manager.commit()
|
||||
49
app/application/use_cases/get_post.py
Normal file
49
app/application/use_cases/get_post.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Get post use case."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from app.application.dtos.post import PostResponseDTO
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.entities import Post
|
||||
from app.domain.exceptions import NotFoundException
|
||||
from app.domain.repositories import PostRepository
|
||||
|
||||
|
||||
class GetPostUseCase:
|
||||
"""Use case for retrieving a post by ID or slug."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> None:
|
||||
self._post_repo = post_repo
|
||||
self._tx_manager = tx_manager
|
||||
|
||||
async def by_id(self, post_id: UUID) -> PostResponseDTO:
|
||||
"""Get post by ID."""
|
||||
post = await self._post_repo.get_by_id(post_id)
|
||||
if not post:
|
||||
raise NotFoundException(f"Post with id '{post_id}' not found")
|
||||
return self._map_to_dto(post)
|
||||
|
||||
async def by_slug(self, slug: str) -> PostResponseDTO:
|
||||
"""Get post by slug."""
|
||||
post = await self._post_repo.get_by_slug(slug)
|
||||
if not post:
|
||||
raise NotFoundException(f"Post with slug '{slug}' not found")
|
||||
return self._map_to_dto(post)
|
||||
|
||||
def _map_to_dto(self, post: Post) -> PostResponseDTO:
|
||||
"""Map domain entity to response DTO."""
|
||||
return PostResponseDTO(
|
||||
id=post.id,
|
||||
title=post.title.value,
|
||||
content=post.content.value,
|
||||
slug=post.slug.value,
|
||||
author_id=post.author_id,
|
||||
published=post.published,
|
||||
tags=post.tags.copy(),
|
||||
created_at=post.created_at,
|
||||
updated_at=post.updated_at,
|
||||
)
|
||||
57
app/application/use_cases/list_posts.py
Normal file
57
app/application/use_cases/list_posts.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""List posts use case."""
|
||||
|
||||
from app.application.dtos.post import PostResponseDTO
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.entities import Post
|
||||
from app.domain.repositories import PostRepository
|
||||
|
||||
|
||||
class ListPostsUseCase:
|
||||
"""Use case for listing blog posts with filtering."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> None:
|
||||
self._post_repo = post_repo
|
||||
self._tx_manager = tx_manager
|
||||
|
||||
async def all_posts(self) -> list[PostResponseDTO]:
|
||||
"""Get all posts."""
|
||||
posts = await self._post_repo.get_all()
|
||||
return [self._map_to_dto(post) for post in posts]
|
||||
|
||||
async def published_posts(self) -> list[PostResponseDTO]:
|
||||
"""Get all published posts."""
|
||||
posts = await self._post_repo.get_published()
|
||||
return [self._map_to_dto(post) for post in posts]
|
||||
|
||||
async def by_author(self, author_id: str) -> list[PostResponseDTO]:
|
||||
"""Get posts by author."""
|
||||
posts = await self._post_repo.get_by_author(author_id)
|
||||
return [self._map_to_dto(post) for post in posts]
|
||||
|
||||
async def by_tag(self, tag: str) -> list[PostResponseDTO]:
|
||||
"""Get posts by tag."""
|
||||
posts = await self._post_repo.get_by_tag(tag)
|
||||
return [self._map_to_dto(post) for post in posts]
|
||||
|
||||
async def search(self, query: str) -> list[PostResponseDTO]:
|
||||
"""Search posts."""
|
||||
posts = await self._post_repo.search(query)
|
||||
return [self._map_to_dto(post) for post in posts]
|
||||
|
||||
def _map_to_dto(self, post: Post) -> PostResponseDTO:
|
||||
"""Map domain entity to response DTO."""
|
||||
return PostResponseDTO(
|
||||
id=post.id,
|
||||
title=post.title.value,
|
||||
content=post.content.value,
|
||||
slug=post.slug.value,
|
||||
author_id=post.author_id,
|
||||
published=post.published,
|
||||
tags=post.tags.copy(),
|
||||
created_at=post.created_at,
|
||||
updated_at=post.updated_at,
|
||||
)
|
||||
65
app/application/use_cases/publish_post.py
Normal file
65
app/application/use_cases/publish_post.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Publish post use case."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from app.application.dtos.post import PostResponseDTO
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.entities import Post
|
||||
from app.domain.exceptions import ForbiddenException, NotFoundException
|
||||
from app.domain.repositories import PostRepository
|
||||
|
||||
|
||||
class PublishPostUseCase:
|
||||
"""Use case for publishing/unpublishing a blog post."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> None:
|
||||
self._post_repo = post_repo
|
||||
self._tx_manager = tx_manager
|
||||
|
||||
async def publish(self, post_id: UUID, current_user_id: str) -> PostResponseDTO:
|
||||
"""Publish a post."""
|
||||
post = await self._post_repo.get_by_id(post_id)
|
||||
if not post:
|
||||
raise NotFoundException(f"Post with id '{post_id}' not found")
|
||||
|
||||
if post.author_id != current_user_id:
|
||||
raise ForbiddenException("You can only publish your own posts")
|
||||
|
||||
post.publish()
|
||||
await self._post_repo.update(post)
|
||||
await self._tx_manager.commit()
|
||||
|
||||
return self._map_to_dto(post)
|
||||
|
||||
async def unpublish(self, post_id: UUID, current_user_id: str) -> PostResponseDTO:
|
||||
"""Unpublish a post."""
|
||||
post = await self._post_repo.get_by_id(post_id)
|
||||
if not post:
|
||||
raise NotFoundException(f"Post with id '{post_id}' not found")
|
||||
|
||||
if post.author_id != current_user_id:
|
||||
raise ForbiddenException("You can only unpublish your own posts")
|
||||
|
||||
post.unpublish()
|
||||
await self._post_repo.update(post)
|
||||
await self._tx_manager.commit()
|
||||
|
||||
return self._map_to_dto(post)
|
||||
|
||||
def _map_to_dto(self, post: Post) -> PostResponseDTO:
|
||||
"""Map domain entity to response DTO."""
|
||||
return PostResponseDTO(
|
||||
id=post.id,
|
||||
title=post.title.value,
|
||||
content=post.content.value,
|
||||
slug=post.slug.value,
|
||||
author_id=post.author_id,
|
||||
published=post.published,
|
||||
tags=post.tags.copy(),
|
||||
created_at=post.created_at,
|
||||
updated_at=post.updated_at,
|
||||
)
|
||||
73
app/application/use_cases/update_post.py
Normal file
73
app/application/use_cases/update_post.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Update post use case."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from app.application.dtos.post import PostResponseDTO, UpdatePostDTO
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.entities import Post
|
||||
from app.domain.exceptions import ForbiddenException, NotFoundException
|
||||
from app.domain.repositories import PostRepository
|
||||
from app.domain.value_objects import Content, Title
|
||||
|
||||
|
||||
class UpdatePostUseCase:
|
||||
"""Use case for updating a blog post."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> None:
|
||||
self._post_repo = post_repo
|
||||
self._tx_manager = tx_manager
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
post_id: UUID,
|
||||
dto: UpdatePostDTO,
|
||||
current_user_id: str,
|
||||
) -> PostResponseDTO:
|
||||
"""Execute the use case."""
|
||||
post = await self._post_repo.get_by_id(post_id)
|
||||
if not post:
|
||||
raise NotFoundException(f"Post with id '{post_id}' not found")
|
||||
|
||||
# Check authorization
|
||||
if post.author_id != current_user_id:
|
||||
raise ForbiddenException("You can only update your own posts")
|
||||
|
||||
# Update fields
|
||||
if dto.title is not None:
|
||||
post.update_title(Title(dto.title))
|
||||
|
||||
if dto.content is not None:
|
||||
post.update_content(Content(dto.content))
|
||||
|
||||
if dto.tags is not None:
|
||||
# Replace all tags
|
||||
for tag in post.tags[:]:
|
||||
post.remove_tag(tag)
|
||||
for tag in dto.tags:
|
||||
post.add_tag(tag)
|
||||
|
||||
# Persist changes
|
||||
await self._post_repo.update(post)
|
||||
|
||||
# Commit transaction
|
||||
await self._tx_manager.commit()
|
||||
|
||||
return self._map_to_dto(post)
|
||||
|
||||
def _map_to_dto(self, post: Post) -> PostResponseDTO:
|
||||
"""Map domain entity to response DTO."""
|
||||
return PostResponseDTO(
|
||||
id=post.id,
|
||||
title=post.title.value,
|
||||
content=post.content.value,
|
||||
slug=post.slug.value,
|
||||
author_id=post.author_id,
|
||||
published=post.published,
|
||||
tags=post.tags.copy(),
|
||||
created_at=post.created_at,
|
||||
updated_at=post.updated_at,
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
"""Common utilities and shared components."""
|
||||
@@ -1,48 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from app.core.exceptions import AppException
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
status_code: int
|
||||
message: str
|
||||
details: dict[str, str] | None = None
|
||||
timestamp: str
|
||||
|
||||
|
||||
async def app_exception_handler(request: Request, exc: AppException) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"status_code": exc.status_code,
|
||||
"message": exc.message,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"status_code": exc.status_code,
|
||||
"message": str(exc.detail),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
app.add_exception_handler(
|
||||
AppException,
|
||||
app_exception_handler, # type: ignore[arg-type]
|
||||
)
|
||||
app.add_exception_handler(
|
||||
HTTPException,
|
||||
http_exception_handler, # type: ignore[arg-type]
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
"""Core module - shared functionality and configuration."""
|
||||
@@ -1,15 +0,0 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
app_name: str = "Blog API"
|
||||
debug: bool = False
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
database_url: str | None = None
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -1,25 +0,0 @@
|
||||
class AppException(Exception):
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class NotFoundError(AppException):
|
||||
def __init__(self, message: str = "Resource not found"):
|
||||
super().__init__(message, status_code=404)
|
||||
|
||||
|
||||
class ValidationError(AppException):
|
||||
def __init__(self, message: str = "Validation failed"):
|
||||
super().__init__(message, status_code=400)
|
||||
|
||||
|
||||
class UnauthorizedError(AppException):
|
||||
def __init__(self, message: str = "Unauthorized"):
|
||||
super().__init__(message, status_code=401)
|
||||
|
||||
|
||||
class ForbiddenError(AppException):
|
||||
def __init__(self, message: str = "Forbidden"):
|
||||
super().__init__(message, status_code=403)
|
||||
34
app/domain/__init__.py
Normal file
34
app/domain/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Domain layer exports."""
|
||||
|
||||
from app.domain.entities import BaseEntity, Post
|
||||
from app.domain.exceptions import (
|
||||
AlreadyExistsException,
|
||||
DomainException,
|
||||
ForbiddenException,
|
||||
NotFoundException,
|
||||
UnauthorizedException,
|
||||
ValidationException,
|
||||
)
|
||||
from app.domain.repositories import PostRepository, Repository
|
||||
from app.domain.value_objects import Content, Slug, Title, ValueObject
|
||||
|
||||
__all__ = [
|
||||
# Entities
|
||||
"BaseEntity",
|
||||
"Post",
|
||||
# Value Objects
|
||||
"ValueObject",
|
||||
"Title",
|
||||
"Content",
|
||||
"Slug",
|
||||
# Repositories
|
||||
"Repository",
|
||||
"PostRepository",
|
||||
# Exceptions
|
||||
"DomainException",
|
||||
"ValidationException",
|
||||
"NotFoundException",
|
||||
"AlreadyExistsException",
|
||||
"UnauthorizedException",
|
||||
"ForbiddenException",
|
||||
]
|
||||
6
app/domain/entities/__init__.py
Normal file
6
app/domain/entities/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Domain entities."""
|
||||
|
||||
from app.domain.entities.base import BaseEntity
|
||||
from app.domain.entities.post import Post
|
||||
|
||||
__all__ = ["BaseEntity", "Post"]
|
||||
33
app/domain/entities/base.py
Normal file
33
app/domain/entities/base.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Base entity for DDD domain layer."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class BaseEntity(ABC):
|
||||
"""Base class for all domain entities."""
|
||||
|
||||
id: UUID = field(default_factory=uuid4)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, BaseEntity):
|
||||
return NotImplemented
|
||||
return self.id == other.id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.id)
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Update the updated_at timestamp."""
|
||||
self.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert entity to dictionary."""
|
||||
...
|
||||
88
app/domain/entities/post.py
Normal file
88
app/domain/entities/post.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Domain entity for Blog Post."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from app.domain.entities.base import BaseEntity
|
||||
from app.domain.value_objects.content import Content
|
||||
from app.domain.value_objects.slug import Slug
|
||||
from app.domain.value_objects.title import Title
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Post(BaseEntity):
|
||||
"""Blog post domain entity."""
|
||||
|
||||
title: Title
|
||||
content: Content
|
||||
slug: Slug
|
||||
author_id: str
|
||||
published: bool = False
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
def publish(self) -> None:
|
||||
"""Publish the post."""
|
||||
self.published = True
|
||||
self.touch()
|
||||
|
||||
def unpublish(self) -> None:
|
||||
"""Unpublish the post."""
|
||||
self.published = False
|
||||
self.touch()
|
||||
|
||||
def update_content(self, content: Content) -> None:
|
||||
"""Update post content."""
|
||||
self.content = content
|
||||
self.touch()
|
||||
|
||||
def update_title(self, title: Title) -> None:
|
||||
"""Update post title and regenerate slug."""
|
||||
self.title = title
|
||||
self.slug = Slug.from_title(title.value)
|
||||
self.touch()
|
||||
|
||||
def add_tag(self, tag: str) -> None:
|
||||
"""Add a tag to the post."""
|
||||
if tag not in self.tags:
|
||||
self.tags.append(tag)
|
||||
self.touch()
|
||||
|
||||
def remove_tag(self, tag: str) -> None:
|
||||
"""Remove a tag from the post."""
|
||||
if tag in self.tags:
|
||||
self.tags.remove(tag)
|
||||
self.touch()
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert entity to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"title": self.title.value,
|
||||
"content": self.content.value,
|
||||
"slug": self.slug.value,
|
||||
"author_id": self.author_id,
|
||||
"published": self.published,
|
||||
"tags": self.tags.copy(),
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
title_str: str,
|
||||
content_str: str,
|
||||
author_id: str,
|
||||
tags: list[str] | None = None,
|
||||
) -> "Post":
|
||||
"""Factory method to create a new post."""
|
||||
title = Title(title_str)
|
||||
content = Content(content_str)
|
||||
slug = Slug.from_title(title_str)
|
||||
return cls(
|
||||
title=title,
|
||||
content=content,
|
||||
slug=slug,
|
||||
author_id=author_id,
|
||||
tags=tags or [],
|
||||
)
|
||||
39
app/domain/exceptions.py
Normal file
39
app/domain/exceptions.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Domain exceptions."""
|
||||
|
||||
|
||||
class DomainException(Exception):
|
||||
"""Base exception for domain layer."""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class ValidationException(DomainException):
|
||||
"""Raised when validation fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NotFoundException(DomainException):
|
||||
"""Raised when an entity is not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AlreadyExistsException(DomainException):
|
||||
"""Raised when trying to create an entity that already exists."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnauthorizedException(DomainException):
|
||||
"""Raised when user is not authorized."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ForbiddenException(DomainException):
|
||||
"""Raised when access is forbidden."""
|
||||
|
||||
pass
|
||||
6
app/domain/repositories/__init__.py
Normal file
6
app/domain/repositories/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Repository interfaces."""
|
||||
|
||||
from app.domain.repositories.base import Repository
|
||||
from app.domain.repositories.post import PostRepository
|
||||
|
||||
__all__ = ["Repository", "PostRepository"]
|
||||
43
app/domain/repositories/base.py
Normal file
43
app/domain/repositories/base.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Base repository interface for DDD."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.domain.entities.base import BaseEntity
|
||||
|
||||
T = TypeVar("T", bound=BaseEntity)
|
||||
|
||||
|
||||
class Repository(ABC, Generic[T]):
|
||||
"""Generic repository interface."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_id(self, entity_id: UUID) -> T | None:
|
||||
"""Get entity by ID."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_all(self) -> list[T]:
|
||||
"""Get all entities."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def add(self, entity: T) -> None:
|
||||
"""Add new entity."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update(self, entity: T) -> None:
|
||||
"""Update existing entity."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, entity_id: UUID) -> None:
|
||||
"""Delete entity by ID."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def exists(self, entity_id: UUID) -> bool:
|
||||
"""Check if entity exists."""
|
||||
...
|
||||
40
app/domain/repositories/post.py
Normal file
40
app/domain/repositories/post.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Post repository interface."""
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from app.domain.entities.post import Post
|
||||
from app.domain.repositories.base import Repository
|
||||
|
||||
|
||||
class PostRepository(Repository[Post]):
|
||||
"""Repository interface for Blog Posts."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_slug(self, slug: str) -> Post | None:
|
||||
"""Get post by slug."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_author(self, author_id: str) -> list[Post]:
|
||||
"""Get all posts by author."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_published(self) -> list[Post]:
|
||||
"""Get all published posts."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_tag(self, tag: str) -> list[Post]:
|
||||
"""Get posts by tag."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def slug_exists(self, slug: str) -> bool:
|
||||
"""Check if slug already exists."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, query: str) -> list[Post]:
|
||||
"""Search posts by query string."""
|
||||
...
|
||||
8
app/domain/value_objects/__init__.py
Normal file
8
app/domain/value_objects/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Value objects."""
|
||||
|
||||
from app.domain.value_objects.base import ValueObject
|
||||
from app.domain.value_objects.content import Content
|
||||
from app.domain.value_objects.slug import Slug
|
||||
from app.domain.value_objects.title import Title
|
||||
|
||||
__all__ = ["ValueObject", "Title", "Content", "Slug"]
|
||||
37
app/domain/value_objects/base.py
Normal file
37
app/domain/value_objects/base.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Base value object for DDD domain layer."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ValueObject(ABC, Generic[T]):
|
||||
"""Base class for all value objects."""
|
||||
|
||||
value: T
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._validate()
|
||||
|
||||
@abstractmethod
|
||||
def _validate(self) -> None:
|
||||
"""Validate the value object. Raise ValueError if invalid."""
|
||||
...
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, ValueObject):
|
||||
return False
|
||||
return bool(self.value == other.value)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
def to_primitive(self) -> Any:
|
||||
"""Convert value object to primitive type."""
|
||||
return self.value
|
||||
23
app/domain/value_objects/content.py
Normal file
23
app/domain/value_objects/content.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Content value object."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.domain.value_objects.base import ValueObject
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Content(ValueObject[str]):
|
||||
"""Blog post content value object."""
|
||||
|
||||
MIN_LENGTH: int = 10
|
||||
MAX_LENGTH: int = 50000
|
||||
|
||||
def _validate(self) -> None:
|
||||
if not isinstance(self.value, str):
|
||||
raise ValueError("Content must be a string")
|
||||
if not self.value.strip():
|
||||
raise ValueError("Content cannot be empty or whitespace")
|
||||
if len(self.value) < self.MIN_LENGTH:
|
||||
raise ValueError(f"Content must be at least {self.MIN_LENGTH} characters")
|
||||
if len(self.value) > self.MAX_LENGTH:
|
||||
raise ValueError(f"Content must be at most {self.MAX_LENGTH} characters")
|
||||
41
app/domain/value_objects/slug.py
Normal file
41
app/domain/value_objects/slug.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Slug value object for URL-friendly identifiers."""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.domain.value_objects.base import ValueObject
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Slug(ValueObject[str]):
|
||||
"""URL slug value object."""
|
||||
|
||||
MAX_LENGTH: int = 200
|
||||
SLUG_PATTERN: str = r"^[a-z0-9]+(?:-[a-z0-9]+)*$"
|
||||
|
||||
def _validate(self) -> None:
|
||||
if not isinstance(self.value, str):
|
||||
raise ValueError("Slug must be a string")
|
||||
if len(self.value) > self.MAX_LENGTH:
|
||||
raise ValueError(f"Slug must be at most {self.MAX_LENGTH} characters")
|
||||
if not re.match(self.SLUG_PATTERN, self.value):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_title(cls, title: str) -> "Slug":
|
||||
"""Generate slug from title."""
|
||||
# Convert to lowercase, replace spaces with hyphens
|
||||
slug = title.lower().strip()
|
||||
# Keep only alphanumeric, spaces, and hyphens
|
||||
slug = re.sub(r"[^a-z0-9\s-]", "", slug)
|
||||
# Replace spaces and multiple hyphens with single hyphen
|
||||
slug = re.sub(r"[-\s]+", "-", slug)
|
||||
# Limit length and strip hyphens
|
||||
max_len = 200 # Same as MAX_LENGTH
|
||||
slug = slug[:max_len].strip("-")
|
||||
# Ensure we have at least one character
|
||||
if not slug:
|
||||
slug = "post"
|
||||
return cls(value=slug)
|
||||
23
app/domain/value_objects/title.py
Normal file
23
app/domain/value_objects/title.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Title value object."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.domain.value_objects.base import ValueObject
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Title(ValueObject[str]):
|
||||
"""Blog post title value object."""
|
||||
|
||||
MIN_LENGTH: int = 3
|
||||
MAX_LENGTH: int = 200
|
||||
|
||||
def _validate(self) -> None:
|
||||
if not isinstance(self.value, str):
|
||||
raise ValueError("Title must be a string")
|
||||
if len(self.value) < self.MIN_LENGTH:
|
||||
raise ValueError(f"Title must be at least {self.MIN_LENGTH} characters")
|
||||
if len(self.value) > self.MAX_LENGTH:
|
||||
raise ValueError(f"Title must be at most {self.MAX_LENGTH} characters")
|
||||
if not self.value.strip():
|
||||
raise ValueError("Title cannot be empty or whitespace")
|
||||
35
app/infrastructure/__init__.py
Normal file
35
app/infrastructure/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Infrastructure layer exports."""
|
||||
|
||||
from app.infrastructure.config import Settings, settings
|
||||
from app.infrastructure.database import (
|
||||
AsyncSessionLocal,
|
||||
Base,
|
||||
PostORM,
|
||||
close_db,
|
||||
engine,
|
||||
get_session,
|
||||
init_db,
|
||||
)
|
||||
from app.infrastructure.di import create_container
|
||||
from app.infrastructure.middleware import register_exception_handlers
|
||||
from app.infrastructure.repositories import SQLAlchemyPostRepository
|
||||
|
||||
__all__ = [
|
||||
# Config
|
||||
"Settings",
|
||||
"settings",
|
||||
# Database
|
||||
"Base",
|
||||
"PostORM",
|
||||
"engine",
|
||||
"AsyncSessionLocal",
|
||||
"get_session",
|
||||
"init_db",
|
||||
"close_db",
|
||||
# Repositories
|
||||
"SQLAlchemyPostRepository",
|
||||
# DI
|
||||
"create_container",
|
||||
# Middleware
|
||||
"register_exception_handlers",
|
||||
]
|
||||
5
app/infrastructure/config/__init__.py
Normal file
5
app/infrastructure/config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Infrastructure configuration."""
|
||||
|
||||
from app.infrastructure.config.settings import Settings, settings
|
||||
|
||||
__all__ = ["Settings", "settings"]
|
||||
31
app/infrastructure/config/settings.py
Normal file
31
app/infrastructure/config/settings.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Application settings."""
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application configuration settings."""
|
||||
|
||||
# App settings
|
||||
app_name: str = "Blog API"
|
||||
debug: bool = False
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
# Database settings
|
||||
database_url: str = "sqlite:///./blog.db"
|
||||
database_echo: bool = False
|
||||
|
||||
# Security settings
|
||||
secret_key: str = "your-secret-key-change-in-production"
|
||||
access_token_expire_minutes: int = 30
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
22
app/infrastructure/database/__init__.py
Normal file
22
app/infrastructure/database/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Database infrastructure."""
|
||||
|
||||
from app.infrastructure.database.connection import (
|
||||
AsyncSessionLocal,
|
||||
close_db,
|
||||
engine,
|
||||
get_session,
|
||||
get_session_context,
|
||||
init_db,
|
||||
)
|
||||
from app.infrastructure.database.models import Base, PostORM
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"PostORM",
|
||||
"engine",
|
||||
"AsyncSessionLocal",
|
||||
"get_session",
|
||||
"get_session_context",
|
||||
"init_db",
|
||||
"close_db",
|
||||
]
|
||||
70
app/infrastructure/database/connection.py
Normal file
70
app/infrastructure/database/connection.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Database connection and session management."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from app.infrastructure.config import settings
|
||||
|
||||
|
||||
# Convert SQLite URL to async format if needed
|
||||
def _get_database_url() -> str:
|
||||
url = settings.database_url
|
||||
if url.startswith("sqlite:///") and not url.startswith("sqlite+aiosqlite:///"):
|
||||
return url.replace("sqlite:///", "sqlite+aiosqlite:///")
|
||||
return url
|
||||
|
||||
|
||||
# Create async engine
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
_get_database_url(),
|
||||
echo=settings.database_echo,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get database session."""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session_context() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get database session as context manager."""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Initialize database tables."""
|
||||
from app.infrastructure.database.models import Base
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
"""Close database connections."""
|
||||
await engine.dispose()
|
||||
40
app/infrastructure/database/models.py
Normal file
40
app/infrastructure/database/models.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""SQLAlchemy ORM models."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import JSON, Boolean, DateTime, String, Text
|
||||
from sqlalchemy.orm import Mapped, declarative_base, mapped_column
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class PostORM(Base): # type: ignore[valid-type,misc]
|
||||
"""SQLAlchemy model for Blog Post."""
|
||||
|
||||
__tablename__ = "posts"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid4())
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
slug: Mapped[str] = mapped_column(
|
||||
String(200), nullable=False, unique=True, index=True
|
||||
)
|
||||
author_id: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||
published: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=False, index=True
|
||||
)
|
||||
tags: Mapped[list[str]] = mapped_column(JSON, default=list)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
7
app/infrastructure/di/__init__.py
Normal file
7
app/infrastructure/di/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Dependency Injection using Dishka."""
|
||||
|
||||
from app.infrastructure.di.container import create_container
|
||||
|
||||
__all__ = [
|
||||
"create_container",
|
||||
]
|
||||
20
app/infrastructure/di/container.py
Normal file
20
app/infrastructure/di/container.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Dishka container setup."""
|
||||
|
||||
from dishka import AsyncContainer, make_async_container
|
||||
|
||||
from app.infrastructure.di.providers import (
|
||||
DatabaseProvider,
|
||||
RepositoryProvider,
|
||||
TransactionManagerProvider,
|
||||
UseCaseProvider,
|
||||
)
|
||||
|
||||
|
||||
def create_container() -> AsyncContainer:
|
||||
"""Create and configure Dishka container."""
|
||||
return make_async_container(
|
||||
DatabaseProvider(),
|
||||
RepositoryProvider(),
|
||||
TransactionManagerProvider(),
|
||||
UseCaseProvider(),
|
||||
)
|
||||
133
app/infrastructure/di/providers.py
Normal file
133
app/infrastructure/di/providers.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Dishka providers for dependency injection."""
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from dishka import Provider, Scope, provide
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||
|
||||
from app.application import (
|
||||
CreatePostUseCase,
|
||||
DeletePostUseCase,
|
||||
GetPostUseCase,
|
||||
ListPostsUseCase,
|
||||
PublishPostUseCase,
|
||||
UpdatePostUseCase,
|
||||
)
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.repositories import PostRepository
|
||||
from app.infrastructure.database.connection import AsyncSessionLocal, engine
|
||||
from app.infrastructure.repositories.post import SQLAlchemyPostRepository
|
||||
|
||||
|
||||
class DatabaseProvider(Provider):
|
||||
"""Provider for database-related dependencies."""
|
||||
|
||||
@provide(scope=Scope.APP)
|
||||
def get_engine(self) -> AsyncEngine:
|
||||
"""Provide SQLAlchemy engine."""
|
||||
return engine
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Provide database session per request."""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
class RepositoryProvider(Provider):
|
||||
"""Provider for repository implementations."""
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_post_repository(self, session: AsyncSession) -> PostRepository:
|
||||
"""Provide PostRepository implementation."""
|
||||
return SQLAlchemyPostRepository(session)
|
||||
|
||||
|
||||
class TransactionManagerProvider(Provider):
|
||||
"""Provider for transaction manager."""
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_transaction_manager(self, session: AsyncSession) -> TransactionManager:
|
||||
"""Provide TransactionManager implementation."""
|
||||
from app.infrastructure.di.transaction_manager import SessionTransactionManager
|
||||
|
||||
return SessionTransactionManager(session)
|
||||
|
||||
|
||||
class UseCaseProvider(Provider):
|
||||
"""Provider for use cases."""
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_create_post_use_case(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> CreatePostUseCase:
|
||||
"""Provide CreatePostUseCase."""
|
||||
return CreatePostUseCase(
|
||||
post_repo=post_repo,
|
||||
tx_manager=tx_manager,
|
||||
)
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_get_post_use_case(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> GetPostUseCase:
|
||||
"""Provide GetPostUseCase."""
|
||||
return GetPostUseCase(
|
||||
post_repo=post_repo,
|
||||
tx_manager=tx_manager,
|
||||
)
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_update_post_use_case(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> UpdatePostUseCase:
|
||||
"""Provide UpdatePostUseCase."""
|
||||
return UpdatePostUseCase(
|
||||
post_repo=post_repo,
|
||||
tx_manager=tx_manager,
|
||||
)
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_delete_post_use_case(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> DeletePostUseCase:
|
||||
"""Provide DeletePostUseCase."""
|
||||
return DeletePostUseCase(
|
||||
post_repo=post_repo,
|
||||
tx_manager=tx_manager,
|
||||
)
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_list_posts_use_case(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> ListPostsUseCase:
|
||||
"""Provide ListPostsUseCase."""
|
||||
return ListPostsUseCase(
|
||||
post_repo=post_repo,
|
||||
tx_manager=tx_manager,
|
||||
)
|
||||
|
||||
@provide(scope=Scope.REQUEST)
|
||||
def get_publish_post_use_case(
|
||||
self,
|
||||
post_repo: PostRepository,
|
||||
tx_manager: TransactionManager,
|
||||
) -> PublishPostUseCase:
|
||||
"""Provide PublishPostUseCase."""
|
||||
return PublishPostUseCase(
|
||||
post_repo=post_repo,
|
||||
tx_manager=tx_manager,
|
||||
)
|
||||
24
app/infrastructure/di/transaction_manager.py
Normal file
24
app/infrastructure/di/transaction_manager.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""SQLAlchemy implementation of Transaction Manager."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.application.interfaces import TransactionManager
|
||||
|
||||
|
||||
class SessionTransactionManager(TransactionManager):
|
||||
"""SQLAlchemy Session-based Transaction Manager."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
self._committed: bool = False
|
||||
|
||||
async def commit(self) -> None:
|
||||
"""Commit the current transaction."""
|
||||
if not self._committed:
|
||||
await self._session.commit()
|
||||
self._committed = True
|
||||
|
||||
async def rollback(self) -> None:
|
||||
"""Rollback the current transaction."""
|
||||
if not self._committed:
|
||||
await self._session.rollback()
|
||||
15
app/infrastructure/middleware/__init__.py
Normal file
15
app/infrastructure/middleware/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Infrastructure middleware."""
|
||||
|
||||
from app.infrastructure.middleware.error_handler import (
|
||||
domain_exception_handler,
|
||||
generic_exception_handler,
|
||||
http_exception_handler,
|
||||
register_exception_handlers,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"domain_exception_handler",
|
||||
"http_exception_handler",
|
||||
"generic_exception_handler",
|
||||
"register_exception_handlers",
|
||||
]
|
||||
93
app/infrastructure/middleware/error_handler.py
Normal file
93
app/infrastructure/middleware/error_handler.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Exception handling middleware."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
|
||||
from app.domain.exceptions import (
|
||||
AlreadyExistsException,
|
||||
DomainException,
|
||||
ForbiddenException,
|
||||
NotFoundException,
|
||||
UnauthorizedException,
|
||||
ValidationException,
|
||||
)
|
||||
|
||||
|
||||
def get_status_code(exc: DomainException) -> int:
|
||||
"""Map domain exceptions to HTTP status codes."""
|
||||
match exc:
|
||||
case ValidationException():
|
||||
return 400
|
||||
case UnauthorizedException():
|
||||
return 401
|
||||
case ForbiddenException():
|
||||
return 403
|
||||
case NotFoundException():
|
||||
return 404
|
||||
case AlreadyExistsException():
|
||||
return 409
|
||||
case _:
|
||||
return 500
|
||||
|
||||
|
||||
async def domain_exception_handler(
|
||||
request: Request, exc: DomainException
|
||||
) -> JSONResponse:
|
||||
"""Handle domain exceptions."""
|
||||
status_code = get_status_code(exc)
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": exc.__class__.__name__,
|
||||
"message": exc.message,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"path": str(request.url.path),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def http_exception_handler(
|
||||
request: Request, exc: StarletteHTTPException
|
||||
) -> JSONResponse:
|
||||
"""Handle HTTP exceptions."""
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"error": "HTTPException",
|
||||
"message": str(exc.detail),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"path": str(request.url.path),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def generic_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
"""Handle generic exceptions."""
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": "InternalServerError",
|
||||
"message": "An unexpected error occurred",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"path": str(request.url.path),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register all exception handlers with FastAPI app."""
|
||||
if not isinstance(app, FastAPI):
|
||||
raise TypeError("app must be a FastAPI instance")
|
||||
|
||||
# Domain exceptions
|
||||
app.add_exception_handler(DomainException, domain_exception_handler) # type: ignore[arg-type]
|
||||
|
||||
# HTTP exceptions
|
||||
app.add_exception_handler(StarletteHTTPException, http_exception_handler) # type: ignore[arg-type]
|
||||
|
||||
# Generic exceptions (only in production)
|
||||
# In development, let FastAPI show detailed traceback
|
||||
# app.add_exception_handler(Exception, generic_exception_handler)
|
||||
5
app/infrastructure/repositories/__init__.py
Normal file
5
app/infrastructure/repositories/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Repository implementations."""
|
||||
|
||||
from app.infrastructure.repositories.post import SQLAlchemyPostRepository
|
||||
|
||||
__all__ = ["SQLAlchemyPostRepository"]
|
||||
151
app/infrastructure/repositories/post.py
Normal file
151
app/infrastructure/repositories/post.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""SQLAlchemy implementation of PostRepository."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.domain.entities import Post
|
||||
from app.domain.repositories import PostRepository
|
||||
from app.domain.value_objects import Content, Slug, Title
|
||||
from app.infrastructure.database.models import PostORM
|
||||
|
||||
|
||||
class SQLAlchemyPostRepository(PostRepository):
|
||||
"""SQLAlchemy implementation of Post repository."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
def _to_domain(self, orm: PostORM) -> Post:
|
||||
"""Convert ORM model to domain entity."""
|
||||
return Post(
|
||||
id=UUID(orm.id),
|
||||
title=Title(orm.title),
|
||||
content=Content(orm.content),
|
||||
slug=Slug(orm.slug),
|
||||
author_id=orm.author_id,
|
||||
published=orm.published,
|
||||
tags=orm.tags or [],
|
||||
created_at=orm.created_at,
|
||||
updated_at=orm.updated_at,
|
||||
)
|
||||
|
||||
def _to_orm(self, post: Post) -> PostORM:
|
||||
"""Convert domain entity to ORM model."""
|
||||
return PostORM(
|
||||
id=str(post.id),
|
||||
title=post.title.value,
|
||||
content=post.content.value,
|
||||
slug=post.slug.value,
|
||||
author_id=post.author_id,
|
||||
published=post.published,
|
||||
tags=post.tags,
|
||||
created_at=post.created_at,
|
||||
updated_at=post.updated_at,
|
||||
)
|
||||
|
||||
async def get_by_id(self, entity_id: UUID) -> Post | None:
|
||||
"""Get post by ID."""
|
||||
result = await self._session.execute(
|
||||
select(PostORM).where(PostORM.id == str(entity_id))
|
||||
)
|
||||
orm = result.scalar_one_or_none()
|
||||
return self._to_domain(orm) if orm else None
|
||||
|
||||
async def get_all(self) -> list[Post]:
|
||||
"""Get all posts."""
|
||||
result = await self._session.execute(select(PostORM))
|
||||
orms = result.scalars().all()
|
||||
return [self._to_domain(orm) for orm in orms]
|
||||
|
||||
async def add(self, entity: Post) -> None:
|
||||
"""Add new post."""
|
||||
orm = self._to_orm(entity)
|
||||
self._session.add(orm)
|
||||
# Commit делает TransactionManager
|
||||
|
||||
async def update(self, entity: Post) -> None:
|
||||
"""Update existing post."""
|
||||
result = await self._session.execute(
|
||||
select(PostORM).where(PostORM.id == str(entity.id))
|
||||
)
|
||||
orm = result.scalar_one()
|
||||
|
||||
orm.title = entity.title.value
|
||||
orm.content = entity.content.value
|
||||
orm.slug = entity.slug.value
|
||||
orm.published = entity.published
|
||||
orm.tags = entity.tags
|
||||
orm.updated_at = entity.updated_at
|
||||
|
||||
# Commit делает TransactionManager
|
||||
|
||||
async def delete(self, entity_id: UUID) -> None:
|
||||
"""Delete post by ID."""
|
||||
result = await self._session.execute(
|
||||
select(PostORM).where(PostORM.id == str(entity_id))
|
||||
)
|
||||
orm = result.scalar_one_or_none()
|
||||
if orm:
|
||||
await self._session.delete(orm)
|
||||
|
||||
async def exists(self, entity_id: UUID) -> bool:
|
||||
"""Check if post exists."""
|
||||
result = await self._session.execute(
|
||||
select(PostORM).where(PostORM.id == str(entity_id))
|
||||
)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
async def get_by_slug(self, slug: str) -> Post | None:
|
||||
"""Get post by slug."""
|
||||
result = await self._session.execute(
|
||||
select(PostORM).where(PostORM.slug == slug)
|
||||
)
|
||||
orm = result.scalar_one_or_none()
|
||||
return self._to_domain(orm) if orm else None
|
||||
|
||||
async def get_by_author(self, author_id: str) -> list[Post]:
|
||||
"""Get posts by author."""
|
||||
result = await self._session.execute(
|
||||
select(PostORM).where(PostORM.author_id == author_id)
|
||||
)
|
||||
orms = result.scalars().all()
|
||||
return [self._to_domain(orm) for orm in orms]
|
||||
|
||||
async def get_published(self) -> list[Post]:
|
||||
"""Get published posts."""
|
||||
result = await self._session.execute(
|
||||
select(PostORM).where(PostORM.published.is_(True))
|
||||
)
|
||||
orms = result.scalars().all()
|
||||
return [self._to_domain(orm) for orm in orms]
|
||||
|
||||
async def get_by_tag(self, tag: str) -> list[Post]:
|
||||
"""Get posts by tag."""
|
||||
result = await self._session.execute(
|
||||
select(PostORM).where(PostORM.tags.contains([tag]))
|
||||
)
|
||||
orms = result.scalars().all()
|
||||
return [self._to_domain(orm) for orm in orms]
|
||||
|
||||
async def slug_exists(self, slug: str) -> bool:
|
||||
"""Check if slug exists."""
|
||||
result = await self._session.execute(
|
||||
select(PostORM).where(PostORM.slug == slug)
|
||||
)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
async def search(self, query: str) -> list[Post]:
|
||||
"""Search posts."""
|
||||
search_pattern = f"%{query}%"
|
||||
result = await self._session.execute(
|
||||
select(PostORM).where(
|
||||
or_(
|
||||
PostORM.title.ilike(search_pattern),
|
||||
PostORM.content.ilike(search_pattern),
|
||||
)
|
||||
)
|
||||
)
|
||||
orms = result.scalars().all()
|
||||
return [self._to_domain(orm) for orm in orms]
|
||||
66
app/main.py
66
app/main.py
@@ -1,22 +1,84 @@
|
||||
"""Application entry point with DDD architecture."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import uvicorn
|
||||
from dishka import make_async_container
|
||||
from dishka.integrations.fastapi import setup_dishka
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.infrastructure import close_db, init_db, register_exception_handlers, settings
|
||||
from app.infrastructure.di.providers import (
|
||||
DatabaseProvider,
|
||||
RepositoryProvider,
|
||||
TransactionManagerProvider,
|
||||
UseCaseProvider,
|
||||
)
|
||||
from app.presentation import router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan manager."""
|
||||
# Startup
|
||||
await init_db()
|
||||
yield
|
||||
# Shutdown
|
||||
await close_db()
|
||||
|
||||
|
||||
def app_factory() -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
"""Create and configure FastAPI application."""
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
debug=settings.debug,
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs" if settings.debug else None,
|
||||
redoc_url="/redoc" if settings.debug else None,
|
||||
)
|
||||
|
||||
# Setup Dishka DI container
|
||||
container = make_async_container(
|
||||
DatabaseProvider(),
|
||||
RepositoryProvider(),
|
||||
TransactionManagerProvider(),
|
||||
UseCaseProvider(),
|
||||
)
|
||||
setup_dishka(container, app)
|
||||
|
||||
# Register exception handlers
|
||||
register_exception_handlers(app)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include API routes
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
# Health check endpoint
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health_check() -> dict[str, str]:
|
||||
return {"status": "ok", "app": settings.app_name}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
uvicorn.run(app_factory, factory=True, host="0.0.0.0", port=8000)
|
||||
"""Run the application."""
|
||||
uvicorn.run(
|
||||
app_factory,
|
||||
factory=True,
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Feature modules - business logic organized by domain."""
|
||||
17
app/presentation/__init__.py
Normal file
17
app/presentation/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Presentation layer exports."""
|
||||
|
||||
from app.presentation.api import router
|
||||
from app.presentation.schemas import (
|
||||
PostCreateSchema,
|
||||
PostListResponseSchema,
|
||||
PostResponseSchema,
|
||||
PostUpdateSchema,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
"PostCreateSchema",
|
||||
"PostUpdateSchema",
|
||||
"PostResponseSchema",
|
||||
"PostListResponseSchema",
|
||||
]
|
||||
8
app/presentation/api/__init__.py
Normal file
8
app/presentation/api/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""API router configuration."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.presentation.api.v1 import router as v1_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(v1_router)
|
||||
34
app/presentation/api/deps.py
Normal file
34
app/presentation/api/deps.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""API dependencies using Dishka."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from dishka.integrations.fastapi import FromDishka
|
||||
from fastapi import Depends, Header
|
||||
|
||||
from app.application import (
|
||||
CreatePostUseCase,
|
||||
DeletePostUseCase,
|
||||
GetPostUseCase,
|
||||
ListPostsUseCase,
|
||||
PublishPostUseCase,
|
||||
UpdatePostUseCase,
|
||||
)
|
||||
|
||||
# Use case dependencies - injected via Dishka
|
||||
CreatePostDep = FromDishka[CreatePostUseCase]
|
||||
GetPostDep = FromDishka[GetPostUseCase]
|
||||
UpdatePostDep = FromDishka[UpdatePostUseCase]
|
||||
DeletePostDep = FromDishka[DeletePostUseCase]
|
||||
ListPostsDep = FromDishka[ListPostsUseCase]
|
||||
PublishPostDep = FromDishka[PublishPostUseCase]
|
||||
|
||||
|
||||
# Mock current user dependency (replace with real auth)
|
||||
async def get_current_user_id(
|
||||
x_user_id: Annotated[str | None, Header()] = "user-123",
|
||||
) -> str:
|
||||
"""Get current user ID from header."""
|
||||
return x_user_id or "user-123"
|
||||
|
||||
|
||||
CurrentUserDep = Annotated[str, Depends(get_current_user_id)]
|
||||
8
app/presentation/api/v1/__init__.py
Normal file
8
app/presentation/api/v1/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""API v1 router."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.presentation.api.v1.posts import router as posts_router
|
||||
|
||||
router = APIRouter(prefix="/v1")
|
||||
router.include_router(posts_router)
|
||||
211
app/presentation/api/v1/posts.py
Normal file
211
app/presentation/api/v1/posts.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Posts API routes."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from dishka.integrations.fastapi import DishkaRoute
|
||||
from fastapi import APIRouter, status
|
||||
|
||||
from app.application.dtos import CreatePostDTO, UpdatePostDTO
|
||||
from app.presentation.api.deps import (
|
||||
CreatePostDep,
|
||||
CurrentUserDep,
|
||||
DeletePostDep,
|
||||
GetPostDep,
|
||||
ListPostsDep,
|
||||
PublishPostDep,
|
||||
UpdatePostDep,
|
||||
)
|
||||
from app.presentation.schemas import (
|
||||
PostCreateSchema,
|
||||
PostListResponseSchema,
|
||||
PostResponseSchema,
|
||||
PostUpdateSchema,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/posts", tags=["posts"], route_class=DishkaRoute)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=PostResponseSchema,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create a new post",
|
||||
)
|
||||
async def create_post(
|
||||
schema: PostCreateSchema,
|
||||
use_case: CreatePostDep,
|
||||
current_user_id: CurrentUserDep,
|
||||
) -> PostResponseSchema:
|
||||
"""Create a new blog post."""
|
||||
dto = CreatePostDTO(
|
||||
title=schema.title,
|
||||
content=schema.content,
|
||||
author_id=current_user_id,
|
||||
tags=schema.tags,
|
||||
)
|
||||
result = await use_case.execute(dto)
|
||||
return PostResponseSchema(**result.__dict__)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=PostListResponseSchema,
|
||||
summary="List all posts",
|
||||
)
|
||||
async def list_posts(use_case: ListPostsDep) -> PostListResponseSchema:
|
||||
"""Get all blog posts."""
|
||||
results = await use_case.all_posts()
|
||||
items = [PostResponseSchema(**r.__dict__) for r in results]
|
||||
return PostListResponseSchema(items=items, total=len(items))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/published",
|
||||
response_model=PostListResponseSchema,
|
||||
summary="List published posts",
|
||||
)
|
||||
async def list_published_posts(
|
||||
use_case: ListPostsDep,
|
||||
) -> PostListResponseSchema:
|
||||
"""Get all published blog posts."""
|
||||
results = await use_case.published_posts()
|
||||
items = [PostResponseSchema(**r.__dict__) for r in results]
|
||||
return PostListResponseSchema(items=items, total=len(items))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search",
|
||||
response_model=PostListResponseSchema,
|
||||
summary="Search posts",
|
||||
)
|
||||
async def search_posts(
|
||||
query: str,
|
||||
use_case: ListPostsDep,
|
||||
) -> PostListResponseSchema:
|
||||
"""Search posts by query."""
|
||||
results = await use_case.search(query)
|
||||
items = [PostResponseSchema(**r.__dict__) for r in results]
|
||||
return PostListResponseSchema(items=items, total=len(items))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/by-tag/{tag}",
|
||||
response_model=PostListResponseSchema,
|
||||
summary="Get posts by tag",
|
||||
)
|
||||
async def get_posts_by_tag(
|
||||
tag: str,
|
||||
use_case: ListPostsDep,
|
||||
) -> PostListResponseSchema:
|
||||
"""Get posts by tag."""
|
||||
results = await use_case.by_tag(tag)
|
||||
items = [PostResponseSchema(**r.__dict__) for r in results]
|
||||
return PostListResponseSchema(items=items, total=len(items))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/by-author/{author_id}",
|
||||
response_model=PostListResponseSchema,
|
||||
summary="Get posts by author",
|
||||
)
|
||||
async def get_posts_by_author(
|
||||
author_id: str,
|
||||
use_case: ListPostsDep,
|
||||
) -> PostListResponseSchema:
|
||||
"""Get posts by author."""
|
||||
results = await use_case.by_author(author_id)
|
||||
items = [PostResponseSchema(**r.__dict__) for r in results]
|
||||
return PostListResponseSchema(items=items, total=len(items))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{post_id}",
|
||||
response_model=PostResponseSchema,
|
||||
summary="Get post by ID",
|
||||
)
|
||||
async def get_post(
|
||||
post_id: UUID,
|
||||
use_case: GetPostDep,
|
||||
) -> PostResponseSchema:
|
||||
"""Get a post by its ID."""
|
||||
result = await use_case.by_id(post_id)
|
||||
return PostResponseSchema(**result.__dict__)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/slug/{slug}",
|
||||
response_model=PostResponseSchema,
|
||||
summary="Get post by slug",
|
||||
)
|
||||
async def get_post_by_slug(
|
||||
slug: str,
|
||||
use_case: GetPostDep,
|
||||
) -> PostResponseSchema:
|
||||
"""Get a post by its slug."""
|
||||
result = await use_case.by_slug(slug)
|
||||
return PostResponseSchema(**result.__dict__)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{post_id}",
|
||||
response_model=PostResponseSchema,
|
||||
summary="Update post",
|
||||
)
|
||||
async def update_post(
|
||||
post_id: UUID,
|
||||
schema: PostUpdateSchema,
|
||||
use_case: UpdatePostDep,
|
||||
current_user_id: CurrentUserDep,
|
||||
) -> PostResponseSchema:
|
||||
"""Update a post."""
|
||||
dto = UpdatePostDTO(
|
||||
title=schema.title,
|
||||
content=schema.content,
|
||||
tags=schema.tags,
|
||||
)
|
||||
result = await use_case.execute(post_id, dto, current_user_id)
|
||||
return PostResponseSchema(**result.__dict__)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{post_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete post",
|
||||
)
|
||||
async def delete_post(
|
||||
post_id: UUID,
|
||||
use_case: DeletePostDep,
|
||||
current_user_id: CurrentUserDep,
|
||||
) -> None:
|
||||
"""Delete a post."""
|
||||
await use_case.execute(post_id, current_user_id)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{post_id}/publish",
|
||||
response_model=PostResponseSchema,
|
||||
summary="Publish post",
|
||||
)
|
||||
async def publish_post(
|
||||
post_id: UUID,
|
||||
use_case: PublishPostDep,
|
||||
current_user_id: CurrentUserDep,
|
||||
) -> PostResponseSchema:
|
||||
"""Publish a post."""
|
||||
result = await use_case.publish(post_id, current_user_id)
|
||||
return PostResponseSchema(**result.__dict__)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{post_id}/unpublish",
|
||||
response_model=PostResponseSchema,
|
||||
summary="Unpublish post",
|
||||
)
|
||||
async def unpublish_post(
|
||||
post_id: UUID,
|
||||
use_case: PublishPostDep,
|
||||
current_user_id: CurrentUserDep,
|
||||
) -> PostResponseSchema:
|
||||
"""Unpublish a post."""
|
||||
result = await use_case.unpublish(post_id, current_user_id)
|
||||
return PostResponseSchema(**result.__dict__)
|
||||
21
app/presentation/schemas/__init__.py
Normal file
21
app/presentation/schemas/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Presentation schemas."""
|
||||
|
||||
from app.presentation.schemas.post import (
|
||||
PostBaseSchema,
|
||||
PostCreateSchema,
|
||||
PostListResponseSchema,
|
||||
PostPublishSchema,
|
||||
PostResponseSchema,
|
||||
PostSearchSchema,
|
||||
PostUpdateSchema,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PostBaseSchema",
|
||||
"PostCreateSchema",
|
||||
"PostUpdateSchema",
|
||||
"PostResponseSchema",
|
||||
"PostListResponseSchema",
|
||||
"PostSearchSchema",
|
||||
"PostPublishSchema",
|
||||
]
|
||||
66
app/presentation/schemas/post.py
Normal file
66
app/presentation/schemas/post.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""API schemas for posts."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class PostBaseSchema(BaseModel):
|
||||
"""Base schema for posts."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
title: str = Field(..., min_length=3, max_length=200)
|
||||
content: str = Field(..., min_length=10, max_length=50000)
|
||||
|
||||
|
||||
class PostCreateSchema(PostBaseSchema):
|
||||
"""Schema for creating a post."""
|
||||
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PostUpdateSchema(BaseModel):
|
||||
"""Schema for updating a post."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
title: str | None = Field(None, min_length=3, max_length=200)
|
||||
content: str | None = Field(None, min_length=10, max_length=50000)
|
||||
tags: list[str] | None = None
|
||||
|
||||
|
||||
class PostResponseSchema(BaseModel):
|
||||
"""Schema for post response."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID
|
||||
title: str
|
||||
content: str
|
||||
slug: str
|
||||
author_id: str
|
||||
published: bool
|
||||
tags: list[str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class PostListResponseSchema(BaseModel):
|
||||
"""Schema for list of posts response."""
|
||||
|
||||
items: list[PostResponseSchema]
|
||||
total: int
|
||||
|
||||
|
||||
class PostSearchSchema(BaseModel):
|
||||
"""Schema for searching posts."""
|
||||
|
||||
query: str = Field(..., min_length=1, max_length=100)
|
||||
|
||||
|
||||
class PostPublishSchema(BaseModel):
|
||||
"""Schema for publishing/unpublishing a post."""
|
||||
|
||||
published: bool
|
||||
@@ -4,11 +4,21 @@ version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["app"]
|
||||
dependencies = [
|
||||
"fastapi>=0.136.0",
|
||||
"pydantic>=2.13.2",
|
||||
"pydantic-settings>=2.14.0",
|
||||
"uvicorn>=0.44.0",
|
||||
"sqlalchemy>=2.0.0",
|
||||
"aiosqlite>=0.21.0",
|
||||
"dishka>=1.5.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
@@ -35,10 +45,13 @@ types = [
|
||||
"mypy>=1.20.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
blog = "app.main:main"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
addopts = "--cov=src --cov-report=term"
|
||||
addopts = "--cov=app --cov-report=term-missing --cov-report=html"
|
||||
pythonpath = "."
|
||||
testpaths = "tests"
|
||||
xfail_strict = true
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
# API test fixtures
|
||||
# Provides: httpx.AsyncClient, authentication helpers, test API data
|
||||
"""API test fixtures."""
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.main import app_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client() -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create async HTTP client for API testing."""
|
||||
from app.main import app_factory
|
||||
|
||||
app = app_factory()
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
@@ -21,4 +20,4 @@ async def client() -> AsyncGenerator[AsyncClient, None]:
|
||||
@pytest.fixture
|
||||
def auth_headers() -> dict[str, str]:
|
||||
"""Return mock authentication headers."""
|
||||
return {"Authorization": "Bearer test_token"}
|
||||
return {"Authorization": "Bearer test_token", "X-User-Id": "user-123"}
|
||||
|
||||
@@ -1,20 +1,58 @@
|
||||
# Integration test fixtures
|
||||
# Provides: test database, external service connections
|
||||
"""Integration test fixtures."""
|
||||
|
||||
from typing import Generator
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from app.infrastructure.database.models import Base
|
||||
|
||||
# Use in-memory SQLite for tests
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_db_connection() -> Generator[str, None, None]:
|
||||
"""Create test database connection."""
|
||||
# TODO: Implement when DB is added to project
|
||||
yield "test_db"
|
||||
@pytest.fixture(scope="session")
|
||||
def engine() -> AsyncEngine:
|
||||
"""Create test engine."""
|
||||
return create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_db() -> Generator[None, None, None]:
|
||||
"""Cleanup database after test."""
|
||||
@pytest.fixture(scope="session")
|
||||
def session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
|
||||
"""Create test session factory."""
|
||||
return async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_db(engine: AsyncEngine) -> AsyncGenerator[None, None]:
|
||||
"""Setup database tables for each test."""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
# TODO: Implement cleanup logic
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create database session for testing."""
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
# Предполагаем, что тестируемый модуль называется `myapp`
|
||||
# Импортируем из него нужные объекты
|
||||
from app.main import app_factory, lifespan, main
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan() -> None:
|
||||
"""Проверяет, что lifespan является корректным асинхронным контекстным менеджером."""
|
||||
app = FastAPI()
|
||||
# Проверяем, что lifespan - это asynccontextmanager
|
||||
assert isinstance(lifespan, asynccontextmanager(lifespan).__class__) # type: ignore[arg-type]
|
||||
|
||||
# Проверяем, что контекстный менеджер работает (ничего не ломается)
|
||||
async with lifespan(app):
|
||||
pass # Просто убеждаемся, что yield отрабатывает
|
||||
|
||||
|
||||
def test_app_factory() -> None:
|
||||
"""Проверяет, что app_factory создаёт приложение FastAPI с переданным lifespan."""
|
||||
app = app_factory()
|
||||
assert isinstance(app, FastAPI)
|
||||
# Проверяем, что lifespan приложения установлен на функцию lifespan
|
||||
assert app.router.lifespan_context == lifespan
|
||||
|
||||
|
||||
@patch("app.main.uvicorn.run")
|
||||
def test_main(mock_uvicorn_run: Mock) -> None:
|
||||
"""Проверяет, что main вызывает uvicorn.run с правильными параметрами."""
|
||||
main()
|
||||
mock_uvicorn_run.assert_called_once_with(
|
||||
app_factory,
|
||||
factory=True,
|
||||
host="0.0.0.0",
|
||||
port=8000, # Предполагаемый порт (в коде обрезано, но обычно 8000)
|
||||
)
|
||||
0
tests/unit/application/__init__.py
Normal file
0
tests/unit/application/__init__.py
Normal file
273
tests/unit/application/test_use_cases.py
Normal file
273
tests/unit/application/test_use_cases.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Tests for application use cases."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.application.dtos.post import CreatePostDTO, UpdatePostDTO
|
||||
from app.application.use_cases import (
|
||||
CreatePostUseCase,
|
||||
DeletePostUseCase,
|
||||
GetPostUseCase,
|
||||
ListPostsUseCase,
|
||||
PublishPostUseCase,
|
||||
UpdatePostUseCase,
|
||||
)
|
||||
from app.domain.entities import Post
|
||||
from app.domain.exceptions import (
|
||||
AlreadyExistsException,
|
||||
ForbiddenException,
|
||||
NotFoundException,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_post() -> Post:
|
||||
"""Create a test post."""
|
||||
return Post.create(
|
||||
title_str="Test Post",
|
||||
content_str="This is test content with enough characters",
|
||||
author_id="user-123",
|
||||
tags=["test"],
|
||||
)
|
||||
|
||||
|
||||
class TestCreatePostUseCase:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_post_success(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
) -> None:
|
||||
"""Test successful post creation."""
|
||||
# Setup
|
||||
mock_post_repository.slug_exists = AsyncMock(return_value=False)
|
||||
mock_post_repository.add = AsyncMock()
|
||||
|
||||
use_case = CreatePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
dto = CreatePostDTO(
|
||||
title="New Post",
|
||||
content="Content with enough characters",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await use_case.execute(dto)
|
||||
|
||||
# Assert
|
||||
assert result.title == "New Post"
|
||||
assert result.author_id == "user-123"
|
||||
mock_post_repository.add.assert_called_once()
|
||||
mock_transaction_manager.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_post_slug_exists(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
) -> None:
|
||||
"""Test post creation with existing slug."""
|
||||
# Setup
|
||||
mock_post_repository.slug_exists = AsyncMock(return_value=True)
|
||||
|
||||
use_case = CreatePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
dto = CreatePostDTO(
|
||||
title="Existing Post",
|
||||
content="Content with enough characters",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(AlreadyExistsException):
|
||||
await use_case.execute(dto)
|
||||
|
||||
mock_post_repository.add.assert_not_called()
|
||||
mock_transaction_manager.commit.assert_not_called()
|
||||
|
||||
|
||||
class TestGetPostUseCase:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_post_by_id_success(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test successful get post by ID."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=test_post)
|
||||
|
||||
use_case = GetPostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
|
||||
# Execute
|
||||
result = await use_case.by_id(test_post.id)
|
||||
|
||||
# Assert
|
||||
assert result.id == test_post.id
|
||||
assert result.title == test_post.title.value
|
||||
mock_post_repository.get_by_id.assert_called_once_with(test_post.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_post_by_id_not_found(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
) -> None:
|
||||
"""Test get post by ID when not found."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=None)
|
||||
|
||||
use_case = GetPostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
post_id = uuid4()
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(NotFoundException):
|
||||
await use_case.by_id(post_id)
|
||||
|
||||
|
||||
class TestUpdatePostUseCase:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_post_success(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test successful post update."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=test_post)
|
||||
mock_post_repository.update = AsyncMock()
|
||||
|
||||
use_case = UpdatePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
dto = UpdatePostDTO(title="Updated Title")
|
||||
|
||||
# Execute
|
||||
result = await use_case.execute(test_post.id, dto, "user-123")
|
||||
|
||||
# Assert
|
||||
assert result.title == "Updated Title"
|
||||
mock_post_repository.update.assert_called_once()
|
||||
mock_transaction_manager.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_post_not_found(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
) -> None:
|
||||
"""Test update post when not found."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=None)
|
||||
|
||||
use_case = UpdatePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
dto = UpdatePostDTO(title="Updated Title")
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(NotFoundException):
|
||||
await use_case.execute(uuid4(), dto, "user-123")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_post_forbidden(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test update post by different user."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=test_post)
|
||||
|
||||
use_case = UpdatePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
dto = UpdatePostDTO(title="Updated Title")
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(ForbiddenException):
|
||||
await use_case.execute(test_post.id, dto, "other-user")
|
||||
|
||||
|
||||
class TestDeletePostUseCase:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_post_success(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test successful post deletion."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=test_post)
|
||||
mock_post_repository.delete = AsyncMock()
|
||||
|
||||
use_case = DeletePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
|
||||
# Execute
|
||||
await use_case.execute(test_post.id, "user-123")
|
||||
|
||||
# Assert
|
||||
mock_post_repository.delete.assert_called_once_with(test_post.id)
|
||||
mock_transaction_manager.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_post_forbidden(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test delete post by different user."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=test_post)
|
||||
|
||||
use_case = DeletePostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
|
||||
# Execute & Assert
|
||||
with pytest.raises(ForbiddenException):
|
||||
await use_case.execute(test_post.id, "other-user")
|
||||
|
||||
|
||||
class TestPublishPostUseCase:
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_post_success(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test successful post publish."""
|
||||
# Setup
|
||||
mock_post_repository.get_by_id = AsyncMock(return_value=test_post)
|
||||
mock_post_repository.update = AsyncMock()
|
||||
|
||||
use_case = PublishPostUseCase(mock_post_repository, mock_transaction_manager)
|
||||
|
||||
# Execute
|
||||
result = await use_case.publish(test_post.id, "user-123")
|
||||
|
||||
# Assert
|
||||
assert result.published is True
|
||||
mock_post_repository.update.assert_called_once()
|
||||
mock_transaction_manager.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestListPostsUseCase:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_posts(
|
||||
self,
|
||||
mock_post_repository: Mock,
|
||||
mock_transaction_manager: Mock,
|
||||
test_post: Post,
|
||||
) -> None:
|
||||
"""Test listing all posts."""
|
||||
# Setup
|
||||
mock_post_repository.get_all = AsyncMock(return_value=[test_post])
|
||||
|
||||
use_case = ListPostsUseCase(mock_post_repository, mock_transaction_manager)
|
||||
|
||||
# Execute
|
||||
results = await use_case.all_posts()
|
||||
|
||||
# Assert
|
||||
assert len(results) == 1
|
||||
assert results[0].id == test_post.id
|
||||
mock_post_repository.get_all.assert_called_once()
|
||||
@@ -1,18 +1,29 @@
|
||||
# Unit test fixtures
|
||||
# Provides: mocks, stubs, isolated test data
|
||||
"""Unit test fixtures."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.application.interfaces import TransactionManager
|
||||
from app.domain.repositories import PostRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_service() -> Mock:
|
||||
"""Create a mock service for unit testing."""
|
||||
return Mock()
|
||||
def mock_post_repository() -> Mock:
|
||||
"""Create a mock post repository."""
|
||||
return Mock(spec=PostRepository)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_transaction_manager() -> Mock:
|
||||
"""Create a mock transaction manager."""
|
||||
tx_manager = Mock(spec=TransactionManager)
|
||||
tx_manager.commit = AsyncMock()
|
||||
tx_manager.rollback = AsyncMock()
|
||||
return tx_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_service() -> AsyncMock:
|
||||
"""Create an async mock service for unit testing."""
|
||||
"""Create an async mock service."""
|
||||
return AsyncMock()
|
||||
|
||||
0
tests/unit/domain/__init__.py
Normal file
0
tests/unit/domain/__init__.py
Normal file
128
tests/unit/domain/test_entities.py
Normal file
128
tests/unit/domain/test_entities.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Tests for domain entities."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from app.domain.entities import Post
|
||||
from app.domain.value_objects import Content, Title
|
||||
|
||||
|
||||
class TestPost:
|
||||
def test_post_creation(self) -> None:
|
||||
"""Test creating a post."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
tags=["test", "python"],
|
||||
)
|
||||
|
||||
assert isinstance(post.id, UUID)
|
||||
assert post.title.value == "Test Title"
|
||||
assert post.content.value == "This is test content that is long enough"
|
||||
assert post.slug.value == "test-title"
|
||||
assert post.author_id == "user-123"
|
||||
assert post.published is False
|
||||
assert post.tags == ["test", "python"]
|
||||
|
||||
def test_post_publish(self) -> None:
|
||||
"""Test publishing a post."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
assert post.published is False
|
||||
post.publish()
|
||||
assert post.published is True
|
||||
|
||||
def test_post_unpublish(self) -> None:
|
||||
"""Test unpublishing a post."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
post.publish()
|
||||
assert post.published is True
|
||||
post.unpublish()
|
||||
assert post.published is False
|
||||
|
||||
def test_post_update_title(self) -> None:
|
||||
"""Test updating post title."""
|
||||
post = Post.create(
|
||||
title_str="Original Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
old_updated_at = post.updated_at
|
||||
post.update_title(Title("New Title"))
|
||||
|
||||
assert post.title.value == "New Title"
|
||||
assert post.slug.value == "new-title"
|
||||
assert post.updated_at > old_updated_at
|
||||
|
||||
def test_post_update_content(self) -> None:
|
||||
"""Test updating post content."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
old_updated_at = post.updated_at
|
||||
post.update_content(Content("Updated content that is also long enough"))
|
||||
|
||||
assert post.content.value == "Updated content that is also long enough"
|
||||
assert post.updated_at > old_updated_at
|
||||
|
||||
def test_post_add_tag(self) -> None:
|
||||
"""Test adding a tag."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
)
|
||||
|
||||
post.add_tag("python")
|
||||
assert "python" in post.tags
|
||||
|
||||
# Adding same tag twice should not duplicate
|
||||
post.add_tag("python")
|
||||
assert post.tags.count("python") == 1
|
||||
|
||||
def test_post_remove_tag(self) -> None:
|
||||
"""Test removing a tag."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
tags=["python", "fastapi"],
|
||||
)
|
||||
|
||||
post.remove_tag("python")
|
||||
assert "python" not in post.tags
|
||||
assert "fastapi" in post.tags
|
||||
|
||||
def test_post_to_dict(self) -> None:
|
||||
"""Test converting post to dict."""
|
||||
post = Post.create(
|
||||
title_str="Test Title",
|
||||
content_str="This is test content that is long enough",
|
||||
author_id="user-123",
|
||||
tags=["test"],
|
||||
)
|
||||
|
||||
data = post.to_dict()
|
||||
|
||||
assert data["title"] == "Test Title"
|
||||
assert data["content"] == "This is test content that is long enough"
|
||||
assert data["slug"] == "test-title"
|
||||
assert data["author_id"] == "user-123"
|
||||
assert data["published"] is False
|
||||
assert data["tags"] == ["test"]
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
assert "updated_at" in data
|
||||
48
tests/unit/domain/test_exceptions.py
Normal file
48
tests/unit/domain/test_exceptions.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Tests for domain exceptions."""
|
||||
|
||||
from app.domain.exceptions import (
|
||||
AlreadyExistsException,
|
||||
DomainException,
|
||||
ForbiddenException,
|
||||
NotFoundException,
|
||||
UnauthorizedException,
|
||||
ValidationException,
|
||||
)
|
||||
|
||||
|
||||
class TestDomainExceptions:
|
||||
def test_base_exception(self) -> None:
|
||||
"""Test base domain exception."""
|
||||
exc = DomainException("Something went wrong")
|
||||
assert exc.message == "Something went wrong"
|
||||
assert str(exc) == "Something went wrong"
|
||||
|
||||
def test_validation_exception(self) -> None:
|
||||
"""Test validation exception."""
|
||||
exc = ValidationException("Invalid input")
|
||||
assert isinstance(exc, DomainException)
|
||||
assert exc.message == "Invalid input"
|
||||
|
||||
def test_not_found_exception(self) -> None:
|
||||
"""Test not found exception."""
|
||||
exc = NotFoundException("Resource not found")
|
||||
assert isinstance(exc, DomainException)
|
||||
assert exc.message == "Resource not found"
|
||||
|
||||
def test_already_exists_exception(self) -> None:
|
||||
"""Test already exists exception."""
|
||||
exc = AlreadyExistsException("Already exists")
|
||||
assert isinstance(exc, DomainException)
|
||||
assert exc.message == "Already exists"
|
||||
|
||||
def test_unauthorized_exception(self) -> None:
|
||||
"""Test unauthorized exception."""
|
||||
exc = UnauthorizedException("Unauthorized")
|
||||
assert isinstance(exc, DomainException)
|
||||
assert exc.message == "Unauthorized"
|
||||
|
||||
def test_forbidden_exception(self) -> None:
|
||||
"""Test forbidden exception."""
|
||||
exc = ForbiddenException("Forbidden")
|
||||
assert isinstance(exc, DomainException)
|
||||
assert exc.message == "Forbidden"
|
||||
93
tests/unit/domain/test_value_objects.py
Normal file
93
tests/unit/domain/test_value_objects.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Tests for domain value objects."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.domain.value_objects import Content, Slug, Title
|
||||
|
||||
|
||||
class TestTitle:
|
||||
def test_valid_title(self) -> None:
|
||||
"""Test creating a valid title."""
|
||||
title = Title("Valid Title")
|
||||
assert title.value == "Valid Title"
|
||||
|
||||
def test_title_too_short(self) -> None:
|
||||
"""Test title that is too short."""
|
||||
with pytest.raises(ValueError, match="at least"):
|
||||
Title("ab")
|
||||
|
||||
def test_title_too_long(self) -> None:
|
||||
"""Test title that is too long."""
|
||||
with pytest.raises(ValueError, match="at most"):
|
||||
Title("a" * 201)
|
||||
|
||||
def test_title_empty(self) -> None:
|
||||
"""Test empty title."""
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
Title(" ")
|
||||
|
||||
def test_title_not_string(self) -> None:
|
||||
"""Test non-string title."""
|
||||
with pytest.raises(ValueError, match="string"):
|
||||
Title(123) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestContent:
|
||||
def test_valid_content(self) -> None:
|
||||
"""Test creating valid content."""
|
||||
content = Content("This is valid content with enough characters")
|
||||
assert content.value == "This is valid content with enough characters"
|
||||
|
||||
def test_content_too_short(self) -> None:
|
||||
"""Test content that is too short."""
|
||||
with pytest.raises(ValueError, match="at least"):
|
||||
Content("short")
|
||||
|
||||
def test_content_too_long(self) -> None:
|
||||
"""Test content that is too long."""
|
||||
with pytest.raises(ValueError, match="at most"):
|
||||
Content("a" * 50001)
|
||||
|
||||
def test_content_empty(self) -> None:
|
||||
"""Test empty content."""
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
Content(" ")
|
||||
|
||||
|
||||
class TestSlug:
|
||||
def test_valid_slug(self) -> None:
|
||||
"""Test creating a valid slug."""
|
||||
slug = Slug("valid-slug")
|
||||
assert slug.value == "valid-slug"
|
||||
|
||||
def test_slug_from_title(self) -> None:
|
||||
"""Test generating slug from title."""
|
||||
slug = Slug.from_title("Hello World Post")
|
||||
assert slug.value == "hello-world-post"
|
||||
|
||||
def test_slug_from_title_with_special_chars(self) -> None:
|
||||
"""Test generating slug from title with special characters."""
|
||||
slug = Slug.from_title("Hello, World! Post @#$%")
|
||||
assert slug.value == "hello-world-post"
|
||||
|
||||
def test_slug_from_title_only_special_chars(self) -> None:
|
||||
"""Test generating slug from title with only special characters."""
|
||||
slug = Slug.from_title("!@#$%")
|
||||
assert slug.value == "post"
|
||||
|
||||
def test_slug_invalid_chars(self) -> None:
|
||||
"""Test slug with invalid characters."""
|
||||
with pytest.raises(ValueError, match="lowercase"):
|
||||
Slug("Invalid_Slug")
|
||||
|
||||
def test_slug_uppercase(self) -> None:
|
||||
"""Test slug with uppercase letters."""
|
||||
with pytest.raises(ValueError, match="lowercase"):
|
||||
Slug("Uppercase-Slug")
|
||||
|
||||
def test_slug_equality(self) -> None:
|
||||
"""Test slug value equality."""
|
||||
slug1 = Slug("test-slug")
|
||||
slug2 = Slug("test-slug")
|
||||
assert slug1 == slug2
|
||||
assert hash(slug1) == hash(slug2)
|
||||
0
tests/unit/infrastructure/__init__.py
Normal file
0
tests/unit/infrastructure/__init__.py
Normal file
37
tests/unit/infrastructure/test_config.py
Normal file
37
tests/unit/infrastructure/test_config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Tests for infrastructure config."""
|
||||
|
||||
from app.infrastructure.config import Settings
|
||||
|
||||
|
||||
class TestSettings:
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default settings values by creating settings without env file."""
|
||||
# Create settings with no env file to test defaults
|
||||
s = Settings(_env_file=None)
|
||||
assert s.app_name == "Blog API"
|
||||
assert s.debug is False
|
||||
assert s.host == "0.0.0.0"
|
||||
assert s.port == 8000
|
||||
assert s.database_url == "sqlite:///./blog.db"
|
||||
assert s.database_echo is False
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Test custom settings values."""
|
||||
s = Settings(
|
||||
app_name="Test API",
|
||||
debug=True,
|
||||
host="localhost",
|
||||
port=9000,
|
||||
database_url="postgresql://test",
|
||||
secret_key="test-secret",
|
||||
)
|
||||
assert s.app_name == "Test API"
|
||||
assert s.debug is True
|
||||
assert s.host == "localhost"
|
||||
assert s.port == 9000
|
||||
assert s.database_url == "postgresql://test"
|
||||
assert s.secret_key == "test-secret"
|
||||
|
||||
def test_model_config(self) -> None:
|
||||
"""Test settings model config."""
|
||||
assert "env_file" in Settings.model_config
|
||||
@@ -1,52 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.core.config import Settings
|
||||
|
||||
|
||||
class TestSettings:
|
||||
def test_default_values(self) -> None:
|
||||
settings = Settings()
|
||||
assert settings.app_name == "Blog API"
|
||||
assert settings.debug is False
|
||||
assert settings.host == "0.0.0.0"
|
||||
assert settings.port == 8000
|
||||
assert settings.database_url is None
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
settings = Settings(
|
||||
app_name="Test API",
|
||||
debug=True,
|
||||
host="localhost",
|
||||
port=9000,
|
||||
database_url="postgresql://test",
|
||||
)
|
||||
assert settings.app_name == "Test API"
|
||||
assert settings.debug is True
|
||||
assert settings.host == "localhost"
|
||||
assert settings.port == 9000
|
||||
assert settings.database_url == "postgresql://test"
|
||||
|
||||
def test_settings_from_env(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"APP_NAME": "Env API",
|
||||
"DEBUG": "true",
|
||||
"HOST": "127.0.0.1",
|
||||
"PORT": "8080",
|
||||
"DATABASE_URL": "sqlite:///test.db",
|
||||
},
|
||||
):
|
||||
settings = Settings()
|
||||
assert settings.app_name == "Env API"
|
||||
assert settings.debug is True
|
||||
assert settings.host == "127.0.0.1"
|
||||
assert settings.port == 8080
|
||||
assert settings.database_url == "sqlite:///test.db"
|
||||
|
||||
def test_global_settings_instance(self) -> None:
|
||||
from app.core.config import settings
|
||||
|
||||
assert isinstance(settings, Settings)
|
||||
assert settings.app_name == "Blog API"
|
||||
@@ -1,110 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from app.common.error_handler import (
|
||||
ErrorResponse,
|
||||
app_exception_handler,
|
||||
http_exception_handler,
|
||||
register_exception_handlers,
|
||||
)
|
||||
from app.core.exceptions import AppException
|
||||
|
||||
|
||||
class TestErrorResponse:
|
||||
def test_error_response_creation(self) -> None:
|
||||
response = ErrorResponse(
|
||||
status_code=400,
|
||||
message="Bad request",
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.message == "Bad request"
|
||||
assert response.details is None
|
||||
|
||||
def test_error_response_with_details(self) -> None:
|
||||
response = ErrorResponse(
|
||||
status_code=500,
|
||||
message="Internal error",
|
||||
details={"field": "value"},
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
assert response.status_code == 500
|
||||
assert response.message == "Internal error"
|
||||
assert response.details == {"field": "value"}
|
||||
|
||||
|
||||
class TestAppExceptionHandler:
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_exception_handler(self) -> None:
|
||||
request = Mock(spec=Request)
|
||||
exc = AppException(message="Test error", status_code=400)
|
||||
|
||||
response = await app_exception_handler(request, exc)
|
||||
|
||||
assert response.status_code == 400
|
||||
body = bytes(response.body).decode()
|
||||
assert "Test error" in body
|
||||
assert "400" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_exception_handler_content(self) -> None:
|
||||
request = Mock(spec=Request)
|
||||
exc = AppException(message="Validation error", status_code=422)
|
||||
|
||||
with patch("app.common.error_handler.datetime") as mock_datetime:
|
||||
mock_datetime.now.return_value.isoformat.return_value = (
|
||||
"2024-01-01T00:00:00"
|
||||
)
|
||||
|
||||
response = await app_exception_handler(request, exc)
|
||||
|
||||
content = bytes(response.body).decode()
|
||||
assert "Validation error" in content
|
||||
assert "422" in content
|
||||
assert "2024-01-01T00:00:00" in content
|
||||
|
||||
|
||||
class TestHttpExceptionHandler:
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_exception_handler(self) -> None:
|
||||
request = Mock(spec=Request)
|
||||
exc = HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
response = await http_exception_handler(request, exc)
|
||||
|
||||
assert response.status_code == 404
|
||||
body = bytes(response.body).decode()
|
||||
assert "Not found" in body
|
||||
assert "404" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_exception_handler_content(self) -> None:
|
||||
request = Mock(spec=Request)
|
||||
exc = HTTPException(status_code=503, detail="Service unavailable")
|
||||
|
||||
with patch("app.common.error_handler.datetime") as mock_datetime:
|
||||
mock_datetime.now.return_value.isoformat.return_value = (
|
||||
"2024-01-01T12:00:00"
|
||||
)
|
||||
|
||||
response = await http_exception_handler(request, exc)
|
||||
|
||||
content = bytes(response.body).decode()
|
||||
assert "Service unavailable" in content
|
||||
assert "503" in content
|
||||
assert "2024-01-01T12:00:00" in content
|
||||
|
||||
|
||||
class TestRegisterExceptionHandlers:
|
||||
def test_register_exception_handlers(self) -> None:
|
||||
app = Mock(spec=FastAPI)
|
||||
|
||||
register_exception_handlers(app)
|
||||
|
||||
assert app.add_exception_handler.call_count == 2
|
||||
app.add_exception_handler.assert_any_call(AppException, app_exception_handler)
|
||||
app.add_exception_handler.assert_any_call(HTTPException, http_exception_handler)
|
||||
@@ -1,87 +0,0 @@
|
||||
from app.core.exceptions import (
|
||||
AppException,
|
||||
ForbiddenError,
|
||||
NotFoundError,
|
||||
UnauthorizedError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
|
||||
class TestAppException:
|
||||
def test_default_status_code(self) -> None:
|
||||
exc = AppException(message="Test error")
|
||||
assert exc.message == "Test error"
|
||||
assert exc.status_code == 500
|
||||
|
||||
def test_custom_status_code(self) -> None:
|
||||
exc = AppException(message="Custom error", status_code=400)
|
||||
assert exc.message == "Custom error"
|
||||
assert exc.status_code == 400
|
||||
|
||||
def test_string_representation(self) -> None:
|
||||
exc = AppException(message="Error message")
|
||||
assert str(exc) == "Error message"
|
||||
|
||||
|
||||
class TestNotFoundError:
|
||||
def test_default_message(self) -> None:
|
||||
exc = NotFoundError()
|
||||
assert exc.message == "Resource not found"
|
||||
assert exc.status_code == 404
|
||||
|
||||
def test_custom_message(self) -> None:
|
||||
exc = NotFoundError(message="Item not found")
|
||||
assert exc.message == "Item not found"
|
||||
assert exc.status_code == 404
|
||||
|
||||
def test_is_subclass_of_app_exception(self) -> None:
|
||||
exc = NotFoundError()
|
||||
assert isinstance(exc, AppException)
|
||||
|
||||
|
||||
class TestValidationError:
|
||||
def test_default_message(self) -> None:
|
||||
exc = ValidationError()
|
||||
assert exc.message == "Validation failed"
|
||||
assert exc.status_code == 400
|
||||
|
||||
def test_custom_message(self) -> None:
|
||||
exc = ValidationError(message="Invalid email format")
|
||||
assert exc.message == "Invalid email format"
|
||||
assert exc.status_code == 400
|
||||
|
||||
def test_is_subclass_of_app_exception(self) -> None:
|
||||
exc = ValidationError()
|
||||
assert isinstance(exc, AppException)
|
||||
|
||||
|
||||
class TestUnauthorizedError:
|
||||
def test_default_message(self) -> None:
|
||||
exc = UnauthorizedError()
|
||||
assert exc.message == "Unauthorized"
|
||||
assert exc.status_code == 401
|
||||
|
||||
def test_custom_message(self) -> None:
|
||||
exc = UnauthorizedError(message="Invalid credentials")
|
||||
assert exc.message == "Invalid credentials"
|
||||
assert exc.status_code == 401
|
||||
|
||||
def test_is_subclass_of_app_exception(self) -> None:
|
||||
exc = UnauthorizedError()
|
||||
assert isinstance(exc, AppException)
|
||||
|
||||
|
||||
class TestForbiddenError:
|
||||
def test_default_message(self) -> None:
|
||||
exc = ForbiddenError()
|
||||
assert exc.message == "Forbidden"
|
||||
assert exc.status_code == 403
|
||||
|
||||
def test_custom_message(self) -> None:
|
||||
exc = ForbiddenError(message="Access denied")
|
||||
assert exc.message == "Access denied"
|
||||
assert exc.status_code == 403
|
||||
|
||||
def test_is_subclass_of_app_exception(self) -> None:
|
||||
exc = ForbiddenError()
|
||||
assert isinstance(exc, AppException)
|
||||
49
tests/unit/test_main.py
Normal file
49
tests/unit/test_main.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Tests for main application."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.main import app_factory, lifespan, main
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan() -> None:
|
||||
"""Test lifespan context manager."""
|
||||
app = FastAPI()
|
||||
|
||||
with (
|
||||
patch("app.main.init_db") as mock_init,
|
||||
patch("app.main.close_db") as mock_close,
|
||||
):
|
||||
async with lifespan(app):
|
||||
mock_init.assert_called_once()
|
||||
mock_close.assert_not_called()
|
||||
mock_close.assert_called_once()
|
||||
|
||||
|
||||
def test_app_factory() -> None:
|
||||
"""Test app factory creates FastAPI app."""
|
||||
app = app_factory()
|
||||
assert isinstance(app, FastAPI)
|
||||
|
||||
|
||||
def test_app_factory_has_routes() -> None:
|
||||
"""Test app has registered routes."""
|
||||
app = app_factory()
|
||||
routes = [str(route.path) for route in app.routes if hasattr(route, "path")]
|
||||
assert "/health" in routes
|
||||
# Check that API routes are included
|
||||
assert any("api" in path for path in routes)
|
||||
|
||||
|
||||
@patch("app.main.uvicorn.run")
|
||||
def test_main(mock_uvicorn_run: Mock) -> None:
|
||||
"""Test main function starts uvicorn."""
|
||||
main()
|
||||
mock_uvicorn_run.assert_called_once()
|
||||
call_kwargs = mock_uvicorn_run.call_args.kwargs
|
||||
assert call_kwargs.get("factory") is True
|
||||
assert call_kwargs.get("host") == "0.0.0.0"
|
||||
assert call_kwargs.get("port") == 8000
|
||||
@@ -1,33 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.main import app_factory, lifespan, main
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan() -> None:
|
||||
app = FastAPI()
|
||||
assert isinstance(lifespan, asynccontextmanager(lifespan).__class__) # type: ignore[arg-type]
|
||||
|
||||
async with lifespan(app):
|
||||
pass
|
||||
|
||||
|
||||
def test_app_factory() -> None:
|
||||
app = app_factory()
|
||||
assert isinstance(app, FastAPI)
|
||||
assert app.router.lifespan_context == lifespan
|
||||
|
||||
|
||||
@patch("app.main.uvicorn.run")
|
||||
def test_main(mock_uvicorn_run: Mock) -> None:
|
||||
main()
|
||||
mock_uvicorn_run.assert_called_once_with(
|
||||
app_factory,
|
||||
factory=True,
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
)
|
||||
Reference in New Issue
Block a user