mirror of
https://github.com/NoFxAiOS/nofx.git
synced 2025-12-06 13:54:41 +08:00
Improve(interface): replace some struct with interface for testing (#994)
* fix(trader): get peakPnlPct using posKey * fix(docs): keep readme at the same page * improve(interface): replace with interface * refactor mcp --------- Co-authored-by: zbhan <zbhan@freewheel.tv>
This commit is contained in:
@@ -121,12 +121,12 @@ type FullDecision struct {
|
||||
}
|
||||
|
||||
// GetFullDecision 获取AI的完整交易决策(批量分析所有币种和持仓)
|
||||
func GetFullDecision(ctx *Context, mcpClient *mcp.Client) (*FullDecision, error) {
|
||||
func GetFullDecision(ctx *Context, mcpClient mcp.AIClient) (*FullDecision, error) {
|
||||
return GetFullDecisionWithCustomPrompt(ctx, mcpClient, "", false, "")
|
||||
}
|
||||
|
||||
// GetFullDecisionWithCustomPrompt 获取AI的完整交易决策(支持自定义prompt和模板选择)
|
||||
func GetFullDecisionWithCustomPrompt(ctx *Context, mcpClient *mcp.Client, customPrompt string, overrideBase bool, templateName string) (*FullDecision, error) {
|
||||
func GetFullDecisionWithCustomPrompt(ctx *Context, mcpClient mcp.AIClient, customPrompt string, overrideBase bool, templateName string) (*FullDecision, error) {
|
||||
// 1. 为所有币种获取市场数据
|
||||
if err := fetchMarketDataForContext(ctx); err != nil {
|
||||
return nil, fmt.Errorf("获取市场数据失败: %w", err)
|
||||
|
||||
@@ -64,6 +64,22 @@ type DecisionAction struct {
|
||||
Error string `json:"error"` // 错误信息
|
||||
}
|
||||
|
||||
// IDecisionLogger 决策日志记录器接口
|
||||
type IDecisionLogger interface {
|
||||
// LogDecision 记录决策
|
||||
LogDecision(record *DecisionRecord) error
|
||||
// GetLatestRecords 获取最近N条记录(按时间正序:从旧到新)
|
||||
GetLatestRecords(n int) ([]*DecisionRecord, error)
|
||||
// GetRecordByDate 获取指定日期的所有记录
|
||||
GetRecordByDate(date time.Time) ([]*DecisionRecord, error)
|
||||
// CleanOldRecords 清理N天前的旧记录
|
||||
CleanOldRecords(days int) error
|
||||
// GetStatistics 获取统计信息
|
||||
GetStatistics() (*Statistics, error)
|
||||
// AnalyzePerformance 分析最近N个周期的交易表现
|
||||
AnalyzePerformance(lookbackCycles int) (*PerformanceAnalysis, error)
|
||||
}
|
||||
|
||||
// DecisionLogger 决策日志记录器
|
||||
type DecisionLogger struct {
|
||||
logDir string
|
||||
@@ -71,7 +87,7 @@ type DecisionLogger struct {
|
||||
}
|
||||
|
||||
// NewDecisionLogger 创建决策日志记录器
|
||||
func NewDecisionLogger(logDir string) *DecisionLogger {
|
||||
func NewDecisionLogger(logDir string) IDecisionLogger {
|
||||
if logDir == "" {
|
||||
logDir = "decision_logs"
|
||||
}
|
||||
|
||||
101
mcp/client.go
101
mcp/client.go
@@ -13,18 +13,17 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Provider AI提供商类型
|
||||
type Provider string
|
||||
|
||||
const (
|
||||
ProviderDeepSeek Provider = "deepseek"
|
||||
ProviderQwen Provider = "qwen"
|
||||
ProviderCustom Provider = "custom"
|
||||
ProviderCustom = "custom"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultTimeout = 120 * time.Second
|
||||
)
|
||||
|
||||
// Client AI API配置
|
||||
type Client struct {
|
||||
Provider Provider
|
||||
Provider string
|
||||
APIKey string
|
||||
BaseURL string
|
||||
Model string
|
||||
@@ -33,7 +32,7 @@ type Client struct {
|
||||
MaxTokens int // AI响应的最大token数
|
||||
}
|
||||
|
||||
func New() *Client {
|
||||
func New() AIClient {
|
||||
// 从环境变量读取 MaxTokens,默认 2000
|
||||
maxTokens := 2000
|
||||
if envMaxTokens := os.Getenv("AI_MAX_TOKENS"); envMaxTokens != "" {
|
||||
@@ -48,65 +47,15 @@ func New() *Client {
|
||||
// 默认配置
|
||||
return &Client{
|
||||
Provider: ProviderDeepSeek,
|
||||
BaseURL: "https://api.deepseek.com/v1",
|
||||
Model: "deepseek-chat",
|
||||
Timeout: 120 * time.Second, // 增加到120秒,因为AI需要分析大量数据
|
||||
BaseURL: DefaultDeepSeekBaseURL,
|
||||
Model: DefaultDeepSeekModel,
|
||||
Timeout: DefaultTimeout,
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// SetDeepSeekAPIKey 设置DeepSeek API密钥
|
||||
// customURL 为空时使用默认URL,customModel 为空时使用默认模型
|
||||
func (client *Client) SetDeepSeekAPIKey(apiKey string, customURL string, customModel string) {
|
||||
client.Provider = ProviderDeepSeek
|
||||
client.APIKey = apiKey
|
||||
if customURL != "" {
|
||||
client.BaseURL = customURL
|
||||
log.Printf("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL)
|
||||
} else {
|
||||
client.BaseURL = "https://api.deepseek.com/v1"
|
||||
log.Printf("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", client.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
client.Model = customModel
|
||||
log.Printf("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel)
|
||||
} else {
|
||||
client.Model = "deepseek-chat"
|
||||
log.Printf("🔧 [MCP] DeepSeek 使用默认 Model: %s", client.Model)
|
||||
}
|
||||
// 打印 API Key 的前后各4位用于验证
|
||||
if len(apiKey) > 8 {
|
||||
log.Printf("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
}
|
||||
|
||||
// SetQwenAPIKey 设置阿里云Qwen API密钥
|
||||
// customURL 为空时使用默认URL,customModel 为空时使用默认模型
|
||||
func (client *Client) SetQwenAPIKey(apiKey string, customURL string, customModel string) {
|
||||
client.Provider = ProviderQwen
|
||||
client.APIKey = apiKey
|
||||
if customURL != "" {
|
||||
client.BaseURL = customURL
|
||||
log.Printf("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL)
|
||||
} else {
|
||||
client.BaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
log.Printf("🔧 [MCP] Qwen 使用默认 BaseURL: %s", client.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
client.Model = customModel
|
||||
log.Printf("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel)
|
||||
} else {
|
||||
client.Model = "qwen3-max"
|
||||
log.Printf("🔧 [MCP] Qwen 使用默认 Model: %s", client.Model)
|
||||
}
|
||||
// 打印 API Key 的前后各4位用于验证
|
||||
if len(apiKey) > 8 {
|
||||
log.Printf("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
}
|
||||
|
||||
// SetCustomAPI 设置自定义OpenAI兼容API
|
||||
func (client *Client) SetCustomAPI(apiURL, apiKey, modelName string) {
|
||||
func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) {
|
||||
client.Provider = ProviderCustom
|
||||
client.APIKey = apiKey
|
||||
|
||||
@@ -119,22 +68,14 @@ func (client *Client) SetCustomAPI(apiURL, apiKey, modelName string) {
|
||||
client.UseFullURL = false
|
||||
}
|
||||
|
||||
client.Model = modelName
|
||||
client.Model = customModel
|
||||
client.Timeout = 120 * time.Second
|
||||
}
|
||||
|
||||
// SetClient 设置完整的AI配置(高级用户)
|
||||
func (client *Client) SetClient(Client Client) {
|
||||
if Client.Timeout == 0 {
|
||||
Client.Timeout = 30 * time.Second
|
||||
}
|
||||
client = &Client
|
||||
}
|
||||
|
||||
// CallWithMessages 使用 system + user prompt 调用AI API(推荐)
|
||||
func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, error) {
|
||||
if client.APIKey == "" {
|
||||
return "", fmt.Errorf("AI API密钥未设置,请先调用 SetDeepSeekAPIKey() 或 SetQwenAPIKey()")
|
||||
return "", fmt.Errorf("AI API密钥未设置,请先调用 SetAPIKey")
|
||||
}
|
||||
|
||||
// 重试配置
|
||||
@@ -171,6 +112,10 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string,
|
||||
return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func (client *Client) setAuthHeader(reqHeader http.Header) {
|
||||
reqHeader.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
|
||||
}
|
||||
|
||||
// callOnce 单次调用AI API(内部使用)
|
||||
func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error) {
|
||||
// 打印当前 AI 配置
|
||||
@@ -234,17 +179,7 @@ func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error)
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 根据不同的Provider设置认证方式
|
||||
switch client.Provider {
|
||||
case ProviderDeepSeek:
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
|
||||
case ProviderQwen:
|
||||
// 阿里云Qwen使用API-Key认证
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
|
||||
// 注意:如果使用的不是兼容模式,可能需要不同的认证方式
|
||||
default:
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
|
||||
}
|
||||
client.setAuthHeader(req.Header)
|
||||
|
||||
// 发送请求
|
||||
httpClient := &http.Client{Timeout: client.Timeout}
|
||||
|
||||
53
mcp/deepseek_client.go
Normal file
53
mcp/deepseek_client.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderDeepSeek = "deepseek"
|
||||
DefaultDeepSeekBaseURL = "https://api.deepseek.com/v1"
|
||||
DefaultDeepSeekModel = "deepseek-chat"
|
||||
)
|
||||
|
||||
type DeepSeekClient struct {
|
||||
*Client
|
||||
}
|
||||
|
||||
func NewDeepSeekClient() AIClient {
|
||||
client := New().(*Client)
|
||||
client.Provider = ProviderDeepSeek
|
||||
client.Model = DefaultDeepSeekModel
|
||||
client.BaseURL = DefaultDeepSeekBaseURL
|
||||
return &DeepSeekClient{
|
||||
Client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
if dsClient.Client == nil {
|
||||
dsClient.Client = New().(*Client)
|
||||
}
|
||||
dsClient.Client.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
log.Printf("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
dsClient.Client.BaseURL = customURL
|
||||
log.Printf("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL)
|
||||
} else {
|
||||
log.Printf("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", dsClient.Client.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
dsClient.Client.Model = customModel
|
||||
log.Printf("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel)
|
||||
} else {
|
||||
log.Printf("🔧 [MCP] DeepSeek 使用默认 Model: %s", dsClient.Client.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func (dsClient *DeepSeekClient) setAuthHeader(reqHeaders http.Header) {
|
||||
dsClient.Client.setAuthHeader(reqHeaders)
|
||||
}
|
||||
12
mcp/interface.go
Normal file
12
mcp/interface.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package mcp
|
||||
|
||||
import "net/http"
|
||||
|
||||
// AIClient AI客户端接口
|
||||
type AIClient interface {
|
||||
SetAPIKey(apiKey string, customURL string, customModel string)
|
||||
// CallWithMessages 使用 system + user prompt 调用AI API
|
||||
CallWithMessages(systemPrompt, userPrompt string) (string, error)
|
||||
|
||||
setAuthHeader(reqHeaders http.Header)
|
||||
}
|
||||
53
mcp/qwen_client.go
Normal file
53
mcp/qwen_client.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderQwen = "qwen"
|
||||
DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
DefaultQwenModel = "qwen3-max"
|
||||
)
|
||||
|
||||
type QwenClient struct {
|
||||
*Client
|
||||
}
|
||||
|
||||
func NewQwenClient() AIClient {
|
||||
client := New().(*Client)
|
||||
client.Provider = ProviderQwen
|
||||
client.Model = DefaultQwenModel
|
||||
client.BaseURL = DefaultQwenBaseURL
|
||||
return &QwenClient{
|
||||
Client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customModel string) {
|
||||
if qwenClient.Client == nil {
|
||||
qwenClient.Client = New().(*Client)
|
||||
}
|
||||
qwenClient.Client.APIKey = apiKey
|
||||
|
||||
if len(apiKey) > 8 {
|
||||
log.Printf("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
|
||||
}
|
||||
if customURL != "" {
|
||||
qwenClient.Client.BaseURL = customURL
|
||||
log.Printf("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL)
|
||||
} else {
|
||||
log.Printf("🔧 [MCP] Qwen 使用默认 BaseURL: %s", qwenClient.Client.BaseURL)
|
||||
}
|
||||
if customModel != "" {
|
||||
qwenClient.Client.Model = customModel
|
||||
log.Printf("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel)
|
||||
} else {
|
||||
log.Printf("🔧 [MCP] Qwen 使用默认 Model: %s", qwenClient.Client.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func (qwenClient *QwenClient) setAuthHeader(reqHeaders http.Header) {
|
||||
qwenClient.Client.setAuthHeader(reqHeaders)
|
||||
}
|
||||
@@ -85,8 +85,8 @@ type AutoTrader struct {
|
||||
exchange string // 交易平台名称
|
||||
config AutoTraderConfig
|
||||
trader Trader // 使用Trader接口(支持多平台)
|
||||
mcpClient *mcp.Client
|
||||
decisionLogger *logger.DecisionLogger // 决策日志记录器
|
||||
mcpClient mcp.AIClient
|
||||
decisionLogger logger.IDecisionLogger // 决策日志记录器
|
||||
initialBalance float64
|
||||
dailyPnL float64
|
||||
customPrompt string // 自定义交易策略prompt
|
||||
@@ -131,11 +131,12 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string)
|
||||
// 初始化AI
|
||||
if config.AIModel == "custom" {
|
||||
// 使用自定义API
|
||||
mcpClient.SetCustomAPI(config.CustomAPIURL, config.CustomAPIKey, config.CustomModelName)
|
||||
mcpClient.SetAPIKey(config.CustomAPIKey, config.CustomAPIURL, config.CustomModelName)
|
||||
log.Printf("🤖 [%s] 使用自定义AI API: %s (模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName)
|
||||
} else if config.UseQwen || config.AIModel == "qwen" {
|
||||
// 使用Qwen (支持自定义URL和Model)
|
||||
mcpClient.SetQwenAPIKey(config.QwenKey, config.CustomAPIURL, config.CustomModelName)
|
||||
mcpClient = mcp.NewQwenClient()
|
||||
mcpClient.SetAPIKey(config.QwenKey, config.CustomAPIURL, config.CustomModelName)
|
||||
if config.CustomAPIURL != "" || config.CustomModelName != "" {
|
||||
log.Printf("🤖 [%s] 使用阿里云Qwen AI (自定义URL: %s, 模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName)
|
||||
} else {
|
||||
@@ -143,7 +144,8 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string)
|
||||
}
|
||||
} else {
|
||||
// 默认使用DeepSeek (支持自定义URL和Model)
|
||||
mcpClient.SetDeepSeekAPIKey(config.DeepSeekKey, config.CustomAPIURL, config.CustomModelName)
|
||||
mcpClient = mcp.NewDeepSeekClient()
|
||||
mcpClient.SetAPIKey(config.DeepSeekKey, config.CustomAPIURL, config.CustomModelName)
|
||||
if config.CustomAPIURL != "" || config.CustomModelName != "" {
|
||||
log.Printf("🤖 [%s] 使用DeepSeek AI (自定义URL: %s, 模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName)
|
||||
} else {
|
||||
@@ -1205,7 +1207,7 @@ func (at *AutoTrader) GetSystemPromptTemplate() string {
|
||||
}
|
||||
|
||||
// GetDecisionLogger 获取决策日志记录器
|
||||
func (at *AutoTrader) GetDecisionLogger() *logger.DecisionLogger {
|
||||
func (at *AutoTrader) GetDecisionLogger() logger.IDecisionLogger {
|
||||
return at.decisionLogger
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ type AutoTraderTestSuite struct {
|
||||
// Mock 依赖
|
||||
mockTrader *MockTrader
|
||||
mockDB *MockDatabase
|
||||
mockLogger *logger.DecisionLogger
|
||||
mockLogger logger.IDecisionLogger
|
||||
|
||||
// gomonkey patches
|
||||
patches *gomonkey.Patches
|
||||
|
||||
Reference in New Issue
Block a user