nodebookls/model_manager.py
2025-10-29 13:56:24 +08:00

290 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
模型管理模块
提供统一的模型API抽象、路由和配置管理功能
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union
from openai import OpenAI
import os
import json
from config import settings
from exceptions import GenerationError as ModelError
class ModelProvider(ABC):
"""模型提供商抽象基类"""
def __init__(self, name: str, api_key: str, base_url: str, models: List[str]):
self.name = name
self.api_key = api_key
self.base_url = base_url
self.models = models
@abstractmethod
def get_client(self) -> Any:
"""获取模型客户端"""
pass
@abstractmethod
def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str:
"""生成文本"""
pass
@abstractmethod
def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]:
"""获取文本嵌入"""
pass
class OpenAIProvider(ModelProvider):
"""OpenAI模型提供商实现"""
def __init__(self, name: str, api_key: str, base_url: str, models: List[str]):
super().__init__(name, api_key, base_url, models)
self.client = OpenAI(api_key=api_key, base_url=base_url)
def get_client(self) -> OpenAI:
return self.client
def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str:
try:
response = self.client.chat.completions.create(
model=model,
messages=messages,
**kwargs
)
return response.choices[0].message.content or ""
except Exception as e:
raise ModelError(f"OpenAI文本生成失败: {str(e)}")
def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]:
try:
response = self.client.embeddings.create(
model=model,
input=texts
)
return [item.embedding for item in response.data]
except Exception as e:
raise ModelError(f"OpenAI嵌入获取失败: {str(e)}")
class OpenRouterProvider(ModelProvider):
"""OpenRouter模型提供商实现"""
def __init__(self, name: str, api_key: str, base_url: str, models: List[str]):
super().__init__(name, api_key, base_url, models)
self.client = OpenAI(api_key=api_key, base_url=base_url)
def get_client(self) -> OpenAI:
return self.client
def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str:
try:
# OpenRouter支持额外的头部信息
extra_headers = kwargs.pop("extra_headers", {})
response = self.client.chat.completions.create(
model=model,
messages=messages,
extra_headers=extra_headers,
**kwargs
)
return response.choices[0].message.content or ""
except Exception as e:
raise ModelError(f"OpenRouter文本生成失败: {str(e)}")
def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]:
# OpenRouter主要用于生成模型嵌入模型可能需要特殊处理
try:
response = self.client.embeddings.create(
model=model,
input=texts
)
return [item.embedding for item in response.data]
except Exception as e:
raise ModelError(f"OpenRouter嵌入获取失败: {str(e)}")
class SiliconFlowProvider(ModelProvider):
"""硅基流动模型提供商实现"""
def __init__(self, name: str, api_key: str, base_url: str, models: List[str]):
super().__init__(name, api_key, base_url, models)
self.client = OpenAI(api_key=api_key, base_url=base_url)
def get_client(self) -> OpenAI:
return self.client
def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str:
try:
response = self.client.chat.completions.create(
model=model,
messages=messages,
**kwargs
)
return response.choices[0].message.content or ""
except Exception as e:
raise ModelError(f"硅基流动文本生成失败: {str(e)}")
def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]:
try:
response = self.client.embeddings.create(
model=model,
input=texts
)
return [item.embedding for item in response.data]
except Exception as e:
raise ModelError(f"硅基流动嵌入获取失败: {str(e)}")
class ModelRouter:
"""模型路由器,支持动态路由和负载均衡"""
def __init__(self):
self.providers: Dict[str, ModelProvider] = {}
self.model_mapping: Dict[str, str] = {} # model_name -> provider_name
self.routes: Dict[str, Dict] = {} # 路由规则
self.load_default_providers()
self.load_config()
def load_default_providers(self):
"""加载默认模型提供商"""
# OpenAI提供商
if settings.OPENAI_API_KEY:
openai_provider = OpenAIProvider(
name="openai",
api_key=settings.OPENAI_API_KEY,
base_url=settings.OPENAI_API_BASE or "https://api.openai.com/v1",
models=["gpt-3.5-turbo", "gpt-4", "gpt-4o", "text-embedding-ada-002"]
)
self.providers["openai"] = openai_provider
for model in openai_provider.models:
self.model_mapping[model] = "openai"
# Anthropic提供商
if settings.ANTHROPIC_API_KEY:
anthropic_provider = OpenAIProvider(
name="anthropic",
api_key=settings.ANTHROPIC_API_KEY,
base_url=settings.ANTHROPIC_API_BASE or "https://api.anthropic.com/v1",
models=["claude-3-haiku", "claude-3-sonnet", "claude-3-opus"]
)
self.providers["anthropic"] = anthropic_provider
for model in anthropic_provider.models:
self.model_mapping[model] = "anthropic"
# 通义千问提供商
if settings.QWEN_API_KEY:
qwen_provider = OpenAIProvider(
name="qwen",
api_key=settings.QWEN_API_KEY,
base_url=settings.QWEN_API_BASE or "https://dashscope.aliyuncs.com/api/v1",
models=["qwen-turbo", "qwen-plus", "qwen-max"]
)
self.providers["qwen"] = qwen_provider
for model in qwen_provider.models:
self.model_mapping[model] = "qwen"
def load_config(self):
"""从配置文件加载模型路由配置"""
config_file = "model_config.json"
if os.path.exists(config_file):
try:
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
# 加载自定义提供商
if "providers" in config:
for provider_config in config["providers"]:
self.add_provider(provider_config)
# 加载模型映射
if "model_mapping" in config:
self.model_mapping.update(config["model_mapping"])
# 加载路由规则
if "routes" in config:
self.routes = config["routes"]
except Exception as e:
print(f"加载模型配置失败: {e}")
def add_provider(self, provider_config: Dict):
"""添加模型提供商"""
provider_type = provider_config.get("type")
name = provider_config.get("name")
api_key = provider_config.get("api_key")
base_url = provider_config.get("base_url")
models = provider_config.get("models", [])
if not all([provider_type, name, api_key, base_url]):
raise ModelError("提供商配置不完整")
# 确保必要字段不为空
if not name or not api_key or not base_url:
raise ModelError("提供商配置不完整name, api_key, base_url 不能为空")
if provider_type == "openai":
provider = OpenAIProvider(name, api_key, base_url, models)
elif provider_type == "openrouter":
provider = OpenRouterProvider(name, api_key, base_url, models)
elif provider_type == "siliconflow":
provider = SiliconFlowProvider(name, api_key, base_url, models)
else:
raise ModelError(f"不支持的提供商类型: {provider_type}")
self.providers[name] = provider
# 更新模型映射
for model in models:
self.model_mapping[model] = name
def get_provider_for_model(self, model_name: str) -> Optional[ModelProvider]:
"""根据模型名称获取对应的提供商"""
# 首先检查路由规则
if model_name in self.routes:
provider_name = self.routes[model_name].get("provider")
if provider_name and provider_name in self.providers:
return self.providers[provider_name]
# 然后检查模型映射
if model_name in self.model_mapping:
provider_name = self.model_mapping[model_name]
if provider_name in self.providers:
return self.providers[provider_name]
# 最后尝试默认提供商
for provider in self.providers.values():
if model_name in provider.models:
return provider
return None
def generate_text(self, model: str, messages: List[Dict], **kwargs) -> str:
"""路由文本生成请求到合适的提供商"""
provider = self.get_provider_for_model(model)
if not provider:
raise ModelError(f"未找到模型 {model} 的提供商")
return provider.generate_text(model, messages, **kwargs)
def get_embeddings(self, model: str, texts: List[str]) -> List[List[float]]:
"""路由嵌入获取请求到合适的提供商"""
provider = self.get_provider_for_model(model)
if not provider:
raise ModelError(f"未找到模型 {model} 的提供商")
return provider.get_embeddings(model, texts)
def list_models(self) -> List[str]:
"""列出所有可用模型"""
return list(self.model_mapping.keys())
def list_providers(self) -> List[str]:
"""列出所有提供商"""
return list(self.providers.keys())
# 全局模型路由器实例
model_router = ModelRouter()