112 lines
3.0 KiB
Python
112 lines
3.0 KiB
Python
"""
|
||
数据库连接和会话管理
|
||
"""
|
||
from sqlalchemy import create_engine, event
|
||
from sqlalchemy.orm import sessionmaker
|
||
from sqlalchemy.pool import QueuePool
|
||
from .config import settings
|
||
|
||
# 创建数据库引擎,支持MySQL连接池
|
||
db_settings = settings.database
|
||
|
||
if "sqlite" in db_settings.DATABASE_URL:
|
||
# SQLite 配置
|
||
engine = create_engine(
|
||
db_settings.DATABASE_URL,
|
||
echo=False,
|
||
future=True,
|
||
connect_args={"check_same_thread": False} # SQLite 特定配置
|
||
)
|
||
else:
|
||
# MySQL 配置
|
||
engine = create_engine(
|
||
db_settings.DATABASE_URL,
|
||
# 连接池配置
|
||
poolclass=QueuePool,
|
||
pool_size=db_settings.DB_POOL_SIZE,
|
||
max_overflow=db_settings.DB_MAX_OVERFLOW,
|
||
pool_timeout=db_settings.DB_POOL_TIMEOUT,
|
||
pool_recycle=db_settings.DB_POOL_RECYCLE,
|
||
pool_pre_ping=True, # 每次使用连接前检测是否有效,防止使用已断开的连接
|
||
echo=False,
|
||
future=True,
|
||
# 设置连接时使用的字符集
|
||
connect_args={
|
||
"charset": "utf8mb4",
|
||
"collation": "utf8mb4_unicode_ci",
|
||
# 添加连接超时设置
|
||
"connect_timeout": 60,
|
||
"read_timeout": 60,
|
||
"write_timeout": 60
|
||
}
|
||
)
|
||
|
||
# 如果是MySQL,启用严格模式
|
||
if "mysql" in db_settings.DATABASE_URL:
|
||
@event.listens_for(engine, "connect")
|
||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||
"""
|
||
为MySQL连接设置参数(如果使用MySQL)
|
||
"""
|
||
# 设置字符集为utf8mb4
|
||
try:
|
||
cursor = dbapi_connection.cursor()
|
||
cursor.execute("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci")
|
||
cursor.close()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
# 创建会话
|
||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
|
||
|
||
def get_db():
|
||
"""
|
||
获取数据库会话
|
||
"""
|
||
db = SessionLocal()
|
||
try:
|
||
yield db
|
||
except Exception as e:
|
||
# 如果发生异常,先回滚未提交的事务
|
||
try:
|
||
if db.is_active:
|
||
db.rollback()
|
||
except Exception:
|
||
pass
|
||
raise e
|
||
finally:
|
||
# 安全地关闭数据库连接
|
||
try:
|
||
# 检查会话状态
|
||
if db.is_active:
|
||
db.close()
|
||
else:
|
||
# 如果会话已经不活跃,尝试无效化连接
|
||
try:
|
||
db.close()
|
||
except Exception:
|
||
pass
|
||
except Exception:
|
||
# 忽略关闭时的错误,避免掩盖原始异常
|
||
pass
|
||
|
||
|
||
def create_tables():
|
||
"""
|
||
创建所有表
|
||
"""
|
||
from ..models.base import Base
|
||
from ..models import user, game
|
||
Base.metadata.create_all(bind=engine)
|
||
|
||
|
||
def drop_tables():
|
||
"""
|
||
删除所有表(仅用于开发环境)
|
||
"""
|
||
from ..models.base import Base
|
||
from ..models import user, game
|
||
Base.metadata.drop_all(bind=engine)
|