Files
quant-factor-research/scripts/auto_tune_etf_trend_fast.py

230 lines
8.5 KiB
Python
Raw Permalink Normal View History

2026-03-13 17:10:49 +08:00
from __future__ import annotations
import argparse
import itertools
import json
from dataclasses import replace
from pathlib import Path
import numpy as np
import pandas as pd
from qfr.strategy.etf_trend import Constraints, TrendParams, UniverseAsset, compute_features, portfolio_vol, risk_parity_weights, select_portfolio
def load_universe(config_path: Path) -> tuple[list[UniverseAsset], Constraints, str, str]:
conf = json.loads(config_path.read_text(encoding="utf-8"))
universe = [UniverseAsset(**a) for a in conf["assets"]]
cons = conf.get("constraints", {})
constraints = Constraints(
max_positions=int(cons.get("max_positions", 4)),
must_commodity=int(cons.get("must_include", {}).get("commodity", 0)),
must_rates=int(cons.get("must_include", {}).get("rates", 0)),
must_equity=int(cons.get("must_include", {}).get("equity", 0)),
)
risk_proxy = cons.get("risk_proxy", "510300.SH")
rates_fallback = cons.get("rates_fallback", "511010.SH")
return universe, constraints, risk_proxy, rates_fallback
def load_prices(raw_dir: Path, universe: list[UniverseAsset], start: str, end: str) -> dict[str, pd.DataFrame]:
out: dict[str, pd.DataFrame] = {}
for a in universe:
fn = raw_dir / f"{a.ts_code.replace('.', '')}.parquet"
df = pd.read_parquet(fn)
df = df.copy()
df["trade_date"] = df["trade_date"].astype(str)
df = df[(df["trade_date"] >= start) & (df["trade_date"] <= end)]
out[a.ts_code] = df
return out
def perf_stats(equity: pd.Series) -> dict[str, float]:
r = equity.pct_change().dropna()
if r.empty:
return {}
ann_ret = float((equity.iloc[-1] / equity.iloc[0]) ** (252 / len(r)) - 1)
ann_vol = float(r.std(ddof=1) * (252 ** 0.5))
dd = float((equity / equity.cummax() - 1.0).min())
calmar = float(ann_ret / abs(dd)) if dd < 0 else float("nan")
return {"ann_return": ann_ret, "ann_vol": ann_vol, "max_drawdown": dd, "calmar": calmar}
def run_backtest_cached(
feats: dict[str, pd.DataFrame],
universe: list[UniverseAsset],
constraints: Constraints,
params: TrendParams,
rates_fallback: str,
risk_proxy: str,
) -> pd.DataFrame:
# align dates intersection
dates = None
for f in feats.values():
d = set(f["trade_date"].astype(str))
dates = d if dates is None else dates.intersection(d)
if not dates:
raise RuntimeError("No overlapping trade_date")
all_dates = sorted(dates)
close_px = pd.DataFrame(index=all_dates)
ret1 = pd.DataFrame(index=all_dates)
for ts, f in feats.items():
g = f.set_index("trade_date").reindex(all_dates)
close_px[ts] = g["close"].astype(float)
ret1[ts] = close_px[ts].pct_change().fillna(0.0)
if risk_proxy not in close_px.columns:
raise RuntimeError("risk_proxy missing")
weights = pd.DataFrame(0.0, index=all_dates, columns=close_px.columns)
in_pos: set[str] = set()
highest_close: dict[str, float] = {}
atr_map = {ts: feats[ts].set_index("trade_date").reindex(all_dates)["atr"].astype(float) for ts in close_px.columns}
mf_map = {ts: feats[ts].set_index("trade_date").reindex(all_dates)["ma_fast"].astype(float) for ts in close_px.columns}
ms_map = {ts: feats[ts].set_index("trade_date").reindex(all_dates)["ma_slow"].astype(float) for ts in close_px.columns}
last_reb = -10**9
for i, d in enumerate(all_dates):
if i > 0:
weights.loc[d] = weights.iloc[i - 1]
for ts in list(in_pos):
c = float(close_px.loc[d, ts])
if np.isfinite(c):
highest_close[ts] = max(highest_close.get(ts, c), c)
# exits
for ts in list(in_pos):
c = float(close_px.loc[d, ts])
mf = float(mf_map[ts].loc[d])
ms = float(ms_map[ts].loc[d])
atr = float(atr_map[ts].loc[d])
h = highest_close.get(ts, c)
trend_break = (np.isfinite(mf) and np.isfinite(ms) and (mf < ms))
chand_break = np.isfinite(atr) and c < (h - params.atr_mult * atr)
if trend_break or chand_break:
weights.loc[d, ts] = 0.0
in_pos.remove(ts)
highest_close.pop(ts, None)
if (i - last_reb) >= params.rebalance_every:
rows = []
for ts in close_px.columns:
f = feats[ts].set_index("trade_date").reindex([d]).iloc[0]
rows.append((ts, bool(f["trend_ok"]) if pd.notna(f["trend_ok"]) else False,
float(f["score_raw"]) if pd.notna(f["score_raw"]) else float("nan"),
float(f["vol"]) if pd.notna(f["vol"]) else float("nan")))
snap = pd.DataFrame(rows, columns=["ts_code", "trend_ok", "score_raw", "vol"]).set_index("ts_code")
picks = select_portfolio(snap, universe, constraints)
vol = snap.loc[picks, "vol"].copy()
w = risk_parity_weights(vol, max_w=0.50)
trailing = ret1[picks].iloc[max(0, i - params.port_vol_window + 1) : i + 1]
pvol = portfolio_vol(trailing, w)
scale = 1.0
if np.isfinite(pvol) and pvol > 0:
scale = min(1.0, params.target_ann_vol / pvol)
w_exec = w * scale
weights.loc[d] = 0.0
for ts, wi in w_exec.items():
weights.loc[d, ts] = float(wi)
rem = 1.0 - float(w_exec.sum())
if rem > 1e-12 and rates_fallback in weights.columns:
weights.loc[d, rates_fallback] += rem
in_pos = {ts for ts in close_px.columns if weights.loc[d, ts] > 1e-12}
for ts in in_pos:
c = float(close_px.loc[d, ts])
highest_close[ts] = max(highest_close.get(ts, c), c)
last_reb = i
w_lag = weights.shift(1).fillna(0.0)
port_ret = (ret1 * w_lag).sum(axis=1)
equity = (1.0 + port_ret).cumprod().to_frame("equity")
return equity
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--config", default="configs/etf_universe.json")
ap.add_argument("--rawdir", default="data/raw")
ap.add_argument("--start", default="20200101")
ap.add_argument("--end", default="20251231")
ap.add_argument("--out", default="data/tune_results_fast.parquet")
args = ap.parse_args()
universe, constraints, risk_proxy, rates_fallback = load_universe(Path(args.config))
prices = load_prices(Path(args.rawdir), universe, args.start, args.end)
base = TrendParams(rebalance_every=1)
# grid (keep small)
fast_list = [3, 5, 8]
slow_list = [15, 20, 30]
atr_mult_list = [2.0, 2.5, 3.0]
vol_window_list = [10, 20]
port_vol_window_list = [40, 60]
max_positions_list = [3, 4]
rows = []
for sma_fast, sma_slow in itertools.product(fast_list, slow_list):
if sma_fast >= sma_slow:
continue
for atr_mult, vol_window, port_vol_window, max_positions in itertools.product(
atr_mult_list, vol_window_list, port_vol_window_list, max_positions_list
):
params = replace(
base,
max_positions=max_positions,
sma_fast=sma_fast,
sma_slow=sma_slow,
atr_mult=atr_mult,
vol_window=vol_window,
port_vol_window=port_vol_window,
)
cons = replace(constraints, max_positions=max_positions)
feats = {ts: compute_features(df, params) for ts, df in prices.items()}
equity = run_backtest_cached(feats, universe, cons, params, rates_fallback, risk_proxy)
st = perf_stats(equity["equity"])
if not st:
continue
rows.append({
"sma_fast": sma_fast,
"sma_slow": sma_slow,
"atr_mult": atr_mult,
"vol_window": vol_window,
"port_vol_window": port_vol_window,
"max_positions": max_positions,
**st,
})
df = pd.DataFrame(rows)
if df.empty:
print("no results")
return
filt = df[df["ann_vol"] <= 0.18].sort_values(["ann_return", "calmar"], ascending=False)
out = Path(args.out)
out.parent.mkdir(parents=True, exist_ok=True)
filt.to_parquet(out, index=False)
cols = ["ann_return", "ann_vol", "max_drawdown", "calmar", "sma_fast", "sma_slow", "atr_mult", "vol_window", "port_vol_window", "max_positions"]
print("top10")
print(filt[cols].head(10).to_string(index=False))
if __name__ == "__main__":
main()