from __future__ import annotations import argparse import itertools import json import random from dataclasses import asdict, replace from pathlib import Path import numpy as np import pandas as pd from qfr.strategy.etf_trend import Constraints, TrendParams, UniverseAsset, run_backtest def load_universe(config_path: Path) -> tuple[list[UniverseAsset], Constraints, str, str]: conf = json.loads(config_path.read_text(encoding="utf-8")) universe = [UniverseAsset(**a) for a in conf["assets"]] cons = conf.get("constraints", {}) constraints = Constraints( max_positions=int(cons.get("max_positions", 4)), must_commodity=int(cons.get("must_include", {}).get("commodity", 0)), must_rates=int(cons.get("must_include", {}).get("rates", 0)), must_equity=int(cons.get("must_include", {}).get("equity", 0)), ) risk_proxy = cons.get("risk_proxy", "510300.SH") rates_fallback = cons.get("rates_fallback", "511010.SH") return universe, constraints, risk_proxy, rates_fallback def load_prices(raw_dir: Path, universe: list[UniverseAsset], start: str, end: str) -> dict[str, pd.DataFrame]: out: dict[str, pd.DataFrame] = {} for a in universe: fn = raw_dir / f"{a.ts_code.replace('.', '')}.parquet" df = pd.read_parquet(fn) df = df.copy() df["trade_date"] = df["trade_date"].astype(str) df = df[(df["trade_date"] >= start) & (df["trade_date"] <= end)] out[a.ts_code] = df return out def perf_stats(equity: pd.Series) -> dict[str, float]: r = equity.pct_change().dropna() if r.empty: return {} ann_ret = float((equity.iloc[-1] / equity.iloc[0]) ** (252 / len(r)) - 1) ann_vol = float(r.std(ddof=1) * (252 ** 0.5)) dd = float((equity / equity.cummax() - 1.0).min()) sharpe = float(ann_ret / ann_vol) if ann_vol > 0 else float("nan") return {"ann_return": ann_ret, "ann_vol": ann_vol, "max_drawdown": dd, "sharpe": sharpe} def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--config", default="configs/etf_universe.json") ap.add_argument("--rawdir", default="data/raw") ap.add_argument("--start", default="20200101") ap.add_argument("--end", default="20251231") ap.add_argument("--out", default="data/grid_search_results.parquet") ap.add_argument("--seed", type=int, default=1) ap.add_argument("--max_combos", type=int, default=400, help="Randomly sample at most this many combos") args = ap.parse_args() universe, constraints, risk_proxy, rates_fallback = load_universe(Path(args.config)) prices = load_prices(Path(args.rawdir), universe, args.start, args.end) base = TrendParams(target_ann_vol=0.25) # Keep grid small. We will sample max_combos from the full cartesian product. grid = { "sma_fast": [3, 5, 8], "sma_slow": [15, 20, 30, 40], "lazy_days": [2, 5], "rebalance_band": [0.03, 0.06], "atr_mult": [2.5, 3.2, 4.0], "profit_tighten_atr": [3.0, 4.0], "atr_mult_profit": [1.5, 2.0], "stop_loss_atr": [2.5, 3.2], "bias_exit": [0.12, 0.18], "vol_ratio_exit": [2.0, 3.0], "max_weight_per_asset": [0.7, 0.9], "concentration_power": [1.6, 2.2], } keys = list(grid.keys()) combos = list(itertools.product(*(grid[k] for k in keys))) random.seed(int(args.seed)) if int(args.max_combos) > 0 and len(combos) > int(args.max_combos): combos = random.sample(combos, int(args.max_combos)) rows = [] for vals in combos: kw = dict(zip(keys, vals)) if int(kw["sma_fast"]) >= int(kw["sma_slow"]): continue params = replace(base, **kw, rebalance_every=1, max_positions=constraints.max_positions) try: equity, _w, _tr = run_backtest( prices, universe, constraints, params, rates_fallback=rates_fallback, risk_proxy=risk_proxy, ) except Exception: continue st = perf_stats(equity["equity"]) if not st: continue row = {**st, **asdict(params)} rows.append(row) df = pd.DataFrame(rows) if df.empty: print("no results") return df = df[df["ann_vol"] <= 0.25].copy() df = df.sort_values(["ann_return", "sharpe"], ascending=False) out = Path(args.out) out.parent.mkdir(parents=True, exist_ok=True) df.to_parquet(out, index=False) cols = [ "ann_return", "ann_vol", "max_drawdown", "sharpe", "sma_fast", "sma_slow", "lazy_days", "rebalance_band", "atr_mult", "profit_tighten_atr", "atr_mult_profit", "stop_loss_atr", "bias_exit", "vol_ratio_exit", "max_weight_per_asset", "concentration_power", ] print("top10") print(df[cols].head(10).to_string(index=False)) if __name__ == "__main__": main()