290 lines
11 KiB
Python
290 lines
11 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
模型管理模块
|
||
提供统一的模型API抽象、路由和配置管理功能
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import List, Dict, Any, Optional, Union
|
||
from openai import OpenAI
|
||
import os
|
||
import json
|
||
from config import settings
|
||
from exceptions import GenerationError as ModelError
|
||
|
||
|
||
class ModelProvider(ABC):
|
||
"""模型提供商抽象基类"""
|
||
|
||
def __init__(self, name: str, api_key: str, base_url: str, models: List[str]):
|
||
self.name = name
|
||
self.api_key = api_key
|
||
self.base_url = base_url
|
||
self.models = models
|
||
|
||
@abstractmethod
|
||
def get_client(self) -> Any:
|
||
"""获取模型客户端"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str:
|
||
"""生成文本"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]:
|
||
"""获取文本嵌入"""
|
||
pass
|
||
|
||
|
||
class OpenAIProvider(ModelProvider):
|
||
"""OpenAI模型提供商实现"""
|
||
|
||
def __init__(self, name: str, api_key: str, base_url: str, models: List[str]):
|
||
super().__init__(name, api_key, base_url, models)
|
||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||
|
||
def get_client(self) -> OpenAI:
|
||
return self.client
|
||
|
||
def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str:
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model=model,
|
||
messages=messages,
|
||
**kwargs
|
||
)
|
||
return response.choices[0].message.content or ""
|
||
except Exception as e:
|
||
raise ModelError(f"OpenAI文本生成失败: {str(e)}")
|
||
|
||
def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]:
|
||
try:
|
||
response = self.client.embeddings.create(
|
||
model=model,
|
||
input=texts
|
||
)
|
||
return [item.embedding for item in response.data]
|
||
except Exception as e:
|
||
raise ModelError(f"OpenAI嵌入获取失败: {str(e)}")
|
||
|
||
|
||
class OpenRouterProvider(ModelProvider):
|
||
"""OpenRouter模型提供商实现"""
|
||
|
||
def __init__(self, name: str, api_key: str, base_url: str, models: List[str]):
|
||
super().__init__(name, api_key, base_url, models)
|
||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||
|
||
def get_client(self) -> OpenAI:
|
||
return self.client
|
||
|
||
def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str:
|
||
try:
|
||
# OpenRouter支持额外的头部信息
|
||
extra_headers = kwargs.pop("extra_headers", {})
|
||
response = self.client.chat.completions.create(
|
||
model=model,
|
||
messages=messages,
|
||
extra_headers=extra_headers,
|
||
**kwargs
|
||
)
|
||
return response.choices[0].message.content or ""
|
||
except Exception as e:
|
||
raise ModelError(f"OpenRouter文本生成失败: {str(e)}")
|
||
|
||
def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]:
|
||
# OpenRouter主要用于生成模型,嵌入模型可能需要特殊处理
|
||
try:
|
||
response = self.client.embeddings.create(
|
||
model=model,
|
||
input=texts
|
||
)
|
||
return [item.embedding for item in response.data]
|
||
except Exception as e:
|
||
raise ModelError(f"OpenRouter嵌入获取失败: {str(e)}")
|
||
|
||
|
||
class SiliconFlowProvider(ModelProvider):
|
||
"""硅基流动模型提供商实现"""
|
||
|
||
def __init__(self, name: str, api_key: str, base_url: str, models: List[str]):
|
||
super().__init__(name, api_key, base_url, models)
|
||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||
|
||
def get_client(self) -> OpenAI:
|
||
return self.client
|
||
|
||
def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str:
|
||
try:
|
||
response = self.client.chat.completions.create(
|
||
model=model,
|
||
messages=messages,
|
||
**kwargs
|
||
)
|
||
return response.choices[0].message.content or ""
|
||
except Exception as e:
|
||
raise ModelError(f"硅基流动文本生成失败: {str(e)}")
|
||
|
||
def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]:
|
||
try:
|
||
response = self.client.embeddings.create(
|
||
model=model,
|
||
input=texts
|
||
)
|
||
return [item.embedding for item in response.data]
|
||
except Exception as e:
|
||
raise ModelError(f"硅基流动嵌入获取失败: {str(e)}")
|
||
|
||
|
||
class ModelRouter:
|
||
"""模型路由器,支持动态路由和负载均衡"""
|
||
|
||
def __init__(self):
|
||
self.providers: Dict[str, ModelProvider] = {}
|
||
self.model_mapping: Dict[str, str] = {} # model_name -> provider_name
|
||
self.routes: Dict[str, Dict] = {} # 路由规则
|
||
self.load_default_providers()
|
||
self.load_config()
|
||
|
||
def load_default_providers(self):
|
||
"""加载默认模型提供商"""
|
||
# OpenAI提供商
|
||
if settings.OPENAI_API_KEY:
|
||
openai_provider = OpenAIProvider(
|
||
name="openai",
|
||
api_key=settings.OPENAI_API_KEY,
|
||
base_url=settings.OPENAI_API_BASE or "https://api.openai.com/v1",
|
||
models=["gpt-3.5-turbo", "gpt-4", "gpt-4o", "text-embedding-ada-002"]
|
||
)
|
||
self.providers["openai"] = openai_provider
|
||
for model in openai_provider.models:
|
||
self.model_mapping[model] = "openai"
|
||
|
||
# Anthropic提供商
|
||
if settings.ANTHROPIC_API_KEY:
|
||
anthropic_provider = OpenAIProvider(
|
||
name="anthropic",
|
||
api_key=settings.ANTHROPIC_API_KEY,
|
||
base_url=settings.ANTHROPIC_API_BASE or "https://api.anthropic.com/v1",
|
||
models=["claude-3-haiku", "claude-3-sonnet", "claude-3-opus"]
|
||
)
|
||
self.providers["anthropic"] = anthropic_provider
|
||
for model in anthropic_provider.models:
|
||
self.model_mapping[model] = "anthropic"
|
||
|
||
# 通义千问提供商
|
||
if settings.QWEN_API_KEY:
|
||
qwen_provider = OpenAIProvider(
|
||
name="qwen",
|
||
api_key=settings.QWEN_API_KEY,
|
||
base_url=settings.QWEN_API_BASE or "https://dashscope.aliyuncs.com/api/v1",
|
||
models=["qwen-turbo", "qwen-plus", "qwen-max"]
|
||
)
|
||
self.providers["qwen"] = qwen_provider
|
||
for model in qwen_provider.models:
|
||
self.model_mapping[model] = "qwen"
|
||
|
||
def load_config(self):
|
||
"""从配置文件加载模型路由配置"""
|
||
config_file = "model_config.json"
|
||
if os.path.exists(config_file):
|
||
try:
|
||
with open(config_file, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
# 加载自定义提供商
|
||
if "providers" in config:
|
||
for provider_config in config["providers"]:
|
||
self.add_provider(provider_config)
|
||
|
||
# 加载模型映射
|
||
if "model_mapping" in config:
|
||
self.model_mapping.update(config["model_mapping"])
|
||
|
||
# 加载路由规则
|
||
if "routes" in config:
|
||
self.routes = config["routes"]
|
||
except Exception as e:
|
||
print(f"加载模型配置失败: {e}")
|
||
|
||
def add_provider(self, provider_config: Dict):
|
||
"""添加模型提供商"""
|
||
provider_type = provider_config.get("type")
|
||
name = provider_config.get("name")
|
||
api_key = provider_config.get("api_key")
|
||
base_url = provider_config.get("base_url")
|
||
models = provider_config.get("models", [])
|
||
|
||
if not all([provider_type, name, api_key, base_url]):
|
||
raise ModelError("提供商配置不完整")
|
||
|
||
# 确保必要字段不为空
|
||
if not name or not api_key or not base_url:
|
||
raise ModelError("提供商配置不完整:name, api_key, base_url 不能为空")
|
||
|
||
if provider_type == "openai":
|
||
provider = OpenAIProvider(name, api_key, base_url, models)
|
||
elif provider_type == "openrouter":
|
||
provider = OpenRouterProvider(name, api_key, base_url, models)
|
||
elif provider_type == "siliconflow":
|
||
provider = SiliconFlowProvider(name, api_key, base_url, models)
|
||
else:
|
||
raise ModelError(f"不支持的提供商类型: {provider_type}")
|
||
|
||
self.providers[name] = provider
|
||
# 更新模型映射
|
||
for model in models:
|
||
self.model_mapping[model] = name
|
||
|
||
def get_provider_for_model(self, model_name: str) -> Optional[ModelProvider]:
|
||
"""根据模型名称获取对应的提供商"""
|
||
# 首先检查路由规则
|
||
if model_name in self.routes:
|
||
provider_name = self.routes[model_name].get("provider")
|
||
if provider_name and provider_name in self.providers:
|
||
return self.providers[provider_name]
|
||
|
||
# 然后检查模型映射
|
||
if model_name in self.model_mapping:
|
||
provider_name = self.model_mapping[model_name]
|
||
if provider_name in self.providers:
|
||
return self.providers[provider_name]
|
||
|
||
# 最后尝试默认提供商
|
||
for provider in self.providers.values():
|
||
if model_name in provider.models:
|
||
return provider
|
||
|
||
return None
|
||
|
||
def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str:
|
||
"""路由文本生成请求到合适的提供商"""
|
||
provider = self.get_provider_for_model(model)
|
||
if not provider:
|
||
raise ModelError(f"未找到模型 {model} 的提供商")
|
||
|
||
return provider.generate_text(model, messages, **kwargs)
|
||
|
||
def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]:
|
||
"""路由嵌入获取请求到合适的提供商"""
|
||
provider = self.get_provider_for_model(model)
|
||
if not provider:
|
||
raise ModelError(f"未找到模型 {model} 的提供商")
|
||
|
||
return provider.get_embeddings(model, texts)
|
||
|
||
def list_models(self) -> List[str]:
|
||
"""列出所有可用模型"""
|
||
return list(self.model_mapping.keys())
|
||
|
||
def list_providers(self) -> List[str]:
|
||
"""列出所有提供商"""
|
||
return list(self.providers.keys())
|
||
|
||
|
||
# 全局模型路由器实例
|
||
model_router = ModelRouter() |