from __future__ import annotations import argparse import json from dataclasses import fields from pathlib import Path import pandas as pd from qfr.strategy.etf_trend import Constraints, TrendParams, UniverseAsset, run_backtest def load_prices(raw_dir: Path, universe: list[UniverseAsset]) -> dict[str, pd.DataFrame]: out: dict[str, pd.DataFrame] = {} for a in universe: fn = raw_dir / f"{a.ts_code.replace('.', '')}.parquet" if not fn.exists(): raise FileNotFoundError(f"missing data file: {fn}") df = pd.read_parquet(fn) out[a.ts_code] = df return out def perf_stats(equity: pd.Series) -> dict[str, float]: r = equity.pct_change().dropna() if r.empty: return {} ann_ret = float((equity.iloc[-1] / equity.iloc[0]) ** (252 / len(r)) - 1) ann_vol = float(r.std(ddof=1) * (252**0.5)) dd = (equity / equity.cummax() - 1.0).min() return {"ann_return": ann_ret, "ann_vol": ann_vol, "max_drawdown": float(dd)} def add_trendparams_args(p: argparse.ArgumentParser) -> None: # Expose a subset of TrendParams for fast experiments / grid search verification. # Keep names stable and CLI-friendly (kebab-case). tp_fields = {f.name: f for f in fields(TrendParams)} def add(name: str, arg: str, typ, help_: str) -> None: if name not in tp_fields: return p.add_argument(arg, type=typ, default=None, help=help_) add("sma_fast", "--sma-fast", int, "SMA fast window") add("sma_slow", "--sma-slow", int, "SMA slow window") add("lazy_days", "--lazy-days", int, "Min days between switches") add("min_hold_days", "--min-hold-days", int, "Min hold days before trend-exit/switch") add("replace_score_gap", "--replace-score-gap", float, "Replace weakest only if score gap >= this") add("min_score", "--min-score", float, "Entry score threshold (allow empty if not met)") add("macro_down_frac", "--macro-down-frac", float, "Down-day breadth threshold for consistent down") add("desired_positions_min", "--desired-positions-min", int, "Desired min positions (allow empty)") add("desired_positions_max", "--desired-positions-max", int, "Desired max positions") add("rebalance_band", "--rebalance-band", float, "Ignore small weight changes") add("atr_mult", "--atr-mult", float, "Chandelier ATR multiple") add("profit_tighten_atr", "--profit-tighten-atr", float, "Tighten trailing after profit >= N*ATR") add("atr_mult_profit", "--atr-mult-profit", float, "Chandelier ATR multiple after tighten") add("stop_loss_atr", "--stop-loss-atr", float, "Hard stop loss from entry in ATR") add("bias_exit", "--bias-exit", float, "Exit when abs(bias) >= threshold") add("vol_ratio_exit", "--vol-ratio-exit", float, "Exit when volume/amount ratio >= threshold") add("max_weight_per_asset", "--max-weight-per-asset", float, "Max weight per risky asset") add("concentration_power", "--concentration-power", float, "Weight concentration power") add("macro_min_breadth", "--macro-min-breadth", float, "Min equity breadth to be risk-on") add("macro_scale_risk_off", "--macro-scale-risk-off", float, "Scale risky weights in risk-off") def main() -> None: p = argparse.ArgumentParser() p.add_argument("--config", default="configs/etf_universe.json") p.add_argument("--rawdir", default="data/raw") p.add_argument("--out", default="data/etf_trend_equity.parquet") p.add_argument("--start", default="20200101", help="Filter start trade_date YYYYMMDD (inclusive)") p.add_argument("--end", default="20251231", help="Filter end trade_date YYYYMMDD (inclusive)") add_trendparams_args(p) args = p.parse_args() conf = json.loads(Path(args.config).read_text(encoding="utf-8")) universe = [UniverseAsset(**a) for a in conf["assets"]] cons = conf.get("constraints", {}) constraints = Constraints( max_positions=int(cons.get("max_positions", 4)), must_commodity=int(cons.get("must_include", {}).get("commodity", 1)), must_rates=int(cons.get("must_include", {}).get("rates", 1)), must_equity=int(cons.get("must_include", {}).get("equity", 1)), ) params = TrendParams(max_positions=constraints.max_positions) # apply CLI overrides overrides = { "sma_fast": args.sma_fast, "sma_slow": args.sma_slow, "lazy_days": args.lazy_days, "min_hold_days": getattr(args, "min_hold_days", None), "replace_score_gap": getattr(args, "replace_score_gap", None), "min_score": getattr(args, "min_score", None), "macro_down_frac": getattr(args, "macro_down_frac", None), "desired_positions_min": getattr(args, "desired_positions_min", None), "desired_positions_max": getattr(args, "desired_positions_max", None), "rebalance_band": args.rebalance_band, "atr_mult": args.atr_mult, "profit_tighten_atr": args.profit_tighten_atr, "atr_mult_profit": args.atr_mult_profit, "stop_loss_atr": args.stop_loss_atr, "bias_exit": args.bias_exit, "vol_ratio_exit": args.vol_ratio_exit, "max_weight_per_asset": args.max_weight_per_asset, "concentration_power": args.concentration_power, "macro_min_breadth": args.macro_min_breadth, "macro_scale_risk_off": args.macro_scale_risk_off, } overrides = {k: v for k, v in overrides.items() if v is not None} if overrides: params = TrendParams(**{**params.__dict__, **overrides}) risk_proxy = cons.get("risk_proxy", "510300.SH") rates_fallback = cons.get("rates_fallback") if rates_fallback is None: for a in universe: if a.asset_class.startswith("rates"): rates_fallback = a.ts_code break if not rates_fallback: raise RuntimeError("universe must include a rates asset for fallback") prices = load_prices(Path(args.rawdir), universe) for k, df in prices.items(): d = df.copy() d["trade_date"] = d["trade_date"].astype(str) d = d[(d["trade_date"] >= str(args.start)) & (d["trade_date"] <= str(args.end))] prices[k] = d equity, weights, trades = run_backtest(prices, universe, constraints, params, rates_fallback=rates_fallback, risk_proxy=risk_proxy) out = Path(args.out) out.parent.mkdir(parents=True, exist_ok=True) equity.to_parquet(out) weights_path = out.with_name(out.stem + "_weights" + out.suffix) trades_path = out.with_name(out.stem + "_trades" + out.suffix) weights.to_parquet(weights_path) if trades is not None and not trades.empty: trades.to_parquet(trades_path, index=False) print(f"wrote trades -> {trades_path}") st = perf_stats(equity["equity"]) print("perf", st) print("last equity", float(equity["equity"].iloc[-1])) print("last weights", weights.iloc[-1].sort_values(ascending=False).head(10).to_dict()) if __name__ == "__main__": main()