#!/usr/bin/env python # -*- coding: utf-8 -*- """ 模型管理模块 提供统一的模型API抽象、路由和配置管理功能 """ from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional, Union from openai import OpenAI import os import json from config import settings from exceptions import GenerationError as ModelError class ModelProvider(ABC): """模型提供商抽象基类""" def __init__(self, name: str, api_key: str, base_url: str, models: List[str]): self.name = name self.api_key = api_key self.base_url = base_url self.models = models @abstractmethod def get_client(self) -> Any: """获取模型客户端""" pass @abstractmethod def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str: """生成文本""" pass @abstractmethod def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]: """获取文本嵌入""" pass class OpenAIProvider(ModelProvider): """OpenAI模型提供商实现""" def __init__(self, name: str, api_key: str, base_url: str, models: List[str]): super().__init__(name, api_key, base_url, models) self.client = OpenAI(api_key=api_key, base_url=base_url) def get_client(self) -> OpenAI: return self.client def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str: try: response = self.client.chat.completions.create( model=model, messages=messages, **kwargs ) return response.choices[0].message.content or "" except Exception as e: raise ModelError(f"OpenAI文本生成失败: {str(e)}") def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]: try: response = self.client.embeddings.create( model=model, input=texts ) return [item.embedding for item in response.data] except Exception as e: raise ModelError(f"OpenAI嵌入获取失败: {str(e)}") class OpenRouterProvider(ModelProvider): """OpenRouter模型提供商实现""" def __init__(self, name: str, api_key: str, base_url: str, models: List[str]): super().__init__(name, api_key, base_url, models) self.client = OpenAI(api_key=api_key, base_url=base_url) def get_client(self) -> OpenAI: return self.client def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str: try: # OpenRouter支持额外的头部信息 extra_headers = kwargs.pop("extra_headers", {}) response = self.client.chat.completions.create( model=model, messages=messages, extra_headers=extra_headers, **kwargs ) return response.choices[0].message.content or "" except Exception as e: raise ModelError(f"OpenRouter文本生成失败: {str(e)}") def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]: # OpenRouter主要用于生成模型,嵌入模型可能需要特殊处理 try: response = self.client.embeddings.create( model=model, input=texts ) return [item.embedding for item in response.data] except Exception as e: raise ModelError(f"OpenRouter嵌入获取失败: {str(e)}") class SiliconFlowProvider(ModelProvider): """硅基流动模型提供商实现""" def __init__(self, name: str, api_key: str, base_url: str, models: List[str]): super().__init__(name, api_key, base_url, models) self.client = OpenAI(api_key=api_key, base_url=base_url) def get_client(self) -> OpenAI: return self.client def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str: try: response = self.client.chat.completions.create( model=model, messages=messages, **kwargs ) return response.choices[0].message.content or "" except Exception as e: raise ModelError(f"硅基流动文本生成失败: {str(e)}") def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]: try: response = self.client.embeddings.create( model=model, input=texts ) return [item.embedding for item in response.data] except Exception as e: raise ModelError(f"硅基流动嵌入获取失败: {str(e)}") class ModelRouter: """模型路由器,支持动态路由和负载均衡""" def __init__(self): self.providers: Dict[str, ModelProvider] = {} self.model_mapping: Dict[str, str] = {} # model_name -> provider_name self.routes: Dict[str, Dict] = {} # 路由规则 self.load_default_providers() self.load_config() def load_default_providers(self): """加载默认模型提供商""" # OpenAI提供商 if settings.OPENAI_API_KEY: openai_provider = OpenAIProvider( name="openai", api_key=settings.OPENAI_API_KEY, base_url=settings.OPENAI_API_BASE or "https://api.openai.com/v1", models=["gpt-3.5-turbo", "gpt-4", "gpt-4o", "text-embedding-ada-002"] ) self.providers["openai"] = openai_provider for model in openai_provider.models: self.model_mapping[model] = "openai" # Anthropic提供商 if settings.ANTHROPIC_API_KEY: anthropic_provider = OpenAIProvider( name="anthropic", api_key=settings.ANTHROPIC_API_KEY, base_url=settings.ANTHROPIC_API_BASE or "https://api.anthropic.com/v1", models=["claude-3-haiku", "claude-3-sonnet", "claude-3-opus"] ) self.providers["anthropic"] = anthropic_provider for model in anthropic_provider.models: self.model_mapping[model] = "anthropic" # 通义千问提供商 if settings.QWEN_API_KEY: qwen_provider = OpenAIProvider( name="qwen", api_key=settings.QWEN_API_KEY, base_url=settings.QWEN_API_BASE or "https://dashscope.aliyuncs.com/api/v1", models=["qwen-turbo", "qwen-plus", "qwen-max"] ) self.providers["qwen"] = qwen_provider for model in qwen_provider.models: self.model_mapping[model] = "qwen" def load_config(self): """从配置文件加载模型路由配置""" config_file = "model_config.json" if os.path.exists(config_file): try: with open(config_file, 'r', encoding='utf-8') as f: config = json.load(f) # 加载自定义提供商 if "providers" in config: for provider_config in config["providers"]: self.add_provider(provider_config) # 加载模型映射 if "model_mapping" in config: self.model_mapping.update(config["model_mapping"]) # 加载路由规则 if "routes" in config: self.routes = config["routes"] except Exception as e: print(f"加载模型配置失败: {e}") def add_provider(self, provider_config: Dict): """添加模型提供商""" provider_type = provider_config.get("type") name = provider_config.get("name") api_key = provider_config.get("api_key") base_url = provider_config.get("base_url") models = provider_config.get("models", []) if not all([provider_type, name, api_key, base_url]): raise ModelError("提供商配置不完整") # 确保必要字段不为空 if not name or not api_key or not base_url: raise ModelError("提供商配置不完整:name, api_key, base_url 不能为空") if provider_type == "openai": provider = OpenAIProvider(name, api_key, base_url, models) elif provider_type == "openrouter": provider = OpenRouterProvider(name, api_key, base_url, models) elif provider_type == "siliconflow": provider = SiliconFlowProvider(name, api_key, base_url, models) else: raise ModelError(f"不支持的提供商类型: {provider_type}") self.providers[name] = provider # 更新模型映射 for model in models: self.model_mapping[model] = name def get_provider_for_model(self, model_name: str) -> Optional[ModelProvider]: """根据模型名称获取对应的提供商""" # 首先检查路由规则 if model_name in self.routes: provider_name = self.routes[model_name].get("provider") if provider_name and provider_name in self.providers: return self.providers[provider_name] # 然后检查模型映射 if model_name in self.model_mapping: provider_name = self.model_mapping[model_name] if provider_name in self.providers: return self.providers[provider_name] # 最后尝试默认提供商 for provider in self.providers.values(): if model_name in provider.models: return provider return None def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str: """路由文本生成请求到合适的提供商""" provider = self.get_provider_for_model(model) if not provider: raise ModelError(f"未找到模型 {model} 的提供商") return provider.generate_text(model, messages, **kwargs) def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]: """路由嵌入获取请求到合适的提供商""" provider = self.get_provider_for_model(model) if not provider: raise ModelError(f"未找到模型 {model} 的提供商") return provider.get_embeddings(model, texts) def list_models(self) -> List[str]: """列出所有可用模型""" return list(self.model_mapping.keys()) def list_providers(self) -> List[str]: """列出所有提供商""" return list(self.providers.keys()) # 全局模型路由器实例 model_router = ModelRouter()