mirror of
https://github.com/NoFxAiOS/nofx.git
synced 2025-12-06 13:54:41 +08:00
Refactor/trading actions (#1169)
* refactor: 简化交易动作,移除 update_stop_loss/update_take_profit/partial_close - 移除 Decision 结构体中的 NewStopLoss, NewTakeProfit, ClosePercentage 字段 - 删除 executeUpdateStopLossWithRecord, executeUpdateTakeProfitWithRecord, executePartialCloseWithRecord 函数 - 简化 logger 中的 partial_close 聚合逻辑 - 更新 AI prompt 和验证逻辑,只保留 6 个核心动作 - 清理相关测试代码 保留的交易动作: open_long, open_short, close_long, close_short, hold, wait * refactor: 移除 AI学习与反思 模块 - 删除前端 AILearning.tsx 组件和相关引用 - 删除后端 /performance API 接口 - 删除 logger 中 AnalyzePerformance、calculateSharpeRatio 等函数 - 删除 PerformanceAnalysis、TradeOutcome、SymbolPerformance 等结构体 - 删除 Context 中的 Performance 字段 - 移除 AI prompt 中夏普比率自我进化相关内容 - 清理 i18n 翻译文件中的相关条目 该模块基于磁盘存储计算,经常出错,做减法移除 * refactor: 将数据库操作统一迁移到 store 包 - 新增 store/ 包,统一管理所有数据库操作 - store.go: 主 Store 结构,懒加载各子模块 - user.go, ai_model.go, exchange.go, trader.go 等子模块 - 支持加密/解密函数注入 (SetCryptoFuncs) - 更新 main.go 使用 store.New() 替代 config.NewDatabase() - 更新 api/server.go 使用 *store.Store 替代 *config.Database - 更新 manager/trader_manager.go: - 新增 LoadTradersFromStore, LoadUserTradersFromStore 方法 - 删除旧版 LoadUserTraders, LoadTraderByID, loadSingleTrader 等方法 - 移除 nofx/config 依赖 - 删除 config/database.go 和 config/database_test.go - 更新 api/server_test.go 使用 store.Trader 类型 - 清理 logger/ 包中未使用的 telegram 相关代码 * refactor: unify encryption key management via .env - Remove redundant EncryptionManager and SecureStorage - Simplify CryptoService to load keys from environment variables only - RSA_PRIVATE_KEY: RSA private key for client-server encryption - DATA_ENCRYPTION_KEY: AES-256 key for database encryption - JWT_SECRET: JWT signing key for authentication - Update start.sh to auto-generate missing keys on first run - Remove secrets/ directory and file-based key storage - Delete obsolete encryption setup scripts - Update .env.example with all required keys * refactor: unify logger usage across mcp package - Add MCPLogger adapter in logger package to implement mcp.Logger interface - Update mcp/config.go to use global logger by default - Remove redundant defaultLogger from mcp/logger.go - Keep noopLogger for testing purposes * chore: remove leftover test RSA key file * chore: remove unused bootstrap package * refactor: unify logging to use logger package instead of fmt/log - Replace all fmt.Print/log.Print calls with logger package - Add auto-initialization in logger package init() for test compatibility - Update main.go to initialize logger at startup - Migrate all packages: api, backtest, config, decision, manager, market, store, trader * refactor: rename database file from config.db to data.db - Update main.go, start.sh, docker-compose.yml - Update migration script and documentation - Update .gitignore and translations * fix: add RSA_PRIVATE_KEY to docker-compose environment * fix: add registration_enabled to /api/config response * fix: Fix navigation between login and register pages Use window.location.href instead of react-router's navigate() to fix the issue where URL changes but the page doesn't reload due to App.tsx using custom route state management. * fix: Switch SQLite from WAL to DELETE mode for Docker compatibility WAL mode causes data sync issues with Docker bind mounts on macOS due to incompatible file locking mechanisms between the container and host. DELETE mode (traditional journaling) ensures data is written directly to the main database file. * refactor: Remove default user from database initialization The default user was a legacy placeholder that is no longer needed now that proper user registration is in place. * feat: Add order tracking system with centralized status sync - Add trader_orders table for tracking all order lifecycle - Implement GetOrderStatus interface for all exchanges (Binance, Bybit, Hyperliquid, Aster, Lighter) - Create OrderSyncManager for centralized order status polling - Add trading statistics (Sharpe ratio, win rate, profit factor) to AI context - Include recent completed orders in AI decision input - Remove per-order goroutine polling in favor of global sync manager * feat: Add TradingView K-line chart to dashboard - Create TradingViewChart component with exchange/symbol selectors - Support Binance, Bybit, OKX, Coinbase, Kraken, KuCoin exchanges - Add popular symbols quick selection - Support multiple timeframes (1m to 1W) - Add fullscreen mode - Integrate with Dashboard page below equity chart - Add i18n translations for zh/en * refactor: Replace separate charts with tabbed ChartTabs component - Create ChartTabs component with tab switching between equity curve and K-line - Add embedded mode support for EquityChart and TradingViewChart - User can now switch between account equity and market chart in same area * fix: Use ChartTabs in App.tsx and fix embedded mode in EquityChart - Replace EquityChart with ChartTabs in App.tsx (the actual dashboard renderer) - Fix EquityChart embedded mode for error and empty data states - Rename interval state to timeInterval to avoid shadowing window.setInterval - Add debug logging to ChartTabs component * feat: Add position tracking system for accurate trade history - Add trader_positions table to track complete open/close trades - Add PositionSyncManager to detect manual closes via polling - Record position on open, update on close with PnL calculation - Use positions table for trading stats and recent trades (replacing orders table) - Fix TradingView chart symbol format (add .P suffix for futures) - Fix DecisionCard wait/hold action color (gray instead of red) - Auto-append USDT suffix for custom symbol input * update ---------
This commit is contained in:
42
.env.example
42
.env.example
@@ -1,14 +1,46 @@
|
||||
# NOFX Environment Variables Template
|
||||
# Copy this file to .env and modify the values as needed
|
||||
|
||||
# Ports Configuration
|
||||
# Backend API server port (internal: 8080, external: configurable)
|
||||
# ===========================================
|
||||
# Server Configuration
|
||||
# ===========================================
|
||||
|
||||
# Backend API server port
|
||||
NOFX_BACKEND_PORT=8080
|
||||
|
||||
# Frontend web interface port (Nginx listens on port 80 internally)
|
||||
# Frontend web interface port
|
||||
NOFX_FRONTEND_PORT=3000
|
||||
|
||||
# Timezone Setting
|
||||
# System timezone for container time synchronization
|
||||
# Timezone
|
||||
NOFX_TIMEZONE=Asia/Shanghai
|
||||
|
||||
# ===========================================
|
||||
# Authentication (Required)
|
||||
# ===========================================
|
||||
|
||||
# JWT signing secret (any random string, at least 32 characters)
|
||||
# Generate with: openssl rand -base64 32
|
||||
JWT_SECRET=your-jwt-secret-change-this-in-production
|
||||
|
||||
# ===========================================
|
||||
# Encryption Keys (Required)
|
||||
# ===========================================
|
||||
|
||||
# AES-256 data encryption key (Base64 encoded, 32 bytes)
|
||||
# Used for encrypting sensitive data in database (API keys, secrets)
|
||||
# Generate with: openssl rand -base64 32
|
||||
DATA_ENCRYPTION_KEY=your-base64-encoded-32-byte-key
|
||||
|
||||
# RSA private key for client-server encryption (PEM format)
|
||||
# Used for end-to-end encryption of sensitive data from browser
|
||||
# Generate with: openssl genrsa 2048
|
||||
# Note: Replace newlines with \n for single-line format
|
||||
RSA_PRIVATE_KEY=-----BEGIN RSA PRIVATE KEY-----\nYOUR_KEY_HERE\n-----END RSA PRIVATE KEY-----
|
||||
|
||||
# ===========================================
|
||||
# Optional: External Services
|
||||
# ===========================================
|
||||
|
||||
# Telegram notifications (optional)
|
||||
# TELEGRAM_BOT_TOKEN=your-bot-token
|
||||
# TELEGRAM_CHAT_ID=your-chat-id
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -30,8 +30,7 @@ Thumbs.db
|
||||
# 环境变量
|
||||
.env
|
||||
config.json
|
||||
config.db*
|
||||
nofx.db
|
||||
data.db*
|
||||
configbak.json
|
||||
|
||||
# 决策日志
|
||||
|
||||
@@ -116,7 +116,7 @@ If needed, rollback is simple:
|
||||
|
||||
```bash
|
||||
# Restore backup
|
||||
cp config.db.backup config.db
|
||||
cp data.db.backup data.db
|
||||
|
||||
# Comment out 3 lines in main.go
|
||||
# (encryption initialization)
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"time"
|
||||
|
||||
"nofx/backtest"
|
||||
"nofx/config"
|
||||
"nofx/decision"
|
||||
"nofx/store"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -486,9 +486,6 @@ func (s *Server) ensureBacktestRunOwnership(runID, userID string) (*backtest.Run
|
||||
if owner == "" {
|
||||
return meta, nil
|
||||
}
|
||||
if owner == "default" && userID == "admin" {
|
||||
return meta, nil
|
||||
}
|
||||
if owner != userID {
|
||||
return nil, errBacktestForbidden
|
||||
}
|
||||
@@ -514,7 +511,7 @@ func (s *Server) resolveBacktestAIConfig(cfg *backtest.BacktestConfig, userID st
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("config is nil")
|
||||
}
|
||||
if s.database == nil {
|
||||
if s.store == nil {
|
||||
return fmt.Errorf("系统数据库未就绪,无法加载AI模型配置")
|
||||
}
|
||||
|
||||
@@ -527,7 +524,7 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("config is nil")
|
||||
}
|
||||
if s.database == nil {
|
||||
if s.store == nil {
|
||||
return fmt.Errorf("系统数据库未就绪,无法加载AI模型配置")
|
||||
}
|
||||
|
||||
@@ -535,17 +532,17 @@ func (s *Server) hydrateBacktestAIConfig(cfg *backtest.BacktestConfig) error {
|
||||
modelID := strings.TrimSpace(cfg.AIModelID)
|
||||
|
||||
var (
|
||||
model *config.AIModelConfig
|
||||
model *store.AIModel
|
||||
err error
|
||||
)
|
||||
|
||||
if modelID != "" {
|
||||
model, err = s.database.GetAIModel(cfg.UserID, modelID)
|
||||
model, err = s.store.AIModel().Get(cfg.UserID, modelID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加载AI模型失败: %w", err)
|
||||
}
|
||||
} else {
|
||||
model, err = s.database.GetDefaultAIModel(cfg.UserID)
|
||||
model, err = s.store.AIModel().GetDefault(cfg.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("未找到可用的AI模型: %w", err)
|
||||
}
|
||||
|
||||
403
api/server.go
403
api/server.go
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"nofx/config"
|
||||
"nofx/store"
|
||||
)
|
||||
|
||||
// TestUpdateTraderRequest_SystemPromptTemplate 测试更新交易员时 SystemPromptTemplate 字段是否存在
|
||||
@@ -100,12 +100,12 @@ func TestUpdateTraderRequest_SystemPromptTemplate(t *testing.T) {
|
||||
func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
traderConfig *config.TraderRecord
|
||||
traderConfig *store.Trader
|
||||
expectedTemplate string
|
||||
}{
|
||||
{
|
||||
name: "获取配置应该返回 system_prompt_template=nof1",
|
||||
traderConfig: &config.TraderRecord{
|
||||
traderConfig: &store.Trader{
|
||||
ID: "trader-123",
|
||||
UserID: "user-1",
|
||||
Name: "Test Trader",
|
||||
@@ -126,7 +126,7 @@ func TestGetTraderConfigResponse_SystemPromptTemplate(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "获取配置应该返回 system_prompt_template=default",
|
||||
traderConfig: &config.TraderRecord{
|
||||
traderConfig: &store.Trader{
|
||||
ID: "trader-456",
|
||||
UserID: "user-1",
|
||||
Name: "Test Trader 2",
|
||||
@@ -229,7 +229,7 @@ func TestUpdateTraderRequest_CompleteFields(t *testing.T) {
|
||||
// TestTraderListResponse_SystemPromptTemplate 测试 handleTraderList API 返回的 trader 对象是否包含 system_prompt_template 字段
|
||||
func TestTraderListResponse_SystemPromptTemplate(t *testing.T) {
|
||||
// 模拟 handleTraderList 中的 trader 对象构造
|
||||
trader := &config.TraderRecord{
|
||||
trader := &store.Trader{
|
||||
ID: "trader-001",
|
||||
UserID: "user-1",
|
||||
Name: "My Trader",
|
||||
|
||||
@@ -4,14 +4,14 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"nofx/logger"
|
||||
"nofx/mcp"
|
||||
"nofx/store"
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
@@ -377,7 +377,7 @@ func (m *Manager) Status(runID string) *StatusPayload {
|
||||
func (m *Manager) launchWatcher(runID string, runner *Runner) {
|
||||
go func() {
|
||||
if err := runner.Wait(); err != nil {
|
||||
log.Printf("backtest run %s finished with error: %v", runID, err)
|
||||
logger.Infof("backtest run %s finished with error: %v", runID, err)
|
||||
}
|
||||
runner.PersistMetadata()
|
||||
meta := runner.CurrentMetadata()
|
||||
@@ -419,7 +419,7 @@ func (m *Manager) storeMetadata(runID string, meta *RunMetadata) {
|
||||
m.mu.Unlock()
|
||||
_ = SaveRunMetadata(meta)
|
||||
if err := updateRunIndex(meta, nil); err != nil {
|
||||
log.Printf("failed to update run index for %s: %v", runID, err)
|
||||
logger.Infof("failed to update run index for %s: %v", runID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -445,7 +445,7 @@ func (m *Manager) resolveAIConfig(cfg *BacktestConfig) error {
|
||||
return resolver(cfg)
|
||||
}
|
||||
|
||||
func (m *Manager) GetTrace(runID string, cycle int) (*logger.DecisionRecord, error) {
|
||||
func (m *Manager) GetTrace(runID string, cycle int) (*store.DecisionRecord, error) {
|
||||
return LoadDecisionTrace(runID, cycle)
|
||||
}
|
||||
|
||||
@@ -462,18 +462,18 @@ func (m *Manager) RestoreRuns() error {
|
||||
for _, runID := range runIDs {
|
||||
meta, err := LoadRunMetadata(runID)
|
||||
if err != nil {
|
||||
log.Printf("skip run %s: %v", runID, err)
|
||||
logger.Infof("skip run %s: %v", runID, err)
|
||||
continue
|
||||
}
|
||||
if meta.State == RunStateRunning {
|
||||
lock, err := loadRunLock(runID)
|
||||
if err != nil || lockIsStale(lock) {
|
||||
if err := deleteRunLock(runID); err != nil {
|
||||
log.Printf("failed to cleanup lock for %s: %v", runID, err)
|
||||
logger.Infof("failed to cleanup lock for %s: %v", runID, err)
|
||||
}
|
||||
meta.State = RunStatePaused
|
||||
if err := SaveRunMetadata(meta); err != nil {
|
||||
log.Printf("failed to mark %s paused: %v", runID, err)
|
||||
logger.Infof("failed to mark %s paused: %v", runID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -481,7 +481,7 @@ func (m *Manager) RestoreRuns() error {
|
||||
m.metadata[runID] = meta
|
||||
m.mu.Unlock()
|
||||
if err := updateRunIndex(meta, nil); err != nil {
|
||||
log.Printf("failed to sync index for %s: %v", runID, err)
|
||||
logger.Infof("failed to sync index for %s: %v", runID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package backtest
|
||||
|
||||
import (
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"os"
|
||||
"sort"
|
||||
"time"
|
||||
@@ -56,13 +56,13 @@ func enforceRetention(maxRuns int) {
|
||||
for i := 0; i < toRemove; i++ {
|
||||
runID := candidates[i].entry.RunID
|
||||
if err := os.RemoveAll(runDir(runID)); err != nil {
|
||||
log.Printf("failed to prune run %s: %v", runID, err)
|
||||
logger.Infof("failed to prune run %s: %v", runID, err)
|
||||
continue
|
||||
}
|
||||
delete(idx.Runs, runID)
|
||||
}
|
||||
if err := saveRunIndex(idx); err != nil {
|
||||
log.Printf("failed to save index after pruning: %v", err)
|
||||
logger.Infof("failed to save index after pruning: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,11 +91,11 @@ func enforceRetentionDB(maxRuns int) {
|
||||
continue
|
||||
}
|
||||
if err := deleteRunDB(runID); err != nil {
|
||||
log.Printf("failed to remove run %s: %v", runID, err)
|
||||
logger.Infof("failed to remove run %s: %v", runID, err)
|
||||
continue
|
||||
}
|
||||
if err := os.RemoveAll(runDir(runID)); err != nil {
|
||||
log.Printf("failed to remove run dir %s: %v", runID, err)
|
||||
logger.Infof("failed to remove run dir %s: %v", runID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
@@ -14,9 +14,9 @@ import (
|
||||
"time"
|
||||
|
||||
"nofx/decision"
|
||||
"nofx/logger"
|
||||
"nofx/market"
|
||||
"nofx/mcp"
|
||||
"nofx/store"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -35,7 +35,7 @@ type Runner struct {
|
||||
feed *DataFeed
|
||||
account *BacktestAccount
|
||||
|
||||
decisionLogger logger.IDecisionLogger
|
||||
decisionLogDir string
|
||||
mcpClient mcp.AIClient
|
||||
|
||||
statusMu sync.RWMutex
|
||||
@@ -83,7 +83,7 @@ func NewRunner(cfg BacktestConfig, mcpClient mcp.AIClient) (*Runner, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dLog := logger.NewDecisionLogger(decisionLogDir(cfg.RunID))
|
||||
dLogDir := decisionLogDir(cfg.RunID)
|
||||
account := NewBacktestAccount(cfg.InitialBalance, cfg.FeeBps, cfg.SlippageBps)
|
||||
|
||||
createdAt := time.Now().UTC()
|
||||
@@ -119,7 +119,7 @@ func NewRunner(cfg BacktestConfig, mcpClient mcp.AIClient) (*Runner, error) {
|
||||
cfg: cfg,
|
||||
feed: feed,
|
||||
account: account,
|
||||
decisionLogger: dLog,
|
||||
decisionLogDir: dLogDir,
|
||||
mcpClient: client,
|
||||
status: RunStateCreated,
|
||||
state: state,
|
||||
@@ -160,7 +160,7 @@ func (r *Runner) lockHeartbeatLoop() {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := updateRunLockHeartbeat(r.lockInfo); err != nil {
|
||||
log.Printf("failed to update lock heartbeat for %s: %v", r.cfg.RunID, err)
|
||||
logger.Infof("failed to update lock heartbeat for %s: %v", r.cfg.RunID, err)
|
||||
}
|
||||
case <-r.lockStop:
|
||||
return
|
||||
@@ -174,7 +174,7 @@ func (r *Runner) releaseLock() {
|
||||
r.lockStop = nil
|
||||
}
|
||||
if err := deleteRunLock(r.cfg.RunID); err != nil {
|
||||
log.Printf("failed to release lock for %s: %v", r.cfg.RunID, err)
|
||||
logger.Infof("failed to release lock for %s: %v", r.cfg.RunID, err)
|
||||
}
|
||||
r.lockInfo = nil
|
||||
}
|
||||
@@ -279,8 +279,8 @@ func (r *Runner) stepOnce() error {
|
||||
shouldDecide := r.shouldTriggerDecision(state.BarIndex)
|
||||
|
||||
var (
|
||||
record *logger.DecisionRecord
|
||||
decisionActions []logger.DecisionAction
|
||||
record *store.DecisionRecord
|
||||
decisionActions []store.DecisionAction
|
||||
tradeEvents = make([]TradeEvent, 0)
|
||||
execLog []string
|
||||
hadError bool
|
||||
@@ -317,7 +317,7 @@ func (r *Runner) stepOnce() error {
|
||||
return decisionErr
|
||||
}
|
||||
} else {
|
||||
log.Printf("failed to compute ai cache key: %v", err)
|
||||
logger.Infof("failed to compute ai cache key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -334,7 +334,7 @@ func (r *Runner) stepOnce() error {
|
||||
fullDecision = fd
|
||||
if r.cfg.CacheAI && r.aiCache != nil && cacheKey != "" {
|
||||
if err := r.aiCache.Put(cacheKey, r.cfg.PromptVariant, ts, fullDecision); err != nil {
|
||||
log.Printf("failed to persist ai cache for %s: %v", r.cfg.RunID, err)
|
||||
logger.Infof("failed to persist ai cache for %s: %v", r.cfg.RunID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -346,7 +346,7 @@ func (r *Runner) stepOnce() error {
|
||||
sorted := sortDecisionsByPriority(fullDecision.Decisions)
|
||||
|
||||
prevLogs := execLog
|
||||
decisionActions = make([]logger.DecisionAction, 0, len(sorted))
|
||||
decisionActions = make([]store.DecisionAction, 0, len(sorted))
|
||||
execLog = make([]string, 0, len(sorted)+len(prevLogs))
|
||||
if len(prevLogs) > 0 {
|
||||
execLog = append(execLog, prevLogs...)
|
||||
@@ -464,7 +464,7 @@ func (r *Runner) stepOnce() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Runner) buildDecisionContext(ts int64, marketData map[string]*market.Data, multiTF map[string]map[string]*market.Data, priceMap map[string]float64, callCount int) (*decision.Context, *logger.DecisionRecord, error) {
|
||||
func (r *Runner) buildDecisionContext(ts int64, marketData map[string]*market.Data, multiTF map[string]map[string]*market.Data, priceMap map[string]float64, callCount int) (*decision.Context, *store.DecisionRecord, error) {
|
||||
equity, unrealized, _ := r.account.TotalEquity(priceMap)
|
||||
available := r.account.Cash()
|
||||
marginUsed := r.totalMarginUsed()
|
||||
@@ -505,8 +505,8 @@ func (r *Runner) buildDecisionContext(ts int64, marketData map[string]*market.Da
|
||||
AltcoinLeverage: r.cfg.Leverage.AltcoinLeverage,
|
||||
}
|
||||
|
||||
record := &logger.DecisionRecord{
|
||||
AccountState: logger.AccountSnapshot{
|
||||
record := &store.DecisionRecord{
|
||||
AccountState: store.AccountSnapshot{
|
||||
TotalBalance: accountInfo.TotalEquity,
|
||||
AvailableBalance: accountInfo.AvailableBalance,
|
||||
TotalUnrealizedProfit: unrealized,
|
||||
@@ -524,7 +524,7 @@ func (r *Runner) buildDecisionContext(ts int64, marketData map[string]*market.Da
|
||||
return ctx, record, nil
|
||||
}
|
||||
|
||||
func (r *Runner) fillDecisionRecord(record *logger.DecisionRecord, full *decision.FullDecision) {
|
||||
func (r *Runner) fillDecisionRecord(record *store.DecisionRecord, full *decision.FullDecision) {
|
||||
record.InputPrompt = full.UserPrompt
|
||||
record.CoTTrace = full.CoTTrace
|
||||
if len(full.Decisions) > 0 {
|
||||
@@ -554,10 +554,10 @@ func (r *Runner) invokeAIWithRetry(ctx *decision.Context) (*decision.FullDecisio
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (r *Runner) executeDecision(dec decision.Decision, priceMap map[string]float64, ts int64, cycle int) (logger.DecisionAction, []TradeEvent, string, error) {
|
||||
func (r *Runner) executeDecision(dec decision.Decision, priceMap map[string]float64, ts int64, cycle int) (store.DecisionAction, []TradeEvent, string, error) {
|
||||
symbol := dec.Symbol
|
||||
usedLeverage := r.resolveLeverage(dec.Leverage, symbol)
|
||||
actionRecord := logger.DecisionAction{
|
||||
actionRecord := store.DecisionAction{
|
||||
Action: dec.Action,
|
||||
Symbol: symbol,
|
||||
Leverage: usedLeverage,
|
||||
@@ -748,12 +748,12 @@ func (r *Runner) remainingPosition(symbol, side string) float64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (r *Runner) snapshotPositions(priceMap map[string]float64) []logger.PositionSnapshot {
|
||||
func (r *Runner) snapshotPositions(priceMap map[string]float64) []store.PositionSnapshot {
|
||||
positions := r.account.Positions()
|
||||
list := make([]logger.PositionSnapshot, 0, len(positions))
|
||||
list := make([]store.PositionSnapshot, 0, len(positions))
|
||||
for _, pos := range positions {
|
||||
price := priceMap[pos.Symbol]
|
||||
list = append(list, logger.PositionSnapshot{
|
||||
list = append(list, store.PositionSnapshot{
|
||||
Symbol: pos.Symbol,
|
||||
Side: pos.Side,
|
||||
PositionAmt: pos.Quantity,
|
||||
@@ -1124,21 +1124,18 @@ func (r *Runner) persistMetadata() {
|
||||
meta := r.buildMetadata(state, r.Status())
|
||||
meta.CreatedAt = r.createdAt
|
||||
if err := SaveRunMetadata(meta); err != nil {
|
||||
log.Printf("failed to save run metadata for %s: %v", r.cfg.RunID, err)
|
||||
logger.Infof("failed to save run metadata for %s: %v", r.cfg.RunID, err)
|
||||
} else {
|
||||
if err := updateRunIndex(meta, &r.cfg); err != nil {
|
||||
log.Printf("failed to update index for %s: %v", r.cfg.RunID, err)
|
||||
logger.Infof("failed to update index for %s: %v", r.cfg.RunID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) logDecision(record *logger.DecisionRecord) error {
|
||||
func (r *Runner) logDecision(record *store.DecisionRecord) error {
|
||||
if record == nil {
|
||||
return nil
|
||||
}
|
||||
if err := r.decisionLogger.LogDecision(record); err != nil {
|
||||
return err
|
||||
}
|
||||
persistDecisionRecord(r.cfg.RunID, record)
|
||||
return nil
|
||||
}
|
||||
@@ -1157,14 +1154,14 @@ func (r *Runner) persistMetrics(force bool) {
|
||||
state := r.snapshotState()
|
||||
metrics, err := CalculateMetrics(r.cfg.RunID, &r.cfg, &state)
|
||||
if err != nil {
|
||||
log.Printf("failed to compute metrics for %s: %v", r.cfg.RunID, err)
|
||||
logger.Infof("failed to compute metrics for %s: %v", r.cfg.RunID, err)
|
||||
return
|
||||
}
|
||||
if metrics == nil {
|
||||
return
|
||||
}
|
||||
if err := PersistMetrics(r.cfg.RunID, metrics); err != nil {
|
||||
log.Printf("failed to persist metrics for %s: %v", r.cfg.RunID, err)
|
||||
logger.Infof("failed to persist metrics for %s: %v", r.cfg.RunID, err)
|
||||
return
|
||||
}
|
||||
r.lastMetricsWrite = time.Now()
|
||||
@@ -1264,7 +1261,7 @@ func (r *Runner) saveCheckpoint(state BacktestState) error {
|
||||
func (r *Runner) forceCheckpoint() {
|
||||
state := r.snapshotState()
|
||||
if err := r.saveCheckpoint(state); err != nil {
|
||||
log.Printf("failed to save checkpoint for %s: %v", r.cfg.RunID, err)
|
||||
logger.Infof("failed to save checkpoint for %s: %v", r.cfg.RunID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1281,7 +1278,6 @@ func (r *Runner) applyCheckpoint(ckpt *Checkpoint) error {
|
||||
return fmt.Errorf("checkpoint is nil")
|
||||
}
|
||||
r.account.RestoreFromSnapshots(ckpt.Cash, ckpt.RealizedPnL, ckpt.Positions)
|
||||
r.decisionLogger.SetCycleNumber(ckpt.DecisionCycle)
|
||||
r.stateMu.Lock()
|
||||
defer r.stateMu.Unlock()
|
||||
r.state.BarIndex = ckpt.BarIndex
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"nofx/logger"
|
||||
"nofx/store"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -380,7 +380,7 @@ func PersistMetrics(runID string, metrics *Metrics) error {
|
||||
return saveMetrics(runID, metrics)
|
||||
}
|
||||
|
||||
func LoadDecisionTrace(runID string, cycle int) (*logger.DecisionRecord, error) {
|
||||
func LoadDecisionTrace(runID string, cycle int) (*store.DecisionRecord, error) {
|
||||
if usingDB() {
|
||||
return loadDecisionTraceDB(runID, cycle)
|
||||
}
|
||||
@@ -418,7 +418,7 @@ func LoadDecisionTrace(runID string, cycle int) (*logger.DecisionRecord, error)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var record logger.DecisionRecord
|
||||
var record store.DecisionRecord
|
||||
if err := json.Unmarshal(data, &record); err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -429,7 +429,7 @@ func LoadDecisionTrace(runID string, cycle int) (*logger.DecisionRecord, error)
|
||||
return nil, fmt.Errorf("decision trace not found for run %s cycle %d", runID, cycle)
|
||||
}
|
||||
|
||||
func LoadDecisionRecords(runID string, limit, offset int) ([]*logger.DecisionRecord, error) {
|
||||
func LoadDecisionRecords(runID string, limit, offset int) ([]*store.DecisionRecord, error) {
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
@@ -443,7 +443,7 @@ func LoadDecisionRecords(runID string, limit, offset int) ([]*logger.DecisionRec
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return []*logger.DecisionRecord{}, nil
|
||||
return []*store.DecisionRecord{}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -471,19 +471,19 @@ func LoadDecisionRecords(runID string, limit, offset int) ([]*logger.DecisionRec
|
||||
return infoI.ModTime().After(infoJ.ModTime())
|
||||
})
|
||||
if offset >= len(files) {
|
||||
return []*logger.DecisionRecord{}, nil
|
||||
return []*store.DecisionRecord{}, nil
|
||||
}
|
||||
end := offset + limit
|
||||
if end > len(files) {
|
||||
end = len(files)
|
||||
}
|
||||
records := make([]*logger.DecisionRecord, 0, end-offset)
|
||||
records := make([]*store.DecisionRecord, 0, end-offset)
|
||||
for _, file := range files[offset:end] {
|
||||
data, err := os.ReadFile(file.path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var record logger.DecisionRecord
|
||||
var record store.DecisionRecord
|
||||
if err := json.Unmarshal(data, &record); err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -553,7 +553,7 @@ func CreateRunExport(runID string) (string, error) {
|
||||
return tmpFile.Name(), nil
|
||||
}
|
||||
|
||||
func persistDecisionRecord(runID string, record *logger.DecisionRecord) {
|
||||
func persistDecisionRecord(runID string, record *store.DecisionRecord) {
|
||||
if !usingDB() || record == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"nofx/logger"
|
||||
"nofx/store"
|
||||
)
|
||||
|
||||
func saveCheckpointDB(runID string, ckpt *Checkpoint) error {
|
||||
@@ -273,7 +273,7 @@ func saveProgressDB(runID string, payload progressPayload) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func loadDecisionTraceDB(runID string, cycle int) (*logger.DecisionRecord, error) {
|
||||
func loadDecisionTraceDB(runID string, cycle int) (*store.DecisionRecord, error) {
|
||||
query := `SELECT payload FROM backtest_decisions WHERE run_id = ?`
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
@@ -293,14 +293,14 @@ func loadDecisionTraceDB(runID string, cycle int) (*logger.DecisionRecord, error
|
||||
if err := rows.Scan(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var record logger.DecisionRecord
|
||||
var record store.DecisionRecord
|
||||
if err := json.Unmarshal(payload, &record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func saveDecisionRecordDB(runID string, record *logger.DecisionRecord) error {
|
||||
func saveDecisionRecordDB(runID string, record *store.DecisionRecord) error {
|
||||
if record == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -315,7 +315,7 @@ func saveDecisionRecordDB(runID string, record *logger.DecisionRecord) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func loadDecisionRecordsDB(runID string, limit, offset int) ([]*logger.DecisionRecord, error) {
|
||||
func loadDecisionRecordsDB(runID string, limit, offset int) ([]*store.DecisionRecord, error) {
|
||||
rows, err := persistenceDB.Query(`
|
||||
SELECT payload FROM backtest_decisions
|
||||
WHERE run_id = ?
|
||||
@@ -326,13 +326,13 @@ func loadDecisionRecordsDB(runID string, limit, offset int) ([]*logger.DecisionR
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
records := make([]*logger.DecisionRecord, 0, limit)
|
||||
records := make([]*store.DecisionRecord, 0, limit)
|
||||
for rows.Next() {
|
||||
var payload []byte
|
||||
if err := rows.Scan(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var record logger.DecisionRecord
|
||||
var record store.DecisionRecord
|
||||
if err := json.Unmarshal(payload, &record); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,455 +0,0 @@
|
||||
# Bootstrap 模块初始化框架
|
||||
|
||||
## 概述
|
||||
|
||||
Bootstrap 是一个模块化的初始化框架,允许各个模块通过注册钩子的方式自动完成初始化,支持优先级控制、条件初始化、错误策略等高级特性。
|
||||
|
||||
## 核心特性
|
||||
|
||||
- ✅ **优先级排序** - 保证模块按正确的顺序初始化
|
||||
- ✅ **钩子命名** - 每个钩子都有清晰的名称,便于日志追踪和错误定位
|
||||
- ✅ **上下文传递** - 模块之间可以共享数据(如数据库实例)
|
||||
- ✅ **条件初始化** - 根据配置动态决定是否初始化某个模块
|
||||
- ✅ **灵活的错误处理** - 支持快速失败、继续执行、警告三种策略
|
||||
- ✅ **详细日志** - 显示初始化进度、耗时统计
|
||||
- ✅ **线程安全** - 使用互斥锁保护全局状态
|
||||
- ✅ **测试友好** - 提供 Clear() 方法清除钩子
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 在模块中注册初始化钩子
|
||||
|
||||
在你的模块包中创建 `init.go` 文件:
|
||||
|
||||
```go
|
||||
// proxy/init.go
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"nofx/bootstrap"
|
||||
"nofx/config"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 注册初始化钩子
|
||||
bootstrap.Register("Proxy模块", bootstrap.PriorityCore, initProxyModule)
|
||||
}
|
||||
|
||||
func initProxyModule(ctx *bootstrap.Context) error {
|
||||
// 从配置中读取 proxy 配置
|
||||
proxyConfig := ctx.Config.Proxy
|
||||
|
||||
// 初始化代理管理器
|
||||
if err := InitGlobalProxyManager(proxyConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 将实例存储到上下文,供其他模块使用
|
||||
ctx.Set("proxy_manager", GetGlobalProxyManager())
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 在 main.go 中运行初始化
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"nofx/bootstrap"
|
||||
"nofx/config"
|
||||
|
||||
// 导入需要初始化的模块(触发 init() 注册)
|
||||
_ "nofx/proxy"
|
||||
_ "nofx/market"
|
||||
_ "nofx/trader"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 加载配置
|
||||
cfg, err := config.LoadConfig("config.json")
|
||||
if err != nil {
|
||||
log.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建初始化上下文
|
||||
ctx := bootstrap.NewContext(cfg)
|
||||
|
||||
// 执行所有初始化钩子
|
||||
if err := bootstrap.Run(ctx); err != nil {
|
||||
log.Fatalf("初始化失败: %v", err)
|
||||
}
|
||||
|
||||
// 启动业务逻辑...
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 运行效果
|
||||
|
||||
```
|
||||
🔄 开始初始化 3 个模块...
|
||||
[1/3] 初始化: Database模块 (优先级: 20)
|
||||
✓ 完成: Database模块 (耗时: 120ms)
|
||||
[2/3] 初始化: Proxy模块 (优先级: 50)
|
||||
↳ 代理自动刷新已启动 (间隔: 30m0s)
|
||||
↳ 代理池状态: 总计=5, 黑名单=0, 可用=5
|
||||
✓ 完成: Proxy模块 (耗时: 35ms)
|
||||
[3/3] 初始化: Market模块 (优先级: 100)
|
||||
✓ 完成: Market模块 (耗时: 200ms)
|
||||
✅ 所有模块初始化完成 (总耗时: 355ms)
|
||||
📊 统计: 成功=3, 跳过=0
|
||||
```
|
||||
|
||||
## 优先级常量
|
||||
|
||||
系统预定义了以下优先级常量(数值越小越先执行):
|
||||
|
||||
| 常量 | 值 | 用途 | 示例 |
|
||||
|------|-----|------|------|
|
||||
| `PriorityInfrastructure` | 10 | 基础设施 | 日志系统、配置加载 |
|
||||
| `PriorityDatabase` | 20 | 数据库连接 | SQLite、Redis |
|
||||
| `PriorityCore` | 50 | 核心模块 | Proxy、Market Monitor |
|
||||
| `PriorityBusiness` | 100 | 业务模块 | Trader、API Server |
|
||||
| `PriorityBackground` | 200 | 后台任务 | 定时任务、监控 |
|
||||
|
||||
### 使用示例
|
||||
|
||||
```go
|
||||
// 数据库模块(最先初始化)
|
||||
bootstrap.Register("Database", bootstrap.PriorityDatabase, initDatabase)
|
||||
|
||||
// 代理模块(核心模块)
|
||||
bootstrap.Register("Proxy", bootstrap.PriorityCore, initProxy)
|
||||
|
||||
// Trader模块(依赖数据库和代理)
|
||||
bootstrap.Register("Trader", bootstrap.PriorityBusiness, initTrader)
|
||||
```
|
||||
|
||||
## 高级特性
|
||||
|
||||
### 1. 条件初始化
|
||||
|
||||
某些模块只在特定条件下才需要初始化:
|
||||
|
||||
```go
|
||||
bootstrap.Register("Proxy模块", bootstrap.PriorityCore, initProxy).
|
||||
EnabledIf(func(ctx *bootstrap.Context) bool {
|
||||
// 只在配置中启用 proxy 时才初始化
|
||||
return ctx.Config.Proxy != nil && ctx.Config.Proxy.Enabled
|
||||
})
|
||||
```
|
||||
|
||||
**输出**:
|
||||
```
|
||||
[2/5] 跳过: Proxy模块 (条件未满足)
|
||||
```
|
||||
|
||||
### 2. 错误处理策略
|
||||
|
||||
支持三种错误处理策略:
|
||||
|
||||
#### FailFast(默认)- 遇到错误立即停止
|
||||
|
||||
```go
|
||||
bootstrap.Register("Database", bootstrap.PriorityDatabase, initDatabase)
|
||||
// 默认就是 FailFast,无需显式设置
|
||||
```
|
||||
|
||||
**效果**:Database 初始化失败,整个系统停止启动
|
||||
|
||||
#### ContinueOnError - 继续执行,收集所有错误
|
||||
|
||||
```go
|
||||
bootstrap.Register("Proxy", bootstrap.PriorityCore, initProxy).
|
||||
OnError(bootstrap.ContinueOnError)
|
||||
```
|
||||
|
||||
**效果**:Proxy 失败不影响其他模块,最后汇总所有错误
|
||||
|
||||
#### WarnOnError - 继续执行,只打印警告
|
||||
|
||||
```go
|
||||
bootstrap.Register("Proxy", bootstrap.PriorityCore, initProxy).
|
||||
OnError(bootstrap.WarnOnError)
|
||||
```
|
||||
|
||||
**效果**:Proxy 失败只打印警告,不影响系统运行
|
||||
|
||||
**输出**:
|
||||
```
|
||||
[2/5] 初始化: Proxy模块 (优先级: 50)
|
||||
⚠️ 警告: Proxy模块 (耗时: 15ms) - 连接代理服务器超时
|
||||
```
|
||||
|
||||
### 3. 上下文数据共享
|
||||
|
||||
模块之间可以通过 Context 共享数据:
|
||||
|
||||
```go
|
||||
// database/init.go - 存储数据库实例
|
||||
func initDatabase(ctx *bootstrap.Context) error {
|
||||
db, err := sql.Open("sqlite", "config.db")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 存储到上下文
|
||||
ctx.Set("database", db)
|
||||
return nil
|
||||
}
|
||||
|
||||
// trader/init.go - 获取数据库实例
|
||||
func initTrader(ctx *bootstrap.Context) error {
|
||||
// 从上下文获取数据库实例
|
||||
db, ok := ctx.Get("database")
|
||||
if !ok {
|
||||
return fmt.Errorf("database 未初始化")
|
||||
}
|
||||
|
||||
database := db.(*sql.DB)
|
||||
// 使用 database 初始化 trader...
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
**安全获取**:
|
||||
```go
|
||||
// 使用 MustGet,不存在会 panic(适合必需的依赖)
|
||||
db := ctx.MustGet("database").(*sql.DB)
|
||||
```
|
||||
|
||||
### 4. 链式调用
|
||||
|
||||
支持流畅的链式调用:
|
||||
|
||||
```go
|
||||
bootstrap.Register("Proxy", bootstrap.PriorityCore, initProxy).
|
||||
EnabledIf(func(ctx *bootstrap.Context) bool {
|
||||
return ctx.Config.Proxy != nil && ctx.Config.Proxy.Enabled
|
||||
}).
|
||||
OnError(bootstrap.WarnOnError)
|
||||
```
|
||||
|
||||
### 5. 自定义错误策略
|
||||
|
||||
在 Run 时可以指定全局默认错误策略:
|
||||
|
||||
```go
|
||||
// 所有钩子默认使用 ContinueOnError,除非钩子自己指定了 FailFast
|
||||
err := bootstrap.RunWithPolicy(ctx, bootstrap.ContinueOnError)
|
||||
```
|
||||
|
||||
## 完整示例
|
||||
|
||||
### 示例1:Database 模块
|
||||
|
||||
```go
|
||||
// database/init.go
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"nofx/bootstrap"
|
||||
)
|
||||
|
||||
func init() {
|
||||
bootstrap.Register("Database", bootstrap.PriorityDatabase, initDatabase)
|
||||
}
|
||||
|
||||
func initDatabase(ctx *bootstrap.Context) error {
|
||||
db, err := sql.Open("sqlite", "config.db")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := db.Ping(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 存储到上下文
|
||||
ctx.Set("database", db)
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### 示例2:Proxy 模块(条件初始化 + 警告策略)
|
||||
|
||||
```go
|
||||
// proxy/init.go
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"nofx/bootstrap"
|
||||
"nofx/config"
|
||||
)
|
||||
|
||||
func init() {
|
||||
bootstrap.Register("Proxy", bootstrap.PriorityCore, initProxy).
|
||||
EnabledIf(func(ctx *bootstrap.Context) bool {
|
||||
return ctx.Config.Proxy != nil && ctx.Config.Proxy.Enabled
|
||||
}).
|
||||
OnError(bootstrap.WarnOnError) // Proxy 失败不影响系统
|
||||
}
|
||||
|
||||
func initProxy(ctx *bootstrap.Context) error {
|
||||
proxyConfig := convertConfig(ctx.Config.Proxy)
|
||||
|
||||
if err := InitGlobalProxyManager(proxyConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Set("proxy_manager", GetGlobalProxyManager())
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### 示例3:Trader 模块(依赖其他模块)
|
||||
|
||||
```go
|
||||
// trader/init.go
|
||||
package trader
|
||||
|
||||
import (
|
||||
"nofx/bootstrap"
|
||||
)
|
||||
|
||||
func init() {
|
||||
bootstrap.Register("Trader", bootstrap.PriorityBusiness, initTrader)
|
||||
}
|
||||
|
||||
func initTrader(ctx *bootstrap.Context) error {
|
||||
// 获取依赖
|
||||
db := ctx.MustGet("database").(*sql.DB)
|
||||
|
||||
// 可选依赖
|
||||
var proxyMgr *proxy.ProxyManager
|
||||
if pm, ok := ctx.Get("proxy_manager"); ok {
|
||||
proxyMgr = pm.(*proxy.ProxyManager)
|
||||
}
|
||||
|
||||
// 使用依赖初始化 trader...
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
## 调试和测试
|
||||
|
||||
### 查看已注册的钩子
|
||||
|
||||
```go
|
||||
hooks := bootstrap.GetRegistered()
|
||||
for _, hook := range hooks {
|
||||
fmt.Printf("钩子: %s, 优先级: %d\n", hook.Name, hook.Priority)
|
||||
}
|
||||
```
|
||||
|
||||
### 清除钩子(用于测试)
|
||||
|
||||
```go
|
||||
func TestMyModule(t *testing.T) {
|
||||
// 清除之前注册的钩子
|
||||
bootstrap.Clear()
|
||||
|
||||
// 注册测试钩子
|
||||
bootstrap.Register("Test", 10, func(ctx *bootstrap.Context) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// 运行测试...
|
||||
}
|
||||
```
|
||||
|
||||
### 统计钩子数量
|
||||
|
||||
```go
|
||||
count := bootstrap.Count()
|
||||
fmt.Printf("已注册 %d 个初始化钩子\n", count)
|
||||
```
|
||||
|
||||
## 错误处理最佳实践
|
||||
|
||||
### 1. 关键模块使用 FailFast
|
||||
|
||||
```go
|
||||
// 数据库是关键依赖,失败必须停止
|
||||
bootstrap.Register("Database", bootstrap.PriorityDatabase, initDatabase)
|
||||
// 默认是 FailFast,无需显式设置
|
||||
```
|
||||
|
||||
### 2. 可选模块使用 WarnOnError
|
||||
|
||||
```go
|
||||
// Proxy 是可选的,失败可以使用直连
|
||||
bootstrap.Register("Proxy", bootstrap.PriorityCore, initProxy).
|
||||
OnError(bootstrap.WarnOnError)
|
||||
```
|
||||
|
||||
### 3. 批量初始化使用 ContinueOnError
|
||||
|
||||
```go
|
||||
// 批量加载插件,希望看到所有失败的插件
|
||||
for _, plugin := range plugins {
|
||||
bootstrap.Register(plugin.Name, 150, plugin.Init).
|
||||
OnError(bootstrap.ContinueOnError)
|
||||
}
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q1: 如何保证模块A在模块B之前初始化?
|
||||
|
||||
使用优先级控制:
|
||||
```go
|
||||
bootstrap.Register("ModuleA", 50, initA) // 先执行
|
||||
bootstrap.Register("ModuleB", 100, initB) // 后执行
|
||||
```
|
||||
|
||||
### Q2: 如何在初始化失败时获取详细信息?
|
||||
|
||||
钩子名称会自动包含在错误信息中:
|
||||
```
|
||||
Error: [Proxy模块] 初始化失败: 连接代理服务器超时
|
||||
```
|
||||
|
||||
### Q3: 可以动态注册钩子吗?
|
||||
|
||||
可以,但建议在 `init()` 函数中注册:
|
||||
```go
|
||||
// 推荐:在 init() 中注册(包加载时自动执行)
|
||||
func init() {
|
||||
bootstrap.Register("MyModule", 100, initModule)
|
||||
}
|
||||
|
||||
// 不推荐:在运行时注册(可能导致顺序问题)
|
||||
func main() {
|
||||
bootstrap.Register("MyModule", 100, initModule)
|
||||
}
|
||||
```
|
||||
|
||||
### Q4: 如何在钩子中访问命令行参数?
|
||||
|
||||
通过 Context 的 Data 字段传递:
|
||||
```go
|
||||
// main.go
|
||||
ctx := bootstrap.NewContext(cfg)
|
||||
ctx.Set("args", os.Args)
|
||||
|
||||
// module/init.go
|
||||
func initModule(ctx *bootstrap.Context) error {
|
||||
args := ctx.MustGet("args").([]string)
|
||||
// 使用 args...
|
||||
}
|
||||
```
|
||||
## 性能考虑
|
||||
|
||||
- 钩子注册是线程安全的,但注册本身有轻微的锁开销
|
||||
- 建议在 `init()` 函数中注册,避免运行时动态注册
|
||||
- 钩子执行是顺序的,不会并发执行
|
||||
- 每个钩子的耗时会被记录并显示
|
||||
|
||||
## 许可证
|
||||
|
||||
本模块为 NOFX 项目内部模块,遵循项目整体许可证。
|
||||
@@ -1,169 +0,0 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Priority 初始化优先级常量
|
||||
const (
|
||||
PriorityInfrastructure = 10 // 基础设施(日志、配置等)
|
||||
PriorityDatabase = 20 // 数据库连接
|
||||
PriorityCore = 50 // 核心模块(Proxy、Market等)
|
||||
PriorityBusiness = 100 // 业务模块(Trader、API等)
|
||||
PriorityBackground = 200 // 后台任务
|
||||
)
|
||||
|
||||
// ErrorPolicy 错误处理策略
|
||||
type ErrorPolicy int
|
||||
|
||||
const (
|
||||
// FailFast 遇到错误立即停止(默认)
|
||||
FailFast ErrorPolicy = iota
|
||||
// ContinueOnError 继续执行,收集所有错误
|
||||
ContinueOnError
|
||||
// WarnOnError 继续执行,只打印警告
|
||||
WarnOnError
|
||||
)
|
||||
|
||||
var (
|
||||
hooks []Hook
|
||||
hooksMu sync.Mutex
|
||||
)
|
||||
|
||||
// Register 注册初始化钩子
|
||||
// name: 模块名称(如 "Proxy", "Database")
|
||||
// priority: 优先级(建议使用常量:PriorityCore、PriorityBusiness等)
|
||||
// fn: 初始化函数
|
||||
func Register(name string, priority int, fn func(*Context) error) *HookBuilder {
|
||||
hooksMu.Lock()
|
||||
defer hooksMu.Unlock()
|
||||
|
||||
hook := Hook{
|
||||
Name: name,
|
||||
Priority: priority,
|
||||
Func: fn,
|
||||
Enabled: nil, // 默认启用
|
||||
ErrorPolicy: FailFast,
|
||||
}
|
||||
|
||||
hooks = append(hooks, hook)
|
||||
|
||||
return &HookBuilder{hook: &hooks[len(hooks)-1]}
|
||||
}
|
||||
|
||||
// Run 执行所有已注册的钩子
|
||||
func Run(ctx *Context) error {
|
||||
return RunWithPolicy(ctx, FailFast)
|
||||
}
|
||||
|
||||
// RunWithPolicy 使用指定的默认错误策略执行所有钩子
|
||||
func RunWithPolicy(ctx *Context, defaultPolicy ErrorPolicy) error {
|
||||
hooksMu.Lock()
|
||||
hooksCopy := make([]Hook, len(hooks))
|
||||
copy(hooksCopy, hooks)
|
||||
hooksMu.Unlock()
|
||||
|
||||
if len(hooksCopy) == 0 {
|
||||
log.Printf("⚠️ 没有注册任何初始化钩子")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 按优先级排序
|
||||
sort.Slice(hooksCopy, func(i, j int) bool {
|
||||
return hooksCopy[i].Priority < hooksCopy[j].Priority
|
||||
})
|
||||
|
||||
log.Printf("🔄 开始初始化 %d 个模块...", len(hooksCopy))
|
||||
startTime := time.Now()
|
||||
|
||||
var errors []error
|
||||
successCount := 0
|
||||
skippedCount := 0
|
||||
|
||||
for i, hook := range hooksCopy {
|
||||
// 检查是否启用
|
||||
if hook.Enabled != nil && !hook.Enabled(ctx) {
|
||||
log.Printf(" [%d/%d] 跳过: %s (条件未满足)",
|
||||
i+1, len(hooksCopy), hook.Name)
|
||||
skippedCount++
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf(" [%d/%d] 初始化: %s (优先级: %d)",
|
||||
i+1, len(hooksCopy), hook.Name, hook.Priority)
|
||||
|
||||
hookStart := time.Now()
|
||||
err := hook.Func(ctx)
|
||||
elapsed := time.Since(hookStart)
|
||||
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("[%s] 初始化失败: %w", hook.Name, err)
|
||||
|
||||
// 根据错误策略处理
|
||||
policy := hook.ErrorPolicy
|
||||
if policy == FailFast && defaultPolicy != FailFast {
|
||||
policy = defaultPolicy
|
||||
}
|
||||
|
||||
switch policy {
|
||||
case FailFast:
|
||||
log.Printf(" ❌ 失败: %s (耗时: %v)", hook.Name, elapsed)
|
||||
return errMsg
|
||||
case ContinueOnError:
|
||||
log.Printf(" ❌ 失败: %s (耗时: %v) - 继续执行", hook.Name, elapsed)
|
||||
errors = append(errors, errMsg)
|
||||
case WarnOnError:
|
||||
log.Printf(" ⚠️ 警告: %s (耗时: %v) - %v", hook.Name, elapsed, err)
|
||||
}
|
||||
} else {
|
||||
log.Printf(" ✓ 完成: %s (耗时: %v)", hook.Name, elapsed)
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
totalElapsed := time.Since(startTime)
|
||||
|
||||
// 汇总结果
|
||||
if len(errors) > 0 {
|
||||
logger.Log.Warnf("⚠️ 初始化完成,但有 %d 个模块失败 (总耗时: %v)",
|
||||
len(errors), totalElapsed)
|
||||
log.Printf("📊 统计: 成功=%d, 失败=%d, 跳过=%d",
|
||||
successCount, len(errors), skippedCount)
|
||||
|
||||
// 返回合并的错误
|
||||
return fmt.Errorf("以下模块初始化失败: %v", errors)
|
||||
}
|
||||
|
||||
log.Printf("✅ 所有模块初始化完成 (总耗时: %v)", totalElapsed)
|
||||
log.Printf("📊 统计: 成功=%d, 跳过=%d", successCount, skippedCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRegistered 获取已注册的钩子列表(用于调试)
|
||||
func GetRegistered() []Hook {
|
||||
hooksMu.Lock()
|
||||
defer hooksMu.Unlock()
|
||||
|
||||
hooksCopy := make([]Hook, len(hooks))
|
||||
copy(hooksCopy, hooks)
|
||||
return hooksCopy
|
||||
}
|
||||
|
||||
// Clear 清除所有钩子(用于测试)
|
||||
func Clear() {
|
||||
hooksMu.Lock()
|
||||
defer hooksMu.Unlock()
|
||||
hooks = nil
|
||||
}
|
||||
|
||||
// Count 返回已注册的钩子数量
|
||||
func Count() int {
|
||||
hooksMu.Lock()
|
||||
defer hooksMu.Unlock()
|
||||
return len(hooks)
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"nofx/config"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Context 初始化上下文,用于在钩子之间传递数据
|
||||
type Context struct {
|
||||
Config *config.Config
|
||||
Data map[string]interface{} // 存储模块之间共享的数据(如数据库实例)
|
||||
ctx context.Context
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewContext 创建新的初始化上下文
|
||||
func NewContext(cfg *config.Config) *Context {
|
||||
return &Context{
|
||||
Config: cfg,
|
||||
Data: make(map[string]interface{}),
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
// Set 存储数据到上下文
|
||||
func (c *Context) Set(key string, value interface{}) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Data[key] = value
|
||||
}
|
||||
|
||||
// Get 从上下文获取数据
|
||||
func (c *Context) Get(key string) (interface{}, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
val, ok := c.Data[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// MustGet 从上下文获取数据,不存在则 panic
|
||||
func (c *Context) MustGet(key string) interface{} {
|
||||
val, ok := c.Get(key)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("context key '%s' not found", key))
|
||||
}
|
||||
return val
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package bootstrap
|
||||
|
||||
// Hook 初始化钩子
|
||||
type Hook struct {
|
||||
Name string // 钩子名称(模块名)
|
||||
Priority int // 优先级(越小越先执行)
|
||||
Func func(*Context) error // 初始化函数
|
||||
Enabled func(*Context) bool // 条件函数,返回 false 则跳过
|
||||
ErrorPolicy ErrorPolicy // 错误处理策略
|
||||
}
|
||||
|
||||
// HookBuilder 钩子构建器(用于链式调用)
|
||||
type HookBuilder struct {
|
||||
hook *Hook
|
||||
}
|
||||
|
||||
// EnabledIf 设置条件函数(链式调用)
|
||||
func (b *HookBuilder) EnabledIf(fn func(*Context) bool) *HookBuilder {
|
||||
b.hook.Enabled = fn
|
||||
return b
|
||||
}
|
||||
|
||||
// OnError 设置错误处理策略(链式调用)
|
||||
func (b *HookBuilder) OnError(policy ErrorPolicy) *HookBuilder {
|
||||
b.hook.ErrorPolicy = policy
|
||||
return b
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
package bootstrap
|
||||
|
||||
import "nofx/config"
|
||||
|
||||
type InitHook func(config *config.Config) error
|
||||
|
||||
var InitHooks []InitHook
|
||||
|
||||
// RegisterInitHook 注册初始化钩子
|
||||
func RegisterInitHook(hook InitHook) {
|
||||
InitHooks = append(InitHooks, hook)
|
||||
}
|
||||
|
||||
// RunInitHooks 运行所有注册的初始化钩子
|
||||
func RunInitHooks(c *config.Config) error {
|
||||
for _, hookF := range InitHooks {
|
||||
if err := hookF(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package config
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"os"
|
||||
)
|
||||
|
||||
@@ -15,16 +15,7 @@ type LeverageConfig struct {
|
||||
|
||||
// LogConfig 日志配置
|
||||
type LogConfig struct {
|
||||
Level string `json:"level"` // 日志级别: debug, info, warn, error (默认: info)
|
||||
Telegram *TelegramConfig `json:"telegram"` // Telegram推送配置(可选)
|
||||
}
|
||||
|
||||
// TelegramConfig Telegram推送配置(简化版,只保留必需字段)
|
||||
type TelegramConfig struct {
|
||||
Enabled bool `json:"enabled"` // 是否启用(默认: false)
|
||||
BotToken string `json:"bot_token"` // Bot Token
|
||||
ChatID int64 `json:"chat_id"` // Chat ID
|
||||
MinLevel string `json:"min_level"` // 最低日志级别,该级别及以上的日志会推送到Telegram(可选,默认: error)
|
||||
Level string `json:"level"` // 日志级别: debug, info, warn, error (默认: info)
|
||||
}
|
||||
|
||||
// Config 总配置
|
||||
@@ -41,14 +32,14 @@ type Config struct {
|
||||
Leverage LeverageConfig `json:"leverage"`
|
||||
JWTSecret string `json:"jwt_secret"`
|
||||
DataKLineTime string `json:"data_k_line_time"`
|
||||
Log *LogConfig `json:"log"` // 日志配置
|
||||
Log *LogConfig `json:"nofx/logger"` // 日志配置
|
||||
}
|
||||
|
||||
// LoadConfig 从文件加载配置
|
||||
func LoadConfig(filename string) (*Config, error) {
|
||||
// 检查filename是否存在
|
||||
if _, err := os.Stat(filename); os.IsNotExist(err) {
|
||||
log.Printf("📄 %s不存在,使用默认配置", filename)
|
||||
logger.Infof("📄 %s不存在,使用默认配置", filename)
|
||||
return &Config{}, nil
|
||||
}
|
||||
|
||||
|
||||
1735
config/database.go
1735
config/database.go
File diff suppressed because it is too large
Load Diff
@@ -1,850 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"nofx/crypto"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestUpdateExchange_EmptyValuesShouldNotOverwrite 测试空值不应覆盖现有数据
|
||||
// 这是 Bug 的核心:当前实现会用空字符串覆盖现有的私钥
|
||||
func TestUpdateExchange_EmptyValuesShouldNotOverwrite(t *testing.T) {
|
||||
// 准备测试数据库
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
userID := "test-user-001"
|
||||
|
||||
// 步骤 1: 创建初始配置(包含私钥)
|
||||
initialAPIKey := "initial-api-key-12345"
|
||||
initialSecretKey := "initial-secret-key-67890"
|
||||
|
||||
err := db.UpdateExchange(
|
||||
userID,
|
||||
"hyperliquid",
|
||||
true, // enabled
|
||||
initialAPIKey,
|
||||
initialSecretKey,
|
||||
false, // testnet
|
||||
"0xWalletAddress",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"", // lighter_wallet_addr
|
||||
"", // lighter_private_key
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("初始化失败: %v", err)
|
||||
}
|
||||
|
||||
// 步骤 2: 验证初始数据已保存
|
||||
exchanges, err := db.GetExchanges(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("获取配置失败: %v", err)
|
||||
}
|
||||
if len(exchanges) == 0 {
|
||||
t.Fatal("未找到配置")
|
||||
}
|
||||
|
||||
// 解密后应该能看到原始值
|
||||
if exchanges[0].APIKey != initialAPIKey {
|
||||
t.Errorf("初始 APIKey 不正确,期望 %s,实际 %s", initialAPIKey, exchanges[0].APIKey)
|
||||
}
|
||||
|
||||
// 步骤 3: 用空值更新(模拟前端发送空值的场景)
|
||||
// 🐛 Bug 重现:这应该 NOT 覆盖现有的私钥,但当前实现会覆盖
|
||||
err = db.UpdateExchange(
|
||||
userID,
|
||||
"hyperliquid",
|
||||
false, // 只改变 enabled 状态
|
||||
"", // 空 apiKey - 不应该覆盖
|
||||
"", // 空 secretKey - 不应该覆盖
|
||||
true, // 改变 testnet 状态
|
||||
"0xWalletAddress",
|
||||
"",
|
||||
"",
|
||||
"", // 空 aster_private_key - 不应该覆盖
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("更新失败: %v", err)
|
||||
}
|
||||
|
||||
// 步骤 4: 验证私钥没有被空值覆盖
|
||||
exchanges, err = db.GetExchanges(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("获取更新后配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 🎯 关键断言:私钥应该保持不变
|
||||
if exchanges[0].APIKey != initialAPIKey {
|
||||
t.Errorf("❌ Bug 确认:APIKey 被空值覆盖了!期望 %s,实际 %s", initialAPIKey, exchanges[0].APIKey)
|
||||
}
|
||||
if exchanges[0].SecretKey != initialSecretKey {
|
||||
t.Errorf("❌ Bug 确认:SecretKey 被空值覆盖了!期望 %s,实际 %s", initialSecretKey, exchanges[0].SecretKey)
|
||||
}
|
||||
|
||||
// 验证非敏感字段正常更新
|
||||
if exchanges[0].Enabled {
|
||||
t.Error("enabled 应该被更新为 false")
|
||||
}
|
||||
if !exchanges[0].Testnet {
|
||||
t.Error("testnet 应该被更新为 true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateExchange_AsterEmptyValuesShouldNotOverwrite 测试 Aster 私钥不被空值覆盖
|
||||
func TestUpdateExchange_AsterEmptyValuesShouldNotOverwrite(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
userID := "test-user-002"
|
||||
|
||||
// 步骤 1: 创建 Aster 配置
|
||||
initialAsterKey := "aster-private-key-xyz123"
|
||||
|
||||
err := db.UpdateExchange(
|
||||
userID,
|
||||
"aster",
|
||||
true,
|
||||
"",
|
||||
"",
|
||||
false,
|
||||
"",
|
||||
"0xAsterUser",
|
||||
"0xAsterSigner",
|
||||
initialAsterKey,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("初始化 Aster 失败: %v", err)
|
||||
}
|
||||
|
||||
// 步骤 2: 用空值更新
|
||||
err = db.UpdateExchange(
|
||||
userID,
|
||||
"aster",
|
||||
false, // 只改 enabled
|
||||
"",
|
||||
"",
|
||||
false,
|
||||
"",
|
||||
"0xAsterUser",
|
||||
"0xAsterSigner",
|
||||
"", // 空 aster_private_key
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("更新失败: %v", err)
|
||||
}
|
||||
|
||||
// 步骤 3: 验证 aster_private_key 没有被覆盖
|
||||
exchanges, err := db.GetExchanges(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("获取配置失败: %v", err)
|
||||
}
|
||||
|
||||
if exchanges[0].AsterPrivateKey != initialAsterKey {
|
||||
t.Errorf("❌ Bug 确认:AsterPrivateKey 被空值覆盖了!期望 %s,实际 %s", initialAsterKey, exchanges[0].AsterPrivateKey)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateExchange_NonEmptyValuesShouldUpdate 测试非空值应该正常更新
|
||||
func TestUpdateExchange_NonEmptyValuesShouldUpdate(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
userID := "test-user-003"
|
||||
|
||||
// 步骤 1: 创建初始配置
|
||||
err := db.UpdateExchange(
|
||||
userID,
|
||||
"hyperliquid",
|
||||
true,
|
||||
"old-api-key",
|
||||
"old-secret-key",
|
||||
false,
|
||||
"0xOldWallet",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("初始化失败: %v", err)
|
||||
}
|
||||
|
||||
// 步骤 2: 用非空值更新
|
||||
newAPIKey := "new-api-key-456"
|
||||
newSecretKey := "new-secret-key-789"
|
||||
|
||||
err = db.UpdateExchange(
|
||||
userID,
|
||||
"hyperliquid",
|
||||
true,
|
||||
newAPIKey,
|
||||
newSecretKey,
|
||||
false,
|
||||
"0xNewWallet",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("更新失败: %v", err)
|
||||
}
|
||||
|
||||
// 步骤 3: 验证新值已更新
|
||||
exchanges, err := db.GetExchanges(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("获取配置失败: %v", err)
|
||||
}
|
||||
|
||||
if exchanges[0].APIKey != newAPIKey {
|
||||
t.Errorf("APIKey 未更新,期望 %s,实际 %s", newAPIKey, exchanges[0].APIKey)
|
||||
}
|
||||
if exchanges[0].SecretKey != newSecretKey {
|
||||
t.Errorf("SecretKey 未更新,期望 %s,实际 %s", newSecretKey, exchanges[0].SecretKey)
|
||||
}
|
||||
if exchanges[0].HyperliquidWalletAddr != "0xNewWallet" {
|
||||
t.Errorf("WalletAddr 未更新")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateExchange_PartialUpdateShouldWork 测试部分字段更新
|
||||
func TestUpdateExchange_PartialUpdateShouldWork(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
userID := "test-user-005"
|
||||
|
||||
// 创建初始配置
|
||||
err := db.UpdateExchange(
|
||||
userID,
|
||||
"hyperliquid",
|
||||
true,
|
||||
"api-key-123",
|
||||
"secret-key-456",
|
||||
false,
|
||||
"0xWallet1",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("初始化失败: %v", err)
|
||||
}
|
||||
|
||||
// 只更新 enabled 和 testnet,私钥留空
|
||||
err = db.UpdateExchange(
|
||||
userID,
|
||||
"hyperliquid",
|
||||
false,
|
||||
"", // 留空
|
||||
"", // 留空
|
||||
true,
|
||||
"0xWallet2",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("部分更新失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证
|
||||
exchanges, err := db.GetExchanges(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("获取配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 私钥应该保持不变
|
||||
if exchanges[0].APIKey != "api-key-123" {
|
||||
t.Errorf("APIKey 不应改变,期望 api-key-123,实际 %s", exchanges[0].APIKey)
|
||||
}
|
||||
if exchanges[0].SecretKey != "secret-key-456" {
|
||||
t.Errorf("SecretKey 不应改变,期望 secret-key-456,实际 %s", exchanges[0].SecretKey)
|
||||
}
|
||||
|
||||
// 其他字段应该更新
|
||||
if exchanges[0].Enabled {
|
||||
t.Error("enabled 应该更新为 false")
|
||||
}
|
||||
if !exchanges[0].Testnet {
|
||||
t.Error("testnet 应该更新为 true")
|
||||
}
|
||||
if exchanges[0].HyperliquidWalletAddr != "0xWallet2" {
|
||||
t.Error("wallet 地址应该更新")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateExchange_MultipleExchangeTypes 测试不同交易所类型
|
||||
func TestUpdateExchange_MultipleExchangeTypes(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
userID := "test-user-006"
|
||||
|
||||
testCases := []struct {
|
||||
exchangeID string
|
||||
name string
|
||||
typ string
|
||||
}{
|
||||
{"binance", "Binance Futures", "cex"},
|
||||
{"hyperliquid", "Hyperliquid", "dex"},
|
||||
{"aster", "Aster DEX", "dex"},
|
||||
{"unknown-exchange", "unknown-exchange Exchange", "cex"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.exchangeID, func(t *testing.T) {
|
||||
err := db.UpdateExchange(
|
||||
userID,
|
||||
tc.exchangeID,
|
||||
true,
|
||||
"api-key-"+tc.exchangeID,
|
||||
"secret-key-"+tc.exchangeID,
|
||||
false,
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("创建 %s 失败: %v", tc.exchangeID, err)
|
||||
}
|
||||
|
||||
// 验证创建成功
|
||||
exchanges, err := db.GetExchanges(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("获取配置失败: %v", err)
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, ex := range exchanges {
|
||||
if ex.ID == tc.exchangeID {
|
||||
found = true
|
||||
if ex.Name != tc.name {
|
||||
t.Errorf("交易所名称不正确,期望 %s,实际 %s", tc.name, ex.Name)
|
||||
}
|
||||
if ex.Type != tc.typ {
|
||||
t.Errorf("交易所类型不正确,期望 %s,实际 %s", tc.typ, ex.Type)
|
||||
}
|
||||
if ex.APIKey != "api-key-"+tc.exchangeID {
|
||||
t.Errorf("APIKey 不正确")
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("未找到交易所 %s", tc.exchangeID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateExchange_MixedSensitiveFields 测试混合更新敏感和非敏感字段
|
||||
func TestUpdateExchange_MixedSensitiveFields(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
userID := "test-user-007"
|
||||
|
||||
// 创建初始配置
|
||||
err := db.UpdateExchange(
|
||||
userID,
|
||||
"hyperliquid",
|
||||
true,
|
||||
"old-api-key",
|
||||
"old-secret-key",
|
||||
false,
|
||||
"0xOldWallet",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("初始化失败: %v", err)
|
||||
}
|
||||
|
||||
// 场景1: 只更新 apiKey,secretKey 留空
|
||||
err = db.UpdateExchange(
|
||||
userID,
|
||||
"hyperliquid",
|
||||
false,
|
||||
"new-api-key",
|
||||
"", // 留空
|
||||
true,
|
||||
"0xNewWallet",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("更新1失败: %v", err)
|
||||
}
|
||||
|
||||
exchanges, _ := db.GetExchanges(userID)
|
||||
if exchanges[0].APIKey != "new-api-key" {
|
||||
t.Error("APIKey 应该更新")
|
||||
}
|
||||
if exchanges[0].SecretKey != "old-secret-key" {
|
||||
t.Error("SecretKey 应该保持不变")
|
||||
}
|
||||
|
||||
// 场景2: 只更新 secretKey,apiKey 留空
|
||||
err = db.UpdateExchange(
|
||||
userID,
|
||||
"hyperliquid",
|
||||
true,
|
||||
"", // 留空
|
||||
"new-secret-key",
|
||||
false,
|
||||
"0xFinalWallet",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("更新2失败: %v", err)
|
||||
}
|
||||
|
||||
exchanges, _ = db.GetExchanges(userID)
|
||||
if exchanges[0].APIKey != "new-api-key" {
|
||||
t.Error("APIKey 应该保持不变")
|
||||
}
|
||||
if exchanges[0].SecretKey != "new-secret-key" {
|
||||
t.Error("SecretKey 应该更新")
|
||||
}
|
||||
if exchanges[0].Enabled != true {
|
||||
t.Error("Enabled 应该更新为 true")
|
||||
}
|
||||
if exchanges[0].HyperliquidWalletAddr != "0xFinalWallet" {
|
||||
t.Error("WalletAddr 应该更新")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateExchange_OnlyNonSensitiveFields 测试只更新非敏感字段
|
||||
func TestUpdateExchange_OnlyNonSensitiveFields(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
userID := "test-user-008"
|
||||
|
||||
// 创建初始配置(包含所有私钥)
|
||||
err := db.UpdateExchange(
|
||||
userID,
|
||||
"aster",
|
||||
true,
|
||||
"binance-api",
|
||||
"binance-secret",
|
||||
false,
|
||||
"",
|
||||
"0xUser1",
|
||||
"0xSigner1",
|
||||
"aster-private-key-1",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("初始化失败: %v", err)
|
||||
}
|
||||
|
||||
// 只更新非敏感字段(所有私钥字段留空)
|
||||
err = db.UpdateExchange(
|
||||
userID,
|
||||
"aster",
|
||||
false,
|
||||
"",
|
||||
"",
|
||||
true,
|
||||
"",
|
||||
"0xUser2",
|
||||
"0xSigner2",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("更新失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证所有私钥保持不变
|
||||
exchanges, _ := db.GetExchanges(userID)
|
||||
if exchanges[0].APIKey != "binance-api" {
|
||||
t.Errorf("APIKey 应该保持不变,实际 %s", exchanges[0].APIKey)
|
||||
}
|
||||
if exchanges[0].SecretKey != "binance-secret" {
|
||||
t.Errorf("SecretKey 应该保持不变,实际 %s", exchanges[0].SecretKey)
|
||||
}
|
||||
if exchanges[0].AsterPrivateKey != "aster-private-key-1" {
|
||||
t.Errorf("AsterPrivateKey 应该保持不变,实际 %s", exchanges[0].AsterPrivateKey)
|
||||
}
|
||||
|
||||
// 验证非敏感字段已更新
|
||||
if exchanges[0].Enabled != false {
|
||||
t.Error("Enabled 应该更新为 false")
|
||||
}
|
||||
if exchanges[0].Testnet != true {
|
||||
t.Error("Testnet 应该更新为 true")
|
||||
}
|
||||
if exchanges[0].AsterUser != "0xUser2" {
|
||||
t.Error("AsterUser 应该更新")
|
||||
}
|
||||
if exchanges[0].AsterSigner != "0xSigner2" {
|
||||
t.Error("AsterSigner 应该更新")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateExchange_AllSensitiveFieldsUpdate 测试同时更新所有敏感字段
|
||||
func TestUpdateExchange_AllSensitiveFieldsUpdate(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
userID := "test-user-009"
|
||||
|
||||
// 创建初始配置
|
||||
err := db.UpdateExchange(
|
||||
userID,
|
||||
"binance",
|
||||
true,
|
||||
"old-api",
|
||||
"old-secret",
|
||||
false,
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"old-aster-key",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("初始化失败: %v", err)
|
||||
}
|
||||
|
||||
// 同时更新所有敏感字段
|
||||
err = db.UpdateExchange(
|
||||
userID,
|
||||
"binance",
|
||||
false,
|
||||
"new-api",
|
||||
"new-secret",
|
||||
true,
|
||||
"0xWallet",
|
||||
"0xUser",
|
||||
"0xSigner",
|
||||
"new-aster-key",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("更新失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证所有字段都更新了
|
||||
exchanges, _ := db.GetExchanges(userID)
|
||||
if exchanges[0].APIKey != "new-api" {
|
||||
t.Error("APIKey 应该更新")
|
||||
}
|
||||
if exchanges[0].SecretKey != "new-secret" {
|
||||
t.Error("SecretKey 应该更新")
|
||||
}
|
||||
if exchanges[0].AsterPrivateKey != "new-aster-key" {
|
||||
t.Error("AsterPrivateKey 应该更新")
|
||||
}
|
||||
if !exchanges[0].Testnet {
|
||||
t.Error("Testnet 应该更新为 true")
|
||||
}
|
||||
}
|
||||
|
||||
// setupTestDB 创建测试数据库
|
||||
func setupTestDB(t *testing.T) (*Database, func()) {
|
||||
// 创建临时数据库文件
|
||||
tmpFile := t.TempDir() + "/test.db"
|
||||
|
||||
db, err := NewDatabase(tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("创建测试数据库失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建测试用户
|
||||
testUsers := []string{
|
||||
"test-user-001", "test-user-002", "test-user-003", "test-user-004", "test-user-005",
|
||||
"test-user-006", "test-user-007", "test-user-008", "test-user-009",
|
||||
"test-user-persistence", "user1", "user2",
|
||||
}
|
||||
for _, userID := range testUsers {
|
||||
user := &User{
|
||||
ID: userID,
|
||||
Email: userID + "@test.com",
|
||||
PasswordHash: "hash",
|
||||
OTPSecret: "",
|
||||
OTPVerified: false,
|
||||
}
|
||||
_ = db.CreateUser(user)
|
||||
}
|
||||
|
||||
// 设置加密服务(用于测试加密功能)
|
||||
// 创建临时 RSA 密钥
|
||||
rsaKeyPath := t.TempDir() + "/test_rsa_key"
|
||||
cryptoService, err := crypto.NewCryptoService(rsaKeyPath)
|
||||
if err != nil {
|
||||
// 如果创建失败,继续测试但不使用加密
|
||||
t.Logf("警告:无法创建加密服务,将在无加密模式下测试: %v", err)
|
||||
} else {
|
||||
db.SetCryptoService(cryptoService)
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
os.RemoveAll(tmpFile)
|
||||
os.RemoveAll(rsaKeyPath)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
}
|
||||
|
||||
// TestWALModeEnabled 测试 WAL 模式是否启用
|
||||
// TDD: 这个测试应该失败,因为当前代码没有启用 WAL 模式
|
||||
func TestWALModeEnabled(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
// 查询当前的 journal_mode
|
||||
var journalMode string
|
||||
err := db.db.QueryRow("PRAGMA journal_mode").Scan(&journalMode)
|
||||
if err != nil {
|
||||
t.Fatalf("查询 journal_mode 失败: %v", err)
|
||||
}
|
||||
|
||||
// 期望是 WAL 模式
|
||||
if journalMode != "wal" {
|
||||
t.Errorf("期望 journal_mode=wal,实际是 %s", journalMode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSynchronousMode 测试 synchronous 模式设置
|
||||
// TDD: 验证数据持久性设置
|
||||
func TestSynchronousMode(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
// 查询 synchronous 设置
|
||||
var synchronous int
|
||||
err := db.db.QueryRow("PRAGMA synchronous").Scan(&synchronous)
|
||||
if err != nil {
|
||||
t.Fatalf("查询 synchronous 失败: %v", err)
|
||||
}
|
||||
|
||||
// 期望是 FULL (2) 以确保数据持久性
|
||||
if synchronous != 2 {
|
||||
t.Errorf("期望 synchronous=2 (FULL),实际是 %d", synchronous)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDataPersistenceAcrossReopen 测试数据在数据库关闭并重新打开后是否持久化
|
||||
// TDD: 模拟 Docker restart 场景
|
||||
func TestDataPersistenceAcrossReopen(t *testing.T) {
|
||||
// 创建临时数据库文件
|
||||
tmpFile, err := os.CreateTemp("", "test_persistence_*.db")
|
||||
if err != nil {
|
||||
t.Fatalf("创建临时文件失败: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
dbPath := tmpFile.Name()
|
||||
defer os.Remove(dbPath)
|
||||
|
||||
// 设置加密服务
|
||||
rsaKeyPath := "test_rsa_key.pem"
|
||||
cryptoService, err := crypto.NewCryptoService(rsaKeyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("初始化加密服务失败: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(rsaKeyPath)
|
||||
|
||||
userID := "test-user-persistence"
|
||||
testAPIKey := "test-api-key-should-persist"
|
||||
testSecretKey := "test-secret-key-should-persist"
|
||||
|
||||
// 第一次打开数据库并写入数据
|
||||
{
|
||||
db, err := NewDatabase(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("第一次创建数据库失败: %v", err)
|
||||
}
|
||||
db.SetCryptoService(cryptoService)
|
||||
|
||||
// 创建持久化测试用户,避免外键约束失败
|
||||
_ = db.CreateUser(&User{
|
||||
ID: userID,
|
||||
Email: userID + "@test.com",
|
||||
PasswordHash: "hash",
|
||||
OTPSecret: "",
|
||||
OTPVerified: true,
|
||||
})
|
||||
|
||||
// 写入交易所配置
|
||||
err = db.UpdateExchange(
|
||||
userID,
|
||||
"binance",
|
||||
true,
|
||||
testAPIKey,
|
||||
testSecretKey,
|
||||
false,
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("写入数据失败: %v", err)
|
||||
}
|
||||
|
||||
// 模拟正常关闭
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatalf("关闭数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 第二次打开数据库并验证数据是否还在
|
||||
{
|
||||
db, err := NewDatabase(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("第二次打开数据库失败: %v", err)
|
||||
}
|
||||
db.SetCryptoService(cryptoService)
|
||||
defer db.Close()
|
||||
|
||||
// 读取数据
|
||||
exchanges, err := db.GetExchanges(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("读取数据失败: %v", err)
|
||||
}
|
||||
|
||||
if len(exchanges) == 0 {
|
||||
t.Fatal("数据丢失:没有找到任何交易所配置")
|
||||
}
|
||||
|
||||
// 验证数据完整性
|
||||
found := false
|
||||
for _, ex := range exchanges {
|
||||
if ex.ID == "binance" {
|
||||
found = true
|
||||
if ex.APIKey != testAPIKey {
|
||||
t.Errorf("API Key 丢失或损坏,期望 %s,实际 %s", testAPIKey, ex.APIKey)
|
||||
}
|
||||
if ex.SecretKey != testSecretKey {
|
||||
t.Errorf("Secret Key 丢失或损坏,期望 %s,实际 %s", testSecretKey, ex.SecretKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Error("数据丢失:找不到 binance 配置")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentWritesWithWAL 测试 WAL 模式下的并发写入
|
||||
// TDD: WAL 模式应该支持更好的并发性能
|
||||
func TestConcurrentWritesWithWAL(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
// 这个测试验证多个并发写入可以成功
|
||||
// WAL 模式下并发性能更好,但 SQLite 仍然可能出现短暂的锁
|
||||
done := make(chan bool, 2)
|
||||
errors := make(chan error, 10)
|
||||
|
||||
// 并发写入1
|
||||
go func() {
|
||||
for i := 0; i < 3; i++ {
|
||||
err := db.UpdateExchange(
|
||||
"user1",
|
||||
"binance",
|
||||
true,
|
||||
"key1",
|
||||
"secret1",
|
||||
false,
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
// 小延迟减少锁冲突
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// 并发写入2
|
||||
go func() {
|
||||
for i := 0; i < 3; i++ {
|
||||
err := db.UpdateExchange(
|
||||
"user2",
|
||||
"hyperliquid",
|
||||
true,
|
||||
"key2",
|
||||
"secret2",
|
||||
false,
|
||||
"0xWallet",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
// 小延迟减少锁冲突
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// 等待两个 goroutine 完成
|
||||
<-done
|
||||
<-done
|
||||
close(errors)
|
||||
|
||||
// 检查是否有错误
|
||||
errorCount := 0
|
||||
for err := range errors {
|
||||
t.Logf("并发写入错误: %v", err)
|
||||
errorCount++
|
||||
}
|
||||
|
||||
// WAL 模式下应该能处理并发,但可能有少量锁错误
|
||||
// 我们允许最多 2 个错误
|
||||
if errorCount > 2 {
|
||||
t.Errorf("并发写入失败次数过多: %d", errorCount)
|
||||
}
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4Y666RzY5LLi6PiYL+vC
|
||||
7+fcr122Fd8BC7IdqUSYKQ33Nsi9J7J5fDgcMf7ZAnIBpxMV7+e1KEoiwtGmxwHj
|
||||
mYo0ZV0E6JXdiK26S052+Shquri0IXkwGFraDuNKqmGrj6vZuXtq2L2gdSyZCxrI
|
||||
veN9g6LxBvLBP1Rx7UEmZeyokRYvChcxAQXuS/0br44BOHGtwAElk6AGLISz55AG
|
||||
oM40b3ktiza+8THKMz3GiylQQYpBltbM3yAXPlnXJ2MtUZiaHNhEQI4++PMvEErN
|
||||
Izm8cIgcvUAXJ5vBfa4kD0kSgBJFuEQ2im3qcWTuEPRKztEeJDY7XAVHc1Xy6d4N
|
||||
vQIDAQAB
|
||||
-----END PUBLIC KEY-----
|
||||
278
crypto/crypto.go
278
crypto/crypto.go
@@ -13,10 +13,7 @@ import (
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -24,8 +21,12 @@ import (
|
||||
const (
|
||||
storagePrefix = "ENC:v1:"
|
||||
storageDelimiter = ":"
|
||||
dataKeyEnvName = "DATA_ENCRYPTION_KEY"
|
||||
dataKeyFilePath = "secrets/data_key"
|
||||
)
|
||||
|
||||
// 环境变量名称
|
||||
const (
|
||||
EnvDataEncryptionKey = "DATA_ENCRYPTION_KEY" // AES 数据加密密钥 (Base64)
|
||||
EnvRSAPrivateKey = "RSA_PRIVATE_KEY" // RSA 私钥 (PEM 格式,换行用 \n)
|
||||
)
|
||||
|
||||
type EncryptedPayload struct {
|
||||
@@ -50,29 +51,18 @@ type CryptoService struct {
|
||||
dataKey []byte
|
||||
}
|
||||
|
||||
func NewCryptoService(privateKeyPath string) (*CryptoService, error) {
|
||||
// 读取私钥文件
|
||||
privateKeyPEM, err := ioutil.ReadFile(privateKeyPath)
|
||||
// NewCryptoService 创建加密服务(从环境变量加载密钥)
|
||||
func NewCryptoService() (*CryptoService, error) {
|
||||
// 1. 加载 RSA 私钥
|
||||
privateKey, err := loadRSAPrivateKeyFromEnv()
|
||||
if err != nil {
|
||||
// 如果私钥文件不存在,生成新的密钥对
|
||||
if err := GenerateRSAKeyPair(privateKeyPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate RSA key pair: %w", err)
|
||||
}
|
||||
privateKeyPEM, err = ioutil.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read generated private key: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("RSA 私钥加载失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析私钥
|
||||
privateKey, err := ParseRSAPrivateKeyFromPEM(privateKeyPEM)
|
||||
// 2. 加载 AES 数据加密密钥
|
||||
dataKey, err := loadDataKeyFromEnv()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
|
||||
dataKey, err := resolveDataKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load data encryption key: %w", err)
|
||||
return nil, fmt.Errorf("数据加密密钥加载失败: %w", err)
|
||||
}
|
||||
|
||||
return &CryptoService{
|
||||
@@ -82,56 +72,43 @@ func NewCryptoService(privateKeyPath string) (*CryptoService, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GenerateRSAKeyPair(privateKeyPath string) error {
|
||||
// 确保目录存在
|
||||
dir := filepath.Dir(privateKeyPath)
|
||||
if dir != "." {
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||
}
|
||||
// loadRSAPrivateKeyFromEnv 从环境变量加载 RSA 私钥
|
||||
func loadRSAPrivateKeyFromEnv() (*rsa.PrivateKey, error) {
|
||||
keyPEM := os.Getenv(EnvRSAPrivateKey)
|
||||
if keyPEM == "" {
|
||||
return nil, fmt.Errorf("环境变量 %s 未设置,请在 .env 中配置 RSA 私钥", EnvRSAPrivateKey)
|
||||
}
|
||||
|
||||
// 生成 RSA 密钥对
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 处理环境变量中的换行符(\n -> 实际换行)
|
||||
keyPEM = strings.ReplaceAll(keyPEM, "\\n", "\n")
|
||||
|
||||
// 编码私钥
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
|
||||
// 保存私钥
|
||||
if err := ioutil.WriteFile(privateKeyPath, privateKeyPEM, 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 编码公钥
|
||||
publicKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: publicKeyDER,
|
||||
})
|
||||
|
||||
// 保存公钥
|
||||
publicKeyPath := privateKeyPath + ".pub"
|
||||
if err := ioutil.WriteFile(publicKeyPath, publicKeyPEM, 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return ParseRSAPrivateKeyFromPEM([]byte(keyPEM))
|
||||
}
|
||||
|
||||
// loadDataKeyFromEnv 从环境变量加载 AES 数据加密密钥
|
||||
func loadDataKeyFromEnv() ([]byte, error) {
|
||||
keyStr := strings.TrimSpace(os.Getenv(EnvDataEncryptionKey))
|
||||
if keyStr == "" {
|
||||
return nil, fmt.Errorf("环境变量 %s 未设置,请在 .env 中配置数据加密密钥", EnvDataEncryptionKey)
|
||||
}
|
||||
|
||||
// 尝试解码
|
||||
if key, ok := decodePossibleKey(keyStr); ok {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// 如果无法解码,使用 SHA256 哈希作为密钥
|
||||
sum := sha256.Sum256([]byte(keyStr))
|
||||
key := make([]byte, len(sum))
|
||||
copy(key, sum[:])
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// ParseRSAPrivateKeyFromPEM 解析 PEM 格式的 RSA 私钥
|
||||
func ParseRSAPrivateKeyFromPEM(pemBytes []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(pemBytes)
|
||||
if block == nil {
|
||||
return nil, errors.New("no PEM block found")
|
||||
return nil, errors.New("无效的 PEM 格式")
|
||||
}
|
||||
|
||||
switch block.Type {
|
||||
@@ -144,100 +121,15 @@ func ParseRSAPrivateKeyFromPEM(pemBytes []byte) (*rsa.PrivateKey, error) {
|
||||
}
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, errors.New("not an RSA key")
|
||||
return nil, errors.New("不是 RSA 密钥")
|
||||
}
|
||||
return rsaKey, nil
|
||||
default:
|
||||
return nil, errors.New("unsupported key type: " + block.Type)
|
||||
return nil, errors.New("不支持的密钥类型: " + block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveDataKey() ([]byte, error) {
|
||||
if key, ok := loadDataKeyFromEnv(); ok {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
key, _, err := loadOrCreateDataKeyFile(dataKeyFilePath)
|
||||
return key, err
|
||||
}
|
||||
|
||||
func loadDataKeyFromEnv() ([]byte, bool) {
|
||||
keyStr := strings.TrimSpace(os.Getenv(dataKeyEnvName))
|
||||
if keyStr == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if key, ok := decodePossibleKey(keyStr); ok {
|
||||
return key, true
|
||||
}
|
||||
|
||||
sum := sha256.Sum256([]byte(keyStr))
|
||||
key := make([]byte, len(sum))
|
||||
copy(key, sum[:])
|
||||
return key, true
|
||||
}
|
||||
|
||||
var errInvalidDataKeyMaterial = errors.New("invalid data encryption key material")
|
||||
|
||||
func loadOrCreateDataKeyFile(path string) ([]byte, bool, error) {
|
||||
key, err := readDataKeyFromFile(path)
|
||||
if err == nil {
|
||||
log.Printf("🔐 使用本地数据加密密钥: %s", path)
|
||||
return key, false, nil
|
||||
}
|
||||
|
||||
if !errors.Is(err, os.ErrNotExist) && !errors.Is(err, errInvalidDataKeyMaterial) {
|
||||
log.Printf("⚠️ 无法读取数据加密密钥文件 (%s): %v,尝试重新生成", path, err)
|
||||
}
|
||||
|
||||
key, err = generateAndPersistDataKey(path)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return key, true, nil
|
||||
}
|
||||
|
||||
func readDataKeyFromFile(path string) ([]byte, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encoded := strings.TrimSpace(string(data))
|
||||
if encoded == "" {
|
||||
return nil, errInvalidDataKeyMaterial
|
||||
}
|
||||
|
||||
if key, ok := decodePossibleKey(encoded); ok {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
return nil, errInvalidDataKeyMaterial
|
||||
}
|
||||
|
||||
func generateAndPersistDataKey(path string) ([]byte, error) {
|
||||
raw := make([]byte, 32)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
if dir != "" && dir != "." {
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
encoded := base64.StdEncoding.EncodeToString(raw)
|
||||
if err := os.WriteFile(path, []byte(encoded+"\n"), 0600); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Printf("🆕 已生成新的数据加密密钥并保存到 %s", path)
|
||||
log.Printf(" 若需在生产或容器环境复用,请设置 %s 为该值", dataKeyEnvName)
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// decodePossibleKey 尝试用多种编码方式解码密钥
|
||||
func decodePossibleKey(value string) ([]byte, bool) {
|
||||
decoders := []func(string) ([]byte, error){
|
||||
base64.StdEncoding.DecodeString,
|
||||
@@ -256,6 +148,7 @@ func decodePossibleKey(value string) ([]byte, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// normalizeAESKey 标准化 AES 密钥长度
|
||||
func normalizeAESKey(raw []byte) ([]byte, bool) {
|
||||
switch len(raw) {
|
||||
case 16, 24, 32:
|
||||
@@ -293,7 +186,7 @@ func (cs *CryptoService) EncryptForStorage(plaintext string, aadParts ...string)
|
||||
return "", nil
|
||||
}
|
||||
if !cs.HasDataKey() {
|
||||
return "", errors.New("data encryption key not configured")
|
||||
return "", errors.New("数据加密密钥未配置")
|
||||
}
|
||||
if isEncryptedStorageValue(plaintext) {
|
||||
return plaintext, nil
|
||||
@@ -327,26 +220,26 @@ func (cs *CryptoService) DecryptFromStorage(value string, aadParts ...string) (s
|
||||
return "", nil
|
||||
}
|
||||
if !cs.HasDataKey() {
|
||||
return "", errors.New("data encryption key not configured")
|
||||
return "", errors.New("数据加密密钥未配置")
|
||||
}
|
||||
if !isEncryptedStorageValue(value) {
|
||||
return "", errors.New("value is not encrypted")
|
||||
return "", errors.New("数据未加密")
|
||||
}
|
||||
|
||||
payload := strings.TrimPrefix(value, storagePrefix)
|
||||
parts := strings.SplitN(payload, storageDelimiter, 2)
|
||||
if len(parts) != 2 {
|
||||
return "", errors.New("invalid encrypted payload format")
|
||||
return "", errors.New("无效的加密数据格式")
|
||||
}
|
||||
|
||||
nonce, err := base64.StdEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode nonce failed: %w", err)
|
||||
return "", fmt.Errorf("解码 nonce 失败: %w", err)
|
||||
}
|
||||
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode ciphertext failed: %w", err)
|
||||
return "", fmt.Errorf("解码密文失败: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(cs.dataKey)
|
||||
@@ -360,13 +253,13 @@ func (cs *CryptoService) DecryptFromStorage(value string, aadParts ...string) (s
|
||||
}
|
||||
|
||||
if len(nonce) != gcm.NonceSize() {
|
||||
return "", fmt.Errorf("invalid nonce size: expected %d, got %d", gcm.NonceSize(), len(nonce))
|
||||
return "", fmt.Errorf("无效的 nonce 长度: 期望 %d, 实际 %d", gcm.NonceSize(), len(nonce))
|
||||
}
|
||||
|
||||
aad := composeAAD(aadParts)
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, aad)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decryption failed: %w", err)
|
||||
return "", fmt.Errorf("解密失败: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
@@ -392,66 +285,63 @@ func (cs *CryptoService) DecryptPayload(payload *EncryptedPayload) ([]byte, erro
|
||||
if payload.TS != 0 {
|
||||
elapsed := time.Since(time.Unix(payload.TS, 0))
|
||||
if elapsed > 5*time.Minute || elapsed < -1*time.Minute {
|
||||
return nil, errors.New("timestamp invalid or expired")
|
||||
return nil, errors.New("时间戳无效或已过期")
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 解码 base64url
|
||||
wrappedKey, err := base64.RawURLEncoding.DecodeString(payload.WrappedKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode wrapped key: %w", err)
|
||||
return nil, fmt.Errorf("解码 wrapped key 失败: %w", err)
|
||||
}
|
||||
|
||||
iv, err := base64.RawURLEncoding.DecodeString(payload.IV)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode IV: %w", err)
|
||||
return nil, fmt.Errorf("解码 IV 失败: %w", err)
|
||||
}
|
||||
|
||||
ciphertext, err := base64.RawURLEncoding.DecodeString(payload.Ciphertext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode ciphertext: %w", err)
|
||||
return nil, fmt.Errorf("解码密文失败: %w", err)
|
||||
}
|
||||
|
||||
var aad []byte
|
||||
if payload.AAD != "" {
|
||||
aad, err = base64.RawURLEncoding.DecodeString(payload.AAD)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode AAD: %w", err)
|
||||
return nil, fmt.Errorf("解码 AAD 失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证 AAD
|
||||
var aadData AADData
|
||||
if err := json.Unmarshal(aad, &aadData); err == nil {
|
||||
// 可以在这里添加额外的验证逻辑
|
||||
// 例如:验证 sessionID、userID 等
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 使用 RSA-OAEP 解密 AES 密钥
|
||||
aesKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, cs.privateKey, wrappedKey, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unwrap AES key: %w", err)
|
||||
return nil, fmt.Errorf("RSA 解密失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 使用 AES-GCM 解密数据
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
|
||||
return nil, fmt.Errorf("创建 AES cipher 失败: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
return nil, fmt.Errorf("创建 GCM 失败: %w", err)
|
||||
}
|
||||
|
||||
if len(iv) != gcm.NonceSize() {
|
||||
return nil, fmt.Errorf("invalid IV size: expected %d, got %d", gcm.NonceSize(), len(iv))
|
||||
return nil, fmt.Errorf("无效的 IV 长度: 期望 %d, 实际 %d", gcm.NonceSize(), len(iv))
|
||||
}
|
||||
|
||||
// 解密并验证认证标签
|
||||
plaintext, err := gcm.Open(nil, iv, ciphertext, aad)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authentication/decryption failed: %w", err)
|
||||
return nil, fmt.Errorf("解密验证失败: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
@@ -464,3 +354,41 @@ func (cs *CryptoService) DecryptSensitiveData(payload *EncryptedPayload) (string
|
||||
}
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// GenerateKeyPair 生成 RSA 密钥对(用于初始化时生成密钥)
|
||||
// 返回 PEM 格式的私钥和公钥
|
||||
func GenerateKeyPair() (privateKeyPEM, publicKeyPEM string, err error) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// 编码私钥
|
||||
privPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
})
|
||||
|
||||
// 编码公钥
|
||||
publicKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
pubPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: publicKeyDER,
|
||||
})
|
||||
|
||||
return string(privPEM), string(pubPEM), nil
|
||||
}
|
||||
|
||||
// GenerateDataKey 生成 AES 数据加密密钥
|
||||
// 返回 Base64 编码的 32 字节密钥
|
||||
func GenerateDataKey() (string, error) {
|
||||
key := make([]byte, 32)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
@@ -1,373 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// EncryptionManager 加密管理器(單例模式)
|
||||
type EncryptionManager struct {
|
||||
privateKey *rsa.PrivateKey
|
||||
publicKeyPEM string
|
||||
masterKey []byte // 用於數據庫加密的主密鑰
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
instance *EncryptionManager
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
// GetEncryptionManager 獲取加密管理器實例
|
||||
func GetEncryptionManager() (*EncryptionManager, error) {
|
||||
var initErr error
|
||||
once.Do(func() {
|
||||
instance, initErr = newEncryptionManager()
|
||||
})
|
||||
return instance, initErr
|
||||
}
|
||||
|
||||
// newEncryptionManager 初始化加密管理器
|
||||
func newEncryptionManager() (*EncryptionManager, error) {
|
||||
em := &EncryptionManager{}
|
||||
|
||||
// 1. 加載或生成 RSA 密鑰對
|
||||
if err := em.loadOrGenerateRSAKeyPair(); err != nil {
|
||||
return nil, fmt.Errorf("初始化 RSA 密鑰失敗: %w", err)
|
||||
}
|
||||
|
||||
// 2. 加載或生成數據庫主密鑰
|
||||
if err := em.loadOrGenerateMasterKey(); err != nil {
|
||||
return nil, fmt.Errorf("初始化主密鑰失敗: %w", err)
|
||||
}
|
||||
|
||||
log.Println("🔐 加密管理器初始化成功")
|
||||
return em, nil
|
||||
}
|
||||
|
||||
// ==================== RSA 密鑰管理 ====================
|
||||
|
||||
const (
|
||||
rsaKeySize = 4096
|
||||
rsaPrivateKeyFile = ".secrets/rsa_private.pem"
|
||||
rsaPublicKeyFile = ".secrets/rsa_public.pem"
|
||||
masterKeyFile = ".secrets/master.key"
|
||||
)
|
||||
|
||||
// loadOrGenerateRSAKeyPair 加載或生成 RSA 密鑰對
|
||||
func (em *EncryptionManager) loadOrGenerateRSAKeyPair() error {
|
||||
// 確保 .secrets 目錄存在
|
||||
if err := os.MkdirAll(".secrets", 0700); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 嘗試加載現有密鑰
|
||||
if _, err := os.Stat(rsaPrivateKeyFile); err == nil {
|
||||
return em.loadRSAKeyPair()
|
||||
}
|
||||
|
||||
// 生成新密鑰對
|
||||
log.Println("🔑 生成新的 RSA-4096 密鑰對...")
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, rsaKeySize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
em.privateKey = privateKey
|
||||
|
||||
// 保存私鑰
|
||||
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: privateKeyBytes,
|
||||
})
|
||||
if err := os.WriteFile(rsaPrivateKeyFile, privateKeyPEM, 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 保存公鑰
|
||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: publicKeyBytes,
|
||||
})
|
||||
if err := os.WriteFile(rsaPublicKeyFile, publicKeyPEM, 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
em.publicKeyPEM = string(publicKeyPEM)
|
||||
log.Println("✅ RSA 密鑰對已生成並保存")
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadRSAKeyPair 加載 RSA 密鑰對
|
||||
func (em *EncryptionManager) loadRSAKeyPair() error {
|
||||
// 加載私鑰
|
||||
privateKeyPEM, err := os.ReadFile(rsaPrivateKeyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(privateKeyPEM)
|
||||
if block == nil || block.Type != "RSA PRIVATE KEY" {
|
||||
return errors.New("無效的私鑰 PEM 格式")
|
||||
}
|
||||
|
||||
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
em.privateKey = privateKey
|
||||
|
||||
// 加載公鑰
|
||||
publicKeyPEM, err := os.ReadFile(rsaPublicKeyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
em.publicKeyPEM = string(publicKeyPEM)
|
||||
|
||||
log.Println("✅ RSA 密鑰對已加載")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPublicKeyPEM 獲取公鑰 (PEM 格式)
|
||||
func (em *EncryptionManager) GetPublicKeyPEM() string {
|
||||
em.mu.RLock()
|
||||
defer em.mu.RUnlock()
|
||||
return em.publicKeyPEM
|
||||
}
|
||||
|
||||
// ==================== 混合解密 (RSA + AES) ====================
|
||||
|
||||
// DecryptWithPrivateKey 使用私鑰解密數據
|
||||
// 數據格式: [加密的 AES 密鑰長度(4字節)] + [加密的 AES 密鑰] + [IV(12字節)] + [加密數據]
|
||||
func (em *EncryptionManager) DecryptWithPrivateKey(encryptedBase64 string) (string, error) {
|
||||
em.mu.RLock()
|
||||
defer em.mu.RUnlock()
|
||||
|
||||
// Base64 解碼
|
||||
encryptedData, err := base64.StdEncoding.DecodeString(encryptedBase64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Base64 解碼失敗: %w", err)
|
||||
}
|
||||
|
||||
if len(encryptedData) < 4+256+12 { // 最小長度檢查
|
||||
return "", errors.New("加密數據長度不足")
|
||||
}
|
||||
|
||||
// 1. 讀取加密的 AES 密鑰長度
|
||||
aesKeyLen := binary.BigEndian.Uint32(encryptedData[:4])
|
||||
if aesKeyLen > 1024 { // 防止過大的長度值
|
||||
return "", errors.New("無效的 AES 密鑰長度")
|
||||
}
|
||||
|
||||
offset := 4
|
||||
// 2. 提取加密的 AES 密鑰
|
||||
encryptedAESKey := encryptedData[offset : offset+int(aesKeyLen)]
|
||||
offset += int(aesKeyLen)
|
||||
|
||||
// 3. 使用 RSA 私鑰解密 AES 密鑰
|
||||
aesKey, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, em.privateKey, encryptedAESKey, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("RSA 解密失敗: %w", err)
|
||||
}
|
||||
|
||||
// 4. 提取 IV
|
||||
iv := encryptedData[offset : offset+12]
|
||||
offset += 12
|
||||
|
||||
// 5. 提取加密數據
|
||||
ciphertext := encryptedData[offset:]
|
||||
|
||||
// 6. 使用 AES-GCM 解密
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
plaintext, err := aesGCM.Open(nil, iv, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("AES 解密失敗: %w", err)
|
||||
}
|
||||
|
||||
// 清除敏感數據
|
||||
for i := range aesKey {
|
||||
aesKey[i] = 0
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// ==================== 數據庫加密 (AES-256-GCM) ====================
|
||||
|
||||
// loadOrGenerateMasterKey 加載或生成數據庫主密鑰
|
||||
func (em *EncryptionManager) loadOrGenerateMasterKey() error {
|
||||
// 優先從環境變數加載
|
||||
if envKey := os.Getenv("NOFX_MASTER_KEY"); envKey != "" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(envKey)
|
||||
if err == nil && len(decoded) == 32 {
|
||||
em.masterKey = decoded
|
||||
log.Println("✅ 從環境變數加載主密鑰")
|
||||
return nil
|
||||
}
|
||||
log.Println("⚠️ 環境變數中的主密鑰無效,使用文件密鑰")
|
||||
}
|
||||
|
||||
// 嘗試從文件加載
|
||||
if _, err := os.Stat(masterKeyFile); err == nil {
|
||||
keyBytes, err := os.ReadFile(masterKeyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(string(keyBytes))
|
||||
if err != nil || len(decoded) != 32 {
|
||||
return errors.New("主密鑰文件損壞")
|
||||
}
|
||||
em.masterKey = decoded
|
||||
log.Println("✅ 從文件加載主密鑰")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 生成新主密鑰
|
||||
log.Println("🔑 生成新的數據庫主密鑰 (AES-256)...")
|
||||
masterKey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, masterKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
em.masterKey = masterKey
|
||||
|
||||
// 保存到文件
|
||||
encoded := base64.StdEncoding.EncodeToString(masterKey)
|
||||
if err := os.WriteFile(masterKeyFile, []byte(encoded), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println("✅ 主密鑰已生成並保存")
|
||||
log.Printf("📁 主密鑰文件位置: %s (權限: 0600)", masterKeyFile)
|
||||
log.Println("🔐 生產環境請設置環境變數: NOFX_MASTER_KEY=<從文件讀取>")
|
||||
log.Println("⚠️ 請妥善保管 .secrets 目錄,切勿將密鑰提交到版本控制系統")
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncryptForDatabase 使用主密鑰加密數據(用於數據庫存儲)
|
||||
func (em *EncryptionManager) EncryptForDatabase(plaintext string) (string, error) {
|
||||
em.mu.RLock()
|
||||
defer em.mu.RUnlock()
|
||||
|
||||
block, err := aes.NewCipher(em.masterKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, aesGCM.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := aesGCM.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// DecryptFromDatabase 使用主密鑰解密數據(從數據庫讀取)
|
||||
func (em *EncryptionManager) DecryptFromDatabase(encryptedBase64 string) (string, error) {
|
||||
em.mu.RLock()
|
||||
defer em.mu.RUnlock()
|
||||
|
||||
// 處理空字符串(未加密的舊數據)
|
||||
if encryptedBase64 == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encryptedBase64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(em.masterKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
aesGCM, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonceSize := aesGCM.NonceSize()
|
||||
if len(ciphertext) < nonceSize {
|
||||
return "", errors.New("加密數據過短")
|
||||
}
|
||||
|
||||
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||
plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// ==================== 密鑰輪換 ====================
|
||||
|
||||
// RotateMasterKey 輪換主密鑰(需要重新加密所有數據)
|
||||
func (em *EncryptionManager) RotateMasterKey() error {
|
||||
em.mu.Lock()
|
||||
defer em.mu.Unlock()
|
||||
|
||||
log.Println("🔄 開始輪換主密鑰...")
|
||||
|
||||
// 生成新主密鑰
|
||||
newMasterKey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, newMasterKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 備份舊密鑰
|
||||
oldMasterKey := em.masterKey
|
||||
|
||||
// 更新密鑰
|
||||
em.masterKey = newMasterKey
|
||||
|
||||
// 保存新密鑰
|
||||
encoded := base64.StdEncoding.EncodeToString(newMasterKey)
|
||||
backupFile := fmt.Sprintf("%s.backup.%d", masterKeyFile, os.Getpid())
|
||||
if err := os.WriteFile(backupFile, []byte(base64.StdEncoding.EncodeToString(oldMasterKey)), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(masterKeyFile, []byte(encoded), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println("✅ 主密鑰已輪換")
|
||||
log.Printf("⚠️ 舊密鑰已備份到: %s", backupFile)
|
||||
log.Printf("🔐 新主密鑰: %s", encoded)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,159 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestRSAKeyPairGeneration 測試 RSA 密鑰對生成
|
||||
func TestRSAKeyPairGeneration(t *testing.T) {
|
||||
em, err := GetEncryptionManager()
|
||||
if err != nil {
|
||||
t.Fatalf("初始化加密管理器失敗: %v", err)
|
||||
}
|
||||
|
||||
publicKey := em.GetPublicKeyPEM()
|
||||
if publicKey == "" {
|
||||
t.Fatal("公鑰為空")
|
||||
}
|
||||
|
||||
if len(publicKey) < 100 {
|
||||
t.Fatal("公鑰長度異常")
|
||||
}
|
||||
|
||||
t.Logf("✅ RSA 密鑰對生成成功,公鑰長度: %d", len(publicKey))
|
||||
}
|
||||
|
||||
// TestDatabaseEncryption 測試數據庫加密/解密
|
||||
func TestDatabaseEncryption(t *testing.T) {
|
||||
em, err := GetEncryptionManager()
|
||||
if err != nil {
|
||||
t.Fatalf("初始化加密管理器失敗: %v", err)
|
||||
}
|
||||
|
||||
testCases := []string{
|
||||
"0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef",
|
||||
"test_api_key_12345",
|
||||
"very_secret_password",
|
||||
"",
|
||||
}
|
||||
|
||||
for _, plaintext := range testCases {
|
||||
// 加密
|
||||
encrypted, err := em.EncryptForDatabase(plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("加密失敗: %v (明文: %s)", err, plaintext)
|
||||
}
|
||||
|
||||
// 驗證加密後不等於明文
|
||||
if encrypted == plaintext && plaintext != "" {
|
||||
t.Fatalf("加密失敗:加密後仍為明文")
|
||||
}
|
||||
|
||||
// 解密
|
||||
decrypted, err := em.DecryptFromDatabase(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("解密失敗: %v (密文: %s)", err, encrypted)
|
||||
}
|
||||
|
||||
// 驗證解密後等於明文
|
||||
if decrypted != plaintext {
|
||||
t.Fatalf("解密結果不匹配: 期望 %s, 得到 %s", plaintext, decrypted)
|
||||
}
|
||||
|
||||
t.Logf("✅ 加密/解密測試通過: %s", plaintext[:min(len(plaintext), 20)])
|
||||
}
|
||||
}
|
||||
|
||||
// TestHybridEncryption 測試混合加密(前端 → 後端場景)
|
||||
func TestHybridEncryption(t *testing.T) {
|
||||
_, err := GetEncryptionManager()
|
||||
if err != nil {
|
||||
t.Fatalf("初始化加密管理器失敗: %v", err)
|
||||
}
|
||||
// 模擬前端加密私鑰
|
||||
// plaintext := "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
// 注意:這裡需要前端的 encryptWithServerPublicKey 實現
|
||||
// 為了測試,我們直接使用後端的加密函數(實際前端使用 Web Crypto API)
|
||||
|
||||
// 由於前端加密邏輯較複雜,這裡僅測試解密流程
|
||||
// 實際測試需要端到端測試
|
||||
t.Log("⚠️ 混合加密測試需要完整的前後端環境,請執行端到端測試")
|
||||
}
|
||||
|
||||
// TestEmptyString 測試空字串處理
|
||||
func TestEmptyString(t *testing.T) {
|
||||
em, err := GetEncryptionManager()
|
||||
if err != nil {
|
||||
t.Fatalf("初始化加密管理器失敗: %v", err)
|
||||
}
|
||||
|
||||
encrypted, err := em.EncryptForDatabase("")
|
||||
if err != nil {
|
||||
t.Fatalf("加密空字串失敗: %v", err)
|
||||
}
|
||||
|
||||
decrypted, err := em.DecryptFromDatabase(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("解密空字串失敗: %v", err)
|
||||
}
|
||||
|
||||
if decrypted != "" {
|
||||
t.Fatalf("空字串處理錯誤: 期望空字串, 得到 %s", decrypted)
|
||||
}
|
||||
|
||||
t.Log("✅ 空字串處理正確")
|
||||
}
|
||||
|
||||
// TestInvalidCiphertext 測試無效密文處理
|
||||
func TestInvalidCiphertext(t *testing.T) {
|
||||
em, err := GetEncryptionManager()
|
||||
if err != nil {
|
||||
t.Fatalf("初始化加密管理器失敗: %v", err)
|
||||
}
|
||||
|
||||
invalidCiphertexts := []string{
|
||||
"not_base64!@#$%",
|
||||
"dGVzdA==", // 有效 Base64,但內容太短
|
||||
"",
|
||||
}
|
||||
|
||||
for _, ciphertext := range invalidCiphertexts {
|
||||
_, err := em.DecryptFromDatabase(ciphertext)
|
||||
if err == nil && ciphertext != "" {
|
||||
t.Fatalf("應該拒絕無效密文: %s", ciphertext)
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("✅ 無效密文處理正確")
|
||||
}
|
||||
|
||||
// BenchmarkEncryption 性能測試:加密
|
||||
func BenchmarkEncryption(b *testing.B) {
|
||||
em, _ := GetEncryptionManager()
|
||||
plaintext := "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = em.EncryptForDatabase(plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDecryption 性能測試:解密
|
||||
func BenchmarkDecryption(b *testing.B) {
|
||||
em, _ := GetEncryptionManager()
|
||||
plaintext := "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
encrypted, _ := em.EncryptForDatabase(plaintext)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = em.DecryptFromDatabase(encrypted)
|
||||
}
|
||||
}
|
||||
|
||||
// min 工具函數
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -1,302 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SecureStorage 安全存儲層(自動加密/解密數據庫中的敏感字段)
|
||||
type SecureStorage struct {
|
||||
db *sql.DB
|
||||
em *EncryptionManager
|
||||
}
|
||||
|
||||
// NewSecureStorage 創建安全存儲實例
|
||||
func NewSecureStorage(db *sql.DB) (*SecureStorage, error) {
|
||||
em, err := GetEncryptionManager()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ss := &SecureStorage{
|
||||
db: db,
|
||||
em: em,
|
||||
}
|
||||
|
||||
// 初始化審計日誌表
|
||||
if err := ss.initAuditLog(); err != nil {
|
||||
return nil, fmt.Errorf("初始化審計日誌失敗: %w", err)
|
||||
}
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// ==================== 交易所配置加密存儲 ====================
|
||||
|
||||
// SaveEncryptedExchangeConfig 保存加密的交易所配置
|
||||
func (ss *SecureStorage) SaveEncryptedExchangeConfig(userID, exchangeID, apiKey, secretKey, asterPrivateKey string) error {
|
||||
// 加密敏感字段
|
||||
encryptedAPIKey, err := ss.em.EncryptForDatabase(apiKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 API Key 失敗: %w", err)
|
||||
}
|
||||
|
||||
encryptedSecretKey, err := ss.em.EncryptForDatabase(secretKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 Secret Key 失敗: %w", err)
|
||||
}
|
||||
|
||||
encryptedPrivateKey := ""
|
||||
if asterPrivateKey != "" {
|
||||
encryptedPrivateKey, err = ss.em.EncryptForDatabase(asterPrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 Private Key 失敗: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新數據庫
|
||||
_, err = ss.db.Exec(`
|
||||
UPDATE exchanges
|
||||
SET api_key = ?, secret_key = ?, aster_private_key = ?, updated_at = datetime('now')
|
||||
WHERE user_id = ? AND id = ?
|
||||
`, encryptedAPIKey, encryptedSecretKey, encryptedPrivateKey, userID, exchangeID)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 記錄審計日誌
|
||||
ss.logAudit(userID, "exchange_config_update", exchangeID, "密鑰已更新")
|
||||
|
||||
log.Printf("🔐 [%s] 交易所 %s 的密鑰已加密保存", userID, exchangeID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadDecryptedExchangeConfig 加載並解密交易所配置
|
||||
func (ss *SecureStorage) LoadDecryptedExchangeConfig(userID, exchangeID string) (apiKey, secretKey, asterPrivateKey string, err error) {
|
||||
var encryptedAPIKey, encryptedSecretKey, encryptedPrivateKey sql.NullString
|
||||
|
||||
err = ss.db.QueryRow(`
|
||||
SELECT api_key, secret_key, aster_private_key
|
||||
FROM exchanges
|
||||
WHERE user_id = ? AND id = ?
|
||||
`, userID, exchangeID).Scan(&encryptedAPIKey, &encryptedSecretKey, &encryptedPrivateKey)
|
||||
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
// 解密 API Key
|
||||
if encryptedAPIKey.Valid && encryptedAPIKey.String != "" {
|
||||
apiKey, err = ss.em.DecryptFromDatabase(encryptedAPIKey.String)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("解密 API Key 失敗: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 解密 Secret Key
|
||||
if encryptedSecretKey.Valid && encryptedSecretKey.String != "" {
|
||||
secretKey, err = ss.em.DecryptFromDatabase(encryptedSecretKey.String)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("解密 Secret Key 失敗: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 解密 Private Key
|
||||
if encryptedPrivateKey.Valid && encryptedPrivateKey.String != "" {
|
||||
asterPrivateKey, err = ss.em.DecryptFromDatabase(encryptedPrivateKey.String)
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("解密 Private Key 失敗: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 記錄審計日誌
|
||||
ss.logAudit(userID, "exchange_config_read", exchangeID, "密鑰已讀取")
|
||||
|
||||
return apiKey, secretKey, asterPrivateKey, nil
|
||||
}
|
||||
|
||||
// ==================== AI 模型配置加密存儲 ====================
|
||||
|
||||
// SaveEncryptedAIModelConfig 保存加密的 AI 模型 API Key
|
||||
func (ss *SecureStorage) SaveEncryptedAIModelConfig(userID, modelID, apiKey string) error {
|
||||
encryptedAPIKey, err := ss.em.EncryptForDatabase(apiKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 API Key 失敗: %w", err)
|
||||
}
|
||||
|
||||
_, err = ss.db.Exec(`
|
||||
UPDATE ai_models
|
||||
SET api_key = ?, updated_at = datetime('now')
|
||||
WHERE user_id = ? AND id = ?
|
||||
`, encryptedAPIKey, userID, modelID)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ss.logAudit(userID, "ai_model_config_update", modelID, "API Key 已更新")
|
||||
log.Printf("🔐 [%s] AI 模型 %s 的 API Key 已加密保存", userID, modelID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadDecryptedAIModelConfig 加載並解密 AI 模型配置
|
||||
func (ss *SecureStorage) LoadDecryptedAIModelConfig(userID, modelID string) (string, error) {
|
||||
var encryptedAPIKey sql.NullString
|
||||
|
||||
err := ss.db.QueryRow(`
|
||||
SELECT api_key FROM ai_models WHERE user_id = ? AND id = ?
|
||||
`, userID, modelID).Scan(&encryptedAPIKey)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if !encryptedAPIKey.Valid || encryptedAPIKey.String == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
apiKey, err := ss.em.DecryptFromDatabase(encryptedAPIKey.String)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解密 API Key 失敗: %w", err)
|
||||
}
|
||||
|
||||
ss.logAudit(userID, "ai_model_config_read", modelID, "API Key 已讀取")
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// ==================== 審計日誌 ====================
|
||||
|
||||
// initAuditLog 初始化審計日誌表
|
||||
func (ss *SecureStorage) initAuditLog() error {
|
||||
_, err := ss.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS audit_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
resource TEXT NOT NULL,
|
||||
details TEXT,
|
||||
ip_address TEXT,
|
||||
user_agent TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
INDEX idx_user_time (user_id, timestamp),
|
||||
INDEX idx_action (action)
|
||||
)
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
// logAudit 記錄審計日誌
|
||||
func (ss *SecureStorage) logAudit(userID, action, resource, details string) {
|
||||
_, err := ss.db.Exec(`
|
||||
INSERT INTO audit_logs (user_id, action, resource, details)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, userID, action, resource, details)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 審計日誌記錄失敗: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuditLogs 查詢審計日誌
|
||||
func (ss *SecureStorage) GetAuditLogs(userID string, limit int) ([]AuditLog, error) {
|
||||
rows, err := ss.db.Query(`
|
||||
SELECT id, user_id, action, resource, details, timestamp
|
||||
FROM audit_logs
|
||||
WHERE user_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
`, userID, limit)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var logs []AuditLog
|
||||
for rows.Next() {
|
||||
var log AuditLog
|
||||
err := rows.Scan(&log.ID, &log.UserID, &log.Action, &log.Resource, &log.Details, &log.Timestamp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logs = append(logs, log)
|
||||
}
|
||||
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
// AuditLog 審計日誌結構
|
||||
type AuditLog struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Action string `json:"action"`
|
||||
Resource string `json:"resource"`
|
||||
Details string `json:"details"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// ==================== 數據遷移工具 ====================
|
||||
|
||||
// MigrateToEncrypted 將舊的明文數據遷移到加密格式
|
||||
func (ss *SecureStorage) MigrateToEncrypted() error {
|
||||
log.Println("🔄 開始遷移明文數據到加密格式...")
|
||||
|
||||
tx, err := ss.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// 遷移交易所配置
|
||||
rows, err := tx.Query(`
|
||||
SELECT user_id, id, api_key, secret_key, aster_private_key
|
||||
FROM exchanges
|
||||
WHERE api_key != '' AND api_key NOT LIKE '%==%' -- 過濾已加密數據
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var count int
|
||||
for rows.Next() {
|
||||
var userID, exchangeID, apiKey, secretKey string
|
||||
var asterPrivateKey sql.NullString
|
||||
if err := rows.Scan(&userID, &exchangeID, &apiKey, &secretKey, &asterPrivateKey); err != nil {
|
||||
rows.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// 加密
|
||||
encAPIKey, _ := ss.em.EncryptForDatabase(apiKey)
|
||||
encSecretKey, _ := ss.em.EncryptForDatabase(secretKey)
|
||||
encPrivateKey := ""
|
||||
if asterPrivateKey.Valid && asterPrivateKey.String != "" {
|
||||
encPrivateKey, _ = ss.em.EncryptForDatabase(asterPrivateKey.String)
|
||||
}
|
||||
|
||||
// 更新
|
||||
_, err = tx.Exec(`
|
||||
UPDATE exchanges
|
||||
SET api_key = ?, secret_key = ?, aster_private_key = ?
|
||||
WHERE user_id = ? AND id = ?
|
||||
`, encAPIKey, encSecretKey, encPrivateKey, userID, exchangeID)
|
||||
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
count++
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("✅ 已遷移 %d 個交易所配置到加密格式", count)
|
||||
return nil
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package decision
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"math"
|
||||
"nofx/market"
|
||||
"nofx/mcp"
|
||||
@@ -72,6 +72,29 @@ type OITopData struct {
|
||||
NetShort float64 // 净空仓
|
||||
}
|
||||
|
||||
// TradingStats 交易统计(用于AI输入)
|
||||
type TradingStats struct {
|
||||
TotalTrades int `json:"total_trades"` // 总交易数(已平仓)
|
||||
WinRate float64 `json:"win_rate"` // 胜率 (%)
|
||||
ProfitFactor float64 `json:"profit_factor"` // 盈亏比
|
||||
SharpeRatio float64 `json:"sharpe_ratio"` // 夏普比
|
||||
TotalPnL float64 `json:"total_pnl"` // 总盈亏
|
||||
AvgWin float64 `json:"avg_win"` // 平均盈利
|
||||
AvgLoss float64 `json:"avg_loss"` // 平均亏损
|
||||
MaxDrawdownPct float64 `json:"max_drawdown_pct"` // 最大回撤 (%)
|
||||
}
|
||||
|
||||
// RecentOrder 最近完成的订单(用于AI输入)
|
||||
type RecentOrder struct {
|
||||
Symbol string `json:"symbol"` // 交易对
|
||||
Side string `json:"side"` // long/short
|
||||
EntryPrice float64 `json:"entry_price"` // 开仓价
|
||||
ExitPrice float64 `json:"exit_price"` // 平仓价
|
||||
RealizedPnL float64 `json:"realized_pnl"` // 已实现盈亏
|
||||
PnLPct float64 `json:"pnl_pct"` // 盈亏百分比
|
||||
FilledAt string `json:"filled_at"` // 成交时间
|
||||
}
|
||||
|
||||
// Context 交易上下文(传递给AI的完整信息)
|
||||
type Context struct {
|
||||
CurrentTime string `json:"current_time"`
|
||||
@@ -81,10 +104,11 @@ type Context struct {
|
||||
Positions []PositionInfo `json:"positions"`
|
||||
CandidateCoins []CandidateCoin `json:"candidate_coins"`
|
||||
PromptVariant string `json:"prompt_variant,omitempty"`
|
||||
MarketDataMap map[string]*market.Data `json:"-"` // 不序列化,但内部使用
|
||||
TradingStats *TradingStats `json:"trading_stats,omitempty"` // 交易统计指标
|
||||
RecentOrders []RecentOrder `json:"recent_orders,omitempty"` // 最近完成的订单(10条)
|
||||
MarketDataMap map[string]*market.Data `json:"-"` // 不序列化,但内部使用
|
||||
MultiTFMarket map[string]map[string]*market.Data `json:"-"`
|
||||
OITopDataMap map[string]*OITopData `json:"-"` // OI Top数据映射
|
||||
Performance interface{} `json:"-"` // 历史表现分析(logger.PerformanceAnalysis)
|
||||
BTCETHLeverage int `json:"-"` // BTC/ETH杠杆倍数(从配置读取)
|
||||
AltcoinLeverage int `json:"-"` // 山寨币杠杆倍数(从配置读取)
|
||||
}
|
||||
@@ -92,7 +116,7 @@ type Context struct {
|
||||
// Decision AI的交易决策
|
||||
type Decision struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Action string `json:"action"` // "open_long", "open_short", "close_long", "close_short", "update_stop_loss", "update_take_profit", "partial_close", "hold", "wait"
|
||||
Action string `json:"action"` // "open_long", "open_short", "close_long", "close_short", "hold", "wait"
|
||||
|
||||
// 开仓参数
|
||||
Leverage int `json:"leverage,omitempty"`
|
||||
@@ -100,11 +124,6 @@ type Decision struct {
|
||||
StopLoss float64 `json:"stop_loss,omitempty"`
|
||||
TakeProfit float64 `json:"take_profit,omitempty"`
|
||||
|
||||
// 调整参数(新增)
|
||||
NewStopLoss float64 `json:"new_stop_loss,omitempty"` // 用于 update_stop_loss
|
||||
NewTakeProfit float64 `json:"new_take_profit,omitempty"` // 用于 update_take_profit
|
||||
ClosePercentage float64 `json:"close_percentage,omitempty"` // 用于 partial_close (0-100)
|
||||
|
||||
// 通用参数
|
||||
Confidence int `json:"confidence,omitempty"` // 信心度 (0-100)
|
||||
RiskUSD float64 `json:"risk_usd,omitempty"` // 最大美元风险
|
||||
@@ -232,7 +251,7 @@ func fetchMarketDataForContext(ctx *Context) error {
|
||||
oiValue := data.OpenInterest.Latest * data.CurrentPrice
|
||||
oiValueInMillions := oiValue / 1_000_000 // 转换为百万美元单位
|
||||
if oiValueInMillions < minOIThresholdMillions {
|
||||
log.Printf("⚠️ %s 持仓价值过低(%.2fM USD < %.1fM),跳过此币种 [持仓量:%.0f × 价格:%.4f]",
|
||||
logger.Infof("⚠️ %s 持仓价值过低(%.2fM USD < %.1fM),跳过此币种 [持仓量:%.0f × 价格:%.4f]",
|
||||
symbol, oiValueInMillions, minOIThresholdMillions, data.OpenInterest.Latest, data.CurrentPrice)
|
||||
continue
|
||||
}
|
||||
@@ -329,11 +348,11 @@ func buildSystemPrompt(accountEquity float64, btcEthLeverage, altcoinLeverage in
|
||||
template, err := GetPromptTemplate(templateName)
|
||||
if err != nil {
|
||||
// 如果模板不存在,记录错误并使用 default
|
||||
log.Printf("⚠️ 提示词模板 '%s' 不存在,使用 default: %v", templateName, err)
|
||||
logger.Infof("⚠️ 提示词模板 '%s' 不存在,使用 default: %v", templateName, err)
|
||||
template, err = GetPromptTemplate("default")
|
||||
if err != nil {
|
||||
// 如果连 default 都不存在,使用内置的简化版本
|
||||
log.Printf("❌ 无法加载任何提示词模板,使用内置简化版本")
|
||||
logger.Infof("❌ 无法加载任何提示词模板,使用内置简化版本")
|
||||
sb.WriteString("你是专业的加密货币交易AI。请根据市场数据做出交易决策。\n\n")
|
||||
} else {
|
||||
sb.WriteString(template.Content)
|
||||
@@ -379,19 +398,11 @@ func buildSystemPrompt(accountEquity float64, btcEthLeverage, altcoinLeverage in
|
||||
sb.WriteString("- AI500 / OI_Top 筛选标签(若有)\n\n")
|
||||
sb.WriteString("自由运用任何有效的分析方法,但**信心度 ≥75** 才能开仓;避免单一指标、信号矛盾、横盘震荡、刚平仓即重启等低质量行为。\n\n")
|
||||
|
||||
// 5. 夏普比率驱动的自适应
|
||||
sb.WriteString("# 🧬 夏普比率自我进化\n\n")
|
||||
sb.WriteString("- Sharpe < -0.5:立即停止交易,至少观望6个周期并深度复盘\n")
|
||||
sb.WriteString("- -0.5 ~ 0:只做信心度>80的交易,并降低频率\n")
|
||||
sb.WriteString("- 0 ~ 0.7:保持当前策略\n")
|
||||
sb.WriteString("- >0.7:允许适度加仓,但仍遵守风控\n\n")
|
||||
|
||||
// 6. 决策流程提示
|
||||
// 5. 决策流程提示
|
||||
sb.WriteString("# 📋 决策流程\n\n")
|
||||
sb.WriteString("1. 回顾夏普比率/盈亏 → 是否需要降频或暂停\n")
|
||||
sb.WriteString("2. 检查持仓 → 是否该止盈/止损/调整\n")
|
||||
sb.WriteString("3. 扫描候选币 + 多时间框 → 是否存在强信号\n")
|
||||
sb.WriteString("4. 先写思维链,再输出结构化JSON\n\n")
|
||||
sb.WriteString("1. 检查持仓 → 是否该止盈/止损\n")
|
||||
sb.WriteString("2. 扫描候选币 + 多时间框 → 是否存在强信号\n")
|
||||
sb.WriteString("3. 先写思维链,再输出结构化JSON\n\n")
|
||||
|
||||
// 7. 输出格式 - 动态生成
|
||||
sb.WriteString("# 输出格式 (严格遵守)\n\n")
|
||||
@@ -405,17 +416,13 @@ func buildSystemPrompt(accountEquity float64, btcEthLeverage, altcoinLeverage in
|
||||
sb.WriteString("第二步: JSON决策数组\n\n")
|
||||
sb.WriteString("```json\n[\n")
|
||||
sb.WriteString(fmt.Sprintf(" {\"symbol\": \"BTCUSDT\", \"action\": \"open_short\", \"leverage\": %d, \"position_size_usd\": %.0f, \"stop_loss\": 97000, \"take_profit\": 91000, \"confidence\": 85, \"risk_usd\": 300},\n", btcEthLeverage, accountEquity*5))
|
||||
sb.WriteString(" {\"symbol\": \"SOLUSDT\", \"action\": \"update_stop_loss\", \"new_stop_loss\": 155},\n")
|
||||
sb.WriteString(" {\"symbol\": \"ETHUSDT\", \"action\": \"close_long\"}\n")
|
||||
sb.WriteString("]\n```\n")
|
||||
sb.WriteString("</decision>\n\n")
|
||||
sb.WriteString("## 字段说明\n\n")
|
||||
sb.WriteString("- `action`: open_long | open_short | close_long | close_short | update_stop_loss | update_take_profit | partial_close | hold | wait\n")
|
||||
sb.WriteString("- `action`: open_long | open_short | close_long | close_short | hold | wait\n")
|
||||
sb.WriteString("- `confidence`: 0-100(开仓建议≥75)\n")
|
||||
sb.WriteString("- 开仓时必填: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd\n")
|
||||
sb.WriteString("- update_stop_loss 时必填: new_stop_loss (注意是 new_stop_loss,不是 stop_loss)\n")
|
||||
sb.WriteString("- update_take_profit 时必填: new_take_profit (注意是 new_take_profit,不是 take_profit)\n")
|
||||
sb.WriteString("- partial_close 时必填: close_percentage (0-100)\n\n")
|
||||
sb.WriteString("- 开仓时必填: leverage, position_size_usd, stop_loss, take_profit, confidence, risk_usd\n\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -462,7 +469,7 @@ func buildUserPrompt(ctx *Context) string {
|
||||
}
|
||||
}
|
||||
|
||||
// 计算仓位价值(用于 partial_close 检查)
|
||||
// 计算仓位价值
|
||||
positionValue := math.Abs(pos.Quantity) * pos.MarkPrice
|
||||
|
||||
sb.WriteString(fmt.Sprintf("%d. %s %s | 入场价%.4f 当前价%.4f | 数量%.4f | 仓位价值%.2f USDT | 盈亏%+.2f%% | 盈亏金额%+.2f USDT | 最高收益率%.2f%% | 杠杆%dx | 保证金%.0f | 强平价%.4f%s\n\n",
|
||||
@@ -480,6 +487,38 @@ func buildUserPrompt(ctx *Context) string {
|
||||
sb.WriteString("当前持仓: 无\n\n")
|
||||
}
|
||||
|
||||
// 交易统计(如果有)
|
||||
if ctx.TradingStats != nil && ctx.TradingStats.TotalTrades > 0 {
|
||||
sb.WriteString("## 历史交易统计\n")
|
||||
sb.WriteString(fmt.Sprintf("总交易数: %d | 胜率: %.1f%% | 盈亏比: %.2f | 夏普比: %.2f\n",
|
||||
ctx.TradingStats.TotalTrades,
|
||||
ctx.TradingStats.WinRate,
|
||||
ctx.TradingStats.ProfitFactor,
|
||||
ctx.TradingStats.SharpeRatio))
|
||||
sb.WriteString(fmt.Sprintf("总盈亏: %.2f USDT | 平均盈利: %.2f | 平均亏损: %.2f | 最大回撤: %.1f%%\n\n",
|
||||
ctx.TradingStats.TotalPnL,
|
||||
ctx.TradingStats.AvgWin,
|
||||
ctx.TradingStats.AvgLoss,
|
||||
ctx.TradingStats.MaxDrawdownPct))
|
||||
}
|
||||
|
||||
// 最近完成的订单(如果有)
|
||||
if len(ctx.RecentOrders) > 0 {
|
||||
sb.WriteString("## 最近完成的交易\n")
|
||||
for i, order := range ctx.RecentOrders {
|
||||
resultStr := "盈利"
|
||||
if order.RealizedPnL < 0 {
|
||||
resultStr = "亏损"
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%d. %s %s | 入场%.4f 出场%.4f | %s: %+.2f USDT (%+.2f%%) | %s\n",
|
||||
i+1, order.Symbol, order.Side,
|
||||
order.EntryPrice, order.ExitPrice,
|
||||
resultStr, order.RealizedPnL, order.PnLPct,
|
||||
order.FilledAt))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// 候选币种(完整市场数据)
|
||||
sb.WriteString(fmt.Sprintf("## 候选币种 (%d个)\n\n", len(ctx.MarketDataMap)))
|
||||
displayedCount := 0
|
||||
@@ -504,20 +543,6 @@ func buildUserPrompt(ctx *Context) string {
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
// 夏普比率(直接传值,不要复杂格式化)
|
||||
if ctx.Performance != nil {
|
||||
// 直接从interface{}中提取SharpeRatio
|
||||
type PerformanceData struct {
|
||||
SharpeRatio float64 `json:"sharpe_ratio"`
|
||||
}
|
||||
var perfData PerformanceData
|
||||
if jsonData, err := json.Marshal(ctx.Performance); err == nil {
|
||||
if err := json.Unmarshal(jsonData, &perfData); err == nil {
|
||||
sb.WriteString(fmt.Sprintf("## 📊 夏普比率: %.2f\n\n", perfData.SharpeRatio))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("---\n\n")
|
||||
sb.WriteString("现在请分析并输出决策(思维链 + JSON)\n")
|
||||
|
||||
@@ -556,20 +581,20 @@ func parseFullDecisionResponse(aiResponse string, accountEquity float64, btcEthL
|
||||
func extractCoTTrace(response string) string {
|
||||
// 方法1: 优先尝试提取 <reasoning> 标签内容
|
||||
if match := reReasoningTag.FindStringSubmatch(response); match != nil && len(match) > 1 {
|
||||
log.Printf("✓ 使用 <reasoning> 标签提取思维链")
|
||||
logger.Infof("✓ 使用 <reasoning> 标签提取思维链")
|
||||
return strings.TrimSpace(match[1])
|
||||
}
|
||||
|
||||
// 方法2: 如果没有 <reasoning> 标签,但有 <decision> 标签,提取 <decision> 之前的内容
|
||||
if decisionIdx := strings.Index(response, "<decision>"); decisionIdx > 0 {
|
||||
log.Printf("✓ 提取 <decision> 标签之前的内容作为思维链")
|
||||
logger.Infof("✓ 提取 <decision> 标签之前的内容作为思维链")
|
||||
return strings.TrimSpace(response[:decisionIdx])
|
||||
}
|
||||
|
||||
// 方法3: 后备方案 - 查找JSON数组的开始位置
|
||||
jsonStart := strings.Index(response, "[")
|
||||
if jsonStart > 0 {
|
||||
log.Printf("⚠️ 使用旧版格式([ 字符分离)提取思维链")
|
||||
logger.Infof("⚠️ 使用旧版格式([ 字符分离)提取思维链")
|
||||
return strings.TrimSpace(response[:jsonStart])
|
||||
}
|
||||
|
||||
@@ -591,11 +616,11 @@ func extractDecisions(response string) ([]Decision, error) {
|
||||
var jsonPart string
|
||||
if match := reDecisionTag.FindStringSubmatch(s); match != nil && len(match) > 1 {
|
||||
jsonPart = strings.TrimSpace(match[1])
|
||||
log.Printf("✓ 使用 <decision> 标签提取JSON")
|
||||
logger.Infof("✓ 使用 <decision> 标签提取JSON")
|
||||
} else {
|
||||
// 后备方案:使用整个响应
|
||||
jsonPart = s
|
||||
log.Printf("⚠️ 未找到 <decision> 标签,使用全文搜索JSON")
|
||||
logger.Infof("⚠️ 未找到 <decision> 标签,使用全文搜索JSON")
|
||||
}
|
||||
|
||||
// 修复 jsonPart 中的全角字符
|
||||
@@ -621,7 +646,7 @@ func extractDecisions(response string) ([]Decision, error) {
|
||||
jsonContent := strings.TrimSpace(reJSONArray.FindString(jsonPart))
|
||||
if jsonContent == "" {
|
||||
// 🔧 安全回退 (Safe Fallback):当AI只输出思维链没有JSON时,生成保底决策(避免系统崩溃)
|
||||
log.Printf("⚠️ [SafeFallback] AI未输出JSON决策,进入安全等待模式 (AI response without JSON, entering safe wait mode)")
|
||||
logger.Infof("⚠️ [SafeFallback] AI未输出JSON决策,进入安全等待模式 (AI response without JSON, entering safe wait mode)")
|
||||
|
||||
// 提取思维链摘要(最多 240 字符)
|
||||
cotSummary := jsonPart
|
||||
@@ -773,15 +798,12 @@ func findMatchingBracket(s string, start int) int {
|
||||
func validateDecision(d *Decision, accountEquity float64, btcEthLeverage, altcoinLeverage int) error {
|
||||
// 验证action
|
||||
validActions := map[string]bool{
|
||||
"open_long": true,
|
||||
"open_short": true,
|
||||
"close_long": true,
|
||||
"close_short": true,
|
||||
"update_stop_loss": true,
|
||||
"update_take_profit": true,
|
||||
"partial_close": true,
|
||||
"hold": true,
|
||||
"wait": true,
|
||||
"open_long": true,
|
||||
"open_short": true,
|
||||
"close_long": true,
|
||||
"close_short": true,
|
||||
"hold": true,
|
||||
"wait": true,
|
||||
}
|
||||
|
||||
if !validActions[d.Action] {
|
||||
@@ -803,7 +825,7 @@ func validateDecision(d *Decision, accountEquity float64, btcEthLeverage, altcoi
|
||||
return fmt.Errorf("杠杆必须大于0: %d", d.Leverage)
|
||||
}
|
||||
if d.Leverage > maxLeverage {
|
||||
log.Printf("⚠️ [Leverage Fallback] %s 杠杆超限 (%dx > %dx),自动调整为上限值 %dx",
|
||||
logger.Infof("⚠️ [Leverage Fallback] %s 杠杆超限 (%dx > %dx),自动调整为上限值 %dx",
|
||||
d.Symbol, d.Leverage, maxLeverage, maxLeverage)
|
||||
d.Leverage = maxLeverage // 自动修正为上限值
|
||||
}
|
||||
@@ -883,26 +905,5 @@ func validateDecision(d *Decision, accountEquity float64, btcEthLeverage, altcoi
|
||||
}
|
||||
}
|
||||
|
||||
// 动态调整止损验证
|
||||
if d.Action == "update_stop_loss" {
|
||||
if d.NewStopLoss <= 0 {
|
||||
return fmt.Errorf("新止损价格必须大于0: %.2f", d.NewStopLoss)
|
||||
}
|
||||
}
|
||||
|
||||
// 动态调整止盈验证
|
||||
if d.Action == "update_take_profit" {
|
||||
if d.NewTakeProfit <= 0 {
|
||||
return fmt.Errorf("新止盈价格必须大于0: %.2f", d.NewTakeProfit)
|
||||
}
|
||||
}
|
||||
|
||||
// 部分平仓验证
|
||||
if d.Action == "partial_close" {
|
||||
if d.ClosePercentage <= 0 || d.ClosePercentage > 100 {
|
||||
return fmt.Errorf("平仓百分比必须在0-100之间: %.1f", d.ClosePercentage)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,9 +13,6 @@ func TestBuildSystemPrompt_ContainsAllValidActions(t *testing.T) {
|
||||
"open_short",
|
||||
"close_long",
|
||||
"close_short",
|
||||
"update_stop_loss",
|
||||
"update_take_profit",
|
||||
"partial_close",
|
||||
"hold",
|
||||
"wait",
|
||||
}
|
||||
@@ -30,21 +27,3 @@ func TestBuildSystemPrompt_ContainsAllValidActions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildSystemPrompt_ActionListCompleteness 测试 action 列表的完整性
|
||||
func TestBuildSystemPrompt_ActionListCompleteness(t *testing.T) {
|
||||
prompt := buildSystemPrompt(1000.0, 10, 5, "default", "")
|
||||
|
||||
// 检查是否包含关键的缺失 action
|
||||
missingActions := []string{
|
||||
"update_stop_loss",
|
||||
"update_take_profit",
|
||||
"partial_close",
|
||||
}
|
||||
|
||||
for _, action := range missingActions {
|
||||
if !strings.Contains(prompt, action) {
|
||||
t.Errorf("Prompt 缺少关键 action: %s(这会导致 AI 返回无效决策)", action)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,185 +99,6 @@ func TestLeverageFallback(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateStopLossValidation 测试 update_stop_loss 动作的字段验证
|
||||
func TestUpdateStopLossValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
decision Decision
|
||||
wantError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "正确使用new_stop_loss字段",
|
||||
decision: Decision{
|
||||
Symbol: "SOLUSDT",
|
||||
Action: "update_stop_loss",
|
||||
NewStopLoss: 155.5,
|
||||
Reasoning: "移动止损至保本位",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "new_stop_loss为0应该报错",
|
||||
decision: Decision{
|
||||
Symbol: "SOLUSDT",
|
||||
Action: "update_stop_loss",
|
||||
NewStopLoss: 0,
|
||||
Reasoning: "测试错误情况",
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "新止损价格必须大于0",
|
||||
},
|
||||
{
|
||||
name: "new_stop_loss为负数应该报错",
|
||||
decision: Decision{
|
||||
Symbol: "SOLUSDT",
|
||||
Action: "update_stop_loss",
|
||||
NewStopLoss: -100,
|
||||
Reasoning: "测试错误情况",
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "新止损价格必须大于0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateDecision(&tt.decision, 1000.0, 10, 5)
|
||||
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantError && err != nil {
|
||||
if tt.errorMsg != "" && !contains(err.Error(), tt.errorMsg) {
|
||||
t.Errorf("错误信息不匹配: got %q, want to contain %q", err.Error(), tt.errorMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateTakeProfitValidation 测试 update_take_profit 动作的字段验证
|
||||
func TestUpdateTakeProfitValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
decision Decision
|
||||
wantError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "正确使用new_take_profit字段",
|
||||
decision: Decision{
|
||||
Symbol: "BTCUSDT",
|
||||
Action: "update_take_profit",
|
||||
NewTakeProfit: 98000,
|
||||
Reasoning: "调整止盈至关键阻力位",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "new_take_profit为0应该报错",
|
||||
decision: Decision{
|
||||
Symbol: "BTCUSDT",
|
||||
Action: "update_take_profit",
|
||||
NewTakeProfit: 0,
|
||||
Reasoning: "测试错误情况",
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "新止盈价格必须大于0",
|
||||
},
|
||||
{
|
||||
name: "new_take_profit为负数应该报错",
|
||||
decision: Decision{
|
||||
Symbol: "BTCUSDT",
|
||||
Action: "update_take_profit",
|
||||
NewTakeProfit: -1000,
|
||||
Reasoning: "测试错误情况",
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "新止盈价格必须大于0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateDecision(&tt.decision, 1000.0, 10, 5)
|
||||
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantError && err != nil {
|
||||
if tt.errorMsg != "" && !contains(err.Error(), tt.errorMsg) {
|
||||
t.Errorf("错误信息不匹配: got %q, want to contain %q", err.Error(), tt.errorMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPartialCloseValidation 测试 partial_close 动作的字段验证
|
||||
func TestPartialCloseValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
decision Decision
|
||||
wantError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "正确使用close_percentage字段",
|
||||
decision: Decision{
|
||||
Symbol: "ETHUSDT",
|
||||
Action: "partial_close",
|
||||
ClosePercentage: 50.0,
|
||||
Reasoning: "锁定一半利润",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "close_percentage为0应该报错",
|
||||
decision: Decision{
|
||||
Symbol: "ETHUSDT",
|
||||
Action: "partial_close",
|
||||
ClosePercentage: 0,
|
||||
Reasoning: "测试错误情况",
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "平仓百分比必须在0-100之间",
|
||||
},
|
||||
{
|
||||
name: "close_percentage超过100应该报错",
|
||||
decision: Decision{
|
||||
Symbol: "ETHUSDT",
|
||||
Action: "partial_close",
|
||||
ClosePercentage: 150,
|
||||
Reasoning: "测试错误情况",
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "平仓百分比必须在0-100之间",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateDecision(&tt.decision, 1000.0, 10, 5)
|
||||
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("validateDecision() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantError && err != nil {
|
||||
if tt.errorMsg != "" && !contains(err.Error(), tt.errorMsg) {
|
||||
t.Errorf("错误信息不匹配: got %q, want to contain %q", err.Error(), tt.errorMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// contains 检查字符串是否包含子串(辅助函数)
|
||||
func contains(s, substr string) bool {
|
||||
|
||||
@@ -1,286 +0,0 @@
|
||||
#!/bin/bash
|
||||
# NOFX 加密系統一鍵部署腳本
|
||||
# 使用方式: chmod +x deploy_encryption.sh && ./deploy_encryption.sh
|
||||
|
||||
set -e # 遇到錯誤立即退出
|
||||
|
||||
# 顏色定義
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# 輔助函數
|
||||
log_info() {
|
||||
echo -e "${BLUE}ℹ️ $1${NC}"
|
||||
}
|
||||
|
||||
log_success() {
|
||||
echo -e "${GREEN}✅ $1${NC}"
|
||||
}
|
||||
|
||||
log_warning() {
|
||||
echo -e "${YELLOW}⚠️ $1${NC}"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}❌ $1${NC}"
|
||||
}
|
||||
|
||||
# 檢查必要工具
|
||||
check_dependencies() {
|
||||
log_info "檢查依賴工具..."
|
||||
|
||||
if ! command -v go &> /dev/null; then
|
||||
log_error "Go 未安裝,請先安裝 Go 1.21+"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v npm &> /dev/null; then
|
||||
log_error "npm 未安裝,請先安裝 Node.js 18+"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v sqlite3 &> /dev/null; then
|
||||
log_warning "sqlite3 未安裝,部分驗證功能不可用"
|
||||
fi
|
||||
|
||||
log_success "依賴檢查通過"
|
||||
}
|
||||
|
||||
# 備份數據庫
|
||||
backup_database() {
|
||||
log_info "備份現有數據庫..."
|
||||
|
||||
if [ -f "config.db" ]; then
|
||||
BACKUP_FILE="config.db.pre_encryption.$(date +%Y%m%d_%H%M%S).backup"
|
||||
cp config.db "$BACKUP_FILE"
|
||||
log_success "數據庫已備份到: $BACKUP_FILE"
|
||||
else
|
||||
log_warning "未找到 config.db,跳過備份(首次安裝)"
|
||||
fi
|
||||
}
|
||||
|
||||
# 創建密鑰目錄
|
||||
setup_secrets_dir() {
|
||||
log_info "設置密鑰目錄..."
|
||||
|
||||
if [ ! -d ".secrets" ]; then
|
||||
mkdir -p .secrets
|
||||
chmod 700 .secrets
|
||||
log_success "密鑰目錄已創建: .secrets/"
|
||||
else
|
||||
log_warning "密鑰目錄已存在,跳過創建"
|
||||
fi
|
||||
}
|
||||
|
||||
# 更新 .gitignore
|
||||
update_gitignore() {
|
||||
log_info "更新 .gitignore..."
|
||||
|
||||
if ! grep -q ".secrets/" .gitignore 2>/dev/null; then
|
||||
echo ".secrets/" >> .gitignore
|
||||
log_success "已添加 .secrets/ 到 .gitignore"
|
||||
fi
|
||||
|
||||
if ! grep -q "config.db.backup" .gitignore 2>/dev/null; then
|
||||
echo "config.db.*.backup" >> .gitignore
|
||||
log_success "已添加備份檔案規則到 .gitignore"
|
||||
fi
|
||||
}
|
||||
|
||||
# 安裝依賴
|
||||
install_dependencies() {
|
||||
log_info "安裝 Go 依賴..."
|
||||
go mod tidy
|
||||
log_success "Go 依賴已更新"
|
||||
|
||||
log_info "安裝前端依賴..."
|
||||
cd web
|
||||
if [ ! -d "node_modules" ]; then
|
||||
npm install
|
||||
fi
|
||||
npm install tweetnacl tweetnacl-util @noble/secp256k1 --save
|
||||
cd ..
|
||||
log_success "前端依賴已安裝"
|
||||
}
|
||||
|
||||
# 運行測試
|
||||
run_tests() {
|
||||
log_info "運行加密系統測試..."
|
||||
|
||||
if go test ./crypto -v > /tmp/nofx_test.log 2>&1; then
|
||||
log_success "加密系統測試通過"
|
||||
cat /tmp/nofx_test.log | grep "✅"
|
||||
else
|
||||
log_error "加密系統測試失敗,詳情:"
|
||||
cat /tmp/nofx_test.log
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 遷移數據
|
||||
migrate_data() {
|
||||
log_info "遷移現有數據到加密格式..."
|
||||
|
||||
if [ -f "config.db" ]; then
|
||||
# 檢查是否已經加密過
|
||||
if sqlite3 config.db "SELECT api_key FROM exchanges LIMIT 1;" 2>/dev/null | grep -q "=="; then
|
||||
log_warning "數據庫似乎已經加密過,跳過遷移"
|
||||
read -p "是否強制重新遷移?(y/N): " -n 1 -r
|
||||
echo
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
return
|
||||
fi
|
||||
fi
|
||||
|
||||
if go run scripts/migrate_encryption.go; then
|
||||
log_success "數據遷移完成"
|
||||
else
|
||||
log_error "數據遷移失敗"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
log_warning "未找到數據庫,跳過遷移"
|
||||
fi
|
||||
}
|
||||
|
||||
# 設置環境變數
|
||||
setup_env_vars() {
|
||||
log_info "設置環境變數..."
|
||||
|
||||
if [ -f ".secrets/master.key" ]; then
|
||||
MASTER_KEY=$(cat .secrets/master.key)
|
||||
|
||||
# 添加到當前 shell 配置
|
||||
SHELL_RC="$HOME/.bashrc"
|
||||
if [ -f "$HOME/.zshrc" ]; then
|
||||
SHELL_RC="$HOME/.zshrc"
|
||||
fi
|
||||
|
||||
if ! grep -q "NOFX_MASTER_KEY" "$SHELL_RC" 2>/dev/null; then
|
||||
echo "" >> "$SHELL_RC"
|
||||
echo "# NOFX 加密系統主密鑰" >> "$SHELL_RC"
|
||||
echo "export NOFX_MASTER_KEY='$MASTER_KEY'" >> "$SHELL_RC"
|
||||
log_success "主密鑰已添加到 $SHELL_RC"
|
||||
else
|
||||
log_warning "主密鑰已存在於 $SHELL_RC"
|
||||
fi
|
||||
|
||||
# 導出到當前 session
|
||||
export NOFX_MASTER_KEY="$MASTER_KEY"
|
||||
log_success "主密鑰已導出到當前 session"
|
||||
else
|
||||
log_warning "主密鑰文件未生成,請先運行應用初始化"
|
||||
fi
|
||||
}
|
||||
|
||||
# 驗證部署
|
||||
verify_deployment() {
|
||||
log_info "驗證部署結果..."
|
||||
|
||||
# 1. 檢查密鑰檔案
|
||||
if [ -f ".secrets/rsa_private.pem" ] && [ -f ".secrets/rsa_public.pem" ] && [ -f ".secrets/master.key" ]; then
|
||||
log_success "密鑰檔案完整"
|
||||
else
|
||||
log_error "密鑰檔案缺失,請檢查日誌"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# 2. 檢查檔案權限
|
||||
PERM=$(stat -f "%Lp" .secrets 2>/dev/null || stat -c "%a" .secrets 2>/dev/null)
|
||||
if [ "$PERM" = "700" ]; then
|
||||
log_success "密鑰目錄權限正確 (700)"
|
||||
else
|
||||
log_warning "密鑰目錄權限為 $PERM,建議修改為 700"
|
||||
chmod 700 .secrets
|
||||
fi
|
||||
|
||||
# 3. 檢查資料庫加密
|
||||
if [ -f "config.db" ] && command -v sqlite3 &> /dev/null; then
|
||||
SAMPLE=$(sqlite3 config.db "SELECT api_key FROM exchanges WHERE api_key != '' LIMIT 1;" 2>/dev/null || echo "")
|
||||
if echo "$SAMPLE" | grep -q "=="; then
|
||||
log_success "數據庫密鑰已加密(Base64 格式)"
|
||||
else
|
||||
log_warning "數據庫可能未加密或無數據"
|
||||
fi
|
||||
fi
|
||||
|
||||
log_success "部署驗證通過"
|
||||
}
|
||||
|
||||
# 打印後續步驟
|
||||
print_next_steps() {
|
||||
echo ""
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo -e "${GREEN}🎉 加密系統部署成功!${NC}"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo ""
|
||||
echo "📝 後續步驟:"
|
||||
echo ""
|
||||
echo " 1️⃣ 啟動後端服務:"
|
||||
echo " $ go run main.go"
|
||||
echo ""
|
||||
echo " 2️⃣ 啟動前端服務:"
|
||||
echo " $ cd web && npm run dev"
|
||||
echo ""
|
||||
echo " 3️⃣ 驗證加密功能:"
|
||||
echo " $ curl http://localhost:8080/api/crypto/public-key"
|
||||
echo ""
|
||||
echo " 4️⃣ 查看審計日誌:"
|
||||
echo " $ sqlite3 config.db 'SELECT * FROM audit_logs ORDER BY timestamp DESC LIMIT 10;'"
|
||||
echo ""
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo ""
|
||||
echo "⚠️ 重要提醒:"
|
||||
echo ""
|
||||
echo " • 請妥善保管 .secrets/ 目錄(已設置為 700 權限)"
|
||||
echo " • 生產環境務必使用環境變數管理主密鑰"
|
||||
echo " • 定期執行密鑰輪換(建議每季度一次)"
|
||||
echo " • 數據庫備份已保存,驗證無誤後可手動刪除"
|
||||
echo ""
|
||||
echo "📚 詳細文檔:"
|
||||
echo " - 快速開始: cat SECURITY_QUICKSTART.md"
|
||||
echo " - 完整指南: cat ENCRYPTION_DEPLOYMENT.md"
|
||||
echo ""
|
||||
}
|
||||
|
||||
# 主函數
|
||||
main() {
|
||||
echo ""
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo -e "${BLUE}🔐 NOFX 加密系統部署腳本${NC}"
|
||||
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
||||
echo ""
|
||||
|
||||
# 確認執行
|
||||
log_warning "此腳本將:"
|
||||
echo " 1. 備份現有數據庫"
|
||||
echo " 2. 生成 RSA-4096 密鑰對"
|
||||
echo " 3. 生成 AES-256 主密鑰"
|
||||
echo " 4. 遷移現有數據到加密格式"
|
||||
echo " 5. 設置環境變數"
|
||||
echo ""
|
||||
read -p "是否繼續?(y/N): " -n 1 -r
|
||||
echo
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
log_info "已取消部署"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# 執行部署步驟
|
||||
check_dependencies
|
||||
backup_database
|
||||
setup_secrets_dir
|
||||
update_gitignore
|
||||
install_dependencies
|
||||
run_tests
|
||||
migrate_data
|
||||
setup_env_vars
|
||||
verify_deployment
|
||||
print_next_steps
|
||||
}
|
||||
|
||||
# 執行主函數
|
||||
main
|
||||
@@ -11,17 +11,17 @@ services:
|
||||
- "${NOFX_BACKEND_PORT:-8080}:8080"
|
||||
volumes:
|
||||
- ./config.json:/app/config.json:ro
|
||||
- ./config.db:/app/config.db
|
||||
- ./data.db:/app/data.db
|
||||
- ./beta_codes.txt:/app/beta_codes.txt:ro
|
||||
- ./decision_logs:/app/decision_logs
|
||||
- ./prompts:/app/prompts
|
||||
- ./secrets:/app/secrets:ro # RSA密钥文件
|
||||
- /etc/localtime:/etc/localtime:ro # Sync host time
|
||||
environment:
|
||||
- TZ=${NOFX_TIMEZONE:-Asia/Shanghai} # Set timezone
|
||||
- AI_MAX_TOKENS=4000 # AI响应的最大token数(默认2000,建议4000-8000)
|
||||
- DATA_ENCRYPTION_KEY=${DATA_ENCRYPTION_KEY} # 数据库加密密钥
|
||||
- JWT_SECRET=${JWT_SECRET} # JWT认证密钥
|
||||
- RSA_PRIVATE_KEY=${RSA_PRIVATE_KEY} # RSA私钥(客户端加密)
|
||||
networks:
|
||||
- nofx-network
|
||||
healthcheck:
|
||||
|
||||
@@ -1,21 +1,8 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Config 日志配置(简化版)
|
||||
type Config struct {
|
||||
Level string `json:"level"` // 日志级别: debug, info, warn, error (默认: info)
|
||||
Telegram *TelegramConfig `json:"telegram"` // Telegram推送配置(可选)
|
||||
}
|
||||
|
||||
// TelegramConfig Telegram推送配置(简化版,高级参数使用默认值)
|
||||
type TelegramConfig struct {
|
||||
Enabled bool `json:"enabled"` // 是否启用(默认: false)
|
||||
BotToken string `json:"bot_token"` // Bot Token
|
||||
ChatID int64 `json:"chat_id"` // Chat ID
|
||||
MinLevel string `json:"min_level"` // 最低日志级别,该级别及以上的日志会推送到Telegram(可选,默认: error)
|
||||
Level string `json:"level"` // 日志级别: debug, info, warn, error (默认: info)
|
||||
}
|
||||
|
||||
// SetDefaults 设置默认值
|
||||
@@ -24,41 +11,3 @@ func (c *Config) SetDefaults() {
|
||||
c.Level = "info"
|
||||
}
|
||||
}
|
||||
|
||||
// GetLogrusLevels 返回要推送到Telegram的日志级别
|
||||
// 根据配置的MinLevel返回该级别及以上的所有日志级别
|
||||
// 如果未配置或配置无效,默认返回error, fatal, panic(向后兼容)
|
||||
func (tc *TelegramConfig) GetLogrusLevels() []logrus.Level {
|
||||
// 如果未配置,使用默认值error(向后兼容)
|
||||
minLevelStr := tc.MinLevel
|
||||
if minLevelStr == "" {
|
||||
minLevelStr = "error"
|
||||
}
|
||||
|
||||
// 解析配置的日志级别
|
||||
minLevel, err := logrus.ParseLevel(minLevelStr)
|
||||
if err != nil {
|
||||
// 如果解析失败,使用默认值error(向后兼容)
|
||||
minLevel = logrus.ErrorLevel
|
||||
}
|
||||
|
||||
// 定义所有日志级别(从高到低:panic, fatal, error, warn, info, debug)
|
||||
allLevels := []logrus.Level{
|
||||
logrus.PanicLevel,
|
||||
logrus.FatalLevel,
|
||||
logrus.ErrorLevel,
|
||||
logrus.WarnLevel,
|
||||
logrus.InfoLevel,
|
||||
logrus.DebugLevel,
|
||||
}
|
||||
|
||||
// 返回所有大于等于minLevel的日志级别
|
||||
var result []logrus.Level
|
||||
for _, level := range allLevels {
|
||||
if level <= minLevel {
|
||||
result = append(result, level)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
{
|
||||
"traders": [
|
||||
{
|
||||
"id": "trader1",
|
||||
"name": "AI Trader 1",
|
||||
"enabled": true,
|
||||
"ai_model": "deepseek",
|
||||
"exchange": "binance",
|
||||
"binance_api_key": "your_api_key",
|
||||
"binance_secret_key": "your_secret_key",
|
||||
"deepseek_key": "your_deepseek_key",
|
||||
"initial_balance": 1000,
|
||||
"scan_interval_minutes": 3
|
||||
}
|
||||
],
|
||||
"use_default_coins": true,
|
||||
"default_coins": ["BTCUSDT", "ETHUSDT", "SOLUSDT"],
|
||||
"api_server_port": 8080,
|
||||
"leverage": {
|
||||
"btc_eth_leverage": 5,
|
||||
"altcoin_leverage": 5
|
||||
},
|
||||
"log": {
|
||||
"level": "info",
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"bot_token": "79472419:feafe231414",
|
||||
"chat_id": -100323252626,
|
||||
"min_level": "error"
|
||||
}
|
||||
},
|
||||
"_comment": "日志配置说明:level 可选值为 debug/info/warn/error,默认 info。telegram 部分作为可选配置, Telegram 推送默认为 error/fatal/panic 级别,min_level 如果设置为warn,则推送warn级别及以上的日志"
|
||||
}
|
||||
@@ -1,768 +0,0 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DecisionRecord 决策记录
|
||||
type DecisionRecord struct {
|
||||
Timestamp time.Time `json:"timestamp"` // 决策时间
|
||||
CycleNumber int `json:"cycle_number"` // 周期编号
|
||||
SystemPrompt string `json:"system_prompt"` // 系统提示词(发送给AI的系统prompt)
|
||||
InputPrompt string `json:"input_prompt"` // 发送给AI的输入prompt
|
||||
CoTTrace string `json:"cot_trace"` // AI思维链(输出)
|
||||
DecisionJSON string `json:"decision_json"` // 决策JSON
|
||||
AccountState AccountSnapshot `json:"account_state"` // 账户状态快照
|
||||
Positions []PositionSnapshot `json:"positions"` // 持仓快照
|
||||
CandidateCoins []string `json:"candidate_coins"` // 候选币种列表
|
||||
Decisions []DecisionAction `json:"decisions"` // 执行的决策
|
||||
ExecutionLog []string `json:"execution_log"` // 执行日志
|
||||
Success bool `json:"success"` // 是否成功
|
||||
ErrorMessage string `json:"error_message"` // 错误信息(如果有)
|
||||
// AIRequestDurationMs 记录 AI API 调用耗时(毫秒),方便评估调用性能
|
||||
AIRequestDurationMs int64 `json:"ai_request_duration_ms,omitempty"`
|
||||
}
|
||||
|
||||
// AccountSnapshot 账户状态快照
|
||||
type AccountSnapshot struct {
|
||||
TotalBalance float64 `json:"total_balance"`
|
||||
AvailableBalance float64 `json:"available_balance"`
|
||||
TotalUnrealizedProfit float64 `json:"total_unrealized_profit"`
|
||||
PositionCount int `json:"position_count"`
|
||||
MarginUsedPct float64 `json:"margin_used_pct"`
|
||||
InitialBalance float64 `json:"initial_balance"` // 记录当时的初始余额基准
|
||||
}
|
||||
|
||||
// PositionSnapshot 持仓快照
|
||||
type PositionSnapshot struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Side string `json:"side"`
|
||||
PositionAmt float64 `json:"position_amt"`
|
||||
EntryPrice float64 `json:"entry_price"`
|
||||
MarkPrice float64 `json:"mark_price"`
|
||||
UnrealizedProfit float64 `json:"unrealized_profit"`
|
||||
Leverage float64 `json:"leverage"`
|
||||
LiquidationPrice float64 `json:"liquidation_price"`
|
||||
}
|
||||
|
||||
// DecisionAction 决策动作
|
||||
type DecisionAction struct {
|
||||
Action string `json:"action"` // open_long, open_short, close_long, close_short, update_stop_loss, update_take_profit, partial_close
|
||||
Symbol string `json:"symbol"` // 币种
|
||||
Quantity float64 `json:"quantity"` // 数量(部分平仓时使用)
|
||||
Leverage int `json:"leverage"` // 杠杆(开仓时)
|
||||
Price float64 `json:"price"` // 执行价格
|
||||
OrderID int64 `json:"order_id"` // 订单ID
|
||||
Timestamp time.Time `json:"timestamp"` // 执行时间
|
||||
Success bool `json:"success"` // 是否成功
|
||||
Error string `json:"error"` // 错误信息
|
||||
}
|
||||
|
||||
// IDecisionLogger 决策日志记录器接口
|
||||
type IDecisionLogger interface {
|
||||
// LogDecision 记录决策
|
||||
LogDecision(record *DecisionRecord) error
|
||||
// GetLatestRecords 获取最近N条记录(按时间正序:从旧到新)
|
||||
GetLatestRecords(n int) ([]*DecisionRecord, error)
|
||||
// GetRecordByDate 获取指定日期的所有记录
|
||||
GetRecordByDate(date time.Time) ([]*DecisionRecord, error)
|
||||
// CleanOldRecords 清理N天前的旧记录
|
||||
CleanOldRecords(days int) error
|
||||
// GetStatistics 获取统计信息
|
||||
GetStatistics() (*Statistics, error)
|
||||
// AnalyzePerformance 分析最近N个周期的交易表现
|
||||
AnalyzePerformance(lookbackCycles int) (*PerformanceAnalysis, error)
|
||||
// SetCycleNumber 允许恢复内部计数(用于回测恢复)
|
||||
SetCycleNumber(n int)
|
||||
}
|
||||
|
||||
// DecisionLogger 决策日志记录器
|
||||
type DecisionLogger struct {
|
||||
logDir string
|
||||
cycleNumber int
|
||||
}
|
||||
|
||||
// NewDecisionLogger 创建决策日志记录器
|
||||
func NewDecisionLogger(logDir string) IDecisionLogger {
|
||||
if logDir == "" {
|
||||
logDir = "decision_logs"
|
||||
}
|
||||
|
||||
// 确保日志目录存在(使用安全权限:只有所有者可访问)
|
||||
if err := os.MkdirAll(logDir, 0700); err != nil {
|
||||
fmt.Printf("⚠ 创建日志目录失败: %v\n", err)
|
||||
}
|
||||
|
||||
// 强制设置目录权限(即使目录已存在)- 确保安全
|
||||
if err := os.Chmod(logDir, 0700); err != nil {
|
||||
fmt.Printf("⚠ 设置日志目录权限失败: %v\n", err)
|
||||
}
|
||||
|
||||
return &DecisionLogger{
|
||||
logDir: logDir,
|
||||
cycleNumber: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// SetCycleNumber 允许外部恢复内部的周期计数(用于回测恢复)。
|
||||
func (l *DecisionLogger) SetCycleNumber(n int) {
|
||||
if n > 0 {
|
||||
l.cycleNumber = n
|
||||
}
|
||||
}
|
||||
|
||||
// LogDecision 记录决策
|
||||
func (l *DecisionLogger) LogDecision(record *DecisionRecord) error {
|
||||
l.cycleNumber++
|
||||
record.CycleNumber = l.cycleNumber
|
||||
if record.Timestamp.IsZero() {
|
||||
record.Timestamp = time.Now().UTC()
|
||||
} else {
|
||||
record.Timestamp = record.Timestamp.UTC()
|
||||
}
|
||||
|
||||
// 生成文件名:decision_YYYYMMDD_HHMMSS_cycleN.json
|
||||
filename := fmt.Sprintf("decision_%s_cycle%d.json",
|
||||
record.Timestamp.Format("20060102_150405"),
|
||||
record.CycleNumber)
|
||||
|
||||
filepath := filepath.Join(l.logDir, filename)
|
||||
|
||||
// 序列化为JSON(带缩进,方便阅读)
|
||||
data, err := json.MarshalIndent(record, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化决策记录失败: %w", err)
|
||||
}
|
||||
|
||||
// 写入文件(使用安全权限:只有所有者可读写)
|
||||
if err := ioutil.WriteFile(filepath, data, 0600); err != nil {
|
||||
return fmt.Errorf("写入决策记录失败: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("📝 决策记录已保存: %s\n", filename)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLatestRecords 获取最近N条记录(按时间正序:从旧到新)
|
||||
func (l *DecisionLogger) GetLatestRecords(n int) ([]*DecisionRecord, error) {
|
||||
files, err := ioutil.ReadDir(l.logDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取日志目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 先按修改时间倒序收集(最新的在前)
|
||||
var records []*DecisionRecord
|
||||
count := 0
|
||||
for i := len(files) - 1; i >= 0 && count < n; i-- {
|
||||
file := files[i]
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
filepath := filepath.Join(l.logDir, file.Name())
|
||||
data, err := ioutil.ReadFile(filepath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var record DecisionRecord
|
||||
if err := json.Unmarshal(data, &record); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
records = append(records, &record)
|
||||
count++
|
||||
}
|
||||
|
||||
// 反转数组,让时间从旧到新排列(用于图表显示)
|
||||
for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 {
|
||||
records[i], records[j] = records[j], records[i]
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// GetRecordByDate 获取指定日期的所有记录
|
||||
func (l *DecisionLogger) GetRecordByDate(date time.Time) ([]*DecisionRecord, error) {
|
||||
dateStr := date.Format("20060102")
|
||||
pattern := filepath.Join(l.logDir, fmt.Sprintf("decision_%s_*.json", dateStr))
|
||||
|
||||
files, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找日志文件失败: %w", err)
|
||||
}
|
||||
|
||||
var records []*DecisionRecord
|
||||
for _, filepath := range files {
|
||||
data, err := ioutil.ReadFile(filepath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var record DecisionRecord
|
||||
if err := json.Unmarshal(data, &record); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
records = append(records, &record)
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// CleanOldRecords 清理N天前的旧记录
|
||||
func (l *DecisionLogger) CleanOldRecords(days int) error {
|
||||
cutoffTime := time.Now().AddDate(0, 0, -days)
|
||||
|
||||
files, err := ioutil.ReadDir(l.logDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取日志目录失败: %w", err)
|
||||
}
|
||||
|
||||
removedCount := 0
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
if file.ModTime().Before(cutoffTime) {
|
||||
filepath := filepath.Join(l.logDir, file.Name())
|
||||
if err := os.Remove(filepath); err != nil {
|
||||
fmt.Printf("⚠ 删除旧记录失败 %s: %v\n", file.Name(), err)
|
||||
continue
|
||||
}
|
||||
removedCount++
|
||||
}
|
||||
}
|
||||
|
||||
if removedCount > 0 {
|
||||
fmt.Printf("🗑️ 已清理 %d 条旧记录(%d天前)\n", removedCount, days)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStatistics 获取统计信息
|
||||
func (l *DecisionLogger) GetStatistics() (*Statistics, error) {
|
||||
files, err := ioutil.ReadDir(l.logDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取日志目录失败: %w", err)
|
||||
}
|
||||
|
||||
stats := &Statistics{}
|
||||
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
filepath := filepath.Join(l.logDir, file.Name())
|
||||
data, err := ioutil.ReadFile(filepath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var record DecisionRecord
|
||||
if err := json.Unmarshal(data, &record); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
stats.TotalCycles++
|
||||
|
||||
for _, action := range record.Decisions {
|
||||
if action.Success {
|
||||
switch action.Action {
|
||||
case "open_long", "open_short":
|
||||
stats.TotalOpenPositions++
|
||||
case "close_long", "close_short", "auto_close_long", "auto_close_short":
|
||||
stats.TotalClosePositions++
|
||||
// 🔧 BUG FIX:partial_close 不計入 TotalClosePositions,避免重複計數
|
||||
// case "partial_close": // 不計數,因為只有完全平倉才算一次
|
||||
// update_stop_loss 和 update_take_profit 不計入統計
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if record.Success {
|
||||
stats.SuccessfulCycles++
|
||||
} else {
|
||||
stats.FailedCycles++
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// Statistics 统计信息
|
||||
type Statistics struct {
|
||||
TotalCycles int `json:"total_cycles"`
|
||||
SuccessfulCycles int `json:"successful_cycles"`
|
||||
FailedCycles int `json:"failed_cycles"`
|
||||
TotalOpenPositions int `json:"total_open_positions"`
|
||||
TotalClosePositions int `json:"total_close_positions"`
|
||||
}
|
||||
|
||||
// TradeOutcome 单笔交易结果
|
||||
type TradeOutcome struct {
|
||||
Symbol string `json:"symbol"` // 币种
|
||||
Side string `json:"side"` // long/short
|
||||
Quantity float64 `json:"quantity"` // 仓位数量
|
||||
Leverage int `json:"leverage"` // 杠杆倍数
|
||||
OpenPrice float64 `json:"open_price"` // 开仓价
|
||||
ClosePrice float64 `json:"close_price"` // 平仓价
|
||||
PositionValue float64 `json:"position_value"` // 仓位价值(quantity × openPrice)
|
||||
MarginUsed float64 `json:"margin_used"` // 保证金使用(positionValue / leverage)
|
||||
PnL float64 `json:"pn_l"` // 盈亏(USDT)
|
||||
PnLPct float64 `json:"pn_l_pct"` // 盈亏百分比(相对保证金)
|
||||
Duration string `json:"duration"` // 持仓时长
|
||||
OpenTime time.Time `json:"open_time"` // 开仓时间
|
||||
CloseTime time.Time `json:"close_time"` // 平仓时间
|
||||
WasStopLoss bool `json:"was_stop_loss"` // 是否止损
|
||||
}
|
||||
|
||||
// PerformanceAnalysis 交易表现分析
|
||||
type PerformanceAnalysis struct {
|
||||
TotalTrades int `json:"total_trades"` // 总交易数
|
||||
WinningTrades int `json:"winning_trades"` // 盈利交易数
|
||||
LosingTrades int `json:"losing_trades"` // 亏损交易数
|
||||
WinRate float64 `json:"win_rate"` // 胜率
|
||||
AvgWin float64 `json:"avg_win"` // 平均盈利
|
||||
AvgLoss float64 `json:"avg_loss"` // 平均亏损
|
||||
ProfitFactor float64 `json:"profit_factor"` // 盈亏比
|
||||
SharpeRatio float64 `json:"sharpe_ratio"` // 夏普比率(风险调整后收益)
|
||||
RecentTrades []TradeOutcome `json:"recent_trades"` // 最近N笔交易
|
||||
SymbolStats map[string]*SymbolPerformance `json:"symbol_stats"` // 各币种表现
|
||||
BestSymbol string `json:"best_symbol"` // 表现最好的币种
|
||||
WorstSymbol string `json:"worst_symbol"` // 表现最差的币种
|
||||
}
|
||||
|
||||
// SymbolPerformance 币种表现统计
|
||||
type SymbolPerformance struct {
|
||||
Symbol string `json:"symbol"` // 币种
|
||||
TotalTrades int `json:"total_trades"` // 交易次数
|
||||
WinningTrades int `json:"winning_trades"` // 盈利次数
|
||||
LosingTrades int `json:"losing_trades"` // 亏损次数
|
||||
WinRate float64 `json:"win_rate"` // 胜率
|
||||
TotalPnL float64 `json:"total_pn_l"` // 总盈亏
|
||||
AvgPnL float64 `json:"avg_pn_l"` // 平均盈亏
|
||||
}
|
||||
|
||||
// AnalyzePerformance 分析最近N个周期的交易表现
|
||||
func (l *DecisionLogger) AnalyzePerformance(lookbackCycles int) (*PerformanceAnalysis, error) {
|
||||
records, err := l.GetLatestRecords(lookbackCycles)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取历史记录失败: %w", err)
|
||||
}
|
||||
|
||||
if len(records) == 0 {
|
||||
return &PerformanceAnalysis{
|
||||
RecentTrades: []TradeOutcome{},
|
||||
SymbolStats: make(map[string]*SymbolPerformance),
|
||||
}, nil
|
||||
}
|
||||
|
||||
analysis := &PerformanceAnalysis{
|
||||
RecentTrades: []TradeOutcome{},
|
||||
SymbolStats: make(map[string]*SymbolPerformance),
|
||||
}
|
||||
|
||||
// 追踪持仓状态:symbol_side -> {side, openPrice, openTime, quantity, leverage}
|
||||
openPositions := make(map[string]map[string]interface{})
|
||||
|
||||
// 为了避免开仓记录在窗口外导致匹配失败,需要先从所有历史记录中找出未平仓的持仓
|
||||
// 获取更多历史记录来构建完整的持仓状态(使用更大的窗口)
|
||||
allRecords, err := l.GetLatestRecords(lookbackCycles * 3) // 扩大3倍窗口
|
||||
if err == nil && len(allRecords) > len(records) {
|
||||
// 先从扩大的窗口中收集所有开仓记录
|
||||
for _, record := range allRecords {
|
||||
for _, action := range record.Decisions {
|
||||
if !action.Success {
|
||||
continue
|
||||
}
|
||||
|
||||
symbol := action.Symbol
|
||||
side := ""
|
||||
if action.Action == "open_long" || action.Action == "close_long" || action.Action == "partial_close" || action.Action == "auto_close_long" {
|
||||
side = "long"
|
||||
} else if action.Action == "open_short" || action.Action == "close_short" || action.Action == "auto_close_short" {
|
||||
side = "short"
|
||||
}
|
||||
|
||||
// partial_close 需要根據持倉判斷方向
|
||||
if action.Action == "partial_close" && side == "" {
|
||||
for key, pos := range openPositions {
|
||||
if posSymbol, _ := pos["side"].(string); key == symbol+"_"+posSymbol {
|
||||
side = posSymbol
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
posKey := symbol + "_" + side
|
||||
|
||||
switch action.Action {
|
||||
case "open_long", "open_short":
|
||||
// 记录开仓
|
||||
openPositions[posKey] = map[string]interface{}{
|
||||
"side": side,
|
||||
"openPrice": action.Price,
|
||||
"openTime": action.Timestamp,
|
||||
"quantity": action.Quantity,
|
||||
"leverage": action.Leverage,
|
||||
}
|
||||
case "close_long", "close_short", "auto_close_long", "auto_close_short":
|
||||
// 移除已平仓记录
|
||||
delete(openPositions, posKey)
|
||||
// partial_close 不處理,保留持倉記錄
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 遍历分析窗口内的记录,生成交易结果
|
||||
for _, record := range records {
|
||||
for _, action := range record.Decisions {
|
||||
if !action.Success {
|
||||
continue
|
||||
}
|
||||
|
||||
symbol := action.Symbol
|
||||
side := ""
|
||||
if action.Action == "open_long" || action.Action == "close_long" || action.Action == "partial_close" || action.Action == "auto_close_long" {
|
||||
side = "long"
|
||||
} else if action.Action == "open_short" || action.Action == "close_short" || action.Action == "auto_close_short" {
|
||||
side = "short"
|
||||
}
|
||||
|
||||
// partial_close 需要根據持倉判斷方向
|
||||
if action.Action == "partial_close" {
|
||||
// 從 openPositions 中查找持倉方向
|
||||
for key, pos := range openPositions {
|
||||
if posSymbol, _ := pos["side"].(string); key == symbol+"_"+posSymbol {
|
||||
side = posSymbol
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
posKey := symbol + "_" + side // 使用symbol_side作为key,区分多空持仓
|
||||
|
||||
switch action.Action {
|
||||
case "open_long", "open_short":
|
||||
// 更新开仓记录(可能已经在预填充时记录过了)
|
||||
openPositions[posKey] = map[string]interface{}{
|
||||
"side": side,
|
||||
"openPrice": action.Price,
|
||||
"openTime": action.Timestamp,
|
||||
"quantity": action.Quantity,
|
||||
"leverage": action.Leverage,
|
||||
"remainingQuantity": action.Quantity, // 🔧 BUG FIX:追蹤剩餘數量
|
||||
"accumulatedPnL": 0.0, // 🔧 BUG FIX:累積部分平倉盈虧
|
||||
"partialCloseCount": 0, // 🔧 BUG FIX:部分平倉次數
|
||||
"partialCloseVolume": 0.0, // 🔧 BUG FIX:部分平倉總量
|
||||
}
|
||||
|
||||
case "close_long", "close_short", "partial_close", "auto_close_long", "auto_close_short":
|
||||
// 查找对应的开仓记录(可能来自预填充或当前窗口)
|
||||
if openPos, exists := openPositions[posKey]; exists {
|
||||
openPrice := openPos["openPrice"].(float64)
|
||||
openTime := openPos["openTime"].(time.Time)
|
||||
side := openPos["side"].(string)
|
||||
quantity := openPos["quantity"].(float64)
|
||||
leverage := openPos["leverage"].(int)
|
||||
|
||||
// 🔧 BUG FIX:取得追蹤字段(若不存在則初始化)
|
||||
remainingQty, _ := openPos["remainingQuantity"].(float64)
|
||||
if remainingQty == 0 {
|
||||
remainingQty = quantity // 兼容舊數據(沒有 remainingQuantity 字段)
|
||||
}
|
||||
accumulatedPnL, _ := openPos["accumulatedPnL"].(float64)
|
||||
partialCloseCount, _ := openPos["partialCloseCount"].(int)
|
||||
partialCloseVolume, _ := openPos["partialCloseVolume"].(float64)
|
||||
|
||||
// 对于 partial_close,使用实际平仓数量;否则使用剩余仓位数量
|
||||
actualQuantity := remainingQty
|
||||
if action.Action == "partial_close" {
|
||||
actualQuantity = action.Quantity
|
||||
}
|
||||
|
||||
// 计算本次平仓的盈亏(USDT)
|
||||
var pnl float64
|
||||
if side == "long" {
|
||||
pnl = actualQuantity * (action.Price - openPrice)
|
||||
} else {
|
||||
pnl = actualQuantity * (openPrice - action.Price)
|
||||
}
|
||||
|
||||
// 🔧 BUG FIX:處理 partial_close 聚合邏輯
|
||||
if action.Action == "partial_close" {
|
||||
// 累積盈虧和數量
|
||||
accumulatedPnL += pnl
|
||||
remainingQty -= actualQuantity
|
||||
partialCloseCount++
|
||||
partialCloseVolume += actualQuantity
|
||||
|
||||
// 更新 openPositions(保留持倉記錄,但更新追蹤數據)
|
||||
openPos["remainingQuantity"] = remainingQty
|
||||
openPos["accumulatedPnL"] = accumulatedPnL
|
||||
openPos["partialCloseCount"] = partialCloseCount
|
||||
openPos["partialCloseVolume"] = partialCloseVolume
|
||||
|
||||
// 判斷是否已完全平倉
|
||||
if remainingQty <= 0.0001 { // 使用小閾值避免浮點誤差
|
||||
// ✅ 完全平倉:記錄為一筆完整交易
|
||||
positionValue := quantity * openPrice
|
||||
marginUsed := positionValue / float64(leverage)
|
||||
pnlPct := 0.0
|
||||
if marginUsed > 0 {
|
||||
pnlPct = (accumulatedPnL / marginUsed) * 100
|
||||
}
|
||||
|
||||
outcome := TradeOutcome{
|
||||
Symbol: symbol,
|
||||
Side: side,
|
||||
Quantity: quantity, // 使用原始總量
|
||||
Leverage: leverage,
|
||||
OpenPrice: openPrice,
|
||||
ClosePrice: action.Price, // 最後一次平倉價格
|
||||
PositionValue: positionValue,
|
||||
MarginUsed: marginUsed,
|
||||
PnL: accumulatedPnL, // 🔧 使用累積盈虧
|
||||
PnLPct: pnlPct,
|
||||
Duration: action.Timestamp.Sub(openTime).String(),
|
||||
OpenTime: openTime,
|
||||
CloseTime: action.Timestamp,
|
||||
}
|
||||
|
||||
analysis.RecentTrades = append(analysis.RecentTrades, outcome)
|
||||
analysis.TotalTrades++ // 🔧 只在完全平倉時計數
|
||||
|
||||
// 分类交易
|
||||
if accumulatedPnL > 0 {
|
||||
analysis.WinningTrades++
|
||||
analysis.AvgWin += accumulatedPnL
|
||||
} else if accumulatedPnL < 0 {
|
||||
analysis.LosingTrades++
|
||||
analysis.AvgLoss += accumulatedPnL
|
||||
}
|
||||
|
||||
// 更新币种统计
|
||||
if _, exists := analysis.SymbolStats[symbol]; !exists {
|
||||
analysis.SymbolStats[symbol] = &SymbolPerformance{
|
||||
Symbol: symbol,
|
||||
}
|
||||
}
|
||||
stats := analysis.SymbolStats[symbol]
|
||||
stats.TotalTrades++
|
||||
stats.TotalPnL += accumulatedPnL
|
||||
if accumulatedPnL > 0 {
|
||||
stats.WinningTrades++
|
||||
} else if accumulatedPnL < 0 {
|
||||
stats.LosingTrades++
|
||||
}
|
||||
|
||||
// 刪除持倉記錄
|
||||
delete(openPositions, posKey)
|
||||
}
|
||||
// ⚠️ 否則不做任何操作(等待後續 partial_close 或 full close)
|
||||
|
||||
} else {
|
||||
// 🔧 完全平倉(close_long/close_short/auto_close)
|
||||
// 如果之前有部分平倉,需要加上累積的 PnL
|
||||
totalPnL := accumulatedPnL + pnl
|
||||
|
||||
positionValue := quantity * openPrice
|
||||
marginUsed := positionValue / float64(leverage)
|
||||
pnlPct := 0.0
|
||||
if marginUsed > 0 {
|
||||
pnlPct = (totalPnL / marginUsed) * 100
|
||||
}
|
||||
|
||||
outcome := TradeOutcome{
|
||||
Symbol: symbol,
|
||||
Side: side,
|
||||
Quantity: quantity, // 使用原始總量
|
||||
Leverage: leverage,
|
||||
OpenPrice: openPrice,
|
||||
ClosePrice: action.Price,
|
||||
PositionValue: positionValue,
|
||||
MarginUsed: marginUsed,
|
||||
PnL: totalPnL, // 🔧 包含之前部分平倉的 PnL
|
||||
PnLPct: pnlPct,
|
||||
Duration: action.Timestamp.Sub(openTime).String(),
|
||||
OpenTime: openTime,
|
||||
CloseTime: action.Timestamp,
|
||||
}
|
||||
|
||||
analysis.RecentTrades = append(analysis.RecentTrades, outcome)
|
||||
analysis.TotalTrades++
|
||||
|
||||
// 分类交易
|
||||
if totalPnL > 0 {
|
||||
analysis.WinningTrades++
|
||||
analysis.AvgWin += totalPnL
|
||||
} else if totalPnL < 0 {
|
||||
analysis.LosingTrades++
|
||||
analysis.AvgLoss += totalPnL
|
||||
}
|
||||
|
||||
// 更新币种统计
|
||||
if _, exists := analysis.SymbolStats[symbol]; !exists {
|
||||
analysis.SymbolStats[symbol] = &SymbolPerformance{
|
||||
Symbol: symbol,
|
||||
}
|
||||
}
|
||||
stats := analysis.SymbolStats[symbol]
|
||||
stats.TotalTrades++
|
||||
stats.TotalPnL += totalPnL
|
||||
if totalPnL > 0 {
|
||||
stats.WinningTrades++
|
||||
} else if totalPnL < 0 {
|
||||
stats.LosingTrades++
|
||||
}
|
||||
|
||||
// 刪除持倉記錄
|
||||
delete(openPositions, posKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算统计指标
|
||||
if analysis.TotalTrades > 0 {
|
||||
analysis.WinRate = (float64(analysis.WinningTrades) / float64(analysis.TotalTrades)) * 100
|
||||
|
||||
// 计算总盈利和总亏损
|
||||
totalWinAmount := analysis.AvgWin // 当前是累加的总和
|
||||
totalLossAmount := analysis.AvgLoss // 当前是累加的总和(负数)
|
||||
|
||||
if analysis.WinningTrades > 0 {
|
||||
analysis.AvgWin /= float64(analysis.WinningTrades)
|
||||
}
|
||||
if analysis.LosingTrades > 0 {
|
||||
analysis.AvgLoss /= float64(analysis.LosingTrades)
|
||||
}
|
||||
|
||||
// Profit Factor = 总盈利 / 总亏损(绝对值)
|
||||
// 注意:totalLossAmount 是负数,所以取负号得到绝对值
|
||||
if totalLossAmount != 0 {
|
||||
analysis.ProfitFactor = totalWinAmount / (-totalLossAmount)
|
||||
} else if totalWinAmount > 0 {
|
||||
// 只有盈利没有亏损的情况,设置为一个很大的值表示完美策略
|
||||
analysis.ProfitFactor = 999.0
|
||||
}
|
||||
}
|
||||
|
||||
// 计算各币种胜率和平均盈亏
|
||||
bestPnL := -999999.0
|
||||
worstPnL := 999999.0
|
||||
for symbol, stats := range analysis.SymbolStats {
|
||||
if stats.TotalTrades > 0 {
|
||||
stats.WinRate = (float64(stats.WinningTrades) / float64(stats.TotalTrades)) * 100
|
||||
stats.AvgPnL = stats.TotalPnL / float64(stats.TotalTrades)
|
||||
|
||||
if stats.TotalPnL > bestPnL {
|
||||
bestPnL = stats.TotalPnL
|
||||
analysis.BestSymbol = symbol
|
||||
}
|
||||
if stats.TotalPnL < worstPnL {
|
||||
worstPnL = stats.TotalPnL
|
||||
analysis.WorstSymbol = symbol
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 只保留最近的交易(倒序:最新的在前)
|
||||
if len(analysis.RecentTrades) > 10 {
|
||||
// 反转数组,让最新的在前
|
||||
for i, j := 0, len(analysis.RecentTrades)-1; i < j; i, j = i+1, j-1 {
|
||||
analysis.RecentTrades[i], analysis.RecentTrades[j] = analysis.RecentTrades[j], analysis.RecentTrades[i]
|
||||
}
|
||||
analysis.RecentTrades = analysis.RecentTrades[:10]
|
||||
} else if len(analysis.RecentTrades) > 0 {
|
||||
// 反转数组
|
||||
for i, j := 0, len(analysis.RecentTrades)-1; i < j; i, j = i+1, j-1 {
|
||||
analysis.RecentTrades[i], analysis.RecentTrades[j] = analysis.RecentTrades[j], analysis.RecentTrades[i]
|
||||
}
|
||||
}
|
||||
|
||||
// 计算夏普比率(需要至少2个数据点)
|
||||
analysis.SharpeRatio = l.calculateSharpeRatio(records)
|
||||
|
||||
return analysis, nil
|
||||
}
|
||||
|
||||
// calculateSharpeRatio 计算夏普比率
|
||||
// 基于账户净值的变化计算风险调整后收益
|
||||
func (l *DecisionLogger) calculateSharpeRatio(records []*DecisionRecord) float64 {
|
||||
if len(records) < 2 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 提取每个周期的账户净值
|
||||
// 注意:TotalBalance字段实际存储的是TotalEquity(账户总净值)
|
||||
// TotalUnrealizedProfit字段实际存储的是TotalPnL(相对初始余额的盈亏)
|
||||
var equities []float64
|
||||
for _, record := range records {
|
||||
// 直接使用TotalBalance,因为它已经是完整的账户净值
|
||||
equity := record.AccountState.TotalBalance
|
||||
if equity > 0 {
|
||||
equities = append(equities, equity)
|
||||
}
|
||||
}
|
||||
|
||||
if len(equities) < 2 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 计算周期收益率(period returns)
|
||||
var returns []float64
|
||||
for i := 1; i < len(equities); i++ {
|
||||
if equities[i-1] > 0 {
|
||||
periodReturn := (equities[i] - equities[i-1]) / equities[i-1]
|
||||
returns = append(returns, periodReturn)
|
||||
}
|
||||
}
|
||||
|
||||
if len(returns) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 计算平均收益率
|
||||
sumReturns := 0.0
|
||||
for _, r := range returns {
|
||||
sumReturns += r
|
||||
}
|
||||
meanReturn := sumReturns / float64(len(returns))
|
||||
|
||||
// 计算收益率标准差
|
||||
sumSquaredDiff := 0.0
|
||||
for _, r := range returns {
|
||||
diff := r - meanReturn
|
||||
sumSquaredDiff += diff * diff
|
||||
}
|
||||
variance := sumSquaredDiff / float64(len(returns))
|
||||
stdDev := math.Sqrt(variance)
|
||||
|
||||
// 避免除以零
|
||||
if stdDev == 0 {
|
||||
if meanReturn > 0 {
|
||||
return 999.0 // 无波动的正收益
|
||||
} else if meanReturn < 0 {
|
||||
return -999.0 // 无波动的负收益
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 计算夏普比率(假设无风险利率为0)
|
||||
// 注:直接返回周期级别的夏普比率(非年化),正常范围 -2 到 +2
|
||||
sharpeRatio := meanReturn / stdDev
|
||||
return sharpeRatio
|
||||
}
|
||||
129
logger/logger.go
129
logger/logger.go
@@ -1,7 +1,6 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"nofx/config"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -10,11 +9,20 @@ import (
|
||||
var (
|
||||
// Log 全局logger实例
|
||||
Log *logrus.Logger
|
||||
|
||||
// telegramHook 保存hook引用,用于优雅关闭
|
||||
telegramHook *TelegramHook
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 自动初始化默认 logger,确保在 Init 被调用前也能使用
|
||||
Log = logrus.New()
|
||||
Log.SetLevel(logrus.InfoLevel)
|
||||
Log.SetFormatter(&logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
TimestampFormat: "2006-01-02 15:04:05",
|
||||
ForceColors: true,
|
||||
})
|
||||
Log.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 初始化函数
|
||||
// ============================================================================
|
||||
@@ -52,26 +60,6 @@ func Init(cfg *Config) error {
|
||||
// 启用调用位置信息
|
||||
Log.SetReportCaller(true)
|
||||
|
||||
// 添加Telegram Hook(可选)
|
||||
if cfg.Telegram != nil && cfg.Telegram.Enabled {
|
||||
if err := setupTelegramHook(cfg.Telegram); err != nil {
|
||||
Log.Warnf("初始化Telegram推送失败,将继续使用普通日志: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupTelegramHook 设置Telegram Hook
|
||||
func setupTelegramHook(telegramCfg *TelegramConfig) error {
|
||||
hook, err := NewTelegramHook(telegramCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Log.AddHook(hook)
|
||||
telegramHook = hook
|
||||
Log.Info("✅ Telegram日志推送已启用")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -81,69 +69,9 @@ func InitWithSimpleConfig(level string) error {
|
||||
return Init(&Config{Level: level})
|
||||
}
|
||||
|
||||
// InitWithTelegram 使用Telegram配置初始化logger
|
||||
func InitWithTelegram(botToken string, chatID int64) error {
|
||||
return Init(&Config{
|
||||
Level: "info",
|
||||
Telegram: &TelegramConfig{
|
||||
Enabled: true,
|
||||
BotToken: botToken,
|
||||
ChatID: chatID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// InitFromLogConfig 从config.LogConfig初始化logger
|
||||
func InitFromLogConfig(logConfig *config.LogConfig) error {
|
||||
if logConfig == nil {
|
||||
return InitWithSimpleConfig("info")
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
Level: logConfig.Level,
|
||||
}
|
||||
|
||||
if cfg.Level == "" {
|
||||
cfg.Level = "info"
|
||||
}
|
||||
|
||||
// 如果启用了Telegram,添加配置
|
||||
if logConfig.Telegram != nil && logConfig.Telegram.Enabled {
|
||||
if botToken := logConfig.Telegram.BotToken; botToken != "" && logConfig.Telegram.ChatID != 0 {
|
||||
cfg.Telegram = &TelegramConfig{
|
||||
Enabled: true,
|
||||
BotToken: botToken,
|
||||
ChatID: logConfig.Telegram.ChatID,
|
||||
MinLevel: logConfig.Telegram.MinLevel,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Init(cfg)
|
||||
}
|
||||
|
||||
// InitFromParams 从参数初始化logger
|
||||
// 适用于不依赖config包的场景
|
||||
func InitFromParams(level string, telegramEnabled bool, botToken string, chatID int64) error {
|
||||
cfg := &Config{Level: level}
|
||||
|
||||
if telegramEnabled && botToken != "" && chatID != 0 {
|
||||
cfg.Telegram = &TelegramConfig{
|
||||
Enabled: true,
|
||||
BotToken: botToken,
|
||||
ChatID: chatID,
|
||||
}
|
||||
}
|
||||
|
||||
return Init(cfg)
|
||||
}
|
||||
|
||||
// Shutdown 优雅关闭logger(主要用于关闭Telegram发送器)
|
||||
// Shutdown 优雅关闭logger
|
||||
func Shutdown() {
|
||||
if telegramHook != nil {
|
||||
telegramHook.Stop()
|
||||
telegramHook = nil
|
||||
}
|
||||
// 预留用于未来扩展
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@@ -208,3 +136,32 @@ func Panic(args ...interface{}) {
|
||||
func Panicf(format string, args ...interface{}) {
|
||||
Log.Panicf(format, args...)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MCP Logger 适配器
|
||||
// ============================================================================
|
||||
|
||||
// MCPLogger 适配器,使 MCP 包使用全局 logger
|
||||
// 实现 mcp.Logger 接口
|
||||
type MCPLogger struct{}
|
||||
|
||||
// NewMCPLogger 创建 MCP 日志适配器
|
||||
func NewMCPLogger() *MCPLogger {
|
||||
return &MCPLogger{}
|
||||
}
|
||||
|
||||
func (l *MCPLogger) Debugf(format string, args ...any) {
|
||||
Log.Debugf(format, args...)
|
||||
}
|
||||
|
||||
func (l *MCPLogger) Infof(format string, args ...any) {
|
||||
Log.Infof(format, args...)
|
||||
}
|
||||
|
||||
func (l *MCPLogger) Warnf(format string, args ...any) {
|
||||
Log.Warnf(format, args...)
|
||||
}
|
||||
|
||||
func (l *MCPLogger) Errorf(format string, args ...any) {
|
||||
Log.Errorf(format, args...)
|
||||
}
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TelegramHook 实现logrus.Hook接口,将日志推送到Telegram
|
||||
type TelegramHook struct {
|
||||
sender *TelegramSender
|
||||
levels []logrus.Level
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// NewTelegramHook 创建Telegram Hook
|
||||
func NewTelegramHook(config *TelegramConfig) (*TelegramHook, error) {
|
||||
if !config.Enabled {
|
||||
return &TelegramHook{enabled: false}, nil
|
||||
}
|
||||
|
||||
if config.BotToken == "" || config.ChatID == 0 {
|
||||
return nil, fmt.Errorf("telegram配置不完整: bot_token和chat_id不能为空")
|
||||
}
|
||||
|
||||
// 创建发送器(使用默认参数)
|
||||
sender, err := NewTelegramSender(config.BotToken, config.ChatID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建telegram发送器失败: %w", err)
|
||||
}
|
||||
|
||||
hook := &TelegramHook{
|
||||
sender: sender,
|
||||
levels: config.GetLogrusLevels(),
|
||||
enabled: true,
|
||||
}
|
||||
|
||||
return hook, nil
|
||||
}
|
||||
|
||||
// Levels 返回需要触发的日志级别
|
||||
func (h *TelegramHook) Levels() []logrus.Level {
|
||||
if !h.enabled {
|
||||
return []logrus.Level{}
|
||||
}
|
||||
return h.levels
|
||||
}
|
||||
|
||||
// Fire 当日志触发时调用
|
||||
func (h *TelegramHook) Fire(entry *logrus.Entry) error {
|
||||
if !h.enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 格式化消息
|
||||
message := h.formatMessage(entry)
|
||||
|
||||
// 异步发送(非阻塞)
|
||||
h.sender.SendAsync(message)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatMessage 格式化日志消息为Telegram格式
|
||||
func (h *TelegramHook) formatMessage(entry *logrus.Entry) string {
|
||||
// 级别emoji
|
||||
levelEmoji := h.getLevelEmoji(entry.Level)
|
||||
|
||||
// 基本信息
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintf("%s *%s*: 系统日志警报\n", levelEmoji, strings.ToUpper(entry.Level.String())))
|
||||
builder.WriteString(fmt.Sprintf("📝 消息: `%s`\n", escapeMarkdown(entry.Message)))
|
||||
|
||||
// 字段信息
|
||||
if len(entry.Data) > 0 {
|
||||
builder.WriteString("📊 字段:\n")
|
||||
for key, value := range entry.Data {
|
||||
builder.WriteString(fmt.Sprintf(" • %s: `%v`\n", key, value))
|
||||
}
|
||||
}
|
||||
|
||||
// 调用位置
|
||||
if entry.HasCaller() {
|
||||
file := entry.Caller.File
|
||||
// 只保留相对路径
|
||||
if idx := strings.Index(file, "nofx/"); idx >= 0 {
|
||||
file = file[idx:]
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("📍 位置: `%s:%d`\n", file, entry.Caller.Line))
|
||||
} else {
|
||||
// 如果entry没有caller,手动获取
|
||||
if _, file, line, ok := runtime.Caller(8); ok {
|
||||
if idx := strings.Index(file, "nofx/"); idx >= 0 {
|
||||
file = file[idx:]
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("📍 位置: `%s:%d`\n", file, line))
|
||||
}
|
||||
}
|
||||
|
||||
// 时间戳
|
||||
builder.WriteString(fmt.Sprintf("🕐 时间: `%s`", entry.Time.Format("2006-01-02 15:04:05")))
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// getLevelEmoji 获取日志级别对应的emoji
|
||||
func (h *TelegramHook) getLevelEmoji(level logrus.Level) string {
|
||||
switch level {
|
||||
case logrus.PanicLevel:
|
||||
return "🔴"
|
||||
case logrus.FatalLevel:
|
||||
return "🔴"
|
||||
case logrus.ErrorLevel:
|
||||
return "🟠"
|
||||
case logrus.WarnLevel:
|
||||
return "🟡"
|
||||
case logrus.InfoLevel:
|
||||
return "🟢"
|
||||
case logrus.DebugLevel:
|
||||
return "🔵"
|
||||
default:
|
||||
return "⚪"
|
||||
}
|
||||
}
|
||||
|
||||
// escapeMarkdown 转义Markdown特殊字符
|
||||
func escapeMarkdown(text string) string {
|
||||
replacer := strings.NewReplacer(
|
||||
"_", "\\_",
|
||||
"*", "\\*",
|
||||
"[", "\\[",
|
||||
"]", "\\]",
|
||||
"(", "\\(",
|
||||
")", "\\)",
|
||||
"~", "\\~",
|
||||
"`", "\\`",
|
||||
">", "\\>",
|
||||
"#", "\\#",
|
||||
"+", "\\+",
|
||||
"-", "\\-",
|
||||
"=", "\\=",
|
||||
"|", "\\|",
|
||||
"{", "\\{",
|
||||
"}", "\\}",
|
||||
".", "\\.",
|
||||
"!", "\\!",
|
||||
)
|
||||
return replacer.Replace(text)
|
||||
}
|
||||
|
||||
// Stop 停止Hook(优雅关闭)
|
||||
func (h *TelegramHook) Stop() {
|
||||
if h.enabled && h.sender != nil {
|
||||
h.sender.Stop()
|
||||
}
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
|
||||
)
|
||||
|
||||
// TelegramSender Telegram消息发送器(异步)
|
||||
type TelegramSender struct {
|
||||
bot *tgbotapi.BotAPI
|
||||
chatID int64
|
||||
msgChan chan string
|
||||
retryCount int
|
||||
retryInterval time.Duration
|
||||
wg sync.WaitGroup
|
||||
stopChan chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
// NewTelegramSender 创建Telegram发送器(使用默认参数)
|
||||
func NewTelegramSender(botToken string, chatID int64) (*TelegramSender, error) {
|
||||
bot, err := tgbotapi.NewBotAPI(botToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建telegram bot失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置为静默模式(不打印bot信息)
|
||||
bot.Debug = false
|
||||
|
||||
sender := &TelegramSender{
|
||||
bot: bot,
|
||||
chatID: chatID,
|
||||
msgChan: make(chan string, 20), // 固定缓冲区大小: 20
|
||||
retryCount: 3, // 固定重试次数: 3
|
||||
retryInterval: 3 * time.Second, // 固定重试间隔: 3秒
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 启动异步发送协程
|
||||
sender.Start()
|
||||
|
||||
return sender, nil
|
||||
}
|
||||
|
||||
// Start 启动异步发送协程
|
||||
func (s *TelegramSender) Start() {
|
||||
s.wg.Add(1)
|
||||
go s.listenAndSend()
|
||||
}
|
||||
|
||||
// SendAsync 异步发送消息(非阻塞)
|
||||
func (s *TelegramSender) SendAsync(message string) {
|
||||
select {
|
||||
case s.msgChan <- message:
|
||||
// 成功写入缓冲区
|
||||
default:
|
||||
// 缓冲区满,丢弃消息(不阻塞主流程)
|
||||
fmt.Printf("[Telegram] 消息缓冲区已满,消息被丢弃\n")
|
||||
}
|
||||
}
|
||||
|
||||
// listenAndSend 监听channel并发送消息
|
||||
func (s *TelegramSender) listenAndSend() {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-s.msgChan:
|
||||
s.sendWithRetry(msg)
|
||||
case <-s.stopChan:
|
||||
// 清空缓冲区后退出
|
||||
for len(s.msgChan) > 0 {
|
||||
msg := <-s.msgChan
|
||||
s.sendWithRetry(msg)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendWithRetry 发送消息(带重试)
|
||||
func (s *TelegramSender) sendWithRetry(message string) {
|
||||
var err error
|
||||
for i := 0; i < s.retryCount; i++ {
|
||||
err = s.send(message)
|
||||
if err == nil {
|
||||
return // 发送成功
|
||||
}
|
||||
|
||||
// 重试前等待
|
||||
if i < s.retryCount-1 {
|
||||
time.Sleep(s.retryInterval)
|
||||
}
|
||||
}
|
||||
|
||||
// 所有重试都失败
|
||||
if err != nil {
|
||||
fmt.Printf("[Telegram] 发送消息失败(已重试%d次): %v\n", s.retryCount, err)
|
||||
}
|
||||
}
|
||||
|
||||
// send 发送单条消息
|
||||
func (s *TelegramSender) send(message string) error {
|
||||
msg := tgbotapi.NewMessage(s.chatID, message)
|
||||
msg.ParseMode = tgbotapi.ModeMarkdown
|
||||
|
||||
_, err := s.bot.Send(msg)
|
||||
return err
|
||||
}
|
||||
|
||||
// Stop 停止发送器(优雅关闭)
|
||||
func (s *TelegramSender) Stop() {
|
||||
s.once.Do(func() {
|
||||
close(s.stopChan)
|
||||
s.wg.Wait()
|
||||
})
|
||||
}
|
||||
241
main.go
241
main.go
@@ -3,21 +3,24 @@ package main
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/api"
|
||||
"nofx/auth"
|
||||
"nofx/backtest"
|
||||
"nofx/config"
|
||||
"nofx/crypto"
|
||||
"nofx/logger"
|
||||
"nofx/manager"
|
||||
"nofx/market"
|
||||
"nofx/mcp"
|
||||
"nofx/pool"
|
||||
"nofx/store"
|
||||
"nofx/trader"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
@@ -44,7 +47,7 @@ type ConfigFile struct {
|
||||
func loadConfigFile() (*ConfigFile, error) {
|
||||
// 检查config.json是否存在
|
||||
if _, err := os.Stat("config.json"); os.IsNotExist(err) {
|
||||
log.Printf("📄 config.json不存在,使用默认配置")
|
||||
logger.Info("📄 config.json不存在,使用默认配置")
|
||||
return &ConfigFile{}, nil
|
||||
}
|
||||
|
||||
@@ -64,12 +67,12 @@ func loadConfigFile() (*ConfigFile, error) {
|
||||
}
|
||||
|
||||
// syncConfigToDatabase 将配置同步到数据库
|
||||
func syncConfigToDatabase(database *config.Database, configFile *ConfigFile) error {
|
||||
func syncConfigToDatabase(st *store.Store, configFile *ConfigFile) error {
|
||||
if configFile == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf("🔄 开始同步config.json到数据库...")
|
||||
logger.Info("🔄 开始同步config.json到数据库...")
|
||||
|
||||
// 同步各配置项到数据库
|
||||
configs := map[string]string{
|
||||
@@ -106,24 +109,24 @@ func syncConfigToDatabase(database *config.Database, configFile *ConfigFile) err
|
||||
|
||||
// 更新数据库配置
|
||||
for key, value := range configs {
|
||||
if err := database.SetSystemConfig(key, value); err != nil {
|
||||
log.Printf("⚠️ 更新配置 %s 失败: %v", key, err)
|
||||
if err := st.SystemConfig().Set(key, value); err != nil {
|
||||
logger.Warnf("⚠️ 更新配置 %s 失败: %v", key, err)
|
||||
} else {
|
||||
log.Printf("✓ 同步配置: %s = %s", key, value)
|
||||
logger.Infof("✓ 同步配置: %s = %s", key, value)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("✅ config.json同步完成")
|
||||
logger.Info("✅ config.json同步完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadBetaCodesToDatabase 加载内测码文件到数据库
|
||||
func loadBetaCodesToDatabase(database *config.Database) error {
|
||||
func loadBetaCodesToDatabase(st *store.Store) error {
|
||||
betaCodeFile := "beta_codes.txt"
|
||||
|
||||
// 检查内测码文件是否存在
|
||||
if _, err := os.Stat(betaCodeFile); os.IsNotExist(err) {
|
||||
log.Printf("📄 内测码文件 %s 不存在,跳过加载", betaCodeFile)
|
||||
logger.Infof("📄 内测码文件 %s 不存在,跳过加载", betaCodeFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -133,37 +136,39 @@ func loadBetaCodesToDatabase(database *config.Database) error {
|
||||
return fmt.Errorf("获取内测码文件信息失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("🔄 发现内测码文件 %s (%.1f KB),开始加载...", betaCodeFile, float64(fileInfo.Size())/1024)
|
||||
logger.Infof("🔄 发现内测码文件 %s (%.1f KB),开始加载...", betaCodeFile, float64(fileInfo.Size())/1024)
|
||||
|
||||
// 加载内测码到数据库
|
||||
err = database.LoadBetaCodesFromFile(betaCodeFile)
|
||||
err = st.BetaCode().LoadFromFile(betaCodeFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加载内测码失败: %w", err)
|
||||
}
|
||||
|
||||
// 显示统计信息
|
||||
total, used, err := database.GetBetaCodeStats()
|
||||
total, used, err := st.BetaCode().GetStats()
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 获取内测码统计失败: %v", err)
|
||||
logger.Warnf("⚠️ 获取内测码统计失败: %v", err)
|
||||
} else {
|
||||
log.Printf("✅ 内测码加载完成: 总计 %d 个,已使用 %d 个,剩余 %d 个", total, used, total-used)
|
||||
logger.Infof("✅ 内测码加载完成: 总计 %d 个,已使用 %d 个,剩余 %d 个", total, used, total-used)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
fmt.Println("╔════════════════════════════════════════════════════════════╗")
|
||||
fmt.Println("║ 🤖 AI多模型交易系统 - 支持 DeepSeek & Qwen ║")
|
||||
fmt.Println("╚════════════════════════════════════════════════════════════╝")
|
||||
fmt.Println()
|
||||
|
||||
// Load environment variables from .env file if present (for local/dev runs)
|
||||
// In Docker Compose, variables are injected by the runtime and this is harmless.
|
||||
_ = godotenv.Load()
|
||||
|
||||
// 初始化日志
|
||||
logger.Init(nil)
|
||||
|
||||
logger.Info("╔════════════════════════════════════════════════════════════╗")
|
||||
logger.Info("║ 🤖 AI多模型交易系统 - 支持 DeepSeek & Qwen ║")
|
||||
logger.Info("╚════════════════════════════════════════════════════════════╝")
|
||||
|
||||
// 初始化数据库配置
|
||||
dbPath := "config.db"
|
||||
dbPath := "data.db"
|
||||
if len(os.Args) > 1 {
|
||||
dbPath = os.Args[1]
|
||||
}
|
||||
@@ -171,163 +176,174 @@ func main() {
|
||||
// 读取配置文件
|
||||
configFile, err := loadConfigFile()
|
||||
if err != nil {
|
||||
log.Fatalf("❌ 读取config.json失败: %v", err)
|
||||
logger.Fatalf("❌ 读取config.json失败: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("📋 初始化配置数据库: %s", dbPath)
|
||||
database, err := config.NewDatabase(dbPath)
|
||||
logger.Infof("📋 初始化配置数据库: %s", dbPath)
|
||||
st, err := store.New(dbPath)
|
||||
if err != nil {
|
||||
log.Fatalf("❌ 初始化数据库失败: %v", err)
|
||||
logger.Fatalf("❌ 初始化数据库失败: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
backtest.UseDatabase(database.Conn())
|
||||
defer st.Close()
|
||||
backtest.UseDatabase(st.DB())
|
||||
|
||||
// 初始化加密服务
|
||||
log.Printf("🔐 初始化加密服务...")
|
||||
cryptoService, err := crypto.NewCryptoService("secrets/rsa_key")
|
||||
logger.Info("🔐 初始化加密服务...")
|
||||
cryptoService, err := crypto.NewCryptoService()
|
||||
if err != nil {
|
||||
log.Fatalf("❌ 初始化加密服务失败: %v", err)
|
||||
logger.Fatalf("❌ 初始化加密服务失败: %v", err)
|
||||
}
|
||||
database.SetCryptoService(cryptoService)
|
||||
log.Printf("✅ 加密服务初始化成功")
|
||||
// 创建加密/解密包装函数
|
||||
encryptFunc := func(plaintext string) string {
|
||||
if plaintext == "" {
|
||||
return plaintext
|
||||
}
|
||||
encrypted, err := cryptoService.EncryptForStorage(plaintext)
|
||||
if err != nil {
|
||||
logger.Warnf("⚠️ 加密失败: %v", err)
|
||||
return plaintext
|
||||
}
|
||||
return encrypted
|
||||
}
|
||||
decryptFunc := func(encrypted string) string {
|
||||
if encrypted == "" {
|
||||
return encrypted
|
||||
}
|
||||
if !cryptoService.IsEncryptedStorageValue(encrypted) {
|
||||
return encrypted
|
||||
}
|
||||
decrypted, err := cryptoService.DecryptFromStorage(encrypted)
|
||||
if err != nil {
|
||||
logger.Warnf("⚠️ 解密失败: %v", err)
|
||||
return encrypted
|
||||
}
|
||||
return decrypted
|
||||
}
|
||||
st.SetCryptoFuncs(encryptFunc, decryptFunc)
|
||||
logger.Info("✅ 加密服务初始化成功")
|
||||
|
||||
// 同步config.json到数据库
|
||||
if err := syncConfigToDatabase(database, configFile); err != nil {
|
||||
log.Printf("⚠️ 同步config.json到数据库失败: %v", err)
|
||||
if err := syncConfigToDatabase(st, configFile); err != nil {
|
||||
logger.Warnf("⚠️ 同步config.json到数据库失败: %v", err)
|
||||
}
|
||||
|
||||
// 加载内测码到数据库
|
||||
if err := loadBetaCodesToDatabase(database); err != nil {
|
||||
log.Printf("⚠️ 加载内测码到数据库失败: %v", err)
|
||||
if err := loadBetaCodesToDatabase(st); err != nil {
|
||||
logger.Warnf("⚠️ 加载内测码到数据库失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取系统配置
|
||||
useDefaultCoinsStr, _ := database.GetSystemConfig("use_default_coins")
|
||||
useDefaultCoinsStr, _ := st.SystemConfig().Get("use_default_coins")
|
||||
useDefaultCoins := useDefaultCoinsStr == "true"
|
||||
apiPortStr, _ := database.GetSystemConfig("api_server_port")
|
||||
apiPortStr, _ := st.SystemConfig().Get("api_server_port")
|
||||
|
||||
// 设置JWT密钥(优先使用环境变量)
|
||||
jwtSecret := strings.TrimSpace(os.Getenv("JWT_SECRET"))
|
||||
if jwtSecret == "" {
|
||||
// 回退到数据库配置
|
||||
jwtSecret, _ = database.GetSystemConfig("jwt_secret")
|
||||
jwtSecret, _ = st.SystemConfig().Get("jwt_secret")
|
||||
if jwtSecret == "" {
|
||||
jwtSecret = "your-jwt-secret-key-change-in-production-make-it-long-and-random"
|
||||
log.Printf("⚠️ 使用默认JWT密钥,建议使用加密设置脚本生成安全密钥")
|
||||
logger.Warn("⚠️ 使用默认JWT密钥,建议使用加密设置脚本生成安全密钥")
|
||||
} else {
|
||||
log.Printf("🔑 使用数据库中JWT密钥")
|
||||
logger.Info("🔑 使用数据库中JWT密钥")
|
||||
}
|
||||
} else {
|
||||
log.Printf("🔑 使用环境变量JWT密钥")
|
||||
logger.Info("🔑 使用环境变量JWT密钥")
|
||||
}
|
||||
auth.SetJWTSecret(jwtSecret)
|
||||
|
||||
// 管理员模式下需要管理员密码,缺失则退出
|
||||
|
||||
log.Printf("✓ 配置数据库初始化成功")
|
||||
fmt.Println()
|
||||
logger.Info("✓ 配置数据库初始化成功")
|
||||
|
||||
// 从数据库读取默认主流币种列表
|
||||
defaultCoinsJSON, _ := database.GetSystemConfig("default_coins")
|
||||
defaultCoinsJSON, _ := st.SystemConfig().Get("default_coins")
|
||||
var defaultCoins []string
|
||||
|
||||
if defaultCoinsJSON != "" {
|
||||
// 尝试从JSON解析
|
||||
if err := json.Unmarshal([]byte(defaultCoinsJSON), &defaultCoins); err != nil {
|
||||
log.Printf("⚠️ 解析default_coins配置失败: %v,使用硬编码默认值", err)
|
||||
logger.Warnf("⚠️ 解析default_coins配置失败: %v,使用硬编码默认值", err)
|
||||
defaultCoins = []string{"BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT", "XRPUSDT", "DOGEUSDT", "ADAUSDT", "HYPEUSDT"}
|
||||
} else {
|
||||
log.Printf("✓ 从数据库加载默认币种列表(共%d个): %v", len(defaultCoins), defaultCoins)
|
||||
logger.Infof("✓ 从数据库加载默认币种列表(共%d个): %v", len(defaultCoins), defaultCoins)
|
||||
}
|
||||
} else {
|
||||
// 如果数据库中没有配置,使用硬编码默认值
|
||||
defaultCoins = []string{"BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT", "XRPUSDT", "DOGEUSDT", "ADAUSDT", "HYPEUSDT"}
|
||||
log.Printf("⚠️ 数据库中未配置default_coins,使用硬编码默认值")
|
||||
logger.Warn("⚠️ 数据库中未配置default_coins,使用硬编码默认值")
|
||||
}
|
||||
|
||||
pool.SetDefaultCoins(defaultCoins)
|
||||
// 设置是否使用默认主流币种
|
||||
pool.SetUseDefaultCoins(useDefaultCoins)
|
||||
if useDefaultCoins {
|
||||
log.Printf("✓ 已启用默认主流币种列表")
|
||||
logger.Info("✓ 已启用默认主流币种列表")
|
||||
}
|
||||
|
||||
// 设置币种池API URL
|
||||
coinPoolAPIURL, _ := database.GetSystemConfig("coin_pool_api_url")
|
||||
coinPoolAPIURL, _ := st.SystemConfig().Get("coin_pool_api_url")
|
||||
if coinPoolAPIURL != "" {
|
||||
pool.SetCoinPoolAPI(coinPoolAPIURL)
|
||||
log.Printf("✓ 已配置AI500币种池API")
|
||||
logger.Info("✓ 已配置AI500币种池API")
|
||||
}
|
||||
|
||||
oiTopAPIURL, _ := database.GetSystemConfig("oi_top_api_url")
|
||||
oiTopAPIURL, _ := st.SystemConfig().Get("oi_top_api_url")
|
||||
if oiTopAPIURL != "" {
|
||||
pool.SetOITopAPI(oiTopAPIURL)
|
||||
log.Printf("✓ 已配置OI Top API")
|
||||
logger.Info("✓ 已配置OI Top API")
|
||||
}
|
||||
|
||||
// 创建TraderManager 与 BacktestManager
|
||||
cfgForAI, cfgErr := config.LoadConfig("config.json")
|
||||
if cfgErr != nil {
|
||||
log.Printf("⚠️ 加载config.json用于AI客户端失败: %v", cfgErr)
|
||||
logger.Warnf("⚠️ 加载config.json用于AI客户端失败: %v", cfgErr)
|
||||
}
|
||||
|
||||
traderManager := manager.NewTraderManager()
|
||||
mcpClient := newSharedMCPClient(cfgForAI)
|
||||
backtestManager := backtest.NewManager(mcpClient)
|
||||
if err := backtestManager.RestoreRuns(); err != nil {
|
||||
log.Printf("⚠️ 恢复历史回测失败: %v", err)
|
||||
logger.Warnf("⚠️ 恢复历史回测失败: %v", err)
|
||||
}
|
||||
|
||||
// 从数据库加载所有交易员到内存
|
||||
err = traderManager.LoadTradersFromDatabase(database)
|
||||
err = traderManager.LoadTradersFromStore(st)
|
||||
if err != nil {
|
||||
log.Fatalf("❌ 加载交易员失败: %v", err)
|
||||
logger.Fatalf("❌ 加载交易员失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取数据库中的所有交易员配置(用于显示,使用default用户)
|
||||
traders, err := database.GetTraders("default")
|
||||
traders, err := st.Trader().List("default")
|
||||
if err != nil {
|
||||
log.Fatalf("❌ 获取交易员列表失败: %v", err)
|
||||
logger.Fatalf("❌ 获取交易员列表失败: %v", err)
|
||||
}
|
||||
|
||||
// 显示加载的交易员信息
|
||||
fmt.Println()
|
||||
fmt.Println("🤖 数据库中的AI交易员配置:")
|
||||
logger.Info("🤖 数据库中的AI交易员配置:")
|
||||
if len(traders) == 0 {
|
||||
fmt.Println(" • 暂无配置的交易员,请通过Web界面创建")
|
||||
logger.Info(" • 暂无配置的交易员,请通过Web界面创建")
|
||||
} else {
|
||||
for _, trader := range traders {
|
||||
status := "停止"
|
||||
if trader.IsRunning {
|
||||
status = "运行中"
|
||||
}
|
||||
fmt.Printf(" • %s (%s + %s) - 初始资金: %.0f USDT [%s]\n",
|
||||
logger.Infof(" • %s (%s + %s) - 初始资金: %.0f USDT [%s]",
|
||||
trader.Name, strings.ToUpper(trader.AIModelID), strings.ToUpper(trader.ExchangeID),
|
||||
trader.InitialBalance, status)
|
||||
}
|
||||
}
|
||||
|
||||
// 创建初始化上下文
|
||||
// TODO : 传入实际配置, 现在并未实际使用,未来所有模块初始化都将通过上下文传递配置
|
||||
// ctx := bootstrap.NewContext(&config.Config{})
|
||||
|
||||
// // 执行所有初始化钩子
|
||||
// if err := bootstrap.Run(ctx); err != nil {
|
||||
// log.Fatalf("初始化失败: %v", err)
|
||||
// }
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("🤖 AI全权决策模式:")
|
||||
fmt.Printf(" • AI将自主决定每笔交易的杠杆倍数(山寨币最高5倍,BTC/ETH最高5倍)\n")
|
||||
fmt.Println(" • AI将自主决定每笔交易的仓位大小")
|
||||
fmt.Println(" • AI将自主设置止损和止盈价格")
|
||||
fmt.Println(" • AI将基于市场数据、技术指标、账户状态做出全面分析")
|
||||
fmt.Println()
|
||||
fmt.Println("⚠️ 风险提示: AI自动交易有风险,建议小额资金测试!")
|
||||
fmt.Println()
|
||||
fmt.Println("按 Ctrl+C 停止运行")
|
||||
fmt.Println(strings.Repeat("=", 60))
|
||||
fmt.Println()
|
||||
logger.Info("🤖 AI全权决策模式:")
|
||||
logger.Info(" • AI将自主决定每笔交易的杠杆倍数(山寨币最高5倍,BTC/ETH最高5倍)")
|
||||
logger.Info(" • AI将自主决定每笔交易的仓位大小")
|
||||
logger.Info(" • AI将自主设置止损和止盈价格")
|
||||
logger.Info(" • AI将基于市场数据、技术指标、账户状态做出全面分析")
|
||||
logger.Warn("⚠️ 风险提示: AI自动交易有风险,建议小额资金测试!")
|
||||
logger.Info("按 Ctrl+C 停止运行")
|
||||
logger.Info(strings.Repeat("=", 60))
|
||||
|
||||
// 获取API服务器端口(优先级:环境变量 > 数据库配置 > 默认值)
|
||||
apiPort := 8080 // 默认端口
|
||||
@@ -336,30 +352,38 @@ func main() {
|
||||
if envPort := strings.TrimSpace(os.Getenv("NOFX_BACKEND_PORT")); envPort != "" {
|
||||
if port, err := strconv.Atoi(envPort); err == nil && port > 0 {
|
||||
apiPort = port
|
||||
log.Printf("🔌 使用环境变量端口: %d (NOFX_BACKEND_PORT)", apiPort)
|
||||
logger.Infof("🔌 使用环境变量端口: %d (NOFX_BACKEND_PORT)", apiPort)
|
||||
} else {
|
||||
log.Printf("⚠️ 环境变量 NOFX_BACKEND_PORT 无效: %s", envPort)
|
||||
logger.Warnf("⚠️ 环境变量 NOFX_BACKEND_PORT 无效: %s", envPort)
|
||||
}
|
||||
} else if apiPortStr != "" {
|
||||
// 2. 从数据库配置读取(config.json 同步过来的)
|
||||
if port, err := strconv.Atoi(apiPortStr); err == nil && port > 0 {
|
||||
apiPort = port
|
||||
log.Printf("🔌 使用数据库配置端口: %d (api_server_port)", apiPort)
|
||||
logger.Infof("🔌 使用数据库配置端口: %d (api_server_port)", apiPort)
|
||||
}
|
||||
} else {
|
||||
log.Printf("🔌 使用默认端口: %d", apiPort)
|
||||
logger.Infof("🔌 使用默认端口: %d", apiPort)
|
||||
}
|
||||
|
||||
// 启动订单同步管理器
|
||||
orderSyncManager := trader.NewOrderSyncManager(st, 10*time.Second)
|
||||
orderSyncManager.Start()
|
||||
|
||||
// 启动仓位同步管理器(检测手动平仓等变化)
|
||||
positionSyncManager := trader.NewPositionSyncManager(st, 10*time.Second)
|
||||
positionSyncManager.Start()
|
||||
|
||||
// 创建并启动API服务器
|
||||
apiServer := api.NewServer(traderManager, database, cryptoService, backtestManager, apiPort)
|
||||
apiServer := api.NewServer(traderManager, st, cryptoService, backtestManager, apiPort)
|
||||
go func() {
|
||||
if err := apiServer.Start(); err != nil {
|
||||
log.Printf("❌ API服务器错误: %v", err)
|
||||
logger.Errorf("❌ API服务器错误: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 启动流行情数据 - 默认使用所有交易员设置的币种 如果没有设置币种 则优先使用系统默认
|
||||
go market.NewWSMonitor(150).Start(database.GetCustomCoins())
|
||||
go market.NewWSMonitor(150).Start(st.Trader().GetCustomCoins())
|
||||
//go market.NewWSMonitor(150).Start([]string{}) //这里是一个使用方式 传入空的话 则使用market市场的所有币种
|
||||
// 设置优雅退出
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
@@ -370,33 +394,36 @@ func main() {
|
||||
|
||||
// 等待退出信号
|
||||
<-sigChan
|
||||
fmt.Println()
|
||||
fmt.Println()
|
||||
log.Println("📛 收到退出信号,正在优雅关闭...")
|
||||
logger.Info("📛 收到退出信号,正在优雅关闭...")
|
||||
|
||||
// 步骤 1: 停止所有交易员
|
||||
log.Println("⏸️ 停止所有交易员...")
|
||||
logger.Info("⏸️ 停止所有交易员...")
|
||||
traderManager.StopAll()
|
||||
log.Println("✅ 所有交易员已停止")
|
||||
logger.Info("✅ 所有交易员已停止")
|
||||
|
||||
// 步骤 2: 关闭 API 服务器
|
||||
log.Println("🛑 停止 API 服务器...")
|
||||
// 步骤 2: 停止订单同步管理器和仓位同步管理器
|
||||
logger.Info("📦 停止订单同步管理器...")
|
||||
orderSyncManager.Stop()
|
||||
logger.Info("📊 停止仓位同步管理器...")
|
||||
positionSyncManager.Stop()
|
||||
|
||||
// 步骤 3: 关闭 API 服务器
|
||||
logger.Info("🛑 停止 API 服务器...")
|
||||
if err := apiServer.Shutdown(); err != nil {
|
||||
log.Printf("⚠️ 关闭 API 服务器时出错: %v", err)
|
||||
logger.Warnf("⚠️ 关闭 API 服务器时出错: %v", err)
|
||||
} else {
|
||||
log.Println("✅ API 服务器已安全关闭")
|
||||
logger.Info("✅ API 服务器已安全关闭")
|
||||
}
|
||||
|
||||
// 步骤 3: 关闭数据库连接 (确保所有写入完成)
|
||||
log.Println("💾 关闭数据库连接...")
|
||||
if err := database.Close(); err != nil {
|
||||
log.Printf("❌ 关闭数据库失败: %v", err)
|
||||
// 步骤 4: 关闭数据库连接 (确保所有写入完成)
|
||||
logger.Info("💾 关闭数据库连接...")
|
||||
if err := st.Close(); err != nil {
|
||||
logger.Errorf("❌ 关闭数据库失败: %v", err)
|
||||
} else {
|
||||
log.Println("✅ 数据库已安全关闭,所有数据已持久化")
|
||||
logger.Info("✅ 数据库已安全关闭,所有数据已持久化")
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("👋 感谢使用AI交易系统!")
|
||||
logger.Info("👋 感谢使用AI交易系统!")
|
||||
}
|
||||
|
||||
func newSharedMCPClient(cfg *config.Config) mcp.AIClient {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -38,7 +38,7 @@ func Get(symbol string) (*Data, error) {
|
||||
|
||||
// Data staleness detection: Prevent DOGEUSDT-style price freeze issues
|
||||
if isStaleData(klines3m, symbol) {
|
||||
log.Printf("⚠️ WARNING: %s detected stale data (consecutive price freeze), skipping symbol", symbol)
|
||||
logger.Infof("⚠️ WARNING: %s detected stale data (consecutive price freeze), skipping symbol", symbol)
|
||||
return nil, fmt.Errorf("%s data is stale, possible cache failure", symbol)
|
||||
}
|
||||
|
||||
@@ -633,11 +633,11 @@ func isStaleData(klines []Kline, symbol string) bool {
|
||||
}
|
||||
|
||||
if allVolumeZero {
|
||||
log.Printf("⚠️ %s stale data confirmed: price freeze + zero volume", symbol)
|
||||
logger.Infof("⚠️ %s stale data confirmed: price freeze + zero volume", symbol)
|
||||
return true
|
||||
}
|
||||
|
||||
// Price frozen but has volume: might be extremely low volatility market, allow but log warning
|
||||
log.Printf("⚠️ %s detected extreme price stability (no fluctuation for %d consecutive periods), but volume is normal", symbol, stalePriceThreshold)
|
||||
logger.Infof("⚠️ %s detected extreme price stability (no fluctuation for %d consecutive periods), but volume is normal", symbol, stalePriceThreshold)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"nofx/logger"
|
||||
)
|
||||
|
||||
// Config 客户端配置(集中管理所有配置)
|
||||
@@ -44,8 +46,8 @@ func DefaultConfig() *Config {
|
||||
Timeout: DefaultTimeout,
|
||||
RetryableErrors: retryableErrors,
|
||||
|
||||
// 默认依赖
|
||||
Logger: &defaultLogger{},
|
||||
// 默认依赖(使用全局 logger)
|
||||
Logger: logger.NewMCPLogger(),
|
||||
HTTPClient: &http.Client{Timeout: DefaultTimeout},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
package mcp
|
||||
|
||||
import "log"
|
||||
|
||||
// Logger 日志接口(抽象依赖)
|
||||
// 使用 Printf 风格的方法名,方便集成 logrus、zap 等主流日志库
|
||||
// 默认使用全局 logger 包(见 mcp/config.go)
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...any)
|
||||
Infof(format string, args ...any)
|
||||
@@ -11,25 +10,6 @@ type Logger interface {
|
||||
Errorf(format string, args ...any)
|
||||
}
|
||||
|
||||
// defaultLogger 默认日志实现(包装标准库 log)
|
||||
type defaultLogger struct{}
|
||||
|
||||
func (l *defaultLogger) Debugf(format string, args ...any) {
|
||||
log.Printf("[DEBUG] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Infof(format string, args ...any) {
|
||||
log.Printf("[INFO] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Warnf(format string, args ...any) {
|
||||
log.Printf("[WARN] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Errorf(format string, args ...any) {
|
||||
log.Printf("[ERROR] "+format, args...)
|
||||
}
|
||||
|
||||
// noopLogger 空日志实现(测试时使用)
|
||||
type noopLogger struct{}
|
||||
|
||||
@@ -42,27 +22,3 @@ func (l *noopLogger) Errorf(format string, args ...any) {}
|
||||
func NewNoopLogger() Logger {
|
||||
return &noopLogger{}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 适配第三方日志库示例
|
||||
// ============================================================
|
||||
|
||||
// Logrus 适配示例:
|
||||
// type LogrusLogger struct {
|
||||
// logger *logrus.Logger
|
||||
// }
|
||||
//
|
||||
// func (l *LogrusLogger) Infof(format string, args ...any) {
|
||||
// l.logger.Infof(format, args...)
|
||||
// }
|
||||
//
|
||||
// Zap 适配示例:
|
||||
// type ZapLogger struct {
|
||||
// logger *zap.Logger
|
||||
// }
|
||||
//
|
||||
// func (l *ZapLogger) Infof(format string, args ...any) {
|
||||
// l.logger.Sugar().Infof(format, args...)
|
||||
// }
|
||||
//
|
||||
// 然后通过 WithLogger(logger) 注入
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 380 KiB |
@@ -203,7 +203,7 @@ spec:
|
||||
./scripts/generate_data_key.sh
|
||||
|
||||
# 2. 备份旧数据库
|
||||
cp config.db config.db.backup
|
||||
cp data.db data.db.backup
|
||||
|
||||
# 3. 重启服务 (会自动处理密钥迁移)
|
||||
source .env && ./mars
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 数据加密密钥生成脚本 - 用于Mars AI交易系统数据库加密
|
||||
# 生成用于AES-256-GCM数据库加密的随机密钥
|
||||
|
||||
set -e # 遇到错误立即退出
|
||||
|
||||
# 颜色定义
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
PURPLE='\033[0;35m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo -e "${BLUE}╔══════════════════════════════════════════════════════════════════╗${NC}"
|
||||
echo -e "${BLUE}║ Mars AI交易系统 安全密钥生成器 ║${NC}"
|
||||
echo -e "${BLUE}║ AES-256-GCM数据密钥 + JWT认证密钥 ║${NC}"
|
||||
echo -e "${BLUE}╚══════════════════════════════════════════════════════════════════╝${NC}"
|
||||
echo
|
||||
|
||||
# 检查是否安装了 OpenSSL
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
echo -e "${RED}❌ 错误: 系统中未安装 OpenSSL${NC}"
|
||||
echo -e "请安装 OpenSSL:"
|
||||
echo -e " macOS: ${YELLOW}brew install openssl${NC}"
|
||||
echo -e " Ubuntu/Debian: ${YELLOW}sudo apt-get install openssl${NC}"
|
||||
echo -e " CentOS/RHEL: ${YELLOW}sudo yum install openssl${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}✓ OpenSSL 已安装: $(openssl version)${NC}"
|
||||
|
||||
# 生成安全密钥
|
||||
echo -e "${BLUE}🔐 生成安全密钥...${NC}"
|
||||
echo
|
||||
|
||||
# 生成 AES-256 数据加密密钥
|
||||
echo -e "${YELLOW}1/2: 生成 AES-256 数据加密密钥...${NC}"
|
||||
DATA_KEY=$(openssl rand -base64 32)
|
||||
if [ $? -eq 0 ]; then
|
||||
echo -e "${GREEN} ✓ 数据加密密钥生成成功${NC}"
|
||||
else
|
||||
echo -e "${RED} ❌ 数据加密密钥生成失败${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 生成 JWT 认证密钥
|
||||
echo -e "${YELLOW}2/2: 生成 JWT 认证密钥...${NC}"
|
||||
JWT_KEY=$(openssl rand -base64 64)
|
||||
if [ $? -eq 0 ]; then
|
||||
echo -e "${GREEN} ✓ JWT认证密钥生成成功${NC}"
|
||||
else
|
||||
echo -e "${RED} ❌ JWT认证密钥生成失败${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 显示密钥
|
||||
echo
|
||||
echo -e "${GREEN}🎉 安全密钥生成完成!${NC}"
|
||||
echo
|
||||
echo -e "${BLUE}📋 生成的密钥:${NC}"
|
||||
echo -e "${PURPLE}1. 数据加密密钥 (AES-256):${NC}"
|
||||
echo -e "${YELLOW}$DATA_KEY${NC}"
|
||||
echo
|
||||
echo -e "${PURPLE}2. JWT认证密钥 (512-bit):${NC}"
|
||||
echo -e "${YELLOW}$JWT_KEY${NC}"
|
||||
echo
|
||||
|
||||
# 显示使用方法
|
||||
echo -e "${YELLOW}📋 使用方法:${NC}"
|
||||
echo
|
||||
echo -e "${BLUE}1. 环境变量设置:${NC}"
|
||||
echo -e " export DATA_ENCRYPTION_KEY=\"$DATA_KEY\""
|
||||
echo -e " export JWT_SECRET=\"$JWT_KEY\""
|
||||
echo
|
||||
echo -e "${BLUE}2. .env 文件设置:${NC}"
|
||||
echo -e " DATA_ENCRYPTION_KEY=$DATA_KEY"
|
||||
echo -e " JWT_SECRET=$JWT_KEY"
|
||||
echo
|
||||
echo -e "${BLUE}3. Docker环境设置:${NC}"
|
||||
echo -e " docker run -e DATA_ENCRYPTION_KEY=\"$DATA_KEY\" -e JWT_SECRET=\"$JWT_KEY\" ..."
|
||||
echo
|
||||
echo -e "${BLUE}4. Kubernetes Secret:${NC}"
|
||||
echo -e " kubectl create secret generic mars-crypto-key \\"
|
||||
echo -e " --from-literal=DATA_ENCRYPTION_KEY=\"$DATA_KEY\" \\"
|
||||
echo -e " --from-literal=JWT_SECRET=\"$JWT_KEY\""
|
||||
echo
|
||||
|
||||
# 显示密钥特性
|
||||
echo -e "${BLUE}🔍 密钥特性:${NC}"
|
||||
echo -e " • 数据加密: ${YELLOW}AES-256-GCM (256 bits)${NC}"
|
||||
echo -e " • JWT认证: ${YELLOW}HS256 (512 bits)${NC}"
|
||||
echo -e " • 格式: ${YELLOW}Base64 编码${NC}"
|
||||
echo -e " • 用途: ${YELLOW}数据库加密 + 用户认证${NC}"
|
||||
|
||||
# 安全提醒
|
||||
echo
|
||||
echo -e "${RED}⚠️ 安全提醒:${NC}"
|
||||
echo -e " • 请妥善保管此密钥,丢失后无法恢复加密的数据"
|
||||
echo -e " • 不要将密钥提交到版本控制系统"
|
||||
echo -e " • 建议在不同环境使用不同的密钥"
|
||||
echo -e " • 定期更换密钥并重新加密数据"
|
||||
echo -e " • 在生产环境中,建议使用密钥管理服务"
|
||||
|
||||
echo
|
||||
echo -e "${GREEN}✅ 数据加密密钥生成完成!${NC}"
|
||||
|
||||
# 可选:保存到 .env 文件
|
||||
echo
|
||||
read -p "是否将密钥保存到 .env 文件? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
if [ -f ".env" ]; then
|
||||
# 检查是否已存在 DATA_ENCRYPTION_KEY
|
||||
if grep -q "^DATA_ENCRYPTION_KEY=" .env; then
|
||||
echo -e "${YELLOW}⚠️ .env 文件中已存在 DATA_ENCRYPTION_KEY${NC}"
|
||||
read -p "是否覆盖现有密钥? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
# 替换现有密钥
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
# macOS
|
||||
sed -i '' "s/^DATA_ENCRYPTION_KEY=.*/DATA_ENCRYPTION_KEY=$RAW_KEY/" .env
|
||||
else
|
||||
# Linux
|
||||
sed -i "s/^DATA_ENCRYPTION_KEY=.*/DATA_ENCRYPTION_KEY=$RAW_KEY/" .env
|
||||
fi
|
||||
echo -e "${GREEN}✓ .env 文件中的密钥已更新${NC}"
|
||||
else
|
||||
echo -e "${BLUE}ℹ️ 保持现有密钥不变${NC}"
|
||||
fi
|
||||
else
|
||||
# 追加新密钥
|
||||
echo "DATA_ENCRYPTION_KEY=$RAW_KEY" >> .env
|
||||
echo -e "${GREEN}✓ 密钥已保存到 .env 文件${NC}"
|
||||
fi
|
||||
else
|
||||
# 创建新的 .env 文件
|
||||
echo "DATA_ENCRYPTION_KEY=$RAW_KEY" > .env
|
||||
echo -e "${GREEN}✓ 密钥已保存到 .env 文件${NC}"
|
||||
fi
|
||||
fi
|
||||
@@ -1,149 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# RSA密钥对生成脚本 - 用于Mars AI交易系统加密服务
|
||||
# 生成用于混合加密的RSA-2048密钥对
|
||||
|
||||
set -e # 遇到错误立即退出
|
||||
|
||||
# 颜色定义
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# 配置
|
||||
RSA_KEY_SIZE=2048
|
||||
SECRETS_DIR="secrets"
|
||||
PRIVATE_KEY_FILE="$SECRETS_DIR/rsa_key"
|
||||
PUBLIC_KEY_FILE="$SECRETS_DIR/rsa_key.pub"
|
||||
|
||||
echo -e "${BLUE}╔══════════════════════════════════════════════════════════════════╗${NC}"
|
||||
echo -e "${BLUE}║ Mars AI交易系统 RSA密钥生成器 ║${NC}"
|
||||
echo -e "${BLUE}║ RSA-2048 混合加密密钥对 ║${NC}"
|
||||
echo -e "${BLUE}╚══════════════════════════════════════════════════════════════════╝${NC}"
|
||||
echo
|
||||
|
||||
# 检查是否安装了 OpenSSL
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
echo -e "${RED}❌ 错误: 系统中未安装 OpenSSL${NC}"
|
||||
echo -e "请安装 OpenSSL:"
|
||||
echo -e " macOS: ${YELLOW}brew install openssl${NC}"
|
||||
echo -e " Ubuntu/Debian: ${YELLOW}sudo apt-get install openssl${NC}"
|
||||
echo -e " CentOS/RHEL: ${YELLOW}sudo yum install openssl${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}✓ OpenSSL 已安装: $(openssl version)${NC}"
|
||||
|
||||
# 创建 secrets 目录
|
||||
if [ ! -d "$SECRETS_DIR" ]; then
|
||||
echo -e "${YELLOW}📁 创建 $SECRETS_DIR 目录...${NC}"
|
||||
mkdir -p "$SECRETS_DIR"
|
||||
chmod 700 "$SECRETS_DIR"
|
||||
echo -e "${GREEN}✓ 目录创建成功${NC}"
|
||||
else
|
||||
echo -e "${GREEN}✓ $SECRETS_DIR 目录已存在${NC}"
|
||||
fi
|
||||
|
||||
# 检查现有密钥
|
||||
if [ -f "$PRIVATE_KEY_FILE" ] || [ -f "$PUBLIC_KEY_FILE" ]; then
|
||||
echo
|
||||
echo -e "${YELLOW}⚠️ 检测到现有的RSA密钥文件:${NC}"
|
||||
[ -f "$PRIVATE_KEY_FILE" ] && echo -e " • $PRIVATE_KEY_FILE"
|
||||
[ -f "$PUBLIC_KEY_FILE" ] && echo -e " • $PUBLIC_KEY_FILE"
|
||||
echo
|
||||
read -p "是否覆盖现有密钥? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo -e "${BLUE}ℹ️ 操作已取消${NC}"
|
||||
exit 0
|
||||
fi
|
||||
echo -e "${YELLOW}🗑️ 删除现有密钥文件...${NC}"
|
||||
rm -f "$PRIVATE_KEY_FILE" "$PUBLIC_KEY_FILE"
|
||||
fi
|
||||
|
||||
echo
|
||||
echo -e "${BLUE}🔐 开始生成 RSA-$RSA_KEY_SIZE 密钥对...${NC}"
|
||||
|
||||
# 生成私钥
|
||||
echo -e "${YELLOW}📝 步骤 1/3: 生成 RSA 私钥 ($RSA_KEY_SIZE bits)...${NC}"
|
||||
if openssl genrsa -out "$PRIVATE_KEY_FILE" $RSA_KEY_SIZE 2>/dev/null; then
|
||||
echo -e "${GREEN}✓ 私钥生成成功${NC}"
|
||||
else
|
||||
echo -e "${RED}❌ 私钥生成失败${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 设置私钥权限
|
||||
chmod 600 "$PRIVATE_KEY_FILE"
|
||||
echo -e "${GREEN}✓ 私钥权限设置为 600${NC}"
|
||||
|
||||
# 生成公钥
|
||||
echo -e "${YELLOW}📝 步骤 2/3: 从私钥提取公钥...${NC}"
|
||||
if openssl rsa -in "$PRIVATE_KEY_FILE" -pubout -out "$PUBLIC_KEY_FILE" 2>/dev/null; then
|
||||
echo -e "${GREEN}✓ 公钥生成成功${NC}"
|
||||
else
|
||||
echo -e "${RED}❌ 公钥生成失败${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 设置公钥权限
|
||||
chmod 644 "$PUBLIC_KEY_FILE"
|
||||
echo -e "${GREEN}✓ 公钥权限设置为 644${NC}"
|
||||
|
||||
# 验证密钥
|
||||
echo -e "${YELLOW}📝 步骤 3/3: 验证密钥对...${NC}"
|
||||
if openssl rsa -in "$PRIVATE_KEY_FILE" -check -noout 2>/dev/null; then
|
||||
echo -e "${GREEN}✓ 私钥验证通过${NC}"
|
||||
else
|
||||
echo -e "${RED}❌ 私钥验证失败${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if openssl rsa -in "$PUBLIC_KEY_FILE" -pubin -text -noout &>/dev/null; then
|
||||
echo -e "${GREEN}✓ 公钥验证通过${NC}"
|
||||
else
|
||||
echo -e "${RED}❌ 公钥验证失败${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 显示密钥信息
|
||||
echo
|
||||
echo -e "${GREEN}🎉 RSA密钥对生成成功!${NC}"
|
||||
echo
|
||||
echo -e "${BLUE}📋 密钥信息:${NC}"
|
||||
echo -e " 私钥文件: ${YELLOW}$PRIVATE_KEY_FILE${NC}"
|
||||
echo -e " 公钥文件: ${YELLOW}$PUBLIC_KEY_FILE${NC}"
|
||||
echo -e " 密钥大小: ${YELLOW}$RSA_KEY_SIZE bits${NC}"
|
||||
echo
|
||||
|
||||
# 显示文件大小
|
||||
PRIVATE_SIZE=$(stat -f%z "$PRIVATE_KEY_FILE" 2>/dev/null || stat -c%s "$PRIVATE_KEY_FILE" 2>/dev/null || echo "未知")
|
||||
PUBLIC_SIZE=$(stat -f%z "$PUBLIC_KEY_FILE" 2>/dev/null || stat -c%s "$PUBLIC_KEY_FILE" 2>/dev/null || echo "未知")
|
||||
|
||||
echo -e "${BLUE}📏 文件大小:${NC}"
|
||||
echo -e " 私钥: ${YELLOW}$PRIVATE_SIZE bytes${NC}"
|
||||
echo -e " 公钥: ${YELLOW}$PUBLIC_SIZE bytes${NC}"
|
||||
|
||||
# 显示公钥内容预览
|
||||
echo
|
||||
echo -e "${BLUE}🔍 公钥内容预览:${NC}"
|
||||
head -n 5 "$PUBLIC_KEY_FILE" | sed 's/^/ /'
|
||||
echo -e " ${YELLOW}...${NC}"
|
||||
tail -n 2 "$PUBLIC_KEY_FILE" | sed 's/^/ /'
|
||||
|
||||
echo
|
||||
echo -e "${GREEN}✅ RSA密钥对生成完成!${NC}"
|
||||
echo
|
||||
echo -e "${YELLOW}📋 使用说明:${NC}"
|
||||
echo -e " 1. 私钥文件 ($PRIVATE_KEY_FILE) 用于服务器端解密"
|
||||
echo -e " 2. 公钥文件 ($PUBLIC_KEY_FILE) 可以分发给客户端用于加密"
|
||||
echo -e " 3. 确保私钥文件的安全性,不要泄露给第三方"
|
||||
echo -e " 4. 在生产环境中,建议将私钥存储在安全的密钥管理服务中"
|
||||
echo
|
||||
echo -e "${RED}⚠️ 安全提醒:${NC}"
|
||||
echo -e " • 私钥文件权限已设置为 600 (仅所有者可读写)"
|
||||
echo -e " • 请定期备份密钥文件"
|
||||
echo -e " • 建议在不同环境使用不同的密钥对"
|
||||
echo
|
||||
@@ -12,71 +12,71 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.Println("🔄 開始遷移數據庫到加密格式...")
|
||||
log.Println("🔄 开始迁移数据库到加密格式...")
|
||||
|
||||
// 1. 檢查數據庫檔案
|
||||
dbPath := "config.db"
|
||||
// 1. 检查数据库文件
|
||||
dbPath := "data.db"
|
||||
if len(os.Args) > 1 {
|
||||
dbPath = os.Args[1]
|
||||
}
|
||||
|
||||
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
|
||||
log.Fatalf("❌ 數據庫檔案不存在: %s", dbPath)
|
||||
log.Fatalf("❌ 数据库文件不存在: %s", dbPath)
|
||||
}
|
||||
|
||||
// 2. 備份數據庫
|
||||
// 2. 备份数据库
|
||||
backupPath := fmt.Sprintf("%s.pre_encryption_backup", dbPath)
|
||||
log.Printf("📦 備份數據庫到: %s", backupPath)
|
||||
log.Printf("📦 备份数据库到: %s", backupPath)
|
||||
|
||||
input, err := os.ReadFile(dbPath)
|
||||
if err != nil {
|
||||
log.Fatalf("❌ 讀取數據庫失敗: %v", err)
|
||||
log.Fatalf("❌ 读取数据库失败: %v", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(backupPath, input, 0600); err != nil {
|
||||
log.Fatalf("❌ 備份失敗: %v", err)
|
||||
log.Fatalf("❌ 备份失败: %v", err)
|
||||
}
|
||||
|
||||
// 3. 打開數據庫
|
||||
// 3. 打开数据库
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
log.Fatalf("❌ 打開數據庫失敗: %v", err)
|
||||
log.Fatalf("❌ 打开数据库失败: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// 4. 初始化加密管理器
|
||||
em, err := crypto.GetEncryptionManager()
|
||||
// 4. 初始化 CryptoService(从环境变量加载密钥)
|
||||
cs, err := crypto.NewCryptoService()
|
||||
if err != nil {
|
||||
log.Fatalf("❌ 初始化加密管理器失敗: %v", err)
|
||||
log.Fatalf("❌ 初始化加密服务失败: %v", err)
|
||||
}
|
||||
|
||||
// 5. 遷移交易所配置
|
||||
if err := migrateExchanges(db, em); err != nil {
|
||||
log.Fatalf("❌ 遷移交易所配置失敗: %v", err)
|
||||
// 5. 迁移交易所配置
|
||||
if err := migrateExchanges(db, cs); err != nil {
|
||||
log.Fatalf("❌ 迁移交易所配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 6. 遷移 AI 模型配置
|
||||
if err := migrateAIModels(db, em); err != nil {
|
||||
log.Fatalf("❌ 遷移 AI 模型配置失敗: %v", err)
|
||||
// 6. 迁移 AI 模型配置
|
||||
if err := migrateAIModels(db, cs); err != nil {
|
||||
log.Fatalf("❌ 迁移 AI 模型配置失败: %v", err)
|
||||
}
|
||||
|
||||
log.Println("✅ 數據遷移完成!")
|
||||
log.Printf("📝 原始數據備份位於: %s", backupPath)
|
||||
log.Println("⚠️ 請驗證系統功能正常後,手動刪除備份檔案")
|
||||
log.Println("✅ 数据迁移完成!")
|
||||
log.Printf("📝 原始数据备份位于: %s", backupPath)
|
||||
log.Println("⚠️ 请验证系统功能正常后,手动删除备份文件")
|
||||
}
|
||||
|
||||
// migrateExchanges 遷移交易所配置
|
||||
func migrateExchanges(db *sql.DB, em *crypto.EncryptionManager) error {
|
||||
log.Println("🔄 遷移交易所配置...")
|
||||
// migrateExchanges 迁移交易所配置
|
||||
func migrateExchanges(db *sql.DB, cs *crypto.CryptoService) error {
|
||||
log.Println("🔄 迁移交易所配置...")
|
||||
|
||||
// 查詢所有未加密的記錄(假設加密數據都包含 '==' Base64 特徵)
|
||||
// 查询所有未加密的记录(加密数据以 ENC:v1: 开头)
|
||||
rows, err := db.Query(`
|
||||
SELECT user_id, id, api_key, secret_key,
|
||||
COALESCE(hyperliquid_private_key, ''),
|
||||
COALESCE(aster_private_key, '')
|
||||
FROM exchanges
|
||||
WHERE (api_key != '' AND api_key NOT LIKE '%==%')
|
||||
OR (secret_key != '' AND secret_key NOT LIKE '%==%')
|
||||
WHERE (api_key != '' AND api_key NOT LIKE 'ENC:v1:%')
|
||||
OR (secret_key != '' AND secret_key NOT LIKE 'ENC:v1:%')
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -96,34 +96,34 @@ func migrateExchanges(db *sql.DB, em *crypto.EncryptionManager) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// 加密每個字段
|
||||
encAPIKey, err := em.EncryptForDatabase(apiKey)
|
||||
// 加密每个字段
|
||||
encAPIKey, err := cs.EncryptForStorage(apiKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 API Key 失敗: %w", err)
|
||||
return fmt.Errorf("加密 API Key 失败: %w", err)
|
||||
}
|
||||
|
||||
encSecretKey, err := em.EncryptForDatabase(secretKey)
|
||||
encSecretKey, err := cs.EncryptForStorage(secretKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 Secret Key 失敗: %w", err)
|
||||
return fmt.Errorf("加密 Secret Key 失败: %w", err)
|
||||
}
|
||||
|
||||
encHLPrivateKey := ""
|
||||
if hlPrivateKey != "" {
|
||||
encHLPrivateKey, err = em.EncryptForDatabase(hlPrivateKey)
|
||||
encHLPrivateKey, err = cs.EncryptForStorage(hlPrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 Hyperliquid Private Key 失敗: %w", err)
|
||||
return fmt.Errorf("加密 Hyperliquid Private Key 失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
encAsterPrivateKey := ""
|
||||
if asterPrivateKey != "" {
|
||||
encAsterPrivateKey, err = em.EncryptForDatabase(asterPrivateKey)
|
||||
encAsterPrivateKey, err = cs.EncryptForStorage(asterPrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 Aster Private Key 失敗: %w", err)
|
||||
return fmt.Errorf("加密 Aster Private Key 失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新數據庫
|
||||
// 更新数据库
|
||||
_, err = tx.Exec(`
|
||||
UPDATE exchanges
|
||||
SET api_key = ?, secret_key = ?,
|
||||
@@ -132,7 +132,7 @@ func migrateExchanges(db *sql.DB, em *crypto.EncryptionManager) error {
|
||||
`, encAPIKey, encSecretKey, encHLPrivateKey, encAsterPrivateKey, userID, exchangeID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新數據庫失敗: %w", err)
|
||||
return fmt.Errorf("更新数据库失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf(" ✓ 已加密: [%s] %s", userID, exchangeID)
|
||||
@@ -143,18 +143,18 @@ func migrateExchanges(db *sql.DB, em *crypto.EncryptionManager) error {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("✅ 已遷移 %d 個交易所配置", count)
|
||||
log.Printf("✅ 已迁移 %d 个交易所配置", count)
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateAIModels 遷移 AI 模型配置
|
||||
func migrateAIModels(db *sql.DB, em *crypto.EncryptionManager) error {
|
||||
log.Println("🔄 遷移 AI 模型配置...")
|
||||
// migrateAIModels 迁移 AI 模型配置
|
||||
func migrateAIModels(db *sql.DB, cs *crypto.CryptoService) error {
|
||||
log.Println("🔄 迁移 AI 模型配置...")
|
||||
|
||||
rows, err := db.Query(`
|
||||
SELECT user_id, id, api_key
|
||||
FROM ai_models
|
||||
WHERE api_key != '' AND api_key NOT LIKE '%==%'
|
||||
WHERE api_key != '' AND api_key NOT LIKE 'ENC:v1:%'
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -174,9 +174,9 @@ func migrateAIModels(db *sql.DB, em *crypto.EncryptionManager) error {
|
||||
return err
|
||||
}
|
||||
|
||||
encAPIKey, err := em.EncryptForDatabase(apiKey)
|
||||
encAPIKey, err := cs.EncryptForStorage(apiKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加密 API Key 失敗: %w", err)
|
||||
return fmt.Errorf("加密 API Key 失败: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
@@ -184,7 +184,7 @@ func migrateAIModels(db *sql.DB, em *crypto.EncryptionManager) error {
|
||||
`, encAPIKey, userID, modelID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新數據庫失敗: %w", err)
|
||||
return fmt.Errorf("更新数据库失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf(" ✓ 已加密: [%s] %s", userID, modelID)
|
||||
@@ -195,6 +195,6 @@ func migrateAIModels(db *sql.DB, em *crypto.EncryptionManager) error {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("✅ 已遷移 %d 個 AI 模型配置", count)
|
||||
log.Printf("✅ 已迁移 %d 个 AI 模型配置", count)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,319 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Mars AI交易系统加密环境设置脚本
|
||||
# 一键生成RSA密钥对和数据加密密钥,完整设置加密环境
|
||||
|
||||
set -e # 遇到错误立即退出
|
||||
|
||||
# 颜色定义
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
PURPLE='\033[0;35m'
|
||||
CYAN='\033[0;36m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# 获取脚本所在目录
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
echo -e "${PURPLE}╔════════════════════════════════════════════════════════════════════════╗${NC}"
|
||||
echo -e "${PURPLE}║ Mars AI交易系统 ║${NC}"
|
||||
echo -e "${PURPLE}║ 🔐 加密环境一键设置工具 ║${NC}"
|
||||
echo -e "${PURPLE}║ ║${NC}"
|
||||
echo -e "${PURPLE}║ 功能: 生成RSA密钥对 + 数据加密密钥 + 配置环境变量 ║${NC}"
|
||||
echo -e "${PURPLE}╚════════════════════════════════════════════════════════════════════════╝${NC}"
|
||||
echo
|
||||
|
||||
# 检查依赖
|
||||
echo -e "${CYAN}🔍 检查系统依赖...${NC}"
|
||||
|
||||
# 检查 OpenSSL
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
echo -e "${RED}❌ 错误: 系统中未安装 OpenSSL${NC}"
|
||||
echo -e "请安装 OpenSSL:"
|
||||
echo -e " macOS: ${YELLOW}brew install openssl${NC}"
|
||||
echo -e " Ubuntu/Debian: ${YELLOW}sudo apt-get install openssl${NC}"
|
||||
echo -e " CentOS/RHEL: ${YELLOW}sudo yum install openssl${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}✓ OpenSSL: $(openssl version)${NC}"
|
||||
|
||||
# 进入项目根目录
|
||||
cd "$PROJECT_ROOT"
|
||||
echo -e "${GREEN}✓ 工作目录: $(pwd)${NC}"
|
||||
|
||||
# 配置参数
|
||||
RSA_KEY_SIZE=2048
|
||||
SECRETS_DIR="secrets"
|
||||
PRIVATE_KEY_FILE="$SECRETS_DIR/rsa_key"
|
||||
PUBLIC_KEY_FILE="$SECRETS_DIR/rsa_key.pub"
|
||||
|
||||
echo
|
||||
echo -e "${BLUE}📋 配置参数:${NC}"
|
||||
echo -e " • RSA密钥大小: ${YELLOW}$RSA_KEY_SIZE bits${NC}"
|
||||
echo -e " • 私钥文件: ${YELLOW}$PRIVATE_KEY_FILE${NC}"
|
||||
echo -e " • 公钥文件: ${YELLOW}$PUBLIC_KEY_FILE${NC}"
|
||||
echo -e " • AES密钥: ${YELLOW}256 bits (自动生成)${NC}"
|
||||
|
||||
# 询问用户确认
|
||||
echo
|
||||
read -p "是否继续设置加密环境? [Y/n]: " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Nn]$ ]]; then
|
||||
echo -e "${BLUE}ℹ️ 操作已取消${NC}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo
|
||||
echo -e "${CYAN}🚀 开始设置加密环境...${NC}"
|
||||
|
||||
# ============= 步骤1: 创建目录 =============
|
||||
echo
|
||||
echo -e "${YELLOW}📁 步骤 1/4: 创建必要目录...${NC}"
|
||||
|
||||
if [ ! -d "$SECRETS_DIR" ]; then
|
||||
mkdir -p "$SECRETS_DIR"
|
||||
chmod 700 "$SECRETS_DIR"
|
||||
echo -e "${GREEN}✓ 创建 $SECRETS_DIR 目录${NC}"
|
||||
else
|
||||
echo -e "${GREEN}✓ $SECRETS_DIR 目录已存在${NC}"
|
||||
fi
|
||||
|
||||
if [ ! -d "scripts" ]; then
|
||||
mkdir -p "scripts"
|
||||
echo -e "${GREEN}✓ 创建 scripts 目录${NC}"
|
||||
else
|
||||
echo -e "${GREEN}✓ scripts 目录已存在${NC}"
|
||||
fi
|
||||
|
||||
# ============= 步骤2: 生成RSA密钥对 =============
|
||||
echo
|
||||
echo -e "${YELLOW}🔐 步骤 2/4: 生成 RSA-$RSA_KEY_SIZE 密钥对...${NC}"
|
||||
|
||||
# 检查现有RSA密钥
|
||||
if [ -f "$PRIVATE_KEY_FILE" ] || [ -f "$PUBLIC_KEY_FILE" ]; then
|
||||
echo -e "${YELLOW}⚠️ 检测到现有的RSA密钥文件${NC}"
|
||||
read -p "是否重新生成RSA密钥? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
rm -f "$PRIVATE_KEY_FILE" "$PUBLIC_KEY_FILE"
|
||||
echo -e "${YELLOW}🗑️ 删除旧密钥${NC}"
|
||||
else
|
||||
echo -e "${BLUE}ℹ️ 保持现有RSA密钥${NC}"
|
||||
RSA_SKIPPED=true
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$RSA_SKIPPED" != "true" ]; then
|
||||
# 生成私钥
|
||||
echo -e " ${CYAN}生成RSA私钥...${NC}"
|
||||
openssl genrsa -out "$PRIVATE_KEY_FILE" $RSA_KEY_SIZE 2>/dev/null
|
||||
chmod 600 "$PRIVATE_KEY_FILE"
|
||||
echo -e "${GREEN} ✓ 私钥生成完成${NC}"
|
||||
|
||||
# 生成公钥
|
||||
echo -e " ${CYAN}提取RSA公钥...${NC}"
|
||||
openssl rsa -in "$PRIVATE_KEY_FILE" -pubout -out "$PUBLIC_KEY_FILE" 2>/dev/null
|
||||
chmod 644 "$PUBLIC_KEY_FILE"
|
||||
echo -e "${GREEN} ✓ 公钥生成完成${NC}"
|
||||
|
||||
# 验证密钥
|
||||
echo -e " ${CYAN}验证密钥对...${NC}"
|
||||
openssl rsa -in "$PRIVATE_KEY_FILE" -check -noout 2>/dev/null
|
||||
echo -e "${GREEN} ✓ 密钥验证通过${NC}"
|
||||
fi
|
||||
|
||||
# ============= 步骤3: 生成数据加密密钥和JWT密钥 =============
|
||||
echo
|
||||
echo -e "${YELLOW}🔑 步骤 3/4: 生成 AES-256 数据加密密钥和JWT认证密钥...${NC}"
|
||||
|
||||
# 检查现有密钥
|
||||
DATA_KEY_EXISTS=false
|
||||
JWT_KEY_EXISTS=false
|
||||
|
||||
if [ -f ".env" ]; then
|
||||
if grep -q "^DATA_ENCRYPTION_KEY=" .env; then
|
||||
DATA_KEY_EXISTS=true
|
||||
fi
|
||||
if grep -q "^JWT_SECRET=" .env; then
|
||||
JWT_KEY_EXISTS=true
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$DATA_KEY_EXISTS" = "true" ] || [ "$JWT_KEY_EXISTS" = "true" ]; then
|
||||
echo -e "${YELLOW}⚠️ 检测到现有的密钥配置${NC}"
|
||||
if [ "$DATA_KEY_EXISTS" = "true" ]; then
|
||||
echo -e " • 数据加密密钥已存在"
|
||||
fi
|
||||
if [ "$JWT_KEY_EXISTS" = "true" ]; then
|
||||
echo -e " • JWT认证密钥已存在"
|
||||
fi
|
||||
read -p "是否重新生成所有密钥? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
echo -e "${BLUE}ℹ️ 保持现有密钥${NC}"
|
||||
KEY_SKIPPED=true
|
||||
# 读取现有密钥
|
||||
if [ "$DATA_KEY_EXISTS" = "true" ]; then
|
||||
DATA_KEY=$(grep "^DATA_ENCRYPTION_KEY=" .env | cut -d'=' -f2)
|
||||
fi
|
||||
if [ "$JWT_KEY_EXISTS" = "true" ]; then
|
||||
JWT_KEY=$(grep "^JWT_SECRET=" .env | cut -d'=' -f2)
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$KEY_SKIPPED" != "true" ]; then
|
||||
# 生成新的密钥
|
||||
echo -e " ${CYAN}生成AES-256数据加密密钥...${NC}"
|
||||
DATA_KEY=$(openssl rand -base64 32)
|
||||
echo -e "${GREEN} ✓ 数据加密密钥生成完成${NC}"
|
||||
|
||||
echo -e " ${CYAN}生成JWT认证密钥...${NC}"
|
||||
JWT_KEY=$(openssl rand -base64 64)
|
||||
echo -e "${GREEN} ✓ JWT认证密钥生成完成${NC}"
|
||||
|
||||
# 保存到.env文件
|
||||
if [ -f ".env" ]; then
|
||||
# 更新现有文件
|
||||
if grep -q "^DATA_ENCRYPTION_KEY=" .env; then
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
sed -i '' "s/^DATA_ENCRYPTION_KEY=.*/DATA_ENCRYPTION_KEY=$DATA_KEY/" .env
|
||||
else
|
||||
sed -i "s/^DATA_ENCRYPTION_KEY=.*/DATA_ENCRYPTION_KEY=$DATA_KEY/" .env
|
||||
fi
|
||||
else
|
||||
echo "DATA_ENCRYPTION_KEY=$DATA_KEY" >> .env
|
||||
fi
|
||||
|
||||
if grep -q "^JWT_SECRET=" .env; then
|
||||
# 使用替代分隔符避免 / 字符冲突,并用引号保护值
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
sed -i '' "s|^JWT_SECRET=.*|JWT_SECRET=\"$JWT_KEY\"|" .env
|
||||
else
|
||||
sed -i "s|^JWT_SECRET=.*|JWT_SECRET=\"$JWT_KEY\"|" .env
|
||||
fi
|
||||
else
|
||||
# 使用引号确保值在同一行
|
||||
printf "JWT_SECRET=\"%s\"\n" "$JWT_KEY" >> .env
|
||||
fi
|
||||
else
|
||||
# 创建新文件
|
||||
echo "DATA_ENCRYPTION_KEY=$DATA_KEY" > .env
|
||||
printf "JWT_SECRET=\"%s\"\n" "$JWT_KEY" >> .env
|
||||
fi
|
||||
chmod 600 .env
|
||||
echo -e "${GREEN} ✓ 密钥已保存到 .env 文件${NC}"
|
||||
elif [ "$DATA_KEY_EXISTS" != "true" ] || [ "$JWT_KEY_EXISTS" != "true" ]; then
|
||||
# 生成缺失的密钥
|
||||
if [ "$DATA_KEY_EXISTS" != "true" ]; then
|
||||
echo -e " ${CYAN}生成缺失的AES-256数据加密密钥...${NC}"
|
||||
DATA_KEY=$(openssl rand -base64 32)
|
||||
echo "DATA_ENCRYPTION_KEY=$DATA_KEY" >> .env
|
||||
echo -e "${GREEN} ✓ 数据加密密钥生成完成${NC}"
|
||||
fi
|
||||
|
||||
if [ "$JWT_KEY_EXISTS" != "true" ]; then
|
||||
echo -e " ${CYAN}生成缺失的JWT认证密钥...${NC}"
|
||||
JWT_KEY=$(openssl rand -base64 64)
|
||||
printf "JWT_SECRET=\"%s\"\n" "$JWT_KEY" >> .env
|
||||
echo -e "${GREEN} ✓ JWT认证密钥生成完成${NC}"
|
||||
fi
|
||||
|
||||
chmod 600 .env
|
||||
echo -e "${GREEN} ✓ 密钥已保存到 .env 文件${NC}"
|
||||
fi
|
||||
|
||||
# ============= 步骤4: 验证和总结 =============
|
||||
echo
|
||||
echo -e "${YELLOW}✅ 步骤 4/4: 环境验证和总结...${NC}"
|
||||
|
||||
# 验证文件存在性和权限
|
||||
echo -e " ${CYAN}验证文件和权限...${NC}"
|
||||
|
||||
if [ -f "$PRIVATE_KEY_FILE" ]; then
|
||||
PRIVATE_PERM=$(stat -f "%A" "$PRIVATE_KEY_FILE" 2>/dev/null || stat -c "%a" "$PRIVATE_KEY_FILE" 2>/dev/null)
|
||||
echo -e "${GREEN} ✓ 私钥文件: $PRIVATE_KEY_FILE (权限: $PRIVATE_PERM)${NC}"
|
||||
else
|
||||
echo -e "${RED} ❌ 私钥文件不存在${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -f "$PUBLIC_KEY_FILE" ]; then
|
||||
PUBLIC_PERM=$(stat -f "%A" "$PUBLIC_KEY_FILE" 2>/dev/null || stat -c "%a" "$PUBLIC_KEY_FILE" 2>/dev/null)
|
||||
echo -e "${GREEN} ✓ 公钥文件: $PUBLIC_KEY_FILE (权限: $PUBLIC_PERM)${NC}"
|
||||
else
|
||||
echo -e "${RED} ❌ 公钥文件不存在${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -f ".env" ] && grep -q "^DATA_ENCRYPTION_KEY=" .env && grep -q "^JWT_SECRET=" .env; then
|
||||
ENV_PERM=$(stat -f "%A" ".env" 2>/dev/null || stat -c "%a" ".env" 2>/dev/null)
|
||||
echo -e "${GREEN} ✓ 环境文件: .env (权限: $ENV_PERM)${NC}"
|
||||
echo -e "${GREEN} 包含: DATA_ENCRYPTION_KEY, JWT_SECRET${NC}"
|
||||
else
|
||||
echo -e "${RED} ❌ 环境文件不存在或缺少必要密钥${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 测试密钥功能
|
||||
echo -e " ${CYAN}测试密钥功能...${NC}"
|
||||
TEST_DATA="Hello Mars AI Trading System"
|
||||
ENCRYPTED=$(echo "$TEST_DATA" | openssl rsautl -encrypt -pubin -inkey "$PUBLIC_KEY_FILE" | base64)
|
||||
DECRYPTED=$(echo "$ENCRYPTED" | base64 -d | openssl rsautl -decrypt -inkey "$PRIVATE_KEY_FILE")
|
||||
|
||||
if [ "$DECRYPTED" = "$TEST_DATA" ]; then
|
||||
echo -e "${GREEN} ✓ RSA加密/解密测试通过${NC}"
|
||||
else
|
||||
echo -e "${RED} ❌ RSA加密/解密测试失败${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 显示最终结果
|
||||
echo
|
||||
echo -e "${GREEN}🎉 加密环境设置完成!${NC}"
|
||||
echo
|
||||
echo -e "${PURPLE}╔════════════════════════════════════════════════════════════════════════╗${NC}"
|
||||
echo -e "${PURPLE}║ 设置完成摘要 ║${NC}"
|
||||
echo -e "${PURPLE}╠════════════════════════════════════════════════════════════════════════╣${NC}"
|
||||
echo -e "${PURPLE}║${NC} ${BLUE}RSA密钥对:${NC} ${PURPLE}║${NC}"
|
||||
echo -e "${PURPLE}║${NC} 私钥: ${YELLOW}$PRIVATE_KEY_FILE${NC} ${PURPLE}║${NC}"
|
||||
echo -e "${PURPLE}║${NC} 公钥: ${YELLOW}$PUBLIC_KEY_FILE${NC} ${PURPLE}║${NC}"
|
||||
echo -e "${PURPLE}║${NC} 大小: ${YELLOW}$RSA_KEY_SIZE bits${NC} ${PURPLE}║${NC}"
|
||||
echo -e "${PURPLE}║${NC} ${PURPLE}║${NC}"
|
||||
echo -e "${PURPLE}║${NC} ${BLUE}安全密钥配置:${NC} ${PURPLE}║${NC}"
|
||||
echo -e "${PURPLE}║${NC} 文件: ${YELLOW}.env${NC} ${PURPLE}║${NC}"
|
||||
echo -e "${PURPLE}║${NC} 数据加密: ${YELLOW}DATA_ENCRYPTION_KEY (AES-256-GCM)${NC} ${PURPLE}║${NC}"
|
||||
echo -e "${PURPLE}║${NC} JWT认证: ${YELLOW}JWT_SECRET (HS256)${NC} ${PURPLE}║${NC}"
|
||||
echo -e "${PURPLE}╚════════════════════════════════════════════════════════════════════════╝${NC}"
|
||||
|
||||
# 使用指南
|
||||
echo
|
||||
echo -e "${BLUE}📋 使用指南:${NC}"
|
||||
echo
|
||||
echo -e "${YELLOW}1. 启动Mars AI交易系统:${NC}"
|
||||
echo -e " source .env && ./mars"
|
||||
echo
|
||||
echo -e "${YELLOW}2. Docker部署:${NC}"
|
||||
echo -e " docker run --env-file .env mars-ai-trading"
|
||||
echo
|
||||
echo -e "${YELLOW}3. 查看公钥内容:${NC}"
|
||||
echo -e " cat $PUBLIC_KEY_FILE"
|
||||
echo
|
||||
echo -e "${YELLOW}4. 测试加密API:${NC}"
|
||||
echo -e " curl http://localhost:8080/api/crypto/public-key"
|
||||
|
||||
# 安全提醒
|
||||
echo
|
||||
echo -e "${RED}🔒 安全提醒:${NC}"
|
||||
echo -e " • 私钥文件 ($PRIVATE_KEY_FILE) 权限已设置为 600"
|
||||
echo -e " • 环境文件 (.env) 权限已设置为 600"
|
||||
echo -e " • 请勿将私钥和数据密钥提交到版本控制系统"
|
||||
echo -e " • 建议在生产环境中使用密钥管理服务"
|
||||
echo -e " • 定期备份密钥文件"
|
||||
|
||||
echo
|
||||
echo -e "${GREEN}✅ Mars AI交易系统加密环境设置完成!${NC}"
|
||||
302
start.sh
302
start.sh
@@ -14,6 +14,7 @@ RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
CYAN='\033[0;36m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
@@ -70,95 +71,109 @@ check_env() {
|
||||
if [ ! -f ".env" ]; then
|
||||
print_warning ".env 不存在,从模板复制..."
|
||||
cp .env.example .env
|
||||
print_info "✓ 已使用默认环境变量创建 .env"
|
||||
print_info "💡 如需修改端口等设置,可编辑 .env 文件"
|
||||
print_info "已创建 .env 文件"
|
||||
fi
|
||||
print_success "环境变量文件存在"
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Validation: Encryption Environment (RSA Keys + Data Encryption Key)
|
||||
# Helper: Check if env var is set and not placeholder
|
||||
# ------------------------------------------------------------------------
|
||||
check_encryption() {
|
||||
local need_setup=false
|
||||
|
||||
print_info "检查加密环境..."
|
||||
|
||||
# 检查RSA密钥对
|
||||
if [ ! -f "secrets/rsa_key" ] || [ ! -f "secrets/rsa_key.pub" ]; then
|
||||
print_warning "RSA密钥对不存在"
|
||||
need_setup=true
|
||||
fi
|
||||
|
||||
# 检查数据加密密钥
|
||||
if [ ! -f ".env" ] || ! grep -q "^DATA_ENCRYPTION_KEY=" .env; then
|
||||
print_warning "数据加密密钥未配置"
|
||||
need_setup=true
|
||||
fi
|
||||
|
||||
# 检查JWT认证密钥
|
||||
if [ ! -f ".env" ] || ! grep -q "^JWT_SECRET=" .env; then
|
||||
print_warning "JWT认证密钥未配置"
|
||||
need_setup=true
|
||||
fi
|
||||
|
||||
# 如果需要设置加密环境,直接自动设置
|
||||
if [ "$need_setup" = "true" ]; then
|
||||
print_info "🔐 检测到加密环境未配置,正在自动设置..."
|
||||
print_info "加密环境用于保护敏感数据(API密钥、私钥等)"
|
||||
echo ""
|
||||
is_env_configured() {
|
||||
local var_name="$1"
|
||||
local value=$(grep "^${var_name}=" .env 2>/dev/null | cut -d'=' -f2-)
|
||||
|
||||
# 检查加密设置脚本是否存在
|
||||
if [ -f "scripts/setup_encryption.sh" ]; then
|
||||
print_info "加密系统将保护: API密钥、私钥、Hyperliquid代理钱包"
|
||||
echo ""
|
||||
# 去除引号
|
||||
value=$(echo "$value" | tr -d '"'"'")
|
||||
|
||||
# 自动运行加密设置脚本
|
||||
echo -e "Y\nn\nn" | bash scripts/setup_encryption.sh
|
||||
if [ $? -eq 0 ]; then
|
||||
echo ""
|
||||
print_success "🔐 加密环境设置完成!"
|
||||
print_info " • RSA-2048密钥对已生成"
|
||||
print_info " • AES-256数据加密密钥已配置"
|
||||
print_info " • JWT认证密钥已配置"
|
||||
print_info " • 所有敏感数据现在都受加密保护"
|
||||
echo ""
|
||||
else
|
||||
print_error "加密环境设置失败"
|
||||
exit 1
|
||||
fi
|
||||
# 检查是否为空或占位符
|
||||
if [ -z "$value" ]; then
|
||||
return 1
|
||||
fi
|
||||
|
||||
# 检查是否是示例值
|
||||
case "$value" in
|
||||
*your-*|*YOUR_*|*change-this*|*CHANGE_THIS*|*example*|*EXAMPLE*)
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Helper: Generate and set env var in .env file
|
||||
# ------------------------------------------------------------------------
|
||||
set_env_var() {
|
||||
local var_name="$1"
|
||||
local var_value="$2"
|
||||
|
||||
# 如果变量已存在(即使是占位符),替换它
|
||||
if grep -q "^${var_name}=" .env 2>/dev/null; then
|
||||
# macOS 和 Linux 兼容的 sed
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
sed -i '' "s|^${var_name}=.*|${var_name}=${var_value}|" .env
|
||||
else
|
||||
print_error "加密设置脚本不存在: scripts/setup_encryption.sh"
|
||||
print_info "请手动运行: ./scripts/setup_encryption.sh"
|
||||
exit 1
|
||||
sed -i "s|^${var_name}=.*|${var_name}=${var_value}|" .env
|
||||
fi
|
||||
else
|
||||
print_success "🔐 加密环境已配置"
|
||||
print_info " • RSA密钥对: secrets/rsa_key + secrets/rsa_key.pub"
|
||||
print_info " • 数据加密密钥: .env (DATA_ENCRYPTION_KEY)"
|
||||
print_info " • JWT认证密钥: .env (JWT_SECRET)"
|
||||
print_info " • 加密算法: RSA-OAEP-2048 + AES-256-GCM + HS256"
|
||||
print_info " • 保护数据: API密钥、私钥、Hyperliquid代理钱包、用户认证"
|
||||
|
||||
# 验证密钥文件权限
|
||||
if [ -f "secrets/rsa_key" ]; then
|
||||
local perm=$(stat -f "%A" "secrets/rsa_key" 2>/dev/null || stat -c "%a" "secrets/rsa_key" 2>/dev/null)
|
||||
if [ "$perm" != "600" ]; then
|
||||
print_warning "修复RSA私钥权限..."
|
||||
chmod 600 secrets/rsa_key
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -f ".env" ]; then
|
||||
local perm=$(stat -f "%A" ".env" 2>/dev/null || stat -c "%a" ".env" 2>/dev/null)
|
||||
if [ "$perm" != "600" ]; then
|
||||
print_warning "修复环境文件权限..."
|
||||
chmod 600 .env
|
||||
fi
|
||||
fi
|
||||
# 变量不存在,追加
|
||||
echo "${var_name}=${var_value}" >> .env
|
||||
fi
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Validation: Encryption Keys in .env
|
||||
# ------------------------------------------------------------------------
|
||||
check_encryption() {
|
||||
print_info "检查加密密钥配置..."
|
||||
|
||||
local generated=false
|
||||
|
||||
# 检查并生成 JWT_SECRET
|
||||
if ! is_env_configured "JWT_SECRET"; then
|
||||
print_warning "JWT_SECRET 未配置,正在生成..."
|
||||
local jwt_secret=$(openssl rand -base64 32)
|
||||
set_env_var "JWT_SECRET" "$jwt_secret"
|
||||
print_success "JWT_SECRET 已生成"
|
||||
generated=true
|
||||
fi
|
||||
|
||||
# 检查并生成 DATA_ENCRYPTION_KEY
|
||||
if ! is_env_configured "DATA_ENCRYPTION_KEY"; then
|
||||
print_warning "DATA_ENCRYPTION_KEY 未配置,正在生成..."
|
||||
local data_key=$(openssl rand -base64 32)
|
||||
set_env_var "DATA_ENCRYPTION_KEY" "$data_key"
|
||||
print_success "DATA_ENCRYPTION_KEY 已生成"
|
||||
generated=true
|
||||
fi
|
||||
|
||||
# 检查并生成 RSA_PRIVATE_KEY
|
||||
if ! is_env_configured "RSA_PRIVATE_KEY"; then
|
||||
print_warning "RSA_PRIVATE_KEY 未配置,正在生成..."
|
||||
# 生成 RSA 密钥并转换为单行格式(\n 替换为 \\n)
|
||||
local rsa_key=$(openssl genrsa 2048 2>/dev/null | awk '{printf "%s\\n", $0}')
|
||||
set_env_var "RSA_PRIVATE_KEY" "\"$rsa_key\""
|
||||
print_success "RSA_PRIVATE_KEY 已生成"
|
||||
generated=true
|
||||
fi
|
||||
|
||||
if [ "$generated" = true ]; then
|
||||
echo ""
|
||||
print_success "所有缺失的密钥已自动生成并保存到 .env"
|
||||
print_warning "请妥善保管 .env 文件,不要提交到版本控制系统"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
print_success "加密密钥检查完成"
|
||||
print_info " • JWT_SECRET: OK"
|
||||
print_info " • DATA_ENCRYPTION_KEY: OK"
|
||||
print_info " • RSA_PRIVATE_KEY: OK"
|
||||
|
||||
# 修复 .env 文件权限
|
||||
chmod 600 .env 2>/dev/null || true
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Validation: Configuration File (config.json) - BASIC SETTINGS ONLY
|
||||
# ------------------------------------------------------------------------
|
||||
@@ -166,9 +181,7 @@ check_config() {
|
||||
if [ ! -f "config.json" ]; then
|
||||
print_warning "config.json 不存在,从模板复制..."
|
||||
cp config.json.example config.json
|
||||
print_info "✓ 已使用默认配置创建 config.json"
|
||||
print_info "💡 如需修改基础设置(杠杆大小、开仓币种、管理员模式、JWT密钥等),可编辑 config.json"
|
||||
print_info "💡 模型/交易所/交易员配置请使用Web界面"
|
||||
print_info "已使用默认配置创建 config.json"
|
||||
fi
|
||||
print_success "配置文件存在"
|
||||
}
|
||||
@@ -178,101 +191,55 @@ check_config() {
|
||||
# ------------------------------------------------------------------------
|
||||
read_env_vars() {
|
||||
if [ -f ".env" ]; then
|
||||
# 读取端口配置,设置默认值
|
||||
NOFX_FRONTEND_PORT=$(grep "^NOFX_FRONTEND_PORT=" .env 2>/dev/null | cut -d'=' -f2 || echo "3000")
|
||||
NOFX_BACKEND_PORT=$(grep "^NOFX_BACKEND_PORT=" .env 2>/dev/null | cut -d'=' -f2 || echo "8080")
|
||||
|
||||
# 去除可能的引号和空格
|
||||
|
||||
NOFX_FRONTEND_PORT=$(echo "$NOFX_FRONTEND_PORT" | tr -d '"'"'" | tr -d ' ')
|
||||
NOFX_BACKEND_PORT=$(echo "$NOFX_BACKEND_PORT" | tr -d '"'"'" | tr -d ' ')
|
||||
|
||||
# 如果为空则使用默认值
|
||||
|
||||
NOFX_FRONTEND_PORT=${NOFX_FRONTEND_PORT:-3000}
|
||||
NOFX_BACKEND_PORT=${NOFX_BACKEND_PORT:-8080}
|
||||
else
|
||||
# 如果.env不存在,使用默认端口
|
||||
NOFX_FRONTEND_PORT=3000
|
||||
NOFX_BACKEND_PORT=8080
|
||||
fi
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Validation: Database File (config.db)
|
||||
# Validation: Database File (data.db)
|
||||
# ------------------------------------------------------------------------
|
||||
check_database() {
|
||||
if [ -d "config.db" ]; then
|
||||
# 如果存在的是目录,删除它
|
||||
print_warning "config.db 是目录而非文件,正在删除目录..."
|
||||
rm -rf config.db
|
||||
print_info "✓ 已删除目录,现在创建文件..."
|
||||
install -m 600 /dev/null config.db
|
||||
print_success "✓ 已创建空数据库文件(权限: 600),系统将在启动时初始化"
|
||||
elif [ ! -f "config.db" ]; then
|
||||
# 如果不存在文件,创建它
|
||||
if [ -d "data.db" ]; then
|
||||
print_warning "data.db 是目录而非文件,正在删除目录..."
|
||||
rm -rf data.db
|
||||
install -m 600 /dev/null data.db
|
||||
print_success "已创建空数据库文件"
|
||||
elif [ ! -f "data.db" ]; then
|
||||
print_warning "数据库文件不存在,创建空数据库文件..."
|
||||
# 创建空文件以避免Docker创建目录(使用安全权限600)
|
||||
install -m 600 /dev/null config.db
|
||||
print_info "✓ 已创建空数据库文件(权限: 600),系统将在启动时初始化"
|
||||
install -m 600 /dev/null data.db
|
||||
print_info "已创建空数据库文件,系统将在启动时初始化"
|
||||
else
|
||||
# 文件存在
|
||||
print_success "数据库文件存在"
|
||||
fi
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Build: Frontend (Node.js Based)
|
||||
# ------------------------------------------------------------------------
|
||||
# build_frontend() {
|
||||
# print_info "检查前端构建环境..."
|
||||
|
||||
# if ! command -v node &> /dev/null; then
|
||||
# print_error "Node.js 未安装!请先安装 Node.js"
|
||||
# exit 1
|
||||
# fi
|
||||
|
||||
# if ! command -v npm &> /dev/null; then
|
||||
# print_error "npm 未安装!请先安装 npm"
|
||||
# exit 1
|
||||
# fi
|
||||
|
||||
# print_info "正在构建前端..."
|
||||
# cd web
|
||||
|
||||
# print_info "安装 Node.js 依赖..."
|
||||
# npm install
|
||||
|
||||
# print_info "构建前端应用..."
|
||||
# npm run build
|
||||
|
||||
# cd ..
|
||||
# print_success "前端构建完成"
|
||||
# }
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Service Management: Start
|
||||
# ------------------------------------------------------------------------
|
||||
start() {
|
||||
print_info "正在启动 NOFX AI Trading System..."
|
||||
|
||||
# 读取环境变量
|
||||
read_env_vars
|
||||
|
||||
# 确保必要的文件和目录存在(修复 Docker volume 挂载问题)
|
||||
if [ ! -f "config.db" ]; then
|
||||
if [ ! -f "data.db" ]; then
|
||||
print_info "创建数据库文件..."
|
||||
install -m 600 /dev/null config.db
|
||||
install -m 600 /dev/null data.db
|
||||
fi
|
||||
if [ ! -d "decision_logs" ]; then
|
||||
print_info "创建日志目录..."
|
||||
install -m 700 -d decision_logs
|
||||
fi
|
||||
|
||||
# Auto-build frontend if missing or forced
|
||||
# if [ ! -d "web/dist" ] || [ "$1" == "--build" ]; then
|
||||
# build_frontend
|
||||
# fi
|
||||
|
||||
# Rebuild images if flag set
|
||||
if [ "$1" == "--build" ]; then
|
||||
print_info "重新构建镜像..."
|
||||
$COMPOSE_CMD up -d --build
|
||||
@@ -322,9 +289,8 @@ logs() {
|
||||
# Monitoring: Status
|
||||
# ------------------------------------------------------------------------
|
||||
status() {
|
||||
# 读取环境变量
|
||||
read_env_vars
|
||||
|
||||
|
||||
print_info "服务状态:"
|
||||
$COMPOSE_CMD ps
|
||||
echo ""
|
||||
@@ -358,18 +324,42 @@ update() {
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# Encryption: Manual Setup
|
||||
# Command: Regenerate all keys (force)
|
||||
# ------------------------------------------------------------------------
|
||||
setup_encryption_manual() {
|
||||
print_info "🔐 手动设置加密环境"
|
||||
|
||||
if [ -f "scripts/setup_encryption.sh" ]; then
|
||||
bash scripts/setup_encryption.sh
|
||||
else
|
||||
print_error "加密设置脚本不存在: scripts/setup_encryption.sh"
|
||||
print_info "请确保项目文件完整"
|
||||
exit 1
|
||||
regenerate_keys() {
|
||||
print_warning "这将重新生成所有加密密钥!"
|
||||
print_warning "如果已有加密数据,重新生成后将无法解密!"
|
||||
echo ""
|
||||
read -p "确认重新生成?(yes/no): " confirm
|
||||
if [ "$confirm" != "yes" ]; then
|
||||
print_info "已取消"
|
||||
return
|
||||
fi
|
||||
|
||||
check_env
|
||||
|
||||
print_info "正在生成新的密钥..."
|
||||
|
||||
# 生成 JWT_SECRET
|
||||
local jwt_secret=$(openssl rand -base64 32)
|
||||
set_env_var "JWT_SECRET" "$jwt_secret"
|
||||
print_success "JWT_SECRET 已生成"
|
||||
|
||||
# 生成 DATA_ENCRYPTION_KEY
|
||||
local data_key=$(openssl rand -base64 32)
|
||||
set_env_var "DATA_ENCRYPTION_KEY" "$data_key"
|
||||
print_success "DATA_ENCRYPTION_KEY 已生成"
|
||||
|
||||
# 生成 RSA_PRIVATE_KEY
|
||||
local rsa_key=$(openssl genrsa 2048 2>/dev/null | awk '{printf "%s\\n", $0}')
|
||||
set_env_var "RSA_PRIVATE_KEY" "\"$rsa_key\""
|
||||
print_success "RSA_PRIVATE_KEY 已生成"
|
||||
|
||||
chmod 600 .env 2>/dev/null || true
|
||||
|
||||
echo ""
|
||||
print_success "所有密钥已重新生成并保存到 .env"
|
||||
print_warning "请妥善保管 .env 文件"
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
@@ -388,18 +378,16 @@ show_help() {
|
||||
echo " status 查看服务状态"
|
||||
echo " clean 清理所有容器和数据"
|
||||
echo " update 更新代码并重启"
|
||||
echo " setup-encryption 设置加密环境(RSA密钥+数据加密)"
|
||||
echo " regenerate-keys 重新生成所有加密密钥(慎用)"
|
||||
echo " help 显示此帮助信息"
|
||||
echo ""
|
||||
echo "示例:"
|
||||
echo " ./start.sh start --build # 构建并启动"
|
||||
echo " ./start.sh logs backend # 查看后端日志"
|
||||
echo " ./start.sh status # 查看状态"
|
||||
echo " ./start.sh setup-encryption # 手动设置加密环境"
|
||||
echo ""
|
||||
echo "🔐 关于加密:"
|
||||
echo " 系统自动检测加密环境,首次运行时会自动设置"
|
||||
echo " 手动设置: ./scripts/setup_encryption.sh"
|
||||
echo "首次使用:"
|
||||
echo " 直接运行 ./start.sh 即可,缺失的密钥会自动生成"
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
@@ -434,8 +422,8 @@ main() {
|
||||
update)
|
||||
update
|
||||
;;
|
||||
setup-encryption)
|
||||
setup_encryption_manual
|
||||
regenerate-keys)
|
||||
regenerate_keys
|
||||
;;
|
||||
help|--help|-h)
|
||||
show_help
|
||||
@@ -449,4 +437,4 @@ main() {
|
||||
}
|
||||
|
||||
# Execute Main
|
||||
main "$@"
|
||||
main "$@"
|
||||
|
||||
294
store/ai_model.go
Normal file
294
store/ai_model.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"nofx/logger"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AIModelStore AI模型存储
|
||||
type AIModelStore struct {
|
||||
db *sql.DB
|
||||
encryptFunc func(string) string
|
||||
decryptFunc func(string) string
|
||||
}
|
||||
|
||||
// AIModel AI模型配置
|
||||
type AIModel struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"apiKey"`
|
||||
CustomAPIURL string `json:"customApiUrl"`
|
||||
CustomModelName string `json:"customModelName"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (s *AIModelStore) initTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS ai_models (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL DEFAULT 'default',
|
||||
name TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
enabled BOOLEAN DEFAULT 0,
|
||||
api_key TEXT DEFAULT '',
|
||||
custom_api_url TEXT DEFAULT '',
|
||||
custom_model_name TEXT DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 触发器
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS update_ai_models_updated_at
|
||||
AFTER UPDATE ON ai_models
|
||||
BEGIN
|
||||
UPDATE ai_models SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 向后兼容:添加可能缺失的列
|
||||
s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_api_url TEXT DEFAULT ''`)
|
||||
s.db.Exec(`ALTER TABLE ai_models ADD COLUMN custom_model_name TEXT DEFAULT ''`)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AIModelStore) initDefaultData() error {
|
||||
models := []struct {
|
||||
id, name, provider string
|
||||
}{
|
||||
{"deepseek", "DeepSeek", "deepseek"},
|
||||
{"qwen", "Qwen", "qwen"},
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT OR IGNORE INTO ai_models (id, user_id, name, provider, enabled)
|
||||
VALUES (?, 'default', ?, ?, 0)
|
||||
`, model.id, model.name, model.provider)
|
||||
if err != nil {
|
||||
return fmt.Errorf("初始化AI模型失败: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AIModelStore) encrypt(plaintext string) string {
|
||||
if s.encryptFunc != nil {
|
||||
return s.encryptFunc(plaintext)
|
||||
}
|
||||
return plaintext
|
||||
}
|
||||
|
||||
func (s *AIModelStore) decrypt(encrypted string) string {
|
||||
if s.decryptFunc != nil {
|
||||
return s.decryptFunc(encrypted)
|
||||
}
|
||||
return encrypted
|
||||
}
|
||||
|
||||
// List 获取用户的AI模型列表
|
||||
func (s *AIModelStore) List(userID string) ([]*AIModel, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, user_id, name, provider, enabled, api_key,
|
||||
COALESCE(custom_api_url, '') as custom_api_url,
|
||||
COALESCE(custom_model_name, '') as custom_model_name,
|
||||
created_at, updated_at
|
||||
FROM ai_models WHERE user_id = ? ORDER BY id
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
models := make([]*AIModel, 0)
|
||||
for rows.Next() {
|
||||
var model AIModel
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(
|
||||
&model.ID, &model.UserID, &model.Name, &model.Provider,
|
||||
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
model.APIKey = s.decrypt(model.APIKey)
|
||||
models = append(models, &model)
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// Get 获取单个AI模型
|
||||
func (s *AIModelStore) Get(userID, modelID string) (*AIModel, error) {
|
||||
if modelID == "" {
|
||||
return nil, fmt.Errorf("模型ID不能为空")
|
||||
}
|
||||
|
||||
candidates := []string{}
|
||||
if userID != "" {
|
||||
candidates = append(candidates, userID)
|
||||
}
|
||||
if userID != "default" {
|
||||
candidates = append(candidates, "default")
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
candidates = append(candidates, "default")
|
||||
}
|
||||
|
||||
for _, uid := range candidates {
|
||||
var model AIModel
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, name, provider, enabled, api_key,
|
||||
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
|
||||
FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1
|
||||
`, uid, modelID).Scan(
|
||||
&model.ID, &model.UserID, &model.Name, &model.Provider,
|
||||
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
if err == nil {
|
||||
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
model.APIKey = s.decrypt(model.APIKey)
|
||||
return &model, nil
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, sql.ErrNoRows
|
||||
}
|
||||
|
||||
// GetDefault 获取默认启用的AI模型
|
||||
func (s *AIModelStore) GetDefault(userID string) (*AIModel, error) {
|
||||
if userID == "" {
|
||||
userID = "default"
|
||||
}
|
||||
model, err := s.firstEnabled(userID)
|
||||
if err == nil {
|
||||
return model, nil
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
if userID != "default" {
|
||||
return s.firstEnabled("default")
|
||||
}
|
||||
return nil, fmt.Errorf("请先在系统中配置可用的AI模型")
|
||||
}
|
||||
|
||||
func (s *AIModelStore) firstEnabled(userID string) (*AIModel, error) {
|
||||
var model AIModel
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, name, provider, enabled, api_key,
|
||||
COALESCE(custom_api_url, ''), COALESCE(custom_model_name, ''), created_at, updated_at
|
||||
FROM ai_models WHERE user_id = ? AND enabled = 1
|
||||
ORDER BY datetime(updated_at) DESC, id ASC LIMIT 1
|
||||
`, userID).Scan(
|
||||
&model.ID, &model.UserID, &model.Name, &model.Provider,
|
||||
&model.Enabled, &model.APIKey, &model.CustomAPIURL, &model.CustomModelName,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
model.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
model.APIKey = s.decrypt(model.APIKey)
|
||||
return &model, nil
|
||||
}
|
||||
|
||||
// Update 更新AI模型,不存在则创建
|
||||
func (s *AIModelStore) Update(userID, id string, enabled bool, apiKey, customAPIURL, customModelName string) error {
|
||||
// 先尝试精确匹配ID
|
||||
var existingID string
|
||||
err := s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND id = ? LIMIT 1`, userID, id).Scan(&existingID)
|
||||
if err == nil {
|
||||
encryptedAPIKey := s.encrypt(apiKey)
|
||||
_, err = s.db.Exec(`
|
||||
UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
|
||||
WHERE id = ? AND user_id = ?
|
||||
`, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// 尝试兼容旧逻辑:将id作为provider查找
|
||||
provider := id
|
||||
err = s.db.QueryRow(`SELECT id FROM ai_models WHERE user_id = ? AND provider = ? LIMIT 1`, userID, provider).Scan(&existingID)
|
||||
if err == nil {
|
||||
logger.Warnf("⚠️ 使用旧版 provider 匹配更新模型: %s -> %s", provider, existingID)
|
||||
encryptedAPIKey := s.encrypt(apiKey)
|
||||
_, err = s.db.Exec(`
|
||||
UPDATE ai_models SET enabled = ?, api_key = ?, custom_api_url = ?, custom_model_name = ?, updated_at = datetime('now')
|
||||
WHERE id = ? AND user_id = ?
|
||||
`, enabled, encryptedAPIKey, customAPIURL, customModelName, existingID, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建新记录
|
||||
if provider == id && (provider == "deepseek" || provider == "qwen") {
|
||||
provider = id
|
||||
} else {
|
||||
parts := strings.Split(id, "_")
|
||||
if len(parts) >= 2 {
|
||||
provider = parts[len(parts)-1]
|
||||
} else {
|
||||
provider = id
|
||||
}
|
||||
}
|
||||
|
||||
var name string
|
||||
err = s.db.QueryRow(`SELECT name FROM ai_models WHERE provider = ? LIMIT 1`, provider).Scan(&name)
|
||||
if err != nil {
|
||||
if provider == "deepseek" {
|
||||
name = "DeepSeek AI"
|
||||
} else if provider == "qwen" {
|
||||
name = "Qwen AI"
|
||||
} else {
|
||||
name = provider + " AI"
|
||||
}
|
||||
}
|
||||
|
||||
newModelID := id
|
||||
if id == provider {
|
||||
newModelID = fmt.Sprintf("%s_%s", userID, provider)
|
||||
}
|
||||
|
||||
logger.Infof("✓ 创建新的 AI 模型配置: ID=%s, Provider=%s, Name=%s", newModelID, provider, name)
|
||||
encryptedAPIKey := s.encrypt(apiKey)
|
||||
_, err = s.db.Exec(`
|
||||
INSERT INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url, custom_model_name, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))
|
||||
`, newModelID, userID, name, provider, enabled, encryptedAPIKey, customAPIURL, customModelName)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create 创建AI模型
|
||||
func (s *AIModelStore) Create(userID, id, name, provider string, enabled bool, apiKey, customAPIURL string) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT OR IGNORE INTO ai_models (id, user_id, name, provider, enabled, api_key, custom_api_url)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
`, id, userID, name, provider, enabled, apiKey, customAPIURL)
|
||||
return err
|
||||
}
|
||||
583
store/backtest.go
Normal file
583
store/backtest.go
Normal file
@@ -0,0 +1,583 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BacktestStore 回测数据存储
|
||||
type BacktestStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// RunState 回测状态
|
||||
type RunState string
|
||||
|
||||
const (
|
||||
RunStateCreated RunState = "created"
|
||||
RunStateRunning RunState = "running"
|
||||
RunStatePaused RunState = "paused"
|
||||
RunStateCompleted RunState = "completed"
|
||||
RunStateFailed RunState = "failed"
|
||||
)
|
||||
|
||||
// RunMetadata 回测元数据
|
||||
type RunMetadata struct {
|
||||
RunID string `json:"run_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Version int `json:"version"`
|
||||
State RunState `json:"state"`
|
||||
Label string `json:"label"`
|
||||
LastError string `json:"last_error"`
|
||||
Summary RunSummary `json:"summary"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// RunSummary 回测摘要
|
||||
type RunSummary struct {
|
||||
SymbolCount int `json:"symbol_count"`
|
||||
DecisionTF string `json:"decision_tf"`
|
||||
ProcessedBars int `json:"processed_bars"`
|
||||
ProgressPct float64 `json:"progress_pct"`
|
||||
EquityLast float64 `json:"equity_last"`
|
||||
MaxDrawdownPct float64 `json:"max_drawdown_pct"`
|
||||
Liquidated bool `json:"liquidated"`
|
||||
LiquidationNote string `json:"liquidation_note"`
|
||||
}
|
||||
|
||||
// EquityPoint 权益点
|
||||
type EquityPoint struct {
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
Equity float64 `json:"equity"`
|
||||
Available float64 `json:"available"`
|
||||
PnL float64 `json:"pnl"`
|
||||
PnLPct float64 `json:"pnl_pct"`
|
||||
DrawdownPct float64 `json:"drawdown_pct"`
|
||||
Cycle int `json:"cycle"`
|
||||
}
|
||||
|
||||
// TradeEvent 交易事件
|
||||
type TradeEvent struct {
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
Symbol string `json:"symbol"`
|
||||
Action string `json:"action"`
|
||||
Side string `json:"side"`
|
||||
Quantity float64 `json:"quantity"`
|
||||
Price float64 `json:"price"`
|
||||
Fee float64 `json:"fee"`
|
||||
Slippage float64 `json:"slippage"`
|
||||
OrderValue float64 `json:"order_value"`
|
||||
RealizedPnL float64 `json:"realized_pnl"`
|
||||
Leverage int `json:"leverage"`
|
||||
Cycle int `json:"cycle"`
|
||||
PositionAfter float64 `json:"position_after"`
|
||||
LiquidationFlag bool `json:"liquidation_flag"`
|
||||
Note string `json:"note"`
|
||||
}
|
||||
|
||||
// RunIndexEntry 回测索引条目
|
||||
type RunIndexEntry struct {
|
||||
RunID string `json:"run_id"`
|
||||
State string `json:"state"`
|
||||
Symbols []string `json:"symbols"`
|
||||
DecisionTF string `json:"decision_tf"`
|
||||
EquityLast float64 `json:"equity_last"`
|
||||
MaxDrawdownPct float64 `json:"max_drawdown_pct"`
|
||||
StartTS int64 `json:"start_ts"`
|
||||
EndTS int64 `json:"end_ts"`
|
||||
CreatedAtISO string `json:"created_at"`
|
||||
UpdatedAtISO string `json:"updated_at"`
|
||||
}
|
||||
|
||||
// initTables 初始化回测相关表
|
||||
func (s *BacktestStore) initTables() error {
|
||||
queries := []string{
|
||||
// 回测运行主表
|
||||
`CREATE TABLE IF NOT EXISTS backtest_runs (
|
||||
run_id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL DEFAULT '',
|
||||
config_json TEXT NOT NULL DEFAULT '',
|
||||
state TEXT NOT NULL DEFAULT 'created',
|
||||
label TEXT DEFAULT '',
|
||||
symbol_count INTEGER DEFAULT 0,
|
||||
decision_tf TEXT DEFAULT '',
|
||||
processed_bars INTEGER DEFAULT 0,
|
||||
progress_pct REAL DEFAULT 0,
|
||||
equity_last REAL DEFAULT 0,
|
||||
max_drawdown_pct REAL DEFAULT 0,
|
||||
liquidated BOOLEAN DEFAULT 0,
|
||||
liquidation_note TEXT DEFAULT '',
|
||||
prompt_template TEXT DEFAULT '',
|
||||
custom_prompt TEXT DEFAULT '',
|
||||
override_prompt BOOLEAN DEFAULT 0,
|
||||
ai_provider TEXT DEFAULT '',
|
||||
ai_model TEXT DEFAULT '',
|
||||
last_error TEXT DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
|
||||
// 回测检查点
|
||||
`CREATE TABLE IF NOT EXISTS backtest_checkpoints (
|
||||
run_id TEXT PRIMARY KEY,
|
||||
payload BLOB NOT NULL,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// 回测权益曲线
|
||||
`CREATE TABLE IF NOT EXISTS backtest_equity (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id TEXT NOT NULL,
|
||||
ts INTEGER NOT NULL,
|
||||
equity REAL NOT NULL,
|
||||
available REAL NOT NULL,
|
||||
pnl REAL NOT NULL,
|
||||
pnl_pct REAL NOT NULL,
|
||||
dd_pct REAL NOT NULL,
|
||||
cycle INTEGER NOT NULL,
|
||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// 回测交易记录
|
||||
`CREATE TABLE IF NOT EXISTS backtest_trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id TEXT NOT NULL,
|
||||
ts INTEGER NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
side TEXT DEFAULT '',
|
||||
qty REAL DEFAULT 0,
|
||||
price REAL DEFAULT 0,
|
||||
fee REAL DEFAULT 0,
|
||||
slippage REAL DEFAULT 0,
|
||||
order_value REAL DEFAULT 0,
|
||||
realized_pnl REAL DEFAULT 0,
|
||||
leverage INTEGER DEFAULT 0,
|
||||
cycle INTEGER DEFAULT 0,
|
||||
position_after REAL DEFAULT 0,
|
||||
liquidation BOOLEAN DEFAULT 0,
|
||||
note TEXT DEFAULT '',
|
||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// 回测指标
|
||||
`CREATE TABLE IF NOT EXISTS backtest_metrics (
|
||||
run_id TEXT PRIMARY KEY,
|
||||
payload BLOB NOT NULL,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// 回测决策日志
|
||||
`CREATE TABLE IF NOT EXISTS backtest_decisions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
run_id TEXT NOT NULL,
|
||||
cycle INTEGER NOT NULL,
|
||||
payload BLOB NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (run_id) REFERENCES backtest_runs(run_id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// 索引
|
||||
`CREATE INDEX IF NOT EXISTS idx_backtest_runs_state ON backtest_runs(state, updated_at)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_backtest_equity_run_ts ON backtest_equity(run_id, ts)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_backtest_trades_run_ts ON backtest_trades(run_id, ts)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_backtest_decisions_run_cycle ON backtest_decisions(run_id, cycle)`,
|
||||
}
|
||||
|
||||
for _, query := range queries {
|
||||
if _, err := s.db.Exec(query); err != nil {
|
||||
return fmt.Errorf("执行SQL失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加可能缺失的列(向后兼容)
|
||||
s.addColumnIfNotExists("backtest_runs", "label", "TEXT DEFAULT ''")
|
||||
s.addColumnIfNotExists("backtest_runs", "last_error", "TEXT DEFAULT ''")
|
||||
s.addColumnIfNotExists("backtest_trades", "leverage", "INTEGER DEFAULT 0")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BacktestStore) addColumnIfNotExists(table, column, definition string) {
|
||||
rows, err := s.db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var cid int
|
||||
var name, ctype string
|
||||
var notnull, pk int
|
||||
var dflt interface{}
|
||||
if err := rows.Scan(&cid, &name, &ctype, ¬null, &dflt, &pk); err != nil {
|
||||
continue
|
||||
}
|
||||
if name == column {
|
||||
return // 列已存在
|
||||
}
|
||||
}
|
||||
|
||||
s.db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, definition))
|
||||
}
|
||||
|
||||
// SaveCheckpoint 保存检查点
|
||||
func (s *BacktestStore) SaveCheckpoint(runID string, payload []byte) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO backtest_checkpoints (run_id, payload, updated_at)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP
|
||||
`, runID, payload)
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadCheckpoint 加载检查点
|
||||
func (s *BacktestStore) LoadCheckpoint(runID string) ([]byte, error) {
|
||||
var payload []byte
|
||||
err := s.db.QueryRow(`SELECT payload FROM backtest_checkpoints WHERE run_id = ?`, runID).Scan(&payload)
|
||||
return payload, err
|
||||
}
|
||||
|
||||
// SaveRunMetadata 保存运行元数据
|
||||
func (s *BacktestStore) SaveRunMetadata(meta *RunMetadata) error {
|
||||
created := meta.CreatedAt.UTC().Format(time.RFC3339)
|
||||
updated := meta.UpdatedAt.UTC().Format(time.RFC3339)
|
||||
userID := meta.UserID
|
||||
|
||||
if _, err := s.db.Exec(`
|
||||
INSERT INTO backtest_runs (run_id, user_id, label, last_error, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(run_id) DO NOTHING
|
||||
`, meta.RunID, userID, meta.Label, meta.LastError, created, updated); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE backtest_runs
|
||||
SET user_id = ?, state = ?, symbol_count = ?, decision_tf = ?, processed_bars = ?,
|
||||
progress_pct = ?, equity_last = ?, max_drawdown_pct = ?, liquidated = ?,
|
||||
liquidation_note = ?, label = ?, last_error = ?, updated_at = ?
|
||||
WHERE run_id = ?
|
||||
`, userID, string(meta.State), meta.Summary.SymbolCount, meta.Summary.DecisionTF,
|
||||
meta.Summary.ProcessedBars, meta.Summary.ProgressPct, meta.Summary.EquityLast,
|
||||
meta.Summary.MaxDrawdownPct, meta.Summary.Liquidated, meta.Summary.LiquidationNote,
|
||||
meta.Label, meta.LastError, updated, meta.RunID)
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadRunMetadata 加载运行元数据
|
||||
func (s *BacktestStore) LoadRunMetadata(runID string) (*RunMetadata, error) {
|
||||
var (
|
||||
userID string
|
||||
state string
|
||||
label string
|
||||
lastErr string
|
||||
symbolCount int
|
||||
decisionTF string
|
||||
processedBars int
|
||||
progressPct float64
|
||||
equityLast float64
|
||||
maxDD float64
|
||||
liquidated bool
|
||||
liquidationNote string
|
||||
createdISO string
|
||||
updatedISO string
|
||||
)
|
||||
|
||||
err := s.db.QueryRow(`
|
||||
SELECT user_id, state, label, last_error, symbol_count, decision_tf, processed_bars,
|
||||
progress_pct, equity_last, max_drawdown_pct, liquidated, liquidation_note,
|
||||
created_at, updated_at
|
||||
FROM backtest_runs WHERE run_id = ?
|
||||
`, runID).Scan(&userID, &state, &label, &lastErr, &symbolCount, &decisionTF,
|
||||
&processedBars, &progressPct, &equityLast, &maxDD, &liquidated, &liquidationNote,
|
||||
&createdISO, &updatedISO)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := &RunMetadata{
|
||||
RunID: runID,
|
||||
UserID: userID,
|
||||
Version: 1,
|
||||
State: RunState(state),
|
||||
Label: label,
|
||||
LastError: lastErr,
|
||||
Summary: RunSummary{
|
||||
SymbolCount: symbolCount,
|
||||
DecisionTF: decisionTF,
|
||||
ProcessedBars: processedBars,
|
||||
ProgressPct: progressPct,
|
||||
EquityLast: equityLast,
|
||||
MaxDrawdownPct: maxDD,
|
||||
Liquidated: liquidated,
|
||||
LiquidationNote: liquidationNote,
|
||||
},
|
||||
}
|
||||
|
||||
meta.CreatedAt, _ = time.Parse(time.RFC3339, createdISO)
|
||||
meta.UpdatedAt, _ = time.Parse(time.RFC3339, updatedISO)
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
// ListRunIDs 列出所有运行ID
|
||||
func (s *BacktestStore) ListRunIDs() ([]string, error) {
|
||||
rows, err := s.db.Query(`SELECT run_id FROM backtest_runs ORDER BY datetime(updated_at) DESC`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var ids []string
|
||||
for rows.Next() {
|
||||
var runID string
|
||||
if err := rows.Scan(&runID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, runID)
|
||||
}
|
||||
return ids, rows.Err()
|
||||
}
|
||||
|
||||
// AppendEquityPoint 添加权益点
|
||||
func (s *BacktestStore) AppendEquityPoint(runID string, point EquityPoint) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO backtest_equity (run_id, ts, equity, available, pnl, pnl_pct, dd_pct, cycle)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, runID, point.Timestamp, point.Equity, point.Available, point.PnL,
|
||||
point.PnLPct, point.DrawdownPct, point.Cycle)
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadEquityPoints 加载权益点
|
||||
func (s *BacktestStore) LoadEquityPoints(runID string) ([]EquityPoint, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT ts, equity, available, pnl, pnl_pct, dd_pct, cycle
|
||||
FROM backtest_equity WHERE run_id = ? ORDER BY ts ASC
|
||||
`, runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
points := make([]EquityPoint, 0)
|
||||
for rows.Next() {
|
||||
var point EquityPoint
|
||||
if err := rows.Scan(&point.Timestamp, &point.Equity, &point.Available,
|
||||
&point.PnL, &point.PnLPct, &point.DrawdownPct, &point.Cycle); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
points = append(points, point)
|
||||
}
|
||||
return points, rows.Err()
|
||||
}
|
||||
|
||||
// AppendTradeEvent 添加交易事件
|
||||
func (s *BacktestStore) AppendTradeEvent(runID string, event TradeEvent) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO backtest_trades (run_id, ts, symbol, action, side, qty, price, fee,
|
||||
slippage, order_value, realized_pnl, leverage, cycle,
|
||||
position_after, liquidation, note)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, runID, event.Timestamp, event.Symbol, event.Action, event.Side, event.Quantity,
|
||||
event.Price, event.Fee, event.Slippage, event.OrderValue, event.RealizedPnL,
|
||||
event.Leverage, event.Cycle, event.PositionAfter, event.LiquidationFlag, event.Note)
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadTradeEvents 加载交易事件
|
||||
func (s *BacktestStore) LoadTradeEvents(runID string) ([]TradeEvent, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT ts, symbol, action, side, qty, price, fee, slippage, order_value,
|
||||
realized_pnl, leverage, cycle, position_after, liquidation, note
|
||||
FROM backtest_trades WHERE run_id = ? ORDER BY ts ASC
|
||||
`, runID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
events := make([]TradeEvent, 0)
|
||||
for rows.Next() {
|
||||
var event TradeEvent
|
||||
if err := rows.Scan(&event.Timestamp, &event.Symbol, &event.Action, &event.Side,
|
||||
&event.Quantity, &event.Price, &event.Fee, &event.Slippage, &event.OrderValue,
|
||||
&event.RealizedPnL, &event.Leverage, &event.Cycle, &event.PositionAfter,
|
||||
&event.LiquidationFlag, &event.Note); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
return events, rows.Err()
|
||||
}
|
||||
|
||||
// SaveMetrics 保存指标
|
||||
func (s *BacktestStore) SaveMetrics(runID string, payload []byte) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO backtest_metrics (run_id, payload, updated_at)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(run_id) DO UPDATE SET payload=excluded.payload, updated_at=CURRENT_TIMESTAMP
|
||||
`, runID, payload)
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadMetrics 加载指标
|
||||
func (s *BacktestStore) LoadMetrics(runID string) ([]byte, error) {
|
||||
var payload []byte
|
||||
err := s.db.QueryRow(`SELECT payload FROM backtest_metrics WHERE run_id = ?`, runID).Scan(&payload)
|
||||
return payload, err
|
||||
}
|
||||
|
||||
// SaveDecisionRecord 保存决策记录
|
||||
func (s *BacktestStore) SaveDecisionRecord(runID string, cycle int, payload []byte) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO backtest_decisions (run_id, cycle, payload)
|
||||
VALUES (?, ?, ?)
|
||||
`, runID, cycle, payload)
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadDecisionRecords 加载决策记录
|
||||
func (s *BacktestStore) LoadDecisionRecords(runID string, limit, offset int) ([]json.RawMessage, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT payload FROM backtest_decisions
|
||||
WHERE run_id = ?
|
||||
ORDER BY id DESC
|
||||
LIMIT ? OFFSET ?
|
||||
`, runID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
records := make([]json.RawMessage, 0, limit)
|
||||
for rows.Next() {
|
||||
var payload []byte
|
||||
if err := rows.Scan(&payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
records = append(records, json.RawMessage(payload))
|
||||
}
|
||||
return records, rows.Err()
|
||||
}
|
||||
|
||||
// LoadLatestDecision 加载最新决策
|
||||
func (s *BacktestStore) LoadLatestDecision(runID string, cycle int) ([]byte, error) {
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if cycle > 0 {
|
||||
query = `SELECT payload FROM backtest_decisions WHERE run_id = ? AND cycle = ? ORDER BY datetime(created_at) DESC LIMIT 1`
|
||||
args = []interface{}{runID, cycle}
|
||||
} else {
|
||||
query = `SELECT payload FROM backtest_decisions WHERE run_id = ? ORDER BY datetime(created_at) DESC LIMIT 1`
|
||||
args = []interface{}{runID}
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
err := s.db.QueryRow(query, args...).Scan(&payload)
|
||||
return payload, err
|
||||
}
|
||||
|
||||
// UpdateProgress 更新进度
|
||||
func (s *BacktestStore) UpdateProgress(runID string, progressPct, equity float64, barIndex int, liquidated bool) error {
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE backtest_runs
|
||||
SET progress_pct = ?, equity_last = ?, processed_bars = ?, liquidated = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE run_id = ?
|
||||
`, progressPct, equity, barIndex, liquidated, runID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListIndexEntries 列出索引条目
|
||||
func (s *BacktestStore) ListIndexEntries() ([]RunIndexEntry, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT run_id, state, symbol_count, decision_tf, equity_last, max_drawdown_pct,
|
||||
created_at, updated_at, config_json
|
||||
FROM backtest_runs
|
||||
ORDER BY datetime(updated_at) DESC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []RunIndexEntry
|
||||
for rows.Next() {
|
||||
var entry RunIndexEntry
|
||||
var symbolCnt int
|
||||
var cfgJSON []byte
|
||||
var createdISO, updatedISO string
|
||||
|
||||
if err := rows.Scan(&entry.RunID, &entry.State, &symbolCnt, &entry.DecisionTF,
|
||||
&entry.EquityLast, &entry.MaxDrawdownPct, &createdISO, &updatedISO, &cfgJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entry.CreatedAtISO = createdISO
|
||||
entry.UpdatedAtISO = updatedISO
|
||||
entry.Symbols = make([]string, 0, symbolCnt)
|
||||
|
||||
// 尝试从配置中提取更多信息
|
||||
if len(cfgJSON) > 0 {
|
||||
var cfg struct {
|
||||
Symbols []string `json:"symbols"`
|
||||
StartTS int64 `json:"start_ts"`
|
||||
EndTS int64 `json:"end_ts"`
|
||||
}
|
||||
if json.Unmarshal(cfgJSON, &cfg) == nil {
|
||||
entry.Symbols = cfg.Symbols
|
||||
entry.StartTS = cfg.StartTS
|
||||
entry.EndTS = cfg.EndTS
|
||||
}
|
||||
}
|
||||
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
// DeleteRun 删除运行
|
||||
func (s *BacktestStore) DeleteRun(runID string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM backtest_runs WHERE run_id = ?`, runID)
|
||||
return err
|
||||
}
|
||||
|
||||
// SaveConfig 保存配置
|
||||
func (s *BacktestStore) SaveConfig(runID, userID, template, customPrompt, provider, model string, override bool, configJSON []byte) error {
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
if userID == "" {
|
||||
userID = "default"
|
||||
}
|
||||
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO backtest_runs (run_id, user_id, config_json, prompt_template, custom_prompt,
|
||||
override_prompt, ai_provider, ai_model, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(run_id) DO NOTHING
|
||||
`, runID, userID, configJSON, template, customPrompt, override, provider, model, now, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(`
|
||||
UPDATE backtest_runs
|
||||
SET user_id = ?, config_json = ?, prompt_template = ?, custom_prompt = ?,
|
||||
override_prompt = ?, ai_provider = ?, ai_model = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE run_id = ?
|
||||
`, userID, configJSON, template, customPrompt, override, provider, model, runID)
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadConfig 加载配置
|
||||
func (s *BacktestStore) LoadConfig(runID string) ([]byte, error) {
|
||||
var payload []byte
|
||||
err := s.db.QueryRow(`SELECT config_json FROM backtest_runs WHERE run_id = ?`, runID).Scan(&payload)
|
||||
return payload, err
|
||||
}
|
||||
121
store/beta_code.go
Normal file
121
store/beta_code.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"nofx/logger"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BetaCodeStore 内测码存储
|
||||
type BetaCodeStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (s *BetaCodeStore) initTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS beta_codes (
|
||||
code TEXT PRIMARY KEY,
|
||||
used BOOLEAN DEFAULT 0,
|
||||
used_by TEXT DEFAULT '',
|
||||
used_at DATETIME DEFAULT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadFromFile 从文件加载内测码
|
||||
func (s *BetaCodeStore) LoadFromFile(filePath string) error {
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取内测码文件失败: %w", err)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(content), "\n")
|
||||
var codes []string
|
||||
for _, line := range lines {
|
||||
code := strings.TrimSpace(line)
|
||||
if code != "" && !strings.HasPrefix(code, "#") {
|
||||
codes = append(codes, code)
|
||||
}
|
||||
}
|
||||
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("开始事务失败: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.Prepare(`INSERT OR IGNORE INTO beta_codes (code) VALUES (?)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("准备语句失败: %w", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
insertedCount := 0
|
||||
for _, code := range codes {
|
||||
result, err := stmt.Exec(code)
|
||||
if err != nil {
|
||||
logger.Warnf("插入内测码 %s 失败: %v", code, err)
|
||||
continue
|
||||
}
|
||||
if rowsAffected, _ := result.RowsAffected(); rowsAffected > 0 {
|
||||
insertedCount++
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("提交事务失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof("✅ 成功加载 %d 个内测码到数据库 (总计 %d 个)", insertedCount, len(codes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate 验证内测码是否有效
|
||||
func (s *BetaCodeStore) Validate(code string) (bool, error) {
|
||||
var used bool
|
||||
err := s.db.QueryRow(`SELECT used FROM beta_codes WHERE code = ?`, code).Scan(&used)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return !used, nil
|
||||
}
|
||||
|
||||
// Use 使用内测码
|
||||
func (s *BetaCodeStore) Use(code, userEmail string) error {
|
||||
result, err := s.db.Exec(`
|
||||
UPDATE beta_codes SET used = 1, used_by = ?, used_at = CURRENT_TIMESTAMP
|
||||
WHERE code = ? AND used = 0
|
||||
`, userEmail, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("内测码无效或已被使用")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats 获取内测码统计
|
||||
func (s *BetaCodeStore) GetStats() (total, used int, err error) {
|
||||
err = s.db.QueryRow(`SELECT COUNT(*) FROM beta_codes`).Scan(&total)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
err = s.db.QueryRow(`SELECT COUNT(*) FROM beta_codes WHERE used = 1`).Scan(&used)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return total, used, nil
|
||||
}
|
||||
530
store/decision.go
Normal file
530
store/decision.go
Normal file
@@ -0,0 +1,530 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DecisionStore 决策日志存储
|
||||
type DecisionStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// DecisionRecord 决策记录
|
||||
type DecisionRecord struct {
|
||||
ID int64 `json:"id"`
|
||||
TraderID string `json:"trader_id"`
|
||||
CycleNumber int `json:"cycle_number"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
SystemPrompt string `json:"system_prompt"`
|
||||
InputPrompt string `json:"input_prompt"`
|
||||
CoTTrace string `json:"cot_trace"`
|
||||
DecisionJSON string `json:"decision_json"`
|
||||
CandidateCoins []string `json:"candidate_coins"`
|
||||
ExecutionLog []string `json:"execution_log"`
|
||||
Success bool `json:"success"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
AIRequestDurationMs int64 `json:"ai_request_duration_ms"`
|
||||
AccountState AccountSnapshot `json:"account_state"`
|
||||
Positions []PositionSnapshot `json:"positions"`
|
||||
Decisions []DecisionAction `json:"decisions"`
|
||||
}
|
||||
|
||||
// AccountSnapshot 账户状态快照
|
||||
type AccountSnapshot struct {
|
||||
TotalBalance float64 `json:"total_balance"`
|
||||
AvailableBalance float64 `json:"available_balance"`
|
||||
TotalUnrealizedProfit float64 `json:"total_unrealized_profit"`
|
||||
PositionCount int `json:"position_count"`
|
||||
MarginUsedPct float64 `json:"margin_used_pct"`
|
||||
InitialBalance float64 `json:"initial_balance"`
|
||||
}
|
||||
|
||||
// PositionSnapshot 持仓快照
|
||||
type PositionSnapshot struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Side string `json:"side"`
|
||||
PositionAmt float64 `json:"position_amt"`
|
||||
EntryPrice float64 `json:"entry_price"`
|
||||
MarkPrice float64 `json:"mark_price"`
|
||||
UnrealizedProfit float64 `json:"unrealized_profit"`
|
||||
Leverage float64 `json:"leverage"`
|
||||
LiquidationPrice float64 `json:"liquidation_price"`
|
||||
}
|
||||
|
||||
// DecisionAction 决策动作
|
||||
type DecisionAction struct {
|
||||
Action string `json:"action"`
|
||||
Symbol string `json:"symbol"`
|
||||
Quantity float64 `json:"quantity"`
|
||||
Leverage int `json:"leverage"`
|
||||
Price float64 `json:"price"`
|
||||
OrderID int64 `json:"order_id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// Statistics 统计信息
|
||||
type Statistics struct {
|
||||
TotalCycles int `json:"total_cycles"`
|
||||
SuccessfulCycles int `json:"successful_cycles"`
|
||||
FailedCycles int `json:"failed_cycles"`
|
||||
TotalOpenPositions int `json:"total_open_positions"`
|
||||
TotalClosePositions int `json:"total_close_positions"`
|
||||
}
|
||||
|
||||
// initTables 初始化决策相关表
|
||||
func (s *DecisionStore) initTables() error {
|
||||
queries := []string{
|
||||
// 决策记录主表
|
||||
`CREATE TABLE IF NOT EXISTS decision_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
trader_id TEXT NOT NULL,
|
||||
cycle_number INTEGER NOT NULL,
|
||||
timestamp DATETIME NOT NULL,
|
||||
system_prompt TEXT DEFAULT '',
|
||||
input_prompt TEXT DEFAULT '',
|
||||
cot_trace TEXT DEFAULT '',
|
||||
decision_json TEXT DEFAULT '',
|
||||
candidate_coins TEXT DEFAULT '',
|
||||
execution_log TEXT DEFAULT '',
|
||||
success BOOLEAN DEFAULT 0,
|
||||
error_message TEXT DEFAULT '',
|
||||
ai_request_duration_ms INTEGER DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
|
||||
// 账户状态快照表
|
||||
`CREATE TABLE IF NOT EXISTS decision_account_snapshots (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
decision_id INTEGER NOT NULL,
|
||||
total_balance REAL DEFAULT 0,
|
||||
available_balance REAL DEFAULT 0,
|
||||
total_unrealized_profit REAL DEFAULT 0,
|
||||
position_count INTEGER DEFAULT 0,
|
||||
margin_used_pct REAL DEFAULT 0,
|
||||
initial_balance REAL DEFAULT 0,
|
||||
FOREIGN KEY (decision_id) REFERENCES decision_records(id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// 持仓快照表
|
||||
`CREATE TABLE IF NOT EXISTS decision_position_snapshots (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
decision_id INTEGER NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
side TEXT DEFAULT '',
|
||||
position_amt REAL DEFAULT 0,
|
||||
entry_price REAL DEFAULT 0,
|
||||
mark_price REAL DEFAULT 0,
|
||||
unrealized_profit REAL DEFAULT 0,
|
||||
leverage REAL DEFAULT 0,
|
||||
liquidation_price REAL DEFAULT 0,
|
||||
FOREIGN KEY (decision_id) REFERENCES decision_records(id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// 决策动作表(订单详情)
|
||||
`CREATE TABLE IF NOT EXISTS decision_actions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
decision_id INTEGER NOT NULL,
|
||||
trader_id TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
quantity REAL DEFAULT 0,
|
||||
leverage INTEGER DEFAULT 0,
|
||||
price REAL DEFAULT 0,
|
||||
order_id INTEGER DEFAULT 0,
|
||||
timestamp DATETIME NOT NULL,
|
||||
success BOOLEAN DEFAULT 0,
|
||||
error TEXT DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (decision_id) REFERENCES decision_records(id) ON DELETE CASCADE
|
||||
)`,
|
||||
|
||||
// 索引
|
||||
`CREATE INDEX IF NOT EXISTS idx_decision_records_trader_time ON decision_records(trader_id, timestamp DESC)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_decision_records_timestamp ON decision_records(timestamp DESC)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_decision_actions_trader ON decision_actions(trader_id, timestamp DESC)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_decision_actions_symbol ON decision_actions(symbol, timestamp DESC)`,
|
||||
}
|
||||
|
||||
for _, query := range queries {
|
||||
if _, err := s.db.Exec(query); err != nil {
|
||||
return fmt.Errorf("执行SQL失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogDecision 记录决策
|
||||
func (s *DecisionStore) LogDecision(record *DecisionRecord) error {
|
||||
if record.Timestamp.IsZero() {
|
||||
record.Timestamp = time.Now().UTC()
|
||||
} else {
|
||||
record.Timestamp = record.Timestamp.UTC()
|
||||
}
|
||||
|
||||
// 开始事务
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("开始事务失败: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// 序列化候选币种和执行日志为 JSON
|
||||
candidateCoinsJSON, _ := json.Marshal(record.CandidateCoins)
|
||||
executionLogJSON, _ := json.Marshal(record.ExecutionLog)
|
||||
|
||||
// 插入决策记录主表
|
||||
result, err := tx.Exec(`
|
||||
INSERT INTO decision_records (
|
||||
trader_id, cycle_number, timestamp, system_prompt, input_prompt,
|
||||
cot_trace, decision_json, candidate_coins, execution_log,
|
||||
success, error_message, ai_request_duration_ms
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`,
|
||||
record.TraderID, record.CycleNumber, record.Timestamp.Format(time.RFC3339),
|
||||
record.SystemPrompt, record.InputPrompt, record.CoTTrace, record.DecisionJSON,
|
||||
string(candidateCoinsJSON), string(executionLogJSON),
|
||||
record.Success, record.ErrorMessage, record.AIRequestDurationMs,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("插入决策记录失败: %w", err)
|
||||
}
|
||||
|
||||
decisionID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取决策ID失败: %w", err)
|
||||
}
|
||||
record.ID = decisionID
|
||||
|
||||
// 插入账户状态快照
|
||||
_, err = tx.Exec(`
|
||||
INSERT INTO decision_account_snapshots (
|
||||
decision_id, total_balance, available_balance, total_unrealized_profit,
|
||||
position_count, margin_used_pct, initial_balance
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
`,
|
||||
decisionID, record.AccountState.TotalBalance, record.AccountState.AvailableBalance,
|
||||
record.AccountState.TotalUnrealizedProfit, record.AccountState.PositionCount,
|
||||
record.AccountState.MarginUsedPct, record.AccountState.InitialBalance,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("插入账户快照失败: %w", err)
|
||||
}
|
||||
|
||||
// 插入持仓快照
|
||||
for _, pos := range record.Positions {
|
||||
_, err = tx.Exec(`
|
||||
INSERT INTO decision_position_snapshots (
|
||||
decision_id, symbol, side, position_amt, entry_price,
|
||||
mark_price, unrealized_profit, leverage, liquidation_price
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`,
|
||||
decisionID, pos.Symbol, pos.Side, pos.PositionAmt, pos.EntryPrice,
|
||||
pos.MarkPrice, pos.UnrealizedProfit, pos.Leverage, pos.LiquidationPrice,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("插入持仓快照失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 插入决策动作(订单详情)
|
||||
for _, action := range record.Decisions {
|
||||
actionTimestamp := action.Timestamp
|
||||
if actionTimestamp.IsZero() {
|
||||
actionTimestamp = record.Timestamp
|
||||
}
|
||||
_, err = tx.Exec(`
|
||||
INSERT INTO decision_actions (
|
||||
decision_id, trader_id, action, symbol, quantity, leverage,
|
||||
price, order_id, timestamp, success, error
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`,
|
||||
decisionID, record.TraderID, action.Action, action.Symbol, action.Quantity,
|
||||
action.Leverage, action.Price, action.OrderID,
|
||||
actionTimestamp.Format(time.RFC3339), action.Success, action.Error,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("插入决策动作失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("提交事务失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLatestRecords 获取指定交易员最近N条记录(按时间正序:从旧到新)
|
||||
func (s *DecisionStore) GetLatestRecords(traderID string, n int) ([]*DecisionRecord, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
|
||||
cot_trace, decision_json, candidate_coins, execution_log,
|
||||
success, error_message, ai_request_duration_ms
|
||||
FROM decision_records
|
||||
WHERE trader_id = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
`, traderID, n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询决策记录失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []*DecisionRecord
|
||||
for rows.Next() {
|
||||
record, err := s.scanDecisionRecord(rows)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
// 填充关联数据
|
||||
for _, record := range records {
|
||||
s.fillRecordDetails(record)
|
||||
}
|
||||
|
||||
// 反转数组,让时间从旧到新排列
|
||||
for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 {
|
||||
records[i], records[j] = records[j], records[i]
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// GetAllLatestRecords 获取所有交易员最近N条记录
|
||||
func (s *DecisionStore) GetAllLatestRecords(n int) ([]*DecisionRecord, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
|
||||
cot_trace, decision_json, candidate_coins, execution_log,
|
||||
success, error_message, ai_request_duration_ms
|
||||
FROM decision_records
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
`, n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询决策记录失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []*DecisionRecord
|
||||
for rows.Next() {
|
||||
record, err := s.scanDecisionRecord(rows)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
// 反转数组
|
||||
for i, j := 0, len(records)-1; i < j; i, j = i+1, j-1 {
|
||||
records[i], records[j] = records[j], records[i]
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// GetRecordsByDate 获取指定交易员指定日期的所有记录
|
||||
func (s *DecisionStore) GetRecordsByDate(traderID string, date time.Time) ([]*DecisionRecord, error) {
|
||||
dateStr := date.Format("2006-01-02")
|
||||
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, cycle_number, timestamp, system_prompt, input_prompt,
|
||||
cot_trace, decision_json, candidate_coins, execution_log,
|
||||
success, error_message, ai_request_duration_ms
|
||||
FROM decision_records
|
||||
WHERE trader_id = ? AND DATE(timestamp) = ?
|
||||
ORDER BY timestamp ASC
|
||||
`, traderID, dateStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询决策记录失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []*DecisionRecord
|
||||
for rows.Next() {
|
||||
record, err := s.scanDecisionRecord(rows)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
records = append(records, record)
|
||||
}
|
||||
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// CleanOldRecords 清理N天前的旧记录
|
||||
func (s *DecisionStore) CleanOldRecords(traderID string, days int) (int64, error) {
|
||||
cutoffTime := time.Now().AddDate(0, 0, -days).Format(time.RFC3339)
|
||||
|
||||
result, err := s.db.Exec(`
|
||||
DELETE FROM decision_records
|
||||
WHERE trader_id = ? AND timestamp < ?
|
||||
`, traderID, cutoffTime)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("清理旧记录失败: %w", err)
|
||||
}
|
||||
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// GetStatistics 获取指定交易员的统计信息
|
||||
func (s *DecisionStore) GetStatistics(traderID string) (*Statistics, error) {
|
||||
stats := &Statistics{}
|
||||
|
||||
err := s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM decision_records WHERE trader_id = ?
|
||||
`, traderID).Scan(&stats.TotalCycles)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询总周期数失败: %w", err)
|
||||
}
|
||||
|
||||
err = s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM decision_records WHERE trader_id = ? AND success = 1
|
||||
`, traderID).Scan(&stats.SuccessfulCycles)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询成功周期数失败: %w", err)
|
||||
}
|
||||
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
|
||||
|
||||
err = s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM decision_actions
|
||||
WHERE trader_id = ? AND success = 1 AND action IN ('open_long', 'open_short')
|
||||
`, traderID).Scan(&stats.TotalOpenPositions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询开仓次数失败: %w", err)
|
||||
}
|
||||
|
||||
err = s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM decision_actions
|
||||
WHERE trader_id = ? AND success = 1 AND action IN ('close_long', 'close_short', 'auto_close_long', 'auto_close_short')
|
||||
`, traderID).Scan(&stats.TotalClosePositions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询平仓次数失败: %w", err)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetAllStatistics 获取所有交易员的统计信息
|
||||
func (s *DecisionStore) GetAllStatistics() (*Statistics, error) {
|
||||
stats := &Statistics{}
|
||||
|
||||
s.db.QueryRow(`SELECT COUNT(*) FROM decision_records`).Scan(&stats.TotalCycles)
|
||||
s.db.QueryRow(`SELECT COUNT(*) FROM decision_records WHERE success = 1`).Scan(&stats.SuccessfulCycles)
|
||||
stats.FailedCycles = stats.TotalCycles - stats.SuccessfulCycles
|
||||
|
||||
s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM decision_actions
|
||||
WHERE success = 1 AND action IN ('open_long', 'open_short')
|
||||
`).Scan(&stats.TotalOpenPositions)
|
||||
|
||||
s.db.QueryRow(`
|
||||
SELECT COUNT(*) FROM decision_actions
|
||||
WHERE success = 1 AND action IN ('close_long', 'close_short', 'auto_close_long', 'auto_close_short')
|
||||
`).Scan(&stats.TotalClosePositions)
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetLastCycleNumber 获取指定交易员的最后周期编号
|
||||
func (s *DecisionStore) GetLastCycleNumber(traderID string) (int, error) {
|
||||
var cycleNumber int
|
||||
err := s.db.QueryRow(`
|
||||
SELECT COALESCE(MAX(cycle_number), 0) FROM decision_records WHERE trader_id = ?
|
||||
`, traderID).Scan(&cycleNumber)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return cycleNumber, nil
|
||||
}
|
||||
|
||||
// scanDecisionRecord 从行中扫描决策记录
|
||||
func (s *DecisionStore) scanDecisionRecord(rows *sql.Rows) (*DecisionRecord, error) {
|
||||
var record DecisionRecord
|
||||
var timestampStr string
|
||||
var candidateCoinsJSON, executionLogJSON string
|
||||
|
||||
err := rows.Scan(
|
||||
&record.ID, &record.TraderID, &record.CycleNumber, ×tampStr,
|
||||
&record.SystemPrompt, &record.InputPrompt, &record.CoTTrace,
|
||||
&record.DecisionJSON, &candidateCoinsJSON, &executionLogJSON,
|
||||
&record.Success, &record.ErrorMessage, &record.AIRequestDurationMs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
record.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
|
||||
json.Unmarshal([]byte(candidateCoinsJSON), &record.CandidateCoins)
|
||||
json.Unmarshal([]byte(executionLogJSON), &record.ExecutionLog)
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// fillRecordDetails 填充决策记录的关联数据
|
||||
func (s *DecisionStore) fillRecordDetails(record *DecisionRecord) {
|
||||
// 查询账户状态
|
||||
s.db.QueryRow(`
|
||||
SELECT total_balance, available_balance, total_unrealized_profit,
|
||||
position_count, margin_used_pct, initial_balance
|
||||
FROM decision_account_snapshots
|
||||
WHERE decision_id = ?
|
||||
`, record.ID).Scan(
|
||||
&record.AccountState.TotalBalance,
|
||||
&record.AccountState.AvailableBalance,
|
||||
&record.AccountState.TotalUnrealizedProfit,
|
||||
&record.AccountState.PositionCount,
|
||||
&record.AccountState.MarginUsedPct,
|
||||
&record.AccountState.InitialBalance,
|
||||
)
|
||||
|
||||
// 查询持仓快照
|
||||
posRows, err := s.db.Query(`
|
||||
SELECT symbol, side, position_amt, entry_price, mark_price,
|
||||
unrealized_profit, leverage, liquidation_price
|
||||
FROM decision_position_snapshots
|
||||
WHERE decision_id = ?
|
||||
`, record.ID)
|
||||
if err == nil {
|
||||
defer posRows.Close()
|
||||
for posRows.Next() {
|
||||
var pos PositionSnapshot
|
||||
posRows.Scan(
|
||||
&pos.Symbol, &pos.Side, &pos.PositionAmt, &pos.EntryPrice,
|
||||
&pos.MarkPrice, &pos.UnrealizedProfit, &pos.Leverage,
|
||||
&pos.LiquidationPrice,
|
||||
)
|
||||
record.Positions = append(record.Positions, pos)
|
||||
}
|
||||
}
|
||||
|
||||
// 查询决策动作
|
||||
actionRows, err := s.db.Query(`
|
||||
SELECT action, symbol, quantity, leverage, price, order_id,
|
||||
timestamp, success, error
|
||||
FROM decision_actions
|
||||
WHERE decision_id = ?
|
||||
`, record.ID)
|
||||
if err == nil {
|
||||
defer actionRows.Close()
|
||||
for actionRows.Next() {
|
||||
var action DecisionAction
|
||||
var timestampStr string
|
||||
actionRows.Scan(
|
||||
&action.Action, &action.Symbol, &action.Quantity,
|
||||
&action.Leverage, &action.Price, &action.OrderID,
|
||||
×tampStr, &action.Success, &action.Error,
|
||||
)
|
||||
action.Timestamp, _ = time.Parse(time.RFC3339, timestampStr)
|
||||
record.Decisions = append(record.Decisions, action)
|
||||
}
|
||||
}
|
||||
}
|
||||
245
store/exchange.go
Normal file
245
store/exchange.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"nofx/logger"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ExchangeStore 交易所存储
|
||||
type ExchangeStore struct {
|
||||
db *sql.DB
|
||||
encryptFunc func(string) string
|
||||
decryptFunc func(string) string
|
||||
}
|
||||
|
||||
// Exchange 交易所配置
|
||||
type Exchange struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
APIKey string `json:"apiKey"`
|
||||
SecretKey string `json:"secretKey"`
|
||||
Testnet bool `json:"testnet"`
|
||||
HyperliquidWalletAddr string `json:"hyperliquidWalletAddr"`
|
||||
AsterUser string `json:"asterUser"`
|
||||
AsterSigner string `json:"asterSigner"`
|
||||
AsterPrivateKey string `json:"asterPrivateKey"`
|
||||
LighterWalletAddr string `json:"lighterWalletAddr"`
|
||||
LighterPrivateKey string `json:"lighterPrivateKey"`
|
||||
LighterAPIKeyPrivateKey string `json:"lighterAPIKeyPrivateKey"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (s *ExchangeStore) initTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS exchanges (
|
||||
id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL DEFAULT 'default',
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
enabled BOOLEAN DEFAULT 0,
|
||||
api_key TEXT DEFAULT '',
|
||||
secret_key TEXT DEFAULT '',
|
||||
testnet BOOLEAN DEFAULT 0,
|
||||
hyperliquid_wallet_addr TEXT DEFAULT '',
|
||||
aster_user TEXT DEFAULT '',
|
||||
aster_signer TEXT DEFAULT '',
|
||||
aster_private_key TEXT DEFAULT '',
|
||||
lighter_wallet_addr TEXT DEFAULT '',
|
||||
lighter_private_key TEXT DEFAULT '',
|
||||
lighter_api_key_private_key TEXT DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (id, user_id),
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 触发器
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS update_exchanges_updated_at
|
||||
AFTER UPDATE ON exchanges
|
||||
BEGIN
|
||||
UPDATE exchanges SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id AND user_id = NEW.user_id;
|
||||
END
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *ExchangeStore) initDefaultData() error {
|
||||
exchanges := []struct {
|
||||
id, name, typ string
|
||||
}{
|
||||
{"binance", "Binance Futures", "binance"},
|
||||
{"bybit", "Bybit Futures", "bybit"},
|
||||
{"hyperliquid", "Hyperliquid", "hyperliquid"},
|
||||
{"aster", "Aster DEX", "aster"},
|
||||
{"lighter", "LIGHTER DEX", "lighter"},
|
||||
}
|
||||
|
||||
for _, exchange := range exchanges {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT OR IGNORE INTO exchanges (id, user_id, name, type, enabled)
|
||||
VALUES (?, 'default', ?, ?, 0)
|
||||
`, exchange.id, exchange.name, exchange.typ)
|
||||
if err != nil {
|
||||
return fmt.Errorf("初始化交易所失败: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ExchangeStore) encrypt(plaintext string) string {
|
||||
if s.encryptFunc != nil {
|
||||
return s.encryptFunc(plaintext)
|
||||
}
|
||||
return plaintext
|
||||
}
|
||||
|
||||
func (s *ExchangeStore) decrypt(encrypted string) string {
|
||||
if s.decryptFunc != nil {
|
||||
return s.decryptFunc(encrypted)
|
||||
}
|
||||
return encrypted
|
||||
}
|
||||
|
||||
// List 获取用户的交易所列表
|
||||
func (s *ExchangeStore) List(userID string) ([]*Exchange, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, user_id, name, type, enabled, api_key, secret_key, testnet,
|
||||
COALESCE(hyperliquid_wallet_addr, '') as hyperliquid_wallet_addr,
|
||||
COALESCE(aster_user, '') as aster_user,
|
||||
COALESCE(aster_signer, '') as aster_signer,
|
||||
COALESCE(aster_private_key, '') as aster_private_key,
|
||||
COALESCE(lighter_wallet_addr, '') as lighter_wallet_addr,
|
||||
COALESCE(lighter_private_key, '') as lighter_private_key,
|
||||
COALESCE(lighter_api_key_private_key, '') as lighter_api_key_private_key,
|
||||
created_at, updated_at
|
||||
FROM exchanges WHERE user_id = ? ORDER BY id
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
exchanges := make([]*Exchange, 0)
|
||||
for rows.Next() {
|
||||
var e Exchange
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(
|
||||
&e.ID, &e.UserID, &e.Name, &e.Type,
|
||||
&e.Enabled, &e.APIKey, &e.SecretKey, &e.Testnet,
|
||||
&e.HyperliquidWalletAddr, &e.AsterUser, &e.AsterSigner, &e.AsterPrivateKey,
|
||||
&e.LighterWalletAddr, &e.LighterPrivateKey, &e.LighterAPIKeyPrivateKey,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
e.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
e.APIKey = s.decrypt(e.APIKey)
|
||||
e.SecretKey = s.decrypt(e.SecretKey)
|
||||
e.AsterPrivateKey = s.decrypt(e.AsterPrivateKey)
|
||||
e.LighterPrivateKey = s.decrypt(e.LighterPrivateKey)
|
||||
e.LighterAPIKeyPrivateKey = s.decrypt(e.LighterAPIKeyPrivateKey)
|
||||
exchanges = append(exchanges, &e)
|
||||
}
|
||||
return exchanges, nil
|
||||
}
|
||||
|
||||
// Update 更新交易所配置
|
||||
func (s *ExchangeStore) Update(userID, id string, enabled bool, apiKey, secretKey string, testnet bool,
|
||||
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey, lighterWalletAddr, lighterPrivateKey string) error {
|
||||
|
||||
logger.Debugf("🔧 ExchangeStore.Update: userID=%s, id=%s, enabled=%v", userID, id, enabled)
|
||||
|
||||
setClauses := []string{
|
||||
"enabled = ?",
|
||||
"testnet = ?",
|
||||
"hyperliquid_wallet_addr = ?",
|
||||
"aster_user = ?",
|
||||
"aster_signer = ?",
|
||||
"lighter_wallet_addr = ?",
|
||||
"updated_at = datetime('now')",
|
||||
}
|
||||
args := []interface{}{enabled, testnet, hyperliquidWalletAddr, asterUser, asterSigner, lighterWalletAddr}
|
||||
|
||||
if apiKey != "" {
|
||||
setClauses = append(setClauses, "api_key = ?")
|
||||
args = append(args, s.encrypt(apiKey))
|
||||
}
|
||||
if secretKey != "" {
|
||||
setClauses = append(setClauses, "secret_key = ?")
|
||||
args = append(args, s.encrypt(secretKey))
|
||||
}
|
||||
if asterPrivateKey != "" {
|
||||
setClauses = append(setClauses, "aster_private_key = ?")
|
||||
args = append(args, s.encrypt(asterPrivateKey))
|
||||
}
|
||||
if lighterPrivateKey != "" {
|
||||
setClauses = append(setClauses, "lighter_private_key = ?")
|
||||
args = append(args, s.encrypt(lighterPrivateKey))
|
||||
}
|
||||
|
||||
args = append(args, id, userID)
|
||||
query := fmt.Sprintf(`UPDATE exchanges SET %s WHERE id = ? AND user_id = ?`, strings.Join(setClauses, ", "))
|
||||
|
||||
result, err := s.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected == 0 {
|
||||
// 创建新记录
|
||||
var name, typ string
|
||||
switch id {
|
||||
case "binance":
|
||||
name, typ = "Binance Futures", "cex"
|
||||
case "bybit":
|
||||
name, typ = "Bybit Futures", "cex"
|
||||
case "hyperliquid":
|
||||
name, typ = "Hyperliquid", "dex"
|
||||
case "aster":
|
||||
name, typ = "Aster DEX", "dex"
|
||||
case "lighter":
|
||||
name, typ = "LIGHTER DEX", "dex"
|
||||
default:
|
||||
name, typ = id+" Exchange", "cex"
|
||||
}
|
||||
|
||||
_, err = s.db.Exec(`
|
||||
INSERT INTO exchanges (id, user_id, name, type, enabled, api_key, secret_key, testnet,
|
||||
hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key,
|
||||
lighter_wallet_addr, lighter_private_key, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))
|
||||
`, id, userID, name, typ, enabled, s.encrypt(apiKey), s.encrypt(secretKey), testnet,
|
||||
hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey),
|
||||
lighterWalletAddr, s.encrypt(lighterPrivateKey))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create 创建交易所配置
|
||||
func (s *ExchangeStore) Create(userID, id, name, typ string, enabled bool, apiKey, secretKey string, testnet bool,
|
||||
hyperliquidWalletAddr, asterUser, asterSigner, asterPrivateKey string) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT OR IGNORE INTO exchanges (id, user_id, name, type, enabled, api_key, secret_key, testnet,
|
||||
hyperliquid_wallet_addr, aster_user, aster_signer, aster_private_key,
|
||||
lighter_wallet_addr, lighter_private_key)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, '', '')
|
||||
`, id, userID, name, typ, enabled, s.encrypt(apiKey), s.encrypt(secretKey), testnet,
|
||||
hyperliquidWalletAddr, asterUser, asterSigner, s.encrypt(asterPrivateKey))
|
||||
return err
|
||||
}
|
||||
511
store/order.go
Normal file
511
store/order.go
Normal file
@@ -0,0 +1,511 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TraderOrder 交易员订单记录
|
||||
type TraderOrder struct {
|
||||
ID int64 `json:"id"`
|
||||
TraderID string `json:"trader_id"` // 交易员ID
|
||||
OrderID string `json:"order_id"` // 交易所订单ID
|
||||
ClientOrderID string `json:"client_order_id"` // 客户端订单ID
|
||||
Symbol string `json:"symbol"` // 交易对
|
||||
Side string `json:"side"` // BUY/SELL
|
||||
PositionSide string `json:"position_side"` // LONG/SHORT/BOTH
|
||||
Action string `json:"action"` // open_long/close_long/open_short/close_short
|
||||
OrderType string `json:"order_type"` // MARKET/LIMIT
|
||||
Quantity float64 `json:"quantity"` // 订单数量
|
||||
Price float64 `json:"price"` // 订单价格
|
||||
AvgPrice float64 `json:"avg_price"` // 实际成交均价
|
||||
ExecutedQty float64 `json:"executed_qty"` // 已成交数量
|
||||
Leverage int `json:"leverage"` // 杠杆倍数
|
||||
Status string `json:"status"` // NEW/FILLED/CANCELED/EXPIRED
|
||||
Fee float64 `json:"fee"` // 手续费
|
||||
FeeAsset string `json:"fee_asset"` // 手续费资产
|
||||
RealizedPnL float64 `json:"realized_pnl"` // 已实现盈亏(平仓时)
|
||||
EntryPrice float64 `json:"entry_price"` // 开仓价(平仓时记录)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
FilledAt time.Time `json:"filled_at"` // 成交时间
|
||||
}
|
||||
|
||||
// TraderStats 交易统计指标
|
||||
type TraderStats struct {
|
||||
TotalTrades int `json:"total_trades"` // 总交易数(已平仓)
|
||||
WinTrades int `json:"win_trades"` // 盈利交易数
|
||||
LossTrades int `json:"loss_trades"` // 亏损交易数
|
||||
WinRate float64 `json:"win_rate"` // 胜率 (%)
|
||||
ProfitFactor float64 `json:"profit_factor"` // 盈亏比
|
||||
SharpeRatio float64 `json:"sharpe_ratio"` // 夏普比
|
||||
TotalPnL float64 `json:"total_pnl"` // 总盈亏
|
||||
TotalFee float64 `json:"total_fee"` // 总手续费
|
||||
AvgWin float64 `json:"avg_win"` // 平均盈利
|
||||
AvgLoss float64 `json:"avg_loss"` // 平均亏损
|
||||
MaxDrawdownPct float64 `json:"max_drawdown_pct"` // 最大回撤 (%)
|
||||
}
|
||||
|
||||
// CompletedOrder 已完成订单(用于AI输入)
|
||||
type CompletedOrder struct {
|
||||
Symbol string `json:"symbol"` // 交易对
|
||||
Action string `json:"action"` // close_long/close_short
|
||||
Side string `json:"side"` // long/short
|
||||
Quantity float64 `json:"quantity"` // 数量
|
||||
EntryPrice float64 `json:"entry_price"` // 开仓价
|
||||
ExitPrice float64 `json:"exit_price"` // 平仓价
|
||||
RealizedPnL float64 `json:"realized_pnl"` // 已实现盈亏
|
||||
PnLPct float64 `json:"pnl_pct"` // 盈亏百分比
|
||||
Fee float64 `json:"fee"` // 手续费
|
||||
Leverage int `json:"leverage"` // 杠杆
|
||||
FilledAt time.Time `json:"filled_at"` // 成交时间
|
||||
}
|
||||
|
||||
// OrderStore 订单存储
|
||||
type OrderStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewOrderStore 创建订单存储实例
|
||||
func NewOrderStore(db *sql.DB) *OrderStore {
|
||||
return &OrderStore{db: db}
|
||||
}
|
||||
|
||||
// InitTables 初始化订单表
|
||||
func (s *OrderStore) InitTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS trader_orders (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
trader_id TEXT NOT NULL,
|
||||
order_id TEXT NOT NULL,
|
||||
client_order_id TEXT DEFAULT '',
|
||||
symbol TEXT NOT NULL,
|
||||
side TEXT NOT NULL,
|
||||
position_side TEXT DEFAULT '',
|
||||
action TEXT NOT NULL,
|
||||
order_type TEXT DEFAULT 'MARKET',
|
||||
quantity REAL NOT NULL,
|
||||
price REAL DEFAULT 0,
|
||||
avg_price REAL DEFAULT 0,
|
||||
executed_qty REAL DEFAULT 0,
|
||||
leverage INTEGER DEFAULT 1,
|
||||
status TEXT DEFAULT 'NEW',
|
||||
fee REAL DEFAULT 0,
|
||||
fee_asset TEXT DEFAULT 'USDT',
|
||||
realized_pnl REAL DEFAULT 0,
|
||||
entry_price REAL DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
filled_at DATETIME,
|
||||
UNIQUE(trader_id, order_id)
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建trader_orders表失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建索引
|
||||
indices := []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_trader_orders_trader ON trader_orders(trader_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_trader_orders_status ON trader_orders(trader_id, status)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_trader_orders_symbol ON trader_orders(trader_id, symbol)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_trader_orders_filled ON trader_orders(trader_id, filled_at DESC)`,
|
||||
}
|
||||
for _, idx := range indices {
|
||||
if _, err := s.db.Exec(idx); err != nil {
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create 创建订单记录
|
||||
func (s *OrderStore) Create(order *TraderOrder) error {
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
result, err := s.db.Exec(`
|
||||
INSERT INTO trader_orders (
|
||||
trader_id, order_id, client_order_id, symbol, side, position_side,
|
||||
action, order_type, quantity, price, avg_price, executed_qty,
|
||||
leverage, status, fee, fee_asset, realized_pnl, entry_price,
|
||||
created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`,
|
||||
order.TraderID, order.OrderID, order.ClientOrderID, order.Symbol,
|
||||
order.Side, order.PositionSide, order.Action, order.OrderType,
|
||||
order.Quantity, order.Price, order.AvgPrice, order.ExecutedQty,
|
||||
order.Leverage, order.Status, order.Fee, order.FeeAsset,
|
||||
order.RealizedPnL, order.EntryPrice, now, now,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建订单记录失败: %w", err)
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
order.ID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update 更新订单记录
|
||||
func (s *OrderStore) Update(order *TraderOrder) error {
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
filledAt := ""
|
||||
if !order.FilledAt.IsZero() {
|
||||
filledAt = order.FilledAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE trader_orders SET
|
||||
avg_price = ?, executed_qty = ?, status = ?, fee = ?,
|
||||
realized_pnl = ?, entry_price = ?, updated_at = ?, filled_at = ?
|
||||
WHERE trader_id = ? AND order_id = ?
|
||||
`,
|
||||
order.AvgPrice, order.ExecutedQty, order.Status, order.Fee,
|
||||
order.RealizedPnL, order.EntryPrice, now, filledAt,
|
||||
order.TraderID, order.OrderID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新订单记录失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByOrderID 根据订单ID获取订单
|
||||
func (s *OrderStore) GetByOrderID(traderID, orderID string) (*TraderOrder, error) {
|
||||
var order TraderOrder
|
||||
var createdAt, updatedAt, filledAt sql.NullString
|
||||
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, trader_id, order_id, client_order_id, symbol, side, position_side,
|
||||
action, order_type, quantity, price, avg_price, executed_qty,
|
||||
leverage, status, fee, fee_asset, realized_pnl, entry_price,
|
||||
created_at, updated_at, filled_at
|
||||
FROM trader_orders WHERE trader_id = ? AND order_id = ?
|
||||
`, traderID, orderID).Scan(
|
||||
&order.ID, &order.TraderID, &order.OrderID, &order.ClientOrderID,
|
||||
&order.Symbol, &order.Side, &order.PositionSide, &order.Action,
|
||||
&order.OrderType, &order.Quantity, &order.Price, &order.AvgPrice,
|
||||
&order.ExecutedQty, &order.Leverage, &order.Status, &order.Fee,
|
||||
&order.FeeAsset, &order.RealizedPnL, &order.EntryPrice,
|
||||
&createdAt, &updatedAt, &filledAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if createdAt.Valid {
|
||||
order.CreatedAt, _ = time.Parse(time.RFC3339, createdAt.String)
|
||||
}
|
||||
if updatedAt.Valid {
|
||||
order.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String)
|
||||
}
|
||||
if filledAt.Valid {
|
||||
order.FilledAt, _ = time.Parse(time.RFC3339, filledAt.String)
|
||||
}
|
||||
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
// GetLatestOpenOrder 获取某币种最近的开仓订单(用于计算平仓盈亏)
|
||||
func (s *OrderStore) GetLatestOpenOrder(traderID, symbol, side string) (*TraderOrder, error) {
|
||||
// side: long -> 找 open_long, short -> 找 open_short
|
||||
action := "open_long"
|
||||
if side == "short" {
|
||||
action = "open_short"
|
||||
}
|
||||
|
||||
var order TraderOrder
|
||||
var createdAt, updatedAt, filledAt sql.NullString
|
||||
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, trader_id, order_id, client_order_id, symbol, side, position_side,
|
||||
action, order_type, quantity, price, avg_price, executed_qty,
|
||||
leverage, status, fee, fee_asset, realized_pnl, entry_price,
|
||||
created_at, updated_at, filled_at
|
||||
FROM trader_orders
|
||||
WHERE trader_id = ? AND symbol = ? AND action = ? AND status = 'FILLED'
|
||||
ORDER BY filled_at DESC LIMIT 1
|
||||
`, traderID, symbol, action).Scan(
|
||||
&order.ID, &order.TraderID, &order.OrderID, &order.ClientOrderID,
|
||||
&order.Symbol, &order.Side, &order.PositionSide, &order.Action,
|
||||
&order.OrderType, &order.Quantity, &order.Price, &order.AvgPrice,
|
||||
&order.ExecutedQty, &order.Leverage, &order.Status, &order.Fee,
|
||||
&order.FeeAsset, &order.RealizedPnL, &order.EntryPrice,
|
||||
&createdAt, &updatedAt, &filledAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if createdAt.Valid {
|
||||
order.CreatedAt, _ = time.Parse(time.RFC3339, createdAt.String)
|
||||
}
|
||||
if updatedAt.Valid {
|
||||
order.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String)
|
||||
}
|
||||
if filledAt.Valid {
|
||||
order.FilledAt, _ = time.Parse(time.RFC3339, filledAt.String)
|
||||
}
|
||||
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
// GetRecentCompletedOrders 获取最近已完成的平仓订单
|
||||
func (s *OrderStore) GetRecentCompletedOrders(traderID string, limit int) ([]CompletedOrder, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT symbol, action, side, executed_qty, entry_price, avg_price,
|
||||
realized_pnl, fee, leverage, filled_at
|
||||
FROM trader_orders
|
||||
WHERE trader_id = ? AND status = 'FILLED'
|
||||
AND (action = 'close_long' OR action = 'close_short')
|
||||
ORDER BY filled_at DESC
|
||||
LIMIT ?
|
||||
`, traderID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询已完成订单失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var orders []CompletedOrder
|
||||
for rows.Next() {
|
||||
var o CompletedOrder
|
||||
var filledAt sql.NullString
|
||||
var side sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&o.Symbol, &o.Action, &side, &o.Quantity, &o.EntryPrice, &o.ExitPrice,
|
||||
&o.RealizedPnL, &o.Fee, &o.Leverage, &filledAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 根据action推断side
|
||||
if o.Action == "close_long" {
|
||||
o.Side = "long"
|
||||
} else if o.Action == "close_short" {
|
||||
o.Side = "short"
|
||||
} else if side.Valid {
|
||||
o.Side = side.String
|
||||
}
|
||||
|
||||
// 计算盈亏百分比
|
||||
if o.EntryPrice > 0 {
|
||||
if o.Side == "long" {
|
||||
o.PnLPct = (o.ExitPrice - o.EntryPrice) / o.EntryPrice * 100 * float64(o.Leverage)
|
||||
} else {
|
||||
o.PnLPct = (o.EntryPrice - o.ExitPrice) / o.EntryPrice * 100 * float64(o.Leverage)
|
||||
}
|
||||
}
|
||||
|
||||
if filledAt.Valid {
|
||||
o.FilledAt, _ = time.Parse(time.RFC3339, filledAt.String)
|
||||
}
|
||||
|
||||
orders = append(orders, o)
|
||||
}
|
||||
|
||||
return orders, nil
|
||||
}
|
||||
|
||||
// GetTraderStats 获取交易统计指标
|
||||
func (s *OrderStore) GetTraderStats(traderID string) (*TraderStats, error) {
|
||||
stats := &TraderStats{}
|
||||
|
||||
// 查询所有已完成的平仓订单
|
||||
rows, err := s.db.Query(`
|
||||
SELECT realized_pnl, fee, filled_at
|
||||
FROM trader_orders
|
||||
WHERE trader_id = ? AND status = 'FILLED'
|
||||
AND (action = 'close_long' OR action = 'close_short')
|
||||
ORDER BY filled_at ASC
|
||||
`, traderID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询订单统计失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var pnls []float64
|
||||
var totalWin, totalLoss float64
|
||||
|
||||
for rows.Next() {
|
||||
var pnl, fee float64
|
||||
var filledAt sql.NullString
|
||||
if err := rows.Scan(&pnl, &fee, &filledAt); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
stats.TotalTrades++
|
||||
stats.TotalPnL += pnl
|
||||
stats.TotalFee += fee
|
||||
pnls = append(pnls, pnl)
|
||||
|
||||
if pnl > 0 {
|
||||
stats.WinTrades++
|
||||
totalWin += pnl
|
||||
} else if pnl < 0 {
|
||||
stats.LossTrades++
|
||||
totalLoss += math.Abs(pnl)
|
||||
}
|
||||
}
|
||||
|
||||
// 计算胜率
|
||||
if stats.TotalTrades > 0 {
|
||||
stats.WinRate = float64(stats.WinTrades) / float64(stats.TotalTrades) * 100
|
||||
}
|
||||
|
||||
// 计算盈亏比
|
||||
if totalLoss > 0 {
|
||||
stats.ProfitFactor = totalWin / totalLoss
|
||||
}
|
||||
|
||||
// 计算平均盈亏
|
||||
if stats.WinTrades > 0 {
|
||||
stats.AvgWin = totalWin / float64(stats.WinTrades)
|
||||
}
|
||||
if stats.LossTrades > 0 {
|
||||
stats.AvgLoss = totalLoss / float64(stats.LossTrades)
|
||||
}
|
||||
|
||||
// 计算夏普比(使用盈亏序列)
|
||||
if len(pnls) > 1 {
|
||||
stats.SharpeRatio = calculateSharpeRatio(pnls)
|
||||
}
|
||||
|
||||
// 计算最大回撤
|
||||
if len(pnls) > 0 {
|
||||
stats.MaxDrawdownPct = calculateMaxDrawdown(pnls)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// calculateSharpeRatio 计算夏普比
|
||||
func calculateSharpeRatio(pnls []float64) float64 {
|
||||
if len(pnls) < 2 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// 计算平均收益
|
||||
var sum float64
|
||||
for _, pnl := range pnls {
|
||||
sum += pnl
|
||||
}
|
||||
mean := sum / float64(len(pnls))
|
||||
|
||||
// 计算标准差
|
||||
var variance float64
|
||||
for _, pnl := range pnls {
|
||||
variance += (pnl - mean) * (pnl - mean)
|
||||
}
|
||||
stdDev := math.Sqrt(variance / float64(len(pnls)-1))
|
||||
|
||||
if stdDev == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// 夏普比 = 平均收益 / 标准差
|
||||
return mean / stdDev
|
||||
}
|
||||
|
||||
// calculateMaxDrawdown 计算最大回撤
|
||||
func calculateMaxDrawdown(pnls []float64) float64 {
|
||||
if len(pnls) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// 计算累计权益曲线
|
||||
var cumulative float64
|
||||
var peak float64
|
||||
var maxDD float64
|
||||
|
||||
for _, pnl := range pnls {
|
||||
cumulative += pnl
|
||||
if cumulative > peak {
|
||||
peak = cumulative
|
||||
}
|
||||
if peak > 0 {
|
||||
dd := (peak - cumulative) / peak * 100
|
||||
if dd > maxDD {
|
||||
maxDD = dd
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return maxDD
|
||||
}
|
||||
|
||||
// GetPendingOrders 获取未成交的订单(用于轮询)
|
||||
func (s *OrderStore) GetPendingOrders(traderID string) ([]*TraderOrder, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, order_id, client_order_id, symbol, side, position_side,
|
||||
action, order_type, quantity, price, avg_price, executed_qty,
|
||||
leverage, status, fee, fee_asset, realized_pnl, entry_price,
|
||||
created_at, updated_at, filled_at
|
||||
FROM trader_orders
|
||||
WHERE trader_id = ? AND status = 'NEW'
|
||||
ORDER BY created_at ASC
|
||||
`, traderID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询未成交订单失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanOrders(rows)
|
||||
}
|
||||
|
||||
// GetAllPendingOrders 获取所有未成交的订单(用于全局同步)
|
||||
func (s *OrderStore) GetAllPendingOrders() ([]*TraderOrder, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, order_id, client_order_id, symbol, side, position_side,
|
||||
action, order_type, quantity, price, avg_price, executed_qty,
|
||||
leverage, status, fee, fee_asset, realized_pnl, entry_price,
|
||||
created_at, updated_at, filled_at
|
||||
FROM trader_orders
|
||||
WHERE status = 'NEW'
|
||||
ORDER BY trader_id, created_at ASC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询未成交订单失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanOrders(rows)
|
||||
}
|
||||
|
||||
// scanOrders 扫描订单行到结构体
|
||||
func (s *OrderStore) scanOrders(rows *sql.Rows) ([]*TraderOrder, error) {
|
||||
var orders []*TraderOrder
|
||||
for rows.Next() {
|
||||
var order TraderOrder
|
||||
var createdAt, updatedAt, filledAt sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&order.ID, &order.TraderID, &order.OrderID, &order.ClientOrderID,
|
||||
&order.Symbol, &order.Side, &order.PositionSide, &order.Action,
|
||||
&order.OrderType, &order.Quantity, &order.Price, &order.AvgPrice,
|
||||
&order.ExecutedQty, &order.Leverage, &order.Status, &order.Fee,
|
||||
&order.FeeAsset, &order.RealizedPnL, &order.EntryPrice,
|
||||
&createdAt, &updatedAt, &filledAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if createdAt.Valid {
|
||||
order.CreatedAt, _ = time.Parse(time.RFC3339, createdAt.String)
|
||||
}
|
||||
if updatedAt.Valid {
|
||||
order.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String)
|
||||
}
|
||||
if filledAt.Valid {
|
||||
order.FilledAt, _ = time.Parse(time.RFC3339, filledAt.String)
|
||||
}
|
||||
|
||||
orders = append(orders, &order)
|
||||
}
|
||||
|
||||
return orders, nil
|
||||
}
|
||||
473
store/position.go
Normal file
473
store/position.go
Normal file
@@ -0,0 +1,473 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TraderPosition 仓位记录(完整的开平仓追踪)
|
||||
type TraderPosition struct {
|
||||
ID int64 `json:"id"`
|
||||
TraderID string `json:"trader_id"`
|
||||
Symbol string `json:"symbol"`
|
||||
Side string `json:"side"` // LONG/SHORT
|
||||
Quantity float64 `json:"quantity"` // 开仓数量
|
||||
EntryPrice float64 `json:"entry_price"` // 开仓均价
|
||||
EntryOrderID string `json:"entry_order_id"` // 开仓订单ID
|
||||
EntryTime time.Time `json:"entry_time"` // 开仓时间
|
||||
ExitPrice float64 `json:"exit_price"` // 平仓均价
|
||||
ExitOrderID string `json:"exit_order_id"` // 平仓订单ID
|
||||
ExitTime *time.Time `json:"exit_time"` // 平仓时间
|
||||
RealizedPnL float64 `json:"realized_pnl"` // 已实现盈亏
|
||||
Fee float64 `json:"fee"` // 手续费
|
||||
Leverage int `json:"leverage"` // 杠杆倍数
|
||||
Status string `json:"status"` // OPEN/CLOSED
|
||||
CloseReason string `json:"close_reason"` // 平仓原因: ai_decision/manual/stop_loss/take_profit
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// PositionStore 仓位存储
|
||||
type PositionStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewPositionStore 创建仓位存储实例
|
||||
func NewPositionStore(db *sql.DB) *PositionStore {
|
||||
return &PositionStore{db: db}
|
||||
}
|
||||
|
||||
// InitTables 初始化仓位表
|
||||
func (s *PositionStore) InitTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS trader_positions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
trader_id TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
side TEXT NOT NULL,
|
||||
quantity REAL NOT NULL,
|
||||
entry_price REAL NOT NULL,
|
||||
entry_order_id TEXT DEFAULT '',
|
||||
entry_time DATETIME NOT NULL,
|
||||
exit_price REAL DEFAULT 0,
|
||||
exit_order_id TEXT DEFAULT '',
|
||||
exit_time DATETIME,
|
||||
realized_pnl REAL DEFAULT 0,
|
||||
fee REAL DEFAULT 0,
|
||||
leverage INTEGER DEFAULT 1,
|
||||
status TEXT DEFAULT 'OPEN',
|
||||
close_reason TEXT DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建trader_positions表失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建索引
|
||||
indices := []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_positions_trader ON trader_positions(trader_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_positions_status ON trader_positions(trader_id, status)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_positions_symbol ON trader_positions(trader_id, symbol, side, status)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_positions_entry ON trader_positions(trader_id, entry_time DESC)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_positions_exit ON trader_positions(trader_id, exit_time DESC)`,
|
||||
}
|
||||
for _, idx := range indices {
|
||||
if _, err := s.db.Exec(idx); err != nil {
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create 创建仓位记录(开仓时调用)
|
||||
func (s *PositionStore) Create(pos *TraderPosition) error {
|
||||
now := time.Now()
|
||||
pos.CreatedAt = now
|
||||
pos.UpdatedAt = now
|
||||
pos.Status = "OPEN"
|
||||
|
||||
result, err := s.db.Exec(`
|
||||
INSERT INTO trader_positions (
|
||||
trader_id, symbol, side, quantity, entry_price, entry_order_id,
|
||||
entry_time, leverage, status, created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`,
|
||||
pos.TraderID, pos.Symbol, pos.Side, pos.Quantity, pos.EntryPrice,
|
||||
pos.EntryOrderID, pos.EntryTime.Format(time.RFC3339), pos.Leverage,
|
||||
pos.Status, now.Format(time.RFC3339), now.Format(time.RFC3339),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建仓位记录失败: %w", err)
|
||||
}
|
||||
|
||||
id, _ := result.LastInsertId()
|
||||
pos.ID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClosePosition 平仓(更新仓位记录)
|
||||
func (s *PositionStore) ClosePosition(id int64, exitPrice float64, exitOrderID string, realizedPnL float64, fee float64, closeReason string) error {
|
||||
now := time.Now()
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE trader_positions SET
|
||||
exit_price = ?, exit_order_id = ?, exit_time = ?,
|
||||
realized_pnl = ?, fee = ?, status = 'CLOSED',
|
||||
close_reason = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
`,
|
||||
exitPrice, exitOrderID, now.Format(time.RFC3339),
|
||||
realizedPnL, fee, closeReason, now.Format(time.RFC3339), id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新仓位记录失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOpenPositions 获取所有未平仓位
|
||||
func (s *PositionStore) GetOpenPositions(traderID string) ([]*TraderPosition, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, symbol, side, quantity, entry_price, entry_order_id,
|
||||
entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee,
|
||||
leverage, status, close_reason, created_at, updated_at
|
||||
FROM trader_positions
|
||||
WHERE trader_id = ? AND status = 'OPEN'
|
||||
ORDER BY entry_time DESC
|
||||
`, traderID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询未平仓位失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanPositions(rows)
|
||||
}
|
||||
|
||||
// GetOpenPositionBySymbol 获取指定币种方向的未平仓位
|
||||
func (s *PositionStore) GetOpenPositionBySymbol(traderID, symbol, side string) (*TraderPosition, error) {
|
||||
var pos TraderPosition
|
||||
var entryTime, exitTime, createdAt, updatedAt sql.NullString
|
||||
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, trader_id, symbol, side, quantity, entry_price, entry_order_id,
|
||||
entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee,
|
||||
leverage, status, close_reason, created_at, updated_at
|
||||
FROM trader_positions
|
||||
WHERE trader_id = ? AND symbol = ? AND side = ? AND status = 'OPEN'
|
||||
ORDER BY entry_time DESC LIMIT 1
|
||||
`, traderID, symbol, side).Scan(
|
||||
&pos.ID, &pos.TraderID, &pos.Symbol, &pos.Side, &pos.Quantity,
|
||||
&pos.EntryPrice, &pos.EntryOrderID, &entryTime, &pos.ExitPrice,
|
||||
&pos.ExitOrderID, &exitTime, &pos.RealizedPnL, &pos.Fee,
|
||||
&pos.Leverage, &pos.Status, &pos.CloseReason, &createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.parsePositionTimes(&pos, entryTime, exitTime, createdAt, updatedAt)
|
||||
return &pos, nil
|
||||
}
|
||||
|
||||
// GetClosedPositions 获取已平仓位(历史记录)
|
||||
func (s *PositionStore) GetClosedPositions(traderID string, limit int) ([]*TraderPosition, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, symbol, side, quantity, entry_price, entry_order_id,
|
||||
entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee,
|
||||
leverage, status, close_reason, created_at, updated_at
|
||||
FROM trader_positions
|
||||
WHERE trader_id = ? AND status = 'CLOSED'
|
||||
ORDER BY exit_time DESC
|
||||
LIMIT ?
|
||||
`, traderID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询已平仓位失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanPositions(rows)
|
||||
}
|
||||
|
||||
// GetAllOpenPositions 获取所有trader的未平仓位(用于全局同步)
|
||||
func (s *PositionStore) GetAllOpenPositions() ([]*TraderPosition, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, trader_id, symbol, side, quantity, entry_price, entry_order_id,
|
||||
entry_time, exit_price, exit_order_id, exit_time, realized_pnl, fee,
|
||||
leverage, status, close_reason, created_at, updated_at
|
||||
FROM trader_positions
|
||||
WHERE status = 'OPEN'
|
||||
ORDER BY trader_id, entry_time DESC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询所有未平仓位失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return s.scanPositions(rows)
|
||||
}
|
||||
|
||||
// GetPositionStats 获取仓位统计(简单版)
|
||||
func (s *PositionStore) GetPositionStats(traderID string) (map[string]interface{}, error) {
|
||||
stats := make(map[string]interface{})
|
||||
|
||||
// 总交易数
|
||||
var totalTrades, winTrades int
|
||||
var totalPnL, totalFee float64
|
||||
|
||||
err := s.db.QueryRow(`
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN realized_pnl > 0 THEN 1 ELSE 0 END) as wins,
|
||||
COALESCE(SUM(realized_pnl), 0) as total_pnl,
|
||||
COALESCE(SUM(fee), 0) as total_fee
|
||||
FROM trader_positions
|
||||
WHERE trader_id = ? AND status = 'CLOSED'
|
||||
`, traderID).Scan(&totalTrades, &winTrades, &totalPnL, &totalFee)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stats["total_trades"] = totalTrades
|
||||
stats["win_trades"] = winTrades
|
||||
stats["total_pnl"] = totalPnL
|
||||
stats["total_fee"] = totalFee
|
||||
if totalTrades > 0 {
|
||||
stats["win_rate"] = float64(winTrades) / float64(totalTrades) * 100
|
||||
} else {
|
||||
stats["win_rate"] = 0.0
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetFullStats 获取完整的交易统计(与 TraderStats 兼容)
|
||||
func (s *PositionStore) GetFullStats(traderID string) (*TraderStats, error) {
|
||||
stats := &TraderStats{}
|
||||
|
||||
// 查询所有已平仓位
|
||||
rows, err := s.db.Query(`
|
||||
SELECT realized_pnl, fee, exit_time
|
||||
FROM trader_positions
|
||||
WHERE trader_id = ? AND status = 'CLOSED'
|
||||
ORDER BY exit_time ASC
|
||||
`, traderID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询仓位统计失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var pnls []float64
|
||||
var totalWin, totalLoss float64
|
||||
|
||||
for rows.Next() {
|
||||
var pnl, fee float64
|
||||
var exitTime sql.NullString
|
||||
if err := rows.Scan(&pnl, &fee, &exitTime); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
stats.TotalTrades++
|
||||
stats.TotalPnL += pnl
|
||||
stats.TotalFee += fee
|
||||
pnls = append(pnls, pnl)
|
||||
|
||||
if pnl > 0 {
|
||||
stats.WinTrades++
|
||||
totalWin += pnl
|
||||
} else if pnl < 0 {
|
||||
stats.LossTrades++
|
||||
totalLoss += -pnl // 转为正数
|
||||
}
|
||||
}
|
||||
|
||||
// 计算胜率
|
||||
if stats.TotalTrades > 0 {
|
||||
stats.WinRate = float64(stats.WinTrades) / float64(stats.TotalTrades) * 100
|
||||
}
|
||||
|
||||
// 计算盈亏比
|
||||
if totalLoss > 0 {
|
||||
stats.ProfitFactor = totalWin / totalLoss
|
||||
}
|
||||
|
||||
// 计算平均盈亏
|
||||
if stats.WinTrades > 0 {
|
||||
stats.AvgWin = totalWin / float64(stats.WinTrades)
|
||||
}
|
||||
if stats.LossTrades > 0 {
|
||||
stats.AvgLoss = totalLoss / float64(stats.LossTrades)
|
||||
}
|
||||
|
||||
// 计算夏普比
|
||||
if len(pnls) > 1 {
|
||||
stats.SharpeRatio = calculateSharpeRatioFromPnls(pnls)
|
||||
}
|
||||
|
||||
// 计算最大回撤
|
||||
if len(pnls) > 0 {
|
||||
stats.MaxDrawdownPct = calculateMaxDrawdownFromPnls(pnls)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// RecentTrade 最近的交易记录(用于AI输入)
|
||||
type RecentTrade struct {
|
||||
Symbol string `json:"symbol"`
|
||||
Side string `json:"side"` // long/short
|
||||
EntryPrice float64 `json:"entry_price"`
|
||||
ExitPrice float64 `json:"exit_price"`
|
||||
RealizedPnL float64 `json:"realized_pnl"`
|
||||
PnLPct float64 `json:"pnl_pct"`
|
||||
ExitTime string `json:"exit_time"`
|
||||
}
|
||||
|
||||
// GetRecentTrades 获取最近的已平仓交易
|
||||
func (s *PositionStore) GetRecentTrades(traderID string, limit int) ([]RecentTrade, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT symbol, side, entry_price, exit_price, realized_pnl, leverage, exit_time
|
||||
FROM trader_positions
|
||||
WHERE trader_id = ? AND status = 'CLOSED'
|
||||
ORDER BY exit_time DESC
|
||||
LIMIT ?
|
||||
`, traderID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询最近交易失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var trades []RecentTrade
|
||||
for rows.Next() {
|
||||
var t RecentTrade
|
||||
var leverage int
|
||||
var exitTime sql.NullString
|
||||
|
||||
err := rows.Scan(&t.Symbol, &t.Side, &t.EntryPrice, &t.ExitPrice, &t.RealizedPnL, &leverage, &exitTime)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 转换 side 格式
|
||||
if t.Side == "LONG" {
|
||||
t.Side = "long"
|
||||
} else if t.Side == "SHORT" {
|
||||
t.Side = "short"
|
||||
}
|
||||
|
||||
// 计算盈亏百分比
|
||||
if t.EntryPrice > 0 {
|
||||
if t.Side == "long" {
|
||||
t.PnLPct = (t.ExitPrice - t.EntryPrice) / t.EntryPrice * 100 * float64(leverage)
|
||||
} else {
|
||||
t.PnLPct = (t.EntryPrice - t.ExitPrice) / t.EntryPrice * 100 * float64(leverage)
|
||||
}
|
||||
}
|
||||
|
||||
// 格式化时间
|
||||
if exitTime.Valid {
|
||||
if parsed, err := time.Parse(time.RFC3339, exitTime.String); err == nil {
|
||||
t.ExitTime = parsed.Format("01-02 15:04")
|
||||
}
|
||||
}
|
||||
|
||||
trades = append(trades, t)
|
||||
}
|
||||
|
||||
return trades, nil
|
||||
}
|
||||
|
||||
// calculateSharpeRatioFromPnls 计算夏普比
|
||||
func calculateSharpeRatioFromPnls(pnls []float64) float64 {
|
||||
if len(pnls) < 2 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var sum float64
|
||||
for _, pnl := range pnls {
|
||||
sum += pnl
|
||||
}
|
||||
mean := sum / float64(len(pnls))
|
||||
|
||||
var variance float64
|
||||
for _, pnl := range pnls {
|
||||
variance += (pnl - mean) * (pnl - mean)
|
||||
}
|
||||
stdDev := math.Sqrt(variance / float64(len(pnls)-1))
|
||||
|
||||
if stdDev == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return mean / stdDev
|
||||
}
|
||||
|
||||
// calculateMaxDrawdownFromPnls 计算最大回撤
|
||||
func calculateMaxDrawdownFromPnls(pnls []float64) float64 {
|
||||
if len(pnls) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var cumulative, peak, maxDD float64
|
||||
for _, pnl := range pnls {
|
||||
cumulative += pnl
|
||||
if cumulative > peak {
|
||||
peak = cumulative
|
||||
}
|
||||
if peak > 0 {
|
||||
dd := (peak - cumulative) / peak * 100
|
||||
if dd > maxDD {
|
||||
maxDD = dd
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return maxDD
|
||||
}
|
||||
|
||||
// scanPositions 扫描仓位行到结构体
|
||||
func (s *PositionStore) scanPositions(rows *sql.Rows) ([]*TraderPosition, error) {
|
||||
var positions []*TraderPosition
|
||||
for rows.Next() {
|
||||
var pos TraderPosition
|
||||
var entryTime, exitTime, createdAt, updatedAt sql.NullString
|
||||
|
||||
err := rows.Scan(
|
||||
&pos.ID, &pos.TraderID, &pos.Symbol, &pos.Side, &pos.Quantity,
|
||||
&pos.EntryPrice, &pos.EntryOrderID, &entryTime, &pos.ExitPrice,
|
||||
&pos.ExitOrderID, &exitTime, &pos.RealizedPnL, &pos.Fee,
|
||||
&pos.Leverage, &pos.Status, &pos.CloseReason, &createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
s.parsePositionTimes(&pos, entryTime, exitTime, createdAt, updatedAt)
|
||||
positions = append(positions, &pos)
|
||||
}
|
||||
|
||||
return positions, nil
|
||||
}
|
||||
|
||||
// parsePositionTimes 解析时间字段
|
||||
func (s *PositionStore) parsePositionTimes(pos *TraderPosition, entryTime, exitTime, createdAt, updatedAt sql.NullString) {
|
||||
if entryTime.Valid {
|
||||
pos.EntryTime, _ = time.Parse(time.RFC3339, entryTime.String)
|
||||
}
|
||||
if exitTime.Valid {
|
||||
t, _ := time.Parse(time.RFC3339, exitTime.String)
|
||||
pos.ExitTime = &t
|
||||
}
|
||||
if createdAt.Valid {
|
||||
pos.CreatedAt, _ = time.Parse(time.RFC3339, createdAt.String)
|
||||
}
|
||||
if updatedAt.Valid {
|
||||
pos.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String)
|
||||
}
|
||||
}
|
||||
86
store/signal_source.go
Normal file
86
store/signal_source.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SignalSourceStore 信号源存储
|
||||
type SignalSourceStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// SignalSource 用户信号源配置
|
||||
type SignalSource struct {
|
||||
ID int `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
CoinPoolURL string `json:"coin_pool_url"`
|
||||
OITopURL string `json:"oi_top_url"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
func (s *SignalSourceStore) initTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS user_signal_sources (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL,
|
||||
coin_pool_url TEXT DEFAULT '',
|
||||
oi_top_url TEXT DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE,
|
||||
UNIQUE(user_id)
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 触发器
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS update_user_signal_sources_updated_at
|
||||
AFTER UPDATE ON user_signal_sources
|
||||
BEGIN
|
||||
UPDATE user_signal_sources SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
// Create 创建信号源配置
|
||||
func (s *SignalSourceStore) Create(userID, coinPoolURL, oiTopURL string) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT OR REPLACE INTO user_signal_sources (user_id, coin_pool_url, oi_top_url, updated_at)
|
||||
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
|
||||
`, userID, coinPoolURL, oiTopURL)
|
||||
return err
|
||||
}
|
||||
|
||||
// Get 获取信号源配置
|
||||
func (s *SignalSourceStore) Get(userID string) (*SignalSource, error) {
|
||||
var source SignalSource
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, user_id, coin_pool_url, oi_top_url, created_at, updated_at
|
||||
FROM user_signal_sources WHERE user_id = ?
|
||||
`, userID).Scan(
|
||||
&source.ID, &source.UserID, &source.CoinPoolURL, &source.OITopURL,
|
||||
&createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
source.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
source.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
return &source, nil
|
||||
}
|
||||
|
||||
// Update 更新信号源配置
|
||||
func (s *SignalSourceStore) Update(userID, coinPoolURL, oiTopURL string) error {
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE user_signal_sources SET coin_pool_url = ?, oi_top_url = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE user_id = ?
|
||||
`, coinPoolURL, oiTopURL, userID)
|
||||
return err
|
||||
}
|
||||
319
store/store.go
Normal file
319
store/store.go
Normal file
@@ -0,0 +1,319 @@
|
||||
// Package store 提供统一的数据库存储层
|
||||
// 所有数据库操作都应该通过这个包进行
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"nofx/logger"
|
||||
"sync"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// Store 统一的数据存储接口
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
|
||||
// 子存储(延迟初始化)
|
||||
user *UserStore
|
||||
aiModel *AIModelStore
|
||||
exchange *ExchangeStore
|
||||
trader *TraderStore
|
||||
systemConfig *SystemConfigStore
|
||||
betaCode *BetaCodeStore
|
||||
signalSource *SignalSourceStore
|
||||
decision *DecisionStore
|
||||
backtest *BacktestStore
|
||||
order *OrderStore
|
||||
position *PositionStore
|
||||
|
||||
// 加密函数
|
||||
encryptFunc func(string) string
|
||||
decryptFunc func(string) string
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// New 创建新的 Store 实例
|
||||
func New(dbPath string) (*Store, error) {
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// SQLite 配置
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
|
||||
// 启用外键约束
|
||||
if _, err := db.Exec(`PRAGMA foreign_keys = ON`); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("启用外键失败: %w", err)
|
||||
}
|
||||
|
||||
// 使用 DELETE 模式(传统模式)以确保 Docker bind mount 兼容性
|
||||
// 注意:WAL 模式在 macOS Docker 下会导致数据同步问题
|
||||
if _, err := db.Exec("PRAGMA journal_mode=DELETE"); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("设置journal_mode失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置 synchronous=FULL
|
||||
if _, err := db.Exec("PRAGMA synchronous=FULL"); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("设置synchronous失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置 busy_timeout
|
||||
if _, err := db.Exec("PRAGMA busy_timeout = 5000"); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("设置busy_timeout失败: %w", err)
|
||||
}
|
||||
|
||||
s := &Store{db: db}
|
||||
|
||||
// 初始化所有表结构
|
||||
if err := s.initTables(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("初始化表结构失败: %w", err)
|
||||
}
|
||||
|
||||
// 初始化默认数据
|
||||
if err := s.initDefaultData(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("初始化默认数据失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("✅ 数据库已启用 DELETE 模式和 FULL 同步")
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewFromDB 从现有数据库连接创建 Store
|
||||
func NewFromDB(db *sql.DB) *Store {
|
||||
return &Store{db: db}
|
||||
}
|
||||
|
||||
// SetCryptoFuncs 设置加密解密函数
|
||||
func (s *Store) SetCryptoFuncs(encrypt, decrypt func(string) string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.encryptFunc = encrypt
|
||||
s.decryptFunc = decrypt
|
||||
|
||||
// 更新已初始化的子存储
|
||||
if s.aiModel != nil {
|
||||
s.aiModel.encryptFunc = encrypt
|
||||
s.aiModel.decryptFunc = decrypt
|
||||
}
|
||||
if s.exchange != nil {
|
||||
s.exchange.encryptFunc = encrypt
|
||||
s.exchange.decryptFunc = decrypt
|
||||
}
|
||||
if s.trader != nil {
|
||||
s.trader.decryptFunc = decrypt
|
||||
}
|
||||
}
|
||||
|
||||
// initTables 初始化所有数据库表
|
||||
func (s *Store) initTables() error {
|
||||
// 按依赖顺序初始化
|
||||
if err := s.User().initTables(); err != nil {
|
||||
return fmt.Errorf("初始化用户表失败: %w", err)
|
||||
}
|
||||
if err := s.AIModel().initTables(); err != nil {
|
||||
return fmt.Errorf("初始化AI模型表失败: %w", err)
|
||||
}
|
||||
if err := s.Exchange().initTables(); err != nil {
|
||||
return fmt.Errorf("初始化交易所表失败: %w", err)
|
||||
}
|
||||
if err := s.Trader().initTables(); err != nil {
|
||||
return fmt.Errorf("初始化交易员表失败: %w", err)
|
||||
}
|
||||
if err := s.SystemConfig().initTables(); err != nil {
|
||||
return fmt.Errorf("初始化系统配置表失败: %w", err)
|
||||
}
|
||||
if err := s.BetaCode().initTables(); err != nil {
|
||||
return fmt.Errorf("初始化内测码表失败: %w", err)
|
||||
}
|
||||
if err := s.SignalSource().initTables(); err != nil {
|
||||
return fmt.Errorf("初始化信号源表失败: %w", err)
|
||||
}
|
||||
if err := s.Decision().initTables(); err != nil {
|
||||
return fmt.Errorf("初始化决策日志表失败: %w", err)
|
||||
}
|
||||
if err := s.Backtest().initTables(); err != nil {
|
||||
return fmt.Errorf("初始化回测表失败: %w", err)
|
||||
}
|
||||
if err := s.Order().InitTables(); err != nil {
|
||||
return fmt.Errorf("初始化订单表失败: %w", err)
|
||||
}
|
||||
if err := s.Position().InitTables(); err != nil {
|
||||
return fmt.Errorf("初始化仓位表失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// initDefaultData 初始化默认数据
|
||||
func (s *Store) initDefaultData() error {
|
||||
if err := s.AIModel().initDefaultData(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.Exchange().initDefaultData(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.SystemConfig().initDefaultData(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// User 获取用户存储
|
||||
func (s *Store) User() *UserStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.user == nil {
|
||||
s.user = &UserStore{db: s.db}
|
||||
}
|
||||
return s.user
|
||||
}
|
||||
|
||||
// AIModel 获取AI模型存储
|
||||
func (s *Store) AIModel() *AIModelStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.aiModel == nil {
|
||||
s.aiModel = &AIModelStore{
|
||||
db: s.db,
|
||||
encryptFunc: s.encryptFunc,
|
||||
decryptFunc: s.decryptFunc,
|
||||
}
|
||||
}
|
||||
return s.aiModel
|
||||
}
|
||||
|
||||
// Exchange 获取交易所存储
|
||||
func (s *Store) Exchange() *ExchangeStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.exchange == nil {
|
||||
s.exchange = &ExchangeStore{
|
||||
db: s.db,
|
||||
encryptFunc: s.encryptFunc,
|
||||
decryptFunc: s.decryptFunc,
|
||||
}
|
||||
}
|
||||
return s.exchange
|
||||
}
|
||||
|
||||
// Trader 获取交易员存储
|
||||
func (s *Store) Trader() *TraderStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.trader == nil {
|
||||
s.trader = &TraderStore{
|
||||
db: s.db,
|
||||
decryptFunc: s.decryptFunc,
|
||||
}
|
||||
}
|
||||
return s.trader
|
||||
}
|
||||
|
||||
// SystemConfig 获取系统配置存储
|
||||
func (s *Store) SystemConfig() *SystemConfigStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.systemConfig == nil {
|
||||
s.systemConfig = &SystemConfigStore{db: s.db}
|
||||
}
|
||||
return s.systemConfig
|
||||
}
|
||||
|
||||
// BetaCode 获取内测码存储
|
||||
func (s *Store) BetaCode() *BetaCodeStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.betaCode == nil {
|
||||
s.betaCode = &BetaCodeStore{db: s.db}
|
||||
}
|
||||
return s.betaCode
|
||||
}
|
||||
|
||||
// SignalSource 获取信号源存储
|
||||
func (s *Store) SignalSource() *SignalSourceStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.signalSource == nil {
|
||||
s.signalSource = &SignalSourceStore{db: s.db}
|
||||
}
|
||||
return s.signalSource
|
||||
}
|
||||
|
||||
// Decision 获取决策日志存储
|
||||
func (s *Store) Decision() *DecisionStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.decision == nil {
|
||||
s.decision = &DecisionStore{db: s.db}
|
||||
}
|
||||
return s.decision
|
||||
}
|
||||
|
||||
// Backtest 获取回测数据存储
|
||||
func (s *Store) Backtest() *BacktestStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.backtest == nil {
|
||||
s.backtest = &BacktestStore{db: s.db}
|
||||
}
|
||||
return s.backtest
|
||||
}
|
||||
|
||||
// Order 获取订单存储
|
||||
func (s *Store) Order() *OrderStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.order == nil {
|
||||
s.order = NewOrderStore(s.db)
|
||||
}
|
||||
return s.order
|
||||
}
|
||||
|
||||
// Position 获取仓位存储
|
||||
func (s *Store) Position() *PositionStore {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.position == nil {
|
||||
s.position = NewPositionStore(s.db)
|
||||
}
|
||||
return s.position
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (s *Store) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
// DB 获取底层数据库连接(仅用于兼容旧代码,逐步废弃)
|
||||
// Deprecated: 使用 Store 的方法代替
|
||||
func (s *Store) DB() *sql.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
// Transaction 执行事务
|
||||
func (s *Store) Transaction(fn func(tx *sql.Tx) error) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("开始事务失败: %w", err)
|
||||
}
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("提交事务失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
70
store/system_config.go
Normal file
70
store/system_config.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// SystemConfigStore 系统配置存储
|
||||
type SystemConfigStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (s *SystemConfigStore) initTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS system_config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 触发器
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS update_system_config_updated_at
|
||||
AFTER UPDATE ON system_config
|
||||
BEGIN
|
||||
UPDATE system_config SET updated_at = CURRENT_TIMESTAMP WHERE key = NEW.key;
|
||||
END
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SystemConfigStore) initDefaultData() error {
|
||||
configs := map[string]string{
|
||||
"beta_mode": "false",
|
||||
"api_server_port": "8080",
|
||||
"use_default_coins": "true",
|
||||
"default_coins": `["BTCUSDT","ETHUSDT","SOLUSDT","BNBUSDT","XRPUSDT","DOGEUSDT","ADAUSDT","HYPEUSDT"]`,
|
||||
"max_daily_loss": "10.0",
|
||||
"max_drawdown": "20.0",
|
||||
"stop_trading_minutes": "60",
|
||||
"btc_eth_leverage": "5",
|
||||
"altcoin_leverage": "5",
|
||||
"jwt_secret": "",
|
||||
"registration_enabled": "true",
|
||||
}
|
||||
|
||||
for key, value := range configs {
|
||||
_, err := s.db.Exec(`INSERT OR IGNORE INTO system_config (key, value) VALUES (?, ?)`, key, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get 获取配置值
|
||||
func (s *SystemConfigStore) Get(key string) (string, error) {
|
||||
var value string
|
||||
err := s.db.QueryRow(`SELECT value FROM system_config WHERE key = ?`, key).Scan(&value)
|
||||
return value, err
|
||||
}
|
||||
|
||||
// Set 设置配置值
|
||||
func (s *SystemConfigStore) Set(key, value string) error {
|
||||
_, err := s.db.Exec(`INSERT OR REPLACE INTO system_config (key, value) VALUES (?, ?)`, key, value)
|
||||
return err
|
||||
}
|
||||
344
store/trader.go
Normal file
344
store/trader.go
Normal file
@@ -0,0 +1,344 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"nofx/logger"
|
||||
"nofx/market"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TraderStore 交易员存储
|
||||
type TraderStore struct {
|
||||
db *sql.DB
|
||||
decryptFunc func(string) string
|
||||
}
|
||||
|
||||
// Trader 交易员配置
|
||||
type Trader struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
AIModelID string `json:"ai_model_id"`
|
||||
ExchangeID string `json:"exchange_id"`
|
||||
InitialBalance float64 `json:"initial_balance"`
|
||||
ScanIntervalMinutes int `json:"scan_interval_minutes"`
|
||||
IsRunning bool `json:"is_running"`
|
||||
BTCETHLeverage int `json:"btc_eth_leverage"`
|
||||
AltcoinLeverage int `json:"altcoin_leverage"`
|
||||
TradingSymbols string `json:"trading_symbols"`
|
||||
UseCoinPool bool `json:"use_coin_pool"`
|
||||
UseOITop bool `json:"use_oi_top"`
|
||||
CustomPrompt string `json:"custom_prompt"`
|
||||
OverrideBasePrompt bool `json:"override_base_prompt"`
|
||||
SystemPromptTemplate string `json:"system_prompt_template"`
|
||||
IsCrossMargin bool `json:"is_cross_margin"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// TraderFullConfig 交易员完整配置(包含AI模型和交易所)
|
||||
type TraderFullConfig struct {
|
||||
Trader *Trader
|
||||
AIModel *AIModel
|
||||
Exchange *Exchange
|
||||
}
|
||||
|
||||
func (s *TraderStore) initTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS traders (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL DEFAULT 'default',
|
||||
name TEXT NOT NULL,
|
||||
ai_model_id TEXT NOT NULL,
|
||||
exchange_id TEXT NOT NULL,
|
||||
initial_balance REAL NOT NULL,
|
||||
scan_interval_minutes INTEGER DEFAULT 3,
|
||||
is_running BOOLEAN DEFAULT 0,
|
||||
btc_eth_leverage INTEGER DEFAULT 5,
|
||||
altcoin_leverage INTEGER DEFAULT 5,
|
||||
trading_symbols TEXT DEFAULT '',
|
||||
use_coin_pool BOOLEAN DEFAULT 0,
|
||||
use_oi_top BOOLEAN DEFAULT 0,
|
||||
custom_prompt TEXT DEFAULT '',
|
||||
override_base_prompt BOOLEAN DEFAULT 0,
|
||||
system_prompt_template TEXT DEFAULT 'default',
|
||||
is_cross_margin BOOLEAN DEFAULT 1,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 触发器
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS update_traders_updated_at
|
||||
AFTER UPDATE ON traders
|
||||
BEGIN
|
||||
UPDATE traders SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 向后兼容
|
||||
alterQueries := []string{
|
||||
`ALTER TABLE traders ADD COLUMN custom_prompt TEXT DEFAULT ''`,
|
||||
`ALTER TABLE traders ADD COLUMN override_base_prompt BOOLEAN DEFAULT 0`,
|
||||
`ALTER TABLE traders ADD COLUMN is_cross_margin BOOLEAN DEFAULT 1`,
|
||||
`ALTER TABLE traders ADD COLUMN btc_eth_leverage INTEGER DEFAULT 5`,
|
||||
`ALTER TABLE traders ADD COLUMN altcoin_leverage INTEGER DEFAULT 5`,
|
||||
`ALTER TABLE traders ADD COLUMN trading_symbols TEXT DEFAULT ''`,
|
||||
`ALTER TABLE traders ADD COLUMN use_coin_pool BOOLEAN DEFAULT 0`,
|
||||
`ALTER TABLE traders ADD COLUMN use_oi_top BOOLEAN DEFAULT 0`,
|
||||
`ALTER TABLE traders ADD COLUMN system_prompt_template TEXT DEFAULT 'default'`,
|
||||
}
|
||||
for _, q := range alterQueries {
|
||||
s.db.Exec(q)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TraderStore) decrypt(encrypted string) string {
|
||||
if s.decryptFunc != nil {
|
||||
return s.decryptFunc(encrypted)
|
||||
}
|
||||
return encrypted
|
||||
}
|
||||
|
||||
// Create 创建交易员
|
||||
func (s *TraderStore) Create(trader *Trader) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO traders (id, user_id, name, ai_model_id, exchange_id, initial_balance, scan_interval_minutes,
|
||||
is_running, btc_eth_leverage, altcoin_leverage, trading_symbols, use_coin_pool,
|
||||
use_oi_top, custom_prompt, override_base_prompt, system_prompt_template, is_cross_margin)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, trader.ID, trader.UserID, trader.Name, trader.AIModelID, trader.ExchangeID, trader.InitialBalance,
|
||||
trader.ScanIntervalMinutes, trader.IsRunning, trader.BTCETHLeverage, trader.AltcoinLeverage,
|
||||
trader.TradingSymbols, trader.UseCoinPool, trader.UseOITop, trader.CustomPrompt,
|
||||
trader.OverrideBasePrompt, trader.SystemPromptTemplate, trader.IsCrossMargin)
|
||||
return err
|
||||
}
|
||||
|
||||
// List 获取用户的交易员列表
|
||||
func (s *TraderStore) List(userID string) ([]*Trader, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, user_id, name, ai_model_id, exchange_id, initial_balance, scan_interval_minutes, is_running,
|
||||
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
|
||||
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
|
||||
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
|
||||
COALESCE(is_cross_margin, 1), created_at, updated_at
|
||||
FROM traders WHERE user_id = ? ORDER BY created_at DESC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var traders []*Trader
|
||||
for rows.Next() {
|
||||
var t Trader
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(
|
||||
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID,
|
||||
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning,
|
||||
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
|
||||
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
|
||||
&t.SystemPromptTemplate, &t.IsCrossMargin, &createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
traders = append(traders, &t)
|
||||
}
|
||||
return traders, nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新交易员运行状态
|
||||
func (s *TraderStore) UpdateStatus(userID, id string, isRunning bool) error {
|
||||
_, err := s.db.Exec(`UPDATE traders SET is_running = ? WHERE id = ? AND user_id = ?`, isRunning, id, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// Update 更新交易员配置
|
||||
func (s *TraderStore) Update(trader *Trader) error {
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE traders SET
|
||||
name = ?, ai_model_id = ?, exchange_id = ?, scan_interval_minutes = ?,
|
||||
btc_eth_leverage = ?, altcoin_leverage = ?, trading_symbols = ?,
|
||||
custom_prompt = ?, override_base_prompt = ?, system_prompt_template = ?,
|
||||
is_cross_margin = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = ? AND user_id = ?
|
||||
`, trader.Name, trader.AIModelID, trader.ExchangeID, trader.ScanIntervalMinutes,
|
||||
trader.BTCETHLeverage, trader.AltcoinLeverage, trader.TradingSymbols,
|
||||
trader.CustomPrompt, trader.OverrideBasePrompt, trader.SystemPromptTemplate,
|
||||
trader.IsCrossMargin, trader.ID, trader.UserID)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateInitialBalance 更新初始余额
|
||||
func (s *TraderStore) UpdateInitialBalance(userID, id string, newBalance float64) error {
|
||||
_, err := s.db.Exec(`UPDATE traders SET initial_balance = ? WHERE id = ? AND user_id = ?`, newBalance, id, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateCustomPrompt 更新自定义提示词
|
||||
func (s *TraderStore) UpdateCustomPrompt(userID, id string, customPrompt string, overrideBase bool) error {
|
||||
_, err := s.db.Exec(`UPDATE traders SET custom_prompt = ?, override_base_prompt = ? WHERE id = ? AND user_id = ?`,
|
||||
customPrompt, overrideBase, id, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete 删除交易员
|
||||
func (s *TraderStore) Delete(userID, id string) error {
|
||||
_, err := s.db.Exec(`DELETE FROM traders WHERE id = ? AND user_id = ?`, id, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetFullConfig 获取交易员完整配置
|
||||
func (s *TraderStore) GetFullConfig(userID, traderID string) (*TraderFullConfig, error) {
|
||||
var trader Trader
|
||||
var aiModel AIModel
|
||||
var exchange Exchange
|
||||
var traderCreatedAt, traderUpdatedAt string
|
||||
var aiModelCreatedAt, aiModelUpdatedAt string
|
||||
var exchangeCreatedAt, exchangeUpdatedAt string
|
||||
|
||||
err := s.db.QueryRow(`
|
||||
SELECT
|
||||
t.id, t.user_id, t.name, t.ai_model_id, t.exchange_id, t.initial_balance, t.scan_interval_minutes, t.is_running,
|
||||
COALESCE(t.btc_eth_leverage, 5), COALESCE(t.altcoin_leverage, 5), COALESCE(t.trading_symbols, ''),
|
||||
COALESCE(t.use_coin_pool, 0), COALESCE(t.use_oi_top, 0), COALESCE(t.custom_prompt, ''),
|
||||
COALESCE(t.override_base_prompt, 0), COALESCE(t.system_prompt_template, 'default'),
|
||||
COALESCE(t.is_cross_margin, 1), t.created_at, t.updated_at,
|
||||
a.id, a.user_id, a.name, a.provider, a.enabled, a.api_key,
|
||||
COALESCE(a.custom_api_url, ''), COALESCE(a.custom_model_name, ''), a.created_at, a.updated_at,
|
||||
e.id, e.user_id, e.name, e.type, e.enabled, e.api_key, e.secret_key, e.testnet,
|
||||
COALESCE(e.hyperliquid_wallet_addr, ''), COALESCE(e.aster_user, ''), COALESCE(e.aster_signer, ''),
|
||||
COALESCE(e.aster_private_key, ''), COALESCE(e.lighter_wallet_addr, ''), COALESCE(e.lighter_private_key, ''),
|
||||
COALESCE(e.lighter_api_key_private_key, ''), e.created_at, e.updated_at
|
||||
FROM traders t
|
||||
JOIN ai_models a ON t.ai_model_id = a.id AND t.user_id = a.user_id
|
||||
JOIN exchanges e ON t.exchange_id = e.id AND t.user_id = e.user_id
|
||||
WHERE t.id = ? AND t.user_id = ?
|
||||
`, traderID, userID).Scan(
|
||||
&trader.ID, &trader.UserID, &trader.Name, &trader.AIModelID, &trader.ExchangeID,
|
||||
&trader.InitialBalance, &trader.ScanIntervalMinutes, &trader.IsRunning,
|
||||
&trader.BTCETHLeverage, &trader.AltcoinLeverage, &trader.TradingSymbols,
|
||||
&trader.UseCoinPool, &trader.UseOITop, &trader.CustomPrompt, &trader.OverrideBasePrompt,
|
||||
&trader.SystemPromptTemplate, &trader.IsCrossMargin, &traderCreatedAt, &traderUpdatedAt,
|
||||
&aiModel.ID, &aiModel.UserID, &aiModel.Name, &aiModel.Provider, &aiModel.Enabled, &aiModel.APIKey,
|
||||
&aiModel.CustomAPIURL, &aiModel.CustomModelName, &aiModelCreatedAt, &aiModelUpdatedAt,
|
||||
&exchange.ID, &exchange.UserID, &exchange.Name, &exchange.Type, &exchange.Enabled,
|
||||
&exchange.APIKey, &exchange.SecretKey, &exchange.Testnet, &exchange.HyperliquidWalletAddr,
|
||||
&exchange.AsterUser, &exchange.AsterSigner, &exchange.AsterPrivateKey,
|
||||
&exchange.LighterWalletAddr, &exchange.LighterPrivateKey, &exchange.LighterAPIKeyPrivateKey,
|
||||
&exchangeCreatedAt, &exchangeUpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
trader.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", traderCreatedAt)
|
||||
trader.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", traderUpdatedAt)
|
||||
aiModel.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelCreatedAt)
|
||||
aiModel.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", aiModelUpdatedAt)
|
||||
exchange.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeCreatedAt)
|
||||
exchange.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", exchangeUpdatedAt)
|
||||
|
||||
// 解密
|
||||
aiModel.APIKey = s.decrypt(aiModel.APIKey)
|
||||
exchange.APIKey = s.decrypt(exchange.APIKey)
|
||||
exchange.SecretKey = s.decrypt(exchange.SecretKey)
|
||||
exchange.AsterPrivateKey = s.decrypt(exchange.AsterPrivateKey)
|
||||
exchange.LighterPrivateKey = s.decrypt(exchange.LighterPrivateKey)
|
||||
exchange.LighterAPIKeyPrivateKey = s.decrypt(exchange.LighterAPIKeyPrivateKey)
|
||||
|
||||
return &TraderFullConfig{
|
||||
Trader: &trader,
|
||||
AIModel: &aiModel,
|
||||
Exchange: &exchange,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetCustomCoins 获取所有交易员自定义币种
|
||||
func (s *TraderStore) GetCustomCoins() []string {
|
||||
var symbol string
|
||||
var symbols []string
|
||||
_ = s.db.QueryRow(`
|
||||
SELECT GROUP_CONCAT(trading_symbols, ',') as symbol
|
||||
FROM traders WHERE trading_symbols != ''
|
||||
`).Scan(&symbol)
|
||||
|
||||
// 如果没有自定义币种,返回默认币种
|
||||
if symbol == "" {
|
||||
var symbolJSON string
|
||||
_ = s.db.QueryRow(`SELECT value FROM system_config WHERE key = 'default_coins'`).Scan(&symbolJSON)
|
||||
if symbolJSON != "" {
|
||||
if err := json.Unmarshal([]byte(symbolJSON), &symbols); err != nil {
|
||||
logger.Warnf("⚠️ 解析default_coins配置失败: %v,使用硬编码默认值", err)
|
||||
symbols = []string{"BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT"}
|
||||
}
|
||||
} else {
|
||||
symbols = []string{"BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT"}
|
||||
}
|
||||
return symbols
|
||||
}
|
||||
|
||||
// 处理并去重币种列表
|
||||
for _, s := range strings.Split(symbol, ",") {
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
coin := market.Normalize(s)
|
||||
if !slices.Contains(symbols, coin) {
|
||||
symbols = append(symbols, coin)
|
||||
}
|
||||
}
|
||||
return symbols
|
||||
}
|
||||
|
||||
// ListAll 获取所有用户的交易员列表
|
||||
func (s *TraderStore) ListAll() ([]*Trader, error) {
|
||||
rows, err := s.db.Query(`
|
||||
SELECT id, user_id, name, ai_model_id, exchange_id, initial_balance, scan_interval_minutes, is_running,
|
||||
COALESCE(btc_eth_leverage, 5), COALESCE(altcoin_leverage, 5), COALESCE(trading_symbols, ''),
|
||||
COALESCE(use_coin_pool, 0), COALESCE(use_oi_top, 0), COALESCE(custom_prompt, ''),
|
||||
COALESCE(override_base_prompt, 0), COALESCE(system_prompt_template, 'default'),
|
||||
COALESCE(is_cross_margin, 1), created_at, updated_at
|
||||
FROM traders ORDER BY created_at DESC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var traders []*Trader
|
||||
for rows.Next() {
|
||||
var t Trader
|
||||
var createdAt, updatedAt string
|
||||
err := rows.Scan(
|
||||
&t.ID, &t.UserID, &t.Name, &t.AIModelID, &t.ExchangeID,
|
||||
&t.InitialBalance, &t.ScanIntervalMinutes, &t.IsRunning,
|
||||
&t.BTCETHLeverage, &t.AltcoinLeverage, &t.TradingSymbols,
|
||||
&t.UseCoinPool, &t.UseOITop, &t.CustomPrompt, &t.OverrideBasePrompt,
|
||||
&t.SystemPromptTemplate, &t.IsCrossMargin, &createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
t.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
traders = append(traders, &t)
|
||||
}
|
||||
return traders, nil
|
||||
}
|
||||
164
store/user.go
Normal file
164
store/user.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base32"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UserStore 用户存储
|
||||
type UserStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// User 用户
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
PasswordHash string `json:"-"`
|
||||
OTPSecret string `json:"-"`
|
||||
OTPVerified bool `json:"otp_verified"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// GenerateOTPSecret 生成OTP密钥
|
||||
func GenerateOTPSecret() (string, error) {
|
||||
secret := make([]byte, 20)
|
||||
_, err := rand.Read(secret)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base32.StdEncoding.EncodeToString(secret), nil
|
||||
}
|
||||
|
||||
func (s *UserStore) initTables() error {
|
||||
_, err := s.db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
otp_secret TEXT,
|
||||
otp_verified BOOLEAN DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 触发器
|
||||
_, err = s.db.Exec(`
|
||||
CREATE TRIGGER IF NOT EXISTS update_users_updated_at
|
||||
AFTER UPDATE ON users
|
||||
BEGIN
|
||||
UPDATE users SET updated_at = CURRENT_TIMESTAMP WHERE id = NEW.id;
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create 创建用户
|
||||
func (s *UserStore) Create(user *User) error {
|
||||
_, err := s.db.Exec(`
|
||||
INSERT INTO users (id, email, password_hash, otp_secret, otp_verified)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`, user.ID, user.Email, user.PasswordHash, user.OTPSecret, user.OTPVerified)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetByEmail 通过邮箱获取用户
|
||||
func (s *UserStore) GetByEmail(email string) (*User, error) {
|
||||
var user User
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at
|
||||
FROM users WHERE email = ?
|
||||
`, email).Scan(
|
||||
&user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret,
|
||||
&user.OTPVerified, &createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByID 通过ID获取用户
|
||||
func (s *UserStore) GetByID(userID string) (*User, error) {
|
||||
var user User
|
||||
var createdAt, updatedAt string
|
||||
err := s.db.QueryRow(`
|
||||
SELECT id, email, password_hash, otp_secret, otp_verified, created_at, updated_at
|
||||
FROM users WHERE id = ?
|
||||
`, userID).Scan(
|
||||
&user.ID, &user.Email, &user.PasswordHash, &user.OTPSecret,
|
||||
&user.OTPVerified, &createdAt, &updatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
user.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetAllIDs 获取所有用户ID
|
||||
func (s *UserStore) GetAllIDs() ([]string, error) {
|
||||
rows, err := s.db.Query(`SELECT id FROM users ORDER BY id`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var userIDs []string
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return userIDs, nil
|
||||
}
|
||||
|
||||
// UpdateOTPVerified 更新OTP验证状态
|
||||
func (s *UserStore) UpdateOTPVerified(userID string, verified bool) error {
|
||||
_, err := s.db.Exec(`UPDATE users SET otp_verified = ? WHERE id = ?`, verified, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdatePassword 更新密码
|
||||
func (s *UserStore) UpdatePassword(userID, passwordHash string) error {
|
||||
_, err := s.db.Exec(`
|
||||
UPDATE users SET password_hash = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?
|
||||
`, passwordHash, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// EnsureAdmin 确保admin用户存在
|
||||
func (s *UserStore) EnsureAdmin() error {
|
||||
var count int
|
||||
err := s.db.QueryRow(`SELECT COUNT(*) FROM users WHERE id = 'admin'`).Scan(&count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return nil
|
||||
}
|
||||
return s.Create(&User{
|
||||
ID: "admin",
|
||||
Email: "admin@localhost",
|
||||
PasswordHash: "",
|
||||
OTPSecret: "",
|
||||
OTPVerified: true,
|
||||
})
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"math"
|
||||
"math/big"
|
||||
"net/http"
|
||||
@@ -469,13 +469,13 @@ func (t *AsterTrader) GetBalance() (map[string]interface{}, error) {
|
||||
}
|
||||
|
||||
if !foundUSDT {
|
||||
log.Printf("⚠️ 未找到USDT资产记录!")
|
||||
logger.Infof("⚠️ 未找到USDT资产记录!")
|
||||
}
|
||||
|
||||
// 获取持仓计算保证金占用和真实未实现盈亏
|
||||
positions, err := t.GetPositions()
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 获取持仓信息失败: %v", err)
|
||||
logger.Infof("⚠️ 获取持仓信息失败: %v", err)
|
||||
// fallback: 无法获取持仓时使用简单计算
|
||||
return map[string]interface{}{
|
||||
"totalWalletBalance": crossWalletBalance,
|
||||
@@ -577,7 +577,7 @@ func (t *AsterTrader) GetPositions() ([]map[string]interface{}, error) {
|
||||
func (t *AsterTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
|
||||
// 开仓前先取消所有挂单,防止残留挂单导致仓位叠加
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消挂单失败(继续开仓): %v", err)
|
||||
logger.Infof(" ⚠ 取消挂单失败(继续开仓): %v", err)
|
||||
}
|
||||
|
||||
// 先设置杠杆
|
||||
@@ -614,7 +614,7 @@ func (t *AsterTrader) OpenLong(symbol string, quantity float64, leverage int) (m
|
||||
priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision)
|
||||
qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision)
|
||||
|
||||
log.Printf(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)",
|
||||
logger.Infof(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)",
|
||||
limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision)
|
||||
|
||||
params := map[string]interface{}{
|
||||
@@ -644,7 +644,7 @@ func (t *AsterTrader) OpenLong(symbol string, quantity float64, leverage int) (m
|
||||
func (t *AsterTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
|
||||
// 开仓前先取消所有挂单,防止残留挂单导致仓位叠加
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消挂单失败(继续开仓): %v", err)
|
||||
logger.Infof(" ⚠ 取消挂单失败(继续开仓): %v", err)
|
||||
}
|
||||
|
||||
// 先设置杠杆
|
||||
@@ -681,7 +681,7 @@ func (t *AsterTrader) OpenShort(symbol string, quantity float64, leverage int) (
|
||||
priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision)
|
||||
qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision)
|
||||
|
||||
log.Printf(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)",
|
||||
logger.Infof(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)",
|
||||
limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision)
|
||||
|
||||
params := map[string]interface{}{
|
||||
@@ -726,7 +726,7 @@ func (t *AsterTrader) CloseLong(symbol string, quantity float64) (map[string]int
|
||||
if quantity == 0 {
|
||||
return nil, fmt.Errorf("没有找到 %s 的多仓", symbol)
|
||||
}
|
||||
log.Printf(" 📊 获取到多仓数量: %.8f", quantity)
|
||||
logger.Infof(" 📊 获取到多仓数量: %.8f", quantity)
|
||||
}
|
||||
|
||||
price, err := t.GetMarketPrice(symbol)
|
||||
@@ -756,7 +756,7 @@ func (t *AsterTrader) CloseLong(symbol string, quantity float64) (map[string]int
|
||||
priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision)
|
||||
qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision)
|
||||
|
||||
log.Printf(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)",
|
||||
logger.Infof(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)",
|
||||
limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision)
|
||||
|
||||
params := map[string]interface{}{
|
||||
@@ -779,11 +779,11 @@ func (t *AsterTrader) CloseLong(symbol string, quantity float64) (map[string]int
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Printf("✓ 平多仓成功: %s 数量: %s", symbol, qtyStr)
|
||||
logger.Infof("✓ 平多仓成功: %s 数量: %s", symbol, qtyStr)
|
||||
|
||||
// 平仓后取消该币种的所有挂单(止损止盈单)
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消挂单失败: %v", err)
|
||||
logger.Infof(" ⚠ 取消挂单失败: %v", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
@@ -809,7 +809,7 @@ func (t *AsterTrader) CloseShort(symbol string, quantity float64) (map[string]in
|
||||
if quantity == 0 {
|
||||
return nil, fmt.Errorf("没有找到 %s 的空仓", symbol)
|
||||
}
|
||||
log.Printf(" 📊 获取到空仓数量: %.8f", quantity)
|
||||
logger.Infof(" 📊 获取到空仓数量: %.8f", quantity)
|
||||
}
|
||||
|
||||
price, err := t.GetMarketPrice(symbol)
|
||||
@@ -839,7 +839,7 @@ func (t *AsterTrader) CloseShort(symbol string, quantity float64) (map[string]in
|
||||
priceStr := t.formatFloatWithPrecision(formattedPrice, prec.PricePrecision)
|
||||
qtyStr := t.formatFloatWithPrecision(formattedQty, prec.QuantityPrecision)
|
||||
|
||||
log.Printf(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)",
|
||||
logger.Infof(" 📏 精度处理: 价格 %.8f -> %s (精度=%d), 数量 %.8f -> %s (精度=%d)",
|
||||
limitPrice, priceStr, prec.PricePrecision, quantity, qtyStr, prec.QuantityPrecision)
|
||||
|
||||
params := map[string]interface{}{
|
||||
@@ -862,11 +862,11 @@ func (t *AsterTrader) CloseShort(symbol string, quantity float64) (map[string]in
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Printf("✓ 平空仓成功: %s 数量: %s", symbol, qtyStr)
|
||||
logger.Infof("✓ 平空仓成功: %s 数量: %s", symbol, qtyStr)
|
||||
|
||||
// 平仓后取消该币种的所有挂单(止损止盈单)
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消挂单失败: %v", err)
|
||||
logger.Infof(" ⚠ 取消挂单失败: %v", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
@@ -892,30 +892,30 @@ func (t *AsterTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
|
||||
// 如果错误表示无需更改,忽略错误
|
||||
if strings.Contains(err.Error(), "No need to change") ||
|
||||
strings.Contains(err.Error(), "Margin type cannot be changed") {
|
||||
log.Printf(" ✓ %s 仓位模式已是 %s 或有持仓无法更改", symbol, marginType)
|
||||
logger.Infof(" ✓ %s 仓位模式已是 %s 或有持仓无法更改", symbol, marginType)
|
||||
return nil
|
||||
}
|
||||
// 检测多资产模式(错误码 -4168)
|
||||
if strings.Contains(err.Error(), "Multi-Assets mode") ||
|
||||
strings.Contains(err.Error(), "-4168") ||
|
||||
strings.Contains(err.Error(), "4168") {
|
||||
log.Printf(" ⚠️ %s 检测到多资产模式,强制使用全仓模式", symbol)
|
||||
log.Printf(" 💡 提示:如需使用逐仓模式,请在交易所关闭多资产模式")
|
||||
logger.Infof(" ⚠️ %s 检测到多资产模式,强制使用全仓模式", symbol)
|
||||
logger.Infof(" 💡 提示:如需使用逐仓模式,请在交易所关闭多资产模式")
|
||||
return nil
|
||||
}
|
||||
// 检测统一账户 API
|
||||
if strings.Contains(err.Error(), "unified") ||
|
||||
strings.Contains(err.Error(), "portfolio") ||
|
||||
strings.Contains(err.Error(), "Portfolio") {
|
||||
log.Printf(" ❌ %s 检测到统一账户 API,无法进行合约交易", symbol)
|
||||
logger.Infof(" ❌ %s 检测到统一账户 API,无法进行合约交易", symbol)
|
||||
return fmt.Errorf("请使用「现货与合约交易」API 权限,不要使用「统一账户 API」")
|
||||
}
|
||||
log.Printf(" ⚠️ 设置仓位模式失败: %v", err)
|
||||
logger.Infof(" ⚠️ 设置仓位模式失败: %v", err)
|
||||
// 不返回错误,让交易继续
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf(" ✓ %s 仓位模式已设置为 %s", symbol, marginType)
|
||||
logger.Infof(" ✓ %s 仓位模式已设置为 %s", symbol, marginType)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1075,19 +1075,19 @@ func (t *AsterTrader) CancelStopLossOrders(symbol string) error {
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("订单ID %d: %v", int64(orderID), err)
|
||||
cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg))
|
||||
log.Printf(" ⚠ 取消止损单失败: %s", errMsg)
|
||||
logger.Infof(" ⚠ 取消止损单失败: %s", errMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
canceledCount++
|
||||
log.Printf(" ✓ 已取消止损单 (订单ID: %d, 类型: %s, 方向: %s)", int64(orderID), orderType, positionSide)
|
||||
logger.Infof(" ✓ 已取消止损单 (订单ID: %d, 类型: %s, 方向: %s)", int64(orderID), orderType, positionSide)
|
||||
}
|
||||
}
|
||||
|
||||
if canceledCount == 0 && len(cancelErrors) == 0 {
|
||||
log.Printf(" ℹ %s 没有止损单需要取消", symbol)
|
||||
logger.Infof(" ℹ %s 没有止损单需要取消", symbol)
|
||||
} else if canceledCount > 0 {
|
||||
log.Printf(" ✓ 已取消 %s 的 %d 个止损单", symbol, canceledCount)
|
||||
logger.Infof(" ✓ 已取消 %s 的 %d 个止损单", symbol, canceledCount)
|
||||
}
|
||||
|
||||
// 如果所有取消都失败了,返回错误
|
||||
@@ -1134,19 +1134,19 @@ func (t *AsterTrader) CancelTakeProfitOrders(symbol string) error {
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("订单ID %d: %v", int64(orderID), err)
|
||||
cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg))
|
||||
log.Printf(" ⚠ 取消止盈单失败: %s", errMsg)
|
||||
logger.Infof(" ⚠ 取消止盈单失败: %s", errMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
canceledCount++
|
||||
log.Printf(" ✓ 已取消止盈单 (订单ID: %d, 类型: %s, 方向: %s)", int64(orderID), orderType, positionSide)
|
||||
logger.Infof(" ✓ 已取消止盈单 (订单ID: %d, 类型: %s, 方向: %s)", int64(orderID), orderType, positionSide)
|
||||
}
|
||||
}
|
||||
|
||||
if canceledCount == 0 && len(cancelErrors) == 0 {
|
||||
log.Printf(" ℹ %s 没有止盈单需要取消", symbol)
|
||||
logger.Infof(" ℹ %s 没有止盈单需要取消", symbol)
|
||||
} else if canceledCount > 0 {
|
||||
log.Printf(" ✓ 已取消 %s 的 %d 个止盈单", symbol, canceledCount)
|
||||
logger.Infof(" ✓ 已取消 %s 的 %d 个止盈单", symbol, canceledCount)
|
||||
}
|
||||
|
||||
// 如果所有取消都失败了,返回错误
|
||||
@@ -1203,20 +1203,20 @@ func (t *AsterTrader) CancelStopOrders(symbol string) error {
|
||||
|
||||
_, err := t.request("DELETE", "/fapi/v3/order", cancelParams)
|
||||
if err != nil {
|
||||
log.Printf(" ⚠ 取消订单 %d 失败: %v", int64(orderID), err)
|
||||
logger.Infof(" ⚠ 取消订单 %d 失败: %v", int64(orderID), err)
|
||||
continue
|
||||
}
|
||||
|
||||
canceledCount++
|
||||
log.Printf(" ✓ 已取消 %s 的止盈/止损单 (订单ID: %d, 类型: %s)",
|
||||
logger.Infof(" ✓ 已取消 %s 的止盈/止损单 (订单ID: %d, 类型: %s)",
|
||||
symbol, int64(orderID), orderType)
|
||||
}
|
||||
}
|
||||
|
||||
if canceledCount == 0 {
|
||||
log.Printf(" ℹ %s 没有止盈/止损单需要取消", symbol)
|
||||
logger.Infof(" ℹ %s 没有止盈/止损单需要取消", symbol)
|
||||
} else {
|
||||
log.Printf(" ✓ 已取消 %s 的 %d 个止盈/止损单", symbol, canceledCount)
|
||||
logger.Infof(" ✓ 已取消 %s 的 %d 个止盈/止损单", symbol, canceledCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1230,3 +1230,52 @@ func (t *AsterTrader) FormatQuantity(symbol string, quantity float64) (string, e
|
||||
}
|
||||
return fmt.Sprintf("%v", formatted), nil
|
||||
}
|
||||
|
||||
// GetOrderStatus 获取订单状态
|
||||
func (t *AsterTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
|
||||
params := map[string]interface{}{
|
||||
"symbol": symbol,
|
||||
"orderId": orderID,
|
||||
}
|
||||
|
||||
body, err := t.request("GET", "/fapi/v3/order", params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取订单状态失败: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析订单响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 标准化返回字段
|
||||
response := map[string]interface{}{
|
||||
"orderId": result["orderId"],
|
||||
"symbol": result["symbol"],
|
||||
"status": result["status"],
|
||||
"side": result["side"],
|
||||
"type": result["type"],
|
||||
"time": result["time"],
|
||||
"updateTime": result["updateTime"],
|
||||
"commission": 0.0, // Aster 可能需要单独查询
|
||||
}
|
||||
|
||||
// 解析数值字段
|
||||
if avgPrice, ok := result["avgPrice"].(string); ok {
|
||||
if v, err := strconv.ParseFloat(avgPrice, 64); err == nil {
|
||||
response["avgPrice"] = v
|
||||
}
|
||||
} else if avgPrice, ok := result["avgPrice"].(float64); ok {
|
||||
response["avgPrice"] = avgPrice
|
||||
}
|
||||
|
||||
if executedQty, ok := result["executedQty"].(string); ok {
|
||||
if v, err := strconv.ParseFloat(executedQty, 64); err == nil {
|
||||
response["executedQty"] = v
|
||||
}
|
||||
} else if executedQty, ok := result["executedQty"].(float64); ok {
|
||||
response["executedQty"] = executedQty
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,9 +8,9 @@ import (
|
||||
"time"
|
||||
|
||||
"nofx/decision"
|
||||
"nofx/logger"
|
||||
"nofx/market"
|
||||
"nofx/pool"
|
||||
"nofx/store"
|
||||
|
||||
"github.com/agiledragon/gomonkey/v2"
|
||||
"github.com/stretchr/testify/suite"
|
||||
@@ -30,8 +30,7 @@ type AutoTraderTestSuite struct {
|
||||
|
||||
// Mock 依赖
|
||||
mockTrader *MockTrader
|
||||
mockDB *MockDatabase
|
||||
mockLogger logger.IDecisionLogger
|
||||
mockStore *store.Store
|
||||
|
||||
// gomonkey patches
|
||||
patches *gomonkey.Patches
|
||||
@@ -65,10 +64,9 @@ func (s *AutoTraderTestSuite) SetupTest() {
|
||||
positions: []map[string]interface{}{},
|
||||
}
|
||||
|
||||
s.mockDB = &MockDatabase{}
|
||||
|
||||
// 创建临时决策日志记录器
|
||||
s.mockLogger = logger.NewDecisionLogger("/tmp/test_decision_logs")
|
||||
// 创建临时store(使用nil表示测试中不需要实际的store)
|
||||
s.mockStore = nil
|
||||
|
||||
// 设置默认配置
|
||||
s.config = AutoTraderConfig{
|
||||
@@ -93,7 +91,7 @@ func (s *AutoTraderTestSuite) SetupTest() {
|
||||
config: s.config,
|
||||
trader: s.mockTrader,
|
||||
mcpClient: nil, // 测试中不需要实际的 MCP Client
|
||||
decisionLogger: s.mockLogger,
|
||||
store: s.mockStore,
|
||||
initialBalance: s.config.InitialBalance,
|
||||
systemPromptTemplate: s.config.SystemPromptTemplate,
|
||||
defaultCoins: []string{"BTC", "ETH"},
|
||||
@@ -106,7 +104,6 @@ func (s *AutoTraderTestSuite) SetupTest() {
|
||||
stopMonitorCh: make(chan struct{}),
|
||||
peakPnLCache: make(map[string]float64),
|
||||
lastBalanceSyncTime: time.Now(),
|
||||
database: s.mockDB,
|
||||
userID: "test_user",
|
||||
}
|
||||
}
|
||||
@@ -134,9 +131,8 @@ func (s *AutoTraderTestSuite) TestSortDecisionsByPriority() {
|
||||
{Action: "open_long", Symbol: "BTCUSDT"},
|
||||
{Action: "close_short", Symbol: "ETHUSDT"},
|
||||
{Action: "hold", Symbol: "BNBUSDT"},
|
||||
{Action: "update_stop_loss", Symbol: "SOLUSDT"},
|
||||
{Action: "open_short", Symbol: "ADAUSDT"},
|
||||
{Action: "partial_close", Symbol: "DOGEUSDT"},
|
||||
{Action: "close_long", Symbol: "DOGEUSDT"},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -150,14 +146,12 @@ func (s *AutoTraderTestSuite) TestSortDecisionsByPriority() {
|
||||
// 验证优先级是否递增
|
||||
getActionPriority := func(action string) int {
|
||||
switch action {
|
||||
case "close_long", "close_short", "partial_close":
|
||||
case "close_long", "close_short":
|
||||
return 1
|
||||
case "update_stop_loss", "update_take_profit":
|
||||
return 2
|
||||
case "open_long", "open_short":
|
||||
return 3
|
||||
return 2
|
||||
case "hold", "wait":
|
||||
return 4
|
||||
return 3
|
||||
default:
|
||||
return 999
|
||||
}
|
||||
@@ -413,14 +407,14 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() {
|
||||
existingSide string
|
||||
availBalance float64
|
||||
expectedErr string
|
||||
executeFn func(*decision.Decision, *logger.DecisionAction) error
|
||||
executeFn func(*decision.Decision, *store.DecisionAction) error
|
||||
}{
|
||||
{
|
||||
name: "成功开多仓",
|
||||
action: "open_long",
|
||||
expectedOrder: 123456,
|
||||
availBalance: 8000.0,
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
executeFn: func(d *decision.Decision, a *store.DecisionAction) error {
|
||||
return s.autoTrader.executeOpenLongWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
@@ -429,7 +423,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() {
|
||||
action: "open_short",
|
||||
expectedOrder: 123457,
|
||||
availBalance: 8000.0,
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
executeFn: func(d *decision.Decision, a *store.DecisionAction) error {
|
||||
return s.autoTrader.executeOpenShortWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
@@ -438,7 +432,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() {
|
||||
action: "open_long",
|
||||
availBalance: 0.0,
|
||||
expectedErr: "保证金不足",
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
executeFn: func(d *decision.Decision, a *store.DecisionAction) error {
|
||||
return s.autoTrader.executeOpenLongWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
@@ -447,7 +441,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() {
|
||||
action: "open_short",
|
||||
availBalance: 0.0,
|
||||
expectedErr: "保证金不足",
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
executeFn: func(d *decision.Decision, a *store.DecisionAction) error {
|
||||
return s.autoTrader.executeOpenShortWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
@@ -457,7 +451,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() {
|
||||
existingSide: "long",
|
||||
availBalance: 8000.0,
|
||||
expectedErr: "已有多仓",
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
executeFn: func(d *decision.Decision, a *store.DecisionAction) error {
|
||||
return s.autoTrader.executeOpenLongWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
@@ -467,7 +461,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() {
|
||||
existingSide: "short",
|
||||
availBalance: 8000.0,
|
||||
expectedErr: "已有空仓",
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
executeFn: func(d *decision.Decision, a *store.DecisionAction) error {
|
||||
return s.autoTrader.executeOpenShortWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
@@ -488,7 +482,7 @@ func (s *AutoTraderTestSuite) TestExecuteOpenPosition() {
|
||||
}
|
||||
|
||||
decision := &decision.Decision{Action: tt.action, Symbol: "BTCUSDT", PositionSizeUSD: 1000.0, Leverage: 10}
|
||||
actionRecord := &logger.DecisionAction{Action: tt.action, Symbol: "BTCUSDT"}
|
||||
actionRecord := &store.DecisionAction{Action: tt.action, Symbol: "BTCUSDT"}
|
||||
|
||||
err := tt.executeFn(decision, actionRecord)
|
||||
|
||||
@@ -516,14 +510,14 @@ func (s *AutoTraderTestSuite) TestExecuteClosePosition() {
|
||||
action string
|
||||
currentPrice float64
|
||||
expectedOrder int64
|
||||
executeFn func(*decision.Decision, *logger.DecisionAction) error
|
||||
executeFn func(*decision.Decision, *store.DecisionAction) error
|
||||
}{
|
||||
{
|
||||
name: "成功平多仓",
|
||||
action: "close_long",
|
||||
currentPrice: 51000.0,
|
||||
expectedOrder: 123458,
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
executeFn: func(d *decision.Decision, a *store.DecisionAction) error {
|
||||
return s.autoTrader.executeCloseLongWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
@@ -532,7 +526,7 @@ func (s *AutoTraderTestSuite) TestExecuteClosePosition() {
|
||||
action: "close_short",
|
||||
currentPrice: 49000.0,
|
||||
expectedOrder: 123459,
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
executeFn: func(d *decision.Decision, a *store.DecisionAction) error {
|
||||
return s.autoTrader.executeCloseShortWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
@@ -546,7 +540,7 @@ func (s *AutoTraderTestSuite) TestExecuteClosePosition() {
|
||||
})
|
||||
|
||||
decision := &decision.Decision{Action: tt.action, Symbol: "BTCUSDT"}
|
||||
actionRecord := &logger.DecisionAction{Action: tt.action, Symbol: "BTCUSDT"}
|
||||
actionRecord := &store.DecisionAction{Action: tt.action, Symbol: "BTCUSDT"}
|
||||
|
||||
err := tt.executeFn(decision, actionRecord)
|
||||
|
||||
@@ -557,221 +551,6 @@ func (s *AutoTraderTestSuite) TestExecuteClosePosition() {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExecuteUpdateStopOrTakeProfit 测试更新止损/止盈(多空通用)
|
||||
func (s *AutoTraderTestSuite) TestExecuteUpdateStopOrTakeProfit() {
|
||||
// 使用指针变量来控制 market.Get 的返回值
|
||||
var testPrice *float64
|
||||
s.patches.ApplyFunc(market.Get, func(symbol string) (*market.Data, error) {
|
||||
price := 50000.0
|
||||
if testPrice != nil {
|
||||
price = *testPrice
|
||||
}
|
||||
return &market.Data{Symbol: symbol, CurrentPrice: price}, nil
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
action string
|
||||
symbol string
|
||||
side string
|
||||
currentPrice float64
|
||||
newPrice float64
|
||||
hasPosition bool
|
||||
expectedErr string
|
||||
executeFn func(*decision.Decision, *logger.DecisionAction) error
|
||||
}{
|
||||
{
|
||||
name: "成功更新多头止损",
|
||||
action: "update_stop_loss",
|
||||
symbol: "BTCUSDT",
|
||||
side: "long",
|
||||
currentPrice: 52000.0,
|
||||
newPrice: 51000.0,
|
||||
hasPosition: true,
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
return s.autoTrader.executeUpdateStopLossWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "成功更新空头止损",
|
||||
action: "update_stop_loss",
|
||||
symbol: "ETHUSDT",
|
||||
side: "short",
|
||||
currentPrice: 2900.0,
|
||||
newPrice: 2950.0,
|
||||
hasPosition: true,
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
return s.autoTrader.executeUpdateStopLossWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "成功更新多头止盈",
|
||||
action: "update_take_profit",
|
||||
symbol: "BTCUSDT",
|
||||
side: "long",
|
||||
currentPrice: 52000.0,
|
||||
newPrice: 55000.0,
|
||||
hasPosition: true,
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
return s.autoTrader.executeUpdateTakeProfitWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "成功更新空头止盈",
|
||||
action: "update_take_profit",
|
||||
symbol: "ETHUSDT",
|
||||
side: "short",
|
||||
currentPrice: 2900.0,
|
||||
newPrice: 2800.0,
|
||||
hasPosition: true,
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
return s.autoTrader.executeUpdateTakeProfitWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "多头止损价格不合理",
|
||||
action: "update_stop_loss",
|
||||
symbol: "BTCUSDT",
|
||||
side: "long",
|
||||
currentPrice: 50000.0,
|
||||
newPrice: 51000.0,
|
||||
hasPosition: true,
|
||||
expectedErr: "多单止损必须低于当前价格",
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
return s.autoTrader.executeUpdateStopLossWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "多头止盈价格不合理",
|
||||
action: "update_take_profit",
|
||||
symbol: "BTCUSDT",
|
||||
side: "long",
|
||||
currentPrice: 50000.0,
|
||||
newPrice: 49000.0,
|
||||
hasPosition: true,
|
||||
expectedErr: "多单止盈必须高于当前价格",
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
return s.autoTrader.executeUpdateTakeProfitWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "止损_持仓不存在",
|
||||
action: "update_stop_loss",
|
||||
symbol: "BTCUSDT",
|
||||
currentPrice: 50000.0,
|
||||
newPrice: 49000.0,
|
||||
hasPosition: false,
|
||||
expectedErr: "持仓不存在",
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
return s.autoTrader.executeUpdateStopLossWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "止盈_持仓不存在",
|
||||
action: "update_take_profit",
|
||||
symbol: "BTCUSDT",
|
||||
currentPrice: 50000.0,
|
||||
newPrice: 55000.0,
|
||||
hasPosition: false,
|
||||
expectedErr: "持仓不存在",
|
||||
executeFn: func(d *decision.Decision, a *logger.DecisionAction) error {
|
||||
return s.autoTrader.executeUpdateTakeProfitWithRecord(d, a)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
time.Sleep(time.Millisecond)
|
||||
s.Run(tt.name, func() {
|
||||
// 设置当前测试用例的价格
|
||||
testPrice = &tt.currentPrice
|
||||
|
||||
if tt.hasPosition {
|
||||
s.mockTrader.positions = []map[string]interface{}{
|
||||
{"symbol": tt.symbol, "side": tt.side, "positionAmt": 0.1},
|
||||
}
|
||||
} else {
|
||||
s.mockTrader.positions = []map[string]interface{}{}
|
||||
}
|
||||
|
||||
decision := &decision.Decision{Action: tt.action, Symbol: tt.symbol}
|
||||
if tt.action == "update_stop_loss" {
|
||||
decision.NewStopLoss = tt.newPrice
|
||||
} else {
|
||||
decision.NewTakeProfit = tt.newPrice
|
||||
}
|
||||
actionRecord := &logger.DecisionAction{Action: tt.action, Symbol: tt.symbol}
|
||||
|
||||
err := tt.executeFn(decision, actionRecord)
|
||||
|
||||
if tt.expectedErr != "" {
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), tt.expectedErr)
|
||||
} else {
|
||||
s.NoError(err)
|
||||
s.Equal(tt.currentPrice, actionRecord.Price)
|
||||
}
|
||||
|
||||
// 恢复默认状态
|
||||
s.mockTrader.positions = []map[string]interface{}{}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AutoTraderTestSuite) TestExecutePartialCloseWithRecord() {
|
||||
s.Run("成功部分平仓", func() {
|
||||
// 设置持仓
|
||||
s.mockTrader.positions = []map[string]interface{}{
|
||||
{
|
||||
"symbol": "BTCUSDT",
|
||||
"side": "long",
|
||||
"positionAmt": 0.1,
|
||||
"entryPrice": 50000.0,
|
||||
"markPrice": 52000.0,
|
||||
},
|
||||
}
|
||||
|
||||
// Mock market.Get
|
||||
s.patches.ApplyFunc(market.Get, func(symbol string) (*market.Data, error) {
|
||||
return &market.Data{
|
||||
Symbol: symbol,
|
||||
CurrentPrice: 52000.0,
|
||||
}, nil
|
||||
})
|
||||
|
||||
decision := &decision.Decision{
|
||||
Action: "partial_close",
|
||||
Symbol: "BTCUSDT",
|
||||
ClosePercentage: 50.0,
|
||||
}
|
||||
|
||||
actionRecord := &logger.DecisionAction{
|
||||
Action: "partial_close",
|
||||
Symbol: "BTCUSDT",
|
||||
}
|
||||
|
||||
err := s.autoTrader.executePartialCloseWithRecord(decision, actionRecord)
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(0.05, actionRecord.Quantity) // 50% of 0.1
|
||||
})
|
||||
|
||||
s.Run("无效的平仓百分比", func() {
|
||||
decision := &decision.Decision{
|
||||
Action: "partial_close",
|
||||
Symbol: "BTCUSDT",
|
||||
ClosePercentage: 150.0, // 无效
|
||||
}
|
||||
|
||||
actionRecord := &logger.DecisionAction{}
|
||||
|
||||
err := s.autoTrader.executePartialCloseWithRecord(decision, actionRecord)
|
||||
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "平仓百分比必须在 0-100 之间")
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 层次 10: executeDecisionWithRecord 路由测试
|
||||
// ============================================================
|
||||
@@ -792,7 +571,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() {
|
||||
PositionSizeUSD: 1000.0,
|
||||
Leverage: 10,
|
||||
}
|
||||
actionRecord := &logger.DecisionAction{}
|
||||
actionRecord := &store.DecisionAction{}
|
||||
|
||||
err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord)
|
||||
s.NoError(err)
|
||||
@@ -803,7 +582,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() {
|
||||
Action: "close_long",
|
||||
Symbol: "BTCUSDT",
|
||||
}
|
||||
actionRecord := &logger.DecisionAction{}
|
||||
actionRecord := &store.DecisionAction{}
|
||||
|
||||
err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord)
|
||||
s.NoError(err)
|
||||
@@ -814,7 +593,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() {
|
||||
Action: "hold",
|
||||
Symbol: "BTCUSDT",
|
||||
}
|
||||
actionRecord := &logger.DecisionAction{}
|
||||
actionRecord := &store.DecisionAction{}
|
||||
|
||||
err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord)
|
||||
s.NoError(err)
|
||||
@@ -825,7 +604,7 @@ func (s *AutoTraderTestSuite) TestExecuteDecisionWithRecord() {
|
||||
Action: "unknown_action",
|
||||
Symbol: "BTCUSDT",
|
||||
}
|
||||
actionRecord := &logger.DecisionAction{}
|
||||
actionRecord := &store.DecisionAction{}
|
||||
|
||||
err := s.autoTrader.executeDecisionWithRecord(decision, actionRecord)
|
||||
s.Error(err)
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/hook"
|
||||
"nofx/logger"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -80,7 +80,7 @@ func NewFuturesTrader(apiKey, secretKey string, userId string) *FuturesTrader {
|
||||
// 设置双向持仓模式(Hedge Mode)
|
||||
// 这是必需的,因为代码中使用了 PositionSide (LONG/SHORT)
|
||||
if err := trader.setDualSidePosition(); err != nil {
|
||||
log.Printf("⚠️ 设置双向持仓模式失败: %v (如果已是双向模式则忽略此警告)", err)
|
||||
logger.Infof("⚠️ 设置双向持仓模式失败: %v (如果已是双向模式则忽略此警告)", err)
|
||||
}
|
||||
|
||||
return trader
|
||||
@@ -96,15 +96,15 @@ func (t *FuturesTrader) setDualSidePosition() error {
|
||||
if err != nil {
|
||||
// 如果错误信息包含"No need to change",说明已经是双向持仓模式
|
||||
if strings.Contains(err.Error(), "No need to change position side") {
|
||||
log.Printf(" ✓ 账户已是双向持仓模式(Hedge Mode)")
|
||||
logger.Infof(" ✓ 账户已是双向持仓模式(Hedge Mode)")
|
||||
return nil
|
||||
}
|
||||
// 其他错误则返回(但在调用方不会中断初始化)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf(" ✓ 账户已切换为双向持仓模式(Hedge Mode)")
|
||||
log.Printf(" ℹ️ 双向持仓模式允许同时持有多单和空单")
|
||||
logger.Infof(" ✓ 账户已切换为双向持仓模式(Hedge Mode)")
|
||||
logger.Infof(" ℹ️ 双向持仓模式允许同时持有多单和空单")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -112,14 +112,14 @@ func (t *FuturesTrader) setDualSidePosition() error {
|
||||
func syncBinanceServerTime(client *futures.Client) {
|
||||
serverTime, err := client.NewServerTimeService().Do(context.Background())
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 同步币安服务器时间失败: %v", err)
|
||||
logger.Infof("⚠️ 同步币安服务器时间失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
offset := now - serverTime
|
||||
client.TimeOffset = offset
|
||||
log.Printf("⏱ 已同步币安服务器时间,偏移 %dms", offset)
|
||||
logger.Infof("⏱ 已同步币安服务器时间,偏移 %dms", offset)
|
||||
}
|
||||
|
||||
// GetBalance 获取账户余额(带缓存)
|
||||
@@ -129,16 +129,16 @@ func (t *FuturesTrader) GetBalance() (map[string]interface{}, error) {
|
||||
if t.cachedBalance != nil && time.Since(t.balanceCacheTime) < t.cacheDuration {
|
||||
cacheAge := time.Since(t.balanceCacheTime)
|
||||
t.balanceCacheMutex.RUnlock()
|
||||
log.Printf("✓ 使用缓存的账户余额(缓存时间: %.1f秒前)", cacheAge.Seconds())
|
||||
logger.Infof("✓ 使用缓存的账户余额(缓存时间: %.1f秒前)", cacheAge.Seconds())
|
||||
return t.cachedBalance, nil
|
||||
}
|
||||
t.balanceCacheMutex.RUnlock()
|
||||
|
||||
// 缓存过期或不存在,调用API
|
||||
log.Printf("🔄 缓存过期,正在调用币安API获取账户余额...")
|
||||
logger.Infof("🔄 缓存过期,正在调用币安API获取账户余额...")
|
||||
account, err := t.client.NewGetAccountService().Do(context.Background())
|
||||
if err != nil {
|
||||
log.Printf("❌ 币安API调用失败: %v", err)
|
||||
logger.Infof("❌ 币安API调用失败: %v", err)
|
||||
return nil, fmt.Errorf("获取账户信息失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -147,7 +147,7 @@ func (t *FuturesTrader) GetBalance() (map[string]interface{}, error) {
|
||||
result["availableBalance"], _ = strconv.ParseFloat(account.AvailableBalance, 64)
|
||||
result["totalUnrealizedProfit"], _ = strconv.ParseFloat(account.TotalUnrealizedProfit, 64)
|
||||
|
||||
log.Printf("✓ 币安API返回: 总余额=%s, 可用=%s, 未实现盈亏=%s",
|
||||
logger.Infof("✓ 币安API返回: 总余额=%s, 可用=%s, 未实现盈亏=%s",
|
||||
account.TotalWalletBalance,
|
||||
account.AvailableBalance,
|
||||
account.TotalUnrealizedProfit)
|
||||
@@ -168,13 +168,13 @@ func (t *FuturesTrader) GetPositions() ([]map[string]interface{}, error) {
|
||||
if t.cachedPositions != nil && time.Since(t.positionsCacheTime) < t.cacheDuration {
|
||||
cacheAge := time.Since(t.positionsCacheTime)
|
||||
t.positionsCacheMutex.RUnlock()
|
||||
log.Printf("✓ 使用缓存的持仓信息(缓存时间: %.1f秒前)", cacheAge.Seconds())
|
||||
logger.Infof("✓ 使用缓存的持仓信息(缓存时间: %.1f秒前)", cacheAge.Seconds())
|
||||
return t.cachedPositions, nil
|
||||
}
|
||||
t.positionsCacheMutex.RUnlock()
|
||||
|
||||
// 缓存过期或不存在,调用API
|
||||
log.Printf("🔄 缓存过期,正在调用币安API获取持仓信息...")
|
||||
logger.Infof("🔄 缓存过期,正在调用币安API获取持仓信息...")
|
||||
positions, err := t.client.NewGetPositionRiskService().Do(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取持仓失败: %w", err)
|
||||
@@ -238,31 +238,31 @@ func (t *FuturesTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
|
||||
if err != nil {
|
||||
// 如果错误信息包含"No need to change",说明仓位模式已经是目标值
|
||||
if contains(err.Error(), "No need to change margin type") {
|
||||
log.Printf(" ✓ %s 仓位模式已是 %s", symbol, marginModeStr)
|
||||
logger.Infof(" ✓ %s 仓位模式已是 %s", symbol, marginModeStr)
|
||||
return nil
|
||||
}
|
||||
// 如果有持仓,无法更改仓位模式,但不影响交易
|
||||
if contains(err.Error(), "Margin type cannot be changed if there exists position") {
|
||||
log.Printf(" ⚠️ %s 有持仓,无法更改仓位模式,继续使用当前模式", symbol)
|
||||
logger.Infof(" ⚠️ %s 有持仓,无法更改仓位模式,继续使用当前模式", symbol)
|
||||
return nil
|
||||
}
|
||||
// 检测多资产模式(错误码 -4168)
|
||||
if contains(err.Error(), "Multi-Assets mode") || contains(err.Error(), "-4168") || contains(err.Error(), "4168") {
|
||||
log.Printf(" ⚠️ %s 检测到多资产模式,强制使用全仓模式", symbol)
|
||||
log.Printf(" 💡 提示:如需使用逐仓模式,请在币安关闭多资产模式")
|
||||
logger.Infof(" ⚠️ %s 检测到多资产模式,强制使用全仓模式", symbol)
|
||||
logger.Infof(" 💡 提示:如需使用逐仓模式,请在币安关闭多资产模式")
|
||||
return nil
|
||||
}
|
||||
// 检测统一账户 API(Portfolio Margin)
|
||||
if contains(err.Error(), "unified") || contains(err.Error(), "portfolio") || contains(err.Error(), "Portfolio") {
|
||||
log.Printf(" ❌ %s 检测到统一账户 API,无法进行合约交易", symbol)
|
||||
logger.Infof(" ❌ %s 检测到统一账户 API,无法进行合约交易", symbol)
|
||||
return fmt.Errorf("请使用「现货与合约交易」API 权限,不要使用「统一账户 API」")
|
||||
}
|
||||
log.Printf(" ⚠️ 设置仓位模式失败: %v", err)
|
||||
logger.Infof(" ⚠️ 设置仓位模式失败: %v", err)
|
||||
// 不返回错误,让交易继续
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Printf(" ✓ %s 仓位模式已设置为 %s", symbol, marginModeStr)
|
||||
logger.Infof(" ✓ %s 仓位模式已设置为 %s", symbol, marginModeStr)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -284,7 +284,7 @@ func (t *FuturesTrader) SetLeverage(symbol string, leverage int) error {
|
||||
|
||||
// 如果当前杠杆已经是目标杠杆,跳过
|
||||
if currentLeverage == leverage && currentLeverage > 0 {
|
||||
log.Printf(" ✓ %s 杠杆已是 %dx,无需切换", symbol, leverage)
|
||||
logger.Infof(" ✓ %s 杠杆已是 %dx,无需切换", symbol, leverage)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -297,16 +297,16 @@ func (t *FuturesTrader) SetLeverage(symbol string, leverage int) error {
|
||||
if err != nil {
|
||||
// 如果错误信息包含"No need to change",说明杠杆已经是目标值
|
||||
if contains(err.Error(), "No need to change") {
|
||||
log.Printf(" ✓ %s 杠杆已是 %dx", symbol, leverage)
|
||||
logger.Infof(" ✓ %s 杠杆已是 %dx", symbol, leverage)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("设置杠杆失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf(" ✓ %s 杠杆已切换为 %dx", symbol, leverage)
|
||||
logger.Infof(" ✓ %s 杠杆已切换为 %dx", symbol, leverage)
|
||||
|
||||
// 切换杠杆后等待5秒(避免冷却期错误)
|
||||
log.Printf(" ⏱ 等待5秒冷却期...")
|
||||
logger.Infof(" ⏱ 等待5秒冷却期...")
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
return nil
|
||||
@@ -316,7 +316,7 @@ func (t *FuturesTrader) SetLeverage(symbol string, leverage int) error {
|
||||
func (t *FuturesTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
|
||||
// 先取消该币种的所有委托单(清理旧的止损止盈单)
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消旧委托单失败(可能没有委托单): %v", err)
|
||||
logger.Infof(" ⚠ 取消旧委托单失败(可能没有委托单): %v", err)
|
||||
}
|
||||
|
||||
// 设置杠杆
|
||||
@@ -357,8 +357,8 @@ func (t *FuturesTrader) OpenLong(symbol string, quantity float64, leverage int)
|
||||
return nil, fmt.Errorf("开多仓失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ 开多仓成功: %s 数量: %s", symbol, quantityStr)
|
||||
log.Printf(" 订单ID: %d", order.OrderID)
|
||||
logger.Infof("✓ 开多仓成功: %s 数量: %s", symbol, quantityStr)
|
||||
logger.Infof(" 订单ID: %d", order.OrderID)
|
||||
|
||||
result := make(map[string]interface{})
|
||||
result["orderId"] = order.OrderID
|
||||
@@ -371,7 +371,7 @@ func (t *FuturesTrader) OpenLong(symbol string, quantity float64, leverage int)
|
||||
func (t *FuturesTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
|
||||
// 先取消该币种的所有委托单(清理旧的止损止盈单)
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消旧委托单失败(可能没有委托单): %v", err)
|
||||
logger.Infof(" ⚠ 取消旧委托单失败(可能没有委托单): %v", err)
|
||||
}
|
||||
|
||||
// 设置杠杆
|
||||
@@ -412,8 +412,8 @@ func (t *FuturesTrader) OpenShort(symbol string, quantity float64, leverage int)
|
||||
return nil, fmt.Errorf("开空仓失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ 开空仓成功: %s 数量: %s", symbol, quantityStr)
|
||||
log.Printf(" 订单ID: %d", order.OrderID)
|
||||
logger.Infof("✓ 开空仓成功: %s 数量: %s", symbol, quantityStr)
|
||||
logger.Infof(" 订单ID: %d", order.OrderID)
|
||||
|
||||
result := make(map[string]interface{})
|
||||
result["orderId"] = order.OrderID
|
||||
@@ -463,11 +463,11 @@ func (t *FuturesTrader) CloseLong(symbol string, quantity float64) (map[string]i
|
||||
return nil, fmt.Errorf("平多仓失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ 平多仓成功: %s 数量: %s", symbol, quantityStr)
|
||||
logger.Infof("✓ 平多仓成功: %s 数量: %s", symbol, quantityStr)
|
||||
|
||||
// 平仓后取消该币种的所有挂单(止损止盈单)
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消挂单失败: %v", err)
|
||||
logger.Infof(" ⚠ 取消挂单失败: %v", err)
|
||||
}
|
||||
|
||||
result := make(map[string]interface{})
|
||||
@@ -518,11 +518,11 @@ func (t *FuturesTrader) CloseShort(symbol string, quantity float64) (map[string]
|
||||
return nil, fmt.Errorf("平空仓失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ 平空仓成功: %s 数量: %s", symbol, quantityStr)
|
||||
logger.Infof("✓ 平空仓成功: %s 数量: %s", symbol, quantityStr)
|
||||
|
||||
// 平仓后取消该币种的所有挂单(止损止盈单)
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消挂单失败: %v", err)
|
||||
logger.Infof(" ⚠ 取消挂单失败: %v", err)
|
||||
}
|
||||
|
||||
result := make(map[string]interface{})
|
||||
@@ -559,19 +559,19 @@ func (t *FuturesTrader) CancelStopLossOrders(symbol string) error {
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("订单ID %d: %v", order.OrderID, err)
|
||||
cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg))
|
||||
log.Printf(" ⚠ 取消止损单失败: %s", errMsg)
|
||||
logger.Infof(" ⚠ 取消止损单失败: %s", errMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
canceledCount++
|
||||
log.Printf(" ✓ 已取消止损单 (订单ID: %d, 类型: %s, 方向: %s)", order.OrderID, orderType, order.PositionSide)
|
||||
logger.Infof(" ✓ 已取消止损单 (订单ID: %d, 类型: %s, 方向: %s)", order.OrderID, orderType, order.PositionSide)
|
||||
}
|
||||
}
|
||||
|
||||
if canceledCount == 0 && len(cancelErrors) == 0 {
|
||||
log.Printf(" ℹ %s 没有止损单需要取消", symbol)
|
||||
logger.Infof(" ℹ %s 没有止损单需要取消", symbol)
|
||||
} else if canceledCount > 0 {
|
||||
log.Printf(" ✓ 已取消 %s 的 %d 个止损单", symbol, canceledCount)
|
||||
logger.Infof(" ✓ 已取消 %s 的 %d 个止损单", symbol, canceledCount)
|
||||
}
|
||||
|
||||
// 如果所有取消都失败了,返回错误
|
||||
@@ -609,19 +609,19 @@ func (t *FuturesTrader) CancelTakeProfitOrders(symbol string) error {
|
||||
if err != nil {
|
||||
errMsg := fmt.Sprintf("订单ID %d: %v", order.OrderID, err)
|
||||
cancelErrors = append(cancelErrors, fmt.Errorf("%s", errMsg))
|
||||
log.Printf(" ⚠ 取消止盈单失败: %s", errMsg)
|
||||
logger.Infof(" ⚠ 取消止盈单失败: %s", errMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
canceledCount++
|
||||
log.Printf(" ✓ 已取消止盈单 (订单ID: %d, 类型: %s, 方向: %s)", order.OrderID, orderType, order.PositionSide)
|
||||
logger.Infof(" ✓ 已取消止盈单 (订单ID: %d, 类型: %s, 方向: %s)", order.OrderID, orderType, order.PositionSide)
|
||||
}
|
||||
}
|
||||
|
||||
if canceledCount == 0 && len(cancelErrors) == 0 {
|
||||
log.Printf(" ℹ %s 没有止盈单需要取消", symbol)
|
||||
logger.Infof(" ℹ %s 没有止盈单需要取消", symbol)
|
||||
} else if canceledCount > 0 {
|
||||
log.Printf(" ✓ 已取消 %s 的 %d 个止盈单", symbol, canceledCount)
|
||||
logger.Infof(" ✓ 已取消 %s 的 %d 个止盈单", symbol, canceledCount)
|
||||
}
|
||||
|
||||
// 如果所有取消都失败了,返回错误
|
||||
@@ -642,7 +642,7 @@ func (t *FuturesTrader) CancelAllOrders(symbol string) error {
|
||||
return fmt.Errorf("取消挂单失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf(" ✓ 已取消 %s 的所有挂单", symbol)
|
||||
logger.Infof(" ✓ 已取消 %s 的所有挂单", symbol)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -674,20 +674,20 @@ func (t *FuturesTrader) CancelStopOrders(symbol string) error {
|
||||
Do(context.Background())
|
||||
|
||||
if err != nil {
|
||||
log.Printf(" ⚠ 取消订单 %d 失败: %v", order.OrderID, err)
|
||||
logger.Infof(" ⚠ 取消订单 %d 失败: %v", order.OrderID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
canceledCount++
|
||||
log.Printf(" ✓ 已取消 %s 的止盈/止损单 (订单ID: %d, 类型: %s)",
|
||||
logger.Infof(" ✓ 已取消 %s 的止盈/止损单 (订单ID: %d, 类型: %s)",
|
||||
symbol, order.OrderID, orderType)
|
||||
}
|
||||
}
|
||||
|
||||
if canceledCount == 0 {
|
||||
log.Printf(" ℹ %s 没有止盈/止损单需要取消", symbol)
|
||||
logger.Infof(" ℹ %s 没有止盈/止损单需要取消", symbol)
|
||||
} else {
|
||||
log.Printf(" ✓ 已取消 %s 的 %d 个止盈/止损单", symbol, canceledCount)
|
||||
logger.Infof(" ✓ 已取消 %s 的 %d 个止盈/止损单", symbol, canceledCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -748,13 +748,14 @@ func (t *FuturesTrader) SetStopLoss(symbol string, positionSide string, quantity
|
||||
Quantity(quantityStr).
|
||||
WorkingType(futures.WorkingTypeContractPrice).
|
||||
ClosePosition(true).
|
||||
NewClientOrderID(getBrOrderID()).
|
||||
Do(context.Background())
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("设置止损失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf(" 止损价设置: %.4f", stopPrice)
|
||||
logger.Infof(" 止损价设置: %.4f", stopPrice)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -786,13 +787,14 @@ func (t *FuturesTrader) SetTakeProfit(symbol string, positionSide string, quanti
|
||||
Quantity(quantityStr).
|
||||
WorkingType(futures.WorkingTypeContractPrice).
|
||||
ClosePosition(true).
|
||||
NewClientOrderID(getBrOrderID()).
|
||||
Do(context.Background())
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("设置止盈失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf(" 止盈价设置: %.4f", takeProfitPrice)
|
||||
logger.Infof(" 止盈价设置: %.4f", takeProfitPrice)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -836,14 +838,14 @@ func (t *FuturesTrader) GetSymbolPrecision(symbol string) (int, error) {
|
||||
if filter["filterType"] == "LOT_SIZE" {
|
||||
stepSize := filter["stepSize"].(string)
|
||||
precision := calculatePrecision(stepSize)
|
||||
log.Printf(" %s 数量精度: %d (stepSize: %s)", symbol, precision, stepSize)
|
||||
logger.Infof(" %s 数量精度: %d (stepSize: %s)", symbol, precision, stepSize)
|
||||
return precision, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf(" ⚠ %s 未找到精度信息,使用默认精度3", symbol)
|
||||
logger.Infof(" ⚠ %s 未找到精度信息,使用默认精度3", symbol)
|
||||
return 3, nil // 默认精度为3
|
||||
}
|
||||
|
||||
@@ -915,3 +917,42 @@ func stringContains(s, substr string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetOrderStatus 获取订单状态
|
||||
func (t *FuturesTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
|
||||
// 将 orderID 转换为 int64
|
||||
orderIDInt, err := strconv.ParseInt(orderID, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无效的订单ID: %s", orderID)
|
||||
}
|
||||
|
||||
order, err := t.client.NewGetOrderService().
|
||||
Symbol(symbol).
|
||||
OrderID(orderIDInt).
|
||||
Do(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取订单状态失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析成交价格
|
||||
avgPrice, _ := strconv.ParseFloat(order.AvgPrice, 64)
|
||||
executedQty, _ := strconv.ParseFloat(order.ExecutedQuantity, 64)
|
||||
|
||||
result := map[string]interface{}{
|
||||
"orderId": order.OrderID,
|
||||
"symbol": order.Symbol,
|
||||
"status": string(order.Status),
|
||||
"avgPrice": avgPrice,
|
||||
"executedQty": executedQty,
|
||||
"side": string(order.Side),
|
||||
"type": string(order.Type),
|
||||
"time": order.Time,
|
||||
"updateTime": order.UpdateTime,
|
||||
}
|
||||
|
||||
// 币安合约的手续费需要通过 GetUserTrades 获取,这里暂时不获取
|
||||
// 后续可以通过 WebSocket 或单独查询获取
|
||||
result["commission"] = 0.0
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package trader
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -55,7 +55,7 @@ func NewBybitTrader(apiKey, secretKey string) *BybitTrader {
|
||||
cacheDuration: 15 * time.Second,
|
||||
}
|
||||
|
||||
log.Printf("🔵 [Bybit] 交易器已初始化")
|
||||
logger.Infof("🔵 [Bybit] 交易器已初始化")
|
||||
|
||||
return trader
|
||||
}
|
||||
@@ -224,7 +224,7 @@ func (t *BybitTrader) GetPositions() ([]map[string]interface{}, error) {
|
||||
func (t *BybitTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
|
||||
// 先设置杠杆
|
||||
if err := t.SetLeverage(symbol, leverage); err != nil {
|
||||
log.Printf("⚠️ [Bybit] 设置杠杆失败: %v", err)
|
||||
logger.Infof("⚠️ [Bybit] 设置杠杆失败: %v", err)
|
||||
}
|
||||
|
||||
params := map[string]interface{}{
|
||||
@@ -251,7 +251,7 @@ func (t *BybitTrader) OpenLong(symbol string, quantity float64, leverage int) (m
|
||||
func (t *BybitTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
|
||||
// 先设置杠杆
|
||||
if err := t.SetLeverage(symbol, leverage); err != nil {
|
||||
log.Printf("⚠️ [Bybit] 设置杠杆失败: %v", err)
|
||||
logger.Infof("⚠️ [Bybit] 设置杠杆失败: %v", err)
|
||||
}
|
||||
|
||||
params := map[string]interface{}{
|
||||
@@ -485,7 +485,7 @@ func (t *BybitTrader) SetStopLoss(symbol string, positionSide string, quantity,
|
||||
return fmt.Errorf("设置止损失败: %s", result.RetMsg)
|
||||
}
|
||||
|
||||
log.Printf(" ✓ [Bybit] 止损单已设置: %s @ %.2f", symbol, stopPrice)
|
||||
logger.Infof(" ✓ [Bybit] 止损单已设置: %s @ %.2f", symbol, stopPrice)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -528,7 +528,7 @@ func (t *BybitTrader) SetTakeProfit(symbol string, positionSide string, quantity
|
||||
return fmt.Errorf("设置止盈失败: %s", result.RetMsg)
|
||||
}
|
||||
|
||||
log.Printf(" ✓ [Bybit] 止盈单已设置: %s @ %.2f", symbol, takeProfitPrice)
|
||||
logger.Infof(" ✓ [Bybit] 止盈单已设置: %s @ %.2f", symbol, takeProfitPrice)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -560,10 +560,10 @@ func (t *BybitTrader) CancelAllOrders(symbol string) error {
|
||||
// CancelStopOrders 取消所有止盈止损单
|
||||
func (t *BybitTrader) CancelStopOrders(symbol string) error {
|
||||
if err := t.CancelStopLossOrders(symbol); err != nil {
|
||||
log.Printf("⚠️ [Bybit] 取消止损单失败: %v", err)
|
||||
logger.Infof("⚠️ [Bybit] 取消止损单失败: %v", err)
|
||||
}
|
||||
if err := t.CancelTakeProfitOrders(symbol); err != nil {
|
||||
log.Printf("⚠️ [Bybit] 取消止盈单失败: %v", err)
|
||||
logger.Infof("⚠️ [Bybit] 取消止盈单失败: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -604,6 +604,67 @@ func (t *BybitTrader) parseOrderResult(result *bybit.ServerResponse) (map[string
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetOrderStatus 获取订单状态
|
||||
func (t *BybitTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
|
||||
params := map[string]interface{}{
|
||||
"category": "linear",
|
||||
"symbol": symbol,
|
||||
"orderId": orderID,
|
||||
}
|
||||
|
||||
result, err := t.client.NewUtaBybitServiceWithParams(params).GetOrderHistory(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取订单状态失败: %w", err)
|
||||
}
|
||||
|
||||
if result.RetCode != 0 {
|
||||
return nil, fmt.Errorf("API 错误: %s", result.RetMsg)
|
||||
}
|
||||
|
||||
resultData, ok := result.Result.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("返回格式错误")
|
||||
}
|
||||
|
||||
list, _ := resultData["list"].([]interface{})
|
||||
if len(list) == 0 {
|
||||
return nil, fmt.Errorf("未找到订单 %s", orderID)
|
||||
}
|
||||
|
||||
order, _ := list[0].(map[string]interface{})
|
||||
|
||||
// 解析订单数据
|
||||
status, _ := order["orderStatus"].(string)
|
||||
avgPriceStr, _ := order["avgPrice"].(string)
|
||||
cumExecQtyStr, _ := order["cumExecQty"].(string)
|
||||
cumExecFeeStr, _ := order["cumExecFee"].(string)
|
||||
|
||||
avgPrice, _ := strconv.ParseFloat(avgPriceStr, 64)
|
||||
executedQty, _ := strconv.ParseFloat(cumExecQtyStr, 64)
|
||||
commission, _ := strconv.ParseFloat(cumExecFeeStr, 64)
|
||||
|
||||
// 转换状态为统一格式
|
||||
unifiedStatus := status
|
||||
switch status {
|
||||
case "Filled":
|
||||
unifiedStatus = "FILLED"
|
||||
case "New", "Created":
|
||||
unifiedStatus = "NEW"
|
||||
case "Cancelled", "Rejected":
|
||||
unifiedStatus = "CANCELED"
|
||||
case "PartiallyFilled":
|
||||
unifiedStatus = "PARTIALLY_FILLED"
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"orderId": orderID,
|
||||
"status": unifiedStatus,
|
||||
"avgPrice": avgPrice,
|
||||
"executedQty": executedQty,
|
||||
"commission": commission,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *BybitTrader) cancelConditionalOrders(symbol string, orderType string) error {
|
||||
// 先获取所有条件单
|
||||
params := map[string]interface{}{
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -56,14 +56,14 @@ func NewHyperliquidTrader(privateKeyHex string, walletAddr string, testnet bool)
|
||||
|
||||
// Check if user accidentally uses main wallet private key (security risk)
|
||||
if strings.EqualFold(walletAddr, agentAddr) {
|
||||
log.Printf("⚠️⚠️⚠️ WARNING: Main wallet address (%s) matches Agent wallet address!", walletAddr)
|
||||
log.Printf(" This indicates you may be using your main wallet private key, which poses extremely high security risks!")
|
||||
log.Printf(" Recommendation: Immediately create a separate Agent Wallet on Hyperliquid official website")
|
||||
log.Printf(" Reference: https://hyperliquid.gitbook.io/hyperliquid-docs/for-developers/api/nonces-and-api-wallets")
|
||||
logger.Infof("⚠️⚠️⚠️ WARNING: Main wallet address (%s) matches Agent wallet address!", walletAddr)
|
||||
logger.Infof(" This indicates you may be using your main wallet private key, which poses extremely high security risks!")
|
||||
logger.Infof(" Recommendation: Immediately create a separate Agent Wallet on Hyperliquid official website")
|
||||
logger.Infof(" Reference: https://hyperliquid.gitbook.io/hyperliquid-docs/for-developers/api/nonces-and-api-wallets")
|
||||
} else {
|
||||
log.Printf("✓ Using Agent Wallet mode (secure)")
|
||||
log.Printf(" └─ Agent wallet address: %s (for signing)", agentAddr)
|
||||
log.Printf(" └─ Main wallet address: %s (holds funds)", walletAddr)
|
||||
logger.Infof("✓ Using Agent Wallet mode (secure)")
|
||||
logger.Infof(" └─ Agent wallet address: %s (for signing)", agentAddr)
|
||||
logger.Infof(" └─ Main wallet address: %s (holds funds)", walletAddr)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -79,7 +79,7 @@ func NewHyperliquidTrader(privateKeyHex string, walletAddr string, testnet bool)
|
||||
nil, // SpotMeta will be fetched automatically
|
||||
)
|
||||
|
||||
log.Printf("✓ Hyperliquid交易器初始化成功 (testnet=%v, wallet=%s)", testnet, walletAddr)
|
||||
logger.Infof("✓ Hyperliquid交易器初始化成功 (testnet=%v, wallet=%s)", testnet, walletAddr)
|
||||
|
||||
// 获取meta信息(包含精度等配置)
|
||||
meta, err := exchange.Info().Meta(ctx)
|
||||
@@ -97,26 +97,26 @@ func NewHyperliquidTrader(privateKeyHex string, walletAddr string, testnet bool)
|
||||
|
||||
if agentBalance > 100 {
|
||||
// Critical: Agent wallet holds too much funds
|
||||
log.Printf("🚨🚨🚨 CRITICAL SECURITY WARNING 🚨🚨🚨")
|
||||
log.Printf(" Agent wallet balance: %.2f USDC (exceeds safe threshold of 100 USDC)", agentBalance)
|
||||
log.Printf(" Agent wallet address: %s", agentAddr)
|
||||
log.Printf(" ⚠️ Agent wallets should only be used for signing and hold minimal/zero balance")
|
||||
log.Printf(" ⚠️ High balance in Agent wallet poses security risks")
|
||||
log.Printf(" 📖 Reference: https://hyperliquid.gitbook.io/hyperliquid-docs/for-developers/api/nonces-and-api-wallets")
|
||||
log.Printf(" 💡 Recommendation: Transfer funds to main wallet and keep Agent wallet balance near 0")
|
||||
logger.Infof("🚨🚨🚨 CRITICAL SECURITY WARNING 🚨🚨🚨")
|
||||
logger.Infof(" Agent wallet balance: %.2f USDC (exceeds safe threshold of 100 USDC)", agentBalance)
|
||||
logger.Infof(" Agent wallet address: %s", agentAddr)
|
||||
logger.Infof(" ⚠️ Agent wallets should only be used for signing and hold minimal/zero balance")
|
||||
logger.Infof(" ⚠️ High balance in Agent wallet poses security risks")
|
||||
logger.Infof(" 📖 Reference: https://hyperliquid.gitbook.io/hyperliquid-docs/for-developers/api/nonces-and-api-wallets")
|
||||
logger.Infof(" 💡 Recommendation: Transfer funds to main wallet and keep Agent wallet balance near 0")
|
||||
return nil, fmt.Errorf("security check failed: Agent wallet balance too high (%.2f USDC), exceeds 100 USDC threshold", agentBalance)
|
||||
} else if agentBalance > 10 {
|
||||
// Warning: Agent wallet has some balance (acceptable but not ideal)
|
||||
log.Printf("⚠️ Notice: Agent wallet address (%s) has some balance: %.2f USDC", agentAddr, agentBalance)
|
||||
log.Printf(" While not critical, it's recommended to keep Agent wallet balance near 0 for security")
|
||||
logger.Infof("⚠️ Notice: Agent wallet address (%s) has some balance: %.2f USDC", agentAddr, agentBalance)
|
||||
logger.Infof(" While not critical, it's recommended to keep Agent wallet balance near 0 for security")
|
||||
} else {
|
||||
// OK: Agent wallet balance is safe
|
||||
log.Printf("✓ Agent wallet balance is safe: %.2f USDC (near zero as recommended)", agentBalance)
|
||||
logger.Infof("✓ Agent wallet balance is safe: %.2f USDC (near zero as recommended)", agentBalance)
|
||||
}
|
||||
} else if err != nil {
|
||||
// Failed to query agent balance - log warning but don't block initialization
|
||||
log.Printf("⚠️ Could not verify Agent wallet balance (query failed): %v", err)
|
||||
log.Printf(" Proceeding with initialization, but please manually verify Agent wallet balance is near 0")
|
||||
logger.Infof("⚠️ Could not verify Agent wallet balance (query failed): %v", err)
|
||||
logger.Infof(" Proceeding with initialization, but please manually verify Agent wallet balance is near 0")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,18 +131,18 @@ func NewHyperliquidTrader(privateKeyHex string, walletAddr string, testnet bool)
|
||||
|
||||
// GetBalance 获取账户余额
|
||||
func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) {
|
||||
log.Printf("🔄 正在调用Hyperliquid API获取账户余额...")
|
||||
logger.Infof("🔄 正在调用Hyperliquid API获取账户余额...")
|
||||
|
||||
// ✅ Step 1: 查询 Spot 现货账户余额
|
||||
spotState, err := t.exchange.Info().SpotUserState(t.ctx, t.walletAddr)
|
||||
var spotUSDCBalance float64 = 0.0
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 查询 Spot 余额失败(可能无现货资产): %v", err)
|
||||
logger.Infof("⚠️ 查询 Spot 余额失败(可能无现货资产): %v", err)
|
||||
} else if spotState != nil && len(spotState.Balances) > 0 {
|
||||
for _, balance := range spotState.Balances {
|
||||
if balance.Coin == "USDC" {
|
||||
spotUSDCBalance, _ = strconv.ParseFloat(balance.Total, 64)
|
||||
log.Printf("✓ 发现 Spot 现货余额: %.2f USDC", spotUSDCBalance)
|
||||
logger.Infof("✓ 发现 Spot 现货余额: %.2f USDC", spotUSDCBalance)
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -151,7 +151,7 @@ func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) {
|
||||
// ✅ Step 2: 查询 Perpetuals 合约账户状态
|
||||
accountState, err := t.exchange.Info().UserState(t.ctx, t.walletAddr)
|
||||
if err != nil {
|
||||
log.Printf("❌ Hyperliquid Perpetuals API调用失败: %v", err)
|
||||
logger.Infof("❌ Hyperliquid Perpetuals API调用失败: %v", err)
|
||||
return nil, fmt.Errorf("获取账户信息失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -179,8 +179,8 @@ func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) {
|
||||
|
||||
// 🔍 调试:打印API返回的完整摘要结构
|
||||
summaryJSON, _ := json.MarshalIndent(summary, " ", " ")
|
||||
log.Printf("🔍 [DEBUG] Hyperliquid API %s 完整数据:", summaryType)
|
||||
log.Printf("%s", string(summaryJSON))
|
||||
logger.Infof("🔍 [DEBUG] Hyperliquid API %s 完整数据:", summaryType)
|
||||
logger.Infof("%s", string(summaryJSON))
|
||||
|
||||
// ⚠️ 关键修复:从所有持仓中累加真正的未实现盈亏
|
||||
totalUnrealizedPnl := 0.0
|
||||
@@ -204,7 +204,7 @@ func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) {
|
||||
withdrawable, err := strconv.ParseFloat(accountState.Withdrawable, 64)
|
||||
if err == nil && withdrawable > 0 {
|
||||
availableBalance = withdrawable
|
||||
log.Printf("✓ 使用 Withdrawable 作为可用余额: %.2f", availableBalance)
|
||||
logger.Infof("✓ 使用 Withdrawable 作为可用余额: %.2f", availableBalance)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,7 +212,7 @@ func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) {
|
||||
if availableBalance == 0 && accountState.Withdrawable == "" {
|
||||
availableBalance = accountValue - totalMarginUsed
|
||||
if availableBalance < 0 {
|
||||
log.Printf("⚠️ 计算出的可用余额为负数 (%.2f),重置为 0", availableBalance)
|
||||
logger.Infof("⚠️ 计算出的可用余额为负数 (%.2f),重置为 0", availableBalance)
|
||||
availableBalance = 0
|
||||
}
|
||||
}
|
||||
@@ -227,16 +227,16 @@ func (t *HyperliquidTrader) GetBalance() (map[string]interface{}, error) {
|
||||
result["totalUnrealizedProfit"] = totalUnrealizedPnl // 未实现盈亏(仅来自 Perpetuals)
|
||||
result["spotBalance"] = spotUSDCBalance // Spot 现货余额(单独返回)
|
||||
|
||||
log.Printf("✓ Hyperliquid 完整账户:")
|
||||
log.Printf(" • Spot 现货余额: %.2f USDC (需手动转账到 Perpetuals 才能开仓)", spotUSDCBalance)
|
||||
log.Printf(" • Perpetuals 合约净值: %.2f USDC (钱包%.2f + 未实现%.2f)",
|
||||
logger.Infof("✓ Hyperliquid 完整账户:")
|
||||
logger.Infof(" • Spot 现货余额: %.2f USDC (需手动转账到 Perpetuals 才能开仓)", spotUSDCBalance)
|
||||
logger.Infof(" • Perpetuals 合约净值: %.2f USDC (钱包%.2f + 未实现%.2f)",
|
||||
accountValue,
|
||||
walletBalanceWithoutUnrealized,
|
||||
totalUnrealizedPnl)
|
||||
log.Printf(" • Perpetuals 可用余额: %.2f USDC (可直接用于开仓)", availableBalance)
|
||||
log.Printf(" • 保证金占用: %.2f USDC", totalMarginUsed)
|
||||
log.Printf(" • 总资产 (Perp+Spot): %.2f USDC", totalWalletBalance)
|
||||
log.Printf(" ⭐ 总资产: %.2f USDC | Perp 可用: %.2f USDC | Spot 余额: %.2f USDC",
|
||||
logger.Infof(" • Perpetuals 可用余额: %.2f USDC (可直接用于开仓)", availableBalance)
|
||||
logger.Infof(" • 保证金占用: %.2f USDC", totalMarginUsed)
|
||||
logger.Infof(" • 总资产 (Perp+Spot): %.2f USDC", totalWalletBalance)
|
||||
logger.Infof(" ⭐ 总资产: %.2f USDC | Perp 可用: %.2f USDC | Spot 余额: %.2f USDC",
|
||||
totalWalletBalance, availableBalance, spotUSDCBalance)
|
||||
|
||||
return result, nil
|
||||
@@ -316,7 +316,7 @@ func (t *HyperliquidTrader) SetMarginMode(symbol string, isCrossMargin bool) err
|
||||
if !isCrossMargin {
|
||||
marginModeStr = "逐仓"
|
||||
}
|
||||
log.Printf(" ✓ %s 将使用 %s 模式", symbol, marginModeStr)
|
||||
logger.Infof(" ✓ %s 将使用 %s 模式", symbol, marginModeStr)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -332,7 +332,7 @@ func (t *HyperliquidTrader) SetLeverage(symbol string, leverage int) error {
|
||||
return fmt.Errorf("设置杠杆失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf(" ✓ %s 杠杆已切换为 %dx", symbol, leverage)
|
||||
logger.Infof(" ✓ %s 杠杆已切换为 %dx", symbol, leverage)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -343,7 +343,7 @@ func (t *HyperliquidTrader) refreshMetaIfNeeded(coin string) error {
|
||||
return nil // Meta 正常,无需刷新
|
||||
}
|
||||
|
||||
log.Printf("⚠️ %s 的 Asset ID 为 0,尝试刷新 Meta 信息...", coin)
|
||||
logger.Infof("⚠️ %s 的 Asset ID 为 0,尝试刷新 Meta 信息...", coin)
|
||||
|
||||
// 刷新 Meta 信息
|
||||
meta, err := t.exchange.Info().Meta(t.ctx)
|
||||
@@ -356,7 +356,7 @@ func (t *HyperliquidTrader) refreshMetaIfNeeded(coin string) error {
|
||||
t.meta = meta
|
||||
t.metaMutex.Unlock()
|
||||
|
||||
log.Printf("✅ Meta 信息已刷新,包含 %d 个资产", len(meta.Universe))
|
||||
logger.Infof("✅ Meta 信息已刷新,包含 %d 个资产", len(meta.Universe))
|
||||
|
||||
// 验证刷新后的 Asset ID
|
||||
assetID = t.exchange.Info().NameToAsset(coin)
|
||||
@@ -367,7 +367,7 @@ func (t *HyperliquidTrader) refreshMetaIfNeeded(coin string) error {
|
||||
" 3. API 连接问题", coin)
|
||||
}
|
||||
|
||||
log.Printf("✅ 刷新后 Asset ID 检查通过: %s -> %d", coin, assetID)
|
||||
logger.Infof("✅ 刷新后 Asset ID 检查通过: %s -> %d", coin, assetID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -375,7 +375,7 @@ func (t *HyperliquidTrader) refreshMetaIfNeeded(coin string) error {
|
||||
func (t *HyperliquidTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
|
||||
// 先取消该币种的所有委托单
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消旧委托单失败: %v", err)
|
||||
logger.Infof(" ⚠ 取消旧委托单失败: %v", err)
|
||||
}
|
||||
|
||||
// 设置杠杆
|
||||
@@ -394,11 +394,11 @@ func (t *HyperliquidTrader) OpenLong(symbol string, quantity float64, leverage i
|
||||
|
||||
// ⚠️ 关键:根据币种精度要求,四舍五入数量
|
||||
roundedQuantity := t.roundToSzDecimals(coin, quantity)
|
||||
log.Printf(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin))
|
||||
logger.Infof(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin))
|
||||
|
||||
// ⚠️ 关键:价格也需要处理为5位有效数字
|
||||
aggressivePrice := t.roundPriceToSigfigs(price * 1.01)
|
||||
log.Printf(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*1.01, aggressivePrice)
|
||||
logger.Infof(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*1.01, aggressivePrice)
|
||||
|
||||
// 创建市价买入订单(使用IOC limit order with aggressive price)
|
||||
order := hyperliquid.CreateOrderRequest{
|
||||
@@ -419,7 +419,7 @@ func (t *HyperliquidTrader) OpenLong(symbol string, quantity float64, leverage i
|
||||
return nil, fmt.Errorf("开多仓失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ 开多仓成功: %s 数量: %.4f", symbol, roundedQuantity)
|
||||
logger.Infof("✓ 开多仓成功: %s 数量: %.4f", symbol, roundedQuantity)
|
||||
|
||||
result := make(map[string]interface{})
|
||||
result["orderId"] = 0 // Hyperliquid没有返回order ID
|
||||
@@ -433,7 +433,7 @@ func (t *HyperliquidTrader) OpenLong(symbol string, quantity float64, leverage i
|
||||
func (t *HyperliquidTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
|
||||
// 先取消该币种的所有委托单
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消旧委托单失败: %v", err)
|
||||
logger.Infof(" ⚠ 取消旧委托单失败: %v", err)
|
||||
}
|
||||
|
||||
// 设置杠杆
|
||||
@@ -452,11 +452,11 @@ func (t *HyperliquidTrader) OpenShort(symbol string, quantity float64, leverage
|
||||
|
||||
// ⚠️ 关键:根据币种精度要求,四舍五入数量
|
||||
roundedQuantity := t.roundToSzDecimals(coin, quantity)
|
||||
log.Printf(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin))
|
||||
logger.Infof(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin))
|
||||
|
||||
// ⚠️ 关键:价格也需要处理为5位有效数字
|
||||
aggressivePrice := t.roundPriceToSigfigs(price * 0.99)
|
||||
log.Printf(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*0.99, aggressivePrice)
|
||||
logger.Infof(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*0.99, aggressivePrice)
|
||||
|
||||
// 创建市价卖出订单
|
||||
order := hyperliquid.CreateOrderRequest{
|
||||
@@ -477,7 +477,7 @@ func (t *HyperliquidTrader) OpenShort(symbol string, quantity float64, leverage
|
||||
return nil, fmt.Errorf("开空仓失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ 开空仓成功: %s 数量: %.4f", symbol, roundedQuantity)
|
||||
logger.Infof("✓ 开空仓成功: %s 数量: %.4f", symbol, roundedQuantity)
|
||||
|
||||
result := make(map[string]interface{})
|
||||
result["orderId"] = 0
|
||||
@@ -519,11 +519,11 @@ func (t *HyperliquidTrader) CloseLong(symbol string, quantity float64) (map[stri
|
||||
|
||||
// ⚠️ 关键:根据币种精度要求,四舍五入数量
|
||||
roundedQuantity := t.roundToSzDecimals(coin, quantity)
|
||||
log.Printf(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin))
|
||||
logger.Infof(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin))
|
||||
|
||||
// ⚠️ 关键:价格也需要处理为5位有效数字
|
||||
aggressivePrice := t.roundPriceToSigfigs(price * 0.99)
|
||||
log.Printf(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*0.99, aggressivePrice)
|
||||
logger.Infof(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*0.99, aggressivePrice)
|
||||
|
||||
// 创建平仓订单(卖出 + ReduceOnly)
|
||||
order := hyperliquid.CreateOrderRequest{
|
||||
@@ -544,11 +544,11 @@ func (t *HyperliquidTrader) CloseLong(symbol string, quantity float64) (map[stri
|
||||
return nil, fmt.Errorf("平多仓失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ 平多仓成功: %s 数量: %.4f", symbol, roundedQuantity)
|
||||
logger.Infof("✓ 平多仓成功: %s 数量: %.4f", symbol, roundedQuantity)
|
||||
|
||||
// 平仓后取消该币种的所有挂单
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消挂单失败: %v", err)
|
||||
logger.Infof(" ⚠ 取消挂单失败: %v", err)
|
||||
}
|
||||
|
||||
result := make(map[string]interface{})
|
||||
@@ -591,11 +591,11 @@ func (t *HyperliquidTrader) CloseShort(symbol string, quantity float64) (map[str
|
||||
|
||||
// ⚠️ 关键:根据币种精度要求,四舍五入数量
|
||||
roundedQuantity := t.roundToSzDecimals(coin, quantity)
|
||||
log.Printf(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin))
|
||||
logger.Infof(" 📏 数量精度处理: %.8f -> %.8f (szDecimals=%d)", quantity, roundedQuantity, t.getSzDecimals(coin))
|
||||
|
||||
// ⚠️ 关键:价格也需要处理为5位有效数字
|
||||
aggressivePrice := t.roundPriceToSigfigs(price * 1.01)
|
||||
log.Printf(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*1.01, aggressivePrice)
|
||||
logger.Infof(" 💰 价格精度处理: %.8f -> %.8f (5位有效数字)", price*1.01, aggressivePrice)
|
||||
|
||||
// 创建平仓订单(买入 + ReduceOnly)
|
||||
order := hyperliquid.CreateOrderRequest{
|
||||
@@ -616,11 +616,11 @@ func (t *HyperliquidTrader) CloseShort(symbol string, quantity float64) (map[str
|
||||
return nil, fmt.Errorf("平空仓失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ 平空仓成功: %s 数量: %.4f", symbol, roundedQuantity)
|
||||
logger.Infof("✓ 平空仓成功: %s 数量: %.4f", symbol, roundedQuantity)
|
||||
|
||||
// 平仓后取消该币种的所有挂单
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消挂单失败: %v", err)
|
||||
logger.Infof(" ⚠ 取消挂单失败: %v", err)
|
||||
}
|
||||
|
||||
result := make(map[string]interface{})
|
||||
@@ -637,7 +637,7 @@ func (t *HyperliquidTrader) CloseShort(symbol string, quantity float64) (map[str
|
||||
func (t *HyperliquidTrader) CancelStopLossOrders(symbol string) error {
|
||||
// Hyperliquid SDK 的 OpenOrder 结构不暴露 trigger 字段
|
||||
// 无法区分止损和止盈单,因此取消该币种的所有挂单
|
||||
log.Printf(" ⚠️ Hyperliquid 无法区分止损/止盈单,将取消所有挂单")
|
||||
logger.Infof(" ⚠️ Hyperliquid 无法区分止损/止盈单,将取消所有挂单")
|
||||
return t.CancelStopOrders(symbol)
|
||||
}
|
||||
|
||||
@@ -645,7 +645,7 @@ func (t *HyperliquidTrader) CancelStopLossOrders(symbol string) error {
|
||||
func (t *HyperliquidTrader) CancelTakeProfitOrders(symbol string) error {
|
||||
// Hyperliquid SDK 的 OpenOrder 结构不暴露 trigger 字段
|
||||
// 无法区分止损和止盈单,因此取消该币种的所有挂单
|
||||
log.Printf(" ⚠️ Hyperliquid 无法区分止损/止盈单,将取消所有挂单")
|
||||
logger.Infof(" ⚠️ Hyperliquid 无法区分止损/止盈单,将取消所有挂单")
|
||||
return t.CancelStopOrders(symbol)
|
||||
}
|
||||
|
||||
@@ -664,12 +664,12 @@ func (t *HyperliquidTrader) CancelAllOrders(symbol string) error {
|
||||
if order.Coin == coin {
|
||||
_, err := t.exchange.Cancel(t.ctx, coin, order.Oid)
|
||||
if err != nil {
|
||||
log.Printf(" ⚠ 取消订单失败 (oid=%d): %v", order.Oid, err)
|
||||
logger.Infof(" ⚠ 取消订单失败 (oid=%d): %v", order.Oid, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf(" ✓ 已取消 %s 的所有挂单", symbol)
|
||||
logger.Infof(" ✓ 已取消 %s 的所有挂单", symbol)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -691,7 +691,7 @@ func (t *HyperliquidTrader) CancelStopOrders(symbol string) error {
|
||||
if order.Coin == coin {
|
||||
_, err := t.exchange.Cancel(t.ctx, coin, order.Oid)
|
||||
if err != nil {
|
||||
log.Printf(" ⚠ 取消订单失败 (oid=%d): %v", order.Oid, err)
|
||||
logger.Infof(" ⚠ 取消订单失败 (oid=%d): %v", order.Oid, err)
|
||||
continue
|
||||
}
|
||||
canceledCount++
|
||||
@@ -699,9 +699,9 @@ func (t *HyperliquidTrader) CancelStopOrders(symbol string) error {
|
||||
}
|
||||
|
||||
if canceledCount == 0 {
|
||||
log.Printf(" ℹ %s 没有挂单需要取消", symbol)
|
||||
logger.Infof(" ℹ %s 没有挂单需要取消", symbol)
|
||||
} else {
|
||||
log.Printf(" ✓ 已取消 %s 的 %d 个挂单(包括止盈/止损单)", symbol, canceledCount)
|
||||
logger.Infof(" ✓ 已取消 %s 的 %d 个挂单(包括止盈/止损单)", symbol, canceledCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -762,7 +762,7 @@ func (t *HyperliquidTrader) SetStopLoss(symbol string, positionSide string, quan
|
||||
return fmt.Errorf("设置止损失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf(" 止损价设置: %.4f", roundedStopPrice)
|
||||
logger.Infof(" 止损价设置: %.4f", roundedStopPrice)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -799,7 +799,7 @@ func (t *HyperliquidTrader) SetTakeProfit(symbol string, positionSide string, qu
|
||||
return fmt.Errorf("设置止盈失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf(" 止盈价设置: %.4f", roundedTakeProfitPrice)
|
||||
logger.Infof(" 止盈价设置: %.4f", roundedTakeProfitPrice)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -820,7 +820,7 @@ func (t *HyperliquidTrader) getSzDecimals(coin string) int {
|
||||
defer t.metaMutex.RUnlock()
|
||||
|
||||
if t.meta == nil {
|
||||
log.Printf("⚠️ meta信息为空,使用默认精度4")
|
||||
logger.Infof("⚠️ meta信息为空,使用默认精度4")
|
||||
return 4 // 默认精度
|
||||
}
|
||||
|
||||
@@ -831,7 +831,7 @@ func (t *HyperliquidTrader) getSzDecimals(coin string) int {
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("⚠️ 未找到 %s 的精度信息,使用默认精度4", coin)
|
||||
logger.Infof("⚠️ 未找到 %s 的精度信息,使用默认精度4", coin)
|
||||
return 4 // 默认精度
|
||||
}
|
||||
|
||||
@@ -897,6 +897,53 @@ func convertSymbolToHyperliquid(symbol string) string {
|
||||
return symbol
|
||||
}
|
||||
|
||||
// GetOrderStatus 获取订单状态
|
||||
// Hyperliquid 使用 IOC 订单,通常立即成交或取消
|
||||
// 对于已完成的订单,需要查询历史记录
|
||||
func (t *HyperliquidTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
|
||||
// Hyperliquid 的 IOC 订单几乎立即完成
|
||||
// 如果订单是通过本系统下单的,返回的 status 都是 FILLED
|
||||
// 这里尝试查询开放订单来判断是否还在等待
|
||||
coin := convertSymbolToHyperliquid(symbol)
|
||||
|
||||
// 首先检查是否在开放订单中
|
||||
openOrders, err := t.exchange.Info().OpenOrders(t.ctx, t.walletAddr)
|
||||
if err != nil {
|
||||
// 如果查询失败,假设订单已完成
|
||||
return map[string]interface{}{
|
||||
"orderId": orderID,
|
||||
"status": "FILLED",
|
||||
"avgPrice": 0.0,
|
||||
"executedQty": 0.0,
|
||||
"commission": 0.0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 检查订单是否在开放订单列表中
|
||||
for _, order := range openOrders {
|
||||
if order.Coin == coin && fmt.Sprintf("%d", order.Oid) == orderID {
|
||||
// 订单仍在等待
|
||||
return map[string]interface{}{
|
||||
"orderId": orderID,
|
||||
"status": "NEW",
|
||||
"avgPrice": 0.0,
|
||||
"executedQty": 0.0,
|
||||
"commission": 0.0,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 订单不在开放列表中,说明已完成或已取消
|
||||
// Hyperliquid IOC 订单如果不在开放列表中,通常是已成交
|
||||
return map[string]interface{}{
|
||||
"orderId": orderID,
|
||||
"status": "FILLED",
|
||||
"avgPrice": 0.0, // Hyperliquid 不直接返回成交价格,需要从持仓信息获取
|
||||
"executedQty": 0.0,
|
||||
"commission": 0.0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// absFloat 返回浮点数的绝对值
|
||||
func absFloat(x float64) float64 {
|
||||
if x < 0 {
|
||||
|
||||
@@ -50,4 +50,8 @@ type Trader interface {
|
||||
|
||||
// FormatQuantity 格式化数量到正确的精度
|
||||
FormatQuantity(symbol string, quantity float64) (string, error)
|
||||
|
||||
// GetOrderStatus 获取订单状态
|
||||
// 返回: status(FILLED/NEW/CANCELED), avgPrice, executedQty, commission
|
||||
GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error)
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@@ -62,7 +62,7 @@ func (t *LighterTrader) CreateOrder(symbol, side string, quantity, price float64
|
||||
return "", err
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER订单已创建 - ID: %s, Symbol: %s, Side: %s, Qty: %.4f",
|
||||
logger.Infof("✓ LIGHTER订单已创建 - ID: %s, Symbol: %s, Side: %s, Qty: %.4f",
|
||||
orderResp.OrderID, symbol, side, quantity)
|
||||
|
||||
return orderResp.OrderID, nil
|
||||
@@ -143,7 +143,7 @@ func (t *LighterTrader) CancelOrder(symbol, orderID string) error {
|
||||
return fmt.Errorf("取消订单失败 (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER订单已取消 - ID: %s", orderID)
|
||||
logger.Infof("✓ LIGHTER订单已取消 - ID: %s", orderID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -160,18 +160,18 @@ func (t *LighterTrader) CancelAllOrders(symbol string) error {
|
||||
}
|
||||
|
||||
if len(orders) == 0 {
|
||||
log.Printf("✓ LIGHTER - 无需取消订单(无活跃订单)")
|
||||
logger.Infof("✓ LIGHTER - 无需取消订单(无活跃订单)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 批量取消
|
||||
for _, order := range orders {
|
||||
if err := t.CancelOrder(symbol, order.OrderID); err != nil {
|
||||
log.Printf("⚠️ 取消订单失败 (ID: %s): %v", order.OrderID, err)
|
||||
logger.Infof("⚠️ 取消订单失败 (ID: %s): %v", order.OrderID, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER - 已取消 %d 个订单", len(orders))
|
||||
logger.Infof("✓ LIGHTER - 已取消 %d 个订单", len(orders))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -223,8 +223,8 @@ func (t *LighterTrader) GetActiveOrders(symbol string) ([]OrderResponse, error)
|
||||
return orders, nil
|
||||
}
|
||||
|
||||
// GetOrderStatus 获取订单状态
|
||||
func (t *LighterTrader) GetOrderStatus(orderID string) (*OrderResponse, error) {
|
||||
// GetOrderStatus 获取订单状态(实现 Trader 接口)
|
||||
func (t *LighterTrader) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
|
||||
if err := t.ensureAuthToken(); err != nil {
|
||||
return nil, fmt.Errorf("认证令牌无效: %w", err)
|
||||
}
|
||||
@@ -261,20 +261,37 @@ func (t *LighterTrader) GetOrderStatus(orderID string) (*OrderResponse, error) {
|
||||
return nil, fmt.Errorf("解析订单响应失败: %w", err)
|
||||
}
|
||||
|
||||
return &order, nil
|
||||
// 转换状态为统一格式
|
||||
unifiedStatus := order.Status
|
||||
switch order.Status {
|
||||
case "filled":
|
||||
unifiedStatus = "FILLED"
|
||||
case "open":
|
||||
unifiedStatus = "NEW"
|
||||
case "cancelled":
|
||||
unifiedStatus = "CANCELED"
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"orderId": order.OrderID,
|
||||
"status": unifiedStatus,
|
||||
"avgPrice": order.Price,
|
||||
"executedQty": order.FilledQty,
|
||||
"commission": 0.0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CancelStopLossOrders 仅取消止损单(LIGHTER 暂无法区分,取消所有止盈止损单)
|
||||
func (t *LighterTrader) CancelStopLossOrders(symbol string) error {
|
||||
// LIGHTER 暂时无法区分止损和止盈单,取消所有止盈止损单
|
||||
log.Printf(" ⚠️ LIGHTER 无法区分止损/止盈单,将取消所有止盈止损单")
|
||||
logger.Infof(" ⚠️ LIGHTER 无法区分止损/止盈单,将取消所有止盈止损单")
|
||||
return t.CancelStopOrders(symbol)
|
||||
}
|
||||
|
||||
// CancelTakeProfitOrders 仅取消止盈单(LIGHTER 暂无法区分,取消所有止盈止损单)
|
||||
func (t *LighterTrader) CancelTakeProfitOrders(symbol string) error {
|
||||
// LIGHTER 暂时无法区分止损和止盈单,取消所有止盈止损单
|
||||
log.Printf(" ⚠️ LIGHTER 无法区分止损/止盈单,将取消所有止盈止损单")
|
||||
logger.Infof(" ⚠️ LIGHTER 无法区分止损/止盈单,将取消所有止盈止损单")
|
||||
return t.CancelStopOrders(symbol)
|
||||
}
|
||||
|
||||
@@ -295,12 +312,12 @@ func (t *LighterTrader) CancelStopOrders(symbol string) error {
|
||||
// TODO: 需要检查订单类型,只取消止盈止损单
|
||||
// 暂时取消所有订单
|
||||
if err := t.CancelOrder(symbol, order.OrderID); err != nil {
|
||||
log.Printf("⚠️ 取消订单失败 (ID: %s): %v", order.OrderID, err)
|
||||
logger.Infof("⚠️ 取消订单失败 (ID: %s): %v", order.OrderID, err)
|
||||
} else {
|
||||
canceledCount++
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER - 已取消 %d 个止盈止损单", canceledCount)
|
||||
logger.Infof("✓ LIGHTER - 已取消 %d 个止盈止损单", canceledCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -59,7 +59,7 @@ func NewLighterTrader(privateKeyHex string, walletAddr string, testnet bool) (*L
|
||||
// 从私钥派生钱包地址(如果未提供)
|
||||
if walletAddr == "" {
|
||||
walletAddr = crypto.PubkeyToAddress(*privateKey.Public().(*ecdsa.PublicKey)).Hex()
|
||||
log.Printf("✓ 从私钥派生钱包地址: %s", walletAddr)
|
||||
logger.Infof("✓ 从私钥派生钱包地址: %s", walletAddr)
|
||||
}
|
||||
|
||||
// 选择API URL
|
||||
@@ -78,7 +78,7 @@ func NewLighterTrader(privateKeyHex string, walletAddr string, testnet bool) (*L
|
||||
symbolPrecision: make(map[string]SymbolPrecision),
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER交易器初始化成功 (testnet=%v, wallet=%s)", testnet, walletAddr)
|
||||
logger.Infof("✓ LIGHTER交易器初始化成功 (testnet=%v, wallet=%s)", testnet, walletAddr)
|
||||
|
||||
// 初始化账户信息(获取账户索引和API密钥)
|
||||
if err := trader.initializeAccount(); err != nil {
|
||||
@@ -100,7 +100,7 @@ func (t *LighterTrader) initializeAccount() error {
|
||||
t.accountIndex = accountInfo["index"].(int)
|
||||
t.accountMutex.Unlock()
|
||||
|
||||
log.Printf("✓ LIGHTER账户索引: %d", t.accountIndex)
|
||||
logger.Infof("✓ LIGHTER账户索引: %d", t.accountIndex)
|
||||
|
||||
// 2. 生成认证令牌(有效期8小时)
|
||||
if err := t.refreshAuthToken(); err != nil {
|
||||
@@ -153,7 +153,7 @@ func (t *LighterTrader) refreshAuthToken() error {
|
||||
|
||||
// 临时实现:设置过期时间为8小时后
|
||||
t.tokenExpiry = time.Now().Add(8 * time.Hour)
|
||||
log.Printf("✓ 认证令牌已生成(有效期至: %s)", t.tokenExpiry.Format(time.RFC3339))
|
||||
logger.Infof("✓ 认证令牌已生成(有效期至: %s)", t.tokenExpiry.Format(time.RFC3339))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -165,7 +165,7 @@ func (t *LighterTrader) ensureAuthToken() error {
|
||||
t.accountMutex.RUnlock()
|
||||
|
||||
if expired {
|
||||
log.Println("🔄 认证令牌即将过期,刷新中...")
|
||||
logger.Info("🔄 认证令牌即将过期,刷新中...")
|
||||
return t.refreshAuthToken()
|
||||
}
|
||||
|
||||
@@ -204,12 +204,12 @@ func (t *LighterTrader) GetExchangeType() string {
|
||||
|
||||
// Close 关闭交易器
|
||||
func (t *LighterTrader) Close() error {
|
||||
log.Println("✓ LIGHTER交易器已关闭")
|
||||
logger.Info("✓ LIGHTER交易器已关闭")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run 运行交易器(实现Trader接口)
|
||||
func (t *LighterTrader) Run() error {
|
||||
log.Println("⚠️ LIGHTER交易器的Run方法应由AutoTrader调用")
|
||||
logger.Info("⚠️ LIGHTER交易器的Run方法应由AutoTrader调用")
|
||||
return fmt.Errorf("请使用AutoTrader管理交易器生命周期")
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -76,7 +76,7 @@ func NewLighterTraderV2(l1PrivateKeyHex, walletAddr, apiKeyPrivateKeyHex string,
|
||||
// 2. 如果沒有提供錢包地址,從私鑰派生
|
||||
if walletAddr == "" {
|
||||
walletAddr = crypto.PubkeyToAddress(*l1PrivateKey.Public().(*ecdsa.PublicKey)).Hex()
|
||||
log.Printf("✓ 從私鑰派生錢包地址: %s", walletAddr)
|
||||
logger.Infof("✓ 從私鑰派生錢包地址: %s", walletAddr)
|
||||
}
|
||||
|
||||
// 3. 確定 API URL 和 Chain ID
|
||||
@@ -112,8 +112,8 @@ func NewLighterTraderV2(l1PrivateKeyHex, walletAddr, apiKeyPrivateKeyHex string,
|
||||
|
||||
// 6. 如果沒有 API Key,提示用戶需要生成
|
||||
if apiKeyPrivateKeyHex == "" {
|
||||
log.Printf("⚠️ 未提供 API Key 私鑰,請調用 GenerateAndRegisterAPIKey() 生成")
|
||||
log.Printf(" 或者從 LIGHTER 官網獲取現有的 API Key")
|
||||
logger.Infof("⚠️ 未提供 API Key 私鑰,請調用 GenerateAndRegisterAPIKey() 生成")
|
||||
logger.Infof(" 或者從 LIGHTER 官網獲取現有的 API Key")
|
||||
return trader, nil
|
||||
}
|
||||
|
||||
@@ -133,12 +133,12 @@ func NewLighterTraderV2(l1PrivateKeyHex, walletAddr, apiKeyPrivateKeyHex string,
|
||||
|
||||
// 8. 驗證 API Key 是否正確
|
||||
if err := trader.checkClient(); err != nil {
|
||||
log.Printf("⚠️ API Key 驗證失敗: %v", err)
|
||||
log.Printf(" 您可能需要重新生成 API Key 或檢查配置")
|
||||
logger.Infof("⚠️ API Key 驗證失敗: %v", err)
|
||||
logger.Infof(" 您可能需要重新生成 API Key 或檢查配置")
|
||||
return trader, err
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER 交易器初始化成功 (account=%d, apiKey=%d, testnet=%v)",
|
||||
logger.Infof("✓ LIGHTER 交易器初始化成功 (account=%d, apiKey=%d, testnet=%v)",
|
||||
trader.accountIndex, trader.apiKeyIndex, testnet)
|
||||
|
||||
return trader, nil
|
||||
@@ -156,7 +156,7 @@ func (t *LighterTraderV2) initializeAccount() error {
|
||||
t.accountIndex = accountInfo.AccountIndex
|
||||
t.accountMutex.Unlock()
|
||||
|
||||
log.Printf("✓ 賬戶索引: %d", t.accountIndex)
|
||||
logger.Infof("✓ 賬戶索引: %d", t.accountIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -214,7 +214,7 @@ func (t *LighterTraderV2) checkClient() error {
|
||||
return fmt.Errorf("API Key 不匹配:本地=%s, 服務器=%s", localPubKey, publicKey)
|
||||
}
|
||||
|
||||
log.Printf("✓ API Key 驗證通過")
|
||||
logger.Infof("✓ API Key 驗證通過")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -249,7 +249,7 @@ func (t *LighterTraderV2) refreshAuthToken() error {
|
||||
t.tokenExpiry = deadline
|
||||
t.accountMutex.Unlock()
|
||||
|
||||
log.Printf("✓ 認證令牌已生成(有效期至: %s)", t.tokenExpiry.Format(time.RFC3339))
|
||||
logger.Infof("✓ 認證令牌已生成(有效期至: %s)", t.tokenExpiry.Format(time.RFC3339))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -260,7 +260,7 @@ func (t *LighterTraderV2) ensureAuthToken() error {
|
||||
t.accountMutex.RUnlock()
|
||||
|
||||
if expired {
|
||||
log.Println("🔄 認證令牌即將過期,刷新中...")
|
||||
logger.Info("🔄 認證令牌即將過期,刷新中...")
|
||||
return t.refreshAuthToken()
|
||||
}
|
||||
|
||||
@@ -274,6 +274,6 @@ func (t *LighterTraderV2) GetExchangeType() string {
|
||||
|
||||
// Cleanup 清理資源
|
||||
func (t *LighterTraderV2) Cleanup() error {
|
||||
log.Println("⏹ LIGHTER 交易器清理完成")
|
||||
logger.Info("⏹ LIGHTER 交易器清理完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
@@ -18,7 +18,7 @@ func (t *LighterTraderV2) SetStopLoss(symbol string, positionSide string, quanti
|
||||
return fmt.Errorf("TxClient 未初始化")
|
||||
}
|
||||
|
||||
log.Printf("🛑 LIGHTER 設置止損: %s %s qty=%.4f, stop=%.2f", symbol, positionSide, quantity, stopPrice)
|
||||
logger.Infof("🛑 LIGHTER 設置止損: %s %s qty=%.4f, stop=%.2f", symbol, positionSide, quantity, stopPrice)
|
||||
|
||||
// 確定訂單方向(做空止損用買單,做多止損用賣單)
|
||||
isAsk := (positionSide == "LONG" || positionSide == "long")
|
||||
@@ -29,7 +29,7 @@ func (t *LighterTraderV2) SetStopLoss(symbol string, positionSide string, quanti
|
||||
return fmt.Errorf("設置止損失敗: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER 止損已設置: %.2f", stopPrice)
|
||||
logger.Infof("✓ LIGHTER 止損已設置: %.2f", stopPrice)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (t *LighterTraderV2) SetTakeProfit(symbol string, positionSide string, quan
|
||||
return fmt.Errorf("TxClient 未初始化")
|
||||
}
|
||||
|
||||
log.Printf("🎯 LIGHTER 設置止盈: %s %s qty=%.4f, tp=%.2f", symbol, positionSide, quantity, takeProfitPrice)
|
||||
logger.Infof("🎯 LIGHTER 設置止盈: %s %s qty=%.4f, tp=%.2f", symbol, positionSide, quantity, takeProfitPrice)
|
||||
|
||||
// 確定訂單方向(做空止盈用買單,做多止盈用賣單)
|
||||
isAsk := (positionSide == "LONG" || positionSide == "long")
|
||||
@@ -50,7 +50,7 @@ func (t *LighterTraderV2) SetTakeProfit(symbol string, positionSide string, quan
|
||||
return fmt.Errorf("設置止盈失敗: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER 止盈已設置: %.2f", takeProfitPrice)
|
||||
logger.Infof("✓ LIGHTER 止盈已設置: %.2f", takeProfitPrice)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ func (t *LighterTraderV2) CancelAllOrders(symbol string) error {
|
||||
}
|
||||
|
||||
if len(orders) == 0 {
|
||||
log.Printf("✓ LIGHTER - 無需取消訂單(無活躍訂單)")
|
||||
logger.Infof("✓ LIGHTER - 無需取消訂單(無活躍訂單)")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -79,27 +79,101 @@ func (t *LighterTraderV2) CancelAllOrders(symbol string) error {
|
||||
canceledCount := 0
|
||||
for _, order := range orders {
|
||||
if err := t.CancelOrder(symbol, order.OrderID); err != nil {
|
||||
log.Printf("⚠️ 取消訂單失敗 (ID: %s): %v", order.OrderID, err)
|
||||
logger.Infof("⚠️ 取消訂單失敗 (ID: %s): %v", order.OrderID, err)
|
||||
} else {
|
||||
canceledCount++
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER - 已取消 %d 個訂單", canceledCount)
|
||||
logger.Infof("✓ LIGHTER - 已取消 %d 個訂單", canceledCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOrderStatus 獲取訂單狀態(實現 Trader 接口)
|
||||
func (t *LighterTraderV2) GetOrderStatus(symbol string, orderID string) (map[string]interface{}, error) {
|
||||
// LIGHTER 使用市價單通常立即成交
|
||||
// 嘗試查詢訂單狀態
|
||||
if err := t.ensureAuthToken(); err != nil {
|
||||
return nil, fmt.Errorf("認證令牌無效: %w", err)
|
||||
}
|
||||
|
||||
// 構建請求 URL
|
||||
endpoint := fmt.Sprintf("%s/api/v1/order/%s", t.baseURL, orderID)
|
||||
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", t.authToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := t.client.Do(req)
|
||||
if err != nil {
|
||||
// 如果查詢失敗,假設訂單已完成
|
||||
return map[string]interface{}{
|
||||
"orderId": orderID,
|
||||
"status": "FILLED",
|
||||
"avgPrice": 0.0,
|
||||
"executedQty": 0.0,
|
||||
"commission": 0.0,
|
||||
}, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"orderId": orderID,
|
||||
"status": "FILLED",
|
||||
"avgPrice": 0.0,
|
||||
"executedQty": 0.0,
|
||||
"commission": 0.0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var order OrderResponse
|
||||
if err := json.Unmarshal(body, &order); err != nil {
|
||||
return map[string]interface{}{
|
||||
"orderId": orderID,
|
||||
"status": "FILLED",
|
||||
"avgPrice": 0.0,
|
||||
"executedQty": 0.0,
|
||||
"commission": 0.0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 轉換狀態為統一格式
|
||||
unifiedStatus := order.Status
|
||||
switch order.Status {
|
||||
case "filled":
|
||||
unifiedStatus = "FILLED"
|
||||
case "open":
|
||||
unifiedStatus = "NEW"
|
||||
case "cancelled":
|
||||
unifiedStatus = "CANCELED"
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"orderId": order.OrderID,
|
||||
"status": unifiedStatus,
|
||||
"avgPrice": order.Price,
|
||||
"executedQty": order.FilledQty,
|
||||
"commission": 0.0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CancelStopLossOrders 僅取消止損單(實現 Trader 接口)
|
||||
func (t *LighterTraderV2) CancelStopLossOrders(symbol string) error {
|
||||
// LIGHTER 暫時無法區分止損和止盈單,取消所有止盈止損單
|
||||
log.Printf("⚠️ LIGHTER 無法區分止損/止盈單,將取消所有止盈止損單")
|
||||
logger.Infof("⚠️ LIGHTER 無法區分止損/止盈單,將取消所有止盈止損單")
|
||||
return t.CancelStopOrders(symbol)
|
||||
}
|
||||
|
||||
// CancelTakeProfitOrders 僅取消止盈單(實現 Trader 接口)
|
||||
func (t *LighterTraderV2) CancelTakeProfitOrders(symbol string) error {
|
||||
// LIGHTER 暫時無法區分止損和止盈單,取消所有止盈止損單
|
||||
log.Printf("⚠️ LIGHTER 無法區分止損/止盈單,將取消所有止盈止損單")
|
||||
logger.Infof("⚠️ LIGHTER 無法區分止損/止盈單,將取消所有止盈止損單")
|
||||
return t.CancelStopOrders(symbol)
|
||||
}
|
||||
|
||||
@@ -124,13 +198,13 @@ func (t *LighterTraderV2) CancelStopOrders(symbol string) error {
|
||||
// TODO: 檢查訂單類型,只取消止盈止損單
|
||||
// 暫時取消所有訂單
|
||||
if err := t.CancelOrder(symbol, order.OrderID); err != nil {
|
||||
log.Printf("⚠️ 取消訂單失敗 (ID: %s): %v", order.OrderID, err)
|
||||
logger.Infof("⚠️ 取消訂單失敗 (ID: %s): %v", order.OrderID, err)
|
||||
} else {
|
||||
canceledCount++
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER - 已取消 %d 個止盈止損單", canceledCount)
|
||||
logger.Infof("✓ LIGHTER - 已取消 %d 個止盈止損單", canceledCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -186,7 +260,7 @@ func (t *LighterTraderV2) GetActiveOrders(symbol string) ([]OrderResponse, error
|
||||
return nil, fmt.Errorf("獲取活躍訂單失敗 (code %d): %s", apiResp.Code, apiResp.Message)
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER - 獲取到 %d 個活躍訂單", len(apiResp.Data))
|
||||
logger.Infof("✓ LIGHTER - 獲取到 %d 個活躍訂單", len(apiResp.Data))
|
||||
return apiResp.Data, nil
|
||||
}
|
||||
|
||||
@@ -235,7 +309,7 @@ func (t *LighterTraderV2) CancelOrder(symbol, orderID string) error {
|
||||
return fmt.Errorf("提交取消訂單失敗: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER訂單已取消 - ID: %s", orderID)
|
||||
logger.Infof("✓ LIGHTER訂單已取消 - ID: %s", orderID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -291,6 +365,6 @@ func (t *LighterTraderV2) submitCancelOrder(signedTx []byte) (map[string]interfa
|
||||
"status": "cancelled",
|
||||
}
|
||||
|
||||
log.Printf("✓ 取消訂單已提交到 LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"])
|
||||
logger.Infof("✓ 取消訂單已提交到 LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"])
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -18,11 +18,11 @@ func (t *LighterTraderV2) OpenLong(symbol string, quantity float64, leverage int
|
||||
return nil, fmt.Errorf("TxClient 未初始化,請先設置 API Key")
|
||||
}
|
||||
|
||||
log.Printf("📈 LIGHTER 開多倉: %s, qty=%.4f, leverage=%dx", symbol, quantity, leverage)
|
||||
logger.Infof("📈 LIGHTER 開多倉: %s, qty=%.4f, leverage=%dx", symbol, quantity, leverage)
|
||||
|
||||
// 1. 設置杠杆(如果需要)
|
||||
if err := t.SetLeverage(symbol, leverage); err != nil {
|
||||
log.Printf("⚠️ 設置杠杆失敗: %v", err)
|
||||
logger.Infof("⚠️ 設置杠杆失敗: %v", err)
|
||||
}
|
||||
|
||||
// 2. 獲取市場價格
|
||||
@@ -37,7 +37,7 @@ func (t *LighterTraderV2) OpenLong(symbol string, quantity float64, leverage int
|
||||
return nil, fmt.Errorf("開多倉失敗: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER 開多倉成功: %s @ %.2f", symbol, marketPrice)
|
||||
logger.Infof("✓ LIGHTER 開多倉成功: %s @ %.2f", symbol, marketPrice)
|
||||
|
||||
return map[string]interface{}{
|
||||
"orderId": orderResult["orderId"],
|
||||
@@ -54,11 +54,11 @@ func (t *LighterTraderV2) OpenShort(symbol string, quantity float64, leverage in
|
||||
return nil, fmt.Errorf("TxClient 未初始化,請先設置 API Key")
|
||||
}
|
||||
|
||||
log.Printf("📉 LIGHTER 開空倉: %s, qty=%.4f, leverage=%dx", symbol, quantity, leverage)
|
||||
logger.Infof("📉 LIGHTER 開空倉: %s, qty=%.4f, leverage=%dx", symbol, quantity, leverage)
|
||||
|
||||
// 1. 設置杠杆
|
||||
if err := t.SetLeverage(symbol, leverage); err != nil {
|
||||
log.Printf("⚠️ 設置杠杆失敗: %v", err)
|
||||
logger.Infof("⚠️ 設置杠杆失敗: %v", err)
|
||||
}
|
||||
|
||||
// 2. 獲取市場價格
|
||||
@@ -73,7 +73,7 @@ func (t *LighterTraderV2) OpenShort(symbol string, quantity float64, leverage in
|
||||
return nil, fmt.Errorf("開空倉失敗: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER 開空倉成功: %s @ %.2f", symbol, marketPrice)
|
||||
logger.Infof("✓ LIGHTER 開空倉成功: %s @ %.2f", symbol, marketPrice)
|
||||
|
||||
return map[string]interface{}{
|
||||
"orderId": orderResult["orderId"],
|
||||
@@ -105,7 +105,7 @@ func (t *LighterTraderV2) CloseLong(symbol string, quantity float64) (map[string
|
||||
quantity = pos.Size
|
||||
}
|
||||
|
||||
log.Printf("🔻 LIGHTER 平多倉: %s, qty=%.4f", symbol, quantity)
|
||||
logger.Infof("🔻 LIGHTER 平多倉: %s, qty=%.4f", symbol, quantity)
|
||||
|
||||
// 創建市價賣出單平倉(reduceOnly=true)
|
||||
orderResult, err := t.CreateOrder(symbol, true, quantity, 0, "market")
|
||||
@@ -115,10 +115,10 @@ func (t *LighterTraderV2) CloseLong(symbol string, quantity float64) (map[string
|
||||
|
||||
// 平倉後取消所有掛單
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf("⚠️ 取消掛單失敗: %v", err)
|
||||
logger.Infof("⚠️ 取消掛單失敗: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER 平多倉成功: %s", symbol)
|
||||
logger.Infof("✓ LIGHTER 平多倉成功: %s", symbol)
|
||||
|
||||
return map[string]interface{}{
|
||||
"orderId": orderResult["orderId"],
|
||||
@@ -148,7 +148,7 @@ func (t *LighterTraderV2) CloseShort(symbol string, quantity float64) (map[strin
|
||||
quantity = pos.Size
|
||||
}
|
||||
|
||||
log.Printf("🔺 LIGHTER 平空倉: %s, qty=%.4f", symbol, quantity)
|
||||
logger.Infof("🔺 LIGHTER 平空倉: %s, qty=%.4f", symbol, quantity)
|
||||
|
||||
// 創建市價買入單平倉(reduceOnly=true)
|
||||
orderResult, err := t.CreateOrder(symbol, false, quantity, 0, "market")
|
||||
@@ -158,10 +158,10 @@ func (t *LighterTraderV2) CloseShort(symbol string, quantity float64) (map[strin
|
||||
|
||||
// 平倉後取消所有掛單
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf("⚠️ 取消掛單失敗: %v", err)
|
||||
logger.Infof("⚠️ 取消掛單失敗: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER 平空倉成功: %s", symbol)
|
||||
logger.Infof("✓ LIGHTER 平空倉成功: %s", symbol)
|
||||
|
||||
return map[string]interface{}{
|
||||
"orderId": orderResult["orderId"],
|
||||
@@ -235,7 +235,7 @@ func (t *LighterTraderV2) CreateOrder(symbol string, isAsk bool, quantity float6
|
||||
if isAsk {
|
||||
side = "sell"
|
||||
}
|
||||
log.Printf("✓ LIGHTER訂單已創建: %s %s qty=%.4f", symbol, side, quantity)
|
||||
logger.Infof("✓ LIGHTER訂單已創建: %s %s qty=%.4f", symbol, side, quantity)
|
||||
|
||||
return orderResp, nil
|
||||
}
|
||||
@@ -315,7 +315,7 @@ func (t *LighterTraderV2) submitOrder(signedTx []byte) (map[string]interface{},
|
||||
result["orderId"] = txHash
|
||||
}
|
||||
|
||||
log.Printf("✓ 訂單已提交到 LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"])
|
||||
logger.Infof("✓ 訂單已提交到 LIGHTER - tx_hash: %v", sendResp.Data["tx_hash"])
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -334,7 +334,7 @@ func (t *LighterTraderV2) getMarketIndex(symbol string) (uint8, error) {
|
||||
markets, err := t.fetchMarketList()
|
||||
if err != nil {
|
||||
// 如果 API 失敗,回退到硬編碼映射
|
||||
log.Printf("⚠️ 從 API 獲取市場列表失敗,使用硬編碼映射: %v", err)
|
||||
logger.Infof("⚠️ 從 API 獲取市場列表失敗,使用硬編碼映射: %v", err)
|
||||
return t.getFallbackMarketIndex(symbol)
|
||||
}
|
||||
|
||||
@@ -412,7 +412,7 @@ func (t *LighterTraderV2) fetchMarketList() ([]MarketInfo, error) {
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("✓ 獲取到 %d 個市場", len(markets))
|
||||
logger.Infof("✓ 獲取到 %d 個市場", len(markets))
|
||||
return markets, nil
|
||||
}
|
||||
|
||||
@@ -428,7 +428,7 @@ func (t *LighterTraderV2) getFallbackMarketIndex(symbol string) (uint8, error) {
|
||||
}
|
||||
|
||||
if index, ok := fallbackMap[symbol]; ok {
|
||||
log.Printf("✓ 使用硬編碼市場索引: %s -> %d", symbol, index)
|
||||
logger.Infof("✓ 使用硬編碼市場索引: %s -> %d", symbol, index)
|
||||
return index, nil
|
||||
}
|
||||
|
||||
@@ -442,7 +442,7 @@ func (t *LighterTraderV2) SetLeverage(symbol string, leverage int) error {
|
||||
}
|
||||
|
||||
// TODO: 使用SDK簽名並提交SetLeverage交易
|
||||
log.Printf("⚙️ 設置杠杆: %s = %dx", symbol, leverage)
|
||||
logger.Infof("⚙️ 設置杠杆: %s = %dx", symbol, leverage)
|
||||
|
||||
return nil // 暫時返回成功
|
||||
}
|
||||
@@ -458,7 +458,7 @@ func (t *LighterTraderV2) SetMarginMode(symbol string, isCrossMargin bool) error
|
||||
modeStr = "全倉"
|
||||
}
|
||||
|
||||
log.Printf("⚙️ 設置倉位模式: %s = %s", symbol, modeStr)
|
||||
logger.Infof("⚙️ 設置倉位模式: %s = %s", symbol, modeStr)
|
||||
|
||||
// TODO: 使用SDK簽名並提交SetMarginMode交易
|
||||
return nil
|
||||
|
||||
@@ -2,13 +2,13 @@ package trader
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"nofx/logger"
|
||||
)
|
||||
|
||||
// OpenLong 开多仓
|
||||
func (t *LighterTrader) OpenLong(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
|
||||
// TODO: 实现完整的开多仓逻辑
|
||||
log.Printf("🚧 LIGHTER OpenLong 暂未完全实现 (symbol=%s, qty=%.4f, leverage=%d)", symbol, quantity, leverage)
|
||||
logger.Infof("🚧 LIGHTER OpenLong 暂未完全实现 (symbol=%s, qty=%.4f, leverage=%d)", symbol, quantity, leverage)
|
||||
|
||||
// 使用市价买入单
|
||||
orderID, err := t.CreateOrder(symbol, "buy", quantity, 0, "market")
|
||||
@@ -26,7 +26,7 @@ func (t *LighterTrader) OpenLong(symbol string, quantity float64, leverage int)
|
||||
// OpenShort 开空仓
|
||||
func (t *LighterTrader) OpenShort(symbol string, quantity float64, leverage int) (map[string]interface{}, error) {
|
||||
// TODO: 实现完整的开空仓逻辑
|
||||
log.Printf("🚧 LIGHTER OpenShort 暂未完全实现 (symbol=%s, qty=%.4f, leverage=%d)", symbol, quantity, leverage)
|
||||
logger.Infof("🚧 LIGHTER OpenShort 暂未完全实现 (symbol=%s, qty=%.4f, leverage=%d)", symbol, quantity, leverage)
|
||||
|
||||
// 使用市价卖出单
|
||||
orderID, err := t.CreateOrder(symbol, "sell", quantity, 0, "market")
|
||||
@@ -66,7 +66,7 @@ func (t *LighterTrader) CloseLong(symbol string, quantity float64) (map[string]i
|
||||
|
||||
// 平仓后取消所有挂单
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消挂单失败: %v", err)
|
||||
logger.Infof(" ⚠ 取消挂单失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
@@ -101,7 +101,7 @@ func (t *LighterTrader) CloseShort(symbol string, quantity float64) (map[string]
|
||||
|
||||
// 平仓后取消所有挂单
|
||||
if err := t.CancelAllOrders(symbol); err != nil {
|
||||
log.Printf(" ⚠ 取消挂单失败: %v", err)
|
||||
logger.Infof(" ⚠ 取消挂单失败: %v", err)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
@@ -114,7 +114,7 @@ func (t *LighterTrader) CloseShort(symbol string, quantity float64) (map[string]
|
||||
// SetStopLoss 设置止损单
|
||||
func (t *LighterTrader) SetStopLoss(symbol string, positionSide string, quantity, stopPrice float64) error {
|
||||
// TODO: 实现完整的止损单逻辑
|
||||
log.Printf("🚧 LIGHTER SetStopLoss 暂未完全实现 (symbol=%s, side=%s, qty=%.4f, stop=%.2f)", symbol, positionSide, quantity, stopPrice)
|
||||
logger.Infof("🚧 LIGHTER SetStopLoss 暂未完全实现 (symbol=%s, side=%s, qty=%.4f, stop=%.2f)", symbol, positionSide, quantity, stopPrice)
|
||||
|
||||
// 确定订单方向(做空止损用买单,做多止损用卖单)
|
||||
side := "sell"
|
||||
@@ -128,14 +128,14 @@ func (t *LighterTrader) SetStopLoss(symbol string, positionSide string, quantity
|
||||
return fmt.Errorf("设置止损失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER - 止损已设置: %.2f (side: %s)", stopPrice, side)
|
||||
logger.Infof("✓ LIGHTER - 止损已设置: %.2f (side: %s)", stopPrice, side)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTakeProfit 设置止盈单
|
||||
func (t *LighterTrader) SetTakeProfit(symbol string, positionSide string, quantity, takeProfitPrice float64) error {
|
||||
// TODO: 实现完整的止盈单逻辑
|
||||
log.Printf("🚧 LIGHTER SetTakeProfit 暂未完全实现 (symbol=%s, side=%s, qty=%.4f, tp=%.2f)", symbol, positionSide, quantity, takeProfitPrice)
|
||||
logger.Infof("🚧 LIGHTER SetTakeProfit 暂未完全实现 (symbol=%s, side=%s, qty=%.4f, tp=%.2f)", symbol, positionSide, quantity, takeProfitPrice)
|
||||
|
||||
// 确定订单方向(做空止盈用买单,做多止盈用卖单)
|
||||
side := "sell"
|
||||
@@ -149,7 +149,7 @@ func (t *LighterTrader) SetTakeProfit(symbol string, positionSide string, quanti
|
||||
return fmt.Errorf("设置止盈失败: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("✓ LIGHTER - 止盈已设置: %.2f (side: %s)", takeProfitPrice, side)
|
||||
logger.Infof("✓ LIGHTER - 止盈已设置: %.2f (side: %s)", takeProfitPrice, side)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -160,7 +160,7 @@ func (t *LighterTrader) SetMarginMode(symbol string, isCrossMargin bool) error {
|
||||
if isCrossMargin {
|
||||
modeStr = "全仓"
|
||||
}
|
||||
log.Printf("🚧 LIGHTER SetMarginMode 暂未实现 (symbol=%s, mode=%s)", symbol, modeStr)
|
||||
logger.Infof("🚧 LIGHTER SetMarginMode 暂未实现 (symbol=%s, mode=%s)", symbol, modeStr)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
309
trader/order_sync.go
Normal file
309
trader/order_sync.go
Normal file
@@ -0,0 +1,309 @@
|
||||
package trader
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"nofx/logger"
|
||||
"nofx/store"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OrderSyncManager 订单状态同步管理器
|
||||
// 负责定期扫描所有 NEW 状态的订单,并更新其状态
|
||||
type OrderSyncManager struct {
|
||||
store *store.Store
|
||||
interval time.Duration
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
traderCache map[string]Trader // trader_id -> Trader 实例缓存
|
||||
configCache map[string]*store.TraderFullConfig // trader_id -> 配置缓存
|
||||
cacheMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewOrderSyncManager 创建订单同步管理器
|
||||
func NewOrderSyncManager(st *store.Store, interval time.Duration) *OrderSyncManager {
|
||||
if interval == 0 {
|
||||
interval = 10 * time.Second
|
||||
}
|
||||
return &OrderSyncManager{
|
||||
store: st,
|
||||
interval: interval,
|
||||
stopCh: make(chan struct{}),
|
||||
traderCache: make(map[string]Trader),
|
||||
configCache: make(map[string]*store.TraderFullConfig),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动订单同步服务
|
||||
func (m *OrderSyncManager) Start() {
|
||||
m.wg.Add(1)
|
||||
go m.run()
|
||||
logger.Info("📦 订单同步管理器已启动")
|
||||
}
|
||||
|
||||
// Stop 停止订单同步服务
|
||||
func (m *OrderSyncManager) Stop() {
|
||||
close(m.stopCh)
|
||||
m.wg.Wait()
|
||||
|
||||
// 清理缓存
|
||||
m.cacheMutex.Lock()
|
||||
m.traderCache = make(map[string]Trader)
|
||||
m.configCache = make(map[string]*store.TraderFullConfig)
|
||||
m.cacheMutex.Unlock()
|
||||
|
||||
logger.Info("📦 订单同步管理器已停止")
|
||||
}
|
||||
|
||||
// run 主循环
|
||||
func (m *OrderSyncManager) run() {
|
||||
defer m.wg.Done()
|
||||
|
||||
// 启动时立即执行一次
|
||||
m.syncOrders()
|
||||
|
||||
ticker := time.NewTicker(m.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.syncOrders()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// syncOrders 同步所有待处理订单
|
||||
func (m *OrderSyncManager) syncOrders() {
|
||||
// 获取所有 NEW 状态的订单
|
||||
orders, err := m.store.Order().GetAllPendingOrders()
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ 获取待处理订单失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(orders) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("📦 开始同步 %d 个待处理订单...", len(orders))
|
||||
|
||||
// 按 trader_id 分组
|
||||
ordersByTrader := make(map[string][]*store.TraderOrder)
|
||||
for _, order := range orders {
|
||||
ordersByTrader[order.TraderID] = append(ordersByTrader[order.TraderID], order)
|
||||
}
|
||||
|
||||
// 逐个 trader 处理
|
||||
for traderID, traderOrders := range ordersByTrader {
|
||||
m.syncTraderOrders(traderID, traderOrders)
|
||||
}
|
||||
}
|
||||
|
||||
// syncTraderOrders 同步单个 trader 的订单
|
||||
func (m *OrderSyncManager) syncTraderOrders(traderID string, orders []*store.TraderOrder) {
|
||||
// 获取或创建 trader 实例
|
||||
trader, err := m.getOrCreateTrader(traderID)
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ 获取 trader 实例失败 (ID: %s): %v", traderID, err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, order := range orders {
|
||||
m.syncSingleOrder(trader, order)
|
||||
}
|
||||
}
|
||||
|
||||
// syncSingleOrder 同步单个订单状态
|
||||
func (m *OrderSyncManager) syncSingleOrder(trader Trader, order *store.TraderOrder) {
|
||||
status, err := trader.GetOrderStatus(order.Symbol, order.OrderID)
|
||||
if err != nil {
|
||||
// 查询失败,检查订单创建时间,超过一定时间假设已成交
|
||||
if time.Since(order.CreatedAt) > 5*time.Minute {
|
||||
logger.Infof("⚠️ 订单查询超时,假设已成交 (ID: %s)", order.OrderID)
|
||||
m.markOrderFilled(order, 0, 0, 0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
statusStr, _ := status["status"].(string)
|
||||
|
||||
switch statusStr {
|
||||
case "FILLED":
|
||||
avgPrice, _ := status["avgPrice"].(float64)
|
||||
executedQty, _ := status["executedQty"].(float64)
|
||||
commission, _ := status["commission"].(float64)
|
||||
|
||||
// 如果 API 未返回数量,使用原始数量
|
||||
if executedQty == 0 {
|
||||
executedQty = order.Quantity
|
||||
}
|
||||
|
||||
m.markOrderFilled(order, avgPrice, executedQty, commission)
|
||||
|
||||
case "CANCELED", "EXPIRED":
|
||||
order.Status = statusStr
|
||||
if err := m.store.Order().Update(order); err != nil {
|
||||
logger.Infof("⚠️ 更新订单状态失败: %v", err)
|
||||
} else {
|
||||
logger.Infof("📦 订单状态更新: %s (ID: %s)", statusStr, order.OrderID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// markOrderFilled 标记订单已成交
|
||||
func (m *OrderSyncManager) markOrderFilled(order *store.TraderOrder, avgPrice, executedQty, commission float64) {
|
||||
// 如果 avgPrice 为 0,使用订单价格
|
||||
if avgPrice == 0 {
|
||||
avgPrice = order.Price
|
||||
}
|
||||
if executedQty == 0 {
|
||||
executedQty = order.Quantity
|
||||
}
|
||||
|
||||
// 计算已实现盈亏(仅平仓订单)
|
||||
var realizedPnL float64
|
||||
if (order.Action == "close_long" || order.Action == "close_short") && order.EntryPrice > 0 && avgPrice > 0 {
|
||||
if order.Action == "close_long" {
|
||||
// 平多盈亏 = (平仓价 - 开仓价) * 数量
|
||||
realizedPnL = (avgPrice - order.EntryPrice) * executedQty
|
||||
} else {
|
||||
// 平空盈亏 = (开仓价 - 平仓价) * 数量
|
||||
realizedPnL = (order.EntryPrice - avgPrice) * executedQty
|
||||
}
|
||||
}
|
||||
|
||||
order.AvgPrice = avgPrice
|
||||
order.ExecutedQty = executedQty
|
||||
order.Status = "FILLED"
|
||||
order.Fee = commission
|
||||
order.RealizedPnL = realizedPnL
|
||||
order.FilledAt = time.Now()
|
||||
|
||||
if err := m.store.Order().Update(order); err != nil {
|
||||
logger.Infof("⚠️ 更新订单状态失败: %v", err)
|
||||
} else {
|
||||
if realizedPnL != 0 {
|
||||
logger.Infof("✅ 订单已成交 (ID: %s, avgPrice: %.4f, qty: %.4f, PnL: %.2f)",
|
||||
order.OrderID, avgPrice, executedQty, realizedPnL)
|
||||
} else {
|
||||
logger.Infof("✅ 订单已成交 (ID: %s, avgPrice: %.4f, qty: %.4f)",
|
||||
order.OrderID, avgPrice, executedQty)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateTrader 获取或创建 trader 实例
|
||||
func (m *OrderSyncManager) getOrCreateTrader(traderID string) (Trader, error) {
|
||||
m.cacheMutex.RLock()
|
||||
trader, exists := m.traderCache[traderID]
|
||||
m.cacheMutex.RUnlock()
|
||||
|
||||
if exists && trader != nil {
|
||||
return trader, nil
|
||||
}
|
||||
|
||||
// 需要创建新的 trader 实例
|
||||
// 首先获取 trader 配置
|
||||
config, err := m.getTraderConfig(traderID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 trader 配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 根据交易所类型创建 trader
|
||||
trader, err = m.createTrader(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 trader 实例失败: %w", err)
|
||||
}
|
||||
|
||||
m.cacheMutex.Lock()
|
||||
m.traderCache[traderID] = trader
|
||||
m.cacheMutex.Unlock()
|
||||
|
||||
return trader, nil
|
||||
}
|
||||
|
||||
// getTraderConfig 获取 trader 配置
|
||||
func (m *OrderSyncManager) getTraderConfig(traderID string) (*store.TraderFullConfig, error) {
|
||||
m.cacheMutex.RLock()
|
||||
config, exists := m.configCache[traderID]
|
||||
m.cacheMutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// 从数据库获取 - 需要找到 trader 对应的 userID
|
||||
// 首先查询所有 traders 找到对应的 userID
|
||||
traders, err := m.store.Trader().ListAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 trader 列表失败: %w", err)
|
||||
}
|
||||
|
||||
var userID string
|
||||
for _, t := range traders {
|
||||
if t.ID == traderID {
|
||||
userID = t.UserID
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
return nil, fmt.Errorf("找不到 trader: %s", traderID)
|
||||
}
|
||||
|
||||
config, err = m.store.Trader().GetFullConfig(userID, traderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.cacheMutex.Lock()
|
||||
m.configCache[traderID] = config
|
||||
m.cacheMutex.Unlock()
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// createTrader 根据配置创建 trader 实例
|
||||
func (m *OrderSyncManager) createTrader(config *store.TraderFullConfig) (Trader, error) {
|
||||
exchange := config.Exchange
|
||||
|
||||
switch exchange.Type {
|
||||
case "binance":
|
||||
return NewFuturesTrader(exchange.APIKey, exchange.SecretKey, config.Trader.UserID), nil
|
||||
|
||||
case "bybit":
|
||||
return NewBybitTrader(exchange.APIKey, exchange.SecretKey), nil
|
||||
|
||||
case "hyperliquid":
|
||||
return NewHyperliquidTrader(exchange.SecretKey, exchange.HyperliquidWalletAddr, exchange.Testnet)
|
||||
|
||||
case "aster":
|
||||
return NewAsterTrader(exchange.AsterUser, exchange.AsterSigner, exchange.AsterPrivateKey)
|
||||
|
||||
case "lighter":
|
||||
if exchange.LighterAPIKeyPrivateKey != "" {
|
||||
return NewLighterTraderV2(
|
||||
exchange.LighterPrivateKey,
|
||||
exchange.LighterWalletAddr,
|
||||
exchange.LighterAPIKeyPrivateKey,
|
||||
exchange.Testnet,
|
||||
)
|
||||
}
|
||||
return NewLighterTrader(exchange.LighterPrivateKey, exchange.LighterWalletAddr, exchange.Testnet)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的交易所类型: %s", exchange.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCache 使缓存失效(当配置变更时调用)
|
||||
func (m *OrderSyncManager) InvalidateCache(traderID string) {
|
||||
m.cacheMutex.Lock()
|
||||
defer m.cacheMutex.Unlock()
|
||||
|
||||
delete(m.traderCache, traderID)
|
||||
delete(m.configCache, traderID)
|
||||
}
|
||||
@@ -1,393 +0,0 @@
|
||||
package trader
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"nofx/decision"
|
||||
"nofx/logger"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// MockPartialCloseTrader 用於測試 partial close 邏輯
|
||||
type MockPartialCloseTrader struct {
|
||||
positions []map[string]interface{}
|
||||
closePartialCalled bool
|
||||
closeLongCalled bool
|
||||
closeShortCalled bool
|
||||
stopLossCalled bool
|
||||
takeProfitCalled bool
|
||||
lastStopLoss float64
|
||||
lastTakeProfit float64
|
||||
}
|
||||
|
||||
func (m *MockPartialCloseTrader) GetPositions() ([]map[string]interface{}, error) {
|
||||
return m.positions, nil
|
||||
}
|
||||
|
||||
func (m *MockPartialCloseTrader) ClosePartialLong(symbol string, quantity float64) (map[string]interface{}, error) {
|
||||
m.closePartialCalled = true
|
||||
return map[string]interface{}{"orderId": "12345"}, nil
|
||||
}
|
||||
|
||||
func (m *MockPartialCloseTrader) ClosePartialShort(symbol string, quantity float64) (map[string]interface{}, error) {
|
||||
m.closePartialCalled = true
|
||||
return map[string]interface{}{"orderId": "12345"}, nil
|
||||
}
|
||||
|
||||
func (m *MockPartialCloseTrader) CloseLong(symbol string, quantity float64) (map[string]interface{}, error) {
|
||||
m.closeLongCalled = true
|
||||
return map[string]interface{}{"orderId": "12346"}, nil
|
||||
}
|
||||
|
||||
func (m *MockPartialCloseTrader) CloseShort(symbol string, quantity float64) (map[string]interface{}, error) {
|
||||
m.closeShortCalled = true
|
||||
return map[string]interface{}{"orderId": "12346"}, nil
|
||||
}
|
||||
|
||||
func (m *MockPartialCloseTrader) SetStopLoss(symbol, side string, quantity, price float64) error {
|
||||
m.stopLossCalled = true
|
||||
m.lastStopLoss = price
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockPartialCloseTrader) SetTakeProfit(symbol, side string, quantity, price float64) error {
|
||||
m.takeProfitCalled = true
|
||||
m.lastTakeProfit = price
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestPartialCloseMinPositionCheck 測試最小倉位檢查邏輯
|
||||
func TestPartialCloseMinPositionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
totalQuantity float64
|
||||
markPrice float64
|
||||
closePercentage float64
|
||||
expectFullClose bool // 是否應該觸發全平邏輯
|
||||
expectRemainValue float64
|
||||
}{
|
||||
{
|
||||
name: "正常部分平倉_剩餘價值充足",
|
||||
totalQuantity: 1.0,
|
||||
markPrice: 100.0,
|
||||
closePercentage: 50.0,
|
||||
expectFullClose: false,
|
||||
expectRemainValue: 50.0, // 剩餘 0.5 * 100 = 50 USDT
|
||||
},
|
||||
{
|
||||
name: "部分平倉_剩餘價值小於10USDT_應該全平",
|
||||
totalQuantity: 0.2,
|
||||
markPrice: 100.0,
|
||||
closePercentage: 95.0, // 平倉 95%,剩餘 1 USDT (0.2 * 5% * 100)
|
||||
expectFullClose: true,
|
||||
expectRemainValue: 1.0,
|
||||
},
|
||||
{
|
||||
name: "部分平倉_剩餘價值剛好10USDT_應該全平",
|
||||
totalQuantity: 1.0,
|
||||
markPrice: 100.0,
|
||||
closePercentage: 90.0, // 剩餘 10 USDT (1.0 * 10% * 100),邊界測試 (<=)
|
||||
expectFullClose: true,
|
||||
expectRemainValue: 10.0,
|
||||
},
|
||||
{
|
||||
name: "部分平倉_剩餘價值11USDT_不應全平",
|
||||
totalQuantity: 1.1,
|
||||
markPrice: 100.0,
|
||||
closePercentage: 90.0, // 剩餘 11 USDT (1.1 * 10% * 100)
|
||||
expectFullClose: false,
|
||||
expectRemainValue: 11.0,
|
||||
},
|
||||
{
|
||||
name: "大倉位部分平倉_剩餘價值遠大於10USDT",
|
||||
totalQuantity: 10.0,
|
||||
markPrice: 1000.0,
|
||||
closePercentage: 80.0,
|
||||
expectFullClose: false,
|
||||
expectRemainValue: 2000.0, // 剩餘 2 * 1000 = 2000 USDT
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 計算剩餘價值
|
||||
closeQuantity := tt.totalQuantity * (tt.closePercentage / 100.0)
|
||||
remainingQuantity := tt.totalQuantity - closeQuantity
|
||||
remainingValue := remainingQuantity * tt.markPrice
|
||||
|
||||
// 驗證計算(使用浮點數比較允許微小誤差)
|
||||
const epsilon = 0.001
|
||||
if remainingValue-tt.expectRemainValue > epsilon || tt.expectRemainValue-remainingValue > epsilon {
|
||||
t.Errorf("計算錯誤: 剩餘價值 = %.2f, 期望 = %.2f",
|
||||
remainingValue, tt.expectRemainValue)
|
||||
}
|
||||
|
||||
// 驗證最小倉位檢查邏輯
|
||||
const MIN_POSITION_VALUE = 10.0
|
||||
shouldFullClose := remainingValue > 0 && remainingValue <= MIN_POSITION_VALUE
|
||||
|
||||
if shouldFullClose != tt.expectFullClose {
|
||||
t.Errorf("最小倉位檢查失敗: shouldFullClose = %v, 期望 = %v (剩餘價值 = %.2f USDT)",
|
||||
shouldFullClose, tt.expectFullClose, remainingValue)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPartialCloseWithStopLossTakeProfitRecovery 測試止盈止損恢復邏輯
|
||||
func TestPartialCloseWithStopLossTakeProfitRecovery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
newStopLoss float64
|
||||
newTakeProfit float64
|
||||
expectStopLoss bool
|
||||
expectTakeProfit bool
|
||||
}{
|
||||
{
|
||||
name: "有新止損和止盈_應該恢復兩者",
|
||||
newStopLoss: 95.0,
|
||||
newTakeProfit: 110.0,
|
||||
expectStopLoss: true,
|
||||
expectTakeProfit: true,
|
||||
},
|
||||
{
|
||||
name: "只有新止損_僅恢復止損",
|
||||
newStopLoss: 95.0,
|
||||
newTakeProfit: 0,
|
||||
expectStopLoss: true,
|
||||
expectTakeProfit: false,
|
||||
},
|
||||
{
|
||||
name: "只有新止盈_僅恢復止盈",
|
||||
newStopLoss: 0,
|
||||
newTakeProfit: 110.0,
|
||||
expectStopLoss: false,
|
||||
expectTakeProfit: true,
|
||||
},
|
||||
{
|
||||
name: "沒有新止損止盈_不恢復",
|
||||
newStopLoss: 0,
|
||||
newTakeProfit: 0,
|
||||
expectStopLoss: false,
|
||||
expectTakeProfit: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模擬止盈止損恢復邏輯
|
||||
stopLossRecovered := tt.newStopLoss > 0
|
||||
takeProfitRecovered := tt.newTakeProfit > 0
|
||||
|
||||
if stopLossRecovered != tt.expectStopLoss {
|
||||
t.Errorf("止損恢復邏輯錯誤: recovered = %v, 期望 = %v",
|
||||
stopLossRecovered, tt.expectStopLoss)
|
||||
}
|
||||
|
||||
if takeProfitRecovered != tt.expectTakeProfit {
|
||||
t.Errorf("止盈恢復邏輯錯誤: recovered = %v, 期望 = %v",
|
||||
takeProfitRecovered, tt.expectTakeProfit)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPartialCloseEdgeCases 測試邊界情況
|
||||
func TestPartialCloseEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
closePercentage float64
|
||||
totalQuantity float64
|
||||
markPrice float64
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "平倉百分比為0_應該報錯",
|
||||
closePercentage: 0,
|
||||
totalQuantity: 1.0,
|
||||
markPrice: 100.0,
|
||||
expectError: true,
|
||||
errorContains: "0-100",
|
||||
},
|
||||
{
|
||||
name: "平倉百分比超過100_應該報錯",
|
||||
closePercentage: 101.0,
|
||||
totalQuantity: 1.0,
|
||||
markPrice: 100.0,
|
||||
expectError: true,
|
||||
errorContains: "0-100",
|
||||
},
|
||||
{
|
||||
name: "平倉百分比為負數_應該報錯",
|
||||
closePercentage: -10.0,
|
||||
totalQuantity: 1.0,
|
||||
markPrice: 100.0,
|
||||
expectError: true,
|
||||
errorContains: "0-100",
|
||||
},
|
||||
{
|
||||
name: "正常範圍_不應報錯",
|
||||
closePercentage: 50.0,
|
||||
totalQuantity: 1.0,
|
||||
markPrice: 100.0,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模擬百分比驗證邏輯
|
||||
var err error
|
||||
if tt.closePercentage <= 0 || tt.closePercentage > 100 {
|
||||
err = fmt.Errorf("平仓百分比必须在 0-100 之间,当前: %.1f", tt.closePercentage)
|
||||
}
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("期望報錯但沒有報錯")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("不應報錯但報錯了: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPartialCloseIntegration 整合測試(使用 mock trader)
|
||||
func TestPartialCloseIntegration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
symbol string
|
||||
side string
|
||||
totalQuantity float64
|
||||
markPrice float64
|
||||
closePercentage float64
|
||||
newStopLoss float64
|
||||
newTakeProfit float64
|
||||
expectFullClose bool
|
||||
expectStopLossCall bool
|
||||
expectTakeProfitCall bool
|
||||
}{
|
||||
{
|
||||
name: "LONG倉_正常部分平倉_有止盈止損",
|
||||
symbol: "BTCUSDT",
|
||||
side: "LONG",
|
||||
totalQuantity: 1.0,
|
||||
markPrice: 50000.0,
|
||||
closePercentage: 50.0,
|
||||
newStopLoss: 48000.0,
|
||||
newTakeProfit: 52000.0,
|
||||
expectFullClose: false,
|
||||
expectStopLossCall: true,
|
||||
expectTakeProfitCall: true,
|
||||
},
|
||||
{
|
||||
name: "SHORT倉_剩餘價值過小_應自動全平",
|
||||
symbol: "ETHUSDT",
|
||||
side: "SHORT",
|
||||
totalQuantity: 0.02,
|
||||
markPrice: 3000.0, // 總價值 60 USDT
|
||||
closePercentage: 95.0, // 剩餘 3 USDT < 10 USDT
|
||||
newStopLoss: 3100.0,
|
||||
newTakeProfit: 2900.0,
|
||||
expectFullClose: true,
|
||||
expectStopLossCall: false, // 全平不需要恢復止盈止損
|
||||
expectTakeProfitCall: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 創建 mock trader
|
||||
mockTrader := &MockPartialCloseTrader{
|
||||
positions: []map[string]interface{}{
|
||||
{
|
||||
"symbol": tt.symbol,
|
||||
"side": tt.side,
|
||||
"quantity": tt.totalQuantity,
|
||||
"markPrice": tt.markPrice,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 創建決策
|
||||
dec := &decision.Decision{
|
||||
Symbol: tt.symbol,
|
||||
Action: "partial_close",
|
||||
ClosePercentage: tt.closePercentage,
|
||||
NewStopLoss: tt.newStopLoss,
|
||||
NewTakeProfit: tt.newTakeProfit,
|
||||
}
|
||||
|
||||
// 創建 actionRecord
|
||||
actionRecord := &logger.DecisionAction{}
|
||||
|
||||
// 計算剩餘價值
|
||||
closeQuantity := tt.totalQuantity * (tt.closePercentage / 100.0)
|
||||
remainingQuantity := tt.totalQuantity - closeQuantity
|
||||
remainingValue := remainingQuantity * tt.markPrice
|
||||
|
||||
// 驗證最小倉位檢查
|
||||
const MIN_POSITION_VALUE = 10.0
|
||||
shouldFullClose := remainingValue > 0 && remainingValue <= MIN_POSITION_VALUE
|
||||
|
||||
if shouldFullClose != tt.expectFullClose {
|
||||
t.Errorf("最小倉位檢查不符: shouldFullClose = %v, 期望 = %v (剩餘 %.2f USDT)",
|
||||
shouldFullClose, tt.expectFullClose, remainingValue)
|
||||
}
|
||||
|
||||
// 模擬執行邏輯
|
||||
if shouldFullClose {
|
||||
// 應該轉為全平
|
||||
if tt.side == "LONG" {
|
||||
mockTrader.CloseLong(tt.symbol, tt.totalQuantity)
|
||||
} else {
|
||||
mockTrader.CloseShort(tt.symbol, tt.totalQuantity)
|
||||
}
|
||||
} else {
|
||||
// 正常部分平倉
|
||||
if tt.side == "LONG" {
|
||||
mockTrader.ClosePartialLong(tt.symbol, closeQuantity)
|
||||
} else {
|
||||
mockTrader.ClosePartialShort(tt.symbol, closeQuantity)
|
||||
}
|
||||
|
||||
// 恢復止盈止損
|
||||
if dec.NewStopLoss > 0 {
|
||||
mockTrader.SetStopLoss(tt.symbol, tt.side, remainingQuantity, dec.NewStopLoss)
|
||||
}
|
||||
if dec.NewTakeProfit > 0 {
|
||||
mockTrader.SetTakeProfit(tt.symbol, tt.side, remainingQuantity, dec.NewTakeProfit)
|
||||
}
|
||||
}
|
||||
|
||||
// 驗證調用
|
||||
if tt.expectFullClose {
|
||||
if !mockTrader.closeLongCalled && !mockTrader.closeShortCalled {
|
||||
t.Error("期望調用全平但沒有調用")
|
||||
}
|
||||
if mockTrader.closePartialCalled {
|
||||
t.Error("不應該調用部分平倉")
|
||||
}
|
||||
} else {
|
||||
if !mockTrader.closePartialCalled {
|
||||
t.Error("期望調用部分平倉但沒有調用")
|
||||
}
|
||||
}
|
||||
|
||||
if mockTrader.stopLossCalled != tt.expectStopLossCall {
|
||||
t.Errorf("止損調用不符: called = %v, 期望 = %v",
|
||||
mockTrader.stopLossCalled, tt.expectStopLossCall)
|
||||
}
|
||||
|
||||
if mockTrader.takeProfitCalled != tt.expectTakeProfitCall {
|
||||
t.Errorf("止盈調用不符: called = %v, 期望 = %v",
|
||||
mockTrader.takeProfitCalled, tt.expectTakeProfitCall)
|
||||
}
|
||||
|
||||
_ = actionRecord // 避免未使用警告
|
||||
})
|
||||
}
|
||||
}
|
||||
318
trader/position_sync.go
Normal file
318
trader/position_sync.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package trader
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"nofx/logger"
|
||||
"nofx/store"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PositionSyncManager 仓位状态同步管理器
|
||||
// 负责定期同步交易所仓位,检测手动平仓等变化
|
||||
type PositionSyncManager struct {
|
||||
store *store.Store
|
||||
interval time.Duration
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
traderCache map[string]Trader // trader_id -> Trader 实例缓存
|
||||
configCache map[string]*store.TraderFullConfig // trader_id -> 配置缓存
|
||||
cacheMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewPositionSyncManager 创建仓位同步管理器
|
||||
func NewPositionSyncManager(st *store.Store, interval time.Duration) *PositionSyncManager {
|
||||
if interval == 0 {
|
||||
interval = 10 * time.Second
|
||||
}
|
||||
return &PositionSyncManager{
|
||||
store: st,
|
||||
interval: interval,
|
||||
stopCh: make(chan struct{}),
|
||||
traderCache: make(map[string]Trader),
|
||||
configCache: make(map[string]*store.TraderFullConfig),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动仓位同步服务
|
||||
func (m *PositionSyncManager) Start() {
|
||||
m.wg.Add(1)
|
||||
go m.run()
|
||||
logger.Info("📊 仓位同步管理器已启动")
|
||||
}
|
||||
|
||||
// Stop 停止仓位同步服务
|
||||
func (m *PositionSyncManager) Stop() {
|
||||
close(m.stopCh)
|
||||
m.wg.Wait()
|
||||
|
||||
// 清理缓存
|
||||
m.cacheMutex.Lock()
|
||||
m.traderCache = make(map[string]Trader)
|
||||
m.configCache = make(map[string]*store.TraderFullConfig)
|
||||
m.cacheMutex.Unlock()
|
||||
|
||||
logger.Info("📊 仓位同步管理器已停止")
|
||||
}
|
||||
|
||||
// run 主循环
|
||||
func (m *PositionSyncManager) run() {
|
||||
defer m.wg.Done()
|
||||
|
||||
// 启动时立即执行一次
|
||||
m.syncPositions()
|
||||
|
||||
ticker := time.NewTicker(m.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.syncPositions()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// syncPositions 同步所有仓位状态
|
||||
func (m *PositionSyncManager) syncPositions() {
|
||||
// 获取所有 OPEN 状态的仓位
|
||||
localPositions, err := m.store.Position().GetAllOpenPositions()
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ 获取本地仓位失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(localPositions) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 按 trader_id 分组
|
||||
positionsByTrader := make(map[string][]*store.TraderPosition)
|
||||
for _, pos := range localPositions {
|
||||
positionsByTrader[pos.TraderID] = append(positionsByTrader[pos.TraderID], pos)
|
||||
}
|
||||
|
||||
// 逐个 trader 处理
|
||||
for traderID, traderPositions := range positionsByTrader {
|
||||
m.syncTraderPositions(traderID, traderPositions)
|
||||
}
|
||||
}
|
||||
|
||||
// syncTraderPositions 同步单个 trader 的仓位
|
||||
func (m *PositionSyncManager) syncTraderPositions(traderID string, localPositions []*store.TraderPosition) {
|
||||
// 获取或创建 trader 实例
|
||||
trader, err := m.getOrCreateTrader(traderID)
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ 获取 trader 实例失败 (ID: %s): %v", traderID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取交易所当前仓位
|
||||
exchangePositions, err := trader.GetPositions()
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ 获取交易所仓位失败 (ID: %s): %v", traderID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 构建交易所仓位 map: symbol_side -> position
|
||||
exchangeMap := make(map[string]map[string]interface{})
|
||||
for _, pos := range exchangePositions {
|
||||
symbol, _ := pos["symbol"].(string)
|
||||
side, _ := pos["positionSide"].(string)
|
||||
if symbol == "" || side == "" {
|
||||
continue
|
||||
}
|
||||
key := fmt.Sprintf("%s_%s", symbol, side)
|
||||
exchangeMap[key] = pos
|
||||
}
|
||||
|
||||
// 对比本地和交易所仓位
|
||||
for _, localPos := range localPositions {
|
||||
key := fmt.Sprintf("%s_%s", localPos.Symbol, localPos.Side)
|
||||
exchangePos, exists := exchangeMap[key]
|
||||
|
||||
if !exists {
|
||||
// 交易所没有这个仓位了 → 已被平仓
|
||||
m.closeLocalPosition(localPos, trader, "manual")
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查数量是否为0或很小
|
||||
qty := getFloatFromMap(exchangePos, "positionAmt")
|
||||
if qty < 0 {
|
||||
qty = -qty // 空仓数量是负的
|
||||
}
|
||||
|
||||
if qty < 0.0000001 {
|
||||
// 数量为0,仓位已平
|
||||
m.closeLocalPosition(localPos, trader, "manual")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeLocalPosition 标记本地仓位为已平仓
|
||||
func (m *PositionSyncManager) closeLocalPosition(pos *store.TraderPosition, trader Trader, reason string) {
|
||||
// 尝试获取最后成交价作为平仓价
|
||||
exitPrice := pos.EntryPrice // 默认用开仓价
|
||||
|
||||
// 尝试从交易所获取最新价格
|
||||
if price, err := trader.GetMarketPrice(pos.Symbol); err == nil && price > 0 {
|
||||
exitPrice = price
|
||||
}
|
||||
|
||||
// 计算盈亏
|
||||
var realizedPnL float64
|
||||
if pos.Side == "LONG" {
|
||||
realizedPnL = (exitPrice - pos.EntryPrice) * pos.Quantity
|
||||
} else {
|
||||
realizedPnL = (pos.EntryPrice - exitPrice) * pos.Quantity
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
err := m.store.Position().ClosePosition(
|
||||
pos.ID,
|
||||
exitPrice,
|
||||
"", // 手动平仓没有订单ID
|
||||
realizedPnL,
|
||||
0, // 手动平仓无法获取手续费
|
||||
reason,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logger.Infof("⚠️ 更新仓位状态失败: %v", err)
|
||||
} else {
|
||||
logger.Infof("📊 仓位已平仓 [%s] %s %s @ %.4f → %.4f, PnL: %.2f (%s)",
|
||||
pos.TraderID[:8], pos.Symbol, pos.Side, pos.EntryPrice, exitPrice, realizedPnL, reason)
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateTrader 获取或创建 trader 实例
|
||||
func (m *PositionSyncManager) getOrCreateTrader(traderID string) (Trader, error) {
|
||||
m.cacheMutex.RLock()
|
||||
trader, exists := m.traderCache[traderID]
|
||||
m.cacheMutex.RUnlock()
|
||||
|
||||
if exists && trader != nil {
|
||||
return trader, nil
|
||||
}
|
||||
|
||||
// 需要创建新的 trader 实例
|
||||
config, err := m.getTraderConfig(traderID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 trader 配置失败: %w", err)
|
||||
}
|
||||
|
||||
trader, err = m.createTrader(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 trader 实例失败: %w", err)
|
||||
}
|
||||
|
||||
m.cacheMutex.Lock()
|
||||
m.traderCache[traderID] = trader
|
||||
m.cacheMutex.Unlock()
|
||||
|
||||
return trader, nil
|
||||
}
|
||||
|
||||
// getTraderConfig 获取 trader 配置
|
||||
func (m *PositionSyncManager) getTraderConfig(traderID string) (*store.TraderFullConfig, error) {
|
||||
m.cacheMutex.RLock()
|
||||
config, exists := m.configCache[traderID]
|
||||
m.cacheMutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// 从数据库获取
|
||||
traders, err := m.store.Trader().ListAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取 trader 列表失败: %w", err)
|
||||
}
|
||||
|
||||
var userID string
|
||||
for _, t := range traders {
|
||||
if t.ID == traderID {
|
||||
userID = t.UserID
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
return nil, fmt.Errorf("找不到 trader: %s", traderID)
|
||||
}
|
||||
|
||||
config, err = m.store.Trader().GetFullConfig(userID, traderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.cacheMutex.Lock()
|
||||
m.configCache[traderID] = config
|
||||
m.cacheMutex.Unlock()
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// createTrader 根据配置创建 trader 实例
|
||||
func (m *PositionSyncManager) createTrader(config *store.TraderFullConfig) (Trader, error) {
|
||||
exchange := config.Exchange
|
||||
|
||||
switch exchange.Type {
|
||||
case "binance":
|
||||
return NewFuturesTrader(exchange.APIKey, exchange.SecretKey, config.Trader.UserID), nil
|
||||
|
||||
case "bybit":
|
||||
return NewBybitTrader(exchange.APIKey, exchange.SecretKey), nil
|
||||
|
||||
case "hyperliquid":
|
||||
return NewHyperliquidTrader(exchange.SecretKey, exchange.HyperliquidWalletAddr, exchange.Testnet)
|
||||
|
||||
case "aster":
|
||||
return NewAsterTrader(exchange.AsterUser, exchange.AsterSigner, exchange.AsterPrivateKey)
|
||||
|
||||
case "lighter":
|
||||
if exchange.LighterAPIKeyPrivateKey != "" {
|
||||
return NewLighterTraderV2(
|
||||
exchange.LighterPrivateKey,
|
||||
exchange.LighterWalletAddr,
|
||||
exchange.LighterAPIKeyPrivateKey,
|
||||
exchange.Testnet,
|
||||
)
|
||||
}
|
||||
return NewLighterTrader(exchange.LighterPrivateKey, exchange.LighterWalletAddr, exchange.Testnet)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的交易所类型: %s", exchange.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCache 使缓存失效
|
||||
func (m *PositionSyncManager) InvalidateCache(traderID string) {
|
||||
m.cacheMutex.Lock()
|
||||
defer m.cacheMutex.Unlock()
|
||||
|
||||
delete(m.traderCache, traderID)
|
||||
delete(m.configCache, traderID)
|
||||
}
|
||||
|
||||
// getFloatFromMap 从 map 中获取 float64 值
|
||||
func getFloatFromMap(m map[string]interface{}, key string) float64 {
|
||||
if v, ok := m[key]; ok {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val
|
||||
case int64:
|
||||
return float64(val)
|
||||
case int:
|
||||
return float64(val)
|
||||
case string:
|
||||
var f float64
|
||||
fmt.Sscanf(val, "%f", &f)
|
||||
return f
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
35
web/package-lock.json
generated
35
web/package-lock.json
generated
@@ -121,7 +121,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@babel/core/-/core-7.28.5.tgz",
|
||||
"integrity": "sha512-e7jT4DxYvIDLk1ZHmU/m/mB19rex9sv0c2ftBtjSBv+kVM/902eh0fINUzD7UwLLNR+jU585GxUJ8/EBfAM5fw==",
|
||||
"dev": true,
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/code-frame": "^7.27.1",
|
||||
"@babel/generator": "^7.28.5",
|
||||
@@ -453,7 +452,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
},
|
||||
@@ -477,7 +475,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
@@ -2037,7 +2034,8 @@
|
||||
"resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.4.tgz",
|
||||
"integrity": "sha512-rfT93uj5s0PRL7EzccGMs3brplhcrghnDoV26NqKhCAS1hVo+WdNsPvE/yb6ilfr5hi2MEk6d5EWJTKdxg8jVw==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/@types/babel__core": {
|
||||
"version": "7.20.5",
|
||||
@@ -2158,7 +2156,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.26.tgz",
|
||||
"integrity": "sha512-RFA/bURkcKzx/X9oumPG9Vp3D3JUgus/d0b67KB0t5S/raciymilkOa66olh78MUI92QLbEJevO7rvqU/kjwKA==",
|
||||
"devOptional": true,
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@types/prop-types": "*",
|
||||
"csstype": "^3.0.2"
|
||||
@@ -2169,7 +2166,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.7.tgz",
|
||||
"integrity": "sha512-MEe3UeoENYVFXzoXEWsvcpg6ZvlrFNlOQ7EOsvhI3CfAXwzPfO8Qwuxd40nepsYKqyyVQnTdEfv68q91yLcKrQ==",
|
||||
"devOptional": true,
|
||||
"peer": true,
|
||||
"peerDependencies": {
|
||||
"@types/react": "^18.0.0"
|
||||
}
|
||||
@@ -2210,7 +2206,6 @@
|
||||
"integrity": "sha512-6m1I5RmHBGTnUGS113G04DMu3CpSdxCAU/UvtjNWL4Nuf3MW9tQhiJqRlHzChIkhy6kZSAQmc+I1bcGjE3yNKg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@typescript-eslint/scope-manager": "8.46.3",
|
||||
"@typescript-eslint/types": "8.46.3",
|
||||
@@ -2535,7 +2530,6 @@
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -2969,7 +2963,6 @@
|
||||
"url": "https://github.com/sponsors/ai"
|
||||
}
|
||||
],
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"baseline-browser-mapping": "^2.8.19",
|
||||
"caniuse-lite": "^1.0.30001751",
|
||||
@@ -3697,7 +3690,8 @@
|
||||
"resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.5.16.tgz",
|
||||
"integrity": "sha512-X7BJ2yElsnOJ30pZF4uIIDfBEVgF4XEBxL9Bxhy6dnrm5hkzqmsWHGTiHqRiITNhMyFLyAiWndIJP7Z1NTteDg==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/dom-helpers": {
|
||||
"version": "5.2.1",
|
||||
@@ -4015,7 +4009,6 @@
|
||||
"integrity": "sha512-BhHmn2yNOFA9H9JmmIVKJmd288g9hrVRDkdoIgRCRuSySRUHH7r/DI6aAXW9T1WwUuY3DFgrcaqB+deURBLR5g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@eslint-community/eslint-utils": "^4.8.0",
|
||||
"@eslint-community/regexpp": "^4.12.1",
|
||||
@@ -4076,7 +4069,6 @@
|
||||
"integrity": "sha512-82GZUjRS0p/jganf6q1rEO25VSoHH0hKPCTrgillPjdI/3bgBhAE1QzHrHTizjpRvy6pGAvKjDJtk2pF9NDq8w==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"eslint-config-prettier": "bin/cli.js"
|
||||
},
|
||||
@@ -5590,7 +5582,6 @@
|
||||
"resolved": "https://registry.npmjs.org/jiti/-/jiti-1.21.7.tgz",
|
||||
"integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==",
|
||||
"dev": true,
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"jiti": "bin/jiti.js"
|
||||
}
|
||||
@@ -5619,7 +5610,6 @@
|
||||
"integrity": "sha512-8i7LzZj7BF8uplX+ZyOlIz86V6TAsSs+np6m1kpW9u0JWi4z/1t+FzcK1aek+ybTnAC4KhBL4uXCNT0wcUIeCw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"cssstyle": "^4.1.0",
|
||||
"data-urls": "^5.0.0",
|
||||
@@ -5994,6 +5984,7 @@
|
||||
"integrity": "sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"lz-string": "bin/bin.js"
|
||||
}
|
||||
@@ -6581,7 +6572,6 @@
|
||||
"url": "https://github.com/sponsors/ai"
|
||||
}
|
||||
],
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"nanoid": "^3.3.11",
|
||||
"picocolors": "^1.1.1",
|
||||
@@ -6735,7 +6725,6 @@
|
||||
"integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"prettier": "bin/prettier.cjs"
|
||||
},
|
||||
@@ -6765,6 +6754,7 @@
|
||||
"integrity": "sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"ansi-regex": "^5.0.1",
|
||||
"ansi-styles": "^5.0.0",
|
||||
@@ -6780,6 +6770,7 @@
|
||||
"integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
@@ -6790,6 +6781,7 @@
|
||||
"integrity": "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
},
|
||||
@@ -6802,7 +6794,8 @@
|
||||
"resolved": "https://registry.npmjs.org/react-is/-/react-is-17.0.2.tgz",
|
||||
"integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/prop-types": {
|
||||
"version": "15.8.1",
|
||||
@@ -6859,7 +6852,6 @@
|
||||
"version": "18.3.1",
|
||||
"resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz",
|
||||
"integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"loose-envify": "^1.1.0"
|
||||
},
|
||||
@@ -6871,7 +6863,6 @@
|
||||
"version": "18.3.1",
|
||||
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz",
|
||||
"integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"loose-envify": "^1.1.0",
|
||||
"scheduler": "^0.23.2"
|
||||
@@ -8063,7 +8054,6 @@
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -8280,7 +8270,6 @@
|
||||
"resolved": "https://registry.npmjs.org/typescript/-/typescript-5.9.3.tgz",
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -8431,7 +8420,6 @@
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz",
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
@@ -9036,7 +9024,6 @@
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -9573,7 +9560,6 @@
|
||||
"integrity": "sha512-o5a9xKjbtuhY6Bi5S3+HvbRERmouabWbyUcpXXUA1u+GNUKoROi9byOJ8M0nHbHYHkYICiMlqxkg1KkYmm25Sw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.21.3",
|
||||
"postcss": "^8.4.43",
|
||||
@@ -9987,7 +9973,6 @@
|
||||
"integrity": "sha512-JInaHOamG8pt5+Ey8kGmdcAcg3OL9reK8ltczgHTAwNhMys/6ThXHityHxVV2p3fkw/c+MAvBHFVYHFZDmjMCQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/colinhacks"
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useEffect, useState } from 'react'
|
||||
import useSWR from 'swr'
|
||||
import { api } from './lib/api'
|
||||
import { EquityChart } from './components/EquityChart'
|
||||
import { ChartTabs } from './components/ChartTabs'
|
||||
import { AITradersPage } from './components/AITradersPage'
|
||||
import { LoginPage } from './components/LoginPage'
|
||||
import { RegisterPage } from './components/RegisterPage'
|
||||
@@ -10,7 +10,6 @@ import { CompetitionPage } from './components/CompetitionPage'
|
||||
import { LandingPage } from './pages/LandingPage'
|
||||
import { FAQPage } from './pages/FAQPage'
|
||||
import HeaderBar from './components/HeaderBar'
|
||||
import AILearning from './components/AILearning'
|
||||
import { LanguageProvider, useLanguage } from './contexts/LanguageContext'
|
||||
import { AuthProvider, useAuth } from './contexts/AuthContext'
|
||||
import { ConfirmDialogProvider } from './components/ConfirmDialog'
|
||||
@@ -780,9 +779,9 @@ function TraderDetailsPage({
|
||||
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6 mb-6">
|
||||
{/* 左侧:图表 + 持仓 */}
|
||||
<div className="space-y-6">
|
||||
{/* Equity Chart */}
|
||||
{/* Chart Tabs (Equity / K-line) */}
|
||||
<div className="animate-slide-in" style={{ animationDelay: '0.1s' }}>
|
||||
<EquityChart traderId={selectedTrader.trader_id} />
|
||||
<ChartTabs traderId={selectedTrader.trader_id} />
|
||||
</div>
|
||||
|
||||
{/* Current Positions */}
|
||||
@@ -1002,10 +1001,6 @@ function TraderDetailsPage({
|
||||
{/* 右侧结束 */}
|
||||
</div>
|
||||
|
||||
{/* AI Learning & Performance Analysis */}
|
||||
<div className="mb-6 animate-slide-in" style={{ animationDelay: '0.3s' }}>
|
||||
<AILearning traderId={selectedTrader.trader_id} />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
89
web/src/components/ChartTabs.tsx
Normal file
89
web/src/components/ChartTabs.tsx
Normal file
@@ -0,0 +1,89 @@
|
||||
import { useState } from 'react'
|
||||
import { EquityChart } from './EquityChart'
|
||||
import { TradingViewChart } from './TradingViewChart'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { BarChart3, CandlestickChart } from 'lucide-react'
|
||||
|
||||
interface ChartTabsProps {
|
||||
traderId: string
|
||||
}
|
||||
|
||||
type ChartTab = 'equity' | 'kline'
|
||||
|
||||
export function ChartTabs({ traderId }: ChartTabsProps) {
|
||||
const { language } = useLanguage()
|
||||
const [activeTab, setActiveTab] = useState<ChartTab>('equity')
|
||||
|
||||
console.log('[ChartTabs] rendering, activeTab:', activeTab)
|
||||
|
||||
return (
|
||||
<div className="binance-card">
|
||||
{/* Tab Headers - 这是Tab切换按钮区域 */}
|
||||
<div
|
||||
className="flex items-center gap-2 p-3"
|
||||
style={{
|
||||
borderBottom: '1px solid #2B3139',
|
||||
background: '#0B0E11',
|
||||
}}
|
||||
>
|
||||
<button
|
||||
onClick={() => {
|
||||
console.log('[ChartTabs] switching to equity')
|
||||
setActiveTab('equity')
|
||||
}}
|
||||
className="flex items-center gap-2 px-4 py-2 rounded-lg text-sm font-semibold"
|
||||
style={
|
||||
activeTab === 'equity'
|
||||
? {
|
||||
background: 'rgba(240, 185, 11, 0.15)',
|
||||
color: '#F0B90B',
|
||||
border: '1px solid rgba(240, 185, 11, 0.3)',
|
||||
}
|
||||
: {
|
||||
background: 'transparent',
|
||||
color: '#848E9C',
|
||||
border: '1px solid transparent',
|
||||
}
|
||||
}
|
||||
>
|
||||
<BarChart3 className="w-4 h-4" />
|
||||
{t('accountEquityCurve', language)}
|
||||
</button>
|
||||
|
||||
<button
|
||||
onClick={() => {
|
||||
console.log('[ChartTabs] switching to kline')
|
||||
setActiveTab('kline')
|
||||
}}
|
||||
className="flex items-center gap-2 px-4 py-2 rounded-lg text-sm font-semibold"
|
||||
style={
|
||||
activeTab === 'kline'
|
||||
? {
|
||||
background: 'rgba(240, 185, 11, 0.15)',
|
||||
color: '#F0B90B',
|
||||
border: '1px solid rgba(240, 185, 11, 0.3)',
|
||||
}
|
||||
: {
|
||||
background: 'transparent',
|
||||
color: '#848E9C',
|
||||
border: '1px solid transparent',
|
||||
}
|
||||
}
|
||||
>
|
||||
<CandlestickChart className="w-4 h-4" />
|
||||
{t('marketChart', language)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Tab Content */}
|
||||
<div>
|
||||
{activeTab === 'equity' ? (
|
||||
<EquityChart traderId={traderId} embedded />
|
||||
) : (
|
||||
<TradingViewChart height={400} embedded />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -126,6 +126,11 @@ export function DecisionCard({ decision, language }: DecisionCardProps) {
|
||||
background: 'rgba(14, 203, 129, 0.1)',
|
||||
color: '#0ECB81',
|
||||
}
|
||||
: action.action === 'wait' || action.action === 'hold'
|
||||
? {
|
||||
background: 'rgba(132, 142, 156, 0.1)',
|
||||
color: '#848E9C',
|
||||
}
|
||||
: {
|
||||
background: 'rgba(248, 113, 113, 0.1)',
|
||||
color: '#F87171',
|
||||
|
||||
@@ -33,9 +33,10 @@ interface EquityPoint {
|
||||
|
||||
interface EquityChartProps {
|
||||
traderId?: string
|
||||
embedded?: boolean // 嵌入模式(不显示外层卡片)
|
||||
}
|
||||
|
||||
export function EquityChart({ traderId }: EquityChartProps) {
|
||||
export function EquityChart({ traderId, embedded = false }: EquityChartProps) {
|
||||
const { language } = useLanguage()
|
||||
const { user, token } = useAuth()
|
||||
const [displayMode, setDisplayMode] = useState<'dollar' | 'percent'>('dollar')
|
||||
@@ -62,7 +63,7 @@ export function EquityChart({ traderId }: EquityChartProps) {
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="binance-card p-6">
|
||||
<div className={embedded ? 'p-6' : 'binance-card p-6'}>
|
||||
<div
|
||||
className="flex items-center gap-3 p-4 rounded"
|
||||
style={{
|
||||
@@ -89,10 +90,12 @@ export function EquityChart({ traderId }: EquityChartProps) {
|
||||
|
||||
if (!validHistory || validHistory.length === 0) {
|
||||
return (
|
||||
<div className="binance-card p-6">
|
||||
<h3 className="text-lg font-semibold mb-6" style={{ color: '#EAECEF' }}>
|
||||
{t('accountEquityCurve', language)}
|
||||
</h3>
|
||||
<div className={embedded ? 'p-6' : 'binance-card p-6'}>
|
||||
{!embedded && (
|
||||
<h3 className="text-lg font-semibold mb-6" style={{ color: '#EAECEF' }}>
|
||||
{t('accountEquityCurve', language)}
|
||||
</h3>
|
||||
)}
|
||||
<div className="text-center py-16" style={{ color: '#848E9C' }}>
|
||||
<div className="mb-4 flex justify-center opacity-50">
|
||||
<BarChart3 className="w-16 h-16" />
|
||||
@@ -193,16 +196,18 @@ export function EquityChart({ traderId }: EquityChartProps) {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="binance-card p-3 sm:p-5 animate-fade-in">
|
||||
<div className={embedded ? 'p-3 sm:p-5' : 'binance-card p-3 sm:p-5 animate-fade-in'}>
|
||||
{/* Header */}
|
||||
<div className="flex flex-col gap-3 sm:flex-row sm:items-center sm:justify-between mb-4">
|
||||
<div className="flex-1">
|
||||
<h3
|
||||
className="text-base sm:text-lg font-bold mb-2"
|
||||
style={{ color: '#EAECEF' }}
|
||||
>
|
||||
{t('accountEquityCurve', language)}
|
||||
</h3>
|
||||
{!embedded && (
|
||||
<h3
|
||||
className="text-base sm:text-lg font-bold mb-2"
|
||||
style={{ color: '#EAECEF' }}
|
||||
>
|
||||
{t('accountEquityCurve', language)}
|
||||
</h3>
|
||||
)}
|
||||
<div className="flex flex-col sm:flex-row sm:items-baseline gap-2 sm:gap-4">
|
||||
<span
|
||||
className="text-2xl sm:text-3xl font-bold mono"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
import { useAuth } from '../contexts/AuthContext'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
@@ -11,7 +10,6 @@ import { useSystemConfig } from '../hooks/useSystemConfig'
|
||||
export function LoginPage() {
|
||||
const { language } = useLanguage()
|
||||
const { login, loginAdmin, verifyOTP } = useAuth()
|
||||
const navigate = useNavigate()
|
||||
const [step, setStep] = useState<'login' | 'otp'>('login')
|
||||
const [email, setEmail] = useState('')
|
||||
const [password, setPassword] = useState('')
|
||||
@@ -236,7 +234,9 @@ export function LoginPage() {
|
||||
<div className="text-right mt-2">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => navigate('/reset-password')}
|
||||
onClick={() => {
|
||||
window.location.href = '/reset-password'
|
||||
}}
|
||||
className="text-xs hover:underline"
|
||||
style={{ color: '#F0B90B' }}
|
||||
>
|
||||
@@ -348,7 +348,9 @@ export function LoginPage() {
|
||||
<p className="text-sm" style={{ color: 'var(--text-secondary)' }}>
|
||||
还没有账户?{' '}
|
||||
<button
|
||||
onClick={() => navigate('/register')}
|
||||
onClick={() => {
|
||||
window.location.href = '/register'
|
||||
}}
|
||||
className="font-semibold hover:underline transition-colors"
|
||||
style={{ color: 'var(--brand-yellow)' }}
|
||||
>
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import { useNavigate } from 'react-router-dom'
|
||||
import { useAuth } from '../contexts/AuthContext'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
@@ -14,7 +13,6 @@ import { RegistrationDisabled } from './RegistrationDisabled'
|
||||
export function RegisterPage() {
|
||||
const { language } = useLanguage()
|
||||
const { register, completeRegistration } = useAuth()
|
||||
const navigate = useNavigate()
|
||||
const [step, setStep] = useState<'register' | 'setup-otp' | 'verify-otp'>(
|
||||
'register'
|
||||
)
|
||||
@@ -530,7 +528,9 @@ export function RegisterPage() {
|
||||
<p className="text-sm" style={{ color: 'var(--text-secondary)' }}>
|
||||
已有账户?{' '}
|
||||
<button
|
||||
onClick={() => navigate('/login')}
|
||||
onClick={() => {
|
||||
window.location.href = '/login'
|
||||
}}
|
||||
className="font-semibold hover:underline transition-colors"
|
||||
style={{ color: 'var(--brand-yellow)' }}
|
||||
>
|
||||
|
||||
377
web/src/components/TradingViewChart.tsx
Normal file
377
web/src/components/TradingViewChart.tsx
Normal file
@@ -0,0 +1,377 @@
|
||||
import { useEffect, useRef, useState, memo } from 'react'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { t } from '../i18n/translations'
|
||||
import { ChevronDown, TrendingUp, X } from 'lucide-react'
|
||||
|
||||
// 支持的交易所列表 (合约格式)
|
||||
const EXCHANGES = [
|
||||
{ id: 'BINANCE', name: 'Binance', prefix: 'BINANCE:', suffix: '.P' },
|
||||
{ id: 'BYBIT', name: 'Bybit', prefix: 'BYBIT:', suffix: '.P' },
|
||||
{ id: 'OKX', name: 'OKX', prefix: 'OKX:', suffix: '.P' },
|
||||
{ id: 'BITGET', name: 'Bitget', prefix: 'BITGET:', suffix: '.P' },
|
||||
{ id: 'MEXC', name: 'MEXC', prefix: 'MEXC:', suffix: '.P' },
|
||||
{ id: 'GATEIO', name: 'Gate.io', prefix: 'GATEIO:', suffix: '.P' },
|
||||
] as const
|
||||
|
||||
// 热门交易对
|
||||
const POPULAR_SYMBOLS = [
|
||||
'BTCUSDT',
|
||||
'ETHUSDT',
|
||||
'SOLUSDT',
|
||||
'BNBUSDT',
|
||||
'XRPUSDT',
|
||||
'DOGEUSDT',
|
||||
'ADAUSDT',
|
||||
'AVAXUSDT',
|
||||
'DOTUSDT',
|
||||
'LINKUSDT',
|
||||
'MATICUSDT',
|
||||
'LTCUSDT',
|
||||
]
|
||||
|
||||
// 时间周期选项
|
||||
const INTERVALS = [
|
||||
{ id: '1', label: '1m' },
|
||||
{ id: '5', label: '5m' },
|
||||
{ id: '15', label: '15m' },
|
||||
{ id: '30', label: '30m' },
|
||||
{ id: '60', label: '1H' },
|
||||
{ id: '240', label: '4H' },
|
||||
{ id: 'D', label: '1D' },
|
||||
{ id: 'W', label: '1W' },
|
||||
]
|
||||
|
||||
interface TradingViewChartProps {
|
||||
defaultSymbol?: string
|
||||
defaultExchange?: string
|
||||
height?: number
|
||||
showToolbar?: boolean
|
||||
embedded?: boolean // 嵌入模式(不显示外层卡片)
|
||||
}
|
||||
|
||||
function TradingViewChartComponent({
|
||||
defaultSymbol = 'BTCUSDT',
|
||||
defaultExchange = 'BINANCE',
|
||||
height = 400,
|
||||
showToolbar = true,
|
||||
embedded = false,
|
||||
}: TradingViewChartProps) {
|
||||
const { language } = useLanguage()
|
||||
const containerRef = useRef<HTMLDivElement>(null)
|
||||
const [exchange, setExchange] = useState(defaultExchange)
|
||||
const [symbol, setSymbol] = useState(defaultSymbol)
|
||||
const [timeInterval, setTimeInterval] = useState('60')
|
||||
const [customSymbol, setCustomSymbol] = useState('')
|
||||
const [showExchangeDropdown, setShowExchangeDropdown] = useState(false)
|
||||
const [showSymbolDropdown, setShowSymbolDropdown] = useState(false)
|
||||
const [isFullscreen, setIsFullscreen] = useState(false)
|
||||
|
||||
// 获取完整的交易对符号 (合约格式: BINANCE:BTCUSDT.P)
|
||||
const getFullSymbol = () => {
|
||||
const exchangeInfo = EXCHANGES.find((e) => e.id === exchange)
|
||||
const prefix = exchangeInfo?.prefix || 'BINANCE:'
|
||||
const suffix = exchangeInfo?.suffix || '.P'
|
||||
return `${prefix}${symbol}${suffix}`
|
||||
}
|
||||
|
||||
// 加载 TradingView Widget
|
||||
useEffect(() => {
|
||||
if (!containerRef.current) return
|
||||
|
||||
// 清空容器
|
||||
containerRef.current.innerHTML = ''
|
||||
|
||||
// 创建 widget 容器
|
||||
const widgetContainer = document.createElement('div')
|
||||
widgetContainer.className = 'tradingview-widget-container'
|
||||
widgetContainer.style.height = '100%'
|
||||
widgetContainer.style.width = '100%'
|
||||
|
||||
const widgetDiv = document.createElement('div')
|
||||
widgetDiv.className = 'tradingview-widget-container__widget'
|
||||
widgetDiv.style.height = '100%'
|
||||
widgetDiv.style.width = '100%'
|
||||
|
||||
widgetContainer.appendChild(widgetDiv)
|
||||
containerRef.current.appendChild(widgetContainer)
|
||||
|
||||
// 加载 TradingView 脚本
|
||||
const script = document.createElement('script')
|
||||
script.src =
|
||||
'https://s3.tradingview.com/external-embedding/embed-widget-advanced-chart.js'
|
||||
script.type = 'text/javascript'
|
||||
script.async = true
|
||||
script.innerHTML = JSON.stringify({
|
||||
autosize: true,
|
||||
symbol: getFullSymbol(),
|
||||
interval: timeInterval,
|
||||
timezone: 'Etc/UTC',
|
||||
theme: 'dark',
|
||||
style: '1',
|
||||
locale: language === 'zh' ? 'zh_CN' : 'en',
|
||||
enable_publishing: false,
|
||||
backgroundColor: 'rgba(11, 14, 17, 1)',
|
||||
gridColor: 'rgba(43, 49, 57, 0.5)',
|
||||
hide_top_toolbar: !showToolbar,
|
||||
hide_legend: false,
|
||||
save_image: false,
|
||||
calendar: false,
|
||||
hide_volume: false,
|
||||
support_host: 'https://www.tradingview.com',
|
||||
})
|
||||
|
||||
widgetContainer.appendChild(script)
|
||||
|
||||
return () => {
|
||||
if (containerRef.current) {
|
||||
containerRef.current.innerHTML = ''
|
||||
}
|
||||
}
|
||||
}, [exchange, symbol, timeInterval, language, showToolbar])
|
||||
|
||||
// 处理自定义交易对输入
|
||||
const handleCustomSymbolSubmit = () => {
|
||||
if (customSymbol.trim()) {
|
||||
let sym = customSymbol.trim().toUpperCase()
|
||||
// 如果没有 USDT 后缀,自动加上
|
||||
if (!sym.endsWith('USDT')) {
|
||||
sym = sym + 'USDT'
|
||||
}
|
||||
setSymbol(sym)
|
||||
setCustomSymbol('')
|
||||
setShowSymbolDropdown(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${embedded ? '' : 'binance-card'} overflow-hidden ${embedded ? '' : 'animate-fade-in'} ${
|
||||
isFullscreen
|
||||
? 'fixed inset-0 z-50 rounded-none'
|
||||
: ''
|
||||
}`}
|
||||
>
|
||||
{/* Header */}
|
||||
<div
|
||||
className="flex flex-wrap items-center gap-2 p-3 sm:p-4"
|
||||
style={{ borderBottom: embedded ? 'none' : '1px solid #2B3139' }}
|
||||
>
|
||||
{!embedded && (
|
||||
<div className="flex items-center gap-2">
|
||||
<TrendingUp className="w-5 h-5" style={{ color: '#F0B90B' }} />
|
||||
<h3
|
||||
className="text-base sm:text-lg font-bold"
|
||||
style={{ color: '#EAECEF' }}
|
||||
>
|
||||
{t('marketChart', language)}
|
||||
</h3>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Controls */}
|
||||
<div className={`flex flex-wrap items-center gap-2 ${embedded ? '' : 'ml-auto'}`}>
|
||||
{/* Exchange Selector */}
|
||||
<div className="relative">
|
||||
<button
|
||||
onClick={() => {
|
||||
setShowExchangeDropdown(!showExchangeDropdown)
|
||||
setShowSymbolDropdown(false)
|
||||
}}
|
||||
className="flex items-center gap-1 px-3 py-1.5 rounded text-sm font-medium transition-all"
|
||||
style={{
|
||||
background: '#1E2329',
|
||||
border: '1px solid #2B3139',
|
||||
color: '#EAECEF',
|
||||
}}
|
||||
>
|
||||
{EXCHANGES.find((e) => e.id === exchange)?.name || exchange}
|
||||
<ChevronDown className="w-4 h-4" style={{ color: '#848E9C' }} />
|
||||
</button>
|
||||
|
||||
{showExchangeDropdown && (
|
||||
<div
|
||||
className="absolute top-full left-0 mt-1 py-1 rounded-lg shadow-xl z-20 min-w-[120px]"
|
||||
style={{
|
||||
background: '#1E2329',
|
||||
border: '1px solid #2B3139',
|
||||
}}
|
||||
>
|
||||
{EXCHANGES.map((ex) => (
|
||||
<button
|
||||
key={ex.id}
|
||||
onClick={() => {
|
||||
setExchange(ex.id)
|
||||
setShowExchangeDropdown(false)
|
||||
}}
|
||||
className="w-full px-4 py-2 text-left text-sm transition-all hover:bg-opacity-50"
|
||||
style={{
|
||||
color: exchange === ex.id ? '#F0B90B' : '#EAECEF',
|
||||
background:
|
||||
exchange === ex.id
|
||||
? 'rgba(240, 185, 11, 0.1)'
|
||||
: 'transparent',
|
||||
}}
|
||||
>
|
||||
{ex.name}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Symbol Selector */}
|
||||
<div className="relative">
|
||||
<button
|
||||
onClick={() => {
|
||||
setShowSymbolDropdown(!showSymbolDropdown)
|
||||
setShowExchangeDropdown(false)
|
||||
}}
|
||||
className="flex items-center gap-1 px-3 py-1.5 rounded text-sm font-bold transition-all"
|
||||
style={{
|
||||
background: 'rgba(240, 185, 11, 0.1)',
|
||||
border: '1px solid rgba(240, 185, 11, 0.3)',
|
||||
color: '#F0B90B',
|
||||
}}
|
||||
>
|
||||
{symbol}
|
||||
<ChevronDown className="w-4 h-4" />
|
||||
</button>
|
||||
|
||||
{showSymbolDropdown && (
|
||||
<div
|
||||
className="absolute top-full left-0 mt-1 py-2 rounded-lg shadow-xl z-20 w-[280px]"
|
||||
style={{
|
||||
background: '#1E2329',
|
||||
border: '1px solid #2B3139',
|
||||
}}
|
||||
>
|
||||
{/* Custom Input */}
|
||||
<div className="px-3 pb-2" style={{ borderBottom: '1px solid #2B3139' }}>
|
||||
<div className="flex gap-2">
|
||||
<input
|
||||
type="text"
|
||||
value={customSymbol}
|
||||
onChange={(e) => setCustomSymbol(e.target.value.toUpperCase())}
|
||||
onKeyDown={(e) => e.key === 'Enter' && handleCustomSymbolSubmit()}
|
||||
placeholder={t('enterSymbol', language)}
|
||||
className="flex-1 px-3 py-1.5 rounded text-sm"
|
||||
style={{
|
||||
background: '#0B0E11',
|
||||
border: '1px solid #2B3139',
|
||||
color: '#EAECEF',
|
||||
}}
|
||||
/>
|
||||
<button
|
||||
onClick={handleCustomSymbolSubmit}
|
||||
className="px-3 py-1.5 rounded text-sm font-medium"
|
||||
style={{
|
||||
background: '#F0B90B',
|
||||
color: '#0B0E11',
|
||||
}}
|
||||
>
|
||||
OK
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Popular Symbols */}
|
||||
<div className="px-2 pt-2">
|
||||
<div
|
||||
className="text-xs px-2 py-1 mb-1"
|
||||
style={{ color: '#848E9C' }}
|
||||
>
|
||||
{t('popularSymbols', language)}
|
||||
</div>
|
||||
<div className="grid grid-cols-3 gap-1">
|
||||
{POPULAR_SYMBOLS.map((sym) => (
|
||||
<button
|
||||
key={sym}
|
||||
onClick={() => {
|
||||
setSymbol(sym)
|
||||
setShowSymbolDropdown(false)
|
||||
}}
|
||||
className="px-2 py-1.5 rounded text-xs font-medium transition-all"
|
||||
style={{
|
||||
color: symbol === sym ? '#F0B90B' : '#EAECEF',
|
||||
background:
|
||||
symbol === sym
|
||||
? 'rgba(240, 185, 11, 0.1)'
|
||||
: 'rgba(43, 49, 57, 0.3)',
|
||||
}}
|
||||
>
|
||||
{sym.replace('USDT', '')}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Interval Selector */}
|
||||
<div
|
||||
className="flex gap-0.5 p-0.5 rounded"
|
||||
style={{ background: '#0B0E11', border: '1px solid #2B3139' }}
|
||||
>
|
||||
{INTERVALS.map((int) => (
|
||||
<button
|
||||
key={int.id}
|
||||
onClick={() => setTimeInterval(int.id)}
|
||||
className="px-2 py-1 rounded text-xs font-medium transition-all"
|
||||
style={{
|
||||
background: timeInterval === int.id ? '#F0B90B' : 'transparent',
|
||||
color: timeInterval === int.id ? '#0B0E11' : '#848E9C',
|
||||
}}
|
||||
>
|
||||
{int.label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Fullscreen Toggle */}
|
||||
<button
|
||||
onClick={() => setIsFullscreen(!isFullscreen)}
|
||||
className="p-1.5 rounded transition-all"
|
||||
style={{
|
||||
background: isFullscreen ? '#F0B90B' : 'transparent',
|
||||
color: isFullscreen ? '#0B0E11' : '#848E9C',
|
||||
border: '1px solid #2B3139',
|
||||
}}
|
||||
title={isFullscreen ? t('exitFullscreen', language) : t('fullscreen', language)}
|
||||
>
|
||||
{isFullscreen ? (
|
||||
<X className="w-4 h-4" />
|
||||
) : (
|
||||
<svg className="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2">
|
||||
<path d="M8 3H5a2 2 0 00-2 2v3m18 0V5a2 2 0 00-2-2h-3m0 18h3a2 2 0 002-2v-3M3 16v3a2 2 0 002 2h3" />
|
||||
</svg>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Chart Container */}
|
||||
<div
|
||||
ref={containerRef}
|
||||
style={{
|
||||
height: isFullscreen ? 'calc(100vh - 60px)' : height,
|
||||
background: '#0B0E11',
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* Click outside to close dropdowns */}
|
||||
{(showExchangeDropdown || showSymbolDropdown) && (
|
||||
<div
|
||||
className="fixed inset-0 z-10"
|
||||
onClick={() => {
|
||||
setShowExchangeDropdown(false)
|
||||
setShowSymbolDropdown(false)
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// 使用 memo 避免不必要的重渲染
|
||||
export const TradingViewChart = memo(TradingViewChartComponent)
|
||||
@@ -83,6 +83,13 @@ export const translations = {
|
||||
currentGap: 'Current Gap',
|
||||
count: '{count} pts',
|
||||
|
||||
// TradingView Chart
|
||||
marketChart: 'Market Chart',
|
||||
enterSymbol: 'Enter symbol...',
|
||||
popularSymbols: 'Popular Symbols',
|
||||
fullscreen: 'Fullscreen',
|
||||
exitFullscreen: 'Exit Fullscreen',
|
||||
|
||||
// Backtest Page
|
||||
backtestPage: {
|
||||
title: 'Backtest Lab',
|
||||
@@ -264,40 +271,6 @@ export const translations = {
|
||||
pnl: 'P&L',
|
||||
pos: 'Pos',
|
||||
|
||||
// AI Learning
|
||||
aiLearning: 'AI Learning & Reflection',
|
||||
tradesAnalyzed: '{count} trades analyzed · Real-time evolution',
|
||||
latestReflection: 'Latest Reflection',
|
||||
fullCoT: 'Full Chain of Thought',
|
||||
totalTrades: 'Total Trades',
|
||||
winRate: 'Win Rate',
|
||||
avgWin: 'Avg Win',
|
||||
avgLoss: 'Avg Loss',
|
||||
profitFactor: 'Profit Factor',
|
||||
avgWinDivLoss: 'Avg Win ÷ Avg Loss',
|
||||
excellent: '🔥 Excellent - Strong profitability',
|
||||
good: '✓ Good - Stable profits',
|
||||
fair: '⚠️ Fair - Needs optimization',
|
||||
poor: '❌ Poor - Losses exceed gains',
|
||||
bestPerformer: 'Best Performer',
|
||||
worstPerformer: 'Worst Performer',
|
||||
symbolPerformance: 'Symbol Performance',
|
||||
tradeHistory: 'Trade History',
|
||||
completedTrades: 'Recent {count} completed trades',
|
||||
noCompletedTrades: 'No completed trades yet',
|
||||
completedTradesWillAppear: 'Completed trades will appear here',
|
||||
entry: 'Entry',
|
||||
exit: 'Exit',
|
||||
stopLoss: 'Stop Loss',
|
||||
latest: 'Latest',
|
||||
|
||||
// AI Learning Description
|
||||
howAILearns: 'How AI Learns & Evolves',
|
||||
aiLearningPoint1: 'Analyzes last 20 trading cycles before each decision',
|
||||
aiLearningPoint2: 'Identifies best & worst performing symbols',
|
||||
aiLearningPoint3: 'Optimizes position sizing based on win rate',
|
||||
aiLearningPoint4: 'Avoids repeating past mistakes',
|
||||
|
||||
// AI Traders Management
|
||||
manageAITraders: 'Manage your AI trading bots',
|
||||
aiModels: 'AI Models',
|
||||
@@ -499,9 +472,6 @@ export const translations = {
|
||||
|
||||
// Loading & Error
|
||||
loading: 'Loading...',
|
||||
loadingError: '⚠️ Failed to load AI learning data',
|
||||
noCompleteData:
|
||||
'No complete trading data (needs to complete open → close cycle)',
|
||||
|
||||
// AI Traders Page - Additional
|
||||
inUse: 'In Use',
|
||||
@@ -954,7 +924,7 @@ export const translations = {
|
||||
// Data & Privacy
|
||||
faqDataStorage: 'Where is my data stored?',
|
||||
faqDataStorageAnswer:
|
||||
'All data is stored locally on your machine in SQLite databases: config.db (trader configurations), trading.db (trade history), and decision_logs/ (AI decision records).',
|
||||
'All data is stored locally on your machine in SQLite databases: data.db (all configurations and trade history), and decision_logs/ (AI decision records).',
|
||||
|
||||
faqApiKeySecurity: 'Is my API key secure?',
|
||||
faqApiKeySecurityAnswer:
|
||||
@@ -1109,6 +1079,13 @@ export const translations = {
|
||||
currentGap: '当前差距',
|
||||
count: '{count} 个',
|
||||
|
||||
// TradingView Chart
|
||||
marketChart: '行情图表',
|
||||
enterSymbol: '输入币种...',
|
||||
popularSymbols: '热门币种',
|
||||
fullscreen: '全屏',
|
||||
exitFullscreen: '退出全屏',
|
||||
|
||||
// Backtest Page
|
||||
backtestPage: {
|
||||
title: '回测实验室',
|
||||
@@ -1288,40 +1265,6 @@ export const translations = {
|
||||
pnl: '收益',
|
||||
pos: '持仓',
|
||||
|
||||
// AI Learning
|
||||
aiLearning: 'AI学习与反思',
|
||||
tradesAnalyzed: '已分析 {count} 笔交易 · 实时演化',
|
||||
latestReflection: '最新反思',
|
||||
fullCoT: '📋 完整思维链',
|
||||
totalTrades: '总交易数',
|
||||
winRate: '胜率',
|
||||
avgWin: '平均盈利',
|
||||
avgLoss: '平均亏损',
|
||||
profitFactor: '盈亏比',
|
||||
avgWinDivLoss: '平均盈利 ÷ 平均亏损',
|
||||
excellent: '🔥 优秀 - 盈利能力强',
|
||||
good: '✓ 良好 - 稳定盈利',
|
||||
fair: '⚠️ 一般 - 需要优化',
|
||||
poor: '❌ 较差 - 亏损超过盈利',
|
||||
bestPerformer: '最佳表现',
|
||||
worstPerformer: '最差表现',
|
||||
symbolPerformance: '📊 币种表现',
|
||||
tradeHistory: '历史成交',
|
||||
completedTrades: '最近 {count} 笔已完成交易',
|
||||
noCompletedTrades: '暂无完成的交易',
|
||||
completedTradesWillAppear: '已完成的交易将显示在这里',
|
||||
entry: '入场',
|
||||
exit: '出场',
|
||||
stopLoss: '止损',
|
||||
latest: '最新',
|
||||
|
||||
// AI Learning Description
|
||||
howAILearns: '💡 AI如何学习和进化',
|
||||
aiLearningPoint1: '每次决策前分析最近20个交易周期',
|
||||
aiLearningPoint2: '识别表现最好和最差的币种',
|
||||
aiLearningPoint3: '根据胜率优化仓位大小',
|
||||
aiLearningPoint4: '避免重复过去的错误',
|
||||
|
||||
// AI Traders Management
|
||||
manageAITraders: '管理您的AI交易机器人',
|
||||
aiModels: 'AI模型',
|
||||
@@ -1512,8 +1455,6 @@ export const translations = {
|
||||
|
||||
// Loading & Error
|
||||
loading: '加载中...',
|
||||
loadingError: '⚠️ 加载AI学习数据失败',
|
||||
noCompleteData: '暂无完整交易数据(需要完成开仓→平仓的完整周期)',
|
||||
|
||||
// AI Traders Page - Additional
|
||||
inUse: '正在使用',
|
||||
@@ -1927,7 +1868,7 @@ export const translations = {
|
||||
// Data & Privacy
|
||||
faqDataStorage: '我的数据存储在哪里?',
|
||||
faqDataStorageAnswer:
|
||||
'所有数据都本地存储在您的机器上,使用 SQLite 数据库:config.db(交易员配置)、trading.db(交易历史)、decision_logs/(AI 决策记录)。',
|
||||
'所有数据都本地存储在您的机器上,使用 SQLite 数据库:data.db(所有配置和交易历史)、decision_logs/(AI 决策记录)。',
|
||||
|
||||
faqApiKeySecurity: 'API 密钥安全吗?',
|
||||
faqApiKeySecurityAnswer:
|
||||
|
||||
@@ -337,16 +337,6 @@ export const api = {
|
||||
return result.data!
|
||||
},
|
||||
|
||||
// 获取AI学习表现分析(支持trader_id)
|
||||
async getPerformance(traderId?: string): Promise<any> {
|
||||
const url = traderId
|
||||
? `${API_BASE}/performance?trader_id=${traderId}`
|
||||
: `${API_BASE}/performance`
|
||||
const result = await httpClient.get<any>(url)
|
||||
if (!result.success) throw new Error('获取AI学习数据失败')
|
||||
return result.data!
|
||||
},
|
||||
|
||||
// 获取竞赛数据(无需认证)
|
||||
async getCompetition(): Promise<CompetitionData> {
|
||||
const result = await httpClient.get<CompetitionData>(
|
||||
|
||||
@@ -2,8 +2,7 @@ import { useEffect, useState } from 'react'
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom'
|
||||
import useSWR from 'swr'
|
||||
import { api } from '../lib/api'
|
||||
import { EquityChart } from '../components/EquityChart'
|
||||
import AILearning from '../components/AILearning'
|
||||
import { ChartTabs } from '../components/ChartTabs'
|
||||
import { useLanguage } from '../contexts/LanguageContext'
|
||||
import { useAuth } from '../contexts/AuthContext'
|
||||
import { t, type Language } from '../i18n/translations'
|
||||
@@ -419,9 +418,9 @@ export default function TraderDashboard() {
|
||||
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6 mb-6">
|
||||
{/* 左侧:图表 + 持仓 */}
|
||||
<div className="space-y-6">
|
||||
{/* Equity Chart */}
|
||||
{/* Chart Tabs (Equity / K-line) */}
|
||||
<div className="animate-slide-in" style={{ animationDelay: '0.1s' }}>
|
||||
<EquityChart traderId={selectedTrader.trader_id} />
|
||||
<ChartTabs traderId={selectedTrader.trader_id} />
|
||||
</div>
|
||||
|
||||
{/* Current Positions */}
|
||||
@@ -669,10 +668,6 @@ export default function TraderDashboard() {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* AI Learning & Performance Analysis */}
|
||||
<div className="mb-6 animate-slide-in" style={{ animationDelay: '0.3s' }}>
|
||||
<AILearning traderId={selectedTrader.trader_id} />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user