test(trader): add comprehensive unit tests and CI coverage reporting (#823)

* chore(config): add Python and uv support to project

- Add comprehensive Python .gitignore rules (pycache, venv, pytest, etc.)
- Add uv package manager specific ignores (.uv/, uv.lock)
- Initialize pyproject.toml for Python tooling

Co-authored-by: tinkle-community <tinklefund@gmail.com>

* chore(deps): add testing dependencies

- Add github.com/stretchr/testify v1.11.1 for test assertions
- Add github.com/agiledragon/gomonkey/v2 v2.13.0 for mocking
- Promote github.com/rs/zerolog to direct dependency

Co-authored-by: tinkle-community <tinklefund@gmail.com>

* ci(workflow): add PR test coverage reporting

Add GitHub Actions workflow to run unit tests and report coverage on PRs:
- Run Go tests with race detection and coverage profiling
- Calculate coverage statistics and generate detailed reports
- Post coverage results as PR comments with visual indicators
- Fix Go version to 1.23 (was incorrectly set to 1.25.0)

Coverage guidelines:
- Green (>=80%): excellent
- Yellow (>=60%): good
- Orange (>=40%): fair
- Red (<40%): needs improvement

This workflow is advisory only and does not block PR merging.

Co-authored-by: tinkle-community <tinklefund@gmail.com>

* test(trader): add comprehensive unit tests for trader modules

Add unit test suites for multiple trader implementations:
- aster_trader_test.go: AsterTrader functionality tests
- auto_trader_test.go: AutoTrader lifecycle and operations tests
- binance_futures_test.go: Binance futures trader tests
- hyperliquid_trader_test.go: Hyperliquid trader tests
- trader_test_suite.go: Common test suite utilities and helpers

Also fix minor formatting issue in auto_trader.go (trailing whitespace)

Co-authored-by: tinkle-community <tinklefund@gmail.com>

* test(trader): preserve existing calculatePnLPercentage unit tests

Merge existing calculatePnLPercentage tests with incoming comprehensive test suite:
- Preserve TestCalculatePnLPercentage with 9 test cases covering edge cases
- Preserve TestCalculatePnLPercentage_RealWorldScenarios with 3 trading scenarios
- Add math package import for floating-point precision comparison
- All tests validate PnL percentage calculation with different leverage scenarios

Co-authored-by: tinkle-community <tinklefund@gmail.com>

---------

Co-authored-by: tinkle-community <tinklefund@gmail.com>
This commit is contained in:
WquGuru
2025-11-09 17:43:28 +08:00
committed by GitHub
parent 0188ffb778
commit ae09647468
14 changed files with 3766 additions and 42 deletions

View File

@@ -0,0 +1,78 @@
name: Go Test Coverage
on:
pull_request:
types: [opened, synchronize, reopened]
branches:
- dev
- main
push:
branches:
- dev
- main
permissions:
contents: read
pull-requests: write
jobs:
test-coverage:
name: Go Unit Tests & Coverage
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.23'
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install -r .github/workflows/scripts/requirements.txt
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Download dependencies
run: go mod download
- name: Run tests with coverage
run: |
go test -v -race -coverprofile=coverage.out -covermode=atomic ./...
- name: Calculate coverage and generate report
id: coverage
run: |
chmod +x .github/workflows/scripts/calculate_coverage.py
python .github/workflows/scripts/calculate_coverage.py coverage.out coverage_report.md
- name: Comment PR with coverage
if: github.event_name == 'pull_request'
continue-on-error: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
chmod +x .github/workflows/scripts/comment_pr.py
python .github/workflows/scripts/comment_pr.py \
${{ github.event.pull_request.number }} \
"${{ steps.coverage.outputs.coverage }}" \
"${{ steps.coverage.outputs.emoji }}" \
"${{ steps.coverage.outputs.status }}" \
"${{ steps.coverage.outputs.badge_color }}" \
coverage_report.md

View File

@@ -0,0 +1,192 @@
#!/usr/bin/env python3
"""
Calculate Go test coverage and generate reports.
This script parses the coverage.out file generated by `go test -coverprofile`,
extracts coverage statistics, and generates formatted reports.
"""
import sys
import re
import os
from typing import Dict, List, Tuple
def parse_coverage_file(coverage_file: str) -> Tuple[float, Dict[str, float]]:
"""
Parse coverage output file and extract coverage data.
Args:
coverage_file: Path to coverage.out file
Returns:
Tuple of (total_coverage, package_coverage_dict)
"""
if not os.path.exists(coverage_file):
print(f"Error: Coverage file {coverage_file} not found", file=sys.stderr)
sys.exit(1)
# Run go tool cover to get coverage data
import subprocess
try:
result = subprocess.run(
['go', 'tool', 'cover', '-func', coverage_file],
capture_output=True,
text=True,
check=True
)
except subprocess.CalledProcessError as e:
print(f"Error running go tool cover: {e}", file=sys.stderr)
sys.exit(1)
lines = result.stdout.strip().split('\n')
package_coverage = {}
total_coverage = 0.0
for line in lines:
# Skip empty lines
if not line.strip():
continue
# Check for total coverage line
if line.startswith('total:'):
# Extract percentage from "total: (statements) XX.X%"
match = re.search(r'(\d+\.\d+)%', line)
if match:
total_coverage = float(match.group(1))
continue
# Parse package/file coverage
# Format: "package/file.go:function statements coverage%"
parts = line.split()
if len(parts) >= 3:
file_path = parts[0]
coverage_str = parts[-1]
# Extract package name from file path
package = file_path.split(':')[0]
package_name = '/'.join(package.split('/')[:-1]) if '/' in package else package
# Extract coverage percentage
match = re.search(r'(\d+\.\d+)%', coverage_str)
if match:
coverage_pct = float(match.group(1))
# Aggregate by package
if package_name not in package_coverage:
package_coverage[package_name] = []
package_coverage[package_name].append(coverage_pct)
# Calculate average coverage per package
package_avg = {
pkg: sum(coverages) / len(coverages)
for pkg, coverages in package_coverage.items()
}
return total_coverage, package_avg
def get_coverage_status(coverage: float) -> Tuple[str, str, str]:
"""
Get coverage status based on percentage.
Args:
coverage: Coverage percentage
Returns:
Tuple of (emoji, status_text, badge_color)
"""
if coverage >= 80:
return '🟢', 'excellent', 'brightgreen'
elif coverage >= 60:
return '🟡', 'good', 'yellow'
elif coverage >= 40:
return '🟠', 'fair', 'orange'
else:
return '🔴', 'needs improvement', 'red'
def generate_coverage_report(coverage_file: str, output_file: str) -> None:
"""
Generate a detailed coverage report in markdown format.
Args:
coverage_file: Path to coverage.out file
output_file: Path to output markdown file
"""
import subprocess
try:
result = subprocess.run(
['go', 'tool', 'cover', '-func', coverage_file],
capture_output=True,
text=True,
check=True
)
except subprocess.CalledProcessError as e:
print(f"Error generating coverage report: {e}", file=sys.stderr)
sys.exit(1)
with open(output_file, 'w') as f:
f.write("## Coverage by Package\n\n")
f.write("```\n")
f.write(result.stdout)
f.write("```\n")
def set_github_output(name: str, value: str) -> None:
"""
Set GitHub Actions output variable.
Args:
name: Output variable name
value: Output variable value
"""
github_output = os.environ.get('GITHUB_OUTPUT')
if github_output:
with open(github_output, 'a') as f:
f.write(f"{name}={value}\n")
else:
print(f"::set-output name={name}::{value}")
def main():
"""Main entry point."""
if len(sys.argv) < 2:
print("Usage: calculate_coverage.py <coverage_file> [output_file]", file=sys.stderr)
sys.exit(1)
coverage_file = sys.argv[1]
output_file = sys.argv[2] if len(sys.argv) > 2 else 'coverage_report.md'
# Parse coverage data
total_coverage, package_coverage = parse_coverage_file(coverage_file)
# Get coverage status
emoji, status, badge_color = get_coverage_status(total_coverage)
# Generate detailed report
generate_coverage_report(coverage_file, output_file)
# Output results
print(f"Total Coverage: {total_coverage}%")
print(f"Status: {status}")
print(f"Badge Color: {badge_color}")
# Set GitHub Actions outputs
set_github_output('coverage', f'{total_coverage}%')
set_github_output('coverage_num', str(total_coverage))
set_github_output('status', status)
set_github_output('emoji', emoji)
set_github_output('badge_color', badge_color)
# Print package breakdown
if package_coverage:
print("\nCoverage by Package:")
for package, coverage in sorted(package_coverage.items()):
print(f" {package}: {coverage:.1f}%")
if __name__ == '__main__':
main()

246
.github/workflows/scripts/comment_pr.py vendored Executable file
View File

@@ -0,0 +1,246 @@
#!/usr/bin/env python3
"""
Post or update coverage report comment on GitHub Pull Request.
This script generates a formatted coverage report comment and posts it to a PR,
or updates an existing coverage comment if one already exists.
"""
import os
import sys
import json
import requests
from typing import Optional
def read_file(file_path: str) -> str:
"""Read file content."""
try:
with open(file_path, 'r') as f:
return f.read()
except FileNotFoundError:
print(f"Warning: File {file_path} not found", file=sys.stderr)
return ""
def generate_comment_body(coverage: str, emoji: str, status: str,
badge_color: str, coverage_report_path: str) -> str:
"""
Generate the PR comment body.
Args:
coverage: Coverage percentage (e.g., "75.5%")
emoji: Status emoji
status: Status text
badge_color: Badge color
coverage_report_path: Path to detailed coverage report
Returns:
Formatted comment body in markdown
"""
coverage_report = read_file(coverage_report_path)
# URL encode the coverage percentage for the badge
coverage_encoded = coverage.replace('%', '%25')
comment = f"""## {emoji} Go Test Coverage Report
**Total Coverage:** `{coverage}` ({status})
![Coverage](https://img.shields.io/badge/coverage-{coverage_encoded}-{badge_color})
<details>
<summary>📊 Detailed Coverage Report (click to expand)</summary>
{coverage_report}
</details>
### Coverage Guidelines
- 🟢 >= 80%: Excellent
- 🟡 >= 60%: Good
- 🟠 >= 40%: Fair
- 🔴 < 40%: Needs improvement
---
*This is an automated coverage report. The coverage requirement is advisory and does not block PR merging.*
"""
return comment
def find_existing_comment(token: str, repo: str, pr_number: int) -> Optional[int]:
"""
Find existing coverage comment in the PR.
Args:
token: GitHub token
repo: Repository in format "owner/repo"
pr_number: Pull request number
Returns:
Comment ID if found, None otherwise
"""
url = f"https://api.github.com/repos/{repo}/issues/{pr_number}/comments"
headers = {
'Authorization': f'token {token}',
'Accept': 'application/vnd.github.v3+json'
}
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
comments = response.json()
# Look for existing coverage comment
for comment in comments:
if (comment.get('user', {}).get('type') == 'Bot' and
'Go Test Coverage Report' in comment.get('body', '')):
return comment['id']
except requests.exceptions.RequestException as e:
print(f"Error fetching comments: {e}", file=sys.stderr)
return None
def post_comment(token: str, repo: str, pr_number: int, body: str) -> bool:
"""
Post a new comment to the PR.
Args:
token: GitHub token
repo: Repository in format "owner/repo"
pr_number: Pull request number
body: Comment body
Returns:
True if successful, False otherwise
"""
url = f"https://api.github.com/repos/{repo}/issues/{pr_number}/comments"
headers = {
'Authorization': f'token {token}',
'Accept': 'application/vnd.github.v3+json'
}
data = {'body': body}
try:
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
print("✅ Coverage comment posted successfully")
return True
except requests.exceptions.RequestException as e:
print(f"Error posting comment: {e}", file=sys.stderr)
if hasattr(e, 'response') and e.response is not None:
print(f"Response: {e.response.text}", file=sys.stderr)
return False
def update_comment(token: str, repo: str, comment_id: int, body: str) -> bool:
"""
Update an existing comment.
Args:
token: GitHub token
repo: Repository in format "owner/repo"
comment_id: Comment ID to update
body: New comment body
Returns:
True if successful, False otherwise
"""
url = f"https://api.github.com/repos/{repo}/issues/comments/{comment_id}"
headers = {
'Authorization': f'token {token}',
'Accept': 'application/vnd.github.v3+json'
}
data = {'body': body}
try:
response = requests.patch(url, headers=headers, json=data)
response.raise_for_status()
print("✅ Coverage comment updated successfully")
return True
except requests.exceptions.RequestException as e:
print(f"Error updating comment: {e}", file=sys.stderr)
if hasattr(e, 'response') and e.response is not None:
print(f"Response: {e.response.text}", file=sys.stderr)
return False
def is_fork_pr(event_path: str) -> bool:
"""
Check if the PR is from a fork.
Args:
event_path: Path to GitHub event JSON file
Returns:
True if fork PR, False otherwise
"""
try:
with open(event_path, 'r') as f:
event = json.load(f)
pr = event.get('pull_request', {})
head_repo = pr.get('head', {}).get('repo', {}).get('full_name')
base_repo = pr.get('base', {}).get('repo', {}).get('full_name')
return head_repo != base_repo
except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
print(f"Warning: Could not determine if fork PR: {e}", file=sys.stderr)
return False
def main():
"""Main entry point."""
# Get environment variables
token = os.environ.get('GITHUB_TOKEN')
repo = os.environ.get('GITHUB_REPOSITORY')
event_path = os.environ.get('GITHUB_EVENT_PATH', '')
# Get arguments
if len(sys.argv) < 6:
print("Usage: comment_pr.py <pr_number> <coverage> <emoji> <status> <badge_color> [coverage_report_path]",
file=sys.stderr)
sys.exit(1)
pr_number = int(sys.argv[1])
coverage = sys.argv[2]
emoji = sys.argv[3]
status = sys.argv[4]
badge_color = sys.argv[5]
coverage_report_path = sys.argv[6] if len(sys.argv) > 6 else 'coverage_report.md'
# Validate environment
if not token:
print("Error: GITHUB_TOKEN environment variable not set", file=sys.stderr)
sys.exit(1)
if not repo:
print("Error: GITHUB_REPOSITORY environment variable not set", file=sys.stderr)
sys.exit(1)
# Check if fork PR
if event_path and is_fork_pr(event_path):
print(" Fork PR detected - skipping comment (no write permissions)")
sys.exit(0)
# Generate comment body
comment_body = generate_comment_body(coverage, emoji, status, badge_color, coverage_report_path)
# Check for existing comment
existing_comment_id = find_existing_comment(token, repo, pr_number)
# Post or update comment
if existing_comment_id:
print(f"Found existing comment (ID: {existing_comment_id}), updating...")
success = update_comment(token, repo, existing_comment_id, comment_body)
else:
print("No existing comment found, creating new one...")
success = post_comment(token, repo, pr_number, comment_body)
sys.exit(0 if success else 1)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,2 @@
# Python dependencies for GitHub Actions scripts
requests>=2.31.0

58
.gitignore vendored
View File

@@ -64,3 +64,61 @@ rsa_key*
# 加密相关
DATA_ENCRYPTION_KEY=*
*.enc
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# Python 虚拟环境
.venv/
venv/
ENV/
env/
.env/
# uv
.uv/
uv.lock
# Pytest
.pytest_cache/
.coverage
htmlcov/
*.cover
.hypothesis/
# Jupyter Notebook
.ipynb_checkpoints
*.ipynb
# pyenv
.python-version
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

7
go.mod
View File

@@ -4,6 +4,7 @@ go 1.25.0
require (
github.com/adshao/go-binance/v2 v2.8.7
github.com/agiledragon/gomonkey/v2 v2.13.0
github.com/ethereum/go-ethereum v1.16.5
github.com/gin-gonic/gin v1.11.0
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1
@@ -12,8 +13,10 @@ require (
github.com/gorilla/websocket v1.5.3
github.com/joho/godotenv v1.5.1
github.com/pquerna/otp v1.4.0
github.com/rs/zerolog v1.34.0
github.com/sirupsen/logrus v1.9.3
github.com/sonirico/go-hyperliquid v0.17.0
github.com/stretchr/testify v1.11.1
golang.org/x/crypto v0.42.0
modernc.org/sqlite v1.40.0
)
@@ -29,6 +32,7 @@ require (
github.com/consensys/gnark-crypto v0.19.0 // indirect
github.com/crate-crypto/go-eth-kzg v1.4.0 // indirect
github.com/crate-crypto/go-ipa v0.0.0-20240724233137-53bbb0ceb27a // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/elastic/go-sysinfo v1.15.4 // indirect
@@ -56,11 +60,11 @@ require (
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/procfs v0.17.0 // indirect
github.com/quic-go/qpack v0.5.1 // indirect
github.com/quic-go/quic-go v0.54.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rs/zerolog v1.34.0 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/sonirico/vago v0.9.0 // indirect
github.com/sonirico/vago/lol v0.0.0-20250901170347-2d1d82c510bd // indirect
@@ -83,6 +87,7 @@ require (
golang.org/x/text v0.29.0 // indirect
golang.org/x/tools v0.36.0 // indirect
google.golang.org/protobuf v1.36.9 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
howett.net/plist v1.0.1 // indirect
modernc.org/libc v1.66.10 // indirect
modernc.org/mathutil v1.7.1 // indirect

13
go.sum
View File

@@ -2,6 +2,8 @@ github.com/StackExchange/wmi v1.2.1 h1:VIkavFPXSjcnS+O8yTq7NI32k0R5Aj+v39y29VYDO
github.com/StackExchange/wmi v1.2.1/go.mod h1:rcmrprowKIVzvc+NUiLncP2uuArMWLCbu9SBzvHz7e8=
github.com/adshao/go-binance/v2 v2.8.7 h1:n7jkhwIHMdtd/9ZU2gTqFV15XVSbUCjyFlOUAtTd8uU=
github.com/adshao/go-binance/v2 v2.8.7/go.mod h1:XkkuecSyJKPolaCGf/q4ovJYB3t0P+7RUYTbGr+LMGM=
github.com/agiledragon/gomonkey/v2 v2.13.0 h1:B24Jg6wBI1iB8EFR1c+/aoTg7QN/Cum7YffG8KMIyYo=
github.com/agiledragon/gomonkey/v2 v2.13.0/go.mod h1:ap1AmDzcVOAz1YpeJ3TCzIgstoaWLA6jbbgxfB4w2iY=
github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI=
github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y=
@@ -88,6 +90,7 @@ github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17k
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/holiman/uint256 v1.3.2 h1:a9EgMPSC1AAaj1SZL5zIQD3WbwTuHrMGOerLjGmM/TA=
@@ -101,6 +104,7 @@ github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2E
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -165,6 +169,8 @@ github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/sonirico/go-hyperliquid v0.17.0 h1:eXYACWupwu41O1VtKw17dqe9oOLQ1A2nRElGhg5Ox+4=
github.com/sonirico/go-hyperliquid v0.17.0/go.mod h1:sH51Vsu+tPUwc95TL2MoQ8YXSewLWBEJirgzo7sZx6w=
github.com/sonirico/vago v0.9.0 h1:DF2OWW2Aaf1xPZmnFv79kBrHmjKX3mVvMbP08vERlKo=
@@ -209,29 +215,36 @@ go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c=
golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ=
golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw=
google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/dnaeon/go-vcr.v4 v4.0.5 h1:I0hpTIvD5rII+8LgYGrHMA2d4SQPoL6u7ZvJakWKsiA=
gopkg.in/dnaeon/go-vcr.v4 v4.0.5/go.mod h1:dRos81TkW9C1WJt6tTaE+uV2Lo8qJT3AG2b35+CB/nQ=
gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg=

7
pyproject.toml Normal file
View File

@@ -0,0 +1,7 @@
[project]
name = "nofx"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = []

299
trader/aster_trader_test.go Normal file
View File

@@ -0,0 +1,299 @@
package trader
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/assert"
)
// ============================================================
// 一、AsterTraderTestSuite - 继承 base test suite
// ============================================================
// AsterTraderTestSuite Aster交易器测试套件
// 继承 TraderTestSuite 并添加 Aster 特定的 mock 逻辑
type AsterTraderTestSuite struct {
*TraderTestSuite // 嵌入基础测试套件
mockServer *httptest.Server
}
// NewAsterTraderTestSuite 创建 Aster 测试套件
func NewAsterTraderTestSuite(t *testing.T) *AsterTraderTestSuite {
// 创建 mock HTTP 服务器
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 根据不同的 URL 路径返回不同的 mock 响应
path := r.URL.Path
var respBody interface{}
switch {
// Mock GetBalance - /fapi/v3/balance (返回数组)
case path == "/fapi/v3/balance":
respBody = []map[string]interface{}{
{
"asset": "USDT",
"walletBalance": "10000.00",
"unrealizedProfit": "100.50",
"marginBalance": "10100.50",
"maintMargin": "200.00",
"initialMargin": "2000.00",
"maxWithdrawAmount": "8000.00",
"crossWalletBalance": "10000.00",
"crossUnPnl": "100.50",
"availableBalance": "8000.00",
},
}
// Mock GetPositions - /fapi/v3/positionRisk
case path == "/fapi/v3/positionRisk":
respBody = []map[string]interface{}{
{
"symbol": "BTCUSDT",
"positionAmt": "0.5",
"entryPrice": "50000.00",
"markPrice": "50500.00",
"unRealizedProfit": "250.00",
"liquidationPrice": "45000.00",
"leverage": "10",
"positionSide": "LONG",
},
}
// Mock GetMarketPrice - /fapi/v3/ticker/price (返回单个对象)
case path == "/fapi/v3/ticker/price":
// 从查询参数获取symbol
symbol := r.URL.Query().Get("symbol")
if symbol == "" {
symbol = "BTCUSDT"
}
// 根据symbol返回不同价格
price := "50000.00"
if symbol == "ETHUSDT" {
price = "3000.00"
} else if symbol == "INVALIDUSDT" {
// 返回错误响应
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]interface{}{
"code": -1121,
"msg": "Invalid symbol",
})
return
}
respBody = map[string]interface{}{
"symbol": symbol,
"price": price,
}
// Mock ExchangeInfo - /fapi/v3/exchangeInfo
case path == "/fapi/v3/exchangeInfo":
respBody = map[string]interface{}{
"symbols": []map[string]interface{}{
{
"symbol": "BTCUSDT",
"pricePrecision": 1,
"quantityPrecision": 3,
"baseAssetPrecision": 8,
"quotePrecision": 8,
"filters": []map[string]interface{}{
{
"filterType": "PRICE_FILTER",
"tickSize": "0.1",
},
{
"filterType": "LOT_SIZE",
"stepSize": "0.001",
},
},
},
{
"symbol": "ETHUSDT",
"pricePrecision": 2,
"quantityPrecision": 3,
"baseAssetPrecision": 8,
"quotePrecision": 8,
"filters": []map[string]interface{}{
{
"filterType": "PRICE_FILTER",
"tickSize": "0.01",
},
{
"filterType": "LOT_SIZE",
"stepSize": "0.001",
},
},
},
},
}
// Mock CreateOrder - /fapi/v1/order and /fapi/v3/order
case (path == "/fapi/v1/order" || path == "/fapi/v3/order") && r.Method == "POST":
// 从请求中解析参数以确定symbol
bodyBytes, _ := io.ReadAll(r.Body)
var orderParams map[string]interface{}
json.Unmarshal(bodyBytes, &orderParams)
symbol := "BTCUSDT"
if s, ok := orderParams["symbol"].(string); ok {
symbol = s
}
respBody = map[string]interface{}{
"orderId": 123456,
"symbol": symbol,
"status": "FILLED",
"side": orderParams["side"],
"type": orderParams["type"],
}
// Mock CancelOrder - /fapi/v1/order (DELETE)
case path == "/fapi/v1/order" && r.Method == "DELETE":
respBody = map[string]interface{}{
"orderId": 123456,
"symbol": "BTCUSDT",
"status": "CANCELED",
}
// Mock ListOpenOrders - /fapi/v1/openOrders and /fapi/v3/openOrders
case path == "/fapi/v1/openOrders" || path == "/fapi/v3/openOrders":
respBody = []map[string]interface{}{}
// Mock SetLeverage - /fapi/v1/leverage
case path == "/fapi/v1/leverage":
respBody = map[string]interface{}{
"leverage": 10,
"symbol": "BTCUSDT",
}
// Mock SetMarginMode - /fapi/v1/marginType
case path == "/fapi/v1/marginType":
respBody = map[string]interface{}{
"code": 200,
"msg": "success",
}
// Default: empty response
default:
respBody = map[string]interface{}{}
}
// 序列化响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
}))
// 生成一个测试用的私钥
privateKey, _ := crypto.GenerateKey()
// 创建 mock trader使用 mock server 的 URL
trader := &AsterTrader{
ctx: context.Background(),
user: "0x1234567890123456789012345678901234567890",
signer: "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd",
privateKey: privateKey,
client: mockServer.Client(),
baseURL: mockServer.URL, // 使用 mock server 的 URL
symbolPrecision: make(map[string]SymbolPrecision),
}
// 创建基础套件
baseSuite := NewTraderTestSuite(t, trader)
return &AsterTraderTestSuite{
TraderTestSuite: baseSuite,
mockServer: mockServer,
}
}
// Cleanup 清理资源
func (s *AsterTraderTestSuite) Cleanup() {
if s.mockServer != nil {
s.mockServer.Close()
}
s.TraderTestSuite.Cleanup()
}
// ============================================================
// 二、使用 AsterTraderTestSuite 运行通用测试
// ============================================================
// TestAsterTrader_InterfaceCompliance 测试接口兼容性
func TestAsterTrader_InterfaceCompliance(t *testing.T) {
var _ Trader = (*AsterTrader)(nil)
}
// TestAsterTrader_CommonInterface 使用测试套件运行所有通用接口测试
func TestAsterTrader_CommonInterface(t *testing.T) {
// 创建测试套件
suite := NewAsterTraderTestSuite(t)
defer suite.Cleanup()
// 运行所有通用接口测试
suite.RunAllTests()
}
// ============================================================
// 三、Aster 特定功能的单元测试
// ============================================================
// TestNewAsterTrader 测试创建 Aster 交易器
func TestNewAsterTrader(t *testing.T) {
tests := []struct {
name string
user string
signer string
privateKeyHex string
wantError bool
errorContains string
}{
{
name: "成功创建",
user: "0x1234567890123456789012345678901234567890",
signer: "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd",
privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
wantError: false,
},
{
name: "无效私钥格式",
user: "0x1234567890123456789012345678901234567890",
signer: "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd",
privateKeyHex: "invalid_key",
wantError: true,
errorContains: "解析私钥失败",
},
{
name: "带0x前缀的私钥",
user: "0x1234567890123456789012345678901234567890",
signer: "0xabcdefabcdefabcdefabcdefabcdefabcdefabcd",
privateKeyHex: "0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
trader, err := NewAsterTrader(tt.user, tt.signer, tt.privateKeyHex)
if tt.wantError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, trader)
} else {
assert.NoError(t, err)
assert.NotNil(t, trader)
if trader != nil {
assert.Equal(t, tt.user, trader.user)
assert.Equal(t, tt.signer, trader.signer)
assert.NotNil(t, trader.privateKey)
}
}
})
}
}

View File

@@ -241,7 +241,7 @@ func (at *AutoTrader) Run() error {
at.isRunning = true
at.stopMonitorCh = make(chan struct{})
at.startTime = time.Now()
log.Println("🚀 AI驱动自动交易系统启动")
log.Printf("💰 初始余额: %.2f USDT", at.initialBalance)
log.Printf("⚙️ 扫描间隔: %v", at.config.ScanInterval)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,420 @@
package trader
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/adshao/go-binance/v2/futures"
"github.com/stretchr/testify/assert"
)
// ============================================================
// 一、BinanceFuturesTestSuite - 继承 base test suite
// ============================================================
// BinanceFuturesTestSuite 币安合约交易器测试套件
// 继承 TraderTestSuite 并添加 Binance Futures 特定的 mock 逻辑
type BinanceFuturesTestSuite struct {
*TraderTestSuite // 嵌入基础测试套件
mockServer *httptest.Server
}
// NewBinanceFuturesTestSuite 创建币安合约测试套件
func NewBinanceFuturesTestSuite(t *testing.T) *BinanceFuturesTestSuite {
// 创建 mock HTTP 服务器
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 根据不同的 URL 路径返回不同的 mock 响应
path := r.URL.Path
var respBody interface{}
switch {
// Mock GetBalance - /fapi/v2/balance
case path == "/fapi/v2/balance":
respBody = []map[string]interface{}{
{
"accountAlias": "test",
"asset": "USDT",
"balance": "10000.00",
"crossWalletBalance": "10000.00",
"crossUnPnl": "100.50",
"availableBalance": "8000.00",
"maxWithdrawAmount": "8000.00",
},
}
// Mock GetAccount - /fapi/v2/account
case path == "/fapi/v2/account":
respBody = map[string]interface{}{
"totalWalletBalance": "10000.00",
"availableBalance": "8000.00",
"totalUnrealizedProfit": "100.50",
"assets": []map[string]interface{}{
{
"asset": "USDT",
"walletBalance": "10000.00",
"unrealizedProfit": "100.50",
"marginBalance": "10100.50",
"maintMargin": "200.00",
"initialMargin": "2000.00",
"positionInitialMargin": "2000.00",
"openOrderInitialMargin": "0.00",
"crossWalletBalance": "10000.00",
"crossUnPnl": "100.50",
"availableBalance": "8000.00",
"maxWithdrawAmount": "8000.00",
},
},
}
// Mock GetPositions - /fapi/v2/positionRisk
case path == "/fapi/v2/positionRisk":
respBody = []map[string]interface{}{
{
"symbol": "BTCUSDT",
"positionAmt": "0.5",
"entryPrice": "50000.00",
"markPrice": "50500.00",
"unRealizedProfit": "250.00",
"liquidationPrice": "45000.00",
"leverage": "10",
"positionSide": "LONG",
},
}
// Mock GetMarketPrice - /fapi/v1/ticker/price and /fapi/v2/ticker/price
case path == "/fapi/v1/ticker/price" || path == "/fapi/v2/ticker/price":
symbol := r.URL.Query().Get("symbol")
if symbol == "" {
// 返回所有价格
respBody = []map[string]interface{}{
{"Symbol": "BTCUSDT", "Price": "50000.00", "Time": 1234567890},
{"Symbol": "ETHUSDT", "Price": "3000.00", "Time": 1234567890},
}
} else if symbol == "INVALIDUSDT" {
// 返回错误
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]interface{}{
"code": -1121,
"msg": "Invalid symbol.",
})
return
} else {
// 返回单个价格(注意:即使有 symbol 参数,也要返回数组)
price := "50000.00"
if symbol == "ETHUSDT" {
price = "3000.00"
}
respBody = []map[string]interface{}{
{
"Symbol": symbol,
"Price": price,
"Time": 1234567890,
},
}
}
// Mock ExchangeInfo - /fapi/v1/exchangeInfo
case path == "/fapi/v1/exchangeInfo":
respBody = map[string]interface{}{
"symbols": []map[string]interface{}{
{
"symbol": "BTCUSDT",
"status": "TRADING",
"baseAsset": "BTC",
"quoteAsset": "USDT",
"pricePrecision": 2,
"quantityPrecision": 3,
"baseAssetPrecision": 8,
"quotePrecision": 8,
"filters": []map[string]interface{}{
{
"filterType": "PRICE_FILTER",
"minPrice": "0.01",
"maxPrice": "1000000",
"tickSize": "0.01",
},
{
"filterType": "LOT_SIZE",
"minQty": "0.001",
"maxQty": "10000",
"stepSize": "0.001",
},
},
},
{
"symbol": "ETHUSDT",
"status": "TRADING",
"baseAsset": "ETH",
"quoteAsset": "USDT",
"pricePrecision": 2,
"quantityPrecision": 3,
"baseAssetPrecision": 8,
"quotePrecision": 8,
"filters": []map[string]interface{}{
{
"filterType": "PRICE_FILTER",
"minPrice": "0.01",
"maxPrice": "100000",
"tickSize": "0.01",
},
{
"filterType": "LOT_SIZE",
"minQty": "0.001",
"maxQty": "10000",
"stepSize": "0.001",
},
},
},
},
}
// Mock CreateOrder - /fapi/v1/order (POST)
case path == "/fapi/v1/order" && r.Method == "POST":
symbol := r.FormValue("symbol")
if symbol == "" {
symbol = "BTCUSDT"
}
respBody = map[string]interface{}{
"orderId": 123456,
"symbol": symbol,
"status": "FILLED",
"clientOrderId": r.FormValue("newClientOrderId"),
"price": r.FormValue("price"),
"avgPrice": r.FormValue("price"),
"origQty": r.FormValue("quantity"),
"executedQty": r.FormValue("quantity"),
"cumQty": r.FormValue("quantity"),
"cumQuote": "1000.00",
"timeInForce": r.FormValue("timeInForce"),
"type": r.FormValue("type"),
"reduceOnly": r.FormValue("reduceOnly") == "true",
"side": r.FormValue("side"),
"positionSide": r.FormValue("positionSide"),
"stopPrice": r.FormValue("stopPrice"),
"workingType": r.FormValue("workingType"),
}
// Mock CancelOrder - /fapi/v1/order (DELETE)
case path == "/fapi/v1/order" && r.Method == "DELETE":
respBody = map[string]interface{}{
"orderId": 123456,
"symbol": r.URL.Query().Get("symbol"),
"status": "CANCELED",
}
// Mock ListOpenOrders - /fapi/v1/openOrders
case path == "/fapi/v1/openOrders":
respBody = []map[string]interface{}{}
// Mock CancelAllOrders - /fapi/v1/allOpenOrders (DELETE)
case path == "/fapi/v1/allOpenOrders" && r.Method == "DELETE":
respBody = map[string]interface{}{
"code": 200,
"msg": "The operation of cancel all open order is done.",
}
// Mock SetLeverage - /fapi/v1/leverage
case path == "/fapi/v1/leverage":
// 将字符串转换为整数
leverageStr := r.FormValue("leverage")
leverage := 10 // 默认值
if leverageStr != "" {
// 注意:这里我们直接返回整数,而不是字符串
fmt.Sscanf(leverageStr, "%d", &leverage)
}
respBody = map[string]interface{}{
"leverage": leverage,
"maxNotionalValue": "1000000",
"symbol": r.FormValue("symbol"),
}
// Mock SetMarginType - /fapi/v1/marginType
case path == "/fapi/v1/marginType":
respBody = map[string]interface{}{
"code": 200,
"msg": "success",
}
// Mock ChangePositionMode - /fapi/v1/positionSide/dual
case path == "/fapi/v1/positionSide/dual":
respBody = map[string]interface{}{
"code": 200,
"msg": "success",
}
// Mock ServerTime - /fapi/v1/time
case path == "/fapi/v1/time":
respBody = map[string]interface{}{
"serverTime": 1234567890000,
}
// Default: empty response
default:
respBody = map[string]interface{}{}
}
// 序列化响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
}))
// 创建 futures.Client 并设置为使用 mock 服务器
client := futures.NewClient("test_api_key", "test_secret_key")
client.BaseURL = mockServer.URL
client.HTTPClient = mockServer.Client()
// 创建 FuturesTrader
trader := &FuturesTrader{
client: client,
cacheDuration: 0, // 禁用缓存以便测试
}
// 创建基础套件
baseSuite := NewTraderTestSuite(t, trader)
return &BinanceFuturesTestSuite{
TraderTestSuite: baseSuite,
mockServer: mockServer,
}
}
// Cleanup 清理资源
func (s *BinanceFuturesTestSuite) Cleanup() {
if s.mockServer != nil {
s.mockServer.Close()
}
s.TraderTestSuite.Cleanup()
}
// ============================================================
// 二、使用 BinanceFuturesTestSuite 运行通用测试
// ============================================================
// TestFuturesTrader_InterfaceCompliance 测试接口兼容性
func TestFuturesTrader_InterfaceCompliance(t *testing.T) {
var _ Trader = (*FuturesTrader)(nil)
}
// TestFuturesTrader_CommonInterface 使用测试套件运行所有通用接口测试
func TestFuturesTrader_CommonInterface(t *testing.T) {
// 创建测试套件
suite := NewBinanceFuturesTestSuite(t)
defer suite.Cleanup()
// 运行所有通用接口测试
suite.RunAllTests()
}
// ============================================================
// 三、币安合约特定功能的单元测试
// ============================================================
// TestNewFuturesTrader 测试创建币安合约交易器
func TestNewFuturesTrader(t *testing.T) {
// 创建 mock HTTP 服务器
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
var respBody interface{}
switch path {
case "/fapi/v1/time":
respBody = map[string]interface{}{
"serverTime": 1234567890000,
}
case "/fapi/v1/positionSide/dual":
respBody = map[string]interface{}{
"code": 200,
"msg": "success",
}
default:
respBody = map[string]interface{}{}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
}))
defer mockServer.Close()
// 测试成功创建
trader := NewFuturesTrader("test_api_key", "test_secret_key", "test_user")
// 修改 client 使用 mock server
trader.client.BaseURL = mockServer.URL
trader.client.HTTPClient = mockServer.Client()
assert.NotNil(t, trader)
assert.NotNil(t, trader.client)
assert.Equal(t, 15*time.Second, trader.cacheDuration)
}
// TestCalculatePositionSize 测试仓位计算
func TestCalculatePositionSize(t *testing.T) {
trader := &FuturesTrader{}
tests := []struct {
name string
balance float64
riskPercent float64
price float64
leverage int
wantQuantity float64
}{
{
name: "正常计算",
balance: 10000,
riskPercent: 2,
price: 50000,
leverage: 10,
wantQuantity: 0.04, // (10000 * 0.02 * 10) / 50000 = 0.04
},
{
name: "高杠杆",
balance: 10000,
riskPercent: 1,
price: 3000,
leverage: 20,
wantQuantity: 0.6667, // (10000 * 0.01 * 20) / 3000 = 0.6667
},
{
name: "低风险",
balance: 5000,
riskPercent: 0.5,
price: 50000,
leverage: 5,
wantQuantity: 0.0025, // (5000 * 0.005 * 5) / 50000 = 0.0025
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
quantity := trader.CalculatePositionSize(tt.balance, tt.riskPercent, tt.price, tt.leverage)
assert.InDelta(t, tt.wantQuantity, quantity, 0.0001, "计算的仓位数量不正确")
})
}
}
// TestGetBrOrderID 测试订单ID生成
func TestGetBrOrderID(t *testing.T) {
// 测试3次确保每次生成的ID都不同
ids := make(map[string]bool)
for i := 0; i < 3; i++ {
id := getBrOrderID()
// 检查格式
assert.True(t, strings.HasPrefix(id, "x-KzrpZaP9"), "订单ID应以x-KzrpZaP9开头")
// 检查长度(应该 <= 32
assert.LessOrEqual(t, len(id), 32, "订单ID长度不应超过32字符")
// 检查唯一性
assert.False(t, ids[id], "订单ID应该唯一")
ids[id] = true
}
}

View File

@@ -0,0 +1,646 @@
package trader
import (
"context"
"crypto/ecdsa"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/sonirico/go-hyperliquid"
"github.com/stretchr/testify/assert"
)
// ============================================================
// 一、HyperliquidTestSuite - 继承 base test suite
// ============================================================
// HyperliquidTestSuite Hyperliquid 交易器测试套件
// 继承 TraderTestSuite 并添加 Hyperliquid 特定的 mock 逻辑
type HyperliquidTestSuite struct {
*TraderTestSuite // 嵌入基础测试套件
mockServer *httptest.Server
privateKey *ecdsa.PrivateKey
}
// NewHyperliquidTestSuite 创建 Hyperliquid 测试套件
func NewHyperliquidTestSuite(t *testing.T) *HyperliquidTestSuite {
// 创建测试用私钥
privateKey, err := crypto.HexToECDSA("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
if err != nil {
t.Fatalf("创建测试私钥失败: %v", err)
}
// 创建 mock HTTP 服务器
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 根据不同的请求路径返回不同的 mock 响应
var respBody interface{}
// Hyperliquid API 使用 POST 请求,请求体是 JSON
// 我们需要根据请求体中的 "type" 字段来区分不同的请求
var reqBody map[string]interface{}
if r.Method == "POST" {
json.NewDecoder(r.Body).Decode(&reqBody)
}
// Try to get type from top level first, then from action object
reqType, _ := reqBody["type"].(string)
if reqType == "" && reqBody["action"] != nil {
if action, ok := reqBody["action"].(map[string]interface{}); ok {
reqType, _ = action["type"].(string)
}
}
switch reqType {
// Mock Meta - 获取市场元数据
case "meta":
respBody = map[string]interface{}{
"universe": []map[string]interface{}{
{
"name": "BTC",
"szDecimals": 4,
"maxLeverage": 50,
"onlyIsolated": false,
"isDelisted": false,
"marginTableId": 0,
},
{
"name": "ETH",
"szDecimals": 3,
"maxLeverage": 50,
"onlyIsolated": false,
"isDelisted": false,
"marginTableId": 0,
},
},
"marginTables": []interface{}{},
}
// Mock UserState - 获取用户账户状态(用于 GetBalance 和 GetPositions
case "clearinghouseState":
user, _ := reqBody["user"].(string)
// 检查是否是查询 Agent 钱包余额(用于安全检查)
agentAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex()
if user == agentAddr {
// Agent 钱包余额应该很低
respBody = map[string]interface{}{
"crossMarginSummary": map[string]interface{}{
"accountValue": "5.00",
"totalMarginUsed": "0.00",
},
"withdrawable": "5.00",
"assetPositions": []interface{}{},
}
} else {
// 主钱包账户状态
respBody = map[string]interface{}{
"crossMarginSummary": map[string]interface{}{
"accountValue": "10000.00",
"totalMarginUsed": "2000.00",
},
"withdrawable": "8000.00",
"assetPositions": []map[string]interface{}{
{
"position": map[string]interface{}{
"coin": "BTC",
"szi": "0.5",
"entryPx": "50000.00",
"liquidationPx": "45000.00",
"positionValue": "25000.00",
"unrealizedPnl": "100.50",
"leverage": map[string]interface{}{
"type": "cross",
"value": 10,
},
},
},
},
}
}
// Mock SpotUserState - 获取现货账户状态
case "spotClearinghouseState":
respBody = map[string]interface{}{
"balances": []map[string]interface{}{
{
"coin": "USDC",
"total": "500.00",
},
},
}
// Mock SpotMeta - 获取现货市场元数据
case "spotMeta":
respBody = map[string]interface{}{
"universe": []map[string]interface{}{},
"tokens": []map[string]interface{}{},
}
// Mock AllMids - 获取所有市场价格
case "allMids":
respBody = map[string]string{
"BTC": "50000.00",
"ETH": "3000.00",
}
// Mock OpenOrders - 获取挂单列表
case "openOrders":
respBody = []interface{}{}
// Mock Order - 创建订单(开仓、平仓、止损、止盈)
case "order":
respBody = map[string]interface{}{
"status": "ok",
"response": map[string]interface{}{
"type": "order",
"data": map[string]interface{}{
"statuses": []map[string]interface{}{
{
"filled": map[string]interface{}{
"totalSz": "0.01",
"avgPx": "50000.00",
},
},
},
},
},
}
// Mock UpdateLeverage - 设置杠杆
case "updateLeverage":
respBody = map[string]interface{}{
"status": "ok",
}
// Mock Cancel - 取消订单
case "cancel":
respBody = map[string]interface{}{
"status": "ok",
}
default:
// 默认返回成功响应
respBody = map[string]interface{}{
"status": "ok",
}
}
// 序列化响应
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
}))
// 创建 HyperliquidTrader使用 mock 服务器 URL
walletAddr := "0x9999999999999999999999999999999999999999"
ctx := context.Background()
// 创建 Exchange 客户端,指向 mock 服务器
exchange := hyperliquid.NewExchange(
ctx,
privateKey,
mockServer.URL, // 使用 mock 服务器 URL
nil,
"",
walletAddr,
nil,
)
// 创建 meta模拟获取成功
meta := &hyperliquid.Meta{
Universe: []hyperliquid.AssetInfo{
{Name: "BTC", SzDecimals: 4},
{Name: "ETH", SzDecimals: 3},
},
}
trader := &HyperliquidTrader{
exchange: exchange,
ctx: ctx,
walletAddr: walletAddr,
meta: meta,
isCrossMargin: true,
}
// 创建基础套件
baseSuite := NewTraderTestSuite(t, trader)
return &HyperliquidTestSuite{
TraderTestSuite: baseSuite,
mockServer: mockServer,
privateKey: privateKey,
}
}
// Cleanup 清理资源
func (s *HyperliquidTestSuite) Cleanup() {
if s.mockServer != nil {
s.mockServer.Close()
}
s.TraderTestSuite.Cleanup()
}
// ============================================================
// 二、使用 HyperliquidTestSuite 运行通用测试
// ============================================================
// TestHyperliquidTrader_InterfaceCompliance 测试接口兼容性
func TestHyperliquidTrader_InterfaceCompliance(t *testing.T) {
var _ Trader = (*HyperliquidTrader)(nil)
}
// TestHyperliquidTrader_CommonInterface 使用测试套件运行所有通用接口测试
func TestHyperliquidTrader_CommonInterface(t *testing.T) {
// 创建测试套件
suite := NewHyperliquidTestSuite(t)
defer suite.Cleanup()
// 运行所有通用接口测试
suite.RunAllTests()
}
// ============================================================
// 三、Hyperliquid 特定功能的单元测试
// ============================================================
// TestNewHyperliquidTrader 测试创建 Hyperliquid 交易器
func TestNewHyperliquidTrader(t *testing.T) {
tests := []struct {
name string
privateKeyHex string
walletAddr string
testnet bool
wantError bool
errorContains string
}{
{
name: "无效私钥格式",
privateKeyHex: "invalid_key",
walletAddr: "0x1234567890123456789012345678901234567890",
testnet: true,
wantError: true,
errorContains: "解析私钥失败",
},
{
name: "钱包地址为空",
privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
walletAddr: "",
testnet: true,
wantError: true,
errorContains: "Configuration error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
trader, err := NewHyperliquidTrader(tt.privateKeyHex, tt.walletAddr, tt.testnet)
if tt.wantError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
assert.Nil(t, trader)
} else {
assert.NoError(t, err)
assert.NotNil(t, trader)
if trader != nil {
assert.Equal(t, tt.walletAddr, trader.walletAddr)
assert.NotNil(t, trader.exchange)
}
}
})
}
}
// TestNewHyperliquidTrader_Success 测试成功创建交易器(需要 mock HTTP
func TestNewHyperliquidTrader_Success(t *testing.T) {
// 创建测试用私钥
privateKey, _ := crypto.HexToECDSA("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
agentAddr := crypto.PubkeyToAddress(privateKey.PublicKey).Hex()
// 创建 mock HTTP 服务器
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
reqType, _ := reqBody["type"].(string)
var respBody interface{}
switch reqType {
case "meta":
respBody = map[string]interface{}{
"universe": []map[string]interface{}{
{
"name": "BTC",
"szDecimals": 4,
"maxLeverage": 50,
"onlyIsolated": false,
"isDelisted": false,
"marginTableId": 0,
},
},
"marginTables": []interface{}{},
}
case "clearinghouseState":
user, _ := reqBody["user"].(string)
if user == agentAddr {
// Agent 钱包余额低
respBody = map[string]interface{}{
"crossMarginSummary": map[string]interface{}{
"accountValue": "5.00",
},
"assetPositions": []interface{}{},
}
} else {
// 主钱包
respBody = map[string]interface{}{
"crossMarginSummary": map[string]interface{}{
"accountValue": "10000.00",
},
"assetPositions": []interface{}{},
}
}
default:
respBody = map[string]interface{}{"status": "ok"}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(respBody)
}))
defer mockServer.Close()
// 注意:这个测试会真正调用 NewHyperliquidTrader但会失败
// 因为 hyperliquid SDK 不允许我们在构造函数中注入自定义 URL
// 所以这个测试仅用于验证参数处理逻辑
t.Skip("跳过此测试hyperliquid SDK 在构造时会调用真实 API无法注入 mock URL")
}
// ============================================================
// 四、工具函数单元测试Hyperliquid 特有)
// ============================================================
// TestConvertSymbolToHyperliquid 测试 symbol 转换函数
func TestConvertSymbolToHyperliquid(t *testing.T) {
tests := []struct {
name string
symbol string
expected string
}{
{
name: "BTCUSDT转换",
symbol: "BTCUSDT",
expected: "BTC",
},
{
name: "ETHUSDT转换",
symbol: "ETHUSDT",
expected: "ETH",
},
{
name: "无USDT后缀",
symbol: "BTC",
expected: "BTC",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertSymbolToHyperliquid(tt.symbol)
assert.Equal(t, tt.expected, result)
})
}
}
// TestAbsFloat 测试绝对值函数
func TestAbsFloat(t *testing.T) {
tests := []struct {
name string
input float64
expected float64
}{
{
name: "正数",
input: 10.5,
expected: 10.5,
},
{
name: "负数",
input: -10.5,
expected: 10.5,
},
{
name: "零",
input: 0,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := absFloat(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// TestHyperliquidTrader_RoundToSzDecimals 测试数量精度处理
func TestHyperliquidTrader_RoundToSzDecimals(t *testing.T) {
trader := &HyperliquidTrader{
meta: &hyperliquid.Meta{
Universe: []hyperliquid.AssetInfo{
{Name: "BTC", SzDecimals: 4},
{Name: "ETH", SzDecimals: 3},
},
},
}
tests := []struct {
name string
coin string
quantity float64
expected float64
}{
{
name: "BTC_四舍五入到4位",
coin: "BTC",
quantity: 1.23456789,
expected: 1.2346,
},
{
name: "ETH_四舍五入到3位",
coin: "ETH",
quantity: 10.12345,
expected: 10.123,
},
{
name: "未知币种_使用默认精度4位",
coin: "UNKNOWN",
quantity: 1.23456789,
expected: 1.2346,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := trader.roundToSzDecimals(tt.coin, tt.quantity)
assert.InDelta(t, tt.expected, result, 0.0001)
})
}
}
// TestHyperliquidTrader_RoundPriceToSigfigs 测试价格有效数字处理
func TestHyperliquidTrader_RoundPriceToSigfigs(t *testing.T) {
trader := &HyperliquidTrader{}
tests := []struct {
name string
price float64
expected float64
}{
{
name: "BTC价格_5位有效数字",
price: 50123.456789,
expected: 50123.0,
},
{
name: "小数价格_5位有效数字",
price: 0.0012345678,
expected: 0.0012346,
},
{
name: "零价格",
price: 0,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := trader.roundPriceToSigfigs(tt.price)
assert.InDelta(t, tt.expected, result, tt.expected*0.001)
})
}
}
// TestHyperliquidTrader_GetSzDecimals 测试获取精度
func TestHyperliquidTrader_GetSzDecimals(t *testing.T) {
tests := []struct {
name string
meta *hyperliquid.Meta
coin string
expected int
}{
{
name: "meta为nil_返回默认精度",
meta: nil,
coin: "BTC",
expected: 4,
},
{
name: "找到BTC_返回正确精度",
meta: &hyperliquid.Meta{
Universe: []hyperliquid.AssetInfo{
{Name: "BTC", SzDecimals: 5},
},
},
coin: "BTC",
expected: 5,
},
{
name: "未找到币种_返回默认精度",
meta: &hyperliquid.Meta{
Universe: []hyperliquid.AssetInfo{
{Name: "ETH", SzDecimals: 3},
},
},
coin: "BTC",
expected: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
trader := &HyperliquidTrader{meta: tt.meta}
result := trader.getSzDecimals(tt.coin)
assert.Equal(t, tt.expected, result)
})
}
}
// TestHyperliquidTrader_SetMarginMode 测试设置保证金模式
func TestHyperliquidTrader_SetMarginMode(t *testing.T) {
trader := &HyperliquidTrader{
ctx: context.Background(),
isCrossMargin: true,
}
tests := []struct {
name string
symbol string
isCrossMargin bool
wantError bool
}{
{
name: "设置为全仓模式",
symbol: "BTCUSDT",
isCrossMargin: true,
wantError: false,
},
{
name: "设置为逐仓模式",
symbol: "ETHUSDT",
isCrossMargin: false,
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := trader.SetMarginMode(tt.symbol, tt.isCrossMargin)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.isCrossMargin, trader.isCrossMargin)
}
})
}
}
// TestNewHyperliquidTrader_PrivateKeyProcessing 测试私钥处理
func TestNewHyperliquidTrader_PrivateKeyProcessing(t *testing.T) {
tests := []struct {
name string
privateKeyHex string
shouldStripOx bool
expectedLength int
}{
{
name: "带0x前缀的私钥",
privateKeyHex: "0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
shouldStripOx: true,
expectedLength: 64,
},
{
name: "无前缀的私钥",
privateKeyHex: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
shouldStripOx: false,
expectedLength: 64,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 测试私钥前缀处理逻辑(不实际创建 trader
processed := tt.privateKeyHex
if len(processed) > 2 && (processed[:2] == "0x" || processed[:2] == "0X") {
processed = processed[2:]
}
assert.Equal(t, tt.expectedLength, len(processed))
})
}
}

664
trader/trader_test_suite.go Normal file
View File

@@ -0,0 +1,664 @@
package trader
import (
"testing"
"github.com/agiledragon/gomonkey/v2"
"github.com/stretchr/testify/assert"
)
// TraderTestSuite 通用的 Trader 接口测试套件(基础套件)
// 用于黑盒测试任何实现了 Trader 接口的交易器
//
// 使用方式:
// 1. 创建具体的测试套件结构体,嵌入 TraderTestSuite
// 2. 实现 SetupMocks() 方法来配置 gomonkey mock
// 3. 调用 RunAllTests() 运行所有通用测试
type TraderTestSuite struct {
T *testing.T
Trader Trader
Patches *gomonkey.Patches
}
// NewTraderTestSuite 创建新的基础测试套件
func NewTraderTestSuite(t *testing.T, trader Trader) *TraderTestSuite {
return &TraderTestSuite{
T: t,
Trader: trader,
Patches: gomonkey.NewPatches(),
}
}
// Cleanup 清理 mock patches
func (s *TraderTestSuite) Cleanup() {
if s.Patches != nil {
s.Patches.Reset()
}
}
// RunAllTests 运行所有通用接口测试
// 注意:调用此方法前,请先通过 SetupMocks 设置好所需的 mock
func (s *TraderTestSuite) RunAllTests() {
// 基础查询方法
s.T.Run("GetBalance", func(t *testing.T) { s.TestGetBalance() })
s.T.Run("GetPositions", func(t *testing.T) { s.TestGetPositions() })
s.T.Run("GetMarketPrice", func(t *testing.T) { s.TestGetMarketPrice() })
// 配置方法
s.T.Run("SetLeverage", func(t *testing.T) { s.TestSetLeverage() })
s.T.Run("SetMarginMode", func(t *testing.T) { s.TestSetMarginMode() })
s.T.Run("FormatQuantity", func(t *testing.T) { s.TestFormatQuantity() })
// 核心交易方法
s.T.Run("OpenLong", func(t *testing.T) { s.TestOpenLong() })
s.T.Run("OpenShort", func(t *testing.T) { s.TestOpenShort() })
s.T.Run("CloseLong", func(t *testing.T) { s.TestCloseLong() })
s.T.Run("CloseShort", func(t *testing.T) { s.TestCloseShort() })
// 止损止盈
s.T.Run("SetStopLoss", func(t *testing.T) { s.TestSetStopLoss() })
s.T.Run("SetTakeProfit", func(t *testing.T) { s.TestSetTakeProfit() })
// 订单管理
s.T.Run("CancelAllOrders", func(t *testing.T) { s.TestCancelAllOrders() })
s.T.Run("CancelStopOrders", func(t *testing.T) { s.TestCancelStopOrders() })
s.T.Run("CancelStopLossOrders", func(t *testing.T) { s.TestCancelStopLossOrders() })
s.T.Run("CancelTakeProfitOrders", func(t *testing.T) { s.TestCancelTakeProfitOrders() })
}
// TestGetBalance 测试获取账户余额
func (s *TraderTestSuite) TestGetBalance() {
tests := []struct {
name string
wantError bool
validate func(*testing.T, map[string]interface{})
}{
{
name: "成功获取余额",
wantError: false,
validate: func(t *testing.T, result map[string]interface{}) {
assert.NotNil(t, result)
assert.Contains(t, result, "totalWalletBalance")
assert.Contains(t, result, "availableBalance")
},
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
result, err := s.Trader.GetBalance()
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.validate != nil {
tt.validate(t, result)
}
}
})
}
}
// TestGetPositions 测试获取持仓
func (s *TraderTestSuite) TestGetPositions() {
tests := []struct {
name string
wantError bool
validate func(*testing.T, []map[string]interface{})
}{
{
name: "成功获取持仓列表",
wantError: false,
validate: func(t *testing.T, positions []map[string]interface{}) {
assert.NotNil(t, positions)
// 持仓可以为空数组
for _, pos := range positions {
assert.Contains(t, pos, "symbol")
assert.Contains(t, pos, "side")
assert.Contains(t, pos, "positionAmt")
}
},
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
result, err := s.Trader.GetPositions()
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.validate != nil {
tt.validate(t, result)
}
}
})
}
}
// TestGetMarketPrice 测试获取市场价格
func (s *TraderTestSuite) TestGetMarketPrice() {
tests := []struct {
name string
symbol string
wantError bool
validate func(*testing.T, float64)
}{
{
name: "成功获取BTC价格",
symbol: "BTCUSDT",
wantError: false,
validate: func(t *testing.T, price float64) {
assert.Greater(t, price, 0.0)
},
},
{
name: "无效交易对返回错误",
symbol: "INVALIDUSDT",
wantError: true,
validate: nil,
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
price, err := s.Trader.GetMarketPrice(tt.symbol)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.validate != nil {
tt.validate(t, price)
}
}
})
}
}
// TestSetLeverage 测试设置杠杆
func (s *TraderTestSuite) TestSetLeverage() {
tests := []struct {
name string
symbol string
leverage int
wantError bool
}{
{
name: "设置10倍杠杆",
symbol: "BTCUSDT",
leverage: 10,
wantError: false,
},
{
name: "设置1倍杠杆",
symbol: "ETHUSDT",
leverage: 1,
wantError: false,
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
err := s.Trader.SetLeverage(tt.symbol, tt.leverage)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// TestSetMarginMode 测试设置仓位模式
func (s *TraderTestSuite) TestSetMarginMode() {
tests := []struct {
name string
symbol string
isCrossMargin bool
wantError bool
}{
{
name: "设置全仓模式",
symbol: "BTCUSDT",
isCrossMargin: true,
wantError: false,
},
{
name: "设置逐仓模式",
symbol: "ETHUSDT",
isCrossMargin: false,
wantError: false,
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
err := s.Trader.SetMarginMode(tt.symbol, tt.isCrossMargin)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// TestFormatQuantity 测试数量格式化
func (s *TraderTestSuite) TestFormatQuantity() {
tests := []struct {
name string
symbol string
quantity float64
wantError bool
validate func(*testing.T, string)
}{
{
name: "格式化BTC数量",
symbol: "BTCUSDT",
quantity: 1.23456789,
wantError: false,
validate: func(t *testing.T, result string) {
assert.NotEmpty(t, result)
},
},
{
name: "格式化小数量",
symbol: "ETHUSDT",
quantity: 0.001,
wantError: false,
validate: func(t *testing.T, result string) {
assert.NotEmpty(t, result)
},
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
result, err := s.Trader.FormatQuantity(tt.symbol, tt.quantity)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.validate != nil {
tt.validate(t, result)
}
}
})
}
}
// TestCancelAllOrders 测试取消所有订单
func (s *TraderTestSuite) TestCancelAllOrders() {
tests := []struct {
name string
symbol string
wantError bool
}{
{
name: "取消BTC所有订单",
symbol: "BTCUSDT",
wantError: false,
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
err := s.Trader.CancelAllOrders(tt.symbol)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// ============================================================
// 核心交易方法测试
// ============================================================
// TestOpenLong 测试开多仓
func (s *TraderTestSuite) TestOpenLong() {
tests := []struct {
name string
symbol string
quantity float64
leverage int
wantError bool
validate func(*testing.T, map[string]interface{})
}{
{
name: "成功开多仓",
symbol: "BTCUSDT",
quantity: 0.01,
leverage: 10,
wantError: false,
validate: func(t *testing.T, result map[string]interface{}) {
assert.NotNil(t, result)
assert.Contains(t, result, "symbol")
assert.Equal(t, "BTCUSDT", result["symbol"])
},
},
{
name: "小数量开仓",
symbol: "ETHUSDT",
quantity: 0.004, // 增加到 0.004 以满足 Binance Futures 的 10 USDT 最小订单金额要求 (0.004 * 3000 = 12 USDT)
leverage: 5,
wantError: false,
validate: func(t *testing.T, result map[string]interface{}) {
assert.NotNil(t, result)
},
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
result, err := s.Trader.OpenLong(tt.symbol, tt.quantity, tt.leverage)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.validate != nil {
tt.validate(t, result)
}
}
})
}
}
// TestOpenShort 测试开空仓
func (s *TraderTestSuite) TestOpenShort() {
tests := []struct {
name string
symbol string
quantity float64
leverage int
wantError bool
validate func(*testing.T, map[string]interface{})
}{
{
name: "成功开空仓",
symbol: "BTCUSDT",
quantity: 0.01,
leverage: 10,
wantError: false,
validate: func(t *testing.T, result map[string]interface{}) {
assert.NotNil(t, result)
assert.Contains(t, result, "symbol")
assert.Equal(t, "BTCUSDT", result["symbol"])
},
},
{
name: "小数量开空仓",
symbol: "ETHUSDT",
quantity: 0.004, // 增加到 0.004 以满足 Binance Futures 的 10 USDT 最小订单金额要求 (0.004 * 3000 = 12 USDT)
leverage: 5,
wantError: false,
validate: func(t *testing.T, result map[string]interface{}) {
assert.NotNil(t, result)
},
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
result, err := s.Trader.OpenShort(tt.symbol, tt.quantity, tt.leverage)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.validate != nil {
tt.validate(t, result)
}
}
})
}
}
// TestCloseLong 测试平多仓
func (s *TraderTestSuite) TestCloseLong() {
tests := []struct {
name string
symbol string
quantity float64
wantError bool
validate func(*testing.T, map[string]interface{})
}{
{
name: "平指定数量",
symbol: "BTCUSDT",
quantity: 0.01,
wantError: false,
validate: func(t *testing.T, result map[string]interface{}) {
assert.NotNil(t, result)
assert.Contains(t, result, "symbol")
},
},
{
name: "全部平仓_quantity为0_无持仓返回错误",
symbol: "ETHUSDT",
quantity: 0,
wantError: true, // 当没有持仓时quantity=0 应该返回错误
validate: nil,
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
result, err := s.Trader.CloseLong(tt.symbol, tt.quantity)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.validate != nil {
tt.validate(t, result)
}
}
})
}
}
// TestCloseShort 测试平空仓
func (s *TraderTestSuite) TestCloseShort() {
tests := []struct {
name string
symbol string
quantity float64
wantError bool
validate func(*testing.T, map[string]interface{})
}{
{
name: "平指定数量",
symbol: "BTCUSDT",
quantity: 0.01,
wantError: false,
validate: func(t *testing.T, result map[string]interface{}) {
assert.NotNil(t, result)
assert.Contains(t, result, "symbol")
},
},
{
name: "全部平仓_quantity为0_无持仓返回错误",
symbol: "ETHUSDT",
quantity: 0,
wantError: true, // 当没有持仓时quantity=0 应该返回错误
validate: nil,
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
result, err := s.Trader.CloseShort(tt.symbol, tt.quantity)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.validate != nil {
tt.validate(t, result)
}
}
})
}
}
// ============================================================
// 止损止盈测试
// ============================================================
// TestSetStopLoss 测试设置止损
func (s *TraderTestSuite) TestSetStopLoss() {
tests := []struct {
name string
symbol string
positionSide string
quantity float64
stopPrice float64
wantError bool
}{
{
name: "多头止损",
symbol: "BTCUSDT",
positionSide: "LONG",
quantity: 0.01,
stopPrice: 45000.0,
wantError: false,
},
{
name: "空头止损",
symbol: "ETHUSDT",
positionSide: "SHORT",
quantity: 0.1,
stopPrice: 3200.0,
wantError: false,
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
err := s.Trader.SetStopLoss(tt.symbol, tt.positionSide, tt.quantity, tt.stopPrice)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// TestSetTakeProfit 测试设置止盈
func (s *TraderTestSuite) TestSetTakeProfit() {
tests := []struct {
name string
symbol string
positionSide string
quantity float64
takeProfitPrice float64
wantError bool
}{
{
name: "多头止盈",
symbol: "BTCUSDT",
positionSide: "LONG",
quantity: 0.01,
takeProfitPrice: 55000.0,
wantError: false,
},
{
name: "空头止盈",
symbol: "ETHUSDT",
positionSide: "SHORT",
quantity: 0.1,
takeProfitPrice: 2800.0,
wantError: false,
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
err := s.Trader.SetTakeProfit(tt.symbol, tt.positionSide, tt.quantity, tt.takeProfitPrice)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// TestCancelStopOrders 测试取消止盈止损单
func (s *TraderTestSuite) TestCancelStopOrders() {
tests := []struct {
name string
symbol string
wantError bool
}{
{
name: "取消BTC止盈止损单",
symbol: "BTCUSDT",
wantError: false,
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
err := s.Trader.CancelStopOrders(tt.symbol)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// TestCancelStopLossOrders 测试取消止损单
func (s *TraderTestSuite) TestCancelStopLossOrders() {
tests := []struct {
name string
symbol string
wantError bool
}{
{
name: "取消BTC止损单",
symbol: "BTCUSDT",
wantError: false,
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
err := s.Trader.CancelStopLossOrders(tt.symbol)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// TestCancelTakeProfitOrders 测试取消止盈单
func (s *TraderTestSuite) TestCancelTakeProfitOrders() {
tests := []struct {
name string
symbol string
wantError bool
}{
{
name: "取消BTC止盈单",
symbol: "BTCUSDT",
wantError: false,
},
}
for _, tt := range tests {
s.T.Run(tt.name, func(t *testing.T) {
err := s.Trader.CancelTakeProfitOrders(tt.symbol)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}