from __future__ import annotations import argparse import json import os import random import sqlite3 from dataclasses import asdict, replace from datetime import datetime, timezone from pathlib import Path from typing import Any import numpy as np import pandas as pd from qfr.strategy.etf_trend import Constraints, TrendParams, UniverseAsset, run_backtest # Globals for multiprocessing (fork mode shares memory COW) _G_PRICES: dict[str, pd.DataFrame] | None = None _G_UNIVERSE: list[UniverseAsset] | None = None _G_CONSTRAINTS: Constraints | None = None _G_RISK_PROXY: str | None = None _G_RATES_FALLBACK: str | None = None def load_universe(config_path: Path) -> tuple[list[UniverseAsset], Constraints, str, str]: conf = json.loads(config_path.read_text(encoding="utf-8")) universe = [UniverseAsset(**a) for a in conf["assets"]] cons = conf.get("constraints", {}) constraints = Constraints( max_positions=int(cons.get("max_positions", 3)), must_commodity=int(cons.get("must_include", {}).get("commodity", 0)), must_rates=int(cons.get("must_include", {}).get("rates", 0)), must_equity=int(cons.get("must_include", {}).get("equity", 0)), ) risk_proxy = cons.get("risk_proxy") or (universe[0].ts_code if universe else "510300.SH") rates_fallback = cons.get("rates_fallback", "511010.SH") return universe, constraints, str(risk_proxy), str(rates_fallback) def load_prices(raw_dir: Path, universe: list[UniverseAsset], start: str, end: str) -> dict[str, pd.DataFrame]: out: dict[str, pd.DataFrame] = {} for a in universe: fn = raw_dir / (a.ts_code.replace(".", "") + ".parquet") df = pd.read_parquet(fn) df = df.copy() df["trade_date"] = df["trade_date"].astype(str) df = df[(df["trade_date"] >= start) & (df["trade_date"] <= end)] out[a.ts_code] = df return out def perf_stats(equity: pd.Series) -> dict[str, float]: r = equity.pct_change().dropna() if r.empty: return {} ann_ret = float((equity.iloc[-1] / equity.iloc[0]) ** (252 / len(r)) - 1) ann_vol = float(r.std(ddof=1) * (252**0.5)) dd = float((equity / equity.cummax() - 1.0).min()) sharpe = float(ann_ret / ann_vol) if ann_vol > 0 else float("nan") return {"ann_return": ann_ret, "ann_vol": ann_vol, "max_drawdown": dd, "sharpe": sharpe} def trades_per_year(trades: pd.DataFrame, start: str, end: str) -> float: if trades is None or trades.empty: return 0.0 years = max(1, (int(end[:4]) - int(start[:4]) + 1)) return float(len(trades) / years) def load_state(path: Path) -> dict: if path.exists(): return json.loads(path.read_text(encoding="utf-8")) return {"best": None, "last_reported_ann_return": None, "history": []} def save_state(path: Path, state: dict) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(state, ensure_ascii=True, indent=2) + "\n", encoding="utf-8") def infer_code_version(repo_dir: Path) -> str: # Prefer git commit hash if available. head = repo_dir / ".git" / "HEAD" if head.exists(): try: txt = head.read_text(encoding="utf-8").strip() if txt.startswith("ref:"): ref = txt.split(" ", 1)[1] ref_path = repo_dir / ".git" / ref if ref_path.exists(): return ref_path.read_text(encoding="utf-8").strip() return txt except Exception: return "unknown" return "nogit" def ensure_db(db_path: Path, param_cols: list[str]) -> None: db_path.parent.mkdir(parents=True, exist_ok=True) with sqlite3.connect(str(db_path)) as con: con.execute("PRAGMA journal_mode=WAL") con.execute("PRAGMA synchronous=NORMAL") con.execute( """ CREATE TABLE IF NOT EXISTS trials ( id INTEGER PRIMARY KEY AUTOINCREMENT, run_id TEXT NOT NULL, ts_utc TEXT NOT NULL, code_version TEXT, config_path TEXT, start TEXT, end TEXT, seed INTEGER, trial INTEGER, jobs INTEGER, ann_return REAL, ann_vol REAL, max_drawdown REAL, sharpe REAL, trades_per_year REAL ) """ ) # Add param columns if missing (structured fields) for c in param_cols: try: con.execute(f"ALTER TABLE trials ADD COLUMN {c} REAL") except sqlite3.OperationalError: pass def insert_rows(db_path: Path, param_cols: list[str], rows: list[dict[str, Any]]) -> None: if not rows: return cols = [ "run_id", "ts_utc", "code_version", "config_path", "start", "end", "seed", "trial", "jobs", "ann_return", "ann_vol", "max_drawdown", "sharpe", "trades_per_year", *param_cols, ] q = ",".join(["?"] * len(cols)) join_cols = ",".join(cols) sql = f"INSERT INTO trials ({join_cols}) VALUES ({q})" vals = [] for r in rows: vals.append([r.get(c) for c in cols]) with sqlite3.connect(str(db_path)) as con: con.executemany(sql, vals) con.commit() def reservoir_sample_product(rng, iterables, k: int): """Sample up to k combos from cartesian product.""" import itertools sample = [] n = 0 for combo in itertools.product(*iterables): n += 1 if len(sample) < k: sample.append(combo) else: j = rng.randrange(n) if j < k: sample[j] = combo return sample def _init_globals(prices: dict[str, pd.DataFrame], universe: list[UniverseAsset], constraints: Constraints, risk_proxy: str, rates_fallback: str) -> None: global _G_PRICES, _G_UNIVERSE, _G_CONSTRAINTS, _G_RISK_PROXY, _G_RATES_FALLBACK _G_PRICES = prices _G_UNIVERSE = universe _G_CONSTRAINTS = constraints _G_RISK_PROXY = risk_proxy _G_RATES_FALLBACK = rates_fallback def _eval_one(task: dict[str, Any]) -> dict[str, Any] | None: assert _G_PRICES is not None assert _G_UNIVERSE is not None assert _G_CONSTRAINTS is not None assert _G_RISK_PROXY is not None assert _G_RATES_FALLBACK is not None params = TrendParams() params = replace(params, **task["params"]) try: equity, _w, tr = run_backtest( _G_PRICES, _G_UNIVERSE, _G_CONSTRAINTS, params, rates_fallback=_G_RATES_FALLBACK, risk_proxy=_G_RISK_PROXY, ) except Exception: return None st = perf_stats(equity["equity"]) if not st: return None tpy = trades_per_year(tr, task["start"], task["end"]) if tpy > float(task["max_trades_per_year"]): return None row = {**st, "trades_per_year": float(tpy), **asdict(params)} row["trial"] = int(task["trial"]) row["seed"] = int(task["seed"]) return row MAX_GRID_COMBOS = 128 def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--config", default="configs/etf_universe_industry_profiled.json") ap.add_argument("--rawdir", default="data/raw") ap.add_argument("--start", default="20200101") ap.add_argument("--end", default="20251231") ap.add_argument("--trials", type=int, default=240) ap.add_argument("--mode", choices=["random", "grid"], default="random") ap.add_argument("--max_grid", type=int, default=MAX_GRID_COMBOS) ap.add_argument("--seed", type=int, default=1) ap.add_argument("--jobs", type=int, default=1, help="Parallel workers (processes), up to 8") ap.add_argument("--state", default="data/opt_state.json") ap.add_argument("--db", default="data/experiments.sqlite") ap.add_argument("--baseline", type=float, default=None) ap.add_argument("--report_step", type=float, default=0.05) ap.add_argument("--max_trades_per_year", type=float, default=80.0) ap.add_argument("--progress_every", type=int, default=25) args = ap.parse_args() jobs = max(1, min(8, int(args.jobs))) random.seed(args.seed) np.random.seed(args.seed) config_path = Path(args.config) universe, constraints, risk_proxy, rates_fallback = load_universe(config_path) prices = load_prices(Path(args.rawdir), universe, args.start, args.end) _init_globals(prices, universe, constraints, risk_proxy, rates_fallback) state_path = Path(args.state) state = load_state(state_path) best = state.get("best") best_ann = float(best["ann_return"]) if best else float("-inf") baseline = args.baseline if baseline is None: baseline = best_ann if np.isfinite(best_ann) else 0.0 last_rep = state.get("last_reported_ann_return") if last_rep is None: last_rep = baseline params0 = TrendParams(max_positions=constraints.max_positions) params0_dict = asdict(params0) # Parameter columns to persist as structured fields in SQLite param_cols = sorted(params0_dict.keys()) db_path = Path(args.db) ensure_db(db_path, param_cols=param_cols) run_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + f"_seed{int(args.seed)}" code_version = infer_code_version(Path(".")) tasks: list[dict[str, Any]] = [] rng = random.Random(int(args.seed)) if str(args.mode) == "grid": grids = { "sma_fast": [3, 5], "sma_slow": [15, 20, 30], "lazy_days": [4, 5, 6, 8], "min_hold_days": [2, 3, 5], "replace_score_gap": [0.5, 0.8, 1.2, 1.6], "min_score": [0.0, 0.2, 0.4, 0.6], "desired_positions_min": [1, 2], "macro_min_breadth": [0.10, 0.15, 0.20, 0.30], "macro_down_frac": [0.75, 0.80, 0.85], "atr_mult": [2.5, 3.2, 4.0], "stop_loss_atr": [2.0, 2.5, 3.2], "profit_tighten_atr": [4.0, 6.0, 8.0], "atr_mult_profit": [1.5, 2.0, 2.5], "bias_exit": [0.12, 0.18, 0.25], "vol_ratio_exit": [3.0, 4.0], } keys = list(grids.keys()) iters = [list(grids[k]) for k in keys] total = 1 for xs in iters: total *= max(1, len(xs)) max_grid = max(1, int(args.max_grid)) if total > max_grid: print(f"grid combos {total} > {max_grid}; sampling combos", flush=True) combos = reservoir_sample_product(rng, iters, max_grid) else: import itertools combos = list(itertools.product(*iters)) for t, combo in enumerate(combos): vals = dict(zip(keys, combo)) sma_fast = int(vals["sma_fast"]) sma_slow = int(vals["sma_slow"]) if sma_fast >= sma_slow: continue p = replace( params0, sma_fast=sma_fast, sma_slow=sma_slow, lazy_days=int(vals["lazy_days"]), min_hold_days=int(vals["min_hold_days"]), replace_score_gap=float(vals["replace_score_gap"]), min_score=float(vals["min_score"]), desired_positions_min=int(vals["desired_positions_min"]), desired_positions_max=int(3), macro_min_breadth=float(vals["macro_min_breadth"]), macro_down_frac=float(vals["macro_down_frac"]), atr_mult=float(vals["atr_mult"]), stop_loss_atr=float(vals["stop_loss_atr"]), profit_tighten_atr=float(vals["profit_tighten_atr"]), atr_mult_profit=float(vals["atr_mult_profit"]), bias_exit=float(vals["bias_exit"]), vol_ratio_exit=float(vals["vol_ratio_exit"]), rebalance_every=1, ) tasks.append({ "trial": int(t), "seed": int(args.seed), "start": str(args.start), "end": str(args.end), "max_trades_per_year": float(args.max_trades_per_year), "params": {k: asdict(p)[k] for k in param_cols}, }) else: for t in range(int(args.trials)): sma_fast = rng.choice([3, 5]) sma_slow = rng.choice([15, 20, 30]) if sma_fast >= sma_slow: continue lazy_days = rng.choice([4, 5, 6, 8]) min_hold = rng.choice([2, 3, 5]) replace_gap = rng.choice([0.5, 0.8, 1.2, 1.6]) min_score = rng.choice([0.0, 0.2, 0.4, 0.6]) dmin = rng.choice([1, 2]) dmax = 3 macro_min_breadth = rng.choice([0.10, 0.15, 0.20, 0.30]) macro_down_frac = rng.choice([0.75, 0.80, 0.85]) atr_mult = rng.choice([2.5, 3.2, 4.0]) stop_loss_atr = rng.choice([2.0, 2.5, 3.2]) profit_tighten_atr = rng.choice([4.0, 6.0, 8.0]) atr_mult_profit = rng.choice([1.5, 2.0, 2.5]) bias_exit = rng.choice([0.12, 0.18, 0.25]) vol_ratio_exit = rng.choice([3.0, 4.0]) p = replace(params0, sma_fast=int(sma_fast), sma_slow=int(sma_slow), lazy_days=int(lazy_days), min_hold_days=int(min_hold), replace_score_gap=float(replace_gap), min_score=float(min_score), desired_positions_min=int(dmin), desired_positions_max=int(dmax), macro_min_breadth=float(macro_min_breadth), macro_down_frac=float(macro_down_frac), atr_mult=float(atr_mult), stop_loss_atr=float(stop_loss_atr), profit_tighten_atr=float(profit_tighten_atr), atr_mult_profit=float(atr_mult_profit), bias_exit=float(bias_exit), vol_ratio_exit=float(vol_ratio_exit), rebalance_every=1) tasks.append({"trial": int(t), "seed": int(args.seed), "start": str(args.start), "end": str(args.end), "max_trades_per_year": float(args.max_trades_per_year), "params": {k: asdict(p)[k] for k in param_cols}}) results: list[dict[str, Any]] = [] rows_for_db: list[dict[str, Any]] = [] def record_row(row: dict[str, Any]) -> None: nonlocal best_ann results.append(row) if float(row["ann_return"]) > best_ann: best_ann = float(row["ann_return"]) state["best"] = row save_state(state_path, state) db_row = { "run_id": run_id, "ts_utc": datetime.now(timezone.utc).isoformat(), "code_version": code_version, "config_path": str(config_path), "start": str(args.start), "end": str(args.end), "seed": int(args.seed), "trial": int(row.get("trial", -1)), "jobs": int(jobs), "ann_return": float(row["ann_return"]), "ann_vol": float(row["ann_vol"]), "max_drawdown": float(row["max_drawdown"]), "sharpe": float(row["sharpe"]), "trades_per_year": float(row["trades_per_year"]), } for c in param_cols: db_row[c] = row.get(c) rows_for_db.append(db_row) if len(rows_for_db) >= 200: insert_rows(db_path, param_cols=param_cols, rows=rows_for_db) rows_for_db.clear() if jobs == 1: for task in tasks: row = _eval_one(task) if row is None: continue record_row(row) if int(args.progress_every) > 0 and (len(results) % int(args.progress_every) == 0): print(f"progress valid={len(results)} best_ann={best_ann:.4f}", flush=True) else: import multiprocessing as mp from concurrent.futures import ProcessPoolExecutor, as_completed ctx = mp.get_context("fork") with ProcessPoolExecutor(max_workers=jobs, mp_context=ctx) as ex: futs = [ex.submit(_eval_one, task) for task in tasks] for fut in as_completed(futs): row = fut.result() if row is None: continue record_row(row) if int(args.progress_every) > 0 and (len(results) % int(args.progress_every) == 0): print(f"progress valid={len(results)} best_ann={best_ann:.4f}", flush=True) if rows_for_db: insert_rows(db_path, param_cols=param_cols, rows=rows_for_db) rows_for_db.clear() state["history"].append( { "timestamp": datetime.now(timezone.utc).isoformat(), "run_id": run_id, "code_version": code_version, "config": str(args.config), "start": str(args.start), "end": str(args.end), "trials": int(args.trials), "jobs": int(jobs), "best_ann_return": float(best_ann) if np.isfinite(best_ann) else None, "db": str(args.db), } ) save_state(state_path, state) if not results: print("no valid trials") return df = pd.DataFrame(results).sort_values(["ann_return"], ascending=False) cols = [ "ann_return", "ann_vol", "max_drawdown", "sharpe", "trades_per_year", "sma_fast", "sma_slow", "lazy_days", "min_hold_days", "replace_score_gap", "min_score", "macro_min_breadth", "macro_down_frac", "desired_positions_min", "atr_mult", "stop_loss_atr", "profit_tighten_atr", "atr_mult_profit", "bias_exit", "vol_ratio_exit", ] cols = [c for c in cols if c in df.columns] print(df[cols].head(12).to_string(index=False)) if best_ann >= float(last_rep) + float(args.report_step): state["last_reported_ann_return"] = float(best_ann) save_state(state_path, state) print("REPORT_TRIGGER", float(best_ann), "baseline", float(last_rep)) if __name__ == "__main__": main()