baoxiang/backend/app/core/database.py

112 lines
3.0 KiB
Python
Raw Normal View History

2025-12-16 18:06:50 +08:00
"""
数据库连接和会话管理
"""
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool
from .config import settings
# 创建数据库引擎支持MySQL连接池
db_settings = settings.database
if "sqlite" in db_settings.DATABASE_URL:
# SQLite 配置
engine = create_engine(
db_settings.DATABASE_URL,
echo=False,
future=True,
connect_args={"check_same_thread": False} # SQLite 特定配置
)
else:
# MySQL 配置
engine = create_engine(
db_settings.DATABASE_URL,
# 连接池配置
poolclass=QueuePool,
pool_size=db_settings.DB_POOL_SIZE,
max_overflow=db_settings.DB_MAX_OVERFLOW,
pool_timeout=db_settings.DB_POOL_TIMEOUT,
pool_recycle=db_settings.DB_POOL_RECYCLE,
2025-12-18 13:28:29 +08:00
pool_pre_ping=True, # 每次使用连接前检测是否有效,防止使用已断开的连接
2025-12-16 18:06:50 +08:00
echo=False,
future=True,
2025-12-17 11:43:50 +08:00
# 设置连接时使用的字符集
2025-12-18 13:28:29 +08:00
connect_args={
"charset": "utf8mb4",
"collation": "utf8mb4_unicode_ci",
# 添加连接超时设置
"connect_timeout": 60,
"read_timeout": 60,
"write_timeout": 60
}
2025-12-16 18:06:50 +08:00
)
# 如果是MySQL启用严格模式
if "mysql" in db_settings.DATABASE_URL:
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
"""
为MySQL连接设置参数如果使用MySQL
"""
2025-12-17 11:43:50 +08:00
# 设置字符集为utf8mb4
try:
cursor = dbapi_connection.cursor()
cursor.execute("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci")
cursor.close()
except Exception:
pass
2025-12-16 18:06:50 +08:00
# 创建会话
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_db():
"""
获取数据库会话
"""
db = SessionLocal()
try:
yield db
2025-12-18 13:28:29 +08:00
except Exception as e:
# 如果发生异常,先回滚未提交的事务
try:
if db.is_active:
db.rollback()
except Exception:
pass
raise e
2025-12-16 18:06:50 +08:00
finally:
2025-12-18 13:28:29 +08:00
# 安全地关闭数据库连接
try:
# 检查会话状态
if db.is_active:
db.close()
else:
# 如果会话已经不活跃,尝试无效化连接
try:
db.close()
except Exception:
pass
except Exception:
# 忽略关闭时的错误,避免掩盖原始异常
pass
2025-12-16 18:06:50 +08:00
def create_tables():
"""
创建所有表
"""
from ..models.base import Base
from ..models import user, game
Base.metadata.create_all(bind=engine)
def drop_tables():
"""
删除所有表仅用于开发环境
"""
from ..models.base import Base
from ..models import user, game
Base.metadata.drop_all(bind=engine)