mirror of
https://github.com/NoFxAiOS/nofx.git
synced 2025-12-06 13:54:41 +08:00
fix(decision): clarify field names for update_stop_loss and update_take_profit actions (#993)
* fix(decision): clarify field names for update_stop_loss and update_take_profit actions 修复 AI 决策中的字段名混淆问题: **问题**: AI 在使用 update_stop_loss 时错误地使用了 `stop_loss` 字段, 导致解析失败(backend 期望 `new_stop_loss` 字段) **根因**: 系统 prompt 的字段说明不够明确,AI 无法知道 update_stop_loss 应该使用 new_stop_loss 字段而非 stop_loss **修复**: 1. 在字段说明中明确标注: - update_stop_loss 时必填: new_stop_loss (不是 stop_loss) - update_take_profit 时必填: new_take_profit (不是 take_profit) 2. 在 JSON 示例中增加 update_stop_loss 的具体用法示例 **验证**: decision_logs 中的错误 "新止损价格必须大于0: 0.00" 应该消失 * test(decision): add validation tests for update actions 添加针对 update_stop_loss、update_take_profit 和 partial_close 动作的字段验证单元测试: **测试覆盖**: 1. TestUpdateStopLossValidation - 验证 new_stop_loss 字段 - 正确使用 new_stop_loss 字段(应通过) - new_stop_loss 为 0(应报错) - new_stop_loss 为负数(应报错) 2. TestUpdateTakeProfitValidation - 验证 new_take_profit 字段 - 正确使用 new_take_profit 字段(应通过) - new_take_profit 为 0(应报错) - new_take_profit 为负数(应报错) 3. TestPartialCloseValidation - 验证 close_percentage 字段 - 正确使用 close_percentage 字段(应通过) - close_percentage 为 0(应报错) - close_percentage 超过 100(应报错) **测试结果**:所有测试用例通过 ✓ --------- Co-authored-by: Shui <88711385+hzb1115@users.noreply.github.com>
This commit is contained in:
committed by
tangmengqiu
parent
a74aed5a20
commit
11e4022867
@@ -346,13 +346,17 @@ func buildSystemPrompt(accountEquity float64, btcEthLeverage, altcoinLeverage in
|
||||
sb.WriteString("<decision>\n")
|
||||
sb.WriteString("```json\n[\n")
|
||||
sb.WriteString(fmt.Sprintf(" {\"symbol\": \"BTCUSDT\", \"action\": \"open_short\", \"leverage\": %d, \"position_size_usd\": %.0f, \"stop_loss\": 97000, \"take_profit\": 91000, \"confidence\": 85, \"risk_usd\": 300, \"reasoning\": \"下跌趋势+MACD死叉\"},\n", btcEthLeverage, accountEquity*5))
|
||||
sb.WriteString(" {\"symbol\": \"SOLUSDT\", \"action\": \"update_stop_loss\", \"new_stop_loss\": 155, \"reasoning\": \"移动止损至保本位\"},\n")
|
||||
sb.WriteString(" {\"symbol\": \"ETHUSDT\", \"action\": \"close_long\", \"reasoning\": \"止盈离场\"}\n")
|
||||
sb.WriteString("]\n```\n")
|
||||
sb.WriteString("</decision>\n\n")
|
||||
sb.WriteString("## 字段说明\n\n")
|
||||
sb.WriteString("- `action`: open_long | open_short | close_long | close_short | update_stop_loss | update_take_profit | partial_close | hold | wait\n")
|
||||
sb.WriteString("- `confidence`: 0-100(开仓建议≥75)\n")
|
||||
sb.WriteString("- 开仓时必填: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd, reasoning\n\n")
|
||||
sb.WriteString("- 开仓时必填: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd, reasoning\n")
|
||||
sb.WriteString("- update_stop_loss 时必填: new_stop_loss (注意是 new_stop_loss,不是 stop_loss)\n")
|
||||
sb.WriteString("- update_take_profit 时必填: new_take_profit (注意是 new_take_profit,不是 take_profit)\n")
|
||||
sb.WriteString("- partial_close 时必填: close_percentage (0-100)\n\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@@ -98,3 +98,198 @@ func TestLeverageFallback(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateStopLossValidation 测试 update_stop_loss 动作的字段验证
|
||||
func TestUpdateStopLossValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
decision Decision
|
||||
wantError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "正确使用new_stop_loss字段",
|
||||
decision: Decision{
|
||||
Symbol: "SOLUSDT",
|
||||
Action: "update_stop_loss",
|
||||
NewStopLoss: 155.5,
|
||||
Reasoning: "移动止损至保本位",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "new_stop_loss为0应该报错",
|
||||
decision: Decision{
|
||||
Symbol: "SOLUSDT",
|
||||
Action: "update_stop_loss",
|
||||
NewStopLoss: 0,
|
||||
Reasoning: "测试错误情况",
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "新止损价格必须大于0",
|
||||
},
|
||||
{
|
||||
name: "new_stop_loss为负数应该报错",
|
||||
decision: Decision{
|
||||
Symbol: "SOLUSDT",
|
||||
Action: "update_stop_loss",
|
||||
NewStopLoss: -100,
|
||||
Reasoning: "测试错误情况",
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "新止损价格必须大于0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateDecision(&tt.decision, 1000.0, 10, 5)
|
||||
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantError && err != nil {
|
||||
if tt.errorMsg != "" && !contains(err.Error(), tt.errorMsg) {
|
||||
t.Errorf("错误信息不匹配: got %q, want to contain %q", err.Error(), tt.errorMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateTakeProfitValidation 测试 update_take_profit 动作的字段验证
|
||||
func TestUpdateTakeProfitValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
decision Decision
|
||||
wantError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "正确使用new_take_profit字段",
|
||||
decision: Decision{
|
||||
Symbol: "BTCUSDT",
|
||||
Action: "update_take_profit",
|
||||
NewTakeProfit: 98000,
|
||||
Reasoning: "调整止盈至关键阻力位",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "new_take_profit为0应该报错",
|
||||
decision: Decision{
|
||||
Symbol: "BTCUSDT",
|
||||
Action: "update_take_profit",
|
||||
NewTakeProfit: 0,
|
||||
Reasoning: "测试错误情况",
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "新止盈价格必须大于0",
|
||||
},
|
||||
{
|
||||
name: "new_take_profit为负数应该报错",
|
||||
decision: Decision{
|
||||
Symbol: "BTCUSDT",
|
||||
Action: "update_take_profit",
|
||||
NewTakeProfit: -1000,
|
||||
Reasoning: "测试错误情况",
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "新止盈价格必须大于0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateDecision(&tt.decision, 1000.0, 10, 5)
|
||||
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantError && err != nil {
|
||||
if tt.errorMsg != "" && !contains(err.Error(), tt.errorMsg) {
|
||||
t.Errorf("错误信息不匹配: got %q, want to contain %q", err.Error(), tt.errorMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPartialCloseValidation 测试 partial_close 动作的字段验证
|
||||
func TestPartialCloseValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
decision Decision
|
||||
wantError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "正确使用close_percentage字段",
|
||||
decision: Decision{
|
||||
Symbol: "ETHUSDT",
|
||||
Action: "partial_close",
|
||||
ClosePercentage: 50.0,
|
||||
Reasoning: "锁定一半利润",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "close_percentage为0应该报错",
|
||||
decision: Decision{
|
||||
Symbol: "ETHUSDT",
|
||||
Action: "partial_close",
|
||||
ClosePercentage: 0,
|
||||
Reasoning: "测试错误情况",
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "平仓百分比必须在0-100之间",
|
||||
},
|
||||
{
|
||||
name: "close_percentage超过100应该报错",
|
||||
decision: Decision{
|
||||
Symbol: "ETHUSDT",
|
||||
Action: "partial_close",
|
||||
ClosePercentage: 150,
|
||||
Reasoning: "测试错误情况",
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "平仓百分比必须在0-100之间",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateDecision(&tt.decision, 1000.0, 10, 5)
|
||||
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantError && err != nil {
|
||||
if tt.errorMsg != "" && !contains(err.Error(), tt.errorMsg) {
|
||||
t.Errorf("错误信息不匹配: got %q, want to contain %q", err.Error(), tt.errorMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// contains 检查字符串是否包含子串(辅助函数)
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
|
||||
(len(s) > 0 && len(substr) > 0 && stringContains(s, substr)))
|
||||
}
|
||||
|
||||
func stringContains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user