Improve(interface): replace some struct with interface for testing (#994)

* fix(trader): get peakPnlPct using posKey

* fix(docs): keep readme at the same page

* improve(interface): replace with interface

* refactor mcp

---------

Co-authored-by: zbhan <zbhan@freewheel.tv>
This commit is contained in:
Shui
2025-11-13 22:22:05 -05:00
committed by tangmengqiu
parent 79358d4776
commit 3f5f964a67
8 changed files with 164 additions and 93 deletions

View File

@@ -121,12 +121,12 @@ type FullDecision struct {
}
// GetFullDecision 获取AI的完整交易决策批量分析所有币种和持仓
func GetFullDecision(ctx *Context, mcpClient *mcp.Client) (*FullDecision, error) {
func GetFullDecision(ctx *Context, mcpClient mcp.AIClient) (*FullDecision, error) {
return GetFullDecisionWithCustomPrompt(ctx, mcpClient, "", false, "")
}
// GetFullDecisionWithCustomPrompt 获取AI的完整交易决策支持自定义prompt和模板选择
func GetFullDecisionWithCustomPrompt(ctx *Context, mcpClient *mcp.Client, customPrompt string, overrideBase bool, templateName string) (*FullDecision, error) {
func GetFullDecisionWithCustomPrompt(ctx *Context, mcpClient mcp.AIClient, customPrompt string, overrideBase bool, templateName string) (*FullDecision, error) {
// 1. 为所有币种获取市场数据
if err := fetchMarketDataForContext(ctx); err != nil {
return nil, fmt.Errorf("获取市场数据失败: %w", err)

View File

@@ -64,6 +64,22 @@ type DecisionAction struct {
Error string `json:"error"` // 错误信息
}
// IDecisionLogger 决策日志记录器接口
type IDecisionLogger interface {
// LogDecision 记录决策
LogDecision(record *DecisionRecord) error
// GetLatestRecords 获取最近N条记录按时间正序从旧到新
GetLatestRecords(n int) ([]*DecisionRecord, error)
// GetRecordByDate 获取指定日期的所有记录
GetRecordByDate(date time.Time) ([]*DecisionRecord, error)
// CleanOldRecords 清理N天前的旧记录
CleanOldRecords(days int) error
// GetStatistics 获取统计信息
GetStatistics() (*Statistics, error)
// AnalyzePerformance 分析最近N个周期的交易表现
AnalyzePerformance(lookbackCycles int) (*PerformanceAnalysis, error)
}
// DecisionLogger 决策日志记录器
type DecisionLogger struct {
logDir string
@@ -71,7 +87,7 @@ type DecisionLogger struct {
}
// NewDecisionLogger 创建决策日志记录器
func NewDecisionLogger(logDir string) *DecisionLogger {
func NewDecisionLogger(logDir string) IDecisionLogger {
if logDir == "" {
logDir = "decision_logs"
}

View File

@@ -13,18 +13,17 @@ import (
"time"
)
// Provider AI提供商类型
type Provider string
const (
ProviderDeepSeek Provider = "deepseek"
ProviderQwen Provider = "qwen"
ProviderCustom Provider = "custom"
ProviderCustom = "custom"
)
var (
DefaultTimeout = 120 * time.Second
)
// Client AI API配置
type Client struct {
Provider Provider
Provider string
APIKey string
BaseURL string
Model string
@@ -33,7 +32,7 @@ type Client struct {
MaxTokens int // AI响应的最大token数
}
func New() *Client {
func New() AIClient {
// 从环境变量读取 MaxTokens默认 2000
maxTokens := 2000
if envMaxTokens := os.Getenv("AI_MAX_TOKENS"); envMaxTokens != "" {
@@ -48,65 +47,15 @@ func New() *Client {
// 默认配置
return &Client{
Provider: ProviderDeepSeek,
BaseURL: "https://api.deepseek.com/v1",
Model: "deepseek-chat",
Timeout: 120 * time.Second, // 增加到120秒因为AI需要分析大量数据
BaseURL: DefaultDeepSeekBaseURL,
Model: DefaultDeepSeekModel,
Timeout: DefaultTimeout,
MaxTokens: maxTokens,
}
}
// SetDeepSeekAPIKey 设置DeepSeek API密钥
// customURL 为空时使用默认URLcustomModel 为空时使用默认模型
func (client *Client) SetDeepSeekAPIKey(apiKey string, customURL string, customModel string) {
client.Provider = ProviderDeepSeek
client.APIKey = apiKey
if customURL != "" {
client.BaseURL = customURL
log.Printf("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL)
} else {
client.BaseURL = "https://api.deepseek.com/v1"
log.Printf("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", client.BaseURL)
}
if customModel != "" {
client.Model = customModel
log.Printf("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel)
} else {
client.Model = "deepseek-chat"
log.Printf("🔧 [MCP] DeepSeek 使用默认 Model: %s", client.Model)
}
// 打印 API Key 的前后各4位用于验证
if len(apiKey) > 8 {
log.Printf("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
}
// SetQwenAPIKey 设置阿里云Qwen API密钥
// customURL 为空时使用默认URLcustomModel 为空时使用默认模型
func (client *Client) SetQwenAPIKey(apiKey string, customURL string, customModel string) {
client.Provider = ProviderQwen
client.APIKey = apiKey
if customURL != "" {
client.BaseURL = customURL
log.Printf("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL)
} else {
client.BaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
log.Printf("🔧 [MCP] Qwen 使用默认 BaseURL: %s", client.BaseURL)
}
if customModel != "" {
client.Model = customModel
log.Printf("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel)
} else {
client.Model = "qwen3-max"
log.Printf("🔧 [MCP] Qwen 使用默认 Model: %s", client.Model)
}
// 打印 API Key 的前后各4位用于验证
if len(apiKey) > 8 {
log.Printf("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
}
// SetCustomAPI 设置自定义OpenAI兼容API
func (client *Client) SetCustomAPI(apiURL, apiKey, modelName string) {
func (client *Client) SetAPIKey(apiKey, apiURL, customModel string) {
client.Provider = ProviderCustom
client.APIKey = apiKey
@@ -119,22 +68,14 @@ func (client *Client) SetCustomAPI(apiURL, apiKey, modelName string) {
client.UseFullURL = false
}
client.Model = modelName
client.Model = customModel
client.Timeout = 120 * time.Second
}
// SetClient 设置完整的AI配置高级用户
func (client *Client) SetClient(Client Client) {
if Client.Timeout == 0 {
Client.Timeout = 30 * time.Second
}
client = &Client
}
// CallWithMessages 使用 system + user prompt 调用AI API推荐
func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string, error) {
if client.APIKey == "" {
return "", fmt.Errorf("AI API密钥未设置请先调用 SetDeepSeekAPIKey() 或 SetQwenAPIKey()")
return "", fmt.Errorf("AI API密钥未设置请先调用 SetAPIKey")
}
// 重试配置
@@ -171,6 +112,10 @@ func (client *Client) CallWithMessages(systemPrompt, userPrompt string) (string,
return "", fmt.Errorf("重试%d次后仍然失败: %w", maxRetries, lastErr)
}
func (client *Client) setAuthHeader(reqHeader http.Header) {
reqHeader.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
}
// callOnce 单次调用AI API内部使用
func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error) {
// 打印当前 AI 配置
@@ -234,17 +179,7 @@ func (client *Client) callOnce(systemPrompt, userPrompt string) (string, error)
req.Header.Set("Content-Type", "application/json")
// 根据不同的Provider设置认证方式
switch client.Provider {
case ProviderDeepSeek:
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
case ProviderQwen:
// 阿里云Qwen使用API-Key认证
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
// 注意:如果使用的不是兼容模式,可能需要不同的认证方式
default:
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.APIKey))
}
client.setAuthHeader(req.Header)
// 发送请求
httpClient := &http.Client{Timeout: client.Timeout}

53
mcp/deepseek_client.go Normal file
View File

@@ -0,0 +1,53 @@
package mcp
import (
"log"
"net/http"
)
const (
ProviderDeepSeek = "deepseek"
DefaultDeepSeekBaseURL = "https://api.deepseek.com/v1"
DefaultDeepSeekModel = "deepseek-chat"
)
type DeepSeekClient struct {
*Client
}
func NewDeepSeekClient() AIClient {
client := New().(*Client)
client.Provider = ProviderDeepSeek
client.Model = DefaultDeepSeekModel
client.BaseURL = DefaultDeepSeekBaseURL
return &DeepSeekClient{
Client: client,
}
}
func (dsClient *DeepSeekClient) SetAPIKey(apiKey string, customURL string, customModel string) {
if dsClient.Client == nil {
dsClient.Client = New().(*Client)
}
dsClient.Client.APIKey = apiKey
if len(apiKey) > 8 {
log.Printf("🔧 [MCP] DeepSeek API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
dsClient.Client.BaseURL = customURL
log.Printf("🔧 [MCP] DeepSeek 使用自定义 BaseURL: %s", customURL)
} else {
log.Printf("🔧 [MCP] DeepSeek 使用默认 BaseURL: %s", dsClient.Client.BaseURL)
}
if customModel != "" {
dsClient.Client.Model = customModel
log.Printf("🔧 [MCP] DeepSeek 使用自定义 Model: %s", customModel)
} else {
log.Printf("🔧 [MCP] DeepSeek 使用默认 Model: %s", dsClient.Client.Model)
}
}
func (dsClient *DeepSeekClient) setAuthHeader(reqHeaders http.Header) {
dsClient.Client.setAuthHeader(reqHeaders)
}

12
mcp/interface.go Normal file
View File

@@ -0,0 +1,12 @@
package mcp
import "net/http"
// AIClient AI客户端接口
type AIClient interface {
SetAPIKey(apiKey string, customURL string, customModel string)
// CallWithMessages 使用 system + user prompt 调用AI API
CallWithMessages(systemPrompt, userPrompt string) (string, error)
setAuthHeader(reqHeaders http.Header)
}

53
mcp/qwen_client.go Normal file
View File

@@ -0,0 +1,53 @@
package mcp
import (
"log"
"net/http"
)
const (
ProviderQwen = "qwen"
DefaultQwenBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
DefaultQwenModel = "qwen3-max"
)
type QwenClient struct {
*Client
}
func NewQwenClient() AIClient {
client := New().(*Client)
client.Provider = ProviderQwen
client.Model = DefaultQwenModel
client.BaseURL = DefaultQwenBaseURL
return &QwenClient{
Client: client,
}
}
func (qwenClient *QwenClient) SetAPIKey(apiKey string, customURL string, customModel string) {
if qwenClient.Client == nil {
qwenClient.Client = New().(*Client)
}
qwenClient.Client.APIKey = apiKey
if len(apiKey) > 8 {
log.Printf("🔧 [MCP] Qwen API Key: %s...%s", apiKey[:4], apiKey[len(apiKey)-4:])
}
if customURL != "" {
qwenClient.Client.BaseURL = customURL
log.Printf("🔧 [MCP] Qwen 使用自定义 BaseURL: %s", customURL)
} else {
log.Printf("🔧 [MCP] Qwen 使用默认 BaseURL: %s", qwenClient.Client.BaseURL)
}
if customModel != "" {
qwenClient.Client.Model = customModel
log.Printf("🔧 [MCP] Qwen 使用自定义 Model: %s", customModel)
} else {
log.Printf("🔧 [MCP] Qwen 使用默认 Model: %s", qwenClient.Client.Model)
}
}
func (qwenClient *QwenClient) setAuthHeader(reqHeaders http.Header) {
qwenClient.Client.setAuthHeader(reqHeaders)
}

View File

@@ -85,8 +85,8 @@ type AutoTrader struct {
exchange string // 交易平台名称
config AutoTraderConfig
trader Trader // 使用Trader接口支持多平台
mcpClient *mcp.Client
decisionLogger *logger.DecisionLogger // 决策日志记录器
mcpClient mcp.AIClient
decisionLogger logger.IDecisionLogger // 决策日志记录器
initialBalance float64
dailyPnL float64
customPrompt string // 自定义交易策略prompt
@@ -131,11 +131,12 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string)
// 初始化AI
if config.AIModel == "custom" {
// 使用自定义API
mcpClient.SetCustomAPI(config.CustomAPIURL, config.CustomAPIKey, config.CustomModelName)
mcpClient.SetAPIKey(config.CustomAPIKey, config.CustomAPIURL, config.CustomModelName)
log.Printf("🤖 [%s] 使用自定义AI API: %s (模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName)
} else if config.UseQwen || config.AIModel == "qwen" {
// 使用Qwen (支持自定义URL和Model)
mcpClient.SetQwenAPIKey(config.QwenKey, config.CustomAPIURL, config.CustomModelName)
mcpClient = mcp.NewQwenClient()
mcpClient.SetAPIKey(config.QwenKey, config.CustomAPIURL, config.CustomModelName)
if config.CustomAPIURL != "" || config.CustomModelName != "" {
log.Printf("🤖 [%s] 使用阿里云Qwen AI (自定义URL: %s, 模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName)
} else {
@@ -143,7 +144,8 @@ func NewAutoTrader(config AutoTraderConfig, database interface{}, userID string)
}
} else {
// 默认使用DeepSeek (支持自定义URL和Model)
mcpClient.SetDeepSeekAPIKey(config.DeepSeekKey, config.CustomAPIURL, config.CustomModelName)
mcpClient = mcp.NewDeepSeekClient()
mcpClient.SetAPIKey(config.DeepSeekKey, config.CustomAPIURL, config.CustomModelName)
if config.CustomAPIURL != "" || config.CustomModelName != "" {
log.Printf("🤖 [%s] 使用DeepSeek AI (自定义URL: %s, 模型: %s)", config.Name, config.CustomAPIURL, config.CustomModelName)
} else {
@@ -1205,7 +1207,7 @@ func (at *AutoTrader) GetSystemPromptTemplate() string {
}
// GetDecisionLogger 获取决策日志记录器
func (at *AutoTrader) GetDecisionLogger() *logger.DecisionLogger {
func (at *AutoTrader) GetDecisionLogger() logger.IDecisionLogger {
return at.decisionLogger
}

View File

@@ -31,7 +31,7 @@ type AutoTraderTestSuite struct {
// Mock 依赖
mockTrader *MockTrader
mockDB *MockDatabase
mockLogger *logger.DecisionLogger
mockLogger logger.IDecisionLogger
// gomonkey patches
patches *gomonkey.Patches